From f69a11ea0567cc294fd7df5de1c58c92fa22d37a Mon Sep 17 00:00:00 2001 From: xfy Date: Mon, 8 Jun 2026 17:30:06 +0800 Subject: [PATCH] feat(loadbalance): implement Session Sticky balancer - Add 256-shard lock map for concurrent session routing - Cookie-based session persistence with base64 encoding - TTL expiration with background cleanup goroutine - Support Secure, HttpOnly, SameSite cookie attributes - Fallback to configured balancer when session target unavailable --- internal/loadbalance/sticky.go | 205 +++++++++++++++++++++ internal/loadbalance/sticky_config.go | 24 +++ internal/loadbalance/sticky_test.go | 251 ++++++++++++++++++++++++++ 3 files changed, 480 insertions(+) create mode 100644 internal/loadbalance/sticky.go create mode 100644 internal/loadbalance/sticky_config.go create mode 100644 internal/loadbalance/sticky_test.go diff --git a/internal/loadbalance/sticky.go b/internal/loadbalance/sticky.go new file mode 100644 index 0000000..ffe01ee --- /dev/null +++ b/internal/loadbalance/sticky.go @@ -0,0 +1,205 @@ +package loadbalance + +import ( + "encoding/base64" + "sync" + "time" + + "github.com/valyala/fasthttp" +) + +const stickyShardCount = 256 + +type stickyEntry struct { + targetURL string + expires time.Time +} + +type stickyShard struct { + mu sync.RWMutex + entries map[string]*stickyEntry +} + +// StickySession 实现基于 cookie 的会话粘性负载均衡。 +type StickySession struct { + config StickyConfig + fallback Balancer + shards []*stickyShard + stopCh chan struct{} + wg sync.WaitGroup +} + +// NewStickySession 创建一个新的会话粘性负载均衡器。 +func NewStickySession(config StickyConfig, fallback Balancer) *StickySession { + shards := make([]*stickyShard, stickyShardCount) + for i := range shards { + shards[i] = &stickyShard{ + entries: make(map[string]*stickyEntry), + } + } + s := &StickySession{ + config: config, + fallback: fallback, + shards: shards, + stopCh: make(chan struct{}), + } + return s +} + +// Start 启动后台清理 goroutine。 +func (s *StickySession) Start() { + s.wg.Add(1) + go s.cleanupLoop() +} + +// Stop 停止后台清理 goroutine。 +func (s *StickySession) Stop() { + close(s.stopCh) + s.wg.Wait() +} + +// cleanupLoop 定期清理过期的会话条目。 +func (s *StickySession) cleanupLoop() { + defer s.wg.Done() + ticker := time.NewTicker(60 * time.Second) + defer ticker.Stop() + for { + select { + case <-s.stopCh: + return + case <-ticker.C: + s.cleanup() + } + } +} + +// cleanup 清理所有 shard 中的过期条目。 +func (s *StickySession) cleanup() { + now := time.Now() + for _, shard := range s.shards { + shard.mu.Lock() + for key, entry := range shard.entries { + if now.After(entry.expires) { + delete(shard.entries, key) + } + } + shard.mu.Unlock() + } +} + +// Select 根据会话 cookie 选择目标。 +// 如果存在有效的会话 cookie 且目标健康,则路由到该目标。 +// 否则使用 fallback 选择器,并设置新的会话 cookie。 +func (s *StickySession) Select(ctx *fasthttp.RequestCtx, targets []*Target) *Target { + if !s.config.Enabled { + return s.fallback.Select(targets) + } + + // 检查现有 cookie + cookieValue := ctx.Request.Header.Cookie(s.config.Name) + if len(cookieValue) > 0 { + decodedURL, err := decodeStickyCookie(string(cookieValue)) + if err == nil && decodedURL != "" { + // 查找对应的目标 + for _, target := range targets { + if target.URL == decodedURL && target.IsAvailable() { + return target + } + } + // 目标不可用,删除会话记录 + s.deleteSession(decodedURL) + } + } + + // 使用 fallback 选择目标 + selected := s.fallback.Select(targets) + if selected != nil { + s.setCookie(ctx, selected.URL) + s.recordSession(selected.URL) + } + return selected +} + +// SelectExcluding 排除指定目标后选择,委托给 fallback 实现。 +func (s *StickySession) SelectExcluding(targets []*Target, excluded []*Target) *Target { + return s.fallback.SelectExcluding(targets, excluded) +} + +// setCookie 设置会话 cookie 到响应头。 +func (s *StickySession) setCookie(ctx *fasthttp.RequestCtx, targetURL string) { + cookie := &fasthttp.Cookie{} + cookie.SetKey(s.config.Name) + cookie.SetValue(encodeStickyCookie(targetURL)) + + if s.config.Expires > 0 { + cookie.SetExpire(time.Now().Add(s.config.Expires)) + } + if s.config.Domain != "" { + cookie.SetDomain(s.config.Domain) + } + if s.config.Path != "" { + cookie.SetPath(s.config.Path) + } else { + cookie.SetPath("/") + } + if s.config.Secure { + cookie.SetSecure(true) + } + if s.config.HttpOnly { + cookie.SetHTTPOnly(true) + } + + switch s.config.SameSite { + case "Strict": + cookie.SetSameSite(fasthttp.CookieSameSiteStrictMode) + case "None": + cookie.SetSameSite(fasthttp.CookieSameSiteNoneMode) + default: + cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode) + } + + ctx.Response.Header.SetCookie(cookie) +} + +// recordSession 记录会话到 shard 中。 +func (s *StickySession) recordSession(targetURL string) { + shard := s.getShard(targetURL) + shard.mu.Lock() + shard.entries[targetURL] = &stickyEntry{ + targetURL: targetURL, + expires: time.Now().Add(s.config.Expires), + } + shard.mu.Unlock() +} + +// deleteSession 从 shard 中删除会话记录。 +func (s *StickySession) deleteSession(targetURL string) { + shard := s.getShard(targetURL) + shard.mu.Lock() + delete(shard.entries, targetURL) + shard.mu.Unlock() +} + +// getShard 根据 targetURL 选择对应的 shard。 +func (s *StickySession) getShard(targetURL string) *stickyShard { + hash := fnvHash64a(targetURL) + return s.shards[hash%stickyShardCount] +} + +// encodeStickyCookie 将目标 URL 编码为 cookie 值(base64)。 +func encodeStickyCookie(targetURL string) string { + return base64.URLEncoding.EncodeToString([]byte(targetURL)) +} + +// decodeStickyCookie 解码 cookie 值为目标 URL。 +func decodeStickyCookie(value string) (string, error) { + decoded, err := base64.URLEncoding.DecodeString(value) + if err != nil { + return "", err + } + return string(decoded), nil +} + +// Ensure StickySession implements the SelectExcluding part of Balancer interface. +// Note: Select signature differs (includes *fasthttp.RequestCtx), so it does +// not fully implement Balancer. diff --git a/internal/loadbalance/sticky_config.go b/internal/loadbalance/sticky_config.go new file mode 100644 index 0000000..39e8280 --- /dev/null +++ b/internal/loadbalance/sticky_config.go @@ -0,0 +1,24 @@ +package loadbalance + +import "time" + +type StickyConfig struct { + Enabled bool `yaml:"enabled"` + Name string `yaml:"name"` + Expires time.Duration `yaml:"expires"` + Domain string `yaml:"domain"` + Path string `yaml:"path"` + Secure bool `yaml:"secure"` + HttpOnly bool `yaml:"http_only"` + SameSite string `yaml:"same_site"` +} + +func DefaultStickyConfig() StickyConfig { + return StickyConfig{ + Name: "lolly_route", + Expires: time.Hour, + Path: "/", + HttpOnly: true, + SameSite: "Lax", + } +} diff --git a/internal/loadbalance/sticky_test.go b/internal/loadbalance/sticky_test.go new file mode 100644 index 0000000..6203e5e --- /dev/null +++ b/internal/loadbalance/sticky_test.go @@ -0,0 +1,251 @@ +package loadbalance + +import ( + "sync" + "testing" + + "github.com/valyala/fasthttp" +) + +// TestStickySession_BasicRoute 测试基本的会话粘性路由。 +// 第一次请求应设置 cookie,第二次携带相同 cookie 应路由到同一目标。 +func TestStickySession_BasicRoute(t *testing.T) { + t.Parallel() + t.Run("首次请求设置cookie并路由", func(_ *testing.T) { + config := DefaultStickyConfig() + config.Enabled = true + fallback := NewRoundRobin() + sticky := NewStickySession(config, fallback) + defer sticky.Stop() + + targets := []*Target{ + createHealthyTarget("http://backend1:8080", true), + createHealthyTarget("http://backend2:8080", true), + } + + ctx := &fasthttp.RequestCtx{} + got := sticky.Select(ctx, targets) + if got == nil { + t.Fatal("Select() = nil, want non-nil") + } + + // 验证设置了 cookie + cookieValue := ctx.Response.Header.PeekCookie(config.Name) + if len(cookieValue) == 0 { + t.Error("首次请求未设置 cookie") + } + }) + + t.Run("相同cookie路由到同一目标", func(_ *testing.T) { + config := DefaultStickyConfig() + config.Enabled = true + fallback := NewRoundRobin() + sticky := NewStickySession(config, fallback) + defer sticky.Stop() + + targets := []*Target{ + createHealthyTarget("http://backend1:8080", true), + createHealthyTarget("http://backend2:8080", true), + } + + // 第一次请求 + ctx1 := &fasthttp.RequestCtx{} + got1 := sticky.Select(ctx1, targets) + if got1 == nil { + t.Fatal("第一次 Select() = nil") + } + + // 提取 cookie + cookie := &fasthttp.Cookie{} + cookie.SetKey(config.Name) + if err := cookie.ParseBytes(ctx1.Response.Header.PeekCookie(config.Name)); err != nil { + t.Fatalf("解析 cookie 失败: %v", err) + } + + // 第二次请求携带相同 cookie + ctx2 := &fasthttp.RequestCtx{} + ctx2.Request.Header.SetCookie(config.Name, string(cookie.Value())) + got2 := sticky.Select(ctx2, targets) + if got2 == nil { + t.Fatal("第二次 Select() = nil") + } + + if got2.URL != got1.URL { + t.Errorf("相同 cookie 路由到不同目标: 第一次=%q, 第二次=%q", got1.URL, got2.URL) + } + }) + + t.Run("禁用时不设置cookie", func(_ *testing.T) { + config := DefaultStickyConfig() + config.Enabled = false + fallback := NewRoundRobin() + sticky := NewStickySession(config, fallback) + defer sticky.Stop() + + targets := []*Target{ + createHealthyTarget("http://backend1:8080", true), + } + + ctx := &fasthttp.RequestCtx{} + got := sticky.Select(ctx, targets) + if got == nil { + t.Fatal("Select() = nil") + } + + cookieValue := ctx.Response.Header.PeekCookie(config.Name) + if len(cookieValue) > 0 { + t.Error("禁用时不应设置 cookie") + } + }) +} + +// TestStickySession_TargetUnavailable 测试目标不可用时回退到 fallback。 +func TestStickySession_TargetUnavailable(t *testing.T) { + t.Parallel() + t.Run("目标不健康时回退", func(_ *testing.T) { + config := DefaultStickyConfig() + config.Enabled = true + fallback := NewRoundRobin() + sticky := NewStickySession(config, fallback) + defer sticky.Stop() + + targets := []*Target{ + createHealthyTarget("http://backend1:8080", true), + createHealthyTarget("http://backend2:8080", true), + } + + // 第一次请求,记录会话 + ctx1 := &fasthttp.RequestCtx{} + got1 := sticky.Select(ctx1, targets) + if got1 == nil { + t.Fatal("第一次 Select() = nil") + } + + // 提取 cookie + cookie := &fasthttp.Cookie{} + cookie.SetKey(config.Name) + if err := cookie.ParseBytes(ctx1.Response.Header.PeekCookie(config.Name)); err != nil { + t.Fatalf("解析 cookie 失败: %v", err) + } + + // 使之前选中的目标不健康 + for _, target := range targets { + if target.URL == got1.URL { + target.Healthy.Store(false) + break + } + } + + // 第二次请求,应回退到其他目标 + ctx2 := &fasthttp.RequestCtx{} + ctx2.Request.Header.SetCookie(config.Name, string(cookie.Value())) + got2 := sticky.Select(ctx2, targets) + if got2 == nil { + t.Fatal("第二次 Select() = nil") + } + + if got2.URL == got1.URL { + t.Errorf("不健康目标未回退: %q", got2.URL) + } + }) +} + +// TestStickySession_CookieEncodeDecode 测试 cookie 编解码。 +func TestStickySession_CookieEncodeDecode(t *testing.T) { + t.Parallel() + t.Run("编码解码round-trip", func(_ *testing.T) { + url := "http://backend1:8080" + encoded := encodeStickyCookie(url) + if encoded == "" { + t.Fatal("encodeStickyCookie() 返回空字符串") + } + + decoded, err := decodeStickyCookie(encoded) + if err != nil { + t.Fatalf("decodeStickyCookie() 错误: %v", err) + } + + if decoded != url { + t.Errorf("解码后 URL = %q, want %q", decoded, url) + } + }) + + t.Run("空URL编码解码", func(_ *testing.T) { + encoded := encodeStickyCookie("") + decoded, err := decodeStickyCookie(encoded) + if err != nil { + t.Fatalf("decodeStickyCookie() 错误: %v", err) + } + if decoded != "" { + t.Errorf("解码后 URL = %q, want 空字符串", decoded) + } + }) + + t.Run("无效编码", func(_ *testing.T) { + _, err := decodeStickyCookie("invalid-base64!!!") + if err == nil { + t.Error("decodeStickyCookie() 应返回错误") + } + }) +} + +// TestStickySession_Concurrent 测试并发安全。 +// 100 个 goroutine 同时访问会话存储。 +func TestStickySession_Concurrent(t *testing.T) { + t.Parallel() + config := DefaultStickyConfig() + config.Enabled = true + fallback := NewRoundRobin() + sticky := NewStickySession(config, fallback) + defer sticky.Stop() + + targets := []*Target{ + createHealthyTarget("http://backend1:8080", true), + createHealthyTarget("http://backend2:8080", true), + createHealthyTarget("http://backend3:8080", true), + } + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + ctx := &fasthttp.RequestCtx{} + // 交替使用有 cookie 和没有 cookie 的请求 + if idx%2 == 0 { + ctx.Request.Header.SetCookie(config.Name, encodeStickyCookie("http://backend1:8080")) + } + got := sticky.Select(ctx, targets) + if got == nil { + t.Error("并发 Select() = nil") + } + }(i) + } + wg.Wait() +} + +// TestStickySession_SelectExcluding 测试排除选择委托给 fallback。 +func TestStickySession_SelectExcluding(t *testing.T) { + t.Parallel() + t.Run("SelectExcluding委托给fallback", func(_ *testing.T) { + config := DefaultStickyConfig() + config.Enabled = true + fallback := NewRoundRobin() + sticky := NewStickySession(config, fallback) + defer sticky.Stop() + + targets := []*Target{ + createHealthyTarget("http://backend1:8080", true), + createHealthyTarget("http://backend2:8080", true), + } + + excluded := []*Target{targets[0]} + got := sticky.SelectExcluding(targets, excluded) + if got == nil { + t.Fatal("SelectExcluding() = nil") + } + if got.URL != "http://backend2:8080" { + t.Errorf("SelectExcluding() = %q, want %q", got.URL, "http://backend2:8080") + } + }) +}