244 lines
7.2 KiB
Go
244 lines
7.2 KiB
Go
|
|
package repository
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"errors"
|
||
|
|
"fmt"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"github.com/jackc/pgx/v5"
|
||
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
||
|
|
"lijiaoqiao/supply-api/internal/domain"
|
||
|
|
)
|
||
|
|
|
||
|
|
// SettlementRepository 结算仓储
|
||
|
|
type SettlementRepository struct {
|
||
|
|
pool *pgxpool.Pool
|
||
|
|
}
|
||
|
|
|
||
|
|
// NewSettlementRepository 创建结算仓储
|
||
|
|
func NewSettlementRepository(pool *pgxpool.Pool) *SettlementRepository {
|
||
|
|
return &SettlementRepository{pool: pool}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Create 创建结算单
|
||
|
|
func (r *SettlementRepository) Create(ctx context.Context, s *domain.Settlement, requestID, idempotencyKey, traceID string) error {
|
||
|
|
query := `
|
||
|
|
INSERT INTO supply_settlements (
|
||
|
|
settlement_no, user_id, total_amount, fee_amount, net_amount,
|
||
|
|
status, payment_method, payment_account,
|
||
|
|
period_start, period_end, total_orders, total_usage_records,
|
||
|
|
currency_code, amount_unit, version,
|
||
|
|
request_id, idempotency_key, audit_trace_id
|
||
|
|
) VALUES (
|
||
|
|
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18
|
||
|
|
)
|
||
|
|
RETURNING id, created_at, updated_at
|
||
|
|
`
|
||
|
|
|
||
|
|
err := r.pool.QueryRow(ctx, query,
|
||
|
|
s.SettlementNo, s.SupplierID, s.TotalAmount, s.FeeAmount, s.NetAmount,
|
||
|
|
s.Status, s.PaymentMethod, s.PaymentAccount,
|
||
|
|
s.PeriodStart, s.PeriodEnd, s.TotalOrders, s.TotalUsageRecords,
|
||
|
|
"USD", "minor", 0,
|
||
|
|
requestID, idempotencyKey, traceID,
|
||
|
|
).Scan(&s.ID, &s.CreatedAt, &s.UpdatedAt)
|
||
|
|
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("failed to create settlement: %w", err)
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// GetByID 获取结算单
|
||
|
|
func (r *SettlementRepository) GetByID(ctx context.Context, supplierID, id int64) (*domain.Settlement, error) {
|
||
|
|
query := `
|
||
|
|
SELECT id, settlement_no, user_id, total_amount, fee_amount, net_amount,
|
||
|
|
status, payment_method, payment_account,
|
||
|
|
period_start, period_end, total_orders, total_usage_records,
|
||
|
|
payment_transaction_id, paid_at,
|
||
|
|
version, created_at, updated_at
|
||
|
|
FROM supply_settlements
|
||
|
|
WHERE id = $1 AND user_id = $2
|
||
|
|
`
|
||
|
|
|
||
|
|
s := &domain.Settlement{}
|
||
|
|
var paidAt pgx.NullTime
|
||
|
|
err := r.pool.QueryRow(ctx, query, id, supplierID).Scan(
|
||
|
|
&s.ID, &s.SettlementNo, &s.SupplierID, &s.TotalAmount, &s.FeeAmount, &s.NetAmount,
|
||
|
|
&s.Status, &s.PaymentMethod, &s.PaymentAccount,
|
||
|
|
&s.PeriodStart, &s.PeriodEnd, &s.TotalOrders, &s.TotalUsageRecords,
|
||
|
|
&s.PaymentTransactionID, &paidAt,
|
||
|
|
&s.Version, &s.CreatedAt, &s.UpdatedAt,
|
||
|
|
)
|
||
|
|
|
||
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
||
|
|
return nil, ErrNotFound
|
||
|
|
}
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to get settlement: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if paidAt.Valid {
|
||
|
|
s.PaidAt = &paidAt.Time
|
||
|
|
}
|
||
|
|
|
||
|
|
return s, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// Update 更新结算单(乐观锁)
|
||
|
|
func (r *SettlementRepository) Update(ctx context.Context, s *domain.Settlement, expectedVersion int) error {
|
||
|
|
query := `
|
||
|
|
UPDATE supply_settlements SET
|
||
|
|
status = $1, payment_method = $2, payment_account = $3,
|
||
|
|
payment_transaction_id = $4, paid_at = $5,
|
||
|
|
total_orders = $6, total_usage_records = $7,
|
||
|
|
version = $8, updated_at = $9
|
||
|
|
WHERE id = $10 AND user_id = $11 AND version = $12
|
||
|
|
`
|
||
|
|
|
||
|
|
s.UpdatedAt = time.Now()
|
||
|
|
newVersion := expectedVersion + 1
|
||
|
|
|
||
|
|
cmdTag, err := r.pool.Exec(ctx, query,
|
||
|
|
s.Status, s.PaymentMethod, s.PaymentAccount,
|
||
|
|
s.PaymentTransactionID, s.PaidAt,
|
||
|
|
s.TotalOrders, s.TotalUsageRecords,
|
||
|
|
newVersion, s.UpdatedAt,
|
||
|
|
s.ID, s.SupplierID, expectedVersion,
|
||
|
|
)
|
||
|
|
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("failed to update settlement: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if cmdTag.RowsAffected() == 0 {
|
||
|
|
return ErrConcurrencyConflict
|
||
|
|
}
|
||
|
|
|
||
|
|
s.Version = newVersion
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// GetForUpdate 获取结算单并加行锁
|
||
|
|
func (r *SettlementRepository) GetForUpdate(ctx context.Context, tx pgxpool.Tx, supplierID, id int64) (*domain.Settlement, error) {
|
||
|
|
query := `
|
||
|
|
SELECT id, settlement_no, user_id, total_amount, fee_amount, net_amount,
|
||
|
|
status, payment_method, payment_account, version,
|
||
|
|
created_at, updated_at
|
||
|
|
FROM supply_settlements
|
||
|
|
WHERE id = $1 AND user_id = $2
|
||
|
|
FOR UPDATE
|
||
|
|
`
|
||
|
|
|
||
|
|
s := &domain.Settlement{}
|
||
|
|
err := tx.QueryRow(ctx, query, id, supplierID).Scan(
|
||
|
|
&s.ID, &s.SettlementNo, &s.SupplierID, &s.TotalAmount, &s.FeeAmount, &s.NetAmount,
|
||
|
|
&s.Status, &s.PaymentMethod, &s.PaymentAccount, &s.Version,
|
||
|
|
&s.CreatedAt, &s.UpdatedAt,
|
||
|
|
)
|
||
|
|
|
||
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
||
|
|
return nil, ErrNotFound
|
||
|
|
}
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to get settlement for update: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return s, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// GetProcessing 获取处理中的结算单(用于单一性约束)
|
||
|
|
func (r *SettlementRepository) GetProcessing(ctx context.Context, tx pgxpool.Tx, supplierID int64) (*domain.Settlement, error) {
|
||
|
|
query := `
|
||
|
|
SELECT id, settlement_no, user_id, total_amount, fee_amount, net_amount,
|
||
|
|
status, payment_method, payment_account, version,
|
||
|
|
created_at, updated_at
|
||
|
|
FROM supply_settlements
|
||
|
|
WHERE user_id = $1 AND status = 'processing'
|
||
|
|
FOR UPDATE SKIP LOCKED
|
||
|
|
LIMIT 1
|
||
|
|
`
|
||
|
|
|
||
|
|
s := &domain.Settlement{}
|
||
|
|
err := tx.QueryRow(ctx, query, supplierID).Scan(
|
||
|
|
&s.ID, &s.SettlementNo, &s.SupplierID, &s.TotalAmount, &s.FeeAmount, &s.NetAmount,
|
||
|
|
&s.Status, &s.PaymentMethod, &s.PaymentAccount, &s.Version,
|
||
|
|
&s.CreatedAt, &s.UpdatedAt,
|
||
|
|
)
|
||
|
|
|
||
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
||
|
|
return nil, nil // 没有处理中的单据
|
||
|
|
}
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to get processing settlement: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return s, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// List 列出结算单
|
||
|
|
func (r *SettlementRepository) List(ctx context.Context, supplierID int64) ([]*domain.Settlement, error) {
|
||
|
|
query := `
|
||
|
|
SELECT id, settlement_no, user_id, total_amount, fee_amount, net_amount,
|
||
|
|
status, payment_method,
|
||
|
|
period_start, period_end, total_orders,
|
||
|
|
version, created_at, updated_at
|
||
|
|
FROM supply_settlements
|
||
|
|
WHERE user_id = $1
|
||
|
|
ORDER BY created_at DESC
|
||
|
|
`
|
||
|
|
|
||
|
|
rows, err := r.pool.Query(ctx, query, supplierID)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to list settlements: %w", err)
|
||
|
|
}
|
||
|
|
defer rows.Close()
|
||
|
|
|
||
|
|
var settlements []*domain.Settlement
|
||
|
|
for rows.Next() {
|
||
|
|
s := &domain.Settlement{}
|
||
|
|
err := rows.Scan(
|
||
|
|
&s.ID, &s.SettlementNo, &s.SupplierID, &s.TotalAmount, &s.FeeAmount, &s.NetAmount,
|
||
|
|
&s.Status, &s.PaymentMethod,
|
||
|
|
&s.PeriodStart, &s.PeriodEnd, &s.TotalOrders,
|
||
|
|
&s.Version, &s.CreatedAt, &s.UpdatedAt,
|
||
|
|
)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to scan settlement: %w", err)
|
||
|
|
}
|
||
|
|
settlements = append(settlements, s)
|
||
|
|
}
|
||
|
|
|
||
|
|
return settlements, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// CreateInTx 在事务中创建结算单
|
||
|
|
func (r *SettlementRepository) CreateInTx(ctx context.Context, tx pgxpool.Tx, s *domain.Settlement, requestID, idempotencyKey, traceID string) error {
|
||
|
|
query := `
|
||
|
|
INSERT INTO supply_settlements (
|
||
|
|
settlement_no, user_id, total_amount, fee_amount, net_amount,
|
||
|
|
status, payment_method, payment_account,
|
||
|
|
period_start, period_end, total_orders, total_usage_records,
|
||
|
|
currency_code, amount_unit, version,
|
||
|
|
request_id, idempotency_key, audit_trace_id
|
||
|
|
) VALUES (
|
||
|
|
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18
|
||
|
|
)
|
||
|
|
RETURNING id, created_at, updated_at
|
||
|
|
`
|
||
|
|
|
||
|
|
err := tx.QueryRow(ctx, query,
|
||
|
|
s.SettlementNo, s.SupplierID, s.TotalAmount, s.FeeAmount, s.NetAmount,
|
||
|
|
s.Status, s.PaymentMethod, s.PaymentAccount,
|
||
|
|
s.PeriodStart, s.PeriodEnd, s.TotalOrders, s.TotalUsageRecords,
|
||
|
|
"USD", "minor", 0,
|
||
|
|
requestID, idempotencyKey, traceID,
|
||
|
|
).Scan(&s.ID, &s.CreatedAt, &s.UpdatedAt)
|
||
|
|
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("failed to create settlement in tx: %w", err)
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|