- IsAdminBootstrapRequired: userRepo.GetByID 循环 → GetByIDs 批量 - AssignRoles: roleRepo.GetByID 循环 → GetByIDs 批量 - 在 userRepositoryInterface 补充 GetByIDs 方法签名
209 lines
5.5 KiB
Go
209 lines
5.5 KiB
Go
package auth
|
||
|
||
import (
|
||
"strings"
|
||
"testing"
|
||
)
|
||
|
||
func TestTOTPManager_GenerateAndValidate(t *testing.T) {
|
||
m := NewTOTPManager()
|
||
|
||
// 生成密钥
|
||
setup, err := m.GenerateSecret("testuser@example.com")
|
||
if err != nil {
|
||
t.Fatalf("GenerateSecret 失败: %v", err)
|
||
}
|
||
|
||
if setup.Secret == "" {
|
||
t.Fatal("生成的 Secret 不应为空")
|
||
}
|
||
if setup.QRCodeBase64 == "" {
|
||
t.Fatal("QRCode Base64 不应为空")
|
||
}
|
||
if len(setup.RecoveryCodes) != RecoveryCodeCount {
|
||
t.Fatalf("恢复码数量期望 %d,实际 %d", RecoveryCodeCount, len(setup.RecoveryCodes))
|
||
}
|
||
t.Logf("生成 Secret: %s", setup.Secret)
|
||
t.Logf("恢复码示例: %s", setup.RecoveryCodes[0])
|
||
|
||
// 用生成的密钥生成当前 TOTP 码,再验证
|
||
code, err := m.GenerateCurrentCode(setup.Secret)
|
||
if err != nil {
|
||
t.Fatalf("GenerateCurrentCode 失败: %v", err)
|
||
}
|
||
if !m.ValidateCode(setup.Secret, code) {
|
||
t.Fatalf("有效 TOTP 码应该通过验证,code=%s", code)
|
||
}
|
||
t.Logf("TOTP 验证通过,code=%s", code)
|
||
}
|
||
|
||
func TestTOTPManager_InvalidCode(t *testing.T) {
|
||
m := NewTOTPManager()
|
||
setup, err := m.GenerateSecret("user")
|
||
if err != nil {
|
||
t.Fatalf("GenerateSecret 失败: %v", err)
|
||
}
|
||
|
||
// 错误的验证码
|
||
if m.ValidateCode(setup.Secret, "000000") {
|
||
// 偶尔可能恰好正确,跳过而不是 fatal
|
||
t.Skip("000000 碰巧是有效码,跳过测试")
|
||
}
|
||
t.Log("无效验证码正确拒绝")
|
||
}
|
||
|
||
func TestTOTPManager_RecoveryCodeFormat(t *testing.T) {
|
||
m := NewTOTPManager()
|
||
setup, err := m.GenerateSecret("user2")
|
||
if err != nil {
|
||
t.Fatalf("GenerateSecret 失败: %v", err)
|
||
}
|
||
|
||
for i, code := range setup.RecoveryCodes {
|
||
parts := strings.Split(code, "-")
|
||
if len(parts) != 2 {
|
||
t.Errorf("恢复码 [%d] 格式错误(期望 XXXXX-XXXXX): %s", i, code)
|
||
}
|
||
if len(parts[0]) != 5 || len(parts[1]) != 5 {
|
||
t.Errorf("恢复码 [%d] 各部分长度应为 5: %s", i, code)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestValidateRecoveryCode(t *testing.T) {
|
||
codes := []string{"ABCDE-FGHIJ", "KLMNO-PQRST", "UVWXY-ZABCD"}
|
||
|
||
// 正确匹配
|
||
idx, ok := ValidateRecoveryCode("ABCDE-FGHIJ", codes)
|
||
if !ok || idx != 0 {
|
||
t.Fatalf("有效恢复码应该匹配,idx=%d ok=%v", idx, ok)
|
||
}
|
||
|
||
// 大小写不敏感
|
||
idx2, ok2 := ValidateRecoveryCode("klmno-pqrst", codes)
|
||
if !ok2 || idx2 != 1 {
|
||
t.Fatalf("大小写不敏感匹配失败,idx=%d ok=%v", idx2, ok2)
|
||
}
|
||
|
||
// 去除空格
|
||
idx3, ok3 := ValidateRecoveryCode(" UVWXY-ZABCD ", codes)
|
||
if !ok3 || idx3 != 2 {
|
||
t.Fatalf("去除空格匹配失败,idx=%d ok=%v", idx3, ok3)
|
||
}
|
||
|
||
// 不匹配
|
||
_, ok4 := ValidateRecoveryCode("XXXXX-YYYYY", codes)
|
||
if ok4 {
|
||
t.Fatal("无效恢复码不应该匹配")
|
||
}
|
||
|
||
t.Log("恢复码验证全部通过")
|
||
}
|
||
|
||
func TestHashRecoveryCode(t *testing.T) {
|
||
code := "ABCDE-FGHIJ"
|
||
|
||
hashed, err := HashRecoveryCode(code)
|
||
if err != nil {
|
||
t.Fatalf("HashRecoveryCode failed: %v", err)
|
||
}
|
||
|
||
if hashed == "" {
|
||
t.Fatal("HashRecoveryCode should return non-empty hash")
|
||
}
|
||
|
||
// Same code should verify against its own hash (bcrypt uses random salt, so hashes differ)
|
||
_, ok := VerifyRecoveryCode(code, []string{hashed})
|
||
if !ok {
|
||
t.Error("Same code should verify against its own hash")
|
||
}
|
||
|
||
// Different codes should NOT verify
|
||
hashed3, err := HashRecoveryCode("DIFFERENT-CODE")
|
||
if err != nil {
|
||
t.Fatalf("HashRecoveryCode for different code failed: %v", err)
|
||
}
|
||
|
||
_, ok2 := VerifyRecoveryCode(code, []string{hashed3})
|
||
if ok2 {
|
||
t.Error("Different codes should NOT verify against each other's hash")
|
||
}
|
||
|
||
// bcrypt hash format check
|
||
if !strings.HasPrefix(hashed, "$2a$") {
|
||
t.Errorf("Hash should be bcrypt format, got: %s", hashed)
|
||
}
|
||
|
||
t.Logf("Hashed code (bcrypt): %s", hashed)
|
||
}
|
||
|
||
func TestVerifyRecoveryCode(t *testing.T) {
|
||
// Generate hashed codes
|
||
codes := []string{"ABCDE-FGHIJ", "KLMNO-PQRST", "UVWXY-ZABCD"}
|
||
hashedCodes := make([]string, len(codes))
|
||
for i, code := range codes {
|
||
hashed, err := HashRecoveryCode(code)
|
||
if err != nil {
|
||
t.Fatalf("HashRecoveryCode failed: %v", err)
|
||
}
|
||
hashedCodes[i] = hashed
|
||
}
|
||
|
||
// Test valid code (exact match)
|
||
idx, ok := VerifyRecoveryCode("ABCDE-FGHIJ", hashedCodes)
|
||
if !ok || idx != 0 {
|
||
t.Fatalf("Valid recovery code should match, idx=%d ok=%v", idx, ok)
|
||
}
|
||
|
||
// Test second code
|
||
idx2, ok2 := VerifyRecoveryCode("KLMNO-PQRST", hashedCodes)
|
||
if !ok2 || idx2 != 1 {
|
||
t.Fatalf("Second code match failed, idx=%d ok=%v", idx2, ok2)
|
||
}
|
||
|
||
// Test third code
|
||
idx3, ok3 := VerifyRecoveryCode("UVWXY-ZABCD", hashedCodes)
|
||
if !ok3 || idx3 != 2 {
|
||
t.Fatalf("Third code match failed, idx=%d ok=%v", idx3, ok3)
|
||
}
|
||
|
||
// Test invalid code
|
||
_, ok4 := VerifyRecoveryCode("XXXXX-YYYYY", hashedCodes)
|
||
if ok4 {
|
||
t.Fatal("Invalid recovery code should not match")
|
||
}
|
||
|
||
// Test empty hashed codes list
|
||
_, ok5 := VerifyRecoveryCode("ABCDE-FGHIJ", []string{})
|
||
if ok5 {
|
||
t.Fatal("Should not match against empty list")
|
||
}
|
||
|
||
t.Log("VerifyRecoveryCode tests passed")
|
||
}
|
||
|
||
func TestVerifyRecoveryCode_TimingSafety(t *testing.T) {
|
||
// Test that the function always iterates through all codes
|
||
// regardless of where the match is found (timing attack prevention)
|
||
codes := []string{"CODE1-AAAAA", "CODE2-BBBBB", "CODE3-CCCCC"}
|
||
hashedCodes := make([]string, len(codes))
|
||
for i, code := range codes {
|
||
hashed, _ := HashRecoveryCode(code)
|
||
hashedCodes[i] = hashed
|
||
}
|
||
|
||
// Test matching first code
|
||
idx1, ok1 := VerifyRecoveryCode("CODE1-AAAAA", hashedCodes)
|
||
if !ok1 || idx1 != 0 {
|
||
t.Errorf("First code match failed, idx=%d ok=%v", idx1, ok1)
|
||
}
|
||
|
||
// Test matching last code
|
||
idx3, ok3 := VerifyRecoveryCode("CODE3-CCCCC", hashedCodes)
|
||
if !ok3 || idx3 != 2 {
|
||
t.Errorf("Last code match failed, idx=%d ok=%v", idx3, ok3)
|
||
}
|
||
|
||
t.Log("Timing safety test passed")
|
||
}
|