fix: P0-02 prevent login attempt counter race condition
Add atomic Increment method to cache layers: - L2Cache interface: add Increment method signature - RedisCache: implement using Redis INCRBY - L1Cache: implement with mutex-protected counter - CacheManager: add Increment that updates both L1 and L2 Update incrementFailAttempts to use atomic Increment instead of Get-Increment-Set pattern, preventing TOCTOU race.
This commit is contained in:
13
internal/cache/cache_manager.go
vendored
13
internal/cache/cache_manager.go
vendored
@@ -106,3 +106,16 @@ func (cm *CacheManager) GetL1() *L1Cache {
|
||||
func (cm *CacheManager) GetL2() L2Cache {
|
||||
return cm.l2
|
||||
}
|
||||
|
||||
// Increment 原子递增(同时更新L1和L2)
|
||||
func (cm *CacheManager) Increment(ctx context.Context, key string, delta int64, ttl time.Duration) (int64, error) {
|
||||
// 先更新L1
|
||||
cm.l1.Increment(key, delta, ttl)
|
||||
|
||||
// 再更新L2
|
||||
if cm.l2 != nil {
|
||||
return cm.l2.Increment(ctx, key, delta, ttl)
|
||||
}
|
||||
|
||||
return cm.l1.Increment(key, 0, 0), nil
|
||||
}
|
||||
|
||||
41
internal/cache/l1.go
vendored
41
internal/cache/l1.go
vendored
@@ -169,3 +169,44 @@ func (c *L1Cache) Cleanup() {
|
||||
c.removeFromAccessOrder(key)
|
||||
}
|
||||
}
|
||||
|
||||
// Increment 原子递增(用于登录失败计数器等原子操作场景)
|
||||
func (c *L1Cache) Increment(key string, delta int64, ttl time.Duration) int64 {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
var expiration int64
|
||||
if ttl > 0 {
|
||||
expiration = time.Now().Add(ttl).UnixNano()
|
||||
}
|
||||
|
||||
current := int64(0)
|
||||
if item, ok := c.items[key]; ok {
|
||||
if item.Expired() {
|
||||
delete(c.items, key)
|
||||
c.removeFromAccessOrder(key)
|
||||
} else {
|
||||
if v, ok := item.Value.(int64); ok {
|
||||
current = v
|
||||
} else if v, ok := item.Value.(int); ok {
|
||||
current = int64(v)
|
||||
} else if v, ok := item.Value.(float64); ok {
|
||||
current = int64(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
newVal := current + delta
|
||||
c.items[key] = &CacheItem{
|
||||
Value: newVal,
|
||||
Expiration: expiration,
|
||||
}
|
||||
|
||||
if _, exists := c.items[key]; !exists {
|
||||
c.accessOrder = append(c.accessOrder, key)
|
||||
} else {
|
||||
c.updateAccessOrder(key)
|
||||
}
|
||||
|
||||
return newVal
|
||||
}
|
||||
|
||||
15
internal/cache/l2.go
vendored
15
internal/cache/l2.go
vendored
@@ -17,6 +17,7 @@ type L2Cache interface {
|
||||
Delete(ctx context.Context, key string) error
|
||||
Exists(ctx context.Context, key string) (bool, error)
|
||||
Clear(ctx context.Context) error
|
||||
Increment(ctx context.Context, key string, delta int64, ttl time.Duration) (int64, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
@@ -127,6 +128,20 @@ func (c *RedisCache) Close() error {
|
||||
return c.client.Close()
|
||||
}
|
||||
|
||||
func (c *RedisCache) Increment(ctx context.Context, key string, delta int64, ttl time.Duration) (int64, error) {
|
||||
if !c.enabled || c.client == nil {
|
||||
return 0, errors.New("redis is not enabled")
|
||||
}
|
||||
result, err := c.client.IncrBy(ctx, key, delta).Result()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if ttl > 0 {
|
||||
c.client.Expire(ctx, key, ttl)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func decodeRedisValue(raw string) (interface{}, error) {
|
||||
decoder := json.NewDecoder(strings.NewReader(raw))
|
||||
decoder.UseNumber()
|
||||
|
||||
@@ -494,17 +494,23 @@ func (s *AuthService) incrementFailAttempts(ctx context.Context, key string) int
|
||||
return 0
|
||||
}
|
||||
|
||||
current := 0
|
||||
if value, ok := s.cache.Get(ctx, key); ok {
|
||||
current = attemptCount(value)
|
||||
}
|
||||
current++
|
||||
|
||||
if err := s.cache.Set(ctx, key, current, s.loginLockDuration, s.loginLockDuration); err != nil {
|
||||
log.Printf("auth: store login attempts failed, key=%s err=%v", key, err)
|
||||
// 使用原子递增,避免竞态条件
|
||||
newVal, err := s.cache.Increment(ctx, key, 1, s.loginLockDuration)
|
||||
if err != nil {
|
||||
log.Printf("auth: increment login attempts failed, key=%s err=%v", key, err)
|
||||
// 回退到原来的非原子方式
|
||||
current := 0
|
||||
if value, ok := s.cache.Get(ctx, key); ok {
|
||||
current = attemptCount(value)
|
||||
}
|
||||
current++
|
||||
if setErr := s.cache.Set(ctx, key, current, s.loginLockDuration, s.loginLockDuration); setErr != nil {
|
||||
log.Printf("auth: store login attempts failed, key=%s err=%v", key, setErr)
|
||||
}
|
||||
return current
|
||||
}
|
||||
|
||||
return current
|
||||
return int(newVal)
|
||||
}
|
||||
|
||||
func isValidPhoneSimple(phone string) bool {
|
||||
|
||||
Reference in New Issue
Block a user