Files
lijiaoqiao/supply-api/internal/domain/settlement_test.go
Your Name efa4edcc15 fix: 修复提现唯一性检查问题 (PRD P0)
问题:Withdraw函数没有检查是否已有处理中的提现,可能导致并发提现

修复内容:
1. 添加新错误码 ErrWithdrawAlreadyProcessing (SUP_SET_4093)
2. 在 SettlementStore 接口添加 HasPendingOrProcessingWithdraw 方法
3. 在 Withdraw 函数中添加检查:已有pending/processing状态提现时拒绝新的提现
4. 在 Repository 中实现 HasPendingOrProcessingWithdraw(检查 pending 和 processing 状态)
5. 在所有 mock 实现中添加该方法

修改的文件:
- domain/settlement.go: 接口定义和 Withdraw 逻辑
- domain/invariants.go: 新错误码
- repository/settlement.go: HasPendingOrProcessingWithdraw 实现
- storage/store.go: InMemorySettlementStore 实现
- cmd/supply-api/main.go: DBSettlementStore 和 InMemorySettlementStoreAdapter 实现
- test mocks: 添加 HasPendingOrProcessingWithdraw
2026-04-08 20:26:50 +08:00

494 lines
14 KiB
Go

package domain
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
"lijiaoqiao/supply-api/internal/audit"
)
// mockSettlementStore Mock结算存储
type mockSettlementStore struct {
settlements map[int64]*Settlement
balances map[int64]float64
nextID int64
}
func newMockSettlementStore() *mockSettlementStore {
return &mockSettlementStore{
settlements: make(map[int64]*Settlement),
balances: make(map[int64]float64),
nextID: 1,
}
}
func (m *mockSettlementStore) Create(ctx context.Context, s *Settlement) error {
s.ID = m.nextID
m.nextID++
m.settlements[s.ID] = s
return nil
}
func (m *mockSettlementStore) GetByID(ctx context.Context, supplierID, id int64) (*Settlement, error) {
if s, ok := m.settlements[id]; ok && s.SupplierID == supplierID {
return s, nil
}
return nil, errors.New("settlement not found")
}
func (m *mockSettlementStore) Update(ctx context.Context, s *Settlement, expectedVersion int) error {
if s.Version != expectedVersion {
return errors.New("concurrency conflict")
}
m.settlements[s.ID] = s
return nil
}
func (m *mockSettlementStore) List(ctx context.Context, supplierID int64) ([]*Settlement, error) {
var result []*Settlement
for _, s := range m.settlements {
if s.SupplierID == supplierID {
result = append(result, s)
}
}
return result, nil
}
func (m *mockSettlementStore) GetWithdrawableBalance(ctx context.Context, supplierID int64) (float64, error) {
if balance, ok := m.balances[supplierID]; ok {
return balance, nil
}
return 0, nil
}
func (m *mockSettlementStore) HasPendingOrProcessingWithdraw(ctx context.Context, supplierID int64) (bool, error) {
return false, nil
}
// mockEarningStore Mock收益存储
type mockEarningStore struct {
records []*EarningRecord
}
func newMockEarningStore() *mockEarningStore {
return &mockEarningStore{
records: make([]*EarningRecord, 0),
}
}
func (m *mockEarningStore) ListRecords(ctx context.Context, supplierID int64, startDate, endDate string, page, pageSize int) ([]*EarningRecord, int, error) {
var result []*EarningRecord
for _, r := range m.records {
if r.SupplierID == supplierID {
result = append(result, r)
}
}
return result, len(result), nil
}
func (m *mockEarningStore) GetBillingSummary(ctx context.Context, supplierID int64, startDate, endDate string) (*BillingSummary, error) {
return &BillingSummary{
Period: BillingPeriod{
Start: startDate,
End: endDate,
},
Summary: BillingTotal{
TotalRevenue: 1000.00,
TotalOrders: 100,
TotalUsage: 5000,
TotalRequests: 10000,
AvgSuccessRate: 99.5,
PlatformFee: 10.00,
NetEarnings: 990.00,
},
}, nil
}
// mockAuditStoreForSettlement Mock审计存储
type mockAuditStoreForSettlement struct{}
func (m *mockAuditStoreForSettlement) Emit(ctx context.Context, event audit.Event) error {
return nil
}
func (m *mockAuditStoreForSettlement) Query(ctx context.Context, filter audit.EventFilter) ([]audit.Event, error) {
return nil, nil
}
func (m *mockAuditStoreForSettlement) QueryWithTotal(ctx context.Context, filter audit.EventFilter) ([]audit.Event, int64, error) {
return nil, 0, nil
}
func (m *mockAuditStoreForSettlement) GetByID(ctx context.Context, eventID string) (audit.Event, error) {
return audit.Event{}, errors.New("not found")
}
// TestSettlementConstants 测试结算状态常量
func TestSettlementConstants(t *testing.T) {
assert.Equal(t, SettlementStatus("pending"), SettlementStatusPending)
assert.Equal(t, SettlementStatus("processing"), SettlementStatusProcessing)
assert.Equal(t, SettlementStatus("completed"), SettlementStatusCompleted)
assert.Equal(t, SettlementStatus("failed"), SettlementStatusFailed)
}
// TestPaymentMethodConstants 测试支付方式常量
func TestPaymentMethodConstants(t *testing.T) {
assert.Equal(t, PaymentMethod("bank"), PaymentMethodBank)
assert.Equal(t, PaymentMethod("alipay"), PaymentMethodAlipay)
assert.Equal(t, PaymentMethod("wechat"), PaymentMethodWechat)
}
// TestSettlementStruct 测试结算单结构体
func TestSettlementStruct(t *testing.T) {
now := time.Now()
s := &Settlement{
ID: 1,
SupplierID: 1001,
SettlementNo: "SET-2024-001",
Status: SettlementStatusPending,
TotalAmount: 1000.00,
FeeAmount: 10.00,
NetAmount: 990.00,
PaymentMethod: PaymentMethodBank,
PaymentAccount: "1234567890",
PeriodStart: now,
PeriodEnd: now.Add(24 * time.Hour),
TotalOrders: 100,
CurrencyCode: "CNY",
AmountUnit: "yuan",
Version: 1,
CreatedAt: now,
UpdatedAt: now,
}
assert.Equal(t, int64(1), s.ID)
assert.Equal(t, int64(1001), s.SupplierID)
assert.Equal(t, "SET-2024-001", s.SettlementNo)
assert.Equal(t, SettlementStatusPending, s.Status)
assert.Equal(t, 1000.00, s.TotalAmount)
assert.Equal(t, 10.00, s.FeeAmount)
assert.Equal(t, 990.00, s.NetAmount)
assert.Equal(t, PaymentMethodBank, s.PaymentMethod)
assert.Equal(t, "1234567890", s.PaymentAccount)
assert.Equal(t, 100, s.TotalOrders)
assert.Equal(t, "CNY", s.CurrencyCode)
assert.Equal(t, "yuan", s.AmountUnit)
assert.Equal(t, 1, s.Version)
}
// TestEarningRecordStruct 测试收益记录结构体
func TestEarningRecordStruct(t *testing.T) {
now := time.Now()
e := &EarningRecord{
ID: 1,
SupplierID: 1001,
SettlementID: 10,
EarningsType: "usage",
Amount: 500.00,
Status: "available",
Description: "usage earnings",
EarnedAt: now,
}
assert.Equal(t, int64(1), e.ID)
assert.Equal(t, int64(1001), e.SupplierID)
assert.Equal(t, int64(10), e.SettlementID)
assert.Equal(t, "usage", e.EarningsType)
assert.Equal(t, 500.00, e.Amount)
assert.Equal(t, "available", e.Status)
}
// TestSettlementStatusTransitions 测试结算状态转换
func TestSettlementStatusTransitions(t *testing.T) {
// 测试有效状态
s := &Settlement{Status: SettlementStatusPending}
assert.Equal(t, SettlementStatusPending, s.Status)
s.Status = SettlementStatusProcessing
assert.Equal(t, SettlementStatusProcessing, s.Status)
s.Status = SettlementStatusCompleted
assert.Equal(t, SettlementStatusCompleted, s.Status)
s.Status = SettlementStatusFailed
assert.Equal(t, SettlementStatusFailed, s.Status)
}
// TestInvariantErrors 测试结算相关不变量错误
func TestSettlementInvariantErrors(t *testing.T) {
// ERRORS from invariants.go related to settlements
assert.Contains(t, ErrSettlementCannotCancel.Error(), "cannot cancel")
assert.Contains(t, ErrWithdrawExceedsBalance.Error(), "exceeds available balance")
assert.Contains(t, ErrSettlementBalanceMismatch.Error(), "does not match balance")
}
// TestNewSettlementService 测试创建结算服务
func TestNewSettlementService(t *testing.T) {
store := newMockSettlementStore()
earningStore := newMockEarningStore()
auditStore := &mockAuditStoreForSettlement{}
svc := NewSettlementService(store, earningStore, auditStore)
assert.NotNil(t, svc)
}
// TestSettlementService_Withdraw 测试提现
func TestSettlementService_Withdraw(t *testing.T) {
store := newMockSettlementStore()
earningStore := newMockEarningStore()
auditStore := &mockAuditStoreForSettlement{}
svc := NewSettlementService(store, earningStore, auditStore)
// 设置余额
store.balances[1001] = 5000.0
tests := []struct {
name string
req *WithdrawRequest
wantErr bool
errMsg string
}{
{
name: "invalid sms code",
req: &WithdrawRequest{
Amount: 1000,
SMSCode: "000000",
PaymentMethod: PaymentMethodBank,
PaymentAccount: "1234567890",
},
wantErr: true,
errMsg: "invalid sms code",
},
{
name: "negative amount",
req: &WithdrawRequest{
Amount: -100,
SMSCode: "123456",
PaymentMethod: PaymentMethodBank,
PaymentAccount: "1234567890",
},
wantErr: true,
errMsg: "must be positive",
},
{
name: "zero amount",
req: &WithdrawRequest{
Amount: 0,
SMSCode: "123456",
PaymentMethod: PaymentMethodBank,
PaymentAccount: "1234567890",
},
wantErr: true,
errMsg: "must be positive",
},
{
name: "exceeds balance",
req: &WithdrawRequest{
Amount: 10000,
SMSCode: "123456",
PaymentMethod: PaymentMethodBank,
PaymentAccount: "1234567890",
},
wantErr: true,
errMsg: "exceeds available balance",
},
{
name: "success",
req: &WithdrawRequest{
Amount: 1000,
SMSCode: "123456",
PaymentMethod: PaymentMethodBank,
PaymentAccount: "1234567890",
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := svc.Withdraw(context.Background(), 1001, tt.req)
if tt.wantErr {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errMsg)
} else {
assert.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, int64(1001), result.SupplierID)
assert.Equal(t, SettlementStatusPending, result.Status)
assert.Equal(t, 1000.0, result.TotalAmount)
assert.Equal(t, 10.0, result.FeeAmount) // 1% fee
assert.Equal(t, 990.0, result.NetAmount) // 99%
}
})
}
}
// TestSettlementService_Cancel 测试取消结算
func TestSettlementService_Cancel(t *testing.T) {
store := newMockSettlementStore()
earningStore := newMockEarningStore()
auditStore := &mockAuditStoreForSettlement{}
svc := NewSettlementService(store, earningStore, auditStore)
// 创建待处理结算
settlement := &Settlement{
ID: 1,
SupplierID: 1001,
SettlementNo: "SET-001",
Status: SettlementStatusPending,
TotalAmount: 1000,
PaymentMethod: PaymentMethodBank,
PaymentAccount: "1234567890",
Version: 1,
}
store.Create(context.Background(), settlement)
// 取消待处理结算应该成功
canceled, err := svc.Cancel(context.Background(), 1001, 1)
assert.NoError(t, err)
assert.NotNil(t, canceled)
assert.Equal(t, SettlementStatusFailed, canceled.Status)
}
// TestSettlementService_Cancel_ProcessingFails 测试取消处理中结算失败
func TestSettlementService_Cancel_ProcessingFails(t *testing.T) {
store := newMockSettlementStore()
earningStore := newMockEarningStore()
auditStore := &mockAuditStoreForSettlement{}
svc := NewSettlementService(store, earningStore, auditStore)
// 创建处理中结算
settlement := &Settlement{
ID: 1,
SupplierID: 1001,
SettlementNo: "SET-001",
Status: SettlementStatusProcessing,
TotalAmount: 1000,
PaymentMethod: PaymentMethodBank,
PaymentAccount: "1234567890",
Version: 1,
}
store.Create(context.Background(), settlement)
// 取消处理中结算应该失败
_, err := svc.Cancel(context.Background(), 1001, 1)
assert.Error(t, err)
assert.Contains(t, err.Error(), "cannot cancel")
}
// TestSettlementService_GetByID 测试获取结算单
func TestSettlementService_GetByID(t *testing.T) {
store := newMockSettlementStore()
earningStore := newMockEarningStore()
auditStore := &mockAuditStoreForSettlement{}
svc := NewSettlementService(store, earningStore, auditStore)
// 创建结算单
settlement := &Settlement{
SupplierID: 1001,
SettlementNo: "SET-001",
Status: SettlementStatusPending,
TotalAmount: 1000,
PaymentMethod: PaymentMethodBank,
PaymentAccount: "1234567890",
Version: 1,
}
store.Create(context.Background(), settlement)
// 获取
found, err := svc.GetByID(context.Background(), 1001, settlement.ID)
assert.NoError(t, err)
assert.NotNil(t, found)
assert.Equal(t, settlement.ID, found.ID)
}
// TestSettlementService_GetByID_NotFound 测试获取不存在的结算单
func TestSettlementService_GetByID_NotFound(t *testing.T) {
store := newMockSettlementStore()
earningStore := newMockEarningStore()
auditStore := &mockAuditStoreForSettlement{}
svc := NewSettlementService(store, earningStore, auditStore)
_, err := svc.GetByID(context.Background(), 1001, 9999)
assert.Error(t, err)
}
// TestSettlementService_List 测试列出结算单
func TestSettlementService_List(t *testing.T) {
store := newMockSettlementStore()
earningStore := newMockEarningStore()
auditStore := &mockAuditStoreForSettlement{}
svc := NewSettlementService(store, earningStore, auditStore)
// 创建结算单
for i := 0; i < 3; i++ {
settlement := &Settlement{
SupplierID: 1001,
SettlementNo: "SET-00" + string(rune('1'+i)),
Status: SettlementStatusPending,
TotalAmount: 1000 + float64(i)*100,
PaymentMethod: PaymentMethodBank,
PaymentAccount: "1234567890",
Version: 1,
}
store.Create(context.Background(), settlement)
}
list, err := svc.List(context.Background(), 1001)
assert.NoError(t, err)
assert.Len(t, list, 3)
}
// TestNewEarningService 测试创建收益服务
func TestNewEarningService(t *testing.T) {
earningStore := newMockEarningStore()
svc := NewEarningService(earningStore)
assert.NotNil(t, svc)
}
// TestEarningService_ListRecords 测试列出收益记录
func TestEarningService_ListRecords(t *testing.T) {
earningStore := newMockEarningStore()
svc := NewEarningService(earningStore)
records, total, err := svc.ListRecords(context.Background(), 1001, "2024-01-01", "2024-01-31", 1, 10)
assert.NoError(t, err)
assert.Equal(t, 0, total)
assert.Len(t, records, 0)
}
// TestEarningService_GetBillingSummary 测试获取账单摘要
func TestEarningService_GetBillingSummary(t *testing.T) {
earningStore := newMockEarningStore()
svc := NewEarningService(earningStore)
summary, err := svc.GetBillingSummary(context.Background(), 1001, "2024-01-01", "2024-01-31")
assert.NoError(t, err)
assert.NotNil(t, summary)
assert.Equal(t, "2024-01-01", summary.Period.Start)
assert.Equal(t, "2024-01-31", summary.Period.End)
assert.Equal(t, float64(1000), summary.Summary.TotalRevenue)
}
// TestGenerateSettlementNo 测试生成结算单号
func TestGenerateSettlementNo(t *testing.T) {
no := generateSettlementNo()
assert.NotEmpty(t, no)
// 格式为时间戳 20060102150405
assert.Equal(t, 14, len(no))
}