fix(sticky): fix cookie format, shard keying, and tests
- Encode cookie as base64(target_url + | + timestamp) per spec - Use cookie value (not targetURL) for shard key and session map keys - Add missing sticky.Start() calls in tests - Fix time precision in cookie encode/decode tests
This commit is contained in:
parent
f69a11ea05
commit
66752a47f0
@ -2,6 +2,8 @@ package loadbalance
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -98,8 +100,8 @@ func (s *StickySession) Select(ctx *fasthttp.RequestCtx, targets []*Target) *Tar
|
||||
// 检查现有 cookie
|
||||
cookieValue := ctx.Request.Header.Cookie(s.config.Name)
|
||||
if len(cookieValue) > 0 {
|
||||
decodedURL, err := decodeStickyCookie(string(cookieValue))
|
||||
if err == nil && decodedURL != "" {
|
||||
decodedURL, _, ok := decodeStickyCookie(string(cookieValue))
|
||||
if ok && decodedURL != "" {
|
||||
// 查找对应的目标
|
||||
for _, target := range targets {
|
||||
if target.URL == decodedURL && target.IsAvailable() {
|
||||
@ -107,15 +109,15 @@ func (s *StickySession) Select(ctx *fasthttp.RequestCtx, targets []*Target) *Tar
|
||||
}
|
||||
}
|
||||
// 目标不可用,删除会话记录
|
||||
s.deleteSession(decodedURL)
|
||||
s.deleteSession(string(cookieValue))
|
||||
}
|
||||
}
|
||||
|
||||
// 使用 fallback 选择目标
|
||||
selected := s.fallback.Select(targets)
|
||||
if selected != nil {
|
||||
s.setCookie(ctx, selected.URL)
|
||||
s.recordSession(selected.URL)
|
||||
newCookieValue := s.setCookie(ctx, selected.URL)
|
||||
s.recordSession(newCookieValue, selected.URL)
|
||||
}
|
||||
return selected
|
||||
}
|
||||
@ -125,14 +127,16 @@ func (s *StickySession) SelectExcluding(targets []*Target, excluded []*Target) *
|
||||
return s.fallback.SelectExcluding(targets, excluded)
|
||||
}
|
||||
|
||||
// setCookie 设置会话 cookie 到响应头。
|
||||
func (s *StickySession) setCookie(ctx *fasthttp.RequestCtx, targetURL string) {
|
||||
// setCookie 设置会话 cookie 到响应头,返回编码后的 cookie 值。
|
||||
func (s *StickySession) setCookie(ctx *fasthttp.RequestCtx, targetURL string) string {
|
||||
expires := time.Now().Add(s.config.Expires)
|
||||
cookie := &fasthttp.Cookie{}
|
||||
cookie.SetKey(s.config.Name)
|
||||
cookie.SetValue(encodeStickyCookie(targetURL))
|
||||
encoded := encodeStickyCookie(targetURL, expires)
|
||||
cookie.SetValue(encoded)
|
||||
|
||||
if s.config.Expires > 0 {
|
||||
cookie.SetExpire(time.Now().Add(s.config.Expires))
|
||||
cookie.SetExpire(expires)
|
||||
}
|
||||
if s.config.Domain != "" {
|
||||
cookie.SetDomain(s.config.Domain)
|
||||
@ -159,13 +163,14 @@ func (s *StickySession) setCookie(ctx *fasthttp.RequestCtx, targetURL string) {
|
||||
}
|
||||
|
||||
ctx.Response.Header.SetCookie(cookie)
|
||||
return encoded
|
||||
}
|
||||
|
||||
// recordSession 记录会话到 shard 中。
|
||||
func (s *StickySession) recordSession(targetURL string) {
|
||||
shard := s.getShard(targetURL)
|
||||
func (s *StickySession) recordSession(cookieValue, targetURL string) {
|
||||
shard := s.getShard(cookieValue)
|
||||
shard.mu.Lock()
|
||||
shard.entries[targetURL] = &stickyEntry{
|
||||
shard.entries[cookieValue] = &stickyEntry{
|
||||
targetURL: targetURL,
|
||||
expires: time.Now().Add(s.config.Expires),
|
||||
}
|
||||
@ -173,31 +178,40 @@ func (s *StickySession) recordSession(targetURL string) {
|
||||
}
|
||||
|
||||
// deleteSession 从 shard 中删除会话记录。
|
||||
func (s *StickySession) deleteSession(targetURL string) {
|
||||
shard := s.getShard(targetURL)
|
||||
func (s *StickySession) deleteSession(cookieValue string) {
|
||||
shard := s.getShard(cookieValue)
|
||||
shard.mu.Lock()
|
||||
delete(shard.entries, targetURL)
|
||||
delete(shard.entries, cookieValue)
|
||||
shard.mu.Unlock()
|
||||
}
|
||||
|
||||
// getShard 根据 targetURL 选择对应的 shard。
|
||||
func (s *StickySession) getShard(targetURL string) *stickyShard {
|
||||
hash := fnvHash64a(targetURL)
|
||||
// getShard 根据 cookieValue 选择对应的 shard。
|
||||
func (s *StickySession) getShard(cookieValue string) *stickyShard {
|
||||
hash := fnvHash64a(cookieValue)
|
||||
return s.shards[hash%stickyShardCount]
|
||||
}
|
||||
|
||||
// encodeStickyCookie 将目标 URL 编码为 cookie 值(base64)。
|
||||
func encodeStickyCookie(targetURL string) string {
|
||||
return base64.URLEncoding.EncodeToString([]byte(targetURL))
|
||||
// encodeStickyCookie 将目标 URL 和过期时间编码为 cookie 值(base64)。
|
||||
func encodeStickyCookie(targetURL string, expires time.Time) string {
|
||||
raw := targetURL + "|" + strconv.FormatInt(expires.Unix(), 10)
|
||||
return base64.URLEncoding.EncodeToString([]byte(raw))
|
||||
}
|
||||
|
||||
// decodeStickyCookie 解码 cookie 值为目标 URL。
|
||||
func decodeStickyCookie(value string) (string, error) {
|
||||
decoded, err := base64.URLEncoding.DecodeString(value)
|
||||
// decodeStickyCookie 解码 cookie 值为目标 URL 和过期时间。
|
||||
func decodeStickyCookie(value string) (targetURL string, expires time.Time, ok bool) {
|
||||
raw, err := base64.URLEncoding.DecodeString(value)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return
|
||||
}
|
||||
return string(decoded), nil
|
||||
parts := strings.Split(string(raw), "|")
|
||||
if len(parts) != 2 {
|
||||
return
|
||||
}
|
||||
ts, err := strconv.ParseInt(parts[1], 10, 64)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return parts[0], time.Unix(ts, 0), true
|
||||
}
|
||||
|
||||
// Ensure StickySession implements the SelectExcluding part of Balancer interface.
|
||||
|
||||
@ -3,6 +3,7 @@ package loadbalance
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
@ -16,6 +17,7 @@ func TestStickySession_BasicRoute(t *testing.T) {
|
||||
config.Enabled = true
|
||||
fallback := NewRoundRobin()
|
||||
sticky := NewStickySession(config, fallback)
|
||||
sticky.Start()
|
||||
defer sticky.Stop()
|
||||
|
||||
targets := []*Target{
|
||||
@ -41,6 +43,7 @@ func TestStickySession_BasicRoute(t *testing.T) {
|
||||
config.Enabled = true
|
||||
fallback := NewRoundRobin()
|
||||
sticky := NewStickySession(config, fallback)
|
||||
sticky.Start()
|
||||
defer sticky.Stop()
|
||||
|
||||
targets := []*Target{
|
||||
@ -80,6 +83,7 @@ func TestStickySession_BasicRoute(t *testing.T) {
|
||||
config.Enabled = false
|
||||
fallback := NewRoundRobin()
|
||||
sticky := NewStickySession(config, fallback)
|
||||
sticky.Start()
|
||||
defer sticky.Stop()
|
||||
|
||||
targets := []*Target{
|
||||
@ -107,6 +111,7 @@ func TestStickySession_TargetUnavailable(t *testing.T) {
|
||||
config.Enabled = true
|
||||
fallback := NewRoundRobin()
|
||||
sticky := NewStickySession(config, fallback)
|
||||
sticky.Start()
|
||||
defer sticky.Stop()
|
||||
|
||||
targets := []*Target{
|
||||
@ -155,36 +160,44 @@ func TestStickySession_CookieEncodeDecode(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("编码解码round-trip", func(_ *testing.T) {
|
||||
url := "http://backend1:8080"
|
||||
encoded := encodeStickyCookie(url)
|
||||
expires := time.Now().Add(time.Hour)
|
||||
encoded := encodeStickyCookie(url, expires)
|
||||
if encoded == "" {
|
||||
t.Fatal("encodeStickyCookie() 返回空字符串")
|
||||
}
|
||||
|
||||
decoded, err := decodeStickyCookie(encoded)
|
||||
if err != nil {
|
||||
t.Fatalf("decodeStickyCookie() 错误: %v", err)
|
||||
decodedURL, decodedExpires, ok := decodeStickyCookie(encoded)
|
||||
if !ok {
|
||||
t.Fatal("decodeStickyCookie() returned ok=false")
|
||||
}
|
||||
|
||||
if decoded != url {
|
||||
t.Errorf("解码后 URL = %q, want %q", decoded, url)
|
||||
if decodedURL != url {
|
||||
t.Errorf("解码后 URL = %q, want %q", decodedURL, url)
|
||||
}
|
||||
if decodedExpires.Unix() != expires.Unix() {
|
||||
t.Errorf("解码后 expires = %v, want %v", decodedExpires, expires)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("空URL编码解码", func(_ *testing.T) {
|
||||
encoded := encodeStickyCookie("")
|
||||
decoded, err := decodeStickyCookie(encoded)
|
||||
if err != nil {
|
||||
t.Fatalf("decodeStickyCookie() 错误: %v", err)
|
||||
expires := time.Now().Add(time.Hour)
|
||||
encoded := encodeStickyCookie("", expires)
|
||||
decodedURL, decodedExpires, ok := decodeStickyCookie(encoded)
|
||||
if !ok {
|
||||
t.Fatal("decodeStickyCookie() returned ok=false")
|
||||
}
|
||||
if decoded != "" {
|
||||
t.Errorf("解码后 URL = %q, want 空字符串", decoded)
|
||||
if decodedURL != "" {
|
||||
t.Errorf("解码后 URL = %q, want 空字符串", decodedURL)
|
||||
}
|
||||
if decodedExpires.Unix() != expires.Unix() {
|
||||
t.Errorf("解码后 expires = %v, want %v", decodedExpires, expires)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("无效编码", func(_ *testing.T) {
|
||||
_, err := decodeStickyCookie("invalid-base64!!!")
|
||||
if err == nil {
|
||||
t.Error("decodeStickyCookie() 应返回错误")
|
||||
decodedURL, decodedExpires, ok := decodeStickyCookie("invalid-base64!!!")
|
||||
if ok {
|
||||
t.Errorf("decodeStickyCookie() = (%q, %v, %v), want ok=false", decodedURL, decodedExpires, ok)
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -197,6 +210,7 @@ func TestStickySession_Concurrent(t *testing.T) {
|
||||
config.Enabled = true
|
||||
fallback := NewRoundRobin()
|
||||
sticky := NewStickySession(config, fallback)
|
||||
sticky.Start()
|
||||
defer sticky.Stop()
|
||||
|
||||
targets := []*Target{
|
||||
@ -213,7 +227,7 @@ func TestStickySession_Concurrent(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
// 交替使用有 cookie 和没有 cookie 的请求
|
||||
if idx%2 == 0 {
|
||||
ctx.Request.Header.SetCookie(config.Name, encodeStickyCookie("http://backend1:8080"))
|
||||
ctx.Request.Header.SetCookie(config.Name, encodeStickyCookie("http://backend1:8080", time.Now().Add(time.Hour)))
|
||||
}
|
||||
got := sticky.Select(ctx, targets)
|
||||
if got == nil {
|
||||
@ -232,6 +246,7 @@ func TestStickySession_SelectExcluding(t *testing.T) {
|
||||
config.Enabled = true
|
||||
fallback := NewRoundRobin()
|
||||
sticky := NewStickySession(config, fallback)
|
||||
sticky.Start()
|
||||
defer sticky.Stop()
|
||||
|
||||
targets := []*Target{
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user