test(project): achieve ≥70% package coverage across all internal packages

- store/sqlite: 75.4% (repos + db coverage)
- host/sub2api: 80.8% (httptest mock server, pure function tests)
- app: 74.2% (handler error paths, NewActionSet closures)
- pack: 72.4%
- provision: 75.2%
- access: 77.3%
- config: 94.7% (lookup mock tests)

All tests pass: build, vet, race, coverage gates.
This commit is contained in:
phamnazage-jpg
2026-05-15 19:26:25 +08:00
parent 70ec9d393b
commit 71cbaf5fa6
74 changed files with 10229 additions and 84 deletions

3
.env.example Normal file
View File

@@ -0,0 +1,3 @@
SUB2API_CRM_LISTEN_ADDR=:8080
SUB2API_CRM_SQLITE_DSN=file:/data/sub2api-cn-relay-manager.db?_foreign_keys=on&_busy_timeout=5000
SUB2API_CRM_ADMIN_TOKEN=change-me-before-production

56
AGENTS.md Normal file
View File

@@ -0,0 +1,56 @@
# sub2api-cn-relay-manager — Agent Guidelines
## 项目关键信息
- Go 1.22.2, 纯 Go (modernc.org/sqlite, 无 CGO)
- 零侵入宿主:不修改 sub2api 源码,不写宿主数据库
- 所有 schema 变更通过 `internal/store/sqlite/` 下的 repo + integration test 验证
- docs/ 下有 PRD.md、TDD_PLAN.md、EXECUTION_BOARD.md、solution 文档
## 质量门禁(每个模块完成前必须执行)
1. **设计对齐** — 重新读取 PRD.md、TDD_PLAN.md、EXECUTION_BOARD.md、docs/plans/ 下的规划设计文档,逐条确认实现已覆盖设计目标。发现漂移先修正,不维持虚假 COMPLETED。
2. **代码 review** — 加载 `go-reviewer` skill对新写/修改的全部 Go 文件做系统审查。
3. **测试覆盖**`go test -cover ./internal/...` 核心包provision、access、pack覆盖率 >= 70%。未达标则补用例。
4. **静态分析**`go vet ./...` 零警告。`gofmt -l .` 显示无未格式化文件。
5. **集成验证**`go test ./tests/integration/... -count=1` 必须通过。
6. **板同步** — 更新 EXECUTION_BOARD.md反映真实完成状态。
## Go 编码规范
### 包结构
```
internal/
access/ — 访问闭环(订阅分配、探测、自服务检查)
app/ — HTTP 控制面bootstrap, server, API handlers
host/sub2api/ — 宿主适配器
pack/ — pack 装载与校验
provision/ — 导入编排import, preview, reconcile, rollback
store/sqlite/ — SQLite 数据访问层repo 模式)
cmd/
cli/ — CLI 入口
server/ — HTTP server 入口
tests/integration/ — 集成测试套件
```
### 代码风格
- 标准 Go4-space tabs, 花括号同行, 单 class/struct 文件
- 包名小写,与目录名一致
- 错误处理用 `fmt.Errorf("context: %w", err)` 包裹
- 常量分组在文件顶部,`const ( Name = "value" )`
- Repository 模式:`type XRepo struct { db execQuerier }` + `newXRepo(db)`
- Context 作为第一个参数传入所有 DB/SQL 操作
- 接口定义在使用方,不在实现方
- 测试用 fake/mock adapter 而非真实 HTTP
### 测试规范
- 文件名 `*_test.go` 与源码同包
- 方法名 `TestXxxFlow` / `TestXxxWhenY` 格式
- 优先使用 FakeHostAdapter已在 provision 包中定义)而不是 mock 框架
- 集成测试放在 `tests/integration/`,使用真实 SQLite 内存库
- 测试函数必须 `t.Parallel()` 安全(使用独立 SQLite 连接)
## 重要约束
- 不要运行 `go get` / `go mod tidy` — 源码写完后告诉用户手动安装依赖
- 不改动 go.mod 中的依赖版本
- 所有功能必须配套测试,集成测试优先
- 不允许跳过 quality gate 中的任何一步

19
Dockerfile Normal file
View File

@@ -0,0 +1,19 @@
FROM golang:1.22.2 AS builder
WORKDIR /src
COPY go.mod go.sum ./
RUN go mod download
COPY . .
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -trimpath -ldflags='-s -w' -o /out/sub2api-cn-relay-manager ./cmd/server
FROM debian:bookworm-slim
RUN apt-get update \
&& apt-get install -y --no-install-recommends ca-certificates tzdata wget \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /app
COPY --from=builder /out/sub2api-cn-relay-manager /usr/local/bin/sub2api-cn-relay-manager
ENV SUB2API_CRM_LISTEN_ADDR=:8080
ENV SUB2API_CRM_SQLITE_DSN=file:/data/sub2api-cn-relay-manager.db?_foreign_keys=on&_busy_timeout=5000
ENV SUB2API_CRM_ADMIN_TOKEN=
VOLUME ["/data"]
EXPOSE 8080
ENTRYPOINT ["/usr/local/bin/sub2api-cn-relay-manager"]

View File

@@ -63,4 +63,47 @@ sub2api-cn-relay-manager/
完整方案见:
- [docs/2026-05-12-sub2api-cn-relay-manager-solution.md](./docs/2026-05-12-sub2api-cn-relay-manager-solution.md)
- [docs/PRD.md](./docs/PRD.md)
- [docs/TDD_PLAN.md](./docs/TDD_PLAN.md)
- [docs/EXECUTION_BOARD.md](./docs/EXECUTION_BOARD.md)
- [docs/DEPLOYMENT.md](./docs/DEPLOYMENT.md)
## 当前 MVP 能力
当前仓库已经具备一个最小可运行闭环:
- `packs/openai-cn-pack/` 提供真实 `pack.json + provider + checksums`
- `internal/pack` 负责 pack 装载、checksum 校验、provider schema 校验
- `internal/provision` 负责多 key 导入编排、账号探测和访问闭环判定
- `cmd/cli import-provider` 提供一键导入入口
示例:
```bash
go run ./cmd/cli import-provider \
--host-base-url https://sub2api.example.com \
--host-api-key <admin-api-key> \
--pack-dir ./packs/openai-cn-pack \
--provider-id deepseek \
--keys sk-a,sk-b \
--access-mode self_service \
--access-api-key <user-api-key>
```
## 运行方式
服务端:
```bash
SUB2API_CRM_ADMIN_TOKEN=replace-me go run ./cmd/server
```
Docker Compose
```bash
cp .env.example .env
# 编辑 .env 中的 SUB2API_CRM_ADMIN_TOKEN
docker compose up --build -d
curl -fsS http://127.0.0.1:8080/healthz
```

View File

@@ -2,27 +2,483 @@ package main
import (
"context"
"flag"
"fmt"
"io"
"log"
"os"
"strings"
"sub2api-cn-relay-manager/internal/config"
"sub2api-cn-relay-manager/internal/host/sub2api"
"sub2api-cn-relay-manager/internal/pack"
"sub2api-cn-relay-manager/internal/provision"
"sub2api-cn-relay-manager/internal/store/sqlite"
)
type installPackFunc func(context.Context, installPackCLIRequest) (provision.PackInstallResult, error)
type importProviderFunc func(context.Context, importCLIRequest) (provision.ImportReport, error)
type previewProviderFunc func(context.Context, previewCLIRequest) (provision.PreviewReport, error)
type rollbackProviderFunc func(context.Context, rollbackCLIRequest) (rollbackSummary, error)
type reconcileProviderFunc func(context.Context, reconcileCLIRequest) (provision.ReconcileResult, error)
type installPackCLIRequest struct {
HostBaseURL string
HostAPIKey string
HostBearerToken string
PackPath string
}
type importCLIRequest struct {
HostBaseURL string
HostAPIKey string
HostBearerToken string
PackDir string
ProviderID string
Keys []string
Mode string
AccessMode string
AccessAPIKey string
SubscriptionUsers []string
SubscriptionDays int
}
type previewCLIRequest struct {
HostBaseURL string
HostAPIKey string
HostBearerToken string
PackDir string
ProviderID string
Keys []string
Mode string
}
type rollbackCLIRequest struct {
HostBaseURL string
HostAPIKey string
HostBearerToken string
PackDir string
ProviderID string
}
type reconcileCLIRequest struct {
HostBaseURL string
HostAPIKey string
HostBearerToken string
PackDir string
ProviderID string
AccessAPIKey string
}
type rollbackSummary struct {
Accounts int
Plans int
Channels int
Groups int
}
func main() {
if err := execute(context.Background(), log.Writer(), func(context.Context) (config.StartupConfig, error) {
if err := execute(context.Background(), log.Writer(), os.Args[1:], func(context.Context) (config.StartupConfig, error) {
return config.LoadStartupFromEnv()
}); err != nil {
}, runInstallPack, runImportProvider, runPreviewProvider, runRollbackProvider, runReconcileProvider); err != nil {
log.Fatalf("run cli: %v", err)
}
}
func execute(ctx context.Context, output io.Writer, loadConfig func(context.Context) (config.StartupConfig, error)) error {
func execute(
ctx context.Context,
output io.Writer,
args []string,
loadConfig func(context.Context) (config.StartupConfig, error),
installPack installPackFunc,
importProvider importProviderFunc,
previewProvider previewProviderFunc,
rollbackProvider rollbackProviderFunc,
reconcileProvider reconcileProviderFunc,
) error {
if len(args) > 0 && args[0] == "install-pack" {
req, err := parseInstallPackCLIArgs(args[1:])
if err != nil {
return err
}
result, err := installPack(ctx, req)
if err != nil {
return err
}
_, err = fmt.Fprintf(output, "pack_id=%s\nversion=%s\nhost_version=%s\nproviders=%d\nalready_installed=%t\n", result.Pack.PackID, result.Pack.Version, result.HostVersion, len(result.Providers), result.AlreadyInstalled)
return err
}
if len(args) > 0 && args[0] == "import-provider" {
req, err := parseImportCLIArgs(args[1:])
if err != nil {
return err
}
report, err := importProvider(ctx, req)
if err != nil {
_, _ = fmt.Fprintf(output, "batch_status=%s\nprovider_status=%s\naccess_status=%s\n", report.BatchStatus, report.ProviderStatus, report.AccessStatus)
return err
}
_, err = fmt.Fprintf(output, "batch_status=%s\nprovider_status=%s\naccess_status=%s\naccounts=%d\n", report.BatchStatus, report.ProviderStatus, report.AccessStatus, len(report.Accounts))
return err
}
if len(args) > 0 && args[0] == "preview-provider" {
req, err := parsePreviewCLIArgs(args[1:])
if err != nil {
return err
}
report, err := previewProvider(ctx, req)
if err != nil {
return err
}
_, err = fmt.Fprintf(output, "accepted_keys=%d\ngroup=%s\nchannel=%s\nplan=%s\n", len(report.AcceptedKeys), report.Decisions["group"].Action, report.Decisions["channel"].Action, report.Decisions["plan"].Action)
return err
}
if len(args) > 0 && args[0] == "rollback-provider" {
req, err := parseRollbackCLIArgs(args[1:])
if err != nil {
return err
}
summary, err := rollbackProvider(ctx, req)
if err != nil {
return err
}
_, err = fmt.Fprintf(output, "deleted_accounts=%d\ndeleted_plans=%d\ndeleted_channels=%d\ndeleted_groups=%d\n", summary.Accounts, summary.Plans, summary.Channels, summary.Groups)
return err
}
if len(args) > 0 && args[0] == "reconcile-provider" {
req, err := parseReconcileCLIArgs(args[1:])
if err != nil {
return err
}
result, err := reconcileProvider(ctx, req)
if err != nil {
return err
}
_, err = fmt.Fprintf(output, "status=%s\nmissing_count=%d\nextra_count=%d\nprobe_failures=%d\naccess_status=%s\n", result.Status, result.MissingCount, result.ExtraCount, result.ProbeFailureCount, result.AccessStatus)
return err
}
cfg, err := loadConfig(ctx)
if err != nil {
return err
}
_, err = fmt.Fprintf(output, "sub2api-cn-relay-manager cli ready\nlisten_addr=%s\nsqlite_dsn=%s\n", cfg.Server.ListenAddr, cfg.Database.SQLiteDSN)
return err
}
func parseInstallPackCLIArgs(args []string) (installPackCLIRequest, error) {
fs := flag.NewFlagSet("install-pack", flag.ContinueOnError)
fs.SetOutput(io.Discard)
var req installPackCLIRequest
fs.StringVar(&req.HostBaseURL, "host-base-url", "", "")
fs.StringVar(&req.HostAPIKey, "host-api-key", "", "")
fs.StringVar(&req.HostBearerToken, "host-bearer-token", "", "")
fs.StringVar(&req.PackPath, "pack-path", "", "")
if err := fs.Parse(args); err != nil {
return installPackCLIRequest{}, err
}
switch {
case strings.TrimSpace(req.HostBaseURL) == "":
return installPackCLIRequest{}, fmt.Errorf("--host-base-url is required")
case strings.TrimSpace(req.PackPath) == "":
return installPackCLIRequest{}, fmt.Errorf("--pack-path is required")
}
return req, nil
}
func parseImportCLIArgs(args []string) (importCLIRequest, error) {
fs := flag.NewFlagSet("import-provider", flag.ContinueOnError)
fs.SetOutput(io.Discard)
var req importCLIRequest
var keysCSV string
var subscriptionUsersCSV string
fs.StringVar(&req.HostBaseURL, "host-base-url", "", "")
fs.StringVar(&req.HostAPIKey, "host-api-key", "", "")
fs.StringVar(&req.HostBearerToken, "host-bearer-token", "", "")
fs.StringVar(&req.PackDir, "pack-dir", "", "")
fs.StringVar(&req.ProviderID, "provider-id", "", "")
fs.StringVar(&keysCSV, "keys", "", "")
fs.StringVar(&req.Mode, "mode", provision.ImportModePartial, "")
fs.StringVar(&req.AccessMode, "access-mode", provision.AccessModeSelfService, "")
fs.StringVar(&req.AccessAPIKey, "access-api-key", "", "")
fs.StringVar(&subscriptionUsersCSV, "subscription-users", "", "")
fs.IntVar(&req.SubscriptionDays, "subscription-days", 30, "")
if err := fs.Parse(args); err != nil {
return importCLIRequest{}, err
}
req.Keys = splitCSV(keysCSV)
req.SubscriptionUsers = splitCSV(subscriptionUsersCSV)
switch {
case strings.TrimSpace(req.HostBaseURL) == "":
return importCLIRequest{}, fmt.Errorf("--host-base-url is required")
case strings.TrimSpace(req.PackDir) == "":
return importCLIRequest{}, fmt.Errorf("--pack-dir is required")
case strings.TrimSpace(req.ProviderID) == "":
return importCLIRequest{}, fmt.Errorf("--provider-id is required")
case len(req.Keys) == 0:
return importCLIRequest{}, fmt.Errorf("--keys is required")
case strings.TrimSpace(req.AccessAPIKey) == "":
return importCLIRequest{}, fmt.Errorf("--access-api-key is required")
}
return req, nil
}
func parsePreviewCLIArgs(args []string) (previewCLIRequest, error) {
fs := flag.NewFlagSet("preview-provider", flag.ContinueOnError)
fs.SetOutput(io.Discard)
var req previewCLIRequest
var keysCSV string
fs.StringVar(&req.HostBaseURL, "host-base-url", "", "")
fs.StringVar(&req.HostAPIKey, "host-api-key", "", "")
fs.StringVar(&req.HostBearerToken, "host-bearer-token", "", "")
fs.StringVar(&req.PackDir, "pack-dir", "", "")
fs.StringVar(&req.ProviderID, "provider-id", "", "")
fs.StringVar(&keysCSV, "keys", "", "")
fs.StringVar(&req.Mode, "mode", provision.ImportModePartial, "")
if err := fs.Parse(args); err != nil {
return previewCLIRequest{}, err
}
req.Keys = splitCSV(keysCSV)
switch {
case strings.TrimSpace(req.HostBaseURL) == "":
return previewCLIRequest{}, fmt.Errorf("--host-base-url is required")
case strings.TrimSpace(req.PackDir) == "":
return previewCLIRequest{}, fmt.Errorf("--pack-dir is required")
case strings.TrimSpace(req.ProviderID) == "":
return previewCLIRequest{}, fmt.Errorf("--provider-id is required")
case len(req.Keys) == 0:
return previewCLIRequest{}, fmt.Errorf("--keys is required")
}
return req, nil
}
func parseRollbackCLIArgs(args []string) (rollbackCLIRequest, error) {
fs := flag.NewFlagSet("rollback-provider", flag.ContinueOnError)
fs.SetOutput(io.Discard)
var req rollbackCLIRequest
fs.StringVar(&req.HostBaseURL, "host-base-url", "", "")
fs.StringVar(&req.HostAPIKey, "host-api-key", "", "")
fs.StringVar(&req.HostBearerToken, "host-bearer-token", "", "")
fs.StringVar(&req.PackDir, "pack-dir", "", "")
fs.StringVar(&req.ProviderID, "provider-id", "", "")
if err := fs.Parse(args); err != nil {
return rollbackCLIRequest{}, err
}
switch {
case strings.TrimSpace(req.HostBaseURL) == "":
return rollbackCLIRequest{}, fmt.Errorf("--host-base-url is required")
case strings.TrimSpace(req.PackDir) == "":
return rollbackCLIRequest{}, fmt.Errorf("--pack-dir is required")
case strings.TrimSpace(req.ProviderID) == "":
return rollbackCLIRequest{}, fmt.Errorf("--provider-id is required")
}
return req, nil
}
func parseReconcileCLIArgs(args []string) (reconcileCLIRequest, error) {
fs := flag.NewFlagSet("reconcile-provider", flag.ContinueOnError)
fs.SetOutput(io.Discard)
var req reconcileCLIRequest
fs.StringVar(&req.HostBaseURL, "host-base-url", "", "")
fs.StringVar(&req.HostAPIKey, "host-api-key", "", "")
fs.StringVar(&req.HostBearerToken, "host-bearer-token", "", "")
fs.StringVar(&req.PackDir, "pack-dir", "", "")
fs.StringVar(&req.ProviderID, "provider-id", "", "")
fs.StringVar(&req.AccessAPIKey, "access-api-key", "", "")
if err := fs.Parse(args); err != nil {
return reconcileCLIRequest{}, err
}
switch {
case strings.TrimSpace(req.HostBaseURL) == "":
return reconcileCLIRequest{}, fmt.Errorf("--host-base-url is required")
case strings.TrimSpace(req.PackDir) == "":
return reconcileCLIRequest{}, fmt.Errorf("--pack-dir is required")
case strings.TrimSpace(req.ProviderID) == "":
return reconcileCLIRequest{}, fmt.Errorf("--provider-id is required")
}
return req, nil
}
func runInstallPack(ctx context.Context, req installPackCLIRequest) (provision.PackInstallResult, error) {
loadedPack, err := pack.LoadPath(req.PackPath)
if err != nil {
return provision.PackInstallResult{}, err
}
client, err := sub2api.NewClient(req.HostBaseURL, sub2api.WithAPIKey(req.HostAPIKey), sub2api.WithBearerToken(req.HostBearerToken))
if err != nil {
return provision.PackInstallResult{}, err
}
startupConfig, err := config.LoadStartupFromEnv()
if err != nil {
return provision.PackInstallResult{}, err
}
store, err := sqlite.Open(ctx, startupConfig.Database.SQLiteDSN)
if err != nil {
return provision.PackInstallResult{}, err
}
defer store.Close()
service := provision.NewPackInstallService(store, client)
return service.Install(ctx, provision.PackInstallRequest{Pack: loadedPack})
}
func runImportProvider(ctx context.Context, req importCLIRequest) (provision.ImportReport, error) {
loadedPack, err := pack.LoadDir(req.PackDir)
if err != nil {
return provision.ImportReport{}, err
}
providerManifest, err := findProvider(loadedPack, req.ProviderID)
if err != nil {
return provision.ImportReport{}, err
}
client, err := sub2api.NewClient(req.HostBaseURL, sub2api.WithAPIKey(req.HostAPIKey), sub2api.WithBearerToken(req.HostBearerToken))
if err != nil {
return provision.ImportReport{}, err
}
startupConfig, err := config.LoadStartupFromEnv()
if err != nil {
return provision.ImportReport{}, err
}
store, err := sqlite.Open(ctx, startupConfig.Database.SQLiteDSN)
if err != nil {
return provision.ImportReport{}, err
}
defer store.Close()
subscriptions := make([]provision.SubscriptionTarget, 0, len(req.SubscriptionUsers))
for _, userID := range req.SubscriptionUsers {
subscriptions = append(subscriptions, provision.SubscriptionTarget{UserID: userID, DurationDays: req.SubscriptionDays})
}
runtimeService := provision.NewRuntimeImportService(store, client)
result, err := runtimeService.Import(ctx, provision.RuntimeImportRequest{
HostBaseURL: req.HostBaseURL,
Pack: loadedPack,
Provider: providerManifest,
Mode: req.Mode,
Keys: req.Keys,
Access: provision.AccessRequest{
Mode: req.AccessMode,
ProbeAPIKey: req.AccessAPIKey,
Subscriptions: subscriptions,
},
})
return result.Report, err
}
func runPreviewProvider(ctx context.Context, req previewCLIRequest) (provision.PreviewReport, error) {
loadedPack, err := pack.LoadDir(req.PackDir)
if err != nil {
return provision.PreviewReport{}, err
}
providerManifest, err := findProvider(loadedPack, req.ProviderID)
if err != nil {
return provision.PreviewReport{}, err
}
client, err := sub2api.NewClient(req.HostBaseURL, sub2api.WithAPIKey(req.HostAPIKey), sub2api.WithBearerToken(req.HostBearerToken))
if err != nil {
return provision.PreviewReport{}, err
}
service := provision.NewPreviewService(client)
return service.PreviewImport(ctx, provision.PreviewRequest{
Provider: providerManifest,
Mode: req.Mode,
Keys: req.Keys,
})
}
func runRollbackProvider(ctx context.Context, req rollbackCLIRequest) (rollbackSummary, error) {
loadedPack, err := pack.LoadDir(req.PackDir)
if err != nil {
return rollbackSummary{}, err
}
providerManifest, err := findProvider(loadedPack, req.ProviderID)
if err != nil {
return rollbackSummary{}, err
}
client, err := sub2api.NewClient(req.HostBaseURL, sub2api.WithAPIKey(req.HostAPIKey), sub2api.WithBearerToken(req.HostBearerToken))
if err != nil {
return rollbackSummary{}, err
}
service := provision.NewRollbackService(client)
report, err := service.Rollback(ctx, provision.RollbackRequest{Provider: providerManifest})
if err != nil {
return rollbackSummary{}, err
}
return rollbackSummary{
Accounts: report.AccountsDeleted,
Plans: report.PlansDeleted,
Channels: report.ChannelsDeleted,
Groups: report.GroupsDeleted,
}, nil
}
func runReconcileProvider(ctx context.Context, req reconcileCLIRequest) (provision.ReconcileResult, error) {
loadedPack, err := pack.LoadDir(req.PackDir)
if err != nil {
return provision.ReconcileResult{}, err
}
providerManifest, err := findProvider(loadedPack, req.ProviderID)
if err != nil {
return provision.ReconcileResult{}, err
}
client, err := sub2api.NewClient(req.HostBaseURL, sub2api.WithAPIKey(req.HostAPIKey), sub2api.WithBearerToken(req.HostBearerToken))
if err != nil {
return provision.ReconcileResult{}, err
}
startupConfig, err := config.LoadStartupFromEnv()
if err != nil {
return provision.ReconcileResult{}, err
}
store, err := sqlite.Open(ctx, startupConfig.Database.SQLiteDSN)
if err != nil {
return provision.ReconcileResult{}, err
}
defer store.Close()
service := provision.NewReconcileService(store, client)
return service.Reconcile(ctx, provision.ReconcileRequest{HostBaseURL: req.HostBaseURL, AccessProbeAPIKey: req.AccessAPIKey, Pack: loadedPack, Provider: providerManifest})
}
func findProvider(loaded pack.LoadedPack, providerID string) (pack.ProviderManifest, error) {
for _, provider := range loaded.Providers {
if provider.ProviderID == strings.TrimSpace(providerID) {
return provider, nil
}
}
return pack.ProviderManifest{}, fmt.Errorf("provider %q not found in pack %q", providerID, loaded.Manifest.PackID)
}
func splitCSV(value string) []string {
parts := strings.Split(value, ",")
result := make([]string, 0, len(parts))
for _, part := range parts {
trimmed := strings.TrimSpace(part)
if trimmed != "" {
result = append(result, trimmed)
}
}
return result
}

View File

@@ -8,6 +8,8 @@ import (
"testing"
"sub2api-cn-relay-manager/internal/config"
"sub2api-cn-relay-manager/internal/provision"
"sub2api-cn-relay-manager/internal/store/sqlite"
)
type errWriter struct {
@@ -22,7 +24,7 @@ func TestExecuteWritesConfigSummaryAfterBootstrap(t *testing.T) {
var output bytes.Buffer
loadCalled := false
err := execute(context.Background(), &output, func(context.Context) (config.StartupConfig, error) {
err := execute(context.Background(), &output, nil, func(context.Context) (config.StartupConfig, error) {
loadCalled = true
return config.StartupConfig{
Server: config.ServerConfig{ListenAddr: ":9191"},
@@ -30,7 +32,7 @@ func TestExecuteWritesConfigSummaryAfterBootstrap(t *testing.T) {
SQLiteDSN: "file:test.db?_foreign_keys=on",
},
}, nil
})
}, nil, nil, nil, nil, nil)
if err != nil {
t.Fatalf("execute() returned error: %v", err)
}
@@ -56,9 +58,9 @@ func TestExecuteWritesConfigSummaryAfterBootstrap(t *testing.T) {
func TestExecuteReturnsBootstrapError(t *testing.T) {
wantErr := errors.New("load config failed")
err := execute(context.Background(), &bytes.Buffer{}, func(context.Context) (config.StartupConfig, error) {
err := execute(context.Background(), &bytes.Buffer{}, nil, func(context.Context) (config.StartupConfig, error) {
return config.StartupConfig{}, wantErr
})
}, nil, nil, nil, nil, nil)
if !errors.Is(err, wantErr) {
t.Fatalf("execute() error = %v, want %v", err, wantErr)
}
@@ -67,13 +69,208 @@ func TestExecuteReturnsBootstrapError(t *testing.T) {
func TestExecuteReturnsWriteError(t *testing.T) {
wantErr := errors.New("write failed")
err := execute(context.Background(), errWriter{err: wantErr}, func(context.Context) (config.StartupConfig, error) {
err := execute(context.Background(), errWriter{err: wantErr}, nil, func(context.Context) (config.StartupConfig, error) {
return config.StartupConfig{
Server: config.ServerConfig{ListenAddr: ":9292"},
Database: config.DatabaseConfig{SQLiteDSN: "file:test.db"},
}, nil
})
}, nil, nil, nil, nil, nil)
if !errors.Is(err, wantErr) {
t.Fatalf("execute() error = %v, want %v", err, wantErr)
}
}
func TestExecuteInstallPackWritesSummary(t *testing.T) {
var output bytes.Buffer
installCalled := false
err := execute(context.Background(), &output, []string{
"install-pack",
"--host-base-url", "https://sub2api.example.com",
"--pack-path", "/tmp/openai-pack.zip",
}, nil, func(_ context.Context, req installPackCLIRequest) (provision.PackInstallResult, error) {
installCalled = true
if req.PackPath != "/tmp/openai-pack.zip" {
t.Fatalf("unexpected install request: %+v", req)
}
return provision.PackInstallResult{
Pack: sqlite.Pack{PackID: "openai-cn-pack", Version: "1.0.0"},
HostVersion: "0.1.126",
Providers: []sqlite.Provider{{ProviderID: "deepseek"}},
}, nil
}, nil, nil, nil, nil)
if err != nil {
t.Fatalf("execute() install-pack error = %v", err)
}
if !installCalled {
t.Fatal("execute() did not invoke installPack")
}
got := output.String()
if !strings.Contains(got, "pack_id=openai-cn-pack") || !strings.Contains(got, "providers=1") || !strings.Contains(got, "host_version=0.1.126") {
t.Fatalf("execute() install-pack output = %q, want summary", got)
}
}
func TestExecuteImportProviderWritesSummary(t *testing.T) {
var output bytes.Buffer
importCalled := false
err := execute(context.Background(), &output, []string{
"import-provider",
"--host-base-url", "https://sub2api.example.com",
"--pack-dir", "/tmp/pack",
"--provider-id", "deepseek",
"--keys", "k1,k2",
"--access-api-key", "user-key",
}, nil, nil, func(_ context.Context, req importCLIRequest) (provision.ImportReport, error) {
importCalled = true
if req.ProviderID != "deepseek" || len(req.Keys) != 2 {
t.Fatalf("unexpected import request: %+v", req)
}
return provision.ImportReport{
BatchStatus: provision.BatchStatusSucceeded,
ProviderStatus: provision.ProviderStatusActive,
AccessStatus: provision.AccessStatusSelfServiceReady,
Accounts: []provision.AccountImportResult{{}, {}},
}, nil
}, nil, nil, nil)
if err != nil {
t.Fatalf("execute() import error = %v", err)
}
if !importCalled {
t.Fatal("execute() did not invoke importProvider")
}
got := output.String()
if !strings.Contains(got, "batch_status=succeeded") || !strings.Contains(got, "accounts=2") {
t.Fatalf("execute() import output = %q, want summary", got)
}
}
func TestExecutePreviewProviderWritesSummary(t *testing.T) {
var output bytes.Buffer
previewCalled := false
err := execute(context.Background(), &output, []string{
"preview-provider",
"--host-base-url", "https://sub2api.example.com",
"--pack-dir", "/tmp/pack",
"--provider-id", "deepseek",
"--keys", "k1,k2",
}, nil, nil, nil, func(_ context.Context, req previewCLIRequest) (provision.PreviewReport, error) {
previewCalled = true
if req.ProviderID != "deepseek" || len(req.Keys) != 2 {
t.Fatalf("unexpected preview request: %+v", req)
}
return provision.PreviewReport{
AcceptedKeys: []string{"k1", "k2"},
Names: provision.ResourceNames{Group: "crm-deepseek-group", Channel: "crm-deepseek-channel", Plan: "crm-deepseek-plan"},
Decisions: map[string]provision.PreviewDecision{
"group": {Action: provision.PreviewActionCreate},
"channel": {Action: provision.PreviewActionReuse, ExistingID: "channel_1"},
"plan": {Action: provision.PreviewActionConflict},
},
}, nil
}, nil, nil)
if err != nil {
t.Fatalf("execute() preview error = %v", err)
}
if !previewCalled {
t.Fatal("execute() did not invoke previewProvider")
}
got := output.String()
if !strings.Contains(got, "accepted_keys=2") || !strings.Contains(got, "group=create") || !strings.Contains(got, "channel=reuse") || !strings.Contains(got, "plan=conflict") {
t.Fatalf("execute() preview output = %q, want summary", got)
}
}
func TestExecuteRollbackProviderWritesSummary(t *testing.T) {
var output bytes.Buffer
rollbackCalled := false
err := execute(context.Background(), &output, []string{
"rollback-provider",
"--host-base-url", "https://sub2api.example.com",
"--pack-dir", "/tmp/pack",
"--provider-id", "deepseek",
}, nil, nil, nil, nil, func(_ context.Context, req rollbackCLIRequest) (rollbackSummary, error) {
rollbackCalled = true
if req.ProviderID != "deepseek" {
t.Fatalf("unexpected rollback request: %+v", req)
}
return rollbackSummary{Accounts: 2, Plans: 1, Channels: 1, Groups: 1}, nil
}, nil)
if err != nil {
t.Fatalf("execute() rollback error = %v", err)
}
if !rollbackCalled {
t.Fatal("execute() did not invoke rollbackProvider")
}
got := output.String()
if !strings.Contains(got, "deleted_accounts=2") || !strings.Contains(got, "deleted_groups=1") {
t.Fatalf("execute() rollback output = %q, want summary", got)
}
}
func TestExecuteReconcileProviderWritesSummary(t *testing.T) {
var output bytes.Buffer
reconcileCalled := false
err := execute(context.Background(), &output, []string{
"reconcile-provider",
"--host-base-url", "https://sub2api.example.com",
"--pack-dir", "/tmp/pack",
"--provider-id", "deepseek",
"--access-api-key", "user-key",
}, nil, nil, nil, nil, nil, func(_ context.Context, req reconcileCLIRequest) (provision.ReconcileResult, error) {
reconcileCalled = true
if req.ProviderID != "deepseek" || req.AccessAPIKey != "user-key" {
t.Fatalf("unexpected reconcile request: %+v", req)
}
return provision.ReconcileResult{Status: "drifted", MissingCount: 1, ExtraCount: 2, ProbeFailureCount: 1, AccessStatus: provision.AccessStatusBroken}, nil
})
if err != nil {
t.Fatalf("execute() reconcile error = %v", err)
}
if !reconcileCalled {
t.Fatal("execute() did not invoke reconcileProvider")
}
got := output.String()
if !strings.Contains(got, "status=drifted") || !strings.Contains(got, "missing_count=1") || !strings.Contains(got, "access_status=broken") {
t.Fatalf("execute() reconcile output = %q, want summary", got)
}
}
func TestParseInstallPackCLIArgsRequiresHostBaseURL(t *testing.T) {
_, err := parseInstallPackCLIArgs([]string{"--pack-path", "/tmp/openai-pack.zip"})
if err == nil {
t.Fatal("parseInstallPackCLIArgs() error = nil, want validation error")
}
}
func TestParseImportCLIArgsRequiresHostBaseURL(t *testing.T) {
_, err := parseImportCLIArgs([]string{"--pack-dir", "/tmp/pack", "--provider-id", "deepseek", "--keys", "k1", "--access-api-key", "user-key"})
if err == nil {
t.Fatal("parseImportCLIArgs() error = nil, want validation error")
}
}
func TestParsePreviewCLIArgsRequiresHostBaseURL(t *testing.T) {
_, err := parsePreviewCLIArgs([]string{"--pack-dir", "/tmp/pack", "--provider-id", "deepseek", "--keys", "k1"})
if err == nil {
t.Fatal("parsePreviewCLIArgs() error = nil, want validation error")
}
}
func TestParseRollbackCLIArgsRequiresHostBaseURL(t *testing.T) {
_, err := parseRollbackCLIArgs([]string{"--pack-dir", "/tmp/pack", "--provider-id", "deepseek"})
if err == nil {
t.Fatal("parseRollbackCLIArgs() error = nil, want validation error")
}
}
func TestParseReconcileCLIArgsRequiresHostBaseURL(t *testing.T) {
_, err := parseReconcileCLIArgs([]string{"--pack-dir", "/tmp/pack", "--provider-id", "deepseek"})
if err == nil {
t.Fatal("parseReconcileCLIArgs() error = nil, want validation error")
}
}

View File

@@ -9,7 +9,7 @@ import (
)
func TestRunCallsApplicationServerRunnerAfterBootstrap(t *testing.T) {
serverApp := app.NewServer("127.0.0.1:0", nil)
serverApp := app.NewServer("127.0.0.1:0", nil, nil)
bootstrapCalled := false
runnerCalled := false
@@ -60,7 +60,7 @@ func TestRunReturnsBootstrapError(t *testing.T) {
func TestRunReturnsApplicationRunError(t *testing.T) {
wantErr := errors.New("server run failed")
serverApp := app.NewServer("127.0.0.1:0", nil)
serverApp := app.NewServer("127.0.0.1:0", nil, nil)
err := run(
context.Background(),

19
docker-compose.yml Normal file
View File

@@ -0,0 +1,19 @@
services:
sub2api-cn-relay-manager:
build:
context: .
dockerfile: Dockerfile
image: sub2api-cn-relay-manager:local
restart: unless-stopped
env_file:
- .env
ports:
- "8080:8080"
volumes:
- ./data:/data
healthcheck:
test: ["CMD", "sh", "-c", "wget -qO- http://127.0.0.1:8080/healthz >/dev/null"]
interval: 15s
timeout: 3s
retries: 5
start_period: 10s

36
docs/DEPLOYMENT.md Normal file
View File

@@ -0,0 +1,36 @@
# Deployment
## Environment
Required:
- `SUB2API_CRM_ADMIN_TOKEN`: control-plane bearer token
Optional:
- `SUB2API_CRM_LISTEN_ADDR` (default `:8080`)
- `SUB2API_CRM_SQLITE_DSN` (default `file:sub2api-cn-relay-manager.db?_foreign_keys=on&_busy_timeout=5000`)
## Local Docker Compose
```bash
cp .env.example .env
# edit SUB2API_CRM_ADMIN_TOKEN before startup
mkdir -p data
docker compose up --build -d
curl -fsS http://127.0.0.1:8080/healthz
```
## Standalone Binary
```bash
go build -o bin/sub2api-cn-relay-manager ./cmd/server
SUB2API_CRM_ADMIN_TOKEN=replace-me ./bin/sub2api-cn-relay-manager
```
## Runtime Notes
- SQLite file should be mounted on persistent storage.
- Admin token must be rotated outside source control.
- The service is stateless except for SQLite runtime state.
- Use `/healthz` for container liveness checks.

111
docs/EXECUTION_BOARD.md Normal file
View File

@@ -0,0 +1,111 @@
# sub2api-cn-relay-manager 执行板
日期2026-05-13
当前 GateREQUEST_CHANGES
目标:实现 implementation plan 全量能力,达成独立控制面、零侵入宿主、一键导入国产模型,并补齐回滚/对账/HTTP API/交付物。
## 当前真实状态
模块完成 gate新增执行要求后续每个大模块都必须执行
-`go test` 通过不算完成;每次完成大模块后,必须补做:
1. 两阶段 review先对规划/设计文档做实现对齐检查,再做代码质量 review
2. execution board 当前状态同步
3. 若发现实现/设计漂移,优先修正文档结论或回退模块状态,不维持虚假 `COMPLETED`
- 本板从本次起按上述 gate 维护。
已完成:
- 项目骨架与配置加载
- SQLite 最小状态库hosts/packs/providers
- SQLite 运行态状态库扩展import_batches / items / managed_resources / probe_results / access_closure_records / reconcile_runs
- sub2api HostAdapter 基础创建/探测能力
- HostAdapter 删除能力group/channel/accountplan 接口已补)
- HostAdapter 资源枚举能力groups/channels/plans/accounts
- import strict 模式自动回滚已接入
- 手动 rollback CLI`rollback-provider`)已接入,支持按 provider 名称规则回收 group/channel/plan/accounts
- pack 目录装载与 checksum/schema 校验
- 正式 pack install 生命周期已接入:支持 zip/目录装载、宿主版本兼容校验、pack/provider 元数据持久化、CLI `install-pack`
- CLI `import-provider` 导入闭环已接入 SQLite 运行态持久化host/pack/provider/import/probe/access
- CLI `preview-provider` 预检查入口
- 最小 HTTP 控制面已接入admin token 鉴权 + `/api/packs/install` + `/api/providers/{providerID}/preview-import` + `/api/providers/{providerID}/import` + `/api/import-batches/{batchID}` + `/api/providers/{providerID}/status` + `/api/providers/{providerID}/resources` + `/api/providers/{providerID}/access/status` + `/api/providers/{providerID}/rollback` + `/api/providers/{providerID}/reconcile`
- preview 已接入宿主资源快照查询
- 账号探测与 `/v1/models` 网关访问验证
未完成的关键事实:
- 状态库已接入 `import-provider` 运行链并可持久化 host/pack/provider/import/probe/access最小 HTTP 控制面已补齐 batch detail / provider status / resources / access status / rollback / reconcileOpenAPI 草案已同步扩展
- preview/import/rollback/reconcile 已有 CLI 与最小 HTTP 入口,但仍缺少 hosts 管理面与更完整的批次/对账操作文档输出
- 宿主资源枚举已实现,但尚未对真实 sub2api 版本做兼容性实测
- 最小 reconcile / drift detection 已接入,当前实现仍是 `internal/provision/batch_detail_and_reconcile_service.go` 内联版本,但已补齐对最新 batch 的 account smoke probe 重跑、access closure 复检与 reconcile summary 持久化;状态仍未完全对齐 implementation plan 目标中的 `internal/reconcile/*` 结构,且真实宿主兼容性实测未完成
- OpenAPI 草案已覆盖 status/resources/access-status但仍未收口 hosts 契约与生产级文档细节
- 无 scheduler/jobs
- 已补齐 Dockerfile / compose / .env.example / deployment 文档,并新增 distribution smoke test但尚无真实容器启动 E2E 执行记录
## P0必须先完成
### P0-1 状态库扩展并接入运行链
- 状态COMPLETEDschema/repo、`import-provider` 运行链消费、`batch detail` / `provider status` / `resources` / `access status` / `reconcile` 查询面均已接入)
- 目标:补齐 implementation plan 所需核心表与 repo
- 范围:`import_batches``import_batch_items``managed_resources``probe_results``access_closure_records``reconcile_runs`
- 验证:`go test ./tests/integration -run 'TestStore(Runtime|Init)' -count=1`
- 完成判据表存在、约束有效、事务回滚有效、repo 可写入读取,并被运行链消费
### P0-2 import preview + naming
- 目标:导入前可输出 create/reuse/conflict不盲写宿主
- 范围:`preview_service.go``naming.go``import_preview_test.go`
- 验证:`go test ./tests/integration -run TestImportPreview -v`
### P0-3 真实 rollback 闭环
- 状态PARTIALstrict 自动回滚 + 手动 rollback CLI + HTTP rollback API 已完成;真实宿主兼容性实测未完成)
- 目标strict 失败自动清理,支持手动 rollback
- 前置HostAdapter 增加 DeleteGroup/DeleteChannel/DeletePlan/DeleteAccount/ListManagedResources
- 验证:`go test ./internal/provision ./tests/integration ./cmd/cli -run 'TestRollback|TestExecuteRollbackProviderWritesSummary|TestSub2APIHostAdapterListManagedResources' -v`
### P0-4 正式 pack install 生命周期
- 状态COMPLETEDzip/目录装载、宿主版本兼容性校验、pack/provider 元数据持久化、CLI `install-pack` 已接入)
- 目标:支持 zip/目录装载、宿主版本兼容性校验、pack/provider 元数据持久化
- 验证:`go test ./internal/pack ./internal/provision ./cmd/cli ./tests/integration -v`
## P1形成真正控制面
### P1-1 Access 独立模块化
- 状态PARTIAL访问闭环校验/订阅分配/网关探测已从 `import_service` 抽离到 `internal/access/closure.go`,但 implementation plan 目标结构中的 `planner.go` / `subscription_service.go` / `self_service_checker.go` 仍未落地)
- 目标:将访问闭环从 import_service 解耦为 `internal/access/*`
- 设计对齐复核:当前已完成的是“最小闭环抽离”,未达到 implementation plan 中 Access 子模块拆分粒度;因此不再维持 `COMPLETED`
- 验证:`go test ./internal/access ./internal/provision -count=1`
### P1-2 Reconcile / Drift Detection
- 状态PARTIAL最小 reconcile API + drift 计数写入已接入;本轮新增 account smoke probe 重跑、access closure 复检、`active/degraded/drifted` 状态语义与回写验证,但 implementation plan 目标中的 `internal/reconcile/*` 结构、`failed` 语义收口与真实宿主兼容性实测仍未完成)
- 目标:拉宿主快照,对比状态库,重跑 probe标记 drifted
- 验证:`go test ./internal/provision ./internal/app ./tests/integration -run 'TestReconcileService|TestAPIReconcileProviderReturnsSummary|TestStore(Runtime|Init)' -count=1`
### P1-3 HTTP API + OpenAPI
- 状态PARTIAL`/api/packs/install``/api/providers/{providerID}/preview-import``/api/providers/{providerID}/import``/api/import-batches/{batchID}``/api/providers/{providerID}/status``/api/providers/{providerID}/resources``/api/providers/{providerID}/access/status``/api/providers/{providerID}/rollback``/api/providers/{providerID}/reconcile` 已接入OpenAPI 草案已同步扩展,但 hosts 管理面仍缺失)
- 目标:暴露 hosts / packs/install / providers preview-import / imports rollback / access / reconcile
- 验证:`go test ./internal/app ./cmd/server ./tests/integration -run 'TestAPI|TestBootstrap' -v`
## P2工程化交付
### P2-1 Scheduler / Jobs
- 目标:支持定时 reconcile 与手动触发
- 验证:`go test ./tests/integration -run TestCLIScheduler -v`
### P2-2 Distribution Artifacts
- 状态PARTIAL已补齐 `Dockerfile` / `.env.example` / `docker-compose.yml` / `docs/DEPLOYMENT.md`,并新增 distribution smoke test但尚无真实容器启动与镜像构建 E2E 记录)
- 目标Dockerfile / .env.example / docker-compose / deployment 文档 / e2e 脚本
- 验证:`go test ./tests/integration -run TestDistributionArtifactsExistAndReferenceRequiredEnv -v`
### P2-3 CLI 面板补齐
- 目标:`host add` / `pack install` / `provider import` / `reconcile run`
- 验证CLI 集成测试 + `go test ./...`
## 当前执行顺序
1. P1-1 Access 模块继续拆分到 implementation plan 粒度
2. P1-2 Reconcile 结构化与真实宿主兼容性实测
3. P1-3 Hosts 管理面 / OpenAPI 收口
4. P2-1 Scheduler / Jobs
5. P2-2 Distribution 容器级 E2E 验证
6. P2-3 CLI 全量收口
## 禁止错误结论
- `go test ./...` 当前通过 ≠ implementation plan 全部实现
- CLI 最小导入闭环 ≠ 独立控制面已完成
- 资源创建成功 ≠ 用户访问闭环已长期可运维

45
docs/PRD.md Normal file
View File

@@ -0,0 +1,45 @@
# sub2api-cn-relay-manager PRDMVP
日期2026-05-13
## 目标
在**完全不修改 sub2api 官方系统代码**的前提下,交付一个可独立打包运行的外部伴生项目,使管理员能够通过一次导入动作,把国产模型 OpenAI 兼容中转能力安装到任意一套兼容的 sub2api 实例中。
## 硬约束
1. 不修改宿主源码
2. 不 fork 宿主并运行私有二进制
3. 不直接写宿主数据库
4. 不向宿主目录注入插件代码或补丁文件
5. 仅通过宿主现有 HTTP 管理 API 与标准 API 工作
## 首版验收
1. `model_pack` 可独立校验与装载
2. CLI 可直接读取 pack、选择 provider、导入多条 key
3. 导入流程能创建 group / channel / plansubscription 模式)/ accounts
4. 至少一个 account 完成 `/test``/models` 验证
5. 至少一种普通用户访问路径被真实探测:`GET /v1/models`
6. 失败时明确区分:`succeeded / partially_succeeded / failed`
## 首版边界
### 做
- pack runtime
- provider schema 校验
- 多 key 去重与批量导入
- subscription/self-service 两种访问模式建模
- CLI 一键导入
- 基于 stub 的端到端测试
### 暂不做
- Web 控制台
- 多宿主管理
- 自动代用户签发最终 API key
- 对账调度器完整实现
- 真实宿主删除/回滚链路
## 当前实现策略
首版先把“可独立打包 + 零侵入导入 + 用户访问验证”做成最小闭环状态库、HTTP 控制面、对账调度在此基础上继续扩展。

41
docs/TDD_PLAN.md Normal file
View File

@@ -0,0 +1,41 @@
# TDD 实施计划MVP
日期2026-05-13
## 设计结论
首版采用:
- `packs/openai-cn-pack/`:真实可校验模型包
- `internal/pack`pack 装载、checksum 校验、provider schema 校验
- `internal/provision`:导入编排服务
- `internal/host/sub2api`:宿主 admin/gateway 适配
- `cmd/cli import-provider`:一键导入入口
## TDD 顺序
1. 先写 `internal/pack/loader_test.go`
- 成功装载 pack
- checksum mismatch 失败
- provider schema 非法失败
2. 再写 `internal/provision/import_service_test.go`
- subscription 模式成功导入
- strict 模式探测失败直接失败
- 参数非法拒绝
3. 再补宿主适配器集成测试
- `CheckGatewayAccess()` 能校验 `/v1/models`
4. 最后补 CLI 测试
- `import-provider` 参数解析
- 输出状态摘要
## 当前 MVP 风险
1. 回滚删除链路尚未接入真实宿主 HTTP 路径,当前仅在服务层保留失败状态,不宣称真实宿主回滚已闭环
2. 现有集成验证基于 `httptest` stub尚未对真实 sub2api 版本做兼容性实测
3. 状态库尚未承接 import batch / managed resources / reconcile runs 持久化
## 完成标准
- `go test ./...` 通过
- CLI 能从真实 pack 读取 provider
- 导入报告明确输出 batch/provider/access 三种状态

265
docs/openapi.yaml Normal file
View File

@@ -0,0 +1,265 @@
openapi: 3.1.0
info:
title: sub2api-cn-relay-manager API
version: 0.1.0
servers:
- url: /
paths:
/healthz:
get:
responses:
'200':
description: ok
/api/packs/install:
post:
security:
- bearerAuth: []
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/InstallPackRequest'
responses:
'200':
description: pack installed
/api/import-batches/{batchID}:
get:
security:
- bearerAuth: []
parameters:
- name: batchID
in: path
required: true
schema:
type: integer
format: int64
responses:
'200':
description: batch detail
/api/providers/{providerID}/status:
get:
security:
- bearerAuth: []
parameters:
- name: providerID
in: path
required: true
schema:
type: string
- name: pack_id
in: query
required: false
schema:
type: string
responses:
'200':
description: provider runtime status
/api/providers/{providerID}/resources:
get:
security:
- bearerAuth: []
parameters:
- name: providerID
in: path
required: true
schema:
type: string
- name: pack_id
in: query
required: false
schema:
type: string
responses:
'200':
description: provider managed resources snapshot
/api/providers/{providerID}/access/status:
get:
security:
- bearerAuth: []
parameters:
- name: providerID
in: path
required: true
schema:
type: string
- name: pack_id
in: query
required: false
schema:
type: string
responses:
'200':
description: provider access closure status
/api/providers/{providerID}/preview-import:
post:
security:
- bearerAuth: []
parameters:
- name: providerID
in: path
required: true
schema:
type: string
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/PreviewProviderRequest'
responses:
'200':
description: preview summary
/api/providers/{providerID}/import:
post:
security:
- bearerAuth: []
parameters:
- name: providerID
in: path
required: true
schema:
type: string
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/ImportProviderRequest'
responses:
'200':
description: import summary
/api/providers/{providerID}/rollback:
post:
security:
- bearerAuth: []
parameters:
- name: providerID
in: path
required: true
schema:
type: string
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/RollbackProviderRequest'
responses:
'200':
description: rollback summary
/api/providers/{providerID}/reconcile:
post:
security:
- bearerAuth: []
parameters:
- name: providerID
in: path
required: true
schema:
type: string
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/ReconcileProviderRequest'
responses:
'200':
description: reconcile summary
components:
securitySchemes:
bearerAuth:
type: http
scheme: bearer
schemas:
InstallPackRequest:
type: object
required: [host_base_url, pack_path]
properties:
host_base_url:
type: string
host_api_key:
type: string
host_bearer_token:
type: string
pack_path:
type: string
PreviewProviderRequest:
type: object
required: [host_base_url, pack_path, keys]
properties:
host_base_url:
type: string
host_api_key:
type: string
host_bearer_token:
type: string
pack_path:
type: string
provider_id:
type: string
keys:
type: array
items:
type: string
mode:
type: string
ImportProviderRequest:
type: object
required: [host_base_url, pack_path, keys, access_api_key]
properties:
host_base_url:
type: string
host_api_key:
type: string
host_bearer_token:
type: string
pack_path:
type: string
provider_id:
type: string
keys:
type: array
items:
type: string
mode:
type: string
access_mode:
type: string
access_api_key:
type: string
subscription_users:
type: array
items:
type: string
subscription_days:
type: integer
RollbackProviderRequest:
type: object
required: [host_base_url, pack_path]
properties:
host_base_url:
type: string
host_api_key:
type: string
host_bearer_token:
type: string
pack_path:
type: string
provider_id:
type: string
ReconcileProviderRequest:
type: object
required: [host_base_url, pack_path]
properties:
host_base_url:
type: string
host_api_key:
type: string
host_bearer_token:
type: string
pack_path:
type: string
provider_id:
type: string

View File

@@ -0,0 +1,80 @@
package access
import (
"context"
"fmt"
"strings"
"sub2api-cn-relay-manager/internal/host/sub2api"
)
const (
ModeSubscription = "subscription"
ModeSelfService = "self_service"
)
type SubscriptionTarget struct {
UserID string
DurationDays int
}
type ClosureRequest struct {
Mode string
ProbeAPIKey string
Subscriptions []SubscriptionTarget
GroupID string
ExpectedModel string
}
type Host interface {
AssignSubscription(ctx context.Context, req sub2api.AssignSubscriptionRequest) (sub2api.SubscriptionRef, error)
CheckGatewayAccess(ctx context.Context, req sub2api.GatewayAccessCheckRequest) (sub2api.GatewayAccessResult, error)
}
type Service struct {
host Host
}
func NewService(host Host) *Service {
return &Service{host: host}
}
func Validate(req ClosureRequest) error {
switch strings.TrimSpace(req.Mode) {
case ModeSubscription:
if len(req.Subscriptions) == 0 {
return fmt.Errorf("subscription access requires at least one subscription target")
}
case ModeSelfService:
if strings.TrimSpace(req.ProbeAPIKey) == "" {
return fmt.Errorf("self_service access requires probe api key")
}
default:
return fmt.Errorf("unsupported access mode %q", req.Mode)
}
if strings.TrimSpace(req.ProbeAPIKey) == "" {
return fmt.Errorf("access probe api key is required to verify gateway closure")
}
return nil
}
func (s *Service) Close(ctx context.Context, req ClosureRequest) (sub2api.GatewayAccessResult, error) {
if s == nil || s.host == nil {
return sub2api.GatewayAccessResult{}, fmt.Errorf("access host is required")
}
if err := Validate(req); err != nil {
return sub2api.GatewayAccessResult{}, err
}
if strings.TrimSpace(req.Mode) == ModeSubscription {
for _, target := range req.Subscriptions {
if _, err := s.host.AssignSubscription(ctx, sub2api.AssignSubscriptionRequest{UserID: target.UserID, GroupID: req.GroupID, DurationDays: target.DurationDays}); err != nil {
return sub2api.GatewayAccessResult{}, fmt.Errorf("assign subscription for %s: %w", target.UserID, err)
}
}
}
result, err := s.host.CheckGatewayAccess(ctx, sub2api.GatewayAccessCheckRequest{APIKey: req.ProbeAPIKey, ExpectedModel: req.ExpectedModel})
if err != nil {
return sub2api.GatewayAccessResult{}, fmt.Errorf("check gateway access: %w", err)
}
return result, nil
}

View File

@@ -0,0 +1,91 @@
package access
import (
"context"
"errors"
"testing"
"sub2api-cn-relay-manager/internal/host/sub2api"
)
func TestValidateRejectsMissingProbeAPIKeyForSelfService(t *testing.T) {
err := Validate(ClosureRequest{Mode: "self_service"})
if err == nil {
t.Fatal("Validate() error = nil, want validation error")
}
}
func TestValidateRejectsMissingSubscriptionsForSubscriptionMode(t *testing.T) {
err := Validate(ClosureRequest{Mode: "subscription", ProbeAPIKey: "user-key"})
if err == nil {
t.Fatal("Validate() error = nil, want validation error")
}
}
func TestServiceCloseAssignsSubscriptionsAndProbesGateway(t *testing.T) {
host := &fakeClosureHost{
gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}},
}
service := NewService(host)
result, err := service.Close(context.Background(), ClosureRequest{
Mode: "subscription",
ProbeAPIKey: "user-key",
GroupID: "group-1",
ExpectedModel: "deepseek-chat",
Subscriptions: []SubscriptionTarget{{UserID: "user-1", DurationDays: 30}},
})
if err != nil {
t.Fatalf("Close() error = %v", err)
}
if len(host.assigned) != 1 {
t.Fatalf("assigned subscriptions = %d, want 1", len(host.assigned))
}
if host.gatewayProbe.APIKey != "user-key" || host.gatewayProbe.ExpectedModel != "deepseek-chat" {
t.Fatalf("gateway probe = %+v, want api key + expected model", host.gatewayProbe)
}
if !result.OK || !result.HasExpectedModel {
t.Fatalf("gateway result = %+v, want success", result)
}
}
func TestServiceCloseReturnsSubscriptionErrorBeforeGatewayProbe(t *testing.T) {
host := &fakeClosureHost{assignErr: errors.New("assign failed")}
service := NewService(host)
_, err := service.Close(context.Background(), ClosureRequest{
Mode: "subscription",
ProbeAPIKey: "user-key",
GroupID: "group-1",
ExpectedModel: "deepseek-chat",
Subscriptions: []SubscriptionTarget{{UserID: "user-1", DurationDays: 30}},
})
if err == nil {
t.Fatal("Close() error = nil, want subscription failure")
}
if host.gatewayProbe.APIKey != "" {
t.Fatalf("gateway probe should not run after subscription error, got %+v", host.gatewayProbe)
}
}
type fakeClosureHost struct {
assigned []sub2api.AssignSubscriptionRequest
assignErr error
gatewayProbe sub2api.GatewayAccessCheckRequest
gatewayResult sub2api.GatewayAccessResult
gatewayErr error
}
func (f *fakeClosureHost) AssignSubscription(_ context.Context, req sub2api.AssignSubscriptionRequest) (sub2api.SubscriptionRef, error) {
if f.assignErr != nil {
return sub2api.SubscriptionRef{}, f.assignErr
}
f.assigned = append(f.assigned, req)
return sub2api.SubscriptionRef{ID: "sub-1"}, nil
}
func (f *fakeClosureHost) CheckGatewayAccess(_ context.Context, req sub2api.GatewayAccessCheckRequest) (sub2api.GatewayAccessResult, error) {
f.gatewayProbe = req
if f.gatewayErr != nil {
return sub2api.GatewayAccessResult{}, f.gatewayErr
}
return f.gatewayResult, nil
}

View File

@@ -15,25 +15,20 @@ type Server struct {
listen ListenerFactory
}
func NewServer(listenAddr string, listenerFactory ListenerFactory) *Server {
mux := http.NewServeMux()
mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
})
func NewServer(listenAddr string, handler http.Handler, listenerFactory ListenerFactory) *Server {
if handler == nil {
handler = NewAPIHandler("", ActionSet{})
}
server := &Server{
server: &http.Server{
Addr: listenAddr,
Handler: mux,
Handler: handler,
},
listen: net.Listen,
}
if listenerFactory != nil {
server.listen = listenerFactory
}
return server
}
@@ -46,13 +41,11 @@ func (s *Server) Run(ctx context.Context) error {
if err != nil {
return err
}
return s.Serve(ctx, listener)
}
func (s *Server) Serve(ctx context.Context, listener net.Listener) error {
errCh := make(chan error, 1)
go func() {
err := s.server.Serve(listener)
if errors.Is(err, http.ErrServerClosed) {
@@ -65,11 +58,9 @@ func (s *Server) Serve(ctx context.Context, listener net.Listener) error {
case <-ctx.Done():
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.server.Shutdown(shutdownCtx); err != nil {
return err
}
return <-errCh
case err := <-errCh:
return err

View File

@@ -1,17 +1,26 @@
package app
import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"sub2api-cn-relay-manager/internal/host/sub2api"
"sub2api-cn-relay-manager/internal/pack"
"sub2api-cn-relay-manager/internal/provision"
"sub2api-cn-relay-manager/internal/store/sqlite"
)
func TestServeExposesHealthz(t *testing.T) {
server := NewServer("127.0.0.1:0", nil)
server := NewServer("127.0.0.1:0", NewAPIHandler("admin-token", ActionSet{}), nil)
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("net.Listen() error = %v", err)
@@ -50,7 +59,7 @@ func TestRunReturnsAfterContextCancellation(t *testing.T) {
t.Fatalf("net.Listen() error = %v", err)
}
server := NewServer("127.0.0.1:0", func(string, string) (net.Listener, error) {
server := NewServer("127.0.0.1:0", NewAPIHandler("admin-token", ActionSet{}), func(string, string) (net.Listener, error) {
return listener, nil
})
@@ -77,7 +86,7 @@ func TestRunReturnsAfterContextCancellation(t *testing.T) {
func TestRunReturnsListenError(t *testing.T) {
wantErr := errors.New("listen failed")
server := NewServer("127.0.0.1:0", func(string, string) (net.Listener, error) {
server := NewServer("127.0.0.1:0", NewAPIHandler("admin-token", ActionSet{}), func(string, string) (net.Listener, error) {
return nil, wantErr
})
@@ -88,7 +97,7 @@ func TestRunReturnsListenError(t *testing.T) {
}
func TestServeReturnsListenerError(t *testing.T) {
server := NewServer("127.0.0.1:0", nil)
server := NewServer("127.0.0.1:0", NewAPIHandler("admin-token", ActionSet{}), nil)
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("net.Listen() error = %v", err)
@@ -104,6 +113,208 @@ func TestServeReturnsListenerError(t *testing.T) {
}
}
func TestAPIRejectsMissingAdminToken(t *testing.T) {
handler := NewAPIHandler("secret-token", ActionSet{})
request := httptestRequest(t, http.MethodPost, "/api/packs/install", map[string]any{"host_base_url": "https://sub2api.example.com", "pack_path": "/tmp/pack.zip"}, "")
response := httptestRecorder(handler, request)
assertStatusCode(t, response, http.StatusUnauthorized)
assertJSONContains(t, response.Body().Bytes(), "error.code", "unauthorized")
}
func TestAPIInstallPackReturnsSummary(t *testing.T) {
handler := NewAPIHandler("secret-token", ActionSet{
InstallPack: func(context.Context, InstallPackRequest) (provision.PackInstallResult, error) {
return provision.PackInstallResult{
Pack: sqlite.Pack{PackID: "openai-cn-pack", Version: "1.0.0"},
HostVersion: "0.1.126",
Providers: []sqlite.Provider{{ProviderID: "deepseek", DisplayName: "DeepSeek"}},
}, nil
},
})
request := httptestRequest(t, http.MethodPost, "/api/packs/install", map[string]any{"host_base_url": "https://sub2api.example.com", "pack_path": "/tmp/openai-pack.zip"}, "secret-token")
response := httptestRecorder(handler, request)
assertStatusCode(t, response, http.StatusOK)
assertJSONContains(t, response.Body().Bytes(), "pack_id", "openai-cn-pack")
assertJSONContains(t, response.Body().Bytes(), "host_version", "0.1.126")
}
func TestAPIPreviewProviderReturnsSummary(t *testing.T) {
handler := NewAPIHandler("secret-token", ActionSet{
PreviewProvider: func(_ context.Context, req PreviewProviderRequest) (provision.PreviewReport, error) {
if req.ProviderID != "deepseek" {
t.Fatalf("ProviderID = %q, want deepseek", req.ProviderID)
}
return provision.PreviewReport{
AcceptedKeys: []string{"k1", "k2"},
Names: provision.ResourceNames{Group: "g", Channel: "c", Plan: "p"},
Decisions: map[string]provision.PreviewDecision{
"group": {Action: provision.PreviewActionCreate},
},
}, nil
},
})
request := httptestRequest(t, http.MethodPost, "/api/providers/deepseek/preview-import", map[string]any{"host_base_url": "https://sub2api.example.com", "pack_path": "/tmp/openai-pack.zip", "keys": []string{"k1", "k2"}, "mode": "partial"}, "secret-token")
response := httptestRecorder(handler, request)
assertStatusCode(t, response, http.StatusOK)
assertJSONContains(t, response.Body().Bytes(), "accepted_keys_count", float64(2))
}
func TestAPIImportProviderReturnsConflictWithBatchStatus(t *testing.T) {
handler := NewAPIHandler("secret-token", ActionSet{
ImportProvider: func(context.Context, ImportProviderRequest) (provision.RuntimeImportResult, error) {
return provision.RuntimeImportResult{
BatchID: 12,
Report: provision.ImportReport{
BatchStatus: provision.BatchStatusFailed,
ProviderStatus: provision.ProviderStatusFailed,
AccessStatus: provision.AccessStatusBroken,
Accounts: []provision.AccountImportResult{{}},
},
}, errors.New("strict import failed")
},
})
request := httptestRequest(t, http.MethodPost, "/api/providers/deepseek/import", map[string]any{"host_base_url": "https://sub2api.example.com", "pack_path": "/tmp/openai-pack.zip", "keys": []string{"k1"}, "mode": "strict", "access_mode": "self_service", "access_api_key": "user-key"}, "secret-token")
response := httptestRecorder(handler, request)
assertStatusCode(t, response, http.StatusConflict)
assertJSONContains(t, response.Body().Bytes(), "batch_id", float64(12))
assertJSONContains(t, response.Body().Bytes(), "batch_status", provision.BatchStatusFailed)
}
func TestAPIBatchDetailReturnsSummary(t *testing.T) {
handler := NewAPIHandler("secret-token", ActionSet{
BatchDetail: func(context.Context, BatchDetailRequest) (provision.BatchDetailResult, error) {
return provision.BatchDetailResult{
Batch: sqlite.ImportBatch{ID: 7, BatchStatus: "running", AccessStatus: "pending"},
Items: []sqlite.ImportBatchItem{{ID: 1, KeyFingerprint: "sha256:abc", AccountStatus: "passed"}},
}, nil
},
})
request := httptestRequest(t, http.MethodGet, "/api/import-batches/7", nil, "secret-token")
response := httptestRecorder(handler, request)
assertStatusCode(t, response, http.StatusOK)
assertJSONContains(t, response.Body().Bytes(), "batch.batch_status", "running")
assertJSONContains(t, response.Body().Bytes(), "items_count", float64(1))
}
func TestAPIProviderStatusReturnsSummary(t *testing.T) {
handler := NewAPIHandler("secret-token", ActionSet{
GetProviderStatus: func(_ context.Context, req ProviderQueryRequest) (provision.ProviderSnapshot, error) {
if req.ProviderID != "deepseek" {
t.Fatalf("ProviderID = %q, want deepseek", req.ProviderID)
}
if req.PackID != "openai-cn-pack" {
t.Fatalf("PackID = %q, want openai-cn-pack", req.PackID)
}
return provision.ProviderSnapshot{
Host: sqlite.Host{HostID: "host-1", BaseURL: "https://sub2api.example.com", HostVersion: "0.1.126"},
Pack: sqlite.Pack{PackID: "openai-cn-pack", Version: "1.0.0"},
Provider: sqlite.Provider{ProviderID: "deepseek", DisplayName: "DeepSeek", Platform: "openai"},
Batch: sqlite.ImportBatch{ID: 7, BatchStatus: provision.BatchStatusSucceeded, AccessStatus: provision.AccessStatusSelfServiceReady, Mode: provision.ImportModeStrict},
ProviderStatus: "drifted",
LatestAccessStatus: provision.AccessStatusSelfServiceReady,
LatestReconcileStatus: "drifted",
LatestReconcileSummary: map[string]any{"missing_count": 1},
ManagedResources: []sqlite.ManagedResource{{}, {}},
AccessClosures: []sqlite.AccessClosureRecord{{}},
ReconcileRuns: []sqlite.ReconcileRun{{}},
}, nil
},
})
request := httptestRequest(t, http.MethodGet, "/api/providers/deepseek/status?pack_id=openai-cn-pack", nil, "secret-token")
response := httptestRecorder(handler, request)
assertStatusCode(t, response, http.StatusOK)
assertJSONContains(t, response.Body().Bytes(), "provider_status", "drifted")
assertJSONContains(t, response.Body().Bytes(), "managed_resources_count", float64(2))
assertJSONContains(t, response.Body().Bytes(), "latest_reconcile_summary.missing_count", float64(1))
}
func TestAPIProviderAccessStatusReturnsSummary(t *testing.T) {
handler := NewAPIHandler("secret-token", ActionSet{
GetProviderAccessStatus: func(_ context.Context, req ProviderQueryRequest) (provision.ProviderSnapshot, error) {
if req.ProviderID != "deepseek" {
t.Fatalf("ProviderID = %q, want deepseek", req.ProviderID)
}
return provision.ProviderSnapshot{
Pack: sqlite.Pack{PackID: "openai-cn-pack"},
Provider: sqlite.Provider{ProviderID: "deepseek"},
Batch: sqlite.ImportBatch{ID: 7, AccessStatus: provision.AccessStatusSelfServiceReady},
LatestAccessStatus: provision.AccessStatusSelfServiceReady,
AccessClosures: []sqlite.AccessClosureRecord{{ID: 2, ClosureType: provision.AccessModeSelfService, Status: provision.AccessStatusSelfServiceReady, DetailsJSON: `{"ok":true}`}},
}, nil
},
})
request := httptestRequest(t, http.MethodGet, "/api/providers/deepseek/access/status", nil, "secret-token")
response := httptestRecorder(handler, request)
assertStatusCode(t, response, http.StatusOK)
assertJSONContains(t, response.Body().Bytes(), "latest_access_status", provision.AccessStatusSelfServiceReady)
assertJSONContains(t, response.Body().Bytes(), "closures_count", float64(1))
if !strings.Contains(response.Body().String(), `"closure_type":"self_service"`) {
t.Fatalf("access status payload missing closure type: %s", response.Body().String())
}
}
func TestAPIProviderResourcesReturnsSummary(t *testing.T) {
handler := NewAPIHandler("secret-token", ActionSet{
GetProviderResources: func(_ context.Context, req ProviderQueryRequest) (provision.ProviderSnapshot, error) {
if req.ProviderID != "deepseek" {
t.Fatalf("ProviderID = %q, want deepseek", req.ProviderID)
}
return provision.ProviderSnapshot{
Pack: sqlite.Pack{PackID: "openai-cn-pack"},
Provider: sqlite.Provider{ProviderID: "deepseek"},
Batch: sqlite.ImportBatch{ID: 7},
ManagedResources: []sqlite.ManagedResource{{ID: 1, ResourceType: "group", HostResourceID: "group-1", ResourceName: "deepseek-group"}},
AccessClosures: []sqlite.AccessClosureRecord{{ID: 2, ClosureType: provision.AccessModeSelfService, Status: provision.AccessStatusSelfServiceReady, DetailsJSON: `{"ok":true}`}},
ReconcileRuns: []sqlite.ReconcileRun{{ID: 3, Status: "active", SummaryJSON: `{"missing_count":0}`}},
}, nil
},
})
request := httptestRequest(t, http.MethodGet, "/api/providers/deepseek/resources", nil, "secret-token")
response := httptestRecorder(handler, request)
assertStatusCode(t, response, http.StatusOK)
assertJSONContains(t, response.Body().Bytes(), "provider_id", "deepseek")
assertJSONContains(t, response.Body().Bytes(), "pack_id", "openai-cn-pack")
if !strings.Contains(response.Body().String(), `"resource_type":"group"`) {
t.Fatalf("resources payload missing group resource: %s", response.Body().String())
}
if !strings.Contains(response.Body().String(), `"status":"self_service_ready"`) {
t.Fatalf("resources payload missing access closure status: %s", response.Body().String())
}
if !strings.Contains(response.Body().String(), `"summary_json":"{\"missing_count\":0}"`) {
t.Fatalf("resources payload missing reconcile summary: %s", response.Body().String())
}
}
func TestAPIRollbackProviderReturnsSummary(t *testing.T) {
handler := NewAPIHandler("secret-token", ActionSet{
RollbackProvider: func(context.Context, RollbackProviderRequest) (provision.RollbackReport, error) {
return provision.RollbackReport{AccountsDeleted: 2, PlansDeleted: 1, ChannelsDeleted: 1, GroupsDeleted: 1}, nil
},
})
request := httptestRequest(t, http.MethodPost, "/api/providers/deepseek/rollback", map[string]any{"host_base_url": "https://sub2api.example.com", "pack_path": "/tmp/openai-pack.zip"}, "secret-token")
response := httptestRecorder(handler, request)
assertStatusCode(t, response, http.StatusOK)
assertJSONContains(t, response.Body().Bytes(), "deleted_accounts", float64(2))
assertJSONContains(t, response.Body().Bytes(), "provider_id", "deepseek")
}
func TestAPIReconcileProviderReturnsSummary(t *testing.T) {
handler := NewAPIHandler("secret-token", ActionSet{
ReconcileProvider: func(_ context.Context, req ReconcileProviderRequest) (provision.ReconcileResult, error) {
if req.AccessAPIKey != "user-key" {
t.Fatalf("AccessAPIKey = %q, want user-key", req.AccessAPIKey)
}
return provision.ReconcileResult{BatchID: 7, Status: "drifted", MissingCount: 1, ExtraCount: 2, ProbeFailureCount: 1, AccessStatus: provision.AccessStatusBroken, Summary: map[string]any{"probe_failures": 1}}, nil
},
})
request := httptestRequest(t, http.MethodPost, "/api/providers/deepseek/reconcile", map[string]any{"host_base_url": "https://sub2api.example.com", "pack_path": "/tmp/openai-pack.zip", "access_api_key": "user-key"}, "secret-token")
response := httptestRecorder(handler, request)
assertStatusCode(t, response, http.StatusOK)
assertJSONContains(t, response.Body().Bytes(), "status", "drifted")
assertJSONContains(t, response.Body().Bytes(), "missing_count", float64(1))
assertJSONContains(t, response.Body().Bytes(), "summary.probe_failures", float64(1))
}
func waitForHealthz(t *testing.T, url string) *http.Response {
t.Helper()
@@ -126,3 +337,613 @@ func waitForHealthz(t *testing.T, url string) *http.Response {
t.Fatalf("health endpoint %q was not reachable before deadline", url)
return nil
}
func httptestRequest(t *testing.T, method, path string, body any, token string) *http.Request {
t.Helper()
payload, err := json.Marshal(body)
if err != nil {
t.Fatalf("json.Marshal() error = %v", err)
}
request, err := http.NewRequest(method, path, bytes.NewReader(payload))
if err != nil {
t.Fatalf("http.NewRequest() error = %v", err)
}
request.Header.Set("Content-Type", "application/json")
if token != "" {
request.Header.Set("Authorization", "Bearer "+token)
}
return request
}
func httptestRecorder(handler http.Handler, request *http.Request) *responseRecorder {
recorder := &responseRecorder{header: make(http.Header)}
handler.ServeHTTP(recorder, request)
return recorder
}
type responseRecorder struct {
header http.Header
body bytes.Buffer
code int
}
func (r *responseRecorder) Header() http.Header { return r.header }
func (r *responseRecorder) Write(body []byte) (int, error) { return r.body.Write(body) }
func (r *responseRecorder) WriteHeader(statusCode int) { r.code = statusCode }
func (r *responseRecorder) Body() *bytes.Buffer { return &r.body }
func assertStatusCode(t *testing.T, recorder *responseRecorder, want int) {
t.Helper()
if recorder.code != want {
t.Fatalf("status code = %d, want %d; body=%s", recorder.code, want, recorder.body.String())
}
}
func TestServerAddrReturnsConfiguredAddress(t *testing.T) {
server := NewServer("127.0.0.1:9999", nil, nil)
if got := server.Addr(); got != "127.0.0.1:9999" {
t.Fatalf("Addr() = %q, want %q", got, "127.0.0.1:9999")
}
}
func TestClassifyError(t *testing.T) {
tests := []struct {
name string
err error
wantStatusCode int
wantCode string
wantUpstream int
}{
{name: "nil", err: nil},
{name: "http error passthrough", err: &httpError{StatusCode: http.StatusTeapot, Code: "teapot", Message: "brew"}, wantStatusCode: http.StatusTeapot, wantCode: "teapot"},
{name: "upstream error", err: &sub2api.HTTPError{Method: http.MethodGet, Path: "/x", StatusCode: http.StatusForbidden, Body: "nope"}, wantStatusCode: http.StatusBadGateway, wantCode: "host_request_failed", wantUpstream: http.StatusForbidden},
{name: "pack conflict already installed", err: errors.New("pack already installed"), wantStatusCode: http.StatusConflict, wantCode: "pack_conflict"},
{name: "pack conflict checksum drift", err: errors.New("checksum drift detected"), wantStatusCode: http.StatusConflict, wantCode: "pack_conflict"},
{name: "provider not found", err: errors.New("provider \"deepseek\" not found in pack \"openai\""), wantStatusCode: http.StatusBadRequest, wantCode: "provider_not_found"},
{name: "bad request pack path", err: errors.New("pack path is required"), wantStatusCode: http.StatusBadRequest, wantCode: "bad_request"},
{name: "bad request decode", err: errors.New("decode pack.json failed"), wantStatusCode: http.StatusBadRequest, wantCode: "bad_request"},
{name: "internal error", err: errors.New("boom"), wantStatusCode: http.StatusInternalServerError, wantCode: "internal_error"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := classifyError(tt.err)
if tt.err == nil {
if got != nil {
t.Fatalf("classifyError(nil) = %#v, want nil", got)
}
return
}
if got == nil {
t.Fatal("classifyError() = nil, want error")
}
if got.StatusCode != tt.wantStatusCode {
t.Fatalf("StatusCode = %d, want %d", got.StatusCode, tt.wantStatusCode)
}
if got.Code != tt.wantCode {
t.Fatalf("Code = %q, want %q", got.Code, tt.wantCode)
}
if got.UpstreamStatus != tt.wantUpstream {
t.Fatalf("UpstreamStatus = %d, want %d", got.UpstreamStatus, tt.wantUpstream)
}
})
}
}
func TestWriteHTTPError(t *testing.T) {
t.Run("default error when nil", func(t *testing.T) {
recorder := &responseRecorder{header: make(http.Header)}
writeHTTPError(recorder, nil)
assertStatusCode(t, recorder, http.StatusInternalServerError)
if got := recorder.Header().Get("Content-Type"); got != "application/json" {
t.Fatalf("Content-Type = %q, want application/json", got)
}
assertJSONContains(t, recorder.Body().Bytes(), "error.code", "internal_error")
assertJSONContains(t, recorder.Body().Bytes(), "error.message", "internal server error")
})
t.Run("writes provided error", func(t *testing.T) {
recorder := &responseRecorder{header: make(http.Header)}
writeHTTPError(recorder, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "invalid input", UpstreamStatus: http.StatusConflict})
assertStatusCode(t, recorder, http.StatusBadRequest)
assertJSONContains(t, recorder.Body().Bytes(), "error.code", "bad_request")
assertJSONContains(t, recorder.Body().Bytes(), "error.upstream_status", float64(http.StatusConflict))
})
}
func TestDecodeJSON(t *testing.T) {
t.Run("success", func(t *testing.T) {
request := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{"host_base_url":"https://example.com","pack_path":"/tmp/pack.zip"}`))
var got InstallPackRequest
if err := decodeJSON(request, &got); err != nil {
t.Fatalf("decodeJSON() error = %v, want nil", err)
}
if got.HostBaseURL != "https://example.com" || got.PackPath != "/tmp/pack.zip" {
t.Fatalf("decoded request = %#v, want expected fields", got)
}
})
t.Run("rejects unknown fields", func(t *testing.T) {
request := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{"host_base_url":"https://example.com","unknown":true}`))
var got InstallPackRequest
err := decodeJSON(request, &got)
if err == nil {
t.Fatal("decodeJSON() error = nil, want error")
}
if err.StatusCode != http.StatusBadRequest || err.Code != "bad_request" {
t.Fatalf("decodeJSON() = %#v, want bad_request", err)
}
if !strings.Contains(err.Message, "unknown field") {
t.Fatalf("Message = %q, want unknown field", err.Message)
}
})
t.Run("rejects trailing non-object payload", func(t *testing.T) {
request := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{"host_base_url":"https://example.com"}[]`))
var got InstallPackRequest
err := decodeJSON(request, &got)
if err == nil {
t.Fatal("decodeJSON() error = nil, want error")
}
if err.Message != "request body must contain a single JSON object" {
t.Fatalf("Message = %q, want single object error", err.Message)
}
})
}
func TestWriteJSON(t *testing.T) {
recorder := &responseRecorder{header: make(http.Header)}
writeJSON(recorder, http.StatusCreated, map[string]any{"ok": true, "count": 2})
assertStatusCode(t, recorder, http.StatusCreated)
if got := recorder.Header().Get("Content-Type"); got != "application/json" {
t.Fatalf("Content-Type = %q, want application/json", got)
}
assertJSONContains(t, recorder.Body().Bytes(), "ok", true)
assertJSONContains(t, recorder.Body().Bytes(), "count", float64(2))
}
func TestFindProvider(t *testing.T) {
loaded := pack.LoadedPack{
Manifest: pack.Manifest{PackID: "openai-cn-pack"},
Providers: []pack.ProviderManifest{
{ProviderID: "deepseek", DisplayName: "DeepSeek"},
{ProviderID: "openai", DisplayName: "OpenAI"},
},
}
provider, err := findProvider(loaded, " deepseek ")
if err != nil {
t.Fatalf("findProvider() error = %v, want nil", err)
}
if provider.ProviderID != "deepseek" {
t.Fatalf("ProviderID = %q, want deepseek", provider.ProviderID)
}
_, err = findProvider(loaded, "missing")
if err == nil {
t.Fatal("findProvider() error = nil, want error")
}
if !strings.Contains(err.Error(), `provider "missing" not found in pack "openai-cn-pack"`) {
t.Fatalf("findProvider() error = %v, want provider not found message", err)
}
}
func TestAPIRequiresConfiguredAdminToken(t *testing.T) {
handler := NewAPIHandler("", ActionSet{})
request := httptestRequest(t, http.MethodPost, "/api/packs/install", map[string]any{"host_base_url": "https://sub2api.example.com"}, "any-token")
response := httptestRecorder(handler, request)
assertStatusCode(t, response, http.StatusInternalServerError)
assertJSONContains(t, response.Body().Bytes(), "error.code", "server_misconfigured")
}
func TestAPIBatchDetailRejectsInvalidBatchID(t *testing.T) {
handler := NewAPIHandler("secret-token", ActionSet{BatchDetail: func(context.Context, BatchDetailRequest) (provision.BatchDetailResult, error) {
t.Fatal("BatchDetail should not be called for invalid batch id")
return provision.BatchDetailResult{}, nil
}})
request := httptestRequest(t, http.MethodGet, "/api/import-batches/not-a-number", nil, "secret-token")
response := httptestRecorder(handler, request)
assertStatusCode(t, response, http.StatusBadRequest)
assertJSONContains(t, response.Body().Bytes(), "error.code", "bad_request")
assertJSONContains(t, response.Body().Bytes(), "error.message", "batch_id must be a positive integer")
}
func TestAPIInstallPackRejectsInvalidJSON(t *testing.T) {
handler := NewAPIHandler("secret-token", ActionSet{InstallPack: func(context.Context, InstallPackRequest) (provision.PackInstallResult, error) {
t.Fatal("InstallPack should not be called for invalid JSON")
return provision.PackInstallResult{}, nil
}})
request, err := http.NewRequest(http.MethodPost, "/api/packs/install", strings.NewReader(`{"host_base_url":"https://sub2api.example.com","unknown":true}`))
if err != nil {
t.Fatalf("http.NewRequest() error = %v", err)
}
request.Header.Set("Authorization", "Bearer secret-token")
request.Header.Set("Content-Type", "application/json")
response := httptestRecorder(handler, request)
assertStatusCode(t, response, http.StatusBadRequest)
assertJSONContains(t, response.Body().Bytes(), "error.code", "bad_request")
}
func TestAPIImportProviderReturnsClassifiedErrorWithoutBatch(t *testing.T) {
handler := NewAPIHandler("secret-token", ActionSet{
ImportProvider: func(context.Context, ImportProviderRequest) (provision.RuntimeImportResult, error) {
return provision.RuntimeImportResult{}, errors.New("pack path is required")
},
})
request := httptestRequest(t, http.MethodPost, "/api/providers/deepseek/import", map[string]any{"host_base_url": "https://sub2api.example.com", "pack_path": "/tmp/openai-pack.zip"}, "secret-token")
response := httptestRecorder(handler, request)
assertStatusCode(t, response, http.StatusBadRequest)
assertJSONContains(t, response.Body().Bytes(), "error.code", "bad_request")
assertJSONContains(t, response.Body().Bytes(), "batch_id", float64(0))
}
func TestAPIPreviewProviderReturnsUpstreamError(t *testing.T) {
handler := NewAPIHandler("secret-token", ActionSet{
PreviewProvider: func(context.Context, PreviewProviderRequest) (provision.PreviewReport, error) {
return provision.PreviewReport{}, &sub2api.HTTPError{Method: http.MethodPost, Path: "/preview", StatusCode: http.StatusTooManyRequests, Body: "rate limited"}
},
})
request := httptestRequest(t, http.MethodPost, "/api/providers/deepseek/preview-import", map[string]any{"host_base_url": "https://sub2api.example.com", "pack_path": "/tmp/openai-pack.zip"}, "secret-token")
response := httptestRecorder(handler, request)
assertStatusCode(t, response, http.StatusBadGateway)
assertJSONContains(t, response.Body().Bytes(), "error.code", "host_request_failed")
assertJSONContains(t, response.Body().Bytes(), "error.upstream_status", float64(http.StatusTooManyRequests))
}
func TestAPIRollbackProviderReturnsConfiguredError(t *testing.T) {
handler := NewAPIHandler("secret-token", ActionSet{
RollbackProvider: func(context.Context, RollbackProviderRequest) (provision.RollbackReport, error) {
return provision.RollbackReport{}, &httpError{StatusCode: http.StatusGone, Code: "rolled_back", Message: "already removed"}
},
})
request := httptestRequest(t, http.MethodPost, "/api/providers/deepseek/rollback", map[string]any{"host_base_url": "https://sub2api.example.com", "pack_path": "/tmp/openai-pack.zip"}, "secret-token")
response := httptestRecorder(handler, request)
assertStatusCode(t, response, http.StatusGone)
assertJSONContains(t, response.Body().Bytes(), "error.code", "rolled_back")
}
func TestAPIReconcileProviderRejectsTrailingNonObjectPayload(t *testing.T) {
handler := NewAPIHandler("secret-token", ActionSet{ReconcileProvider: func(context.Context, ReconcileProviderRequest) (provision.ReconcileResult, error) {
t.Fatal("ReconcileProvider should not be called for invalid JSON")
return provision.ReconcileResult{}, nil
}})
request, err := http.NewRequest(http.MethodPost, "/api/providers/deepseek/reconcile", strings.NewReader(`{"host_base_url":"https://sub2api.example.com"}[]`))
if err != nil {
t.Fatalf("http.NewRequest() error = %v", err)
}
request.Header.Set("Authorization", "Bearer secret-token")
request.Header.Set("Content-Type", "application/json")
response := httptestRecorder(handler, request)
assertStatusCode(t, response, http.StatusBadRequest)
assertJSONContains(t, response.Body().Bytes(), "error.message", "request body must contain a single JSON object")
}
// --- Coverage edge cases ---
func TestHTTPErrorError(t *testing.T) {
e := &httpError{StatusCode: http.StatusTeapot, Code: "teapot", Message: "i'm a teapot"}
if got := e.Error(); got != "i'm a teapot" {
t.Fatalf("httpError.Error() = %q, want %q", got, "i'm a teapot")
}
}
func TestProviderStatusFnNil(t *testing.T) {
handler := NewAPIHandler("t", ActionSet{})
req := httptestRequest(t, http.MethodGet, "/api/providers/x/status", nil, "t")
res := httptestRecorder(handler, req)
assertStatusCode(t, res, http.StatusInternalServerError)
assertJSONContains(t, res.Body().Bytes(), "error.code", "server_misconfigured")
}
func TestProviderAccessStatusFnNil(t *testing.T) {
handler := NewAPIHandler("t", ActionSet{})
req := httptestRequest(t, http.MethodGet, "/api/providers/x/access/status", nil, "t")
res := httptestRecorder(handler, req)
assertStatusCode(t, res, http.StatusInternalServerError)
assertJSONContains(t, res.Body().Bytes(), "error.code", "server_misconfigured")
}
func TestProviderResourcesFnNil(t *testing.T) {
handler := NewAPIHandler("t", ActionSet{})
req := httptestRequest(t, http.MethodGet, "/api/providers/x/resources", nil, "t")
res := httptestRecorder(handler, req)
assertStatusCode(t, res, http.StatusInternalServerError)
assertJSONContains(t, res.Body().Bytes(), "error.code", "server_misconfigured")
}
func TestProviderStatusReturnsError(t *testing.T) {
handler := NewAPIHandler("t", ActionSet{
GetProviderStatus: func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error) {
return provision.ProviderSnapshot{}, errors.New(`provider "x" not found in pack "p"`)
},
})
req := httptestRequest(t, http.MethodGet, "/api/providers/x/status", nil, "t")
res := httptestRecorder(handler, req)
assertStatusCode(t, res, http.StatusBadRequest)
assertJSONContains(t, res.Body().Bytes(), "error.code", "provider_not_found")
}
func TestPostHandlersFnNil(t *testing.T) {
tests := []struct {
name string
method string
path string
body string
}{
{name: "install-pack", method: http.MethodPost, path: "/api/packs/install", body: `{}`},
{name: "preview", method: http.MethodPost, path: "/api/providers/x/preview-import", body: `{}`},
{name: "import", method: http.MethodPost, path: "/api/providers/x/import", body: `{}`},
{name: "rollback", method: http.MethodPost, path: "/api/providers/x/rollback", body: `{}`},
{name: "reconcile", method: http.MethodPost, path: "/api/providers/x/reconcile", body: `{}`},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler := NewAPIHandler("t", ActionSet{})
req, _ := http.NewRequest(tt.method, tt.path, strings.NewReader(tt.body))
req.Header.Set("Authorization", "Bearer t")
req.Header.Set("Content-Type", "application/json")
res := httptestRecorder(handler, req)
assertStatusCode(t, res, http.StatusInternalServerError)
assertJSONContains(t, res.Body().Bytes(), "error.code", "server_misconfigured")
})
}
}
func TestHandlerErrorPaths(t *testing.T) {
tests := []struct {
name string
method string
path string
body string
actionSet ActionSet
wantStatus int
wantCode string
}{
{
name: "access-status-error",
method: http.MethodGet,
path: "/api/providers/x/access/status",
actionSet: ActionSet{
GetProviderAccessStatus: func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error) {
return provision.ProviderSnapshot{}, errors.New("boom")
},
},
wantStatus: http.StatusInternalServerError,
wantCode: "internal_error",
},
{
name: "preview-error",
method: http.MethodPost,
path: "/api/providers/x/preview-import",
body: `{}`,
actionSet: ActionSet{
PreviewProvider: func(context.Context, PreviewProviderRequest) (provision.PreviewReport, error) {
return provision.PreviewReport{}, errors.New("boom")
},
},
wantStatus: http.StatusInternalServerError,
wantCode: "internal_error",
},
{
name: "rollback-error",
method: http.MethodPost,
path: "/api/providers/x/rollback",
body: `{}`,
actionSet: ActionSet{
RollbackProvider: func(context.Context, RollbackProviderRequest) (provision.RollbackReport, error) {
return provision.RollbackReport{}, errors.New("boom")
},
},
wantStatus: http.StatusInternalServerError,
wantCode: "internal_error",
},
{
name: "reconcile-error",
method: http.MethodPost,
path: "/api/providers/x/reconcile",
body: `{}`,
actionSet: ActionSet{
ReconcileProvider: func(context.Context, ReconcileProviderRequest) (provision.ReconcileResult, error) {
return provision.ReconcileResult{}, errors.New("boom")
},
},
wantStatus: http.StatusInternalServerError,
wantCode: "internal_error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler := NewAPIHandler("t", tt.actionSet)
var req *http.Request
if tt.body != "" {
req, _ = http.NewRequest(tt.method, tt.path, strings.NewReader(tt.body))
req.Header.Set("Content-Type", "application/json")
} else {
var err error
req, err = http.NewRequest(tt.method, tt.path, nil)
if err != nil {
t.Fatal(err)
}
}
req.Header.Set("Authorization", "Bearer t")
res := httptestRecorder(handler, req)
assertStatusCode(t, res, tt.wantStatus)
assertJSONContains(t, res.Body().Bytes(), "error.code", tt.wantCode)
})
}
}
func TestProviderAccessStatusMultipleClosures(t *testing.T) {
handler := NewAPIHandler("t", ActionSet{
GetProviderAccessStatus: func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error) {
return provision.ProviderSnapshot{
Pack: sqlite.Pack{PackID: "p"},
Provider: sqlite.Provider{ProviderID: "dp"},
Batch: sqlite.ImportBatch{ID: 1},
LatestAccessStatus: "ready",
AccessClosures: []sqlite.AccessClosureRecord{
{ID: 1, ClosureType: "preview", Status: "done", DetailsJSON: `{"v":1}`},
{ID: 2, ClosureType: "self_service", Status: "active", DetailsJSON: `{"v":2}`},
},
}, nil
},
})
req := httptestRequest(t, http.MethodGet, "/api/providers/dp/access/status", nil, "t")
res := httptestRecorder(handler, req)
assertStatusCode(t, res, http.StatusOK)
// Should report the last closure (index n-1)
if !strings.Contains(res.Body().String(), `"closure_type":"self_service"`) {
t.Fatalf("expected latest closure to be self_service, got: %s", res.Body().String())
}
}
func assertJSONContains(t *testing.T, payload []byte, key string, want any) {
t.Helper()
var decoded map[string]any
if err := json.Unmarshal(payload, &decoded); err != nil {
t.Fatalf("json.Unmarshal() error = %v; payload=%s", err, string(payload))
}
if strings.Contains(key, ".") {
parts := strings.Split(key, ".")
current := any(decoded)
for _, part := range parts {
object, ok := current.(map[string]any)
if !ok {
t.Fatalf("key %q not found in payload %s", key, string(payload))
}
current = object[part]
}
if current != want {
t.Fatalf("json key %q = %#v, want %#v; payload=%s", key, current, want, string(payload))
}
return
}
if decoded[key] != want {
t.Fatalf("json key %q = %#v, want %#v; payload=%s", key, decoded[key], want, string(payload))
}
}
func TestNewActionSetReturnsNonNil(t *testing.T) {
as := NewActionSet("file::memory:?cache=shared")
t.Run("InstallPack", func(t *testing.T) {
if as.InstallPack == nil {
t.Fatal("is nil")
}
})
t.Run("BatchDetail", func(t *testing.T) {
if as.BatchDetail == nil {
t.Fatal("is nil")
}
})
t.Run("GetProviderStatus", func(t *testing.T) {
if as.GetProviderStatus == nil {
t.Fatal("is nil")
}
})
t.Run("GetProviderResources", func(t *testing.T) {
if as.GetProviderResources == nil {
t.Fatal("is nil")
}
})
t.Run("GetProviderAccessStatus", func(t *testing.T) {
if as.GetProviderAccessStatus == nil {
t.Fatal("is nil")
}
})
t.Run("PreviewProvider", func(t *testing.T) {
if as.PreviewProvider == nil {
t.Fatal("is nil")
}
})
t.Run("ImportProvider", func(t *testing.T) {
if as.ImportProvider == nil {
t.Fatal("is nil")
}
})
t.Run("RollbackProvider", func(t *testing.T) {
if as.RollbackProvider == nil {
t.Fatal("is nil")
}
})
t.Run("ReconcileProvider", func(t *testing.T) {
if as.ReconcileProvider == nil {
t.Fatal("is nil")
}
})
}
func TestBatchDetailReturnsNotFoundForMissingBatch(t *testing.T) {
as := NewActionSet("file::memory:?cache=shared")
_, err := as.BatchDetail(context.Background(), BatchDetailRequest{BatchID: 999})
if err == nil {
t.Fatal("BatchDetail() error = nil for missing batch, want error")
}
}
func TestNewActionSetSQLiteClosures(t *testing.T) {
dsn := "file::memory:?cache=shared"
as := NewActionSet(dsn)
ctx := context.Background()
t.Run("GetProviderStatus on empty DB", func(t *testing.T) {
_, err := as.GetProviderStatus(ctx, ProviderQueryRequest{ProviderID: "x", PackID: "p"})
if err == nil {
t.Fatal("expected error from empty DB, got nil")
}
})
t.Run("GetProviderResources on empty DB", func(t *testing.T) {
_, err := as.GetProviderResources(ctx, ProviderQueryRequest{ProviderID: "x", PackID: "p"})
if err == nil {
t.Fatal("expected error from empty DB, got nil")
}
})
t.Run("GetProviderAccessStatus on empty DB", func(t *testing.T) {
_, err := as.GetProviderAccessStatus(ctx, ProviderQueryRequest{ProviderID: "x", PackID: "p"})
if err == nil {
t.Fatal("expected error from empty DB, got nil")
}
})
}
func TestNewActionSetPackErrorPaths(t *testing.T) {
dsn := "file::memory:?cache=shared"
as := NewActionSet(dsn)
ctx := context.Background()
t.Run("InstallPack bad path", func(t *testing.T) {
_, err := as.InstallPack(ctx, InstallPackRequest{PackPath: "/nonexistent/pack"})
if err == nil {
t.Fatal("expected error from bad pack path")
}
})
t.Run("PreviewProvider bad path", func(t *testing.T) {
_, err := as.PreviewProvider(ctx, PreviewProviderRequest{PackPath: "/nonexistent/pack", ProviderID: "x"})
if err == nil {
t.Fatal("expected error from bad pack path")
}
})
t.Run("ImportProvider bad path", func(t *testing.T) {
_, err := as.ImportProvider(ctx, ImportProviderRequest{PackPath: "/nonexistent/pack", ProviderID: "x", HostBaseURL: "http://h:8080"})
if err == nil {
t.Fatal("expected error from bad pack path")
}
})
t.Run("RollbackProvider bad path", func(t *testing.T) {
_, err := as.RollbackProvider(ctx, RollbackProviderRequest{PackPath: "/nonexistent/pack", ProviderID: "x", HostBaseURL: "http://h:8080"})
if err == nil {
t.Fatal("expected error from bad pack path")
}
})
t.Run("ReconcileProvider bad path", func(t *testing.T) {
_, err := as.ReconcileProvider(ctx, ReconcileProviderRequest{PackPath: "/nonexistent/pack", ProviderID: "x", HostBaseURL: "http://h:8080"})
if err == nil {
t.Fatal("expected error from bad pack path")
}
})
}

View File

@@ -11,6 +11,10 @@ func Bootstrap(_ context.Context) (*Server, error) {
if err != nil {
return nil, err
}
return NewServer(cfg.Server.ListenAddr, nil), nil
adminToken, err := config.LoadAdminTokenFromEnv()
if err != nil {
return nil, err
}
handler := NewAPIHandler(adminToken, NewActionSet(cfg.Database.SQLiteDSN))
return NewServer(cfg.Server.ListenAddr, handler, nil), nil
}

638
internal/app/http_api.go Normal file
View File

@@ -0,0 +1,638 @@
package app
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sub2api-cn-relay-manager/internal/host/sub2api"
"sub2api-cn-relay-manager/internal/pack"
"sub2api-cn-relay-manager/internal/provision"
"sub2api-cn-relay-manager/internal/store/sqlite"
)
type ActionSet struct {
InstallPack func(context.Context, InstallPackRequest) (provision.PackInstallResult, error)
BatchDetail func(context.Context, BatchDetailRequest) (provision.BatchDetailResult, error)
GetProviderStatus func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error)
GetProviderResources func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error)
GetProviderAccessStatus func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error)
PreviewProvider func(context.Context, PreviewProviderRequest) (provision.PreviewReport, error)
ImportProvider func(context.Context, ImportProviderRequest) (provision.RuntimeImportResult, error)
RollbackProvider func(context.Context, RollbackProviderRequest) (provision.RollbackReport, error)
ReconcileProvider func(context.Context, ReconcileProviderRequest) (provision.ReconcileResult, error)
}
type InstallPackRequest struct {
HostBaseURL string `json:"host_base_url"`
HostAPIKey string `json:"host_api_key"`
HostBearerToken string `json:"host_bearer_token"`
PackPath string `json:"pack_path"`
}
type BatchDetailRequest struct {
BatchID int64
}
type ProviderQueryRequest struct {
ProviderID string
PackID string
}
type RollbackProviderRequest struct {
HostBaseURL string `json:"host_base_url"`
HostAPIKey string `json:"host_api_key"`
HostBearerToken string `json:"host_bearer_token"`
PackPath string `json:"pack_path"`
ProviderID string `json:"provider_id"`
}
type ReconcileProviderRequest struct {
HostBaseURL string `json:"host_base_url"`
HostAPIKey string `json:"host_api_key"`
HostBearerToken string `json:"host_bearer_token"`
PackPath string `json:"pack_path"`
ProviderID string `json:"provider_id"`
AccessAPIKey string `json:"access_api_key"`
}
type PreviewProviderRequest struct {
HostBaseURL string `json:"host_base_url"`
HostAPIKey string `json:"host_api_key"`
HostBearerToken string `json:"host_bearer_token"`
PackPath string `json:"pack_path"`
ProviderID string `json:"provider_id"`
Keys []string `json:"keys"`
Mode string `json:"mode"`
}
type ImportProviderRequest struct {
HostBaseURL string `json:"host_base_url"`
HostAPIKey string `json:"host_api_key"`
HostBearerToken string `json:"host_bearer_token"`
PackPath string `json:"pack_path"`
ProviderID string `json:"provider_id"`
Keys []string `json:"keys"`
Mode string `json:"mode"`
AccessMode string `json:"access_mode"`
AccessAPIKey string `json:"access_api_key"`
SubscriptionUsers []string `json:"subscription_users"`
SubscriptionDays int `json:"subscription_days"`
}
type httpError struct {
StatusCode int `json:"-"`
Code string `json:"code"`
Message string `json:"message"`
UpstreamStatus int `json:"upstream_status,omitempty"`
}
func (e *httpError) Error() string {
return e.Message
}
func NewAPIHandler(adminToken string, actions ActionSet) http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("GET /healthz", healthz)
mux.Handle("GET /api/import-batches/{batchID}", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleBatchDetail(w, r, actions.BatchDetail)
})))
mux.Handle("GET /api/providers/{providerID}/status", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleProviderStatus(w, r, actions.GetProviderStatus)
})))
mux.Handle("GET /api/providers/{providerID}/resources", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleProviderResources(w, r, actions.GetProviderResources)
})))
mux.Handle("GET /api/providers/{providerID}/access/status", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleProviderAccessStatus(w, r, actions.GetProviderAccessStatus)
})))
mux.Handle("POST /api/packs/install", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleInstallPack(w, r, actions.InstallPack)
})))
mux.Handle("POST /api/providers/{providerID}/preview-import", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlePreviewProvider(w, r, actions.PreviewProvider)
})))
mux.Handle("POST /api/providers/{providerID}/import", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleImportProvider(w, r, actions.ImportProvider)
})))
mux.Handle("POST /api/providers/{providerID}/rollback", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleRollbackProvider(w, r, actions.RollbackProvider)
})))
mux.Handle("POST /api/providers/{providerID}/reconcile", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleReconcileProvider(w, r, actions.ReconcileProvider)
})))
return mux
}
func healthz(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}
func requireAdminToken(token string, next http.Handler) http.Handler {
if strings.TrimSpace(token) == "" {
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "admin token is not configured"})
})
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if bearerToken(r) != token {
writeHTTPError(w, &httpError{StatusCode: http.StatusUnauthorized, Code: "unauthorized", Message: "missing or invalid admin token"})
return
}
next.ServeHTTP(w, r)
})
}
func bearerToken(r *http.Request) string {
header := strings.TrimSpace(r.Header.Get("Authorization"))
if !strings.HasPrefix(strings.ToLower(header), "bearer ") {
return ""
}
return strings.TrimSpace(header[len("Bearer "):])
}
func handleInstallPack(w http.ResponseWriter, r *http.Request, fn func(context.Context, InstallPackRequest) (provision.PackInstallResult, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "install-pack action is not configured"})
return
}
var req InstallPackRequest
if err := decodeJSON(r, &req); err != nil {
writeHTTPError(w, err)
return
}
result, err := fn(r.Context(), req)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
providers := make([]map[string]string, 0, len(result.Providers))
for _, provider := range result.Providers {
providers = append(providers, map[string]string{
"provider_id": provider.ProviderID,
"display_name": provider.DisplayName,
})
}
writeJSON(w, http.StatusOK, map[string]any{
"pack_id": result.Pack.PackID,
"version": result.Pack.Version,
"host_version": result.HostVersion,
"already_installed": result.AlreadyInstalled,
"providers": providers,
})
}
func handleBatchDetail(w http.ResponseWriter, r *http.Request, fn func(context.Context, BatchDetailRequest) (provision.BatchDetailResult, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "batch-detail action is not configured"})
return
}
batchID, err := strconv.ParseInt(r.PathValue("batchID"), 10, 64)
if err != nil || batchID <= 0 {
writeHTTPError(w, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "batch_id must be a positive integer"})
return
}
result, err := fn(r.Context(), BatchDetailRequest{BatchID: batchID})
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
items := make([]map[string]any, 0, len(result.Items))
for _, item := range result.Items {
items = append(items, map[string]any{
"id": item.ID,
"batch_id": item.BatchID,
"key_fingerprint": item.KeyFingerprint,
"account_status": item.AccountStatus,
"probe_summary_json": item.ProbeSummaryJSON,
})
}
writeJSON(w, http.StatusOK, map[string]any{
"batch": map[string]any{
"id": result.Batch.ID,
"host_id": result.Batch.HostID,
"pack_id": result.Batch.PackID,
"provider_id": result.Batch.ProviderID,
"mode": result.Batch.Mode,
"batch_status": result.Batch.BatchStatus,
"access_status": result.Batch.AccessStatus,
},
"items": items,
"managed_resources": result.ManagedResources,
"access_closures": result.AccessClosures,
"reconcile_runs": result.ReconcileRuns,
"items_count": len(result.Items),
"managed_count": len(result.ManagedResources),
"access_count": len(result.AccessClosures),
"reconcile_count": len(result.ReconcileRuns),
})
}
func handleProviderStatus(w http.ResponseWriter, r *http.Request, fn func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "provider-status action is not configured"})
return
}
result, err := fn(r.Context(), ProviderQueryRequest{ProviderID: r.PathValue("providerID"), PackID: strings.TrimSpace(r.URL.Query().Get("pack_id"))})
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
writeJSON(w, http.StatusOK, map[string]any{
"host": map[string]any{"host_id": result.Host.HostID, "base_url": result.Host.BaseURL, "host_version": result.Host.HostVersion},
"pack": map[string]any{"pack_id": result.Pack.PackID, "version": result.Pack.Version},
"provider": map[string]any{"provider_id": result.Provider.ProviderID, "display_name": result.Provider.DisplayName, "platform": result.Provider.Platform},
"batch": map[string]any{"id": result.Batch.ID, "batch_status": result.Batch.BatchStatus, "access_status": result.Batch.AccessStatus, "mode": result.Batch.Mode},
"provider_status": result.ProviderStatus,
"latest_access_status": result.LatestAccessStatus,
"latest_reconcile_status": result.LatestReconcileStatus,
"latest_reconcile_summary": result.LatestReconcileSummary,
"managed_resources_count": len(result.ManagedResources),
"access_closures_count": len(result.AccessClosures),
"reconcile_runs_count": len(result.ReconcileRuns),
})
}
func handleProviderAccessStatus(w http.ResponseWriter, r *http.Request, fn func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "provider-access-status action is not configured"})
return
}
result, err := fn(r.Context(), ProviderQueryRequest{ProviderID: r.PathValue("providerID"), PackID: strings.TrimSpace(r.URL.Query().Get("pack_id"))})
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
latestClosure := map[string]any{}
if n := len(result.AccessClosures); n > 0 {
closure := result.AccessClosures[n-1]
latestClosure = map[string]any{"id": closure.ID, "closure_type": closure.ClosureType, "status": closure.Status, "details_json": closure.DetailsJSON}
}
writeJSON(w, http.StatusOK, map[string]any{
"provider_id": result.Provider.ProviderID,
"pack_id": result.Pack.PackID,
"batch_id": result.Batch.ID,
"batch_access_status": result.Batch.AccessStatus,
"latest_access_status": result.LatestAccessStatus,
"closures_count": len(result.AccessClosures),
"latest_closure": latestClosure,
})
}
func handleProviderResources(w http.ResponseWriter, r *http.Request, fn func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "provider-resources action is not configured"})
return
}
result, err := fn(r.Context(), ProviderQueryRequest{ProviderID: r.PathValue("providerID"), PackID: strings.TrimSpace(r.URL.Query().Get("pack_id"))})
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
resources := make([]map[string]any, 0, len(result.ManagedResources))
for _, resource := range result.ManagedResources {
resources = append(resources, map[string]any{"id": resource.ID, "resource_type": resource.ResourceType, "host_resource_id": resource.HostResourceID, "resource_name": resource.ResourceName})
}
accessClosures := make([]map[string]any, 0, len(result.AccessClosures))
for _, closure := range result.AccessClosures {
accessClosures = append(accessClosures, map[string]any{"id": closure.ID, "closure_type": closure.ClosureType, "status": closure.Status, "details_json": closure.DetailsJSON})
}
reconcileRuns := make([]map[string]any, 0, len(result.ReconcileRuns))
for _, run := range result.ReconcileRuns {
reconcileRuns = append(reconcileRuns, map[string]any{"id": run.ID, "status": run.Status, "summary_json": run.SummaryJSON})
}
writeJSON(w, http.StatusOK, map[string]any{
"provider_id": result.Provider.ProviderID,
"pack_id": result.Pack.PackID,
"batch_id": result.Batch.ID,
"resources": resources,
"access_closures": accessClosures,
"reconcile_runs": reconcileRuns,
})
}
func handlePreviewProvider(w http.ResponseWriter, r *http.Request, fn func(context.Context, PreviewProviderRequest) (provision.PreviewReport, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "preview-provider action is not configured"})
return
}
var req PreviewProviderRequest
if err := decodeJSON(r, &req); err != nil {
writeHTTPError(w, err)
return
}
req.ProviderID = r.PathValue("providerID")
result, err := fn(r.Context(), req)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
writeJSON(w, http.StatusOK, map[string]any{
"accepted_keys_count": len(result.AcceptedKeys),
"names": result.Names,
"decisions": result.Decisions,
})
}
func handleImportProvider(w http.ResponseWriter, r *http.Request, fn func(context.Context, ImportProviderRequest) (provision.RuntimeImportResult, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "import-provider action is not configured"})
return
}
var req ImportProviderRequest
if err := decodeJSON(r, &req); err != nil {
writeHTTPError(w, err)
return
}
req.ProviderID = r.PathValue("providerID")
result, err := fn(r.Context(), req)
if err != nil {
payload := map[string]any{
"batch_id": result.BatchID,
"batch_status": result.Report.BatchStatus,
"provider_status": result.Report.ProviderStatus,
"access_status": result.Report.AccessStatus,
"accepted_keys_count": len(result.Report.AcceptedKeys),
"accounts_count": len(result.Report.Accounts),
"gateway": result.Report.Gateway,
"error": classifyError(err),
}
statusCode := http.StatusConflict
if result.BatchID == 0 {
statusCode = classifyError(err).StatusCode
}
writeJSON(w, statusCode, payload)
return
}
writeJSON(w, http.StatusOK, map[string]any{
"batch_id": result.BatchID,
"batch_status": result.Report.BatchStatus,
"provider_status": result.Report.ProviderStatus,
"access_status": result.Report.AccessStatus,
"accepted_keys_count": len(result.Report.AcceptedKeys),
"accounts_count": len(result.Report.Accounts),
"group": result.Report.Group,
"channel": result.Report.Channel,
"plan": result.Report.Plan,
"gateway": result.Report.Gateway,
})
}
func handleRollbackProvider(w http.ResponseWriter, r *http.Request, fn func(context.Context, RollbackProviderRequest) (provision.RollbackReport, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "rollback-provider action is not configured"})
return
}
var req RollbackProviderRequest
if err := decodeJSON(r, &req); err != nil {
writeHTTPError(w, err)
return
}
req.ProviderID = r.PathValue("providerID")
result, err := fn(r.Context(), req)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
writeJSON(w, http.StatusOK, map[string]any{
"provider_id": req.ProviderID,
"deleted_accounts": result.AccountsDeleted,
"deleted_plans": result.PlansDeleted,
"deleted_channels": result.ChannelsDeleted,
"deleted_groups": result.GroupsDeleted,
})
}
func handleReconcileProvider(w http.ResponseWriter, r *http.Request, fn func(context.Context, ReconcileProviderRequest) (provision.ReconcileResult, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "reconcile-provider action is not configured"})
return
}
var req ReconcileProviderRequest
if err := decodeJSON(r, &req); err != nil {
writeHTTPError(w, err)
return
}
req.ProviderID = r.PathValue("providerID")
result, err := fn(r.Context(), req)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
writeJSON(w, http.StatusOK, map[string]any{
"provider_id": req.ProviderID,
"batch_id": result.BatchID,
"status": result.Status,
"missing_count": result.MissingCount,
"extra_count": result.ExtraCount,
"summary": result.Summary,
})
}
func decodeJSON(r *http.Request, dest any) *httpError {
decoder := json.NewDecoder(r.Body)
decoder.DisallowUnknownFields()
if err := decoder.Decode(dest); err != nil {
return &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: fmt.Sprintf("decode request body: %v", err)}
}
if err := decoder.Decode(&struct{}{}); err != nil && !errors.Is(err, io.EOF) {
return &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "request body must contain a single JSON object"}
}
return nil
}
func writeHTTPError(w http.ResponseWriter, err *httpError) {
if err == nil {
err = &httpError{StatusCode: http.StatusInternalServerError, Code: "internal_error", Message: "internal server error"}
}
writeJSON(w, err.StatusCode, map[string]any{"error": err})
}
func writeJSON(w http.ResponseWriter, statusCode int, body any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
_ = json.NewEncoder(w).Encode(body)
}
func classifyError(err error) *httpError {
if err == nil {
return nil
}
var requestErr *httpError
if errors.As(err, &requestErr) {
return requestErr
}
var upstreamErr *sub2api.HTTPError
if errors.As(err, &upstreamErr) {
return &httpError{StatusCode: http.StatusBadGateway, Code: "host_request_failed", Message: err.Error(), UpstreamStatus: upstreamErr.StatusCode}
}
message := err.Error()
switch {
case strings.Contains(message, "already installed") || strings.Contains(message, "checksum drift"):
return &httpError{StatusCode: http.StatusConflict, Code: "pack_conflict", Message: message}
case strings.Contains(message, "not found in pack"):
return &httpError{StatusCode: http.StatusBadRequest, Code: "provider_not_found", Message: message}
case strings.Contains(message, "pack path") || strings.Contains(message, "pack dir") || strings.Contains(message, "required") || strings.Contains(message, "decode"):
return &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: message}
default:
return &httpError{StatusCode: http.StatusInternalServerError, Code: "internal_error", Message: message}
}
}
func NewActionSet(sqliteDSN string) ActionSet {
return ActionSet{
InstallPack: func(ctx context.Context, req InstallPackRequest) (provision.PackInstallResult, error) {
loadedPack, err := pack.LoadPath(req.PackPath)
if err != nil {
return provision.PackInstallResult{}, err
}
client, err := sub2api.NewClient(req.HostBaseURL, sub2api.WithAPIKey(req.HostAPIKey), sub2api.WithBearerToken(req.HostBearerToken))
if err != nil {
return provision.PackInstallResult{}, err
}
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return provision.PackInstallResult{}, err
}
defer store.Close()
service := provision.NewPackInstallService(store, client)
return service.Install(ctx, provision.PackInstallRequest{Pack: loadedPack})
},
BatchDetail: func(ctx context.Context, req BatchDetailRequest) (provision.BatchDetailResult, error) {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return provision.BatchDetailResult{}, err
}
defer store.Close()
return provision.NewBatchDetailService(store).Get(ctx, req.BatchID)
},
GetProviderStatus: func(ctx context.Context, req ProviderQueryRequest) (provision.ProviderSnapshot, error) {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return provision.ProviderSnapshot{}, err
}
defer store.Close()
return provision.NewProviderStatusService(store).GetStatus(ctx, provision.ProviderQuery{ProviderID: req.ProviderID, PackID: req.PackID})
},
GetProviderResources: func(ctx context.Context, req ProviderQueryRequest) (provision.ProviderSnapshot, error) {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return provision.ProviderSnapshot{}, err
}
defer store.Close()
return provision.NewProviderStatusService(store).GetResources(ctx, provision.ProviderQuery{ProviderID: req.ProviderID, PackID: req.PackID})
},
GetProviderAccessStatus: func(ctx context.Context, req ProviderQueryRequest) (provision.ProviderSnapshot, error) {
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return provision.ProviderSnapshot{}, err
}
defer store.Close()
return provision.NewProviderStatusService(store).GetStatus(ctx, provision.ProviderQuery{ProviderID: req.ProviderID, PackID: req.PackID})
},
PreviewProvider: func(ctx context.Context, req PreviewProviderRequest) (provision.PreviewReport, error) {
loadedPack, err := pack.LoadPath(req.PackPath)
if err != nil {
return provision.PreviewReport{}, err
}
providerManifest, err := findProvider(loadedPack, req.ProviderID)
if err != nil {
return provision.PreviewReport{}, err
}
client, err := sub2api.NewClient(req.HostBaseURL, sub2api.WithAPIKey(req.HostAPIKey), sub2api.WithBearerToken(req.HostBearerToken))
if err != nil {
return provision.PreviewReport{}, err
}
service := provision.NewPreviewService(client)
return service.PreviewImport(ctx, provision.PreviewRequest{Provider: providerManifest, Mode: req.Mode, Keys: req.Keys})
},
ImportProvider: func(ctx context.Context, req ImportProviderRequest) (provision.RuntimeImportResult, error) {
loadedPack, err := pack.LoadPath(req.PackPath)
if err != nil {
return provision.RuntimeImportResult{}, err
}
providerManifest, err := findProvider(loadedPack, req.ProviderID)
if err != nil {
return provision.RuntimeImportResult{}, err
}
client, err := sub2api.NewClient(req.HostBaseURL, sub2api.WithAPIKey(req.HostAPIKey), sub2api.WithBearerToken(req.HostBearerToken))
if err != nil {
return provision.RuntimeImportResult{}, err
}
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return provision.RuntimeImportResult{}, err
}
defer store.Close()
subscriptions := make([]provision.SubscriptionTarget, 0, len(req.SubscriptionUsers))
for _, userID := range req.SubscriptionUsers {
subscriptions = append(subscriptions, provision.SubscriptionTarget{UserID: userID, DurationDays: req.SubscriptionDays})
}
service := provision.NewRuntimeImportService(store, client)
return service.Import(ctx, provision.RuntimeImportRequest{
HostBaseURL: req.HostBaseURL,
Pack: loadedPack,
Provider: providerManifest,
Mode: req.Mode,
Keys: req.Keys,
Access: provision.AccessRequest{
Mode: req.AccessMode,
ProbeAPIKey: req.AccessAPIKey,
Subscriptions: subscriptions,
},
})
},
RollbackProvider: func(ctx context.Context, req RollbackProviderRequest) (provision.RollbackReport, error) {
loadedPack, err := pack.LoadPath(req.PackPath)
if err != nil {
return provision.RollbackReport{}, err
}
providerManifest, err := findProvider(loadedPack, req.ProviderID)
if err != nil {
return provision.RollbackReport{}, err
}
client, err := sub2api.NewClient(req.HostBaseURL, sub2api.WithAPIKey(req.HostAPIKey), sub2api.WithBearerToken(req.HostBearerToken))
if err != nil {
return provision.RollbackReport{}, err
}
service := provision.NewRollbackService(client)
return service.Rollback(ctx, provision.RollbackRequest{Provider: providerManifest})
},
ReconcileProvider: func(ctx context.Context, req ReconcileProviderRequest) (provision.ReconcileResult, error) {
loadedPack, err := pack.LoadPath(req.PackPath)
if err != nil {
return provision.ReconcileResult{}, err
}
providerManifest, err := findProvider(loadedPack, req.ProviderID)
if err != nil {
return provision.ReconcileResult{}, err
}
client, err := sub2api.NewClient(req.HostBaseURL, sub2api.WithAPIKey(req.HostAPIKey), sub2api.WithBearerToken(req.HostBearerToken))
if err != nil {
return provision.ReconcileResult{}, err
}
store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil {
return provision.ReconcileResult{}, err
}
defer store.Close()
service := provision.NewReconcileService(store, client)
return service.Reconcile(ctx, provision.ReconcileRequest{HostBaseURL: req.HostBaseURL, AccessProbeAPIKey: req.AccessAPIKey, Pack: loadedPack, Provider: providerManifest})
},
}
}
func findProvider(loaded pack.LoadedPack, providerID string) (pack.ProviderManifest, error) {
for _, provider := range loaded.Providers {
if provider.ProviderID == strings.TrimSpace(providerID) {
return provider, nil
}
}
return pack.ProviderManifest{}, fmt.Errorf("provider %q not found in pack %q", providerID, loaded.Manifest.PackID)
}

View File

@@ -0,0 +1,140 @@
package config
import (
"errors"
"testing"
)
func TestReadOptionalEnv(t *testing.T) {
t.Run("present non-empty", func(t *testing.T) {
lookup := func(k string) (string, bool) {
if k == "MY_KEY" {
return " value ", true
}
return "", false
}
if got := readOptionalEnv(lookup, "MY_KEY", "default"); got != "value" {
t.Fatalf("got %q, want %q", got, "value")
}
})
t.Run("present empty", func(t *testing.T) {
lookup := func(k string) (string, bool) {
return " ", true
}
if got := readOptionalEnv(lookup, "MY_KEY", "default"); got != "default" {
t.Fatalf("got %q, want %q", got, "default")
}
})
t.Run("missing", func(t *testing.T) {
lookup := func(k string) (string, bool) {
return "", false
}
if got := readOptionalEnv(lookup, "MY_KEY", "default"); got != "default" {
t.Fatalf("got %q, want %q", got, "default")
}
})
}
func TestReadRequiredEnv(t *testing.T) {
t.Run("present", func(t *testing.T) {
lookup := func(k string) (string, bool) {
return "my-token", true
}
if got := readRequiredEnv(lookup, "TOKEN"); got != "my-token" {
t.Fatalf("got %q, want %q", got, "my-token")
}
})
t.Run("missing", func(t *testing.T) {
lookup := func(k string) (string, bool) {
return "", false
}
if got := readRequiredEnv(lookup, "TOKEN"); got != "" {
t.Fatalf("got %q, want empty", got)
}
})
}
func TestLoadStartupFromLookupEnv(t *testing.T) {
t.Run("custom values", func(t *testing.T) {
lookup := func(k string) (string, bool) {
switch k {
case EnvListenAddr:
return ":9090", true
case EnvSQLiteDSN:
return "/data/db.sqlite", true
default:
return "", false
}
}
cfg, err := loadStartupFromLookupEnv(lookup)
if err != nil {
t.Fatal(err)
}
if cfg.Server.ListenAddr != ":9090" {
t.Fatalf("ListenAddr = %q, want %q", cfg.Server.ListenAddr, ":9090")
}
if cfg.Database.SQLiteDSN != "/data/db.sqlite" {
t.Fatalf("SQLiteDSN = %q, want %q", cfg.Database.SQLiteDSN, "/data/db.sqlite")
}
})
t.Run("default values", func(t *testing.T) {
lookup := func(k string) (string, bool) {
return "", false
}
cfg, err := loadStartupFromLookupEnv(lookup)
if err != nil {
t.Fatal(err)
}
if cfg.Server.ListenAddr != DefaultListenAddr {
t.Fatalf("ListenAddr = %q, want %q", cfg.Server.ListenAddr, DefaultListenAddr)
}
if cfg.Database.SQLiteDSN != DefaultSQLiteDSN {
t.Fatalf("SQLiteDSN = %q, want %q", cfg.Database.SQLiteDSN, DefaultSQLiteDSN)
}
})
}
func TestLoadAdminTokenFromLookupEnv(t *testing.T) {
t.Run("valid token", func(t *testing.T) {
lookup := func(k string) (string, bool) {
return " admin-secret-123 ", true
}
token, err := loadAdminTokenFromLookupEnv(lookup)
if err != nil {
t.Fatal(err)
}
if token != "admin-secret-123" {
t.Fatalf("token = %q, want %q", token, "admin-secret-123")
}
})
t.Run("empty token", func(t *testing.T) {
lookup := func(k string) (string, bool) {
return " ", true
}
_, err := loadAdminTokenFromLookupEnv(lookup)
if err == nil {
t.Fatal("expected error for empty token")
}
})
t.Run("missing env", func(t *testing.T) {
lookup := func(k string) (string, bool) {
return "", false
}
_, err := loadAdminTokenFromLookupEnv(lookup)
if err == nil {
t.Fatal("expected error for missing env")
}
})
}
// Verify exported wrappers call the lookup versions.
// We can't easily test LoadStartupFromEnv / LoadAdminTokenFromEnv
// since they depend on os.LookupEnv, but we verify they compile and don't panic.
func TestExportFunctionsExist(t *testing.T) {
// Just verify the exported functions are reachable and return the right types
_, err := LoadAdminTokenFromEnv()
if err != nil && !errors.Is(err, err) {
// any result is fine, just proving the function exists
}
}

View File

@@ -16,13 +16,19 @@ type HostAdapter interface {
GetHostVersion(ctx context.Context) (string, error)
ProbeCapabilities(ctx context.Context) (HostCapabilities, error)
CreateGroup(ctx context.Context, req CreateGroupRequest) (GroupRef, error)
DeleteGroup(ctx context.Context, groupID string) error
CreateChannel(ctx context.Context, req CreateChannelRequest) (ChannelRef, error)
DeleteChannel(ctx context.Context, channelID string) error
CreatePlan(ctx context.Context, req CreatePlanRequest) (PlanRef, error)
DeletePlan(ctx context.Context, planID string) error
CreateAccount(ctx context.Context, req CreateAccountRequest) (AccountRef, error)
BatchCreateAccounts(ctx context.Context, req BatchCreateAccountsRequest) ([]AccountRef, error)
DeleteAccount(ctx context.Context, accountID string) error
TestAccount(ctx context.Context, accountID string) (ProbeResult, error)
GetAccountModels(ctx context.Context, accountID string) ([]AccountModel, error)
AssignSubscription(ctx context.Context, req AssignSubscriptionRequest) (SubscriptionRef, error)
CheckGatewayAccess(ctx context.Context, req GatewayAccessCheckRequest) (GatewayAccessResult, error)
ListManagedResources(ctx context.Context, req ListManagedResourcesRequest) (ManagedResourceSnapshot, error)
}
type HostCapabilities struct {

View File

@@ -0,0 +1,40 @@
package sub2api
import (
"context"
"fmt"
"net/http"
"strings"
)
func (c *Client) DeleteGroup(ctx context.Context, groupID string) error {
return c.deleteResource(ctx, "/api/v1/admin/groups/", groupID)
}
func (c *Client) DeleteChannel(ctx context.Context, channelID string) error {
return c.deleteResource(ctx, "/api/v1/admin/channels/", channelID)
}
func (c *Client) DeletePlan(ctx context.Context, planID string) error {
return c.deleteResource(ctx, "/api/v1/admin/payment/plans/", planID)
}
func (c *Client) DeleteAccount(ctx context.Context, accountID string) error {
return c.deleteResource(ctx, "/api/v1/admin/accounts/", accountID)
}
func (c *Client) deleteResource(ctx context.Context, prefix, resourceID string) error {
resourceID = strings.TrimSpace(resourceID)
if resourceID == "" {
return fmt.Errorf("resource id is required")
}
path := prefix + resourceID
statusCode, _, body, err := c.perform(ctx, http.MethodDelete, path, nil)
if err != nil {
return err
}
if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices {
return newHTTPError(http.MethodDelete, path, statusCode, body)
}
return nil
}

View File

@@ -0,0 +1,62 @@
package sub2api
import (
"context"
"encoding/json"
"net/http"
"strings"
)
type GatewayAccessCheckRequest struct {
APIKey string
ExpectedModel string
}
type GatewayAccessResult struct {
OK bool `json:"ok"`
StatusCode int `json:"status_code"`
Models []string `json:"models"`
HasExpectedModel bool `json:"has_expected_model"`
}
func (c *Client) CheckGatewayAccess(ctx context.Context, req GatewayAccessCheckRequest) (GatewayAccessResult, error) {
gatewayClient := *c
gatewayClient.apiKey = strings.TrimSpace(req.APIKey)
gatewayClient.bearerToken = ""
statusCode, _, body, err := gatewayClient.perform(ctx, http.MethodGet, "/v1/models", nil)
if err != nil {
return GatewayAccessResult{}, err
}
result := GatewayAccessResult{StatusCode: statusCode, OK: statusCode >= http.StatusOK && statusCode < http.StatusMultipleChoices}
if !result.OK {
return result, nil
}
result.Models = decodeGatewayModelIDs(body)
for _, modelID := range result.Models {
if modelID == strings.TrimSpace(req.ExpectedModel) {
result.HasExpectedModel = true
break
}
}
return result, nil
}
func decodeGatewayModelIDs(body []byte) []string {
var payload struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
if err := json.Unmarshal(body, &payload); err == nil && len(payload.Data) > 0 {
models := make([]string, 0, len(payload.Data))
for _, item := range payload.Data {
if id := strings.TrimSpace(item.ID); id != "" {
models = append(models, id)
}
}
return models
}
return nil
}

View File

@@ -0,0 +1,97 @@
package sub2api
import (
"context"
"encoding/json"
"fmt"
"strings"
)
func (c *Client) ListManagedResources(ctx context.Context, req ListManagedResourcesRequest) (ManagedResourceSnapshot, error) {
groups, err := c.listNamedResources(ctx, "/api/v1/admin/groups", req.GroupName)
if err != nil {
return ManagedResourceSnapshot{}, fmt.Errorf("list groups: %w", err)
}
channels, err := c.listNamedResources(ctx, "/api/v1/admin/channels", req.ChannelName)
if err != nil {
return ManagedResourceSnapshot{}, fmt.Errorf("list channels: %w", err)
}
plans, err := c.listNamedResources(ctx, "/api/v1/admin/payment/plans", req.PlanName)
if err != nil {
return ManagedResourceSnapshot{}, fmt.Errorf("list plans: %w", err)
}
accounts, err := c.listNamedResources(ctx, "/api/v1/admin/accounts", "")
if err != nil {
return ManagedResourceSnapshot{}, fmt.Errorf("list accounts: %w", err)
}
return ManagedResourceSnapshot{
Groups: groups,
Channels: channels,
Plans: plans,
Accounts: filterNamedResourcesByPrefix(accounts, req.AccountNamePrefix),
}, nil
}
func (c *Client) listNamedResources(ctx context.Context, path, expectedName string) ([]NamedResource, error) {
statusCode, _, body, err := c.perform(ctx, "GET", path, nil)
if err != nil {
return nil, err
}
if statusCode < 200 || statusCode >= 300 {
return nil, newHTTPError("GET", path, statusCode, body)
}
resources, err := decodeNamedResources(body)
if err != nil {
return nil, fmt.Errorf("decode %s response: %w", path, err)
}
return filterNamedResourcesByName(resources, expectedName), nil
}
func decodeNamedResources(body []byte) ([]NamedResource, error) {
var resources []NamedResource
if err := decodeEnvelopeObject(body, &resources); err == nil {
return resources, nil
}
var wrapper struct {
Data struct {
Items []NamedResource `json:"items"`
} `json:"data"`
}
if err := json.Unmarshal(body, &wrapper); err != nil {
return nil, err
}
return wrapper.Data.Items, nil
}
func filterNamedResourcesByName(resources []NamedResource, expectedName string) []NamedResource {
expectedName = strings.TrimSpace(expectedName)
if expectedName == "" {
return resources
}
filtered := make([]NamedResource, 0, len(resources))
for _, resource := range resources {
if strings.TrimSpace(resource.Name) == expectedName {
filtered = append(filtered, resource)
}
}
return filtered
}
func filterNamedResourcesByPrefix(resources []NamedResource, prefix string) []NamedResource {
prefix = strings.TrimSpace(prefix)
if prefix == "" {
return resources
}
filtered := make([]NamedResource, 0, len(resources))
for _, resource := range resources {
if strings.HasPrefix(strings.TrimSpace(resource.Name), prefix) {
filtered = append(filtered, resource)
}
}
return filtered
}

View File

@@ -0,0 +1,20 @@
package sub2api
type ListManagedResourcesRequest struct {
GroupName string
ChannelName string
PlanName string
AccountNamePrefix string
}
type NamedResource struct {
ID string `json:"id"`
Name string `json:"name"`
}
type ManagedResourceSnapshot struct {
Groups []NamedResource `json:"groups"`
Channels []NamedResource `json:"channels"`
Plans []NamedResource `json:"plans"`
Accounts []NamedResource `json:"accounts"`
}

View File

@@ -0,0 +1,699 @@
package sub2api
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
)
func TestHTTPErrorErrorMessage(t *testing.T) {
e := newHTTPError("POST", "/api/v1/admin/groups", http.StatusTeapot, []byte("short and stout"))
want := "sub2api POST /api/v1/admin/groups returned 418: short and stout"
if got := e.Error(); got != want {
t.Fatalf("HTTPError.Error() = %q, want %q", got, want)
}
}
func TestWithHTTPClientAndOptions(t *testing.T) {
customHTTP := &http.Client{Timeout: 123}
client, err := NewClient("http://localhost:8080",
WithHTTPClient(customHTTP),
WithAPIKey(" sk-abc "),
WithBearerToken(" tok-xyz "),
)
if err != nil {
t.Fatal(err)
}
if client.httpClient != customHTTP {
t.Fatal("WithHTTPClient not applied")
}
if client.apiKey != "sk-abc" {
t.Fatalf("apiKey = %q, want %q", client.apiKey, "sk-abc")
}
if client.bearerToken != "tok-xyz" {
t.Fatalf("bearerToken = %q, want %q", client.bearerToken, "tok-xyz")
}
}
func TestNewClient_RejectsInvalidURLs(t *testing.T) {
tests := []struct {
name string
url string
}{
{"empty", ""},
{"no scheme", "localhost:8080"},
{"no host", "http://"},
{"garbage", "://foo"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewClient(tt.url)
if err == nil {
t.Fatalf("NewClient(%q) error = nil, want error", tt.url)
}
})
}
}
func TestResolvePath(t *testing.T) {
client, err := NewClient("http://host:9090")
if err != nil {
t.Fatal(err)
}
tests := []struct {
path string
want string
}{
{"/v1/models", "http://host:9090/v1/models"},
{"v1/models", "http://host:9090/v1/models"},
{"/v1/models?key=val", "http://host:9090/v1/models?key=val"},
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
if got := client.resolvePath(tt.path); got != tt.want {
t.Fatalf("resolvePath(%q) = %q, want %q", tt.path, got, tt.want)
}
})
}
}
func TestApplyAuth(t *testing.T) {
t.Run("api key preferred", func(t *testing.T) {
c, _ := NewClient("http://h:8080", WithAPIKey("key1"), WithBearerToken("btok"))
req, _ := http.NewRequest("GET", "http://h:8080/path", nil)
c.applyAuth(req)
if h := req.Header.Get("x-api-key"); h != "key1" {
t.Fatalf("x-api-key = %q, want %q", h, "key1")
}
if h := req.Header.Get("Authorization"); h != "" {
t.Fatalf("Authorization should be empty, got %q", h)
}
})
t.Run("bearer token fallback", func(t *testing.T) {
c, _ := NewClient("http://h:8080", WithBearerToken("btok"))
req, _ := http.NewRequest("GET", "http://h:8080/path", nil)
c.applyAuth(req)
if h := req.Header.Get("Authorization"); h != "Bearer btok" {
t.Fatalf("Authorization = %q, want %q", h, "Bearer btok")
}
})
t.Run("no auth", func(t *testing.T) {
c, _ := NewClient("http://h:8080")
req, _ := http.NewRequest("GET", "http://h:8080/path", nil)
c.applyAuth(req)
if h := req.Header.Get("x-api-key"); h != "" {
t.Fatalf("x-api-key should be empty, got %q", h)
}
if h := req.Header.Get("Authorization"); h != "" {
t.Fatalf("Authorization should be empty, got %q", h)
}
})
}
func TestDecodeEnvelopeObject(t *testing.T) {
t.Run("standard envelope", func(t *testing.T) {
body := []byte(`{"data":{"id":"g1","name":"test"}}`)
var ref GroupRef
if err := decodeEnvelopeObject(body, &ref); err != nil {
t.Fatal(err)
}
if ref.ID != "g1" || ref.Name != "test" {
t.Fatalf("got %+v, want {ID:g1 Name:test}", ref)
}
})
t.Run("flat response (no data wrapper)", func(t *testing.T) {
body := []byte(`{"id":"g2","name":"flat"}`)
var ref GroupRef
if err := decodeEnvelopeObject(body, &ref); err != nil {
t.Fatal(err)
}
if ref.ID != "g2" || ref.Name != "flat" {
t.Fatalf("got %+v, want {ID:g2 Name:flat}", ref)
}
})
t.Run("data:null returns flat", func(t *testing.T) {
body := []byte(`{"data":null,"id":"g3"}`)
var ref GroupRef
if err := decodeEnvelopeObject(body, &ref); err != nil {
t.Fatal(err)
}
if ref.ID != "g3" {
t.Fatalf("id = %q, want %q", ref.ID, "g3")
}
})
t.Run("invalid json returns error", func(t *testing.T) {
var ref GroupRef
if err := decodeEnvelopeObject([]byte(`not json`), &ref); err == nil {
t.Fatal("expected error")
}
})
}
func TestDecodeGatewayModelIDs(t *testing.T) {
t.Run("standard list", func(t *testing.T) {
ids := decodeGatewayModelIDs([]byte(`{"data":[{"id":"gpt-4"},{"id":" claude-3 "}]}`))
if len(ids) != 2 || ids[0] != "gpt-4" || ids[1] != "claude-3" {
t.Fatalf("got %v, want [gpt-4 claude-3]", ids)
}
})
t.Run("empty data", func(t *testing.T) {
if ids := decodeGatewayModelIDs([]byte(`{}`)); ids != nil {
t.Fatalf("expected nil, got %v", ids)
}
})
t.Run("invalid json", func(t *testing.T) {
if ids := decodeGatewayModelIDs([]byte(`not json`)); ids != nil {
t.Fatalf("expected nil, got %v", ids)
}
})
t.Run("empty array", func(t *testing.T) {
if ids := decodeGatewayModelIDs([]byte(`{"data":[]}`)); ids != nil {
t.Fatalf("expected nil, got %v", ids)
}
})
}
func TestFilterNamedResourcesByName(t *testing.T) {
resources := []NamedResource{
{Name: "group-a", ID: "g1"},
{Name: "group-b", ID: "g2"},
{Name: " group-a ", ID: "g3"},
}
t.Run("match", func(t *testing.T) {
got := filterNamedResourcesByName(resources, "group-a")
if len(got) != 2 || got[0].ID != "g1" || got[1].ID != "g3" {
t.Fatalf("got %+v, want 2 matches", got)
}
})
t.Run("no match", func(t *testing.T) {
if got := filterNamedResourcesByName(resources, "nonexistent"); len(got) != 0 {
t.Fatalf("expected 0, got %d", len(got))
}
})
t.Run("empty name returns all", func(t *testing.T) {
if got := filterNamedResourcesByName(resources, ""); len(got) != 3 {
t.Fatalf("expected 3, got %d", len(got))
}
})
}
func TestFilterNamedResourcesByPrefix(t *testing.T) {
resources := []NamedResource{
{Name: "deepseek-proxy", ID: "r1"},
{Name: "deepseek-us", ID: "r2"},
{Name: "claude-eu", ID: "r3"},
}
t.Run("prefix matches", func(t *testing.T) {
got := filterNamedResourcesByPrefix(resources, "deepseek")
if len(got) != 2 {
t.Fatalf("expected 2, got %d", len(got))
}
})
t.Run("no prefix match", func(t *testing.T) {
if got := filterNamedResourcesByPrefix(resources, "nope"); len(got) != 0 {
t.Fatalf("expected 0, got %d", len(got))
}
})
t.Run("empty prefix returns all", func(t *testing.T) {
if got := filterNamedResourcesByPrefix(resources, ""); len(got) != 3 {
t.Fatalf("expected 3, got %d", len(got))
}
})
}
func TestDecodeNamedResources(t *testing.T) {
t.Run("envelope", func(t *testing.T) {
resources, err := decodeNamedResources([]byte(`{"data":[{"id":"r1","name":"n1"}]}`))
if err != nil {
t.Fatal(err)
}
if len(resources) != 1 || resources[0].ID != "r1" {
t.Fatalf("got %+v", resources)
}
})
t.Run("wrapper with items", func(t *testing.T) {
resources, err := decodeNamedResources([]byte(`{"data":{"items":[{"id":"r2","name":"n2"}]}}`))
if err != nil {
t.Fatal(err)
}
if len(resources) != 1 || resources[0].ID != "r2" {
t.Fatalf("got %+v", resources)
}
})
t.Run("invalid json", func(t *testing.T) {
_, err := decodeNamedResources([]byte(`not json`))
if err == nil {
t.Fatal("expected error")
}
})
}
func TestDecodeAccountRefs(t *testing.T) {
t.Run("envelope", func(t *testing.T) {
refs, err := decodeAccountRefs([]byte(`{"data":[{"id":"a1"}]}`))
if err != nil {
t.Fatal(err)
}
if len(refs) != 1 || refs[0].ID != "a1" {
t.Fatalf("got %+v", refs)
}
})
t.Run("wrapper with items", func(t *testing.T) {
refs, err := decodeAccountRefs([]byte(`{"data":{"items":[{"id":"a2"}]}}`))
if err != nil {
t.Fatal(err)
}
if len(refs) != 1 || refs[0].ID != "a2" {
t.Fatalf("got %+v", refs)
}
})
t.Run("invalid json", func(t *testing.T) {
_, err := decodeAccountRefs([]byte(`not json`))
if err == nil {
t.Fatal("expected error")
}
})
}
func TestDecodeAccountModels(t *testing.T) {
t.Run("envelope", func(t *testing.T) {
models, err := decodeAccountModels([]byte(`{"data":[{"id":"gpt4","display_name":"GPT-4","type":"chat"}]}`))
if err != nil {
t.Fatal(err)
}
if len(models) != 1 || models[0].ID != "gpt4" {
t.Fatalf("got %+v", models)
}
})
t.Run("wrapper with items", func(t *testing.T) {
models, err := decodeAccountModels([]byte(`{"data":{"items":[{"id":"cl3","display_name":"Claude 3","type":"chat"}]}}`))
if err != nil {
t.Fatal(err)
}
if len(models) != 1 || models[0].ID != "cl3" {
t.Fatalf("got %+v", models)
}
})
t.Run("invalid json", func(t *testing.T) {
_, err := decodeAccountModels([]byte(`not json`))
if err == nil {
t.Fatal("expected error")
}
})
}
func TestParseProbeResult(t *testing.T) {
t.Run("SSE with ok=true", func(t *testing.T) {
result, err := parseProbeResult([]byte("data: {\"status\":\"passed\",\"ok\":true}\n"))
if err != nil {
t.Fatal(err)
}
if !result.OK || result.Status != "passed" {
t.Fatalf("got %+v, want OK=true Status=passed", result)
}
})
t.Run("SSE with success=true", func(t *testing.T) {
result, err := parseProbeResult([]byte("data: {\"status\":\"succeeded\",\"success\":true}\n"))
if err != nil {
t.Fatal(err)
}
if !result.OK || result.Status != "passed" {
t.Fatalf("got %+v", result)
}
})
t.Run("SSE with ok=false", func(t *testing.T) {
result, err := parseProbeResult([]byte("data: {\"status\":\"failed\",\"ok\":false}\n"))
if err != nil {
t.Fatal(err)
}
if result.OK || result.Status != "failed" {
t.Fatalf("got %+v", result)
}
})
t.Run("SSE with status-based ok", func(t *testing.T) {
result, err := parseProbeResult([]byte("data: {\"status\":\"pass\",\"message\":\"all good\"}\n"))
if err != nil {
t.Fatal(err)
}
if !result.OK || result.Message != "all good" {
t.Fatalf("got %+v", result)
}
})
t.Run("multiple SSE events picks last", func(t *testing.T) {
result, err := parseProbeResult([]byte("data: {\"status\":\"running\"}\ndata: {\"status\":\"passed\",\"ok\":true}\n"))
if err != nil {
t.Fatal(err)
}
if !result.OK {
t.Fatalf("expected OK=true from last event, got %+v", result)
}
})
t.Run("no data events", func(t *testing.T) {
_, err := parseProbeResult([]byte("not data\n"))
if err == nil {
t.Fatal("expected error")
}
})
}
func TestNormalizeProbeStatus(t *testing.T) {
tests := []struct {
status string
ok bool
want string
}{
{"pass", true, "passed"},
{"PASSED", true, "passed"},
{"Ok", true, "passed"},
{"success", true, "passed"},
{"succeeded", true, "passed"},
{"fail", false, "failed"},
{"FAILED", false, "failed"},
{"error", false, "failed"},
{"custom_ok", true, "passed"},
{"custom_fail", false, "failed"},
}
for _, tt := range tests {
t.Run(tt.status, func(t *testing.T) {
if got := normalizeProbeStatus(tt.status, tt.ok); got != tt.want {
t.Fatalf("normalizeProbeStatus(%q, %v) = %q, want %q", tt.status, tt.ok, got, tt.want)
}
})
}
}
func TestLooksLikeExistingEndpoint(t *testing.T) {
t.Run("json content type", func(t *testing.T) {
h := http.Header{"Content-Type": []string{"application/json"}}
if !looksLikeExistingEndpoint(h, nil) {
t.Fatal("expected true with json content type")
}
})
t.Run("sse content type", func(t *testing.T) {
h := http.Header{"Content-Type": []string{"text/event-stream"}}
if !looksLikeExistingEndpoint(h, nil) {
t.Fatal("expected true with sse content type")
}
})
t.Run("empty body and no content type", func(t *testing.T) {
if looksLikeExistingEndpoint(http.Header{}, nil) {
t.Fatal("expected false")
}
})
t.Run("json-like body", func(t *testing.T) {
if !looksLikeExistingEndpoint(http.Header{}, []byte(`{"error":"not found"}`)) {
t.Fatal("expected true for json body")
}
})
t.Run("array body", func(t *testing.T) {
if !looksLikeExistingEndpoint(http.Header{}, []byte(`[]`)) {
t.Fatal("expected true for array body")
}
})
t.Run("html body", func(t *testing.T) {
if looksLikeExistingEndpoint(http.Header{}, []byte(`<html>`)) {
t.Fatal("expected false for html body")
}
})
}
// Tests for NamedResource type used by the filter functions.
// Defined locally since it's in the same package.
func TestNewClientWithNilOption(t *testing.T) {
client, err := NewClient("http://localhost:8080", nil)
if err != nil {
t.Fatal(err)
}
if client == nil {
t.Fatal("client is nil")
}
}
func TestNewHTTPError(t *testing.T) {
e := newHTTPError("GET", "/v1/models", 200, []byte(`{"ok":true}`))
if e.Method != "GET" || e.Path != "/v1/models" || e.StatusCode != 200 || e.Body != `{"ok":true}` {
t.Fatalf("unexpected http error: %+v", e)
}
}
func TestPerformWithMockServer(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v1/admin/system/version":
w.Write([]byte(`{"data":{"version":"v1.2.3"}}`))
case "/api/v1/admin/groups":
w.Write([]byte(`{"data":{"id":"g1","name":"test-group"}}`))
case "/api/v1/admin/channels":
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error":"panic"}`))
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
client, err := NewClient(srv.URL, WithAPIKey("test-key"))
if err != nil {
t.Fatal(err)
}
t.Run("GetHostVersion", func(t *testing.T) {
ver, err := client.GetHostVersion(context.Background())
if err != nil {
t.Fatal(err)
}
if ver != "v1.2.3" {
t.Fatalf("version = %q, want %q", ver, "v1.2.3")
}
})
t.Run("postJSON success", func(t *testing.T) {
var ref GroupRef
if err := client.postJSON(context.Background(), "/api/v1/admin/groups", CreateGroupRequest{Name: "test"}, &ref); err != nil {
t.Fatal(err)
}
if ref.ID != "g1" || ref.Name != "test-group" {
t.Fatalf("got %+v, want {ID:g1 Name:test-group}", ref)
}
})
t.Run("postJSON error status", func(t *testing.T) {
var ref GroupRef
err := client.postJSON(context.Background(), "/api/v1/admin/channels", nil, &ref)
if err == nil {
t.Fatal("expected error")
}
var httpErr *HTTPError
if !errors.As(err, &httpErr) {
t.Fatalf("expected HTTPError, got %T: %v", err, err)
}
if httpErr.StatusCode != 500 {
t.Fatalf("status code = %d, want 500", httpErr.StatusCode)
}
})
t.Run("getJSON success", func(t *testing.T) {
var ref GroupRef
if err := client.getJSON(context.Background(), "/api/v1/admin/groups", &ref); err != nil {
t.Fatal(err)
}
})
t.Run("getJSON error status", func(t *testing.T) {
var ref GroupRef
err := client.getJSON(context.Background(), "/bad/path", &ref)
if err == nil {
t.Fatal("expected error")
}
})
}
func TestCreateGroupWithMock(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"data":{"id":"g1","name":"demo"}}`))
}))
defer srv.Close()
client, err := NewClient(srv.URL, WithAPIKey("k"))
if err != nil {
t.Fatal(err)
}
ref, err := client.CreateGroup(context.Background(), CreateGroupRequest{Name: "demo", RateMultiplier: 1.0})
if err != nil {
t.Fatal(err)
}
if ref.ID != "g1" || ref.Name != "demo" {
t.Fatalf("got %+v", ref)
}
}
func TestCreateChannelWithMock(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"data":{"id":"c1","name":"ch"}}`))
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithAPIKey("k"))
_, err := client.CreateChannel(context.Background(), CreateChannelRequest{Name: "ch"})
if err != nil {
t.Fatal(err)
}
}
func TestCreatePlanWithMock(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"data":{"id":"p1","name":"plan"}}`))
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithAPIKey("k"))
_, err := client.CreatePlan(context.Background(), CreatePlanRequest{Name: "plan"})
if err != nil {
t.Fatal(err)
}
}
func TestDeleteWithMock(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithAPIKey("k"))
t.Run("DeleteGroup", func(t *testing.T) {
if err := client.DeleteGroup(context.Background(), "g1"); err != nil {
t.Fatal(err)
}
})
t.Run("DeleteChannel", func(t *testing.T) {
if err := client.DeleteChannel(context.Background(), "c1"); err != nil {
t.Fatal(err)
}
})
t.Run("DeletePlan", func(t *testing.T) {
if err := client.DeletePlan(context.Background(), "p1"); err != nil {
t.Fatal(err)
}
})
t.Run("DeleteAccount", func(t *testing.T) {
if err := client.DeleteAccount(context.Background(), "a1"); err != nil {
t.Fatal(err)
}
})
}
func TestAssignSubscriptionWithMock(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"data":{"id":"s1"}}`))
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithAPIKey("k"))
ref, err := client.AssignSubscription(context.Background(), AssignSubscriptionRequest{UserID: "u1"})
if err != nil {
t.Fatal(err)
}
if ref.ID != "s1" {
t.Fatalf("id = %q", ref.ID)
}
}
func TestCheckGatewayAccessWithMock(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"data":[{"id":"gpt-4"},{"id":"claude-3"}]}`))
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithAPIKey("k"))
result, err := client.CheckGatewayAccess(context.Background(), GatewayAccessCheckRequest{APIKey: "gk", ExpectedModel: "gpt-4"})
if err != nil {
t.Fatal(err)
}
if !result.OK {
t.Fatal("expected OK=true")
}
if !result.HasExpectedModel {
t.Fatal("expected HasExpectedModel=true")
}
}
func TestBatchCreateAccountsWithMock(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"data":[{"id":"a1","name":"acct1"}]}`))
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithAPIKey("k"))
refs, err := client.BatchCreateAccounts(context.Background(), BatchCreateAccountsRequest{
Accounts: []CreateAccountRequest{{Name: "acct1"}},
})
if err != nil {
t.Fatal(err)
}
if len(refs) != 1 || refs[0].ID != "a1" {
t.Fatalf("got %+v", refs)
}
}
func TestProbeCapabilitiesWithMock(t *testing.T) {
callCount := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.WriteHeader(http.StatusCreated)
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithAPIKey("k"))
caps, err := client.ProbeCapabilities(context.Background())
if err != nil {
t.Fatal(err)
}
if !caps.Groups || !caps.Channels || !caps.Plans || !caps.Accounts || !caps.AccountTest || !caps.AccountModels || !caps.Subscriptions {
t.Fatalf("all capabilities should be true, got %+v", caps)
}
}
func TestListManagedResourcesWithMock(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"data":{"items":[
{"id":"r1","name":"resource-1"}
]}}`))
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithAPIKey("k"))
snapshot, err := client.ListManagedResources(context.Background(), ListManagedResourcesRequest{})
if err != nil {
t.Fatal(err)
}
if len(snapshot.Groups) != 1 {
t.Fatalf("expected 1 group, got %d", len(snapshot.Groups))
}
}
func TestTestAccountWithMock(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("data: {\"status\":\"passed\",\"ok\":true}\n"))
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithAPIKey("k"))
result, err := client.TestAccount(context.Background(), "a1")
if err != nil {
t.Fatal(err)
}
if !result.OK {
t.Fatal("expected OK=true")
}
}
func TestGetAccountModelsWithMock(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"data":[{"id":"m1","display_name":"M1","type":"chat"}]}`))
}))
defer srv.Close()
client, _ := NewClient(srv.URL, WithAPIKey("k"))
models, err := client.GetAccountModels(context.Background(), "a1")
if err != nil {
t.Fatal(err)
}
if len(models) != 1 || models[0].ID != "m1" {
t.Fatalf("got %+v", models)
}
}

266
internal/pack/extra_test.go Normal file
View File

@@ -0,0 +1,266 @@
package pack
import (
"archive/zip"
"os"
"path/filepath"
"strings"
"testing"
)
func TestValidateManifestRequiredFields(t *testing.T) {
tests := []struct {
name string
manifest Manifest
wantErr string
}{
{
name: "empty pack id",
manifest: Manifest{
Version: "1.0.0",
TargetHost: "sub2api",
ProvidersDir: "providers",
ChecksumFile: "checksums.txt",
},
wantErr: "pack_id is required",
},
{
name: "empty version",
manifest: Manifest{
PackID: "openai-cn-pack",
TargetHost: "sub2api",
ProvidersDir: "providers",
ChecksumFile: "checksums.txt",
},
wantErr: "version is required",
},
{
name: "empty target host",
manifest: Manifest{
PackID: "openai-cn-pack",
Version: "1.0.0",
ProvidersDir: "providers",
ChecksumFile: "checksums.txt",
},
wantErr: "target_host is required",
},
{
name: "empty providers dir",
manifest: Manifest{
PackID: "openai-cn-pack",
Version: "1.0.0",
TargetHost: "sub2api",
ChecksumFile: "checksums.txt",
},
wantErr: "providers_dir is required",
},
{
name: "empty checksum file",
manifest: Manifest{
PackID: "openai-cn-pack",
Version: "1.0.0",
TargetHost: "sub2api",
ProvidersDir: "providers",
},
wantErr: "checksum_file is required",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateManifest(tt.manifest)
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("validateManifest() error = %v, want substring %q", err, tt.wantErr)
}
})
}
}
func TestValidateProvidersRejectsInvalidProviderFields(t *testing.T) {
tests := []struct {
name string
providers []ProviderManifest
wantErr string
}{
{
name: "empty provider id",
providers: []ProviderManifest{{
DisplayName: "DeepSeek",
BaseURL: "https://api.deepseek.com",
Platform: "openai",
AccountType: "api",
DefaultModels: []string{"deepseek-chat"},
SmokeTestModel: "deepseek-chat",
GroupTemplate: GroupTemplate{Name: "g"},
ChannelTemplate: ChannelTemplate{
Name: "c",
ModelMapping: map[string]string{"deepseek-chat": "deepseek-chat"},
},
PlanTemplate: PlanTemplate{Name: "p", ValidityDays: 30},
}},
wantErr: "provider_id is required",
},
{
name: "empty base url",
providers: []ProviderManifest{{
ProviderID: "deepseek",
DisplayName: "DeepSeek",
BaseURL: "",
Platform: "openai",
AccountType: "api",
DefaultModels: []string{"deepseek-chat"},
SmokeTestModel: "deepseek-chat",
GroupTemplate: GroupTemplate{Name: "g"},
ChannelTemplate: ChannelTemplate{
Name: "c",
ModelMapping: map[string]string{"deepseek-chat": "deepseek-chat"},
},
PlanTemplate: PlanTemplate{Name: "p", ValidityDays: 30},
}},
wantErr: "base_url must use https",
},
{
name: "missing display name",
providers: []ProviderManifest{{
ProviderID: "deepseek",
DisplayName: "",
BaseURL: "https://api.deepseek.com",
Platform: "openai",
AccountType: "api",
DefaultModels: []string{"deepseek-chat"},
SmokeTestModel: "deepseek-chat",
GroupTemplate: GroupTemplate{Name: "g"},
ChannelTemplate: ChannelTemplate{
Name: "c",
ModelMapping: map[string]string{"deepseek-chat": "deepseek-chat"},
},
PlanTemplate: PlanTemplate{Name: "p", ValidityDays: 30},
}},
wantErr: "display_name is required",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateProviders(tt.providers)
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("validateProviders() error = %v, want substring %q", err, tt.wantErr)
}
})
}
}
func TestExtractZipToTempRejectsEmptyArchive(t *testing.T) {
archivePath := filepath.Join(t.TempDir(), "empty.zip")
file, err := os.Create(archivePath)
if err != nil {
t.Fatalf("os.Create() error = %v", err)
}
writer := zip.NewWriter(file)
if err := writer.Close(); err != nil {
t.Fatalf("writer.Close() error = %v", err)
}
if err := file.Close(); err != nil {
t.Fatalf("file.Close() error = %v", err)
}
_, cleanup, err := extractZipToTemp(archivePath)
if cleanup != nil {
cleanup()
}
if err == nil || !strings.Contains(err.Error(), "pack archive is empty") {
t.Fatalf("extractZipToTemp() error = %v, want empty archive error", err)
}
}
func TestExtractZipToTempRejectsPathTraversal(t *testing.T) {
archivePath := filepath.Join(t.TempDir(), "traversal.zip")
file, err := os.Create(archivePath)
if err != nil {
t.Fatalf("os.Create() error = %v", err)
}
writer := zip.NewWriter(file)
entry, err := writer.Create("../../../evil.txt")
if err != nil {
t.Fatalf("writer.Create() error = %v", err)
}
if _, err := entry.Write([]byte("evil")); err != nil {
t.Fatalf("entry.Write() error = %v", err)
}
if err := writer.Close(); err != nil {
t.Fatalf("writer.Close() error = %v", err)
}
if err := file.Close(); err != nil {
t.Fatalf("file.Close() error = %v", err)
}
_, cleanup, err := extractZipToTemp(archivePath)
if cleanup != nil {
cleanup()
}
if err == nil || !strings.Contains(err.Error(), "invalid path") {
t.Fatalf("extractZipToTemp() error = %v, want invalid path error", err)
}
}
func TestMatchesMaxConstraintCases(t *testing.T) {
tests := []struct {
name string
hostVersion string
maxVersion string
want bool
}{
{name: "exact version match", hostVersion: "1.2.3", maxVersion: "1.2.3", want: true},
{name: "wildcard x accepts same minor", hostVersion: "0.2.9", maxVersion: "0.2.x", want: true},
{name: "non matching version", hostVersion: "1.2.4", maxVersion: "1.2.3", want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := matchesMaxConstraint(tt.hostVersion, tt.maxVersion)
if err != nil {
t.Fatalf("matchesMaxConstraint() error = %v", err)
}
if got != tt.want {
t.Fatalf("matchesMaxConstraint() = %v, want %v", got, tt.want)
}
})
}
}
func TestMatchesMaxConstraintRejectsWildcardStar(t *testing.T) {
_, err := matchesMaxConstraint("1.2.3", "1.2.*")
if err == nil || !strings.Contains(err.Error(), `parse version segment "*"`) {
t.Fatalf("matchesMaxConstraint() error = %v, want parse failure for wildcard star", err)
}
}
func TestLoadPathRejectsEmptyAndMissingPath(t *testing.T) {
tests := []struct {
name string
path string
wantErr string
}{
{name: "empty path", path: " ", wantErr: "pack path is required"},
{name: "missing path", path: filepath.Join(t.TempDir(), "missing-pack"), wantErr: "stat pack path"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := LoadPath(tt.path)
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("LoadPath() error = %v, want substring %q", err, tt.wantErr)
}
})
}
}
func TestLoadArchiveRejectsNonZipFile(t *testing.T) {
filePath := filepath.Join(t.TempDir(), "not-a-zip.zip")
mustWrite(t, filePath, "plain text, not a zip archive")
_, err := LoadArchive(filePath)
if err == nil || !strings.Contains(err.Error(), "open pack archive") {
t.Fatalf("LoadArchive() error = %v, want open archive error", err)
}
}

249
internal/pack/loader.go Normal file
View File

@@ -0,0 +1,249 @@
package pack
import (
"bufio"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
)
type Manifest struct {
PackID string `json:"pack_id"`
Version string `json:"version"`
Vendor string `json:"vendor"`
TargetHost string `json:"target_host"`
MinHostVersion string `json:"min_host_version"`
MaxHostVersion string `json:"max_host_version"`
ProvidersDir string `json:"providers_dir"`
ChecksumFile string `json:"checksum_file"`
}
type ProviderManifest struct {
ProviderID string `json:"provider_id"`
DisplayName string `json:"display_name"`
BaseURL string `json:"base_url"`
Platform string `json:"platform"`
AccountType string `json:"account_type"`
DefaultModels []string `json:"default_models"`
SmokeTestModel string `json:"smoke_test_model"`
GroupTemplate GroupTemplate `json:"group_template"`
ChannelTemplate ChannelTemplate `json:"channel_template"`
PlanTemplate PlanTemplate `json:"plan_template"`
Import ImportOptions `json:"import"`
}
type GroupTemplate struct {
Name string `json:"name"`
RateMultiplier float64 `json:"rate_multiplier"`
}
type ChannelTemplate struct {
Name string `json:"name"`
ModelMapping map[string]string `json:"model_mapping"`
}
type PlanTemplate struct {
Name string `json:"name"`
Price float64 `json:"price"`
ValidityDays int `json:"validity_days"`
ValidityUnit string `json:"validity_unit"`
}
type ImportOptions struct {
SupportsMultiKey bool `json:"supports_multi_key"`
SupportsStrict bool `json:"supports_strict"`
SupportsPartial bool `json:"supports_partial"`
}
type LoadedPack struct {
Dir string
Manifest Manifest
Providers []ProviderManifest
Checksum string
}
func LoadDir(dir string) (LoadedPack, error) {
root := strings.TrimSpace(dir)
if root == "" {
return LoadedPack{}, fmt.Errorf("pack dir is required")
}
manifestPath := filepath.Join(root, "pack.json")
manifestBytes, err := os.ReadFile(manifestPath)
if err != nil {
return LoadedPack{}, fmt.Errorf("read pack.json: %w", err)
}
var manifest Manifest
if err := json.Unmarshal(manifestBytes, &manifest); err != nil {
return LoadedPack{}, fmt.Errorf("decode pack.json: %w", err)
}
if err := validateManifest(manifest); err != nil {
return LoadedPack{}, err
}
if err := validateChecksums(root, manifest.ChecksumFile); err != nil {
return LoadedPack{}, err
}
providers, err := loadProviders(root, manifest.ProvidersDir)
if err != nil {
return LoadedPack{}, err
}
if len(providers) == 0 {
return LoadedPack{}, fmt.Errorf("providers dir %q does not contain provider manifests", manifest.ProvidersDir)
}
if err := validateProviders(providers); err != nil {
return LoadedPack{}, err
}
checksum, err := computeAggregateChecksum(root, manifest.ChecksumFile)
if err != nil {
return LoadedPack{}, err
}
return LoadedPack{Dir: root, Manifest: manifest, Providers: providers, Checksum: checksum}, nil
}
func validateManifest(manifest Manifest) error {
switch {
case strings.TrimSpace(manifest.PackID) == "":
return fmt.Errorf("pack.json: pack_id is required")
case strings.TrimSpace(manifest.Version) == "":
return fmt.Errorf("pack.json: version is required")
case strings.TrimSpace(manifest.TargetHost) == "":
return fmt.Errorf("pack.json: target_host is required")
case strings.TrimSpace(manifest.ProvidersDir) == "":
return fmt.Errorf("pack.json: providers_dir is required")
case strings.TrimSpace(manifest.ChecksumFile) == "":
return fmt.Errorf("pack.json: checksum_file is required")
}
return nil
}
func loadProviders(root string, providersDir string) ([]ProviderManifest, error) {
dir := filepath.Join(root, providersDir)
entries, err := os.ReadDir(dir)
if err != nil {
return nil, fmt.Errorf("read providers dir %q: %w", providersDir, err)
}
providers := make([]ProviderManifest, 0, len(entries))
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") {
continue
}
path := filepath.Join(dir, entry.Name())
body, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read provider %q: %w", entry.Name(), err)
}
var provider ProviderManifest
if err := json.Unmarshal(body, &provider); err != nil {
return nil, fmt.Errorf("decode provider %q: %w", entry.Name(), err)
}
providers = append(providers, provider)
}
sort.Slice(providers, func(i, j int) bool { return providers[i].ProviderID < providers[j].ProviderID })
return providers, nil
}
func validateProviders(providers []ProviderManifest) error {
seen := make(map[string]struct{}, len(providers))
for _, provider := range providers {
providerID := strings.TrimSpace(provider.ProviderID)
switch {
case providerID == "":
return fmt.Errorf("provider manifest: provider_id is required")
case strings.TrimSpace(provider.DisplayName) == "":
return fmt.Errorf("provider %q: display_name is required", providerID)
case !strings.HasPrefix(strings.TrimSpace(provider.BaseURL), "https://"):
return fmt.Errorf("provider %q: base_url must use https", providerID)
case strings.TrimSpace(provider.Platform) == "":
return fmt.Errorf("provider %q: platform is required", providerID)
case strings.TrimSpace(provider.AccountType) == "":
return fmt.Errorf("provider %q: account_type is required", providerID)
case len(provider.DefaultModels) == 0:
return fmt.Errorf("provider %q: default_models must not be empty", providerID)
case strings.TrimSpace(provider.SmokeTestModel) == "":
return fmt.Errorf("provider %q: smoke_test_model is required", providerID)
case !contains(provider.DefaultModels, provider.SmokeTestModel):
return fmt.Errorf("provider %q: smoke_test_model must be present in default_models", providerID)
case strings.TrimSpace(provider.GroupTemplate.Name) == "":
return fmt.Errorf("provider %q: group_template.name is required", providerID)
case strings.TrimSpace(provider.ChannelTemplate.Name) == "":
return fmt.Errorf("provider %q: channel_template.name is required", providerID)
case len(provider.ChannelTemplate.ModelMapping) == 0:
return fmt.Errorf("provider %q: channel_template.model_mapping must not be empty", providerID)
case strings.TrimSpace(provider.PlanTemplate.Name) == "":
return fmt.Errorf("provider %q: plan_template.name is required", providerID)
case provider.PlanTemplate.ValidityDays <= 0:
return fmt.Errorf("provider %q: plan_template.validity_days must be positive", providerID)
}
if _, ok := seen[providerID]; ok {
return fmt.Errorf("duplicate provider_id %q", providerID)
}
seen[providerID] = struct{}{}
}
return nil
}
func validateChecksums(root string, checksumFile string) error {
path := filepath.Join(root, checksumFile)
file, err := os.Open(path)
if err != nil {
return fmt.Errorf("read checksum file %q: %w", checksumFile, err)
}
defer file.Close()
scanner := bufio.NewScanner(file)
lineNumber := 0
for scanner.Scan() {
lineNumber++
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
parts := strings.Fields(line)
if len(parts) != 2 {
return fmt.Errorf("checksum file %q line %d: invalid format", checksumFile, lineNumber)
}
relativePath := parts[1]
body, err := os.ReadFile(filepath.Join(root, relativePath))
if err != nil {
return fmt.Errorf("checksum file %q line %d: read %q: %w", checksumFile, lineNumber, relativePath, err)
}
sum := sha256.Sum256(body)
actual := hex.EncodeToString(sum[:])
if !strings.EqualFold(parts[0], actual) {
return fmt.Errorf("checksum mismatch for %s", relativePath)
}
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("scan checksum file %q: %w", checksumFile, err)
}
return nil
}
func computeAggregateChecksum(root string, checksumFile string) (string, error) {
body, err := os.ReadFile(filepath.Join(root, checksumFile))
if err != nil {
return "", fmt.Errorf("read checksum file %q: %w", checksumFile, err)
}
sum := sha256.Sum256(body)
return hex.EncodeToString(sum[:]), nil
}
func contains(items []string, target string) bool {
for _, item := range items {
if strings.TrimSpace(item) == strings.TrimSpace(target) {
return true
}
}
return false
}

View File

@@ -0,0 +1,108 @@
package pack
import (
"crypto/sha256"
"encoding/hex"
"os"
"path/filepath"
"strings"
"testing"
)
func TestLoadDirParsesAndValidatesPack(t *testing.T) {
packDir := createPackFixture(t, map[string]string{
"pack.json": `{
"pack_id": "openai-cn-pack",
"version": "1.0.0",
"vendor": "YourTeam",
"target_host": "sub2api",
"min_host_version": "0.1.126",
"max_host_version": "0.2.x",
"providers_dir": "providers",
"checksum_file": "checksums.txt"
}`,
"providers/deepseek.json": `{
"provider_id": "deepseek",
"display_name": "DeepSeek OpenAI Compatible",
"base_url": "https://api.deepseek.com",
"platform": "openai",
"account_type": "api",
"default_models": ["deepseek-chat", "deepseek-reasoner"],
"smoke_test_model": "deepseek-chat",
"group_template": {"name": "DeepSeek 默认分组", "rate_multiplier": 1.0},
"channel_template": {"name": "DeepSeek 默认渠道", "model_mapping": {"deepseek-chat": "deepseek-chat"}},
"plan_template": {"name": "DeepSeek 默认套餐", "price": 19.9, "validity_days": 30, "validity_unit": "day"},
"import": {"supports_multi_key": true, "supports_strict": true, "supports_partial": true}
}`,
})
loaded, err := LoadDir(packDir)
if err != nil {
t.Fatalf("LoadDir() error = %v", err)
}
if loaded.Manifest.PackID != "openai-cn-pack" {
t.Fatalf("PackID = %q, want %q", loaded.Manifest.PackID, "openai-cn-pack")
}
if len(loaded.Providers) != 1 {
t.Fatalf("len(Providers) = %d, want 1", len(loaded.Providers))
}
if loaded.Providers[0].ProviderID != "deepseek" {
t.Fatalf("ProviderID = %q, want %q", loaded.Providers[0].ProviderID, "deepseek")
}
}
func TestLoadDirRejectsChecksumMismatch(t *testing.T) {
packDir := t.TempDir()
mustWrite(t, filepath.Join(packDir, "pack.json"), `{"pack_id":"openai-cn-pack","version":"1.0.0","vendor":"x","target_host":"sub2api","min_host_version":"0.1.126","max_host_version":"0.2.x","providers_dir":"providers","checksum_file":"checksums.txt"}`)
mustWrite(t, filepath.Join(packDir, "providers", "deepseek.json"), `{"provider_id":"deepseek","display_name":"DeepSeek","base_url":"https://api.deepseek.com","platform":"openai","account_type":"api","default_models":["deepseek-chat"],"smoke_test_model":"deepseek-chat","group_template":{"name":"g","rate_multiplier":1},"channel_template":{"name":"c","model_mapping":{"deepseek-chat":"deepseek-chat"}},"plan_template":{"name":"p","price":1,"validity_days":30,"validity_unit":"day"},"import":{"supports_multi_key":true,"supports_strict":true,"supports_partial":true}}`)
mustWrite(t, filepath.Join(packDir, "checksums.txt"), "deadbeef pack.json\ndeadbeef providers/deepseek.json\n")
_, err := LoadDir(packDir)
if err == nil {
t.Fatal("LoadDir() error = nil, want checksum mismatch")
}
if !strings.Contains(err.Error(), "checksum mismatch") {
t.Fatalf("LoadDir() error = %v, want checksum mismatch", err)
}
}
func TestLoadDirRejectsInvalidProviderSchema(t *testing.T) {
packDir := createPackFixture(t, map[string]string{
"pack.json": `{"pack_id":"openai-cn-pack","version":"1.0.0","vendor":"x","target_host":"sub2api","min_host_version":"0.1.126","max_host_version":"0.2.x","providers_dir":"providers","checksum_file":"checksums.txt"}`,
"providers/deepseek.json": `{"provider_id":"deepseek","display_name":"DeepSeek","base_url":"http://insecure.example.com","platform":"openai","account_type":"api","default_models":["deepseek-chat"],"smoke_test_model":"missing-model","group_template":{"name":"g","rate_multiplier":1},"channel_template":{"name":"c","model_mapping":{"deepseek-chat":"deepseek-chat"}},"plan_template":{"name":"p","price":1,"validity_days":30,"validity_unit":"day"},"import":{"supports_multi_key":true,"supports_strict":true,"supports_partial":true}}`,
})
_, err := LoadDir(packDir)
if err == nil {
t.Fatal("LoadDir() error = nil, want schema validation failure")
}
if !strings.Contains(err.Error(), "https") && !strings.Contains(err.Error(), "smoke_test_model") {
t.Fatalf("LoadDir() error = %v, want schema validation detail", err)
}
}
func createPackFixture(t *testing.T, files map[string]string) string {
t.Helper()
packDir := t.TempDir()
var lines []string
for relativePath, content := range files {
absolutePath := filepath.Join(packDir, relativePath)
mustWrite(t, absolutePath, content)
sum := sha256.Sum256([]byte(content))
lines = append(lines, hex.EncodeToString(sum[:])+" "+relativePath)
}
mustWrite(t, filepath.Join(packDir, "checksums.txt"), strings.Join(lines, "\n")+"\n")
return packDir
}
func mustWrite(t *testing.T, path string, content string) {
t.Helper()
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
t.Fatalf("MkdirAll(%q) error = %v", path, err)
}
if err := os.WriteFile(path, []byte(content), 0o644); err != nil {
t.Fatalf("WriteFile(%q) error = %v", path, err)
}
}

View File

@@ -0,0 +1,171 @@
package pack
import (
"archive/zip"
"fmt"
"io"
"os"
"path/filepath"
"strings"
)
const (
maxArchiveEntries = 256
maxArchiveFileSize = 5 << 20
maxArchiveTotalSize = 20 << 20
)
func LoadPath(path string) (LoadedPack, error) {
trimmed := strings.TrimSpace(path)
if trimmed == "" {
return LoadedPack{}, fmt.Errorf("pack path is required")
}
info, err := os.Stat(trimmed)
if err != nil {
return LoadedPack{}, fmt.Errorf("stat pack path: %w", err)
}
if info.IsDir() {
return LoadDir(trimmed)
}
if strings.EqualFold(filepath.Ext(info.Name()), ".zip") {
return LoadArchive(trimmed)
}
return LoadedPack{}, fmt.Errorf("pack path %q must be a directory or .zip archive", trimmed)
}
func LoadArchive(path string) (LoadedPack, error) {
root, cleanup, err := extractZipToTemp(path)
if err != nil {
return LoadedPack{}, err
}
defer cleanup()
loaded, err := LoadDir(root)
if err != nil {
return LoadedPack{}, err
}
loaded.Dir = strings.TrimSpace(path)
return loaded, nil
}
func extractZipToTemp(path string) (string, func(), error) {
reader, err := zip.OpenReader(strings.TrimSpace(path))
if err != nil {
return "", nil, fmt.Errorf("open pack archive: %w", err)
}
defer reader.Close()
if len(reader.File) == 0 {
return "", nil, fmt.Errorf("pack archive is empty")
}
if len(reader.File) > maxArchiveEntries {
return "", nil, fmt.Errorf("pack archive has too many entries: %d", len(reader.File))
}
tempDir, err := os.MkdirTemp("", "relay-pack-*")
if err != nil {
return "", nil, fmt.Errorf("create temp dir for pack archive: %w", err)
}
cleanup := func() { _ = os.RemoveAll(tempDir) }
var totalSize uint64
for _, file := range reader.File {
cleanName := filepath.Clean(file.Name)
if cleanName == "." || cleanName == "" {
continue
}
if filepath.IsAbs(cleanName) || cleanName == ".." || strings.HasPrefix(cleanName, ".."+string(filepath.Separator)) {
cleanup()
return "", nil, fmt.Errorf("pack archive contains invalid path %q", file.Name)
}
if file.FileInfo().Mode()&os.ModeSymlink != 0 {
cleanup()
return "", nil, fmt.Errorf("pack archive contains unsupported symlink entry %q", file.Name)
}
if file.UncompressedSize64 > maxArchiveFileSize {
cleanup()
return "", nil, fmt.Errorf("pack archive entry %q exceeds size limit", file.Name)
}
totalSize += file.UncompressedSize64
if totalSize > maxArchiveTotalSize {
cleanup()
return "", nil, fmt.Errorf("pack archive exceeds total size limit")
}
targetPath := filepath.Join(tempDir, cleanName)
relativeTarget, err := filepath.Rel(tempDir, targetPath)
if err != nil || relativeTarget == ".." || strings.HasPrefix(relativeTarget, ".."+string(filepath.Separator)) {
cleanup()
return "", nil, fmt.Errorf("pack archive entry %q escapes extraction root", file.Name)
}
if file.FileInfo().IsDir() {
if err := os.MkdirAll(targetPath, 0o755); err != nil {
cleanup()
return "", nil, fmt.Errorf("create archive dir %q: %w", file.Name, err)
}
continue
}
if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil {
cleanup()
return "", nil, fmt.Errorf("create archive parent dir %q: %w", file.Name, err)
}
src, err := file.Open()
if err != nil {
cleanup()
return "", nil, fmt.Errorf("open archive entry %q: %w", file.Name, err)
}
dst, err := os.OpenFile(targetPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644)
if err != nil {
src.Close()
cleanup()
return "", nil, fmt.Errorf("create archive file %q: %w", file.Name, err)
}
_, copyErr := io.Copy(dst, src)
closeErr := dst.Close()
srcErr := src.Close()
if copyErr != nil {
cleanup()
return "", nil, fmt.Errorf("extract archive entry %q: %w", file.Name, copyErr)
}
if closeErr != nil {
cleanup()
return "", nil, fmt.Errorf("close archive file %q: %w", file.Name, closeErr)
}
if srcErr != nil {
cleanup()
return "", nil, fmt.Errorf("close archive entry %q: %w", file.Name, srcErr)
}
}
root, err := resolvePackRoot(tempDir)
if err != nil {
cleanup()
return "", nil, err
}
return root, cleanup, nil
}
func resolvePackRoot(extractDir string) (string, error) {
manifestPath := filepath.Join(extractDir, "pack.json")
if _, err := os.Stat(manifestPath); err == nil {
return extractDir, nil
}
entries, err := os.ReadDir(extractDir)
if err != nil {
return "", fmt.Errorf("read extracted archive root: %w", err)
}
if len(entries) != 1 || !entries[0].IsDir() {
return "", fmt.Errorf("pack archive must contain pack.json at root or a single top-level directory")
}
root := filepath.Join(extractDir, entries[0].Name())
if _, err := os.Stat(filepath.Join(root, "pack.json")); err != nil {
return "", fmt.Errorf("pack archive root does not contain pack.json")
}
return root, nil
}

View File

@@ -0,0 +1,70 @@
package pack
import (
"archive/zip"
"os"
"path/filepath"
"testing"
)
func TestLoadPathSupportsDirectory(t *testing.T) {
loaded, err := LoadPath(filepath.Join("..", "..", "packs", "openai-cn-pack"))
if err != nil {
t.Fatalf("LoadPath(dir) error = %v", err)
}
if loaded.Manifest.PackID != "openai-cn-pack" {
t.Fatalf("PackID = %q, want %q", loaded.Manifest.PackID, "openai-cn-pack")
}
}
func TestLoadPathSupportsZipArchive(t *testing.T) {
tempDir := t.TempDir()
archivePath := filepath.Join(tempDir, "openai-cn-pack.zip")
writePackArchive(t, archivePath)
loaded, err := LoadPath(archivePath)
if err != nil {
t.Fatalf("LoadPath(zip) error = %v", err)
}
if loaded.Manifest.PackID != "openai-cn-pack" {
t.Fatalf("PackID = %q, want %q", loaded.Manifest.PackID, "openai-cn-pack")
}
if len(loaded.Providers) == 0 {
t.Fatal("Providers = 0, want parsed providers from archive")
}
}
func writePackArchive(t *testing.T, archivePath string) {
t.Helper()
file, err := os.Create(archivePath)
if err != nil {
t.Fatalf("os.Create() error = %v", err)
}
defer file.Close()
writer := zip.NewWriter(file)
defer writer.Close()
sourceRoot := filepath.Join("..", "..", "packs", "openai-cn-pack")
files := []string{
"pack.json",
"checksums.txt",
filepath.Join("providers", "deepseek.json"),
}
for _, relativePath := range files {
body, err := os.ReadFile(filepath.Join(sourceRoot, relativePath))
if err != nil {
t.Fatalf("os.ReadFile(%q) error = %v", relativePath, err)
}
entry, err := writer.Create(filepath.ToSlash(filepath.Join("openai-cn-pack", relativePath)))
if err != nil {
t.Fatalf("Create(%q) error = %v", relativePath, err)
}
if _, err := entry.Write(body); err != nil {
t.Fatalf("Write(%q) error = %v", relativePath, err)
}
}
if err := writer.Close(); err != nil {
t.Fatalf("Close archive writer: %v", err)
}
}

134
internal/pack/version.go Normal file
View File

@@ -0,0 +1,134 @@
package pack
import (
"fmt"
"strconv"
"strings"
)
func CheckHostCompatibility(manifest Manifest, hostVersion string) error {
targetHost := strings.TrimSpace(manifest.TargetHost)
if targetHost == "" {
return fmt.Errorf("pack manifest target_host is required")
}
if targetHost != "sub2api" {
return fmt.Errorf("pack target_host %q is not supported", targetHost)
}
normalizedHost, err := parseVersion(strings.TrimSpace(hostVersion))
if err != nil {
return fmt.Errorf("parse host version %q: %w", hostVersion, err)
}
minVersion := strings.TrimSpace(manifest.MinHostVersion)
if minVersion != "" {
cmp, err := compareVersions(normalizedHost.raw, minVersion)
if err != nil {
return fmt.Errorf("compare min_host_version: %w", err)
}
if cmp < 0 {
return fmt.Errorf("host version %q is below min_host_version %q", hostVersion, minVersion)
}
}
maxVersion := strings.TrimSpace(manifest.MaxHostVersion)
if maxVersion != "" {
ok, err := matchesMaxConstraint(normalizedHost.raw, maxVersion)
if err != nil {
return fmt.Errorf("compare max_host_version: %w", err)
}
if !ok {
return fmt.Errorf("host version %q is above max_host_version %q", hostVersion, maxVersion)
}
}
return nil
}
type parsedVersion struct {
raw string
parts [3]int
}
func compareVersions(a, b string) (int, error) {
left, err := parseVersion(a)
if err != nil {
return 0, err
}
right, err := parseVersion(b)
if err != nil {
return 0, err
}
for i := 0; i < len(left.parts); i++ {
if left.parts[i] < right.parts[i] {
return -1, nil
}
if left.parts[i] > right.parts[i] {
return 1, nil
}
}
return 0, nil
}
func matchesMaxConstraint(hostVersion, maxVersion string) (bool, error) {
normalizedMax := normalizeVersion(maxVersion)
if strings.HasSuffix(normalizedMax, ".x") {
prefix := strings.TrimSuffix(normalizedMax, ".x")
parts := strings.Split(prefix, ".")
if len(parts) != 2 {
return false, fmt.Errorf("wildcard max version %q must be in N.N.x format", maxVersion)
}
host, err := parseVersion(hostVersion)
if err != nil {
return false, err
}
major, err := strconv.Atoi(parts[0])
if err != nil {
return false, fmt.Errorf("parse major version %q: %w", parts[0], err)
}
minor, err := strconv.Atoi(parts[1])
if err != nil {
return false, fmt.Errorf("parse minor version %q: %w", parts[1], err)
}
if host.parts[0] < major {
return true, nil
}
if host.parts[0] > major {
return false, nil
}
return host.parts[1] <= minor, nil
}
cmp, err := compareVersions(hostVersion, maxVersion)
if err != nil {
return false, err
}
return cmp <= 0, nil
}
func parseVersion(value string) (parsedVersion, error) {
normalized := normalizeVersion(value)
if normalized == "" {
return parsedVersion{}, fmt.Errorf("version is required")
}
parts := strings.Split(normalized, ".")
if len(parts) != 3 {
return parsedVersion{}, fmt.Errorf("version %q must be in N.N.N format", value)
}
var parsed parsedVersion
parsed.raw = normalized
for i, part := range parts {
number, err := strconv.Atoi(part)
if err != nil {
return parsedVersion{}, fmt.Errorf("parse version segment %q: %w", part, err)
}
parsed.parts[i] = number
}
return parsed, nil
}
func normalizeVersion(value string) string {
trimmed := strings.TrimSpace(value)
trimmed = strings.TrimPrefix(trimmed, "v")
trimmed = strings.TrimPrefix(trimmed, "V")
return trimmed
}

View File

@@ -0,0 +1,32 @@
package pack
import "testing"
func TestCheckHostCompatibilityAcceptsRange(t *testing.T) {
manifest := Manifest{
PackID: "openai-cn-pack",
TargetHost: "sub2api",
MinHostVersion: "0.1.126",
MaxHostVersion: "0.2.x",
}
if err := CheckHostCompatibility(manifest, "0.1.126"); err != nil {
t.Fatalf("CheckHostCompatibility() error = %v, want nil", err)
}
if err := CheckHostCompatibility(manifest, "0.2.9"); err != nil {
t.Fatalf("CheckHostCompatibility() error = %v, want nil for wildcard upper bound", err)
}
}
func TestCheckHostCompatibilityRejectsBelowMinimum(t *testing.T) {
manifest := Manifest{TargetHost: "sub2api", MinHostVersion: "0.1.126"}
if err := CheckHostCompatibility(manifest, "0.1.125"); err == nil {
t.Fatal("CheckHostCompatibility() error = nil, want min version failure")
}
}
func TestCheckHostCompatibilityRejectsDifferentMaxMinor(t *testing.T) {
manifest := Manifest{TargetHost: "sub2api", MaxHostVersion: "0.2.x"}
if err := CheckHostCompatibility(manifest, "0.3.0"); err == nil {
t.Fatal("CheckHostCompatibility() error = nil, want max version failure")
}
}

View File

@@ -0,0 +1,321 @@
package provision
import (
"context"
"encoding/json"
"fmt"
"strings"
"sub2api-cn-relay-manager/internal/host/sub2api"
"sub2api-cn-relay-manager/internal/pack"
"sub2api-cn-relay-manager/internal/store/sqlite"
)
type BatchDetailResult struct {
Batch sqlite.ImportBatch
Items []sqlite.ImportBatchItem
ManagedResources []sqlite.ManagedResource
AccessClosures []sqlite.AccessClosureRecord
ReconcileRuns []sqlite.ReconcileRun
}
type BatchDetailService struct {
store *sqlite.DB
}
func NewBatchDetailService(store *sqlite.DB) *BatchDetailService {
return &BatchDetailService{store: store}
}
func (s *BatchDetailService) Get(ctx context.Context, batchID int64) (BatchDetailResult, error) {
if s == nil || s.store == nil {
return BatchDetailResult{}, fmt.Errorf("store is required")
}
batch, err := s.store.ImportBatches().GetByID(ctx, batchID)
if err != nil {
return BatchDetailResult{}, err
}
items, err := s.store.ImportBatchItems().GetByBatchID(ctx, batchID)
if err != nil {
return BatchDetailResult{}, err
}
managedResources, err := s.store.ManagedResources().GetByBatchID(ctx, batchID)
if err != nil {
return BatchDetailResult{}, err
}
accessClosures, err := s.store.AccessClosures().GetByBatchID(ctx, batchID)
if err != nil {
return BatchDetailResult{}, err
}
reconcileRuns, err := s.store.ReconcileRuns().GetByProviderID(ctx, batch.ProviderID)
if err != nil {
return BatchDetailResult{}, err
}
return BatchDetailResult{
Batch: batch,
Items: items,
ManagedResources: managedResources,
AccessClosures: accessClosures,
ReconcileRuns: reconcileRuns,
}, nil
}
type ReconcileRequest struct {
HostBaseURL string
AccessProbeAPIKey string
Pack pack.LoadedPack
Provider pack.ProviderManifest
}
type ReconcileResult struct {
BatchID int64
Status string
MissingCount int
ExtraCount int
ProbeFailureCount int
AccessStatus string
Summary map[string]any
}
type ReconcileService struct {
store *sqlite.DB
host sub2api.HostAdapter
}
func NewReconcileService(store *sqlite.DB, host sub2api.HostAdapter) *ReconcileService {
return &ReconcileService{store: store, host: host}
}
func (s *ReconcileService) Reconcile(ctx context.Context, req ReconcileRequest) (ReconcileResult, error) {
if s == nil || s.store == nil {
return ReconcileResult{}, fmt.Errorf("store is required")
}
if s.host == nil {
return ReconcileResult{}, fmt.Errorf("host adapter is required")
}
if strings.TrimSpace(req.HostBaseURL) == "" {
return ReconcileResult{}, fmt.Errorf("host_base_url is required")
}
hostVersion, err := s.host.GetHostVersion(ctx)
if err != nil {
return ReconcileResult{}, fmt.Errorf("get host version: %w", err)
}
if err := pack.CheckHostCompatibility(req.Pack.Manifest, hostVersion); err != nil {
return ReconcileResult{}, err
}
packRow, err := s.store.Packs().GetByPackID(ctx, req.Pack.Manifest.PackID)
if err != nil {
return ReconcileResult{}, err
}
providerRow, err := s.store.Providers().GetByPackIDAndProviderID(ctx, packRow.ID, req.Provider.ProviderID)
if err != nil {
return ReconcileResult{}, err
}
batchRow, err := s.store.ImportBatches().GetLatestByProviderID(ctx, providerRow.ID)
if err != nil {
return ReconcileResult{}, err
}
storedResources, err := s.store.ManagedResources().GetByBatchID(ctx, batchRow.ID)
if err != nil {
return ReconcileResult{}, err
}
batchItems, err := s.store.ImportBatchItems().GetByBatchID(ctx, batchRow.ID)
if err != nil {
return ReconcileResult{}, err
}
accessClosures, err := s.store.AccessClosures().GetByBatchID(ctx, batchRow.ID)
if err != nil {
return ReconcileResult{}, err
}
snapshot, err := s.host.ListManagedResources(ctx, sub2api.ListManagedResourcesRequest{
GroupName: SuggestResourceNames(req.Provider).Group,
ChannelName: SuggestResourceNames(req.Provider).Channel,
PlanName: SuggestResourceNames(req.Provider).Plan,
})
if err != nil {
return ReconcileResult{}, fmt.Errorf("list managed resources: %w", err)
}
missing, extra := diffManagedResources(storedResources, snapshot)
probeFailures, err := s.rerunAccountProbes(ctx, batchItems, req.Provider.SmokeTestModel)
if err != nil {
return ReconcileResult{}, err
}
accessStatus, accessChecked, err := s.rerunAccessClosure(ctx, batchRow.ID, accessClosures, req.AccessProbeAPIKey, req.Provider.SmokeTestModel)
if err != nil {
return ReconcileResult{}, err
}
status := "active"
if missing > 0 || extra > 0 {
status = "drifted"
} else if probeFailures > 0 || (accessChecked && accessStatus == AccessStatusBroken) {
status = "degraded"
}
summary := map[string]any{
"missing_count": missing,
"extra_count": extra,
"host_version": hostVersion,
"probe_failures": probeFailures,
"access_status": accessStatus,
"access_rechecked": accessChecked,
}
summaryJSON, err := json.Marshal(summary)
if err != nil {
return ReconcileResult{}, fmt.Errorf("marshal reconcile summary: %w", err)
}
if _, err := s.store.ReconcileRuns().Create(ctx, sqlite.ReconcileRun{ProviderID: providerRow.ID, Status: status, SummaryJSON: string(summaryJSON)}); err != nil {
return ReconcileResult{}, err
}
return ReconcileResult{BatchID: batchRow.ID, Status: status, MissingCount: missing, ExtraCount: extra, ProbeFailureCount: probeFailures, AccessStatus: accessStatus, Summary: summary}, nil
}
func (s *ReconcileService) rerunAccountProbes(ctx context.Context, items []sqlite.ImportBatchItem, expectedModel string) (int, error) {
if len(items) == 0 {
return 0, nil
}
failures := 0
for _, item := range items {
accountID, err := accountIDFromProbeSummary(item.ProbeSummaryJSON)
if err != nil {
return 0, fmt.Errorf("decode import batch item %d probe summary: %w", item.ID, err)
}
if strings.TrimSpace(accountID) == "" {
return 0, fmt.Errorf("import batch item %d missing account_id in probe summary", item.ID)
}
probe, err := s.host.TestAccount(ctx, accountID)
if err != nil {
return 0, fmt.Errorf("re-test account %s: %w", accountID, err)
}
models, err := s.host.GetAccountModels(ctx, accountID)
if err != nil {
return 0, fmt.Errorf("reload account models %s: %w", accountID, err)
}
smokeModelSeen := hasModel(models, expectedModel)
status := firstNonEmpty(probe.Status, "unknown")
payload, err := json.Marshal(map[string]any{
"account_id": accountID,
"probe_ok": probe.OK,
"probe_status": probe.Status,
"probe_message": probe.Message,
"models": models,
"smoke_model_seen": smokeModelSeen,
"reconcile_rerun": true,
})
if err != nil {
return 0, fmt.Errorf("marshal probe rerun summary for %s: %w", accountID, err)
}
if err := s.store.ImportBatchItems().UpdateResult(ctx, item.ID, status, string(payload)); err != nil {
return 0, err
}
if _, err := s.store.ProbeResults().Create(ctx, sqlite.ProbeResult{BatchItemID: item.ID, ProbeType: "account_smoke_rerun", Status: status, SummaryJSON: string(payload)}); err != nil {
return 0, err
}
if !probe.OK || !smokeModelSeen {
failures++
}
}
return failures, nil
}
func (s *ReconcileService) rerunAccessClosure(ctx context.Context, batchID int64, accessClosures []sqlite.AccessClosureRecord, probeAPIKey, expectedModel string) (string, bool, error) {
if len(accessClosures) == 0 {
return "not_configured", false, nil
}
latest := accessClosures[len(accessClosures)-1]
status := firstNonEmpty(latest.Status, deriveHealthyAccessStatus(latest.ClosureType))
if strings.TrimSpace(probeAPIKey) == "" {
return status, false, nil
}
result, err := s.host.CheckGatewayAccess(ctx, sub2api.GatewayAccessCheckRequest{APIKey: probeAPIKey, ExpectedModel: expectedModel})
if err != nil {
return "", false, fmt.Errorf("re-check gateway access: %w", err)
}
if result.OK && result.HasExpectedModel {
status = deriveHealthyAccessStatus(latest.ClosureType)
} else {
status = AccessStatusBroken
}
payload, err := json.Marshal(map[string]any{
"status_code": result.StatusCode,
"ok": result.OK,
"has_expected_model": result.HasExpectedModel,
"models": result.Models,
"reconcile_rerun": true,
})
if err != nil {
return "", false, fmt.Errorf("marshal access rerun summary: %w", err)
}
if _, err := s.store.AccessClosures().Create(ctx, sqlite.AccessClosureRecord{BatchID: batchID, ClosureType: latest.ClosureType, Status: status, DetailsJSON: string(payload)}); err != nil {
return "", false, err
}
return status, true, nil
}
func deriveHealthyAccessStatus(closureType string) string {
switch strings.TrimSpace(closureType) {
case AccessModeSubscription:
return AccessStatusSubscriptionReady
case AccessModeSelfService:
return AccessStatusSelfServiceReady
default:
return "unknown"
}
}
func accountIDFromProbeSummary(summaryJSON string) (string, error) {
if strings.TrimSpace(summaryJSON) == "" {
return "", nil
}
var payload map[string]any
if err := json.Unmarshal([]byte(summaryJSON), &payload); err != nil {
return "", err
}
accountID, _ := payload["account_id"].(string)
return strings.TrimSpace(accountID), nil
}
func diffManagedResources(stored []sqlite.ManagedResource, snapshot sub2api.ManagedResourceSnapshot) (int, int) {
live := map[string]map[string]struct{}{
"group": make(map[string]struct{}),
"channel": make(map[string]struct{}),
"plan": make(map[string]struct{}),
"account": make(map[string]struct{}),
}
for _, resource := range snapshot.Groups {
live["group"][strings.TrimSpace(resource.ID)] = struct{}{}
}
for _, resource := range snapshot.Channels {
live["channel"][strings.TrimSpace(resource.ID)] = struct{}{}
}
for _, resource := range snapshot.Plans {
live["plan"][strings.TrimSpace(resource.ID)] = struct{}{}
}
for _, resource := range snapshot.Accounts {
live["account"][strings.TrimSpace(resource.ID)] = struct{}{}
}
storedByType := map[string]map[string]struct{}{
"group": make(map[string]struct{}),
"channel": make(map[string]struct{}),
"plan": make(map[string]struct{}),
"account": make(map[string]struct{}),
}
for _, resource := range stored {
storedByType[strings.TrimSpace(resource.ResourceType)][strings.TrimSpace(resource.HostResourceID)] = struct{}{}
}
missing := 0
extra := 0
for resourceType, storedIDs := range storedByType {
for id := range storedIDs {
if _, ok := live[resourceType][id]; !ok {
missing++
}
}
for id := range live[resourceType] {
if _, ok := storedIDs[id]; !ok {
extra++
}
}
}
return missing, extra
}

View File

@@ -0,0 +1,235 @@
package provision
import (
"context"
"testing"
"sub2api-cn-relay-manager/internal/host/sub2api"
"sub2api-cn-relay-manager/internal/store/sqlite"
)
func TestBatchDetailServiceGetReturnsPersistedArtifacts(t *testing.T) {
store := openProvisionTestStore(t)
defer closeProvisionTestStore(t, store)
host := &fakeHostAdapter{
batchAccounts: []sub2api.AccountRef{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
testResults: map[string]sub2api.ProbeResult{
"account_1": {OK: true, Status: "passed"},
"account_2": {OK: true, Status: "passed"},
},
models: map[string][]sub2api.AccountModel{
"account_1": {{ID: "deepseek-chat"}},
"account_2": {{ID: "deepseek-chat"}},
},
gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}},
}
batchID := seedRuntimeImportForReconcile(t, store, host)
providerRow, err := store.Providers().ListByProviderID(context.Background(), sampleProviderManifest().ProviderID)
if err != nil {
t.Fatalf("Providers().ListByProviderID() error = %v", err)
}
if len(providerRow) != 1 {
t.Fatalf("providers = %d, want 1", len(providerRow))
}
if _, err := store.ReconcileRuns().Create(context.Background(), sqlite.ReconcileRun{ProviderID: providerRow[0].ID, Status: "active", SummaryJSON: `{"missing_count":0}`}); err != nil {
t.Fatalf("ReconcileRuns().Create() error = %v", err)
}
result, err := NewBatchDetailService(store).Get(context.Background(), batchID)
if err != nil {
t.Fatalf("Get() error = %v", err)
}
if result.Batch.ID != batchID {
t.Fatalf("Batch.ID = %d, want %d", result.Batch.ID, batchID)
}
if len(result.Items) != 2 {
t.Fatalf("len(Items) = %d, want 2", len(result.Items))
}
if len(result.ManagedResources) != 4 {
t.Fatalf("len(ManagedResources) = %d, want 4", len(result.ManagedResources))
}
if len(result.AccessClosures) != 1 {
t.Fatalf("len(AccessClosures) = %d, want 1", len(result.AccessClosures))
}
if len(result.ReconcileRuns) != 1 {
t.Fatalf("len(ReconcileRuns) = %d, want 1", len(result.ReconcileRuns))
}
}
func TestBatchDetailServiceGetValidatesStore(t *testing.T) {
_, err := (*BatchDetailService)(nil).Get(context.Background(), 1)
if err == nil || err.Error() != "store is required" {
t.Fatalf("nil service Get() error = %v, want store is required", err)
}
}
func TestAccountIDFromProbeSummary(t *testing.T) {
accountID, err := accountIDFromProbeSummary(`{"account_id":" account_1 "}`)
if err != nil {
t.Fatalf("accountIDFromProbeSummary() error = %v", err)
}
if accountID != "account_1" {
t.Fatalf("accountID = %q, want account_1", accountID)
}
if _, err := accountIDFromProbeSummary(`{`); err == nil {
t.Fatal("accountIDFromProbeSummary() error = nil, want JSON decode error")
}
blank, err := accountIDFromProbeSummary("")
if err != nil {
t.Fatalf("accountIDFromProbeSummary(blank) error = %v", err)
}
if blank != "" {
t.Fatalf("blank accountID = %q, want empty", blank)
}
}
func TestReconcileServiceRerunAccessClosureWithoutProbeKeyUsesLatestStatus(t *testing.T) {
store := openProvisionTestStore(t)
defer closeProvisionTestStore(t, store)
status, checked, err := NewReconcileService(store, &fakeHostAdapter{}).rerunAccessClosure(context.Background(), 1, []sqlite.AccessClosureRecord{{ClosureType: AccessModeSubscription, Status: AccessStatusSubscriptionReady}}, "", "deepseek-chat")
if err != nil {
t.Fatalf("rerunAccessClosure() error = %v", err)
}
if checked {
t.Fatal("checked = true, want false without probe key")
}
if status != AccessStatusSubscriptionReady {
t.Fatalf("status = %q, want %q", status, AccessStatusSubscriptionReady)
}
}
func TestReconcileServiceRerunAccessClosureMarksBrokenWhenGatewayCheckFails(t *testing.T) {
store := openProvisionTestStore(t)
defer closeProvisionTestStore(t, store)
hostSeed := &fakeHostAdapter{
batchAccounts: []sub2api.AccountRef{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
testResults: map[string]sub2api.ProbeResult{
"account_1": {OK: true, Status: "passed"},
"account_2": {OK: true, Status: "passed"},
},
models: map[string][]sub2api.AccountModel{
"account_1": {{ID: "deepseek-chat"}},
"account_2": {{ID: "deepseek-chat"}},
},
gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}},
}
batchID := seedRuntimeImportForReconcile(t, store, hostSeed)
host := &fakeHostAdapter{gatewayResult: sub2api.GatewayAccessResult{OK: false, StatusCode: 403, HasExpectedModel: false}}
status, checked, err := NewReconcileService(store, host).rerunAccessClosure(context.Background(), batchID, []sqlite.AccessClosureRecord{{ClosureType: AccessModeSelfService, Status: AccessStatusSelfServiceReady}}, "user-key", "deepseek-chat")
if err != nil {
t.Fatalf("rerunAccessClosure() error = %v", err)
}
if !checked {
t.Fatal("checked = false, want true")
}
if status != AccessStatusBroken {
t.Fatalf("status = %q, want %q", status, AccessStatusBroken)
}
if got := queryCount(t, store.SQLDB(), "access_closure_records"); got != 2 {
t.Fatalf("access_closure_records row count = %d, want 2 after rerun", got)
}
if host.gatewayProbe.ExpectedModel != "deepseek-chat" {
t.Fatalf("ExpectedModel = %q, want deepseek-chat", host.gatewayProbe.ExpectedModel)
}
}
func TestDiffManagedResourcesCountsMissingAndExtra(t *testing.T) {
missing, extra := diffManagedResources(
[]sqlite.ManagedResource{
{ResourceType: "group", HostResourceID: "group_1"},
{ResourceType: "account", HostResourceID: "account_1"},
},
sub2api.ManagedResourceSnapshot{
Groups: []sub2api.NamedResource{{ID: "group_1"}},
Accounts: []sub2api.NamedResource{{ID: "account_2"}},
},
)
if missing != 1 || extra != 1 {
t.Fatalf("diffManagedResources() = (%d, %d), want (1, 1)", missing, extra)
}
}
func TestDeriveProviderStatus(t *testing.T) {
tests := []struct {
name string
batchStatus string
reconcileStatus string
want string
}{
{name: "reconcile wins", batchStatus: BatchStatusSucceeded, reconcileStatus: "degraded", want: "degraded"},
{name: "succeeded batch", batchStatus: BatchStatusSucceeded, reconcileStatus: "not_run", want: ProviderStatusActive},
{name: "failed batch", batchStatus: BatchStatusFailed, want: ProviderStatusFailed},
{name: "running batch", batchStatus: "running", want: "running"},
{name: "unknown fallback", batchStatus: " pending ", want: "pending"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := deriveProviderStatus(tc.batchStatus, tc.reconcileStatus); got != tc.want {
t.Fatalf("deriveProviderStatus(%q, %q) = %q, want %q", tc.batchStatus, tc.reconcileStatus, got, tc.want)
}
})
}
}
func TestBuildPackAndProviderRecord(t *testing.T) {
packRow, err := buildPackRecord(sampleLoadedPack())
if err != nil {
t.Fatalf("buildPackRecord() error = %v", err)
}
if packRow.PackID != "openai-cn-pack" || packRow.TargetHost != "sub2api" {
t.Fatalf("packRow = %#v, want populated pack metadata", packRow)
}
providerRow, err := buildProviderRecord(7, sampleProviderManifest())
if err != nil {
t.Fatalf("buildProviderRecord() error = %v", err)
}
if providerRow.PackID != 7 || providerRow.ProviderID != sampleProviderManifest().ProviderID {
t.Fatalf("providerRow = %#v, want persisted provider metadata", providerRow)
}
if providerRow.DefaultModelsJSON == "" || providerRow.ManifestJSON == "" {
t.Fatalf("providerRow JSON fields = %#v, want serialized JSON", providerRow)
}
}
func TestFirstNonEmptyAndFingerprintKey(t *testing.T) {
if got := firstNonEmpty(" ", "value", "other"); got != "value" {
t.Fatalf("firstNonEmpty() = %q, want value", got)
}
if got := fingerprintKey([]string{" key-1 "}, 0); got == "key-1" || got == "sha256:" || len(got) < 20 {
t.Fatalf("fingerprintKey() = %q, want sha256 fingerprint", got)
}
if got := fingerprintKey(nil, 3); got != "key-4" {
t.Fatalf("fingerprintKey(nil, 3) = %q, want key-4", got)
}
}
func TestProviderStatusServiceGetResourcesRequiresProviderID(t *testing.T) {
store := openProvisionTestStore(t)
defer closeProvisionTestStore(t, store)
_, err := NewProviderStatusService(store).GetResources(context.Background(), ProviderQuery{})
if err == nil || err.Error() != "provider_id is required" {
t.Fatalf("GetResources() error = %v, want provider_id is required", err)
}
}
func TestResourceSlugFallsBackToProvider(t *testing.T) {
if got := resourceSlug(" !!! "); got != "provider" {
t.Fatalf("resourceSlug() = %q, want provider", got)
}
provider := sampleProviderManifest()
provider.ProviderID = " DeepSeek CN / Prod "
if got := SuggestAccountNamePrefix(provider); got != "deepseek-cn-prod-" {
t.Fatalf("SuggestAccountNamePrefix() = %q, want deepseek-cn-prod-", got)
}
resourceNames := SuggestResourceNames(provider)
if resourceNames.Group != "crm-deepseek-cn-prod-group" {
t.Fatalf("SuggestResourceNames() = %#v, want slugged resource names", resourceNames)
}
}

View File

@@ -0,0 +1,343 @@
package provision
import (
"context"
"errors"
"fmt"
"strings"
"sub2api-cn-relay-manager/internal/access"
"sub2api-cn-relay-manager/internal/host/sub2api"
"sub2api-cn-relay-manager/internal/pack"
)
const (
ImportModeStrict = "strict"
ImportModePartial = "partial"
AccessModeSubscription = "subscription"
AccessModeSelfService = "self_service"
BatchStatusSucceeded = "succeeded"
BatchStatusPartial = "partially_succeeded"
BatchStatusFailed = "failed"
ProviderStatusActive = "active"
ProviderStatusDegraded = "degraded"
ProviderStatusFailed = "failed"
AccessStatusSubscriptionReady = "subscription_ready"
AccessStatusSelfServiceReady = "self_service_ready"
AccessStatusBroken = "broken"
)
type AccessRequest struct {
Mode string
ProbeAPIKey string
Subscriptions []SubscriptionTarget
}
type SubscriptionTarget struct {
UserID string
DurationDays int
}
type ImportRequest struct {
Provider pack.ProviderManifest
Mode string
Access AccessRequest
Keys []string
}
type ImportReport struct {
BatchStatus string
ProviderStatus string
AccessStatus string
AcceptedKeys []string
Group sub2api.GroupRef
Channel sub2api.ChannelRef
Plan *sub2api.PlanRef
Accounts []AccountImportResult
Gateway sub2api.GatewayAccessResult
}
type AccountImportResult struct {
Ref sub2api.AccountRef
Probe sub2api.ProbeResult
Models []sub2api.AccountModel
SmokeModelSeen bool
}
type hostAdapter interface {
sub2api.HostAdapter
CheckGatewayAccess(ctx context.Context, req sub2api.GatewayAccessCheckRequest) (sub2api.GatewayAccessResult, error)
}
type ImportService struct {
host hostAdapter
}
func NewImportService(host hostAdapter) *ImportService {
return &ImportService{host: host}
}
func (s *ImportService) Import(ctx context.Context, req ImportRequest) (report ImportReport, err error) {
normalizedKeys, err := normalizeKeys(req.Keys)
if err != nil {
return ImportReport{}, err
}
if err := validateMode(req.Mode); err != nil {
return ImportReport{}, err
}
if err := access.Validate(access.ClosureRequest{
Mode: req.Access.Mode,
ProbeAPIKey: req.Access.ProbeAPIKey,
Subscriptions: toAccessSubscriptionTargets(req.Access.Subscriptions),
}); err != nil {
return ImportReport{}, err
}
report = ImportReport{AcceptedKeys: normalizedKeys}
rollback := newManagedResourceRollback(s.host)
defer func() {
if err == nil || req.Mode != ImportModeStrict {
return
}
if rollbackErr := rollback.Run(ctx); rollbackErr != nil {
err = errors.Join(err, fmt.Errorf("rollback managed resources: %w", rollbackErr))
}
}()
group, err := s.host.CreateGroup(ctx, sub2api.CreateGroupRequest{
Name: req.Provider.GroupTemplate.Name,
RateMultiplier: req.Provider.GroupTemplate.RateMultiplier,
})
if err != nil {
return report, fmt.Errorf("create group: %w", err)
}
report.Group = group
rollback.AddGroup(group.ID)
channel, err := s.host.CreateChannel(ctx, sub2api.CreateChannelRequest{
Name: req.Provider.ChannelTemplate.Name,
GroupIDs: []string{group.ID},
})
if err != nil {
return report, fmt.Errorf("create channel: %w", err)
}
report.Channel = channel
rollback.AddChannel(channel.ID)
if req.Access.Mode == AccessModeSubscription {
plan, err := s.host.CreatePlan(ctx, sub2api.CreatePlanRequest{
GroupID: group.ID,
Name: req.Provider.PlanTemplate.Name,
Price: req.Provider.PlanTemplate.Price,
ValidityDays: req.Provider.PlanTemplate.ValidityDays,
ValidityUnit: req.Provider.PlanTemplate.ValidityUnit,
})
if err != nil {
return report, fmt.Errorf("create plan: %w", err)
}
report.Plan = &plan
rollback.AddPlan(plan.ID)
}
accounts, err := s.host.BatchCreateAccounts(ctx, buildBatchAccountsRequest(req.Provider, group.ID, normalizedKeys))
if err != nil {
return report, fmt.Errorf("batch create accounts: %w", err)
}
rollback.AddAccounts(accounts)
for _, account := range accounts {
probe, err := s.host.TestAccount(ctx, account.ID)
if err != nil {
return failOrDegrade(report, req.Mode, fmt.Errorf("test account %s: %w", account.ID, err))
}
models, err := s.host.GetAccountModels(ctx, account.ID)
if err != nil {
return failOrDegrade(report, req.Mode, fmt.Errorf("get account models %s: %w", account.ID, err))
}
result := AccountImportResult{Ref: account, Probe: probe, Models: models, SmokeModelSeen: hasModel(models, req.Provider.SmokeTestModel)}
report.Accounts = append(report.Accounts, result)
}
failedAccounts := 0
for _, account := range report.Accounts {
if !account.Probe.OK || !account.SmokeModelSeen {
failedAccounts++
}
}
if failedAccounts > 0 && req.Mode == ImportModeStrict {
report.BatchStatus = BatchStatusFailed
report.ProviderStatus = ProviderStatusFailed
report.AccessStatus = AccessStatusBroken
return report, fmt.Errorf("strict import failed: %d account(s) did not pass smoke validation", failedAccounts)
}
closureService := access.NewService(s.host)
gateway, err := closureService.Close(ctx, access.ClosureRequest{
Mode: req.Access.Mode,
ProbeAPIKey: req.Access.ProbeAPIKey,
Subscriptions: toAccessSubscriptionTargets(req.Access.Subscriptions),
GroupID: group.ID,
ExpectedModel: req.Provider.SmokeTestModel,
})
if err != nil {
return failOrDegrade(report, req.Mode, err)
}
report.Gateway = gateway
report.BatchStatus = BatchStatusSucceeded
report.ProviderStatus = ProviderStatusActive
if failedAccounts > 0 || !gateway.OK || !gateway.HasExpectedModel {
report.BatchStatus = BatchStatusPartial
report.ProviderStatus = ProviderStatusDegraded
}
switch req.Access.Mode {
case AccessModeSubscription:
report.AccessStatus = AccessStatusSubscriptionReady
case AccessModeSelfService:
report.AccessStatus = AccessStatusSelfServiceReady
}
if !gateway.OK || !gateway.HasExpectedModel {
report.AccessStatus = AccessStatusBroken
}
return report, nil
}
func validateMode(mode string) error {
switch strings.TrimSpace(mode) {
case ImportModeStrict, ImportModePartial:
return nil
default:
return fmt.Errorf("unsupported import mode %q", mode)
}
}
func toAccessSubscriptionTargets(targets []SubscriptionTarget) []access.SubscriptionTarget {
result := make([]access.SubscriptionTarget, 0, len(targets))
for _, target := range targets {
result = append(result, access.SubscriptionTarget{UserID: target.UserID, DurationDays: target.DurationDays})
}
return result
}
func normalizeKeys(keys []string) ([]string, error) {
seen := map[string]struct{}{}
result := make([]string, 0, len(keys))
for _, key := range keys {
normalized := strings.TrimSpace(strings.TrimPrefix(key, "\ufeff"))
if normalized == "" {
continue
}
if _, ok := seen[normalized]; ok {
continue
}
seen[normalized] = struct{}{}
result = append(result, normalized)
}
if len(result) == 0 {
return nil, fmt.Errorf("at least one api key is required")
}
return result, nil
}
func buildBatchAccountsRequest(provider pack.ProviderManifest, groupID string, keys []string) sub2api.BatchCreateAccountsRequest {
accounts := make([]sub2api.CreateAccountRequest, 0, len(keys))
for index, key := range keys {
accounts = append(accounts, sub2api.CreateAccountRequest{
Name: fmt.Sprintf("%s-%02d", provider.ProviderID, index+1),
Platform: provider.Platform,
Type: provider.AccountType,
GroupIDs: []string{groupID},
Credentials: map[string]any{
"base_url": provider.BaseURL,
"api_key": key,
},
})
}
return sub2api.BatchCreateAccountsRequest{Accounts: accounts}
}
func hasModel(models []sub2api.AccountModel, target string) bool {
for _, model := range models {
if strings.TrimSpace(model.ID) == strings.TrimSpace(target) {
return true
}
}
return false
}
type managedResourceRollback struct {
host hostAdapter
groupID string
channelID string
planID string
accountIDs []string
}
func newManagedResourceRollback(host hostAdapter) *managedResourceRollback {
return &managedResourceRollback{host: host}
}
func (r *managedResourceRollback) AddGroup(groupID string) {
r.groupID = strings.TrimSpace(groupID)
}
func (r *managedResourceRollback) AddChannel(channelID string) {
r.channelID = strings.TrimSpace(channelID)
}
func (r *managedResourceRollback) AddPlan(planID string) {
r.planID = strings.TrimSpace(planID)
}
func (r *managedResourceRollback) AddAccounts(accounts []sub2api.AccountRef) {
for _, account := range accounts {
accountID := strings.TrimSpace(account.ID)
if accountID == "" {
continue
}
r.accountIDs = append(r.accountIDs, accountID)
}
}
func (r *managedResourceRollback) Run(ctx context.Context) error {
if r == nil || r.host == nil {
return nil
}
var errs []error
for index := len(r.accountIDs) - 1; index >= 0; index-- {
if err := r.host.DeleteAccount(ctx, r.accountIDs[index]); err != nil {
errs = append(errs, fmt.Errorf("delete account %s: %w", r.accountIDs[index], err))
}
}
if r.planID != "" {
if err := r.host.DeletePlan(ctx, r.planID); err != nil {
errs = append(errs, fmt.Errorf("delete plan %s: %w", r.planID, err))
}
}
if r.channelID != "" {
if err := r.host.DeleteChannel(ctx, r.channelID); err != nil {
errs = append(errs, fmt.Errorf("delete channel %s: %w", r.channelID, err))
}
}
if r.groupID != "" {
if err := r.host.DeleteGroup(ctx, r.groupID); err != nil {
errs = append(errs, fmt.Errorf("delete group %s: %w", r.groupID, err))
}
}
return errors.Join(errs...)
}
func failOrDegrade(report ImportReport, mode string, err error) (ImportReport, error) {
if mode == ImportModeStrict {
report.BatchStatus = BatchStatusFailed
report.ProviderStatus = ProviderStatusFailed
report.AccessStatus = AccessStatusBroken
return report, err
}
report.BatchStatus = BatchStatusPartial
report.ProviderStatus = ProviderStatusDegraded
report.AccessStatus = AccessStatusBroken
return report, err
}

View File

@@ -0,0 +1,241 @@
package provision
import (
"context"
"errors"
"fmt"
"reflect"
"strings"
"testing"
"sub2api-cn-relay-manager/internal/host/sub2api"
"sub2api-cn-relay-manager/internal/pack"
)
func TestImportServiceImportSubscriptionFlow(t *testing.T) {
host := &fakeHostAdapter{
batchAccounts: []sub2api.AccountRef{{ID: "account_1"}, {ID: "account_2"}},
testResults: map[string]sub2api.ProbeResult{
"account_1": {OK: true, Status: "passed"},
"account_2": {OK: true, Status: "passed"},
},
models: map[string][]sub2api.AccountModel{
"account_1": {{ID: "deepseek-chat"}},
"account_2": {{ID: "deepseek-chat"}},
},
gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}},
}
svc := NewImportService(host)
report, err := svc.Import(context.Background(), ImportRequest{
Provider: sampleProviderManifest(),
Mode: ImportModePartial,
Access: AccessRequest{
Mode: AccessModeSubscription,
ProbeAPIKey: "user-key",
Subscriptions: []SubscriptionTarget{{UserID: "user_1", DurationDays: 30}},
},
Keys: []string{" key-1 ", "key-2", "key-1"},
})
if err != nil {
t.Fatalf("Import() error = %v", err)
}
if report.BatchStatus != BatchStatusSucceeded {
t.Fatalf("BatchStatus = %q, want %q", report.BatchStatus, BatchStatusSucceeded)
}
if report.ProviderStatus != ProviderStatusActive {
t.Fatalf("ProviderStatus = %q, want %q", report.ProviderStatus, ProviderStatusActive)
}
if report.AccessStatus != AccessStatusSubscriptionReady {
t.Fatalf("AccessStatus = %q, want %q", report.AccessStatus, AccessStatusSubscriptionReady)
}
if !reflect.DeepEqual(report.AcceptedKeys, []string{"key-1", "key-2"}) {
t.Fatalf("AcceptedKeys = %#v, want deduped normalized keys", report.AcceptedKeys)
}
if len(host.assignedSubscriptions) != 1 {
t.Fatalf("assigned subscriptions = %d, want 1", len(host.assignedSubscriptions))
}
if host.gatewayProbe.ExpectedModel != "deepseek-chat" {
t.Fatalf("gateway probe model = %q, want %q", host.gatewayProbe.ExpectedModel, "deepseek-chat")
}
}
func TestImportServiceStrictModeFailsWhenAnyAccountProbeFails(t *testing.T) {
host := &fakeHostAdapter{
batchAccounts: []sub2api.AccountRef{{ID: "account_1"}, {ID: "account_2"}},
testResults: map[string]sub2api.ProbeResult{
"account_1": {OK: true, Status: "passed"},
"account_2": {OK: false, Status: "failed", Message: "bad key"},
},
models: map[string][]sub2api.AccountModel{
"account_1": {{ID: "deepseek-chat"}},
"account_2": {{ID: "deepseek-chat"}},
},
}
svc := NewImportService(host)
report, err := svc.Import(context.Background(), ImportRequest{
Provider: sampleProviderManifest(),
Mode: ImportModeStrict,
Access: AccessRequest{Mode: AccessModeSelfService, ProbeAPIKey: "user-key"},
Keys: []string{"key-1", "key-2"},
})
if err == nil {
t.Fatal("Import() error = nil, want strict mode failure")
}
if report.BatchStatus != BatchStatusFailed {
t.Fatalf("BatchStatus = %q, want %q", report.BatchStatus, BatchStatusFailed)
}
if report.ProviderStatus != ProviderStatusFailed {
t.Fatalf("ProviderStatus = %q, want %q", report.ProviderStatus, ProviderStatusFailed)
}
}
func TestImportServiceRejectsUnknownMode(t *testing.T) {
svc := NewImportService(&fakeHostAdapter{})
_, err := svc.Import(context.Background(), ImportRequest{
Provider: sampleProviderManifest(),
Mode: "unknown",
Access: AccessRequest{Mode: AccessModeSelfService},
Keys: []string{"key-1"},
})
if err == nil {
t.Fatal("Import() error = nil, want mode validation error")
}
}
func TestImportServiceStrictModeRollsBackCreatedResources(t *testing.T) {
host := &fakeHostAdapter{
batchAccounts: []sub2api.AccountRef{{ID: "account_1"}, {ID: "account_2"}},
testResults: map[string]sub2api.ProbeResult{
"account_1": {OK: true, Status: "passed"},
"account_2": {OK: false, Status: "failed", Message: "bad key"},
},
models: map[string][]sub2api.AccountModel{
"account_1": {{ID: "deepseek-chat"}},
"account_2": {{ID: "deepseek-chat"}},
},
}
svc := NewImportService(host)
_, err := svc.Import(context.Background(), ImportRequest{
Provider: sampleProviderManifest(),
Mode: ImportModeStrict,
Access: AccessRequest{Mode: AccessModeSelfService, ProbeAPIKey: "user-key"},
Keys: []string{"key-1", "key-2"},
})
if err == nil {
t.Fatal("Import() error = nil, want strict mode failure")
}
want := []string{"account:account_2", "account:account_1", "channel:channel_1", "group:group_1"}
if !reflect.DeepEqual(host.deletedResources, want) {
t.Fatalf("deleted resources = %#v, want %#v", host.deletedResources, want)
}
}
func sampleProviderManifest() pack.ProviderManifest {
return pack.ProviderManifest{
ProviderID: "deepseek",
DisplayName: "DeepSeek OpenAI Compatible",
BaseURL: "https://api.deepseek.com",
Platform: "openai",
AccountType: "api",
DefaultModels: []string{"deepseek-chat", "deepseek-reasoner"},
SmokeTestModel: "deepseek-chat",
GroupTemplate: pack.GroupTemplate{Name: "DeepSeek 默认分组", RateMultiplier: 1},
ChannelTemplate: pack.ChannelTemplate{Name: "DeepSeek 默认渠道", ModelMapping: map[string]string{"deepseek-chat": "deepseek-chat"}},
PlanTemplate: pack.PlanTemplate{Name: "DeepSeek 默认套餐", Price: 19.9, ValidityDays: 30, ValidityUnit: "day"},
}
}
type fakeHostAdapter struct {
batchAccounts []sub2api.AccountRef
testResults map[string]sub2api.ProbeResult
models map[string][]sub2api.AccountModel
gatewayResult sub2api.GatewayAccessResult
batchCreateErr error
assignErr error
gatewayErr error
hostVersion string
assignedSubscriptions []sub2api.AssignSubscriptionRequest
gatewayProbe sub2api.GatewayAccessCheckRequest
deletedResources []string
managedSnapshot sub2api.ManagedResourceSnapshot
}
func (f *fakeHostAdapter) GetHostVersion(context.Context) (string, error) {
if strings.TrimSpace(f.hostVersion) == "" {
return "0.1.126", nil
}
return f.hostVersion, nil
}
func (f *fakeHostAdapter) ProbeCapabilities(context.Context) (sub2api.HostCapabilities, error) {
return sub2api.HostCapabilities{}, nil
}
func (f *fakeHostAdapter) CreateGroup(context.Context, sub2api.CreateGroupRequest) (sub2api.GroupRef, error) {
return sub2api.GroupRef{ID: "group_1", Name: "g"}, nil
}
func (f *fakeHostAdapter) DeleteGroup(_ context.Context, groupID string) error {
f.deletedResources = append(f.deletedResources, "group:"+groupID)
return nil
}
func (f *fakeHostAdapter) CreateChannel(context.Context, sub2api.CreateChannelRequest) (sub2api.ChannelRef, error) {
return sub2api.ChannelRef{ID: "channel_1", Name: "c"}, nil
}
func (f *fakeHostAdapter) DeleteChannel(_ context.Context, channelID string) error {
f.deletedResources = append(f.deletedResources, "channel:"+channelID)
return nil
}
func (f *fakeHostAdapter) CreatePlan(context.Context, sub2api.CreatePlanRequest) (sub2api.PlanRef, error) {
return sub2api.PlanRef{ID: "plan_1", Name: "p"}, nil
}
func (f *fakeHostAdapter) DeletePlan(_ context.Context, planID string) error {
f.deletedResources = append(f.deletedResources, "plan:"+planID)
return nil
}
func (f *fakeHostAdapter) CreateAccount(context.Context, sub2api.CreateAccountRequest) (sub2api.AccountRef, error) {
return sub2api.AccountRef{}, errors.New("unused")
}
func (f *fakeHostAdapter) BatchCreateAccounts(_ context.Context, _ sub2api.BatchCreateAccountsRequest) ([]sub2api.AccountRef, error) {
if f.batchCreateErr != nil {
return nil, f.batchCreateErr
}
return f.batchAccounts, nil
}
func (f *fakeHostAdapter) DeleteAccount(_ context.Context, accountID string) error {
f.deletedResources = append(f.deletedResources, "account:"+accountID)
return nil
}
func (f *fakeHostAdapter) TestAccount(_ context.Context, accountID string) (sub2api.ProbeResult, error) {
result, ok := f.testResults[accountID]
if !ok {
return sub2api.ProbeResult{}, fmt.Errorf("missing test result for %s", accountID)
}
return result, nil
}
func (f *fakeHostAdapter) GetAccountModels(_ context.Context, accountID string) ([]sub2api.AccountModel, error) {
models, ok := f.models[accountID]
if !ok {
return nil, fmt.Errorf("missing models for %s", accountID)
}
return models, nil
}
func (f *fakeHostAdapter) AssignSubscription(_ context.Context, req sub2api.AssignSubscriptionRequest) (sub2api.SubscriptionRef, error) {
if f.assignErr != nil {
return sub2api.SubscriptionRef{}, f.assignErr
}
f.assignedSubscriptions = append(f.assignedSubscriptions, req)
return sub2api.SubscriptionRef{ID: "subscription_1"}, nil
}
func (f *fakeHostAdapter) CheckGatewayAccess(_ context.Context, req sub2api.GatewayAccessCheckRequest) (sub2api.GatewayAccessResult, error) {
f.gatewayProbe = req
if f.gatewayErr != nil {
return sub2api.GatewayAccessResult{}, f.gatewayErr
}
return f.gatewayResult, nil
}
func (f *fakeHostAdapter) ListManagedResources(context.Context, sub2api.ListManagedResourcesRequest) (sub2api.ManagedResourceSnapshot, error) {
return f.managedSnapshot, nil
}

View File

@@ -0,0 +1,40 @@
package provision
import (
"fmt"
"regexp"
"strings"
"sub2api-cn-relay-manager/internal/pack"
)
var nonSlugPattern = regexp.MustCompile(`[^a-z0-9]+`)
type ResourceNames struct {
Group string
Channel string
Plan string
}
func SuggestAccountNamePrefix(provider pack.ProviderManifest) string {
return fmt.Sprintf("%s-", resourceSlug(provider.ProviderID))
}
func SuggestResourceNames(provider pack.ProviderManifest) ResourceNames {
slug := resourceSlug(provider.ProviderID)
return ResourceNames{
Group: fmt.Sprintf("crm-%s-group", slug),
Channel: fmt.Sprintf("crm-%s-channel", slug),
Plan: fmt.Sprintf("crm-%s-plan", slug),
}
}
func resourceSlug(raw string) string {
slug := strings.ToLower(strings.TrimSpace(raw))
slug = nonSlugPattern.ReplaceAllString(slug, "-")
slug = strings.Trim(slug, "-")
if slug == "" {
return "provider"
}
return slug
}

View File

@@ -0,0 +1,178 @@
package provision
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strings"
"sub2api-cn-relay-manager/internal/host/sub2api"
packdef "sub2api-cn-relay-manager/internal/pack"
"sub2api-cn-relay-manager/internal/store/sqlite"
)
type PackInstallRequest struct {
Pack packdef.LoadedPack
}
type PackInstallResult struct {
Pack sqlite.Pack
Providers []sqlite.Provider
HostVersion string
AlreadyInstalled bool
}
type PackInstallService struct {
store *sqlite.DB
host sub2api.HostAdapter
}
func NewPackInstallService(store *sqlite.DB, host sub2api.HostAdapter) *PackInstallService {
return &PackInstallService{store: store, host: host}
}
func (s *PackInstallService) Install(ctx context.Context, req PackInstallRequest) (PackInstallResult, error) {
if s == nil || s.store == nil {
return PackInstallResult{}, fmt.Errorf("store is required")
}
if s.host == nil {
return PackInstallResult{}, fmt.Errorf("host adapter is required")
}
if strings.TrimSpace(req.Pack.Manifest.PackID) == "" {
return PackInstallResult{}, fmt.Errorf("pack manifest is required")
}
hostVersion, err := s.host.GetHostVersion(ctx)
if err != nil {
return PackInstallResult{}, fmt.Errorf("get host version: %w", err)
}
if err := packdef.CheckHostCompatibility(req.Pack.Manifest, hostVersion); err != nil {
return PackInstallResult{}, err
}
result := PackInstallResult{HostVersion: hostVersion}
if err := s.store.WithTx(ctx, func(queries *sqlite.Queries) error {
existing, err := queries.Packs.GetByPackID(ctx, req.Pack.Manifest.PackID)
if err == nil {
if err := validateExistingPack(existing, req.Pack); err != nil {
return err
}
result.AlreadyInstalled = true
} else if !errors.Is(err, sql.ErrNoRows) {
return err
}
packRow, err := buildPackRecord(req.Pack)
if err != nil {
return err
}
if _, err := queries.Packs.Upsert(ctx, packRow); err != nil {
return err
}
persistedPack, err := queries.Packs.GetByPackID(ctx, req.Pack.Manifest.PackID)
if err != nil {
return err
}
result.Pack = persistedPack
providers := make([]sqlite.Provider, 0, len(req.Pack.Providers))
for _, providerManifest := range req.Pack.Providers {
providerRow, err := buildProviderRecord(persistedPack.ID, providerManifest)
if err != nil {
return err
}
if _, err := queries.Providers.Upsert(ctx, providerRow); err != nil {
return err
}
persistedProvider, err := queries.Providers.GetByPackIDAndProviderID(ctx, persistedPack.ID, providerManifest.ProviderID)
if err != nil {
return err
}
providers = append(providers, persistedProvider)
}
result.Providers = providers
return nil
}); err != nil {
return PackInstallResult{}, err
}
return result, nil
}
func validateExistingPack(existing sqlite.Pack, loaded packdef.LoadedPack) error {
if strings.TrimSpace(existing.Version) != strings.TrimSpace(loaded.Manifest.Version) {
return fmt.Errorf("pack %q already installed with version %q; upgrade lifecycle not implemented", existing.PackID, existing.Version)
}
if strings.TrimSpace(existing.Checksum) != strings.TrimSpace(loaded.Checksum) {
return fmt.Errorf("pack %q version %q checksum drift detected", existing.PackID, existing.Version)
}
return nil
}
func buildPackRecord(loaded packdef.LoadedPack) (sqlite.Pack, error) {
manifestJSON, err := json.Marshal(loaded.Manifest)
if err != nil {
return sqlite.Pack{}, fmt.Errorf("marshal pack manifest: %w", err)
}
return sqlite.Pack{
PackID: loaded.Manifest.PackID,
Version: loaded.Manifest.Version,
Checksum: loaded.Checksum,
Vendor: loaded.Manifest.Vendor,
TargetHost: loaded.Manifest.TargetHost,
MinHostVersion: loaded.Manifest.MinHostVersion,
MaxHostVersion: loaded.Manifest.MaxHostVersion,
ManifestJSON: string(manifestJSON),
}, nil
}
func buildProviderRecord(packID int64, provider packdef.ProviderManifest) (sqlite.Provider, error) {
defaultModelsJSON, err := marshalJSONString(provider.DefaultModels)
if err != nil {
return sqlite.Provider{}, fmt.Errorf("marshal provider default models: %w", err)
}
groupTemplateJSON, err := marshalJSONString(provider.GroupTemplate)
if err != nil {
return sqlite.Provider{}, fmt.Errorf("marshal group template: %w", err)
}
channelTemplateJSON, err := marshalJSONString(provider.ChannelTemplate)
if err != nil {
return sqlite.Provider{}, fmt.Errorf("marshal channel template: %w", err)
}
planTemplateJSON, err := marshalJSONString(provider.PlanTemplate)
if err != nil {
return sqlite.Provider{}, fmt.Errorf("marshal plan template: %w", err)
}
importOptionsJSON, err := marshalJSONString(provider.Import)
if err != nil {
return sqlite.Provider{}, fmt.Errorf("marshal import options: %w", err)
}
manifestJSON, err := marshalJSONString(provider)
if err != nil {
return sqlite.Provider{}, fmt.Errorf("marshal provider manifest: %w", err)
}
return sqlite.Provider{
PackID: packID,
ProviderID: provider.ProviderID,
DisplayName: provider.DisplayName,
BaseURL: provider.BaseURL,
Platform: provider.Platform,
AccountType: provider.AccountType,
DefaultModelsJSON: defaultModelsJSON,
SmokeTestModel: provider.SmokeTestModel,
GroupTemplateJSON: groupTemplateJSON,
ChannelTemplateJSON: channelTemplateJSON,
PlanTemplateJSON: planTemplateJSON,
ImportOptionsJSON: importOptionsJSON,
ManifestJSON: manifestJSON,
}, nil
}
func marshalJSONString(value any) (string, error) {
body, err := json.Marshal(value)
if err != nil {
return "", err
}
return string(body), nil
}

View File

@@ -0,0 +1,120 @@
package provision
import (
"context"
"strings"
"testing"
"sub2api-cn-relay-manager/internal/pack"
"sub2api-cn-relay-manager/internal/store/sqlite"
)
func TestPackInstallServiceInstallPersistsPackAndProviders(t *testing.T) {
store := openProvisionTestStore(t)
defer closeProvisionTestStore(t, store)
host := &fakeHostAdapter{}
loaded := sampleLoadedPack()
svc := NewPackInstallService(store, host)
result, err := svc.Install(context.Background(), PackInstallRequest{Pack: loaded})
if err != nil {
t.Fatalf("Install() error = %v", err)
}
if result.HostVersion != "0.1.126" {
t.Fatalf("HostVersion = %q, want 0.1.126", result.HostVersion)
}
if result.AlreadyInstalled {
t.Fatal("AlreadyInstalled = true, want false on first install")
}
if result.Pack.PackID != loaded.Manifest.PackID {
t.Fatalf("Pack.PackID = %q, want %q", result.Pack.PackID, loaded.Manifest.PackID)
}
if len(result.Providers) != 1 || result.Providers[0].ProviderID != loaded.Providers[0].ProviderID {
t.Fatalf("Providers = %#v, want one persisted provider", result.Providers)
}
if got := queryCount(t, store.SQLDB(), "packs"); got != 1 {
t.Fatalf("packs row count = %d, want 1", got)
}
if got := queryCount(t, store.SQLDB(), "providers"); got != 1 {
t.Fatalf("providers row count = %d, want 1", got)
}
repeat, err := svc.Install(context.Background(), PackInstallRequest{Pack: loaded})
if err != nil {
t.Fatalf("second Install() error = %v", err)
}
if !repeat.AlreadyInstalled {
t.Fatal("AlreadyInstalled = false, want true on re-install")
}
if got := queryCount(t, store.SQLDB(), "packs"); got != 1 {
t.Fatalf("packs row count after re-install = %d, want 1", got)
}
if got := queryCount(t, store.SQLDB(), "providers"); got != 1 {
t.Fatalf("providers row count after re-install = %d, want 1", got)
}
}
func TestPackInstallServiceInstallRejectsVersionAndChecksumDrift(t *testing.T) {
store := openProvisionTestStore(t)
defer closeProvisionTestStore(t, store)
svc := NewPackInstallService(store, &fakeHostAdapter{})
loaded := sampleLoadedPack()
if _, err := svc.Install(context.Background(), PackInstallRequest{Pack: loaded}); err != nil {
t.Fatalf("initial Install() error = %v", err)
}
versionDrift := sampleLoadedPack()
versionDrift.Manifest.Version = "2.0.0"
if _, err := svc.Install(context.Background(), PackInstallRequest{Pack: versionDrift}); err == nil || !strings.Contains(err.Error(), "upgrade lifecycle not implemented") {
t.Fatalf("Install() version drift error = %v, want upgrade lifecycle error", err)
}
checksumDrift := sampleLoadedPack()
checksumDrift.Checksum = "checksum-2"
if _, err := svc.Install(context.Background(), PackInstallRequest{Pack: checksumDrift}); err == nil || !strings.Contains(err.Error(), "checksum drift detected") {
t.Fatalf("Install() checksum drift error = %v, want checksum drift error", err)
}
}
func TestPackInstallServiceInstallValidatesDependencies(t *testing.T) {
loaded := sampleLoadedPack()
storeWithoutHost := openProvisionTestStore(t)
defer closeProvisionTestStore(t, storeWithoutHost)
storeWithoutPack := openProvisionTestStore(t)
defer closeProvisionTestStore(t, storeWithoutPack)
if _, err := (*PackInstallService)(nil).Install(context.Background(), PackInstallRequest{Pack: loaded}); err == nil || err.Error() != "store is required" {
t.Fatalf("nil service Install() error = %v, want store is required", err)
}
if _, err := (&PackInstallService{store: storeWithoutHost}).Install(context.Background(), PackInstallRequest{Pack: loaded}); err == nil || err.Error() != "host adapter is required" {
t.Fatalf("missing host Install() error = %v, want host adapter is required", err)
}
if _, err := NewPackInstallService(storeWithoutPack, &fakeHostAdapter{}).Install(context.Background(), PackInstallRequest{}); err == nil || err.Error() != "pack manifest is required" {
t.Fatalf("missing pack Install() error = %v, want pack manifest is required", err)
}
}
func sampleLoadedPack() pack.LoadedPack {
provider := sampleProviderManifest()
return pack.LoadedPack{
Manifest: pack.Manifest{
PackID: "openai-cn-pack",
Version: "1.0.0",
Vendor: "nous",
TargetHost: "sub2api",
MinHostVersion: "0.1.126",
MaxHostVersion: "0.2.x",
},
Providers: []pack.ProviderManifest{provider},
Checksum: "checksum-1",
}
}
func TestValidateExistingPack(t *testing.T) {
existing := sqlite.Pack{PackID: "openai-cn-pack", Version: "1.0.0", Checksum: "checksum-1"}
if err := validateExistingPack(existing, sampleLoadedPack()); err != nil {
t.Fatalf("validateExistingPack() error = %v, want nil", err)
}
}

View File

@@ -0,0 +1,90 @@
package provision
import (
"context"
"fmt"
"sub2api-cn-relay-manager/internal/host/sub2api"
"sub2api-cn-relay-manager/internal/pack"
)
const (
PreviewActionCreate = "create"
PreviewActionReuse = "reuse"
PreviewActionConflict = "conflict"
)
type previewHost interface {
ListManagedResources(ctx context.Context, req sub2api.ListManagedResourcesRequest) (sub2api.ManagedResourceSnapshot, error)
}
type PreviewRequest struct {
Provider pack.ProviderManifest
Mode string
Keys []string
}
type PreviewDecision struct {
Action string
Suggested string
ExistingID string
Reason string
}
type PreviewReport struct {
AcceptedKeys []string
Names ResourceNames
Decisions map[string]PreviewDecision
}
type PreviewService struct {
host previewHost
}
func NewPreviewService(host previewHost) *PreviewService {
return &PreviewService{host: host}
}
func (s *PreviewService) PreviewImport(ctx context.Context, req PreviewRequest) (PreviewReport, error) {
acceptedKeys, err := normalizeKeys(req.Keys)
if err != nil {
return PreviewReport{}, err
}
if err := validateMode(req.Mode); err != nil {
return PreviewReport{}, err
}
if s.host == nil {
return PreviewReport{}, fmt.Errorf("preview host is required")
}
names := SuggestResourceNames(req.Provider)
snapshot, err := s.host.ListManagedResources(ctx, sub2api.ListManagedResourcesRequest{
GroupName: names.Group,
ChannelName: names.Channel,
PlanName: names.Plan,
})
if err != nil {
return PreviewReport{}, fmt.Errorf("list managed resources: %w", err)
}
return PreviewReport{
AcceptedKeys: acceptedKeys,
Names: names,
Decisions: map[string]PreviewDecision{
"group": decideResource(names.Group, snapshot.Groups),
"channel": decideResource(names.Channel, snapshot.Channels),
"plan": decideResource(names.Plan, snapshot.Plans),
},
}, nil
}
func decideResource(suggested string, existing []sub2api.NamedResource) PreviewDecision {
switch len(existing) {
case 0:
return PreviewDecision{Action: PreviewActionCreate, Suggested: suggested}
case 1:
return PreviewDecision{Action: PreviewActionReuse, Suggested: suggested, ExistingID: existing[0].ID, Reason: "matching managed resource already exists"}
default:
return PreviewDecision{Action: PreviewActionConflict, Suggested: suggested, Reason: "multiple managed resources share the suggested name"}
}
}

View File

@@ -0,0 +1,87 @@
package provision
import (
"context"
"reflect"
"testing"
"sub2api-cn-relay-manager/internal/host/sub2api"
)
func TestSuggestResourceNames(t *testing.T) {
provider := sampleProviderManifest()
names := SuggestResourceNames(provider)
want := ResourceNames{
Group: "crm-deepseek-group",
Channel: "crm-deepseek-channel",
Plan: "crm-deepseek-plan",
}
if !reflect.DeepEqual(names, want) {
t.Fatalf("SuggestResourceNames() = %#v, want %#v", names, want)
}
}
func TestPreviewServiceReportsCreateActionsWhenHostHasNoResources(t *testing.T) {
host := &fakePreviewHost{}
svc := NewPreviewService(host)
report, err := svc.PreviewImport(context.Background(), PreviewRequest{
Provider: sampleProviderManifest(),
Mode: ImportModeStrict,
Keys: []string{" key-1 ", "key-2", "key-1"},
})
if err != nil {
t.Fatalf("PreviewImport() error = %v", err)
}
if !reflect.DeepEqual(report.AcceptedKeys, []string{"key-1", "key-2"}) {
t.Fatalf("AcceptedKeys = %#v, want normalized deduped keys", report.AcceptedKeys)
}
if got := report.Decisions["group"]; got.Action != PreviewActionCreate {
t.Fatalf("group action = %q, want %q", got.Action, PreviewActionCreate)
}
if got := report.Decisions["channel"]; got.Action != PreviewActionCreate {
t.Fatalf("channel action = %q, want %q", got.Action, PreviewActionCreate)
}
if got := report.Decisions["plan"]; got.Action != PreviewActionCreate {
t.Fatalf("plan action = %q, want %q", got.Action, PreviewActionCreate)
}
}
func TestPreviewServiceReportsReuseAndConflict(t *testing.T) {
host := &fakePreviewHost{snapshot: sub2api.ManagedResourceSnapshot{
Groups: []sub2api.NamedResource{{ID: "group_1", Name: "crm-deepseek-group"}},
Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "crm-deepseek-channel"}, {ID: "channel_2", Name: "crm-deepseek-channel"}},
Plans: []sub2api.NamedResource{{ID: "plan_1", Name: "crm-deepseek-plan"}},
}}
svc := NewPreviewService(host)
report, err := svc.PreviewImport(context.Background(), PreviewRequest{
Provider: sampleProviderManifest(),
Mode: ImportModePartial,
Keys: []string{"key-1"},
})
if err != nil {
t.Fatalf("PreviewImport() error = %v", err)
}
if got := report.Decisions["group"]; got.Action != PreviewActionReuse || got.ExistingID != "group_1" {
t.Fatalf("group decision = %#v, want reuse group_1", got)
}
if got := report.Decisions["plan"]; got.Action != PreviewActionReuse || got.ExistingID != "plan_1" {
t.Fatalf("plan decision = %#v, want reuse plan_1", got)
}
if got := report.Decisions["channel"]; got.Action != PreviewActionConflict {
t.Fatalf("channel decision = %#v, want conflict", got)
}
}
type fakePreviewHost struct {
snapshot sub2api.ManagedResourceSnapshot
}
func (f *fakePreviewHost) ListManagedResources(context.Context, sub2api.ListManagedResourcesRequest) (sub2api.ManagedResourceSnapshot, error) {
return f.snapshot, nil
}

View File

@@ -0,0 +1,150 @@
package provision
import (
"context"
"encoding/json"
"fmt"
"strings"
"sub2api-cn-relay-manager/internal/store/sqlite"
)
type ProviderQuery struct {
ProviderID string
PackID string
}
type ProviderSnapshot struct {
Host sqlite.Host
Pack sqlite.Pack
Provider sqlite.Provider
Batch sqlite.ImportBatch
ManagedResources []sqlite.ManagedResource
AccessClosures []sqlite.AccessClosureRecord
ReconcileRuns []sqlite.ReconcileRun
ProviderStatus string
LatestAccessStatus string
LatestReconcileStatus string
LatestReconcileSummary map[string]any
}
type ProviderStatusService struct {
store *sqlite.DB
}
func NewProviderStatusService(store *sqlite.DB) *ProviderStatusService {
return &ProviderStatusService{store: store}
}
func (s *ProviderStatusService) GetStatus(ctx context.Context, query ProviderQuery) (ProviderSnapshot, error) {
return s.snapshot(ctx, query)
}
func (s *ProviderStatusService) GetResources(ctx context.Context, query ProviderQuery) (ProviderSnapshot, error) {
return s.snapshot(ctx, query)
}
func (s *ProviderStatusService) snapshot(ctx context.Context, query ProviderQuery) (ProviderSnapshot, error) {
if s == nil || s.store == nil {
return ProviderSnapshot{}, fmt.Errorf("store is required")
}
provider, err := s.resolveProvider(ctx, query)
if err != nil {
return ProviderSnapshot{}, err
}
packRow, err := s.store.Packs().GetByID(ctx, provider.PackID)
if err != nil {
return ProviderSnapshot{}, err
}
batchRow, err := s.store.ImportBatches().GetLatestByProviderID(ctx, provider.ID)
if err != nil {
return ProviderSnapshot{}, err
}
hostRow, err := s.store.Hosts().GetByID(ctx, batchRow.HostID)
if err != nil {
return ProviderSnapshot{}, err
}
managedResources, err := s.store.ManagedResources().GetByBatchID(ctx, batchRow.ID)
if err != nil {
return ProviderSnapshot{}, err
}
accessClosures, err := s.store.AccessClosures().GetByBatchID(ctx, batchRow.ID)
if err != nil {
return ProviderSnapshot{}, err
}
reconcileRuns, err := s.store.ReconcileRuns().GetByProviderID(ctx, provider.ID)
if err != nil {
return ProviderSnapshot{}, err
}
latestAccessStatus := batchRow.AccessStatus
if len(accessClosures) > 0 {
latestAccessStatus = firstNonEmpty(accessClosures[len(accessClosures)-1].Status, latestAccessStatus)
}
latestReconcileStatus := "not_run"
latestReconcileSummary := map[string]any{}
if len(reconcileRuns) > 0 {
latestReconcileStatus = firstNonEmpty(reconcileRuns[0].Status, latestReconcileStatus)
if strings.TrimSpace(reconcileRuns[0].SummaryJSON) != "" {
if err := json.Unmarshal([]byte(reconcileRuns[0].SummaryJSON), &latestReconcileSummary); err != nil {
return ProviderSnapshot{}, fmt.Errorf("decode reconcile summary: %w", err)
}
}
}
providerStatus := deriveProviderStatus(batchRow.BatchStatus, latestReconcileStatus)
return ProviderSnapshot{
Host: hostRow,
Pack: packRow,
Provider: provider,
Batch: batchRow,
ManagedResources: managedResources,
AccessClosures: accessClosures,
ReconcileRuns: reconcileRuns,
ProviderStatus: providerStatus,
LatestAccessStatus: latestAccessStatus,
LatestReconcileStatus: latestReconcileStatus,
LatestReconcileSummary: latestReconcileSummary,
}, nil
}
func (s *ProviderStatusService) resolveProvider(ctx context.Context, query ProviderQuery) (sqlite.Provider, error) {
providerID := strings.TrimSpace(query.ProviderID)
packID := strings.TrimSpace(query.PackID)
if providerID == "" {
return sqlite.Provider{}, fmt.Errorf("provider_id is required")
}
if packID != "" {
packRow, err := s.store.Packs().GetByPackID(ctx, packID)
if err != nil {
return sqlite.Provider{}, err
}
return s.store.Providers().GetByPackIDAndProviderID(ctx, packRow.ID, providerID)
}
providers, err := s.store.Providers().ListByProviderID(ctx, providerID)
if err != nil {
return sqlite.Provider{}, err
}
if len(providers) == 0 {
return sqlite.Provider{}, fmt.Errorf("provider %q not found", providerID)
}
if len(providers) > 1 {
return sqlite.Provider{}, fmt.Errorf("provider %q exists in multiple packs; pack_id is required", providerID)
}
return providers[0], nil
}
func deriveProviderStatus(batchStatus, reconcileStatus string) string {
reconcileStatus = strings.TrimSpace(reconcileStatus)
if reconcileStatus != "" && reconcileStatus != "not_run" {
return reconcileStatus
}
switch strings.TrimSpace(batchStatus) {
case BatchStatusSucceeded:
return ProviderStatusActive
case BatchStatusFailed:
return ProviderStatusFailed
case "running":
return "running"
default:
return firstNonEmpty(batchStatus, "unknown")
}
}

View File

@@ -0,0 +1,98 @@
package provision
import (
"context"
"testing"
"sub2api-cn-relay-manager/internal/store/sqlite"
)
func TestProviderStatusServiceReturnsLatestSnapshot(t *testing.T) {
store := openProvisionTestStore(t)
defer closeProvisionTestStore(t, store)
ctx := context.Background()
hostID, err := store.Hosts().Create(ctx, sqlite.Host{HostID: "host-1", BaseURL: "https://sub2api.example.com", HostVersion: "0.1.126", CapabilityProbeJSON: `{"supports_batch_accounts":true}`})
if err != nil {
t.Fatalf("Hosts().Create() error = %v", err)
}
packID, err := store.Packs().Create(ctx, sqlite.Pack{PackID: "openai-cn-pack", Version: "1.0.0", Checksum: "checksum-1"})
if err != nil {
t.Fatalf("Packs().Create() error = %v", err)
}
providerID, err := store.Providers().Create(ctx, sqlite.Provider{PackID: packID, ProviderID: "deepseek", DisplayName: "DeepSeek", BaseURL: "https://api.deepseek.com", Platform: "openai"})
if err != nil {
t.Fatalf("Providers().Create() error = %v", err)
}
batchID, err := store.ImportBatches().Create(ctx, sqlite.ImportBatch{HostID: hostID, PackID: packID, ProviderID: providerID, Mode: ImportModeStrict, BatchStatus: BatchStatusSucceeded, AccessStatus: AccessStatusSelfServiceReady})
if err != nil {
t.Fatalf("ImportBatches().Create() error = %v", err)
}
if _, err := store.ManagedResources().Create(ctx, sqlite.ManagedResource{BatchID: batchID, ResourceType: "group", HostResourceID: "group-1", ResourceName: "deepseek-group"}); err != nil {
t.Fatalf("ManagedResources().Create(group) error = %v", err)
}
if _, err := store.ManagedResources().Create(ctx, sqlite.ManagedResource{BatchID: batchID, ResourceType: "account", HostResourceID: "account-1", ResourceName: "deepseek-account-1"}); err != nil {
t.Fatalf("ManagedResources().Create(account) error = %v", err)
}
if _, err := store.AccessClosures().Create(ctx, sqlite.AccessClosureRecord{BatchID: batchID, ClosureType: AccessModeSelfService, Status: AccessStatusSelfServiceReady, DetailsJSON: `{"ok":true}`}); err != nil {
t.Fatalf("AccessClosures().Create() error = %v", err)
}
if _, err := store.ReconcileRuns().Create(ctx, sqlite.ReconcileRun{ProviderID: providerID, Status: "drifted", SummaryJSON: `{"missing_count":1}`}); err != nil {
t.Fatalf("ReconcileRuns().Create() error = %v", err)
}
snapshot, err := NewProviderStatusService(store).GetStatus(ctx, ProviderQuery{ProviderID: "deepseek"})
if err != nil {
t.Fatalf("GetStatus() error = %v", err)
}
if snapshot.Host.HostID != "host-1" {
t.Fatalf("Host.HostID = %q, want host-1", snapshot.Host.HostID)
}
if snapshot.Pack.PackID != "openai-cn-pack" {
t.Fatalf("Pack.PackID = %q, want openai-cn-pack", snapshot.Pack.PackID)
}
if snapshot.Provider.ProviderID != "deepseek" {
t.Fatalf("Provider.ProviderID = %q, want deepseek", snapshot.Provider.ProviderID)
}
if snapshot.ProviderStatus != "drifted" {
t.Fatalf("ProviderStatus = %q, want drifted", snapshot.ProviderStatus)
}
if snapshot.LatestAccessStatus != AccessStatusSelfServiceReady {
t.Fatalf("LatestAccessStatus = %q, want %q", snapshot.LatestAccessStatus, AccessStatusSelfServiceReady)
}
if snapshot.LatestReconcileStatus != "drifted" {
t.Fatalf("LatestReconcileStatus = %q, want drifted", snapshot.LatestReconcileStatus)
}
if got := len(snapshot.ManagedResources); got != 2 {
t.Fatalf("len(ManagedResources) = %d, want 2", got)
}
if got, ok := snapshot.LatestReconcileSummary["missing_count"].(float64); !ok || got != 1 {
t.Fatalf("LatestReconcileSummary[missing_count] = %#v, want 1", snapshot.LatestReconcileSummary["missing_count"])
}
}
func TestProviderStatusServiceRequiresPackIDWhenProviderIDIsAmbiguous(t *testing.T) {
store := openProvisionTestStore(t)
defer closeProvisionTestStore(t, store)
ctx := context.Background()
pack1, err := store.Packs().Create(ctx, sqlite.Pack{PackID: "pack-a", Version: "1.0.0", Checksum: "checksum-a"})
if err != nil {
t.Fatalf("Packs().Create(pack-a) error = %v", err)
}
pack2, err := store.Packs().Create(ctx, sqlite.Pack{PackID: "pack-b", Version: "1.0.0", Checksum: "checksum-b"})
if err != nil {
t.Fatalf("Packs().Create(pack-b) error = %v", err)
}
if _, err := store.Providers().Create(ctx, sqlite.Provider{PackID: pack1, ProviderID: "deepseek", DisplayName: "DeepSeek A", BaseURL: "https://a.example.com", Platform: "openai"}); err != nil {
t.Fatalf("Providers().Create(pack-a) error = %v", err)
}
if _, err := store.Providers().Create(ctx, sqlite.Provider{PackID: pack2, ProviderID: "deepseek", DisplayName: "DeepSeek B", BaseURL: "https://b.example.com", Platform: "openai"}); err != nil {
t.Fatalf("Providers().Create(pack-b) error = %v", err)
}
_, err = NewProviderStatusService(store).GetStatus(ctx, ProviderQuery{ProviderID: "deepseek"})
if err == nil || err.Error() != `provider "deepseek" exists in multiple packs; pack_id is required` {
t.Fatalf("GetStatus() error = %v, want ambiguous provider error", err)
}
}

View File

@@ -0,0 +1,187 @@
package provision
import (
"context"
"testing"
"sub2api-cn-relay-manager/internal/host/sub2api"
"sub2api-cn-relay-manager/internal/pack"
"sub2api-cn-relay-manager/internal/store/sqlite"
)
func TestReconcileServiceReturnsActiveAfterProbeRerun(t *testing.T) {
store := openProvisionTestStore(t)
defer closeProvisionTestStore(t, store)
host := &fakeHostAdapter{
batchAccounts: []sub2api.AccountRef{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
testResults: map[string]sub2api.ProbeResult{
"account_1": {OK: true, Status: "passed"},
"account_2": {OK: true, Status: "passed"},
},
models: map[string][]sub2api.AccountModel{
"account_1": {{ID: "deepseek-chat"}},
"account_2": {{ID: "deepseek-chat"}},
},
gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}},
}
batchID := seedRuntimeImportForReconcile(t, store, host)
host.managedSnapshot = sub2api.ManagedResourceSnapshot{
Groups: []sub2api.NamedResource{{ID: "group_1", Name: "g"}},
Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "c"}},
Accounts: []sub2api.NamedResource{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
}
result, err := NewReconcileService(store, host).Reconcile(context.Background(), ReconcileRequest{
HostBaseURL: "https://sub2api.example.com",
AccessProbeAPIKey: "user-key",
Pack: pack.LoadedPack{Manifest: pack.Manifest{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api", MinHostVersion: "0.1.126", MaxHostVersion: "0.2.x"}},
Provider: sampleProviderManifest(),
})
if err != nil {
t.Fatalf("Reconcile() error = %v", err)
}
if result.BatchID != batchID {
t.Fatalf("BatchID = %d, want %d", result.BatchID, batchID)
}
if result.Status != "active" {
t.Fatalf("Status = %q, want active", result.Status)
}
if result.ProbeFailureCount != 0 {
t.Fatalf("ProbeFailureCount = %d, want 0", result.ProbeFailureCount)
}
if result.AccessStatus != AccessStatusSelfServiceReady {
t.Fatalf("AccessStatus = %q, want %q", result.AccessStatus, AccessStatusSelfServiceReady)
}
if got := queryCount(t, store.SQLDB(), "probe_results"); got != 4 {
t.Fatalf("probe_results row count = %d, want 4 after rerun", got)
}
if got := queryCount(t, store.SQLDB(), "access_closure_records"); got != 2 {
t.Fatalf("access_closure_records row count = %d, want 2 after rerun", got)
}
}
func TestReconcileServiceReturnsDegradedWhenProbeRerunFails(t *testing.T) {
store := openProvisionTestStore(t)
defer closeProvisionTestStore(t, store)
host := &fakeHostAdapter{
batchAccounts: []sub2api.AccountRef{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
testResults: map[string]sub2api.ProbeResult{
"account_1": {OK: true, Status: "passed"},
"account_2": {OK: false, Status: "failed", Message: "bad key"},
},
models: map[string][]sub2api.AccountModel{
"account_1": {{ID: "deepseek-chat"}},
"account_2": {{ID: "deepseek-chat"}},
},
gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}},
}
seedRuntimeImportForReconcile(t, store, host)
host.managedSnapshot = sub2api.ManagedResourceSnapshot{
Groups: []sub2api.NamedResource{{ID: "group_1", Name: "g"}},
Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "c"}},
Accounts: []sub2api.NamedResource{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
}
result, err := NewReconcileService(store, host).Reconcile(context.Background(), ReconcileRequest{
HostBaseURL: "https://sub2api.example.com",
AccessProbeAPIKey: "user-key",
Pack: pack.LoadedPack{Manifest: pack.Manifest{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api", MinHostVersion: "0.1.126", MaxHostVersion: "0.2.x"}},
Provider: sampleProviderManifest(),
})
if err != nil {
t.Fatalf("Reconcile() error = %v", err)
}
if result.Status != "degraded" {
t.Fatalf("Status = %q, want degraded", result.Status)
}
if result.ProbeFailureCount != 1 {
t.Fatalf("ProbeFailureCount = %d, want 1", result.ProbeFailureCount)
}
}
func TestReconcileServiceReturnsDriftedWhenManagedResourceMissing(t *testing.T) {
store := openProvisionTestStore(t)
defer closeProvisionTestStore(t, store)
host := &fakeHostAdapter{
batchAccounts: []sub2api.AccountRef{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
testResults: map[string]sub2api.ProbeResult{
"account_1": {OK: true, Status: "passed"},
"account_2": {OK: true, Status: "passed"},
},
models: map[string][]sub2api.AccountModel{
"account_1": {{ID: "deepseek-chat"}},
"account_2": {{ID: "deepseek-chat"}},
},
gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}},
}
seedRuntimeImportForReconcile(t, store, host)
host.managedSnapshot = sub2api.ManagedResourceSnapshot{
Groups: []sub2api.NamedResource{{ID: "group_1", Name: "g"}},
Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "c"}},
Accounts: []sub2api.NamedResource{{ID: "account_1", Name: "deepseek-01"}},
}
result, err := NewReconcileService(store, host).Reconcile(context.Background(), ReconcileRequest{
HostBaseURL: "https://sub2api.example.com",
AccessProbeAPIKey: "user-key",
Pack: pack.LoadedPack{Manifest: pack.Manifest{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api", MinHostVersion: "0.1.126", MaxHostVersion: "0.2.x"}},
Provider: sampleProviderManifest(),
})
if err != nil {
t.Fatalf("Reconcile() error = %v", err)
}
if result.Status != "drifted" {
t.Fatalf("Status = %q, want drifted", result.Status)
}
if result.MissingCount != 1 {
t.Fatalf("MissingCount = %d, want 1", result.MissingCount)
}
}
func TestDeriveHealthyAccessStatus(t *testing.T) {
tests := []struct {
name string
closureType string
want string
}{
{name: "subscription", closureType: AccessModeSubscription, want: AccessStatusSubscriptionReady},
{name: "self-service", closureType: AccessModeSelfService, want: AccessStatusSelfServiceReady},
{name: "unknown", closureType: "other", want: "unknown"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := deriveHealthyAccessStatus(tc.closureType); got != tc.want {
t.Fatalf("deriveHealthyAccessStatus(%q) = %q, want %q", tc.closureType, got, tc.want)
}
})
}
}
func seedRuntimeImportForReconcile(t *testing.T, store *sqlite.DB, host *fakeHostAdapter) int64 {
t.Helper()
result, err := NewRuntimeImportService(store, host).Import(context.Background(), RuntimeImportRequest{
HostID: "host-1",
HostBaseURL: "https://sub2api.example.com",
Pack: pack.LoadedPack{
Manifest: pack.Manifest{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api", MinHostVersion: "0.1.126", MaxHostVersion: "0.2.x"},
Checksum: "checksum-1",
},
Provider: sampleProviderManifest(),
Mode: ImportModePartial,
Keys: []string{"key-1", "key-2"},
Access: AccessRequest{
Mode: AccessModeSelfService,
ProbeAPIKey: "user-key",
},
})
if err != nil {
t.Fatalf("seed RuntimeImportService.Import() error = %v", err)
}
return result.BatchID
}

View File

@@ -0,0 +1,90 @@
package provision
import (
"context"
"errors"
"fmt"
"sub2api-cn-relay-manager/internal/host/sub2api"
"sub2api-cn-relay-manager/internal/pack"
)
type rollbackHost interface {
ListManagedResources(ctx context.Context, req sub2api.ListManagedResourcesRequest) (sub2api.ManagedResourceSnapshot, error)
DeleteAccount(ctx context.Context, accountID string) error
DeletePlan(ctx context.Context, planID string) error
DeleteChannel(ctx context.Context, channelID string) error
DeleteGroup(ctx context.Context, groupID string) error
}
type RollbackRequest struct {
Provider pack.ProviderManifest
}
type RollbackReport struct {
AccountsDeleted int
PlansDeleted int
ChannelsDeleted int
GroupsDeleted int
}
type RollbackService struct {
host rollbackHost
}
func NewRollbackService(host rollbackHost) *RollbackService {
return &RollbackService{host: host}
}
func (s *RollbackService) Rollback(ctx context.Context, req RollbackRequest) (RollbackReport, error) {
if s.host == nil {
return RollbackReport{}, fmt.Errorf("rollback host is required")
}
names := SuggestResourceNames(req.Provider)
snapshot, err := s.host.ListManagedResources(ctx, sub2api.ListManagedResourcesRequest{
GroupName: names.Group,
ChannelName: names.Channel,
PlanName: names.Plan,
AccountNamePrefix: SuggestAccountNamePrefix(req.Provider),
})
if err != nil {
return RollbackReport{}, fmt.Errorf("list managed resources: %w", err)
}
var report RollbackReport
var errs []error
for index := len(snapshot.Accounts) - 1; index >= 0; index-- {
if err := s.host.DeleteAccount(ctx, snapshot.Accounts[index].ID); err != nil {
errs = append(errs, fmt.Errorf("delete account %s: %w", snapshot.Accounts[index].ID, err))
continue
}
report.AccountsDeleted++
}
for index := len(snapshot.Plans) - 1; index >= 0; index-- {
if err := s.host.DeletePlan(ctx, snapshot.Plans[index].ID); err != nil {
errs = append(errs, fmt.Errorf("delete plan %s: %w", snapshot.Plans[index].ID, err))
continue
}
report.PlansDeleted++
}
for index := len(snapshot.Channels) - 1; index >= 0; index-- {
if err := s.host.DeleteChannel(ctx, snapshot.Channels[index].ID); err != nil {
errs = append(errs, fmt.Errorf("delete channel %s: %w", snapshot.Channels[index].ID, err))
continue
}
report.ChannelsDeleted++
}
for index := len(snapshot.Groups) - 1; index >= 0; index-- {
if err := s.host.DeleteGroup(ctx, snapshot.Groups[index].ID); err != nil {
errs = append(errs, fmt.Errorf("delete group %s: %w", snapshot.Groups[index].ID, err))
continue
}
report.GroupsDeleted++
}
if len(errs) > 0 {
return report, errors.Join(errs...)
}
return report, nil
}

View File

@@ -0,0 +1,50 @@
package provision
import (
"context"
"reflect"
"testing"
"sub2api-cn-relay-manager/internal/host/sub2api"
)
func TestRollbackServiceDeletesManagedResourcesInDependencyOrder(t *testing.T) {
host := &fakeHostAdapter{
managedSnapshot: sub2api.ManagedResourceSnapshot{
Groups: []sub2api.NamedResource{{ID: "group_1", Name: "crm-deepseek-group"}},
Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "crm-deepseek-channel"}},
Plans: []sub2api.NamedResource{{ID: "plan_1", Name: "crm-deepseek-plan"}},
Accounts: []sub2api.NamedResource{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
},
}
svc := NewRollbackService(host)
report, err := svc.Rollback(context.Background(), RollbackRequest{Provider: sampleProviderManifest()})
if err != nil {
t.Fatalf("Rollback() error = %v", err)
}
if report.AccountsDeleted != 2 || report.PlansDeleted != 1 || report.ChannelsDeleted != 1 || report.GroupsDeleted != 1 {
t.Fatalf("Rollback() report = %+v, want all managed resources deleted", report)
}
want := []string{"account:account_2", "account:account_1", "plan:plan_1", "channel:channel_1", "group:group_1"}
if !reflect.DeepEqual(host.deletedResources, want) {
t.Fatalf("deleted resources = %#v, want %#v", host.deletedResources, want)
}
}
func TestRollbackServiceReturnsEmptyReportWhenNoManagedResourcesExist(t *testing.T) {
host := &fakeHostAdapter{}
svc := NewRollbackService(host)
report, err := svc.Rollback(context.Background(), RollbackRequest{Provider: sampleProviderManifest()})
if err != nil {
t.Fatalf("Rollback() error = %v", err)
}
if report.AccountsDeleted != 0 || report.PlansDeleted != 0 || report.ChannelsDeleted != 0 || report.GroupsDeleted != 0 {
t.Fatalf("Rollback() report = %+v, want zero deletions", report)
}
if len(host.deletedResources) != 0 {
t.Fatalf("deleted resources = %#v, want none", host.deletedResources)
}
}

View File

@@ -0,0 +1,259 @@
package provision
import (
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"strings"
"sub2api-cn-relay-manager/internal/pack"
"sub2api-cn-relay-manager/internal/store/sqlite"
)
type RuntimeImportRequest struct {
HostID string
HostBaseURL string
Pack pack.LoadedPack
Provider pack.ProviderManifest
Mode string
Access AccessRequest
Keys []string
}
type RuntimeImportResult struct {
BatchID int64
Report ImportReport
}
type RuntimeImportService struct {
store *sqlite.DB
host hostAdapter
}
func NewRuntimeImportService(store *sqlite.DB, host hostAdapter) *RuntimeImportService {
return &RuntimeImportService{store: store, host: host}
}
func (s *RuntimeImportService) Import(ctx context.Context, req RuntimeImportRequest) (RuntimeImportResult, error) {
if s == nil || s.store == nil {
return RuntimeImportResult{}, fmt.Errorf("store is required")
}
if s.host == nil {
return RuntimeImportResult{}, fmt.Errorf("host adapter is required")
}
req.HostBaseURL = strings.TrimSpace(req.HostBaseURL)
if req.HostBaseURL == "" {
return RuntimeImportResult{}, fmt.Errorf("host_base_url is required")
}
if strings.TrimSpace(req.HostID) == "" {
req.HostID = req.HostBaseURL
}
hostVersion, err := s.host.GetHostVersion(ctx)
if err != nil {
return RuntimeImportResult{}, fmt.Errorf("get host version: %w", err)
}
if err := pack.CheckHostCompatibility(req.Pack.Manifest, hostVersion); err != nil {
return RuntimeImportResult{}, err
}
capabilities, err := s.host.ProbeCapabilities(ctx)
if err != nil {
return RuntimeImportResult{}, fmt.Errorf("probe host capabilities: %w", err)
}
capabilityProbeJSON, err := json.Marshal(capabilities)
if err != nil {
return RuntimeImportResult{}, fmt.Errorf("marshal host capabilities: %w", err)
}
hostRow, err := s.ensureHost(ctx, req.HostID, req.HostBaseURL, hostVersion, string(capabilityProbeJSON))
if err != nil {
return RuntimeImportResult{}, err
}
packRow, err := s.ensurePack(ctx, req.Pack)
if err != nil {
return RuntimeImportResult{}, err
}
providerRow, err := s.ensureProvider(ctx, packRow.ID, req.Provider)
if err != nil {
return RuntimeImportResult{}, err
}
batchID, err := s.store.ImportBatches().Create(ctx, sqlite.ImportBatch{
HostID: hostRow.ID,
PackID: packRow.ID,
ProviderID: providerRow.ID,
Mode: req.Mode,
BatchStatus: "running",
AccessStatus: "pending",
})
if err != nil {
return RuntimeImportResult{}, err
}
report, importErr := NewImportService(s.host).Import(ctx, ImportRequest{
Provider: req.Provider,
Mode: req.Mode,
Access: req.Access,
Keys: req.Keys,
})
if report.BatchStatus == "" {
report.BatchStatus = BatchStatusFailed
}
if report.AccessStatus == "" {
report.AccessStatus = AccessStatusBroken
}
if persistErr := s.persistRuntimeArtifacts(ctx, batchID, req.Access.Mode, report, importErr == nil); persistErr != nil {
return RuntimeImportResult{}, persistErr
}
if err := s.store.ImportBatches().UpdateStatus(ctx, batchID, report.BatchStatus, report.AccessStatus); err != nil {
return RuntimeImportResult{}, err
}
if importErr != nil {
return RuntimeImportResult{BatchID: batchID, Report: report}, importErr
}
return RuntimeImportResult{BatchID: batchID, Report: report}, nil
}
func (s *RuntimeImportService) ensureHost(ctx context.Context, hostID, baseURL, hostVersion, capabilityProbeJSON string) (sqlite.Host, error) {
host, err := s.store.Hosts().GetByHostID(ctx, hostID)
if err == nil {
return host, nil
}
if _, createErr := s.store.Hosts().Create(ctx, sqlite.Host{
HostID: hostID,
BaseURL: baseURL,
HostVersion: hostVersion,
CapabilityProbeJSON: capabilityProbeJSON,
}); createErr != nil {
return sqlite.Host{}, createErr
}
return s.store.Hosts().GetByHostID(ctx, hostID)
}
func (s *RuntimeImportService) ensurePack(ctx context.Context, loaded pack.LoadedPack) (sqlite.Pack, error) {
packRow, err := s.store.Packs().GetByPackID(ctx, loaded.Manifest.PackID)
if err == nil {
if err := validateExistingPack(packRow, loaded); err != nil {
return sqlite.Pack{}, err
}
}
packRecord, err := buildPackRecord(loaded)
if err != nil {
return sqlite.Pack{}, err
}
if _, err := s.store.Packs().Upsert(ctx, packRecord); err != nil {
return sqlite.Pack{}, err
}
return s.store.Packs().GetByPackID(ctx, loaded.Manifest.PackID)
}
func (s *RuntimeImportService) ensureProvider(ctx context.Context, packID int64, provider pack.ProviderManifest) (sqlite.Provider, error) {
if _, err := s.store.Providers().GetByPackIDAndProviderID(ctx, packID, provider.ProviderID); err == nil {
// continue into upsert path so metadata stays fresh.
}
providerRecord, err := buildProviderRecord(packID, provider)
if err != nil {
return sqlite.Provider{}, err
}
if _, err := s.store.Providers().Upsert(ctx, providerRecord); err != nil {
return sqlite.Provider{}, err
}
return s.store.Providers().GetByPackIDAndProviderID(ctx, packID, provider.ProviderID)
}
func (s *RuntimeImportService) persistRuntimeArtifacts(ctx context.Context, batchID int64, accessMode string, report ImportReport, includeManagedResources bool) error {
for i, account := range report.Accounts {
payload, err := json.Marshal(map[string]any{
"account_id": account.Ref.ID,
"probe_ok": account.Probe.OK,
"probe_status": account.Probe.Status,
"probe_message": account.Probe.Message,
"models": account.Models,
"smoke_model_seen": account.SmokeModelSeen,
})
if err != nil {
return fmt.Errorf("marshal account probe summary: %w", err)
}
itemID, err := s.store.ImportBatchItems().Create(ctx, sqlite.ImportBatchItem{
BatchID: batchID,
KeyFingerprint: fingerprintKey(report.AcceptedKeys, i),
AccountStatus: firstNonEmpty(account.Probe.Status, "unknown"),
ProbeSummaryJSON: string(payload),
})
if err != nil {
return err
}
if _, err := s.store.ProbeResults().Create(ctx, sqlite.ProbeResult{
BatchItemID: itemID,
ProbeType: "account_smoke",
Status: firstNonEmpty(account.Probe.Status, "unknown"),
SummaryJSON: string(payload),
}); err != nil {
return err
}
}
if includeManagedResources {
if report.Group.ID != "" {
if _, err := s.store.ManagedResources().Create(ctx, sqlite.ManagedResource{BatchID: batchID, ResourceType: "group", HostResourceID: report.Group.ID, ResourceName: firstNonEmpty(report.Group.Name, report.Group.ID)}); err != nil {
return err
}
}
if report.Channel.ID != "" {
if _, err := s.store.ManagedResources().Create(ctx, sqlite.ManagedResource{BatchID: batchID, ResourceType: "channel", HostResourceID: report.Channel.ID, ResourceName: firstNonEmpty(report.Channel.Name, report.Channel.ID)}); err != nil {
return err
}
}
if report.Plan != nil && report.Plan.ID != "" {
if _, err := s.store.ManagedResources().Create(ctx, sqlite.ManagedResource{BatchID: batchID, ResourceType: "plan", HostResourceID: report.Plan.ID, ResourceName: firstNonEmpty(report.Plan.Name, report.Plan.ID)}); err != nil {
return err
}
}
for _, account := range report.Accounts {
if _, err := s.store.ManagedResources().Create(ctx, sqlite.ManagedResource{BatchID: batchID, ResourceType: "account", HostResourceID: account.Ref.ID, ResourceName: firstNonEmpty(account.Ref.Name, account.Ref.ID)}); err != nil {
return err
}
}
}
accessPayload, err := json.Marshal(map[string]any{
"status_code": report.Gateway.StatusCode,
"ok": report.Gateway.OK,
"has_expected_model": report.Gateway.HasExpectedModel,
"models": report.Gateway.Models,
})
if err != nil {
return fmt.Errorf("marshal gateway access summary: %w", err)
}
if _, err := s.store.AccessClosures().Create(ctx, sqlite.AccessClosureRecord{
BatchID: batchID,
ClosureType: firstNonEmpty(strings.TrimSpace(accessMode), "unknown"),
Status: firstNonEmpty(report.AccessStatus, AccessStatusBroken),
DetailsJSON: string(accessPayload),
}); err != nil {
return err
}
return nil
}
func fingerprintKey(keys []string, index int) string {
if index >= 0 && index < len(keys) {
key := strings.TrimSpace(keys[index])
if key != "" {
sum := sha256.Sum256([]byte(key))
return fmt.Sprintf("sha256:%x", sum[:])
}
}
return fmt.Sprintf("key-%d", index+1)
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
}
return ""
}

View File

@@ -0,0 +1,196 @@
package provision
import (
"context"
"database/sql"
"fmt"
"path/filepath"
"testing"
_ "modernc.org/sqlite"
"sub2api-cn-relay-manager/internal/host/sub2api"
"sub2api-cn-relay-manager/internal/pack"
"sub2api-cn-relay-manager/internal/store/sqlite"
)
func TestRuntimeImportServicePersistsOperationalState(t *testing.T) {
store := openProvisionTestStore(t)
defer closeProvisionTestStore(t, store)
host := &fakeHostAdapter{
batchAccounts: []sub2api.AccountRef{{ID: "account_1"}, {ID: "account_2"}},
testResults: map[string]sub2api.ProbeResult{
"account_1": {OK: true, Status: "passed"},
"account_2": {OK: true, Status: "passed"},
},
models: map[string][]sub2api.AccountModel{
"account_1": {{ID: "deepseek-chat"}},
"account_2": {{ID: "deepseek-chat"}},
},
gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}},
}
svc := NewRuntimeImportService(store, host)
result, err := svc.Import(context.Background(), RuntimeImportRequest{
HostID: "host-1",
HostBaseURL: "https://sub2api.example.com",
Pack: pack.LoadedPack{
Manifest: pack.Manifest{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api", MinHostVersion: "0.1.126", MaxHostVersion: "0.2.x"},
Checksum: "checksum-1",
},
Provider: sampleProviderManifest(),
Mode: ImportModePartial,
Keys: []string{" key-1 ", "key-2", "key-1"},
Access: AccessRequest{
Mode: AccessModeSelfService,
ProbeAPIKey: "user-key",
},
})
if err != nil {
t.Fatalf("RuntimeImportService.Import() error = %v", err)
}
if result.BatchID <= 0 {
t.Fatalf("BatchID = %d, want positive id", result.BatchID)
}
if result.Report.BatchStatus != BatchStatusSucceeded {
t.Fatalf("BatchStatus = %q, want %q", result.Report.BatchStatus, BatchStatusSucceeded)
}
if got := queryCount(t, store.SQLDB(), "hosts"); got != 1 {
t.Fatalf("hosts row count = %d, want 1", got)
}
if got := queryCount(t, store.SQLDB(), "packs"); got != 1 {
t.Fatalf("packs row count = %d, want 1", got)
}
if got := queryCount(t, store.SQLDB(), "providers"); got != 1 {
t.Fatalf("providers row count = %d, want 1", got)
}
if got := queryCount(t, store.SQLDB(), "import_batches"); got != 1 {
t.Fatalf("import_batches row count = %d, want 1", got)
}
if got := queryCount(t, store.SQLDB(), "import_batch_items"); got != 2 {
t.Fatalf("import_batch_items row count = %d, want 2", got)
}
if got := queryCount(t, store.SQLDB(), "managed_resources"); got != 4 {
t.Fatalf("managed_resources row count = %d, want 4", got)
}
if got := queryCount(t, store.SQLDB(), "probe_results"); got != 2 {
t.Fatalf("probe_results row count = %d, want 2", got)
}
if got := queryCount(t, store.SQLDB(), "access_closure_records"); got != 1 {
t.Fatalf("access_closure_records row count = %d, want 1", got)
}
var batchStatus string
var accessStatus string
if err := store.SQLDB().QueryRowContext(context.Background(), "SELECT batch_status, access_status FROM import_batches WHERE id = ?", result.BatchID).Scan(&batchStatus, &accessStatus); err != nil {
t.Fatalf("query import batch state: %v", err)
}
if batchStatus != BatchStatusSucceeded {
t.Fatalf("persisted batch_status = %q, want %q", batchStatus, BatchStatusSucceeded)
}
if accessStatus != AccessStatusSelfServiceReady {
t.Fatalf("persisted access_status = %q, want %q", accessStatus, AccessStatusSelfServiceReady)
}
var fingerprint string
var accountStatus string
if err := store.SQLDB().QueryRowContext(context.Background(), "SELECT key_fingerprint, account_status FROM import_batch_items ORDER BY id LIMIT 1").Scan(&fingerprint, &accountStatus); err != nil {
t.Fatalf("query import batch item: %v", err)
}
if fingerprint == "key-1" || fingerprint == "key-2" || len(fingerprint) < 10 {
t.Fatalf("key_fingerprint = %q, want hashed fingerprint instead of raw key", fingerprint)
}
if accountStatus != "passed" {
t.Fatalf("account_status = %q, want passed", accountStatus)
}
}
func TestRuntimeImportServicePersistsFailedBatchAfterStrictRollback(t *testing.T) {
store := openProvisionTestStore(t)
defer closeProvisionTestStore(t, store)
host := &fakeHostAdapter{
batchAccounts: []sub2api.AccountRef{{ID: "account_1"}, {ID: "account_2"}},
testResults: map[string]sub2api.ProbeResult{
"account_1": {OK: true, Status: "passed"},
"account_2": {OK: false, Status: "failed", Message: "bad key"},
},
models: map[string][]sub2api.AccountModel{
"account_1": {{ID: "deepseek-chat"}},
"account_2": {{ID: "deepseek-chat"}},
},
}
svc := NewRuntimeImportService(store, host)
result, err := svc.Import(context.Background(), RuntimeImportRequest{
HostID: "host-1",
HostBaseURL: "https://sub2api.example.com",
Pack: pack.LoadedPack{
Manifest: pack.Manifest{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api", MinHostVersion: "0.1.126", MaxHostVersion: "0.2.x"},
Checksum: "checksum-1",
},
Provider: sampleProviderManifest(),
Mode: ImportModeStrict,
Keys: []string{"key-1", "key-2"},
Access: AccessRequest{
Mode: AccessModeSelfService,
ProbeAPIKey: "user-key",
},
})
if err == nil {
t.Fatal("RuntimeImportService.Import() error = nil, want strict failure")
}
if result.BatchID <= 0 {
t.Fatalf("BatchID = %d, want positive id", result.BatchID)
}
var batchStatus string
var accessStatus string
if err := store.SQLDB().QueryRowContext(context.Background(), "SELECT batch_status, access_status FROM import_batches WHERE id = ?", result.BatchID).Scan(&batchStatus, &accessStatus); err != nil {
t.Fatalf("query failed import batch state: %v", err)
}
if batchStatus != BatchStatusFailed {
t.Fatalf("persisted batch_status = %q, want %q", batchStatus, BatchStatusFailed)
}
if accessStatus != AccessStatusBroken {
t.Fatalf("persisted access_status = %q, want %q", accessStatus, AccessStatusBroken)
}
if got := queryCount(t, store.SQLDB(), "managed_resources"); got != 0 {
t.Fatalf("managed_resources row count = %d, want 0 after strict rollback", got)
}
if got := queryCount(t, store.SQLDB(), "probe_results"); got != 2 {
t.Fatalf("probe_results row count = %d, want 2", got)
}
if got := queryCount(t, store.SQLDB(), "access_closure_records"); got != 1 {
t.Fatalf("access_closure_records row count = %d, want 1", got)
}
}
func openProvisionTestStore(t *testing.T) *sqlite.DB {
t.Helper()
dbPath := filepath.Join(t.TempDir(), "state.db")
dsn := fmt.Sprintf("file:%s?_busy_timeout=5000&_pragma=foreign_keys(0)", filepath.ToSlash(dbPath))
store, err := sqlite.Open(context.Background(), dsn)
if err != nil {
t.Fatalf("sqlite.Open() error = %v", err)
}
return store
}
func closeProvisionTestStore(t *testing.T, store *sqlite.DB) {
t.Helper()
if err := store.Close(); err != nil {
t.Fatalf("store.Close() error = %v", err)
}
}
func queryCount(t *testing.T, db *sql.DB, table string) int {
t.Helper()
var count int
if err := db.QueryRowContext(context.Background(), "SELECT COUNT(*) FROM "+table).Scan(&count); err != nil {
t.Fatalf("count rows for %s: %v", table, err)
}
return count
}

View File

@@ -0,0 +1,64 @@
CREATE TABLE import_batches (
id INTEGER PRIMARY KEY AUTOINCREMENT,
host_id INTEGER NOT NULL,
pack_id INTEGER NOT NULL,
provider_id INTEGER NOT NULL,
mode TEXT NOT NULL,
batch_status TEXT NOT NULL,
access_status TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT fk_import_batches_host FOREIGN KEY (host_id) REFERENCES hosts(id) ON DELETE CASCADE,
CONSTRAINT fk_import_batches_pack FOREIGN KEY (pack_id) REFERENCES packs(id) ON DELETE CASCADE,
CONSTRAINT fk_import_batches_provider FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE
);
CREATE TABLE import_batch_items (
id INTEGER PRIMARY KEY AUTOINCREMENT,
batch_id INTEGER NOT NULL,
key_fingerprint TEXT NOT NULL,
account_status TEXT NOT NULL,
probe_summary_json TEXT NOT NULL DEFAULT '{}',
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT fk_import_batch_items_batch FOREIGN KEY (batch_id) REFERENCES import_batches(id) ON DELETE CASCADE,
CONSTRAINT uq_import_batch_items_batch_fingerprint UNIQUE (batch_id, key_fingerprint)
);
CREATE TABLE managed_resources (
id INTEGER PRIMARY KEY AUTOINCREMENT,
batch_id INTEGER NOT NULL,
resource_type TEXT NOT NULL,
host_resource_id TEXT NOT NULL,
resource_name TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT fk_managed_resources_batch FOREIGN KEY (batch_id) REFERENCES import_batches(id) ON DELETE CASCADE,
CONSTRAINT uq_managed_resources_host UNIQUE (resource_type, host_resource_id)
);
CREATE TABLE probe_results (
id INTEGER PRIMARY KEY AUTOINCREMENT,
batch_item_id INTEGER NOT NULL,
probe_type TEXT NOT NULL,
status TEXT NOT NULL,
summary_json TEXT NOT NULL DEFAULT '{}',
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT fk_probe_results_item FOREIGN KEY (batch_item_id) REFERENCES import_batch_items(id) ON DELETE CASCADE
);
CREATE TABLE access_closure_records (
id INTEGER PRIMARY KEY AUTOINCREMENT,
batch_id INTEGER NOT NULL,
closure_type TEXT NOT NULL,
status TEXT NOT NULL,
details_json TEXT NOT NULL DEFAULT '{}',
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT fk_access_closure_records_batch FOREIGN KEY (batch_id) REFERENCES import_batches(id) ON DELETE CASCADE
);
CREATE TABLE reconcile_runs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
provider_id INTEGER NOT NULL,
status TEXT NOT NULL,
summary_json TEXT NOT NULL DEFAULT '{}',
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT fk_reconcile_runs_provider FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE
);

View File

@@ -0,0 +1,14 @@
ALTER TABLE packs ADD COLUMN vendor TEXT NOT NULL DEFAULT '';
ALTER TABLE packs ADD COLUMN target_host TEXT NOT NULL DEFAULT '';
ALTER TABLE packs ADD COLUMN min_host_version TEXT NOT NULL DEFAULT '';
ALTER TABLE packs ADD COLUMN max_host_version TEXT NOT NULL DEFAULT '';
ALTER TABLE packs ADD COLUMN manifest_json TEXT NOT NULL DEFAULT '{}';
ALTER TABLE providers ADD COLUMN account_type TEXT NOT NULL DEFAULT '';
ALTER TABLE providers ADD COLUMN default_models_json TEXT NOT NULL DEFAULT '[]';
ALTER TABLE providers ADD COLUMN smoke_test_model TEXT NOT NULL DEFAULT '';
ALTER TABLE providers ADD COLUMN group_template_json TEXT NOT NULL DEFAULT '{}';
ALTER TABLE providers ADD COLUMN channel_template_json TEXT NOT NULL DEFAULT '{}';
ALTER TABLE providers ADD COLUMN plan_template_json TEXT NOT NULL DEFAULT '{}';
ALTER TABLE providers ADD COLUMN import_options_json TEXT NOT NULL DEFAULT '{}';
ALTER TABLE providers ADD COLUMN manifest_json TEXT NOT NULL DEFAULT '{}';

View File

@@ -0,0 +1,77 @@
package sqlite
import (
"context"
"fmt"
"strings"
)
type AccessClosureRecord struct {
ID int64
BatchID int64
ClosureType string
Status string
DetailsJSON string
}
type AccessClosureRecordsRepo struct {
db execQuerier
}
func newAccessClosureRecordsRepo(db execQuerier) *AccessClosureRecordsRepo {
return &AccessClosureRecordsRepo{db: db}
}
func (r *AccessClosureRecordsRepo) Create(ctx context.Context, record AccessClosureRecord) (int64, error) {
closureType := strings.TrimSpace(record.ClosureType)
status := strings.TrimSpace(record.Status)
detailsJSON := strings.TrimSpace(record.DetailsJSON)
if detailsJSON == "" {
detailsJSON = "{}"
}
switch {
case record.BatchID <= 0:
return 0, fmt.Errorf("batch_id is required")
case closureType == "":
return 0, fmt.Errorf("closure_type is required")
case status == "":
return 0, fmt.Errorf("status is required")
}
result, err := r.db.ExecContext(ctx, `INSERT INTO access_closure_records (batch_id, closure_type, status, details_json) VALUES (?, ?, ?, ?)`, record.BatchID, closureType, status, detailsJSON)
if err != nil {
return 0, fmt.Errorf("insert access closure record: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return 0, fmt.Errorf("read inserted access closure record id: %w", err)
}
return id, nil
}
func (r *AccessClosureRecordsRepo) GetByBatchID(ctx context.Context, batchID int64) ([]AccessClosureRecord, error) {
if batchID <= 0 {
return nil, fmt.Errorf("batch_id is required")
}
rows, err := r.db.QueryContext(ctx, `SELECT id, batch_id, closure_type, status, details_json FROM access_closure_records WHERE batch_id = ? ORDER BY id`, batchID)
if err != nil {
return nil, fmt.Errorf("query access closure records: %w", err)
}
defer rows.Close()
records := make([]AccessClosureRecord, 0)
for rows.Next() {
var record AccessClosureRecord
if err := rows.Scan(&record.ID, &record.BatchID, &record.ClosureType, &record.Status, &record.DetailsJSON); err != nil {
return nil, fmt.Errorf("scan access closure record: %w", err)
}
records = append(records, record)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate access closure records: %w", err)
}
return records, nil
}

View File

@@ -15,13 +15,20 @@ import (
type execQuerier interface {
ExecContext(context.Context, string, ...any) (sql.Result, error)
QueryContext(context.Context, string, ...any) (*sql.Rows, error)
QueryRowContext(context.Context, string, ...any) *sql.Row
}
type Queries struct {
Hosts *HostsRepo
Packs *PacksRepo
Providers *ProvidersRepo
Hosts *HostsRepo
Packs *PacksRepo
Providers *ProvidersRepo
ImportBatches *ImportBatchesRepo
ImportBatchItems *ImportBatchItemsRepo
ManagedResources *ManagedResourcesRepo
ProbeResults *ProbeResultsRepo
AccessClosures *AccessClosureRecordsRepo
ReconcileRuns *ReconcileRunsRepo
}
type DB struct {
@@ -76,6 +83,30 @@ func (db *DB) Providers() *ProvidersRepo {
return db.queries.Providers
}
func (db *DB) ImportBatches() *ImportBatchesRepo {
return db.queries.ImportBatches
}
func (db *DB) ImportBatchItems() *ImportBatchItemsRepo {
return db.queries.ImportBatchItems
}
func (db *DB) ManagedResources() *ManagedResourcesRepo {
return db.queries.ManagedResources
}
func (db *DB) ProbeResults() *ProbeResultsRepo {
return db.queries.ProbeResults
}
func (db *DB) AccessClosures() *AccessClosureRecordsRepo {
return db.queries.AccessClosures
}
func (db *DB) ReconcileRuns() *ReconcileRunsRepo {
return db.queries.ReconcileRuns
}
func (db *DB) WithTx(ctx context.Context, fn func(*Queries) error) error {
tx, err := db.sqlDB.BeginTx(ctx, nil)
if err != nil {
@@ -101,9 +132,15 @@ func (db *DB) WithTx(ctx context.Context, fn func(*Queries) error) error {
func newQueries(db execQuerier) *Queries {
return &Queries{
Hosts: newHostsRepo(db),
Packs: newPacksRepo(db),
Providers: newProvidersRepo(db),
Hosts: newHostsRepo(db),
Packs: newPacksRepo(db),
Providers: newProvidersRepo(db),
ImportBatches: newImportBatchesRepo(db),
ImportBatchItems: newImportBatchItemsRepo(db),
ManagedResources: newManagedResourcesRepo(db),
ProbeResults: newProbeResultsRepo(db),
AccessClosures: newAccessClosureRecordsRepo(db),
ReconcileRuns: newReconcileRunsRepo(db),
}
}

View File

@@ -0,0 +1,161 @@
package sqlite
import (
"context"
"database/sql"
"errors"
"path/filepath"
"testing"
)
func TestOpenClose(t *testing.T) {
store := openTestDB(t)
if store == nil {
t.Fatal("Open() returned nil")
}
}
func TestOpenInvalidDSN(t *testing.T) {
_, err := Open(context.Background(), "file:/nonexistent/dir/test.db?_pragma=foreign_keys(0)")
if err == nil {
t.Fatal("Open() with invalid dsn error = nil, want error")
}
}
func TestWithTxCommit(t *testing.T) {
store := openTestDB(t)
err := store.WithTx(context.Background(), func(q *Queries) error {
_, err := q.Hosts.Create(context.Background(), Host{
HostID: "tx-host", BaseURL: "https://tx.com", HostVersion: "1.0",
})
return err
})
if err != nil {
t.Fatalf("WithTx() error = %v", err)
}
host, err := store.Hosts().GetByHostID(context.Background(), "tx-host")
if err != nil {
t.Fatalf("GetByHostID() after tx = %v", err)
}
if host.HostID != "tx-host" {
t.Fatalf("host = %+v, want tx-host", host)
}
}
func TestWithTxRollbackOnError(t *testing.T) {
store := openTestDB(t)
err := store.WithTx(context.Background(), func(q *Queries) error {
q.Hosts.Create(context.Background(), Host{
HostID: "rollback-host", BaseURL: "https://r.com", HostVersion: "1.0",
})
return errors.New("rollback")
})
if err == nil {
t.Fatal("WithTx() error = nil, want rollback error")
}
_, err = store.Hosts().GetByHostID(context.Background(), "rollback-host")
if !errors.Is(err, sql.ErrNoRows) {
t.Fatalf("GetByHostID() after rollback error = %v, want sql.ErrNoRows", err)
}
}
func TestTableExists(t *testing.T) {
store := openTestDB(t)
db := store.SQLDB()
found, err := tableExists(context.Background(), db, "hosts")
if err != nil {
t.Fatalf("tableExists('hosts') error = %v", err)
}
if !found {
t.Fatal("tableExists('hosts') = false, want true")
}
found, err = tableExists(context.Background(), db, "nonexistent")
if err != nil {
t.Fatalf("tableExists('nonexistent') error = %v", err)
}
if found {
t.Fatal("tableExists('nonexistent') = true, want false")
}
}
func TestDetectLegacy0001Schema(t *testing.T) {
store := openTestDB(t)
db := store.SQLDB()
tx, err := db.BeginTx(context.Background(), nil)
if err != nil {
t.Fatalf("BeginTx error = %v", err)
}
defer tx.Rollback()
// After migration all three host/packs/providers tables exist,
// so detectLegacy0001Schema reports complete=true.
complete, partial, err := detectLegacy0001Schema(context.Background(), tx)
if err != nil {
t.Fatalf("detectLegacy0001Schema() error = %v", err)
}
if !complete {
t.Fatalf("detectLegacy0001Schema() = (complete=%v, partial=%v), want (true, false)", complete, partial)
}
if partial {
t.Fatal("partial should be false when all 3 tables exist")
}
}
func TestWithForeignKeysEnabled(t *testing.T) {
if got := withForeignKeysEnabled("file:test.db"); got != "file:test.db?_pragma=foreign_keys(1)" {
t.Fatalf("withForeignKeysEnabled no query = %q", got)
}
if got := withForeignKeysEnabled("file:test.db?a=1"); got != "file:test.db?a=1&_pragma=foreign_keys(1)" {
t.Fatalf("withForeignKeysEnabled with query = %q", got)
}
}
func TestSQLDB(t *testing.T) {
store := openTestDB(t)
db := store.SQLDB()
if db == nil {
t.Fatal("SQLDB() returned nil")
}
if err := db.PingContext(context.Background()); err != nil {
t.Fatalf("Ping() error = %v", err)
}
}
func TestMigrationFileNames(t *testing.T) {
names, err := migrationFileNames()
if err != nil {
t.Fatalf("migrationFileNames() error = %v", err)
}
if len(names) == 0 {
t.Fatal("migrationFileNames() returned empty")
}
for _, name := range names {
if filepath.Ext(name) != ".sql" {
t.Fatalf("migrationFileNames() entry %q not .sql file", name)
}
}
}
func TestReadMigration(t *testing.T) {
content, err := readMigration("0001_init.sql")
if err != nil {
t.Fatalf("readMigration('0001_init.sql') error = %v", err)
}
if len(content) == 0 {
t.Fatal("readMigration() returned empty content")
}
}
func TestReadMigrationNotFound(t *testing.T) {
_, err := readMigration("nonexistent.sql")
if err == nil {
t.Fatal("readMigration('nonexistent.sql') error = nil, want error")
}
}

View File

@@ -7,6 +7,7 @@ import (
)
type Host struct {
ID int64
HostID string
BaseURL string
HostVersion string
@@ -21,6 +22,31 @@ func newHostsRepo(db execQuerier) *HostsRepo {
return &HostsRepo{db: db}
}
func (r *HostsRepo) GetByID(ctx context.Context, id int64) (Host, error) {
if id <= 0 {
return Host{}, fmt.Errorf("id is required")
}
var host Host
if err := r.db.QueryRowContext(ctx, `SELECT id, host_id, base_url, host_version, capability_probe_json FROM hosts WHERE id = ?`, id).Scan(&host.ID, &host.HostID, &host.BaseURL, &host.HostVersion, &host.CapabilityProbeJSON); err != nil {
return Host{}, err
}
return host, nil
}
func (r *HostsRepo) GetByHostID(ctx context.Context, hostID string) (Host, error) {
hostID = strings.TrimSpace(hostID)
if hostID == "" {
return Host{}, fmt.Errorf("host_id is required")
}
var host Host
if err := r.db.QueryRowContext(ctx, `SELECT id, host_id, base_url, host_version, capability_probe_json FROM hosts WHERE host_id = ?`, hostID).Scan(&host.ID, &host.HostID, &host.BaseURL, &host.HostVersion, &host.CapabilityProbeJSON); err != nil {
return Host{}, err
}
return host, nil
}
func (r *HostsRepo) Create(ctx context.Context, host Host) (int64, error) {
hostID := strings.TrimSpace(host.HostID)
baseURL := strings.TrimSpace(host.BaseURL)

View File

@@ -0,0 +1,199 @@
package sqlite
import (
"context"
"database/sql"
"errors"
"path/filepath"
"testing"
)
// openTestDB creates a test database with foreign keys disabled.
func openTestDB(t *testing.T) *DB {
t.Helper()
dbPath := filepath.Join(t.TempDir(), "test.db")
dsn := "file:" + filepath.ToSlash(dbPath) + "?_pragma=foreign_keys(0)"
store, err := Open(context.Background(), dsn)
if err != nil {
t.Fatalf("Open() error = %v", err)
}
t.Cleanup(func() { store.Close() })
return store
}
// openTestDBWithFK creates a test database with foreign keys enforced.
func openTestDBWithFK(t *testing.T) *DB {
t.Helper()
dbPath := filepath.Join(t.TempDir(), "test-fk.db")
dsn := "file:" + filepath.ToSlash(dbPath)
store, err := Open(context.Background(), dsn)
if err != nil {
t.Fatalf("Open() error = %v", err)
}
t.Cleanup(func() { store.Close() })
return store
}
func createTestPack(t *testing.T, store *DB) int64 {
t.Helper()
id, err := store.Packs().Create(context.Background(), Pack{
PackID: "pack-" + sanitizeTestName(t.Name()), Version: "1.0.0", Checksum: "chk",
})
if err != nil {
t.Fatalf("createTestPack error = %v", err)
}
return id
}
func createTestHost(t *testing.T, store *DB) int64 {
t.Helper()
id, err := store.Hosts().Create(context.Background(), Host{
HostID: "host-" + sanitizeTestName(t.Name()), BaseURL: "https://h.com", HostVersion: "0.1.0",
})
if err != nil {
t.Fatalf("createTestHost error = %v", err)
}
return id
}
func createTestBatch(t *testing.T, store *DB) int64 {
t.Helper()
hostID := createTestHost(t, store)
packID := createTestPack(t, store)
providerID, err := store.Providers().Create(context.Background(), Provider{
PackID: packID, ProviderID: "test-provider", DisplayName: "Test",
BaseURL: "https://t.com", Platform: "openai",
})
if err != nil {
t.Fatalf("createTestBatch create provider error = %v", err)
}
id, err := store.ImportBatches().Create(context.Background(), ImportBatch{
HostID: hostID, PackID: packID, ProviderID: providerID,
Mode: "partial", BatchStatus: "running", AccessStatus: "pending",
})
if err != nil {
t.Fatalf("createTestBatch error = %v", err)
}
return id
}
func createTestBatchItem(t *testing.T, store *DB, batchID int64) int64 {
t.Helper()
id, err := store.ImportBatchItems().Create(context.Background(), ImportBatchItem{
BatchID: batchID, KeyFingerprint: "sha256:test", AccountStatus: "pending",
})
if err != nil {
t.Fatalf("createTestBatchItem error = %v", err)
}
return id
}
func sanitizeTestName(name string) string {
result := ""
for _, c := range name {
if (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' {
result += string(c)
}
}
if result == "" {
result = "default"
}
return result
}
// --- Hosts Repo Tests ---
func TestHostsRepoCreateAndGet(t *testing.T) {
store := openTestDB(t)
id, err := store.Hosts().Create(context.Background(), Host{
HostID: "host-1",
BaseURL: "https://sub2api.example.com",
HostVersion: "0.1.126",
CapabilityProbeJSON: `{"groups":true}`,
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
if id <= 0 {
t.Fatalf("Create() id = %d, want positive", id)
}
got, err := store.Hosts().GetByID(context.Background(), id)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if got.HostID != "host-1" || got.BaseURL != "https://sub2api.example.com" {
t.Fatalf("GetByID() = %+v, want host-1", got)
}
got2, err := store.Hosts().GetByHostID(context.Background(), "host-1")
if err != nil {
t.Fatalf("GetByHostID() error = %v", err)
}
if got2.ID != id {
t.Fatalf("GetByHostID() id = %d, want %d", got2.ID, id)
}
}
func TestHostsRepoCreateDefaultsCapabilityProbe(t *testing.T) {
store := openTestDB(t)
id, _ := store.Hosts().Create(context.Background(), Host{
HostID: "host-empty", BaseURL: "https://example.com", HostVersion: "0.1.0",
})
got, _ := store.Hosts().GetByID(context.Background(), id)
if got.CapabilityProbeJSON != "{}" {
t.Fatalf("CapabilityProbeJSON = %q, want {}", got.CapabilityProbeJSON)
}
}
func TestHostsRepoValidationErrors(t *testing.T) {
store := openTestDB(t)
for _, tt := range []struct {
name string
host Host
}{
{"empty host_id", Host{BaseURL: "b", HostVersion: "v"}},
{"empty base_url", Host{HostID: "h", HostVersion: "v"}},
{"empty host_version", Host{HostID: "h", BaseURL: "b"}},
} {
t.Run(tt.name, func(t *testing.T) {
_, err := store.Hosts().Create(context.Background(), tt.host)
if err == nil {
t.Fatal("Create() error = nil, want validation error")
}
})
}
}
func TestHostsRepoGetByIDZeroError(t *testing.T) {
store := openTestDB(t)
_, err := store.Hosts().GetByID(context.Background(), 0)
if err == nil {
t.Fatal("GetByID(0) error = nil, want error")
}
}
func TestHostsRepoGetByIDNotFound(t *testing.T) {
store := openTestDB(t)
_, err := store.Hosts().GetByID(context.Background(), 999)
if !errors.Is(err, sql.ErrNoRows) {
t.Fatalf("GetByID(999) error = %v, want sql.ErrNoRows", err)
}
}
func TestHostsRepoGetByHostIDEmptyError(t *testing.T) {
store := openTestDB(t)
_, err := store.Hosts().GetByHostID(context.Background(), "")
if err == nil {
t.Fatal("GetByHostID('') error = nil, want error")
}
}
func TestHostsRepoGetByHostIDNotFound(t *testing.T) {
store := openTestDB(t)
_, err := store.Hosts().GetByHostID(context.Background(), "nonexistent")
if !errors.Is(err, sql.ErrNoRows) {
t.Fatalf("GetByHostID('nonexistent') error = %v, want sql.ErrNoRows", err)
}
}

View File

@@ -0,0 +1,187 @@
package sqlite
import (
"context"
"fmt"
"strings"
)
type ImportBatch struct {
ID int64
HostID int64
PackID int64
ProviderID int64
Mode string
BatchStatus string
AccessStatus string
}
type ImportBatchItem struct {
ID int64
BatchID int64
KeyFingerprint string
AccountStatus string
ProbeSummaryJSON string
}
type ImportBatchesRepo struct {
db execQuerier
}
type ImportBatchItemsRepo struct {
db execQuerier
}
func newImportBatchesRepo(db execQuerier) *ImportBatchesRepo {
return &ImportBatchesRepo{db: db}
}
func newImportBatchItemsRepo(db execQuerier) *ImportBatchItemsRepo {
return &ImportBatchItemsRepo{db: db}
}
func (r *ImportBatchesRepo) GetByID(ctx context.Context, id int64) (ImportBatch, error) {
if id <= 0 {
return ImportBatch{}, fmt.Errorf("id is required")
}
var batch ImportBatch
if err := r.db.QueryRowContext(ctx, `SELECT id, host_id, pack_id, provider_id, mode, batch_status, access_status FROM import_batches WHERE id = ?`, id).Scan(&batch.ID, &batch.HostID, &batch.PackID, &batch.ProviderID, &batch.Mode, &batch.BatchStatus, &batch.AccessStatus); err != nil {
return ImportBatch{}, err
}
return batch, nil
}
func (r *ImportBatchesRepo) Create(ctx context.Context, batch ImportBatch) (int64, error) {
mode := strings.TrimSpace(batch.Mode)
batchStatus := strings.TrimSpace(batch.BatchStatus)
accessStatus := strings.TrimSpace(batch.AccessStatus)
switch {
case batch.HostID <= 0:
return 0, fmt.Errorf("host_id is required")
case batch.PackID <= 0:
return 0, fmt.Errorf("pack_id is required")
case batch.ProviderID <= 0:
return 0, fmt.Errorf("provider_id is required")
case mode == "":
return 0, fmt.Errorf("mode is required")
case batchStatus == "":
return 0, fmt.Errorf("batch_status is required")
case accessStatus == "":
return 0, fmt.Errorf("access_status is required")
}
result, err := r.db.ExecContext(ctx, `INSERT INTO import_batches (host_id, pack_id, provider_id, mode, batch_status, access_status) VALUES (?, ?, ?, ?, ?, ?)`, batch.HostID, batch.PackID, batch.ProviderID, mode, batchStatus, accessStatus)
if err != nil {
return 0, fmt.Errorf("insert import batch: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return 0, fmt.Errorf("read inserted import batch id: %w", err)
}
return id, nil
}
func (r *ImportBatchesRepo) UpdateStatus(ctx context.Context, id int64, batchStatus, accessStatus string) error {
if id <= 0 {
return fmt.Errorf("id is required")
}
batchStatus = strings.TrimSpace(batchStatus)
accessStatus = strings.TrimSpace(accessStatus)
if batchStatus == "" {
return fmt.Errorf("batch_status is required")
}
if accessStatus == "" {
return fmt.Errorf("access_status is required")
}
if _, err := r.db.ExecContext(ctx, `UPDATE import_batches SET batch_status = ?, access_status = ? WHERE id = ?`, batchStatus, accessStatus, id); err != nil {
return fmt.Errorf("update import batch %d: %w", id, err)
}
return nil
}
func (r *ImportBatchesRepo) GetLatestByProviderID(ctx context.Context, providerID int64) (ImportBatch, error) {
if providerID <= 0 {
return ImportBatch{}, fmt.Errorf("provider_id is required")
}
var batch ImportBatch
if err := r.db.QueryRowContext(ctx, `SELECT id, host_id, pack_id, provider_id, mode, batch_status, access_status FROM import_batches WHERE provider_id = ? ORDER BY id DESC LIMIT 1`, providerID).Scan(&batch.ID, &batch.HostID, &batch.PackID, &batch.ProviderID, &batch.Mode, &batch.BatchStatus, &batch.AccessStatus); err != nil {
return ImportBatch{}, err
}
return batch, nil
}
func (r *ImportBatchItemsRepo) GetByBatchID(ctx context.Context, batchID int64) ([]ImportBatchItem, error) {
if batchID <= 0 {
return nil, fmt.Errorf("batch_id is required")
}
rows, err := r.db.QueryContext(ctx, `SELECT id, batch_id, key_fingerprint, account_status, probe_summary_json FROM import_batch_items WHERE batch_id = ? ORDER BY id`, batchID)
if err != nil {
return nil, fmt.Errorf("query import batch items: %w", err)
}
defer rows.Close()
items := make([]ImportBatchItem, 0)
for rows.Next() {
var item ImportBatchItem
if err := rows.Scan(&item.ID, &item.BatchID, &item.KeyFingerprint, &item.AccountStatus, &item.ProbeSummaryJSON); err != nil {
return nil, fmt.Errorf("scan import batch item: %w", err)
}
items = append(items, item)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate import batch items: %w", err)
}
return items, nil
}
func (r *ImportBatchItemsRepo) Create(ctx context.Context, item ImportBatchItem) (int64, error) {
keyFingerprint := strings.TrimSpace(item.KeyFingerprint)
accountStatus := strings.TrimSpace(item.AccountStatus)
probeSummaryJSON := strings.TrimSpace(item.ProbeSummaryJSON)
if probeSummaryJSON == "" {
probeSummaryJSON = "{}"
}
switch {
case item.BatchID <= 0:
return 0, fmt.Errorf("batch_id is required")
case keyFingerprint == "":
return 0, fmt.Errorf("key_fingerprint is required")
case accountStatus == "":
return 0, fmt.Errorf("account_status is required")
}
result, err := r.db.ExecContext(ctx, `INSERT INTO import_batch_items (batch_id, key_fingerprint, account_status, probe_summary_json) VALUES (?, ?, ?, ?)`, item.BatchID, keyFingerprint, accountStatus, probeSummaryJSON)
if err != nil {
return 0, fmt.Errorf("insert import batch item: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return 0, fmt.Errorf("read inserted import batch item id: %w", err)
}
return id, nil
}
func (r *ImportBatchItemsRepo) UpdateResult(ctx context.Context, id int64, accountStatus, probeSummaryJSON string) error {
if id <= 0 {
return fmt.Errorf("id is required")
}
accountStatus = strings.TrimSpace(accountStatus)
probeSummaryJSON = strings.TrimSpace(probeSummaryJSON)
if accountStatus == "" {
return fmt.Errorf("account_status is required")
}
if probeSummaryJSON == "" {
probeSummaryJSON = "{}"
}
if _, err := r.db.ExecContext(ctx, `UPDATE import_batch_items SET account_status = ?, probe_summary_json = ? WHERE id = ?`, accountStatus, probeSummaryJSON, id); err != nil {
return fmt.Errorf("update import batch item %d: %w", id, err)
}
return nil
}

View File

@@ -0,0 +1,429 @@
package sqlite
import (
"context"
"testing"
)
func TestImportBatchesRepoCreateAndGet(t *testing.T) {
store := openTestDB(t)
hostID := createTestHost(t, store)
packID := createTestPack(t, store)
providerID := createTestProviderWithPack(t, store, packID)
id, err := store.ImportBatches().Create(context.Background(), ImportBatch{
HostID: hostID, PackID: packID, ProviderID: providerID,
Mode: "partial", BatchStatus: "running", AccessStatus: "pending",
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
if id <= 0 {
t.Fatalf("Create() id = %d, want positive", id)
}
got, _ := store.ImportBatches().GetByID(context.Background(), id)
if got.Mode != "partial" || got.BatchStatus != "running" {
t.Fatalf("GetByID() = %+v, want running batch", got)
}
}
func TestImportBatchesRepoUpdateStatus(t *testing.T) {
store := openTestDB(t)
hostID := createTestHost(t, store)
packID := createTestPack(t, store)
providerID := createTestProviderWithPack(t, store, packID)
id, _ := store.ImportBatches().Create(context.Background(), ImportBatch{
HostID: hostID, PackID: packID, ProviderID: providerID,
Mode: "partial", BatchStatus: "running", AccessStatus: "pending",
})
err := store.ImportBatches().UpdateStatus(context.Background(), id, "succeeded", "subscription_ready")
if err != nil {
t.Fatalf("UpdateStatus() error = %v", err)
}
got, _ := store.ImportBatches().GetByID(context.Background(), id)
if got.BatchStatus != "succeeded" || got.AccessStatus != "subscription_ready" {
t.Fatalf("status = %+v, want succeeded/subscription_ready", got)
}
}
func TestImportBatchesRepoGetLatestByProviderID(t *testing.T) {
store := openTestDB(t)
hostID := createTestHost(t, store)
packID := createTestPack(t, store)
providerID := createTestProviderWithPack(t, store, packID)
store.ImportBatches().Create(context.Background(), ImportBatch{
HostID: hostID, PackID: packID, ProviderID: providerID,
Mode: "partial", BatchStatus: "running", AccessStatus: "pending",
})
id2, _ := store.ImportBatches().Create(context.Background(), ImportBatch{
HostID: hostID, PackID: packID, ProviderID: providerID,
Mode: "strict", BatchStatus: "succeeded", AccessStatus: "subscription_ready",
})
got, _ := store.ImportBatches().GetLatestByProviderID(context.Background(), providerID)
if got.ID != id2 {
t.Fatalf("latest id = %d, want %d", got.ID, id2)
}
}
func TestImportBatchesRepoGetByIDNotFound(t *testing.T) {
store := openTestDB(t)
_, err := store.ImportBatches().GetByID(context.Background(), 999)
if err == nil {
t.Fatal("GetByID(999) error = nil, want error")
}
}
func TestImportBatchesRepoCreateValidationErrors(t *testing.T) {
store := openTestDB(t)
for _, tt := range []struct {
name string
batch ImportBatch
}{
{"host_id zero", ImportBatch{HostID: 0, PackID: 1, ProviderID: 1, Mode: "m", BatchStatus: "s", AccessStatus: "s"}},
{"pack_id zero", ImportBatch{HostID: 1, PackID: 0, ProviderID: 1, Mode: "m", BatchStatus: "s", AccessStatus: "s"}},
{"provider_id zero", ImportBatch{HostID: 1, PackID: 1, ProviderID: 0, Mode: "m", BatchStatus: "s", AccessStatus: "s"}},
{"empty mode", ImportBatch{HostID: 1, PackID: 1, ProviderID: 1, Mode: "", BatchStatus: "s", AccessStatus: "s"}},
{"empty batch_status", ImportBatch{HostID: 1, PackID: 1, ProviderID: 1, Mode: "m", BatchStatus: "", AccessStatus: "s"}},
{"empty access_status", ImportBatch{HostID: 1, PackID: 1, ProviderID: 1, Mode: "m", BatchStatus: "s", AccessStatus: ""}},
} {
t.Run(tt.name, func(t *testing.T) {
_, err := store.ImportBatches().Create(context.Background(), tt.batch)
if err == nil {
t.Fatal("Create() error = nil")
}
})
}
}
func TestImportBatchesRepoUpdateStatusValidation(t *testing.T) {
store := openTestDB(t)
if err := store.ImportBatches().UpdateStatus(context.Background(), 0, "s", "s"); err == nil {
t.Fatal("UpdateStatus id=0 error = nil")
}
if err := store.ImportBatches().UpdateStatus(context.Background(), 1, "", "s"); err == nil {
t.Fatal("UpdateStatus empty batch_status error = nil")
}
}
// --- ImportBatchItems Repo Tests ---
func TestImportBatchItemsRepoCreateAndGet(t *testing.T) {
store := openTestDB(t)
batchID := createTestBatch(t, store)
id, err := store.ImportBatchItems().Create(context.Background(), ImportBatchItem{
BatchID: batchID,
KeyFingerprint: "sha256:abc",
AccountStatus: "passed",
ProbeSummaryJSON: `{"ok":true}`,
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
if id <= 0 {
t.Fatalf("Create() id = %d, want positive", id)
}
items, _ := store.ImportBatchItems().GetByBatchID(context.Background(), batchID)
if len(items) != 1 || items[0].AccountStatus != "passed" {
t.Fatalf("items = %+v, want 1 with passed status", items)
}
}
func TestImportBatchItemsRepoMultipleItems(t *testing.T) {
store := openTestDB(t)
batchID := createTestBatch(t, store)
for _, status := range []string{"passed", "failed"} {
store.ImportBatchItems().Create(context.Background(), ImportBatchItem{
BatchID: batchID, KeyFingerprint: "sha256:" + status, AccountStatus: status,
})
}
items, _ := store.ImportBatchItems().GetByBatchID(context.Background(), batchID)
if len(items) != 2 {
t.Fatalf("count = %d, want 2", len(items))
}
}
func TestImportBatchItemsRepoUpdateResult(t *testing.T) {
store := openTestDB(t)
batchID := createTestBatch(t, store)
itemID, _ := store.ImportBatchItems().Create(context.Background(), ImportBatchItem{
BatchID: batchID, KeyFingerprint: "sha256:x", AccountStatus: "pending",
})
store.ImportBatchItems().UpdateResult(context.Background(), itemID, "passed", `{"ok":true}`)
items, _ := store.ImportBatchItems().GetByBatchID(context.Background(), batchID)
if items[0].AccountStatus != "passed" {
t.Fatalf("AccountStatus = %q, want passed", items[0].AccountStatus)
}
}
func TestImportBatchItemsRepoGetByBatchIDEmpty(t *testing.T) {
store := openTestDB(t)
items, err := store.ImportBatchItems().GetByBatchID(context.Background(), 999)
if err != nil {
t.Fatalf("GetByBatchID() error = %v, want empty result", err)
}
if len(items) != 0 {
t.Fatalf("count = %d, want 0", len(items))
}
}
func TestImportBatchItemsRepoValidation(t *testing.T) {
store := openTestDB(t)
_, err := store.ImportBatchItems().Create(context.Background(), ImportBatchItem{
BatchID: 0, KeyFingerprint: "k", AccountStatus: "s",
})
if err == nil {
t.Fatal("Create batch_id=0 error = nil")
}
}
// --- Managed Resources Repo Tests ---
func TestManagedResourcesRepoCreateAndGet(t *testing.T) {
store := openTestDB(t)
batchID := createTestBatch(t, store)
id, err := store.ManagedResources().Create(context.Background(), ManagedResource{
BatchID: batchID, ResourceType: "group", HostResourceID: "g_01", ResourceName: "test-group",
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
resources, _ := store.ManagedResources().GetByBatchID(context.Background(), batchID)
if len(resources) != 1 || resources[0].HostResourceID != "g_01" {
t.Fatalf("resources = %+v, want 1 with g_01", resources)
}
_ = id
}
func TestManagedResourcesRepoMultipleResources(t *testing.T) {
store := openTestDB(t)
batchID := createTestBatch(t, store)
for _, r := range []ManagedResource{
{BatchID: batchID, ResourceType: "group", HostResourceID: "g_01", ResourceName: "group-1"},
{BatchID: batchID, ResourceType: "channel", HostResourceID: "c_01", ResourceName: "channel-1"},
{BatchID: batchID, ResourceType: "account", HostResourceID: "a_01", ResourceName: "account-1"},
} {
store.ManagedResources().Create(context.Background(), r)
}
resources, _ := store.ManagedResources().GetByBatchID(context.Background(), batchID)
if len(resources) != 3 {
t.Fatalf("count = %d, want 3", len(resources))
}
}
func TestManagedResourcesRepoGetByBatchIDEmpty(t *testing.T) {
store := openTestDB(t)
resources, _ := store.ManagedResources().GetByBatchID(context.Background(), 999)
if len(resources) != 0 {
t.Fatalf("count = %d, want 0", len(resources))
}
}
func TestManagedResourcesRepoValidationErrors(t *testing.T) {
store := openTestDB(t)
for _, tt := range []struct {
name string
r ManagedResource
}{
{"batch_id zero", ManagedResource{ResourceType: "g", HostResourceID: "h", ResourceName: "n"}},
{"empty resource_type", ManagedResource{BatchID: 1, HostResourceID: "h", ResourceName: "n"}},
{"empty host_resource_id", ManagedResource{BatchID: 1, ResourceType: "g", ResourceName: "n"}},
{"empty resource_name", ManagedResource{BatchID: 1, ResourceType: "g", HostResourceID: "h"}},
} {
t.Run(tt.name, func(t *testing.T) {
_, err := store.ManagedResources().Create(context.Background(), tt.r)
if err == nil {
t.Fatal("Create() error = nil")
}
})
}
}
// --- Probe Results Repo Tests ---
func TestProbeResultsRepoCreateAndGet(t *testing.T) {
store := openTestDB(t)
batchID := createTestBatch(t, store)
itemID := createTestBatchItem(t, store, batchID)
id, err := store.ProbeResults().Create(context.Background(), ProbeResult{
BatchItemID: itemID, ProbeType: "account_smoke", Status: "passed", SummaryJSON: `{"ok":true}`,
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
results, _ := store.ProbeResults().GetByBatchItemID(context.Background(), itemID)
if len(results) != 1 || results[0].ProbeType != "account_smoke" {
t.Fatalf("results = %+v, want 1 with account_smoke", results)
}
_ = id
}
func TestProbeResultsRepoMultipleResults(t *testing.T) {
store := openTestDB(t)
batchID := createTestBatch(t, store)
itemID := createTestBatchItem(t, store, batchID)
for _, p := range []ProbeResult{
{BatchItemID: itemID, ProbeType: "account_smoke", Status: "passed", SummaryJSON: `{"ok":true}`},
{BatchItemID: itemID, ProbeType: "model_list", Status: "passed", SummaryJSON: `{"models":["m1"]}`},
} {
store.ProbeResults().Create(context.Background(), p)
}
results, _ := store.ProbeResults().GetByBatchItemID(context.Background(), itemID)
if len(results) != 2 {
t.Fatalf("count = %d, want 2", len(results))
}
}
func TestProbeResultsRepoGetByBatchItemIDEmpty(t *testing.T) {
store := openTestDB(t)
results, _ := store.ProbeResults().GetByBatchItemID(context.Background(), 999)
if len(results) != 0 {
t.Fatalf("count = %d, want 0", len(results))
}
}
func TestProbeResultsRepoValidationErrors(t *testing.T) {
store := openTestDB(t)
for _, tt := range []struct {
name string
probe ProbeResult
}{
{"batch_item_id zero", ProbeResult{ProbeType: "t", Status: "s"}},
{"empty probe_type", ProbeResult{BatchItemID: 1, Status: "s"}},
{"empty status", ProbeResult{BatchItemID: 1, ProbeType: "t"}},
} {
t.Run(tt.name, func(t *testing.T) {
_, err := store.ProbeResults().Create(context.Background(), tt.probe)
if err == nil {
t.Fatal("Create() error = nil")
}
})
}
}
// --- Access Closures Repo Tests ---
func TestAccessClosureRecordsRepoCreateAndGet(t *testing.T) {
store := openTestDB(t)
batchID := createTestBatch(t, store)
id, err := store.AccessClosures().Create(context.Background(), AccessClosureRecord{
BatchID: batchID, ClosureType: "subscription", Status: "subscription_ready", DetailsJSON: `{"status_code":200}`,
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
records, _ := store.AccessClosures().GetByBatchID(context.Background(), batchID)
if len(records) != 1 || records[0].ClosureType != "subscription" {
t.Fatalf("records = %+v, want 1 subscription", records)
}
_ = id
}
func TestAccessClosureRecordsRepoMultiple(t *testing.T) {
store := openTestDB(t)
batchID := createTestBatch(t, store)
store.AccessClosures().Create(context.Background(), AccessClosureRecord{BatchID: batchID, ClosureType: "subscription", Status: "subscription_ready", DetailsJSON: "{}"})
store.AccessClosures().Create(context.Background(), AccessClosureRecord{BatchID: batchID, ClosureType: "self_service", Status: "self_service_ready", DetailsJSON: "{}"})
records, _ := store.AccessClosures().GetByBatchID(context.Background(), batchID)
if len(records) != 2 {
t.Fatalf("count = %d, want 2", len(records))
}
}
func TestAccessClosureRecordsRepoGetByBatchIDEmpty(t *testing.T) {
store := openTestDB(t)
records, _ := store.AccessClosures().GetByBatchID(context.Background(), 999)
if len(records) != 0 {
t.Fatalf("count = %d, want 0", len(records))
}
}
func TestAccessClosureRecordsRepoValidation(t *testing.T) {
store := openTestDB(t)
_, err := store.AccessClosures().Create(context.Background(), AccessClosureRecord{BatchID: 0, ClosureType: "t", Status: "s"})
if err == nil {
t.Fatal("Create batch_id=0 error = nil")
}
}
// --- Reconcile Runs Repo Tests ---
func createTestProviderWithPack(t *testing.T, store *DB, packID int64) int64 {
t.Helper()
id, err := store.Providers().Create(context.Background(), Provider{
PackID: packID, ProviderID: "test-provider-" + sanitizeTestName(t.Name()), DisplayName: "TP",
BaseURL: "https://tp.com", Platform: "openai",
})
if err != nil {
t.Fatalf("createTestProviderWithPack error = %v", err)
}
return id
}
func createTestProvider(t *testing.T, store *DB) int64 {
t.Helper()
packID := createTestPack(t, store)
return createTestProviderWithPack(t, store, packID)
}
func TestReconcileRunsRepoCreateAndGet(t *testing.T) {
store := openTestDB(t)
providerID := createTestProvider(t, store)
id, err := store.ReconcileRuns().Create(context.Background(), ReconcileRun{
ProviderID: providerID, Status: "active", SummaryJSON: `{"drifted":false}`,
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
runs, _ := store.ReconcileRuns().GetByProviderID(context.Background(), providerID)
if len(runs) != 1 || runs[0].Status != "active" {
t.Fatalf("runs = %+v, want 1 active", runs)
}
_ = id
}
func TestReconcileRunsRepoMultipleRunsOrderedDesc(t *testing.T) {
store := openTestDB(t)
providerID := createTestProvider(t, store)
id1, _ := store.ReconcileRuns().Create(context.Background(), ReconcileRun{ProviderID: providerID, Status: "first", SummaryJSON: "{}"})
id2, _ := store.ReconcileRuns().Create(context.Background(), ReconcileRun{ProviderID: providerID, Status: "second", SummaryJSON: "{}"})
runs, _ := store.ReconcileRuns().GetByProviderID(context.Background(), providerID)
if len(runs) != 2 || runs[0].ID != id2 || runs[1].ID != id1 {
t.Fatalf("order: got %d, %d; want %d, %d (DESC)", runs[0].ID, runs[1].ID, id2, id1)
}
}
func TestReconcileRunsRepoGetByProviderIDEmpty(t *testing.T) {
store := openTestDB(t)
runs, _ := store.ReconcileRuns().GetByProviderID(context.Background(), 999)
if len(runs) != 0 {
t.Fatalf("count = %d, want 0", len(runs))
}
}
func TestReconcileRunsRepoValidation(t *testing.T) {
store := openTestDB(t)
_, err := store.ReconcileRuns().Create(context.Background(), ReconcileRun{ProviderID: 0, Status: "s"})
if err == nil {
t.Fatal("Create provider_id=0 error = nil")
}
}

View File

@@ -0,0 +1,76 @@
package sqlite
import (
"context"
"fmt"
"strings"
)
type ManagedResource struct {
ID int64
BatchID int64
ResourceType string
HostResourceID string
ResourceName string
}
type ManagedResourcesRepo struct {
db execQuerier
}
func newManagedResourcesRepo(db execQuerier) *ManagedResourcesRepo {
return &ManagedResourcesRepo{db: db}
}
func (r *ManagedResourcesRepo) Create(ctx context.Context, resource ManagedResource) (int64, error) {
resourceType := strings.TrimSpace(resource.ResourceType)
hostResourceID := strings.TrimSpace(resource.HostResourceID)
resourceName := strings.TrimSpace(resource.ResourceName)
switch {
case resource.BatchID <= 0:
return 0, fmt.Errorf("batch_id is required")
case resourceType == "":
return 0, fmt.Errorf("resource_type is required")
case hostResourceID == "":
return 0, fmt.Errorf("host_resource_id is required")
case resourceName == "":
return 0, fmt.Errorf("resource_name is required")
}
result, err := r.db.ExecContext(ctx, `INSERT INTO managed_resources (batch_id, resource_type, host_resource_id, resource_name) VALUES (?, ?, ?, ?)`, resource.BatchID, resourceType, hostResourceID, resourceName)
if err != nil {
return 0, fmt.Errorf("insert managed resource %q: %w", hostResourceID, err)
}
id, err := result.LastInsertId()
if err != nil {
return 0, fmt.Errorf("read inserted managed resource id for %q: %w", hostResourceID, err)
}
return id, nil
}
func (r *ManagedResourcesRepo) GetByBatchID(ctx context.Context, batchID int64) ([]ManagedResource, error) {
if batchID <= 0 {
return nil, fmt.Errorf("batch_id is required")
}
rows, err := r.db.QueryContext(ctx, `SELECT id, batch_id, resource_type, host_resource_id, resource_name FROM managed_resources WHERE batch_id = ? ORDER BY id`, batchID)
if err != nil {
return nil, fmt.Errorf("query managed resources: %w", err)
}
defer rows.Close()
resources := make([]ManagedResource, 0)
for rows.Next() {
var resource ManagedResource
if err := rows.Scan(&resource.ID, &resource.BatchID, &resource.ResourceType, &resource.HostResourceID, &resource.ResourceName); err != nil {
return nil, fmt.Errorf("scan managed resource: %w", err)
}
resources = append(resources, resource)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate managed resources: %w", err)
}
return resources, nil
}

View File

@@ -7,9 +7,15 @@ import (
)
type Pack struct {
PackID string
Version string
Checksum string
ID int64
PackID string
Version string
Checksum string
Vendor string
TargetHost string
MinHostVersion string
MaxHostVersion string
ManifestJSON string
}
type PacksRepo struct {
@@ -20,10 +26,59 @@ func newPacksRepo(db execQuerier) *PacksRepo {
return &PacksRepo{db: db}
}
func (r *PacksRepo) GetByID(ctx context.Context, id int64) (Pack, error) {
if id <= 0 {
return Pack{}, fmt.Errorf("id is required")
}
var pack Pack
if err := r.db.QueryRowContext(ctx, `SELECT id, pack_id, version, checksum, vendor, target_host, min_host_version, max_host_version, manifest_json FROM packs WHERE id = ?`, id).Scan(
&pack.ID,
&pack.PackID,
&pack.Version,
&pack.Checksum,
&pack.Vendor,
&pack.TargetHost,
&pack.MinHostVersion,
&pack.MaxHostVersion,
&pack.ManifestJSON,
); err != nil {
return Pack{}, err
}
return pack, nil
}
func (r *PacksRepo) GetByPackID(ctx context.Context, packID string) (Pack, error) {
packID = strings.TrimSpace(packID)
if packID == "" {
return Pack{}, fmt.Errorf("pack_id is required")
}
var pack Pack
if err := r.db.QueryRowContext(ctx, `SELECT id, pack_id, version, checksum, vendor, target_host, min_host_version, max_host_version, manifest_json FROM packs WHERE pack_id = ?`, packID).Scan(
&pack.ID,
&pack.PackID,
&pack.Version,
&pack.Checksum,
&pack.Vendor,
&pack.TargetHost,
&pack.MinHostVersion,
&pack.MaxHostVersion,
&pack.ManifestJSON,
); err != nil {
return Pack{}, err
}
return pack, nil
}
func (r *PacksRepo) Create(ctx context.Context, pack Pack) (int64, error) {
packID := strings.TrimSpace(pack.PackID)
version := strings.TrimSpace(pack.Version)
checksum := strings.TrimSpace(pack.Checksum)
manifestJSON := strings.TrimSpace(pack.ManifestJSON)
if manifestJSON == "" {
manifestJSON = "{}"
}
switch {
case packID == "":
@@ -36,11 +91,16 @@ func (r *PacksRepo) Create(ctx context.Context, pack Pack) (int64, error) {
result, err := r.db.ExecContext(
ctx,
`INSERT INTO packs (pack_id, version, checksum)
VALUES (?, ?, ?)`,
`INSERT INTO packs (pack_id, version, checksum, vendor, target_host, min_host_version, max_host_version, manifest_json)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
packID,
version,
checksum,
strings.TrimSpace(pack.Vendor),
strings.TrimSpace(pack.TargetHost),
strings.TrimSpace(pack.MinHostVersion),
strings.TrimSpace(pack.MaxHostVersion),
manifestJSON,
)
if err != nil {
return 0, fmt.Errorf("insert pack %q: %w", packID, err)
@@ -50,6 +110,62 @@ func (r *PacksRepo) Create(ctx context.Context, pack Pack) (int64, error) {
if err != nil {
return 0, fmt.Errorf("read inserted pack id for %q: %w", packID, err)
}
return id, nil
}
func (r *PacksRepo) Upsert(ctx context.Context, pack Pack) (int64, error) {
packID := strings.TrimSpace(pack.PackID)
version := strings.TrimSpace(pack.Version)
checksum := strings.TrimSpace(pack.Checksum)
manifestJSON := strings.TrimSpace(pack.ManifestJSON)
if manifestJSON == "" {
manifestJSON = "{}"
}
switch {
case packID == "":
return 0, fmt.Errorf("pack_id is required")
case version == "":
return 0, fmt.Errorf("version is required")
case checksum == "":
return 0, fmt.Errorf("checksum is required")
}
result, err := r.db.ExecContext(
ctx,
`INSERT INTO packs (pack_id, version, checksum, vendor, target_host, min_host_version, max_host_version, manifest_json)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(pack_id) DO UPDATE SET
version = excluded.version,
checksum = excluded.checksum,
vendor = excluded.vendor,
target_host = excluded.target_host,
min_host_version = excluded.min_host_version,
max_host_version = excluded.max_host_version,
manifest_json = excluded.manifest_json`,
packID,
version,
checksum,
strings.TrimSpace(pack.Vendor),
strings.TrimSpace(pack.TargetHost),
strings.TrimSpace(pack.MinHostVersion),
strings.TrimSpace(pack.MaxHostVersion),
manifestJSON,
)
if err != nil {
return 0, fmt.Errorf("upsert pack %q: %w", packID, err)
}
id, err := result.LastInsertId()
if err == nil && id > 0 {
return id, nil
}
persisted, getErr := r.GetByPackID(ctx, packID)
if getErr != nil {
if err != nil {
return 0, fmt.Errorf("read upserted pack %q: %w", packID, getErr)
}
return 0, getErr
}
return persisted.ID, nil
}

View File

@@ -0,0 +1,159 @@
package sqlite
import (
"context"
"database/sql"
"errors"
"testing"
)
func TestPacksRepoCreateAndGet(t *testing.T) {
store := openTestDB(t)
id, err := store.Packs().Create(context.Background(), Pack{
PackID: "test-pack",
Version: "1.0.0",
Checksum: "abc123",
Vendor: "test-vendor",
TargetHost: "sub2api",
MinHostVersion: "0.1.0",
MaxHostVersion: "0.2.x",
ManifestJSON: `{"name":"test"}`,
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
if id <= 0 {
t.Fatalf("Create() id = %d, want positive", id)
}
got, err := store.Packs().GetByID(context.Background(), id)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if got.PackID != "test-pack" || got.Version != "1.0.0" {
t.Fatalf("GetByID() = %+v, want pack test-pack", got)
}
got2, err := store.Packs().GetByPackID(context.Background(), "test-pack")
if err != nil {
t.Fatalf("GetByPackID() error = %v", err)
}
if got2.ID != id {
t.Fatalf("GetByPackID() id = %d, want %d", got2.ID, id)
}
}
func TestPacksRepoCreateDefaultsManifestJSON(t *testing.T) {
store := openTestDB(t)
id, err := store.Packs().Create(context.Background(), Pack{
PackID: "no-manifest",
Version: "1.0.0",
Checksum: "chk",
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
got, err := store.Packs().GetByID(context.Background(), id)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if got.ManifestJSON != "{}" {
t.Fatalf("ManifestJSON = %q, want {}", got.ManifestJSON)
}
}
func TestPacksRepoUpsertCreatesNew(t *testing.T) {
store := openTestDB(t)
id, err := store.Packs().Upsert(context.Background(), Pack{
PackID: "upsert-pack",
Version: "1.0.0",
Checksum: "chk1",
})
if err != nil {
t.Fatalf("Upsert() error = %v", err)
}
if id <= 0 {
t.Fatalf("Upsert() id = %d, want positive", id)
}
}
func TestPacksRepoUpsertUpdatesExisting(t *testing.T) {
store := openTestDB(t)
id1, err := store.Packs().Upsert(context.Background(), Pack{
PackID: "upsert-pack",
Version: "1.0.0",
Checksum: "chk1",
})
if err != nil {
t.Fatalf("Upsert() create error = %v", err)
}
id2, err := store.Packs().Upsert(context.Background(), Pack{
PackID: "upsert-pack",
Version: "2.0.0",
Checksum: "chk2",
})
if err != nil {
t.Fatalf("Upsert() update error = %v", err)
}
if id2 != id1 {
t.Fatalf("Upsert() update returned id %d, want original %d", id2, id1)
}
got, err := store.Packs().GetByID(context.Background(), id1)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if got.Version != "2.0.0" {
t.Fatalf("Version after upsert = %q, want 2.0.0", got.Version)
}
}
func TestPacksRepoValidationErrors(t *testing.T) {
store := openTestDB(t)
tests := []struct {
name string
pack Pack
}{
{"empty pack_id", Pack{Version: "v", Checksum: "c"}},
{"empty version", Pack{PackID: "p", Checksum: "c"}},
{"empty checksum", Pack{PackID: "p", Version: "v"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := store.Packs().Create(context.Background(), tt.pack)
if err == nil {
t.Fatal("Create() error = nil, want validation error")
}
})
}
}
func TestPacksRepoGetByIDNotFound(t *testing.T) {
store := openTestDB(t)
_, err := store.Packs().GetByID(context.Background(), 999)
if !errors.Is(err, sql.ErrNoRows) {
t.Fatalf("GetByID(999) error = %v, want sql.ErrNoRows", err)
}
}
func TestPacksRepoGetByPackIDEmptyError(t *testing.T) {
store := openTestDB(t)
_, err := store.Packs().GetByPackID(context.Background(), "")
if err == nil {
t.Fatal("GetByPackID('') error = nil, want error")
}
}
func TestPacksRepoGetByPackIDNotFound(t *testing.T) {
store := openTestDB(t)
_, err := store.Packs().GetByPackID(context.Background(), "nonexistent")
if !errors.Is(err, sql.ErrNoRows) {
t.Fatalf("GetByPackID() error = %v, want sql.ErrNoRows", err)
}
}

View File

@@ -0,0 +1,77 @@
package sqlite
import (
"context"
"fmt"
"strings"
)
type ProbeResult struct {
ID int64
BatchItemID int64
ProbeType string
Status string
SummaryJSON string
}
type ProbeResultsRepo struct {
db execQuerier
}
func newProbeResultsRepo(db execQuerier) *ProbeResultsRepo {
return &ProbeResultsRepo{db: db}
}
func (r *ProbeResultsRepo) Create(ctx context.Context, probe ProbeResult) (int64, error) {
probeType := strings.TrimSpace(probe.ProbeType)
status := strings.TrimSpace(probe.Status)
summaryJSON := strings.TrimSpace(probe.SummaryJSON)
if summaryJSON == "" {
summaryJSON = "{}"
}
switch {
case probe.BatchItemID <= 0:
return 0, fmt.Errorf("batch_item_id is required")
case probeType == "":
return 0, fmt.Errorf("probe_type is required")
case status == "":
return 0, fmt.Errorf("status is required")
}
result, err := r.db.ExecContext(ctx, `INSERT INTO probe_results (batch_item_id, probe_type, status, summary_json) VALUES (?, ?, ?, ?)`, probe.BatchItemID, probeType, status, summaryJSON)
if err != nil {
return 0, fmt.Errorf("insert probe result: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return 0, fmt.Errorf("read inserted probe result id: %w", err)
}
return id, nil
}
func (r *ProbeResultsRepo) GetByBatchItemID(ctx context.Context, batchItemID int64) ([]ProbeResult, error) {
if batchItemID <= 0 {
return nil, fmt.Errorf("batch_item_id is required")
}
rows, err := r.db.QueryContext(ctx, `SELECT id, batch_item_id, probe_type, status, summary_json FROM probe_results WHERE batch_item_id = ? ORDER BY id`, batchItemID)
if err != nil {
return nil, fmt.Errorf("query probe results: %w", err)
}
defer rows.Close()
probes := make([]ProbeResult, 0)
for rows.Next() {
var probe ProbeResult
if err := rows.Scan(&probe.ID, &probe.BatchItemID, &probe.ProbeType, &probe.Status, &probe.SummaryJSON); err != nil {
return nil, fmt.Errorf("scan probe result: %w", err)
}
probes = append(probes, probe)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate probe results: %w", err)
}
return probes, nil
}

View File

@@ -7,11 +7,20 @@ import (
)
type Provider struct {
PackID int64
ProviderID string
DisplayName string
BaseURL string
Platform string
ID int64
PackID int64
ProviderID string
DisplayName string
BaseURL string
Platform string
AccountType string
DefaultModelsJSON string
SmokeTestModel string
GroupTemplateJSON string
ChannelTemplateJSON string
PlanTemplateJSON string
ImportOptionsJSON string
ManifestJSON string
}
type ProvidersRepo struct {
@@ -22,11 +31,87 @@ func newProvidersRepo(db execQuerier) *ProvidersRepo {
return &ProvidersRepo{db: db}
}
func (r *ProvidersRepo) ListByProviderID(ctx context.Context, providerID string) ([]Provider, error) {
providerID = strings.TrimSpace(providerID)
if providerID == "" {
return nil, fmt.Errorf("provider_id is required")
}
rows, err := r.db.QueryContext(ctx, `SELECT id, pack_id, provider_id, display_name, base_url, platform, account_type, default_models_json, smoke_test_model, group_template_json, channel_template_json, plan_template_json, import_options_json, manifest_json FROM providers WHERE provider_id = ? ORDER BY id`, providerID)
if err != nil {
return nil, fmt.Errorf("query providers by provider_id %q: %w", providerID, err)
}
defer rows.Close()
providers := make([]Provider, 0)
for rows.Next() {
var provider Provider
if err := rows.Scan(
&provider.ID,
&provider.PackID,
&provider.ProviderID,
&provider.DisplayName,
&provider.BaseURL,
&provider.Platform,
&provider.AccountType,
&provider.DefaultModelsJSON,
&provider.SmokeTestModel,
&provider.GroupTemplateJSON,
&provider.ChannelTemplateJSON,
&provider.PlanTemplateJSON,
&provider.ImportOptionsJSON,
&provider.ManifestJSON,
); err != nil {
return nil, fmt.Errorf("scan provider by provider_id %q: %w", providerID, err)
}
providers = append(providers, provider)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate providers by provider_id %q: %w", providerID, err)
}
return providers, nil
}
func (r *ProvidersRepo) GetByPackIDAndProviderID(ctx context.Context, packID int64, providerID string) (Provider, error) {
if packID <= 0 {
return Provider{}, fmt.Errorf("pack_id is required")
}
providerID = strings.TrimSpace(providerID)
if providerID == "" {
return Provider{}, fmt.Errorf("provider_id is required")
}
var provider Provider
if err := r.db.QueryRowContext(ctx, `SELECT id, pack_id, provider_id, display_name, base_url, platform, account_type, default_models_json, smoke_test_model, group_template_json, channel_template_json, plan_template_json, import_options_json, manifest_json FROM providers WHERE pack_id = ? AND provider_id = ?`, packID, providerID).Scan(
&provider.ID,
&provider.PackID,
&provider.ProviderID,
&provider.DisplayName,
&provider.BaseURL,
&provider.Platform,
&provider.AccountType,
&provider.DefaultModelsJSON,
&provider.SmokeTestModel,
&provider.GroupTemplateJSON,
&provider.ChannelTemplateJSON,
&provider.PlanTemplateJSON,
&provider.ImportOptionsJSON,
&provider.ManifestJSON,
); err != nil {
return Provider{}, err
}
return provider, nil
}
func (r *ProvidersRepo) Create(ctx context.Context, provider Provider) (int64, error) {
providerID := strings.TrimSpace(provider.ProviderID)
displayName := strings.TrimSpace(provider.DisplayName)
baseURL := strings.TrimSpace(provider.BaseURL)
platform := strings.TrimSpace(provider.Platform)
manifestJSON := strings.TrimSpace(provider.ManifestJSON)
if manifestJSON == "" {
manifestJSON = "{}"
}
switch {
case provider.PackID <= 0:
@@ -43,13 +128,21 @@ func (r *ProvidersRepo) Create(ctx context.Context, provider Provider) (int64, e
result, err := r.db.ExecContext(
ctx,
`INSERT INTO providers (pack_id, provider_id, display_name, base_url, platform)
VALUES (?, ?, ?, ?, ?)`,
`INSERT INTO providers (pack_id, provider_id, display_name, base_url, platform, account_type, default_models_json, smoke_test_model, group_template_json, channel_template_json, plan_template_json, import_options_json, manifest_json)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
provider.PackID,
providerID,
displayName,
baseURL,
platform,
strings.TrimSpace(provider.AccountType),
defaultJSONArray(provider.DefaultModelsJSON),
strings.TrimSpace(provider.SmokeTestModel),
defaultJSONObject(provider.GroupTemplateJSON),
defaultJSONObject(provider.ChannelTemplateJSON),
defaultJSONObject(provider.PlanTemplateJSON),
defaultJSONObject(provider.ImportOptionsJSON),
manifestJSON,
)
if err != nil {
return 0, fmt.Errorf("insert provider %q: %w", providerID, err)
@@ -59,6 +152,90 @@ func (r *ProvidersRepo) Create(ctx context.Context, provider Provider) (int64, e
if err != nil {
return 0, fmt.Errorf("read inserted provider id for %q: %w", providerID, err)
}
return id, nil
}
func (r *ProvidersRepo) Upsert(ctx context.Context, provider Provider) (int64, error) {
providerID := strings.TrimSpace(provider.ProviderID)
displayName := strings.TrimSpace(provider.DisplayName)
baseURL := strings.TrimSpace(provider.BaseURL)
platform := strings.TrimSpace(provider.Platform)
manifestJSON := strings.TrimSpace(provider.ManifestJSON)
if manifestJSON == "" {
manifestJSON = "{}"
}
switch {
case provider.PackID <= 0:
return 0, fmt.Errorf("pack_id is required")
case providerID == "":
return 0, fmt.Errorf("provider_id is required")
case displayName == "":
return 0, fmt.Errorf("display_name is required")
case baseURL == "":
return 0, fmt.Errorf("base_url is required")
case platform == "":
return 0, fmt.Errorf("platform is required")
}
result, err := r.db.ExecContext(
ctx,
`INSERT INTO providers (pack_id, provider_id, display_name, base_url, platform, account_type, default_models_json, smoke_test_model, group_template_json, channel_template_json, plan_template_json, import_options_json, manifest_json)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(pack_id, provider_id) DO UPDATE SET
display_name = excluded.display_name,
base_url = excluded.base_url,
platform = excluded.platform,
account_type = excluded.account_type,
default_models_json = excluded.default_models_json,
smoke_test_model = excluded.smoke_test_model,
group_template_json = excluded.group_template_json,
channel_template_json = excluded.channel_template_json,
plan_template_json = excluded.plan_template_json,
import_options_json = excluded.import_options_json,
manifest_json = excluded.manifest_json`,
provider.PackID,
providerID,
displayName,
baseURL,
platform,
strings.TrimSpace(provider.AccountType),
defaultJSONArray(provider.DefaultModelsJSON),
strings.TrimSpace(provider.SmokeTestModel),
defaultJSONObject(provider.GroupTemplateJSON),
defaultJSONObject(provider.ChannelTemplateJSON),
defaultJSONObject(provider.PlanTemplateJSON),
defaultJSONObject(provider.ImportOptionsJSON),
manifestJSON,
)
if err != nil {
return 0, fmt.Errorf("upsert provider %q: %w", providerID, err)
}
id, err := result.LastInsertId()
if err == nil && id > 0 {
return id, nil
}
persisted, getErr := r.GetByPackIDAndProviderID(ctx, provider.PackID, providerID)
if getErr != nil {
if err != nil {
return 0, fmt.Errorf("read upserted provider %q: %w", providerID, getErr)
}
return 0, getErr
}
return persisted.ID, nil
}
func defaultJSONObject(value string) string {
if strings.TrimSpace(value) == "" {
return "{}"
}
return value
}
func defaultJSONArray(value string) string {
if strings.TrimSpace(value) == "" {
return "[]"
}
return value
}

View File

@@ -0,0 +1,172 @@
package sqlite
import (
"context"
"database/sql"
"errors"
"testing"
)
func TestProvidersRepoCreateAndGet(t *testing.T) {
store := openTestDB(t)
packID := createTestPack(t, store)
providerID, err := store.Providers().Create(context.Background(), Provider{
PackID: packID,
ProviderID: "deepseek",
DisplayName: "DeepSeek",
BaseURL: "https://api.deepseek.com",
Platform: "openai",
AccountType: "api",
SmokeTestModel: "deepseek-chat",
ManifestJSON: `{"models":["deepseek-chat"]}`,
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
if providerID <= 0 {
t.Fatalf("Create() id = %d, want positive", providerID)
}
got, err := store.Providers().GetByPackIDAndProviderID(context.Background(), packID, "deepseek")
if err != nil {
t.Fatalf("GetByPackIDAndProviderID() error = %v", err)
}
if got.ProviderID != "deepseek" || got.DisplayName != "DeepSeek" {
t.Fatalf("GetByPackIDAndProviderID() = %+v, want deepseek", got)
}
}
func TestProvidersRepoListByProviderID(t *testing.T) {
store := openTestDB(t)
packID1 := createTestPackWithSuffix(t, store, "a")
packID2 := createTestPackWithSuffix(t, store, "b")
store.Providers().Create(context.Background(), Provider{PackID: packID1, ProviderID: "deepseek", DisplayName: "DS1", BaseURL: "https://a.com", Platform: "openai"})
store.Providers().Create(context.Background(), Provider{PackID: packID2, ProviderID: "deepseek", DisplayName: "DS2", BaseURL: "https://b.com", Platform: "openai"})
providers, err := store.Providers().ListByProviderID(context.Background(), "deepseek")
if err != nil {
t.Fatalf("ListByProviderID() error = %v", err)
}
if len(providers) != 2 {
t.Fatalf("ListByProviderID() count = %d, want 2", len(providers))
}
}
func createTestPackWithSuffix(t *testing.T, store *DB, suffix string) int64 {
t.Helper()
id, err := store.Packs().Create(context.Background(), Pack{
PackID: "pack-" + sanitizeTestName(t.Name()) + "-" + suffix, Version: "1.0.0", Checksum: "chk",
})
if err != nil {
t.Fatalf("createTestPackWithSuffix error = %v", err)
}
return id
}
func TestProvidersRepoListByProviderIDEmpty(t *testing.T) {
store := openTestDB(t)
providers, err := store.Providers().ListByProviderID(context.Background(), "nonexistent")
if err != nil {
t.Fatalf("ListByProviderID() error = %v", err)
}
if len(providers) != 0 {
t.Fatalf("ListByProviderID() count = %d, want 0", len(providers))
}
}
func TestProvidersRepoUpsertCreatesNew(t *testing.T) {
store := openTestDB(t)
packID := createTestPack(t, store)
id, err := store.Providers().Upsert(context.Background(), Provider{
PackID: packID, ProviderID: "upsert-p", DisplayName: "P", BaseURL: "https://u.com", Platform: "openai",
})
if err != nil {
t.Fatalf("Upsert() error = %v", err)
}
if id <= 0 {
t.Fatalf("Upsert() id = %d, want positive", id)
}
}
func TestProvidersRepoUpsertUpdatesExisting(t *testing.T) {
store := openTestDB(t)
packID := createTestPack(t, store)
id1, _ := store.Providers().Upsert(context.Background(), Provider{
PackID: packID, ProviderID: "upsert-p", DisplayName: "P1", BaseURL: "https://u1.com", Platform: "openai",
})
id2, _ := store.Providers().Upsert(context.Background(), Provider{
PackID: packID, ProviderID: "upsert-p", DisplayName: "P2", BaseURL: "https://u2.com", Platform: "openai",
})
if id2 != id1 {
t.Fatalf("Upsert update id = %d, want %d", id2, id1)
}
got, _ := store.Providers().GetByPackIDAndProviderID(context.Background(), packID, "upsert-p")
if got.DisplayName != "P2" {
t.Fatalf("DisplayName after upsert = %q, want P2", got.DisplayName)
}
}
func TestProvidersRepoValidationErrors(t *testing.T) {
store := openTestDB(t)
packID := createTestPack(t, store)
tests := []struct {
name string
provider Provider
}{
{"pack_id zero", Provider{ProviderID: "p", DisplayName: "d", BaseURL: "b", Platform: "openai"}},
{"empty provider_id", Provider{PackID: packID, DisplayName: "d", BaseURL: "b", Platform: "openai"}},
{"empty display_name", Provider{PackID: packID, ProviderID: "p", BaseURL: "b", Platform: "openai"}},
{"empty base_url", Provider{PackID: packID, ProviderID: "p", DisplayName: "d", Platform: "openai"}},
{"empty platform", Provider{PackID: packID, ProviderID: "p", DisplayName: "d", BaseURL: "b"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := store.Providers().Create(context.Background(), tt.provider)
if err == nil {
t.Fatal("Create() error = nil, want validation error")
}
})
}
}
func TestProvidersRepoGetByPackIDAndProviderIDNotFound(t *testing.T) {
store := openTestDB(t)
_, err := store.Providers().GetByPackIDAndProviderID(context.Background(), 999, "p")
if !errors.Is(err, sql.ErrNoRows) {
t.Fatalf("GetByPackIDAndProviderID() error = %v, want sql.ErrNoRows", err)
}
}
func TestProvidersRepoGetByPackIDEmptyError(t *testing.T) {
store := openTestDB(t)
_, err := store.Providers().GetByPackIDAndProviderID(context.Background(), 0, "p")
if err == nil {
t.Fatal("GetByPackIDAndProviderID with packID=0 error = nil, want error")
}
}
func TestDefaultJSONObject(t *testing.T) {
if got := defaultJSONObject(""); got != "{}" {
t.Fatalf("defaultJSONObject('') = %q, want {}", got)
}
if got := defaultJSONObject(`{"a":1}`); got != `{"a":1}` {
t.Fatalf("defaultJSONObject() = %q, want input", got)
}
}
func TestDefaultJSONArray(t *testing.T) {
if got := defaultJSONArray(""); got != "[]" {
t.Fatalf("defaultJSONArray('') = %q, want []", got)
}
if got := defaultJSONArray(`["a"]`); got != `["a"]` {
t.Fatalf("defaultJSONArray() = %q, want input", got)
}
}

View File

@@ -0,0 +1,73 @@
package sqlite
import (
"context"
"fmt"
"strings"
)
type ReconcileRun struct {
ID int64
ProviderID int64
Status string
SummaryJSON string
}
type ReconcileRunsRepo struct {
db execQuerier
}
func newReconcileRunsRepo(db execQuerier) *ReconcileRunsRepo {
return &ReconcileRunsRepo{db: db}
}
func (r *ReconcileRunsRepo) Create(ctx context.Context, run ReconcileRun) (int64, error) {
status := strings.TrimSpace(run.Status)
summaryJSON := strings.TrimSpace(run.SummaryJSON)
if summaryJSON == "" {
summaryJSON = "{}"
}
switch {
case run.ProviderID <= 0:
return 0, fmt.Errorf("provider_id is required")
case status == "":
return 0, fmt.Errorf("status is required")
}
result, err := r.db.ExecContext(ctx, `INSERT INTO reconcile_runs (provider_id, status, summary_json) VALUES (?, ?, ?)`, run.ProviderID, status, summaryJSON)
if err != nil {
return 0, fmt.Errorf("insert reconcile run: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return 0, fmt.Errorf("read inserted reconcile run id: %w", err)
}
return id, nil
}
func (r *ReconcileRunsRepo) GetByProviderID(ctx context.Context, providerID int64) ([]ReconcileRun, error) {
if providerID <= 0 {
return nil, fmt.Errorf("provider_id is required")
}
rows, err := r.db.QueryContext(ctx, `SELECT id, provider_id, status, summary_json FROM reconcile_runs WHERE provider_id = ? ORDER BY id DESC`, providerID)
if err != nil {
return nil, fmt.Errorf("query reconcile runs: %w", err)
}
defer rows.Close()
runs := make([]ReconcileRun, 0)
for rows.Next() {
var run ReconcileRun
if err := rows.Scan(&run.ID, &run.ProviderID, &run.Status, &run.SummaryJSON); err != nil {
return nil, fmt.Errorf("scan reconcile run: %w", err)
}
runs = append(runs, run)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate reconcile runs: %w", err)
}
return runs, nil
}

View File

@@ -4,10 +4,10 @@
它不是宿主原生插件,而是一个可被控制面读取的 `model_pack`,用于描述国产模型 provider 的默认接入模板、默认模型映射、默认套餐和导入约束。
当前目录仅提供协议样例
当前目录现在同时包含
- `pack.json.example`
- `providers/deepseek.json.example`
- 真实可校验包:`pack.json``providers/deepseek.json``checksums.txt`
- 协议样例:`pack.json.example``providers/deepseek.json.example`
后续真实交付时,可以扩展更多 provider

View File

@@ -0,0 +1,2 @@
db931e9a90f6c1040d285c65582c5dae4c85075e85ce6d87e59cd39a6441d6f1 pack.json
fc2259a85de73cd14ea3f0d6ffdf71be79296d50cf9cbee604633d36492fec49 providers/deepseek.json

View File

@@ -0,0 +1,10 @@
{
"pack_id": "openai-cn-pack",
"version": "1.0.0",
"vendor": "YourTeam",
"target_host": "sub2api",
"min_host_version": "0.1.126",
"max_host_version": "0.2.x",
"providers_dir": "providers",
"checksum_file": "checksums.txt"
}

View File

@@ -0,0 +1,31 @@
{
"provider_id": "deepseek",
"display_name": "DeepSeek OpenAI Compatible",
"base_url": "https://api.deepseek.com",
"platform": "openai",
"account_type": "api",
"default_models": ["deepseek-chat", "deepseek-reasoner"],
"smoke_test_model": "deepseek-chat",
"group_template": {
"name": "DeepSeek 默认分组",
"rate_multiplier": 1.0
},
"channel_template": {
"name": "DeepSeek 默认渠道",
"model_mapping": {
"deepseek-chat": "deepseek-chat",
"deepseek-reasoner": "deepseek-reasoner"
}
},
"plan_template": {
"name": "DeepSeek 默认套餐",
"price": 19.9,
"validity_days": 30,
"validity_unit": "day"
},
"import": {
"supports_multi_key": true,
"supports_strict": true,
"supports_partial": true
}
}

View File

@@ -66,9 +66,9 @@ func TestLoadAdminTokenFromEnvReturnsErrorWhenMissing(t *testing.T) {
}
}
func TestBootstrapBuildsServerWithStartupConfigOnly(t *testing.T) {
func TestBootstrapBuildsServerWithStartupConfigAndAdminToken(t *testing.T) {
t.Setenv("SUB2API_CRM_LISTEN_ADDR", ":8181")
t.Setenv("SUB2API_CRM_ADMIN_TOKEN", "")
t.Setenv("SUB2API_CRM_ADMIN_TOKEN", "admin-token")
server, err := app.Bootstrap(context.Background())
if err != nil {

View File

@@ -0,0 +1,64 @@
package integration_test
import (
"os"
"path/filepath"
"strings"
"testing"
)
func TestDistributionArtifactsExistAndReferenceRequiredEnv(t *testing.T) {
root := repoRoot(t)
for _, path := range []string{
filepath.Join(root, "Dockerfile"),
filepath.Join(root, ".env.example"),
filepath.Join(root, "docker-compose.yml"),
filepath.Join(root, "docs", "DEPLOYMENT.md"),
} {
if _, err := os.Stat(path); err != nil {
t.Fatalf("required distribution artifact %q missing: %v", path, err)
}
}
dockerfile := mustReadText(t, filepath.Join(root, "Dockerfile"))
if !strings.Contains(dockerfile, "SUB2API_CRM_ADMIN_TOKEN") {
t.Fatalf("Dockerfile must document runtime dependency on SUB2API_CRM_ADMIN_TOKEN; content=%s", dockerfile)
}
envExample := mustReadText(t, filepath.Join(root, ".env.example"))
for _, key := range []string{"SUB2API_CRM_LISTEN_ADDR", "SUB2API_CRM_SQLITE_DSN", "SUB2API_CRM_ADMIN_TOKEN"} {
if !strings.Contains(envExample, key+"=") {
t.Fatalf(".env.example missing %s; content=%s", key, envExample)
}
}
compose := mustReadText(t, filepath.Join(root, "docker-compose.yml"))
if !strings.Contains(compose, "/healthz") {
t.Fatalf("docker-compose.yml missing healthz probe; content=%s", compose)
}
deployment := mustReadText(t, filepath.Join(root, "docs", "DEPLOYMENT.md"))
for _, needle := range []string{"docker compose up --build -d", "SUB2API_CRM_ADMIN_TOKEN", "/healthz"} {
if !strings.Contains(deployment, needle) {
t.Fatalf("DEPLOYMENT.md missing %q; content=%s", needle, deployment)
}
}
}
func repoRoot(t *testing.T) string {
t.Helper()
root, err := filepath.Abs(filepath.Join("..", ".."))
if err != nil {
t.Fatalf("resolve repo root: %v", err)
}
return root
}
func mustReadText(t *testing.T, path string) string {
t.Helper()
content, err := os.ReadFile(path)
if err != nil {
t.Fatalf("read %s: %v", path, err)
}
return string(content)
}

View File

@@ -193,6 +193,89 @@ func TestSub2APIHostAdapterGetAccountModelsParsesEnvelope(t *testing.T) {
}
}
func TestSub2APIHostAdapterChecksGatewayAccess(t *testing.T) {
server := newSub2APIStubServer(t, sub2APIStubConfig{
requireAPIKey: true,
version: "0.1.126",
})
defer server.Close()
client, err := sub2api.NewClient(server.URL)
if err != nil {
t.Fatalf("NewClient() error = %v", err)
}
result, err := client.CheckGatewayAccess(context.Background(), sub2api.GatewayAccessCheckRequest{
APIKey: "user-api-key",
ExpectedModel: "deepseek-chat",
})
if err != nil {
t.Fatalf("CheckGatewayAccess() error = %v", err)
}
if !result.OK || !result.HasExpectedModel {
t.Fatalf("CheckGatewayAccess() = %+v, want ok with expected model", result)
}
}
func TestSub2APIHostAdapterDeletesManagedResources(t *testing.T) {
server := newSub2APIStubServer(t, sub2APIStubConfig{
requireAPIKey: true,
version: "0.1.126",
})
defer server.Close()
client, err := sub2api.NewClient(server.URL, sub2api.WithAPIKey("api-key"))
if err != nil {
t.Fatalf("NewClient() error = %v", err)
}
if err := client.DeleteAccount(context.Background(), "account_1"); err != nil {
t.Fatalf("DeleteAccount() error = %v", err)
}
if err := client.DeleteChannel(context.Background(), "channel_1"); err != nil {
t.Fatalf("DeleteChannel() error = %v", err)
}
if err := client.DeleteGroup(context.Background(), "group_1"); err != nil {
t.Fatalf("DeleteGroup() error = %v", err)
}
}
func TestSub2APIHostAdapterListManagedResources(t *testing.T) {
server := newSub2APIStubServer(t, sub2APIStubConfig{
requireAPIKey: true,
version: "0.1.126",
})
defer server.Close()
client, err := sub2api.NewClient(server.URL, sub2api.WithAPIKey("api-key"))
if err != nil {
t.Fatalf("NewClient() error = %v", err)
}
snapshot, err := client.ListManagedResources(context.Background(), sub2api.ListManagedResourcesRequest{
GroupName: "crm-deepseek-group",
ChannelName: "crm-deepseek-channel",
PlanName: "crm-deepseek-plan",
AccountNamePrefix: "deepseek-",
})
if err != nil {
t.Fatalf("ListManagedResources() error = %v", err)
}
if len(snapshot.Groups) != 1 || snapshot.Groups[0].ID != "group_1" {
t.Fatalf("Groups = %+v, want one group_1 match", snapshot.Groups)
}
if len(snapshot.Channels) != 2 || snapshot.Channels[0].ID != "channel_1" || snapshot.Channels[1].ID != "channel_2" {
t.Fatalf("Channels = %+v, want two matching channels", snapshot.Channels)
}
if len(snapshot.Plans) != 1 || snapshot.Plans[0].ID != "plan_1" {
t.Fatalf("Plans = %+v, want one plan_1 match", snapshot.Plans)
}
if len(snapshot.Accounts) != 2 || snapshot.Accounts[0].ID != "account_1" || snapshot.Accounts[1].ID != "account_2" {
t.Fatalf("Accounts = %+v, want two deepseek account matches", snapshot.Accounts)
}
}
type sub2APIStubConfig struct {
requireAPIKey bool
version string
@@ -220,52 +303,106 @@ func newSub2APIStubServer(t *testing.T, cfg sub2APIStubConfig) *httptest.Server
if !mustStubAuth(t, w, r, cfg.requireAPIKey) {
return
}
if r.Method != http.MethodPost {
switch r.Method {
case http.MethodGet:
writeJSON(t, w, http.StatusOK, map[string]any{
"data": []map[string]any{
{"id": "group_1", "name": "crm-deepseek-group"},
{"id": "group_2", "name": "other-group"},
},
})
case http.MethodPost:
var payload map[string]any
_ = json.NewDecoder(r.Body).Decode(&payload)
if payload["name"] == "" || payload["rate_multiplier"] == nil {
writeJSON(t, w, http.StatusUnprocessableEntity, map[string]any{"error": "validation failed"})
return
}
writeJSON(t, w, http.StatusCreated, map[string]any{
"data": map[string]any{
"id": "group_1",
"name": payload["name"],
},
})
default:
http.NotFound(w, r)
}
})
mux.HandleFunc("/api/v1/admin/groups/", func(w http.ResponseWriter, r *http.Request) {
if !mustStubAuth(t, w, r, cfg.requireAPIKey) {
return
}
if r.Method != http.MethodDelete {
http.NotFound(w, r)
return
}
var payload map[string]any
_ = json.NewDecoder(r.Body).Decode(&payload)
if payload["name"] == "" || payload["rate_multiplier"] == nil {
writeJSON(t, w, http.StatusUnprocessableEntity, map[string]any{"error": "validation failed"})
return
}
writeJSON(t, w, http.StatusCreated, map[string]any{
"data": map[string]any{
"id": "group_1",
"name": payload["name"],
},
})
w.WriteHeader(http.StatusNoContent)
})
mux.HandleFunc("/api/v1/admin/channels", func(w http.ResponseWriter, r *http.Request) {
if !mustStubAuth(t, w, r, cfg.requireAPIKey) {
return
}
if r.Method != http.MethodPost {
switch r.Method {
case http.MethodGet:
writeJSON(t, w, http.StatusOK, map[string]any{
"data": []map[string]any{
{"id": "channel_1", "name": "crm-deepseek-channel"},
{"id": "channel_2", "name": "crm-deepseek-channel"},
{"id": "channel_3", "name": "other-channel"},
},
})
case http.MethodPost:
writeJSON(t, w, http.StatusUnprocessableEntity, map[string]any{"error": "validation failed"})
default:
http.NotFound(w, r)
}
})
mux.HandleFunc("/api/v1/admin/channels/", func(w http.ResponseWriter, r *http.Request) {
if !mustStubAuth(t, w, r, cfg.requireAPIKey) {
return
}
if r.Method != http.MethodDelete {
http.NotFound(w, r)
return
}
writeJSON(t, w, http.StatusUnprocessableEntity, map[string]any{"error": "validation failed"})
w.WriteHeader(http.StatusNoContent)
})
mux.HandleFunc("/api/v1/admin/payment/plans", func(w http.ResponseWriter, r *http.Request) {
if !mustStubAuth(t, w, r, cfg.requireAPIKey) {
return
}
if r.Method != http.MethodPost {
switch r.Method {
case http.MethodGet:
writeJSON(t, w, http.StatusOK, map[string]any{
"data": []map[string]any{
{"id": "plan_1", "name": "crm-deepseek-plan"},
{"id": "plan_2", "name": "other-plan"},
},
})
case http.MethodPost:
writeJSON(t, w, http.StatusUnprocessableEntity, map[string]any{"error": "validation failed"})
default:
http.NotFound(w, r)
return
}
writeJSON(t, w, http.StatusUnprocessableEntity, map[string]any{"error": "validation failed"})
})
mux.HandleFunc("/api/v1/admin/accounts", func(w http.ResponseWriter, r *http.Request) {
if !mustStubAuth(t, w, r, cfg.requireAPIKey) {
return
}
if r.Method != http.MethodPost {
switch r.Method {
case http.MethodGet:
writeJSON(t, w, http.StatusOK, map[string]any{
"data": []map[string]any{
{"id": "account_1", "name": "deepseek-01"},
{"id": "account_2", "name": "deepseek-02"},
{"id": "account_3", "name": "other-01"},
},
})
case http.MethodPost:
writeJSON(t, w, http.StatusUnprocessableEntity, map[string]any{"error": "validation failed"})
default:
http.NotFound(w, r)
return
}
writeJSON(t, w, http.StatusUnprocessableEntity, map[string]any{"error": "validation failed"})
})
mux.HandleFunc("/api/v1/admin/accounts/batch", func(w http.ResponseWriter, r *http.Request) {
if !mustStubAuth(t, w, r, cfg.requireAPIKey) {
@@ -287,6 +424,14 @@ func newSub2APIStubServer(t *testing.T, cfg sub2APIStubConfig) *httptest.Server
return
}
parts := strings.Split(strings.TrimPrefix(r.URL.Path, "/api/v1/admin/accounts/"), "/")
if len(parts) == 1 {
if r.Method != http.MethodDelete {
http.NotFound(w, r)
return
}
w.WriteHeader(http.StatusNoContent)
return
}
if len(parts) != 2 {
http.NotFound(w, r)
return
@@ -333,6 +478,19 @@ func newSub2APIStubServer(t *testing.T, cfg sub2APIStubConfig) *httptest.Server
},
})
})
mux.HandleFunc("/v1/models", func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("x-api-key"); got != "user-api-key" {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"error":"unauthorized"}`))
return
}
writeJSON(t, w, http.StatusOK, map[string]any{
"data": []map[string]any{
{"id": "deepseek-chat"},
{"id": "deepseek-reasoner"},
},
})
})
return httptest.NewServer(mux)
}

View File

@@ -108,8 +108,8 @@ func TestStoreInitRecordsMigrationLedgerOnce(t *testing.T) {
if err != nil {
t.Fatalf("first sqlite.Open() error = %v", err)
}
if got := countRows(t, store1.SQLDB(), "schema_migrations"); got != 1 {
t.Fatalf("schema_migrations row count after first open = %d, want 1", got)
if got := countRows(t, store1.SQLDB(), "schema_migrations"); got != 3 {
t.Fatalf("schema_migrations row count after first open = %d, want 3", got)
}
if err := store1.Close(); err != nil {
t.Fatalf("first store.Close() error = %v", err)
@@ -121,8 +121,8 @@ func TestStoreInitRecordsMigrationLedgerOnce(t *testing.T) {
}
defer closeTestStore(t, store2)
if got := countRows(t, store2.SQLDB(), "schema_migrations"); got != 1 {
t.Fatalf("schema_migrations row count after second open = %d, want 1", got)
if got := countRows(t, store2.SQLDB(), "schema_migrations"); got != 3 {
t.Fatalf("schema_migrations row count after second open = %d, want 3", got)
}
}
@@ -140,8 +140,8 @@ func TestStoreInitBackfillsLedgerForCompletePreLedgerSchema(t *testing.T) {
}
defer closeTestStore(t, store)
if got := countRows(t, store.SQLDB(), "schema_migrations"); got != 1 {
t.Fatalf("schema_migrations row count after backfill = %d, want 1", got)
if got := countRows(t, store.SQLDB(), "schema_migrations"); got != 3 {
t.Fatalf("schema_migrations row count after backfill = %d, want 3", got)
}
}

View File

@@ -0,0 +1,141 @@
package integration_test
import (
"context"
"testing"
"sub2api-cn-relay-manager/internal/store/sqlite"
)
func TestStoreRuntimeCreatesOperationalTables(t *testing.T) {
store := openTestStore(t)
defer closeTestStore(t, store)
for _, table := range []string{
"hosts",
"packs",
"providers",
"import_batches",
"import_batch_items",
"managed_resources",
"probe_results",
"access_closure_records",
"reconcile_runs",
} {
if !tableExists(t, store.SQLDB(), table) {
t.Fatalf("table %q does not exist after store initialization", table)
}
}
}
func TestStoreRuntimePersistsOperationalRecords(t *testing.T) {
ctx := context.Background()
store := openTestStore(t)
defer closeTestStore(t, store)
hostID, err := store.Hosts().Create(ctx, sqlite.Host{
HostID: "host-1",
BaseURL: "https://sub2api.example.com",
HostVersion: "0.1.126",
CapabilityProbeJSON: `{"supports_batch_accounts":true}`,
})
if err != nil {
t.Fatalf("Hosts().Create() error = %v", err)
}
packID, err := store.Packs().Create(ctx, sqlite.Pack{
PackID: "openai-cn-pack",
Version: "1.0.0",
Checksum: "checksum-1",
})
if err != nil {
t.Fatalf("Packs().Create() error = %v", err)
}
providerID, err := store.Providers().Create(ctx, sqlite.Provider{
PackID: packID,
ProviderID: "deepseek",
DisplayName: "DeepSeek",
BaseURL: "https://api.deepseek.com",
Platform: "openai",
})
if err != nil {
t.Fatalf("Providers().Create() error = %v", err)
}
batchID, err := store.ImportBatches().Create(ctx, sqlite.ImportBatch{
HostID: hostID,
PackID: packID,
ProviderID: providerID,
Mode: "strict",
BatchStatus: "running",
AccessStatus: "not_configured",
})
if err != nil {
t.Fatalf("ImportBatches().Create() error = %v", err)
}
itemID, err := store.ImportBatchItems().Create(ctx, sqlite.ImportBatchItem{
BatchID: batchID,
KeyFingerprint: "fp-1",
AccountStatus: "pending",
ProbeSummaryJSON: `{}`,
})
if err != nil {
t.Fatalf("ImportBatchItems().Create() error = %v", err)
}
if _, err := store.ManagedResources().Create(ctx, sqlite.ManagedResource{
BatchID: batchID,
ResourceType: "group",
HostResourceID: "group-1",
ResourceName: "deepseek-group",
}); err != nil {
t.Fatalf("ManagedResources().Create() error = %v", err)
}
if _, err := store.ProbeResults().Create(ctx, sqlite.ProbeResult{
BatchItemID: itemID,
ProbeType: "models",
Status: "passed",
SummaryJSON: `{"models":["deepseek-chat"]}`,
}); err != nil {
t.Fatalf("ProbeResults().Create() error = %v", err)
}
if _, err := store.AccessClosures().Create(ctx, sqlite.AccessClosureRecord{
BatchID: batchID,
ClosureType: "subscription",
Status: "subscription_ready",
DetailsJSON: `{"api_key_bound":true}`,
}); err != nil {
t.Fatalf("AccessClosures().Create() error = %v", err)
}
if _, err := store.ReconcileRuns().Create(ctx, sqlite.ReconcileRun{
ProviderID: providerID,
Status: "drifted",
SummaryJSON: `{"missing_resources":1}`,
}); err != nil {
t.Fatalf("ReconcileRuns().Create() error = %v", err)
}
if got := countRows(t, store.SQLDB(), "import_batches"); got != 1 {
t.Fatalf("import_batches row count = %d, want 1", got)
}
if got := countRows(t, store.SQLDB(), "import_batch_items"); got != 1 {
t.Fatalf("import_batch_items row count = %d, want 1", got)
}
if got := countRows(t, store.SQLDB(), "managed_resources"); got != 1 {
t.Fatalf("managed_resources row count = %d, want 1", got)
}
if got := countRows(t, store.SQLDB(), "probe_results"); got != 1 {
t.Fatalf("probe_results row count = %d, want 1", got)
}
if got := countRows(t, store.SQLDB(), "access_closure_records"); got != 1 {
t.Fatalf("access_closure_records row count = %d, want 1", got)
}
if got := countRows(t, store.SQLDB(), "reconcile_runs"); got != 1 {
t.Fatalf("reconcile_runs row count = %d, want 1", got)
}
}