diff --git a/internal/middleware/accesslog/accesslog.go b/internal/middleware/accesslog/accesslog.go new file mode 100644 index 0000000..8e5656d --- /dev/null +++ b/internal/middleware/accesslog/accesslog.go @@ -0,0 +1,42 @@ +// Package accesslog 提供访问日志中间件,记录每个请求的详细信息。 +package accesslog + +import ( + "time" + + "github.com/valyala/fasthttp" + "rua.plus/lolly/internal/config" + "rua.plus/lolly/internal/logging" +) + +// AccessLog 访问日志中间件,记录请求方法、路径、状态码、响应大小和处理时间。 +type AccessLog struct { + logger *logging.Logger +} + +// New 创建访问日志中间件。 +func New(cfg *config.LoggingConfig) *AccessLog { + return &AccessLog{ + logger: logging.New(cfg), + } +} + +// Name 返回中间件名称。 +func (a *AccessLog) Name() string { + return "accesslog" +} + +// Process 包装 handler,在请求处理后记录访问日志。 +func (a *AccessLog) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + start := time.Now() + next(ctx) + duration := time.Since(start) + a.logger.LogAccess(ctx, ctx.Response.StatusCode(), int64(len(ctx.Response.Body())), duration) + } +} + +// Close 关闭日志文件。 +func (a *AccessLog) Close() error { + return a.logger.Close() +} \ No newline at end of file diff --git a/internal/middleware/accesslog/accesslog_test.go b/internal/middleware/accesslog/accesslog_test.go new file mode 100644 index 0000000..50fb751 --- /dev/null +++ b/internal/middleware/accesslog/accesslog_test.go @@ -0,0 +1,79 @@ +package accesslog + +import ( + "bytes" + "testing" + "time" + + "github.com/valyala/fasthttp" + "rua.plus/lolly/internal/config" +) + +func TestAccessLog_Name(t *testing.T) { + al := New(&config.LoggingConfig{}) + if al.Name() != "accesslog" { + t.Errorf("expected name 'accesslog', got '%s'", al.Name()) + } +} + +func TestAccessLog_Process(t *testing.T) { + al := New(&config.LoggingConfig{ + Access: config.AccessLogConfig{Format: "json"}, + }) + + // 创建一个简单的 handler + handler := func(ctx *fasthttp.RequestCtx) { + ctx.SetStatusCode(200) + ctx.SetBodyString("hello") + } + + // 包装 handler + wrapped := al.Process(handler) + + // 创建模拟请求上下文 + var ctx fasthttp.RequestCtx + ctx.Init(&fasthttp.Request{}, nil, nil) + + // 执行 + wrapped(&ctx) + + // 验证响应未被修改 + if ctx.Response.StatusCode() != 200 { + t.Errorf("expected status 200, got %d", ctx.Response.StatusCode()) + } + if !bytes.Equal(ctx.Response.Body(), []byte("hello")) { + t.Errorf("expected body 'hello', got '%s'", ctx.Response.Body()) + } + + // 清理 + al.Close() +} + +func TestAccessLog_ProcessWithDuration(t *testing.T) { + al := New(&config.LoggingConfig{ + Access: config.AccessLogConfig{Format: "json"}, + }) + + // 创建一个有延迟的 handler + handler := func(ctx *fasthttp.RequestCtx) { + time.Sleep(10 * time.Millisecond) + ctx.SetStatusCode(201) + ctx.SetBodyString("created") + } + + wrapped := al.Process(handler) + + var ctx fasthttp.RequestCtx + ctx.Init(&fasthttp.Request{}, nil, nil) + + start := time.Now() + wrapped(&ctx) + elapsed := time.Since(start) + + // 验证延迟被记录(至少 10ms) + if elapsed < 10*time.Millisecond { + t.Errorf("expected duration >= 10ms, got %v", elapsed) + } + + al.Close() +} \ No newline at end of file diff --git a/internal/server/server.go b/internal/server/server.go index 5ed0c58..31927a1 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -10,16 +10,18 @@ import ( "rua.plus/lolly/internal/loadbalance" "rua.plus/lolly/internal/logging" "rua.plus/lolly/internal/middleware" + "rua.plus/lolly/internal/middleware/accesslog" "rua.plus/lolly/internal/proxy" ) // Server HTTP 服务器 type Server struct { - config *config.Config - fastServer *fasthttp.Server - handler fasthttp.RequestHandler - running bool - healthCheckers []*proxy.HealthChecker // 新增 + config *config.Config + fastServer *fasthttp.Server + handler fasthttp.RequestHandler + running bool + healthCheckers []*proxy.HealthChecker + accessLogMiddleware *accesslog.AccessLog } // New 创建服务器 @@ -52,7 +54,10 @@ func (s *Server) startSingleMode() error { router.GET("/{filepath:*}", staticHandler.Handle) router.HEAD("/{filepath:*}", staticHandler.Handle) - chain := middleware.NewChain() + // 创建访问日志中间件 + s.accessLogMiddleware = accesslog.New(&s.config.Logging) + + chain := middleware.NewChain(s.accessLogMiddleware) s.handler = chain.Apply(router.Handler()) s.fastServer = &fasthttp.Server{ @@ -73,6 +78,10 @@ func (s *Server) startSingleMode() error { func (s *Server) startVHostMode() error { vhostMgr := NewVHostManager() + // 创建访问日志中间件(共享给所有虚拟主机) + s.accessLogMiddleware = accesslog.New(&s.config.Logging) + chain := middleware.NewChain(s.accessLogMiddleware) + for i := range s.config.Servers { router := handler.NewRouter() s.registerProxyRoutes(router, &s.config.Servers[i]) @@ -85,7 +94,7 @@ func (s *Server) startVHostMode() error { router.GET("/{filepath:*}", staticHandler.Handle) router.HEAD("/{filepath:*}", staticHandler.Handle) - vhostMgr.AddHost(s.config.Servers[i].Name, router.Handler()) + vhostMgr.AddHost(s.config.Servers[i].Name, chain.Apply(router.Handler())) } // 默认主机 @@ -97,7 +106,7 @@ func (s *Server) startVHostMode() error { s.config.Server.Static.Index, ) router.GET("/{filepath:*}", staticHandler.Handle) - vhostMgr.SetDefault(router.Handler()) + vhostMgr.SetDefault(chain.Apply(router.Handler())) } s.handler = vhostMgr.Handler() @@ -161,6 +170,11 @@ func (s *Server) Stop() error { hc.Stop() } + // 关闭访问日志 + if s.accessLogMiddleware != nil { + s.accessLogMiddleware.Close() + } + if s.fastServer != nil { return s.fastServer.Shutdown() } @@ -176,6 +190,11 @@ func (s *Server) GracefulStop(timeout time.Duration) error { hc.Stop() } + // 关闭访问日志 + if s.accessLogMiddleware != nil { + s.accessLogMiddleware.Close() + } + if s.fastServer != nil { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel()