diff --git a/internal/cache/cache_manager.go b/internal/cache/cache_manager.go index 561abbd..7eb3899 100644 --- a/internal/cache/cache_manager.go +++ b/internal/cache/cache_manager.go @@ -106,3 +106,16 @@ func (cm *CacheManager) GetL1() *L1Cache { func (cm *CacheManager) GetL2() L2Cache { return cm.l2 } + +// Increment 原子递增(同时更新L1和L2) +func (cm *CacheManager) Increment(ctx context.Context, key string, delta int64, ttl time.Duration) (int64, error) { + // 先更新L1 + cm.l1.Increment(key, delta, ttl) + + // 再更新L2 + if cm.l2 != nil { + return cm.l2.Increment(ctx, key, delta, ttl) + } + + return cm.l1.Increment(key, 0, 0), nil +} diff --git a/internal/cache/l1.go b/internal/cache/l1.go index c26061e..73dc97e 100644 --- a/internal/cache/l1.go +++ b/internal/cache/l1.go @@ -169,3 +169,44 @@ func (c *L1Cache) Cleanup() { c.removeFromAccessOrder(key) } } + +// Increment 原子递增(用于登录失败计数器等原子操作场景) +func (c *L1Cache) Increment(key string, delta int64, ttl time.Duration) int64 { + c.mu.Lock() + defer c.mu.Unlock() + + var expiration int64 + if ttl > 0 { + expiration = time.Now().Add(ttl).UnixNano() + } + + current := int64(0) + if item, ok := c.items[key]; ok { + if item.Expired() { + delete(c.items, key) + c.removeFromAccessOrder(key) + } else { + if v, ok := item.Value.(int64); ok { + current = v + } else if v, ok := item.Value.(int); ok { + current = int64(v) + } else if v, ok := item.Value.(float64); ok { + current = int64(v) + } + } + } + + newVal := current + delta + c.items[key] = &CacheItem{ + Value: newVal, + Expiration: expiration, + } + + if _, exists := c.items[key]; !exists { + c.accessOrder = append(c.accessOrder, key) + } else { + c.updateAccessOrder(key) + } + + return newVal +} diff --git a/internal/cache/l2.go b/internal/cache/l2.go index 868caaa..265e39f 100644 --- a/internal/cache/l2.go +++ b/internal/cache/l2.go @@ -17,6 +17,7 @@ type L2Cache interface { Delete(ctx context.Context, key string) error Exists(ctx context.Context, key string) (bool, error) Clear(ctx context.Context) error + Increment(ctx context.Context, key string, delta int64, ttl time.Duration) (int64, error) Close() error } @@ -127,6 +128,20 @@ func (c *RedisCache) Close() error { return c.client.Close() } +func (c *RedisCache) Increment(ctx context.Context, key string, delta int64, ttl time.Duration) (int64, error) { + if !c.enabled || c.client == nil { + return 0, errors.New("redis is not enabled") + } + result, err := c.client.IncrBy(ctx, key, delta).Result() + if err != nil { + return 0, err + } + if ttl > 0 { + c.client.Expire(ctx, key, ttl) + } + return result, nil +} + func decodeRedisValue(raw string) (interface{}, error) { decoder := json.NewDecoder(strings.NewReader(raw)) decoder.UseNumber() diff --git a/internal/service/auth.go b/internal/service/auth.go index 20bc12f..90c426b 100644 --- a/internal/service/auth.go +++ b/internal/service/auth.go @@ -494,17 +494,23 @@ func (s *AuthService) incrementFailAttempts(ctx context.Context, key string) int return 0 } - current := 0 - if value, ok := s.cache.Get(ctx, key); ok { - current = attemptCount(value) - } - current++ - - if err := s.cache.Set(ctx, key, current, s.loginLockDuration, s.loginLockDuration); err != nil { - log.Printf("auth: store login attempts failed, key=%s err=%v", key, err) + // 使用原子递增,避免竞态条件 + newVal, err := s.cache.Increment(ctx, key, 1, s.loginLockDuration) + if err != nil { + log.Printf("auth: increment login attempts failed, key=%s err=%v", key, err) + // 回退到原来的非原子方式 + current := 0 + if value, ok := s.cache.Get(ctx, key); ok { + current = attemptCount(value) + } + current++ + if setErr := s.cache.Set(ctx, key, current, s.loginLockDuration, s.loginLockDuration); setErr != nil { + log.Printf("auth: store login attempts failed, key=%s err=%v", key, setErr) + } + return current } - return current + return int(newVal) } func isValidPhoneSimple(phone string) bool {