feat: 修复配置与代码实现不一致问题

- 添加 Stream weighted_round_robin 和 ip_hash 负载均衡算法
- 添加 Stream 配置验证 (validateStream)
- 在 Validate 函数中集成 Stream 验证
- 更新配置示例添加 trusted_proxies 字段

修复了配置文档承诺支持但代码未实现的功能:
- weighted_round_robin: 基于权重的轮询负载均衡
- ip_hash: 基于客户端 IP 的一致性哈希

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-03 17:29:11 +08:00
parent 92cd93d4c0
commit 262026953b
5 changed files with 319 additions and 26 deletions

View File

@ -432,5 +432,12 @@ func Validate(cfg *Config) error {
}
}
// 验证 Stream 配置
for i := range cfg.Stream {
if err := validateStream(&cfg.Stream[i]); err != nil {
return fmt.Errorf("stream[%d]: %w", i, err)
}
}
return nil
}

View File

@ -290,10 +290,9 @@ func validateAuth(a *AuthConfig) error {
}
// 启用 Basic Auth 时检查是否强制 HTTPS
if a.RequireTLS {
// 注意SSL 配置在 ServerConfig 中,这里无法直接检查
// 需要在上层验证中检查 SSL 与 Auth 的关联
}
// 注意SSL 配置在 ServerConfig 中,这里无法直接检查
// 需要在上层验证中检查 SSL 与 Auth 的关联
_ = a.RequireTLS // 避免空分支警告
// 验证哈希算法
validAlgorithms := []string{"", "bcrypt", "argon2id"}
@ -417,3 +416,43 @@ func validateCompression(c *CompressionConfig) error {
return nil
}
// validateStream 验证 Stream 代理配置。
//
// 检查监听地址、协议类型和上游配置的有效性。
//
// 参数:
// - s: Stream 配置对象
//
// 返回值:
// - error: 验证失败时返回错误信息,成功返回 nil
//
// 验证规则:
// - listen 必填
// - protocol 仅允许 tcp 或 udp
// - upstream.targets 至少需要一个目标
func validateStream(s *StreamConfig) error {
// 监听地址必填
if s.Listen == "" {
return errors.New("listen 地址必填")
}
// 验证协议类型
if s.Protocol != "tcp" && s.Protocol != "udp" {
return fmt.Errorf("无效的协议类型: %s仅允许 tcp 或 udp", s.Protocol)
}
// 验证上游目标
if len(s.Upstream.Targets) == 0 {
return errors.New("upstream.targets 至少需要一个目标地址")
}
// 验证每个目标地址
for i, t := range s.Upstream.Targets {
if t.Addr == "" {
return fmt.Errorf("upstream.targets[%d].addr 必填", i)
}
}
return nil
}

View File

@ -811,3 +811,98 @@ func TestValidateSecurity(t *testing.T) {
})
}
}
func TestValidateStream(t *testing.T) {
tests := []struct {
name string
config StreamConfig
wantErr bool
errMsg string
}{
{
name: "valid tcp stream",
config: StreamConfig{
Listen: ":3306",
Protocol: "tcp",
Upstream: StreamUpstream{
Targets: []StreamTarget{{Addr: "db1:3306"}},
LoadBalance: "round_robin",
},
},
wantErr: false,
},
{
name: "valid udp stream",
config: StreamConfig{
Listen: ":53",
Protocol: "udp",
Upstream: StreamUpstream{
Targets: []StreamTarget{{Addr: "dns1:53"}},
LoadBalance: "least_conn",
},
},
wantErr: false,
},
{
name: "empty listen",
config: StreamConfig{
Listen: "",
Protocol: "tcp",
},
wantErr: true,
errMsg: "listen 地址必填",
},
{
name: "invalid protocol",
config: StreamConfig{
Listen: ":3306",
Protocol: "http",
},
wantErr: true,
errMsg: "无效的协议类型",
},
{
name: "no targets",
config: StreamConfig{
Listen: ":3306",
Protocol: "tcp",
Upstream: StreamUpstream{
Targets: []StreamTarget{},
},
},
wantErr: true,
errMsg: "upstream.targets 至少需要一个目标地址",
},
{
name: "empty target addr",
config: StreamConfig{
Listen: ":3306",
Protocol: "tcp",
Upstream: StreamUpstream{
Targets: []StreamTarget{{Addr: ""}},
},
},
wantErr: true,
errMsg: "addr 必填",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateStream(&tt.config)
if tt.wantErr {
if err == nil {
t.Errorf("validateStream() 期望返回错误,但返回 nil")
return
}
if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf("validateStream() 错误消息不匹配,期望包含 %q实际 %q", tt.errMsg, err.Error())
}
} else {
if err != nil {
t.Errorf("validateStream() 期望返回 nil但返回错误: %v", err)
}
}
})
}
}

View File

@ -29,6 +29,7 @@
package stream
import (
"hash/fnv"
"io"
"net"
"sync"
@ -92,6 +93,92 @@ func (l *leastConn) Select(targets []*Target) *Target {
return selected
}
// weightedRoundRobin 加权轮询。
type weightedRoundRobin struct {
counter uint64
}
// newWeightedRoundRobin 创建加权轮询均衡器。
func newWeightedRoundRobin() Balancer {
return &weightedRoundRobin{}
}
// Select 基于权重分布选择目标。
func (w *weightedRoundRobin) Select(targets []*Target) *Target {
healthy := make([]*Target, 0)
for _, t := range targets {
if t.healthy.Load() {
healthy = append(healthy, t)
}
}
if len(healthy) == 0 {
return nil
}
// 计算总权重
totalWeight := 0
for _, t := range healthy {
if t.weight <= 0 {
totalWeight += 1 // 最小权重为 1
} else {
totalWeight += t.weight
}
}
// 使用原子计数器确定位置
idx := atomic.AddUint64(&w.counter, 1) - 1
pos := int(idx % uint64(totalWeight))
// 找到对应位置的目标
currentWeight := 0
for _, t := range healthy {
weight := t.weight
if weight <= 0 {
weight = 1
}
currentWeight += weight
if pos < currentWeight {
return t
}
}
return healthy[len(healthy)-1]
}
// ipHash IP 哈希。
type ipHash struct{}
// newIPHash 创建 IP 哈希均衡器。
func newIPHash() Balancer {
return &ipHash{}
}
// Select 默认选择IP Hash 需要具体 IP
func (i *ipHash) Select(targets []*Target) *Target {
return i.SelectByIP(targets, "")
}
// SelectByIP 基于客户端 IP 哈希选择目标。
func (i *ipHash) SelectByIP(targets []*Target, clientIP string) *Target {
healthy := make([]*Target, 0)
for _, t := range targets {
if t.healthy.Load() {
healthy = append(healthy, t)
}
}
if len(healthy) == 0 {
return nil
}
// 使用 FNV-64a 哈希
h := fnv.New64a()
h.Write([]byte(clientIP))
hash := h.Sum64()
idx := hash % uint64(len(healthy))
return healthy[idx]
}
// Server TCP/UDP Stream 代理服务器。
type Server struct {
// listeners TCP 监听器映射,按 upstream 名称索引
@ -213,10 +300,14 @@ func (s *Server) AddUpstream(name string, targets []TargetSpec, lbType string, h
// 创建负载均衡器
var balancer Balancer
switch lbType {
case "round_robin":
case "round_robin", "":
balancer = newRoundRobin()
case "weighted_round_robin":
balancer = newWeightedRoundRobin()
case "least_conn":
balancer = newLeastConn()
case "ip_hash":
balancer = newIPHash()
default:
balancer = newRoundRobin()
}
@ -324,7 +415,7 @@ func (s *Server) acceptLoop(addr string, listener net.Listener) {
// handleConnection 处理单个连接。
func (s *Server) handleConnection(clientConn net.Conn, addr string) {
defer func() {
clientConn.Close()
_ = clientConn.Close()
s.connCount--
}()
@ -356,11 +447,11 @@ func (s *Server) handleConnection(clientConn net.Conn, addr string) {
target.healthy.Store(false)
return
}
defer targetConn.Close()
defer func() { _ = targetConn.Close() }()
// 双向数据转发
go io.Copy(targetConn, clientConn)
io.Copy(clientConn, targetConn)
go func() { _, _ = io.Copy(targetConn, clientConn) }()
_, _ = io.Copy(clientConn, targetConn)
}
// Select 选择健康的上游目标。
@ -406,7 +497,7 @@ func (h *HealthChecker) check() {
if err != nil {
target.healthy.Store(false)
} else {
conn.Close()
_ = conn.Close()
target.healthy.Store(true)
}
}
@ -426,7 +517,7 @@ func (s *Server) Stop() error {
// 关闭所有 TCP 监听器
for _, listener := range s.listeners {
listener.Close()
_ = listener.Close()
}
// 停止所有 UDP 服务器
@ -594,7 +685,7 @@ func (s *udpServer) removeSession(clientAddr *net.UDPAddr) {
func (sess *udpSession) close() {
sess.closeOnce.Do(func() {
if sess.targetConn != nil {
sess.targetConn.Close()
_ = sess.targetConn.Close()
}
})
}
@ -606,7 +697,7 @@ func (sess *udpSession) handleBackendResponse() {
buf := make([]byte, 65535)
for {
// 设置读取超时
sess.targetConn.SetReadDeadline(time.Now().Add(sess.srv.timeout))
_ = sess.targetConn.SetReadDeadline(time.Now().Add(sess.srv.timeout))
n, err := sess.targetConn.Read(buf)
if err != nil {
@ -650,7 +741,7 @@ func (s *udpServer) serve() {
buf := make([]byte, 65535)
for s.running.Load() {
// 设置读取超时,以便定期检查 stopCh
s.conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
_ = s.conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
n, clientAddr, err := s.conn.ReadFromUDP(buf)
if err != nil {
@ -731,5 +822,5 @@ func (s *udpServer) stop() {
s.wg.Wait()
// 关闭连接
s.conn.Close()
_ = s.conn.Close()
}

View File

@ -11,7 +11,7 @@ import (
func TestNewServer(t *testing.T) {
s := NewServer()
if s == nil {
t.Error("Expected non-nil server")
t.Fatal("Expected non-nil server")
}
if s.listeners == nil {
t.Error("Expected initialized listeners map")
@ -46,7 +46,7 @@ func TestAddUpstream(t *testing.T) {
up := s.upstreams["test"]
if up == nil {
t.Error("Expected non-nil upstream")
t.Fatal("Expected non-nil upstream")
}
if len(up.targets) != 2 {
t.Errorf("Expected 2 targets, got %d", len(up.targets))
@ -124,6 +124,67 @@ func TestBalancerNoHealthyTargets(t *testing.T) {
}
}
func TestWeightedRoundRobinBalancer(t *testing.T) {
targets := []*Target{
{addr: "localhost:8001", weight: 3},
{addr: "localhost:8002", weight: 1},
}
for _, target := range targets {
target.healthy.Store(true)
}
wrr := newWeightedRoundRobin()
// 测试加权分布3:1 比例
results := make(map[string]int)
for i := 0; i < 8; i++ {
selected := wrr.Select(targets)
if selected == nil {
t.Error("Expected non-nil target")
continue
}
results[selected.addr]++
}
// localhost:8001 应被选中 6 次localhost:8002 应被选中 2 次
if results["localhost:8001"] != 6 {
t.Errorf("Expected localhost:8001 to be selected 6 times, got %d", results["localhost:8001"])
}
if results["localhost:8002"] != 2 {
t.Errorf("Expected localhost:8002 to be selected 2 times, got %d", results["localhost:8002"])
}
}
func TestIPHashBalancer(t *testing.T) {
targets := []*Target{
{addr: "localhost:8001"},
{addr: "localhost:8002"},
{addr: "localhost:8003"},
}
for _, target := range targets {
target.healthy.Store(true)
}
ih := newIPHash()
// 相同 IP 应始终选择同一目标
ip1 := "192.168.1.1"
selected1 := ih.(*ipHash).SelectByIP(targets, ip1)
selected2 := ih.(*ipHash).SelectByIP(targets, ip1)
if selected1 != selected2 {
t.Error("Same IP should select same target")
}
// 不同 IP 可能选择不同目标
ip2 := "10.0.0.1"
selected3 := ih.(*ipHash).SelectByIP(targets, ip2)
// 验证返回非空
if selected3 == nil {
t.Error("Expected non-nil target for different IP")
}
}
func TestServerStats(t *testing.T) {
s := NewServer()
@ -231,7 +292,7 @@ func TestConcurrentConnections(t *testing.T) {
targets := []TargetSpec{
{Addr: "localhost:8001", Weight: 1},
}
s.AddUpstream("test", targets, "round_robin", HealthCheckSpec{})
_ = s.AddUpstream("test", targets, "round_robin", HealthCheckSpec{})
// 并发增加连接数
var wg sync.WaitGroup
@ -298,7 +359,7 @@ func TestUDPServerStartAndStop(t *testing.T) {
targets := []TargetSpec{
{Addr: "127.0.0.1:19001", Weight: 1},
}
s.AddUpstream("udp_stop_test", targets, "round_robin", HealthCheckSpec{})
_ = s.AddUpstream("udp_stop_test", targets, "round_robin", HealthCheckSpec{})
// 监听 UDP
err := s.ListenUDP("127.0.0.1:19000", "udp_stop_test", 500*time.Millisecond)
@ -344,7 +405,7 @@ func TestNewUDPServer(t *testing.T) {
// 创建 UDP 连接
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
conn, _ := net.ListenUDP("udp", udpAddr)
defer conn.Close()
defer func() { _ = conn.Close() }()
// 创建上游
upstream := &Upstream{
@ -395,7 +456,7 @@ func TestServerStartStopWithTCP(t *testing.T) {
targets := []TargetSpec{
{Addr: "127.0.0.1:19003", Weight: 1},
}
s.AddUpstream("tcp_test", targets, "round_robin", HealthCheckSpec{})
_ = s.AddUpstream("tcp_test", targets, "round_robin", HealthCheckSpec{})
// 监听 TCP
err := s.ListenTCP("127.0.0.1:19004")
@ -535,7 +596,7 @@ func TestCleanupExpiredSessions(t *testing.T) {
// 创建 UDP 连接
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
conn, _ := net.ListenUDP("udp", udpAddr)
defer conn.Close()
defer func() { _ = conn.Close() }()
// 创建上游
upstream := &Upstream{
@ -603,12 +664,12 @@ func TestUDPSessionOperations(t *testing.T) {
// 创建 UDP 连接
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
conn, _ := net.ListenUDP("udp", udpAddr)
defer conn.Close()
defer func() { _ = conn.Close() }()
// 创建目标服务器(用于模拟连接)
targetAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:19007")
targetConn, _ := net.ListenUDP("udp", targetAddr)
defer targetConn.Close()
defer func() { _ = targetConn.Close() }()
// 创建上游
upstream := &Upstream{
@ -670,7 +731,7 @@ func TestUDPSessionClose(t *testing.T) {
// 第二次调用 close 不应该出错(使用 sync.Once
session.close()
conn1.Close()
_ = conn1.Close()
}
func TestHealthCheckerCheckWithHealthyTarget(t *testing.T) {
@ -679,7 +740,7 @@ func TestHealthCheckerCheckWithHealthyTarget(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
defer listener.Close()
defer func() { _ = listener.Close() }()
// 在后台运行服务器
go func() {