diff --git a/backend/internal/setup/setup.go b/backend/internal/setup/setup.go index 9256d245..925b2138 100644 --- a/backend/internal/setup/setup.go +++ b/backend/internal/setup/setup.go @@ -160,6 +160,16 @@ func NeedsSetup() bool { return true } +// quoteIdentifier safely quotes a PostgreSQL identifier (table name, database name, etc.) +// to prevent SQL injection. It follows PostgreSQL's quoting rules: +// - Wrap in double quotes +// - Escape internal double quotes by doubling them +func quoteIdentifier(name string) string { + // Escape any existing double quotes by doubling them + escaped := strings.ReplaceAll(name, `"`, `""`) + return `"` + escaped + `"` +} + // TestDatabaseConnection tests the database connection and creates database if not exists func TestDatabaseConnection(cfg *DatabaseConfig) error { // First, connect to the default 'postgres' database to check/create target database @@ -198,10 +208,11 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error { // Create database if not exists if !exists { - // 注意:数据库名不能参数化,依赖前置输入校验保障安全。 - // Note: Database names cannot be parameterized, but we've already validated cfg.DBName - // in the handler using validateDBName() which only allows [a-zA-Z][a-zA-Z0-9_]* - _, err := db.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", cfg.DBName)) + // 使用 quoteIdentifier 对数据库名进行安全引用,防止 SQL 注入。 + // 虽然前置校验 validateDBName() 已限制为 [a-zA-Z][a-zA-Z0-9_]*, + // 但此处增加防御深度,确保即使校验被绕过也能安全执行。 + quotedDBName := quoteIdentifier(cfg.DBName) + _, err := db.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", quotedDBName)) if err != nil { return fmt.Errorf("failed to create database '%s': %w", cfg.DBName, err) } diff --git a/backend/internal/setup/setup_test.go b/backend/internal/setup/setup_test.go index a01dd00c..f7eaeb06 100644 --- a/backend/internal/setup/setup_test.go +++ b/backend/internal/setup/setup_test.go @@ -87,3 +87,55 @@ func TestWriteConfigFileKeepsDefaultUserConcurrency(t *testing.T) { t.Fatalf("config missing default user concurrency, got:\n%s", string(data)) } } + +func TestQuoteIdentifier(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected string + }{ + { + name: "simple name", + input: "mydb", + expected: `"mydb"`, + }, + { + name: "name with underscore", + input: "my_db_123", + expected: `"my_db_123"`, + }, + { + name: "name with double quote (injection attempt)", + input: `my"; DROP TABLE users; --`, + expected: `"my""; DROP TABLE users; --"`, + }, + { + name: "name with multiple double quotes", + input: `my"db"test`, + expected: `"my""db""test"`, + }, + { + name: "empty name", + input: "", + expected: `""`, + }, + { + name: "name starting with number", + input: "123db", + expected: `"123db"`, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got := quoteIdentifier(tc.input) + if got != tc.expected { + t.Fatalf("quoteIdentifier(%q) = %q, want %q", tc.input, got, tc.expected) + } + }) + } +}