lolly/internal/proxy/redirect_rewrite.go
xfy 470c82d940 style(proxy,server): 代码风格优化
- headers.go: 添加协议常量 protoHTTP/protoHTTPS
- redirect_rewrite.go: 添加模式常量,修正缩进
- proxy_ssl_test.go: 表格测试字段对齐
- server.go: 添加 ServerModeAuto 分支防御性处理

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-16 09:54:09 +08:00

289 lines
8.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package proxy 反向代理包,为 Lolly HTTP 服务器提供反向代理功能。
package proxy
import (
"regexp"
"strings"
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/variable"
)
// RedirectRewrite 模式常量
const (
redirectModeDefault = "default"
redirectModeOff = "off"
redirectModeCustom = "custom"
)
// compiledRule 预编译的改写规则
type compiledRule struct {
pattern *regexp.Regexp // 正则模式nil 表示非正则匹配
replacement string // 替换模板(含变量)
exactMatch string // 精确匹配前缀(用于 prefix 匹配)
caseInsensitive bool // 正则大小写不敏感(~* 前缀)
}
// RedirectRewriter Location/Refresh 头改写器
type RedirectRewriter struct {
proxyPath string // 用于 default 模式(当前代理路径)
mode string // "default" | "off" | "custom"(空字符串视为 default
rules []compiledRule // 仅 custom 模式预编译
}
// NewRedirectRewriter 创建改写器
// proxyPath: 当前代理路径(如 "/api/"
// 注意mode 为空字符串时默认为 "default"
func NewRedirectRewriter(cfg *config.RedirectRewriteConfig, proxyPath string) (*RedirectRewriter, error) {
if cfg == nil {
// 未配置时默认启用 default 模式
return &RedirectRewriter{
mode: redirectModeDefault,
proxyPath: proxyPath,
}, nil
}
rw := &RedirectRewriter{
mode: cfg.Mode,
proxyPath: proxyPath,
}
// custom 模式:预编译规则
if cfg.Mode == redirectModeCustom {
rules := make([]compiledRule, 0, len(cfg.Rules))
for _, rule := range cfg.Rules {
cr := compiledRule{
replacement: rule.Replacement,
}
if strings.HasPrefix(rule.Pattern, "~") {
// 正则模式
var patternStr string
if strings.HasPrefix(rule.Pattern, "~*") {
cr.caseInsensitive = true
patternStr = rule.Pattern[2:]
} else {
patternStr = rule.Pattern[1:]
}
re, err := regexp.Compile(patternStr)
if err != nil {
return nil, err
}
cr.pattern = re
} else {
// 非正则:使用前缀匹配
cr.exactMatch = rule.Pattern
}
rules = append(rules, cr)
}
rw.rules = rules
}
return rw, nil
}
// Mode 返回当前模式(处理空字符串默认值)
func (r *RedirectRewriter) Mode() string {
if r.mode == "" {
return redirectModeDefault
}
return r.mode
}
// RewriteResponse 改写响应中的 Location 和 Refresh 头
// targetURL: 实际选中的上游地址(用于 default 模式)
// originalClientHost: 客户端原始 Host在 modifyRequestHeaders 改写前保存)
// 调用位置:必须在 modifyResponseHeaders 之前
// 内部逻辑:
// - 检查 resp.StatusCode(),仅 3xx 状态码处理 Location 头
// - 所有状态码都处理 Refresh 头
func (r *RedirectRewriter) RewriteResponse(resp *fasthttp.Response, ctx *fasthttp.RequestCtx, targetURL string, originalClientHost string) {
statusCode := resp.StatusCode()
// 仅 3xx 状态码处理 Location 头
if statusCode >= 300 && statusCode < 400 {
location := resp.Header.Peek("Location")
if len(location) > 0 {
rewritten := r.rewriteURL(string(location), ctx, targetURL, originalClientHost)
if rewritten != "" {
resp.Header.Set("Location", rewritten)
}
}
}
// 所有状态码都处理 Refresh 头
refresh := resp.Header.Peek("Refresh")
if len(refresh) > 0 {
rewritten := r.rewriteRefresh(string(refresh), ctx, targetURL, originalClientHost)
if rewritten != "" {
resp.Header.Set("Refresh", rewritten)
}
}
}
// RewriteRefreshOnly 仅改写 Refresh 头(用于缓存响应路径)
// Location 头在缓存响应中不存在(缓存仅存储 2xx故跳过
func (r *RedirectRewriter) RewriteRefreshOnly(resp *fasthttp.Response, ctx *fasthttp.RequestCtx, targetURL string, originalClientHost string) {
refresh := resp.Header.Peek("Refresh")
if len(refresh) > 0 {
rewritten := r.rewriteRefresh(string(refresh), ctx, targetURL, originalClientHost)
if rewritten != "" {
resp.Header.Set("Refresh", rewritten)
}
}
}
// rewriteURL 改写单个 URL 值Location 或 Refresh 中的 URL 部分)
// originalClientHost: 客户端原始 Host用于 default 模式构建 replacement
func (r *RedirectRewriter) rewriteURL(headerValue string, ctx *fasthttp.RequestCtx, targetURL string, originalClientHost string) string {
if headerValue == "" {
return ""
}
switch r.Mode() {
case redirectModeOff:
return headerValue
case redirectModeCustom:
return r.rewriteCustom(headerValue, ctx)
case redirectModeDefault, "":
return r.rewriteDefault(headerValue, ctx, targetURL, originalClientHost)
default:
return headerValue
}
}
// rewriteDefault 动态生成 default 规则(运行时)
// 使用前缀匹配:如果 headerValue 以 targetURL 开头,替换为 replacement + 原路径后缀
// replacement 使用 originalClientHost 构建:"$scheme://originalClientHost/"
// 例如targetURL="http://backend:8000", headerValue="http://backend:8000/api/v2/users"
//
// → 替换为 "$scheme://originalClientHost/api/v2/users"
func (r *RedirectRewriter) rewriteDefault(headerValue string, ctx *fasthttp.RequestCtx, targetURL string, originalClientHost string) string {
if targetURL == "" {
return headerValue
}
// 精确前缀匹配headerValue 以 targetURL 开头,且后面是 / ? # 或结束
// 防止 "https://www.google.com" 匹配 "https://www.google.com.hk"(后者后面是 .hk
if strings.HasPrefix(headerValue, targetURL) {
remaining := headerValue[len(targetURL):]
// 检查剩余部分是否以合法分隔符开头
if len(remaining) == 0 || remaining[0] == '/' || remaining[0] == '?' || remaining[0] == '#' {
// 使用客户端原始 host 构建 replacement
scheme := protoHTTP
if ctx.IsTLS() {
scheme = protoHTTPS
}
replacement := scheme + "://" + originalClientHost
return replacement + remaining
}
}
return headerValue
}
// rewriteCustom 使用预编译的 custom 规则改写 URL
// 规则按顺序匹配,第一个成功的生效
func (r *RedirectRewriter) rewriteCustom(headerValue string, ctx *fasthttp.RequestCtx) string {
vc := variable.NewContext(ctx)
defer variable.ReleaseContext(vc)
for _, rule := range r.rules {
if rule.pattern != nil {
// 正则匹配
if rule.caseInsensitive {
// 大小写不敏感:先将 headerValue 转为小写匹配,但替换时保留原始值
lowerValue := strings.ToLower(headerValue)
loc := rule.pattern.FindStringIndex(lowerValue)
if loc != nil {
expanded := vc.Expand(rule.replacement)
result := headerValue[:loc[0]] + expanded + headerValue[loc[1]:]
return result
}
} else {
loc := rule.pattern.FindStringIndex(headerValue)
if loc != nil {
expanded := vc.Expand(rule.replacement)
result := headerValue[:loc[0]] + expanded + headerValue[loc[1]:]
return result
}
}
} else if rule.exactMatch != "" {
// 前缀匹配
if strings.HasPrefix(headerValue, rule.exactMatch) {
expanded := vc.Expand(rule.replacement)
suffix := strings.TrimPrefix(headerValue, rule.exactMatch)
return expanded + suffix
}
}
}
return headerValue
}
// rewriteRefresh 改写 Refresh 头
// 格式:`N; url=URL` 或 `N;url=URL`(无空格)或纯数字 `N`
func (r *RedirectRewriter) rewriteRefresh(value string, ctx *fasthttp.RequestCtx, targetURL string, originalClientHost string) string {
delay, url, valid := parseRefreshHeader(value)
if !valid || url == "" {
// 无法解析或无 URL原样返回
return value
}
rewrittenURL := r.rewriteURL(url, ctx, targetURL, originalClientHost)
if rewrittenURL == url {
// URL 未变化,原样返回
return value
}
return delay + "; url=" + rewrittenURL
}
// parseRefreshHeader 解析 Refresh 头格式
// 格式:`N; url=URL` 或 `N;url=URL`(无空格)或纯数字 `N`
// 返回delay(N), url(URL), 是否有效
// 边缘处理:忽略引号、忽略多余参数
func parseRefreshHeader(value string) (delay string, url string, valid bool) {
value = strings.TrimSpace(value)
if value == "" {
return "", "", false
}
// 查找 url= 部分
urlIdx := strings.Index(strings.ToLower(value), "url=")
if urlIdx == -1 {
// 纯数字格式,有效但无 URL
return value, "", true
}
// 提取 delay 部分url= 之前的部分)
delay = strings.TrimSpace(value[:urlIdx])
// 去除 delay 末尾的分号
delay = strings.TrimSuffix(delay, ";")
delay = strings.TrimSpace(delay)
if delay == "" {
return "", "", false
}
// 提取 URL 部分url= 之后)
url = strings.TrimSpace(value[urlIdx+4:])
if url == "" {
return delay, "", true
}
// 去除 URL 两端的引号(如果有)
if len(url) >= 2 {
if (url[0] == '"' && url[len(url)-1] == '"') ||
(url[0] == '\'' && url[len(url)-1] == '\'') {
url = url[1 : len(url)-1]
}
}
return delay, url, true
}