295 lines
8.2 KiB
Go
295 lines
8.2 KiB
Go
package cache
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
type fakeRedisClient struct {
|
|
values map[string]string
|
|
expirations map[string]time.Duration
|
|
published []publishedMessage
|
|
}
|
|
|
|
type publishedMessage struct {
|
|
channel string
|
|
payload string
|
|
}
|
|
|
|
type fakeRedisPipeline struct {
|
|
client *fakeRedisClient
|
|
cmds []redis.Cmder
|
|
}
|
|
|
|
func newFakeRedisClient() *fakeRedisClient {
|
|
return &fakeRedisClient{
|
|
values: make(map[string]string),
|
|
expirations: make(map[string]time.Duration),
|
|
}
|
|
}
|
|
|
|
func (c *fakeRedisClient) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func (c *fakeRedisClient) Ping(ctx context.Context) *redis.StatusCmd {
|
|
return redis.NewStatusResult("PONG", nil)
|
|
}
|
|
|
|
func (c *fakeRedisClient) Get(ctx context.Context, key string) *redis.StringCmd {
|
|
val, ok := c.values[key]
|
|
if !ok {
|
|
return redis.NewStringResult("", redis.Nil)
|
|
}
|
|
return redis.NewStringResult(val, nil)
|
|
}
|
|
|
|
func (c *fakeRedisClient) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.StatusCmd {
|
|
c.values[key] = stringifyRedisValue(value)
|
|
c.expirations[key] = expiration
|
|
return redis.NewStatusResult("OK", nil)
|
|
}
|
|
|
|
func (c *fakeRedisClient) Del(ctx context.Context, keys ...string) *redis.IntCmd {
|
|
var removed int64
|
|
for _, key := range keys {
|
|
if _, ok := c.values[key]; ok {
|
|
delete(c.values, key)
|
|
delete(c.expirations, key)
|
|
removed++
|
|
}
|
|
}
|
|
return redis.NewIntResult(removed, nil)
|
|
}
|
|
|
|
func (c *fakeRedisClient) Publish(ctx context.Context, channel string, message interface{}) *redis.IntCmd {
|
|
c.published = append(c.published, publishedMessage{
|
|
channel: channel,
|
|
payload: stringifyRedisValue(message),
|
|
})
|
|
return redis.NewIntResult(1, nil)
|
|
}
|
|
|
|
func (c *fakeRedisClient) Subscribe(ctx context.Context, channels ...string) *redis.PubSub {
|
|
panic("Subscribe should not be called in cache unit tests")
|
|
}
|
|
|
|
func (c *fakeRedisClient) Pipeline() redisPipeline {
|
|
return &fakeRedisPipeline{client: c}
|
|
}
|
|
|
|
func (c *fakeRedisClient) SetNX(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.BoolCmd {
|
|
if _, exists := c.values[key]; exists {
|
|
return redis.NewBoolResult(false, nil)
|
|
}
|
|
c.values[key] = stringifyRedisValue(value)
|
|
c.expirations[key] = expiration
|
|
return redis.NewBoolResult(true, nil)
|
|
}
|
|
|
|
func (p *fakeRedisPipeline) Incr(ctx context.Context, key string) *redis.IntCmd {
|
|
var current int64
|
|
if raw, ok := p.client.values[key]; ok {
|
|
if _, err := fmt.Sscanf(raw, "%d", ¤t); err != nil {
|
|
return redis.NewIntResult(0, err)
|
|
}
|
|
}
|
|
current++
|
|
p.client.values[key] = fmt.Sprintf("%d", current)
|
|
cmd := redis.NewIntResult(current, nil)
|
|
p.cmds = append(p.cmds, cmd)
|
|
return cmd
|
|
}
|
|
|
|
func (p *fakeRedisPipeline) Expire(ctx context.Context, key string, expiration time.Duration) *redis.BoolCmd {
|
|
p.client.expirations[key] = expiration
|
|
cmd := redis.NewBoolResult(true, nil)
|
|
p.cmds = append(p.cmds, cmd)
|
|
return cmd
|
|
}
|
|
|
|
func (p *fakeRedisPipeline) Exec(ctx context.Context) ([]redis.Cmder, error) {
|
|
return p.cmds, nil
|
|
}
|
|
|
|
func stringifyRedisValue(value interface{}) string {
|
|
switch v := value.(type) {
|
|
case string:
|
|
return v
|
|
case []byte:
|
|
return string(v)
|
|
default:
|
|
return fmt.Sprint(v)
|
|
}
|
|
}
|
|
|
|
func TestRedisCacheTokenStatusAndSessionRoundTrip(t *testing.T) {
|
|
ctx := context.Background()
|
|
client := newFakeRedisClient()
|
|
cache := newRedisCacheWithClient(client)
|
|
|
|
status := &TokenStatus{
|
|
TokenID: "tok-1",
|
|
SubjectID: "sub-1",
|
|
Role: "owner",
|
|
Status: "active",
|
|
ExpiresAt: 1710000000,
|
|
}
|
|
if err := cache.SetTokenStatus(ctx, status, 5*time.Minute); err != nil {
|
|
t.Fatalf("SetTokenStatus() error = %v", err)
|
|
}
|
|
|
|
gotStatus, err := cache.GetTokenStatus(ctx, "tok-1")
|
|
if err != nil {
|
|
t.Fatalf("GetTokenStatus() error = %v", err)
|
|
}
|
|
if gotStatus == nil || gotStatus.TokenID != status.TokenID || gotStatus.SubjectID != status.SubjectID {
|
|
t.Fatalf("unexpected token status: %#v", gotStatus)
|
|
}
|
|
if client.expirations["token:status:tok-1"] != 5*time.Minute {
|
|
t.Fatalf("token status ttl = %s, want %s", client.expirations["token:status:tok-1"], 5*time.Minute)
|
|
}
|
|
|
|
session := &SessionData{
|
|
UserID: 7,
|
|
TenantID: 9,
|
|
Role: "viewer",
|
|
CreatedAt: 1710000001,
|
|
}
|
|
if err := cache.SetSession(ctx, "sess-1", session, time.Hour); err != nil {
|
|
t.Fatalf("SetSession() error = %v", err)
|
|
}
|
|
|
|
gotSession, err := cache.GetSession(ctx, "sess-1")
|
|
if err != nil {
|
|
t.Fatalf("GetSession() error = %v", err)
|
|
}
|
|
if gotSession == nil || gotSession.UserID != session.UserID || gotSession.Role != session.Role {
|
|
t.Fatalf("unexpected session: %#v", gotSession)
|
|
}
|
|
if client.expirations["session:sess-1"] != time.Hour {
|
|
t.Fatalf("session ttl = %s, want %s", client.expirations["session:sess-1"], time.Hour)
|
|
}
|
|
}
|
|
|
|
func TestRedisCacheMissingValueSemantics(t *testing.T) {
|
|
ctx := context.Background()
|
|
cache := newRedisCacheWithClient(newFakeRedisClient())
|
|
|
|
tokenStatus, err := cache.GetTokenStatus(ctx, "missing-token")
|
|
if err != nil {
|
|
t.Fatalf("GetTokenStatus() error = %v", err)
|
|
}
|
|
if tokenStatus != nil {
|
|
t.Fatalf("expected nil token status, got %#v", tokenStatus)
|
|
}
|
|
|
|
idempotency, err := cache.GetIdempotency(ctx, "missing-key")
|
|
if err != nil {
|
|
t.Fatalf("GetIdempotency() error = %v", err)
|
|
}
|
|
if idempotency != "" {
|
|
t.Fatalf("expected empty idempotency value, got %q", idempotency)
|
|
}
|
|
|
|
count, err := cache.GetRateLimit(ctx, &RateLimitKey{TenantID: 1, Route: "/v1/chat", LimitType: "rpm"}, time.Minute)
|
|
if err != nil {
|
|
t.Fatalf("GetRateLimit() error = %v", err)
|
|
}
|
|
if count != 0 {
|
|
t.Fatalf("rate limit count = %d, want 0", count)
|
|
}
|
|
}
|
|
|
|
func TestRedisCacheRateLimitAndLockSemantics(t *testing.T) {
|
|
ctx := context.Background()
|
|
client := newFakeRedisClient()
|
|
cache := newRedisCacheWithClient(client)
|
|
key := &RateLimitKey{TenantID: 42, Route: "/v1/chat", LimitType: "rpm"}
|
|
|
|
count, err := cache.IncrRateLimit(ctx, key, time.Minute)
|
|
if err != nil {
|
|
t.Fatalf("IncrRateLimit() error = %v", err)
|
|
}
|
|
if count != 1 {
|
|
t.Fatalf("first count = %d, want 1", count)
|
|
}
|
|
|
|
allowed, count, err := cache.CheckRateLimit(ctx, key, 2, time.Minute)
|
|
if err != nil {
|
|
t.Fatalf("CheckRateLimit() error = %v", err)
|
|
}
|
|
if !allowed || count != 2 {
|
|
t.Fatalf("second check = (%v, %d), want (true, 2)", allowed, count)
|
|
}
|
|
|
|
allowed, count, err = cache.CheckRateLimit(ctx, key, 2, time.Minute)
|
|
if err != nil {
|
|
t.Fatalf("CheckRateLimit() error = %v", err)
|
|
}
|
|
if allowed || count != 3 {
|
|
t.Fatalf("third check = (%v, %d), want (false, 3)", allowed, count)
|
|
}
|
|
|
|
if client.expirations["ratelimit:42:/v1/chat:rpm"] != time.Minute {
|
|
t.Fatalf("rate limit ttl = %s, want %s", client.expirations["ratelimit:42:/v1/chat:rpm"], time.Minute)
|
|
}
|
|
|
|
acquired, err := cache.AcquireLock(ctx, "sync-job", 30*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("AcquireLock() error = %v", err)
|
|
}
|
|
if !acquired {
|
|
t.Fatal("first AcquireLock() should succeed")
|
|
}
|
|
|
|
acquired, err = cache.AcquireLock(ctx, "sync-job", 30*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("AcquireLock() error = %v", err)
|
|
}
|
|
if acquired {
|
|
t.Fatal("second AcquireLock() should fail while lock exists")
|
|
}
|
|
|
|
if err := cache.ReleaseLock(ctx, "sync-job"); err != nil {
|
|
t.Fatalf("ReleaseLock() error = %v", err)
|
|
}
|
|
if _, exists := client.values["lock:sync-job"]; exists {
|
|
t.Fatal("lock key should be deleted after ReleaseLock()")
|
|
}
|
|
}
|
|
|
|
func TestRedisCachePublishTokenRevokedMarshalsEvent(t *testing.T) {
|
|
ctx := context.Background()
|
|
client := newFakeRedisClient()
|
|
cache := newRedisCacheWithClient(client)
|
|
event := &TokenRevokedCacheEvent{
|
|
TokenID: "tok-9",
|
|
RevokedAt: time.Unix(1710001111, 0).UTC(),
|
|
Reason: "manual revoke",
|
|
}
|
|
|
|
if err := cache.PublishTokenRevoked(ctx, event); err != nil {
|
|
t.Fatalf("PublishTokenRevoked() error = %v", err)
|
|
}
|
|
if len(client.published) != 1 {
|
|
t.Fatalf("published message count = %d, want 1", len(client.published))
|
|
}
|
|
if client.published[0].channel != "token:revoked" {
|
|
t.Fatalf("published channel = %q, want %q", client.published[0].channel, "token:revoked")
|
|
}
|
|
|
|
var got TokenRevokedCacheEvent
|
|
if err := json.Unmarshal([]byte(client.published[0].payload), &got); err != nil {
|
|
t.Fatalf("Unmarshal() error = %v", err)
|
|
}
|
|
if got.TokenID != event.TokenID || got.Reason != event.Reason || !got.RevokedAt.Equal(event.RevokedAt) {
|
|
t.Fatalf("unexpected event payload: %#v", got)
|
|
}
|
|
}
|