package httpapi import ( "context" "encoding/json" "errors" "net/http" "net/http/httptest" "strings" "testing" "time" "lijiaoqiao/supply-api/internal/audit" "lijiaoqiao/supply-api/internal/domain" "lijiaoqiao/supply-api/internal/middleware" ) // ==================== Mock Implementations ==================== // mockAccountService Mock账户服务 type mockAccountService struct { verifyResult *domain.VerifyResult verifyErr error account *domain.Account createErr error activateErr error suspendErr error deleteErr error lastVerifySupplierID int64 } func (m *mockAccountService) Verify(ctx context.Context, supplierID int64, provider domain.Provider, accountType domain.AccountType, credential string) (*domain.VerifyResult, error) { m.lastVerifySupplierID = supplierID if m.verifyErr != nil { return nil, m.verifyErr } return m.verifyResult, nil } func (m *mockAccountService) Create(ctx context.Context, req *domain.CreateAccountRequest) (*domain.Account, error) { if m.createErr != nil { return nil, m.createErr } return m.account, nil } func (m *mockAccountService) Activate(ctx context.Context, supplierID, accountID int64) (*domain.Account, error) { if m.activateErr != nil { return nil, m.activateErr } return m.account, nil } func (m *mockAccountService) Suspend(ctx context.Context, supplierID, accountID int64) (*domain.Account, error) { if m.suspendErr != nil { return nil, m.suspendErr } return m.account, nil } func (m *mockAccountService) Delete(ctx context.Context, supplierID, accountID int64) error { return m.deleteErr } func (m *mockAccountService) GetByID(ctx context.Context, supplierID, accountID int64) (*domain.Account, error) { return m.account, nil } // mockPackageService Mock套餐服务 type mockPackageService struct { pkg *domain.Package createDraftErr error publishErr error pauseErr error unlistErr error cloneErr error batchResp *domain.BatchUpdatePriceResponse batchErr error } func (m *mockPackageService) CreateDraft(ctx context.Context, supplierID int64, req *domain.CreatePackageDraftRequest) (*domain.Package, error) { if m.createDraftErr != nil { return nil, m.createDraftErr } return m.pkg, nil } func (m *mockPackageService) Publish(ctx context.Context, supplierID, packageID int64) (*domain.Package, error) { if m.publishErr != nil { return nil, m.publishErr } return m.pkg, nil } func (m *mockPackageService) Pause(ctx context.Context, supplierID, packageID int64) (*domain.Package, error) { if m.pauseErr != nil { return nil, m.pauseErr } return m.pkg, nil } func (m *mockPackageService) Unlist(ctx context.Context, supplierID, packageID int64) (*domain.Package, error) { if m.unlistErr != nil { return nil, m.unlistErr } return m.pkg, nil } func (m *mockPackageService) Clone(ctx context.Context, supplierID, packageID int64) (*domain.Package, error) { if m.cloneErr != nil { return nil, m.cloneErr } return m.pkg, nil } func (m *mockPackageService) BatchUpdatePrice(ctx context.Context, supplierID int64, req *domain.BatchUpdatePriceRequest) (*domain.BatchUpdatePriceResponse, error) { if m.batchErr != nil { return nil, m.batchErr } return m.batchResp, nil } func (m *mockPackageService) GetByID(ctx context.Context, supplierID, packageID int64) (*domain.Package, error) { return m.pkg, nil } // mockSettlementService Mock结算服务 type mockSettlementService struct { settlement *domain.Settlement withdrawErr error cancelErr error getErr error } func (m *mockSettlementService) Withdraw(ctx context.Context, supplierID int64, req *domain.WithdrawRequest) (*domain.Settlement, error) { if m.withdrawErr != nil { return nil, m.withdrawErr } return m.settlement, nil } func (m *mockSettlementService) Cancel(ctx context.Context, supplierID, settlementID int64) (*domain.Settlement, error) { if m.cancelErr != nil { return nil, m.cancelErr } return m.settlement, nil } func (m *mockSettlementService) GetByID(ctx context.Context, supplierID, settlementID int64) (*domain.Settlement, error) { if m.getErr != nil { return nil, m.getErr } return m.settlement, nil } func (m *mockSettlementService) List(ctx context.Context, supplierID int64) ([]*domain.Settlement, error) { if m.settlement != nil { return []*domain.Settlement{m.settlement}, nil } return nil, nil } func (m *mockSettlementService) GetBillingSummary(ctx context.Context, supplierID int64, startDate, endDate string) (*domain.BillingSummary, error) { return nil, nil } // mockEarningService Mock收益服务 type mockEarningService struct { records []*domain.EarningRecord total int billingSummary *domain.BillingSummary listErr error billingErr error } func (m *mockEarningService) ListRecords(ctx context.Context, supplierID int64, startDate, endDate string, page, pageSize int) ([]*domain.EarningRecord, int, error) { if m.listErr != nil { return nil, 0, m.listErr } return m.records, m.total, nil } func (m *mockEarningService) GetBillingSummary(ctx context.Context, supplierID int64, startDate, endDate string) (*domain.BillingSummary, error) { if m.billingErr != nil { return nil, m.billingErr } return m.billingSummary, nil } // mockAuditStore Mock审计存储 type mockAuditStore struct { events []audit.Event event audit.Event err error } func (m *mockAuditStore) Emit(ctx context.Context, event audit.Event) error { return m.err } func (m *mockAuditStore) Query(ctx context.Context, filter audit.EventFilter) ([]audit.Event, error) { if m.err != nil { return nil, m.err } return m.events, nil } func (m *mockAuditStore) QueryWithTotal(ctx context.Context, filter audit.EventFilter) ([]audit.Event, int64, error) { if m.err != nil { return nil, 0, m.err } return m.events, int64(len(m.events)), nil } func (m *mockAuditStore) GetByID(ctx context.Context, eventID string) (audit.Event, error) { if m.err != nil { return audit.Event{}, m.err } return m.event, nil } // ==================== Test Helpers ==================== func newTestAPI() (*SupplyAPI, *mockAccountService, *mockPackageService, *mockSettlementService, *mockEarningService, *mockAuditStore) { return newTestAPIWithIdempotency(middleware.NewIdempotencyMiddleware(nil, middleware.IdempotencyConfig{ Enabled: false, })) } func newTestAPIWithoutIdempotencyForTest() (*SupplyAPI, *mockAccountService, *mockPackageService, *mockSettlementService, *mockEarningService, *mockAuditStore) { return newTestAPIWithIdempotency(nil) } func newTestAPIWithIdempotency(idempotencyMw *middleware.IdempotencyMiddleware) (*SupplyAPI, *mockAccountService, *mockPackageService, *mockSettlementService, *mockEarningService, *mockAuditStore) { accountSvc := &mockAccountService{ account: &domain.Account{ ID: 1, SupplierID: 100, Provider: domain.ProviderOpenAI, AccountType: domain.AccountTypeAPIKey, Status: domain.AccountStatusActive, CreatedAt: time.Now(), UpdatedAt: time.Now(), }, verifyResult: &domain.VerifyResult{ VerifyStatus: "pass", AvailableQuota: 1000, RiskScore: 0, }, } packageSvc := &mockPackageService{ pkg: &domain.Package{ ID: 1, SupplierID: 100, Model: "gpt-4", Status: domain.PackageStatusActive, TotalQuota: 10000, AvailableQuota: 8000, CreatedAt: time.Now(), UpdatedAt: time.Now(), }, } settlementSvc := &mockSettlementService{ settlement: &domain.Settlement{ ID: 1, SupplierID: 100, Status: domain.SettlementStatusPending, TotalAmount: 1000, NetAmount: 950, CreatedAt: time.Now(), UpdatedAt: time.Now(), }, } earningSvc := &mockEarningService{ records: []*domain.EarningRecord{ { ID: 1, Amount: 100, Status: "available", }, }, total: 1, billingSummary: &domain.BillingSummary{}, } auditSvc := &mockAuditStore{ events: []audit.Event{ { EventID: "evt_123", TenantID: 100, ObjectType: "supply_account", ObjectID: 1, Action: "create", CreatedAt: time.Now(), }, }, event: audit.Event{ EventID: "evt_123", TenantID: 100, ObjectType: "supply_account", ObjectID: 1, Action: "create", CreatedAt: time.Now(), }, } api, err := NewSupplyAPI( accountSvc, packageSvc, settlementSvc, earningSvc, idempotencyMw, auditSvc, nil, // fkValidator 100, // supplierID "https://statements.example.com", time.Now, ) if err != nil { panic("expected api constructor to succeed: " + err.Error()) } return api, accountSvc, packageSvc, settlementSvc, earningSvc, auditSvc } func TestNewSupplyAPI_ReturnsErrorWhenAccountServiceMissing(t *testing.T) { api, err := NewSupplyAPI( nil, &mockPackageService{}, &mockSettlementService{}, &mockEarningService{}, nil, &mockAuditStore{}, nil, 100, "https://statements.example.com", time.Now, ) if err == nil { t.Fatal("expected error") } if api != nil { t.Fatal("expected nil api") } } func TestNewSupplyAPI_DefaultsClockWhenNil(t *testing.T) { api, err := NewSupplyAPI( &mockAccountService{}, &mockPackageService{}, &mockSettlementService{}, &mockEarningService{}, nil, &mockAuditStore{}, nil, 100, "https://statements.example.com", nil, ) if err != nil { t.Fatalf("expected no error, got %v", err) } if api.now == nil { t.Fatal("expected default clock") } } // ==================== Account Handler Tests ==================== func TestSupplyAPI_VerifyAccount_Success(t *testing.T) { api, _, _, _, _, _ := newTestAPI() body := `{"provider":"openai","account_type":"resource","credential_input":"sk-test123"}` req := httptest.NewRequest("POST", "/api/v1/supply/accounts/verify", strings.NewReader(body)) req.Header.Set("Content-Type", "application/json") req.Header.Set("X-Request-Id", "test-req-001") w := httptest.NewRecorder() api.handleVerifyAccount(w, req) if w.Code != http.StatusOK { t.Errorf("expected status 200, got %d", w.Code) } var resp map[string]any if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { t.Fatalf("failed to unmarshal response: %v", err) } if resp["request_id"] != "test-req-001" { t.Errorf("expected request_id test-req-001, got %v", resp["request_id"]) } } func TestSupplyAPI_VerifyAccount_MethodNotAllowed(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("GET", "/api/v1/supply/accounts/verify", nil) w := httptest.NewRecorder() api.handleVerifyAccount(w, req) if w.Code != http.StatusMethodNotAllowed { t.Errorf("expected status 405, got %d", w.Code) } } func TestSupplyAPI_VerifyAccount_InvalidJSON(t *testing.T) { api, _, _, _, _, _ := newTestAPI() body := `{invalid json}` req := httptest.NewRequest("POST", "/api/v1/supply/accounts/verify", strings.NewReader(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() api.handleVerifyAccount(w, req) if w.Code != http.StatusBadRequest { t.Errorf("expected status 400, got %d", w.Code) } } func TestSupplyAPI_VerifyAccount_VerifyFailed(t *testing.T) { api, accountSvc, _, _, _, _ := newTestAPI() accountSvc.verifyErr = errors.New("SUP_ACC_4001: verification failed") body := `{"provider":"openai","account_type":"resource","credential_input":"invalid"}` req := httptest.NewRequest("POST", "/api/v1/supply/accounts/verify", strings.NewReader(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() api.handleVerifyAccount(w, req) if w.Code != http.StatusUnprocessableEntity { t.Errorf("expected status 422, got %d", w.Code) } } func TestSupplyAPI_VerifyAccount_UsesTenantIDFromContext(t *testing.T) { api, accountSvc, _, _, _, _ := newTestAPI() body := `{"provider":"openai","account_type":"resource","credential_input":"sk-test123"}` req := httptest.NewRequest("POST", "/api/v1/supply/accounts/verify", strings.NewReader(body)) req = req.WithContext(middleware.WithTenantID(req.Context(), 200)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() api.handleVerifyAccount(w, req) if w.Code != http.StatusOK { t.Fatalf("expected status 200, got %d", w.Code) } if accountSvc.lastVerifySupplierID != 200 { t.Fatalf("expected tenant supplier ID 200, got %d", accountSvc.lastVerifySupplierID) } } func TestSupplyAPI_VerifyAccount_RejectsMissingTenantContextWithoutDefaultSupplier(t *testing.T) { api, _, _, _, _, _ := newTestAPI() api.supplierID = 0 body := `{"provider":"openai","account_type":"resource","credential_input":"sk-test123"}` req := httptest.NewRequest("POST", "/api/v1/supply/accounts/verify", strings.NewReader(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() api.handleVerifyAccount(w, req) if w.Code != http.StatusUnauthorized { t.Fatalf("expected status 401, got %d body=%s", w.Code, w.Body.String()) } } func TestSupplyAPI_CreateAccount_Success(t *testing.T) { api, _, _, _, _, _ := newTestAPI() body := `{"provider":"openai","account_type":"resource","credential_input":"sk-test","account_alias":"test","risk_ack":true}` req := httptest.NewRequest("POST", "/api/v1/supply/accounts", strings.NewReader(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() api.handleCreateAccount(w, req) if w.Code != http.StatusCreated { t.Errorf("expected status 201, got %d", w.Code) } } func TestHandleCreateAccount_RequiresIdempotencyMiddleware(t *testing.T) { api, _, _, _, _, _ := newTestAPIWithoutIdempotencyForTest() body := `{"provider":"openai","account_type":"resource","credential_input":"sk-test","account_alias":"test","risk_ack":true}` req := httptest.NewRequest("POST", "/api/v1/supply/accounts", strings.NewReader(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() api.handleCreateAccount(w, req) if w.Code != http.StatusServiceUnavailable { t.Fatalf("expected 503 when idempotency middleware is missing, got=%d body=%s", w.Code, w.Body.String()) } } func TestSupplyAPI_CreateAccount_MethodNotAllowed(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("GET", "/api/v1/supply/accounts", nil) w := httptest.NewRecorder() api.handleCreateAccount(w, req) if w.Code != http.StatusMethodNotAllowed { t.Errorf("expected status 405, got %d", w.Code) } } func TestSupplyAPI_ActivateAccount_Success(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("POST", "/api/v1/supply/accounts/1/activate", nil) w := httptest.NewRecorder() api.handleAccountActions(w, req) if w.Code != http.StatusOK { t.Errorf("expected status 200, got %d", w.Code) } } func TestSupplyAPI_ActivateAccount_NotFound(t *testing.T) { api, accountSvc, _, _, _, _ := newTestAPI() accountSvc.activateErr = errors.New("account not found") req := httptest.NewRequest("POST", "/api/v1/supply/accounts/1/activate", nil) w := httptest.NewRecorder() api.handleAccountActions(w, req) if w.Code != http.StatusNotFound { t.Errorf("expected status 404, got %d", w.Code) } } func TestSupplyAPI_SuspendAccount_Success(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("POST", "/api/v1/supply/accounts/1/suspend", nil) w := httptest.NewRecorder() api.handleAccountActions(w, req) if w.Code != http.StatusOK { t.Errorf("expected status 200, got %d", w.Code) } } func TestSupplyAPI_SuspendAccount_Conflict(t *testing.T) { api, accountSvc, _, _, _, _ := newTestAPI() accountSvc.suspendErr = errors.New("SUP_ACC_4091: account state conflict") req := httptest.NewRequest("POST", "/api/v1/supply/accounts/1/suspend", nil) w := httptest.NewRecorder() api.handleAccountActions(w, req) if w.Code != http.StatusConflict { t.Errorf("expected status 409, got %d", w.Code) } } func TestSupplyAPI_SuspendAccount_WrongMethod(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("GET", "/api/v1/supply/accounts/1/suspend", nil) w := httptest.NewRecorder() api.handleAccountActions(w, req) if w.Code != http.StatusMethodNotAllowed { t.Errorf("expected status 405, got %d", w.Code) } } func TestSupplyAPI_DeleteAccount_Success(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("DELETE", "/api/v1/supply/accounts/1/delete", nil) w := httptest.NewRecorder() api.handleAccountActions(w, req) if w.Code != http.StatusNoContent { t.Errorf("expected status 204, got %d", w.Code) } } func TestSupplyAPI_DeleteAccount_Conflict(t *testing.T) { api, accountSvc, _, _, _, _ := newTestAPI() accountSvc.deleteErr = errors.New("SUP_ACC_4092: cannot delete account with active packages") req := httptest.NewRequest("DELETE", "/api/v1/supply/accounts/1/delete", nil) w := httptest.NewRecorder() api.handleAccountActions(w, req) if w.Code != http.StatusConflict { t.Errorf("expected status 409, got %d", w.Code) } } func TestSupplyAPI_DeleteAccount_WrongMethod(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("POST", "/api/v1/supply/accounts/1/delete", nil) w := httptest.NewRecorder() api.handleAccountActions(w, req) if w.Code != http.StatusMethodNotAllowed { t.Errorf("expected status 405, got %d", w.Code) } } func TestSupplyAPI_AccountAuditLogs_Success(t *testing.T) { api, _, _, _, _, auditSvc := newTestAPI() req := httptest.NewRequest("GET", "/api/v1/supply/accounts/1/audit-logs", nil) w := httptest.NewRecorder() api.handleAccountActions(w, req) if w.Code != http.StatusOK { t.Errorf("expected status 200, got %d", w.Code) } var resp map[string]any if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { t.Fatalf("failed to unmarshal response: %v", err) } data, ok := resp["data"].([]any) if !ok { t.Fatal("expected data array in response") } if len(data) != 1 { t.Errorf("expected 1 event, got %d", len(data)) } auditSvc.err = errors.New("query failed") req = httptest.NewRequest("GET", "/api/v1/supply/accounts/1/audit-logs", nil) w = httptest.NewRecorder() api.handleAccountActions(w, req) if w.Code != http.StatusInternalServerError { t.Errorf("expected status 500, got %d", w.Code) } } func TestSupplyAPI_AccountActions_InvalidID(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("POST", "/api/v1/supply/accounts/invalid/activate", nil) w := httptest.NewRecorder() api.handleAccountActions(w, req) if w.Code != http.StatusBadRequest { t.Errorf("expected status 400, got %d", w.Code) } } func TestSupplyAPI_AccountActions_UnknownRoute(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("POST", "/api/v1/supply/accounts/1/unknown", nil) w := httptest.NewRecorder() api.handleAccountActions(w, req) if w.Code != http.StatusNotFound { t.Errorf("expected status 404, got %d", w.Code) } } // ==================== Package Handler Tests ==================== func TestSupplyAPI_CreatePackageDraft_Success(t *testing.T) { api, _, _, _, _, _ := newTestAPI() body := `{"supply_account_id":1,"model":"gpt-4","total_quota":10000,"price_per_1m_input":0.1,"price_per_1m_output":0.2,"valid_days":30,"max_concurrent":10,"rate_limit_rpm":1000}` req := httptest.NewRequest("POST", "/api/v1/supply/packages/draft", strings.NewReader(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() api.handleCreatePackageDraft(w, req) if w.Code != http.StatusCreated { t.Errorf("expected status 201, got %d", w.Code) } } func TestSupplyAPI_CreatePackageDraft_MethodNotAllowed(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("GET", "/api/v1/supply/packages/draft", nil) w := httptest.NewRecorder() api.handleCreatePackageDraft(w, req) if w.Code != http.StatusMethodNotAllowed { t.Errorf("expected status 405, got %d", w.Code) } } func TestSupplyAPI_PublishPackage_Success(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("POST", "/api/v1/supply/packages/1/publish", nil) w := httptest.NewRecorder() api.handlePackageActions(w, req) if w.Code != http.StatusOK { t.Errorf("expected status 200, got %d", w.Code) } } func TestSupplyAPI_PublishPackage_NotFound(t *testing.T) { api, _, packageSvc, _, _, _ := newTestAPI() packageSvc.publishErr = errors.New("package not found") req := httptest.NewRequest("POST", "/api/v1/supply/packages/1/publish", nil) w := httptest.NewRecorder() api.handlePackageActions(w, req) if w.Code != http.StatusNotFound { t.Errorf("expected status 404, got %d", w.Code) } } func TestSupplyAPI_PausePackage_Success(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("POST", "/api/v1/supply/packages/1/pause", nil) w := httptest.NewRecorder() api.handlePackageActions(w, req) if w.Code != http.StatusOK { t.Errorf("expected status 200, got %d", w.Code) } } func TestSupplyAPI_PausePackage_Conflict(t *testing.T) { api, _, packageSvc, _, _, _ := newTestAPI() packageSvc.pauseErr = errors.New("SUP_PKG_4092: cannot pause active package") req := httptest.NewRequest("POST", "/api/v1/supply/packages/1/pause", nil) w := httptest.NewRecorder() api.handlePackageActions(w, req) if w.Code != http.StatusConflict { t.Errorf("expected status 409, got %d", w.Code) } } func TestSupplyAPI_PausePackage_WrongMethod(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("GET", "/api/v1/supply/packages/1/pause", nil) w := httptest.NewRecorder() api.handlePackageActions(w, req) if w.Code != http.StatusMethodNotAllowed { t.Errorf("expected status 405, got %d", w.Code) } } func TestSupplyAPI_UnlistPackage_Success(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("POST", "/api/v1/supply/packages/1/unlist", nil) w := httptest.NewRecorder() api.handlePackageActions(w, req) if w.Code != http.StatusOK { t.Errorf("expected status 200, got %d", w.Code) } } func TestSupplyAPI_UnlistPackage_Conflict(t *testing.T) { api, _, packageSvc, _, _, _ := newTestAPI() packageSvc.unlistErr = errors.New("SUP_PKG_4093: cannot unlist package") req := httptest.NewRequest("POST", "/api/v1/supply/packages/1/unlist", nil) w := httptest.NewRecorder() api.handlePackageActions(w, req) if w.Code != http.StatusConflict { t.Errorf("expected status 409, got %d", w.Code) } } func TestSupplyAPI_UnlistPackage_WrongMethod(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("GET", "/api/v1/supply/packages/1/unlist", nil) w := httptest.NewRecorder() api.handlePackageActions(w, req) if w.Code != http.StatusMethodNotAllowed { t.Errorf("expected status 405, got %d", w.Code) } } func TestSupplyAPI_ClonePackage_WrongMethod(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("GET", "/api/v1/supply/packages/1/clone", nil) w := httptest.NewRecorder() api.handlePackageActions(w, req) if w.Code != http.StatusMethodNotAllowed { t.Errorf("expected status 405, got %d", w.Code) } } func TestSupplyAPI_ClonePackage_NotFound(t *testing.T) { api, _, packageSvc, _, _, _ := newTestAPI() packageSvc.cloneErr = errors.New("package not found") req := httptest.NewRequest("POST", "/api/v1/supply/packages/1/clone", nil) w := httptest.NewRecorder() api.handlePackageActions(w, req) if w.Code != http.StatusNotFound { t.Errorf("expected status 404, got %d", w.Code) } } func TestSupplyAPI_PublishPackage_WrongMethod(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("GET", "/api/v1/supply/packages/1/publish", nil) w := httptest.NewRecorder() api.handlePackageActions(w, req) if w.Code != http.StatusMethodNotAllowed { t.Errorf("expected status 405, got %d", w.Code) } } func TestSupplyAPI_ClonePackage_Success(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("POST", "/api/v1/supply/packages/1/clone", nil) w := httptest.NewRecorder() api.handlePackageActions(w, req) if w.Code != http.StatusCreated { t.Errorf("expected status 201, got %d", w.Code) } } func TestSupplyAPI_BatchUpdatePrice_Success(t *testing.T) { api, _, _, _, _, _ := newTestAPI() api.packageService.(*mockPackageService).batchResp = &domain.BatchUpdatePriceResponse{ Total: 2, SuccessCount: 2, FailedCount: 0, } body := `{"items":[{"package_id":1,"price_per_1m_input":0.15,"price_per_1m_output":0.25},{"package_id":2,"price_per_1m_input":0.12,"price_per_1m_output":0.22}]}` req := httptest.NewRequest("POST", "/api/v1/supply/packages/batch-price", strings.NewReader(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() api.handleBatchUpdatePrice(w, req) if w.Code != http.StatusOK { t.Errorf("expected status 200, got %d", w.Code) } } func TestSupplyAPI_BatchUpdatePrice_MethodNotAllowed(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("GET", "/api/v1/supply/packages/batch-price", nil) w := httptest.NewRecorder() api.handleBatchUpdatePrice(w, req) if w.Code != http.StatusMethodNotAllowed { t.Errorf("expected status 405, got %d", w.Code) } } func TestSupplyAPI_BatchUpdatePrice_InvalidJSON(t *testing.T) { api, _, _, _, _, _ := newTestAPI() body := `{invalid}` req := httptest.NewRequest("POST", "/api/v1/supply/packages/batch-price", strings.NewReader(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() api.handleBatchUpdatePrice(w, req) if w.Code != http.StatusBadRequest { t.Errorf("expected status 400, got %d", w.Code) } } func TestSupplyAPI_BatchUpdatePrice_BatchFailed(t *testing.T) { api, _, packageSvc, _, _, _ := newTestAPI() packageSvc.batchErr = errors.New("batch update failed") body := `{"items":[{"package_id":1,"price_per_1m_input":0.15,"price_per_1m_output":0.25}]}` req := httptest.NewRequest("POST", "/api/v1/supply/packages/batch-price", strings.NewReader(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() api.handleBatchUpdatePrice(w, req) if w.Code != http.StatusUnprocessableEntity { t.Errorf("expected status 422, got %d", w.Code) } } // ==================== Billing Handler Tests ==================== func TestSupplyAPI_GetBilling_Success(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("GET", "/api/v1/supply/billing?start_date=2024-01-01&end_date=2024-01-31", nil) w := httptest.NewRecorder() api.handleGetBilling(w, req) if w.Code != http.StatusOK { t.Errorf("expected status 200, got %d", w.Code) } } func TestSupplyAPI_GetBilling_MethodNotAllowed(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("POST", "/api/v1/supply/billing", nil) w := httptest.NewRecorder() api.handleGetBilling(w, req) if w.Code != http.StatusMethodNotAllowed { t.Errorf("expected status 405, got %d", w.Code) } } func TestSupplyAPI_GetBilling_QueryFailed(t *testing.T) { api, _, _, _, earningSvc, _ := newTestAPI() earningSvc.billingErr = errors.New("query failed") req := httptest.NewRequest("GET", "/api/v1/supply/billing", nil) w := httptest.NewRecorder() api.handleGetBilling(w, req) if w.Code != http.StatusInternalServerError { t.Errorf("expected status 500, got %d", w.Code) } } // ==================== Settlement Handler Tests ==================== func TestSupplyAPI_Withdraw_Success(t *testing.T) { api, _, _, _, _, _ := newTestAPI() body := `{"withdraw_amount":1000,"payment_method":"bank","payment_account":"1234567890","sms_code":"123456"}` req := httptest.NewRequest("POST", "/api/v1/supply/settlements/withdraw", strings.NewReader(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() api.handleWithdraw(w, req) if w.Code != http.StatusCreated { t.Errorf("expected status 201, got %d", w.Code) } } func TestHandleWithdraw_RequiresIdempotencyMiddleware(t *testing.T) { api, _, _, _, _, _ := newTestAPIWithoutIdempotencyForTest() body := `{"withdraw_amount":1000,"payment_method":"bank","payment_account":"1234567890","sms_code":"123456"}` req := httptest.NewRequest("POST", "/api/v1/supply/settlements/withdraw", strings.NewReader(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() api.handleWithdraw(w, req) if w.Code != http.StatusServiceUnavailable { t.Fatalf("expected 503 when idempotency middleware is missing, got=%d body=%s", w.Code, w.Body.String()) } } func TestSupplyAPI_Withdraw_MethodNotAllowed(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("GET", "/api/v1/supply/settlements/withdraw", nil) w := httptest.NewRecorder() api.handleWithdraw(w, req) if w.Code != http.StatusMethodNotAllowed { t.Errorf("expected status 405, got %d", w.Code) } } func TestSupplyAPI_Withdraw_InvalidJSON(t *testing.T) { api, _, _, _, _, _ := newTestAPI() body := `{invalid}` req := httptest.NewRequest("POST", "/api/v1/supply/settlements/withdraw", strings.NewReader(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() api.handleWithdraw(w, req) if w.Code != http.StatusBadRequest { t.Errorf("expected status 400, got %d", w.Code) } } func TestSupplyAPI_Withdraw_WithdrawFailed(t *testing.T) { api, _, _, settlementSvc, _, _ := newTestAPI() settlementSvc.withdrawErr = errors.New("SUP_SET_4001: insufficient balance") body := `{"withdraw_amount":1000000,"payment_method":"bank","payment_account":"1234567890","sms_code":"123456"}` req := httptest.NewRequest("POST", "/api/v1/supply/settlements/withdraw", strings.NewReader(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() api.handleWithdraw(w, req) if w.Code != http.StatusConflict { t.Errorf("expected status 409, got %d", w.Code) } } func TestSupplyAPI_CancelSettlement_Success(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("POST", "/api/v1/supply/settlements/1/cancel", nil) w := httptest.NewRecorder() api.handleSettlementActions(w, req) if w.Code != http.StatusOK { t.Errorf("expected status 200, got %d", w.Code) } } func TestSupplyAPI_CancelSettlement_NotFound(t *testing.T) { api, _, _, settlementSvc, _, _ := newTestAPI() settlementSvc.cancelErr = errors.New("settlement not found") req := httptest.NewRequest("POST", "/api/v1/supply/settlements/1/cancel", nil) w := httptest.NewRecorder() api.handleSettlementActions(w, req) if w.Code != http.StatusNotFound { t.Errorf("expected status 404, got %d", w.Code) } } func TestSupplyAPI_GetStatement_Success(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("GET", "/api/v1/supply/settlements/1/statement", nil) w := httptest.NewRecorder() api.handleSettlementActions(w, req) if w.Code != http.StatusOK { t.Errorf("expected status 200, got %d", w.Code) } var resp map[string]any if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { t.Fatalf("failed to unmarshal response: %v", err) } data, ok := resp["data"].(map[string]any) if !ok { t.Fatal("expected data in response") } if data["file_name"] == nil { t.Error("expected file_name in data") } if data["download_url"] == nil { t.Error("expected download_url in data") } if data["expires_at"] == nil { t.Error("expected expires_at in data") } } func TestSupplyAPI_GetStatement_NotFound(t *testing.T) { api, _, _, settlementSvc, _, _ := newTestAPI() settlementSvc.getErr = errors.New("settlement not found") req := httptest.NewRequest("GET", "/api/v1/supply/settlements/1/statement", nil) w := httptest.NewRecorder() api.handleSettlementActions(w, req) if w.Code != http.StatusNotFound { t.Errorf("expected status 404, got %d", w.Code) } } func TestSupplyAPI_SettlementActions_InvalidID(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("POST", "/api/v1/supply/settlements/invalid/cancel", nil) w := httptest.NewRecorder() api.handleSettlementActions(w, req) if w.Code != http.StatusBadRequest { t.Errorf("expected status 400, got %d", w.Code) } } func TestSupplyAPI_SettlementActions_UnknownAction(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("POST", "/api/v1/supply/settlements/1/unknown", nil) w := httptest.NewRecorder() api.handleSettlementActions(w, req) if w.Code != http.StatusNotFound { t.Errorf("expected status 404, got %d", w.Code) } } // ==================== Earning Handler Tests ==================== func TestSupplyAPI_GetEarningRecords_Success(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("GET", "/api/v1/supply/earnings/records?start_date=2024-01-01&end_date=2024-01-31&page=1&page_size=20", nil) w := httptest.NewRecorder() api.handleGetEarningRecords(w, req) if w.Code != http.StatusOK { t.Errorf("expected status 200, got %d", w.Code) } var resp map[string]any if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { t.Fatalf("failed to unmarshal response: %v", err) } data, ok := resp["data"].([]any) if !ok { t.Fatal("expected data array in response") } if len(data) != 1 { t.Errorf("expected 1 record, got %d", len(data)) } pagination, ok := resp["pagination"].(map[string]any) if !ok { t.Fatal("expected pagination in response") } if pagination["total"] != float64(1) { t.Errorf("expected total 1, got %v", pagination["total"]) } } func TestSupplyAPI_GetEarningRecords_MethodNotAllowed(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("POST", "/api/v1/supply/earnings/records", nil) w := httptest.NewRecorder() api.handleGetEarningRecords(w, req) if w.Code != http.StatusMethodNotAllowed { t.Errorf("expected status 405, got %d", w.Code) } } func TestSupplyAPI_GetEarningRecords_QueryFailed(t *testing.T) { api, _, _, _, earningSvc, _ := newTestAPI() earningSvc.listErr = errors.New("query failed") req := httptest.NewRequest("GET", "/api/v1/supply/earnings/records", nil) w := httptest.NewRecorder() api.handleGetEarningRecords(w, req) if w.Code != http.StatusInternalServerError { t.Errorf("expected status 500, got %d", w.Code) } } // ==================== Audit Event Handler Tests ==================== func TestSupplyAPI_GetAuditEvent_Success(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("GET", "/api/v1/audit/events/evt_123", nil) w := httptest.NewRecorder() api.handleAuditEvent(w, req) if w.Code != http.StatusOK { t.Errorf("expected status 200, got %d", w.Code) } var resp map[string]any if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { t.Fatalf("failed to unmarshal response: %v", err) } data, ok := resp["data"].(map[string]any) if !ok { t.Fatal("expected data in response") } if data["event_id"] != "evt_123" { t.Errorf("expected event_id evt_123, got %v", data["event_id"]) } } func TestSupplyAPI_GetAuditEvent_NotFound(t *testing.T) { api, _, _, _, _, auditSvc := newTestAPI() auditSvc.err = errors.New("not found") req := httptest.NewRequest("GET", "/api/v1/audit/events/evt_999", nil) w := httptest.NewRecorder() api.handleAuditEvent(w, req) if w.Code != http.StatusNotFound { t.Errorf("expected status 404, got %d", w.Code) } } func TestSupplyAPI_GetAuditEvent_MissingID(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("GET", "/api/v1/audit/events/", nil) w := httptest.NewRecorder() api.handleAuditEvent(w, req) if w.Code != http.StatusBadRequest { t.Errorf("expected status 400, got %d", w.Code) } } func TestSupplyAPI_GetAuditEvent_MethodNotAllowed(t *testing.T) { api, _, _, _, _, _ := newTestAPI() req := httptest.NewRequest("POST", "/api/v1/audit/events/evt_123", nil) w := httptest.NewRecorder() api.handleAuditEvent(w, req) if w.Code != http.StatusMethodNotAllowed { t.Errorf("expected status 405, got %d", w.Code) } } // ==================== Helper Function Tests ==================== func TestGetRequestID(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) req.Header.Set("X-Request-Id", "req-123") id := getRequestID(req) if id != "req-123" { t.Errorf("expected req-123, got %s", id) } req = httptest.NewRequest("GET", "/", nil) req.Header.Set("X-Request-ID", "req-456") id = getRequestID(req) if id != "req-456" { t.Errorf("expected req-456, got %s", id) } req = httptest.NewRequest("GET", "/", nil) id = getRequestID(req) if id != "" { t.Errorf("expected empty string, got %s", id) } } func TestGetQueryInt(t *testing.T) { req := httptest.NewRequest("GET", "/?page=5&page_size=100", nil) if getQueryInt(req, "page", 1) != 5 { t.Error("expected page 5") } if getQueryInt(req, "page_size", 20) != 100 { t.Error("expected page_size 100") } if getQueryInt(req, "missing", 10) != 10 { t.Error("expected default 10 for missing param") } if getQueryInt(req, "invalid", 1) != 1 { t.Error("expected default 1 for invalid value") } } func TestWriteJSON(t *testing.T) { w := httptest.NewRecorder() writeJSON(w, http.StatusOK, map[string]any{"key": "value"}) if w.Code != http.StatusOK { t.Errorf("expected status 200, got %d", w.Code) } if w.Header().Get("Content-Type") != "application/json" { t.Error("expected Content-Type application/json") } } func TestWriteError(t *testing.T) { w := httptest.NewRecorder() writeError(w, http.StatusBadRequest, "TEST_ERROR", "test message") if w.Code != http.StatusBadRequest { t.Errorf("expected status 400, got %d", w.Code) } var resp map[string]any json.Unmarshal(w.Body.Bytes(), &resp) errObj, ok := resp["error"].(map[string]any) if !ok { t.Fatal("expected error object in response") } if errObj["code"] != "TEST_ERROR" { t.Errorf("expected code TEST_ERROR, got %v", errObj["code"]) } if errObj["message"] != "test message" { t.Errorf("expected message 'test message', got %v", errObj["message"]) } } // ==================== Integration Tests ==================== func TestSupplyAPI_Register(t *testing.T) { api, _, _, _, _, _ := newTestAPI() mux := http.NewServeMux() api.Register(mux) // 验证路由已注册(不会panic) _ = mux } func TestSupplyAPI_EndToEnd_Withdraw(t *testing.T) { api, _, _, settlementSvc, _, _ := newTestAPI() settlementSvc.settlement = &domain.Settlement{ ID: 1, SupplierID: 100, SettlementNo: "SET_20240101_001", Status: domain.SettlementStatusPending, TotalAmount: 1000, NetAmount: 950, CreatedAt: time.Now(), UpdatedAt: time.Now(), } body := `{"withdraw_amount":500,"payment_method":"bank","payment_account":"1234567890","sms_code":"123456"}` req := httptest.NewRequest("POST", "/api/v1/supply/settlements/withdraw", strings.NewReader(body)) req.Header.Set("Content-Type", "application/json") req.Header.Set("X-Request-Id", "test-req-001") w := httptest.NewRecorder() api.handleWithdraw(w, req) if w.Code != http.StatusCreated { t.Errorf("expected status 201, got %d. Body: %s", w.Code, w.Body.String()) } var resp map[string]any if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { t.Fatalf("failed to unmarshal response: %v", err) } if resp["request_id"] != "test-req-001" { t.Errorf("expected request_id test-req-001, got %v", resp["request_id"]) } data, ok := resp["data"].(map[string]any) if !ok { t.Fatal("expected data in response") } if data["settlement_id"] != float64(1) { t.Errorf("expected settlement_id 1, got %v", data["settlement_id"]) } if data["status"] != "pending" { t.Errorf("expected status pending, got %v", data["status"]) } } func TestSupplyAPI_WithdrawDisabled_ReturnsServiceUnavailable(t *testing.T) { api, _, _, _, _, _ := newTestAPI() api.withdrawEnabled = false body := `{"withdraw_amount":500,"payment_method":"bank","payment_account":"1234567890","sms_code":"123456"}` req := httptest.NewRequest("POST", "/api/v1/supply/settlements/withdraw", strings.NewReader(body)) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() api.handleWithdraw(w, req) if w.Code != http.StatusServiceUnavailable { t.Fatalf("expected status 503, got %d body=%s", w.Code, w.Body.String()) } } func TestSupplyAPI_EndToEnd_BillingSummary(t *testing.T) { api, _, _, _, earningSvc, _ := newTestAPI() earningSvc.billingSummary = &domain.BillingSummary{ Period: domain.BillingPeriod{ Start: "2024-01-01", End: "2024-01-31", }, Summary: domain.BillingTotal{ TotalRevenue: 10000, TotalOrders: 100, TotalUsage: 1000000, TotalRequests: 5000000, AvgSuccessRate: 99.5, }, } req := httptest.NewRequest("GET", "/api/v1/supply/billing?start_date=2024-01-01&end_date=2024-01-31", nil) w := httptest.NewRecorder() api.handleGetBilling(w, req) if w.Code != http.StatusOK { t.Errorf("expected status 200, got %d", w.Code) } var resp map[string]any if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { t.Fatalf("failed to unmarshal response: %v", err) } data, ok := resp["data"].(map[string]any) if !ok { t.Fatal("expected data in response") } summary, ok := data["summary"].(map[string]any) if !ok { t.Fatal("expected summary in data") } if summary["total_revenue"] != float64(10000) { t.Errorf("expected total_revenue 10000, got %v", summary["total_revenue"]) } }