diff --git a/internal/app/app_test.go b/internal/app/app_test.go index 45ae8f6..80cc6f1 100644 --- a/internal/app/app_test.go +++ b/internal/app/app_test.go @@ -200,6 +200,7 @@ func TestRun(t *testing.T) { name string cfgPath string outputPath string + importPath string wantContains string wantErrContains string wantExitCode int @@ -234,6 +235,25 @@ func TestRun(t *testing.T) { wantExitCode: 1, wantErrContains: "加载配置失败", }, + { + name: "generate 与 import 互斥", + genConfig: true, + importPath: "/tmp/nginx.conf", + wantExitCode: 1, + wantErrContains: "mutually exclusive", + }, + { + name: "o 参数无 generate 或 import", + outputPath: "output.yaml", + wantExitCode: 1, + wantErrContains: "-o requires", + }, + { + name: "导入 nginx 配置文件不存在", + importPath: "/tmp/nginx.conf", + wantExitCode: 1, + wantErrContains: "解析 nginx 配置失败", + }, } for _, tt := range tests { @@ -241,7 +261,7 @@ func TestRun(t *testing.T) { getStdout, restoreStdout := captureStdout(t) getStderr, restoreStderr := captureStderr(t) - exitCode := Run(tt.cfgPath, tt.genConfig, tt.outputPath, tt.showVersion) + exitCode := Run(tt.cfgPath, tt.genConfig, tt.outputPath, tt.importPath, tt.showVersion) restoreStderr() restoreStdout() diff --git a/internal/converter/nginx/converter.go b/internal/converter/nginx/converter.go new file mode 100644 index 0000000..af78e9c --- /dev/null +++ b/internal/converter/nginx/converter.go @@ -0,0 +1,886 @@ +// Package nginx provides a converter from nginx configuration to lolly configuration. +package nginx + +import ( + "fmt" + "strconv" + "strings" + "time" + + "rua.plus/lolly/internal/config" +) + +// Warning represents a conversion warning for unsupported or partially supported directives. +type Warning struct { + Directive string + Line int + File string + Message string +} + +func (w Warning) String() string { + return fmt.Sprintf("warning: %s:%d: %s", w.File, w.Line, w.Message) +} + +// ConvertResult holds the conversion output. +type ConvertResult struct { + Config *config.Config + Warnings []Warning +} + +// upstreamInfo holds parsed upstream data for later reference. +type upstreamInfo struct { + Targets []config.ProxyTarget + LoadBalance string +} + +// locationClassification classifies a location block for conversion. +type locationClassification struct { + LocType string // "proxy", "static", "redirect", "unsupported" + Path string // location path (without modifier) + Modifier string // "=", "^~", "~", "~*", "@" + Directives []Directive // original directives in the location block +} + +// unsupportedDirectives are known nginx directives that have no lolly equivalent. +var unsupportedDirectives = map[string]string{ + "if": "the 'if' directive is not supported; consider using map or rewrite", + "map": "the 'map' directive is not supported; use variables config instead", + "set": "the 'set' directive is not supported; use variables config instead", + "limit_req": "the 'limit_req' directive is not supported; use rate_limit config instead", + "limit_conn": "the 'limit_conn' directive is not supported", + "add_header": "the 'add_header' directive is not supported; use security.headers config instead", + "more_set_headers": "the 'more_set_headers' directive is not supported; use security.headers config instead", + "auth_request": "the 'auth_request' directive is not supported; use security.auth_request config instead", + "split_clients": "the 'split_clients' directive is not supported", + "geo": "the 'geo' directive is not supported; use access.geoip config instead", + "range": "the 'range' directive is not supported", + "return": "the 'return' directive is not supported for non-redirect status codes; only 301/302 are supported", +} + +// Convert converts a parsed nginx configuration to a lolly configuration. +func Convert(nginxCfg *NginxConfig) (*ConvertResult, error) { + result := &ConvertResult{ + Config: &config.Config{}, + } + + // 1. Build upstream map from top-level and http-level upstream blocks. + upstreams := make(map[string]*upstreamInfo) + for i := range nginxCfg.Directives { + d := &nginxCfg.Directives[i] + if d.Name == "upstream" { + info := convertUpstream(d, result) + if len(d.Args) > 0 { + upstreams[d.Args[0]] = info + } + } + if d.Name == "http" { + for j := range d.Block { + bd := &d.Block[j] + if bd.Name == "upstream" { + info := convertUpstream(bd, result) + if len(bd.Args) > 0 { + upstreams[bd.Args[0]] = info + } + } + } + } + } + + // 2. Find all server blocks: inside http blocks, or at top level. + var serverBlocks []Directive + for i := range nginxCfg.Directives { + d := &nginxCfg.Directives[i] + if d.Name == "http" { + // Check for unsupported directives at the http level. + for j := range d.Block { + bd := &d.Block[j] + if bd.Name == "server" { + serverBlocks = append(serverBlocks, d.Block[j]) + } else if msg, ok := unsupportedDirectives[bd.Name]; ok { + result.Warnings = append(result.Warnings, Warning{ + Directive: bd.Name, + Line: bd.Line, + File: bd.File, + Message: msg, + }) + } + } + } else if d.Name == "server" { + serverBlocks = append(serverBlocks, *d) + } + } + + // 3. Convert each server block. + for i := range serverBlocks { + serverCfg := convertServerBlock(&serverBlocks[i], upstreams, result) + result.Config.Servers = append(result.Config.Servers, serverCfg) + } + + return result, nil +} + +// convertUpstream converts an upstream block to upstreamInfo. +func convertUpstream(d *Directive, result *ConvertResult) *upstreamInfo { + info := &upstreamInfo{} + + for i := range d.Block { + bd := &d.Block[i] + + switch bd.Name { + case "server": + target := convertUpstreamServer(bd) + info.Targets = append(info.Targets, target) + case "least_conn": + info.LoadBalance = "least_conn" + case "ip_hash": + info.LoadBalance = "ip_hash" + case "hash": + // hash $variable consistent → consistent_hash + if len(bd.Args) > 0 { + info.LoadBalance = "consistent_hash" + } + case "random": + info.LoadBalance = "random" + default: + result.Warnings = append(result.Warnings, Warning{ + Directive: bd.Name, + Line: bd.Line, + File: bd.File, + Message: fmt.Sprintf("unsupported directive in upstream block: %s", bd.Name), + }) + } + } + + return info +} + +// convertUpstreamServer parses a server directive inside an upstream block. +func convertUpstreamServer(d *Directive) config.ProxyTarget { + target := config.ProxyTarget{} + + if len(d.Args) > 0 { + target.URL = d.Args[0] + } + + for _, arg := range d.Args[1:] { + if after, ok := strings.CutPrefix(arg, "weight="); ok { + if v, err := strconv.Atoi(after); err == nil { + target.Weight = v + } + } else if after, ok := strings.CutPrefix(arg, "max_fails="); ok { + if v, err := strconv.Atoi(after); err == nil { + target.MaxFails = v + } + } else if after, ok := strings.CutPrefix(arg, "fail_timeout="); ok { + target.FailTimeout = parseDuration(after) + } else if arg == "backup" { + target.Backup = true + } else if arg == "down" { + target.Down = true + } + } + + return target +} + +// convertServerBlock converts a server block directive to a ServerConfig. +func convertServerBlock(d *Directive, upstreams map[string]*upstreamInfo, result *ConvertResult) config.ServerConfig { + server := config.ServerConfig{} + var sslDetected bool + + for i := range d.Block { + bd := &d.Block[i] + + switch bd.Name { + case "listen": + if parseListen(bd, &server) { + sslDetected = true + } + case "server_name": + parseServerName(bd, &server) + case "ssl_certificate": + if len(bd.Args) > 0 { + server.SSL.Cert = bd.Args[0] + } + case "ssl_certificate_key": + if len(bd.Args) > 0 { + server.SSL.Key = bd.Args[0] + } + case "gzip": + parseGzip(bd, &server) + case "gzip_types": + server.Compression.Types = bd.Args + case "gzip_min_length": + if len(bd.Args) > 0 { + if v, err := strconv.Atoi(bd.Args[0]); err == nil { + server.Compression.MinSize = v + } + } + case "client_max_body_size": + if len(bd.Args) > 0 { + server.ClientMaxBodySize = bd.Args[0] + } + case "server_tokens": + if len(bd.Args) > 0 { + server.ServerTokens = bd.Args[0] != "off" + } + case "access_log": + parseAccessLog(bd, result) + case "error_log": + parseErrorLog(bd, result) + case "return": + parseServerReturn(bd, &server, result) + case "rewrite": + parseRewrite(bd, &server) + case "location": + classification := classifyLocation(bd, result) + convertLocation(classification, &server, upstreams, result) + case "error_page": + parseErrorPage(bd, &server) + case "auth_basic": + parseAuthBasic(bd, &server) + case "auth_basic_user_file": + if len(bd.Args) > 0 { + result.Warnings = append(result.Warnings, Warning{ + Directive: "auth_basic_user_file", + Line: bd.Line, + File: bd.File, + Message: fmt.Sprintf("auth_basic_user_file (%s) cannot be directly converted; htpasswd file must be manually migrated to auth.users", bd.Args[0]), + }) + } + default: + if msg, ok := unsupportedDirectives[bd.Name]; ok { + result.Warnings = append(result.Warnings, Warning{ + Directive: bd.Name, + Line: bd.Line, + File: bd.File, + Message: msg, + }) + } + } + } + + // Warn if SSL was detected (listen ... ssl) but cert/key are not configured. + if sslDetected && (server.SSL.Cert == "" || server.SSL.Key == "") { + result.Warnings = append(result.Warnings, Warning{ + Directive: "listen", + Message: "SSL is enabled via listen directive but ssl_certificate and/or ssl_certificate_key are not configured; SSL config will be incomplete", + }) + } + + // Default listen address if no listen directive was specified. + if server.Listen == "" { + server.Listen = "0.0.0.0:80" + } + + return server +} + +// parseListen parses a listen directive. +func parseListen(d *Directive, server *config.ServerConfig) bool { + if len(d.Args) == 0 { + return false + } + + addr := d.Args[0] + isSSL := false + isDefault := false + + for _, arg := range d.Args[1:] { + if arg == "ssl" { + isSSL = true + } + if arg == "default_server" { + isDefault = true + } + } + + // If addr is just a port number like "80" or "8080", prefix with ":". + if port, err := strconv.Atoi(addr); err == nil { + server.Listen = fmt.Sprintf(":%d", port) + } else if strings.Contains(addr, ":") { + server.Listen = addr + } else { + server.Listen = ":" + addr + } + + // Set default_server flag. + if isDefault { + server.Default = true + } + + // Enable SSL if specified. + if isSSL { + server.SSL.Cert = "" // Marker cleared; cert/key set by ssl_certificate directives. + server.SSL.Key = "" // If cert/key remain empty, a warning is added after processing. + } + + return isSSL +} + +// parseServerName parses a server_name directive. +func parseServerName(d *Directive, server *config.ServerConfig) { + if len(d.Args) == 0 { + return + } + + server.Name = d.Args[0] + server.ServerNames = append(server.ServerNames, d.Args...) +} + +// parseGzip parses a gzip directive. +func parseGzip(d *Directive, server *config.ServerConfig) { + if len(d.Args) > 0 && d.Args[0] == "on" { + server.Compression.Type = "gzip" + } +} + +// parseAccessLog parses an access_log directive. +func parseAccessLog(d *Directive, result *ConvertResult) { + if len(d.Args) > 0 { + result.Config.Logging.Access.Path = d.Args[0] + } + if len(d.Args) > 1 { + result.Config.Logging.Access.Format = d.Args[1] + } +} + +// parseErrorLog parses an error_log directive. +func parseErrorLog(d *Directive, result *ConvertResult) { + if len(d.Args) > 0 { + result.Config.Logging.Error.Path = d.Args[0] + } + if len(d.Args) > 1 { + result.Config.Logging.Error.Level = d.Args[1] + } +} + +// parseServerReturn parses a return directive at server level. +func parseServerReturn(d *Directive, server *config.ServerConfig, result *ConvertResult) { + if len(d.Args) == 0 { + return + } + + code, err := strconv.Atoi(d.Args[0]) + if err != nil { + return + } + + switch code { + case 301: + url := "" + if len(d.Args) > 1 { + url = d.Args[1] + } + server.Rewrite = append(server.Rewrite, config.RewriteRule{ + Pattern: "^/", + Replacement: url, + Flag: "permanent", + }) + case 302: + url := "" + if len(d.Args) > 1 { + url = d.Args[1] + } + server.Rewrite = append(server.Rewrite, config.RewriteRule{ + Pattern: "^/", + Replacement: url, + Flag: "redirect", + }) + default: + result.Warnings = append(result.Warnings, Warning{ + Directive: "return", + Line: d.Line, + File: d.File, + Message: fmt.Sprintf("return %d is not a redirect; only 301/302 are supported at server level", code), + }) + } +} + +// parseRewrite parses a rewrite directive. +func parseRewrite(d *Directive, server *config.ServerConfig) { + if len(d.Args) < 2 { + return + } + + rule := config.RewriteRule{ + Pattern: d.Args[0], + Replacement: d.Args[1], + } + + if len(d.Args) > 2 { + rule.Flag = d.Args[2] + } + + server.Rewrite = append(server.Rewrite, rule) +} + +// parseErrorPage parses an error_page directive. +func parseErrorPage(d *Directive, server *config.ServerConfig) { + // error_page 404 500 50x.html + // error_page 404 /404.html + if len(d.Args) < 2 { + return + } + + // Last arg is the page path. + pagePath := d.Args[len(d.Args)-1] + + if server.Security.ErrorPage.Pages == nil { + server.Security.ErrorPage.Pages = make(map[int]string) + } + + for _, arg := range d.Args[:len(d.Args)-1] { + if code, err := strconv.Atoi(arg); err == nil { + server.Security.ErrorPage.Pages[code] = pagePath + } + } +} + +// parseAuthBasic parses an auth_basic directive. +func parseAuthBasic(d *Directive, server *config.ServerConfig) { + if len(d.Args) > 0 { + if d.Args[0] != "off" { + server.Security.Auth.Type = "basic" + server.Security.Auth.Realm = d.Args[0] + } + } +} + +// classifyLocation classifies a location block based on its directives. +func classifyLocation(d *Directive, result *ConvertResult) locationClassification { + class := locationClassification{ + Directives: d.Block, + } + + // Parse location path and modifier. + if len(d.Args) > 0 { + first := d.Args[0] + switch first { + case "=", "^~", "~", "~*": + class.Modifier = first + if len(d.Args) > 1 { + class.Path = d.Args[1] + } + default: + if strings.HasPrefix(first, "@") { + class.Modifier = "@" + class.Path = first[1:] + } else { + class.Path = first + } + } + } + + // Classify based on content. + hasProxyPass := false + hasRootOrAlias := false + hasRedirect := false + + for i := range d.Block { + switch d.Block[i].Name { + case "proxy_pass": + hasProxyPass = true + case "root", "alias": + hasRootOrAlias = true + case "return": + if len(d.Block[i].Args) > 0 { + code, err := strconv.Atoi(d.Block[i].Args[0]) + if err == nil && (code == 301 || code == 302) { + hasRedirect = true + } + } + } + } + + switch { + case hasProxyPass: + class.LocType = "proxy" + if hasRootOrAlias { + result.Warnings = append(result.Warnings, Warning{ + Directive: "location", + Line: d.Line, + File: d.File, + Message: "location has both proxy_pass and root/alias; proxy_pass takes priority", + }) + } + case hasRootOrAlias: + class.LocType = "static" + case hasRedirect: + class.LocType = "redirect" + default: + class.LocType = "unsupported" + } + + return class +} + +// convertLocation converts a classified location to the appropriate config entries. +func convertLocation(class locationClassification, server *config.ServerConfig, upstreams map[string]*upstreamInfo, result *ConvertResult) { + locType := modifierToLocationType(class.Modifier) + + switch class.LocType { + case "proxy": + proxy := config.ProxyConfig{ + Path: class.Path, + LocationType: locType, + } + + if class.Modifier == "@" { + proxy.LocationName = class.Path + } + + convertProxyDirectives(class.Directives, &proxy, upstreams, result) + server.Proxy = append(server.Proxy, proxy) + + case "static": + static := config.StaticConfig{ + Path: class.Path, + LocationType: locType, + } + + convertStaticDirectives(class.Directives, &static, result) + server.Static = append(server.Static, static) + + case "redirect": + convertRedirectDirectives(class.Directives, class.Path, server, result) + + case "unsupported": + if len(class.Directives) == 0 { + result.Warnings = append(result.Warnings, Warning{ + Directive: "location", + Message: fmt.Sprintf("location %s has no content and is unsupported", class.Path), + }) + } + for i := range class.Directives { + bd := &class.Directives[i] + result.Warnings = append(result.Warnings, Warning{ + Directive: bd.Name, + Line: bd.Line, + File: bd.File, + Message: fmt.Sprintf("unsupported directive in location: %s", bd.Name), + }) + } + } +} + +// modifierToLocationType maps nginx location modifiers to lolly location types. +func modifierToLocationType(modifier string) string { + switch modifier { + case "=": + return "exact" + case "^~": + return "prefix_priority" + case "~": + return "regex" + case "~*": + return "regex_caseless" + case "@": + return "named" + default: + return "" + } +} + +// convertProxyDirectives converts directives within a proxy location block. +func convertProxyDirectives(directives []Directive, proxy *config.ProxyConfig, upstreams map[string]*upstreamInfo, result *ConvertResult) { + for i := range directives { + d := &directives[i] + + switch d.Name { + case "proxy_pass": + if len(d.Args) > 0 { + url := d.Args[0] + // Check if URL references an upstream name (no scheme). + if upstreamName := extractUpstreamName(url); upstreamName != "" { + if info, ok := upstreams[upstreamName]; ok { + proxy.Targets = append(proxy.Targets, info.Targets...) + if info.LoadBalance != "" && proxy.LoadBalance == "" { + proxy.LoadBalance = info.LoadBalance + } + } else { + // Upstream not found; use URL as-is. + proxy.Targets = append(proxy.Targets, config.ProxyTarget{URL: url}) + } + } else { + proxy.Targets = append(proxy.Targets, config.ProxyTarget{URL: url}) + } + } + case "proxy_set_header": + if len(d.Args) >= 2 { + if proxy.Headers.SetRequest == nil { + proxy.Headers.SetRequest = make(map[string]string) + } + proxy.Headers.SetRequest[d.Args[0]] = mapVariable(d.Args[1], result, d) + } + case "proxy_hide_header": + if len(d.Args) > 0 { + proxy.Headers.HideResponse = append(proxy.Headers.HideResponse, d.Args[0]) + } + case "proxy_pass_header": + if len(d.Args) > 0 { + proxy.Headers.PassResponse = append(proxy.Headers.PassResponse, d.Args[0]) + } + case "proxy_redirect": + convertProxyRedirect(d, proxy) + case "proxy_connect_timeout": + if len(d.Args) > 0 { + proxy.Timeout.Connect = parseDuration(d.Args[0]) + } + case "proxy_read_timeout": + if len(d.Args) > 0 { + proxy.Timeout.Read = parseDuration(d.Args[0]) + } + case "proxy_send_timeout": + if len(d.Args) > 0 { + proxy.Timeout.Write = parseDuration(d.Args[0]) + } + case "proxy_cache": + proxy.Cache.Enabled = true + case "proxy_cache_valid": + parseProxyCacheValid(d, proxy) + default: + if msg, ok := unsupportedDirectives[d.Name]; ok { + result.Warnings = append(result.Warnings, Warning{ + Directive: d.Name, + Line: d.Line, + File: d.File, + Message: msg, + }) + } + } + } +} + +// extractUpstreamName extracts an upstream name from a proxy_pass URL. +// If the URL has no scheme (e.g., "http://upstream_name" where upstream_name +// has no port), it returns the host portion. Otherwise returns empty string. +func extractUpstreamName(url string) string { + if _, rest, ok := strings.Cut(url, "://"); ok { + host := rest + if slashIdx := strings.IndexAny(host, "/?"); slashIdx >= 0 { + host = host[:slashIdx] + } + // Check if this is an upstream reference by looking up known upstream names. + // An upstream name is a host with no port and no dot (not an IP or domain). + if !strings.Contains(host, ":") && !strings.Contains(host, ".") && host != "" { + return host + } + } + return "" +} + +// mapVariable replaces nginx variables with lolly equivalents. +func mapVariable(value string, result *ConvertResult, d *Directive) string { + if strings.Contains(value, "$proxy_add_x_forwarded_for") { + result.Warnings = append(result.Warnings, Warning{ + Directive: "proxy_set_header", + Line: d.Line, + File: d.File, + Message: "$proxy_add_x_forwarded_for is replaced with $remote_addr; lolly automatically appends to X-Forwarded-For", + }) + return strings.ReplaceAll(value, "$proxy_add_x_forwarded_for", "$remote_addr") + } + return value +} + +// convertProxyRedirect handles proxy_redirect directive. +func convertProxyRedirect(d *Directive, proxy *config.ProxyConfig) { + if len(d.Args) == 0 { + return + } + + rr := &config.RedirectRewriteConfig{} + + switch d.Args[0] { + case "off": + rr.Mode = "off" + case "default": + rr.Mode = "default" + default: + rr.Mode = "custom" + if len(d.Args) >= 2 { + rr.Rules = append(rr.Rules, config.RedirectRewriteRule{ + Pattern: d.Args[0], + Replacement: d.Args[1], + }) + } + } + + proxy.RedirectRewrite = rr +} + +// parseProxyCacheValid parses proxy_cache_valid directive. +func parseProxyCacheValid(d *Directive, proxy *config.ProxyConfig) { + // proxy_cache_valid 200 10m + // proxy_cache_valid 301 302 1h + // proxy_cache_valid any 1m + if len(d.Args) < 2 { + return + } + + if proxy.CacheValid == nil { + proxy.CacheValid = &config.ProxyCacheValidConfig{} + } + + // Last arg is the duration. + dur := parseDuration(d.Args[len(d.Args)-1]) + + for _, arg := range d.Args[:len(d.Args)-1] { + switch arg { + case "200", "201", "202", "203", "204", "205", "206", "207", "208", "226": + proxy.CacheValid.OK = dur + case "301", "302": + proxy.CacheValid.Redirect = dur + case "404": + proxy.CacheValid.NotFound = dur + case "any": + proxy.CacheValid.OK = dur + proxy.CacheValid.Redirect = dur + proxy.CacheValid.NotFound = dur + proxy.CacheValid.ClientError = dur + proxy.CacheValid.ServerError = dur + default: + code, err := strconv.Atoi(arg) + if err == nil { + switch { + case code >= 400 && code < 500 && code != 404: + proxy.CacheValid.ClientError = dur + case code >= 500: + proxy.CacheValid.ServerError = dur + } + } + } + } +} + +// convertStaticDirectives converts directives within a static location block. +func convertStaticDirectives(directives []Directive, static *config.StaticConfig, result *ConvertResult) { + for i := range directives { + d := &directives[i] + + switch d.Name { + case "root": + if len(d.Args) > 0 { + static.Root = d.Args[0] + } + case "alias": + if len(d.Args) > 0 { + static.Root = d.Args[0] + result.Warnings = append(result.Warnings, Warning{ + Directive: "alias", + Line: d.Line, + File: d.File, + Message: "alias is converted to root; semantic differences may exist for locations with non-trailing paths", + }) + } + case "index": + static.Index = append(static.Index, d.Args...) + case "try_files": + static.TryFiles = append(static.TryFiles, d.Args...) + default: + if msg, ok := unsupportedDirectives[d.Name]; ok { + result.Warnings = append(result.Warnings, Warning{ + Directive: d.Name, + Line: d.Line, + File: d.File, + Message: msg, + }) + } + } + } +} + +// convertRedirectDirectives converts redirect directives within a location block. +func convertRedirectDirectives(directives []Directive, locPath string, server *config.ServerConfig, result *ConvertResult) { + for i := range directives { + d := &directives[i] + + if d.Name != "return" { + continue + } + + if len(d.Args) < 2 { + continue + } + + code, err := strconv.Atoi(d.Args[0]) + if err != nil { + continue + } + + url := d.Args[1] + pattern := "^" + locPath + "$" + + switch code { + case 301: + server.Rewrite = append(server.Rewrite, config.RewriteRule{ + Pattern: pattern, + Replacement: url, + Flag: "permanent", + }) + result.Warnings = append(result.Warnings, Warning{ + Directive: "return", + Line: d.Line, + File: d.File, + Message: "return 301 converted to rewrite rule with permanent flag", + }) + case 302: + server.Rewrite = append(server.Rewrite, config.RewriteRule{ + Pattern: pattern, + Replacement: url, + Flag: "redirect", + }) + result.Warnings = append(result.Warnings, Warning{ + Directive: "return", + Line: d.Line, + File: d.File, + Message: "return 302 converted to rewrite rule with redirect flag", + }) + default: + result.Warnings = append(result.Warnings, Warning{ + Directive: "return", + Line: d.Line, + File: d.File, + Message: fmt.Sprintf("return %d in location is not a redirect; only 301/302 are supported", code), + }) + } + } +} + +// parseDuration parses a time duration string. +// Supports nginx-style durations: "10s", "5m", "1h", "1d". +func parseDuration(s string) time.Duration { + if s == "" { + return 0 + } + + // Try standard Go duration first. + if d, err := time.ParseDuration(s); err == nil { + return d + } + + // Handle nginx-style durations without Go support. + s = strings.TrimSpace(s) + numStr := s[:len(s)-1] + unit := s[len(s)-1] + + value, err := strconv.ParseInt(numStr, 10, 64) + if err != nil { + return 0 + } + + switch unit { + case 's': + return time.Duration(value) * time.Second + case 'm': + return time.Duration(value) * time.Minute + case 'h': + return time.Duration(value) * time.Hour + case 'd': + return time.Duration(value) * 24 * time.Hour + default: + return 0 + } +} diff --git a/internal/converter/nginx/converter_test.go b/internal/converter/nginx/converter_test.go new file mode 100644 index 0000000..a484d47 --- /dev/null +++ b/internal/converter/nginx/converter_test.go @@ -0,0 +1,1328 @@ +package nginx + +import ( + "fmt" + "strings" + "testing" + "time" +) + +// helper: parse nginx config string and convert. +func convertString(t *testing.T, input string) (*ConvertResult, error) { + t.Helper() + cfg, err := Parse(input) + if err != nil { + t.Fatalf("parse error: %v", err) + } + return Convert(cfg) +} + +// helper: check if any warning contains substring. +func hasWarningContaining(warnings []Warning, substr string) bool { + for _, w := range warnings { + if strings.Contains(w.Message, substr) { + return true + } + } + return false +} + +func TestConvertServerBlock(t *testing.T) { + input := ` +http { + server { + listen 8080; + server_name example.com www.example.com; + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + if len(result.Config.Servers) != 1 { + t.Fatalf("expected 1 server, got %d", len(result.Config.Servers)) + } + + s := result.Config.Servers[0] + if s.Listen != ":8080" { + t.Errorf("expected listen :8080, got %s", s.Listen) + } + if s.Name != "example.com" { + t.Errorf("expected name example.com, got %s", s.Name) + } + if len(s.ServerNames) != 2 { + t.Fatalf("expected 2 server_names, got %d", len(s.ServerNames)) + } + if s.ServerNames[0] != "example.com" || s.ServerNames[1] != "www.example.com" { + t.Errorf("expected [example.com www.example.com], got %v", s.ServerNames) + } +} + +func TestConvertLocationProxyPass(t *testing.T) { + input := ` +http { + server { + listen 80; + location /api/ { + proxy_pass http://backend:8080; + } + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + s := result.Config.Servers[0] + if len(s.Proxy) != 1 { + t.Fatalf("expected 1 proxy, got %d", len(s.Proxy)) + } + + p := s.Proxy[0] + if p.Path != "/api/" { + t.Errorf("expected path /api/, got %s", p.Path) + } + if len(p.Targets) != 1 { + t.Fatalf("expected 1 target, got %d", len(p.Targets)) + } + if p.Targets[0].URL != "http://backend:8080" { + t.Errorf("expected target URL http://backend:8080, got %s", p.Targets[0].URL) + } +} + +func TestConvertLocationRoot(t *testing.T) { + input := ` +http { + server { + listen 80; + location /static/ { + root /var/www/html; + index index.html index.htm; + } + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + s := result.Config.Servers[0] + if len(s.Static) != 1 { + t.Fatalf("expected 1 static, got %d", len(s.Static)) + } + + st := s.Static[0] + if st.Path != "/static/" { + t.Errorf("expected path /static/, got %s", st.Path) + } + if st.Root != "/var/www/html" { + t.Errorf("expected root /var/www/html, got %s", st.Root) + } + if len(st.Index) != 2 || st.Index[0] != "index.html" || st.Index[1] != "index.htm" { + t.Errorf("expected [index.html index.htm], got %v", st.Index) + } +} + +func TestConvertUpstream(t *testing.T) { + input := ` +upstream backend { + server 10.0.0.1:8080 weight=3; + server 10.0.0.2:8080 weight=1 max_fails=3 fail_timeout=30s; + server 10.0.0.3:8080 backup; +} + +http { + server { + listen 80; + location / { + proxy_pass http://backend; + } + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + s := result.Config.Servers[0] + if len(s.Proxy) != 1 { + t.Fatalf("expected 1 proxy, got %d", len(s.Proxy)) + } + + p := s.Proxy[0] + if len(p.Targets) != 3 { + t.Fatalf("expected 3 targets, got %d", len(p.Targets)) + } + + if p.Targets[0].URL != "10.0.0.1:8080" { + t.Errorf("target[0] URL = %s, want 10.0.0.1:8080", p.Targets[0].URL) + } + if p.Targets[0].Weight != 3 { + t.Errorf("target[0] Weight = %d, want 3", p.Targets[0].Weight) + } + if p.Targets[1].MaxFails != 3 { + t.Errorf("target[1] MaxFails = %d, want 3", p.Targets[1].MaxFails) + } + if p.Targets[1].FailTimeout != 30*time.Second { + t.Errorf("target[1] FailTimeout = %v, want 30s", p.Targets[1].FailTimeout) + } + if !p.Targets[2].Backup { + t.Error("target[2] Backup = false, want true") + } +} + +func TestConvertLocationModifiers(t *testing.T) { + tests := []struct { + name string + input string + wantType string + }{ + { + name: "exact", + input: ` +http { + server { + listen 80; + location = /exact { + proxy_pass http://backend; + } + } +}`, + wantType: "exact", + }, + { + name: "prefix_priority", + input: ` +http { + server { + listen 80; + location ^~ /prefix { + proxy_pass http://backend; + } + } +}`, + wantType: "prefix_priority", + }, + { + name: "regex", + input: ` +http { + server { + listen 80; + location ~ \.php$ { + proxy_pass http://backend; + } + } +}`, + wantType: "regex", + }, + { + name: "regex_caseless", + input: ` +http { + server { + listen 80; + location ~* \.php$ { + proxy_pass http://backend; + } + } +}`, + wantType: "regex_caseless", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := convertString(t, tt.input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + s := result.Config.Servers[0] + if len(s.Proxy) != 1 { + t.Fatalf("expected 1 proxy, got %d", len(s.Proxy)) + } + if s.Proxy[0].LocationType != tt.wantType { + t.Errorf("LocationType = %s, want %s", s.Proxy[0].LocationType, tt.wantType) + } + }) + } +} + +func TestConvertGzipConfig(t *testing.T) { + input := ` +http { + server { + listen 80; + gzip on; + gzip_types text/plain text/css application/json; + gzip_min_length 1024; + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + s := result.Config.Servers[0] + if s.Compression.Type != "gzip" { + t.Errorf("Compression.Type = %s, want gzip", s.Compression.Type) + } + if len(s.Compression.Types) != 3 { + t.Fatalf("Compression.Types length = %d, want 3", len(s.Compression.Types)) + } + if s.Compression.Types[0] != "text/plain" { + t.Errorf("Compression.Types[0] = %s, want text/plain", s.Compression.Types[0]) + } + if s.Compression.MinSize != 1024 { + t.Errorf("Compression.MinSize = %d, want 1024", s.Compression.MinSize) + } +} + +func TestConvertRewrite(t *testing.T) { + input := ` +http { + server { + listen 80; + rewrite ^/old/(.*)$ /new/$1 last; + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + s := result.Config.Servers[0] + if len(s.Rewrite) != 1 { + t.Fatalf("expected 1 rewrite rule, got %d", len(s.Rewrite)) + } + + r := s.Rewrite[0] + if r.Pattern != "^/old/(.*)$" { + t.Errorf("Pattern = %s, want ^/old/(.*)$", r.Pattern) + } + if r.Replacement != "/new/$1" { + t.Errorf("Replacement = %s, want /new/$1", r.Replacement) + } + if r.Flag != "last" { + t.Errorf("Flag = %s, want last", r.Flag) + } +} + +func TestConvertReturn301(t *testing.T) { + input := ` +http { + server { + listen 80; + location /old { + return 301 https://example.com/new; + } + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + s := result.Config.Servers[0] + if len(s.Rewrite) != 1 { + t.Fatalf("expected 1 rewrite rule, got %d", len(s.Rewrite)) + } + + r := s.Rewrite[0] + if r.Pattern != "^/old$" { + t.Errorf("Pattern = %s, want ^/old$", r.Pattern) + } + if r.Replacement != "https://example.com/new" { + t.Errorf("Replacement = %s, want https://example.com/new", r.Replacement) + } + if r.Flag != "permanent" { + t.Errorf("Flag = %s, want permanent", r.Flag) + } + + if !hasWarningContaining(result.Warnings, "return 301") { + t.Error("expected warning about return 301 conversion") + } +} + +func TestConvertReturn302(t *testing.T) { + input := ` +http { + server { + listen 80; + location /temp { + return 302 https://example.com/temp-new; + } + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + s := result.Config.Servers[0] + if len(s.Rewrite) != 1 { + t.Fatalf("expected 1 rewrite rule, got %d", len(s.Rewrite)) + } + + r := s.Rewrite[0] + if r.Flag != "redirect" { + t.Errorf("Flag = %s, want redirect", r.Flag) + } + + if !hasWarningContaining(result.Warnings, "return 302") { + t.Error("expected warning about return 302 conversion") + } +} + +func TestConvertSSLConfig(t *testing.T) { + input := ` +http { + server { + listen 443 ssl; + ssl_certificate /etc/ssl/server.crt; + ssl_certificate_key /etc/ssl/server.key; + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + s := result.Config.Servers[0] + if s.SSL.Cert != "/etc/ssl/server.crt" { + t.Errorf("SSL.Cert = %s, want /etc/ssl/server.crt", s.SSL.Cert) + } + if s.SSL.Key != "/etc/ssl/server.key" { + t.Errorf("SSL.Key = %s, want /etc/ssl/server.key", s.SSL.Key) + } + if s.Listen != ":443" { + t.Errorf("Listen = %s, want :443", s.Listen) + } +} + +func TestConvertProxyAddXForwardedFor(t *testing.T) { + input := ` +http { + server { + listen 80; + location / { + proxy_pass http://backend; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + } + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + s := result.Config.Servers[0] + if len(s.Proxy) != 1 { + t.Fatalf("expected 1 proxy, got %d", len(s.Proxy)) + } + + val := s.Proxy[0].Headers.SetRequest["X-Forwarded-For"] + if val != "$remote_addr" { + t.Errorf("X-Forwarded-For = %s, want $remote_addr", val) + } + + if !hasWarningContaining(result.Warnings, "$proxy_add_x_forwarded_for") { + t.Error("expected warning about $proxy_add_x_forwarded_for replacement") + } +} + +func TestConvertUnsupportedDirective(t *testing.T) { + input := ` +http { + server { + listen 80; + if ($host = example.com) { + return 301 https://example.com; + } + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + if !hasWarningContaining(result.Warnings, "'if' directive") { + t.Error("expected warning about unsupported 'if' directive") + } +} + +func TestConvertEmptyLocation(t *testing.T) { + input := ` +http { + server { + listen 80; + location /empty { + } + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + if !hasWarningContaining(result.Warnings, "no content") { + t.Error("expected warning about empty location") + } +} + +func TestConvertConflictingLocation(t *testing.T) { + input := ` +http { + server { + listen 80; + location /conflict { + proxy_pass http://backend; + root /var/www; + } + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + s := result.Config.Servers[0] + // Should be classified as proxy, not static. + if len(s.Proxy) != 1 { + t.Fatalf("expected 1 proxy, got %d", len(s.Proxy)) + } + if len(s.Static) != 0 { + t.Errorf("expected 0 static, got %d", len(s.Static)) + } + + if !hasWarningContaining(result.Warnings, "proxy_pass takes priority") { + t.Error("expected warning about proxy_pass taking priority over root/alias") + } +} + +func TestConvertAliasDirective(t *testing.T) { + input := ` +http { + server { + listen 80; + location /images/ { + alias /data/photos/; + } + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + s := result.Config.Servers[0] + if len(s.Static) != 1 { + t.Fatalf("expected 1 static, got %d", len(s.Static)) + } + + st := s.Static[0] + if st.Root != "/data/photos/" { + t.Errorf("Root = %s, want /data/photos/", st.Root) + } + + if !hasWarningContaining(result.Warnings, "alias") { + t.Error("expected warning about alias conversion to root") + } +} + +func TestConvertReturnNonRedirect(t *testing.T) { + input := ` +http { + server { + listen 80; + location /health { + return 200 "OK"; + } + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + if len(result.Warnings) == 0 { + t.Fatal("expected at least one warning, got none") + } + if !hasWarningContaining(result.Warnings, "'return' directive") && !hasWarningContaining(result.Warnings, "return") { + t.Errorf("expected warning about non-redirect return, got warnings: %v", result.Warnings) + } +} + +func TestConvertServerLevelReturn(t *testing.T) { + input := ` +http { + server { + listen 80; + return 301 https://example.com; + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + s := result.Config.Servers[0] + if len(s.Rewrite) != 1 { + t.Fatalf("expected 1 rewrite rule, got %d", len(s.Rewrite)) + } + + r := s.Rewrite[0] + if r.Pattern != "^/" { + t.Errorf("Pattern = %s, want ^/", r.Pattern) + } + if r.Replacement != "https://example.com" { + t.Errorf("Replacement = %s, want https://example.com", r.Replacement) + } + if r.Flag != "permanent" { + t.Errorf("Flag = %s, want permanent", r.Flag) + } +} + +func TestConvertProxySetHeader(t *testing.T) { + input := ` +http { + server { + listen 80; + location / { + proxy_pass http://backend; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-Proto $scheme; + } + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + s := result.Config.Servers[0] + p := s.Proxy[0] + + if p.Headers.SetRequest["Host"] != "$host" { + t.Errorf("Host = %s, want $host", p.Headers.SetRequest["Host"]) + } + if p.Headers.SetRequest["X-Real-IP"] != "$remote_addr" { + t.Errorf("X-Real-IP = %s, want $remote_addr", p.Headers.SetRequest["X-Real-IP"]) + } + if p.Headers.SetRequest["X-Forwarded-Proto"] != "$scheme" { + t.Errorf("X-Forwarded-Proto = %s, want $scheme", p.Headers.SetRequest["X-Forwarded-Proto"]) + } +} + +func TestConvertProxyTimeout(t *testing.T) { + input := ` +http { + server { + listen 80; + location / { + proxy_pass http://backend; + proxy_connect_timeout 5s; + proxy_read_timeout 30s; + proxy_send_timeout 15s; + } + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + s := result.Config.Servers[0] + p := s.Proxy[0] + + if p.Timeout.Connect != 5*time.Second { + t.Errorf("Connect = %v, want 5s", p.Timeout.Connect) + } + if p.Timeout.Read != 30*time.Second { + t.Errorf("Read = %v, want 30s", p.Timeout.Read) + } + if p.Timeout.Write != 15*time.Second { + t.Errorf("Write = %v, want 15s", p.Timeout.Write) + } +} + +func TestConvertClientMaxBodySize(t *testing.T) { + input := ` +http { + server { + listen 80; + client_max_body_size 10m; + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + s := result.Config.Servers[0] + if s.ClientMaxBodySize != "10m" { + t.Errorf("ClientMaxBodySize = %s, want 10m", s.ClientMaxBodySize) + } +} + +func TestConvertMultipleServerBlocks(t *testing.T) { + input := ` +http { + server { + listen 80; + server_name example.com; + } + server { + listen 443 ssl; + server_name secure.example.com; + ssl_certificate /etc/ssl/cert.pem; + ssl_certificate_key /etc/ssl/key.pem; + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + if len(result.Config.Servers) != 2 { + t.Fatalf("expected 2 servers, got %d", len(result.Config.Servers)) + } + + if result.Config.Servers[0].Listen != ":80" { + t.Errorf("server[0] Listen = %s, want :80", result.Config.Servers[0].Listen) + } + if result.Config.Servers[1].Listen != ":443" { + t.Errorf("server[1] Listen = %s, want :443", result.Config.Servers[1].Listen) + } + if result.Config.Servers[1].Name != "secure.example.com" { + t.Errorf("server[1] Name = %s, want secure.example.com", result.Config.Servers[1].Name) + } +} + +func TestConvertDefaultServer(t *testing.T) { + input := ` +http { + server { + listen 80 default_server; + server_name _; + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + s := result.Config.Servers[0] + if !s.Default { + t.Error("Default = false, want true") + } + if s.Listen != ":80" { + t.Errorf("Listen = %s, want :80", s.Listen) + } +} + +func TestConvertProxyRedirect(t *testing.T) { + tests := []struct { + name string + input string + mode string + rules int + }{ + { + name: "off", + input: ` +http { + server { + listen 80; + location / { + proxy_pass http://backend; + proxy_redirect off; + } + } +}`, + mode: "off", + rules: 0, + }, + { + name: "default", + input: ` +http { + server { + listen 80; + location / { + proxy_pass http://backend; + proxy_redirect default; + } + } +}`, + mode: "default", + rules: 0, + }, + { + name: "custom", + input: ` +http { + server { + listen 80; + location / { + proxy_pass http://backend; + proxy_redirect http://backend/ /; + } + } +}`, + mode: "custom", + rules: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := convertString(t, tt.input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + p := result.Config.Servers[0].Proxy[0] + if p.RedirectRewrite == nil { + t.Fatal("RedirectRewrite is nil") + } + if p.RedirectRewrite.Mode != tt.mode { + t.Errorf("Mode = %s, want %s", p.RedirectRewrite.Mode, tt.mode) + } + if len(p.RedirectRewrite.Rules) != tt.rules { + t.Errorf("Rules count = %d, want %d", len(p.RedirectRewrite.Rules), tt.rules) + } + }) + } +} + +func TestConvertErrorPage(t *testing.T) { + input := ` +http { + server { + listen 80; + error_page 404 /404.html; + error_page 500 502 503 /50x.html; + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + s := result.Config.Servers[0] + pages := s.Security.ErrorPage.Pages + + if pages[404] != "/404.html" { + t.Errorf("error_page[404] = %s, want /404.html", pages[404]) + } + if pages[500] != "/50x.html" { + t.Errorf("error_page[500] = %s, want /50x.html", pages[500]) + } + if pages[502] != "/50x.html" { + t.Errorf("error_page[502] = %s, want /50x.html", pages[502]) + } + if pages[503] != "/50x.html" { + t.Errorf("error_page[503] = %s, want /50x.html", pages[503]) + } +} + +func TestConvertAuthBasic(t *testing.T) { + input := ` +http { + server { + listen 80; + auth_basic "Restricted Area"; + auth_basic_user_file /etc/nginx/.htpasswd; + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + s := result.Config.Servers[0] + if s.Security.Auth.Type != "basic" { + t.Errorf("Auth.Type = %s, want basic", s.Security.Auth.Type) + } + if s.Security.Auth.Realm != "Restricted Area" { + t.Errorf("Auth.Realm = %s, want Restricted Area", s.Security.Auth.Realm) + } + + if !hasWarningContaining(result.Warnings, "auth_basic_user_file") { + t.Error("expected warning about auth_basic_user_file migration") + } +} + +func TestConvertTopLevelServer(t *testing.T) { + // Server block without http wrapper. + input := ` +server { + listen 8080; + server_name direct.example.com; +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + if len(result.Config.Servers) != 1 { + t.Fatalf("expected 1 server, got %d", len(result.Config.Servers)) + } + s := result.Config.Servers[0] + if s.Listen != ":8080" { + t.Errorf("Listen = %s, want :8080", s.Listen) + } + if s.Name != "direct.example.com" { + t.Errorf("Name = %s, want direct.example.com", s.Name) + } +} + +func TestConvertLoggingConfig(t *testing.T) { + input := ` +http { + server { + listen 80; + access_log /var/log/nginx/access.log combined; + error_log /var/log/nginx/error.log warn; + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + if result.Config.Logging.Access.Path != "/var/log/nginx/access.log" { + t.Errorf("Access.Path = %s, want /var/log/nginx/access.log", result.Config.Logging.Access.Path) + } + if result.Config.Logging.Access.Format != "combined" { + t.Errorf("Access.Format = %s, want combined", result.Config.Logging.Access.Format) + } + if result.Config.Logging.Error.Path != "/var/log/nginx/error.log" { + t.Errorf("Error.Path = %s, want /var/log/nginx/error.log", result.Config.Logging.Error.Path) + } + if result.Config.Logging.Error.Level != "warn" { + t.Errorf("Error.Level = %s, want warn", result.Config.Logging.Error.Level) + } +} + +func TestConvertServerTokens(t *testing.T) { + input := ` +http { + server { + listen 80; + server_tokens off; + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + s := result.Config.Servers[0] + if s.ServerTokens { + t.Error("ServerTokens = true, want false") + } +} + +func TestConvertUpstreamLoadBalance(t *testing.T) { + tests := []struct { + name string + upstream string + wantLB string + }{ + { + name: "least_conn", + upstream: "least_conn;", + wantLB: "least_conn", + }, + { + name: "ip_hash", + upstream: "ip_hash;", + wantLB: "ip_hash", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + input := fmt.Sprintf(` +upstream backend { + server 10.0.0.1:8080; + %s +} + +http { + server { + listen 80; + location / { + proxy_pass http://backend; + } + } +} +`, tt.upstream) + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + p := result.Config.Servers[0].Proxy[0] + if p.LoadBalance != tt.wantLB { + t.Errorf("LoadBalance = %s, want %s", p.LoadBalance, tt.wantLB) + } + }) + } +} + +func TestConvertTryFiles(t *testing.T) { + input := ` +http { + server { + listen 80; + location / { + root /var/www; + try_files $uri $uri/ /index.html; + } + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + s := result.Config.Servers[0] + st := s.Static[0] + if len(st.TryFiles) != 3 { + t.Fatalf("TryFiles length = %d, want 3", len(st.TryFiles)) + } + if st.TryFiles[0] != "$uri" || st.TryFiles[1] != "$uri/" || st.TryFiles[2] != "/index.html" { + t.Errorf("TryFiles = %v, want [$uri $uri/ /index.html]", st.TryFiles) + } +} + +func TestConvertProxyHidePassHeaders(t *testing.T) { + input := ` +http { + server { + listen 80; + location / { + proxy_pass http://backend; + proxy_hide_header X-Powered-By; + proxy_pass_header X-My-Custom; + } + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + p := result.Config.Servers[0].Proxy[0] + if len(p.Headers.HideResponse) != 1 || p.Headers.HideResponse[0] != "X-Powered-By" { + t.Errorf("HideResponse = %v, want [X-Powered-By]", p.Headers.HideResponse) + } + if len(p.Headers.PassResponse) != 1 || p.Headers.PassResponse[0] != "X-My-Custom" { + t.Errorf("PassResponse = %v, want [X-My-Custom]", p.Headers.PassResponse) + } +} + +func TestConvertProxyCache(t *testing.T) { + input := ` +http { + server { + listen 80; + location / { + proxy_pass http://backend; + proxy_cache my_cache; + proxy_cache_valid 200 10m; + proxy_cache_valid 404 1m; + } + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + p := result.Config.Servers[0].Proxy[0] + if !p.Cache.Enabled { + t.Error("Cache.Enabled = false, want true") + } + if p.CacheValid == nil { + t.Fatal("CacheValid is nil") + } + if p.CacheValid.OK != 10*time.Minute { + t.Errorf("CacheValid.OK = %v, want 10m", p.CacheValid.OK) + } + if p.CacheValid.NotFound != time.Minute { + t.Errorf("CacheValid.NotFound = %v, want 1m", p.CacheValid.NotFound) + } +} + +func TestConvertNamedLocation(t *testing.T) { + input := ` +http { + server { + listen 80; + location @fallback { + proxy_pass http://fallback-backend; + } + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + p := result.Config.Servers[0].Proxy[0] + if p.LocationType != "named" { + t.Errorf("LocationType = %s, want named", p.LocationType) + } + if p.LocationName != "fallback" { + t.Errorf("LocationName = %s, want fallback", p.LocationName) + } +} + +func TestConvertWarningString(t *testing.T) { + w := Warning{ + Directive: "if", + Line: 10, + File: "test.conf", + Message: "unsupported", + } + s := w.String() + if !strings.Contains(s, "test.conf:10") { + t.Errorf("Warning.String() = %s, should contain test.conf:10", s) + } + if !strings.Contains(s, "unsupported") { + t.Errorf("Warning.String() = %s, should contain 'unsupported'", s) + } +} + +func TestConvertEmptyConfig(t *testing.T) { + input := `` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + if len(result.Config.Servers) != 0 { + t.Errorf("expected 0 servers, got %d", len(result.Config.Servers)) + } +} + +func TestConvertServerTokensOn(t *testing.T) { + input := ` +http { + server { + listen 80; + server_tokens on; + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + s := result.Config.Servers[0] + if !s.ServerTokens { + t.Error("ServerTokens = false, want true") + } +} + +func TestConvertMapDirectiveWarning(t *testing.T) { + input := ` +http { + map $host $backend { + default backend1; + } + server { + listen 80; + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + if !hasWarningContaining(result.Warnings, "'map' directive") { + t.Error("expected warning about unsupported 'map' directive") + } +} + +func TestConvertUpstreamDownServer(t *testing.T) { + input := ` +upstream backend { + server 10.0.0.1:8080; + server 10.0.0.2:8080 down; +} + +http { + server { + listen 80; + location / { + proxy_pass http://backend; + } + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + p := result.Config.Servers[0].Proxy[0] + if len(p.Targets) != 2 { + t.Fatalf("expected 2 targets, got %d", len(p.Targets)) + } + if !p.Targets[1].Down { + t.Error("target[1].Down = false, want true") + } +} + +func TestConvertUpstreamInsideHttpBlock(t *testing.T) { + // Upstream blocks inside http block should be found. + input := ` +http { + upstream backend { + server 10.0.0.1:8080; + server 10.0.0.2:8080; + } + server { + listen 80; + location / { + proxy_pass http://backend; + } + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + s := result.Config.Servers[0] + if len(s.Proxy) != 1 { + t.Fatalf("expected 1 proxy, got %d", len(s.Proxy)) + } + + p := s.Proxy[0] + if len(p.Targets) != 2 { + t.Fatalf("expected 2 targets from upstream inside http block, got %d", len(p.Targets)) + } + if p.Targets[0].URL != "10.0.0.1:8080" { + t.Errorf("target[0] URL = %s, want 10.0.0.1:8080", p.Targets[0].URL) + } + if p.Targets[1].URL != "10.0.0.2:8080" { + t.Errorf("target[1] URL = %s, want 10.0.0.2:8080", p.Targets[1].URL) + } +} + +func TestConvertSSLWithoutCertKey(t *testing.T) { + // SSL enabled via listen directive but no ssl_certificate/key should produce a warning. + input := ` +http { + server { + listen 443 ssl; + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + s := result.Config.Servers[0] + if s.SSL.Cert != "" { + t.Errorf("SSL.Cert = %s, want empty string (no auto)", s.SSL.Cert) + } + if s.SSL.Key != "" { + t.Errorf("SSL.Key = %s, want empty string (no auto)", s.SSL.Key) + } + if !hasWarningContaining(result.Warnings, "ssl_certificate") { + t.Error("expected warning about missing ssl_certificate/ssl_certificate_key") + } +} + +func TestConvertSSLWithCertKeyNoWarning(t *testing.T) { + // SSL enabled with both cert and key should NOT produce a warning. + input := ` +http { + server { + listen 443 ssl; + ssl_certificate /etc/ssl/server.crt; + ssl_certificate_key /etc/ssl/server.key; + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + s := result.Config.Servers[0] + if s.SSL.Cert != "/etc/ssl/server.crt" { + t.Errorf("SSL.Cert = %s, want /etc/ssl/server.crt", s.SSL.Cert) + } + if s.SSL.Key != "/etc/ssl/server.key" { + t.Errorf("SSL.Key = %s, want /etc/ssl/server.key", s.SSL.Key) + } + if hasWarningContaining(result.Warnings, "ssl_certificate") { + t.Error("unexpected warning about missing ssl_certificate when both are provided") + } +} + +func TestConvertServerNoListenDefault(t *testing.T) { + // Server block without listen directive should get default 0.0.0.0:80. + input := ` +http { + server { + server_name example.com; + } +} +` + result, err := convertString(t, input) + if err != nil { + t.Fatalf("convert error: %v", err) + } + + s := result.Config.Servers[0] + if s.Listen != "0.0.0.0:80" { + t.Errorf("Listen = %s, want 0.0.0.0:80 when no listen directive", s.Listen) + } +} + diff --git a/internal/converter/nginx/parser.go b/internal/converter/nginx/parser.go new file mode 100644 index 0000000..3b565cc --- /dev/null +++ b/internal/converter/nginx/parser.go @@ -0,0 +1,344 @@ +// Package nginx provides a recursive descent parser for nginx configuration files. +package nginx + +import ( + "fmt" + "os" + "path/filepath" +) + +// Directive represents a single nginx directive. +type Directive struct { + Name string // directive name (e.g., "server", "listen", "proxy_pass") + Args []string // directive arguments + Block []Directive // child directives for block directives (e.g., server { ... }) + Line int // line number in source file + File string // source file path (for include tracking) +} + +// NginxConfig represents a parsed nginx configuration. +type NginxConfig struct { + Directives []Directive +} + +// ParseError represents a parse error with file and line information. +type ParseError struct { + File string + Line int + Message string +} + +func (e *ParseError) Error() string { + return fmt.Sprintf("%s:%d: %s", e.File, e.Line, e.Message) +} + +type parser struct { + input []byte + pos int + line int + file string + includeStack map[string]bool + depth int + extraDirectives []Directive // directives injected by include expansion +} + +const maxDepth = 10 + +// Parse parses nginx configuration from a string. +func Parse(input string) (*NginxConfig, error) { + p := &parser{ + input: []byte(input), + pos: 0, + line: 1, + file: "", + includeStack: make(map[string]bool), + depth: 0, + } + directives, err := p.parseDirectives() + if err != nil { + return nil, err + } + return &NginxConfig{Directives: directives}, nil +} + +// ParseFile parses an nginx configuration file, handling include directives. +func ParseFile(path string) (*NginxConfig, error) { + absPath, err := filepath.Abs(path) + if err != nil { + return nil, &ParseError{File: path, Line: 1, Message: fmt.Sprintf("resolve absolute path: %v", err)} + } + resolved, err := filepath.EvalSymlinks(absPath) + if err != nil { + return nil, &ParseError{File: path, Line: 1, Message: fmt.Sprintf("resolve symlinks: %v", err)} + } + return parseFileWithStack(resolved, map[string]bool{resolved: true}, 0) +} + +func parseFileWithStack(path string, includeStack map[string]bool, depth int) (*NginxConfig, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, &ParseError{File: path, Line: 1, Message: fmt.Sprintf("read file: %v", err)} + } + + newStack := make(map[string]bool, len(includeStack)+1) + for k := range includeStack { + newStack[k] = true + } + newStack[path] = true + + p := &parser{ + input: data, + pos: 0, + line: 1, + file: path, + includeStack: newStack, + depth: depth, + } + directives, err := p.parseDirectives() + if err != nil { + return nil, err + } + return &NginxConfig{Directives: directives}, nil +} + +func (p *parser) errorf(msg string, args ...any) error { + return &ParseError{File: p.file, Line: p.line, Message: fmt.Sprintf(msg, args...)} +} + +func (p *parser) parseDirectives() ([]Directive, error) { + var directives []Directive + for { + p.skipWhitespaceAndComments() + if p.pos >= len(p.input) { + break + } + if p.input[p.pos] == '}' { + break + } + d, err := p.parseDirective() + if err != nil { + return nil, err + } + // handleInclude may produce zero directives (glob no match) or + // multiple directives (include expands to several files). + if d == nil { + continue + } + directives = append(directives, *d) + + // Drain any extra directives injected by include expansion. + for _, extra := range p.extraDirectives { + directives = append(directives, extra) + } + p.extraDirectives = nil + } + return directives, nil +} + +func (p *parser) parseDirective() (*Directive, error) { + p.skipWhitespaceAndComments() + + line := p.line + name, err := p.readToken() + if err != nil { + return nil, err + } + if name == "" { + return nil, p.errorf("expected directive name") + } + + d := &Directive{ + Name: name, + Line: line, + File: p.file, + } + + // Read arguments until ; or { + for { + p.skipWhitespaceAndComments() + if p.pos >= len(p.input) { + return nil, p.errorf("unexpected end of input, expected ';' or '{'") + } + + ch := p.input[p.pos] + if ch == ';' { + p.pos++ + break + } + if ch == '{' { + p.pos++ + block, err := p.parseDirectives() + if err != nil { + return nil, err + } + p.skipWhitespaceAndComments() + if p.pos >= len(p.input) || p.input[p.pos] != '}' { + return nil, p.errorf("expected '}'") + } + p.pos++ + d.Block = block + break + } + + arg, err := p.readToken() + if err != nil { + return nil, err + } + if arg == "" { + return nil, p.errorf("unexpected character %q", p.input[p.pos]) + } + d.Args = append(d.Args, arg) + } + + // Handle include directive: replace with expanded content. + if d.Name == "include" && len(d.Args) > 0 { + return p.handleInclude(d.Args[0]) + } + + return d, nil +} + +func (p *parser) handleInclude(pattern string) (*Directive, error) { + if p.depth >= maxDepth { + return nil, p.errorf("include depth exceeds maximum of %d", maxDepth) + } + + var fullPattern string + if filepath.IsAbs(pattern) { + fullPattern = pattern + } else { + baseDir := filepath.Dir(p.file) + fullPattern = filepath.Join(baseDir, pattern) + } + + matches, err := filepath.Glob(fullPattern) + if err != nil { + return nil, p.errorf("invalid include pattern %q: %v", pattern, err) + } + + if len(matches) == 0 { + // If the pattern contains no glob metacharacters, it's a literal + // file path that should exist. Return an error if it doesn't. + if !isGlobPattern(pattern) { + return nil, p.errorf("include file not found: %s", fullPattern) + } + // Glob pattern with no matches — silently skip (matches nginx behavior). + return nil, nil + } + + var allDirectives []Directive + for _, match := range matches { + resolved, err := filepath.EvalSymlinks(match) + if err != nil { + return nil, p.errorf("resolve symlinks for %q: %v", match, err) + } + if p.includeStack[resolved] { + return nil, p.errorf("circular include detected: %s", resolved) + } + + cfg, err := parseFileWithStack(resolved, p.includeStack, p.depth+1) + if err != nil { + return nil, err + } + allDirectives = append(allDirectives, cfg.Directives...) + } + + if len(allDirectives) == 0 { + return nil, nil + } + + // Return the first directive; stash the rest for parseDirectives to drain. + p.extraDirectives = allDirectives[1:] + return &allDirectives[0], nil +} + +func (p *parser) skipWhitespaceAndComments() { + for p.pos < len(p.input) { + ch := p.input[p.pos] + if ch == ' ' || ch == '\t' || ch == '\r' { + p.pos++ + continue + } + if ch == '\n' { + p.pos++ + p.line++ + continue + } + if ch == '#' { + for p.pos < len(p.input) && p.input[p.pos] != '\n' { + p.pos++ + } + continue + } + break + } +} + +func (p *parser) readToken() (string, error) { + if p.pos >= len(p.input) { + return "", nil + } + + ch := p.input[p.pos] + + if ch == '"' || ch == '\'' { + return p.readQuotedString(ch) + } + + if ch == '{' || ch == '}' || ch == ';' { + return "", nil + } + + start := p.pos + for p.pos < len(p.input) { + ch = p.input[p.pos] + if ch == ' ' || ch == '\t' || ch == '\r' || ch == '\n' || + ch == '{' || ch == '}' || ch == ';' || ch == '#' || + ch == '"' || ch == '\'' { + break + } + p.pos++ + } + if p.pos == start { + return "", nil + } + return string(p.input[start:p.pos]), nil +} + +func (p *parser) readQuotedString(quote byte) (string, error) { + p.pos++ // skip opening quote + var buf []byte + for p.pos < len(p.input) { + ch := p.input[p.pos] + if ch == '\\' { + p.pos++ + if p.pos >= len(p.input) { + return "", p.errorf("unterminated escape in quoted string") + } + buf = append(buf, p.input[p.pos]) + p.pos++ + continue + } + if ch == quote { + p.pos++ // skip closing quote + return string(buf), nil + } + if ch == '\n' { + p.line++ + } + buf = append(buf, ch) + p.pos++ + } + return "", p.errorf("unterminated quoted string") +} + +// isGlobPattern returns true if the path contains glob metacharacters. +func isGlobPattern(path string) bool { + for i := 0; i < len(path); i++ { + ch := path[i] + if ch == '*' || ch == '?' || ch == '[' { + return true + } + } + return false +} diff --git a/internal/converter/nginx/parser_test.go b/internal/converter/nginx/parser_test.go new file mode 100644 index 0000000..35fa35d --- /dev/null +++ b/internal/converter/nginx/parser_test.go @@ -0,0 +1,430 @@ +package nginx + +import ( + "fmt" + "os" + "path/filepath" + "testing" +) + +func TestParseSimpleDirective(t *testing.T) { + cfg, err := Parse("listen 80;") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(cfg.Directives) != 1 { + t.Fatalf("expected 1 directive, got %d", len(cfg.Directives)) + } + d := cfg.Directives[0] + if d.Name != "listen" { + t.Errorf("expected name %q, got %q", "listen", d.Name) + } + if len(d.Args) != 1 || d.Args[0] != "80" { + t.Errorf("expected args [80], got %v", d.Args) + } +} + +func TestParseBlockDirective(t *testing.T) { + cfg, err := Parse("server { listen 80; }") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(cfg.Directives) != 1 { + t.Fatalf("expected 1 directive, got %d", len(cfg.Directives)) + } + d := cfg.Directives[0] + if d.Name != "server" { + t.Errorf("expected name %q, got %q", "server", d.Name) + } + if len(d.Block) != 1 { + t.Fatalf("expected 1 child, got %d", len(d.Block)) + } + if d.Block[0].Name != "listen" { + t.Errorf("expected child name %q, got %q", "listen", d.Block[0].Name) + } + if len(d.Block[0].Args) != 1 || d.Block[0].Args[0] != "80" { + t.Errorf("expected child args [80], got %v", d.Block[0].Args) + } +} + +func TestParseComment(t *testing.T) { + cfg, err := Parse("# this is a comment\nlisten 80;") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(cfg.Directives) != 1 { + t.Fatalf("expected 1 directive, got %d", len(cfg.Directives)) + } + if cfg.Directives[0].Name != "listen" { + t.Errorf("expected name %q, got %q", "listen", cfg.Directives[0].Name) + } +} + +func TestParseQuotedString(t *testing.T) { + tests := []struct { + name string + input string + dirName string + expected []string + }{ + { + name: "double quoted", + input: `proxy_set_header Host "example.com";`, + dirName: "proxy_set_header", + expected: []string{"Host", "example.com"}, + }, + { + name: "single quoted", + input: `proxy_set_header Host 'example.com';`, + dirName: "proxy_set_header", + expected: []string{"Host", "example.com"}, + }, + { + name: "escaped quote inside double", + input: `set $x "hello\"world";`, + dirName: "set", + expected: []string{"$x", `hello"world`}, + }, + { + name: "escaped quote inside single", + input: `set $x 'hello\'world';`, + dirName: "set", + expected: []string{"$x", "hello'world"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg, err := Parse(tt.input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(cfg.Directives) != 1 { + t.Fatalf("expected 1 directive, got %d", len(cfg.Directives)) + } + d := cfg.Directives[0] + if d.Name != tt.dirName { + t.Errorf("expected name %q, got %q", tt.dirName, d.Name) + } + if len(d.Args) != len(tt.expected) { + t.Fatalf("expected %d args, got %d", len(tt.expected), len(d.Args)) + } + for i, want := range tt.expected { + if d.Args[i] != want { + t.Errorf("arg[%d]: expected %q, got %q", i, want, d.Args[i]) + } + } + }) + } +} + +func TestParseMultipleDirectives(t *testing.T) { + input := `listen 80; +server_name example.com;` + cfg, err := Parse(input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(cfg.Directives) != 2 { + t.Fatalf("expected 2 directives, got %d", len(cfg.Directives)) + } + if cfg.Directives[0].Name != "listen" { + t.Errorf("directive[0]: expected %q, got %q", "listen", cfg.Directives[0].Name) + } + if cfg.Directives[1].Name != "server_name" { + t.Errorf("directive[1]: expected %q, got %q", "server_name", cfg.Directives[1].Name) + } +} + +func TestParseNestedBlocks(t *testing.T) { + input := `http { server { location / { root /var/www; } } }` + cfg, err := Parse(input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(cfg.Directives) != 1 { + t.Fatalf("expected 1 top-level directive, got %d", len(cfg.Directives)) + } + http := cfg.Directives[0] + if http.Name != "http" || len(http.Block) != 1 { + t.Fatalf("expected http with 1 child") + } + srv := http.Block[0] + if srv.Name != "server" || len(srv.Block) != 1 { + t.Fatalf("expected server with 1 child") + } + loc := srv.Block[0] + if loc.Name != "location" || len(loc.Block) != 1 { + t.Fatalf("expected location with 1 child") + } + if loc.Block[0].Name != "root" { + t.Errorf("expected root, got %q", loc.Block[0].Name) + } +} + +func TestParseUnclosedBlock(t *testing.T) { + _, err := Parse("server { listen 80;") + if err == nil { + t.Fatal("expected error for unclosed block") + } + pe, ok := err.(*ParseError) + if !ok { + t.Fatalf("expected *ParseError, got %T", err) + } + if pe.Line == 0 { + t.Error("expected non-zero line number") + } +} + +func TestParseMissingSemicolon(t *testing.T) { + _, err := Parse("listen 80") + if err == nil { + t.Fatal("expected error for missing semicolon") + } + pe, ok := err.(*ParseError) + if !ok { + t.Fatalf("expected *ParseError, got %T", err) + } + if pe.Line == 0 { + t.Error("expected non-zero line number") + } +} + +func TestParseFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "nginx.conf") + content := "listen 80;\nserver_name example.com;" + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatalf("write file: %v", err) + } + cfg, err := ParseFile(path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(cfg.Directives) != 2 { + t.Fatalf("expected 2 directives, got %d", len(cfg.Directives)) + } + if cfg.Directives[0].Name != "listen" { + t.Errorf("directive[0]: expected %q, got %q", "listen", cfg.Directives[0].Name) + } + if cfg.Directives[1].Name != "server_name" { + t.Errorf("directive[1]: expected %q, got %q", "server_name", cfg.Directives[1].Name) + } +} + +func TestParseIncludeGlob(t *testing.T) { + dir := t.TempDir() + + // Create a subdirectory for included files to avoid matching nginx.conf itself. + incDir := filepath.Join(dir, "includes") + if err := os.Mkdir(incDir, 0755); err != nil { + t.Fatal(err) + } + + if err := os.WriteFile(filepath.Join(incDir, "a.conf"), []byte("listen 80;\n"), 0644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(incDir, "b.conf"), []byte("server_name a.com;\n"), 0644); err != nil { + t.Fatal(err) + } + + // Create main config with include. + main := filepath.Join(dir, "nginx.conf") + content := "include " + incDir + "/*.conf;" + if err := os.WriteFile(main, []byte(content), 0644); err != nil { + t.Fatal(err) + } + + cfg, err := ParseFile(main) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(cfg.Directives) != 2 { + t.Fatalf("expected 2 directives from included files, got %d", len(cfg.Directives)) + } + names := map[string]bool{} + for _, d := range cfg.Directives { + names[d.Name] = true + } + if !names["listen"] || !names["server_name"] { + t.Errorf("expected listen and server_name, got %v", cfg.Directives) + } +} + +func TestParseIncludeCircular(t *testing.T) { + dir := t.TempDir() + + a := filepath.Join(dir, "a.conf") + b := filepath.Join(dir, "b.conf") + + if err := os.WriteFile(a, []byte("include "+b+";"), 0644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(b, []byte("include "+a+";"), 0644); err != nil { + t.Fatal(err) + } + + _, err := ParseFile(a) + if err == nil { + t.Fatal("expected error for circular include") + } + pe, ok := err.(*ParseError) + if !ok { + t.Fatalf("expected *ParseError, got %T", err) + } + if pe.Message == "" { + t.Error("expected non-empty error message") + } +} + +func TestParseIncludeMaxDepth(t *testing.T) { + dir := t.TempDir() + + // Create a chain of includes: 0.conf includes 1.conf, 1.conf includes 2.conf, etc. + for i := 0; i <= maxDepth+1; i++ { + path := filepath.Join(dir, fmt.Sprintf("%d.conf", i)) + var content string + if i < maxDepth+1 { + next := filepath.Join(dir, fmt.Sprintf("%d.conf", i+1)) + content = "include " + next + ";" + } else { + content = "listen 80;" + } + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatal(err) + } + } + + _, err := ParseFile(filepath.Join(dir, "0.conf")) + if err == nil { + t.Fatal("expected error for max include depth") + } + if _, ok := err.(*ParseError); !ok { + t.Fatalf("expected *ParseError, got %T", err) + } +} + +func TestParseIncludeNotFound(t *testing.T) { + dir := t.TempDir() + main := filepath.Join(dir, "nginx.conf") + content := "include /nonexistent/path.conf;" + if err := os.WriteFile(main, []byte(content), 0644); err != nil { + t.Fatal(err) + } + _, err := ParseFile(main) + if err == nil { + t.Fatal("expected error for include of nonexistent file") + } +} + +func TestParseIncludeGlobNoMatch(t *testing.T) { + dir := t.TempDir() + main := filepath.Join(dir, "nginx.conf") + content := "include " + dir + "/nonexistent/*.conf;" + if err := os.WriteFile(main, []byte(content), 0644); err != nil { + t.Fatal(err) + } + cfg, err := ParseFile(main) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(cfg.Directives) != 0 { + t.Fatalf("expected 0 directives for glob with no matches, got %d", len(cfg.Directives)) + } +} + +func TestParseLocationModifiers(t *testing.T) { + tests := []struct { + name string + input string + dirName string + args []string + }{ + { + name: "exact match", + input: "location = /path {}", + dirName: "location", + args: []string{"=", "/path"}, + }, + { + name: "regex", + input: `location ~ \.php$ {}`, + dirName: "location", + args: []string{"~", `\.php$`}, + }, + { + name: "case insensitive regex", + input: `location ~* \.jpg$ {}`, + dirName: "location", + args: []string{"~*", `\.jpg$`}, + }, + { + name: "prefix with continuation", + input: "location ^~ /images/ {}", + dirName: "location", + args: []string{"^~", "/images/"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg, err := Parse(tt.input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(cfg.Directives) != 1 { + t.Fatalf("expected 1 directive, got %d", len(cfg.Directives)) + } + d := cfg.Directives[0] + if d.Name != tt.dirName { + t.Errorf("expected name %q, got %q", tt.dirName, d.Name) + } + if len(d.Args) != len(tt.args) { + t.Fatalf("expected %d args, got %d: %v", len(tt.args), len(d.Args), d.Args) + } + for i, want := range tt.args { + if d.Args[i] != want { + t.Errorf("arg[%d]: expected %q, got %q", i, want, d.Args[i]) + } + } + }) + } +} + +func TestParseEmptyBlock(t *testing.T) { + cfg, err := Parse("server {}") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(cfg.Directives) != 1 { + t.Fatalf("expected 1 directive, got %d", len(cfg.Directives)) + } + d := cfg.Directives[0] + if d.Name != "server" { + t.Errorf("expected name %q, got %q", "server", d.Name) + } + if len(d.Block) != 0 { + t.Errorf("expected empty block, got %d children", len(d.Block)) + } +} + +func TestParseMultipleArgs(t *testing.T) { + cfg, err := Parse("return 301 https://example.com;") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(cfg.Directives) != 1 { + t.Fatalf("expected 1 directive, got %d", len(cfg.Directives)) + } + d := cfg.Directives[0] + if d.Name != "return" { + t.Errorf("expected name %q, got %q", "return", d.Name) + } + expected := []string{"301", "https://example.com"} + if len(d.Args) != len(expected) { + t.Fatalf("expected %d args, got %d", len(expected), len(d.Args)) + } + for i, want := range expected { + if d.Args[i] != want { + t.Errorf("arg[%d]: expected %q, got %q", i, want, d.Args[i]) + } + } +} diff --git a/main.go b/main.go index e8a95ab..989ce89 100644 --- a/main.go +++ b/main.go @@ -3,6 +3,7 @@ // 该文件包含命令行参数解析和应用程序启动逻辑: // - 配置文件路径指定(-c/--config) // - 默认配置生成(--generate-config) +// - nginx 配置导入(--import/-i) // - 版本信息显示(-v) // // 使用示例: @@ -26,7 +27,9 @@ func main() { cfgPathLong := flag.String("config", "", "配置文件路径(长参数)") genConfig := flag.Bool("generate-config", false, "生成默认配置") genConfigShort := flag.Bool("g", false, "生成默认配置(短参数)") - outputPath := flag.String("o", "", "输出文件路径(配合 --generate-config)") + outputPath := flag.String("o", "", "输出文件路径(配合 --generate-config 或 --import)") + importPath := flag.String("import", "", "导入 nginx 配置文件") + importPathShort := flag.String("i", "", "导入 nginx 配置文件(短参数)") showVersion := flag.Bool("v", false, "显示版本") flag.Parse() @@ -37,6 +40,10 @@ func main() { configPath = *cfgPathLong } generate := *genConfig || *genConfigShort + nginxImport := *importPath + if *importPathShort != "" { + nginxImport = *importPathShort + } - os.Exit(app.Run(configPath, generate, *outputPath, *showVersion)) + os.Exit(app.Run(configPath, generate, *outputPath, nginxImport, *showVersion)) }