388 lines
12 KiB
Go
388 lines
12 KiB
Go
package batch
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"sub2api-cn-relay-manager/internal/probe"
|
|
"sub2api-cn-relay-manager/internal/store/sqlite"
|
|
)
|
|
|
|
type BatchImportEntry struct {
|
|
BaseURL string
|
|
APIKey string
|
|
RequestedModels []string
|
|
}
|
|
|
|
type BatchImportRunRequest struct {
|
|
RunID string
|
|
Mode string
|
|
AccessMode string
|
|
HostID string
|
|
HostBaseURL string
|
|
SubscriptionUsers []string
|
|
SubscriptionDays int
|
|
ProbeAPIKey string
|
|
Entries []BatchImportEntry
|
|
}
|
|
|
|
type BatchImportRunResult struct {
|
|
RunID string
|
|
ItemIDs []string
|
|
}
|
|
|
|
type RunStateStore interface {
|
|
Create(ctx context.Context, run sqlite.ImportRun) error
|
|
Update(ctx context.Context, run sqlite.ImportRun) error
|
|
}
|
|
|
|
type ItemStateStore interface {
|
|
Upsert(ctx context.Context, item sqlite.ImportRunItem) error
|
|
}
|
|
|
|
type ReuseLookupInput struct {
|
|
HostID string
|
|
ProviderID string
|
|
BaseURL string
|
|
APIKeyFingerprint string
|
|
CanonicalModelFamilies []string
|
|
}
|
|
|
|
type ReuseLookupResult struct {
|
|
ExistingProviderID string
|
|
ExistingAccessStatus AccessStatus
|
|
ExistingCanonicalFamilys []string
|
|
MatchedAccountID int64
|
|
MatchedAccountState MatchedAccountState
|
|
ExistingModelMapping map[string]string
|
|
}
|
|
|
|
type ProvisionRequest struct {
|
|
RunID string
|
|
ItemID string
|
|
Entry BatchImportEntry
|
|
ProviderID string
|
|
ResolvedModel string
|
|
RoutingStrategy ImportRoutingStrategy
|
|
CapabilityProfile *probe.CapabilityProfile
|
|
}
|
|
|
|
type ProvisionResult struct {
|
|
LegacyBatchID *int64
|
|
LegacyProviderID string
|
|
}
|
|
|
|
type PatchProvisionRequest struct {
|
|
ProviderID string
|
|
Contract ChannelPatchContract
|
|
}
|
|
|
|
type BatchProvisioner interface {
|
|
Provision(ctx context.Context, req ProvisionRequest) (ProvisionResult, error)
|
|
Patch(ctx context.Context, req PatchProvisionRequest) error
|
|
}
|
|
|
|
type BatchImportService struct {
|
|
RunStore RunStateStore
|
|
ItemStore ItemStateStore
|
|
ProbeModels func(ctx context.Context, baseURL, apiKey string) (*probe.ModelsResult, error)
|
|
ProbeCapabilities func(ctx context.Context, baseURL, apiKey string, rawModels []string) (*probe.CapabilityProfile, error)
|
|
InspectReuse func(ctx context.Context, input ReuseLookupInput) (ReuseLookupResult, error)
|
|
Provisioner BatchProvisioner
|
|
}
|
|
|
|
func (s BatchImportService) StartRun(ctx context.Context, req BatchImportRunRequest) (BatchImportRunResult, error) {
|
|
if s.RunStore == nil {
|
|
return BatchImportRunResult{}, fmt.Errorf("run store is required")
|
|
}
|
|
if s.ItemStore == nil {
|
|
return BatchImportRunResult{}, fmt.Errorf("item store is required")
|
|
}
|
|
if s.ProbeModels == nil {
|
|
return BatchImportRunResult{}, fmt.Errorf("model probe is required")
|
|
}
|
|
if s.ProbeCapabilities == nil {
|
|
return BatchImportRunResult{}, fmt.Errorf("capability probe is required")
|
|
}
|
|
|
|
runID := strings.TrimSpace(req.RunID)
|
|
if runID == "" {
|
|
runID = fmt.Sprintf("run-%d", time.Now().UnixNano())
|
|
}
|
|
if len(req.Entries) == 0 {
|
|
return BatchImportRunResult{}, fmt.Errorf("entries are required")
|
|
}
|
|
|
|
if err := s.RunStore.Create(ctx, sqlite.ImportRun{
|
|
RunID: runID,
|
|
HostID: strings.TrimSpace(req.HostID),
|
|
Mode: strings.TrimSpace(req.Mode),
|
|
AccessMode: strings.TrimSpace(req.AccessMode),
|
|
SubscriptionUsersJSON: mustMarshalJSON(req.SubscriptionUsers, "[]"),
|
|
SubscriptionDays: req.SubscriptionDays,
|
|
ProbeAPIKey: strings.TrimSpace(req.ProbeAPIKey),
|
|
State: string(RunStateRunning),
|
|
TotalItems: len(req.Entries),
|
|
}); err != nil {
|
|
return BatchImportRunResult{}, err
|
|
}
|
|
|
|
result := BatchImportRunResult{
|
|
RunID: runID,
|
|
ItemIDs: make([]string, 0, len(req.Entries)),
|
|
}
|
|
|
|
for idx, entry := range req.Entries {
|
|
itemID := fmt.Sprintf("%s-item-%d", runID, idx+1)
|
|
result.ItemIDs = append(result.ItemIDs, itemID)
|
|
|
|
providerID := NormalizeProviderID(entry.BaseURL)
|
|
fingerprint := fingerprintAPIKey(entry.APIKey)
|
|
initialItem := sqlite.ImportRunItem{
|
|
ItemID: itemID,
|
|
RunID: runID,
|
|
BaseURL: strings.TrimSpace(entry.BaseURL),
|
|
ProviderID: providerID,
|
|
APIKeyFingerprint: fingerprint,
|
|
CurrentStage: string(ItemStageProbe),
|
|
ConfirmationStatus: string(ConfirmationPending),
|
|
AccessStatus: string(AccessStatusUnknown),
|
|
MatchedAccountState: string(MatchedAccountStateNone),
|
|
AccountResolution: string(AccountResolutionCreated),
|
|
}
|
|
if err := s.ItemStore.Upsert(ctx, initialItem); err != nil {
|
|
return BatchImportRunResult{}, err
|
|
}
|
|
|
|
modelsResult, err := s.ProbeModels(ctx, entry.BaseURL, entry.APIKey)
|
|
if err != nil {
|
|
if failErr := s.failRun(ctx, req, initialItem, ItemStageProbe, err); failErr != nil {
|
|
return BatchImportRunResult{}, failErr
|
|
}
|
|
return result, nil
|
|
}
|
|
rawModels := append([]string(nil), modelsResult.RawModels...)
|
|
capabilityProfile, err := s.ProbeCapabilities(ctx, entry.BaseURL, entry.APIKey, rawModels)
|
|
if err != nil {
|
|
if failErr := s.failRun(ctx, req, initialItem, ItemStageProbe, err); failErr != nil {
|
|
return BatchImportRunResult{}, failErr
|
|
}
|
|
return result, nil
|
|
}
|
|
routingStrategy := BuildImportRoutingStrategy(capabilityProfile)
|
|
resolvedSmokeModel, recommendedModels, err := probe.ResolveSmokeModel(entry.RequestedModels, rawModels, capabilityProfile)
|
|
if err != nil {
|
|
if failErr := s.failRun(ctx, req, initialItem, ItemStageProbe, err); failErr != nil {
|
|
return BatchImportRunResult{}, failErr
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
canonicalFamilies := uniqueCanonicalFamilies(rawModels)
|
|
reuseLookup := ReuseLookupResult{}
|
|
if s.InspectReuse != nil {
|
|
reuseLookup, err = s.InspectReuse(ctx, ReuseLookupInput{
|
|
HostID: strings.TrimSpace(req.HostID),
|
|
ProviderID: providerID,
|
|
BaseURL: entry.BaseURL,
|
|
APIKeyFingerprint: fingerprint,
|
|
CanonicalModelFamilies: canonicalFamilies,
|
|
})
|
|
if err != nil {
|
|
if failErr := s.failRun(ctx, req, initialItem, ItemStageProbe, err); failErr != nil {
|
|
return BatchImportRunResult{}, failErr
|
|
}
|
|
return result, nil
|
|
}
|
|
}
|
|
|
|
reuseDecision := DecideReuse(ReuseInput{
|
|
ProviderID: providerID,
|
|
CanonicalModelFamilies: canonicalFamilies,
|
|
MatchedAccountID: reuseLookup.MatchedAccountID,
|
|
MatchedAccountState: reuseLookup.MatchedAccountState,
|
|
ExistingProviderID: reuseLookup.ExistingProviderID,
|
|
ExistingAccessStatus: reuseLookup.ExistingAccessStatus,
|
|
ExistingCanonicalFamilys: reuseLookup.ExistingCanonicalFamilys,
|
|
})
|
|
|
|
finalItem := sqlite.ImportRunItem{
|
|
ItemID: itemID,
|
|
RunID: runID,
|
|
BaseURL: strings.TrimSpace(entry.BaseURL),
|
|
ProviderID: providerID,
|
|
APIKeyFingerprint: fingerprint,
|
|
RequestedModelsJSON: mustMarshalJSON(entry.RequestedModels, "[]"),
|
|
RawModelsJSON: mustMarshalJSON(rawModels, "[]"),
|
|
NormalizedModelsJSON: mustMarshalJSON(uniqueNormalizedModels(rawModels), "[]"),
|
|
CanonicalFamiliesJSON: mustMarshalJSON(canonicalFamilies, "[]"),
|
|
RecommendedModelsJSON: mustMarshalJSON(recommendedModels, "[]"),
|
|
ResolvedSmokeModel: resolvedSmokeModel,
|
|
CapabilityProfileJSON: mustMarshalJSON(capabilityProfile, "{}"),
|
|
CurrentStage: string(ItemStageConfirm),
|
|
ConfirmationStatus: string(ConfirmationPending),
|
|
AccessStatus: string(AccessStatusUnknown),
|
|
MatchedAccountState: string(reuseDecision.MatchedAccountState),
|
|
AccountResolution: string(reuseDecision.AccountResolution),
|
|
ProvisionReused: reuseDecision.ProvisionReused,
|
|
ReusedFromProviderID: reuseDecision.ReusedFromProviderID,
|
|
ReusedFromAccountID: int64PtrIfSet(reuseDecision.ReusedFromAccountID),
|
|
}
|
|
|
|
if reuseDecision.ProvisionReused {
|
|
patchContract := ModelMappingDelta(reuseLookup.ExistingModelMapping, probe.BuildAliasTable(rawModels))
|
|
if shouldPatchAliases(reuseLookup.ExistingModelMapping, patchContract.ModelMapping) {
|
|
if s.Provisioner == nil {
|
|
if failErr := s.failRun(ctx, req, initialItem, ItemStageProvision, fmt.Errorf("provisioner is required for patch-only flow")); failErr != nil {
|
|
return BatchImportRunResult{}, failErr
|
|
}
|
|
return result, nil
|
|
}
|
|
if err := s.Provisioner.Patch(ctx, PatchProvisionRequest{
|
|
ProviderID: reuseDecision.ReusedFromProviderID,
|
|
Contract: patchContract,
|
|
}); err != nil {
|
|
if failErr := s.failRun(ctx, req, initialItem, ItemStageProvision, err); failErr != nil {
|
|
return BatchImportRunResult{}, failErr
|
|
}
|
|
return result, nil
|
|
}
|
|
}
|
|
} else {
|
|
if s.Provisioner == nil {
|
|
if failErr := s.failRun(ctx, req, initialItem, ItemStageProvision, fmt.Errorf("provisioner is required")); failErr != nil {
|
|
return BatchImportRunResult{}, failErr
|
|
}
|
|
return result, nil
|
|
}
|
|
provisionResult, err := s.Provisioner.Provision(ctx, ProvisionRequest{
|
|
RunID: runID,
|
|
ItemID: itemID,
|
|
Entry: entry,
|
|
ProviderID: providerID,
|
|
ResolvedModel: resolvedSmokeModel,
|
|
RoutingStrategy: routingStrategy,
|
|
CapabilityProfile: capabilityProfile,
|
|
})
|
|
if err != nil {
|
|
if failErr := s.failRun(ctx, req, initialItem, ItemStageProvision, err); failErr != nil {
|
|
return BatchImportRunResult{}, failErr
|
|
}
|
|
return result, nil
|
|
}
|
|
finalItem.LegacyBatchID = provisionResult.LegacyBatchID
|
|
finalItem.LegacyProviderID = strings.TrimSpace(provisionResult.LegacyProviderID)
|
|
}
|
|
|
|
if err := s.ItemStore.Upsert(ctx, finalItem); err != nil {
|
|
return BatchImportRunResult{}, err
|
|
}
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func (s BatchImportService) failRun(ctx context.Context, req BatchImportRunRequest, item sqlite.ImportRunItem, stage ItemStage, cause error) error {
|
|
item.CurrentStage = string(ItemStageDone)
|
|
item.ConfirmationStatus = string(ConfirmationFailed)
|
|
item.AccessStatus = string(AccessStatusBroken)
|
|
item.LastErrorStage = string(stage)
|
|
item.LastError = strings.TrimSpace(cause.Error())
|
|
item.LeaseOwner = ""
|
|
item.LeaseUntil = ""
|
|
item.NextRetryAt = ""
|
|
if err := s.ItemStore.Upsert(ctx, item); err != nil {
|
|
return err
|
|
}
|
|
|
|
return s.RunStore.Update(ctx, sqlite.ImportRun{
|
|
RunID: strings.TrimSpace(req.RunID),
|
|
HostID: strings.TrimSpace(req.HostID),
|
|
Mode: strings.TrimSpace(req.Mode),
|
|
AccessMode: strings.TrimSpace(req.AccessMode),
|
|
SubscriptionUsersJSON: mustMarshalJSON(req.SubscriptionUsers, "[]"),
|
|
SubscriptionDays: req.SubscriptionDays,
|
|
ProbeAPIKey: strings.TrimSpace(req.ProbeAPIKey),
|
|
State: string(RunStateFailed),
|
|
TotalItems: len(req.Entries),
|
|
CompletedItems: 1,
|
|
BrokenItems: 1,
|
|
FinishedAt: time.Now().UTC().Format(time.RFC3339),
|
|
})
|
|
}
|
|
|
|
func uniqueCanonicalFamilies(rawModels []string) []string {
|
|
seen := make(map[string]struct{}, len(rawModels))
|
|
families := make([]string, 0, len(rawModels))
|
|
for _, rawModel := range rawModels {
|
|
family := probe.CanonicalModelFamily(rawModel)
|
|
if family == "" {
|
|
continue
|
|
}
|
|
if _, ok := seen[family]; ok {
|
|
continue
|
|
}
|
|
seen[family] = struct{}{}
|
|
families = append(families, family)
|
|
}
|
|
return families
|
|
}
|
|
|
|
func uniqueNormalizedModels(rawModels []string) []string {
|
|
seen := make(map[string]struct{}, len(rawModels))
|
|
models := make([]string, 0, len(rawModels))
|
|
for _, rawModel := range rawModels {
|
|
normalized := probe.NormalizeModelID(rawModel)
|
|
if normalized == "" {
|
|
continue
|
|
}
|
|
if _, ok := seen[normalized]; ok {
|
|
continue
|
|
}
|
|
seen[normalized] = struct{}{}
|
|
models = append(models, normalized)
|
|
}
|
|
return models
|
|
}
|
|
|
|
func mustMarshalJSON(value any, fallback string) string {
|
|
payload, err := json.Marshal(value)
|
|
if err != nil {
|
|
return fallback
|
|
}
|
|
return string(payload)
|
|
}
|
|
|
|
func fingerprintAPIKey(apiKey string) string {
|
|
trimmed := strings.TrimSpace(apiKey)
|
|
if trimmed == "" {
|
|
return ""
|
|
}
|
|
sum := sha256.Sum256([]byte(trimmed))
|
|
return fmt.Sprintf("sha256:%x", sum[:8])
|
|
}
|
|
|
|
func int64PtrIfSet(value int64) *int64 {
|
|
if value == 0 {
|
|
return nil
|
|
}
|
|
result := value
|
|
return &result
|
|
}
|
|
|
|
func shouldPatchAliases(existing map[string]string, next map[string]string) bool {
|
|
if len(existing) == 0 {
|
|
return false
|
|
}
|
|
for key, value := range next {
|
|
if existingValue, ok := existing[key]; !ok || existingValue != value {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|