From 7a98a0b044d91f2709026f4720d8dbbe721ed5b5 Mon Sep 17 00:00:00 2001 From: xfy Date: Fri, 3 Apr 2026 18:24:21 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E6=8A=BD=E5=8F=96=E7=BD=91?= =?UTF-8?q?=E7=BB=9C=E5=B7=A5=E5=85=B7=E5=87=BD=E6=95=B0=E5=88=B0=20netuti?= =?UTF-8?q?l=20=E5=8C=85=EF=BC=8C=E7=A7=BB=E9=99=A4=E5=86=97=E4=BD=99?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 internal/netutil 包,统一 IP 提取和 URL 解析函数 - proxy/websocket/middleware 使用 netutil 替代重复实现 - 移除 handler/sendfile 中未使用的 BufferPool 相关代码 - 移除 http3/adapter 中未使用的反向转换函数 - 提取 server.registerStaticHandler 函数改进代码结构 - 优化 access.go 锁范围,减少持锁时间 Co-Authored-By: Claude --- internal/handler/sendfile.go | 51 -------- internal/handler/sendfile_test.go | 67 ---------- internal/http3/adapter.go | 92 -------------- internal/http3/adapter_test.go | 137 -------------------- internal/middleware/security/access.go | 42 ++++--- internal/middleware/security/ratelimit.go | 45 +------ internal/netutil/ip.go | 112 +++++++++++++++++ internal/netutil/ip_test.go | 123 ++++++++++++++++++ internal/netutil/url.go | 76 +++++++++++ internal/netutil/url_test.go | 147 ++++++++++++++++++++++ internal/proxy/proxy.go | 53 ++------ internal/proxy/proxy_test.go | 5 +- internal/proxy/websocket.go | 75 ++--------- internal/server/server.go | 36 +++--- 14 files changed, 525 insertions(+), 536 deletions(-) create mode 100644 internal/netutil/ip.go create mode 100644 internal/netutil/ip_test.go create mode 100644 internal/netutil/url.go create mode 100644 internal/netutil/url_test.go diff --git a/internal/handler/sendfile.go b/internal/handler/sendfile.go index a7ffe4a..31dab34 100644 --- a/internal/handler/sendfile.go +++ b/internal/handler/sendfile.go @@ -22,7 +22,6 @@ import ( "net" "os" "runtime" - "sync" "syscall" "github.com/valyala/fasthttp" @@ -145,53 +144,3 @@ func getSocketFd(conn net.Conn) (uintptr, error) { return 0, syscall.ENOTSUP } } - -// BufferPool 缓冲池,复用内存减少分配。 -var BufferPool = &syncPool{ - pool: make(chan []byte, 32), - size: 32 * 1024, // 32KB -} - -// syncPool 简化的缓冲池。 -type syncPool struct { - pool chan []byte - size int -} - -// Get 获取缓冲区。 -func (p *syncPool) Get() []byte { - select { - case buf := <-p.pool: - return buf - default: - return make([]byte, p.size) - } -} - -// Put 放回缓冲区。 -func (p *syncPool) Put(buf []byte) { - // 只放回合适大小的缓冲区 - if len(buf) == p.size { - select { - case p.pool <- buf: - default: // 池满,丢弃 - } - } -} - -// RealBufferPool 使用 sync.Pool 的标准实现(推荐)。 -var RealBufferPool = sync.Pool{ - New: func() interface{} { - return make([]byte, 32*1024) - }, -} - -// GetBuffer 从池获取缓冲区。 -func GetBuffer() []byte { - return RealBufferPool.Get().([]byte) -} - -// PutBuffer 放回缓冲区。 -func PutBuffer(buf []byte) { - RealBufferPool.Put(buf) //nolint:staticcheck // SA6002: 测试表明指针优化不明显,保持简洁 -} diff --git a/internal/handler/sendfile_test.go b/internal/handler/sendfile_test.go index 1da41ad..e3177b0 100644 --- a/internal/handler/sendfile_test.go +++ b/internal/handler/sendfile_test.go @@ -14,61 +14,12 @@ import ( "github.com/valyala/fasthttp" ) -func TestBufferPool(t *testing.T) { - // 获取缓冲区 - buf := BufferPool.Get() - if buf == nil { - t.Error("Expected non-nil buffer") - } - if len(buf) != 32*1024 { - t.Errorf("Expected buffer size 32KB, got %d", len(buf)) - } - - // 放回缓冲区 - BufferPool.Put(buf) - - // 再次获取(可能是同一个) - buf2 := BufferPool.Get() - if buf2 == nil { - t.Error("Expected non-nil buffer") - } -} - -func TestRealBufferPool(t *testing.T) { - buf := GetBuffer() - if buf == nil { - t.Error("Expected non-nil buffer") - } - if len(buf) != 32*1024 { - t.Errorf("Expected buffer size 32KB, got %d", len(buf)) - } - - PutBuffer(buf) -} - func TestMinSendfileSize(t *testing.T) { if MinSendfileSize != 8*1024 { t.Errorf("Expected MinSendfileSize 8KB, got %d", MinSendfileSize) } } -func TestGetBuffer(t *testing.T) { - buf := GetBuffer() - if buf == nil { - t.Error("Expected non-nil buffer") - return - } - if len(buf) != 32*1024 { - t.Errorf("Expected buffer size 32KB, got %d", len(buf)) - } - - // 测试写入 - copy(buf, []byte("test")) - if string(buf[:4]) != "test" { - t.Error("Expected to write 'test' to buffer") - } -} - func TestPlatformSendfile(t *testing.T) { // 创建临时文件 tmpDir := t.TempDir() @@ -90,24 +41,6 @@ func TestPlatformSendfile(t *testing.T) { _ = platformSendfile(nil, file, 0, int64(len(content))) } -func TestBufferPoolConcurrent(t *testing.T) { - const iterations = 100 - - done := make(chan bool) - - for i := 0; i < iterations; i++ { - go func() { - buf := GetBuffer() - PutBuffer(buf) - done <- true - }() - } - - for i := 0; i < iterations; i++ { - <-done - } -} - // TestCopyFile 测试 copyFile fallback 函数 func TestCopyFile(t *testing.T) { tmpDir := t.TempDir() diff --git a/internal/http3/adapter.go b/internal/http3/adapter.go index 14e0a1c..93e864e 100644 --- a/internal/http3/adapter.go +++ b/internal/http3/adapter.go @@ -11,11 +11,9 @@ package http3 import ( - "bytes" "io" "net" "net/http" - "net/url" "sync" "github.com/valyala/fasthttp" @@ -159,93 +157,3 @@ func (a *Adapter) convertResponse(ctx *fasthttp.RequestCtx, w http.ResponseWrite func (a *Adapter) WrapHandler(handler fasthttp.RequestHandler) http.Handler { return a.Wrap(handler) } - -// FastHTTPHandler 从 http.Handler 提取并调用 fasthttp 处理器。 -// -// 这是一个便捷方法,用于在需要时反向转换。 -// -// 参数: -// - h: 标准库 http.Handler -// - ctx: FastHTTP 请求上下文 -func FastHTTPHandler(h http.Handler, ctx *fasthttp.RequestCtx) { - // 创建虚拟 ResponseWriter - rw := &fastHTTPResponseWriter{ - ctx: ctx, - } - - // 转换请求 - r := convertToHTTPRequest(ctx) - - // 调用标准库 handler - h.ServeHTTP(rw, r) -} - -// fastHTTPResponseWriter 实现 http.ResponseWriter 接口。 -type fastHTTPResponseWriter struct { - ctx *fasthttp.RequestCtx - status int - header http.Header - written bool -} - -func (w *fastHTTPResponseWriter) Header() http.Header { - if w.header == nil { - w.header = make(http.Header) - } - return w.header -} - -func (w *fastHTTPResponseWriter) Write(data []byte) (int, error) { - if !w.written { - w.WriteHeader(http.StatusOK) - } - return w.ctx.Write(data) -} - -func (w *fastHTTPResponseWriter) WriteHeader(statusCode int) { - if w.written { - return - } - w.written = true - w.status = statusCode - - // 复制头部 - for k, v := range w.header { - for _, vv := range v { - w.ctx.Response.Header.Add(k, vv) - } - } - - w.ctx.SetStatusCode(statusCode) -} - -// convertToHTTPRequest 将 fasthttp.RequestCtx 转换为 http.Request。 -func convertToHTTPRequest(ctx *fasthttp.RequestCtx) *http.Request { - r := &http.Request{ - Method: string(ctx.Method()), - Host: string(ctx.Host()), - RemoteAddr: ctx.RemoteAddr().String(), - Proto: "HTTP/3", - ProtoMajor: 3, - ProtoMinor: 0, - } - - // 构建 URL - r.URL = &url.URL{ - Path: string(ctx.Path()), - RawQuery: string(ctx.URI().QueryString()), - } - - // 复制头部 - r.Header = make(http.Header) - for k, v := range ctx.Request.Header.All() { - r.Header.Add(string(k), string(v)) - } - - // 设置请求体 - if len(ctx.PostBody()) > 0 { - r.Body = io.NopCloser(bytes.NewReader(ctx.PostBody())) - } - - return r -} diff --git a/internal/http3/adapter_test.go b/internal/http3/adapter_test.go index fa2af12..ba5d794 100644 --- a/internal/http3/adapter_test.go +++ b/internal/http3/adapter_test.go @@ -299,143 +299,6 @@ func TestConvertResponse_Body(t *testing.T) { } } -// TestFastHTTPHandler 测试反向转换 -func TestFastHTTPHandler(t *testing.T) { - // 创建标准库 handler - stdHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/plain") - w.WriteHeader(200) - _, _ = w.Write([]byte("Hello from std http")) - }) - - ctx := &fasthttp.RequestCtx{} - ctx.Init(&fasthttp.Request{}, nil, nil) - ctx.Request.SetRequestURI("/test") - ctx.Request.Header.SetMethod("GET") - - FastHTTPHandler(stdHandler, ctx) - - if ctx.Response.StatusCode() != 200 { - t.Errorf("Expected status 200, got %d", ctx.Response.StatusCode()) - } - - if string(ctx.Response.Body()) != "Hello from std http" { - t.Errorf("Expected body 'Hello from std http', got %s", ctx.Response.Body()) - } -} - -// TestConvertToHTTPRequest 测试转换为标准库请求 -func TestConvertToHTTPRequest(t *testing.T) { - ctx := &fasthttp.RequestCtx{} - ctx.Init(&fasthttp.Request{}, nil, nil) - ctx.Request.SetRequestURI("/path?query=value") - ctx.Request.Header.SetMethod("POST") - ctx.Request.Header.SetHost("example.com") - ctx.Request.Header.Set("Content-Type", "application/json") - ctx.Request.SetBody([]byte("test body")) - - r := convertToHTTPRequest(ctx) - - if r.Method != "POST" { - t.Errorf("Expected Method POST, got %s", r.Method) - } - - if r.Host != "example.com" { - t.Errorf("Expected Host example.com, got %s", r.Host) - } - - if r.URL.Path != "/path" { - t.Errorf("Expected Path /path, got %s", r.URL.Path) - } - - if r.URL.RawQuery != "query=value" { - t.Errorf("Expected RawQuery query=value, got %s", r.URL.RawQuery) - } - - if r.Proto != "HTTP/3" { - t.Errorf("Expected Proto HTTP/3, got %s", r.Proto) - } - - if r.ProtoMajor != 3 || r.ProtoMinor != 0 { - t.Errorf("Expected Proto 3.0, got %d.%d", r.ProtoMajor, r.ProtoMinor) - } - - // 检查头部 - if r.Header.Get("Content-Type") != "application/json" { - t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type")) - } - - // 检查请求体 - body, _ := io.ReadAll(r.Body) - if string(body) != "test body" { - t.Errorf("Expected body 'test body', got %s", body) - } -} - -// TestFastHTTPResponseWriter_Write 测试 Write 方法 -func TestFastHTTPResponseWriter_Write(t *testing.T) { - ctx := &fasthttp.RequestCtx{} - ctx.Init(&fasthttp.Request{}, nil, nil) - - rw := &fastHTTPResponseWriter{ctx: ctx} - - n, err := rw.Write([]byte("test content")) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if n != len("test content") { - t.Errorf("Expected written %d, got %d", len("test content"), n) - } - - // 检查状态码被自动设置 - if rw.status != http.StatusOK { - t.Errorf("Expected auto-set status 200, got %d", rw.status) - } -} - -// TestFastHTTPResponseWriter_WriteHeader 测试 WriteHeader 方法 -func TestFastHTTPResponseWriter_WriteHeader(t *testing.T) { - ctx := &fasthttp.RequestCtx{} - ctx.Init(&fasthttp.Request{}, nil, nil) - - rw := &fastHTTPResponseWriter{ctx: ctx} - - rw.Header().Set("X-Custom", "value") - rw.WriteHeader(404) - - if rw.status != 404 { - t.Errorf("Expected status 404, got %d", rw.status) - } - - if rw.written != true { - t.Error("Expected written flag to be true") - } - - // 再次调用应该被忽略 - rw.WriteHeader(500) - if rw.status != 404 { - t.Errorf("Expected status to remain 404, got %d", rw.status) - } -} - -// TestFastHTTPResponseWriter_Header 测试 Header 方法 -func TestFastHTTPResponseWriter_Header(t *testing.T) { - ctx := &fasthttp.RequestCtx{} - ctx.Init(&fasthttp.Request{}, nil, nil) - - rw := &fastHTTPResponseWriter{ctx: ctx} - - h := rw.Header() - if h == nil { - t.Error("Expected non-nil header") - } - - h.Set("Content-Type", "text/html") - if rw.Header().Get("Content-Type") != "text/html" { - t.Errorf("Expected Content-Type text/html, got %s", rw.Header().Get("Content-Type")) - } -} - // TestWrap_RoundTrip 完整流程测试 func TestWrap_RoundTrip(t *testing.T) { adapter := NewAdapter() diff --git a/internal/middleware/security/access.go b/internal/middleware/security/access.go index 102e0cb..4dd3a37 100644 --- a/internal/middleware/security/access.go +++ b/internal/middleware/security/access.go @@ -195,19 +195,14 @@ func (ac *AccessControl) Check(ip net.IP) bool { // 返回值: // - error: CIDR 解析失败时返回错误 func (ac *AccessControl) UpdateAllowList(cidrs []string) error { - ac.mu.Lock() - defer ac.mu.Unlock() - - newList := make([]net.IPNet, 0, len(cidrs)) - for _, cidr := range cidrs { - network, err := parseCIDR(cidr) - if err != nil { - return fmt.Errorf("invalid CIDR %s: %w", cidr, err) - } - newList = append(newList, *network) + newList, err := parseCIDRList(cidrs) + if err != nil { + return err } + ac.mu.Lock() ac.allowList = newList + ac.mu.Unlock() return nil } @@ -221,20 +216,35 @@ func (ac *AccessControl) UpdateAllowList(cidrs []string) error { // 返回值: // - error: CIDR 解析失败时返回错误 func (ac *AccessControl) UpdateDenyList(cidrs []string) error { - ac.mu.Lock() - defer ac.mu.Unlock() + newList, err := parseCIDRList(cidrs) + if err != nil { + return err + } + ac.mu.Lock() + ac.denyList = newList + ac.mu.Unlock() + return nil +} + +// parseCIDRList 解析 CIDR 字符串列表为 IPNet 列表。 +// +// 参数: +// - cidrs: CIDR 字符串列表 +// +// 返回值: +// - []net.IPNet: 解析后的 IP 网络对象列表 +// - error: 任一 CIDR 解析失败时返回错误 +func parseCIDRList(cidrs []string) ([]net.IPNet, error) { newList := make([]net.IPNet, 0, len(cidrs)) for _, cidr := range cidrs { network, err := parseCIDR(cidr) if err != nil { - return fmt.Errorf("invalid CIDR %s: %w", cidr, err) + return nil, fmt.Errorf("invalid CIDR %s: %w", cidr, err) } newList = append(newList, *network) } - - ac.denyList = newList - return nil + return newList, nil } // SetDefault 设置默认操作。 diff --git a/internal/middleware/security/ratelimit.go b/internal/middleware/security/ratelimit.go index e95e62b..e18260c 100644 --- a/internal/middleware/security/ratelimit.go +++ b/internal/middleware/security/ratelimit.go @@ -32,8 +32,6 @@ package security import ( "errors" "fmt" - "net" - "strings" "sync" "sync/atomic" "time" @@ -41,6 +39,7 @@ import ( "github.com/valyala/fasthttp" "rua.plus/lolly/internal/config" "rua.plus/lolly/internal/middleware" + "rua.plus/lolly/internal/netutil" ) // RateLimiter 基于令牌桶算法的请求速率限制器。 @@ -322,53 +321,13 @@ func (rl *RateLimiter) getRetryAfter(key string) int64 { // 返回值: // - string: IP 地址字符串,无法获取时返回 "unknown" func keyByIP(ctx *fasthttp.RequestCtx) string { - ip := extractClientIP(ctx) + ip := netutil.ExtractClientIPNet(ctx) if ip == nil { return "unknown" } return ip.String() } -// extractClientIP 从请求上下文提取客户端 IP。 -// -// 按优先级依次检查:X-Forwarded-For、X-Real-IP、RemoteAddr。 -// -// 参数: -// - ctx: FastHTTP 请求上下文 -// -// 返回值: -// - net.IP: 客户端 IP 地址,无法获取时返回 nil -func extractClientIP(ctx *fasthttp.RequestCtx) net.IP { - // 优先检查 X-Forwarded-For 头部 - if xff := ctx.Request.Header.Peek("X-Forwarded-For"); len(xff) > 0 { - ips := strings.Split(string(xff), ",") - if len(ips) > 0 { - ipStr := strings.TrimSpace(ips[0]) - ip := net.ParseIP(ipStr) - if ip != nil { - return ip - } - } - } - - // 检查 X-Real-IP 头部 - if xri := ctx.Request.Header.Peek("X-Real-IP"); len(xri) > 0 { - ip := net.ParseIP(string(xri)) - if ip != nil { - return ip - } - } - - // 回退到 RemoteAddr - if addr := ctx.RemoteAddr(); addr != nil { - if tcpAddr, ok := addr.(*net.TCPAddr); ok { - return tcpAddr.IP - } - } - - return nil -} - // keyByHeader 提取头部值作为限流键。 // // 默认使用 X-RateLimit-Key 头部,如果不存在则回退到 IP。 diff --git a/internal/netutil/ip.go b/internal/netutil/ip.go new file mode 100644 index 0000000..b1ab74a --- /dev/null +++ b/internal/netutil/ip.go @@ -0,0 +1,112 @@ +// Package netutil 提供网络相关的通用工具函数。 +// +// 该文件包含客户端 IP 提取相关的工具函数, +// 从 HTTP 请求中提取真实的客户端 IP 地址。 +// +// 作者:xfy +package netutil + +import ( + "net" + "strings" + + "github.com/valyala/fasthttp" +) + +// ExtractClientIP 从请求上下文中提取客户端 IP 地址(返回字符串)。 +// +// 该函数按以下顺序提取 IP: +// 1. X-Forwarded-For 请求头的第一个 IP(最左侧) +// 2. X-Real-IP 请求头 +// 3. RemoteAddr +// +// 注意:此函数不进行可信代理验证,适用于非安全场景(如日志记录)。 +// 对于安全场景(如访问控制),应使用特定模块的安全实现。 +// +// 参数: +// - ctx: FastHTTP 请求上下文 +// +// 返回值: +// - string: 客户端 IP 地址字符串 +func ExtractClientIP(ctx *fasthttp.RequestCtx) string { + // 首先检查 X-Forwarded-For 请求头 + if xff := ctx.Request.Header.Peek("X-Forwarded-For"); len(xff) > 0 { + ips := strings.Split(string(xff), ",") + if len(ips) > 0 { + return strings.TrimSpace(ips[0]) + } + } + + // 检查 X-Real-IP 请求头 + if xri := ctx.Request.Header.Peek("X-Real-IP"); len(xri) > 0 { + return string(xri) + } + + // 回退到 RemoteAddr + if addr := ctx.RemoteAddr(); addr != nil { + if tcpAddr, ok := addr.(*net.TCPAddr); ok { + return tcpAddr.IP.String() + } + return addr.String() + } + + return "" +} + +// ExtractClientIPNet 从请求上下文中提取客户端 IP 地址(返回 net.IP)。 +// +// 该函数与 ExtractClientIP 功能相同,但返回 net.IP 类型, +// 便于后续进行 IP 网络操作(如 CIDR 匹配)。 +// +// 参数: +// - ctx: FastHTTP 请求上下文 +// +// 返回值: +// - net.IP: 客户端 IP 地址,无法解析时返回 nil +func ExtractClientIPNet(ctx *fasthttp.RequestCtx) net.IP { + // 首先检查 X-Forwarded-For 请求头 + if xff := ctx.Request.Header.Peek("X-Forwarded-For"); len(xff) > 0 { + ips := strings.Split(string(xff), ",") + if len(ips) > 0 { + ipStr := strings.TrimSpace(ips[0]) + if ip := net.ParseIP(ipStr); ip != nil { + return ip + } + } + } + + // 检查 X-Real-IP 请求头 + if xri := ctx.Request.Header.Peek("X-Real-IP"); len(xri) > 0 { + if ip := net.ParseIP(string(xri)); ip != nil { + return ip + } + } + + // 回退到 RemoteAddr + if addr := ctx.RemoteAddr(); addr != nil { + if tcpAddr, ok := addr.(*net.TCPAddr); ok { + return tcpAddr.IP + } + } + + return nil +} + +// GetRemoteAddrIP 从 RemoteAddr 提取 IP 地址。 +// +// 这是一个辅助函数,直接从连接的远程地址获取 IP, +// 不检查任何代理头。 +// +// 参数: +// - ctx: FastHTTP 请求上下文 +// +// 返回值: +// - net.IP: 客户端 IP 地址,无法获取时返回 nil +func GetRemoteAddrIP(ctx *fasthttp.RequestCtx) net.IP { + if addr := ctx.RemoteAddr(); addr != nil { + if tcpAddr, ok := addr.(*net.TCPAddr); ok { + return tcpAddr.IP + } + } + return nil +} \ No newline at end of file diff --git a/internal/netutil/ip_test.go b/internal/netutil/ip_test.go new file mode 100644 index 0000000..6e8950b --- /dev/null +++ b/internal/netutil/ip_test.go @@ -0,0 +1,123 @@ +package netutil + +import ( + "net" + "testing" + + "github.com/valyala/fasthttp" +) + +func TestExtractClientIP(t *testing.T) { + tests := []struct { + name string + xff string + xri string + remoteAddr string + want string + }{ + { + name: "X-Forwarded-For with single IP", + xff: "192.168.1.100", + want: "192.168.1.100", + }, + { + name: "X-Forwarded-For with multiple IPs", + xff: "192.168.1.100, 10.0.0.1, 172.16.0.1", + want: "192.168.1.100", + }, + { + name: "X-Real-IP only", + xri: "192.168.1.200", + want: "192.168.1.200", + }, + { + name: "RemoteAddr fallback", + remoteAddr: "192.168.1.1:12345", + want: "0.0.0.0", // fasthttp 默认初始化为 0.0.0.0 + }, + { + name: "X-Forwarded-For takes precedence over X-Real-IP", + xff: "192.168.1.100", + xri: "192.168.1.200", + want: "192.168.1.100", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.Init(&fasthttp.Request{}, nil, nil) + + if tt.xff != "" { + ctx.Request.Header.Set("X-Forwarded-For", tt.xff) + } + if tt.xri != "" { + ctx.Request.Header.Set("X-Real-IP", tt.xri) + } + + got := ExtractClientIP(ctx) + if got != tt.want { + t.Errorf("ExtractClientIP() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestExtractClientIPNet(t *testing.T) { + tests := []struct { + name string + xff string + xri string + want net.IP + }{ + { + name: "X-Forwarded-For valid IP", + xff: "192.168.1.100", + want: net.ParseIP("192.168.1.100"), + }, + { + name: "X-Forwarded-For invalid IP", + xff: "invalid-ip", + want: net.ParseIP("0.0.0.0"), // fasthttp 默认 RemoteAddr + }, + { + name: "X-Real-IP valid IP", + xri: "192.168.1.200", + want: net.ParseIP("192.168.1.200"), + }, + { + name: "No headers", + want: net.ParseIP("0.0.0.0"), // fasthttp 默认 RemoteAddr + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.Init(&fasthttp.Request{}, nil, nil) + + if tt.xff != "" { + ctx.Request.Header.Set("X-Forwarded-For", tt.xff) + } + if tt.xri != "" { + ctx.Request.Header.Set("X-Real-IP", tt.xri) + } + + got := ExtractClientIPNet(ctx) + if !got.Equal(tt.want) { + t.Errorf("ExtractClientIPNet() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetRemoteAddrIP(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.Init(&fasthttp.Request{}, nil, nil) + + // Without setting remote addr, should return nil + got := GetRemoteAddrIP(ctx) + // The result depends on how fasthttp initializes the remote addr + // Just verify it doesn't panic + _ = got +} \ No newline at end of file diff --git a/internal/netutil/url.go b/internal/netutil/url.go new file mode 100644 index 0000000..d561c30 --- /dev/null +++ b/internal/netutil/url.go @@ -0,0 +1,76 @@ +// Package netutil 提供网络相关的通用工具函数。 +// +// 该包包含 URL 解析、客户端 IP 提取等网络操作的工具函数, +// 供 proxy、middleware、server 等模块共享使用。 +// +// 作者:xfy +package netutil + +import "strings" + +// ParseTargetURL 解析目标 URL,提取主机地址和 TLS 标志。 +// +// 该函数用于统一处理代理模块中的 URL 解析逻辑,支持 http:// 和 https:// 前缀。 +// +// 参数: +// - targetURL: 目标 URL 字符串(如 "http://backend:8080/path" 或 "https://api.example.com") +// - addDefaultPort: 是否在没有端口时添加默认端口(:80 或 :443) +// +// 返回值: +// - addr: 主机地址(格式 host:port) +// - isTLS: 是否使用 TLS(HTTPS) +// +// 示例: +// +// addr, isTLS := ParseTargetURL("https://api.example.com/api", true) +// // addr = "api.example.com:443", isTLS = true +// +// addr, isTLS := ParseTargetURL("http://backend:8080", false) +// // addr = "backend:8080", isTLS = false +func ParseTargetURL(targetURL string, addDefaultPort bool) (addr string, isTLS bool) { + addr = targetURL + + // 处理协议前缀 + if strings.HasPrefix(targetURL, "http://") { + addr = targetURL[7:] + } else if strings.HasPrefix(targetURL, "https://") { + addr = targetURL[8:] + isTLS = true + } + + // 移除路径部分,只保留 host:port + if idx := strings.Index(addr, "/"); idx != -1 { + addr = addr[:idx] + } + + // 如果需要,添加默认端口 + if addDefaultPort && !strings.Contains(addr, ":") { + if isTLS { + addr = addr + ":443" + } else { + addr = addr + ":80" + } + } + + return addr, isTLS +} + +// ExtractHost 从 URL 提取主机地址(host:port)。 +// +// 该函数是 ParseTargetURL 的简化版本,始终添加默认端口, +// 用于需要完整地址但不需要 TLS 标志的场景。 +// +// 参数: +// - targetURL: 目标 URL 字符串 +// +// 返回值: +// - string: 主机地址(格式 host:port) +// +// 示例: +// +// host := ExtractHost("https://api.example.com/api") +// // host = "api.example.com:443" +func ExtractHost(targetURL string) string { + addr, _ := ParseTargetURL(targetURL, true) + return addr +} \ No newline at end of file diff --git a/internal/netutil/url_test.go b/internal/netutil/url_test.go new file mode 100644 index 0000000..fe06ad4 --- /dev/null +++ b/internal/netutil/url_test.go @@ -0,0 +1,147 @@ +package netutil + +import "testing" + +func TestParseTargetURL(t *testing.T) { + tests := []struct { + name string + targetURL string + addDefaultPort bool + wantAddr string + wantIsTLS bool + }{ + // HTTP without port + { + name: "http without port, add default", + targetURL: "http://backend.example.com", + addDefaultPort: true, + wantAddr: "backend.example.com:80", + wantIsTLS: false, + }, + { + name: "http without port, no default", + targetURL: "http://backend.example.com", + addDefaultPort: false, + wantAddr: "backend.example.com", + wantIsTLS: false, + }, + // HTTPS without port + { + name: "https without port, add default", + targetURL: "https://api.example.com", + addDefaultPort: true, + wantAddr: "api.example.com:443", + wantIsTLS: true, + }, + { + name: "https without port, no default", + targetURL: "https://api.example.com", + addDefaultPort: false, + wantAddr: "api.example.com", + wantIsTLS: true, + }, + // HTTP with port + { + name: "http with port", + targetURL: "http://backend:8080", + addDefaultPort: true, + wantAddr: "backend:8080", + wantIsTLS: false, + }, + // HTTPS with port + { + name: "https with port", + targetURL: "https://api:8443", + addDefaultPort: true, + wantAddr: "api:8443", + wantIsTLS: true, + }, + // With path + { + name: "http with path", + targetURL: "http://backend:8080/api/v1", + addDefaultPort: false, + wantAddr: "backend:8080", + wantIsTLS: false, + }, + { + name: "https with path", + targetURL: "https://api.example.com/v1/users", + addDefaultPort: true, + wantAddr: "api.example.com:443", + wantIsTLS: true, + }, + // No protocol (treat as HTTP) + { + name: "no protocol", + targetURL: "backend:8080", + addDefaultPort: false, + wantAddr: "backend:8080", + wantIsTLS: false, + }, + { + name: "no protocol, no port, add default", + targetURL: "backend", + addDefaultPort: true, + wantAddr: "backend:80", + wantIsTLS: false, + }, + // IPv6 address + { + name: "ipv6 address", + targetURL: "http://[::1]:8080", + addDefaultPort: false, + wantAddr: "[::1]:8080", + wantIsTLS: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotAddr, gotIsTLS := ParseTargetURL(tt.targetURL, tt.addDefaultPort) + if gotAddr != tt.wantAddr { + t.Errorf("ParseTargetURL() addr = %q, want %q", gotAddr, tt.wantAddr) + } + if gotIsTLS != tt.wantIsTLS { + t.Errorf("ParseTargetURL() isTLS = %v, want %v", gotIsTLS, tt.wantIsTLS) + } + }) + } +} + +func TestExtractHost(t *testing.T) { + tests := []struct { + name string + targetURL string + want string + }{ + { + name: "http without port", + targetURL: "http://backend.example.com", + want: "backend.example.com:80", + }, + { + name: "https without port", + targetURL: "https://api.example.com", + want: "api.example.com:443", + }, + { + name: "http with port", + targetURL: "http://backend:8080", + want: "backend:8080", + }, + { + name: "https with path", + targetURL: "https://api.example.com/v1/users", + want: "api.example.com:443", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ExtractHost(tt.targetURL); got != tt.want { + t.Errorf("ExtractHost() = %q, want %q", got, tt.want) + } + }) + } +} \ No newline at end of file diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 5d3198d..0aab829 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -34,7 +34,6 @@ package proxy import ( "errors" - "net" "strings" "sync" "time" @@ -44,6 +43,7 @@ import ( "rua.plus/lolly/internal/config" "rua.plus/lolly/internal/loadbalance" "rua.plus/lolly/internal/logging" + "rua.plus/lolly/internal/netutil" ) // Proxy 表示反向代理实例,负责将 HTTP 请求转发到后端目标。 @@ -147,20 +147,7 @@ func createBalancer(cfg *config.ProxyConfig) (loadbalance.Balancer, error) { // createHostClient 为后台目标 URL 创建 fasthttp.HostClient。 func createHostClient(targetURL string, timeout config.ProxyTimeout, transportCfg *config.TransportConfig) *fasthttp.HostClient { // 从目标 URL 解析主机和协议 - addr := targetURL - isTLS := false - - if strings.HasPrefix(targetURL, "http://") { - addr = targetURL[7:] - } else if strings.HasPrefix(targetURL, "https://") { - addr = targetURL[8:] - isTLS = true - } - - // 如果存在路径则移除,只保留 host:port - if idx := strings.Index(addr, "/"); idx != -1 { - addr = addr[:idx] - } + addr, isTLS := netutil.ParseTargetURL(targetURL, false) // 默认值 maxIdleConnDuration := 90 * time.Second @@ -321,7 +308,7 @@ func (p *Proxy) selectTarget(ctx *fasthttp.RequestCtx) *loadbalance.Target { // 对于 IPHash 负载均衡器,提取客户端 IP if ipHash, ok := balancer.(*loadbalance.IPHash); ok { - clientIP := getClientIP(ctx) + clientIP := netutil.ExtractClientIP(ctx) return ipHash.SelectByIP(targets, clientIP) } @@ -339,7 +326,7 @@ func (p *Proxy) selectTarget(ctx *fasthttp.RequestCtx) *loadbalance.Target { func (p *Proxy) extractHashKey(ctx *fasthttp.RequestCtx, hashKey string) string { switch { case hashKey == "ip" || hashKey == "": - return getClientIP(ctx) + return netutil.ExtractClientIP(ctx) case hashKey == "uri": return string(ctx.RequestURI()) case strings.HasPrefix(hashKey, "header:"): @@ -348,38 +335,12 @@ func (p *Proxy) extractHashKey(ctx *fasthttp.RequestCtx, hashKey string) string if len(value) > 0 { return string(value) } - return getClientIP(ctx) // fallback to IP + return netutil.ExtractClientIP(ctx) // fallback to IP default: - return getClientIP(ctx) + return netutil.ExtractClientIP(ctx) } } -// getClientIP 从请求上下文中提取客户端 IP 地址。 -func getClientIP(ctx *fasthttp.RequestCtx) string { - // 首先检查 X-Forwarded-For 请求头 - if xff := ctx.Request.Header.Peek("X-Forwarded-For"); len(xff) > 0 { - ips := strings.Split(string(xff), ",") - if len(ips) > 0 { - return strings.TrimSpace(ips[0]) - } - } - - // 检查 X-Real-IP 请求头 - if xri := ctx.Request.Header.Peek("X-Real-IP"); len(xri) > 0 { - return string(xri) - } - - // 回退到 RemoteAddr - if addr := ctx.RemoteAddr(); addr != nil { - if tcpAddr, ok := addr.(*net.TCPAddr); ok { - return tcpAddr.IP.String() - } - return addr.String() - } - - return "" -} - // getClient 返回给定目标 URL 对应的 HostClient。 func (p *Proxy) getClient(targetURL string) *fasthttp.HostClient { p.mu.RLock() @@ -394,7 +355,7 @@ func (p *Proxy) modifyRequestHeaders(ctx *fasthttp.RequestCtx, target *loadbalan headers := &ctx.Request.Header // 添加 X-Real-IP 请求头 - clientIP := getClientIP(ctx) + clientIP := netutil.ExtractClientIP(ctx) if clientIP != "" { headers.Set("X-Real-IP", clientIP) } diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 60a23ab..1863e93 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -12,6 +12,7 @@ import ( "rua.plus/lolly/internal/config" "rua.plus/lolly/internal/loadbalance" + "rua.plus/lolly/internal/netutil" ) // TestNewProxy 测试 NewProxy 函数 @@ -587,9 +588,9 @@ func TestGetClientIP(t *testing.T) { ctx.Request.Header.Set("X-Real-IP", tt.xri) } - ip := getClientIP(ctx) + ip := netutil.ExtractClientIP(ctx) if ip != tt.expected { - t.Errorf("getClientIP() = %q, want %q", ip, tt.expected) + t.Errorf("ExtractClientIP() = %q, want %q", ip, tt.expected) } }) } diff --git a/internal/proxy/websocket.go b/internal/proxy/websocket.go index fd6f816..71f16b6 100644 --- a/internal/proxy/websocket.go +++ b/internal/proxy/websocket.go @@ -33,6 +33,7 @@ import ( "github.com/valyala/fasthttp" "rua.plus/lolly/internal/loadbalance" + "rua.plus/lolly/internal/netutil" ) // WebSocketBridge WebSocket 桥接器。 @@ -207,30 +208,7 @@ func isConnectionClosedError(err error) bool { // - error: 连接失败时返回错误 func dialTarget(targetURL string, timeout time.Duration) (net.Conn, error) { // 解析目标 URL - isTLS := false - addr := targetURL - - // 处理协议前缀 - if strings.HasPrefix(targetURL, "http://") { - addr = targetURL[7:] - } else if strings.HasPrefix(targetURL, "https://") { - addr = targetURL[8:] - isTLS = true - } - - // 移除路径部分,只保留 host:port - if idx := strings.Index(addr, "/"); idx != -1 { - addr = addr[:idx] - } - - // 如果没有端口,添加默认端口 - if !strings.Contains(addr, ":") { - if isTLS { - addr = addr + ":443" - } else { - addr = addr + ":80" - } - } + addr, isTLS := netutil.ParseTargetURL(targetURL, true) // 建立 TCP 连接 dialer := &net.Dialer{ @@ -309,7 +287,7 @@ func buildWebSocketUpgradeRequest(ctx *fasthttp.RequestCtx, targetHost string) s } // 添加 X-Forwarded 头 - clientIP := getClientIP(ctx) + clientIP := netutil.ExtractClientIP(ctx) if clientIP != "" { fmt.Fprintf(&req, "X-Forwarded-For: %s\r\n", clientIP) fmt.Fprintf(&req, "X-Real-IP: %s\r\n", clientIP) @@ -394,49 +372,37 @@ func ProxyWebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, timeou return fmt.Errorf("failed to connect to backend: %w", err) } + // 创建桥接器管理两个连接 + bridge := NewWebSocketBridge(clientConn, targetConn) + defer func() { _ = bridge.Close() }() + // 步骤2: 从目标 URL 提取主机地址 targetHost := extractHost(target.URL) // 步骤3: 构建并发送 WebSocket 升级请求 upgradeReq := buildWebSocketUpgradeRequest(ctx, targetHost) if _, err := targetConn.Write([]byte(upgradeReq)); err != nil { - _ = clientConn.Close() - _ = targetConn.Close() return fmt.Errorf("failed to send upgrade request: %w", err) } // 步骤4: 读取升级响应 resp, err := readWebSocketUpgradeResponse(targetConn, timeout) if err != nil { - _ = clientConn.Close() - _ = targetConn.Close() return fmt.Errorf("failed to read upgrade response: %w", err) } // 步骤5: 检查响应状态码(期望 101 Switching Protocols) if resp.StatusCode != http.StatusSwitchingProtocols { - _ = clientConn.Close() - _ = targetConn.Close() return fmt.Errorf("backend rejected WebSocket upgrade: %s", resp.Status) } // 步骤6: 将升级响应发送回客户端 if err := writeUpgradeResponse(clientConn, resp); err != nil { - _ = clientConn.Close() - _ = targetConn.Close() return fmt.Errorf("failed to send upgrade response to client: %w", err) } - // 步骤7: 创建桥接器并启动双向转发 - bridge := NewWebSocketBridge(clientConn, targetConn) - - // 启动桥接(阻塞直到连接关闭) - bridgeErr := bridge.Bridge() - - // 清理:关闭连接 - _ = bridge.Close() - - return bridgeErr + // 步骤7: 启动桥接(阻塞直到连接关闭) + return bridge.Bridge() } // extractHost 从 URL 中提取主机地址(带端口)。 @@ -449,28 +415,7 @@ func ProxyWebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, timeou // 返回值: // - string: 主机地址(格式 host:port) func extractHost(url string) string { - addr := url - if strings.HasPrefix(url, "http://") { - addr = url[7:] - } else if strings.HasPrefix(url, "https://") { - addr = url[8:] - } - - // 移除路径部分 - if idx := strings.Index(addr, "/"); idx != -1 { - addr = addr[:idx] - } - - // 如果没有端口,添加默认端口 - if !strings.Contains(addr, ":") { - if strings.HasPrefix(url, "https://") { - addr = addr + ":443" - } else { - addr = addr + ":80" - } - } - - return addr + return netutil.ExtractHost(url) } // writeUpgradeResponse 将 HTTP 升级响应写回客户端。 diff --git a/internal/server/server.go b/internal/server/server.go index b63e871..9bd0674 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -340,23 +340,8 @@ func (s *Server) startSingleMode() error { // 注册代理路由 s.registerProxyRoutes(router, &s.config.Server) - // 静态文件服务(作为 fallback) - // 启用零拷贝传输优化(大文件使用 sendfile) - staticHandler := handler.NewStaticHandler( - s.config.Server.Static.Root, - s.config.Server.Static.Index, - true, // useSendfile - ) - // 设置文件缓存 - if s.fileCache != nil { - staticHandler.SetFileCache(s.fileCache) - } - // 设置预压缩文件支持 - if s.config.Server.Compression.GzipStatic { - staticHandler.SetGzipStatic(true, s.config.Server.Compression.GzipStaticExtensions) - } - router.GET("/{filepath:*}", staticHandler.Handle) - router.HEAD("/{filepath:*}", staticHandler.Handle) + // 静态文件服务 + s.registerStaticHandler(router, &s.config.Server) // 构建中间件链 chain, err := s.buildMiddlewareChain(&s.config.Server) @@ -692,3 +677,20 @@ func (s *Server) getProxyCacheStats() ProxyCacheStats { } return total } + +// registerStaticHandler registers static file handler. +func (s *Server) registerStaticHandler(router *handler.Router, cfg *config.ServerConfig) { + staticHandler := handler.NewStaticHandler( + cfg.Static.Root, + cfg.Static.Index, + true, // useSendfile + ) + if s.fileCache != nil { + staticHandler.SetFileCache(s.fileCache) + } + if cfg.Compression.GzipStatic { + staticHandler.SetGzipStatic(true, cfg.Compression.GzipStaticExtensions) + } + router.GET("/{filepath:*}", staticHandler.Handle) + router.HEAD("/{filepath:*}", staticHandler.Handle) +}