461 lines
12 KiB
Go
461 lines
12 KiB
Go
package routing
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestBuildStickyKey(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
scope string
|
|
want string
|
|
wantErr bool
|
|
}{
|
|
{name: "conversation", scope: StickyScopeConversation, want: "lg:gpt-shared:m:gpt-5.4:conv:conversation-1"},
|
|
{name: "session", scope: StickyScopeSession, want: "lg:gpt-shared:m:gpt-5.4:sess:session-1"},
|
|
{name: "user", scope: StickyScopeUser, want: "lg:gpt-shared:m:gpt-5.4:user:user-1"},
|
|
{name: "invalid", scope: "bad", wantErr: true},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
got, err := BuildStickyKey(tt.scope, "gpt-shared", "gpt-5.4", tt.name+"-1")
|
|
if tt.wantErr {
|
|
if err == nil {
|
|
t.Fatal("BuildStickyKey() error = nil, want error")
|
|
}
|
|
return
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("BuildStickyKey() error = %v", err)
|
|
}
|
|
if got != tt.want {
|
|
t.Fatalf("BuildStickyKey() = %q, want %q", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestInMemoryStickyStoreBindingFailureAndCooldown(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store := NewInMemoryStickyStore()
|
|
ctx := context.Background()
|
|
key, err := BuildStickyKey(StickyScopeConversation, "gpt-shared", "gpt-5.4", "conv-1")
|
|
if err != nil {
|
|
t.Fatalf("BuildStickyKey() error = %v", err)
|
|
}
|
|
|
|
if err := store.Set(ctx, key, StickyBinding{
|
|
LogicalGroupID: "gpt-shared",
|
|
PublicModel: "gpt-5.4",
|
|
RouteID: "asxs",
|
|
ShadowGroupID: "gpt-shared__asxs",
|
|
}, 2*time.Second); err != nil {
|
|
t.Fatalf("Set() error = %v", err)
|
|
}
|
|
binding, ok, err := store.Get(ctx, key)
|
|
if err != nil || !ok {
|
|
t.Fatalf("Get() = (%+v, %v, %v), want binding", binding, ok, err)
|
|
}
|
|
if binding.RouteID != "asxs" {
|
|
t.Fatalf("binding.RouteID = %q, want asxs", binding.RouteID)
|
|
}
|
|
if err := store.Delete(ctx, key); err != nil {
|
|
t.Fatalf("Delete() error = %v", err)
|
|
}
|
|
if _, ok, err := store.Get(ctx, key); err != nil || ok {
|
|
t.Fatalf("Get() after delete = (ok=%v, err=%v), want false nil", ok, err)
|
|
}
|
|
|
|
if err := store.SetRouteFailure(ctx, "asxs", RouteFailureState{
|
|
FailureCount: 2,
|
|
LastErrorClass: "timeout",
|
|
}, time.Second); err != nil {
|
|
t.Fatalf("SetRouteFailure() error = %v", err)
|
|
}
|
|
failure, ok, err := store.GetRouteFailure(ctx, "asxs")
|
|
if err != nil || !ok || failure.FailureCount != 2 {
|
|
t.Fatalf("GetRouteFailure() = (%+v, %v, %v), want count 2", failure, ok, err)
|
|
}
|
|
if err := store.ClearRouteFailure(ctx, "asxs"); err != nil {
|
|
t.Fatalf("ClearRouteFailure() error = %v", err)
|
|
}
|
|
|
|
if err := store.SetCooldown(ctx, "asxs", RouteCooldownState{
|
|
Reason: "cooldown",
|
|
}, time.Second); err != nil {
|
|
t.Fatalf("SetCooldown() error = %v", err)
|
|
}
|
|
cooldown, ok, err := store.GetCooldown(ctx, "asxs")
|
|
if err != nil || !ok || cooldown.RouteID != "asxs" {
|
|
t.Fatalf("GetCooldown() = (%+v, %v, %v), want route asxs", cooldown, ok, err)
|
|
}
|
|
if err := store.ClearCooldown(ctx, "asxs"); err != nil {
|
|
t.Fatalf("ClearCooldown() error = %v", err)
|
|
}
|
|
}
|
|
|
|
func TestInMemoryStickyStoreTTlExpiry(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
store := NewInMemoryStickyStore()
|
|
ctx := context.Background()
|
|
key, err := BuildStickyKey(StickyScopeUser, "gpt-shared", "gpt-5.4", "user-1")
|
|
if err != nil {
|
|
t.Fatalf("BuildStickyKey() error = %v", err)
|
|
}
|
|
if err := store.Set(ctx, key, StickyBinding{
|
|
LogicalGroupID: "gpt-shared",
|
|
PublicModel: "gpt-5.4",
|
|
RouteID: "asxs",
|
|
ShadowGroupID: "gpt-shared__asxs",
|
|
}, 40*time.Millisecond); err != nil {
|
|
t.Fatalf("Set() error = %v", err)
|
|
}
|
|
time.Sleep(60 * time.Millisecond)
|
|
if _, ok, err := store.Get(ctx, key); err != nil || ok {
|
|
t.Fatalf("Get() after ttl = (ok=%v, err=%v), want false nil", ok, err)
|
|
}
|
|
}
|
|
|
|
func TestRedisStickyStoreRoundTripWithFakeServer(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := newFakeRedisServer(t)
|
|
defer server.Close()
|
|
|
|
store, err := NewRedisStickyStore(context.Background(), RedisConfig{
|
|
Addr: server.Addr(),
|
|
Password: "secret",
|
|
DB: 2,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewRedisStickyStore() error = %v", err)
|
|
}
|
|
|
|
ctx := context.Background()
|
|
key, err := BuildStickyKey(StickyScopeSession, "gpt-shared", "gpt-5.4", "sess-1")
|
|
if err != nil {
|
|
t.Fatalf("BuildStickyKey() error = %v", err)
|
|
}
|
|
if err := store.Set(ctx, key, StickyBinding{
|
|
LogicalGroupID: "gpt-shared",
|
|
PublicModel: "gpt-5.4",
|
|
RouteID: "asxs",
|
|
ShadowGroupID: "gpt-shared__asxs",
|
|
}, time.Minute); err != nil {
|
|
t.Fatalf("Set() error = %v", err)
|
|
}
|
|
if binding, ok, err := store.Get(ctx, key); err != nil || !ok || binding.RouteID != "asxs" {
|
|
t.Fatalf("Get() = (%+v, %v, %v), want route asxs", binding, ok, err)
|
|
}
|
|
if err := store.SetRouteFailure(ctx, "asxs", RouteFailureState{
|
|
FailureCount: 3,
|
|
LastErrorClass: "timeout",
|
|
}, time.Minute); err != nil {
|
|
t.Fatalf("SetRouteFailure() error = %v", err)
|
|
}
|
|
if state, ok, err := store.GetRouteFailure(ctx, "asxs"); err != nil || !ok || state.FailureCount != 3 {
|
|
t.Fatalf("GetRouteFailure() = (%+v, %v, %v), want count 3", state, ok, err)
|
|
}
|
|
if err := store.SetCooldown(ctx, "asxs", RouteCooldownState{
|
|
Reason: "degraded",
|
|
}, time.Minute); err != nil {
|
|
t.Fatalf("SetCooldown() error = %v", err)
|
|
}
|
|
if state, ok, err := store.GetCooldown(ctx, "asxs"); err != nil || !ok || state.Reason != "degraded" {
|
|
t.Fatalf("GetCooldown() = (%+v, %v, %v), want reason degraded", state, ok, err)
|
|
}
|
|
if err := store.Delete(ctx, key); err != nil {
|
|
t.Fatalf("Delete() error = %v", err)
|
|
}
|
|
if _, ok, err := store.Get(ctx, key); err != nil || ok {
|
|
t.Fatalf("Get() after delete = (ok=%v, err=%v), want false nil", ok, err)
|
|
}
|
|
}
|
|
|
|
type fakeRedisServer struct {
|
|
t *testing.T
|
|
listener net.Listener
|
|
password string
|
|
mu sync.Mutex
|
|
values map[int]map[string]fakeRedisValue
|
|
}
|
|
|
|
type fakeRedisValue struct {
|
|
value string
|
|
expiresAt time.Time
|
|
}
|
|
|
|
func newFakeRedisServer(t *testing.T) *fakeRedisServer {
|
|
t.Helper()
|
|
|
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("net.Listen() error = %v", err)
|
|
}
|
|
server := &fakeRedisServer{
|
|
t: t,
|
|
listener: ln,
|
|
password: "secret",
|
|
values: make(map[int]map[string]fakeRedisValue),
|
|
}
|
|
go server.serve()
|
|
return server
|
|
}
|
|
|
|
func (s *fakeRedisServer) Addr() string {
|
|
return s.listener.Addr().String()
|
|
}
|
|
|
|
func (s *fakeRedisServer) Close() {
|
|
_ = s.listener.Close()
|
|
}
|
|
|
|
func (s *fakeRedisServer) serve() {
|
|
for {
|
|
conn, err := s.listener.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
go s.handleConn(conn)
|
|
}
|
|
}
|
|
|
|
func (s *fakeRedisServer) handleConn(conn net.Conn) {
|
|
defer conn.Close()
|
|
|
|
reader := bufio.NewReader(conn)
|
|
currentDB := 0
|
|
authed := false
|
|
for {
|
|
command, err := readRESPArray(reader)
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
return
|
|
}
|
|
s.writeError(conn, err.Error())
|
|
return
|
|
}
|
|
if len(command) == 0 {
|
|
s.writeError(conn, "empty command")
|
|
continue
|
|
}
|
|
|
|
switch strings.ToUpper(command[0]) {
|
|
case "PING":
|
|
s.writeSimpleString(conn, "PONG")
|
|
case "AUTH":
|
|
if len(command) != 2 || command[1] != s.password {
|
|
s.writeError(conn, "ERR invalid password")
|
|
continue
|
|
}
|
|
authed = true
|
|
s.writeSimpleString(conn, "OK")
|
|
case "SELECT":
|
|
if len(command) != 2 {
|
|
s.writeError(conn, "ERR bad select")
|
|
continue
|
|
}
|
|
db, err := strconv.Atoi(command[1])
|
|
if err != nil {
|
|
s.writeError(conn, "ERR bad db")
|
|
continue
|
|
}
|
|
currentDB = db
|
|
s.writeSimpleString(conn, "OK")
|
|
case "SET":
|
|
if !authed {
|
|
s.writeError(conn, "NOAUTH")
|
|
continue
|
|
}
|
|
if len(command) != 5 || strings.ToUpper(command[3]) != "EX" {
|
|
s.writeError(conn, "ERR bad set")
|
|
continue
|
|
}
|
|
ttl, err := strconv.Atoi(command[4])
|
|
if err != nil {
|
|
s.writeError(conn, "ERR bad ttl")
|
|
continue
|
|
}
|
|
s.setValue(currentDB, command[1], command[2], time.Duration(ttl)*time.Second)
|
|
s.writeSimpleString(conn, "OK")
|
|
case "GET":
|
|
if !authed {
|
|
s.writeError(conn, "NOAUTH")
|
|
continue
|
|
}
|
|
if len(command) != 2 {
|
|
s.writeError(conn, "ERR bad get")
|
|
continue
|
|
}
|
|
value, ok := s.getValue(currentDB, command[1])
|
|
if !ok {
|
|
s.writeNullBulk(conn)
|
|
continue
|
|
}
|
|
s.writeBulk(conn, value)
|
|
case "DEL":
|
|
if !authed {
|
|
s.writeError(conn, "NOAUTH")
|
|
continue
|
|
}
|
|
if len(command) != 2 {
|
|
s.writeError(conn, "ERR bad del")
|
|
continue
|
|
}
|
|
s.deleteValue(currentDB, command[1])
|
|
s.writeInteger(conn, 1)
|
|
default:
|
|
s.writeError(conn, "ERR unknown command")
|
|
}
|
|
}
|
|
}
|
|
|
|
func readRESPArray(reader *bufio.Reader) ([]string, error) {
|
|
line, err := reader.ReadString('\n')
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
line = strings.TrimSuffix(strings.TrimSuffix(line, "\n"), "\r")
|
|
if !strings.HasPrefix(line, "*") {
|
|
return nil, fmt.Errorf("expected array, got %q", line)
|
|
}
|
|
count, err := strconv.Atoi(strings.TrimPrefix(line, "*"))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
parts := make([]string, 0, count)
|
|
for i := 0; i < count; i++ {
|
|
header, err := reader.ReadString('\n')
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
header = strings.TrimSuffix(strings.TrimSuffix(header, "\n"), "\r")
|
|
if !strings.HasPrefix(header, "$") {
|
|
return nil, fmt.Errorf("expected bulk header, got %q", header)
|
|
}
|
|
size, err := strconv.Atoi(strings.TrimPrefix(header, "$"))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
payload := make([]byte, size+2)
|
|
if _, err := io.ReadFull(reader, payload); err != nil {
|
|
return nil, err
|
|
}
|
|
parts = append(parts, string(payload[:size]))
|
|
}
|
|
return parts, nil
|
|
}
|
|
|
|
func (s *fakeRedisServer) writeSimpleString(w io.Writer, value string) {
|
|
_, _ = io.WriteString(w, "+"+value+"\r\n")
|
|
}
|
|
|
|
func (s *fakeRedisServer) writeBulk(w io.Writer, value string) {
|
|
_, _ = io.WriteString(w, fmt.Sprintf("$%d\r\n%s\r\n", len(value), value))
|
|
}
|
|
|
|
func (s *fakeRedisServer) writeNullBulk(w io.Writer) {
|
|
_, _ = io.WriteString(w, "$-1\r\n")
|
|
}
|
|
|
|
func (s *fakeRedisServer) writeInteger(w io.Writer, value int) {
|
|
_, _ = io.WriteString(w, fmt.Sprintf(":%d\r\n", value))
|
|
}
|
|
|
|
func (s *fakeRedisServer) writeError(w io.Writer, message string) {
|
|
_, _ = io.WriteString(w, "-"+message+"\r\n")
|
|
}
|
|
|
|
func (s *fakeRedisServer) setValue(db int, key, value string, ttl time.Duration) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.values[db] == nil {
|
|
s.values[db] = make(map[string]fakeRedisValue)
|
|
}
|
|
s.values[db][key] = fakeRedisValue{value: value, expiresAt: time.Now().Add(ttl)}
|
|
}
|
|
|
|
func (s *fakeRedisServer) getValue(db int, key string) (string, bool) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
value, ok := s.values[db][key]
|
|
if !ok {
|
|
return "", false
|
|
}
|
|
if !value.expiresAt.IsZero() && !value.expiresAt.After(time.Now()) {
|
|
delete(s.values[db], key)
|
|
return "", false
|
|
}
|
|
return value.value, true
|
|
}
|
|
|
|
func (s *fakeRedisServer) deleteValue(db int, key string) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.values[db] == nil {
|
|
return
|
|
}
|
|
delete(s.values[db], key)
|
|
}
|
|
|
|
func TestRedisStickyStoreRequiresAddr(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
if _, err := NewRedisStickyStore(context.Background(), RedisConfig{}); err == nil {
|
|
t.Fatal("NewRedisStickyStore() error = nil, want missing addr")
|
|
}
|
|
}
|
|
|
|
func TestNormalizeRuntimeBackend(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
if got, err := NormalizeRuntimeBackend(""); err != nil || got != RuntimeBackendMemory {
|
|
t.Fatalf("NormalizeRuntimeBackend(\"\") = (%q, %v), want memory nil", got, err)
|
|
}
|
|
if got, err := NormalizeRuntimeBackend("redis"); err != nil || got != RuntimeBackendRedis {
|
|
t.Fatalf("NormalizeRuntimeBackend(redis) = (%q, %v), want redis nil", got, err)
|
|
}
|
|
if _, err := NormalizeRuntimeBackend("bad"); err == nil {
|
|
t.Fatal("NormalizeRuntimeBackend(bad) error = nil, want error")
|
|
}
|
|
}
|
|
|
|
func TestRouteFailureAndCooldownKeyBuilders(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
failureKey, err := BuildRouteFailureKey("asxs")
|
|
if err != nil || failureKey != "routefail:asxs" {
|
|
t.Fatalf("BuildRouteFailureKey() = (%q, %v), want routefail:asxs nil", failureKey, err)
|
|
}
|
|
cooldownKey, err := BuildRouteCooldownKey("asxs")
|
|
if err != nil || cooldownKey != "routecool:asxs" {
|
|
t.Fatalf("BuildRouteCooldownKey() = (%q, %v), want routecool:asxs nil", cooldownKey, err)
|
|
}
|
|
}
|
|
|
|
func TestRedisStickyStoreFixturePathExists(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
if filepath.Base(t.TempDir()) == "" {
|
|
t.Fatal("temp dir base should not be empty")
|
|
}
|
|
}
|