Files
sub2api-cn-relay-manager/internal/routing/sticky_memory.go
2026-05-29 07:43:29 +08:00

176 lines
4.2 KiB
Go

package routing
import (
"context"
"strings"
"sync"
"time"
)
type memoryStickyEntry struct {
binding StickyBinding
expiresAt time.Time
}
type memoryRouteFailureEntry struct {
state RouteFailureState
expiresAt time.Time
}
type memoryRouteCooldownEntry struct {
state RouteCooldownState
expiresAt time.Time
}
type InMemoryStickyStore struct {
now func() time.Time
mu sync.RWMutex
bindings map[string]memoryStickyEntry
routeFailure map[string]memoryRouteFailureEntry
cooldowns map[string]memoryRouteCooldownEntry
}
func NewInMemoryStickyStore() *InMemoryStickyStore {
return &InMemoryStickyStore{
now: time.Now,
bindings: make(map[string]memoryStickyEntry),
routeFailure: make(map[string]memoryRouteFailureEntry),
cooldowns: make(map[string]memoryRouteCooldownEntry),
}
}
func (s *InMemoryStickyStore) Get(_ context.Context, key string) (StickyBinding, bool, error) {
key = strings.TrimSpace(key)
if key == "" {
return StickyBinding{}, false, nil
}
s.mu.Lock()
defer s.mu.Unlock()
entry, ok := s.bindings[key]
if !ok {
return StickyBinding{}, false, nil
}
if s.expired(entry.expiresAt) {
delete(s.bindings, key)
return StickyBinding{}, false, nil
}
return entry.binding, true, nil
}
func (s *InMemoryStickyStore) Set(_ context.Context, key string, binding StickyBinding, ttl time.Duration) error {
key = strings.TrimSpace(key)
binding, err := normalizeStickyBinding(binding, ttl, s.now())
if err != nil {
return err
}
if key == "" {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
s.bindings[key] = memoryStickyEntry{binding: binding, expiresAt: s.now().UTC().Add(ttl)}
return nil
}
func (s *InMemoryStickyStore) Delete(_ context.Context, key string) error {
key = strings.TrimSpace(key)
if key == "" {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
delete(s.bindings, key)
return nil
}
func (s *InMemoryStickyStore) GetRouteFailure(_ context.Context, routeID string) (RouteFailureState, bool, error) {
routeID = strings.TrimSpace(routeID)
if routeID == "" {
return RouteFailureState{}, false, nil
}
s.mu.Lock()
defer s.mu.Unlock()
entry, ok := s.routeFailure[routeID]
if !ok {
return RouteFailureState{}, false, nil
}
if s.expired(entry.expiresAt) {
delete(s.routeFailure, routeID)
return RouteFailureState{}, false, nil
}
return entry.state, true, nil
}
func (s *InMemoryStickyStore) SetRouteFailure(_ context.Context, routeID string, state RouteFailureState, ttl time.Duration) error {
state, err := normalizeRouteFailureState(routeID, state, ttl, s.now())
if err != nil {
return err
}
s.mu.Lock()
defer s.mu.Unlock()
s.routeFailure[state.RouteID] = memoryRouteFailureEntry{state: state, expiresAt: s.now().UTC().Add(ttl)}
return nil
}
func (s *InMemoryStickyStore) ClearRouteFailure(_ context.Context, routeID string) error {
routeID = strings.TrimSpace(routeID)
if routeID == "" {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
delete(s.routeFailure, routeID)
return nil
}
func (s *InMemoryStickyStore) GetCooldown(_ context.Context, routeID string) (RouteCooldownState, bool, error) {
routeID = strings.TrimSpace(routeID)
if routeID == "" {
return RouteCooldownState{}, false, nil
}
s.mu.Lock()
defer s.mu.Unlock()
entry, ok := s.cooldowns[routeID]
if !ok {
return RouteCooldownState{}, false, nil
}
if s.expired(entry.expiresAt) {
delete(s.cooldowns, routeID)
return RouteCooldownState{}, false, nil
}
return entry.state, true, nil
}
func (s *InMemoryStickyStore) SetCooldown(_ context.Context, routeID string, state RouteCooldownState, ttl time.Duration) error {
state, err := normalizeRouteCooldownState(routeID, state, ttl, s.now())
if err != nil {
return err
}
s.mu.Lock()
defer s.mu.Unlock()
s.cooldowns[state.RouteID] = memoryRouteCooldownEntry{state: state, expiresAt: s.now().UTC().Add(ttl)}
return nil
}
func (s *InMemoryStickyStore) ClearCooldown(_ context.Context, routeID string) error {
routeID = strings.TrimSpace(routeID)
if routeID == "" {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
delete(s.cooldowns, routeID)
return nil
}
func (s *InMemoryStickyStore) expired(expiresAt time.Time) bool {
return !expiresAt.IsZero() && !expiresAt.After(s.now().UTC())
}