feat(proxy): 实现 Location/Refresh 头改写功能
- 新增 RedirectRewriter 改写器,支持三种模式: - default: 动态匹配 targetURL 前缀并替换为客户端原始 host - off: 禁用改写 - custom: 使用预编译规则列表匹配替换 - 实现 RewriteResponse 方法改写 Location(3xx 状态码)和 Refresh 头 - 实现 RewriteRefreshOnly 方法用于缓存响应路径(仅 Refresh) - 支持正则匹配(~ 前缀)和大小写不敏感(~* 前缀) - 支持变量展开($host, $scheme, $server_port 等) - 添加 parseRefreshHeader 解析 Refresh 头格式(N; url=URL) - 在 Proxy.ServeHTTP 中集成改写器调用: - 保存 originalClientHost 用于 default 模式 - 缓存响应路径调用 RewriteRefreshOnly - 正常响应路径调用 RewriteResponse - 添加完整单元测试覆盖各模式和边界情况 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
a026277385
commit
abbc4a50dd
@ -85,6 +85,7 @@ type Proxy struct {
|
|||||||
cache *cache.ProxyCache
|
cache *cache.ProxyCache
|
||||||
healthChecker *HealthChecker
|
healthChecker *HealthChecker
|
||||||
luaEngine *lua.LuaEngine // Lua 引擎引用
|
luaEngine *lua.LuaEngine // Lua 引擎引用
|
||||||
|
redirectRewriter *RedirectRewriter // 重定向改写器
|
||||||
stopCh chan struct{}
|
stopCh chan struct{}
|
||||||
targets []*loadbalance.Target
|
targets []*loadbalance.Target
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
@ -160,6 +161,13 @@ func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target, transportC
|
|||||||
p.cache = cache.NewProxyCache(rules, cfg.Cache.CacheLock, cfg.Cache.StaleWhileRevalidate)
|
p.cache = cache.NewProxyCache(rules, cfg.Cache.CacheLock, cfg.Cache.StaleWhileRevalidate)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 初始化重定向改写器
|
||||||
|
rewriter, err := NewRedirectRewriter(cfg.RedirectRewrite, cfg.Path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create redirect rewriter: %w", err)
|
||||||
|
}
|
||||||
|
p.redirectRewriter = rewriter
|
||||||
|
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -399,6 +407,10 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
|
|||||||
// 增加连接计数(用于最少连接数负载均衡)
|
// 增加连接计数(用于最少连接数负载均衡)
|
||||||
loadbalance.IncrementConnections(target)
|
loadbalance.IncrementConnections(target)
|
||||||
|
|
||||||
|
// 保存客户端原始 host(在 modifyRequestHeaders 改写前)
|
||||||
|
// 用于 redirect_rewrite 获取客户端实际访问地址
|
||||||
|
originalClientHost := string(ctx.Host())
|
||||||
|
|
||||||
// 设置上游地址
|
// 设置上游地址
|
||||||
upstreamAddr = target.URL
|
upstreamAddr = target.URL
|
||||||
|
|
||||||
@ -449,6 +461,9 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
|
|||||||
upstreamAddr = upstreamCache
|
upstreamAddr = upstreamCache
|
||||||
upstreamStatus = entry.Status
|
upstreamStatus = entry.Status
|
||||||
p.writeCachedResponse(ctx, entry)
|
p.writeCachedResponse(ctx, entry)
|
||||||
|
if p.redirectRewriter != nil {
|
||||||
|
p.redirectRewriter.RewriteRefreshOnly(&ctx.Response, ctx, upstreamCache, originalClientHost)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 过期缓存,尝试后台刷新,同时返回旧数据
|
// 过期缓存,尝试后台刷新,同时返回旧数据
|
||||||
@ -458,6 +473,9 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
|
|||||||
upstreamStatus = entry.Status
|
upstreamStatus = entry.Status
|
||||||
|
|
||||||
p.writeCachedResponse(ctx, entry)
|
p.writeCachedResponse(ctx, entry)
|
||||||
|
if p.redirectRewriter != nil {
|
||||||
|
p.redirectRewriter.RewriteRefreshOnly(&ctx.Response, ctx, upstreamCache, originalClientHost)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -473,6 +491,9 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
|
|||||||
upstreamStatus = entry.Status
|
upstreamStatus = entry.Status
|
||||||
|
|
||||||
p.writeCachedResponse(ctx, entry)
|
p.writeCachedResponse(ctx, entry)
|
||||||
|
if p.redirectRewriter != nil {
|
||||||
|
p.redirectRewriter.RewriteRefreshOnly(&ctx.Response, ctx, upstreamCache, originalClientHost)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 缓存未命中,需要重新选择目标
|
// 缓存未命中,需要重新选择目标
|
||||||
@ -566,6 +587,11 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
|
|||||||
p.cache.ReleaseLock(hashKey, nil)
|
p.cache.ReleaseLock(hashKey, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 改写重定向响应头(Location/Refresh)
|
||||||
|
if p.redirectRewriter != nil && p.redirectRewriter.Mode() != "off" {
|
||||||
|
p.redirectRewriter.RewriteResponse(&ctx.Response, ctx, upstreamAddr, originalClientHost)
|
||||||
|
}
|
||||||
|
|
||||||
// 修改响应头
|
// 修改响应头
|
||||||
p.modifyResponseHeaders(ctx)
|
p.modifyResponseHeaders(ctx)
|
||||||
return
|
return
|
||||||
|
|||||||
280
internal/proxy/redirect_rewrite.go
Normal file
280
internal/proxy/redirect_rewrite.go
Normal file
@ -0,0 +1,280 @@
|
|||||||
|
// Package proxy 反向代理包,为 Lolly HTTP 服务器提供反向代理功能。
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
"rua.plus/lolly/internal/config"
|
||||||
|
"rua.plus/lolly/internal/variable"
|
||||||
|
)
|
||||||
|
|
||||||
|
// compiledRule 预编译的改写规则
|
||||||
|
type compiledRule struct {
|
||||||
|
pattern *regexp.Regexp // 正则模式,nil 表示非正则匹配
|
||||||
|
exactMatch string // 精确匹配前缀(用于 prefix 匹配)
|
||||||
|
replacement string // 替换模板(含变量)
|
||||||
|
caseInsensitive bool // 正则大小写不敏感(~* 前缀)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RedirectRewriter Location/Refresh 头改写器
|
||||||
|
type RedirectRewriter struct {
|
||||||
|
mode string // "default" | "off" | "custom"(空字符串视为 default)
|
||||||
|
rules []compiledRule // 仅 custom 模式预编译
|
||||||
|
proxyPath string // 用于 default 模式(当前代理路径)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRedirectRewriter 创建改写器
|
||||||
|
// proxyPath: 当前代理路径(如 "/api/")
|
||||||
|
// 注意:mode 为空字符串时默认为 "default"
|
||||||
|
func NewRedirectRewriter(cfg *config.RedirectRewriteConfig, proxyPath string) (*RedirectRewriter, error) {
|
||||||
|
if cfg == nil {
|
||||||
|
// 未配置时默认启用 default 模式
|
||||||
|
return &RedirectRewriter{
|
||||||
|
mode: "default",
|
||||||
|
proxyPath: proxyPath,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rw := &RedirectRewriter{
|
||||||
|
mode: cfg.Mode,
|
||||||
|
proxyPath: proxyPath,
|
||||||
|
}
|
||||||
|
|
||||||
|
// custom 模式:预编译规则
|
||||||
|
if cfg.Mode == "custom" {
|
||||||
|
rules := make([]compiledRule, 0, len(cfg.Rules))
|
||||||
|
for _, rule := range cfg.Rules {
|
||||||
|
cr := compiledRule{
|
||||||
|
replacement: rule.Replacement,
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(rule.Pattern, "~") {
|
||||||
|
// 正则模式
|
||||||
|
patternStr := rule.Pattern
|
||||||
|
if strings.HasPrefix(rule.Pattern, "~*") {
|
||||||
|
cr.caseInsensitive = true
|
||||||
|
patternStr = rule.Pattern[2:]
|
||||||
|
} else {
|
||||||
|
patternStr = rule.Pattern[1:]
|
||||||
|
}
|
||||||
|
re, err := regexp.Compile(patternStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cr.pattern = re
|
||||||
|
} else {
|
||||||
|
// 非正则:使用前缀匹配
|
||||||
|
cr.exactMatch = rule.Pattern
|
||||||
|
}
|
||||||
|
|
||||||
|
rules = append(rules, cr)
|
||||||
|
}
|
||||||
|
rw.rules = rules
|
||||||
|
}
|
||||||
|
|
||||||
|
return rw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mode 返回当前模式(处理空字符串默认值)
|
||||||
|
func (r *RedirectRewriter) Mode() string {
|
||||||
|
if r.mode == "" {
|
||||||
|
return "default"
|
||||||
|
}
|
||||||
|
return r.mode
|
||||||
|
}
|
||||||
|
|
||||||
|
// RewriteResponse 改写响应中的 Location 和 Refresh 头
|
||||||
|
// targetURL: 实际选中的上游地址(用于 default 模式)
|
||||||
|
// originalClientHost: 客户端原始 Host(在 modifyRequestHeaders 改写前保存)
|
||||||
|
// 调用位置:必须在 modifyResponseHeaders 之前
|
||||||
|
// 内部逻辑:
|
||||||
|
// - 检查 resp.StatusCode(),仅 3xx 状态码处理 Location 头
|
||||||
|
// - 所有状态码都处理 Refresh 头
|
||||||
|
func (r *RedirectRewriter) RewriteResponse(resp *fasthttp.Response, ctx *fasthttp.RequestCtx, targetURL string, originalClientHost string) {
|
||||||
|
statusCode := resp.StatusCode()
|
||||||
|
|
||||||
|
// 仅 3xx 状态码处理 Location 头
|
||||||
|
if statusCode >= 300 && statusCode < 400 {
|
||||||
|
location := resp.Header.Peek("Location")
|
||||||
|
if len(location) > 0 {
|
||||||
|
rewritten := r.rewriteURL(string(location), ctx, targetURL, originalClientHost)
|
||||||
|
if rewritten != "" {
|
||||||
|
resp.Header.Set("Location", rewritten)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 所有状态码都处理 Refresh 头
|
||||||
|
refresh := resp.Header.Peek("Refresh")
|
||||||
|
if len(refresh) > 0 {
|
||||||
|
rewritten := r.rewriteRefresh(string(refresh), ctx, targetURL, originalClientHost)
|
||||||
|
if rewritten != "" {
|
||||||
|
resp.Header.Set("Refresh", rewritten)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RewriteRefreshOnly 仅改写 Refresh 头(用于缓存响应路径)
|
||||||
|
// Location 头在缓存响应中不存在(缓存仅存储 2xx),故跳过
|
||||||
|
func (r *RedirectRewriter) RewriteRefreshOnly(resp *fasthttp.Response, ctx *fasthttp.RequestCtx, targetURL string, originalClientHost string) {
|
||||||
|
refresh := resp.Header.Peek("Refresh")
|
||||||
|
if len(refresh) > 0 {
|
||||||
|
rewritten := r.rewriteRefresh(string(refresh), ctx, targetURL, originalClientHost)
|
||||||
|
if rewritten != "" {
|
||||||
|
resp.Header.Set("Refresh", rewritten)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewriteURL 改写单个 URL 值(Location 或 Refresh 中的 URL 部分)
|
||||||
|
// originalClientHost: 客户端原始 Host(用于 default 模式构建 replacement)
|
||||||
|
func (r *RedirectRewriter) rewriteURL(headerValue string, ctx *fasthttp.RequestCtx, targetURL string, originalClientHost string) string {
|
||||||
|
if headerValue == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
switch r.Mode() {
|
||||||
|
case "off":
|
||||||
|
return headerValue
|
||||||
|
|
||||||
|
case "custom":
|
||||||
|
return r.rewriteCustom(headerValue, ctx)
|
||||||
|
|
||||||
|
case "default", "":
|
||||||
|
return r.rewriteDefault(headerValue, ctx, targetURL, originalClientHost)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return headerValue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewriteDefault 动态生成 default 规则(运行时)
|
||||||
|
// 使用前缀匹配:如果 headerValue 以 targetURL 开头,替换为 replacement + 原路径后缀
|
||||||
|
// replacement 使用 originalClientHost 构建:"$scheme://originalClientHost/"
|
||||||
|
// 例如:targetURL="http://backend:8000", headerValue="http://backend:8000/api/v2/users"
|
||||||
|
// → 替换为 "$scheme://originalClientHost/api/v2/users"
|
||||||
|
func (r *RedirectRewriter) rewriteDefault(headerValue string, ctx *fasthttp.RequestCtx, targetURL string, originalClientHost string) string {
|
||||||
|
if targetURL == "" {
|
||||||
|
return headerValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 精确前缀匹配:headerValue 以 targetURL 开头,且后面是 / ? # 或结束
|
||||||
|
// 防止 "https://www.google.com" 匹配 "https://www.google.com.hk"(后者后面是 .hk)
|
||||||
|
if strings.HasPrefix(headerValue, targetURL) {
|
||||||
|
remaining := headerValue[len(targetURL):]
|
||||||
|
// 检查剩余部分是否以合法分隔符开头
|
||||||
|
if len(remaining) == 0 || remaining[0] == '/' || remaining[0] == '?' || remaining[0] == '#' {
|
||||||
|
// 使用客户端原始 host 构建 replacement
|
||||||
|
scheme := "http"
|
||||||
|
if ctx.IsTLS() {
|
||||||
|
scheme = "https"
|
||||||
|
}
|
||||||
|
replacement := scheme + "://" + originalClientHost
|
||||||
|
return replacement + remaining
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return headerValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewriteCustom 使用预编译的 custom 规则改写 URL
|
||||||
|
// 规则按顺序匹配,第一个成功的生效
|
||||||
|
func (r *RedirectRewriter) rewriteCustom(headerValue string, ctx *fasthttp.RequestCtx) string {
|
||||||
|
vc := variable.NewContext(ctx)
|
||||||
|
defer variable.ReleaseContext(vc)
|
||||||
|
|
||||||
|
for _, rule := range r.rules {
|
||||||
|
if rule.pattern != nil {
|
||||||
|
// 正则匹配
|
||||||
|
if rule.caseInsensitive {
|
||||||
|
// 大小写不敏感:先将 headerValue 转为小写匹配,但替换时保留原始值
|
||||||
|
lowerValue := strings.ToLower(headerValue)
|
||||||
|
loc := rule.pattern.FindStringIndex(lowerValue)
|
||||||
|
if loc != nil {
|
||||||
|
expanded := vc.Expand(rule.replacement)
|
||||||
|
result := headerValue[:loc[0]] + expanded + headerValue[loc[1]:]
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
loc := rule.pattern.FindStringIndex(headerValue)
|
||||||
|
if loc != nil {
|
||||||
|
expanded := vc.Expand(rule.replacement)
|
||||||
|
result := headerValue[:loc[0]] + expanded + headerValue[loc[1]:]
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if rule.exactMatch != "" {
|
||||||
|
// 前缀匹配
|
||||||
|
if strings.HasPrefix(headerValue, rule.exactMatch) {
|
||||||
|
expanded := vc.Expand(rule.replacement)
|
||||||
|
suffix := strings.TrimPrefix(headerValue, rule.exactMatch)
|
||||||
|
return expanded + suffix
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return headerValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewriteRefresh 改写 Refresh 头
|
||||||
|
// 格式:`N; url=URL` 或 `N;url=URL`(无空格)或纯数字 `N`
|
||||||
|
func (r *RedirectRewriter) rewriteRefresh(value string, ctx *fasthttp.RequestCtx, targetURL string, originalClientHost string) string {
|
||||||
|
delay, url, valid := parseRefreshHeader(value)
|
||||||
|
if !valid || url == "" {
|
||||||
|
// 无法解析或无 URL,原样返回
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
rewrittenURL := r.rewriteURL(url, ctx, targetURL, originalClientHost)
|
||||||
|
if rewrittenURL == url {
|
||||||
|
// URL 未变化,原样返回
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
return delay + "; url=" + rewrittenURL
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseRefreshHeader 解析 Refresh 头格式
|
||||||
|
// 格式:`N; url=URL` 或 `N;url=URL`(无空格)或纯数字 `N`
|
||||||
|
// 返回:delay(N), url(URL), 是否有效
|
||||||
|
// 边缘处理:忽略引号、忽略多余参数
|
||||||
|
func parseRefreshHeader(value string) (delay string, url string, valid bool) {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
if value == "" {
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查找 url= 部分
|
||||||
|
urlIdx := strings.Index(strings.ToLower(value), "url=")
|
||||||
|
if urlIdx == -1 {
|
||||||
|
// 纯数字格式,有效但无 URL
|
||||||
|
return value, "", true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提取 delay 部分(url= 之前的部分)
|
||||||
|
delay = strings.TrimSpace(value[:urlIdx])
|
||||||
|
// 去除 delay 末尾的分号
|
||||||
|
delay = strings.TrimSuffix(delay, ";")
|
||||||
|
delay = strings.TrimSpace(delay)
|
||||||
|
if delay == "" {
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提取 URL 部分(url= 之后)
|
||||||
|
url = strings.TrimSpace(value[urlIdx+4:])
|
||||||
|
if url == "" {
|
||||||
|
return delay, "", true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 去除 URL 两端的引号(如果有)
|
||||||
|
if len(url) >= 2 {
|
||||||
|
if (url[0] == '"' && url[len(url)-1] == '"') ||
|
||||||
|
(url[0] == '\'' && url[len(url)-1] == '\'') {
|
||||||
|
url = url[1 : len(url)-1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return delay, url, true
|
||||||
|
}
|
||||||
276
internal/proxy/redirect_rewrite_test.go
Normal file
276
internal/proxy/redirect_rewrite_test.go
Normal file
@ -0,0 +1,276 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"rua.plus/lolly/internal/config"
|
||||||
|
"rua.plus/lolly/internal/testutil"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestRedirectRewrite_ExactMatch 测试精确匹配改写
|
||||||
|
func TestRedirectRewrite_ExactMatch(t *testing.T) {
|
||||||
|
cfg := &config.RedirectRewriteConfig{
|
||||||
|
Mode: "custom",
|
||||||
|
Rules: []config.RedirectRewriteRule{
|
||||||
|
{Pattern: "http://localhost:8000/", Replacement: "/"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
rw, err := NewRedirectRewriter(cfg, "/")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRedirectRewriter() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := testutil.NewRequestCtx("GET", "/api/test")
|
||||||
|
|
||||||
|
resp := &fasthttp.Response{}
|
||||||
|
resp.Header.Set("Location", "http://localhost:8000/api/users")
|
||||||
|
resp.SetStatusCode(301)
|
||||||
|
|
||||||
|
rw.RewriteResponse(resp, ctx, "", "frontend:8080")
|
||||||
|
|
||||||
|
got := string(resp.Header.Peek("Location"))
|
||||||
|
want := "/api/users"
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("Location = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRedirectRewrite_DefaultMode 测试 default 模式前缀匹配
|
||||||
|
func TestRedirectRewrite_DefaultMode(t *testing.T) {
|
||||||
|
cfg := &config.RedirectRewriteConfig{
|
||||||
|
Mode: "default",
|
||||||
|
}
|
||||||
|
rw, err := NewRedirectRewriter(cfg, "/api/")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRedirectRewriter() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := testutil.NewRequestCtx("GET", "/api/test")
|
||||||
|
|
||||||
|
resp := &fasthttp.Response{}
|
||||||
|
resp.Header.Set("Location", "http://backend1:8000/api/v2/users")
|
||||||
|
resp.SetStatusCode(301)
|
||||||
|
|
||||||
|
// targetURL = http://backend1:8000, originalClientHost = frontend:8080
|
||||||
|
rw.RewriteResponse(resp, ctx, "http://backend1:8000", "frontend:8080")
|
||||||
|
|
||||||
|
got := string(resp.Header.Peek("Location"))
|
||||||
|
// default 模式:targetURL 前缀替换为 http://originalClientHost
|
||||||
|
want := "http://frontend:8080/api/v2/users"
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("Location = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRedirectRewrite_DefaultMode_ExternalService 测试代理外部服务
|
||||||
|
// 验证 Location 指向外部服务时不应改写
|
||||||
|
func TestRedirectRewrite_DefaultMode_ExternalService(t *testing.T) {
|
||||||
|
cfg := &config.RedirectRewriteConfig{
|
||||||
|
Mode: "default",
|
||||||
|
}
|
||||||
|
rw, err := NewRedirectRewriter(cfg, "/")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRedirectRewriter() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := testutil.NewRequestCtx("GET", "/")
|
||||||
|
|
||||||
|
resp := &fasthttp.Response{}
|
||||||
|
// Google 返回的 Location 指向 google.com.hk(不以 targetURL 开头)
|
||||||
|
resp.Header.Set("Location", "https://www.google.com.hk/search")
|
||||||
|
resp.SetStatusCode(302)
|
||||||
|
|
||||||
|
// targetURL = https://www.google.com, originalClientHost = localhost:8081
|
||||||
|
rw.RewriteResponse(resp, ctx, "https://www.google.com", "localhost:8081")
|
||||||
|
|
||||||
|
got := string(resp.Header.Peek("Location"))
|
||||||
|
// 不匹配 targetURL 前缀,应该原样返回
|
||||||
|
want := "https://www.google.com.hk/search"
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("Location = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRedirectRewrite_DefaultMode_MatchingExternal 测试匹配外部服务
|
||||||
|
// 验证 Location 以 targetURL 开头时正确改写
|
||||||
|
func TestRedirectRewrite_DefaultMode_MatchingExternal(t *testing.T) {
|
||||||
|
cfg := &config.RedirectRewriteConfig{
|
||||||
|
Mode: "default",
|
||||||
|
}
|
||||||
|
rw, err := NewRedirectRewriter(cfg, "/")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRedirectRewriter() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := testutil.NewRequestCtx("GET", "/")
|
||||||
|
|
||||||
|
resp := &fasthttp.Response{}
|
||||||
|
// Location 以 targetURL 开头,应该被改写
|
||||||
|
resp.Header.Set("Location", "https://www.google.com/search?q=test")
|
||||||
|
resp.SetStatusCode(302)
|
||||||
|
|
||||||
|
rw.RewriteResponse(resp, ctx, "https://www.google.com", "localhost:8081")
|
||||||
|
|
||||||
|
got := string(resp.Header.Peek("Location"))
|
||||||
|
// targetURL 前缀替换为 http://originalClientHost
|
||||||
|
want := "http://localhost:8081/search?q=test"
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("Location = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRedirectRewrite_OffMode 测试 off 模式不改写
|
||||||
|
func TestRedirectRewrite_OffMode(t *testing.T) {
|
||||||
|
cfg := &config.RedirectRewriteConfig{
|
||||||
|
Mode: "off",
|
||||||
|
}
|
||||||
|
rw, err := NewRedirectRewriter(cfg, "/")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRedirectRewriter() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := testutil.NewRequestCtx("GET", "/")
|
||||||
|
|
||||||
|
resp := &fasthttp.Response{}
|
||||||
|
resp.Header.Set("Location", "http://backend:8000/path")
|
||||||
|
resp.SetStatusCode(301)
|
||||||
|
|
||||||
|
rw.RewriteResponse(resp, ctx, "http://backend:8000", "frontend:8080")
|
||||||
|
|
||||||
|
got := string(resp.Header.Peek("Location"))
|
||||||
|
want := "http://backend:8000/path"
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("Location = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRedirectRewrite_RelativeURL 测试相对 URL 不改写
|
||||||
|
func TestRedirectRewrite_RelativeURL(t *testing.T) {
|
||||||
|
cfg := &config.RedirectRewriteConfig{
|
||||||
|
Mode: "default",
|
||||||
|
}
|
||||||
|
rw, err := NewRedirectRewriter(cfg, "/")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRedirectRewriter() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := testutil.NewRequestCtx("GET", "/")
|
||||||
|
|
||||||
|
resp := &fasthttp.Response{}
|
||||||
|
resp.Header.Set("Location", "/new-path")
|
||||||
|
resp.SetStatusCode(302)
|
||||||
|
|
||||||
|
rw.RewriteResponse(resp, ctx, "http://backend:8000", "frontend:8080")
|
||||||
|
|
||||||
|
got := string(resp.Header.Peek("Location"))
|
||||||
|
want := "/new-path"
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("Location = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRedirectRewrite_EmptyLocation 测试空 Location 不改写
|
||||||
|
func TestRedirectRewrite_EmptyLocation(t *testing.T) {
|
||||||
|
cfg := &config.RedirectRewriteConfig{
|
||||||
|
Mode: "default",
|
||||||
|
}
|
||||||
|
rw, err := NewRedirectRewriter(cfg, "/")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRedirectRewriter() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := testutil.NewRequestCtx("GET", "/")
|
||||||
|
|
||||||
|
resp := &fasthttp.Response{}
|
||||||
|
resp.Header.Set("Location", "")
|
||||||
|
resp.SetStatusCode(302)
|
||||||
|
|
||||||
|
rw.RewriteResponse(resp, ctx, "http://backend:8000", "frontend:8080")
|
||||||
|
|
||||||
|
got := resp.Header.Peek("Location")
|
||||||
|
if len(got) != 0 {
|
||||||
|
t.Errorf("Location should be empty, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRedirectRewrite_NonRedirectStatus 测试非 3xx 状态码不改写 Location
|
||||||
|
func TestRedirectRewrite_NonRedirectStatus(t *testing.T) {
|
||||||
|
cfg := &config.RedirectRewriteConfig{
|
||||||
|
Mode: "default",
|
||||||
|
}
|
||||||
|
rw, err := NewRedirectRewriter(cfg, "/")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRedirectRewriter() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := testutil.NewRequestCtx("GET", "/")
|
||||||
|
|
||||||
|
resp := &fasthttp.Response{}
|
||||||
|
resp.Header.Set("Location", "http://backend:8000/path")
|
||||||
|
resp.SetStatusCode(200) // 非 3xx
|
||||||
|
|
||||||
|
rw.RewriteResponse(resp, ctx, "http://backend:8000", "frontend:8080")
|
||||||
|
|
||||||
|
got := string(resp.Header.Peek("Location"))
|
||||||
|
want := "http://backend:8000/path"
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("Location = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRedirectRewrite_RefreshHeader 测试 Refresh 头改写
|
||||||
|
func TestRedirectRewrite_RefreshHeader(t *testing.T) {
|
||||||
|
cfg := &config.RedirectRewriteConfig{
|
||||||
|
Mode: "custom",
|
||||||
|
Rules: []config.RedirectRewriteRule{
|
||||||
|
{Pattern: "http://backend:8000/", Replacement: "http://frontend/"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
rw, err := NewRedirectRewriter(cfg, "/")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRedirectRewriter() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := testutil.NewRequestCtx("GET", "/")
|
||||||
|
|
||||||
|
resp := &fasthttp.Response{}
|
||||||
|
resp.Header.Set("Refresh", "5; url=http://backend:8000/api/")
|
||||||
|
resp.SetStatusCode(200)
|
||||||
|
|
||||||
|
rw.RewriteResponse(resp, ctx, "", "frontend:8080")
|
||||||
|
|
||||||
|
got := string(resp.Header.Peek("Refresh"))
|
||||||
|
want := "5; url=http://frontend/api/"
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("Refresh = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRedirectRewrite_NilConfig 测试 nil 配置默认启用 default 模式
|
||||||
|
func TestRedirectRewrite_NilConfig(t *testing.T) {
|
||||||
|
rw, err := NewRedirectRewriter(nil, "/api/")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRedirectRewriter() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rw.Mode() != "default" {
|
||||||
|
t.Errorf("Mode() = %q, want %q", rw.Mode(), "default")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRedirectRewrite_EmptyMode 测试空 Mode 默认为 default
|
||||||
|
func TestRedirectRewrite_EmptyMode(t *testing.T) {
|
||||||
|
cfg := &config.RedirectRewriteConfig{
|
||||||
|
Mode: "",
|
||||||
|
}
|
||||||
|
rw, err := NewRedirectRewriter(cfg, "/")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRedirectRewriter() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rw.Mode() != "default" {
|
||||||
|
t.Errorf("Mode() = %q, want %q", rw.Mode(), "default")
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user