Files
lijiaoqiao/supply-api/internal/storage/store.go
Your Name efa4edcc15 fix: 修复提现唯一性检查问题 (PRD P0)
问题:Withdraw函数没有检查是否已有处理中的提现,可能导致并发提现

修复内容:
1. 添加新错误码 ErrWithdrawAlreadyProcessing (SUP_SET_4093)
2. 在 SettlementStore 接口添加 HasPendingOrProcessingWithdraw 方法
3. 在 Withdraw 函数中添加检查:已有pending/processing状态提现时拒绝新的提现
4. 在 Repository 中实现 HasPendingOrProcessingWithdraw(检查 pending 和 processing 状态)
5. 在所有 mock 实现中添加该方法

修改的文件:
- domain/settlement.go: 接口定义和 Withdraw 逻辑
- domain/invariants.go: 新错误码
- repository/settlement.go: HasPendingOrProcessingWithdraw 实现
- storage/store.go: InMemorySettlementStore 实现
- cmd/supply-api/main.go: DBSettlementStore 和 InMemorySettlementStoreAdapter 实现
- test mocks: 添加 HasPendingOrProcessingWithdraw
2026-04-08 20:26:50 +08:00

379 lines
8.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package storage
import (
"context"
"errors"
"sync"
"time"
"lijiaoqiao/supply-api/internal/domain"
"lijiaoqiao/supply-api/internal/repository"
)
// 错误定义
var ErrNotFound = errors.New("resource not found")
// 内存账号存储
type InMemoryAccountStore struct {
mu sync.RWMutex
accounts map[int64]*domain.Account
nextID int64
}
func NewInMemoryAccountStore() *InMemoryAccountStore {
return &InMemoryAccountStore{
accounts: make(map[int64]*domain.Account),
nextID: 1,
}
}
func (s *InMemoryAccountStore) Create(ctx context.Context, account *domain.Account) error {
s.mu.Lock()
defer s.mu.Unlock()
account.ID = s.nextID
s.nextID++
account.CreatedAt = time.Now()
account.UpdatedAt = time.Now()
s.accounts[account.ID] = account
return nil
}
func (s *InMemoryAccountStore) GetByID(ctx context.Context, supplierID, id int64) (*domain.Account, error) {
s.mu.RLock()
defer s.mu.RUnlock()
account, ok := s.accounts[id]
if !ok || account.SupplierID != supplierID {
return nil, ErrNotFound
}
return account, nil
}
func (s *InMemoryAccountStore) Update(ctx context.Context, account *domain.Account) error {
s.mu.Lock()
defer s.mu.Unlock()
existing, ok := s.accounts[account.ID]
if !ok || existing.SupplierID != account.SupplierID {
return ErrNotFound
}
account.UpdatedAt = time.Now()
s.accounts[account.ID] = account
return nil
}
func (s *InMemoryAccountStore) List(ctx context.Context, supplierID int64) ([]*domain.Account, error) {
s.mu.RLock()
defer s.mu.RUnlock()
result := make([]*domain.Account, 0)
for _, account := range s.accounts {
if account.SupplierID == supplierID {
result = append(result, account)
}
}
return result, nil
}
// 内存套餐存储
type InMemoryPackageStore struct {
mu sync.RWMutex
packages map[int64]*domain.Package
nextID int64
}
func NewInMemoryPackageStore() *InMemoryPackageStore {
return &InMemoryPackageStore{
packages: make(map[int64]*domain.Package),
nextID: 1,
}
}
func (s *InMemoryPackageStore) Create(ctx context.Context, pkg *domain.Package) error {
s.mu.Lock()
defer s.mu.Unlock()
pkg.ID = s.nextID
s.nextID++
pkg.CreatedAt = time.Now()
pkg.UpdatedAt = time.Now()
s.packages[pkg.ID] = pkg
return nil
}
func (s *InMemoryPackageStore) GetByID(ctx context.Context, supplierID, id int64) (*domain.Package, error) {
s.mu.RLock()
defer s.mu.RUnlock()
pkg, ok := s.packages[id]
if !ok || pkg.SupplierID != supplierID {
return nil, ErrNotFound
}
return pkg, nil
}
func (s *InMemoryPackageStore) Update(ctx context.Context, pkg *domain.Package) error {
s.mu.Lock()
defer s.mu.Unlock()
existing, ok := s.packages[pkg.ID]
if !ok || existing.SupplierID != pkg.SupplierID {
return ErrNotFound
}
pkg.UpdatedAt = time.Now()
s.packages[pkg.ID] = pkg
return nil
}
func (s *InMemoryPackageStore) List(ctx context.Context, supplierID int64) ([]*domain.Package, error) {
s.mu.RLock()
defer s.mu.RUnlock()
result := make([]*domain.Package, 0)
for _, pkg := range s.packages {
if pkg.SupplierID == supplierID {
result = append(result, pkg)
}
}
return result, nil
}
// 内存结算存储
type InMemorySettlementStore struct {
mu sync.RWMutex
settlements map[int64]*domain.Settlement
nextID int64
}
func NewInMemorySettlementStore() *InMemorySettlementStore {
return &InMemorySettlementStore{
settlements: make(map[int64]*domain.Settlement),
nextID: 1,
}
}
func (s *InMemorySettlementStore) Create(ctx context.Context, settlement *domain.Settlement) error {
s.mu.Lock()
defer s.mu.Unlock()
settlement.ID = s.nextID
s.nextID++
settlement.CreatedAt = time.Now()
settlement.UpdatedAt = time.Now()
s.settlements[settlement.ID] = settlement
return nil
}
func (s *InMemorySettlementStore) GetByID(ctx context.Context, supplierID, id int64) (*domain.Settlement, error) {
s.mu.RLock()
defer s.mu.RUnlock()
settlement, ok := s.settlements[id]
if !ok || settlement.SupplierID != supplierID {
return nil, ErrNotFound
}
return settlement, nil
}
func (s *InMemorySettlementStore) Update(ctx context.Context, settlement *domain.Settlement, expectedVersion int) error {
s.mu.Lock()
defer s.mu.Unlock()
existing, ok := s.settlements[settlement.ID]
if !ok || existing.SupplierID != settlement.SupplierID {
return ErrNotFound
}
// P1-005: 乐观锁检查
if existing.Version != expectedVersion {
return repository.ErrConcurrencyConflict
}
settlement.Version = expectedVersion + 1
settlement.UpdatedAt = time.Now()
s.settlements[settlement.ID] = settlement
return nil
}
func (s *InMemorySettlementStore) List(ctx context.Context, supplierID int64) ([]*domain.Settlement, error) {
s.mu.RLock()
defer s.mu.RUnlock()
result := make([]*domain.Settlement, 0)
for _, settlement := range s.settlements {
if settlement.SupplierID == supplierID {
result = append(result, settlement)
}
}
return result, nil
}
func (s *InMemorySettlementStore) GetWithdrawableBalance(ctx context.Context, supplierID int64) (float64, error) {
return 10000.0, nil
}
func (s *InMemorySettlementStore) HasPendingOrProcessingWithdraw(ctx context.Context, supplierID int64) (bool, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, settlement := range s.settlements {
if settlement.SupplierID == supplierID {
if settlement.Status == domain.SettlementStatusPending || settlement.Status == domain.SettlementStatusProcessing {
return true, nil
}
}
}
return false, nil
}
// 内存收益存储
type InMemoryEarningStore struct {
mu sync.RWMutex
records map[int64]*domain.EarningRecord
nextID int64
}
func NewInMemoryEarningStore() *InMemoryEarningStore {
return &InMemoryEarningStore{
records: make(map[int64]*domain.EarningRecord),
nextID: 1,
}
}
func (s *InMemoryEarningStore) ListRecords(ctx context.Context, supplierID int64, startDate, endDate string, page, pageSize int) ([]*domain.EarningRecord, int, error) {
s.mu.RLock()
defer s.mu.RUnlock()
var result []*domain.EarningRecord
for _, record := range s.records {
if record.SupplierID == supplierID {
result = append(result, record)
}
}
total := len(result)
start := (page - 1) * pageSize
end := start + pageSize
if start >= total {
return []*domain.EarningRecord{}, total, nil
}
if end > total {
end = total
}
return result[start:end], total, nil
}
func (s *InMemoryEarningStore) GetBillingSummary(ctx context.Context, supplierID int64, startDate, endDate string) (*domain.BillingSummary, error) {
return &domain.BillingSummary{
Period: domain.BillingPeriod{
Start: startDate,
End: endDate,
},
Summary: domain.BillingTotal{
TotalRevenue: 10000.0,
TotalOrders: 100,
TotalUsage: 1000000,
TotalRequests: 50000,
AvgSuccessRate: 99.5,
PlatformFee: 100.0,
NetEarnings: 9900.0,
},
}, nil
}
// 内存幂等存储
type InMemoryIdempotencyStore struct {
mu sync.RWMutex
records map[string]*IdempotencyRecord
cleanupCounter int64 // 清理触发计数器
}
type IdempotencyRecord struct {
Key string
Status string // processing, succeeded, failed
Response interface{}
CreatedAt time.Time
ExpiresAt time.Time
}
func NewInMemoryIdempotencyStore() *InMemoryIdempotencyStore {
return &InMemoryIdempotencyStore{
records: make(map[string]*IdempotencyRecord),
}
}
func (s *InMemoryIdempotencyStore) Get(key string) (*IdempotencyRecord, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
record, ok := s.records[key]
if ok && record.ExpiresAt.After(time.Now()) {
return record, true
}
return nil, false
}
func (s *InMemoryIdempotencyStore) SetProcessing(key string, ttl time.Duration) {
s.mu.Lock()
defer s.mu.Unlock()
s.records[key] = &IdempotencyRecord{
Key: key,
Status: "processing",
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(ttl),
}
s.triggerCleanupLocked()
}
func (s *InMemoryIdempotencyStore) SetSuccess(key string, response interface{}, ttl time.Duration) {
s.mu.Lock()
defer s.mu.Unlock()
s.records[key] = &IdempotencyRecord{
Key: key,
Status: "succeeded",
Response: response,
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(ttl),
}
s.triggerCleanupLocked()
}
// triggerCleanupLocked 触发清理每100次操作清理一次过期记录
// 调用时必须持有锁
func (s *InMemoryIdempotencyStore) triggerCleanupLocked() {
s.cleanupCounter++
if s.cleanupCounter >= 100 {
s.cleanupCounter = 0
s.cleanupExpiredLocked()
}
}
// cleanupExpiredLocked 清理过期记录(需要持有锁)
func (s *InMemoryIdempotencyStore) cleanupExpiredLocked() {
now := time.Now()
for key, record := range s.records {
if record.ExpiresAt.Before(now) {
delete(s.records, key)
}
}
}
// CleanExpired 主动清理过期记录(可由外部定期调用)
func (s *InMemoryIdempotencyStore) CleanExpired() {
s.mu.Lock()
defer s.mu.Unlock()
s.cleanupExpiredLocked()
}
// Len 返回当前记录数量(用于监控)
func (s *InMemoryIdempotencyStore) Len() int {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.records)
}