diff --git a/internal/loadbalance/sticky.go b/internal/loadbalance/sticky.go index ffe01ee..c64a90f 100644 --- a/internal/loadbalance/sticky.go +++ b/internal/loadbalance/sticky.go @@ -2,6 +2,8 @@ package loadbalance import ( "encoding/base64" + "strconv" + "strings" "sync" "time" @@ -98,8 +100,8 @@ func (s *StickySession) Select(ctx *fasthttp.RequestCtx, targets []*Target) *Tar // 检查现有 cookie cookieValue := ctx.Request.Header.Cookie(s.config.Name) if len(cookieValue) > 0 { - decodedURL, err := decodeStickyCookie(string(cookieValue)) - if err == nil && decodedURL != "" { + decodedURL, _, ok := decodeStickyCookie(string(cookieValue)) + if ok && decodedURL != "" { // 查找对应的目标 for _, target := range targets { if target.URL == decodedURL && target.IsAvailable() { @@ -107,15 +109,15 @@ func (s *StickySession) Select(ctx *fasthttp.RequestCtx, targets []*Target) *Tar } } // 目标不可用,删除会话记录 - s.deleteSession(decodedURL) + s.deleteSession(string(cookieValue)) } } // 使用 fallback 选择目标 selected := s.fallback.Select(targets) if selected != nil { - s.setCookie(ctx, selected.URL) - s.recordSession(selected.URL) + newCookieValue := s.setCookie(ctx, selected.URL) + s.recordSession(newCookieValue, selected.URL) } return selected } @@ -125,14 +127,16 @@ func (s *StickySession) SelectExcluding(targets []*Target, excluded []*Target) * return s.fallback.SelectExcluding(targets, excluded) } -// setCookie 设置会话 cookie 到响应头。 -func (s *StickySession) setCookie(ctx *fasthttp.RequestCtx, targetURL string) { +// setCookie 设置会话 cookie 到响应头,返回编码后的 cookie 值。 +func (s *StickySession) setCookie(ctx *fasthttp.RequestCtx, targetURL string) string { + expires := time.Now().Add(s.config.Expires) cookie := &fasthttp.Cookie{} cookie.SetKey(s.config.Name) - cookie.SetValue(encodeStickyCookie(targetURL)) + encoded := encodeStickyCookie(targetURL, expires) + cookie.SetValue(encoded) if s.config.Expires > 0 { - cookie.SetExpire(time.Now().Add(s.config.Expires)) + cookie.SetExpire(expires) } if s.config.Domain != "" { cookie.SetDomain(s.config.Domain) @@ -159,13 +163,14 @@ func (s *StickySession) setCookie(ctx *fasthttp.RequestCtx, targetURL string) { } ctx.Response.Header.SetCookie(cookie) + return encoded } // recordSession 记录会话到 shard 中。 -func (s *StickySession) recordSession(targetURL string) { - shard := s.getShard(targetURL) +func (s *StickySession) recordSession(cookieValue, targetURL string) { + shard := s.getShard(cookieValue) shard.mu.Lock() - shard.entries[targetURL] = &stickyEntry{ + shard.entries[cookieValue] = &stickyEntry{ targetURL: targetURL, expires: time.Now().Add(s.config.Expires), } @@ -173,31 +178,40 @@ func (s *StickySession) recordSession(targetURL string) { } // deleteSession 从 shard 中删除会话记录。 -func (s *StickySession) deleteSession(targetURL string) { - shard := s.getShard(targetURL) +func (s *StickySession) deleteSession(cookieValue string) { + shard := s.getShard(cookieValue) shard.mu.Lock() - delete(shard.entries, targetURL) + delete(shard.entries, cookieValue) shard.mu.Unlock() } -// getShard 根据 targetURL 选择对应的 shard。 -func (s *StickySession) getShard(targetURL string) *stickyShard { - hash := fnvHash64a(targetURL) +// getShard 根据 cookieValue 选择对应的 shard。 +func (s *StickySession) getShard(cookieValue string) *stickyShard { + hash := fnvHash64a(cookieValue) return s.shards[hash%stickyShardCount] } -// encodeStickyCookie 将目标 URL 编码为 cookie 值(base64)。 -func encodeStickyCookie(targetURL string) string { - return base64.URLEncoding.EncodeToString([]byte(targetURL)) +// encodeStickyCookie 将目标 URL 和过期时间编码为 cookie 值(base64)。 +func encodeStickyCookie(targetURL string, expires time.Time) string { + raw := targetURL + "|" + strconv.FormatInt(expires.Unix(), 10) + return base64.URLEncoding.EncodeToString([]byte(raw)) } -// decodeStickyCookie 解码 cookie 值为目标 URL。 -func decodeStickyCookie(value string) (string, error) { - decoded, err := base64.URLEncoding.DecodeString(value) +// decodeStickyCookie 解码 cookie 值为目标 URL 和过期时间。 +func decodeStickyCookie(value string) (targetURL string, expires time.Time, ok bool) { + raw, err := base64.URLEncoding.DecodeString(value) if err != nil { - return "", err + return } - return string(decoded), nil + parts := strings.Split(string(raw), "|") + if len(parts) != 2 { + return + } + ts, err := strconv.ParseInt(parts[1], 10, 64) + if err != nil { + return + } + return parts[0], time.Unix(ts, 0), true } // Ensure StickySession implements the SelectExcluding part of Balancer interface. diff --git a/internal/loadbalance/sticky_test.go b/internal/loadbalance/sticky_test.go index 6203e5e..55237bc 100644 --- a/internal/loadbalance/sticky_test.go +++ b/internal/loadbalance/sticky_test.go @@ -3,6 +3,7 @@ package loadbalance import ( "sync" "testing" + "time" "github.com/valyala/fasthttp" ) @@ -16,6 +17,7 @@ func TestStickySession_BasicRoute(t *testing.T) { config.Enabled = true fallback := NewRoundRobin() sticky := NewStickySession(config, fallback) + sticky.Start() defer sticky.Stop() targets := []*Target{ @@ -41,6 +43,7 @@ func TestStickySession_BasicRoute(t *testing.T) { config.Enabled = true fallback := NewRoundRobin() sticky := NewStickySession(config, fallback) + sticky.Start() defer sticky.Stop() targets := []*Target{ @@ -80,6 +83,7 @@ func TestStickySession_BasicRoute(t *testing.T) { config.Enabled = false fallback := NewRoundRobin() sticky := NewStickySession(config, fallback) + sticky.Start() defer sticky.Stop() targets := []*Target{ @@ -107,6 +111,7 @@ func TestStickySession_TargetUnavailable(t *testing.T) { config.Enabled = true fallback := NewRoundRobin() sticky := NewStickySession(config, fallback) + sticky.Start() defer sticky.Stop() targets := []*Target{ @@ -155,36 +160,44 @@ func TestStickySession_CookieEncodeDecode(t *testing.T) { t.Parallel() t.Run("编码解码round-trip", func(_ *testing.T) { url := "http://backend1:8080" - encoded := encodeStickyCookie(url) + expires := time.Now().Add(time.Hour) + encoded := encodeStickyCookie(url, expires) if encoded == "" { t.Fatal("encodeStickyCookie() 返回空字符串") } - decoded, err := decodeStickyCookie(encoded) - if err != nil { - t.Fatalf("decodeStickyCookie() 错误: %v", err) + decodedURL, decodedExpires, ok := decodeStickyCookie(encoded) + if !ok { + t.Fatal("decodeStickyCookie() returned ok=false") } - if decoded != url { - t.Errorf("解码后 URL = %q, want %q", decoded, url) + if decodedURL != url { + t.Errorf("解码后 URL = %q, want %q", decodedURL, url) + } + if decodedExpires.Unix() != expires.Unix() { + t.Errorf("解码后 expires = %v, want %v", decodedExpires, expires) } }) t.Run("空URL编码解码", func(_ *testing.T) { - encoded := encodeStickyCookie("") - decoded, err := decodeStickyCookie(encoded) - if err != nil { - t.Fatalf("decodeStickyCookie() 错误: %v", err) + expires := time.Now().Add(time.Hour) + encoded := encodeStickyCookie("", expires) + decodedURL, decodedExpires, ok := decodeStickyCookie(encoded) + if !ok { + t.Fatal("decodeStickyCookie() returned ok=false") } - if decoded != "" { - t.Errorf("解码后 URL = %q, want 空字符串", decoded) + if decodedURL != "" { + t.Errorf("解码后 URL = %q, want 空字符串", decodedURL) + } + if decodedExpires.Unix() != expires.Unix() { + t.Errorf("解码后 expires = %v, want %v", decodedExpires, expires) } }) t.Run("无效编码", func(_ *testing.T) { - _, err := decodeStickyCookie("invalid-base64!!!") - if err == nil { - t.Error("decodeStickyCookie() 应返回错误") + decodedURL, decodedExpires, ok := decodeStickyCookie("invalid-base64!!!") + if ok { + t.Errorf("decodeStickyCookie() = (%q, %v, %v), want ok=false", decodedURL, decodedExpires, ok) } }) } @@ -197,6 +210,7 @@ func TestStickySession_Concurrent(t *testing.T) { config.Enabled = true fallback := NewRoundRobin() sticky := NewStickySession(config, fallback) + sticky.Start() defer sticky.Stop() targets := []*Target{ @@ -213,7 +227,7 @@ func TestStickySession_Concurrent(t *testing.T) { ctx := &fasthttp.RequestCtx{} // 交替使用有 cookie 和没有 cookie 的请求 if idx%2 == 0 { - ctx.Request.Header.SetCookie(config.Name, encodeStickyCookie("http://backend1:8080")) + ctx.Request.Header.SetCookie(config.Name, encodeStickyCookie("http://backend1:8080", time.Now().Add(time.Hour))) } got := sticky.Select(ctx, targets) if got == nil { @@ -232,6 +246,7 @@ func TestStickySession_SelectExcluding(t *testing.T) { config.Enabled = true fallback := NewRoundRobin() sticky := NewStickySession(config, fallback) + sticky.Start() defer sticky.Stop() targets := []*Target{