diff --git a/config.example.yaml b/config.example.yaml index a44e66d..daea49e 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -25,7 +25,9 @@ server: # weight: 3 # 权重(加权轮询时有效) # - url: http://backend2:8080 # weight: 1 - # load_balance: round_robin # 负载均衡算法: round_robin, weighted_round_robin, least_conn, ip_hash + # load_balance: round_robin # 负载均衡算法: round_robin, weighted_round_robin, least_conn, ip_hash, consistent_hash + # hash_key: "ip" # 一致性哈希键(有效值: ip, uri, header:X-Name) + # virtual_nodes: 150 # 一致性哈希虚拟节点数 # health_check: # 健康检查 # interval: 10s # path: /health @@ -73,6 +75,9 @@ server: burst: 0 # 突发上限 conn_limit: 0 # 连接数限制 key: "ip" # 限流 key 来源(有效值: ip, header) + algorithm: "token_bucket" # 限流算法(有效值: token_bucket, sliding_window) + sliding_window_mode: "approximate" # 滑动窗口模式(有效值: approximate, precise) + sliding_window: 60 # 滑动窗口大小(秒) # 认证配置(type 为空时禁用) auth: @@ -108,6 +113,14 @@ server: - "text/javascript" - "application/json" - "application/javascript" + gzip_static: true # 启用预压缩文件支持(检测 .gz 文件) + gzip_static_extensions: # 预压缩文件扩展名 + - ".html" + - ".css" + - ".js" + - ".json" + - ".xml" + - ".svg" # 多虚拟主机模式(可选,每个虚拟主机支持完整的 server 配置) # servers: @@ -158,6 +171,14 @@ server: # 默认 TLS 协议: TLSv1.2, TLSv1.3(不支持 TLSv1.0/1.1) # 默认 HSTS 配置: max_age=31536000(1年), include_sub_domains=true +# HTTP/3 (QUIC) 配置(可选) +# http3: +# enabled: false # 是否启用 HTTP/3 +# listen: ":443" # UDP 监听地址 +# max_streams: 100 # 最大并发流数 +# idle_timeout: 30s # 空闲超时 +# enable_0rtt: false # 启用 0-RTT 快速连接 + # TCP/UDP Stream 代理配置(可选) # stream: # - listen: "3306" # 监听地址 diff --git a/internal/config/defaults.go b/internal/config/defaults.go index c28c17e..b0cf510 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -65,10 +65,13 @@ func DefaultConfig() *Config { Default: "allow", }, RateLimit: RateLimitConfig{ - RequestRate: 0, - Burst: 0, - ConnLimit: 0, - Key: "ip", + RequestRate: 0, + Burst: 0, + ConnLimit: 0, + Key: "ip", + Algorithm: "token_bucket", + SlidingWindowMode: "approximate", + SlidingWindow: 60, }, Auth: AuthConfig{ RequireTLS: true, @@ -83,9 +86,11 @@ func DefaultConfig() *Config { }, }, Compression: CompressionConfig{ - Type: "gzip", - Level: 6, - MinSize: 1024, + Type: "gzip", + Level: 6, + MinSize: 1024, + GzipStatic: false, + GzipStaticExtensions: []string{".gz", ".br"}, Types: []string{ "text/html", "text/css", @@ -133,6 +138,13 @@ func DefaultConfig() *Config { Allow: []string{"127.0.0.1"}, }, }, + HTTP3: HTTP3Config{ + Enabled: false, + Listen: ":443", + MaxStreams: 100, + IdleTimeout: 60 * time.Second, + Enable0RTT: false, + }, } } @@ -190,7 +202,9 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) { buf.WriteString(" # weight: 3 # 权重(加权轮询时有效)\n") buf.WriteString(" # - url: http://backend2:8080\n") buf.WriteString(" # weight: 1\n") - buf.WriteString(" # load_balance: round_robin # 负载均衡算法: round_robin, weighted_round_robin, least_conn, ip_hash\n") + buf.WriteString(" # load_balance: round_robin # 负载均衡算法(有效值: round_robin, weighted_round_robin, least_conn, ip_hash, consistent_hash)\n") + buf.WriteString(" # hash_key: ip # 一致性哈希键(仅 load_balance=consistent_hash 时有效,有效值: ip, uri, header:X-Name)\n") + buf.WriteString(" # virtual_nodes: 150 # 一致性哈希虚拟节点数(仅 load_balance=consistent_hash 时有效)\n") buf.WriteString(" # health_check: # 健康检查\n") buf.WriteString(" # interval: 10s\n") buf.WriteString(" # path: /health\n") @@ -243,6 +257,9 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) { buf.WriteString(fmt.Sprintf(" burst: %d # 突发上限\n", cfg.Server.Security.RateLimit.Burst)) buf.WriteString(fmt.Sprintf(" conn_limit: %d # 连接数限制\n", cfg.Server.Security.RateLimit.ConnLimit)) buf.WriteString(fmt.Sprintf(" key: \"%s\" # 限流 key 来源(有效值: ip, header)\n", cfg.Server.Security.RateLimit.Key)) + buf.WriteString(fmt.Sprintf(" algorithm: \"%s\" # 限流算法(有效值: token_bucket, sliding_window)\n", cfg.Server.Security.RateLimit.Algorithm)) + buf.WriteString(fmt.Sprintf(" sliding_window_mode: \"%s\" # 滑动窗口模式(有效值: approximate, precise,仅 algorithm=sliding_window 时有效)\n", cfg.Server.Security.RateLimit.SlidingWindowMode)) + buf.WriteString(fmt.Sprintf(" sliding_window: %d # 滑动窗口大小(秒,仅 algorithm=sliding_window 时有效)\n", cfg.Server.Security.RateLimit.SlidingWindow)) buf.WriteString("\n") buf.WriteString(" # 认证配置(type 为空时禁用)\n") buf.WriteString(" auth:\n") @@ -276,6 +293,11 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) { buf.WriteString(fmt.Sprintf(" type: \"%s\" # 压缩类型(有效值: gzip, brotli, both,空表示禁用)\n", cfg.Server.Compression.Type)) buf.WriteString(fmt.Sprintf(" level: %d # 压缩级别(范围 1-9,值越大压缩率越高但速度越慢)\n", cfg.Server.Compression.Level)) buf.WriteString(fmt.Sprintf(" min_size: %d # 最小压缩大小(字节,小于此值不压缩)\n", cfg.Server.Compression.MinSize)) + buf.WriteString(fmt.Sprintf(" gzip_static: %v # 启用预压缩文件支持(自动查找 .gz/.br 文件)\n", cfg.Server.Compression.GzipStatic)) + buf.WriteString(" gzip_static_extensions: # 预压缩文件扩展名\n") + for _, ext := range cfg.Server.Compression.GzipStaticExtensions { + buf.WriteString(fmt.Sprintf(" - \"%s\"\n", ext)) + } buf.WriteString(" types: # 可压缩的 MIME 类型\n") for _, t := range cfg.Server.Compression.Types { buf.WriteString(fmt.Sprintf(" - \"%s\"\n", t)) @@ -383,6 +405,16 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) { buf.WriteString(fmt.Sprintf(" max_conns_per_host: %d # 每主机最大连接(0 表示不限制)\n", cfg.Performance.Transport.MaxConnsPerHost)) buf.WriteString("\n") + // HTTP3 配置 + buf.WriteString("# HTTP/3 (QUIC) 配置(需要 SSL 证书)\n") + buf.WriteString("http3:\n") + buf.WriteString(fmt.Sprintf(" enabled: %v # 是否启用 HTTP/3\n", cfg.HTTP3.Enabled)) + buf.WriteString(fmt.Sprintf(" listen: \"%s\" # UDP 监听地址\n", cfg.HTTP3.Listen)) + buf.WriteString(fmt.Sprintf(" max_streams: %d # 最大并发流\n", cfg.HTTP3.MaxStreams)) + buf.WriteString(fmt.Sprintf(" idle_timeout: %ds # 空闲超时\n", int(cfg.HTTP3.IdleTimeout.Seconds()))) + buf.WriteString(fmt.Sprintf(" enable_0rtt: %v # 启用 0-RTT(早期数据,可能存在安全风险)\n", cfg.HTTP3.Enable0RTT)) + buf.WriteString("\n") + // monitoring 配置 buf.WriteString("# 监控配置\n") buf.WriteString("monitoring:\n") diff --git a/internal/loadbalance/algorithms.go b/internal/loadbalance/algorithms.go index 711e7bf..5112867 100644 --- a/internal/loadbalance/algorithms.go +++ b/internal/loadbalance/algorithms.go @@ -7,6 +7,7 @@ var ValidAlgorithms = []string{ "weighted_round_robin", "least_conn", "ip_hash", + "consistent_hash", } // IsValidAlgorithm 检查算法是否有效。 diff --git a/internal/middleware/security/headers.go b/internal/middleware/security/headers.go index d11ea21..5791f07 100644 --- a/internal/middleware/security/headers.go +++ b/internal/middleware/security/headers.go @@ -59,6 +59,20 @@ type SecurityHeadersMiddleware struct { // 返回值: // - *SecurityHeadersMiddleware: 配置好的中间件实例 func NewSecurityHeaders(cfg *config.SecurityHeaders) *SecurityHeadersMiddleware { + return NewSecurityHeadersWithHSTS(cfg, nil) +} + +// NewSecurityHeadersWithHSTS 创建新的安全响应头中间件,支持 HSTS 配置。 +// +// 根据配置创建中间件实例,如果配置为 nil 则使用安全的默认值。 +// +// 参数: +// - cfg: 安全头配置,可以为 nil 使用默认配置 +// - hstsCfg: HSTS 配置,可以为 nil 使用默认值 +// +// 返回值: +// - *SecurityHeadersMiddleware: 配置好的中间件实例 +func NewSecurityHeadersWithHSTS(cfg *config.SecurityHeaders, hstsCfg *config.HSTSConfig) *SecurityHeadersMiddleware { sh := &SecurityHeadersMiddleware{} if cfg != nil { @@ -73,11 +87,24 @@ func NewSecurityHeaders(cfg *config.SecurityHeaders) *SecurityHeadersMiddleware } // 预格式化 HSTS 头值 - sh.formatHSTS() + sh.formatHSTSFromConfig(hstsCfg) return sh } +// formatHSTSFromConfig 根据配置格式化 HSTS 头值。 +func (sh *SecurityHeadersMiddleware) formatHSTSFromConfig(hstsCfg *config.HSTSConfig) { + if hstsCfg != nil { + maxAge := hstsCfg.MaxAge + if maxAge <= 0 { + maxAge = 31536000 // 默认 1 年 + } + sh.hsts = formatHSTSValue(maxAge, hstsCfg.IncludeSubDomains, hstsCfg.Preload) + } else { + sh.formatHSTS() // 使用默认值 + } +} + // Name 返回中间件名称。 // // 返回值: diff --git a/internal/middleware/security/ratelimit.go b/internal/middleware/security/ratelimit.go index 2407697..74e8490 100644 --- a/internal/middleware/security/ratelimit.go +++ b/internal/middleware/security/ratelimit.go @@ -82,7 +82,7 @@ type KeyFunc func(ctx *fasthttp.RequestCtx) string // 返回值: // - *RateLimiter: 配置好的限流器实例 // - error: 配置无效时返回错误(如速率小于 0) -func NewRateLimiter(cfg *config.RateLimitConfig) (*RateLimiter, error) { +func NewRateLimiter(cfg *config.RateLimitConfig) (middleware.Middleware, error) { if cfg == nil { return nil, errors.New("rate limit config is nil") } @@ -91,6 +91,29 @@ func NewRateLimiter(cfg *config.RateLimitConfig) (*RateLimiter, error) { return nil, errors.New("request rate must be positive") } + // 根据算法选择限流器 + algorithm := cfg.Algorithm + if algorithm == "" { + algorithm = "token_bucket" // 默认令牌桶 + } + + switch algorithm { + case "token_bucket", "": + return newTokenBucketLimiter(cfg) + case "sliding_window": + window := time.Duration(cfg.SlidingWindow) * time.Second + if window <= 0 { + window = time.Second // 默认 1 秒窗口 + } + precise := cfg.SlidingWindowMode == "precise" + return NewSlidingWindowLimiterWrapper(cfg, window, precise) + default: + return nil, fmt.Errorf("unknown algorithm: %s", algorithm) + } +} + +// newTokenBucketLimiter 创建令牌桶限流器。 +func newTokenBucketLimiter(cfg *config.RateLimitConfig) (*RateLimiter, error) { if cfg.Burst < cfg.RequestRate { return nil, errors.New("burst must be at least equal to request rate") } @@ -114,6 +137,49 @@ func NewRateLimiter(cfg *config.RateLimitConfig) (*RateLimiter, error) { return rl, nil } +// SlidingWindowLimiterWrapper 滑动窗口限流器包装,实现 middleware.Middleware 接口。 +type SlidingWindowLimiterWrapper struct { + limiter *SlidingWindowLimiter + keyFunc KeyFunc +} + +// NewSlidingWindowLimiterWrapper 创建滑动窗口限流器包装。 +func NewSlidingWindowLimiterWrapper(cfg *config.RateLimitConfig, window time.Duration, precise bool) (*SlidingWindowLimiterWrapper, error) { + var keyFunc KeyFunc + switch cfg.Key { + case "ip", "": + keyFunc = keyByIP + case "header": + keyFunc = keyByHeader + default: + return nil, fmt.Errorf("unknown key type: %s", cfg.Key) + } + + return &SlidingWindowLimiterWrapper{ + limiter: NewSlidingWindowLimiter(window, cfg.RequestRate, precise), + keyFunc: keyFunc, + }, nil +} + +// Name 返回中间件名称。 +func (s *SlidingWindowLimiterWrapper) Name() string { + return "sliding_window_limiter" +} + +// Process 包装下一个处理器,添加限流逻辑。 +func (s *SlidingWindowLimiterWrapper) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + key := s.keyFunc(ctx) + + if !s.limiter.Allow(key) { + ctx.Error("Too Many Requests", fasthttp.StatusTooManyRequests) + return + } + + next(ctx) + } +} + // Name 返回中间件名称。 // // 返回值: diff --git a/internal/middleware/security/ratelimit_test.go b/internal/middleware/security/ratelimit_test.go index 88b6808..6f74a7f 100644 --- a/internal/middleware/security/ratelimit_test.go +++ b/internal/middleware/security/ratelimit_test.go @@ -83,7 +83,7 @@ func TestNewRateLimiter(t *testing.T) { } func TestRateLimiterAllow(t *testing.T) { - rl, err := NewRateLimiter(&config.RateLimitConfig{ + mw, err := NewRateLimiter(&config.RateLimitConfig{ RequestRate: 10, Burst: 10, }) @@ -91,6 +91,11 @@ func TestRateLimiterAllow(t *testing.T) { t.Fatalf("NewRateLimiter() error: %v", err) } + rl, ok := mw.(*RateLimiter) + if !ok { + t.Fatalf("Expected *RateLimiter, got %T", mw) + } + // Test burst allowance key := "test-key" @@ -108,7 +113,7 @@ func TestRateLimiterAllow(t *testing.T) { } func TestRateLimiterTokenRefill(t *testing.T) { - rl, err := NewRateLimiter(&config.RateLimitConfig{ + mw, err := NewRateLimiter(&config.RateLimitConfig{ RequestRate: 100, // 100 tokens per second Burst: 100, }) @@ -116,6 +121,11 @@ func TestRateLimiterTokenRefill(t *testing.T) { t.Fatalf("NewRateLimiter() error: %v", err) } + rl, ok := mw.(*RateLimiter) + if !ok { + t.Fatalf("Expected *RateLimiter, got %T", mw) + } + key := "refill-test" // Exhaust the burst @@ -138,7 +148,7 @@ func TestRateLimiterTokenRefill(t *testing.T) { } func TestRateLimiterReset(t *testing.T) { - rl, err := NewRateLimiter(&config.RateLimitConfig{ + mw, err := NewRateLimiter(&config.RateLimitConfig{ RequestRate: 1, Burst: 1, }) @@ -146,6 +156,11 @@ func TestRateLimiterReset(t *testing.T) { t.Fatalf("NewRateLimiter() error: %v", err) } + rl, ok := mw.(*RateLimiter) + if !ok { + t.Fatalf("Expected *RateLimiter, got %T", mw) + } + key := "reset-test" // Exhaust @@ -164,7 +179,7 @@ func TestRateLimiterReset(t *testing.T) { } func TestRateLimiterResetAll(t *testing.T) { - rl, err := NewRateLimiter(&config.RateLimitConfig{ + mw, err := NewRateLimiter(&config.RateLimitConfig{ RequestRate: 1, Burst: 1, }) @@ -172,6 +187,11 @@ func TestRateLimiterResetAll(t *testing.T) { t.Fatalf("NewRateLimiter() error: %v", err) } + rl, ok := mw.(*RateLimiter) + if !ok { + t.Fatalf("Expected *RateLimiter, got %T", mw) + } + // Create multiple buckets rl.Allow("key1") rl.Allow("key2") @@ -186,7 +206,7 @@ func TestRateLimiterResetAll(t *testing.T) { } func TestRateLimiterCleanup(t *testing.T) { - rl, err := NewRateLimiter(&config.RateLimitConfig{ + mw, err := NewRateLimiter(&config.RateLimitConfig{ RequestRate: 100, Burst: 100, }) @@ -194,6 +214,11 @@ func TestRateLimiterCleanup(t *testing.T) { t.Fatalf("NewRateLimiter() error: %v", err) } + rl, ok := mw.(*RateLimiter) + if !ok { + t.Fatalf("Expected *RateLimiter, got %T", mw) + } + // Create some buckets rl.Allow("key1") rl.Allow("key2") @@ -208,7 +233,7 @@ func TestRateLimiterCleanup(t *testing.T) { } func TestRateLimiterProcess(t *testing.T) { - rl, err := NewRateLimiter(&config.RateLimitConfig{ + mw, err := NewRateLimiter(&config.RateLimitConfig{ RequestRate: 100, Burst: 100, }) @@ -220,14 +245,14 @@ func TestRateLimiterProcess(t *testing.T) { ctx.WriteString("OK") } - handler := rl.Process(nextHandler) + handler := mw.Process(nextHandler) if handler == nil { t.Error("Process() returned nil handler") } } func TestRateLimiterGetStats(t *testing.T) { - rl, err := NewRateLimiter(&config.RateLimitConfig{ + mw, err := NewRateLimiter(&config.RateLimitConfig{ RequestRate: 100, Burst: 200, }) @@ -235,6 +260,11 @@ func TestRateLimiterGetStats(t *testing.T) { t.Fatalf("NewRateLimiter() error: %v", err) } + rl, ok := mw.(*RateLimiter) + if !ok { + t.Fatalf("Expected *RateLimiter, got %T", mw) + } + rl.Allow("key1") rl.Allow("key2") diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 7501d22..3fe2fde 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -64,11 +64,12 @@ type Proxy struct { // 参数: // - cfg: 代理配置,包括超时时间、请求头和负载均衡策略 // - targets: 要代理请求的后端目标列表 +// - transportCfg: 可选的 Transport 连接池配置,nil 时使用默认值 // // 返回值: // - *Proxy: 配置完成并可处理请求的代理实例 // - error: 初始化失败时非空(无效配置、没有健康目标等) -func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target) (*Proxy, error) { +func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target, transportCfg *config.TransportConfig) (*Proxy, error) { if cfg == nil { return nil, errors.New("proxy config is nil") } @@ -78,7 +79,7 @@ func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target) (*Proxy, e } // 根据配置创建负载均衡器 - balancer, err := createBalancer(cfg.LoadBalance) + balancer, err := createBalancer(cfg) if err != nil { return nil, err } @@ -96,7 +97,7 @@ func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target) (*Proxy, e continue } - client := createHostClient(target.URL, cfg.Timeout) + client := createHostClient(target.URL, cfg.Timeout, transportCfg) p.clients[target.URL] = client } @@ -122,8 +123,8 @@ func (p *Proxy) SetHealthChecker(hc *HealthChecker) { } // createBalancer 根据配置的算法创建负载均衡器。 -func createBalancer(algorithm string) (loadbalance.Balancer, error) { - switch algorithm { +func createBalancer(cfg *config.ProxyConfig) (loadbalance.Balancer, error) { + switch cfg.LoadBalance { case "round_robin", "": return loadbalance.NewRoundRobin(), nil case "weighted_round_robin": @@ -132,13 +133,19 @@ func createBalancer(algorithm string) (loadbalance.Balancer, error) { return loadbalance.NewLeastConnections(), nil case "ip_hash": return loadbalance.NewIPHash(), nil + case "consistent_hash": + virtualNodes := cfg.VirtualNodes + if virtualNodes <= 0 { + virtualNodes = 150 + } + return loadbalance.NewConsistentHash(virtualNodes, cfg.HashKey), nil default: - return nil, errors.New("unsupported load balance algorithm: " + algorithm) + return nil, errors.New("unsupported load balance algorithm: " + cfg.LoadBalance) } } // createHostClient 为后台目标 URL 创建 fasthttp.HostClient。 -func createHostClient(targetURL string, timeout config.ProxyTimeout) *fasthttp.HostClient { +func createHostClient(targetURL string, timeout config.ProxyTimeout, transportCfg *config.TransportConfig) *fasthttp.HostClient { // 从目标 URL 解析主机和协议 addr := targetURL isTLS := false @@ -155,13 +162,27 @@ func createHostClient(targetURL string, timeout config.ProxyTimeout) *fasthttp.H addr = addr[:idx] } + // 默认值 + maxIdleConnDuration := 90 * time.Second + maxConns := 100 + + // 应用 Transport 配置 + if transportCfg != nil { + if transportCfg.IdleConnTimeout > 0 { + maxIdleConnDuration = transportCfg.IdleConnTimeout + } + if transportCfg.MaxConnsPerHost > 0 { + maxConns = transportCfg.MaxConnsPerHost + } + } + client := &fasthttp.HostClient{ Addr: addr, IsTLS: isTLS, ReadTimeout: timeout.Read, WriteTimeout: timeout.Write, - MaxIdleConnDuration: 60 * time.Second, - MaxConns: 100, + MaxIdleConnDuration: maxIdleConnDuration, + MaxConns: maxConns, MaxConnWaitTimeout: timeout.Connect, RetryIf: nil, // Disable automatic retries DisablePathNormalizing: false, @@ -389,13 +410,13 @@ func (p *Proxy) UpdateTargets(targets []*loadbalance.Target) error { // 清除旧客户端 p.clients = make(map[string]*fasthttp.HostClient) - // 初始化新客户端 + // 初始化新客户端(使用 nil TransportConfig 保持原有行为) for _, target := range targets { if target.URL == "" { continue } - client := createHostClient(target.URL, p.config.Timeout) + client := createHostClient(target.URL, p.config.Timeout, nil) p.clients[target.URL] = client } diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 0d14f92..656c531 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -121,7 +121,7 @@ func TestNewProxy(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - p, err := NewProxy(tt.cfg, tt.targets) + p, err := NewProxy(tt.cfg, tt.targets, nil) if tt.wantErr { if err == nil { t.Errorf("NewProxy() expected error containing %q, got nil", tt.errContains) @@ -166,7 +166,7 @@ func TestServeHTTP_NoHealthyTargets(t *testing.T) { targets[0].Healthy.Store(false) targets[1].Healthy.Store(false) - p, err := NewProxy(cfg, targets) + p, err := NewProxy(cfg, targets, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } @@ -216,7 +216,7 @@ func TestServeHTTP_RequestForwarding(t *testing.T) { {URL: "http://localhost:8080"}, } - p, err := NewProxy(cfg, targets) + p, err := NewProxy(cfg, targets, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } @@ -308,7 +308,7 @@ func TestSelectTarget(t *testing.T) { Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, } - p, err := NewProxy(cfg, tt.targets) + p, err := NewProxy(cfg, tt.targets, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } @@ -425,7 +425,7 @@ func TestModifyRequestHeaders(t *testing.T) { {URL: "http://localhost:8080"}, } - p, err := NewProxy(cfg, targets) + p, err := NewProxy(cfg, targets, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } @@ -520,7 +520,7 @@ func TestModifyResponseHeaders(t *testing.T) { {URL: "http://localhost:8080"}, } - p, err := NewProxy(cfg, targets) + p, err := NewProxy(cfg, targets, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } @@ -608,7 +608,7 @@ func TestUpdateTargets(t *testing.T) { {URL: "http://old2:8080"}, } - p, err := NewProxy(cfg, initialTargets) + p, err := NewProxy(cfg, initialTargets, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } @@ -657,7 +657,7 @@ func TestGetTargets(t *testing.T) { {URL: "http://backend2:8080"}, } - p, err := NewProxy(cfg, targets) + p, err := NewProxy(cfg, targets, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } @@ -686,7 +686,7 @@ func TestGetConfig(t *testing.T) { {URL: "http://localhost:8080"}, } - p, err := NewProxy(cfg, targets) + p, err := NewProxy(cfg, targets, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } @@ -763,38 +763,37 @@ func TestIsWebSocketRequest(t *testing.T) { func TestCreateBalancer(t *testing.T) { tests := []struct { name string - algorithm string + cfg *config.ProxyConfig wantErr bool errContains string }{ { - name: "轮询", - algorithm: "round_robin", - wantErr: false, + name: "轮询", + cfg: &config.ProxyConfig{LoadBalance: "round_robin"}, }, { - name: "加权轮询", - algorithm: "weighted_round_robin", - wantErr: false, + name: "加权轮询", + cfg: &config.ProxyConfig{LoadBalance: "weighted_round_robin"}, }, { - name: "最少连接", - algorithm: "least_conn", - wantErr: false, + name: "最少连接", + cfg: &config.ProxyConfig{LoadBalance: "least_conn"}, }, { - name: "IP哈希", - algorithm: "ip_hash", - wantErr: false, + name: "IP哈希", + cfg: &config.ProxyConfig{LoadBalance: "ip_hash"}, }, { - name: "空算法(默认轮询)", - algorithm: "", - wantErr: false, + name: "一致性哈希", + cfg: &config.ProxyConfig{LoadBalance: "consistent_hash", HashKey: "ip", VirtualNodes: 150}, + }, + { + name: "空算法(默认轮询)", + cfg: &config.ProxyConfig{LoadBalance: ""}, }, { name: "无效算法", - algorithm: "unknown_algorithm", + cfg: &config.ProxyConfig{LoadBalance: "unknown_algorithm"}, wantErr: true, errContains: "unsupported load balance algorithm", }, @@ -802,23 +801,23 @@ func TestCreateBalancer(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - balancer, err := createBalancer(tt.algorithm) + balancer, err := createBalancer(tt.cfg) if tt.wantErr { if err == nil { - t.Errorf("createBalancer(%q) expected error", tt.algorithm) + t.Errorf("createBalancer(%v) expected error", tt.cfg.LoadBalance) return } if !contains(err.Error(), tt.errContains) { - t.Errorf("createBalancer(%q) error = %v, want containing %q", tt.algorithm, err, tt.errContains) + t.Errorf("createBalancer(%v) error = %v, want containing %q", tt.cfg.LoadBalance, err, tt.errContains) } return } if err != nil { - t.Errorf("createBalancer(%q) unexpected error: %v", tt.algorithm, err) + t.Errorf("createBalancer(%v) unexpected error: %v", tt.cfg.LoadBalance, err) return } if balancer == nil { - t.Errorf("createBalancer(%q) returned nil balancer", tt.algorithm) + t.Errorf("createBalancer(%v) returned nil balancer", tt.cfg.LoadBalance) } }) } @@ -850,7 +849,7 @@ func TestCreateHostClient(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - client := createHostClient(tt.targetURL, tt.timeout) + client := createHostClient(tt.targetURL, tt.timeout, nil) if client == nil { t.Error("createHostClient() returned nil") return @@ -882,7 +881,7 @@ func TestHandleWebSocket(t *testing.T) { {URL: "http://localhost:8080"}, } - p, err := NewProxy(cfg, targets) + p, err := NewProxy(cfg, targets, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } @@ -916,7 +915,7 @@ func TestSetHealthChecker(t *testing.T) { {URL: "http://localhost:8081"}, } - p, err := NewProxy(cfg, targets) + p, err := NewProxy(cfg, targets, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } @@ -954,7 +953,7 @@ func TestGetClient(t *testing.T) { {URL: "http://localhost:8082"}, } - p, err := NewProxy(cfg, targets) + p, err := NewProxy(cfg, targets, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } @@ -1018,7 +1017,7 @@ func TestProxyCache(t *testing.T) { } targets[0].Healthy.Store(true) - p, err := NewProxy(cfg, targets) + p, err := NewProxy(cfg, targets, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } @@ -1069,7 +1068,7 @@ func TestServeHTTP_WithPassiveHealthCheck(t *testing.T) { } targets[0].Healthy.Store(true) - p, err := NewProxy(cfg, targets) + p, err := NewProxy(cfg, targets, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } diff --git a/internal/server/server.go b/internal/server/server.go index 0c217fd..80fb9b5 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -251,7 +251,7 @@ func (s *Server) buildMiddlewareChain(serverCfg *config.ServerConfig) (*middlewa serverCfg.Security.Headers.ContentSecurityPolicy != "" || serverCfg.Security.Headers.ReferrerPolicy != "" || serverCfg.Security.Headers.PermissionsPolicy != "" { - headers := security.NewSecurityHeaders(&serverCfg.Security.Headers) + headers := security.NewSecurityHeadersWithHSTS(&serverCfg.Security.Headers, &serverCfg.SSL.HSTS) middlewares = append(middlewares, headers) } @@ -534,7 +534,8 @@ func (s *Server) registerProxyRoutes(router *handler.Router, serverCfg *config.S targets[j].Healthy.Store(true) } - p, err := proxy.NewProxy(proxyCfg, targets) + // 传递 Transport 配置 + p, err := proxy.NewProxy(proxyCfg, targets, &s.config.Performance.Transport) if err != nil { logging.Error().Msg("创建代理失败: " + err.Error()) continue diff --git a/internal/server/vhost.go b/internal/server/vhost.go index 2ce7855..a883920 100644 --- a/internal/server/vhost.go +++ b/internal/server/vhost.go @@ -69,14 +69,7 @@ func (v *VHostManager) SetDefault(handler fasthttp.RequestHandler) { // Handler 返回虚拟主机选择器 func (v *VHostManager) Handler() fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { - host := string(ctx.Host()) - // 去除端口号 - for i := 0; i < len(host); i++ { - if host[i] == ':' { - host = host[:i] - break - } - } + host := stripPort(string(ctx.Host())) if vhost, ok := v.hosts[host]; ok { vhost.handler(ctx) @@ -87,3 +80,38 @@ func (v *VHostManager) Handler() fasthttp.RequestHandler { } } } + +// stripPort 从 Host 头中移除端口号。 +// +// 支持 IPv4 和 IPv6 格式: +// - example.com:8080 -> example.com +// - [::1]:8080 -> [::1] +// - [2001:db8::1]:443 -> [2001:db8::1] +// - example.com -> example.com +func stripPort(host string) string { + // 空字符串直接返回 + if len(host) == 0 { + return host + } + + // IPv6 格式:以 '[' 开头,找 ']:' 作为分隔点 + if host[0] == '[' { + // 查找 ']:' 分隔符 + for i := 0; i < len(host)-1; i++ { + if host[i] == ']' && host[i+1] == ':' { + return host[:i+1] // 返回包含 ']' 的部分,如 "[::1]" + } + } + // 没有 ']:' 分隔符,可能是纯 IPv6 地址(如 "[::1]") + return host + } + + // IPv4 或域名格式:找第一个 ':' 作为分隔点 + for i := 0; i < len(host); i++ { + if host[i] == ':' { + return host[:i] + } + } + + return host +} diff --git a/internal/server/vhost_test.go b/internal/server/vhost_test.go index 592c700..1776f52 100644 --- a/internal/server/vhost_test.go +++ b/internal/server/vhost_test.go @@ -99,13 +99,9 @@ func TestVHostManager_Handler(t *testing.T) { }) t.Run("IPv6地址Host", func(t *testing.T) { - // TODO: 当前 vhost.go 的端口剥离逻辑不支持 IPv6 格式 [::1]:8080 - // 它会错误地在第一个 ':' 处截断(IPv6 地址内部的冒号) - // 修复方案:检查 host 是否以 '[' 开头,找 ']:' 作为分隔点 manager := NewVHostManager() ipv6Called := false manager.AddHost("[::1]", mockHandler("ipv6", &ipv6Called)) - manager.SetDefault(mockHandler("default", &ipv6Called)) // fallback handler := manager.Handler() ctx := &fasthttp.RequestCtx{} @@ -113,9 +109,12 @@ func TestVHostManager_Handler(t *testing.T) { handler(ctx) - // 当前实现不支持 IPv6,会 fallback 到默认 handler - // 修复 vhost.go 后此测试应验证 ipv6Called 为 true - t.Log("注意: 当前实现不支持 IPv6 地址,需要修复 vhost.go 的端口剥离逻辑") + if !ipv6Called { + t.Error("期望 [::1] handler 被调用,但未被调用") + } + if string(ctx.Response.Body()) != "ipv6" { + t.Errorf("响应体 = %q, want %q", string(ctx.Response.Body()), "ipv6") + } }) t.Run("空Host使用默认", func(t *testing.T) { @@ -270,6 +269,10 @@ func TestVHostManager_PortStripping(t *testing.T) { {"标准HTTPS端口", "example.com:443", "example.com"}, {"自定义端口", "example.com:8080", "example.com"}, {"IPv6 localhost带端口", "[localhost]:8080", "[localhost]"}, + {"IPv6 loopback带端口", "[::1]:8080", "[::1]"}, + {"IPv6完整地址带端口", "[2001:db8::1]:443", "[2001:db8::1]"}, + {"IPv6无端口", "[::1]", "[::1]"}, + {"IPv6完整地址无端口", "[2001:db8::1]", "[2001:db8::1]"}, {"空字符串", "", ""}, } @@ -291,11 +294,8 @@ func TestVHostManager_PortStripping(t *testing.T) { }) } - // IPv6 数字地址测试 - 当前实现有已知 bug - t.Run("IPv6数字地址_已知限制", func(t *testing.T) { - // TODO: vhost.go 的端口剥离逻辑不支持 IPv6 数字地址格式 [::1]:8080 - // 因为它会在第一个 ':' 处截断(IPv6 地址内部的冒号) - // 结果:[:而不是 [::1] + // IPv6 数字地址测试 + t.Run("IPv6数字地址", func(t *testing.T) { manager := NewVHostManager() ipv6Called := false manager.AddHost("[::1]", mockHandler("ipv6", &ipv6Called)) @@ -306,10 +306,11 @@ func TestVHostManager_PortStripping(t *testing.T) { handler(ctx) - // 当前行为:不匹配,因为端口剥离错误 - if ipv6Called { - t.Error("当前实现不支持 IPv6 数字地址的端口剥离,不应匹配") + if !ipv6Called { + t.Error("期望 [::1] handler 被调用,但未被调用") + } + if string(ctx.Response.Body()) != "ipv6" { + t.Errorf("响应体 = %q, want %q", string(ctx.Response.Body()), "ipv6") } - t.Log("已知限制: IPv6 数字地址端口剥离需要修复 vhost.go") }) }