fix(proxy,handler,server,stream,ratelimit): fix resource leaks and functional bugs
- proxy/proxy.go: decrement connection count on dangerous path rejection (line 724) to prevent connection count leak - handler/sendfile_linux.go: return *os.File from getSocketFile and let linuxSendfile close it, fixing EBADF from deferred close in getSocketFd - proxy/websocket.go: return bufio.Reader from readWebSocketUpgradeResponse and wrap targetConn with bufferedConn to consume pre-buffered frame data, preventing first-frame loss - server/pool.go: use non-blocking send after starting new worker to avoid deadlock when queue is full - stream/stream.go: check stopCh on non-timeout UDP read errors to prevent infinite loop and shutdown deadlock - middleware/ratelimit: replace select-based close guard with sync.Once in StopCleanup to prevent double-close panic
This commit is contained in:
parent
fe0dee4da3
commit
27e00b84a8
@ -80,10 +80,12 @@ func SendFile(ctx *fasthttp.RequestCtx, file *os.File, offset, length int64) err
|
|||||||
// 返回值:
|
// 返回值:
|
||||||
// - error: 传输过程中的错误,nil 表示成功
|
// - error: 传输过程中的错误,nil 表示成功
|
||||||
func linuxSendfile(conn net.Conn, fileFd uintptr, _, length int64) error {
|
func linuxSendfile(conn net.Conn, fileFd uintptr, _, length int64) error {
|
||||||
socketFd, err := getSocketFd(conn)
|
socketFile, err := getSocketFile(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
defer func() { _ = socketFile.Close() }()
|
||||||
|
socketFd := socketFile.Fd()
|
||||||
|
|
||||||
// Linux sendfile: sendfile(out_fd, in_fd, offset, count)
|
// Linux sendfile: sendfile(out_fd, in_fd, offset, count)
|
||||||
var sent int64
|
var sent int64
|
||||||
@ -148,23 +150,13 @@ func linuxSendfile(conn net.Conn, fileFd uintptr, _, length int64) error {
|
|||||||
// 返回值:
|
// 返回值:
|
||||||
// - uintptr: socket 文件描述符,失败时返回 0
|
// - uintptr: socket 文件描述符,失败时返回 0
|
||||||
// - error: 获取失败时的错误,不支持的连接类型返回 ENOTSUP
|
// - error: 获取失败时的错误,不支持的连接类型返回 ENOTSUP
|
||||||
func getSocketFd(conn net.Conn) (uintptr, error) {
|
func getSocketFile(conn net.Conn) (*os.File, error) {
|
||||||
switch c := conn.(type) {
|
switch c := conn.(type) {
|
||||||
case *net.TCPConn:
|
case *net.TCPConn:
|
||||||
file, err := c.File()
|
return c.File()
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
defer func() { _ = file.Close() }()
|
|
||||||
return file.Fd(), nil
|
|
||||||
case *net.UnixConn:
|
case *net.UnixConn:
|
||||||
file, err := c.File()
|
return c.File()
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
defer func() { _ = file.Close() }()
|
|
||||||
return file.Fd(), nil
|
|
||||||
default:
|
default:
|
||||||
return 0, syscall.ENOTSUP
|
return nil, syscall.ENOTSUP
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -114,7 +114,7 @@ func TestCopyFile(t *testing.T) {
|
|||||||
|
|
||||||
// TestGetSocketFd_NilConn 测试 nil 连接的情况
|
// TestGetSocketFd_NilConn 测试 nil 连接的情况
|
||||||
func TestGetSocketFd_NilConn(t *testing.T) {
|
func TestGetSocketFd_NilConn(t *testing.T) {
|
||||||
_, err := getSocketFd(nil)
|
_, err := getSocketFile(nil)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error for nil connection")
|
t.Error("expected error for nil connection")
|
||||||
}
|
}
|
||||||
@ -123,7 +123,7 @@ func TestGetSocketFd_NilConn(t *testing.T) {
|
|||||||
// TestGetSocketFd_UnsupportedType 测试不支持的连接类型
|
// TestGetSocketFd_UnsupportedType 测试不支持的连接类型
|
||||||
func TestGetSocketFd_UnsupportedType(t *testing.T) {
|
func TestGetSocketFd_UnsupportedType(t *testing.T) {
|
||||||
conn := &mockConn{}
|
conn := &mockConn{}
|
||||||
_, err := getSocketFd(conn)
|
_, err := getSocketFile(conn)
|
||||||
if err != syscall.ENOTSUP {
|
if err != syscall.ENOTSUP {
|
||||||
t.Errorf("expected ENOTSUP for unsupported conn type, got: %v", err)
|
t.Errorf("expected ENOTSUP for unsupported conn type, got: %v", err)
|
||||||
}
|
}
|
||||||
@ -131,7 +131,7 @@ func TestGetSocketFd_UnsupportedType(t *testing.T) {
|
|||||||
|
|
||||||
// mockConn 是一个不实现 TCPConn/UnixConn 的模拟连接。
|
// mockConn 是一个不实现 TCPConn/UnixConn 的模拟连接。
|
||||||
//
|
//
|
||||||
// 用于测试 getSocketFd 对不支持连接类型的处理。
|
// 用于测试 getSocketFile 对不支持连接类型的处理。
|
||||||
type mockConn struct{}
|
type mockConn struct{}
|
||||||
|
|
||||||
func (m *mockConn) Read([]byte) (n int, err error) { return 0, nil }
|
func (m *mockConn) Read([]byte) (n int, err error) { return 0, nil }
|
||||||
@ -449,11 +449,11 @@ func TestGetSocketFd_UnixConn(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer clientConn.Close()
|
defer clientConn.Close()
|
||||||
|
|
||||||
fd, err := getSocketFd(clientConn)
|
f, err := getSocketFile(clientConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("getSocketFd failed: %v", err)
|
t.Errorf("getSocketFile failed: %v", err)
|
||||||
}
|
}
|
||||||
if fd == 0 {
|
if f == nil {
|
||||||
t.Error("Expected non-zero fd")
|
t.Error("Expected non-zero fd")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -616,11 +616,11 @@ func TestGetSocketFd_TCPConn(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer clientConn.Close()
|
defer clientConn.Close()
|
||||||
|
|
||||||
fd, err := getSocketFd(clientConn)
|
f, err := getSocketFile(clientConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("getSocketFd failed for TCPConn: %v", err)
|
t.Errorf("getSocketFile failed for TCPConn: %v", err)
|
||||||
}
|
}
|
||||||
if fd == 0 {
|
if f == nil {
|
||||||
t.Error("Expected non-zero fd for TCPConn")
|
t.Error("Expected non-zero fd for TCPConn")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -70,6 +70,7 @@ type RateLimiter struct {
|
|||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
stopCleanupCh chan struct{}
|
stopCleanupCh chan struct{}
|
||||||
cleanupDone chan struct{}
|
cleanupDone chan struct{}
|
||||||
|
stopOnce sync.Once
|
||||||
rate float64
|
rate float64
|
||||||
burst float64
|
burst float64
|
||||||
}
|
}
|
||||||
@ -494,18 +495,9 @@ func (rl *RateLimiter) startCleanup(interval time.Duration) {
|
|||||||
// 发送停止信号并等待 goroutine 完成,确保资源正确释放。
|
// 发送停止信号并等待 goroutine 完成,确保资源正确释放。
|
||||||
// 该方法应在限流器不再使用时调用(如服务器关闭时)。
|
// 该方法应在限流器不再使用时调用(如服务器关闭时)。
|
||||||
func (rl *RateLimiter) StopCleanup() {
|
func (rl *RateLimiter) StopCleanup() {
|
||||||
// 使用原子操作或简单的标志检查来避免竞争
|
|
||||||
// 关闭 stopCleanupCh 会广播给所有等待的 goroutine
|
|
||||||
select {
|
|
||||||
case <-rl.stopCleanupCh:
|
|
||||||
// 已经关闭
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
if rl.cleanupTicker != nil {
|
if rl.cleanupTicker != nil {
|
||||||
rl.cleanupTicker.Stop()
|
rl.cleanupTicker.Stop()
|
||||||
close(rl.stopCleanupCh)
|
rl.stopOnce.Do(func() { close(rl.stopCleanupCh) })
|
||||||
<-rl.cleanupDone
|
<-rl.cleanupDone
|
||||||
rl.cleanupTicker = nil // 防止重复关闭
|
rl.cleanupTicker = nil // 防止重复关闭
|
||||||
}
|
}
|
||||||
|
|||||||
@ -721,6 +721,7 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
|
|||||||
if bytes.ContainsAny(path, "@\r\n") {
|
if bytes.ContainsAny(path, "@\r\n") {
|
||||||
logging.Warn().Msgf("rejected suspicious proxy path containing dangerous chars: %s", path)
|
logging.Warn().Msgf("rejected suspicious proxy path containing dangerous chars: %s", path)
|
||||||
upstreamStatus = 502
|
upstreamStatus = 502
|
||||||
|
loadbalance.DecrementConnections(target)
|
||||||
utils.SendErrorWithDetail(ctx, utils.ErrBadGateway, "invalid proxy path")
|
utils.SendErrorWithDetail(ctx, utils.ErrBadGateway, "invalid proxy path")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@ -752,7 +752,7 @@ func TestWebSocket_UpgradeRejected(t *testing.T) {
|
|||||||
// 发送一个请求让服务端触发
|
// 发送一个请求让服务端触发
|
||||||
_, _ = conn.Write([]byte("GET /ws HTTP/1.1\r\nHost: localhost\r\n\r\n"))
|
_, _ = conn.Write([]byte("GET /ws HTTP/1.1\r\nHost: localhost\r\n\r\n"))
|
||||||
|
|
||||||
resp, err := readWebSocketUpgradeResponse(conn, 1*time.Second)
|
resp, _, err := readWebSocketUpgradeResponse(conn, 1*time.Second)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 400, resp.StatusCode)
|
assert.Equal(t, 400, resp.StatusCode)
|
||||||
}
|
}
|
||||||
@ -1274,7 +1274,7 @@ func TestReadWebSocketUpgradeResponse_ReadError(t *testing.T) {
|
|||||||
conn1, conn2 := net.Pipe()
|
conn1, conn2 := net.Pipe()
|
||||||
_ = conn2.Close()
|
_ = conn2.Close()
|
||||||
|
|
||||||
_, err := readWebSocketUpgradeResponse(conn1, 100*time.Millisecond)
|
_, _, err := readWebSocketUpgradeResponse(conn1, 100*time.Millisecond)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
_ = conn1.Close()
|
_ = conn1.Close()
|
||||||
}
|
}
|
||||||
|
|||||||
@ -332,20 +332,20 @@ func buildWebSocketUpgradeRequest(ctx *fasthttp.RequestCtx, targetHost string, h
|
|||||||
// 返回值:
|
// 返回值:
|
||||||
// - *http.Response: HTTP 响应对象
|
// - *http.Response: HTTP 响应对象
|
||||||
// - error: 读取失败时返回错误
|
// - error: 读取失败时返回错误
|
||||||
func readWebSocketUpgradeResponse(conn net.Conn, timeout time.Duration) (*http.Response, error) {
|
func readWebSocketUpgradeResponse(conn net.Conn, timeout time.Duration) (*http.Response, *bufio.Reader, error) {
|
||||||
// 设置读取超时
|
// 设置读取超时
|
||||||
if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil {
|
if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用 bufio.Reader 读取 HTTP 响应
|
// 使用 bufio.Reader 读取 HTTP 响应
|
||||||
reader := bufio.NewReader(conn)
|
reader := bufio.NewReader(conn)
|
||||||
resp, err := http.ReadResponse(reader, nil)
|
resp, err := http.ReadResponse(reader, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to read upgrade response: %w", err)
|
return nil, nil, fmt.Errorf("failed to read upgrade response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return resp, nil
|
return resp, reader, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// WebSocket 处理 WebSocket 代理请求。
|
// WebSocket 处理 WebSocket 代理请求。
|
||||||
@ -400,7 +400,7 @@ func WebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, timeout tim
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 步骤4: 读取升级响应
|
// 步骤4: 读取升级响应
|
||||||
resp, err := readWebSocketUpgradeResponse(targetConn, timeout)
|
resp, bufferedReader, err := readWebSocketUpgradeResponse(targetConn, timeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to read upgrade response: %w", err)
|
return fmt.Errorf("failed to read upgrade response: %w", err)
|
||||||
}
|
}
|
||||||
@ -422,6 +422,16 @@ func WebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, timeout tim
|
|||||||
// 注意: WebSocket 升级成功后,resp.Body 不需要显式关闭
|
// 注意: WebSocket 升级成功后,resp.Body 不需要显式关闭
|
||||||
// 因为底层连接已被 bridge 用于双向数据传输
|
// 因为底层连接已被 bridge 用于双向数据传输
|
||||||
|
|
||||||
|
// 如果 bufferedReader 已经缓冲了 WebSocket frame 数据,
|
||||||
|
// 需要包装连接使后续读取先消耗缓冲区
|
||||||
|
if bufferedReader != nil && bufferedReader.Buffered() > 0 {
|
||||||
|
targetConn = &bufferedConn{
|
||||||
|
Conn: targetConn,
|
||||||
|
reader: bufferedReader,
|
||||||
|
}
|
||||||
|
bridge.targetConn = targetConn
|
||||||
|
}
|
||||||
|
|
||||||
// 步骤7: 启动桥接(阻塞直到连接关闭)
|
// 步骤7: 启动桥接(阻塞直到连接关闭)
|
||||||
return bridge.Bridge()
|
return bridge.Bridge()
|
||||||
}
|
}
|
||||||
@ -469,3 +479,19 @@ func writeUpgradeResponse(conn net.Conn, resp *http.Response) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// bufferedConn 包装 net.Conn,优先从 bufio.Reader 的缓冲区读取数据。
|
||||||
|
//
|
||||||
|
// 用于 WebSocket 升级响应后,消耗 bufio.Reader 可能已缓冲的 frame 数据。
|
||||||
|
type bufferedConn struct {
|
||||||
|
net.Conn
|
||||||
|
reader *bufio.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read 优先从内部 bufio.Reader 读取,若缓冲区为空则回退到原始连接。
|
||||||
|
func (bc *bufferedConn) Read(p []byte) (int, error) {
|
||||||
|
if bc.reader.Buffered() > 0 {
|
||||||
|
return bc.reader.Read(p)
|
||||||
|
}
|
||||||
|
return bc.Conn.Read(p)
|
||||||
|
}
|
||||||
|
|||||||
@ -557,7 +557,7 @@ func TestReadWebSocketUpgradeResponse(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// 读取响应
|
// 读取响应
|
||||||
resp, err := readWebSocketUpgradeResponse(conn1, 1*time.Second)
|
resp, _, err := readWebSocketUpgradeResponse(conn1, 1*time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("readWebSocketUpgradeResponse failed: %v", err)
|
t.Fatalf("readWebSocketUpgradeResponse failed: %v", err)
|
||||||
}
|
}
|
||||||
@ -579,7 +579,7 @@ func TestReadWebSocketUpgradeResponse_Timeout(t *testing.T) {
|
|||||||
defer func() { _ = conn2.Close() }()
|
defer func() { _ = conn2.Close() }()
|
||||||
|
|
||||||
// 使用很短的超时
|
// 使用很短的超时
|
||||||
_, err := readWebSocketUpgradeResponse(conn1, 10*time.Millisecond)
|
_, _, err := readWebSocketUpgradeResponse(conn1, 10*time.Millisecond)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Expected timeout error, got nil")
|
t.Error("Expected timeout error, got nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -179,9 +179,14 @@ func (p *GoroutinePool) Submit(ctx *fasthttp.RequestCtx, task Task) error {
|
|||||||
// 队列满,需要启动新 worker 或直接执行
|
// 队列满,需要启动新 worker 或直接执行
|
||||||
if atomic.LoadInt32(&p.workers) < p.maxWorkers {
|
if atomic.LoadInt32(&p.workers) < p.maxWorkers {
|
||||||
p.startWorker()
|
p.startWorker()
|
||||||
// 重新尝试入队
|
// 非阻塞尝试入队,避免新 worker 尚未就绪时死锁
|
||||||
p.taskQueue <- task
|
select {
|
||||||
return nil
|
case p.taskQueue <- task:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
task(ctx)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 达到最大 worker,直接执行(fallback)
|
// 达到最大 worker,直接执行(fallback)
|
||||||
|
|||||||
@ -982,7 +982,13 @@ func (s *udpServer) serve() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
continue
|
// 非超时错误(如连接关闭),检查 stopCh 后退出
|
||||||
|
select {
|
||||||
|
case <-s.stopCh:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取或创建会话
|
// 获取或创建会话
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user