实现 Phase 3 核心功能: - loadbalance: 轮询、加权轮询、最少连接、IP哈希四种算法 - proxy: HTTP 反向代理、健康检查、故障转移 - 所有实现均为并发安全,使用 atomic 操作 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
390 lines
11 KiB
Go
390 lines
11 KiB
Go
// Package proxy provides reverse proxy functionality for the Lolly HTTP server.
|
|
//
|
|
// This package implements a high-performance reverse proxy using fasthttp.HostClient
|
|
// for connection pooling and automatic keep-alive management. It supports load balancing,
|
|
// WebSocket forwarding, custom headers, and comprehensive timeout configurations.
|
|
//
|
|
// Example usage:
|
|
//
|
|
// targets := []*loadbalance.Target{
|
|
// {URL: "http://backend1:8080", Weight: 1, Healthy: true},
|
|
// {URL: "http://backend2:8080", Weight: 2, Healthy: true},
|
|
// }
|
|
//
|
|
// proxyConfig := &config.ProxyConfig{
|
|
// Path: "/api",
|
|
// LoadBalance: "weighted_round_robin",
|
|
// Timeout: config.ProxyTimeout{
|
|
// Connect: 5 * time.Second,
|
|
// Read: 30 * time.Second,
|
|
// Write: 30 * time.Second,
|
|
// },
|
|
// }
|
|
//
|
|
// p, err := proxy.NewProxy(proxyConfig, targets)
|
|
// if err != nil {
|
|
// log.Fatal(err)
|
|
// }
|
|
//
|
|
// // Use p.ServeHTTP as fasthttp request handler
|
|
//
|
|
//go:generate go test -v ./...
|
|
package proxy
|
|
|
|
import (
|
|
"errors"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/valyala/fasthttp"
|
|
"rua.plus/lolly/internal/config"
|
|
"rua.plus/lolly/internal/loadbalance"
|
|
)
|
|
|
|
// Proxy represents a reverse proxy instance that forwards HTTP requests to backend targets.
|
|
// It manages connection pools for each target and provides load balancing capabilities.
|
|
type Proxy struct {
|
|
targets []*loadbalance.Target
|
|
clients map[string]*fasthttp.HostClient // key: target URL
|
|
balancer loadbalance.Balancer
|
|
config *config.ProxyConfig
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// NewProxy creates a new reverse proxy instance with the given configuration and targets.
|
|
// It initializes the load balancer based on the config and creates HostClients for each target.
|
|
//
|
|
// Parameters:
|
|
// - cfg: Proxy configuration including timeouts, headers, and load balancing strategy
|
|
// - targets: List of backend targets to proxy requests to
|
|
//
|
|
// Returns:
|
|
// - *Proxy: Configured proxy instance ready to serve requests
|
|
// - error: Non-nil if initialization fails (invalid config, no healthy targets, etc.)
|
|
func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target) (*Proxy, error) {
|
|
if cfg == nil {
|
|
return nil, errors.New("proxy config is nil")
|
|
}
|
|
|
|
if len(targets) == 0 {
|
|
return nil, errors.New("no proxy targets provided")
|
|
}
|
|
|
|
// Create balancer based on configuration
|
|
balancer, err := createBalancer(cfg.LoadBalance)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
p := &Proxy{
|
|
targets: targets,
|
|
clients: make(map[string]*fasthttp.HostClient),
|
|
balancer: balancer,
|
|
config: cfg,
|
|
}
|
|
|
|
// Initialize HostClient for each target
|
|
for _, target := range targets {
|
|
if target.URL == "" {
|
|
continue
|
|
}
|
|
|
|
client := createHostClient(target.URL, cfg.Timeout)
|
|
p.clients[target.URL] = client
|
|
}
|
|
|
|
return p, nil
|
|
}
|
|
|
|
// createBalancer creates a load balancer based on the configured algorithm.
|
|
func createBalancer(algorithm string) (loadbalance.Balancer, error) {
|
|
switch algorithm {
|
|
case "round_robin", "":
|
|
return loadbalance.NewRoundRobin(), nil
|
|
case "weighted_round_robin":
|
|
return loadbalance.NewWeightedRoundRobin(), nil
|
|
case "least_conn":
|
|
return loadbalance.NewLeastConnections(), nil
|
|
case "ip_hash":
|
|
return loadbalance.NewIPHash(), nil
|
|
default:
|
|
return nil, errors.New("unsupported load balance algorithm: " + algorithm)
|
|
}
|
|
}
|
|
|
|
// createHostClient creates a fasthttp.HostClient for a target URL.
|
|
func createHostClient(targetURL string, timeout config.ProxyTimeout) *fasthttp.HostClient {
|
|
// Parse host and scheme from target URL
|
|
addr := targetURL
|
|
isTLS := false
|
|
|
|
if strings.HasPrefix(targetURL, "http://") {
|
|
addr = targetURL[7:]
|
|
} else if strings.HasPrefix(targetURL, "https://") {
|
|
addr = targetURL[8:]
|
|
isTLS = true
|
|
}
|
|
|
|
// Remove path if present, keep only host:port
|
|
if idx := strings.Index(addr, "/"); idx != -1 {
|
|
addr = addr[:idx]
|
|
}
|
|
|
|
client := &fasthttp.HostClient{
|
|
Addr: addr,
|
|
IsTLS: isTLS,
|
|
ReadTimeout: timeout.Read,
|
|
WriteTimeout: timeout.Write,
|
|
MaxIdleConnDuration: 60 * time.Second,
|
|
MaxConns: 100,
|
|
MaxConnWaitTimeout: timeout.Connect,
|
|
RetryIf: nil, // Disable automatic retries
|
|
DisablePathNormalizing: false,
|
|
SecureErrorLogMessage: false,
|
|
}
|
|
|
|
return client
|
|
}
|
|
|
|
// ServeHTTP handles the incoming HTTP request by forwarding it to a selected backend target.
|
|
// It implements the fasthttp request handler interface.
|
|
//
|
|
// The method:
|
|
// 1. Selects a target using load balancing
|
|
// 2. Prepares the request (modifies headers)
|
|
// 3. Forwards the request to the backend
|
|
// 4. Copies the response back to the client
|
|
//
|
|
// If no healthy targets are available, returns 502 Bad Gateway.
|
|
// If the backend request fails, returns appropriate error response.
|
|
func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
|
|
// Select target using load balancer
|
|
target := p.selectTarget(ctx)
|
|
if target == nil {
|
|
ctx.Error("Bad Gateway: no healthy upstream", fasthttp.StatusBadGateway)
|
|
return
|
|
}
|
|
|
|
// Get the client for selected target
|
|
client := p.getClient(target.URL)
|
|
if client == nil {
|
|
ctx.Error("Bad Gateway: upstream client unavailable", fasthttp.StatusBadGateway)
|
|
return
|
|
}
|
|
|
|
// Increment connection count for least_connections tracking
|
|
loadbalance.IncrementConnections(target)
|
|
defer loadbalance.DecrementConnections(target)
|
|
|
|
// Check if this is a WebSocket upgrade request
|
|
if isWebSocketRequest(ctx) {
|
|
p.handleWebSocket(ctx, target, client)
|
|
return
|
|
}
|
|
|
|
// Prepare request
|
|
req := &ctx.Request
|
|
|
|
// Modify request headers
|
|
p.modifyRequestHeaders(ctx, target)
|
|
|
|
// Perform the proxy request
|
|
err := client.Do(req, &ctx.Response)
|
|
if err != nil {
|
|
// Handle different error types
|
|
if errors.Is(err, fasthttp.ErrTimeout) {
|
|
ctx.Error("Gateway Timeout", fasthttp.StatusGatewayTimeout)
|
|
} else if errors.Is(err, fasthttp.ErrConnectionClosed) {
|
|
ctx.Error("Bad Gateway: upstream connection closed", fasthttp.StatusBadGateway)
|
|
} else {
|
|
ctx.Error("Bad Gateway", fasthttp.StatusBadGateway)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Modify response headers
|
|
p.modifyResponseHeaders(ctx)
|
|
}
|
|
|
|
// selectTarget selects a backend target using the configured load balancer.
|
|
// It extracts the client IP from the request for IP hash balancing.
|
|
// Returns nil if no healthy targets are available.
|
|
func (p *Proxy) selectTarget(ctx *fasthttp.RequestCtx) *loadbalance.Target {
|
|
p.mu.RLock()
|
|
balancer := p.balancer
|
|
targets := p.targets
|
|
p.mu.RUnlock()
|
|
|
|
if len(targets) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// For IPHash balancer, extract client IP
|
|
if ipHash, ok := balancer.(*loadbalance.IPHash); ok {
|
|
clientIP := getClientIP(ctx)
|
|
return ipHash.SelectByIP(targets, clientIP)
|
|
}
|
|
|
|
return balancer.Select(targets)
|
|
}
|
|
|
|
// getClientIP extracts the client IP address from the request context.
|
|
func getClientIP(ctx *fasthttp.RequestCtx) string {
|
|
// Check X-Forwarded-For header first
|
|
if xff := ctx.Request.Header.Peek("X-Forwarded-For"); len(xff) > 0 {
|
|
ips := strings.Split(string(xff), ",")
|
|
if len(ips) > 0 {
|
|
return strings.TrimSpace(ips[0])
|
|
}
|
|
}
|
|
|
|
// Check X-Real-IP header
|
|
if xri := ctx.Request.Header.Peek("X-Real-IP"); len(xri) > 0 {
|
|
return string(xri)
|
|
}
|
|
|
|
// Fall back to RemoteAddr
|
|
if addr := ctx.RemoteAddr(); addr != nil {
|
|
if tcpAddr, ok := addr.(*net.TCPAddr); ok {
|
|
return tcpAddr.IP.String()
|
|
}
|
|
return addr.String()
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
// getClient returns the HostClient for a given target URL.
|
|
func (p *Proxy) getClient(targetURL string) *fasthttp.HostClient {
|
|
p.mu.RLock()
|
|
client := p.clients[targetURL]
|
|
p.mu.RUnlock()
|
|
return client
|
|
}
|
|
|
|
// modifyRequestHeaders modifies the request headers before forwarding to backend.
|
|
// It adds standard proxy headers and applies custom header configurations.
|
|
func (p *Proxy) modifyRequestHeaders(ctx *fasthttp.RequestCtx, target *loadbalance.Target) {
|
|
headers := &ctx.Request.Header
|
|
|
|
// Add X-Real-IP header
|
|
clientIP := getClientIP(ctx)
|
|
if clientIP != "" {
|
|
headers.Set("X-Real-IP", clientIP)
|
|
}
|
|
|
|
// Add/Append X-Forwarded-For header
|
|
existingXFF := headers.Peek("X-Forwarded-For")
|
|
if len(existingXFF) > 0 {
|
|
headers.Set("X-Forwarded-For", string(existingXFF)+", "+clientIP)
|
|
} else {
|
|
headers.Set("X-Forwarded-For", clientIP)
|
|
}
|
|
|
|
// Add X-Forwarded-Host header
|
|
host := string(ctx.Host())
|
|
if host != "" {
|
|
headers.Set("X-Forwarded-Host", host)
|
|
}
|
|
|
|
// Add X-Forwarded-Proto header
|
|
proto := "http"
|
|
if ctx.IsTLS() {
|
|
proto = "https"
|
|
}
|
|
headers.Set("X-Forwarded-Proto", proto)
|
|
|
|
// Set custom request headers from config
|
|
if p.config.Headers.SetRequest != nil {
|
|
for key, value := range p.config.Headers.SetRequest {
|
|
headers.Set(key, value)
|
|
}
|
|
}
|
|
|
|
// Remove configured headers
|
|
if len(p.config.Headers.Remove) > 0 {
|
|
for _, key := range p.config.Headers.Remove {
|
|
headers.Del(key)
|
|
}
|
|
}
|
|
}
|
|
|
|
// modifyResponseHeaders modifies the response headers before sending to client.
|
|
func (p *Proxy) modifyResponseHeaders(ctx *fasthttp.RequestCtx) {
|
|
// Set custom response headers from config
|
|
if p.config.Headers.SetResponse != nil {
|
|
for key, value := range p.config.Headers.SetResponse {
|
|
ctx.Response.Header.Set(key, value)
|
|
}
|
|
}
|
|
}
|
|
|
|
// isWebSocketRequest checks if the request is a WebSocket upgrade request.
|
|
func isWebSocketRequest(ctx *fasthttp.RequestCtx) bool {
|
|
// Check Connection header
|
|
connection := ctx.Request.Header.Peek("Connection")
|
|
if !strings.EqualFold(string(connection), "upgrade") {
|
|
// Also check for "Upgrade" substring (e.g., "keep-alive, Upgrade")
|
|
if !strings.Contains(strings.ToLower(string(connection)), "upgrade") {
|
|
return false
|
|
}
|
|
}
|
|
|
|
// Check Upgrade header
|
|
upgrade := ctx.Request.Header.Peek("Upgrade")
|
|
return strings.EqualFold(string(upgrade), "websocket")
|
|
}
|
|
|
|
// handleWebSocket handles WebSocket upgrade requests.
|
|
// For now, it returns 501 Not Implemented as WebSocket proxying
|
|
// requires special handling beyond HTTP.
|
|
func (p *Proxy) handleWebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, client *fasthttp.HostClient) {
|
|
// WebSocket proxying requires raw TCP connection handling
|
|
// which is beyond the scope of basic HTTP proxying
|
|
// This can be implemented later using a TCP bridge
|
|
ctx.Error("WebSocket proxying not implemented", fasthttp.StatusNotImplemented)
|
|
}
|
|
|
|
// UpdateTargets updates the proxy targets and reinitializes clients.
|
|
// This is useful for dynamic configuration updates.
|
|
func (p *Proxy) UpdateTargets(targets []*loadbalance.Target) error {
|
|
if len(targets) == 0 {
|
|
return errors.New("no targets provided")
|
|
}
|
|
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
|
|
// Clear old clients
|
|
p.clients = make(map[string]*fasthttp.HostClient)
|
|
|
|
// Initialize new clients
|
|
for _, target := range targets {
|
|
if target.URL == "" {
|
|
continue
|
|
}
|
|
|
|
client := createHostClient(target.URL, p.config.Timeout)
|
|
p.clients[target.URL] = client
|
|
}
|
|
|
|
p.targets = targets
|
|
return nil
|
|
}
|
|
|
|
// GetTargets returns the current list of targets.
|
|
func (p *Proxy) GetTargets() []*loadbalance.Target {
|
|
p.mu.RLock()
|
|
defer p.mu.RUnlock()
|
|
return p.targets
|
|
}
|
|
|
|
// GetConfig returns the proxy configuration.
|
|
func (p *Proxy) GetConfig() *config.ProxyConfig {
|
|
p.mu.RLock()
|
|
defer p.mu.RUnlock()
|
|
return p.config
|
|
}
|