lolly/internal/proxy/websocket.go
xfy 27e00b84a8 fix(proxy,handler,server,stream,ratelimit): fix resource leaks and functional bugs
- proxy/proxy.go: decrement connection count on dangerous path rejection
  (line 724) to prevent connection count leak
- handler/sendfile_linux.go: return *os.File from getSocketFile and let
  linuxSendfile close it, fixing EBADF from deferred close in getSocketFd
- proxy/websocket.go: return bufio.Reader from readWebSocketUpgradeResponse
  and wrap targetConn with bufferedConn to consume pre-buffered frame data,
  preventing first-frame loss
- server/pool.go: use non-blocking send after starting new worker to avoid
  deadlock when queue is full
- stream/stream.go: check stopCh on non-timeout UDP read errors to prevent
  infinite loop and shutdown deadlock
- middleware/ratelimit: replace select-based close guard with sync.Once in
  StopCleanup to prevent double-close panic
2026-06-11 16:35:10 +08:00

498 lines
13 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package proxy 提供反向代理功能,支持 HTTP、WebSocket 和流式代理。
//
// 该文件实现了 WebSocket 代理桥接器,用于在客户端和后端服务器之间
// 建立 WebSocket 连接并进行双向数据转发。
//
// 主要功能:
// - WebSocket 连接升级:处理 HTTP 到 WebSocket 的协议升级
// - 双向数据转发:在客户端和后端之间透明转发数据帧
// - TLS 支持:支持 ws:// 和 wss:// 协议
// - 超时控制:可配置的连接和读写超时
//
// 使用示例:
//
// err := proxy.WebSocket(ctx, target, 30*time.Second)
// if err != nil {
// log.Printf("WebSocket proxy error: %v", err)
// }
//
// 作者xfy
package proxy
import (
"bufio"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/loadbalance"
"rua.plus/lolly/internal/netutil"
)
// wsBufPool WebSocket 数据转发 buffer pool。
// 复用 32KB buffer 避免每次 copyData 调用分配。
var wsBufPool = sync.Pool{
New: func() any {
buf := make([]byte, 32*1024)
return &buf
},
}
// WebSocketBridge WebSocket 桥接器。
//
// 在客户端和后端服务器之间建立双向数据通道,透明转发 WebSocket 数据帧。
// 支持并发读写,使用互斥锁保护关闭状态。
//
// 注意事项:
// - 调用 Bridge() 会阻塞直到连接关闭
// - 使用完毕后应调用 Close() 释放资源
type WebSocketBridge struct {
clientConn net.Conn // 客户端 TCP 连接
targetConn net.Conn // 后端目标 TCP 连接
mu sync.Mutex // 保护 closed 字段的互斥锁
closed bool // 连接关闭标志
}
// NewWebSocketBridge 创建新的 WebSocket 桥接器。
//
// 参数:
// - clientConn: 客户端网络连接
// - targetConn: 后端目标网络连接
//
// 返回值:
// - *WebSocketBridge: 初始化的桥接器实例
func NewWebSocketBridge(clientConn, targetConn net.Conn) *WebSocketBridge {
return &WebSocketBridge{
clientConn: clientConn,
targetConn: targetConn,
closed: false,
}
}
// Bridge 启动双向数据转发。
//
// 创建两个 goroutine 分别处理客户端到后端和后端到客户端的数据流,
// 阻塞直到两个方向的转发都完成。
//
// 返回值:
// - error: 转发过程中的错误(连接正常关闭返回 nil
func (b *WebSocketBridge) Bridge() error {
var wg sync.WaitGroup
wg.Add(2)
var copyErr1, copyErr2 error
// 客户端 -> 后端方向
go func() {
defer wg.Done()
copyErr1 = b.copyData(b.clientConn, b.targetConn, "client->target")
}()
// 后端 -> 客户端方向
go func() {
defer wg.Done()
copyErr2 = b.copyData(b.targetConn, b.clientConn, "target->client")
}()
// 等待双向转发完成
wg.Wait()
// 返回第一个非 nil 的错误(忽略连接关闭错误)
if copyErr1 != nil && !isConnectionClosedError(copyErr1) {
return copyErr1
}
if copyErr2 != nil && !isConnectionClosedError(copyErr2) {
return copyErr2
}
return nil
}
// copyData 在两个连接之间复制数据。
//
// 使用 32KB 缓冲区进行数据拷贝,遇到连接关闭错误时返回 nil。
//
// 参数:
// - dst: 目标连接(写入端)
// - src: 源连接(读取端)
// - direction: 方向描述,用于错误信息
//
// 返回值:
// - error: 读写错误,连接正常关闭返回 nil
func (b *WebSocketBridge) copyData(dst, src net.Conn, direction string) error {
bufPtr := wsBufPool.Get().(*[]byte) //nolint:errcheck // pool always returns valid *[]byte
buf := *bufPtr
defer wsBufPool.Put(bufPtr)
for {
n, err := src.Read(buf)
if err != nil {
if isConnectionClosedError(err) {
return nil
}
return fmt.Errorf("read error (%s): %w", direction, err)
}
if n > 0 {
_, err = dst.Write(buf[:n])
if err != nil {
if isConnectionClosedError(err) {
return nil
}
return fmt.Errorf("write error (%s): %w", direction, err)
}
}
}
}
// Close 关闭桥接器的两个连接。
//
// 关闭客户端和后端连接,使用互斥锁确保只关闭一次。
//
// 返回值:
// - error: 关闭过程中的错误
func (b *WebSocketBridge) Close() error {
b.mu.Lock()
if b.closed {
b.mu.Unlock()
return nil
}
b.closed = true
b.mu.Unlock()
var err1, err2 error
if b.clientConn != nil {
err1 = b.clientConn.Close()
}
if b.targetConn != nil {
err2 = b.targetConn.Close()
}
if err1 != nil {
return err1
}
return err2
}
// isConnectionClosedError 检查错误是否表示连接已关闭。
//
// 判断 EOF、网络超时和使用已关闭连接等正常关闭情况。
//
// 参数:
// - err: 待检查的错误
//
// 返回值:
// - bool: true 表示是连接关闭错误
func isConnectionClosedError(err error) bool {
if err == nil {
return false
}
if err == io.EOF {
return true
}
if netErr, ok := err.(net.Error); ok {
// 检查是否为 "use of closed network connection" 错误
if strings.Contains(err.Error(), "use of closed network connection") {
return true
}
return netErr.Timeout()
}
return false
}
// dialTarget 建立到后端目标的 TCP 连接。
//
// 解析目标 URL支持 HTTP 和 HTTPS 协议,自动添加默认端口。
//
// 参数:
// - targetURL: 目标 URL如 http://example.com 或 https://example.com:8443
// - timeout: 连接超时时间
//
// 返回值:
// - net.Conn: 建立的连接TLS 连接或普通 TCP 连接)
// - error: 连接失败时返回错误
func dialTarget(targetURL string, timeout time.Duration) (net.Conn, error) {
// 解析目标 URL
addr, isTLS := netutil.ParseTargetURL(targetURL, true)
// 建立 TCP 连接
dialer := &net.Dialer{
Timeout: timeout,
}
conn, err := dialer.Dial("tcp", addr)
if err != nil {
return nil, fmt.Errorf("failed to connect to target: %w", err)
}
// 如果是 HTTPS建立 TLS 连接
if isTLS {
tlsConn := tls.Client(conn, &tls.Config{
InsecureSkipVerify: false,
ServerName: strings.Split(addr, ":")[0],
})
if err := tlsConn.SetDeadline(time.Now().Add(timeout)); err != nil {
_ = conn.Close()
return nil, fmt.Errorf("failed to set TLS deadline: %w", err)
}
if err := tlsConn.Handshake(); err != nil {
_ = conn.Close()
return nil, fmt.Errorf("TLS handshake failed: %w", err)
}
return tlsConn, nil
}
return conn, nil
}
// buildWebSocketUpgradeRequest 构建 WebSocket 升级 HTTP 请求。
//
// 根据客户端请求构建发往后端的 WebSocket 升级请求,
// 复制必要的请求头并添加 X-Forwarded 系列代理头。
//
// 参数:
// - ctx: FastHTTP 请求上下文
// - targetHost: 目标主机地址
// - headersConfig: 代理头配置,控制 X-Forwarded-Host/Proto 的设置
//
// 返回值:
// - string: 完整的 HTTP 请求字符串
func buildWebSocketUpgradeRequest(ctx *fasthttp.RequestCtx, targetHost string, headersConfig *config.ProxyHeaders) string {
// 构建请求行
path := string(ctx.Path())
if path == "" {
path = "/"
}
// 添加查询参数
query := string(ctx.QueryArgs().QueryString())
if query != "" {
path = path + "?" + query
}
// 构建请求头
var req strings.Builder
fmt.Fprintf(&req, "GET %s HTTP/1.1\r\n", path)
fmt.Fprintf(&req, "Host: %s\r\n", targetHost)
// 复制原始请求的关键头
copyHeaders := []string{
"Upgrade",
"Connection",
"Sec-WebSocket-Key",
"Sec-WebSocket-Version",
"Sec-WebSocket-Protocol",
"Sec-WebSocket-Extensions",
"Origin",
}
for _, header := range copyHeaders {
if value := ctx.Request.Header.Peek(header); len(value) > 0 {
fmt.Fprintf(&req, "%s: %s\r\n", header, string(value))
}
}
// 添加 X-Forwarded 头
fh := ExtractForwardedHeaders(ctx)
// 根据配置决定是否设置 X-Forwarded-Host 和 X-Forwarded-Proto
setHost := true // 默认值(向后兼容)
if headersConfig != nil && headersConfig.SetForwardedHost != nil {
setHost = *headersConfig.SetForwardedHost
}
setProto := true // 默认值(向后兼容)
if headersConfig != nil && headersConfig.SetForwardedProto != nil {
setProto = *headersConfig.SetForwardedProto
}
WriteForwardedHeaders(&req, fh, setHost, setProto)
// 结束请求头
req.WriteString("\r\n")
return req.String()
}
// readWebSocketUpgradeResponse 读取 WebSocket 升级响应。
//
// 从后端连接读取 HTTP 响应,解析响应头和状态码。
//
// 参数:
// - conn: 后端网络连接
// - timeout: 读取超时时间
//
// 返回值:
// - *http.Response: HTTP 响应对象
// - error: 读取失败时返回错误
func readWebSocketUpgradeResponse(conn net.Conn, timeout time.Duration) (*http.Response, *bufio.Reader, error) {
// 设置读取超时
if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil {
return nil, nil, err
}
// 使用 bufio.Reader 读取 HTTP 响应
reader := bufio.NewReader(conn)
resp, err := http.ReadResponse(reader, nil)
if err != nil {
return nil, nil, fmt.Errorf("failed to read upgrade response: %w", err)
}
return resp, reader, nil
}
// WebSocket 处理 WebSocket 代理请求。
//
// 完整流程:
// 1. 劫持客户端连接
// 2. 建立到后端的 TCP/TLS 连接
// 3. 发送 WebSocket 升级请求
// 4. 验证后端升级响应
// 5. 启动双向数据转发
//
// 参数:
// - ctx: FastHTTP 请求上下文
// - target: 负载均衡目标,包含后端 URL
// - timeout: 连接和 I/O 超时时间
// - headersConfig: 代理头配置,控制 X-Forwarded-Host/Proto 的设置
//
// 返回值:
// - error: 代理过程中的错误
func WebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, timeout time.Duration, headersConfig *config.ProxyHeaders) error {
// 使用 Hijack 获取客户端 TCP 连接
var clientConn net.Conn
ctx.Hijack(func(c net.Conn) {
clientConn = c
})
if clientConn == nil {
return errors.New("failed to hijack connection")
}
// 步骤1: 建立到后端目标的连接
targetConn, err := dialTarget(target.URL, timeout)
if err != nil {
_ = clientConn.Close()
return fmt.Errorf("failed to connect to backend: %w", err)
}
// 创建桥接器管理两个连接
bridge := NewWebSocketBridge(clientConn, targetConn)
defer func() {
_ = bridge.Close()
}()
// 步骤2: 从目标 URL 提取主机地址
targetHost := extractHost(target.URL)
// 步骤3: 构建并发送 WebSocket 升级请求
upgradeReq := buildWebSocketUpgradeRequest(ctx, targetHost, headersConfig)
if _, writeErr := targetConn.Write([]byte(upgradeReq)); writeErr != nil {
return fmt.Errorf("failed to send upgrade request: %w", writeErr)
}
// 步骤4: 读取升级响应
resp, bufferedReader, err := readWebSocketUpgradeResponse(targetConn, timeout)
if err != nil {
return fmt.Errorf("failed to read upgrade response: %w", err)
}
// 步骤5: 检查响应状态码(期望 101 Switching Protocols
if resp.StatusCode != http.StatusSwitchingProtocols {
// 关闭响应 body升级失败时
_ = resp.Body.Close()
return fmt.Errorf("backend rejected WebSocket upgrade: %s", resp.Status)
}
// 步骤6: 将升级响应发送回客户端
if err := writeUpgradeResponse(clientConn, resp); err != nil {
// 关闭响应 body写入失败时
_ = resp.Body.Close()
return fmt.Errorf("failed to send upgrade response to client: %w", err)
}
// 注意: WebSocket 升级成功后resp.Body 不需要显式关闭
// 因为底层连接已被 bridge 用于双向数据传输
// 如果 bufferedReader 已经缓冲了 WebSocket frame 数据,
// 需要包装连接使后续读取先消耗缓冲区
if bufferedReader != nil && bufferedReader.Buffered() > 0 {
targetConn = &bufferedConn{
Conn: targetConn,
reader: bufferedReader,
}
bridge.targetConn = targetConn
}
// 步骤7: 启动桥接(阻塞直到连接关闭)
return bridge.Bridge()
}
// extractHost 从 URL 中提取主机地址(带端口)。
//
// 处理 http:// 和 https:// 前缀,自动添加默认端口。
//
// 参数:
// - url: 完整的 URL 字符串
//
// 返回值:
// - string: 主机地址(格式 host:port
func extractHost(url string) string {
return netutil.ExtractHost(url)
}
// writeUpgradeResponse 将 HTTP 升级响应写回客户端。
//
// 将后端返回的 101 Switching Protocols 响应转发给客户端。
//
// 参数:
// - conn: 客户端网络连接
// - resp: HTTP 响应对象
//
// 返回值:
// - error: 写入失败时返回错误
func writeUpgradeResponse(conn net.Conn, resp *http.Response) error {
// 构建响应行
var respStr strings.Builder
fmt.Fprintf(&respStr, "HTTP/%d.%d %s\r\n", resp.ProtoMajor, resp.ProtoMinor, resp.Status)
// 写入响应头
for key, values := range resp.Header {
for _, value := range values {
fmt.Fprintf(&respStr, "%s: %s\r\n", key, value)
}
}
respStr.WriteString("\r\n")
if _, err := conn.Write([]byte(respStr.String())); err != nil {
return err
}
return nil
}
// bufferedConn 包装 net.Conn优先从 bufio.Reader 的缓冲区读取数据。
//
// 用于 WebSocket 升级响应后,消耗 bufio.Reader 可能已缓冲的 frame 数据。
type bufferedConn struct {
net.Conn
reader *bufio.Reader
}
// Read 优先从内部 bufio.Reader 读取,若缓冲区为空则回退到原始连接。
func (bc *bufferedConn) Read(p []byte) (int, error) {
if bc.reader.Buffered() > 0 {
return bc.reader.Read(p)
}
return bc.Conn.Read(p)
}