Go Socket 编程教程
📖 关于本教程本教程系统讲解 Go 网络编程的核心知识:TCP/UDP 协议、长连接、粘包处理、并发安全、以及一个完整的高并发 RPC Server 实战项目。
1. TCP 和 UDP 协议解读
1.1 OSI 模型与 TCP/IP 分层
text
OSI 七层模型 TCP/IP 四层模型 对应协议 / 技术
───────────────────────────────────────────────────────────
应用层 ─┐
表示层 ├───────── 应用层 HTTP, gRPC, DNS, SMTP, SSH
会话层 ─┘
传输层 ────────── 传输层 TCP, UDP
网络层 ────────── 网络层 IP, ICMP
数据链路层 ─┐
物理层 ─┴───── 网络接口层 Ethernet, Wi-Fi
数据在网络中的流动(以 HTTP 请求为例):
发送端:
应用数据 → [HTTP头 + 数据] → [TCP头 + ...] → [IP头 + ...] → [帧头 + ... + 帧尾]
封装 ↓ 封装 ↓ 封装 ↓ 封装 ↓
接收端:
[帧头 + ... + 帧尾] → [IP头 + ...] → [TCP头 + ...] → [HTTP头 + 数据] → 应用数据
解封 ↓ 解封 ↓ 解封 ↓ 解封 ↓1.2 TCP 协议
text
TCP(Transmission Control Protocol)传输控制协议
特点:
✅ 面向连接(三次握手建立,四次挥手断开)
✅ 可靠传输(确认应答、超时重传、校验和)
✅ 有序传输(序列号保证顺序)
✅ 流量控制(滑动窗口)
✅ 拥塞控制(慢启动、拥塞避免)
❌ 开销较大(头部 20+ 字节)
❌ 存在粘包问题(字节流,无消息边界)
三次握手(建立连接):
客户端 服务端
│ │
│──── SYN (seq=x) ─────→│ ① 客户端发起连接请求
│ │
│←── SYN+ACK ───────────│ ② 服务端确认并发起自己的连接请求
│ (seq=y, ack=x+1) │
│ │
│──── ACK (ack=y+1) ───→│ ③ 客户端确认
│ │
│ 连接已建立 │
四次挥手(断开连接):
客户端 服务端
│ │
│──── FIN ──────────────→│ ① 客户端:我要关闭了
│←── ACK ───────────────│ ② 服务端:收到,我还有数据要发
│ ... 服务端继续发数据 ... │
│←── FIN ───────────────│ ③ 服务端:我也关闭了
│──── ACK ──────────────→│ ④ 客户端:收到
│ │
│ 连接已断开 │
适用场景:HTTP、gRPC、数据库连接、文件传输1.3 UDP 协议
text
UDP(User Datagram Protocol)用户数据报协议
特点:
✅ 无连接(不需要握手,直接发送)
✅ 低延迟(无重传、无拥塞控制)
✅ 有消息边界(每个报文独立,无粘包)
✅ 支持广播和多播
✅ 头部开销小(仅 8 字节)
❌ 不可靠(可能丢包、乱序、重复)
❌ 无流量控制
UDP 报文结构(简单):
┌──────────┬──────────┬──────────┬──────────┐
│ 源端口 │ 目标端口 │ 长度 │ 校验和 │ ← 8 字节头部
├──────────┴──────────┴──────────┴──────────┤
│ 数据 │
└────────────────────────────────────────────┘
适用场景:DNS 查询、视频直播、游戏、VoIP、IoT 传感器数据1.4 TCP vs UDP 对比
| 特性 | TCP | UDP |
|---|---|---|
| 连接方式 | 面向连接 | 无连接 |
| 可靠性 | 可靠(确认 + 重传) | 不可靠 |
| 有序性 | 保证顺序 | 不保证 |
| 消息边界 | 无(字节流) | 有(数据报) |
| 速度 | 较慢 | 快 |
| 头部开销 | 20+ 字节 | 8 字节 |
| 粘包问题 | 有 | 无 |
| 广播/多播 | 不支持 | 支持 |
2. TCP 编程
2.1 TCP Server
go
package main
import (
"bufio"
"fmt"
"io"
"net"
"strings"
"time"
)
func main() {
// 监听端口
listener, err := net.Listen("tcp", ":8080")
if err != nil {
fmt.Println("监听失败:", err)
return
}
defer listener.Close()
fmt.Println("TCP Server 启动,监听 :8080")
for {
// Accept 阻塞等待客户端连接
conn, err := listener.Accept()
if err != nil {
fmt.Println("接受连接失败:", err)
continue
}
// 每个连接启动一个协程处理(并发)
go handleConnection(conn)
}
}
func handleConnection(conn net.Conn) {
defer conn.Close()
remoteAddr := conn.RemoteAddr().String()
fmt.Printf("[%s] 客户端已连接\n", remoteAddr)
// 设置读取超时
conn.SetReadDeadline(time.Now().Add(5 * time.Minute))
reader := bufio.NewReader(conn)
for {
// 读取一行数据(以 \n 结尾)
message, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
fmt.Printf("[%s] 客户端断开连接\n", remoteAddr)
} else {
fmt.Printf("[%s] 读取错误: %v\n", remoteAddr, err)
}
return
}
message = strings.TrimSpace(message)
fmt.Printf("[%s] 收到: %s\n", remoteAddr, message)
// 处理命令
var response string
switch {
case message == "ping":
response = "pong"
case message == "time":
response = time.Now().Format("2006-01-02 15:04:05")
case message == "quit":
conn.Write([]byte("bye\n"))
return
default:
response = "echo: " + message
}
// 发送响应
_, err = conn.Write([]byte(response + "\n"))
if err != nil {
fmt.Printf("[%s] 写入错误: %v\n", remoteAddr, err)
return
}
}
}2.2 TCP Client
go
package main
import (
"bufio"
"fmt"
"net"
"os"
"strings"
)
func main() {
// 连接服务器
conn, err := net.Dial("tcp", "localhost:8080")
if err != nil {
fmt.Println("连接失败:", err)
return
}
defer conn.Close()
fmt.Println("已连接到服务器")
// 启动一个协程读取服务器响应
go func() {
reader := bufio.NewReader(conn)
for {
response, err := reader.ReadString('\n')
if err != nil {
fmt.Println("连接断开")
os.Exit(0)
}
fmt.Print("服务器: " + response)
}
}()
// 主协程读取用户输入
scanner := bufio.NewScanner(os.Stdin)
for {
fmt.Print("> ")
if !scanner.Scan() {
break
}
input := strings.TrimSpace(scanner.Text())
if input == "" {
continue
}
// 发送到服务器
_, err := conn.Write([]byte(input + "\n"))
if err != nil {
fmt.Println("发送失败:", err)
break
}
if input == "quit" {
break
}
}
}2.3 net.Conn 常用方法
go
package main
import (
"net"
"time"
)
func connMethods(conn net.Conn) {
// ==================== 读写 ====================
buf := make([]byte, 1024)
n, _ := conn.Read(buf) // 读取数据到 buf,返回读到的字节数
_ = buf[:n]
conn.Write([]byte("hello")) // 写入数据
// ==================== 地址信息 ====================
_ = conn.LocalAddr() // 本端地址 (ip:port)
_ = conn.RemoteAddr() // 对端地址 (ip:port)
// ==================== 超时控制 ====================
conn.SetDeadline(time.Now().Add(30 * time.Second)) // 读写超时
conn.SetReadDeadline(time.Now().Add(30 * time.Second)) // 仅读超时
conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) // 仅写超时
// 超时后 Read/Write 返回 net.Error(Timeout() == true)
// ==================== 关闭 ====================
conn.Close() // 关闭连接(双向)
// TCP 连接可以半关闭
if tcpConn, ok := conn.(*net.TCPConn); ok {
tcpConn.CloseRead() // 关闭读端(不再接收数据)
tcpConn.CloseWrite() // 关闭写端(发送 FIN,不再发送数据)
}
}3. UDP 编程
3.1 UDP Server
go
package main
import (
"fmt"
"net"
"strings"
"time"
)
func main() {
// 监听 UDP 端口
addr, err := net.ResolveUDPAddr("udp", ":9090")
if err != nil {
fmt.Println("解析地址失败:", err)
return
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
fmt.Println("监听失败:", err)
return
}
defer conn.Close()
fmt.Println("UDP Server 启动,监听 :9090")
buf := make([]byte, 1024)
for {
// ReadFromUDP 读取数据和来源地址
// UDP 无连接,每次读取都能知道是谁发来的
n, clientAddr, err := conn.ReadFromUDP(buf)
if err != nil {
fmt.Println("读取失败:", err)
continue
}
message := strings.TrimSpace(string(buf[:n]))
fmt.Printf("[%s] 收到: %s\n", clientAddr, message)
// 构造响应
response := fmt.Sprintf("[%s] echo: %s",
time.Now().Format("15:04:05"), message)
// WriteToUDP 发送响应到指定地址
_, err = conn.WriteToUDP([]byte(response), clientAddr)
if err != nil {
fmt.Println("发送失败:", err)
}
}
}3.2 UDP Client
go
package main
import (
"fmt"
"net"
"time"
)
func main() {
// 解析服务器地址
serverAddr, err := net.ResolveUDPAddr("udp", "localhost:9090")
if err != nil {
fmt.Println("解析地址失败:", err)
return
}
// 创建 UDP 连接(实际上不会真的建立连接,只是绑定地址)
conn, err := net.DialUDP("udp", nil, serverAddr)
if err != nil {
fmt.Println("连接失败:", err)
return
}
defer conn.Close()
// 发送数据
messages := []string{"hello", "world", "ping", "Go UDP"}
buf := make([]byte, 1024)
for _, msg := range messages {
// 发送
_, err := conn.Write([]byte(msg))
if err != nil {
fmt.Println("发送失败:", err)
continue
}
// 设置读超时
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
// 接收响应
n, err := conn.Read(buf)
if err != nil {
fmt.Println("接收超时或失败:", err)
continue
}
fmt.Printf("服务器响应: %s\n", string(buf[:n]))
}
}3.3 UDP 广播
go
package main
import (
"fmt"
"net"
"time"
)
// UDP 广播发送端
func broadcastSender() {
// 广播地址:255.255.255.255 或子网广播地址
addr, _ := net.ResolveUDPAddr("udp", "255.255.255.255:9999")
conn, _ := net.DialUDP("udp", nil, addr)
defer conn.Close()
for i := 0; i < 5; i++ {
msg := fmt.Sprintf("广播消息 #%d", i)
conn.Write([]byte(msg))
fmt.Println("发送:", msg)
time.Sleep(time.Second)
}
}
// UDP 广播接收端
func broadcastReceiver() {
addr, _ := net.ResolveUDPAddr("udp", ":9999")
conn, _ := net.ListenUDP("udp", addr)
defer conn.Close()
buf := make([]byte, 1024)
for {
n, src, _ := conn.ReadFromUDP(buf)
fmt.Printf("收到广播 [%s]: %s\n", src, string(buf[:n]))
}
}4. 长连接
4.1 长连接 vs 短连接
text
短连接:
客户端 ──连接──→ 发送请求 ──→ 接收响应 ──→ 断开
每次通信都要重新建立连接,开销大
长连接:
客户端 ──连接──→ 发送请求1 → 接收响应1
发送请求2 → 接收响应2
...(保持连接不断开)
发送请求N → 接收响应N ──→ 断开
一次连接,多次通信
长连接优势:
✅ 减少 TCP 握手/挥手开销
✅ 减少端口占用(TIME_WAIT 问题)
✅ 服务端可以主动推送
✅ 适合频繁通信的场景
长连接挑战:
⚠️ 需要心跳机制检测连接是否存活
⚠️ 需要处理连接的异常断开和重连
⚠️ 服务端需要管理大量连接4.2 带心跳的长连接 Server
go
package main
import (
"bufio"
"fmt"
"io"
"net"
"strings"
"sync"
"time"
)
// ConnManager 连接管理器
type ConnManager struct {
mu sync.RWMutex
conns map[string]net.Conn
}
func NewConnManager() *ConnManager {
return &ConnManager{conns: make(map[string]net.Conn)}
}
func (cm *ConnManager) Add(addr string, conn net.Conn) {
cm.mu.Lock()
defer cm.mu.Unlock()
cm.conns[addr] = conn
}
func (cm *ConnManager) Remove(addr string) {
cm.mu.Lock()
defer cm.mu.Unlock()
delete(cm.conns, addr)
}
func (cm *ConnManager) Count() int {
cm.mu.RLock()
defer cm.mu.RUnlock()
return len(cm.conns)
}
// Broadcast 向所有连接广播消息
func (cm *ConnManager) Broadcast(msg string) {
cm.mu.RLock()
defer cm.mu.RUnlock()
for addr, conn := range cm.conns {
_, err := conn.Write([]byte(msg + "\n"))
if err != nil {
fmt.Printf("[%s] 广播失败: %v\n", addr, err)
}
}
}
var manager = NewConnManager()
func main() {
listener, err := net.Listen("tcp", ":8080")
if err != nil {
fmt.Println("监听失败:", err)
return
}
defer listener.Close()
fmt.Println("长连接 Server 启动,监听 :8080")
// 定期打印连接数
go func() {
for range time.Tick(10 * time.Second) {
fmt.Printf("当前连接数: %d\n", manager.Count())
}
}()
for {
conn, err := listener.Accept()
if err != nil {
continue
}
go handleLongConnection(conn)
}
}
func handleLongConnection(conn net.Conn) {
addr := conn.RemoteAddr().String()
manager.Add(addr, conn)
fmt.Printf("[%s] 新连接\n", addr)
defer func() {
conn.Close()
manager.Remove(addr)
fmt.Printf("[%s] 连接关闭,剩余连接: %d\n", addr, manager.Count())
}()
// 心跳超时时间
heartbeatTimeout := 30 * time.Second
reader := bufio.NewReader(conn)
for {
// 每次读取前重置超时
conn.SetReadDeadline(time.Now().Add(heartbeatTimeout))
message, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
fmt.Printf("[%s] 客户端主动断开\n", addr)
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
fmt.Printf("[%s] 心跳超时,断开连接\n", addr)
} else {
fmt.Printf("[%s] 读取错误: %v\n", addr, err)
}
return
}
message = strings.TrimSpace(message)
if message == "HEARTBEAT" {
conn.Write([]byte("HEARTBEAT_ACK\n"))
continue
}
fmt.Printf("[%s] 收到: %s\n", addr, message)
conn.Write([]byte("echo: " + message + "\n"))
}
}4.3 带心跳和自动重连的 Client
go
package main
import (
"bufio"
"fmt"
"net"
"sync"
"time"
)
// LongConnClient 长连接客户端
type LongConnClient struct {
addr string
conn net.Conn
mu sync.Mutex
connected bool
stopCh chan struct{}
}
func NewLongConnClient(addr string) *LongConnClient {
return &LongConnClient{
addr: addr,
stopCh: make(chan struct{}),
}
}
// Connect 建立连接
func (c *LongConnClient) Connect() error {
c.mu.Lock()
defer c.mu.Unlock()
conn, err := net.DialTimeout("tcp", c.addr, 5*time.Second)
if err != nil {
return err
}
c.conn = conn
c.connected = true
fmt.Println("已连接到", c.addr)
// 启动心跳
go c.heartbeat()
// 启动接收
go c.receive()
return nil
}
// heartbeat 定时发送心跳
func (c *LongConnClient) heartbeat() {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := c.Send("HEARTBEAT"); err != nil {
fmt.Println("心跳发送失败,尝试重连...")
c.reconnect()
return
}
case <-c.stopCh:
return
}
}
}
// receive 接收服务器数据
func (c *LongConnClient) receive() {
reader := bufio.NewReader(c.conn)
for {
response, err := reader.ReadString('\n')
if err != nil {
fmt.Println("接收断开:", err)
c.connected = false
return
}
fmt.Print("← " + response)
}
}
// reconnect 自动重连
func (c *LongConnClient) reconnect() {
c.mu.Lock()
c.connected = false
if c.conn != nil {
c.conn.Close()
}
c.mu.Unlock()
for i := 0; i < 10; i++ {
delay := time.Duration(1<<uint(i)) * time.Second // 指数退避
if delay > 30*time.Second {
delay = 30 * time.Second
}
fmt.Printf("第 %d 次重连,等待 %s...\n", i+1, delay)
time.Sleep(delay)
if err := c.Connect(); err == nil {
fmt.Println("重连成功")
return
}
}
fmt.Println("重连失败,已达最大重试次数")
}
// Send 发送数据
func (c *LongConnClient) Send(msg string) error {
c.mu.Lock()
defer c.mu.Unlock()
if !c.connected || c.conn == nil {
return fmt.Errorf("not connected")
}
_, err := c.conn.Write([]byte(msg + "\n"))
return err
}
// Close 关闭连接
func (c *LongConnClient) Close() {
close(c.stopCh)
c.mu.Lock()
defer c.mu.Unlock()
if c.conn != nil {
c.conn.Close()
}
c.connected = false
}
func main() {
client := NewLongConnClient("localhost:8080")
if err := client.Connect(); err != nil {
fmt.Println("连接失败:", err)
return
}
defer client.Close()
// 模拟业务发送
for i := 0; i < 5; i++ {
client.Send(fmt.Sprintf("消息 #%d", i))
time.Sleep(3 * time.Second)
}
}5. TCP 粘包问题解决
5.1 粘包是怎么发生的
❌ TCP 是字节流协议,没有消息边界!发送端发了 3 条消息:"hello" "world" "go",接收端可能收到的是 "helloworldgo" 或 "hellow" + "orldgo" 等任意组合。
text
发送端连续发送三条消息:
send("hello")
send("world")
send("go!!!")
接收端可能收到的情况:
情况1(正常):收到 "hello"、"world"、"go!!!"(三次 Read)
情况2(粘包):收到 "helloworld"、"go!!!"(两条粘在一起)
情况3(拆包):收到 "hel"、"loworld"、"go!!!"(一条被拆开)
情况4(混合):收到 "helloworldgo!!!"(全部粘在一起)
原因:
1. Nagle 算法:TCP 会合并小数据包一起发送
2. 接收缓冲区:Read 读取的是缓冲区中所有可用数据
3. 分片重组:大数据包在网络层可能被分片5.2 解决方案1:固定长度
go
package main
import (
"fmt"
"io"
"net"
)
const MessageSize = 32 // 每条消息固定 32 字节
// 发送固定长度消息
func sendFixed(conn net.Conn, msg string) error {
buf := make([]byte, MessageSize)
copy(buf, msg) // 不足 32 字节自动补零
_, err := conn.Write(buf)
return err
}
// 接收固定长度消息
func recvFixed(conn net.Conn) (string, error) {
buf := make([]byte, MessageSize)
_, err := io.ReadFull(conn, buf) // 严格读满 32 字节
if err != nil {
return "", err
}
// 去掉末尾的零字节
end := len(buf)
for end > 0 && buf[end-1] == 0 {
end--
}
return string(buf[:end]), nil
}
// 优点:实现简单
// 缺点:浪费带宽(短消息补零),无法处理超长消息5.3 解决方案2:分隔符
go
package main
import (
"bufio"
"fmt"
"net"
"strings"
)
const Delimiter = '\n' // 用换行符作为分隔符
// 发送带分隔符的消息
func sendDelimited(conn net.Conn, msg string) error {
// 确保消息本身不含分隔符
msg = strings.ReplaceAll(msg, "\n", "\\n")
_, err := conn.Write([]byte(msg + "\n"))
return err
}
// 接收带分隔符的消息
func recvDelimited(reader *bufio.Reader) (string, error) {
msg, err := reader.ReadString(Delimiter)
if err != nil {
return "", err
}
msg = strings.TrimRight(msg, "\n")
msg = strings.ReplaceAll(msg, "\\n", "\n")
return msg, nil
}
// 优点:实现简单,支持变长消息
// 缺点:需要转义消息中的分隔符,不适合二进制数据5.4 解决方案3:长度前缀(最通用,推荐)
go
package main
import (
"encoding/binary"
"fmt"
"io"
"net"
)
// ==================== 协议格式 ====================
// ┌────────────┬────────────────────┐
// │ 4 字节长度 │ 消息体 │
// │ (大端序) │ (变长,最大 10MB) │
// └────────────┴────────────────────┘
const MaxMessageSize = 10 * 1024 * 1024 // 10MB
// Encode 编码消息:长度前缀 + 消息体
func Encode(data []byte) []byte {
length := uint32(len(data))
buf := make([]byte, 4+len(data))
binary.BigEndian.PutUint32(buf[:4], length) // 前 4 字节写长度
copy(buf[4:], data) // 后面写数据
return buf
}
// Decode 从连接中解码一条完整消息
func Decode(conn net.Conn) ([]byte, error) {
// 第 1 步:读取 4 字节长度头
header := make([]byte, 4)
_, err := io.ReadFull(conn, header)
if err != nil {
return nil, err
}
// 解析长度
length := binary.BigEndian.Uint32(header)
if length > MaxMessageSize {
return nil, fmt.Errorf("message too large: %d", length)
}
if length == 0 {
return []byte{}, nil
}
// 第 2 步:根据长度读取消息体
body := make([]byte, length)
_, err = io.ReadFull(conn, body)
if err != nil {
return nil, err
}
return body, nil
}
// ==================== 使用示例 ====================
func serverExample() {
listener, _ := net.Listen("tcp", ":8080")
defer listener.Close()
for {
conn, _ := listener.Accept()
go func(c net.Conn) {
defer c.Close()
for {
// 解码一条完整消息(自动处理粘包)
data, err := Decode(c)
if err != nil {
return
}
fmt.Printf("收到完整消息: %s (长度 %d)\n", string(data), len(data))
// 回复
response := Encode([]byte("OK: " + string(data)))
c.Write(response)
}
}(conn)
}
}
func clientExample() {
conn, _ := net.Dial("tcp", "localhost:8080")
defer conn.Close()
// 连续快速发送多条消息(即使粘包也能正确解析)
messages := []string{
"hello",
"这是一条中文消息",
"short",
"这是一条比较长的消息,用来测试变长消息的处理能力",
}
for _, msg := range messages {
data := Encode([]byte(msg))
conn.Write(data)
}
// 读取响应
for range messages {
response, _ := Decode(conn)
fmt.Println("响应:", string(response))
}
}💡 长度前缀是工业级方案 gRPC(HTTP/2 帧)、Kafka、Redis 协议、MySQL 协议等几乎所有二进制协议都使用长度前缀来解决粘包问题。这是最推荐的方案。
6. socket 连接的并发安全性
go
package main
import (
"encoding/binary"
"fmt"
"io"
"net"
"sync"
)
// ==================== 问题:net.Conn 的并发安全性 ====================
//
// Go 官方文档:
// "Multiple goroutines may invoke methods on a Conn simultaneously."
//
// 但这里有微妙之处:
// ✅ 并发调用 Read 和 Write 是安全的(一个读,一个写)
// ❌ 多个协程并发 Write 不安全(数据会交错)
// ❌ 多个协程并发 Read 不安全(数据被分散到不同协程)
//
// 单次 Write 调用是原子的,但连续调用不是:
// 协程A: conn.Write(header) → conn.Write(body)
// 协程B: conn.Write(header) → conn.Write(body)
// 实际可能:headerA → headerB → bodyA → bodyB(数据错乱)
// SafeConn 并发安全的连接包装
type SafeConn struct {
conn net.Conn
writeMu sync.Mutex // 写锁
readMu sync.Mutex // 读锁
}
func NewSafeConn(conn net.Conn) *SafeConn {
return &SafeConn{conn: conn}
}
// SendMessage 并发安全的消息发送(长度前缀协议)
func (sc *SafeConn) SendMessage(data []byte) error {
sc.writeMu.Lock()
defer sc.writeMu.Unlock()
// header + body 在同一把锁内写入,保证原子性
header := make([]byte, 4)
binary.BigEndian.PutUint32(header, uint32(len(data)))
if _, err := sc.conn.Write(header); err != nil {
return err
}
if _, err := sc.conn.Write(data); err != nil {
return err
}
return nil
}
// RecvMessage 并发安全的消息接收
func (sc *SafeConn) RecvMessage() ([]byte, error) {
sc.readMu.Lock()
defer sc.readMu.Unlock()
// 读取长度头
header := make([]byte, 4)
if _, err := io.ReadFull(sc.conn, header); err != nil {
return nil, err
}
length := binary.BigEndian.Uint32(header)
if length == 0 {
return []byte{}, nil
}
// 读取消息体
body := make([]byte, length)
if _, err := io.ReadFull(sc.conn, body); err != nil {
return nil, err
}
return body, nil
}
func (sc *SafeConn) Close() error {
return sc.conn.Close()
}
// ==================== 使用示例 ====================
func concurrentSendExample() {
conn, _ := net.Dial("tcp", "localhost:8080")
safeConn := NewSafeConn(conn)
defer safeConn.Close()
var wg sync.WaitGroup
// 10 个协程并发发送,不会数据交错
for i := 0; i < 10; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
msg := fmt.Sprintf("来自协程 %d 的消息", id)
if err := safeConn.SendMessage([]byte(msg)); err != nil {
fmt.Printf("协程 %d 发送失败: %v\n", id, err)
}
}(i)
}
wg.Wait()
}⚠️ 最佳实践实际项目中通常只有一个协程负责 Write(通过 channel 接收要发送的数据),一个协程负责 Read。这样完全不需要加锁,也更容易管理。
7. 练习:基于 socket 编程实现支持高并发的 RPC Server
💡 目标实现一个完整的 RPC 框架,包含:自定义二进制协议、服务注册、方法调用、并发处理、超时控制、连接池。不依赖任何第三方 RPC 框架。
7.1 协议设计
text
自定义 RPC 协议格式:
请求:
┌──────────┬──────────┬───────────────┬──────────┬────────────┬──────────┬───────────┐
│ MagicNum │ Version │ RequestID │ MethodLen│ Method │ BodyLen │ Body │
│ 4 bytes │ 1 byte │ 8 bytes │ 2 bytes │ var bytes │ 4 bytes │ var bytes │
│ 0x47525043│ 0x01 │ uint64 │ uint16 │ string │ uint32 │ JSON │
└──────────┴──────────┴───────────────┴──────────┴────────────┴──────────┴───────────┘
响应:
┌──────────┬──────────┬───────────────┬──────────┬──────────┬───────────┐
│ MagicNum │ Version │ RequestID │ Status │ BodyLen │ Body │
│ 4 bytes │ 1 byte │ 8 bytes │ 1 byte │ 4 bytes │ var bytes │
│ 0x47525043│ 0x01 │ uint64 │ 0=OK │ uint32 │ JSON │
└──────────┴──────────┴───────────────┴──────────┴──────────┴───────────┘
MagicNum = "GRPC" 的 ASCII 值(0x47525043)用于校验协议7.2 协议编解码
go
package rpc
import (
"encoding/binary"
"encoding/json"
"fmt"
"io"
"net"
)
const (
MagicNumber uint32 = 0x47525043 // "GRPC"
Version byte = 0x01
)
// 响应状态码
const (
StatusOK byte = 0
StatusError byte = 1
StatusNotFound byte = 2
StatusTimeout byte = 3
)
// Request RPC 请求
type Request struct {
RequestID uint64
Method string
Body []byte // JSON 编码的参数
}
// Response RPC 响应
type Response struct {
RequestID uint64
Status byte
Body []byte // JSON 编码的结果
}
// WriteRequest 写入请求到连接
func WriteRequest(conn net.Conn, req *Request) error {
methodBytes := []byte(req.Method)
// 计算总长度并一次性写入(避免并发问题)
totalLen := 4 + 1 + 8 + 2 + len(methodBytes) + 4 + len(req.Body)
buf := make([]byte, totalLen)
offset := 0
// MagicNumber (4 bytes)
binary.BigEndian.PutUint32(buf[offset:], MagicNumber)
offset += 4
// Version (1 byte)
buf[offset] = Version
offset++
// RequestID (8 bytes)
binary.BigEndian.PutUint64(buf[offset:], req.RequestID)
offset += 8
// MethodLen + Method
binary.BigEndian.PutUint16(buf[offset:], uint16(len(methodBytes)))
offset += 2
copy(buf[offset:], methodBytes)
offset += len(methodBytes)
// BodyLen + Body
binary.BigEndian.PutUint32(buf[offset:], uint32(len(req.Body)))
offset += 4
copy(buf[offset:], req.Body)
_, err := conn.Write(buf)
return err
}
// ReadRequest 从连接读取请求
func ReadRequest(conn net.Conn) (*Request, error) {
// MagicNumber
header := make([]byte, 4+1+8+2)
if _, err := io.ReadFull(conn, header); err != nil {
return nil, err
}
magic := binary.BigEndian.Uint32(header[0:4])
if magic != MagicNumber {
return nil, fmt.Errorf("invalid magic number: 0x%x", magic)
}
// version := header[4]
requestID := binary.BigEndian.Uint64(header[5:13])
methodLen := binary.BigEndian.Uint16(header[13:15])
// Method
methodBuf := make([]byte, methodLen)
if _, err := io.ReadFull(conn, methodBuf); err != nil {
return nil, err
}
// BodyLen + Body
bodyLenBuf := make([]byte, 4)
if _, err := io.ReadFull(conn, bodyLenBuf); err != nil {
return nil, err
}
bodyLen := binary.BigEndian.Uint32(bodyLenBuf)
body := make([]byte, bodyLen)
if bodyLen > 0 {
if _, err := io.ReadFull(conn, body); err != nil {
return nil, err
}
}
return &Request{
RequestID: requestID,
Method: string(methodBuf),
Body: body,
}, nil
}
// WriteResponse 写入响应
func WriteResponse(conn net.Conn, resp *Response) error {
totalLen := 4 + 1 + 8 + 1 + 4 + len(resp.Body)
buf := make([]byte, totalLen)
offset := 0
binary.BigEndian.PutUint32(buf[offset:], MagicNumber)
offset += 4
buf[offset] = Version
offset++
binary.BigEndian.PutUint64(buf[offset:], resp.RequestID)
offset += 8
buf[offset] = resp.Status
offset++
binary.BigEndian.PutUint32(buf[offset:], uint32(len(resp.Body)))
offset += 4
copy(buf[offset:], resp.Body)
_, err := conn.Write(buf)
return err
}
// ReadResponse 读取响应
func ReadResponse(conn net.Conn) (*Response, error) {
header := make([]byte, 4+1+8+1+4)
if _, err := io.ReadFull(conn, header); err != nil {
return nil, err
}
magic := binary.BigEndian.Uint32(header[0:4])
if magic != MagicNumber {
return nil, fmt.Errorf("invalid magic number")
}
requestID := binary.BigEndian.Uint64(header[5:13])
status := header[13]
bodyLen := binary.BigEndian.Uint32(header[14:18])
body := make([]byte, bodyLen)
if bodyLen > 0 {
if _, err := io.ReadFull(conn, body); err != nil {
return nil, err
}
}
return &Response{
RequestID: requestID,
Status: status,
Body: body,
}, nil
}7.3 RPC Server
go
package main
import (
"context"
"encoding/json"
"fmt"
"log"
"net"
"reflect"
"runtime/debug"
"strings"
"sync"
"sync/atomic"
"time"
)
// HandlerFunc RPC 处理函数签名
type HandlerFunc func(ctx context.Context, params json.RawMessage) (interface{}, error)
// RPCServer 高并发 RPC 服务端
type RPCServer struct {
listener net.Listener
handlers map[string]HandlerFunc // 服务名.方法名 → 处理函数
mu sync.RWMutex
activeConn sync.WaitGroup
shutdown chan struct{}
// 统计
totalRequests atomic.Int64
activeRequests atomic.Int64
}
func NewRPCServer() *RPCServer {
return &RPCServer{
handlers: make(map[string]HandlerFunc),
shutdown: make(chan struct{}),
}
}
// Register 注册 RPC 方法
func (s *RPCServer) Register(method string, handler HandlerFunc) {
s.mu.Lock()
defer s.mu.Unlock()
s.handlers[method] = handler
log.Printf("注册方法: %s", method)
}
// RegisterService 注册结构体的所有导出方法
func (s *RPCServer) RegisterService(name string, service interface{}) {
val := reflect.ValueOf(service)
typ := val.Type()
for i := 0; i < typ.NumMethod(); i++ {
method := typ.Method(i)
methodName := name + "." + method.Name
// 捕获 method 变量
m := method
s.Register(methodName, func(ctx context.Context, params json.RawMessage) (interface{}, error) {
// 简化处理:直接调用方法,传入 JSON 字符串参数
args := []reflect.Value{val, reflect.ValueOf(ctx), reflect.ValueOf(params)}
results := m.Func.Call(args)
var result interface{}
var err error
if len(results) > 0 {
result = results[0].Interface()
}
if len(results) > 1 && !results[1].IsNil() {
err = results[1].Interface().(error)
}
return result, err
})
}
}
// Start 启动服务
func (s *RPCServer) Start(addr string) error {
listener, err := net.Listen("tcp", addr)
if err != nil {
return err
}
s.listener = listener
log.Printf("RPC Server 启动,监听 %s", addr)
// 统计打印
go s.statsLoop()
for {
conn, err := listener.Accept()
if err != nil {
select {
case <-s.shutdown:
return nil // 正常关闭
default:
log.Printf("Accept 错误: %v", err)
continue
}
}
s.activeConn.Add(1)
go s.handleConn(conn)
}
}
// Shutdown 优雅关闭
func (s *RPCServer) Shutdown(timeout time.Duration) {
close(s.shutdown)
s.listener.Close()
// 等待所有连接处理完
done := make(chan struct{})
go func() {
s.activeConn.Wait()
close(done)
}()
select {
case <-done:
log.Println("所有连接已处理完毕")
case <-time.After(timeout):
log.Println("关闭超时,强制退出")
}
}
func (s *RPCServer) handleConn(conn net.Conn) {
defer s.activeConn.Done()
defer conn.Close()
addr := conn.RemoteAddr().String()
log.Printf("[%s] 新连接", addr)
for {
// 读超时
conn.SetReadDeadline(time.Now().Add(60 * time.Second))
req, err := ReadRequest(conn)
if err != nil {
if !strings.Contains(err.Error(), "EOF") {
log.Printf("[%s] 读取请求失败: %v", addr, err)
}
return
}
s.totalRequests.Add(1)
s.activeRequests.Add(1)
// 每个请求在独立协程中处理(同一连接上的请求可以并发)
go s.handleRequest(conn, req, addr)
}
}
func (s *RPCServer) handleRequest(conn net.Conn, req *Request, addr string) {
defer s.activeRequests.Add(-1)
defer func() {
if r := recover(); r != nil {
log.Printf("[%s] panic: %v\n%s", addr, r, debug.Stack())
resp := &Response{
RequestID: req.RequestID,
Status: StatusError,
Body: []byte(`"internal server error"`),
}
WriteResponse(conn, resp)
}
}()
// 请求超时控制
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// 查找处理函数
s.mu.RLock()
handler, ok := s.handlers[req.Method]
s.mu.RUnlock()
if !ok {
resp := &Response{
RequestID: req.RequestID,
Status: StatusNotFound,
}
errMsg, _ := json.Marshal(fmt.Sprintf("method not found: %s", req.Method))
resp.Body = errMsg
WriteResponse(conn, resp)
return
}
// 执行 RPC 调用
resultCh := make(chan struct {
result interface{}
err error
}, 1)
go func() {
result, err := handler(ctx, req.Body)
resultCh <- struct {
result interface{}
err error
}{result, err}
}()
// 等待结果或超时
select {
case res := <-resultCh:
resp := &Response{RequestID: req.RequestID}
if res.err != nil {
resp.Status = StatusError
resp.Body, _ = json.Marshal(res.err.Error())
} else {
resp.Status = StatusOK
resp.Body, _ = json.Marshal(res.result)
}
WriteResponse(conn, resp)
case <-ctx.Done():
resp := &Response{
RequestID: req.RequestID,
Status: StatusTimeout,
}
resp.Body, _ = json.Marshal("request timeout")
WriteResponse(conn, resp)
}
}
func (s *RPCServer) statsLoop() {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
log.Printf("统计 | 总请求: %d, 活跃请求: %d",
s.totalRequests.Load(), s.activeRequests.Load())
case <-s.shutdown:
return
}
}
}7.4 RPC Client(带连接池)
go
package main
import (
"encoding/json"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
)
// RPCClient 带连接池的 RPC 客户端
type RPCClient struct {
addr string
pool chan net.Conn // 连接池
poolSize int
requestID atomic.Uint64
mu sync.Mutex
}
func NewRPCClient(addr string, poolSize int) (*RPCClient, error) {
client := &RPCClient{
addr: addr,
pool: make(chan net.Conn, poolSize),
poolSize: poolSize,
}
// 预创建连接
for i := 0; i < poolSize; i++ {
conn, err := net.DialTimeout("tcp", addr, 5*time.Second)
if err != nil {
return nil, fmt.Errorf("create connection %d: %w", i, err)
}
client.pool <- conn
}
return client, nil
}
// getConn 从连接池获取连接
func (c *RPCClient) getConn() (net.Conn, error) {
select {
case conn := <-c.pool:
return conn, nil
case <-time.After(5 * time.Second):
// 池为空,创建新连接
return net.DialTimeout("tcp", c.addr, 5*time.Second)
}
}
// putConn 归还连接到池
func (c *RPCClient) putConn(conn net.Conn) {
select {
case c.pool <- conn:
// 归还成功
default:
// 池已满,关闭连接
conn.Close()
}
}
// Call 发起 RPC 调用
func (c *RPCClient) Call(method string, params interface{}) (json.RawMessage, error) {
// 序列化参数
body, err := json.Marshal(params)
if err != nil {
return nil, fmt.Errorf("marshal params: %w", err)
}
// 构建请求
req := &Request{
RequestID: c.requestID.Add(1),
Method: method,
Body: body,
}
// 获取连接
conn, err := c.getConn()
if err != nil {
return nil, fmt.Errorf("get connection: %w", err)
}
// 设置超时
conn.SetDeadline(time.Now().Add(10 * time.Second))
// 发送请求
if err := WriteRequest(conn, req); err != nil {
conn.Close() // 连接可能已损坏,不要归还
return nil, fmt.Errorf("send request: %w", err)
}
// 读取响应
resp, err := ReadResponse(conn)
if err != nil {
conn.Close()
return nil, fmt.Errorf("read response: %w", err)
}
// 归还连接
c.putConn(conn)
// 检查响应状态
switch resp.Status {
case StatusOK:
return resp.Body, nil
case StatusNotFound:
return nil, fmt.Errorf("method not found: %s", method)
case StatusTimeout:
return nil, fmt.Errorf("request timeout")
default:
var errMsg string
json.Unmarshal(resp.Body, &errMsg)
return nil, fmt.Errorf("rpc error: %s", errMsg)
}
}
// Close 关闭客户端和所有连接
func (c *RPCClient) Close() {
close(c.pool)
for conn := range c.pool {
conn.Close()
}
}7.5 业务服务定义和使用
go
package main
import (
"context"
"encoding/json"
"fmt"
"log"
"math"
"sync"
"time"
)
// ==================== 定义业务服务 ====================
// MathService 数学计算服务
type MathService struct{}
type AddParams struct {
A float64 `json:"a"`
B float64 `json:"b"`
}
type AddResult struct {
Result float64 `json:"result"`
}
func (s *MathService) Add(ctx context.Context, data json.RawMessage) (interface{}, error) {
var params AddParams
if err := json.Unmarshal(data, ¶ms); err != nil {
return nil, fmt.Errorf("invalid params: %w", err)
}
return AddResult{Result: params.A + params.B}, nil
}
func (s *MathService) Sqrt(ctx context.Context, data json.RawMessage) (interface{}, error) {
var params struct {
N float64 `json:"n"`
}
if err := json.Unmarshal(data, ¶ms); err != nil {
return nil, err
}
if params.N < 0 {
return nil, fmt.Errorf("cannot sqrt negative number: %f", params.N)
}
return map[string]float64{"result": math.Sqrt(params.N)}, nil
}
// UserService 用户服务
type UserService struct {
users map[int]string
mu sync.RWMutex
}
func NewUserService() *UserService {
return &UserService{
users: map[int]string{
1: "Alice",
2: "Bob",
3: "Charlie",
},
}
}
func (s *UserService) GetUser(ctx context.Context, data json.RawMessage) (interface{}, error) {
var params struct {
ID int `json:"id"`
}
if err := json.Unmarshal(data, ¶ms); err != nil {
return nil, err
}
s.mu.RLock()
defer s.mu.RUnlock()
name, ok := s.users[params.ID]
if !ok {
return nil, fmt.Errorf("user %d not found", params.ID)
}
return map[string]interface{}{
"id": params.ID,
"name": name,
}, nil
}
// 模拟慢请求(测试超时)
func (s *UserService) SlowQuery(ctx context.Context, data json.RawMessage) (interface{}, error) {
select {
case <-time.After(15 * time.Second):
return "done", nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
// ==================== 服务端启动 ====================
func startServer() {
server := NewRPCServer()
// 注册方法(手动方式)
mathSvc := &MathService{}
server.Register("Math.Add", mathSvc.Add)
server.Register("Math.Sqrt", mathSvc.Sqrt)
userSvc := NewUserService()
server.Register("User.GetUser", userSvc.GetUser)
server.Register("User.SlowQuery", userSvc.SlowQuery)
// 启动
if err := server.Start(":9000"); err != nil {
log.Fatal(err)
}
}
// ==================== 客户端使用 ====================
func startClient() {
// 创建带连接池的客户端
client, err := NewRPCClient("localhost:9000", 10)
if err != nil {
log.Fatal("连接失败:", err)
}
defer client.Close()
// 调用 Math.Add
result, err := client.Call("Math.Add", AddParams{A: 10.5, B: 20.3})
if err != nil {
log.Fatal(err)
}
fmt.Println("Math.Add:", string(result)) // {"result":30.8}
// 调用 Math.Sqrt
result, err = client.Call("Math.Sqrt", map[string]float64{"n": 144})
if err != nil {
log.Fatal(err)
}
fmt.Println("Math.Sqrt:", string(result)) // {"result":12}
// 调用 User.GetUser
result, err = client.Call("User.GetUser", map[string]int{"id": 1})
if err != nil {
log.Fatal(err)
}
fmt.Println("User.GetUser:", string(result)) // {"id":1,"name":"Alice"}
// 调用不存在的方法
_, err = client.Call("Unknown.Method", nil)
fmt.Println("未知方法:", err) // method not found
// ==================== 并发压测 ====================
fmt.Println("\n=== 并发压测 ===")
var wg sync.WaitGroup
start := time.Now()
concurrency := 100
requestsPerGoroutine := 100
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < requestsPerGoroutine; j++ {
_, err := client.Call("Math.Add", AddParams{
A: float64(id),
B: float64(j),
})
if err != nil {
fmt.Printf("请求失败: %v\n", err)
}
}
}(i)
}
wg.Wait()
totalRequests := concurrency * requestsPerGoroutine
elapsed := time.Since(start)
fmt.Printf("完成 %d 个请求,耗时 %s\n", totalRequests, elapsed)
fmt.Printf("QPS: %.0f\n", float64(totalRequests)/elapsed.Seconds())
}
// ==================== 主函数 ====================
func main() {
// 启动服务端
go startServer()
time.Sleep(500 * time.Millisecond) // 等待服务端就绪
// 启动客户端
startClient()
}7.6 架构总结
text
RPC 框架架构:
客户端 服务端
┌────────────────────┐ ┌────────────────────────┐
│ RPCClient │ │ RPCServer │
│ ┌───────────────┐ │ TCP 连接 │ ┌──────────────────┐ │
│ │ 连接池 │ │ ◄════════════► │ │ Accept 循环 │ │
│ │ (chan Conn) │ │ │ └──────┬───────────┘ │
│ └───────┬───────┘ │ │ │ │
│ │ │ │ ┌──────▼───────────┐ │
│ ┌───────▼───────┐ │ │ │ handleConn │ │
│ │ Call() │ │ 二进制协议 │ │ (每连接一个协程) │ │
│ │ 序列化参数 │──┼──────────────→ │ │ 读取请求 │ │
│ │ 发送请求 │ │ │ └──────┬───────────┘ │
│ │ 读取响应 │←─┼────────────── │ │ │
│ │ 反序列化结果 │ │ │ ┌──────▼───────────┐ │
│ └───────────────┘ │ │ │ handleRequest │ │
│ │ │ │ (每请求一个协程) │ │
│ RequestID 匹配 │ │ │ 路由 → Handler │ │
│ 原子递增 ID │ │ │ 超时控制 │ │
└────────────────────┘ │ │ panic recover │ │
│ └──────────────────┘ │
│ │
│ handlers map: │
│ Math.Add → func │
│ Math.Sqrt → func │
│ User.GetUser → func │
└────────────────────────┘
关键设计:
1. 二进制协议:4 字节长度前缀,解决粘包
2. 连接池:复用 TCP 连接,减少握手开销
3. 每连接一协程:Accept 后立即分发
4. 每请求一协程:同一连接上的请求并发处理
5. RequestID:请求/响应匹配(本实现简化为串行)
6. context 超时:10 秒请求超时,60 秒连接超时
7. panic recover:单个请求崩溃不影响连接
8. 优雅关闭:等待活跃连接处理完毕附录:网络编程速查
text
函数 说明
──────────────────────────────────────────────────
net.Listen("tcp", ":8080") 监听 TCP 端口
listener.Accept() 接受连接(阻塞)
net.Dial("tcp", "host:port") 连接服务器
net.DialTimeout(...) 带超时的连接
conn.Read(buf) 读取数据
conn.Write(data) 写入数据
conn.Close() 关闭连接
conn.SetDeadline(t) 设置读写超时
conn.RemoteAddr() 获取对端地址
io.ReadFull(conn, buf) 精确读取 n 字节
net.ListenUDP("udp", addr) 监听 UDP 端口
conn.ReadFromUDP(buf) UDP 读取 + 来源地址
conn.WriteToUDP(data, addr) UDP 发送到指定地址
binary.BigEndian.PutUint32(...) 大端序编码
binary.BigEndian.Uint32(...) 大端序解码