Files
sub2api-cn-relay-manager/internal/routing/sticky_test.go
2026-05-29 07:43:29 +08:00

461 lines
12 KiB
Go

package routing
import (
"bufio"
"context"
"fmt"
"io"
"net"
"path/filepath"
"strconv"
"strings"
"sync"
"testing"
"time"
)
func TestBuildStickyKey(t *testing.T) {
t.Parallel()
tests := []struct {
name string
scope string
want string
wantErr bool
}{
{name: "conversation", scope: StickyScopeConversation, want: "lg:gpt-shared:m:gpt-5.4:conv:conversation-1"},
{name: "session", scope: StickyScopeSession, want: "lg:gpt-shared:m:gpt-5.4:sess:session-1"},
{name: "user", scope: StickyScopeUser, want: "lg:gpt-shared:m:gpt-5.4:user:user-1"},
{name: "invalid", scope: "bad", wantErr: true},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := BuildStickyKey(tt.scope, "gpt-shared", "gpt-5.4", tt.name+"-1")
if tt.wantErr {
if err == nil {
t.Fatal("BuildStickyKey() error = nil, want error")
}
return
}
if err != nil {
t.Fatalf("BuildStickyKey() error = %v", err)
}
if got != tt.want {
t.Fatalf("BuildStickyKey() = %q, want %q", got, tt.want)
}
})
}
}
func TestInMemoryStickyStoreBindingFailureAndCooldown(t *testing.T) {
t.Parallel()
store := NewInMemoryStickyStore()
ctx := context.Background()
key, err := BuildStickyKey(StickyScopeConversation, "gpt-shared", "gpt-5.4", "conv-1")
if err != nil {
t.Fatalf("BuildStickyKey() error = %v", err)
}
if err := store.Set(ctx, key, StickyBinding{
LogicalGroupID: "gpt-shared",
PublicModel: "gpt-5.4",
RouteID: "asxs",
ShadowGroupID: "gpt-shared__asxs",
}, 2*time.Second); err != nil {
t.Fatalf("Set() error = %v", err)
}
binding, ok, err := store.Get(ctx, key)
if err != nil || !ok {
t.Fatalf("Get() = (%+v, %v, %v), want binding", binding, ok, err)
}
if binding.RouteID != "asxs" {
t.Fatalf("binding.RouteID = %q, want asxs", binding.RouteID)
}
if err := store.Delete(ctx, key); err != nil {
t.Fatalf("Delete() error = %v", err)
}
if _, ok, err := store.Get(ctx, key); err != nil || ok {
t.Fatalf("Get() after delete = (ok=%v, err=%v), want false nil", ok, err)
}
if err := store.SetRouteFailure(ctx, "asxs", RouteFailureState{
FailureCount: 2,
LastErrorClass: "timeout",
}, time.Second); err != nil {
t.Fatalf("SetRouteFailure() error = %v", err)
}
failure, ok, err := store.GetRouteFailure(ctx, "asxs")
if err != nil || !ok || failure.FailureCount != 2 {
t.Fatalf("GetRouteFailure() = (%+v, %v, %v), want count 2", failure, ok, err)
}
if err := store.ClearRouteFailure(ctx, "asxs"); err != nil {
t.Fatalf("ClearRouteFailure() error = %v", err)
}
if err := store.SetCooldown(ctx, "asxs", RouteCooldownState{
Reason: "cooldown",
}, time.Second); err != nil {
t.Fatalf("SetCooldown() error = %v", err)
}
cooldown, ok, err := store.GetCooldown(ctx, "asxs")
if err != nil || !ok || cooldown.RouteID != "asxs" {
t.Fatalf("GetCooldown() = (%+v, %v, %v), want route asxs", cooldown, ok, err)
}
if err := store.ClearCooldown(ctx, "asxs"); err != nil {
t.Fatalf("ClearCooldown() error = %v", err)
}
}
func TestInMemoryStickyStoreTTlExpiry(t *testing.T) {
t.Parallel()
store := NewInMemoryStickyStore()
ctx := context.Background()
key, err := BuildStickyKey(StickyScopeUser, "gpt-shared", "gpt-5.4", "user-1")
if err != nil {
t.Fatalf("BuildStickyKey() error = %v", err)
}
if err := store.Set(ctx, key, StickyBinding{
LogicalGroupID: "gpt-shared",
PublicModel: "gpt-5.4",
RouteID: "asxs",
ShadowGroupID: "gpt-shared__asxs",
}, 40*time.Millisecond); err != nil {
t.Fatalf("Set() error = %v", err)
}
time.Sleep(60 * time.Millisecond)
if _, ok, err := store.Get(ctx, key); err != nil || ok {
t.Fatalf("Get() after ttl = (ok=%v, err=%v), want false nil", ok, err)
}
}
func TestRedisStickyStoreRoundTripWithFakeServer(t *testing.T) {
t.Parallel()
server := newFakeRedisServer(t)
defer server.Close()
store, err := NewRedisStickyStore(context.Background(), RedisConfig{
Addr: server.Addr(),
Password: "secret",
DB: 2,
})
if err != nil {
t.Fatalf("NewRedisStickyStore() error = %v", err)
}
ctx := context.Background()
key, err := BuildStickyKey(StickyScopeSession, "gpt-shared", "gpt-5.4", "sess-1")
if err != nil {
t.Fatalf("BuildStickyKey() error = %v", err)
}
if err := store.Set(ctx, key, StickyBinding{
LogicalGroupID: "gpt-shared",
PublicModel: "gpt-5.4",
RouteID: "asxs",
ShadowGroupID: "gpt-shared__asxs",
}, time.Minute); err != nil {
t.Fatalf("Set() error = %v", err)
}
if binding, ok, err := store.Get(ctx, key); err != nil || !ok || binding.RouteID != "asxs" {
t.Fatalf("Get() = (%+v, %v, %v), want route asxs", binding, ok, err)
}
if err := store.SetRouteFailure(ctx, "asxs", RouteFailureState{
FailureCount: 3,
LastErrorClass: "timeout",
}, time.Minute); err != nil {
t.Fatalf("SetRouteFailure() error = %v", err)
}
if state, ok, err := store.GetRouteFailure(ctx, "asxs"); err != nil || !ok || state.FailureCount != 3 {
t.Fatalf("GetRouteFailure() = (%+v, %v, %v), want count 3", state, ok, err)
}
if err := store.SetCooldown(ctx, "asxs", RouteCooldownState{
Reason: "degraded",
}, time.Minute); err != nil {
t.Fatalf("SetCooldown() error = %v", err)
}
if state, ok, err := store.GetCooldown(ctx, "asxs"); err != nil || !ok || state.Reason != "degraded" {
t.Fatalf("GetCooldown() = (%+v, %v, %v), want reason degraded", state, ok, err)
}
if err := store.Delete(ctx, key); err != nil {
t.Fatalf("Delete() error = %v", err)
}
if _, ok, err := store.Get(ctx, key); err != nil || ok {
t.Fatalf("Get() after delete = (ok=%v, err=%v), want false nil", ok, err)
}
}
type fakeRedisServer struct {
t *testing.T
listener net.Listener
password string
mu sync.Mutex
values map[int]map[string]fakeRedisValue
}
type fakeRedisValue struct {
value string
expiresAt time.Time
}
func newFakeRedisServer(t *testing.T) *fakeRedisServer {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("net.Listen() error = %v", err)
}
server := &fakeRedisServer{
t: t,
listener: ln,
password: "secret",
values: make(map[int]map[string]fakeRedisValue),
}
go server.serve()
return server
}
func (s *fakeRedisServer) Addr() string {
return s.listener.Addr().String()
}
func (s *fakeRedisServer) Close() {
_ = s.listener.Close()
}
func (s *fakeRedisServer) serve() {
for {
conn, err := s.listener.Accept()
if err != nil {
return
}
go s.handleConn(conn)
}
}
func (s *fakeRedisServer) handleConn(conn net.Conn) {
defer conn.Close()
reader := bufio.NewReader(conn)
currentDB := 0
authed := false
for {
command, err := readRESPArray(reader)
if err != nil {
if err == io.EOF {
return
}
s.writeError(conn, err.Error())
return
}
if len(command) == 0 {
s.writeError(conn, "empty command")
continue
}
switch strings.ToUpper(command[0]) {
case "PING":
s.writeSimpleString(conn, "PONG")
case "AUTH":
if len(command) != 2 || command[1] != s.password {
s.writeError(conn, "ERR invalid password")
continue
}
authed = true
s.writeSimpleString(conn, "OK")
case "SELECT":
if len(command) != 2 {
s.writeError(conn, "ERR bad select")
continue
}
db, err := strconv.Atoi(command[1])
if err != nil {
s.writeError(conn, "ERR bad db")
continue
}
currentDB = db
s.writeSimpleString(conn, "OK")
case "SET":
if !authed {
s.writeError(conn, "NOAUTH")
continue
}
if len(command) != 5 || strings.ToUpper(command[3]) != "EX" {
s.writeError(conn, "ERR bad set")
continue
}
ttl, err := strconv.Atoi(command[4])
if err != nil {
s.writeError(conn, "ERR bad ttl")
continue
}
s.setValue(currentDB, command[1], command[2], time.Duration(ttl)*time.Second)
s.writeSimpleString(conn, "OK")
case "GET":
if !authed {
s.writeError(conn, "NOAUTH")
continue
}
if len(command) != 2 {
s.writeError(conn, "ERR bad get")
continue
}
value, ok := s.getValue(currentDB, command[1])
if !ok {
s.writeNullBulk(conn)
continue
}
s.writeBulk(conn, value)
case "DEL":
if !authed {
s.writeError(conn, "NOAUTH")
continue
}
if len(command) != 2 {
s.writeError(conn, "ERR bad del")
continue
}
s.deleteValue(currentDB, command[1])
s.writeInteger(conn, 1)
default:
s.writeError(conn, "ERR unknown command")
}
}
}
func readRESPArray(reader *bufio.Reader) ([]string, error) {
line, err := reader.ReadString('\n')
if err != nil {
return nil, err
}
line = strings.TrimSuffix(strings.TrimSuffix(line, "\n"), "\r")
if !strings.HasPrefix(line, "*") {
return nil, fmt.Errorf("expected array, got %q", line)
}
count, err := strconv.Atoi(strings.TrimPrefix(line, "*"))
if err != nil {
return nil, err
}
parts := make([]string, 0, count)
for i := 0; i < count; i++ {
header, err := reader.ReadString('\n')
if err != nil {
return nil, err
}
header = strings.TrimSuffix(strings.TrimSuffix(header, "\n"), "\r")
if !strings.HasPrefix(header, "$") {
return nil, fmt.Errorf("expected bulk header, got %q", header)
}
size, err := strconv.Atoi(strings.TrimPrefix(header, "$"))
if err != nil {
return nil, err
}
payload := make([]byte, size+2)
if _, err := io.ReadFull(reader, payload); err != nil {
return nil, err
}
parts = append(parts, string(payload[:size]))
}
return parts, nil
}
func (s *fakeRedisServer) writeSimpleString(w io.Writer, value string) {
_, _ = io.WriteString(w, "+"+value+"\r\n")
}
func (s *fakeRedisServer) writeBulk(w io.Writer, value string) {
_, _ = io.WriteString(w, fmt.Sprintf("$%d\r\n%s\r\n", len(value), value))
}
func (s *fakeRedisServer) writeNullBulk(w io.Writer) {
_, _ = io.WriteString(w, "$-1\r\n")
}
func (s *fakeRedisServer) writeInteger(w io.Writer, value int) {
_, _ = io.WriteString(w, fmt.Sprintf(":%d\r\n", value))
}
func (s *fakeRedisServer) writeError(w io.Writer, message string) {
_, _ = io.WriteString(w, "-"+message+"\r\n")
}
func (s *fakeRedisServer) setValue(db int, key, value string, ttl time.Duration) {
s.mu.Lock()
defer s.mu.Unlock()
if s.values[db] == nil {
s.values[db] = make(map[string]fakeRedisValue)
}
s.values[db][key] = fakeRedisValue{value: value, expiresAt: time.Now().Add(ttl)}
}
func (s *fakeRedisServer) getValue(db int, key string) (string, bool) {
s.mu.Lock()
defer s.mu.Unlock()
value, ok := s.values[db][key]
if !ok {
return "", false
}
if !value.expiresAt.IsZero() && !value.expiresAt.After(time.Now()) {
delete(s.values[db], key)
return "", false
}
return value.value, true
}
func (s *fakeRedisServer) deleteValue(db int, key string) {
s.mu.Lock()
defer s.mu.Unlock()
if s.values[db] == nil {
return
}
delete(s.values[db], key)
}
func TestRedisStickyStoreRequiresAddr(t *testing.T) {
t.Parallel()
if _, err := NewRedisStickyStore(context.Background(), RedisConfig{}); err == nil {
t.Fatal("NewRedisStickyStore() error = nil, want missing addr")
}
}
func TestNormalizeRuntimeBackend(t *testing.T) {
t.Parallel()
if got, err := NormalizeRuntimeBackend(""); err != nil || got != RuntimeBackendMemory {
t.Fatalf("NormalizeRuntimeBackend(\"\") = (%q, %v), want memory nil", got, err)
}
if got, err := NormalizeRuntimeBackend("redis"); err != nil || got != RuntimeBackendRedis {
t.Fatalf("NormalizeRuntimeBackend(redis) = (%q, %v), want redis nil", got, err)
}
if _, err := NormalizeRuntimeBackend("bad"); err == nil {
t.Fatal("NormalizeRuntimeBackend(bad) error = nil, want error")
}
}
func TestRouteFailureAndCooldownKeyBuilders(t *testing.T) {
t.Parallel()
failureKey, err := BuildRouteFailureKey("asxs")
if err != nil || failureKey != "routefail:asxs" {
t.Fatalf("BuildRouteFailureKey() = (%q, %v), want routefail:asxs nil", failureKey, err)
}
cooldownKey, err := BuildRouteCooldownKey("asxs")
if err != nil || cooldownKey != "routecool:asxs" {
t.Fatalf("BuildRouteCooldownKey() = (%q, %v), want routecool:asxs nil", cooldownKey, err)
}
}
func TestRedisStickyStoreFixturePathExists(t *testing.T) {
t.Parallel()
if filepath.Base(t.TempDir()) == "" {
t.Fatal("temp dir base should not be empty")
}
}