289 lines
9.4 KiB
Go
289 lines
9.4 KiB
Go
package provision
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"sub2api-cn-relay-manager/internal/pack"
|
|
"sub2api-cn-relay-manager/internal/store/sqlite"
|
|
)
|
|
|
|
type RuntimeImportRequest struct {
|
|
HostID string
|
|
HostBaseURL string
|
|
Pack pack.LoadedPack
|
|
Provider pack.ProviderManifest
|
|
Mode string
|
|
Access AccessRequest
|
|
Keys []string
|
|
}
|
|
|
|
type RuntimeImportResult struct {
|
|
BatchID int64
|
|
Report ImportReport
|
|
}
|
|
|
|
type RuntimeImportService struct {
|
|
store *sqlite.DB
|
|
host hostAdapter
|
|
}
|
|
|
|
func NewRuntimeImportService(store *sqlite.DB, host hostAdapter) *RuntimeImportService {
|
|
return &RuntimeImportService{store: store, host: host}
|
|
}
|
|
|
|
func (s *RuntimeImportService) Import(ctx context.Context, req RuntimeImportRequest) (RuntimeImportResult, error) {
|
|
if s == nil || s.store == nil {
|
|
return RuntimeImportResult{}, fmt.Errorf("store is required")
|
|
}
|
|
if s.host == nil {
|
|
return RuntimeImportResult{}, fmt.Errorf("host adapter is required")
|
|
}
|
|
req.HostID = strings.TrimSpace(req.HostID)
|
|
req.HostBaseURL = strings.TrimSpace(req.HostBaseURL)
|
|
if req.HostID == "" {
|
|
if req.HostBaseURL == "" {
|
|
return RuntimeImportResult{}, fmt.Errorf("host_id is required")
|
|
}
|
|
hostRow, err := s.store.Hosts().GetByBaseURL(ctx, req.HostBaseURL)
|
|
if err != nil {
|
|
return RuntimeImportResult{}, fmt.Errorf("host_id is required for unregistered host_base_url %q: %w", req.HostBaseURL, err)
|
|
}
|
|
req.HostID = hostRow.HostID
|
|
}
|
|
|
|
hostVersion, err := s.host.GetHostVersion(ctx)
|
|
if err != nil {
|
|
return RuntimeImportResult{}, fmt.Errorf("get host version: %w", err)
|
|
}
|
|
if err := pack.CheckHostCompatibility(req.Pack.Manifest, hostVersion); err != nil {
|
|
return RuntimeImportResult{}, err
|
|
}
|
|
capabilities, err := s.host.ProbeCapabilities(ctx)
|
|
if err != nil {
|
|
return RuntimeImportResult{}, fmt.Errorf("probe host capabilities: %w", err)
|
|
}
|
|
capabilityProbeJSON, err := json.Marshal(capabilities)
|
|
if err != nil {
|
|
return RuntimeImportResult{}, fmt.Errorf("marshal host capabilities: %w", err)
|
|
}
|
|
|
|
hostRow, err := s.ensureHost(ctx, req.HostID, req.HostBaseURL, hostVersion, string(capabilityProbeJSON))
|
|
if err != nil {
|
|
return RuntimeImportResult{}, err
|
|
}
|
|
packRow, err := s.ensurePack(ctx, req.Pack)
|
|
if err != nil {
|
|
return RuntimeImportResult{}, err
|
|
}
|
|
providerRow, err := s.ensureProvider(ctx, packRow.ID, req.Provider)
|
|
if err != nil {
|
|
return RuntimeImportResult{}, err
|
|
}
|
|
|
|
batchID, err := s.store.ImportBatches().Create(ctx, sqlite.ImportBatch{
|
|
HostID: hostRow.ID,
|
|
PackID: packRow.ID,
|
|
ProviderID: providerRow.ID,
|
|
Mode: req.Mode,
|
|
BatchStatus: "running",
|
|
AccessStatus: "pending",
|
|
})
|
|
if err != nil {
|
|
return RuntimeImportResult{}, err
|
|
}
|
|
|
|
report, importErr := NewImportService(s.host).Import(ctx, ImportRequest{
|
|
Provider: req.Provider,
|
|
Mode: req.Mode,
|
|
Access: req.Access,
|
|
Keys: req.Keys,
|
|
})
|
|
if report.BatchStatus == "" {
|
|
report.BatchStatus = BatchStatusFailed
|
|
}
|
|
if report.AccessStatus == "" {
|
|
report.AccessStatus = AccessStatusBroken
|
|
}
|
|
|
|
includeManagedResources := importErr == nil || req.Mode != ImportModeStrict
|
|
if persistErr := s.persistRuntimeArtifacts(ctx, batchID, hostRow.ID, req.Access.Mode, report, includeManagedResources); persistErr != nil {
|
|
return RuntimeImportResult{}, persistErr
|
|
}
|
|
if err := s.store.ImportBatches().UpdateStatus(ctx, batchID, report.BatchStatus, report.AccessStatus); err != nil {
|
|
return RuntimeImportResult{}, err
|
|
}
|
|
if importErr != nil {
|
|
return RuntimeImportResult{BatchID: batchID, Report: report}, importErr
|
|
}
|
|
return RuntimeImportResult{BatchID: batchID, Report: report}, nil
|
|
}
|
|
|
|
func (s *RuntimeImportService) ensureHost(ctx context.Context, hostID, baseURL, hostVersion, capabilityProbeJSON string) (sqlite.Host, error) {
|
|
host, err := s.store.Hosts().GetByHostID(ctx, hostID)
|
|
if err != nil {
|
|
return sqlite.Host{}, fmt.Errorf("registered host %q not found: %w", hostID, err)
|
|
}
|
|
if baseURL != "" && strings.TrimSpace(host.BaseURL) != strings.TrimSpace(baseURL) {
|
|
return sqlite.Host{}, fmt.Errorf("host %q base_url mismatch: registered=%s runtime=%s", hostID, host.BaseURL, baseURL)
|
|
}
|
|
if err := s.store.Hosts().UpdateProbeByHostID(ctx, hostID, hostVersion, capabilityProbeJSON); err != nil {
|
|
return sqlite.Host{}, err
|
|
}
|
|
return s.store.Hosts().GetByHostID(ctx, hostID)
|
|
}
|
|
|
|
func (s *RuntimeImportService) ensurePack(ctx context.Context, loaded pack.LoadedPack) (sqlite.Pack, error) {
|
|
packRow, err := s.store.Packs().GetByPackID(ctx, loaded.Manifest.PackID)
|
|
if err == nil {
|
|
if err := validateExistingPack(packRow, loaded); err != nil {
|
|
return sqlite.Pack{}, err
|
|
}
|
|
}
|
|
packRecord, err := buildPackRecord(loaded)
|
|
if err != nil {
|
|
return sqlite.Pack{}, err
|
|
}
|
|
if _, err := s.store.Packs().Upsert(ctx, packRecord); err != nil {
|
|
return sqlite.Pack{}, err
|
|
}
|
|
return s.store.Packs().GetByPackID(ctx, loaded.Manifest.PackID)
|
|
}
|
|
|
|
func (s *RuntimeImportService) ensureProvider(ctx context.Context, packID int64, provider pack.ProviderManifest) (sqlite.Provider, error) {
|
|
if _, err := s.store.Providers().GetByPackIDAndProviderID(ctx, packID, provider.ProviderID); err == nil {
|
|
// continue into upsert path so metadata stays fresh.
|
|
}
|
|
providerRecord, err := buildProviderRecord(packID, provider)
|
|
if err != nil {
|
|
return sqlite.Provider{}, err
|
|
}
|
|
if _, err := s.store.Providers().Upsert(ctx, providerRecord); err != nil {
|
|
return sqlite.Provider{}, err
|
|
}
|
|
return s.store.Providers().GetByPackIDAndProviderID(ctx, packID, provider.ProviderID)
|
|
}
|
|
|
|
func (s *RuntimeImportService) persistRuntimeArtifacts(ctx context.Context, batchID, hostID int64, accessMode string, report ImportReport, includeManagedResources bool) error {
|
|
for i, account := range report.Accounts {
|
|
validationStatus := account.ValidationStatus()
|
|
payload, err := json.Marshal(map[string]any{
|
|
"account_id": account.Ref.ID,
|
|
"probe_ok": account.Probe.OK,
|
|
"probe_status": account.Probe.Status,
|
|
"probe_message": account.Probe.Message,
|
|
"models": account.Models,
|
|
"smoke_model_seen": account.SmokeModelSeen,
|
|
"probe_advisory": account.HasAdvisoryWarning(),
|
|
"validation_status": validationStatus,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("marshal account probe summary: %w", err)
|
|
}
|
|
itemID, err := s.store.ImportBatchItems().Create(ctx, sqlite.ImportBatchItem{
|
|
BatchID: batchID,
|
|
KeyFingerprint: fingerprintKey(report.AcceptedKeys, i),
|
|
AccountStatus: validationStatus,
|
|
ProbeSummaryJSON: string(payload),
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if _, err := s.store.ProbeResults().Create(ctx, sqlite.ProbeResult{
|
|
BatchItemID: itemID,
|
|
ProbeType: "account_smoke",
|
|
Status: validationStatus,
|
|
SummaryJSON: string(payload),
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if includeManagedResources {
|
|
if err := s.persistManagedResourceIfAbsent(ctx, batchID, hostID, "group", report.Group.ID, report.Group.Name); err != nil {
|
|
return err
|
|
}
|
|
if err := s.persistManagedResourceIfAbsent(ctx, batchID, hostID, "channel", report.Channel.ID, report.Channel.Name); err != nil {
|
|
return err
|
|
}
|
|
if report.Plan != nil {
|
|
if err := s.persistManagedResourceIfAbsent(ctx, batchID, hostID, "plan", report.Plan.ID, report.Plan.Name); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
for _, account := range report.Accounts {
|
|
if err := s.persistManagedResourceIfAbsent(ctx, batchID, hostID, "account", account.Ref.ID, account.Ref.Name); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
accessPayload, err := json.Marshal(map[string]any{
|
|
"status_code": report.Gateway.StatusCode,
|
|
"ok": report.Gateway.OK,
|
|
"has_expected_model": report.Gateway.HasExpectedModel,
|
|
"models": report.Gateway.Models,
|
|
"completion_ok": report.Gateway.CompletionOK,
|
|
"completion_status": report.Gateway.CompletionStatus,
|
|
"completion_type": report.Gateway.CompletionType,
|
|
"completion_preview": report.Gateway.CompletionBody,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("marshal gateway access summary: %w", err)
|
|
}
|
|
if _, err := s.store.AccessClosures().Create(ctx, sqlite.AccessClosureRecord{
|
|
BatchID: batchID,
|
|
ClosureType: firstNonEmpty(strings.TrimSpace(accessMode), "unknown"),
|
|
Status: firstNonEmpty(report.AccessStatus, AccessStatusBroken),
|
|
DetailsJSON: string(accessPayload),
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *RuntimeImportService) persistManagedResourceIfAbsent(ctx context.Context, batchID, hostID int64, resourceType, hostResourceID, resourceName string) error {
|
|
resourceType = strings.TrimSpace(resourceType)
|
|
hostResourceID = strings.TrimSpace(hostResourceID)
|
|
resourceName = firstNonEmpty(resourceName, hostResourceID)
|
|
if resourceType == "" || hostResourceID == "" {
|
|
return nil
|
|
}
|
|
if _, err := s.store.ManagedResources().GetByResourceIdentity(ctx, hostID, resourceType, hostResourceID); err == nil {
|
|
return nil
|
|
}
|
|
if _, err := s.store.ManagedResources().Create(ctx, sqlite.ManagedResource{
|
|
BatchID: batchID,
|
|
HostID: hostID,
|
|
ResourceType: resourceType,
|
|
HostResourceID: hostResourceID,
|
|
ResourceName: resourceName,
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func fingerprintKey(keys []string, index int) string {
|
|
if index >= 0 && index < len(keys) {
|
|
key := strings.TrimSpace(keys[index])
|
|
if key != "" {
|
|
sum := sha256.Sum256([]byte(key))
|
|
return fmt.Sprintf("sha256:%x", sum[:])
|
|
}
|
|
}
|
|
return fmt.Sprintf("key-%d", index+1)
|
|
}
|
|
|
|
func firstNonEmpty(values ...string) string {
|
|
for _, value := range values {
|
|
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
|
return trimmed
|
|
}
|
|
}
|
|
return ""
|
|
}
|