fix(sticky): fix cookie format, shard keying, and tests

- Encode cookie as base64(target_url + | + timestamp) per spec
- Use cookie value (not targetURL) for shard key and session map keys
- Add missing sticky.Start() calls in tests
- Fix time precision in cookie encode/decode tests
This commit is contained in:
xfy 2026-06-08 17:36:41 +08:00
parent f69a11ea05
commit 66752a47f0
2 changed files with 71 additions and 42 deletions

View File

@ -2,6 +2,8 @@ package loadbalance
import ( import (
"encoding/base64" "encoding/base64"
"strconv"
"strings"
"sync" "sync"
"time" "time"
@ -98,8 +100,8 @@ func (s *StickySession) Select(ctx *fasthttp.RequestCtx, targets []*Target) *Tar
// 检查现有 cookie // 检查现有 cookie
cookieValue := ctx.Request.Header.Cookie(s.config.Name) cookieValue := ctx.Request.Header.Cookie(s.config.Name)
if len(cookieValue) > 0 { if len(cookieValue) > 0 {
decodedURL, err := decodeStickyCookie(string(cookieValue)) decodedURL, _, ok := decodeStickyCookie(string(cookieValue))
if err == nil && decodedURL != "" { if ok && decodedURL != "" {
// 查找对应的目标 // 查找对应的目标
for _, target := range targets { for _, target := range targets {
if target.URL == decodedURL && target.IsAvailable() { if target.URL == decodedURL && target.IsAvailable() {
@ -107,15 +109,15 @@ func (s *StickySession) Select(ctx *fasthttp.RequestCtx, targets []*Target) *Tar
} }
} }
// 目标不可用,删除会话记录 // 目标不可用,删除会话记录
s.deleteSession(decodedURL) s.deleteSession(string(cookieValue))
} }
} }
// 使用 fallback 选择目标 // 使用 fallback 选择目标
selected := s.fallback.Select(targets) selected := s.fallback.Select(targets)
if selected != nil { if selected != nil {
s.setCookie(ctx, selected.URL) newCookieValue := s.setCookie(ctx, selected.URL)
s.recordSession(selected.URL) s.recordSession(newCookieValue, selected.URL)
} }
return selected return selected
} }
@ -125,14 +127,16 @@ func (s *StickySession) SelectExcluding(targets []*Target, excluded []*Target) *
return s.fallback.SelectExcluding(targets, excluded) return s.fallback.SelectExcluding(targets, excluded)
} }
// setCookie 设置会话 cookie 到响应头。 // setCookie 设置会话 cookie 到响应头,返回编码后的 cookie 值。
func (s *StickySession) setCookie(ctx *fasthttp.RequestCtx, targetURL string) { func (s *StickySession) setCookie(ctx *fasthttp.RequestCtx, targetURL string) string {
expires := time.Now().Add(s.config.Expires)
cookie := &fasthttp.Cookie{} cookie := &fasthttp.Cookie{}
cookie.SetKey(s.config.Name) cookie.SetKey(s.config.Name)
cookie.SetValue(encodeStickyCookie(targetURL)) encoded := encodeStickyCookie(targetURL, expires)
cookie.SetValue(encoded)
if s.config.Expires > 0 { if s.config.Expires > 0 {
cookie.SetExpire(time.Now().Add(s.config.Expires)) cookie.SetExpire(expires)
} }
if s.config.Domain != "" { if s.config.Domain != "" {
cookie.SetDomain(s.config.Domain) cookie.SetDomain(s.config.Domain)
@ -159,13 +163,14 @@ func (s *StickySession) setCookie(ctx *fasthttp.RequestCtx, targetURL string) {
} }
ctx.Response.Header.SetCookie(cookie) ctx.Response.Header.SetCookie(cookie)
return encoded
} }
// recordSession 记录会话到 shard 中。 // recordSession 记录会话到 shard 中。
func (s *StickySession) recordSession(targetURL string) { func (s *StickySession) recordSession(cookieValue, targetURL string) {
shard := s.getShard(targetURL) shard := s.getShard(cookieValue)
shard.mu.Lock() shard.mu.Lock()
shard.entries[targetURL] = &stickyEntry{ shard.entries[cookieValue] = &stickyEntry{
targetURL: targetURL, targetURL: targetURL,
expires: time.Now().Add(s.config.Expires), expires: time.Now().Add(s.config.Expires),
} }
@ -173,31 +178,40 @@ func (s *StickySession) recordSession(targetURL string) {
} }
// deleteSession 从 shard 中删除会话记录。 // deleteSession 从 shard 中删除会话记录。
func (s *StickySession) deleteSession(targetURL string) { func (s *StickySession) deleteSession(cookieValue string) {
shard := s.getShard(targetURL) shard := s.getShard(cookieValue)
shard.mu.Lock() shard.mu.Lock()
delete(shard.entries, targetURL) delete(shard.entries, cookieValue)
shard.mu.Unlock() shard.mu.Unlock()
} }
// getShard 根据 targetURL 选择对应的 shard。 // getShard 根据 cookieValue 选择对应的 shard。
func (s *StickySession) getShard(targetURL string) *stickyShard { func (s *StickySession) getShard(cookieValue string) *stickyShard {
hash := fnvHash64a(targetURL) hash := fnvHash64a(cookieValue)
return s.shards[hash%stickyShardCount] return s.shards[hash%stickyShardCount]
} }
// encodeStickyCookie 将目标 URL 编码为 cookie 值base64 // encodeStickyCookie 将目标 URL 和过期时间编码为 cookie 值base64
func encodeStickyCookie(targetURL string) string { func encodeStickyCookie(targetURL string, expires time.Time) string {
return base64.URLEncoding.EncodeToString([]byte(targetURL)) raw := targetURL + "|" + strconv.FormatInt(expires.Unix(), 10)
return base64.URLEncoding.EncodeToString([]byte(raw))
} }
// decodeStickyCookie 解码 cookie 值为目标 URL // decodeStickyCookie 解码 cookie 值为目标 URL 和过期时间
func decodeStickyCookie(value string) (string, error) { func decodeStickyCookie(value string) (targetURL string, expires time.Time, ok bool) {
decoded, err := base64.URLEncoding.DecodeString(value) raw, err := base64.URLEncoding.DecodeString(value)
if err != nil { if err != nil {
return "", err return
} }
return string(decoded), nil parts := strings.Split(string(raw), "|")
if len(parts) != 2 {
return
}
ts, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
return
}
return parts[0], time.Unix(ts, 0), true
} }
// Ensure StickySession implements the SelectExcluding part of Balancer interface. // Ensure StickySession implements the SelectExcluding part of Balancer interface.

View File

@ -3,6 +3,7 @@ package loadbalance
import ( import (
"sync" "sync"
"testing" "testing"
"time"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -16,6 +17,7 @@ func TestStickySession_BasicRoute(t *testing.T) {
config.Enabled = true config.Enabled = true
fallback := NewRoundRobin() fallback := NewRoundRobin()
sticky := NewStickySession(config, fallback) sticky := NewStickySession(config, fallback)
sticky.Start()
defer sticky.Stop() defer sticky.Stop()
targets := []*Target{ targets := []*Target{
@ -41,6 +43,7 @@ func TestStickySession_BasicRoute(t *testing.T) {
config.Enabled = true config.Enabled = true
fallback := NewRoundRobin() fallback := NewRoundRobin()
sticky := NewStickySession(config, fallback) sticky := NewStickySession(config, fallback)
sticky.Start()
defer sticky.Stop() defer sticky.Stop()
targets := []*Target{ targets := []*Target{
@ -80,6 +83,7 @@ func TestStickySession_BasicRoute(t *testing.T) {
config.Enabled = false config.Enabled = false
fallback := NewRoundRobin() fallback := NewRoundRobin()
sticky := NewStickySession(config, fallback) sticky := NewStickySession(config, fallback)
sticky.Start()
defer sticky.Stop() defer sticky.Stop()
targets := []*Target{ targets := []*Target{
@ -107,6 +111,7 @@ func TestStickySession_TargetUnavailable(t *testing.T) {
config.Enabled = true config.Enabled = true
fallback := NewRoundRobin() fallback := NewRoundRobin()
sticky := NewStickySession(config, fallback) sticky := NewStickySession(config, fallback)
sticky.Start()
defer sticky.Stop() defer sticky.Stop()
targets := []*Target{ targets := []*Target{
@ -155,36 +160,44 @@ func TestStickySession_CookieEncodeDecode(t *testing.T) {
t.Parallel() t.Parallel()
t.Run("编码解码round-trip", func(_ *testing.T) { t.Run("编码解码round-trip", func(_ *testing.T) {
url := "http://backend1:8080" url := "http://backend1:8080"
encoded := encodeStickyCookie(url) expires := time.Now().Add(time.Hour)
encoded := encodeStickyCookie(url, expires)
if encoded == "" { if encoded == "" {
t.Fatal("encodeStickyCookie() 返回空字符串") t.Fatal("encodeStickyCookie() 返回空字符串")
} }
decoded, err := decodeStickyCookie(encoded) decodedURL, decodedExpires, ok := decodeStickyCookie(encoded)
if err != nil { if !ok {
t.Fatalf("decodeStickyCookie() 错误: %v", err) t.Fatal("decodeStickyCookie() returned ok=false")
} }
if decoded != url { if decodedURL != url {
t.Errorf("解码后 URL = %q, want %q", decoded, url) t.Errorf("解码后 URL = %q, want %q", decodedURL, url)
}
if decodedExpires.Unix() != expires.Unix() {
t.Errorf("解码后 expires = %v, want %v", decodedExpires, expires)
} }
}) })
t.Run("空URL编码解码", func(_ *testing.T) { t.Run("空URL编码解码", func(_ *testing.T) {
encoded := encodeStickyCookie("") expires := time.Now().Add(time.Hour)
decoded, err := decodeStickyCookie(encoded) encoded := encodeStickyCookie("", expires)
if err != nil { decodedURL, decodedExpires, ok := decodeStickyCookie(encoded)
t.Fatalf("decodeStickyCookie() 错误: %v", err) if !ok {
t.Fatal("decodeStickyCookie() returned ok=false")
} }
if decoded != "" { if decodedURL != "" {
t.Errorf("解码后 URL = %q, want 空字符串", decoded) t.Errorf("解码后 URL = %q, want 空字符串", decodedURL)
}
if decodedExpires.Unix() != expires.Unix() {
t.Errorf("解码后 expires = %v, want %v", decodedExpires, expires)
} }
}) })
t.Run("无效编码", func(_ *testing.T) { t.Run("无效编码", func(_ *testing.T) {
_, err := decodeStickyCookie("invalid-base64!!!") decodedURL, decodedExpires, ok := decodeStickyCookie("invalid-base64!!!")
if err == nil { if ok {
t.Error("decodeStickyCookie() 应返回错误") t.Errorf("decodeStickyCookie() = (%q, %v, %v), want ok=false", decodedURL, decodedExpires, ok)
} }
}) })
} }
@ -197,6 +210,7 @@ func TestStickySession_Concurrent(t *testing.T) {
config.Enabled = true config.Enabled = true
fallback := NewRoundRobin() fallback := NewRoundRobin()
sticky := NewStickySession(config, fallback) sticky := NewStickySession(config, fallback)
sticky.Start()
defer sticky.Stop() defer sticky.Stop()
targets := []*Target{ targets := []*Target{
@ -213,7 +227,7 @@ func TestStickySession_Concurrent(t *testing.T) {
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
// 交替使用有 cookie 和没有 cookie 的请求 // 交替使用有 cookie 和没有 cookie 的请求
if idx%2 == 0 { if idx%2 == 0 {
ctx.Request.Header.SetCookie(config.Name, encodeStickyCookie("http://backend1:8080")) ctx.Request.Header.SetCookie(config.Name, encodeStickyCookie("http://backend1:8080", time.Now().Add(time.Hour)))
} }
got := sticky.Select(ctx, targets) got := sticky.Select(ctx, targets)
if got == nil { if got == nil {
@ -232,6 +246,7 @@ func TestStickySession_SelectExcluding(t *testing.T) {
config.Enabled = true config.Enabled = true
fallback := NewRoundRobin() fallback := NewRoundRobin()
sticky := NewStickySession(config, fallback) sticky := NewStickySession(config, fallback)
sticky.Start()
defer sticky.Stop() defer sticky.Stop()
targets := []*Target{ targets := []*Target{