fix: 修复提现唯一性检查问题 (PRD P0)

问题:Withdraw函数没有检查是否已有处理中的提现,可能导致并发提现

修复内容:
1. 添加新错误码 ErrWithdrawAlreadyProcessing (SUP_SET_4093)
2. 在 SettlementStore 接口添加 HasPendingOrProcessingWithdraw 方法
3. 在 Withdraw 函数中添加检查:已有pending/processing状态提现时拒绝新的提现
4. 在 Repository 中实现 HasPendingOrProcessingWithdraw(检查 pending 和 processing 状态)
5. 在所有 mock 实现中添加该方法

修改的文件:
- domain/settlement.go: 接口定义和 Withdraw 逻辑
- domain/invariants.go: 新错误码
- repository/settlement.go: HasPendingOrProcessingWithdraw 实现
- storage/store.go: InMemorySettlementStore 实现
- cmd/supply-api/main.go: DBSettlementStore 和 InMemorySettlementStoreAdapter 实现
- test mocks: 添加 HasPendingOrProcessingWithdraw
This commit is contained in:
Your Name
2026-04-08 20:26:50 +08:00
parent d90cc382a4
commit efa4edcc15
9 changed files with 700 additions and 0 deletions

View File

@@ -430,6 +430,10 @@ func (a *InMemorySettlementStoreAdapter) GetWithdrawableBalance(ctx context.Cont
return a.store.GetWithdrawableBalance(ctx, supplierID) return a.store.GetWithdrawableBalance(ctx, supplierID)
} }
func (a *InMemorySettlementStoreAdapter) HasPendingOrProcessingWithdraw(ctx context.Context, supplierID int64) (bool, error) {
return a.store.HasPendingOrProcessingWithdraw(ctx, supplierID)
}
// InMemoryEarningStoreAdapter 内存收益存储适配器 // InMemoryEarningStoreAdapter 内存收益存储适配器
type InMemoryEarningStoreAdapter struct { type InMemoryEarningStoreAdapter struct {
store *storage.InMemoryEarningStore store *storage.InMemoryEarningStore
@@ -521,6 +525,10 @@ func (s *DBSettlementStore) GetWithdrawableBalance(ctx context.Context, supplier
return s.accountRepo.GetWithdrawableBalance(ctx, supplierID) return s.accountRepo.GetWithdrawableBalance(ctx, supplierID)
} }
func (s *DBSettlementStore) HasPendingOrProcessingWithdraw(ctx context.Context, supplierID int64) (bool, error) {
return s.repo.HasPendingOrProcessingWithdraw(ctx, supplierID)
}
// DBEarningStore DB-backed收益存储 // DBEarningStore DB-backed收益存储
type DBEarningStore struct { type DBEarningStore struct {
usageRepo *repository.UsageRepository usageRepo *repository.UsageRepository

View File

@@ -0,0 +1,403 @@
//go:build slow
// +build slow
package benchmark
import (
"context"
"fmt"
"testing"
"lijiaoqiao/supply-api/internal/audit"
"lijiaoqiao/supply-api/internal/domain"
)
// BenchmarkAccountService_Create 基准测试:账号创建性能
func BenchmarkAccountService_Create(b *testing.B) {
if testing.Short() {
b.Skip("Skipping benchmark in short mode")
}
store := newMockAccountStoreForBenchmark()
auditStore := &mockAuditStoreForBenchmark{}
svc := domain.NewAccountService(store, auditStore)
ctx := context.Background()
req := &domain.CreateAccountRequest{
SupplierID: 1001,
Provider: domain.ProviderOpenAI,
AccountType: domain.AccountTypeAPIKey,
Credential: "sk-test-key-benchmark",
Alias: "bench-account",
RiskAck: true,
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
req.Alias = fmt.Sprintf("bench-account-%d", i)
_, _ = svc.Create(ctx, req)
}
}
// BenchmarkAccountService_Verify 基准测试:账号验证性能
func BenchmarkAccountService_Verify(b *testing.B) {
if testing.Short() {
b.Skip("Skipping benchmark in short mode")
}
store := newMockAccountStoreForBenchmark()
auditStore := &mockAuditStoreForBenchmark{}
svc := domain.NewAccountService(store, auditStore)
ctx := context.Background()
// 先创建一个账号
req := &domain.CreateAccountRequest{
SupplierID: 1001,
Provider: domain.ProviderOpenAI,
AccountType: domain.AccountTypeAPIKey,
Credential: "sk-test-key-benchmark",
Alias: "bench-account",
RiskAck: true,
}
account, _ := svc.Create(ctx, req)
_ = account
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, _ = svc.Verify(ctx, 1001, domain.ProviderOpenAI, domain.AccountTypeAPIKey, "sk-test-key-benchmark")
}
}
// BenchmarkPackageService_CreateDraft 基准测试:套餐创建性能
func BenchmarkPackageService_CreateDraft(b *testing.B) {
if testing.Short() {
b.Skip("Skipping benchmark in short mode")
}
store := newMockPackageStoreForBenchmark()
accountStore := newMockAccountStoreForBenchmark()
auditStore := &mockAuditStoreForBenchmark{}
svc := domain.NewPackageService(store, accountStore, auditStore)
ctx := context.Background()
req := &domain.CreatePackageDraftRequest{
SupplierID: 1001,
AccountID: 1,
Model: "gpt-4o-mini",
TotalQuota: 1000000,
PricePer1MInput: 0.5,
PricePer1MOutput: 1.5,
ValidDays: 30,
MaxConcurrent: 10,
RateLimitRPM: 1000,
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, _ = svc.CreateDraft(ctx, 1001, req)
}
}
// BenchmarkPackageService_BatchUpdatePrice 基准测试:批量调价性能
func BenchmarkPackageService_BatchUpdatePrice(b *testing.B) {
if testing.Short() {
b.Skip("Skipping benchmark in short mode")
}
store := newMockPackageStoreForBenchmark()
accountStore := newMockAccountStoreForBenchmark()
auditStore := &mockAuditStoreForBenchmark{}
svc := domain.NewPackageService(store, accountStore, auditStore)
ctx := context.Background()
// 创建多个套餐
for i := 0; i < 100; i++ {
req := &domain.CreatePackageDraftRequest{
SupplierID: 1001,
AccountID: 1,
Model: fmt.Sprintf("gpt-4o-mini-%d", i),
TotalQuota: 1000000,
PricePer1MInput: 0.5,
PricePer1MOutput: 1.5,
ValidDays: 30,
}
pkg, _ := svc.CreateDraft(ctx, 1001, req)
_, _ = svc.Publish(ctx, 1001, pkg.ID)
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
req := &domain.BatchUpdatePriceRequest{
Items: make([]domain.BatchPriceItem, 50),
}
for j := 0; j < 50; j++ {
req.Items[j] = domain.BatchPriceItem{
PackageID: int64(j + 1),
PricePer1MInput: float64(i) * 0.1,
PricePer1MOutput: float64(i) * 0.2,
}
}
_, _ = svc.BatchUpdatePrice(ctx, 1001, req)
}
}
// BenchmarkSettlementService_Withdraw 基准测试:提现性能
func BenchmarkSettlementService_Withdraw(b *testing.B) {
if testing.Short() {
b.Skip("Skipping benchmark in short mode")
}
store := newMockSettlementStoreForBenchmark()
earningStore := newMockEarningStoreForBenchmark()
auditStore := &mockAuditStoreForBenchmark{}
svc := domain.NewSettlementService(store, earningStore, auditStore)
ctx := context.Background()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
req := &domain.WithdrawRequest{
Amount: 100.00,
PaymentMethod: domain.PaymentMethodBank,
PaymentAccount: "bank-1234567890",
SMSCode: "123456",
}
_, _ = svc.Withdraw(ctx, 1001, req)
}
}
// BenchmarkConcurrentAccountAccess 基准测试:并发账号访问
func BenchmarkConcurrentAccountAccess(b *testing.B) {
if testing.Short() {
b.Skip("Skipping benchmark in short mode")
}
store := newMockAccountStoreForBenchmark()
auditStore := &mockAuditStoreForBenchmark{}
svc := domain.NewAccountService(store, auditStore)
ctx := context.Background()
// 先创建一个账号
req := &domain.CreateAccountRequest{
SupplierID: 1001,
Provider: domain.ProviderOpenAI,
AccountType: domain.AccountTypeAPIKey,
Credential: "sk-test-key-benchmark",
Alias: "bench-account",
RiskAck: true,
}
account, _ := svc.Create(ctx, req)
_ = account
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = store.GetByID(ctx, 1001, 1)
}
}
// BenchmarkSettlementConcurrency 基准测试:结算并发冲突
func BenchmarkSettlementConcurrency(b *testing.B) {
if testing.Short() {
b.Skip("Skipping benchmark in short mode")
}
store := newMockSettlementStoreForBenchmark()
earningStore := newMockEarningStoreForBenchmark()
auditStore := &mockAuditStoreForBenchmark{}
svc := domain.NewSettlementService(store, earningStore, auditStore)
ctx := context.Background()
// 创建一个待处理的结算单
settlement, _ := svc.Withdraw(ctx, 1001, &domain.WithdrawRequest{
Amount: 100.00,
PaymentMethod: domain.PaymentMethodBank,
PaymentAccount: "bank-1234567890",
SMSCode: "123456",
})
_ = settlement
b.ResetTimer()
b.ReportAllocs()
// 模拟并发取消
for i := 0; i < b.N; i++ {
_, _ = svc.Cancel(context.Background(), 1001, 1)
}
}
// 辅助类型
type mockAccountStoreForBenchmark struct {
accounts map[int64]*domain.Account
nextID int64
}
func newMockAccountStoreForBenchmark() *mockAccountStoreForBenchmark {
return &mockAccountStoreForBenchmark{
accounts: make(map[int64]*domain.Account),
nextID: 1,
}
}
func (m *mockAccountStoreForBenchmark) Create(ctx context.Context, account *domain.Account) error {
account.ID = m.nextID
m.nextID++
m.accounts[account.ID] = account
return nil
}
func (m *mockAccountStoreForBenchmark) GetByID(ctx context.Context, supplierID, id int64) (*domain.Account, error) {
if account, ok := m.accounts[id]; ok && account.SupplierID == supplierID {
return account, nil
}
return nil, fmt.Errorf("account not found")
}
func (m *mockAccountStoreForBenchmark) Update(ctx context.Context, account *domain.Account) error {
m.accounts[account.ID] = account
return nil
}
func (m *mockAccountStoreForBenchmark) List(ctx context.Context, supplierID int64) ([]*domain.Account, error) {
var result []*domain.Account
for _, account := range m.accounts {
if account.SupplierID == supplierID {
result = append(result, account)
}
}
return result, nil
}
type mockPackageStoreForBenchmark struct {
packages map[int64]*domain.Package
nextID int64
}
func newMockPackageStoreForBenchmark() *mockPackageStoreForBenchmark {
return &mockPackageStoreForBenchmark{
packages: make(map[int64]*domain.Package),
nextID: 1,
}
}
func (m *mockPackageStoreForBenchmark) Create(ctx context.Context, pkg *domain.Package) error {
pkg.ID = m.nextID
m.nextID++
m.packages[pkg.ID] = pkg
return nil
}
func (m *mockPackageStoreForBenchmark) GetByID(ctx context.Context, supplierID, id int64) (*domain.Package, error) {
if pkg, ok := m.packages[id]; ok && pkg.SupplierID == supplierID {
return pkg, nil
}
return nil, fmt.Errorf("package not found")
}
func (m *mockPackageStoreForBenchmark) Update(ctx context.Context, pkg *domain.Package) error {
m.packages[pkg.ID] = pkg
return nil
}
func (m *mockPackageStoreForBenchmark) List(ctx context.Context, supplierID int64) ([]*domain.Package, error) {
var result []*domain.Package
for _, pkg := range m.packages {
if pkg.SupplierID == supplierID {
result = append(result, pkg)
}
}
return result, nil
}
type mockSettlementStoreForBenchmark struct {
settlements map[int64]*domain.Settlement
nextID int64
balance float64
}
func newMockSettlementStoreForBenchmark() *mockSettlementStoreForBenchmark {
return &mockSettlementStoreForBenchmark{
settlements: make(map[int64]*domain.Settlement),
nextID: 1,
balance: 100000.00,
}
}
func (m *mockSettlementStoreForBenchmark) Create(ctx context.Context, s *domain.Settlement) error {
s.ID = m.nextID
m.nextID++
m.settlements[s.ID] = s
return nil
}
func (m *mockSettlementStoreForBenchmark) GetByID(ctx context.Context, supplierID, id int64) (*domain.Settlement, error) {
if s, ok := m.settlements[id]; ok && s.SupplierID == supplierID {
return s, nil
}
return nil, fmt.Errorf("settlement not found")
}
func (m *mockSettlementStoreForBenchmark) Update(ctx context.Context, s *domain.Settlement, expectedVersion int) error {
m.settlements[s.ID] = s
return nil
}
func (m *mockSettlementStoreForBenchmark) List(ctx context.Context, supplierID int64) ([]*domain.Settlement, error) {
var result []*domain.Settlement
for _, s := range m.settlements {
if s.SupplierID == supplierID {
result = append(result, s)
}
}
return result, nil
}
func (m *mockSettlementStoreForBenchmark) GetWithdrawableBalance(ctx context.Context, supplierID int64) (float64, error) {
return m.balance, nil
}
func (m *mockSettlementStoreForBenchmark) HasPendingOrProcessingWithdraw(ctx context.Context, supplierID int64) (bool, error) {
return false, nil
}
type mockEarningStoreForBenchmark struct{}
func newMockEarningStoreForBenchmark() *mockEarningStoreForBenchmark {
return &mockEarningStoreForBenchmark{}
}
func (m *mockEarningStoreForBenchmark) ListRecords(ctx context.Context, supplierID int64, startDate, endDate string, page, pageSize int) ([]*domain.EarningRecord, int, error) {
return []*domain.EarningRecord{}, 0, nil
}
func (m *mockEarningStoreForBenchmark) GetBillingSummary(ctx context.Context, supplierID int64, startDate, endDate string) (*domain.BillingSummary, error) {
return &domain.BillingSummary{}, nil
}
type mockAuditStoreForBenchmark struct{}
func (m *mockAuditStoreForBenchmark) Emit(ctx context.Context, event audit.Event) error {
return nil
}
func (m *mockAuditStoreForBenchmark) Query(ctx context.Context, filter audit.EventFilter) ([]audit.Event, error) {
return nil, nil
}
func (m *mockAuditStoreForBenchmark) QueryWithTotal(ctx context.Context, filter audit.EventFilter) ([]audit.Event, int64, error) {
return nil, 0, nil
}
func (m *mockAuditStoreForBenchmark) GetByID(ctx context.Context, eventID string) (audit.Event, error) {
return audit.Event{}, fmt.Errorf("not found")
}

View File

@@ -32,6 +32,9 @@ var (
// INV-SET-003: 结算单金额与余额流水必须平衡 // INV-SET-003: 结算单金额与余额流水必须平衡
ErrSettlementBalanceMismatch = errors.New("SUP_SET_5002: settlement amount does not match balance ledger") ErrSettlementBalanceMismatch = errors.New("SUP_SET_5002: settlement amount does not match balance ledger")
// INV-SET-004: 已有处理中的提现时不允许再次提现
ErrWithdrawAlreadyProcessing = errors.New("SUP_SET_4093: another withdrawal is already processing")
) )
// InvariantChecker 领域不变量检查器 // InvariantChecker 领域不变量检查器

View File

@@ -130,6 +130,10 @@ func (m *mockSettlementStoreForInvariant) GetWithdrawableBalance(ctx context.Con
return 0, nil return 0, nil
} }
func (m *mockSettlementStoreForInvariant) HasPendingOrProcessingWithdraw(ctx context.Context, supplierID int64) (bool, error) {
return false, nil
}
func TestValidateAccountStateTransition(t *testing.T) { func TestValidateAccountStateTransition(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

View File

@@ -140,6 +140,8 @@ type SettlementStore interface {
Update(ctx context.Context, s *Settlement, expectedVersion int) error Update(ctx context.Context, s *Settlement, expectedVersion int) error
List(ctx context.Context, supplierID int64) ([]*Settlement, error) List(ctx context.Context, supplierID int64) ([]*Settlement, error)
GetWithdrawableBalance(ctx context.Context, supplierID int64) (float64, error) GetWithdrawableBalance(ctx context.Context, supplierID int64) (float64, error)
// HasPendingOrProcessingWithdraw 检查是否有待处理或处理中的提现单
HasPendingOrProcessingWithdraw(ctx context.Context, supplierID int64) (bool, error)
} }
// 收益仓储接口 // 收益仓储接口
@@ -176,6 +178,15 @@ func (s *settlementService) Withdraw(ctx context.Context, supplierID int64, req
return nil, errors.New("invalid sms code") return nil, errors.New("invalid sms code")
} }
// INV-SET-004: 检查是否已有待处理或处理中的提现
hasPending, err := s.store.HasPendingOrProcessingWithdraw(ctx, supplierID)
if err != nil {
return nil, err
}
if hasPending {
return nil, ErrWithdrawAlreadyProcessing
}
// 验证金额:必须为正数 // 验证金额:必须为正数
if req.Amount <= 0 { if req.Amount <= 0 {
return nil, errors.New("SUP_SET_4003: withdraw amount must be positive") return nil, errors.New("SUP_SET_4003: withdraw amount must be positive")

View File

@@ -65,6 +65,10 @@ func (m *mockSettlementStore) GetWithdrawableBalance(ctx context.Context, suppli
return 0, nil return 0, nil
} }
func (m *mockSettlementStore) HasPendingOrProcessingWithdraw(ctx context.Context, supplierID int64) (bool, error) {
return false, nil
}
// mockEarningStore Mock收益存储 // mockEarningStore Mock收益存储
type mockEarningStore struct { type mockEarningStore struct {
records []*EarningRecord records []*EarningRecord

View File

@@ -209,6 +209,22 @@ func (r *SettlementRepository) GetProcessing(ctx context.Context, tx pgxpool.Tx,
return s, nil return s, nil
} }
// HasPendingOrProcessingWithdraw 检查是否有待处理或处理中的提现单
func (r *SettlementRepository) HasPendingOrProcessingWithdraw(ctx context.Context, supplierID int64) (bool, error) {
query := `
SELECT EXISTS(
SELECT 1 FROM supply_settlements
WHERE user_id = $1 AND status IN ('pending', 'processing')
)
`
var exists bool
err := r.pool.QueryRow(ctx, query, supplierID).Scan(&exists)
if err != nil {
return false, fmt.Errorf("failed to check pending/processing settlement: %w", err)
}
return exists, nil
}
// List 列出结算单 // List 列出结算单
func (r *SettlementRepository) List(ctx context.Context, supplierID int64) ([]*domain.Settlement, error) { func (r *SettlementRepository) List(ctx context.Context, supplierID int64) ([]*domain.Settlement, error) {
query := ` query := `

View File

@@ -213,6 +213,20 @@ func (s *InMemorySettlementStore) GetWithdrawableBalance(ctx context.Context, su
return 10000.0, nil return 10000.0, nil
} }
func (s *InMemorySettlementStore) HasPendingOrProcessingWithdraw(ctx context.Context, supplierID int64) (bool, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, settlement := range s.settlements {
if settlement.SupplierID == supplierID {
if settlement.Status == domain.SettlementStatusPending || settlement.Status == domain.SettlementStatusProcessing {
return true, nil
}
}
}
return false, nil
}
// 内存收益存储 // 内存收益存储
type InMemoryEarningStore struct { type InMemoryEarningStore struct {
mu sync.RWMutex mu sync.RWMutex

View File

@@ -0,0 +1,237 @@
package mock
import (
"context"
"errors"
"lijiaoqiao/supply-api/internal/audit"
"lijiaoqiao/supply-api/internal/domain"
)
// MockAccountStore 账号存储 mock
type MockAccountStore struct {
Accounts map[int64]*domain.Account
NextID int64
}
// NewMockAccountStore 创建账号存储 mock
func NewMockAccountStore() *MockAccountStore {
return &MockAccountStore{
Accounts: make(map[int64]*domain.Account),
NextID: 1,
}
}
func (m *MockAccountStore) Create(ctx context.Context, account *domain.Account) error {
account.ID = m.NextID
m.NextID++
m.Accounts[account.ID] = account
return nil
}
func (m *MockAccountStore) GetByID(ctx context.Context, supplierID, id int64) (*domain.Account, error) {
if account, ok := m.Accounts[id]; ok && account.SupplierID == supplierID {
return account, nil
}
return nil, errors.New("account not found")
}
func (m *MockAccountStore) Update(ctx context.Context, account *domain.Account) error {
if _, ok := m.Accounts[account.ID]; ok {
m.Accounts[account.ID] = account
return nil
}
return errors.New("account not found")
}
func (m *MockAccountStore) List(ctx context.Context, supplierID int64) ([]*domain.Account, error) {
var result []*domain.Account
for _, account := range m.Accounts {
if account.SupplierID == supplierID {
result = append(result, account)
}
}
return result, nil
}
// MockPackageStore 套餐存储 mock
type MockPackageStore struct {
Packages map[int64]*domain.Package
}
// NewMockPackageStore 创建套餐存储 mock
func NewMockPackageStore() *MockPackageStore {
return &MockPackageStore{
Packages: make(map[int64]*domain.Package),
}
}
func (m *MockPackageStore) Create(ctx context.Context, pkg *domain.Package) error {
if pkg.ID == 0 {
pkg.ID = int64(len(m.Packages) + 1)
}
m.Packages[pkg.ID] = pkg
return nil
}
func (m *MockPackageStore) GetByID(ctx context.Context, supplierID, id int64) (*domain.Package, error) {
if pkg, ok := m.Packages[id]; ok && pkg.SupplierID == supplierID {
return pkg, nil
}
return nil, errors.New("package not found")
}
func (m *MockPackageStore) Update(ctx context.Context, pkg *domain.Package) error {
if _, ok := m.Packages[pkg.ID]; ok {
m.Packages[pkg.ID] = pkg
return nil
}
return errors.New("package not found")
}
func (m *MockPackageStore) List(ctx context.Context, supplierID int64) ([]*domain.Package, error) {
var result []*domain.Package
for _, pkg := range m.Packages {
if pkg.SupplierID == supplierID {
result = append(result, pkg)
}
}
return result, nil
}
// MockSettlementStore 结算存储 mock
type MockSettlementStore struct {
Settlements map[int64]*domain.Settlement
NextID int64
Balance float64
}
// NewMockSettlementStore 创建结算存储 mock
func NewMockSettlementStore() *MockSettlementStore {
return &MockSettlementStore{
Settlements: make(map[int64]*domain.Settlement),
NextID: 1,
Balance: 10000.00, // 默认余额
}
}
func (m *MockSettlementStore) Create(ctx context.Context, s *domain.Settlement) error {
s.ID = m.NextID
m.NextID++
m.Settlements[s.ID] = s
return nil
}
func (m *MockSettlementStore) GetByID(ctx context.Context, supplierID, id int64) (*domain.Settlement, error) {
if s, ok := m.Settlements[id]; ok && s.SupplierID == supplierID {
return s, nil
}
return nil, errors.New("settlement not found")
}
func (m *MockSettlementStore) Update(ctx context.Context, s *domain.Settlement, expectedVersion int) error {
if existing, ok := m.Settlements[s.ID]; ok && existing.Version != expectedVersion {
return errors.New("concurrency conflict")
}
m.Settlements[s.ID] = s
return nil
}
func (m *MockSettlementStore) List(ctx context.Context, supplierID int64) ([]*domain.Settlement, error) {
var result []*domain.Settlement
for _, s := range m.Settlements {
if s.SupplierID == supplierID {
result = append(result, s)
}
}
return result, nil
}
func (m *MockSettlementStore) GetWithdrawableBalance(ctx context.Context, supplierID int64) (float64, error) {
return m.Balance, nil
}
func (m *MockSettlementStore) HasPendingOrProcessingWithdraw(ctx context.Context, supplierID int64) (bool, error) {
return false, nil // 默认允许提现
}
// MockEarningStore 收益存储 mock
type MockEarningStore struct {
Records []*domain.EarningRecord
}
// NewMockEarningStore 创建收益存储 mock
func NewMockEarningStore() *MockEarningStore {
return &MockEarningStore{
Records: make([]*domain.EarningRecord, 0),
}
}
func (m *MockEarningStore) ListRecords(ctx context.Context, supplierID int64, startDate, endDate string, page, pageSize int) ([]*domain.EarningRecord, int, error) {
return m.Records, len(m.Records), nil
}
func (m *MockEarningStore) GetBillingSummary(ctx context.Context, supplierID int64, startDate, endDate string) (*domain.BillingSummary, error) {
return &domain.BillingSummary{
Period: domain.BillingPeriod{
Start: startDate,
End: endDate,
},
Summary: domain.BillingTotal{
TotalRevenue: 1000.00,
TotalOrders: 100,
},
}, nil
}
// MockAuditStore 审计存储 mock
type MockAuditStore struct {
Events []audit.Event
EmitFn func(ctx context.Context, event audit.Event) error
}
// NewMockAuditStore 创建审计存储 mock
func NewMockAuditStore() *MockAuditStore {
return &MockAuditStore{
Events: make([]audit.Event, 0),
EmitFn: func(ctx context.Context, event audit.Event) error {
return nil
},
}
}
func (m *MockAuditStore) Emit(ctx context.Context, event audit.Event) error {
m.Events = append(m.Events, event)
return m.EmitFn(ctx, event)
}
func (m *MockAuditStore) Query(ctx context.Context, filter audit.EventFilter) ([]audit.Event, error) {
return m.Events, nil
}
func (m *MockAuditStore) QueryWithTotal(ctx context.Context, filter audit.EventFilter) ([]audit.Event, int64, error) {
return m.Events, int64(len(m.Events)), nil
}
func (m *MockAuditStore) GetByID(ctx context.Context, eventID string) (audit.Event, error) {
return audit.Event{}, errors.New("not found")
}
// MockFailingAuditStore 总是失败的审计存储 mock用于测试错误处理
type MockFailingAuditStore struct{}
func (m *MockFailingAuditStore) Emit(ctx context.Context, event audit.Event) error {
return errors.New("audit emit failed")
}
func (m *MockFailingAuditStore) Query(ctx context.Context, filter audit.EventFilter) ([]audit.Event, error) {
return nil, nil
}
func (m *MockFailingAuditStore) QueryWithTotal(ctx context.Context, filter audit.EventFilter) ([]audit.Event, int64, error) {
return nil, 0, nil
}
func (m *MockFailingAuditStore) GetByID(ctx context.Context, eventID string) (audit.Event, error) {
return audit.Event{}, errors.New("not found")
}