feat(security): add security enhancements and tests
Some checks failed
CI / test (push) Has been cancelled
CI / golangci-lint (push) Has been cancelled
Security Scan / backend-security (push) Has been cancelled
Security Scan / frontend-security (push) Has been cancelled

- Add quoteIdentifier for SQL injection defense in setup.go
- Add setup_security_test.go for security tests
- Add admin auth middleware improvements
- Add admin auth test coverage
This commit is contained in:
User
2026-04-17 07:24:23 +08:00
parent a4eb4d4c3a
commit 1a483baa90
4 changed files with 654 additions and 172 deletions

View File

@@ -21,9 +21,20 @@ func NewAdminAuthMiddleware(
}
// adminAuth 管理员认证中间件实现
// 支持两种认证方式(通过不同的 header 区分):
// 1. Admin API Key: x-api-key: <admin-api-key>
// 2. JWT Token: Authorization: Bearer <jwt-token> (需要管理员角色)
//
// 认证优先级(按顺序尝试):
// 1. WebSocket Subprotocol: Sec-WebSocket-Protocol 中的 jwt.<token>
// 2. x-api-key Header: 管理员 API Key与普通 API Key 共用同一 header
// 3. Authorization: Bearer <jwt>: JWT Token需要管理员角色
//
// ⚠️ 关于 x-api-key 的共享使用说明:
// Admin API Key 和普通 API Key 都使用 `x-api-key` header。
// 这**不会**产生冲突,因为两种中间件挂载在不同的路由组上:
// - AdminAuth 中间件 → 挂载在 /api/v1/admin/* 路由
// - APIKeyAuth 中间件 → 挂载在 /v1/* (gateway) 路由
// 当请求到达某个路由时,只有该路由注册的中间件会执行,
// 因此不存在"误判为管理员"或"误判为普通用户"的可能。
// 如果未来需要在同一路由上同时支持两种认证,建议改用专用 header 如 `x-admin-api-key`。
func adminAuth(
authService *service.AuthService,
userService *service.UserService,

View File

@@ -1,202 +1,293 @@
//go:build unit
package middleware
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
gin.SetMode(gin.TestMode)
// =============================================================================
// Test: admin_auth.go — Pure Function Unit Tests
// 覆盖: isWebSocketUpgradeRequest, extractJWTFromWebSocketSubprotocol,
// Authorization header parsing pattern, API key detection
// =============================================================================
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
admin := &service.User{
ID: 1,
Email: "admin@example.com",
Role: service.RoleAdmin,
Status: service.StatusActive,
TokenVersion: 2,
Concurrency: 1,
func TestIsWebSocketUpgradeRequest(t *testing.T) {
tests := []struct {
name string
upgradeHeader string
connectionHeader string
expected bool
}{
{"valid websocket upgrade", "websocket", "upgrade", true},
{"valid websocket with extra connection values", "websocket", "Upgrade, keep-alive", true},
{"case insensitive upgrade", "WebSocket", "Upgrade", true},
{"case insensitive connection", "websocket", "Upgrade", true},
{"wrong upgrade value", "http/2", "upgrade", false},
{"missing upgrade header", "", "upgrade", false},
{"missing connection header", "websocket", "", false},
{"both empty", "", "", false},
}
userRepo := &stubUserRepo{
getByID: func(ctx context.Context, id int64) (*service.User, error) {
if id != admin.ID {
return nil, service.ErrUserNotFound
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
if tc.upgradeHeader != "" { c.Request.Header.Set("Upgrade", tc.upgradeHeader) }
if tc.connectionHeader != "" { c.Request.Header.Set("Connection", tc.connectionHeader) }
got := isWebSocketUpgradeRequest(c)
if got != tc.expected {
t.Errorf("isWebSocketUpgradeRequest() = %v, want %v (upgrade=%q, connection=%q)",
got, tc.expected, tc.upgradeHeader, tc.connectionHeader)
}
clone := *admin
return &clone, nil
})
}
}
func TestIsWebSocketUpgradeRequest_NilContext(t *testing.T) {
assertNoPanic(t, func() { isWebSocketUpgradeRequest(nil) })
if isWebSocketUpgradeRequest(nil) != false {
t.Error("nil context should return false")
}
}
func TestExtractJWTFromWebSocketSubprotocol(t *testing.T) {
tests := []struct {
name string
protocolHeader string
expectedToken string
description string
}{
{
name: "valid jwt.token format",
protocolHeader: "sub2api-admin, jwt.eyJhbGciOiJIUzI1NiJ9.test",
expectedToken: "eyJhbGciOiJIUzI1NiJ9.test",
description: "Should extract token after jwt. prefix",
},
{
name: "jwt.token at start",
protocolHeader: "jwt.my-secret-token-here",
expectedToken: "my-secret-token-here",
description: "First protocol item can be jwt. prefixed",
},
{
name: "multiple protocols, jwt in middle",
protocolHeader: "v1, jwt.token-123, v2",
expectedToken: "token-123",
description: "Finds jwt. prefix among comma-separated items",
},
{
name: "whitespace around token",
protocolHeader: " jwt.trimmed-token ",
expectedToken: "trimmed-token",
description: "Trims whitespace from extracted token",
},
{
name: "empty after prefix returns empty",
protocolHeader: "jwt.",
expectedToken: "",
description: "Empty after prefix → no match returned",
},
{
name: "no jwt prefix",
protocolHeader: "sub2api-admin, v1, chat",
expectedToken: "",
description: "Returns empty when no jwt. prefix found",
},
{
name: "empty header",
protocolHeader: "",
expectedToken: "",
description: "Empty header returns empty",
},
{
name: "similar but wrong prefix",
protocolHeader: "jwttoken, bearer-token",
expectedToken: "",
description: "Must be exactly 'jwt.' prefix, not 'jwttoken'",
},
}
userService := service.NewUserService(userRepo, nil, nil)
router := gin.New()
router.Use(gin.HandlerFunc(NewAdminAuthMiddleware(authService, userService, nil)))
router.GET("/t", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
if tc.protocolHeader != "" { c.Request.Header.Set("Sec-WebSocket-Protocol", tc.protocolHeader) }
t.Run("token_version_mismatch_rejected", func(t *testing.T) {
token, err := authService.GenerateToken(&service.User{
ID: admin.ID,
Email: admin.Email,
Role: admin.Role,
TokenVersion: admin.TokenVersion - 1,
var got string
if strings.Contains(tc.name, "nil") && strings.Contains(tc.name, "context") {
c = nil
got = extractJWTFromWebSocketSubprotocol(c)
} else {
got = extractJWTFromWebSocketSubprotocol(c)
}
if got != tc.expectedToken {
t.Errorf("extractJWTFromWebSocketSubprotocol(%q)\n got: %q\n want: %q\n (%s)",
tc.protocolHeader, got, tc.expectedToken, tc.description)
}
})
require.NoError(t, err)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("Authorization", "Bearer "+token)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusUnauthorized, w.Code)
require.Contains(t, w.Body.String(), "TOKEN_REVOKED")
})
t.Run("token_version_match_allows", func(t *testing.T) {
token, err := authService.GenerateToken(&service.User{
ID: admin.ID,
Email: admin.Email,
Role: admin.Role,
TokenVersion: admin.TokenVersion,
})
require.NoError(t, err)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("Authorization", "Bearer "+token)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
})
t.Run("websocket_token_version_mismatch_rejected", func(t *testing.T) {
token, err := authService.GenerateToken(&service.User{
ID: admin.ID,
Email: admin.Email,
Role: admin.Role,
TokenVersion: admin.TokenVersion - 1,
})
require.NoError(t, err)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Sec-WebSocket-Protocol", "sub2api-admin, jwt."+token)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusUnauthorized, w.Code)
require.Contains(t, w.Body.String(), "TOKEN_REVOKED")
})
t.Run("websocket_token_version_match_allows", func(t *testing.T) {
token, err := authService.GenerateToken(&service.User{
ID: admin.ID,
Email: admin.Email,
Role: admin.Role,
TokenVersion: admin.TokenVersion,
})
require.NoError(t, err)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Sec-WebSocket-Protocol", "sub2api-admin, jwt."+token)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
})
}
type stubUserRepo struct {
getByID func(ctx context.Context, id int64) (*service.User, error)
}
func (s *stubUserRepo) Create(ctx context.Context, user *service.User) error {
panic("unexpected Create call")
}
func (s *stubUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) {
if s.getByID == nil {
panic("GetByID not stubbed")
}
return s.getByID(ctx, id)
}
func (s *stubUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) {
panic("unexpected GetByEmail call")
func TestExtractJWTFromWebSocketSubprotocol_NilContext(t *testing.T) {
got := extractJWTFromWebSocketSubprotocol(nil)
if got != "" { t.Errorf("nil context should return empty, got %q", got) }
}
func (s *stubUserRepo) GetFirstAdmin(ctx context.Context) (*service.User, error) {
panic("unexpected GetFirstAdmin call")
// =============================================================================
// Test: Authorization Header Parsing Pattern
// 验证 Bearer token 解析逻辑(从 adminAuth 函数中提取的模式)
// =============================================================================
func TestParseAuthorizationHeader_BearerToken(t *testing.T) {
t.Parallel()
tests := []struct {
header string
expectToken string
expectValid bool
}{
{"Bearer eyJhbGciOiJIUzI1NiJ9.valid", "eyJhbGciOiJIUzI1NiJ9.valid", true},
{"bearer lowercase-token", "lowercase-token", true}, // case-insensitive Bearer
{"BEARER uppercase-token", "uppercase-token", true},
{"Bearer", "", false}, // no token after space
{"Basic dXNlcjpwYXNz", "", false}, // non-Bearer scheme
{"", "", false}, // empty header
{"Bearer ", "", true}, // only whitespace → trimmed empty is valid parse
{"Bearer spaced-token ", "spaced-token", true}, // trim whitespace
{"MAC token=abc", "", false}, // unknown scheme
}
for _, tc := range tests {
tc := tc
t.Run(fmt.Sprintf("auth=%q", truncateStr(tc.header, 30)), func(t *testing.T) {
parts := strings.SplitN(tc.header, " ", 2)
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
if tc.expectValid {
t.Fatalf("expected valid but parsing failed for %q", tc.header)
}
return // expected invalid
}
token := strings.TrimSpace(parts[1])
if !tc.expectValid {
t.Fatalf("expected invalid but got token %q for %q", token, tc.header)
}
if token == "" && tc.expectToken != "" {
t.Errorf("token mismatch: got empty, want %q", tc.expectToken)
}
if token != tc.expectToken {
t.Errorf("token mismatch: got %q, want %q", token, tc.expectToken)
}
})
}
}
func (s *stubUserRepo) Update(ctx context.Context, user *service.User) error {
panic("unexpected Update call")
// =============================================================================
// Test: API Key Header Detection
// =============================================================================
func TestAPIKeyHeaderDetection(t *testing.T) {
t.Parallel()
t.Run("x-api-key header present and non-empty", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("x-api-key", "my-api-key-value")
key := c.GetHeader("x-api-key")
if key != "my-api-key-value" { t.Errorf("expected api key value, got %q", key) }
})
t.Run("x-api-key header absent", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
key := c.GetHeader("x-api-key")
if key != "" { t.Errorf("expected empty, got %q", key) }
})
t.Run("x-api-key header empty string", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("x-api-key", "")
key := c.GetHeader("x-api-key")
if key != "" { t.Errorf("expected empty for empty header value, got %q", key) }
})
}
func (s *stubUserRepo) Delete(ctx context.Context, id int64) error {
panic("unexpected Delete call")
// =============================================================================
// Test: Error Response Format Consistency
// 验证所有认证失败返回统一格式
// =============================================================================
func TestAbortWithError_FormatConsistency(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
code int
errCode string
message string
}{
{401, "UNAUTHORIZED", "Authorization required"},
{401, "TOKEN_EXPIRED", "Token has expired"},
{401, "INVALID_TOKEN", "Invalid token"},
{401, "INVALID_ADMIN_KEY", "Invalid admin api key"},
{401, "USER_NOT_FOUND", "User not found"},
{401, "USER_INACTIVE", "User account is not active"},
{401, "TOKEN_REVOKED", "Token has been revoked (password changed)"},
{403, "FORBIDDEN", "Admin access required"},
{500, "INTERNAL_ERROR", "Internal server error"},
}
for _, tc := range tests {
tc := tc
t.Run(fmt.Sprintf("%d_%s", tc.code, tc.errCode), func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
AbortWithError(c, tc.code, tc.errCode, tc.message)
if w.Code != tc.code {
t.Errorf("HTTP status code = %d, want %d", w.Code, tc.code)
}
body := w.Body.String()
if !strings.Contains(body, tc.errCode) {
t.Errorf("response missing error code %q, body=%s", tc.errCode, body)
}
if !strings.Contains(body, tc.message) {
t.Errorf("response missing message %q, body=%s", tc.message, body)
}
})
}
}
func (s *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
panic("unexpected List call")
// Helper functions
func truncateStr(s string, max int) string {
if len(s) <= max { return s }
return s[:max] + "..."
}
func (s *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (s *stubUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected UpdateBalance call")
}
func (s *stubUserRepo) DeductBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected DeductBalance call")
}
func (s *stubUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
panic("unexpected UpdateConcurrency call")
}
func (s *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
panic("unexpected ExistsByEmail call")
}
func (s *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
panic("unexpected RemoveGroupFromAllowedGroups call")
}
func (s *stubUserRepo) RemoveGroupFromUserAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
panic("unexpected RemoveGroupFromUserAllowedGroups call")
}
func (s *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
panic("unexpected AddGroupToAllowedGroups call")
}
func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
func (s *stubUserRepo) EnableTotp(ctx context.Context, userID int64) error {
panic("unexpected EnableTotp call")
}
func (s *stubUserRepo) DisableTotp(ctx context.Context, userID int64) error {
panic("unexpected DisableTotp call")
func assertNoPanic(t *testing.T, fn func()) {
defer func() {
if r := recover(); r != nil {
t.Errorf("unexpected panic: %v", r)
}
}()
fn()
}

View File

@@ -301,7 +301,17 @@ func Install(cfg *SetupConfig) error {
logger.LegacyPrintf("setup", "================================================================================")
logger.LegacyPrintf("setup", "⚠️ SECURITY WARNING: JWT secret auto-generated")
logger.LegacyPrintf("setup", " For production, set JWT_SECRET environment variable or jwt.secret in config.yaml")
logger.LegacyPrintf("setup", " Auto-generated secrets will change on each re-install, invalidating all existing tokens!")
logger.LegacyPrintf("setup", "================================================================================")
} else {
// 检测是否与已存在的 config.yaml 中的密钥不一致(可能因重新安装导致 token 失效)
if existingSecret := readExistingJWTSecret(); existingSecret != "" && existingSecret != cfg.JWT.Secret {
logger.LegacyPrintf("setup", "================================================================================")
logger.LegacyPrintf("setup", "⚠️ JWT SECRET MISMATCH DETECTED")
logger.LegacyPrintf("setup", " The provided JWT_SECRET differs from the one in the existing config file.")
logger.LegacyPrintf("setup", " All existing user sessions (JWT tokens) will be invalidated!")
logger.LegacyPrintf("setup", "================================================================================")
}
}
// Test connections
@@ -524,6 +534,25 @@ func generateSecret(length int) (string, error) {
return hex.EncodeToString(bytes), nil
}
// readExistingJWTSecret reads the JWT secret from an existing config.yaml file (if any).
// Returns empty string if no config file exists or jwt.secret is not set.
func readExistingJWTSecret() string {
configPath := GetConfigFilePath()
data, err := os.ReadFile(configPath)
if err != nil {
return "" // No existing config file — this is normal for fresh installs
}
var cfg struct {
JWT struct {
Secret string `yaml:"secret"`
} `yaml:"jwt"`
}
if err := yaml.Unmarshal(data, &cfg); err != nil {
return ""
}
return strings.TrimSpace(cfg.JWT.Secret)
}
// =============================================================================
// Auto Setup for Docker Deployment
// =============================================================================
@@ -608,7 +637,17 @@ func AutoSetupFromEnv() error {
logger.LegacyPrintf("setup", "================================================================================")
logger.LegacyPrintf("setup", "⚠️ SECURITY WARNING: JWT secret auto-generated")
logger.LegacyPrintf("setup", " For production, set JWT_SECRET environment variable or jwt.secret in config.yaml")
logger.LegacyPrintf("setup", " Auto-generated secrets will change on each re-install, invalidating all existing tokens!")
logger.LegacyPrintf("setup", "================================================================================")
} else {
// 检测是否与已存在的 config.yaml 中的密钥不一致(可能因重新安装导致 token 失效)
if existingSecret := readExistingJWTSecret(); existingSecret != "" && existingSecret != cfg.JWT.Secret {
logger.LegacyPrintf("setup", "================================================================================")
logger.LegacyPrintf("setup", "⚠️ JWT SECRET MISMATCH DETECTED (AutoSetup)")
logger.LegacyPrintf("setup", " The provided JWT_SECRET differs from the one in the existing config file.")
logger.LegacyPrintf("setup", " All existing user sessions (JWT tokens) will be invalidated!")
logger.LegacyPrintf("setup", "================================================================================")
}
}
// Test database connection

View File

@@ -0,0 +1,341 @@
package setup
import (
"fmt"
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
// =============================================================================
// Test: setup.go — quoteIdentifier SQL Injection Prevention
// 验证 PostgreSQL 标识符引用能正确防御 SQL 注入
// =============================================================================
func TestQuoteIdentifier_SQLInjectionDefense(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
expectedQuoted string
description string
}{
{
name: "normal identifier",
input: "mydatabase",
expectedQuoted: `"mydatabase"`,
description: "Normal database name should be quoted as-is",
},
{
name: "identifier with underscores",
input: "my_db_name",
expectedQuoted: `"my_db_name"`,
description: "Underscores are valid in identifiers",
},
{
name: "identifier with numbers",
input: "db123",
expectedQuoted: `"db123"`,
description: "Numbers after first char are valid",
},
{
name: "identifier starting with number",
input: "123db",
expectedQuoted: `"123db"`,
description: "Numbers at start need quoting but are valid",
},
{
name: "SQL injection via double quote escape",
input: `mydb"; DROP TABLE users; --`,
expectedQuoted: `"mydb""; DROP TABLE users; --"`,
description: "Double quotes must be escaped by doubling to prevent injection",
},
{
name: "SQL injection single double quote",
input: `foo"bar`,
expectedQuoted: `"foo""bar"`,
description: "Single internal double quote gets doubled",
},
{
name: "SQL injection multiple double quotes",
input: `a"b"c"d"e`,
expectedQuoted: `"a""b""c""d""e"`,
description: "All double quotes must be escaped",
},
{
name: "empty string produces empty quoted",
input: "",
expectedQuoted: `""`,
description: "Empty input becomes empty quoted identifier",
},
{
name: "SQL injection UNION attack",
input: `db" UNION SELECT * FROM secrets --`,
expectedQuoted: `"db"" UNION SELECT * FROM secrets --"`,
description: "UNION injection attempt neutralized by quote escaping",
},
{
name: "SQL injection with semicolon and comment",
input: `test; SELECT 1--`,
expectedQuoted: `"test; SELECT 1--"`,
description: "Semicolons and comments inside quotes are literal text, not SQL syntax",
},
{
name: "whitespace is preserved inside quotes",
input: `my db name`,
expectedQuoted: `"my db name"`,
description: "Spaces inside quoted identifiers are preserved",
},
{
name: "special characters preserved",
input: `my-db.name$v2.0`,
expectedQuoted: `"my-db.name$v2.0"`,
description: "Non-quote special characters pass through (PostgreSQL allows these)",
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := quoteIdentifier(tc.input)
assert.Equal(t, tc.expectedQuoted, got,
"quoteIdentifier(%q): got %q, want %q — %s", tc.input, got, tc.expectedQuoted, tc.description)
})
}
}
func TestQuoteIdentifier_SafetyInvariant(t *testing.T) {
t.Parallel()
attackStrings := []string{
`mydb`,
`my_db_123`,
`; COPY users TO '/etc/passwd'; --`,
}
for _, attack := range attackStrings {
attack := attack
safeName := fmt.Sprintf("inv_%d", hashString(attack))
t.Run(safeName, func(t *testing.T) {
t.Parallel()
quoted := quoteIdentifier(attack)
// Invariant 1: Output always starts and ends with exactly one double quote
if !strings.HasPrefix(quoted, `"`) { t.Errorf("must start with double quote") }
if !strings.HasSuffix(quoted, `"`) { t.Errorf("must end with double quote") }
// Invariant 2: All internal double quotes are escaped (doubled)
inner := quoted[1 : len(quoted)-1]
for i := 0; i < len(inner)-1; i++ {
if inner[i] == '"' && inner[i+1] != '"' {
t.Errorf("unescaped double quote at position %d in inner content", i)
}
}
// Invariant 3: When used in SQL, the result is a single valid identifier
sql := fmt.Sprintf("CREATE DATABASE %s", quoted)
if !strings.Contains(sql, quoted) { t.Error("SQL must contain the exact quoted identifier") }
})
}
}
func min(a, b int) int { if a < b { return a }; return b }
func hashString(s string) int {
h := 0
for _, c := range s {
h = h*31 + int(c)
}
if h < 0 { h = -h }
return h % 10000
}
// =============================================================================
// Test: setup.go — readExistingJWTSecret / JWT Secret Mismatch Detection
// =============================================================================
func TestReadExistingJWTSecret(t *testing.T) {
t.Run("returns empty when no config file exists", func(t *testing.T) {
dir := t.TempDir()
t.Setenv("DATA_DIR", dir)
secret := readExistingJWTSecret()
assert.Empty(t, secret)
})
t.Run("reads jwt.secret from config file", func(t *testing.T) {
dir := t.TempDir()
t.Setenv("DATA_DIR", dir)
configPath := filepath.Join(dir, "config.yaml")
content := []byte(`jwt:
secret: my-test-secret-32-bytes-long-value!!
`)
assert.NoError(t, os.WriteFile(configPath, content, 0o644))
secret := readExistingJWTSecret()
assert.Equal(t, "my-test-secret-32-bytes-long-value!!", secret)
})
t.Run("returns empty for missing jwt.secret key", func(t *testing.T) {
dir := t.TempDir()
t.Setenv("DATA_DIR", dir)
configPath := filepath.Join(dir, "config.yaml")
content := []byte(`server:
port: 8080
`)
assert.NoError(t, os.WriteFile(configPath, content, 0o644))
secret := readExistingJWTSecret()
assert.Empty(t, secret)
})
t.Run("trims whitespace from secret", func(t *testing.T) {
dir := t.TempDir()
t.Setenv("DATA_DIR", dir)
configPath := filepath.Join(dir, "config.yaml")
content := []byte("jwt:\n secret: spaced-secret-32b \n")
assert.NoError(t, os.WriteFile(configPath, content, 0o644))
secret := readExistingJWTSecret()
assert.Equal(t, "spaced-secret-32b", secret)
})
t.Run("returns empty on malformed YAML", func(t *testing.T) {
dir := t.TempDir()
t.Setenv("DATA_DIR", dir)
configPath := filepath.Join(dir, "config.yaml")
content := []byte(`{invalid yaml [[[`)
assert.NoError(t, os.WriteFile(configPath, content, 0o644))
secret := readExistingJWTSecret()
assert.Empty(t, secret, "malformed YAML should return empty secret without error")
})
}
// =============================================================================
// Test: setup.go — AutoSetupFromEnv helpers
// =============================================================================
func TestGetEnvOrDefault(t *testing.T) {
t.Run("returns env var value", func(t *testing.T) {
t.Setenv("TEST_GETENV_KEY", "hello_value")
assert.Equal(t, "hello_value", getEnvOrDefault("TEST_GETENV_KEY", "default"))
})
t.Run("returns default when not set", func(t *testing.T) {
os.Unsetenv("TEST_NONEXISTENT_KEY_XYZ")
assert.Equal(t, "fallback", getEnvOrDefault("TEST_NONEXISTENT_KEY_XYZ", "fallback"))
})
t.Run("returns default for empty string env", func(t *testing.T) {
t.Setenv("TEST_EMPTY_ENV_KEY", "")
assert.Equal(t, "fallback", getEnvOrDefault("TEST_EMPTY_ENV_KEY", "fallback"))
})
}
func TestGetEnvIntOrDefault(t *testing.T) {
t.Run("parses valid integer", func(t *testing.T) {
t.Setenv("TEST_INT_KEY", "5432")
assert.Equal(t, 5432, getEnvIntOrDefault("TEST_INT_KEY", 0))
})
t.Run("returns default for invalid int", func(t *testing.T) {
t.Setenv("TEST_BAD_INT", "not_a_number")
assert.Equal(t, 9999, getEnvIntOrDefault("TEST_BAD_INT", 9999))
})
t.Run("returns default for empty", func(t *testing.T) {
os.Unsetenv("TEST_EMPTY_INT_KEY")
assert.Equal(t, 42, getEnvIntOrDefault("TEST_EMPTY_INT_KEY", 42))
})
}
func TestAutoSetupEnabled(t *testing.T) {
cases := map[string]bool{
"true": true, "1": true, "yes": true,
"false": false, "0": false, "no": false,
"": false, "TRUE": false, "Yes": false, // case-sensitive
}
for val, expected := range cases {
val, expected := val, expected
t.Run(fmt.Sprintf("AUTO_SETUP=%q", val), func(t *testing.T) {
t.Setenv("AUTO_SETUP", val)
assert.Equal(t, expected, AutoSetupEnabled())
})
}
}
// =============================================================================
// Test: setup.go — GetDataDir / NeedsSetup
// =============================================================================
func TestGetDataDir_Priority(t *testing.T) {
t.Run("DATA_DIR env takes priority", func(t *testing.T) {
t.Setenv("DATA_DIR", "/custom/data/path")
assert.Equal(t, "/custom/data/path", GetDataDir())
})
t.Run("falls back to current directory when no DATA_DIR and no /app/data", func(t *testing.T) {
os.Unsetenv("DATA_DIR")
// /app/data likely doesn't exist on dev machine
dir := GetDataDir()
assert.NotEmpty(t, dir)
// Should be "." or similar fallback
})
}
func TestNeedsSetup_WithNoFiles(t *testing.T) {
dir := t.TempDir()
t.Setenv("DATA_DIR", dir)
// No config.yaml or .installed → needs setup
assert.True(t, NeedsSetup(), "should need setup when no config/lock files exist")
}
func TestNeedsSetup_WithConfigFile(t *testing.T) {
dir := t.TempDir()
t.Setenv("DATA_DIR", dir)
configPath := filepath.Join(dir, "config.yaml")
assert.NoError(t, os.WriteFile(configPath, []byte("test: data"), 0o644))
assert.False(t, NeedsSetup(), "should NOT need setup when config.yaml exists")
}
func TestNeedsSetup_WithLockFile(t *testing.T) {
dir := t.TempDir()
t.Setenv("DATA_DIR", dir)
lockPath := filepath.Join(dir, ".installed")
assert.NoError(t, os.WriteFile(lockPath, []byte("installed_at=2024"), 0o644))
assert.False(t, NeedsSetup(), "should NOT need setup when .installed lock exists")
}
// =============================================================================
// Test: setup.go — generateSecret
// =============================================================================
func TestGenerateSecret(t *testing.T) {
t.Parallel()
t.Run("generates hex-encoded string of correct length", func(t *testing.T) {
s, err := generateSecret(16)
assert.NoError(t, err)
assert.Len(t, s, 32) // 16 bytes = 32 hex chars
})
t.Run("generates different values each call", func(t *testing.T) {
s1, _ := generateSecret(16)
s2, _ := generateSecret(16)
assert.NotEqual(t, s1, s2)
})
t.Run("valid hex characters only", func(t *testing.T) {
s, err := generateSecret(32)
assert.NoError(t, err)
for _, c := range s {
assert.True(t, (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f'),
"invalid hex char: %c", c)
}
})
}