refactor: 抽取网络工具函数到 netutil 包,移除冗余代码
- 新增 internal/netutil 包,统一 IP 提取和 URL 解析函数 - proxy/websocket/middleware 使用 netutil 替代重复实现 - 移除 handler/sendfile 中未使用的 BufferPool 相关代码 - 移除 http3/adapter 中未使用的反向转换函数 - 提取 server.registerStaticHandler 函数改进代码结构 - 优化 access.go 锁范围,减少持锁时间 Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
cd2d1a8194
commit
7a98a0b044
@ -22,7 +22,6 @@ import (
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
@ -145,53 +144,3 @@ func getSocketFd(conn net.Conn) (uintptr, error) {
|
||||
return 0, syscall.ENOTSUP
|
||||
}
|
||||
}
|
||||
|
||||
// BufferPool 缓冲池,复用内存减少分配。
|
||||
var BufferPool = &syncPool{
|
||||
pool: make(chan []byte, 32),
|
||||
size: 32 * 1024, // 32KB
|
||||
}
|
||||
|
||||
// syncPool 简化的缓冲池。
|
||||
type syncPool struct {
|
||||
pool chan []byte
|
||||
size int
|
||||
}
|
||||
|
||||
// Get 获取缓冲区。
|
||||
func (p *syncPool) Get() []byte {
|
||||
select {
|
||||
case buf := <-p.pool:
|
||||
return buf
|
||||
default:
|
||||
return make([]byte, p.size)
|
||||
}
|
||||
}
|
||||
|
||||
// Put 放回缓冲区。
|
||||
func (p *syncPool) Put(buf []byte) {
|
||||
// 只放回合适大小的缓冲区
|
||||
if len(buf) == p.size {
|
||||
select {
|
||||
case p.pool <- buf:
|
||||
default: // 池满,丢弃
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RealBufferPool 使用 sync.Pool 的标准实现(推荐)。
|
||||
var RealBufferPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, 32*1024)
|
||||
},
|
||||
}
|
||||
|
||||
// GetBuffer 从池获取缓冲区。
|
||||
func GetBuffer() []byte {
|
||||
return RealBufferPool.Get().([]byte)
|
||||
}
|
||||
|
||||
// PutBuffer 放回缓冲区。
|
||||
func PutBuffer(buf []byte) {
|
||||
RealBufferPool.Put(buf) //nolint:staticcheck // SA6002: 测试表明指针优化不明显,保持简洁
|
||||
}
|
||||
|
||||
@ -14,61 +14,12 @@ import (
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func TestBufferPool(t *testing.T) {
|
||||
// 获取缓冲区
|
||||
buf := BufferPool.Get()
|
||||
if buf == nil {
|
||||
t.Error("Expected non-nil buffer")
|
||||
}
|
||||
if len(buf) != 32*1024 {
|
||||
t.Errorf("Expected buffer size 32KB, got %d", len(buf))
|
||||
}
|
||||
|
||||
// 放回缓冲区
|
||||
BufferPool.Put(buf)
|
||||
|
||||
// 再次获取(可能是同一个)
|
||||
buf2 := BufferPool.Get()
|
||||
if buf2 == nil {
|
||||
t.Error("Expected non-nil buffer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRealBufferPool(t *testing.T) {
|
||||
buf := GetBuffer()
|
||||
if buf == nil {
|
||||
t.Error("Expected non-nil buffer")
|
||||
}
|
||||
if len(buf) != 32*1024 {
|
||||
t.Errorf("Expected buffer size 32KB, got %d", len(buf))
|
||||
}
|
||||
|
||||
PutBuffer(buf)
|
||||
}
|
||||
|
||||
func TestMinSendfileSize(t *testing.T) {
|
||||
if MinSendfileSize != 8*1024 {
|
||||
t.Errorf("Expected MinSendfileSize 8KB, got %d", MinSendfileSize)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetBuffer(t *testing.T) {
|
||||
buf := GetBuffer()
|
||||
if buf == nil {
|
||||
t.Error("Expected non-nil buffer")
|
||||
return
|
||||
}
|
||||
if len(buf) != 32*1024 {
|
||||
t.Errorf("Expected buffer size 32KB, got %d", len(buf))
|
||||
}
|
||||
|
||||
// 测试写入
|
||||
copy(buf, []byte("test"))
|
||||
if string(buf[:4]) != "test" {
|
||||
t.Error("Expected to write 'test' to buffer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlatformSendfile(t *testing.T) {
|
||||
// 创建临时文件
|
||||
tmpDir := t.TempDir()
|
||||
@ -90,24 +41,6 @@ func TestPlatformSendfile(t *testing.T) {
|
||||
_ = platformSendfile(nil, file, 0, int64(len(content)))
|
||||
}
|
||||
|
||||
func TestBufferPoolConcurrent(t *testing.T) {
|
||||
const iterations = 100
|
||||
|
||||
done := make(chan bool)
|
||||
|
||||
for i := 0; i < iterations; i++ {
|
||||
go func() {
|
||||
buf := GetBuffer()
|
||||
PutBuffer(buf)
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
for i := 0; i < iterations; i++ {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
// TestCopyFile 测试 copyFile fallback 函数
|
||||
func TestCopyFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
@ -11,11 +11,9 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
@ -159,93 +157,3 @@ func (a *Adapter) convertResponse(ctx *fasthttp.RequestCtx, w http.ResponseWrite
|
||||
func (a *Adapter) WrapHandler(handler fasthttp.RequestHandler) http.Handler {
|
||||
return a.Wrap(handler)
|
||||
}
|
||||
|
||||
// FastHTTPHandler 从 http.Handler 提取并调用 fasthttp 处理器。
|
||||
//
|
||||
// 这是一个便捷方法,用于在需要时反向转换。
|
||||
//
|
||||
// 参数:
|
||||
// - h: 标准库 http.Handler
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
func FastHTTPHandler(h http.Handler, ctx *fasthttp.RequestCtx) {
|
||||
// 创建虚拟 ResponseWriter
|
||||
rw := &fastHTTPResponseWriter{
|
||||
ctx: ctx,
|
||||
}
|
||||
|
||||
// 转换请求
|
||||
r := convertToHTTPRequest(ctx)
|
||||
|
||||
// 调用标准库 handler
|
||||
h.ServeHTTP(rw, r)
|
||||
}
|
||||
|
||||
// fastHTTPResponseWriter 实现 http.ResponseWriter 接口。
|
||||
type fastHTTPResponseWriter struct {
|
||||
ctx *fasthttp.RequestCtx
|
||||
status int
|
||||
header http.Header
|
||||
written bool
|
||||
}
|
||||
|
||||
func (w *fastHTTPResponseWriter) Header() http.Header {
|
||||
if w.header == nil {
|
||||
w.header = make(http.Header)
|
||||
}
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *fastHTTPResponseWriter) Write(data []byte) (int, error) {
|
||||
if !w.written {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
return w.ctx.Write(data)
|
||||
}
|
||||
|
||||
func (w *fastHTTPResponseWriter) WriteHeader(statusCode int) {
|
||||
if w.written {
|
||||
return
|
||||
}
|
||||
w.written = true
|
||||
w.status = statusCode
|
||||
|
||||
// 复制头部
|
||||
for k, v := range w.header {
|
||||
for _, vv := range v {
|
||||
w.ctx.Response.Header.Add(k, vv)
|
||||
}
|
||||
}
|
||||
|
||||
w.ctx.SetStatusCode(statusCode)
|
||||
}
|
||||
|
||||
// convertToHTTPRequest 将 fasthttp.RequestCtx 转换为 http.Request。
|
||||
func convertToHTTPRequest(ctx *fasthttp.RequestCtx) *http.Request {
|
||||
r := &http.Request{
|
||||
Method: string(ctx.Method()),
|
||||
Host: string(ctx.Host()),
|
||||
RemoteAddr: ctx.RemoteAddr().String(),
|
||||
Proto: "HTTP/3",
|
||||
ProtoMajor: 3,
|
||||
ProtoMinor: 0,
|
||||
}
|
||||
|
||||
// 构建 URL
|
||||
r.URL = &url.URL{
|
||||
Path: string(ctx.Path()),
|
||||
RawQuery: string(ctx.URI().QueryString()),
|
||||
}
|
||||
|
||||
// 复制头部
|
||||
r.Header = make(http.Header)
|
||||
for k, v := range ctx.Request.Header.All() {
|
||||
r.Header.Add(string(k), string(v))
|
||||
}
|
||||
|
||||
// 设置请求体
|
||||
if len(ctx.PostBody()) > 0 {
|
||||
r.Body = io.NopCloser(bytes.NewReader(ctx.PostBody()))
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
@ -299,143 +299,6 @@ func TestConvertResponse_Body(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestFastHTTPHandler 测试反向转换
|
||||
func TestFastHTTPHandler(t *testing.T) {
|
||||
// 创建标准库 handler
|
||||
stdHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(200)
|
||||
_, _ = w.Write([]byte("Hello from std http"))
|
||||
})
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
ctx.Request.SetRequestURI("/test")
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
|
||||
FastHTTPHandler(stdHandler, ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != 200 {
|
||||
t.Errorf("Expected status 200, got %d", ctx.Response.StatusCode())
|
||||
}
|
||||
|
||||
if string(ctx.Response.Body()) != "Hello from std http" {
|
||||
t.Errorf("Expected body 'Hello from std http', got %s", ctx.Response.Body())
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertToHTTPRequest 测试转换为标准库请求
|
||||
func TestConvertToHTTPRequest(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
ctx.Request.SetRequestURI("/path?query=value")
|
||||
ctx.Request.Header.SetMethod("POST")
|
||||
ctx.Request.Header.SetHost("example.com")
|
||||
ctx.Request.Header.Set("Content-Type", "application/json")
|
||||
ctx.Request.SetBody([]byte("test body"))
|
||||
|
||||
r := convertToHTTPRequest(ctx)
|
||||
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("Expected Method POST, got %s", r.Method)
|
||||
}
|
||||
|
||||
if r.Host != "example.com" {
|
||||
t.Errorf("Expected Host example.com, got %s", r.Host)
|
||||
}
|
||||
|
||||
if r.URL.Path != "/path" {
|
||||
t.Errorf("Expected Path /path, got %s", r.URL.Path)
|
||||
}
|
||||
|
||||
if r.URL.RawQuery != "query=value" {
|
||||
t.Errorf("Expected RawQuery query=value, got %s", r.URL.RawQuery)
|
||||
}
|
||||
|
||||
if r.Proto != "HTTP/3" {
|
||||
t.Errorf("Expected Proto HTTP/3, got %s", r.Proto)
|
||||
}
|
||||
|
||||
if r.ProtoMajor != 3 || r.ProtoMinor != 0 {
|
||||
t.Errorf("Expected Proto 3.0, got %d.%d", r.ProtoMajor, r.ProtoMinor)
|
||||
}
|
||||
|
||||
// 检查头部
|
||||
if r.Header.Get("Content-Type") != "application/json" {
|
||||
t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
|
||||
}
|
||||
|
||||
// 检查请求体
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
if string(body) != "test body" {
|
||||
t.Errorf("Expected body 'test body', got %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFastHTTPResponseWriter_Write 测试 Write 方法
|
||||
func TestFastHTTPResponseWriter_Write(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
rw := &fastHTTPResponseWriter{ctx: ctx}
|
||||
|
||||
n, err := rw.Write([]byte("test content"))
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if n != len("test content") {
|
||||
t.Errorf("Expected written %d, got %d", len("test content"), n)
|
||||
}
|
||||
|
||||
// 检查状态码被自动设置
|
||||
if rw.status != http.StatusOK {
|
||||
t.Errorf("Expected auto-set status 200, got %d", rw.status)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFastHTTPResponseWriter_WriteHeader 测试 WriteHeader 方法
|
||||
func TestFastHTTPResponseWriter_WriteHeader(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
rw := &fastHTTPResponseWriter{ctx: ctx}
|
||||
|
||||
rw.Header().Set("X-Custom", "value")
|
||||
rw.WriteHeader(404)
|
||||
|
||||
if rw.status != 404 {
|
||||
t.Errorf("Expected status 404, got %d", rw.status)
|
||||
}
|
||||
|
||||
if rw.written != true {
|
||||
t.Error("Expected written flag to be true")
|
||||
}
|
||||
|
||||
// 再次调用应该被忽略
|
||||
rw.WriteHeader(500)
|
||||
if rw.status != 404 {
|
||||
t.Errorf("Expected status to remain 404, got %d", rw.status)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFastHTTPResponseWriter_Header 测试 Header 方法
|
||||
func TestFastHTTPResponseWriter_Header(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
rw := &fastHTTPResponseWriter{ctx: ctx}
|
||||
|
||||
h := rw.Header()
|
||||
if h == nil {
|
||||
t.Error("Expected non-nil header")
|
||||
}
|
||||
|
||||
h.Set("Content-Type", "text/html")
|
||||
if rw.Header().Get("Content-Type") != "text/html" {
|
||||
t.Errorf("Expected Content-Type text/html, got %s", rw.Header().Get("Content-Type"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestWrap_RoundTrip 完整流程测试
|
||||
func TestWrap_RoundTrip(t *testing.T) {
|
||||
adapter := NewAdapter()
|
||||
|
||||
@ -195,19 +195,14 @@ func (ac *AccessControl) Check(ip net.IP) bool {
|
||||
// 返回值:
|
||||
// - error: CIDR 解析失败时返回错误
|
||||
func (ac *AccessControl) UpdateAllowList(cidrs []string) error {
|
||||
ac.mu.Lock()
|
||||
defer ac.mu.Unlock()
|
||||
|
||||
newList := make([]net.IPNet, 0, len(cidrs))
|
||||
for _, cidr := range cidrs {
|
||||
network, err := parseCIDR(cidr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid CIDR %s: %w", cidr, err)
|
||||
}
|
||||
newList = append(newList, *network)
|
||||
newList, err := parseCIDRList(cidrs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ac.mu.Lock()
|
||||
ac.allowList = newList
|
||||
ac.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -221,20 +216,35 @@ func (ac *AccessControl) UpdateAllowList(cidrs []string) error {
|
||||
// 返回值:
|
||||
// - error: CIDR 解析失败时返回错误
|
||||
func (ac *AccessControl) UpdateDenyList(cidrs []string) error {
|
||||
ac.mu.Lock()
|
||||
defer ac.mu.Unlock()
|
||||
newList, err := parseCIDRList(cidrs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ac.mu.Lock()
|
||||
ac.denyList = newList
|
||||
ac.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseCIDRList 解析 CIDR 字符串列表为 IPNet 列表。
|
||||
//
|
||||
// 参数:
|
||||
// - cidrs: CIDR 字符串列表
|
||||
//
|
||||
// 返回值:
|
||||
// - []net.IPNet: 解析后的 IP 网络对象列表
|
||||
// - error: 任一 CIDR 解析失败时返回错误
|
||||
func parseCIDRList(cidrs []string) ([]net.IPNet, error) {
|
||||
newList := make([]net.IPNet, 0, len(cidrs))
|
||||
for _, cidr := range cidrs {
|
||||
network, err := parseCIDR(cidr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid CIDR %s: %w", cidr, err)
|
||||
return nil, fmt.Errorf("invalid CIDR %s: %w", cidr, err)
|
||||
}
|
||||
newList = append(newList, *network)
|
||||
}
|
||||
|
||||
ac.denyList = newList
|
||||
return nil
|
||||
return newList, nil
|
||||
}
|
||||
|
||||
// SetDefault 设置默认操作。
|
||||
|
||||
@ -32,8 +32,6 @@ package security
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@ -41,6 +39,7 @@ import (
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/middleware"
|
||||
"rua.plus/lolly/internal/netutil"
|
||||
)
|
||||
|
||||
// RateLimiter 基于令牌桶算法的请求速率限制器。
|
||||
@ -322,53 +321,13 @@ func (rl *RateLimiter) getRetryAfter(key string) int64 {
|
||||
// 返回值:
|
||||
// - string: IP 地址字符串,无法获取时返回 "unknown"
|
||||
func keyByIP(ctx *fasthttp.RequestCtx) string {
|
||||
ip := extractClientIP(ctx)
|
||||
ip := netutil.ExtractClientIPNet(ctx)
|
||||
if ip == nil {
|
||||
return "unknown"
|
||||
}
|
||||
return ip.String()
|
||||
}
|
||||
|
||||
// extractClientIP 从请求上下文提取客户端 IP。
|
||||
//
|
||||
// 按优先级依次检查:X-Forwarded-For、X-Real-IP、RemoteAddr。
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
//
|
||||
// 返回值:
|
||||
// - net.IP: 客户端 IP 地址,无法获取时返回 nil
|
||||
func extractClientIP(ctx *fasthttp.RequestCtx) net.IP {
|
||||
// 优先检查 X-Forwarded-For 头部
|
||||
if xff := ctx.Request.Header.Peek("X-Forwarded-For"); len(xff) > 0 {
|
||||
ips := strings.Split(string(xff), ",")
|
||||
if len(ips) > 0 {
|
||||
ipStr := strings.TrimSpace(ips[0])
|
||||
ip := net.ParseIP(ipStr)
|
||||
if ip != nil {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 X-Real-IP 头部
|
||||
if xri := ctx.Request.Header.Peek("X-Real-IP"); len(xri) > 0 {
|
||||
ip := net.ParseIP(string(xri))
|
||||
if ip != nil {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
|
||||
// 回退到 RemoteAddr
|
||||
if addr := ctx.RemoteAddr(); addr != nil {
|
||||
if tcpAddr, ok := addr.(*net.TCPAddr); ok {
|
||||
return tcpAddr.IP
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// keyByHeader 提取头部值作为限流键。
|
||||
//
|
||||
// 默认使用 X-RateLimit-Key 头部,如果不存在则回退到 IP。
|
||||
|
||||
112
internal/netutil/ip.go
Normal file
112
internal/netutil/ip.go
Normal file
@ -0,0 +1,112 @@
|
||||
// Package netutil 提供网络相关的通用工具函数。
|
||||
//
|
||||
// 该文件包含客户端 IP 提取相关的工具函数,
|
||||
// 从 HTTP 请求中提取真实的客户端 IP 地址。
|
||||
//
|
||||
// 作者:xfy
|
||||
package netutil
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// ExtractClientIP 从请求上下文中提取客户端 IP 地址(返回字符串)。
|
||||
//
|
||||
// 该函数按以下顺序提取 IP:
|
||||
// 1. X-Forwarded-For 请求头的第一个 IP(最左侧)
|
||||
// 2. X-Real-IP 请求头
|
||||
// 3. RemoteAddr
|
||||
//
|
||||
// 注意:此函数不进行可信代理验证,适用于非安全场景(如日志记录)。
|
||||
// 对于安全场景(如访问控制),应使用特定模块的安全实现。
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
//
|
||||
// 返回值:
|
||||
// - string: 客户端 IP 地址字符串
|
||||
func ExtractClientIP(ctx *fasthttp.RequestCtx) string {
|
||||
// 首先检查 X-Forwarded-For 请求头
|
||||
if xff := ctx.Request.Header.Peek("X-Forwarded-For"); len(xff) > 0 {
|
||||
ips := strings.Split(string(xff), ",")
|
||||
if len(ips) > 0 {
|
||||
return strings.TrimSpace(ips[0])
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 X-Real-IP 请求头
|
||||
if xri := ctx.Request.Header.Peek("X-Real-IP"); len(xri) > 0 {
|
||||
return string(xri)
|
||||
}
|
||||
|
||||
// 回退到 RemoteAddr
|
||||
if addr := ctx.RemoteAddr(); addr != nil {
|
||||
if tcpAddr, ok := addr.(*net.TCPAddr); ok {
|
||||
return tcpAddr.IP.String()
|
||||
}
|
||||
return addr.String()
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// ExtractClientIPNet 从请求上下文中提取客户端 IP 地址(返回 net.IP)。
|
||||
//
|
||||
// 该函数与 ExtractClientIP 功能相同,但返回 net.IP 类型,
|
||||
// 便于后续进行 IP 网络操作(如 CIDR 匹配)。
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
//
|
||||
// 返回值:
|
||||
// - net.IP: 客户端 IP 地址,无法解析时返回 nil
|
||||
func ExtractClientIPNet(ctx *fasthttp.RequestCtx) net.IP {
|
||||
// 首先检查 X-Forwarded-For 请求头
|
||||
if xff := ctx.Request.Header.Peek("X-Forwarded-For"); len(xff) > 0 {
|
||||
ips := strings.Split(string(xff), ",")
|
||||
if len(ips) > 0 {
|
||||
ipStr := strings.TrimSpace(ips[0])
|
||||
if ip := net.ParseIP(ipStr); ip != nil {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 X-Real-IP 请求头
|
||||
if xri := ctx.Request.Header.Peek("X-Real-IP"); len(xri) > 0 {
|
||||
if ip := net.ParseIP(string(xri)); ip != nil {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
|
||||
// 回退到 RemoteAddr
|
||||
if addr := ctx.RemoteAddr(); addr != nil {
|
||||
if tcpAddr, ok := addr.(*net.TCPAddr); ok {
|
||||
return tcpAddr.IP
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRemoteAddrIP 从 RemoteAddr 提取 IP 地址。
|
||||
//
|
||||
// 这是一个辅助函数,直接从连接的远程地址获取 IP,
|
||||
// 不检查任何代理头。
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
//
|
||||
// 返回值:
|
||||
// - net.IP: 客户端 IP 地址,无法获取时返回 nil
|
||||
func GetRemoteAddrIP(ctx *fasthttp.RequestCtx) net.IP {
|
||||
if addr := ctx.RemoteAddr(); addr != nil {
|
||||
if tcpAddr, ok := addr.(*net.TCPAddr); ok {
|
||||
return tcpAddr.IP
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
123
internal/netutil/ip_test.go
Normal file
123
internal/netutil/ip_test.go
Normal file
@ -0,0 +1,123 @@
|
||||
package netutil
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func TestExtractClientIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
xff string
|
||||
xri string
|
||||
remoteAddr string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "X-Forwarded-For with single IP",
|
||||
xff: "192.168.1.100",
|
||||
want: "192.168.1.100",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-For with multiple IPs",
|
||||
xff: "192.168.1.100, 10.0.0.1, 172.16.0.1",
|
||||
want: "192.168.1.100",
|
||||
},
|
||||
{
|
||||
name: "X-Real-IP only",
|
||||
xri: "192.168.1.200",
|
||||
want: "192.168.1.200",
|
||||
},
|
||||
{
|
||||
name: "RemoteAddr fallback",
|
||||
remoteAddr: "192.168.1.1:12345",
|
||||
want: "0.0.0.0", // fasthttp 默认初始化为 0.0.0.0
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-For takes precedence over X-Real-IP",
|
||||
xff: "192.168.1.100",
|
||||
xri: "192.168.1.200",
|
||||
want: "192.168.1.100",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
if tt.xff != "" {
|
||||
ctx.Request.Header.Set("X-Forwarded-For", tt.xff)
|
||||
}
|
||||
if tt.xri != "" {
|
||||
ctx.Request.Header.Set("X-Real-IP", tt.xri)
|
||||
}
|
||||
|
||||
got := ExtractClientIP(ctx)
|
||||
if got != tt.want {
|
||||
t.Errorf("ExtractClientIP() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractClientIPNet(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
xff string
|
||||
xri string
|
||||
want net.IP
|
||||
}{
|
||||
{
|
||||
name: "X-Forwarded-For valid IP",
|
||||
xff: "192.168.1.100",
|
||||
want: net.ParseIP("192.168.1.100"),
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-For invalid IP",
|
||||
xff: "invalid-ip",
|
||||
want: net.ParseIP("0.0.0.0"), // fasthttp 默认 RemoteAddr
|
||||
},
|
||||
{
|
||||
name: "X-Real-IP valid IP",
|
||||
xri: "192.168.1.200",
|
||||
want: net.ParseIP("192.168.1.200"),
|
||||
},
|
||||
{
|
||||
name: "No headers",
|
||||
want: net.ParseIP("0.0.0.0"), // fasthttp 默认 RemoteAddr
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
if tt.xff != "" {
|
||||
ctx.Request.Header.Set("X-Forwarded-For", tt.xff)
|
||||
}
|
||||
if tt.xri != "" {
|
||||
ctx.Request.Header.Set("X-Real-IP", tt.xri)
|
||||
}
|
||||
|
||||
got := ExtractClientIPNet(ctx)
|
||||
if !got.Equal(tt.want) {
|
||||
t.Errorf("ExtractClientIPNet() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRemoteAddrIP(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
// Without setting remote addr, should return nil
|
||||
got := GetRemoteAddrIP(ctx)
|
||||
// The result depends on how fasthttp initializes the remote addr
|
||||
// Just verify it doesn't panic
|
||||
_ = got
|
||||
}
|
||||
76
internal/netutil/url.go
Normal file
76
internal/netutil/url.go
Normal file
@ -0,0 +1,76 @@
|
||||
// Package netutil 提供网络相关的通用工具函数。
|
||||
//
|
||||
// 该包包含 URL 解析、客户端 IP 提取等网络操作的工具函数,
|
||||
// 供 proxy、middleware、server 等模块共享使用。
|
||||
//
|
||||
// 作者:xfy
|
||||
package netutil
|
||||
|
||||
import "strings"
|
||||
|
||||
// ParseTargetURL 解析目标 URL,提取主机地址和 TLS 标志。
|
||||
//
|
||||
// 该函数用于统一处理代理模块中的 URL 解析逻辑,支持 http:// 和 https:// 前缀。
|
||||
//
|
||||
// 参数:
|
||||
// - targetURL: 目标 URL 字符串(如 "http://backend:8080/path" 或 "https://api.example.com")
|
||||
// - addDefaultPort: 是否在没有端口时添加默认端口(:80 或 :443)
|
||||
//
|
||||
// 返回值:
|
||||
// - addr: 主机地址(格式 host:port)
|
||||
// - isTLS: 是否使用 TLS(HTTPS)
|
||||
//
|
||||
// 示例:
|
||||
//
|
||||
// addr, isTLS := ParseTargetURL("https://api.example.com/api", true)
|
||||
// // addr = "api.example.com:443", isTLS = true
|
||||
//
|
||||
// addr, isTLS := ParseTargetURL("http://backend:8080", false)
|
||||
// // addr = "backend:8080", isTLS = false
|
||||
func ParseTargetURL(targetURL string, addDefaultPort bool) (addr string, isTLS bool) {
|
||||
addr = targetURL
|
||||
|
||||
// 处理协议前缀
|
||||
if strings.HasPrefix(targetURL, "http://") {
|
||||
addr = targetURL[7:]
|
||||
} else if strings.HasPrefix(targetURL, "https://") {
|
||||
addr = targetURL[8:]
|
||||
isTLS = true
|
||||
}
|
||||
|
||||
// 移除路径部分,只保留 host:port
|
||||
if idx := strings.Index(addr, "/"); idx != -1 {
|
||||
addr = addr[:idx]
|
||||
}
|
||||
|
||||
// 如果需要,添加默认端口
|
||||
if addDefaultPort && !strings.Contains(addr, ":") {
|
||||
if isTLS {
|
||||
addr = addr + ":443"
|
||||
} else {
|
||||
addr = addr + ":80"
|
||||
}
|
||||
}
|
||||
|
||||
return addr, isTLS
|
||||
}
|
||||
|
||||
// ExtractHost 从 URL 提取主机地址(host:port)。
|
||||
//
|
||||
// 该函数是 ParseTargetURL 的简化版本,始终添加默认端口,
|
||||
// 用于需要完整地址但不需要 TLS 标志的场景。
|
||||
//
|
||||
// 参数:
|
||||
// - targetURL: 目标 URL 字符串
|
||||
//
|
||||
// 返回值:
|
||||
// - string: 主机地址(格式 host:port)
|
||||
//
|
||||
// 示例:
|
||||
//
|
||||
// host := ExtractHost("https://api.example.com/api")
|
||||
// // host = "api.example.com:443"
|
||||
func ExtractHost(targetURL string) string {
|
||||
addr, _ := ParseTargetURL(targetURL, true)
|
||||
return addr
|
||||
}
|
||||
147
internal/netutil/url_test.go
Normal file
147
internal/netutil/url_test.go
Normal file
@ -0,0 +1,147 @@
|
||||
package netutil
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseTargetURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
targetURL string
|
||||
addDefaultPort bool
|
||||
wantAddr string
|
||||
wantIsTLS bool
|
||||
}{
|
||||
// HTTP without port
|
||||
{
|
||||
name: "http without port, add default",
|
||||
targetURL: "http://backend.example.com",
|
||||
addDefaultPort: true,
|
||||
wantAddr: "backend.example.com:80",
|
||||
wantIsTLS: false,
|
||||
},
|
||||
{
|
||||
name: "http without port, no default",
|
||||
targetURL: "http://backend.example.com",
|
||||
addDefaultPort: false,
|
||||
wantAddr: "backend.example.com",
|
||||
wantIsTLS: false,
|
||||
},
|
||||
// HTTPS without port
|
||||
{
|
||||
name: "https without port, add default",
|
||||
targetURL: "https://api.example.com",
|
||||
addDefaultPort: true,
|
||||
wantAddr: "api.example.com:443",
|
||||
wantIsTLS: true,
|
||||
},
|
||||
{
|
||||
name: "https without port, no default",
|
||||
targetURL: "https://api.example.com",
|
||||
addDefaultPort: false,
|
||||
wantAddr: "api.example.com",
|
||||
wantIsTLS: true,
|
||||
},
|
||||
// HTTP with port
|
||||
{
|
||||
name: "http with port",
|
||||
targetURL: "http://backend:8080",
|
||||
addDefaultPort: true,
|
||||
wantAddr: "backend:8080",
|
||||
wantIsTLS: false,
|
||||
},
|
||||
// HTTPS with port
|
||||
{
|
||||
name: "https with port",
|
||||
targetURL: "https://api:8443",
|
||||
addDefaultPort: true,
|
||||
wantAddr: "api:8443",
|
||||
wantIsTLS: true,
|
||||
},
|
||||
// With path
|
||||
{
|
||||
name: "http with path",
|
||||
targetURL: "http://backend:8080/api/v1",
|
||||
addDefaultPort: false,
|
||||
wantAddr: "backend:8080",
|
||||
wantIsTLS: false,
|
||||
},
|
||||
{
|
||||
name: "https with path",
|
||||
targetURL: "https://api.example.com/v1/users",
|
||||
addDefaultPort: true,
|
||||
wantAddr: "api.example.com:443",
|
||||
wantIsTLS: true,
|
||||
},
|
||||
// No protocol (treat as HTTP)
|
||||
{
|
||||
name: "no protocol",
|
||||
targetURL: "backend:8080",
|
||||
addDefaultPort: false,
|
||||
wantAddr: "backend:8080",
|
||||
wantIsTLS: false,
|
||||
},
|
||||
{
|
||||
name: "no protocol, no port, add default",
|
||||
targetURL: "backend",
|
||||
addDefaultPort: true,
|
||||
wantAddr: "backend:80",
|
||||
wantIsTLS: false,
|
||||
},
|
||||
// IPv6 address
|
||||
{
|
||||
name: "ipv6 address",
|
||||
targetURL: "http://[::1]:8080",
|
||||
addDefaultPort: false,
|
||||
wantAddr: "[::1]:8080",
|
||||
wantIsTLS: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotAddr, gotIsTLS := ParseTargetURL(tt.targetURL, tt.addDefaultPort)
|
||||
if gotAddr != tt.wantAddr {
|
||||
t.Errorf("ParseTargetURL() addr = %q, want %q", gotAddr, tt.wantAddr)
|
||||
}
|
||||
if gotIsTLS != tt.wantIsTLS {
|
||||
t.Errorf("ParseTargetURL() isTLS = %v, want %v", gotIsTLS, tt.wantIsTLS)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractHost(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
targetURL string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "http without port",
|
||||
targetURL: "http://backend.example.com",
|
||||
want: "backend.example.com:80",
|
||||
},
|
||||
{
|
||||
name: "https without port",
|
||||
targetURL: "https://api.example.com",
|
||||
want: "api.example.com:443",
|
||||
},
|
||||
{
|
||||
name: "http with port",
|
||||
targetURL: "http://backend:8080",
|
||||
want: "backend:8080",
|
||||
},
|
||||
{
|
||||
name: "https with path",
|
||||
targetURL: "https://api.example.com/v1/users",
|
||||
want: "api.example.com:443",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := ExtractHost(tt.targetURL); got != tt.want {
|
||||
t.Errorf("ExtractHost() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -34,7 +34,6 @@ package proxy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@ -44,6 +43,7 @@ import (
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/loadbalance"
|
||||
"rua.plus/lolly/internal/logging"
|
||||
"rua.plus/lolly/internal/netutil"
|
||||
)
|
||||
|
||||
// Proxy 表示反向代理实例,负责将 HTTP 请求转发到后端目标。
|
||||
@ -147,20 +147,7 @@ func createBalancer(cfg *config.ProxyConfig) (loadbalance.Balancer, error) {
|
||||
// createHostClient 为后台目标 URL 创建 fasthttp.HostClient。
|
||||
func createHostClient(targetURL string, timeout config.ProxyTimeout, transportCfg *config.TransportConfig) *fasthttp.HostClient {
|
||||
// 从目标 URL 解析主机和协议
|
||||
addr := targetURL
|
||||
isTLS := false
|
||||
|
||||
if strings.HasPrefix(targetURL, "http://") {
|
||||
addr = targetURL[7:]
|
||||
} else if strings.HasPrefix(targetURL, "https://") {
|
||||
addr = targetURL[8:]
|
||||
isTLS = true
|
||||
}
|
||||
|
||||
// 如果存在路径则移除,只保留 host:port
|
||||
if idx := strings.Index(addr, "/"); idx != -1 {
|
||||
addr = addr[:idx]
|
||||
}
|
||||
addr, isTLS := netutil.ParseTargetURL(targetURL, false)
|
||||
|
||||
// 默认值
|
||||
maxIdleConnDuration := 90 * time.Second
|
||||
@ -321,7 +308,7 @@ func (p *Proxy) selectTarget(ctx *fasthttp.RequestCtx) *loadbalance.Target {
|
||||
|
||||
// 对于 IPHash 负载均衡器,提取客户端 IP
|
||||
if ipHash, ok := balancer.(*loadbalance.IPHash); ok {
|
||||
clientIP := getClientIP(ctx)
|
||||
clientIP := netutil.ExtractClientIP(ctx)
|
||||
return ipHash.SelectByIP(targets, clientIP)
|
||||
}
|
||||
|
||||
@ -339,7 +326,7 @@ func (p *Proxy) selectTarget(ctx *fasthttp.RequestCtx) *loadbalance.Target {
|
||||
func (p *Proxy) extractHashKey(ctx *fasthttp.RequestCtx, hashKey string) string {
|
||||
switch {
|
||||
case hashKey == "ip" || hashKey == "":
|
||||
return getClientIP(ctx)
|
||||
return netutil.ExtractClientIP(ctx)
|
||||
case hashKey == "uri":
|
||||
return string(ctx.RequestURI())
|
||||
case strings.HasPrefix(hashKey, "header:"):
|
||||
@ -348,38 +335,12 @@ func (p *Proxy) extractHashKey(ctx *fasthttp.RequestCtx, hashKey string) string
|
||||
if len(value) > 0 {
|
||||
return string(value)
|
||||
}
|
||||
return getClientIP(ctx) // fallback to IP
|
||||
return netutil.ExtractClientIP(ctx) // fallback to IP
|
||||
default:
|
||||
return getClientIP(ctx)
|
||||
return netutil.ExtractClientIP(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// getClientIP 从请求上下文中提取客户端 IP 地址。
|
||||
func getClientIP(ctx *fasthttp.RequestCtx) string {
|
||||
// 首先检查 X-Forwarded-For 请求头
|
||||
if xff := ctx.Request.Header.Peek("X-Forwarded-For"); len(xff) > 0 {
|
||||
ips := strings.Split(string(xff), ",")
|
||||
if len(ips) > 0 {
|
||||
return strings.TrimSpace(ips[0])
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 X-Real-IP 请求头
|
||||
if xri := ctx.Request.Header.Peek("X-Real-IP"); len(xri) > 0 {
|
||||
return string(xri)
|
||||
}
|
||||
|
||||
// 回退到 RemoteAddr
|
||||
if addr := ctx.RemoteAddr(); addr != nil {
|
||||
if tcpAddr, ok := addr.(*net.TCPAddr); ok {
|
||||
return tcpAddr.IP.String()
|
||||
}
|
||||
return addr.String()
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// getClient 返回给定目标 URL 对应的 HostClient。
|
||||
func (p *Proxy) getClient(targetURL string) *fasthttp.HostClient {
|
||||
p.mu.RLock()
|
||||
@ -394,7 +355,7 @@ func (p *Proxy) modifyRequestHeaders(ctx *fasthttp.RequestCtx, target *loadbalan
|
||||
headers := &ctx.Request.Header
|
||||
|
||||
// 添加 X-Real-IP 请求头
|
||||
clientIP := getClientIP(ctx)
|
||||
clientIP := netutil.ExtractClientIP(ctx)
|
||||
if clientIP != "" {
|
||||
headers.Set("X-Real-IP", clientIP)
|
||||
}
|
||||
|
||||
@ -12,6 +12,7 @@ import (
|
||||
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/loadbalance"
|
||||
"rua.plus/lolly/internal/netutil"
|
||||
)
|
||||
|
||||
// TestNewProxy 测试 NewProxy 函数
|
||||
@ -587,9 +588,9 @@ func TestGetClientIP(t *testing.T) {
|
||||
ctx.Request.Header.Set("X-Real-IP", tt.xri)
|
||||
}
|
||||
|
||||
ip := getClientIP(ctx)
|
||||
ip := netutil.ExtractClientIP(ctx)
|
||||
if ip != tt.expected {
|
||||
t.Errorf("getClientIP() = %q, want %q", ip, tt.expected)
|
||||
t.Errorf("ExtractClientIP() = %q, want %q", ip, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@ -33,6 +33,7 @@ import (
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/loadbalance"
|
||||
"rua.plus/lolly/internal/netutil"
|
||||
)
|
||||
|
||||
// WebSocketBridge WebSocket 桥接器。
|
||||
@ -207,30 +208,7 @@ func isConnectionClosedError(err error) bool {
|
||||
// - error: 连接失败时返回错误
|
||||
func dialTarget(targetURL string, timeout time.Duration) (net.Conn, error) {
|
||||
// 解析目标 URL
|
||||
isTLS := false
|
||||
addr := targetURL
|
||||
|
||||
// 处理协议前缀
|
||||
if strings.HasPrefix(targetURL, "http://") {
|
||||
addr = targetURL[7:]
|
||||
} else if strings.HasPrefix(targetURL, "https://") {
|
||||
addr = targetURL[8:]
|
||||
isTLS = true
|
||||
}
|
||||
|
||||
// 移除路径部分,只保留 host:port
|
||||
if idx := strings.Index(addr, "/"); idx != -1 {
|
||||
addr = addr[:idx]
|
||||
}
|
||||
|
||||
// 如果没有端口,添加默认端口
|
||||
if !strings.Contains(addr, ":") {
|
||||
if isTLS {
|
||||
addr = addr + ":443"
|
||||
} else {
|
||||
addr = addr + ":80"
|
||||
}
|
||||
}
|
||||
addr, isTLS := netutil.ParseTargetURL(targetURL, true)
|
||||
|
||||
// 建立 TCP 连接
|
||||
dialer := &net.Dialer{
|
||||
@ -309,7 +287,7 @@ func buildWebSocketUpgradeRequest(ctx *fasthttp.RequestCtx, targetHost string) s
|
||||
}
|
||||
|
||||
// 添加 X-Forwarded 头
|
||||
clientIP := getClientIP(ctx)
|
||||
clientIP := netutil.ExtractClientIP(ctx)
|
||||
if clientIP != "" {
|
||||
fmt.Fprintf(&req, "X-Forwarded-For: %s\r\n", clientIP)
|
||||
fmt.Fprintf(&req, "X-Real-IP: %s\r\n", clientIP)
|
||||
@ -394,49 +372,37 @@ func ProxyWebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, timeou
|
||||
return fmt.Errorf("failed to connect to backend: %w", err)
|
||||
}
|
||||
|
||||
// 创建桥接器管理两个连接
|
||||
bridge := NewWebSocketBridge(clientConn, targetConn)
|
||||
defer func() { _ = bridge.Close() }()
|
||||
|
||||
// 步骤2: 从目标 URL 提取主机地址
|
||||
targetHost := extractHost(target.URL)
|
||||
|
||||
// 步骤3: 构建并发送 WebSocket 升级请求
|
||||
upgradeReq := buildWebSocketUpgradeRequest(ctx, targetHost)
|
||||
if _, err := targetConn.Write([]byte(upgradeReq)); err != nil {
|
||||
_ = clientConn.Close()
|
||||
_ = targetConn.Close()
|
||||
return fmt.Errorf("failed to send upgrade request: %w", err)
|
||||
}
|
||||
|
||||
// 步骤4: 读取升级响应
|
||||
resp, err := readWebSocketUpgradeResponse(targetConn, timeout)
|
||||
if err != nil {
|
||||
_ = clientConn.Close()
|
||||
_ = targetConn.Close()
|
||||
return fmt.Errorf("failed to read upgrade response: %w", err)
|
||||
}
|
||||
|
||||
// 步骤5: 检查响应状态码(期望 101 Switching Protocols)
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
_ = clientConn.Close()
|
||||
_ = targetConn.Close()
|
||||
return fmt.Errorf("backend rejected WebSocket upgrade: %s", resp.Status)
|
||||
}
|
||||
|
||||
// 步骤6: 将升级响应发送回客户端
|
||||
if err := writeUpgradeResponse(clientConn, resp); err != nil {
|
||||
_ = clientConn.Close()
|
||||
_ = targetConn.Close()
|
||||
return fmt.Errorf("failed to send upgrade response to client: %w", err)
|
||||
}
|
||||
|
||||
// 步骤7: 创建桥接器并启动双向转发
|
||||
bridge := NewWebSocketBridge(clientConn, targetConn)
|
||||
|
||||
// 启动桥接(阻塞直到连接关闭)
|
||||
bridgeErr := bridge.Bridge()
|
||||
|
||||
// 清理:关闭连接
|
||||
_ = bridge.Close()
|
||||
|
||||
return bridgeErr
|
||||
// 步骤7: 启动桥接(阻塞直到连接关闭)
|
||||
return bridge.Bridge()
|
||||
}
|
||||
|
||||
// extractHost 从 URL 中提取主机地址(带端口)。
|
||||
@ -449,28 +415,7 @@ func ProxyWebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, timeou
|
||||
// 返回值:
|
||||
// - string: 主机地址(格式 host:port)
|
||||
func extractHost(url string) string {
|
||||
addr := url
|
||||
if strings.HasPrefix(url, "http://") {
|
||||
addr = url[7:]
|
||||
} else if strings.HasPrefix(url, "https://") {
|
||||
addr = url[8:]
|
||||
}
|
||||
|
||||
// 移除路径部分
|
||||
if idx := strings.Index(addr, "/"); idx != -1 {
|
||||
addr = addr[:idx]
|
||||
}
|
||||
|
||||
// 如果没有端口,添加默认端口
|
||||
if !strings.Contains(addr, ":") {
|
||||
if strings.HasPrefix(url, "https://") {
|
||||
addr = addr + ":443"
|
||||
} else {
|
||||
addr = addr + ":80"
|
||||
}
|
||||
}
|
||||
|
||||
return addr
|
||||
return netutil.ExtractHost(url)
|
||||
}
|
||||
|
||||
// writeUpgradeResponse 将 HTTP 升级响应写回客户端。
|
||||
|
||||
@ -340,23 +340,8 @@ func (s *Server) startSingleMode() error {
|
||||
// 注册代理路由
|
||||
s.registerProxyRoutes(router, &s.config.Server)
|
||||
|
||||
// 静态文件服务(作为 fallback)
|
||||
// 启用零拷贝传输优化(大文件使用 sendfile)
|
||||
staticHandler := handler.NewStaticHandler(
|
||||
s.config.Server.Static.Root,
|
||||
s.config.Server.Static.Index,
|
||||
true, // useSendfile
|
||||
)
|
||||
// 设置文件缓存
|
||||
if s.fileCache != nil {
|
||||
staticHandler.SetFileCache(s.fileCache)
|
||||
}
|
||||
// 设置预压缩文件支持
|
||||
if s.config.Server.Compression.GzipStatic {
|
||||
staticHandler.SetGzipStatic(true, s.config.Server.Compression.GzipStaticExtensions)
|
||||
}
|
||||
router.GET("/{filepath:*}", staticHandler.Handle)
|
||||
router.HEAD("/{filepath:*}", staticHandler.Handle)
|
||||
// 静态文件服务
|
||||
s.registerStaticHandler(router, &s.config.Server)
|
||||
|
||||
// 构建中间件链
|
||||
chain, err := s.buildMiddlewareChain(&s.config.Server)
|
||||
@ -692,3 +677,20 @@ func (s *Server) getProxyCacheStats() ProxyCacheStats {
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// registerStaticHandler registers static file handler.
|
||||
func (s *Server) registerStaticHandler(router *handler.Router, cfg *config.ServerConfig) {
|
||||
staticHandler := handler.NewStaticHandler(
|
||||
cfg.Static.Root,
|
||||
cfg.Static.Index,
|
||||
true, // useSendfile
|
||||
)
|
||||
if s.fileCache != nil {
|
||||
staticHandler.SetFileCache(s.fileCache)
|
||||
}
|
||||
if cfg.Compression.GzipStatic {
|
||||
staticHandler.SetGzipStatic(true, cfg.Compression.GzipStaticExtensions)
|
||||
}
|
||||
router.GET("/{filepath:*}", staticHandler.Handle)
|
||||
router.HEAD("/{filepath:*}", staticHandler.Handle)
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user