chore: initial public snapshot for github upload

This commit is contained in:
Your Name
2026-03-26 20:06:14 +08:00
commit 0e5ecd930e
3497 changed files with 1586236 additions and 0 deletions

View File

@@ -0,0 +1,606 @@
package admin
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"time"
"log/slog"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
const (
dataType = "sub2api-data"
legacyDataType = "sub2api-bundle"
dataVersion = 1
dataPageCap = 1000
)
type DataPayload struct {
Type string `json:"type,omitempty"`
Version int `json:"version,omitempty"`
ExportedAt string `json:"exported_at"`
Proxies []DataProxy `json:"proxies"`
Accounts []DataAccount `json:"accounts"`
}
type DataProxy struct {
ProxyKey string `json:"proxy_key"`
Name string `json:"name"`
Protocol string `json:"protocol"`
Host string `json:"host"`
Port int `json:"port"`
Username string `json:"username,omitempty"`
Password string `json:"password,omitempty"`
Status string `json:"status"`
}
type DataAccount struct {
Name string `json:"name"`
Notes *string `json:"notes,omitempty"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra,omitempty"`
ProxyKey *string `json:"proxy_key,omitempty"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
RateMultiplier *float64 `json:"rate_multiplier,omitempty"`
ExpiresAt *int64 `json:"expires_at,omitempty"`
AutoPauseOnExpired *bool `json:"auto_pause_on_expired,omitempty"`
}
type DataImportRequest struct {
Data DataPayload `json:"data"`
SkipDefaultGroupBind *bool `json:"skip_default_group_bind"`
}
type DataImportResult struct {
ProxyCreated int `json:"proxy_created"`
ProxyReused int `json:"proxy_reused"`
ProxyFailed int `json:"proxy_failed"`
AccountCreated int `json:"account_created"`
AccountFailed int `json:"account_failed"`
Errors []DataImportError `json:"errors,omitempty"`
}
type DataImportError struct {
Kind string `json:"kind"`
Name string `json:"name,omitempty"`
ProxyKey string `json:"proxy_key,omitempty"`
Message string `json:"message"`
}
func buildProxyKey(protocol, host string, port int, username, password string) string {
return fmt.Sprintf("%s|%s|%d|%s|%s", strings.TrimSpace(protocol), strings.TrimSpace(host), port, strings.TrimSpace(username), strings.TrimSpace(password))
}
func (h *AccountHandler) ExportData(c *gin.Context) {
ctx := c.Request.Context()
selectedIDs, err := parseAccountIDs(c)
if err != nil {
response.BadRequest(c, err.Error())
return
}
accounts, err := h.resolveExportAccounts(ctx, selectedIDs, c)
if err != nil {
response.ErrorFrom(c, err)
return
}
includeProxies, err := parseIncludeProxies(c)
if err != nil {
response.BadRequest(c, err.Error())
return
}
var proxies []service.Proxy
if includeProxies {
proxies, err = h.resolveExportProxies(ctx, accounts)
if err != nil {
response.ErrorFrom(c, err)
return
}
} else {
proxies = []service.Proxy{}
}
proxyKeyByID := make(map[int64]string, len(proxies))
dataProxies := make([]DataProxy, 0, len(proxies))
for i := range proxies {
p := proxies[i]
key := buildProxyKey(p.Protocol, p.Host, p.Port, p.Username, p.Password)
proxyKeyByID[p.ID] = key
dataProxies = append(dataProxies, DataProxy{
ProxyKey: key,
Name: p.Name,
Protocol: p.Protocol,
Host: p.Host,
Port: p.Port,
Username: p.Username,
Password: p.Password,
Status: p.Status,
})
}
dataAccounts := make([]DataAccount, 0, len(accounts))
for i := range accounts {
acc := accounts[i]
var proxyKey *string
if acc.ProxyID != nil {
if key, ok := proxyKeyByID[*acc.ProxyID]; ok {
proxyKey = &key
}
}
var expiresAt *int64
if acc.ExpiresAt != nil {
v := acc.ExpiresAt.Unix()
expiresAt = &v
}
dataAccounts = append(dataAccounts, DataAccount{
Name: acc.Name,
Notes: acc.Notes,
Platform: acc.Platform,
Type: acc.Type,
Credentials: acc.Credentials,
Extra: acc.Extra,
ProxyKey: proxyKey,
Concurrency: acc.Concurrency,
Priority: acc.Priority,
RateMultiplier: acc.RateMultiplier,
ExpiresAt: expiresAt,
AutoPauseOnExpired: &acc.AutoPauseOnExpired,
})
}
payload := DataPayload{
ExportedAt: time.Now().UTC().Format(time.RFC3339),
Proxies: dataProxies,
Accounts: dataAccounts,
}
response.Success(c, payload)
}
func (h *AccountHandler) ImportData(c *gin.Context) {
var req DataImportRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := validateDataHeader(req.Data); err != nil {
response.BadRequest(c, err.Error())
return
}
executeAdminIdempotentJSON(c, "admin.accounts.import_data", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
return h.importData(ctx, req)
})
}
func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest) (DataImportResult, error) {
skipDefaultGroupBind := true
if req.SkipDefaultGroupBind != nil {
skipDefaultGroupBind = *req.SkipDefaultGroupBind
}
dataPayload := req.Data
result := DataImportResult{}
existingProxies, err := h.listAllProxies(ctx)
if err != nil {
return result, err
}
proxyKeyToID := make(map[string]int64, len(existingProxies))
for i := range existingProxies {
p := existingProxies[i]
key := buildProxyKey(p.Protocol, p.Host, p.Port, p.Username, p.Password)
proxyKeyToID[key] = p.ID
}
for i := range dataPayload.Proxies {
item := dataPayload.Proxies[i]
key := item.ProxyKey
if key == "" {
key = buildProxyKey(item.Protocol, item.Host, item.Port, item.Username, item.Password)
}
if err := validateDataProxy(item); err != nil {
result.ProxyFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "proxy",
Name: item.Name,
ProxyKey: key,
Message: err.Error(),
})
continue
}
normalizedStatus := normalizeProxyStatus(item.Status)
if existingID, ok := proxyKeyToID[key]; ok {
proxyKeyToID[key] = existingID
result.ProxyReused++
if normalizedStatus != "" {
if proxy, getErr := h.adminService.GetProxy(ctx, existingID); getErr == nil && proxy != nil && proxy.Status != normalizedStatus {
_, _ = h.adminService.UpdateProxy(ctx, existingID, &service.UpdateProxyInput{
Status: normalizedStatus,
})
}
}
continue
}
created, createErr := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{
Name: defaultProxyName(item.Name),
Protocol: item.Protocol,
Host: item.Host,
Port: item.Port,
Username: item.Username,
Password: item.Password,
})
if createErr != nil {
result.ProxyFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "proxy",
Name: item.Name,
ProxyKey: key,
Message: createErr.Error(),
})
continue
}
proxyKeyToID[key] = created.ID
result.ProxyCreated++
if normalizedStatus != "" && normalizedStatus != created.Status {
_, _ = h.adminService.UpdateProxy(ctx, created.ID, &service.UpdateProxyInput{
Status: normalizedStatus,
})
}
}
for i := range dataPayload.Accounts {
item := dataPayload.Accounts[i]
if err := validateDataAccount(item); err != nil {
result.AccountFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "account",
Name: item.Name,
Message: err.Error(),
})
continue
}
var proxyID *int64
if item.ProxyKey != nil && *item.ProxyKey != "" {
if id, ok := proxyKeyToID[*item.ProxyKey]; ok {
proxyID = &id
} else {
result.AccountFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "account",
Name: item.Name,
ProxyKey: *item.ProxyKey,
Message: "proxy_key not found",
})
continue
}
}
enrichCredentialsFromIDToken(&item)
accountInput := &service.CreateAccountInput{
Name: item.Name,
Notes: item.Notes,
Platform: item.Platform,
Type: item.Type,
Credentials: item.Credentials,
Extra: item.Extra,
ProxyID: proxyID,
Concurrency: item.Concurrency,
Priority: item.Priority,
RateMultiplier: item.RateMultiplier,
GroupIDs: nil,
ExpiresAt: item.ExpiresAt,
AutoPauseOnExpired: item.AutoPauseOnExpired,
SkipDefaultGroupBind: skipDefaultGroupBind,
}
if _, err := h.adminService.CreateAccount(ctx, accountInput); err != nil {
result.AccountFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "account",
Name: item.Name,
Message: err.Error(),
})
continue
}
result.AccountCreated++
}
return result, nil
}
func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, error) {
page := 1
pageSize := dataPageCap
var out []service.Proxy
for {
items, total, err := h.adminService.ListProxies(ctx, page, pageSize, "", "", "")
if err != nil {
return nil, err
}
out = append(out, items...)
if len(out) >= int(total) || len(items) == 0 {
break
}
page++
}
return out, nil
}
func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, accountType, status, search string) ([]service.Account, error) {
page := 1
pageSize := dataPageCap
var out []service.Account
for {
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0)
if err != nil {
return nil, err
}
out = append(out, items...)
if len(out) >= int(total) || len(items) == 0 {
break
}
page++
}
return out, nil
}
func (h *AccountHandler) resolveExportAccounts(ctx context.Context, ids []int64, c *gin.Context) ([]service.Account, error) {
if len(ids) > 0 {
accounts, err := h.adminService.GetAccountsByIDs(ctx, ids)
if err != nil {
return nil, err
}
out := make([]service.Account, 0, len(accounts))
for _, acc := range accounts {
if acc == nil {
continue
}
out = append(out, *acc)
}
return out, nil
}
platform := c.Query("platform")
accountType := c.Query("type")
status := c.Query("status")
search := strings.TrimSpace(c.Query("search"))
if len(search) > 100 {
search = search[:100]
}
return h.listAccountsFiltered(ctx, platform, accountType, status, search)
}
func (h *AccountHandler) resolveExportProxies(ctx context.Context, accounts []service.Account) ([]service.Proxy, error) {
if len(accounts) == 0 {
return []service.Proxy{}, nil
}
seen := make(map[int64]struct{})
ids := make([]int64, 0)
for i := range accounts {
if accounts[i].ProxyID == nil {
continue
}
id := *accounts[i].ProxyID
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
ids = append(ids, id)
}
if len(ids) == 0 {
return []service.Proxy{}, nil
}
return h.adminService.GetProxiesByIDs(ctx, ids)
}
func parseAccountIDs(c *gin.Context) ([]int64, error) {
values := c.QueryArray("ids")
if len(values) == 0 {
raw := strings.TrimSpace(c.Query("ids"))
if raw != "" {
values = []string{raw}
}
}
if len(values) == 0 {
return nil, nil
}
ids := make([]int64, 0, len(values))
for _, item := range values {
for _, part := range strings.Split(item, ",") {
part = strings.TrimSpace(part)
if part == "" {
continue
}
id, err := strconv.ParseInt(part, 10, 64)
if err != nil || id <= 0 {
return nil, fmt.Errorf("invalid account id: %s", part)
}
ids = append(ids, id)
}
}
return ids, nil
}
func parseIncludeProxies(c *gin.Context) (bool, error) {
raw := strings.TrimSpace(strings.ToLower(c.Query("include_proxies")))
if raw == "" {
return true, nil
}
switch raw {
case "1", "true", "yes", "on":
return true, nil
case "0", "false", "no", "off":
return false, nil
default:
return true, fmt.Errorf("invalid include_proxies value: %s", raw)
}
}
func validateDataHeader(payload DataPayload) error {
if payload.Type != "" && payload.Type != dataType && payload.Type != legacyDataType {
return fmt.Errorf("unsupported data type: %s", payload.Type)
}
if payload.Version != 0 && payload.Version != dataVersion {
return fmt.Errorf("unsupported data version: %d", payload.Version)
}
if payload.Proxies == nil {
return errors.New("proxies is required")
}
if payload.Accounts == nil {
return errors.New("accounts is required")
}
return nil
}
func validateDataProxy(item DataProxy) error {
if strings.TrimSpace(item.Protocol) == "" {
return errors.New("proxy protocol is required")
}
if strings.TrimSpace(item.Host) == "" {
return errors.New("proxy host is required")
}
if item.Port <= 0 || item.Port > 65535 {
return errors.New("proxy port is invalid")
}
switch item.Protocol {
case "http", "https", "socks5", "socks5h":
default:
return fmt.Errorf("proxy protocol is invalid: %s", item.Protocol)
}
if item.Status != "" {
normalizedStatus := normalizeProxyStatus(item.Status)
if normalizedStatus != service.StatusActive && normalizedStatus != "inactive" {
return fmt.Errorf("proxy status is invalid: %s", item.Status)
}
}
return nil
}
func validateDataAccount(item DataAccount) error {
if strings.TrimSpace(item.Name) == "" {
return errors.New("account name is required")
}
if strings.TrimSpace(item.Platform) == "" {
return errors.New("account platform is required")
}
if strings.TrimSpace(item.Type) == "" {
return errors.New("account type is required")
}
if len(item.Credentials) == 0 {
return errors.New("account credentials is required")
}
switch item.Type {
case service.AccountTypeOAuth, service.AccountTypeSetupToken, service.AccountTypeAPIKey, service.AccountTypeUpstream:
default:
return fmt.Errorf("account type is invalid: %s", item.Type)
}
if item.RateMultiplier != nil && *item.RateMultiplier < 0 {
return errors.New("rate_multiplier must be >= 0")
}
if item.Concurrency < 0 {
return errors.New("concurrency must be >= 0")
}
if item.Priority < 0 {
return errors.New("priority must be >= 0")
}
return nil
}
func defaultProxyName(name string) string {
if strings.TrimSpace(name) == "" {
return "imported-proxy"
}
return name
}
// enrichCredentialsFromIDToken performs best-effort extraction of user info fields
// (email, plan_type, chatgpt_account_id, etc.) from id_token in credentials.
// Only applies to OpenAI/Sora OAuth accounts. Skips expired token errors silently.
// Existing credential values are never overwritten — only missing fields are filled.
func enrichCredentialsFromIDToken(item *DataAccount) {
if item.Credentials == nil {
return
}
// Only enrich OpenAI/Sora OAuth accounts
platform := strings.ToLower(strings.TrimSpace(item.Platform))
if platform != service.PlatformOpenAI && platform != service.PlatformSora {
return
}
if strings.ToLower(strings.TrimSpace(item.Type)) != service.AccountTypeOAuth {
return
}
idToken, _ := item.Credentials["id_token"].(string)
if strings.TrimSpace(idToken) == "" {
return
}
// DecodeIDToken skips expiry validation — safe for imported data
claims, err := openai.DecodeIDToken(idToken)
if err != nil {
slog.Debug("import_enrich_id_token_decode_failed", "account", item.Name, "error", err)
return
}
userInfo := claims.GetUserInfo()
if userInfo == nil {
return
}
// Fill missing fields only (never overwrite existing values)
setIfMissing := func(key, value string) {
if value == "" {
return
}
if existing, _ := item.Credentials[key].(string); existing == "" {
item.Credentials[key] = value
}
}
setIfMissing("email", userInfo.Email)
setIfMissing("plan_type", userInfo.PlanType)
setIfMissing("chatgpt_account_id", userInfo.ChatGPTAccountID)
setIfMissing("chatgpt_user_id", userInfo.ChatGPTUserID)
setIfMissing("organization_id", userInfo.OrganizationID)
}
func normalizeProxyStatus(status string) string {
normalized := strings.TrimSpace(strings.ToLower(status))
switch normalized {
case "":
return ""
case service.StatusActive:
return service.StatusActive
case "inactive", service.StatusDisabled:
return "inactive"
default:
return normalized
}
}

View File

@@ -0,0 +1,232 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type dataResponse struct {
Code int `json:"code"`
Data dataPayload `json:"data"`
}
type dataPayload struct {
Type string `json:"type"`
Version int `json:"version"`
Proxies []dataProxy `json:"proxies"`
Accounts []dataAccount `json:"accounts"`
}
type dataProxy struct {
ProxyKey string `json:"proxy_key"`
Name string `json:"name"`
Protocol string `json:"protocol"`
Host string `json:"host"`
Port int `json:"port"`
Username string `json:"username"`
Password string `json:"password"`
Status string `json:"status"`
}
type dataAccount struct {
Name string `json:"name"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyKey *string `json:"proxy_key"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
}
func setupAccountDataRouter() (*gin.Engine, *stubAdminService) {
gin.SetMode(gin.TestMode)
router := gin.New()
adminSvc := newStubAdminService()
h := NewAccountHandler(
adminSvc,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
)
router.GET("/api/v1/admin/accounts/data", h.ExportData)
router.POST("/api/v1/admin/accounts/data", h.ImportData)
return router, adminSvc
}
func TestExportDataIncludesSecrets(t *testing.T) {
router, adminSvc := setupAccountDataRouter()
proxyID := int64(11)
adminSvc.proxies = []service.Proxy{
{
ID: proxyID,
Name: "proxy",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Username: "user",
Password: "pass",
Status: service.StatusActive,
},
{
ID: 12,
Name: "orphan",
Protocol: "https",
Host: "10.0.0.1",
Port: 443,
Username: "o",
Password: "p",
Status: service.StatusActive,
},
}
adminSvc.accounts = []service.Account{
{
ID: 21,
Name: "account",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeOAuth,
Credentials: map[string]any{"token": "secret"},
Extra: map[string]any{"note": "x"},
ProxyID: &proxyID,
Concurrency: 3,
Priority: 50,
Status: service.StatusDisabled,
},
}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/data", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp dataResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Empty(t, resp.Data.Type)
require.Equal(t, 0, resp.Data.Version)
require.Len(t, resp.Data.Proxies, 1)
require.Equal(t, "pass", resp.Data.Proxies[0].Password)
require.Len(t, resp.Data.Accounts, 1)
require.Equal(t, "secret", resp.Data.Accounts[0].Credentials["token"])
}
func TestExportDataWithoutProxies(t *testing.T) {
router, adminSvc := setupAccountDataRouter()
proxyID := int64(11)
adminSvc.proxies = []service.Proxy{
{
ID: proxyID,
Name: "proxy",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Username: "user",
Password: "pass",
Status: service.StatusActive,
},
}
adminSvc.accounts = []service.Account{
{
ID: 21,
Name: "account",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeOAuth,
Credentials: map[string]any{"token": "secret"},
ProxyID: &proxyID,
Concurrency: 3,
Priority: 50,
Status: service.StatusDisabled,
},
}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/data?include_proxies=false", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp dataResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Len(t, resp.Data.Proxies, 0)
require.Len(t, resp.Data.Accounts, 1)
require.Nil(t, resp.Data.Accounts[0].ProxyKey)
}
func TestImportDataReusesProxyAndSkipsDefaultGroup(t *testing.T) {
router, adminSvc := setupAccountDataRouter()
adminSvc.proxies = []service.Proxy{
{
ID: 1,
Name: "proxy",
Protocol: "socks5",
Host: "1.2.3.4",
Port: 1080,
Username: "u",
Password: "p",
Status: service.StatusActive,
},
}
dataPayload := map[string]any{
"data": map[string]any{
"type": dataType,
"version": dataVersion,
"proxies": []map[string]any{
{
"proxy_key": "socks5|1.2.3.4|1080|u|p",
"name": "proxy",
"protocol": "socks5",
"host": "1.2.3.4",
"port": 1080,
"username": "u",
"password": "p",
"status": "active",
},
},
"accounts": []map[string]any{
{
"name": "acc",
"platform": service.PlatformOpenAI,
"type": service.AccountTypeOAuth,
"credentials": map[string]any{"token": "x"},
"proxy_key": "socks5|1.2.3.4|1080|u|p",
"concurrency": 3,
"priority": 50,
},
},
},
"skip_default_group_bind": true,
}
body, _ := json.Marshal(dataPayload)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/data", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Len(t, adminSvc.createdProxies, 0)
require.Len(t, adminSvc.createdAccounts, 1)
require.True(t, adminSvc.createdAccounts[0].SkipDefaultGroupBind)
}

View File

@@ -0,0 +1,105 @@
package admin
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type availableModelsAdminService struct {
*stubAdminService
account service.Account
}
func (s *availableModelsAdminService) GetAccount(_ context.Context, id int64) (*service.Account, error) {
if s.account.ID == id {
acc := s.account
return &acc, nil
}
return s.stubAdminService.GetAccount(context.Background(), id)
}
func setupAvailableModelsRouter(adminSvc service.AdminService) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.GET("/api/v1/admin/accounts/:id/models", handler.GetAvailableModels)
return router
}
func TestAccountHandlerGetAvailableModels_OpenAIOAuthUsesExplicitModelMapping(t *testing.T) {
svc := &availableModelsAdminService{
stubAdminService: newStubAdminService(),
account: service.Account{
ID: 42,
Name: "openai-oauth",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5": "gpt-5.1",
},
},
},
}
router := setupAvailableModelsRouter(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/42/models", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Len(t, resp.Data, 1)
require.Equal(t, "gpt-5", resp.Data[0].ID)
}
func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefaults(t *testing.T) {
svc := &availableModelsAdminService{
stubAdminService: newStubAdminService(),
account: service.Account{
ID: 43,
Name: "openai-oauth-passthrough",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5": "gpt-5.1",
},
},
Extra: map[string]any{
"openai_passthrough": true,
},
},
}
router := setupAvailableModelsRouter(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/accounts/43/models", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.NotEmpty(t, resp.Data)
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
}

View File

@@ -0,0 +1,198 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func setupAccountMixedChannelRouter(adminSvc *stubAdminService) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.POST("/api/v1/admin/accounts/check-mixed-channel", accountHandler.CheckMixedChannel)
router.POST("/api/v1/admin/accounts", accountHandler.Create)
router.PUT("/api/v1/admin/accounts/:id", accountHandler.Update)
router.POST("/api/v1/admin/accounts/bulk-update", accountHandler.BulkUpdate)
return router
}
func TestAccountHandlerCheckMixedChannelNoRisk(t *testing.T) {
adminSvc := newStubAdminService()
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"platform": "antigravity",
"group_ids": []int64{27},
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/check-mixed-channel", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, float64(0), resp["code"])
data, ok := resp["data"].(map[string]any)
require.True(t, ok)
require.Equal(t, false, data["has_risk"])
require.Equal(t, int64(0), adminSvc.lastMixedCheck.accountID)
require.Equal(t, "antigravity", adminSvc.lastMixedCheck.platform)
require.Equal(t, []int64{27}, adminSvc.lastMixedCheck.groupIDs)
}
func TestAccountHandlerCheckMixedChannelWithRisk(t *testing.T) {
adminSvc := newStubAdminService()
adminSvc.checkMixedErr = &service.MixedChannelError{
GroupID: 27,
GroupName: "claude-max",
CurrentPlatform: "Antigravity",
OtherPlatform: "Anthropic",
}
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"platform": "antigravity",
"group_ids": []int64{27},
"account_id": 99,
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/check-mixed-channel", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, float64(0), resp["code"])
data, ok := resp["data"].(map[string]any)
require.True(t, ok)
require.Equal(t, true, data["has_risk"])
require.Equal(t, "mixed_channel_warning", data["error"])
details, ok := data["details"].(map[string]any)
require.True(t, ok)
require.Equal(t, float64(27), details["group_id"])
require.Equal(t, "claude-max", details["group_name"])
require.Equal(t, "Antigravity", details["current_platform"])
require.Equal(t, "Anthropic", details["other_platform"])
require.Equal(t, int64(99), adminSvc.lastMixedCheck.accountID)
}
func TestAccountHandlerCreateMixedChannelConflictSimplifiedResponse(t *testing.T) {
adminSvc := newStubAdminService()
adminSvc.createAccountErr = &service.MixedChannelError{
GroupID: 27,
GroupName: "claude-max",
CurrentPlatform: "Antigravity",
OtherPlatform: "Anthropic",
}
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"name": "ag-oauth-1",
"platform": "antigravity",
"type": "oauth",
"credentials": map[string]any{"refresh_token": "rt"},
"group_ids": []int64{27},
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusConflict, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "mixed_channel_warning", resp["error"])
require.Contains(t, resp["message"], "mixed_channel_warning")
_, hasDetails := resp["details"]
_, hasRequireConfirmation := resp["require_confirmation"]
require.False(t, hasDetails)
require.False(t, hasRequireConfirmation)
}
func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T) {
adminSvc := newStubAdminService()
adminSvc.updateAccountErr = &service.MixedChannelError{
GroupID: 27,
GroupName: "claude-max",
CurrentPlatform: "Antigravity",
OtherPlatform: "Anthropic",
}
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"group_ids": []int64{27},
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/accounts/3", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusConflict, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "mixed_channel_warning", resp["error"])
require.Contains(t, resp["message"], "mixed_channel_warning")
_, hasDetails := resp["details"]
_, hasRequireConfirmation := resp["require_confirmation"]
require.False(t, hasDetails)
require.False(t, hasRequireConfirmation)
}
func TestAccountHandlerBulkUpdateMixedChannelConflict(t *testing.T) {
adminSvc := newStubAdminService()
adminSvc.bulkUpdateAccountErr = &service.MixedChannelError{
GroupID: 27,
GroupName: "claude-max",
CurrentPlatform: "Antigravity",
OtherPlatform: "Anthropic",
}
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1, 2, 3},
"group_ids": []int64{27},
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusConflict, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "mixed_channel_warning", resp["error"])
require.Contains(t, resp["message"], "claude-max")
}
func TestAccountHandlerBulkUpdateMixedChannelConfirmSkips(t *testing.T) {
adminSvc := newStubAdminService()
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1, 2},
"group_ids": []int64{27},
"confirm_mixed_channel_risk": true,
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, float64(0), resp["code"])
data, ok := resp["data"].(map[string]any)
require.True(t, ok)
require.Equal(t, float64(2), data["success"])
require.Equal(t, float64(0), data["failed"])
}

View File

@@ -0,0 +1,67 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestAccountHandler_Create_AnthropicAPIKeyPassthroughExtraForwarded(t *testing.T) {
gin.SetMode(gin.TestMode)
adminSvc := newStubAdminService()
handler := NewAccountHandler(
adminSvc,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
)
router := gin.New()
router.POST("/api/v1/admin/accounts", handler.Create)
body := map[string]any{
"name": "anthropic-key-1",
"platform": "anthropic",
"type": "apikey",
"credentials": map[string]any{
"api_key": "sk-ant-xxx",
"base_url": "https://api.anthropic.com",
},
"extra": map[string]any{
"anthropic_passthrough": true,
},
"concurrency": 1,
"priority": 1,
}
raw, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts", bytes.NewReader(raw))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Len(t, adminSvc.createdAccounts, 1)
created := adminSvc.createdAccounts[0]
require.Equal(t, "anthropic", created.Platform)
require.Equal(t, "apikey", created.Type)
require.NotNil(t, created.Extra)
require.Equal(t, true, created.Extra["anthropic_passthrough"])
}

View File

@@ -0,0 +1,25 @@
package admin
import (
"strconv"
"strings"
"time"
)
var accountTodayStatsBatchCache = newSnapshotCache(30 * time.Second)
func buildAccountTodayStatsBatchCacheKey(accountIDs []int64) string {
if len(accountIDs) == 0 {
return "accounts_today_stats_empty"
}
var b strings.Builder
b.Grow(len(accountIDs) * 6)
_, _ = b.WriteString("accounts_today_stats:")
for i, id := range accountIDs {
if i > 0 {
_ = b.WriteByte(',')
}
_, _ = b.WriteString(strconv.FormatInt(id, 10))
}
return b.String()
}

View File

@@ -0,0 +1,268 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func setupAdminRouter() (*gin.Engine, *stubAdminService) {
gin.SetMode(gin.TestMode)
router := gin.New()
adminSvc := newStubAdminService()
userHandler := NewUserHandler(adminSvc, nil)
groupHandler := NewGroupHandler(adminSvc)
proxyHandler := NewProxyHandler(adminSvc)
redeemHandler := NewRedeemHandler(adminSvc, nil)
router.GET("/api/v1/admin/users", userHandler.List)
router.GET("/api/v1/admin/users/:id", userHandler.GetByID)
router.POST("/api/v1/admin/users", userHandler.Create)
router.PUT("/api/v1/admin/users/:id", userHandler.Update)
router.DELETE("/api/v1/admin/users/:id", userHandler.Delete)
router.POST("/api/v1/admin/users/:id/balance", userHandler.UpdateBalance)
router.GET("/api/v1/admin/users/:id/api-keys", userHandler.GetUserAPIKeys)
router.GET("/api/v1/admin/users/:id/usage", userHandler.GetUserUsage)
router.GET("/api/v1/admin/groups", groupHandler.List)
router.GET("/api/v1/admin/groups/all", groupHandler.GetAll)
router.GET("/api/v1/admin/groups/:id", groupHandler.GetByID)
router.POST("/api/v1/admin/groups", groupHandler.Create)
router.PUT("/api/v1/admin/groups/:id", groupHandler.Update)
router.DELETE("/api/v1/admin/groups/:id", groupHandler.Delete)
router.GET("/api/v1/admin/groups/:id/stats", groupHandler.GetStats)
router.GET("/api/v1/admin/groups/:id/api-keys", groupHandler.GetGroupAPIKeys)
router.GET("/api/v1/admin/proxies", proxyHandler.List)
router.GET("/api/v1/admin/proxies/all", proxyHandler.GetAll)
router.GET("/api/v1/admin/proxies/:id", proxyHandler.GetByID)
router.POST("/api/v1/admin/proxies", proxyHandler.Create)
router.PUT("/api/v1/admin/proxies/:id", proxyHandler.Update)
router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete)
router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete)
router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test)
router.POST("/api/v1/admin/proxies/:id/quality-check", proxyHandler.CheckQuality)
router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats)
router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts)
router.GET("/api/v1/admin/redeem-codes", redeemHandler.List)
router.GET("/api/v1/admin/redeem-codes/:id", redeemHandler.GetByID)
router.POST("/api/v1/admin/redeem-codes", redeemHandler.Generate)
router.DELETE("/api/v1/admin/redeem-codes/:id", redeemHandler.Delete)
router.POST("/api/v1/admin/redeem-codes/batch-delete", redeemHandler.BatchDelete)
router.POST("/api/v1/admin/redeem-codes/:id/expire", redeemHandler.Expire)
router.GET("/api/v1/admin/redeem-codes/:id/stats", redeemHandler.GetStats)
return router, adminSvc
}
func TestUserHandlerEndpoints(t *testing.T) {
router, _ := setupAdminRouter()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/users?page=1&page_size=20", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
createBody := map[string]any{"email": "new@example.com", "password": "pass123", "balance": 1, "concurrency": 2}
body, _ := json.Marshal(createBody)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
updateBody := map[string]any{"email": "updated@example.com"}
body, _ = json.Marshal(updateBody)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/users/1", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/users/1", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/1/balance", bytes.NewBufferString(`{"balance":1,"operation":"add"}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1/api-keys", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1/usage?period=today", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestGroupHandlerEndpoints(t *testing.T) {
router, _ := setupAdminRouter()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/all", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
body, _ := json.Marshal(map[string]any{"name": "new", "platform": "anthropic", "subscription_type": "standard"})
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/groups", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
body, _ = json.Marshal(map[string]any{"name": "update"})
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/groups/2", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/groups/2", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2/stats", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2/api-keys", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestProxyHandlerEndpoints(t *testing.T) {
router, _ := setupAdminRouter()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/all", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
body, _ := json.Marshal(map[string]any{"name": "proxy", "protocol": "http", "host": "localhost", "port": 8080})
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
body, _ = json.Marshal(map[string]any{"name": "proxy2"})
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/proxies/4", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/proxies/4", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/batch-delete", bytes.NewBufferString(`{"ids":[1,2]}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/4/test", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/4/quality-check", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/accounts", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestRedeemHandlerEndpoints(t *testing.T) {
router, _ := setupAdminRouter()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/5", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
body, _ := json.Marshal(map[string]any{"count": 1, "type": "balance", "value": 10})
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/redeem-codes/5", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/batch-delete", bytes.NewBufferString(`{"ids":[1,2]}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/5/expire", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/5/stats", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}

View File

@@ -0,0 +1,224 @@
package admin
import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/netip"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestParseTimeRange(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
req := httptest.NewRequest(http.MethodGet, "/?start_date=2024-01-01&end_date=2024-01-02&timezone=UTC", nil)
c.Request = req
start, end := parseTimeRange(c)
require.Equal(t, time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), start)
require.Equal(t, time.Date(2024, 1, 3, 0, 0, 0, 0, time.UTC), end)
req = httptest.NewRequest(http.MethodGet, "/?start_date=bad&timezone=UTC", nil)
c.Request = req
start, end = parseTimeRange(c)
require.False(t, start.IsZero())
require.False(t, end.IsZero())
}
func TestParseOpsViewParam(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/?view=excluded", nil)
require.Equal(t, opsListViewExcluded, parseOpsViewParam(c))
c2, _ := gin.CreateTestContext(w)
c2.Request = httptest.NewRequest(http.MethodGet, "/?view=all", nil)
require.Equal(t, opsListViewAll, parseOpsViewParam(c2))
c3, _ := gin.CreateTestContext(w)
c3.Request = httptest.NewRequest(http.MethodGet, "/?view=unknown", nil)
require.Equal(t, opsListViewErrors, parseOpsViewParam(c3))
require.Equal(t, "", parseOpsViewParam(nil))
}
func TestParseOpsDuration(t *testing.T) {
dur, ok := parseOpsDuration("1h")
require.True(t, ok)
require.Equal(t, time.Hour, dur)
_, ok = parseOpsDuration("invalid")
require.False(t, ok)
}
func TestParseOpsOpenAITokenStatsDuration(t *testing.T) {
tests := []struct {
input string
want time.Duration
ok bool
}{
{input: "30m", want: 30 * time.Minute, ok: true},
{input: "1h", want: time.Hour, ok: true},
{input: "1d", want: 24 * time.Hour, ok: true},
{input: "15d", want: 15 * 24 * time.Hour, ok: true},
{input: "30d", want: 30 * 24 * time.Hour, ok: true},
{input: "7d", want: 0, ok: false},
}
for _, tt := range tests {
got, ok := parseOpsOpenAITokenStatsDuration(tt.input)
require.Equal(t, tt.ok, ok, "input=%s", tt.input)
require.Equal(t, tt.want, got, "input=%s", tt.input)
}
}
func TestParseOpsOpenAITokenStatsFilter_Defaults(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
before := time.Now().UTC()
filter, err := parseOpsOpenAITokenStatsFilter(c)
after := time.Now().UTC()
require.NoError(t, err)
require.NotNil(t, filter)
require.Equal(t, "30d", filter.TimeRange)
require.Equal(t, 1, filter.Page)
require.Equal(t, 20, filter.PageSize)
require.Equal(t, 0, filter.TopN)
require.Nil(t, filter.GroupID)
require.Equal(t, "", filter.Platform)
require.True(t, filter.StartTime.Before(filter.EndTime))
require.WithinDuration(t, before.Add(-30*24*time.Hour), filter.StartTime, 2*time.Second)
require.WithinDuration(t, after, filter.EndTime, 2*time.Second)
}
func TestParseOpsOpenAITokenStatsFilter_WithTopN(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(
http.MethodGet,
"/?time_range=1h&platform=openai&group_id=12&top_n=50",
nil,
)
filter, err := parseOpsOpenAITokenStatsFilter(c)
require.NoError(t, err)
require.Equal(t, "1h", filter.TimeRange)
require.Equal(t, "openai", filter.Platform)
require.NotNil(t, filter.GroupID)
require.Equal(t, int64(12), *filter.GroupID)
require.Equal(t, 50, filter.TopN)
require.Equal(t, 0, filter.Page)
require.Equal(t, 0, filter.PageSize)
}
func TestParseOpsOpenAITokenStatsFilter_InvalidParams(t *testing.T) {
tests := []string{
"/?time_range=7d",
"/?group_id=0",
"/?group_id=abc",
"/?top_n=0",
"/?top_n=101",
"/?top_n=10&page=1",
"/?top_n=10&page_size=20",
"/?page=0",
"/?page_size=0",
"/?page_size=101",
}
gin.SetMode(gin.TestMode)
for _, rawURL := range tests {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, rawURL, nil)
_, err := parseOpsOpenAITokenStatsFilter(c)
require.Error(t, err, "url=%s", rawURL)
}
}
func TestParseOpsTimeRange(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
now := time.Now().UTC()
startStr := now.Add(-time.Hour).Format(time.RFC3339)
endStr := now.Format(time.RFC3339)
c.Request = httptest.NewRequest(http.MethodGet, "/?start_time="+startStr+"&end_time="+endStr, nil)
start, end, err := parseOpsTimeRange(c, "1h")
require.NoError(t, err)
require.True(t, start.Before(end))
c2, _ := gin.CreateTestContext(w)
c2.Request = httptest.NewRequest(http.MethodGet, "/?start_time=bad", nil)
_, _, err = parseOpsTimeRange(c2, "1h")
require.Error(t, err)
}
func TestParseOpsRealtimeWindow(t *testing.T) {
dur, label, ok := parseOpsRealtimeWindow("5m")
require.True(t, ok)
require.Equal(t, 5*time.Minute, dur)
require.Equal(t, "5min", label)
_, _, ok = parseOpsRealtimeWindow("invalid")
require.False(t, ok)
}
func TestPickThroughputBucketSeconds(t *testing.T) {
require.Equal(t, 60, pickThroughputBucketSeconds(30*time.Minute))
require.Equal(t, 300, pickThroughputBucketSeconds(6*time.Hour))
require.Equal(t, 3600, pickThroughputBucketSeconds(48*time.Hour))
}
func TestParseOpsQueryMode(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/?mode=raw", nil)
require.Equal(t, service.ParseOpsQueryMode("raw"), parseOpsQueryMode(c))
require.Equal(t, service.OpsQueryMode(""), parseOpsQueryMode(nil))
}
func TestOpsAlertRuleValidation(t *testing.T) {
raw := map[string]json.RawMessage{
"name": json.RawMessage(`"High error rate"`),
"metric_type": json.RawMessage(`"error_rate"`),
"operator": json.RawMessage(`">"`),
"threshold": json.RawMessage(`90`),
}
validated, err := validateOpsAlertRulePayload(raw)
require.NoError(t, err)
require.Equal(t, "High error rate", validated.Name)
_, err = validateOpsAlertRulePayload(map[string]json.RawMessage{})
require.Error(t, err)
require.True(t, isPercentOrRateMetric("error_rate"))
require.False(t, isPercentOrRateMetric("concurrency_queue_depth"))
}
func TestOpsWSHelpers(t *testing.T) {
prefixes, invalid := parseTrustedProxyList("10.0.0.0/8,invalid")
require.Len(t, prefixes, 1)
require.Len(t, invalid, 1)
host := hostWithoutPort("example.com:443")
require.Equal(t, "example.com", host)
addr := netip.MustParseAddr("10.0.0.1")
require.True(t, isAddrInTrustedProxies(addr, prefixes))
require.False(t, isAddrInTrustedProxies(netip.MustParseAddr("192.168.0.1"), prefixes))
}

View File

@@ -0,0 +1,449 @@
package admin
import (
"context"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type stubAdminService struct {
users []service.User
apiKeys []service.APIKey
groups []service.Group
accounts []service.Account
proxies []service.Proxy
proxyCounts []service.ProxyWithAccountCount
redeems []service.RedeemCode
createdAccounts []*service.CreateAccountInput
createdProxies []*service.CreateProxyInput
updatedProxyIDs []int64
updatedProxies []*service.UpdateProxyInput
testedProxyIDs []int64
createAccountErr error
updateAccountErr error
bulkUpdateAccountErr error
checkMixedErr error
lastMixedCheck struct {
accountID int64
platform string
groupIDs []int64
}
mu sync.Mutex
}
func newStubAdminService() *stubAdminService {
now := time.Now().UTC()
user := service.User{
ID: 1,
Email: "user@example.com",
Role: service.RoleUser,
Status: service.StatusActive,
CreatedAt: now,
UpdatedAt: now,
}
apiKey := service.APIKey{
ID: 10,
UserID: user.ID,
Key: "sk-test",
Name: "test",
Status: service.StatusActive,
CreatedAt: now,
UpdatedAt: now,
}
group := service.Group{
ID: 2,
Name: "group",
Platform: service.PlatformAnthropic,
Status: service.StatusActive,
CreatedAt: now,
UpdatedAt: now,
}
account := service.Account{
ID: 3,
Name: "account",
Platform: service.PlatformAnthropic,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
CreatedAt: now,
UpdatedAt: now,
}
proxy := service.Proxy{
ID: 4,
Name: "proxy",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Status: service.StatusActive,
CreatedAt: now,
UpdatedAt: now,
}
redeem := service.RedeemCode{
ID: 5,
Code: "R-TEST",
Type: service.RedeemTypeBalance,
Value: 10,
Status: service.StatusUnused,
CreatedAt: now,
}
return &stubAdminService{
users: []service.User{user},
apiKeys: []service.APIKey{apiKey},
groups: []service.Group{group},
accounts: []service.Account{account},
proxies: []service.Proxy{proxy},
proxyCounts: []service.ProxyWithAccountCount{{Proxy: proxy, AccountCount: 1}},
redeems: []service.RedeemCode{redeem},
}
}
func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, filters service.UserListFilters) ([]service.User, int64, error) {
return s.users, int64(len(s.users)), nil
}
func (s *stubAdminService) GetUser(ctx context.Context, id int64) (*service.User, error) {
for i := range s.users {
if s.users[i].ID == id {
return &s.users[i], nil
}
}
user := service.User{ID: id, Email: "user@example.com", Status: service.StatusActive}
return &user, nil
}
func (s *stubAdminService) CreateUser(ctx context.Context, input *service.CreateUserInput) (*service.User, error) {
user := service.User{ID: 100, Email: input.Email, Status: service.StatusActive}
return &user, nil
}
func (s *stubAdminService) UpdateUser(ctx context.Context, id int64, input *service.UpdateUserInput) (*service.User, error) {
user := service.User{ID: id, Email: "updated@example.com", Status: service.StatusActive}
return &user, nil
}
func (s *stubAdminService) DeleteUser(ctx context.Context, id int64) error {
return nil
}
func (s *stubAdminService) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*service.User, error) {
user := service.User{ID: userID, Balance: balance, Status: service.StatusActive}
return &user, nil
}
func (s *stubAdminService) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]service.APIKey, int64, error) {
return s.apiKeys, int64(len(s.apiKeys)), nil
}
func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) {
return map[string]any{"user_id": userID}, nil
}
func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]service.Group, int64, error) {
return s.groups, int64(len(s.groups)), nil
}
func (s *stubAdminService) GetAllGroups(ctx context.Context) ([]service.Group, error) {
return s.groups, nil
}
func (s *stubAdminService) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
return s.groups, nil
}
func (s *stubAdminService) GetGroup(ctx context.Context, id int64) (*service.Group, error) {
group := service.Group{ID: id, Name: "group", Status: service.StatusActive}
return &group, nil
}
func (s *stubAdminService) CreateGroup(ctx context.Context, input *service.CreateGroupInput) (*service.Group, error) {
group := service.Group{ID: 200, Name: input.Name, Status: service.StatusActive}
return &group, nil
}
func (s *stubAdminService) UpdateGroup(ctx context.Context, id int64, input *service.UpdateGroupInput) (*service.Group, error) {
group := service.Group{ID: id, Name: input.Name, Status: service.StatusActive}
return &group, nil
}
func (s *stubAdminService) DeleteGroup(ctx context.Context, id int64) error {
return nil
}
func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]service.APIKey, int64, error) {
return s.apiKeys, int64(len(s.apiKeys)), nil
}
func (s *stubAdminService) GetGroupRateMultipliers(_ context.Context, _ int64) ([]service.UserGroupRateEntry, error) {
return nil, nil
}
func (s *stubAdminService) ClearGroupRateMultipliers(_ context.Context, _ int64) error {
return nil
}
func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int64, _ []service.GroupRateMultiplierInput) error {
return nil
}
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
return s.accounts, int64(len(s.accounts)), nil
}
func (s *stubAdminService) GetAccount(ctx context.Context, id int64) (*service.Account, error) {
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
out := make([]*service.Account, 0, len(ids))
for _, id := range ids {
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
out = append(out, &account)
}
return out, nil
}
func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.CreateAccountInput) (*service.Account, error) {
s.mu.Lock()
s.createdAccounts = append(s.createdAccounts, input)
s.mu.Unlock()
if s.createAccountErr != nil {
return nil, s.createAccountErr
}
account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
if s.updateAccountErr != nil {
return nil, s.updateAccountErr
}
account := service.Account{ID: id, Name: input.Name, Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) DeleteAccount(ctx context.Context, id int64) error {
return nil
}
func (s *stubAdminService) RefreshAccountCredentials(ctx context.Context, id int64) (*service.Account, error) {
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) ClearAccountError(ctx context.Context, id int64) (*service.Account, error) {
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) SetAccountError(ctx context.Context, id int64, errorMsg string) error {
return nil
}
func (s *stubAdminService) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*service.Account, error) {
account := service.Account{ID: id, Name: "account", Status: service.StatusActive, Schedulable: schedulable}
return &account, nil
}
func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *service.BulkUpdateAccountsInput) (*service.BulkUpdateAccountsResult, error) {
if s.bulkUpdateAccountErr != nil {
return nil, s.bulkUpdateAccountErr
}
return &service.BulkUpdateAccountsResult{Success: len(input.AccountIDs), Failed: 0, SuccessIDs: input.AccountIDs}, nil
}
func (s *stubAdminService) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
s.lastMixedCheck.accountID = currentAccountID
s.lastMixedCheck.platform = currentAccountPlatform
s.lastMixedCheck.groupIDs = append([]int64(nil), groupIDs...)
return s.checkMixedErr
}
func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) {
search = strings.TrimSpace(strings.ToLower(search))
filtered := make([]service.Proxy, 0, len(s.proxies))
for _, proxy := range s.proxies {
if protocol != "" && proxy.Protocol != protocol {
continue
}
if status != "" && proxy.Status != status {
continue
}
if search != "" {
name := strings.ToLower(proxy.Name)
host := strings.ToLower(proxy.Host)
if !strings.Contains(name, search) && !strings.Contains(host, search) {
continue
}
}
filtered = append(filtered, proxy)
}
return filtered, int64(len(filtered)), nil
}
func (s *stubAdminService) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.ProxyWithAccountCount, int64, error) {
return s.proxyCounts, int64(len(s.proxyCounts)), nil
}
func (s *stubAdminService) GetAllProxies(ctx context.Context) ([]service.Proxy, error) {
return s.proxies, nil
}
func (s *stubAdminService) GetAllProxiesWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) {
return s.proxyCounts, nil
}
func (s *stubAdminService) GetProxy(ctx context.Context, id int64) (*service.Proxy, error) {
for i := range s.proxies {
proxy := s.proxies[i]
if proxy.ID == id {
return &proxy, nil
}
}
proxy := service.Proxy{ID: id, Name: "proxy", Status: service.StatusActive}
return &proxy, nil
}
func (s *stubAdminService) GetProxiesByIDs(ctx context.Context, ids []int64) ([]service.Proxy, error) {
if len(ids) == 0 {
return []service.Proxy{}, nil
}
out := make([]service.Proxy, 0, len(ids))
seen := make(map[int64]struct{}, len(ids))
for _, id := range ids {
seen[id] = struct{}{}
}
for i := range s.proxies {
proxy := s.proxies[i]
if _, ok := seen[proxy.ID]; ok {
out = append(out, proxy)
}
}
return out, nil
}
func (s *stubAdminService) CreateProxy(ctx context.Context, input *service.CreateProxyInput) (*service.Proxy, error) {
s.mu.Lock()
s.createdProxies = append(s.createdProxies, input)
s.mu.Unlock()
proxy := service.Proxy{ID: 400, Name: input.Name, Status: service.StatusActive}
return &proxy, nil
}
func (s *stubAdminService) UpdateProxy(ctx context.Context, id int64, input *service.UpdateProxyInput) (*service.Proxy, error) {
s.mu.Lock()
s.updatedProxyIDs = append(s.updatedProxyIDs, id)
s.updatedProxies = append(s.updatedProxies, input)
s.mu.Unlock()
proxy := service.Proxy{ID: id, Name: input.Name, Status: service.StatusActive}
return &proxy, nil
}
func (s *stubAdminService) DeleteProxy(ctx context.Context, id int64) error {
return nil
}
func (s *stubAdminService) BatchDeleteProxies(ctx context.Context, ids []int64) (*service.ProxyBatchDeleteResult, error) {
return &service.ProxyBatchDeleteResult{DeletedIDs: ids}, nil
}
func (s *stubAdminService) GetProxyAccounts(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) {
return []service.ProxyAccountSummary{{ID: 1, Name: "account"}}, nil
}
func (s *stubAdminService) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
return false, nil
}
func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.ProxyTestResult, error) {
s.mu.Lock()
s.testedProxyIDs = append(s.testedProxyIDs, id)
s.mu.Unlock()
return &service.ProxyTestResult{Success: true, Message: "ok"}, nil
}
func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*service.ProxyQualityCheckResult, error) {
return &service.ProxyQualityCheckResult{
ProxyID: id,
Score: 95,
Grade: "A",
Summary: "通过 5 项,告警 0 项,失败 0 项,挑战 0 项",
PassedCount: 5,
WarnCount: 0,
FailedCount: 0,
ChallengeCount: 0,
CheckedAt: time.Now().Unix(),
Items: []service.ProxyQualityCheckItem{
{Target: "base_connectivity", Status: "pass", Message: "ok"},
{Target: "openai", Status: "pass", HTTPStatus: 401},
{Target: "anthropic", Status: "pass", HTTPStatus: 401},
{Target: "gemini", Status: "pass", HTTPStatus: 200},
{Target: "sora", Status: "pass", HTTPStatus: 401},
},
}, nil
}
func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) {
return s.redeems, int64(len(s.redeems)), nil
}
func (s *stubAdminService) GetRedeemCode(ctx context.Context, id int64) (*service.RedeemCode, error) {
code := service.RedeemCode{ID: id, Code: "R-TEST", Status: service.StatusUnused}
return &code, nil
}
func (s *stubAdminService) GenerateRedeemCodes(ctx context.Context, input *service.GenerateRedeemCodesInput) ([]service.RedeemCode, error) {
return s.redeems, nil
}
func (s *stubAdminService) DeleteRedeemCode(ctx context.Context, id int64) error {
return nil
}
func (s *stubAdminService) BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) {
return int64(len(ids)), nil
}
func (s *stubAdminService) ExpireRedeemCode(ctx context.Context, id int64) (*service.RedeemCode, error) {
code := service.RedeemCode{ID: id, Code: "R-TEST", Status: service.StatusUsed}
return &code, nil
}
func (s *stubAdminService) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]service.RedeemCode, int64, float64, error) {
return s.redeems, int64(len(s.redeems)), 100.0, nil
}
func (s *stubAdminService) UpdateGroupSortOrders(ctx context.Context, updates []service.GroupSortOrderUpdate) error {
return nil
}
func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*service.AdminUpdateAPIKeyGroupIDResult, error) {
for i := range s.apiKeys {
if s.apiKeys[i].ID == keyID {
k := s.apiKeys[i]
if groupID != nil {
if *groupID == 0 {
k.GroupID = nil
} else {
gid := *groupID
k.GroupID = &gid
}
}
return &service.AdminUpdateAPIKeyGroupIDResult{APIKey: &k}, nil
}
}
return nil, service.ErrAPIKeyNotFound
}
func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) error {
return nil
}
func (s *stubAdminService) EnsureOpenAIPrivacy(ctx context.Context, account *service.Account) string {
return ""
}
// Ensure stub implements interface.
var _ service.AdminService = (*stubAdminService)(nil)

View File

@@ -0,0 +1,250 @@
package admin
import (
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// AnnouncementHandler handles admin announcement management
type AnnouncementHandler struct {
announcementService *service.AnnouncementService
}
// NewAnnouncementHandler creates a new admin announcement handler
func NewAnnouncementHandler(announcementService *service.AnnouncementService) *AnnouncementHandler {
return &AnnouncementHandler{
announcementService: announcementService,
}
}
type CreateAnnouncementRequest struct {
Title string `json:"title" binding:"required"`
Content string `json:"content" binding:"required"`
Status string `json:"status" binding:"omitempty,oneof=draft active archived"`
NotifyMode string `json:"notify_mode" binding:"omitempty,oneof=silent popup"`
Targeting service.AnnouncementTargeting `json:"targeting"`
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0/empty = immediate
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0/empty = never
}
type UpdateAnnouncementRequest struct {
Title *string `json:"title"`
Content *string `json:"content"`
Status *string `json:"status" binding:"omitempty,oneof=draft active archived"`
NotifyMode *string `json:"notify_mode" binding:"omitempty,oneof=silent popup"`
Targeting *service.AnnouncementTargeting `json:"targeting"`
StartsAt *int64 `json:"starts_at"` // Unix seconds, 0 = clear
EndsAt *int64 `json:"ends_at"` // Unix seconds, 0 = clear
}
// List handles listing announcements with filters
// GET /api/v1/admin/announcements
func (h *AnnouncementHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
status := strings.TrimSpace(c.Query("status"))
search := strings.TrimSpace(c.Query("search"))
if len(search) > 200 {
search = search[:200]
}
params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
}
items, paginationResult, err := h.announcementService.List(
c.Request.Context(),
params,
service.AnnouncementListFilters{Status: status, Search: search},
)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.Announcement, 0, len(items))
for i := range items {
out = append(out, *dto.AnnouncementFromService(&items[i]))
}
response.Paginated(c, out, paginationResult.Total, page, pageSize)
}
// GetByID handles getting an announcement by ID
// GET /api/v1/admin/announcements/:id
func (h *AnnouncementHandler) GetByID(c *gin.Context) {
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || announcementID <= 0 {
response.BadRequest(c, "Invalid announcement ID")
return
}
item, err := h.announcementService.GetByID(c.Request.Context(), announcementID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.AnnouncementFromService(item))
}
// Create handles creating a new announcement
// POST /api/v1/admin/announcements
func (h *AnnouncementHandler) Create(c *gin.Context) {
var req CreateAnnouncementRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not found in context")
return
}
input := &service.CreateAnnouncementInput{
Title: req.Title,
Content: req.Content,
Status: req.Status,
NotifyMode: req.NotifyMode,
Targeting: req.Targeting,
ActorID: &subject.UserID,
}
if req.StartsAt != nil && *req.StartsAt > 0 {
t := time.Unix(*req.StartsAt, 0)
input.StartsAt = &t
}
if req.EndsAt != nil && *req.EndsAt > 0 {
t := time.Unix(*req.EndsAt, 0)
input.EndsAt = &t
}
created, err := h.announcementService.Create(c.Request.Context(), input)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.AnnouncementFromService(created))
}
// Update handles updating an announcement
// PUT /api/v1/admin/announcements/:id
func (h *AnnouncementHandler) Update(c *gin.Context) {
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || announcementID <= 0 {
response.BadRequest(c, "Invalid announcement ID")
return
}
var req UpdateAnnouncementRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not found in context")
return
}
input := &service.UpdateAnnouncementInput{
Title: req.Title,
Content: req.Content,
Status: req.Status,
NotifyMode: req.NotifyMode,
Targeting: req.Targeting,
ActorID: &subject.UserID,
}
if req.StartsAt != nil {
if *req.StartsAt == 0 {
var cleared *time.Time = nil
input.StartsAt = &cleared
} else {
t := time.Unix(*req.StartsAt, 0)
ptr := &t
input.StartsAt = &ptr
}
}
if req.EndsAt != nil {
if *req.EndsAt == 0 {
var cleared *time.Time = nil
input.EndsAt = &cleared
} else {
t := time.Unix(*req.EndsAt, 0)
ptr := &t
input.EndsAt = &ptr
}
}
updated, err := h.announcementService.Update(c.Request.Context(), announcementID, input)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.AnnouncementFromService(updated))
}
// Delete handles deleting an announcement
// DELETE /api/v1/admin/announcements/:id
func (h *AnnouncementHandler) Delete(c *gin.Context) {
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || announcementID <= 0 {
response.BadRequest(c, "Invalid announcement ID")
return
}
if err := h.announcementService.Delete(c.Request.Context(), announcementID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Announcement deleted successfully"})
}
// ListReadStatus handles listing users read status for an announcement
// GET /api/v1/admin/announcements/:id/read-status
func (h *AnnouncementHandler) ListReadStatus(c *gin.Context) {
announcementID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || announcementID <= 0 {
response.BadRequest(c, "Invalid announcement ID")
return
}
page, pageSize := response.ParsePagination(c)
params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
}
search := strings.TrimSpace(c.Query("search"))
if len(search) > 200 {
search = search[:200]
}
items, paginationResult, err := h.announcementService.ListUserReadStatus(
c.Request.Context(),
announcementID,
params,
search,
)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, items, paginationResult.Total, page, pageSize)
}

View File

@@ -0,0 +1,91 @@
package admin
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type AntigravityOAuthHandler struct {
antigravityOAuthService *service.AntigravityOAuthService
}
func NewAntigravityOAuthHandler(antigravityOAuthService *service.AntigravityOAuthService) *AntigravityOAuthHandler {
return &AntigravityOAuthHandler{antigravityOAuthService: antigravityOAuthService}
}
type AntigravityGenerateAuthURLRequest struct {
ProxyID *int64 `json:"proxy_id"`
}
// GenerateAuthURL generates Google OAuth authorization URL
// POST /api/v1/admin/antigravity/oauth/auth-url
func (h *AntigravityOAuthHandler) GenerateAuthURL(c *gin.Context) {
var req AntigravityGenerateAuthURLRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "请求无效: "+err.Error())
return
}
result, err := h.antigravityOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID)
if err != nil {
response.InternalError(c, "生成授权链接失败: "+err.Error())
return
}
response.Success(c, result)
}
type AntigravityExchangeCodeRequest struct {
SessionID string `json:"session_id" binding:"required"`
State string `json:"state" binding:"required"`
Code string `json:"code" binding:"required"`
ProxyID *int64 `json:"proxy_id"`
}
// ExchangeCode 用 authorization code 交换 token
// POST /api/v1/admin/antigravity/oauth/exchange-code
func (h *AntigravityOAuthHandler) ExchangeCode(c *gin.Context) {
var req AntigravityExchangeCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "请求无效: "+err.Error())
return
}
tokenInfo, err := h.antigravityOAuthService.ExchangeCode(c.Request.Context(), &service.AntigravityExchangeCodeInput{
SessionID: req.SessionID,
State: req.State,
Code: req.Code,
ProxyID: req.ProxyID,
})
if err != nil {
response.BadRequest(c, "Token 交换失败: "+err.Error())
return
}
response.Success(c, tokenInfo)
}
// AntigravityRefreshTokenRequest represents the request for validating Antigravity refresh token
type AntigravityRefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"`
ProxyID *int64 `json:"proxy_id"`
}
// RefreshToken validates an Antigravity refresh token and returns full token info
// POST /api/v1/admin/antigravity/oauth/refresh-token
func (h *AntigravityOAuthHandler) RefreshToken(c *gin.Context) {
var req AntigravityRefreshTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "请求无效: "+err.Error())
return
}
tokenInfo, err := h.antigravityOAuthService.ValidateRefreshToken(c.Request.Context(), req.RefreshToken, req.ProxyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, tokenInfo)
}

View File

@@ -0,0 +1,63 @@
package admin
import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// AdminAPIKeyHandler handles admin API key management
type AdminAPIKeyHandler struct {
adminService service.AdminService
}
// NewAdminAPIKeyHandler creates a new admin API key handler
func NewAdminAPIKeyHandler(adminService service.AdminService) *AdminAPIKeyHandler {
return &AdminAPIKeyHandler{
adminService: adminService,
}
}
// AdminUpdateAPIKeyGroupRequest represents the request to update an API key's group
type AdminUpdateAPIKeyGroupRequest struct {
GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组
}
// UpdateGroup handles updating an API key's group binding
// PUT /api/v1/admin/api-keys/:id
func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) {
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid API key ID")
return
}
var req AdminUpdateAPIKeyGroupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
result, err := h.adminService.AdminUpdateAPIKeyGroupID(c.Request.Context(), keyID, req.GroupID)
if err != nil {
response.ErrorFrom(c, err)
return
}
resp := struct {
APIKey *dto.APIKey `json:"api_key"`
AutoGrantedGroupAccess bool `json:"auto_granted_group_access"`
GrantedGroupID *int64 `json:"granted_group_id,omitempty"`
GrantedGroupName string `json:"granted_group_name,omitempty"`
}{
APIKey: dto.APIKeyFromService(result.APIKey),
AutoGrantedGroupAccess: result.AutoGrantedGroupAccess,
GrantedGroupID: result.GrantedGroupID,
GrantedGroupName: result.GrantedGroupName,
}
response.Success(c, resp)
}

View File

@@ -0,0 +1,202 @@
package admin
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func setupAPIKeyHandler(adminSvc service.AdminService) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
h := NewAdminAPIKeyHandler(adminSvc)
router.PUT("/api/v1/admin/api-keys/:id", h.UpdateGroup)
return router
}
func TestAdminAPIKeyHandler_UpdateGroup_InvalidID(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
body := `{"group_id": 2}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/abc", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "Invalid API key ID")
}
func TestAdminAPIKeyHandler_UpdateGroup_InvalidJSON(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{bad json`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "Invalid request")
}
func TestAdminAPIKeyHandler_UpdateGroup_KeyNotFound(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
body := `{"group_id": 2}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/999", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
// ErrAPIKeyNotFound maps to 404
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestAdminAPIKeyHandler_UpdateGroup_BindGroup(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
body := `{"group_id": 2}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Code int `json:"code"`
Data json.RawMessage `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
var data struct {
APIKey struct {
ID int64 `json:"id"`
GroupID *int64 `json:"group_id"`
} `json:"api_key"`
AutoGrantedGroupAccess bool `json:"auto_granted_group_access"`
}
require.NoError(t, json.Unmarshal(resp.Data, &data))
require.Equal(t, int64(10), data.APIKey.ID)
require.NotNil(t, data.APIKey.GroupID)
require.Equal(t, int64(2), *data.APIKey.GroupID)
}
func TestAdminAPIKeyHandler_UpdateGroup_Unbind(t *testing.T) {
svc := newStubAdminService()
gid := int64(2)
svc.apiKeys[0].GroupID = &gid
router := setupAPIKeyHandler(svc)
body := `{"group_id": 0}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Data struct {
APIKey struct {
GroupID *int64 `json:"group_id"`
} `json:"api_key"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Nil(t, resp.Data.APIKey.GroupID)
}
func TestAdminAPIKeyHandler_UpdateGroup_ServiceError(t *testing.T) {
svc := &failingUpdateGroupService{
stubAdminService: newStubAdminService(),
err: errors.New("internal failure"),
}
router := setupAPIKeyHandler(svc)
body := `{"group_id": 2}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
// H2: empty body → group_id is nil → no-op, returns original key
func TestAdminAPIKeyHandler_UpdateGroup_EmptyBody_NoChange(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Code int `json:"code"`
Data struct {
APIKey struct {
ID int64 `json:"id"`
} `json:"api_key"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Equal(t, int64(10), resp.Data.APIKey.ID)
}
// M2: service returns GROUP_NOT_ACTIVE → handler maps to 400
func TestAdminAPIKeyHandler_UpdateGroup_GroupNotActive(t *testing.T) {
svc := &failingUpdateGroupService{
stubAdminService: newStubAdminService(),
err: infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active"),
}
router := setupAPIKeyHandler(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"group_id": 5}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "GROUP_NOT_ACTIVE")
}
// M2: service returns INVALID_GROUP_ID → handler maps to 400
func TestAdminAPIKeyHandler_UpdateGroup_NegativeGroupID(t *testing.T) {
svc := &failingUpdateGroupService{
stubAdminService: newStubAdminService(),
err: infraerrors.BadRequest("INVALID_GROUP_ID", "group_id must be non-negative"),
}
router := setupAPIKeyHandler(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"group_id": -5}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "INVALID_GROUP_ID")
}
// failingUpdateGroupService overrides AdminUpdateAPIKeyGroupID to return an error.
type failingUpdateGroupService struct {
*stubAdminService
err error
}
func (f *failingUpdateGroupService) AdminUpdateAPIKeyGroupID(_ context.Context, _ int64, _ *int64) (*service.AdminUpdateAPIKeyGroupIDResult, error) {
return nil, f.err
}

View File

@@ -0,0 +1,204 @@
package admin
import (
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type BackupHandler struct {
backupService *service.BackupService
userService *service.UserService
}
func NewBackupHandler(backupService *service.BackupService, userService *service.UserService) *BackupHandler {
return &BackupHandler{
backupService: backupService,
userService: userService,
}
}
// ─── S3 配置 ───
func (h *BackupHandler) GetS3Config(c *gin.Context) {
cfg, err := h.backupService.GetS3Config(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
func (h *BackupHandler) UpdateS3Config(c *gin.Context) {
var req service.BackupS3Config
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
cfg, err := h.backupService.UpdateS3Config(c.Request.Context(), req)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
func (h *BackupHandler) TestS3Connection(c *gin.Context) {
var req service.BackupS3Config
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
err := h.backupService.TestS3Connection(c.Request.Context(), req)
if err != nil {
response.Success(c, gin.H{"ok": false, "message": err.Error()})
return
}
response.Success(c, gin.H{"ok": true, "message": "connection successful"})
}
// ─── 定时备份 ───
func (h *BackupHandler) GetSchedule(c *gin.Context) {
cfg, err := h.backupService.GetSchedule(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
func (h *BackupHandler) UpdateSchedule(c *gin.Context) {
var req service.BackupScheduleConfig
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
cfg, err := h.backupService.UpdateSchedule(c.Request.Context(), req)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
// ─── 备份操作 ───
type CreateBackupRequest struct {
ExpireDays *int `json:"expire_days"` // nil=使用默认值140=永不过期
}
func (h *BackupHandler) CreateBackup(c *gin.Context) {
var req CreateBackupRequest
_ = c.ShouldBindJSON(&req) // 允许空 body
expireDays := 14 // 默认14天过期
if req.ExpireDays != nil {
expireDays = *req.ExpireDays
}
record, err := h.backupService.CreateBackup(c.Request.Context(), "manual", expireDays)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, record)
}
func (h *BackupHandler) ListBackups(c *gin.Context) {
records, err := h.backupService.ListBackups(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
if records == nil {
records = []service.BackupRecord{}
}
response.Success(c, gin.H{"items": records})
}
func (h *BackupHandler) GetBackup(c *gin.Context) {
backupID := c.Param("id")
if backupID == "" {
response.BadRequest(c, "backup ID is required")
return
}
record, err := h.backupService.GetBackupRecord(c.Request.Context(), backupID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, record)
}
func (h *BackupHandler) DeleteBackup(c *gin.Context) {
backupID := c.Param("id")
if backupID == "" {
response.BadRequest(c, "backup ID is required")
return
}
if err := h.backupService.DeleteBackup(c.Request.Context(), backupID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"deleted": true})
}
func (h *BackupHandler) GetDownloadURL(c *gin.Context) {
backupID := c.Param("id")
if backupID == "" {
response.BadRequest(c, "backup ID is required")
return
}
url, err := h.backupService.GetBackupDownloadURL(c.Request.Context(), backupID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"url": url})
}
// ─── 恢复操作(需要重新输入管理员密码) ───
type RestoreBackupRequest struct {
Password string `json:"password" binding:"required"`
}
func (h *BackupHandler) RestoreBackup(c *gin.Context) {
backupID := c.Param("id")
if backupID == "" {
response.BadRequest(c, "backup ID is required")
return
}
var req RestoreBackupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "password is required for restore operation")
return
}
// 从上下文获取当前管理员用户 ID
sub, ok := middleware.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "unauthorized")
return
}
// 获取管理员用户并验证密码
user, err := h.userService.GetByID(c.Request.Context(), sub.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
if !user.CheckPassword(req.Password) {
response.BadRequest(c, "incorrect admin password")
return
}
if err := h.backupService.RestoreBackup(c.Request.Context(), backupID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"restored": true})
}

View File

@@ -0,0 +1,208 @@
//go:build unit
package admin
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// failingAdminService 嵌入 stubAdminService可配置 UpdateAccount 在指定 ID 时失败。
type failingAdminService struct {
*stubAdminService
failOnAccountID int64
updateCallCount atomic.Int64
}
func (f *failingAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
f.updateCallCount.Add(1)
if id == f.failOnAccountID {
return nil, errors.New("database error")
}
return f.stubAdminService.UpdateAccount(ctx, id, input)
}
func setupAccountHandlerWithService(adminSvc service.AdminService) (*gin.Engine, *AccountHandler) {
gin.SetMode(gin.TestMode)
router := gin.New()
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.POST("/api/v1/admin/accounts/batch-update-credentials", handler.BatchUpdateCredentials)
return router, handler
}
func TestBatchUpdateCredentials_AllSuccess(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "account_uuid",
Value: "test-uuid",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code, "全部成功时应返回 200")
require.Equal(t, int64(3), svc.updateCallCount.Load(), "应调用 3 次 UpdateAccount")
}
func TestBatchUpdateCredentials_PartialFailure(t *testing.T) {
// 让第 2 个账号ID=2更新时失败
svc := &failingAdminService{
stubAdminService: newStubAdminService(),
failOnAccountID: 2,
}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "org_uuid",
Value: "test-org",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
// 实现采用"部分成功"模式:总是返回 200 + 成功/失败明细
require.Equal(t, http.StatusOK, w.Code, "批量更新返回 200 + 成功/失败明细")
var resp map[string]any
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
data := resp["data"].(map[string]any)
require.Equal(t, float64(2), data["success"], "应有 2 个成功")
require.Equal(t, float64(1), data["failed"], "应有 1 个失败")
// 所有 3 个账号都会被尝试更新(非 fail-fast
require.Equal(t, int64(3), svc.updateCallCount.Load(),
"应调用 3 次 UpdateAccount逐个尝试失败后继续")
}
func TestBatchUpdateCredentials_FirstAccountNotFound(t *testing.T) {
// GetAccount 在 stubAdminService 中总是成功的,需要创建一个 GetAccount 会失败的 stub
svc := &getAccountFailingService{
stubAdminService: newStubAdminService(),
failOnAccountID: 1,
}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "account_uuid",
Value: "test",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusNotFound, w.Code, "第一阶段验证失败应返回 404")
}
// getAccountFailingService 模拟 GetAccount 在特定 ID 时返回 not found。
type getAccountFailingService struct {
*stubAdminService
failOnAccountID int64
}
func (f *getAccountFailingService) GetAccount(ctx context.Context, id int64) (*service.Account, error) {
if id == f.failOnAccountID {
return nil, errors.New("not found")
}
return f.stubAdminService.GetAccount(ctx, id)
}
func TestBatchUpdateCredentials_InterceptWarmupRequests_NonBool(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// intercept_warmup_requests 传入非 bool 类型string应返回 400
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "intercept_warmup_requests",
"value": "not-a-bool",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code,
"intercept_warmup_requests 传入非 bool 值应返回 400")
}
func TestBatchUpdateCredentials_InterceptWarmupRequests_ValidBool(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "intercept_warmup_requests",
"value": true,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code,
"intercept_warmup_requests 传入合法 bool 值应返回 200")
}
func TestBatchUpdateCredentials_AccountUUID_NonString(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// account_uuid 传入非 string 类型number应返回 400
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "account_uuid",
"value": 12345,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code,
"account_uuid 传入非 string 值应返回 400")
}
func TestBatchUpdateCredentials_AccountUUID_NullValue(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// account_uuid 传入 null设置为空应正常通过
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "account_uuid",
"value": nil,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code,
"account_uuid 传入 null 应返回 200")
}

View File

@@ -0,0 +1,606 @@
package admin
import (
"encoding/json"
"errors"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// DashboardHandler handles admin dashboard statistics
type DashboardHandler struct {
dashboardService *service.DashboardService
aggregationService *service.DashboardAggregationService
startTime time.Time // Server start time for uptime calculation
}
// NewDashboardHandler creates a new admin dashboard handler
func NewDashboardHandler(dashboardService *service.DashboardService, aggregationService *service.DashboardAggregationService) *DashboardHandler {
return &DashboardHandler{
dashboardService: dashboardService,
aggregationService: aggregationService,
startTime: time.Now(),
}
}
// parseTimeRange parses start_date, end_date query parameters
// Uses user's timezone if provided, otherwise falls back to server timezone
func parseTimeRange(c *gin.Context) (time.Time, time.Time) {
userTZ := c.Query("timezone") // Get user's timezone from request
now := timezone.NowInUserLocation(userTZ)
startDate := c.Query("start_date")
endDate := c.Query("end_date")
var startTime, endTime time.Time
if startDate != "" {
if t, err := timezone.ParseInUserLocation("2006-01-02", startDate, userTZ); err == nil {
startTime = t
} else {
startTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, -7), userTZ)
}
} else {
startTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, -7), userTZ)
}
if endDate != "" {
if t, err := timezone.ParseInUserLocation("2006-01-02", endDate, userTZ); err == nil {
endTime = t.Add(24 * time.Hour) // Include the end date
} else {
endTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, 1), userTZ)
}
} else {
endTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, 1), userTZ)
}
return startTime, endTime
}
// GetStats handles getting dashboard statistics
// GET /api/v1/admin/dashboard/stats
func (h *DashboardHandler) GetStats(c *gin.Context) {
stats, err := h.dashboardService.GetDashboardStats(c.Request.Context())
if err != nil {
response.Error(c, 500, "Failed to get dashboard statistics")
return
}
// Calculate uptime in seconds
uptime := int64(time.Since(h.startTime).Seconds())
response.Success(c, gin.H{
// 用户统计
"total_users": stats.TotalUsers,
"today_new_users": stats.TodayNewUsers,
"active_users": stats.ActiveUsers,
// API Key 统计
"total_api_keys": stats.TotalAPIKeys,
"active_api_keys": stats.ActiveAPIKeys,
// 账户统计
"total_accounts": stats.TotalAccounts,
"normal_accounts": stats.NormalAccounts,
"error_accounts": stats.ErrorAccounts,
"ratelimit_accounts": stats.RateLimitAccounts,
"overload_accounts": stats.OverloadAccounts,
// 累计 Token 使用统计
"total_requests": stats.TotalRequests,
"total_input_tokens": stats.TotalInputTokens,
"total_output_tokens": stats.TotalOutputTokens,
"total_cache_creation_tokens": stats.TotalCacheCreationTokens,
"total_cache_read_tokens": stats.TotalCacheReadTokens,
"total_tokens": stats.TotalTokens,
"total_cost": stats.TotalCost, // 标准计费
"total_actual_cost": stats.TotalActualCost, // 实际扣除
// 今日 Token 使用统计
"today_requests": stats.TodayRequests,
"today_input_tokens": stats.TodayInputTokens,
"today_output_tokens": stats.TodayOutputTokens,
"today_cache_creation_tokens": stats.TodayCacheCreationTokens,
"today_cache_read_tokens": stats.TodayCacheReadTokens,
"today_tokens": stats.TodayTokens,
"today_cost": stats.TodayCost, // 今日标准计费
"today_actual_cost": stats.TodayActualCost, // 今日实际扣除
// 系统运行统计
"average_duration_ms": stats.AverageDurationMs,
"uptime": uptime,
// 性能指标
"rpm": stats.Rpm,
"tpm": stats.Tpm,
// 预聚合新鲜度
"hourly_active_users": stats.HourlyActiveUsers,
"stats_updated_at": stats.StatsUpdatedAt,
"stats_stale": stats.StatsStale,
})
}
type DashboardAggregationBackfillRequest struct {
Start string `json:"start"`
End string `json:"end"`
}
// BackfillAggregation handles triggering aggregation backfill
// POST /api/v1/admin/dashboard/aggregation/backfill
func (h *DashboardHandler) BackfillAggregation(c *gin.Context) {
if h.aggregationService == nil {
response.InternalError(c, "Aggregation service not available")
return
}
var req DashboardAggregationBackfillRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
start, err := time.Parse(time.RFC3339, req.Start)
if err != nil {
response.BadRequest(c, "Invalid start time")
return
}
end, err := time.Parse(time.RFC3339, req.End)
if err != nil {
response.BadRequest(c, "Invalid end time")
return
}
if err := h.aggregationService.TriggerBackfill(start, end); err != nil {
if errors.Is(err, service.ErrDashboardBackfillDisabled) {
response.Forbidden(c, "Backfill is disabled")
return
}
if errors.Is(err, service.ErrDashboardBackfillTooLarge) {
response.BadRequest(c, "Backfill range too large")
return
}
response.InternalError(c, "Failed to trigger backfill")
return
}
response.Success(c, gin.H{
"status": "accepted",
})
}
// GetRealtimeMetrics handles getting real-time system metrics
// GET /api/v1/admin/dashboard/realtime
func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
// Return mock data for now
response.Success(c, gin.H{
"active_requests": 0,
"requests_per_minute": 0,
"average_response_time": 0,
"error_rate": 0.0,
})
}
// GetUsageTrend handles getting usage trend data
// GET /api/v1/admin/dashboard/trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, request_type, stream, billing_type
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
// Parse optional filter params
var userID, apiKeyID, accountID, groupID int64
var model string
var requestType *int16
var stream *bool
var billingType *int8
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
userID = id
}
}
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil {
apiKeyID = id
}
}
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
if id, err := strconv.ParseInt(accountIDStr, 10, 64); err == nil {
accountID = id
}
}
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
groupID = id
}
}
if modelStr := c.Query("model"); modelStr != "" {
model = modelStr
}
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
response.BadRequest(c, err.Error())
return
}
value := int16(parsed)
requestType = &value
} else if streamStr := c.Query("stream"); streamStr != "" {
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
stream = &streamVal
} else {
response.BadRequest(c, "Invalid stream value, use true or false")
return
}
}
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil {
bt := int8(v)
billingType = &bt
} else {
response.BadRequest(c, "Invalid billing_type")
return
}
}
trend, hit, err := h.getUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get usage trend")
return
}
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
response.Success(c, gin.H{
"trend": trend,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
"granularity": granularity,
})
}
// GetModelStats handles getting model usage statistics
// GET /api/v1/admin/dashboard/models
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, request_type, stream, billing_type
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
// Parse optional filter params
var userID, apiKeyID, accountID, groupID int64
var requestType *int16
var stream *bool
var billingType *int8
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
userID = id
}
}
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil {
apiKeyID = id
}
}
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
if id, err := strconv.ParseInt(accountIDStr, 10, 64); err == nil {
accountID = id
}
}
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
groupID = id
}
}
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
response.BadRequest(c, err.Error())
return
}
value := int16(parsed)
requestType = &value
} else if streamStr := c.Query("stream"); streamStr != "" {
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
stream = &streamVal
} else {
response.BadRequest(c, "Invalid stream value, use true or false")
return
}
}
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil {
bt := int8(v)
billingType = &bt
} else {
response.BadRequest(c, "Invalid billing_type")
return
}
}
stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return
}
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
response.Success(c, gin.H{
"models": stats,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
})
}
// GetGroupStats handles getting group usage statistics
// GET /api/v1/admin/dashboard/groups
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, request_type, stream, billing_type
func (h *DashboardHandler) GetGroupStats(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
var userID, apiKeyID, accountID, groupID int64
var requestType *int16
var stream *bool
var billingType *int8
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
userID = id
}
}
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil {
apiKeyID = id
}
}
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
if id, err := strconv.ParseInt(accountIDStr, 10, 64); err == nil {
accountID = id
}
}
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
groupID = id
}
}
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
response.BadRequest(c, err.Error())
return
}
value := int16(parsed)
requestType = &value
} else if streamStr := c.Query("stream"); streamStr != "" {
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
stream = &streamVal
} else {
response.BadRequest(c, "Invalid stream value, use true or false")
return
}
}
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil {
bt := int8(v)
billingType = &bt
} else {
response.BadRequest(c, "Invalid billing_type")
return
}
}
stats, hit, err := h.getGroupStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get group statistics")
return
}
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
response.Success(c, gin.H{
"groups": stats,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
})
}
// GetAPIKeyUsageTrend handles getting API key usage trend data
// GET /api/v1/admin/dashboard/api-keys-trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), limit (default 5)
func (h *DashboardHandler) GetAPIKeyUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
limitStr := c.DefaultQuery("limit", "5")
limit, err := strconv.Atoi(limitStr)
if err != nil || limit <= 0 {
limit = 5
}
trend, hit, err := h.getAPIKeyUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit)
if err != nil {
response.Error(c, 500, "Failed to get API key usage trend")
return
}
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
response.Success(c, gin.H{
"trend": trend,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
"granularity": granularity,
})
}
// GetUserUsageTrend handles getting user usage trend data
// GET /api/v1/admin/dashboard/users-trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), limit (default 12)
func (h *DashboardHandler) GetUserUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
limitStr := c.DefaultQuery("limit", "12")
limit, err := strconv.Atoi(limitStr)
if err != nil || limit <= 0 {
limit = 12
}
trend, hit, err := h.getUserUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit)
if err != nil {
response.Error(c, 500, "Failed to get user usage trend")
return
}
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
response.Success(c, gin.H{
"trend": trend,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
"granularity": granularity,
})
}
// BatchUsersUsageRequest represents the request body for batch user usage stats
type BatchUsersUsageRequest struct {
UserIDs []int64 `json:"user_ids" binding:"required"`
}
var dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second)
var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second)
func parseRankingLimit(raw string) int {
limit, err := strconv.Atoi(strings.TrimSpace(raw))
if err != nil || limit <= 0 {
return 12
}
if limit > 50 {
return 50
}
return limit
}
// GetUserSpendingRanking handles getting user spending ranking data.
// GET /api/v1/admin/dashboard/users-ranking
func (h *DashboardHandler) GetUserSpendingRanking(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
limit := parseRankingLimit(c.DefaultQuery("limit", "12"))
keyRaw, _ := json.Marshal(struct {
Start string `json:"start"`
End string `json:"end"`
Limit int `json:"limit"`
}{
Start: startTime.UTC().Format(time.RFC3339),
End: endTime.UTC().Format(time.RFC3339),
Limit: limit,
})
cacheKey := string(keyRaw)
if cached, ok := dashboardUsersRankingCache.Get(cacheKey); ok {
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
ranking, err := h.dashboardService.GetUserSpendingRanking(c.Request.Context(), startTime, endTime, limit)
if err != nil {
response.Error(c, 500, "Failed to get user spending ranking")
return
}
payload := gin.H{
"ranking": ranking.Ranking,
"total_actual_cost": ranking.TotalActualCost,
"total_requests": ranking.TotalRequests,
"total_tokens": ranking.TotalTokens,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
}
dashboardUsersRankingCache.Set(cacheKey, payload)
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, payload)
}
// GetBatchUsersUsage handles getting usage stats for multiple users
// POST /api/v1/admin/dashboard/users-usage
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
var req BatchUsersUsageRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
userIDs := normalizeInt64IDList(req.UserIDs)
if len(userIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
keyRaw, _ := json.Marshal(struct {
UserIDs []int64 `json:"user_ids"`
}{
UserIDs: userIDs,
})
cacheKey := string(keyRaw)
if cached, ok := dashboardBatchUsersUsageCache.Get(cacheKey); ok {
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), userIDs, time.Time{}, time.Time{})
if err != nil {
response.Error(c, 500, "Failed to get user usage stats")
return
}
payload := gin.H{"stats": stats}
dashboardBatchUsersUsageCache.Set(cacheKey, payload)
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, payload)
}
// BatchAPIKeysUsageRequest represents the request body for batch api key usage stats
type BatchAPIKeysUsageRequest struct {
APIKeyIDs []int64 `json:"api_key_ids" binding:"required"`
}
// GetBatchAPIKeysUsage handles getting usage stats for multiple API keys
// POST /api/v1/admin/dashboard/api-keys-usage
func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
var req BatchAPIKeysUsageRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
apiKeyIDs := normalizeInt64IDList(req.APIKeyIDs)
if len(apiKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
keyRaw, _ := json.Marshal(struct {
APIKeyIDs []int64 `json:"api_key_ids"`
}{
APIKeyIDs: apiKeyIDs,
})
cacheKey := string(keyRaw)
if cached, ok := dashboardBatchAPIKeysUsageCache.Get(cacheKey); ok {
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), apiKeyIDs, time.Time{}, time.Time{})
if err != nil {
response.Error(c, 500, "Failed to get API key usage stats")
return
}
payload := gin.H{"stats": stats}
dashboardBatchAPIKeysUsageCache.Set(cacheKey, payload)
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, payload)
}

View File

@@ -0,0 +1,118 @@
package admin
import (
"context"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type dashboardUsageRepoCacheProbe struct {
service.UsageLogRepository
trendCalls atomic.Int32
usersTrendCalls atomic.Int32
}
func (r *dashboardUsageRepoCacheProbe) GetUsageTrendWithFilters(
ctx context.Context,
startTime, endTime time.Time,
granularity string,
userID, apiKeyID, accountID, groupID int64,
model string,
requestType *int16,
stream *bool,
billingType *int8,
) ([]usagestats.TrendDataPoint, error) {
r.trendCalls.Add(1)
return []usagestats.TrendDataPoint{{
Date: "2026-03-11",
Requests: 1,
TotalTokens: 2,
Cost: 3,
ActualCost: 4,
}}, nil
}
func (r *dashboardUsageRepoCacheProbe) GetUserUsageTrend(
ctx context.Context,
startTime, endTime time.Time,
granularity string,
limit int,
) ([]usagestats.UserUsageTrendPoint, error) {
r.usersTrendCalls.Add(1)
return []usagestats.UserUsageTrendPoint{{
Date: "2026-03-11",
UserID: 1,
Email: "cache@test.dev",
Requests: 2,
Tokens: 20,
Cost: 2,
ActualCost: 1,
}}, nil
}
func resetDashboardReadCachesForTest() {
dashboardTrendCache = newSnapshotCache(30 * time.Second)
dashboardUsersTrendCache = newSnapshotCache(30 * time.Second)
dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second)
dashboardModelStatsCache = newSnapshotCache(30 * time.Second)
dashboardGroupStatsCache = newSnapshotCache(30 * time.Second)
dashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
}
func TestDashboardHandler_GetUsageTrend_UsesCache(t *testing.T) {
t.Cleanup(resetDashboardReadCachesForTest)
resetDashboardReadCachesForTest()
gin.SetMode(gin.TestMode)
repo := &dashboardUsageRepoCacheProbe{}
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
handler := NewDashboardHandler(dashboardSvc, nil)
router := gin.New()
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil)
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
require.Equal(t, http.StatusOK, rec1.Code)
require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache"))
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
require.Equal(t, http.StatusOK, rec2.Code)
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
require.Equal(t, int32(1), repo.trendCalls.Load())
}
func TestDashboardHandler_GetUserUsageTrend_UsesCache(t *testing.T) {
t.Cleanup(resetDashboardReadCachesForTest)
resetDashboardReadCachesForTest()
gin.SetMode(gin.TestMode)
repo := &dashboardUsageRepoCacheProbe{}
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
handler := NewDashboardHandler(dashboardSvc, nil)
router := gin.New()
router.GET("/admin/dashboard/users-trend", handler.GetUserUsageTrend)
req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil)
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
require.Equal(t, http.StatusOK, rec1.Code)
require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache"))
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
require.Equal(t, http.StatusOK, rec2.Code)
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
require.Equal(t, int32(1), repo.usersTrendCalls.Load())
}

View File

@@ -0,0 +1,179 @@
package admin
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type dashboardUsageRepoCapture struct {
service.UsageLogRepository
trendRequestType *int16
trendStream *bool
modelRequestType *int16
modelStream *bool
rankingLimit int
ranking []usagestats.UserSpendingRankingItem
rankingTotal float64
}
func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
ctx context.Context,
startTime, endTime time.Time,
granularity string,
userID, apiKeyID, accountID, groupID int64,
model string,
requestType *int16,
stream *bool,
billingType *int8,
) ([]usagestats.TrendDataPoint, error) {
s.trendRequestType = requestType
s.trendStream = stream
return []usagestats.TrendDataPoint{}, nil
}
func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters(
ctx context.Context,
startTime, endTime time.Time,
userID, apiKeyID, accountID, groupID int64,
requestType *int16,
stream *bool,
billingType *int8,
) ([]usagestats.ModelStat, error) {
s.modelRequestType = requestType
s.modelStream = stream
return []usagestats.ModelStat{}, nil
}
func (s *dashboardUsageRepoCapture) GetUserSpendingRanking(
ctx context.Context,
startTime, endTime time.Time,
limit int,
) (*usagestats.UserSpendingRankingResponse, error) {
s.rankingLimit = limit
return &usagestats.UserSpendingRankingResponse{
Ranking: s.ranking,
TotalActualCost: s.rankingTotal,
TotalRequests: 44,
TotalTokens: 1234,
}, nil
}
func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
gin.SetMode(gin.TestMode)
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
handler := NewDashboardHandler(dashboardSvc, nil)
router := gin.New()
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
router.GET("/admin/dashboard/models", handler.GetModelStats)
router.GET("/admin/dashboard/users-ranking", handler.GetUserSpendingRanking)
return router
}
func TestDashboardTrendRequestTypePriority(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=ws_v2&stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, repo.trendRequestType)
require.Equal(t, int16(service.RequestTypeWSV2), *repo.trendRequestType)
require.Nil(t, repo.trendStream)
}
func TestDashboardTrendInvalidRequestType(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardTrendInvalidStream(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardModelStatsRequestTypePriority(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=sync&stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, repo.modelRequestType)
require.Equal(t, int16(service.RequestTypeSync), *repo.modelRequestType)
require.Nil(t, repo.modelStream)
}
func TestDashboardModelStatsInvalidRequestType(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardModelStatsInvalidStream(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardUsersRankingLimitAndCache(t *testing.T) {
dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute)
repo := &dashboardUsageRepoCapture{
ranking: []usagestats.UserSpendingRankingItem{
{UserID: 7, Email: "rank@example.com", ActualCost: 10.5, Requests: 3, Tokens: 300},
},
rankingTotal: 88.8,
}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, 50, repo.rankingLimit)
require.Contains(t, rec.Body.String(), "\"total_actual_cost\":88.8")
require.Contains(t, rec.Body.String(), "\"total_requests\":44")
require.Contains(t, rec.Body.String(), "\"total_tokens\":1234")
require.Equal(t, "miss", rec.Header().Get("X-Snapshot-Cache"))
req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-ranking?limit=100&start_date=2025-01-01&end_date=2025-01-02", nil)
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
require.Equal(t, http.StatusOK, rec2.Code)
require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache"))
}

View File

@@ -0,0 +1,200 @@
package admin
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
)
var (
dashboardTrendCache = newSnapshotCache(30 * time.Second)
dashboardModelStatsCache = newSnapshotCache(30 * time.Second)
dashboardGroupStatsCache = newSnapshotCache(30 * time.Second)
dashboardUsersTrendCache = newSnapshotCache(30 * time.Second)
dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second)
)
type dashboardTrendCacheKey struct {
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
Granularity string `json:"granularity"`
UserID int64 `json:"user_id"`
APIKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
GroupID int64 `json:"group_id"`
Model string `json:"model"`
RequestType *int16 `json:"request_type"`
Stream *bool `json:"stream"`
BillingType *int8 `json:"billing_type"`
}
type dashboardModelGroupCacheKey struct {
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
UserID int64 `json:"user_id"`
APIKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
GroupID int64 `json:"group_id"`
RequestType *int16 `json:"request_type"`
Stream *bool `json:"stream"`
BillingType *int8 `json:"billing_type"`
}
type dashboardEntityTrendCacheKey struct {
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
Granularity string `json:"granularity"`
Limit int `json:"limit"`
}
func cacheStatusValue(hit bool) string {
if hit {
return "hit"
}
return "miss"
}
func mustMarshalDashboardCacheKey(value any) string {
raw, err := json.Marshal(value)
if err != nil {
return ""
}
return string(raw)
}
func snapshotPayloadAs[T any](payload any) (T, error) {
typed, ok := payload.(T)
if !ok {
var zero T
return zero, fmt.Errorf("unexpected cache payload type %T", payload)
}
return typed, nil
}
func (h *DashboardHandler) getUsageTrendCached(
ctx context.Context,
startTime, endTime time.Time,
granularity string,
userID, apiKeyID, accountID, groupID int64,
model string,
requestType *int16,
stream *bool,
billingType *int8,
) ([]usagestats.TrendDataPoint, bool, error) {
key := mustMarshalDashboardCacheKey(dashboardTrendCacheKey{
StartTime: startTime.UTC().Format(time.RFC3339),
EndTime: endTime.UTC().Format(time.RFC3339),
Granularity: granularity,
UserID: userID,
APIKeyID: apiKeyID,
AccountID: accountID,
GroupID: groupID,
Model: model,
RequestType: requestType,
Stream: stream,
BillingType: billingType,
})
entry, hit, err := dashboardTrendCache.GetOrLoad(key, func() (any, error) {
return h.dashboardService.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
})
if err != nil {
return nil, hit, err
}
trend, err := snapshotPayloadAs[[]usagestats.TrendDataPoint](entry.Payload)
return trend, hit, err
}
func (h *DashboardHandler) getModelStatsCached(
ctx context.Context,
startTime, endTime time.Time,
userID, apiKeyID, accountID, groupID int64,
requestType *int16,
stream *bool,
billingType *int8,
) ([]usagestats.ModelStat, bool, error) {
key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{
StartTime: startTime.UTC().Format(time.RFC3339),
EndTime: endTime.UTC().Format(time.RFC3339),
UserID: userID,
APIKeyID: apiKeyID,
AccountID: accountID,
GroupID: groupID,
RequestType: requestType,
Stream: stream,
BillingType: billingType,
})
entry, hit, err := dashboardModelStatsCache.GetOrLoad(key, func() (any, error) {
return h.dashboardService.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
})
if err != nil {
return nil, hit, err
}
stats, err := snapshotPayloadAs[[]usagestats.ModelStat](entry.Payload)
return stats, hit, err
}
func (h *DashboardHandler) getGroupStatsCached(
ctx context.Context,
startTime, endTime time.Time,
userID, apiKeyID, accountID, groupID int64,
requestType *int16,
stream *bool,
billingType *int8,
) ([]usagestats.GroupStat, bool, error) {
key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{
StartTime: startTime.UTC().Format(time.RFC3339),
EndTime: endTime.UTC().Format(time.RFC3339),
UserID: userID,
APIKeyID: apiKeyID,
AccountID: accountID,
GroupID: groupID,
RequestType: requestType,
Stream: stream,
BillingType: billingType,
})
entry, hit, err := dashboardGroupStatsCache.GetOrLoad(key, func() (any, error) {
return h.dashboardService.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
})
if err != nil {
return nil, hit, err
}
stats, err := snapshotPayloadAs[[]usagestats.GroupStat](entry.Payload)
return stats, hit, err
}
func (h *DashboardHandler) getAPIKeyUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, bool, error) {
key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{
StartTime: startTime.UTC().Format(time.RFC3339),
EndTime: endTime.UTC().Format(time.RFC3339),
Granularity: granularity,
Limit: limit,
})
entry, hit, err := dashboardAPIKeysTrendCache.GetOrLoad(key, func() (any, error) {
return h.dashboardService.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
})
if err != nil {
return nil, hit, err
}
trend, err := snapshotPayloadAs[[]usagestats.APIKeyUsageTrendPoint](entry.Payload)
return trend, hit, err
}
func (h *DashboardHandler) getUserUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, bool, error) {
key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{
StartTime: startTime.UTC().Format(time.RFC3339),
EndTime: endTime.UTC().Format(time.RFC3339),
Granularity: granularity,
Limit: limit,
})
entry, hit, err := dashboardUsersTrendCache.GetOrLoad(key, func() (any, error) {
return h.dashboardService.GetUserUsageTrend(ctx, startTime, endTime, granularity, limit)
})
if err != nil {
return nil, hit, err
}
trend, err := snapshotPayloadAs[[]usagestats.UserUsageTrendPoint](entry.Payload)
return trend, hit, err
}

View File

@@ -0,0 +1,302 @@
package admin
import (
"context"
"encoding/json"
"errors"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
var dashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
type dashboardSnapshotV2Stats struct {
usagestats.DashboardStats
Uptime int64 `json:"uptime"`
}
type dashboardSnapshotV2Response struct {
GeneratedAt string `json:"generated_at"`
StartDate string `json:"start_date"`
EndDate string `json:"end_date"`
Granularity string `json:"granularity"`
Stats *dashboardSnapshotV2Stats `json:"stats,omitempty"`
Trend []usagestats.TrendDataPoint `json:"trend,omitempty"`
Models []usagestats.ModelStat `json:"models,omitempty"`
Groups []usagestats.GroupStat `json:"groups,omitempty"`
UsersTrend []usagestats.UserUsageTrendPoint `json:"users_trend,omitempty"`
}
type dashboardSnapshotV2Filters struct {
UserID int64
APIKeyID int64
AccountID int64
GroupID int64
Model string
RequestType *int16
Stream *bool
BillingType *int8
}
type dashboardSnapshotV2CacheKey struct {
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
Granularity string `json:"granularity"`
UserID int64 `json:"user_id"`
APIKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
GroupID int64 `json:"group_id"`
Model string `json:"model"`
RequestType *int16 `json:"request_type"`
Stream *bool `json:"stream"`
BillingType *int8 `json:"billing_type"`
IncludeStats bool `json:"include_stats"`
IncludeTrend bool `json:"include_trend"`
IncludeModels bool `json:"include_models"`
IncludeGroups bool `json:"include_groups"`
IncludeUsersTrend bool `json:"include_users_trend"`
UsersTrendLimit int `json:"users_trend_limit"`
}
func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := strings.TrimSpace(c.DefaultQuery("granularity", "day"))
if granularity != "hour" {
granularity = "day"
}
includeStats := parseBoolQueryWithDefault(c.Query("include_stats"), true)
includeTrend := parseBoolQueryWithDefault(c.Query("include_trend"), true)
includeModels := parseBoolQueryWithDefault(c.Query("include_model_stats"), true)
includeGroups := parseBoolQueryWithDefault(c.Query("include_group_stats"), false)
includeUsersTrend := parseBoolQueryWithDefault(c.Query("include_users_trend"), false)
usersTrendLimit := 12
if raw := strings.TrimSpace(c.Query("users_trend_limit")); raw != "" {
if parsed, err := strconv.Atoi(raw); err == nil && parsed > 0 && parsed <= 50 {
usersTrendLimit = parsed
}
}
filters, err := parseDashboardSnapshotV2Filters(c)
if err != nil {
response.BadRequest(c, err.Error())
return
}
keyRaw, _ := json.Marshal(dashboardSnapshotV2CacheKey{
StartTime: startTime.UTC().Format(time.RFC3339),
EndTime: endTime.UTC().Format(time.RFC3339),
Granularity: granularity,
UserID: filters.UserID,
APIKeyID: filters.APIKeyID,
AccountID: filters.AccountID,
GroupID: filters.GroupID,
Model: filters.Model,
RequestType: filters.RequestType,
Stream: filters.Stream,
BillingType: filters.BillingType,
IncludeStats: includeStats,
IncludeTrend: includeTrend,
IncludeModels: includeModels,
IncludeGroups: includeGroups,
IncludeUsersTrend: includeUsersTrend,
UsersTrendLimit: usersTrendLimit,
})
cacheKey := string(keyRaw)
cached, hit, err := dashboardSnapshotV2Cache.GetOrLoad(cacheKey, func() (any, error) {
return h.buildSnapshotV2Response(
c.Request.Context(),
startTime,
endTime,
granularity,
filters,
includeStats,
includeTrend,
includeModels,
includeGroups,
includeUsersTrend,
usersTrendLimit,
)
})
if err != nil {
response.Error(c, 500, err.Error())
return
}
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
c.Status(http.StatusNotModified)
return
}
}
c.Header("X-Snapshot-Cache", cacheStatusValue(hit))
response.Success(c, cached.Payload)
}
func (h *DashboardHandler) buildSnapshotV2Response(
ctx context.Context,
startTime, endTime time.Time,
granularity string,
filters *dashboardSnapshotV2Filters,
includeStats, includeTrend, includeModels, includeGroups, includeUsersTrend bool,
usersTrendLimit int,
) (*dashboardSnapshotV2Response, error) {
resp := &dashboardSnapshotV2Response{
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
StartDate: startTime.Format("2006-01-02"),
EndDate: endTime.Add(-24 * time.Hour).Format("2006-01-02"),
Granularity: granularity,
}
if includeStats {
stats, err := h.dashboardService.GetDashboardStats(ctx)
if err != nil {
return nil, errors.New("failed to get dashboard statistics")
}
resp.Stats = &dashboardSnapshotV2Stats{
DashboardStats: *stats,
Uptime: int64(time.Since(h.startTime).Seconds()),
}
}
if includeTrend {
trend, _, err := h.getUsageTrendCached(
ctx,
startTime,
endTime,
granularity,
filters.UserID,
filters.APIKeyID,
filters.AccountID,
filters.GroupID,
filters.Model,
filters.RequestType,
filters.Stream,
filters.BillingType,
)
if err != nil {
return nil, errors.New("failed to get usage trend")
}
resp.Trend = trend
}
if includeModels {
models, _, err := h.getModelStatsCached(
ctx,
startTime,
endTime,
filters.UserID,
filters.APIKeyID,
filters.AccountID,
filters.GroupID,
filters.RequestType,
filters.Stream,
filters.BillingType,
)
if err != nil {
return nil, errors.New("failed to get model statistics")
}
resp.Models = models
}
if includeGroups {
groups, _, err := h.getGroupStatsCached(
ctx,
startTime,
endTime,
filters.UserID,
filters.APIKeyID,
filters.AccountID,
filters.GroupID,
filters.RequestType,
filters.Stream,
filters.BillingType,
)
if err != nil {
return nil, errors.New("failed to get group statistics")
}
resp.Groups = groups
}
if includeUsersTrend {
usersTrend, _, err := h.getUserUsageTrendCached(ctx, startTime, endTime, granularity, usersTrendLimit)
if err != nil {
return nil, errors.New("failed to get user usage trend")
}
resp.UsersTrend = usersTrend
}
return resp, nil
}
func parseDashboardSnapshotV2Filters(c *gin.Context) (*dashboardSnapshotV2Filters, error) {
filters := &dashboardSnapshotV2Filters{
Model: strings.TrimSpace(c.Query("model")),
}
if userIDStr := strings.TrimSpace(c.Query("user_id")); userIDStr != "" {
id, err := strconv.ParseInt(userIDStr, 10, 64)
if err != nil {
return nil, err
}
filters.UserID = id
}
if apiKeyIDStr := strings.TrimSpace(c.Query("api_key_id")); apiKeyIDStr != "" {
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
if err != nil {
return nil, err
}
filters.APIKeyID = id
}
if accountIDStr := strings.TrimSpace(c.Query("account_id")); accountIDStr != "" {
id, err := strconv.ParseInt(accountIDStr, 10, 64)
if err != nil {
return nil, err
}
filters.AccountID = id
}
if groupIDStr := strings.TrimSpace(c.Query("group_id")); groupIDStr != "" {
id, err := strconv.ParseInt(groupIDStr, 10, 64)
if err != nil {
return nil, err
}
filters.GroupID = id
}
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
return nil, err
}
value := int16(parsed)
filters.RequestType = &value
} else if streamStr := strings.TrimSpace(c.Query("stream")); streamStr != "" {
streamVal, err := strconv.ParseBool(streamStr)
if err != nil {
return nil, err
}
filters.Stream = &streamVal
}
if billingTypeStr := strings.TrimSpace(c.Query("billing_type")); billingTypeStr != "" {
v, err := strconv.ParseInt(billingTypeStr, 10, 8)
if err != nil {
return nil, err
}
bt := int8(v)
filters.BillingType = &bt
}
return filters, nil
}

View File

@@ -0,0 +1,545 @@
package admin
import (
"context"
"strconv"
"strings"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type DataManagementHandler struct {
dataManagementService dataManagementService
}
func NewDataManagementHandler(dataManagementService *service.DataManagementService) *DataManagementHandler {
return &DataManagementHandler{dataManagementService: dataManagementService}
}
type dataManagementService interface {
GetConfig(ctx context.Context) (service.DataManagementConfig, error)
UpdateConfig(ctx context.Context, cfg service.DataManagementConfig) (service.DataManagementConfig, error)
ValidateS3(ctx context.Context, cfg service.DataManagementS3Config) (service.DataManagementTestS3Result, error)
CreateBackupJob(ctx context.Context, input service.DataManagementCreateBackupJobInput) (service.DataManagementBackupJob, error)
ListSourceProfiles(ctx context.Context, sourceType string) ([]service.DataManagementSourceProfile, error)
CreateSourceProfile(ctx context.Context, input service.DataManagementCreateSourceProfileInput) (service.DataManagementSourceProfile, error)
UpdateSourceProfile(ctx context.Context, input service.DataManagementUpdateSourceProfileInput) (service.DataManagementSourceProfile, error)
DeleteSourceProfile(ctx context.Context, sourceType, profileID string) error
SetActiveSourceProfile(ctx context.Context, sourceType, profileID string) (service.DataManagementSourceProfile, error)
ListS3Profiles(ctx context.Context) ([]service.DataManagementS3Profile, error)
CreateS3Profile(ctx context.Context, input service.DataManagementCreateS3ProfileInput) (service.DataManagementS3Profile, error)
UpdateS3Profile(ctx context.Context, input service.DataManagementUpdateS3ProfileInput) (service.DataManagementS3Profile, error)
DeleteS3Profile(ctx context.Context, profileID string) error
SetActiveS3Profile(ctx context.Context, profileID string) (service.DataManagementS3Profile, error)
ListBackupJobs(ctx context.Context, input service.DataManagementListBackupJobsInput) (service.DataManagementListBackupJobsResult, error)
GetBackupJob(ctx context.Context, jobID string) (service.DataManagementBackupJob, error)
EnsureAgentEnabled(ctx context.Context) error
GetAgentHealth(ctx context.Context) service.DataManagementAgentHealth
}
type TestS3ConnectionRequest struct {
Endpoint string `json:"endpoint"`
Region string `json:"region" binding:"required"`
Bucket string `json:"bucket" binding:"required"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
UseSSL bool `json:"use_ssl"`
}
type CreateBackupJobRequest struct {
BackupType string `json:"backup_type" binding:"required,oneof=postgres redis full"`
UploadToS3 bool `json:"upload_to_s3"`
S3ProfileID string `json:"s3_profile_id"`
PostgresID string `json:"postgres_profile_id"`
RedisID string `json:"redis_profile_id"`
IdempotencyKey string `json:"idempotency_key"`
}
type CreateSourceProfileRequest struct {
ProfileID string `json:"profile_id" binding:"required"`
Name string `json:"name" binding:"required"`
Config service.DataManagementSourceConfig `json:"config" binding:"required"`
SetActive bool `json:"set_active"`
}
type UpdateSourceProfileRequest struct {
Name string `json:"name" binding:"required"`
Config service.DataManagementSourceConfig `json:"config" binding:"required"`
}
type CreateS3ProfileRequest struct {
ProfileID string `json:"profile_id" binding:"required"`
Name string `json:"name" binding:"required"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
UseSSL bool `json:"use_ssl"`
SetActive bool `json:"set_active"`
}
type UpdateS3ProfileRequest struct {
Name string `json:"name" binding:"required"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
UseSSL bool `json:"use_ssl"`
}
func (h *DataManagementHandler) GetAgentHealth(c *gin.Context) {
health := h.getAgentHealth(c)
payload := gin.H{
"enabled": health.Enabled,
"reason": health.Reason,
"socket_path": health.SocketPath,
}
if health.Agent != nil {
payload["agent"] = gin.H{
"status": health.Agent.Status,
"version": health.Agent.Version,
"uptime_seconds": health.Agent.UptimeSeconds,
}
}
response.Success(c, payload)
}
func (h *DataManagementHandler) GetConfig(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
cfg, err := h.dataManagementService.GetConfig(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
func (h *DataManagementHandler) UpdateConfig(c *gin.Context) {
var req service.DataManagementConfig
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
cfg, err := h.dataManagementService.UpdateConfig(c.Request.Context(), req)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
func (h *DataManagementHandler) TestS3(c *gin.Context) {
var req TestS3ConnectionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
result, err := h.dataManagementService.ValidateS3(c.Request.Context(), service.DataManagementS3Config{
Enabled: true,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
UseSSL: req.UseSSL,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"ok": result.OK, "message": result.Message})
}
func (h *DataManagementHandler) CreateBackupJob(c *gin.Context) {
var req CreateBackupJobRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
req.IdempotencyKey = normalizeBackupIdempotencyKey(c.GetHeader("X-Idempotency-Key"), req.IdempotencyKey)
if !h.requireAgentEnabled(c) {
return
}
triggeredBy := "admin:unknown"
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
triggeredBy = "admin:" + strconv.FormatInt(subject.UserID, 10)
}
job, err := h.dataManagementService.CreateBackupJob(c.Request.Context(), service.DataManagementCreateBackupJobInput{
BackupType: req.BackupType,
UploadToS3: req.UploadToS3,
S3ProfileID: req.S3ProfileID,
PostgresID: req.PostgresID,
RedisID: req.RedisID,
TriggeredBy: triggeredBy,
IdempotencyKey: req.IdempotencyKey,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"job_id": job.JobID, "status": job.Status})
}
func (h *DataManagementHandler) ListSourceProfiles(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType == "" {
response.BadRequest(c, "Invalid source_type")
return
}
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
if !h.requireAgentEnabled(c) {
return
}
items, err := h.dataManagementService.ListSourceProfiles(c.Request.Context(), sourceType)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"items": items})
}
func (h *DataManagementHandler) CreateSourceProfile(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
var req CreateSourceProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.CreateSourceProfile(c.Request.Context(), service.DataManagementCreateSourceProfileInput{
SourceType: sourceType,
ProfileID: req.ProfileID,
Name: req.Name,
Config: req.Config,
SetActive: req.SetActive,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) UpdateSourceProfile(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
var req UpdateSourceProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.UpdateSourceProfile(c.Request.Context(), service.DataManagementUpdateSourceProfileInput{
SourceType: sourceType,
ProfileID: profileID,
Name: req.Name,
Config: req.Config,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) DeleteSourceProfile(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
if err := h.dataManagementService.DeleteSourceProfile(c.Request.Context(), sourceType, profileID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"deleted": true})
}
func (h *DataManagementHandler) SetActiveSourceProfile(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.SetActiveSourceProfile(c.Request.Context(), sourceType, profileID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) ListS3Profiles(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
items, err := h.dataManagementService.ListS3Profiles(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"items": items})
}
func (h *DataManagementHandler) CreateS3Profile(c *gin.Context) {
var req CreateS3ProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.CreateS3Profile(c.Request.Context(), service.DataManagementCreateS3ProfileInput{
ProfileID: req.ProfileID,
Name: req.Name,
SetActive: req.SetActive,
S3: service.DataManagementS3Config{
Enabled: req.Enabled,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
UseSSL: req.UseSSL,
},
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) UpdateS3Profile(c *gin.Context) {
var req UpdateS3ProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.UpdateS3Profile(c.Request.Context(), service.DataManagementUpdateS3ProfileInput{
ProfileID: profileID,
Name: req.Name,
S3: service.DataManagementS3Config{
Enabled: req.Enabled,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
UseSSL: req.UseSSL,
},
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) DeleteS3Profile(c *gin.Context) {
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
if err := h.dataManagementService.DeleteS3Profile(c.Request.Context(), profileID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"deleted": true})
}
func (h *DataManagementHandler) SetActiveS3Profile(c *gin.Context) {
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.SetActiveS3Profile(c.Request.Context(), profileID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) ListBackupJobs(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
pageSize := int32(20)
if raw := strings.TrimSpace(c.Query("page_size")); raw != "" {
v, err := strconv.Atoi(raw)
if err != nil || v <= 0 {
response.BadRequest(c, "Invalid page_size")
return
}
pageSize = int32(v)
}
result, err := h.dataManagementService.ListBackupJobs(c.Request.Context(), service.DataManagementListBackupJobsInput{
PageSize: pageSize,
PageToken: c.Query("page_token"),
Status: c.Query("status"),
BackupType: c.Query("backup_type"),
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
func (h *DataManagementHandler) GetBackupJob(c *gin.Context) {
jobID := strings.TrimSpace(c.Param("job_id"))
if jobID == "" {
response.BadRequest(c, "Invalid backup job ID")
return
}
if !h.requireAgentEnabled(c) {
return
}
job, err := h.dataManagementService.GetBackupJob(c.Request.Context(), jobID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, job)
}
func (h *DataManagementHandler) requireAgentEnabled(c *gin.Context) bool {
if h.dataManagementService == nil {
err := infraerrors.ServiceUnavailable(
service.DataManagementAgentUnavailableReason,
"data management agent service is not configured",
).WithMetadata(map[string]string{"socket_path": service.DefaultDataManagementAgentSocketPath})
response.ErrorFrom(c, err)
return false
}
if err := h.dataManagementService.EnsureAgentEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return false
}
return true
}
func (h *DataManagementHandler) getAgentHealth(c *gin.Context) service.DataManagementAgentHealth {
if h.dataManagementService == nil {
return service.DataManagementAgentHealth{
Enabled: false,
Reason: service.DataManagementAgentUnavailableReason,
SocketPath: service.DefaultDataManagementAgentSocketPath,
}
}
return h.dataManagementService.GetAgentHealth(c.Request.Context())
}
func normalizeBackupIdempotencyKey(headerValue, bodyValue string) string {
headerKey := strings.TrimSpace(headerValue)
if headerKey != "" {
return headerKey
}
return strings.TrimSpace(bodyValue)
}

View File

@@ -0,0 +1,78 @@
package admin
import (
"encoding/json"
"net/http"
"net/http/httptest"
"path/filepath"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type apiEnvelope struct {
Code int `json:"code"`
Message string `json:"message"`
Reason string `json:"reason"`
Data json.RawMessage `json:"data"`
}
func TestDataManagementHandler_AgentHealthAlways200(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond)
h := NewDataManagementHandler(svc)
r := gin.New()
r.GET("/api/v1/admin/data-management/agent/health", h.GetAgentHealth)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/agent/health", nil)
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var envelope apiEnvelope
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope))
require.Equal(t, 0, envelope.Code)
var data struct {
Enabled bool `json:"enabled"`
Reason string `json:"reason"`
SocketPath string `json:"socket_path"`
}
require.NoError(t, json.Unmarshal(envelope.Data, &data))
require.False(t, data.Enabled)
require.Equal(t, service.DataManagementDeprecatedReason, data.Reason)
require.Equal(t, svc.SocketPath(), data.SocketPath)
}
func TestDataManagementHandler_NonHealthRouteReturns503WhenDisabled(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond)
h := NewDataManagementHandler(svc)
r := gin.New()
r.GET("/api/v1/admin/data-management/config", h.GetConfig)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/config", nil)
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
var envelope apiEnvelope
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope))
require.Equal(t, http.StatusServiceUnavailable, envelope.Code)
require.Equal(t, service.DataManagementDeprecatedReason, envelope.Reason)
}
func TestNormalizeBackupIdempotencyKey(t *testing.T) {
require.Equal(t, "from-header", normalizeBackupIdempotencyKey("from-header", "from-body"))
require.Equal(t, "from-body", normalizeBackupIdempotencyKey(" ", " from-body "))
require.Equal(t, "", normalizeBackupIdempotencyKey("", ""))
}

View File

@@ -0,0 +1,282 @@
package admin
import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ErrorPassthroughHandler 处理错误透传规则的 HTTP 请求
type ErrorPassthroughHandler struct {
service *service.ErrorPassthroughService
}
// NewErrorPassthroughHandler 创建错误透传规则处理器
func NewErrorPassthroughHandler(service *service.ErrorPassthroughService) *ErrorPassthroughHandler {
return &ErrorPassthroughHandler{service: service}
}
// CreateErrorPassthroughRuleRequest 创建规则请求
type CreateErrorPassthroughRuleRequest struct {
Name string `json:"name" binding:"required"`
Enabled *bool `json:"enabled"`
Priority int `json:"priority"`
ErrorCodes []int `json:"error_codes"`
Keywords []string `json:"keywords"`
MatchMode string `json:"match_mode"`
Platforms []string `json:"platforms"`
PassthroughCode *bool `json:"passthrough_code"`
ResponseCode *int `json:"response_code"`
PassthroughBody *bool `json:"passthrough_body"`
CustomMessage *string `json:"custom_message"`
SkipMonitoring *bool `json:"skip_monitoring"`
Description *string `json:"description"`
}
// UpdateErrorPassthroughRuleRequest 更新规则请求(部分更新,所有字段可选)
type UpdateErrorPassthroughRuleRequest struct {
Name *string `json:"name"`
Enabled *bool `json:"enabled"`
Priority *int `json:"priority"`
ErrorCodes []int `json:"error_codes"`
Keywords []string `json:"keywords"`
MatchMode *string `json:"match_mode"`
Platforms []string `json:"platforms"`
PassthroughCode *bool `json:"passthrough_code"`
ResponseCode *int `json:"response_code"`
PassthroughBody *bool `json:"passthrough_body"`
CustomMessage *string `json:"custom_message"`
SkipMonitoring *bool `json:"skip_monitoring"`
Description *string `json:"description"`
}
// List 获取所有规则
// GET /api/v1/admin/error-passthrough-rules
func (h *ErrorPassthroughHandler) List(c *gin.Context) {
rules, err := h.service.List(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, rules)
}
// GetByID 根据 ID 获取规则
// GET /api/v1/admin/error-passthrough-rules/:id
func (h *ErrorPassthroughHandler) GetByID(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid rule ID")
return
}
rule, err := h.service.GetByID(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
if rule == nil {
response.NotFound(c, "Rule not found")
return
}
response.Success(c, rule)
}
// Create 创建规则
// POST /api/v1/admin/error-passthrough-rules
func (h *ErrorPassthroughHandler) Create(c *gin.Context) {
var req CreateErrorPassthroughRuleRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
rule := &model.ErrorPassthroughRule{
Name: req.Name,
Priority: req.Priority,
ErrorCodes: req.ErrorCodes,
Keywords: req.Keywords,
Platforms: req.Platforms,
}
// 设置默认值
if req.Enabled != nil {
rule.Enabled = *req.Enabled
} else {
rule.Enabled = true
}
if req.MatchMode != "" {
rule.MatchMode = req.MatchMode
} else {
rule.MatchMode = model.MatchModeAny
}
if req.PassthroughCode != nil {
rule.PassthroughCode = *req.PassthroughCode
} else {
rule.PassthroughCode = true
}
if req.PassthroughBody != nil {
rule.PassthroughBody = *req.PassthroughBody
} else {
rule.PassthroughBody = true
}
if req.SkipMonitoring != nil {
rule.SkipMonitoring = *req.SkipMonitoring
}
rule.ResponseCode = req.ResponseCode
rule.CustomMessage = req.CustomMessage
rule.Description = req.Description
// 确保切片不为 nil
if rule.ErrorCodes == nil {
rule.ErrorCodes = []int{}
}
if rule.Keywords == nil {
rule.Keywords = []string{}
}
if rule.Platforms == nil {
rule.Platforms = []string{}
}
created, err := h.service.Create(c.Request.Context(), rule)
if err != nil {
if _, ok := err.(*model.ValidationError); ok {
response.BadRequest(c, err.Error())
return
}
response.ErrorFrom(c, err)
return
}
response.Success(c, created)
}
// Update 更新规则(支持部分更新)
// PUT /api/v1/admin/error-passthrough-rules/:id
func (h *ErrorPassthroughHandler) Update(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid rule ID")
return
}
var req UpdateErrorPassthroughRuleRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// 先获取现有规则
existing, err := h.service.GetByID(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
if existing == nil {
response.NotFound(c, "Rule not found")
return
}
// 部分更新:只更新请求中提供的字段
rule := &model.ErrorPassthroughRule{
ID: id,
Name: existing.Name,
Enabled: existing.Enabled,
Priority: existing.Priority,
ErrorCodes: existing.ErrorCodes,
Keywords: existing.Keywords,
MatchMode: existing.MatchMode,
Platforms: existing.Platforms,
PassthroughCode: existing.PassthroughCode,
ResponseCode: existing.ResponseCode,
PassthroughBody: existing.PassthroughBody,
CustomMessage: existing.CustomMessage,
SkipMonitoring: existing.SkipMonitoring,
Description: existing.Description,
}
// 应用请求中提供的更新
if req.Name != nil {
rule.Name = *req.Name
}
if req.Enabled != nil {
rule.Enabled = *req.Enabled
}
if req.Priority != nil {
rule.Priority = *req.Priority
}
if req.ErrorCodes != nil {
rule.ErrorCodes = req.ErrorCodes
}
if req.Keywords != nil {
rule.Keywords = req.Keywords
}
if req.MatchMode != nil {
rule.MatchMode = *req.MatchMode
}
if req.Platforms != nil {
rule.Platforms = req.Platforms
}
if req.PassthroughCode != nil {
rule.PassthroughCode = *req.PassthroughCode
}
if req.ResponseCode != nil {
rule.ResponseCode = req.ResponseCode
}
if req.PassthroughBody != nil {
rule.PassthroughBody = *req.PassthroughBody
}
if req.CustomMessage != nil {
rule.CustomMessage = req.CustomMessage
}
if req.Description != nil {
rule.Description = req.Description
}
if req.SkipMonitoring != nil {
rule.SkipMonitoring = *req.SkipMonitoring
}
// 确保切片不为 nil
if rule.ErrorCodes == nil {
rule.ErrorCodes = []int{}
}
if rule.Keywords == nil {
rule.Keywords = []string{}
}
if rule.Platforms == nil {
rule.Platforms = []string{}
}
updated, err := h.service.Update(c.Request.Context(), rule)
if err != nil {
if _, ok := err.(*model.ValidationError); ok {
response.BadRequest(c, err.Error())
return
}
response.ErrorFrom(c, err)
return
}
response.Success(c, updated)
}
// Delete 删除规则
// DELETE /api/v1/admin/error-passthrough-rules/:id
func (h *ErrorPassthroughHandler) Delete(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid rule ID")
return
}
if err := h.service.Delete(c.Request.Context(), id); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Rule deleted successfully"})
}

View File

@@ -0,0 +1,146 @@
package admin
import (
"fmt"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type GeminiOAuthHandler struct {
geminiOAuthService *service.GeminiOAuthService
}
func NewGeminiOAuthHandler(geminiOAuthService *service.GeminiOAuthService) *GeminiOAuthHandler {
return &GeminiOAuthHandler{geminiOAuthService: geminiOAuthService}
}
// GetCapabilities returns the Gemini OAuth configuration capabilities.
// GET /api/v1/admin/gemini/oauth/capabilities
func (h *GeminiOAuthHandler) GetCapabilities(c *gin.Context) {
cfg := h.geminiOAuthService.GetOAuthConfig()
response.Success(c, cfg)
}
type GeminiGenerateAuthURLRequest struct {
ProxyID *int64 `json:"proxy_id"`
ProjectID string `json:"project_id"`
// OAuth 类型: "code_assist" (需要 project_id) 或 "ai_studio" (不需要 project_id)
// 默认为 "code_assist" 以保持向后兼容
OAuthType string `json:"oauth_type"`
// TierID is a user-selected tier to be used when auto detection is unavailable or fails.
TierID string `json:"tier_id"`
}
// GenerateAuthURL generates Google OAuth authorization URL for Gemini.
// POST /api/v1/admin/gemini/oauth/auth-url
func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) {
var req GeminiGenerateAuthURLRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// 默认使用 code_assist 以保持向后兼容
oauthType := strings.TrimSpace(req.OAuthType)
if oauthType == "" {
oauthType = "code_assist"
}
if oauthType != "code_assist" && oauthType != "google_one" && oauthType != "ai_studio" {
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist', 'google_one', or 'ai_studio'")
return
}
// Always pass the "hosted" callback URI; the OAuth service may override it depending on
// oauth_type and whether the built-in Gemini CLI OAuth client is used.
redirectURI := deriveGeminiRedirectURI(c)
result, err := h.geminiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, redirectURI, req.ProjectID, oauthType, req.TierID)
if err != nil {
msg := err.Error()
// Treat missing/invalid OAuth client configuration as a user/config error.
if strings.Contains(msg, "OAuth client not configured") ||
strings.Contains(msg, "requires your own OAuth Client") ||
strings.Contains(msg, "requires a custom OAuth Client") ||
strings.Contains(msg, "GEMINI_CLI_OAUTH_CLIENT_SECRET_MISSING") ||
strings.Contains(msg, "built-in Gemini CLI OAuth client_secret is not configured") {
response.BadRequest(c, "Failed to generate auth URL: "+msg)
return
}
response.InternalError(c, "Failed to generate auth URL: "+msg)
return
}
response.Success(c, result)
}
type GeminiExchangeCodeRequest struct {
SessionID string `json:"session_id" binding:"required"`
State string `json:"state" binding:"required"`
Code string `json:"code" binding:"required"`
ProxyID *int64 `json:"proxy_id"`
// OAuth 类型: "code_assist" 或 "ai_studio",需要与 GenerateAuthURL 时的类型一致
OAuthType string `json:"oauth_type"`
// TierID is a user-selected tier to be used when auto detection is unavailable or fails.
// This field is optional; when omitted, the server uses the tier stored in the OAuth session.
TierID string `json:"tier_id"`
}
// ExchangeCode exchanges authorization code for tokens.
// POST /api/v1/admin/gemini/oauth/exchange-code
func (h *GeminiOAuthHandler) ExchangeCode(c *gin.Context) {
var req GeminiExchangeCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// 默认使用 code_assist 以保持向后兼容
oauthType := strings.TrimSpace(req.OAuthType)
if oauthType == "" {
oauthType = "code_assist"
}
if oauthType != "code_assist" && oauthType != "google_one" && oauthType != "ai_studio" {
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist', 'google_one', or 'ai_studio'")
return
}
tokenInfo, err := h.geminiOAuthService.ExchangeCode(c.Request.Context(), &service.GeminiExchangeCodeInput{
SessionID: req.SessionID,
State: req.State,
Code: req.Code,
ProxyID: req.ProxyID,
OAuthType: oauthType,
TierID: req.TierID,
})
if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error())
return
}
response.Success(c, tokenInfo)
}
func deriveGeminiRedirectURI(c *gin.Context) string {
origin := strings.TrimSpace(c.GetHeader("Origin"))
if origin != "" {
return strings.TrimRight(origin, "/") + "/auth/callback"
}
scheme := "http"
if c.Request.TLS != nil {
scheme = "https"
}
if xfProto := strings.TrimSpace(c.GetHeader("X-Forwarded-Proto")); xfProto != "" {
scheme = strings.TrimSpace(strings.Split(xfProto, ",")[0])
}
host := strings.TrimSpace(c.Request.Host)
if xfHost := strings.TrimSpace(c.GetHeader("X-Forwarded-Host")); xfHost != "" {
host = strings.TrimSpace(strings.Split(xfHost, ",")[0])
}
return fmt.Sprintf("%s://%s/auth/callback", scheme, host)
}

View File

@@ -0,0 +1,487 @@
package admin
import (
"bytes"
"encoding/json"
"fmt"
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// GroupHandler handles admin group management
type GroupHandler struct {
adminService service.AdminService
}
type optionalLimitField struct {
set bool
value *float64
}
func (f *optionalLimitField) UnmarshalJSON(data []byte) error {
f.set = true
trimmed := bytes.TrimSpace(data)
if bytes.Equal(trimmed, []byte("null")) {
f.value = nil
return nil
}
var number float64
if err := json.Unmarshal(trimmed, &number); err == nil {
f.value = &number
return nil
}
var text string
if err := json.Unmarshal(trimmed, &text); err == nil {
text = strings.TrimSpace(text)
if text == "" {
f.value = nil
return nil
}
number, err = strconv.ParseFloat(text, 64)
if err != nil {
return fmt.Errorf("invalid numeric limit value %q: %w", text, err)
}
f.value = &number
return nil
}
return fmt.Errorf("invalid limit value: %s", string(trimmed))
}
func (f optionalLimitField) ToServiceInput() *float64 {
if !f.set {
return nil
}
if f.value != nil {
return f.value
}
zero := 0.0
return &zero
}
// NewGroupHandler creates a new admin group handler
func NewGroupHandler(adminService service.AdminService) *GroupHandler {
return &GroupHandler{
adminService: adminService,
}
}
// CreateGroupRequest represents create group request
type CreateGroupRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
DailyLimitUSD optionalLimitField `json:"daily_limit_usd"`
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
// 图片生成计费配置antigravity 和 gemini 平台使用,负数表示清除配置)
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
MCPXMLInject *bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"`
// Sora 存储配额
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
DefaultMappedModel string `json:"default_mapped_model"`
// 从指定分组复制账号(创建后自动绑定)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
// UpdateGroupRequest represents update group request
type UpdateGroupRequest struct {
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier *float64 `json:"rate_multiplier"`
IsExclusive *bool `json:"is_exclusive"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
DailyLimitUSD optionalLimitField `json:"daily_limit_usd"`
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
// 图片生成计费配置antigravity 和 gemini 平台使用,负数表示清除配置)
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
ClaudeCodeOnly *bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64 `json:"model_routing"`
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
MCPXMLInject *bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string `json:"supported_model_scopes"`
// Sora 存储配额
SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch *bool `json:"allow_messages_dispatch"`
DefaultMappedModel *string `json:"default_mapped_model"`
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
// List handles listing all groups with pagination
// GET /api/v1/admin/groups
func (h *GroupHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
platform := c.Query("platform")
status := c.Query("status")
search := c.Query("search")
// 标准化和验证 search 参数
search = strings.TrimSpace(search)
if len(search) > 100 {
search = search[:100]
}
isExclusiveStr := c.Query("is_exclusive")
var isExclusive *bool
if isExclusiveStr != "" {
val := isExclusiveStr == "true"
isExclusive = &val
}
groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, search, isExclusive)
if err != nil {
response.ErrorFrom(c, err)
return
}
outGroups := make([]dto.AdminGroup, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *dto.GroupFromServiceAdmin(&groups[i]))
}
response.Paginated(c, outGroups, total, page, pageSize)
}
// GetAll handles getting all active groups without pagination
// GET /api/v1/admin/groups/all
func (h *GroupHandler) GetAll(c *gin.Context) {
platform := c.Query("platform")
var groups []service.Group
var err error
if platform != "" {
groups, err = h.adminService.GetAllGroupsByPlatform(c.Request.Context(), platform)
} else {
groups, err = h.adminService.GetAllGroups(c.Request.Context())
}
if err != nil {
response.ErrorFrom(c, err)
return
}
outGroups := make([]dto.AdminGroup, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *dto.GroupFromServiceAdmin(&groups[i]))
}
response.Success(c, outGroups)
}
// GetByID handles getting a group by ID
// GET /api/v1/admin/groups/:id
func (h *GroupHandler) GetByID(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
group, err := h.adminService.GetGroup(c.Request.Context(), groupID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.GroupFromServiceAdmin(group))
}
// Create handles creating a new group
// POST /api/v1/admin/groups
func (h *GroupHandler) Create(c *gin.Context) {
var req CreateGroupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{
Name: req.Name,
Description: req.Description,
Platform: req.Platform,
RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive,
SubscriptionType: req.SubscriptionType,
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
SoraImagePrice360: req.SoraImagePrice360,
SoraImagePrice540: req.SoraImagePrice540,
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
ModelRouting: req.ModelRouting,
ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject,
SupportedModelScopes: req.SupportedModelScopes,
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
AllowMessagesDispatch: req.AllowMessagesDispatch,
DefaultMappedModel: req.DefaultMappedModel,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.GroupFromServiceAdmin(group))
}
// Update handles updating a group
// PUT /api/v1/admin/groups/:id
func (h *GroupHandler) Update(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
var req UpdateGroupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{
Name: req.Name,
Description: req.Description,
Platform: req.Platform,
RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive,
Status: req.Status,
SubscriptionType: req.SubscriptionType,
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
SoraImagePrice360: req.SoraImagePrice360,
SoraImagePrice540: req.SoraImagePrice540,
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
ModelRouting: req.ModelRouting,
ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject,
SupportedModelScopes: req.SupportedModelScopes,
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
AllowMessagesDispatch: req.AllowMessagesDispatch,
DefaultMappedModel: req.DefaultMappedModel,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.GroupFromServiceAdmin(group))
}
// Delete handles deleting a group
// DELETE /api/v1/admin/groups/:id
func (h *GroupHandler) Delete(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
err = h.adminService.DeleteGroup(c.Request.Context(), groupID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Group deleted successfully"})
}
// GetStats handles getting group statistics
// GET /api/v1/admin/groups/:id/stats
func (h *GroupHandler) GetStats(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
// Return mock data for now
response.Success(c, gin.H{
"total_api_keys": 0,
"active_api_keys": 0,
"total_requests": 0,
"total_cost": 0.0,
})
_ = groupID // TODO: implement actual stats
}
// GetGroupAPIKeys handles getting API keys in a group
// GET /api/v1/admin/groups/:id/api-keys
func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
page, pageSize := response.ParsePagination(c)
keys, total, err := h.adminService.GetGroupAPIKeys(c.Request.Context(), groupID, page, pageSize)
if err != nil {
response.ErrorFrom(c, err)
return
}
outKeys := make([]dto.APIKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *dto.APIKeyFromService(&keys[i]))
}
response.Paginated(c, outKeys, total, page, pageSize)
}
// GetGroupRateMultipliers handles getting rate multipliers for users in a group
// GET /api/v1/admin/groups/:id/rate-multipliers
func (h *GroupHandler) GetGroupRateMultipliers(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
entries, err := h.adminService.GetGroupRateMultipliers(c.Request.Context(), groupID)
if err != nil {
response.ErrorFrom(c, err)
return
}
if entries == nil {
entries = []service.UserGroupRateEntry{}
}
response.Success(c, entries)
}
// ClearGroupRateMultipliers handles clearing all rate multipliers for a group
// DELETE /api/v1/admin/groups/:id/rate-multipliers
func (h *GroupHandler) ClearGroupRateMultipliers(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
if err := h.adminService.ClearGroupRateMultipliers(c.Request.Context(), groupID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Rate multipliers cleared successfully"})
}
// BatchSetGroupRateMultipliersRequest represents batch set rate multipliers request
type BatchSetGroupRateMultipliersRequest struct {
Entries []service.GroupRateMultiplierInput `json:"entries" binding:"required"`
}
// BatchSetGroupRateMultipliers handles batch setting rate multipliers for a group
// PUT /api/v1/admin/groups/:id/rate-multipliers
func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
var req BatchSetGroupRateMultipliersRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.adminService.BatchSetGroupRateMultipliers(c.Request.Context(), groupID, req.Entries); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Rate multipliers updated successfully"})
}
// UpdateSortOrderRequest represents the request to update group sort orders
type UpdateSortOrderRequest struct {
Updates []struct {
ID int64 `json:"id" binding:"required"`
SortOrder int `json:"sort_order"`
} `json:"updates" binding:"required,min=1"`
}
// UpdateSortOrder handles updating group sort orders
// PUT /api/v1/admin/groups/sort-order
func (h *GroupHandler) UpdateSortOrder(c *gin.Context) {
var req UpdateSortOrderRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
updates := make([]service.GroupSortOrderUpdate, 0, len(req.Updates))
for _, u := range req.Updates {
updates = append(updates, service.GroupSortOrderUpdate{
ID: u.ID,
SortOrder: u.SortOrder,
})
}
if err := h.adminService.UpdateGroupSortOrders(c.Request.Context(), updates); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Sort order updated successfully"})
}

View File

@@ -0,0 +1,25 @@
package admin
import "sort"
func normalizeInt64IDList(ids []int64) []int64 {
if len(ids) == 0 {
return nil
}
out := make([]int64, 0, len(ids))
seen := make(map[int64]struct{}, len(ids))
for _, id := range ids {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
sort.Slice(out, func(i, j int) bool { return out[i] < out[j] })
return out
}

View File

@@ -0,0 +1,57 @@
//go:build unit
package admin
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestNormalizeInt64IDList(t *testing.T) {
tests := []struct {
name string
in []int64
want []int64
}{
{"nil input", nil, nil},
{"empty input", []int64{}, nil},
{"single element", []int64{5}, []int64{5}},
{"already sorted unique", []int64{1, 2, 3}, []int64{1, 2, 3}},
{"duplicates removed", []int64{3, 1, 3, 2, 1}, []int64{1, 2, 3}},
{"zero filtered", []int64{0, 1, 2}, []int64{1, 2}},
{"negative filtered", []int64{-5, -1, 3}, []int64{3}},
{"all invalid", []int64{0, -1, -2}, []int64{}},
{"sorted output", []int64{9, 3, 7, 1}, []int64{1, 3, 7, 9}},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := normalizeInt64IDList(tc.in)
if tc.want == nil {
require.Nil(t, got)
} else {
require.Equal(t, tc.want, got)
}
})
}
}
func TestBuildAccountTodayStatsBatchCacheKey(t *testing.T) {
tests := []struct {
name string
ids []int64
want string
}{
{"empty", nil, "accounts_today_stats_empty"},
{"single", []int64{42}, "accounts_today_stats:42"},
{"multiple", []int64{1, 2, 3}, "accounts_today_stats:1,2,3"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := buildAccountTodayStatsBatchCacheKey(tc.ids)
require.Equal(t, tc.want, got)
})
}
}

View File

@@ -0,0 +1,115 @@
package admin
import (
"context"
"strconv"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type idempotencyStoreUnavailableMode int
const (
idempotencyStoreUnavailableFailClose idempotencyStoreUnavailableMode = iota
idempotencyStoreUnavailableFailOpen
)
func executeAdminIdempotent(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
execute func(context.Context) (any, error),
) (*service.IdempotencyExecuteResult, error) {
coordinator := service.DefaultIdempotencyCoordinator()
if coordinator == nil {
data, err := execute(c.Request.Context())
if err != nil {
return nil, err
}
return &service.IdempotencyExecuteResult{Data: data}, nil
}
actorScope := "admin:0"
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
actorScope = "admin:" + strconv.FormatInt(subject.UserID, 10)
}
return coordinator.Execute(c.Request.Context(), service.IdempotencyExecuteOptions{
Scope: scope,
ActorScope: actorScope,
Method: c.Request.Method,
Route: c.FullPath(),
IdempotencyKey: c.GetHeader("Idempotency-Key"),
Payload: payload,
RequireKey: true,
TTL: ttl,
}, execute)
}
func executeAdminIdempotentJSON(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
execute func(context.Context) (any, error),
) {
executeAdminIdempotentJSONWithMode(c, scope, payload, ttl, idempotencyStoreUnavailableFailClose, execute)
}
func executeAdminIdempotentJSONFailOpenOnStoreUnavailable(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
execute func(context.Context) (any, error),
) {
executeAdminIdempotentJSONWithMode(c, scope, payload, ttl, idempotencyStoreUnavailableFailOpen, execute)
}
func executeAdminIdempotentJSONWithMode(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
mode idempotencyStoreUnavailableMode,
execute func(context.Context) (any, error),
) {
result, err := executeAdminIdempotent(c, scope, payload, ttl, execute)
if err != nil {
if infraerrors.Code(err) == infraerrors.Code(service.ErrIdempotencyStoreUnavail) {
strategy := "fail_close"
if mode == idempotencyStoreUnavailableFailOpen {
strategy = "fail_open"
}
service.RecordIdempotencyStoreUnavailable(c.FullPath(), scope, "handler_"+strategy)
logger.LegacyPrintf("handler.idempotency", "[Idempotency] store unavailable: method=%s route=%s scope=%s strategy=%s", c.Request.Method, c.FullPath(), scope, strategy)
if mode == idempotencyStoreUnavailableFailOpen {
data, fallbackErr := execute(c.Request.Context())
if fallbackErr != nil {
response.ErrorFrom(c, fallbackErr)
return
}
c.Header("X-Idempotency-Degraded", "store-unavailable")
response.Success(c, data)
return
}
}
if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
response.ErrorFrom(c, err)
return
}
if result != nil && result.Replayed {
c.Header("X-Idempotency-Replayed", "true")
}
response.Success(c, result.Data)
}

View File

@@ -0,0 +1,285 @@
package admin
import (
"bytes"
"context"
"errors"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type storeUnavailableRepoStub struct{}
func (storeUnavailableRepoStub) CreateProcessing(context.Context, *service.IdempotencyRecord) (bool, error) {
return false, errors.New("store unavailable")
}
func (storeUnavailableRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*service.IdempotencyRecord, error) {
return nil, errors.New("store unavailable")
}
func (storeUnavailableRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
return false, errors.New("store unavailable")
}
func (storeUnavailableRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
return false, errors.New("store unavailable")
}
func (storeUnavailableRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
return errors.New("store unavailable")
}
func (storeUnavailableRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
return errors.New("store unavailable")
}
func (storeUnavailableRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) {
return 0, errors.New("store unavailable")
}
func TestExecuteAdminIdempotentJSONFailCloseOnStoreUnavailable(t *testing.T) {
gin.SetMode(gin.TestMode)
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(storeUnavailableRepoStub{}, service.DefaultIdempotencyConfig()))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed int
router := gin.New()
router.POST("/idempotent", func(c *gin.Context) {
executeAdminIdempotentJSON(c, "admin.test.high", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed++
return gin.H{"ok": true}, nil
})
})
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "test-key-1")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
require.Equal(t, 0, executed, "fail-close should block business execution when idempotency store is unavailable")
}
func TestExecuteAdminIdempotentJSONFailOpenOnStoreUnavailable(t *testing.T) {
gin.SetMode(gin.TestMode)
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(storeUnavailableRepoStub{}, service.DefaultIdempotencyConfig()))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed int
router := gin.New()
router.POST("/idempotent", func(c *gin.Context) {
executeAdminIdempotentJSONFailOpenOnStoreUnavailable(c, "admin.test.medium", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed++
return gin.H{"ok": true}, nil
})
})
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "test-key-2")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "store-unavailable", rec.Header().Get("X-Idempotency-Degraded"))
require.Equal(t, 1, executed, "fail-open strategy should allow semantic idempotent path to continue")
}
type memoryIdempotencyRepoStub struct {
mu sync.Mutex
nextID int64
data map[string]*service.IdempotencyRecord
}
func newMemoryIdempotencyRepoStub() *memoryIdempotencyRepoStub {
return &memoryIdempotencyRepoStub{
nextID: 1,
data: make(map[string]*service.IdempotencyRecord),
}
}
func (r *memoryIdempotencyRepoStub) key(scope, keyHash string) string {
return scope + "|" + keyHash
}
func (r *memoryIdempotencyRepoStub) clone(in *service.IdempotencyRecord) *service.IdempotencyRecord {
if in == nil {
return nil
}
out := *in
if in.LockedUntil != nil {
v := *in.LockedUntil
out.LockedUntil = &v
}
if in.ResponseBody != nil {
v := *in.ResponseBody
out.ResponseBody = &v
}
if in.ResponseStatus != nil {
v := *in.ResponseStatus
out.ResponseStatus = &v
}
if in.ErrorReason != nil {
v := *in.ErrorReason
out.ErrorReason = &v
}
return &out
}
func (r *memoryIdempotencyRepoStub) CreateProcessing(_ context.Context, record *service.IdempotencyRecord) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
k := r.key(record.Scope, record.IdempotencyKeyHash)
if _, ok := r.data[k]; ok {
return false, nil
}
cp := r.clone(record)
cp.ID = r.nextID
r.nextID++
r.data[k] = cp
record.ID = cp.ID
return true, nil
}
func (r *memoryIdempotencyRepoStub) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) {
r.mu.Lock()
defer r.mu.Unlock()
return r.clone(r.data[r.key(scope, keyHash)]), nil
}
func (r *memoryIdempotencyRepoStub) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
if rec.Status != fromStatus {
return false, nil
}
if rec.LockedUntil != nil && rec.LockedUntil.After(now) {
return false, nil
}
rec.Status = service.IdempotencyStatusProcessing
rec.LockedUntil = &newLockedUntil
rec.ExpiresAt = newExpiresAt
rec.ErrorReason = nil
return true, nil
}
return false, nil
}
func (r *memoryIdempotencyRepoStub) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
if rec.Status != service.IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint {
return false, nil
}
rec.LockedUntil = &newLockedUntil
rec.ExpiresAt = newExpiresAt
return true, nil
}
return false, nil
}
func (r *memoryIdempotencyRepoStub) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
rec.Status = service.IdempotencyStatusSucceeded
rec.LockedUntil = nil
rec.ExpiresAt = expiresAt
rec.ResponseStatus = &responseStatus
rec.ResponseBody = &responseBody
rec.ErrorReason = nil
return nil
}
return nil
}
func (r *memoryIdempotencyRepoStub) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
rec.Status = service.IdempotencyStatusFailedRetryable
rec.LockedUntil = &lockedUntil
rec.ExpiresAt = expiresAt
rec.ErrorReason = &errorReason
return nil
}
return nil
}
func (r *memoryIdempotencyRepoStub) DeleteExpired(_ context.Context, _ time.Time, _ int) (int64, error) {
return 0, nil
}
func TestExecuteAdminIdempotentJSONConcurrentRetryOnlyOneSideEffect(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := newMemoryIdempotencyRepoStub()
cfg := service.DefaultIdempotencyConfig()
cfg.ProcessingTimeout = 2 * time.Second
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(repo, cfg))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed atomic.Int32
router := gin.New()
router.POST("/idempotent", func(c *gin.Context) {
executeAdminIdempotentJSON(c, "admin.test.concurrent", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed.Add(1)
time.Sleep(120 * time.Millisecond)
return gin.H{"ok": true}, nil
})
})
call := func() (int, http.Header) {
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "same-key")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
return rec.Code, rec.Header()
}
var status1, status2 int
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
status1, _ = call()
}()
go func() {
defer wg.Done()
status2, _ = call()
}()
wg.Wait()
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status1)
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status2)
require.Equal(t, int32(1), executed.Load(), "same idempotency key should execute side-effect only once")
status3, headers3 := call()
require.Equal(t, http.StatusOK, status3)
require.Equal(t, "true", headers3.Get("X-Idempotency-Replayed"))
require.Equal(t, int32(1), executed.Load())
}

View File

@@ -0,0 +1,304 @@
package admin
import (
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// OpenAIOAuthHandler handles OpenAI OAuth-related operations
type OpenAIOAuthHandler struct {
openaiOAuthService *service.OpenAIOAuthService
adminService service.AdminService
}
func oauthPlatformFromPath(c *gin.Context) string {
if strings.Contains(c.FullPath(), "/admin/sora/") {
return service.PlatformSora
}
return service.PlatformOpenAI
}
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
return &OpenAIOAuthHandler{
openaiOAuthService: openaiOAuthService,
adminService: adminService,
}
}
// OpenAIGenerateAuthURLRequest represents the request for generating OpenAI auth URL
type OpenAIGenerateAuthURLRequest struct {
ProxyID *int64 `json:"proxy_id"`
RedirectURI string `json:"redirect_uri"`
}
// GenerateAuthURL generates OpenAI OAuth authorization URL
// POST /api/v1/admin/openai/generate-auth-url
func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
var req OpenAIGenerateAuthURLRequest
if err := c.ShouldBindJSON(&req); err != nil {
// Allow empty body
req = OpenAIGenerateAuthURLRequest{}
}
result, err := h.openaiOAuthService.GenerateAuthURL(
c.Request.Context(),
req.ProxyID,
req.RedirectURI,
oauthPlatformFromPath(c),
)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// OpenAIExchangeCodeRequest represents the request for exchanging OpenAI auth code
type OpenAIExchangeCodeRequest struct {
SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"`
State string `json:"state" binding:"required"`
RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"`
}
// ExchangeCode exchanges OpenAI authorization code for tokens
// POST /api/v1/admin/openai/exchange-code
func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
var req OpenAIExchangeCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID,
Code: req.Code,
State: req.State,
RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, tokenInfo)
}
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
type OpenAIRefreshTokenRequest struct {
RefreshToken string `json:"refresh_token"`
RT string `json:"rt"`
ClientID string `json:"client_id"`
ProxyID *int64 `json:"proxy_id"`
}
// RefreshToken refreshes an OpenAI OAuth token
// POST /api/v1/admin/openai/refresh-token
// POST /api/v1/admin/sora/rt2at
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
var req OpenAIRefreshTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
refreshToken := strings.TrimSpace(req.RefreshToken)
if refreshToken == "" {
refreshToken = strings.TrimSpace(req.RT)
}
if refreshToken == "" {
response.BadRequest(c, "refresh_token is required")
return
}
var proxyURL string
if req.ProxyID != nil {
proxy, err := h.adminService.GetProxy(c.Request.Context(), *req.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
// 未指定 client_id 时,根据请求路径平台自动设置默认值,避免 repository 层盲猜
clientID := strings.TrimSpace(req.ClientID)
if clientID == "" {
platform := oauthPlatformFromPath(c)
clientID, _ = openai.OAuthClientConfigByPlatform(platform)
}
tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, clientID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, tokenInfo)
}
// ExchangeSoraSessionToken exchanges Sora session token to access token
// POST /api/v1/admin/sora/st2at
func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) {
var req struct {
SessionToken string `json:"session_token"`
ST string `json:"st"`
ProxyID *int64 `json:"proxy_id"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
sessionToken := strings.TrimSpace(req.SessionToken)
if sessionToken == "" {
sessionToken = strings.TrimSpace(req.ST)
}
if sessionToken == "" {
response.BadRequest(c, "session_token is required")
return
}
tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, tokenInfo)
}
// RefreshAccountToken refreshes token for a specific OpenAI/Sora account
// POST /api/v1/admin/openai/accounts/:id/refresh
// POST /api/v1/admin/sora/accounts/:id/refresh
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
// Get account
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.ErrorFrom(c, err)
return
}
platform := oauthPlatformFromPath(c)
if account.Platform != platform {
response.BadRequest(c, "Account platform does not match OAuth endpoint")
return
}
// Only refresh OAuth-based accounts
if !account.IsOAuth() {
response.BadRequest(c, "Cannot refresh non-OAuth account credentials")
return
}
// Use OpenAI OAuth service to refresh token
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Build new credentials from token info
newCredentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
// Preserve non-token settings from existing credentials
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Credentials: newCredentials,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.AccountFromService(updatedAccount))
}
// CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info
// POST /api/v1/admin/openai/create-from-oauth
// POST /api/v1/admin/sora/create-from-oauth
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
var req struct {
SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"`
State string `json:"state" binding:"required"`
RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"`
Name string `json:"name"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Exchange code for tokens
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID,
Code: req.Code,
State: req.State,
RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
// Build credentials from token info
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
platform := oauthPlatformFromPath(c)
// Use email as default name if not provided
name := req.Name
if name == "" && tokenInfo.Email != "" {
name = tokenInfo.Email
}
if name == "" {
if platform == service.PlatformSora {
name = "Sora OAuth Account"
} else {
name = "OpenAI OAuth Account"
}
}
// Create account
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
Name: name,
Platform: platform,
Type: "oauth",
Credentials: credentials,
Extra: nil,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
GroupIDs: req.GroupIDs,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.AccountFromService(account))
}

View File

@@ -0,0 +1,612 @@
package admin
import (
"encoding/json"
"fmt"
"math"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
)
var validOpsAlertMetricTypes = []string{
"success_rate",
"error_rate",
"upstream_error_rate",
"cpu_usage_percent",
"memory_usage_percent",
"concurrency_queue_depth",
"group_available_accounts",
"group_available_ratio",
"group_rate_limit_ratio",
"account_rate_limited_count",
"account_error_count",
"account_error_ratio",
"overload_account_count",
}
var validOpsAlertMetricTypeSet = func() map[string]struct{} {
set := make(map[string]struct{}, len(validOpsAlertMetricTypes))
for _, v := range validOpsAlertMetricTypes {
set[v] = struct{}{}
}
return set
}()
var validOpsAlertOperators = []string{">", "<", ">=", "<=", "==", "!="}
var validOpsAlertOperatorSet = func() map[string]struct{} {
set := make(map[string]struct{}, len(validOpsAlertOperators))
for _, v := range validOpsAlertOperators {
set[v] = struct{}{}
}
return set
}()
var validOpsAlertSeverities = []string{"P0", "P1", "P2", "P3"}
var validOpsAlertSeveritySet = func() map[string]struct{} {
set := make(map[string]struct{}, len(validOpsAlertSeverities))
for _, v := range validOpsAlertSeverities {
set[v] = struct{}{}
}
return set
}()
type opsAlertRuleValidatedInput struct {
Name string
MetricType string
Operator string
Threshold float64
Severity string
WindowMinutes int
SustainedMinutes int
CooldownMinutes int
Enabled bool
NotifyEmail bool
WindowProvided bool
SustainedProvided bool
CooldownProvided bool
SeverityProvided bool
EnabledProvided bool
NotifyProvided bool
}
func isPercentOrRateMetric(metricType string) bool {
switch metricType {
case "success_rate",
"error_rate",
"upstream_error_rate",
"cpu_usage_percent",
"memory_usage_percent",
"group_available_ratio",
"group_rate_limit_ratio",
"account_error_ratio":
return true
default:
return false
}
}
func validateOpsAlertRulePayload(raw map[string]json.RawMessage) (*opsAlertRuleValidatedInput, error) {
if raw == nil {
return nil, fmt.Errorf("invalid request body")
}
requiredFields := []string{"name", "metric_type", "operator", "threshold"}
for _, field := range requiredFields {
if _, ok := raw[field]; !ok {
return nil, fmt.Errorf("%s is required", field)
}
}
var name string
if err := json.Unmarshal(raw["name"], &name); err != nil || strings.TrimSpace(name) == "" {
return nil, fmt.Errorf("name is required")
}
name = strings.TrimSpace(name)
var metricType string
if err := json.Unmarshal(raw["metric_type"], &metricType); err != nil || strings.TrimSpace(metricType) == "" {
return nil, fmt.Errorf("metric_type is required")
}
metricType = strings.TrimSpace(metricType)
if _, ok := validOpsAlertMetricTypeSet[metricType]; !ok {
return nil, fmt.Errorf("metric_type must be one of: %s", strings.Join(validOpsAlertMetricTypes, ", "))
}
var operator string
if err := json.Unmarshal(raw["operator"], &operator); err != nil || strings.TrimSpace(operator) == "" {
return nil, fmt.Errorf("operator is required")
}
operator = strings.TrimSpace(operator)
if _, ok := validOpsAlertOperatorSet[operator]; !ok {
return nil, fmt.Errorf("operator must be one of: %s", strings.Join(validOpsAlertOperators, ", "))
}
var threshold float64
if err := json.Unmarshal(raw["threshold"], &threshold); err != nil {
return nil, fmt.Errorf("threshold must be a number")
}
if math.IsNaN(threshold) || math.IsInf(threshold, 0) {
return nil, fmt.Errorf("threshold must be a finite number")
}
if isPercentOrRateMetric(metricType) {
if threshold < 0 || threshold > 100 {
return nil, fmt.Errorf("threshold must be between 0 and 100 for metric_type %s", metricType)
}
} else if threshold < 0 {
return nil, fmt.Errorf("threshold must be >= 0")
}
validated := &opsAlertRuleValidatedInput{
Name: name,
MetricType: metricType,
Operator: operator,
Threshold: threshold,
}
if v, ok := raw["severity"]; ok {
validated.SeverityProvided = true
var sev string
if err := json.Unmarshal(v, &sev); err != nil {
return nil, fmt.Errorf("severity must be a string")
}
sev = strings.ToUpper(strings.TrimSpace(sev))
if sev != "" {
if _, ok := validOpsAlertSeveritySet[sev]; !ok {
return nil, fmt.Errorf("severity must be one of: %s", strings.Join(validOpsAlertSeverities, ", "))
}
validated.Severity = sev
}
}
if validated.Severity == "" {
validated.Severity = "P2"
}
if v, ok := raw["enabled"]; ok {
validated.EnabledProvided = true
if err := json.Unmarshal(v, &validated.Enabled); err != nil {
return nil, fmt.Errorf("enabled must be a boolean")
}
} else {
validated.Enabled = true
}
if v, ok := raw["notify_email"]; ok {
validated.NotifyProvided = true
if err := json.Unmarshal(v, &validated.NotifyEmail); err != nil {
return nil, fmt.Errorf("notify_email must be a boolean")
}
} else {
validated.NotifyEmail = true
}
if v, ok := raw["window_minutes"]; ok {
validated.WindowProvided = true
if err := json.Unmarshal(v, &validated.WindowMinutes); err != nil {
return nil, fmt.Errorf("window_minutes must be an integer")
}
switch validated.WindowMinutes {
case 1, 5, 60:
default:
return nil, fmt.Errorf("window_minutes must be one of: 1, 5, 60")
}
} else {
validated.WindowMinutes = 1
}
if v, ok := raw["sustained_minutes"]; ok {
validated.SustainedProvided = true
if err := json.Unmarshal(v, &validated.SustainedMinutes); err != nil {
return nil, fmt.Errorf("sustained_minutes must be an integer")
}
if validated.SustainedMinutes < 1 || validated.SustainedMinutes > 1440 {
return nil, fmt.Errorf("sustained_minutes must be between 1 and 1440")
}
} else {
validated.SustainedMinutes = 1
}
if v, ok := raw["cooldown_minutes"]; ok {
validated.CooldownProvided = true
if err := json.Unmarshal(v, &validated.CooldownMinutes); err != nil {
return nil, fmt.Errorf("cooldown_minutes must be an integer")
}
if validated.CooldownMinutes < 0 || validated.CooldownMinutes > 1440 {
return nil, fmt.Errorf("cooldown_minutes must be between 0 and 1440")
}
} else {
validated.CooldownMinutes = 0
}
return validated, nil
}
// ListAlertRules returns all ops alert rules.
// GET /api/v1/admin/ops/alert-rules
func (h *OpsHandler) ListAlertRules(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
rules, err := h.opsService.ListAlertRules(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, rules)
}
// CreateAlertRule creates an ops alert rule.
// POST /api/v1/admin/ops/alert-rules
func (h *OpsHandler) CreateAlertRule(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
var raw map[string]json.RawMessage
if err := c.ShouldBindBodyWith(&raw, binding.JSON); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
validated, err := validateOpsAlertRulePayload(raw)
if err != nil {
response.BadRequest(c, err.Error())
return
}
var rule service.OpsAlertRule
if err := c.ShouldBindBodyWith(&rule, binding.JSON); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
rule.Name = validated.Name
rule.MetricType = validated.MetricType
rule.Operator = validated.Operator
rule.Threshold = validated.Threshold
rule.WindowMinutes = validated.WindowMinutes
rule.SustainedMinutes = validated.SustainedMinutes
rule.CooldownMinutes = validated.CooldownMinutes
rule.Severity = validated.Severity
rule.Enabled = validated.Enabled
rule.NotifyEmail = validated.NotifyEmail
created, err := h.opsService.CreateAlertRule(c.Request.Context(), &rule)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, created)
}
// UpdateAlertRule updates an existing ops alert rule.
// PUT /api/v1/admin/ops/alert-rules/:id
func (h *OpsHandler) UpdateAlertRule(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid rule ID")
return
}
var raw map[string]json.RawMessage
if err := c.ShouldBindBodyWith(&raw, binding.JSON); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
validated, err := validateOpsAlertRulePayload(raw)
if err != nil {
response.BadRequest(c, err.Error())
return
}
var rule service.OpsAlertRule
if err := c.ShouldBindBodyWith(&rule, binding.JSON); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
rule.ID = id
rule.Name = validated.Name
rule.MetricType = validated.MetricType
rule.Operator = validated.Operator
rule.Threshold = validated.Threshold
rule.WindowMinutes = validated.WindowMinutes
rule.SustainedMinutes = validated.SustainedMinutes
rule.CooldownMinutes = validated.CooldownMinutes
rule.Severity = validated.Severity
rule.Enabled = validated.Enabled
rule.NotifyEmail = validated.NotifyEmail
updated, err := h.opsService.UpdateAlertRule(c.Request.Context(), &rule)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, updated)
}
// DeleteAlertRule deletes an ops alert rule.
// DELETE /api/v1/admin/ops/alert-rules/:id
func (h *OpsHandler) DeleteAlertRule(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid rule ID")
return
}
if err := h.opsService.DeleteAlertRule(c.Request.Context(), id); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"deleted": true})
}
// GetAlertEvent returns a single ops alert event.
// GET /api/v1/admin/ops/alert-events/:id
func (h *OpsHandler) GetAlertEvent(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid event ID")
return
}
ev, err := h.opsService.GetAlertEventByID(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, ev)
}
// UpdateAlertEventStatus updates an ops alert event status.
// PUT /api/v1/admin/ops/alert-events/:id/status
func (h *OpsHandler) UpdateAlertEventStatus(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid event ID")
return
}
var payload struct {
Status string `json:"status"`
}
if err := c.ShouldBindJSON(&payload); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
payload.Status = strings.TrimSpace(payload.Status)
if payload.Status == "" {
response.BadRequest(c, "Invalid status")
return
}
if payload.Status != service.OpsAlertStatusResolved && payload.Status != service.OpsAlertStatusManualResolved {
response.BadRequest(c, "Invalid status")
return
}
var resolvedAt *time.Time
if payload.Status == service.OpsAlertStatusResolved || payload.Status == service.OpsAlertStatusManualResolved {
now := time.Now().UTC()
resolvedAt = &now
}
if err := h.opsService.UpdateAlertEventStatus(c.Request.Context(), id, payload.Status, resolvedAt); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"updated": true})
}
// ListAlertEvents lists recent ops alert events.
// GET /api/v1/admin/ops/alert-events
// CreateAlertSilence creates a scoped silence for ops alerts.
// POST /api/v1/admin/ops/alert-silences
func (h *OpsHandler) CreateAlertSilence(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
var payload struct {
RuleID int64 `json:"rule_id"`
Platform string `json:"platform"`
GroupID *int64 `json:"group_id"`
Region *string `json:"region"`
Until string `json:"until"`
Reason string `json:"reason"`
}
if err := c.ShouldBindJSON(&payload); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
until, err := time.Parse(time.RFC3339, strings.TrimSpace(payload.Until))
if err != nil {
response.BadRequest(c, "Invalid until")
return
}
createdBy := (*int64)(nil)
if subject, ok := middleware.GetAuthSubjectFromContext(c); ok {
uid := subject.UserID
createdBy = &uid
}
silence := &service.OpsAlertSilence{
RuleID: payload.RuleID,
Platform: strings.TrimSpace(payload.Platform),
GroupID: payload.GroupID,
Region: payload.Region,
Until: until,
Reason: strings.TrimSpace(payload.Reason),
CreatedBy: createdBy,
}
created, err := h.opsService.CreateAlertSilence(c.Request.Context(), silence)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, created)
}
func (h *OpsHandler) ListAlertEvents(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
limit := 20
if raw := strings.TrimSpace(c.Query("limit")); raw != "" {
n, err := strconv.Atoi(raw)
if err != nil || n <= 0 {
response.BadRequest(c, "Invalid limit")
return
}
limit = n
}
filter := &service.OpsAlertEventFilter{
Limit: limit,
Status: strings.TrimSpace(c.Query("status")),
Severity: strings.TrimSpace(c.Query("severity")),
}
if v := strings.TrimSpace(c.Query("email_sent")); v != "" {
vv := strings.ToLower(v)
switch vv {
case "true", "1":
b := true
filter.EmailSent = &b
case "false", "0":
b := false
filter.EmailSent = &b
default:
response.BadRequest(c, "Invalid email_sent")
return
}
}
// Cursor pagination: both params must be provided together.
rawTS := strings.TrimSpace(c.Query("before_fired_at"))
rawID := strings.TrimSpace(c.Query("before_id"))
if (rawTS == "") != (rawID == "") {
response.BadRequest(c, "before_fired_at and before_id must be provided together")
return
}
if rawTS != "" {
ts, err := time.Parse(time.RFC3339Nano, rawTS)
if err != nil {
if t2, err2 := time.Parse(time.RFC3339, rawTS); err2 == nil {
ts = t2
} else {
response.BadRequest(c, "Invalid before_fired_at")
return
}
}
filter.BeforeFiredAt = &ts
}
if rawID != "" {
id, err := strconv.ParseInt(rawID, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid before_id")
return
}
filter.BeforeID = &id
}
// Optional global filter support (platform/group/time range).
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
filter.Platform = platform
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
if startTime, endTime, err := parseOpsTimeRange(c, "24h"); err == nil {
// Only apply when explicitly provided to avoid surprising default narrowing.
if strings.TrimSpace(c.Query("start_time")) != "" || strings.TrimSpace(c.Query("end_time")) != "" || strings.TrimSpace(c.Query("time_range")) != "" {
filter.StartTime = &startTime
filter.EndTime = &endTime
}
} else {
response.BadRequest(c, err.Error())
return
}
events, err := h.opsService.ListAlertEvents(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, events)
}

View File

@@ -0,0 +1,353 @@
package admin
import (
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// GetDashboardOverview returns vNext ops dashboard overview (raw path).
// GET /api/v1/admin/ops/dashboard/overview
func (h *OpsHandler) GetDashboardOverview(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
startTime, endTime, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsDashboardFilter{
StartTime: startTime,
EndTime: endTime,
Platform: strings.TrimSpace(c.Query("platform")),
QueryMode: parseOpsQueryMode(c),
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
data, err := h.opsService.GetDashboardOverview(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, data)
}
// GetDashboardThroughputTrend returns throughput time series (raw path).
// GET /api/v1/admin/ops/dashboard/throughput-trend
func (h *OpsHandler) GetDashboardThroughputTrend(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
startTime, endTime, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsDashboardFilter{
StartTime: startTime,
EndTime: endTime,
Platform: strings.TrimSpace(c.Query("platform")),
QueryMode: parseOpsQueryMode(c),
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
bucketSeconds := pickThroughputBucketSeconds(endTime.Sub(startTime))
data, err := h.opsService.GetThroughputTrend(c.Request.Context(), filter, bucketSeconds)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, data)
}
// GetDashboardLatencyHistogram returns the latency distribution histogram (success requests).
// GET /api/v1/admin/ops/dashboard/latency-histogram
func (h *OpsHandler) GetDashboardLatencyHistogram(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
startTime, endTime, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsDashboardFilter{
StartTime: startTime,
EndTime: endTime,
Platform: strings.TrimSpace(c.Query("platform")),
QueryMode: parseOpsQueryMode(c),
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
data, err := h.opsService.GetLatencyHistogram(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, data)
}
// GetDashboardErrorTrend returns error counts time series (raw path).
// GET /api/v1/admin/ops/dashboard/error-trend
func (h *OpsHandler) GetDashboardErrorTrend(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
startTime, endTime, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsDashboardFilter{
StartTime: startTime,
EndTime: endTime,
Platform: strings.TrimSpace(c.Query("platform")),
QueryMode: parseOpsQueryMode(c),
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
bucketSeconds := pickThroughputBucketSeconds(endTime.Sub(startTime))
data, err := h.opsService.GetErrorTrend(c.Request.Context(), filter, bucketSeconds)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, data)
}
// GetDashboardErrorDistribution returns error distribution by status code (raw path).
// GET /api/v1/admin/ops/dashboard/error-distribution
func (h *OpsHandler) GetDashboardErrorDistribution(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
startTime, endTime, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsDashboardFilter{
StartTime: startTime,
EndTime: endTime,
Platform: strings.TrimSpace(c.Query("platform")),
QueryMode: parseOpsQueryMode(c),
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
data, err := h.opsService.GetErrorDistribution(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, data)
}
// GetDashboardOpenAITokenStats returns OpenAI token efficiency stats grouped by model.
// GET /api/v1/admin/ops/dashboard/openai-token-stats
func (h *OpsHandler) GetDashboardOpenAITokenStats(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
filter, err := parseOpsOpenAITokenStatsFilter(c)
if err != nil {
response.BadRequest(c, err.Error())
return
}
data, err := h.opsService.GetOpenAITokenStats(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, data)
}
func parseOpsOpenAITokenStatsFilter(c *gin.Context) (*service.OpsOpenAITokenStatsFilter, error) {
if c == nil {
return nil, fmt.Errorf("invalid request")
}
timeRange := strings.TrimSpace(c.Query("time_range"))
if timeRange == "" {
timeRange = "30d"
}
dur, ok := parseOpsOpenAITokenStatsDuration(timeRange)
if !ok {
return nil, fmt.Errorf("invalid time_range")
}
end := time.Now().UTC()
start := end.Add(-dur)
filter := &service.OpsOpenAITokenStatsFilter{
TimeRange: timeRange,
StartTime: start,
EndTime: end,
Platform: strings.TrimSpace(c.Query("platform")),
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
return nil, fmt.Errorf("invalid group_id")
}
filter.GroupID = &id
}
topNRaw := strings.TrimSpace(c.Query("top_n"))
pageRaw := strings.TrimSpace(c.Query("page"))
pageSizeRaw := strings.TrimSpace(c.Query("page_size"))
if topNRaw != "" && (pageRaw != "" || pageSizeRaw != "") {
return nil, fmt.Errorf("invalid query: top_n cannot be used with page/page_size")
}
if topNRaw != "" {
topN, err := strconv.Atoi(topNRaw)
if err != nil || topN < 1 || topN > 100 {
return nil, fmt.Errorf("invalid top_n")
}
filter.TopN = topN
return filter, nil
}
filter.Page = 1
filter.PageSize = 20
if pageRaw != "" {
page, err := strconv.Atoi(pageRaw)
if err != nil || page < 1 {
return nil, fmt.Errorf("invalid page")
}
filter.Page = page
}
if pageSizeRaw != "" {
pageSize, err := strconv.Atoi(pageSizeRaw)
if err != nil || pageSize < 1 || pageSize > 100 {
return nil, fmt.Errorf("invalid page_size")
}
filter.PageSize = pageSize
}
return filter, nil
}
func parseOpsOpenAITokenStatsDuration(v string) (time.Duration, bool) {
switch strings.TrimSpace(v) {
case "30m":
return 30 * time.Minute, true
case "1h":
return time.Hour, true
case "1d":
return 24 * time.Hour, true
case "15d":
return 15 * 24 * time.Hour, true
case "30d":
return 30 * 24 * time.Hour, true
default:
return 0, false
}
}
func pickThroughputBucketSeconds(window time.Duration) int {
// Keep buckets predictable and avoid huge responses.
switch {
case window <= 2*time.Hour:
return 60
case window <= 24*time.Hour:
return 300
default:
return 3600
}
}
func parseOpsQueryMode(c *gin.Context) service.OpsQueryMode {
if c == nil {
return ""
}
raw := strings.TrimSpace(c.Query("mode"))
if raw == "" {
// Empty means "use server default" (DB setting ops_query_mode_default).
return ""
}
return service.ParseOpsQueryMode(raw)
}

View File

@@ -0,0 +1,925 @@
package admin
import (
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type OpsHandler struct {
opsService *service.OpsService
}
// GetErrorLogByID returns ops error log detail.
// GET /api/v1/admin/ops/errors/:id
func (h *OpsHandler) GetErrorLogByID(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
detail, err := h.opsService.GetErrorLogByID(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, detail)
}
const (
opsListViewErrors = "errors"
opsListViewExcluded = "excluded"
opsListViewAll = "all"
)
func parseOpsViewParam(c *gin.Context) string {
if c == nil {
return ""
}
v := strings.ToLower(strings.TrimSpace(c.Query("view")))
switch v {
case "", opsListViewErrors:
return opsListViewErrors
case opsListViewExcluded:
return opsListViewExcluded
case opsListViewAll:
return opsListViewAll
default:
return opsListViewErrors
}
}
func NewOpsHandler(opsService *service.OpsService) *OpsHandler {
return &OpsHandler{opsService: opsService}
}
// GetErrorLogs lists ops error logs.
// GET /api/v1/admin/ops/errors
func (h *OpsHandler) GetErrorLogs(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
page, pageSize := response.ParsePagination(c)
// Ops list can be larger than standard admin tables.
if pageSize > 500 {
pageSize = 500
}
startTime, endTime, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
if !startTime.IsZero() {
filter.StartTime = &startTime
}
if !endTime.IsZero() {
filter.EndTime = &endTime
}
filter.View = parseOpsViewParam(c)
filter.Phase = strings.TrimSpace(c.Query("phase"))
filter.Owner = strings.TrimSpace(c.Query("error_owner"))
filter.Source = strings.TrimSpace(c.Query("error_source"))
filter.Query = strings.TrimSpace(c.Query("q"))
filter.UserQuery = strings.TrimSpace(c.Query("user_query"))
// Force request errors: client-visible status >= 400.
// buildOpsErrorLogsWhere already applies this for non-upstream phase.
if strings.EqualFold(strings.TrimSpace(filter.Phase), "upstream") {
filter.Phase = ""
}
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
filter.Platform = platform
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid account_id")
return
}
filter.AccountID = &id
}
if v := strings.TrimSpace(c.Query("resolved")); v != "" {
switch strings.ToLower(v) {
case "1", "true", "yes":
b := true
filter.Resolved = &b
case "0", "false", "no":
b := false
filter.Resolved = &b
default:
response.BadRequest(c, "Invalid resolved")
return
}
}
if statusCodesStr := strings.TrimSpace(c.Query("status_codes")); statusCodesStr != "" {
parts := strings.Split(statusCodesStr, ",")
out := make([]int, 0, len(parts))
for _, part := range parts {
p := strings.TrimSpace(part)
if p == "" {
continue
}
n, err := strconv.Atoi(p)
if err != nil || n < 0 {
response.BadRequest(c, "Invalid status_codes")
return
}
out = append(out, n)
}
filter.StatusCodes = out
}
result, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
}
// ListRequestErrors lists client-visible request errors.
// GET /api/v1/admin/ops/request-errors
func (h *OpsHandler) ListRequestErrors(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
page, pageSize := response.ParsePagination(c)
if pageSize > 500 {
pageSize = 500
}
startTime, endTime, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
if !startTime.IsZero() {
filter.StartTime = &startTime
}
if !endTime.IsZero() {
filter.EndTime = &endTime
}
filter.View = parseOpsViewParam(c)
filter.Phase = strings.TrimSpace(c.Query("phase"))
filter.Owner = strings.TrimSpace(c.Query("error_owner"))
filter.Source = strings.TrimSpace(c.Query("error_source"))
filter.Query = strings.TrimSpace(c.Query("q"))
filter.UserQuery = strings.TrimSpace(c.Query("user_query"))
// Force request errors: client-visible status >= 400.
// buildOpsErrorLogsWhere already applies this for non-upstream phase.
if strings.EqualFold(strings.TrimSpace(filter.Phase), "upstream") {
filter.Phase = ""
}
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
filter.Platform = platform
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid account_id")
return
}
filter.AccountID = &id
}
if v := strings.TrimSpace(c.Query("resolved")); v != "" {
switch strings.ToLower(v) {
case "1", "true", "yes":
b := true
filter.Resolved = &b
case "0", "false", "no":
b := false
filter.Resolved = &b
default:
response.BadRequest(c, "Invalid resolved")
return
}
}
if statusCodesStr := strings.TrimSpace(c.Query("status_codes")); statusCodesStr != "" {
parts := strings.Split(statusCodesStr, ",")
out := make([]int, 0, len(parts))
for _, part := range parts {
p := strings.TrimSpace(part)
if p == "" {
continue
}
n, err := strconv.Atoi(p)
if err != nil || n < 0 {
response.BadRequest(c, "Invalid status_codes")
return
}
out = append(out, n)
}
filter.StatusCodes = out
}
result, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
}
// GetRequestError returns request error detail.
// GET /api/v1/admin/ops/request-errors/:id
func (h *OpsHandler) GetRequestError(c *gin.Context) {
// same storage; just proxy to existing detail
h.GetErrorLogByID(c)
}
// ListRequestErrorUpstreamErrors lists upstream error logs correlated to a request error.
// GET /api/v1/admin/ops/request-errors/:id/upstream-errors
func (h *OpsHandler) ListRequestErrorUpstreamErrors(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
// Load request error to get correlation keys.
detail, err := h.opsService.GetErrorLogByID(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Correlate by request_id/client_request_id.
requestID := strings.TrimSpace(detail.RequestID)
clientRequestID := strings.TrimSpace(detail.ClientRequestID)
if requestID == "" && clientRequestID == "" {
response.Paginated(c, []*service.OpsErrorLog{}, 0, 1, 10)
return
}
page, pageSize := response.ParsePagination(c)
if pageSize > 500 {
pageSize = 500
}
// Keep correlation window wide enough so linked upstream errors
// are discoverable even when UI defaults to 1h elsewhere.
startTime, endTime, err := parseOpsTimeRange(c, "30d")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
if !startTime.IsZero() {
filter.StartTime = &startTime
}
if !endTime.IsZero() {
filter.EndTime = &endTime
}
filter.View = "all"
filter.Phase = "upstream"
filter.Owner = "provider"
filter.Source = strings.TrimSpace(c.Query("error_source"))
filter.Query = strings.TrimSpace(c.Query("q"))
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
filter.Platform = platform
}
// Prefer exact match on request_id; if missing, fall back to client_request_id.
if requestID != "" {
filter.RequestID = requestID
} else {
filter.ClientRequestID = clientRequestID
}
result, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
// If client asks for details, expand each upstream error log to include upstream response fields.
includeDetail := strings.TrimSpace(c.Query("include_detail"))
if includeDetail == "1" || strings.EqualFold(includeDetail, "true") || strings.EqualFold(includeDetail, "yes") {
details := make([]*service.OpsErrorLogDetail, 0, len(result.Errors))
for _, item := range result.Errors {
if item == nil {
continue
}
d, err := h.opsService.GetErrorLogByID(c.Request.Context(), item.ID)
if err != nil || d == nil {
continue
}
details = append(details, d)
}
response.Paginated(c, details, int64(result.Total), result.Page, result.PageSize)
return
}
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
}
// RetryRequestErrorClient retries the client request based on stored request body.
// POST /api/v1/admin/ops/request-errors/:id/retry-client
func (h *OpsHandler) RetryRequestErrorClient(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, service.OpsRetryModeClient, nil)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// RetryRequestErrorUpstreamEvent retries a specific upstream attempt using captured upstream_request_body.
// POST /api/v1/admin/ops/request-errors/:id/upstream-errors/:idx/retry
func (h *OpsHandler) RetryRequestErrorUpstreamEvent(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
idxStr := strings.TrimSpace(c.Param("idx"))
idx, err := strconv.Atoi(idxStr)
if err != nil || idx < 0 {
response.BadRequest(c, "Invalid upstream idx")
return
}
result, err := h.opsService.RetryUpstreamEvent(c.Request.Context(), subject.UserID, id, idx)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// ResolveRequestError toggles resolved status.
// PUT /api/v1/admin/ops/request-errors/:id/resolve
func (h *OpsHandler) ResolveRequestError(c *gin.Context) {
h.UpdateErrorResolution(c)
}
// ListUpstreamErrors lists independent upstream errors.
// GET /api/v1/admin/ops/upstream-errors
func (h *OpsHandler) ListUpstreamErrors(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
page, pageSize := response.ParsePagination(c)
if pageSize > 500 {
pageSize = 500
}
startTime, endTime, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
if !startTime.IsZero() {
filter.StartTime = &startTime
}
if !endTime.IsZero() {
filter.EndTime = &endTime
}
filter.View = parseOpsViewParam(c)
filter.Phase = "upstream"
filter.Owner = "provider"
filter.Source = strings.TrimSpace(c.Query("error_source"))
filter.Query = strings.TrimSpace(c.Query("q"))
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
filter.Platform = platform
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid account_id")
return
}
filter.AccountID = &id
}
if v := strings.TrimSpace(c.Query("resolved")); v != "" {
switch strings.ToLower(v) {
case "1", "true", "yes":
b := true
filter.Resolved = &b
case "0", "false", "no":
b := false
filter.Resolved = &b
default:
response.BadRequest(c, "Invalid resolved")
return
}
}
if statusCodesStr := strings.TrimSpace(c.Query("status_codes")); statusCodesStr != "" {
parts := strings.Split(statusCodesStr, ",")
out := make([]int, 0, len(parts))
for _, part := range parts {
p := strings.TrimSpace(part)
if p == "" {
continue
}
n, err := strconv.Atoi(p)
if err != nil || n < 0 {
response.BadRequest(c, "Invalid status_codes")
return
}
out = append(out, n)
}
filter.StatusCodes = out
}
result, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
}
// GetUpstreamError returns upstream error detail.
// GET /api/v1/admin/ops/upstream-errors/:id
func (h *OpsHandler) GetUpstreamError(c *gin.Context) {
h.GetErrorLogByID(c)
}
// RetryUpstreamError retries upstream error using the original account_id.
// POST /api/v1/admin/ops/upstream-errors/:id/retry
func (h *OpsHandler) RetryUpstreamError(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, service.OpsRetryModeUpstream, nil)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// ResolveUpstreamError toggles resolved status.
// PUT /api/v1/admin/ops/upstream-errors/:id/resolve
func (h *OpsHandler) ResolveUpstreamError(c *gin.Context) {
h.UpdateErrorResolution(c)
}
// ==================== Existing endpoints ====================
// ListRequestDetails returns a request-level list (success + error) for drill-down.
// GET /api/v1/admin/ops/requests
func (h *OpsHandler) ListRequestDetails(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
page, pageSize := response.ParsePagination(c)
if pageSize > 100 {
pageSize = 100
}
startTime, endTime, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsRequestDetailFilter{
Page: page,
PageSize: pageSize,
StartTime: &startTime,
EndTime: &endTime,
}
filter.Kind = strings.TrimSpace(c.Query("kind"))
filter.Platform = strings.TrimSpace(c.Query("platform"))
filter.Model = strings.TrimSpace(c.Query("model"))
filter.RequestID = strings.TrimSpace(c.Query("request_id"))
filter.Query = strings.TrimSpace(c.Query("q"))
filter.Sort = strings.TrimSpace(c.Query("sort"))
if v := strings.TrimSpace(c.Query("user_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid user_id")
return
}
filter.UserID = &id
}
if v := strings.TrimSpace(c.Query("api_key_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid api_key_id")
return
}
filter.APIKeyID = &id
}
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid account_id")
return
}
filter.AccountID = &id
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
if v := strings.TrimSpace(c.Query("min_duration_ms")); v != "" {
parsed, err := strconv.Atoi(v)
if err != nil || parsed < 0 {
response.BadRequest(c, "Invalid min_duration_ms")
return
}
filter.MinDurationMs = &parsed
}
if v := strings.TrimSpace(c.Query("max_duration_ms")); v != "" {
parsed, err := strconv.Atoi(v)
if err != nil || parsed < 0 {
response.BadRequest(c, "Invalid max_duration_ms")
return
}
filter.MaxDurationMs = &parsed
}
out, err := h.opsService.ListRequestDetails(c.Request.Context(), filter)
if err != nil {
// Invalid sort/kind/platform etc should be a bad request; keep it simple.
if strings.Contains(strings.ToLower(err.Error()), "invalid") {
response.BadRequest(c, err.Error())
return
}
response.Error(c, http.StatusInternalServerError, "Failed to list request details")
return
}
response.Paginated(c, out.Items, out.Total, out.Page, out.PageSize)
}
type opsRetryRequest struct {
Mode string `json:"mode"`
PinnedAccountID *int64 `json:"pinned_account_id"`
Force bool `json:"force"`
}
type opsResolveRequest struct {
Resolved bool `json:"resolved"`
}
// RetryErrorRequest retries a failed request using stored request_body.
// POST /api/v1/admin/ops/errors/:id/retry
func (h *OpsHandler) RetryErrorRequest(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
req := opsRetryRequest{Mode: service.OpsRetryModeClient}
if err := c.ShouldBindJSON(&req); err != nil && !errors.Is(err, io.EOF) {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if strings.TrimSpace(req.Mode) == "" {
req.Mode = service.OpsRetryModeClient
}
// Force flag is currently a UI-level acknowledgement. Server may still enforce safety constraints.
_ = req.Force
// Legacy endpoint safety: only allow retrying the client request here.
// Upstream retries must go through the split endpoints.
if strings.EqualFold(strings.TrimSpace(req.Mode), service.OpsRetryModeUpstream) {
response.BadRequest(c, "upstream retry is not supported on this endpoint")
return
}
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, req.Mode, req.PinnedAccountID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// ListRetryAttempts lists retry attempts for an error log.
// GET /api/v1/admin/ops/errors/:id/retries
func (h *OpsHandler) ListRetryAttempts(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
limit := 50
if v := strings.TrimSpace(c.Query("limit")); v != "" {
n, err := strconv.Atoi(v)
if err != nil || n <= 0 {
response.BadRequest(c, "Invalid limit")
return
}
limit = n
}
items, err := h.opsService.ListRetryAttemptsByErrorID(c.Request.Context(), id, limit)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, items)
}
// UpdateErrorResolution allows manual resolve/unresolve.
// PUT /api/v1/admin/ops/errors/:id/resolve
func (h *OpsHandler) UpdateErrorResolution(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
var req opsResolveRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
uid := subject.UserID
if err := h.opsService.UpdateErrorResolution(c.Request.Context(), id, req.Resolved, &uid, nil); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"ok": true})
}
func parseOpsTimeRange(c *gin.Context, defaultRange string) (time.Time, time.Time, error) {
startStr := strings.TrimSpace(c.Query("start_time"))
endStr := strings.TrimSpace(c.Query("end_time"))
parseTS := func(s string) (time.Time, error) {
if s == "" {
return time.Time{}, nil
}
if t, err := time.Parse(time.RFC3339Nano, s); err == nil {
return t, nil
}
return time.Parse(time.RFC3339, s)
}
start, err := parseTS(startStr)
if err != nil {
return time.Time{}, time.Time{}, err
}
end, err := parseTS(endStr)
if err != nil {
return time.Time{}, time.Time{}, err
}
// start/end explicitly provided (even partially)
if startStr != "" || endStr != "" {
if end.IsZero() {
end = time.Now()
}
if start.IsZero() {
dur, _ := parseOpsDuration(defaultRange)
start = end.Add(-dur)
}
if start.After(end) {
return time.Time{}, time.Time{}, fmt.Errorf("invalid time range: start_time must be <= end_time")
}
if end.Sub(start) > 30*24*time.Hour {
return time.Time{}, time.Time{}, fmt.Errorf("invalid time range: max window is 30 days")
}
return start, end, nil
}
// time_range fallback
tr := strings.TrimSpace(c.Query("time_range"))
if tr == "" {
tr = defaultRange
}
dur, ok := parseOpsDuration(tr)
if !ok {
dur, _ = parseOpsDuration(defaultRange)
}
end = time.Now()
start = end.Add(-dur)
if end.Sub(start) > 30*24*time.Hour {
return time.Time{}, time.Time{}, fmt.Errorf("invalid time range: max window is 30 days")
}
return start, end, nil
}
func parseOpsDuration(v string) (time.Duration, bool) {
switch strings.TrimSpace(v) {
case "5m":
return 5 * time.Minute, true
case "30m":
return 30 * time.Minute, true
case "1h":
return time.Hour, true
case "6h":
return 6 * time.Hour, true
case "24h":
return 24 * time.Hour, true
case "7d":
return 7 * 24 * time.Hour, true
case "30d":
return 30 * 24 * time.Hour, true
default:
return 0, false
}
}

View File

@@ -0,0 +1,250 @@
package admin
import (
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// GetConcurrencyStats returns real-time concurrency usage aggregated by platform/group/account.
// GET /api/v1/admin/ops/concurrency
func (h *OpsHandler) GetConcurrencyStats(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
if !h.opsService.IsRealtimeMonitoringEnabled(c.Request.Context()) {
response.Success(c, gin.H{
"enabled": false,
"platform": map[string]*service.PlatformConcurrencyInfo{},
"group": map[int64]*service.GroupConcurrencyInfo{},
"account": map[int64]*service.AccountConcurrencyInfo{},
"timestamp": time.Now().UTC(),
})
return
}
platformFilter := strings.TrimSpace(c.Query("platform"))
var groupID *int64
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
groupID = &id
}
platform, group, account, collectedAt, err := h.opsService.GetConcurrencyStats(c.Request.Context(), platformFilter, groupID)
if err != nil {
response.ErrorFrom(c, err)
return
}
payload := gin.H{
"enabled": true,
"platform": platform,
"group": group,
"account": account,
}
if collectedAt != nil {
payload["timestamp"] = collectedAt.UTC()
}
response.Success(c, payload)
}
// GetUserConcurrencyStats returns real-time concurrency usage for all active users.
// GET /api/v1/admin/ops/user-concurrency
func (h *OpsHandler) GetUserConcurrencyStats(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
if !h.opsService.IsRealtimeMonitoringEnabled(c.Request.Context()) {
response.Success(c, gin.H{
"enabled": false,
"user": map[int64]*service.UserConcurrencyInfo{},
"timestamp": time.Now().UTC(),
})
return
}
users, collectedAt, err := h.opsService.GetUserConcurrencyStats(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
payload := gin.H{
"enabled": true,
"user": users,
}
if collectedAt != nil {
payload["timestamp"] = collectedAt.UTC()
}
response.Success(c, payload)
}
// GetAccountAvailability returns account availability statistics.
// GET /api/v1/admin/ops/account-availability
//
// Query params:
// - platform: optional
// - group_id: optional
func (h *OpsHandler) GetAccountAvailability(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
if !h.opsService.IsRealtimeMonitoringEnabled(c.Request.Context()) {
response.Success(c, gin.H{
"enabled": false,
"platform": map[string]*service.PlatformAvailability{},
"group": map[int64]*service.GroupAvailability{},
"account": map[int64]*service.AccountAvailability{},
"timestamp": time.Now().UTC(),
})
return
}
platform := strings.TrimSpace(c.Query("platform"))
var groupID *int64
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
groupID = &id
}
platformStats, groupStats, accountStats, collectedAt, err := h.opsService.GetAccountAvailabilityStats(c.Request.Context(), platform, groupID)
if err != nil {
response.ErrorFrom(c, err)
return
}
payload := gin.H{
"enabled": true,
"platform": platformStats,
"group": groupStats,
"account": accountStats,
}
if collectedAt != nil {
payload["timestamp"] = collectedAt.UTC()
}
response.Success(c, payload)
}
func parseOpsRealtimeWindow(v string) (time.Duration, string, bool) {
switch strings.ToLower(strings.TrimSpace(v)) {
case "", "1min", "1m":
return 1 * time.Minute, "1min", true
case "5min", "5m":
return 5 * time.Minute, "5min", true
case "30min", "30m":
return 30 * time.Minute, "30min", true
case "1h", "60m", "60min":
return 1 * time.Hour, "1h", true
default:
return 0, "", false
}
}
// GetRealtimeTrafficSummary returns QPS/TPS current/peak/avg for the selected window.
// GET /api/v1/admin/ops/realtime-traffic
//
// Query params:
// - window: 1min|5min|30min|1h (default: 1min)
// - platform: optional
// - group_id: optional
func (h *OpsHandler) GetRealtimeTrafficSummary(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
windowDur, windowLabel, ok := parseOpsRealtimeWindow(c.Query("window"))
if !ok {
response.BadRequest(c, "Invalid window")
return
}
platform := strings.TrimSpace(c.Query("platform"))
var groupID *int64
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
groupID = &id
}
endTime := time.Now().UTC()
startTime := endTime.Add(-windowDur)
if !h.opsService.IsRealtimeMonitoringEnabled(c.Request.Context()) {
disabledSummary := &service.OpsRealtimeTrafficSummary{
Window: windowLabel,
StartTime: startTime,
EndTime: endTime,
Platform: platform,
GroupID: groupID,
QPS: service.OpsRateSummary{},
TPS: service.OpsRateSummary{},
}
response.Success(c, gin.H{
"enabled": false,
"summary": disabledSummary,
"timestamp": endTime,
})
return
}
filter := &service.OpsDashboardFilter{
StartTime: startTime,
EndTime: endTime,
Platform: platform,
GroupID: groupID,
QueryMode: service.OpsQueryModeRaw,
}
summary, err := h.opsService.GetRealtimeTrafficSummary(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
if summary != nil {
summary.Window = windowLabel
}
response.Success(c, gin.H{
"enabled": true,
"summary": summary,
"timestamp": endTime,
})
}

View File

@@ -0,0 +1,173 @@
package admin
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type testSettingRepo struct {
values map[string]string
}
func newTestSettingRepo() *testSettingRepo {
return &testSettingRepo{values: map[string]string{}}
}
func (s *testSettingRepo) Get(ctx context.Context, key string) (*service.Setting, error) {
v, err := s.GetValue(ctx, key)
if err != nil {
return nil, err
}
return &service.Setting{Key: key, Value: v}, nil
}
func (s *testSettingRepo) GetValue(ctx context.Context, key string) (string, error) {
v, ok := s.values[key]
if !ok {
return "", service.ErrSettingNotFound
}
return v, nil
}
func (s *testSettingRepo) Set(ctx context.Context, key, value string) error {
s.values[key] = value
return nil
}
func (s *testSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, k := range keys {
if v, ok := s.values[k]; ok {
out[k] = v
}
}
return out, nil
}
func (s *testSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error {
for k, v := range settings {
s.values[k] = v
}
return nil
}
func (s *testSettingRepo) GetAll(ctx context.Context) (map[string]string, error) {
out := make(map[string]string, len(s.values))
for k, v := range s.values {
out[k] = v
}
return out, nil
}
func (s *testSettingRepo) Delete(ctx context.Context, key string) error {
delete(s.values, key)
return nil
}
func newOpsRuntimeRouter(handler *OpsHandler, withUser bool) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
if withUser {
r.Use(func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 7})
c.Next()
})
}
r.GET("/runtime/logging", handler.GetRuntimeLogConfig)
r.PUT("/runtime/logging", handler.UpdateRuntimeLogConfig)
r.POST("/runtime/logging/reset", handler.ResetRuntimeLogConfig)
return r
}
func newRuntimeOpsService(t *testing.T) *service.OpsService {
t.Helper()
if err := logger.Init(logger.InitOptions{
Level: "info",
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Output: logger.OutputOptions{
ToStdout: false,
ToFile: false,
},
}); err != nil {
t.Fatalf("init logger: %v", err)
}
settingRepo := newTestSettingRepo()
cfg := &config.Config{
Ops: config.OpsConfig{Enabled: true},
Log: config.LogConfig{
Level: "info",
Caller: true,
StacktraceLevel: "error",
Sampling: config.LogSamplingConfig{
Enabled: false,
Initial: 100,
Thereafter: 100,
},
},
}
return service.NewOpsService(nil, settingRepo, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
}
func TestOpsRuntimeLoggingHandler_GetConfig(t *testing.T) {
h := NewOpsHandler(newRuntimeOpsService(t))
r := newOpsRuntimeRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/runtime/logging", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d, want 200", w.Code)
}
}
func TestOpsRuntimeLoggingHandler_UpdateUnauthorized(t *testing.T) {
h := NewOpsHandler(newRuntimeOpsService(t))
r := newOpsRuntimeRouter(h, false)
body := `{"level":"debug","enable_sampling":false,"sampling_initial":100,"sampling_thereafter":100,"caller":true,"stacktrace_level":"error","retention_days":30}`
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/runtime/logging", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", w.Code)
}
}
func TestOpsRuntimeLoggingHandler_UpdateAndResetSuccess(t *testing.T) {
h := NewOpsHandler(newRuntimeOpsService(t))
r := newOpsRuntimeRouter(h, true)
payload := map[string]any{
"level": "debug",
"enable_sampling": false,
"sampling_initial": 100,
"sampling_thereafter": 100,
"caller": true,
"stacktrace_level": "error",
"retention_days": 30,
}
raw, _ := json.Marshal(payload)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/runtime/logging", bytes.NewReader(raw))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("update status=%d, want 200, body=%s", w.Code, w.Body.String())
}
w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/runtime/logging/reset", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("reset status=%d, want 200, body=%s", w.Code, w.Body.String())
}
}

View File

@@ -0,0 +1,273 @@
package admin
import (
"net/http"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// GetEmailNotificationConfig returns Ops email notification config (DB-backed).
// GET /api/v1/admin/ops/email-notification/config
func (h *OpsHandler) GetEmailNotificationConfig(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
cfg, err := h.opsService.GetEmailNotificationConfig(c.Request.Context())
if err != nil {
response.Error(c, http.StatusInternalServerError, "Failed to get email notification config")
return
}
response.Success(c, cfg)
}
// UpdateEmailNotificationConfig updates Ops email notification config (DB-backed).
// PUT /api/v1/admin/ops/email-notification/config
func (h *OpsHandler) UpdateEmailNotificationConfig(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
var req service.OpsEmailNotificationConfigUpdateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
updated, err := h.opsService.UpdateEmailNotificationConfig(c.Request.Context(), &req)
if err != nil {
// Most failures here are validation errors from request payload; treat as 400.
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, updated)
}
// GetAlertRuntimeSettings returns Ops alert evaluator runtime settings (DB-backed).
// GET /api/v1/admin/ops/runtime/alert
func (h *OpsHandler) GetAlertRuntimeSettings(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
cfg, err := h.opsService.GetOpsAlertRuntimeSettings(c.Request.Context())
if err != nil {
response.Error(c, http.StatusInternalServerError, "Failed to get alert runtime settings")
return
}
response.Success(c, cfg)
}
// UpdateAlertRuntimeSettings updates Ops alert evaluator runtime settings (DB-backed).
// PUT /api/v1/admin/ops/runtime/alert
func (h *OpsHandler) UpdateAlertRuntimeSettings(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
var req service.OpsAlertRuntimeSettings
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
updated, err := h.opsService.UpdateOpsAlertRuntimeSettings(c.Request.Context(), &req)
if err != nil {
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, updated)
}
// GetRuntimeLogConfig returns runtime log config (DB-backed).
// GET /api/v1/admin/ops/runtime/logging
func (h *OpsHandler) GetRuntimeLogConfig(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
cfg, err := h.opsService.GetRuntimeLogConfig(c.Request.Context())
if err != nil {
response.Error(c, http.StatusInternalServerError, "Failed to get runtime log config")
return
}
response.Success(c, cfg)
}
// UpdateRuntimeLogConfig updates runtime log config and applies changes immediately.
// PUT /api/v1/admin/ops/runtime/logging
func (h *OpsHandler) UpdateRuntimeLogConfig(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
var req service.OpsRuntimeLogConfig
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
updated, err := h.opsService.UpdateRuntimeLogConfig(c.Request.Context(), &req, subject.UserID)
if err != nil {
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, updated)
}
// ResetRuntimeLogConfig removes runtime override and falls back to env/yaml baseline.
// POST /api/v1/admin/ops/runtime/logging/reset
func (h *OpsHandler) ResetRuntimeLogConfig(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
updated, err := h.opsService.ResetRuntimeLogConfig(c.Request.Context(), subject.UserID)
if err != nil {
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, updated)
}
// GetAdvancedSettings returns Ops advanced settings (DB-backed).
// GET /api/v1/admin/ops/advanced-settings
func (h *OpsHandler) GetAdvancedSettings(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
cfg, err := h.opsService.GetOpsAdvancedSettings(c.Request.Context())
if err != nil {
response.Error(c, http.StatusInternalServerError, "Failed to get advanced settings")
return
}
response.Success(c, cfg)
}
// UpdateAdvancedSettings updates Ops advanced settings (DB-backed).
// PUT /api/v1/admin/ops/advanced-settings
func (h *OpsHandler) UpdateAdvancedSettings(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
var req service.OpsAdvancedSettings
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
updated, err := h.opsService.UpdateOpsAdvancedSettings(c.Request.Context(), &req)
if err != nil {
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, updated)
}
// GetMetricThresholds returns Ops metric thresholds (DB-backed).
// GET /api/v1/admin/ops/settings/metric-thresholds
func (h *OpsHandler) GetMetricThresholds(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
cfg, err := h.opsService.GetMetricThresholds(c.Request.Context())
if err != nil {
response.Error(c, http.StatusInternalServerError, "Failed to get metric thresholds")
return
}
response.Success(c, cfg)
}
// UpdateMetricThresholds updates Ops metric thresholds (DB-backed).
// PUT /api/v1/admin/ops/settings/metric-thresholds
func (h *OpsHandler) UpdateMetricThresholds(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
var req service.OpsMetricThresholds
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
updated, err := h.opsService.UpdateMetricThresholds(c.Request.Context(), &req)
if err != nil {
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, updated)
}

View File

@@ -0,0 +1,145 @@
package admin
import (
"encoding/json"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"golang.org/x/sync/errgroup"
)
var opsDashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
type opsDashboardSnapshotV2Response struct {
GeneratedAt string `json:"generated_at"`
Overview *service.OpsDashboardOverview `json:"overview"`
ThroughputTrend *service.OpsThroughputTrendResponse `json:"throughput_trend"`
ErrorTrend *service.OpsErrorTrendResponse `json:"error_trend"`
}
type opsDashboardSnapshotV2CacheKey struct {
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
Platform string `json:"platform"`
GroupID *int64 `json:"group_id"`
QueryMode service.OpsQueryMode `json:"mode"`
BucketSecond int `json:"bucket_second"`
}
// GetDashboardSnapshotV2 returns ops dashboard core snapshot in one request.
// GET /api/v1/admin/ops/dashboard/snapshot-v2
func (h *OpsHandler) GetDashboardSnapshotV2(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
startTime, endTime, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsDashboardFilter{
StartTime: startTime,
EndTime: endTime,
Platform: strings.TrimSpace(c.Query("platform")),
QueryMode: parseOpsQueryMode(c),
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
bucketSeconds := pickThroughputBucketSeconds(endTime.Sub(startTime))
keyRaw, _ := json.Marshal(opsDashboardSnapshotV2CacheKey{
StartTime: startTime.UTC().Format(time.RFC3339),
EndTime: endTime.UTC().Format(time.RFC3339),
Platform: filter.Platform,
GroupID: filter.GroupID,
QueryMode: filter.QueryMode,
BucketSecond: bucketSeconds,
})
cacheKey := string(keyRaw)
if cached, ok := opsDashboardSnapshotV2Cache.Get(cacheKey); ok {
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
c.Status(http.StatusNotModified)
return
}
}
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
var (
overview *service.OpsDashboardOverview
trend *service.OpsThroughputTrendResponse
errTrend *service.OpsErrorTrendResponse
)
g, gctx := errgroup.WithContext(c.Request.Context())
g.Go(func() error {
f := *filter
result, err := h.opsService.GetDashboardOverview(gctx, &f)
if err != nil {
return err
}
overview = result
return nil
})
g.Go(func() error {
f := *filter
result, err := h.opsService.GetThroughputTrend(gctx, &f, bucketSeconds)
if err != nil {
return err
}
trend = result
return nil
})
g.Go(func() error {
f := *filter
result, err := h.opsService.GetErrorTrend(gctx, &f, bucketSeconds)
if err != nil {
return err
}
errTrend = result
return nil
})
if err := g.Wait(); err != nil {
response.ErrorFrom(c, err)
return
}
resp := &opsDashboardSnapshotV2Response{
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
Overview: overview,
ThroughputTrend: trend,
ErrorTrend: errTrend,
}
cached := opsDashboardSnapshotV2Cache.Set(cacheKey, resp)
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
}
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, resp)
}

View File

@@ -0,0 +1,174 @@
package admin
import (
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type opsSystemLogCleanupRequest struct {
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
Level string `json:"level"`
Component string `json:"component"`
RequestID string `json:"request_id"`
ClientRequestID string `json:"client_request_id"`
UserID *int64 `json:"user_id"`
AccountID *int64 `json:"account_id"`
Platform string `json:"platform"`
Model string `json:"model"`
Query string `json:"q"`
}
// ListSystemLogs returns indexed system logs.
// GET /api/v1/admin/ops/system-logs
func (h *OpsHandler) ListSystemLogs(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
page, pageSize := response.ParsePagination(c)
if pageSize > 200 {
pageSize = 200
}
start, end, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsSystemLogFilter{
Page: page,
PageSize: pageSize,
StartTime: &start,
EndTime: &end,
Level: strings.TrimSpace(c.Query("level")),
Component: strings.TrimSpace(c.Query("component")),
RequestID: strings.TrimSpace(c.Query("request_id")),
ClientRequestID: strings.TrimSpace(c.Query("client_request_id")),
Platform: strings.TrimSpace(c.Query("platform")),
Model: strings.TrimSpace(c.Query("model")),
Query: strings.TrimSpace(c.Query("q")),
}
if v := strings.TrimSpace(c.Query("user_id")); v != "" {
id, parseErr := strconv.ParseInt(v, 10, 64)
if parseErr != nil || id <= 0 {
response.BadRequest(c, "Invalid user_id")
return
}
filter.UserID = &id
}
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
id, parseErr := strconv.ParseInt(v, 10, 64)
if parseErr != nil || id <= 0 {
response.BadRequest(c, "Invalid account_id")
return
}
filter.AccountID = &id
}
result, err := h.opsService.ListSystemLogs(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, result.Logs, int64(result.Total), result.Page, result.PageSize)
}
// CleanupSystemLogs deletes indexed system logs by filter.
// POST /api/v1/admin/ops/system-logs/cleanup
func (h *OpsHandler) CleanupSystemLogs(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
var req opsSystemLogCleanupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
parseTS := func(raw string) (*time.Time, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil, nil
}
if t, err := time.Parse(time.RFC3339Nano, raw); err == nil {
return &t, nil
}
t, err := time.Parse(time.RFC3339, raw)
if err != nil {
return nil, err
}
return &t, nil
}
start, err := parseTS(req.StartTime)
if err != nil {
response.BadRequest(c, "Invalid start_time")
return
}
end, err := parseTS(req.EndTime)
if err != nil {
response.BadRequest(c, "Invalid end_time")
return
}
filter := &service.OpsSystemLogCleanupFilter{
StartTime: start,
EndTime: end,
Level: strings.TrimSpace(req.Level),
Component: strings.TrimSpace(req.Component),
RequestID: strings.TrimSpace(req.RequestID),
ClientRequestID: strings.TrimSpace(req.ClientRequestID),
UserID: req.UserID,
AccountID: req.AccountID,
Platform: strings.TrimSpace(req.Platform),
Model: strings.TrimSpace(req.Model),
Query: strings.TrimSpace(req.Query),
}
deleted, err := h.opsService.CleanupSystemLogs(c.Request.Context(), filter, subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"deleted": deleted})
}
// GetSystemLogIngestionHealth returns sink health metrics.
// GET /api/v1/admin/ops/system-logs/health
func (h *OpsHandler) GetSystemLogIngestionHealth(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, h.opsService.GetSystemLogSinkHealth())
}

View File

@@ -0,0 +1,233 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type responseEnvelope struct {
Code int `json:"code"`
Message string `json:"message"`
Data json.RawMessage `json:"data"`
}
func newOpsSystemLogTestRouter(handler *OpsHandler, withUser bool) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
if withUser {
r.Use(func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 99})
c.Next()
})
}
r.GET("/logs", handler.ListSystemLogs)
r.POST("/logs/cleanup", handler.CleanupSystemLogs)
r.GET("/logs/health", handler.GetSystemLogIngestionHealth)
return r
}
func TestOpsSystemLogHandler_ListUnavailable(t *testing.T) {
h := NewOpsHandler(nil)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("status=%d, want 503", w.Code)
}
}
func TestOpsSystemLogHandler_ListInvalidUserID(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs?user_id=abc", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_ListInvalidAccountID(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs?account_id=-1", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_ListMonitoringDisabled(t *testing.T) {
svc := service.NewOpsService(nil, nil, &config.Config{
Ops: config.OpsConfig{Enabled: false},
}, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("status=%d, want 404", w.Code)
}
}
func TestOpsSystemLogHandler_ListSuccess(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs?time_range=30m&page=1&page_size=20", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d, want 200", w.Code)
}
var resp responseEnvelope
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("unmarshal response: %v", err)
}
if resp.Code != 0 {
t.Fatalf("unexpected response code: %+v", resp)
}
}
func TestOpsSystemLogHandler_CleanupUnauthorized(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupInvalidPayload(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{bad-json`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupInvalidTime(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"start_time":"bad","request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupInvalidEndTime(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"end_time":"bad","request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupServiceUnavailable(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("status=%d, want 503", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupMonitoringDisabled(t *testing.T) {
svc := service.NewOpsService(nil, nil, &config.Config{
Ops: config.OpsConfig{Enabled: false},
}, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("status=%d, want 404", w.Code)
}
}
func TestOpsSystemLogHandler_Health(t *testing.T) {
sink := service.NewOpsSystemLogSink(nil)
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, sink)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs/health", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d, want 200", w.Code)
}
}
func TestOpsSystemLogHandler_HealthUnavailableAndMonitoringDisabled(t *testing.T) {
h := NewOpsHandler(nil)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs/health", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("status=%d, want 503", w.Code)
}
svc := service.NewOpsService(nil, nil, &config.Config{
Ops: config.OpsConfig{Enabled: false},
}, nil, nil, nil, nil, nil, nil, nil, nil)
h = NewOpsHandler(svc)
r = newOpsSystemLogTestRouter(h, false)
w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/logs/health", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("status=%d, want 404", w.Code)
}
}

View File

@@ -0,0 +1,761 @@
package admin
import (
"context"
"encoding/json"
"math"
"net"
"net/http"
"net/netip"
"net/url"
"os"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
type OpsWSProxyConfig struct {
TrustProxy bool
TrustedProxies []netip.Prefix
OriginPolicy string
}
const (
envOpsWSTrustProxy = "OPS_WS_TRUST_PROXY"
envOpsWSTrustedProxies = "OPS_WS_TRUSTED_PROXIES"
envOpsWSOriginPolicy = "OPS_WS_ORIGIN_POLICY"
envOpsWSMaxConns = "OPS_WS_MAX_CONNS"
envOpsWSMaxConnsPerIP = "OPS_WS_MAX_CONNS_PER_IP"
)
const (
OriginPolicyStrict = "strict"
OriginPolicyPermissive = "permissive"
)
var opsWSProxyConfig = loadOpsWSProxyConfigFromEnv()
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return isAllowedOpsWSOrigin(r)
},
// Subprotocol negotiation:
// - The frontend passes ["sub2api-admin", "jwt.<token>"].
// - We always select "sub2api-admin" so the token is never echoed back in the handshake response.
Subprotocols: []string{"sub2api-admin"},
}
const (
qpsWSPushInterval = 2 * time.Second
qpsWSRefreshInterval = 5 * time.Second
qpsWSRequestCountWindow = 1 * time.Minute
defaultMaxWSConns = 100
defaultMaxWSConnsPerIP = 20
)
var wsConnCount atomic.Int32
var wsConnCountByIPMu sync.Mutex
var wsConnCountByIP = make(map[string]int32)
const qpsWSIdleStopDelay = 30 * time.Second
const (
opsWSCloseRealtimeDisabled = 4001
)
var qpsWSIdleStopMu sync.Mutex
var qpsWSIdleStopTimer *time.Timer
func cancelQPSWSIdleStop() {
qpsWSIdleStopMu.Lock()
if qpsWSIdleStopTimer != nil {
qpsWSIdleStopTimer.Stop()
qpsWSIdleStopTimer = nil
}
qpsWSIdleStopMu.Unlock()
}
func scheduleQPSWSIdleStop() {
qpsWSIdleStopMu.Lock()
if qpsWSIdleStopTimer != nil {
qpsWSIdleStopMu.Unlock()
return
}
qpsWSIdleStopTimer = time.AfterFunc(qpsWSIdleStopDelay, func() {
// Only stop if truly idle at fire time.
if wsConnCount.Load() == 0 {
qpsWSCache.Stop()
}
qpsWSIdleStopMu.Lock()
qpsWSIdleStopTimer = nil
qpsWSIdleStopMu.Unlock()
})
qpsWSIdleStopMu.Unlock()
}
type opsWSRuntimeLimits struct {
MaxConns int32
MaxConnsPerIP int32
}
var opsWSLimits = loadOpsWSRuntimeLimitsFromEnv()
const (
qpsWSWriteTimeout = 10 * time.Second
qpsWSPongWait = 60 * time.Second
qpsWSPingInterval = 30 * time.Second
// We don't expect clients to send application messages; we only read to process control frames (Pong/Close).
qpsWSMaxReadBytes = 1024
)
type opsWSQPSCache struct {
refreshInterval time.Duration
requestCountWindow time.Duration
lastUpdatedUnixNano atomic.Int64
payload atomic.Value // []byte
opsService *service.OpsService
cancel context.CancelFunc
done chan struct{}
mu sync.Mutex
running bool
}
var qpsWSCache = &opsWSQPSCache{
refreshInterval: qpsWSRefreshInterval,
requestCountWindow: qpsWSRequestCountWindow,
}
func (c *opsWSQPSCache) start(opsService *service.OpsService) {
if c == nil || opsService == nil {
return
}
for {
c.mu.Lock()
if c.running {
c.mu.Unlock()
return
}
// If a previous refresh loop is currently stopping, wait for it to fully exit.
done := c.done
if done != nil {
c.mu.Unlock()
<-done
c.mu.Lock()
if c.done == done && !c.running {
c.done = nil
}
c.mu.Unlock()
continue
}
c.opsService = opsService
ctx, cancel := context.WithCancel(context.Background())
c.cancel = cancel
c.done = make(chan struct{})
done = c.done
c.running = true
c.mu.Unlock()
go func() {
defer close(done)
c.refreshLoop(ctx)
}()
return
}
}
// Stop stops the background refresh loop.
// It is safe to call multiple times.
func (c *opsWSQPSCache) Stop() {
if c == nil {
return
}
c.mu.Lock()
if !c.running {
done := c.done
c.mu.Unlock()
if done != nil {
<-done
}
return
}
cancel := c.cancel
c.cancel = nil
c.running = false
c.opsService = nil
done := c.done
c.mu.Unlock()
if cancel != nil {
cancel()
}
if done != nil {
<-done
}
c.mu.Lock()
if c.done == done && !c.running {
c.done = nil
}
c.mu.Unlock()
}
func (c *opsWSQPSCache) refreshLoop(ctx context.Context) {
ticker := time.NewTicker(c.refreshInterval)
defer ticker.Stop()
c.refresh(ctx)
for {
select {
case <-ticker.C:
c.refresh(ctx)
case <-ctx.Done():
return
}
}
}
func (c *opsWSQPSCache) refresh(parentCtx context.Context) {
if c == nil {
return
}
c.mu.Lock()
opsService := c.opsService
c.mu.Unlock()
if opsService == nil {
return
}
if parentCtx == nil {
parentCtx = context.Background()
}
ctx, cancel := context.WithTimeout(parentCtx, 10*time.Second)
defer cancel()
now := time.Now().UTC()
stats, err := opsService.GetWindowStats(ctx, now.Add(-c.requestCountWindow), now)
if err != nil || stats == nil {
if err != nil {
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] refresh: get window stats failed: %v", err)
}
return
}
requestCount := stats.SuccessCount + stats.ErrorCountTotal
qps := 0.0
tps := 0.0
if c.requestCountWindow > 0 {
seconds := c.requestCountWindow.Seconds()
qps = roundTo1DP(float64(requestCount) / seconds)
tps = roundTo1DP(float64(stats.TokenConsumed) / seconds)
}
payload := gin.H{
"type": "qps_update",
"timestamp": now.Format(time.RFC3339),
"data": gin.H{
"qps": qps,
"tps": tps,
"request_count": requestCount,
},
}
msg, err := json.Marshal(payload)
if err != nil {
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] refresh: marshal payload failed: %v", err)
return
}
c.payload.Store(msg)
c.lastUpdatedUnixNano.Store(now.UnixNano())
}
func roundTo1DP(v float64) float64 {
return math.Round(v*10) / 10
}
func (c *opsWSQPSCache) getPayload() []byte {
if c == nil {
return nil
}
if cached, ok := c.payload.Load().([]byte); ok && cached != nil {
return cached
}
return nil
}
func closeWS(conn *websocket.Conn, code int, reason string) {
if conn == nil {
return
}
msg := websocket.FormatCloseMessage(code, reason)
_ = conn.WriteControl(websocket.CloseMessage, msg, time.Now().Add(qpsWSWriteTimeout))
_ = conn.Close()
}
// QPSWSHandler handles realtime QPS push via WebSocket.
// GET /api/v1/admin/ops/ws/qps
func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
clientIP := requestClientIP(c.Request)
if h == nil || h.opsService == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "ops service not initialized"})
return
}
// If realtime monitoring is disabled, prefer a successful WS upgrade followed by a clean close
// with a deterministic close code. This prevents clients from spinning on 404/1006 reconnect loops.
if !h.opsService.IsRealtimeMonitoringEnabled(c.Request.Context()) {
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "ops realtime monitoring is disabled"})
return
}
closeWS(conn, opsWSCloseRealtimeDisabled, "realtime_disabled")
return
}
cancelQPSWSIdleStop()
// Lazily start the background refresh loop so unit tests that never hit the
// websocket route don't spawn goroutines that depend on DB/Redis stubs.
qpsWSCache.start(h.opsService)
// Reserve a global slot before upgrading the connection to keep the limit strict.
if !tryAcquireOpsWSTotalSlot(opsWSLimits.MaxConns) {
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns)
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
return
}
defer func() {
if wsConnCount.Add(-1) == 0 {
scheduleQPSWSIdleStop()
}
}()
if opsWSLimits.MaxConnsPerIP > 0 && clientIP != "" {
if !tryAcquireOpsWSIPSlot(clientIP, opsWSLimits.MaxConnsPerIP) {
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] per-ip connection limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP)
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
return
}
defer releaseOpsWSIPSlot(clientIP)
}
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] upgrade failed: %v", err)
return
}
defer func() {
_ = conn.Close()
}()
handleQPSWebSocket(c.Request.Context(), conn)
}
func tryAcquireOpsWSTotalSlot(limit int32) bool {
if limit <= 0 {
return true
}
for {
current := wsConnCount.Load()
if current >= limit {
return false
}
if wsConnCount.CompareAndSwap(current, current+1) {
return true
}
}
}
func tryAcquireOpsWSIPSlot(clientIP string, limit int32) bool {
if strings.TrimSpace(clientIP) == "" || limit <= 0 {
return true
}
wsConnCountByIPMu.Lock()
defer wsConnCountByIPMu.Unlock()
current := wsConnCountByIP[clientIP]
if current >= limit {
return false
}
wsConnCountByIP[clientIP] = current + 1
return true
}
func releaseOpsWSIPSlot(clientIP string) {
if strings.TrimSpace(clientIP) == "" {
return
}
wsConnCountByIPMu.Lock()
defer wsConnCountByIPMu.Unlock()
current, ok := wsConnCountByIP[clientIP]
if !ok {
return
}
if current <= 1 {
delete(wsConnCountByIP, clientIP)
return
}
wsConnCountByIP[clientIP] = current - 1
}
func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
if conn == nil {
return
}
ctx, cancel := context.WithCancel(parentCtx)
defer cancel()
var closeOnce sync.Once
closeConn := func() {
closeOnce.Do(func() {
_ = conn.Close()
})
}
closeFrameCh := make(chan []byte, 1)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
conn.SetReadLimit(qpsWSMaxReadBytes)
if err := conn.SetReadDeadline(time.Now().Add(qpsWSPongWait)); err != nil {
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] set read deadline failed: %v", err)
return
}
conn.SetPongHandler(func(string) error {
return conn.SetReadDeadline(time.Now().Add(qpsWSPongWait))
})
conn.SetCloseHandler(func(code int, text string) error {
select {
case closeFrameCh <- websocket.FormatCloseMessage(code, text):
default:
}
cancel()
return nil
})
for {
_, _, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] read failed: %v", err)
}
return
}
}
}()
// Push QPS data every 2 seconds (values are globally cached and refreshed at most once per qpsWSRefreshInterval).
pushTicker := time.NewTicker(qpsWSPushInterval)
defer pushTicker.Stop()
// Heartbeat ping every 30 seconds.
pingTicker := time.NewTicker(qpsWSPingInterval)
defer pingTicker.Stop()
writeWithTimeout := func(messageType int, data []byte) error {
if err := conn.SetWriteDeadline(time.Now().Add(qpsWSWriteTimeout)); err != nil {
return err
}
return conn.WriteMessage(messageType, data)
}
sendClose := func(closeFrame []byte) {
if closeFrame == nil {
closeFrame = websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")
}
_ = writeWithTimeout(websocket.CloseMessage, closeFrame)
}
for {
select {
case <-pushTicker.C:
msg := qpsWSCache.getPayload()
if msg == nil {
continue
}
if err := writeWithTimeout(websocket.TextMessage, msg); err != nil {
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] write failed: %v", err)
cancel()
closeConn()
wg.Wait()
return
}
case <-pingTicker.C:
if err := writeWithTimeout(websocket.PingMessage, nil); err != nil {
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] ping failed: %v", err)
cancel()
closeConn()
wg.Wait()
return
}
case closeFrame := <-closeFrameCh:
sendClose(closeFrame)
closeConn()
wg.Wait()
return
case <-ctx.Done():
var closeFrame []byte
select {
case closeFrame = <-closeFrameCh:
default:
}
sendClose(closeFrame)
closeConn()
wg.Wait()
return
}
}
}
func isAllowedOpsWSOrigin(r *http.Request) bool {
if r == nil {
return false
}
origin := strings.TrimSpace(r.Header.Get("Origin"))
if origin == "" {
switch strings.ToLower(strings.TrimSpace(opsWSProxyConfig.OriginPolicy)) {
case OriginPolicyStrict:
return false
case OriginPolicyPermissive, "":
return true
default:
return true
}
}
parsed, err := url.Parse(origin)
if err != nil || parsed.Hostname() == "" {
return false
}
originHost := strings.ToLower(parsed.Hostname())
trustProxyHeaders := shouldTrustOpsWSProxyHeaders(r)
reqHost := hostWithoutPort(r.Host)
if trustProxyHeaders {
xfHost := strings.TrimSpace(r.Header.Get("X-Forwarded-Host"))
if xfHost != "" {
xfHost = strings.TrimSpace(strings.Split(xfHost, ",")[0])
if xfHost != "" {
reqHost = hostWithoutPort(xfHost)
}
}
}
reqHost = strings.ToLower(reqHost)
if reqHost == "" {
return false
}
return originHost == reqHost
}
func shouldTrustOpsWSProxyHeaders(r *http.Request) bool {
if r == nil {
return false
}
if !opsWSProxyConfig.TrustProxy {
return false
}
peerIP, ok := requestPeerIP(r)
if !ok {
return false
}
return isAddrInTrustedProxies(peerIP, opsWSProxyConfig.TrustedProxies)
}
func requestPeerIP(r *http.Request) (netip.Addr, bool) {
if r == nil {
return netip.Addr{}, false
}
host, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr))
if err != nil {
host = strings.TrimSpace(r.RemoteAddr)
}
host = strings.TrimPrefix(host, "[")
host = strings.TrimSuffix(host, "]")
if host == "" {
return netip.Addr{}, false
}
addr, err := netip.ParseAddr(host)
if err != nil {
return netip.Addr{}, false
}
return addr.Unmap(), true
}
func requestClientIP(r *http.Request) string {
if r == nil {
return ""
}
trustProxyHeaders := shouldTrustOpsWSProxyHeaders(r)
if trustProxyHeaders {
xff := strings.TrimSpace(r.Header.Get("X-Forwarded-For"))
if xff != "" {
// Use the left-most entry (original client). If multiple proxies add values, they are comma-separated.
xff = strings.TrimSpace(strings.Split(xff, ",")[0])
xff = strings.TrimPrefix(xff, "[")
xff = strings.TrimSuffix(xff, "]")
if addr, err := netip.ParseAddr(xff); err == nil && addr.IsValid() {
return addr.Unmap().String()
}
}
}
if peer, ok := requestPeerIP(r); ok && peer.IsValid() {
return peer.String()
}
return ""
}
func isAddrInTrustedProxies(addr netip.Addr, trusted []netip.Prefix) bool {
if !addr.IsValid() {
return false
}
for _, p := range trusted {
if p.Contains(addr) {
return true
}
}
return false
}
func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig {
cfg := OpsWSProxyConfig{
TrustProxy: true,
TrustedProxies: defaultTrustedProxies(),
OriginPolicy: OriginPolicyPermissive,
}
if v := strings.TrimSpace(os.Getenv(envOpsWSTrustProxy)); v != "" {
if parsed, err := strconv.ParseBool(v); err == nil {
cfg.TrustProxy = parsed
} else {
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy)
}
}
if raw := strings.TrimSpace(os.Getenv(envOpsWSTrustedProxies)); raw != "" {
prefixes, invalid := parseTrustedProxyList(raw)
if len(invalid) > 0 {
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", "))
}
cfg.TrustedProxies = prefixes
}
if v := strings.TrimSpace(os.Getenv(envOpsWSOriginPolicy)); v != "" {
normalized := strings.ToLower(v)
switch normalized {
case OriginPolicyStrict, OriginPolicyPermissive:
cfg.OriginPolicy = normalized
default:
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy)
}
}
return cfg
}
func loadOpsWSRuntimeLimitsFromEnv() opsWSRuntimeLimits {
cfg := opsWSRuntimeLimits{
MaxConns: defaultMaxWSConns,
MaxConnsPerIP: defaultMaxWSConnsPerIP,
}
if v := strings.TrimSpace(os.Getenv(envOpsWSMaxConns)); v != "" {
if parsed, err := strconv.Atoi(v); err == nil && parsed > 0 {
cfg.MaxConns = int32(parsed)
} else {
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected int>0); using default=%d", envOpsWSMaxConns, v, cfg.MaxConns)
}
}
if v := strings.TrimSpace(os.Getenv(envOpsWSMaxConnsPerIP)); v != "" {
if parsed, err := strconv.Atoi(v); err == nil && parsed >= 0 {
cfg.MaxConnsPerIP = int32(parsed)
} else {
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected int>=0); using default=%d", envOpsWSMaxConnsPerIP, v, cfg.MaxConnsPerIP)
}
}
return cfg
}
func defaultTrustedProxies() []netip.Prefix {
prefixes, _ := parseTrustedProxyList("127.0.0.0/8,::1/128")
return prefixes
}
func parseTrustedProxyList(raw string) (prefixes []netip.Prefix, invalid []string) {
for _, token := range strings.Split(raw, ",") {
item := strings.TrimSpace(token)
if item == "" {
continue
}
var (
p netip.Prefix
err error
)
if strings.Contains(item, "/") {
p, err = netip.ParsePrefix(item)
} else {
var addr netip.Addr
addr, err = netip.ParseAddr(item)
if err == nil {
addr = addr.Unmap()
bits := 128
if addr.Is4() {
bits = 32
}
p = netip.PrefixFrom(addr, bits)
}
}
if err != nil || !p.IsValid() {
invalid = append(invalid, item)
continue
}
prefixes = append(prefixes, p.Masked())
}
return prefixes, invalid
}
func hostWithoutPort(hostport string) string {
hostport = strings.TrimSpace(hostport)
if hostport == "" {
return ""
}
if host, _, err := net.SplitHostPort(hostport); err == nil {
return host
}
if strings.HasPrefix(hostport, "[") && strings.HasSuffix(hostport, "]") {
return strings.Trim(hostport, "[]")
}
parts := strings.Split(hostport, ":")
return parts[0]
}

View File

@@ -0,0 +1,209 @@
package admin
import (
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// PromoHandler handles admin promo code management
type PromoHandler struct {
promoService *service.PromoService
}
// NewPromoHandler creates a new admin promo handler
func NewPromoHandler(promoService *service.PromoService) *PromoHandler {
return &PromoHandler{
promoService: promoService,
}
}
// CreatePromoCodeRequest represents create promo code request
type CreatePromoCodeRequest struct {
Code string `json:"code"` // 可选,为空则自动生成
BonusAmount float64 `json:"bonus_amount" binding:"required,min=0"` // 赠送余额
MaxUses int `json:"max_uses" binding:"min=0"` // 最大使用次数0=无限
ExpiresAt *int64 `json:"expires_at"` // 过期时间戳(秒)
Notes string `json:"notes"` // 备注
}
// UpdatePromoCodeRequest represents update promo code request
type UpdatePromoCodeRequest struct {
Code *string `json:"code"`
BonusAmount *float64 `json:"bonus_amount" binding:"omitempty,min=0"`
MaxUses *int `json:"max_uses" binding:"omitempty,min=0"`
Status *string `json:"status" binding:"omitempty,oneof=active disabled"`
ExpiresAt *int64 `json:"expires_at"`
Notes *string `json:"notes"`
}
// List handles listing all promo codes with pagination
// GET /api/v1/admin/promo-codes
func (h *PromoHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
status := c.Query("status")
search := strings.TrimSpace(c.Query("search"))
if len(search) > 100 {
search = search[:100]
}
params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
}
codes, paginationResult, err := h.promoService.List(c.Request.Context(), params, status, search)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.PromoCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.PromoCodeFromService(&codes[i]))
}
response.Paginated(c, out, paginationResult.Total, page, pageSize)
}
// GetByID handles getting a promo code by ID
// GET /api/v1/admin/promo-codes/:id
func (h *PromoHandler) GetByID(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid promo code ID")
return
}
code, err := h.promoService.GetByID(c.Request.Context(), codeID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.PromoCodeFromService(code))
}
// Create handles creating a new promo code
// POST /api/v1/admin/promo-codes
func (h *PromoHandler) Create(c *gin.Context) {
var req CreatePromoCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
input := &service.CreatePromoCodeInput{
Code: req.Code,
BonusAmount: req.BonusAmount,
MaxUses: req.MaxUses,
Notes: req.Notes,
}
if req.ExpiresAt != nil {
t := time.Unix(*req.ExpiresAt, 0)
input.ExpiresAt = &t
}
code, err := h.promoService.Create(c.Request.Context(), input)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.PromoCodeFromService(code))
}
// Update handles updating a promo code
// PUT /api/v1/admin/promo-codes/:id
func (h *PromoHandler) Update(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid promo code ID")
return
}
var req UpdatePromoCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
input := &service.UpdatePromoCodeInput{
Code: req.Code,
BonusAmount: req.BonusAmount,
MaxUses: req.MaxUses,
Status: req.Status,
Notes: req.Notes,
}
if req.ExpiresAt != nil {
if *req.ExpiresAt == 0 {
// 0 表示清除过期时间
input.ExpiresAt = nil
} else {
t := time.Unix(*req.ExpiresAt, 0)
input.ExpiresAt = &t
}
}
code, err := h.promoService.Update(c.Request.Context(), codeID, input)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.PromoCodeFromService(code))
}
// Delete handles deleting a promo code
// DELETE /api/v1/admin/promo-codes/:id
func (h *PromoHandler) Delete(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid promo code ID")
return
}
err = h.promoService.Delete(c.Request.Context(), codeID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Promo code deleted successfully"})
}
// GetUsages handles getting usage records for a promo code
// GET /api/v1/admin/promo-codes/:id/usages
func (h *PromoHandler) GetUsages(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid promo code ID")
return
}
page, pageSize := response.ParsePagination(c)
params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
}
usages, paginationResult, err := h.promoService.ListUsages(c.Request.Context(), codeID, params)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.PromoCodeUsage, 0, len(usages))
for i := range usages {
out = append(out, *dto.PromoCodeUsageFromService(&usages[i]))
}
response.Paginated(c, out, paginationResult.Total, page, pageSize)
}

View File

@@ -0,0 +1,239 @@
package admin
import (
"context"
"fmt"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ExportData exports proxy-only data for migration.
func (h *ProxyHandler) ExportData(c *gin.Context) {
ctx := c.Request.Context()
selectedIDs, err := parseProxyIDs(c)
if err != nil {
response.BadRequest(c, err.Error())
return
}
var proxies []service.Proxy
if len(selectedIDs) > 0 {
proxies, err = h.getProxiesByIDs(ctx, selectedIDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
} else {
protocol := c.Query("protocol")
status := c.Query("status")
search := strings.TrimSpace(c.Query("search"))
if len(search) > 100 {
search = search[:100]
}
proxies, err = h.listProxiesFiltered(ctx, protocol, status, search)
if err != nil {
response.ErrorFrom(c, err)
return
}
}
dataProxies := make([]DataProxy, 0, len(proxies))
for i := range proxies {
p := proxies[i]
key := buildProxyKey(p.Protocol, p.Host, p.Port, p.Username, p.Password)
dataProxies = append(dataProxies, DataProxy{
ProxyKey: key,
Name: p.Name,
Protocol: p.Protocol,
Host: p.Host,
Port: p.Port,
Username: p.Username,
Password: p.Password,
Status: p.Status,
})
}
payload := DataPayload{
ExportedAt: time.Now().UTC().Format(time.RFC3339),
Proxies: dataProxies,
Accounts: []DataAccount{},
}
response.Success(c, payload)
}
// ImportData imports proxy-only data for migration.
func (h *ProxyHandler) ImportData(c *gin.Context) {
type ProxyImportRequest struct {
Data DataPayload `json:"data"`
}
var req ProxyImportRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := validateDataHeader(req.Data); err != nil {
response.BadRequest(c, err.Error())
return
}
ctx := c.Request.Context()
result := DataImportResult{}
existingProxies, err := h.listProxiesFiltered(ctx, "", "", "")
if err != nil {
response.ErrorFrom(c, err)
return
}
proxyByKey := make(map[string]service.Proxy, len(existingProxies))
for i := range existingProxies {
p := existingProxies[i]
key := buildProxyKey(p.Protocol, p.Host, p.Port, p.Username, p.Password)
proxyByKey[key] = p
}
latencyProbeIDs := make([]int64, 0, len(req.Data.Proxies))
for i := range req.Data.Proxies {
item := req.Data.Proxies[i]
key := item.ProxyKey
if key == "" {
key = buildProxyKey(item.Protocol, item.Host, item.Port, item.Username, item.Password)
}
if err := validateDataProxy(item); err != nil {
result.ProxyFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "proxy",
Name: item.Name,
ProxyKey: key,
Message: err.Error(),
})
continue
}
normalizedStatus := normalizeProxyStatus(item.Status)
if existing, ok := proxyByKey[key]; ok {
result.ProxyReused++
if normalizedStatus != "" && normalizedStatus != existing.Status {
if _, err := h.adminService.UpdateProxy(ctx, existing.ID, &service.UpdateProxyInput{Status: normalizedStatus}); err != nil {
result.Errors = append(result.Errors, DataImportError{
Kind: "proxy",
Name: item.Name,
ProxyKey: key,
Message: "update status failed: " + err.Error(),
})
}
}
latencyProbeIDs = append(latencyProbeIDs, existing.ID)
continue
}
created, err := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{
Name: defaultProxyName(item.Name),
Protocol: item.Protocol,
Host: item.Host,
Port: item.Port,
Username: item.Username,
Password: item.Password,
})
if err != nil {
result.ProxyFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "proxy",
Name: item.Name,
ProxyKey: key,
Message: err.Error(),
})
continue
}
result.ProxyCreated++
proxyByKey[key] = *created
if normalizedStatus != "" && normalizedStatus != created.Status {
if _, err := h.adminService.UpdateProxy(ctx, created.ID, &service.UpdateProxyInput{Status: normalizedStatus}); err != nil {
result.Errors = append(result.Errors, DataImportError{
Kind: "proxy",
Name: item.Name,
ProxyKey: key,
Message: "update status failed: " + err.Error(),
})
}
}
// CreateProxy already triggers a latency probe, avoid double probing here.
}
if len(latencyProbeIDs) > 0 {
ids := append([]int64(nil), latencyProbeIDs...)
go func() {
for _, id := range ids {
_, _ = h.adminService.TestProxy(context.Background(), id)
}
}()
}
response.Success(c, result)
}
func (h *ProxyHandler) getProxiesByIDs(ctx context.Context, ids []int64) ([]service.Proxy, error) {
if len(ids) == 0 {
return []service.Proxy{}, nil
}
return h.adminService.GetProxiesByIDs(ctx, ids)
}
func parseProxyIDs(c *gin.Context) ([]int64, error) {
values := c.QueryArray("ids")
if len(values) == 0 {
raw := strings.TrimSpace(c.Query("ids"))
if raw != "" {
values = []string{raw}
}
}
if len(values) == 0 {
return nil, nil
}
ids := make([]int64, 0, len(values))
for _, item := range values {
for _, part := range strings.Split(item, ",") {
part = strings.TrimSpace(part)
if part == "" {
continue
}
id, err := strconv.ParseInt(part, 10, 64)
if err != nil || id <= 0 {
return nil, fmt.Errorf("invalid proxy id: %s", part)
}
ids = append(ids, id)
}
}
return ids, nil
}
func (h *ProxyHandler) listProxiesFiltered(ctx context.Context, protocol, status, search string) ([]service.Proxy, error) {
page := 1
pageSize := dataPageCap
var out []service.Proxy
for {
items, total, err := h.adminService.ListProxies(ctx, page, pageSize, protocol, status, search)
if err != nil {
return nil, err
}
out = append(out, items...)
if len(out) >= int(total) || len(items) == 0 {
break
}
page++
}
return out, nil
}

View File

@@ -0,0 +1,188 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type proxyDataResponse struct {
Code int `json:"code"`
Data DataPayload `json:"data"`
}
type proxyImportResponse struct {
Code int `json:"code"`
Data DataImportResult `json:"data"`
}
func setupProxyDataRouter() (*gin.Engine, *stubAdminService) {
gin.SetMode(gin.TestMode)
router := gin.New()
adminSvc := newStubAdminService()
h := NewProxyHandler(adminSvc)
router.GET("/api/v1/admin/proxies/data", h.ExportData)
router.POST("/api/v1/admin/proxies/data", h.ImportData)
return router, adminSvc
}
func TestProxyExportDataRespectsFilters(t *testing.T) {
router, adminSvc := setupProxyDataRouter()
adminSvc.proxies = []service.Proxy{
{
ID: 1,
Name: "proxy-a",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Username: "user",
Password: "pass",
Status: service.StatusActive,
},
{
ID: 2,
Name: "proxy-b",
Protocol: "https",
Host: "10.0.0.2",
Port: 443,
Username: "u",
Password: "p",
Status: service.StatusDisabled,
},
}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/data?protocol=https", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp proxyDataResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Empty(t, resp.Data.Type)
require.Equal(t, 0, resp.Data.Version)
require.Len(t, resp.Data.Proxies, 1)
require.Len(t, resp.Data.Accounts, 0)
require.Equal(t, "https", resp.Data.Proxies[0].Protocol)
}
func TestProxyExportDataWithSelectedIDs(t *testing.T) {
router, adminSvc := setupProxyDataRouter()
adminSvc.proxies = []service.Proxy{
{
ID: 1,
Name: "proxy-a",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Username: "user",
Password: "pass",
Status: service.StatusActive,
},
{
ID: 2,
Name: "proxy-b",
Protocol: "https",
Host: "10.0.0.2",
Port: 443,
Username: "u",
Password: "p",
Status: service.StatusDisabled,
},
}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/data?ids=2", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp proxyDataResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Len(t, resp.Data.Proxies, 1)
require.Equal(t, "https", resp.Data.Proxies[0].Protocol)
require.Equal(t, "10.0.0.2", resp.Data.Proxies[0].Host)
}
func TestProxyImportDataReusesAndTriggersLatencyProbe(t *testing.T) {
router, adminSvc := setupProxyDataRouter()
adminSvc.proxies = []service.Proxy{
{
ID: 1,
Name: "proxy-a",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Username: "user",
Password: "pass",
Status: service.StatusActive,
},
}
payload := map[string]any{
"data": map[string]any{
"type": dataType,
"version": dataVersion,
"proxies": []map[string]any{
{
"proxy_key": "http|127.0.0.1|8080|user|pass",
"name": "proxy-a",
"protocol": "http",
"host": "127.0.0.1",
"port": 8080,
"username": "user",
"password": "pass",
"status": "inactive",
},
{
"proxy_key": "https|10.0.0.2|443|u|p",
"name": "proxy-b",
"protocol": "https",
"host": "10.0.0.2",
"port": 443,
"username": "u",
"password": "p",
"status": "active",
},
},
"accounts": []map[string]any{},
},
}
body, _ := json.Marshal(payload)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/data", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp proxyImportResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Equal(t, 1, resp.Data.ProxyCreated)
require.Equal(t, 1, resp.Data.ProxyReused)
require.Equal(t, 0, resp.Data.ProxyFailed)
adminSvc.mu.Lock()
updatedIDs := append([]int64(nil), adminSvc.updatedProxyIDs...)
adminSvc.mu.Unlock()
require.Contains(t, updatedIDs, int64(1))
require.Eventually(t, func() bool {
adminSvc.mu.Lock()
defer adminSvc.mu.Unlock()
return len(adminSvc.testedProxyIDs) == 1
}, time.Second, 10*time.Millisecond)
}

View File

@@ -0,0 +1,367 @@
package admin
import (
"context"
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ProxyHandler handles admin proxy management
type ProxyHandler struct {
adminService service.AdminService
}
// NewProxyHandler creates a new admin proxy handler
func NewProxyHandler(adminService service.AdminService) *ProxyHandler {
return &ProxyHandler{
adminService: adminService,
}
}
// CreateProxyRequest represents create proxy request
type CreateProxyRequest struct {
Name string `json:"name" binding:"required"`
Protocol string `json:"protocol" binding:"required,oneof=http https socks5 socks5h"`
Host string `json:"host" binding:"required"`
Port int `json:"port" binding:"required,min=1,max=65535"`
Username string `json:"username"`
Password string `json:"password"`
}
// UpdateProxyRequest represents update proxy request
type UpdateProxyRequest struct {
Name string `json:"name"`
Protocol string `json:"protocol" binding:"omitempty,oneof=http https socks5 socks5h"`
Host string `json:"host"`
Port int `json:"port" binding:"omitempty,min=1,max=65535"`
Username string `json:"username"`
Password string `json:"password"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
}
// List handles listing all proxies with pagination
// GET /api/v1/admin/proxies
func (h *ProxyHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
protocol := c.Query("protocol")
status := c.Query("status")
search := c.Query("search")
// 标准化和验证 search 参数
search = strings.TrimSpace(search)
if len(search) > 100 {
search = search[:100]
}
proxies, total, err := h.adminService.ListProxiesWithAccountCount(c.Request.Context(), page, pageSize, protocol, status, search)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.AdminProxyWithAccountCount, 0, len(proxies))
for i := range proxies {
out = append(out, *dto.ProxyWithAccountCountFromServiceAdmin(&proxies[i]))
}
response.Paginated(c, out, total, page, pageSize)
}
// GetAll handles getting all active proxies without pagination
// GET /api/v1/admin/proxies/all
// Optional query param: with_count=true to include account count per proxy
func (h *ProxyHandler) GetAll(c *gin.Context) {
withCount := c.Query("with_count") == "true"
if withCount {
proxies, err := h.adminService.GetAllProxiesWithAccountCount(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.AdminProxyWithAccountCount, 0, len(proxies))
for i := range proxies {
out = append(out, *dto.ProxyWithAccountCountFromServiceAdmin(&proxies[i]))
}
response.Success(c, out)
return
}
proxies, err := h.adminService.GetAllProxies(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.AdminProxy, 0, len(proxies))
for i := range proxies {
out = append(out, *dto.ProxyFromServiceAdmin(&proxies[i]))
}
response.Success(c, out)
}
// GetByID handles getting a proxy by ID
// GET /api/v1/admin/proxies/:id
func (h *ProxyHandler) GetByID(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid proxy ID")
return
}
proxy, err := h.adminService.GetProxy(c.Request.Context(), proxyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.ProxyFromServiceAdmin(proxy))
}
// Create handles creating a new proxy
// POST /api/v1/admin/proxies
func (h *ProxyHandler) Create(c *gin.Context) {
var req CreateProxyRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
executeAdminIdempotentJSON(c, "admin.proxies.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
proxy, err := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{
Name: strings.TrimSpace(req.Name),
Protocol: strings.TrimSpace(req.Protocol),
Host: strings.TrimSpace(req.Host),
Port: req.Port,
Username: strings.TrimSpace(req.Username),
Password: strings.TrimSpace(req.Password),
})
if err != nil {
return nil, err
}
return dto.ProxyFromServiceAdmin(proxy), nil
})
}
// Update handles updating a proxy
// PUT /api/v1/admin/proxies/:id
func (h *ProxyHandler) Update(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid proxy ID")
return
}
var req UpdateProxyRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
proxy, err := h.adminService.UpdateProxy(c.Request.Context(), proxyID, &service.UpdateProxyInput{
Name: strings.TrimSpace(req.Name),
Protocol: strings.TrimSpace(req.Protocol),
Host: strings.TrimSpace(req.Host),
Port: req.Port,
Username: strings.TrimSpace(req.Username),
Password: strings.TrimSpace(req.Password),
Status: strings.TrimSpace(req.Status),
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.ProxyFromServiceAdmin(proxy))
}
// Delete handles deleting a proxy
// DELETE /api/v1/admin/proxies/:id
func (h *ProxyHandler) Delete(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid proxy ID")
return
}
err = h.adminService.DeleteProxy(c.Request.Context(), proxyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Proxy deleted successfully"})
}
// BatchDelete handles batch deleting proxies
// POST /api/v1/admin/proxies/batch-delete
func (h *ProxyHandler) BatchDelete(c *gin.Context) {
type BatchDeleteRequest struct {
IDs []int64 `json:"ids" binding:"required,min=1"`
}
var req BatchDeleteRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
result, err := h.adminService.BatchDeleteProxies(c.Request.Context(), req.IDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// Test handles testing proxy connectivity
// POST /api/v1/admin/proxies/:id/test
func (h *ProxyHandler) Test(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid proxy ID")
return
}
result, err := h.adminService.TestProxy(c.Request.Context(), proxyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// CheckQuality handles checking proxy quality across common AI targets.
// POST /api/v1/admin/proxies/:id/quality-check
func (h *ProxyHandler) CheckQuality(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid proxy ID")
return
}
result, err := h.adminService.CheckProxyQuality(c.Request.Context(), proxyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// GetStats handles getting proxy statistics
// GET /api/v1/admin/proxies/:id/stats
func (h *ProxyHandler) GetStats(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid proxy ID")
return
}
// Return mock data for now
_ = proxyID
response.Success(c, gin.H{
"total_accounts": 0,
"active_accounts": 0,
"total_requests": 0,
"success_rate": 100.0,
"average_latency": 0,
})
}
// GetProxyAccounts handles getting accounts using a proxy
// GET /api/v1/admin/proxies/:id/accounts
func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid proxy ID")
return
}
accounts, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.ProxyAccountSummary, 0, len(accounts))
for i := range accounts {
out = append(out, *dto.ProxyAccountSummaryFromService(&accounts[i]))
}
response.Success(c, out)
}
// BatchCreateProxyItem represents a single proxy in batch create request
type BatchCreateProxyItem struct {
Protocol string `json:"protocol" binding:"required,oneof=http https socks5 socks5h"`
Host string `json:"host" binding:"required"`
Port int `json:"port" binding:"required,min=1,max=65535"`
Username string `json:"username"`
Password string `json:"password"`
}
// BatchCreateRequest represents batch create proxies request
type BatchCreateRequest struct {
Proxies []BatchCreateProxyItem `json:"proxies" binding:"required,min=1"`
}
// BatchCreate handles batch creating proxies
// POST /api/v1/admin/proxies/batch
func (h *ProxyHandler) BatchCreate(c *gin.Context) {
var req BatchCreateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
created := 0
skipped := 0
for _, item := range req.Proxies {
// Trim all string fields
host := strings.TrimSpace(item.Host)
protocol := strings.TrimSpace(item.Protocol)
username := strings.TrimSpace(item.Username)
password := strings.TrimSpace(item.Password)
// Check for duplicates (same host, port, username, password)
exists, err := h.adminService.CheckProxyExists(c.Request.Context(), host, item.Port, username, password)
if err != nil {
response.ErrorFrom(c, err)
return
}
if exists {
skipped++
continue
}
// Create proxy with default name
_, err = h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
Name: "default",
Protocol: protocol,
Host: host,
Port: item.Port,
Username: username,
Password: password,
})
if err != nil {
// If creation fails due to duplicate, count as skipped
skipped++
continue
}
created++
}
response.Success(c, gin.H{
"created": created,
"skipped": skipped,
})
}

View File

@@ -0,0 +1,360 @@
package admin
import (
"bytes"
"context"
"encoding/csv"
"errors"
"fmt"
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// RedeemHandler handles admin redeem code management
type RedeemHandler struct {
adminService service.AdminService
redeemService *service.RedeemService
}
// NewRedeemHandler creates a new admin redeem handler
func NewRedeemHandler(adminService service.AdminService, redeemService *service.RedeemService) *RedeemHandler {
return &RedeemHandler{
adminService: adminService,
redeemService: redeemService,
}
}
// GenerateRedeemCodesRequest represents generate redeem codes request
type GenerateRedeemCodesRequest struct {
Count int `json:"count" binding:"required,min=1,max=100"`
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
Value float64 `json:"value" binding:"min=0"`
GroupID *int64 `json:"group_id"` // 订阅类型必填
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // 订阅类型使用默认30天最大100年
}
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
// Type 为 omitempty 而非 required 是为了向后兼容旧版调用方(不传 type 时默认 balance
type CreateAndRedeemCodeRequest struct {
Code string `json:"code" binding:"required,min=3,max=128"`
Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance向后兼容
Value float64 `json:"value" binding:"required,gt=0"`
UserID int64 `json:"user_id" binding:"required,gt=0"`
GroupID *int64 `json:"group_id"` // subscription 类型必填
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // subscription 类型必填,>0
Notes string `json:"notes"`
}
// List handles listing all redeem codes with pagination
// GET /api/v1/admin/redeem-codes
func (h *RedeemHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
codeType := c.Query("type")
status := c.Query("status")
search := c.Query("search")
// 标准化和验证 search 参数
search = strings.TrimSpace(search)
if len(search) > 100 {
search = search[:100]
}
codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.AdminRedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
}
response.Paginated(c, out, total, page, pageSize)
}
// GetByID handles getting a redeem code by ID
// GET /api/v1/admin/redeem-codes/:id
func (h *RedeemHandler) GetByID(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid redeem code ID")
return
}
code, err := h.adminService.GetRedeemCode(c.Request.Context(), codeID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.RedeemCodeFromServiceAdmin(code))
}
// Generate handles generating new redeem codes
// POST /api/v1/admin/redeem-codes/generate
func (h *RedeemHandler) Generate(c *gin.Context) {
var req GenerateRedeemCodesRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
executeAdminIdempotentJSON(c, "admin.redeem_codes.generate", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
codes, execErr := h.adminService.GenerateRedeemCodes(ctx, &service.GenerateRedeemCodesInput{
Count: req.Count,
Type: req.Type,
Value: req.Value,
GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
})
if execErr != nil {
return nil, execErr
}
out := make([]dto.AdminRedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
}
return out, nil
})
}
// CreateAndRedeem creates a fixed redeem code and redeems it for a target user in one step.
// POST /api/v1/admin/redeem-codes/create-and-redeem
func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
if h.redeemService == nil {
response.InternalError(c, "redeem service not configured")
return
}
var req CreateAndRedeemCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
req.Code = strings.TrimSpace(req.Code)
// 向后兼容:旧版调用方(如 Sub2ApiPay不传 type 字段,默认当作 balance 充值处理。
// 请勿删除此默认值逻辑,否则会导致旧版调用方 400 报错。
if req.Type == "" {
req.Type = "balance"
}
if req.Type == "subscription" {
if req.GroupID == nil {
response.BadRequest(c, "group_id is required for subscription type")
return
}
if req.ValidityDays <= 0 {
response.BadRequest(c, "validity_days must be greater than 0 for subscription type")
return
}
}
executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
existing, err := h.redeemService.GetByCode(ctx, req.Code)
if err == nil {
return h.resolveCreateAndRedeemExisting(ctx, existing, req.UserID)
}
if !errors.Is(err, service.ErrRedeemCodeNotFound) {
return nil, err
}
createErr := h.redeemService.CreateCode(ctx, &service.RedeemCode{
Code: req.Code,
Type: req.Type,
Value: req.Value,
Status: service.StatusUnused,
Notes: req.Notes,
GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
})
if createErr != nil {
// Unique code race: if code now exists, use idempotent semantics by used_by.
existingAfterCreateErr, getErr := h.redeemService.GetByCode(ctx, req.Code)
if getErr == nil {
return h.resolveCreateAndRedeemExisting(ctx, existingAfterCreateErr, req.UserID)
}
return nil, createErr
}
redeemed, redeemErr := h.redeemService.Redeem(ctx, req.UserID, req.Code)
if redeemErr != nil {
return nil, redeemErr
}
return gin.H{"redeem_code": dto.RedeemCodeFromServiceAdmin(redeemed)}, nil
})
}
func (h *RedeemHandler) resolveCreateAndRedeemExisting(ctx context.Context, existing *service.RedeemCode, userID int64) (any, error) {
if existing == nil {
return nil, infraerrors.Conflict("REDEEM_CODE_CONFLICT", "redeem code conflict")
}
// If previous run created the code but crashed before redeem, redeem it now.
if existing.CanUse() {
redeemed, err := h.redeemService.Redeem(ctx, userID, existing.Code)
if err == nil {
return gin.H{"redeem_code": dto.RedeemCodeFromServiceAdmin(redeemed)}, nil
}
if !errors.Is(err, service.ErrRedeemCodeUsed) {
return nil, err
}
latest, getErr := h.redeemService.GetByCode(ctx, existing.Code)
if getErr == nil {
existing = latest
}
}
if existing.UsedBy != nil && *existing.UsedBy == userID {
return gin.H{"redeem_code": dto.RedeemCodeFromServiceAdmin(existing)}, nil
}
return nil, infraerrors.Conflict("REDEEM_CODE_CONFLICT", "redeem code already used by another user")
}
// Delete handles deleting a redeem code
// DELETE /api/v1/admin/redeem-codes/:id
func (h *RedeemHandler) Delete(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid redeem code ID")
return
}
err = h.adminService.DeleteRedeemCode(c.Request.Context(), codeID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Redeem code deleted successfully"})
}
// BatchDelete handles batch deleting redeem codes
// POST /api/v1/admin/redeem-codes/batch-delete
func (h *RedeemHandler) BatchDelete(c *gin.Context) {
var req struct {
IDs []int64 `json:"ids" binding:"required,min=1"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
deleted, err := h.adminService.BatchDeleteRedeemCodes(c.Request.Context(), req.IDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{
"deleted": deleted,
"message": "Redeem codes deleted successfully",
})
}
// Expire handles expiring a redeem code
// POST /api/v1/admin/redeem-codes/:id/expire
func (h *RedeemHandler) Expire(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid redeem code ID")
return
}
code, err := h.adminService.ExpireRedeemCode(c.Request.Context(), codeID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.RedeemCodeFromServiceAdmin(code))
}
// GetStats handles getting redeem code statistics
// GET /api/v1/admin/redeem-codes/stats
func (h *RedeemHandler) GetStats(c *gin.Context) {
// Return mock data for now
response.Success(c, gin.H{
"total_codes": 0,
"active_codes": 0,
"used_codes": 0,
"expired_codes": 0,
"total_value_distributed": 0.0,
"by_type": gin.H{
"balance": 0,
"concurrency": 0,
"trial": 0,
},
})
}
// Export handles exporting redeem codes to CSV
// GET /api/v1/admin/redeem-codes/export
func (h *RedeemHandler) Export(c *gin.Context) {
codeType := c.Query("type")
status := c.Query("status")
// Get all codes without pagination (use large page size)
codes, _, err := h.adminService.ListRedeemCodes(c.Request.Context(), 1, 10000, codeType, status, "")
if err != nil {
response.ErrorFrom(c, err)
return
}
// Create CSV buffer
var buf bytes.Buffer
writer := csv.NewWriter(&buf)
// Write header
if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_by_email", "used_at", "created_at"}); err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
return
}
// Write data rows
for _, code := range codes {
usedBy := ""
if code.UsedBy != nil {
usedBy = fmt.Sprintf("%d", *code.UsedBy)
}
usedByEmail := ""
if code.User != nil {
usedByEmail = code.User.Email
}
usedAt := ""
if code.UsedAt != nil {
usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
}
if err := writer.Write([]string{
fmt.Sprintf("%d", code.ID),
code.Code,
code.Type,
fmt.Sprintf("%.2f", code.Value),
code.Status,
usedBy,
usedByEmail,
usedAt,
code.CreatedAt.Format("2006-01-02 15:04:05"),
}); err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
return
}
}
writer.Flush()
if err := writer.Error(); err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
return
}
c.Header("Content-Type", "text/csv")
c.Header("Content-Disposition", "attachment; filename=redeem_codes.csv")
c.Data(200, "text/csv", buf.Bytes())
}

View File

@@ -0,0 +1,135 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// newCreateAndRedeemHandler creates a RedeemHandler with a non-nil (but minimal)
// RedeemService so that CreateAndRedeem's nil guard passes and we can test the
// parameter-validation layer that runs before any service call.
func newCreateAndRedeemHandler() *RedeemHandler {
return &RedeemHandler{
adminService: newStubAdminService(),
redeemService: &service.RedeemService{}, // non-nil to pass nil guard
}
}
// postCreateAndRedeemValidation calls CreateAndRedeem and returns the response
// status code. For cases that pass validation and proceed into the service layer,
// a panic may occur (because RedeemService internals are nil); this is expected
// and treated as "validation passed" (returns 0 to indicate panic).
func postCreateAndRedeemValidation(t *testing.T, handler *RedeemHandler, body any) (code int) {
t.Helper()
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
jsonBytes, err := json.Marshal(body)
require.NoError(t, err)
c.Request, _ = http.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/create-and-redeem", bytes.NewReader(jsonBytes))
c.Request.Header.Set("Content-Type", "application/json")
defer func() {
if r := recover(); r != nil {
// Panic means we passed validation and entered service layer (expected for minimal stub).
code = 0
}
}()
handler.CreateAndRedeem(c)
return w.Code
}
func TestCreateAndRedeem_TypeDefaultsToBalance(t *testing.T) {
// 不传 type 字段时应默认 balance不触发 subscription 校验。
// 验证通过后进入 service 层会 panic返回 0说明默认值生效。
h := newCreateAndRedeemHandler()
code := postCreateAndRedeemValidation(t, h, map[string]any{
"code": "test-balance-default",
"value": 10.0,
"user_id": 1,
})
assert.NotEqual(t, http.StatusBadRequest, code,
"omitting type should default to balance and pass validation")
}
func TestCreateAndRedeem_SubscriptionRequiresGroupID(t *testing.T) {
h := newCreateAndRedeemHandler()
code := postCreateAndRedeemValidation(t, h, map[string]any{
"code": "test-sub-no-group",
"type": "subscription",
"value": 29.9,
"user_id": 1,
"validity_days": 30,
// group_id 缺失
})
assert.Equal(t, http.StatusBadRequest, code)
}
func TestCreateAndRedeem_SubscriptionRequiresPositiveValidityDays(t *testing.T) {
groupID := int64(5)
h := newCreateAndRedeemHandler()
cases := []struct {
name string
validityDays int
}{
{"zero", 0},
{"negative", -1},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
code := postCreateAndRedeemValidation(t, h, map[string]any{
"code": "test-sub-bad-days-" + tc.name,
"type": "subscription",
"value": 29.9,
"user_id": 1,
"group_id": groupID,
"validity_days": tc.validityDays,
})
assert.Equal(t, http.StatusBadRequest, code)
})
}
}
func TestCreateAndRedeem_SubscriptionValidParamsPassValidation(t *testing.T) {
groupID := int64(5)
h := newCreateAndRedeemHandler()
code := postCreateAndRedeemValidation(t, h, map[string]any{
"code": "test-sub-valid",
"type": "subscription",
"value": 29.9,
"user_id": 1,
"group_id": groupID,
"validity_days": 31,
})
assert.NotEqual(t, http.StatusBadRequest, code,
"valid subscription params should pass validation")
}
func TestCreateAndRedeem_BalanceIgnoresSubscriptionFields(t *testing.T) {
h := newCreateAndRedeemHandler()
// balance 类型不传 group_id 和 validity_days不应报 400
code := postCreateAndRedeemValidation(t, h, map[string]any{
"code": "test-balance-no-extras",
"type": "balance",
"value": 50.0,
"user_id": 1,
})
assert.NotEqual(t, http.StatusBadRequest, code,
"balance type should not require group_id or validity_days")
}

View File

@@ -0,0 +1,163 @@
package admin
import (
"net/http"
"strconv"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ScheduledTestHandler handles admin scheduled-test-plan management.
type ScheduledTestHandler struct {
scheduledTestSvc *service.ScheduledTestService
}
// NewScheduledTestHandler creates a new ScheduledTestHandler.
func NewScheduledTestHandler(scheduledTestSvc *service.ScheduledTestService) *ScheduledTestHandler {
return &ScheduledTestHandler{scheduledTestSvc: scheduledTestSvc}
}
type createScheduledTestPlanRequest struct {
AccountID int64 `json:"account_id" binding:"required"`
ModelID string `json:"model_id"`
CronExpression string `json:"cron_expression" binding:"required"`
Enabled *bool `json:"enabled"`
MaxResults int `json:"max_results"`
AutoRecover *bool `json:"auto_recover"`
}
type updateScheduledTestPlanRequest struct {
ModelID string `json:"model_id"`
CronExpression string `json:"cron_expression"`
Enabled *bool `json:"enabled"`
MaxResults int `json:"max_results"`
AutoRecover *bool `json:"auto_recover"`
}
// ListByAccount GET /admin/accounts/:id/scheduled-test-plans
func (h *ScheduledTestHandler) ListByAccount(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "invalid account id")
return
}
plans, err := h.scheduledTestSvc.ListPlansByAccount(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, err.Error())
return
}
c.JSON(http.StatusOK, plans)
}
// Create POST /admin/scheduled-test-plans
func (h *ScheduledTestHandler) Create(c *gin.Context) {
var req createScheduledTestPlanRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
plan := &service.ScheduledTestPlan{
AccountID: req.AccountID,
ModelID: req.ModelID,
CronExpression: req.CronExpression,
Enabled: true,
MaxResults: req.MaxResults,
}
if req.Enabled != nil {
plan.Enabled = *req.Enabled
}
if req.AutoRecover != nil {
plan.AutoRecover = *req.AutoRecover
}
created, err := h.scheduledTestSvc.CreatePlan(c.Request.Context(), plan)
if err != nil {
response.BadRequest(c, err.Error())
return
}
c.JSON(http.StatusOK, created)
}
// Update PUT /admin/scheduled-test-plans/:id
func (h *ScheduledTestHandler) Update(c *gin.Context) {
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "invalid plan id")
return
}
existing, err := h.scheduledTestSvc.GetPlan(c.Request.Context(), planID)
if err != nil {
response.NotFound(c, "plan not found")
return
}
var req updateScheduledTestPlanRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if req.ModelID != "" {
existing.ModelID = req.ModelID
}
if req.CronExpression != "" {
existing.CronExpression = req.CronExpression
}
if req.Enabled != nil {
existing.Enabled = *req.Enabled
}
if req.MaxResults > 0 {
existing.MaxResults = req.MaxResults
}
if req.AutoRecover != nil {
existing.AutoRecover = *req.AutoRecover
}
updated, err := h.scheduledTestSvc.UpdatePlan(c.Request.Context(), existing)
if err != nil {
response.BadRequest(c, err.Error())
return
}
c.JSON(http.StatusOK, updated)
}
// Delete DELETE /admin/scheduled-test-plans/:id
func (h *ScheduledTestHandler) Delete(c *gin.Context) {
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "invalid plan id")
return
}
if err := h.scheduledTestSvc.DeletePlan(c.Request.Context(), planID); err != nil {
response.InternalError(c, err.Error())
return
}
c.JSON(http.StatusOK, gin.H{"message": "deleted"})
}
// ListResults GET /admin/scheduled-test-plans/:id/results
func (h *ScheduledTestHandler) ListResults(c *gin.Context) {
planID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "invalid plan id")
return
}
limit := 50
if l, err := strconv.Atoi(c.Query("limit")); err == nil && l > 0 {
limit = l
}
results, err := h.scheduledTestSvc.ListResults(c.Request.Context(), planID, limit)
if err != nil {
response.InternalError(c, err.Error())
return
}
c.JSON(http.StatusOK, results)
}

View File

@@ -0,0 +1,97 @@
//go:build unit
package admin
import (
"testing"
"github.com/stretchr/testify/require"
)
// truncateSearchByRune 模拟 user_handler.go 中的 search 截断逻辑
func truncateSearchByRune(search string, maxRunes int) string {
if runes := []rune(search); len(runes) > maxRunes {
return string(runes[:maxRunes])
}
return search
}
func TestTruncateSearchByRune(t *testing.T) {
tests := []struct {
name string
input string
maxRunes int
wantLen int // 期望的 rune 长度
}{
{
name: "纯中文超长",
input: string(make([]rune, 150)),
maxRunes: 100,
wantLen: 100,
},
{
name: "纯 ASCII 超长",
input: string(make([]byte, 150)),
maxRunes: 100,
wantLen: 100,
},
{
name: "空字符串",
input: "",
maxRunes: 100,
wantLen: 0,
},
{
name: "恰好 100 个字符",
input: string(make([]rune, 100)),
maxRunes: 100,
wantLen: 100,
},
{
name: "不足 100 字符不截断",
input: "hello世界",
maxRunes: 100,
wantLen: 7,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := truncateSearchByRune(tc.input, tc.maxRunes)
require.Equal(t, tc.wantLen, len([]rune(result)))
})
}
}
func TestTruncateSearchByRune_PreservesMultibyte(t *testing.T) {
// 101 个中文字符,截断到 100 个后应该仍然是有效 UTF-8
input := ""
for i := 0; i < 101; i++ {
input += "中"
}
result := truncateSearchByRune(input, 100)
require.Equal(t, 100, len([]rune(result)))
// 验证截断结果是有效的 UTF-8每个中文字符 3 字节)
require.Equal(t, 300, len(result))
}
func TestTruncateSearchByRune_MixedASCIIAndMultibyte(t *testing.T) {
// 50 个 ASCII + 51 个中文 = 101 个 rune
input := ""
for i := 0; i < 50; i++ {
input += "a"
}
for i := 0; i < 51; i++ {
input += "中"
}
result := truncateSearchByRune(input, 100)
runes := []rune(result)
require.Equal(t, 100, len(runes))
// 前 50 个应该是 'a',后 50 个应该是 '中'
require.Equal(t, 'a', runes[0])
require.Equal(t, 'a', runes[49])
require.Equal(t, '中', runes[50])
require.Equal(t, '中', runes[99])
}

View File

@@ -0,0 +1,138 @@
package admin
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"strings"
"sync"
"time"
"golang.org/x/sync/singleflight"
)
type snapshotCacheEntry struct {
ETag string
Payload any
ExpiresAt time.Time
}
type snapshotCache struct {
mu sync.RWMutex
ttl time.Duration
items map[string]snapshotCacheEntry
sf singleflight.Group
}
type snapshotCacheLoadResult struct {
Entry snapshotCacheEntry
Hit bool
}
func newSnapshotCache(ttl time.Duration) *snapshotCache {
if ttl <= 0 {
ttl = 30 * time.Second
}
return &snapshotCache{
ttl: ttl,
items: make(map[string]snapshotCacheEntry),
}
}
func (c *snapshotCache) Get(key string) (snapshotCacheEntry, bool) {
if c == nil || key == "" {
return snapshotCacheEntry{}, false
}
now := time.Now()
c.mu.RLock()
entry, ok := c.items[key]
c.mu.RUnlock()
if !ok {
return snapshotCacheEntry{}, false
}
if now.After(entry.ExpiresAt) {
c.mu.Lock()
delete(c.items, key)
c.mu.Unlock()
return snapshotCacheEntry{}, false
}
return entry, true
}
func (c *snapshotCache) Set(key string, payload any) snapshotCacheEntry {
if c == nil {
return snapshotCacheEntry{}
}
entry := snapshotCacheEntry{
ETag: buildETagFromAny(payload),
Payload: payload,
ExpiresAt: time.Now().Add(c.ttl),
}
if key == "" {
return entry
}
c.mu.Lock()
c.items[key] = entry
c.mu.Unlock()
return entry
}
func (c *snapshotCache) GetOrLoad(key string, load func() (any, error)) (snapshotCacheEntry, bool, error) {
if load == nil {
return snapshotCacheEntry{}, false, nil
}
if entry, ok := c.Get(key); ok {
return entry, true, nil
}
if c == nil || key == "" {
payload, err := load()
if err != nil {
return snapshotCacheEntry{}, false, err
}
return c.Set(key, payload), false, nil
}
value, err, _ := c.sf.Do(key, func() (any, error) {
if entry, ok := c.Get(key); ok {
return snapshotCacheLoadResult{Entry: entry, Hit: true}, nil
}
payload, err := load()
if err != nil {
return nil, err
}
return snapshotCacheLoadResult{Entry: c.Set(key, payload), Hit: false}, nil
})
if err != nil {
return snapshotCacheEntry{}, false, err
}
result, ok := value.(snapshotCacheLoadResult)
if !ok {
return snapshotCacheEntry{}, false, nil
}
return result.Entry, result.Hit, nil
}
func buildETagFromAny(payload any) string {
raw, err := json.Marshal(payload)
if err != nil {
return ""
}
sum := sha256.Sum256(raw)
return "\"" + hex.EncodeToString(sum[:]) + "\""
}
func parseBoolQueryWithDefault(raw string, def bool) bool {
value := strings.TrimSpace(strings.ToLower(raw))
if value == "" {
return def
}
switch value {
case "1", "true", "yes", "on":
return true
case "0", "false", "no", "off":
return false
default:
return def
}
}

View File

@@ -0,0 +1,185 @@
//go:build unit
package admin
import (
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestSnapshotCache_SetAndGet(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
entry := c.Set("key1", map[string]string{"hello": "world"})
require.NotEmpty(t, entry.ETag)
require.NotNil(t, entry.Payload)
got, ok := c.Get("key1")
require.True(t, ok)
require.Equal(t, entry.ETag, got.ETag)
}
func TestSnapshotCache_Expiration(t *testing.T) {
c := newSnapshotCache(1 * time.Millisecond)
c.Set("key1", "value")
time.Sleep(5 * time.Millisecond)
_, ok := c.Get("key1")
require.False(t, ok, "expired entry should not be returned")
}
func TestSnapshotCache_GetEmptyKey(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
_, ok := c.Get("")
require.False(t, ok)
}
func TestSnapshotCache_GetMiss(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
_, ok := c.Get("nonexistent")
require.False(t, ok)
}
func TestSnapshotCache_NilReceiver(t *testing.T) {
var c *snapshotCache
_, ok := c.Get("key")
require.False(t, ok)
entry := c.Set("key", "value")
require.Empty(t, entry.ETag)
}
func TestSnapshotCache_SetEmptyKey(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
// Set with empty key should return entry but not store it
entry := c.Set("", "value")
require.NotEmpty(t, entry.ETag)
_, ok := c.Get("")
require.False(t, ok)
}
func TestSnapshotCache_DefaultTTL(t *testing.T) {
c := newSnapshotCache(0)
require.Equal(t, 30*time.Second, c.ttl)
c2 := newSnapshotCache(-1 * time.Second)
require.Equal(t, 30*time.Second, c2.ttl)
}
func TestSnapshotCache_ETagDeterministic(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
payload := map[string]int{"a": 1, "b": 2}
entry1 := c.Set("k1", payload)
entry2 := c.Set("k2", payload)
require.Equal(t, entry1.ETag, entry2.ETag, "same payload should produce same ETag")
}
func TestSnapshotCache_ETagFormat(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
entry := c.Set("k", "test")
// ETag should be quoted hex string: "abcdef..."
require.True(t, len(entry.ETag) > 2)
require.Equal(t, byte('"'), entry.ETag[0])
require.Equal(t, byte('"'), entry.ETag[len(entry.ETag)-1])
}
func TestBuildETagFromAny_UnmarshalablePayload(t *testing.T) {
// channels are not JSON-serializable
etag := buildETagFromAny(make(chan int))
require.Empty(t, etag)
}
func TestSnapshotCache_GetOrLoad_MissThenHit(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
var loads atomic.Int32
entry, hit, err := c.GetOrLoad("key1", func() (any, error) {
loads.Add(1)
return map[string]string{"hello": "world"}, nil
})
require.NoError(t, err)
require.False(t, hit)
require.NotEmpty(t, entry.ETag)
require.Equal(t, int32(1), loads.Load())
entry2, hit, err := c.GetOrLoad("key1", func() (any, error) {
loads.Add(1)
return map[string]string{"unexpected": "value"}, nil
})
require.NoError(t, err)
require.True(t, hit)
require.Equal(t, entry.ETag, entry2.ETag)
require.Equal(t, int32(1), loads.Load())
}
func TestSnapshotCache_GetOrLoad_ConcurrentSingleflight(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
var loads atomic.Int32
start := make(chan struct{})
const callers = 8
errCh := make(chan error, callers)
var wg sync.WaitGroup
wg.Add(callers)
for range callers {
go func() {
defer wg.Done()
<-start
_, _, err := c.GetOrLoad("shared", func() (any, error) {
loads.Add(1)
time.Sleep(20 * time.Millisecond)
return "value", nil
})
errCh <- err
}()
}
close(start)
wg.Wait()
close(errCh)
for err := range errCh {
require.NoError(t, err)
}
require.Equal(t, int32(1), loads.Load())
}
func TestParseBoolQueryWithDefault(t *testing.T) {
tests := []struct {
name string
raw string
def bool
want bool
}{
{"empty returns default true", "", true, true},
{"empty returns default false", "", false, false},
{"1", "1", false, true},
{"true", "true", false, true},
{"TRUE", "TRUE", false, true},
{"yes", "yes", false, true},
{"on", "on", false, true},
{"0", "0", true, false},
{"false", "false", true, false},
{"FALSE", "FALSE", true, false},
{"no", "no", true, false},
{"off", "off", true, false},
{"whitespace trimmed", " true ", false, true},
{"unknown returns default true", "maybe", true, true},
{"unknown returns default false", "maybe", false, false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := parseBoolQueryWithDefault(tc.raw, tc.def)
require.Equal(t, tc.want, got)
})
}
}

View File

@@ -0,0 +1,322 @@
package admin
import (
"context"
"strconv"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// toResponsePagination converts pagination.PaginationResult to response.PaginationResult
func toResponsePagination(p *pagination.PaginationResult) *response.PaginationResult {
if p == nil {
return nil
}
return &response.PaginationResult{
Total: p.Total,
Page: p.Page,
PageSize: p.PageSize,
Pages: p.Pages,
}
}
// SubscriptionHandler handles admin subscription management
type SubscriptionHandler struct {
subscriptionService *service.SubscriptionService
}
// NewSubscriptionHandler creates a new admin subscription handler
func NewSubscriptionHandler(subscriptionService *service.SubscriptionService) *SubscriptionHandler {
return &SubscriptionHandler{
subscriptionService: subscriptionService,
}
}
// AssignSubscriptionRequest represents assign subscription request
type AssignSubscriptionRequest struct {
UserID int64 `json:"user_id" binding:"required"`
GroupID int64 `json:"group_id" binding:"required"`
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // max 100 years
Notes string `json:"notes"`
}
// BulkAssignSubscriptionRequest represents bulk assign subscription request
type BulkAssignSubscriptionRequest struct {
UserIDs []int64 `json:"user_ids" binding:"required,min=1"`
GroupID int64 `json:"group_id" binding:"required"`
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // max 100 years
Notes string `json:"notes"`
}
// AdjustSubscriptionRequest represents adjust subscription request (extend or shorten)
type AdjustSubscriptionRequest struct {
Days int `json:"days" binding:"required,min=-36500,max=36500"` // negative to shorten, positive to extend
}
// List handles listing all subscriptions with pagination and filters
// GET /api/v1/admin/subscriptions
func (h *SubscriptionHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
// Parse optional filters
var userID, groupID *int64
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
userID = &id
}
}
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
groupID = &id
}
}
status := c.Query("status")
// Parse sorting parameters
sortBy := c.DefaultQuery("sort_by", "created_at")
sortOrder := c.DefaultQuery("sort_order", "desc")
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
}
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
}
// GetByID handles getting a subscription by ID
// GET /api/v1/admin/subscriptions/:id
func (h *SubscriptionHandler) GetByID(c *gin.Context) {
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid subscription ID")
return
}
subscription, err := h.subscriptionService.GetByID(c.Request.Context(), subscriptionID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
}
// GetProgress handles getting subscription usage progress
// GET /api/v1/admin/subscriptions/:id/progress
func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid subscription ID")
return
}
progress, err := h.subscriptionService.GetSubscriptionProgress(c.Request.Context(), subscriptionID)
if err != nil {
response.NotFound(c, "Subscription not found")
return
}
response.Success(c, progress)
}
// Assign handles assigning a subscription to a user
// POST /api/v1/admin/subscriptions/assign
func (h *SubscriptionHandler) Assign(c *gin.Context) {
var req AssignSubscriptionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Get admin user ID from context
adminID := getAdminIDFromContext(c)
subscription, err := h.subscriptionService.AssignSubscription(c.Request.Context(), &service.AssignSubscriptionInput{
UserID: req.UserID,
GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
AssignedBy: adminID,
Notes: req.Notes,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
}
// BulkAssign handles bulk assigning subscriptions to multiple users
// POST /api/v1/admin/subscriptions/bulk-assign
func (h *SubscriptionHandler) BulkAssign(c *gin.Context) {
var req BulkAssignSubscriptionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Get admin user ID from context
adminID := getAdminIDFromContext(c)
result, err := h.subscriptionService.BulkAssignSubscription(c.Request.Context(), &service.BulkAssignSubscriptionInput{
UserIDs: req.UserIDs,
GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
AssignedBy: adminID,
Notes: req.Notes,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.BulkAssignResultFromService(result))
}
// Extend handles adjusting a subscription (extend or shorten)
// POST /api/v1/admin/subscriptions/:id/extend
func (h *SubscriptionHandler) Extend(c *gin.Context) {
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid subscription ID")
return
}
var req AdjustSubscriptionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
idempotencyPayload := struct {
SubscriptionID int64 `json:"subscription_id"`
Body AdjustSubscriptionRequest `json:"body"`
}{
SubscriptionID: subscriptionID,
Body: req,
}
executeAdminIdempotentJSON(c, "admin.subscriptions.extend", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
subscription, execErr := h.subscriptionService.ExtendSubscription(ctx, subscriptionID, req.Days)
if execErr != nil {
return nil, execErr
}
return dto.UserSubscriptionFromServiceAdmin(subscription), nil
})
}
// ResetSubscriptionQuotaRequest represents the reset quota request
type ResetSubscriptionQuotaRequest struct {
Daily bool `json:"daily"`
Weekly bool `json:"weekly"`
Monthly bool `json:"monthly"`
}
// ResetQuota resets daily, weekly, and/or monthly usage for a subscription.
// POST /api/v1/admin/subscriptions/:id/reset-quota
func (h *SubscriptionHandler) ResetQuota(c *gin.Context) {
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid subscription ID")
return
}
var req ResetSubscriptionQuotaRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !req.Daily && !req.Weekly && !req.Monthly {
response.BadRequest(c, "At least one of 'daily', 'weekly', or 'monthly' must be true")
return
}
sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly, req.Monthly)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.UserSubscriptionFromServiceAdmin(sub))
}
// Revoke handles revoking a subscription
// DELETE /api/v1/admin/subscriptions/:id
func (h *SubscriptionHandler) Revoke(c *gin.Context) {
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid subscription ID")
return
}
err = h.subscriptionService.RevokeSubscription(c.Request.Context(), subscriptionID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Subscription revoked successfully"})
}
// ListByGroup handles listing subscriptions for a specific group
// GET /api/v1/admin/groups/:id/subscriptions
func (h *SubscriptionHandler) ListByGroup(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
page, pageSize := response.ParsePagination(c)
subscriptions, pagination, err := h.subscriptionService.ListGroupSubscriptions(c.Request.Context(), groupID, page, pageSize)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
}
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
}
// ListByUser handles listing subscriptions for a specific user
// GET /api/v1/admin/users/:id/subscriptions
func (h *SubscriptionHandler) ListByUser(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
}
response.Success(c, out)
}
// Helper function to get admin ID from context
func getAdminIDFromContext(c *gin.Context) int64 {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
return 0
}
return subject.UserID
}

View File

@@ -0,0 +1,177 @@
package admin
import (
"context"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/sysutil"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// SystemHandler handles system-related operations
type SystemHandler struct {
updateSvc *service.UpdateService
lockSvc *service.SystemOperationLockService
}
// NewSystemHandler creates a new SystemHandler
func NewSystemHandler(updateSvc *service.UpdateService, lockSvc *service.SystemOperationLockService) *SystemHandler {
return &SystemHandler{
updateSvc: updateSvc,
lockSvc: lockSvc,
}
}
// GetVersion returns the current version
// GET /api/v1/admin/system/version
func (h *SystemHandler) GetVersion(c *gin.Context) {
info, _ := h.updateSvc.CheckUpdate(c.Request.Context(), false)
response.Success(c, gin.H{
"version": info.CurrentVersion,
})
}
// CheckUpdates checks for available updates
// GET /api/v1/admin/system/check-updates
func (h *SystemHandler) CheckUpdates(c *gin.Context) {
force := c.Query("force") == "true"
info, err := h.updateSvc.CheckUpdate(c.Request.Context(), force)
if err != nil {
response.Error(c, http.StatusInternalServerError, err.Error())
return
}
response.Success(c, info)
}
// PerformUpdate downloads and applies the update
// POST /api/v1/admin/system/update
func (h *SystemHandler) PerformUpdate(c *gin.Context) {
operationID := buildSystemOperationID(c, "update")
payload := gin.H{"operation_id": operationID}
executeAdminIdempotentJSON(c, "admin.system.update", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
lock, release, err := h.acquireSystemLock(ctx, operationID)
if err != nil {
return nil, err
}
var releaseReason string
succeeded := false
defer func() {
release(releaseReason, succeeded)
}()
if err := h.updateSvc.PerformUpdate(ctx); err != nil {
releaseReason = "SYSTEM_UPDATE_FAILED"
return nil, err
}
succeeded = true
return gin.H{
"message": "Update completed. Please restart the service.",
"need_restart": true,
"operation_id": lock.OperationID(),
}, nil
})
}
// Rollback restores the previous version
// POST /api/v1/admin/system/rollback
func (h *SystemHandler) Rollback(c *gin.Context) {
operationID := buildSystemOperationID(c, "rollback")
payload := gin.H{"operation_id": operationID}
executeAdminIdempotentJSON(c, "admin.system.rollback", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
lock, release, err := h.acquireSystemLock(ctx, operationID)
if err != nil {
return nil, err
}
var releaseReason string
succeeded := false
defer func() {
release(releaseReason, succeeded)
}()
if err := h.updateSvc.Rollback(); err != nil {
releaseReason = "SYSTEM_ROLLBACK_FAILED"
return nil, err
}
succeeded = true
return gin.H{
"message": "Rollback completed. Please restart the service.",
"need_restart": true,
"operation_id": lock.OperationID(),
}, nil
})
}
// RestartService restarts the systemd service
// POST /api/v1/admin/system/restart
func (h *SystemHandler) RestartService(c *gin.Context) {
operationID := buildSystemOperationID(c, "restart")
payload := gin.H{"operation_id": operationID}
executeAdminIdempotentJSON(c, "admin.system.restart", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
lock, release, err := h.acquireSystemLock(ctx, operationID)
if err != nil {
return nil, err
}
succeeded := false
defer func() {
release("", succeeded)
}()
// Schedule service restart in background after sending response
// This ensures the client receives the success response before the service restarts
go func() {
// Wait a moment to ensure the response is sent
time.Sleep(500 * time.Millisecond)
sysutil.RestartServiceAsync()
}()
succeeded = true
return gin.H{
"message": "Service restart initiated",
"operation_id": lock.OperationID(),
}, nil
})
}
func (h *SystemHandler) acquireSystemLock(
ctx context.Context,
operationID string,
) (*service.SystemOperationLock, func(string, bool), error) {
if h.lockSvc == nil {
return nil, nil, service.ErrIdempotencyStoreUnavail
}
lock, err := h.lockSvc.Acquire(ctx, operationID)
if err != nil {
return nil, nil, err
}
release := func(reason string, succeeded bool) {
releaseCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = h.lockSvc.Release(releaseCtx, lock, succeeded, reason)
}
return lock, release, nil
}
func buildSystemOperationID(c *gin.Context, operation string) string {
key := strings.TrimSpace(c.GetHeader("Idempotency-Key"))
if key == "" {
return "sysop-" + operation + "-" + strconv.FormatInt(time.Now().UnixNano(), 36)
}
actorScope := "admin:0"
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
actorScope = "admin:" + strconv.FormatInt(subject.UserID, 10)
}
seed := operation + "|" + actorScope + "|" + c.FullPath() + "|" + key
hash := service.HashIdempotencyKey(seed)
if len(hash) > 24 {
hash = hash[:24]
}
return "sysop-" + hash
}

View File

@@ -0,0 +1,463 @@
package admin
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type cleanupRepoStub struct {
mu sync.Mutex
created []*service.UsageCleanupTask
listTasks []service.UsageCleanupTask
listResult *pagination.PaginationResult
listErr error
statusByID map[int64]string
}
func (s *cleanupRepoStub) CreateTask(ctx context.Context, task *service.UsageCleanupTask) error {
if task == nil {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
if task.ID == 0 {
task.ID = int64(len(s.created) + 1)
}
if task.CreatedAt.IsZero() {
task.CreatedAt = time.Now().UTC()
}
task.UpdatedAt = task.CreatedAt
clone := *task
s.created = append(s.created, &clone)
return nil
}
func (s *cleanupRepoStub) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]service.UsageCleanupTask, *pagination.PaginationResult, error) {
s.mu.Lock()
defer s.mu.Unlock()
return s.listTasks, s.listResult, s.listErr
}
func (s *cleanupRepoStub) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*service.UsageCleanupTask, error) {
return nil, nil
}
func (s *cleanupRepoStub) GetTaskStatus(ctx context.Context, taskID int64) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.statusByID == nil {
return "", sql.ErrNoRows
}
status, ok := s.statusByID[taskID]
if !ok {
return "", sql.ErrNoRows
}
return status, nil
}
func (s *cleanupRepoStub) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error {
return nil
}
func (s *cleanupRepoStub) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.statusByID == nil {
s.statusByID = map[int64]string{}
}
status := s.statusByID[taskID]
if status != service.UsageCleanupStatusPending && status != service.UsageCleanupStatusRunning {
return false, nil
}
s.statusByID[taskID] = service.UsageCleanupStatusCanceled
return true, nil
}
func (s *cleanupRepoStub) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error {
return nil
}
func (s *cleanupRepoStub) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error {
return nil
}
func (s *cleanupRepoStub) DeleteUsageLogsBatch(ctx context.Context, filters service.UsageCleanupFilters, limit int) (int64, error) {
return 0, nil
}
var _ service.UsageCleanupRepository = (*cleanupRepoStub)(nil)
func setupCleanupRouter(cleanupService *service.UsageCleanupService, userID int64) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
if userID > 0 {
router.Use(func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: userID})
c.Next()
})
}
handler := NewUsageHandler(nil, nil, nil, cleanupService)
router.POST("/api/v1/admin/usage/cleanup-tasks", handler.CreateCleanupTask)
router.GET("/api/v1/admin/usage/cleanup-tasks", handler.ListCleanupTasks)
router.POST("/api/v1/admin/usage/cleanup-tasks/:id/cancel", handler.CancelCleanupTask)
return router
}
func TestUsageHandlerCreateCleanupTaskUnauthorized(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 0)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString(`{}`))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusUnauthorized, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskUnavailable(t *testing.T) {
router := setupCleanupRouter(nil, 1)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString(`{}`))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusServiceUnavailable, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskBindError(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 88)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString("{bad-json"))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskMissingRange(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 88)
payload := map[string]any{
"start_date": "2024-01-01",
"timezone": "UTC",
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskInvalidDate(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 88)
payload := map[string]any{
"start_date": "2024-13-01",
"end_date": "2024-01-02",
"timezone": "UTC",
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskInvalidEndDate(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 88)
payload := map[string]any{
"start_date": "2024-01-01",
"end_date": "2024-02-40",
"timezone": "UTC",
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskInvalidRequestType(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 88)
payload := map[string]any{
"start_date": "2024-01-01",
"end_date": "2024-01-02",
"timezone": "UTC",
"request_type": "invalid",
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskRequestTypePriority(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 99)
payload := map[string]any{
"start_date": "2024-01-01",
"end_date": "2024-01-02",
"timezone": "UTC",
"request_type": "ws_v2",
"stream": false,
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.created, 1)
created := repo.created[0]
require.NotNil(t, created.Filters.RequestType)
require.Equal(t, int16(service.RequestTypeWSV2), *created.Filters.RequestType)
require.Nil(t, created.Filters.Stream)
}
func TestUsageHandlerCreateCleanupTaskWithLegacyStream(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 99)
payload := map[string]any{
"start_date": "2024-01-01",
"end_date": "2024-01-02",
"timezone": "UTC",
"stream": true,
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.created, 1)
created := repo.created[0]
require.Nil(t, created.Filters.RequestType)
require.NotNil(t, created.Filters.Stream)
require.True(t, *created.Filters.Stream)
}
func TestUsageHandlerCreateCleanupTaskSuccess(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 99)
payload := map[string]any{
"start_date": " 2024-01-01 ",
"end_date": "2024-01-02",
"timezone": "UTC",
"model": "gpt-4",
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
var resp response.Response
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.created, 1)
created := repo.created[0]
require.Equal(t, int64(99), created.CreatedBy)
require.NotNil(t, created.Filters.Model)
require.Equal(t, "gpt-4", *created.Filters.Model)
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC).Add(24*time.Hour - time.Nanosecond)
require.True(t, created.Filters.StartTime.Equal(start))
require.True(t, created.Filters.EndTime.Equal(end))
}
func TestUsageHandlerListCleanupTasksUnavailable(t *testing.T) {
router := setupCleanupRouter(nil, 0)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusServiceUnavailable, recorder.Code)
}
func TestUsageHandlerListCleanupTasksSuccess(t *testing.T) {
repo := &cleanupRepoStub{}
repo.listTasks = []service.UsageCleanupTask{
{
ID: 7,
Status: service.UsageCleanupStatusSucceeded,
CreatedBy: 4,
},
}
repo.listResult = &pagination.PaginationResult{Total: 1, Page: 1, PageSize: 20, Pages: 1}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 1)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
var resp struct {
Code int `json:"code"`
Data struct {
Items []dto.UsageCleanupTask `json:"items"`
Total int64 `json:"total"`
Page int `json:"page"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Len(t, resp.Data.Items, 1)
require.Equal(t, int64(7), resp.Data.Items[0].ID)
require.Equal(t, int64(1), resp.Data.Total)
require.Equal(t, 1, resp.Data.Page)
}
func TestUsageHandlerListCleanupTasksError(t *testing.T) {
repo := &cleanupRepoStub{listErr: errors.New("boom")}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 1)
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusInternalServerError, recorder.Code)
}
func TestUsageHandlerCancelCleanupTaskUnauthorized(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 0)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/1/cancel", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusUnauthorized, rec.Code)
}
func TestUsageHandlerCancelCleanupTaskNotFound(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 1)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/999/cancel", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestUsageHandlerCancelCleanupTaskConflict(t *testing.T) {
repo := &cleanupRepoStub{statusByID: map[int64]string{2: service.UsageCleanupStatusSucceeded}}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 1)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/2/cancel", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusConflict, rec.Code)
}
func TestUsageHandlerCancelCleanupTaskSuccess(t *testing.T) {
repo := &cleanupRepoStub{statusByID: map[int64]string{3: service.UsageCleanupStatusPending}}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 1)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/3/cancel", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}

View File

@@ -0,0 +1,585 @@
package admin
import (
"context"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// UsageHandler handles admin usage-related requests
type UsageHandler struct {
usageService *service.UsageService
apiKeyService *service.APIKeyService
adminService service.AdminService
cleanupService *service.UsageCleanupService
}
// NewUsageHandler creates a new admin usage handler
func NewUsageHandler(
usageService *service.UsageService,
apiKeyService *service.APIKeyService,
adminService service.AdminService,
cleanupService *service.UsageCleanupService,
) *UsageHandler {
return &UsageHandler{
usageService: usageService,
apiKeyService: apiKeyService,
adminService: adminService,
cleanupService: cleanupService,
}
}
// CreateUsageCleanupTaskRequest represents cleanup task creation request
type CreateUsageCleanupTaskRequest struct {
StartDate string `json:"start_date"`
EndDate string `json:"end_date"`
UserID *int64 `json:"user_id"`
APIKeyID *int64 `json:"api_key_id"`
AccountID *int64 `json:"account_id"`
GroupID *int64 `json:"group_id"`
Model *string `json:"model"`
RequestType *string `json:"request_type"`
Stream *bool `json:"stream"`
BillingType *int8 `json:"billing_type"`
Timezone string `json:"timezone"`
}
// List handles listing all usage records with filters
// GET /api/v1/admin/usage
func (h *UsageHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
exactTotal := false
if exactTotalRaw := strings.TrimSpace(c.Query("exact_total")); exactTotalRaw != "" {
parsed, err := strconv.ParseBool(exactTotalRaw)
if err != nil {
response.BadRequest(c, "Invalid exact_total value, use true or false")
return
}
exactTotal = parsed
}
// Parse filters
var userID, apiKeyID, accountID, groupID int64
if userIDStr := c.Query("user_id"); userIDStr != "" {
id, err := strconv.ParseInt(userIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user_id")
return
}
userID = id
}
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid api_key_id")
return
}
apiKeyID = id
}
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
id, err := strconv.ParseInt(accountIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account_id")
return
}
accountID = id
}
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
id, err := strconv.ParseInt(groupIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group_id")
return
}
groupID = id
}
model := c.Query("model")
var requestType *int16
var stream *bool
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
response.BadRequest(c, err.Error())
return
}
value := int16(parsed)
requestType = &value
} else if streamStr := c.Query("stream"); streamStr != "" {
val, err := strconv.ParseBool(streamStr)
if err != nil {
response.BadRequest(c, "Invalid stream value, use true or false")
return
}
stream = &val
}
var billingType *int8
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
val, err := strconv.ParseInt(billingTypeStr, 10, 8)
if err != nil {
response.BadRequest(c, "Invalid billing_type")
return
}
bt := int8(val)
billingType = &bt
}
// Parse date range
var startTime, endTime *time.Time
userTZ := c.Query("timezone") // Get user's timezone from request
if startDateStr := c.Query("start_date"); startDateStr != "" {
t, err := timezone.ParseInUserLocation("2006-01-02", startDateStr, userTZ)
if err != nil {
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
return
}
startTime = &t
}
if endDateStr := c.Query("end_date"); endDateStr != "" {
t, err := timezone.ParseInUserLocation("2006-01-02", endDateStr, userTZ)
if err != nil {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
}
// Use half-open range [start, end), move to next calendar day start (DST-safe).
t = t.AddDate(0, 0, 1)
endTime = &t
}
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
filters := usagestats.UsageLogFilters{
UserID: userID,
APIKeyID: apiKeyID,
AccountID: accountID,
GroupID: groupID,
Model: model,
RequestType: requestType,
Stream: stream,
BillingType: billingType,
StartTime: startTime,
EndTime: endTime,
ExactTotal: exactTotal,
}
records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.AdminUsageLog, 0, len(records))
for i := range records {
out = append(out, *dto.UsageLogFromServiceAdmin(&records[i]))
}
response.Paginated(c, out, result.Total, page, pageSize)
}
// Stats handles getting usage statistics with filters
// GET /api/v1/admin/usage/stats
func (h *UsageHandler) Stats(c *gin.Context) {
// Parse filters - same as List endpoint
var userID, apiKeyID, accountID, groupID int64
if userIDStr := c.Query("user_id"); userIDStr != "" {
id, err := strconv.ParseInt(userIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user_id")
return
}
userID = id
}
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid api_key_id")
return
}
apiKeyID = id
}
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
id, err := strconv.ParseInt(accountIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account_id")
return
}
accountID = id
}
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
id, err := strconv.ParseInt(groupIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group_id")
return
}
groupID = id
}
model := c.Query("model")
var requestType *int16
var stream *bool
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
response.BadRequest(c, err.Error())
return
}
value := int16(parsed)
requestType = &value
} else if streamStr := c.Query("stream"); streamStr != "" {
val, err := strconv.ParseBool(streamStr)
if err != nil {
response.BadRequest(c, "Invalid stream value, use true or false")
return
}
stream = &val
}
var billingType *int8
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
val, err := strconv.ParseInt(billingTypeStr, 10, 8)
if err != nil {
response.BadRequest(c, "Invalid billing_type")
return
}
bt := int8(val)
billingType = &bt
}
// Parse date range
userTZ := c.Query("timezone")
now := timezone.NowInUserLocation(userTZ)
var startTime, endTime time.Time
startDateStr := c.Query("start_date")
endDateStr := c.Query("end_date")
if startDateStr != "" && endDateStr != "" {
var err error
startTime, err = timezone.ParseInUserLocation("2006-01-02", startDateStr, userTZ)
if err != nil {
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
return
}
endTime, err = timezone.ParseInUserLocation("2006-01-02", endDateStr, userTZ)
if err != nil {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
}
// 与 SQL 条件 created_at < end 对齐,使用次日 00:00 作为上边界DST-safe
endTime = endTime.AddDate(0, 0, 1)
} else {
period := c.DefaultQuery("period", "today")
switch period {
case "today":
startTime = timezone.StartOfDayInUserLocation(now, userTZ)
case "week":
startTime = now.AddDate(0, 0, -7)
case "month":
startTime = now.AddDate(0, -1, 0)
default:
startTime = timezone.StartOfDayInUserLocation(now, userTZ)
}
endTime = now
}
// Build filters and call GetStatsWithFilters
filters := usagestats.UsageLogFilters{
UserID: userID,
APIKeyID: apiKeyID,
AccountID: accountID,
GroupID: groupID,
Model: model,
RequestType: requestType,
Stream: stream,
BillingType: billingType,
StartTime: &startTime,
EndTime: &endTime,
}
stats, err := h.usageService.GetStatsWithFilters(c.Request.Context(), filters)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, stats)
}
// SearchUsers handles searching users by email keyword
// GET /api/v1/admin/usage/search-users
func (h *UsageHandler) SearchUsers(c *gin.Context) {
keyword := c.Query("q")
if keyword == "" {
response.Success(c, []any{})
return
}
// Limit to 30 results
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, service.UserListFilters{Search: keyword})
if err != nil {
response.ErrorFrom(c, err)
return
}
// Return simplified user list (only id and email)
type SimpleUser struct {
ID int64 `json:"id"`
Email string `json:"email"`
}
result := make([]SimpleUser, len(users))
for i, u := range users {
result[i] = SimpleUser{
ID: u.ID,
Email: u.Email,
}
}
response.Success(c, result)
}
// SearchAPIKeys handles searching API keys by user
// GET /api/v1/admin/usage/search-api-keys
func (h *UsageHandler) SearchAPIKeys(c *gin.Context) {
userIDStr := c.Query("user_id")
keyword := c.Query("q")
var userID int64
if userIDStr != "" {
id, err := strconv.ParseInt(userIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user_id")
return
}
userID = id
}
keys, err := h.apiKeyService.SearchAPIKeys(c.Request.Context(), userID, keyword, 30)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Return simplified API key list (only id and name)
type SimpleAPIKey struct {
ID int64 `json:"id"`
Name string `json:"name"`
UserID int64 `json:"user_id"`
}
result := make([]SimpleAPIKey, len(keys))
for i, k := range keys {
result[i] = SimpleAPIKey{
ID: k.ID,
Name: k.Name,
UserID: k.UserID,
}
}
response.Success(c, result)
}
// ListCleanupTasks handles listing usage cleanup tasks
// GET /api/v1/admin/usage/cleanup-tasks
func (h *UsageHandler) ListCleanupTasks(c *gin.Context) {
if h.cleanupService == nil {
response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
return
}
operator := int64(0)
if subject, ok := middleware.GetAuthSubjectFromContext(c); ok {
operator = subject.UserID
}
page, pageSize := response.ParsePagination(c)
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize)
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
tasks, result, err := h.cleanupService.ListTasks(c.Request.Context(), params)
if err != nil {
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err)
response.ErrorFrom(c, err)
return
}
out := make([]dto.UsageCleanupTask, 0, len(tasks))
for i := range tasks {
out = append(out, *dto.UsageCleanupTaskFromService(&tasks[i]))
}
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize)
response.Paginated(c, out, result.Total, page, pageSize)
}
// CreateCleanupTask handles creating a usage cleanup task
// POST /api/v1/admin/usage/cleanup-tasks
func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
if h.cleanupService == nil {
response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Unauthorized(c, "Unauthorized")
return
}
var req CreateUsageCleanupTaskRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
req.StartDate = strings.TrimSpace(req.StartDate)
req.EndDate = strings.TrimSpace(req.EndDate)
if req.StartDate == "" || req.EndDate == "" {
response.BadRequest(c, "start_date and end_date are required")
return
}
startTime, err := timezone.ParseInUserLocation("2006-01-02", req.StartDate, req.Timezone)
if err != nil {
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
return
}
endTime, err := timezone.ParseInUserLocation("2006-01-02", req.EndDate, req.Timezone)
if err != nil {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
}
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
var requestType *int16
stream := req.Stream
if req.RequestType != nil {
parsed, err := service.ParseUsageRequestType(*req.RequestType)
if err != nil {
response.BadRequest(c, err.Error())
return
}
value := int16(parsed)
requestType = &value
stream = nil
}
filters := service.UsageCleanupFilters{
StartTime: startTime,
EndTime: endTime,
UserID: req.UserID,
APIKeyID: req.APIKeyID,
AccountID: req.AccountID,
GroupID: req.GroupID,
Model: req.Model,
RequestType: requestType,
Stream: stream,
BillingType: req.BillingType,
}
var userID any
if filters.UserID != nil {
userID = *filters.UserID
}
var apiKeyID any
if filters.APIKeyID != nil {
apiKeyID = *filters.APIKeyID
}
var accountID any
if filters.AccountID != nil {
accountID = *filters.AccountID
}
var groupID any
if filters.GroupID != nil {
groupID = *filters.GroupID
}
var model any
if filters.Model != nil {
model = *filters.Model
}
var streamValue any
if filters.Stream != nil {
streamValue = *filters.Stream
}
var requestTypeName any
if filters.RequestType != nil {
requestTypeName = service.RequestTypeFromInt16(*filters.RequestType).String()
}
var billingType any
if filters.BillingType != nil {
billingType = *filters.BillingType
}
idempotencyPayload := struct {
OperatorID int64 `json:"operator_id"`
Body CreateUsageCleanupTaskRequest `json:"body"`
}{
OperatorID: subject.UserID,
Body: req,
}
executeAdminIdempotentJSON(c, "admin.usage.cleanup_tasks.create", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v request_type=%v stream=%v billing_type=%v tz=%q",
subject.UserID,
filters.StartTime.Format(time.RFC3339),
filters.EndTime.Format(time.RFC3339),
userID,
apiKeyID,
accountID,
groupID,
model,
requestTypeName,
streamValue,
billingType,
req.Timezone,
)
task, err := h.cleanupService.CreateTask(ctx, filters, subject.UserID)
if err != nil {
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
return nil, err
}
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
return dto.UsageCleanupTaskFromService(task), nil
})
}
// CancelCleanupTask handles canceling a usage cleanup task
// POST /api/v1/admin/usage/cleanup-tasks/:id/cancel
func (h *UsageHandler) CancelCleanupTask(c *gin.Context) {
if h.cleanupService == nil {
response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Unauthorized(c, "Unauthorized")
return
}
idStr := strings.TrimSpace(c.Param("id"))
taskID, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || taskID <= 0 {
response.BadRequest(c, "Invalid task id")
return
}
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID)
if err := h.cleanupService.CancelTask(c.Request.Context(), taskID, subject.UserID); err != nil {
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err)
response.ErrorFrom(c, err)
return
}
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID)
response.Success(c, gin.H{"id": taskID, "status": service.UsageCleanupStatusCanceled})
}

View File

@@ -0,0 +1,140 @@
package admin
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type adminUsageRepoCapture struct {
service.UsageLogRepository
listFilters usagestats.UsageLogFilters
statsFilters usagestats.UsageLogFilters
}
func (s *adminUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
s.listFilters = filters
return []service.UsageLog{}, &pagination.PaginationResult{
Total: 0,
Page: params.Page,
PageSize: params.PageSize,
Pages: 0,
}, nil
}
func (s *adminUsageRepoCapture) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
s.statsFilters = filters
return &usagestats.UsageStats{}, nil
}
func newAdminUsageRequestTypeTestRouter(repo *adminUsageRepoCapture) *gin.Engine {
gin.SetMode(gin.TestMode)
usageSvc := service.NewUsageService(repo, nil, nil, nil)
handler := NewUsageHandler(usageSvc, nil, nil, nil)
router := gin.New()
router.GET("/admin/usage", handler.List)
router.GET("/admin/usage/stats", handler.Stats)
return router
}
func TestAdminUsageListRequestTypePriority(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage?request_type=ws_v2&stream=false", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, repo.listFilters.RequestType)
require.Equal(t, int16(service.RequestTypeWSV2), *repo.listFilters.RequestType)
require.Nil(t, repo.listFilters.Stream)
}
func TestAdminUsageListInvalidRequestType(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage?request_type=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestAdminUsageListInvalidStream(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage?stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestAdminUsageListExactTotalTrue(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage?exact_total=true", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.True(t, repo.listFilters.ExactTotal)
}
func TestAdminUsageListInvalidExactTotal(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage?exact_total=oops", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestAdminUsageStatsRequestTypePriority(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?request_type=stream&stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, repo.statsFilters.RequestType)
require.Equal(t, int16(service.RequestTypeStream), *repo.statsFilters.RequestType)
require.Nil(t, repo.statsFilters.Stream)
}
func TestAdminUsageStatsInvalidRequestType(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?request_type=oops", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestAdminUsageStatsInvalidStream(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?stream=oops", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}

View File

@@ -0,0 +1,362 @@
package admin
import (
"encoding/json"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// UserAttributeHandler handles user attribute management
type UserAttributeHandler struct {
attrService *service.UserAttributeService
}
// NewUserAttributeHandler creates a new handler
func NewUserAttributeHandler(attrService *service.UserAttributeService) *UserAttributeHandler {
return &UserAttributeHandler{attrService: attrService}
}
// --- Request/Response DTOs ---
// CreateAttributeDefinitionRequest represents create attribute definition request
type CreateAttributeDefinitionRequest struct {
Key string `json:"key" binding:"required,min=1,max=100"`
Name string `json:"name" binding:"required,min=1,max=255"`
Description string `json:"description"`
Type string `json:"type" binding:"required"`
Options []service.UserAttributeOption `json:"options"`
Required bool `json:"required"`
Validation service.UserAttributeValidation `json:"validation"`
Placeholder string `json:"placeholder"`
Enabled bool `json:"enabled"`
}
// UpdateAttributeDefinitionRequest represents update attribute definition request
type UpdateAttributeDefinitionRequest struct {
Name *string `json:"name"`
Description *string `json:"description"`
Type *string `json:"type"`
Options *[]service.UserAttributeOption `json:"options"`
Required *bool `json:"required"`
Validation *service.UserAttributeValidation `json:"validation"`
Placeholder *string `json:"placeholder"`
Enabled *bool `json:"enabled"`
}
// ReorderRequest represents reorder attribute definitions request
type ReorderRequest struct {
IDs []int64 `json:"ids" binding:"required"`
}
// UpdateUserAttributesRequest represents update user attributes request
type UpdateUserAttributesRequest struct {
Values map[int64]string `json:"values" binding:"required"`
}
// BatchGetUserAttributesRequest represents batch get user attributes request
type BatchGetUserAttributesRequest struct {
UserIDs []int64 `json:"user_ids" binding:"required"`
}
// BatchUserAttributesResponse represents batch user attributes response
type BatchUserAttributesResponse struct {
// Map of userID -> map of attributeID -> value
Attributes map[int64]map[int64]string `json:"attributes"`
}
var userAttributesBatchCache = newSnapshotCache(30 * time.Second)
// AttributeDefinitionResponse represents attribute definition response
type AttributeDefinitionResponse struct {
ID int64 `json:"id"`
Key string `json:"key"`
Name string `json:"name"`
Description string `json:"description"`
Type string `json:"type"`
Options []service.UserAttributeOption `json:"options"`
Required bool `json:"required"`
Validation service.UserAttributeValidation `json:"validation"`
Placeholder string `json:"placeholder"`
DisplayOrder int `json:"display_order"`
Enabled bool `json:"enabled"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
// AttributeValueResponse represents attribute value response
type AttributeValueResponse struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
AttributeID int64 `json:"attribute_id"`
Value string `json:"value"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
// --- Helpers ---
func defToResponse(def *service.UserAttributeDefinition) *AttributeDefinitionResponse {
return &AttributeDefinitionResponse{
ID: def.ID,
Key: def.Key,
Name: def.Name,
Description: def.Description,
Type: string(def.Type),
Options: def.Options,
Required: def.Required,
Validation: def.Validation,
Placeholder: def.Placeholder,
DisplayOrder: def.DisplayOrder,
Enabled: def.Enabled,
CreatedAt: def.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
UpdatedAt: def.UpdatedAt.Format("2006-01-02T15:04:05Z07:00"),
}
}
func valueToResponse(val *service.UserAttributeValue) *AttributeValueResponse {
return &AttributeValueResponse{
ID: val.ID,
UserID: val.UserID,
AttributeID: val.AttributeID,
Value: val.Value,
CreatedAt: val.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
UpdatedAt: val.UpdatedAt.Format("2006-01-02T15:04:05Z07:00"),
}
}
// --- Handlers ---
// ListDefinitions lists all attribute definitions
// GET /admin/user-attributes
func (h *UserAttributeHandler) ListDefinitions(c *gin.Context) {
enabledOnly := c.Query("enabled") == "true"
defs, err := h.attrService.ListDefinitions(c.Request.Context(), enabledOnly)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]*AttributeDefinitionResponse, 0, len(defs))
for i := range defs {
out = append(out, defToResponse(&defs[i]))
}
response.Success(c, out)
}
// CreateDefinition creates a new attribute definition
// POST /admin/user-attributes
func (h *UserAttributeHandler) CreateDefinition(c *gin.Context) {
var req CreateAttributeDefinitionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
def, err := h.attrService.CreateDefinition(c.Request.Context(), service.CreateAttributeDefinitionInput{
Key: req.Key,
Name: req.Name,
Description: req.Description,
Type: service.UserAttributeType(req.Type),
Options: req.Options,
Required: req.Required,
Validation: req.Validation,
Placeholder: req.Placeholder,
Enabled: req.Enabled,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, defToResponse(def))
}
// UpdateDefinition updates an attribute definition
// PUT /admin/user-attributes/:id
func (h *UserAttributeHandler) UpdateDefinition(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid attribute ID")
return
}
var req UpdateAttributeDefinitionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
input := service.UpdateAttributeDefinitionInput{
Name: req.Name,
Description: req.Description,
Options: req.Options,
Required: req.Required,
Validation: req.Validation,
Placeholder: req.Placeholder,
Enabled: req.Enabled,
}
if req.Type != nil {
t := service.UserAttributeType(*req.Type)
input.Type = &t
}
def, err := h.attrService.UpdateDefinition(c.Request.Context(), id, input)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, defToResponse(def))
}
// DeleteDefinition deletes an attribute definition
// DELETE /admin/user-attributes/:id
func (h *UserAttributeHandler) DeleteDefinition(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid attribute ID")
return
}
if err := h.attrService.DeleteDefinition(c.Request.Context(), id); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Attribute definition deleted successfully"})
}
// ReorderDefinitions reorders attribute definitions
// PUT /admin/user-attributes/reorder
func (h *UserAttributeHandler) ReorderDefinitions(c *gin.Context) {
var req ReorderRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Convert IDs array to orders map (position in array = display_order)
orders := make(map[int64]int, len(req.IDs))
for i, id := range req.IDs {
orders[id] = i
}
if err := h.attrService.ReorderDefinitions(c.Request.Context(), orders); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Reorder successful"})
}
// GetUserAttributes gets a user's attribute values
// GET /admin/users/:id/attributes
func (h *UserAttributeHandler) GetUserAttributes(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
values, err := h.attrService.GetUserAttributes(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]*AttributeValueResponse, 0, len(values))
for i := range values {
out = append(out, valueToResponse(&values[i]))
}
response.Success(c, out)
}
// UpdateUserAttributes updates a user's attribute values
// PUT /admin/users/:id/attributes
func (h *UserAttributeHandler) UpdateUserAttributes(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
var req UpdateUserAttributesRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
inputs := make([]service.UpdateUserAttributeInput, 0, len(req.Values))
for attrID, value := range req.Values {
inputs = append(inputs, service.UpdateUserAttributeInput{
AttributeID: attrID,
Value: value,
})
}
if err := h.attrService.UpdateUserAttributes(c.Request.Context(), userID, inputs); err != nil {
response.ErrorFrom(c, err)
return
}
// Return updated values
values, err := h.attrService.GetUserAttributes(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]*AttributeValueResponse, 0, len(values))
for i := range values {
out = append(out, valueToResponse(&values[i]))
}
response.Success(c, out)
}
// GetBatchUserAttributes gets attribute values for multiple users
// POST /admin/user-attributes/batch
func (h *UserAttributeHandler) GetBatchUserAttributes(c *gin.Context) {
var req BatchGetUserAttributesRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
userIDs := normalizeInt64IDList(req.UserIDs)
if len(userIDs) == 0 {
response.Success(c, BatchUserAttributesResponse{Attributes: map[int64]map[int64]string{}})
return
}
keyRaw, _ := json.Marshal(struct {
UserIDs []int64 `json:"user_ids"`
}{
UserIDs: userIDs,
})
cacheKey := string(keyRaw)
if cached, ok := userAttributesBatchCache.Get(cacheKey); ok {
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
attrs, err := h.attrService.GetBatchUserAttributes(c.Request.Context(), userIDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
payload := BatchUserAttributesResponse{Attributes: attrs}
userAttributesBatchCache.Set(cacheKey, payload)
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, payload)
}

View File

@@ -0,0 +1,368 @@
package admin
import (
"context"
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// UserWithConcurrency wraps AdminUser with current concurrency info
type UserWithConcurrency struct {
dto.AdminUser
CurrentConcurrency int `json:"current_concurrency"`
}
// UserHandler handles admin user management
type UserHandler struct {
adminService service.AdminService
concurrencyService *service.ConcurrencyService
}
// NewUserHandler creates a new admin user handler
func NewUserHandler(adminService service.AdminService, concurrencyService *service.ConcurrencyService) *UserHandler {
return &UserHandler{
adminService: adminService,
concurrencyService: concurrencyService,
}
}
// CreateUserRequest represents admin create user request
type CreateUserRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
Username string `json:"username"`
Notes string `json:"notes"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
AllowedGroups []int64 `json:"allowed_groups"`
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
}
// UpdateUserRequest represents admin update user request
// 使用指针类型来区分"未提供"和"设置为0"
type UpdateUserRequest struct {
Email string `json:"email" binding:"omitempty,email"`
Password string `json:"password" binding:"omitempty,min=6"`
Username *string `json:"username"`
Notes *string `json:"notes"`
Balance *float64 `json:"balance"`
Concurrency *int `json:"concurrency"`
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
AllowedGroups *[]int64 `json:"allowed_groups"`
// GroupRates 用户专属分组倍率配置
// map[groupID]*ratenil 表示删除该分组的专属倍率
GroupRates map[int64]*float64 `json:"group_rates"`
SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
}
// UpdateBalanceRequest represents balance update request
type UpdateBalanceRequest struct {
Balance float64 `json:"balance" binding:"required,gt=0"`
Operation string `json:"operation" binding:"required,oneof=set add subtract"`
Notes string `json:"notes"`
}
// List handles listing all users with pagination
// GET /api/v1/admin/users
// Query params:
// - status: filter by user status
// - role: filter by user role
// - search: search in email, username
// - attr[{id}]: filter by custom attribute value, e.g. attr[1]=company
func (h *UserHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
search := c.Query("search")
// 标准化和验证 search 参数
search = strings.TrimSpace(search)
if runes := []rune(search); len(runes) > 100 {
search = string(runes[:100])
}
filters := service.UserListFilters{
Status: c.Query("status"),
Role: c.Query("role"),
Search: search,
Attributes: parseAttributeFilters(c),
}
if raw, ok := c.GetQuery("include_subscriptions"); ok {
includeSubscriptions := parseBoolQueryWithDefault(raw, true)
filters.IncludeSubscriptions = &includeSubscriptions
}
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Batch get current concurrency (nil map if unavailable)
var loadInfo map[int64]*service.UserLoadInfo
if len(users) > 0 && h.concurrencyService != nil {
usersConcurrency := make([]service.UserWithConcurrency, len(users))
for i := range users {
usersConcurrency[i] = service.UserWithConcurrency{
ID: users[i].ID,
MaxConcurrency: users[i].Concurrency,
}
}
loadInfo, _ = h.concurrencyService.GetUsersLoadBatch(c.Request.Context(), usersConcurrency)
}
// Build response with concurrency info
out := make([]UserWithConcurrency, len(users))
for i := range users {
out[i] = UserWithConcurrency{
AdminUser: *dto.UserFromServiceAdmin(&users[i]),
}
if info := loadInfo[users[i].ID]; info != nil {
out[i].CurrentConcurrency = info.CurrentConcurrency
}
}
response.Paginated(c, out, total, page, pageSize)
}
// parseAttributeFilters extracts attribute filters from query params
// Format: attr[{attributeID}]=value, e.g. attr[1]=company&attr[2]=developer
func parseAttributeFilters(c *gin.Context) map[int64]string {
result := make(map[int64]string)
// Get all query params and look for attr[*] pattern
for key, values := range c.Request.URL.Query() {
if len(values) == 0 || values[0] == "" {
continue
}
// Check if key matches pattern attr[{id}]
if len(key) > 5 && key[:5] == "attr[" && key[len(key)-1] == ']' {
idStr := key[5 : len(key)-1]
id, err := strconv.ParseInt(idStr, 10, 64)
if err == nil && id > 0 {
result[id] = values[0]
}
}
}
return result
}
// GetByID handles getting a user by ID
// GET /api/v1/admin/users/:id
func (h *UserHandler) GetByID(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
user, err := h.adminService.GetUser(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.UserFromServiceAdmin(user))
}
// Create handles creating a new user
// POST /api/v1/admin/users
func (h *UserHandler) Create(c *gin.Context) {
var req CreateUserRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{
Email: req.Email,
Password: req.Password,
Username: req.Username,
Notes: req.Notes,
Balance: req.Balance,
Concurrency: req.Concurrency,
AllowedGroups: req.AllowedGroups,
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.UserFromServiceAdmin(user))
}
// Update handles updating a user
// PUT /api/v1/admin/users/:id
func (h *UserHandler) Update(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
var req UpdateUserRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// 使用指针类型直接传递nil 表示未提供该字段
user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{
Email: req.Email,
Password: req.Password,
Username: req.Username,
Notes: req.Notes,
Balance: req.Balance,
Concurrency: req.Concurrency,
Status: req.Status,
AllowedGroups: req.AllowedGroups,
GroupRates: req.GroupRates,
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.UserFromServiceAdmin(user))
}
// Delete handles deleting a user
// DELETE /api/v1/admin/users/:id
func (h *UserHandler) Delete(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
err = h.adminService.DeleteUser(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "User deleted successfully"})
}
// UpdateBalance handles updating user balance
// POST /api/v1/admin/users/:id/balance
func (h *UserHandler) UpdateBalance(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
var req UpdateBalanceRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
idempotencyPayload := struct {
UserID int64 `json:"user_id"`
Body UpdateBalanceRequest `json:"body"`
}{
UserID: userID,
Body: req,
}
executeAdminIdempotentJSON(c, "admin.users.balance.update", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
user, execErr := h.adminService.UpdateUserBalance(ctx, userID, req.Balance, req.Operation, req.Notes)
if execErr != nil {
return nil, execErr
}
return dto.UserFromServiceAdmin(user), nil
})
}
// GetUserAPIKeys handles getting user's API keys
// GET /api/v1/admin/users/:id/api-keys
func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
page, pageSize := response.ParsePagination(c)
keys, total, err := h.adminService.GetUserAPIKeys(c.Request.Context(), userID, page, pageSize)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.APIKey, 0, len(keys))
for i := range keys {
out = append(out, *dto.APIKeyFromService(&keys[i]))
}
response.Paginated(c, out, total, page, pageSize)
}
// GetUserUsage handles getting user's usage statistics
// GET /api/v1/admin/users/:id/usage
func (h *UserHandler) GetUserUsage(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
period := c.DefaultQuery("period", "month")
stats, err := h.adminService.GetUserUsageStats(c.Request.Context(), userID, period)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, stats)
}
// GetBalanceHistory handles getting user's balance/concurrency change history
// GET /api/v1/admin/users/:id/balance-history
// Query params:
// - type: filter by record type (balance, admin_balance, concurrency, admin_concurrency, subscription)
func (h *UserHandler) GetBalanceHistory(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
page, pageSize := response.ParsePagination(c)
codeType := c.Query("type")
codes, total, totalRecharged, err := h.adminService.GetUserBalanceHistory(c.Request.Context(), userID, page, pageSize, codeType)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Convert to admin DTO (includes notes field for admin visibility)
out := make([]dto.AdminRedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
}
// Custom response with total_recharged alongside pagination
pages := int((total + int64(pageSize) - 1) / int64(pageSize))
if pages < 1 {
pages = 1
}
response.Success(c, gin.H{
"items": out,
"total": total,
"page": page,
"page_size": pageSize,
"pages": pages,
"total_recharged": totalRecharged,
})
}