Files
lijiaoqiao/supply-api/internal/repository/settlement.go
Your Name efa4edcc15 fix: 修复提现唯一性检查问题 (PRD P0)
问题:Withdraw函数没有检查是否已有处理中的提现,可能导致并发提现

修复内容:
1. 添加新错误码 ErrWithdrawAlreadyProcessing (SUP_SET_4093)
2. 在 SettlementStore 接口添加 HasPendingOrProcessingWithdraw 方法
3. 在 Withdraw 函数中添加检查:已有pending/processing状态提现时拒绝新的提现
4. 在 Repository 中实现 HasPendingOrProcessingWithdraw(检查 pending 和 processing 状态)
5. 在所有 mock 实现中添加该方法

修改的文件:
- domain/settlement.go: 接口定义和 Withdraw 逻辑
- domain/invariants.go: 新错误码
- repository/settlement.go: HasPendingOrProcessingWithdraw 实现
- storage/store.go: InMemorySettlementStore 实现
- cmd/supply-api/main.go: DBSettlementStore 和 InMemorySettlementStoreAdapter 实现
- test mocks: 添加 HasPendingOrProcessingWithdraw
2026-04-08 20:26:50 +08:00

292 lines
8.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package repository
import (
"context"
"errors"
"fmt"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"lijiaoqiao/supply-api/internal/domain"
)
// SettlementRepository 结算仓储
type SettlementRepository struct {
pool *pgxpool.Pool
}
// NewSettlementRepository 创建结算仓储
func NewSettlementRepository(pool *pgxpool.Pool) *SettlementRepository {
return &SettlementRepository{pool: pool}
}
// Create 创建结算单
func (r *SettlementRepository) Create(ctx context.Context, s *domain.Settlement, requestID, idempotencyKey, traceID string) error {
query := `
INSERT INTO supply_settlements (
settlement_no, user_id, total_amount, fee_amount, net_amount,
status, payment_method, payment_account,
period_start, period_end, total_orders, total_usage_records,
currency_code, amount_unit, version,
request_id, idempotency_key, audit_trace_id
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18
)
RETURNING id, created_at, updated_at
`
err := r.pool.QueryRow(ctx, query,
s.SettlementNo, s.SupplierID, s.TotalAmount, s.FeeAmount, s.NetAmount,
s.Status, s.PaymentMethod, s.PaymentAccount,
s.PeriodStart, s.PeriodEnd, s.TotalOrders, s.TotalUsageRecords,
"USD", "minor", 0,
requestID, idempotencyKey, traceID,
).Scan(&s.ID, &s.CreatedAt, &s.UpdatedAt)
if err != nil {
return fmt.Errorf("failed to create settlement: %w", err)
}
return nil
}
// GetByID 获取结算单
func (r *SettlementRepository) GetByID(ctx context.Context, supplierID, id int64) (*domain.Settlement, error) {
query := `
SELECT id, settlement_no, user_id, total_amount, fee_amount, net_amount,
status, payment_method, payment_account,
period_start, period_end, total_orders, total_usage_records,
payment_transaction_id, paid_at,
version, created_at, updated_at
FROM supply_settlements
WHERE id = $1 AND user_id = $2
`
s := &domain.Settlement{}
var paidAt *time.Time
err := r.pool.QueryRow(ctx, query, id, supplierID).Scan(
&s.ID, &s.SettlementNo, &s.SupplierID, &s.TotalAmount, &s.FeeAmount, &s.NetAmount,
&s.Status, &s.PaymentMethod, &s.PaymentAccount,
&s.PeriodStart, &s.PeriodEnd, &s.TotalOrders, &s.TotalUsageRecords,
&s.PaymentTransactionID, &paidAt,
&s.Version, &s.CreatedAt, &s.UpdatedAt,
)
if errors.Is(err, pgx.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("failed to get settlement: %w", err)
}
if paidAt != nil {
s.PaidAt = paidAt
}
return s, nil
}
// Update 更新结算单(乐观锁)
func (r *SettlementRepository) Update(ctx context.Context, s *domain.Settlement, expectedVersion int) error {
query := `
UPDATE supply_settlements SET
status = $1, payment_method = $2, payment_account = $3,
payment_transaction_id = $4, paid_at = $5,
total_orders = $6, total_usage_records = $7,
version = $8, updated_at = $9
WHERE id = $10 AND user_id = $11 AND version = $12
`
s.UpdatedAt = time.Now()
newVersion := expectedVersion + 1
cmdTag, err := r.pool.Exec(ctx, query,
s.Status, s.PaymentMethod, s.PaymentAccount,
s.PaymentTransactionID, s.PaidAt,
s.TotalOrders, s.TotalUsageRecords,
newVersion, s.UpdatedAt,
s.ID, s.SupplierID, expectedVersion,
)
if err != nil {
return fmt.Errorf("failed to update settlement: %w", err)
}
if cmdTag.RowsAffected() == 0 {
return ErrConcurrencyConflict
}
s.Version = newVersion
return nil
}
// GetForUpdate 获取结算单并加行锁(悲观锁)
// 注意:在高并发场景下,建议使用 GetForUpdateNoWait 或 乐观锁
// P1-005: 已添加 NOWAIT 变体和乐观锁支持
func (r *SettlementRepository) GetForUpdate(ctx context.Context, tx pgxpool.Tx, supplierID, id int64) (*domain.Settlement, error) {
query := `
SELECT id, settlement_no, user_id, total_amount, fee_amount, net_amount,
status, payment_method, payment_account, version,
created_at, updated_at
FROM supply_settlements
WHERE id = $1 AND user_id = $2
FOR UPDATE
`
s := &domain.Settlement{}
err := tx.QueryRow(ctx, query, id, supplierID).Scan(
&s.ID, &s.SettlementNo, &s.SupplierID, &s.TotalAmount, &s.FeeAmount, &s.NetAmount,
&s.Status, &s.PaymentMethod, &s.PaymentAccount, &s.Version,
&s.CreatedAt, &s.UpdatedAt,
)
if errors.Is(err, pgx.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("failed to get settlement for update: %w", err)
}
return s, nil
}
// GetForUpdateNoWait 获取结算单并加行锁(不等待锁)
// P1-005: NOWAIT变体 - 如果无法获取锁立即返回错误,适用于高并发场景
func (r *SettlementRepository) GetForUpdateNoWait(ctx context.Context, tx pgxpool.Tx, supplierID, id int64) (*domain.Settlement, error) {
query := `
SELECT id, settlement_no, user_id, total_amount, fee_amount, net_amount,
status, payment_method, payment_account, version,
created_at, updated_at
FROM supply_settlements
WHERE id = $1 AND user_id = $2
FOR UPDATE NOWAIT
`
s := &domain.Settlement{}
err := tx.QueryRow(ctx, query, id, supplierID).Scan(
&s.ID, &s.SettlementNo, &s.SupplierID, &s.TotalAmount, &s.FeeAmount, &s.NetAmount,
&s.Status, &s.PaymentMethod, &s.PaymentAccount, &s.Version,
&s.CreatedAt, &s.UpdatedAt,
)
if errors.Is(err, pgx.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
// NOWAIT会导致锁不可用时立即返回错误而不是等待
return nil, fmt.Errorf("failed to get settlement for update (nowait): %w", err)
}
return s, nil
}
// GetProcessing 获取处理中的结算单(用于单一性约束)
func (r *SettlementRepository) GetProcessing(ctx context.Context, tx pgxpool.Tx, supplierID int64) (*domain.Settlement, error) {
query := `
SELECT id, settlement_no, user_id, total_amount, fee_amount, net_amount,
status, payment_method, payment_account, version,
created_at, updated_at
FROM supply_settlements
WHERE user_id = $1 AND status = 'processing'
FOR UPDATE SKIP LOCKED
LIMIT 1
`
s := &domain.Settlement{}
err := tx.QueryRow(ctx, query, supplierID).Scan(
&s.ID, &s.SettlementNo, &s.SupplierID, &s.TotalAmount, &s.FeeAmount, &s.NetAmount,
&s.Status, &s.PaymentMethod, &s.PaymentAccount, &s.Version,
&s.CreatedAt, &s.UpdatedAt,
)
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil // 没有处理中的单据
}
if err != nil {
return nil, fmt.Errorf("failed to get processing settlement: %w", err)
}
return s, nil
}
// HasPendingOrProcessingWithdraw 检查是否有待处理或处理中的提现单
func (r *SettlementRepository) HasPendingOrProcessingWithdraw(ctx context.Context, supplierID int64) (bool, error) {
query := `
SELECT EXISTS(
SELECT 1 FROM supply_settlements
WHERE user_id = $1 AND status IN ('pending', 'processing')
)
`
var exists bool
err := r.pool.QueryRow(ctx, query, supplierID).Scan(&exists)
if err != nil {
return false, fmt.Errorf("failed to check pending/processing settlement: %w", err)
}
return exists, nil
}
// List 列出结算单
func (r *SettlementRepository) List(ctx context.Context, supplierID int64) ([]*domain.Settlement, error) {
query := `
SELECT id, settlement_no, user_id, total_amount, fee_amount, net_amount,
status, payment_method,
period_start, period_end, total_orders,
version, created_at, updated_at
FROM supply_settlements
WHERE user_id = $1
ORDER BY created_at DESC
`
rows, err := r.pool.Query(ctx, query, supplierID)
if err != nil {
return nil, fmt.Errorf("failed to list settlements: %w", err)
}
defer rows.Close()
settlements := make([]*domain.Settlement, 0)
for rows.Next() {
s := &domain.Settlement{}
err := rows.Scan(
&s.ID, &s.SettlementNo, &s.SupplierID, &s.TotalAmount, &s.FeeAmount, &s.NetAmount,
&s.Status, &s.PaymentMethod,
&s.PeriodStart, &s.PeriodEnd, &s.TotalOrders,
&s.Version, &s.CreatedAt, &s.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan settlement: %w", err)
}
settlements = append(settlements, s)
}
return settlements, nil
}
// CreateInTx 在事务中创建结算单
func (r *SettlementRepository) CreateInTx(ctx context.Context, tx pgxpool.Tx, s *domain.Settlement, requestID, idempotencyKey, traceID string) error {
query := `
INSERT INTO supply_settlements (
settlement_no, user_id, total_amount, fee_amount, net_amount,
status, payment_method, payment_account,
period_start, period_end, total_orders, total_usage_records,
currency_code, amount_unit, version,
request_id, idempotency_key, audit_trace_id
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18
)
RETURNING id, created_at, updated_at
`
err := tx.QueryRow(ctx, query,
s.SettlementNo, s.SupplierID, s.TotalAmount, s.FeeAmount, s.NetAmount,
s.Status, s.PaymentMethod, s.PaymentAccount,
s.PeriodStart, s.PeriodEnd, s.TotalOrders, s.TotalUsageRecords,
"USD", "minor", 0,
requestID, idempotencyKey, traceID,
).Scan(&s.ID, &s.CreatedAt, &s.UpdatedAt)
if err != nil {
return fmt.Errorf("failed to create settlement in tx: %w", err)
}
return nil
}