refactor: 抽取网络工具函数到 netutil 包,移除冗余代码

- 新增 internal/netutil 包,统一 IP 提取和 URL 解析函数
- proxy/websocket/middleware 使用 netutil 替代重复实现
- 移除 handler/sendfile 中未使用的 BufferPool 相关代码
- 移除 http3/adapter 中未使用的反向转换函数
- 提取 server.registerStaticHandler 函数改进代码结构
- 优化 access.go 锁范围,减少持锁时间

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-03 18:24:21 +08:00
parent cd2d1a8194
commit 7a98a0b044
14 changed files with 525 additions and 536 deletions

View File

@ -22,7 +22,6 @@ import (
"net"
"os"
"runtime"
"sync"
"syscall"
"github.com/valyala/fasthttp"
@ -145,53 +144,3 @@ func getSocketFd(conn net.Conn) (uintptr, error) {
return 0, syscall.ENOTSUP
}
}
// BufferPool 缓冲池,复用内存减少分配。
var BufferPool = &syncPool{
pool: make(chan []byte, 32),
size: 32 * 1024, // 32KB
}
// syncPool 简化的缓冲池。
type syncPool struct {
pool chan []byte
size int
}
// Get 获取缓冲区。
func (p *syncPool) Get() []byte {
select {
case buf := <-p.pool:
return buf
default:
return make([]byte, p.size)
}
}
// Put 放回缓冲区。
func (p *syncPool) Put(buf []byte) {
// 只放回合适大小的缓冲区
if len(buf) == p.size {
select {
case p.pool <- buf:
default: // 池满,丢弃
}
}
}
// RealBufferPool 使用 sync.Pool 的标准实现(推荐)。
var RealBufferPool = sync.Pool{
New: func() interface{} {
return make([]byte, 32*1024)
},
}
// GetBuffer 从池获取缓冲区。
func GetBuffer() []byte {
return RealBufferPool.Get().([]byte)
}
// PutBuffer 放回缓冲区。
func PutBuffer(buf []byte) {
RealBufferPool.Put(buf) //nolint:staticcheck // SA6002: 测试表明指针优化不明显,保持简洁
}

View File

@ -14,61 +14,12 @@ import (
"github.com/valyala/fasthttp"
)
func TestBufferPool(t *testing.T) {
// 获取缓冲区
buf := BufferPool.Get()
if buf == nil {
t.Error("Expected non-nil buffer")
}
if len(buf) != 32*1024 {
t.Errorf("Expected buffer size 32KB, got %d", len(buf))
}
// 放回缓冲区
BufferPool.Put(buf)
// 再次获取(可能是同一个)
buf2 := BufferPool.Get()
if buf2 == nil {
t.Error("Expected non-nil buffer")
}
}
func TestRealBufferPool(t *testing.T) {
buf := GetBuffer()
if buf == nil {
t.Error("Expected non-nil buffer")
}
if len(buf) != 32*1024 {
t.Errorf("Expected buffer size 32KB, got %d", len(buf))
}
PutBuffer(buf)
}
func TestMinSendfileSize(t *testing.T) {
if MinSendfileSize != 8*1024 {
t.Errorf("Expected MinSendfileSize 8KB, got %d", MinSendfileSize)
}
}
func TestGetBuffer(t *testing.T) {
buf := GetBuffer()
if buf == nil {
t.Error("Expected non-nil buffer")
return
}
if len(buf) != 32*1024 {
t.Errorf("Expected buffer size 32KB, got %d", len(buf))
}
// 测试写入
copy(buf, []byte("test"))
if string(buf[:4]) != "test" {
t.Error("Expected to write 'test' to buffer")
}
}
func TestPlatformSendfile(t *testing.T) {
// 创建临时文件
tmpDir := t.TempDir()
@ -90,24 +41,6 @@ func TestPlatformSendfile(t *testing.T) {
_ = platformSendfile(nil, file, 0, int64(len(content)))
}
func TestBufferPoolConcurrent(t *testing.T) {
const iterations = 100
done := make(chan bool)
for i := 0; i < iterations; i++ {
go func() {
buf := GetBuffer()
PutBuffer(buf)
done <- true
}()
}
for i := 0; i < iterations; i++ {
<-done
}
}
// TestCopyFile 测试 copyFile fallback 函数
func TestCopyFile(t *testing.T) {
tmpDir := t.TempDir()

View File

@ -11,11 +11,9 @@
package http3
import (
"bytes"
"io"
"net"
"net/http"
"net/url"
"sync"
"github.com/valyala/fasthttp"
@ -159,93 +157,3 @@ func (a *Adapter) convertResponse(ctx *fasthttp.RequestCtx, w http.ResponseWrite
func (a *Adapter) WrapHandler(handler fasthttp.RequestHandler) http.Handler {
return a.Wrap(handler)
}
// FastHTTPHandler 从 http.Handler 提取并调用 fasthttp 处理器。
//
// 这是一个便捷方法,用于在需要时反向转换。
//
// 参数:
// - h: 标准库 http.Handler
// - ctx: FastHTTP 请求上下文
func FastHTTPHandler(h http.Handler, ctx *fasthttp.RequestCtx) {
// 创建虚拟 ResponseWriter
rw := &fastHTTPResponseWriter{
ctx: ctx,
}
// 转换请求
r := convertToHTTPRequest(ctx)
// 调用标准库 handler
h.ServeHTTP(rw, r)
}
// fastHTTPResponseWriter 实现 http.ResponseWriter 接口。
type fastHTTPResponseWriter struct {
ctx *fasthttp.RequestCtx
status int
header http.Header
written bool
}
func (w *fastHTTPResponseWriter) Header() http.Header {
if w.header == nil {
w.header = make(http.Header)
}
return w.header
}
func (w *fastHTTPResponseWriter) Write(data []byte) (int, error) {
if !w.written {
w.WriteHeader(http.StatusOK)
}
return w.ctx.Write(data)
}
func (w *fastHTTPResponseWriter) WriteHeader(statusCode int) {
if w.written {
return
}
w.written = true
w.status = statusCode
// 复制头部
for k, v := range w.header {
for _, vv := range v {
w.ctx.Response.Header.Add(k, vv)
}
}
w.ctx.SetStatusCode(statusCode)
}
// convertToHTTPRequest 将 fasthttp.RequestCtx 转换为 http.Request。
func convertToHTTPRequest(ctx *fasthttp.RequestCtx) *http.Request {
r := &http.Request{
Method: string(ctx.Method()),
Host: string(ctx.Host()),
RemoteAddr: ctx.RemoteAddr().String(),
Proto: "HTTP/3",
ProtoMajor: 3,
ProtoMinor: 0,
}
// 构建 URL
r.URL = &url.URL{
Path: string(ctx.Path()),
RawQuery: string(ctx.URI().QueryString()),
}
// 复制头部
r.Header = make(http.Header)
for k, v := range ctx.Request.Header.All() {
r.Header.Add(string(k), string(v))
}
// 设置请求体
if len(ctx.PostBody()) > 0 {
r.Body = io.NopCloser(bytes.NewReader(ctx.PostBody()))
}
return r
}

View File

@ -299,143 +299,6 @@ func TestConvertResponse_Body(t *testing.T) {
}
}
// TestFastHTTPHandler 测试反向转换
func TestFastHTTPHandler(t *testing.T) {
// 创建标准库 handler
stdHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(200)
_, _ = w.Write([]byte("Hello from std http"))
})
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
ctx.Request.SetRequestURI("/test")
ctx.Request.Header.SetMethod("GET")
FastHTTPHandler(stdHandler, ctx)
if ctx.Response.StatusCode() != 200 {
t.Errorf("Expected status 200, got %d", ctx.Response.StatusCode())
}
if string(ctx.Response.Body()) != "Hello from std http" {
t.Errorf("Expected body 'Hello from std http', got %s", ctx.Response.Body())
}
}
// TestConvertToHTTPRequest 测试转换为标准库请求
func TestConvertToHTTPRequest(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
ctx.Request.SetRequestURI("/path?query=value")
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.SetHost("example.com")
ctx.Request.Header.Set("Content-Type", "application/json")
ctx.Request.SetBody([]byte("test body"))
r := convertToHTTPRequest(ctx)
if r.Method != "POST" {
t.Errorf("Expected Method POST, got %s", r.Method)
}
if r.Host != "example.com" {
t.Errorf("Expected Host example.com, got %s", r.Host)
}
if r.URL.Path != "/path" {
t.Errorf("Expected Path /path, got %s", r.URL.Path)
}
if r.URL.RawQuery != "query=value" {
t.Errorf("Expected RawQuery query=value, got %s", r.URL.RawQuery)
}
if r.Proto != "HTTP/3" {
t.Errorf("Expected Proto HTTP/3, got %s", r.Proto)
}
if r.ProtoMajor != 3 || r.ProtoMinor != 0 {
t.Errorf("Expected Proto 3.0, got %d.%d", r.ProtoMajor, r.ProtoMinor)
}
// 检查头部
if r.Header.Get("Content-Type") != "application/json" {
t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
}
// 检查请求体
body, _ := io.ReadAll(r.Body)
if string(body) != "test body" {
t.Errorf("Expected body 'test body', got %s", body)
}
}
// TestFastHTTPResponseWriter_Write 测试 Write 方法
func TestFastHTTPResponseWriter_Write(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
rw := &fastHTTPResponseWriter{ctx: ctx}
n, err := rw.Write([]byte("test content"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if n != len("test content") {
t.Errorf("Expected written %d, got %d", len("test content"), n)
}
// 检查状态码被自动设置
if rw.status != http.StatusOK {
t.Errorf("Expected auto-set status 200, got %d", rw.status)
}
}
// TestFastHTTPResponseWriter_WriteHeader 测试 WriteHeader 方法
func TestFastHTTPResponseWriter_WriteHeader(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
rw := &fastHTTPResponseWriter{ctx: ctx}
rw.Header().Set("X-Custom", "value")
rw.WriteHeader(404)
if rw.status != 404 {
t.Errorf("Expected status 404, got %d", rw.status)
}
if rw.written != true {
t.Error("Expected written flag to be true")
}
// 再次调用应该被忽略
rw.WriteHeader(500)
if rw.status != 404 {
t.Errorf("Expected status to remain 404, got %d", rw.status)
}
}
// TestFastHTTPResponseWriter_Header 测试 Header 方法
func TestFastHTTPResponseWriter_Header(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
rw := &fastHTTPResponseWriter{ctx: ctx}
h := rw.Header()
if h == nil {
t.Error("Expected non-nil header")
}
h.Set("Content-Type", "text/html")
if rw.Header().Get("Content-Type") != "text/html" {
t.Errorf("Expected Content-Type text/html, got %s", rw.Header().Get("Content-Type"))
}
}
// TestWrap_RoundTrip 完整流程测试
func TestWrap_RoundTrip(t *testing.T) {
adapter := NewAdapter()

View File

@ -195,19 +195,14 @@ func (ac *AccessControl) Check(ip net.IP) bool {
// 返回值:
// - error: CIDR 解析失败时返回错误
func (ac *AccessControl) UpdateAllowList(cidrs []string) error {
ac.mu.Lock()
defer ac.mu.Unlock()
newList := make([]net.IPNet, 0, len(cidrs))
for _, cidr := range cidrs {
network, err := parseCIDR(cidr)
if err != nil {
return fmt.Errorf("invalid CIDR %s: %w", cidr, err)
}
newList = append(newList, *network)
newList, err := parseCIDRList(cidrs)
if err != nil {
return err
}
ac.mu.Lock()
ac.allowList = newList
ac.mu.Unlock()
return nil
}
@ -221,20 +216,35 @@ func (ac *AccessControl) UpdateAllowList(cidrs []string) error {
// 返回值:
// - error: CIDR 解析失败时返回错误
func (ac *AccessControl) UpdateDenyList(cidrs []string) error {
ac.mu.Lock()
defer ac.mu.Unlock()
newList, err := parseCIDRList(cidrs)
if err != nil {
return err
}
ac.mu.Lock()
ac.denyList = newList
ac.mu.Unlock()
return nil
}
// parseCIDRList 解析 CIDR 字符串列表为 IPNet 列表。
//
// 参数:
// - cidrs: CIDR 字符串列表
//
// 返回值:
// - []net.IPNet: 解析后的 IP 网络对象列表
// - error: 任一 CIDR 解析失败时返回错误
func parseCIDRList(cidrs []string) ([]net.IPNet, error) {
newList := make([]net.IPNet, 0, len(cidrs))
for _, cidr := range cidrs {
network, err := parseCIDR(cidr)
if err != nil {
return fmt.Errorf("invalid CIDR %s: %w", cidr, err)
return nil, fmt.Errorf("invalid CIDR %s: %w", cidr, err)
}
newList = append(newList, *network)
}
ac.denyList = newList
return nil
return newList, nil
}
// SetDefault 设置默认操作。

View File

@ -32,8 +32,6 @@ package security
import (
"errors"
"fmt"
"net"
"strings"
"sync"
"sync/atomic"
"time"
@ -41,6 +39,7 @@ import (
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/middleware"
"rua.plus/lolly/internal/netutil"
)
// RateLimiter 基于令牌桶算法的请求速率限制器。
@ -322,53 +321,13 @@ func (rl *RateLimiter) getRetryAfter(key string) int64 {
// 返回值:
// - string: IP 地址字符串,无法获取时返回 "unknown"
func keyByIP(ctx *fasthttp.RequestCtx) string {
ip := extractClientIP(ctx)
ip := netutil.ExtractClientIPNet(ctx)
if ip == nil {
return "unknown"
}
return ip.String()
}
// extractClientIP 从请求上下文提取客户端 IP。
//
// 按优先级依次检查X-Forwarded-For、X-Real-IP、RemoteAddr。
//
// 参数:
// - ctx: FastHTTP 请求上下文
//
// 返回值:
// - net.IP: 客户端 IP 地址,无法获取时返回 nil
func extractClientIP(ctx *fasthttp.RequestCtx) net.IP {
// 优先检查 X-Forwarded-For 头部
if xff := ctx.Request.Header.Peek("X-Forwarded-For"); len(xff) > 0 {
ips := strings.Split(string(xff), ",")
if len(ips) > 0 {
ipStr := strings.TrimSpace(ips[0])
ip := net.ParseIP(ipStr)
if ip != nil {
return ip
}
}
}
// 检查 X-Real-IP 头部
if xri := ctx.Request.Header.Peek("X-Real-IP"); len(xri) > 0 {
ip := net.ParseIP(string(xri))
if ip != nil {
return ip
}
}
// 回退到 RemoteAddr
if addr := ctx.RemoteAddr(); addr != nil {
if tcpAddr, ok := addr.(*net.TCPAddr); ok {
return tcpAddr.IP
}
}
return nil
}
// keyByHeader 提取头部值作为限流键。
//
// 默认使用 X-RateLimit-Key 头部,如果不存在则回退到 IP。

112
internal/netutil/ip.go Normal file
View File

@ -0,0 +1,112 @@
// Package netutil 提供网络相关的通用工具函数。
//
// 该文件包含客户端 IP 提取相关的工具函数,
// 从 HTTP 请求中提取真实的客户端 IP 地址。
//
// 作者xfy
package netutil
import (
"net"
"strings"
"github.com/valyala/fasthttp"
)
// ExtractClientIP 从请求上下文中提取客户端 IP 地址(返回字符串)。
//
// 该函数按以下顺序提取 IP
// 1. X-Forwarded-For 请求头的第一个 IP最左侧
// 2. X-Real-IP 请求头
// 3. RemoteAddr
//
// 注意:此函数不进行可信代理验证,适用于非安全场景(如日志记录)。
// 对于安全场景(如访问控制),应使用特定模块的安全实现。
//
// 参数:
// - ctx: FastHTTP 请求上下文
//
// 返回值:
// - string: 客户端 IP 地址字符串
func ExtractClientIP(ctx *fasthttp.RequestCtx) string {
// 首先检查 X-Forwarded-For 请求头
if xff := ctx.Request.Header.Peek("X-Forwarded-For"); len(xff) > 0 {
ips := strings.Split(string(xff), ",")
if len(ips) > 0 {
return strings.TrimSpace(ips[0])
}
}
// 检查 X-Real-IP 请求头
if xri := ctx.Request.Header.Peek("X-Real-IP"); len(xri) > 0 {
return string(xri)
}
// 回退到 RemoteAddr
if addr := ctx.RemoteAddr(); addr != nil {
if tcpAddr, ok := addr.(*net.TCPAddr); ok {
return tcpAddr.IP.String()
}
return addr.String()
}
return ""
}
// ExtractClientIPNet 从请求上下文中提取客户端 IP 地址(返回 net.IP
//
// 该函数与 ExtractClientIP 功能相同,但返回 net.IP 类型,
// 便于后续进行 IP 网络操作(如 CIDR 匹配)。
//
// 参数:
// - ctx: FastHTTP 请求上下文
//
// 返回值:
// - net.IP: 客户端 IP 地址,无法解析时返回 nil
func ExtractClientIPNet(ctx *fasthttp.RequestCtx) net.IP {
// 首先检查 X-Forwarded-For 请求头
if xff := ctx.Request.Header.Peek("X-Forwarded-For"); len(xff) > 0 {
ips := strings.Split(string(xff), ",")
if len(ips) > 0 {
ipStr := strings.TrimSpace(ips[0])
if ip := net.ParseIP(ipStr); ip != nil {
return ip
}
}
}
// 检查 X-Real-IP 请求头
if xri := ctx.Request.Header.Peek("X-Real-IP"); len(xri) > 0 {
if ip := net.ParseIP(string(xri)); ip != nil {
return ip
}
}
// 回退到 RemoteAddr
if addr := ctx.RemoteAddr(); addr != nil {
if tcpAddr, ok := addr.(*net.TCPAddr); ok {
return tcpAddr.IP
}
}
return nil
}
// GetRemoteAddrIP 从 RemoteAddr 提取 IP 地址。
//
// 这是一个辅助函数,直接从连接的远程地址获取 IP
// 不检查任何代理头。
//
// 参数:
// - ctx: FastHTTP 请求上下文
//
// 返回值:
// - net.IP: 客户端 IP 地址,无法获取时返回 nil
func GetRemoteAddrIP(ctx *fasthttp.RequestCtx) net.IP {
if addr := ctx.RemoteAddr(); addr != nil {
if tcpAddr, ok := addr.(*net.TCPAddr); ok {
return tcpAddr.IP
}
}
return nil
}

123
internal/netutil/ip_test.go Normal file
View File

@ -0,0 +1,123 @@
package netutil
import (
"net"
"testing"
"github.com/valyala/fasthttp"
)
func TestExtractClientIP(t *testing.T) {
tests := []struct {
name string
xff string
xri string
remoteAddr string
want string
}{
{
name: "X-Forwarded-For with single IP",
xff: "192.168.1.100",
want: "192.168.1.100",
},
{
name: "X-Forwarded-For with multiple IPs",
xff: "192.168.1.100, 10.0.0.1, 172.16.0.1",
want: "192.168.1.100",
},
{
name: "X-Real-IP only",
xri: "192.168.1.200",
want: "192.168.1.200",
},
{
name: "RemoteAddr fallback",
remoteAddr: "192.168.1.1:12345",
want: "0.0.0.0", // fasthttp 默认初始化为 0.0.0.0
},
{
name: "X-Forwarded-For takes precedence over X-Real-IP",
xff: "192.168.1.100",
xri: "192.168.1.200",
want: "192.168.1.100",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
if tt.xff != "" {
ctx.Request.Header.Set("X-Forwarded-For", tt.xff)
}
if tt.xri != "" {
ctx.Request.Header.Set("X-Real-IP", tt.xri)
}
got := ExtractClientIP(ctx)
if got != tt.want {
t.Errorf("ExtractClientIP() = %q, want %q", got, tt.want)
}
})
}
}
func TestExtractClientIPNet(t *testing.T) {
tests := []struct {
name string
xff string
xri string
want net.IP
}{
{
name: "X-Forwarded-For valid IP",
xff: "192.168.1.100",
want: net.ParseIP("192.168.1.100"),
},
{
name: "X-Forwarded-For invalid IP",
xff: "invalid-ip",
want: net.ParseIP("0.0.0.0"), // fasthttp 默认 RemoteAddr
},
{
name: "X-Real-IP valid IP",
xri: "192.168.1.200",
want: net.ParseIP("192.168.1.200"),
},
{
name: "No headers",
want: net.ParseIP("0.0.0.0"), // fasthttp 默认 RemoteAddr
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
if tt.xff != "" {
ctx.Request.Header.Set("X-Forwarded-For", tt.xff)
}
if tt.xri != "" {
ctx.Request.Header.Set("X-Real-IP", tt.xri)
}
got := ExtractClientIPNet(ctx)
if !got.Equal(tt.want) {
t.Errorf("ExtractClientIPNet() = %v, want %v", got, tt.want)
}
})
}
}
func TestGetRemoteAddrIP(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
// Without setting remote addr, should return nil
got := GetRemoteAddrIP(ctx)
// The result depends on how fasthttp initializes the remote addr
// Just verify it doesn't panic
_ = got
}

76
internal/netutil/url.go Normal file
View File

@ -0,0 +1,76 @@
// Package netutil 提供网络相关的通用工具函数。
//
// 该包包含 URL 解析、客户端 IP 提取等网络操作的工具函数,
// 供 proxy、middleware、server 等模块共享使用。
//
// 作者xfy
package netutil
import "strings"
// ParseTargetURL 解析目标 URL提取主机地址和 TLS 标志。
//
// 该函数用于统一处理代理模块中的 URL 解析逻辑,支持 http:// 和 https:// 前缀。
//
// 参数:
// - targetURL: 目标 URL 字符串(如 "http://backend:8080/path" 或 "https://api.example.com"
// - addDefaultPort: 是否在没有端口时添加默认端口(:80 或 :443
//
// 返回值:
// - addr: 主机地址(格式 host:port
// - isTLS: 是否使用 TLSHTTPS
//
// 示例:
//
// addr, isTLS := ParseTargetURL("https://api.example.com/api", true)
// // addr = "api.example.com:443", isTLS = true
//
// addr, isTLS := ParseTargetURL("http://backend:8080", false)
// // addr = "backend:8080", isTLS = false
func ParseTargetURL(targetURL string, addDefaultPort bool) (addr string, isTLS bool) {
addr = targetURL
// 处理协议前缀
if strings.HasPrefix(targetURL, "http://") {
addr = targetURL[7:]
} else if strings.HasPrefix(targetURL, "https://") {
addr = targetURL[8:]
isTLS = true
}
// 移除路径部分,只保留 host:port
if idx := strings.Index(addr, "/"); idx != -1 {
addr = addr[:idx]
}
// 如果需要,添加默认端口
if addDefaultPort && !strings.Contains(addr, ":") {
if isTLS {
addr = addr + ":443"
} else {
addr = addr + ":80"
}
}
return addr, isTLS
}
// ExtractHost 从 URL 提取主机地址host:port
//
// 该函数是 ParseTargetURL 的简化版本,始终添加默认端口,
// 用于需要完整地址但不需要 TLS 标志的场景。
//
// 参数:
// - targetURL: 目标 URL 字符串
//
// 返回值:
// - string: 主机地址(格式 host:port
//
// 示例:
//
// host := ExtractHost("https://api.example.com/api")
// // host = "api.example.com:443"
func ExtractHost(targetURL string) string {
addr, _ := ParseTargetURL(targetURL, true)
return addr
}

View File

@ -0,0 +1,147 @@
package netutil
import "testing"
func TestParseTargetURL(t *testing.T) {
tests := []struct {
name string
targetURL string
addDefaultPort bool
wantAddr string
wantIsTLS bool
}{
// HTTP without port
{
name: "http without port, add default",
targetURL: "http://backend.example.com",
addDefaultPort: true,
wantAddr: "backend.example.com:80",
wantIsTLS: false,
},
{
name: "http without port, no default",
targetURL: "http://backend.example.com",
addDefaultPort: false,
wantAddr: "backend.example.com",
wantIsTLS: false,
},
// HTTPS without port
{
name: "https without port, add default",
targetURL: "https://api.example.com",
addDefaultPort: true,
wantAddr: "api.example.com:443",
wantIsTLS: true,
},
{
name: "https without port, no default",
targetURL: "https://api.example.com",
addDefaultPort: false,
wantAddr: "api.example.com",
wantIsTLS: true,
},
// HTTP with port
{
name: "http with port",
targetURL: "http://backend:8080",
addDefaultPort: true,
wantAddr: "backend:8080",
wantIsTLS: false,
},
// HTTPS with port
{
name: "https with port",
targetURL: "https://api:8443",
addDefaultPort: true,
wantAddr: "api:8443",
wantIsTLS: true,
},
// With path
{
name: "http with path",
targetURL: "http://backend:8080/api/v1",
addDefaultPort: false,
wantAddr: "backend:8080",
wantIsTLS: false,
},
{
name: "https with path",
targetURL: "https://api.example.com/v1/users",
addDefaultPort: true,
wantAddr: "api.example.com:443",
wantIsTLS: true,
},
// No protocol (treat as HTTP)
{
name: "no protocol",
targetURL: "backend:8080",
addDefaultPort: false,
wantAddr: "backend:8080",
wantIsTLS: false,
},
{
name: "no protocol, no port, add default",
targetURL: "backend",
addDefaultPort: true,
wantAddr: "backend:80",
wantIsTLS: false,
},
// IPv6 address
{
name: "ipv6 address",
targetURL: "http://[::1]:8080",
addDefaultPort: false,
wantAddr: "[::1]:8080",
wantIsTLS: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotAddr, gotIsTLS := ParseTargetURL(tt.targetURL, tt.addDefaultPort)
if gotAddr != tt.wantAddr {
t.Errorf("ParseTargetURL() addr = %q, want %q", gotAddr, tt.wantAddr)
}
if gotIsTLS != tt.wantIsTLS {
t.Errorf("ParseTargetURL() isTLS = %v, want %v", gotIsTLS, tt.wantIsTLS)
}
})
}
}
func TestExtractHost(t *testing.T) {
tests := []struct {
name string
targetURL string
want string
}{
{
name: "http without port",
targetURL: "http://backend.example.com",
want: "backend.example.com:80",
},
{
name: "https without port",
targetURL: "https://api.example.com",
want: "api.example.com:443",
},
{
name: "http with port",
targetURL: "http://backend:8080",
want: "backend:8080",
},
{
name: "https with path",
targetURL: "https://api.example.com/v1/users",
want: "api.example.com:443",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ExtractHost(tt.targetURL); got != tt.want {
t.Errorf("ExtractHost() = %q, want %q", got, tt.want)
}
})
}
}

View File

@ -34,7 +34,6 @@ package proxy
import (
"errors"
"net"
"strings"
"sync"
"time"
@ -44,6 +43,7 @@ import (
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/loadbalance"
"rua.plus/lolly/internal/logging"
"rua.plus/lolly/internal/netutil"
)
// Proxy 表示反向代理实例,负责将 HTTP 请求转发到后端目标。
@ -147,20 +147,7 @@ func createBalancer(cfg *config.ProxyConfig) (loadbalance.Balancer, error) {
// createHostClient 为后台目标 URL 创建 fasthttp.HostClient。
func createHostClient(targetURL string, timeout config.ProxyTimeout, transportCfg *config.TransportConfig) *fasthttp.HostClient {
// 从目标 URL 解析主机和协议
addr := targetURL
isTLS := false
if strings.HasPrefix(targetURL, "http://") {
addr = targetURL[7:]
} else if strings.HasPrefix(targetURL, "https://") {
addr = targetURL[8:]
isTLS = true
}
// 如果存在路径则移除,只保留 host:port
if idx := strings.Index(addr, "/"); idx != -1 {
addr = addr[:idx]
}
addr, isTLS := netutil.ParseTargetURL(targetURL, false)
// 默认值
maxIdleConnDuration := 90 * time.Second
@ -321,7 +308,7 @@ func (p *Proxy) selectTarget(ctx *fasthttp.RequestCtx) *loadbalance.Target {
// 对于 IPHash 负载均衡器,提取客户端 IP
if ipHash, ok := balancer.(*loadbalance.IPHash); ok {
clientIP := getClientIP(ctx)
clientIP := netutil.ExtractClientIP(ctx)
return ipHash.SelectByIP(targets, clientIP)
}
@ -339,7 +326,7 @@ func (p *Proxy) selectTarget(ctx *fasthttp.RequestCtx) *loadbalance.Target {
func (p *Proxy) extractHashKey(ctx *fasthttp.RequestCtx, hashKey string) string {
switch {
case hashKey == "ip" || hashKey == "":
return getClientIP(ctx)
return netutil.ExtractClientIP(ctx)
case hashKey == "uri":
return string(ctx.RequestURI())
case strings.HasPrefix(hashKey, "header:"):
@ -348,38 +335,12 @@ func (p *Proxy) extractHashKey(ctx *fasthttp.RequestCtx, hashKey string) string
if len(value) > 0 {
return string(value)
}
return getClientIP(ctx) // fallback to IP
return netutil.ExtractClientIP(ctx) // fallback to IP
default:
return getClientIP(ctx)
return netutil.ExtractClientIP(ctx)
}
}
// getClientIP 从请求上下文中提取客户端 IP 地址。
func getClientIP(ctx *fasthttp.RequestCtx) string {
// 首先检查 X-Forwarded-For 请求头
if xff := ctx.Request.Header.Peek("X-Forwarded-For"); len(xff) > 0 {
ips := strings.Split(string(xff), ",")
if len(ips) > 0 {
return strings.TrimSpace(ips[0])
}
}
// 检查 X-Real-IP 请求头
if xri := ctx.Request.Header.Peek("X-Real-IP"); len(xri) > 0 {
return string(xri)
}
// 回退到 RemoteAddr
if addr := ctx.RemoteAddr(); addr != nil {
if tcpAddr, ok := addr.(*net.TCPAddr); ok {
return tcpAddr.IP.String()
}
return addr.String()
}
return ""
}
// getClient 返回给定目标 URL 对应的 HostClient。
func (p *Proxy) getClient(targetURL string) *fasthttp.HostClient {
p.mu.RLock()
@ -394,7 +355,7 @@ func (p *Proxy) modifyRequestHeaders(ctx *fasthttp.RequestCtx, target *loadbalan
headers := &ctx.Request.Header
// 添加 X-Real-IP 请求头
clientIP := getClientIP(ctx)
clientIP := netutil.ExtractClientIP(ctx)
if clientIP != "" {
headers.Set("X-Real-IP", clientIP)
}

View File

@ -12,6 +12,7 @@ import (
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/loadbalance"
"rua.plus/lolly/internal/netutil"
)
// TestNewProxy 测试 NewProxy 函数
@ -587,9 +588,9 @@ func TestGetClientIP(t *testing.T) {
ctx.Request.Header.Set("X-Real-IP", tt.xri)
}
ip := getClientIP(ctx)
ip := netutil.ExtractClientIP(ctx)
if ip != tt.expected {
t.Errorf("getClientIP() = %q, want %q", ip, tt.expected)
t.Errorf("ExtractClientIP() = %q, want %q", ip, tt.expected)
}
})
}

View File

@ -33,6 +33,7 @@ import (
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/loadbalance"
"rua.plus/lolly/internal/netutil"
)
// WebSocketBridge WebSocket 桥接器。
@ -207,30 +208,7 @@ func isConnectionClosedError(err error) bool {
// - error: 连接失败时返回错误
func dialTarget(targetURL string, timeout time.Duration) (net.Conn, error) {
// 解析目标 URL
isTLS := false
addr := targetURL
// 处理协议前缀
if strings.HasPrefix(targetURL, "http://") {
addr = targetURL[7:]
} else if strings.HasPrefix(targetURL, "https://") {
addr = targetURL[8:]
isTLS = true
}
// 移除路径部分,只保留 host:port
if idx := strings.Index(addr, "/"); idx != -1 {
addr = addr[:idx]
}
// 如果没有端口,添加默认端口
if !strings.Contains(addr, ":") {
if isTLS {
addr = addr + ":443"
} else {
addr = addr + ":80"
}
}
addr, isTLS := netutil.ParseTargetURL(targetURL, true)
// 建立 TCP 连接
dialer := &net.Dialer{
@ -309,7 +287,7 @@ func buildWebSocketUpgradeRequest(ctx *fasthttp.RequestCtx, targetHost string) s
}
// 添加 X-Forwarded 头
clientIP := getClientIP(ctx)
clientIP := netutil.ExtractClientIP(ctx)
if clientIP != "" {
fmt.Fprintf(&req, "X-Forwarded-For: %s\r\n", clientIP)
fmt.Fprintf(&req, "X-Real-IP: %s\r\n", clientIP)
@ -394,49 +372,37 @@ func ProxyWebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, timeou
return fmt.Errorf("failed to connect to backend: %w", err)
}
// 创建桥接器管理两个连接
bridge := NewWebSocketBridge(clientConn, targetConn)
defer func() { _ = bridge.Close() }()
// 步骤2: 从目标 URL 提取主机地址
targetHost := extractHost(target.URL)
// 步骤3: 构建并发送 WebSocket 升级请求
upgradeReq := buildWebSocketUpgradeRequest(ctx, targetHost)
if _, err := targetConn.Write([]byte(upgradeReq)); err != nil {
_ = clientConn.Close()
_ = targetConn.Close()
return fmt.Errorf("failed to send upgrade request: %w", err)
}
// 步骤4: 读取升级响应
resp, err := readWebSocketUpgradeResponse(targetConn, timeout)
if err != nil {
_ = clientConn.Close()
_ = targetConn.Close()
return fmt.Errorf("failed to read upgrade response: %w", err)
}
// 步骤5: 检查响应状态码(期望 101 Switching Protocols
if resp.StatusCode != http.StatusSwitchingProtocols {
_ = clientConn.Close()
_ = targetConn.Close()
return fmt.Errorf("backend rejected WebSocket upgrade: %s", resp.Status)
}
// 步骤6: 将升级响应发送回客户端
if err := writeUpgradeResponse(clientConn, resp); err != nil {
_ = clientConn.Close()
_ = targetConn.Close()
return fmt.Errorf("failed to send upgrade response to client: %w", err)
}
// 步骤7: 创建桥接器并启动双向转发
bridge := NewWebSocketBridge(clientConn, targetConn)
// 启动桥接(阻塞直到连接关闭)
bridgeErr := bridge.Bridge()
// 清理:关闭连接
_ = bridge.Close()
return bridgeErr
// 步骤7: 启动桥接(阻塞直到连接关闭)
return bridge.Bridge()
}
// extractHost 从 URL 中提取主机地址(带端口)。
@ -449,28 +415,7 @@ func ProxyWebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, timeou
// 返回值:
// - string: 主机地址(格式 host:port
func extractHost(url string) string {
addr := url
if strings.HasPrefix(url, "http://") {
addr = url[7:]
} else if strings.HasPrefix(url, "https://") {
addr = url[8:]
}
// 移除路径部分
if idx := strings.Index(addr, "/"); idx != -1 {
addr = addr[:idx]
}
// 如果没有端口,添加默认端口
if !strings.Contains(addr, ":") {
if strings.HasPrefix(url, "https://") {
addr = addr + ":443"
} else {
addr = addr + ":80"
}
}
return addr
return netutil.ExtractHost(url)
}
// writeUpgradeResponse 将 HTTP 升级响应写回客户端。

View File

@ -340,23 +340,8 @@ func (s *Server) startSingleMode() error {
// 注册代理路由
s.registerProxyRoutes(router, &s.config.Server)
// 静态文件服务(作为 fallback
// 启用零拷贝传输优化(大文件使用 sendfile
staticHandler := handler.NewStaticHandler(
s.config.Server.Static.Root,
s.config.Server.Static.Index,
true, // useSendfile
)
// 设置文件缓存
if s.fileCache != nil {
staticHandler.SetFileCache(s.fileCache)
}
// 设置预压缩文件支持
if s.config.Server.Compression.GzipStatic {
staticHandler.SetGzipStatic(true, s.config.Server.Compression.GzipStaticExtensions)
}
router.GET("/{filepath:*}", staticHandler.Handle)
router.HEAD("/{filepath:*}", staticHandler.Handle)
// 静态文件服务
s.registerStaticHandler(router, &s.config.Server)
// 构建中间件链
chain, err := s.buildMiddlewareChain(&s.config.Server)
@ -692,3 +677,20 @@ func (s *Server) getProxyCacheStats() ProxyCacheStats {
}
return total
}
// registerStaticHandler registers static file handler.
func (s *Server) registerStaticHandler(router *handler.Router, cfg *config.ServerConfig) {
staticHandler := handler.NewStaticHandler(
cfg.Static.Root,
cfg.Static.Index,
true, // useSendfile
)
if s.fileCache != nil {
staticHandler.SetFileCache(s.fileCache)
}
if cfg.Compression.GzipStatic {
staticHandler.SetGzipStatic(true, cfg.Compression.GzipStaticExtensions)
}
router.GET("/{filepath:*}", staticHandler.Handle)
router.HEAD("/{filepath:*}", staticHandler.Handle)
}