diff --git a/internal/handler/sendfile_linux.go b/internal/handler/sendfile_linux.go index a4c8d57..cfb5060 100644 --- a/internal/handler/sendfile_linux.go +++ b/internal/handler/sendfile_linux.go @@ -80,10 +80,12 @@ func SendFile(ctx *fasthttp.RequestCtx, file *os.File, offset, length int64) err // 返回值: // - error: 传输过程中的错误,nil 表示成功 func linuxSendfile(conn net.Conn, fileFd uintptr, _, length int64) error { - socketFd, err := getSocketFd(conn) + socketFile, err := getSocketFile(conn) if err != nil { return err } + defer func() { _ = socketFile.Close() }() + socketFd := socketFile.Fd() // Linux sendfile: sendfile(out_fd, in_fd, offset, count) var sent int64 @@ -148,23 +150,13 @@ func linuxSendfile(conn net.Conn, fileFd uintptr, _, length int64) error { // 返回值: // - uintptr: socket 文件描述符,失败时返回 0 // - error: 获取失败时的错误,不支持的连接类型返回 ENOTSUP -func getSocketFd(conn net.Conn) (uintptr, error) { +func getSocketFile(conn net.Conn) (*os.File, error) { switch c := conn.(type) { case *net.TCPConn: - file, err := c.File() - if err != nil { - return 0, err - } - defer func() { _ = file.Close() }() - return file.Fd(), nil + return c.File() case *net.UnixConn: - file, err := c.File() - if err != nil { - return 0, err - } - defer func() { _ = file.Close() }() - return file.Fd(), nil + return c.File() default: - return 0, syscall.ENOTSUP + return nil, syscall.ENOTSUP } } diff --git a/internal/handler/sendfile_test.go b/internal/handler/sendfile_test.go index 247469f..43bb15e 100644 --- a/internal/handler/sendfile_test.go +++ b/internal/handler/sendfile_test.go @@ -114,7 +114,7 @@ func TestCopyFile(t *testing.T) { // TestGetSocketFd_NilConn 测试 nil 连接的情况 func TestGetSocketFd_NilConn(t *testing.T) { - _, err := getSocketFd(nil) + _, err := getSocketFile(nil) if err == nil { t.Error("expected error for nil connection") } @@ -123,7 +123,7 @@ func TestGetSocketFd_NilConn(t *testing.T) { // TestGetSocketFd_UnsupportedType 测试不支持的连接类型 func TestGetSocketFd_UnsupportedType(t *testing.T) { conn := &mockConn{} - _, err := getSocketFd(conn) + _, err := getSocketFile(conn) if err != syscall.ENOTSUP { t.Errorf("expected ENOTSUP for unsupported conn type, got: %v", err) } @@ -131,7 +131,7 @@ func TestGetSocketFd_UnsupportedType(t *testing.T) { // mockConn 是一个不实现 TCPConn/UnixConn 的模拟连接。 // -// 用于测试 getSocketFd 对不支持连接类型的处理。 +// 用于测试 getSocketFile 对不支持连接类型的处理。 type mockConn struct{} func (m *mockConn) Read([]byte) (n int, err error) { return 0, nil } @@ -449,11 +449,11 @@ func TestGetSocketFd_UnixConn(t *testing.T) { } defer clientConn.Close() - fd, err := getSocketFd(clientConn) + f, err := getSocketFile(clientConn) if err != nil { - t.Errorf("getSocketFd failed: %v", err) + t.Errorf("getSocketFile failed: %v", err) } - if fd == 0 { + if f == nil { t.Error("Expected non-zero fd") } @@ -616,11 +616,11 @@ func TestGetSocketFd_TCPConn(t *testing.T) { } defer clientConn.Close() - fd, err := getSocketFd(clientConn) + f, err := getSocketFile(clientConn) if err != nil { - t.Errorf("getSocketFd failed for TCPConn: %v", err) + t.Errorf("getSocketFile failed for TCPConn: %v", err) } - if fd == 0 { + if f == nil { t.Error("Expected non-zero fd for TCPConn") } diff --git a/internal/middleware/security/ratelimit.go b/internal/middleware/security/ratelimit.go index cc12ccf..ecf5440 100644 --- a/internal/middleware/security/ratelimit.go +++ b/internal/middleware/security/ratelimit.go @@ -70,6 +70,7 @@ type RateLimiter struct { cleanupTicker *time.Ticker stopCleanupCh chan struct{} cleanupDone chan struct{} + stopOnce sync.Once rate float64 burst float64 } @@ -494,18 +495,9 @@ func (rl *RateLimiter) startCleanup(interval time.Duration) { // 发送停止信号并等待 goroutine 完成,确保资源正确释放。 // 该方法应在限流器不再使用时调用(如服务器关闭时)。 func (rl *RateLimiter) StopCleanup() { - // 使用原子操作或简单的标志检查来避免竞争 - // 关闭 stopCleanupCh 会广播给所有等待的 goroutine - select { - case <-rl.stopCleanupCh: - // 已经关闭 - return - default: - } - if rl.cleanupTicker != nil { rl.cleanupTicker.Stop() - close(rl.stopCleanupCh) + rl.stopOnce.Do(func() { close(rl.stopCleanupCh) }) <-rl.cleanupDone rl.cleanupTicker = nil // 防止重复关闭 } diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index a1b35a0..852376c 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -721,6 +721,7 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) { if bytes.ContainsAny(path, "@\r\n") { logging.Warn().Msgf("rejected suspicious proxy path containing dangerous chars: %s", path) upstreamStatus = 502 + loadbalance.DecrementConnections(target) utils.SendErrorWithDetail(ctx, utils.ErrBadGateway, "invalid proxy path") return } diff --git a/internal/proxy/proxy_low_coverage_test.go b/internal/proxy/proxy_low_coverage_test.go index f2dbf51..b904d34 100644 --- a/internal/proxy/proxy_low_coverage_test.go +++ b/internal/proxy/proxy_low_coverage_test.go @@ -752,7 +752,7 @@ func TestWebSocket_UpgradeRejected(t *testing.T) { // 发送一个请求让服务端触发 _, _ = conn.Write([]byte("GET /ws HTTP/1.1\r\nHost: localhost\r\n\r\n")) - resp, err := readWebSocketUpgradeResponse(conn, 1*time.Second) + resp, _, err := readWebSocketUpgradeResponse(conn, 1*time.Second) require.NoError(t, err) assert.Equal(t, 400, resp.StatusCode) } @@ -1274,7 +1274,7 @@ func TestReadWebSocketUpgradeResponse_ReadError(t *testing.T) { conn1, conn2 := net.Pipe() _ = conn2.Close() - _, err := readWebSocketUpgradeResponse(conn1, 100*time.Millisecond) + _, _, err := readWebSocketUpgradeResponse(conn1, 100*time.Millisecond) assert.Error(t, err) _ = conn1.Close() } diff --git a/internal/proxy/websocket.go b/internal/proxy/websocket.go index 59e7c17..1b91599 100644 --- a/internal/proxy/websocket.go +++ b/internal/proxy/websocket.go @@ -332,20 +332,20 @@ func buildWebSocketUpgradeRequest(ctx *fasthttp.RequestCtx, targetHost string, h // 返回值: // - *http.Response: HTTP 响应对象 // - error: 读取失败时返回错误 -func readWebSocketUpgradeResponse(conn net.Conn, timeout time.Duration) (*http.Response, error) { +func readWebSocketUpgradeResponse(conn net.Conn, timeout time.Duration) (*http.Response, *bufio.Reader, error) { // 设置读取超时 if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil { - return nil, err + return nil, nil, err } // 使用 bufio.Reader 读取 HTTP 响应 reader := bufio.NewReader(conn) resp, err := http.ReadResponse(reader, nil) if err != nil { - return nil, fmt.Errorf("failed to read upgrade response: %w", err) + return nil, nil, fmt.Errorf("failed to read upgrade response: %w", err) } - return resp, nil + return resp, reader, nil } // WebSocket 处理 WebSocket 代理请求。 @@ -400,7 +400,7 @@ func WebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, timeout tim } // 步骤4: 读取升级响应 - resp, err := readWebSocketUpgradeResponse(targetConn, timeout) + resp, bufferedReader, err := readWebSocketUpgradeResponse(targetConn, timeout) if err != nil { return fmt.Errorf("failed to read upgrade response: %w", err) } @@ -422,6 +422,16 @@ func WebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, timeout tim // 注意: WebSocket 升级成功后,resp.Body 不需要显式关闭 // 因为底层连接已被 bridge 用于双向数据传输 + // 如果 bufferedReader 已经缓冲了 WebSocket frame 数据, + // 需要包装连接使后续读取先消耗缓冲区 + if bufferedReader != nil && bufferedReader.Buffered() > 0 { + targetConn = &bufferedConn{ + Conn: targetConn, + reader: bufferedReader, + } + bridge.targetConn = targetConn + } + // 步骤7: 启动桥接(阻塞直到连接关闭) return bridge.Bridge() } @@ -469,3 +479,19 @@ func writeUpgradeResponse(conn net.Conn, resp *http.Response) error { return nil } + +// bufferedConn 包装 net.Conn,优先从 bufio.Reader 的缓冲区读取数据。 +// +// 用于 WebSocket 升级响应后,消耗 bufio.Reader 可能已缓冲的 frame 数据。 +type bufferedConn struct { + net.Conn + reader *bufio.Reader +} + +// Read 优先从内部 bufio.Reader 读取,若缓冲区为空则回退到原始连接。 +func (bc *bufferedConn) Read(p []byte) (int, error) { + if bc.reader.Buffered() > 0 { + return bc.reader.Read(p) + } + return bc.Conn.Read(p) +} diff --git a/internal/proxy/websocket_test.go b/internal/proxy/websocket_test.go index 78b2eaf..f4c93c9 100644 --- a/internal/proxy/websocket_test.go +++ b/internal/proxy/websocket_test.go @@ -557,7 +557,7 @@ func TestReadWebSocketUpgradeResponse(t *testing.T) { }() // 读取响应 - resp, err := readWebSocketUpgradeResponse(conn1, 1*time.Second) + resp, _, err := readWebSocketUpgradeResponse(conn1, 1*time.Second) if err != nil { t.Fatalf("readWebSocketUpgradeResponse failed: %v", err) } @@ -579,7 +579,7 @@ func TestReadWebSocketUpgradeResponse_Timeout(t *testing.T) { defer func() { _ = conn2.Close() }() // 使用很短的超时 - _, err := readWebSocketUpgradeResponse(conn1, 10*time.Millisecond) + _, _, err := readWebSocketUpgradeResponse(conn1, 10*time.Millisecond) if err == nil { t.Error("Expected timeout error, got nil") } diff --git a/internal/server/pool.go b/internal/server/pool.go index 42f9267..3ebc5f1 100644 --- a/internal/server/pool.go +++ b/internal/server/pool.go @@ -179,9 +179,14 @@ func (p *GoroutinePool) Submit(ctx *fasthttp.RequestCtx, task Task) error { // 队列满,需要启动新 worker 或直接执行 if atomic.LoadInt32(&p.workers) < p.maxWorkers { p.startWorker() - // 重新尝试入队 - p.taskQueue <- task - return nil + // 非阻塞尝试入队,避免新 worker 尚未就绪时死锁 + select { + case p.taskQueue <- task: + return nil + default: + task(ctx) + return nil + } } // 达到最大 worker,直接执行(fallback) diff --git a/internal/stream/stream.go b/internal/stream/stream.go index 56402a3..9f3e70e 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -982,7 +982,13 @@ func (s *udpServer) serve() { continue } } - continue + // 非超时错误(如连接关闭),检查 stopCh 后退出 + select { + case <-s.stopCh: + return + default: + continue + } } // 获取或创建会话