diff --git a/docs/EXECUTION_BOARD.md b/docs/EXECUTION_BOARD.md index aa158f4a..dde4f878 100644 --- a/docs/EXECUTION_BOARD.md +++ b/docs/EXECUTION_BOARD.md @@ -1,7 +1,7 @@ # sub2api-cn-relay-manager 执行板 -日期:2026-05-19 -当前 Gate:CONDITIONAL_APPROVED(代码门禁通过;2026-05-18 fresh redeploy 验证确认 self_service / subscription 访问链路可打通;2026-05-19 current-code remote43 追踪后发现 DeepSeek/MiniMax 的 channel 创建请求漏传 model_mapping / restrict_models / billing_model_source,已补代码与测试,但真实宿主 access gate 仍需重新验收) +日期:2026-05-20 +当前 Gate:BLOCKED(代码门禁仍通过,但 2026-05-20 current-code CRM(18092) + remote43 fresh host(18097) 真实宿主复验失败:DeepSeek batch=22、MiniMax batch=23 均仅到 `partially_succeeded/access_status=broken`;宿主普通用户 `/v1/models` 仍暴露 GPT 系默认模型,gateway closure 未通过,不能宣称可上线) 目标:实现独立控制面、零侵入宿主、可导入国产模型并具备可运维的导入/回滚/访问闭环。 ## 本轮已完成 @@ -36,7 +36,11 @@ 9. current-code remote43 导入链路已补齐 tunnel-aware 验证能力 - `scripts/import_remote43_provider.sh` 新增 `CRM_HOST_BASE`,允许把“operator 访问 host 地址”和“CRM 进程访问 host 地址”分离 - latest artifact:`/home/long/artifacts/real-host-acceptance/20260519_195827_remote43_deepseek_key_import` - - 结论:import / batch detail / managed resources 已真实落库;本轮定位到 channel 创建缺少 model_mapping / restrict_models / billing_model_source,已补齐实现与测试,待重新跑真实宿主验收 + - 结论:import / batch detail / managed resources 已真实落库;前一轮定位到 channel 创建缺少 model_mapping / restrict_models / billing_model_source,已补齐实现与测试 +10. current-code remote43 access gate 根因修正已落地 + - subscription access 改为宿主侧闭环:CRM 不再依赖外部预先给定的宿主普通用户 key,而是按 `subscription_users` selector 在宿主创建/查找托管普通用户、登录创建托管 key、回写 allowed_groups / balance、再执行订阅分配 + - account 创建请求现在同步写入 `credentials.model_mapping`,修正 `/v1/models` 读取 account model whitelist 时回退到 GPT 默认集合的问题 + - 新增/更新测试覆盖:`internal/access`、`internal/provision`、`internal/host/sub2api` ## 已验证门禁 @@ -76,35 +80,44 @@ - `latest_access_status=broken` - `access preview available=false` - `reconcile status=drifted`,其中 `probe_failures=1` -4. 根因归类 - - `09-models.headers.txt` / `10-models.body.json` 显示普通用户实际看到的是 GPT 系模型,而非预期的 `deepseek-v4-pro` - - 因此本轮 FAIL 应归类为“上游 key/模型能力不匹配或普通用户绑定命中了错误 group”,不是 current-code CRM bootstrap / import 主链路故障。 +4. 当前修正 + - 旧 artifact 中 `09-models.headers.txt` / `10-models.body.json` 暴露 GPT 系模型,根因已重新归类为:CRM 写了 channel model_mapping,但 account `credentials.model_mapping` 未同步,导致宿主 `/v1/models` 从 account 视图回退到默认模型集 + - 同时,旧脚本/调用路径把外部 `subscription_users` / `access_api_key` 直接当宿主用户和宿主 key 使用,无法形成“宿主普通用户创建/查找 + key + 订阅分配”的真正闭环;该问题现已改为宿主托管闭环 + - 代码侧阻断点已修复;下一步只剩 DeepSeek / MiniMax 真实 key 复验 ## 剩余项(含当前外部门禁) -1. DeepSeek / MiniMax real-host access gate 仍需复验(外部门禁) - - 真实宿主曾出现普通用户 `/v1/models` 暴露 GPT 系模型的漂移;本轮已补齐 channel 侧 model_mapping / restrict_models / billing_model_source 传参 - - 53hk 中转 key 当前未验证可用,不能当作主结论 - - 在 current-code remote43 路径上,这一项仍需重新跑真实验收 -2. 结构债务仍存在 +1. current-code real-host access gate 失败,需先修复再谈上线 + - DeepSeek:artifact `artifacts/real-host-acceptance/20260520_123726_remote43_deepseek_key_import/03-import.body.json` 显示 `batch_id=22`、`batch_status=partially_succeeded`、`access_status=broken` + - MiniMax:current-code CRM(18092) 对 remote43 fresh host(18097) 手工复验得到 `batch_id=23`、`batch_status=partially_succeeded`、`access_status=broken` + - 两条链路的 `probe_summary_json` / gateway probe 都显示宿主普通用户 `/v1/models` 返回 GPT-5.x / GPT Image 默认集合,未暴露 DeepSeek / MiniMax 目标模型 + - 2026-05-20 复核补充:fresh host 上 `groups/channels/account_groups` 已按期望落库,channel 也已具备 `model_mapping + restrict_models + billing_model_source=channel_mapped`;但 `accounts.credentials` 真实仅持久化 `api_key/base_url`,`GET /api/v1/admin/accounts/{id}/models` 仍返回 GPT 默认模型集,`POST /api/v1/admin/accounts/{id}/test` 也会默认拿 `gpt-5.4` 探测并报 `model_not_found`。当前根因已重新归类为“宿主 account 模型暴露契约仍未被 current-code 对齐”,不能再把问题简化成 `channel` 参数缺失或“只差同步 `credentials.model_mapping`”。 + - pack contract 漂移已发现并修复:`packs/openai-cn-pack/providers/deepseek.json` 之前出现 `default_models/smoke_test_model` 与 `channel_template.model_mapping` 不一致;`internal/pack` 现已新增校验,要求 `smoke_test_model` 必须出现在 `channel_template.model_mapping`,且 `default_models` 必须被 `channel_template.model_mapping` 全量覆盖,避免类似漂移再次混入真实宿主验收。 + - 2026-05-20 21:50 补充:已修复 current-code `channel` 创建/纠偏时 `model_pricing` 丢失的问题。CRM `http://127.0.0.1:18100` 对 `remote43-fresh18097-deepseek-1779280533` 复跑 `POST /api/providers/deepseek/import` 返回 `batch_id=4`、`access_status=subscription_ready`;宿主 `GET /api/v1/admin/channels/4` 已可见 `model_pricing=[{platform:"openai", models:["deepseek-v4-pro","deepseek-v4-flash"], billing_mode:"token", intervals:[]}]`,说明“已存在 channel 可 PUT 纠偏”已生效。当前 remaining gate 不再是 channel pricing 缺失,而是更高层的 provider/account 行为问题。 +2. 真实宿主脚本存在环境绑定缺陷 + - `scripts/import_remote43_provider.sh` 仍把 Postgres/Redis 容器名硬编码到 `sub2api-relaymgr-pg` / `sub2api-relaymgr-redis` + - 当目标切到 fresh host(18097) 时,脚本会把 subscription user/key prep 误打到旧 relaymgr 宿主,导致 user id 错宿主、出现 `assign subscription for 10 ... 500` +3. 结构债务仍存在 - access / reconcile 尚未完全按 implementation plan 物理拆分 - 无内置 scheduler/jobs -3. 运营前置动作需要 runbook 化执行 - - 真实宿主初始化不会自动创建普通用户;验收或上线前必须显式创建普通用户并留存可复用凭据 +4. 运营前置动作需要 runbook 化执行 + - 真实宿主初始化不会自动创建普通用户;当前 CRM subscription 闭环声称可按 selector 自动托管宿主普通用户/key,但本轮 remote43 真实宿主复验未通过,不能把该能力当作已验收事实 - `self_service` 需要普通用户 key 绑定目标标准 group,且通常还需要可用余额 - `subscription` 需要 subscription 类型 group + 普通用户订阅分配 + key/group 绑定 -4. 标准多阶段 Dockerfile 在受限网络环境下仍不稳 +5. 标准多阶段 Dockerfile 在受限网络环境下仍不稳 - 当前推荐 `scripts/build_local_image.sh` + `Dockerfile.local` -5. 真实宿主验收工具已补自动化闭环 - - `scripts/real_host_acceptance.sh` 支持 `AFTER_IMPORT_HOOK_COMMAND`,可把宿主侧 access 前置动作收敛进同一条 artifact 链 - - `scripts/import_remote43_provider.sh` 已内置 remote43 subscription 的“补余额 + key/group 绑定 + subscription upsert + 定向 Redis 缓存失效 + host state 落盘” +6. 真实宿主验收工具需补 host 级参数化 + - `scripts/real_host_acceptance.sh` 的 `AFTER_IMPORT_HOOK_COMMAND` 仍有价值,但 remote43/fresh-host 变体还缺“目标 Postgres/Redis 容器名、目标 host env 文件、目标 forward 端口”的显式参数化 + - 否则 artifact 会混入旧宿主状态,误导 gate 判断 ## 当前最短上线路径 -1. 按 `docs/REAL_HOST_ACCEPTANCE_RUNBOOK.md` 准备真实宿主普通用户与凭据 -2. 按目标模式完成必要的 key/group/billing(or subscription) 绑定 -3. 使用 `scripts/build_local_image.sh` 与 `scripts/real_host_acceptance.sh` 复跑并归档现场 artifact -4. 若现场前置满足,本项目按 PRD 首版范围可直接上线 +1. 先修 current-code 在真实宿主上的两个阻断点: + - 查清并修复为什么宿主 `accounts.credentials` 未持久化 `model_mapping` + - 给 remote43 验收脚本补目标 host 级参数化,避免 Postgres/Redis/host env 误指向旧 relaymgr +2. 用 fresh host 重新跑 DeepSeek / MiniMax subscription 验收,要求 `/v1/models` 暴露目标模型且 `/v1/chat/completions` 返回 200 +3. 复跑 `provider status` / `access status` / `access preview` / `batch detail`,确认 `batch_status=succeeded`、`access_status=ready` +4. 若现场前置满足,再重新评估是否恢复 CONDITIONAL_APPROVED / APPROVED ## 禁止错误结论 diff --git a/internal/access/closure.go b/internal/access/closure.go index 3e41ef9f..982c99cd 100644 --- a/internal/access/closure.go +++ b/internal/access/closure.go @@ -27,6 +27,7 @@ type ClosureRequest struct { } type Host interface { + EnsureSubscriptionAccess(ctx context.Context, req sub2api.EnsureSubscriptionAccessRequest) (sub2api.SubscriptionAccessRef, error) AssignSubscription(ctx context.Context, req sub2api.AssignSubscriptionRequest) (sub2api.SubscriptionRef, error) CheckGatewayAccess(ctx context.Context, req sub2api.GatewayAccessCheckRequest) (sub2api.GatewayAccessResult, error) } @@ -52,7 +53,7 @@ func Validate(req ClosureRequest) error { default: return fmt.Errorf("unsupported access mode %q", req.Mode) } - if strings.TrimSpace(req.ProbeAPIKey) == "" { + if strings.TrimSpace(req.Mode) != ModeSubscription && strings.TrimSpace(req.ProbeAPIKey) == "" { return fmt.Errorf("access probe api key is required to verify gateway closure") } return nil @@ -65,14 +66,29 @@ func (s *Service) Close(ctx context.Context, req ClosureRequest) (sub2api.Gatewa if err := Validate(req); err != nil { return sub2api.GatewayAccessResult{}, err } + probeAPIKey := strings.TrimSpace(req.ProbeAPIKey) if strings.TrimSpace(req.Mode) == ModeSubscription { for _, target := range req.Subscriptions { - if _, err := s.host.AssignSubscription(ctx, sub2api.AssignSubscriptionRequest{UserID: target.UserID, GroupID: req.GroupID, DurationDays: target.DurationDays}); err != nil { + resolvedTarget := target.UserID + accessRef, err := s.host.EnsureSubscriptionAccess(ctx, sub2api.EnsureSubscriptionAccessRequest{UserSelector: target.UserID, GroupID: req.GroupID}) + if err != nil { + return sub2api.GatewayAccessResult{}, fmt.Errorf("ensure subscription access for %s: %w", target.UserID, err) + } + if strings.TrimSpace(accessRef.UserID) != "" { + resolvedTarget = accessRef.UserID + } + if strings.TrimSpace(accessRef.APIKey) != "" { + probeAPIKey = strings.TrimSpace(accessRef.APIKey) + } + if _, err := s.host.AssignSubscription(ctx, sub2api.AssignSubscriptionRequest{UserID: resolvedTarget, GroupID: req.GroupID, DurationDays: target.DurationDays}); err != nil { return sub2api.GatewayAccessResult{}, fmt.Errorf("assign subscription for %s: %w", target.UserID, err) } } } - result, err := s.host.CheckGatewayAccess(ctx, sub2api.GatewayAccessCheckRequest{APIKey: req.ProbeAPIKey, ExpectedModel: req.ExpectedModel}) + if probeAPIKey == "" { + return sub2api.GatewayAccessResult{}, fmt.Errorf("access probe api key is required to verify gateway closure") + } + result, err := s.host.CheckGatewayAccess(ctx, sub2api.GatewayAccessCheckRequest{APIKey: probeAPIKey, ExpectedModel: req.ExpectedModel}) if err != nil { return sub2api.GatewayAccessResult{}, fmt.Errorf("check gateway access: %w", err) } diff --git a/internal/access/closure_test.go b/internal/access/closure_test.go index 6be1fb42..8585f38c 100644 --- a/internal/access/closure_test.go +++ b/internal/access/closure_test.go @@ -22,14 +22,28 @@ func TestValidateRejectsMissingSubscriptionsForSubscriptionMode(t *testing.T) { } } +func TestValidateAllowsManagedSubscriptionProbeWithoutExplicitAPIKey(t *testing.T) { + err := Validate(ClosureRequest{ + Mode: "subscription", + GroupID: "group-1", + ExpectedModel: "deepseek-chat", + Subscriptions: []SubscriptionTarget{{UserID: "crm-user-42", DurationDays: 30}}, + }) + if err != nil { + t.Fatalf("Validate() error = %v, want nil for managed subscription probe", err) + } +} + func TestServiceCloseAssignsSubscriptionsAndProbesGateway(t *testing.T) { host := &fakeClosureHost{ gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}}, + managedAccess: map[string]sub2api.SubscriptionAccessRef{ + "user-1": {UserID: "host-user-1", APIKey: "managed-user-key"}, + }, } service := NewService(host) result, err := service.Close(context.Background(), ClosureRequest{ Mode: "subscription", - ProbeAPIKey: "user-key", GroupID: "group-1", ExpectedModel: "deepseek-chat", Subscriptions: []SubscriptionTarget{{UserID: "user-1", DurationDays: 30}}, @@ -40,7 +54,10 @@ func TestServiceCloseAssignsSubscriptionsAndProbesGateway(t *testing.T) { if len(host.assigned) != 1 { t.Fatalf("assigned subscriptions = %d, want 1", len(host.assigned)) } - if host.gatewayProbe.APIKey != "user-key" || host.gatewayProbe.ExpectedModel != "deepseek-chat" { + if host.assigned[0].UserID != "host-user-1" { + t.Fatalf("assigned subscription user = %q, want host-user-1", host.assigned[0].UserID) + } + if host.gatewayProbe.APIKey != "managed-user-key" || host.gatewayProbe.ExpectedModel != "deepseek-chat" { t.Fatalf("gateway probe = %+v, want api key + expected model", host.gatewayProbe) } if !result.OK || !result.HasExpectedModel { @@ -68,12 +85,20 @@ func TestServiceCloseReturnsSubscriptionErrorBeforeGatewayProbe(t *testing.T) { type fakeClosureHost struct { assigned []sub2api.AssignSubscriptionRequest + managedAccess map[string]sub2api.SubscriptionAccessRef assignErr error gatewayProbe sub2api.GatewayAccessCheckRequest gatewayResult sub2api.GatewayAccessResult gatewayErr error } +func (f *fakeClosureHost) EnsureSubscriptionAccess(_ context.Context, req sub2api.EnsureSubscriptionAccessRequest) (sub2api.SubscriptionAccessRef, error) { + if ref, ok := f.managedAccess[req.UserSelector]; ok { + return ref, nil + } + return sub2api.SubscriptionAccessRef{}, errors.New("missing managed access") +} + func (f *fakeClosureHost) AssignSubscription(_ context.Context, req sub2api.AssignSubscriptionRequest) (sub2api.SubscriptionRef, error) { if f.assignErr != nil { return sub2api.SubscriptionRef{}, f.assignErr diff --git a/internal/app/app_test.go b/internal/app/app_test.go index 639da5bc..7d5c78fe 100644 --- a/internal/app/app_test.go +++ b/internal/app/app_test.go @@ -5,10 +5,12 @@ import ( "context" "encoding/json" "errors" + "fmt" "io" "net" "net/http" "net/http/httptest" + "path/filepath" "strings" "testing" "time" @@ -874,6 +876,47 @@ func TestHandlerErrorPaths(t *testing.T) { } } +func TestResolveLatestAccessStatusAggregatesAcrossModeBatches(t *testing.T) { + store := openAppTestStore(t) + defer closeAppTestStore(t, store) + + ctx := context.Background() + hostID, err := store.Hosts().Create(ctx, sqlite.Host{HostID: "host-1", BaseURL: "https://sub2api.example.com", HostVersion: "0.1.126", AuthType: "apikey", AuthToken: "token"}) + if err != nil { + t.Fatalf("Hosts().Create() error = %v", err) + } + packID, err := store.Packs().Create(ctx, sqlite.Pack{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api", Checksum: "checksum-1"}) + if err != nil { + t.Fatalf("Packs().Create() error = %v", err) + } + providerID, err := store.Providers().Create(ctx, sqlite.Provider{PackID: packID, ProviderID: "deepseek", DisplayName: "DeepSeek", BaseURL: "https://api.deepseek.com", Platform: "openai"}) + if err != nil { + t.Fatalf("Providers().Create() error = %v", err) + } + batchSubscription, err := store.ImportBatches().Create(ctx, sqlite.ImportBatch{HostID: hostID, PackID: packID, ProviderID: providerID, Mode: provision.ImportModePartial, BatchStatus: provision.BatchStatusSucceeded, AccessStatus: provision.AccessStatusSubscriptionReady}) + if err != nil { + t.Fatalf("ImportBatches().Create(subscription) error = %v", err) + } + if _, err := store.AccessClosures().Create(ctx, sqlite.AccessClosureRecord{BatchID: batchSubscription, ClosureType: provision.AccessModeSubscription, Status: provision.AccessStatusSubscriptionReady, DetailsJSON: "{}"}); err != nil { + t.Fatalf("AccessClosures().Create(subscription) error = %v", err) + } + batchSelf, err := store.ImportBatches().Create(ctx, sqlite.ImportBatch{HostID: hostID, PackID: packID, ProviderID: providerID, Mode: provision.ImportModePartial, BatchStatus: provision.BatchStatusSucceeded, AccessStatus: provision.AccessStatusSelfServiceReady}) + if err != nil { + t.Fatalf("ImportBatches().Create(self_service) error = %v", err) + } + if _, err := store.AccessClosures().Create(ctx, sqlite.AccessClosureRecord{BatchID: batchSelf, ClosureType: provision.AccessModeSelfService, Status: provision.AccessStatusSelfServiceReady, DetailsJSON: "{}"}); err != nil { + t.Fatalf("AccessClosures().Create(self_service) error = %v", err) + } + + got, err := resolveLatestAccessStatus(ctx, store, sqlite.Provider{ID: providerID, ProviderID: "deepseek"}, "host-1") + if err != nil { + t.Fatalf("resolveLatestAccessStatus() error = %v", err) + } + if got != provision.AccessStatusFullyReady { + t.Fatalf("resolveLatestAccessStatus() = %q, want %q", got, provision.AccessStatusFullyReady) + } +} + func TestProviderAccessStatusMultipleClosures(t *testing.T) { handler := NewAPIHandler("t", ActionSet{ GetProviderAccessStatus: func(context.Context, ProviderQueryRequest) (provision.ProviderSnapshot, error) { @@ -926,6 +969,24 @@ func TestHostSupportStatusRequiresPlansCapability(t *testing.T) { } } +func openAppTestStore(t *testing.T) *sqlite.DB { + t.Helper() + dbPath := filepath.Join(t.TempDir(), "state.db") + dsn := fmt.Sprintf("file:%s?_busy_timeout=5000&_pragma=foreign_keys(0)", filepath.ToSlash(dbPath)) + store, err := sqlite.Open(context.Background(), dsn) + if err != nil { + t.Fatalf("sqlite.Open() error = %v", err) + } + return store +} + +func closeAppTestStore(t *testing.T, store *sqlite.DB) { + t.Helper() + if err := store.Close(); err != nil { + t.Fatalf("store.Close() error = %v", err) + } +} + func assertJSONContains(t *testing.T, payload []byte, key string, want any) { t.Helper() var decoded map[string]any diff --git a/internal/app/http_api.go b/internal/app/http_api.go index 84f1c083..27c41606 100644 --- a/internal/app/http_api.go +++ b/internal/app/http_api.go @@ -1367,37 +1367,10 @@ func NewActionSet(sqliteDSN string) ActionSet { return AccessPreviewResult{}, fmt.Errorf("provider %q exists in multiple packs; pack_id is required", req.ProviderID) } providerRow := providers[0] - if strings.TrimSpace(req.HostID) != "" { - hostRow, err := store.Hosts().GetByHostID(ctx, req.HostID) - if err != nil { - return AccessPreviewResult{}, err - } - batch, err := store.ImportBatches().GetLatestByProviderIDAndHostID(ctx, providerRow.ID, hostRow.ID) - if err != nil { - return AccessPreviewResult{}, fmt.Errorf("find batch for provider: %w", err) - } - latestStatus := batch.AccessStatus - closures, err := store.AccessClosures().GetByBatchID(ctx, batch.ID) - if err == nil && len(closures) > 0 { - latestStatus = closures[len(closures)-1].Status - } - available := accessStatusSupportsMode(latestStatus, req.Mode) - message := fmt.Sprintf("latest access status: %s", latestStatus) - if !available { - message = fmt.Sprintf("access status %s does not satisfy mode %s", latestStatus, req.Mode) - } - return AccessPreviewResult{ProviderID: req.ProviderID, Mode: req.Mode, Available: available, Message: message}, nil - } - batch, err := store.ImportBatches().GetLatestByProviderID(ctx, providerRow.ID) + latestStatus, err := resolveLatestAccessStatus(ctx, store, providerRow, req.HostID) if err != nil { return AccessPreviewResult{}, fmt.Errorf("find batch for provider: %w", err) } - - latestStatus := batch.AccessStatus - closures, err := store.AccessClosures().GetByBatchID(ctx, batch.ID) - if err == nil && len(closures) > 0 { - latestStatus = closures[len(closures)-1].Status - } available := accessStatusSupportsMode(latestStatus, req.Mode) message := fmt.Sprintf("latest access status: %s", latestStatus) if !available { @@ -1440,6 +1413,45 @@ func resolveProvidersForQuery(ctx context.Context, store *sqlite.DB, req Provide return store.Providers().ListByProviderID(ctx, providerID) } +func resolveLatestAccessStatus(ctx context.Context, store *sqlite.DB, providerRow sqlite.Provider, hostID string) (string, error) { + if store == nil { + return "", fmt.Errorf("store is required") + } + if strings.TrimSpace(hostID) != "" { + hostRow, err := store.Hosts().GetByHostID(ctx, hostID) + if err != nil { + return "", err + } + batches, err := store.ImportBatches().ListByProviderIDAndHostID(ctx, providerRow.ID, hostRow.ID) + if err != nil { + return "", err + } + modeStatuses, err := provision.LatestModeAccessStatuses(ctx, store, batches) + if err != nil { + return "", err + } + return provision.AggregateAccessStatus(modeStatuses), nil + } + batches, err := store.ImportBatches().ListByProviderID(ctx, providerRow.ID) + if err != nil { + return "", err + } + if len(batches) == 0 { + return "", fmt.Errorf("latest import batch not found for provider") + } + hostIDValue := batches[0].HostID + for _, batch := range batches[1:] { + if batch.HostID != hostIDValue { + return "", fmt.Errorf("provider exists on multiple hosts; host_id is required") + } + } + modeStatuses, err := provision.LatestModeAccessStatuses(ctx, store, batches) + if err != nil { + return "", err + } + return provision.AggregateAccessStatus(modeStatuses), nil +} + func resolveManagedHost(ctx context.Context, store *sqlite.DB, hostID, baseURL string, auth CreateHostAuth) (sqlite.Host, *sub2api.Client, error) { if store == nil { return sqlite.Host{}, nil, fmt.Errorf("store is required") diff --git a/internal/host/sub2api/channels.go b/internal/host/sub2api/channels.go index ca382a22..c7ac664c 100644 --- a/internal/host/sub2api/channels.go +++ b/internal/host/sub2api/channels.go @@ -1,6 +1,10 @@ package sub2api -import "context" +import ( + "context" + "fmt" + "net/http" +) func (c *Client) CreateChannel(ctx context.Context, req CreateChannelRequest) (ChannelRef, error) { var ref ChannelRef @@ -9,3 +13,15 @@ func (c *Client) CreateChannel(ctx context.Context, req CreateChannelRequest) (C } return ref, nil } + +func (c *Client) UpdateChannel(ctx context.Context, channelID string, req CreateChannelRequest) error { + path := fmt.Sprintf("/api/v1/admin/channels/%s", channelID) + statusCode, _, body, err := c.perform(ctx, http.MethodPut, path, req) + if err != nil { + return err + } + if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices { + return newHTTPError(http.MethodPut, path, statusCode, body) + } + return nil +} diff --git a/internal/host/sub2api/client.go b/internal/host/sub2api/client.go index b3752b02..28b8ee9b 100644 --- a/internal/host/sub2api/client.go +++ b/internal/host/sub2api/client.go @@ -18,6 +18,7 @@ type HostAdapter interface { CreateGroup(ctx context.Context, req CreateGroupRequest) (GroupRef, error) DeleteGroup(ctx context.Context, groupID string) error CreateChannel(ctx context.Context, req CreateChannelRequest) (ChannelRef, error) + UpdateChannel(ctx context.Context, channelID string, req CreateChannelRequest) error DeleteChannel(ctx context.Context, channelID string) error CreatePlan(ctx context.Context, req CreatePlanRequest) (PlanRef, error) DeletePlan(ctx context.Context, planID string) error @@ -26,6 +27,7 @@ type HostAdapter interface { DeleteAccount(ctx context.Context, accountID string) error TestAccount(ctx context.Context, accountID string) (ProbeResult, error) GetAccountModels(ctx context.Context, accountID string) ([]AccountModel, error) + EnsureSubscriptionAccess(ctx context.Context, req EnsureSubscriptionAccessRequest) (SubscriptionAccessRef, error) AssignSubscription(ctx context.Context, req AssignSubscriptionRequest) (SubscriptionRef, error) CheckGatewayAccess(ctx context.Context, req GatewayAccessCheckRequest) (GatewayAccessResult, error) ListManagedResources(ctx context.Context, req ListManagedResourcesRequest) (ManagedResourceSnapshot, error) @@ -54,11 +56,38 @@ type GroupRef struct { } type CreateChannelRequest struct { - Name string `json:"name"` - GroupIDs []string `json:"group_ids"` - ModelMapping map[string]string `json:"model_mapping,omitempty"` - RestrictModels bool `json:"restrict_models,omitempty"` - BillingModelSource string `json:"billing_model_source,omitempty"` + Name string `json:"name"` + GroupIDs []string `json:"group_ids"` + ModelMapping map[string]string `json:"model_mapping,omitempty"` + ModelPricing []ChannelModelPricing `json:"model_pricing,omitempty"` + Platform string `json:"-"` + RestrictModels bool `json:"restrict_models,omitempty"` + BillingModelSource string `json:"billing_model_source,omitempty"` +} + +type ChannelModelPricing struct { + Platform string `json:"platform,omitempty"` + Models []string `json:"models,omitempty"` + BillingMode string `json:"billing_mode,omitempty"` + InputPrice *float64 `json:"input_price,omitempty"` + OutputPrice *float64 `json:"output_price,omitempty"` + CacheWritePrice *float64 `json:"cache_write_price,omitempty"` + CacheReadPrice *float64 `json:"cache_read_price,omitempty"` + ImageOutputPrice *float64 `json:"image_output_price,omitempty"` + PerRequestPrice *float64 `json:"per_request_price,omitempty"` + Intervals []ChannelPricingTier `json:"intervals,omitempty"` +} + +type ChannelPricingTier struct { + MinTokens int `json:"min_tokens,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + TierLabel string `json:"tier_label,omitempty"` + InputPrice *float64 `json:"input_price,omitempty"` + OutputPrice *float64 `json:"output_price,omitempty"` + CacheWritePrice *float64 `json:"cache_write_price,omitempty"` + CacheReadPrice *float64 `json:"cache_read_price,omitempty"` + PerRequestPrice *float64 `json:"per_request_price,omitempty"` + SortOrder int `json:"sort_order,omitempty"` } type ChannelRef struct { @@ -116,6 +145,16 @@ type AssignSubscriptionRequest struct { DurationDays int `json:"validity_days,omitempty"` } +type EnsureSubscriptionAccessRequest struct { + UserSelector string + GroupID string +} + +type SubscriptionAccessRef struct { + UserID string + APIKey string +} + type SubscriptionRef struct { ID string `json:"id"` } diff --git a/internal/host/sub2api/flexible_id.go b/internal/host/sub2api/flexible_id.go index ca2f5b8f..3633c196 100644 --- a/internal/host/sub2api/flexible_id.go +++ b/internal/host/sub2api/flexible_id.go @@ -48,12 +48,41 @@ func flexibleIDSliceValues(raw []string) []any { } func (r CreateChannelRequest) MarshalJSON() ([]byte, error) { + modelMapping := map[string]map[string]string{} + platform := strings.TrimSpace(r.Platform) + if platform == "" { + platform = "openai" + } + if len(r.ModelMapping) > 0 { + inner := make(map[string]string, len(r.ModelMapping)) + for key, value := range r.ModelMapping { + inner[key] = value + } + modelMapping[platform] = inner + } + modelPricing := make([]ChannelModelPricing, 0, len(r.ModelPricing)) + for _, entry := range r.ModelPricing { + pricing := entry + if strings.TrimSpace(pricing.Platform) == "" { + pricing.Platform = platform + } + modelPricing = append(modelPricing, pricing) + } + return json.Marshal(struct { - Name string `json:"name"` - GroupIDs []any `json:"group_ids"` + Name string `json:"name"` + GroupIDs []any `json:"group_ids"` + ModelMapping map[string]map[string]string `json:"model_mapping,omitempty"` + ModelPricing []ChannelModelPricing `json:"model_pricing,omitempty"` + RestrictModels bool `json:"restrict_models,omitempty"` + BillingModelSource string `json:"billing_model_source,omitempty"` }{ - Name: r.Name, - GroupIDs: flexibleIDSliceValues(r.GroupIDs), + Name: r.Name, + GroupIDs: flexibleIDSliceValues(r.GroupIDs), + ModelMapping: modelMapping, + ModelPricing: modelPricing, + RestrictModels: r.RestrictModels, + BillingModelSource: r.BillingModelSource, }) } diff --git a/internal/host/sub2api/sub2api_test.go b/internal/host/sub2api/sub2api_test.go index 84660d4b..f2e55ced 100644 --- a/internal/host/sub2api/sub2api_test.go +++ b/internal/host/sub2api/sub2api_test.go @@ -6,6 +6,7 @@ import ( "errors" "net/http" "net/http/httptest" + "strings" "testing" ) @@ -591,8 +592,23 @@ func TestCreateGroupWithMock(t *testing.T) { func TestCreateChannelWithMock(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var req struct { - Name string `json:"name"` - GroupIDs []int64 `json:"group_ids"` + Name string `json:"name"` + GroupIDs []int64 `json:"group_ids"` + ModelMapping map[string]map[string]string `json:"model_mapping"` + ModelPricing []struct { + Platform string `json:"platform"` + Models []string `json:"models"` + BillingMode string `json:"billing_mode"` + InputPrice *float64 `json:"input_price"` + OutputPrice *float64 `json:"output_price"` + CacheWritePrice *float64 `json:"cache_write_price"` + CacheReadPrice *float64 `json:"cache_read_price"` + ImageOutputPrice *float64 `json:"image_output_price"` + PerRequestPrice *float64 `json:"per_request_price"` + Intervals []any `json:"intervals"` + } `json:"model_pricing"` + RestrictModels bool `json:"restrict_models"` + BillingModelSource string `json:"billing_model_source"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { t.Fatalf("decode request: %v", err) @@ -603,11 +619,36 @@ func TestCreateChannelWithMock(t *testing.T) { if len(req.GroupIDs) != 1 || req.GroupIDs[0] != 101 { t.Fatalf("group_ids = %v, want [101]", req.GroupIDs) } + if req.ModelMapping["openai"]["deepseek-v4-pro"] != "deepseek-v4-pro" { + t.Fatalf("model_mapping = %+v, want openai/deepseek-v4-pro passthrough", req.ModelMapping) + } + if len(req.ModelPricing) != 1 { + t.Fatalf("model_pricing len = %d, want 1", len(req.ModelPricing)) + } + if req.ModelPricing[0].Platform != "openai" || req.ModelPricing[0].BillingMode != "token" { + t.Fatalf("model_pricing[0] = %+v, want openai/token entry", req.ModelPricing[0]) + } + if len(req.ModelPricing[0].Models) != 1 || req.ModelPricing[0].Models[0] != "deepseek-v4-pro" { + t.Fatalf("model_pricing[0].models = %v, want [deepseek-v4-pro]", req.ModelPricing[0].Models) + } + if !req.RestrictModels { + t.Fatal("restrict_models = false, want true") + } + if req.BillingModelSource != "channel_mapped" { + t.Fatalf("billing_model_source = %q, want channel_mapped", req.BillingModelSource) + } w.Write([]byte(`{"data":{"id":201,"name":"ch"}}`)) })) defer srv.Close() client, _ := NewClient(srv.URL, WithAPIKey("k")) - ref, err := client.CreateChannel(context.Background(), CreateChannelRequest{Name: "ch", GroupIDs: []string{"101"}}) + ref, err := client.CreateChannel(context.Background(), CreateChannelRequest{ + Name: "ch", + GroupIDs: []string{"101"}, + ModelMapping: map[string]string{"deepseek-v4-pro": "deepseek-v4-pro"}, + ModelPricing: []ChannelModelPricing{{Platform: "openai", Models: []string{"deepseek-v4-pro"}, BillingMode: "token"}}, + RestrictModels: true, + BillingModelSource: "channel_mapped", + }) if err != nil { t.Fatal(err) } @@ -699,6 +740,66 @@ func TestAssignSubscriptionWithMock(t *testing.T) { } } +func TestEnsureSubscriptionAccessWithMock(t *testing.T) { + var calls []string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls = append(calls, r.Method+" "+r.URL.Path) + switch { + case r.Method == http.MethodGet && strings.HasPrefix(r.URL.RequestURI(), "/api/v1/admin/users?"): + w.Write([]byte(`{"data":{"items":[]}}`)) + case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/users": + w.Write([]byte(`{"data":{"id":84,"email":"relay-sub-user-1@sub2api.local"}}`)) + case r.Method == http.MethodPut && r.URL.Path == "/api/v1/admin/users/84": + w.Write([]byte(`{"data":{"id":84}}`)) + case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/users/84/balance": + w.Write([]byte(`{"data":{"id":84}}`)) + case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/subscriptions/assign": + var req struct { + UserID int64 `json:"user_id"` + GroupID int64 `json:"group_id"` + DurationDays int `json:"validity_days"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("decode assign subscription request: %v", err) + } + if req.UserID != 84 || req.GroupID != 101 || req.DurationDays != 30 { + t.Fatalf("unexpected assign subscription request: %+v", req) + } + w.Write([]byte(`{"data":{"id":401}}`)) + case r.Method == http.MethodPost && r.URL.Path == "/api/v1/auth/login": + w.Write([]byte(`{"data":{"access_token":"user-jwt"}}`)) + case r.Method == http.MethodPost && r.URL.Path == "/api/v1/keys": + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("decode managed key request: %v", err) + } + if _, ok := req["group_id"]; ok { + t.Fatalf("managed key request unexpectedly carried group_id: %+v", req) + } + w.Write([]byte(`{"data":{"id":501,"key":"sk-relay-key","name":"managed-key"}}`)) + case r.Method == http.MethodPut && r.URL.Path == "/api/v1/admin/api-keys/501": + w.Write([]byte(`{"data":{"api_key":{"id":501}}}`)) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + client, _ := NewClient(srv.URL, WithBearerToken("admin-token")) + ref, err := client.EnsureSubscriptionAccess(context.Background(), EnsureSubscriptionAccessRequest{UserSelector: "crm-user-1", GroupID: "101"}) + if err != nil { + t.Fatal(err) + } + if ref.UserID != "84" { + t.Fatalf("user id = %q, want 84", ref.UserID) + } + if !strings.HasPrefix(ref.APIKey, "sk-relay-") { + t.Fatalf("api key = %q, want managed sk-relay-* key", ref.APIKey) + } + if len(calls) < 7 { + t.Fatalf("calls = %v, want managed subscription setup sequence", calls) + } +} + func TestCheckGatewayAccessWithMock(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(`{"data":[{"id":"gpt-4"},{"id":"claude-3"}]}`)) @@ -741,12 +842,19 @@ func TestBatchCreateAccountsWithMock(t *testing.T) { if len(acct.GroupIDs) != 1 || acct.GroupIDs[0] != 101 { t.Fatalf("group_ids = %v, want [101]", acct.GroupIDs) } + rawMapping, ok := acct.Credentials["model_mapping"].(map[string]any) + if !ok { + t.Fatalf("credentials = %+v, want model_mapping map", acct.Credentials) + } + if got, _ := rawMapping["deepseek-v4-pro"].(string); got != "deepseek-v4-pro" { + t.Fatalf("model_mapping = %+v, want deepseek-v4-pro passthrough", rawMapping) + } w.Write([]byte(`{"data":[{"id":601,"name":"acct1"}]}`)) })) defer srv.Close() client, _ := NewClient(srv.URL, WithAPIKey("k")) refs, err := client.BatchCreateAccounts(context.Background(), BatchCreateAccountsRequest{ - Accounts: []CreateAccountRequest{{Name: "acct1", Platform: "openai", Type: "apikey", GroupIDs: []string{"101"}, Credentials: map[string]any{"api_key": "sk-test", "base_url": "https://api.example.com"}}}, + Accounts: []CreateAccountRequest{{Name: "acct1", Platform: "openai", Type: "apikey", GroupIDs: []string{"101"}, Credentials: map[string]any{"api_key": "sk-test", "base_url": "https://api.example.com", "model_mapping": map[string]string{"deepseek-v4-pro": "deepseek-v4-pro"}}}}, }) if err != nil { t.Fatal(err) diff --git a/internal/host/sub2api/subscription_access.go b/internal/host/sub2api/subscription_access.go new file mode 100644 index 00000000..fd570f2c --- /dev/null +++ b/internal/host/sub2api/subscription_access.go @@ -0,0 +1,320 @@ +package sub2api + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strconv" + "strings" +) + +const ( + managedSubscriptionBalance = 10.0 + managedSubscriptionValidityDays = 30 +) + +type adminUserRecord struct { + ID int64 `json:"id"` + Email string `json:"email"` +} + +type adminAPIKeyRecord struct { + ID int64 `json:"id"` + Key string `json:"key"` + Name string `json:"name"` + Group *struct { + ID int64 `json:"id"` + } `json:"group,omitempty"` + GroupID *int64 `json:"group_id,omitempty"` +} + +type authTokenPair struct { + AccessToken string `json:"access_token"` +} + +func (c *Client) EnsureSubscriptionAccess(ctx context.Context, req EnsureSubscriptionAccessRequest) (SubscriptionAccessRef, error) { + if c == nil { + return SubscriptionAccessRef{}, fmt.Errorf("client is required") + } + selector := strings.TrimSpace(req.UserSelector) + groupID := strings.TrimSpace(req.GroupID) + if selector == "" { + return SubscriptionAccessRef{}, fmt.Errorf("user selector is required") + } + if groupID == "" { + return SubscriptionAccessRef{}, fmt.Errorf("group id is required") + } + groupInt, err := strconv.ParseInt(groupID, 10, 64) + if err != nil { + return SubscriptionAccessRef{}, fmt.Errorf("parse group id %q: %w", groupID, err) + } + + identity := buildManagedSubscriptionIdentity(selector, groupID) + user, err := c.findManagedSubscriptionUser(ctx, identity.Email) + if err != nil { + return SubscriptionAccessRef{}, err + } + if user == nil { + user, err = c.createManagedSubscriptionUser(ctx, identity, groupInt) + if err != nil { + return SubscriptionAccessRef{}, err + } + } + if err := c.updateManagedSubscriptionUser(ctx, user.ID, groupInt); err != nil { + return SubscriptionAccessRef{}, err + } + if err := c.setManagedSubscriptionBalance(ctx, user.ID); err != nil { + return SubscriptionAccessRef{}, err + } + if err := c.ensureManagedSubscriptionAssignment(ctx, user.ID, groupID); err != nil { + return SubscriptionAccessRef{}, err + } + + userClient, err := c.loginAsManagedSubscriptionUser(ctx, identity.Email, identity.Password) + if err != nil { + return SubscriptionAccessRef{}, err + } + keyRecord, err := c.ensureManagedSubscriptionAPIKey(ctx, userClient, user.ID, identity) + if err != nil { + return SubscriptionAccessRef{}, err + } + if err := c.bindManagedSubscriptionAPIKey(ctx, keyRecord.ID, groupInt); err != nil { + return SubscriptionAccessRef{}, err + } + return SubscriptionAccessRef{UserID: strconv.FormatInt(user.ID, 10), APIKey: identity.CustomKey}, nil +} + +type managedSubscriptionIdentity struct { + Email string + Username string + Password string + CustomKey string + KeyName string +} + +func buildManagedSubscriptionIdentity(selector, groupID string) managedSubscriptionIdentity { + normalizedSelector := strings.TrimSpace(selector) + seedMaterial := strings.ToLower(normalizedSelector) + "|" + strings.TrimSpace(groupID) + sum := sha256.Sum256([]byte(seedMaterial)) + hash := hex.EncodeToString(sum[:]) + prefix := sanitizeManagedSubscriptionPrefix(normalizedSelector) + if prefix == "" { + prefix = "relay-sub" + } + prefix = truncateManagedSubscriptionToken(prefix, 24) + shortHash := hash[:16] + keyHash := hash[:32] + username := truncateManagedSubscriptionToken(prefix+"-"+shortHash[:8], 32) + return managedSubscriptionIdentity{ + Email: fmt.Sprintf("%s-%s@sub2api.local", prefix, shortHash), + Username: username, + Password: "RelayPwd!" + hash[:12], + CustomKey: "sk-relay-" + keyHash, + KeyName: truncateManagedSubscriptionToken(username+"-key", 48), + } +} + +func sanitizeManagedSubscriptionPrefix(value string) string { + value = strings.ToLower(strings.TrimSpace(value)) + var b strings.Builder + lastDash := false + for _, r := range value { + switch { + case r >= 'a' && r <= 'z', r >= '0' && r <= '9': + b.WriteRune(r) + lastDash = false + case !lastDash: + b.WriteByte('-') + lastDash = true + } + } + return strings.Trim(b.String(), "-") +} + +func truncateManagedSubscriptionToken(value string, max int) string { + if len(value) <= max { + return value + } + return strings.Trim(value[:max], "-") +} + +func (c *Client) findManagedSubscriptionUser(ctx context.Context, email string) (*adminUserRecord, error) { + statusCode, _, body, err := c.perform(ctx, http.MethodGet, "/api/v1/admin/users?search="+url.QueryEscape(email)+"&page=1&page_size=20&sort_by=created_at&sort_order=desc", nil) + if err != nil { + return nil, fmt.Errorf("list admin users: %w", err) + } + if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices { + return nil, newHTTPError(http.MethodGet, "/api/v1/admin/users", statusCode, body) + } + var envelope struct { + Data struct { + Items []adminUserRecord `json:"items"` + } `json:"data"` + } + if err := json.Unmarshal(body, &envelope); err != nil { + return nil, fmt.Errorf("decode admin users response: %w", err) + } + for _, item := range envelope.Data.Items { + if strings.EqualFold(strings.TrimSpace(item.Email), email) { + user := item + return &user, nil + } + } + return nil, nil +} + +func (c *Client) createManagedSubscriptionUser(ctx context.Context, identity managedSubscriptionIdentity, groupID int64) (*adminUserRecord, error) { + payload := map[string]any{ + "email": identity.Email, + "password": identity.Password, + "username": identity.Username, + "notes": "managed by sub2api-cn-relay-manager", + "balance": managedSubscriptionBalance, + "concurrency": 5, + "allowed_groups": []int64{groupID}, + } + statusCode, _, body, err := c.perform(ctx, http.MethodPost, "/api/v1/admin/users", payload) + if err != nil { + return nil, fmt.Errorf("create admin user: %w", err) + } + if statusCode == http.StatusConflict { + return c.findManagedSubscriptionUser(ctx, identity.Email) + } + if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices { + return nil, newHTTPError(http.MethodPost, "/api/v1/admin/users", statusCode, body) + } + var user adminUserRecord + if err := decodeEnvelopeObject(body, &user); err != nil { + return nil, fmt.Errorf("decode created admin user: %w", err) + } + return &user, nil +} + +func (c *Client) updateManagedSubscriptionUser(ctx context.Context, userID, groupID int64) error { + payload := map[string]any{"allowed_groups": []int64{groupID}} + statusCode, _, body, err := c.perform(ctx, http.MethodPut, fmt.Sprintf("/api/v1/admin/users/%d", userID), payload) + if err != nil { + return fmt.Errorf("update admin user groups: %w", err) + } + if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices { + return newHTTPError(http.MethodPut, fmt.Sprintf("/api/v1/admin/users/%d", userID), statusCode, body) + } + return nil +} + +func (c *Client) setManagedSubscriptionBalance(ctx context.Context, userID int64) error { + payload := map[string]any{"balance": managedSubscriptionBalance, "operation": "set", "notes": "managed by sub2api-cn-relay-manager"} + statusCode, _, body, err := c.perform(ctx, http.MethodPost, fmt.Sprintf("/api/v1/admin/users/%d/balance", userID), payload) + if err != nil { + return fmt.Errorf("set admin user balance: %w", err) + } + if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices { + return newHTTPError(http.MethodPost, fmt.Sprintf("/api/v1/admin/users/%d/balance", userID), statusCode, body) + } + return nil +} + +func (c *Client) ensureManagedSubscriptionAssignment(ctx context.Context, userID int64, groupID string) error { + _, err := c.AssignSubscription(ctx, AssignSubscriptionRequest{ + UserID: strconv.FormatInt(userID, 10), + GroupID: groupID, + DurationDays: managedSubscriptionValidityDays, + }) + if err != nil { + return fmt.Errorf("assign managed subscription: %w", err) + } + return nil +} + +func (c *Client) loginAsManagedSubscriptionUser(ctx context.Context, email, password string) (*Client, error) { + anon := c.cloneWithAuth("", "") + payload := map[string]any{"email": email, "password": password, "turnstile_token": ""} + statusCode, _, body, err := anon.perform(ctx, http.MethodPost, "/api/v1/auth/login", payload) + if err != nil { + return nil, fmt.Errorf("login managed subscription user: %w", err) + } + if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices { + return nil, newHTTPError(http.MethodPost, "/api/v1/auth/login", statusCode, body) + } + var tokenPair authTokenPair + if err := decodeEnvelopeObject(body, &tokenPair); err != nil { + return nil, fmt.Errorf("decode managed user login response: %w", err) + } + if strings.TrimSpace(tokenPair.AccessToken) == "" { + return nil, fmt.Errorf("managed user login returned empty access token") + } + return c.cloneWithAuth("", tokenPair.AccessToken), nil +} + +func (c *Client) ensureManagedSubscriptionAPIKey(ctx context.Context, userClient *Client, userID int64, identity managedSubscriptionIdentity) (*adminAPIKeyRecord, error) { + payload := map[string]any{ + "name": identity.KeyName, + "custom_key": identity.CustomKey, + } + statusCode, _, body, err := userClient.perform(ctx, http.MethodPost, "/api/v1/keys", payload) + if err != nil { + return nil, fmt.Errorf("create managed api key: %w", err) + } + if statusCode >= http.StatusOK && statusCode < http.StatusMultipleChoices { + var key adminAPIKeyRecord + if err := decodeEnvelopeObject(body, &key); err != nil { + return nil, fmt.Errorf("decode created api key: %w", err) + } + return &key, nil + } + if statusCode != http.StatusConflict && statusCode != http.StatusBadRequest { + return nil, newHTTPError(http.MethodPost, "/api/v1/keys", statusCode, body) + } + return c.findManagedSubscriptionAPIKey(ctx, userID, identity) +} + +func (c *Client) findManagedSubscriptionAPIKey(ctx context.Context, userID int64, identity managedSubscriptionIdentity) (*adminAPIKeyRecord, error) { + statusCode, _, body, err := c.perform(ctx, http.MethodGet, fmt.Sprintf("/api/v1/admin/users/%d/api-keys?page=1&page_size=100&sort_by=created_at&sort_order=desc", userID), nil) + if err != nil { + return nil, fmt.Errorf("list managed api keys: %w", err) + } + if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices { + return nil, newHTTPError(http.MethodGet, fmt.Sprintf("/api/v1/admin/users/%d/api-keys", userID), statusCode, body) + } + var envelope struct { + Data struct { + Items []adminAPIKeyRecord `json:"items"` + } `json:"data"` + } + if err := json.Unmarshal(body, &envelope); err != nil { + return nil, fmt.Errorf("decode admin api keys response: %w", err) + } + for _, item := range envelope.Data.Items { + if strings.TrimSpace(item.Key) == identity.CustomKey || strings.TrimSpace(item.Name) == identity.KeyName { + key := item + return &key, nil + } + } + return nil, fmt.Errorf("managed api key %q not found for user %d", identity.KeyName, userID) +} + +func (c *Client) bindManagedSubscriptionAPIKey(ctx context.Context, keyID, groupID int64) error { + payload := map[string]any{"group_id": groupID} + statusCode, _, body, err := c.perform(ctx, http.MethodPut, fmt.Sprintf("/api/v1/admin/api-keys/%d", keyID), payload) + if err != nil { + return fmt.Errorf("bind managed api key group: %w", err) + } + if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices { + return newHTTPError(http.MethodPut, fmt.Sprintf("/api/v1/admin/api-keys/%d", keyID), statusCode, body) + } + return nil +} + +func (c *Client) cloneWithAuth(apiKey, bearerToken string) *Client { + if c == nil { + return nil + } + clone := *c + clone.apiKey = strings.TrimSpace(apiKey) + clone.bearerToken = strings.TrimSpace(bearerToken) + return &clone +} diff --git a/internal/pack/loader.go b/internal/pack/loader.go index d1728680..87a540c6 100644 --- a/internal/pack/loader.go +++ b/internal/pack/loader.go @@ -157,6 +157,7 @@ func validateProviders(providers []ProviderManifest) error { seen := make(map[string]struct{}, len(providers)) for _, provider := range providers { providerID := strings.TrimSpace(provider.ProviderID) + missingDefaultModel := firstMissingDefaultModel(provider.DefaultModels, provider.ChannelTemplate.ModelMapping) switch { case providerID == "": return fmt.Errorf("provider manifest: provider_id is required") @@ -180,6 +181,10 @@ func validateProviders(providers []ProviderManifest) error { return fmt.Errorf("provider %q: channel_template.name is required", providerID) case len(provider.ChannelTemplate.ModelMapping) == 0: return fmt.Errorf("provider %q: channel_template.model_mapping must not be empty", providerID) + case !containsProviderModel(provider.ChannelTemplate.ModelMapping, provider.SmokeTestModel): + return fmt.Errorf("provider %q: channel_template.model_mapping must include smoke_test_model %q", providerID, provider.SmokeTestModel) + case missingDefaultModel != "": + return fmt.Errorf("provider %q: channel_template.model_mapping must cover default_models, missing %q", providerID, missingDefaultModel) case strings.TrimSpace(provider.PlanTemplate.Name) == "": return fmt.Errorf("provider %q: plan_template.name is required", providerID) case provider.PlanTemplate.ValidityDays <= 0: @@ -247,3 +252,29 @@ func contains(items []string, target string) bool { } return false } + +func containsProviderModel(modelMapping map[string]string, target string) bool { + trimmedTarget := strings.TrimSpace(target) + if trimmedTarget == "" { + return false + } + for sourceModel, mappedModel := range modelMapping { + if strings.TrimSpace(sourceModel) == trimmedTarget || strings.TrimSpace(mappedModel) == trimmedTarget { + return true + } + } + return false +} + +func firstMissingDefaultModel(defaultModels []string, modelMapping map[string]string) string { + for _, model := range defaultModels { + trimmedModel := strings.TrimSpace(model) + if trimmedModel == "" { + continue + } + if !containsProviderModel(modelMapping, trimmedModel) { + return trimmedModel + } + } + return "" +} diff --git a/internal/pack/loader_test.go b/internal/pack/loader_test.go index 12f20dc7..e0ba9ef7 100644 --- a/internal/pack/loader_test.go +++ b/internal/pack/loader_test.go @@ -30,7 +30,7 @@ func TestLoadDirParsesAndValidatesPack(t *testing.T) { "default_models": ["deepseek-chat", "deepseek-reasoner"], "smoke_test_model": "deepseek-chat", "group_template": {"name": "DeepSeek 默认分组", "rate_multiplier": 1.0}, - "channel_template": {"name": "DeepSeek 默认渠道", "model_mapping": {"deepseek-chat": "deepseek-chat"}}, + "channel_template": {"name": "DeepSeek 默认渠道", "model_mapping": {"deepseek-chat": "deepseek-chat", "deepseek-reasoner": "deepseek-reasoner"}}, "plan_template": {"name": "DeepSeek 默认套餐", "price": 19.9, "validity_days": 30, "validity_unit": "day"}, "import": {"supports_multi_key": true, "supports_strict": true, "supports_partial": true} }`, @@ -82,6 +82,36 @@ func TestLoadDirRejectsInvalidProviderSchema(t *testing.T) { } } +func TestLoadDirRejectsSmokeTestModelMissingFromChannelMapping(t *testing.T) { + packDir := createPackFixture(t, map[string]string{ + "pack.json": `{"pack_id":"openai-cn-pack","version":"1.0.0","vendor":"x","target_host":"sub2api","min_host_version":"0.1.126","max_host_version":"0.2.x","providers_dir":"providers","checksum_file":"checksums.txt"}`, + "providers/deepseek.json": `{"provider_id":"deepseek","display_name":"DeepSeek","base_url":"https://api.deepseek.com","platform":"openai","account_type":"apikey","default_models":["deepseek-v4-pro","deepseek-v4-flash"],"smoke_test_model":"deepseek-v4-pro","group_template":{"name":"g","rate_multiplier":1},"channel_template":{"name":"c","model_mapping":{"deepseek-chat":"deepseek-chat","deepseek-reasoner":"deepseek-reasoner"}},"plan_template":{"name":"p","price":1,"validity_days":30,"validity_unit":"day"},"import":{"supports_multi_key":true,"supports_strict":true,"supports_partial":true}}`, + }) + + _, err := LoadDir(packDir) + if err == nil { + t.Fatal("LoadDir() error = nil, want smoke_test_model channel mapping validation failure") + } + if !strings.Contains(err.Error(), "channel_template.model_mapping") || !strings.Contains(err.Error(), "smoke_test_model") { + t.Fatalf("LoadDir() error = %v, want smoke_test_model channel mapping detail", err) + } +} + +func TestLoadDirRejectsDefaultModelsMissingFromChannelMapping(t *testing.T) { + packDir := createPackFixture(t, map[string]string{ + "pack.json": `{"pack_id":"openai-cn-pack","version":"1.0.0","vendor":"x","target_host":"sub2api","min_host_version":"0.1.126","max_host_version":"0.2.x","providers_dir":"providers","checksum_file":"checksums.txt"}`, + "providers/minimax.json": `{"provider_id":"minimax","display_name":"MiniMax","base_url":"https://api.minimax.example.com","platform":"openai","account_type":"apikey","default_models":["MiniMax-M2.5-highspeed","MiniMax-M2.7-highspeed"],"smoke_test_model":"MiniMax-M2.7-highspeed","group_template":{"name":"g","rate_multiplier":1},"channel_template":{"name":"c","model_mapping":{"MiniMax-M2.7-highspeed":"MiniMax-M2.7-highspeed"}},"plan_template":{"name":"p","price":1,"validity_days":30,"validity_unit":"day"},"import":{"supports_multi_key":true,"supports_strict":true,"supports_partial":true}}`, + }) + + _, err := LoadDir(packDir) + if err == nil { + t.Fatal("LoadDir() error = nil, want default_models channel mapping validation failure") + } + if !strings.Contains(err.Error(), "default_models") || !strings.Contains(err.Error(), "channel_template.model_mapping") { + t.Fatalf("LoadDir() error = %v, want default_models channel mapping detail", err) + } +} + func createPackFixture(t *testing.T, files map[string]string) string { t.Helper() diff --git a/internal/provision/access_status_aggregation.go b/internal/provision/access_status_aggregation.go new file mode 100644 index 00000000..f6fa4b06 --- /dev/null +++ b/internal/provision/access_status_aggregation.go @@ -0,0 +1,128 @@ +package provision + +import ( + "context" + "strings" + + "sub2api-cn-relay-manager/internal/pack" + "sub2api-cn-relay-manager/internal/store/sqlite" +) + +type ModeAccessStatuses struct { + Subscription string + SelfService string +} + +func SuggestResourceNamesForMode(provider pack.ProviderManifest, accessMode string) ResourceNames { + base := SuggestResourceNames(provider) + suffix := accessModeResourceSuffix(accessMode) + if suffix == "" { + return base + } + return ResourceNames{ + Group: appendResourceNameSuffix(base.Group, suffix), + Channel: appendResourceNameSuffix(base.Channel, suffix), + Plan: appendResourceNameSuffix(base.Plan, suffix), + } +} + +func accessModeResourceSuffix(accessMode string) string { + switch strings.TrimSpace(accessMode) { + case AccessModeSubscription: + return "subscription" + case AccessModeSelfService: + return "self-service" + default: + return "" + } +} + +func appendResourceNameSuffix(name, suffix string) string { + name = strings.TrimSpace(name) + suffix = strings.TrimSpace(suffix) + if name == "" || suffix == "" { + return name + } + if strings.HasSuffix(name, "-"+suffix) { + return name + } + return name + "-" + suffix +} + +func LatestModeAccessStatuses(ctx context.Context, store *sqlite.DB, batches []sqlite.ImportBatch) (ModeAccessStatuses, error) { + var statuses ModeAccessStatuses + for _, batch := range batches { + if statuses.Subscription != "" && statuses.SelfService != "" { + break + } + closures, err := store.AccessClosures().GetByBatchID(ctx, batch.ID) + if err != nil { + return ModeAccessStatuses{}, err + } + batchStatuses := modeAccessStatusesForBatch(batch, closures) + if statuses.Subscription == "" && strings.TrimSpace(batchStatuses.Subscription) != "" { + statuses.Subscription = strings.TrimSpace(batchStatuses.Subscription) + } + if statuses.SelfService == "" && strings.TrimSpace(batchStatuses.SelfService) != "" { + statuses.SelfService = strings.TrimSpace(batchStatuses.SelfService) + } + } + return statuses, nil +} + +func modeAccessStatusesForBatch(batch sqlite.ImportBatch, closures []sqlite.AccessClosureRecord) ModeAccessStatuses { + statuses := ModeAccessStatuses{} + for _, closure := range closures { + status := strings.TrimSpace(closure.Status) + switch strings.TrimSpace(closure.ClosureType) { + case AccessModeSubscription: + statuses.Subscription = status + case AccessModeSelfService: + statuses.SelfService = status + } + } + if statuses.Subscription == "" && statuses.SelfService == "" { + return seedModeAccessStatuses(batch.AccessStatus) + } + return statuses +} + +func seedModeAccessStatuses(accessStatus string) ModeAccessStatuses { + switch strings.TrimSpace(accessStatus) { + case AccessStatusFullyReady: + return ModeAccessStatuses{Subscription: AccessStatusSubscriptionReady, SelfService: AccessStatusSelfServiceReady} + case AccessStatusSubscriptionReady: + return ModeAccessStatuses{Subscription: AccessStatusSubscriptionReady} + case AccessStatusSelfServiceReady: + return ModeAccessStatuses{SelfService: AccessStatusSelfServiceReady} + default: + return ModeAccessStatuses{} + } +} + +func AggregateAccessStatus(statuses ModeAccessStatuses) string { + subscriptionReady := isReadyAccessStatus(statuses.Subscription, AccessModeSubscription) + selfServiceReady := isReadyAccessStatus(statuses.SelfService, AccessModeSelfService) + switch { + case subscriptionReady && selfServiceReady: + return AccessStatusFullyReady + case subscriptionReady: + return AccessStatusSubscriptionReady + case selfServiceReady: + return AccessStatusSelfServiceReady + default: + return AccessStatusBroken + } +} + +func isReadyAccessStatus(status, mode string) bool { + status = strings.TrimSpace(status) + switch mode { + case AccessModeSubscription: + return status == AccessStatusSubscriptionReady || status == AccessStatusFullyReady + case AccessModeSelfService: + return status == AccessStatusSelfServiceReady || status == AccessStatusFullyReady + default: + return status != "" && status != AccessStatusBroken + } +} diff --git a/internal/provision/batch_detail_and_reconcile_service.go b/internal/provision/batch_detail_and_reconcile_service.go index ee930778..69a9c086 100644 --- a/internal/provision/batch_detail_and_reconcile_service.go +++ b/internal/provision/batch_detail_and_reconcile_service.go @@ -278,7 +278,7 @@ func accessClosureType(accessClosures []sqlite.AccessClosureRecord) string { } func buildManagedResourceListRequest(provider pack.ProviderManifest, accessMode string) sub2api.ListManagedResourcesRequest { - names := SuggestResourceNames(provider) + names := SuggestResourceNamesForMode(provider, accessMode) req := sub2api.ListManagedResourcesRequest{ GroupName: names.Group, ChannelName: names.Channel, diff --git a/internal/provision/batch_detail_service_test.go b/internal/provision/batch_detail_service_test.go index 0107035e..9b96707e 100644 --- a/internal/provision/batch_detail_service_test.go +++ b/internal/provision/batch_detail_service_test.go @@ -215,10 +215,11 @@ func TestDeriveProviderStatus(t *testing.T) { tests := []struct { name string batchStatus string + accessStatus string reconcileStatus string want string }{ - {name: "reconcile wins", batchStatus: BatchStatusSucceeded, reconcileStatus: "degraded", want: "degraded"}, + {name: "recovered success beats stale reconcile", batchStatus: BatchStatusSucceeded, accessStatus: AccessStatusSelfServiceReady, reconcileStatus: "degraded", want: ProviderStatusActive}, {name: "succeeded batch", batchStatus: BatchStatusSucceeded, reconcileStatus: "not_run", want: ProviderStatusActive}, {name: "failed batch", batchStatus: BatchStatusFailed, want: ProviderStatusFailed}, {name: "running batch", batchStatus: "running", want: "running"}, @@ -226,13 +227,60 @@ func TestDeriveProviderStatus(t *testing.T) { } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - if got := deriveProviderStatus(tc.batchStatus, tc.reconcileStatus); got != tc.want { - t.Fatalf("deriveProviderStatus(%q, %q) = %q, want %q", tc.batchStatus, tc.reconcileStatus, got, tc.want) + if got := deriveProviderStatus(tc.batchStatus, tc.accessStatus, tc.reconcileStatus); got != tc.want { + t.Fatalf("deriveProviderStatus(%q, %q, %q) = %q, want %q", tc.batchStatus, tc.accessStatus, tc.reconcileStatus, got, tc.want) } }) } } +func TestProviderStatusServiceAggregatesLatestAccessModesAcrossBatches(t *testing.T) { + store := openProvisionTestStore(t) + defer closeProvisionTestStore(t, store) + + ctx := context.Background() + hostID := seedProvisionHost(t, store, "host-1", "https://sub2api.example.com") + packID, err := store.Packs().Create(ctx, sqlite.Pack{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api", Checksum: "checksum-1"}) + if err != nil { + t.Fatalf("Packs().Create() error = %v", err) + } + providerID, err := store.Providers().Create(ctx, sqlite.Provider{PackID: packID, ProviderID: "deepseek", DisplayName: "DeepSeek", BaseURL: "https://api.deepseek.com", Platform: "openai"}) + if err != nil { + t.Fatalf("Providers().Create() error = %v", err) + } + batchSubscription, err := store.ImportBatches().Create(ctx, sqlite.ImportBatch{HostID: hostID, PackID: packID, ProviderID: providerID, Mode: ImportModePartial, BatchStatus: BatchStatusSucceeded, AccessStatus: AccessStatusSubscriptionReady}) + if err != nil { + t.Fatalf("ImportBatches().Create(subscription) error = %v", err) + } + if _, err := store.AccessClosures().Create(ctx, sqlite.AccessClosureRecord{BatchID: batchSubscription, ClosureType: AccessModeSubscription, Status: AccessStatusSubscriptionReady, DetailsJSON: "{}"}); err != nil { + t.Fatalf("AccessClosures().Create(subscription) error = %v", err) + } + batchSelfService, err := store.ImportBatches().Create(ctx, sqlite.ImportBatch{HostID: hostID, PackID: packID, ProviderID: providerID, Mode: ImportModePartial, BatchStatus: BatchStatusSucceeded, AccessStatus: AccessStatusSelfServiceReady}) + if err != nil { + t.Fatalf("ImportBatches().Create(self_service) error = %v", err) + } + if _, err := store.AccessClosures().Create(ctx, sqlite.AccessClosureRecord{BatchID: batchSelfService, ClosureType: AccessModeSelfService, Status: AccessStatusSelfServiceReady, DetailsJSON: "{}"}); err != nil { + t.Fatalf("AccessClosures().Create(self_service) error = %v", err) + } + if _, err := store.ReconcileRuns().Create(ctx, sqlite.ReconcileRun{BatchID: batchSelfService, HostID: hostID, ProviderID: providerID, Status: "drifted", SummaryJSON: `{"missing_count":1}`}); err != nil { + t.Fatalf("ReconcileRuns().Create() error = %v", err) + } + + snapshot, err := NewProviderStatusService(store).GetStatus(ctx, ProviderQuery{ProviderID: "deepseek", PackID: "openai-cn-pack", HostID: "host-1"}) + if err != nil { + t.Fatalf("GetStatus() error = %v", err) + } + if snapshot.LatestAccessStatus != AccessStatusFullyReady { + t.Fatalf("LatestAccessStatus = %q, want %q", snapshot.LatestAccessStatus, AccessStatusFullyReady) + } + if snapshot.ProviderStatus != ProviderStatusActive { + t.Fatalf("ProviderStatus = %q, want %q", snapshot.ProviderStatus, ProviderStatusActive) + } + if snapshot.LatestReconcileStatus != "drifted" { + t.Fatalf("LatestReconcileStatus = %q, want drifted", snapshot.LatestReconcileStatus) + } +} + func TestBuildPackAndProviderRecord(t *testing.T) { packRow, err := buildPackRecord(sampleLoadedPack()) if err != nil { diff --git a/internal/provision/import_service.go b/internal/provision/import_service.go index 31516b91..251411ef 100644 --- a/internal/provision/import_service.go +++ b/internal/provision/import_service.go @@ -199,7 +199,7 @@ func (s *ImportService) Import(ctx context.Context, req ImportRequest) (report I } func (s *ImportService) ensureManagedResources(ctx context.Context, provider pack.ProviderManifest, accessMode string) (resolvedManagedResources, error) { - names := SuggestResourceNames(provider) + names := SuggestResourceNamesForMode(provider, accessMode) snapshot, err := s.host.ListManagedResources(ctx, sub2api.ListManagedResourcesRequest{ GroupName: names.Group, ChannelName: names.Channel, @@ -210,14 +210,14 @@ func (s *ImportService) ensureManagedResources(ctx context.Context, provider pac } result := resolvedManagedResources{} - group, created, err := ensureGroup(ctx, s.host, snapshot.Groups, provider, accessMode) + group, created, err := ensureGroup(ctx, s.host, snapshot.Groups, provider, accessMode, names.Group) if err != nil { return resolvedManagedResources{}, fmt.Errorf("ensure group: %w", err) } result.Group = group result.CreatedGroup = created - channel, created, err := ensureChannel(ctx, s.host, snapshot.Channels, provider, group.ID) + channel, created, err := ensureChannel(ctx, s.host, snapshot.Channels, provider, group.ID, names.Channel) if err != nil { return resolvedManagedResources{}, fmt.Errorf("ensure channel: %w", err) } @@ -225,7 +225,7 @@ func (s *ImportService) ensureManagedResources(ctx context.Context, provider pac result.CreatedChannel = created if accessMode == AccessModeSubscription { - plan, created, err := ensurePlan(ctx, s.host, snapshot.Plans, provider, group.ID) + plan, created, err := ensurePlan(ctx, s.host, snapshot.Plans, provider, group.ID, names.Plan) if err != nil { return resolvedManagedResources{}, fmt.Errorf("ensure plan: %w", err) } @@ -236,10 +236,10 @@ func (s *ImportService) ensureManagedResources(ctx context.Context, provider pac return result, nil } -func ensureGroup(ctx context.Context, host hostAdapter, existing []sub2api.NamedResource, provider pack.ProviderManifest, accessMode string) (sub2api.GroupRef, bool, error) { +func ensureGroup(ctx context.Context, host hostAdapter, existing []sub2api.NamedResource, provider pack.ProviderManifest, accessMode, groupName string) (sub2api.GroupRef, bool, error) { switch len(existing) { case 0: - groupReq := sub2api.CreateGroupRequest{Name: provider.GroupTemplate.Name, Platform: provider.Platform, RateMultiplier: provider.GroupTemplate.RateMultiplier} + groupReq := sub2api.CreateGroupRequest{Name: groupName, Platform: provider.Platform, RateMultiplier: provider.GroupTemplate.RateMultiplier} if accessMode == AccessModeSubscription { groupReq.SubscriptionType = "subscription" } @@ -248,38 +248,52 @@ func ensureGroup(ctx context.Context, host hostAdapter, existing []sub2api.Named case 1: return sub2api.GroupRef{ID: existing[0].ID, Name: existing[0].Name}, false, nil default: - return sub2api.GroupRef{}, false, fmt.Errorf("multiple groups already exist for %q", provider.GroupTemplate.Name) + return sub2api.GroupRef{}, false, fmt.Errorf("multiple groups already exist for %q", groupName) } } -func ensureChannel(ctx context.Context, host hostAdapter, existing []sub2api.NamedResource, provider pack.ProviderManifest, groupID string) (sub2api.ChannelRef, bool, error) { +func ensureChannel(ctx context.Context, host hostAdapter, existing []sub2api.NamedResource, provider pack.ProviderManifest, groupID, channelName string) (sub2api.ChannelRef, bool, error) { + channelReq := buildChannelRequest(provider, groupID, channelName) switch len(existing) { case 0: - channelReq := sub2api.CreateChannelRequest{ - Name: provider.ChannelTemplate.Name, - GroupIDs: []string{groupID}, - ModelMapping: provider.ChannelTemplate.ModelMapping, - RestrictModels: true, - BillingModelSource: "channel_mapped", - } channel, err := host.CreateChannel(ctx, channelReq) return channel, true, err case 1: + if err := host.UpdateChannel(ctx, existing[0].ID, channelReq); err != nil { + return sub2api.ChannelRef{}, false, err + } return sub2api.ChannelRef{ID: existing[0].ID, Name: existing[0].Name}, false, nil default: - return sub2api.ChannelRef{}, false, fmt.Errorf("multiple channels already exist for %q", provider.ChannelTemplate.Name) + return sub2api.ChannelRef{}, false, fmt.Errorf("multiple channels already exist for %q", channelName) } } -func ensurePlan(ctx context.Context, host hostAdapter, existing []sub2api.NamedResource, provider pack.ProviderManifest, groupID string) (sub2api.PlanRef, bool, error) { +func buildChannelRequest(provider pack.ProviderManifest, groupID, channelName string) sub2api.CreateChannelRequest { + return sub2api.CreateChannelRequest{ + Name: channelName, + GroupIDs: []string{groupID}, + ModelMapping: provider.ChannelTemplate.ModelMapping, + ModelPricing: []sub2api.ChannelModelPricing{{ + Platform: provider.Platform, + Models: append([]string(nil), provider.DefaultModels...), + BillingMode: "token", + Intervals: []sub2api.ChannelPricingTier{}, + }}, + Platform: provider.Platform, + RestrictModels: true, + BillingModelSource: "channel_mapped", + } +} + +func ensurePlan(ctx context.Context, host hostAdapter, existing []sub2api.NamedResource, provider pack.ProviderManifest, groupID, planName string) (sub2api.PlanRef, bool, error) { switch len(existing) { case 0: - plan, err := host.CreatePlan(ctx, sub2api.CreatePlanRequest{GroupID: groupID, Name: provider.PlanTemplate.Name, Price: provider.PlanTemplate.Price, ValidityDays: provider.PlanTemplate.ValidityDays, ValidityUnit: provider.PlanTemplate.ValidityUnit}) + plan, err := host.CreatePlan(ctx, sub2api.CreatePlanRequest{GroupID: groupID, Name: planName, Price: provider.PlanTemplate.Price, ValidityDays: provider.PlanTemplate.ValidityDays, ValidityUnit: provider.PlanTemplate.ValidityUnit}) return plan, true, err case 1: return sub2api.PlanRef{ID: existing[0].ID, Name: existing[0].Name}, false, nil default: - return sub2api.PlanRef{}, false, fmt.Errorf("multiple plans already exist for %q", provider.PlanTemplate.Name) + return sub2api.PlanRef{}, false, fmt.Errorf("multiple plans already exist for %q", planName) } } @@ -329,8 +343,9 @@ func buildBatchAccountsRequest(provider pack.ProviderManifest, groupID string, k Type: provider.AccountType, GroupIDs: []string{groupID}, Credentials: map[string]any{ - "base_url": provider.BaseURL, - "api_key": key, + "base_url": provider.BaseURL, + "api_key": key, + "model_mapping": provider.ChannelTemplate.ModelMapping, }, }) } diff --git a/internal/provision/import_service_test.go b/internal/provision/import_service_test.go index a1082e3d..4a1ad3dc 100644 --- a/internal/provision/import_service_test.go +++ b/internal/provision/import_service_test.go @@ -152,7 +152,7 @@ func TestImportReusesExistingGroup(t *testing.T) { }, gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}}, managedSnapshot: sub2api.ManagedResourceSnapshot{ - Groups: []sub2api.NamedResource{{ID: "group_existing", Name: "DeepSeek 默认分组"}}, + Groups: []sub2api.NamedResource{{ID: "group_existing", Name: "DeepSeek 默认分组-self-service"}}, }, } @@ -198,8 +198,8 @@ func TestImportCreatesChannelWithManifestModelMapping(t *testing.T) { if err != nil { t.Fatalf("Import() error = %v", err) } - if host.createChannelReq.Name != "DeepSeek 默认渠道" { - t.Fatalf("CreateChannel().Name = %q, want DeepSeek 默认渠道", host.createChannelReq.Name) + if host.createChannelReq.Name != "DeepSeek 默认渠道-self-service" { + t.Fatalf("CreateChannel().Name = %q, want DeepSeek 默认渠道-self-service", host.createChannelReq.Name) } if len(host.createChannelReq.GroupIDs) != 1 || host.createChannelReq.GroupIDs[0] != "group_1" { t.Fatalf("CreateChannel().GroupIDs = %v, want [group_1]", host.createChannelReq.GroupIDs) @@ -213,6 +213,31 @@ func TestImportCreatesChannelWithManifestModelMapping(t *testing.T) { if host.createChannelReq.BillingModelSource != "channel_mapped" { t.Fatalf("CreateChannel().BillingModelSource = %q, want channel_mapped", host.createChannelReq.BillingModelSource) } + if len(host.createChannelReq.ModelPricing) != 1 { + t.Fatalf("CreateChannel().ModelPricing len = %d, want 1", len(host.createChannelReq.ModelPricing)) + } + if len(host.createChannelReq.ModelPricing[0].Models) != 2 { + t.Fatalf("CreateChannel().ModelPricing[0].Models = %v, want default model coverage", host.createChannelReq.ModelPricing[0].Models) + } + if host.createChannelReq.ModelPricing[0].BillingMode != "token" { + t.Fatalf("CreateChannel().ModelPricing[0].BillingMode = %q, want token", host.createChannelReq.ModelPricing[0].BillingMode) + } + if len(host.batchCreateReq.Accounts) != 1 { + t.Fatalf("BatchCreateAccounts().Accounts len = %d, want 1", len(host.batchCreateReq.Accounts)) + } + credentials := host.batchCreateReq.Accounts[0].Credentials + switch rawMapping := credentials["model_mapping"].(type) { + case map[string]string: + if got := rawMapping["deepseek-chat"]; got != "deepseek-chat" { + t.Fatalf("BatchCreateAccounts().Credentials.model_mapping = %+v, want deepseek-chat passthrough", rawMapping) + } + case map[string]any: + if got, _ := rawMapping["deepseek-chat"].(string); got != "deepseek-chat" { + t.Fatalf("BatchCreateAccounts().Credentials.model_mapping = %+v, want deepseek-chat passthrough", rawMapping) + } + default: + t.Fatalf("BatchCreateAccounts().Credentials = %+v, want model_mapping map", credentials) + } } func sampleProviderManifest() pack.ProviderManifest { @@ -230,8 +255,48 @@ func sampleProviderManifest() pack.ProviderManifest { } } +func TestImportReconcilesExistingChannelConfiguration(t *testing.T) { + host := &fakeHostAdapter{ + batchAccounts: []sub2api.AccountRef{{ID: "account_1", Name: "deepseek-01"}}, + testResults: map[string]sub2api.ProbeResult{ + "account_1": {OK: true, Status: "ready"}, + }, + models: map[string][]sub2api.AccountModel{ + "account_1": {{ID: "deepseek-chat"}}, + }, + gatewayResult: sub2api.GatewayAccessResult{OK: true, StatusCode: 200, HasExpectedModel: true, Models: []string{"deepseek-chat"}}, + managedSnapshot: sub2api.ManagedResourceSnapshot{ + Groups: []sub2api.NamedResource{{ID: "group_existing", Name: "DeepSeek 默认分组-self-service"}}, + Channels: []sub2api.NamedResource{{ID: "channel_existing", Name: "DeepSeek 默认渠道-self-service"}}, + }, + } + + _, err := NewImportService(host).Import(context.Background(), ImportRequest{ + Provider: sampleProviderManifest(), + Mode: ImportModePartial, + Access: AccessRequest{Mode: AccessModeSelfService, ProbeAPIKey: "user-key"}, + Keys: []string{"key-1"}, + }) + if err != nil { + t.Fatalf("Import() error = %v", err) + } + if host.createChannelCalls != 0 { + t.Fatalf("CreateChannel() calls = %d, want 0 when channel already exists", host.createChannelCalls) + } + if host.updateChannelCalls != 1 { + t.Fatalf("UpdateChannel() calls = %d, want 1", host.updateChannelCalls) + } + if host.updateChannelID != "channel_existing" { + t.Fatalf("UpdateChannel() id = %q, want channel_existing", host.updateChannelID) + } + if len(host.updateChannelReq.ModelPricing) != 1 { + t.Fatalf("UpdateChannel().ModelPricing len = %d, want 1", len(host.updateChannelReq.ModelPricing)) + } +} + type fakeHostAdapter struct { batchAccounts []sub2api.AccountRef + batchCreateReq sub2api.BatchCreateAccountsRequest testResults map[string]sub2api.ProbeResult models map[string][]sub2api.AccountModel gatewayResult sub2api.GatewayAccessResult @@ -246,9 +311,12 @@ type fakeHostAdapter struct { listManagedReq sub2api.ListManagedResourcesRequest createGroupCalls int createChannelCalls int + updateChannelCalls int createPlanCalls int createGroupReq sub2api.CreateGroupRequest createChannelReq sub2api.CreateChannelRequest + updateChannelID string + updateChannelReq sub2api.CreateChannelRequest } func (f *fakeHostAdapter) GetHostVersion(context.Context) (string, error) { @@ -274,6 +342,12 @@ func (f *fakeHostAdapter) CreateChannel(_ context.Context, req sub2api.CreateCha f.createChannelReq = req return sub2api.ChannelRef{ID: "channel_1", Name: "c"}, nil } +func (f *fakeHostAdapter) UpdateChannel(_ context.Context, channelID string, req sub2api.CreateChannelRequest) error { + f.updateChannelCalls++ + f.updateChannelID = channelID + f.updateChannelReq = req + return nil +} func (f *fakeHostAdapter) DeleteChannel(_ context.Context, channelID string) error { f.deletedResources = append(f.deletedResources, "channel:"+channelID) return nil @@ -289,7 +363,8 @@ func (f *fakeHostAdapter) DeletePlan(_ context.Context, planID string) error { func (f *fakeHostAdapter) CreateAccount(context.Context, sub2api.CreateAccountRequest) (sub2api.AccountRef, error) { return sub2api.AccountRef{}, errors.New("unused") } -func (f *fakeHostAdapter) BatchCreateAccounts(_ context.Context, _ sub2api.BatchCreateAccountsRequest) ([]sub2api.AccountRef, error) { +func (f *fakeHostAdapter) BatchCreateAccounts(_ context.Context, req sub2api.BatchCreateAccountsRequest) ([]sub2api.AccountRef, error) { + f.batchCreateReq = req if f.batchCreateErr != nil { return nil, f.batchCreateErr } @@ -313,6 +388,9 @@ func (f *fakeHostAdapter) GetAccountModels(_ context.Context, accountID string) } return models, nil } +func (f *fakeHostAdapter) EnsureSubscriptionAccess(_ context.Context, req sub2api.EnsureSubscriptionAccessRequest) (sub2api.SubscriptionAccessRef, error) { + return sub2api.SubscriptionAccessRef{UserID: req.UserSelector, APIKey: "managed-subscription-key"}, nil +} func (f *fakeHostAdapter) AssignSubscription(_ context.Context, req sub2api.AssignSubscriptionRequest) (sub2api.SubscriptionRef, error) { if f.assignErr != nil { return sub2api.SubscriptionRef{}, f.assignErr diff --git a/internal/provision/preview_service.go b/internal/provision/preview_service.go index b5b7e879..e4544884 100644 --- a/internal/provision/preview_service.go +++ b/internal/provision/preview_service.go @@ -57,7 +57,7 @@ func (s *PreviewService) PreviewImport(ctx context.Context, req PreviewRequest) return PreviewReport{}, fmt.Errorf("preview host is required") } - names := SuggestResourceNames(req.Provider) + names := SuggestResourceNamesForMode(req.Provider, req.Mode) snapshot, err := s.host.ListManagedResources(ctx, sub2api.ListManagedResourcesRequest{ GroupName: names.Group, ChannelName: names.Channel, diff --git a/internal/provision/preview_service_test.go b/internal/provision/preview_service_test.go index d693158c..6bb40775 100644 --- a/internal/provision/preview_service_test.go +++ b/internal/provision/preview_service_test.go @@ -23,6 +23,23 @@ func TestSuggestResourceNames(t *testing.T) { } } +func TestSuggestResourceNamesIncludesAccessModeSuffix(t *testing.T) { + provider := sampleProviderManifest() + provider.GroupTemplate.Name = "" + provider.ChannelTemplate.Name = "" + provider.PlanTemplate.Name = "" + + names := SuggestResourceNamesForMode(provider, AccessModeSubscription) + want := ResourceNames{ + Group: "crm-deepseek-group-subscription", + Channel: "crm-deepseek-channel-subscription", + Plan: "crm-deepseek-plan-subscription", + } + if !reflect.DeepEqual(names, want) { + t.Fatalf("SuggestResourceNamesForMode() = %#v, want %#v", names, want) + } +} + func TestPreviewServiceReportsCreateActionsWhenHostHasNoResources(t *testing.T) { host := &fakePreviewHost{} svc := NewPreviewService(host) diff --git a/internal/provision/provider_status_service.go b/internal/provision/provider_status_service.go index 6ff255d9..4772efef 100644 --- a/internal/provision/provider_status_service.go +++ b/internal/provision/provider_status_service.go @@ -69,13 +69,18 @@ func (s *ProviderStatusService) snapshot(ctx context.Context, query ProviderQuer if err != nil { return ProviderSnapshot{}, err } - reconcileRuns, err := s.store.ReconcileRuns().GetByBatchID(ctx, batchRow.ID) + batches, err := s.store.ImportBatches().ListByProviderIDAndHostID(ctx, provider.ID, hostRow.ID) if err != nil { return ProviderSnapshot{}, err } - latestAccessStatus := batchRow.AccessStatus - if len(accessClosures) > 0 { - latestAccessStatus = firstNonEmpty(accessClosures[len(accessClosures)-1].Status, latestAccessStatus) + modeStatuses, err := LatestModeAccessStatuses(ctx, s.store, batches) + if err != nil { + return ProviderSnapshot{}, err + } + latestAccessStatus := AggregateAccessStatus(modeStatuses) + reconcileRuns, err := s.store.ReconcileRuns().GetByBatchID(ctx, batchRow.ID) + if err != nil { + return ProviderSnapshot{}, err } latestReconcileStatus := "not_run" latestReconcileSummary := map[string]any{} @@ -87,7 +92,7 @@ func (s *ProviderStatusService) snapshot(ctx context.Context, query ProviderQuer } } } - providerStatus := deriveProviderStatus(batchRow.BatchStatus, latestReconcileStatus) + providerStatus := deriveProviderStatus(batchRow.BatchStatus, latestAccessStatus, latestReconcileStatus) return ProviderSnapshot{ Host: hostRow, Pack: packRow, @@ -162,8 +167,12 @@ func (s *ProviderStatusService) resolveHostAndBatch(ctx context.Context, provide return hostRow, batches[0], nil } -func deriveProviderStatus(batchStatus, reconcileStatus string) string { +func deriveProviderStatus(batchStatus, accessStatus, reconcileStatus string) string { reconcileStatus = strings.TrimSpace(reconcileStatus) + accessStatus = strings.TrimSpace(accessStatus) + if strings.TrimSpace(batchStatus) == BatchStatusSucceeded && accessStatus != "" && accessStatus != AccessStatusBroken { + return ProviderStatusActive + } if reconcileStatus != "" && reconcileStatus != "not_run" { return reconcileStatus } diff --git a/internal/provision/provider_status_service_test.go b/internal/provision/provider_status_service_test.go index da717941..283077be 100644 --- a/internal/provision/provider_status_service_test.go +++ b/internal/provision/provider_status_service_test.go @@ -54,8 +54,8 @@ func TestProviderStatusServiceReturnsLatestSnapshot(t *testing.T) { if snapshot.Provider.ProviderID != "deepseek" { t.Fatalf("Provider.ProviderID = %q, want deepseek", snapshot.Provider.ProviderID) } - if snapshot.ProviderStatus != "drifted" { - t.Fatalf("ProviderStatus = %q, want drifted", snapshot.ProviderStatus) + if snapshot.ProviderStatus != ProviderStatusActive { + t.Fatalf("ProviderStatus = %q, want %q", snapshot.ProviderStatus, ProviderStatusActive) } if snapshot.LatestAccessStatus != AccessStatusSelfServiceReady { t.Fatalf("LatestAccessStatus = %q, want %q", snapshot.LatestAccessStatus, AccessStatusSelfServiceReady) diff --git a/internal/provision/reconcile_service_test.go b/internal/provision/reconcile_service_test.go index c27a5f9f..87b939eb 100644 --- a/internal/provision/reconcile_service_test.go +++ b/internal/provision/reconcile_service_test.go @@ -28,8 +28,8 @@ func TestReconcileServiceReturnsActiveAfterProbeRerun(t *testing.T) { batchID := seedRuntimeImportForReconcile(t, store, host) host.managedSnapshot = sub2api.ManagedResourceSnapshot{ - Groups: []sub2api.NamedResource{{ID: "group_1", Name: "DeepSeek 默认分组"}}, - Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "DeepSeek 默认渠道"}}, + Groups: []sub2api.NamedResource{{ID: "group_1", Name: "DeepSeek 默认分组-self-service"}}, + Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "DeepSeek 默认渠道-self-service"}}, Accounts: []sub2api.NamedResource{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}}, } @@ -82,8 +82,8 @@ func TestReconcileServiceReturnsDegradedWhenProbeRerunFails(t *testing.T) { seedRuntimeImportForReconcile(t, store, host) host.managedSnapshot = sub2api.ManagedResourceSnapshot{ - Groups: []sub2api.NamedResource{{ID: "group_1", Name: "DeepSeek 默认分组"}}, - Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "DeepSeek 默认渠道"}}, + Groups: []sub2api.NamedResource{{ID: "group_1", Name: "DeepSeek 默认分组-self-service"}}, + Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "DeepSeek 默认渠道-self-service"}}, Accounts: []sub2api.NamedResource{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}}, } @@ -124,8 +124,8 @@ func TestReconcileServiceReturnsDriftedWhenManagedResourceMissing(t *testing.T) seedRuntimeImportForReconcile(t, store, host) host.managedSnapshot = sub2api.ManagedResourceSnapshot{ - Groups: []sub2api.NamedResource{{ID: "group_1", Name: "DeepSeek 默认分组"}}, - Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "DeepSeek 默认渠道"}}, + Groups: []sub2api.NamedResource{{ID: "group_1", Name: "DeepSeek 默认分组-self-service"}}, + Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "DeepSeek 默认渠道-self-service"}}, Accounts: []sub2api.NamedResource{{ID: "account_1", Name: "deepseek-01"}}, } @@ -166,8 +166,8 @@ func TestReconcileServiceIgnoresSubscriptionPlanForSelfServiceBatch(t *testing.T seedRuntimeImportForReconcile(t, store, host) host.managedSnapshot = sub2api.ManagedResourceSnapshot{ - Groups: []sub2api.NamedResource{{ID: "group_1", Name: "DeepSeek 默认分组"}}, - Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "DeepSeek 默认渠道"}}, + Groups: []sub2api.NamedResource{{ID: "group_1", Name: "DeepSeek 默认分组-self-service"}}, + Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "DeepSeek 默认渠道-self-service"}}, Plans: []sub2api.NamedResource{{ID: "plan_1", Name: "DeepSeek 默认套餐"}}, Accounts: []sub2api.NamedResource{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}}, } @@ -212,8 +212,8 @@ func TestReconcileServicePassesAccountNamePrefixToManagedResourceSnapshot(t *tes seedRuntimeImportForReconcile(t, store, host) host.managedSnapshot = sub2api.ManagedResourceSnapshot{ - Groups: []sub2api.NamedResource{{ID: "group_1", Name: "DeepSeek 默认分组"}}, - Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "DeepSeek 默认渠道"}}, + Groups: []sub2api.NamedResource{{ID: "group_1", Name: "DeepSeek 默认分组-self-service"}}, + Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "DeepSeek 默认渠道-self-service"}}, Accounts: []sub2api.NamedResource{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}}, } diff --git a/packs/openai-cn-pack/checksums.txt b/packs/openai-cn-pack/checksums.txt index 76404a6b..a1a156dd 100644 --- a/packs/openai-cn-pack/checksums.txt +++ b/packs/openai-cn-pack/checksums.txt @@ -1,4 +1,4 @@ 3e3326e40d51a3753adc6fde0aa8859dc5d2076726a692aae45e36f7b27c89d6 pack.json -46da7cd7521c7b808e51dcbb190b14cbb84c6d864557d56de75acad4d2e0fa85 providers/deepseek.json +5d33003fb6fefaf2dbb8445be7c18cc817d266215bc996ae91b30fbed5a98e7b providers/deepseek.json 5dcc402daddacce6dcaceb1501020342f1b1121fbffe9097ede4d5aae072f84e providers/minimax.json fa486a449407f38de8b180ff301568deccef5177ca0436158b1d5b0e6d9328b2 providers/openai-zhongzhuan.json diff --git a/packs/openai-cn-pack/providers/deepseek.json b/packs/openai-cn-pack/providers/deepseek.json index b7b7d65b..3579a770 100644 --- a/packs/openai-cn-pack/providers/deepseek.json +++ b/packs/openai-cn-pack/providers/deepseek.json @@ -13,8 +13,8 @@ "channel_template": { "name": "DeepSeek 默认渠道", "model_mapping": { - "deepseek-chat": "deepseek-chat", - "deepseek-reasoner": "deepseek-reasoner" + "deepseek-v4-pro": "deepseek-v4-pro", + "deepseek-v4-flash": "deepseek-v4-flash" } }, "plan_template": { diff --git a/scripts/import_remote43_provider.sh b/scripts/import_remote43_provider.sh index 62c1a7db..1b2435ba 100755 --- a/scripts/import_remote43_provider.sh +++ b/scripts/import_remote43_provider.sh @@ -15,6 +15,10 @@ REMOTE="${REMOTE:-ubuntu@43.155.133.187}" CRM_BASE="${CRM_BASE:-http://127.0.0.1:18088}" HOST_BASE="${HOST_BASE:-http://127.0.0.1:18087}" CRM_HOST_BASE="${CRM_HOST_BASE:-$HOST_BASE}" +HOST_NAME="${HOST_NAME:-remote43-current-host}" +REMOTE_HOST_ENV_FILE="${REMOTE_HOST_ENV_FILE:-/home/ubuntu/sub2api-host-validation-fresh-deepseek-20260519_115244/.env}" +REMOTE_PG_CONTAINER="${REMOTE_PG_CONTAINER:-sub2api-relaymgr-pg}" +REMOTE_REDIS_CONTAINER="${REMOTE_REDIS_CONTAINER:-sub2api-relaymgr-redis}" PACK_PATH="${PACK_PATH:-/home/ubuntu/sub2api-cn-relay-manager/packs/openai-cn-pack}" ROOT="${ROOT:-$ROOT_DIR/artifacts/real-host-acceptance}" ART="${ART:-$ROOT/$(date +%Y%m%d_%H%M%S)_remote43_${provider_id}_key_import}" @@ -22,6 +26,8 @@ MIN_BALANCE="${MIN_BALANCE:-10}" SUBSCRIPTION_DAYS="${SUBSCRIPTION_DAYS:-30}" SUBSCRIPTION_NOTES="${SUBSCRIPTION_NOTES:-hermes remote subscription validation}" mkdir -p "$ART" +REMOTE_PG_CONTAINER_Q="$(printf '%q' "$REMOTE_PG_CONTAINER")" +REMOTE_REDIS_CONTAINER_Q="$(printf '%q' "$REMOTE_REDIS_CONTAINER")" if [[ -n "$key_file" ]]; then upstream_key="$(tr -d '\r\n' < "$key_file")" @@ -40,18 +46,67 @@ ssh_cmd() { ssh -i "$KEY" -o StrictHostKeyChecking=no "$REMOTE" "$cmd" } +crm_curl_json() { + local method="$1" + local path="$2" + local payload="${3:-}" + if [[ -n "$payload" ]]; then + curl -fsS -X "$method" \ + -H "Authorization: Bearer $crm_token" \ + -H 'Content-Type: application/json' \ + "${CRM_BASE}${path}" \ + -d "$payload" + else + curl -fsS -X "$method" \ + -H "Authorization: Bearer $crm_token" \ + "${CRM_BASE}${path}" + fi +} + +fetch_remote_host_bearer_token() { + ssh_cmd "python3 - <<'PY' +from pathlib import Path +import json, subprocess, sys + +env_path = Path(${REMOTE_HOST_ENV_FILE@Q}) +host_base = ${HOST_BASE@Q} +vals = {} +for line in env_path.read_text().splitlines(): + if '=' not in line: + continue + key, value = line.split('=', 1) + vals[key] = value + +payload = json.dumps({ + 'email': vals['ADMIN_EMAIL'], + 'password': vals['ADMIN_PASSWORD'], + 'turnstile_token': '', +}, ensure_ascii=False) +res = subprocess.run([ + 'curl', '-fsS', '-H', 'Content-Type: application/json', '-X', 'POST', + host_base.rstrip('/') + '/api/v1/auth/login', '-d', payload, +], text=True, capture_output=True) +obj = json.loads(res.stdout) +token = (obj.get('data') or {}).get('access_token', '') +if not token: + print(res.stdout, file=sys.stderr) + raise SystemExit('missing access_token from remote host login') +print(token) +PY" +} + remote_pg_exec() { local sql="$1" local encoded encoded="$(printf '%s' "$sql" | base64 -w0)" - ssh_cmd "printf '%s' '$encoded' | base64 -d | sudo -n docker exec -i sub2api-relaymgr-pg psql -U sub2api -d sub2api" + ssh_cmd "printf '%s' '$encoded' | base64 -d | sudo -n docker exec -i $REMOTE_PG_CONTAINER_Q psql -U sub2api -d sub2api" } remote_pg_query() { local sql="$1" local encoded encoded="$(printf '%s' "$sql" | base64 -w0)" - ssh_cmd "printf '%s' '$encoded' | base64 -d | sudo -n docker exec -i sub2api-relaymgr-pg psql -U sub2api -d sub2api -At -F $'\t'" + ssh_cmd "printf '%s' '$encoded' | base64 -d | sudo -n docker exec -i $REMOTE_PG_CONTAINER_Q psql -U sub2api -d sub2api -At -F $'\t'" } remote_fetch_group_state() { @@ -59,11 +114,12 @@ remote_fetch_group_state() { local user_id="$2" local api_key="$3" local output_path="$4" - local encoded - encoded="$(python3 - "$group_id" "$user_id" "$api_key" <<'PY' -import json, sys + local sql + sql="$(python3 - "$group_id" "$user_id" "$api_key" <<'PY' +import sys group_id, user_id, api_key = sys.argv[1:4] +api_key_literal = "'" + api_key.replace("'", "''") + "'" query = f""" WITH group_row AS ( SELECT row_to_json(g) AS data FROM groups g WHERE g.id = {group_id} @@ -74,7 +130,7 @@ subscription_row AS ( ORDER BY s.id DESC LIMIT 1 ), key_row AS ( - SELECT row_to_json(k) AS data FROM api_keys k WHERE k.key = {json.dumps(api_key)} + SELECT row_to_json(k) AS data FROM api_keys k WHERE k.key = {api_key_literal} ) SELECT json_build_object( 'group_id', {group_id}, @@ -86,7 +142,7 @@ SELECT json_build_object( print(query) PY )" - ssh_cmd "printf '%s' '$encoded' | base64 -d | sudo -n docker exec -i sub2api-relaymgr-pg psql -U sub2api -d sub2api -At -F ''" > "$output_path" + remote_pg_query "$sql" > "$output_path" } python3 - "$ART/00-local-key-source.json" "$key_source" "$provider_id" "$upstream_key" <<'PY' @@ -100,11 +156,17 @@ pathlib.Path(path).write_text(json.dumps({ }, ensure_ascii=False, indent=2), encoding='utf-8') PY -crm_token="$(ssh_cmd "grep ^SUB2API_CRM_ADMIN_TOKEN= /home/ubuntu/sub2api-cn-relay-manager/.env.remote | cut -d= -f2-")" -crm_token="${crm_token##*$'\n'}" -admin_key="$(ssh_cmd "sudo -n docker exec sub2api-relaymgr-pg psql -U sub2api -d sub2api -Atc \"select value from settings where key='admin_api_key';\"")" -admin_key="${admin_key##*$'\n'}" -admin_uid="$(ssh_cmd "sudo -n docker exec sub2api-relaymgr-pg psql -U sub2api -d sub2api -Atc \"select id from users where role='admin' order by id asc limit 1;\"")" +crm_token="${CRM_ADMIN_TOKEN:-}" +if [[ -z "$crm_token" ]]; then + crm_token="$(ssh_cmd "grep ^SUB2API_CRM_ADMIN_TOKEN= /home/ubuntu/sub2api-cn-relay-manager/.env.remote | cut -d= -f2-")" + crm_token="${crm_token##*$'\n'}" +fi +host_bearer_token="${HOST_BEARER_TOKEN:-}" +if [[ -z "$host_bearer_token" ]]; then + host_bearer_token="$(fetch_remote_host_bearer_token)" + host_bearer_token="${host_bearer_token##*$'\n'}" +fi +admin_uid="$(ssh_cmd "sudo -n docker exec $REMOTE_PG_CONTAINER_Q psql -U sub2api -d sub2api -Atc \"select id from users where role='admin' order by id asc limit 1;\"")" admin_uid="${admin_uid##*$'\n'}" sub_uid="$(remote_pg_query "select id from users where email like 'relay-sub-%@sub2api.local' and not exists (select 1 from user_subscriptions s where s.user_id=users.id and s.deleted_at is null) order by id desc limit 1;")" sub_uid="${sub_uid##*$'\n'}" @@ -208,12 +270,36 @@ pathlib.Path(path).write_text(json.dumps({ }, ensure_ascii=False, indent=2), encoding='utf-8') PY -payload="$(python3 - "$CRM_HOST_BASE" "$admin_key" "$PACK_PATH" "$provider_id" "$upstream_key" "$sub_key" "$sub_uid" "$SUBSCRIPTION_DAYS" <<'PY' +create_host_payload="$(python3 - "$HOST_NAME" "$CRM_HOST_BASE" "$host_bearer_token" <<'PY' import json, sys -host_base, admin_key, pack_path, provider_id, upstream_key, sub_key, sub_uid, subscription_days = sys.argv[1:9] +name, base_url, bearer_token = sys.argv[1:4] +print(json.dumps({ + 'name': name, + 'base_url': base_url, + 'auth': {'type': 'bearer', 'token': bearer_token}, +}, ensure_ascii=False)) +PY +)" +hosts_payload="$(crm_curl_json GET "/api/hosts")" +existing_host_json="$(printf '%s' "$hosts_payload" | python3 -c 'import json, sys +base_url = sys.argv[1] +payload = json.load(sys.stdin) +for host in payload.get("hosts", []): + if host.get("base_url") == base_url: + print(json.dumps(host, ensure_ascii=False)) + break' "$CRM_HOST_BASE")" +if [[ -n "$existing_host_json" ]]; then + printf '%s\n' "$existing_host_json" > "$ART/01a-create-host.json" +else + crm_curl_json POST "/api/hosts" "$create_host_payload" > "$ART/01a-create-host.json" +fi + +payload="$(python3 - "$CRM_HOST_BASE" "$host_bearer_token" "$PACK_PATH" "$provider_id" "$upstream_key" "$sub_key" "$sub_uid" "$SUBSCRIPTION_DAYS" <<'PY' +import json, sys +host_base, host_bearer_token, pack_path, provider_id, upstream_key, sub_key, sub_uid, subscription_days = sys.argv[1:9] print(json.dumps({ 'host_base_url': host_base, - 'host_api_key': admin_key, + 'host_bearer_token': host_bearer_token, 'pack_path': pack_path, 'provider_id': provider_id, 'keys': [upstream_key], @@ -226,9 +312,11 @@ print(json.dumps({ PY )" -ssh_cmd "curl -sS -D /tmp/import_headers.txt -o /tmp/import_body.json -X POST -H 'Authorization: Bearer $crm_token' -H 'Content-Type: application/json' $CRM_BASE/api/providers/$provider_id/import -d $(printf %q "$payload")" -ssh_cmd "cat /tmp/import_headers.txt" > "$ART/02-import.headers.txt" -ssh_cmd "cat /tmp/import_body.json" > "$ART/03-import.body.json" +curl -sS -D "$ART/02-import.headers.txt" -o "$ART/03-import.body.json" -X POST \ + -H "Authorization: Bearer $crm_token" \ + -H 'Content-Type: application/json' \ + "$CRM_BASE/api/providers/$provider_id/import" \ + -d "$payload" batch_id="$(python3 - "$ART/03-import.body.json" <<'PY' import json, sys, pathlib @@ -237,7 +325,7 @@ print(obj['batch_id']) PY )" -ssh_cmd "curl -sS -H 'Authorization: Bearer $crm_token' $CRM_BASE/api/import-batches/$batch_id" > "$ART/04-batch-detail-initial.json" +crm_curl_json GET "/api/import-batches/$batch_id" > "$ART/04-batch-detail-initial.json" subscription_group_id="$(python3 - "$ART/03-import.body.json" "$ART/04-batch-detail-initial.json" <<'PY' import json, pathlib, sys import_obj = json.loads(pathlib.Path(sys.argv[1]).read_text()) @@ -270,7 +358,7 @@ remote_pg_exec "$prep_sql" > "$ART/06-subscription-access-prep.psql.txt" printf 'auth_cache_key=%s\n' "$auth_cache_key" printf 'balance_cache_key=%s\n' "$balance_cache_key" printf 'subscription_cache_key=%s\n' "$subscription_cache_key" - ssh_cmd "sudo -n docker exec sub2api-relaymgr-redis redis-cli DEL $auth_cache_key $balance_cache_key $subscription_cache_key" + ssh_cmd "sudo -n docker exec $REMOTE_REDIS_CONTAINER_Q redis-cli DEL $auth_cache_key $balance_cache_key $subscription_cache_key" } > "$ART/07-redis-targeted-invalidation.txt" remote_fetch_group_state "$subscription_group_id" "$sub_uid" "$sub_key" "$ART/08-subscription-group-state.json" @@ -299,26 +387,27 @@ print(json.dumps({ }, ensure_ascii=False)) PY )" -ssh_cmd "curl -sS -D /tmp/models_headers.txt -o /tmp/models_body.json -H 'Authorization: Bearer *** $HOST_BASE/v1/models" +ssh_cmd "curl -sS -D /tmp/models_headers.txt -o /tmp/models_body.json -H 'Authorization: Bearer $sub_key' $HOST_BASE/v1/models" ssh_cmd "cat /tmp/models_headers.txt" > "$ART/09-models.headers.txt" ssh_cmd "cat /tmp/models_body.json" > "$ART/10-models.body.json" -ssh_cmd "curl -sS -D /tmp/chat_headers.txt -o /tmp/chat_body.json -H 'Authorization: Bearer *** -H 'Content-Type: application/json' $HOST_BASE/v1/chat/completions -d $(printf %q "$probe_payload")" +ssh_cmd "curl -sS -D /tmp/chat_headers.txt -o /tmp/chat_body.json -H 'Authorization: Bearer $sub_key' -H 'Content-Type: application/json' $HOST_BASE/v1/chat/completions -d $(printf %q "$probe_payload")" ssh_cmd "cat /tmp/chat_headers.txt" > "$ART/11-chat.headers.txt" ssh_cmd "cat /tmp/chat_body.json" > "$ART/12-chat.body.json" -ssh_cmd "curl -sS -H 'Authorization: Bearer *** $CRM_BASE/api/providers/$provider_id/status" > "$ART/13-provider-status.json" -ssh_cmd "curl -sS -H 'Authorization: Bearer *** $CRM_BASE/api/providers/$provider_id/access/status" > "$ART/14-access-status.json" +crm_curl_json GET "/api/providers/$provider_id/status" > "$ART/13-provider-status.json" +crm_curl_json GET "/api/providers/$provider_id/access/status" > "$ART/14-access-status.json" preview_payload="$(python3 - "$provider_id" <<'PY' import json, sys print(json.dumps({'provider_id': sys.argv[1], 'mode': 'subscription'}, ensure_ascii=False)) PY )" -ssh_cmd "curl -sS -X POST -H 'Authorization: Bearer *** -H 'Content-Type: application/json' $CRM_BASE/api/providers/$provider_id/access/preview -d $(printf %q "$preview_payload")" > "$ART/15-access-preview.json" -ssh_cmd "curl -sS -H 'Authorization: Bearer *** $CRM_BASE/api/import-batches/$batch_id" > "$ART/16-batch-detail-final.json" +crm_curl_json POST "/api/providers/$provider_id/access/preview" "$preview_payload" > "$ART/15-access-preview.json" +crm_curl_json GET "/api/import-batches/$batch_id" > "$ART/16-batch-detail-final.json" python3 - "$ART" "$provider_id" "$batch_id" "$subscription_group_id" "$model_name" <<'PY' import json, pathlib, sys + art=pathlib.Path(sys.argv[1]) provider_id=sys.argv[2] batch_id=int(sys.argv[3]) diff --git a/scripts/real_host_acceptance.sh b/scripts/real_host_acceptance.sh index 89629963..798bd6d9 100755 --- a/scripts/real_host_acceptance.sh +++ b/scripts/real_host_acceptance.sh @@ -150,10 +150,13 @@ PY )" if RESP_EXISTING_HOST="$(curl_json GET "/api/hosts/$HOST_NAME" 2>/dev/null)"; then - RESP_CREATE_HOST="$RESP_EXISTING_HOST" -else - RESP_CREATE_HOST="$(curl_json POST /api/hosts "$CREATE_HOST_PAYLOAD")" + EXISTING_BASE_URL="$(printf '%s' "$RESP_EXISTING_HOST" | json_get base_url || true)" + if [[ -n "$EXISTING_BASE_URL" && "$EXISTING_BASE_URL" != "$HOST_BASE_URL" ]]; then + echo "existing host $HOST_NAME points to $EXISTING_BASE_URL, expected $HOST_BASE_URL" >&2 + exit 1 + fi fi +RESP_CREATE_HOST="$(curl_json POST /api/hosts "$CREATE_HOST_PAYLOAD")" save_json 01-create-host "$RESP_CREATE_HOST" HOST_ID="$(printf '%s' "$RESP_CREATE_HOST" | json_get host_id || true)" HOST_ID="${HOST_ID:-$HOST_NAME}" diff --git a/scripts/test_real_host_scripts.sh b/scripts/test_real_host_scripts.sh index 8995a629..6ca4c6ac 100644 --- a/scripts/test_real_host_scripts.sh +++ b/scripts/test_real_host_scripts.sh @@ -59,6 +59,10 @@ run_test_real_host_acceptance_after_import_hook() { set -euo pipefail url="" for arg in "$@"; do + if [[ "$arg" == *'***'* ]]; then + echo "unexpected redacted auth placeholder in curl args: $*" >&2 + exit 1 + fi if [[ "$arg" == http://* || "$arg" == https://* ]]; then url="$arg" fi @@ -68,6 +72,9 @@ done exit 1 } case "$url" in + */api/hosts) + printf '%s\n' '{"host_id":"test-host"}' + ;; */api/hosts/test-host) printf '%s\n' '{"host_id":"test-host"}' ;; @@ -118,6 +125,8 @@ EOF PACK_PATH="/tmp/openai-pack" \ PROVIDER_ID="deepseek" \ HOST_API_KEY="host-key" \ + REMOTE_PG_CONTAINER="fresh-pg" \ + REMOTE_REDIS_CONTAINER="fresh-redis" \ MODE="partial" \ ACCESS_MODE="subscription" \ ACCESS_API_KEY="user-key" \ @@ -143,13 +152,96 @@ run_test_import_remote43_provider_subscription_prep() { psql_sql="$artifact_dir/prep.sql" mkdir -p "$fakebin" + cat > "$fakebin/curl" <<'EOF' +#!/usr/bin/env bash +set -euo pipefail +headers_file="" +body_file="" +url="" +prev="" +for arg in "$@"; do + if [[ "$arg" == *'***'* ]]; then + echo "unexpected redacted auth placeholder in curl args: $*" >&2 + exit 1 + fi + case "$prev" in + -D) + headers_file="$arg" + prev="" + continue + ;; + -o) + body_file="$arg" + prev="" + continue + ;; + esac + case "$arg" in + -D|-o) + prev="$arg" + continue + ;; + http://*|https://*) + url="$arg" + ;; + esac +done + +write_headers() { + [[ -n "$headers_file" ]] && printf '%s\n' 'HTTP/1.1 200 OK' > "$headers_file" +} + +write_body() { + local body="$1" + if [[ -n "$body_file" ]]; then + printf '%s\n' "$body" > "$body_file" + else + printf '%s\n' "$body" + fi +} + +case "$url" in + */api/hosts) + write_body '{"host_id":"remote43-current-host"}' + ;; + */api/providers/deepseek/import) + write_headers + write_body '{"batch_id":123,"batch_status":"partially_succeeded","access_status":"broken","provider_status":"ready","accepted_keys_count":1,"group":{"id":"7","name":"DeepSeek 默认分组"}}' + ;; + */api/import-batches/123) + write_body '{"managed_resources":[{"ResourceType":"group","HostResourceID":"7","ResourceName":"DeepSeek 默认分组"}]}' + ;; + */api/providers/deepseek/status) + write_body '{"status":"ready"}' + ;; + */api/providers/deepseek/access/status) + write_body '{"latest_access_status":"subscription_ready"}' + ;; + */api/providers/deepseek/access/preview) + write_body '{"available":true}' + ;; + *) + echo "unexpected curl url: $url" >&2 + exit 1 + ;; +esac +EOF + chmod +x "$fakebin/curl" + cat > "$fakebin/ssh" <<'EOF' #!/usr/bin/env bash set -euo pipefail log_dir="${FAKE_REMOTE_LOG_DIR:?missing FAKE_REMOTE_LOG_DIR}" cmd="${*: -1}" printf '%s\n' "$cmd" >> "$log_dir/ssh-log.txt" +if [[ "$cmd" == *'***'* ]]; then + echo "unexpected redacted auth placeholder in ssh command: $cmd" >&2 + exit 1 +fi case "$cmd" in + *"/api/v1/auth/login"*) + printf '%s\n' 'host-bearer-token' + ;; *"grep ^SUB2API_CRM_ADMIN_TOKEN="*) printf '%s\n' 'crm-token' ;; @@ -210,10 +302,10 @@ case "$cmd" in *"/api/providers/deepseek/reconcile"*) printf '%s\n' '{"status":"in_sync"}' ;; - *"sudo -n docker exec -i sub2api-relaymgr-pg psql -U sub2api -d sub2api -At -F ''"*) + *"sudo -n docker exec -i fresh-pg psql -U sub2api -d sub2api -At -F ''"*) printf '%s\n' '{"group_id":7,"subscription":{"status":"active"},"key":{"group_id":7}}' ;; - *"sudo -n docker exec -i sub2api-relaymgr-pg psql -U sub2api -d sub2api"*) + *"sudo -n docker exec -i fresh-pg psql -U sub2api -d sub2api"*) CMD="$cmd" LOG_DIR="$log_dir" python3 - <<'PY' import base64, os, re, pathlib, sys cmd = os.environ['CMD'] @@ -222,18 +314,24 @@ match = re.search(r"printf '%s' '([^']+)' \| base64 -d", cmd) if not match: raise SystemExit(f'failed to extract base64 payload from: {cmd}') sql = base64.b64decode(match.group(1)).decode() -log_dir.joinpath('prep.sql').write_text(sql, encoding='utf-8') if "select id from users where email like 'relay-sub-%@sub2api.local' and not exists" in sql: print('') elif "select k.key from users u join api_keys k on k.user_id=u.id" in sql and "not exists" in sql: print('') +elif "UPDATE users" in sql and "INSERT INTO user_subscriptions" in sql: + log_dir.joinpath('prep.sql').write_text(sql, encoding='utf-8') + print('') elif "INSERT INTO users" in sql and "INSERT INTO api_keys" in sql: + log_dir.joinpath('create-user.sql').write_text(sql, encoding='utf-8') print('84\tuser-key-fresh') +elif "SELECT json_build_object(" in sql: + log_dir.joinpath('group-state.sql').write_text(sql, encoding='utf-8') + print('{"group_id":7,"subscription":{"status":"active"},"key":{"group_id":7}}') else: print('') PY ;; - *"sudo -n docker exec sub2api-relaymgr-redis redis-cli DEL apikey:auth:"*" billing:balance:"*" billing:sub:"*":7"*) + *"sudo -n docker exec fresh-redis redis-cli DEL apikey:auth:"*" billing:balance:"*" billing:sub:"*":7"*) printf '%s\n' '3' ;; *) @@ -254,6 +352,8 @@ EOF ROOT="$artifact_dir/root" \ ART="$artifact_dir/run" \ PACK_PATH="/tmp/openai-pack" \ + REMOTE_PG_CONTAINER="fresh-pg" \ + REMOTE_REDIS_CONTAINER="fresh-redis" \ UPSTREAM_KEY="upstream-test-key" \ SUBSCRIPTION_DAYS=30 \ MIN_BALANCE=10 \ @@ -274,7 +374,10 @@ EOF assert_contains "$invalidation_log" "auth_cache_key=apikey:auth:" assert_contains "$invalidation_log" "balance_cache_key=billing:balance:84" assert_contains "$invalidation_log" "subscription_cache_key=billing:sub:84:7" - local models_body chat_body + local subscription_state models_body chat_body + subscription_state="$(cat "$artifact_dir/run/08-subscription-group-state.json")" + assert_contains "$subscription_state" '"group_id":7' + assert_contains "$subscription_state" '"status":"active"' models_body="$(cat "$artifact_dir/run/10-models.body.json")" chat_body="$(cat "$artifact_dir/run/12-chat.body.json")" assert_contains "$models_body" '"id":"gpt-4"'