Files
sub2api-cn-relay-manager/internal/pack/loader.go

386 lines
14 KiB
Go

package pack
import (
"bufio"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
)
const envAllowInsecureProviderBaseURL = "SUB2API_CRM_ALLOW_INSECURE_PROVIDER_BASE_URLS"
type Manifest struct {
PackID string `json:"pack_id"`
Version string `json:"version"`
Vendor string `json:"vendor"`
TargetHost string `json:"target_host"`
MinHostVersion string `json:"min_host_version"`
MaxHostVersion string `json:"max_host_version"`
ProvidersDir string `json:"providers_dir"`
ChecksumFile string `json:"checksum_file"`
}
type ProviderManifest struct {
ProviderID string `json:"provider_id"`
DisplayName string `json:"display_name"`
BaseURL string `json:"base_url"`
Platform string `json:"platform"`
AccountType string `json:"account_type"`
ForceDisableOpenAIResponsesAPI bool `json:"force_disable_openai_responses_api,omitempty"`
HostOverlays []HostOverlay `json:"host_overlays,omitempty"`
DefaultModels []string `json:"default_models"`
SmokeTestModel string `json:"smoke_test_model"`
GroupTemplate GroupTemplate `json:"group_template"`
ChannelTemplate ChannelTemplate `json:"channel_template"`
PlanTemplate PlanTemplate `json:"plan_template"`
Import ImportOptions `json:"import"`
}
type HostOverlay struct {
OverlayID string `json:"overlay_id"`
DisplayName string `json:"display_name"`
TargetHost string `json:"target_host"`
MinHostVersion string `json:"min_host_version,omitempty"`
MaxHostVersion string `json:"max_host_version,omitempty"`
ApplyMode string `json:"apply_mode,omitempty"`
PatchPath string `json:"patch_path,omitempty"`
NotesPath string `json:"notes_path,omitempty"`
Reason string `json:"reason"`
}
type GroupTemplate struct {
Name string `json:"name"`
RateMultiplier float64 `json:"rate_multiplier"`
}
type ChannelTemplate struct {
Name string `json:"name"`
ModelMapping map[string]string `json:"model_mapping"`
}
type PlanTemplate struct {
Name string `json:"name"`
Price float64 `json:"price"`
ValidityDays int `json:"validity_days"`
ValidityUnit string `json:"validity_unit"`
}
type ImportOptions struct {
SupportsMultiKey bool `json:"supports_multi_key"`
SupportsStrict bool `json:"supports_strict"`
SupportsPartial bool `json:"supports_partial"`
}
type LoadedPack struct {
Dir string
Manifest Manifest
Providers []ProviderManifest
Checksum string
}
func LoadDir(dir string) (LoadedPack, error) {
root := strings.TrimSpace(dir)
if root == "" {
return LoadedPack{}, fmt.Errorf("pack dir is required")
}
manifestPath := filepath.Join(root, "pack.json")
manifestBytes, err := os.ReadFile(manifestPath)
if err != nil {
return LoadedPack{}, fmt.Errorf("read pack.json: %w", err)
}
var manifest Manifest
if err := json.Unmarshal(manifestBytes, &manifest); err != nil {
return LoadedPack{}, fmt.Errorf("decode pack.json: %w", err)
}
if err := validateManifest(manifest); err != nil {
return LoadedPack{}, err
}
if err := validateChecksums(root, manifest.ChecksumFile); err != nil {
return LoadedPack{}, err
}
providers, err := loadProviders(root, manifest.ProvidersDir)
if err != nil {
return LoadedPack{}, err
}
if len(providers) == 0 {
return LoadedPack{}, fmt.Errorf("providers dir %q does not contain provider manifests", manifest.ProvidersDir)
}
if err := validateProviders(root, providers); err != nil {
return LoadedPack{}, err
}
checksum, err := computeAggregateChecksum(root, manifest.ChecksumFile)
if err != nil {
return LoadedPack{}, err
}
return LoadedPack{Dir: root, Manifest: manifest, Providers: providers, Checksum: checksum}, nil
}
func validateManifest(manifest Manifest) error {
switch {
case strings.TrimSpace(manifest.PackID) == "":
return fmt.Errorf("pack.json: pack_id is required")
case strings.TrimSpace(manifest.Version) == "":
return fmt.Errorf("pack.json: version is required")
case strings.TrimSpace(manifest.TargetHost) == "":
return fmt.Errorf("pack.json: target_host is required")
case strings.TrimSpace(manifest.ProvidersDir) == "":
return fmt.Errorf("pack.json: providers_dir is required")
case strings.TrimSpace(manifest.ChecksumFile) == "":
return fmt.Errorf("pack.json: checksum_file is required")
}
return nil
}
func loadProviders(root string, providersDir string) ([]ProviderManifest, error) {
dir := filepath.Join(root, providersDir)
entries, err := os.ReadDir(dir)
if err != nil {
return nil, fmt.Errorf("read providers dir %q: %w", providersDir, err)
}
providers := make([]ProviderManifest, 0, len(entries))
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") {
continue
}
path := filepath.Join(dir, entry.Name())
body, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read provider %q: %w", entry.Name(), err)
}
var provider ProviderManifest
if err := json.Unmarshal(body, &provider); err != nil {
return nil, fmt.Errorf("decode provider %q: %w", entry.Name(), err)
}
providers = append(providers, provider)
}
sort.Slice(providers, func(i, j int) bool { return providers[i].ProviderID < providers[j].ProviderID })
return providers, nil
}
func validateProviders(root string, providers []ProviderManifest) error {
seen := make(map[string]struct{}, len(providers))
allowInsecureBaseURL := strings.EqualFold(strings.TrimSpace(os.Getenv(envAllowInsecureProviderBaseURL)), "1") ||
strings.EqualFold(strings.TrimSpace(os.Getenv(envAllowInsecureProviderBaseURL)), "true")
for _, provider := range providers {
providerID := strings.TrimSpace(provider.ProviderID)
missingDefaultModel := firstMissingDefaultModel(provider.DefaultModels, provider.ChannelTemplate.ModelMapping)
switch {
case providerID == "":
return fmt.Errorf("provider manifest: provider_id is required")
case strings.TrimSpace(provider.DisplayName) == "":
return fmt.Errorf("provider %q: display_name is required", providerID)
case !hasAllowedProviderBaseURL(strings.TrimSpace(provider.BaseURL), allowInsecureBaseURL):
return fmt.Errorf("provider %q: base_url must use https", providerID)
case strings.TrimSpace(provider.Platform) == "":
return fmt.Errorf("provider %q: platform is required", providerID)
case strings.TrimSpace(provider.AccountType) == "":
return fmt.Errorf("provider %q: account_type is required", providerID)
case len(provider.DefaultModels) == 0:
return fmt.Errorf("provider %q: default_models must not be empty", providerID)
case strings.TrimSpace(provider.SmokeTestModel) == "":
return fmt.Errorf("provider %q: smoke_test_model is required", providerID)
case !contains(provider.DefaultModels, provider.SmokeTestModel):
return fmt.Errorf("provider %q: smoke_test_model must be present in default_models", providerID)
case strings.TrimSpace(provider.GroupTemplate.Name) == "":
return fmt.Errorf("provider %q: group_template.name is required", providerID)
case strings.TrimSpace(provider.ChannelTemplate.Name) == "":
return fmt.Errorf("provider %q: channel_template.name is required", providerID)
case len(provider.ChannelTemplate.ModelMapping) == 0:
return fmt.Errorf("provider %q: channel_template.model_mapping must not be empty", providerID)
case !containsProviderModel(provider.ChannelTemplate.ModelMapping, provider.SmokeTestModel):
return fmt.Errorf("provider %q: channel_template.model_mapping must include smoke_test_model %q", providerID, provider.SmokeTestModel)
case missingDefaultModel != "":
return fmt.Errorf("provider %q: channel_template.model_mapping must cover default_models, missing %q", providerID, missingDefaultModel)
case strings.TrimSpace(provider.PlanTemplate.Name) == "":
return fmt.Errorf("provider %q: plan_template.name is required", providerID)
case provider.PlanTemplate.ValidityDays <= 0:
return fmt.Errorf("provider %q: plan_template.validity_days must be positive", providerID)
}
if _, ok := seen[providerID]; ok {
return fmt.Errorf("duplicate provider_id %q", providerID)
}
if err := validateHostOverlays(root, provider); err != nil {
return err
}
seen[providerID] = struct{}{}
}
return nil
}
func validateHostOverlays(root string, provider ProviderManifest) error {
if len(provider.HostOverlays) == 0 {
return nil
}
seen := make(map[string]struct{}, len(provider.HostOverlays))
for _, overlay := range provider.HostOverlays {
overlayID := strings.TrimSpace(overlay.OverlayID)
switch {
case overlayID == "":
return fmt.Errorf("provider %q: host_overlays.overlay_id is required", provider.ProviderID)
case strings.TrimSpace(overlay.DisplayName) == "":
return fmt.Errorf("provider %q overlay %q: display_name is required", provider.ProviderID, overlayID)
case strings.TrimSpace(overlay.TargetHost) == "":
return fmt.Errorf("provider %q overlay %q: target_host is required", provider.ProviderID, overlayID)
case strings.TrimSpace(overlay.Reason) == "":
return fmt.Errorf("provider %q overlay %q: reason is required", provider.ProviderID, overlayID)
}
if _, ok := seen[overlayID]; ok {
return fmt.Errorf("provider %q: duplicate host overlay %q", provider.ProviderID, overlayID)
}
if err := validateHostOverlayVersionRange(overlay); err != nil {
return fmt.Errorf("provider %q overlay %q: %w", provider.ProviderID, overlayID, err)
}
if err := validateOverlayFileRef(root, overlay.PatchPath); err != nil {
return fmt.Errorf("provider %q overlay %q: %w", provider.ProviderID, overlayID, err)
}
if err := validateOverlayFileRef(root, overlay.NotesPath); err != nil {
return fmt.Errorf("provider %q overlay %q: %w", provider.ProviderID, overlayID, err)
}
seen[overlayID] = struct{}{}
}
return nil
}
func validateHostOverlayVersionRange(overlay HostOverlay) error {
minVersion := strings.TrimSpace(overlay.MinHostVersion)
if minVersion != "" {
if _, err := parseVersion(minVersion); err != nil {
return fmt.Errorf("parse min_host_version: %w", err)
}
}
maxVersion := strings.TrimSpace(overlay.MaxHostVersion)
if maxVersion != "" {
if strings.HasSuffix(normalizeVersion(maxVersion), ".x") {
if _, err := matchesMaxConstraint("0.0.0", maxVersion); err != nil {
return fmt.Errorf("parse max_host_version: %w", err)
}
} else if _, err := parseVersion(maxVersion); err != nil {
return fmt.Errorf("parse max_host_version: %w", err)
}
}
if minVersion != "" && maxVersion != "" && !strings.HasSuffix(normalizeVersion(maxVersion), ".x") {
cmp, err := compareVersions(minVersion, maxVersion)
if err != nil {
return fmt.Errorf("compare version range: %w", err)
}
if cmp > 0 {
return fmt.Errorf("min_host_version %q is above max_host_version %q", minVersion, maxVersion)
}
}
return nil
}
func validateOverlayFileRef(root string, relativePath string) error {
trimmed := strings.TrimSpace(relativePath)
if trimmed == "" {
return nil
}
if filepath.IsAbs(trimmed) {
return fmt.Errorf("file reference %q must be relative to pack root", trimmed)
}
if _, err := os.Stat(filepath.Join(root, trimmed)); err != nil {
return fmt.Errorf("read file reference %q: %w", trimmed, err)
}
return nil
}
func hasAllowedProviderBaseURL(baseURL string, allowInsecureBaseURL bool) bool {
if strings.HasPrefix(baseURL, "https://") {
return true
}
return allowInsecureBaseURL && strings.HasPrefix(baseURL, "http://")
}
func validateChecksums(root string, checksumFile string) error {
path := filepath.Join(root, checksumFile)
file, err := os.Open(path)
if err != nil {
return fmt.Errorf("read checksum file %q: %w", checksumFile, err)
}
defer file.Close()
scanner := bufio.NewScanner(file)
lineNumber := 0
for scanner.Scan() {
lineNumber++
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
parts := strings.Fields(line)
if len(parts) != 2 {
return fmt.Errorf("checksum file %q line %d: invalid format", checksumFile, lineNumber)
}
relativePath := parts[1]
body, err := os.ReadFile(filepath.Join(root, relativePath))
if err != nil {
return fmt.Errorf("checksum file %q line %d: read %q: %w", checksumFile, lineNumber, relativePath, err)
}
sum := sha256.Sum256(body)
actual := hex.EncodeToString(sum[:])
if !strings.EqualFold(parts[0], actual) {
return fmt.Errorf("checksum mismatch for %s", relativePath)
}
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("scan checksum file %q: %w", checksumFile, err)
}
return nil
}
func computeAggregateChecksum(root string, checksumFile string) (string, error) {
body, err := os.ReadFile(filepath.Join(root, checksumFile))
if err != nil {
return "", fmt.Errorf("read checksum file %q: %w", checksumFile, err)
}
sum := sha256.Sum256(body)
return hex.EncodeToString(sum[:]), nil
}
func contains(items []string, target string) bool {
for _, item := range items {
if strings.TrimSpace(item) == strings.TrimSpace(target) {
return true
}
}
return false
}
func containsProviderModel(modelMapping map[string]string, target string) bool {
trimmedTarget := strings.TrimSpace(target)
if trimmedTarget == "" {
return false
}
for sourceModel, mappedModel := range modelMapping {
if strings.TrimSpace(sourceModel) == trimmedTarget || strings.TrimSpace(mappedModel) == trimmedTarget {
return true
}
}
return false
}
func firstMissingDefaultModel(defaultModels []string, modelMapping map[string]string) string {
for _, model := range defaultModels {
trimmedModel := strings.TrimSpace(model)
if trimmedModel == "" {
continue
}
if !containsProviderModel(modelMapping, trimmedModel) {
return trimmedModel
}
}
return ""
}