Files
sub2api-cn-relay-manager/internal/batch/service_test.go

366 lines
12 KiB
Go

package batch
import (
"context"
"encoding/json"
"testing"
"sub2api-cn-relay-manager/internal/probe"
"sub2api-cn-relay-manager/internal/store/sqlite"
)
func TestBatchImport_StartRun(t *testing.T) {
t.Parallel()
t.Run("creates run items and backfills legacy linkage after provision", func(t *testing.T) {
t.Parallel()
runStore := &fakeRunStore{}
itemStore := &fakeItemStore{}
provisioner := &fakeProvisioner{
provisionResult: ProvisionResult{
LegacyBatchID: int64Ptr(81),
LegacyProviderID: "legacy-provider",
},
}
service := BatchImportService{
RunStore: runStore,
ItemStore: itemStore,
ProbeModels: func(context.Context, string, string) (*probe.ModelsResult, error) {
return &probe.ModelsResult{RawModels: []string{"deepseek-ai/DeepSeek-V4-Pro"}}, nil
},
ProbeCapabilities: func(context.Context, string, string, []string) (*probe.CapabilityProfile, error) {
return &probe.CapabilityProfile{
TransportProfile: probe.TransportProfile{SupportsOpenAIChatCompletions: true},
ModelProfiles: []probe.ModelCapabilityProfile{
{RawModelID: "deepseek-ai/DeepSeek-V4-Pro", CanonicalModelFamily: "deepseek-v4-pro", SmokeChatOK: true},
},
}, nil
},
InspectReuse: func(context.Context, ReuseLookupInput) (ReuseLookupResult, error) {
return ReuseLookupResult{}, nil
},
Provisioner: provisioner,
}
result, err := service.StartRun(context.Background(), BatchImportRunRequest{
RunID: "run-1",
Mode: "strict",
AccessMode: "subscription",
HostID: "host-1",
HostBaseURL: "https://relay.example.com",
Entries: []BatchImportEntry{
{BaseURL: "https://api.deepseek.com/v1", APIKey: "sk-live", RequestedModels: []string{"DeepSeek V4 Pro"}},
},
})
if err != nil {
t.Fatalf("StartRun() error = %v", err)
}
if result.RunID != "run-1" {
t.Fatalf("RunID = %q, want run-1", result.RunID)
}
if len(runStore.created) != 1 || runStore.created[0].RunID != "run-1" {
t.Fatalf("created runs = %#v, want run-1 persisted", runStore.created)
}
if provisioner.provisionCalls != 1 {
t.Fatalf("provision calls = %d, want 1", provisioner.provisionCalls)
}
if len(itemStore.upserts) != 2 {
t.Fatalf("item upserts = %d, want initial + final", len(itemStore.upserts))
}
finalItem := itemStore.upserts[len(itemStore.upserts)-1]
if finalItem.LegacyBatchID == nil || *finalItem.LegacyBatchID != 81 {
t.Fatalf("LegacyBatchID = %#v, want 81", finalItem.LegacyBatchID)
}
if finalItem.LegacyProviderID != "legacy-provider" {
t.Fatalf("LegacyProviderID = %q, want legacy-provider", finalItem.LegacyProviderID)
}
if finalItem.CurrentStage != string(ItemStageConfirm) {
t.Fatalf("CurrentStage = %q, want confirm", finalItem.CurrentStage)
}
})
t.Run("active duplicate account is reused without provision", func(t *testing.T) {
t.Parallel()
providerID := NormalizeProviderID("https://api.kimi.com/v1")
itemStore := &fakeItemStore{}
provisioner := &fakeProvisioner{}
service := BatchImportService{
RunStore: &fakeRunStore{},
ItemStore: itemStore,
Provisioner: provisioner,
ProbeModels: func(context.Context, string, string) (*probe.ModelsResult, error) {
return &probe.ModelsResult{RawModels: []string{"kimi-k2.6"}}, nil
},
ProbeCapabilities: func(context.Context, string, string, []string) (*probe.CapabilityProfile, error) {
return &probe.CapabilityProfile{
ModelProfiles: []probe.ModelCapabilityProfile{{RawModelID: "kimi-k2.6", CanonicalModelFamily: "kimi-2.6", SmokeChatOK: true}},
}, nil
},
InspectReuse: func(context.Context, ReuseLookupInput) (ReuseLookupResult, error) {
return ReuseLookupResult{
ExistingProviderID: providerID,
ExistingAccessStatus: AccessStatusActive,
ExistingCanonicalFamilys: []string{"kimi 2.6"},
MatchedAccountID: 201,
MatchedAccountState: MatchedAccountStateActive,
}, nil
},
}
_, err := service.StartRun(context.Background(), BatchImportRunRequest{
RunID: "run-2",
Mode: "strict",
AccessMode: "subscription",
HostID: "host-1",
Entries: []BatchImportEntry{
{BaseURL: "https://api.kimi.com/v1", APIKey: "sk-live", RequestedModels: []string{"kimi 2.6"}},
},
})
if err != nil {
t.Fatalf("StartRun() error = %v", err)
}
if provisioner.provisionCalls != 0 {
t.Fatalf("provision calls = %d, want 0", provisioner.provisionCalls)
}
finalItem := itemStore.upserts[len(itemStore.upserts)-1]
if !finalItem.ProvisionReused {
t.Fatal("ProvisionReused = false, want true")
}
if finalItem.MatchedAccountState != string(MatchedAccountStateActive) {
t.Fatalf("MatchedAccountState = %q, want active", finalItem.MatchedAccountState)
}
if finalItem.AccountResolution != string(AccountResolutionReused) {
t.Fatalf("AccountResolution = %q, want reused", finalItem.AccountResolution)
}
})
t.Run("deprecated duplicate account becomes reactivated", func(t *testing.T) {
t.Parallel()
providerID := NormalizeProviderID("https://api.kimi.com/v1")
itemStore := &fakeItemStore{}
service := BatchImportService{
RunStore: &fakeRunStore{},
ItemStore: itemStore,
Provisioner: &fakeProvisioner{},
ProbeModels: func(context.Context, string, string) (*probe.ModelsResult, error) {
return &probe.ModelsResult{RawModels: []string{"kimi-k2.6"}}, nil
},
ProbeCapabilities: func(context.Context, string, string, []string) (*probe.CapabilityProfile, error) {
return &probe.CapabilityProfile{
ModelProfiles: []probe.ModelCapabilityProfile{{RawModelID: "kimi-k2.6", CanonicalModelFamily: "kimi-2.6", SmokeChatOK: true}},
}, nil
},
InspectReuse: func(context.Context, ReuseLookupInput) (ReuseLookupResult, error) {
return ReuseLookupResult{
ExistingProviderID: providerID,
ExistingAccessStatus: AccessStatusActive,
ExistingCanonicalFamilys: []string{"kimi-2.6"},
MatchedAccountID: 301,
MatchedAccountState: MatchedAccountStateDeprecated,
}, nil
},
}
_, err := service.StartRun(context.Background(), BatchImportRunRequest{
RunID: "run-3",
Mode: "strict",
AccessMode: "subscription",
HostID: "host-1",
Entries: []BatchImportEntry{
{BaseURL: "https://api.kimi.com/v1", APIKey: "sk-live", RequestedModels: []string{"kimi 2.6"}},
},
})
if err != nil {
t.Fatalf("StartRun() error = %v", err)
}
finalItem := itemStore.upserts[len(itemStore.upserts)-1]
if finalItem.AccountResolution != string(AccountResolutionReactivated) {
t.Fatalf("AccountResolution = %q, want reactivated", finalItem.AccountResolution)
}
if !finalItem.ProvisionReused {
t.Fatal("ProvisionReused = false, want true")
}
})
t.Run("same family new alias only patches mapping", func(t *testing.T) {
t.Parallel()
providerID := NormalizeProviderID("https://api.kimi.com/v1")
itemStore := &fakeItemStore{}
provisioner := &fakeProvisioner{}
service := BatchImportService{
RunStore: &fakeRunStore{},
ItemStore: itemStore,
Provisioner: provisioner,
ProbeModels: func(context.Context, string, string) (*probe.ModelsResult, error) {
return &probe.ModelsResult{RawModels: []string{"Kimi-K2.6"}}, nil
},
ProbeCapabilities: func(context.Context, string, string, []string) (*probe.CapabilityProfile, error) {
return &probe.CapabilityProfile{
ModelProfiles: []probe.ModelCapabilityProfile{{RawModelID: "Kimi-K2.6", CanonicalModelFamily: "kimi-2.6", SmokeChatOK: true}},
}, nil
},
InspectReuse: func(context.Context, ReuseLookupInput) (ReuseLookupResult, error) {
return ReuseLookupResult{
ExistingProviderID: providerID,
ExistingAccessStatus: AccessStatusActive,
ExistingCanonicalFamilys: []string{"kimi 2.6"},
MatchedAccountID: 401,
MatchedAccountState: MatchedAccountStateActive,
ExistingModelMapping: map[string]string{"kimi-k2.6": "kimi-2.6"},
}, nil
},
}
_, err := service.StartRun(context.Background(), BatchImportRunRequest{
RunID: "run-4",
Mode: "strict",
AccessMode: "subscription",
HostID: "host-1",
Entries: []BatchImportEntry{
{BaseURL: "https://api.kimi.com/v1", APIKey: "sk-live", RequestedModels: []string{"kimi 2.6"}},
},
})
if err != nil {
t.Fatalf("StartRun() error = %v", err)
}
if provisioner.provisionCalls != 0 {
t.Fatalf("provision calls = %d, want 0", provisioner.provisionCalls)
}
if provisioner.patchCalls != 1 {
t.Fatalf("patch calls = %d, want 1", provisioner.patchCalls)
}
if provisioner.lastPatch.Contract.ModelMapping["Kimi-K2.6"] != "kimi-2.6" {
t.Fatalf("patch mapping = %#v, want raw alias mapped to canonical family", provisioner.lastPatch.Contract.ModelMapping)
}
finalItem := itemStore.upserts[len(itemStore.upserts)-1]
if !finalItem.ProvisionReused {
t.Fatal("ProvisionReused = false, want true for patch-only flow")
}
})
t.Run("probe failure marks run failed instead of leaving running half state", func(t *testing.T) {
t.Parallel()
runStore := &fakeRunStore{}
itemStore := &fakeItemStore{}
service := BatchImportService{
RunStore: runStore,
ItemStore: itemStore,
ProbeModels: func(context.Context, string, string) (*probe.ModelsResult, error) {
return nil, context.DeadlineExceeded
},
ProbeCapabilities: func(context.Context, string, string, []string) (*probe.CapabilityProfile, error) {
t.Fatal("ProbeCapabilities should not be called after probe failure")
return nil, nil
},
}
result, err := service.StartRun(context.Background(), BatchImportRunRequest{
RunID: "run-probe-fail",
Mode: "strict",
AccessMode: "self_service",
HostID: "host-1",
Entries: []BatchImportEntry{
{BaseURL: "https://api.deepseek.com/v1", APIKey: "sk-live", RequestedModels: []string{"DeepSeek V4 Pro"}},
},
})
if err != nil {
t.Fatalf("StartRun() error = %v, want persisted failed run without transport error", err)
}
if result.RunID != "run-probe-fail" {
t.Fatalf("result.RunID = %q, want run-probe-fail", result.RunID)
}
if len(runStore.updated) == 0 {
t.Fatal("run store was not updated to failed state")
}
gotRun := runStore.updated[len(runStore.updated)-1]
if gotRun.State != string(RunStateFailed) {
t.Fatalf("run.State = %q, want failed", gotRun.State)
}
if gotRun.CompletedItems != 1 || gotRun.BrokenItems != 1 {
t.Fatalf("run counters = %+v, want completed_items=1 broken_items=1", gotRun)
}
if len(itemStore.upserts) < 2 {
t.Fatalf("item upserts = %d, want initial + failed terminal state", len(itemStore.upserts))
}
gotItem := itemStore.upserts[len(itemStore.upserts)-1]
if gotItem.CurrentStage != string(ItemStageDone) {
t.Fatalf("item.CurrentStage = %q, want done", gotItem.CurrentStage)
}
if gotItem.ConfirmationStatus != string(ConfirmationFailed) {
t.Fatalf("item.ConfirmationStatus = %q, want failed", gotItem.ConfirmationStatus)
}
if gotItem.AccessStatus != string(AccessStatusBroken) {
t.Fatalf("item.AccessStatus = %q, want broken", gotItem.AccessStatus)
}
if gotItem.LastErrorStage != string(ItemStageProbe) {
t.Fatalf("item.LastErrorStage = %q, want probe", gotItem.LastErrorStage)
}
})
}
type fakeRunStore struct {
created []sqlite.ImportRun
updated []sqlite.ImportRun
}
func (f *fakeRunStore) Create(ctx context.Context, run sqlite.ImportRun) error {
f.created = append(f.created, run)
return nil
}
func (f *fakeRunStore) Update(ctx context.Context, run sqlite.ImportRun) error {
f.updated = append(f.updated, run)
return nil
}
type fakeItemStore struct {
upserts []sqlite.ImportRunItem
}
func (f *fakeItemStore) Upsert(ctx context.Context, item sqlite.ImportRunItem) error {
f.upserts = append(f.upserts, item)
return nil
}
type fakeProvisioner struct {
provisionCalls int
patchCalls int
provisionResult ProvisionResult
lastPatch PatchProvisionRequest
}
func (f *fakeProvisioner) Provision(ctx context.Context, req ProvisionRequest) (ProvisionResult, error) {
f.provisionCalls++
return f.provisionResult, nil
}
func (f *fakeProvisioner) Patch(ctx context.Context, req PatchProvisionRequest) error {
f.patchCalls++
f.lastPatch = req
return nil
}
func int64Ptr(value int64) *int64 {
return &value
}
func mustJSON(t *testing.T, value any) string {
t.Helper()
payload, err := json.Marshal(value)
if err != nil {
t.Fatalf("json.Marshal() error = %v", err)
}
return string(payload)
}