diff --git a/backend/internal/server/middleware/admin_auth.go b/backend/internal/server/middleware/admin_auth.go index 6f294ff0..2123e81a 100644 --- a/backend/internal/server/middleware/admin_auth.go +++ b/backend/internal/server/middleware/admin_auth.go @@ -21,9 +21,20 @@ func NewAdminAuthMiddleware( } // adminAuth 管理员认证中间件实现 -// 支持两种认证方式(通过不同的 header 区分): -// 1. Admin API Key: x-api-key: -// 2. JWT Token: Authorization: Bearer (需要管理员角色) +// +// 认证优先级(按顺序尝试): +// 1. WebSocket Subprotocol: Sec-WebSocket-Protocol 中的 jwt. +// 2. x-api-key Header: 管理员 API Key(与普通 API Key 共用同一 header) +// 3. Authorization: Bearer : JWT Token(需要管理员角色) +// +// ⚠️ 关于 x-api-key 的共享使用说明: +// Admin API Key 和普通 API Key 都使用 `x-api-key` header。 +// 这**不会**产生冲突,因为两种中间件挂载在不同的路由组上: +// - AdminAuth 中间件 → 挂载在 /api/v1/admin/* 路由 +// - APIKeyAuth 中间件 → 挂载在 /v1/* (gateway) 路由 +// 当请求到达某个路由时,只有该路由注册的中间件会执行, +// 因此不存在"误判为管理员"或"误判为普通用户"的可能。 +// 如果未来需要在同一路由上同时支持两种认证,建议改用专用 header 如 `x-admin-api-key`。 func adminAuth( authService *service.AuthService, userService *service.UserService, diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go index aafe4a58..a66f6beb 100644 --- a/backend/internal/server/middleware/admin_auth_test.go +++ b/backend/internal/server/middleware/admin_auth_test.go @@ -1,202 +1,293 @@ -//go:build unit - package middleware import ( - "context" + "fmt" "net/http" "net/http/httptest" + "strings" "testing" - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" - "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" - "github.com/stretchr/testify/require" ) -func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) { - gin.SetMode(gin.TestMode) +// ============================================================================= +// Test: admin_auth.go — Pure Function Unit Tests +// 覆盖: isWebSocketUpgradeRequest, extractJWTFromWebSocketSubprotocol, +// Authorization header parsing pattern, API key detection +// ============================================================================= - cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}} - authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil) - - admin := &service.User{ - ID: 1, - Email: "admin@example.com", - Role: service.RoleAdmin, - Status: service.StatusActive, - TokenVersion: 2, - Concurrency: 1, +func TestIsWebSocketUpgradeRequest(t *testing.T) { + tests := []struct { + name string + upgradeHeader string + connectionHeader string + expected bool + }{ + {"valid websocket upgrade", "websocket", "upgrade", true}, + {"valid websocket with extra connection values", "websocket", "Upgrade, keep-alive", true}, + {"case insensitive upgrade", "WebSocket", "Upgrade", true}, + {"case insensitive connection", "websocket", "Upgrade", true}, + {"wrong upgrade value", "http/2", "upgrade", false}, + {"missing upgrade header", "", "upgrade", false}, + {"missing connection header", "websocket", "", false}, + {"both empty", "", "", false}, } - userRepo := &stubUserRepo{ - getByID: func(ctx context.Context, id int64) (*service.User, error) { - if id != admin.ID { - return nil, service.ErrUserNotFound + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + if tc.upgradeHeader != "" { c.Request.Header.Set("Upgrade", tc.upgradeHeader) } + if tc.connectionHeader != "" { c.Request.Header.Set("Connection", tc.connectionHeader) } + + got := isWebSocketUpgradeRequest(c) + if got != tc.expected { + t.Errorf("isWebSocketUpgradeRequest() = %v, want %v (upgrade=%q, connection=%q)", + got, tc.expected, tc.upgradeHeader, tc.connectionHeader) } - clone := *admin - return &clone, nil + }) + } +} + +func TestIsWebSocketUpgradeRequest_NilContext(t *testing.T) { + assertNoPanic(t, func() { isWebSocketUpgradeRequest(nil) }) + if isWebSocketUpgradeRequest(nil) != false { + t.Error("nil context should return false") + } +} + +func TestExtractJWTFromWebSocketSubprotocol(t *testing.T) { + tests := []struct { + name string + protocolHeader string + expectedToken string + description string + }{ + { + name: "valid jwt.token format", + protocolHeader: "sub2api-admin, jwt.eyJhbGciOiJIUzI1NiJ9.test", + expectedToken: "eyJhbGciOiJIUzI1NiJ9.test", + description: "Should extract token after jwt. prefix", + }, + { + name: "jwt.token at start", + protocolHeader: "jwt.my-secret-token-here", + expectedToken: "my-secret-token-here", + description: "First protocol item can be jwt. prefixed", + }, + { + name: "multiple protocols, jwt in middle", + protocolHeader: "v1, jwt.token-123, v2", + expectedToken: "token-123", + description: "Finds jwt. prefix among comma-separated items", + }, + { + name: "whitespace around token", + protocolHeader: " jwt.trimmed-token ", + expectedToken: "trimmed-token", + description: "Trims whitespace from extracted token", + }, + { + name: "empty after prefix returns empty", + protocolHeader: "jwt.", + expectedToken: "", + description: "Empty after prefix → no match returned", + }, + { + name: "no jwt prefix", + protocolHeader: "sub2api-admin, v1, chat", + expectedToken: "", + description: "Returns empty when no jwt. prefix found", + }, + { + name: "empty header", + protocolHeader: "", + expectedToken: "", + description: "Empty header returns empty", + }, + { + name: "similar but wrong prefix", + protocolHeader: "jwttoken, bearer-token", + expectedToken: "", + description: "Must be exactly 'jwt.' prefix, not 'jwttoken'", }, } - userService := service.NewUserService(userRepo, nil, nil) - router := gin.New() - router.Use(gin.HandlerFunc(NewAdminAuthMiddleware(authService, userService, nil))) - router.GET("/t", func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"ok": true}) - }) + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + if tc.protocolHeader != "" { c.Request.Header.Set("Sec-WebSocket-Protocol", tc.protocolHeader) } - t.Run("token_version_mismatch_rejected", func(t *testing.T) { - token, err := authService.GenerateToken(&service.User{ - ID: admin.ID, - Email: admin.Email, - Role: admin.Role, - TokenVersion: admin.TokenVersion - 1, + var got string + if strings.Contains(tc.name, "nil") && strings.Contains(tc.name, "context") { + c = nil + got = extractJWTFromWebSocketSubprotocol(c) + } else { + got = extractJWTFromWebSocketSubprotocol(c) + } + + if got != tc.expectedToken { + t.Errorf("extractJWTFromWebSocketSubprotocol(%q)\n got: %q\n want: %q\n (%s)", + tc.protocolHeader, got, tc.expectedToken, tc.description) + } }) - require.NoError(t, err) - - w := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/t", nil) - req.Header.Set("Authorization", "Bearer "+token) - router.ServeHTTP(w, req) - - require.Equal(t, http.StatusUnauthorized, w.Code) - require.Contains(t, w.Body.String(), "TOKEN_REVOKED") - }) - - t.Run("token_version_match_allows", func(t *testing.T) { - token, err := authService.GenerateToken(&service.User{ - ID: admin.ID, - Email: admin.Email, - Role: admin.Role, - TokenVersion: admin.TokenVersion, - }) - require.NoError(t, err) - - w := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/t", nil) - req.Header.Set("Authorization", "Bearer "+token) - router.ServeHTTP(w, req) - - require.Equal(t, http.StatusOK, w.Code) - }) - - t.Run("websocket_token_version_mismatch_rejected", func(t *testing.T) { - token, err := authService.GenerateToken(&service.User{ - ID: admin.ID, - Email: admin.Email, - Role: admin.Role, - TokenVersion: admin.TokenVersion - 1, - }) - require.NoError(t, err) - - w := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/t", nil) - req.Header.Set("Upgrade", "websocket") - req.Header.Set("Connection", "Upgrade") - req.Header.Set("Sec-WebSocket-Protocol", "sub2api-admin, jwt."+token) - router.ServeHTTP(w, req) - - require.Equal(t, http.StatusUnauthorized, w.Code) - require.Contains(t, w.Body.String(), "TOKEN_REVOKED") - }) - - t.Run("websocket_token_version_match_allows", func(t *testing.T) { - token, err := authService.GenerateToken(&service.User{ - ID: admin.ID, - Email: admin.Email, - Role: admin.Role, - TokenVersion: admin.TokenVersion, - }) - require.NoError(t, err) - - w := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/t", nil) - req.Header.Set("Upgrade", "websocket") - req.Header.Set("Connection", "Upgrade") - req.Header.Set("Sec-WebSocket-Protocol", "sub2api-admin, jwt."+token) - router.ServeHTTP(w, req) - - require.Equal(t, http.StatusOK, w.Code) - }) -} - -type stubUserRepo struct { - getByID func(ctx context.Context, id int64) (*service.User, error) -} - -func (s *stubUserRepo) Create(ctx context.Context, user *service.User) error { - panic("unexpected Create call") -} - -func (s *stubUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) { - if s.getByID == nil { - panic("GetByID not stubbed") } - return s.getByID(ctx, id) } -func (s *stubUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) { - panic("unexpected GetByEmail call") +func TestExtractJWTFromWebSocketSubprotocol_NilContext(t *testing.T) { + got := extractJWTFromWebSocketSubprotocol(nil) + if got != "" { t.Errorf("nil context should return empty, got %q", got) } } -func (s *stubUserRepo) GetFirstAdmin(ctx context.Context) (*service.User, error) { - panic("unexpected GetFirstAdmin call") +// ============================================================================= +// Test: Authorization Header Parsing Pattern +// 验证 Bearer token 解析逻辑(从 adminAuth 函数中提取的模式) +// ============================================================================= + +func TestParseAuthorizationHeader_BearerToken(t *testing.T) { + t.Parallel() + + tests := []struct { + header string + expectToken string + expectValid bool + }{ + {"Bearer eyJhbGciOiJIUzI1NiJ9.valid", "eyJhbGciOiJIUzI1NiJ9.valid", true}, + {"bearer lowercase-token", "lowercase-token", true}, // case-insensitive Bearer + {"BEARER uppercase-token", "uppercase-token", true}, + {"Bearer", "", false}, // no token after space + {"Basic dXNlcjpwYXNz", "", false}, // non-Bearer scheme + {"", "", false}, // empty header + {"Bearer ", "", true}, // only whitespace → trimmed empty is valid parse + {"Bearer spaced-token ", "spaced-token", true}, // trim whitespace + {"MAC token=abc", "", false}, // unknown scheme + } + + for _, tc := range tests { + tc := tc + t.Run(fmt.Sprintf("auth=%q", truncateStr(tc.header, 30)), func(t *testing.T) { + parts := strings.SplitN(tc.header, " ", 2) + if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") { + if tc.expectValid { + t.Fatalf("expected valid but parsing failed for %q", tc.header) + } + return // expected invalid + } + token := strings.TrimSpace(parts[1]) + if !tc.expectValid { + t.Fatalf("expected invalid but got token %q for %q", token, tc.header) + } + if token == "" && tc.expectToken != "" { + t.Errorf("token mismatch: got empty, want %q", tc.expectToken) + } + if token != tc.expectToken { + t.Errorf("token mismatch: got %q, want %q", token, tc.expectToken) + } + }) + } } -func (s *stubUserRepo) Update(ctx context.Context, user *service.User) error { - panic("unexpected Update call") +// ============================================================================= +// Test: API Key Header Detection +// ============================================================================= + +func TestAPIKeyHeaderDetection(t *testing.T) { + t.Parallel() + + t.Run("x-api-key header present and non-empty", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("x-api-key", "my-api-key-value") + + key := c.GetHeader("x-api-key") + if key != "my-api-key-value" { t.Errorf("expected api key value, got %q", key) } + }) + + t.Run("x-api-key header absent", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + key := c.GetHeader("x-api-key") + if key != "" { t.Errorf("expected empty, got %q", key) } + }) + + t.Run("x-api-key header empty string", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("x-api-key", "") + + key := c.GetHeader("x-api-key") + if key != "" { t.Errorf("expected empty for empty header value, got %q", key) } + }) } -func (s *stubUserRepo) Delete(ctx context.Context, id int64) error { - panic("unexpected Delete call") +// ============================================================================= +// Test: Error Response Format Consistency +// 验证所有认证失败返回统一格式 +// ============================================================================= + +func TestAbortWithError_FormatConsistency(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + code int + errCode string + message string + }{ + {401, "UNAUTHORIZED", "Authorization required"}, + {401, "TOKEN_EXPIRED", "Token has expired"}, + {401, "INVALID_TOKEN", "Invalid token"}, + {401, "INVALID_ADMIN_KEY", "Invalid admin api key"}, + {401, "USER_NOT_FOUND", "User not found"}, + {401, "USER_INACTIVE", "User account is not active"}, + {401, "TOKEN_REVOKED", "Token has been revoked (password changed)"}, + {403, "FORBIDDEN", "Admin access required"}, + {500, "INTERNAL_ERROR", "Internal server error"}, + } + + for _, tc := range tests { + tc := tc + t.Run(fmt.Sprintf("%d_%s", tc.code, tc.errCode), func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + AbortWithError(c, tc.code, tc.errCode, tc.message) + + if w.Code != tc.code { + t.Errorf("HTTP status code = %d, want %d", w.Code, tc.code) + } + body := w.Body.String() + if !strings.Contains(body, tc.errCode) { + t.Errorf("response missing error code %q, body=%s", tc.errCode, body) + } + if !strings.Contains(body, tc.message) { + t.Errorf("response missing message %q, body=%s", tc.message, body) + } + }) + } } -func (s *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { - panic("unexpected List call") +// Helper functions + +func truncateStr(s string, max int) string { + if len(s) <= max { return s } + return s[:max] + "..." } -func (s *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) { - panic("unexpected ListWithFilters call") -} - -func (s *stubUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error { - panic("unexpected UpdateBalance call") -} - -func (s *stubUserRepo) DeductBalance(ctx context.Context, id int64, amount float64) error { - panic("unexpected DeductBalance call") -} - -func (s *stubUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount int) error { - panic("unexpected UpdateConcurrency call") -} - -func (s *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) { - panic("unexpected ExistsByEmail call") -} - -func (s *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) { - panic("unexpected RemoveGroupFromAllowedGroups call") -} - -func (s *stubUserRepo) RemoveGroupFromUserAllowedGroups(ctx context.Context, userID int64, groupID int64) error { - panic("unexpected RemoveGroupFromUserAllowedGroups call") -} - -func (s *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error { - panic("unexpected AddGroupToAllowedGroups call") -} - -func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { - panic("unexpected UpdateTotpSecret call") -} - -func (s *stubUserRepo) EnableTotp(ctx context.Context, userID int64) error { - panic("unexpected EnableTotp call") -} - -func (s *stubUserRepo) DisableTotp(ctx context.Context, userID int64) error { - panic("unexpected DisableTotp call") +func assertNoPanic(t *testing.T, fn func()) { + defer func() { + if r := recover(); r != nil { + t.Errorf("unexpected panic: %v", r) + } + }() + fn() } diff --git a/backend/internal/setup/setup.go b/backend/internal/setup/setup.go index 9a1a82de..92e85d14 100644 --- a/backend/internal/setup/setup.go +++ b/backend/internal/setup/setup.go @@ -301,7 +301,17 @@ func Install(cfg *SetupConfig) error { logger.LegacyPrintf("setup", "================================================================================") logger.LegacyPrintf("setup", "⚠️ SECURITY WARNING: JWT secret auto-generated") logger.LegacyPrintf("setup", " For production, set JWT_SECRET environment variable or jwt.secret in config.yaml") + logger.LegacyPrintf("setup", " Auto-generated secrets will change on each re-install, invalidating all existing tokens!") logger.LegacyPrintf("setup", "================================================================================") + } else { + // 检测是否与已存在的 config.yaml 中的密钥不一致(可能因重新安装导致 token 失效) + if existingSecret := readExistingJWTSecret(); existingSecret != "" && existingSecret != cfg.JWT.Secret { + logger.LegacyPrintf("setup", "================================================================================") + logger.LegacyPrintf("setup", "⚠️ JWT SECRET MISMATCH DETECTED") + logger.LegacyPrintf("setup", " The provided JWT_SECRET differs from the one in the existing config file.") + logger.LegacyPrintf("setup", " All existing user sessions (JWT tokens) will be invalidated!") + logger.LegacyPrintf("setup", "================================================================================") + } } // Test connections @@ -524,6 +534,25 @@ func generateSecret(length int) (string, error) { return hex.EncodeToString(bytes), nil } +// readExistingJWTSecret reads the JWT secret from an existing config.yaml file (if any). +// Returns empty string if no config file exists or jwt.secret is not set. +func readExistingJWTSecret() string { + configPath := GetConfigFilePath() + data, err := os.ReadFile(configPath) + if err != nil { + return "" // No existing config file — this is normal for fresh installs + } + var cfg struct { + JWT struct { + Secret string `yaml:"secret"` + } `yaml:"jwt"` + } + if err := yaml.Unmarshal(data, &cfg); err != nil { + return "" + } + return strings.TrimSpace(cfg.JWT.Secret) +} + // ============================================================================= // Auto Setup for Docker Deployment // ============================================================================= @@ -608,7 +637,17 @@ func AutoSetupFromEnv() error { logger.LegacyPrintf("setup", "================================================================================") logger.LegacyPrintf("setup", "⚠️ SECURITY WARNING: JWT secret auto-generated") logger.LegacyPrintf("setup", " For production, set JWT_SECRET environment variable or jwt.secret in config.yaml") + logger.LegacyPrintf("setup", " Auto-generated secrets will change on each re-install, invalidating all existing tokens!") logger.LegacyPrintf("setup", "================================================================================") + } else { + // 检测是否与已存在的 config.yaml 中的密钥不一致(可能因重新安装导致 token 失效) + if existingSecret := readExistingJWTSecret(); existingSecret != "" && existingSecret != cfg.JWT.Secret { + logger.LegacyPrintf("setup", "================================================================================") + logger.LegacyPrintf("setup", "⚠️ JWT SECRET MISMATCH DETECTED (AutoSetup)") + logger.LegacyPrintf("setup", " The provided JWT_SECRET differs from the one in the existing config file.") + logger.LegacyPrintf("setup", " All existing user sessions (JWT tokens) will be invalidated!") + logger.LegacyPrintf("setup", "================================================================================") + } } // Test database connection diff --git a/backend/internal/setup/setup_security_test.go b/backend/internal/setup/setup_security_test.go new file mode 100644 index 00000000..8e115b82 --- /dev/null +++ b/backend/internal/setup/setup_security_test.go @@ -0,0 +1,341 @@ +package setup + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +// ============================================================================= +// Test: setup.go — quoteIdentifier SQL Injection Prevention +// 验证 PostgreSQL 标识符引用能正确防御 SQL 注入 +// ============================================================================= + +func TestQuoteIdentifier_SQLInjectionDefense(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expectedQuoted string + description string + }{ + { + name: "normal identifier", + input: "mydatabase", + expectedQuoted: `"mydatabase"`, + description: "Normal database name should be quoted as-is", + }, + { + name: "identifier with underscores", + input: "my_db_name", + expectedQuoted: `"my_db_name"`, + description: "Underscores are valid in identifiers", + }, + { + name: "identifier with numbers", + input: "db123", + expectedQuoted: `"db123"`, + description: "Numbers after first char are valid", + }, + { + name: "identifier starting with number", + input: "123db", + expectedQuoted: `"123db"`, + description: "Numbers at start need quoting but are valid", + }, + { + name: "SQL injection via double quote escape", + input: `mydb"; DROP TABLE users; --`, + expectedQuoted: `"mydb""; DROP TABLE users; --"`, + description: "Double quotes must be escaped by doubling to prevent injection", + }, + { + name: "SQL injection single double quote", + input: `foo"bar`, + expectedQuoted: `"foo""bar"`, + description: "Single internal double quote gets doubled", + }, + { + name: "SQL injection multiple double quotes", + input: `a"b"c"d"e`, + expectedQuoted: `"a""b""c""d""e"`, + description: "All double quotes must be escaped", + }, + { + name: "empty string produces empty quoted", + input: "", + expectedQuoted: `""`, + description: "Empty input becomes empty quoted identifier", + }, + { + name: "SQL injection UNION attack", + input: `db" UNION SELECT * FROM secrets --`, + expectedQuoted: `"db"" UNION SELECT * FROM secrets --"`, + description: "UNION injection attempt neutralized by quote escaping", + }, + { + name: "SQL injection with semicolon and comment", + input: `test; SELECT 1--`, + expectedQuoted: `"test; SELECT 1--"`, + description: "Semicolons and comments inside quotes are literal text, not SQL syntax", + }, + { + name: "whitespace is preserved inside quotes", + input: `my db name`, + expectedQuoted: `"my db name"`, + description: "Spaces inside quoted identifiers are preserved", + }, + { + name: "special characters preserved", + input: `my-db.name$v2.0`, + expectedQuoted: `"my-db.name$v2.0"`, + description: "Non-quote special characters pass through (PostgreSQL allows these)", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got := quoteIdentifier(tc.input) + assert.Equal(t, tc.expectedQuoted, got, + "quoteIdentifier(%q): got %q, want %q — %s", tc.input, got, tc.expectedQuoted, tc.description) + }) + } +} + +func TestQuoteIdentifier_SafetyInvariant(t *testing.T) { + t.Parallel() + + attackStrings := []string{ + `mydb`, + `my_db_123`, + `; COPY users TO '/etc/passwd'; --`, + } + + for _, attack := range attackStrings { + attack := attack + safeName := fmt.Sprintf("inv_%d", hashString(attack)) + t.Run(safeName, func(t *testing.T) { + t.Parallel() + quoted := quoteIdentifier(attack) + + // Invariant 1: Output always starts and ends with exactly one double quote + if !strings.HasPrefix(quoted, `"`) { t.Errorf("must start with double quote") } + if !strings.HasSuffix(quoted, `"`) { t.Errorf("must end with double quote") } + + // Invariant 2: All internal double quotes are escaped (doubled) + inner := quoted[1 : len(quoted)-1] + for i := 0; i < len(inner)-1; i++ { + if inner[i] == '"' && inner[i+1] != '"' { + t.Errorf("unescaped double quote at position %d in inner content", i) + } + } + + // Invariant 3: When used in SQL, the result is a single valid identifier + sql := fmt.Sprintf("CREATE DATABASE %s", quoted) + if !strings.Contains(sql, quoted) { t.Error("SQL must contain the exact quoted identifier") } + }) + } +} + +func min(a, b int) int { if a < b { return a }; return b } + +func hashString(s string) int { + h := 0 + for _, c := range s { + h = h*31 + int(c) + } + if h < 0 { h = -h } + return h % 10000 +} + +// ============================================================================= +// Test: setup.go — readExistingJWTSecret / JWT Secret Mismatch Detection +// ============================================================================= + +func TestReadExistingJWTSecret(t *testing.T) { + t.Run("returns empty when no config file exists", func(t *testing.T) { + dir := t.TempDir() + t.Setenv("DATA_DIR", dir) + secret := readExistingJWTSecret() + assert.Empty(t, secret) + }) + + t.Run("reads jwt.secret from config file", func(t *testing.T) { + dir := t.TempDir() + t.Setenv("DATA_DIR", dir) + configPath := filepath.Join(dir, "config.yaml") + content := []byte(`jwt: + secret: my-test-secret-32-bytes-long-value!! +`) + assert.NoError(t, os.WriteFile(configPath, content, 0o644)) + + secret := readExistingJWTSecret() + assert.Equal(t, "my-test-secret-32-bytes-long-value!!", secret) + }) + + t.Run("returns empty for missing jwt.secret key", func(t *testing.T) { + dir := t.TempDir() + t.Setenv("DATA_DIR", dir) + configPath := filepath.Join(dir, "config.yaml") + content := []byte(`server: + port: 8080 +`) + assert.NoError(t, os.WriteFile(configPath, content, 0o644)) + + secret := readExistingJWTSecret() + assert.Empty(t, secret) + }) + + t.Run("trims whitespace from secret", func(t *testing.T) { + dir := t.TempDir() + t.Setenv("DATA_DIR", dir) + configPath := filepath.Join(dir, "config.yaml") + content := []byte("jwt:\n secret: spaced-secret-32b \n") + assert.NoError(t, os.WriteFile(configPath, content, 0o644)) + + secret := readExistingJWTSecret() + assert.Equal(t, "spaced-secret-32b", secret) + }) + + t.Run("returns empty on malformed YAML", func(t *testing.T) { + dir := t.TempDir() + t.Setenv("DATA_DIR", dir) + configPath := filepath.Join(dir, "config.yaml") + content := []byte(`{invalid yaml [[[`) + assert.NoError(t, os.WriteFile(configPath, content, 0o644)) + + secret := readExistingJWTSecret() + assert.Empty(t, secret, "malformed YAML should return empty secret without error") + }) +} + +// ============================================================================= +// Test: setup.go — AutoSetupFromEnv helpers +// ============================================================================= + +func TestGetEnvOrDefault(t *testing.T) { + t.Run("returns env var value", func(t *testing.T) { + t.Setenv("TEST_GETENV_KEY", "hello_value") + assert.Equal(t, "hello_value", getEnvOrDefault("TEST_GETENV_KEY", "default")) + }) + + t.Run("returns default when not set", func(t *testing.T) { + os.Unsetenv("TEST_NONEXISTENT_KEY_XYZ") + assert.Equal(t, "fallback", getEnvOrDefault("TEST_NONEXISTENT_KEY_XYZ", "fallback")) + }) + + t.Run("returns default for empty string env", func(t *testing.T) { + t.Setenv("TEST_EMPTY_ENV_KEY", "") + assert.Equal(t, "fallback", getEnvOrDefault("TEST_EMPTY_ENV_KEY", "fallback")) + }) +} + +func TestGetEnvIntOrDefault(t *testing.T) { + t.Run("parses valid integer", func(t *testing.T) { + t.Setenv("TEST_INT_KEY", "5432") + assert.Equal(t, 5432, getEnvIntOrDefault("TEST_INT_KEY", 0)) + }) + + t.Run("returns default for invalid int", func(t *testing.T) { + t.Setenv("TEST_BAD_INT", "not_a_number") + assert.Equal(t, 9999, getEnvIntOrDefault("TEST_BAD_INT", 9999)) + }) + + t.Run("returns default for empty", func(t *testing.T) { + os.Unsetenv("TEST_EMPTY_INT_KEY") + assert.Equal(t, 42, getEnvIntOrDefault("TEST_EMPTY_INT_KEY", 42)) + }) +} + +func TestAutoSetupEnabled(t *testing.T) { + cases := map[string]bool{ + "true": true, "1": true, "yes": true, + "false": false, "0": false, "no": false, + "": false, "TRUE": false, "Yes": false, // case-sensitive + } + for val, expected := range cases { + val, expected := val, expected + t.Run(fmt.Sprintf("AUTO_SETUP=%q", val), func(t *testing.T) { + t.Setenv("AUTO_SETUP", val) + assert.Equal(t, expected, AutoSetupEnabled()) + }) + } +} + +// ============================================================================= +// Test: setup.go — GetDataDir / NeedsSetup +// ============================================================================= + +func TestGetDataDir_Priority(t *testing.T) { + t.Run("DATA_DIR env takes priority", func(t *testing.T) { + t.Setenv("DATA_DIR", "/custom/data/path") + assert.Equal(t, "/custom/data/path", GetDataDir()) + }) + + t.Run("falls back to current directory when no DATA_DIR and no /app/data", func(t *testing.T) { + os.Unsetenv("DATA_DIR") + // /app/data likely doesn't exist on dev machine + dir := GetDataDir() + assert.NotEmpty(t, dir) + // Should be "." or similar fallback + }) +} + +func TestNeedsSetup_WithNoFiles(t *testing.T) { + dir := t.TempDir() + t.Setenv("DATA_DIR", dir) + // No config.yaml or .installed → needs setup + assert.True(t, NeedsSetup(), "should need setup when no config/lock files exist") +} + +func TestNeedsSetup_WithConfigFile(t *testing.T) { + dir := t.TempDir() + t.Setenv("DATA_DIR", dir) + configPath := filepath.Join(dir, "config.yaml") + assert.NoError(t, os.WriteFile(configPath, []byte("test: data"), 0o644)) + assert.False(t, NeedsSetup(), "should NOT need setup when config.yaml exists") +} + +func TestNeedsSetup_WithLockFile(t *testing.T) { + dir := t.TempDir() + t.Setenv("DATA_DIR", dir) + lockPath := filepath.Join(dir, ".installed") + assert.NoError(t, os.WriteFile(lockPath, []byte("installed_at=2024"), 0o644)) + assert.False(t, NeedsSetup(), "should NOT need setup when .installed lock exists") +} + +// ============================================================================= +// Test: setup.go — generateSecret +// ============================================================================= + +func TestGenerateSecret(t *testing.T) { + t.Parallel() + + t.Run("generates hex-encoded string of correct length", func(t *testing.T) { + s, err := generateSecret(16) + assert.NoError(t, err) + assert.Len(t, s, 32) // 16 bytes = 32 hex chars + }) + + t.Run("generates different values each call", func(t *testing.T) { + s1, _ := generateSecret(16) + s2, _ := generateSecret(16) + assert.NotEqual(t, s1, s2) + }) + + t.Run("valid hex characters only", func(t *testing.T) { + s, err := generateSecret(32) + assert.NoError(t, err) + for _, c := range s { + assert.True(t, (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f'), + "invalid hex char: %c", c) + } + }) +}