Files
sub2api-cn-relay-manager/internal/app/reconcile_background.go

313 lines
9.4 KiB
Go

package app
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"strings"
"time"
"sub2api-cn-relay-manager/internal/host/sub2api"
"sub2api-cn-relay-manager/internal/pack"
"sub2api-cn-relay-manager/internal/provision"
"sub2api-cn-relay-manager/internal/reconcile"
"sub2api-cn-relay-manager/internal/store/sqlite"
"sub2api-cn-relay-manager/internal/worker"
)
const sqliteTimestampLayout = "2006-01-02 15:04:05"
func runReconcileBackgroundScheduler(ctx context.Context, sqliteDSN string, interval time.Duration) {
if interval <= 0 {
return
}
worker.NewRunner(
[]worker.Job{reconcileSweepJob{sqliteDSN: sqliteDSN, interval: interval}},
interval,
log.Printf,
).Start(ctx)
}
type reconcileSweepJob struct {
sqliteDSN string
interval time.Duration
}
func (j reconcileSweepJob) Name() string {
return "reconcile background scheduler"
}
func (j reconcileSweepJob) Run(ctx context.Context) error {
store, err := sqlite.Open(ctx, j.sqliteDSN)
if err != nil {
return err
}
defer store.Close()
return runReconcileBackgroundSweep(ctx, store, j.interval, time.Now())
}
func runReconcileBackgroundSweep(ctx context.Context, store *sqlite.DB, interval time.Duration, now time.Time) error {
if store == nil {
return fmt.Errorf("store is required")
}
candidates, err := store.ImportBatches().ListLatestReconcilable(ctx)
if err != nil {
return err
}
var errs []error
for _, batch := range candidates {
if ctx.Err() != nil {
return ctx.Err()
}
lastRun, err := latestReconcileRunForBatch(ctx, store, batch.ProviderID, batch.HostID)
if err != nil {
errs = append(errs, fmt.Errorf("load latest reconcile run for batch %d: %w", batch.ID, err))
continue
}
if !reconcileRunDue(now, lastRun, interval) {
continue
}
if err := runReconcileCandidate(ctx, store, batch); err != nil {
errs = append(errs, fmt.Errorf("run reconcile for batch %d: %w", batch.ID, err))
}
}
return errors.Join(errs...)
}
func latestReconcileRunForBatch(ctx context.Context, store *sqlite.DB, providerID, hostID int64) (*sqlite.ReconcileRun, error) {
runs, err := store.ReconcileRuns().GetByProviderIDAndHostID(ctx, providerID, hostID)
if err != nil {
return nil, err
}
if len(runs) == 0 {
return nil, nil
}
return &runs[0], nil
}
func reconcileRunDue(now time.Time, run *sqlite.ReconcileRun, interval time.Duration) bool {
if run == nil || interval <= 0 {
return true
}
lastRunAt, err := time.ParseInLocation(sqliteTimestampLayout, strings.TrimSpace(run.CreatedAt), time.UTC)
if err != nil {
return true
}
return now.Sub(lastRunAt) >= interval
}
func runReconcileCandidate(ctx context.Context, store *sqlite.DB, batch sqlite.ImportBatch) error {
hostRow, err := store.Hosts().GetByID(ctx, batch.HostID)
if err != nil {
return err
}
packRow, err := store.Packs().GetByID(ctx, batch.PackID)
if err != nil {
return err
}
providerRow, err := store.Providers().GetByID(ctx, batch.ProviderID)
if err != nil {
return err
}
accessClosures, err := store.AccessClosures().GetByBatchID(ctx, batch.ID)
if err != nil {
return err
}
accessProbeAPIKey, err := reconcileProbeAPIKey(ctx, store, hostRow, batch, accessClosures)
if err != nil {
return err
}
loadedPack, err := storedLoadedPack(packRow)
if err != nil {
return err
}
providerManifest, err := storedProviderManifest(providerRow)
if err != nil {
return err
}
client, err := newSub2APIClient(hostRow.BaseURL, authFromStoredHost(hostRow))
if err != nil {
return err
}
_, err = reconcile.NewService(store, client).Reconcile(ctx, reconcile.Request{
HostID: hostRow.HostID,
HostBaseURL: hostRow.BaseURL,
AccessProbeAPIKey: accessProbeAPIKey,
Pack: loadedPack,
Provider: providerManifest,
})
return err
}
func reconcileProbeAPIKey(ctx context.Context, store *sqlite.DB, hostRow sqlite.Host, batch sqlite.ImportBatch, accessClosures []sqlite.AccessClosureRecord) (string, error) {
if len(accessClosures) == 0 {
return "", fmt.Errorf("access closure not found for batch %d", batch.ID)
}
latestClosure := accessClosures[len(accessClosures)-1]
switch strings.TrimSpace(latestClosure.ClosureType) {
case provision.AccessModeSelfService:
details := parseAccessClosureDetails(latestClosure.DetailsJSON)
apiKey, _ := details["access_api_key"].(string)
if strings.TrimSpace(apiKey) == "" {
apiKey, _ = details["probe_api_key"].(string)
}
if strings.TrimSpace(apiKey) == "" {
return "", fmt.Errorf("self_service access closure missing probe api key")
}
return strings.TrimSpace(apiKey), nil
case provision.AccessModeSubscription:
details := parseAccessClosureDetails(latestClosure.DetailsJSON)
subscriptionUsers := parseJSONStringArray(details["subscription_users"])
if len(subscriptionUsers) == 0 {
return "", fmt.Errorf("subscription access closure missing subscription_users")
}
subscriptionDays := parseJSONInt(details["subscription_days"])
groupID, err := resolveManagedResourceHostIDByBatch(ctx, store, batch.ID, "group")
if err != nil {
return "", err
}
client, err := newSub2APIClient(hostRow.BaseURL, authFromStoredHost(hostRow))
if err != nil {
return "", err
}
accessRef, err := client.EnsureSubscriptionAccess(ctx, sub2api.EnsureSubscriptionAccessRequest{
UserSelector: subscriptionUsers[0],
GroupID: groupID,
})
if err != nil {
return "", err
}
userID := strings.TrimSpace(accessRef.UserID)
if userID == "" {
userID = subscriptionUsers[0]
}
if subscriptionDays > 0 {
if _, err := client.AssignSubscription(ctx, sub2api.AssignSubscriptionRequest{
UserID: userID,
GroupID: groupID,
DurationDays: subscriptionDays,
}); err != nil {
return "", err
}
}
if strings.TrimSpace(accessRef.APIKey) == "" {
return "", fmt.Errorf("subscription access api key is empty")
}
return strings.TrimSpace(accessRef.APIKey), nil
default:
return "", fmt.Errorf("unsupported access closure type %q", latestClosure.ClosureType)
}
}
func parseAccessClosureDetails(raw string) map[string]any {
payload := map[string]any{}
if err := json.Unmarshal([]byte(strings.TrimSpace(raw)), &payload); err != nil {
return map[string]any{}
}
return payload
}
func parseJSONStringArray(raw any) []string {
values, ok := raw.([]any)
if !ok {
return nil
}
result := make([]string, 0, len(values))
for _, value := range values {
text, ok := value.(string)
if !ok {
continue
}
if trimmed := strings.TrimSpace(text); trimmed != "" {
result = append(result, trimmed)
}
}
return result
}
func parseJSONInt(raw any) int {
switch value := raw.(type) {
case float64:
return int(value)
case int:
return value
default:
return 0
}
}
func storedLoadedPack(packRow sqlite.Pack) (pack.LoadedPack, error) {
manifest := pack.Manifest{}
if trimmed := strings.TrimSpace(packRow.ManifestJSON); trimmed != "" && trimmed != "{}" {
if err := json.Unmarshal([]byte(trimmed), &manifest); err != nil {
return pack.LoadedPack{}, fmt.Errorf("decode stored pack manifest: %w", err)
}
}
if strings.TrimSpace(manifest.PackID) == "" {
manifest.PackID = strings.TrimSpace(packRow.PackID)
}
if strings.TrimSpace(manifest.Version) == "" {
manifest.Version = strings.TrimSpace(packRow.Version)
}
if strings.TrimSpace(manifest.Vendor) == "" {
manifest.Vendor = strings.TrimSpace(packRow.Vendor)
}
if strings.TrimSpace(manifest.TargetHost) == "" {
manifest.TargetHost = strings.TrimSpace(packRow.TargetHost)
}
if strings.TrimSpace(manifest.MinHostVersion) == "" {
manifest.MinHostVersion = strings.TrimSpace(packRow.MinHostVersion)
}
if strings.TrimSpace(manifest.MaxHostVersion) == "" {
manifest.MaxHostVersion = strings.TrimSpace(packRow.MaxHostVersion)
}
return pack.LoadedPack{Manifest: manifest, Checksum: strings.TrimSpace(packRow.Checksum)}, nil
}
func storedProviderManifest(providerRow sqlite.Provider) (pack.ProviderManifest, error) {
provider := pack.ProviderManifest{}
if trimmed := strings.TrimSpace(providerRow.ManifestJSON); trimmed != "" && trimmed != "{}" {
if err := json.Unmarshal([]byte(trimmed), &provider); err != nil {
return pack.ProviderManifest{}, fmt.Errorf("decode stored provider manifest: %w", err)
}
}
if strings.TrimSpace(provider.ProviderID) == "" {
provider.ProviderID = strings.TrimSpace(providerRow.ProviderID)
}
if strings.TrimSpace(provider.DisplayName) == "" {
provider.DisplayName = strings.TrimSpace(providerRow.DisplayName)
}
if strings.TrimSpace(provider.BaseURL) == "" {
provider.BaseURL = strings.TrimSpace(providerRow.BaseURL)
}
if strings.TrimSpace(provider.Platform) == "" {
provider.Platform = strings.TrimSpace(providerRow.Platform)
}
if strings.TrimSpace(provider.AccountType) == "" {
provider.AccountType = strings.TrimSpace(providerRow.AccountType)
}
if strings.TrimSpace(provider.SmokeTestModel) == "" {
provider.SmokeTestModel = strings.TrimSpace(providerRow.SmokeTestModel)
}
return provider, nil
}
func resolveManagedResourceHostIDByBatch(ctx context.Context, store *sqlite.DB, batchID int64, resourceType string) (string, error) {
resources, err := store.ManagedResources().GetByBatchID(ctx, batchID)
if err != nil {
return "", err
}
resourceType = strings.TrimSpace(resourceType)
for _, resource := range resources {
if strings.TrimSpace(resource.ResourceType) == resourceType && strings.TrimSpace(resource.HostResourceID) != "" {
return strings.TrimSpace(resource.HostResourceID), nil
}
}
return "", fmt.Errorf("managed resource %q not found for batch %d", resourceType, batchID)
}