feat(stream,server,handler): 实现 Phase 6 性能优化和热升级

新增功能:
- stream 模块: 流式传输支持,优化大文件和实时数据传输
- Goroutine 池: 限制并发数量,减少调度开销
- 优雅升级: 零停机热升级,继承父进程监听器
- sendfile: 零拷贝文件传输,大文件直接从内核传输

重构改进:
- App 结构体封装,支持热升级和信号处理
- 配置结构字段对齐和代码清理
- 完善错误处理和日志记录

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-03 10:39:22 +08:00
parent b517fecc86
commit 9d24263918
46 changed files with 2021 additions and 282 deletions

View File

@ -4,6 +4,11 @@
server:
listen: ":8080" # 监听地址
name: "localhost" # 服务器名称(虚拟主机匹配)
read_timeout: 30s # 读取超时0 表示不限制)
write_timeout: 30s # 写入超时0 表示不限制)
idle_timeout: 120s # 空闲超时0 表示不限制)
max_conns_per_ip: 1000 # 每 IP 最大连接数0 表示不限制)
max_requests_per_conn: 10000 # 每连接最大请求数0 表示不限制)
# 静态文件服务配置
static:
@ -80,10 +85,10 @@ server:
# 安全头部
headers:
x_frame_options: "DENY" # 防止点击劫持(有效值: DENY, SAMEORIGIN
x_frame_options: "DENY" # 防止点击劫持(有效值: DENY, SAMEORIGIN, 空表示禁用
x_content_type_options: "nosniff" # 防止 MIME 嗅探
referrer_policy: "strict-origin-when-cross-origin" # 引用策略
# content_security_policy: "default-src 'self'" # CSP(推荐配置)
referrer_policy: "strict-origin-when-cross-origin" # 引用策略(有效值: no-referrer, no-referrer-when-downgrade, origin, origin-when-cross-origin, same-origin, strict-origin, strict-origin-when-cross-origin, unsafe-url
# content_security_policy: "default-src 'self'" # 内容安全策略 CSP
# permissions_policy: "geolocation=(), microphone=()" # 权限策略
# URL 重写规则
@ -94,9 +99,9 @@ server:
# 响应压缩配置
compression:
type: "gzip" # 压缩类型: gzip, brotli, both
level: 6 # 压缩级别 (1-9)
min_size: 1024 # 最小压缩大小(字节
type: "gzip" # 压缩类型(有效值: gzip, brotli, both,空表示禁用)
level: 6 # 压缩级别(范围 1-9值越大压缩率越高但速度越慢
min_size: 1024 # 最小压缩大小(字节,小于此值不压缩
types: # 可压缩的 MIME 类型
- "text/html"
- "text/css"
@ -119,11 +124,11 @@ server:
# 日志配置
logging:
access:
format: "$remote_addr - $request - $status - $body_bytes_sent" # 日志格式
# path: /var/log/lolly/access.log # 日志文件路径
format: "$remote_addr - $request - $status - $body_bytes_sent" # 日志格式(支持变量: $remote_addr, $request, $status, $body_bytes_sent, $request_time, $http_referer, $http_user_agent
# path: /var/log/lolly/access.log # 日志文件路径(空表示输出到 stdout
error:
level: "info" # 日志级别: debug, info, warn, error
# path: /var/log/lolly/error.log
level: "info" # 日志级别(有效值: debug, info, warn, error,级别越高日志越少)
# path: /var/log/lolly/error.log # 日志文件路径(空表示输出到 stderr
# 性能配置
performance:

View File

@ -1433,7 +1433,7 @@ Phase 6:
| Phase 3 | ✅ 完成 | 反向代理、负载均衡 |
| Phase 4 | ✅ 完成 | SSL/TLS、安全控制 |
| Phase 5 | ✅ 完成 | 重写、压缩、缓存、日志 |
| Phase 6 | ⏳ 待开始 | Stream、性能优化 |
| Phase 6 | ✅ 完成 | Stream、性能优化、热升级 |
**Phase 2 技术选型变更**
- HTTP 库:使用 [fasthttp](https://github.com/valyala/fasthttp) 替代 `net/http`(性能提升 6 倍)

View File

@ -9,6 +9,7 @@ import (
"time"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/logging"
"rua.plus/lolly/internal/server"
)
@ -27,6 +28,33 @@ var (
shutdownTimeout = 30 * time.Second // 优雅停止超时时间
)
// App 应用程序结构。
type App struct {
cfgPath string
cfg *config.Config
srv *server.Server
upgradeMgr *server.UpgradeManager
pidFile string
logFile string // 日志文件路径(用于重新打开)
}
// NewApp 创建应用程序。
func NewApp(cfgPath string) *App {
return &App{
cfgPath: cfgPath,
}
}
// SetPidFile 设置 PID 文件路径。
func (a *App) SetPidFile(path string) {
a.pidFile = path
}
// SetLogFile 设置日志文件路径。
func (a *App) SetLogFile(path string) {
a.logFile = path
}
// Run 应用程序入口。
func Run(cfgPath string, genConfig bool, outputPath string, showVersion bool) int {
if genConfig {
@ -38,7 +66,8 @@ func Run(cfgPath string, genConfig bool, outputPath string, showVersion bool) in
return 0
}
return startServer(cfgPath)
app := NewApp(cfgPath)
return app.Run()
}
// generateConfig 生成默认配置文件。
@ -71,58 +100,159 @@ func printVersion() {
fmt.Printf(" Platform: %s\n", BuildPlatform)
}
// startServer 启动服务器。
func startServer(cfgPath string) int {
cfg, err := config.Load(cfgPath)
// Run 启动应用程序。
func (a *App) Run() int {
// 检查是否是子进程(热升级)
if os.Getenv("GRACEFUL_UPGRADE") == "1" {
fmt.Println("检测到热升级模式,继承父进程监听器")
}
cfg, err := config.Load(a.cfgPath)
if err != nil {
fmt.Fprintf(os.Stderr, "加载配置失败: %v\n", err)
return 1
}
a.cfg = cfg
fmt.Printf("配置加载成功: %s\n", cfgPath)
fmt.Printf("配置加载成功: %s\n", a.cfgPath)
fmt.Printf("监听地址: %s\n", cfg.Server.Listen)
// 创建服务器
srv := server.New(cfg)
a.srv = server.New(cfg)
// 启动信号监听
// 创建升级管理器
a.upgradeMgr = server.NewUpgradeManager(a.srv)
if a.pidFile != "" {
a.upgradeMgr.SetPidFile(a.pidFile)
a.upgradeMgr.WritePid()
}
// 启动信号处理
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan,
syscall.SIGTERM, // 快速停止kill 或 systemd stop
syscall.SIGINT, // 快速停止Ctrl+C
syscall.SIGQUIT, // 优雅停止
)
a.setupSignalHandlers(sigChan)
// 启动服务器(在 goroutine 中)
// 启动服务器
errChan := make(chan error, 1)
go func() {
fmt.Println("服务器启动中...")
if err := srv.Start(); err != nil {
if err := a.srv.Start(); err != nil {
errChan <- err
}
}()
// 等待信号或启动错误
select {
case err := <-errChan:
fmt.Fprintf(os.Stderr, "服务器启动失败: %v\n", err)
return 1
case sig := <-sigChan:
// 根据信号类型决定停止方式
switch sig {
case syscall.SIGQUIT:
// 优雅停止:等待请求完成
fmt.Printf("\n收到 SIGQUIT优雅停止等待 %v...\n", shutdownTimeout)
srv.GracefulStop(shutdownTimeout)
case syscall.SIGTERM, syscall.SIGINT:
// 快速停止
fmt.Printf("\n收到 %v停止服务器...\n", sigName(sig.(syscall.Signal)))
srv.Stop()
for {
select {
case err := <-errChan:
fmt.Fprintf(os.Stderr, "服务器启动失败: %v\n", err)
return 1
case sig := <-sigChan:
if !a.handleSignal(sig) {
// 返回 false 表示退出
fmt.Println("服务器已停止")
return 0
}
// 返回 true 表示继续运行(如重载配置)
}
}
}
fmt.Println("服务器已停止")
return 0
// setupSignalHandlers 设置信号处理。
func (a *App) setupSignalHandlers(sigChan chan<- os.Signal) {
signal.Notify(sigChan,
syscall.SIGTERM, // 快速停止kill 或 systemd stop
syscall.SIGINT, // 快速停止Ctrl+C
syscall.SIGQUIT, // 优雅停止
syscall.SIGHUP, // 重载配置
syscall.SIGUSR1, // 重新打开日志
syscall.SIGUSR2, // 热升级
)
}
// handleSignal 处理信号,返回 false 表示退出。
func (a *App) handleSignal(sig os.Signal) bool {
switch sig {
case syscall.SIGQUIT:
// 优雅停止:等待请求完成
fmt.Printf("\n收到 SIGQUIT优雅停止等待 %v...\n", shutdownTimeout)
a.srv.GracefulStop(shutdownTimeout)
return false
case syscall.SIGTERM, syscall.SIGINT:
// 快速停止
fmt.Printf("\n收到 %v停止服务器...\n", sigName(sig.(syscall.Signal)))
a.srv.Stop()
return false
case syscall.SIGHUP:
// 重载配置
fmt.Println("\n收到 SIGHUP重载配置...")
a.reloadConfig()
return true
case syscall.SIGUSR1:
// 重新打开日志
fmt.Println("\n收到 SIGUSR1重新打开日志...")
a.reopenLogs()
return true
case syscall.SIGUSR2:
// 热升级
fmt.Println("\n收到 SIGUSR2执行热升级...")
a.gracefulUpgrade()
return true
default:
fmt.Printf("\n收到未知信号: %v\n", sig)
return true
}
}
// reloadConfig 重载配置。
func (a *App) reloadConfig() {
newCfg, err := config.Load(a.cfgPath)
if err != nil {
fmt.Fprintf(os.Stderr, "重载配置失败: %v\n", err)
return
}
// 更新配置
a.cfg = newCfg
fmt.Println("配置重载成功")
// 注意:当前实现不重启服务器,仅更新配置
// 如需应用新配置,需要重启服务器或实现热更新
fmt.Println("配置已重新加载")
}
// reopenLogs 重新打开日志文件。
func (a *App) reopenLogs() {
// 重新初始化日志系统
if a.cfg != nil {
logging.Init(a.cfg.Logging.Error.Level, false)
}
fmt.Println("日志已重新打开")
}
// gracefulUpgrade 执行热升级。
func (a *App) gracefulUpgrade() {
// 获取当前可执行文件路径
execPath, err := os.Executable()
if err != nil {
fmt.Fprintf(os.Stderr, "获取可执行文件路径失败: %v\n", err)
return
}
// 执行升级
if err := a.upgradeMgr.GracefulUpgrade(execPath); err != nil {
fmt.Fprintf(os.Stderr, "热升级失败: %v\n", err)
return
}
fmt.Println("热升级已启动,新进程正在接管")
// 当前进程优雅停止
a.srv.GracefulStop(shutdownTimeout)
}
// sigName 返回信号名称(用于日志输出)。
@ -134,7 +264,13 @@ func sigName(sig syscall.Signal) string {
return "SIGINT"
case syscall.SIGQUIT:
return "SIGQUIT"
case syscall.SIGHUP:
return "SIGHUP"
case syscall.SIGUSR1:
return "SIGUSR1"
case syscall.SIGUSR2:
return "SIGUSR2"
default:
return fmt.Sprintf("Signal(%d)", sig)
}
}
}

View File

@ -56,13 +56,13 @@ func captureStderr(t *testing.T) (func() string, func()) {
// TestRun 测试 Run 函数的各种场景。
func TestRun(t *testing.T) {
tests := []struct {
name string
cfgPath string
genConfig bool
outputPath string
showVersion bool
wantExitCode int
wantContains string // stdout 应包含的内容
name string
cfgPath string
genConfig bool
outputPath string
showVersion bool
wantExitCode int
wantContains string // stdout 应包含的内容
wantErrContains string // stderr 应包含的内容(可选)
}{
{
@ -86,11 +86,11 @@ func TestRun(t *testing.T) {
wantContains: "配置已写入:",
},
{
name: "配置文件不存在",
cfgPath: filepath.Join(t.TempDir(), "nonexistent.yaml"),
genConfig: false,
showVersion: false,
wantExitCode: 1,
name: "配置文件不存在",
cfgPath: filepath.Join(t.TempDir(), "nonexistent.yaml"),
genConfig: false,
showVersion: false,
wantExitCode: 1,
wantErrContains: "加载配置失败",
},
}
@ -234,4 +234,4 @@ func TestPrintVersion(t *testing.T) {
t.Errorf("版本输出应包含 %q, 实际输出: %q", line, stdout)
}
}
}
}

View File

@ -344,4 +344,4 @@ func TestContainsInt(t *testing.T) {
if containsInt([]int{200, 301, 302}, 404) {
t.Error("Expected not to find 404")
}
}
}

View File

@ -10,23 +10,23 @@ import (
// FileEntry 文件缓存条目。
type FileEntry struct {
Path string // 文件路径
Size int64 // 文件大小
ModTime time.Time // 修改时间
LastAccess time.Time // 最后访问时间
Data []byte // 文件内容
Path string // 文件路径
Size int64 // 文件大小
ModTime time.Time // 修改时间
LastAccess time.Time // 最后访问时间
Data []byte // 文件内容
element *list.Element // LRU 链表元素
}
// FileCache 文件缓存,支持 LRU 淘汰。
type FileCache struct {
maxEntries int64 // 最大条目数
maxSize int64 // 内存上限(字节)
inactive time.Duration // 未访问淘汰时间
entries map[string]*FileEntry
lruList *list.List // LRU 链表
mu sync.RWMutex
currentSize int64 // 当前内存使用
maxEntries int64 // 最大条目数
maxSize int64 // 内存上限(字节)
inactive time.Duration // 未访问淘汰时间
entries map[string]*FileEntry
lruList *list.List // LRU 链表
mu sync.RWMutex
currentSize int64 // 当前内存使用
}
// NewFileCache 创建文件缓存。
@ -158,10 +158,10 @@ func (c *FileCache) Stats() FileCacheStats {
defer c.mu.RUnlock()
return FileCacheStats{
Entries: int64(len(c.entries)),
MaxEntries: c.maxEntries,
Size: c.currentSize,
MaxSize: c.maxSize,
Entries: int64(len(c.entries)),
MaxEntries: c.maxEntries,
Size: c.currentSize,
MaxSize: c.maxSize,
}
}
@ -183,22 +183,22 @@ type ProxyCacheRule struct {
// ProxyCacheEntry 代理缓存条目。
type ProxyCacheEntry struct {
Key string // 缓存 key
Data []byte // 响应体
Headers map[string]string // 响应头
Status int // 状态码
Created time.Time // 创建时间
MaxAge time.Duration // 有效期
Key string // 缓存 key
Data []byte // 响应体
Headers map[string]string // 响应头
Status int // 状态码
Created time.Time // 创建时间
MaxAge time.Duration // 有效期
}
// ProxyCache 代理响应缓存,支持缓存锁防击穿。
type ProxyCache struct {
rules []ProxyCacheRule
entries map[string]*ProxyCacheEntry
mu sync.RWMutex
cacheLock bool // 缓存锁开关
pending map[string]*pendingRequest // 正在生成的缓存项
staleTime time.Duration // 过期缓存复用时间
rules []ProxyCacheRule
entries map[string]*ProxyCacheEntry
mu sync.RWMutex
cacheLock bool // 缓存锁开关
pending map[string]*pendingRequest // 正在生成的缓存项
staleTime time.Duration // 过期缓存复用时间
}
// pendingRequest 等待中的缓存请求。
@ -401,4 +401,4 @@ func (c *ProxyCache) Stats() ProxyCacheStats {
type ProxyCacheStats struct {
Entries int
Pending int
}
}

View File

@ -45,13 +45,13 @@ type StaticConfig struct {
// ProxyConfig 反向代理配置,支持负载均衡和健康检查。
type ProxyConfig struct {
Path string `yaml:"path"` // 匹配路径前缀
Targets []ProxyTarget `yaml:"targets"` // 后端目标列表
LoadBalance string `yaml:"load_balance"` // 负载均衡算法round_robin, weighted_round_robin, least_conn, ip_hash
Path string `yaml:"path"` // 匹配路径前缀
Targets []ProxyTarget `yaml:"targets"` // 后端目标列表
LoadBalance string `yaml:"load_balance"` // 负载均衡算法round_robin, weighted_round_robin, least_conn, ip_hash
HealthCheck HealthCheckConfig `yaml:"health_check"` // 健康检查配置
Timeout ProxyTimeout `yaml:"timeout"` // 超时配置
Headers ProxyHeaders `yaml:"headers"` // 请求/响应头修改
Cache ProxyCacheConfig `yaml:"cache"` // 代理缓存配置
Timeout ProxyTimeout `yaml:"timeout"` // 超时配置
Headers ProxyHeaders `yaml:"headers"` // 请求/响应头修改
Cache ProxyCacheConfig `yaml:"cache"` // 代理缓存配置
}
// ProxyTarget 后端目标配置。
@ -83,21 +83,21 @@ type ProxyHeaders struct {
// ProxyCacheConfig 代理缓存配置。
type ProxyCacheConfig struct {
Enabled bool `yaml:"enabled"` // 是否启用缓存
MaxAge time.Duration `yaml:"max_age"` // 缓存有效期
CacheLock bool `yaml:"cache_lock"` // 缓存锁,防止击穿
Enabled bool `yaml:"enabled"` // 是否启用缓存
MaxAge time.Duration `yaml:"max_age"` // 缓存有效期
CacheLock bool `yaml:"cache_lock"` // 缓存锁,防止击穿
StaleWhileRevalidate time.Duration `yaml:"stale_while_revalidate"` // 过期缓存复用时间
}
// SSLConfig SSL/TLS 配置。
type SSLConfig struct {
Cert string `yaml:"cert"` // 证书文件路径
Key string `yaml:"key"` // 私钥文件路径
CertChain string `yaml:"cert_chain"` // 证书链文件路径
Protocols []string `yaml:"protocols"` // TLS 版本,默认 ["TLSv1.2", "TLSv1.3"]
Ciphers []string `yaml:"ciphers"` // 加密套件(仅 TLS 1.2 有效)
OCSPStapling bool `yaml:"ocsp_stapling"` // OCSP Stapling 支持
HSTS HSTSConfig `yaml:"hsts"` // HSTS 配置
Cert string `yaml:"cert"` // 证书文件路径
Key string `yaml:"key"` // 私钥文件路径
CertChain string `yaml:"cert_chain"` // 证书链文件路径
Protocols []string `yaml:"protocols"` // TLS 版本,默认 ["TLSv1.2", "TLSv1.3"]
Ciphers []string `yaml:"ciphers"` // 加密套件(仅 TLS 1.2 有效)
OCSPStapling bool `yaml:"ocsp_stapling"` // OCSP Stapling 支持
HSTS HSTSConfig `yaml:"hsts"` // HSTS 配置
}
// HSTSConfig HTTP Strict Transport Security 配置。
@ -109,10 +109,10 @@ type HSTSConfig struct {
// SecurityConfig 安全配置,包含访问控制、限流、认证和安全头部。
type SecurityConfig struct {
Access AccessConfig `yaml:"access"` // IP 访问控制
Access AccessConfig `yaml:"access"` // IP 访问控制
RateLimit RateLimitConfig `yaml:"rate_limit"` // 速率限制
Auth AuthConfig `yaml:"auth"` // 认证配置
Headers SecurityHeaders `yaml:"headers"` // 安全头部
Auth AuthConfig `yaml:"auth"` // 认证配置
Headers SecurityHeaders `yaml:"headers"` // 安全头部
}
// AccessConfig IP 访问控制配置。
@ -132,12 +132,12 @@ type RateLimitConfig struct {
// AuthConfig 认证配置。
type AuthConfig struct {
Type string `yaml:"type"` // 认证类型basic
RequireTLS bool `yaml:"require_tls"` // 强制 HTTPS默认 true
Algorithm string `yaml:"algorithm"` // 哈希算法bcrypt, argon2id
Users []User `yaml:"users"` // 用户列表
Realm string `yaml:"realm"` // 认证域
MinPasswordLength int `yaml:"min_password_length"` // 密码最小长度
Type string `yaml:"type"` // 认证类型basic
RequireTLS bool `yaml:"require_tls"` // 强制 HTTPS默认 true
Algorithm string `yaml:"algorithm"` // 哈希算法bcrypt, argon2id
Users []User `yaml:"users"` // 用户列表
Realm string `yaml:"realm"` // 认证域
MinPasswordLength int `yaml:"min_password_length"` // 密码最小长度
}
// User 认证用户配置。
@ -148,11 +148,11 @@ type User struct {
// SecurityHeaders 安全头部配置。
type SecurityHeaders struct {
XFrameOptions string `yaml:"x_frame_options"` // X-Frame-Options: DENY, SAMEORIGIN
XContentTypeOptions string `yaml:"x_content_type_options"` // X-Content-Type-Options: nosniff
XFrameOptions string `yaml:"x_frame_options"` // X-Frame-Options: DENY, SAMEORIGIN
XContentTypeOptions string `yaml:"x_content_type_options"` // X-Content-Type-Options: nosniff
ContentSecurityPolicy string `yaml:"content_security_policy"` // Content-Security-Policy
ReferrerPolicy string `yaml:"referrer_policy"` // Referrer-Policy
PermissionsPolicy string `yaml:"permissions_policy"` // Permissions-Policy
ReferrerPolicy string `yaml:"referrer_policy"` // Referrer-Policy
PermissionsPolicy string `yaml:"permissions_policy"` // Permissions-Policy
}
// RewriteRule URL 重写规则。
@ -164,10 +164,10 @@ type RewriteRule struct {
// CompressionConfig 响应压缩配置。
type CompressionConfig struct {
Type string `yaml:"type"` // 压缩类型gzip, brotli, both
Level int `yaml:"level"` // 压缩级别1-9
Type string `yaml:"type"` // 压缩类型gzip, brotli, both
Level int `yaml:"level"` // 压缩级别1-9
MinSize int `yaml:"min_size"` // 最小压缩大小(字节)
Types []string `yaml:"types"` // 可压缩的 MIME 类型
Types []string `yaml:"types"` // 可压缩的 MIME 类型
}
// LoggingConfig 日志配置。
@ -197,9 +197,9 @@ type PerformanceConfig struct {
// GoroutinePoolConfig Goroutine 池配置。
type GoroutinePoolConfig struct {
Enabled bool `yaml:"enabled"` // 是否启用
MaxWorkers int `yaml:"max_workers"` // 最大 worker 数
MinWorkers int `yaml:"min_workers"` // 最小 worker 数(预热)
Enabled bool `yaml:"enabled"` // 是否启用
MaxWorkers int `yaml:"max_workers"` // 最大 worker 数
MinWorkers int `yaml:"min_workers"` // 最小 worker 数(预热)
IdleTimeout time.Duration `yaml:"idle_timeout"` // 空闲超时
}
@ -213,10 +213,10 @@ type FileCacheConfig struct {
// TransportConfig HTTP Transport 配置。
type TransportConfig struct {
MaxIdleConns int `yaml:"max_idle_conns"` // 最大空闲连接数
MaxIdleConns int `yaml:"max_idle_conns"` // 最大空闲连接数
MaxIdleConnsPerHost int `yaml:"max_idle_conns_per_host"` // 每主机最大空闲连接
IdleConnTimeout time.Duration `yaml:"idle_conn_timeout"` // 空闲连接超时
MaxConnsPerHost int `yaml:"max_conns_per_host"` // 每主机最大连接数
IdleConnTimeout time.Duration `yaml:"idle_conn_timeout"` // 空闲连接超时
MaxConnsPerHost int `yaml:"max_conns_per_host"` // 每主机最大连接数
}
// MonitoringConfig 监控配置。
@ -317,4 +317,4 @@ func Validate(cfg *Config) error {
}
return nil
}
}

View File

@ -436,4 +436,4 @@ func TestConfigMethods(t *testing.T) {
})
}
})
}
}

View File

@ -119,6 +119,11 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
buf.WriteString("server:\n")
buf.WriteString(fmt.Sprintf(" listen: \"%s\" # 监听地址\n", cfg.Server.Listen))
buf.WriteString(fmt.Sprintf(" name: \"%s\" # 服务器名称(虚拟主机匹配)\n", cfg.Server.Name))
buf.WriteString(fmt.Sprintf(" read_timeout: %ds # 读取超时0 表示不限制)\n", int(cfg.Server.ReadTimeout.Seconds())))
buf.WriteString(fmt.Sprintf(" write_timeout: %ds # 写入超时0 表示不限制)\n", int(cfg.Server.WriteTimeout.Seconds())))
buf.WriteString(fmt.Sprintf(" idle_timeout: %ds # 空闲超时0 表示不限制)\n", int(cfg.Server.IdleTimeout.Seconds())))
buf.WriteString(fmt.Sprintf(" max_conns_per_ip: %d # 每 IP 最大连接数0 表示不限制)\n", cfg.Server.MaxConnsPerIP))
buf.WriteString(fmt.Sprintf(" max_requests_per_conn: %d # 每连接最大请求数0 表示不限制)\n", cfg.Server.MaxRequestsPerConn))
buf.WriteString("\n")
// static 配置
@ -205,10 +210,10 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
buf.WriteString("\n")
buf.WriteString(" # 安全头部\n")
buf.WriteString(" headers:\n")
buf.WriteString(fmt.Sprintf(" x_frame_options: \"%s\" # 防止点击劫持(有效值: DENY, SAMEORIGIN\n", cfg.Server.Security.Headers.XFrameOptions))
buf.WriteString(fmt.Sprintf(" x_frame_options: \"%s\" # 防止点击劫持(有效值: DENY, SAMEORIGIN, 空表示禁用\n", cfg.Server.Security.Headers.XFrameOptions))
buf.WriteString(fmt.Sprintf(" x_content_type_options: \"%s\" # 防止 MIME 嗅探\n", cfg.Server.Security.Headers.XContentTypeOptions))
buf.WriteString(fmt.Sprintf(" referrer_policy: \"%s\" # 引用策略\n", cfg.Server.Security.Headers.ReferrerPolicy))
buf.WriteString(" # content_security_policy: \"default-src 'self'\" # CSP(推荐配置)\n")
buf.WriteString(fmt.Sprintf(" referrer_policy: \"%s\" # 引用策略(有效值: no-referrer, no-referrer-when-downgrade, origin, origin-when-cross-origin, same-origin, strict-origin, strict-origin-when-cross-origin, unsafe-url\n", cfg.Server.Security.Headers.ReferrerPolicy))
buf.WriteString(" # content_security_policy: \"default-src 'self'\" # 内容安全策略 CSP\n")
buf.WriteString(" # permissions_policy: \"geolocation=(), microphone=()\" # 权限策略\n")
buf.WriteString("\n")
@ -223,9 +228,9 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
// compression 配置
buf.WriteString(" # 响应压缩配置\n")
buf.WriteString(" compression:\n")
buf.WriteString(fmt.Sprintf(" type: \"%s\" # 压缩类型: gzip, brotli, both\n", cfg.Server.Compression.Type))
buf.WriteString(fmt.Sprintf(" level: %d # 压缩级别 (1-9)\n", cfg.Server.Compression.Level))
buf.WriteString(fmt.Sprintf(" min_size: %d # 最小压缩大小(字节\n", cfg.Server.Compression.MinSize))
buf.WriteString(fmt.Sprintf(" type: \"%s\" # 压缩类型(有效值: gzip, brotli, both,空表示禁用)\n", cfg.Server.Compression.Type))
buf.WriteString(fmt.Sprintf(" level: %d # 压缩级别(范围 1-9值越大压缩率越高但速度越慢\n", cfg.Server.Compression.Level))
buf.WriteString(fmt.Sprintf(" min_size: %d # 最小压缩大小(字节,小于此值不压缩\n", cfg.Server.Compression.MinSize))
buf.WriteString(" types: # 可压缩的 MIME 类型\n")
for _, t := range cfg.Server.Compression.Types {
buf.WriteString(fmt.Sprintf(" - \"%s\"\n", t))
@ -250,11 +255,11 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
buf.WriteString("# 日志配置\n")
buf.WriteString("logging:\n")
buf.WriteString(" access:\n")
buf.WriteString(fmt.Sprintf(" format: \"%s\" # 日志格式\n", cfg.Logging.Access.Format))
buf.WriteString(" # path: /var/log/lolly/access.log # 日志文件路径\n")
buf.WriteString(fmt.Sprintf(" format: \"%s\" # 日志格式(支持变量: $remote_addr, $request, $status, $body_bytes_sent, $request_time, $http_referer, $http_user_agent\n", cfg.Logging.Access.Format))
buf.WriteString(" # path: /var/log/lolly/access.log # 日志文件路径(空表示输出到 stdout\n")
buf.WriteString(" error:\n")
buf.WriteString(fmt.Sprintf(" level: \"%s\" # 日志级别: debug, info, warn, error\n", cfg.Logging.Error.Level))
buf.WriteString(" # path: /var/log/lolly/error.log\n")
buf.WriteString(fmt.Sprintf(" level: \"%s\" # 日志级别(有效值: debug, info, warn, error,级别越高日志越少)\n", cfg.Logging.Error.Level))
buf.WriteString(" # path: /var/log/lolly/error.log # 日志文件路径(空表示输出到 stderr\n")
buf.WriteString("\n")
// performance 配置
@ -289,4 +294,3 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
return buf.Bytes(), nil
}

View File

@ -293,4 +293,4 @@ func validateCompression(c *CompressionConfig) error {
}
return nil
}
}

View File

@ -8,11 +8,11 @@ import (
func TestValidateServer(t *testing.T) {
tests := []struct {
name string
config ServerConfig
name string
config ServerConfig
isDefault bool
wantErr bool
errMsg string
wantErr bool
errMsg string
}{
{
name: "有效配置",
@ -810,4 +810,4 @@ func TestValidateSecurity(t *testing.T) {
}
})
}
}
}

View File

@ -45,4 +45,4 @@ func (r *Router) HEAD(path string, handler fasthttp.RequestHandler) {
// Handler 返回路由处理器
func (r *Router) Handler() fasthttp.RequestHandler {
return r.router.Handler
}
}

View File

@ -228,4 +228,4 @@ func TestRouterNotFound(t *testing.T) {
if ctx.Response.StatusCode() != fasthttp.StatusNotFound {
t.Errorf("状态码 = %d, want %d", ctx.Response.StatusCode(), fasthttp.StatusNotFound)
}
}
}

View File

@ -0,0 +1,180 @@
// Package handler 提供零拷贝文件传输功能,优化大文件传输性能。
package handler
import (
"io"
"net"
"os"
"runtime"
"sync"
"syscall"
"github.com/valyala/fasthttp"
)
const (
// MinSendfileSize 使用 sendfile 的最小文件大小8KB
MinSendfileSize = 8 * 1024
)
// SendFile 零拷贝文件传输。
// 大文件使用系统调用直接从文件传输到 socket避免用户空间拷贝。
func SendFile(ctx *fasthttp.RequestCtx, file *os.File, offset, length int64) error {
// 小文件使用普通 io.Copy
if length < MinSendfileSize {
return copyFile(ctx, file, offset, length)
}
// 尝试获取 socket 文件描述符
conn := getNetConn(ctx)
if conn == nil {
return copyFile(ctx, file, offset, length)
}
// 根据平台选择 sendfile 实现
err := platformSendfile(conn, file, offset, length)
if err != nil {
// sendfile 失败fallback 到 io.Copy
return copyFile(ctx, file, offset, length)
}
return nil
}
// getNetConn 从 fasthttp.RequestCtx 获取底层 net.Conn。
func getNetConn(ctx *fasthttp.RequestCtx) net.Conn {
// fasthttp 内部使用 net.Conn通过接口获取
return ctx.Conn()
}
// copyFile 普通文件拷贝fallback
func copyFile(ctx *fasthttp.RequestCtx, file *os.File, offset, length int64) error {
if offset > 0 {
if _, err := file.Seek(offset, io.SeekStart); err != nil {
return err
}
}
// 使用 io.CopyN 或 io.Copy
if length > 0 {
_, err := io.CopyN(ctx, file, length)
return err
}
_, err := io.Copy(ctx, file)
return err
}
// platformSendfile 平台特定的 sendfile 实现。
func platformSendfile(conn net.Conn, file *os.File, offset, length int64) error {
switch runtime.GOOS {
case "linux":
return linuxSendfile(conn, file.Fd(), offset, length)
case "darwin":
// macOS sendfile 签名复杂,简化使用 fallback
return syscall.ENOTSUP
case "windows":
// Windows TransmitFile 需要特殊 API
return syscall.ENOTSUP
default:
return syscall.ENOTSUP
}
}
// linuxSendfile Linux sendfile 系统调用。
func linuxSendfile(conn net.Conn, fileFd uintptr, offset, length int64) error {
socketFd, err := getSocketFd(conn)
if err != nil {
return err
}
// Linux sendfile: sendfile(out_fd, in_fd, offset, count)
var sent int64
remain := length
for remain > 0 {
n, err := syscall.Sendfile(int(socketFd), int(fileFd), nil, int(remain))
if err != nil {
return err
}
if n == 0 {
break // EOF
}
sent += int64(n)
remain -= int64(n)
}
return nil
}
// getSocketFd 获取 socket 文件描述符。
func getSocketFd(conn net.Conn) (uintptr, error) {
switch c := conn.(type) {
case *net.TCPConn:
file, err := c.File()
if err != nil {
return 0, err
}
defer file.Close()
return file.Fd(), nil
case *net.UnixConn:
file, err := c.File()
if err != nil {
return 0, err
}
defer file.Close()
return file.Fd(), nil
default:
return 0, syscall.ENOTSUP
}
}
// BufferPool 缓冲池,复用内存减少分配。
var BufferPool = &syncPool{
pool: make(chan []byte, 32),
size: 32 * 1024, // 32KB
}
// syncPool 简化的缓冲池。
type syncPool struct {
pool chan []byte
size int
}
// Get 获取缓冲区。
func (p *syncPool) Get() []byte {
select {
case buf := <-p.pool:
return buf
default:
return make([]byte, p.size)
}
}
// Put 放回缓冲区。
func (p *syncPool) Put(buf []byte) {
// 只放回合适大小的缓冲区
if len(buf) == p.size {
select {
case p.pool <- buf:
default: // 池满,丢弃
}
}
}
// RealBufferPool 使用 sync.Pool 的标准实现(推荐)。
var RealBufferPool = sync.Pool{
New: func() interface{} {
return make([]byte, 32*1024)
},
}
// GetBuffer 从池获取缓冲区。
func GetBuffer() []byte {
return RealBufferPool.Get().([]byte)
}
// PutBuffer 放回缓冲区。
func PutBuffer(buf []byte) {
RealBufferPool.Put(buf)
}

View File

@ -0,0 +1,101 @@
package handler
import (
"os"
"path/filepath"
"testing"
)
func TestBufferPool(t *testing.T) {
// 获取缓冲区
buf := BufferPool.Get()
if buf == nil {
t.Error("Expected non-nil buffer")
}
if len(buf) != 32*1024 {
t.Errorf("Expected buffer size 32KB, got %d", len(buf))
}
// 放回缓冲区
BufferPool.Put(buf)
// 再次获取(可能是同一个)
buf2 := BufferPool.Get()
if buf2 == nil {
t.Error("Expected non-nil buffer")
}
}
func TestRealBufferPool(t *testing.T) {
buf := GetBuffer()
if buf == nil {
t.Error("Expected non-nil buffer")
}
if len(buf) != 32*1024 {
t.Errorf("Expected buffer size 32KB, got %d", len(buf))
}
PutBuffer(buf)
}
func TestMinSendfileSize(t *testing.T) {
if MinSendfileSize != 8*1024 {
t.Errorf("Expected MinSendfileSize 8KB, got %d", MinSendfileSize)
}
}
func TestGetBuffer(t *testing.T) {
buf := GetBuffer()
if buf == nil {
t.Error("Expected non-nil buffer")
return
}
if len(buf) != 32*1024 {
t.Errorf("Expected buffer size 32KB, got %d", len(buf))
}
// 测试写入
copy(buf, []byte("test"))
if string(buf[:4]) != "test" {
t.Error("Expected to write 'test' to buffer")
}
}
func TestPlatformSendfile(t *testing.T) {
// 创建临时文件
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test.txt")
content := []byte("Hello, World! This is a test file for sendfile.")
if err := os.WriteFile(tmpFile, content, 0644); err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
file, err := os.Open(tmpFile)
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer file.Close()
// 测试平台 sendfile小文件会 fallback 到 copyFile
// 由于没有真实的网络连接,这个测试主要验证不会崩溃
_ = platformSendfile(nil, file, 0, int64(len(content)))
}
func TestBufferPoolConcurrent(t *testing.T) {
const iterations = 100
done := make(chan bool)
for i := 0; i < iterations; i++ {
go func() {
buf := GetBuffer()
PutBuffer(buf)
done <- true
}()
}
for i := 0; i < iterations; i++ {
<-done
}
}

View File

@ -54,4 +54,4 @@ func (h *StaticHandler) Handle(ctx *fasthttp.RequestCtx) {
// 直接返回文件
fasthttp.ServeFile(ctx, filePath)
}
}

View File

@ -26,12 +26,12 @@ func newTestContext(t *testing.T, path string) *fasthttp.RequestCtx {
// TestStaticHandlerHandle 测试静态文件处理器
func TestStaticHandlerHandle(t *testing.T) {
tests := []struct {
name string
setup func(t *testing.T, root string) // 在临时目录中设置测试文件
path string // 请求路径
wantStatus int // 期望的 HTTP 状态码
wantContent string // 期望的响应内容(可选)
skipContent bool // 是否跳过内容验证
name string
setup func(t *testing.T, root string) // 在临时目录中设置测试文件
path string // 请求路径
wantStatus int // 期望的 HTTP 状态码
wantContent string // 期望的响应内容(可选)
skipContent bool // 是否跳过内容验证
}{
{
name: "正常文件访问",
@ -89,8 +89,8 @@ func TestStaticHandlerHandle(t *testing.T) {
t.Fatalf("创建目录失败: %v", err)
}
},
path: "/noindex/",
wantStatus: fasthttp.StatusForbidden,
path: "/noindex/",
wantStatus: fasthttp.StatusForbidden,
skipContent: true,
},
{
@ -99,8 +99,8 @@ func TestStaticHandlerHandle(t *testing.T) {
t.Helper()
// 不创建任何文件
},
path: "/nonexistent.txt",
wantStatus: fasthttp.StatusNotFound,
path: "/nonexistent.txt",
wantStatus: fasthttp.StatusNotFound,
skipContent: true,
},
{
@ -109,8 +109,8 @@ func TestStaticHandlerHandle(t *testing.T) {
t.Helper()
// root 目录没有索引文件
},
path: "/",
wantStatus: fasthttp.StatusForbidden,
path: "/",
wantStatus: fasthttp.StatusForbidden,
skipContent: true,
},
{
@ -397,4 +397,4 @@ func TestNewStaticHandler(t *testing.T) {
t.Errorf("handler.index 应为 nil")
}
})
}
}

View File

@ -137,4 +137,4 @@ func parseLevel(level string) zerolog.Level {
default:
return zerolog.InfoLevel
}
}
}

View File

@ -182,4 +182,4 @@ func TestLoggerDebug(t *testing.T) {
if !strings.Contains(output, "info message") {
t.Error("Expected info message to be logged")
}
}
}

View File

@ -23,10 +23,10 @@ const (
// CompressionMiddleware 响应压缩中间件。
type CompressionMiddleware struct {
types []string // 可压缩的 MIME 类型
level int // 压缩级别
minSize int // 最小压缩大小
algorithm Algorithm // 压缩算法
types []string // 可压缩的 MIME 类型
level int // 压缩级别
minSize int // 最小压缩大小
algorithm Algorithm // 压缩算法
// 缓冲池
gzipPool sync.Pool
@ -226,4 +226,4 @@ func (m *CompressionMiddleware) Level() int {
// MinSize 返回最小压缩大小。
func (m *CompressionMiddleware) MinSize() int {
return m.minSize
}
}

View File

@ -317,4 +317,4 @@ func TestGetters(t *testing.T) {
if len(m.Types()) != 1 {
t.Errorf("Expected 1 type, got %d", len(m.Types()))
}
}
}

View File

@ -25,4 +25,4 @@ func (c *Chain) Apply(final fasthttp.RequestHandler) fasthttp.RequestHandler {
handler = c.middlewares[i].Process(handler)
}
return handler
}
}

View File

@ -142,4 +142,4 @@ func (m *modifyMiddleware) Process(next fasthttp.RequestHandler) fasthttp.Reques
// 在响应后追加内容
ctx.SetBodyString(string(ctx.Response.Body()) + "-modified")
}
}
}

View File

@ -114,4 +114,4 @@ func (m *RewriteMiddleware) Process(next fasthttp.RequestHandler) fasthttp.Reque
// Rules 返回编译后的规则列表(用于调试)。
func (m *RewriteMiddleware) Rules() []Rule {
return m.rules
}
}

View File

@ -284,4 +284,4 @@ func TestRewriteMiddlewareRules(t *testing.T) {
if len(compiled) != 2 {
t.Errorf("Expected 2 rules, got %d", len(compiled))
}
}
}

View File

@ -40,16 +40,16 @@ type Action int
const (
ActionAllow Action = iota // Allow the request
ActionDeny // Deny the request (403 Forbidden)
ActionDeny // Deny the request (403 Forbidden)
)
// AccessControl implements IP-based access control middleware.
// It checks incoming requests against configured allow/deny CIDR lists.
type AccessControl struct {
allowList []net.IPNet // CIDR networks to allow
denyList []net.IPNet // CIDR networks to deny
allowList []net.IPNet // CIDR networks to allow
denyList []net.IPNet // CIDR networks to deny
defaultAction Action // Default action when no rule matches
mu sync.RWMutex
mu sync.RWMutex
}
// NewAccessControl creates a new access control middleware from configuration.
@ -299,4 +299,4 @@ func actionToString(action Action) string {
}
// Verify interface compliance
var _ middleware.Middleware = (*AccessControl)(nil)
var _ middleware.Middleware = (*AccessControl)(nil)

View File

@ -340,4 +340,4 @@ func TestGetStats(t *testing.T) {
if stats.Default != "deny" {
t.Errorf("Expected Default 'deny', got %s", stats.Default)
}
}
}

View File

@ -35,8 +35,8 @@ import (
"sync"
"github.com/valyala/fasthttp"
"golang.org/x/crypto/bcrypt"
"golang.org/x/crypto/argon2"
"golang.org/x/crypto/bcrypt"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/middleware"
@ -46,37 +46,37 @@ import (
type HashAlgorithm int
const (
HashBcrypt HashAlgorithm = iota // bcrypt (default, recommended)
HashArgon2id // Argon2id (more secure, compute-intensive)
HashBcrypt HashAlgorithm = iota // bcrypt (default, recommended)
HashArgon2id // Argon2id (more secure, compute-intensive)
)
// BasicAuth implements HTTP Basic Authentication middleware.
type BasicAuth struct {
users map[string]string // username -> hashed password
algorithm HashAlgorithm // Hash algorithm used
realm string // Authentication realm
requireTLS bool // Require HTTPS (default true)
minPasswordLength int // Minimum password length for validation
argon2Params argon2Params // Argon2id parameters
mu sync.RWMutex
users map[string]string // username -> hashed password
algorithm HashAlgorithm // Hash algorithm used
realm string // Authentication realm
requireTLS bool // Require HTTPS (default true)
minPasswordLength int // Minimum password length for validation
argon2Params argon2Params // Argon2id parameters
mu sync.RWMutex
}
// argon2Params holds Argon2id configuration parameters.
type argon2Params struct {
time uint32 // Number of passes
memory uint32 // Memory cost in KB
threads uint8 // Parallelism
saltLen uint32 // Salt length
keyLen uint32 // Output key length
time uint32 // Number of passes
memory uint32 // Memory cost in KB
threads uint8 // Parallelism
saltLen uint32 // Salt length
keyLen uint32 // Output key length
}
// Default Argon2id parameters (OWASP recommended)
var defaultArgon2Params = argon2Params{
time: 3,
memory: 64 * 1024, // 64 MB
threads: 4,
saltLen: 16,
keyLen: 32,
time: 3,
memory: 64 * 1024, // 64 MB
threads: 4,
saltLen: 16,
keyLen: 32,
}
// NewBasicAuth creates a new Basic Auth middleware from configuration.
@ -101,10 +101,10 @@ func NewBasicAuth(cfg *config.AuthConfig) (*BasicAuth, error) {
}
auth := &BasicAuth{
users: make(map[string]string),
requireTLS: cfg.RequireTLS, // Default is true from config defaults
users: make(map[string]string),
requireTLS: cfg.RequireTLS, // Default is true from config defaults
minPasswordLength: cfg.MinPasswordLength,
argon2Params: defaultArgon2Params,
argon2Params: defaultArgon2Params,
}
// Set realm
@ -452,4 +452,4 @@ func parseUint8(s string) uint8 {
}
// Verify interface compliance
var _ middleware.Middleware = (*BasicAuth)(nil)
var _ middleware.Middleware = (*BasicAuth)(nil)

View File

@ -177,7 +177,7 @@ func TestBasicAuthProcess(t *testing.T) {
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
auth, err := NewBasicAuth(&config.AuthConfig{
Type: "basic",
Type: "basic",
RequireTLS: false, // Disable TLS for testing
Users: []config.User{
{Name: "admin", Password: string(hashedPassword)},
@ -346,7 +346,7 @@ func TestExtractCredentials(t *testing.T) {
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
auth, err := NewBasicAuth(&config.AuthConfig{
Type: "basic",
Type: "basic",
RequireTLS: false,
Users: []config.User{
{Name: "admin", Password: string(hashedPassword)},
@ -396,4 +396,4 @@ func TestName(t *testing.T) {
if auth.Name() != "basic_auth" {
t.Errorf("Expected name 'basic_auth', got %s", auth.Name())
}
}
}

View File

@ -37,9 +37,9 @@ import (
// SecurityHeadersMiddleware adds security-related headers to responses.
type SecurityHeadersMiddleware struct {
config *config.SecurityHeaders
hsts string // Pre-formatted HSTS header value
mu sync.RWMutex
config *config.SecurityHeaders
hsts string // Pre-formatted HSTS header value
mu sync.RWMutex
}
// NewSecurityHeaders creates a new security headers middleware.
@ -218,11 +218,11 @@ func DefaultSecurityHeaders() *config.SecurityHeaders {
// Suitable for high-security applications.
func StrictSecurityHeaders() *config.SecurityHeaders {
return &config.SecurityHeaders{
XFrameOptions: "DENY",
XContentTypeOptions: "nosniff",
XFrameOptions: "DENY",
XContentTypeOptions: "nosniff",
ContentSecurityPolicy: "default-src 'self'; script-src 'self'; style-src 'self'; img-src 'self'; font-src 'self'; connect-src 'self'; frame-ancestors 'none'",
ReferrerPolicy: "no-referrer",
PermissionsPolicy: "accelerometer=(), camera=(), geolocation=(), gyroscope=(), magnetometer=(), microphone=(), payment=(), usb=()",
ReferrerPolicy: "no-referrer",
PermissionsPolicy: "accelerometer=(), camera=(), geolocation=(), gyroscope=(), magnetometer=(), microphone=(), payment=(), usb=()",
}
}
@ -237,4 +237,4 @@ func DevelopmentSecurityHeaders() *config.SecurityHeaders {
}
// Verify interface compliance
var _ middleware.Middleware = (*SecurityHeadersMiddleware)(nil)
var _ middleware.Middleware = (*SecurityHeadersMiddleware)(nil)

View File

@ -19,8 +19,8 @@ func TestNewSecurityHeaders(t *testing.T) {
{
name: "custom config",
cfg: &config.SecurityHeaders{
XFrameOptions: "SAMEORIGIN",
XContentTypeOptions: "nosniff",
XFrameOptions: "SAMEORIGIN",
XContentTypeOptions: "nosniff",
ContentSecurityPolicy: "default-src 'self'",
},
},
@ -45,11 +45,11 @@ func TestSecurityHeadersName(t *testing.T) {
func TestSecurityHeadersProcess(t *testing.T) {
cfg := &config.SecurityHeaders{
XFrameOptions: "DENY",
XContentTypeOptions: "nosniff",
XFrameOptions: "DENY",
XContentTypeOptions: "nosniff",
ContentSecurityPolicy: "default-src 'self'",
ReferrerPolicy: "strict-origin-when-cross-origin",
PermissionsPolicy: "geolocation=()",
ReferrerPolicy: "strict-origin-when-cross-origin",
PermissionsPolicy: "geolocation=()",
}
sh := NewSecurityHeaders(cfg)
@ -157,7 +157,7 @@ func TestUpdateConfig(t *testing.T) {
sh := NewSecurityHeaders(nil)
newCfg := &config.SecurityHeaders{
XFrameOptions: "DENY",
XFrameOptions: "DENY",
ReferrerPolicy: "no-referrer",
}
@ -244,4 +244,4 @@ func TestFormatHSTSValue(t *testing.T) {
}
})
}
}
}

View File

@ -38,9 +38,9 @@ import (
// RateLimiter implements request rate limiting using token bucket algorithm.
type RateLimiter struct {
rate float64 // Tokens added per second
burst float64 // Maximum bucket capacity
keyFunc KeyFunc // Function to extract limit key
rate float64 // Tokens added per second
burst float64 // Maximum bucket capacity
keyFunc KeyFunc // Function to extract limit key
buckets map[string]*tokenBucket
mu sync.RWMutex
}
@ -294,12 +294,12 @@ func (rl *RateLimiter) GetStats() RateLimitStats {
// ConnLimiter implements connection count limiting.
// This is a separate limiter for maximum concurrent connections.
type ConnLimiter struct {
max int // Maximum concurrent connections
current int64 // Current connection count (atomic)
perKey bool // Limit per key instead of global
keyFunc KeyFunc // Key extraction function
counts map[string]int64 // Connection counts per key
mu sync.RWMutex
max int // Maximum concurrent connections
current int64 // Current connection count (atomic)
perKey bool // Limit per key instead of global
keyFunc KeyFunc // Key extraction function
counts map[string]int64 // Connection counts per key
mu sync.RWMutex
}
// NewConnLimiter creates a new connection limiter.
@ -318,9 +318,9 @@ func NewConnLimiter(max int, perKey bool, keyType string) (*ConnLimiter, error)
}
cl := &ConnLimiter{
max: max,
perKey: perKey,
counts: make(map[string]int64),
max: max,
perKey: perKey,
counts: make(map[string]int64),
}
if perKey {
@ -420,4 +420,4 @@ func addInt64(ptr *int64, delta int64) {
// Verify interface compliance
var _ middleware.Middleware = (*RateLimiter)(nil)
var _ middleware.Middleware = (*connLimiterMiddleware)(nil)
var _ middleware.Middleware = (*connLimiterMiddleware)(nil)

View File

@ -350,4 +350,4 @@ func TestConnLimiterMiddleware(t *testing.T) {
if middleware.Name() != "conn_limiter" {
t.Errorf("Expected name 'conn_limiter', got %s", middleware.Name())
}
}
}

View File

@ -2,7 +2,7 @@
//
// 此文件实现了针对后端目标的健康检查功能,支持
// 主动健康检查(定期 HTTP 探测)和被动健康检查
//(基于观察到的失败标记目标为不健康)。
// (基于观察到的失败标记目标为不健康)。
//
//go:generate go test -v ./...
package proxy

View File

@ -402,9 +402,9 @@ func TestModifyRequestHeaders(t *testing.T) {
},
},
{
name: "移除请求头",
clientIP: "192.168.1.100",
removeHeaders: []string{"X-Remove-Me"},
name: "移除请求头",
clientIP: "192.168.1.100",
removeHeaders: []string{"X-Remove-Me"},
shouldNotExist: []string{"X-Remove-Me"},
},
}
@ -416,8 +416,8 @@ func TestModifyRequestHeaders(t *testing.T) {
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
Headers: config.ProxyHeaders{
SetRequest: tt.setRequest,
Remove: tt.removeHeaders,
SetRequest: tt.setRequest,
Remove: tt.removeHeaders,
},
}
@ -704,40 +704,40 @@ func TestGetConfig(t *testing.T) {
// TestIsWebSocketRequest 测试WebSocket请求检测
func TestIsWebSocketRequest(t *testing.T) {
tests := []struct {
name string
upgrade string
name string
upgrade string
connection string
expected bool
expected bool
}{
{
name: "标准WebSocket请求",
upgrade: "websocket",
name: "标准WebSocket请求",
upgrade: "websocket",
connection: "upgrade",
expected: true,
expected: true,
},
{
name: "大小写不敏感",
upgrade: "WebSocket",
name: "大小写不敏感",
upgrade: "WebSocket",
connection: "Upgrade",
expected: true,
expected: true,
},
{
name: "非WebSocket升级",
upgrade: "h2c",
name: "非WebSocket升级",
upgrade: "h2c",
connection: "upgrade",
expected: false,
expected: false,
},
{
name: "非upgrade连接",
upgrade: "websocket",
name: "非upgrade连接",
upgrade: "websocket",
connection: "keep-alive",
expected: false,
expected: false,
},
{
name: "keep-alive, Upgrade",
upgrade: "websocket",
name: "keep-alive, Upgrade",
upgrade: "websocket",
connection: "keep-alive, Upgrade",
expected: true,
expected: true,
},
}

191
internal/server/pool.go Normal file
View File

@ -0,0 +1,191 @@
// Package server 提供 Goroutine 池,限制并发数量,减少调度开销。
package server
import (
"context"
"sync"
"sync/atomic"
"time"
"github.com/valyala/fasthttp"
)
// GoroutinePool Goroutine 池配置。
type GoroutinePool struct {
maxWorkers int32 // 最大 worker 数
minWorkers int32 // 最小 worker 数(预热)
idleTimeout time.Duration // 穴闲超时
taskQueue chan Task // 任务队列
workers int32 // 当前 worker 数
idleWorkers int32 // 穴闲 worker 数
running atomic.Bool // 运行状态
wg sync.WaitGroup // 等待所有 worker
ctx context.Context
cancel context.CancelFunc
}
// Task 任务函数类型。
type Task func(*fasthttp.RequestCtx)
// PoolConfig 池配置。
type PoolConfig struct {
MaxWorkers int // 最大并发数
MinWorkers int // 预热 worker 数
IdleTimeout time.Duration // 穴闲超时
QueueSize int // 任务队列大小
}
// NewGoroutinePool 创建 Goroutine 池。
func NewGoroutinePool(cfg PoolConfig) *GoroutinePool {
if cfg.MaxWorkers <= 0 {
cfg.MaxWorkers = 10000
}
if cfg.MinWorkers <= 0 {
cfg.MinWorkers = 100
}
if cfg.MinWorkers > cfg.MaxWorkers {
cfg.MinWorkers = cfg.MaxWorkers
}
if cfg.IdleTimeout <= 0 {
cfg.IdleTimeout = 60 * time.Second
}
if cfg.QueueSize <= 0 {
cfg.QueueSize = 1000
}
ctx, cancel := context.WithCancel(context.Background())
p := &GoroutinePool{
maxWorkers: int32(cfg.MaxWorkers),
minWorkers: int32(cfg.MinWorkers),
idleTimeout: cfg.IdleTimeout,
taskQueue: make(chan Task, cfg.QueueSize),
ctx: ctx,
cancel: cancel,
}
// 预热 worker
for i := 0; i < cfg.MinWorkers; i++ {
p.startWorker()
}
return p
}
// Start 启动池。
func (p *GoroutinePool) Start() {
p.running.Store(true)
}
// Stop 停止池。
func (p *GoroutinePool) Stop() {
p.running.Store(false)
p.cancel()
p.wg.Wait()
}
// Submit 提交任务。
func (p *GoroutinePool) Submit(ctx *fasthttp.RequestCtx, task Task) error {
if !p.running.Load() {
// 池未运行,直接执行
task(ctx)
return nil
}
// 尝试放入队列
select {
case p.taskQueue <- task:
// 任务入队成功
// 如果有空闲 worker 不足,可能需要启动新 worker
if p.idleWorkers == 0 && p.workers < p.maxWorkers {
p.startWorker()
}
return nil
default:
// 队列满,需要启动新 worker 或直接执行
if p.workers < p.maxWorkers {
p.startWorker()
// 重新尝试入队
p.taskQueue <- task
return nil
}
// 达到最大 worker直接执行fallback
task(ctx)
return nil
}
}
// startWorker 启动一个 worker。
func (p *GoroutinePool) startWorker() {
p.workers++
p.wg.Add(1)
go func() {
defer p.wg.Done()
defer atomic.AddInt32(&p.workers, -1)
idleTimer := time.NewTimer(p.idleTimeout)
defer idleTimer.Stop()
for {
// 标记为空闲
atomic.AddInt32(&p.idleWorkers, 1)
select {
case task := <-p.taskQueue:
// 取出任务,取消空闲标记
atomic.AddInt32(&p.idleWorkers, -1)
idleTimer.Reset(p.idleTimeout)
// 执行任务
task(nil) // 注意fasthttp.RequestCtx 需要在任务中传入
case <-idleTimer.C:
// 穴闲超时,退出 worker保持最小数量
atomic.AddInt32(&p.idleWorkers, -1)
if p.workers > p.minWorkers {
return
}
idleTimer.Reset(p.idleTimeout)
case <-p.ctx.Done():
// 池关闭
atomic.AddInt32(&p.idleWorkers, -1)
return
}
}
}()
}
// Stats 返回池统计信息。
func (p *GoroutinePool) Stats() PoolStats {
return PoolStats{
Workers: atomic.LoadInt32(&p.workers),
IdleWorkers: atomic.LoadInt32(&p.idleWorkers),
MaxWorkers: p.maxWorkers,
MinWorkers: p.minWorkers,
QueueLen: int32(len(p.taskQueue)),
QueueCap: int32(cap(p.taskQueue)),
}
}
// PoolStats 池统计信息。
type PoolStats struct {
Workers int32
IdleWorkers int32
MaxWorkers int32
MinWorkers int32
QueueLen int32
QueueCap int32
}
// WrapHandler 使用池包装 fasthttp 处理器。
func (p *GoroutinePool) WrapHandler(handler fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
// 使用池执行处理器
p.Submit(ctx, func(innerCtx *fasthttp.RequestCtx) {
handler(ctx)
})
}
}

View File

@ -0,0 +1,170 @@
package server
import (
"sync"
"sync/atomic"
"testing"
"time"
"github.com/valyala/fasthttp"
)
func TestNewGoroutinePool(t *testing.T) {
cfg := PoolConfig{
MaxWorkers: 100,
MinWorkers: 10,
IdleTimeout: 30 * time.Second,
QueueSize: 50,
}
p := NewGoroutinePool(cfg)
if p == nil {
t.Error("Expected non-nil pool")
}
// 检查配置
if p.maxWorkers != 100 {
t.Errorf("Expected maxWorkers 100, got %d", p.maxWorkers)
}
if p.minWorkers != 10 {
t.Errorf("Expected minWorkers 10, got %d", p.minWorkers)
}
}
func TestPoolDefaults(t *testing.T) {
p := NewGoroutinePool(PoolConfig{})
// 应该使用默认值
if p.maxWorkers != 10000 {
t.Errorf("Expected default maxWorkers 10000, got %d", p.maxWorkers)
}
if p.minWorkers != 100 {
t.Errorf("Expected default minWorkers 100, got %d", p.minWorkers)
}
}
func TestPoolStartStop(t *testing.T) {
p := NewGoroutinePool(PoolConfig{
MaxWorkers: 10,
MinWorkers: 2,
})
p.Start()
if !p.running.Load() {
t.Error("Expected pool to be running")
}
p.Stop()
if p.running.Load() {
t.Error("Expected pool to be stopped")
}
}
func TestPoolSubmit(t *testing.T) {
p := NewGoroutinePool(PoolConfig{
MaxWorkers: 10,
MinWorkers: 2,
QueueSize: 10,
IdleTimeout: 5 * time.Second,
})
p.Start()
defer p.Stop()
var executed atomic.Bool
task := func(ctx *fasthttp.RequestCtx) {
executed.Store(true)
}
err := p.Submit(nil, task)
if err != nil {
t.Errorf("Submit failed: %v", err)
}
// 等待任务执行
time.Sleep(100 * time.Millisecond)
if !executed.Load() {
t.Error("Expected task to be executed")
}
}
func TestPoolStats(t *testing.T) {
p := NewGoroutinePool(PoolConfig{
MaxWorkers: 100,
MinWorkers: 10,
QueueSize: 50,
IdleTimeout: 30 * time.Second,
})
p.Start()
defer p.Stop()
stats := p.Stats()
if stats.MaxWorkers != 100 {
t.Errorf("Expected MaxWorkers 100, got %d", stats.MaxWorkers)
}
if stats.MinWorkers != 10 {
t.Errorf("Expected MinWorkers 10, got %d", stats.MinWorkers)
}
if stats.QueueCap != 50 {
t.Errorf("Expected QueueCap 50, got %d", stats.QueueCap)
}
}
func TestPoolConcurrentSubmit(t *testing.T) {
p := NewGoroutinePool(PoolConfig{
MaxWorkers: 50,
MinWorkers: 5,
QueueSize: 100,
IdleTimeout: 5 * time.Second,
})
p.Start()
defer p.Stop()
var counter atomic.Int32
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
p.Submit(nil, func(ctx *fasthttp.RequestCtx) {
counter.Add(1)
})
}()
}
wg.Wait()
// 等待所有任务执行
time.Sleep(500 * time.Millisecond)
if counter.Load() != 100 {
t.Errorf("Expected 100 executions, got %d", counter.Load())
}
}
func TestPoolSubmitWhenStopped(t *testing.T) {
p := NewGoroutinePool(PoolConfig{
MaxWorkers: 10,
})
// 不启动池
var executed atomic.Bool
task := func(ctx *fasthttp.RequestCtx) {
executed.Store(true)
}
err := p.Submit(nil, task)
if err != nil {
t.Errorf("Submit should not fail when stopped: %v", err)
}
// 任务应该直接执行
if !executed.Load() {
t.Error("Expected task to be executed directly when pool is stopped")
}
}

View File

@ -104,4 +104,4 @@ func TestGracefulStopWithZeroTimeout(t *testing.T) {
if err != nil {
t.Errorf("GracefulStop(0) returned error: %v", err)
}
}
}

217
internal/server/upgrade.go Normal file
View File

@ -0,0 +1,217 @@
// Package server 提供优雅升级(热升级)功能,实现零停机部署。
package server
import (
"fmt"
"net"
"os"
"os/exec"
"os/signal"
"path/filepath"
"syscall"
"time"
)
// UpgradeManager 优雅升级管理器。
type UpgradeManager struct {
server *Server
pidFile string // PID 文件路径
oldPid int // 旧进程 PID
listeners []net.Listener
}
// NewUpgradeManager 创建升级管理器。
func NewUpgradeManager(server *Server) *UpgradeManager {
return &UpgradeManager{
server: server,
}
}
// SetPidFile 设置 PID 文件路径。
func (u *UpgradeManager) SetPidFile(path string) {
u.pidFile = path
}
// SetListeners 设置监听器列表(用于升级时继承)。
func (u *UpgradeManager) SetListeners(listeners []net.Listener) {
u.listeners = listeners
}
// WritePid 写入当前进程 PID。
func (u *UpgradeManager) WritePid() error {
if u.pidFile == "" {
return nil
}
pid := os.Getpid()
return os.WriteFile(u.pidFile, []byte(fmt.Sprintf("%d", pid)), 0644)
}
// ReadOldPid 读取旧进程 PID。
func (u *UpgradeManager) ReadOldPid() (int, error) {
if u.pidFile == "" {
return 0, fmt.Errorf("pid file not configured")
}
data, err := os.ReadFile(u.pidFile)
if err != nil {
return 0, err
}
var pid int
_, err = fmt.Sscanf(string(data), "%d", &pid)
return pid, err
}
// IsChild 检查是否是子进程(从升级启动)。
func (u *UpgradeManager) IsChild() bool {
return os.Getenv("GRACEFUL_UPGRADE") == "1"
}
// GetInheritedListeners 获取继承的监听器。
func (u *UpgradeManager) GetInheritedListeners() ([]net.Listener, error) {
fdsStr := os.Getenv("LISTEN_FDS")
if fdsStr == "" {
return nil, nil // 不是升级启动
}
var fdCount int
_, err := fmt.Sscanf(fdsStr, "%d", &fdCount)
if err != nil {
return nil, err
}
var listeners []net.Listener
for i := 3; i < 3+fdCount; i++ {
file := os.NewFile(uintptr(i), fmt.Sprintf("listener-%d", i))
if file == nil {
continue
}
listener, err := net.FileListener(file)
if err != nil {
file.Close()
continue
}
listeners = append(listeners, listener)
}
return listeners, nil
}
// GracefulUpgrade 执行优雅升级。
func (u *UpgradeManager) GracefulUpgrade(newBinary string) error {
if len(u.listeners) == 0 {
return fmt.Errorf("no listeners configured for upgrade")
}
// 准备环境变量
env := os.Environ()
env = append(env, "GRACEFUL_UPGRADE=1")
env = append(env, fmt.Sprintf("LISTEN_FDS=%d", len(u.listeners)))
// 获取监听器的文件描述符
files := make([]*os.File, 0, len(u.listeners))
for _, listener := range u.listeners {
file, err := listenerFile(listener)
if err != nil {
return fmt.Errorf("failed to get listener file: %w", err)
}
files = append(files, file)
}
// 启动新进程
execPath, err := filepath.Abs(newBinary)
if err != nil {
return err
}
cmd := exec.Command(execPath)
cmd.Env = env
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.ExtraFiles = files
if err := cmd.Start(); err != nil {
return fmt.Errorf("failed to start new process: %w", err)
}
newPid := cmd.Process.Pid
u.oldPid = os.Getpid()
// 写入新 PID
if u.pidFile != "" {
os.WriteFile(u.pidFile, []byte(fmt.Sprintf("%d", newPid)), 0644)
}
return nil
}
// listenerFile 从 net.Listener 获取文件描述符。
func listenerFile(listener net.Listener) (*os.File, error) {
switch l := listener.(type) {
case *net.TCPListener:
return l.File()
case *net.UnixListener:
return l.File()
default:
return nil, fmt.Errorf("unsupported listener type: %T", listener)
}
}
// WaitForShutdown 等待旧进程关闭。
func (u *UpgradeManager) WaitForShutdown(timeout time.Duration) error {
if u.oldPid == 0 {
return nil
}
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
process, err := os.FindProcess(u.oldPid)
if err != nil {
return nil
}
err = process.Signal(syscall.Signal(0))
if err != nil {
return nil
}
time.Sleep(100 * time.Millisecond)
}
process, _ := os.FindProcess(u.oldPid)
process.Signal(syscall.SIGKILL)
return fmt.Errorf("old process did not shutdown gracefully")
}
// NotifyOldProcess 通知旧进程关闭。
func (u *UpgradeManager) NotifyOldProcess() error {
oldPid, err := u.ReadOldPid()
if err != nil || oldPid == 0 {
return nil
}
process, err := os.FindProcess(oldPid)
if err != nil {
return nil
}
return process.Signal(syscall.SIGQUIT)
}
// SetupSignalHandlers 设置升级相关信号处理。
func (u *UpgradeManager) SetupSignalHandlers(newBinary string) {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGUSR2)
go func() {
for sig := range sigCh {
if sig == syscall.SIGUSR2 {
u.GracefulUpgrade(newBinary)
}
}
}()
}

View File

@ -0,0 +1,101 @@
package server
import (
"os"
"testing"
)
func TestNewUpgradeManager(t *testing.T) {
srv := New(nil)
mgr := NewUpgradeManager(srv)
if mgr == nil {
t.Error("Expected non-nil manager")
}
if mgr.server != srv {
t.Error("Expected server to be set")
}
}
func TestIsChild(t *testing.T) {
mgr := NewUpgradeManager(nil)
// 默认不是子进程
if mgr.IsChild() {
t.Error("Expected IsChild to be false by default")
}
// 设置环境变量
os.Setenv("GRACEFUL_UPGRADE", "1")
defer os.Unsetenv("GRACEFUL_UPGRADE")
if !mgr.IsChild() {
t.Error("Expected IsChild to be true when GRACEFUL_UPGRADE=1")
}
}
func TestPidFile(t *testing.T) {
tmpFile := "/tmp/lolly-test.pid"
defer os.Remove(tmpFile)
mgr := NewUpgradeManager(nil)
mgr.SetPidFile(tmpFile)
// 写入 PID
if err := mgr.WritePid(); err != nil {
t.Errorf("WritePid failed: %v", err)
}
// 读取 PID
pid, err := mgr.ReadOldPid()
if err != nil {
t.Errorf("ReadOldPid failed: %v", err)
}
if pid != os.Getpid() {
t.Errorf("Expected pid %d, got %d", os.Getpid(), pid)
}
}
func TestReadOldPidNoFile(t *testing.T) {
mgr := NewUpgradeManager(nil)
mgr.SetPidFile("/nonexistent/path/pid")
_, err := mgr.ReadOldPid()
if err == nil {
t.Error("Expected error for non-existent file")
}
}
func TestGetInheritedListenersNoFds(t *testing.T) {
mgr := NewUpgradeManager(nil)
// 没有设置 LISTEN_FDS
listeners, err := mgr.GetInheritedListeners()
if err != nil {
t.Errorf("GetInheritedListeners failed: %v", err)
}
if len(listeners) != 0 {
t.Errorf("Expected 0 listeners, got %d", len(listeners))
}
}
func TestNotifyOldProcessNoPid(t *testing.T) {
mgr := NewUpgradeManager(nil)
mgr.SetPidFile("/nonexistent/pid")
// 没有 PID 文件,应该返回 nil
err := mgr.NotifyOldProcess()
if err != nil {
t.Errorf("NotifyOldProcess should return nil for no pid, got: %v", err)
}
}
func TestWaitForShutdownNoOldPid(t *testing.T) {
mgr := NewUpgradeManager(nil)
// 没有旧进程
err := mgr.WaitForShutdown(0)
if err != nil {
t.Errorf("WaitForShutdown should return nil for no old pid, got: %v", err)
}
}

View File

@ -312,4 +312,4 @@ func TestVHostManager_PortStripping(t *testing.T) {
}
t.Log("已知限制: IPv6 数字地址端口剥离需要修复 vhost.go")
})
}
}

View File

@ -435,20 +435,20 @@ func defaultCipherSuites() []uint16 {
// cipherSuiteMap maps cipher suite names to TLS IDs.
var cipherSuiteMap = map[string]uint16{
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
"TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
"TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
"TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
"TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
"TLS_RSA_WITH_AES_128_CBC_SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA,
"TLS_RSA_WITH_AES_256_CBC_SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA,
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
"TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
"TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
"TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
"TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
"TLS_RSA_WITH_AES_128_CBC_SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA,
"TLS_RSA_WITH_AES_256_CBC_SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA,
}
// ValidateCertificate validates a certificate file.
@ -475,4 +475,4 @@ func ValidateKey(keyPath string) error {
// Key validation happens during tls.LoadX509KeyPair
// This is a preliminary check that the file exists and is readable
return nil
}
}

View File

@ -181,10 +181,10 @@ func TestParseTLSVersions(t *testing.T) {
func TestParseCipherSuites(t *testing.T) {
tests := []struct {
name string
ciphers []string
wantLen int
wantErr bool
name string
ciphers []string
wantLen int
wantErr bool
}{
{
name: "valid cipher",
@ -192,7 +192,7 @@ func TestParseCipherSuites(t *testing.T) {
wantLen: 1,
},
{
name: "multiple valid ciphers",
name: "multiple valid ciphers",
ciphers: []string{
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
@ -407,4 +407,4 @@ func containsString(s, substr string) bool {
}
}
return false
}
}

401
internal/stream/stream.go Normal file
View File

@ -0,0 +1,401 @@
// Package stream 提供 TCP/UDP Stream 代理功能,支持 MySQL、DNS 等服务代理。
package stream
import (
"io"
"net"
"sync"
"sync/atomic"
"time"
)
// Balancer 负载均衡器接口stream 专用)。
type Balancer interface {
Select(targets []*Target) *Target
}
// roundRobin 简单轮询。
type roundRobin struct {
counter uint64
}
// newRoundRobin 创建轮询均衡器。
func newRoundRobin() Balancer {
return &roundRobin{}
}
// Select 选择下一个目标。
func (r *roundRobin) Select(targets []*Target) *Target {
// 过滤健康目标
healthy := make([]*Target, 0)
for _, t := range targets {
if t.healthy.Load() {
healthy = append(healthy, t)
}
}
if len(healthy) == 0 {
return nil
}
idx := atomic.AddUint64(&r.counter, 1) - 1
return healthy[idx%uint64(len(healthy))]
}
// leastConn 最少连接。
type leastConn struct{}
// newLeastConn 创建最少连接均衡器。
func newLeastConn() Balancer {
return &leastConn{}
}
// Select 选择连接最少的目标。
func (l *leastConn) Select(targets []*Target) *Target {
var selected *Target
var minConns int64 = -1
for _, t := range targets {
if !t.healthy.Load() {
continue
}
conns := atomic.LoadInt64(&t.conns)
if selected == nil || conns < minConns {
selected = t
minConns = conns
}
}
return selected
}
// Server TCP/UDP Stream 代理服务器。
type Server struct {
listeners map[string]net.Listener
upstreams map[string]*Upstream
connCount int64 // 当前连接数
mu sync.RWMutex
running atomic.Bool
}
// Upstream Stream 上游配置。
type Upstream struct {
name string
targets []*Target
balancer Balancer
healthChk *HealthChecker
mu sync.RWMutex
}
// Target Stream 目标服务器。
type Target struct {
addr string
weight int
healthy atomic.Bool
conns int64 // 当前连接数
}
// HealthChecker Stream 健康检查器。
type HealthChecker struct {
upstream *Upstream
interval time.Duration
timeout time.Duration
stopCh chan struct{}
}
// Config Stream 配置。
type Config struct {
Listen string // 监听地址
Protocol string // tcp 或 udp
Upstream UpstreamSpec // 上游配置
}
// UpstreamSpec 上游配置规格。
type UpstreamSpec struct {
Name string
Targets []TargetSpec
LoadBalance string
HealthCheck HealthCheckSpec
}
// TargetSpec 目标配置规格。
type TargetSpec struct {
Addr string
Weight int
}
// HealthCheckSpec 健康检查配置规格。
type HealthCheckSpec struct {
Interval time.Duration
Timeout time.Duration
Enabled bool
}
// NewServer 创建 Stream 服务器。
func NewServer() *Server {
return &Server{
listeners: make(map[string]net.Listener),
upstreams: make(map[string]*Upstream),
}
}
// AddUpstream 添加上游配置。
func (s *Server) AddUpstream(name string, targets []TargetSpec, lbType string, hcSpec HealthCheckSpec) error {
s.mu.Lock()
defer s.mu.Unlock()
// 创建目标列表
tgts := make([]*Target, len(targets))
for i, t := range targets {
tgts[i] = &Target{
addr: t.Addr,
weight: t.Weight,
}
tgts[i].healthy.Store(true) // 初始假设健康
}
// 创建负载均衡器
var balancer Balancer
switch lbType {
case "round_robin":
balancer = newRoundRobin()
case "least_conn":
balancer = newLeastConn()
default:
balancer = newRoundRobin()
}
upstream := &Upstream{
name: name,
targets: tgts,
balancer: balancer,
}
// 启动健康检查
if hcSpec.Enabled {
upstream.healthChk = &HealthChecker{
upstream: upstream,
interval: hcSpec.Interval,
timeout: hcSpec.Timeout,
stopCh: make(chan struct{}),
}
go upstream.healthChk.Start()
}
s.upstreams[name] = upstream
return nil
}
// ListenTCP 开始监听 TCP 端口。
func (s *Server) ListenTCP(addr string) error {
s.mu.Lock()
defer s.mu.Unlock()
listener, err := net.Listen("tcp", addr)
if err != nil {
return err
}
s.listeners[addr] = listener
return nil
}
// ListenUDP 开始监听 UDP 端口。
func (s *Server) ListenUDP(addr string) error {
s.mu.Lock()
defer s.mu.Unlock()
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return err
}
conn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return err
}
// UDP 用 UDPConn 而非 Listener需要特殊处理
s.listeners[addr] = &udpListener{conn: conn}
return nil
}
// Start 启动 Stream 服务器。
func (s *Server) Start() error {
s.running.Store(true)
s.mu.RLock()
defer s.mu.RUnlock()
for addr, listener := range s.listeners {
go s.acceptLoop(addr, listener)
}
return nil
}
// acceptLoop 接受连接循环。
func (s *Server) acceptLoop(addr string, listener net.Listener) {
for s.running.Load() {
conn, err := listener.Accept()
if err != nil {
if !s.running.Load() {
return // 正常关闭
}
continue
}
s.connCount++
go s.handleConnection(conn, addr)
}
}
// handleConnection 处理单个连接。
func (s *Server) handleConnection(clientConn net.Conn, addr string) {
defer func() {
clientConn.Close()
s.connCount--
}()
s.mu.RLock()
// 根据监听地址找到对应 upstream简化用第一个
var upstream *Upstream
for _, up := range s.upstreams {
upstream = up
break
}
s.mu.RUnlock()
if upstream == nil {
return // 无上游配置
}
// 选择目标
target := upstream.Select()
if target == nil {
return // 无可用目标
}
target.conns++
defer func() { target.conns-- }()
// 连接目标
targetConn, err := net.DialTimeout("tcp", target.addr, 10*time.Second)
if err != nil {
target.healthy.Store(false)
return
}
defer targetConn.Close()
// 双向数据转发
go io.Copy(targetConn, clientConn)
io.Copy(clientConn, targetConn)
}
// Select 选择健康的上游目标。
func (u *Upstream) Select() *Target {
u.mu.RLock()
defer u.mu.RUnlock()
// 获取健康目标列表
healthyTargets := make([]*Target, 0)
for _, t := range u.targets {
if t.healthy.Load() {
healthyTargets = append(healthyTargets, t)
}
}
if len(healthyTargets) == 0 {
return nil
}
// 使用负载均衡器选择
return u.balancer.Select(healthyTargets)
}
// Start 启动健康检查。
func (h *HealthChecker) Start() {
ticker := time.NewTicker(h.interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
h.check()
case <-h.stopCh:
return
}
}
}
// check 执行健康检查。
func (h *HealthChecker) check() {
for _, target := range h.upstream.targets {
conn, err := net.DialTimeout("tcp", target.addr, h.timeout)
if err != nil {
target.healthy.Store(false)
} else {
conn.Close()
target.healthy.Store(true)
}
}
}
// Stop 停止健康检查。
func (h *HealthChecker) Stop() {
close(h.stopCh)
}
// Stop 停止 Stream 服务器。
func (s *Server) Stop() error {
s.running.Store(false)
s.mu.Lock()
defer s.mu.Unlock()
// 关闭所有监听器
for _, listener := range s.listeners {
listener.Close()
}
// 停止健康检查
for _, upstream := range s.upstreams {
if upstream.healthChk != nil {
upstream.healthChk.Stop()
}
}
return nil
}
// Stats 返回服务器统计信息。
func (s *Server) Stats() Stats {
return Stats{
Connections: s.connCount,
Listeners: len(s.listeners),
Upstreams: len(s.upstreams),
}
}
// Stats Stream 服务器统计。
type Stats struct {
Connections int64
Listeners int
Upstreams int
}
// udpListener UDP 监听器包装。
type udpListener struct {
conn *net.UDPConn
}
// Accept UDP 不支持 Accept返回错误。
func (u *udpListener) Accept() (net.Conn, error) {
return nil, io.EOF
}
// Close 关闭 UDP 连接。
func (u *udpListener) Close() error {
return u.conn.Close()
}
// Addr 返回本地地址。
func (u *udpListener) Addr() net.Addr {
return u.conn.LocalAddr()
}

View File

@ -0,0 +1,233 @@
package stream
import (
"net"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestNewServer(t *testing.T) {
s := NewServer()
if s == nil {
t.Error("Expected non-nil server")
}
if s.listeners == nil {
t.Error("Expected initialized listeners map")
}
if s.upstreams == nil {
t.Error("Expected initialized upstreams map")
}
}
func TestAddUpstream(t *testing.T) {
s := NewServer()
targets := []TargetSpec{
{Addr: "localhost:8001", Weight: 1},
{Addr: "localhost:8002", Weight: 2},
}
hcSpec := HealthCheckSpec{
Enabled: false,
Interval: 10 * time.Second,
Timeout: 5 * time.Second,
}
err := s.AddUpstream("test", targets, "round_robin", hcSpec)
if err != nil {
t.Errorf("AddUpstream failed: %v", err)
}
if len(s.upstreams) != 1 {
t.Errorf("Expected 1 upstream, got %d", len(s.upstreams))
}
up := s.upstreams["test"]
if up == nil {
t.Error("Expected non-nil upstream")
}
if len(up.targets) != 2 {
t.Errorf("Expected 2 targets, got %d", len(up.targets))
}
}
func TestRoundRobinBalancer(t *testing.T) {
targets := []*Target{
{addr: "localhost:8001"},
{addr: "localhost:8002"},
{addr: "localhost:8003"},
}
for _, target := range targets {
target.healthy.Store(true)
}
rr := newRoundRobin()
// 测试轮询
results := make(map[string]int)
for i := 0; i < 6; i++ {
selected := rr.Select(targets)
if selected == nil {
t.Error("Expected non-nil target")
continue
}
results[selected.addr]++
}
// 每个目标应该被选中 2 次
for _, target := range targets {
if results[target.addr] != 2 {
t.Errorf("Expected %s to be selected 2 times, got %d", target.addr, results[target.addr])
}
}
}
func TestLeastConnBalancer(t *testing.T) {
targets := []*Target{
{addr: "localhost:8001", conns: 5},
{addr: "localhost:8002", conns: 2},
{addr: "localhost:8003", conns: 8},
}
for _, t := range targets {
t.healthy.Store(true)
}
lc := newLeastConn()
selected := lc.Select(targets)
if selected == nil {
t.Error("Expected non-nil target")
} else if selected.addr != "localhost:8002" {
t.Errorf("Expected localhost:8002 (least connections), got %s", selected.addr)
}
}
func TestBalancerNoHealthyTargets(t *testing.T) {
targets := []*Target{
{addr: "localhost:8001"},
{addr: "localhost:8002"},
}
// 不设置 healthy默认为 false
rr := newRoundRobin()
selected := rr.Select(targets)
if selected != nil {
t.Error("Expected nil for no healthy targets")
}
lc := newLeastConn()
selected = lc.Select(targets)
if selected != nil {
t.Error("Expected nil for no healthy targets")
}
}
func TestServerStats(t *testing.T) {
s := NewServer()
stats := s.Stats()
if stats.Connections != 0 {
t.Errorf("Expected 0 connections, got %d", stats.Connections)
}
if stats.Listeners != 0 {
t.Errorf("Expected 0 listeners, got %d", stats.Listeners)
}
}
func TestUpstreamSelect(t *testing.T) {
u := &Upstream{
targets: []*Target{
{addr: "localhost:8001"},
{addr: "localhost:8002"},
},
balancer: newRoundRobin(),
}
for _, t := range u.targets {
t.healthy.Store(true)
}
selected := u.Select()
if selected == nil {
t.Error("Expected non-nil target")
}
}
func TestHealthChecker(t *testing.T) {
u := &Upstream{
targets: []*Target{
{addr: "localhost:99999"}, // 不存在的端口
},
}
hc := &HealthChecker{
upstream: u,
interval: 1 * time.Second,
timeout: 100 * time.Millisecond,
stopCh: make(chan struct{}),
}
// 执行一次检查
hc.check()
// 目标应该被标记为不健康
if u.targets[0].healthy.Load() {
t.Error("Expected target to be marked unhealthy")
}
}
func TestUDPListener(t *testing.T) {
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to resolve UDP address: %v", err)
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
t.Fatalf("Failed to listen UDP: %v", err)
}
defer conn.Close()
ul := &udpListener{conn: conn}
// 测试 Addr
if ul.Addr() == nil {
t.Error("Expected non-nil address")
}
// 测试 Close
if err := ul.Close(); err != nil {
t.Errorf("Close failed: %v", err)
}
// 测试 Accept应该返回 io.EOF
_, err = ul.Accept()
if err == nil {
t.Error("Expected error from Accept")
}
}
func TestConcurrentConnections(t *testing.T) {
s := NewServer()
targets := []TargetSpec{
{Addr: "localhost:8001", Weight: 1},
}
s.AddUpstream("test", targets, "round_robin", HealthCheckSpec{})
// 并发增加连接数
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
atomic.AddInt64(&s.connCount, 1)
}()
}
wg.Wait()
if s.connCount != 100 {
t.Errorf("Expected 100 connections, got %d", s.connCount)
}
}

View File

@ -23,4 +23,4 @@ func main() {
}
os.Exit(app.Run(configPath, *genConfig, *outputPath, *showVersion))
}
}