232 lines
6.3 KiB
Go
232 lines
6.3 KiB
Go
|
|
package cache
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"encoding/json"
|
||
|
|
"fmt"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"github.com/redis/go-redis/v9"
|
||
|
|
"lijiaoqiao/supply-api/internal/config"
|
||
|
|
)
|
||
|
|
|
||
|
|
// RedisCache Redis缓存客户端
|
||
|
|
type RedisCache struct {
|
||
|
|
client *redis.Client
|
||
|
|
}
|
||
|
|
|
||
|
|
// NewRedisCache 创建Redis缓存客户端
|
||
|
|
func NewRedisCache(cfg config.RedisConfig) (*RedisCache, error) {
|
||
|
|
client := redis.NewClient(&redis.Options{
|
||
|
|
Addr: cfg.Addr(),
|
||
|
|
Password: cfg.Password,
|
||
|
|
DB: cfg.DB,
|
||
|
|
PoolSize: cfg.PoolSize,
|
||
|
|
})
|
||
|
|
|
||
|
|
// 验证连接
|
||
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
|
|
defer cancel()
|
||
|
|
|
||
|
|
if err := client.Ping(ctx).Err(); err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to connect to redis: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return &RedisCache{client: client}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// Close 关闭连接
|
||
|
|
func (r *RedisCache) Close() error {
|
||
|
|
return r.client.Close()
|
||
|
|
}
|
||
|
|
|
||
|
|
// HealthCheck 健康检查
|
||
|
|
func (r *RedisCache) HealthCheck(ctx context.Context) error {
|
||
|
|
return r.client.Ping(ctx).Err()
|
||
|
|
}
|
||
|
|
|
||
|
|
// ==================== Token状态缓存 ====================
|
||
|
|
|
||
|
|
// TokenStatus Token状态
|
||
|
|
type TokenStatus struct {
|
||
|
|
TokenID string `json:"token_id"`
|
||
|
|
SubjectID string `json:"subject_id"`
|
||
|
|
Role string `json:"role"`
|
||
|
|
Status string `json:"status"` // active, revoked, expired
|
||
|
|
ExpiresAt int64 `json:"expires_at"`
|
||
|
|
RevokedAt int64 `json:"revoked_at,omitempty"`
|
||
|
|
RevokedReason string `json:"revoked_reason,omitempty"`
|
||
|
|
}
|
||
|
|
|
||
|
|
// GetTokenStatus 获取Token状态
|
||
|
|
func (r *RedisCache) GetTokenStatus(ctx context.Context, tokenID string) (*TokenStatus, error) {
|
||
|
|
key := fmt.Sprintf("token:status:%s", tokenID)
|
||
|
|
data, err := r.client.Get(ctx, key).Bytes()
|
||
|
|
if err == redis.Nil {
|
||
|
|
return nil, nil
|
||
|
|
}
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to get token status: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
var status TokenStatus
|
||
|
|
if err := json.Unmarshal(data, &status); err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to unmarshal token status: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return &status, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// SetTokenStatus 设置Token状态
|
||
|
|
func (r *RedisCache) SetTokenStatus(ctx context.Context, status *TokenStatus, ttl time.Duration) error {
|
||
|
|
key := fmt.Sprintf("token:status:%s", status.TokenID)
|
||
|
|
data, err := json.Marshal(status)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("failed to marshal token status: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return r.client.Set(ctx, key, data, ttl).Err()
|
||
|
|
}
|
||
|
|
|
||
|
|
// InvalidateToken 使Token失效
|
||
|
|
func (r *RedisCache) InvalidateToken(ctx context.Context, tokenID string) error {
|
||
|
|
key := fmt.Sprintf("token:status:%s", tokenID)
|
||
|
|
return r.client.Del(ctx, key).Err()
|
||
|
|
}
|
||
|
|
|
||
|
|
// ==================== 限流 ====================
|
||
|
|
|
||
|
|
// RateLimitKey 限流键
|
||
|
|
type RateLimitKey struct {
|
||
|
|
TenantID int64
|
||
|
|
Route string
|
||
|
|
LimitType string // rpm, rpd, concurrent
|
||
|
|
}
|
||
|
|
|
||
|
|
// GetRateLimit 获取限流计数
|
||
|
|
func (r *RedisCache) GetRateLimit(ctx context.Context, key *RateLimitKey, window time.Duration) (int64, error) {
|
||
|
|
redisKey := fmt.Sprintf("ratelimit:%d:%s:%s", key.TenantID, key.Route, key.LimitType)
|
||
|
|
|
||
|
|
count, err := r.client.Get(ctx, redisKey).Int64()
|
||
|
|
if err == redis.Nil {
|
||
|
|
return 0, nil
|
||
|
|
}
|
||
|
|
if err != nil {
|
||
|
|
return 0, fmt.Errorf("failed to get rate limit: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return count, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// IncrRateLimit 增加限流计数
|
||
|
|
func (r *RedisCache) IncrRateLimit(ctx context.Context, key *RateLimitKey, window time.Duration) (int64, error) {
|
||
|
|
redisKey := fmt.Sprintf("ratelimit:%d:%s:%s", key.TenantID, key.Route, key.LimitType)
|
||
|
|
|
||
|
|
pipe := r.client.Pipeline()
|
||
|
|
incrCmd := pipe.Incr(ctx, redisKey)
|
||
|
|
pipe.Expire(ctx, redisKey, window)
|
||
|
|
|
||
|
|
_, err := pipe.Exec(ctx)
|
||
|
|
if err != nil {
|
||
|
|
return 0, fmt.Errorf("failed to increment rate limit: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return incrCmd.Val(), nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// CheckRateLimit 检查限流
|
||
|
|
func (r *RedisCache) CheckRateLimit(ctx context.Context, key *RateLimitKey, limit int64, window time.Duration) (bool, int64, error) {
|
||
|
|
count, err := r.IncrRateLimit(ctx, key, window)
|
||
|
|
if err != nil {
|
||
|
|
return false, 0, err
|
||
|
|
}
|
||
|
|
|
||
|
|
return count <= limit, count, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// ==================== 分布式锁 ====================
|
||
|
|
|
||
|
|
// AcquireLock 获取分布式锁
|
||
|
|
func (r *RedisCache) AcquireLock(ctx context.Context, lockKey string, ttl time.Duration) (bool, error) {
|
||
|
|
redisKey := fmt.Sprintf("lock:%s", lockKey)
|
||
|
|
|
||
|
|
ok, err := r.client.SetNX(ctx, redisKey, "1", ttl).Result()
|
||
|
|
if err != nil {
|
||
|
|
return false, fmt.Errorf("failed to acquire lock: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return ok, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// ReleaseLock 释放分布式锁
|
||
|
|
func (r *RedisCache) ReleaseLock(ctx context.Context, lockKey string) error {
|
||
|
|
redisKey := fmt.Sprintf("lock:%s", lockKey)
|
||
|
|
return r.client.Del(ctx, redisKey).Err()
|
||
|
|
}
|
||
|
|
|
||
|
|
// ==================== 幂等缓存 ====================
|
||
|
|
|
||
|
|
// IdempotencyCache 幂等缓存(短期)
|
||
|
|
func (r *RedisCache) GetIdempotency(ctx context.Context, key string) (string, error) {
|
||
|
|
redisKey := fmt.Sprintf("idempotency:%s", key)
|
||
|
|
val, err := r.client.Get(ctx, redisKey).Result()
|
||
|
|
if err == redis.Nil {
|
||
|
|
return "", nil
|
||
|
|
}
|
||
|
|
if err != nil {
|
||
|
|
return "", fmt.Errorf("failed to get idempotency: %w", err)
|
||
|
|
}
|
||
|
|
return val, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (r *RedisCache) SetIdempotency(ctx context.Context, key, value string, ttl time.Duration) error {
|
||
|
|
redisKey := fmt.Sprintf("idempotency:%s", key)
|
||
|
|
return r.client.Set(ctx, redisKey, value, ttl).Err()
|
||
|
|
}
|
||
|
|
|
||
|
|
// ==================== Session缓存 ====================
|
||
|
|
|
||
|
|
// SessionData Session数据
|
||
|
|
type SessionData struct {
|
||
|
|
UserID int64 `json:"user_id"`
|
||
|
|
TenantID int64 `json:"tenant_id"`
|
||
|
|
Role string `json:"role"`
|
||
|
|
CreatedAt int64 `json:"created_at"`
|
||
|
|
}
|
||
|
|
|
||
|
|
// GetSession 获取Session
|
||
|
|
func (r *RedisCache) GetSession(ctx context.Context, sessionID string) (*SessionData, error) {
|
||
|
|
key := fmt.Sprintf("session:%s", sessionID)
|
||
|
|
data, err := r.client.Get(ctx, key).Bytes()
|
||
|
|
if err == redis.Nil {
|
||
|
|
return nil, nil
|
||
|
|
}
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to get session: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
var session SessionData
|
||
|
|
if err := json.Unmarshal(data, &session); err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to unmarshal session: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return &session, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// SetSession 设置Session
|
||
|
|
func (r *RedisCache) SetSession(ctx context.Context, sessionID string, session *SessionData, ttl time.Duration) error {
|
||
|
|
key := fmt.Sprintf("session:%s", sessionID)
|
||
|
|
data, err := json.Marshal(session)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("failed to marshal session: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return r.client.Set(ctx, key, data, ttl).Err()
|
||
|
|
}
|
||
|
|
|
||
|
|
// DeleteSession 删除Session
|
||
|
|
func (r *RedisCache) DeleteSession(ctx context.Context, sessionID string) error {
|
||
|
|
key := fmt.Sprintf("session:%s", sessionID)
|
||
|
|
return r.client.Del(ctx, key).Err()
|
||
|
|
}
|