- Add new test files for auth, service, and handler modules - Improve test organization and coverage - Refactor code for better maintainability - Add captcha, settings, stats, and theme handler tests - Add auth module tests (CAS, OAuth, password, SSO, state) - Add service layer tests for auth, export, permissions, roles - All Go tests pass (exit code 0) - All frontend tests pass (325 tests in 59 files)
1092 lines
33 KiB
Go
1092 lines
33 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/user-management-system/internal/auth"
|
|
"github.com/user-management-system/internal/cache"
|
|
"github.com/user-management-system/internal/domain"
|
|
"github.com/user-management-system/internal/security"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// =============================================================================
|
|
// Auth Runtime Helper Functions Tests
|
|
// =============================================================================
|
|
|
|
func TestIsUserNotFoundError(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
err error
|
|
expected bool
|
|
}{
|
|
{
|
|
name: "nil error",
|
|
err: nil,
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "gorm record not found",
|
|
err: gorm.ErrRecordNotFound,
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "wrapped gorm record not found",
|
|
err: errors.Join(gorm.ErrRecordNotFound, errors.New("additional context")),
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "other error",
|
|
err: errors.New("some other error"),
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "generic error",
|
|
err: errors.New("something went wrong"),
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "error containing user not found",
|
|
err: errors.New("user not found"),
|
|
expected: true, // contains "user not found" in lowercase
|
|
},
|
|
{
|
|
name: "error containing record not found",
|
|
err: errors.New("record not found"),
|
|
expected: true, // contains "record not found"
|
|
},
|
|
{
|
|
name: "error containing not found",
|
|
err: errors.New("entity not found"),
|
|
expected: true, // contains "not found"
|
|
},
|
|
{
|
|
name: "error containing 用户不存在",
|
|
err: errors.New("用户不存在"),
|
|
expected: true, // contains Chinese "用户不存在"
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := isUserNotFoundError(tt.err)
|
|
if result != tt.expected {
|
|
t.Errorf("isUserNotFoundError(%v) = %v, want %v", tt.err, result, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// OAuth State Tests
|
|
// =============================================================================
|
|
|
|
func TestAuthService_CreateOAuthState(t *testing.T) {
|
|
l1Cache := cache.NewL1Cache()
|
|
l2Cache := cache.NewRedisCache(false)
|
|
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
|
|
svc := &AuthService{cache: cacheManager}
|
|
ctx := context.Background()
|
|
|
|
t.Run("CreateOAuthState success", func(t *testing.T) {
|
|
state, err := svc.CreateOAuthState(ctx, "http://localhost/callback")
|
|
if err != nil {
|
|
t.Fatalf("CreateOAuthState failed: %v", err)
|
|
}
|
|
if state == "" {
|
|
t.Error("Expected non-empty state")
|
|
}
|
|
})
|
|
|
|
t.Run("CreateOAuthState with empty return URL", func(t *testing.T) {
|
|
state, err := svc.CreateOAuthState(ctx, "")
|
|
if err != nil {
|
|
t.Fatalf("CreateOAuthState failed: %v", err)
|
|
}
|
|
if state == "" {
|
|
t.Error("Expected non-empty state")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestAuthService_CreateOAuthBindState(t *testing.T) {
|
|
l1Cache := cache.NewL1Cache()
|
|
l2Cache := cache.NewRedisCache(false)
|
|
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
|
|
svc := &AuthService{cache: cacheManager}
|
|
ctx := context.Background()
|
|
|
|
t.Run("CreateOAuthBindState success", func(t *testing.T) {
|
|
state, err := svc.CreateOAuthBindState(ctx, 1, "http://localhost/callback")
|
|
if err != nil {
|
|
t.Fatalf("CreateOAuthBindState failed: %v", err)
|
|
}
|
|
if state == "" {
|
|
t.Error("Expected non-empty state")
|
|
}
|
|
})
|
|
|
|
t.Run("CreateOAuthBindState with invalid user ID", func(t *testing.T) {
|
|
_, err := svc.CreateOAuthBindState(ctx, 0, "http://localhost/callback")
|
|
if err == nil {
|
|
t.Error("Expected error for invalid user ID")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestAuthService_ConsumeOAuthState(t *testing.T) {
|
|
l1Cache := cache.NewL1Cache()
|
|
l2Cache := cache.NewRedisCache(false)
|
|
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
|
|
svc := &AuthService{cache: cacheManager}
|
|
ctx := context.Background()
|
|
|
|
t.Run("ConsumeOAuthState invalid state", func(t *testing.T) {
|
|
_, err := svc.ConsumeOAuthState(ctx, "invalid_state")
|
|
if err == nil {
|
|
t.Error("Expected error for invalid state")
|
|
}
|
|
})
|
|
|
|
t.Run("ConsumeOAuthState valid state", func(t *testing.T) {
|
|
state, _ := svc.CreateOAuthState(ctx, "http://localhost/callback")
|
|
returnTo, err := svc.ConsumeOAuthState(ctx, state)
|
|
if err != nil {
|
|
t.Fatalf("ConsumeOAuthState failed: %v", err)
|
|
}
|
|
if returnTo != "http://localhost/callback" {
|
|
t.Errorf("Expected return URL 'http://localhost/callback', got %s", returnTo)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestAuthService_ConsumeOAuthStatePayload(t *testing.T) {
|
|
l1Cache := cache.NewL1Cache()
|
|
l2Cache := cache.NewRedisCache(false)
|
|
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
|
|
svc := &AuthService{cache: cacheManager}
|
|
ctx := context.Background()
|
|
|
|
t.Run("ConsumeOAuthStatePayload with bind purpose", func(t *testing.T) {
|
|
state, _ := svc.CreateOAuthBindState(ctx, 123, "http://localhost/callback")
|
|
payload, err := svc.ConsumeOAuthStatePayload(ctx, state)
|
|
if err != nil {
|
|
t.Fatalf("ConsumeOAuthStatePayload failed: %v", err)
|
|
}
|
|
if payload.Purpose != OAuthStatePurposeBind {
|
|
t.Errorf("Expected purpose 'bind', got %s", payload.Purpose)
|
|
}
|
|
if payload.UserID != 123 {
|
|
t.Errorf("Expected user ID 123, got %d", payload.UserID)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestAuthService_CreateOAuthHandoff(t *testing.T) {
|
|
l1Cache := cache.NewL1Cache()
|
|
l2Cache := cache.NewRedisCache(false)
|
|
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
|
|
svc := &AuthService{cache: cacheManager}
|
|
ctx := context.Background()
|
|
|
|
t.Run("CreateOAuthHandoff success", func(t *testing.T) {
|
|
loginResp := &LoginResponse{
|
|
AccessToken: "test_token",
|
|
RefreshToken: "test_refresh",
|
|
}
|
|
code, err := svc.CreateOAuthHandoff(ctx, loginResp)
|
|
if err != nil {
|
|
t.Fatalf("CreateOAuthHandoff failed: %v", err)
|
|
}
|
|
if code == "" {
|
|
t.Error("Expected non-empty code")
|
|
}
|
|
})
|
|
|
|
t.Run("CreateOAuthHandoff with nil response", func(t *testing.T) {
|
|
_, err := svc.CreateOAuthHandoff(ctx, nil)
|
|
if err == nil {
|
|
t.Error("Expected error for nil response")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestAuthService_ConsumeOAuthHandoff(t *testing.T) {
|
|
l1Cache := cache.NewL1Cache()
|
|
l2Cache := cache.NewRedisCache(false)
|
|
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
|
|
svc := &AuthService{cache: cacheManager}
|
|
ctx := context.Background()
|
|
|
|
t.Run("ConsumeOAuthHandoff invalid code", func(t *testing.T) {
|
|
_, err := svc.ConsumeOAuthHandoff(ctx, "invalid_code")
|
|
if err == nil {
|
|
t.Error("Expected error for invalid code")
|
|
}
|
|
})
|
|
|
|
t.Run("ConsumeOAuthHandoff valid code", func(t *testing.T) {
|
|
loginResp := &LoginResponse{
|
|
AccessToken: "test_token",
|
|
RefreshToken: "test_refresh",
|
|
}
|
|
code, _ := svc.CreateOAuthHandoff(ctx, loginResp)
|
|
resp, err := svc.ConsumeOAuthHandoff(ctx, code)
|
|
if err != nil {
|
|
t.Fatalf("ConsumeOAuthHandoff failed: %v", err)
|
|
}
|
|
if resp.AccessToken != "test_token" {
|
|
t.Errorf("Expected access token 'test_token', got %s", resp.AccessToken)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestAuthService_OAuthStateNilService(t *testing.T) {
|
|
var nilSvc *AuthService
|
|
ctx := context.Background()
|
|
|
|
t.Run("CreateOAuthState nil service", func(t *testing.T) {
|
|
_, err := nilSvc.CreateOAuthState(ctx, "http://localhost")
|
|
if err == nil {
|
|
t.Error("Expected error for nil service")
|
|
}
|
|
})
|
|
|
|
t.Run("ConsumeOAuthState nil service", func(t *testing.T) {
|
|
_, err := nilSvc.ConsumeOAuthState(ctx, "state")
|
|
if err == nil {
|
|
t.Error("Expected error for nil service")
|
|
}
|
|
})
|
|
|
|
t.Run("CreateOAuthHandoff nil service", func(t *testing.T) {
|
|
_, err := nilSvc.CreateOAuthHandoff(ctx, &LoginResponse{})
|
|
if err == nil {
|
|
t.Error("Expected error for nil service")
|
|
}
|
|
})
|
|
|
|
t.Run("ConsumeOAuthHandoff nil service", func(t *testing.T) {
|
|
_, err := nilSvc.ConsumeOAuthHandoff(ctx, "code")
|
|
if err == nil {
|
|
t.Error("Expected error for nil service")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestGenerateOAuthEphemeralCode(t *testing.T) {
|
|
code, err := generateOAuthEphemeralCode()
|
|
if err != nil {
|
|
t.Fatalf("generateOAuthEphemeralCode failed: %v", err)
|
|
}
|
|
if code == "" {
|
|
t.Error("Expected non-empty code")
|
|
}
|
|
// Should generate different codes
|
|
code2, _ := generateOAuthEphemeralCode()
|
|
if code == code2 {
|
|
t.Error("Expected different codes on each call")
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// Password Policy Tests
|
|
// =============================================================================
|
|
|
|
func TestAuthService_SetPasswordPolicy(t *testing.T) {
|
|
l1Cache := cache.NewL1Cache()
|
|
l2Cache := cache.NewRedisCache(false)
|
|
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
|
|
svc := &AuthService{cache: cacheManager}
|
|
|
|
t.Run("SetPasswordPolicy success", func(t *testing.T) {
|
|
policy := security.PasswordPolicy{
|
|
MinLength: 12,
|
|
RequireSpecial: true,
|
|
RequireNumber: true,
|
|
}
|
|
svc.SetPasswordPolicy(policy)
|
|
// Verify policy is set
|
|
if !svc.passwordPolicySet {
|
|
t.Error("Expected passwordPolicySet to be true")
|
|
}
|
|
if svc.passwordPolicy.MinLength != 12 {
|
|
t.Errorf("Expected MinLength 12, got %d", svc.passwordPolicy.MinLength)
|
|
}
|
|
})
|
|
|
|
t.Run("SetPasswordPolicy with defaults", func(t *testing.T) {
|
|
svc2 := &AuthService{cache: cacheManager}
|
|
policy := security.PasswordPolicy{} // Empty policy
|
|
svc2.SetPasswordPolicy(policy)
|
|
// Should normalize to default min length 8
|
|
if svc2.passwordPolicy.MinLength != 8 {
|
|
t.Errorf("Expected normalized MinLength 8, got %d", svc2.passwordPolicy.MinLength)
|
|
}
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// Social Account Helper Tests
|
|
// =============================================================================
|
|
|
|
func TestFindSocialAccountByProvider(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
accounts []*domain.SocialAccount
|
|
provider string
|
|
expectNil bool
|
|
}{
|
|
{
|
|
name: "nil accounts",
|
|
accounts: nil,
|
|
provider: "github",
|
|
expectNil: true,
|
|
},
|
|
{
|
|
name: "empty accounts",
|
|
accounts: []*domain.SocialAccount{},
|
|
provider: "github",
|
|
expectNil: true,
|
|
},
|
|
{
|
|
name: "found matching provider",
|
|
accounts: []*domain.SocialAccount{
|
|
{Provider: "github", OpenID: "123"},
|
|
{Provider: "google", OpenID: "456"},
|
|
},
|
|
provider: "github",
|
|
expectNil: false,
|
|
},
|
|
{
|
|
name: "case insensitive match",
|
|
accounts: []*domain.SocialAccount{
|
|
{Provider: "GitHub", OpenID: "123"},
|
|
},
|
|
provider: "github",
|
|
expectNil: false,
|
|
},
|
|
{
|
|
name: "provider not found",
|
|
accounts: []*domain.SocialAccount{
|
|
{Provider: "google", OpenID: "456"},
|
|
},
|
|
provider: "github",
|
|
expectNil: true,
|
|
},
|
|
{
|
|
name: "nil account in list",
|
|
accounts: []*domain.SocialAccount{
|
|
nil,
|
|
{Provider: "github", OpenID: "123"},
|
|
},
|
|
provider: "github",
|
|
expectNil: false,
|
|
},
|
|
{
|
|
name: "empty provider",
|
|
accounts: []*domain.SocialAccount{
|
|
{Provider: "github", OpenID: "123"},
|
|
},
|
|
provider: "",
|
|
expectNil: true,
|
|
},
|
|
{
|
|
name: "provider with spaces",
|
|
accounts: []*domain.SocialAccount{
|
|
{Provider: " github ", OpenID: "123"},
|
|
},
|
|
provider: "github",
|
|
expectNil: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := findSocialAccountByProvider(tt.accounts, tt.provider)
|
|
if (result == nil) != tt.expectNil {
|
|
t.Errorf("findSocialAccountByProvider() nil = %v, expectNil = %v", result == nil, tt.expectNil)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// Available Login Method Count Tests
|
|
// =============================================================================
|
|
|
|
func TestAuthService_AvailableLoginMethodCount(t *testing.T) {
|
|
l1Cache := cache.NewL1Cache()
|
|
l2Cache := cache.NewRedisCache(false)
|
|
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
|
|
|
|
t.Run("nil user", func(t *testing.T) {
|
|
svc := &AuthService{cache: cacheManager}
|
|
count := svc.availableLoginMethodCount(nil, nil, "")
|
|
if count != 0 {
|
|
t.Errorf("Expected 0 for nil user, got %d", count)
|
|
}
|
|
})
|
|
|
|
t.Run("password only", func(t *testing.T) {
|
|
svc := &AuthService{cache: cacheManager}
|
|
user := &domain.User{Password: "hashed_password"}
|
|
count := svc.availableLoginMethodCount(user, nil, "")
|
|
if count != 1 {
|
|
t.Errorf("Expected 1 for password only, got %d", count)
|
|
}
|
|
})
|
|
|
|
t.Run("password and email with email service", func(t *testing.T) {
|
|
email := "test@example.com"
|
|
svc := &AuthService{
|
|
cache: cacheManager,
|
|
emailCodeSvc: &EmailCodeService{},
|
|
}
|
|
user := &domain.User{Password: "hashed_password", Email: &email}
|
|
count := svc.availableLoginMethodCount(user, nil, "")
|
|
if count != 2 {
|
|
t.Errorf("Expected 2 for password and email, got %d", count)
|
|
}
|
|
})
|
|
|
|
t.Run("password and phone with sms service", func(t *testing.T) {
|
|
phone := "13800138000"
|
|
svc := &AuthService{
|
|
cache: cacheManager,
|
|
smsCodeSvc: &SMSCodeService{},
|
|
}
|
|
user := &domain.User{Password: "hashed_password", Phone: &phone}
|
|
count := svc.availableLoginMethodCount(user, nil, "")
|
|
if count != 2 {
|
|
t.Errorf("Expected 2 for password and phone, got %d", count)
|
|
}
|
|
})
|
|
|
|
t.Run("all methods", func(t *testing.T) {
|
|
email := "test@example.com"
|
|
phone := "13800138000"
|
|
svc := &AuthService{
|
|
cache: cacheManager,
|
|
emailCodeSvc: &EmailCodeService{},
|
|
smsCodeSvc: &SMSCodeService{},
|
|
}
|
|
user := &domain.User{Password: "hashed_password", Email: &email, Phone: &phone}
|
|
accounts := []*domain.SocialAccount{
|
|
{Provider: "github", Status: domain.SocialAccountStatusActive},
|
|
}
|
|
count := svc.availableLoginMethodCount(user, accounts, "")
|
|
if count != 4 {
|
|
t.Errorf("Expected 4 for all methods, got %d", count)
|
|
}
|
|
})
|
|
|
|
t.Run("exclude social provider", func(t *testing.T) {
|
|
email := "test@example.com"
|
|
svc := &AuthService{
|
|
cache: cacheManager,
|
|
emailCodeSvc: &EmailCodeService{},
|
|
}
|
|
user := &domain.User{Password: "hashed_password", Email: &email}
|
|
accounts := []*domain.SocialAccount{
|
|
{Provider: "github", Status: domain.SocialAccountStatusActive},
|
|
{Provider: "google", Status: domain.SocialAccountStatusActive},
|
|
}
|
|
count := svc.availableLoginMethodCount(user, accounts, "github")
|
|
// password + email + google (github excluded)
|
|
if count != 3 {
|
|
t.Errorf("Expected 3 with github excluded, got %d", count)
|
|
}
|
|
})
|
|
|
|
t.Run("inactive social accounts not counted", func(t *testing.T) {
|
|
svc := &AuthService{cache: cacheManager}
|
|
user := &domain.User{Password: "hashed_password"}
|
|
accounts := []*domain.SocialAccount{
|
|
{Provider: "github", Status: domain.SocialAccountStatusActive},
|
|
{Provider: "google", Status: 0}, // inactive
|
|
nil, // nil account
|
|
}
|
|
count := svc.availableLoginMethodCount(user, accounts, "")
|
|
// password + github only
|
|
if count != 2 {
|
|
t.Errorf("Expected 2 with inactive filtered, got %d", count)
|
|
}
|
|
})
|
|
|
|
t.Run("empty password not counted", func(t *testing.T) {
|
|
svc := &AuthService{cache: cacheManager}
|
|
user := &domain.User{Password: " "}
|
|
count := svc.availableLoginMethodCount(user, nil, "")
|
|
if count != 0 {
|
|
t.Errorf("Expected 0 for empty password, got %d", count)
|
|
}
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// Generate Unique Username Tests
|
|
// =============================================================================
|
|
|
|
func TestGenerateUniqueUsername(t *testing.T) {
|
|
t.Run("nil service returns sanitized username", func(t *testing.T) {
|
|
var nilSvc *AuthService
|
|
username, err := nilSvc.generateUniqueUsername(context.Background(), "Test User")
|
|
if err != nil {
|
|
t.Fatalf("Expected no error, got: %v", err)
|
|
}
|
|
if username != "test_user" {
|
|
t.Errorf("Expected 'test_user', got %q", username)
|
|
}
|
|
})
|
|
|
|
t.Run("service with nil userRepo returns sanitized username", func(t *testing.T) {
|
|
svc := &AuthService{}
|
|
username, err := svc.generateUniqueUsername(context.Background(), "John Doe")
|
|
if err != nil {
|
|
t.Fatalf("Expected no error, got: %v", err)
|
|
}
|
|
if username != "john_doe" {
|
|
t.Errorf("Expected 'john_doe', got %q", username)
|
|
}
|
|
})
|
|
|
|
t.Run("long username is truncated", func(t *testing.T) {
|
|
svc := &AuthService{}
|
|
longName := "this_is_a_very_long_username_that_should_be_truncated_to_forty_characters"
|
|
username, err := svc.generateUniqueUsername(context.Background(), longName)
|
|
if err != nil {
|
|
t.Fatalf("Expected no error, got: %v", err)
|
|
}
|
|
if len(username) > 50 {
|
|
t.Errorf("Username should be max 50 chars, got %d", len(username))
|
|
}
|
|
})
|
|
|
|
t.Run("empty base returns user", func(t *testing.T) {
|
|
svc := &AuthService{}
|
|
username, err := svc.generateUniqueUsername(context.Background(), "")
|
|
if err != nil {
|
|
t.Fatalf("Expected no error, got: %v", err)
|
|
}
|
|
if username != "user" {
|
|
t.Errorf("Expected 'user', got %q", username)
|
|
}
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// Set Login Log Repository Tests
|
|
// =============================================================================
|
|
|
|
func TestAuthService_SetLoginLogRepository(t *testing.T) {
|
|
svc := &AuthService{}
|
|
// Should not panic with nil
|
|
svc.SetLoginLogRepository(nil)
|
|
}
|
|
|
|
// =============================================================================
|
|
// Set Anomaly Detector Tests
|
|
// =============================================================================
|
|
|
|
func TestAuthService_SetAnomalyDetector(t *testing.T) {
|
|
svc := &AuthService{}
|
|
// Should not panic with nil
|
|
svc.SetAnomalyDetector(nil)
|
|
}
|
|
|
|
// =============================================================================
|
|
// Set Device Service Tests
|
|
// =============================================================================
|
|
|
|
func TestAuthService_SetDeviceService(t *testing.T) {
|
|
svc := &AuthService{}
|
|
// Should not panic with nil
|
|
svc.SetDeviceService(nil)
|
|
}
|
|
|
|
// =============================================================================
|
|
// Set SMS Code Service Tests
|
|
// =============================================================================
|
|
|
|
func TestAuthService_SetSMSCodeService(t *testing.T) {
|
|
svc := &AuthService{}
|
|
// Should not panic with nil
|
|
svc.SetSMSCodeService(nil)
|
|
}
|
|
|
|
// =============================================================================
|
|
// Available Login Method Count After Contact Removal Tests
|
|
// =============================================================================
|
|
|
|
func TestAuthService_AvailableLoginMethodCountAfterContactRemoval(t *testing.T) {
|
|
l1Cache := cache.NewL1Cache()
|
|
l2Cache := cache.NewRedisCache(false)
|
|
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
|
|
|
|
t.Run("nil user", func(t *testing.T) {
|
|
svc := &AuthService{cache: cacheManager}
|
|
count := svc.availableLoginMethodCountAfterContactRemoval(nil, nil, false, false)
|
|
if count != 0 {
|
|
t.Errorf("Expected 0 for nil user, got %d", count)
|
|
}
|
|
})
|
|
|
|
t.Run("password only no removal", func(t *testing.T) {
|
|
svc := &AuthService{cache: cacheManager}
|
|
user := &domain.User{Password: "hashed_password"}
|
|
count := svc.availableLoginMethodCountAfterContactRemoval(user, nil, false, false)
|
|
if count != 1 {
|
|
t.Errorf("Expected 1 for password only, got %d", count)
|
|
}
|
|
})
|
|
|
|
t.Run("password and email with email service", func(t *testing.T) {
|
|
email := "test@example.com"
|
|
svc := &AuthService{
|
|
cache: cacheManager,
|
|
emailCodeSvc: &EmailCodeService{},
|
|
}
|
|
user := &domain.User{Password: "hashed_password", Email: &email}
|
|
count := svc.availableLoginMethodCountAfterContactRemoval(user, nil, false, false)
|
|
if count != 2 {
|
|
t.Errorf("Expected 2 for password and email, got %d", count)
|
|
}
|
|
})
|
|
|
|
t.Run("remove email", func(t *testing.T) {
|
|
email := "test@example.com"
|
|
svc := &AuthService{
|
|
cache: cacheManager,
|
|
emailCodeSvc: &EmailCodeService{},
|
|
}
|
|
user := &domain.User{Password: "hashed_password", Email: &email}
|
|
count := svc.availableLoginMethodCountAfterContactRemoval(user, nil, true, false)
|
|
if count != 1 {
|
|
t.Errorf("Expected 1 after email removal, got %d", count)
|
|
}
|
|
})
|
|
|
|
t.Run("remove phone", func(t *testing.T) {
|
|
phone := "13800138000"
|
|
svc := &AuthService{
|
|
cache: cacheManager,
|
|
smsCodeSvc: &SMSCodeService{},
|
|
}
|
|
user := &domain.User{Password: "hashed_password", Phone: &phone}
|
|
count := svc.availableLoginMethodCountAfterContactRemoval(user, nil, false, true)
|
|
if count != 1 {
|
|
t.Errorf("Expected 1 after phone removal, got %d", count)
|
|
}
|
|
})
|
|
|
|
t.Run("social accounts counted", func(t *testing.T) {
|
|
svc := &AuthService{cache: cacheManager}
|
|
user := &domain.User{Password: "hashed_password"}
|
|
accounts := []*domain.SocialAccount{
|
|
{Provider: "github", Status: domain.SocialAccountStatusActive},
|
|
{Provider: "google", Status: domain.SocialAccountStatusActive},
|
|
}
|
|
count := svc.availableLoginMethodCountAfterContactRemoval(user, accounts, false, false)
|
|
if count != 3 {
|
|
t.Errorf("Expected 3 with social accounts, got %d", count)
|
|
}
|
|
})
|
|
|
|
t.Run("inactive social accounts not counted", func(t *testing.T) {
|
|
svc := &AuthService{cache: cacheManager}
|
|
user := &domain.User{Password: "hashed_password"}
|
|
accounts := []*domain.SocialAccount{
|
|
{Provider: "github", Status: domain.SocialAccountStatusActive},
|
|
{Provider: "google", Status: 0}, // inactive
|
|
nil,
|
|
}
|
|
count := svc.availableLoginMethodCountAfterContactRemoval(user, accounts, false, false)
|
|
if count != 2 {
|
|
t.Errorf("Expected 2 with inactive filtered, got %d", count)
|
|
}
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// Register OAuth Provider Tests
|
|
// =============================================================================
|
|
|
|
func TestAuthService_RegisterOAuthProvider(t *testing.T) {
|
|
t.Run("nil config does nothing", func(t *testing.T) {
|
|
svc := &AuthService{}
|
|
// Should not panic with nil config
|
|
svc.RegisterOAuthProvider("github", nil)
|
|
})
|
|
|
|
t.Run("nil oauth manager", func(t *testing.T) {
|
|
svc := &AuthService{}
|
|
cfg := &auth.OAuthConfig{ClientID: "test"}
|
|
// Should not panic with nil oauthManager
|
|
svc.RegisterOAuthProvider("github", cfg)
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// Best Effort Register Device Public Tests
|
|
// =============================================================================
|
|
|
|
func TestAuthService_BestEffortRegisterDevicePublic(t *testing.T) {
|
|
t.Run("nil service does not panic", func(t *testing.T) {
|
|
var nilSvc *AuthService
|
|
// Should not panic
|
|
nilSvc.BestEffortRegisterDevicePublic(context.Background(), 1, nil)
|
|
})
|
|
|
|
t.Run("nil device service does not panic", func(t *testing.T) {
|
|
svc := &AuthService{}
|
|
svc.BestEffortRegisterDevicePublic(context.Background(), 1, &LoginRequest{})
|
|
// Should not panic
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// Int Value and Int64 Value Tests
|
|
// =============================================================================
|
|
|
|
func TestIntValue(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input interface{}
|
|
expected int
|
|
wantOk bool
|
|
}{
|
|
{"int value", 42, 42, true},
|
|
{"int64 value", int64(100), 100, true},
|
|
{"float64 value", float64(99.0), 99, true},
|
|
{"float64 with decimal", float64(99.5), 99, true},
|
|
{"string value", "42", 0, false},
|
|
{"nil value", nil, 0, false},
|
|
{"negative int", -5, -5, true},
|
|
{"zero value", 0, 0, true},
|
|
{"large int64", int64(9999999999), 9999999999, true},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result, ok := intValue(tt.input)
|
|
if result != tt.expected || ok != tt.wantOk {
|
|
t.Errorf("intValue(%v) = (%d, %v), want (%d, %v)", tt.input, result, ok, tt.expected, tt.wantOk)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestInt64Value(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input interface{}
|
|
expected int64
|
|
wantOk bool
|
|
}{
|
|
{"int value", 42, 42, true},
|
|
{"int64 value", int64(100), 100, true},
|
|
{"float64 value", float64(99.0), 99, true},
|
|
{"float64 with decimal", float64(99.5), 99, true},
|
|
{"string value", "42", 0, false},
|
|
{"nil value", nil, 0, false},
|
|
{"negative int64", int64(-5), -5, true},
|
|
{"zero value", 0, 0, true},
|
|
{"large int64", int64(9999999999), 9999999999, true},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result, ok := int64Value(tt.input)
|
|
if result != tt.expected || ok != tt.wantOk {
|
|
t.Errorf("int64Value(%v) = (%d, %v), want (%d, %v)", tt.input, result, ok, tt.expected, tt.wantOk)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// Best Effort Update Last Login Tests
|
|
// =============================================================================
|
|
|
|
func TestBestEffortUpdateLastLogin(t *testing.T) {
|
|
t.Run("nil service does not panic", func(t *testing.T) {
|
|
var nilSvc *AuthService
|
|
// Should not panic
|
|
nilSvc.bestEffortUpdateLastLogin(context.Background(), 1, "127.0.0.1", "password")
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// Best Effort Assign Default Roles Tests
|
|
// =============================================================================
|
|
|
|
func TestBestEffortAssignDefaultRoles(t *testing.T) {
|
|
t.Run("nil service does not panic", func(t *testing.T) {
|
|
var nilSvc *AuthService
|
|
nilSvc.bestEffortAssignDefaultRoles(context.Background(), 1, "register")
|
|
})
|
|
|
|
t.Run("service without repos does not panic", func(t *testing.T) {
|
|
svc := &AuthService{}
|
|
svc.bestEffortAssignDefaultRoles(context.Background(), 1, "register")
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// Create OAuth State Payload Tests
|
|
// =============================================================================
|
|
|
|
func TestCreateOAuthStatePayload(t *testing.T) {
|
|
t.Run("nil service returns error", func(t *testing.T) {
|
|
var nilSvc *AuthService
|
|
_, err := nilSvc.createOAuthStatePayload(context.Background(), &OAuthStatePayload{Purpose: OAuthStatePurposeLogin})
|
|
if err == nil {
|
|
t.Error("Expected error for nil service")
|
|
}
|
|
})
|
|
|
|
t.Run("service without cache returns error", func(t *testing.T) {
|
|
svc := &AuthService{}
|
|
_, err := svc.createOAuthStatePayload(context.Background(), &OAuthStatePayload{Purpose: OAuthStatePurposeLogin})
|
|
if err == nil {
|
|
t.Error("Expected error when cache not configured")
|
|
}
|
|
})
|
|
|
|
t.Run("nil payload returns error", func(t *testing.T) {
|
|
l1Cache := cache.NewL1Cache()
|
|
l2Cache := cache.NewRedisCache(false)
|
|
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
|
|
svc := &AuthService{cache: cacheManager}
|
|
|
|
_, err := svc.createOAuthStatePayload(context.Background(), nil)
|
|
if err == nil {
|
|
t.Error("Expected error for nil payload")
|
|
}
|
|
})
|
|
|
|
t.Run("create state payload with cache", func(t *testing.T) {
|
|
l1Cache := cache.NewL1Cache()
|
|
l2Cache := cache.NewRedisCache(false)
|
|
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
|
|
svc := &AuthService{cache: cacheManager}
|
|
|
|
state, err := svc.createOAuthStatePayload(context.Background(), &OAuthStatePayload{
|
|
Purpose: OAuthStatePurposeLogin,
|
|
ReturnTo: "http://localhost/callback",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("createOAuthStatePayload failed: %v", err)
|
|
}
|
|
if state == "" {
|
|
t.Error("Expected non-empty state")
|
|
}
|
|
})
|
|
|
|
t.Run("create state payload with default purpose", func(t *testing.T) {
|
|
l1Cache := cache.NewL1Cache()
|
|
l2Cache := cache.NewRedisCache(false)
|
|
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
|
|
svc := &AuthService{cache: cacheManager}
|
|
|
|
state, err := svc.createOAuthStatePayload(context.Background(), &OAuthStatePayload{
|
|
ReturnTo: "http://localhost/callback",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("createOAuthStatePayload failed: %v", err)
|
|
}
|
|
if state == "" {
|
|
t.Error("Expected non-empty state")
|
|
}
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// Verify Phone Registration Tests
|
|
// =============================================================================
|
|
|
|
func TestVerifyPhoneRegistration(t *testing.T) {
|
|
t.Run("nil service returns nil for empty phone", func(t *testing.T) {
|
|
var nilSvc *AuthService
|
|
err := nilSvc.verifyPhoneRegistration(context.Background(), &RegisterRequest{Phone: ""})
|
|
if err != nil {
|
|
t.Errorf("Expected nil error for empty phone, got: %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("nil request returns nil", func(t *testing.T) {
|
|
svc := &AuthService{}
|
|
err := svc.verifyPhoneRegistration(context.Background(), nil)
|
|
if err != nil {
|
|
t.Errorf("Expected nil error for nil request, got: %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("service without SMS returns error", func(t *testing.T) {
|
|
svc := &AuthService{}
|
|
err := svc.verifyPhoneRegistration(context.Background(), &RegisterRequest{Phone: "13800138000", PhoneCode: "123456"})
|
|
if err == nil {
|
|
t.Error("Expected error when SMS service not configured")
|
|
}
|
|
})
|
|
|
|
t.Run("empty phone code returns error", func(t *testing.T) {
|
|
svc := &AuthService{smsCodeSvc: &SMSCodeService{}}
|
|
err := svc.verifyPhoneRegistration(context.Background(), &RegisterRequest{Phone: "13800138000", PhoneCode: ""})
|
|
if err == nil {
|
|
t.Error("Expected error when phone code is empty")
|
|
}
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// Consume OAuth State Payload Tests
|
|
// =============================================================================
|
|
|
|
func TestConsumeOAuthStatePayload(t *testing.T) {
|
|
t.Run("nil service returns error", func(t *testing.T) {
|
|
var nilSvc *AuthService
|
|
_, err := nilSvc.ConsumeOAuthStatePayload(context.Background(), "state123")
|
|
if err == nil {
|
|
t.Error("Expected error for nil service")
|
|
}
|
|
})
|
|
|
|
t.Run("service without cache returns error", func(t *testing.T) {
|
|
svc := &AuthService{}
|
|
_, err := svc.ConsumeOAuthStatePayload(context.Background(), "state123")
|
|
if err == nil {
|
|
t.Error("Expected error when cache not configured")
|
|
}
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// Consume OAuth Handoff Tests
|
|
// =============================================================================
|
|
|
|
func TestConsumeOAuthHandoff(t *testing.T) {
|
|
t.Run("nil service returns error", func(t *testing.T) {
|
|
var nilSvc *AuthService
|
|
_, err := nilSvc.ConsumeOAuthHandoff(context.Background(), "code123")
|
|
if err == nil {
|
|
t.Error("Expected error for nil service")
|
|
}
|
|
})
|
|
|
|
t.Run("service without cache returns error", func(t *testing.T) {
|
|
svc := &AuthService{}
|
|
_, err := svc.ConsumeOAuthHandoff(context.Background(), "code123")
|
|
if err == nil {
|
|
t.Error("Expected error when cache not configured")
|
|
}
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// Consume OAuth Handoff With Cache Tests
|
|
// =============================================================================
|
|
|
|
func TestConsumeOAuthHandoff_WithCache(t *testing.T) {
|
|
l1Cache := cache.NewL1Cache()
|
|
l2Cache := cache.NewRedisCache(false)
|
|
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
|
|
svc := &AuthService{cache: cacheManager}
|
|
ctx := context.Background()
|
|
|
|
t.Run("consume non-existent handoff", func(t *testing.T) {
|
|
_, err := svc.ConsumeOAuthHandoff(ctx, "nonexistent_code")
|
|
if err == nil {
|
|
t.Error("Expected error for non-existent handoff")
|
|
}
|
|
})
|
|
|
|
t.Run("consume handoff with pointer response", func(t *testing.T) {
|
|
resp := &LoginResponse{
|
|
AccessToken: "test_access_token",
|
|
RefreshToken: "test_refresh_token",
|
|
}
|
|
cacheManager.Set(ctx, "oauth_handoff:test_code_1", resp, time.Minute, time.Minute)
|
|
|
|
result, err := svc.ConsumeOAuthHandoff(ctx, "test_code_1")
|
|
if err != nil {
|
|
t.Fatalf("ConsumeOAuthHandoff failed: %v", err)
|
|
}
|
|
if result.AccessToken != "test_access_token" {
|
|
t.Errorf("Expected access token, got %s", result.AccessToken)
|
|
}
|
|
})
|
|
|
|
t.Run("consume handoff with value response", func(t *testing.T) {
|
|
resp := LoginResponse{
|
|
AccessToken: "value_access_token",
|
|
RefreshToken: "value_refresh_token",
|
|
}
|
|
cacheManager.Set(ctx, "oauth_handoff:test_code_2", resp, time.Minute, time.Minute)
|
|
|
|
result, err := svc.ConsumeOAuthHandoff(ctx, "test_code_2")
|
|
if err != nil {
|
|
t.Fatalf("ConsumeOAuthHandoff failed: %v", err)
|
|
}
|
|
if result.AccessToken != "value_access_token" {
|
|
t.Errorf("Expected access token, got %s", result.AccessToken)
|
|
}
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// Consume OAuth State Payload With Cache Tests
|
|
// =============================================================================
|
|
|
|
func TestConsumeOAuthStatePayload_WithCache(t *testing.T) {
|
|
l1Cache := cache.NewL1Cache()
|
|
l2Cache := cache.NewRedisCache(false)
|
|
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
|
|
svc := &AuthService{cache: cacheManager}
|
|
ctx := context.Background()
|
|
|
|
t.Run("consume non-existent state", func(t *testing.T) {
|
|
_, err := svc.ConsumeOAuthStatePayload(ctx, "nonexistent_state")
|
|
if err == nil {
|
|
t.Error("Expected error for non-existent state")
|
|
}
|
|
})
|
|
|
|
t.Run("consume state with pointer payload", func(t *testing.T) {
|
|
payload := &OAuthStatePayload{
|
|
Purpose: OAuthStatePurposeLogin,
|
|
ReturnTo: "http://localhost/callback",
|
|
}
|
|
cacheManager.Set(ctx, "oauth_state:test_state_1", payload, time.Minute*10, time.Minute*10)
|
|
|
|
result, err := svc.ConsumeOAuthStatePayload(ctx, "test_state_1")
|
|
if err != nil {
|
|
t.Fatalf("ConsumeOAuthStatePayload failed: %v", err)
|
|
}
|
|
if result.Purpose != OAuthStatePurposeLogin {
|
|
t.Errorf("Expected purpose %s, got %s", OAuthStatePurposeLogin, result.Purpose)
|
|
}
|
|
})
|
|
|
|
t.Run("consume state with value payload", func(t *testing.T) {
|
|
payload := OAuthStatePayload{
|
|
Purpose: OAuthStatePurposeBind,
|
|
ReturnTo: "http://localhost/bind",
|
|
UserID: 123,
|
|
}
|
|
cacheManager.Set(ctx, "oauth_state:test_state_2", payload, time.Minute*10, time.Minute*10)
|
|
|
|
result, err := svc.ConsumeOAuthStatePayload(ctx, "test_state_2")
|
|
if err != nil {
|
|
t.Fatalf("ConsumeOAuthStatePayload failed: %v", err)
|
|
}
|
|
if result.UserID != 123 {
|
|
t.Errorf("Expected UserID 123, got %d", result.UserID)
|
|
}
|
|
})
|
|
}
|