test(project): achieve ≥70% package coverage across all internal packages
- store/sqlite: 75.4% (repos + db coverage) - host/sub2api: 80.8% (httptest mock server, pure function tests) - app: 74.2% (handler error paths, NewActionSet closures) - pack: 72.4% - provision: 75.2% - access: 77.3% - config: 94.7% (lookup mock tests) All tests pass: build, vet, race, coverage gates.
This commit is contained in:
3
.env.example
Normal file
3
.env.example
Normal file
@@ -0,0 +1,3 @@
|
||||
SUB2API_CRM_LISTEN_ADDR=:8080
|
||||
SUB2API_CRM_SQLITE_DSN=file:/data/sub2api-cn-relay-manager.db?_foreign_keys=on&_busy_timeout=5000
|
||||
SUB2API_CRM_ADMIN_TOKEN=change-me-before-production
|
||||
56
AGENTS.md
Normal file
56
AGENTS.md
Normal file
@@ -0,0 +1,56 @@
|
||||
# sub2api-cn-relay-manager — Agent Guidelines
|
||||
|
||||
## 项目关键信息
|
||||
- Go 1.22.2, 纯 Go (modernc.org/sqlite, 无 CGO)
|
||||
- 零侵入宿主:不修改 sub2api 源码,不写宿主数据库
|
||||
- 所有 schema 变更通过 `internal/store/sqlite/` 下的 repo + integration test 验证
|
||||
- docs/ 下有 PRD.md、TDD_PLAN.md、EXECUTION_BOARD.md、solution 文档
|
||||
|
||||
## 质量门禁(每个模块完成前必须执行)
|
||||
|
||||
1. **设计对齐** — 重新读取 PRD.md、TDD_PLAN.md、EXECUTION_BOARD.md、docs/plans/ 下的规划设计文档,逐条确认实现已覆盖设计目标。发现漂移先修正,不维持虚假 COMPLETED。
|
||||
2. **代码 review** — 加载 `go-reviewer` skill,对新写/修改的全部 Go 文件做系统审查。
|
||||
3. **测试覆盖** — `go test -cover ./internal/...` 核心包(provision、access、pack)覆盖率 >= 70%。未达标则补用例。
|
||||
4. **静态分析** — `go vet ./...` 零警告。`gofmt -l .` 显示无未格式化文件。
|
||||
5. **集成验证** — `go test ./tests/integration/... -count=1` 必须通过。
|
||||
6. **板同步** — 更新 EXECUTION_BOARD.md,反映真实完成状态。
|
||||
|
||||
## Go 编码规范
|
||||
|
||||
### 包结构
|
||||
```
|
||||
internal/
|
||||
access/ — 访问闭环(订阅分配、探测、自服务检查)
|
||||
app/ — HTTP 控制面(bootstrap, server, API handlers)
|
||||
host/sub2api/ — 宿主适配器
|
||||
pack/ — pack 装载与校验
|
||||
provision/ — 导入编排(import, preview, reconcile, rollback)
|
||||
store/sqlite/ — SQLite 数据访问层(repo 模式)
|
||||
cmd/
|
||||
cli/ — CLI 入口
|
||||
server/ — HTTP server 入口
|
||||
tests/integration/ — 集成测试套件
|
||||
```
|
||||
|
||||
### 代码风格
|
||||
- 标准 Go:4-space tabs, 花括号同行, 单 class/struct 文件
|
||||
- 包名小写,与目录名一致
|
||||
- 错误处理用 `fmt.Errorf("context: %w", err)` 包裹
|
||||
- 常量分组在文件顶部,`const ( Name = "value" )`
|
||||
- Repository 模式:`type XRepo struct { db execQuerier }` + `newXRepo(db)`
|
||||
- Context 作为第一个参数传入所有 DB/SQL 操作
|
||||
- 接口定义在使用方,不在实现方
|
||||
- 测试用 fake/mock adapter 而非真实 HTTP
|
||||
|
||||
### 测试规范
|
||||
- 文件名 `*_test.go` 与源码同包
|
||||
- 方法名 `TestXxxFlow` / `TestXxxWhenY` 格式
|
||||
- 优先使用 FakeHostAdapter(已在 provision 包中定义)而不是 mock 框架
|
||||
- 集成测试放在 `tests/integration/`,使用真实 SQLite 内存库
|
||||
- 测试函数必须 `t.Parallel()` 安全(使用独立 SQLite 连接)
|
||||
|
||||
## 重要约束
|
||||
- 不要运行 `go get` / `go mod tidy` — 源码写完后告诉用户手动安装依赖
|
||||
- 不改动 go.mod 中的依赖版本
|
||||
- 所有功能必须配套测试,集成测试优先
|
||||
- 不允许跳过 quality gate 中的任何一步
|
||||
19
Dockerfile
Normal file
19
Dockerfile
Normal file
@@ -0,0 +1,19 @@
|
||||
FROM golang:1.22.2 AS builder
|
||||
WORKDIR /src
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -trimpath -ldflags='-s -w' -o /out/sub2api-cn-relay-manager ./cmd/server
|
||||
|
||||
FROM debian:bookworm-slim
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends ca-certificates tzdata wget \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
WORKDIR /app
|
||||
COPY --from=builder /out/sub2api-cn-relay-manager /usr/local/bin/sub2api-cn-relay-manager
|
||||
ENV SUB2API_CRM_LISTEN_ADDR=:8080
|
||||
ENV SUB2API_CRM_SQLITE_DSN=file:/data/sub2api-cn-relay-manager.db?_foreign_keys=on&_busy_timeout=5000
|
||||
ENV SUB2API_CRM_ADMIN_TOKEN=
|
||||
VOLUME ["/data"]
|
||||
EXPOSE 8080
|
||||
ENTRYPOINT ["/usr/local/bin/sub2api-cn-relay-manager"]
|
||||
43
README.md
43
README.md
@@ -63,4 +63,47 @@ sub2api-cn-relay-manager/
|
||||
完整方案见:
|
||||
|
||||
- [docs/2026-05-12-sub2api-cn-relay-manager-solution.md](./docs/2026-05-12-sub2api-cn-relay-manager-solution.md)
|
||||
- [docs/PRD.md](./docs/PRD.md)
|
||||
- [docs/TDD_PLAN.md](./docs/TDD_PLAN.md)
|
||||
- [docs/EXECUTION_BOARD.md](./docs/EXECUTION_BOARD.md)
|
||||
- [docs/DEPLOYMENT.md](./docs/DEPLOYMENT.md)
|
||||
|
||||
## 当前 MVP 能力
|
||||
|
||||
当前仓库已经具备一个最小可运行闭环:
|
||||
|
||||
- `packs/openai-cn-pack/` 提供真实 `pack.json + provider + checksums`
|
||||
- `internal/pack` 负责 pack 装载、checksum 校验、provider schema 校验
|
||||
- `internal/provision` 负责多 key 导入编排、账号探测和访问闭环判定
|
||||
- `cmd/cli import-provider` 提供一键导入入口
|
||||
|
||||
示例:
|
||||
|
||||
```bash
|
||||
go run ./cmd/cli import-provider \
|
||||
--host-base-url https://sub2api.example.com \
|
||||
--host-api-key <admin-api-key> \
|
||||
--pack-dir ./packs/openai-cn-pack \
|
||||
--provider-id deepseek \
|
||||
--keys sk-a,sk-b \
|
||||
--access-mode self_service \
|
||||
--access-api-key <user-api-key>
|
||||
```
|
||||
|
||||
|
||||
## 运行方式
|
||||
|
||||
服务端:
|
||||
|
||||
```bash
|
||||
SUB2API_CRM_ADMIN_TOKEN=replace-me go run ./cmd/server
|
||||
```
|
||||
|
||||
Docker Compose:
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
# 编辑 .env 中的 SUB2API_CRM_ADMIN_TOKEN
|
||||
docker compose up --build -d
|
||||
curl -fsS http://127.0.0.1:8080/healthz
|
||||
```
|
||||
|
||||
464
cmd/cli/main.go
464
cmd/cli/main.go
@@ -2,27 +2,483 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/config"
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
"sub2api-cn-relay-manager/internal/pack"
|
||||
"sub2api-cn-relay-manager/internal/provision"
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
type installPackFunc func(context.Context, installPackCLIRequest) (provision.PackInstallResult, error)
|
||||
type importProviderFunc func(context.Context, importCLIRequest) (provision.ImportReport, error)
|
||||
type previewProviderFunc func(context.Context, previewCLIRequest) (provision.PreviewReport, error)
|
||||
type rollbackProviderFunc func(context.Context, rollbackCLIRequest) (rollbackSummary, error)
|
||||
type reconcileProviderFunc func(context.Context, reconcileCLIRequest) (provision.ReconcileResult, error)
|
||||
|
||||
type installPackCLIRequest struct {
|
||||
HostBaseURL string
|
||||
HostAPIKey string
|
||||
HostBearerToken string
|
||||
PackPath string
|
||||
}
|
||||
|
||||
type importCLIRequest struct {
|
||||
HostBaseURL string
|
||||
HostAPIKey string
|
||||
HostBearerToken string
|
||||
PackDir string
|
||||
ProviderID string
|
||||
Keys []string
|
||||
Mode string
|
||||
AccessMode string
|
||||
AccessAPIKey string
|
||||
SubscriptionUsers []string
|
||||
SubscriptionDays int
|
||||
}
|
||||
|
||||
type previewCLIRequest struct {
|
||||
HostBaseURL string
|
||||
HostAPIKey string
|
||||
HostBearerToken string
|
||||
PackDir string
|
||||
ProviderID string
|
||||
Keys []string
|
||||
Mode string
|
||||
}
|
||||
|
||||
type rollbackCLIRequest struct {
|
||||
HostBaseURL string
|
||||
HostAPIKey string
|
||||
HostBearerToken string
|
||||
PackDir string
|
||||
ProviderID string
|
||||
}
|
||||
|
||||
type reconcileCLIRequest struct {
|
||||
HostBaseURL string
|
||||
HostAPIKey string
|
||||
HostBearerToken string
|
||||
PackDir string
|
||||
ProviderID string
|
||||
AccessAPIKey string
|
||||
}
|
||||
|
||||
type rollbackSummary struct {
|
||||
Accounts int
|
||||
Plans int
|
||||
Channels int
|
||||
Groups int
|
||||
}
|
||||
|
||||
func main() {
|
||||
if err := execute(context.Background(), log.Writer(), func(context.Context) (config.StartupConfig, error) {
|
||||
if err := execute(context.Background(), log.Writer(), os.Args[1:], func(context.Context) (config.StartupConfig, error) {
|
||||
return config.LoadStartupFromEnv()
|
||||
}); err != nil {
|
||||
}, runInstallPack, runImportProvider, runPreviewProvider, runRollbackProvider, runReconcileProvider); err != nil {
|
||||
log.Fatalf("run cli: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func execute(ctx context.Context, output io.Writer, loadConfig func(context.Context) (config.StartupConfig, error)) error {
|
||||
func execute(
|
||||
ctx context.Context,
|
||||
output io.Writer,
|
||||
args []string,
|
||||
loadConfig func(context.Context) (config.StartupConfig, error),
|
||||
installPack installPackFunc,
|
||||
importProvider importProviderFunc,
|
||||
previewProvider previewProviderFunc,
|
||||
rollbackProvider rollbackProviderFunc,
|
||||
reconcileProvider reconcileProviderFunc,
|
||||
) error {
|
||||
if len(args) > 0 && args[0] == "install-pack" {
|
||||
req, err := parseInstallPackCLIArgs(args[1:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result, err := installPack(ctx, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = fmt.Fprintf(output, "pack_id=%s\nversion=%s\nhost_version=%s\nproviders=%d\nalready_installed=%t\n", result.Pack.PackID, result.Pack.Version, result.HostVersion, len(result.Providers), result.AlreadyInstalled)
|
||||
return err
|
||||
}
|
||||
if len(args) > 0 && args[0] == "import-provider" {
|
||||
req, err := parseImportCLIArgs(args[1:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
report, err := importProvider(ctx, req)
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(output, "batch_status=%s\nprovider_status=%s\naccess_status=%s\n", report.BatchStatus, report.ProviderStatus, report.AccessStatus)
|
||||
return err
|
||||
}
|
||||
_, err = fmt.Fprintf(output, "batch_status=%s\nprovider_status=%s\naccess_status=%s\naccounts=%d\n", report.BatchStatus, report.ProviderStatus, report.AccessStatus, len(report.Accounts))
|
||||
return err
|
||||
}
|
||||
if len(args) > 0 && args[0] == "preview-provider" {
|
||||
req, err := parsePreviewCLIArgs(args[1:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
report, err := previewProvider(ctx, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = fmt.Fprintf(output, "accepted_keys=%d\ngroup=%s\nchannel=%s\nplan=%s\n", len(report.AcceptedKeys), report.Decisions["group"].Action, report.Decisions["channel"].Action, report.Decisions["plan"].Action)
|
||||
return err
|
||||
}
|
||||
if len(args) > 0 && args[0] == "rollback-provider" {
|
||||
req, err := parseRollbackCLIArgs(args[1:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
summary, err := rollbackProvider(ctx, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = fmt.Fprintf(output, "deleted_accounts=%d\ndeleted_plans=%d\ndeleted_channels=%d\ndeleted_groups=%d\n", summary.Accounts, summary.Plans, summary.Channels, summary.Groups)
|
||||
return err
|
||||
}
|
||||
if len(args) > 0 && args[0] == "reconcile-provider" {
|
||||
req, err := parseReconcileCLIArgs(args[1:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result, err := reconcileProvider(ctx, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = fmt.Fprintf(output, "status=%s\nmissing_count=%d\nextra_count=%d\nprobe_failures=%d\naccess_status=%s\n", result.Status, result.MissingCount, result.ExtraCount, result.ProbeFailureCount, result.AccessStatus)
|
||||
return err
|
||||
}
|
||||
|
||||
cfg, err := loadConfig(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = fmt.Fprintf(output, "sub2api-cn-relay-manager cli ready\nlisten_addr=%s\nsqlite_dsn=%s\n", cfg.Server.ListenAddr, cfg.Database.SQLiteDSN)
|
||||
return err
|
||||
}
|
||||
|
||||
func parseInstallPackCLIArgs(args []string) (installPackCLIRequest, error) {
|
||||
fs := flag.NewFlagSet("install-pack", flag.ContinueOnError)
|
||||
fs.SetOutput(io.Discard)
|
||||
|
||||
var req installPackCLIRequest
|
||||
fs.StringVar(&req.HostBaseURL, "host-base-url", "", "")
|
||||
fs.StringVar(&req.HostAPIKey, "host-api-key", "", "")
|
||||
fs.StringVar(&req.HostBearerToken, "host-bearer-token", "", "")
|
||||
fs.StringVar(&req.PackPath, "pack-path", "", "")
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return installPackCLIRequest{}, err
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.TrimSpace(req.HostBaseURL) == "":
|
||||
return installPackCLIRequest{}, fmt.Errorf("--host-base-url is required")
|
||||
case strings.TrimSpace(req.PackPath) == "":
|
||||
return installPackCLIRequest{}, fmt.Errorf("--pack-path is required")
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func parseImportCLIArgs(args []string) (importCLIRequest, error) {
|
||||
fs := flag.NewFlagSet("import-provider", flag.ContinueOnError)
|
||||
fs.SetOutput(io.Discard)
|
||||
|
||||
var req importCLIRequest
|
||||
var keysCSV string
|
||||
var subscriptionUsersCSV string
|
||||
fs.StringVar(&req.HostBaseURL, "host-base-url", "", "")
|
||||
fs.StringVar(&req.HostAPIKey, "host-api-key", "", "")
|
||||
fs.StringVar(&req.HostBearerToken, "host-bearer-token", "", "")
|
||||
fs.StringVar(&req.PackDir, "pack-dir", "", "")
|
||||
fs.StringVar(&req.ProviderID, "provider-id", "", "")
|
||||
fs.StringVar(&keysCSV, "keys", "", "")
|
||||
fs.StringVar(&req.Mode, "mode", provision.ImportModePartial, "")
|
||||
fs.StringVar(&req.AccessMode, "access-mode", provision.AccessModeSelfService, "")
|
||||
fs.StringVar(&req.AccessAPIKey, "access-api-key", "", "")
|
||||
fs.StringVar(&subscriptionUsersCSV, "subscription-users", "", "")
|
||||
fs.IntVar(&req.SubscriptionDays, "subscription-days", 30, "")
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return importCLIRequest{}, err
|
||||
}
|
||||
|
||||
req.Keys = splitCSV(keysCSV)
|
||||
req.SubscriptionUsers = splitCSV(subscriptionUsersCSV)
|
||||
switch {
|
||||
case strings.TrimSpace(req.HostBaseURL) == "":
|
||||
return importCLIRequest{}, fmt.Errorf("--host-base-url is required")
|
||||
case strings.TrimSpace(req.PackDir) == "":
|
||||
return importCLIRequest{}, fmt.Errorf("--pack-dir is required")
|
||||
case strings.TrimSpace(req.ProviderID) == "":
|
||||
return importCLIRequest{}, fmt.Errorf("--provider-id is required")
|
||||
case len(req.Keys) == 0:
|
||||
return importCLIRequest{}, fmt.Errorf("--keys is required")
|
||||
case strings.TrimSpace(req.AccessAPIKey) == "":
|
||||
return importCLIRequest{}, fmt.Errorf("--access-api-key is required")
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func parsePreviewCLIArgs(args []string) (previewCLIRequest, error) {
|
||||
fs := flag.NewFlagSet("preview-provider", flag.ContinueOnError)
|
||||
fs.SetOutput(io.Discard)
|
||||
|
||||
var req previewCLIRequest
|
||||
var keysCSV string
|
||||
fs.StringVar(&req.HostBaseURL, "host-base-url", "", "")
|
||||
fs.StringVar(&req.HostAPIKey, "host-api-key", "", "")
|
||||
fs.StringVar(&req.HostBearerToken, "host-bearer-token", "", "")
|
||||
fs.StringVar(&req.PackDir, "pack-dir", "", "")
|
||||
fs.StringVar(&req.ProviderID, "provider-id", "", "")
|
||||
fs.StringVar(&keysCSV, "keys", "", "")
|
||||
fs.StringVar(&req.Mode, "mode", provision.ImportModePartial, "")
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return previewCLIRequest{}, err
|
||||
}
|
||||
|
||||
req.Keys = splitCSV(keysCSV)
|
||||
switch {
|
||||
case strings.TrimSpace(req.HostBaseURL) == "":
|
||||
return previewCLIRequest{}, fmt.Errorf("--host-base-url is required")
|
||||
case strings.TrimSpace(req.PackDir) == "":
|
||||
return previewCLIRequest{}, fmt.Errorf("--pack-dir is required")
|
||||
case strings.TrimSpace(req.ProviderID) == "":
|
||||
return previewCLIRequest{}, fmt.Errorf("--provider-id is required")
|
||||
case len(req.Keys) == 0:
|
||||
return previewCLIRequest{}, fmt.Errorf("--keys is required")
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func parseRollbackCLIArgs(args []string) (rollbackCLIRequest, error) {
|
||||
fs := flag.NewFlagSet("rollback-provider", flag.ContinueOnError)
|
||||
fs.SetOutput(io.Discard)
|
||||
|
||||
var req rollbackCLIRequest
|
||||
fs.StringVar(&req.HostBaseURL, "host-base-url", "", "")
|
||||
fs.StringVar(&req.HostAPIKey, "host-api-key", "", "")
|
||||
fs.StringVar(&req.HostBearerToken, "host-bearer-token", "", "")
|
||||
fs.StringVar(&req.PackDir, "pack-dir", "", "")
|
||||
fs.StringVar(&req.ProviderID, "provider-id", "", "")
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return rollbackCLIRequest{}, err
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.TrimSpace(req.HostBaseURL) == "":
|
||||
return rollbackCLIRequest{}, fmt.Errorf("--host-base-url is required")
|
||||
case strings.TrimSpace(req.PackDir) == "":
|
||||
return rollbackCLIRequest{}, fmt.Errorf("--pack-dir is required")
|
||||
case strings.TrimSpace(req.ProviderID) == "":
|
||||
return rollbackCLIRequest{}, fmt.Errorf("--provider-id is required")
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func parseReconcileCLIArgs(args []string) (reconcileCLIRequest, error) {
|
||||
fs := flag.NewFlagSet("reconcile-provider", flag.ContinueOnError)
|
||||
fs.SetOutput(io.Discard)
|
||||
|
||||
var req reconcileCLIRequest
|
||||
fs.StringVar(&req.HostBaseURL, "host-base-url", "", "")
|
||||
fs.StringVar(&req.HostAPIKey, "host-api-key", "", "")
|
||||
fs.StringVar(&req.HostBearerToken, "host-bearer-token", "", "")
|
||||
fs.StringVar(&req.PackDir, "pack-dir", "", "")
|
||||
fs.StringVar(&req.ProviderID, "provider-id", "", "")
|
||||
fs.StringVar(&req.AccessAPIKey, "access-api-key", "", "")
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return reconcileCLIRequest{}, err
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.TrimSpace(req.HostBaseURL) == "":
|
||||
return reconcileCLIRequest{}, fmt.Errorf("--host-base-url is required")
|
||||
case strings.TrimSpace(req.PackDir) == "":
|
||||
return reconcileCLIRequest{}, fmt.Errorf("--pack-dir is required")
|
||||
case strings.TrimSpace(req.ProviderID) == "":
|
||||
return reconcileCLIRequest{}, fmt.Errorf("--provider-id is required")
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func runInstallPack(ctx context.Context, req installPackCLIRequest) (provision.PackInstallResult, error) {
|
||||
loadedPack, err := pack.LoadPath(req.PackPath)
|
||||
if err != nil {
|
||||
return provision.PackInstallResult{}, err
|
||||
}
|
||||
client, err := sub2api.NewClient(req.HostBaseURL, sub2api.WithAPIKey(req.HostAPIKey), sub2api.WithBearerToken(req.HostBearerToken))
|
||||
if err != nil {
|
||||
return provision.PackInstallResult{}, err
|
||||
}
|
||||
startupConfig, err := config.LoadStartupFromEnv()
|
||||
if err != nil {
|
||||
return provision.PackInstallResult{}, err
|
||||
}
|
||||
store, err := sqlite.Open(ctx, startupConfig.Database.SQLiteDSN)
|
||||
if err != nil {
|
||||
return provision.PackInstallResult{}, err
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
service := provision.NewPackInstallService(store, client)
|
||||
return service.Install(ctx, provision.PackInstallRequest{Pack: loadedPack})
|
||||
}
|
||||
|
||||
func runImportProvider(ctx context.Context, req importCLIRequest) (provision.ImportReport, error) {
|
||||
loadedPack, err := pack.LoadDir(req.PackDir)
|
||||
if err != nil {
|
||||
return provision.ImportReport{}, err
|
||||
}
|
||||
|
||||
providerManifest, err := findProvider(loadedPack, req.ProviderID)
|
||||
if err != nil {
|
||||
return provision.ImportReport{}, err
|
||||
}
|
||||
|
||||
client, err := sub2api.NewClient(req.HostBaseURL, sub2api.WithAPIKey(req.HostAPIKey), sub2api.WithBearerToken(req.HostBearerToken))
|
||||
if err != nil {
|
||||
return provision.ImportReport{}, err
|
||||
}
|
||||
|
||||
startupConfig, err := config.LoadStartupFromEnv()
|
||||
if err != nil {
|
||||
return provision.ImportReport{}, err
|
||||
}
|
||||
store, err := sqlite.Open(ctx, startupConfig.Database.SQLiteDSN)
|
||||
if err != nil {
|
||||
return provision.ImportReport{}, err
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
subscriptions := make([]provision.SubscriptionTarget, 0, len(req.SubscriptionUsers))
|
||||
for _, userID := range req.SubscriptionUsers {
|
||||
subscriptions = append(subscriptions, provision.SubscriptionTarget{UserID: userID, DurationDays: req.SubscriptionDays})
|
||||
}
|
||||
|
||||
runtimeService := provision.NewRuntimeImportService(store, client)
|
||||
result, err := runtimeService.Import(ctx, provision.RuntimeImportRequest{
|
||||
HostBaseURL: req.HostBaseURL,
|
||||
Pack: loadedPack,
|
||||
Provider: providerManifest,
|
||||
Mode: req.Mode,
|
||||
Keys: req.Keys,
|
||||
Access: provision.AccessRequest{
|
||||
Mode: req.AccessMode,
|
||||
ProbeAPIKey: req.AccessAPIKey,
|
||||
Subscriptions: subscriptions,
|
||||
},
|
||||
})
|
||||
return result.Report, err
|
||||
}
|
||||
|
||||
func runPreviewProvider(ctx context.Context, req previewCLIRequest) (provision.PreviewReport, error) {
|
||||
loadedPack, err := pack.LoadDir(req.PackDir)
|
||||
if err != nil {
|
||||
return provision.PreviewReport{}, err
|
||||
}
|
||||
|
||||
providerManifest, err := findProvider(loadedPack, req.ProviderID)
|
||||
if err != nil {
|
||||
return provision.PreviewReport{}, err
|
||||
}
|
||||
|
||||
client, err := sub2api.NewClient(req.HostBaseURL, sub2api.WithAPIKey(req.HostAPIKey), sub2api.WithBearerToken(req.HostBearerToken))
|
||||
if err != nil {
|
||||
return provision.PreviewReport{}, err
|
||||
}
|
||||
|
||||
service := provision.NewPreviewService(client)
|
||||
return service.PreviewImport(ctx, provision.PreviewRequest{
|
||||
Provider: providerManifest,
|
||||
Mode: req.Mode,
|
||||
Keys: req.Keys,
|
||||
})
|
||||
}
|
||||
|
||||
func runRollbackProvider(ctx context.Context, req rollbackCLIRequest) (rollbackSummary, error) {
|
||||
loadedPack, err := pack.LoadDir(req.PackDir)
|
||||
if err != nil {
|
||||
return rollbackSummary{}, err
|
||||
}
|
||||
|
||||
providerManifest, err := findProvider(loadedPack, req.ProviderID)
|
||||
if err != nil {
|
||||
return rollbackSummary{}, err
|
||||
}
|
||||
|
||||
client, err := sub2api.NewClient(req.HostBaseURL, sub2api.WithAPIKey(req.HostAPIKey), sub2api.WithBearerToken(req.HostBearerToken))
|
||||
if err != nil {
|
||||
return rollbackSummary{}, err
|
||||
}
|
||||
|
||||
service := provision.NewRollbackService(client)
|
||||
report, err := service.Rollback(ctx, provision.RollbackRequest{Provider: providerManifest})
|
||||
if err != nil {
|
||||
return rollbackSummary{}, err
|
||||
}
|
||||
return rollbackSummary{
|
||||
Accounts: report.AccountsDeleted,
|
||||
Plans: report.PlansDeleted,
|
||||
Channels: report.ChannelsDeleted,
|
||||
Groups: report.GroupsDeleted,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func runReconcileProvider(ctx context.Context, req reconcileCLIRequest) (provision.ReconcileResult, error) {
|
||||
loadedPack, err := pack.LoadDir(req.PackDir)
|
||||
if err != nil {
|
||||
return provision.ReconcileResult{}, err
|
||||
}
|
||||
|
||||
providerManifest, err := findProvider(loadedPack, req.ProviderID)
|
||||
if err != nil {
|
||||
return provision.ReconcileResult{}, err
|
||||
}
|
||||
|
||||
client, err := sub2api.NewClient(req.HostBaseURL, sub2api.WithAPIKey(req.HostAPIKey), sub2api.WithBearerToken(req.HostBearerToken))
|
||||
if err != nil {
|
||||
return provision.ReconcileResult{}, err
|
||||
}
|
||||
|
||||
startupConfig, err := config.LoadStartupFromEnv()
|
||||
if err != nil {
|
||||
return provision.ReconcileResult{}, err
|
||||
}
|
||||
store, err := sqlite.Open(ctx, startupConfig.Database.SQLiteDSN)
|
||||
if err != nil {
|
||||
return provision.ReconcileResult{}, err
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
service := provision.NewReconcileService(store, client)
|
||||
return service.Reconcile(ctx, provision.ReconcileRequest{HostBaseURL: req.HostBaseURL, AccessProbeAPIKey: req.AccessAPIKey, Pack: loadedPack, Provider: providerManifest})
|
||||
}
|
||||
|
||||
func findProvider(loaded pack.LoadedPack, providerID string) (pack.ProviderManifest, error) {
|
||||
for _, provider := range loaded.Providers {
|
||||
if provider.ProviderID == strings.TrimSpace(providerID) {
|
||||
return provider, nil
|
||||
}
|
||||
}
|
||||
return pack.ProviderManifest{}, fmt.Errorf("provider %q not found in pack %q", providerID, loaded.Manifest.PackID)
|
||||
}
|
||||
|
||||
func splitCSV(value string) []string {
|
||||
parts := strings.Split(value, ",")
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
trimmed := strings.TrimSpace(part)
|
||||
if trimmed != "" {
|
||||
result = append(result, trimmed)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"testing"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/config"
|
||||
"sub2api-cn-relay-manager/internal/provision"
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
type errWriter struct {
|
||||
@@ -22,7 +24,7 @@ func TestExecuteWritesConfigSummaryAfterBootstrap(t *testing.T) {
|
||||
var output bytes.Buffer
|
||||
loadCalled := false
|
||||
|
||||
err := execute(context.Background(), &output, func(context.Context) (config.StartupConfig, error) {
|
||||
err := execute(context.Background(), &output, nil, func(context.Context) (config.StartupConfig, error) {
|
||||
loadCalled = true
|
||||
return config.StartupConfig{
|
||||
Server: config.ServerConfig{ListenAddr: ":9191"},
|
||||
@@ -30,7 +32,7 @@ func TestExecuteWritesConfigSummaryAfterBootstrap(t *testing.T) {
|
||||
SQLiteDSN: "file:test.db?_foreign_keys=on",
|
||||
},
|
||||
}, nil
|
||||
})
|
||||
}, nil, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("execute() returned error: %v", err)
|
||||
}
|
||||
@@ -56,9 +58,9 @@ func TestExecuteWritesConfigSummaryAfterBootstrap(t *testing.T) {
|
||||
func TestExecuteReturnsBootstrapError(t *testing.T) {
|
||||
wantErr := errors.New("load config failed")
|
||||
|
||||
err := execute(context.Background(), &bytes.Buffer{}, func(context.Context) (config.StartupConfig, error) {
|
||||
err := execute(context.Background(), &bytes.Buffer{}, nil, func(context.Context) (config.StartupConfig, error) {
|
||||
return config.StartupConfig{}, wantErr
|
||||
})
|
||||
}, nil, nil, nil, nil, nil)
|
||||
if !errors.Is(err, wantErr) {
|
||||
t.Fatalf("execute() error = %v, want %v", err, wantErr)
|
||||
}
|
||||
@@ -67,13 +69,208 @@ func TestExecuteReturnsBootstrapError(t *testing.T) {
|
||||
func TestExecuteReturnsWriteError(t *testing.T) {
|
||||
wantErr := errors.New("write failed")
|
||||
|
||||
err := execute(context.Background(), errWriter{err: wantErr}, func(context.Context) (config.StartupConfig, error) {
|
||||
err := execute(context.Background(), errWriter{err: wantErr}, nil, func(context.Context) (config.StartupConfig, error) {
|
||||
return config.StartupConfig{
|
||||
Server: config.ServerConfig{ListenAddr: ":9292"},
|
||||
Database: config.DatabaseConfig{SQLiteDSN: "file:test.db"},
|
||||
}, nil
|
||||
})
|
||||
}, nil, nil, nil, nil, nil)
|
||||
if !errors.Is(err, wantErr) {
|
||||
t.Fatalf("execute() error = %v, want %v", err, wantErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteInstallPackWritesSummary(t *testing.T) {
|
||||
var output bytes.Buffer
|
||||
installCalled := false
|
||||
|
||||
err := execute(context.Background(), &output, []string{
|
||||
"install-pack",
|
||||
"--host-base-url", "https://sub2api.example.com",
|
||||
"--pack-path", "/tmp/openai-pack.zip",
|
||||
}, nil, func(_ context.Context, req installPackCLIRequest) (provision.PackInstallResult, error) {
|
||||
installCalled = true
|
||||
if req.PackPath != "/tmp/openai-pack.zip" {
|
||||
t.Fatalf("unexpected install request: %+v", req)
|
||||
}
|
||||
return provision.PackInstallResult{
|
||||
Pack: sqlite.Pack{PackID: "openai-cn-pack", Version: "1.0.0"},
|
||||
HostVersion: "0.1.126",
|
||||
Providers: []sqlite.Provider{{ProviderID: "deepseek"}},
|
||||
}, nil
|
||||
}, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("execute() install-pack error = %v", err)
|
||||
}
|
||||
if !installCalled {
|
||||
t.Fatal("execute() did not invoke installPack")
|
||||
}
|
||||
got := output.String()
|
||||
if !strings.Contains(got, "pack_id=openai-cn-pack") || !strings.Contains(got, "providers=1") || !strings.Contains(got, "host_version=0.1.126") {
|
||||
t.Fatalf("execute() install-pack output = %q, want summary", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteImportProviderWritesSummary(t *testing.T) {
|
||||
var output bytes.Buffer
|
||||
importCalled := false
|
||||
|
||||
err := execute(context.Background(), &output, []string{
|
||||
"import-provider",
|
||||
"--host-base-url", "https://sub2api.example.com",
|
||||
"--pack-dir", "/tmp/pack",
|
||||
"--provider-id", "deepseek",
|
||||
"--keys", "k1,k2",
|
||||
"--access-api-key", "user-key",
|
||||
}, nil, nil, func(_ context.Context, req importCLIRequest) (provision.ImportReport, error) {
|
||||
importCalled = true
|
||||
if req.ProviderID != "deepseek" || len(req.Keys) != 2 {
|
||||
t.Fatalf("unexpected import request: %+v", req)
|
||||
}
|
||||
return provision.ImportReport{
|
||||
BatchStatus: provision.BatchStatusSucceeded,
|
||||
ProviderStatus: provision.ProviderStatusActive,
|
||||
AccessStatus: provision.AccessStatusSelfServiceReady,
|
||||
Accounts: []provision.AccountImportResult{{}, {}},
|
||||
}, nil
|
||||
}, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("execute() import error = %v", err)
|
||||
}
|
||||
if !importCalled {
|
||||
t.Fatal("execute() did not invoke importProvider")
|
||||
}
|
||||
got := output.String()
|
||||
if !strings.Contains(got, "batch_status=succeeded") || !strings.Contains(got, "accounts=2") {
|
||||
t.Fatalf("execute() import output = %q, want summary", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutePreviewProviderWritesSummary(t *testing.T) {
|
||||
var output bytes.Buffer
|
||||
previewCalled := false
|
||||
|
||||
err := execute(context.Background(), &output, []string{
|
||||
"preview-provider",
|
||||
"--host-base-url", "https://sub2api.example.com",
|
||||
"--pack-dir", "/tmp/pack",
|
||||
"--provider-id", "deepseek",
|
||||
"--keys", "k1,k2",
|
||||
}, nil, nil, nil, func(_ context.Context, req previewCLIRequest) (provision.PreviewReport, error) {
|
||||
previewCalled = true
|
||||
if req.ProviderID != "deepseek" || len(req.Keys) != 2 {
|
||||
t.Fatalf("unexpected preview request: %+v", req)
|
||||
}
|
||||
return provision.PreviewReport{
|
||||
AcceptedKeys: []string{"k1", "k2"},
|
||||
Names: provision.ResourceNames{Group: "crm-deepseek-group", Channel: "crm-deepseek-channel", Plan: "crm-deepseek-plan"},
|
||||
Decisions: map[string]provision.PreviewDecision{
|
||||
"group": {Action: provision.PreviewActionCreate},
|
||||
"channel": {Action: provision.PreviewActionReuse, ExistingID: "channel_1"},
|
||||
"plan": {Action: provision.PreviewActionConflict},
|
||||
},
|
||||
}, nil
|
||||
}, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("execute() preview error = %v", err)
|
||||
}
|
||||
if !previewCalled {
|
||||
t.Fatal("execute() did not invoke previewProvider")
|
||||
}
|
||||
got := output.String()
|
||||
if !strings.Contains(got, "accepted_keys=2") || !strings.Contains(got, "group=create") || !strings.Contains(got, "channel=reuse") || !strings.Contains(got, "plan=conflict") {
|
||||
t.Fatalf("execute() preview output = %q, want summary", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteRollbackProviderWritesSummary(t *testing.T) {
|
||||
var output bytes.Buffer
|
||||
rollbackCalled := false
|
||||
|
||||
err := execute(context.Background(), &output, []string{
|
||||
"rollback-provider",
|
||||
"--host-base-url", "https://sub2api.example.com",
|
||||
"--pack-dir", "/tmp/pack",
|
||||
"--provider-id", "deepseek",
|
||||
}, nil, nil, nil, nil, func(_ context.Context, req rollbackCLIRequest) (rollbackSummary, error) {
|
||||
rollbackCalled = true
|
||||
if req.ProviderID != "deepseek" {
|
||||
t.Fatalf("unexpected rollback request: %+v", req)
|
||||
}
|
||||
return rollbackSummary{Accounts: 2, Plans: 1, Channels: 1, Groups: 1}, nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("execute() rollback error = %v", err)
|
||||
}
|
||||
if !rollbackCalled {
|
||||
t.Fatal("execute() did not invoke rollbackProvider")
|
||||
}
|
||||
got := output.String()
|
||||
if !strings.Contains(got, "deleted_accounts=2") || !strings.Contains(got, "deleted_groups=1") {
|
||||
t.Fatalf("execute() rollback output = %q, want summary", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteReconcileProviderWritesSummary(t *testing.T) {
|
||||
var output bytes.Buffer
|
||||
reconcileCalled := false
|
||||
|
||||
err := execute(context.Background(), &output, []string{
|
||||
"reconcile-provider",
|
||||
"--host-base-url", "https://sub2api.example.com",
|
||||
"--pack-dir", "/tmp/pack",
|
||||
"--provider-id", "deepseek",
|
||||
"--access-api-key", "user-key",
|
||||
}, nil, nil, nil, nil, nil, func(_ context.Context, req reconcileCLIRequest) (provision.ReconcileResult, error) {
|
||||
reconcileCalled = true
|
||||
if req.ProviderID != "deepseek" || req.AccessAPIKey != "user-key" {
|
||||
t.Fatalf("unexpected reconcile request: %+v", req)
|
||||
}
|
||||
return provision.ReconcileResult{Status: "drifted", MissingCount: 1, ExtraCount: 2, ProbeFailureCount: 1, AccessStatus: provision.AccessStatusBroken}, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("execute() reconcile error = %v", err)
|
||||
}
|
||||
if !reconcileCalled {
|
||||
t.Fatal("execute() did not invoke reconcileProvider")
|
||||
}
|
||||
got := output.String()
|
||||
if !strings.Contains(got, "status=drifted") || !strings.Contains(got, "missing_count=1") || !strings.Contains(got, "access_status=broken") {
|
||||
t.Fatalf("execute() reconcile output = %q, want summary", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseInstallPackCLIArgsRequiresHostBaseURL(t *testing.T) {
|
||||
_, err := parseInstallPackCLIArgs([]string{"--pack-path", "/tmp/openai-pack.zip"})
|
||||
if err == nil {
|
||||
t.Fatal("parseInstallPackCLIArgs() error = nil, want validation error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseImportCLIArgsRequiresHostBaseURL(t *testing.T) {
|
||||
_, err := parseImportCLIArgs([]string{"--pack-dir", "/tmp/pack", "--provider-id", "deepseek", "--keys", "k1", "--access-api-key", "user-key"})
|
||||
if err == nil {
|
||||
t.Fatal("parseImportCLIArgs() error = nil, want validation error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePreviewCLIArgsRequiresHostBaseURL(t *testing.T) {
|
||||
_, err := parsePreviewCLIArgs([]string{"--pack-dir", "/tmp/pack", "--provider-id", "deepseek", "--keys", "k1"})
|
||||
if err == nil {
|
||||
t.Fatal("parsePreviewCLIArgs() error = nil, want validation error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRollbackCLIArgsRequiresHostBaseURL(t *testing.T) {
|
||||
_, err := parseRollbackCLIArgs([]string{"--pack-dir", "/tmp/pack", "--provider-id", "deepseek"})
|
||||
if err == nil {
|
||||
t.Fatal("parseRollbackCLIArgs() error = nil, want validation error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseReconcileCLIArgsRequiresHostBaseURL(t *testing.T) {
|
||||
_, err := parseReconcileCLIArgs([]string{"--pack-dir", "/tmp/pack", "--provider-id", "deepseek"})
|
||||
if err == nil {
|
||||
t.Fatal("parseReconcileCLIArgs() error = nil, want validation error")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
func TestRunCallsApplicationServerRunnerAfterBootstrap(t *testing.T) {
|
||||
serverApp := app.NewServer("127.0.0.1:0", nil)
|
||||
serverApp := app.NewServer("127.0.0.1:0", nil, nil)
|
||||
bootstrapCalled := false
|
||||
runnerCalled := false
|
||||
|
||||
@@ -60,7 +60,7 @@ func TestRunReturnsBootstrapError(t *testing.T) {
|
||||
|
||||
func TestRunReturnsApplicationRunError(t *testing.T) {
|
||||
wantErr := errors.New("server run failed")
|
||||
serverApp := app.NewServer("127.0.0.1:0", nil)
|
||||
serverApp := app.NewServer("127.0.0.1:0", nil, nil)
|
||||
|
||||
err := run(
|
||||
context.Background(),
|
||||
|
||||
19
docker-compose.yml
Normal file
19
docker-compose.yml
Normal file
@@ -0,0 +1,19 @@
|
||||
services:
|
||||
sub2api-cn-relay-manager:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
image: sub2api-cn-relay-manager:local
|
||||
restart: unless-stopped
|
||||
env_file:
|
||||
- .env
|
||||
ports:
|
||||
- "8080:8080"
|
||||
volumes:
|
||||
- ./data:/data
|
||||
healthcheck:
|
||||
test: ["CMD", "sh", "-c", "wget -qO- http://127.0.0.1:8080/healthz >/dev/null"]
|
||||
interval: 15s
|
||||
timeout: 3s
|
||||
retries: 5
|
||||
start_period: 10s
|
||||
36
docs/DEPLOYMENT.md
Normal file
36
docs/DEPLOYMENT.md
Normal file
@@ -0,0 +1,36 @@
|
||||
# Deployment
|
||||
|
||||
## Environment
|
||||
|
||||
Required:
|
||||
|
||||
- `SUB2API_CRM_ADMIN_TOKEN`: control-plane bearer token
|
||||
|
||||
Optional:
|
||||
|
||||
- `SUB2API_CRM_LISTEN_ADDR` (default `:8080`)
|
||||
- `SUB2API_CRM_SQLITE_DSN` (default `file:sub2api-cn-relay-manager.db?_foreign_keys=on&_busy_timeout=5000`)
|
||||
|
||||
## Local Docker Compose
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
# edit SUB2API_CRM_ADMIN_TOKEN before startup
|
||||
mkdir -p data
|
||||
docker compose up --build -d
|
||||
curl -fsS http://127.0.0.1:8080/healthz
|
||||
```
|
||||
|
||||
## Standalone Binary
|
||||
|
||||
```bash
|
||||
go build -o bin/sub2api-cn-relay-manager ./cmd/server
|
||||
SUB2API_CRM_ADMIN_TOKEN=replace-me ./bin/sub2api-cn-relay-manager
|
||||
```
|
||||
|
||||
## Runtime Notes
|
||||
|
||||
- SQLite file should be mounted on persistent storage.
|
||||
- Admin token must be rotated outside source control.
|
||||
- The service is stateless except for SQLite runtime state.
|
||||
- Use `/healthz` for container liveness checks.
|
||||
111
docs/EXECUTION_BOARD.md
Normal file
111
docs/EXECUTION_BOARD.md
Normal file
@@ -0,0 +1,111 @@
|
||||
# sub2api-cn-relay-manager 执行板
|
||||
|
||||
日期:2026-05-13
|
||||
当前 Gate:REQUEST_CHANGES
|
||||
目标:实现 implementation plan 全量能力,达成独立控制面、零侵入宿主、一键导入国产模型,并补齐回滚/对账/HTTP API/交付物。
|
||||
|
||||
## 当前真实状态
|
||||
|
||||
模块完成 gate(新增执行要求,后续每个大模块都必须执行):
|
||||
- 仅 `go test` 通过不算完成;每次完成大模块后,必须补做:
|
||||
1. 两阶段 review(先对规划/设计文档做实现对齐检查,再做代码质量 review)
|
||||
2. execution board 当前状态同步
|
||||
3. 若发现实现/设计漂移,优先修正文档结论或回退模块状态,不维持虚假 `COMPLETED`
|
||||
- 本板从本次起按上述 gate 维护。
|
||||
|
||||
已完成:
|
||||
- 项目骨架与配置加载
|
||||
- SQLite 最小状态库(hosts/packs/providers)
|
||||
- SQLite 运行态状态库扩展(import_batches / items / managed_resources / probe_results / access_closure_records / reconcile_runs)
|
||||
- sub2api HostAdapter 基础创建/探测能力
|
||||
- HostAdapter 删除能力(group/channel/account;plan 接口已补)
|
||||
- HostAdapter 资源枚举能力(groups/channels/plans/accounts)
|
||||
- import strict 模式自动回滚已接入
|
||||
- 手动 rollback CLI(`rollback-provider`)已接入,支持按 provider 名称规则回收 group/channel/plan/accounts
|
||||
- pack 目录装载与 checksum/schema 校验
|
||||
- 正式 pack install 生命周期已接入:支持 zip/目录装载、宿主版本兼容校验、pack/provider 元数据持久化、CLI `install-pack`
|
||||
- CLI `import-provider` 导入闭环已接入 SQLite 运行态持久化(host/pack/provider/import/probe/access)
|
||||
- CLI `preview-provider` 预检查入口
|
||||
- 最小 HTTP 控制面已接入:admin token 鉴权 + `/api/packs/install` + `/api/providers/{providerID}/preview-import` + `/api/providers/{providerID}/import` + `/api/import-batches/{batchID}` + `/api/providers/{providerID}/status` + `/api/providers/{providerID}/resources` + `/api/providers/{providerID}/access/status` + `/api/providers/{providerID}/rollback` + `/api/providers/{providerID}/reconcile`
|
||||
- preview 已接入宿主资源快照查询
|
||||
- 账号探测与 `/v1/models` 网关访问验证
|
||||
|
||||
未完成的关键事实:
|
||||
- 状态库已接入 `import-provider` 运行链并可持久化 host/pack/provider/import/probe/access;最小 HTTP 控制面已补齐 batch detail / provider status / resources / access status / rollback / reconcile,OpenAPI 草案已同步扩展
|
||||
- preview/import/rollback/reconcile 已有 CLI 与最小 HTTP 入口,但仍缺少 hosts 管理面与更完整的批次/对账操作文档输出
|
||||
- 宿主资源枚举已实现,但尚未对真实 sub2api 版本做兼容性实测
|
||||
- 最小 reconcile / drift detection 已接入,当前实现仍是 `internal/provision/batch_detail_and_reconcile_service.go` 内联版本,但已补齐对最新 batch 的 account smoke probe 重跑、access closure 复检与 reconcile summary 持久化;状态仍未完全对齐 implementation plan 目标中的 `internal/reconcile/*` 结构,且真实宿主兼容性实测未完成
|
||||
- OpenAPI 草案已覆盖 status/resources/access-status,但仍未收口 hosts 契约与生产级文档细节
|
||||
- 无 scheduler/jobs
|
||||
- 已补齐 Dockerfile / compose / .env.example / deployment 文档,并新增 distribution smoke test;但尚无真实容器启动 E2E 执行记录
|
||||
|
||||
## P0(必须先完成)
|
||||
|
||||
### P0-1 状态库扩展并接入运行链
|
||||
- 状态:COMPLETED(schema/repo、`import-provider` 运行链消费、`batch detail` / `provider status` / `resources` / `access status` / `reconcile` 查询面均已接入)
|
||||
- 目标:补齐 implementation plan 所需核心表与 repo
|
||||
- 范围:`import_batches`、`import_batch_items`、`managed_resources`、`probe_results`、`access_closure_records`、`reconcile_runs`
|
||||
- 验证:`go test ./tests/integration -run 'TestStore(Runtime|Init)' -count=1`
|
||||
- 完成判据:表存在、约束有效、事务回滚有效、repo 可写入读取,并被运行链消费
|
||||
|
||||
### P0-2 import preview + naming
|
||||
- 目标:导入前可输出 create/reuse/conflict,不盲写宿主
|
||||
- 范围:`preview_service.go`、`naming.go`、`import_preview_test.go`
|
||||
- 验证:`go test ./tests/integration -run TestImportPreview -v`
|
||||
|
||||
### P0-3 真实 rollback 闭环
|
||||
- 状态:PARTIAL(strict 自动回滚 + 手动 rollback CLI + HTTP rollback API 已完成;真实宿主兼容性实测未完成)
|
||||
- 目标:strict 失败自动清理,支持手动 rollback
|
||||
- 前置:HostAdapter 增加 DeleteGroup/DeleteChannel/DeletePlan/DeleteAccount/ListManagedResources
|
||||
- 验证:`go test ./internal/provision ./tests/integration ./cmd/cli -run 'TestRollback|TestExecuteRollbackProviderWritesSummary|TestSub2APIHostAdapterListManagedResources' -v`
|
||||
|
||||
### P0-4 正式 pack install 生命周期
|
||||
- 状态:COMPLETED(zip/目录装载、宿主版本兼容性校验、pack/provider 元数据持久化、CLI `install-pack` 已接入)
|
||||
- 目标:支持 zip/目录装载、宿主版本兼容性校验、pack/provider 元数据持久化
|
||||
- 验证:`go test ./internal/pack ./internal/provision ./cmd/cli ./tests/integration -v`
|
||||
|
||||
## P1(形成真正控制面)
|
||||
|
||||
### P1-1 Access 独立模块化
|
||||
- 状态:PARTIAL(访问闭环校验/订阅分配/网关探测已从 `import_service` 抽离到 `internal/access/closure.go`,但 implementation plan 目标结构中的 `planner.go` / `subscription_service.go` / `self_service_checker.go` 仍未落地)
|
||||
- 目标:将访问闭环从 import_service 解耦为 `internal/access/*`
|
||||
- 设计对齐复核:当前已完成的是“最小闭环抽离”,未达到 implementation plan 中 Access 子模块拆分粒度;因此不再维持 `COMPLETED`
|
||||
- 验证:`go test ./internal/access ./internal/provision -count=1`
|
||||
|
||||
### P1-2 Reconcile / Drift Detection
|
||||
- 状态:PARTIAL(最小 reconcile API + drift 计数写入已接入;本轮新增 account smoke probe 重跑、access closure 复检、`active/degraded/drifted` 状态语义与回写验证,但 implementation plan 目标中的 `internal/reconcile/*` 结构、`failed` 语义收口与真实宿主兼容性实测仍未完成)
|
||||
- 目标:拉宿主快照,对比状态库,重跑 probe,标记 drifted
|
||||
- 验证:`go test ./internal/provision ./internal/app ./tests/integration -run 'TestReconcileService|TestAPIReconcileProviderReturnsSummary|TestStore(Runtime|Init)' -count=1`
|
||||
|
||||
### P1-3 HTTP API + OpenAPI
|
||||
- 状态:PARTIAL(`/api/packs/install`、`/api/providers/{providerID}/preview-import`、`/api/providers/{providerID}/import`、`/api/import-batches/{batchID}`、`/api/providers/{providerID}/status`、`/api/providers/{providerID}/resources`、`/api/providers/{providerID}/access/status`、`/api/providers/{providerID}/rollback`、`/api/providers/{providerID}/reconcile` 已接入;OpenAPI 草案已同步扩展,但 hosts 管理面仍缺失)
|
||||
- 目标:暴露 hosts / packs/install / providers preview-import / imports rollback / access / reconcile
|
||||
- 验证:`go test ./internal/app ./cmd/server ./tests/integration -run 'TestAPI|TestBootstrap' -v`
|
||||
|
||||
## P2(工程化交付)
|
||||
|
||||
### P2-1 Scheduler / Jobs
|
||||
- 目标:支持定时 reconcile 与手动触发
|
||||
- 验证:`go test ./tests/integration -run TestCLIScheduler -v`
|
||||
|
||||
### P2-2 Distribution Artifacts
|
||||
- 状态:PARTIAL(已补齐 `Dockerfile` / `.env.example` / `docker-compose.yml` / `docs/DEPLOYMENT.md`,并新增 distribution smoke test;但尚无真实容器启动与镜像构建 E2E 记录)
|
||||
- 目标:Dockerfile / .env.example / docker-compose / deployment 文档 / e2e 脚本
|
||||
- 验证:`go test ./tests/integration -run TestDistributionArtifactsExistAndReferenceRequiredEnv -v`
|
||||
|
||||
### P2-3 CLI 面板补齐
|
||||
- 目标:`host add` / `pack install` / `provider import` / `reconcile run`
|
||||
- 验证:CLI 集成测试 + `go test ./...`
|
||||
|
||||
## 当前执行顺序
|
||||
1. P1-1 Access 模块继续拆分到 implementation plan 粒度
|
||||
2. P1-2 Reconcile 结构化与真实宿主兼容性实测
|
||||
3. P1-3 Hosts 管理面 / OpenAPI 收口
|
||||
4. P2-1 Scheduler / Jobs
|
||||
5. P2-2 Distribution 容器级 E2E 验证
|
||||
6. P2-3 CLI 全量收口
|
||||
|
||||
## 禁止错误结论
|
||||
- `go test ./...` 当前通过 ≠ implementation plan 全部实现
|
||||
- CLI 最小导入闭环 ≠ 独立控制面已完成
|
||||
- 资源创建成功 ≠ 用户访问闭环已长期可运维
|
||||
45
docs/PRD.md
Normal file
45
docs/PRD.md
Normal file
@@ -0,0 +1,45 @@
|
||||
# sub2api-cn-relay-manager PRD(MVP)
|
||||
|
||||
日期:2026-05-13
|
||||
|
||||
## 目标
|
||||
|
||||
在**完全不修改 sub2api 官方系统代码**的前提下,交付一个可独立打包运行的外部伴生项目,使管理员能够通过一次导入动作,把国产模型 OpenAI 兼容中转能力安装到任意一套兼容的 sub2api 实例中。
|
||||
|
||||
## 硬约束
|
||||
|
||||
1. 不修改宿主源码
|
||||
2. 不 fork 宿主并运行私有二进制
|
||||
3. 不直接写宿主数据库
|
||||
4. 不向宿主目录注入插件代码或补丁文件
|
||||
5. 仅通过宿主现有 HTTP 管理 API 与标准 API 工作
|
||||
|
||||
## 首版验收
|
||||
|
||||
1. `model_pack` 可独立校验与装载
|
||||
2. CLI 可直接读取 pack、选择 provider、导入多条 key
|
||||
3. 导入流程能创建 group / channel / plan(subscription 模式)/ accounts
|
||||
4. 至少一个 account 完成 `/test` 与 `/models` 验证
|
||||
5. 至少一种普通用户访问路径被真实探测:`GET /v1/models`
|
||||
6. 失败时明确区分:`succeeded / partially_succeeded / failed`
|
||||
|
||||
## 首版边界
|
||||
|
||||
### 做
|
||||
- pack runtime
|
||||
- provider schema 校验
|
||||
- 多 key 去重与批量导入
|
||||
- subscription/self-service 两种访问模式建模
|
||||
- CLI 一键导入
|
||||
- 基于 stub 的端到端测试
|
||||
|
||||
### 暂不做
|
||||
- Web 控制台
|
||||
- 多宿主管理
|
||||
- 自动代用户签发最终 API key
|
||||
- 对账调度器完整实现
|
||||
- 真实宿主删除/回滚链路
|
||||
|
||||
## 当前实现策略
|
||||
|
||||
首版先把“可独立打包 + 零侵入导入 + 用户访问验证”做成最小闭环;状态库、HTTP 控制面、对账调度在此基础上继续扩展。
|
||||
41
docs/TDD_PLAN.md
Normal file
41
docs/TDD_PLAN.md
Normal file
@@ -0,0 +1,41 @@
|
||||
# TDD 实施计划(MVP)
|
||||
|
||||
日期:2026-05-13
|
||||
|
||||
## 设计结论
|
||||
|
||||
首版采用:
|
||||
|
||||
- `packs/openai-cn-pack/`:真实可校验模型包
|
||||
- `internal/pack`:pack 装载、checksum 校验、provider schema 校验
|
||||
- `internal/provision`:导入编排服务
|
||||
- `internal/host/sub2api`:宿主 admin/gateway 适配
|
||||
- `cmd/cli import-provider`:一键导入入口
|
||||
|
||||
## TDD 顺序
|
||||
|
||||
1. 先写 `internal/pack/loader_test.go`
|
||||
- 成功装载 pack
|
||||
- checksum mismatch 失败
|
||||
- provider schema 非法失败
|
||||
2. 再写 `internal/provision/import_service_test.go`
|
||||
- subscription 模式成功导入
|
||||
- strict 模式探测失败直接失败
|
||||
- 参数非法拒绝
|
||||
3. 再补宿主适配器集成测试
|
||||
- `CheckGatewayAccess()` 能校验 `/v1/models`
|
||||
4. 最后补 CLI 测试
|
||||
- `import-provider` 参数解析
|
||||
- 输出状态摘要
|
||||
|
||||
## 当前 MVP 风险
|
||||
|
||||
1. 回滚删除链路尚未接入真实宿主 HTTP 路径,当前仅在服务层保留失败状态,不宣称真实宿主回滚已闭环
|
||||
2. 现有集成验证基于 `httptest` stub,尚未对真实 sub2api 版本做兼容性实测
|
||||
3. 状态库尚未承接 import batch / managed resources / reconcile runs 持久化
|
||||
|
||||
## 完成标准
|
||||
|
||||
- `go test ./...` 通过
|
||||
- CLI 能从真实 pack 读取 provider
|
||||
- 导入报告明确输出 batch/provider/access 三种状态
|
||||
265
docs/openapi.yaml
Normal file
265
docs/openapi.yaml
Normal file
@@ -0,0 +1,265 @@
|
||||
openapi: 3.1.0
|
||||
info:
|
||||
title: sub2api-cn-relay-manager API
|
||||
version: 0.1.0
|
||||
servers:
|
||||
- url: /
|
||||
paths:
|
||||
/healthz:
|
||||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: ok
|
||||
/api/packs/install:
|
||||
post:
|
||||
security:
|
||||
- bearerAuth: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/InstallPackRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: pack installed
|
||||
/api/import-batches/{batchID}:
|
||||
get:
|
||||
security:
|
||||
- bearerAuth: []
|
||||
parameters:
|
||||
- name: batchID
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: integer
|
||||
format: int64
|
||||
responses:
|
||||
'200':
|
||||
description: batch detail
|
||||
/api/providers/{providerID}/status:
|
||||
get:
|
||||
security:
|
||||
- bearerAuth: []
|
||||
parameters:
|
||||
- name: providerID
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: pack_id
|
||||
in: query
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
'200':
|
||||
description: provider runtime status
|
||||
/api/providers/{providerID}/resources:
|
||||
get:
|
||||
security:
|
||||
- bearerAuth: []
|
||||
parameters:
|
||||
- name: providerID
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: pack_id
|
||||
in: query
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
'200':
|
||||
description: provider managed resources snapshot
|
||||
/api/providers/{providerID}/access/status:
|
||||
get:
|
||||
security:
|
||||
- bearerAuth: []
|
||||
parameters:
|
||||
- name: providerID
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: pack_id
|
||||
in: query
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
'200':
|
||||
description: provider access closure status
|
||||
/api/providers/{providerID}/preview-import:
|
||||
post:
|
||||
security:
|
||||
- bearerAuth: []
|
||||
parameters:
|
||||
- name: providerID
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/PreviewProviderRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: preview summary
|
||||
/api/providers/{providerID}/import:
|
||||
post:
|
||||
security:
|
||||
- bearerAuth: []
|
||||
parameters:
|
||||
- name: providerID
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ImportProviderRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: import summary
|
||||
/api/providers/{providerID}/rollback:
|
||||
post:
|
||||
security:
|
||||
- bearerAuth: []
|
||||
parameters:
|
||||
- name: providerID
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/RollbackProviderRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: rollback summary
|
||||
/api/providers/{providerID}/reconcile:
|
||||
post:
|
||||
security:
|
||||
- bearerAuth: []
|
||||
parameters:
|
||||
- name: providerID
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ReconcileProviderRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: reconcile summary
|
||||
components:
|
||||
securitySchemes:
|
||||
bearerAuth:
|
||||
type: http
|
||||
scheme: bearer
|
||||
schemas:
|
||||
InstallPackRequest:
|
||||
type: object
|
||||
required: [host_base_url, pack_path]
|
||||
properties:
|
||||
host_base_url:
|
||||
type: string
|
||||
host_api_key:
|
||||
type: string
|
||||
host_bearer_token:
|
||||
type: string
|
||||
pack_path:
|
||||
type: string
|
||||
PreviewProviderRequest:
|
||||
type: object
|
||||
required: [host_base_url, pack_path, keys]
|
||||
properties:
|
||||
host_base_url:
|
||||
type: string
|
||||
host_api_key:
|
||||
type: string
|
||||
host_bearer_token:
|
||||
type: string
|
||||
pack_path:
|
||||
type: string
|
||||
provider_id:
|
||||
type: string
|
||||
keys:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
mode:
|
||||
type: string
|
||||
ImportProviderRequest:
|
||||
type: object
|
||||
required: [host_base_url, pack_path, keys, access_api_key]
|
||||
properties:
|
||||
host_base_url:
|
||||
type: string
|
||||
host_api_key:
|
||||
type: string
|
||||
host_bearer_token:
|
||||
type: string
|
||||
pack_path:
|
||||
type: string
|
||||
provider_id:
|
||||
type: string
|
||||
keys:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
mode:
|
||||
type: string
|
||||
access_mode:
|
||||
type: string
|
||||
access_api_key:
|
||||
type: string
|
||||
subscription_users:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
subscription_days:
|
||||
type: integer
|
||||
RollbackProviderRequest:
|
||||
type: object
|
||||
required: [host_base_url, pack_path]
|
||||
properties:
|
||||
host_base_url:
|
||||
type: string
|
||||
host_api_key:
|
||||
type: string
|
||||
host_bearer_token:
|
||||
type: string
|
||||
pack_path:
|
||||
type: string
|
||||
provider_id:
|
||||
type: string
|
||||
ReconcileProviderRequest:
|
||||
type: object
|
||||
required: [host_base_url, pack_path]
|
||||
properties:
|
||||
host_base_url:
|
||||
type: string
|
||||
host_api_key:
|
||||
type: string
|
||||
host_bearer_token:
|
||||
type: string
|
||||
pack_path:
|
||||
type: string
|
||||
provider_id:
|
||||
type: string
|
||||
80
internal/access/closure.go
Normal file
80
internal/access/closure.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package access
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
)
|
||||
|
||||
const (
|
||||
ModeSubscription = "subscription"
|
||||
ModeSelfService = "self_service"
|
||||
)
|
||||
|
||||
type SubscriptionTarget struct {
|
||||
UserID string
|
||||
DurationDays int
|
||||
}
|
||||
|
||||
type ClosureRequest struct {
|
||||
Mode string
|
||||
ProbeAPIKey string
|
||||
Subscriptions []SubscriptionTarget
|
||||
GroupID string
|
||||
ExpectedModel string
|
||||
}
|
||||
|
||||
type Host interface {
|
||||
AssignSubscription(ctx context.Context, req sub2api.AssignSubscriptionRequest) (sub2api.SubscriptionRef, error)
|
||||
CheckGatewayAccess(ctx context.Context, req sub2api.GatewayAccessCheckRequest) (sub2api.GatewayAccessResult, error)
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
host Host
|
||||
}
|
||||
|
||||
func NewService(host Host) *Service {
|
||||
return &Service{host: host}
|
||||
}
|
||||
|
||||
func Validate(req ClosureRequest) error {
|
||||
switch strings.TrimSpace(req.Mode) {
|
||||
case ModeSubscription:
|
||||
if len(req.Subscriptions) == 0 {
|
||||
return fmt.Errorf("subscription access requires at least one subscription target")
|
||||
}
|
||||
case ModeSelfService:
|
||||
if strings.TrimSpace(req.ProbeAPIKey) == "" {
|
||||
return fmt.Errorf("self_service access requires probe api key")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported access mode %q", req.Mode)
|
||||
}
|
||||
if strings.TrimSpace(req.ProbeAPIKey) == "" {
|
||||
return fmt.Errorf("access probe api key is required to verify gateway closure")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) Close(ctx context.Context, req ClosureRequest) (sub2api.GatewayAccessResult, error) {
|
||||
if s == nil || s.host == nil {
|
||||
return sub2api.GatewayAccessResult{}, fmt.Errorf("access host is required")
|
||||
}
|
||||
if err := Validate(req); err != nil {
|
||||
return sub2api.GatewayAccessResult{}, err
|
||||
}
|
||||
if strings.TrimSpace(req.Mode) == ModeSubscription {
|
||||
for _, target := range req.Subscriptions {
|
||||
if _, err := s.host.AssignSubscription(ctx, sub2api.AssignSubscriptionRequest{UserID: target.UserID, GroupID: req.GroupID, DurationDays: target.DurationDays}); err != nil {
|
||||
return sub2api.GatewayAccessResult{}, fmt.Errorf("assign subscription for %s: %w", target.UserID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
result, err := s.host.CheckGatewayAccess(ctx, sub2api.GatewayAccessCheckRequest{APIKey: req.ProbeAPIKey, ExpectedModel: req.ExpectedModel})
|
||||
if err != nil {
|
||||
return sub2api.GatewayAccessResult{}, fmt.Errorf("check gateway access: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
91
internal/access/closure_test.go
Normal file
91
internal/access/closure_test.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package access
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
)
|
||||
|
||||
func TestValidateRejectsMissingProbeAPIKeyForSelfService(t *testing.T) {
|
||||
err := Validate(ClosureRequest{Mode: "self_service"})
|
||||
if err == nil {
|
||||
t.Fatal("Validate() error = nil, want validation error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRejectsMissingSubscriptionsForSubscriptionMode(t *testing.T) {
|
||||
err := Validate(ClosureRequest{Mode: "subscription", ProbeAPIKey: "user-key"})
|
||||
if err == nil {
|
||||
t.Fatal("Validate() error = nil, want validation error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceCloseAssignsSubscriptionsAndProbesGateway(t *testing.T) {
|
||||
host := &fakeClosureHost{
|
||||
gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}},
|
||||
}
|
||||
service := NewService(host)
|
||||
result, err := service.Close(context.Background(), ClosureRequest{
|
||||
Mode: "subscription",
|
||||
ProbeAPIKey: "user-key",
|
||||
GroupID: "group-1",
|
||||
ExpectedModel: "deepseek-chat",
|
||||
Subscriptions: []SubscriptionTarget{{UserID: "user-1", DurationDays: 30}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Close() error = %v", err)
|
||||
}
|
||||
if len(host.assigned) != 1 {
|
||||
t.Fatalf("assigned subscriptions = %d, want 1", len(host.assigned))
|
||||
}
|
||||
if host.gatewayProbe.APIKey != "user-key" || host.gatewayProbe.ExpectedModel != "deepseek-chat" {
|
||||
t.Fatalf("gateway probe = %+v, want api key + expected model", host.gatewayProbe)
|
||||
}
|
||||
if !result.OK || !result.HasExpectedModel {
|
||||
t.Fatalf("gateway result = %+v, want success", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceCloseReturnsSubscriptionErrorBeforeGatewayProbe(t *testing.T) {
|
||||
host := &fakeClosureHost{assignErr: errors.New("assign failed")}
|
||||
service := NewService(host)
|
||||
_, err := service.Close(context.Background(), ClosureRequest{
|
||||
Mode: "subscription",
|
||||
ProbeAPIKey: "user-key",
|
||||
GroupID: "group-1",
|
||||
ExpectedModel: "deepseek-chat",
|
||||
Subscriptions: []SubscriptionTarget{{UserID: "user-1", DurationDays: 30}},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Close() error = nil, want subscription failure")
|
||||
}
|
||||
if host.gatewayProbe.APIKey != "" {
|
||||
t.Fatalf("gateway probe should not run after subscription error, got %+v", host.gatewayProbe)
|
||||
}
|
||||
}
|
||||
|
||||
type fakeClosureHost struct {
|
||||
assigned []sub2api.AssignSubscriptionRequest
|
||||
assignErr error
|
||||
gatewayProbe sub2api.GatewayAccessCheckRequest
|
||||
gatewayResult sub2api.GatewayAccessResult
|
||||
gatewayErr error
|
||||
}
|
||||
|
||||
func (f *fakeClosureHost) AssignSubscription(_ context.Context, req sub2api.AssignSubscriptionRequest) (sub2api.SubscriptionRef, error) {
|
||||
if f.assignErr != nil {
|
||||
return sub2api.SubscriptionRef{}, f.assignErr
|
||||
}
|
||||
f.assigned = append(f.assigned, req)
|
||||
return sub2api.SubscriptionRef{ID: "sub-1"}, nil
|
||||
}
|
||||
|
||||
func (f *fakeClosureHost) CheckGatewayAccess(_ context.Context, req sub2api.GatewayAccessCheckRequest) (sub2api.GatewayAccessResult, error) {
|
||||
f.gatewayProbe = req
|
||||
if f.gatewayErr != nil {
|
||||
return sub2api.GatewayAccessResult{}, f.gatewayErr
|
||||
}
|
||||
return f.gatewayResult, nil
|
||||
}
|
||||
@@ -15,25 +15,20 @@ type Server struct {
|
||||
listen ListenerFactory
|
||||
}
|
||||
|
||||
func NewServer(listenAddr string, listenerFactory ListenerFactory) *Server {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
})
|
||||
|
||||
func NewServer(listenAddr string, handler http.Handler, listenerFactory ListenerFactory) *Server {
|
||||
if handler == nil {
|
||||
handler = NewAPIHandler("", ActionSet{})
|
||||
}
|
||||
server := &Server{
|
||||
server: &http.Server{
|
||||
Addr: listenAddr,
|
||||
Handler: mux,
|
||||
Handler: handler,
|
||||
},
|
||||
listen: net.Listen,
|
||||
}
|
||||
|
||||
if listenerFactory != nil {
|
||||
server.listen = listenerFactory
|
||||
}
|
||||
|
||||
return server
|
||||
}
|
||||
|
||||
@@ -46,13 +41,11 @@ func (s *Server) Run(ctx context.Context) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.Serve(ctx, listener)
|
||||
}
|
||||
|
||||
func (s *Server) Serve(ctx context.Context, listener net.Listener) error {
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
err := s.server.Serve(listener)
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
@@ -65,11 +58,9 @@ func (s *Server) Serve(ctx context.Context, listener net.Listener) error {
|
||||
case <-ctx.Done():
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.server.Shutdown(shutdownCtx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return <-errCh
|
||||
case err := <-errCh:
|
||||
return err
|
||||
|
||||
@@ -1,17 +1,26 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
"sub2api-cn-relay-manager/internal/pack"
|
||||
"sub2api-cn-relay-manager/internal/provision"
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
func TestServeExposesHealthz(t *testing.T) {
|
||||
server := NewServer("127.0.0.1:0", nil)
|
||||
server := NewServer("127.0.0.1:0", NewAPIHandler("admin-token", ActionSet{}), nil)
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("net.Listen() error = %v", err)
|
||||
@@ -50,7 +59,7 @@ func TestRunReturnsAfterContextCancellation(t *testing.T) {
|
||||
t.Fatalf("net.Listen() error = %v", err)
|
||||
}
|
||||
|
||||
server := NewServer("127.0.0.1:0", func(string, string) (net.Listener, error) {
|
||||
server := NewServer("127.0.0.1:0", NewAPIHandler("admin-token", ActionSet{}), func(string, string) (net.Listener, error) {
|
||||
return listener, nil
|
||||
})
|
||||
|
||||
@@ -77,7 +86,7 @@ func TestRunReturnsAfterContextCancellation(t *testing.T) {
|
||||
|
||||
func TestRunReturnsListenError(t *testing.T) {
|
||||
wantErr := errors.New("listen failed")
|
||||
server := NewServer("127.0.0.1:0", func(string, string) (net.Listener, error) {
|
||||
server := NewServer("127.0.0.1:0", NewAPIHandler("admin-token", ActionSet{}), func(string, string) (net.Listener, error) {
|
||||
return nil, wantErr
|
||||
})
|
||||
|
||||
@@ -88,7 +97,7 @@ func TestRunReturnsListenError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServeReturnsListenerError(t *testing.T) {
|
||||
server := NewServer("127.0.0.1:0", nil)
|
||||
server := NewServer("127.0.0.1:0", NewAPIHandler("admin-token", ActionSet{}), nil)
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("net.Listen() error = %v", err)
|
||||
@@ -104,6 +113,208 @@ func TestServeReturnsListenerError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIRejectsMissingAdminToken(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{})
|
||||
request := httptestRequest(t, http.MethodPost, "/api/packs/install", map[string]any{"host_base_url": "https://sub2api.example.com", "pack_path": "/tmp/pack.zip"}, "")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusUnauthorized)
|
||||
assertJSONContains(t, response.Body().Bytes(), "error.code", "unauthorized")
|
||||
}
|
||||
|
||||
func TestAPIInstallPackReturnsSummary(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{
|
||||
InstallPack: func(context.Context, InstallPackRequest) (provision.PackInstallResult, error) {
|
||||
return provision.PackInstallResult{
|
||||
Pack: sqlite.Pack{PackID: "openai-cn-pack", Version: "1.0.0"},
|
||||
HostVersion: "0.1.126",
|
||||
Providers: []sqlite.Provider{{ProviderID: "deepseek", DisplayName: "DeepSeek"}},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
request := httptestRequest(t, http.MethodPost, "/api/packs/install", map[string]any{"host_base_url": "https://sub2api.example.com", "pack_path": "/tmp/openai-pack.zip"}, "secret-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusOK)
|
||||
assertJSONContains(t, response.Body().Bytes(), "pack_id", "openai-cn-pack")
|
||||
assertJSONContains(t, response.Body().Bytes(), "host_version", "0.1.126")
|
||||
}
|
||||
|
||||
func TestAPIPreviewProviderReturnsSummary(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{
|
||||
PreviewProvider: func(_ context.Context, req PreviewProviderRequest) (provision.PreviewReport, error) {
|
||||
if req.ProviderID != "deepseek" {
|
||||
t.Fatalf("ProviderID = %q, want deepseek", req.ProviderID)
|
||||
}
|
||||
return provision.PreviewReport{
|
||||
AcceptedKeys: []string{"k1", "k2"},
|
||||
Names: provision.ResourceNames{Group: "g", Channel: "c", Plan: "p"},
|
||||
Decisions: map[string]provision.PreviewDecision{
|
||||
"group": {Action: provision.PreviewActionCreate},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
request := httptestRequest(t, http.MethodPost, "/api/providers/deepseek/preview-import", map[string]any{"host_base_url": "https://sub2api.example.com", "pack_path": "/tmp/openai-pack.zip", "keys": []string{"k1", "k2"}, "mode": "partial"}, "secret-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusOK)
|
||||
assertJSONContains(t, response.Body().Bytes(), "accepted_keys_count", float64(2))
|
||||
}
|
||||
|
||||
func TestAPIImportProviderReturnsConflictWithBatchStatus(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{
|
||||
ImportProvider: func(context.Context, ImportProviderRequest) (provision.RuntimeImportResult, error) {
|
||||
return provision.RuntimeImportResult{
|
||||
BatchID: 12,
|
||||
Report: provision.ImportReport{
|
||||
BatchStatus: provision.BatchStatusFailed,
|
||||
ProviderStatus: provision.ProviderStatusFailed,
|
||||
AccessStatus: provision.AccessStatusBroken,
|
||||
Accounts: []provision.AccountImportResult{{}},
|
||||
},
|
||||
}, errors.New("strict import failed")
|
||||
},
|
||||
})
|
||||
request := httptestRequest(t, http.MethodPost, "/api/providers/deepseek/import", map[string]any{"host_base_url": "https://sub2api.example.com", "pack_path": "/tmp/openai-pack.zip", "keys": []string{"k1"}, "mode": "strict", "access_mode": "self_service", "access_api_key": "user-key"}, "secret-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusConflict)
|
||||
assertJSONContains(t, response.Body().Bytes(), "batch_id", float64(12))
|
||||
assertJSONContains(t, response.Body().Bytes(), "batch_status", provision.BatchStatusFailed)
|
||||
}
|
||||
|
||||
func TestAPIBatchDetailReturnsSummary(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{
|
||||
BatchDetail: func(context.Context, BatchDetailRequest) (provision.BatchDetailResult, error) {
|
||||
return provision.BatchDetailResult{
|
||||
Batch: sqlite.ImportBatch{ID: 7, BatchStatus: "running", AccessStatus: "pending"},
|
||||
Items: []sqlite.ImportBatchItem{{ID: 1, KeyFingerprint: "sha256:abc", AccountStatus: "passed"}},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
request := httptestRequest(t, http.MethodGet, "/api/import-batches/7", nil, "secret-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusOK)
|
||||
assertJSONContains(t, response.Body().Bytes(), "batch.batch_status", "running")
|
||||
assertJSONContains(t, response.Body().Bytes(), "items_count", float64(1))
|
||||
}
|
||||
|
||||
func TestAPIProviderStatusReturnsSummary(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{
|
||||
GetProviderStatus: func(_ context.Context, req ProviderQueryRequest) (provision.ProviderSnapshot, error) {
|
||||
if req.ProviderID != "deepseek" {
|
||||
t.Fatalf("ProviderID = %q, want deepseek", req.ProviderID)
|
||||
}
|
||||
if req.PackID != "openai-cn-pack" {
|
||||
t.Fatalf("PackID = %q, want openai-cn-pack", req.PackID)
|
||||
}
|
||||
return provision.ProviderSnapshot{
|
||||
Host: sqlite.Host{HostID: "host-1", BaseURL: "https://sub2api.example.com", HostVersion: "0.1.126"},
|
||||
Pack: sqlite.Pack{PackID: "openai-cn-pack", Version: "1.0.0"},
|
||||
Provider: sqlite.Provider{ProviderID: "deepseek", DisplayName: "DeepSeek", Platform: "openai"},
|
||||
Batch: sqlite.ImportBatch{ID: 7, BatchStatus: provision.BatchStatusSucceeded, AccessStatus: provision.AccessStatusSelfServiceReady, Mode: provision.ImportModeStrict},
|
||||
ProviderStatus: "drifted",
|
||||
LatestAccessStatus: provision.AccessStatusSelfServiceReady,
|
||||
LatestReconcileStatus: "drifted",
|
||||
LatestReconcileSummary: map[string]any{"missing_count": 1},
|
||||
ManagedResources: []sqlite.ManagedResource{{}, {}},
|
||||
AccessClosures: []sqlite.AccessClosureRecord{{}},
|
||||
ReconcileRuns: []sqlite.ReconcileRun{{}},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
request := httptestRequest(t, http.MethodGet, "/api/providers/deepseek/status?pack_id=openai-cn-pack", nil, "secret-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusOK)
|
||||
assertJSONContains(t, response.Body().Bytes(), "provider_status", "drifted")
|
||||
assertJSONContains(t, response.Body().Bytes(), "managed_resources_count", float64(2))
|
||||
assertJSONContains(t, response.Body().Bytes(), "latest_reconcile_summary.missing_count", float64(1))
|
||||
}
|
||||
|
||||
func TestAPIProviderAccessStatusReturnsSummary(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{
|
||||
GetProviderAccessStatus: func(_ context.Context, req ProviderQueryRequest) (provision.ProviderSnapshot, error) {
|
||||
if req.ProviderID != "deepseek" {
|
||||
t.Fatalf("ProviderID = %q, want deepseek", req.ProviderID)
|
||||
}
|
||||
return provision.ProviderSnapshot{
|
||||
Pack: sqlite.Pack{PackID: "openai-cn-pack"},
|
||||
Provider: sqlite.Provider{ProviderID: "deepseek"},
|
||||
Batch: sqlite.ImportBatch{ID: 7, AccessStatus: provision.AccessStatusSelfServiceReady},
|
||||
LatestAccessStatus: provision.AccessStatusSelfServiceReady,
|
||||
AccessClosures: []sqlite.AccessClosureRecord{{ID: 2, ClosureType: provision.AccessModeSelfService, Status: provision.AccessStatusSelfServiceReady, DetailsJSON: `{"ok":true}`}},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
request := httptestRequest(t, http.MethodGet, "/api/providers/deepseek/access/status", nil, "secret-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusOK)
|
||||
assertJSONContains(t, response.Body().Bytes(), "latest_access_status", provision.AccessStatusSelfServiceReady)
|
||||
assertJSONContains(t, response.Body().Bytes(), "closures_count", float64(1))
|
||||
if !strings.Contains(response.Body().String(), `"closure_type":"self_service"`) {
|
||||
t.Fatalf("access status payload missing closure type: %s", response.Body().String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIProviderResourcesReturnsSummary(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{
|
||||
GetProviderResources: func(_ context.Context, req ProviderQueryRequest) (provision.ProviderSnapshot, error) {
|
||||
if req.ProviderID != "deepseek" {
|
||||
t.Fatalf("ProviderID = %q, want deepseek", req.ProviderID)
|
||||
}
|
||||
return provision.ProviderSnapshot{
|
||||
Pack: sqlite.Pack{PackID: "openai-cn-pack"},
|
||||
Provider: sqlite.Provider{ProviderID: "deepseek"},
|
||||
Batch: sqlite.ImportBatch{ID: 7},
|
||||
ManagedResources: []sqlite.ManagedResource{{ID: 1, ResourceType: "group", HostResourceID: "group-1", ResourceName: "deepseek-group"}},
|
||||
AccessClosures: []sqlite.AccessClosureRecord{{ID: 2, ClosureType: provision.AccessModeSelfService, Status: provision.AccessStatusSelfServiceReady, DetailsJSON: `{"ok":true}`}},
|
||||
ReconcileRuns: []sqlite.ReconcileRun{{ID: 3, Status: "active", SummaryJSON: `{"missing_count":0}`}},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
request := httptestRequest(t, http.MethodGet, "/api/providers/deepseek/resources", nil, "secret-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusOK)
|
||||
assertJSONContains(t, response.Body().Bytes(), "provider_id", "deepseek")
|
||||
assertJSONContains(t, response.Body().Bytes(), "pack_id", "openai-cn-pack")
|
||||
if !strings.Contains(response.Body().String(), `"resource_type":"group"`) {
|
||||
t.Fatalf("resources payload missing group resource: %s", response.Body().String())
|
||||
}
|
||||
if !strings.Contains(response.Body().String(), `"status":"self_service_ready"`) {
|
||||
t.Fatalf("resources payload missing access closure status: %s", response.Body().String())
|
||||
}
|
||||
if !strings.Contains(response.Body().String(), `"summary_json":"{\"missing_count\":0}"`) {
|
||||
t.Fatalf("resources payload missing reconcile summary: %s", response.Body().String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIRollbackProviderReturnsSummary(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{
|
||||
RollbackProvider: func(context.Context, RollbackProviderRequest) (provision.RollbackReport, error) {
|
||||
return provision.RollbackReport{AccountsDeleted: 2, PlansDeleted: 1, ChannelsDeleted: 1, GroupsDeleted: 1}, nil
|
||||
},
|
||||
})
|
||||
request := httptestRequest(t, http.MethodPost, "/api/providers/deepseek/rollback", map[string]any{"host_base_url": "https://sub2api.example.com", "pack_path": "/tmp/openai-pack.zip"}, "secret-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusOK)
|
||||
assertJSONContains(t, response.Body().Bytes(), "deleted_accounts", float64(2))
|
||||
assertJSONContains(t, response.Body().Bytes(), "provider_id", "deepseek")
|
||||
}
|
||||
|
||||
func TestAPIReconcileProviderReturnsSummary(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{
|
||||
ReconcileProvider: func(_ context.Context, req ReconcileProviderRequest) (provision.ReconcileResult, error) {
|
||||
if req.AccessAPIKey != "user-key" {
|
||||
t.Fatalf("AccessAPIKey = %q, want user-key", req.AccessAPIKey)
|
||||
}
|
||||
return provision.ReconcileResult{BatchID: 7, Status: "drifted", MissingCount: 1, ExtraCount: 2, ProbeFailureCount: 1, AccessStatus: provision.AccessStatusBroken, Summary: map[string]any{"probe_failures": 1}}, nil
|
||||
},
|
||||
})
|
||||
request := httptestRequest(t, http.MethodPost, "/api/providers/deepseek/reconcile", map[string]any{"host_base_url": "https://sub2api.example.com", "pack_path": "/tmp/openai-pack.zip", "access_api_key": "user-key"}, "secret-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusOK)
|
||||
assertJSONContains(t, response.Body().Bytes(), "status", "drifted")
|
||||
assertJSONContains(t, response.Body().Bytes(), "missing_count", float64(1))
|
||||
assertJSONContains(t, response.Body().Bytes(), "summary.probe_failures", float64(1))
|
||||
}
|
||||
|
||||
func waitForHealthz(t *testing.T, url string) *http.Response {
|
||||
t.Helper()
|
||||
|
||||
@@ -126,3 +337,613 @@ func waitForHealthz(t *testing.T, url string) *http.Response {
|
||||
t.Fatalf("health endpoint %q was not reachable before deadline", url)
|
||||
return nil
|
||||
}
|
||||
|
||||
func httptestRequest(t *testing.T, method, path string, body any, token string) *http.Request {
|
||||
t.Helper()
|
||||
payload, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal() error = %v", err)
|
||||
}
|
||||
request, err := http.NewRequest(method, path, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
t.Fatalf("http.NewRequest() error = %v", err)
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
if token != "" {
|
||||
request.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
return request
|
||||
}
|
||||
|
||||
func httptestRecorder(handler http.Handler, request *http.Request) *responseRecorder {
|
||||
recorder := &responseRecorder{header: make(http.Header)}
|
||||
handler.ServeHTTP(recorder, request)
|
||||
return recorder
|
||||
}
|
||||
|
||||
type responseRecorder struct {
|
||||
header http.Header
|
||||
body bytes.Buffer
|
||||
code int
|
||||
}
|
||||
|
||||
func (r *responseRecorder) Header() http.Header { return r.header }
|
||||
func (r *responseRecorder) Write(body []byte) (int, error) { return r.body.Write(body) }
|
||||
func (r *responseRecorder) WriteHeader(statusCode int) { r.code = statusCode }
|
||||
func (r *responseRecorder) Body() *bytes.Buffer { return &r.body }
|
||||
|
||||
func assertStatusCode(t *testing.T, recorder *responseRecorder, want int) {
|
||||
t.Helper()
|
||||
if recorder.code != want {
|
||||
t.Fatalf("status code = %d, want %d; body=%s", recorder.code, want, recorder.body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerAddrReturnsConfiguredAddress(t *testing.T) {
|
||||
server := NewServer("127.0.0.1:9999", nil, nil)
|
||||
if got := server.Addr(); got != "127.0.0.1:9999" {
|
||||
t.Fatalf("Addr() = %q, want %q", got, "127.0.0.1:9999")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantStatusCode int
|
||||
wantCode string
|
||||
wantUpstream int
|
||||
}{
|
||||
{name: "nil", err: nil},
|
||||
{name: "http error passthrough", err: &httpError{StatusCode: http.StatusTeapot, Code: "teapot", Message: "brew"}, wantStatusCode: http.StatusTeapot, wantCode: "teapot"},
|
||||
{name: "upstream error", err: &sub2api.HTTPError{Method: http.MethodGet, Path: "/x", StatusCode: http.StatusForbidden, Body: "nope"}, wantStatusCode: http.StatusBadGateway, wantCode: "host_request_failed", wantUpstream: http.StatusForbidden},
|
||||
{name: "pack conflict already installed", err: errors.New("pack already installed"), wantStatusCode: http.StatusConflict, wantCode: "pack_conflict"},
|
||||
{name: "pack conflict checksum drift", err: errors.New("checksum drift detected"), wantStatusCode: http.StatusConflict, wantCode: "pack_conflict"},
|
||||
{name: "provider not found", err: errors.New("provider \"deepseek\" not found in pack \"openai\""), wantStatusCode: http.StatusBadRequest, wantCode: "provider_not_found"},
|
||||
{name: "bad request pack path", err: errors.New("pack path is required"), wantStatusCode: http.StatusBadRequest, wantCode: "bad_request"},
|
||||
{name: "bad request decode", err: errors.New("decode pack.json failed"), wantStatusCode: http.StatusBadRequest, wantCode: "bad_request"},
|
||||
{name: "internal error", err: errors.New("boom"), wantStatusCode: http.StatusInternalServerError, wantCode: "internal_error"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := classifyError(tt.err)
|
||||
if tt.err == nil {
|
||||
if got != nil {
|
||||
t.Fatalf("classifyError(nil) = %#v, want nil", got)
|
||||
}
|
||||
return
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatal("classifyError() = nil, want error")
|
||||
}
|
||||
if got.StatusCode != tt.wantStatusCode {
|
||||
t.Fatalf("StatusCode = %d, want %d", got.StatusCode, tt.wantStatusCode)
|
||||
}
|
||||
if got.Code != tt.wantCode {
|
||||
t.Fatalf("Code = %q, want %q", got.Code, tt.wantCode)
|
||||
}
|
||||
if got.UpstreamStatus != tt.wantUpstream {
|
||||
t.Fatalf("UpstreamStatus = %d, want %d", got.UpstreamStatus, tt.wantUpstream)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteHTTPError(t *testing.T) {
|
||||
t.Run("default error when nil", func(t *testing.T) {
|
||||
recorder := &responseRecorder{header: make(http.Header)}
|
||||
writeHTTPError(recorder, nil)
|
||||
assertStatusCode(t, recorder, http.StatusInternalServerError)
|
||||
if got := recorder.Header().Get("Content-Type"); got != "application/json" {
|
||||
t.Fatalf("Content-Type = %q, want application/json", got)
|
||||
}
|
||||
assertJSONContains(t, recorder.Body().Bytes(), "error.code", "internal_error")
|
||||
assertJSONContains(t, recorder.Body().Bytes(), "error.message", "internal server error")
|
||||
})
|
||||
|
||||
t.Run("writes provided error", func(t *testing.T) {
|
||||
recorder := &responseRecorder{header: make(http.Header)}
|
||||
writeHTTPError(recorder, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "invalid input", UpstreamStatus: http.StatusConflict})
|
||||
assertStatusCode(t, recorder, http.StatusBadRequest)
|
||||
assertJSONContains(t, recorder.Body().Bytes(), "error.code", "bad_request")
|
||||
assertJSONContains(t, recorder.Body().Bytes(), "error.upstream_status", float64(http.StatusConflict))
|
||||
})
|
||||
}
|
||||
|
||||
func TestDecodeJSON(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
request := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{"host_base_url":"https://example.com","pack_path":"/tmp/pack.zip"}`))
|
||||
var got InstallPackRequest
|
||||
if err := decodeJSON(request, &got); err != nil {
|
||||
t.Fatalf("decodeJSON() error = %v, want nil", err)
|
||||
}
|
||||
if got.HostBaseURL != "https://example.com" || got.PackPath != "/tmp/pack.zip" {
|
||||
t.Fatalf("decoded request = %#v, want expected fields", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects unknown fields", func(t *testing.T) {
|
||||
request := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{"host_base_url":"https://example.com","unknown":true}`))
|
||||
var got InstallPackRequest
|
||||
err := decodeJSON(request, &got)
|
||||
if err == nil {
|
||||
t.Fatal("decodeJSON() error = nil, want error")
|
||||
}
|
||||
if err.StatusCode != http.StatusBadRequest || err.Code != "bad_request" {
|
||||
t.Fatalf("decodeJSON() = %#v, want bad_request", err)
|
||||
}
|
||||
if !strings.Contains(err.Message, "unknown field") {
|
||||
t.Fatalf("Message = %q, want unknown field", err.Message)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects trailing non-object payload", func(t *testing.T) {
|
||||
request := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{"host_base_url":"https://example.com"}[]`))
|
||||
var got InstallPackRequest
|
||||
err := decodeJSON(request, &got)
|
||||
if err == nil {
|
||||
t.Fatal("decodeJSON() error = nil, want error")
|
||||
}
|
||||
if err.Message != "request body must contain a single JSON object" {
|
||||
t.Fatalf("Message = %q, want single object error", err.Message)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWriteJSON(t *testing.T) {
|
||||
recorder := &responseRecorder{header: make(http.Header)}
|
||||
writeJSON(recorder, http.StatusCreated, map[string]any{"ok": true, "count": 2})
|
||||
assertStatusCode(t, recorder, http.StatusCreated)
|
||||
if got := recorder.Header().Get("Content-Type"); got != "application/json" {
|
||||
t.Fatalf("Content-Type = %q, want application/json", got)
|
||||
}
|
||||
assertJSONContains(t, recorder.Body().Bytes(), "ok", true)
|
||||
assertJSONContains(t, recorder.Body().Bytes(), "count", float64(2))
|
||||
}
|
||||
|
||||
func TestFindProvider(t *testing.T) {
|
||||
loaded := pack.LoadedPack{
|
||||
Manifest: pack.Manifest{PackID: "openai-cn-pack"},
|
||||
Providers: []pack.ProviderManifest{
|
||||
{ProviderID: "deepseek", DisplayName: "DeepSeek"},
|
||||
{ProviderID: "openai", DisplayName: "OpenAI"},
|
||||
},
|
||||
}
|
||||
|
||||
provider, err := findProvider(loaded, " deepseek ")
|
||||
if err != nil {
|
||||
t.Fatalf("findProvider() error = %v, want nil", err)
|
||||
}
|
||||
if provider.ProviderID != "deepseek" {
|
||||
t.Fatalf("ProviderID = %q, want deepseek", provider.ProviderID)
|
||||
}
|
||||
|
||||
_, err = findProvider(loaded, "missing")
|
||||
if err == nil {
|
||||
t.Fatal("findProvider() error = nil, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), `provider "missing" not found in pack "openai-cn-pack"`) {
|
||||
t.Fatalf("findProvider() error = %v, want provider not found message", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIRequiresConfiguredAdminToken(t *testing.T) {
|
||||
handler := NewAPIHandler("", ActionSet{})
|
||||
request := httptestRequest(t, http.MethodPost, "/api/packs/install", map[string]any{"host_base_url": "https://sub2api.example.com"}, "any-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusInternalServerError)
|
||||
assertJSONContains(t, response.Body().Bytes(), "error.code", "server_misconfigured")
|
||||
}
|
||||
|
||||
func TestAPIBatchDetailRejectsInvalidBatchID(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{BatchDetail: func(context.Context, BatchDetailRequest) (provision.BatchDetailResult, error) {
|
||||
t.Fatal("BatchDetail should not be called for invalid batch id")
|
||||
return provision.BatchDetailResult{}, nil
|
||||
}})
|
||||
request := httptestRequest(t, http.MethodGet, "/api/import-batches/not-a-number", nil, "secret-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusBadRequest)
|
||||
assertJSONContains(t, response.Body().Bytes(), "error.code", "bad_request")
|
||||
assertJSONContains(t, response.Body().Bytes(), "error.message", "batch_id must be a positive integer")
|
||||
}
|
||||
|
||||
func TestAPIInstallPackRejectsInvalidJSON(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{InstallPack: func(context.Context, InstallPackRequest) (provision.PackInstallResult, error) {
|
||||
t.Fatal("InstallPack should not be called for invalid JSON")
|
||||
return provision.PackInstallResult{}, nil
|
||||
}})
|
||||
request, err := http.NewRequest(http.MethodPost, "/api/packs/install", strings.NewReader(`{"host_base_url":"https://sub2api.example.com","unknown":true}`))
|
||||
if err != nil {
|
||||
t.Fatalf("http.NewRequest() error = %v", err)
|
||||
}
|
||||
request.Header.Set("Authorization", "Bearer secret-token")
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusBadRequest)
|
||||
assertJSONContains(t, response.Body().Bytes(), "error.code", "bad_request")
|
||||
}
|
||||
|
||||
func TestAPIImportProviderReturnsClassifiedErrorWithoutBatch(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{
|
||||
ImportProvider: func(context.Context, ImportProviderRequest) (provision.RuntimeImportResult, error) {
|
||||
return provision.RuntimeImportResult{}, errors.New("pack path is required")
|
||||
},
|
||||
})
|
||||
request := httptestRequest(t, http.MethodPost, "/api/providers/deepseek/import", map[string]any{"host_base_url": "https://sub2api.example.com", "pack_path": "/tmp/openai-pack.zip"}, "secret-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusBadRequest)
|
||||
assertJSONContains(t, response.Body().Bytes(), "error.code", "bad_request")
|
||||
assertJSONContains(t, response.Body().Bytes(), "batch_id", float64(0))
|
||||
}
|
||||
|
||||
func TestAPIPreviewProviderReturnsUpstreamError(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{
|
||||
PreviewProvider: func(context.Context, PreviewProviderRequest) (provision.PreviewReport, error) {
|
||||
return provision.PreviewReport{}, &sub2api.HTTPError{Method: http.MethodPost, Path: "/preview", StatusCode: http.StatusTooManyRequests, Body: "rate limited"}
|
||||
},
|
||||
})
|
||||
request := httptestRequest(t, http.MethodPost, "/api/providers/deepseek/preview-import", map[string]any{"host_base_url": "https://sub2api.example.com", "pack_path": "/tmp/openai-pack.zip"}, "secret-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusBadGateway)
|
||||
assertJSONContains(t, response.Body().Bytes(), "error.code", "host_request_failed")
|
||||
assertJSONContains(t, response.Body().Bytes(), "error.upstream_status", float64(http.StatusTooManyRequests))
|
||||
}
|
||||
|
||||
func TestAPIRollbackProviderReturnsConfiguredError(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{
|
||||
RollbackProvider: func(context.Context, RollbackProviderRequest) (provision.RollbackReport, error) {
|
||||
return provision.RollbackReport{}, &httpError{StatusCode: http.StatusGone, Code: "rolled_back", Message: "already removed"}
|
||||
},
|
||||
})
|
||||
request := httptestRequest(t, http.MethodPost, "/api/providers/deepseek/rollback", map[string]any{"host_base_url": "https://sub2api.example.com", "pack_path": "/tmp/openai-pack.zip"}, "secret-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusGone)
|
||||
assertJSONContains(t, response.Body().Bytes(), "error.code", "rolled_back")
|
||||
}
|
||||
|
||||
func TestAPIReconcileProviderRejectsTrailingNonObjectPayload(t *testing.T) {
|
||||
handler := NewAPIHandler("secret-token", ActionSet{ReconcileProvider: func(context.Context, ReconcileProviderRequest) (provision.ReconcileResult, error) {
|
||||
t.Fatal("ReconcileProvider should not be called for invalid JSON")
|
||||
return provision.ReconcileResult{}, nil
|
||||
}})
|
||||
request, err := http.NewRequest(http.MethodPost, "/api/providers/deepseek/reconcile", strings.NewReader(`{"host_base_url":"https://sub2api.example.com"}[]`))
|
||||
if err != nil {
|
||||
t.Fatalf("http.NewRequest() error = %v", err)
|
||||
}
|
||||
request.Header.Set("Authorization", "Bearer secret-token")
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusBadRequest)
|
||||
assertJSONContains(t, response.Body().Bytes(), "error.message", "request body must contain a single JSON object")
|
||||
}
|
||||
|
||||
// --- Coverage edge cases ---
|
||||
|
||||
func TestHTTPErrorError(t *testing.T) {
|
||||
e := &httpError{StatusCode: http.StatusTeapot, Code: "teapot", Message: "i'm a teapot"}
|
||||
if got := e.Error(); got != "i'm a teapot" {
|
||||
t.Fatalf("httpError.Error() = %q, want %q", got, "i'm a teapot")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderStatusFnNil(t *testing.T) {
|
||||
handler := NewAPIHandler("t", ActionSet{})
|
||||
req := httptestRequest(t, http.MethodGet, "/api/providers/x/status", nil, "t")
|
||||
res := httptestRecorder(handler, req)
|
||||
assertStatusCode(t, res, http.StatusInternalServerError)
|
||||
assertJSONContains(t, res.Body().Bytes(), "error.code", "server_misconfigured")
|
||||
}
|
||||
|
||||
func TestProviderAccessStatusFnNil(t *testing.T) {
|
||||
handler := NewAPIHandler("t", ActionSet{})
|
||||
req := httptestRequest(t, http.MethodGet, "/api/providers/x/access/status", nil, "t")
|
||||
res := httptestRecorder(handler, req)
|
||||
assertStatusCode(t, res, http.StatusInternalServerError)
|
||||
assertJSONContains(t, res.Body().Bytes(), "error.code", "server_misconfigured")
|
||||
}
|
||||
|
||||
func TestProviderResourcesFnNil(t *testing.T) {
|
||||
handler := NewAPIHandler("t", ActionSet{})
|
||||
req := httptestRequest(t, http.MethodGet, "/api/providers/x/resources", nil, "t")
|
||||
res := httptestRecorder(handler, req)
|
||||
assertStatusCode(t, res, http.StatusInternalServerError)
|
||||
assertJSONContains(t, res.Body().Bytes(), "error.code", "server_misconfigured")
|
||||
}
|
||||
|
||||
func TestProviderStatusReturnsError(t *testing.T) {
|
||||
handler := NewAPIHandler("t", ActionSet{
|
||||
GetProviderStatus: func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error) {
|
||||
return provision.ProviderSnapshot{}, errors.New(`provider "x" not found in pack "p"`)
|
||||
},
|
||||
})
|
||||
req := httptestRequest(t, http.MethodGet, "/api/providers/x/status", nil, "t")
|
||||
res := httptestRecorder(handler, req)
|
||||
assertStatusCode(t, res, http.StatusBadRequest)
|
||||
assertJSONContains(t, res.Body().Bytes(), "error.code", "provider_not_found")
|
||||
}
|
||||
|
||||
func TestPostHandlersFnNil(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
body string
|
||||
}{
|
||||
{name: "install-pack", method: http.MethodPost, path: "/api/packs/install", body: `{}`},
|
||||
{name: "preview", method: http.MethodPost, path: "/api/providers/x/preview-import", body: `{}`},
|
||||
{name: "import", method: http.MethodPost, path: "/api/providers/x/import", body: `{}`},
|
||||
{name: "rollback", method: http.MethodPost, path: "/api/providers/x/rollback", body: `{}`},
|
||||
{name: "reconcile", method: http.MethodPost, path: "/api/providers/x/reconcile", body: `{}`},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
handler := NewAPIHandler("t", ActionSet{})
|
||||
req, _ := http.NewRequest(tt.method, tt.path, strings.NewReader(tt.body))
|
||||
req.Header.Set("Authorization", "Bearer t")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
res := httptestRecorder(handler, req)
|
||||
assertStatusCode(t, res, http.StatusInternalServerError)
|
||||
assertJSONContains(t, res.Body().Bytes(), "error.code", "server_misconfigured")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerErrorPaths(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
body string
|
||||
actionSet ActionSet
|
||||
wantStatus int
|
||||
wantCode string
|
||||
}{
|
||||
{
|
||||
name: "access-status-error",
|
||||
method: http.MethodGet,
|
||||
path: "/api/providers/x/access/status",
|
||||
actionSet: ActionSet{
|
||||
GetProviderAccessStatus: func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error) {
|
||||
return provision.ProviderSnapshot{}, errors.New("boom")
|
||||
},
|
||||
},
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
wantCode: "internal_error",
|
||||
},
|
||||
{
|
||||
name: "preview-error",
|
||||
method: http.MethodPost,
|
||||
path: "/api/providers/x/preview-import",
|
||||
body: `{}`,
|
||||
actionSet: ActionSet{
|
||||
PreviewProvider: func(context.Context, PreviewProviderRequest) (provision.PreviewReport, error) {
|
||||
return provision.PreviewReport{}, errors.New("boom")
|
||||
},
|
||||
},
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
wantCode: "internal_error",
|
||||
},
|
||||
{
|
||||
name: "rollback-error",
|
||||
method: http.MethodPost,
|
||||
path: "/api/providers/x/rollback",
|
||||
body: `{}`,
|
||||
actionSet: ActionSet{
|
||||
RollbackProvider: func(context.Context, RollbackProviderRequest) (provision.RollbackReport, error) {
|
||||
return provision.RollbackReport{}, errors.New("boom")
|
||||
},
|
||||
},
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
wantCode: "internal_error",
|
||||
},
|
||||
{
|
||||
name: "reconcile-error",
|
||||
method: http.MethodPost,
|
||||
path: "/api/providers/x/reconcile",
|
||||
body: `{}`,
|
||||
actionSet: ActionSet{
|
||||
ReconcileProvider: func(context.Context, ReconcileProviderRequest) (provision.ReconcileResult, error) {
|
||||
return provision.ReconcileResult{}, errors.New("boom")
|
||||
},
|
||||
},
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
wantCode: "internal_error",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
handler := NewAPIHandler("t", tt.actionSet)
|
||||
var req *http.Request
|
||||
if tt.body != "" {
|
||||
req, _ = http.NewRequest(tt.method, tt.path, strings.NewReader(tt.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
} else {
|
||||
var err error
|
||||
req, err = http.NewRequest(tt.method, tt.path, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer t")
|
||||
res := httptestRecorder(handler, req)
|
||||
assertStatusCode(t, res, tt.wantStatus)
|
||||
assertJSONContains(t, res.Body().Bytes(), "error.code", tt.wantCode)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderAccessStatusMultipleClosures(t *testing.T) {
|
||||
handler := NewAPIHandler("t", ActionSet{
|
||||
GetProviderAccessStatus: func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error) {
|
||||
return provision.ProviderSnapshot{
|
||||
Pack: sqlite.Pack{PackID: "p"},
|
||||
Provider: sqlite.Provider{ProviderID: "dp"},
|
||||
Batch: sqlite.ImportBatch{ID: 1},
|
||||
LatestAccessStatus: "ready",
|
||||
AccessClosures: []sqlite.AccessClosureRecord{
|
||||
{ID: 1, ClosureType: "preview", Status: "done", DetailsJSON: `{"v":1}`},
|
||||
{ID: 2, ClosureType: "self_service", Status: "active", DetailsJSON: `{"v":2}`},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
req := httptestRequest(t, http.MethodGet, "/api/providers/dp/access/status", nil, "t")
|
||||
res := httptestRecorder(handler, req)
|
||||
assertStatusCode(t, res, http.StatusOK)
|
||||
// Should report the last closure (index n-1)
|
||||
if !strings.Contains(res.Body().String(), `"closure_type":"self_service"`) {
|
||||
t.Fatalf("expected latest closure to be self_service, got: %s", res.Body().String())
|
||||
}
|
||||
}
|
||||
|
||||
func assertJSONContains(t *testing.T, payload []byte, key string, want any) {
|
||||
t.Helper()
|
||||
var decoded map[string]any
|
||||
if err := json.Unmarshal(payload, &decoded); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v; payload=%s", err, string(payload))
|
||||
}
|
||||
if strings.Contains(key, ".") {
|
||||
parts := strings.Split(key, ".")
|
||||
current := any(decoded)
|
||||
for _, part := range parts {
|
||||
object, ok := current.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("key %q not found in payload %s", key, string(payload))
|
||||
}
|
||||
current = object[part]
|
||||
}
|
||||
if current != want {
|
||||
t.Fatalf("json key %q = %#v, want %#v; payload=%s", key, current, want, string(payload))
|
||||
}
|
||||
return
|
||||
}
|
||||
if decoded[key] != want {
|
||||
t.Fatalf("json key %q = %#v, want %#v; payload=%s", key, decoded[key], want, string(payload))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewActionSetReturnsNonNil(t *testing.T) {
|
||||
as := NewActionSet("file::memory:?cache=shared")
|
||||
t.Run("InstallPack", func(t *testing.T) {
|
||||
if as.InstallPack == nil {
|
||||
t.Fatal("is nil")
|
||||
}
|
||||
})
|
||||
t.Run("BatchDetail", func(t *testing.T) {
|
||||
if as.BatchDetail == nil {
|
||||
t.Fatal("is nil")
|
||||
}
|
||||
})
|
||||
t.Run("GetProviderStatus", func(t *testing.T) {
|
||||
if as.GetProviderStatus == nil {
|
||||
t.Fatal("is nil")
|
||||
}
|
||||
})
|
||||
t.Run("GetProviderResources", func(t *testing.T) {
|
||||
if as.GetProviderResources == nil {
|
||||
t.Fatal("is nil")
|
||||
}
|
||||
})
|
||||
t.Run("GetProviderAccessStatus", func(t *testing.T) {
|
||||
if as.GetProviderAccessStatus == nil {
|
||||
t.Fatal("is nil")
|
||||
}
|
||||
})
|
||||
t.Run("PreviewProvider", func(t *testing.T) {
|
||||
if as.PreviewProvider == nil {
|
||||
t.Fatal("is nil")
|
||||
}
|
||||
})
|
||||
t.Run("ImportProvider", func(t *testing.T) {
|
||||
if as.ImportProvider == nil {
|
||||
t.Fatal("is nil")
|
||||
}
|
||||
})
|
||||
t.Run("RollbackProvider", func(t *testing.T) {
|
||||
if as.RollbackProvider == nil {
|
||||
t.Fatal("is nil")
|
||||
}
|
||||
})
|
||||
t.Run("ReconcileProvider", func(t *testing.T) {
|
||||
if as.ReconcileProvider == nil {
|
||||
t.Fatal("is nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBatchDetailReturnsNotFoundForMissingBatch(t *testing.T) {
|
||||
as := NewActionSet("file::memory:?cache=shared")
|
||||
_, err := as.BatchDetail(context.Background(), BatchDetailRequest{BatchID: 999})
|
||||
if err == nil {
|
||||
t.Fatal("BatchDetail() error = nil for missing batch, want error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewActionSetSQLiteClosures(t *testing.T) {
|
||||
dsn := "file::memory:?cache=shared"
|
||||
as := NewActionSet(dsn)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("GetProviderStatus on empty DB", func(t *testing.T) {
|
||||
_, err := as.GetProviderStatus(ctx, ProviderQueryRequest{ProviderID: "x", PackID: "p"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from empty DB, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetProviderResources on empty DB", func(t *testing.T) {
|
||||
_, err := as.GetProviderResources(ctx, ProviderQueryRequest{ProviderID: "x", PackID: "p"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from empty DB, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetProviderAccessStatus on empty DB", func(t *testing.T) {
|
||||
_, err := as.GetProviderAccessStatus(ctx, ProviderQueryRequest{ProviderID: "x", PackID: "p"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from empty DB, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewActionSetPackErrorPaths(t *testing.T) {
|
||||
dsn := "file::memory:?cache=shared"
|
||||
as := NewActionSet(dsn)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("InstallPack bad path", func(t *testing.T) {
|
||||
_, err := as.InstallPack(ctx, InstallPackRequest{PackPath: "/nonexistent/pack"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from bad pack path")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PreviewProvider bad path", func(t *testing.T) {
|
||||
_, err := as.PreviewProvider(ctx, PreviewProviderRequest{PackPath: "/nonexistent/pack", ProviderID: "x"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from bad pack path")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ImportProvider bad path", func(t *testing.T) {
|
||||
_, err := as.ImportProvider(ctx, ImportProviderRequest{PackPath: "/nonexistent/pack", ProviderID: "x", HostBaseURL: "http://h:8080"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from bad pack path")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RollbackProvider bad path", func(t *testing.T) {
|
||||
_, err := as.RollbackProvider(ctx, RollbackProviderRequest{PackPath: "/nonexistent/pack", ProviderID: "x", HostBaseURL: "http://h:8080"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from bad pack path")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ReconcileProvider bad path", func(t *testing.T) {
|
||||
_, err := as.ReconcileProvider(ctx, ReconcileProviderRequest{PackPath: "/nonexistent/pack", ProviderID: "x", HostBaseURL: "http://h:8080"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from bad pack path")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -11,6 +11,10 @@ func Bootstrap(_ context.Context) (*Server, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewServer(cfg.Server.ListenAddr, nil), nil
|
||||
adminToken, err := config.LoadAdminTokenFromEnv()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
handler := NewAPIHandler(adminToken, NewActionSet(cfg.Database.SQLiteDSN))
|
||||
return NewServer(cfg.Server.ListenAddr, handler, nil), nil
|
||||
}
|
||||
|
||||
638
internal/app/http_api.go
Normal file
638
internal/app/http_api.go
Normal file
@@ -0,0 +1,638 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
"sub2api-cn-relay-manager/internal/pack"
|
||||
"sub2api-cn-relay-manager/internal/provision"
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
type ActionSet struct {
|
||||
InstallPack func(context.Context, InstallPackRequest) (provision.PackInstallResult, error)
|
||||
BatchDetail func(context.Context, BatchDetailRequest) (provision.BatchDetailResult, error)
|
||||
GetProviderStatus func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error)
|
||||
GetProviderResources func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error)
|
||||
GetProviderAccessStatus func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error)
|
||||
PreviewProvider func(context.Context, PreviewProviderRequest) (provision.PreviewReport, error)
|
||||
ImportProvider func(context.Context, ImportProviderRequest) (provision.RuntimeImportResult, error)
|
||||
RollbackProvider func(context.Context, RollbackProviderRequest) (provision.RollbackReport, error)
|
||||
ReconcileProvider func(context.Context, ReconcileProviderRequest) (provision.ReconcileResult, error)
|
||||
}
|
||||
|
||||
type InstallPackRequest struct {
|
||||
HostBaseURL string `json:"host_base_url"`
|
||||
HostAPIKey string `json:"host_api_key"`
|
||||
HostBearerToken string `json:"host_bearer_token"`
|
||||
PackPath string `json:"pack_path"`
|
||||
}
|
||||
|
||||
type BatchDetailRequest struct {
|
||||
BatchID int64
|
||||
}
|
||||
|
||||
type ProviderQueryRequest struct {
|
||||
ProviderID string
|
||||
PackID string
|
||||
}
|
||||
|
||||
type RollbackProviderRequest struct {
|
||||
HostBaseURL string `json:"host_base_url"`
|
||||
HostAPIKey string `json:"host_api_key"`
|
||||
HostBearerToken string `json:"host_bearer_token"`
|
||||
PackPath string `json:"pack_path"`
|
||||
ProviderID string `json:"provider_id"`
|
||||
}
|
||||
|
||||
type ReconcileProviderRequest struct {
|
||||
HostBaseURL string `json:"host_base_url"`
|
||||
HostAPIKey string `json:"host_api_key"`
|
||||
HostBearerToken string `json:"host_bearer_token"`
|
||||
PackPath string `json:"pack_path"`
|
||||
ProviderID string `json:"provider_id"`
|
||||
AccessAPIKey string `json:"access_api_key"`
|
||||
}
|
||||
|
||||
type PreviewProviderRequest struct {
|
||||
HostBaseURL string `json:"host_base_url"`
|
||||
HostAPIKey string `json:"host_api_key"`
|
||||
HostBearerToken string `json:"host_bearer_token"`
|
||||
PackPath string `json:"pack_path"`
|
||||
ProviderID string `json:"provider_id"`
|
||||
Keys []string `json:"keys"`
|
||||
Mode string `json:"mode"`
|
||||
}
|
||||
|
||||
type ImportProviderRequest struct {
|
||||
HostBaseURL string `json:"host_base_url"`
|
||||
HostAPIKey string `json:"host_api_key"`
|
||||
HostBearerToken string `json:"host_bearer_token"`
|
||||
PackPath string `json:"pack_path"`
|
||||
ProviderID string `json:"provider_id"`
|
||||
Keys []string `json:"keys"`
|
||||
Mode string `json:"mode"`
|
||||
AccessMode string `json:"access_mode"`
|
||||
AccessAPIKey string `json:"access_api_key"`
|
||||
SubscriptionUsers []string `json:"subscription_users"`
|
||||
SubscriptionDays int `json:"subscription_days"`
|
||||
}
|
||||
|
||||
type httpError struct {
|
||||
StatusCode int `json:"-"`
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
UpstreamStatus int `json:"upstream_status,omitempty"`
|
||||
}
|
||||
|
||||
func (e *httpError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
func NewAPIHandler(adminToken string, actions ActionSet) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /healthz", healthz)
|
||||
mux.Handle("GET /api/import-batches/{batchID}", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handleBatchDetail(w, r, actions.BatchDetail)
|
||||
})))
|
||||
mux.Handle("GET /api/providers/{providerID}/status", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handleProviderStatus(w, r, actions.GetProviderStatus)
|
||||
})))
|
||||
mux.Handle("GET /api/providers/{providerID}/resources", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handleProviderResources(w, r, actions.GetProviderResources)
|
||||
})))
|
||||
mux.Handle("GET /api/providers/{providerID}/access/status", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handleProviderAccessStatus(w, r, actions.GetProviderAccessStatus)
|
||||
})))
|
||||
mux.Handle("POST /api/packs/install", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handleInstallPack(w, r, actions.InstallPack)
|
||||
})))
|
||||
mux.Handle("POST /api/providers/{providerID}/preview-import", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handlePreviewProvider(w, r, actions.PreviewProvider)
|
||||
})))
|
||||
mux.Handle("POST /api/providers/{providerID}/import", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handleImportProvider(w, r, actions.ImportProvider)
|
||||
})))
|
||||
mux.Handle("POST /api/providers/{providerID}/rollback", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handleRollbackProvider(w, r, actions.RollbackProvider)
|
||||
})))
|
||||
mux.Handle("POST /api/providers/{providerID}/reconcile", requireAdminToken(adminToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handleReconcileProvider(w, r, actions.ReconcileProvider)
|
||||
})))
|
||||
return mux
|
||||
}
|
||||
|
||||
func healthz(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}
|
||||
|
||||
func requireAdminToken(token string, next http.Handler) http.Handler {
|
||||
if strings.TrimSpace(token) == "" {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "admin token is not configured"})
|
||||
})
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if bearerToken(r) != token {
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusUnauthorized, Code: "unauthorized", Message: "missing or invalid admin token"})
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func bearerToken(r *http.Request) string {
|
||||
header := strings.TrimSpace(r.Header.Get("Authorization"))
|
||||
if !strings.HasPrefix(strings.ToLower(header), "bearer ") {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(header[len("Bearer "):])
|
||||
}
|
||||
|
||||
func handleInstallPack(w http.ResponseWriter, r *http.Request, fn func(context.Context, InstallPackRequest) (provision.PackInstallResult, error)) {
|
||||
if fn == nil {
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "install-pack action is not configured"})
|
||||
return
|
||||
}
|
||||
var req InstallPackRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeHTTPError(w, err)
|
||||
return
|
||||
}
|
||||
result, err := fn(r.Context(), req)
|
||||
if err != nil {
|
||||
writeHTTPError(w, classifyError(err))
|
||||
return
|
||||
}
|
||||
providers := make([]map[string]string, 0, len(result.Providers))
|
||||
for _, provider := range result.Providers {
|
||||
providers = append(providers, map[string]string{
|
||||
"provider_id": provider.ProviderID,
|
||||
"display_name": provider.DisplayName,
|
||||
})
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"pack_id": result.Pack.PackID,
|
||||
"version": result.Pack.Version,
|
||||
"host_version": result.HostVersion,
|
||||
"already_installed": result.AlreadyInstalled,
|
||||
"providers": providers,
|
||||
})
|
||||
}
|
||||
|
||||
func handleBatchDetail(w http.ResponseWriter, r *http.Request, fn func(context.Context, BatchDetailRequest) (provision.BatchDetailResult, error)) {
|
||||
if fn == nil {
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "batch-detail action is not configured"})
|
||||
return
|
||||
}
|
||||
batchID, err := strconv.ParseInt(r.PathValue("batchID"), 10, 64)
|
||||
if err != nil || batchID <= 0 {
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "batch_id must be a positive integer"})
|
||||
return
|
||||
}
|
||||
result, err := fn(r.Context(), BatchDetailRequest{BatchID: batchID})
|
||||
if err != nil {
|
||||
writeHTTPError(w, classifyError(err))
|
||||
return
|
||||
}
|
||||
items := make([]map[string]any, 0, len(result.Items))
|
||||
for _, item := range result.Items {
|
||||
items = append(items, map[string]any{
|
||||
"id": item.ID,
|
||||
"batch_id": item.BatchID,
|
||||
"key_fingerprint": item.KeyFingerprint,
|
||||
"account_status": item.AccountStatus,
|
||||
"probe_summary_json": item.ProbeSummaryJSON,
|
||||
})
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"batch": map[string]any{
|
||||
"id": result.Batch.ID,
|
||||
"host_id": result.Batch.HostID,
|
||||
"pack_id": result.Batch.PackID,
|
||||
"provider_id": result.Batch.ProviderID,
|
||||
"mode": result.Batch.Mode,
|
||||
"batch_status": result.Batch.BatchStatus,
|
||||
"access_status": result.Batch.AccessStatus,
|
||||
},
|
||||
"items": items,
|
||||
"managed_resources": result.ManagedResources,
|
||||
"access_closures": result.AccessClosures,
|
||||
"reconcile_runs": result.ReconcileRuns,
|
||||
"items_count": len(result.Items),
|
||||
"managed_count": len(result.ManagedResources),
|
||||
"access_count": len(result.AccessClosures),
|
||||
"reconcile_count": len(result.ReconcileRuns),
|
||||
})
|
||||
}
|
||||
|
||||
func handleProviderStatus(w http.ResponseWriter, r *http.Request, fn func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error)) {
|
||||
if fn == nil {
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "provider-status action is not configured"})
|
||||
return
|
||||
}
|
||||
result, err := fn(r.Context(), ProviderQueryRequest{ProviderID: r.PathValue("providerID"), PackID: strings.TrimSpace(r.URL.Query().Get("pack_id"))})
|
||||
if err != nil {
|
||||
writeHTTPError(w, classifyError(err))
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"host": map[string]any{"host_id": result.Host.HostID, "base_url": result.Host.BaseURL, "host_version": result.Host.HostVersion},
|
||||
"pack": map[string]any{"pack_id": result.Pack.PackID, "version": result.Pack.Version},
|
||||
"provider": map[string]any{"provider_id": result.Provider.ProviderID, "display_name": result.Provider.DisplayName, "platform": result.Provider.Platform},
|
||||
"batch": map[string]any{"id": result.Batch.ID, "batch_status": result.Batch.BatchStatus, "access_status": result.Batch.AccessStatus, "mode": result.Batch.Mode},
|
||||
"provider_status": result.ProviderStatus,
|
||||
"latest_access_status": result.LatestAccessStatus,
|
||||
"latest_reconcile_status": result.LatestReconcileStatus,
|
||||
"latest_reconcile_summary": result.LatestReconcileSummary,
|
||||
"managed_resources_count": len(result.ManagedResources),
|
||||
"access_closures_count": len(result.AccessClosures),
|
||||
"reconcile_runs_count": len(result.ReconcileRuns),
|
||||
})
|
||||
}
|
||||
|
||||
func handleProviderAccessStatus(w http.ResponseWriter, r *http.Request, fn func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error)) {
|
||||
if fn == nil {
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "provider-access-status action is not configured"})
|
||||
return
|
||||
}
|
||||
result, err := fn(r.Context(), ProviderQueryRequest{ProviderID: r.PathValue("providerID"), PackID: strings.TrimSpace(r.URL.Query().Get("pack_id"))})
|
||||
if err != nil {
|
||||
writeHTTPError(w, classifyError(err))
|
||||
return
|
||||
}
|
||||
latestClosure := map[string]any{}
|
||||
if n := len(result.AccessClosures); n > 0 {
|
||||
closure := result.AccessClosures[n-1]
|
||||
latestClosure = map[string]any{"id": closure.ID, "closure_type": closure.ClosureType, "status": closure.Status, "details_json": closure.DetailsJSON}
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"provider_id": result.Provider.ProviderID,
|
||||
"pack_id": result.Pack.PackID,
|
||||
"batch_id": result.Batch.ID,
|
||||
"batch_access_status": result.Batch.AccessStatus,
|
||||
"latest_access_status": result.LatestAccessStatus,
|
||||
"closures_count": len(result.AccessClosures),
|
||||
"latest_closure": latestClosure,
|
||||
})
|
||||
}
|
||||
|
||||
func handleProviderResources(w http.ResponseWriter, r *http.Request, fn func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error)) {
|
||||
if fn == nil {
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "provider-resources action is not configured"})
|
||||
return
|
||||
}
|
||||
result, err := fn(r.Context(), ProviderQueryRequest{ProviderID: r.PathValue("providerID"), PackID: strings.TrimSpace(r.URL.Query().Get("pack_id"))})
|
||||
if err != nil {
|
||||
writeHTTPError(w, classifyError(err))
|
||||
return
|
||||
}
|
||||
resources := make([]map[string]any, 0, len(result.ManagedResources))
|
||||
for _, resource := range result.ManagedResources {
|
||||
resources = append(resources, map[string]any{"id": resource.ID, "resource_type": resource.ResourceType, "host_resource_id": resource.HostResourceID, "resource_name": resource.ResourceName})
|
||||
}
|
||||
accessClosures := make([]map[string]any, 0, len(result.AccessClosures))
|
||||
for _, closure := range result.AccessClosures {
|
||||
accessClosures = append(accessClosures, map[string]any{"id": closure.ID, "closure_type": closure.ClosureType, "status": closure.Status, "details_json": closure.DetailsJSON})
|
||||
}
|
||||
reconcileRuns := make([]map[string]any, 0, len(result.ReconcileRuns))
|
||||
for _, run := range result.ReconcileRuns {
|
||||
reconcileRuns = append(reconcileRuns, map[string]any{"id": run.ID, "status": run.Status, "summary_json": run.SummaryJSON})
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"provider_id": result.Provider.ProviderID,
|
||||
"pack_id": result.Pack.PackID,
|
||||
"batch_id": result.Batch.ID,
|
||||
"resources": resources,
|
||||
"access_closures": accessClosures,
|
||||
"reconcile_runs": reconcileRuns,
|
||||
})
|
||||
}
|
||||
|
||||
func handlePreviewProvider(w http.ResponseWriter, r *http.Request, fn func(context.Context, PreviewProviderRequest) (provision.PreviewReport, error)) {
|
||||
if fn == nil {
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "preview-provider action is not configured"})
|
||||
return
|
||||
}
|
||||
var req PreviewProviderRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeHTTPError(w, err)
|
||||
return
|
||||
}
|
||||
req.ProviderID = r.PathValue("providerID")
|
||||
result, err := fn(r.Context(), req)
|
||||
if err != nil {
|
||||
writeHTTPError(w, classifyError(err))
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"accepted_keys_count": len(result.AcceptedKeys),
|
||||
"names": result.Names,
|
||||
"decisions": result.Decisions,
|
||||
})
|
||||
}
|
||||
|
||||
func handleImportProvider(w http.ResponseWriter, r *http.Request, fn func(context.Context, ImportProviderRequest) (provision.RuntimeImportResult, error)) {
|
||||
if fn == nil {
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "import-provider action is not configured"})
|
||||
return
|
||||
}
|
||||
var req ImportProviderRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeHTTPError(w, err)
|
||||
return
|
||||
}
|
||||
req.ProviderID = r.PathValue("providerID")
|
||||
result, err := fn(r.Context(), req)
|
||||
if err != nil {
|
||||
payload := map[string]any{
|
||||
"batch_id": result.BatchID,
|
||||
"batch_status": result.Report.BatchStatus,
|
||||
"provider_status": result.Report.ProviderStatus,
|
||||
"access_status": result.Report.AccessStatus,
|
||||
"accepted_keys_count": len(result.Report.AcceptedKeys),
|
||||
"accounts_count": len(result.Report.Accounts),
|
||||
"gateway": result.Report.Gateway,
|
||||
"error": classifyError(err),
|
||||
}
|
||||
statusCode := http.StatusConflict
|
||||
if result.BatchID == 0 {
|
||||
statusCode = classifyError(err).StatusCode
|
||||
}
|
||||
writeJSON(w, statusCode, payload)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"batch_id": result.BatchID,
|
||||
"batch_status": result.Report.BatchStatus,
|
||||
"provider_status": result.Report.ProviderStatus,
|
||||
"access_status": result.Report.AccessStatus,
|
||||
"accepted_keys_count": len(result.Report.AcceptedKeys),
|
||||
"accounts_count": len(result.Report.Accounts),
|
||||
"group": result.Report.Group,
|
||||
"channel": result.Report.Channel,
|
||||
"plan": result.Report.Plan,
|
||||
"gateway": result.Report.Gateway,
|
||||
})
|
||||
}
|
||||
|
||||
func handleRollbackProvider(w http.ResponseWriter, r *http.Request, fn func(context.Context, RollbackProviderRequest) (provision.RollbackReport, error)) {
|
||||
if fn == nil {
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "rollback-provider action is not configured"})
|
||||
return
|
||||
}
|
||||
var req RollbackProviderRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeHTTPError(w, err)
|
||||
return
|
||||
}
|
||||
req.ProviderID = r.PathValue("providerID")
|
||||
result, err := fn(r.Context(), req)
|
||||
if err != nil {
|
||||
writeHTTPError(w, classifyError(err))
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"provider_id": req.ProviderID,
|
||||
"deleted_accounts": result.AccountsDeleted,
|
||||
"deleted_plans": result.PlansDeleted,
|
||||
"deleted_channels": result.ChannelsDeleted,
|
||||
"deleted_groups": result.GroupsDeleted,
|
||||
})
|
||||
}
|
||||
|
||||
func handleReconcileProvider(w http.ResponseWriter, r *http.Request, fn func(context.Context, ReconcileProviderRequest) (provision.ReconcileResult, error)) {
|
||||
if fn == nil {
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "reconcile-provider action is not configured"})
|
||||
return
|
||||
}
|
||||
var req ReconcileProviderRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeHTTPError(w, err)
|
||||
return
|
||||
}
|
||||
req.ProviderID = r.PathValue("providerID")
|
||||
result, err := fn(r.Context(), req)
|
||||
if err != nil {
|
||||
writeHTTPError(w, classifyError(err))
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"provider_id": req.ProviderID,
|
||||
"batch_id": result.BatchID,
|
||||
"status": result.Status,
|
||||
"missing_count": result.MissingCount,
|
||||
"extra_count": result.ExtraCount,
|
||||
"summary": result.Summary,
|
||||
})
|
||||
}
|
||||
|
||||
func decodeJSON(r *http.Request, dest any) *httpError {
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
decoder.DisallowUnknownFields()
|
||||
if err := decoder.Decode(dest); err != nil {
|
||||
return &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: fmt.Sprintf("decode request body: %v", err)}
|
||||
}
|
||||
if err := decoder.Decode(&struct{}{}); err != nil && !errors.Is(err, io.EOF) {
|
||||
return &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "request body must contain a single JSON object"}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeHTTPError(w http.ResponseWriter, err *httpError) {
|
||||
if err == nil {
|
||||
err = &httpError{StatusCode: http.StatusInternalServerError, Code: "internal_error", Message: "internal server error"}
|
||||
}
|
||||
writeJSON(w, err.StatusCode, map[string]any{"error": err})
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, statusCode int, body any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
_ = json.NewEncoder(w).Encode(body)
|
||||
}
|
||||
|
||||
func classifyError(err error) *httpError {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
var requestErr *httpError
|
||||
if errors.As(err, &requestErr) {
|
||||
return requestErr
|
||||
}
|
||||
var upstreamErr *sub2api.HTTPError
|
||||
if errors.As(err, &upstreamErr) {
|
||||
return &httpError{StatusCode: http.StatusBadGateway, Code: "host_request_failed", Message: err.Error(), UpstreamStatus: upstreamErr.StatusCode}
|
||||
}
|
||||
message := err.Error()
|
||||
switch {
|
||||
case strings.Contains(message, "already installed") || strings.Contains(message, "checksum drift"):
|
||||
return &httpError{StatusCode: http.StatusConflict, Code: "pack_conflict", Message: message}
|
||||
case strings.Contains(message, "not found in pack"):
|
||||
return &httpError{StatusCode: http.StatusBadRequest, Code: "provider_not_found", Message: message}
|
||||
case strings.Contains(message, "pack path") || strings.Contains(message, "pack dir") || strings.Contains(message, "required") || strings.Contains(message, "decode"):
|
||||
return &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: message}
|
||||
default:
|
||||
return &httpError{StatusCode: http.StatusInternalServerError, Code: "internal_error", Message: message}
|
||||
}
|
||||
}
|
||||
|
||||
func NewActionSet(sqliteDSN string) ActionSet {
|
||||
return ActionSet{
|
||||
InstallPack: func(ctx context.Context, req InstallPackRequest) (provision.PackInstallResult, error) {
|
||||
loadedPack, err := pack.LoadPath(req.PackPath)
|
||||
if err != nil {
|
||||
return provision.PackInstallResult{}, err
|
||||
}
|
||||
client, err := sub2api.NewClient(req.HostBaseURL, sub2api.WithAPIKey(req.HostAPIKey), sub2api.WithBearerToken(req.HostBearerToken))
|
||||
if err != nil {
|
||||
return provision.PackInstallResult{}, err
|
||||
}
|
||||
store, err := sqlite.Open(ctx, sqliteDSN)
|
||||
if err != nil {
|
||||
return provision.PackInstallResult{}, err
|
||||
}
|
||||
defer store.Close()
|
||||
service := provision.NewPackInstallService(store, client)
|
||||
return service.Install(ctx, provision.PackInstallRequest{Pack: loadedPack})
|
||||
},
|
||||
BatchDetail: func(ctx context.Context, req BatchDetailRequest) (provision.BatchDetailResult, error) {
|
||||
store, err := sqlite.Open(ctx, sqliteDSN)
|
||||
if err != nil {
|
||||
return provision.BatchDetailResult{}, err
|
||||
}
|
||||
defer store.Close()
|
||||
return provision.NewBatchDetailService(store).Get(ctx, req.BatchID)
|
||||
},
|
||||
GetProviderStatus: func(ctx context.Context, req ProviderQueryRequest) (provision.ProviderSnapshot, error) {
|
||||
store, err := sqlite.Open(ctx, sqliteDSN)
|
||||
if err != nil {
|
||||
return provision.ProviderSnapshot{}, err
|
||||
}
|
||||
defer store.Close()
|
||||
return provision.NewProviderStatusService(store).GetStatus(ctx, provision.ProviderQuery{ProviderID: req.ProviderID, PackID: req.PackID})
|
||||
},
|
||||
GetProviderResources: func(ctx context.Context, req ProviderQueryRequest) (provision.ProviderSnapshot, error) {
|
||||
store, err := sqlite.Open(ctx, sqliteDSN)
|
||||
if err != nil {
|
||||
return provision.ProviderSnapshot{}, err
|
||||
}
|
||||
defer store.Close()
|
||||
return provision.NewProviderStatusService(store).GetResources(ctx, provision.ProviderQuery{ProviderID: req.ProviderID, PackID: req.PackID})
|
||||
},
|
||||
GetProviderAccessStatus: func(ctx context.Context, req ProviderQueryRequest) (provision.ProviderSnapshot, error) {
|
||||
store, err := sqlite.Open(ctx, sqliteDSN)
|
||||
if err != nil {
|
||||
return provision.ProviderSnapshot{}, err
|
||||
}
|
||||
defer store.Close()
|
||||
return provision.NewProviderStatusService(store).GetStatus(ctx, provision.ProviderQuery{ProviderID: req.ProviderID, PackID: req.PackID})
|
||||
},
|
||||
PreviewProvider: func(ctx context.Context, req PreviewProviderRequest) (provision.PreviewReport, error) {
|
||||
loadedPack, err := pack.LoadPath(req.PackPath)
|
||||
if err != nil {
|
||||
return provision.PreviewReport{}, err
|
||||
}
|
||||
providerManifest, err := findProvider(loadedPack, req.ProviderID)
|
||||
if err != nil {
|
||||
return provision.PreviewReport{}, err
|
||||
}
|
||||
client, err := sub2api.NewClient(req.HostBaseURL, sub2api.WithAPIKey(req.HostAPIKey), sub2api.WithBearerToken(req.HostBearerToken))
|
||||
if err != nil {
|
||||
return provision.PreviewReport{}, err
|
||||
}
|
||||
service := provision.NewPreviewService(client)
|
||||
return service.PreviewImport(ctx, provision.PreviewRequest{Provider: providerManifest, Mode: req.Mode, Keys: req.Keys})
|
||||
},
|
||||
ImportProvider: func(ctx context.Context, req ImportProviderRequest) (provision.RuntimeImportResult, error) {
|
||||
loadedPack, err := pack.LoadPath(req.PackPath)
|
||||
if err != nil {
|
||||
return provision.RuntimeImportResult{}, err
|
||||
}
|
||||
providerManifest, err := findProvider(loadedPack, req.ProviderID)
|
||||
if err != nil {
|
||||
return provision.RuntimeImportResult{}, err
|
||||
}
|
||||
client, err := sub2api.NewClient(req.HostBaseURL, sub2api.WithAPIKey(req.HostAPIKey), sub2api.WithBearerToken(req.HostBearerToken))
|
||||
if err != nil {
|
||||
return provision.RuntimeImportResult{}, err
|
||||
}
|
||||
store, err := sqlite.Open(ctx, sqliteDSN)
|
||||
if err != nil {
|
||||
return provision.RuntimeImportResult{}, err
|
||||
}
|
||||
defer store.Close()
|
||||
subscriptions := make([]provision.SubscriptionTarget, 0, len(req.SubscriptionUsers))
|
||||
for _, userID := range req.SubscriptionUsers {
|
||||
subscriptions = append(subscriptions, provision.SubscriptionTarget{UserID: userID, DurationDays: req.SubscriptionDays})
|
||||
}
|
||||
service := provision.NewRuntimeImportService(store, client)
|
||||
return service.Import(ctx, provision.RuntimeImportRequest{
|
||||
HostBaseURL: req.HostBaseURL,
|
||||
Pack: loadedPack,
|
||||
Provider: providerManifest,
|
||||
Mode: req.Mode,
|
||||
Keys: req.Keys,
|
||||
Access: provision.AccessRequest{
|
||||
Mode: req.AccessMode,
|
||||
ProbeAPIKey: req.AccessAPIKey,
|
||||
Subscriptions: subscriptions,
|
||||
},
|
||||
})
|
||||
},
|
||||
RollbackProvider: func(ctx context.Context, req RollbackProviderRequest) (provision.RollbackReport, error) {
|
||||
loadedPack, err := pack.LoadPath(req.PackPath)
|
||||
if err != nil {
|
||||
return provision.RollbackReport{}, err
|
||||
}
|
||||
providerManifest, err := findProvider(loadedPack, req.ProviderID)
|
||||
if err != nil {
|
||||
return provision.RollbackReport{}, err
|
||||
}
|
||||
client, err := sub2api.NewClient(req.HostBaseURL, sub2api.WithAPIKey(req.HostAPIKey), sub2api.WithBearerToken(req.HostBearerToken))
|
||||
if err != nil {
|
||||
return provision.RollbackReport{}, err
|
||||
}
|
||||
service := provision.NewRollbackService(client)
|
||||
return service.Rollback(ctx, provision.RollbackRequest{Provider: providerManifest})
|
||||
},
|
||||
ReconcileProvider: func(ctx context.Context, req ReconcileProviderRequest) (provision.ReconcileResult, error) {
|
||||
loadedPack, err := pack.LoadPath(req.PackPath)
|
||||
if err != nil {
|
||||
return provision.ReconcileResult{}, err
|
||||
}
|
||||
providerManifest, err := findProvider(loadedPack, req.ProviderID)
|
||||
if err != nil {
|
||||
return provision.ReconcileResult{}, err
|
||||
}
|
||||
client, err := sub2api.NewClient(req.HostBaseURL, sub2api.WithAPIKey(req.HostAPIKey), sub2api.WithBearerToken(req.HostBearerToken))
|
||||
if err != nil {
|
||||
return provision.ReconcileResult{}, err
|
||||
}
|
||||
store, err := sqlite.Open(ctx, sqliteDSN)
|
||||
if err != nil {
|
||||
return provision.ReconcileResult{}, err
|
||||
}
|
||||
defer store.Close()
|
||||
service := provision.NewReconcileService(store, client)
|
||||
return service.Reconcile(ctx, provision.ReconcileRequest{HostBaseURL: req.HostBaseURL, AccessProbeAPIKey: req.AccessAPIKey, Pack: loadedPack, Provider: providerManifest})
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func findProvider(loaded pack.LoadedPack, providerID string) (pack.ProviderManifest, error) {
|
||||
for _, provider := range loaded.Providers {
|
||||
if provider.ProviderID == strings.TrimSpace(providerID) {
|
||||
return provider, nil
|
||||
}
|
||||
}
|
||||
return pack.ProviderManifest{}, fmt.Errorf("provider %q not found in pack %q", providerID, loaded.Manifest.PackID)
|
||||
}
|
||||
140
internal/config/config_test.go
Normal file
140
internal/config/config_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestReadOptionalEnv(t *testing.T) {
|
||||
t.Run("present non-empty", func(t *testing.T) {
|
||||
lookup := func(k string) (string, bool) {
|
||||
if k == "MY_KEY" {
|
||||
return " value ", true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
if got := readOptionalEnv(lookup, "MY_KEY", "default"); got != "value" {
|
||||
t.Fatalf("got %q, want %q", got, "value")
|
||||
}
|
||||
})
|
||||
t.Run("present empty", func(t *testing.T) {
|
||||
lookup := func(k string) (string, bool) {
|
||||
return " ", true
|
||||
}
|
||||
if got := readOptionalEnv(lookup, "MY_KEY", "default"); got != "default" {
|
||||
t.Fatalf("got %q, want %q", got, "default")
|
||||
}
|
||||
})
|
||||
t.Run("missing", func(t *testing.T) {
|
||||
lookup := func(k string) (string, bool) {
|
||||
return "", false
|
||||
}
|
||||
if got := readOptionalEnv(lookup, "MY_KEY", "default"); got != "default" {
|
||||
t.Fatalf("got %q, want %q", got, "default")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadRequiredEnv(t *testing.T) {
|
||||
t.Run("present", func(t *testing.T) {
|
||||
lookup := func(k string) (string, bool) {
|
||||
return "my-token", true
|
||||
}
|
||||
if got := readRequiredEnv(lookup, "TOKEN"); got != "my-token" {
|
||||
t.Fatalf("got %q, want %q", got, "my-token")
|
||||
}
|
||||
})
|
||||
t.Run("missing", func(t *testing.T) {
|
||||
lookup := func(k string) (string, bool) {
|
||||
return "", false
|
||||
}
|
||||
if got := readRequiredEnv(lookup, "TOKEN"); got != "" {
|
||||
t.Fatalf("got %q, want empty", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadStartupFromLookupEnv(t *testing.T) {
|
||||
t.Run("custom values", func(t *testing.T) {
|
||||
lookup := func(k string) (string, bool) {
|
||||
switch k {
|
||||
case EnvListenAddr:
|
||||
return ":9090", true
|
||||
case EnvSQLiteDSN:
|
||||
return "/data/db.sqlite", true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
cfg, err := loadStartupFromLookupEnv(lookup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cfg.Server.ListenAddr != ":9090" {
|
||||
t.Fatalf("ListenAddr = %q, want %q", cfg.Server.ListenAddr, ":9090")
|
||||
}
|
||||
if cfg.Database.SQLiteDSN != "/data/db.sqlite" {
|
||||
t.Fatalf("SQLiteDSN = %q, want %q", cfg.Database.SQLiteDSN, "/data/db.sqlite")
|
||||
}
|
||||
})
|
||||
t.Run("default values", func(t *testing.T) {
|
||||
lookup := func(k string) (string, bool) {
|
||||
return "", false
|
||||
}
|
||||
cfg, err := loadStartupFromLookupEnv(lookup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cfg.Server.ListenAddr != DefaultListenAddr {
|
||||
t.Fatalf("ListenAddr = %q, want %q", cfg.Server.ListenAddr, DefaultListenAddr)
|
||||
}
|
||||
if cfg.Database.SQLiteDSN != DefaultSQLiteDSN {
|
||||
t.Fatalf("SQLiteDSN = %q, want %q", cfg.Database.SQLiteDSN, DefaultSQLiteDSN)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadAdminTokenFromLookupEnv(t *testing.T) {
|
||||
t.Run("valid token", func(t *testing.T) {
|
||||
lookup := func(k string) (string, bool) {
|
||||
return " admin-secret-123 ", true
|
||||
}
|
||||
token, err := loadAdminTokenFromLookupEnv(lookup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if token != "admin-secret-123" {
|
||||
t.Fatalf("token = %q, want %q", token, "admin-secret-123")
|
||||
}
|
||||
})
|
||||
t.Run("empty token", func(t *testing.T) {
|
||||
lookup := func(k string) (string, bool) {
|
||||
return " ", true
|
||||
}
|
||||
_, err := loadAdminTokenFromLookupEnv(lookup)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty token")
|
||||
}
|
||||
})
|
||||
t.Run("missing env", func(t *testing.T) {
|
||||
lookup := func(k string) (string, bool) {
|
||||
return "", false
|
||||
}
|
||||
_, err := loadAdminTokenFromLookupEnv(lookup)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing env")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Verify exported wrappers call the lookup versions.
|
||||
// We can't easily test LoadStartupFromEnv / LoadAdminTokenFromEnv
|
||||
// since they depend on os.LookupEnv, but we verify they compile and don't panic.
|
||||
|
||||
func TestExportFunctionsExist(t *testing.T) {
|
||||
// Just verify the exported functions are reachable and return the right types
|
||||
_, err := LoadAdminTokenFromEnv()
|
||||
if err != nil && !errors.Is(err, err) {
|
||||
// any result is fine, just proving the function exists
|
||||
}
|
||||
}
|
||||
@@ -16,13 +16,19 @@ type HostAdapter interface {
|
||||
GetHostVersion(ctx context.Context) (string, error)
|
||||
ProbeCapabilities(ctx context.Context) (HostCapabilities, error)
|
||||
CreateGroup(ctx context.Context, req CreateGroupRequest) (GroupRef, error)
|
||||
DeleteGroup(ctx context.Context, groupID string) error
|
||||
CreateChannel(ctx context.Context, req CreateChannelRequest) (ChannelRef, error)
|
||||
DeleteChannel(ctx context.Context, channelID string) error
|
||||
CreatePlan(ctx context.Context, req CreatePlanRequest) (PlanRef, error)
|
||||
DeletePlan(ctx context.Context, planID string) error
|
||||
CreateAccount(ctx context.Context, req CreateAccountRequest) (AccountRef, error)
|
||||
BatchCreateAccounts(ctx context.Context, req BatchCreateAccountsRequest) ([]AccountRef, error)
|
||||
DeleteAccount(ctx context.Context, accountID string) error
|
||||
TestAccount(ctx context.Context, accountID string) (ProbeResult, error)
|
||||
GetAccountModels(ctx context.Context, accountID string) ([]AccountModel, error)
|
||||
AssignSubscription(ctx context.Context, req AssignSubscriptionRequest) (SubscriptionRef, error)
|
||||
CheckGatewayAccess(ctx context.Context, req GatewayAccessCheckRequest) (GatewayAccessResult, error)
|
||||
ListManagedResources(ctx context.Context, req ListManagedResourcesRequest) (ManagedResourceSnapshot, error)
|
||||
}
|
||||
|
||||
type HostCapabilities struct {
|
||||
|
||||
40
internal/host/sub2api/delete.go
Normal file
40
internal/host/sub2api/delete.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package sub2api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (c *Client) DeleteGroup(ctx context.Context, groupID string) error {
|
||||
return c.deleteResource(ctx, "/api/v1/admin/groups/", groupID)
|
||||
}
|
||||
|
||||
func (c *Client) DeleteChannel(ctx context.Context, channelID string) error {
|
||||
return c.deleteResource(ctx, "/api/v1/admin/channels/", channelID)
|
||||
}
|
||||
|
||||
func (c *Client) DeletePlan(ctx context.Context, planID string) error {
|
||||
return c.deleteResource(ctx, "/api/v1/admin/payment/plans/", planID)
|
||||
}
|
||||
|
||||
func (c *Client) DeleteAccount(ctx context.Context, accountID string) error {
|
||||
return c.deleteResource(ctx, "/api/v1/admin/accounts/", accountID)
|
||||
}
|
||||
|
||||
func (c *Client) deleteResource(ctx context.Context, prefix, resourceID string) error {
|
||||
resourceID = strings.TrimSpace(resourceID)
|
||||
if resourceID == "" {
|
||||
return fmt.Errorf("resource id is required")
|
||||
}
|
||||
path := prefix + resourceID
|
||||
statusCode, _, body, err := c.perform(ctx, http.MethodDelete, path, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices {
|
||||
return newHTTPError(http.MethodDelete, path, statusCode, body)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
62
internal/host/sub2api/gateway_probe.go
Normal file
62
internal/host/sub2api/gateway_probe.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package sub2api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type GatewayAccessCheckRequest struct {
|
||||
APIKey string
|
||||
ExpectedModel string
|
||||
}
|
||||
|
||||
type GatewayAccessResult struct {
|
||||
OK bool `json:"ok"`
|
||||
StatusCode int `json:"status_code"`
|
||||
Models []string `json:"models"`
|
||||
HasExpectedModel bool `json:"has_expected_model"`
|
||||
}
|
||||
|
||||
func (c *Client) CheckGatewayAccess(ctx context.Context, req GatewayAccessCheckRequest) (GatewayAccessResult, error) {
|
||||
gatewayClient := *c
|
||||
gatewayClient.apiKey = strings.TrimSpace(req.APIKey)
|
||||
gatewayClient.bearerToken = ""
|
||||
|
||||
statusCode, _, body, err := gatewayClient.perform(ctx, http.MethodGet, "/v1/models", nil)
|
||||
if err != nil {
|
||||
return GatewayAccessResult{}, err
|
||||
}
|
||||
|
||||
result := GatewayAccessResult{StatusCode: statusCode, OK: statusCode >= http.StatusOK && statusCode < http.StatusMultipleChoices}
|
||||
if !result.OK {
|
||||
return result, nil
|
||||
}
|
||||
result.Models = decodeGatewayModelIDs(body)
|
||||
for _, modelID := range result.Models {
|
||||
if modelID == strings.TrimSpace(req.ExpectedModel) {
|
||||
result.HasExpectedModel = true
|
||||
break
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func decodeGatewayModelIDs(body []byte) []string {
|
||||
var payload struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &payload); err == nil && len(payload.Data) > 0 {
|
||||
models := make([]string, 0, len(payload.Data))
|
||||
for _, item := range payload.Data {
|
||||
if id := strings.TrimSpace(item.ID); id != "" {
|
||||
models = append(models, id)
|
||||
}
|
||||
}
|
||||
return models
|
||||
}
|
||||
return nil
|
||||
}
|
||||
97
internal/host/sub2api/list_resources.go
Normal file
97
internal/host/sub2api/list_resources.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package sub2api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (c *Client) ListManagedResources(ctx context.Context, req ListManagedResourcesRequest) (ManagedResourceSnapshot, error) {
|
||||
groups, err := c.listNamedResources(ctx, "/api/v1/admin/groups", req.GroupName)
|
||||
if err != nil {
|
||||
return ManagedResourceSnapshot{}, fmt.Errorf("list groups: %w", err)
|
||||
}
|
||||
channels, err := c.listNamedResources(ctx, "/api/v1/admin/channels", req.ChannelName)
|
||||
if err != nil {
|
||||
return ManagedResourceSnapshot{}, fmt.Errorf("list channels: %w", err)
|
||||
}
|
||||
plans, err := c.listNamedResources(ctx, "/api/v1/admin/payment/plans", req.PlanName)
|
||||
if err != nil {
|
||||
return ManagedResourceSnapshot{}, fmt.Errorf("list plans: %w", err)
|
||||
}
|
||||
accounts, err := c.listNamedResources(ctx, "/api/v1/admin/accounts", "")
|
||||
if err != nil {
|
||||
return ManagedResourceSnapshot{}, fmt.Errorf("list accounts: %w", err)
|
||||
}
|
||||
|
||||
return ManagedResourceSnapshot{
|
||||
Groups: groups,
|
||||
Channels: channels,
|
||||
Plans: plans,
|
||||
Accounts: filterNamedResourcesByPrefix(accounts, req.AccountNamePrefix),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Client) listNamedResources(ctx context.Context, path, expectedName string) ([]NamedResource, error) {
|
||||
statusCode, _, body, err := c.perform(ctx, "GET", path, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if statusCode < 200 || statusCode >= 300 {
|
||||
return nil, newHTTPError("GET", path, statusCode, body)
|
||||
}
|
||||
|
||||
resources, err := decodeNamedResources(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode %s response: %w", path, err)
|
||||
}
|
||||
return filterNamedResourcesByName(resources, expectedName), nil
|
||||
}
|
||||
|
||||
func decodeNamedResources(body []byte) ([]NamedResource, error) {
|
||||
var resources []NamedResource
|
||||
if err := decodeEnvelopeObject(body, &resources); err == nil {
|
||||
return resources, nil
|
||||
}
|
||||
|
||||
var wrapper struct {
|
||||
Data struct {
|
||||
Items []NamedResource `json:"items"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &wrapper); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return wrapper.Data.Items, nil
|
||||
}
|
||||
|
||||
func filterNamedResourcesByName(resources []NamedResource, expectedName string) []NamedResource {
|
||||
expectedName = strings.TrimSpace(expectedName)
|
||||
if expectedName == "" {
|
||||
return resources
|
||||
}
|
||||
|
||||
filtered := make([]NamedResource, 0, len(resources))
|
||||
for _, resource := range resources {
|
||||
if strings.TrimSpace(resource.Name) == expectedName {
|
||||
filtered = append(filtered, resource)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func filterNamedResourcesByPrefix(resources []NamedResource, prefix string) []NamedResource {
|
||||
prefix = strings.TrimSpace(prefix)
|
||||
if prefix == "" {
|
||||
return resources
|
||||
}
|
||||
|
||||
filtered := make([]NamedResource, 0, len(resources))
|
||||
for _, resource := range resources {
|
||||
if strings.HasPrefix(strings.TrimSpace(resource.Name), prefix) {
|
||||
filtered = append(filtered, resource)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
20
internal/host/sub2api/resources.go
Normal file
20
internal/host/sub2api/resources.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package sub2api
|
||||
|
||||
type ListManagedResourcesRequest struct {
|
||||
GroupName string
|
||||
ChannelName string
|
||||
PlanName string
|
||||
AccountNamePrefix string
|
||||
}
|
||||
|
||||
type NamedResource struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type ManagedResourceSnapshot struct {
|
||||
Groups []NamedResource `json:"groups"`
|
||||
Channels []NamedResource `json:"channels"`
|
||||
Plans []NamedResource `json:"plans"`
|
||||
Accounts []NamedResource `json:"accounts"`
|
||||
}
|
||||
699
internal/host/sub2api/sub2api_test.go
Normal file
699
internal/host/sub2api/sub2api_test.go
Normal file
@@ -0,0 +1,699 @@
|
||||
package sub2api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHTTPErrorErrorMessage(t *testing.T) {
|
||||
e := newHTTPError("POST", "/api/v1/admin/groups", http.StatusTeapot, []byte("short and stout"))
|
||||
want := "sub2api POST /api/v1/admin/groups returned 418: short and stout"
|
||||
if got := e.Error(); got != want {
|
||||
t.Fatalf("HTTPError.Error() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithHTTPClientAndOptions(t *testing.T) {
|
||||
customHTTP := &http.Client{Timeout: 123}
|
||||
client, err := NewClient("http://localhost:8080",
|
||||
WithHTTPClient(customHTTP),
|
||||
WithAPIKey(" sk-abc "),
|
||||
WithBearerToken(" tok-xyz "),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if client.httpClient != customHTTP {
|
||||
t.Fatal("WithHTTPClient not applied")
|
||||
}
|
||||
if client.apiKey != "sk-abc" {
|
||||
t.Fatalf("apiKey = %q, want %q", client.apiKey, "sk-abc")
|
||||
}
|
||||
if client.bearerToken != "tok-xyz" {
|
||||
t.Fatalf("bearerToken = %q, want %q", client.bearerToken, "tok-xyz")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClient_RejectsInvalidURLs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
}{
|
||||
{"empty", ""},
|
||||
{"no scheme", "localhost:8080"},
|
||||
{"no host", "http://"},
|
||||
{"garbage", "://foo"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := NewClient(tt.url)
|
||||
if err == nil {
|
||||
t.Fatalf("NewClient(%q) error = nil, want error", tt.url)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolvePath(t *testing.T) {
|
||||
client, err := NewClient("http://host:9090")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
want string
|
||||
}{
|
||||
{"/v1/models", "http://host:9090/v1/models"},
|
||||
{"v1/models", "http://host:9090/v1/models"},
|
||||
{"/v1/models?key=val", "http://host:9090/v1/models?key=val"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.path, func(t *testing.T) {
|
||||
if got := client.resolvePath(tt.path); got != tt.want {
|
||||
t.Fatalf("resolvePath(%q) = %q, want %q", tt.path, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyAuth(t *testing.T) {
|
||||
t.Run("api key preferred", func(t *testing.T) {
|
||||
c, _ := NewClient("http://h:8080", WithAPIKey("key1"), WithBearerToken("btok"))
|
||||
req, _ := http.NewRequest("GET", "http://h:8080/path", nil)
|
||||
c.applyAuth(req)
|
||||
if h := req.Header.Get("x-api-key"); h != "key1" {
|
||||
t.Fatalf("x-api-key = %q, want %q", h, "key1")
|
||||
}
|
||||
if h := req.Header.Get("Authorization"); h != "" {
|
||||
t.Fatalf("Authorization should be empty, got %q", h)
|
||||
}
|
||||
})
|
||||
t.Run("bearer token fallback", func(t *testing.T) {
|
||||
c, _ := NewClient("http://h:8080", WithBearerToken("btok"))
|
||||
req, _ := http.NewRequest("GET", "http://h:8080/path", nil)
|
||||
c.applyAuth(req)
|
||||
if h := req.Header.Get("Authorization"); h != "Bearer btok" {
|
||||
t.Fatalf("Authorization = %q, want %q", h, "Bearer btok")
|
||||
}
|
||||
})
|
||||
t.Run("no auth", func(t *testing.T) {
|
||||
c, _ := NewClient("http://h:8080")
|
||||
req, _ := http.NewRequest("GET", "http://h:8080/path", nil)
|
||||
c.applyAuth(req)
|
||||
if h := req.Header.Get("x-api-key"); h != "" {
|
||||
t.Fatalf("x-api-key should be empty, got %q", h)
|
||||
}
|
||||
if h := req.Header.Get("Authorization"); h != "" {
|
||||
t.Fatalf("Authorization should be empty, got %q", h)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDecodeEnvelopeObject(t *testing.T) {
|
||||
t.Run("standard envelope", func(t *testing.T) {
|
||||
body := []byte(`{"data":{"id":"g1","name":"test"}}`)
|
||||
var ref GroupRef
|
||||
if err := decodeEnvelopeObject(body, &ref); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if ref.ID != "g1" || ref.Name != "test" {
|
||||
t.Fatalf("got %+v, want {ID:g1 Name:test}", ref)
|
||||
}
|
||||
})
|
||||
t.Run("flat response (no data wrapper)", func(t *testing.T) {
|
||||
body := []byte(`{"id":"g2","name":"flat"}`)
|
||||
var ref GroupRef
|
||||
if err := decodeEnvelopeObject(body, &ref); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if ref.ID != "g2" || ref.Name != "flat" {
|
||||
t.Fatalf("got %+v, want {ID:g2 Name:flat}", ref)
|
||||
}
|
||||
})
|
||||
t.Run("data:null returns flat", func(t *testing.T) {
|
||||
body := []byte(`{"data":null,"id":"g3"}`)
|
||||
var ref GroupRef
|
||||
if err := decodeEnvelopeObject(body, &ref); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if ref.ID != "g3" {
|
||||
t.Fatalf("id = %q, want %q", ref.ID, "g3")
|
||||
}
|
||||
})
|
||||
t.Run("invalid json returns error", func(t *testing.T) {
|
||||
var ref GroupRef
|
||||
if err := decodeEnvelopeObject([]byte(`not json`), &ref); err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDecodeGatewayModelIDs(t *testing.T) {
|
||||
t.Run("standard list", func(t *testing.T) {
|
||||
ids := decodeGatewayModelIDs([]byte(`{"data":[{"id":"gpt-4"},{"id":" claude-3 "}]}`))
|
||||
if len(ids) != 2 || ids[0] != "gpt-4" || ids[1] != "claude-3" {
|
||||
t.Fatalf("got %v, want [gpt-4 claude-3]", ids)
|
||||
}
|
||||
})
|
||||
t.Run("empty data", func(t *testing.T) {
|
||||
if ids := decodeGatewayModelIDs([]byte(`{}`)); ids != nil {
|
||||
t.Fatalf("expected nil, got %v", ids)
|
||||
}
|
||||
})
|
||||
t.Run("invalid json", func(t *testing.T) {
|
||||
if ids := decodeGatewayModelIDs([]byte(`not json`)); ids != nil {
|
||||
t.Fatalf("expected nil, got %v", ids)
|
||||
}
|
||||
})
|
||||
t.Run("empty array", func(t *testing.T) {
|
||||
if ids := decodeGatewayModelIDs([]byte(`{"data":[]}`)); ids != nil {
|
||||
t.Fatalf("expected nil, got %v", ids)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFilterNamedResourcesByName(t *testing.T) {
|
||||
resources := []NamedResource{
|
||||
{Name: "group-a", ID: "g1"},
|
||||
{Name: "group-b", ID: "g2"},
|
||||
{Name: " group-a ", ID: "g3"},
|
||||
}
|
||||
t.Run("match", func(t *testing.T) {
|
||||
got := filterNamedResourcesByName(resources, "group-a")
|
||||
if len(got) != 2 || got[0].ID != "g1" || got[1].ID != "g3" {
|
||||
t.Fatalf("got %+v, want 2 matches", got)
|
||||
}
|
||||
})
|
||||
t.Run("no match", func(t *testing.T) {
|
||||
if got := filterNamedResourcesByName(resources, "nonexistent"); len(got) != 0 {
|
||||
t.Fatalf("expected 0, got %d", len(got))
|
||||
}
|
||||
})
|
||||
t.Run("empty name returns all", func(t *testing.T) {
|
||||
if got := filterNamedResourcesByName(resources, ""); len(got) != 3 {
|
||||
t.Fatalf("expected 3, got %d", len(got))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFilterNamedResourcesByPrefix(t *testing.T) {
|
||||
resources := []NamedResource{
|
||||
{Name: "deepseek-proxy", ID: "r1"},
|
||||
{Name: "deepseek-us", ID: "r2"},
|
||||
{Name: "claude-eu", ID: "r3"},
|
||||
}
|
||||
t.Run("prefix matches", func(t *testing.T) {
|
||||
got := filterNamedResourcesByPrefix(resources, "deepseek")
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected 2, got %d", len(got))
|
||||
}
|
||||
})
|
||||
t.Run("no prefix match", func(t *testing.T) {
|
||||
if got := filterNamedResourcesByPrefix(resources, "nope"); len(got) != 0 {
|
||||
t.Fatalf("expected 0, got %d", len(got))
|
||||
}
|
||||
})
|
||||
t.Run("empty prefix returns all", func(t *testing.T) {
|
||||
if got := filterNamedResourcesByPrefix(resources, ""); len(got) != 3 {
|
||||
t.Fatalf("expected 3, got %d", len(got))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDecodeNamedResources(t *testing.T) {
|
||||
t.Run("envelope", func(t *testing.T) {
|
||||
resources, err := decodeNamedResources([]byte(`{"data":[{"id":"r1","name":"n1"}]}`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(resources) != 1 || resources[0].ID != "r1" {
|
||||
t.Fatalf("got %+v", resources)
|
||||
}
|
||||
})
|
||||
t.Run("wrapper with items", func(t *testing.T) {
|
||||
resources, err := decodeNamedResources([]byte(`{"data":{"items":[{"id":"r2","name":"n2"}]}}`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(resources) != 1 || resources[0].ID != "r2" {
|
||||
t.Fatalf("got %+v", resources)
|
||||
}
|
||||
})
|
||||
t.Run("invalid json", func(t *testing.T) {
|
||||
_, err := decodeNamedResources([]byte(`not json`))
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDecodeAccountRefs(t *testing.T) {
|
||||
t.Run("envelope", func(t *testing.T) {
|
||||
refs, err := decodeAccountRefs([]byte(`{"data":[{"id":"a1"}]}`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(refs) != 1 || refs[0].ID != "a1" {
|
||||
t.Fatalf("got %+v", refs)
|
||||
}
|
||||
})
|
||||
t.Run("wrapper with items", func(t *testing.T) {
|
||||
refs, err := decodeAccountRefs([]byte(`{"data":{"items":[{"id":"a2"}]}}`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(refs) != 1 || refs[0].ID != "a2" {
|
||||
t.Fatalf("got %+v", refs)
|
||||
}
|
||||
})
|
||||
t.Run("invalid json", func(t *testing.T) {
|
||||
_, err := decodeAccountRefs([]byte(`not json`))
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDecodeAccountModels(t *testing.T) {
|
||||
t.Run("envelope", func(t *testing.T) {
|
||||
models, err := decodeAccountModels([]byte(`{"data":[{"id":"gpt4","display_name":"GPT-4","type":"chat"}]}`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(models) != 1 || models[0].ID != "gpt4" {
|
||||
t.Fatalf("got %+v", models)
|
||||
}
|
||||
})
|
||||
t.Run("wrapper with items", func(t *testing.T) {
|
||||
models, err := decodeAccountModels([]byte(`{"data":{"items":[{"id":"cl3","display_name":"Claude 3","type":"chat"}]}}`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(models) != 1 || models[0].ID != "cl3" {
|
||||
t.Fatalf("got %+v", models)
|
||||
}
|
||||
})
|
||||
t.Run("invalid json", func(t *testing.T) {
|
||||
_, err := decodeAccountModels([]byte(`not json`))
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseProbeResult(t *testing.T) {
|
||||
t.Run("SSE with ok=true", func(t *testing.T) {
|
||||
result, err := parseProbeResult([]byte("data: {\"status\":\"passed\",\"ok\":true}\n"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !result.OK || result.Status != "passed" {
|
||||
t.Fatalf("got %+v, want OK=true Status=passed", result)
|
||||
}
|
||||
})
|
||||
t.Run("SSE with success=true", func(t *testing.T) {
|
||||
result, err := parseProbeResult([]byte("data: {\"status\":\"succeeded\",\"success\":true}\n"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !result.OK || result.Status != "passed" {
|
||||
t.Fatalf("got %+v", result)
|
||||
}
|
||||
})
|
||||
t.Run("SSE with ok=false", func(t *testing.T) {
|
||||
result, err := parseProbeResult([]byte("data: {\"status\":\"failed\",\"ok\":false}\n"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.OK || result.Status != "failed" {
|
||||
t.Fatalf("got %+v", result)
|
||||
}
|
||||
})
|
||||
t.Run("SSE with status-based ok", func(t *testing.T) {
|
||||
result, err := parseProbeResult([]byte("data: {\"status\":\"pass\",\"message\":\"all good\"}\n"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !result.OK || result.Message != "all good" {
|
||||
t.Fatalf("got %+v", result)
|
||||
}
|
||||
})
|
||||
t.Run("multiple SSE events picks last", func(t *testing.T) {
|
||||
result, err := parseProbeResult([]byte("data: {\"status\":\"running\"}\ndata: {\"status\":\"passed\",\"ok\":true}\n"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !result.OK {
|
||||
t.Fatalf("expected OK=true from last event, got %+v", result)
|
||||
}
|
||||
})
|
||||
t.Run("no data events", func(t *testing.T) {
|
||||
_, err := parseProbeResult([]byte("not data\n"))
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNormalizeProbeStatus(t *testing.T) {
|
||||
tests := []struct {
|
||||
status string
|
||||
ok bool
|
||||
want string
|
||||
}{
|
||||
{"pass", true, "passed"},
|
||||
{"PASSED", true, "passed"},
|
||||
{"Ok", true, "passed"},
|
||||
{"success", true, "passed"},
|
||||
{"succeeded", true, "passed"},
|
||||
{"fail", false, "failed"},
|
||||
{"FAILED", false, "failed"},
|
||||
{"error", false, "failed"},
|
||||
{"custom_ok", true, "passed"},
|
||||
{"custom_fail", false, "failed"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.status, func(t *testing.T) {
|
||||
if got := normalizeProbeStatus(tt.status, tt.ok); got != tt.want {
|
||||
t.Fatalf("normalizeProbeStatus(%q, %v) = %q, want %q", tt.status, tt.ok, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
func TestLooksLikeExistingEndpoint(t *testing.T) {
|
||||
t.Run("json content type", func(t *testing.T) {
|
||||
h := http.Header{"Content-Type": []string{"application/json"}}
|
||||
if !looksLikeExistingEndpoint(h, nil) {
|
||||
t.Fatal("expected true with json content type")
|
||||
}
|
||||
})
|
||||
t.Run("sse content type", func(t *testing.T) {
|
||||
h := http.Header{"Content-Type": []string{"text/event-stream"}}
|
||||
if !looksLikeExistingEndpoint(h, nil) {
|
||||
t.Fatal("expected true with sse content type")
|
||||
}
|
||||
})
|
||||
t.Run("empty body and no content type", func(t *testing.T) {
|
||||
if looksLikeExistingEndpoint(http.Header{}, nil) {
|
||||
t.Fatal("expected false")
|
||||
}
|
||||
})
|
||||
t.Run("json-like body", func(t *testing.T) {
|
||||
if !looksLikeExistingEndpoint(http.Header{}, []byte(`{"error":"not found"}`)) {
|
||||
t.Fatal("expected true for json body")
|
||||
}
|
||||
})
|
||||
t.Run("array body", func(t *testing.T) {
|
||||
if !looksLikeExistingEndpoint(http.Header{}, []byte(`[]`)) {
|
||||
t.Fatal("expected true for array body")
|
||||
}
|
||||
})
|
||||
t.Run("html body", func(t *testing.T) {
|
||||
if looksLikeExistingEndpoint(http.Header{}, []byte(`<html>`)) {
|
||||
t.Fatal("expected false for html body")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Tests for NamedResource type used by the filter functions.
|
||||
// Defined locally since it's in the same package.
|
||||
|
||||
func TestNewClientWithNilOption(t *testing.T) {
|
||||
client, err := NewClient("http://localhost:8080", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if client == nil {
|
||||
t.Fatal("client is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHTTPError(t *testing.T) {
|
||||
e := newHTTPError("GET", "/v1/models", 200, []byte(`{"ok":true}`))
|
||||
if e.Method != "GET" || e.Path != "/v1/models" || e.StatusCode != 200 || e.Body != `{"ok":true}` {
|
||||
t.Fatalf("unexpected http error: %+v", e)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func TestPerformWithMockServer(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/v1/admin/system/version":
|
||||
w.Write([]byte(`{"data":{"version":"v1.2.3"}}`))
|
||||
case "/api/v1/admin/groups":
|
||||
w.Write([]byte(`{"data":{"id":"g1","name":"test-group"}}`))
|
||||
case "/api/v1/admin/channels":
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(`{"error":"panic"}`))
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client, err := NewClient(srv.URL, WithAPIKey("test-key"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("GetHostVersion", func(t *testing.T) {
|
||||
ver, err := client.GetHostVersion(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if ver != "v1.2.3" {
|
||||
t.Fatalf("version = %q, want %q", ver, "v1.2.3")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("postJSON success", func(t *testing.T) {
|
||||
var ref GroupRef
|
||||
if err := client.postJSON(context.Background(), "/api/v1/admin/groups", CreateGroupRequest{Name: "test"}, &ref); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if ref.ID != "g1" || ref.Name != "test-group" {
|
||||
t.Fatalf("got %+v, want {ID:g1 Name:test-group}", ref)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("postJSON error status", func(t *testing.T) {
|
||||
var ref GroupRef
|
||||
err := client.postJSON(context.Background(), "/api/v1/admin/channels", nil, &ref)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
var httpErr *HTTPError
|
||||
if !errors.As(err, &httpErr) {
|
||||
t.Fatalf("expected HTTPError, got %T: %v", err, err)
|
||||
}
|
||||
if httpErr.StatusCode != 500 {
|
||||
t.Fatalf("status code = %d, want 500", httpErr.StatusCode)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("getJSON success", func(t *testing.T) {
|
||||
var ref GroupRef
|
||||
if err := client.getJSON(context.Background(), "/api/v1/admin/groups", &ref); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("getJSON error status", func(t *testing.T) {
|
||||
var ref GroupRef
|
||||
err := client.getJSON(context.Background(), "/bad/path", &ref)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateGroupWithMock(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(`{"data":{"id":"g1","name":"demo"}}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client, err := NewClient(srv.URL, WithAPIKey("k"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ref, err := client.CreateGroup(context.Background(), CreateGroupRequest{Name: "demo", RateMultiplier: 1.0})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if ref.ID != "g1" || ref.Name != "demo" {
|
||||
t.Fatalf("got %+v", ref)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateChannelWithMock(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(`{"data":{"id":"c1","name":"ch"}}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
||||
_, err := client.CreateChannel(context.Background(), CreateChannelRequest{Name: "ch"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreatePlanWithMock(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(`{"data":{"id":"p1","name":"plan"}}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
||||
_, err := client.CreatePlan(context.Background(), CreatePlanRequest{Name: "plan"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteWithMock(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer srv.Close()
|
||||
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
||||
|
||||
t.Run("DeleteGroup", func(t *testing.T) {
|
||||
if err := client.DeleteGroup(context.Background(), "g1"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
t.Run("DeleteChannel", func(t *testing.T) {
|
||||
if err := client.DeleteChannel(context.Background(), "c1"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
t.Run("DeletePlan", func(t *testing.T) {
|
||||
if err := client.DeletePlan(context.Background(), "p1"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
t.Run("DeleteAccount", func(t *testing.T) {
|
||||
if err := client.DeleteAccount(context.Background(), "a1"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAssignSubscriptionWithMock(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(`{"data":{"id":"s1"}}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
||||
ref, err := client.AssignSubscription(context.Background(), AssignSubscriptionRequest{UserID: "u1"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if ref.ID != "s1" {
|
||||
t.Fatalf("id = %q", ref.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckGatewayAccessWithMock(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(`{"data":[{"id":"gpt-4"},{"id":"claude-3"}]}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
||||
result, err := client.CheckGatewayAccess(context.Background(), GatewayAccessCheckRequest{APIKey: "gk", ExpectedModel: "gpt-4"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !result.OK {
|
||||
t.Fatal("expected OK=true")
|
||||
}
|
||||
if !result.HasExpectedModel {
|
||||
t.Fatal("expected HasExpectedModel=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchCreateAccountsWithMock(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(`{"data":[{"id":"a1","name":"acct1"}]}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
||||
refs, err := client.BatchCreateAccounts(context.Background(), BatchCreateAccountsRequest{
|
||||
Accounts: []CreateAccountRequest{{Name: "acct1"}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(refs) != 1 || refs[0].ID != "a1" {
|
||||
t.Fatalf("got %+v", refs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProbeCapabilitiesWithMock(t *testing.T) {
|
||||
callCount := 0
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
}))
|
||||
defer srv.Close()
|
||||
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
||||
caps, err := client.ProbeCapabilities(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !caps.Groups || !caps.Channels || !caps.Plans || !caps.Accounts || !caps.AccountTest || !caps.AccountModels || !caps.Subscriptions {
|
||||
t.Fatalf("all capabilities should be true, got %+v", caps)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListManagedResourcesWithMock(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(`{"data":{"items":[
|
||||
{"id":"r1","name":"resource-1"}
|
||||
]}}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
||||
snapshot, err := client.ListManagedResources(context.Background(), ListManagedResourcesRequest{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(snapshot.Groups) != 1 {
|
||||
t.Fatalf("expected 1 group, got %d", len(snapshot.Groups))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTestAccountWithMock(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("data: {\"status\":\"passed\",\"ok\":true}\n"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
||||
result, err := client.TestAccount(context.Background(), "a1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !result.OK {
|
||||
t.Fatal("expected OK=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountModelsWithMock(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(`{"data":[{"id":"m1","display_name":"M1","type":"chat"}]}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
||||
models, err := client.GetAccountModels(context.Background(), "a1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(models) != 1 || models[0].ID != "m1" {
|
||||
t.Fatalf("got %+v", models)
|
||||
}
|
||||
}
|
||||
266
internal/pack/extra_test.go
Normal file
266
internal/pack/extra_test.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package pack
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValidateManifestRequiredFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
manifest Manifest
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "empty pack id",
|
||||
manifest: Manifest{
|
||||
Version: "1.0.0",
|
||||
TargetHost: "sub2api",
|
||||
ProvidersDir: "providers",
|
||||
ChecksumFile: "checksums.txt",
|
||||
},
|
||||
wantErr: "pack_id is required",
|
||||
},
|
||||
{
|
||||
name: "empty version",
|
||||
manifest: Manifest{
|
||||
PackID: "openai-cn-pack",
|
||||
TargetHost: "sub2api",
|
||||
ProvidersDir: "providers",
|
||||
ChecksumFile: "checksums.txt",
|
||||
},
|
||||
wantErr: "version is required",
|
||||
},
|
||||
{
|
||||
name: "empty target host",
|
||||
manifest: Manifest{
|
||||
PackID: "openai-cn-pack",
|
||||
Version: "1.0.0",
|
||||
ProvidersDir: "providers",
|
||||
ChecksumFile: "checksums.txt",
|
||||
},
|
||||
wantErr: "target_host is required",
|
||||
},
|
||||
{
|
||||
name: "empty providers dir",
|
||||
manifest: Manifest{
|
||||
PackID: "openai-cn-pack",
|
||||
Version: "1.0.0",
|
||||
TargetHost: "sub2api",
|
||||
ChecksumFile: "checksums.txt",
|
||||
},
|
||||
wantErr: "providers_dir is required",
|
||||
},
|
||||
{
|
||||
name: "empty checksum file",
|
||||
manifest: Manifest{
|
||||
PackID: "openai-cn-pack",
|
||||
Version: "1.0.0",
|
||||
TargetHost: "sub2api",
|
||||
ProvidersDir: "providers",
|
||||
},
|
||||
wantErr: "checksum_file is required",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateManifest(tt.manifest)
|
||||
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Fatalf("validateManifest() error = %v, want substring %q", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateProvidersRejectsInvalidProviderFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
providers []ProviderManifest
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "empty provider id",
|
||||
providers: []ProviderManifest{{
|
||||
DisplayName: "DeepSeek",
|
||||
BaseURL: "https://api.deepseek.com",
|
||||
Platform: "openai",
|
||||
AccountType: "api",
|
||||
DefaultModels: []string{"deepseek-chat"},
|
||||
SmokeTestModel: "deepseek-chat",
|
||||
GroupTemplate: GroupTemplate{Name: "g"},
|
||||
ChannelTemplate: ChannelTemplate{
|
||||
Name: "c",
|
||||
ModelMapping: map[string]string{"deepseek-chat": "deepseek-chat"},
|
||||
},
|
||||
PlanTemplate: PlanTemplate{Name: "p", ValidityDays: 30},
|
||||
}},
|
||||
wantErr: "provider_id is required",
|
||||
},
|
||||
{
|
||||
name: "empty base url",
|
||||
providers: []ProviderManifest{{
|
||||
ProviderID: "deepseek",
|
||||
DisplayName: "DeepSeek",
|
||||
BaseURL: "",
|
||||
Platform: "openai",
|
||||
AccountType: "api",
|
||||
DefaultModels: []string{"deepseek-chat"},
|
||||
SmokeTestModel: "deepseek-chat",
|
||||
GroupTemplate: GroupTemplate{Name: "g"},
|
||||
ChannelTemplate: ChannelTemplate{
|
||||
Name: "c",
|
||||
ModelMapping: map[string]string{"deepseek-chat": "deepseek-chat"},
|
||||
},
|
||||
PlanTemplate: PlanTemplate{Name: "p", ValidityDays: 30},
|
||||
}},
|
||||
wantErr: "base_url must use https",
|
||||
},
|
||||
{
|
||||
name: "missing display name",
|
||||
providers: []ProviderManifest{{
|
||||
ProviderID: "deepseek",
|
||||
DisplayName: "",
|
||||
BaseURL: "https://api.deepseek.com",
|
||||
Platform: "openai",
|
||||
AccountType: "api",
|
||||
DefaultModels: []string{"deepseek-chat"},
|
||||
SmokeTestModel: "deepseek-chat",
|
||||
GroupTemplate: GroupTemplate{Name: "g"},
|
||||
ChannelTemplate: ChannelTemplate{
|
||||
Name: "c",
|
||||
ModelMapping: map[string]string{"deepseek-chat": "deepseek-chat"},
|
||||
},
|
||||
PlanTemplate: PlanTemplate{Name: "p", ValidityDays: 30},
|
||||
}},
|
||||
wantErr: "display_name is required",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateProviders(tt.providers)
|
||||
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Fatalf("validateProviders() error = %v, want substring %q", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractZipToTempRejectsEmptyArchive(t *testing.T) {
|
||||
archivePath := filepath.Join(t.TempDir(), "empty.zip")
|
||||
file, err := os.Create(archivePath)
|
||||
if err != nil {
|
||||
t.Fatalf("os.Create() error = %v", err)
|
||||
}
|
||||
writer := zip.NewWriter(file)
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatalf("writer.Close() error = %v", err)
|
||||
}
|
||||
if err := file.Close(); err != nil {
|
||||
t.Fatalf("file.Close() error = %v", err)
|
||||
}
|
||||
|
||||
_, cleanup, err := extractZipToTemp(archivePath)
|
||||
if cleanup != nil {
|
||||
cleanup()
|
||||
}
|
||||
if err == nil || !strings.Contains(err.Error(), "pack archive is empty") {
|
||||
t.Fatalf("extractZipToTemp() error = %v, want empty archive error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractZipToTempRejectsPathTraversal(t *testing.T) {
|
||||
archivePath := filepath.Join(t.TempDir(), "traversal.zip")
|
||||
file, err := os.Create(archivePath)
|
||||
if err != nil {
|
||||
t.Fatalf("os.Create() error = %v", err)
|
||||
}
|
||||
writer := zip.NewWriter(file)
|
||||
entry, err := writer.Create("../../../evil.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("writer.Create() error = %v", err)
|
||||
}
|
||||
if _, err := entry.Write([]byte("evil")); err != nil {
|
||||
t.Fatalf("entry.Write() error = %v", err)
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatalf("writer.Close() error = %v", err)
|
||||
}
|
||||
if err := file.Close(); err != nil {
|
||||
t.Fatalf("file.Close() error = %v", err)
|
||||
}
|
||||
|
||||
_, cleanup, err := extractZipToTemp(archivePath)
|
||||
if cleanup != nil {
|
||||
cleanup()
|
||||
}
|
||||
if err == nil || !strings.Contains(err.Error(), "invalid path") {
|
||||
t.Fatalf("extractZipToTemp() error = %v, want invalid path error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchesMaxConstraintCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hostVersion string
|
||||
maxVersion string
|
||||
want bool
|
||||
}{
|
||||
{name: "exact version match", hostVersion: "1.2.3", maxVersion: "1.2.3", want: true},
|
||||
{name: "wildcard x accepts same minor", hostVersion: "0.2.9", maxVersion: "0.2.x", want: true},
|
||||
{name: "non matching version", hostVersion: "1.2.4", maxVersion: "1.2.3", want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := matchesMaxConstraint(tt.hostVersion, tt.maxVersion)
|
||||
if err != nil {
|
||||
t.Fatalf("matchesMaxConstraint() error = %v", err)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Fatalf("matchesMaxConstraint() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchesMaxConstraintRejectsWildcardStar(t *testing.T) {
|
||||
_, err := matchesMaxConstraint("1.2.3", "1.2.*")
|
||||
if err == nil || !strings.Contains(err.Error(), `parse version segment "*"`) {
|
||||
t.Fatalf("matchesMaxConstraint() error = %v, want parse failure for wildcard star", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadPathRejectsEmptyAndMissingPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
wantErr string
|
||||
}{
|
||||
{name: "empty path", path: " ", wantErr: "pack path is required"},
|
||||
{name: "missing path", path: filepath.Join(t.TempDir(), "missing-pack"), wantErr: "stat pack path"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := LoadPath(tt.path)
|
||||
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Fatalf("LoadPath() error = %v, want substring %q", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadArchiveRejectsNonZipFile(t *testing.T) {
|
||||
filePath := filepath.Join(t.TempDir(), "not-a-zip.zip")
|
||||
mustWrite(t, filePath, "plain text, not a zip archive")
|
||||
|
||||
_, err := LoadArchive(filePath)
|
||||
if err == nil || !strings.Contains(err.Error(), "open pack archive") {
|
||||
t.Fatalf("LoadArchive() error = %v, want open archive error", err)
|
||||
}
|
||||
}
|
||||
249
internal/pack/loader.go
Normal file
249
internal/pack/loader.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package pack
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Manifest struct {
|
||||
PackID string `json:"pack_id"`
|
||||
Version string `json:"version"`
|
||||
Vendor string `json:"vendor"`
|
||||
TargetHost string `json:"target_host"`
|
||||
MinHostVersion string `json:"min_host_version"`
|
||||
MaxHostVersion string `json:"max_host_version"`
|
||||
ProvidersDir string `json:"providers_dir"`
|
||||
ChecksumFile string `json:"checksum_file"`
|
||||
}
|
||||
|
||||
type ProviderManifest struct {
|
||||
ProviderID string `json:"provider_id"`
|
||||
DisplayName string `json:"display_name"`
|
||||
BaseURL string `json:"base_url"`
|
||||
Platform string `json:"platform"`
|
||||
AccountType string `json:"account_type"`
|
||||
DefaultModels []string `json:"default_models"`
|
||||
SmokeTestModel string `json:"smoke_test_model"`
|
||||
GroupTemplate GroupTemplate `json:"group_template"`
|
||||
ChannelTemplate ChannelTemplate `json:"channel_template"`
|
||||
PlanTemplate PlanTemplate `json:"plan_template"`
|
||||
Import ImportOptions `json:"import"`
|
||||
}
|
||||
|
||||
type GroupTemplate struct {
|
||||
Name string `json:"name"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
}
|
||||
|
||||
type ChannelTemplate struct {
|
||||
Name string `json:"name"`
|
||||
ModelMapping map[string]string `json:"model_mapping"`
|
||||
}
|
||||
|
||||
type PlanTemplate struct {
|
||||
Name string `json:"name"`
|
||||
Price float64 `json:"price"`
|
||||
ValidityDays int `json:"validity_days"`
|
||||
ValidityUnit string `json:"validity_unit"`
|
||||
}
|
||||
|
||||
type ImportOptions struct {
|
||||
SupportsMultiKey bool `json:"supports_multi_key"`
|
||||
SupportsStrict bool `json:"supports_strict"`
|
||||
SupportsPartial bool `json:"supports_partial"`
|
||||
}
|
||||
|
||||
type LoadedPack struct {
|
||||
Dir string
|
||||
Manifest Manifest
|
||||
Providers []ProviderManifest
|
||||
Checksum string
|
||||
}
|
||||
|
||||
func LoadDir(dir string) (LoadedPack, error) {
|
||||
root := strings.TrimSpace(dir)
|
||||
if root == "" {
|
||||
return LoadedPack{}, fmt.Errorf("pack dir is required")
|
||||
}
|
||||
|
||||
manifestPath := filepath.Join(root, "pack.json")
|
||||
manifestBytes, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
return LoadedPack{}, fmt.Errorf("read pack.json: %w", err)
|
||||
}
|
||||
|
||||
var manifest Manifest
|
||||
if err := json.Unmarshal(manifestBytes, &manifest); err != nil {
|
||||
return LoadedPack{}, fmt.Errorf("decode pack.json: %w", err)
|
||||
}
|
||||
if err := validateManifest(manifest); err != nil {
|
||||
return LoadedPack{}, err
|
||||
}
|
||||
|
||||
if err := validateChecksums(root, manifest.ChecksumFile); err != nil {
|
||||
return LoadedPack{}, err
|
||||
}
|
||||
|
||||
providers, err := loadProviders(root, manifest.ProvidersDir)
|
||||
if err != nil {
|
||||
return LoadedPack{}, err
|
||||
}
|
||||
if len(providers) == 0 {
|
||||
return LoadedPack{}, fmt.Errorf("providers dir %q does not contain provider manifests", manifest.ProvidersDir)
|
||||
}
|
||||
if err := validateProviders(providers); err != nil {
|
||||
return LoadedPack{}, err
|
||||
}
|
||||
|
||||
checksum, err := computeAggregateChecksum(root, manifest.ChecksumFile)
|
||||
if err != nil {
|
||||
return LoadedPack{}, err
|
||||
}
|
||||
|
||||
return LoadedPack{Dir: root, Manifest: manifest, Providers: providers, Checksum: checksum}, nil
|
||||
}
|
||||
|
||||
func validateManifest(manifest Manifest) error {
|
||||
switch {
|
||||
case strings.TrimSpace(manifest.PackID) == "":
|
||||
return fmt.Errorf("pack.json: pack_id is required")
|
||||
case strings.TrimSpace(manifest.Version) == "":
|
||||
return fmt.Errorf("pack.json: version is required")
|
||||
case strings.TrimSpace(manifest.TargetHost) == "":
|
||||
return fmt.Errorf("pack.json: target_host is required")
|
||||
case strings.TrimSpace(manifest.ProvidersDir) == "":
|
||||
return fmt.Errorf("pack.json: providers_dir is required")
|
||||
case strings.TrimSpace(manifest.ChecksumFile) == "":
|
||||
return fmt.Errorf("pack.json: checksum_file is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadProviders(root string, providersDir string) ([]ProviderManifest, error) {
|
||||
dir := filepath.Join(root, providersDir)
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read providers dir %q: %w", providersDir, err)
|
||||
}
|
||||
|
||||
providers := make([]ProviderManifest, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") {
|
||||
continue
|
||||
}
|
||||
path := filepath.Join(dir, entry.Name())
|
||||
body, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read provider %q: %w", entry.Name(), err)
|
||||
}
|
||||
var provider ProviderManifest
|
||||
if err := json.Unmarshal(body, &provider); err != nil {
|
||||
return nil, fmt.Errorf("decode provider %q: %w", entry.Name(), err)
|
||||
}
|
||||
providers = append(providers, provider)
|
||||
}
|
||||
sort.Slice(providers, func(i, j int) bool { return providers[i].ProviderID < providers[j].ProviderID })
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
func validateProviders(providers []ProviderManifest) error {
|
||||
seen := make(map[string]struct{}, len(providers))
|
||||
for _, provider := range providers {
|
||||
providerID := strings.TrimSpace(provider.ProviderID)
|
||||
switch {
|
||||
case providerID == "":
|
||||
return fmt.Errorf("provider manifest: provider_id is required")
|
||||
case strings.TrimSpace(provider.DisplayName) == "":
|
||||
return fmt.Errorf("provider %q: display_name is required", providerID)
|
||||
case !strings.HasPrefix(strings.TrimSpace(provider.BaseURL), "https://"):
|
||||
return fmt.Errorf("provider %q: base_url must use https", providerID)
|
||||
case strings.TrimSpace(provider.Platform) == "":
|
||||
return fmt.Errorf("provider %q: platform is required", providerID)
|
||||
case strings.TrimSpace(provider.AccountType) == "":
|
||||
return fmt.Errorf("provider %q: account_type is required", providerID)
|
||||
case len(provider.DefaultModels) == 0:
|
||||
return fmt.Errorf("provider %q: default_models must not be empty", providerID)
|
||||
case strings.TrimSpace(provider.SmokeTestModel) == "":
|
||||
return fmt.Errorf("provider %q: smoke_test_model is required", providerID)
|
||||
case !contains(provider.DefaultModels, provider.SmokeTestModel):
|
||||
return fmt.Errorf("provider %q: smoke_test_model must be present in default_models", providerID)
|
||||
case strings.TrimSpace(provider.GroupTemplate.Name) == "":
|
||||
return fmt.Errorf("provider %q: group_template.name is required", providerID)
|
||||
case strings.TrimSpace(provider.ChannelTemplate.Name) == "":
|
||||
return fmt.Errorf("provider %q: channel_template.name is required", providerID)
|
||||
case len(provider.ChannelTemplate.ModelMapping) == 0:
|
||||
return fmt.Errorf("provider %q: channel_template.model_mapping must not be empty", providerID)
|
||||
case strings.TrimSpace(provider.PlanTemplate.Name) == "":
|
||||
return fmt.Errorf("provider %q: plan_template.name is required", providerID)
|
||||
case provider.PlanTemplate.ValidityDays <= 0:
|
||||
return fmt.Errorf("provider %q: plan_template.validity_days must be positive", providerID)
|
||||
}
|
||||
if _, ok := seen[providerID]; ok {
|
||||
return fmt.Errorf("duplicate provider_id %q", providerID)
|
||||
}
|
||||
seen[providerID] = struct{}{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateChecksums(root string, checksumFile string) error {
|
||||
path := filepath.Join(root, checksumFile)
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read checksum file %q: %w", checksumFile, err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
lineNumber := 0
|
||||
for scanner.Scan() {
|
||||
lineNumber++
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
parts := strings.Fields(line)
|
||||
if len(parts) != 2 {
|
||||
return fmt.Errorf("checksum file %q line %d: invalid format", checksumFile, lineNumber)
|
||||
}
|
||||
relativePath := parts[1]
|
||||
body, err := os.ReadFile(filepath.Join(root, relativePath))
|
||||
if err != nil {
|
||||
return fmt.Errorf("checksum file %q line %d: read %q: %w", checksumFile, lineNumber, relativePath, err)
|
||||
}
|
||||
sum := sha256.Sum256(body)
|
||||
actual := hex.EncodeToString(sum[:])
|
||||
if !strings.EqualFold(parts[0], actual) {
|
||||
return fmt.Errorf("checksum mismatch for %s", relativePath)
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return fmt.Errorf("scan checksum file %q: %w", checksumFile, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func computeAggregateChecksum(root string, checksumFile string) (string, error) {
|
||||
body, err := os.ReadFile(filepath.Join(root, checksumFile))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read checksum file %q: %w", checksumFile, err)
|
||||
}
|
||||
sum := sha256.Sum256(body)
|
||||
return hex.EncodeToString(sum[:]), nil
|
||||
}
|
||||
|
||||
func contains(items []string, target string) bool {
|
||||
for _, item := range items {
|
||||
if strings.TrimSpace(item) == strings.TrimSpace(target) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
108
internal/pack/loader_test.go
Normal file
108
internal/pack/loader_test.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package pack
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadDirParsesAndValidatesPack(t *testing.T) {
|
||||
packDir := createPackFixture(t, map[string]string{
|
||||
"pack.json": `{
|
||||
"pack_id": "openai-cn-pack",
|
||||
"version": "1.0.0",
|
||||
"vendor": "YourTeam",
|
||||
"target_host": "sub2api",
|
||||
"min_host_version": "0.1.126",
|
||||
"max_host_version": "0.2.x",
|
||||
"providers_dir": "providers",
|
||||
"checksum_file": "checksums.txt"
|
||||
}`,
|
||||
"providers/deepseek.json": `{
|
||||
"provider_id": "deepseek",
|
||||
"display_name": "DeepSeek OpenAI Compatible",
|
||||
"base_url": "https://api.deepseek.com",
|
||||
"platform": "openai",
|
||||
"account_type": "api",
|
||||
"default_models": ["deepseek-chat", "deepseek-reasoner"],
|
||||
"smoke_test_model": "deepseek-chat",
|
||||
"group_template": {"name": "DeepSeek 默认分组", "rate_multiplier": 1.0},
|
||||
"channel_template": {"name": "DeepSeek 默认渠道", "model_mapping": {"deepseek-chat": "deepseek-chat"}},
|
||||
"plan_template": {"name": "DeepSeek 默认套餐", "price": 19.9, "validity_days": 30, "validity_unit": "day"},
|
||||
"import": {"supports_multi_key": true, "supports_strict": true, "supports_partial": true}
|
||||
}`,
|
||||
})
|
||||
|
||||
loaded, err := LoadDir(packDir)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadDir() error = %v", err)
|
||||
}
|
||||
|
||||
if loaded.Manifest.PackID != "openai-cn-pack" {
|
||||
t.Fatalf("PackID = %q, want %q", loaded.Manifest.PackID, "openai-cn-pack")
|
||||
}
|
||||
if len(loaded.Providers) != 1 {
|
||||
t.Fatalf("len(Providers) = %d, want 1", len(loaded.Providers))
|
||||
}
|
||||
if loaded.Providers[0].ProviderID != "deepseek" {
|
||||
t.Fatalf("ProviderID = %q, want %q", loaded.Providers[0].ProviderID, "deepseek")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadDirRejectsChecksumMismatch(t *testing.T) {
|
||||
packDir := t.TempDir()
|
||||
mustWrite(t, filepath.Join(packDir, "pack.json"), `{"pack_id":"openai-cn-pack","version":"1.0.0","vendor":"x","target_host":"sub2api","min_host_version":"0.1.126","max_host_version":"0.2.x","providers_dir":"providers","checksum_file":"checksums.txt"}`)
|
||||
mustWrite(t, filepath.Join(packDir, "providers", "deepseek.json"), `{"provider_id":"deepseek","display_name":"DeepSeek","base_url":"https://api.deepseek.com","platform":"openai","account_type":"api","default_models":["deepseek-chat"],"smoke_test_model":"deepseek-chat","group_template":{"name":"g","rate_multiplier":1},"channel_template":{"name":"c","model_mapping":{"deepseek-chat":"deepseek-chat"}},"plan_template":{"name":"p","price":1,"validity_days":30,"validity_unit":"day"},"import":{"supports_multi_key":true,"supports_strict":true,"supports_partial":true}}`)
|
||||
mustWrite(t, filepath.Join(packDir, "checksums.txt"), "deadbeef pack.json\ndeadbeef providers/deepseek.json\n")
|
||||
|
||||
_, err := LoadDir(packDir)
|
||||
if err == nil {
|
||||
t.Fatal("LoadDir() error = nil, want checksum mismatch")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "checksum mismatch") {
|
||||
t.Fatalf("LoadDir() error = %v, want checksum mismatch", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadDirRejectsInvalidProviderSchema(t *testing.T) {
|
||||
packDir := createPackFixture(t, map[string]string{
|
||||
"pack.json": `{"pack_id":"openai-cn-pack","version":"1.0.0","vendor":"x","target_host":"sub2api","min_host_version":"0.1.126","max_host_version":"0.2.x","providers_dir":"providers","checksum_file":"checksums.txt"}`,
|
||||
"providers/deepseek.json": `{"provider_id":"deepseek","display_name":"DeepSeek","base_url":"http://insecure.example.com","platform":"openai","account_type":"api","default_models":["deepseek-chat"],"smoke_test_model":"missing-model","group_template":{"name":"g","rate_multiplier":1},"channel_template":{"name":"c","model_mapping":{"deepseek-chat":"deepseek-chat"}},"plan_template":{"name":"p","price":1,"validity_days":30,"validity_unit":"day"},"import":{"supports_multi_key":true,"supports_strict":true,"supports_partial":true}}`,
|
||||
})
|
||||
|
||||
_, err := LoadDir(packDir)
|
||||
if err == nil {
|
||||
t.Fatal("LoadDir() error = nil, want schema validation failure")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "https") && !strings.Contains(err.Error(), "smoke_test_model") {
|
||||
t.Fatalf("LoadDir() error = %v, want schema validation detail", err)
|
||||
}
|
||||
}
|
||||
|
||||
func createPackFixture(t *testing.T, files map[string]string) string {
|
||||
t.Helper()
|
||||
|
||||
packDir := t.TempDir()
|
||||
var lines []string
|
||||
for relativePath, content := range files {
|
||||
absolutePath := filepath.Join(packDir, relativePath)
|
||||
mustWrite(t, absolutePath, content)
|
||||
sum := sha256.Sum256([]byte(content))
|
||||
lines = append(lines, hex.EncodeToString(sum[:])+" "+relativePath)
|
||||
}
|
||||
mustWrite(t, filepath.Join(packDir, "checksums.txt"), strings.Join(lines, "\n")+"\n")
|
||||
return packDir
|
||||
}
|
||||
|
||||
func mustWrite(t *testing.T, path string, content string) {
|
||||
t.Helper()
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
t.Fatalf("MkdirAll(%q) error = %v", path, err)
|
||||
}
|
||||
if err := os.WriteFile(path, []byte(content), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(%q) error = %v", path, err)
|
||||
}
|
||||
}
|
||||
171
internal/pack/source_loader.go
Normal file
171
internal/pack/source_loader.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package pack
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
maxArchiveEntries = 256
|
||||
maxArchiveFileSize = 5 << 20
|
||||
maxArchiveTotalSize = 20 << 20
|
||||
)
|
||||
|
||||
func LoadPath(path string) (LoadedPack, error) {
|
||||
trimmed := strings.TrimSpace(path)
|
||||
if trimmed == "" {
|
||||
return LoadedPack{}, fmt.Errorf("pack path is required")
|
||||
}
|
||||
|
||||
info, err := os.Stat(trimmed)
|
||||
if err != nil {
|
||||
return LoadedPack{}, fmt.Errorf("stat pack path: %w", err)
|
||||
}
|
||||
if info.IsDir() {
|
||||
return LoadDir(trimmed)
|
||||
}
|
||||
if strings.EqualFold(filepath.Ext(info.Name()), ".zip") {
|
||||
return LoadArchive(trimmed)
|
||||
}
|
||||
return LoadedPack{}, fmt.Errorf("pack path %q must be a directory or .zip archive", trimmed)
|
||||
}
|
||||
|
||||
func LoadArchive(path string) (LoadedPack, error) {
|
||||
root, cleanup, err := extractZipToTemp(path)
|
||||
if err != nil {
|
||||
return LoadedPack{}, err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
loaded, err := LoadDir(root)
|
||||
if err != nil {
|
||||
return LoadedPack{}, err
|
||||
}
|
||||
loaded.Dir = strings.TrimSpace(path)
|
||||
return loaded, nil
|
||||
}
|
||||
|
||||
func extractZipToTemp(path string) (string, func(), error) {
|
||||
reader, err := zip.OpenReader(strings.TrimSpace(path))
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("open pack archive: %w", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
if len(reader.File) == 0 {
|
||||
return "", nil, fmt.Errorf("pack archive is empty")
|
||||
}
|
||||
if len(reader.File) > maxArchiveEntries {
|
||||
return "", nil, fmt.Errorf("pack archive has too many entries: %d", len(reader.File))
|
||||
}
|
||||
|
||||
tempDir, err := os.MkdirTemp("", "relay-pack-*")
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("create temp dir for pack archive: %w", err)
|
||||
}
|
||||
cleanup := func() { _ = os.RemoveAll(tempDir) }
|
||||
|
||||
var totalSize uint64
|
||||
for _, file := range reader.File {
|
||||
cleanName := filepath.Clean(file.Name)
|
||||
if cleanName == "." || cleanName == "" {
|
||||
continue
|
||||
}
|
||||
if filepath.IsAbs(cleanName) || cleanName == ".." || strings.HasPrefix(cleanName, ".."+string(filepath.Separator)) {
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("pack archive contains invalid path %q", file.Name)
|
||||
}
|
||||
if file.FileInfo().Mode()&os.ModeSymlink != 0 {
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("pack archive contains unsupported symlink entry %q", file.Name)
|
||||
}
|
||||
if file.UncompressedSize64 > maxArchiveFileSize {
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("pack archive entry %q exceeds size limit", file.Name)
|
||||
}
|
||||
totalSize += file.UncompressedSize64
|
||||
if totalSize > maxArchiveTotalSize {
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("pack archive exceeds total size limit")
|
||||
}
|
||||
|
||||
targetPath := filepath.Join(tempDir, cleanName)
|
||||
relativeTarget, err := filepath.Rel(tempDir, targetPath)
|
||||
if err != nil || relativeTarget == ".." || strings.HasPrefix(relativeTarget, ".."+string(filepath.Separator)) {
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("pack archive entry %q escapes extraction root", file.Name)
|
||||
}
|
||||
|
||||
if file.FileInfo().IsDir() {
|
||||
if err := os.MkdirAll(targetPath, 0o755); err != nil {
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("create archive dir %q: %w", file.Name, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil {
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("create archive parent dir %q: %w", file.Name, err)
|
||||
}
|
||||
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("open archive entry %q: %w", file.Name, err)
|
||||
}
|
||||
dst, err := os.OpenFile(targetPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644)
|
||||
if err != nil {
|
||||
src.Close()
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("create archive file %q: %w", file.Name, err)
|
||||
}
|
||||
_, copyErr := io.Copy(dst, src)
|
||||
closeErr := dst.Close()
|
||||
srcErr := src.Close()
|
||||
if copyErr != nil {
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("extract archive entry %q: %w", file.Name, copyErr)
|
||||
}
|
||||
if closeErr != nil {
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("close archive file %q: %w", file.Name, closeErr)
|
||||
}
|
||||
if srcErr != nil {
|
||||
cleanup()
|
||||
return "", nil, fmt.Errorf("close archive entry %q: %w", file.Name, srcErr)
|
||||
}
|
||||
}
|
||||
|
||||
root, err := resolvePackRoot(tempDir)
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return "", nil, err
|
||||
}
|
||||
return root, cleanup, nil
|
||||
}
|
||||
|
||||
func resolvePackRoot(extractDir string) (string, error) {
|
||||
manifestPath := filepath.Join(extractDir, "pack.json")
|
||||
if _, err := os.Stat(manifestPath); err == nil {
|
||||
return extractDir, nil
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(extractDir)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read extracted archive root: %w", err)
|
||||
}
|
||||
if len(entries) != 1 || !entries[0].IsDir() {
|
||||
return "", fmt.Errorf("pack archive must contain pack.json at root or a single top-level directory")
|
||||
}
|
||||
|
||||
root := filepath.Join(extractDir, entries[0].Name())
|
||||
if _, err := os.Stat(filepath.Join(root, "pack.json")); err != nil {
|
||||
return "", fmt.Errorf("pack archive root does not contain pack.json")
|
||||
}
|
||||
return root, nil
|
||||
}
|
||||
70
internal/pack/source_loader_test.go
Normal file
70
internal/pack/source_loader_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package pack
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadPathSupportsDirectory(t *testing.T) {
|
||||
loaded, err := LoadPath(filepath.Join("..", "..", "packs", "openai-cn-pack"))
|
||||
if err != nil {
|
||||
t.Fatalf("LoadPath(dir) error = %v", err)
|
||||
}
|
||||
if loaded.Manifest.PackID != "openai-cn-pack" {
|
||||
t.Fatalf("PackID = %q, want %q", loaded.Manifest.PackID, "openai-cn-pack")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadPathSupportsZipArchive(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
archivePath := filepath.Join(tempDir, "openai-cn-pack.zip")
|
||||
writePackArchive(t, archivePath)
|
||||
|
||||
loaded, err := LoadPath(archivePath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadPath(zip) error = %v", err)
|
||||
}
|
||||
if loaded.Manifest.PackID != "openai-cn-pack" {
|
||||
t.Fatalf("PackID = %q, want %q", loaded.Manifest.PackID, "openai-cn-pack")
|
||||
}
|
||||
if len(loaded.Providers) == 0 {
|
||||
t.Fatal("Providers = 0, want parsed providers from archive")
|
||||
}
|
||||
}
|
||||
|
||||
func writePackArchive(t *testing.T, archivePath string) {
|
||||
t.Helper()
|
||||
file, err := os.Create(archivePath)
|
||||
if err != nil {
|
||||
t.Fatalf("os.Create() error = %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
writer := zip.NewWriter(file)
|
||||
defer writer.Close()
|
||||
|
||||
sourceRoot := filepath.Join("..", "..", "packs", "openai-cn-pack")
|
||||
files := []string{
|
||||
"pack.json",
|
||||
"checksums.txt",
|
||||
filepath.Join("providers", "deepseek.json"),
|
||||
}
|
||||
for _, relativePath := range files {
|
||||
body, err := os.ReadFile(filepath.Join(sourceRoot, relativePath))
|
||||
if err != nil {
|
||||
t.Fatalf("os.ReadFile(%q) error = %v", relativePath, err)
|
||||
}
|
||||
entry, err := writer.Create(filepath.ToSlash(filepath.Join("openai-cn-pack", relativePath)))
|
||||
if err != nil {
|
||||
t.Fatalf("Create(%q) error = %v", relativePath, err)
|
||||
}
|
||||
if _, err := entry.Write(body); err != nil {
|
||||
t.Fatalf("Write(%q) error = %v", relativePath, err)
|
||||
}
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatalf("Close archive writer: %v", err)
|
||||
}
|
||||
}
|
||||
134
internal/pack/version.go
Normal file
134
internal/pack/version.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package pack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func CheckHostCompatibility(manifest Manifest, hostVersion string) error {
|
||||
targetHost := strings.TrimSpace(manifest.TargetHost)
|
||||
if targetHost == "" {
|
||||
return fmt.Errorf("pack manifest target_host is required")
|
||||
}
|
||||
if targetHost != "sub2api" {
|
||||
return fmt.Errorf("pack target_host %q is not supported", targetHost)
|
||||
}
|
||||
|
||||
normalizedHost, err := parseVersion(strings.TrimSpace(hostVersion))
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse host version %q: %w", hostVersion, err)
|
||||
}
|
||||
minVersion := strings.TrimSpace(manifest.MinHostVersion)
|
||||
if minVersion != "" {
|
||||
cmp, err := compareVersions(normalizedHost.raw, minVersion)
|
||||
if err != nil {
|
||||
return fmt.Errorf("compare min_host_version: %w", err)
|
||||
}
|
||||
if cmp < 0 {
|
||||
return fmt.Errorf("host version %q is below min_host_version %q", hostVersion, minVersion)
|
||||
}
|
||||
}
|
||||
|
||||
maxVersion := strings.TrimSpace(manifest.MaxHostVersion)
|
||||
if maxVersion != "" {
|
||||
ok, err := matchesMaxConstraint(normalizedHost.raw, maxVersion)
|
||||
if err != nil {
|
||||
return fmt.Errorf("compare max_host_version: %w", err)
|
||||
}
|
||||
if !ok {
|
||||
return fmt.Errorf("host version %q is above max_host_version %q", hostVersion, maxVersion)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type parsedVersion struct {
|
||||
raw string
|
||||
parts [3]int
|
||||
}
|
||||
|
||||
func compareVersions(a, b string) (int, error) {
|
||||
left, err := parseVersion(a)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
right, err := parseVersion(b)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for i := 0; i < len(left.parts); i++ {
|
||||
if left.parts[i] < right.parts[i] {
|
||||
return -1, nil
|
||||
}
|
||||
if left.parts[i] > right.parts[i] {
|
||||
return 1, nil
|
||||
}
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func matchesMaxConstraint(hostVersion, maxVersion string) (bool, error) {
|
||||
normalizedMax := normalizeVersion(maxVersion)
|
||||
if strings.HasSuffix(normalizedMax, ".x") {
|
||||
prefix := strings.TrimSuffix(normalizedMax, ".x")
|
||||
parts := strings.Split(prefix, ".")
|
||||
if len(parts) != 2 {
|
||||
return false, fmt.Errorf("wildcard max version %q must be in N.N.x format", maxVersion)
|
||||
}
|
||||
host, err := parseVersion(hostVersion)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
major, err := strconv.Atoi(parts[0])
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parse major version %q: %w", parts[0], err)
|
||||
}
|
||||
minor, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parse minor version %q: %w", parts[1], err)
|
||||
}
|
||||
if host.parts[0] < major {
|
||||
return true, nil
|
||||
}
|
||||
if host.parts[0] > major {
|
||||
return false, nil
|
||||
}
|
||||
return host.parts[1] <= minor, nil
|
||||
}
|
||||
|
||||
cmp, err := compareVersions(hostVersion, maxVersion)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cmp <= 0, nil
|
||||
}
|
||||
|
||||
func parseVersion(value string) (parsedVersion, error) {
|
||||
normalized := normalizeVersion(value)
|
||||
if normalized == "" {
|
||||
return parsedVersion{}, fmt.Errorf("version is required")
|
||||
}
|
||||
parts := strings.Split(normalized, ".")
|
||||
if len(parts) != 3 {
|
||||
return parsedVersion{}, fmt.Errorf("version %q must be in N.N.N format", value)
|
||||
}
|
||||
|
||||
var parsed parsedVersion
|
||||
parsed.raw = normalized
|
||||
for i, part := range parts {
|
||||
number, err := strconv.Atoi(part)
|
||||
if err != nil {
|
||||
return parsedVersion{}, fmt.Errorf("parse version segment %q: %w", part, err)
|
||||
}
|
||||
parsed.parts[i] = number
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func normalizeVersion(value string) string {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
trimmed = strings.TrimPrefix(trimmed, "v")
|
||||
trimmed = strings.TrimPrefix(trimmed, "V")
|
||||
return trimmed
|
||||
}
|
||||
32
internal/pack/version_test.go
Normal file
32
internal/pack/version_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package pack
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestCheckHostCompatibilityAcceptsRange(t *testing.T) {
|
||||
manifest := Manifest{
|
||||
PackID: "openai-cn-pack",
|
||||
TargetHost: "sub2api",
|
||||
MinHostVersion: "0.1.126",
|
||||
MaxHostVersion: "0.2.x",
|
||||
}
|
||||
if err := CheckHostCompatibility(manifest, "0.1.126"); err != nil {
|
||||
t.Fatalf("CheckHostCompatibility() error = %v, want nil", err)
|
||||
}
|
||||
if err := CheckHostCompatibility(manifest, "0.2.9"); err != nil {
|
||||
t.Fatalf("CheckHostCompatibility() error = %v, want nil for wildcard upper bound", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckHostCompatibilityRejectsBelowMinimum(t *testing.T) {
|
||||
manifest := Manifest{TargetHost: "sub2api", MinHostVersion: "0.1.126"}
|
||||
if err := CheckHostCompatibility(manifest, "0.1.125"); err == nil {
|
||||
t.Fatal("CheckHostCompatibility() error = nil, want min version failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckHostCompatibilityRejectsDifferentMaxMinor(t *testing.T) {
|
||||
manifest := Manifest{TargetHost: "sub2api", MaxHostVersion: "0.2.x"}
|
||||
if err := CheckHostCompatibility(manifest, "0.3.0"); err == nil {
|
||||
t.Fatal("CheckHostCompatibility() error = nil, want max version failure")
|
||||
}
|
||||
}
|
||||
321
internal/provision/batch_detail_and_reconcile_service.go
Normal file
321
internal/provision/batch_detail_and_reconcile_service.go
Normal file
@@ -0,0 +1,321 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
"sub2api-cn-relay-manager/internal/pack"
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
type BatchDetailResult struct {
|
||||
Batch sqlite.ImportBatch
|
||||
Items []sqlite.ImportBatchItem
|
||||
ManagedResources []sqlite.ManagedResource
|
||||
AccessClosures []sqlite.AccessClosureRecord
|
||||
ReconcileRuns []sqlite.ReconcileRun
|
||||
}
|
||||
|
||||
type BatchDetailService struct {
|
||||
store *sqlite.DB
|
||||
}
|
||||
|
||||
func NewBatchDetailService(store *sqlite.DB) *BatchDetailService {
|
||||
return &BatchDetailService{store: store}
|
||||
}
|
||||
|
||||
func (s *BatchDetailService) Get(ctx context.Context, batchID int64) (BatchDetailResult, error) {
|
||||
if s == nil || s.store == nil {
|
||||
return BatchDetailResult{}, fmt.Errorf("store is required")
|
||||
}
|
||||
batch, err := s.store.ImportBatches().GetByID(ctx, batchID)
|
||||
if err != nil {
|
||||
return BatchDetailResult{}, err
|
||||
}
|
||||
items, err := s.store.ImportBatchItems().GetByBatchID(ctx, batchID)
|
||||
if err != nil {
|
||||
return BatchDetailResult{}, err
|
||||
}
|
||||
managedResources, err := s.store.ManagedResources().GetByBatchID(ctx, batchID)
|
||||
if err != nil {
|
||||
return BatchDetailResult{}, err
|
||||
}
|
||||
accessClosures, err := s.store.AccessClosures().GetByBatchID(ctx, batchID)
|
||||
if err != nil {
|
||||
return BatchDetailResult{}, err
|
||||
}
|
||||
reconcileRuns, err := s.store.ReconcileRuns().GetByProviderID(ctx, batch.ProviderID)
|
||||
if err != nil {
|
||||
return BatchDetailResult{}, err
|
||||
}
|
||||
return BatchDetailResult{
|
||||
Batch: batch,
|
||||
Items: items,
|
||||
ManagedResources: managedResources,
|
||||
AccessClosures: accessClosures,
|
||||
ReconcileRuns: reconcileRuns,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type ReconcileRequest struct {
|
||||
HostBaseURL string
|
||||
AccessProbeAPIKey string
|
||||
Pack pack.LoadedPack
|
||||
Provider pack.ProviderManifest
|
||||
}
|
||||
|
||||
type ReconcileResult struct {
|
||||
BatchID int64
|
||||
Status string
|
||||
MissingCount int
|
||||
ExtraCount int
|
||||
ProbeFailureCount int
|
||||
AccessStatus string
|
||||
Summary map[string]any
|
||||
}
|
||||
|
||||
type ReconcileService struct {
|
||||
store *sqlite.DB
|
||||
host sub2api.HostAdapter
|
||||
}
|
||||
|
||||
func NewReconcileService(store *sqlite.DB, host sub2api.HostAdapter) *ReconcileService {
|
||||
return &ReconcileService{store: store, host: host}
|
||||
}
|
||||
|
||||
func (s *ReconcileService) Reconcile(ctx context.Context, req ReconcileRequest) (ReconcileResult, error) {
|
||||
if s == nil || s.store == nil {
|
||||
return ReconcileResult{}, fmt.Errorf("store is required")
|
||||
}
|
||||
if s.host == nil {
|
||||
return ReconcileResult{}, fmt.Errorf("host adapter is required")
|
||||
}
|
||||
if strings.TrimSpace(req.HostBaseURL) == "" {
|
||||
return ReconcileResult{}, fmt.Errorf("host_base_url is required")
|
||||
}
|
||||
hostVersion, err := s.host.GetHostVersion(ctx)
|
||||
if err != nil {
|
||||
return ReconcileResult{}, fmt.Errorf("get host version: %w", err)
|
||||
}
|
||||
if err := pack.CheckHostCompatibility(req.Pack.Manifest, hostVersion); err != nil {
|
||||
return ReconcileResult{}, err
|
||||
}
|
||||
packRow, err := s.store.Packs().GetByPackID(ctx, req.Pack.Manifest.PackID)
|
||||
if err != nil {
|
||||
return ReconcileResult{}, err
|
||||
}
|
||||
providerRow, err := s.store.Providers().GetByPackIDAndProviderID(ctx, packRow.ID, req.Provider.ProviderID)
|
||||
if err != nil {
|
||||
return ReconcileResult{}, err
|
||||
}
|
||||
batchRow, err := s.store.ImportBatches().GetLatestByProviderID(ctx, providerRow.ID)
|
||||
if err != nil {
|
||||
return ReconcileResult{}, err
|
||||
}
|
||||
storedResources, err := s.store.ManagedResources().GetByBatchID(ctx, batchRow.ID)
|
||||
if err != nil {
|
||||
return ReconcileResult{}, err
|
||||
}
|
||||
batchItems, err := s.store.ImportBatchItems().GetByBatchID(ctx, batchRow.ID)
|
||||
if err != nil {
|
||||
return ReconcileResult{}, err
|
||||
}
|
||||
accessClosures, err := s.store.AccessClosures().GetByBatchID(ctx, batchRow.ID)
|
||||
if err != nil {
|
||||
return ReconcileResult{}, err
|
||||
}
|
||||
snapshot, err := s.host.ListManagedResources(ctx, sub2api.ListManagedResourcesRequest{
|
||||
GroupName: SuggestResourceNames(req.Provider).Group,
|
||||
ChannelName: SuggestResourceNames(req.Provider).Channel,
|
||||
PlanName: SuggestResourceNames(req.Provider).Plan,
|
||||
})
|
||||
if err != nil {
|
||||
return ReconcileResult{}, fmt.Errorf("list managed resources: %w", err)
|
||||
}
|
||||
missing, extra := diffManagedResources(storedResources, snapshot)
|
||||
probeFailures, err := s.rerunAccountProbes(ctx, batchItems, req.Provider.SmokeTestModel)
|
||||
if err != nil {
|
||||
return ReconcileResult{}, err
|
||||
}
|
||||
accessStatus, accessChecked, err := s.rerunAccessClosure(ctx, batchRow.ID, accessClosures, req.AccessProbeAPIKey, req.Provider.SmokeTestModel)
|
||||
if err != nil {
|
||||
return ReconcileResult{}, err
|
||||
}
|
||||
status := "active"
|
||||
if missing > 0 || extra > 0 {
|
||||
status = "drifted"
|
||||
} else if probeFailures > 0 || (accessChecked && accessStatus == AccessStatusBroken) {
|
||||
status = "degraded"
|
||||
}
|
||||
summary := map[string]any{
|
||||
"missing_count": missing,
|
||||
"extra_count": extra,
|
||||
"host_version": hostVersion,
|
||||
"probe_failures": probeFailures,
|
||||
"access_status": accessStatus,
|
||||
"access_rechecked": accessChecked,
|
||||
}
|
||||
summaryJSON, err := json.Marshal(summary)
|
||||
if err != nil {
|
||||
return ReconcileResult{}, fmt.Errorf("marshal reconcile summary: %w", err)
|
||||
}
|
||||
if _, err := s.store.ReconcileRuns().Create(ctx, sqlite.ReconcileRun{ProviderID: providerRow.ID, Status: status, SummaryJSON: string(summaryJSON)}); err != nil {
|
||||
return ReconcileResult{}, err
|
||||
}
|
||||
return ReconcileResult{BatchID: batchRow.ID, Status: status, MissingCount: missing, ExtraCount: extra, ProbeFailureCount: probeFailures, AccessStatus: accessStatus, Summary: summary}, nil
|
||||
}
|
||||
|
||||
func (s *ReconcileService) rerunAccountProbes(ctx context.Context, items []sqlite.ImportBatchItem, expectedModel string) (int, error) {
|
||||
if len(items) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
failures := 0
|
||||
for _, item := range items {
|
||||
accountID, err := accountIDFromProbeSummary(item.ProbeSummaryJSON)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("decode import batch item %d probe summary: %w", item.ID, err)
|
||||
}
|
||||
if strings.TrimSpace(accountID) == "" {
|
||||
return 0, fmt.Errorf("import batch item %d missing account_id in probe summary", item.ID)
|
||||
}
|
||||
probe, err := s.host.TestAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("re-test account %s: %w", accountID, err)
|
||||
}
|
||||
models, err := s.host.GetAccountModels(ctx, accountID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("reload account models %s: %w", accountID, err)
|
||||
}
|
||||
smokeModelSeen := hasModel(models, expectedModel)
|
||||
status := firstNonEmpty(probe.Status, "unknown")
|
||||
payload, err := json.Marshal(map[string]any{
|
||||
"account_id": accountID,
|
||||
"probe_ok": probe.OK,
|
||||
"probe_status": probe.Status,
|
||||
"probe_message": probe.Message,
|
||||
"models": models,
|
||||
"smoke_model_seen": smokeModelSeen,
|
||||
"reconcile_rerun": true,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("marshal probe rerun summary for %s: %w", accountID, err)
|
||||
}
|
||||
if err := s.store.ImportBatchItems().UpdateResult(ctx, item.ID, status, string(payload)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if _, err := s.store.ProbeResults().Create(ctx, sqlite.ProbeResult{BatchItemID: item.ID, ProbeType: "account_smoke_rerun", Status: status, SummaryJSON: string(payload)}); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !probe.OK || !smokeModelSeen {
|
||||
failures++
|
||||
}
|
||||
}
|
||||
return failures, nil
|
||||
}
|
||||
|
||||
func (s *ReconcileService) rerunAccessClosure(ctx context.Context, batchID int64, accessClosures []sqlite.AccessClosureRecord, probeAPIKey, expectedModel string) (string, bool, error) {
|
||||
if len(accessClosures) == 0 {
|
||||
return "not_configured", false, nil
|
||||
}
|
||||
latest := accessClosures[len(accessClosures)-1]
|
||||
status := firstNonEmpty(latest.Status, deriveHealthyAccessStatus(latest.ClosureType))
|
||||
if strings.TrimSpace(probeAPIKey) == "" {
|
||||
return status, false, nil
|
||||
}
|
||||
result, err := s.host.CheckGatewayAccess(ctx, sub2api.GatewayAccessCheckRequest{APIKey: probeAPIKey, ExpectedModel: expectedModel})
|
||||
if err != nil {
|
||||
return "", false, fmt.Errorf("re-check gateway access: %w", err)
|
||||
}
|
||||
if result.OK && result.HasExpectedModel {
|
||||
status = deriveHealthyAccessStatus(latest.ClosureType)
|
||||
} else {
|
||||
status = AccessStatusBroken
|
||||
}
|
||||
payload, err := json.Marshal(map[string]any{
|
||||
"status_code": result.StatusCode,
|
||||
"ok": result.OK,
|
||||
"has_expected_model": result.HasExpectedModel,
|
||||
"models": result.Models,
|
||||
"reconcile_rerun": true,
|
||||
})
|
||||
if err != nil {
|
||||
return "", false, fmt.Errorf("marshal access rerun summary: %w", err)
|
||||
}
|
||||
if _, err := s.store.AccessClosures().Create(ctx, sqlite.AccessClosureRecord{BatchID: batchID, ClosureType: latest.ClosureType, Status: status, DetailsJSON: string(payload)}); err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
return status, true, nil
|
||||
}
|
||||
|
||||
func deriveHealthyAccessStatus(closureType string) string {
|
||||
switch strings.TrimSpace(closureType) {
|
||||
case AccessModeSubscription:
|
||||
return AccessStatusSubscriptionReady
|
||||
case AccessModeSelfService:
|
||||
return AccessStatusSelfServiceReady
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func accountIDFromProbeSummary(summaryJSON string) (string, error) {
|
||||
if strings.TrimSpace(summaryJSON) == "" {
|
||||
return "", nil
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(summaryJSON), &payload); err != nil {
|
||||
return "", err
|
||||
}
|
||||
accountID, _ := payload["account_id"].(string)
|
||||
return strings.TrimSpace(accountID), nil
|
||||
}
|
||||
|
||||
func diffManagedResources(stored []sqlite.ManagedResource, snapshot sub2api.ManagedResourceSnapshot) (int, int) {
|
||||
live := map[string]map[string]struct{}{
|
||||
"group": make(map[string]struct{}),
|
||||
"channel": make(map[string]struct{}),
|
||||
"plan": make(map[string]struct{}),
|
||||
"account": make(map[string]struct{}),
|
||||
}
|
||||
for _, resource := range snapshot.Groups {
|
||||
live["group"][strings.TrimSpace(resource.ID)] = struct{}{}
|
||||
}
|
||||
for _, resource := range snapshot.Channels {
|
||||
live["channel"][strings.TrimSpace(resource.ID)] = struct{}{}
|
||||
}
|
||||
for _, resource := range snapshot.Plans {
|
||||
live["plan"][strings.TrimSpace(resource.ID)] = struct{}{}
|
||||
}
|
||||
for _, resource := range snapshot.Accounts {
|
||||
live["account"][strings.TrimSpace(resource.ID)] = struct{}{}
|
||||
}
|
||||
|
||||
storedByType := map[string]map[string]struct{}{
|
||||
"group": make(map[string]struct{}),
|
||||
"channel": make(map[string]struct{}),
|
||||
"plan": make(map[string]struct{}),
|
||||
"account": make(map[string]struct{}),
|
||||
}
|
||||
for _, resource := range stored {
|
||||
storedByType[strings.TrimSpace(resource.ResourceType)][strings.TrimSpace(resource.HostResourceID)] = struct{}{}
|
||||
}
|
||||
|
||||
missing := 0
|
||||
extra := 0
|
||||
for resourceType, storedIDs := range storedByType {
|
||||
for id := range storedIDs {
|
||||
if _, ok := live[resourceType][id]; !ok {
|
||||
missing++
|
||||
}
|
||||
}
|
||||
for id := range live[resourceType] {
|
||||
if _, ok := storedIDs[id]; !ok {
|
||||
extra++
|
||||
}
|
||||
}
|
||||
}
|
||||
return missing, extra
|
||||
}
|
||||
235
internal/provision/batch_detail_service_test.go
Normal file
235
internal/provision/batch_detail_service_test.go
Normal file
@@ -0,0 +1,235 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
func TestBatchDetailServiceGetReturnsPersistedArtifacts(t *testing.T) {
|
||||
store := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, store)
|
||||
|
||||
host := &fakeHostAdapter{
|
||||
batchAccounts: []sub2api.AccountRef{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
|
||||
testResults: map[string]sub2api.ProbeResult{
|
||||
"account_1": {OK: true, Status: "passed"},
|
||||
"account_2": {OK: true, Status: "passed"},
|
||||
},
|
||||
models: map[string][]sub2api.AccountModel{
|
||||
"account_1": {{ID: "deepseek-chat"}},
|
||||
"account_2": {{ID: "deepseek-chat"}},
|
||||
},
|
||||
gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}},
|
||||
}
|
||||
|
||||
batchID := seedRuntimeImportForReconcile(t, store, host)
|
||||
providerRow, err := store.Providers().ListByProviderID(context.Background(), sampleProviderManifest().ProviderID)
|
||||
if err != nil {
|
||||
t.Fatalf("Providers().ListByProviderID() error = %v", err)
|
||||
}
|
||||
if len(providerRow) != 1 {
|
||||
t.Fatalf("providers = %d, want 1", len(providerRow))
|
||||
}
|
||||
if _, err := store.ReconcileRuns().Create(context.Background(), sqlite.ReconcileRun{ProviderID: providerRow[0].ID, Status: "active", SummaryJSON: `{"missing_count":0}`}); err != nil {
|
||||
t.Fatalf("ReconcileRuns().Create() error = %v", err)
|
||||
}
|
||||
|
||||
result, err := NewBatchDetailService(store).Get(context.Background(), batchID)
|
||||
if err != nil {
|
||||
t.Fatalf("Get() error = %v", err)
|
||||
}
|
||||
if result.Batch.ID != batchID {
|
||||
t.Fatalf("Batch.ID = %d, want %d", result.Batch.ID, batchID)
|
||||
}
|
||||
if len(result.Items) != 2 {
|
||||
t.Fatalf("len(Items) = %d, want 2", len(result.Items))
|
||||
}
|
||||
if len(result.ManagedResources) != 4 {
|
||||
t.Fatalf("len(ManagedResources) = %d, want 4", len(result.ManagedResources))
|
||||
}
|
||||
if len(result.AccessClosures) != 1 {
|
||||
t.Fatalf("len(AccessClosures) = %d, want 1", len(result.AccessClosures))
|
||||
}
|
||||
if len(result.ReconcileRuns) != 1 {
|
||||
t.Fatalf("len(ReconcileRuns) = %d, want 1", len(result.ReconcileRuns))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchDetailServiceGetValidatesStore(t *testing.T) {
|
||||
_, err := (*BatchDetailService)(nil).Get(context.Background(), 1)
|
||||
if err == nil || err.Error() != "store is required" {
|
||||
t.Fatalf("nil service Get() error = %v, want store is required", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountIDFromProbeSummary(t *testing.T) {
|
||||
accountID, err := accountIDFromProbeSummary(`{"account_id":" account_1 "}`)
|
||||
if err != nil {
|
||||
t.Fatalf("accountIDFromProbeSummary() error = %v", err)
|
||||
}
|
||||
if accountID != "account_1" {
|
||||
t.Fatalf("accountID = %q, want account_1", accountID)
|
||||
}
|
||||
if _, err := accountIDFromProbeSummary(`{`); err == nil {
|
||||
t.Fatal("accountIDFromProbeSummary() error = nil, want JSON decode error")
|
||||
}
|
||||
blank, err := accountIDFromProbeSummary("")
|
||||
if err != nil {
|
||||
t.Fatalf("accountIDFromProbeSummary(blank) error = %v", err)
|
||||
}
|
||||
if blank != "" {
|
||||
t.Fatalf("blank accountID = %q, want empty", blank)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconcileServiceRerunAccessClosureWithoutProbeKeyUsesLatestStatus(t *testing.T) {
|
||||
store := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, store)
|
||||
|
||||
status, checked, err := NewReconcileService(store, &fakeHostAdapter{}).rerunAccessClosure(context.Background(), 1, []sqlite.AccessClosureRecord{{ClosureType: AccessModeSubscription, Status: AccessStatusSubscriptionReady}}, "", "deepseek-chat")
|
||||
if err != nil {
|
||||
t.Fatalf("rerunAccessClosure() error = %v", err)
|
||||
}
|
||||
if checked {
|
||||
t.Fatal("checked = true, want false without probe key")
|
||||
}
|
||||
if status != AccessStatusSubscriptionReady {
|
||||
t.Fatalf("status = %q, want %q", status, AccessStatusSubscriptionReady)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconcileServiceRerunAccessClosureMarksBrokenWhenGatewayCheckFails(t *testing.T) {
|
||||
store := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, store)
|
||||
|
||||
hostSeed := &fakeHostAdapter{
|
||||
batchAccounts: []sub2api.AccountRef{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
|
||||
testResults: map[string]sub2api.ProbeResult{
|
||||
"account_1": {OK: true, Status: "passed"},
|
||||
"account_2": {OK: true, Status: "passed"},
|
||||
},
|
||||
models: map[string][]sub2api.AccountModel{
|
||||
"account_1": {{ID: "deepseek-chat"}},
|
||||
"account_2": {{ID: "deepseek-chat"}},
|
||||
},
|
||||
gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}},
|
||||
}
|
||||
batchID := seedRuntimeImportForReconcile(t, store, hostSeed)
|
||||
|
||||
host := &fakeHostAdapter{gatewayResult: sub2api.GatewayAccessResult{OK: false, StatusCode: 403, HasExpectedModel: false}}
|
||||
status, checked, err := NewReconcileService(store, host).rerunAccessClosure(context.Background(), batchID, []sqlite.AccessClosureRecord{{ClosureType: AccessModeSelfService, Status: AccessStatusSelfServiceReady}}, "user-key", "deepseek-chat")
|
||||
if err != nil {
|
||||
t.Fatalf("rerunAccessClosure() error = %v", err)
|
||||
}
|
||||
if !checked {
|
||||
t.Fatal("checked = false, want true")
|
||||
}
|
||||
if status != AccessStatusBroken {
|
||||
t.Fatalf("status = %q, want %q", status, AccessStatusBroken)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "access_closure_records"); got != 2 {
|
||||
t.Fatalf("access_closure_records row count = %d, want 2 after rerun", got)
|
||||
}
|
||||
if host.gatewayProbe.ExpectedModel != "deepseek-chat" {
|
||||
t.Fatalf("ExpectedModel = %q, want deepseek-chat", host.gatewayProbe.ExpectedModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiffManagedResourcesCountsMissingAndExtra(t *testing.T) {
|
||||
missing, extra := diffManagedResources(
|
||||
[]sqlite.ManagedResource{
|
||||
{ResourceType: "group", HostResourceID: "group_1"},
|
||||
{ResourceType: "account", HostResourceID: "account_1"},
|
||||
},
|
||||
sub2api.ManagedResourceSnapshot{
|
||||
Groups: []sub2api.NamedResource{{ID: "group_1"}},
|
||||
Accounts: []sub2api.NamedResource{{ID: "account_2"}},
|
||||
},
|
||||
)
|
||||
if missing != 1 || extra != 1 {
|
||||
t.Fatalf("diffManagedResources() = (%d, %d), want (1, 1)", missing, extra)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveProviderStatus(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
batchStatus string
|
||||
reconcileStatus string
|
||||
want string
|
||||
}{
|
||||
{name: "reconcile wins", batchStatus: BatchStatusSucceeded, reconcileStatus: "degraded", want: "degraded"},
|
||||
{name: "succeeded batch", batchStatus: BatchStatusSucceeded, reconcileStatus: "not_run", want: ProviderStatusActive},
|
||||
{name: "failed batch", batchStatus: BatchStatusFailed, want: ProviderStatusFailed},
|
||||
{name: "running batch", batchStatus: "running", want: "running"},
|
||||
{name: "unknown fallback", batchStatus: " pending ", want: "pending"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := deriveProviderStatus(tc.batchStatus, tc.reconcileStatus); got != tc.want {
|
||||
t.Fatalf("deriveProviderStatus(%q, %q) = %q, want %q", tc.batchStatus, tc.reconcileStatus, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPackAndProviderRecord(t *testing.T) {
|
||||
packRow, err := buildPackRecord(sampleLoadedPack())
|
||||
if err != nil {
|
||||
t.Fatalf("buildPackRecord() error = %v", err)
|
||||
}
|
||||
if packRow.PackID != "openai-cn-pack" || packRow.TargetHost != "sub2api" {
|
||||
t.Fatalf("packRow = %#v, want populated pack metadata", packRow)
|
||||
}
|
||||
|
||||
providerRow, err := buildProviderRecord(7, sampleProviderManifest())
|
||||
if err != nil {
|
||||
t.Fatalf("buildProviderRecord() error = %v", err)
|
||||
}
|
||||
if providerRow.PackID != 7 || providerRow.ProviderID != sampleProviderManifest().ProviderID {
|
||||
t.Fatalf("providerRow = %#v, want persisted provider metadata", providerRow)
|
||||
}
|
||||
if providerRow.DefaultModelsJSON == "" || providerRow.ManifestJSON == "" {
|
||||
t.Fatalf("providerRow JSON fields = %#v, want serialized JSON", providerRow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirstNonEmptyAndFingerprintKey(t *testing.T) {
|
||||
if got := firstNonEmpty(" ", "value", "other"); got != "value" {
|
||||
t.Fatalf("firstNonEmpty() = %q, want value", got)
|
||||
}
|
||||
if got := fingerprintKey([]string{" key-1 "}, 0); got == "key-1" || got == "sha256:" || len(got) < 20 {
|
||||
t.Fatalf("fingerprintKey() = %q, want sha256 fingerprint", got)
|
||||
}
|
||||
if got := fingerprintKey(nil, 3); got != "key-4" {
|
||||
t.Fatalf("fingerprintKey(nil, 3) = %q, want key-4", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderStatusServiceGetResourcesRequiresProviderID(t *testing.T) {
|
||||
store := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, store)
|
||||
|
||||
_, err := NewProviderStatusService(store).GetResources(context.Background(), ProviderQuery{})
|
||||
if err == nil || err.Error() != "provider_id is required" {
|
||||
t.Fatalf("GetResources() error = %v, want provider_id is required", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResourceSlugFallsBackToProvider(t *testing.T) {
|
||||
if got := resourceSlug(" !!! "); got != "provider" {
|
||||
t.Fatalf("resourceSlug() = %q, want provider", got)
|
||||
}
|
||||
provider := sampleProviderManifest()
|
||||
provider.ProviderID = " DeepSeek CN / Prod "
|
||||
if got := SuggestAccountNamePrefix(provider); got != "deepseek-cn-prod-" {
|
||||
t.Fatalf("SuggestAccountNamePrefix() = %q, want deepseek-cn-prod-", got)
|
||||
}
|
||||
resourceNames := SuggestResourceNames(provider)
|
||||
if resourceNames.Group != "crm-deepseek-cn-prod-group" {
|
||||
t.Fatalf("SuggestResourceNames() = %#v, want slugged resource names", resourceNames)
|
||||
}
|
||||
}
|
||||
343
internal/provision/import_service.go
Normal file
343
internal/provision/import_service.go
Normal file
@@ -0,0 +1,343 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/access"
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
"sub2api-cn-relay-manager/internal/pack"
|
||||
)
|
||||
|
||||
const (
|
||||
ImportModeStrict = "strict"
|
||||
ImportModePartial = "partial"
|
||||
|
||||
AccessModeSubscription = "subscription"
|
||||
AccessModeSelfService = "self_service"
|
||||
|
||||
BatchStatusSucceeded = "succeeded"
|
||||
BatchStatusPartial = "partially_succeeded"
|
||||
BatchStatusFailed = "failed"
|
||||
|
||||
ProviderStatusActive = "active"
|
||||
ProviderStatusDegraded = "degraded"
|
||||
ProviderStatusFailed = "failed"
|
||||
|
||||
AccessStatusSubscriptionReady = "subscription_ready"
|
||||
AccessStatusSelfServiceReady = "self_service_ready"
|
||||
AccessStatusBroken = "broken"
|
||||
)
|
||||
|
||||
type AccessRequest struct {
|
||||
Mode string
|
||||
ProbeAPIKey string
|
||||
Subscriptions []SubscriptionTarget
|
||||
}
|
||||
|
||||
type SubscriptionTarget struct {
|
||||
UserID string
|
||||
DurationDays int
|
||||
}
|
||||
|
||||
type ImportRequest struct {
|
||||
Provider pack.ProviderManifest
|
||||
Mode string
|
||||
Access AccessRequest
|
||||
Keys []string
|
||||
}
|
||||
|
||||
type ImportReport struct {
|
||||
BatchStatus string
|
||||
ProviderStatus string
|
||||
AccessStatus string
|
||||
AcceptedKeys []string
|
||||
Group sub2api.GroupRef
|
||||
Channel sub2api.ChannelRef
|
||||
Plan *sub2api.PlanRef
|
||||
Accounts []AccountImportResult
|
||||
Gateway sub2api.GatewayAccessResult
|
||||
}
|
||||
|
||||
type AccountImportResult struct {
|
||||
Ref sub2api.AccountRef
|
||||
Probe sub2api.ProbeResult
|
||||
Models []sub2api.AccountModel
|
||||
SmokeModelSeen bool
|
||||
}
|
||||
|
||||
type hostAdapter interface {
|
||||
sub2api.HostAdapter
|
||||
CheckGatewayAccess(ctx context.Context, req sub2api.GatewayAccessCheckRequest) (sub2api.GatewayAccessResult, error)
|
||||
}
|
||||
|
||||
type ImportService struct {
|
||||
host hostAdapter
|
||||
}
|
||||
|
||||
func NewImportService(host hostAdapter) *ImportService {
|
||||
return &ImportService{host: host}
|
||||
}
|
||||
|
||||
func (s *ImportService) Import(ctx context.Context, req ImportRequest) (report ImportReport, err error) {
|
||||
normalizedKeys, err := normalizeKeys(req.Keys)
|
||||
if err != nil {
|
||||
return ImportReport{}, err
|
||||
}
|
||||
if err := validateMode(req.Mode); err != nil {
|
||||
return ImportReport{}, err
|
||||
}
|
||||
if err := access.Validate(access.ClosureRequest{
|
||||
Mode: req.Access.Mode,
|
||||
ProbeAPIKey: req.Access.ProbeAPIKey,
|
||||
Subscriptions: toAccessSubscriptionTargets(req.Access.Subscriptions),
|
||||
}); err != nil {
|
||||
return ImportReport{}, err
|
||||
}
|
||||
|
||||
report = ImportReport{AcceptedKeys: normalizedKeys}
|
||||
rollback := newManagedResourceRollback(s.host)
|
||||
defer func() {
|
||||
if err == nil || req.Mode != ImportModeStrict {
|
||||
return
|
||||
}
|
||||
if rollbackErr := rollback.Run(ctx); rollbackErr != nil {
|
||||
err = errors.Join(err, fmt.Errorf("rollback managed resources: %w", rollbackErr))
|
||||
}
|
||||
}()
|
||||
group, err := s.host.CreateGroup(ctx, sub2api.CreateGroupRequest{
|
||||
Name: req.Provider.GroupTemplate.Name,
|
||||
RateMultiplier: req.Provider.GroupTemplate.RateMultiplier,
|
||||
})
|
||||
if err != nil {
|
||||
return report, fmt.Errorf("create group: %w", err)
|
||||
}
|
||||
report.Group = group
|
||||
rollback.AddGroup(group.ID)
|
||||
|
||||
channel, err := s.host.CreateChannel(ctx, sub2api.CreateChannelRequest{
|
||||
Name: req.Provider.ChannelTemplate.Name,
|
||||
GroupIDs: []string{group.ID},
|
||||
})
|
||||
if err != nil {
|
||||
return report, fmt.Errorf("create channel: %w", err)
|
||||
}
|
||||
report.Channel = channel
|
||||
rollback.AddChannel(channel.ID)
|
||||
|
||||
if req.Access.Mode == AccessModeSubscription {
|
||||
plan, err := s.host.CreatePlan(ctx, sub2api.CreatePlanRequest{
|
||||
GroupID: group.ID,
|
||||
Name: req.Provider.PlanTemplate.Name,
|
||||
Price: req.Provider.PlanTemplate.Price,
|
||||
ValidityDays: req.Provider.PlanTemplate.ValidityDays,
|
||||
ValidityUnit: req.Provider.PlanTemplate.ValidityUnit,
|
||||
})
|
||||
if err != nil {
|
||||
return report, fmt.Errorf("create plan: %w", err)
|
||||
}
|
||||
report.Plan = &plan
|
||||
rollback.AddPlan(plan.ID)
|
||||
}
|
||||
|
||||
accounts, err := s.host.BatchCreateAccounts(ctx, buildBatchAccountsRequest(req.Provider, group.ID, normalizedKeys))
|
||||
if err != nil {
|
||||
return report, fmt.Errorf("batch create accounts: %w", err)
|
||||
}
|
||||
rollback.AddAccounts(accounts)
|
||||
for _, account := range accounts {
|
||||
probe, err := s.host.TestAccount(ctx, account.ID)
|
||||
if err != nil {
|
||||
return failOrDegrade(report, req.Mode, fmt.Errorf("test account %s: %w", account.ID, err))
|
||||
}
|
||||
models, err := s.host.GetAccountModels(ctx, account.ID)
|
||||
if err != nil {
|
||||
return failOrDegrade(report, req.Mode, fmt.Errorf("get account models %s: %w", account.ID, err))
|
||||
}
|
||||
result := AccountImportResult{Ref: account, Probe: probe, Models: models, SmokeModelSeen: hasModel(models, req.Provider.SmokeTestModel)}
|
||||
report.Accounts = append(report.Accounts, result)
|
||||
}
|
||||
|
||||
failedAccounts := 0
|
||||
for _, account := range report.Accounts {
|
||||
if !account.Probe.OK || !account.SmokeModelSeen {
|
||||
failedAccounts++
|
||||
}
|
||||
}
|
||||
if failedAccounts > 0 && req.Mode == ImportModeStrict {
|
||||
report.BatchStatus = BatchStatusFailed
|
||||
report.ProviderStatus = ProviderStatusFailed
|
||||
report.AccessStatus = AccessStatusBroken
|
||||
return report, fmt.Errorf("strict import failed: %d account(s) did not pass smoke validation", failedAccounts)
|
||||
}
|
||||
|
||||
closureService := access.NewService(s.host)
|
||||
gateway, err := closureService.Close(ctx, access.ClosureRequest{
|
||||
Mode: req.Access.Mode,
|
||||
ProbeAPIKey: req.Access.ProbeAPIKey,
|
||||
Subscriptions: toAccessSubscriptionTargets(req.Access.Subscriptions),
|
||||
GroupID: group.ID,
|
||||
ExpectedModel: req.Provider.SmokeTestModel,
|
||||
})
|
||||
if err != nil {
|
||||
return failOrDegrade(report, req.Mode, err)
|
||||
}
|
||||
report.Gateway = gateway
|
||||
|
||||
report.BatchStatus = BatchStatusSucceeded
|
||||
report.ProviderStatus = ProviderStatusActive
|
||||
if failedAccounts > 0 || !gateway.OK || !gateway.HasExpectedModel {
|
||||
report.BatchStatus = BatchStatusPartial
|
||||
report.ProviderStatus = ProviderStatusDegraded
|
||||
}
|
||||
switch req.Access.Mode {
|
||||
case AccessModeSubscription:
|
||||
report.AccessStatus = AccessStatusSubscriptionReady
|
||||
case AccessModeSelfService:
|
||||
report.AccessStatus = AccessStatusSelfServiceReady
|
||||
}
|
||||
if !gateway.OK || !gateway.HasExpectedModel {
|
||||
report.AccessStatus = AccessStatusBroken
|
||||
}
|
||||
return report, nil
|
||||
}
|
||||
|
||||
func validateMode(mode string) error {
|
||||
switch strings.TrimSpace(mode) {
|
||||
case ImportModeStrict, ImportModePartial:
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unsupported import mode %q", mode)
|
||||
}
|
||||
}
|
||||
|
||||
func toAccessSubscriptionTargets(targets []SubscriptionTarget) []access.SubscriptionTarget {
|
||||
result := make([]access.SubscriptionTarget, 0, len(targets))
|
||||
for _, target := range targets {
|
||||
result = append(result, access.SubscriptionTarget{UserID: target.UserID, DurationDays: target.DurationDays})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func normalizeKeys(keys []string) ([]string, error) {
|
||||
seen := map[string]struct{}{}
|
||||
result := make([]string, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
normalized := strings.TrimSpace(strings.TrimPrefix(key, "\ufeff"))
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[normalized]; ok {
|
||||
continue
|
||||
}
|
||||
seen[normalized] = struct{}{}
|
||||
result = append(result, normalized)
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, fmt.Errorf("at least one api key is required")
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func buildBatchAccountsRequest(provider pack.ProviderManifest, groupID string, keys []string) sub2api.BatchCreateAccountsRequest {
|
||||
accounts := make([]sub2api.CreateAccountRequest, 0, len(keys))
|
||||
for index, key := range keys {
|
||||
accounts = append(accounts, sub2api.CreateAccountRequest{
|
||||
Name: fmt.Sprintf("%s-%02d", provider.ProviderID, index+1),
|
||||
Platform: provider.Platform,
|
||||
Type: provider.AccountType,
|
||||
GroupIDs: []string{groupID},
|
||||
Credentials: map[string]any{
|
||||
"base_url": provider.BaseURL,
|
||||
"api_key": key,
|
||||
},
|
||||
})
|
||||
}
|
||||
return sub2api.BatchCreateAccountsRequest{Accounts: accounts}
|
||||
}
|
||||
|
||||
func hasModel(models []sub2api.AccountModel, target string) bool {
|
||||
for _, model := range models {
|
||||
if strings.TrimSpace(model.ID) == strings.TrimSpace(target) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type managedResourceRollback struct {
|
||||
host hostAdapter
|
||||
groupID string
|
||||
channelID string
|
||||
planID string
|
||||
accountIDs []string
|
||||
}
|
||||
|
||||
func newManagedResourceRollback(host hostAdapter) *managedResourceRollback {
|
||||
return &managedResourceRollback{host: host}
|
||||
}
|
||||
|
||||
func (r *managedResourceRollback) AddGroup(groupID string) {
|
||||
r.groupID = strings.TrimSpace(groupID)
|
||||
}
|
||||
|
||||
func (r *managedResourceRollback) AddChannel(channelID string) {
|
||||
r.channelID = strings.TrimSpace(channelID)
|
||||
}
|
||||
|
||||
func (r *managedResourceRollback) AddPlan(planID string) {
|
||||
r.planID = strings.TrimSpace(planID)
|
||||
}
|
||||
|
||||
func (r *managedResourceRollback) AddAccounts(accounts []sub2api.AccountRef) {
|
||||
for _, account := range accounts {
|
||||
accountID := strings.TrimSpace(account.ID)
|
||||
if accountID == "" {
|
||||
continue
|
||||
}
|
||||
r.accountIDs = append(r.accountIDs, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *managedResourceRollback) Run(ctx context.Context) error {
|
||||
if r == nil || r.host == nil {
|
||||
return nil
|
||||
}
|
||||
var errs []error
|
||||
for index := len(r.accountIDs) - 1; index >= 0; index-- {
|
||||
if err := r.host.DeleteAccount(ctx, r.accountIDs[index]); err != nil {
|
||||
errs = append(errs, fmt.Errorf("delete account %s: %w", r.accountIDs[index], err))
|
||||
}
|
||||
}
|
||||
if r.planID != "" {
|
||||
if err := r.host.DeletePlan(ctx, r.planID); err != nil {
|
||||
errs = append(errs, fmt.Errorf("delete plan %s: %w", r.planID, err))
|
||||
}
|
||||
}
|
||||
if r.channelID != "" {
|
||||
if err := r.host.DeleteChannel(ctx, r.channelID); err != nil {
|
||||
errs = append(errs, fmt.Errorf("delete channel %s: %w", r.channelID, err))
|
||||
}
|
||||
}
|
||||
if r.groupID != "" {
|
||||
if err := r.host.DeleteGroup(ctx, r.groupID); err != nil {
|
||||
errs = append(errs, fmt.Errorf("delete group %s: %w", r.groupID, err))
|
||||
}
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func failOrDegrade(report ImportReport, mode string, err error) (ImportReport, error) {
|
||||
if mode == ImportModeStrict {
|
||||
report.BatchStatus = BatchStatusFailed
|
||||
report.ProviderStatus = ProviderStatusFailed
|
||||
report.AccessStatus = AccessStatusBroken
|
||||
return report, err
|
||||
}
|
||||
report.BatchStatus = BatchStatusPartial
|
||||
report.ProviderStatus = ProviderStatusDegraded
|
||||
report.AccessStatus = AccessStatusBroken
|
||||
return report, err
|
||||
}
|
||||
241
internal/provision/import_service_test.go
Normal file
241
internal/provision/import_service_test.go
Normal file
@@ -0,0 +1,241 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
"sub2api-cn-relay-manager/internal/pack"
|
||||
)
|
||||
|
||||
func TestImportServiceImportSubscriptionFlow(t *testing.T) {
|
||||
host := &fakeHostAdapter{
|
||||
batchAccounts: []sub2api.AccountRef{{ID: "account_1"}, {ID: "account_2"}},
|
||||
testResults: map[string]sub2api.ProbeResult{
|
||||
"account_1": {OK: true, Status: "passed"},
|
||||
"account_2": {OK: true, Status: "passed"},
|
||||
},
|
||||
models: map[string][]sub2api.AccountModel{
|
||||
"account_1": {{ID: "deepseek-chat"}},
|
||||
"account_2": {{ID: "deepseek-chat"}},
|
||||
},
|
||||
gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}},
|
||||
}
|
||||
|
||||
svc := NewImportService(host)
|
||||
report, err := svc.Import(context.Background(), ImportRequest{
|
||||
Provider: sampleProviderManifest(),
|
||||
Mode: ImportModePartial,
|
||||
Access: AccessRequest{
|
||||
Mode: AccessModeSubscription,
|
||||
ProbeAPIKey: "user-key",
|
||||
Subscriptions: []SubscriptionTarget{{UserID: "user_1", DurationDays: 30}},
|
||||
},
|
||||
Keys: []string{" key-1 ", "key-2", "key-1"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Import() error = %v", err)
|
||||
}
|
||||
|
||||
if report.BatchStatus != BatchStatusSucceeded {
|
||||
t.Fatalf("BatchStatus = %q, want %q", report.BatchStatus, BatchStatusSucceeded)
|
||||
}
|
||||
if report.ProviderStatus != ProviderStatusActive {
|
||||
t.Fatalf("ProviderStatus = %q, want %q", report.ProviderStatus, ProviderStatusActive)
|
||||
}
|
||||
if report.AccessStatus != AccessStatusSubscriptionReady {
|
||||
t.Fatalf("AccessStatus = %q, want %q", report.AccessStatus, AccessStatusSubscriptionReady)
|
||||
}
|
||||
if !reflect.DeepEqual(report.AcceptedKeys, []string{"key-1", "key-2"}) {
|
||||
t.Fatalf("AcceptedKeys = %#v, want deduped normalized keys", report.AcceptedKeys)
|
||||
}
|
||||
if len(host.assignedSubscriptions) != 1 {
|
||||
t.Fatalf("assigned subscriptions = %d, want 1", len(host.assignedSubscriptions))
|
||||
}
|
||||
if host.gatewayProbe.ExpectedModel != "deepseek-chat" {
|
||||
t.Fatalf("gateway probe model = %q, want %q", host.gatewayProbe.ExpectedModel, "deepseek-chat")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportServiceStrictModeFailsWhenAnyAccountProbeFails(t *testing.T) {
|
||||
host := &fakeHostAdapter{
|
||||
batchAccounts: []sub2api.AccountRef{{ID: "account_1"}, {ID: "account_2"}},
|
||||
testResults: map[string]sub2api.ProbeResult{
|
||||
"account_1": {OK: true, Status: "passed"},
|
||||
"account_2": {OK: false, Status: "failed", Message: "bad key"},
|
||||
},
|
||||
models: map[string][]sub2api.AccountModel{
|
||||
"account_1": {{ID: "deepseek-chat"}},
|
||||
"account_2": {{ID: "deepseek-chat"}},
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewImportService(host)
|
||||
report, err := svc.Import(context.Background(), ImportRequest{
|
||||
Provider: sampleProviderManifest(),
|
||||
Mode: ImportModeStrict,
|
||||
Access: AccessRequest{Mode: AccessModeSelfService, ProbeAPIKey: "user-key"},
|
||||
Keys: []string{"key-1", "key-2"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Import() error = nil, want strict mode failure")
|
||||
}
|
||||
if report.BatchStatus != BatchStatusFailed {
|
||||
t.Fatalf("BatchStatus = %q, want %q", report.BatchStatus, BatchStatusFailed)
|
||||
}
|
||||
if report.ProviderStatus != ProviderStatusFailed {
|
||||
t.Fatalf("ProviderStatus = %q, want %q", report.ProviderStatus, ProviderStatusFailed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportServiceRejectsUnknownMode(t *testing.T) {
|
||||
svc := NewImportService(&fakeHostAdapter{})
|
||||
_, err := svc.Import(context.Background(), ImportRequest{
|
||||
Provider: sampleProviderManifest(),
|
||||
Mode: "unknown",
|
||||
Access: AccessRequest{Mode: AccessModeSelfService},
|
||||
Keys: []string{"key-1"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Import() error = nil, want mode validation error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportServiceStrictModeRollsBackCreatedResources(t *testing.T) {
|
||||
host := &fakeHostAdapter{
|
||||
batchAccounts: []sub2api.AccountRef{{ID: "account_1"}, {ID: "account_2"}},
|
||||
testResults: map[string]sub2api.ProbeResult{
|
||||
"account_1": {OK: true, Status: "passed"},
|
||||
"account_2": {OK: false, Status: "failed", Message: "bad key"},
|
||||
},
|
||||
models: map[string][]sub2api.AccountModel{
|
||||
"account_1": {{ID: "deepseek-chat"}},
|
||||
"account_2": {{ID: "deepseek-chat"}},
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewImportService(host)
|
||||
_, err := svc.Import(context.Background(), ImportRequest{
|
||||
Provider: sampleProviderManifest(),
|
||||
Mode: ImportModeStrict,
|
||||
Access: AccessRequest{Mode: AccessModeSelfService, ProbeAPIKey: "user-key"},
|
||||
Keys: []string{"key-1", "key-2"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Import() error = nil, want strict mode failure")
|
||||
}
|
||||
|
||||
want := []string{"account:account_2", "account:account_1", "channel:channel_1", "group:group_1"}
|
||||
if !reflect.DeepEqual(host.deletedResources, want) {
|
||||
t.Fatalf("deleted resources = %#v, want %#v", host.deletedResources, want)
|
||||
}
|
||||
}
|
||||
|
||||
func sampleProviderManifest() pack.ProviderManifest {
|
||||
return pack.ProviderManifest{
|
||||
ProviderID: "deepseek",
|
||||
DisplayName: "DeepSeek OpenAI Compatible",
|
||||
BaseURL: "https://api.deepseek.com",
|
||||
Platform: "openai",
|
||||
AccountType: "api",
|
||||
DefaultModels: []string{"deepseek-chat", "deepseek-reasoner"},
|
||||
SmokeTestModel: "deepseek-chat",
|
||||
GroupTemplate: pack.GroupTemplate{Name: "DeepSeek 默认分组", RateMultiplier: 1},
|
||||
ChannelTemplate: pack.ChannelTemplate{Name: "DeepSeek 默认渠道", ModelMapping: map[string]string{"deepseek-chat": "deepseek-chat"}},
|
||||
PlanTemplate: pack.PlanTemplate{Name: "DeepSeek 默认套餐", Price: 19.9, ValidityDays: 30, ValidityUnit: "day"},
|
||||
}
|
||||
}
|
||||
|
||||
type fakeHostAdapter struct {
|
||||
batchAccounts []sub2api.AccountRef
|
||||
testResults map[string]sub2api.ProbeResult
|
||||
models map[string][]sub2api.AccountModel
|
||||
gatewayResult sub2api.GatewayAccessResult
|
||||
batchCreateErr error
|
||||
assignErr error
|
||||
gatewayErr error
|
||||
hostVersion string
|
||||
assignedSubscriptions []sub2api.AssignSubscriptionRequest
|
||||
gatewayProbe sub2api.GatewayAccessCheckRequest
|
||||
deletedResources []string
|
||||
managedSnapshot sub2api.ManagedResourceSnapshot
|
||||
}
|
||||
|
||||
func (f *fakeHostAdapter) GetHostVersion(context.Context) (string, error) {
|
||||
if strings.TrimSpace(f.hostVersion) == "" {
|
||||
return "0.1.126", nil
|
||||
}
|
||||
return f.hostVersion, nil
|
||||
}
|
||||
func (f *fakeHostAdapter) ProbeCapabilities(context.Context) (sub2api.HostCapabilities, error) {
|
||||
return sub2api.HostCapabilities{}, nil
|
||||
}
|
||||
func (f *fakeHostAdapter) CreateGroup(context.Context, sub2api.CreateGroupRequest) (sub2api.GroupRef, error) {
|
||||
return sub2api.GroupRef{ID: "group_1", Name: "g"}, nil
|
||||
}
|
||||
func (f *fakeHostAdapter) DeleteGroup(_ context.Context, groupID string) error {
|
||||
f.deletedResources = append(f.deletedResources, "group:"+groupID)
|
||||
return nil
|
||||
}
|
||||
func (f *fakeHostAdapter) CreateChannel(context.Context, sub2api.CreateChannelRequest) (sub2api.ChannelRef, error) {
|
||||
return sub2api.ChannelRef{ID: "channel_1", Name: "c"}, nil
|
||||
}
|
||||
func (f *fakeHostAdapter) DeleteChannel(_ context.Context, channelID string) error {
|
||||
f.deletedResources = append(f.deletedResources, "channel:"+channelID)
|
||||
return nil
|
||||
}
|
||||
func (f *fakeHostAdapter) CreatePlan(context.Context, sub2api.CreatePlanRequest) (sub2api.PlanRef, error) {
|
||||
return sub2api.PlanRef{ID: "plan_1", Name: "p"}, nil
|
||||
}
|
||||
func (f *fakeHostAdapter) DeletePlan(_ context.Context, planID string) error {
|
||||
f.deletedResources = append(f.deletedResources, "plan:"+planID)
|
||||
return nil
|
||||
}
|
||||
func (f *fakeHostAdapter) CreateAccount(context.Context, sub2api.CreateAccountRequest) (sub2api.AccountRef, error) {
|
||||
return sub2api.AccountRef{}, errors.New("unused")
|
||||
}
|
||||
func (f *fakeHostAdapter) BatchCreateAccounts(_ context.Context, _ sub2api.BatchCreateAccountsRequest) ([]sub2api.AccountRef, error) {
|
||||
if f.batchCreateErr != nil {
|
||||
return nil, f.batchCreateErr
|
||||
}
|
||||
return f.batchAccounts, nil
|
||||
}
|
||||
func (f *fakeHostAdapter) DeleteAccount(_ context.Context, accountID string) error {
|
||||
f.deletedResources = append(f.deletedResources, "account:"+accountID)
|
||||
return nil
|
||||
}
|
||||
func (f *fakeHostAdapter) TestAccount(_ context.Context, accountID string) (sub2api.ProbeResult, error) {
|
||||
result, ok := f.testResults[accountID]
|
||||
if !ok {
|
||||
return sub2api.ProbeResult{}, fmt.Errorf("missing test result for %s", accountID)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
func (f *fakeHostAdapter) GetAccountModels(_ context.Context, accountID string) ([]sub2api.AccountModel, error) {
|
||||
models, ok := f.models[accountID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing models for %s", accountID)
|
||||
}
|
||||
return models, nil
|
||||
}
|
||||
func (f *fakeHostAdapter) AssignSubscription(_ context.Context, req sub2api.AssignSubscriptionRequest) (sub2api.SubscriptionRef, error) {
|
||||
if f.assignErr != nil {
|
||||
return sub2api.SubscriptionRef{}, f.assignErr
|
||||
}
|
||||
f.assignedSubscriptions = append(f.assignedSubscriptions, req)
|
||||
return sub2api.SubscriptionRef{ID: "subscription_1"}, nil
|
||||
}
|
||||
func (f *fakeHostAdapter) CheckGatewayAccess(_ context.Context, req sub2api.GatewayAccessCheckRequest) (sub2api.GatewayAccessResult, error) {
|
||||
f.gatewayProbe = req
|
||||
if f.gatewayErr != nil {
|
||||
return sub2api.GatewayAccessResult{}, f.gatewayErr
|
||||
}
|
||||
return f.gatewayResult, nil
|
||||
}
|
||||
func (f *fakeHostAdapter) ListManagedResources(context.Context, sub2api.ListManagedResourcesRequest) (sub2api.ManagedResourceSnapshot, error) {
|
||||
return f.managedSnapshot, nil
|
||||
}
|
||||
40
internal/provision/naming.go
Normal file
40
internal/provision/naming.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/pack"
|
||||
)
|
||||
|
||||
var nonSlugPattern = regexp.MustCompile(`[^a-z0-9]+`)
|
||||
|
||||
type ResourceNames struct {
|
||||
Group string
|
||||
Channel string
|
||||
Plan string
|
||||
}
|
||||
|
||||
func SuggestAccountNamePrefix(provider pack.ProviderManifest) string {
|
||||
return fmt.Sprintf("%s-", resourceSlug(provider.ProviderID))
|
||||
}
|
||||
|
||||
func SuggestResourceNames(provider pack.ProviderManifest) ResourceNames {
|
||||
slug := resourceSlug(provider.ProviderID)
|
||||
return ResourceNames{
|
||||
Group: fmt.Sprintf("crm-%s-group", slug),
|
||||
Channel: fmt.Sprintf("crm-%s-channel", slug),
|
||||
Plan: fmt.Sprintf("crm-%s-plan", slug),
|
||||
}
|
||||
}
|
||||
|
||||
func resourceSlug(raw string) string {
|
||||
slug := strings.ToLower(strings.TrimSpace(raw))
|
||||
slug = nonSlugPattern.ReplaceAllString(slug, "-")
|
||||
slug = strings.Trim(slug, "-")
|
||||
if slug == "" {
|
||||
return "provider"
|
||||
}
|
||||
return slug
|
||||
}
|
||||
178
internal/provision/pack_install_service.go
Normal file
178
internal/provision/pack_install_service.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
packdef "sub2api-cn-relay-manager/internal/pack"
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
type PackInstallRequest struct {
|
||||
Pack packdef.LoadedPack
|
||||
}
|
||||
|
||||
type PackInstallResult struct {
|
||||
Pack sqlite.Pack
|
||||
Providers []sqlite.Provider
|
||||
HostVersion string
|
||||
AlreadyInstalled bool
|
||||
}
|
||||
|
||||
type PackInstallService struct {
|
||||
store *sqlite.DB
|
||||
host sub2api.HostAdapter
|
||||
}
|
||||
|
||||
func NewPackInstallService(store *sqlite.DB, host sub2api.HostAdapter) *PackInstallService {
|
||||
return &PackInstallService{store: store, host: host}
|
||||
}
|
||||
|
||||
func (s *PackInstallService) Install(ctx context.Context, req PackInstallRequest) (PackInstallResult, error) {
|
||||
if s == nil || s.store == nil {
|
||||
return PackInstallResult{}, fmt.Errorf("store is required")
|
||||
}
|
||||
if s.host == nil {
|
||||
return PackInstallResult{}, fmt.Errorf("host adapter is required")
|
||||
}
|
||||
if strings.TrimSpace(req.Pack.Manifest.PackID) == "" {
|
||||
return PackInstallResult{}, fmt.Errorf("pack manifest is required")
|
||||
}
|
||||
|
||||
hostVersion, err := s.host.GetHostVersion(ctx)
|
||||
if err != nil {
|
||||
return PackInstallResult{}, fmt.Errorf("get host version: %w", err)
|
||||
}
|
||||
if err := packdef.CheckHostCompatibility(req.Pack.Manifest, hostVersion); err != nil {
|
||||
return PackInstallResult{}, err
|
||||
}
|
||||
|
||||
result := PackInstallResult{HostVersion: hostVersion}
|
||||
if err := s.store.WithTx(ctx, func(queries *sqlite.Queries) error {
|
||||
existing, err := queries.Packs.GetByPackID(ctx, req.Pack.Manifest.PackID)
|
||||
if err == nil {
|
||||
if err := validateExistingPack(existing, req.Pack); err != nil {
|
||||
return err
|
||||
}
|
||||
result.AlreadyInstalled = true
|
||||
} else if !errors.Is(err, sql.ErrNoRows) {
|
||||
return err
|
||||
}
|
||||
|
||||
packRow, err := buildPackRecord(req.Pack)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := queries.Packs.Upsert(ctx, packRow); err != nil {
|
||||
return err
|
||||
}
|
||||
persistedPack, err := queries.Packs.GetByPackID(ctx, req.Pack.Manifest.PackID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result.Pack = persistedPack
|
||||
|
||||
providers := make([]sqlite.Provider, 0, len(req.Pack.Providers))
|
||||
for _, providerManifest := range req.Pack.Providers {
|
||||
providerRow, err := buildProviderRecord(persistedPack.ID, providerManifest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := queries.Providers.Upsert(ctx, providerRow); err != nil {
|
||||
return err
|
||||
}
|
||||
persistedProvider, err := queries.Providers.GetByPackIDAndProviderID(ctx, persistedPack.ID, providerManifest.ProviderID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
providers = append(providers, persistedProvider)
|
||||
}
|
||||
result.Providers = providers
|
||||
return nil
|
||||
}); err != nil {
|
||||
return PackInstallResult{}, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func validateExistingPack(existing sqlite.Pack, loaded packdef.LoadedPack) error {
|
||||
if strings.TrimSpace(existing.Version) != strings.TrimSpace(loaded.Manifest.Version) {
|
||||
return fmt.Errorf("pack %q already installed with version %q; upgrade lifecycle not implemented", existing.PackID, existing.Version)
|
||||
}
|
||||
if strings.TrimSpace(existing.Checksum) != strings.TrimSpace(loaded.Checksum) {
|
||||
return fmt.Errorf("pack %q version %q checksum drift detected", existing.PackID, existing.Version)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildPackRecord(loaded packdef.LoadedPack) (sqlite.Pack, error) {
|
||||
manifestJSON, err := json.Marshal(loaded.Manifest)
|
||||
if err != nil {
|
||||
return sqlite.Pack{}, fmt.Errorf("marshal pack manifest: %w", err)
|
||||
}
|
||||
return sqlite.Pack{
|
||||
PackID: loaded.Manifest.PackID,
|
||||
Version: loaded.Manifest.Version,
|
||||
Checksum: loaded.Checksum,
|
||||
Vendor: loaded.Manifest.Vendor,
|
||||
TargetHost: loaded.Manifest.TargetHost,
|
||||
MinHostVersion: loaded.Manifest.MinHostVersion,
|
||||
MaxHostVersion: loaded.Manifest.MaxHostVersion,
|
||||
ManifestJSON: string(manifestJSON),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildProviderRecord(packID int64, provider packdef.ProviderManifest) (sqlite.Provider, error) {
|
||||
defaultModelsJSON, err := marshalJSONString(provider.DefaultModels)
|
||||
if err != nil {
|
||||
return sqlite.Provider{}, fmt.Errorf("marshal provider default models: %w", err)
|
||||
}
|
||||
groupTemplateJSON, err := marshalJSONString(provider.GroupTemplate)
|
||||
if err != nil {
|
||||
return sqlite.Provider{}, fmt.Errorf("marshal group template: %w", err)
|
||||
}
|
||||
channelTemplateJSON, err := marshalJSONString(provider.ChannelTemplate)
|
||||
if err != nil {
|
||||
return sqlite.Provider{}, fmt.Errorf("marshal channel template: %w", err)
|
||||
}
|
||||
planTemplateJSON, err := marshalJSONString(provider.PlanTemplate)
|
||||
if err != nil {
|
||||
return sqlite.Provider{}, fmt.Errorf("marshal plan template: %w", err)
|
||||
}
|
||||
importOptionsJSON, err := marshalJSONString(provider.Import)
|
||||
if err != nil {
|
||||
return sqlite.Provider{}, fmt.Errorf("marshal import options: %w", err)
|
||||
}
|
||||
manifestJSON, err := marshalJSONString(provider)
|
||||
if err != nil {
|
||||
return sqlite.Provider{}, fmt.Errorf("marshal provider manifest: %w", err)
|
||||
}
|
||||
return sqlite.Provider{
|
||||
PackID: packID,
|
||||
ProviderID: provider.ProviderID,
|
||||
DisplayName: provider.DisplayName,
|
||||
BaseURL: provider.BaseURL,
|
||||
Platform: provider.Platform,
|
||||
AccountType: provider.AccountType,
|
||||
DefaultModelsJSON: defaultModelsJSON,
|
||||
SmokeTestModel: provider.SmokeTestModel,
|
||||
GroupTemplateJSON: groupTemplateJSON,
|
||||
ChannelTemplateJSON: channelTemplateJSON,
|
||||
PlanTemplateJSON: planTemplateJSON,
|
||||
ImportOptionsJSON: importOptionsJSON,
|
||||
ManifestJSON: manifestJSON,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func marshalJSONString(value any) (string, error) {
|
||||
body, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(body), nil
|
||||
}
|
||||
120
internal/provision/pack_install_service_test.go
Normal file
120
internal/provision/pack_install_service_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/pack"
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
func TestPackInstallServiceInstallPersistsPackAndProviders(t *testing.T) {
|
||||
store := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, store)
|
||||
|
||||
host := &fakeHostAdapter{}
|
||||
loaded := sampleLoadedPack()
|
||||
|
||||
svc := NewPackInstallService(store, host)
|
||||
result, err := svc.Install(context.Background(), PackInstallRequest{Pack: loaded})
|
||||
if err != nil {
|
||||
t.Fatalf("Install() error = %v", err)
|
||||
}
|
||||
if result.HostVersion != "0.1.126" {
|
||||
t.Fatalf("HostVersion = %q, want 0.1.126", result.HostVersion)
|
||||
}
|
||||
if result.AlreadyInstalled {
|
||||
t.Fatal("AlreadyInstalled = true, want false on first install")
|
||||
}
|
||||
if result.Pack.PackID != loaded.Manifest.PackID {
|
||||
t.Fatalf("Pack.PackID = %q, want %q", result.Pack.PackID, loaded.Manifest.PackID)
|
||||
}
|
||||
if len(result.Providers) != 1 || result.Providers[0].ProviderID != loaded.Providers[0].ProviderID {
|
||||
t.Fatalf("Providers = %#v, want one persisted provider", result.Providers)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "packs"); got != 1 {
|
||||
t.Fatalf("packs row count = %d, want 1", got)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "providers"); got != 1 {
|
||||
t.Fatalf("providers row count = %d, want 1", got)
|
||||
}
|
||||
|
||||
repeat, err := svc.Install(context.Background(), PackInstallRequest{Pack: loaded})
|
||||
if err != nil {
|
||||
t.Fatalf("second Install() error = %v", err)
|
||||
}
|
||||
if !repeat.AlreadyInstalled {
|
||||
t.Fatal("AlreadyInstalled = false, want true on re-install")
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "packs"); got != 1 {
|
||||
t.Fatalf("packs row count after re-install = %d, want 1", got)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "providers"); got != 1 {
|
||||
t.Fatalf("providers row count after re-install = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPackInstallServiceInstallRejectsVersionAndChecksumDrift(t *testing.T) {
|
||||
store := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, store)
|
||||
|
||||
svc := NewPackInstallService(store, &fakeHostAdapter{})
|
||||
loaded := sampleLoadedPack()
|
||||
if _, err := svc.Install(context.Background(), PackInstallRequest{Pack: loaded}); err != nil {
|
||||
t.Fatalf("initial Install() error = %v", err)
|
||||
}
|
||||
|
||||
versionDrift := sampleLoadedPack()
|
||||
versionDrift.Manifest.Version = "2.0.0"
|
||||
if _, err := svc.Install(context.Background(), PackInstallRequest{Pack: versionDrift}); err == nil || !strings.Contains(err.Error(), "upgrade lifecycle not implemented") {
|
||||
t.Fatalf("Install() version drift error = %v, want upgrade lifecycle error", err)
|
||||
}
|
||||
|
||||
checksumDrift := sampleLoadedPack()
|
||||
checksumDrift.Checksum = "checksum-2"
|
||||
if _, err := svc.Install(context.Background(), PackInstallRequest{Pack: checksumDrift}); err == nil || !strings.Contains(err.Error(), "checksum drift detected") {
|
||||
t.Fatalf("Install() checksum drift error = %v, want checksum drift error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPackInstallServiceInstallValidatesDependencies(t *testing.T) {
|
||||
loaded := sampleLoadedPack()
|
||||
storeWithoutHost := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, storeWithoutHost)
|
||||
storeWithoutPack := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, storeWithoutPack)
|
||||
|
||||
if _, err := (*PackInstallService)(nil).Install(context.Background(), PackInstallRequest{Pack: loaded}); err == nil || err.Error() != "store is required" {
|
||||
t.Fatalf("nil service Install() error = %v, want store is required", err)
|
||||
}
|
||||
if _, err := (&PackInstallService{store: storeWithoutHost}).Install(context.Background(), PackInstallRequest{Pack: loaded}); err == nil || err.Error() != "host adapter is required" {
|
||||
t.Fatalf("missing host Install() error = %v, want host adapter is required", err)
|
||||
}
|
||||
if _, err := NewPackInstallService(storeWithoutPack, &fakeHostAdapter{}).Install(context.Background(), PackInstallRequest{}); err == nil || err.Error() != "pack manifest is required" {
|
||||
t.Fatalf("missing pack Install() error = %v, want pack manifest is required", err)
|
||||
}
|
||||
}
|
||||
|
||||
func sampleLoadedPack() pack.LoadedPack {
|
||||
provider := sampleProviderManifest()
|
||||
return pack.LoadedPack{
|
||||
Manifest: pack.Manifest{
|
||||
PackID: "openai-cn-pack",
|
||||
Version: "1.0.0",
|
||||
Vendor: "nous",
|
||||
TargetHost: "sub2api",
|
||||
MinHostVersion: "0.1.126",
|
||||
MaxHostVersion: "0.2.x",
|
||||
},
|
||||
Providers: []pack.ProviderManifest{provider},
|
||||
Checksum: "checksum-1",
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateExistingPack(t *testing.T) {
|
||||
existing := sqlite.Pack{PackID: "openai-cn-pack", Version: "1.0.0", Checksum: "checksum-1"}
|
||||
if err := validateExistingPack(existing, sampleLoadedPack()); err != nil {
|
||||
t.Fatalf("validateExistingPack() error = %v, want nil", err)
|
||||
}
|
||||
}
|
||||
90
internal/provision/preview_service.go
Normal file
90
internal/provision/preview_service.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
"sub2api-cn-relay-manager/internal/pack"
|
||||
)
|
||||
|
||||
const (
|
||||
PreviewActionCreate = "create"
|
||||
PreviewActionReuse = "reuse"
|
||||
PreviewActionConflict = "conflict"
|
||||
)
|
||||
|
||||
type previewHost interface {
|
||||
ListManagedResources(ctx context.Context, req sub2api.ListManagedResourcesRequest) (sub2api.ManagedResourceSnapshot, error)
|
||||
}
|
||||
|
||||
type PreviewRequest struct {
|
||||
Provider pack.ProviderManifest
|
||||
Mode string
|
||||
Keys []string
|
||||
}
|
||||
|
||||
type PreviewDecision struct {
|
||||
Action string
|
||||
Suggested string
|
||||
ExistingID string
|
||||
Reason string
|
||||
}
|
||||
|
||||
type PreviewReport struct {
|
||||
AcceptedKeys []string
|
||||
Names ResourceNames
|
||||
Decisions map[string]PreviewDecision
|
||||
}
|
||||
|
||||
type PreviewService struct {
|
||||
host previewHost
|
||||
}
|
||||
|
||||
func NewPreviewService(host previewHost) *PreviewService {
|
||||
return &PreviewService{host: host}
|
||||
}
|
||||
|
||||
func (s *PreviewService) PreviewImport(ctx context.Context, req PreviewRequest) (PreviewReport, error) {
|
||||
acceptedKeys, err := normalizeKeys(req.Keys)
|
||||
if err != nil {
|
||||
return PreviewReport{}, err
|
||||
}
|
||||
if err := validateMode(req.Mode); err != nil {
|
||||
return PreviewReport{}, err
|
||||
}
|
||||
if s.host == nil {
|
||||
return PreviewReport{}, fmt.Errorf("preview host is required")
|
||||
}
|
||||
|
||||
names := SuggestResourceNames(req.Provider)
|
||||
snapshot, err := s.host.ListManagedResources(ctx, sub2api.ListManagedResourcesRequest{
|
||||
GroupName: names.Group,
|
||||
ChannelName: names.Channel,
|
||||
PlanName: names.Plan,
|
||||
})
|
||||
if err != nil {
|
||||
return PreviewReport{}, fmt.Errorf("list managed resources: %w", err)
|
||||
}
|
||||
|
||||
return PreviewReport{
|
||||
AcceptedKeys: acceptedKeys,
|
||||
Names: names,
|
||||
Decisions: map[string]PreviewDecision{
|
||||
"group": decideResource(names.Group, snapshot.Groups),
|
||||
"channel": decideResource(names.Channel, snapshot.Channels),
|
||||
"plan": decideResource(names.Plan, snapshot.Plans),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func decideResource(suggested string, existing []sub2api.NamedResource) PreviewDecision {
|
||||
switch len(existing) {
|
||||
case 0:
|
||||
return PreviewDecision{Action: PreviewActionCreate, Suggested: suggested}
|
||||
case 1:
|
||||
return PreviewDecision{Action: PreviewActionReuse, Suggested: suggested, ExistingID: existing[0].ID, Reason: "matching managed resource already exists"}
|
||||
default:
|
||||
return PreviewDecision{Action: PreviewActionConflict, Suggested: suggested, Reason: "multiple managed resources share the suggested name"}
|
||||
}
|
||||
}
|
||||
87
internal/provision/preview_service_test.go
Normal file
87
internal/provision/preview_service_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
)
|
||||
|
||||
func TestSuggestResourceNames(t *testing.T) {
|
||||
provider := sampleProviderManifest()
|
||||
|
||||
names := SuggestResourceNames(provider)
|
||||
|
||||
want := ResourceNames{
|
||||
Group: "crm-deepseek-group",
|
||||
Channel: "crm-deepseek-channel",
|
||||
Plan: "crm-deepseek-plan",
|
||||
}
|
||||
if !reflect.DeepEqual(names, want) {
|
||||
t.Fatalf("SuggestResourceNames() = %#v, want %#v", names, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreviewServiceReportsCreateActionsWhenHostHasNoResources(t *testing.T) {
|
||||
host := &fakePreviewHost{}
|
||||
svc := NewPreviewService(host)
|
||||
|
||||
report, err := svc.PreviewImport(context.Background(), PreviewRequest{
|
||||
Provider: sampleProviderManifest(),
|
||||
Mode: ImportModeStrict,
|
||||
Keys: []string{" key-1 ", "key-2", "key-1"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("PreviewImport() error = %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(report.AcceptedKeys, []string{"key-1", "key-2"}) {
|
||||
t.Fatalf("AcceptedKeys = %#v, want normalized deduped keys", report.AcceptedKeys)
|
||||
}
|
||||
if got := report.Decisions["group"]; got.Action != PreviewActionCreate {
|
||||
t.Fatalf("group action = %q, want %q", got.Action, PreviewActionCreate)
|
||||
}
|
||||
if got := report.Decisions["channel"]; got.Action != PreviewActionCreate {
|
||||
t.Fatalf("channel action = %q, want %q", got.Action, PreviewActionCreate)
|
||||
}
|
||||
if got := report.Decisions["plan"]; got.Action != PreviewActionCreate {
|
||||
t.Fatalf("plan action = %q, want %q", got.Action, PreviewActionCreate)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreviewServiceReportsReuseAndConflict(t *testing.T) {
|
||||
host := &fakePreviewHost{snapshot: sub2api.ManagedResourceSnapshot{
|
||||
Groups: []sub2api.NamedResource{{ID: "group_1", Name: "crm-deepseek-group"}},
|
||||
Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "crm-deepseek-channel"}, {ID: "channel_2", Name: "crm-deepseek-channel"}},
|
||||
Plans: []sub2api.NamedResource{{ID: "plan_1", Name: "crm-deepseek-plan"}},
|
||||
}}
|
||||
svc := NewPreviewService(host)
|
||||
|
||||
report, err := svc.PreviewImport(context.Background(), PreviewRequest{
|
||||
Provider: sampleProviderManifest(),
|
||||
Mode: ImportModePartial,
|
||||
Keys: []string{"key-1"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("PreviewImport() error = %v", err)
|
||||
}
|
||||
|
||||
if got := report.Decisions["group"]; got.Action != PreviewActionReuse || got.ExistingID != "group_1" {
|
||||
t.Fatalf("group decision = %#v, want reuse group_1", got)
|
||||
}
|
||||
if got := report.Decisions["plan"]; got.Action != PreviewActionReuse || got.ExistingID != "plan_1" {
|
||||
t.Fatalf("plan decision = %#v, want reuse plan_1", got)
|
||||
}
|
||||
if got := report.Decisions["channel"]; got.Action != PreviewActionConflict {
|
||||
t.Fatalf("channel decision = %#v, want conflict", got)
|
||||
}
|
||||
}
|
||||
|
||||
type fakePreviewHost struct {
|
||||
snapshot sub2api.ManagedResourceSnapshot
|
||||
}
|
||||
|
||||
func (f *fakePreviewHost) ListManagedResources(context.Context, sub2api.ListManagedResourcesRequest) (sub2api.ManagedResourceSnapshot, error) {
|
||||
return f.snapshot, nil
|
||||
}
|
||||
150
internal/provision/provider_status_service.go
Normal file
150
internal/provision/provider_status_service.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
type ProviderQuery struct {
|
||||
ProviderID string
|
||||
PackID string
|
||||
}
|
||||
|
||||
type ProviderSnapshot struct {
|
||||
Host sqlite.Host
|
||||
Pack sqlite.Pack
|
||||
Provider sqlite.Provider
|
||||
Batch sqlite.ImportBatch
|
||||
ManagedResources []sqlite.ManagedResource
|
||||
AccessClosures []sqlite.AccessClosureRecord
|
||||
ReconcileRuns []sqlite.ReconcileRun
|
||||
ProviderStatus string
|
||||
LatestAccessStatus string
|
||||
LatestReconcileStatus string
|
||||
LatestReconcileSummary map[string]any
|
||||
}
|
||||
|
||||
type ProviderStatusService struct {
|
||||
store *sqlite.DB
|
||||
}
|
||||
|
||||
func NewProviderStatusService(store *sqlite.DB) *ProviderStatusService {
|
||||
return &ProviderStatusService{store: store}
|
||||
}
|
||||
|
||||
func (s *ProviderStatusService) GetStatus(ctx context.Context, query ProviderQuery) (ProviderSnapshot, error) {
|
||||
return s.snapshot(ctx, query)
|
||||
}
|
||||
|
||||
func (s *ProviderStatusService) GetResources(ctx context.Context, query ProviderQuery) (ProviderSnapshot, error) {
|
||||
return s.snapshot(ctx, query)
|
||||
}
|
||||
|
||||
func (s *ProviderStatusService) snapshot(ctx context.Context, query ProviderQuery) (ProviderSnapshot, error) {
|
||||
if s == nil || s.store == nil {
|
||||
return ProviderSnapshot{}, fmt.Errorf("store is required")
|
||||
}
|
||||
provider, err := s.resolveProvider(ctx, query)
|
||||
if err != nil {
|
||||
return ProviderSnapshot{}, err
|
||||
}
|
||||
packRow, err := s.store.Packs().GetByID(ctx, provider.PackID)
|
||||
if err != nil {
|
||||
return ProviderSnapshot{}, err
|
||||
}
|
||||
batchRow, err := s.store.ImportBatches().GetLatestByProviderID(ctx, provider.ID)
|
||||
if err != nil {
|
||||
return ProviderSnapshot{}, err
|
||||
}
|
||||
hostRow, err := s.store.Hosts().GetByID(ctx, batchRow.HostID)
|
||||
if err != nil {
|
||||
return ProviderSnapshot{}, err
|
||||
}
|
||||
managedResources, err := s.store.ManagedResources().GetByBatchID(ctx, batchRow.ID)
|
||||
if err != nil {
|
||||
return ProviderSnapshot{}, err
|
||||
}
|
||||
accessClosures, err := s.store.AccessClosures().GetByBatchID(ctx, batchRow.ID)
|
||||
if err != nil {
|
||||
return ProviderSnapshot{}, err
|
||||
}
|
||||
reconcileRuns, err := s.store.ReconcileRuns().GetByProviderID(ctx, provider.ID)
|
||||
if err != nil {
|
||||
return ProviderSnapshot{}, err
|
||||
}
|
||||
latestAccessStatus := batchRow.AccessStatus
|
||||
if len(accessClosures) > 0 {
|
||||
latestAccessStatus = firstNonEmpty(accessClosures[len(accessClosures)-1].Status, latestAccessStatus)
|
||||
}
|
||||
latestReconcileStatus := "not_run"
|
||||
latestReconcileSummary := map[string]any{}
|
||||
if len(reconcileRuns) > 0 {
|
||||
latestReconcileStatus = firstNonEmpty(reconcileRuns[0].Status, latestReconcileStatus)
|
||||
if strings.TrimSpace(reconcileRuns[0].SummaryJSON) != "" {
|
||||
if err := json.Unmarshal([]byte(reconcileRuns[0].SummaryJSON), &latestReconcileSummary); err != nil {
|
||||
return ProviderSnapshot{}, fmt.Errorf("decode reconcile summary: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
providerStatus := deriveProviderStatus(batchRow.BatchStatus, latestReconcileStatus)
|
||||
return ProviderSnapshot{
|
||||
Host: hostRow,
|
||||
Pack: packRow,
|
||||
Provider: provider,
|
||||
Batch: batchRow,
|
||||
ManagedResources: managedResources,
|
||||
AccessClosures: accessClosures,
|
||||
ReconcileRuns: reconcileRuns,
|
||||
ProviderStatus: providerStatus,
|
||||
LatestAccessStatus: latestAccessStatus,
|
||||
LatestReconcileStatus: latestReconcileStatus,
|
||||
LatestReconcileSummary: latestReconcileSummary,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *ProviderStatusService) resolveProvider(ctx context.Context, query ProviderQuery) (sqlite.Provider, error) {
|
||||
providerID := strings.TrimSpace(query.ProviderID)
|
||||
packID := strings.TrimSpace(query.PackID)
|
||||
if providerID == "" {
|
||||
return sqlite.Provider{}, fmt.Errorf("provider_id is required")
|
||||
}
|
||||
if packID != "" {
|
||||
packRow, err := s.store.Packs().GetByPackID(ctx, packID)
|
||||
if err != nil {
|
||||
return sqlite.Provider{}, err
|
||||
}
|
||||
return s.store.Providers().GetByPackIDAndProviderID(ctx, packRow.ID, providerID)
|
||||
}
|
||||
providers, err := s.store.Providers().ListByProviderID(ctx, providerID)
|
||||
if err != nil {
|
||||
return sqlite.Provider{}, err
|
||||
}
|
||||
if len(providers) == 0 {
|
||||
return sqlite.Provider{}, fmt.Errorf("provider %q not found", providerID)
|
||||
}
|
||||
if len(providers) > 1 {
|
||||
return sqlite.Provider{}, fmt.Errorf("provider %q exists in multiple packs; pack_id is required", providerID)
|
||||
}
|
||||
return providers[0], nil
|
||||
}
|
||||
|
||||
func deriveProviderStatus(batchStatus, reconcileStatus string) string {
|
||||
reconcileStatus = strings.TrimSpace(reconcileStatus)
|
||||
if reconcileStatus != "" && reconcileStatus != "not_run" {
|
||||
return reconcileStatus
|
||||
}
|
||||
switch strings.TrimSpace(batchStatus) {
|
||||
case BatchStatusSucceeded:
|
||||
return ProviderStatusActive
|
||||
case BatchStatusFailed:
|
||||
return ProviderStatusFailed
|
||||
case "running":
|
||||
return "running"
|
||||
default:
|
||||
return firstNonEmpty(batchStatus, "unknown")
|
||||
}
|
||||
}
|
||||
98
internal/provision/provider_status_service_test.go
Normal file
98
internal/provision/provider_status_service_test.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
func TestProviderStatusServiceReturnsLatestSnapshot(t *testing.T) {
|
||||
store := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, store)
|
||||
|
||||
ctx := context.Background()
|
||||
hostID, err := store.Hosts().Create(ctx, sqlite.Host{HostID: "host-1", BaseURL: "https://sub2api.example.com", HostVersion: "0.1.126", CapabilityProbeJSON: `{"supports_batch_accounts":true}`})
|
||||
if err != nil {
|
||||
t.Fatalf("Hosts().Create() error = %v", err)
|
||||
}
|
||||
packID, err := store.Packs().Create(ctx, sqlite.Pack{PackID: "openai-cn-pack", Version: "1.0.0", Checksum: "checksum-1"})
|
||||
if err != nil {
|
||||
t.Fatalf("Packs().Create() error = %v", err)
|
||||
}
|
||||
providerID, err := store.Providers().Create(ctx, sqlite.Provider{PackID: packID, ProviderID: "deepseek", DisplayName: "DeepSeek", BaseURL: "https://api.deepseek.com", Platform: "openai"})
|
||||
if err != nil {
|
||||
t.Fatalf("Providers().Create() error = %v", err)
|
||||
}
|
||||
batchID, err := store.ImportBatches().Create(ctx, sqlite.ImportBatch{HostID: hostID, PackID: packID, ProviderID: providerID, Mode: ImportModeStrict, BatchStatus: BatchStatusSucceeded, AccessStatus: AccessStatusSelfServiceReady})
|
||||
if err != nil {
|
||||
t.Fatalf("ImportBatches().Create() error = %v", err)
|
||||
}
|
||||
if _, err := store.ManagedResources().Create(ctx, sqlite.ManagedResource{BatchID: batchID, ResourceType: "group", HostResourceID: "group-1", ResourceName: "deepseek-group"}); err != nil {
|
||||
t.Fatalf("ManagedResources().Create(group) error = %v", err)
|
||||
}
|
||||
if _, err := store.ManagedResources().Create(ctx, sqlite.ManagedResource{BatchID: batchID, ResourceType: "account", HostResourceID: "account-1", ResourceName: "deepseek-account-1"}); err != nil {
|
||||
t.Fatalf("ManagedResources().Create(account) error = %v", err)
|
||||
}
|
||||
if _, err := store.AccessClosures().Create(ctx, sqlite.AccessClosureRecord{BatchID: batchID, ClosureType: AccessModeSelfService, Status: AccessStatusSelfServiceReady, DetailsJSON: `{"ok":true}`}); err != nil {
|
||||
t.Fatalf("AccessClosures().Create() error = %v", err)
|
||||
}
|
||||
if _, err := store.ReconcileRuns().Create(ctx, sqlite.ReconcileRun{ProviderID: providerID, Status: "drifted", SummaryJSON: `{"missing_count":1}`}); err != nil {
|
||||
t.Fatalf("ReconcileRuns().Create() error = %v", err)
|
||||
}
|
||||
|
||||
snapshot, err := NewProviderStatusService(store).GetStatus(ctx, ProviderQuery{ProviderID: "deepseek"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetStatus() error = %v", err)
|
||||
}
|
||||
if snapshot.Host.HostID != "host-1" {
|
||||
t.Fatalf("Host.HostID = %q, want host-1", snapshot.Host.HostID)
|
||||
}
|
||||
if snapshot.Pack.PackID != "openai-cn-pack" {
|
||||
t.Fatalf("Pack.PackID = %q, want openai-cn-pack", snapshot.Pack.PackID)
|
||||
}
|
||||
if snapshot.Provider.ProviderID != "deepseek" {
|
||||
t.Fatalf("Provider.ProviderID = %q, want deepseek", snapshot.Provider.ProviderID)
|
||||
}
|
||||
if snapshot.ProviderStatus != "drifted" {
|
||||
t.Fatalf("ProviderStatus = %q, want drifted", snapshot.ProviderStatus)
|
||||
}
|
||||
if snapshot.LatestAccessStatus != AccessStatusSelfServiceReady {
|
||||
t.Fatalf("LatestAccessStatus = %q, want %q", snapshot.LatestAccessStatus, AccessStatusSelfServiceReady)
|
||||
}
|
||||
if snapshot.LatestReconcileStatus != "drifted" {
|
||||
t.Fatalf("LatestReconcileStatus = %q, want drifted", snapshot.LatestReconcileStatus)
|
||||
}
|
||||
if got := len(snapshot.ManagedResources); got != 2 {
|
||||
t.Fatalf("len(ManagedResources) = %d, want 2", got)
|
||||
}
|
||||
if got, ok := snapshot.LatestReconcileSummary["missing_count"].(float64); !ok || got != 1 {
|
||||
t.Fatalf("LatestReconcileSummary[missing_count] = %#v, want 1", snapshot.LatestReconcileSummary["missing_count"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderStatusServiceRequiresPackIDWhenProviderIDIsAmbiguous(t *testing.T) {
|
||||
store := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, store)
|
||||
|
||||
ctx := context.Background()
|
||||
pack1, err := store.Packs().Create(ctx, sqlite.Pack{PackID: "pack-a", Version: "1.0.0", Checksum: "checksum-a"})
|
||||
if err != nil {
|
||||
t.Fatalf("Packs().Create(pack-a) error = %v", err)
|
||||
}
|
||||
pack2, err := store.Packs().Create(ctx, sqlite.Pack{PackID: "pack-b", Version: "1.0.0", Checksum: "checksum-b"})
|
||||
if err != nil {
|
||||
t.Fatalf("Packs().Create(pack-b) error = %v", err)
|
||||
}
|
||||
if _, err := store.Providers().Create(ctx, sqlite.Provider{PackID: pack1, ProviderID: "deepseek", DisplayName: "DeepSeek A", BaseURL: "https://a.example.com", Platform: "openai"}); err != nil {
|
||||
t.Fatalf("Providers().Create(pack-a) error = %v", err)
|
||||
}
|
||||
if _, err := store.Providers().Create(ctx, sqlite.Provider{PackID: pack2, ProviderID: "deepseek", DisplayName: "DeepSeek B", BaseURL: "https://b.example.com", Platform: "openai"}); err != nil {
|
||||
t.Fatalf("Providers().Create(pack-b) error = %v", err)
|
||||
}
|
||||
|
||||
_, err = NewProviderStatusService(store).GetStatus(ctx, ProviderQuery{ProviderID: "deepseek"})
|
||||
if err == nil || err.Error() != `provider "deepseek" exists in multiple packs; pack_id is required` {
|
||||
t.Fatalf("GetStatus() error = %v, want ambiguous provider error", err)
|
||||
}
|
||||
}
|
||||
187
internal/provision/reconcile_service_test.go
Normal file
187
internal/provision/reconcile_service_test.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
"sub2api-cn-relay-manager/internal/pack"
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
func TestReconcileServiceReturnsActiveAfterProbeRerun(t *testing.T) {
|
||||
store := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, store)
|
||||
|
||||
host := &fakeHostAdapter{
|
||||
batchAccounts: []sub2api.AccountRef{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
|
||||
testResults: map[string]sub2api.ProbeResult{
|
||||
"account_1": {OK: true, Status: "passed"},
|
||||
"account_2": {OK: true, Status: "passed"},
|
||||
},
|
||||
models: map[string][]sub2api.AccountModel{
|
||||
"account_1": {{ID: "deepseek-chat"}},
|
||||
"account_2": {{ID: "deepseek-chat"}},
|
||||
},
|
||||
gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}},
|
||||
}
|
||||
|
||||
batchID := seedRuntimeImportForReconcile(t, store, host)
|
||||
host.managedSnapshot = sub2api.ManagedResourceSnapshot{
|
||||
Groups: []sub2api.NamedResource{{ID: "group_1", Name: "g"}},
|
||||
Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "c"}},
|
||||
Accounts: []sub2api.NamedResource{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
|
||||
}
|
||||
|
||||
result, err := NewReconcileService(store, host).Reconcile(context.Background(), ReconcileRequest{
|
||||
HostBaseURL: "https://sub2api.example.com",
|
||||
AccessProbeAPIKey: "user-key",
|
||||
Pack: pack.LoadedPack{Manifest: pack.Manifest{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api", MinHostVersion: "0.1.126", MaxHostVersion: "0.2.x"}},
|
||||
Provider: sampleProviderManifest(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Reconcile() error = %v", err)
|
||||
}
|
||||
if result.BatchID != batchID {
|
||||
t.Fatalf("BatchID = %d, want %d", result.BatchID, batchID)
|
||||
}
|
||||
if result.Status != "active" {
|
||||
t.Fatalf("Status = %q, want active", result.Status)
|
||||
}
|
||||
if result.ProbeFailureCount != 0 {
|
||||
t.Fatalf("ProbeFailureCount = %d, want 0", result.ProbeFailureCount)
|
||||
}
|
||||
if result.AccessStatus != AccessStatusSelfServiceReady {
|
||||
t.Fatalf("AccessStatus = %q, want %q", result.AccessStatus, AccessStatusSelfServiceReady)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "probe_results"); got != 4 {
|
||||
t.Fatalf("probe_results row count = %d, want 4 after rerun", got)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "access_closure_records"); got != 2 {
|
||||
t.Fatalf("access_closure_records row count = %d, want 2 after rerun", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconcileServiceReturnsDegradedWhenProbeRerunFails(t *testing.T) {
|
||||
store := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, store)
|
||||
|
||||
host := &fakeHostAdapter{
|
||||
batchAccounts: []sub2api.AccountRef{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
|
||||
testResults: map[string]sub2api.ProbeResult{
|
||||
"account_1": {OK: true, Status: "passed"},
|
||||
"account_2": {OK: false, Status: "failed", Message: "bad key"},
|
||||
},
|
||||
models: map[string][]sub2api.AccountModel{
|
||||
"account_1": {{ID: "deepseek-chat"}},
|
||||
"account_2": {{ID: "deepseek-chat"}},
|
||||
},
|
||||
gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}},
|
||||
}
|
||||
|
||||
seedRuntimeImportForReconcile(t, store, host)
|
||||
host.managedSnapshot = sub2api.ManagedResourceSnapshot{
|
||||
Groups: []sub2api.NamedResource{{ID: "group_1", Name: "g"}},
|
||||
Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "c"}},
|
||||
Accounts: []sub2api.NamedResource{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
|
||||
}
|
||||
|
||||
result, err := NewReconcileService(store, host).Reconcile(context.Background(), ReconcileRequest{
|
||||
HostBaseURL: "https://sub2api.example.com",
|
||||
AccessProbeAPIKey: "user-key",
|
||||
Pack: pack.LoadedPack{Manifest: pack.Manifest{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api", MinHostVersion: "0.1.126", MaxHostVersion: "0.2.x"}},
|
||||
Provider: sampleProviderManifest(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Reconcile() error = %v", err)
|
||||
}
|
||||
if result.Status != "degraded" {
|
||||
t.Fatalf("Status = %q, want degraded", result.Status)
|
||||
}
|
||||
if result.ProbeFailureCount != 1 {
|
||||
t.Fatalf("ProbeFailureCount = %d, want 1", result.ProbeFailureCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconcileServiceReturnsDriftedWhenManagedResourceMissing(t *testing.T) {
|
||||
store := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, store)
|
||||
|
||||
host := &fakeHostAdapter{
|
||||
batchAccounts: []sub2api.AccountRef{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
|
||||
testResults: map[string]sub2api.ProbeResult{
|
||||
"account_1": {OK: true, Status: "passed"},
|
||||
"account_2": {OK: true, Status: "passed"},
|
||||
},
|
||||
models: map[string][]sub2api.AccountModel{
|
||||
"account_1": {{ID: "deepseek-chat"}},
|
||||
"account_2": {{ID: "deepseek-chat"}},
|
||||
},
|
||||
gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}},
|
||||
}
|
||||
|
||||
seedRuntimeImportForReconcile(t, store, host)
|
||||
host.managedSnapshot = sub2api.ManagedResourceSnapshot{
|
||||
Groups: []sub2api.NamedResource{{ID: "group_1", Name: "g"}},
|
||||
Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "c"}},
|
||||
Accounts: []sub2api.NamedResource{{ID: "account_1", Name: "deepseek-01"}},
|
||||
}
|
||||
|
||||
result, err := NewReconcileService(store, host).Reconcile(context.Background(), ReconcileRequest{
|
||||
HostBaseURL: "https://sub2api.example.com",
|
||||
AccessProbeAPIKey: "user-key",
|
||||
Pack: pack.LoadedPack{Manifest: pack.Manifest{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api", MinHostVersion: "0.1.126", MaxHostVersion: "0.2.x"}},
|
||||
Provider: sampleProviderManifest(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Reconcile() error = %v", err)
|
||||
}
|
||||
if result.Status != "drifted" {
|
||||
t.Fatalf("Status = %q, want drifted", result.Status)
|
||||
}
|
||||
if result.MissingCount != 1 {
|
||||
t.Fatalf("MissingCount = %d, want 1", result.MissingCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveHealthyAccessStatus(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
closureType string
|
||||
want string
|
||||
}{
|
||||
{name: "subscription", closureType: AccessModeSubscription, want: AccessStatusSubscriptionReady},
|
||||
{name: "self-service", closureType: AccessModeSelfService, want: AccessStatusSelfServiceReady},
|
||||
{name: "unknown", closureType: "other", want: "unknown"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := deriveHealthyAccessStatus(tc.closureType); got != tc.want {
|
||||
t.Fatalf("deriveHealthyAccessStatus(%q) = %q, want %q", tc.closureType, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func seedRuntimeImportForReconcile(t *testing.T, store *sqlite.DB, host *fakeHostAdapter) int64 {
|
||||
t.Helper()
|
||||
result, err := NewRuntimeImportService(store, host).Import(context.Background(), RuntimeImportRequest{
|
||||
HostID: "host-1",
|
||||
HostBaseURL: "https://sub2api.example.com",
|
||||
Pack: pack.LoadedPack{
|
||||
Manifest: pack.Manifest{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api", MinHostVersion: "0.1.126", MaxHostVersion: "0.2.x"},
|
||||
Checksum: "checksum-1",
|
||||
},
|
||||
Provider: sampleProviderManifest(),
|
||||
Mode: ImportModePartial,
|
||||
Keys: []string{"key-1", "key-2"},
|
||||
Access: AccessRequest{
|
||||
Mode: AccessModeSelfService,
|
||||
ProbeAPIKey: "user-key",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("seed RuntimeImportService.Import() error = %v", err)
|
||||
}
|
||||
return result.BatchID
|
||||
}
|
||||
90
internal/provision/rollback_service.go
Normal file
90
internal/provision/rollback_service.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
"sub2api-cn-relay-manager/internal/pack"
|
||||
)
|
||||
|
||||
type rollbackHost interface {
|
||||
ListManagedResources(ctx context.Context, req sub2api.ListManagedResourcesRequest) (sub2api.ManagedResourceSnapshot, error)
|
||||
DeleteAccount(ctx context.Context, accountID string) error
|
||||
DeletePlan(ctx context.Context, planID string) error
|
||||
DeleteChannel(ctx context.Context, channelID string) error
|
||||
DeleteGroup(ctx context.Context, groupID string) error
|
||||
}
|
||||
|
||||
type RollbackRequest struct {
|
||||
Provider pack.ProviderManifest
|
||||
}
|
||||
|
||||
type RollbackReport struct {
|
||||
AccountsDeleted int
|
||||
PlansDeleted int
|
||||
ChannelsDeleted int
|
||||
GroupsDeleted int
|
||||
}
|
||||
|
||||
type RollbackService struct {
|
||||
host rollbackHost
|
||||
}
|
||||
|
||||
func NewRollbackService(host rollbackHost) *RollbackService {
|
||||
return &RollbackService{host: host}
|
||||
}
|
||||
|
||||
func (s *RollbackService) Rollback(ctx context.Context, req RollbackRequest) (RollbackReport, error) {
|
||||
if s.host == nil {
|
||||
return RollbackReport{}, fmt.Errorf("rollback host is required")
|
||||
}
|
||||
|
||||
names := SuggestResourceNames(req.Provider)
|
||||
snapshot, err := s.host.ListManagedResources(ctx, sub2api.ListManagedResourcesRequest{
|
||||
GroupName: names.Group,
|
||||
ChannelName: names.Channel,
|
||||
PlanName: names.Plan,
|
||||
AccountNamePrefix: SuggestAccountNamePrefix(req.Provider),
|
||||
})
|
||||
if err != nil {
|
||||
return RollbackReport{}, fmt.Errorf("list managed resources: %w", err)
|
||||
}
|
||||
|
||||
var report RollbackReport
|
||||
var errs []error
|
||||
for index := len(snapshot.Accounts) - 1; index >= 0; index-- {
|
||||
if err := s.host.DeleteAccount(ctx, snapshot.Accounts[index].ID); err != nil {
|
||||
errs = append(errs, fmt.Errorf("delete account %s: %w", snapshot.Accounts[index].ID, err))
|
||||
continue
|
||||
}
|
||||
report.AccountsDeleted++
|
||||
}
|
||||
for index := len(snapshot.Plans) - 1; index >= 0; index-- {
|
||||
if err := s.host.DeletePlan(ctx, snapshot.Plans[index].ID); err != nil {
|
||||
errs = append(errs, fmt.Errorf("delete plan %s: %w", snapshot.Plans[index].ID, err))
|
||||
continue
|
||||
}
|
||||
report.PlansDeleted++
|
||||
}
|
||||
for index := len(snapshot.Channels) - 1; index >= 0; index-- {
|
||||
if err := s.host.DeleteChannel(ctx, snapshot.Channels[index].ID); err != nil {
|
||||
errs = append(errs, fmt.Errorf("delete channel %s: %w", snapshot.Channels[index].ID, err))
|
||||
continue
|
||||
}
|
||||
report.ChannelsDeleted++
|
||||
}
|
||||
for index := len(snapshot.Groups) - 1; index >= 0; index-- {
|
||||
if err := s.host.DeleteGroup(ctx, snapshot.Groups[index].ID); err != nil {
|
||||
errs = append(errs, fmt.Errorf("delete group %s: %w", snapshot.Groups[index].ID, err))
|
||||
continue
|
||||
}
|
||||
report.GroupsDeleted++
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return report, errors.Join(errs...)
|
||||
}
|
||||
return report, nil
|
||||
}
|
||||
50
internal/provision/rollback_service_test.go
Normal file
50
internal/provision/rollback_service_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
)
|
||||
|
||||
func TestRollbackServiceDeletesManagedResourcesInDependencyOrder(t *testing.T) {
|
||||
host := &fakeHostAdapter{
|
||||
managedSnapshot: sub2api.ManagedResourceSnapshot{
|
||||
Groups: []sub2api.NamedResource{{ID: "group_1", Name: "crm-deepseek-group"}},
|
||||
Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "crm-deepseek-channel"}},
|
||||
Plans: []sub2api.NamedResource{{ID: "plan_1", Name: "crm-deepseek-plan"}},
|
||||
Accounts: []sub2api.NamedResource{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewRollbackService(host)
|
||||
report, err := svc.Rollback(context.Background(), RollbackRequest{Provider: sampleProviderManifest()})
|
||||
if err != nil {
|
||||
t.Fatalf("Rollback() error = %v", err)
|
||||
}
|
||||
|
||||
if report.AccountsDeleted != 2 || report.PlansDeleted != 1 || report.ChannelsDeleted != 1 || report.GroupsDeleted != 1 {
|
||||
t.Fatalf("Rollback() report = %+v, want all managed resources deleted", report)
|
||||
}
|
||||
want := []string{"account:account_2", "account:account_1", "plan:plan_1", "channel:channel_1", "group:group_1"}
|
||||
if !reflect.DeepEqual(host.deletedResources, want) {
|
||||
t.Fatalf("deleted resources = %#v, want %#v", host.deletedResources, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRollbackServiceReturnsEmptyReportWhenNoManagedResourcesExist(t *testing.T) {
|
||||
host := &fakeHostAdapter{}
|
||||
svc := NewRollbackService(host)
|
||||
|
||||
report, err := svc.Rollback(context.Background(), RollbackRequest{Provider: sampleProviderManifest()})
|
||||
if err != nil {
|
||||
t.Fatalf("Rollback() error = %v", err)
|
||||
}
|
||||
if report.AccountsDeleted != 0 || report.PlansDeleted != 0 || report.ChannelsDeleted != 0 || report.GroupsDeleted != 0 {
|
||||
t.Fatalf("Rollback() report = %+v, want zero deletions", report)
|
||||
}
|
||||
if len(host.deletedResources) != 0 {
|
||||
t.Fatalf("deleted resources = %#v, want none", host.deletedResources)
|
||||
}
|
||||
}
|
||||
259
internal/provision/runtime_import_service.go
Normal file
259
internal/provision/runtime_import_service.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/pack"
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
type RuntimeImportRequest struct {
|
||||
HostID string
|
||||
HostBaseURL string
|
||||
Pack pack.LoadedPack
|
||||
Provider pack.ProviderManifest
|
||||
Mode string
|
||||
Access AccessRequest
|
||||
Keys []string
|
||||
}
|
||||
|
||||
type RuntimeImportResult struct {
|
||||
BatchID int64
|
||||
Report ImportReport
|
||||
}
|
||||
|
||||
type RuntimeImportService struct {
|
||||
store *sqlite.DB
|
||||
host hostAdapter
|
||||
}
|
||||
|
||||
func NewRuntimeImportService(store *sqlite.DB, host hostAdapter) *RuntimeImportService {
|
||||
return &RuntimeImportService{store: store, host: host}
|
||||
}
|
||||
|
||||
func (s *RuntimeImportService) Import(ctx context.Context, req RuntimeImportRequest) (RuntimeImportResult, error) {
|
||||
if s == nil || s.store == nil {
|
||||
return RuntimeImportResult{}, fmt.Errorf("store is required")
|
||||
}
|
||||
if s.host == nil {
|
||||
return RuntimeImportResult{}, fmt.Errorf("host adapter is required")
|
||||
}
|
||||
req.HostBaseURL = strings.TrimSpace(req.HostBaseURL)
|
||||
if req.HostBaseURL == "" {
|
||||
return RuntimeImportResult{}, fmt.Errorf("host_base_url is required")
|
||||
}
|
||||
if strings.TrimSpace(req.HostID) == "" {
|
||||
req.HostID = req.HostBaseURL
|
||||
}
|
||||
|
||||
hostVersion, err := s.host.GetHostVersion(ctx)
|
||||
if err != nil {
|
||||
return RuntimeImportResult{}, fmt.Errorf("get host version: %w", err)
|
||||
}
|
||||
if err := pack.CheckHostCompatibility(req.Pack.Manifest, hostVersion); err != nil {
|
||||
return RuntimeImportResult{}, err
|
||||
}
|
||||
capabilities, err := s.host.ProbeCapabilities(ctx)
|
||||
if err != nil {
|
||||
return RuntimeImportResult{}, fmt.Errorf("probe host capabilities: %w", err)
|
||||
}
|
||||
capabilityProbeJSON, err := json.Marshal(capabilities)
|
||||
if err != nil {
|
||||
return RuntimeImportResult{}, fmt.Errorf("marshal host capabilities: %w", err)
|
||||
}
|
||||
|
||||
hostRow, err := s.ensureHost(ctx, req.HostID, req.HostBaseURL, hostVersion, string(capabilityProbeJSON))
|
||||
if err != nil {
|
||||
return RuntimeImportResult{}, err
|
||||
}
|
||||
packRow, err := s.ensurePack(ctx, req.Pack)
|
||||
if err != nil {
|
||||
return RuntimeImportResult{}, err
|
||||
}
|
||||
providerRow, err := s.ensureProvider(ctx, packRow.ID, req.Provider)
|
||||
if err != nil {
|
||||
return RuntimeImportResult{}, err
|
||||
}
|
||||
|
||||
batchID, err := s.store.ImportBatches().Create(ctx, sqlite.ImportBatch{
|
||||
HostID: hostRow.ID,
|
||||
PackID: packRow.ID,
|
||||
ProviderID: providerRow.ID,
|
||||
Mode: req.Mode,
|
||||
BatchStatus: "running",
|
||||
AccessStatus: "pending",
|
||||
})
|
||||
if err != nil {
|
||||
return RuntimeImportResult{}, err
|
||||
}
|
||||
|
||||
report, importErr := NewImportService(s.host).Import(ctx, ImportRequest{
|
||||
Provider: req.Provider,
|
||||
Mode: req.Mode,
|
||||
Access: req.Access,
|
||||
Keys: req.Keys,
|
||||
})
|
||||
if report.BatchStatus == "" {
|
||||
report.BatchStatus = BatchStatusFailed
|
||||
}
|
||||
if report.AccessStatus == "" {
|
||||
report.AccessStatus = AccessStatusBroken
|
||||
}
|
||||
|
||||
if persistErr := s.persistRuntimeArtifacts(ctx, batchID, req.Access.Mode, report, importErr == nil); persistErr != nil {
|
||||
return RuntimeImportResult{}, persistErr
|
||||
}
|
||||
if err := s.store.ImportBatches().UpdateStatus(ctx, batchID, report.BatchStatus, report.AccessStatus); err != nil {
|
||||
return RuntimeImportResult{}, err
|
||||
}
|
||||
if importErr != nil {
|
||||
return RuntimeImportResult{BatchID: batchID, Report: report}, importErr
|
||||
}
|
||||
return RuntimeImportResult{BatchID: batchID, Report: report}, nil
|
||||
}
|
||||
|
||||
func (s *RuntimeImportService) ensureHost(ctx context.Context, hostID, baseURL, hostVersion, capabilityProbeJSON string) (sqlite.Host, error) {
|
||||
host, err := s.store.Hosts().GetByHostID(ctx, hostID)
|
||||
if err == nil {
|
||||
return host, nil
|
||||
}
|
||||
if _, createErr := s.store.Hosts().Create(ctx, sqlite.Host{
|
||||
HostID: hostID,
|
||||
BaseURL: baseURL,
|
||||
HostVersion: hostVersion,
|
||||
CapabilityProbeJSON: capabilityProbeJSON,
|
||||
}); createErr != nil {
|
||||
return sqlite.Host{}, createErr
|
||||
}
|
||||
return s.store.Hosts().GetByHostID(ctx, hostID)
|
||||
}
|
||||
|
||||
func (s *RuntimeImportService) ensurePack(ctx context.Context, loaded pack.LoadedPack) (sqlite.Pack, error) {
|
||||
packRow, err := s.store.Packs().GetByPackID(ctx, loaded.Manifest.PackID)
|
||||
if err == nil {
|
||||
if err := validateExistingPack(packRow, loaded); err != nil {
|
||||
return sqlite.Pack{}, err
|
||||
}
|
||||
}
|
||||
packRecord, err := buildPackRecord(loaded)
|
||||
if err != nil {
|
||||
return sqlite.Pack{}, err
|
||||
}
|
||||
if _, err := s.store.Packs().Upsert(ctx, packRecord); err != nil {
|
||||
return sqlite.Pack{}, err
|
||||
}
|
||||
return s.store.Packs().GetByPackID(ctx, loaded.Manifest.PackID)
|
||||
}
|
||||
|
||||
func (s *RuntimeImportService) ensureProvider(ctx context.Context, packID int64, provider pack.ProviderManifest) (sqlite.Provider, error) {
|
||||
if _, err := s.store.Providers().GetByPackIDAndProviderID(ctx, packID, provider.ProviderID); err == nil {
|
||||
// continue into upsert path so metadata stays fresh.
|
||||
}
|
||||
providerRecord, err := buildProviderRecord(packID, provider)
|
||||
if err != nil {
|
||||
return sqlite.Provider{}, err
|
||||
}
|
||||
if _, err := s.store.Providers().Upsert(ctx, providerRecord); err != nil {
|
||||
return sqlite.Provider{}, err
|
||||
}
|
||||
return s.store.Providers().GetByPackIDAndProviderID(ctx, packID, provider.ProviderID)
|
||||
}
|
||||
|
||||
func (s *RuntimeImportService) persistRuntimeArtifacts(ctx context.Context, batchID int64, accessMode string, report ImportReport, includeManagedResources bool) error {
|
||||
for i, account := range report.Accounts {
|
||||
payload, err := json.Marshal(map[string]any{
|
||||
"account_id": account.Ref.ID,
|
||||
"probe_ok": account.Probe.OK,
|
||||
"probe_status": account.Probe.Status,
|
||||
"probe_message": account.Probe.Message,
|
||||
"models": account.Models,
|
||||
"smoke_model_seen": account.SmokeModelSeen,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal account probe summary: %w", err)
|
||||
}
|
||||
itemID, err := s.store.ImportBatchItems().Create(ctx, sqlite.ImportBatchItem{
|
||||
BatchID: batchID,
|
||||
KeyFingerprint: fingerprintKey(report.AcceptedKeys, i),
|
||||
AccountStatus: firstNonEmpty(account.Probe.Status, "unknown"),
|
||||
ProbeSummaryJSON: string(payload),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := s.store.ProbeResults().Create(ctx, sqlite.ProbeResult{
|
||||
BatchItemID: itemID,
|
||||
ProbeType: "account_smoke",
|
||||
Status: firstNonEmpty(account.Probe.Status, "unknown"),
|
||||
SummaryJSON: string(payload),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if includeManagedResources {
|
||||
if report.Group.ID != "" {
|
||||
if _, err := s.store.ManagedResources().Create(ctx, sqlite.ManagedResource{BatchID: batchID, ResourceType: "group", HostResourceID: report.Group.ID, ResourceName: firstNonEmpty(report.Group.Name, report.Group.ID)}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if report.Channel.ID != "" {
|
||||
if _, err := s.store.ManagedResources().Create(ctx, sqlite.ManagedResource{BatchID: batchID, ResourceType: "channel", HostResourceID: report.Channel.ID, ResourceName: firstNonEmpty(report.Channel.Name, report.Channel.ID)}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if report.Plan != nil && report.Plan.ID != "" {
|
||||
if _, err := s.store.ManagedResources().Create(ctx, sqlite.ManagedResource{BatchID: batchID, ResourceType: "plan", HostResourceID: report.Plan.ID, ResourceName: firstNonEmpty(report.Plan.Name, report.Plan.ID)}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for _, account := range report.Accounts {
|
||||
if _, err := s.store.ManagedResources().Create(ctx, sqlite.ManagedResource{BatchID: batchID, ResourceType: "account", HostResourceID: account.Ref.ID, ResourceName: firstNonEmpty(account.Ref.Name, account.Ref.ID)}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessPayload, err := json.Marshal(map[string]any{
|
||||
"status_code": report.Gateway.StatusCode,
|
||||
"ok": report.Gateway.OK,
|
||||
"has_expected_model": report.Gateway.HasExpectedModel,
|
||||
"models": report.Gateway.Models,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal gateway access summary: %w", err)
|
||||
}
|
||||
if _, err := s.store.AccessClosures().Create(ctx, sqlite.AccessClosureRecord{
|
||||
BatchID: batchID,
|
||||
ClosureType: firstNonEmpty(strings.TrimSpace(accessMode), "unknown"),
|
||||
Status: firstNonEmpty(report.AccessStatus, AccessStatusBroken),
|
||||
DetailsJSON: string(accessPayload),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func fingerprintKey(keys []string, index int) string {
|
||||
if index >= 0 && index < len(keys) {
|
||||
key := strings.TrimSpace(keys[index])
|
||||
if key != "" {
|
||||
sum := sha256.Sum256([]byte(key))
|
||||
return fmt.Sprintf("sha256:%x", sum[:])
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("key-%d", index+1)
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, value := range values {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
196
internal/provision/runtime_import_service_test.go
Normal file
196
internal/provision/runtime_import_service_test.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
"sub2api-cn-relay-manager/internal/host/sub2api"
|
||||
"sub2api-cn-relay-manager/internal/pack"
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
func TestRuntimeImportServicePersistsOperationalState(t *testing.T) {
|
||||
store := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, store)
|
||||
|
||||
host := &fakeHostAdapter{
|
||||
batchAccounts: []sub2api.AccountRef{{ID: "account_1"}, {ID: "account_2"}},
|
||||
testResults: map[string]sub2api.ProbeResult{
|
||||
"account_1": {OK: true, Status: "passed"},
|
||||
"account_2": {OK: true, Status: "passed"},
|
||||
},
|
||||
models: map[string][]sub2api.AccountModel{
|
||||
"account_1": {{ID: "deepseek-chat"}},
|
||||
"account_2": {{ID: "deepseek-chat"}},
|
||||
},
|
||||
gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}},
|
||||
}
|
||||
|
||||
svc := NewRuntimeImportService(store, host)
|
||||
result, err := svc.Import(context.Background(), RuntimeImportRequest{
|
||||
HostID: "host-1",
|
||||
HostBaseURL: "https://sub2api.example.com",
|
||||
Pack: pack.LoadedPack{
|
||||
Manifest: pack.Manifest{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api", MinHostVersion: "0.1.126", MaxHostVersion: "0.2.x"},
|
||||
Checksum: "checksum-1",
|
||||
},
|
||||
Provider: sampleProviderManifest(),
|
||||
Mode: ImportModePartial,
|
||||
Keys: []string{" key-1 ", "key-2", "key-1"},
|
||||
Access: AccessRequest{
|
||||
Mode: AccessModeSelfService,
|
||||
ProbeAPIKey: "user-key",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("RuntimeImportService.Import() error = %v", err)
|
||||
}
|
||||
if result.BatchID <= 0 {
|
||||
t.Fatalf("BatchID = %d, want positive id", result.BatchID)
|
||||
}
|
||||
if result.Report.BatchStatus != BatchStatusSucceeded {
|
||||
t.Fatalf("BatchStatus = %q, want %q", result.Report.BatchStatus, BatchStatusSucceeded)
|
||||
}
|
||||
|
||||
if got := queryCount(t, store.SQLDB(), "hosts"); got != 1 {
|
||||
t.Fatalf("hosts row count = %d, want 1", got)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "packs"); got != 1 {
|
||||
t.Fatalf("packs row count = %d, want 1", got)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "providers"); got != 1 {
|
||||
t.Fatalf("providers row count = %d, want 1", got)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "import_batches"); got != 1 {
|
||||
t.Fatalf("import_batches row count = %d, want 1", got)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "import_batch_items"); got != 2 {
|
||||
t.Fatalf("import_batch_items row count = %d, want 2", got)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "managed_resources"); got != 4 {
|
||||
t.Fatalf("managed_resources row count = %d, want 4", got)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "probe_results"); got != 2 {
|
||||
t.Fatalf("probe_results row count = %d, want 2", got)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "access_closure_records"); got != 1 {
|
||||
t.Fatalf("access_closure_records row count = %d, want 1", got)
|
||||
}
|
||||
|
||||
var batchStatus string
|
||||
var accessStatus string
|
||||
if err := store.SQLDB().QueryRowContext(context.Background(), "SELECT batch_status, access_status FROM import_batches WHERE id = ?", result.BatchID).Scan(&batchStatus, &accessStatus); err != nil {
|
||||
t.Fatalf("query import batch state: %v", err)
|
||||
}
|
||||
if batchStatus != BatchStatusSucceeded {
|
||||
t.Fatalf("persisted batch_status = %q, want %q", batchStatus, BatchStatusSucceeded)
|
||||
}
|
||||
if accessStatus != AccessStatusSelfServiceReady {
|
||||
t.Fatalf("persisted access_status = %q, want %q", accessStatus, AccessStatusSelfServiceReady)
|
||||
}
|
||||
|
||||
var fingerprint string
|
||||
var accountStatus string
|
||||
if err := store.SQLDB().QueryRowContext(context.Background(), "SELECT key_fingerprint, account_status FROM import_batch_items ORDER BY id LIMIT 1").Scan(&fingerprint, &accountStatus); err != nil {
|
||||
t.Fatalf("query import batch item: %v", err)
|
||||
}
|
||||
if fingerprint == "key-1" || fingerprint == "key-2" || len(fingerprint) < 10 {
|
||||
t.Fatalf("key_fingerprint = %q, want hashed fingerprint instead of raw key", fingerprint)
|
||||
}
|
||||
if accountStatus != "passed" {
|
||||
t.Fatalf("account_status = %q, want passed", accountStatus)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuntimeImportServicePersistsFailedBatchAfterStrictRollback(t *testing.T) {
|
||||
store := openProvisionTestStore(t)
|
||||
defer closeProvisionTestStore(t, store)
|
||||
|
||||
host := &fakeHostAdapter{
|
||||
batchAccounts: []sub2api.AccountRef{{ID: "account_1"}, {ID: "account_2"}},
|
||||
testResults: map[string]sub2api.ProbeResult{
|
||||
"account_1": {OK: true, Status: "passed"},
|
||||
"account_2": {OK: false, Status: "failed", Message: "bad key"},
|
||||
},
|
||||
models: map[string][]sub2api.AccountModel{
|
||||
"account_1": {{ID: "deepseek-chat"}},
|
||||
"account_2": {{ID: "deepseek-chat"}},
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewRuntimeImportService(store, host)
|
||||
result, err := svc.Import(context.Background(), RuntimeImportRequest{
|
||||
HostID: "host-1",
|
||||
HostBaseURL: "https://sub2api.example.com",
|
||||
Pack: pack.LoadedPack{
|
||||
Manifest: pack.Manifest{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api", MinHostVersion: "0.1.126", MaxHostVersion: "0.2.x"},
|
||||
Checksum: "checksum-1",
|
||||
},
|
||||
Provider: sampleProviderManifest(),
|
||||
Mode: ImportModeStrict,
|
||||
Keys: []string{"key-1", "key-2"},
|
||||
Access: AccessRequest{
|
||||
Mode: AccessModeSelfService,
|
||||
ProbeAPIKey: "user-key",
|
||||
},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("RuntimeImportService.Import() error = nil, want strict failure")
|
||||
}
|
||||
if result.BatchID <= 0 {
|
||||
t.Fatalf("BatchID = %d, want positive id", result.BatchID)
|
||||
}
|
||||
|
||||
var batchStatus string
|
||||
var accessStatus string
|
||||
if err := store.SQLDB().QueryRowContext(context.Background(), "SELECT batch_status, access_status FROM import_batches WHERE id = ?", result.BatchID).Scan(&batchStatus, &accessStatus); err != nil {
|
||||
t.Fatalf("query failed import batch state: %v", err)
|
||||
}
|
||||
if batchStatus != BatchStatusFailed {
|
||||
t.Fatalf("persisted batch_status = %q, want %q", batchStatus, BatchStatusFailed)
|
||||
}
|
||||
if accessStatus != AccessStatusBroken {
|
||||
t.Fatalf("persisted access_status = %q, want %q", accessStatus, AccessStatusBroken)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "managed_resources"); got != 0 {
|
||||
t.Fatalf("managed_resources row count = %d, want 0 after strict rollback", got)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "probe_results"); got != 2 {
|
||||
t.Fatalf("probe_results row count = %d, want 2", got)
|
||||
}
|
||||
if got := queryCount(t, store.SQLDB(), "access_closure_records"); got != 1 {
|
||||
t.Fatalf("access_closure_records row count = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func openProvisionTestStore(t *testing.T) *sqlite.DB {
|
||||
t.Helper()
|
||||
|
||||
dbPath := filepath.Join(t.TempDir(), "state.db")
|
||||
dsn := fmt.Sprintf("file:%s?_busy_timeout=5000&_pragma=foreign_keys(0)", filepath.ToSlash(dbPath))
|
||||
store, err := sqlite.Open(context.Background(), dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("sqlite.Open() error = %v", err)
|
||||
}
|
||||
return store
|
||||
}
|
||||
|
||||
func closeProvisionTestStore(t *testing.T, store *sqlite.DB) {
|
||||
t.Helper()
|
||||
if err := store.Close(); err != nil {
|
||||
t.Fatalf("store.Close() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func queryCount(t *testing.T, db *sql.DB, table string) int {
|
||||
t.Helper()
|
||||
var count int
|
||||
if err := db.QueryRowContext(context.Background(), "SELECT COUNT(*) FROM "+table).Scan(&count); err != nil {
|
||||
t.Fatalf("count rows for %s: %v", table, err)
|
||||
}
|
||||
return count
|
||||
}
|
||||
64
internal/store/migrations/0002_operational_runtime.sql
Normal file
64
internal/store/migrations/0002_operational_runtime.sql
Normal file
@@ -0,0 +1,64 @@
|
||||
CREATE TABLE import_batches (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
host_id INTEGER NOT NULL,
|
||||
pack_id INTEGER NOT NULL,
|
||||
provider_id INTEGER NOT NULL,
|
||||
mode TEXT NOT NULL,
|
||||
batch_status TEXT NOT NULL,
|
||||
access_status TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
CONSTRAINT fk_import_batches_host FOREIGN KEY (host_id) REFERENCES hosts(id) ON DELETE CASCADE,
|
||||
CONSTRAINT fk_import_batches_pack FOREIGN KEY (pack_id) REFERENCES packs(id) ON DELETE CASCADE,
|
||||
CONSTRAINT fk_import_batches_provider FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE import_batch_items (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
batch_id INTEGER NOT NULL,
|
||||
key_fingerprint TEXT NOT NULL,
|
||||
account_status TEXT NOT NULL,
|
||||
probe_summary_json TEXT NOT NULL DEFAULT '{}',
|
||||
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
CONSTRAINT fk_import_batch_items_batch FOREIGN KEY (batch_id) REFERENCES import_batches(id) ON DELETE CASCADE,
|
||||
CONSTRAINT uq_import_batch_items_batch_fingerprint UNIQUE (batch_id, key_fingerprint)
|
||||
);
|
||||
|
||||
CREATE TABLE managed_resources (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
batch_id INTEGER NOT NULL,
|
||||
resource_type TEXT NOT NULL,
|
||||
host_resource_id TEXT NOT NULL,
|
||||
resource_name TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
CONSTRAINT fk_managed_resources_batch FOREIGN KEY (batch_id) REFERENCES import_batches(id) ON DELETE CASCADE,
|
||||
CONSTRAINT uq_managed_resources_host UNIQUE (resource_type, host_resource_id)
|
||||
);
|
||||
|
||||
CREATE TABLE probe_results (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
batch_item_id INTEGER NOT NULL,
|
||||
probe_type TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
summary_json TEXT NOT NULL DEFAULT '{}',
|
||||
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
CONSTRAINT fk_probe_results_item FOREIGN KEY (batch_item_id) REFERENCES import_batch_items(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE access_closure_records (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
batch_id INTEGER NOT NULL,
|
||||
closure_type TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
details_json TEXT NOT NULL DEFAULT '{}',
|
||||
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
CONSTRAINT fk_access_closure_records_batch FOREIGN KEY (batch_id) REFERENCES import_batches(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE reconcile_runs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
provider_id INTEGER NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
summary_json TEXT NOT NULL DEFAULT '{}',
|
||||
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
CONSTRAINT fk_reconcile_runs_provider FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE
|
||||
);
|
||||
14
internal/store/migrations/0003_pack_install_metadata.sql
Normal file
14
internal/store/migrations/0003_pack_install_metadata.sql
Normal file
@@ -0,0 +1,14 @@
|
||||
ALTER TABLE packs ADD COLUMN vendor TEXT NOT NULL DEFAULT '';
|
||||
ALTER TABLE packs ADD COLUMN target_host TEXT NOT NULL DEFAULT '';
|
||||
ALTER TABLE packs ADD COLUMN min_host_version TEXT NOT NULL DEFAULT '';
|
||||
ALTER TABLE packs ADD COLUMN max_host_version TEXT NOT NULL DEFAULT '';
|
||||
ALTER TABLE packs ADD COLUMN manifest_json TEXT NOT NULL DEFAULT '{}';
|
||||
|
||||
ALTER TABLE providers ADD COLUMN account_type TEXT NOT NULL DEFAULT '';
|
||||
ALTER TABLE providers ADD COLUMN default_models_json TEXT NOT NULL DEFAULT '[]';
|
||||
ALTER TABLE providers ADD COLUMN smoke_test_model TEXT NOT NULL DEFAULT '';
|
||||
ALTER TABLE providers ADD COLUMN group_template_json TEXT NOT NULL DEFAULT '{}';
|
||||
ALTER TABLE providers ADD COLUMN channel_template_json TEXT NOT NULL DEFAULT '{}';
|
||||
ALTER TABLE providers ADD COLUMN plan_template_json TEXT NOT NULL DEFAULT '{}';
|
||||
ALTER TABLE providers ADD COLUMN import_options_json TEXT NOT NULL DEFAULT '{}';
|
||||
ALTER TABLE providers ADD COLUMN manifest_json TEXT NOT NULL DEFAULT '{}';
|
||||
77
internal/store/sqlite/access_closure_records_repo.go
Normal file
77
internal/store/sqlite/access_closure_records_repo.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type AccessClosureRecord struct {
|
||||
ID int64
|
||||
BatchID int64
|
||||
ClosureType string
|
||||
Status string
|
||||
DetailsJSON string
|
||||
}
|
||||
|
||||
type AccessClosureRecordsRepo struct {
|
||||
db execQuerier
|
||||
}
|
||||
|
||||
func newAccessClosureRecordsRepo(db execQuerier) *AccessClosureRecordsRepo {
|
||||
return &AccessClosureRecordsRepo{db: db}
|
||||
}
|
||||
|
||||
func (r *AccessClosureRecordsRepo) Create(ctx context.Context, record AccessClosureRecord) (int64, error) {
|
||||
closureType := strings.TrimSpace(record.ClosureType)
|
||||
status := strings.TrimSpace(record.Status)
|
||||
detailsJSON := strings.TrimSpace(record.DetailsJSON)
|
||||
if detailsJSON == "" {
|
||||
detailsJSON = "{}"
|
||||
}
|
||||
|
||||
switch {
|
||||
case record.BatchID <= 0:
|
||||
return 0, fmt.Errorf("batch_id is required")
|
||||
case closureType == "":
|
||||
return 0, fmt.Errorf("closure_type is required")
|
||||
case status == "":
|
||||
return 0, fmt.Errorf("status is required")
|
||||
}
|
||||
|
||||
result, err := r.db.ExecContext(ctx, `INSERT INTO access_closure_records (batch_id, closure_type, status, details_json) VALUES (?, ?, ?, ?)`, record.BatchID, closureType, status, detailsJSON)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("insert access closure record: %w", err)
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("read inserted access closure record id: %w", err)
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (r *AccessClosureRecordsRepo) GetByBatchID(ctx context.Context, batchID int64) ([]AccessClosureRecord, error) {
|
||||
if batchID <= 0 {
|
||||
return nil, fmt.Errorf("batch_id is required")
|
||||
}
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, `SELECT id, batch_id, closure_type, status, details_json FROM access_closure_records WHERE batch_id = ? ORDER BY id`, batchID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query access closure records: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
records := make([]AccessClosureRecord, 0)
|
||||
for rows.Next() {
|
||||
var record AccessClosureRecord
|
||||
if err := rows.Scan(&record.ID, &record.BatchID, &record.ClosureType, &record.Status, &record.DetailsJSON); err != nil {
|
||||
return nil, fmt.Errorf("scan access closure record: %w", err)
|
||||
}
|
||||
records = append(records, record)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate access closure records: %w", err)
|
||||
}
|
||||
return records, nil
|
||||
}
|
||||
@@ -15,13 +15,20 @@ import (
|
||||
|
||||
type execQuerier interface {
|
||||
ExecContext(context.Context, string, ...any) (sql.Result, error)
|
||||
QueryContext(context.Context, string, ...any) (*sql.Rows, error)
|
||||
QueryRowContext(context.Context, string, ...any) *sql.Row
|
||||
}
|
||||
|
||||
type Queries struct {
|
||||
Hosts *HostsRepo
|
||||
Packs *PacksRepo
|
||||
Providers *ProvidersRepo
|
||||
Hosts *HostsRepo
|
||||
Packs *PacksRepo
|
||||
Providers *ProvidersRepo
|
||||
ImportBatches *ImportBatchesRepo
|
||||
ImportBatchItems *ImportBatchItemsRepo
|
||||
ManagedResources *ManagedResourcesRepo
|
||||
ProbeResults *ProbeResultsRepo
|
||||
AccessClosures *AccessClosureRecordsRepo
|
||||
ReconcileRuns *ReconcileRunsRepo
|
||||
}
|
||||
|
||||
type DB struct {
|
||||
@@ -76,6 +83,30 @@ func (db *DB) Providers() *ProvidersRepo {
|
||||
return db.queries.Providers
|
||||
}
|
||||
|
||||
func (db *DB) ImportBatches() *ImportBatchesRepo {
|
||||
return db.queries.ImportBatches
|
||||
}
|
||||
|
||||
func (db *DB) ImportBatchItems() *ImportBatchItemsRepo {
|
||||
return db.queries.ImportBatchItems
|
||||
}
|
||||
|
||||
func (db *DB) ManagedResources() *ManagedResourcesRepo {
|
||||
return db.queries.ManagedResources
|
||||
}
|
||||
|
||||
func (db *DB) ProbeResults() *ProbeResultsRepo {
|
||||
return db.queries.ProbeResults
|
||||
}
|
||||
|
||||
func (db *DB) AccessClosures() *AccessClosureRecordsRepo {
|
||||
return db.queries.AccessClosures
|
||||
}
|
||||
|
||||
func (db *DB) ReconcileRuns() *ReconcileRunsRepo {
|
||||
return db.queries.ReconcileRuns
|
||||
}
|
||||
|
||||
func (db *DB) WithTx(ctx context.Context, fn func(*Queries) error) error {
|
||||
tx, err := db.sqlDB.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
@@ -101,9 +132,15 @@ func (db *DB) WithTx(ctx context.Context, fn func(*Queries) error) error {
|
||||
|
||||
func newQueries(db execQuerier) *Queries {
|
||||
return &Queries{
|
||||
Hosts: newHostsRepo(db),
|
||||
Packs: newPacksRepo(db),
|
||||
Providers: newProvidersRepo(db),
|
||||
Hosts: newHostsRepo(db),
|
||||
Packs: newPacksRepo(db),
|
||||
Providers: newProvidersRepo(db),
|
||||
ImportBatches: newImportBatchesRepo(db),
|
||||
ImportBatchItems: newImportBatchItemsRepo(db),
|
||||
ManagedResources: newManagedResourcesRepo(db),
|
||||
ProbeResults: newProbeResultsRepo(db),
|
||||
AccessClosures: newAccessClosureRecordsRepo(db),
|
||||
ReconcileRuns: newReconcileRunsRepo(db),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
161
internal/store/sqlite/db_test.go
Normal file
161
internal/store/sqlite/db_test.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestOpenClose(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
if store == nil {
|
||||
t.Fatal("Open() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenInvalidDSN(t *testing.T) {
|
||||
_, err := Open(context.Background(), "file:/nonexistent/dir/test.db?_pragma=foreign_keys(0)")
|
||||
if err == nil {
|
||||
t.Fatal("Open() with invalid dsn error = nil, want error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithTxCommit(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
|
||||
err := store.WithTx(context.Background(), func(q *Queries) error {
|
||||
_, err := q.Hosts.Create(context.Background(), Host{
|
||||
HostID: "tx-host", BaseURL: "https://tx.com", HostVersion: "1.0",
|
||||
})
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("WithTx() error = %v", err)
|
||||
}
|
||||
|
||||
host, err := store.Hosts().GetByHostID(context.Background(), "tx-host")
|
||||
if err != nil {
|
||||
t.Fatalf("GetByHostID() after tx = %v", err)
|
||||
}
|
||||
if host.HostID != "tx-host" {
|
||||
t.Fatalf("host = %+v, want tx-host", host)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithTxRollbackOnError(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
|
||||
err := store.WithTx(context.Background(), func(q *Queries) error {
|
||||
q.Hosts.Create(context.Background(), Host{
|
||||
HostID: "rollback-host", BaseURL: "https://r.com", HostVersion: "1.0",
|
||||
})
|
||||
return errors.New("rollback")
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("WithTx() error = nil, want rollback error")
|
||||
}
|
||||
|
||||
_, err = store.Hosts().GetByHostID(context.Background(), "rollback-host")
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
t.Fatalf("GetByHostID() after rollback error = %v, want sql.ErrNoRows", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTableExists(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
db := store.SQLDB()
|
||||
|
||||
found, err := tableExists(context.Background(), db, "hosts")
|
||||
if err != nil {
|
||||
t.Fatalf("tableExists('hosts') error = %v", err)
|
||||
}
|
||||
if !found {
|
||||
t.Fatal("tableExists('hosts') = false, want true")
|
||||
}
|
||||
|
||||
found, err = tableExists(context.Background(), db, "nonexistent")
|
||||
if err != nil {
|
||||
t.Fatalf("tableExists('nonexistent') error = %v", err)
|
||||
}
|
||||
if found {
|
||||
t.Fatal("tableExists('nonexistent') = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectLegacy0001Schema(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
db := store.SQLDB()
|
||||
|
||||
tx, err := db.BeginTx(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("BeginTx error = %v", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// After migration all three host/packs/providers tables exist,
|
||||
// so detectLegacy0001Schema reports complete=true.
|
||||
complete, partial, err := detectLegacy0001Schema(context.Background(), tx)
|
||||
if err != nil {
|
||||
t.Fatalf("detectLegacy0001Schema() error = %v", err)
|
||||
}
|
||||
if !complete {
|
||||
t.Fatalf("detectLegacy0001Schema() = (complete=%v, partial=%v), want (true, false)", complete, partial)
|
||||
}
|
||||
if partial {
|
||||
t.Fatal("partial should be false when all 3 tables exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithForeignKeysEnabled(t *testing.T) {
|
||||
if got := withForeignKeysEnabled("file:test.db"); got != "file:test.db?_pragma=foreign_keys(1)" {
|
||||
t.Fatalf("withForeignKeysEnabled no query = %q", got)
|
||||
}
|
||||
if got := withForeignKeysEnabled("file:test.db?a=1"); got != "file:test.db?a=1&_pragma=foreign_keys(1)" {
|
||||
t.Fatalf("withForeignKeysEnabled with query = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLDB(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
db := store.SQLDB()
|
||||
if db == nil {
|
||||
t.Fatal("SQLDB() returned nil")
|
||||
}
|
||||
if err := db.PingContext(context.Background()); err != nil {
|
||||
t.Fatalf("Ping() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrationFileNames(t *testing.T) {
|
||||
names, err := migrationFileNames()
|
||||
if err != nil {
|
||||
t.Fatalf("migrationFileNames() error = %v", err)
|
||||
}
|
||||
if len(names) == 0 {
|
||||
t.Fatal("migrationFileNames() returned empty")
|
||||
}
|
||||
for _, name := range names {
|
||||
if filepath.Ext(name) != ".sql" {
|
||||
t.Fatalf("migrationFileNames() entry %q not .sql file", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadMigration(t *testing.T) {
|
||||
content, err := readMigration("0001_init.sql")
|
||||
if err != nil {
|
||||
t.Fatalf("readMigration('0001_init.sql') error = %v", err)
|
||||
}
|
||||
if len(content) == 0 {
|
||||
t.Fatal("readMigration() returned empty content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadMigrationNotFound(t *testing.T) {
|
||||
_, err := readMigration("nonexistent.sql")
|
||||
if err == nil {
|
||||
t.Fatal("readMigration('nonexistent.sql') error = nil, want error")
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
)
|
||||
|
||||
type Host struct {
|
||||
ID int64
|
||||
HostID string
|
||||
BaseURL string
|
||||
HostVersion string
|
||||
@@ -21,6 +22,31 @@ func newHostsRepo(db execQuerier) *HostsRepo {
|
||||
return &HostsRepo{db: db}
|
||||
}
|
||||
|
||||
func (r *HostsRepo) GetByID(ctx context.Context, id int64) (Host, error) {
|
||||
if id <= 0 {
|
||||
return Host{}, fmt.Errorf("id is required")
|
||||
}
|
||||
|
||||
var host Host
|
||||
if err := r.db.QueryRowContext(ctx, `SELECT id, host_id, base_url, host_version, capability_probe_json FROM hosts WHERE id = ?`, id).Scan(&host.ID, &host.HostID, &host.BaseURL, &host.HostVersion, &host.CapabilityProbeJSON); err != nil {
|
||||
return Host{}, err
|
||||
}
|
||||
return host, nil
|
||||
}
|
||||
|
||||
func (r *HostsRepo) GetByHostID(ctx context.Context, hostID string) (Host, error) {
|
||||
hostID = strings.TrimSpace(hostID)
|
||||
if hostID == "" {
|
||||
return Host{}, fmt.Errorf("host_id is required")
|
||||
}
|
||||
|
||||
var host Host
|
||||
if err := r.db.QueryRowContext(ctx, `SELECT id, host_id, base_url, host_version, capability_probe_json FROM hosts WHERE host_id = ?`, hostID).Scan(&host.ID, &host.HostID, &host.BaseURL, &host.HostVersion, &host.CapabilityProbeJSON); err != nil {
|
||||
return Host{}, err
|
||||
}
|
||||
return host, nil
|
||||
}
|
||||
|
||||
func (r *HostsRepo) Create(ctx context.Context, host Host) (int64, error) {
|
||||
hostID := strings.TrimSpace(host.HostID)
|
||||
baseURL := strings.TrimSpace(host.BaseURL)
|
||||
|
||||
199
internal/store/sqlite/hosts_repo_test.go
Normal file
199
internal/store/sqlite/hosts_repo_test.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// openTestDB creates a test database with foreign keys disabled.
|
||||
func openTestDB(t *testing.T) *DB {
|
||||
t.Helper()
|
||||
dbPath := filepath.Join(t.TempDir(), "test.db")
|
||||
dsn := "file:" + filepath.ToSlash(dbPath) + "?_pragma=foreign_keys(0)"
|
||||
store, err := Open(context.Background(), dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("Open() error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() { store.Close() })
|
||||
return store
|
||||
}
|
||||
|
||||
// openTestDBWithFK creates a test database with foreign keys enforced.
|
||||
func openTestDBWithFK(t *testing.T) *DB {
|
||||
t.Helper()
|
||||
dbPath := filepath.Join(t.TempDir(), "test-fk.db")
|
||||
dsn := "file:" + filepath.ToSlash(dbPath)
|
||||
store, err := Open(context.Background(), dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("Open() error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() { store.Close() })
|
||||
return store
|
||||
}
|
||||
|
||||
func createTestPack(t *testing.T, store *DB) int64 {
|
||||
t.Helper()
|
||||
id, err := store.Packs().Create(context.Background(), Pack{
|
||||
PackID: "pack-" + sanitizeTestName(t.Name()), Version: "1.0.0", Checksum: "chk",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("createTestPack error = %v", err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func createTestHost(t *testing.T, store *DB) int64 {
|
||||
t.Helper()
|
||||
id, err := store.Hosts().Create(context.Background(), Host{
|
||||
HostID: "host-" + sanitizeTestName(t.Name()), BaseURL: "https://h.com", HostVersion: "0.1.0",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("createTestHost error = %v", err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func createTestBatch(t *testing.T, store *DB) int64 {
|
||||
t.Helper()
|
||||
hostID := createTestHost(t, store)
|
||||
packID := createTestPack(t, store)
|
||||
providerID, err := store.Providers().Create(context.Background(), Provider{
|
||||
PackID: packID, ProviderID: "test-provider", DisplayName: "Test",
|
||||
BaseURL: "https://t.com", Platform: "openai",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("createTestBatch create provider error = %v", err)
|
||||
}
|
||||
id, err := store.ImportBatches().Create(context.Background(), ImportBatch{
|
||||
HostID: hostID, PackID: packID, ProviderID: providerID,
|
||||
Mode: "partial", BatchStatus: "running", AccessStatus: "pending",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("createTestBatch error = %v", err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func createTestBatchItem(t *testing.T, store *DB, batchID int64) int64 {
|
||||
t.Helper()
|
||||
id, err := store.ImportBatchItems().Create(context.Background(), ImportBatchItem{
|
||||
BatchID: batchID, KeyFingerprint: "sha256:test", AccountStatus: "pending",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("createTestBatchItem error = %v", err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func sanitizeTestName(name string) string {
|
||||
result := ""
|
||||
for _, c := range name {
|
||||
if (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' {
|
||||
result += string(c)
|
||||
}
|
||||
}
|
||||
if result == "" {
|
||||
result = "default"
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// --- Hosts Repo Tests ---
|
||||
|
||||
func TestHostsRepoCreateAndGet(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
|
||||
id, err := store.Hosts().Create(context.Background(), Host{
|
||||
HostID: "host-1",
|
||||
BaseURL: "https://sub2api.example.com",
|
||||
HostVersion: "0.1.126",
|
||||
CapabilityProbeJSON: `{"groups":true}`,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
if id <= 0 {
|
||||
t.Fatalf("Create() id = %d, want positive", id)
|
||||
}
|
||||
|
||||
got, err := store.Hosts().GetByID(context.Background(), id)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByID() error = %v", err)
|
||||
}
|
||||
if got.HostID != "host-1" || got.BaseURL != "https://sub2api.example.com" {
|
||||
t.Fatalf("GetByID() = %+v, want host-1", got)
|
||||
}
|
||||
|
||||
got2, err := store.Hosts().GetByHostID(context.Background(), "host-1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetByHostID() error = %v", err)
|
||||
}
|
||||
if got2.ID != id {
|
||||
t.Fatalf("GetByHostID() id = %d, want %d", got2.ID, id)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostsRepoCreateDefaultsCapabilityProbe(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
id, _ := store.Hosts().Create(context.Background(), Host{
|
||||
HostID: "host-empty", BaseURL: "https://example.com", HostVersion: "0.1.0",
|
||||
})
|
||||
got, _ := store.Hosts().GetByID(context.Background(), id)
|
||||
if got.CapabilityProbeJSON != "{}" {
|
||||
t.Fatalf("CapabilityProbeJSON = %q, want {}", got.CapabilityProbeJSON)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostsRepoValidationErrors(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
host Host
|
||||
}{
|
||||
{"empty host_id", Host{BaseURL: "b", HostVersion: "v"}},
|
||||
{"empty base_url", Host{HostID: "h", HostVersion: "v"}},
|
||||
{"empty host_version", Host{HostID: "h", BaseURL: "b"}},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := store.Hosts().Create(context.Background(), tt.host)
|
||||
if err == nil {
|
||||
t.Fatal("Create() error = nil, want validation error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostsRepoGetByIDZeroError(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
_, err := store.Hosts().GetByID(context.Background(), 0)
|
||||
if err == nil {
|
||||
t.Fatal("GetByID(0) error = nil, want error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostsRepoGetByIDNotFound(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
_, err := store.Hosts().GetByID(context.Background(), 999)
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
t.Fatalf("GetByID(999) error = %v, want sql.ErrNoRows", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostsRepoGetByHostIDEmptyError(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
_, err := store.Hosts().GetByHostID(context.Background(), "")
|
||||
if err == nil {
|
||||
t.Fatal("GetByHostID('') error = nil, want error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostsRepoGetByHostIDNotFound(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
_, err := store.Hosts().GetByHostID(context.Background(), "nonexistent")
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
t.Fatalf("GetByHostID('nonexistent') error = %v, want sql.ErrNoRows", err)
|
||||
}
|
||||
}
|
||||
187
internal/store/sqlite/import_batches_repo.go
Normal file
187
internal/store/sqlite/import_batches_repo.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ImportBatch struct {
|
||||
ID int64
|
||||
HostID int64
|
||||
PackID int64
|
||||
ProviderID int64
|
||||
Mode string
|
||||
BatchStatus string
|
||||
AccessStatus string
|
||||
}
|
||||
|
||||
type ImportBatchItem struct {
|
||||
ID int64
|
||||
BatchID int64
|
||||
KeyFingerprint string
|
||||
AccountStatus string
|
||||
ProbeSummaryJSON string
|
||||
}
|
||||
|
||||
type ImportBatchesRepo struct {
|
||||
db execQuerier
|
||||
}
|
||||
|
||||
type ImportBatchItemsRepo struct {
|
||||
db execQuerier
|
||||
}
|
||||
|
||||
func newImportBatchesRepo(db execQuerier) *ImportBatchesRepo {
|
||||
return &ImportBatchesRepo{db: db}
|
||||
}
|
||||
|
||||
func newImportBatchItemsRepo(db execQuerier) *ImportBatchItemsRepo {
|
||||
return &ImportBatchItemsRepo{db: db}
|
||||
}
|
||||
|
||||
func (r *ImportBatchesRepo) GetByID(ctx context.Context, id int64) (ImportBatch, error) {
|
||||
if id <= 0 {
|
||||
return ImportBatch{}, fmt.Errorf("id is required")
|
||||
}
|
||||
|
||||
var batch ImportBatch
|
||||
if err := r.db.QueryRowContext(ctx, `SELECT id, host_id, pack_id, provider_id, mode, batch_status, access_status FROM import_batches WHERE id = ?`, id).Scan(&batch.ID, &batch.HostID, &batch.PackID, &batch.ProviderID, &batch.Mode, &batch.BatchStatus, &batch.AccessStatus); err != nil {
|
||||
return ImportBatch{}, err
|
||||
}
|
||||
return batch, nil
|
||||
}
|
||||
|
||||
func (r *ImportBatchesRepo) Create(ctx context.Context, batch ImportBatch) (int64, error) {
|
||||
mode := strings.TrimSpace(batch.Mode)
|
||||
batchStatus := strings.TrimSpace(batch.BatchStatus)
|
||||
accessStatus := strings.TrimSpace(batch.AccessStatus)
|
||||
|
||||
switch {
|
||||
case batch.HostID <= 0:
|
||||
return 0, fmt.Errorf("host_id is required")
|
||||
case batch.PackID <= 0:
|
||||
return 0, fmt.Errorf("pack_id is required")
|
||||
case batch.ProviderID <= 0:
|
||||
return 0, fmt.Errorf("provider_id is required")
|
||||
case mode == "":
|
||||
return 0, fmt.Errorf("mode is required")
|
||||
case batchStatus == "":
|
||||
return 0, fmt.Errorf("batch_status is required")
|
||||
case accessStatus == "":
|
||||
return 0, fmt.Errorf("access_status is required")
|
||||
}
|
||||
|
||||
result, err := r.db.ExecContext(ctx, `INSERT INTO import_batches (host_id, pack_id, provider_id, mode, batch_status, access_status) VALUES (?, ?, ?, ?, ?, ?)`, batch.HostID, batch.PackID, batch.ProviderID, mode, batchStatus, accessStatus)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("insert import batch: %w", err)
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("read inserted import batch id: %w", err)
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (r *ImportBatchesRepo) UpdateStatus(ctx context.Context, id int64, batchStatus, accessStatus string) error {
|
||||
if id <= 0 {
|
||||
return fmt.Errorf("id is required")
|
||||
}
|
||||
batchStatus = strings.TrimSpace(batchStatus)
|
||||
accessStatus = strings.TrimSpace(accessStatus)
|
||||
if batchStatus == "" {
|
||||
return fmt.Errorf("batch_status is required")
|
||||
}
|
||||
if accessStatus == "" {
|
||||
return fmt.Errorf("access_status is required")
|
||||
}
|
||||
if _, err := r.db.ExecContext(ctx, `UPDATE import_batches SET batch_status = ?, access_status = ? WHERE id = ?`, batchStatus, accessStatus, id); err != nil {
|
||||
return fmt.Errorf("update import batch %d: %w", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *ImportBatchesRepo) GetLatestByProviderID(ctx context.Context, providerID int64) (ImportBatch, error) {
|
||||
if providerID <= 0 {
|
||||
return ImportBatch{}, fmt.Errorf("provider_id is required")
|
||||
}
|
||||
|
||||
var batch ImportBatch
|
||||
if err := r.db.QueryRowContext(ctx, `SELECT id, host_id, pack_id, provider_id, mode, batch_status, access_status FROM import_batches WHERE provider_id = ? ORDER BY id DESC LIMIT 1`, providerID).Scan(&batch.ID, &batch.HostID, &batch.PackID, &batch.ProviderID, &batch.Mode, &batch.BatchStatus, &batch.AccessStatus); err != nil {
|
||||
return ImportBatch{}, err
|
||||
}
|
||||
return batch, nil
|
||||
}
|
||||
|
||||
func (r *ImportBatchItemsRepo) GetByBatchID(ctx context.Context, batchID int64) ([]ImportBatchItem, error) {
|
||||
if batchID <= 0 {
|
||||
return nil, fmt.Errorf("batch_id is required")
|
||||
}
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, `SELECT id, batch_id, key_fingerprint, account_status, probe_summary_json FROM import_batch_items WHERE batch_id = ? ORDER BY id`, batchID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query import batch items: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
items := make([]ImportBatchItem, 0)
|
||||
for rows.Next() {
|
||||
var item ImportBatchItem
|
||||
if err := rows.Scan(&item.ID, &item.BatchID, &item.KeyFingerprint, &item.AccountStatus, &item.ProbeSummaryJSON); err != nil {
|
||||
return nil, fmt.Errorf("scan import batch item: %w", err)
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate import batch items: %w", err)
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (r *ImportBatchItemsRepo) Create(ctx context.Context, item ImportBatchItem) (int64, error) {
|
||||
keyFingerprint := strings.TrimSpace(item.KeyFingerprint)
|
||||
accountStatus := strings.TrimSpace(item.AccountStatus)
|
||||
probeSummaryJSON := strings.TrimSpace(item.ProbeSummaryJSON)
|
||||
if probeSummaryJSON == "" {
|
||||
probeSummaryJSON = "{}"
|
||||
}
|
||||
|
||||
switch {
|
||||
case item.BatchID <= 0:
|
||||
return 0, fmt.Errorf("batch_id is required")
|
||||
case keyFingerprint == "":
|
||||
return 0, fmt.Errorf("key_fingerprint is required")
|
||||
case accountStatus == "":
|
||||
return 0, fmt.Errorf("account_status is required")
|
||||
}
|
||||
|
||||
result, err := r.db.ExecContext(ctx, `INSERT INTO import_batch_items (batch_id, key_fingerprint, account_status, probe_summary_json) VALUES (?, ?, ?, ?)`, item.BatchID, keyFingerprint, accountStatus, probeSummaryJSON)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("insert import batch item: %w", err)
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("read inserted import batch item id: %w", err)
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (r *ImportBatchItemsRepo) UpdateResult(ctx context.Context, id int64, accountStatus, probeSummaryJSON string) error {
|
||||
if id <= 0 {
|
||||
return fmt.Errorf("id is required")
|
||||
}
|
||||
accountStatus = strings.TrimSpace(accountStatus)
|
||||
probeSummaryJSON = strings.TrimSpace(probeSummaryJSON)
|
||||
if accountStatus == "" {
|
||||
return fmt.Errorf("account_status is required")
|
||||
}
|
||||
if probeSummaryJSON == "" {
|
||||
probeSummaryJSON = "{}"
|
||||
}
|
||||
if _, err := r.db.ExecContext(ctx, `UPDATE import_batch_items SET account_status = ?, probe_summary_json = ? WHERE id = ?`, accountStatus, probeSummaryJSON, id); err != nil {
|
||||
return fmt.Errorf("update import batch item %d: %w", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
429
internal/store/sqlite/import_batches_repo_test.go
Normal file
429
internal/store/sqlite/import_batches_repo_test.go
Normal file
@@ -0,0 +1,429 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestImportBatchesRepoCreateAndGet(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
hostID := createTestHost(t, store)
|
||||
packID := createTestPack(t, store)
|
||||
providerID := createTestProviderWithPack(t, store, packID)
|
||||
|
||||
id, err := store.ImportBatches().Create(context.Background(), ImportBatch{
|
||||
HostID: hostID, PackID: packID, ProviderID: providerID,
|
||||
Mode: "partial", BatchStatus: "running", AccessStatus: "pending",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
if id <= 0 {
|
||||
t.Fatalf("Create() id = %d, want positive", id)
|
||||
}
|
||||
|
||||
got, _ := store.ImportBatches().GetByID(context.Background(), id)
|
||||
if got.Mode != "partial" || got.BatchStatus != "running" {
|
||||
t.Fatalf("GetByID() = %+v, want running batch", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportBatchesRepoUpdateStatus(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
hostID := createTestHost(t, store)
|
||||
packID := createTestPack(t, store)
|
||||
providerID := createTestProviderWithPack(t, store, packID)
|
||||
id, _ := store.ImportBatches().Create(context.Background(), ImportBatch{
|
||||
HostID: hostID, PackID: packID, ProviderID: providerID,
|
||||
Mode: "partial", BatchStatus: "running", AccessStatus: "pending",
|
||||
})
|
||||
|
||||
err := store.ImportBatches().UpdateStatus(context.Background(), id, "succeeded", "subscription_ready")
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateStatus() error = %v", err)
|
||||
}
|
||||
got, _ := store.ImportBatches().GetByID(context.Background(), id)
|
||||
if got.BatchStatus != "succeeded" || got.AccessStatus != "subscription_ready" {
|
||||
t.Fatalf("status = %+v, want succeeded/subscription_ready", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportBatchesRepoGetLatestByProviderID(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
hostID := createTestHost(t, store)
|
||||
packID := createTestPack(t, store)
|
||||
providerID := createTestProviderWithPack(t, store, packID)
|
||||
|
||||
store.ImportBatches().Create(context.Background(), ImportBatch{
|
||||
HostID: hostID, PackID: packID, ProviderID: providerID,
|
||||
Mode: "partial", BatchStatus: "running", AccessStatus: "pending",
|
||||
})
|
||||
id2, _ := store.ImportBatches().Create(context.Background(), ImportBatch{
|
||||
HostID: hostID, PackID: packID, ProviderID: providerID,
|
||||
Mode: "strict", BatchStatus: "succeeded", AccessStatus: "subscription_ready",
|
||||
})
|
||||
|
||||
got, _ := store.ImportBatches().GetLatestByProviderID(context.Background(), providerID)
|
||||
if got.ID != id2 {
|
||||
t.Fatalf("latest id = %d, want %d", got.ID, id2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportBatchesRepoGetByIDNotFound(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
_, err := store.ImportBatches().GetByID(context.Background(), 999)
|
||||
if err == nil {
|
||||
t.Fatal("GetByID(999) error = nil, want error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportBatchesRepoCreateValidationErrors(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
batch ImportBatch
|
||||
}{
|
||||
{"host_id zero", ImportBatch{HostID: 0, PackID: 1, ProviderID: 1, Mode: "m", BatchStatus: "s", AccessStatus: "s"}},
|
||||
{"pack_id zero", ImportBatch{HostID: 1, PackID: 0, ProviderID: 1, Mode: "m", BatchStatus: "s", AccessStatus: "s"}},
|
||||
{"provider_id zero", ImportBatch{HostID: 1, PackID: 1, ProviderID: 0, Mode: "m", BatchStatus: "s", AccessStatus: "s"}},
|
||||
{"empty mode", ImportBatch{HostID: 1, PackID: 1, ProviderID: 1, Mode: "", BatchStatus: "s", AccessStatus: "s"}},
|
||||
{"empty batch_status", ImportBatch{HostID: 1, PackID: 1, ProviderID: 1, Mode: "m", BatchStatus: "", AccessStatus: "s"}},
|
||||
{"empty access_status", ImportBatch{HostID: 1, PackID: 1, ProviderID: 1, Mode: "m", BatchStatus: "s", AccessStatus: ""}},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := store.ImportBatches().Create(context.Background(), tt.batch)
|
||||
if err == nil {
|
||||
t.Fatal("Create() error = nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportBatchesRepoUpdateStatusValidation(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
if err := store.ImportBatches().UpdateStatus(context.Background(), 0, "s", "s"); err == nil {
|
||||
t.Fatal("UpdateStatus id=0 error = nil")
|
||||
}
|
||||
if err := store.ImportBatches().UpdateStatus(context.Background(), 1, "", "s"); err == nil {
|
||||
t.Fatal("UpdateStatus empty batch_status error = nil")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ImportBatchItems Repo Tests ---
|
||||
|
||||
func TestImportBatchItemsRepoCreateAndGet(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
batchID := createTestBatch(t, store)
|
||||
|
||||
id, err := store.ImportBatchItems().Create(context.Background(), ImportBatchItem{
|
||||
BatchID: batchID,
|
||||
KeyFingerprint: "sha256:abc",
|
||||
AccountStatus: "passed",
|
||||
ProbeSummaryJSON: `{"ok":true}`,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
if id <= 0 {
|
||||
t.Fatalf("Create() id = %d, want positive", id)
|
||||
}
|
||||
|
||||
items, _ := store.ImportBatchItems().GetByBatchID(context.Background(), batchID)
|
||||
if len(items) != 1 || items[0].AccountStatus != "passed" {
|
||||
t.Fatalf("items = %+v, want 1 with passed status", items)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportBatchItemsRepoMultipleItems(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
batchID := createTestBatch(t, store)
|
||||
|
||||
for _, status := range []string{"passed", "failed"} {
|
||||
store.ImportBatchItems().Create(context.Background(), ImportBatchItem{
|
||||
BatchID: batchID, KeyFingerprint: "sha256:" + status, AccountStatus: status,
|
||||
})
|
||||
}
|
||||
|
||||
items, _ := store.ImportBatchItems().GetByBatchID(context.Background(), batchID)
|
||||
if len(items) != 2 {
|
||||
t.Fatalf("count = %d, want 2", len(items))
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportBatchItemsRepoUpdateResult(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
batchID := createTestBatch(t, store)
|
||||
itemID, _ := store.ImportBatchItems().Create(context.Background(), ImportBatchItem{
|
||||
BatchID: batchID, KeyFingerprint: "sha256:x", AccountStatus: "pending",
|
||||
})
|
||||
|
||||
store.ImportBatchItems().UpdateResult(context.Background(), itemID, "passed", `{"ok":true}`)
|
||||
items, _ := store.ImportBatchItems().GetByBatchID(context.Background(), batchID)
|
||||
if items[0].AccountStatus != "passed" {
|
||||
t.Fatalf("AccountStatus = %q, want passed", items[0].AccountStatus)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportBatchItemsRepoGetByBatchIDEmpty(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
items, err := store.ImportBatchItems().GetByBatchID(context.Background(), 999)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByBatchID() error = %v, want empty result", err)
|
||||
}
|
||||
if len(items) != 0 {
|
||||
t.Fatalf("count = %d, want 0", len(items))
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportBatchItemsRepoValidation(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
_, err := store.ImportBatchItems().Create(context.Background(), ImportBatchItem{
|
||||
BatchID: 0, KeyFingerprint: "k", AccountStatus: "s",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Create batch_id=0 error = nil")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Managed Resources Repo Tests ---
|
||||
|
||||
func TestManagedResourcesRepoCreateAndGet(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
batchID := createTestBatch(t, store)
|
||||
|
||||
id, err := store.ManagedResources().Create(context.Background(), ManagedResource{
|
||||
BatchID: batchID, ResourceType: "group", HostResourceID: "g_01", ResourceName: "test-group",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
resources, _ := store.ManagedResources().GetByBatchID(context.Background(), batchID)
|
||||
if len(resources) != 1 || resources[0].HostResourceID != "g_01" {
|
||||
t.Fatalf("resources = %+v, want 1 with g_01", resources)
|
||||
}
|
||||
_ = id
|
||||
}
|
||||
|
||||
func TestManagedResourcesRepoMultipleResources(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
batchID := createTestBatch(t, store)
|
||||
|
||||
for _, r := range []ManagedResource{
|
||||
{BatchID: batchID, ResourceType: "group", HostResourceID: "g_01", ResourceName: "group-1"},
|
||||
{BatchID: batchID, ResourceType: "channel", HostResourceID: "c_01", ResourceName: "channel-1"},
|
||||
{BatchID: batchID, ResourceType: "account", HostResourceID: "a_01", ResourceName: "account-1"},
|
||||
} {
|
||||
store.ManagedResources().Create(context.Background(), r)
|
||||
}
|
||||
|
||||
resources, _ := store.ManagedResources().GetByBatchID(context.Background(), batchID)
|
||||
if len(resources) != 3 {
|
||||
t.Fatalf("count = %d, want 3", len(resources))
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagedResourcesRepoGetByBatchIDEmpty(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
resources, _ := store.ManagedResources().GetByBatchID(context.Background(), 999)
|
||||
if len(resources) != 0 {
|
||||
t.Fatalf("count = %d, want 0", len(resources))
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagedResourcesRepoValidationErrors(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
r ManagedResource
|
||||
}{
|
||||
{"batch_id zero", ManagedResource{ResourceType: "g", HostResourceID: "h", ResourceName: "n"}},
|
||||
{"empty resource_type", ManagedResource{BatchID: 1, HostResourceID: "h", ResourceName: "n"}},
|
||||
{"empty host_resource_id", ManagedResource{BatchID: 1, ResourceType: "g", ResourceName: "n"}},
|
||||
{"empty resource_name", ManagedResource{BatchID: 1, ResourceType: "g", HostResourceID: "h"}},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := store.ManagedResources().Create(context.Background(), tt.r)
|
||||
if err == nil {
|
||||
t.Fatal("Create() error = nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- Probe Results Repo Tests ---
|
||||
|
||||
func TestProbeResultsRepoCreateAndGet(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
batchID := createTestBatch(t, store)
|
||||
itemID := createTestBatchItem(t, store, batchID)
|
||||
|
||||
id, err := store.ProbeResults().Create(context.Background(), ProbeResult{
|
||||
BatchItemID: itemID, ProbeType: "account_smoke", Status: "passed", SummaryJSON: `{"ok":true}`,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
results, _ := store.ProbeResults().GetByBatchItemID(context.Background(), itemID)
|
||||
if len(results) != 1 || results[0].ProbeType != "account_smoke" {
|
||||
t.Fatalf("results = %+v, want 1 with account_smoke", results)
|
||||
}
|
||||
_ = id
|
||||
}
|
||||
|
||||
func TestProbeResultsRepoMultipleResults(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
batchID := createTestBatch(t, store)
|
||||
itemID := createTestBatchItem(t, store, batchID)
|
||||
|
||||
for _, p := range []ProbeResult{
|
||||
{BatchItemID: itemID, ProbeType: "account_smoke", Status: "passed", SummaryJSON: `{"ok":true}`},
|
||||
{BatchItemID: itemID, ProbeType: "model_list", Status: "passed", SummaryJSON: `{"models":["m1"]}`},
|
||||
} {
|
||||
store.ProbeResults().Create(context.Background(), p)
|
||||
}
|
||||
|
||||
results, _ := store.ProbeResults().GetByBatchItemID(context.Background(), itemID)
|
||||
if len(results) != 2 {
|
||||
t.Fatalf("count = %d, want 2", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProbeResultsRepoGetByBatchItemIDEmpty(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
results, _ := store.ProbeResults().GetByBatchItemID(context.Background(), 999)
|
||||
if len(results) != 0 {
|
||||
t.Fatalf("count = %d, want 0", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProbeResultsRepoValidationErrors(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
probe ProbeResult
|
||||
}{
|
||||
{"batch_item_id zero", ProbeResult{ProbeType: "t", Status: "s"}},
|
||||
{"empty probe_type", ProbeResult{BatchItemID: 1, Status: "s"}},
|
||||
{"empty status", ProbeResult{BatchItemID: 1, ProbeType: "t"}},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := store.ProbeResults().Create(context.Background(), tt.probe)
|
||||
if err == nil {
|
||||
t.Fatal("Create() error = nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- Access Closures Repo Tests ---
|
||||
|
||||
func TestAccessClosureRecordsRepoCreateAndGet(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
batchID := createTestBatch(t, store)
|
||||
|
||||
id, err := store.AccessClosures().Create(context.Background(), AccessClosureRecord{
|
||||
BatchID: batchID, ClosureType: "subscription", Status: "subscription_ready", DetailsJSON: `{"status_code":200}`,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
records, _ := store.AccessClosures().GetByBatchID(context.Background(), batchID)
|
||||
if len(records) != 1 || records[0].ClosureType != "subscription" {
|
||||
t.Fatalf("records = %+v, want 1 subscription", records)
|
||||
}
|
||||
_ = id
|
||||
}
|
||||
|
||||
func TestAccessClosureRecordsRepoMultiple(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
batchID := createTestBatch(t, store)
|
||||
store.AccessClosures().Create(context.Background(), AccessClosureRecord{BatchID: batchID, ClosureType: "subscription", Status: "subscription_ready", DetailsJSON: "{}"})
|
||||
store.AccessClosures().Create(context.Background(), AccessClosureRecord{BatchID: batchID, ClosureType: "self_service", Status: "self_service_ready", DetailsJSON: "{}"})
|
||||
records, _ := store.AccessClosures().GetByBatchID(context.Background(), batchID)
|
||||
if len(records) != 2 {
|
||||
t.Fatalf("count = %d, want 2", len(records))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessClosureRecordsRepoGetByBatchIDEmpty(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
records, _ := store.AccessClosures().GetByBatchID(context.Background(), 999)
|
||||
if len(records) != 0 {
|
||||
t.Fatalf("count = %d, want 0", len(records))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessClosureRecordsRepoValidation(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
_, err := store.AccessClosures().Create(context.Background(), AccessClosureRecord{BatchID: 0, ClosureType: "t", Status: "s"})
|
||||
if err == nil {
|
||||
t.Fatal("Create batch_id=0 error = nil")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Reconcile Runs Repo Tests ---
|
||||
|
||||
func createTestProviderWithPack(t *testing.T, store *DB, packID int64) int64 {
|
||||
t.Helper()
|
||||
id, err := store.Providers().Create(context.Background(), Provider{
|
||||
PackID: packID, ProviderID: "test-provider-" + sanitizeTestName(t.Name()), DisplayName: "TP",
|
||||
BaseURL: "https://tp.com", Platform: "openai",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("createTestProviderWithPack error = %v", err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func createTestProvider(t *testing.T, store *DB) int64 {
|
||||
t.Helper()
|
||||
packID := createTestPack(t, store)
|
||||
return createTestProviderWithPack(t, store, packID)
|
||||
}
|
||||
|
||||
func TestReconcileRunsRepoCreateAndGet(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
providerID := createTestProvider(t, store)
|
||||
|
||||
id, err := store.ReconcileRuns().Create(context.Background(), ReconcileRun{
|
||||
ProviderID: providerID, Status: "active", SummaryJSON: `{"drifted":false}`,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
runs, _ := store.ReconcileRuns().GetByProviderID(context.Background(), providerID)
|
||||
if len(runs) != 1 || runs[0].Status != "active" {
|
||||
t.Fatalf("runs = %+v, want 1 active", runs)
|
||||
}
|
||||
_ = id
|
||||
}
|
||||
|
||||
func TestReconcileRunsRepoMultipleRunsOrderedDesc(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
providerID := createTestProvider(t, store)
|
||||
|
||||
id1, _ := store.ReconcileRuns().Create(context.Background(), ReconcileRun{ProviderID: providerID, Status: "first", SummaryJSON: "{}"})
|
||||
id2, _ := store.ReconcileRuns().Create(context.Background(), ReconcileRun{ProviderID: providerID, Status: "second", SummaryJSON: "{}"})
|
||||
runs, _ := store.ReconcileRuns().GetByProviderID(context.Background(), providerID)
|
||||
if len(runs) != 2 || runs[0].ID != id2 || runs[1].ID != id1 {
|
||||
t.Fatalf("order: got %d, %d; want %d, %d (DESC)", runs[0].ID, runs[1].ID, id2, id1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconcileRunsRepoGetByProviderIDEmpty(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
runs, _ := store.ReconcileRuns().GetByProviderID(context.Background(), 999)
|
||||
if len(runs) != 0 {
|
||||
t.Fatalf("count = %d, want 0", len(runs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconcileRunsRepoValidation(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
_, err := store.ReconcileRuns().Create(context.Background(), ReconcileRun{ProviderID: 0, Status: "s"})
|
||||
if err == nil {
|
||||
t.Fatal("Create provider_id=0 error = nil")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
76
internal/store/sqlite/managed_resources_repo.go
Normal file
76
internal/store/sqlite/managed_resources_repo.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ManagedResource struct {
|
||||
ID int64
|
||||
BatchID int64
|
||||
ResourceType string
|
||||
HostResourceID string
|
||||
ResourceName string
|
||||
}
|
||||
|
||||
type ManagedResourcesRepo struct {
|
||||
db execQuerier
|
||||
}
|
||||
|
||||
func newManagedResourcesRepo(db execQuerier) *ManagedResourcesRepo {
|
||||
return &ManagedResourcesRepo{db: db}
|
||||
}
|
||||
|
||||
func (r *ManagedResourcesRepo) Create(ctx context.Context, resource ManagedResource) (int64, error) {
|
||||
resourceType := strings.TrimSpace(resource.ResourceType)
|
||||
hostResourceID := strings.TrimSpace(resource.HostResourceID)
|
||||
resourceName := strings.TrimSpace(resource.ResourceName)
|
||||
|
||||
switch {
|
||||
case resource.BatchID <= 0:
|
||||
return 0, fmt.Errorf("batch_id is required")
|
||||
case resourceType == "":
|
||||
return 0, fmt.Errorf("resource_type is required")
|
||||
case hostResourceID == "":
|
||||
return 0, fmt.Errorf("host_resource_id is required")
|
||||
case resourceName == "":
|
||||
return 0, fmt.Errorf("resource_name is required")
|
||||
}
|
||||
|
||||
result, err := r.db.ExecContext(ctx, `INSERT INTO managed_resources (batch_id, resource_type, host_resource_id, resource_name) VALUES (?, ?, ?, ?)`, resource.BatchID, resourceType, hostResourceID, resourceName)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("insert managed resource %q: %w", hostResourceID, err)
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("read inserted managed resource id for %q: %w", hostResourceID, err)
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (r *ManagedResourcesRepo) GetByBatchID(ctx context.Context, batchID int64) ([]ManagedResource, error) {
|
||||
if batchID <= 0 {
|
||||
return nil, fmt.Errorf("batch_id is required")
|
||||
}
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, `SELECT id, batch_id, resource_type, host_resource_id, resource_name FROM managed_resources WHERE batch_id = ? ORDER BY id`, batchID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query managed resources: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
resources := make([]ManagedResource, 0)
|
||||
for rows.Next() {
|
||||
var resource ManagedResource
|
||||
if err := rows.Scan(&resource.ID, &resource.BatchID, &resource.ResourceType, &resource.HostResourceID, &resource.ResourceName); err != nil {
|
||||
return nil, fmt.Errorf("scan managed resource: %w", err)
|
||||
}
|
||||
resources = append(resources, resource)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate managed resources: %w", err)
|
||||
}
|
||||
return resources, nil
|
||||
}
|
||||
@@ -7,9 +7,15 @@ import (
|
||||
)
|
||||
|
||||
type Pack struct {
|
||||
PackID string
|
||||
Version string
|
||||
Checksum string
|
||||
ID int64
|
||||
PackID string
|
||||
Version string
|
||||
Checksum string
|
||||
Vendor string
|
||||
TargetHost string
|
||||
MinHostVersion string
|
||||
MaxHostVersion string
|
||||
ManifestJSON string
|
||||
}
|
||||
|
||||
type PacksRepo struct {
|
||||
@@ -20,10 +26,59 @@ func newPacksRepo(db execQuerier) *PacksRepo {
|
||||
return &PacksRepo{db: db}
|
||||
}
|
||||
|
||||
func (r *PacksRepo) GetByID(ctx context.Context, id int64) (Pack, error) {
|
||||
if id <= 0 {
|
||||
return Pack{}, fmt.Errorf("id is required")
|
||||
}
|
||||
|
||||
var pack Pack
|
||||
if err := r.db.QueryRowContext(ctx, `SELECT id, pack_id, version, checksum, vendor, target_host, min_host_version, max_host_version, manifest_json FROM packs WHERE id = ?`, id).Scan(
|
||||
&pack.ID,
|
||||
&pack.PackID,
|
||||
&pack.Version,
|
||||
&pack.Checksum,
|
||||
&pack.Vendor,
|
||||
&pack.TargetHost,
|
||||
&pack.MinHostVersion,
|
||||
&pack.MaxHostVersion,
|
||||
&pack.ManifestJSON,
|
||||
); err != nil {
|
||||
return Pack{}, err
|
||||
}
|
||||
return pack, nil
|
||||
}
|
||||
|
||||
func (r *PacksRepo) GetByPackID(ctx context.Context, packID string) (Pack, error) {
|
||||
packID = strings.TrimSpace(packID)
|
||||
if packID == "" {
|
||||
return Pack{}, fmt.Errorf("pack_id is required")
|
||||
}
|
||||
|
||||
var pack Pack
|
||||
if err := r.db.QueryRowContext(ctx, `SELECT id, pack_id, version, checksum, vendor, target_host, min_host_version, max_host_version, manifest_json FROM packs WHERE pack_id = ?`, packID).Scan(
|
||||
&pack.ID,
|
||||
&pack.PackID,
|
||||
&pack.Version,
|
||||
&pack.Checksum,
|
||||
&pack.Vendor,
|
||||
&pack.TargetHost,
|
||||
&pack.MinHostVersion,
|
||||
&pack.MaxHostVersion,
|
||||
&pack.ManifestJSON,
|
||||
); err != nil {
|
||||
return Pack{}, err
|
||||
}
|
||||
return pack, nil
|
||||
}
|
||||
|
||||
func (r *PacksRepo) Create(ctx context.Context, pack Pack) (int64, error) {
|
||||
packID := strings.TrimSpace(pack.PackID)
|
||||
version := strings.TrimSpace(pack.Version)
|
||||
checksum := strings.TrimSpace(pack.Checksum)
|
||||
manifestJSON := strings.TrimSpace(pack.ManifestJSON)
|
||||
if manifestJSON == "" {
|
||||
manifestJSON = "{}"
|
||||
}
|
||||
|
||||
switch {
|
||||
case packID == "":
|
||||
@@ -36,11 +91,16 @@ func (r *PacksRepo) Create(ctx context.Context, pack Pack) (int64, error) {
|
||||
|
||||
result, err := r.db.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO packs (pack_id, version, checksum)
|
||||
VALUES (?, ?, ?)`,
|
||||
`INSERT INTO packs (pack_id, version, checksum, vendor, target_host, min_host_version, max_host_version, manifest_json)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
packID,
|
||||
version,
|
||||
checksum,
|
||||
strings.TrimSpace(pack.Vendor),
|
||||
strings.TrimSpace(pack.TargetHost),
|
||||
strings.TrimSpace(pack.MinHostVersion),
|
||||
strings.TrimSpace(pack.MaxHostVersion),
|
||||
manifestJSON,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("insert pack %q: %w", packID, err)
|
||||
@@ -50,6 +110,62 @@ func (r *PacksRepo) Create(ctx context.Context, pack Pack) (int64, error) {
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("read inserted pack id for %q: %w", packID, err)
|
||||
}
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (r *PacksRepo) Upsert(ctx context.Context, pack Pack) (int64, error) {
|
||||
packID := strings.TrimSpace(pack.PackID)
|
||||
version := strings.TrimSpace(pack.Version)
|
||||
checksum := strings.TrimSpace(pack.Checksum)
|
||||
manifestJSON := strings.TrimSpace(pack.ManifestJSON)
|
||||
if manifestJSON == "" {
|
||||
manifestJSON = "{}"
|
||||
}
|
||||
|
||||
switch {
|
||||
case packID == "":
|
||||
return 0, fmt.Errorf("pack_id is required")
|
||||
case version == "":
|
||||
return 0, fmt.Errorf("version is required")
|
||||
case checksum == "":
|
||||
return 0, fmt.Errorf("checksum is required")
|
||||
}
|
||||
|
||||
result, err := r.db.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO packs (pack_id, version, checksum, vendor, target_host, min_host_version, max_host_version, manifest_json)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(pack_id) DO UPDATE SET
|
||||
version = excluded.version,
|
||||
checksum = excluded.checksum,
|
||||
vendor = excluded.vendor,
|
||||
target_host = excluded.target_host,
|
||||
min_host_version = excluded.min_host_version,
|
||||
max_host_version = excluded.max_host_version,
|
||||
manifest_json = excluded.manifest_json`,
|
||||
packID,
|
||||
version,
|
||||
checksum,
|
||||
strings.TrimSpace(pack.Vendor),
|
||||
strings.TrimSpace(pack.TargetHost),
|
||||
strings.TrimSpace(pack.MinHostVersion),
|
||||
strings.TrimSpace(pack.MaxHostVersion),
|
||||
manifestJSON,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("upsert pack %q: %w", packID, err)
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err == nil && id > 0 {
|
||||
return id, nil
|
||||
}
|
||||
persisted, getErr := r.GetByPackID(ctx, packID)
|
||||
if getErr != nil {
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("read upserted pack %q: %w", packID, getErr)
|
||||
}
|
||||
return 0, getErr
|
||||
}
|
||||
return persisted.ID, nil
|
||||
}
|
||||
|
||||
159
internal/store/sqlite/packs_repo_test.go
Normal file
159
internal/store/sqlite/packs_repo_test.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPacksRepoCreateAndGet(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
|
||||
id, err := store.Packs().Create(context.Background(), Pack{
|
||||
PackID: "test-pack",
|
||||
Version: "1.0.0",
|
||||
Checksum: "abc123",
|
||||
Vendor: "test-vendor",
|
||||
TargetHost: "sub2api",
|
||||
MinHostVersion: "0.1.0",
|
||||
MaxHostVersion: "0.2.x",
|
||||
ManifestJSON: `{"name":"test"}`,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
if id <= 0 {
|
||||
t.Fatalf("Create() id = %d, want positive", id)
|
||||
}
|
||||
|
||||
got, err := store.Packs().GetByID(context.Background(), id)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByID() error = %v", err)
|
||||
}
|
||||
if got.PackID != "test-pack" || got.Version != "1.0.0" {
|
||||
t.Fatalf("GetByID() = %+v, want pack test-pack", got)
|
||||
}
|
||||
|
||||
got2, err := store.Packs().GetByPackID(context.Background(), "test-pack")
|
||||
if err != nil {
|
||||
t.Fatalf("GetByPackID() error = %v", err)
|
||||
}
|
||||
if got2.ID != id {
|
||||
t.Fatalf("GetByPackID() id = %d, want %d", got2.ID, id)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPacksRepoCreateDefaultsManifestJSON(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
|
||||
id, err := store.Packs().Create(context.Background(), Pack{
|
||||
PackID: "no-manifest",
|
||||
Version: "1.0.0",
|
||||
Checksum: "chk",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
got, err := store.Packs().GetByID(context.Background(), id)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByID() error = %v", err)
|
||||
}
|
||||
if got.ManifestJSON != "{}" {
|
||||
t.Fatalf("ManifestJSON = %q, want {}", got.ManifestJSON)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPacksRepoUpsertCreatesNew(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
|
||||
id, err := store.Packs().Upsert(context.Background(), Pack{
|
||||
PackID: "upsert-pack",
|
||||
Version: "1.0.0",
|
||||
Checksum: "chk1",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Upsert() error = %v", err)
|
||||
}
|
||||
if id <= 0 {
|
||||
t.Fatalf("Upsert() id = %d, want positive", id)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPacksRepoUpsertUpdatesExisting(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
|
||||
id1, err := store.Packs().Upsert(context.Background(), Pack{
|
||||
PackID: "upsert-pack",
|
||||
Version: "1.0.0",
|
||||
Checksum: "chk1",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Upsert() create error = %v", err)
|
||||
}
|
||||
|
||||
id2, err := store.Packs().Upsert(context.Background(), Pack{
|
||||
PackID: "upsert-pack",
|
||||
Version: "2.0.0",
|
||||
Checksum: "chk2",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Upsert() update error = %v", err)
|
||||
}
|
||||
if id2 != id1 {
|
||||
t.Fatalf("Upsert() update returned id %d, want original %d", id2, id1)
|
||||
}
|
||||
|
||||
got, err := store.Packs().GetByID(context.Background(), id1)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByID() error = %v", err)
|
||||
}
|
||||
if got.Version != "2.0.0" {
|
||||
t.Fatalf("Version after upsert = %q, want 2.0.0", got.Version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPacksRepoValidationErrors(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
pack Pack
|
||||
}{
|
||||
{"empty pack_id", Pack{Version: "v", Checksum: "c"}},
|
||||
{"empty version", Pack{PackID: "p", Checksum: "c"}},
|
||||
{"empty checksum", Pack{PackID: "p", Version: "v"}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := store.Packs().Create(context.Background(), tt.pack)
|
||||
if err == nil {
|
||||
t.Fatal("Create() error = nil, want validation error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPacksRepoGetByIDNotFound(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
_, err := store.Packs().GetByID(context.Background(), 999)
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
t.Fatalf("GetByID(999) error = %v, want sql.ErrNoRows", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPacksRepoGetByPackIDEmptyError(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
_, err := store.Packs().GetByPackID(context.Background(), "")
|
||||
if err == nil {
|
||||
t.Fatal("GetByPackID('') error = nil, want error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPacksRepoGetByPackIDNotFound(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
_, err := store.Packs().GetByPackID(context.Background(), "nonexistent")
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
t.Fatalf("GetByPackID() error = %v, want sql.ErrNoRows", err)
|
||||
}
|
||||
}
|
||||
77
internal/store/sqlite/probe_results_repo.go
Normal file
77
internal/store/sqlite/probe_results_repo.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ProbeResult struct {
|
||||
ID int64
|
||||
BatchItemID int64
|
||||
ProbeType string
|
||||
Status string
|
||||
SummaryJSON string
|
||||
}
|
||||
|
||||
type ProbeResultsRepo struct {
|
||||
db execQuerier
|
||||
}
|
||||
|
||||
func newProbeResultsRepo(db execQuerier) *ProbeResultsRepo {
|
||||
return &ProbeResultsRepo{db: db}
|
||||
}
|
||||
|
||||
func (r *ProbeResultsRepo) Create(ctx context.Context, probe ProbeResult) (int64, error) {
|
||||
probeType := strings.TrimSpace(probe.ProbeType)
|
||||
status := strings.TrimSpace(probe.Status)
|
||||
summaryJSON := strings.TrimSpace(probe.SummaryJSON)
|
||||
if summaryJSON == "" {
|
||||
summaryJSON = "{}"
|
||||
}
|
||||
|
||||
switch {
|
||||
case probe.BatchItemID <= 0:
|
||||
return 0, fmt.Errorf("batch_item_id is required")
|
||||
case probeType == "":
|
||||
return 0, fmt.Errorf("probe_type is required")
|
||||
case status == "":
|
||||
return 0, fmt.Errorf("status is required")
|
||||
}
|
||||
|
||||
result, err := r.db.ExecContext(ctx, `INSERT INTO probe_results (batch_item_id, probe_type, status, summary_json) VALUES (?, ?, ?, ?)`, probe.BatchItemID, probeType, status, summaryJSON)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("insert probe result: %w", err)
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("read inserted probe result id: %w", err)
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (r *ProbeResultsRepo) GetByBatchItemID(ctx context.Context, batchItemID int64) ([]ProbeResult, error) {
|
||||
if batchItemID <= 0 {
|
||||
return nil, fmt.Errorf("batch_item_id is required")
|
||||
}
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, `SELECT id, batch_item_id, probe_type, status, summary_json FROM probe_results WHERE batch_item_id = ? ORDER BY id`, batchItemID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query probe results: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
probes := make([]ProbeResult, 0)
|
||||
for rows.Next() {
|
||||
var probe ProbeResult
|
||||
if err := rows.Scan(&probe.ID, &probe.BatchItemID, &probe.ProbeType, &probe.Status, &probe.SummaryJSON); err != nil {
|
||||
return nil, fmt.Errorf("scan probe result: %w", err)
|
||||
}
|
||||
probes = append(probes, probe)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate probe results: %w", err)
|
||||
}
|
||||
return probes, nil
|
||||
}
|
||||
@@ -7,11 +7,20 @@ import (
|
||||
)
|
||||
|
||||
type Provider struct {
|
||||
PackID int64
|
||||
ProviderID string
|
||||
DisplayName string
|
||||
BaseURL string
|
||||
Platform string
|
||||
ID int64
|
||||
PackID int64
|
||||
ProviderID string
|
||||
DisplayName string
|
||||
BaseURL string
|
||||
Platform string
|
||||
AccountType string
|
||||
DefaultModelsJSON string
|
||||
SmokeTestModel string
|
||||
GroupTemplateJSON string
|
||||
ChannelTemplateJSON string
|
||||
PlanTemplateJSON string
|
||||
ImportOptionsJSON string
|
||||
ManifestJSON string
|
||||
}
|
||||
|
||||
type ProvidersRepo struct {
|
||||
@@ -22,11 +31,87 @@ func newProvidersRepo(db execQuerier) *ProvidersRepo {
|
||||
return &ProvidersRepo{db: db}
|
||||
}
|
||||
|
||||
func (r *ProvidersRepo) ListByProviderID(ctx context.Context, providerID string) ([]Provider, error) {
|
||||
providerID = strings.TrimSpace(providerID)
|
||||
if providerID == "" {
|
||||
return nil, fmt.Errorf("provider_id is required")
|
||||
}
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, `SELECT id, pack_id, provider_id, display_name, base_url, platform, account_type, default_models_json, smoke_test_model, group_template_json, channel_template_json, plan_template_json, import_options_json, manifest_json FROM providers WHERE provider_id = ? ORDER BY id`, providerID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query providers by provider_id %q: %w", providerID, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
providers := make([]Provider, 0)
|
||||
for rows.Next() {
|
||||
var provider Provider
|
||||
if err := rows.Scan(
|
||||
&provider.ID,
|
||||
&provider.PackID,
|
||||
&provider.ProviderID,
|
||||
&provider.DisplayName,
|
||||
&provider.BaseURL,
|
||||
&provider.Platform,
|
||||
&provider.AccountType,
|
||||
&provider.DefaultModelsJSON,
|
||||
&provider.SmokeTestModel,
|
||||
&provider.GroupTemplateJSON,
|
||||
&provider.ChannelTemplateJSON,
|
||||
&provider.PlanTemplateJSON,
|
||||
&provider.ImportOptionsJSON,
|
||||
&provider.ManifestJSON,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("scan provider by provider_id %q: %w", providerID, err)
|
||||
}
|
||||
providers = append(providers, provider)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate providers by provider_id %q: %w", providerID, err)
|
||||
}
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
func (r *ProvidersRepo) GetByPackIDAndProviderID(ctx context.Context, packID int64, providerID string) (Provider, error) {
|
||||
if packID <= 0 {
|
||||
return Provider{}, fmt.Errorf("pack_id is required")
|
||||
}
|
||||
providerID = strings.TrimSpace(providerID)
|
||||
if providerID == "" {
|
||||
return Provider{}, fmt.Errorf("provider_id is required")
|
||||
}
|
||||
|
||||
var provider Provider
|
||||
if err := r.db.QueryRowContext(ctx, `SELECT id, pack_id, provider_id, display_name, base_url, platform, account_type, default_models_json, smoke_test_model, group_template_json, channel_template_json, plan_template_json, import_options_json, manifest_json FROM providers WHERE pack_id = ? AND provider_id = ?`, packID, providerID).Scan(
|
||||
&provider.ID,
|
||||
&provider.PackID,
|
||||
&provider.ProviderID,
|
||||
&provider.DisplayName,
|
||||
&provider.BaseURL,
|
||||
&provider.Platform,
|
||||
&provider.AccountType,
|
||||
&provider.DefaultModelsJSON,
|
||||
&provider.SmokeTestModel,
|
||||
&provider.GroupTemplateJSON,
|
||||
&provider.ChannelTemplateJSON,
|
||||
&provider.PlanTemplateJSON,
|
||||
&provider.ImportOptionsJSON,
|
||||
&provider.ManifestJSON,
|
||||
); err != nil {
|
||||
return Provider{}, err
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (r *ProvidersRepo) Create(ctx context.Context, provider Provider) (int64, error) {
|
||||
providerID := strings.TrimSpace(provider.ProviderID)
|
||||
displayName := strings.TrimSpace(provider.DisplayName)
|
||||
baseURL := strings.TrimSpace(provider.BaseURL)
|
||||
platform := strings.TrimSpace(provider.Platform)
|
||||
manifestJSON := strings.TrimSpace(provider.ManifestJSON)
|
||||
if manifestJSON == "" {
|
||||
manifestJSON = "{}"
|
||||
}
|
||||
|
||||
switch {
|
||||
case provider.PackID <= 0:
|
||||
@@ -43,13 +128,21 @@ func (r *ProvidersRepo) Create(ctx context.Context, provider Provider) (int64, e
|
||||
|
||||
result, err := r.db.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO providers (pack_id, provider_id, display_name, base_url, platform)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
`INSERT INTO providers (pack_id, provider_id, display_name, base_url, platform, account_type, default_models_json, smoke_test_model, group_template_json, channel_template_json, plan_template_json, import_options_json, manifest_json)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
provider.PackID,
|
||||
providerID,
|
||||
displayName,
|
||||
baseURL,
|
||||
platform,
|
||||
strings.TrimSpace(provider.AccountType),
|
||||
defaultJSONArray(provider.DefaultModelsJSON),
|
||||
strings.TrimSpace(provider.SmokeTestModel),
|
||||
defaultJSONObject(provider.GroupTemplateJSON),
|
||||
defaultJSONObject(provider.ChannelTemplateJSON),
|
||||
defaultJSONObject(provider.PlanTemplateJSON),
|
||||
defaultJSONObject(provider.ImportOptionsJSON),
|
||||
manifestJSON,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("insert provider %q: %w", providerID, err)
|
||||
@@ -59,6 +152,90 @@ func (r *ProvidersRepo) Create(ctx context.Context, provider Provider) (int64, e
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("read inserted provider id for %q: %w", providerID, err)
|
||||
}
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (r *ProvidersRepo) Upsert(ctx context.Context, provider Provider) (int64, error) {
|
||||
providerID := strings.TrimSpace(provider.ProviderID)
|
||||
displayName := strings.TrimSpace(provider.DisplayName)
|
||||
baseURL := strings.TrimSpace(provider.BaseURL)
|
||||
platform := strings.TrimSpace(provider.Platform)
|
||||
manifestJSON := strings.TrimSpace(provider.ManifestJSON)
|
||||
if manifestJSON == "" {
|
||||
manifestJSON = "{}"
|
||||
}
|
||||
|
||||
switch {
|
||||
case provider.PackID <= 0:
|
||||
return 0, fmt.Errorf("pack_id is required")
|
||||
case providerID == "":
|
||||
return 0, fmt.Errorf("provider_id is required")
|
||||
case displayName == "":
|
||||
return 0, fmt.Errorf("display_name is required")
|
||||
case baseURL == "":
|
||||
return 0, fmt.Errorf("base_url is required")
|
||||
case platform == "":
|
||||
return 0, fmt.Errorf("platform is required")
|
||||
}
|
||||
|
||||
result, err := r.db.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO providers (pack_id, provider_id, display_name, base_url, platform, account_type, default_models_json, smoke_test_model, group_template_json, channel_template_json, plan_template_json, import_options_json, manifest_json)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(pack_id, provider_id) DO UPDATE SET
|
||||
display_name = excluded.display_name,
|
||||
base_url = excluded.base_url,
|
||||
platform = excluded.platform,
|
||||
account_type = excluded.account_type,
|
||||
default_models_json = excluded.default_models_json,
|
||||
smoke_test_model = excluded.smoke_test_model,
|
||||
group_template_json = excluded.group_template_json,
|
||||
channel_template_json = excluded.channel_template_json,
|
||||
plan_template_json = excluded.plan_template_json,
|
||||
import_options_json = excluded.import_options_json,
|
||||
manifest_json = excluded.manifest_json`,
|
||||
provider.PackID,
|
||||
providerID,
|
||||
displayName,
|
||||
baseURL,
|
||||
platform,
|
||||
strings.TrimSpace(provider.AccountType),
|
||||
defaultJSONArray(provider.DefaultModelsJSON),
|
||||
strings.TrimSpace(provider.SmokeTestModel),
|
||||
defaultJSONObject(provider.GroupTemplateJSON),
|
||||
defaultJSONObject(provider.ChannelTemplateJSON),
|
||||
defaultJSONObject(provider.PlanTemplateJSON),
|
||||
defaultJSONObject(provider.ImportOptionsJSON),
|
||||
manifestJSON,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("upsert provider %q: %w", providerID, err)
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err == nil && id > 0 {
|
||||
return id, nil
|
||||
}
|
||||
persisted, getErr := r.GetByPackIDAndProviderID(ctx, provider.PackID, providerID)
|
||||
if getErr != nil {
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("read upserted provider %q: %w", providerID, getErr)
|
||||
}
|
||||
return 0, getErr
|
||||
}
|
||||
return persisted.ID, nil
|
||||
}
|
||||
|
||||
func defaultJSONObject(value string) string {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return "{}"
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func defaultJSONArray(value string) string {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return "[]"
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
172
internal/store/sqlite/providers_repo_test.go
Normal file
172
internal/store/sqlite/providers_repo_test.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestProvidersRepoCreateAndGet(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
|
||||
packID := createTestPack(t, store)
|
||||
providerID, err := store.Providers().Create(context.Background(), Provider{
|
||||
PackID: packID,
|
||||
ProviderID: "deepseek",
|
||||
DisplayName: "DeepSeek",
|
||||
BaseURL: "https://api.deepseek.com",
|
||||
Platform: "openai",
|
||||
AccountType: "api",
|
||||
SmokeTestModel: "deepseek-chat",
|
||||
ManifestJSON: `{"models":["deepseek-chat"]}`,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
if providerID <= 0 {
|
||||
t.Fatalf("Create() id = %d, want positive", providerID)
|
||||
}
|
||||
|
||||
got, err := store.Providers().GetByPackIDAndProviderID(context.Background(), packID, "deepseek")
|
||||
if err != nil {
|
||||
t.Fatalf("GetByPackIDAndProviderID() error = %v", err)
|
||||
}
|
||||
if got.ProviderID != "deepseek" || got.DisplayName != "DeepSeek" {
|
||||
t.Fatalf("GetByPackIDAndProviderID() = %+v, want deepseek", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvidersRepoListByProviderID(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
|
||||
packID1 := createTestPackWithSuffix(t, store, "a")
|
||||
packID2 := createTestPackWithSuffix(t, store, "b")
|
||||
|
||||
store.Providers().Create(context.Background(), Provider{PackID: packID1, ProviderID: "deepseek", DisplayName: "DS1", BaseURL: "https://a.com", Platform: "openai"})
|
||||
store.Providers().Create(context.Background(), Provider{PackID: packID2, ProviderID: "deepseek", DisplayName: "DS2", BaseURL: "https://b.com", Platform: "openai"})
|
||||
|
||||
providers, err := store.Providers().ListByProviderID(context.Background(), "deepseek")
|
||||
if err != nil {
|
||||
t.Fatalf("ListByProviderID() error = %v", err)
|
||||
}
|
||||
if len(providers) != 2 {
|
||||
t.Fatalf("ListByProviderID() count = %d, want 2", len(providers))
|
||||
}
|
||||
}
|
||||
|
||||
func createTestPackWithSuffix(t *testing.T, store *DB, suffix string) int64 {
|
||||
t.Helper()
|
||||
id, err := store.Packs().Create(context.Background(), Pack{
|
||||
PackID: "pack-" + sanitizeTestName(t.Name()) + "-" + suffix, Version: "1.0.0", Checksum: "chk",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("createTestPackWithSuffix error = %v", err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func TestProvidersRepoListByProviderIDEmpty(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
|
||||
providers, err := store.Providers().ListByProviderID(context.Background(), "nonexistent")
|
||||
if err != nil {
|
||||
t.Fatalf("ListByProviderID() error = %v", err)
|
||||
}
|
||||
if len(providers) != 0 {
|
||||
t.Fatalf("ListByProviderID() count = %d, want 0", len(providers))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvidersRepoUpsertCreatesNew(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
packID := createTestPack(t, store)
|
||||
|
||||
id, err := store.Providers().Upsert(context.Background(), Provider{
|
||||
PackID: packID, ProviderID: "upsert-p", DisplayName: "P", BaseURL: "https://u.com", Platform: "openai",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Upsert() error = %v", err)
|
||||
}
|
||||
if id <= 0 {
|
||||
t.Fatalf("Upsert() id = %d, want positive", id)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvidersRepoUpsertUpdatesExisting(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
packID := createTestPack(t, store)
|
||||
|
||||
id1, _ := store.Providers().Upsert(context.Background(), Provider{
|
||||
PackID: packID, ProviderID: "upsert-p", DisplayName: "P1", BaseURL: "https://u1.com", Platform: "openai",
|
||||
})
|
||||
id2, _ := store.Providers().Upsert(context.Background(), Provider{
|
||||
PackID: packID, ProviderID: "upsert-p", DisplayName: "P2", BaseURL: "https://u2.com", Platform: "openai",
|
||||
})
|
||||
if id2 != id1 {
|
||||
t.Fatalf("Upsert update id = %d, want %d", id2, id1)
|
||||
}
|
||||
|
||||
got, _ := store.Providers().GetByPackIDAndProviderID(context.Background(), packID, "upsert-p")
|
||||
if got.DisplayName != "P2" {
|
||||
t.Fatalf("DisplayName after upsert = %q, want P2", got.DisplayName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvidersRepoValidationErrors(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
packID := createTestPack(t, store)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
provider Provider
|
||||
}{
|
||||
{"pack_id zero", Provider{ProviderID: "p", DisplayName: "d", BaseURL: "b", Platform: "openai"}},
|
||||
{"empty provider_id", Provider{PackID: packID, DisplayName: "d", BaseURL: "b", Platform: "openai"}},
|
||||
{"empty display_name", Provider{PackID: packID, ProviderID: "p", BaseURL: "b", Platform: "openai"}},
|
||||
{"empty base_url", Provider{PackID: packID, ProviderID: "p", DisplayName: "d", Platform: "openai"}},
|
||||
{"empty platform", Provider{PackID: packID, ProviderID: "p", DisplayName: "d", BaseURL: "b"}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := store.Providers().Create(context.Background(), tt.provider)
|
||||
if err == nil {
|
||||
t.Fatal("Create() error = nil, want validation error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvidersRepoGetByPackIDAndProviderIDNotFound(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
_, err := store.Providers().GetByPackIDAndProviderID(context.Background(), 999, "p")
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
t.Fatalf("GetByPackIDAndProviderID() error = %v, want sql.ErrNoRows", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvidersRepoGetByPackIDEmptyError(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
_, err := store.Providers().GetByPackIDAndProviderID(context.Background(), 0, "p")
|
||||
if err == nil {
|
||||
t.Fatal("GetByPackIDAndProviderID with packID=0 error = nil, want error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultJSONObject(t *testing.T) {
|
||||
if got := defaultJSONObject(""); got != "{}" {
|
||||
t.Fatalf("defaultJSONObject('') = %q, want {}", got)
|
||||
}
|
||||
if got := defaultJSONObject(`{"a":1}`); got != `{"a":1}` {
|
||||
t.Fatalf("defaultJSONObject() = %q, want input", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultJSONArray(t *testing.T) {
|
||||
if got := defaultJSONArray(""); got != "[]" {
|
||||
t.Fatalf("defaultJSONArray('') = %q, want []", got)
|
||||
}
|
||||
if got := defaultJSONArray(`["a"]`); got != `["a"]` {
|
||||
t.Fatalf("defaultJSONArray() = %q, want input", got)
|
||||
}
|
||||
}
|
||||
73
internal/store/sqlite/reconcile_runs_repo.go
Normal file
73
internal/store/sqlite/reconcile_runs_repo.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ReconcileRun struct {
|
||||
ID int64
|
||||
ProviderID int64
|
||||
Status string
|
||||
SummaryJSON string
|
||||
}
|
||||
|
||||
type ReconcileRunsRepo struct {
|
||||
db execQuerier
|
||||
}
|
||||
|
||||
func newReconcileRunsRepo(db execQuerier) *ReconcileRunsRepo {
|
||||
return &ReconcileRunsRepo{db: db}
|
||||
}
|
||||
|
||||
func (r *ReconcileRunsRepo) Create(ctx context.Context, run ReconcileRun) (int64, error) {
|
||||
status := strings.TrimSpace(run.Status)
|
||||
summaryJSON := strings.TrimSpace(run.SummaryJSON)
|
||||
if summaryJSON == "" {
|
||||
summaryJSON = "{}"
|
||||
}
|
||||
|
||||
switch {
|
||||
case run.ProviderID <= 0:
|
||||
return 0, fmt.Errorf("provider_id is required")
|
||||
case status == "":
|
||||
return 0, fmt.Errorf("status is required")
|
||||
}
|
||||
|
||||
result, err := r.db.ExecContext(ctx, `INSERT INTO reconcile_runs (provider_id, status, summary_json) VALUES (?, ?, ?)`, run.ProviderID, status, summaryJSON)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("insert reconcile run: %w", err)
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("read inserted reconcile run id: %w", err)
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (r *ReconcileRunsRepo) GetByProviderID(ctx context.Context, providerID int64) ([]ReconcileRun, error) {
|
||||
if providerID <= 0 {
|
||||
return nil, fmt.Errorf("provider_id is required")
|
||||
}
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, `SELECT id, provider_id, status, summary_json FROM reconcile_runs WHERE provider_id = ? ORDER BY id DESC`, providerID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query reconcile runs: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
runs := make([]ReconcileRun, 0)
|
||||
for rows.Next() {
|
||||
var run ReconcileRun
|
||||
if err := rows.Scan(&run.ID, &run.ProviderID, &run.Status, &run.SummaryJSON); err != nil {
|
||||
return nil, fmt.Errorf("scan reconcile run: %w", err)
|
||||
}
|
||||
runs = append(runs, run)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate reconcile runs: %w", err)
|
||||
}
|
||||
return runs, nil
|
||||
}
|
||||
@@ -4,10 +4,10 @@
|
||||
|
||||
它不是宿主原生插件,而是一个可被控制面读取的 `model_pack`,用于描述国产模型 provider 的默认接入模板、默认模型映射、默认套餐和导入约束。
|
||||
|
||||
当前目录仅提供协议样例:
|
||||
当前目录现在同时包含:
|
||||
|
||||
- `pack.json.example`
|
||||
- `providers/deepseek.json.example`
|
||||
- 真实可校验包:`pack.json`、`providers/deepseek.json`、`checksums.txt`
|
||||
- 协议样例:`pack.json.example`、`providers/deepseek.json.example`
|
||||
|
||||
后续真实交付时,可以扩展更多 provider:
|
||||
|
||||
|
||||
2
packs/openai-cn-pack/checksums.txt
Normal file
2
packs/openai-cn-pack/checksums.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
db931e9a90f6c1040d285c65582c5dae4c85075e85ce6d87e59cd39a6441d6f1 pack.json
|
||||
fc2259a85de73cd14ea3f0d6ffdf71be79296d50cf9cbee604633d36492fec49 providers/deepseek.json
|
||||
10
packs/openai-cn-pack/pack.json
Normal file
10
packs/openai-cn-pack/pack.json
Normal file
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"pack_id": "openai-cn-pack",
|
||||
"version": "1.0.0",
|
||||
"vendor": "YourTeam",
|
||||
"target_host": "sub2api",
|
||||
"min_host_version": "0.1.126",
|
||||
"max_host_version": "0.2.x",
|
||||
"providers_dir": "providers",
|
||||
"checksum_file": "checksums.txt"
|
||||
}
|
||||
31
packs/openai-cn-pack/providers/deepseek.json
Normal file
31
packs/openai-cn-pack/providers/deepseek.json
Normal file
@@ -0,0 +1,31 @@
|
||||
{
|
||||
"provider_id": "deepseek",
|
||||
"display_name": "DeepSeek OpenAI Compatible",
|
||||
"base_url": "https://api.deepseek.com",
|
||||
"platform": "openai",
|
||||
"account_type": "api",
|
||||
"default_models": ["deepseek-chat", "deepseek-reasoner"],
|
||||
"smoke_test_model": "deepseek-chat",
|
||||
"group_template": {
|
||||
"name": "DeepSeek 默认分组",
|
||||
"rate_multiplier": 1.0
|
||||
},
|
||||
"channel_template": {
|
||||
"name": "DeepSeek 默认渠道",
|
||||
"model_mapping": {
|
||||
"deepseek-chat": "deepseek-chat",
|
||||
"deepseek-reasoner": "deepseek-reasoner"
|
||||
}
|
||||
},
|
||||
"plan_template": {
|
||||
"name": "DeepSeek 默认套餐",
|
||||
"price": 19.9,
|
||||
"validity_days": 30,
|
||||
"validity_unit": "day"
|
||||
},
|
||||
"import": {
|
||||
"supports_multi_key": true,
|
||||
"supports_strict": true,
|
||||
"supports_partial": true
|
||||
}
|
||||
}
|
||||
@@ -66,9 +66,9 @@ func TestLoadAdminTokenFromEnvReturnsErrorWhenMissing(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBootstrapBuildsServerWithStartupConfigOnly(t *testing.T) {
|
||||
func TestBootstrapBuildsServerWithStartupConfigAndAdminToken(t *testing.T) {
|
||||
t.Setenv("SUB2API_CRM_LISTEN_ADDR", ":8181")
|
||||
t.Setenv("SUB2API_CRM_ADMIN_TOKEN", "")
|
||||
t.Setenv("SUB2API_CRM_ADMIN_TOKEN", "admin-token")
|
||||
|
||||
server, err := app.Bootstrap(context.Background())
|
||||
if err != nil {
|
||||
|
||||
64
tests/integration/distribution_smoke_test.go
Normal file
64
tests/integration/distribution_smoke_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package integration_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDistributionArtifactsExistAndReferenceRequiredEnv(t *testing.T) {
|
||||
root := repoRoot(t)
|
||||
for _, path := range []string{
|
||||
filepath.Join(root, "Dockerfile"),
|
||||
filepath.Join(root, ".env.example"),
|
||||
filepath.Join(root, "docker-compose.yml"),
|
||||
filepath.Join(root, "docs", "DEPLOYMENT.md"),
|
||||
} {
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
t.Fatalf("required distribution artifact %q missing: %v", path, err)
|
||||
}
|
||||
}
|
||||
|
||||
dockerfile := mustReadText(t, filepath.Join(root, "Dockerfile"))
|
||||
if !strings.Contains(dockerfile, "SUB2API_CRM_ADMIN_TOKEN") {
|
||||
t.Fatalf("Dockerfile must document runtime dependency on SUB2API_CRM_ADMIN_TOKEN; content=%s", dockerfile)
|
||||
}
|
||||
|
||||
envExample := mustReadText(t, filepath.Join(root, ".env.example"))
|
||||
for _, key := range []string{"SUB2API_CRM_LISTEN_ADDR", "SUB2API_CRM_SQLITE_DSN", "SUB2API_CRM_ADMIN_TOKEN"} {
|
||||
if !strings.Contains(envExample, key+"=") {
|
||||
t.Fatalf(".env.example missing %s; content=%s", key, envExample)
|
||||
}
|
||||
}
|
||||
|
||||
compose := mustReadText(t, filepath.Join(root, "docker-compose.yml"))
|
||||
if !strings.Contains(compose, "/healthz") {
|
||||
t.Fatalf("docker-compose.yml missing healthz probe; content=%s", compose)
|
||||
}
|
||||
|
||||
deployment := mustReadText(t, filepath.Join(root, "docs", "DEPLOYMENT.md"))
|
||||
for _, needle := range []string{"docker compose up --build -d", "SUB2API_CRM_ADMIN_TOKEN", "/healthz"} {
|
||||
if !strings.Contains(deployment, needle) {
|
||||
t.Fatalf("DEPLOYMENT.md missing %q; content=%s", needle, deployment)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func repoRoot(t *testing.T) string {
|
||||
t.Helper()
|
||||
root, err := filepath.Abs(filepath.Join("..", ".."))
|
||||
if err != nil {
|
||||
t.Fatalf("resolve repo root: %v", err)
|
||||
}
|
||||
return root
|
||||
}
|
||||
|
||||
func mustReadText(t *testing.T, path string) string {
|
||||
t.Helper()
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("read %s: %v", path, err)
|
||||
}
|
||||
return string(content)
|
||||
}
|
||||
@@ -193,6 +193,89 @@ func TestSub2APIHostAdapterGetAccountModelsParsesEnvelope(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSub2APIHostAdapterChecksGatewayAccess(t *testing.T) {
|
||||
server := newSub2APIStubServer(t, sub2APIStubConfig{
|
||||
requireAPIKey: true,
|
||||
version: "0.1.126",
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
client, err := sub2api.NewClient(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("NewClient() error = %v", err)
|
||||
}
|
||||
|
||||
result, err := client.CheckGatewayAccess(context.Background(), sub2api.GatewayAccessCheckRequest{
|
||||
APIKey: "user-api-key",
|
||||
ExpectedModel: "deepseek-chat",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CheckGatewayAccess() error = %v", err)
|
||||
}
|
||||
if !result.OK || !result.HasExpectedModel {
|
||||
t.Fatalf("CheckGatewayAccess() = %+v, want ok with expected model", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSub2APIHostAdapterDeletesManagedResources(t *testing.T) {
|
||||
server := newSub2APIStubServer(t, sub2APIStubConfig{
|
||||
requireAPIKey: true,
|
||||
version: "0.1.126",
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
client, err := sub2api.NewClient(server.URL, sub2api.WithAPIKey("api-key"))
|
||||
if err != nil {
|
||||
t.Fatalf("NewClient() error = %v", err)
|
||||
}
|
||||
|
||||
if err := client.DeleteAccount(context.Background(), "account_1"); err != nil {
|
||||
t.Fatalf("DeleteAccount() error = %v", err)
|
||||
}
|
||||
if err := client.DeleteChannel(context.Background(), "channel_1"); err != nil {
|
||||
t.Fatalf("DeleteChannel() error = %v", err)
|
||||
}
|
||||
if err := client.DeleteGroup(context.Background(), "group_1"); err != nil {
|
||||
t.Fatalf("DeleteGroup() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSub2APIHostAdapterListManagedResources(t *testing.T) {
|
||||
server := newSub2APIStubServer(t, sub2APIStubConfig{
|
||||
requireAPIKey: true,
|
||||
version: "0.1.126",
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
client, err := sub2api.NewClient(server.URL, sub2api.WithAPIKey("api-key"))
|
||||
if err != nil {
|
||||
t.Fatalf("NewClient() error = %v", err)
|
||||
}
|
||||
|
||||
snapshot, err := client.ListManagedResources(context.Background(), sub2api.ListManagedResourcesRequest{
|
||||
GroupName: "crm-deepseek-group",
|
||||
ChannelName: "crm-deepseek-channel",
|
||||
PlanName: "crm-deepseek-plan",
|
||||
AccountNamePrefix: "deepseek-",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ListManagedResources() error = %v", err)
|
||||
}
|
||||
|
||||
if len(snapshot.Groups) != 1 || snapshot.Groups[0].ID != "group_1" {
|
||||
t.Fatalf("Groups = %+v, want one group_1 match", snapshot.Groups)
|
||||
}
|
||||
if len(snapshot.Channels) != 2 || snapshot.Channels[0].ID != "channel_1" || snapshot.Channels[1].ID != "channel_2" {
|
||||
t.Fatalf("Channels = %+v, want two matching channels", snapshot.Channels)
|
||||
}
|
||||
if len(snapshot.Plans) != 1 || snapshot.Plans[0].ID != "plan_1" {
|
||||
t.Fatalf("Plans = %+v, want one plan_1 match", snapshot.Plans)
|
||||
}
|
||||
if len(snapshot.Accounts) != 2 || snapshot.Accounts[0].ID != "account_1" || snapshot.Accounts[1].ID != "account_2" {
|
||||
t.Fatalf("Accounts = %+v, want two deepseek account matches", snapshot.Accounts)
|
||||
}
|
||||
}
|
||||
|
||||
type sub2APIStubConfig struct {
|
||||
requireAPIKey bool
|
||||
version string
|
||||
@@ -220,52 +303,106 @@ func newSub2APIStubServer(t *testing.T, cfg sub2APIStubConfig) *httptest.Server
|
||||
if !mustStubAuth(t, w, r, cfg.requireAPIKey) {
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodPost {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{
|
||||
"data": []map[string]any{
|
||||
{"id": "group_1", "name": "crm-deepseek-group"},
|
||||
{"id": "group_2", "name": "other-group"},
|
||||
},
|
||||
})
|
||||
case http.MethodPost:
|
||||
var payload map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&payload)
|
||||
if payload["name"] == "" || payload["rate_multiplier"] == nil {
|
||||
writeJSON(t, w, http.StatusUnprocessableEntity, map[string]any{"error": "validation failed"})
|
||||
return
|
||||
}
|
||||
writeJSON(t, w, http.StatusCreated, map[string]any{
|
||||
"data": map[string]any{
|
||||
"id": "group_1",
|
||||
"name": payload["name"],
|
||||
},
|
||||
})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
mux.HandleFunc("/api/v1/admin/groups/", func(w http.ResponseWriter, r *http.Request) {
|
||||
if !mustStubAuth(t, w, r, cfg.requireAPIKey) {
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodDelete {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
var payload map[string]any
|
||||
_ = json.NewDecoder(r.Body).Decode(&payload)
|
||||
if payload["name"] == "" || payload["rate_multiplier"] == nil {
|
||||
writeJSON(t, w, http.StatusUnprocessableEntity, map[string]any{"error": "validation failed"})
|
||||
return
|
||||
}
|
||||
writeJSON(t, w, http.StatusCreated, map[string]any{
|
||||
"data": map[string]any{
|
||||
"id": "group_1",
|
||||
"name": payload["name"],
|
||||
},
|
||||
})
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
mux.HandleFunc("/api/v1/admin/channels", func(w http.ResponseWriter, r *http.Request) {
|
||||
if !mustStubAuth(t, w, r, cfg.requireAPIKey) {
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodPost {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{
|
||||
"data": []map[string]any{
|
||||
{"id": "channel_1", "name": "crm-deepseek-channel"},
|
||||
{"id": "channel_2", "name": "crm-deepseek-channel"},
|
||||
{"id": "channel_3", "name": "other-channel"},
|
||||
},
|
||||
})
|
||||
case http.MethodPost:
|
||||
writeJSON(t, w, http.StatusUnprocessableEntity, map[string]any{"error": "validation failed"})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
mux.HandleFunc("/api/v1/admin/channels/", func(w http.ResponseWriter, r *http.Request) {
|
||||
if !mustStubAuth(t, w, r, cfg.requireAPIKey) {
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodDelete {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
writeJSON(t, w, http.StatusUnprocessableEntity, map[string]any{"error": "validation failed"})
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
mux.HandleFunc("/api/v1/admin/payment/plans", func(w http.ResponseWriter, r *http.Request) {
|
||||
if !mustStubAuth(t, w, r, cfg.requireAPIKey) {
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodPost {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{
|
||||
"data": []map[string]any{
|
||||
{"id": "plan_1", "name": "crm-deepseek-plan"},
|
||||
{"id": "plan_2", "name": "other-plan"},
|
||||
},
|
||||
})
|
||||
case http.MethodPost:
|
||||
writeJSON(t, w, http.StatusUnprocessableEntity, map[string]any{"error": "validation failed"})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
writeJSON(t, w, http.StatusUnprocessableEntity, map[string]any{"error": "validation failed"})
|
||||
})
|
||||
mux.HandleFunc("/api/v1/admin/accounts", func(w http.ResponseWriter, r *http.Request) {
|
||||
if !mustStubAuth(t, w, r, cfg.requireAPIKey) {
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodPost {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{
|
||||
"data": []map[string]any{
|
||||
{"id": "account_1", "name": "deepseek-01"},
|
||||
{"id": "account_2", "name": "deepseek-02"},
|
||||
{"id": "account_3", "name": "other-01"},
|
||||
},
|
||||
})
|
||||
case http.MethodPost:
|
||||
writeJSON(t, w, http.StatusUnprocessableEntity, map[string]any{"error": "validation failed"})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
writeJSON(t, w, http.StatusUnprocessableEntity, map[string]any{"error": "validation failed"})
|
||||
})
|
||||
mux.HandleFunc("/api/v1/admin/accounts/batch", func(w http.ResponseWriter, r *http.Request) {
|
||||
if !mustStubAuth(t, w, r, cfg.requireAPIKey) {
|
||||
@@ -287,6 +424,14 @@ func newSub2APIStubServer(t *testing.T, cfg sub2APIStubConfig) *httptest.Server
|
||||
return
|
||||
}
|
||||
parts := strings.Split(strings.TrimPrefix(r.URL.Path, "/api/v1/admin/accounts/"), "/")
|
||||
if len(parts) == 1 {
|
||||
if r.Method != http.MethodDelete {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
if len(parts) != 2 {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
@@ -333,6 +478,19 @@ func newSub2APIStubServer(t *testing.T, cfg sub2APIStubConfig) *httptest.Server
|
||||
},
|
||||
})
|
||||
})
|
||||
mux.HandleFunc("/v1/models", func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.Header.Get("x-api-key"); got != "user-api-key" {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"error":"unauthorized"}`))
|
||||
return
|
||||
}
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{
|
||||
"data": []map[string]any{
|
||||
{"id": "deepseek-chat"},
|
||||
{"id": "deepseek-reasoner"},
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
return httptest.NewServer(mux)
|
||||
}
|
||||
|
||||
@@ -108,8 +108,8 @@ func TestStoreInitRecordsMigrationLedgerOnce(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("first sqlite.Open() error = %v", err)
|
||||
}
|
||||
if got := countRows(t, store1.SQLDB(), "schema_migrations"); got != 1 {
|
||||
t.Fatalf("schema_migrations row count after first open = %d, want 1", got)
|
||||
if got := countRows(t, store1.SQLDB(), "schema_migrations"); got != 3 {
|
||||
t.Fatalf("schema_migrations row count after first open = %d, want 3", got)
|
||||
}
|
||||
if err := store1.Close(); err != nil {
|
||||
t.Fatalf("first store.Close() error = %v", err)
|
||||
@@ -121,8 +121,8 @@ func TestStoreInitRecordsMigrationLedgerOnce(t *testing.T) {
|
||||
}
|
||||
defer closeTestStore(t, store2)
|
||||
|
||||
if got := countRows(t, store2.SQLDB(), "schema_migrations"); got != 1 {
|
||||
t.Fatalf("schema_migrations row count after second open = %d, want 1", got)
|
||||
if got := countRows(t, store2.SQLDB(), "schema_migrations"); got != 3 {
|
||||
t.Fatalf("schema_migrations row count after second open = %d, want 3", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,8 +140,8 @@ func TestStoreInitBackfillsLedgerForCompletePreLedgerSchema(t *testing.T) {
|
||||
}
|
||||
defer closeTestStore(t, store)
|
||||
|
||||
if got := countRows(t, store.SQLDB(), "schema_migrations"); got != 1 {
|
||||
t.Fatalf("schema_migrations row count after backfill = %d, want 1", got)
|
||||
if got := countRows(t, store.SQLDB(), "schema_migrations"); got != 3 {
|
||||
t.Fatalf("schema_migrations row count after backfill = %d, want 3", got)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
141
tests/integration/store_runtime_test.go
Normal file
141
tests/integration/store_runtime_test.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package integration_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
func TestStoreRuntimeCreatesOperationalTables(t *testing.T) {
|
||||
store := openTestStore(t)
|
||||
defer closeTestStore(t, store)
|
||||
|
||||
for _, table := range []string{
|
||||
"hosts",
|
||||
"packs",
|
||||
"providers",
|
||||
"import_batches",
|
||||
"import_batch_items",
|
||||
"managed_resources",
|
||||
"probe_results",
|
||||
"access_closure_records",
|
||||
"reconcile_runs",
|
||||
} {
|
||||
if !tableExists(t, store.SQLDB(), table) {
|
||||
t.Fatalf("table %q does not exist after store initialization", table)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreRuntimePersistsOperationalRecords(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
store := openTestStore(t)
|
||||
defer closeTestStore(t, store)
|
||||
|
||||
hostID, err := store.Hosts().Create(ctx, sqlite.Host{
|
||||
HostID: "host-1",
|
||||
BaseURL: "https://sub2api.example.com",
|
||||
HostVersion: "0.1.126",
|
||||
CapabilityProbeJSON: `{"supports_batch_accounts":true}`,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Hosts().Create() error = %v", err)
|
||||
}
|
||||
|
||||
packID, err := store.Packs().Create(ctx, sqlite.Pack{
|
||||
PackID: "openai-cn-pack",
|
||||
Version: "1.0.0",
|
||||
Checksum: "checksum-1",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Packs().Create() error = %v", err)
|
||||
}
|
||||
|
||||
providerID, err := store.Providers().Create(ctx, sqlite.Provider{
|
||||
PackID: packID,
|
||||
ProviderID: "deepseek",
|
||||
DisplayName: "DeepSeek",
|
||||
BaseURL: "https://api.deepseek.com",
|
||||
Platform: "openai",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Providers().Create() error = %v", err)
|
||||
}
|
||||
|
||||
batchID, err := store.ImportBatches().Create(ctx, sqlite.ImportBatch{
|
||||
HostID: hostID,
|
||||
PackID: packID,
|
||||
ProviderID: providerID,
|
||||
Mode: "strict",
|
||||
BatchStatus: "running",
|
||||
AccessStatus: "not_configured",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ImportBatches().Create() error = %v", err)
|
||||
}
|
||||
|
||||
itemID, err := store.ImportBatchItems().Create(ctx, sqlite.ImportBatchItem{
|
||||
BatchID: batchID,
|
||||
KeyFingerprint: "fp-1",
|
||||
AccountStatus: "pending",
|
||||
ProbeSummaryJSON: `{}`,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ImportBatchItems().Create() error = %v", err)
|
||||
}
|
||||
|
||||
if _, err := store.ManagedResources().Create(ctx, sqlite.ManagedResource{
|
||||
BatchID: batchID,
|
||||
ResourceType: "group",
|
||||
HostResourceID: "group-1",
|
||||
ResourceName: "deepseek-group",
|
||||
}); err != nil {
|
||||
t.Fatalf("ManagedResources().Create() error = %v", err)
|
||||
}
|
||||
|
||||
if _, err := store.ProbeResults().Create(ctx, sqlite.ProbeResult{
|
||||
BatchItemID: itemID,
|
||||
ProbeType: "models",
|
||||
Status: "passed",
|
||||
SummaryJSON: `{"models":["deepseek-chat"]}`,
|
||||
}); err != nil {
|
||||
t.Fatalf("ProbeResults().Create() error = %v", err)
|
||||
}
|
||||
|
||||
if _, err := store.AccessClosures().Create(ctx, sqlite.AccessClosureRecord{
|
||||
BatchID: batchID,
|
||||
ClosureType: "subscription",
|
||||
Status: "subscription_ready",
|
||||
DetailsJSON: `{"api_key_bound":true}`,
|
||||
}); err != nil {
|
||||
t.Fatalf("AccessClosures().Create() error = %v", err)
|
||||
}
|
||||
|
||||
if _, err := store.ReconcileRuns().Create(ctx, sqlite.ReconcileRun{
|
||||
ProviderID: providerID,
|
||||
Status: "drifted",
|
||||
SummaryJSON: `{"missing_resources":1}`,
|
||||
}); err != nil {
|
||||
t.Fatalf("ReconcileRuns().Create() error = %v", err)
|
||||
}
|
||||
|
||||
if got := countRows(t, store.SQLDB(), "import_batches"); got != 1 {
|
||||
t.Fatalf("import_batches row count = %d, want 1", got)
|
||||
}
|
||||
if got := countRows(t, store.SQLDB(), "import_batch_items"); got != 1 {
|
||||
t.Fatalf("import_batch_items row count = %d, want 1", got)
|
||||
}
|
||||
if got := countRows(t, store.SQLDB(), "managed_resources"); got != 1 {
|
||||
t.Fatalf("managed_resources row count = %d, want 1", got)
|
||||
}
|
||||
if got := countRows(t, store.SQLDB(), "probe_results"); got != 1 {
|
||||
t.Fatalf("probe_results row count = %d, want 1", got)
|
||||
}
|
||||
if got := countRows(t, store.SQLDB(), "access_closure_records"); got != 1 {
|
||||
t.Fatalf("access_closure_records row count = %d, want 1", got)
|
||||
}
|
||||
if got := countRows(t, store.SQLDB(), "reconcile_runs"); got != 1 {
|
||||
t.Fatalf("reconcile_runs row count = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user