feat(proxy,middleware,config): 集成配置与代码差异修复
- 集成一致性哈希负载均衡到 proxy.go,支持 hash_key 和 virtual_nodes 配置 - 集成滑动窗口限流算法到 ratelimit.go,支持 algorithm 选择 - 应用 Transport 连接池配置到 createHostClient - 集成 HSTS 配置到安全头部中间件 - 补充 config.example.yaml 缺失配置(http3、gzip_static、sliding_window) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
d6367a1c38
commit
ec916d882d
@ -25,7 +25,9 @@ server:
|
|||||||
# weight: 3 # 权重(加权轮询时有效)
|
# weight: 3 # 权重(加权轮询时有效)
|
||||||
# - url: http://backend2:8080
|
# - url: http://backend2:8080
|
||||||
# weight: 1
|
# weight: 1
|
||||||
# load_balance: round_robin # 负载均衡算法: round_robin, weighted_round_robin, least_conn, ip_hash
|
# load_balance: round_robin # 负载均衡算法: round_robin, weighted_round_robin, least_conn, ip_hash, consistent_hash
|
||||||
|
# hash_key: "ip" # 一致性哈希键(有效值: ip, uri, header:X-Name)
|
||||||
|
# virtual_nodes: 150 # 一致性哈希虚拟节点数
|
||||||
# health_check: # 健康检查
|
# health_check: # 健康检查
|
||||||
# interval: 10s
|
# interval: 10s
|
||||||
# path: /health
|
# path: /health
|
||||||
@ -73,6 +75,9 @@ server:
|
|||||||
burst: 0 # 突发上限
|
burst: 0 # 突发上限
|
||||||
conn_limit: 0 # 连接数限制
|
conn_limit: 0 # 连接数限制
|
||||||
key: "ip" # 限流 key 来源(有效值: ip, header)
|
key: "ip" # 限流 key 来源(有效值: ip, header)
|
||||||
|
algorithm: "token_bucket" # 限流算法(有效值: token_bucket, sliding_window)
|
||||||
|
sliding_window_mode: "approximate" # 滑动窗口模式(有效值: approximate, precise)
|
||||||
|
sliding_window: 60 # 滑动窗口大小(秒)
|
||||||
|
|
||||||
# 认证配置(type 为空时禁用)
|
# 认证配置(type 为空时禁用)
|
||||||
auth:
|
auth:
|
||||||
@ -108,6 +113,14 @@ server:
|
|||||||
- "text/javascript"
|
- "text/javascript"
|
||||||
- "application/json"
|
- "application/json"
|
||||||
- "application/javascript"
|
- "application/javascript"
|
||||||
|
gzip_static: true # 启用预压缩文件支持(检测 .gz 文件)
|
||||||
|
gzip_static_extensions: # 预压缩文件扩展名
|
||||||
|
- ".html"
|
||||||
|
- ".css"
|
||||||
|
- ".js"
|
||||||
|
- ".json"
|
||||||
|
- ".xml"
|
||||||
|
- ".svg"
|
||||||
|
|
||||||
# 多虚拟主机模式(可选,每个虚拟主机支持完整的 server 配置)
|
# 多虚拟主机模式(可选,每个虚拟主机支持完整的 server 配置)
|
||||||
# servers:
|
# servers:
|
||||||
@ -158,6 +171,14 @@ server:
|
|||||||
# 默认 TLS 协议: TLSv1.2, TLSv1.3(不支持 TLSv1.0/1.1)
|
# 默认 TLS 协议: TLSv1.2, TLSv1.3(不支持 TLSv1.0/1.1)
|
||||||
# 默认 HSTS 配置: max_age=31536000(1年), include_sub_domains=true
|
# 默认 HSTS 配置: max_age=31536000(1年), include_sub_domains=true
|
||||||
|
|
||||||
|
# HTTP/3 (QUIC) 配置(可选)
|
||||||
|
# http3:
|
||||||
|
# enabled: false # 是否启用 HTTP/3
|
||||||
|
# listen: ":443" # UDP 监听地址
|
||||||
|
# max_streams: 100 # 最大并发流数
|
||||||
|
# idle_timeout: 30s # 空闲超时
|
||||||
|
# enable_0rtt: false # 启用 0-RTT 快速连接
|
||||||
|
|
||||||
# TCP/UDP Stream 代理配置(可选)
|
# TCP/UDP Stream 代理配置(可选)
|
||||||
# stream:
|
# stream:
|
||||||
# - listen: "3306" # 监听地址
|
# - listen: "3306" # 监听地址
|
||||||
|
|||||||
@ -69,6 +69,9 @@ func DefaultConfig() *Config {
|
|||||||
Burst: 0,
|
Burst: 0,
|
||||||
ConnLimit: 0,
|
ConnLimit: 0,
|
||||||
Key: "ip",
|
Key: "ip",
|
||||||
|
Algorithm: "token_bucket",
|
||||||
|
SlidingWindowMode: "approximate",
|
||||||
|
SlidingWindow: 60,
|
||||||
},
|
},
|
||||||
Auth: AuthConfig{
|
Auth: AuthConfig{
|
||||||
RequireTLS: true,
|
RequireTLS: true,
|
||||||
@ -86,6 +89,8 @@ func DefaultConfig() *Config {
|
|||||||
Type: "gzip",
|
Type: "gzip",
|
||||||
Level: 6,
|
Level: 6,
|
||||||
MinSize: 1024,
|
MinSize: 1024,
|
||||||
|
GzipStatic: false,
|
||||||
|
GzipStaticExtensions: []string{".gz", ".br"},
|
||||||
Types: []string{
|
Types: []string{
|
||||||
"text/html",
|
"text/html",
|
||||||
"text/css",
|
"text/css",
|
||||||
@ -133,6 +138,13 @@ func DefaultConfig() *Config {
|
|||||||
Allow: []string{"127.0.0.1"},
|
Allow: []string{"127.0.0.1"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
HTTP3: HTTP3Config{
|
||||||
|
Enabled: false,
|
||||||
|
Listen: ":443",
|
||||||
|
MaxStreams: 100,
|
||||||
|
IdleTimeout: 60 * time.Second,
|
||||||
|
Enable0RTT: false,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -190,7 +202,9 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
|
|||||||
buf.WriteString(" # weight: 3 # 权重(加权轮询时有效)\n")
|
buf.WriteString(" # weight: 3 # 权重(加权轮询时有效)\n")
|
||||||
buf.WriteString(" # - url: http://backend2:8080\n")
|
buf.WriteString(" # - url: http://backend2:8080\n")
|
||||||
buf.WriteString(" # weight: 1\n")
|
buf.WriteString(" # weight: 1\n")
|
||||||
buf.WriteString(" # load_balance: round_robin # 负载均衡算法: round_robin, weighted_round_robin, least_conn, ip_hash\n")
|
buf.WriteString(" # load_balance: round_robin # 负载均衡算法(有效值: round_robin, weighted_round_robin, least_conn, ip_hash, consistent_hash)\n")
|
||||||
|
buf.WriteString(" # hash_key: ip # 一致性哈希键(仅 load_balance=consistent_hash 时有效,有效值: ip, uri, header:X-Name)\n")
|
||||||
|
buf.WriteString(" # virtual_nodes: 150 # 一致性哈希虚拟节点数(仅 load_balance=consistent_hash 时有效)\n")
|
||||||
buf.WriteString(" # health_check: # 健康检查\n")
|
buf.WriteString(" # health_check: # 健康检查\n")
|
||||||
buf.WriteString(" # interval: 10s\n")
|
buf.WriteString(" # interval: 10s\n")
|
||||||
buf.WriteString(" # path: /health\n")
|
buf.WriteString(" # path: /health\n")
|
||||||
@ -243,6 +257,9 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
|
|||||||
buf.WriteString(fmt.Sprintf(" burst: %d # 突发上限\n", cfg.Server.Security.RateLimit.Burst))
|
buf.WriteString(fmt.Sprintf(" burst: %d # 突发上限\n", cfg.Server.Security.RateLimit.Burst))
|
||||||
buf.WriteString(fmt.Sprintf(" conn_limit: %d # 连接数限制\n", cfg.Server.Security.RateLimit.ConnLimit))
|
buf.WriteString(fmt.Sprintf(" conn_limit: %d # 连接数限制\n", cfg.Server.Security.RateLimit.ConnLimit))
|
||||||
buf.WriteString(fmt.Sprintf(" key: \"%s\" # 限流 key 来源(有效值: ip, header)\n", cfg.Server.Security.RateLimit.Key))
|
buf.WriteString(fmt.Sprintf(" key: \"%s\" # 限流 key 来源(有效值: ip, header)\n", cfg.Server.Security.RateLimit.Key))
|
||||||
|
buf.WriteString(fmt.Sprintf(" algorithm: \"%s\" # 限流算法(有效值: token_bucket, sliding_window)\n", cfg.Server.Security.RateLimit.Algorithm))
|
||||||
|
buf.WriteString(fmt.Sprintf(" sliding_window_mode: \"%s\" # 滑动窗口模式(有效值: approximate, precise,仅 algorithm=sliding_window 时有效)\n", cfg.Server.Security.RateLimit.SlidingWindowMode))
|
||||||
|
buf.WriteString(fmt.Sprintf(" sliding_window: %d # 滑动窗口大小(秒,仅 algorithm=sliding_window 时有效)\n", cfg.Server.Security.RateLimit.SlidingWindow))
|
||||||
buf.WriteString("\n")
|
buf.WriteString("\n")
|
||||||
buf.WriteString(" # 认证配置(type 为空时禁用)\n")
|
buf.WriteString(" # 认证配置(type 为空时禁用)\n")
|
||||||
buf.WriteString(" auth:\n")
|
buf.WriteString(" auth:\n")
|
||||||
@ -276,6 +293,11 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
|
|||||||
buf.WriteString(fmt.Sprintf(" type: \"%s\" # 压缩类型(有效值: gzip, brotli, both,空表示禁用)\n", cfg.Server.Compression.Type))
|
buf.WriteString(fmt.Sprintf(" type: \"%s\" # 压缩类型(有效值: gzip, brotli, both,空表示禁用)\n", cfg.Server.Compression.Type))
|
||||||
buf.WriteString(fmt.Sprintf(" level: %d # 压缩级别(范围 1-9,值越大压缩率越高但速度越慢)\n", cfg.Server.Compression.Level))
|
buf.WriteString(fmt.Sprintf(" level: %d # 压缩级别(范围 1-9,值越大压缩率越高但速度越慢)\n", cfg.Server.Compression.Level))
|
||||||
buf.WriteString(fmt.Sprintf(" min_size: %d # 最小压缩大小(字节,小于此值不压缩)\n", cfg.Server.Compression.MinSize))
|
buf.WriteString(fmt.Sprintf(" min_size: %d # 最小压缩大小(字节,小于此值不压缩)\n", cfg.Server.Compression.MinSize))
|
||||||
|
buf.WriteString(fmt.Sprintf(" gzip_static: %v # 启用预压缩文件支持(自动查找 .gz/.br 文件)\n", cfg.Server.Compression.GzipStatic))
|
||||||
|
buf.WriteString(" gzip_static_extensions: # 预压缩文件扩展名\n")
|
||||||
|
for _, ext := range cfg.Server.Compression.GzipStaticExtensions {
|
||||||
|
buf.WriteString(fmt.Sprintf(" - \"%s\"\n", ext))
|
||||||
|
}
|
||||||
buf.WriteString(" types: # 可压缩的 MIME 类型\n")
|
buf.WriteString(" types: # 可压缩的 MIME 类型\n")
|
||||||
for _, t := range cfg.Server.Compression.Types {
|
for _, t := range cfg.Server.Compression.Types {
|
||||||
buf.WriteString(fmt.Sprintf(" - \"%s\"\n", t))
|
buf.WriteString(fmt.Sprintf(" - \"%s\"\n", t))
|
||||||
@ -383,6 +405,16 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
|
|||||||
buf.WriteString(fmt.Sprintf(" max_conns_per_host: %d # 每主机最大连接(0 表示不限制)\n", cfg.Performance.Transport.MaxConnsPerHost))
|
buf.WriteString(fmt.Sprintf(" max_conns_per_host: %d # 每主机最大连接(0 表示不限制)\n", cfg.Performance.Transport.MaxConnsPerHost))
|
||||||
buf.WriteString("\n")
|
buf.WriteString("\n")
|
||||||
|
|
||||||
|
// HTTP3 配置
|
||||||
|
buf.WriteString("# HTTP/3 (QUIC) 配置(需要 SSL 证书)\n")
|
||||||
|
buf.WriteString("http3:\n")
|
||||||
|
buf.WriteString(fmt.Sprintf(" enabled: %v # 是否启用 HTTP/3\n", cfg.HTTP3.Enabled))
|
||||||
|
buf.WriteString(fmt.Sprintf(" listen: \"%s\" # UDP 监听地址\n", cfg.HTTP3.Listen))
|
||||||
|
buf.WriteString(fmt.Sprintf(" max_streams: %d # 最大并发流\n", cfg.HTTP3.MaxStreams))
|
||||||
|
buf.WriteString(fmt.Sprintf(" idle_timeout: %ds # 空闲超时\n", int(cfg.HTTP3.IdleTimeout.Seconds())))
|
||||||
|
buf.WriteString(fmt.Sprintf(" enable_0rtt: %v # 启用 0-RTT(早期数据,可能存在安全风险)\n", cfg.HTTP3.Enable0RTT))
|
||||||
|
buf.WriteString("\n")
|
||||||
|
|
||||||
// monitoring 配置
|
// monitoring 配置
|
||||||
buf.WriteString("# 监控配置\n")
|
buf.WriteString("# 监控配置\n")
|
||||||
buf.WriteString("monitoring:\n")
|
buf.WriteString("monitoring:\n")
|
||||||
|
|||||||
@ -7,6 +7,7 @@ var ValidAlgorithms = []string{
|
|||||||
"weighted_round_robin",
|
"weighted_round_robin",
|
||||||
"least_conn",
|
"least_conn",
|
||||||
"ip_hash",
|
"ip_hash",
|
||||||
|
"consistent_hash",
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValidAlgorithm 检查算法是否有效。
|
// IsValidAlgorithm 检查算法是否有效。
|
||||||
|
|||||||
@ -59,6 +59,20 @@ type SecurityHeadersMiddleware struct {
|
|||||||
// 返回值:
|
// 返回值:
|
||||||
// - *SecurityHeadersMiddleware: 配置好的中间件实例
|
// - *SecurityHeadersMiddleware: 配置好的中间件实例
|
||||||
func NewSecurityHeaders(cfg *config.SecurityHeaders) *SecurityHeadersMiddleware {
|
func NewSecurityHeaders(cfg *config.SecurityHeaders) *SecurityHeadersMiddleware {
|
||||||
|
return NewSecurityHeadersWithHSTS(cfg, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSecurityHeadersWithHSTS 创建新的安全响应头中间件,支持 HSTS 配置。
|
||||||
|
//
|
||||||
|
// 根据配置创建中间件实例,如果配置为 nil 则使用安全的默认值。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - cfg: 安全头配置,可以为 nil 使用默认配置
|
||||||
|
// - hstsCfg: HSTS 配置,可以为 nil 使用默认值
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - *SecurityHeadersMiddleware: 配置好的中间件实例
|
||||||
|
func NewSecurityHeadersWithHSTS(cfg *config.SecurityHeaders, hstsCfg *config.HSTSConfig) *SecurityHeadersMiddleware {
|
||||||
sh := &SecurityHeadersMiddleware{}
|
sh := &SecurityHeadersMiddleware{}
|
||||||
|
|
||||||
if cfg != nil {
|
if cfg != nil {
|
||||||
@ -73,11 +87,24 @@ func NewSecurityHeaders(cfg *config.SecurityHeaders) *SecurityHeadersMiddleware
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 预格式化 HSTS 头值
|
// 预格式化 HSTS 头值
|
||||||
sh.formatHSTS()
|
sh.formatHSTSFromConfig(hstsCfg)
|
||||||
|
|
||||||
return sh
|
return sh
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// formatHSTSFromConfig 根据配置格式化 HSTS 头值。
|
||||||
|
func (sh *SecurityHeadersMiddleware) formatHSTSFromConfig(hstsCfg *config.HSTSConfig) {
|
||||||
|
if hstsCfg != nil {
|
||||||
|
maxAge := hstsCfg.MaxAge
|
||||||
|
if maxAge <= 0 {
|
||||||
|
maxAge = 31536000 // 默认 1 年
|
||||||
|
}
|
||||||
|
sh.hsts = formatHSTSValue(maxAge, hstsCfg.IncludeSubDomains, hstsCfg.Preload)
|
||||||
|
} else {
|
||||||
|
sh.formatHSTS() // 使用默认值
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Name 返回中间件名称。
|
// Name 返回中间件名称。
|
||||||
//
|
//
|
||||||
// 返回值:
|
// 返回值:
|
||||||
|
|||||||
@ -82,7 +82,7 @@ type KeyFunc func(ctx *fasthttp.RequestCtx) string
|
|||||||
// 返回值:
|
// 返回值:
|
||||||
// - *RateLimiter: 配置好的限流器实例
|
// - *RateLimiter: 配置好的限流器实例
|
||||||
// - error: 配置无效时返回错误(如速率小于 0)
|
// - error: 配置无效时返回错误(如速率小于 0)
|
||||||
func NewRateLimiter(cfg *config.RateLimitConfig) (*RateLimiter, error) {
|
func NewRateLimiter(cfg *config.RateLimitConfig) (middleware.Middleware, error) {
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
return nil, errors.New("rate limit config is nil")
|
return nil, errors.New("rate limit config is nil")
|
||||||
}
|
}
|
||||||
@ -91,6 +91,29 @@ func NewRateLimiter(cfg *config.RateLimitConfig) (*RateLimiter, error) {
|
|||||||
return nil, errors.New("request rate must be positive")
|
return nil, errors.New("request rate must be positive")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 根据算法选择限流器
|
||||||
|
algorithm := cfg.Algorithm
|
||||||
|
if algorithm == "" {
|
||||||
|
algorithm = "token_bucket" // 默认令牌桶
|
||||||
|
}
|
||||||
|
|
||||||
|
switch algorithm {
|
||||||
|
case "token_bucket", "":
|
||||||
|
return newTokenBucketLimiter(cfg)
|
||||||
|
case "sliding_window":
|
||||||
|
window := time.Duration(cfg.SlidingWindow) * time.Second
|
||||||
|
if window <= 0 {
|
||||||
|
window = time.Second // 默认 1 秒窗口
|
||||||
|
}
|
||||||
|
precise := cfg.SlidingWindowMode == "precise"
|
||||||
|
return NewSlidingWindowLimiterWrapper(cfg, window, precise)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown algorithm: %s", algorithm)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// newTokenBucketLimiter 创建令牌桶限流器。
|
||||||
|
func newTokenBucketLimiter(cfg *config.RateLimitConfig) (*RateLimiter, error) {
|
||||||
if cfg.Burst < cfg.RequestRate {
|
if cfg.Burst < cfg.RequestRate {
|
||||||
return nil, errors.New("burst must be at least equal to request rate")
|
return nil, errors.New("burst must be at least equal to request rate")
|
||||||
}
|
}
|
||||||
@ -114,6 +137,49 @@ func NewRateLimiter(cfg *config.RateLimitConfig) (*RateLimiter, error) {
|
|||||||
return rl, nil
|
return rl, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SlidingWindowLimiterWrapper 滑动窗口限流器包装,实现 middleware.Middleware 接口。
|
||||||
|
type SlidingWindowLimiterWrapper struct {
|
||||||
|
limiter *SlidingWindowLimiter
|
||||||
|
keyFunc KeyFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSlidingWindowLimiterWrapper 创建滑动窗口限流器包装。
|
||||||
|
func NewSlidingWindowLimiterWrapper(cfg *config.RateLimitConfig, window time.Duration, precise bool) (*SlidingWindowLimiterWrapper, error) {
|
||||||
|
var keyFunc KeyFunc
|
||||||
|
switch cfg.Key {
|
||||||
|
case "ip", "":
|
||||||
|
keyFunc = keyByIP
|
||||||
|
case "header":
|
||||||
|
keyFunc = keyByHeader
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown key type: %s", cfg.Key)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &SlidingWindowLimiterWrapper{
|
||||||
|
limiter: NewSlidingWindowLimiter(window, cfg.RequestRate, precise),
|
||||||
|
keyFunc: keyFunc,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name 返回中间件名称。
|
||||||
|
func (s *SlidingWindowLimiterWrapper) Name() string {
|
||||||
|
return "sliding_window_limiter"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process 包装下一个处理器,添加限流逻辑。
|
||||||
|
func (s *SlidingWindowLimiterWrapper) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
||||||
|
return func(ctx *fasthttp.RequestCtx) {
|
||||||
|
key := s.keyFunc(ctx)
|
||||||
|
|
||||||
|
if !s.limiter.Allow(key) {
|
||||||
|
ctx.Error("Too Many Requests", fasthttp.StatusTooManyRequests)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Name 返回中间件名称。
|
// Name 返回中间件名称。
|
||||||
//
|
//
|
||||||
// 返回值:
|
// 返回值:
|
||||||
|
|||||||
@ -83,7 +83,7 @@ func TestNewRateLimiter(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRateLimiterAllow(t *testing.T) {
|
func TestRateLimiterAllow(t *testing.T) {
|
||||||
rl, err := NewRateLimiter(&config.RateLimitConfig{
|
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||||||
RequestRate: 10,
|
RequestRate: 10,
|
||||||
Burst: 10,
|
Burst: 10,
|
||||||
})
|
})
|
||||||
@ -91,6 +91,11 @@ func TestRateLimiterAllow(t *testing.T) {
|
|||||||
t.Fatalf("NewRateLimiter() error: %v", err)
|
t.Fatalf("NewRateLimiter() error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rl, ok := mw.(*RateLimiter)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected *RateLimiter, got %T", mw)
|
||||||
|
}
|
||||||
|
|
||||||
// Test burst allowance
|
// Test burst allowance
|
||||||
key := "test-key"
|
key := "test-key"
|
||||||
|
|
||||||
@ -108,7 +113,7 @@ func TestRateLimiterAllow(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRateLimiterTokenRefill(t *testing.T) {
|
func TestRateLimiterTokenRefill(t *testing.T) {
|
||||||
rl, err := NewRateLimiter(&config.RateLimitConfig{
|
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||||||
RequestRate: 100, // 100 tokens per second
|
RequestRate: 100, // 100 tokens per second
|
||||||
Burst: 100,
|
Burst: 100,
|
||||||
})
|
})
|
||||||
@ -116,6 +121,11 @@ func TestRateLimiterTokenRefill(t *testing.T) {
|
|||||||
t.Fatalf("NewRateLimiter() error: %v", err)
|
t.Fatalf("NewRateLimiter() error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rl, ok := mw.(*RateLimiter)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected *RateLimiter, got %T", mw)
|
||||||
|
}
|
||||||
|
|
||||||
key := "refill-test"
|
key := "refill-test"
|
||||||
|
|
||||||
// Exhaust the burst
|
// Exhaust the burst
|
||||||
@ -138,7 +148,7 @@ func TestRateLimiterTokenRefill(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRateLimiterReset(t *testing.T) {
|
func TestRateLimiterReset(t *testing.T) {
|
||||||
rl, err := NewRateLimiter(&config.RateLimitConfig{
|
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||||||
RequestRate: 1,
|
RequestRate: 1,
|
||||||
Burst: 1,
|
Burst: 1,
|
||||||
})
|
})
|
||||||
@ -146,6 +156,11 @@ func TestRateLimiterReset(t *testing.T) {
|
|||||||
t.Fatalf("NewRateLimiter() error: %v", err)
|
t.Fatalf("NewRateLimiter() error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rl, ok := mw.(*RateLimiter)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected *RateLimiter, got %T", mw)
|
||||||
|
}
|
||||||
|
|
||||||
key := "reset-test"
|
key := "reset-test"
|
||||||
|
|
||||||
// Exhaust
|
// Exhaust
|
||||||
@ -164,7 +179,7 @@ func TestRateLimiterReset(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRateLimiterResetAll(t *testing.T) {
|
func TestRateLimiterResetAll(t *testing.T) {
|
||||||
rl, err := NewRateLimiter(&config.RateLimitConfig{
|
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||||||
RequestRate: 1,
|
RequestRate: 1,
|
||||||
Burst: 1,
|
Burst: 1,
|
||||||
})
|
})
|
||||||
@ -172,6 +187,11 @@ func TestRateLimiterResetAll(t *testing.T) {
|
|||||||
t.Fatalf("NewRateLimiter() error: %v", err)
|
t.Fatalf("NewRateLimiter() error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rl, ok := mw.(*RateLimiter)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected *RateLimiter, got %T", mw)
|
||||||
|
}
|
||||||
|
|
||||||
// Create multiple buckets
|
// Create multiple buckets
|
||||||
rl.Allow("key1")
|
rl.Allow("key1")
|
||||||
rl.Allow("key2")
|
rl.Allow("key2")
|
||||||
@ -186,7 +206,7 @@ func TestRateLimiterResetAll(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRateLimiterCleanup(t *testing.T) {
|
func TestRateLimiterCleanup(t *testing.T) {
|
||||||
rl, err := NewRateLimiter(&config.RateLimitConfig{
|
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||||||
RequestRate: 100,
|
RequestRate: 100,
|
||||||
Burst: 100,
|
Burst: 100,
|
||||||
})
|
})
|
||||||
@ -194,6 +214,11 @@ func TestRateLimiterCleanup(t *testing.T) {
|
|||||||
t.Fatalf("NewRateLimiter() error: %v", err)
|
t.Fatalf("NewRateLimiter() error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rl, ok := mw.(*RateLimiter)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected *RateLimiter, got %T", mw)
|
||||||
|
}
|
||||||
|
|
||||||
// Create some buckets
|
// Create some buckets
|
||||||
rl.Allow("key1")
|
rl.Allow("key1")
|
||||||
rl.Allow("key2")
|
rl.Allow("key2")
|
||||||
@ -208,7 +233,7 @@ func TestRateLimiterCleanup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRateLimiterProcess(t *testing.T) {
|
func TestRateLimiterProcess(t *testing.T) {
|
||||||
rl, err := NewRateLimiter(&config.RateLimitConfig{
|
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||||||
RequestRate: 100,
|
RequestRate: 100,
|
||||||
Burst: 100,
|
Burst: 100,
|
||||||
})
|
})
|
||||||
@ -220,14 +245,14 @@ func TestRateLimiterProcess(t *testing.T) {
|
|||||||
ctx.WriteString("OK")
|
ctx.WriteString("OK")
|
||||||
}
|
}
|
||||||
|
|
||||||
handler := rl.Process(nextHandler)
|
handler := mw.Process(nextHandler)
|
||||||
if handler == nil {
|
if handler == nil {
|
||||||
t.Error("Process() returned nil handler")
|
t.Error("Process() returned nil handler")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRateLimiterGetStats(t *testing.T) {
|
func TestRateLimiterGetStats(t *testing.T) {
|
||||||
rl, err := NewRateLimiter(&config.RateLimitConfig{
|
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||||||
RequestRate: 100,
|
RequestRate: 100,
|
||||||
Burst: 200,
|
Burst: 200,
|
||||||
})
|
})
|
||||||
@ -235,6 +260,11 @@ func TestRateLimiterGetStats(t *testing.T) {
|
|||||||
t.Fatalf("NewRateLimiter() error: %v", err)
|
t.Fatalf("NewRateLimiter() error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rl, ok := mw.(*RateLimiter)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected *RateLimiter, got %T", mw)
|
||||||
|
}
|
||||||
|
|
||||||
rl.Allow("key1")
|
rl.Allow("key1")
|
||||||
rl.Allow("key2")
|
rl.Allow("key2")
|
||||||
|
|
||||||
|
|||||||
@ -64,11 +64,12 @@ type Proxy struct {
|
|||||||
// 参数:
|
// 参数:
|
||||||
// - cfg: 代理配置,包括超时时间、请求头和负载均衡策略
|
// - cfg: 代理配置,包括超时时间、请求头和负载均衡策略
|
||||||
// - targets: 要代理请求的后端目标列表
|
// - targets: 要代理请求的后端目标列表
|
||||||
|
// - transportCfg: 可选的 Transport 连接池配置,nil 时使用默认值
|
||||||
//
|
//
|
||||||
// 返回值:
|
// 返回值:
|
||||||
// - *Proxy: 配置完成并可处理请求的代理实例
|
// - *Proxy: 配置完成并可处理请求的代理实例
|
||||||
// - error: 初始化失败时非空(无效配置、没有健康目标等)
|
// - error: 初始化失败时非空(无效配置、没有健康目标等)
|
||||||
func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target) (*Proxy, error) {
|
func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target, transportCfg *config.TransportConfig) (*Proxy, error) {
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
return nil, errors.New("proxy config is nil")
|
return nil, errors.New("proxy config is nil")
|
||||||
}
|
}
|
||||||
@ -78,7 +79,7 @@ func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target) (*Proxy, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 根据配置创建负载均衡器
|
// 根据配置创建负载均衡器
|
||||||
balancer, err := createBalancer(cfg.LoadBalance)
|
balancer, err := createBalancer(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -96,7 +97,7 @@ func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target) (*Proxy, e
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
client := createHostClient(target.URL, cfg.Timeout)
|
client := createHostClient(target.URL, cfg.Timeout, transportCfg)
|
||||||
p.clients[target.URL] = client
|
p.clients[target.URL] = client
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -122,8 +123,8 @@ func (p *Proxy) SetHealthChecker(hc *HealthChecker) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// createBalancer 根据配置的算法创建负载均衡器。
|
// createBalancer 根据配置的算法创建负载均衡器。
|
||||||
func createBalancer(algorithm string) (loadbalance.Balancer, error) {
|
func createBalancer(cfg *config.ProxyConfig) (loadbalance.Balancer, error) {
|
||||||
switch algorithm {
|
switch cfg.LoadBalance {
|
||||||
case "round_robin", "":
|
case "round_robin", "":
|
||||||
return loadbalance.NewRoundRobin(), nil
|
return loadbalance.NewRoundRobin(), nil
|
||||||
case "weighted_round_robin":
|
case "weighted_round_robin":
|
||||||
@ -132,13 +133,19 @@ func createBalancer(algorithm string) (loadbalance.Balancer, error) {
|
|||||||
return loadbalance.NewLeastConnections(), nil
|
return loadbalance.NewLeastConnections(), nil
|
||||||
case "ip_hash":
|
case "ip_hash":
|
||||||
return loadbalance.NewIPHash(), nil
|
return loadbalance.NewIPHash(), nil
|
||||||
|
case "consistent_hash":
|
||||||
|
virtualNodes := cfg.VirtualNodes
|
||||||
|
if virtualNodes <= 0 {
|
||||||
|
virtualNodes = 150
|
||||||
|
}
|
||||||
|
return loadbalance.NewConsistentHash(virtualNodes, cfg.HashKey), nil
|
||||||
default:
|
default:
|
||||||
return nil, errors.New("unsupported load balance algorithm: " + algorithm)
|
return nil, errors.New("unsupported load balance algorithm: " + cfg.LoadBalance)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// createHostClient 为后台目标 URL 创建 fasthttp.HostClient。
|
// createHostClient 为后台目标 URL 创建 fasthttp.HostClient。
|
||||||
func createHostClient(targetURL string, timeout config.ProxyTimeout) *fasthttp.HostClient {
|
func createHostClient(targetURL string, timeout config.ProxyTimeout, transportCfg *config.TransportConfig) *fasthttp.HostClient {
|
||||||
// 从目标 URL 解析主机和协议
|
// 从目标 URL 解析主机和协议
|
||||||
addr := targetURL
|
addr := targetURL
|
||||||
isTLS := false
|
isTLS := false
|
||||||
@ -155,13 +162,27 @@ func createHostClient(targetURL string, timeout config.ProxyTimeout) *fasthttp.H
|
|||||||
addr = addr[:idx]
|
addr = addr[:idx]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 默认值
|
||||||
|
maxIdleConnDuration := 90 * time.Second
|
||||||
|
maxConns := 100
|
||||||
|
|
||||||
|
// 应用 Transport 配置
|
||||||
|
if transportCfg != nil {
|
||||||
|
if transportCfg.IdleConnTimeout > 0 {
|
||||||
|
maxIdleConnDuration = transportCfg.IdleConnTimeout
|
||||||
|
}
|
||||||
|
if transportCfg.MaxConnsPerHost > 0 {
|
||||||
|
maxConns = transportCfg.MaxConnsPerHost
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
client := &fasthttp.HostClient{
|
client := &fasthttp.HostClient{
|
||||||
Addr: addr,
|
Addr: addr,
|
||||||
IsTLS: isTLS,
|
IsTLS: isTLS,
|
||||||
ReadTimeout: timeout.Read,
|
ReadTimeout: timeout.Read,
|
||||||
WriteTimeout: timeout.Write,
|
WriteTimeout: timeout.Write,
|
||||||
MaxIdleConnDuration: 60 * time.Second,
|
MaxIdleConnDuration: maxIdleConnDuration,
|
||||||
MaxConns: 100,
|
MaxConns: maxConns,
|
||||||
MaxConnWaitTimeout: timeout.Connect,
|
MaxConnWaitTimeout: timeout.Connect,
|
||||||
RetryIf: nil, // Disable automatic retries
|
RetryIf: nil, // Disable automatic retries
|
||||||
DisablePathNormalizing: false,
|
DisablePathNormalizing: false,
|
||||||
@ -389,13 +410,13 @@ func (p *Proxy) UpdateTargets(targets []*loadbalance.Target) error {
|
|||||||
// 清除旧客户端
|
// 清除旧客户端
|
||||||
p.clients = make(map[string]*fasthttp.HostClient)
|
p.clients = make(map[string]*fasthttp.HostClient)
|
||||||
|
|
||||||
// 初始化新客户端
|
// 初始化新客户端(使用 nil TransportConfig 保持原有行为)
|
||||||
for _, target := range targets {
|
for _, target := range targets {
|
||||||
if target.URL == "" {
|
if target.URL == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
client := createHostClient(target.URL, p.config.Timeout)
|
client := createHostClient(target.URL, p.config.Timeout, nil)
|
||||||
p.clients[target.URL] = client
|
p.clients[target.URL] = client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -121,7 +121,7 @@ func TestNewProxy(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
p, err := NewProxy(tt.cfg, tt.targets)
|
p, err := NewProxy(tt.cfg, tt.targets, nil)
|
||||||
if tt.wantErr {
|
if tt.wantErr {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("NewProxy() expected error containing %q, got nil", tt.errContains)
|
t.Errorf("NewProxy() expected error containing %q, got nil", tt.errContains)
|
||||||
@ -166,7 +166,7 @@ func TestServeHTTP_NoHealthyTargets(t *testing.T) {
|
|||||||
targets[0].Healthy.Store(false)
|
targets[0].Healthy.Store(false)
|
||||||
targets[1].Healthy.Store(false)
|
targets[1].Healthy.Store(false)
|
||||||
|
|
||||||
p, err := NewProxy(cfg, targets)
|
p, err := NewProxy(cfg, targets, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewProxy() error: %v", err)
|
t.Fatalf("NewProxy() error: %v", err)
|
||||||
}
|
}
|
||||||
@ -216,7 +216,7 @@ func TestServeHTTP_RequestForwarding(t *testing.T) {
|
|||||||
{URL: "http://localhost:8080"},
|
{URL: "http://localhost:8080"},
|
||||||
}
|
}
|
||||||
|
|
||||||
p, err := NewProxy(cfg, targets)
|
p, err := NewProxy(cfg, targets, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewProxy() error: %v", err)
|
t.Fatalf("NewProxy() error: %v", err)
|
||||||
}
|
}
|
||||||
@ -308,7 +308,7 @@ func TestSelectTarget(t *testing.T) {
|
|||||||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||||
}
|
}
|
||||||
|
|
||||||
p, err := NewProxy(cfg, tt.targets)
|
p, err := NewProxy(cfg, tt.targets, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewProxy() error: %v", err)
|
t.Fatalf("NewProxy() error: %v", err)
|
||||||
}
|
}
|
||||||
@ -425,7 +425,7 @@ func TestModifyRequestHeaders(t *testing.T) {
|
|||||||
{URL: "http://localhost:8080"},
|
{URL: "http://localhost:8080"},
|
||||||
}
|
}
|
||||||
|
|
||||||
p, err := NewProxy(cfg, targets)
|
p, err := NewProxy(cfg, targets, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewProxy() error: %v", err)
|
t.Fatalf("NewProxy() error: %v", err)
|
||||||
}
|
}
|
||||||
@ -520,7 +520,7 @@ func TestModifyResponseHeaders(t *testing.T) {
|
|||||||
{URL: "http://localhost:8080"},
|
{URL: "http://localhost:8080"},
|
||||||
}
|
}
|
||||||
|
|
||||||
p, err := NewProxy(cfg, targets)
|
p, err := NewProxy(cfg, targets, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewProxy() error: %v", err)
|
t.Fatalf("NewProxy() error: %v", err)
|
||||||
}
|
}
|
||||||
@ -608,7 +608,7 @@ func TestUpdateTargets(t *testing.T) {
|
|||||||
{URL: "http://old2:8080"},
|
{URL: "http://old2:8080"},
|
||||||
}
|
}
|
||||||
|
|
||||||
p, err := NewProxy(cfg, initialTargets)
|
p, err := NewProxy(cfg, initialTargets, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewProxy() error: %v", err)
|
t.Fatalf("NewProxy() error: %v", err)
|
||||||
}
|
}
|
||||||
@ -657,7 +657,7 @@ func TestGetTargets(t *testing.T) {
|
|||||||
{URL: "http://backend2:8080"},
|
{URL: "http://backend2:8080"},
|
||||||
}
|
}
|
||||||
|
|
||||||
p, err := NewProxy(cfg, targets)
|
p, err := NewProxy(cfg, targets, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewProxy() error: %v", err)
|
t.Fatalf("NewProxy() error: %v", err)
|
||||||
}
|
}
|
||||||
@ -686,7 +686,7 @@ func TestGetConfig(t *testing.T) {
|
|||||||
{URL: "http://localhost:8080"},
|
{URL: "http://localhost:8080"},
|
||||||
}
|
}
|
||||||
|
|
||||||
p, err := NewProxy(cfg, targets)
|
p, err := NewProxy(cfg, targets, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewProxy() error: %v", err)
|
t.Fatalf("NewProxy() error: %v", err)
|
||||||
}
|
}
|
||||||
@ -763,38 +763,37 @@ func TestIsWebSocketRequest(t *testing.T) {
|
|||||||
func TestCreateBalancer(t *testing.T) {
|
func TestCreateBalancer(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
algorithm string
|
cfg *config.ProxyConfig
|
||||||
wantErr bool
|
wantErr bool
|
||||||
errContains string
|
errContains string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "轮询",
|
name: "轮询",
|
||||||
algorithm: "round_robin",
|
cfg: &config.ProxyConfig{LoadBalance: "round_robin"},
|
||||||
wantErr: false,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "加权轮询",
|
name: "加权轮询",
|
||||||
algorithm: "weighted_round_robin",
|
cfg: &config.ProxyConfig{LoadBalance: "weighted_round_robin"},
|
||||||
wantErr: false,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "最少连接",
|
name: "最少连接",
|
||||||
algorithm: "least_conn",
|
cfg: &config.ProxyConfig{LoadBalance: "least_conn"},
|
||||||
wantErr: false,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "IP哈希",
|
name: "IP哈希",
|
||||||
algorithm: "ip_hash",
|
cfg: &config.ProxyConfig{LoadBalance: "ip_hash"},
|
||||||
wantErr: false,
|
},
|
||||||
|
{
|
||||||
|
name: "一致性哈希",
|
||||||
|
cfg: &config.ProxyConfig{LoadBalance: "consistent_hash", HashKey: "ip", VirtualNodes: 150},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "空算法(默认轮询)",
|
name: "空算法(默认轮询)",
|
||||||
algorithm: "",
|
cfg: &config.ProxyConfig{LoadBalance: ""},
|
||||||
wantErr: false,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "无效算法",
|
name: "无效算法",
|
||||||
algorithm: "unknown_algorithm",
|
cfg: &config.ProxyConfig{LoadBalance: "unknown_algorithm"},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
errContains: "unsupported load balance algorithm",
|
errContains: "unsupported load balance algorithm",
|
||||||
},
|
},
|
||||||
@ -802,23 +801,23 @@ func TestCreateBalancer(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
balancer, err := createBalancer(tt.algorithm)
|
balancer, err := createBalancer(tt.cfg)
|
||||||
if tt.wantErr {
|
if tt.wantErr {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("createBalancer(%q) expected error", tt.algorithm)
|
t.Errorf("createBalancer(%v) expected error", tt.cfg.LoadBalance)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !contains(err.Error(), tt.errContains) {
|
if !contains(err.Error(), tt.errContains) {
|
||||||
t.Errorf("createBalancer(%q) error = %v, want containing %q", tt.algorithm, err, tt.errContains)
|
t.Errorf("createBalancer(%v) error = %v, want containing %q", tt.cfg.LoadBalance, err, tt.errContains)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("createBalancer(%q) unexpected error: %v", tt.algorithm, err)
|
t.Errorf("createBalancer(%v) unexpected error: %v", tt.cfg.LoadBalance, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if balancer == nil {
|
if balancer == nil {
|
||||||
t.Errorf("createBalancer(%q) returned nil balancer", tt.algorithm)
|
t.Errorf("createBalancer(%v) returned nil balancer", tt.cfg.LoadBalance)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -850,7 +849,7 @@ func TestCreateHostClient(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
client := createHostClient(tt.targetURL, tt.timeout)
|
client := createHostClient(tt.targetURL, tt.timeout, nil)
|
||||||
if client == nil {
|
if client == nil {
|
||||||
t.Error("createHostClient() returned nil")
|
t.Error("createHostClient() returned nil")
|
||||||
return
|
return
|
||||||
@ -882,7 +881,7 @@ func TestHandleWebSocket(t *testing.T) {
|
|||||||
{URL: "http://localhost:8080"},
|
{URL: "http://localhost:8080"},
|
||||||
}
|
}
|
||||||
|
|
||||||
p, err := NewProxy(cfg, targets)
|
p, err := NewProxy(cfg, targets, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewProxy() error: %v", err)
|
t.Fatalf("NewProxy() error: %v", err)
|
||||||
}
|
}
|
||||||
@ -916,7 +915,7 @@ func TestSetHealthChecker(t *testing.T) {
|
|||||||
{URL: "http://localhost:8081"},
|
{URL: "http://localhost:8081"},
|
||||||
}
|
}
|
||||||
|
|
||||||
p, err := NewProxy(cfg, targets)
|
p, err := NewProxy(cfg, targets, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewProxy() error: %v", err)
|
t.Fatalf("NewProxy() error: %v", err)
|
||||||
}
|
}
|
||||||
@ -954,7 +953,7 @@ func TestGetClient(t *testing.T) {
|
|||||||
{URL: "http://localhost:8082"},
|
{URL: "http://localhost:8082"},
|
||||||
}
|
}
|
||||||
|
|
||||||
p, err := NewProxy(cfg, targets)
|
p, err := NewProxy(cfg, targets, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewProxy() error: %v", err)
|
t.Fatalf("NewProxy() error: %v", err)
|
||||||
}
|
}
|
||||||
@ -1018,7 +1017,7 @@ func TestProxyCache(t *testing.T) {
|
|||||||
}
|
}
|
||||||
targets[0].Healthy.Store(true)
|
targets[0].Healthy.Store(true)
|
||||||
|
|
||||||
p, err := NewProxy(cfg, targets)
|
p, err := NewProxy(cfg, targets, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewProxy() error: %v", err)
|
t.Fatalf("NewProxy() error: %v", err)
|
||||||
}
|
}
|
||||||
@ -1069,7 +1068,7 @@ func TestServeHTTP_WithPassiveHealthCheck(t *testing.T) {
|
|||||||
}
|
}
|
||||||
targets[0].Healthy.Store(true)
|
targets[0].Healthy.Store(true)
|
||||||
|
|
||||||
p, err := NewProxy(cfg, targets)
|
p, err := NewProxy(cfg, targets, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewProxy() error: %v", err)
|
t.Fatalf("NewProxy() error: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -251,7 +251,7 @@ func (s *Server) buildMiddlewareChain(serverCfg *config.ServerConfig) (*middlewa
|
|||||||
serverCfg.Security.Headers.ContentSecurityPolicy != "" ||
|
serverCfg.Security.Headers.ContentSecurityPolicy != "" ||
|
||||||
serverCfg.Security.Headers.ReferrerPolicy != "" ||
|
serverCfg.Security.Headers.ReferrerPolicy != "" ||
|
||||||
serverCfg.Security.Headers.PermissionsPolicy != "" {
|
serverCfg.Security.Headers.PermissionsPolicy != "" {
|
||||||
headers := security.NewSecurityHeaders(&serverCfg.Security.Headers)
|
headers := security.NewSecurityHeadersWithHSTS(&serverCfg.Security.Headers, &serverCfg.SSL.HSTS)
|
||||||
middlewares = append(middlewares, headers)
|
middlewares = append(middlewares, headers)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -534,7 +534,8 @@ func (s *Server) registerProxyRoutes(router *handler.Router, serverCfg *config.S
|
|||||||
targets[j].Healthy.Store(true)
|
targets[j].Healthy.Store(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
p, err := proxy.NewProxy(proxyCfg, targets)
|
// 传递 Transport 配置
|
||||||
|
p, err := proxy.NewProxy(proxyCfg, targets, &s.config.Performance.Transport)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Error().Msg("创建代理失败: " + err.Error())
|
logging.Error().Msg("创建代理失败: " + err.Error())
|
||||||
continue
|
continue
|
||||||
|
|||||||
@ -69,14 +69,7 @@ func (v *VHostManager) SetDefault(handler fasthttp.RequestHandler) {
|
|||||||
// Handler 返回虚拟主机选择器
|
// Handler 返回虚拟主机选择器
|
||||||
func (v *VHostManager) Handler() fasthttp.RequestHandler {
|
func (v *VHostManager) Handler() fasthttp.RequestHandler {
|
||||||
return func(ctx *fasthttp.RequestCtx) {
|
return func(ctx *fasthttp.RequestCtx) {
|
||||||
host := string(ctx.Host())
|
host := stripPort(string(ctx.Host()))
|
||||||
// 去除端口号
|
|
||||||
for i := 0; i < len(host); i++ {
|
|
||||||
if host[i] == ':' {
|
|
||||||
host = host[:i]
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if vhost, ok := v.hosts[host]; ok {
|
if vhost, ok := v.hosts[host]; ok {
|
||||||
vhost.handler(ctx)
|
vhost.handler(ctx)
|
||||||
@ -87,3 +80,38 @@ func (v *VHostManager) Handler() fasthttp.RequestHandler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// stripPort 从 Host 头中移除端口号。
|
||||||
|
//
|
||||||
|
// 支持 IPv4 和 IPv6 格式:
|
||||||
|
// - example.com:8080 -> example.com
|
||||||
|
// - [::1]:8080 -> [::1]
|
||||||
|
// - [2001:db8::1]:443 -> [2001:db8::1]
|
||||||
|
// - example.com -> example.com
|
||||||
|
func stripPort(host string) string {
|
||||||
|
// 空字符串直接返回
|
||||||
|
if len(host) == 0 {
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPv6 格式:以 '[' 开头,找 ']:' 作为分隔点
|
||||||
|
if host[0] == '[' {
|
||||||
|
// 查找 ']:' 分隔符
|
||||||
|
for i := 0; i < len(host)-1; i++ {
|
||||||
|
if host[i] == ']' && host[i+1] == ':' {
|
||||||
|
return host[:i+1] // 返回包含 ']' 的部分,如 "[::1]"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 没有 ']:' 分隔符,可能是纯 IPv6 地址(如 "[::1]")
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPv4 或域名格式:找第一个 ':' 作为分隔点
|
||||||
|
for i := 0; i < len(host); i++ {
|
||||||
|
if host[i] == ':' {
|
||||||
|
return host[:i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|||||||
@ -99,13 +99,9 @@ func TestVHostManager_Handler(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("IPv6地址Host", func(t *testing.T) {
|
t.Run("IPv6地址Host", func(t *testing.T) {
|
||||||
// TODO: 当前 vhost.go 的端口剥离逻辑不支持 IPv6 格式 [::1]:8080
|
|
||||||
// 它会错误地在第一个 ':' 处截断(IPv6 地址内部的冒号)
|
|
||||||
// 修复方案:检查 host 是否以 '[' 开头,找 ']:' 作为分隔点
|
|
||||||
manager := NewVHostManager()
|
manager := NewVHostManager()
|
||||||
ipv6Called := false
|
ipv6Called := false
|
||||||
manager.AddHost("[::1]", mockHandler("ipv6", &ipv6Called))
|
manager.AddHost("[::1]", mockHandler("ipv6", &ipv6Called))
|
||||||
manager.SetDefault(mockHandler("default", &ipv6Called)) // fallback
|
|
||||||
|
|
||||||
handler := manager.Handler()
|
handler := manager.Handler()
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
@ -113,9 +109,12 @@ func TestVHostManager_Handler(t *testing.T) {
|
|||||||
|
|
||||||
handler(ctx)
|
handler(ctx)
|
||||||
|
|
||||||
// 当前实现不支持 IPv6,会 fallback 到默认 handler
|
if !ipv6Called {
|
||||||
// 修复 vhost.go 后此测试应验证 ipv6Called 为 true
|
t.Error("期望 [::1] handler 被调用,但未被调用")
|
||||||
t.Log("注意: 当前实现不支持 IPv6 地址,需要修复 vhost.go 的端口剥离逻辑")
|
}
|
||||||
|
if string(ctx.Response.Body()) != "ipv6" {
|
||||||
|
t.Errorf("响应体 = %q, want %q", string(ctx.Response.Body()), "ipv6")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("空Host使用默认", func(t *testing.T) {
|
t.Run("空Host使用默认", func(t *testing.T) {
|
||||||
@ -270,6 +269,10 @@ func TestVHostManager_PortStripping(t *testing.T) {
|
|||||||
{"标准HTTPS端口", "example.com:443", "example.com"},
|
{"标准HTTPS端口", "example.com:443", "example.com"},
|
||||||
{"自定义端口", "example.com:8080", "example.com"},
|
{"自定义端口", "example.com:8080", "example.com"},
|
||||||
{"IPv6 localhost带端口", "[localhost]:8080", "[localhost]"},
|
{"IPv6 localhost带端口", "[localhost]:8080", "[localhost]"},
|
||||||
|
{"IPv6 loopback带端口", "[::1]:8080", "[::1]"},
|
||||||
|
{"IPv6完整地址带端口", "[2001:db8::1]:443", "[2001:db8::1]"},
|
||||||
|
{"IPv6无端口", "[::1]", "[::1]"},
|
||||||
|
{"IPv6完整地址无端口", "[2001:db8::1]", "[2001:db8::1]"},
|
||||||
{"空字符串", "", ""},
|
{"空字符串", "", ""},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -291,11 +294,8 @@ func TestVHostManager_PortStripping(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// IPv6 数字地址测试 - 当前实现有已知 bug
|
// IPv6 数字地址测试
|
||||||
t.Run("IPv6数字地址_已知限制", func(t *testing.T) {
|
t.Run("IPv6数字地址", func(t *testing.T) {
|
||||||
// TODO: vhost.go 的端口剥离逻辑不支持 IPv6 数字地址格式 [::1]:8080
|
|
||||||
// 因为它会在第一个 ':' 处截断(IPv6 地址内部的冒号)
|
|
||||||
// 结果:[:而不是 [::1]
|
|
||||||
manager := NewVHostManager()
|
manager := NewVHostManager()
|
||||||
ipv6Called := false
|
ipv6Called := false
|
||||||
manager.AddHost("[::1]", mockHandler("ipv6", &ipv6Called))
|
manager.AddHost("[::1]", mockHandler("ipv6", &ipv6Called))
|
||||||
@ -306,10 +306,11 @@ func TestVHostManager_PortStripping(t *testing.T) {
|
|||||||
|
|
||||||
handler(ctx)
|
handler(ctx)
|
||||||
|
|
||||||
// 当前行为:不匹配,因为端口剥离错误
|
if !ipv6Called {
|
||||||
if ipv6Called {
|
t.Error("期望 [::1] handler 被调用,但未被调用")
|
||||||
t.Error("当前实现不支持 IPv6 数字地址的端口剥离,不应匹配")
|
}
|
||||||
|
if string(ctx.Response.Body()) != "ipv6" {
|
||||||
|
t.Errorf("响应体 = %q, want %q", string(ctx.Response.Body()), "ipv6")
|
||||||
}
|
}
|
||||||
t.Log("已知限制: IPv6 数字地址端口剥离需要修复 vhost.go")
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user