fix(proxy,handler,server,stream,ratelimit): fix resource leaks and functional bugs
- proxy/proxy.go: decrement connection count on dangerous path rejection (line 724) to prevent connection count leak - handler/sendfile_linux.go: return *os.File from getSocketFile and let linuxSendfile close it, fixing EBADF from deferred close in getSocketFd - proxy/websocket.go: return bufio.Reader from readWebSocketUpgradeResponse and wrap targetConn with bufferedConn to consume pre-buffered frame data, preventing first-frame loss - server/pool.go: use non-blocking send after starting new worker to avoid deadlock when queue is full - stream/stream.go: check stopCh on non-timeout UDP read errors to prevent infinite loop and shutdown deadlock - middleware/ratelimit: replace select-based close guard with sync.Once in StopCleanup to prevent double-close panic
This commit is contained in:
parent
fe0dee4da3
commit
27e00b84a8
@ -80,10 +80,12 @@ func SendFile(ctx *fasthttp.RequestCtx, file *os.File, offset, length int64) err
|
||||
// 返回值:
|
||||
// - error: 传输过程中的错误,nil 表示成功
|
||||
func linuxSendfile(conn net.Conn, fileFd uintptr, _, length int64) error {
|
||||
socketFd, err := getSocketFd(conn)
|
||||
socketFile, err := getSocketFile(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = socketFile.Close() }()
|
||||
socketFd := socketFile.Fd()
|
||||
|
||||
// Linux sendfile: sendfile(out_fd, in_fd, offset, count)
|
||||
var sent int64
|
||||
@ -148,23 +150,13 @@ func linuxSendfile(conn net.Conn, fileFd uintptr, _, length int64) error {
|
||||
// 返回值:
|
||||
// - uintptr: socket 文件描述符,失败时返回 0
|
||||
// - error: 获取失败时的错误,不支持的连接类型返回 ENOTSUP
|
||||
func getSocketFd(conn net.Conn) (uintptr, error) {
|
||||
func getSocketFile(conn net.Conn) (*os.File, error) {
|
||||
switch c := conn.(type) {
|
||||
case *net.TCPConn:
|
||||
file, err := c.File()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
return file.Fd(), nil
|
||||
return c.File()
|
||||
case *net.UnixConn:
|
||||
file, err := c.File()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
return file.Fd(), nil
|
||||
return c.File()
|
||||
default:
|
||||
return 0, syscall.ENOTSUP
|
||||
return nil, syscall.ENOTSUP
|
||||
}
|
||||
}
|
||||
|
||||
@ -114,7 +114,7 @@ func TestCopyFile(t *testing.T) {
|
||||
|
||||
// TestGetSocketFd_NilConn 测试 nil 连接的情况
|
||||
func TestGetSocketFd_NilConn(t *testing.T) {
|
||||
_, err := getSocketFd(nil)
|
||||
_, err := getSocketFile(nil)
|
||||
if err == nil {
|
||||
t.Error("expected error for nil connection")
|
||||
}
|
||||
@ -123,7 +123,7 @@ func TestGetSocketFd_NilConn(t *testing.T) {
|
||||
// TestGetSocketFd_UnsupportedType 测试不支持的连接类型
|
||||
func TestGetSocketFd_UnsupportedType(t *testing.T) {
|
||||
conn := &mockConn{}
|
||||
_, err := getSocketFd(conn)
|
||||
_, err := getSocketFile(conn)
|
||||
if err != syscall.ENOTSUP {
|
||||
t.Errorf("expected ENOTSUP for unsupported conn type, got: %v", err)
|
||||
}
|
||||
@ -131,7 +131,7 @@ func TestGetSocketFd_UnsupportedType(t *testing.T) {
|
||||
|
||||
// mockConn 是一个不实现 TCPConn/UnixConn 的模拟连接。
|
||||
//
|
||||
// 用于测试 getSocketFd 对不支持连接类型的处理。
|
||||
// 用于测试 getSocketFile 对不支持连接类型的处理。
|
||||
type mockConn struct{}
|
||||
|
||||
func (m *mockConn) Read([]byte) (n int, err error) { return 0, nil }
|
||||
@ -449,11 +449,11 @@ func TestGetSocketFd_UnixConn(t *testing.T) {
|
||||
}
|
||||
defer clientConn.Close()
|
||||
|
||||
fd, err := getSocketFd(clientConn)
|
||||
f, err := getSocketFile(clientConn)
|
||||
if err != nil {
|
||||
t.Errorf("getSocketFd failed: %v", err)
|
||||
t.Errorf("getSocketFile failed: %v", err)
|
||||
}
|
||||
if fd == 0 {
|
||||
if f == nil {
|
||||
t.Error("Expected non-zero fd")
|
||||
}
|
||||
|
||||
@ -616,11 +616,11 @@ func TestGetSocketFd_TCPConn(t *testing.T) {
|
||||
}
|
||||
defer clientConn.Close()
|
||||
|
||||
fd, err := getSocketFd(clientConn)
|
||||
f, err := getSocketFile(clientConn)
|
||||
if err != nil {
|
||||
t.Errorf("getSocketFd failed for TCPConn: %v", err)
|
||||
t.Errorf("getSocketFile failed for TCPConn: %v", err)
|
||||
}
|
||||
if fd == 0 {
|
||||
if f == nil {
|
||||
t.Error("Expected non-zero fd for TCPConn")
|
||||
}
|
||||
|
||||
|
||||
@ -70,6 +70,7 @@ type RateLimiter struct {
|
||||
cleanupTicker *time.Ticker
|
||||
stopCleanupCh chan struct{}
|
||||
cleanupDone chan struct{}
|
||||
stopOnce sync.Once
|
||||
rate float64
|
||||
burst float64
|
||||
}
|
||||
@ -494,18 +495,9 @@ func (rl *RateLimiter) startCleanup(interval time.Duration) {
|
||||
// 发送停止信号并等待 goroutine 完成,确保资源正确释放。
|
||||
// 该方法应在限流器不再使用时调用(如服务器关闭时)。
|
||||
func (rl *RateLimiter) StopCleanup() {
|
||||
// 使用原子操作或简单的标志检查来避免竞争
|
||||
// 关闭 stopCleanupCh 会广播给所有等待的 goroutine
|
||||
select {
|
||||
case <-rl.stopCleanupCh:
|
||||
// 已经关闭
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if rl.cleanupTicker != nil {
|
||||
rl.cleanupTicker.Stop()
|
||||
close(rl.stopCleanupCh)
|
||||
rl.stopOnce.Do(func() { close(rl.stopCleanupCh) })
|
||||
<-rl.cleanupDone
|
||||
rl.cleanupTicker = nil // 防止重复关闭
|
||||
}
|
||||
|
||||
@ -721,6 +721,7 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
|
||||
if bytes.ContainsAny(path, "@\r\n") {
|
||||
logging.Warn().Msgf("rejected suspicious proxy path containing dangerous chars: %s", path)
|
||||
upstreamStatus = 502
|
||||
loadbalance.DecrementConnections(target)
|
||||
utils.SendErrorWithDetail(ctx, utils.ErrBadGateway, "invalid proxy path")
|
||||
return
|
||||
}
|
||||
|
||||
@ -752,7 +752,7 @@ func TestWebSocket_UpgradeRejected(t *testing.T) {
|
||||
// 发送一个请求让服务端触发
|
||||
_, _ = conn.Write([]byte("GET /ws HTTP/1.1\r\nHost: localhost\r\n\r\n"))
|
||||
|
||||
resp, err := readWebSocketUpgradeResponse(conn, 1*time.Second)
|
||||
resp, _, err := readWebSocketUpgradeResponse(conn, 1*time.Second)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 400, resp.StatusCode)
|
||||
}
|
||||
@ -1274,7 +1274,7 @@ func TestReadWebSocketUpgradeResponse_ReadError(t *testing.T) {
|
||||
conn1, conn2 := net.Pipe()
|
||||
_ = conn2.Close()
|
||||
|
||||
_, err := readWebSocketUpgradeResponse(conn1, 100*time.Millisecond)
|
||||
_, _, err := readWebSocketUpgradeResponse(conn1, 100*time.Millisecond)
|
||||
assert.Error(t, err)
|
||||
_ = conn1.Close()
|
||||
}
|
||||
|
||||
@ -332,20 +332,20 @@ func buildWebSocketUpgradeRequest(ctx *fasthttp.RequestCtx, targetHost string, h
|
||||
// 返回值:
|
||||
// - *http.Response: HTTP 响应对象
|
||||
// - error: 读取失败时返回错误
|
||||
func readWebSocketUpgradeResponse(conn net.Conn, timeout time.Duration) (*http.Response, error) {
|
||||
func readWebSocketUpgradeResponse(conn net.Conn, timeout time.Duration) (*http.Response, *bufio.Reader, error) {
|
||||
// 设置读取超时
|
||||
if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 使用 bufio.Reader 读取 HTTP 响应
|
||||
reader := bufio.NewReader(conn)
|
||||
resp, err := http.ReadResponse(reader, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read upgrade response: %w", err)
|
||||
return nil, nil, fmt.Errorf("failed to read upgrade response: %w", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
return resp, reader, nil
|
||||
}
|
||||
|
||||
// WebSocket 处理 WebSocket 代理请求。
|
||||
@ -400,7 +400,7 @@ func WebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, timeout tim
|
||||
}
|
||||
|
||||
// 步骤4: 读取升级响应
|
||||
resp, err := readWebSocketUpgradeResponse(targetConn, timeout)
|
||||
resp, bufferedReader, err := readWebSocketUpgradeResponse(targetConn, timeout)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read upgrade response: %w", err)
|
||||
}
|
||||
@ -422,6 +422,16 @@ func WebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, timeout tim
|
||||
// 注意: WebSocket 升级成功后,resp.Body 不需要显式关闭
|
||||
// 因为底层连接已被 bridge 用于双向数据传输
|
||||
|
||||
// 如果 bufferedReader 已经缓冲了 WebSocket frame 数据,
|
||||
// 需要包装连接使后续读取先消耗缓冲区
|
||||
if bufferedReader != nil && bufferedReader.Buffered() > 0 {
|
||||
targetConn = &bufferedConn{
|
||||
Conn: targetConn,
|
||||
reader: bufferedReader,
|
||||
}
|
||||
bridge.targetConn = targetConn
|
||||
}
|
||||
|
||||
// 步骤7: 启动桥接(阻塞直到连接关闭)
|
||||
return bridge.Bridge()
|
||||
}
|
||||
@ -469,3 +479,19 @@ func writeUpgradeResponse(conn net.Conn, resp *http.Response) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// bufferedConn 包装 net.Conn,优先从 bufio.Reader 的缓冲区读取数据。
|
||||
//
|
||||
// 用于 WebSocket 升级响应后,消耗 bufio.Reader 可能已缓冲的 frame 数据。
|
||||
type bufferedConn struct {
|
||||
net.Conn
|
||||
reader *bufio.Reader
|
||||
}
|
||||
|
||||
// Read 优先从内部 bufio.Reader 读取,若缓冲区为空则回退到原始连接。
|
||||
func (bc *bufferedConn) Read(p []byte) (int, error) {
|
||||
if bc.reader.Buffered() > 0 {
|
||||
return bc.reader.Read(p)
|
||||
}
|
||||
return bc.Conn.Read(p)
|
||||
}
|
||||
|
||||
@ -557,7 +557,7 @@ func TestReadWebSocketUpgradeResponse(t *testing.T) {
|
||||
}()
|
||||
|
||||
// 读取响应
|
||||
resp, err := readWebSocketUpgradeResponse(conn1, 1*time.Second)
|
||||
resp, _, err := readWebSocketUpgradeResponse(conn1, 1*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("readWebSocketUpgradeResponse failed: %v", err)
|
||||
}
|
||||
@ -579,7 +579,7 @@ func TestReadWebSocketUpgradeResponse_Timeout(t *testing.T) {
|
||||
defer func() { _ = conn2.Close() }()
|
||||
|
||||
// 使用很短的超时
|
||||
_, err := readWebSocketUpgradeResponse(conn1, 10*time.Millisecond)
|
||||
_, _, err := readWebSocketUpgradeResponse(conn1, 10*time.Millisecond)
|
||||
if err == nil {
|
||||
t.Error("Expected timeout error, got nil")
|
||||
}
|
||||
|
||||
@ -179,9 +179,14 @@ func (p *GoroutinePool) Submit(ctx *fasthttp.RequestCtx, task Task) error {
|
||||
// 队列满,需要启动新 worker 或直接执行
|
||||
if atomic.LoadInt32(&p.workers) < p.maxWorkers {
|
||||
p.startWorker()
|
||||
// 重新尝试入队
|
||||
p.taskQueue <- task
|
||||
return nil
|
||||
// 非阻塞尝试入队,避免新 worker 尚未就绪时死锁
|
||||
select {
|
||||
case p.taskQueue <- task:
|
||||
return nil
|
||||
default:
|
||||
task(ctx)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// 达到最大 worker,直接执行(fallback)
|
||||
|
||||
@ -982,7 +982,13 @@ func (s *udpServer) serve() {
|
||||
continue
|
||||
}
|
||||
}
|
||||
continue
|
||||
// 非超时错误(如连接关闭),检查 stopCh 后退出
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// 获取或创建会话
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user