fix(security): add SQL injection defense for CREATE DATABASE

Add quoteIdentifier() function to safely quote PostgreSQL identifiers
following PostgreSQL's quoting rules (wrap in double quotes, escape
internal quotes by doubling).

This provides defense-in-depth for the CREATE DATABASE statement,
complementing the existing validateDBName() input validation.

Changes:
- Add quoteIdentifier() function with proper escaping
- Use quoted identifier in CREATE DATABASE statement
- Add comprehensive unit tests for quoteIdentifier()
This commit is contained in:
User
2026-04-16 20:28:36 +08:00
parent c9992af876
commit db307b0d0f
2 changed files with 67 additions and 4 deletions

View File

@@ -160,6 +160,16 @@ func NeedsSetup() bool {
return true
}
// quoteIdentifier safely quotes a PostgreSQL identifier (table name, database name, etc.)
// to prevent SQL injection. It follows PostgreSQL's quoting rules:
// - Wrap in double quotes
// - Escape internal double quotes by doubling them
func quoteIdentifier(name string) string {
// Escape any existing double quotes by doubling them
escaped := strings.ReplaceAll(name, `"`, `""`)
return `"` + escaped + `"`
}
// TestDatabaseConnection tests the database connection and creates database if not exists
func TestDatabaseConnection(cfg *DatabaseConfig) error {
// First, connect to the default 'postgres' database to check/create target database
@@ -198,10 +208,11 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error {
// Create database if not exists
if !exists {
// 注意:数据库名不能参数化,依赖前置输入校验保障安全
// Note: Database names cannot be parameterized, but we've already validated cfg.DBName
// in the handler using validateDBName() which only allows [a-zA-Z][a-zA-Z0-9_]*
_, err := db.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", cfg.DBName))
// 使用 quoteIdentifier 对数据库名进行安全引用,防止 SQL 注入
// 虽然前置校验 validateDBName() 已限制为 [a-zA-Z][a-zA-Z0-9_]*
// 但此处增加防御深度,确保即使校验被绕过也能安全执行。
quotedDBName := quoteIdentifier(cfg.DBName)
_, err := db.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", quotedDBName))
if err != nil {
return fmt.Errorf("failed to create database '%s': %w", cfg.DBName, err)
}

View File

@@ -87,3 +87,55 @@ func TestWriteConfigFileKeepsDefaultUserConcurrency(t *testing.T) {
t.Fatalf("config missing default user concurrency, got:\n%s", string(data))
}
}
func TestQuoteIdentifier(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
expected string
}{
{
name: "simple name",
input: "mydb",
expected: `"mydb"`,
},
{
name: "name with underscore",
input: "my_db_123",
expected: `"my_db_123"`,
},
{
name: "name with double quote (injection attempt)",
input: `my"; DROP TABLE users; --`,
expected: `"my""; DROP TABLE users; --"`,
},
{
name: "name with multiple double quotes",
input: `my"db"test`,
expected: `"my""db""test"`,
},
{
name: "empty name",
input: "",
expected: `""`,
},
{
name: "name starting with number",
input: "123db",
expected: `"123db"`,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := quoteIdentifier(tc.input)
if got != tc.expected {
t.Fatalf("quoteIdentifier(%q) = %q, want %q", tc.input, got, tc.expected)
}
})
}
}