diff --git a/internal/config/config.go b/internal/config/config.go index a4038fe..a323c8c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -432,5 +432,12 @@ func Validate(cfg *Config) error { } } + // 验证 Stream 配置 + for i := range cfg.Stream { + if err := validateStream(&cfg.Stream[i]); err != nil { + return fmt.Errorf("stream[%d]: %w", i, err) + } + } + return nil } diff --git a/internal/config/validate.go b/internal/config/validate.go index 17671c6..a0e1773 100644 --- a/internal/config/validate.go +++ b/internal/config/validate.go @@ -290,10 +290,9 @@ func validateAuth(a *AuthConfig) error { } // 启用 Basic Auth 时检查是否强制 HTTPS - if a.RequireTLS { - // 注意:SSL 配置在 ServerConfig 中,这里无法直接检查 - // 需要在上层验证中检查 SSL 与 Auth 的关联 - } + // 注意:SSL 配置在 ServerConfig 中,这里无法直接检查 + // 需要在上层验证中检查 SSL 与 Auth 的关联 + _ = a.RequireTLS // 避免空分支警告 // 验证哈希算法 validAlgorithms := []string{"", "bcrypt", "argon2id"} @@ -417,3 +416,43 @@ func validateCompression(c *CompressionConfig) error { return nil } + +// validateStream 验证 Stream 代理配置。 +// +// 检查监听地址、协议类型和上游配置的有效性。 +// +// 参数: +// - s: Stream 配置对象 +// +// 返回值: +// - error: 验证失败时返回错误信息,成功返回 nil +// +// 验证规则: +// - listen 必填 +// - protocol 仅允许 tcp 或 udp +// - upstream.targets 至少需要一个目标 +func validateStream(s *StreamConfig) error { + // 监听地址必填 + if s.Listen == "" { + return errors.New("listen 地址必填") + } + + // 验证协议类型 + if s.Protocol != "tcp" && s.Protocol != "udp" { + return fmt.Errorf("无效的协议类型: %s(仅允许 tcp 或 udp)", s.Protocol) + } + + // 验证上游目标 + if len(s.Upstream.Targets) == 0 { + return errors.New("upstream.targets 至少需要一个目标地址") + } + + // 验证每个目标地址 + for i, t := range s.Upstream.Targets { + if t.Addr == "" { + return fmt.Errorf("upstream.targets[%d].addr 必填", i) + } + } + + return nil +} diff --git a/internal/config/validate_test.go b/internal/config/validate_test.go index 55d58ba..f9f51d7 100644 --- a/internal/config/validate_test.go +++ b/internal/config/validate_test.go @@ -811,3 +811,98 @@ func TestValidateSecurity(t *testing.T) { }) } } + +func TestValidateStream(t *testing.T) { + tests := []struct { + name string + config StreamConfig + wantErr bool + errMsg string + }{ + { + name: "valid tcp stream", + config: StreamConfig{ + Listen: ":3306", + Protocol: "tcp", + Upstream: StreamUpstream{ + Targets: []StreamTarget{{Addr: "db1:3306"}}, + LoadBalance: "round_robin", + }, + }, + wantErr: false, + }, + { + name: "valid udp stream", + config: StreamConfig{ + Listen: ":53", + Protocol: "udp", + Upstream: StreamUpstream{ + Targets: []StreamTarget{{Addr: "dns1:53"}}, + LoadBalance: "least_conn", + }, + }, + wantErr: false, + }, + { + name: "empty listen", + config: StreamConfig{ + Listen: "", + Protocol: "tcp", + }, + wantErr: true, + errMsg: "listen 地址必填", + }, + { + name: "invalid protocol", + config: StreamConfig{ + Listen: ":3306", + Protocol: "http", + }, + wantErr: true, + errMsg: "无效的协议类型", + }, + { + name: "no targets", + config: StreamConfig{ + Listen: ":3306", + Protocol: "tcp", + Upstream: StreamUpstream{ + Targets: []StreamTarget{}, + }, + }, + wantErr: true, + errMsg: "upstream.targets 至少需要一个目标地址", + }, + { + name: "empty target addr", + config: StreamConfig{ + Listen: ":3306", + Protocol: "tcp", + Upstream: StreamUpstream{ + Targets: []StreamTarget{{Addr: ""}}, + }, + }, + wantErr: true, + errMsg: "addr 必填", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateStream(&tt.config) + if tt.wantErr { + if err == nil { + t.Errorf("validateStream() 期望返回错误,但返回 nil") + return + } + if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("validateStream() 错误消息不匹配,期望包含 %q,实际 %q", tt.errMsg, err.Error()) + } + } else { + if err != nil { + t.Errorf("validateStream() 期望返回 nil,但返回错误: %v", err) + } + } + }) + } +} diff --git a/internal/stream/stream.go b/internal/stream/stream.go index 786131c..498ef3b 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -29,6 +29,7 @@ package stream import ( + "hash/fnv" "io" "net" "sync" @@ -92,6 +93,92 @@ func (l *leastConn) Select(targets []*Target) *Target { return selected } +// weightedRoundRobin 加权轮询。 +type weightedRoundRobin struct { + counter uint64 +} + +// newWeightedRoundRobin 创建加权轮询均衡器。 +func newWeightedRoundRobin() Balancer { + return &weightedRoundRobin{} +} + +// Select 基于权重分布选择目标。 +func (w *weightedRoundRobin) Select(targets []*Target) *Target { + healthy := make([]*Target, 0) + for _, t := range targets { + if t.healthy.Load() { + healthy = append(healthy, t) + } + } + if len(healthy) == 0 { + return nil + } + + // 计算总权重 + totalWeight := 0 + for _, t := range healthy { + if t.weight <= 0 { + totalWeight += 1 // 最小权重为 1 + } else { + totalWeight += t.weight + } + } + + // 使用原子计数器确定位置 + idx := atomic.AddUint64(&w.counter, 1) - 1 + pos := int(idx % uint64(totalWeight)) + + // 找到对应位置的目标 + currentWeight := 0 + for _, t := range healthy { + weight := t.weight + if weight <= 0 { + weight = 1 + } + currentWeight += weight + if pos < currentWeight { + return t + } + } + + return healthy[len(healthy)-1] +} + +// ipHash IP 哈希。 +type ipHash struct{} + +// newIPHash 创建 IP 哈希均衡器。 +func newIPHash() Balancer { + return &ipHash{} +} + +// Select 默认选择(IP Hash 需要具体 IP)。 +func (i *ipHash) Select(targets []*Target) *Target { + return i.SelectByIP(targets, "") +} + +// SelectByIP 基于客户端 IP 哈希选择目标。 +func (i *ipHash) SelectByIP(targets []*Target, clientIP string) *Target { + healthy := make([]*Target, 0) + for _, t := range targets { + if t.healthy.Load() { + healthy = append(healthy, t) + } + } + if len(healthy) == 0 { + return nil + } + + // 使用 FNV-64a 哈希 + h := fnv.New64a() + h.Write([]byte(clientIP)) + hash := h.Sum64() + + idx := hash % uint64(len(healthy)) + return healthy[idx] +} + // Server TCP/UDP Stream 代理服务器。 type Server struct { // listeners TCP 监听器映射,按 upstream 名称索引 @@ -213,10 +300,14 @@ func (s *Server) AddUpstream(name string, targets []TargetSpec, lbType string, h // 创建负载均衡器 var balancer Balancer switch lbType { - case "round_robin": + case "round_robin", "": balancer = newRoundRobin() + case "weighted_round_robin": + balancer = newWeightedRoundRobin() case "least_conn": balancer = newLeastConn() + case "ip_hash": + balancer = newIPHash() default: balancer = newRoundRobin() } @@ -324,7 +415,7 @@ func (s *Server) acceptLoop(addr string, listener net.Listener) { // handleConnection 处理单个连接。 func (s *Server) handleConnection(clientConn net.Conn, addr string) { defer func() { - clientConn.Close() + _ = clientConn.Close() s.connCount-- }() @@ -356,11 +447,11 @@ func (s *Server) handleConnection(clientConn net.Conn, addr string) { target.healthy.Store(false) return } - defer targetConn.Close() + defer func() { _ = targetConn.Close() }() // 双向数据转发 - go io.Copy(targetConn, clientConn) - io.Copy(clientConn, targetConn) + go func() { _, _ = io.Copy(targetConn, clientConn) }() + _, _ = io.Copy(clientConn, targetConn) } // Select 选择健康的上游目标。 @@ -406,7 +497,7 @@ func (h *HealthChecker) check() { if err != nil { target.healthy.Store(false) } else { - conn.Close() + _ = conn.Close() target.healthy.Store(true) } } @@ -426,7 +517,7 @@ func (s *Server) Stop() error { // 关闭所有 TCP 监听器 for _, listener := range s.listeners { - listener.Close() + _ = listener.Close() } // 停止所有 UDP 服务器 @@ -594,7 +685,7 @@ func (s *udpServer) removeSession(clientAddr *net.UDPAddr) { func (sess *udpSession) close() { sess.closeOnce.Do(func() { if sess.targetConn != nil { - sess.targetConn.Close() + _ = sess.targetConn.Close() } }) } @@ -606,7 +697,7 @@ func (sess *udpSession) handleBackendResponse() { buf := make([]byte, 65535) for { // 设置读取超时 - sess.targetConn.SetReadDeadline(time.Now().Add(sess.srv.timeout)) + _ = sess.targetConn.SetReadDeadline(time.Now().Add(sess.srv.timeout)) n, err := sess.targetConn.Read(buf) if err != nil { @@ -650,7 +741,7 @@ func (s *udpServer) serve() { buf := make([]byte, 65535) for s.running.Load() { // 设置读取超时,以便定期检查 stopCh - s.conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + _ = s.conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) n, clientAddr, err := s.conn.ReadFromUDP(buf) if err != nil { @@ -731,5 +822,5 @@ func (s *udpServer) stop() { s.wg.Wait() // 关闭连接 - s.conn.Close() + _ = s.conn.Close() } diff --git a/internal/stream/stream_test.go b/internal/stream/stream_test.go index 85ddd1a..981da20 100644 --- a/internal/stream/stream_test.go +++ b/internal/stream/stream_test.go @@ -11,7 +11,7 @@ import ( func TestNewServer(t *testing.T) { s := NewServer() if s == nil { - t.Error("Expected non-nil server") + t.Fatal("Expected non-nil server") } if s.listeners == nil { t.Error("Expected initialized listeners map") @@ -46,7 +46,7 @@ func TestAddUpstream(t *testing.T) { up := s.upstreams["test"] if up == nil { - t.Error("Expected non-nil upstream") + t.Fatal("Expected non-nil upstream") } if len(up.targets) != 2 { t.Errorf("Expected 2 targets, got %d", len(up.targets)) @@ -124,6 +124,67 @@ func TestBalancerNoHealthyTargets(t *testing.T) { } } +func TestWeightedRoundRobinBalancer(t *testing.T) { + targets := []*Target{ + {addr: "localhost:8001", weight: 3}, + {addr: "localhost:8002", weight: 1}, + } + for _, target := range targets { + target.healthy.Store(true) + } + + wrr := newWeightedRoundRobin() + + // 测试加权分布:3:1 比例 + results := make(map[string]int) + for i := 0; i < 8; i++ { + selected := wrr.Select(targets) + if selected == nil { + t.Error("Expected non-nil target") + continue + } + results[selected.addr]++ + } + + // localhost:8001 应被选中 6 次,localhost:8002 应被选中 2 次 + if results["localhost:8001"] != 6 { + t.Errorf("Expected localhost:8001 to be selected 6 times, got %d", results["localhost:8001"]) + } + if results["localhost:8002"] != 2 { + t.Errorf("Expected localhost:8002 to be selected 2 times, got %d", results["localhost:8002"]) + } +} + +func TestIPHashBalancer(t *testing.T) { + targets := []*Target{ + {addr: "localhost:8001"}, + {addr: "localhost:8002"}, + {addr: "localhost:8003"}, + } + for _, target := range targets { + target.healthy.Store(true) + } + + ih := newIPHash() + + // 相同 IP 应始终选择同一目标 + ip1 := "192.168.1.1" + selected1 := ih.(*ipHash).SelectByIP(targets, ip1) + selected2 := ih.(*ipHash).SelectByIP(targets, ip1) + + if selected1 != selected2 { + t.Error("Same IP should select same target") + } + + // 不同 IP 可能选择不同目标 + ip2 := "10.0.0.1" + selected3 := ih.(*ipHash).SelectByIP(targets, ip2) + // 验证返回非空 + if selected3 == nil { + t.Error("Expected non-nil target for different IP") + } +} + func TestServerStats(t *testing.T) { s := NewServer() @@ -231,7 +292,7 @@ func TestConcurrentConnections(t *testing.T) { targets := []TargetSpec{ {Addr: "localhost:8001", Weight: 1}, } - s.AddUpstream("test", targets, "round_robin", HealthCheckSpec{}) + _ = s.AddUpstream("test", targets, "round_robin", HealthCheckSpec{}) // 并发增加连接数 var wg sync.WaitGroup @@ -298,7 +359,7 @@ func TestUDPServerStartAndStop(t *testing.T) { targets := []TargetSpec{ {Addr: "127.0.0.1:19001", Weight: 1}, } - s.AddUpstream("udp_stop_test", targets, "round_robin", HealthCheckSpec{}) + _ = s.AddUpstream("udp_stop_test", targets, "round_robin", HealthCheckSpec{}) // 监听 UDP err := s.ListenUDP("127.0.0.1:19000", "udp_stop_test", 500*time.Millisecond) @@ -344,7 +405,7 @@ func TestNewUDPServer(t *testing.T) { // 创建 UDP 连接 udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") conn, _ := net.ListenUDP("udp", udpAddr) - defer conn.Close() + defer func() { _ = conn.Close() }() // 创建上游 upstream := &Upstream{ @@ -395,7 +456,7 @@ func TestServerStartStopWithTCP(t *testing.T) { targets := []TargetSpec{ {Addr: "127.0.0.1:19003", Weight: 1}, } - s.AddUpstream("tcp_test", targets, "round_robin", HealthCheckSpec{}) + _ = s.AddUpstream("tcp_test", targets, "round_robin", HealthCheckSpec{}) // 监听 TCP err := s.ListenTCP("127.0.0.1:19004") @@ -535,7 +596,7 @@ func TestCleanupExpiredSessions(t *testing.T) { // 创建 UDP 连接 udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") conn, _ := net.ListenUDP("udp", udpAddr) - defer conn.Close() + defer func() { _ = conn.Close() }() // 创建上游 upstream := &Upstream{ @@ -603,12 +664,12 @@ func TestUDPSessionOperations(t *testing.T) { // 创建 UDP 连接 udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") conn, _ := net.ListenUDP("udp", udpAddr) - defer conn.Close() + defer func() { _ = conn.Close() }() // 创建目标服务器(用于模拟连接) targetAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:19007") targetConn, _ := net.ListenUDP("udp", targetAddr) - defer targetConn.Close() + defer func() { _ = targetConn.Close() }() // 创建上游 upstream := &Upstream{ @@ -670,7 +731,7 @@ func TestUDPSessionClose(t *testing.T) { // 第二次调用 close 不应该出错(使用 sync.Once) session.close() - conn1.Close() + _ = conn1.Close() } func TestHealthCheckerCheckWithHealthyTarget(t *testing.T) { @@ -679,7 +740,7 @@ func TestHealthCheckerCheckWithHealthyTarget(t *testing.T) { if err != nil { t.Fatalf("Failed to create listener: %v", err) } - defer listener.Close() + defer func() { _ = listener.Close() }() // 在后台运行服务器 go func() {