- Add comprehensive Validator tests (email, phone, username, password) - Add URL and IP validation tests (IPv4/IPv6) - Add SQL injection sanitization tests - Add XSS sanitization tests - Security package coverage: 34.9% -> 69.4% - Overall coverage: 53.5% -> 54.1%
292 lines
7.4 KiB
Go
292 lines
7.4 KiB
Go
package security
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
// TestNewValidator 测试 Validator 创建
|
|
func TestNewValidator(t *testing.T) {
|
|
v := NewValidator(8, true, true)
|
|
assert.NotNil(t, v)
|
|
assert.Equal(t, 8, v.passwordMinLength)
|
|
assert.True(t, v.passwordRequireSpecial)
|
|
assert.True(t, v.passwordRequireNumber)
|
|
|
|
v2 := NewValidator(6, false, false)
|
|
assert.Equal(t, 6, v2.passwordMinLength)
|
|
assert.False(t, v2.passwordRequireSpecial)
|
|
assert.False(t, v2.passwordRequireNumber)
|
|
}
|
|
|
|
// TestValidator_ValidateEmail 测试邮箱验证
|
|
func TestValidator_ValidateEmail(t *testing.T) {
|
|
v := NewValidator(8, true, true)
|
|
|
|
tests := []struct {
|
|
name string
|
|
email string
|
|
expected bool
|
|
}{
|
|
{"empty", "", false},
|
|
{"invalid", "invalid", false},
|
|
{"no at", "test.example.com", false},
|
|
{"no domain", "test@", false},
|
|
{"no user", "@example.com", false},
|
|
{"valid simple", "test@example.com", true},
|
|
{"valid with dot", "test.user@example.com", true},
|
|
{"valid with plus", "test+tag@example.com", true},
|
|
{"valid subdomain", "test@mail.example.com", true},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := v.ValidateEmail(tt.email)
|
|
assert.Equal(t, tt.expected, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestValidator_ValidatePhone 测试手机号验证
|
|
func TestValidator_ValidatePhone(t *testing.T) {
|
|
v := NewValidator(8, true, true)
|
|
|
|
tests := []struct {
|
|
name string
|
|
phone string
|
|
expected bool
|
|
}{
|
|
{"empty", "", false},
|
|
{"invalid format", "12345678901", false},
|
|
{"too short", "1380013800", false},
|
|
{"too long", "138001380001", false},
|
|
{"invalid prefix 1", "12800138000", false},
|
|
{"invalid prefix 2", "10800138000", false},
|
|
{"valid 13x", "13800138000", true},
|
|
{"valid 15x", "15800138000", true},
|
|
{"valid 18x", "18800138000", true},
|
|
{"valid 19x", "19800138000", true},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := v.ValidatePhone(tt.phone)
|
|
assert.Equal(t, tt.expected, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestValidator_ValidateUsername 测试用户名验证
|
|
func TestValidator_ValidateUsername(t *testing.T) {
|
|
v := NewValidator(8, true, true)
|
|
|
|
tests := []struct {
|
|
name string
|
|
username string
|
|
expected bool
|
|
}{
|
|
{"empty", "", false},
|
|
{"too short", "abc", false},
|
|
{"starts with number", "1abc", false},
|
|
{"starts with underscore", "_abc", false},
|
|
{"contains special", "abc@123", false},
|
|
{"valid lowercase", "abc123", true},
|
|
{"valid uppercase", "Abc123", true},
|
|
{"valid with underscore", "abc_123", true},
|
|
{"valid max length", "abcd1234abcd1234abcd", true},
|
|
{"too long", "abcd1234abcd1234abcd1", false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := v.ValidateUsername(tt.username)
|
|
assert.Equal(t, tt.expected, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestValidator_ValidatePassword 测试密码验证
|
|
func TestValidator_ValidatePassword(t *testing.T) {
|
|
v := NewValidator(8, true, true)
|
|
|
|
tests := []struct {
|
|
name string
|
|
password string
|
|
expected bool
|
|
}{
|
|
{"too short", "Abc1!", false},
|
|
{"no number", "Abcdefgh!", false},
|
|
{"no special", "Abcdefgh1", false},
|
|
{"no uppercase", "abcdefgh1!", false},
|
|
{"no lowercase", "ABCDEFGH1!", false},
|
|
{"valid complex", "Abcdef1!", true},
|
|
{"valid longer", "Abcdefgh123!", true},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := v.ValidatePassword(tt.password)
|
|
assert.Equal(t, tt.expected, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestValidator_ValidateURL 测试 URL 验证
|
|
func TestValidator_ValidateURL(t *testing.T) {
|
|
v := NewValidator(8, true, true)
|
|
|
|
tests := []struct {
|
|
name string
|
|
url string
|
|
expected bool
|
|
}{
|
|
{"empty", "", false},
|
|
{"no scheme", "example.com", false},
|
|
{"http", "http://example.com", true},
|
|
{"https", "https://example.com", true},
|
|
{"with path", "https://example.com/path", true},
|
|
{"with query", "https://example.com?foo=bar", true},
|
|
{"with fragment", "https://example.com#section", true},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := v.ValidateURL(tt.url)
|
|
assert.Equal(t, tt.expected, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestValidator_ValidateIP 测试 IP 验证
|
|
func TestValidator_ValidateIP(t *testing.T) {
|
|
v := NewValidator(8, true, true)
|
|
|
|
tests := []struct {
|
|
name string
|
|
ip string
|
|
expected bool
|
|
}{
|
|
{"empty", "", false},
|
|
{"invalid", "not-an-ip", false},
|
|
{"IPv4 valid", "192.168.1.1", true},
|
|
{"IPv4 invalid", "192.168.1.256", false},
|
|
{"IPv6 valid", "::1", true},
|
|
{"IPv6 valid full", "2001:0db8:85a3:0000:0000:8a2e:0370:7334", true},
|
|
{"IPv6 compressed", "fe80::1", true},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := v.ValidateIP(tt.ip)
|
|
assert.Equal(t, tt.expected, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestValidator_ValidateIPv4 测试 IPv4 验证
|
|
func TestValidator_ValidateIPv4(t *testing.T) {
|
|
v := NewValidator(8, true, true)
|
|
|
|
tests := []struct {
|
|
name string
|
|
ip string
|
|
expected bool
|
|
}{
|
|
{"empty", "", false},
|
|
{"IPv4 valid", "192.168.1.1", true},
|
|
{"IPv4 invalid", "192.168.1.256", false},
|
|
{"IPv6 localhost", "::1", false}, // IPv6 should fail IPv4 validation
|
|
{"IPv6 full", "2001:0db8:85a3::8a2e:0370:7334", false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := v.ValidateIPv4(tt.ip)
|
|
assert.Equal(t, tt.expected, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestValidator_ValidateIPv6 测试 IPv6 验证
|
|
func TestValidator_ValidateIPv6(t *testing.T) {
|
|
v := NewValidator(8, true, true)
|
|
|
|
tests := []struct {
|
|
name string
|
|
ip string
|
|
expected bool
|
|
}{
|
|
{"empty", "", false},
|
|
{"IPv4 valid", "192.168.1.1", false}, // IPv4 should fail IPv6 validation
|
|
{"IPv6 localhost", "::1", true},
|
|
{"IPv6 full", "2001:0db8:85a3:0000:0000:8a2e:0370:7334", true},
|
|
{"IPv6 compressed", "fe80::1", true},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := v.ValidateIPv6(tt.ip)
|
|
assert.Equal(t, tt.expected, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestValidator_SanitizeSQL 测试 SQL 净化
|
|
func TestValidator_SanitizeSQL(t *testing.T) {
|
|
v := NewValidator(8, true, true)
|
|
|
|
tests := []struct {
|
|
name string
|
|
input string
|
|
expected string
|
|
}{
|
|
{"empty", "", ""},
|
|
{"normal text", "hello world", "hello world"},
|
|
{"quote escape", "'test'", "''test''"},
|
|
{"backslash escape", "\\test", "\\test"},
|
|
{"remove comment", "select; -- comment", "select "},
|
|
{"remove block comment", "select /* comment */ from", "select from"},
|
|
{"remove union", "select union select", "select "},
|
|
{"remove drop", "drop table users", ""},
|
|
{"remove insert", "insert into users", ""},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := v.SanitizeSQL(tt.input)
|
|
// 检查输出不包含危险模式
|
|
assert.NotContains(t, got, "--")
|
|
assert.NotContains(t, got, "/*")
|
|
assert.NotContains(t, got, "*/")
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestValidator_SanitizeXSS 测试 XSS 净化
|
|
func TestValidator_SanitizeXSS(t *testing.T) {
|
|
v := NewValidator(8, true, true)
|
|
|
|
tests := []struct {
|
|
name string
|
|
input string
|
|
checkNot string
|
|
}{
|
|
{"empty", "", ""},
|
|
{"normal text", "hello world", ""},
|
|
{"remove script", "<script>alert('xss')</script>", "script"},
|
|
{"remove iframe", "<iframe src='evil.com'></iframe>", "iframe"},
|
|
{"remove javascript", "javascript:alert(1)", "javascript:"},
|
|
{"remove event handler", "<img onerror='alert(1)'>", "onerror"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := v.SanitizeXSS(tt.input)
|
|
if tt.checkNot != "" {
|
|
assert.NotContains(t, got, tt.checkNot)
|
|
}
|
|
})
|
|
}
|
|
}
|