176 lines
5.3 KiB
Go
176 lines
5.3 KiB
Go
package provision
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"sub2api-cn-relay-manager/internal/host/sub2api"
|
|
packdef "sub2api-cn-relay-manager/internal/pack"
|
|
"sub2api-cn-relay-manager/internal/store/sqlite"
|
|
)
|
|
|
|
type PackInstallRequest struct {
|
|
Pack packdef.LoadedPack
|
|
}
|
|
|
|
type PackInstallResult struct {
|
|
Pack sqlite.Pack
|
|
Providers []sqlite.Provider
|
|
HostVersion string
|
|
AlreadyInstalled bool
|
|
}
|
|
|
|
type PackInstallService struct {
|
|
store *sqlite.DB
|
|
host sub2api.HostAdapter
|
|
}
|
|
|
|
func NewPackInstallService(store *sqlite.DB, host sub2api.HostAdapter) *PackInstallService {
|
|
return &PackInstallService{store: store, host: host}
|
|
}
|
|
|
|
func (s *PackInstallService) Install(ctx context.Context, req PackInstallRequest) (PackInstallResult, error) {
|
|
if s == nil || s.store == nil {
|
|
return PackInstallResult{}, fmt.Errorf("store is required")
|
|
}
|
|
if s.host == nil {
|
|
return PackInstallResult{}, fmt.Errorf("host adapter is required")
|
|
}
|
|
if strings.TrimSpace(req.Pack.Manifest.PackID) == "" {
|
|
return PackInstallResult{}, fmt.Errorf("pack manifest is required")
|
|
}
|
|
|
|
hostVersion, err := s.host.GetHostVersion(ctx)
|
|
if err != nil {
|
|
return PackInstallResult{}, fmt.Errorf("get host version: %w", err)
|
|
}
|
|
if err := packdef.CheckHostCompatibility(req.Pack.Manifest, hostVersion); err != nil {
|
|
return PackInstallResult{}, err
|
|
}
|
|
|
|
result := PackInstallResult{HostVersion: hostVersion}
|
|
if err := s.store.WithTx(ctx, func(queries *sqlite.Queries) error {
|
|
existing, err := queries.Packs.GetByPackID(ctx, req.Pack.Manifest.PackID)
|
|
if err == nil {
|
|
if err := validateExistingPack(existing, req.Pack); err != nil {
|
|
return err
|
|
}
|
|
result.AlreadyInstalled = true
|
|
} else if !errors.Is(err, sql.ErrNoRows) {
|
|
return err
|
|
}
|
|
|
|
packRow, err := buildPackRecord(req.Pack)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if _, err := queries.Packs.Upsert(ctx, packRow); err != nil {
|
|
return err
|
|
}
|
|
persistedPack, err := queries.Packs.GetByPackID(ctx, req.Pack.Manifest.PackID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
result.Pack = persistedPack
|
|
|
|
providers := make([]sqlite.Provider, 0, len(req.Pack.Providers))
|
|
for _, providerManifest := range req.Pack.Providers {
|
|
providerRow, err := buildProviderRecord(persistedPack.ID, providerManifest)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if _, err := queries.Providers.Upsert(ctx, providerRow); err != nil {
|
|
return err
|
|
}
|
|
persistedProvider, err := queries.Providers.GetByPackIDAndProviderID(ctx, persistedPack.ID, providerManifest.ProviderID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
providers = append(providers, persistedProvider)
|
|
}
|
|
result.Providers = providers
|
|
return nil
|
|
}); err != nil {
|
|
return PackInstallResult{}, err
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func validateExistingPack(existing sqlite.Pack, loaded packdef.LoadedPack) error {
|
|
if strings.TrimSpace(existing.PackID) != strings.TrimSpace(loaded.Manifest.PackID) {
|
|
return fmt.Errorf("existing pack %q does not match loaded pack %q", existing.PackID, loaded.Manifest.PackID)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func buildPackRecord(loaded packdef.LoadedPack) (sqlite.Pack, error) {
|
|
manifestJSON, err := json.Marshal(loaded.Manifest)
|
|
if err != nil {
|
|
return sqlite.Pack{}, fmt.Errorf("marshal pack manifest: %w", err)
|
|
}
|
|
return sqlite.Pack{
|
|
PackID: loaded.Manifest.PackID,
|
|
Version: loaded.Manifest.Version,
|
|
Checksum: loaded.Checksum,
|
|
Vendor: loaded.Manifest.Vendor,
|
|
TargetHost: loaded.Manifest.TargetHost,
|
|
MinHostVersion: loaded.Manifest.MinHostVersion,
|
|
MaxHostVersion: loaded.Manifest.MaxHostVersion,
|
|
ManifestJSON: string(manifestJSON),
|
|
}, nil
|
|
}
|
|
|
|
func buildProviderRecord(packID int64, provider packdef.ProviderManifest) (sqlite.Provider, error) {
|
|
defaultModelsJSON, err := marshalJSONString(provider.DefaultModels)
|
|
if err != nil {
|
|
return sqlite.Provider{}, fmt.Errorf("marshal provider default models: %w", err)
|
|
}
|
|
groupTemplateJSON, err := marshalJSONString(provider.GroupTemplate)
|
|
if err != nil {
|
|
return sqlite.Provider{}, fmt.Errorf("marshal group template: %w", err)
|
|
}
|
|
channelTemplateJSON, err := marshalJSONString(provider.ChannelTemplate)
|
|
if err != nil {
|
|
return sqlite.Provider{}, fmt.Errorf("marshal channel template: %w", err)
|
|
}
|
|
planTemplateJSON, err := marshalJSONString(provider.PlanTemplate)
|
|
if err != nil {
|
|
return sqlite.Provider{}, fmt.Errorf("marshal plan template: %w", err)
|
|
}
|
|
importOptionsJSON, err := marshalJSONString(provider.Import)
|
|
if err != nil {
|
|
return sqlite.Provider{}, fmt.Errorf("marshal import options: %w", err)
|
|
}
|
|
manifestJSON, err := marshalJSONString(provider)
|
|
if err != nil {
|
|
return sqlite.Provider{}, fmt.Errorf("marshal provider manifest: %w", err)
|
|
}
|
|
return sqlite.Provider{
|
|
PackID: packID,
|
|
ProviderID: provider.ProviderID,
|
|
DisplayName: provider.DisplayName,
|
|
BaseURL: provider.BaseURL,
|
|
Platform: provider.Platform,
|
|
AccountType: provider.AccountType,
|
|
DefaultModelsJSON: defaultModelsJSON,
|
|
SmokeTestModel: provider.SmokeTestModel,
|
|
GroupTemplateJSON: groupTemplateJSON,
|
|
ChannelTemplateJSON: channelTemplateJSON,
|
|
PlanTemplateJSON: planTemplateJSON,
|
|
ImportOptionsJSON: importOptionsJSON,
|
|
ManifestJSON: manifestJSON,
|
|
}, nil
|
|
}
|
|
|
|
func marshalJSONString(value any) (string, error) {
|
|
body, err := json.Marshal(value)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return string(body), nil
|
|
}
|