fix(sticky): fix cookie format, shard keying, and tests

- Encode cookie as base64(target_url + | + timestamp) per spec
- Use cookie value (not targetURL) for shard key and session map keys
- Add missing sticky.Start() calls in tests
- Fix time precision in cookie encode/decode tests
This commit is contained in:
xfy 2026-06-08 17:36:41 +08:00
parent f69a11ea05
commit 66752a47f0
2 changed files with 71 additions and 42 deletions

View File

@ -2,6 +2,8 @@ package loadbalance
import (
"encoding/base64"
"strconv"
"strings"
"sync"
"time"
@ -98,8 +100,8 @@ func (s *StickySession) Select(ctx *fasthttp.RequestCtx, targets []*Target) *Tar
// 检查现有 cookie
cookieValue := ctx.Request.Header.Cookie(s.config.Name)
if len(cookieValue) > 0 {
decodedURL, err := decodeStickyCookie(string(cookieValue))
if err == nil && decodedURL != "" {
decodedURL, _, ok := decodeStickyCookie(string(cookieValue))
if ok && decodedURL != "" {
// 查找对应的目标
for _, target := range targets {
if target.URL == decodedURL && target.IsAvailable() {
@ -107,15 +109,15 @@ func (s *StickySession) Select(ctx *fasthttp.RequestCtx, targets []*Target) *Tar
}
}
// 目标不可用,删除会话记录
s.deleteSession(decodedURL)
s.deleteSession(string(cookieValue))
}
}
// 使用 fallback 选择目标
selected := s.fallback.Select(targets)
if selected != nil {
s.setCookie(ctx, selected.URL)
s.recordSession(selected.URL)
newCookieValue := s.setCookie(ctx, selected.URL)
s.recordSession(newCookieValue, selected.URL)
}
return selected
}
@ -125,14 +127,16 @@ func (s *StickySession) SelectExcluding(targets []*Target, excluded []*Target) *
return s.fallback.SelectExcluding(targets, excluded)
}
// setCookie 设置会话 cookie 到响应头。
func (s *StickySession) setCookie(ctx *fasthttp.RequestCtx, targetURL string) {
// setCookie 设置会话 cookie 到响应头,返回编码后的 cookie 值。
func (s *StickySession) setCookie(ctx *fasthttp.RequestCtx, targetURL string) string {
expires := time.Now().Add(s.config.Expires)
cookie := &fasthttp.Cookie{}
cookie.SetKey(s.config.Name)
cookie.SetValue(encodeStickyCookie(targetURL))
encoded := encodeStickyCookie(targetURL, expires)
cookie.SetValue(encoded)
if s.config.Expires > 0 {
cookie.SetExpire(time.Now().Add(s.config.Expires))
cookie.SetExpire(expires)
}
if s.config.Domain != "" {
cookie.SetDomain(s.config.Domain)
@ -159,13 +163,14 @@ func (s *StickySession) setCookie(ctx *fasthttp.RequestCtx, targetURL string) {
}
ctx.Response.Header.SetCookie(cookie)
return encoded
}
// recordSession 记录会话到 shard 中。
func (s *StickySession) recordSession(targetURL string) {
shard := s.getShard(targetURL)
func (s *StickySession) recordSession(cookieValue, targetURL string) {
shard := s.getShard(cookieValue)
shard.mu.Lock()
shard.entries[targetURL] = &stickyEntry{
shard.entries[cookieValue] = &stickyEntry{
targetURL: targetURL,
expires: time.Now().Add(s.config.Expires),
}
@ -173,31 +178,40 @@ func (s *StickySession) recordSession(targetURL string) {
}
// deleteSession 从 shard 中删除会话记录。
func (s *StickySession) deleteSession(targetURL string) {
shard := s.getShard(targetURL)
func (s *StickySession) deleteSession(cookieValue string) {
shard := s.getShard(cookieValue)
shard.mu.Lock()
delete(shard.entries, targetURL)
delete(shard.entries, cookieValue)
shard.mu.Unlock()
}
// getShard 根据 targetURL 选择对应的 shard。
func (s *StickySession) getShard(targetURL string) *stickyShard {
hash := fnvHash64a(targetURL)
// getShard 根据 cookieValue 选择对应的 shard。
func (s *StickySession) getShard(cookieValue string) *stickyShard {
hash := fnvHash64a(cookieValue)
return s.shards[hash%stickyShardCount]
}
// encodeStickyCookie 将目标 URL 编码为 cookie 值base64
func encodeStickyCookie(targetURL string) string {
return base64.URLEncoding.EncodeToString([]byte(targetURL))
// encodeStickyCookie 将目标 URL 和过期时间编码为 cookie 值base64
func encodeStickyCookie(targetURL string, expires time.Time) string {
raw := targetURL + "|" + strconv.FormatInt(expires.Unix(), 10)
return base64.URLEncoding.EncodeToString([]byte(raw))
}
// decodeStickyCookie 解码 cookie 值为目标 URL
func decodeStickyCookie(value string) (string, error) {
decoded, err := base64.URLEncoding.DecodeString(value)
// decodeStickyCookie 解码 cookie 值为目标 URL 和过期时间
func decodeStickyCookie(value string) (targetURL string, expires time.Time, ok bool) {
raw, err := base64.URLEncoding.DecodeString(value)
if err != nil {
return "", err
return
}
return string(decoded), nil
parts := strings.Split(string(raw), "|")
if len(parts) != 2 {
return
}
ts, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
return
}
return parts[0], time.Unix(ts, 0), true
}
// Ensure StickySession implements the SelectExcluding part of Balancer interface.

View File

@ -3,6 +3,7 @@ package loadbalance
import (
"sync"
"testing"
"time"
"github.com/valyala/fasthttp"
)
@ -16,6 +17,7 @@ func TestStickySession_BasicRoute(t *testing.T) {
config.Enabled = true
fallback := NewRoundRobin()
sticky := NewStickySession(config, fallback)
sticky.Start()
defer sticky.Stop()
targets := []*Target{
@ -41,6 +43,7 @@ func TestStickySession_BasicRoute(t *testing.T) {
config.Enabled = true
fallback := NewRoundRobin()
sticky := NewStickySession(config, fallback)
sticky.Start()
defer sticky.Stop()
targets := []*Target{
@ -80,6 +83,7 @@ func TestStickySession_BasicRoute(t *testing.T) {
config.Enabled = false
fallback := NewRoundRobin()
sticky := NewStickySession(config, fallback)
sticky.Start()
defer sticky.Stop()
targets := []*Target{
@ -107,6 +111,7 @@ func TestStickySession_TargetUnavailable(t *testing.T) {
config.Enabled = true
fallback := NewRoundRobin()
sticky := NewStickySession(config, fallback)
sticky.Start()
defer sticky.Stop()
targets := []*Target{
@ -155,36 +160,44 @@ func TestStickySession_CookieEncodeDecode(t *testing.T) {
t.Parallel()
t.Run("编码解码round-trip", func(_ *testing.T) {
url := "http://backend1:8080"
encoded := encodeStickyCookie(url)
expires := time.Now().Add(time.Hour)
encoded := encodeStickyCookie(url, expires)
if encoded == "" {
t.Fatal("encodeStickyCookie() 返回空字符串")
}
decoded, err := decodeStickyCookie(encoded)
if err != nil {
t.Fatalf("decodeStickyCookie() 错误: %v", err)
decodedURL, decodedExpires, ok := decodeStickyCookie(encoded)
if !ok {
t.Fatal("decodeStickyCookie() returned ok=false")
}
if decoded != url {
t.Errorf("解码后 URL = %q, want %q", decoded, url)
if decodedURL != url {
t.Errorf("解码后 URL = %q, want %q", decodedURL, url)
}
if decodedExpires.Unix() != expires.Unix() {
t.Errorf("解码后 expires = %v, want %v", decodedExpires, expires)
}
})
t.Run("空URL编码解码", func(_ *testing.T) {
encoded := encodeStickyCookie("")
decoded, err := decodeStickyCookie(encoded)
if err != nil {
t.Fatalf("decodeStickyCookie() 错误: %v", err)
expires := time.Now().Add(time.Hour)
encoded := encodeStickyCookie("", expires)
decodedURL, decodedExpires, ok := decodeStickyCookie(encoded)
if !ok {
t.Fatal("decodeStickyCookie() returned ok=false")
}
if decoded != "" {
t.Errorf("解码后 URL = %q, want 空字符串", decoded)
if decodedURL != "" {
t.Errorf("解码后 URL = %q, want 空字符串", decodedURL)
}
if decodedExpires.Unix() != expires.Unix() {
t.Errorf("解码后 expires = %v, want %v", decodedExpires, expires)
}
})
t.Run("无效编码", func(_ *testing.T) {
_, err := decodeStickyCookie("invalid-base64!!!")
if err == nil {
t.Error("decodeStickyCookie() 应返回错误")
decodedURL, decodedExpires, ok := decodeStickyCookie("invalid-base64!!!")
if ok {
t.Errorf("decodeStickyCookie() = (%q, %v, %v), want ok=false", decodedURL, decodedExpires, ok)
}
})
}
@ -197,6 +210,7 @@ func TestStickySession_Concurrent(t *testing.T) {
config.Enabled = true
fallback := NewRoundRobin()
sticky := NewStickySession(config, fallback)
sticky.Start()
defer sticky.Stop()
targets := []*Target{
@ -213,7 +227,7 @@ func TestStickySession_Concurrent(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
// 交替使用有 cookie 和没有 cookie 的请求
if idx%2 == 0 {
ctx.Request.Header.SetCookie(config.Name, encodeStickyCookie("http://backend1:8080"))
ctx.Request.Header.SetCookie(config.Name, encodeStickyCookie("http://backend1:8080", time.Now().Add(time.Hour)))
}
got := sticky.Select(ctx, targets)
if got == nil {
@ -232,6 +246,7 @@ func TestStickySession_SelectExcluding(t *testing.T) {
config.Enabled = true
fallback := NewRoundRobin()
sticky := NewStickySession(config, fallback)
sticky.Start()
defer sticky.Stop()
targets := []*Target{