- headers.go: 添加协议常量 protoHTTP/protoHTTPS - redirect_rewrite.go: 添加模式常量,修正缩进 - proxy_ssl_test.go: 表格测试字段对齐 - server.go: 添加 ServerModeAuto 分支防御性处理 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
289 lines
8.8 KiB
Go
289 lines
8.8 KiB
Go
// Package proxy 反向代理包,为 Lolly HTTP 服务器提供反向代理功能。
|
||
package proxy
|
||
|
||
import (
|
||
"regexp"
|
||
"strings"
|
||
|
||
"github.com/valyala/fasthttp"
|
||
"rua.plus/lolly/internal/config"
|
||
"rua.plus/lolly/internal/variable"
|
||
)
|
||
|
||
// RedirectRewrite 模式常量
|
||
const (
|
||
redirectModeDefault = "default"
|
||
redirectModeOff = "off"
|
||
redirectModeCustom = "custom"
|
||
)
|
||
|
||
// compiledRule 预编译的改写规则
|
||
type compiledRule struct {
|
||
pattern *regexp.Regexp // 正则模式,nil 表示非正则匹配
|
||
replacement string // 替换模板(含变量)
|
||
exactMatch string // 精确匹配前缀(用于 prefix 匹配)
|
||
caseInsensitive bool // 正则大小写不敏感(~* 前缀)
|
||
}
|
||
|
||
// RedirectRewriter Location/Refresh 头改写器
|
||
type RedirectRewriter struct {
|
||
proxyPath string // 用于 default 模式(当前代理路径)
|
||
mode string // "default" | "off" | "custom"(空字符串视为 default)
|
||
rules []compiledRule // 仅 custom 模式预编译
|
||
}
|
||
|
||
// NewRedirectRewriter 创建改写器
|
||
// proxyPath: 当前代理路径(如 "/api/")
|
||
// 注意:mode 为空字符串时默认为 "default"
|
||
func NewRedirectRewriter(cfg *config.RedirectRewriteConfig, proxyPath string) (*RedirectRewriter, error) {
|
||
if cfg == nil {
|
||
// 未配置时默认启用 default 模式
|
||
return &RedirectRewriter{
|
||
mode: redirectModeDefault,
|
||
proxyPath: proxyPath,
|
||
}, nil
|
||
}
|
||
|
||
rw := &RedirectRewriter{
|
||
mode: cfg.Mode,
|
||
proxyPath: proxyPath,
|
||
}
|
||
|
||
// custom 模式:预编译规则
|
||
if cfg.Mode == redirectModeCustom {
|
||
rules := make([]compiledRule, 0, len(cfg.Rules))
|
||
for _, rule := range cfg.Rules {
|
||
cr := compiledRule{
|
||
replacement: rule.Replacement,
|
||
}
|
||
|
||
if strings.HasPrefix(rule.Pattern, "~") {
|
||
// 正则模式
|
||
var patternStr string
|
||
if strings.HasPrefix(rule.Pattern, "~*") {
|
||
cr.caseInsensitive = true
|
||
patternStr = rule.Pattern[2:]
|
||
} else {
|
||
patternStr = rule.Pattern[1:]
|
||
}
|
||
re, err := regexp.Compile(patternStr)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
cr.pattern = re
|
||
} else {
|
||
// 非正则:使用前缀匹配
|
||
cr.exactMatch = rule.Pattern
|
||
}
|
||
|
||
rules = append(rules, cr)
|
||
}
|
||
rw.rules = rules
|
||
}
|
||
|
||
return rw, nil
|
||
}
|
||
|
||
// Mode 返回当前模式(处理空字符串默认值)
|
||
func (r *RedirectRewriter) Mode() string {
|
||
if r.mode == "" {
|
||
return redirectModeDefault
|
||
}
|
||
return r.mode
|
||
}
|
||
|
||
// RewriteResponse 改写响应中的 Location 和 Refresh 头
|
||
// targetURL: 实际选中的上游地址(用于 default 模式)
|
||
// originalClientHost: 客户端原始 Host(在 modifyRequestHeaders 改写前保存)
|
||
// 调用位置:必须在 modifyResponseHeaders 之前
|
||
// 内部逻辑:
|
||
// - 检查 resp.StatusCode(),仅 3xx 状态码处理 Location 头
|
||
// - 所有状态码都处理 Refresh 头
|
||
func (r *RedirectRewriter) RewriteResponse(resp *fasthttp.Response, ctx *fasthttp.RequestCtx, targetURL string, originalClientHost string) {
|
||
statusCode := resp.StatusCode()
|
||
|
||
// 仅 3xx 状态码处理 Location 头
|
||
if statusCode >= 300 && statusCode < 400 {
|
||
location := resp.Header.Peek("Location")
|
||
if len(location) > 0 {
|
||
rewritten := r.rewriteURL(string(location), ctx, targetURL, originalClientHost)
|
||
if rewritten != "" {
|
||
resp.Header.Set("Location", rewritten)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 所有状态码都处理 Refresh 头
|
||
refresh := resp.Header.Peek("Refresh")
|
||
if len(refresh) > 0 {
|
||
rewritten := r.rewriteRefresh(string(refresh), ctx, targetURL, originalClientHost)
|
||
if rewritten != "" {
|
||
resp.Header.Set("Refresh", rewritten)
|
||
}
|
||
}
|
||
}
|
||
|
||
// RewriteRefreshOnly 仅改写 Refresh 头(用于缓存响应路径)
|
||
// Location 头在缓存响应中不存在(缓存仅存储 2xx),故跳过
|
||
func (r *RedirectRewriter) RewriteRefreshOnly(resp *fasthttp.Response, ctx *fasthttp.RequestCtx, targetURL string, originalClientHost string) {
|
||
refresh := resp.Header.Peek("Refresh")
|
||
if len(refresh) > 0 {
|
||
rewritten := r.rewriteRefresh(string(refresh), ctx, targetURL, originalClientHost)
|
||
if rewritten != "" {
|
||
resp.Header.Set("Refresh", rewritten)
|
||
}
|
||
}
|
||
}
|
||
|
||
// rewriteURL 改写单个 URL 值(Location 或 Refresh 中的 URL 部分)
|
||
// originalClientHost: 客户端原始 Host(用于 default 模式构建 replacement)
|
||
func (r *RedirectRewriter) rewriteURL(headerValue string, ctx *fasthttp.RequestCtx, targetURL string, originalClientHost string) string {
|
||
if headerValue == "" {
|
||
return ""
|
||
}
|
||
|
||
switch r.Mode() {
|
||
case redirectModeOff:
|
||
return headerValue
|
||
|
||
case redirectModeCustom:
|
||
return r.rewriteCustom(headerValue, ctx)
|
||
|
||
case redirectModeDefault, "":
|
||
return r.rewriteDefault(headerValue, ctx, targetURL, originalClientHost)
|
||
|
||
default:
|
||
return headerValue
|
||
}
|
||
}
|
||
|
||
// rewriteDefault 动态生成 default 规则(运行时)
|
||
// 使用前缀匹配:如果 headerValue 以 targetURL 开头,替换为 replacement + 原路径后缀
|
||
// replacement 使用 originalClientHost 构建:"$scheme://originalClientHost/"
|
||
// 例如:targetURL="http://backend:8000", headerValue="http://backend:8000/api/v2/users"
|
||
//
|
||
// → 替换为 "$scheme://originalClientHost/api/v2/users"
|
||
func (r *RedirectRewriter) rewriteDefault(headerValue string, ctx *fasthttp.RequestCtx, targetURL string, originalClientHost string) string {
|
||
if targetURL == "" {
|
||
return headerValue
|
||
}
|
||
|
||
// 精确前缀匹配:headerValue 以 targetURL 开头,且后面是 / ? # 或结束
|
||
// 防止 "https://www.google.com" 匹配 "https://www.google.com.hk"(后者后面是 .hk)
|
||
if strings.HasPrefix(headerValue, targetURL) {
|
||
remaining := headerValue[len(targetURL):]
|
||
// 检查剩余部分是否以合法分隔符开头
|
||
if len(remaining) == 0 || remaining[0] == '/' || remaining[0] == '?' || remaining[0] == '#' {
|
||
// 使用客户端原始 host 构建 replacement
|
||
scheme := protoHTTP
|
||
if ctx.IsTLS() {
|
||
scheme = protoHTTPS
|
||
}
|
||
replacement := scheme + "://" + originalClientHost
|
||
return replacement + remaining
|
||
}
|
||
}
|
||
|
||
return headerValue
|
||
}
|
||
|
||
// rewriteCustom 使用预编译的 custom 规则改写 URL
|
||
// 规则按顺序匹配,第一个成功的生效
|
||
func (r *RedirectRewriter) rewriteCustom(headerValue string, ctx *fasthttp.RequestCtx) string {
|
||
vc := variable.NewContext(ctx)
|
||
defer variable.ReleaseContext(vc)
|
||
|
||
for _, rule := range r.rules {
|
||
if rule.pattern != nil {
|
||
// 正则匹配
|
||
if rule.caseInsensitive {
|
||
// 大小写不敏感:先将 headerValue 转为小写匹配,但替换时保留原始值
|
||
lowerValue := strings.ToLower(headerValue)
|
||
loc := rule.pattern.FindStringIndex(lowerValue)
|
||
if loc != nil {
|
||
expanded := vc.Expand(rule.replacement)
|
||
result := headerValue[:loc[0]] + expanded + headerValue[loc[1]:]
|
||
return result
|
||
}
|
||
} else {
|
||
loc := rule.pattern.FindStringIndex(headerValue)
|
||
if loc != nil {
|
||
expanded := vc.Expand(rule.replacement)
|
||
result := headerValue[:loc[0]] + expanded + headerValue[loc[1]:]
|
||
return result
|
||
}
|
||
}
|
||
} else if rule.exactMatch != "" {
|
||
// 前缀匹配
|
||
if strings.HasPrefix(headerValue, rule.exactMatch) {
|
||
expanded := vc.Expand(rule.replacement)
|
||
suffix := strings.TrimPrefix(headerValue, rule.exactMatch)
|
||
return expanded + suffix
|
||
}
|
||
}
|
||
}
|
||
|
||
return headerValue
|
||
}
|
||
|
||
// rewriteRefresh 改写 Refresh 头
|
||
// 格式:`N; url=URL` 或 `N;url=URL`(无空格)或纯数字 `N`
|
||
func (r *RedirectRewriter) rewriteRefresh(value string, ctx *fasthttp.RequestCtx, targetURL string, originalClientHost string) string {
|
||
delay, url, valid := parseRefreshHeader(value)
|
||
if !valid || url == "" {
|
||
// 无法解析或无 URL,原样返回
|
||
return value
|
||
}
|
||
|
||
rewrittenURL := r.rewriteURL(url, ctx, targetURL, originalClientHost)
|
||
if rewrittenURL == url {
|
||
// URL 未变化,原样返回
|
||
return value
|
||
}
|
||
|
||
return delay + "; url=" + rewrittenURL
|
||
}
|
||
|
||
// parseRefreshHeader 解析 Refresh 头格式
|
||
// 格式:`N; url=URL` 或 `N;url=URL`(无空格)或纯数字 `N`
|
||
// 返回:delay(N), url(URL), 是否有效
|
||
// 边缘处理:忽略引号、忽略多余参数
|
||
func parseRefreshHeader(value string) (delay string, url string, valid bool) {
|
||
value = strings.TrimSpace(value)
|
||
if value == "" {
|
||
return "", "", false
|
||
}
|
||
|
||
// 查找 url= 部分
|
||
urlIdx := strings.Index(strings.ToLower(value), "url=")
|
||
if urlIdx == -1 {
|
||
// 纯数字格式,有效但无 URL
|
||
return value, "", true
|
||
}
|
||
|
||
// 提取 delay 部分(url= 之前的部分)
|
||
delay = strings.TrimSpace(value[:urlIdx])
|
||
// 去除 delay 末尾的分号
|
||
delay = strings.TrimSuffix(delay, ";")
|
||
delay = strings.TrimSpace(delay)
|
||
if delay == "" {
|
||
return "", "", false
|
||
}
|
||
|
||
// 提取 URL 部分(url= 之后)
|
||
url = strings.TrimSpace(value[urlIdx+4:])
|
||
if url == "" {
|
||
return delay, "", true
|
||
}
|
||
|
||
// 去除 URL 两端的引号(如果有)
|
||
if len(url) >= 2 {
|
||
if (url[0] == '"' && url[len(url)-1] == '"') ||
|
||
(url[0] == '\'' && url[len(url)-1] == '\'') {
|
||
url = url[1 : len(url)-1]
|
||
}
|
||
}
|
||
|
||
return delay, url, true
|
||
}
|