feat(proxy,middleware,config): 集成配置与代码差异修复
- 集成一致性哈希负载均衡到 proxy.go,支持 hash_key 和 virtual_nodes 配置 - 集成滑动窗口限流算法到 ratelimit.go,支持 algorithm 选择 - 应用 Transport 连接池配置到 createHostClient - 集成 HSTS 配置到安全头部中间件 - 补充 config.example.yaml 缺失配置(http3、gzip_static、sliding_window) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
d6367a1c38
commit
ec916d882d
@ -25,7 +25,9 @@ server:
|
||||
# weight: 3 # 权重(加权轮询时有效)
|
||||
# - url: http://backend2:8080
|
||||
# weight: 1
|
||||
# load_balance: round_robin # 负载均衡算法: round_robin, weighted_round_robin, least_conn, ip_hash
|
||||
# load_balance: round_robin # 负载均衡算法: round_robin, weighted_round_robin, least_conn, ip_hash, consistent_hash
|
||||
# hash_key: "ip" # 一致性哈希键(有效值: ip, uri, header:X-Name)
|
||||
# virtual_nodes: 150 # 一致性哈希虚拟节点数
|
||||
# health_check: # 健康检查
|
||||
# interval: 10s
|
||||
# path: /health
|
||||
@ -73,6 +75,9 @@ server:
|
||||
burst: 0 # 突发上限
|
||||
conn_limit: 0 # 连接数限制
|
||||
key: "ip" # 限流 key 来源(有效值: ip, header)
|
||||
algorithm: "token_bucket" # 限流算法(有效值: token_bucket, sliding_window)
|
||||
sliding_window_mode: "approximate" # 滑动窗口模式(有效值: approximate, precise)
|
||||
sliding_window: 60 # 滑动窗口大小(秒)
|
||||
|
||||
# 认证配置(type 为空时禁用)
|
||||
auth:
|
||||
@ -108,6 +113,14 @@ server:
|
||||
- "text/javascript"
|
||||
- "application/json"
|
||||
- "application/javascript"
|
||||
gzip_static: true # 启用预压缩文件支持(检测 .gz 文件)
|
||||
gzip_static_extensions: # 预压缩文件扩展名
|
||||
- ".html"
|
||||
- ".css"
|
||||
- ".js"
|
||||
- ".json"
|
||||
- ".xml"
|
||||
- ".svg"
|
||||
|
||||
# 多虚拟主机模式(可选,每个虚拟主机支持完整的 server 配置)
|
||||
# servers:
|
||||
@ -158,6 +171,14 @@ server:
|
||||
# 默认 TLS 协议: TLSv1.2, TLSv1.3(不支持 TLSv1.0/1.1)
|
||||
# 默认 HSTS 配置: max_age=31536000(1年), include_sub_domains=true
|
||||
|
||||
# HTTP/3 (QUIC) 配置(可选)
|
||||
# http3:
|
||||
# enabled: false # 是否启用 HTTP/3
|
||||
# listen: ":443" # UDP 监听地址
|
||||
# max_streams: 100 # 最大并发流数
|
||||
# idle_timeout: 30s # 空闲超时
|
||||
# enable_0rtt: false # 启用 0-RTT 快速连接
|
||||
|
||||
# TCP/UDP Stream 代理配置(可选)
|
||||
# stream:
|
||||
# - listen: "3306" # 监听地址
|
||||
|
||||
@ -65,10 +65,13 @@ func DefaultConfig() *Config {
|
||||
Default: "allow",
|
||||
},
|
||||
RateLimit: RateLimitConfig{
|
||||
RequestRate: 0,
|
||||
Burst: 0,
|
||||
ConnLimit: 0,
|
||||
Key: "ip",
|
||||
RequestRate: 0,
|
||||
Burst: 0,
|
||||
ConnLimit: 0,
|
||||
Key: "ip",
|
||||
Algorithm: "token_bucket",
|
||||
SlidingWindowMode: "approximate",
|
||||
SlidingWindow: 60,
|
||||
},
|
||||
Auth: AuthConfig{
|
||||
RequireTLS: true,
|
||||
@ -83,9 +86,11 @@ func DefaultConfig() *Config {
|
||||
},
|
||||
},
|
||||
Compression: CompressionConfig{
|
||||
Type: "gzip",
|
||||
Level: 6,
|
||||
MinSize: 1024,
|
||||
Type: "gzip",
|
||||
Level: 6,
|
||||
MinSize: 1024,
|
||||
GzipStatic: false,
|
||||
GzipStaticExtensions: []string{".gz", ".br"},
|
||||
Types: []string{
|
||||
"text/html",
|
||||
"text/css",
|
||||
@ -133,6 +138,13 @@ func DefaultConfig() *Config {
|
||||
Allow: []string{"127.0.0.1"},
|
||||
},
|
||||
},
|
||||
HTTP3: HTTP3Config{
|
||||
Enabled: false,
|
||||
Listen: ":443",
|
||||
MaxStreams: 100,
|
||||
IdleTimeout: 60 * time.Second,
|
||||
Enable0RTT: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@ -190,7 +202,9 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
|
||||
buf.WriteString(" # weight: 3 # 权重(加权轮询时有效)\n")
|
||||
buf.WriteString(" # - url: http://backend2:8080\n")
|
||||
buf.WriteString(" # weight: 1\n")
|
||||
buf.WriteString(" # load_balance: round_robin # 负载均衡算法: round_robin, weighted_round_robin, least_conn, ip_hash\n")
|
||||
buf.WriteString(" # load_balance: round_robin # 负载均衡算法(有效值: round_robin, weighted_round_robin, least_conn, ip_hash, consistent_hash)\n")
|
||||
buf.WriteString(" # hash_key: ip # 一致性哈希键(仅 load_balance=consistent_hash 时有效,有效值: ip, uri, header:X-Name)\n")
|
||||
buf.WriteString(" # virtual_nodes: 150 # 一致性哈希虚拟节点数(仅 load_balance=consistent_hash 时有效)\n")
|
||||
buf.WriteString(" # health_check: # 健康检查\n")
|
||||
buf.WriteString(" # interval: 10s\n")
|
||||
buf.WriteString(" # path: /health\n")
|
||||
@ -243,6 +257,9 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
|
||||
buf.WriteString(fmt.Sprintf(" burst: %d # 突发上限\n", cfg.Server.Security.RateLimit.Burst))
|
||||
buf.WriteString(fmt.Sprintf(" conn_limit: %d # 连接数限制\n", cfg.Server.Security.RateLimit.ConnLimit))
|
||||
buf.WriteString(fmt.Sprintf(" key: \"%s\" # 限流 key 来源(有效值: ip, header)\n", cfg.Server.Security.RateLimit.Key))
|
||||
buf.WriteString(fmt.Sprintf(" algorithm: \"%s\" # 限流算法(有效值: token_bucket, sliding_window)\n", cfg.Server.Security.RateLimit.Algorithm))
|
||||
buf.WriteString(fmt.Sprintf(" sliding_window_mode: \"%s\" # 滑动窗口模式(有效值: approximate, precise,仅 algorithm=sliding_window 时有效)\n", cfg.Server.Security.RateLimit.SlidingWindowMode))
|
||||
buf.WriteString(fmt.Sprintf(" sliding_window: %d # 滑动窗口大小(秒,仅 algorithm=sliding_window 时有效)\n", cfg.Server.Security.RateLimit.SlidingWindow))
|
||||
buf.WriteString("\n")
|
||||
buf.WriteString(" # 认证配置(type 为空时禁用)\n")
|
||||
buf.WriteString(" auth:\n")
|
||||
@ -276,6 +293,11 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
|
||||
buf.WriteString(fmt.Sprintf(" type: \"%s\" # 压缩类型(有效值: gzip, brotli, both,空表示禁用)\n", cfg.Server.Compression.Type))
|
||||
buf.WriteString(fmt.Sprintf(" level: %d # 压缩级别(范围 1-9,值越大压缩率越高但速度越慢)\n", cfg.Server.Compression.Level))
|
||||
buf.WriteString(fmt.Sprintf(" min_size: %d # 最小压缩大小(字节,小于此值不压缩)\n", cfg.Server.Compression.MinSize))
|
||||
buf.WriteString(fmt.Sprintf(" gzip_static: %v # 启用预压缩文件支持(自动查找 .gz/.br 文件)\n", cfg.Server.Compression.GzipStatic))
|
||||
buf.WriteString(" gzip_static_extensions: # 预压缩文件扩展名\n")
|
||||
for _, ext := range cfg.Server.Compression.GzipStaticExtensions {
|
||||
buf.WriteString(fmt.Sprintf(" - \"%s\"\n", ext))
|
||||
}
|
||||
buf.WriteString(" types: # 可压缩的 MIME 类型\n")
|
||||
for _, t := range cfg.Server.Compression.Types {
|
||||
buf.WriteString(fmt.Sprintf(" - \"%s\"\n", t))
|
||||
@ -383,6 +405,16 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
|
||||
buf.WriteString(fmt.Sprintf(" max_conns_per_host: %d # 每主机最大连接(0 表示不限制)\n", cfg.Performance.Transport.MaxConnsPerHost))
|
||||
buf.WriteString("\n")
|
||||
|
||||
// HTTP3 配置
|
||||
buf.WriteString("# HTTP/3 (QUIC) 配置(需要 SSL 证书)\n")
|
||||
buf.WriteString("http3:\n")
|
||||
buf.WriteString(fmt.Sprintf(" enabled: %v # 是否启用 HTTP/3\n", cfg.HTTP3.Enabled))
|
||||
buf.WriteString(fmt.Sprintf(" listen: \"%s\" # UDP 监听地址\n", cfg.HTTP3.Listen))
|
||||
buf.WriteString(fmt.Sprintf(" max_streams: %d # 最大并发流\n", cfg.HTTP3.MaxStreams))
|
||||
buf.WriteString(fmt.Sprintf(" idle_timeout: %ds # 空闲超时\n", int(cfg.HTTP3.IdleTimeout.Seconds())))
|
||||
buf.WriteString(fmt.Sprintf(" enable_0rtt: %v # 启用 0-RTT(早期数据,可能存在安全风险)\n", cfg.HTTP3.Enable0RTT))
|
||||
buf.WriteString("\n")
|
||||
|
||||
// monitoring 配置
|
||||
buf.WriteString("# 监控配置\n")
|
||||
buf.WriteString("monitoring:\n")
|
||||
|
||||
@ -7,6 +7,7 @@ var ValidAlgorithms = []string{
|
||||
"weighted_round_robin",
|
||||
"least_conn",
|
||||
"ip_hash",
|
||||
"consistent_hash",
|
||||
}
|
||||
|
||||
// IsValidAlgorithm 检查算法是否有效。
|
||||
|
||||
@ -59,6 +59,20 @@ type SecurityHeadersMiddleware struct {
|
||||
// 返回值:
|
||||
// - *SecurityHeadersMiddleware: 配置好的中间件实例
|
||||
func NewSecurityHeaders(cfg *config.SecurityHeaders) *SecurityHeadersMiddleware {
|
||||
return NewSecurityHeadersWithHSTS(cfg, nil)
|
||||
}
|
||||
|
||||
// NewSecurityHeadersWithHSTS 创建新的安全响应头中间件,支持 HSTS 配置。
|
||||
//
|
||||
// 根据配置创建中间件实例,如果配置为 nil 则使用安全的默认值。
|
||||
//
|
||||
// 参数:
|
||||
// - cfg: 安全头配置,可以为 nil 使用默认配置
|
||||
// - hstsCfg: HSTS 配置,可以为 nil 使用默认值
|
||||
//
|
||||
// 返回值:
|
||||
// - *SecurityHeadersMiddleware: 配置好的中间件实例
|
||||
func NewSecurityHeadersWithHSTS(cfg *config.SecurityHeaders, hstsCfg *config.HSTSConfig) *SecurityHeadersMiddleware {
|
||||
sh := &SecurityHeadersMiddleware{}
|
||||
|
||||
if cfg != nil {
|
||||
@ -73,11 +87,24 @@ func NewSecurityHeaders(cfg *config.SecurityHeaders) *SecurityHeadersMiddleware
|
||||
}
|
||||
|
||||
// 预格式化 HSTS 头值
|
||||
sh.formatHSTS()
|
||||
sh.formatHSTSFromConfig(hstsCfg)
|
||||
|
||||
return sh
|
||||
}
|
||||
|
||||
// formatHSTSFromConfig 根据配置格式化 HSTS 头值。
|
||||
func (sh *SecurityHeadersMiddleware) formatHSTSFromConfig(hstsCfg *config.HSTSConfig) {
|
||||
if hstsCfg != nil {
|
||||
maxAge := hstsCfg.MaxAge
|
||||
if maxAge <= 0 {
|
||||
maxAge = 31536000 // 默认 1 年
|
||||
}
|
||||
sh.hsts = formatHSTSValue(maxAge, hstsCfg.IncludeSubDomains, hstsCfg.Preload)
|
||||
} else {
|
||||
sh.formatHSTS() // 使用默认值
|
||||
}
|
||||
}
|
||||
|
||||
// Name 返回中间件名称。
|
||||
//
|
||||
// 返回值:
|
||||
|
||||
@ -82,7 +82,7 @@ type KeyFunc func(ctx *fasthttp.RequestCtx) string
|
||||
// 返回值:
|
||||
// - *RateLimiter: 配置好的限流器实例
|
||||
// - error: 配置无效时返回错误(如速率小于 0)
|
||||
func NewRateLimiter(cfg *config.RateLimitConfig) (*RateLimiter, error) {
|
||||
func NewRateLimiter(cfg *config.RateLimitConfig) (middleware.Middleware, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("rate limit config is nil")
|
||||
}
|
||||
@ -91,6 +91,29 @@ func NewRateLimiter(cfg *config.RateLimitConfig) (*RateLimiter, error) {
|
||||
return nil, errors.New("request rate must be positive")
|
||||
}
|
||||
|
||||
// 根据算法选择限流器
|
||||
algorithm := cfg.Algorithm
|
||||
if algorithm == "" {
|
||||
algorithm = "token_bucket" // 默认令牌桶
|
||||
}
|
||||
|
||||
switch algorithm {
|
||||
case "token_bucket", "":
|
||||
return newTokenBucketLimiter(cfg)
|
||||
case "sliding_window":
|
||||
window := time.Duration(cfg.SlidingWindow) * time.Second
|
||||
if window <= 0 {
|
||||
window = time.Second // 默认 1 秒窗口
|
||||
}
|
||||
precise := cfg.SlidingWindowMode == "precise"
|
||||
return NewSlidingWindowLimiterWrapper(cfg, window, precise)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown algorithm: %s", algorithm)
|
||||
}
|
||||
}
|
||||
|
||||
// newTokenBucketLimiter 创建令牌桶限流器。
|
||||
func newTokenBucketLimiter(cfg *config.RateLimitConfig) (*RateLimiter, error) {
|
||||
if cfg.Burst < cfg.RequestRate {
|
||||
return nil, errors.New("burst must be at least equal to request rate")
|
||||
}
|
||||
@ -114,6 +137,49 @@ func NewRateLimiter(cfg *config.RateLimitConfig) (*RateLimiter, error) {
|
||||
return rl, nil
|
||||
}
|
||||
|
||||
// SlidingWindowLimiterWrapper 滑动窗口限流器包装,实现 middleware.Middleware 接口。
|
||||
type SlidingWindowLimiterWrapper struct {
|
||||
limiter *SlidingWindowLimiter
|
||||
keyFunc KeyFunc
|
||||
}
|
||||
|
||||
// NewSlidingWindowLimiterWrapper 创建滑动窗口限流器包装。
|
||||
func NewSlidingWindowLimiterWrapper(cfg *config.RateLimitConfig, window time.Duration, precise bool) (*SlidingWindowLimiterWrapper, error) {
|
||||
var keyFunc KeyFunc
|
||||
switch cfg.Key {
|
||||
case "ip", "":
|
||||
keyFunc = keyByIP
|
||||
case "header":
|
||||
keyFunc = keyByHeader
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown key type: %s", cfg.Key)
|
||||
}
|
||||
|
||||
return &SlidingWindowLimiterWrapper{
|
||||
limiter: NewSlidingWindowLimiter(window, cfg.RequestRate, precise),
|
||||
keyFunc: keyFunc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Name 返回中间件名称。
|
||||
func (s *SlidingWindowLimiterWrapper) Name() string {
|
||||
return "sliding_window_limiter"
|
||||
}
|
||||
|
||||
// Process 包装下一个处理器,添加限流逻辑。
|
||||
func (s *SlidingWindowLimiterWrapper) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
||||
return func(ctx *fasthttp.RequestCtx) {
|
||||
key := s.keyFunc(ctx)
|
||||
|
||||
if !s.limiter.Allow(key) {
|
||||
ctx.Error("Too Many Requests", fasthttp.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
next(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// Name 返回中间件名称。
|
||||
//
|
||||
// 返回值:
|
||||
|
||||
@ -83,7 +83,7 @@ func TestNewRateLimiter(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRateLimiterAllow(t *testing.T) {
|
||||
rl, err := NewRateLimiter(&config.RateLimitConfig{
|
||||
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||||
RequestRate: 10,
|
||||
Burst: 10,
|
||||
})
|
||||
@ -91,6 +91,11 @@ func TestRateLimiterAllow(t *testing.T) {
|
||||
t.Fatalf("NewRateLimiter() error: %v", err)
|
||||
}
|
||||
|
||||
rl, ok := mw.(*RateLimiter)
|
||||
if !ok {
|
||||
t.Fatalf("Expected *RateLimiter, got %T", mw)
|
||||
}
|
||||
|
||||
// Test burst allowance
|
||||
key := "test-key"
|
||||
|
||||
@ -108,7 +113,7 @@ func TestRateLimiterAllow(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRateLimiterTokenRefill(t *testing.T) {
|
||||
rl, err := NewRateLimiter(&config.RateLimitConfig{
|
||||
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||||
RequestRate: 100, // 100 tokens per second
|
||||
Burst: 100,
|
||||
})
|
||||
@ -116,6 +121,11 @@ func TestRateLimiterTokenRefill(t *testing.T) {
|
||||
t.Fatalf("NewRateLimiter() error: %v", err)
|
||||
}
|
||||
|
||||
rl, ok := mw.(*RateLimiter)
|
||||
if !ok {
|
||||
t.Fatalf("Expected *RateLimiter, got %T", mw)
|
||||
}
|
||||
|
||||
key := "refill-test"
|
||||
|
||||
// Exhaust the burst
|
||||
@ -138,7 +148,7 @@ func TestRateLimiterTokenRefill(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRateLimiterReset(t *testing.T) {
|
||||
rl, err := NewRateLimiter(&config.RateLimitConfig{
|
||||
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||||
RequestRate: 1,
|
||||
Burst: 1,
|
||||
})
|
||||
@ -146,6 +156,11 @@ func TestRateLimiterReset(t *testing.T) {
|
||||
t.Fatalf("NewRateLimiter() error: %v", err)
|
||||
}
|
||||
|
||||
rl, ok := mw.(*RateLimiter)
|
||||
if !ok {
|
||||
t.Fatalf("Expected *RateLimiter, got %T", mw)
|
||||
}
|
||||
|
||||
key := "reset-test"
|
||||
|
||||
// Exhaust
|
||||
@ -164,7 +179,7 @@ func TestRateLimiterReset(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRateLimiterResetAll(t *testing.T) {
|
||||
rl, err := NewRateLimiter(&config.RateLimitConfig{
|
||||
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||||
RequestRate: 1,
|
||||
Burst: 1,
|
||||
})
|
||||
@ -172,6 +187,11 @@ func TestRateLimiterResetAll(t *testing.T) {
|
||||
t.Fatalf("NewRateLimiter() error: %v", err)
|
||||
}
|
||||
|
||||
rl, ok := mw.(*RateLimiter)
|
||||
if !ok {
|
||||
t.Fatalf("Expected *RateLimiter, got %T", mw)
|
||||
}
|
||||
|
||||
// Create multiple buckets
|
||||
rl.Allow("key1")
|
||||
rl.Allow("key2")
|
||||
@ -186,7 +206,7 @@ func TestRateLimiterResetAll(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRateLimiterCleanup(t *testing.T) {
|
||||
rl, err := NewRateLimiter(&config.RateLimitConfig{
|
||||
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||||
RequestRate: 100,
|
||||
Burst: 100,
|
||||
})
|
||||
@ -194,6 +214,11 @@ func TestRateLimiterCleanup(t *testing.T) {
|
||||
t.Fatalf("NewRateLimiter() error: %v", err)
|
||||
}
|
||||
|
||||
rl, ok := mw.(*RateLimiter)
|
||||
if !ok {
|
||||
t.Fatalf("Expected *RateLimiter, got %T", mw)
|
||||
}
|
||||
|
||||
// Create some buckets
|
||||
rl.Allow("key1")
|
||||
rl.Allow("key2")
|
||||
@ -208,7 +233,7 @@ func TestRateLimiterCleanup(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRateLimiterProcess(t *testing.T) {
|
||||
rl, err := NewRateLimiter(&config.RateLimitConfig{
|
||||
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||||
RequestRate: 100,
|
||||
Burst: 100,
|
||||
})
|
||||
@ -220,14 +245,14 @@ func TestRateLimiterProcess(t *testing.T) {
|
||||
ctx.WriteString("OK")
|
||||
}
|
||||
|
||||
handler := rl.Process(nextHandler)
|
||||
handler := mw.Process(nextHandler)
|
||||
if handler == nil {
|
||||
t.Error("Process() returned nil handler")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterGetStats(t *testing.T) {
|
||||
rl, err := NewRateLimiter(&config.RateLimitConfig{
|
||||
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||||
RequestRate: 100,
|
||||
Burst: 200,
|
||||
})
|
||||
@ -235,6 +260,11 @@ func TestRateLimiterGetStats(t *testing.T) {
|
||||
t.Fatalf("NewRateLimiter() error: %v", err)
|
||||
}
|
||||
|
||||
rl, ok := mw.(*RateLimiter)
|
||||
if !ok {
|
||||
t.Fatalf("Expected *RateLimiter, got %T", mw)
|
||||
}
|
||||
|
||||
rl.Allow("key1")
|
||||
rl.Allow("key2")
|
||||
|
||||
|
||||
@ -64,11 +64,12 @@ type Proxy struct {
|
||||
// 参数:
|
||||
// - cfg: 代理配置,包括超时时间、请求头和负载均衡策略
|
||||
// - targets: 要代理请求的后端目标列表
|
||||
// - transportCfg: 可选的 Transport 连接池配置,nil 时使用默认值
|
||||
//
|
||||
// 返回值:
|
||||
// - *Proxy: 配置完成并可处理请求的代理实例
|
||||
// - error: 初始化失败时非空(无效配置、没有健康目标等)
|
||||
func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target) (*Proxy, error) {
|
||||
func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target, transportCfg *config.TransportConfig) (*Proxy, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("proxy config is nil")
|
||||
}
|
||||
@ -78,7 +79,7 @@ func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target) (*Proxy, e
|
||||
}
|
||||
|
||||
// 根据配置创建负载均衡器
|
||||
balancer, err := createBalancer(cfg.LoadBalance)
|
||||
balancer, err := createBalancer(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -96,7 +97,7 @@ func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target) (*Proxy, e
|
||||
continue
|
||||
}
|
||||
|
||||
client := createHostClient(target.URL, cfg.Timeout)
|
||||
client := createHostClient(target.URL, cfg.Timeout, transportCfg)
|
||||
p.clients[target.URL] = client
|
||||
}
|
||||
|
||||
@ -122,8 +123,8 @@ func (p *Proxy) SetHealthChecker(hc *HealthChecker) {
|
||||
}
|
||||
|
||||
// createBalancer 根据配置的算法创建负载均衡器。
|
||||
func createBalancer(algorithm string) (loadbalance.Balancer, error) {
|
||||
switch algorithm {
|
||||
func createBalancer(cfg *config.ProxyConfig) (loadbalance.Balancer, error) {
|
||||
switch cfg.LoadBalance {
|
||||
case "round_robin", "":
|
||||
return loadbalance.NewRoundRobin(), nil
|
||||
case "weighted_round_robin":
|
||||
@ -132,13 +133,19 @@ func createBalancer(algorithm string) (loadbalance.Balancer, error) {
|
||||
return loadbalance.NewLeastConnections(), nil
|
||||
case "ip_hash":
|
||||
return loadbalance.NewIPHash(), nil
|
||||
case "consistent_hash":
|
||||
virtualNodes := cfg.VirtualNodes
|
||||
if virtualNodes <= 0 {
|
||||
virtualNodes = 150
|
||||
}
|
||||
return loadbalance.NewConsistentHash(virtualNodes, cfg.HashKey), nil
|
||||
default:
|
||||
return nil, errors.New("unsupported load balance algorithm: " + algorithm)
|
||||
return nil, errors.New("unsupported load balance algorithm: " + cfg.LoadBalance)
|
||||
}
|
||||
}
|
||||
|
||||
// createHostClient 为后台目标 URL 创建 fasthttp.HostClient。
|
||||
func createHostClient(targetURL string, timeout config.ProxyTimeout) *fasthttp.HostClient {
|
||||
func createHostClient(targetURL string, timeout config.ProxyTimeout, transportCfg *config.TransportConfig) *fasthttp.HostClient {
|
||||
// 从目标 URL 解析主机和协议
|
||||
addr := targetURL
|
||||
isTLS := false
|
||||
@ -155,13 +162,27 @@ func createHostClient(targetURL string, timeout config.ProxyTimeout) *fasthttp.H
|
||||
addr = addr[:idx]
|
||||
}
|
||||
|
||||
// 默认值
|
||||
maxIdleConnDuration := 90 * time.Second
|
||||
maxConns := 100
|
||||
|
||||
// 应用 Transport 配置
|
||||
if transportCfg != nil {
|
||||
if transportCfg.IdleConnTimeout > 0 {
|
||||
maxIdleConnDuration = transportCfg.IdleConnTimeout
|
||||
}
|
||||
if transportCfg.MaxConnsPerHost > 0 {
|
||||
maxConns = transportCfg.MaxConnsPerHost
|
||||
}
|
||||
}
|
||||
|
||||
client := &fasthttp.HostClient{
|
||||
Addr: addr,
|
||||
IsTLS: isTLS,
|
||||
ReadTimeout: timeout.Read,
|
||||
WriteTimeout: timeout.Write,
|
||||
MaxIdleConnDuration: 60 * time.Second,
|
||||
MaxConns: 100,
|
||||
MaxIdleConnDuration: maxIdleConnDuration,
|
||||
MaxConns: maxConns,
|
||||
MaxConnWaitTimeout: timeout.Connect,
|
||||
RetryIf: nil, // Disable automatic retries
|
||||
DisablePathNormalizing: false,
|
||||
@ -389,13 +410,13 @@ func (p *Proxy) UpdateTargets(targets []*loadbalance.Target) error {
|
||||
// 清除旧客户端
|
||||
p.clients = make(map[string]*fasthttp.HostClient)
|
||||
|
||||
// 初始化新客户端
|
||||
// 初始化新客户端(使用 nil TransportConfig 保持原有行为)
|
||||
for _, target := range targets {
|
||||
if target.URL == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
client := createHostClient(target.URL, p.config.Timeout)
|
||||
client := createHostClient(target.URL, p.config.Timeout, nil)
|
||||
p.clients[target.URL] = client
|
||||
}
|
||||
|
||||
|
||||
@ -121,7 +121,7 @@ func TestNewProxy(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p, err := NewProxy(tt.cfg, tt.targets)
|
||||
p, err := NewProxy(tt.cfg, tt.targets, nil)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("NewProxy() expected error containing %q, got nil", tt.errContains)
|
||||
@ -166,7 +166,7 @@ func TestServeHTTP_NoHealthyTargets(t *testing.T) {
|
||||
targets[0].Healthy.Store(false)
|
||||
targets[1].Healthy.Store(false)
|
||||
|
||||
p, err := NewProxy(cfg, targets)
|
||||
p, err := NewProxy(cfg, targets, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
@ -216,7 +216,7 @@ func TestServeHTTP_RequestForwarding(t *testing.T) {
|
||||
{URL: "http://localhost:8080"},
|
||||
}
|
||||
|
||||
p, err := NewProxy(cfg, targets)
|
||||
p, err := NewProxy(cfg, targets, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
@ -308,7 +308,7 @@ func TestSelectTarget(t *testing.T) {
|
||||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||
}
|
||||
|
||||
p, err := NewProxy(cfg, tt.targets)
|
||||
p, err := NewProxy(cfg, tt.targets, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
@ -425,7 +425,7 @@ func TestModifyRequestHeaders(t *testing.T) {
|
||||
{URL: "http://localhost:8080"},
|
||||
}
|
||||
|
||||
p, err := NewProxy(cfg, targets)
|
||||
p, err := NewProxy(cfg, targets, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
@ -520,7 +520,7 @@ func TestModifyResponseHeaders(t *testing.T) {
|
||||
{URL: "http://localhost:8080"},
|
||||
}
|
||||
|
||||
p, err := NewProxy(cfg, targets)
|
||||
p, err := NewProxy(cfg, targets, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
@ -608,7 +608,7 @@ func TestUpdateTargets(t *testing.T) {
|
||||
{URL: "http://old2:8080"},
|
||||
}
|
||||
|
||||
p, err := NewProxy(cfg, initialTargets)
|
||||
p, err := NewProxy(cfg, initialTargets, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
@ -657,7 +657,7 @@ func TestGetTargets(t *testing.T) {
|
||||
{URL: "http://backend2:8080"},
|
||||
}
|
||||
|
||||
p, err := NewProxy(cfg, targets)
|
||||
p, err := NewProxy(cfg, targets, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
@ -686,7 +686,7 @@ func TestGetConfig(t *testing.T) {
|
||||
{URL: "http://localhost:8080"},
|
||||
}
|
||||
|
||||
p, err := NewProxy(cfg, targets)
|
||||
p, err := NewProxy(cfg, targets, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
@ -763,38 +763,37 @@ func TestIsWebSocketRequest(t *testing.T) {
|
||||
func TestCreateBalancer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
algorithm string
|
||||
cfg *config.ProxyConfig
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "轮询",
|
||||
algorithm: "round_robin",
|
||||
wantErr: false,
|
||||
name: "轮询",
|
||||
cfg: &config.ProxyConfig{LoadBalance: "round_robin"},
|
||||
},
|
||||
{
|
||||
name: "加权轮询",
|
||||
algorithm: "weighted_round_robin",
|
||||
wantErr: false,
|
||||
name: "加权轮询",
|
||||
cfg: &config.ProxyConfig{LoadBalance: "weighted_round_robin"},
|
||||
},
|
||||
{
|
||||
name: "最少连接",
|
||||
algorithm: "least_conn",
|
||||
wantErr: false,
|
||||
name: "最少连接",
|
||||
cfg: &config.ProxyConfig{LoadBalance: "least_conn"},
|
||||
},
|
||||
{
|
||||
name: "IP哈希",
|
||||
algorithm: "ip_hash",
|
||||
wantErr: false,
|
||||
name: "IP哈希",
|
||||
cfg: &config.ProxyConfig{LoadBalance: "ip_hash"},
|
||||
},
|
||||
{
|
||||
name: "空算法(默认轮询)",
|
||||
algorithm: "",
|
||||
wantErr: false,
|
||||
name: "一致性哈希",
|
||||
cfg: &config.ProxyConfig{LoadBalance: "consistent_hash", HashKey: "ip", VirtualNodes: 150},
|
||||
},
|
||||
{
|
||||
name: "空算法(默认轮询)",
|
||||
cfg: &config.ProxyConfig{LoadBalance: ""},
|
||||
},
|
||||
{
|
||||
name: "无效算法",
|
||||
algorithm: "unknown_algorithm",
|
||||
cfg: &config.ProxyConfig{LoadBalance: "unknown_algorithm"},
|
||||
wantErr: true,
|
||||
errContains: "unsupported load balance algorithm",
|
||||
},
|
||||
@ -802,23 +801,23 @@ func TestCreateBalancer(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
balancer, err := createBalancer(tt.algorithm)
|
||||
balancer, err := createBalancer(tt.cfg)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("createBalancer(%q) expected error", tt.algorithm)
|
||||
t.Errorf("createBalancer(%v) expected error", tt.cfg.LoadBalance)
|
||||
return
|
||||
}
|
||||
if !contains(err.Error(), tt.errContains) {
|
||||
t.Errorf("createBalancer(%q) error = %v, want containing %q", tt.algorithm, err, tt.errContains)
|
||||
t.Errorf("createBalancer(%v) error = %v, want containing %q", tt.cfg.LoadBalance, err, tt.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("createBalancer(%q) unexpected error: %v", tt.algorithm, err)
|
||||
t.Errorf("createBalancer(%v) unexpected error: %v", tt.cfg.LoadBalance, err)
|
||||
return
|
||||
}
|
||||
if balancer == nil {
|
||||
t.Errorf("createBalancer(%q) returned nil balancer", tt.algorithm)
|
||||
t.Errorf("createBalancer(%v) returned nil balancer", tt.cfg.LoadBalance)
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -850,7 +849,7 @@ func TestCreateHostClient(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
client := createHostClient(tt.targetURL, tt.timeout)
|
||||
client := createHostClient(tt.targetURL, tt.timeout, nil)
|
||||
if client == nil {
|
||||
t.Error("createHostClient() returned nil")
|
||||
return
|
||||
@ -882,7 +881,7 @@ func TestHandleWebSocket(t *testing.T) {
|
||||
{URL: "http://localhost:8080"},
|
||||
}
|
||||
|
||||
p, err := NewProxy(cfg, targets)
|
||||
p, err := NewProxy(cfg, targets, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
@ -916,7 +915,7 @@ func TestSetHealthChecker(t *testing.T) {
|
||||
{URL: "http://localhost:8081"},
|
||||
}
|
||||
|
||||
p, err := NewProxy(cfg, targets)
|
||||
p, err := NewProxy(cfg, targets, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
@ -954,7 +953,7 @@ func TestGetClient(t *testing.T) {
|
||||
{URL: "http://localhost:8082"},
|
||||
}
|
||||
|
||||
p, err := NewProxy(cfg, targets)
|
||||
p, err := NewProxy(cfg, targets, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
@ -1018,7 +1017,7 @@ func TestProxyCache(t *testing.T) {
|
||||
}
|
||||
targets[0].Healthy.Store(true)
|
||||
|
||||
p, err := NewProxy(cfg, targets)
|
||||
p, err := NewProxy(cfg, targets, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
@ -1069,7 +1068,7 @@ func TestServeHTTP_WithPassiveHealthCheck(t *testing.T) {
|
||||
}
|
||||
targets[0].Healthy.Store(true)
|
||||
|
||||
p, err := NewProxy(cfg, targets)
|
||||
p, err := NewProxy(cfg, targets, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
|
||||
@ -251,7 +251,7 @@ func (s *Server) buildMiddlewareChain(serverCfg *config.ServerConfig) (*middlewa
|
||||
serverCfg.Security.Headers.ContentSecurityPolicy != "" ||
|
||||
serverCfg.Security.Headers.ReferrerPolicy != "" ||
|
||||
serverCfg.Security.Headers.PermissionsPolicy != "" {
|
||||
headers := security.NewSecurityHeaders(&serverCfg.Security.Headers)
|
||||
headers := security.NewSecurityHeadersWithHSTS(&serverCfg.Security.Headers, &serverCfg.SSL.HSTS)
|
||||
middlewares = append(middlewares, headers)
|
||||
}
|
||||
|
||||
@ -534,7 +534,8 @@ func (s *Server) registerProxyRoutes(router *handler.Router, serverCfg *config.S
|
||||
targets[j].Healthy.Store(true)
|
||||
}
|
||||
|
||||
p, err := proxy.NewProxy(proxyCfg, targets)
|
||||
// 传递 Transport 配置
|
||||
p, err := proxy.NewProxy(proxyCfg, targets, &s.config.Performance.Transport)
|
||||
if err != nil {
|
||||
logging.Error().Msg("创建代理失败: " + err.Error())
|
||||
continue
|
||||
|
||||
@ -69,14 +69,7 @@ func (v *VHostManager) SetDefault(handler fasthttp.RequestHandler) {
|
||||
// Handler 返回虚拟主机选择器
|
||||
func (v *VHostManager) Handler() fasthttp.RequestHandler {
|
||||
return func(ctx *fasthttp.RequestCtx) {
|
||||
host := string(ctx.Host())
|
||||
// 去除端口号
|
||||
for i := 0; i < len(host); i++ {
|
||||
if host[i] == ':' {
|
||||
host = host[:i]
|
||||
break
|
||||
}
|
||||
}
|
||||
host := stripPort(string(ctx.Host()))
|
||||
|
||||
if vhost, ok := v.hosts[host]; ok {
|
||||
vhost.handler(ctx)
|
||||
@ -87,3 +80,38 @@ func (v *VHostManager) Handler() fasthttp.RequestHandler {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// stripPort 从 Host 头中移除端口号。
|
||||
//
|
||||
// 支持 IPv4 和 IPv6 格式:
|
||||
// - example.com:8080 -> example.com
|
||||
// - [::1]:8080 -> [::1]
|
||||
// - [2001:db8::1]:443 -> [2001:db8::1]
|
||||
// - example.com -> example.com
|
||||
func stripPort(host string) string {
|
||||
// 空字符串直接返回
|
||||
if len(host) == 0 {
|
||||
return host
|
||||
}
|
||||
|
||||
// IPv6 格式:以 '[' 开头,找 ']:' 作为分隔点
|
||||
if host[0] == '[' {
|
||||
// 查找 ']:' 分隔符
|
||||
for i := 0; i < len(host)-1; i++ {
|
||||
if host[i] == ']' && host[i+1] == ':' {
|
||||
return host[:i+1] // 返回包含 ']' 的部分,如 "[::1]"
|
||||
}
|
||||
}
|
||||
// 没有 ']:' 分隔符,可能是纯 IPv6 地址(如 "[::1]")
|
||||
return host
|
||||
}
|
||||
|
||||
// IPv4 或域名格式:找第一个 ':' 作为分隔点
|
||||
for i := 0; i < len(host); i++ {
|
||||
if host[i] == ':' {
|
||||
return host[:i]
|
||||
}
|
||||
}
|
||||
|
||||
return host
|
||||
}
|
||||
|
||||
@ -99,13 +99,9 @@ func TestVHostManager_Handler(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("IPv6地址Host", func(t *testing.T) {
|
||||
// TODO: 当前 vhost.go 的端口剥离逻辑不支持 IPv6 格式 [::1]:8080
|
||||
// 它会错误地在第一个 ':' 处截断(IPv6 地址内部的冒号)
|
||||
// 修复方案:检查 host 是否以 '[' 开头,找 ']:' 作为分隔点
|
||||
manager := NewVHostManager()
|
||||
ipv6Called := false
|
||||
manager.AddHost("[::1]", mockHandler("ipv6", &ipv6Called))
|
||||
manager.SetDefault(mockHandler("default", &ipv6Called)) // fallback
|
||||
|
||||
handler := manager.Handler()
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
@ -113,9 +109,12 @@ func TestVHostManager_Handler(t *testing.T) {
|
||||
|
||||
handler(ctx)
|
||||
|
||||
// 当前实现不支持 IPv6,会 fallback 到默认 handler
|
||||
// 修复 vhost.go 后此测试应验证 ipv6Called 为 true
|
||||
t.Log("注意: 当前实现不支持 IPv6 地址,需要修复 vhost.go 的端口剥离逻辑")
|
||||
if !ipv6Called {
|
||||
t.Error("期望 [::1] handler 被调用,但未被调用")
|
||||
}
|
||||
if string(ctx.Response.Body()) != "ipv6" {
|
||||
t.Errorf("响应体 = %q, want %q", string(ctx.Response.Body()), "ipv6")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("空Host使用默认", func(t *testing.T) {
|
||||
@ -270,6 +269,10 @@ func TestVHostManager_PortStripping(t *testing.T) {
|
||||
{"标准HTTPS端口", "example.com:443", "example.com"},
|
||||
{"自定义端口", "example.com:8080", "example.com"},
|
||||
{"IPv6 localhost带端口", "[localhost]:8080", "[localhost]"},
|
||||
{"IPv6 loopback带端口", "[::1]:8080", "[::1]"},
|
||||
{"IPv6完整地址带端口", "[2001:db8::1]:443", "[2001:db8::1]"},
|
||||
{"IPv6无端口", "[::1]", "[::1]"},
|
||||
{"IPv6完整地址无端口", "[2001:db8::1]", "[2001:db8::1]"},
|
||||
{"空字符串", "", ""},
|
||||
}
|
||||
|
||||
@ -291,11 +294,8 @@ func TestVHostManager_PortStripping(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// IPv6 数字地址测试 - 当前实现有已知 bug
|
||||
t.Run("IPv6数字地址_已知限制", func(t *testing.T) {
|
||||
// TODO: vhost.go 的端口剥离逻辑不支持 IPv6 数字地址格式 [::1]:8080
|
||||
// 因为它会在第一个 ':' 处截断(IPv6 地址内部的冒号)
|
||||
// 结果:[:而不是 [::1]
|
||||
// IPv6 数字地址测试
|
||||
t.Run("IPv6数字地址", func(t *testing.T) {
|
||||
manager := NewVHostManager()
|
||||
ipv6Called := false
|
||||
manager.AddHost("[::1]", mockHandler("ipv6", &ipv6Called))
|
||||
@ -306,10 +306,11 @@ func TestVHostManager_PortStripping(t *testing.T) {
|
||||
|
||||
handler(ctx)
|
||||
|
||||
// 当前行为:不匹配,因为端口剥离错误
|
||||
if ipv6Called {
|
||||
t.Error("当前实现不支持 IPv6 数字地址的端口剥离,不应匹配")
|
||||
if !ipv6Called {
|
||||
t.Error("期望 [::1] handler 被调用,但未被调用")
|
||||
}
|
||||
if string(ctx.Response.Body()) != "ipv6" {
|
||||
t.Errorf("响应体 = %q, want %q", string(ctx.Response.Body()), "ipv6")
|
||||
}
|
||||
t.Log("已知限制: IPv6 数字地址端口剥离需要修复 vhost.go")
|
||||
})
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user