251 lines
7.4 KiB
Go
251 lines
7.4 KiB
Go
|
|
package repository
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"errors"
|
||
|
|
"fmt"
|
||
|
|
"net/netip"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"github.com/jackc/pgx/v5"
|
||
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
||
|
|
"lijiaoqiao/supply-api/internal/domain"
|
||
|
|
)
|
||
|
|
|
||
|
|
// PackageRepository 套餐仓储
|
||
|
|
type PackageRepository struct {
|
||
|
|
pool *pgxpool.Pool
|
||
|
|
}
|
||
|
|
|
||
|
|
// NewPackageRepository 创建套餐仓储
|
||
|
|
func NewPackageRepository(pool *pgxpool.Pool) *PackageRepository {
|
||
|
|
return &PackageRepository{pool: pool}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Create 创建套餐
|
||
|
|
func (r *PackageRepository) Create(ctx context.Context, pkg *domain.Package, requestID, traceID string) error {
|
||
|
|
query := `
|
||
|
|
INSERT INTO supply_packages (
|
||
|
|
supply_account_id, user_id, platform, model,
|
||
|
|
total_quota, available_quota, sold_quota, reserved_quota,
|
||
|
|
price_per_1m_input, price_per_1m_output, min_purchase,
|
||
|
|
start_at, end_at, valid_days,
|
||
|
|
status, max_concurrent, rate_limit_rpm,
|
||
|
|
total_orders, total_revenue, rating, rating_count,
|
||
|
|
quota_unit, price_unit, currency_code, version,
|
||
|
|
created_ip, updated_ip, audit_trace_id,
|
||
|
|
request_id
|
||
|
|
) VALUES (
|
||
|
|
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21,
|
||
|
|
$22, $23, $24, $25, $26, $27, $28
|
||
|
|
)
|
||
|
|
RETURNING id, created_at, updated_at
|
||
|
|
`
|
||
|
|
|
||
|
|
var startAt, endAt *time.Time
|
||
|
|
if !pkg.StartAt.IsZero() {
|
||
|
|
startAt = &pkg.StartAt
|
||
|
|
}
|
||
|
|
if !pkg.EndAt.IsZero() {
|
||
|
|
endAt = &pkg.EndAt
|
||
|
|
}
|
||
|
|
|
||
|
|
err := r.pool.QueryRow(ctx, query,
|
||
|
|
pkg.SupplierID, pkg.SupplierID, pkg.Platform, pkg.Model,
|
||
|
|
pkg.TotalQuota, pkg.AvailableQuota, pkg.SoldQuota, pkg.ReservedQuota,
|
||
|
|
pkg.PricePer1MInput, pkg.PricePer1MOutput, pkg.MinPurchase,
|
||
|
|
startAt, endAt, pkg.ValidDays,
|
||
|
|
pkg.Status, pkg.MaxConcurrent, pkg.RateLimitRPM,
|
||
|
|
pkg.TotalOrders, pkg.TotalRevenue, pkg.Rating, pkg.RatingCount,
|
||
|
|
"token", "per_1m_tokens", "USD", 0,
|
||
|
|
nil, nil, traceID,
|
||
|
|
requestID,
|
||
|
|
).Scan(&pkg.ID, &pkg.CreatedAt, &pkg.UpdatedAt)
|
||
|
|
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("failed to create package: %w", err)
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// GetByID 获取套餐
|
||
|
|
func (r *PackageRepository) GetByID(ctx context.Context, supplierID, id int64) (*domain.Package, error) {
|
||
|
|
query := `
|
||
|
|
SELECT id, supply_account_id, user_id, platform, model,
|
||
|
|
total_quota, available_quota, sold_quota, reserved_quota,
|
||
|
|
price_per_1m_input, price_per_1m_output, min_purchase,
|
||
|
|
start_at, end_at, valid_days,
|
||
|
|
status, max_concurrent, rate_limit_rpm,
|
||
|
|
total_orders, total_revenue, rating, rating_count,
|
||
|
|
quota_unit, price_unit, currency_code, version,
|
||
|
|
created_at, updated_at
|
||
|
|
FROM supply_packages
|
||
|
|
WHERE id = $1 AND user_id = $2
|
||
|
|
`
|
||
|
|
|
||
|
|
pkg := &domain.Package{}
|
||
|
|
var startAt, endAt pgx.NullTime
|
||
|
|
err := r.pool.QueryRow(ctx, query, id, supplierID).Scan(
|
||
|
|
&pkg.ID, &pkg.SupplierID, &pkg.SupplierID, &pkg.Platform, &pkg.Model,
|
||
|
|
&pkg.TotalQuota, &pkg.AvailableQuota, &pkg.SoldQuota, &pkg.ReservedQuota,
|
||
|
|
&pkg.PricePer1MInput, &pkg.PricePer1MOutput, &pkg.MinPurchase,
|
||
|
|
&startAt, &endAt, &pkg.ValidDays,
|
||
|
|
&pkg.Status, &pkg.MaxConcurrent, &pkg.RateLimitRPM,
|
||
|
|
&pkg.TotalOrders, &pkg.TotalRevenue, &pkg.Rating, &pkg.RatingCount,
|
||
|
|
&pkg.QuotaUnit, &pkg.PriceUnit, &pkg.CurrencyCode, &pkg.Version,
|
||
|
|
&pkg.CreatedAt, &pkg.UpdatedAt,
|
||
|
|
)
|
||
|
|
|
||
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
||
|
|
return nil, ErrNotFound
|
||
|
|
}
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to get package: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if startAt.Valid {
|
||
|
|
pkg.StartAt = startAt.Time
|
||
|
|
}
|
||
|
|
if endAt.Valid {
|
||
|
|
pkg.EndAt = endAt.Time
|
||
|
|
}
|
||
|
|
|
||
|
|
return pkg, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// Update 更新套餐(乐观锁)
|
||
|
|
func (r *PackageRepository) Update(ctx context.Context, pkg *domain.Package, expectedVersion int) error {
|
||
|
|
query := `
|
||
|
|
UPDATE supply_packages SET
|
||
|
|
platform = $1, model = $2,
|
||
|
|
total_quota = $3, available_quota = $4, sold_quota = $5, reserved_quota = $6,
|
||
|
|
price_per_1m_input = $7, price_per_1m_output = $8,
|
||
|
|
start_at = $9, end_at = $10, valid_days = $11,
|
||
|
|
status = $12, max_concurrent = $13, rate_limit_rpm = $14,
|
||
|
|
total_orders = $15, total_revenue = $16,
|
||
|
|
rating = $17, rating_count = $18,
|
||
|
|
version = $19, updated_at = $20
|
||
|
|
WHERE id = $21 AND user_id = $22 AND version = $23
|
||
|
|
`
|
||
|
|
|
||
|
|
pkg.UpdatedAt = time.Now()
|
||
|
|
newVersion := expectedVersion + 1
|
||
|
|
|
||
|
|
cmdTag, err := r.pool.Exec(ctx, query,
|
||
|
|
pkg.Platform, pkg.Model,
|
||
|
|
pkg.TotalQuota, pkg.AvailableQuota, pkg.SoldQuota, pkg.ReservedQuota,
|
||
|
|
pkg.PricePer1MInput, pkg.PricePer1MOutput,
|
||
|
|
pkg.StartAt, pkg.EndAt, pkg.ValidDays,
|
||
|
|
pkg.Status, pkg.MaxConcurrent, pkg.RateLimitRPM,
|
||
|
|
pkg.TotalOrders, pkg.TotalRevenue,
|
||
|
|
pkg.Rating, pkg.RatingCount,
|
||
|
|
newVersion, pkg.UpdatedAt,
|
||
|
|
pkg.ID, pkg.SupplierID, expectedVersion,
|
||
|
|
)
|
||
|
|
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("failed to update package: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if cmdTag.RowsAffected() == 0 {
|
||
|
|
return ErrConcurrencyConflict
|
||
|
|
}
|
||
|
|
|
||
|
|
pkg.Version = newVersion
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// GetForUpdate 获取套餐并加行锁
|
||
|
|
func (r *PackageRepository) GetForUpdate(ctx context.Context, tx pgxpool.Tx, supplierID, id int64) (*domain.Package, error) {
|
||
|
|
query := `
|
||
|
|
SELECT id, supply_account_id, user_id, platform, model,
|
||
|
|
total_quota, available_quota, sold_quota, reserved_quota,
|
||
|
|
price_per_1m_input, price_per_1m_output,
|
||
|
|
status, version,
|
||
|
|
created_at, updated_at
|
||
|
|
FROM supply_packages
|
||
|
|
WHERE id = $1 AND user_id = $2
|
||
|
|
FOR UPDATE
|
||
|
|
`
|
||
|
|
|
||
|
|
pkg := &domain.Package{}
|
||
|
|
err := tx.QueryRow(ctx, query, id, supplierID).Scan(
|
||
|
|
&pkg.ID, &pkg.SupplierID, &pkg.SupplierID, &pkg.Platform, &pkg.Model,
|
||
|
|
&pkg.TotalQuota, &pkg.AvailableQuota, &pkg.SoldQuota, &pkg.ReservedQuota,
|
||
|
|
&pkg.PricePer1MInput, &pkg.PricePer1MOutput,
|
||
|
|
&pkg.Status, &pkg.Version,
|
||
|
|
&pkg.CreatedAt, &pkg.UpdatedAt,
|
||
|
|
)
|
||
|
|
|
||
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
||
|
|
return nil, ErrNotFound
|
||
|
|
}
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to get package for update: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return pkg, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// List 列出套餐
|
||
|
|
func (r *PackageRepository) List(ctx context.Context, supplierID int64) ([]*domain.Package, error) {
|
||
|
|
query := `
|
||
|
|
SELECT id, supply_account_id, user_id, platform, model,
|
||
|
|
total_quota, available_quota, sold_quota,
|
||
|
|
price_per_1m_input, price_per_1m_output,
|
||
|
|
status, max_concurrent, rate_limit_rpm,
|
||
|
|
valid_days, total_orders, total_revenue,
|
||
|
|
version, created_at, updated_at
|
||
|
|
FROM supply_packages
|
||
|
|
WHERE user_id = $1
|
||
|
|
ORDER BY created_at DESC
|
||
|
|
`
|
||
|
|
|
||
|
|
rows, err := r.pool.Query(ctx, query, supplierID)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to list packages: %w", err)
|
||
|
|
}
|
||
|
|
defer rows.Close()
|
||
|
|
|
||
|
|
var packages []*domain.Package
|
||
|
|
for rows.Next() {
|
||
|
|
pkg := &domain.Package{}
|
||
|
|
err := rows.Scan(
|
||
|
|
&pkg.ID, &pkg.SupplierID, &pkg.SupplierID, &pkg.Platform, &pkg.Model,
|
||
|
|
&pkg.TotalQuota, &pkg.AvailableQuota, &pkg.SoldQuota,
|
||
|
|
&pkg.PricePer1MInput, &pkg.PricePer1MOutput,
|
||
|
|
&pkg.Status, &pkg.MaxConcurrent, &pkg.RateLimitRPM,
|
||
|
|
&pkg.ValidDays, &pkg.TotalOrders, &pkg.TotalRevenue,
|
||
|
|
&pkg.Version, &pkg.CreatedAt, &pkg.UpdatedAt,
|
||
|
|
)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to scan package: %w", err)
|
||
|
|
}
|
||
|
|
packages = append(packages, pkg)
|
||
|
|
}
|
||
|
|
|
||
|
|
return packages, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// UpdateQuota 扣减配额
|
||
|
|
func (r *PackageRepository) UpdateQuota(ctx context.Context, tx pgxpool.Tx, packageID, supplierID int64, usedQuota float64) error {
|
||
|
|
query := `
|
||
|
|
UPDATE supply_packages SET
|
||
|
|
available_quota = available_quota - $1,
|
||
|
|
sold_quota = sold_quota + $1,
|
||
|
|
updated_at = $2
|
||
|
|
WHERE id = $3 AND user_id = $4 AND available_quota >= $1
|
||
|
|
RETURNING id
|
||
|
|
`
|
||
|
|
|
||
|
|
var id int64
|
||
|
|
err := tx.QueryRow(ctx, query, usedQuota, time.Now(), packageID, supplierID).Scan(&id)
|
||
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
||
|
|
return errors.New("insufficient quota or package not found")
|
||
|
|
}
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("failed to update quota: %w", err)
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|