Files
sub2api-cn-relay-manager/internal/provision/runtime_import_service.go
2026-05-23 17:06:52 +08:00

297 lines
9.5 KiB
Go

package provision
import (
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"strings"
"sub2api-cn-relay-manager/internal/pack"
"sub2api-cn-relay-manager/internal/store/sqlite"
)
type RuntimeImportRequest struct {
HostID string
HostBaseURL string
Pack pack.LoadedPack
Provider pack.ProviderManifest
Mode string
Access AccessRequest
Keys []string
}
type RuntimeImportResult struct {
BatchID int64
Report ImportReport
}
type RuntimeImportService struct {
store *sqlite.DB
host hostAdapter
}
func NewRuntimeImportService(store *sqlite.DB, host hostAdapter) *RuntimeImportService {
return &RuntimeImportService{store: store, host: host}
}
func (s *RuntimeImportService) Import(ctx context.Context, req RuntimeImportRequest) (RuntimeImportResult, error) {
if s == nil || s.store == nil {
return RuntimeImportResult{}, fmt.Errorf("store is required")
}
if s.host == nil {
return RuntimeImportResult{}, fmt.Errorf("host adapter is required")
}
req.HostID = strings.TrimSpace(req.HostID)
req.HostBaseURL = strings.TrimSpace(req.HostBaseURL)
if req.HostID == "" {
if req.HostBaseURL == "" {
return RuntimeImportResult{}, fmt.Errorf("host_id is required")
}
hostRow, err := s.store.Hosts().GetByBaseURL(ctx, req.HostBaseURL)
if err != nil {
return RuntimeImportResult{}, fmt.Errorf("host_id is required for unregistered host_base_url %q: %w", req.HostBaseURL, err)
}
req.HostID = hostRow.HostID
}
hostVersion, err := s.host.GetHostVersion(ctx)
if err != nil {
return RuntimeImportResult{}, fmt.Errorf("get host version: %w", err)
}
if err := pack.CheckHostCompatibility(req.Pack.Manifest, hostVersion); err != nil {
return RuntimeImportResult{}, err
}
capabilities, err := s.host.ProbeCapabilities(ctx)
if err != nil {
return RuntimeImportResult{}, fmt.Errorf("probe host capabilities: %w", err)
}
capabilityProbeJSON, err := json.Marshal(capabilities)
if err != nil {
return RuntimeImportResult{}, fmt.Errorf("marshal host capabilities: %w", err)
}
hostRow, err := s.ensureHost(ctx, req.HostID, req.HostBaseURL, hostVersion, string(capabilityProbeJSON))
if err != nil {
return RuntimeImportResult{}, err
}
packRow, err := s.ensurePack(ctx, req.Pack)
if err != nil {
return RuntimeImportResult{}, err
}
providerRow, err := s.ensureProvider(ctx, packRow.ID, req.Provider)
if err != nil {
return RuntimeImportResult{}, err
}
batchID, err := s.store.ImportBatches().Create(ctx, sqlite.ImportBatch{
HostID: hostRow.ID,
PackID: packRow.ID,
ProviderID: providerRow.ID,
Mode: req.Mode,
BatchStatus: "running",
AccessStatus: "pending",
})
if err != nil {
return RuntimeImportResult{}, err
}
report, importErr := NewImportService(s.host).Import(ctx, ImportRequest{
Provider: req.Provider,
Mode: req.Mode,
Access: req.Access,
Keys: req.Keys,
})
if report.BatchStatus == "" {
report.BatchStatus = BatchStatusFailed
}
if report.AccessStatus == "" {
report.AccessStatus = AccessStatusBroken
}
includeManagedResources := importErr == nil || req.Mode != ImportModeStrict
if persistErr := s.persistRuntimeArtifacts(ctx, batchID, hostRow.ID, req.Access, report, includeManagedResources); persistErr != nil {
return RuntimeImportResult{}, persistErr
}
if err := s.store.ImportBatches().UpdateStatus(ctx, batchID, report.BatchStatus, report.AccessStatus); err != nil {
return RuntimeImportResult{}, err
}
if importErr != nil {
return RuntimeImportResult{BatchID: batchID, Report: report}, importErr
}
return RuntimeImportResult{BatchID: batchID, Report: report}, nil
}
func (s *RuntimeImportService) ensureHost(ctx context.Context, hostID, baseURL, hostVersion, capabilityProbeJSON string) (sqlite.Host, error) {
host, err := s.store.Hosts().GetByHostID(ctx, hostID)
if err != nil {
return sqlite.Host{}, fmt.Errorf("registered host %q not found: %w", hostID, err)
}
if baseURL != "" && strings.TrimSpace(host.BaseURL) != strings.TrimSpace(baseURL) {
return sqlite.Host{}, fmt.Errorf("host %q base_url mismatch: registered=%s runtime=%s", hostID, host.BaseURL, baseURL)
}
if err := s.store.Hosts().UpdateProbeByHostID(ctx, hostID, hostVersion, capabilityProbeJSON); err != nil {
return sqlite.Host{}, err
}
return s.store.Hosts().GetByHostID(ctx, hostID)
}
func (s *RuntimeImportService) ensurePack(ctx context.Context, loaded pack.LoadedPack) (sqlite.Pack, error) {
packRow, err := s.store.Packs().GetByPackID(ctx, loaded.Manifest.PackID)
if err == nil {
if err := validateExistingPack(packRow, loaded); err != nil {
return sqlite.Pack{}, err
}
}
packRecord, err := buildPackRecord(loaded)
if err != nil {
return sqlite.Pack{}, err
}
if _, err := s.store.Packs().Upsert(ctx, packRecord); err != nil {
return sqlite.Pack{}, err
}
return s.store.Packs().GetByPackID(ctx, loaded.Manifest.PackID)
}
func (s *RuntimeImportService) ensureProvider(ctx context.Context, packID int64, provider pack.ProviderManifest) (sqlite.Provider, error) {
if _, err := s.store.Providers().GetByPackIDAndProviderID(ctx, packID, provider.ProviderID); err == nil {
// continue into upsert path so metadata stays fresh.
}
providerRecord, err := buildProviderRecord(packID, provider)
if err != nil {
return sqlite.Provider{}, err
}
if _, err := s.store.Providers().Upsert(ctx, providerRecord); err != nil {
return sqlite.Provider{}, err
}
return s.store.Providers().GetByPackIDAndProviderID(ctx, packID, provider.ProviderID)
}
func (s *RuntimeImportService) persistRuntimeArtifacts(ctx context.Context, batchID, hostID int64, access AccessRequest, report ImportReport, includeManagedResources bool) error {
for i, account := range report.Accounts {
validationStatus := account.ValidationStatus()
payload, err := json.Marshal(map[string]any{
"account_id": account.Ref.ID,
"probe_ok": account.Probe.OK,
"probe_status": account.Probe.Status,
"probe_message": account.Probe.Message,
"models": account.Models,
"smoke_model_seen": account.SmokeModelSeen,
"probe_advisory": account.HasAdvisoryWarning(),
"validation_status": validationStatus,
})
if err != nil {
return fmt.Errorf("marshal account probe summary: %w", err)
}
itemID, err := s.store.ImportBatchItems().Create(ctx, sqlite.ImportBatchItem{
BatchID: batchID,
KeyFingerprint: fingerprintKey(report.AcceptedKeys, i),
AccountStatus: validationStatus,
ProbeSummaryJSON: string(payload),
})
if err != nil {
return err
}
if _, err := s.store.ProbeResults().Create(ctx, sqlite.ProbeResult{
BatchItemID: itemID,
ProbeType: "account_smoke",
Status: validationStatus,
SummaryJSON: string(payload),
}); err != nil {
return err
}
}
if includeManagedResources {
if err := s.persistManagedResourceIfAbsent(ctx, batchID, hostID, "group", report.Group.ID, report.Group.Name); err != nil {
return err
}
if err := s.persistManagedResourceIfAbsent(ctx, batchID, hostID, "channel", report.Channel.ID, report.Channel.Name); err != nil {
return err
}
if report.Plan != nil {
if err := s.persistManagedResourceIfAbsent(ctx, batchID, hostID, "plan", report.Plan.ID, report.Plan.Name); err != nil {
return err
}
}
for _, account := range report.Accounts {
if err := s.persistManagedResourceIfAbsent(ctx, batchID, hostID, "account", account.Ref.ID, account.Ref.Name); err != nil {
return err
}
}
}
accessPayload, err := json.Marshal(BuildAccessClosureDetails(access, report.Gateway))
if err != nil {
return fmt.Errorf("marshal gateway access summary: %w", err)
}
if _, err := s.store.AccessClosures().Create(ctx, sqlite.AccessClosureRecord{
BatchID: batchID,
ClosureType: firstNonEmpty(strings.TrimSpace(access.Mode), "unknown"),
Status: firstNonEmpty(report.AccessStatus, AccessStatusBroken),
DetailsJSON: string(accessPayload),
}); err != nil {
return err
}
return nil
}
func subscriptionUserIDs(targets []SubscriptionTarget) []string {
values := make([]string, 0, len(targets))
for _, target := range targets {
if trimmed := strings.TrimSpace(target.UserID); trimmed != "" {
values = append(values, trimmed)
}
}
return values
}
func subscriptionDurationDays(targets []SubscriptionTarget) int {
if len(targets) == 0 {
return 0
}
return targets[0].DurationDays
}
func (s *RuntimeImportService) persistManagedResourceIfAbsent(ctx context.Context, batchID, hostID int64, resourceType, hostResourceID, resourceName string) error {
resourceType = strings.TrimSpace(resourceType)
hostResourceID = strings.TrimSpace(hostResourceID)
resourceName = firstNonEmpty(resourceName, hostResourceID)
if resourceType == "" || hostResourceID == "" {
return nil
}
if _, err := s.store.ManagedResources().GetByResourceIdentity(ctx, hostID, resourceType, hostResourceID); err == nil {
return nil
}
if _, err := s.store.ManagedResources().Create(ctx, sqlite.ManagedResource{
BatchID: batchID,
HostID: hostID,
ResourceType: resourceType,
HostResourceID: hostResourceID,
ResourceName: resourceName,
}); err != nil {
return err
}
return nil
}
func fingerprintKey(keys []string, index int) string {
if index >= 0 && index < len(keys) {
key := strings.TrimSpace(keys[index])
if key != "" {
sum := sha256.Sum256([]byte(key))
return fmt.Sprintf("sha256:%x", sum[:])
}
}
return fmt.Sprintf("key-%d", index+1)
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
}
return ""
}