package service import ( "context" "encoding/json" "errors" "fmt" "log" "strings" "time" "unicode" "unicode/utf8" "github.com/user-management-system/internal/auth" "github.com/user-management-system/internal/cache" "github.com/user-management-system/internal/domain" "github.com/user-management-system/internal/repository" "github.com/user-management-system/internal/security" ) const ( userInfoCachePrefix = "auth_user_info:" tokenBlacklistPrefix = "auth_token_blacklist:" defaultUserCacheTTL = 15 * time.Minute defaultBlacklistTTL = time.Hour defaultPasswordMinLen = 8 ) type userRepositoryInterface interface { Create(ctx context.Context, user *domain.User) error Update(ctx context.Context, user *domain.User) error UpdateTOTP(ctx context.Context, user *domain.User) error Delete(ctx context.Context, id int64) error GetByID(ctx context.Context, id int64) (*domain.User, error) GetByUsername(ctx context.Context, username string) (*domain.User, error) GetByEmail(ctx context.Context, email string) (*domain.User, error) GetByPhone(ctx context.Context, phone string) (*domain.User, error) List(ctx context.Context, offset, limit int) ([]*domain.User, int64, error) ListByStatus(ctx context.Context, status domain.UserStatus, offset, limit int) ([]*domain.User, int64, error) UpdateStatus(ctx context.Context, id int64, status domain.UserStatus) error UpdateLastLogin(ctx context.Context, id int64, ip string) error ExistsByUsername(ctx context.Context, username string) (bool, error) ExistsByEmail(ctx context.Context, email string) (bool, error) ExistsByPhone(ctx context.Context, phone string) (bool, error) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.User, int64, error) } type userRoleRepositoryInterface interface { BatchCreate(ctx context.Context, userRoles []*domain.UserRole) error GetUserIDByRoleID(ctx context.Context, roleID int64) ([]int64, error) } type roleRepositoryInterface interface { GetDefaultRoles(ctx context.Context) ([]*domain.Role, error) GetByCode(ctx context.Context, code string) (*domain.Role, error) } type loginLogRepositoryInterface interface { Create(ctx context.Context, loginRecord *domain.LoginLog) error } type anomalyRecorder interface { RecordLogin(ctx context.Context, userID int64, ip, location, deviceFingerprint string, success bool) []security.AnomalyEvent } type PasswordStrengthInfo struct { Score int `json:"score"` Length int `json:"length"` HasUpper bool `json:"has_upper"` HasLower bool `json:"has_lower"` HasDigit bool `json:"has_digit"` HasSpecial bool `json:"has_special"` } type RegisterRequest struct { Username string `json:"username" binding:"required"` Email string `json:"email"` Phone string `json:"phone"` PhoneCode string `json:"phone_code"` Password string `json:"password" binding:"required"` Nickname string `json:"nickname"` } type LoginRequest struct { Account string `json:"account"` Username string `json:"username"` Email string `json:"email"` Phone string `json:"phone"` Password string `json:"password"` Remember bool `json:"remember"` // 记住登录 DeviceID string `json:"device_id,omitempty"` // 设备唯一标识 DeviceName string `json:"device_name,omitempty"` // 设备名称 DeviceBrowser string `json:"device_browser,omitempty"` // 浏览器 DeviceOS string `json:"device_os,omitempty"` // 操作系统 } func (r *LoginRequest) GetAccount() string { if r == nil { return "" } for _, candidate := range []string{r.Account, r.Username, r.Email, r.Phone} { if trimmed := strings.TrimSpace(candidate); trimmed != "" { return trimmed } } return "" } type UserInfo struct { ID int64 `json:"id"` Username string `json:"username"` Email string `json:"email,omitempty"` Phone string `json:"phone,omitempty"` Nickname string `json:"nickname,omitempty"` Avatar string `json:"avatar,omitempty"` Status domain.UserStatus `json:"status"` } type LoginResponse struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` ExpiresIn int64 `json:"expires_in"` User *UserInfo `json:"user"` } type LogoutRequest struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` } type AuthService struct { userRepo userRepositoryInterface socialRepo repository.SocialAccountRepository jwtManager *auth.JWT cache *cache.CacheManager passwordMinLength int maxLoginAttempts int loginLockDuration time.Duration userRoleRepo userRoleRepositoryInterface roleRepo roleRepositoryInterface loginLogRepo loginLogRepositoryInterface webhookSvc *WebhookService passwordPolicy security.PasswordPolicy passwordPolicySet bool anomalyDetector anomalyRecorder smsCodeSvc *SMSCodeService emailActivationSvc *EmailActivationService emailCodeSvc *EmailCodeService oauthManager auth.OAuthManager deviceService *DeviceService } func NewAuthService( userRepo userRepositoryInterface, socialRepo repository.SocialAccountRepository, jwtManager *auth.JWT, cacheManager *cache.CacheManager, passwordMinLength int, maxLoginAttempts int, loginLockDuration time.Duration, ) *AuthService { if passwordMinLength <= 0 { passwordMinLength = defaultPasswordMinLen } if maxLoginAttempts <= 0 { maxLoginAttempts = 5 } if loginLockDuration <= 0 { loginLockDuration = 15 * time.Minute } return &AuthService{ userRepo: userRepo, socialRepo: socialRepo, jwtManager: jwtManager, cache: cacheManager, passwordMinLength: passwordMinLength, maxLoginAttempts: maxLoginAttempts, loginLockDuration: loginLockDuration, oauthManager: auth.NewOAuthManager(), } } func (s *AuthService) SetWebhookService(webhookSvc *WebhookService) { s.webhookSvc = webhookSvc } func (s *AuthService) SetRoleRepositories(userRoleRepo userRoleRepositoryInterface, roleRepo roleRepositoryInterface) { s.userRoleRepo = userRoleRepo s.roleRepo = roleRepo } func (s *AuthService) SetLoginLogRepository(loginLogRepo loginLogRepositoryInterface) { s.loginLogRepo = loginLogRepo } func (s *AuthService) SetPasswordPolicy(policy security.PasswordPolicy) { s.passwordPolicy = policy.Normalize() s.passwordPolicySet = true } func (s *AuthService) SetAnomalyDetector(detector anomalyRecorder) { s.anomalyDetector = detector } func (s *AuthService) SetDeviceService(svc *DeviceService) { s.deviceService = svc } func (s *AuthService) SetSMSCodeService(svc *SMSCodeService) { s.smsCodeSvc = svc } func sanitizeUsername(value string) string { trimmed := strings.TrimSpace(value) if trimmed == "" { return "user" } var builder strings.Builder lastUnderscore := false for _, r := range trimmed { switch { case unicode.IsLetter(r) || unicode.IsDigit(r): builder.WriteRune(unicode.ToLower(r)) lastUnderscore = false case r == '.' || r == '-' || r == '_': builder.WriteRune(r) lastUnderscore = false case unicode.IsSpace(r): if !lastUnderscore && builder.Len() > 0 { builder.WriteByte('_') lastUnderscore = true } } } result := strings.Trim(builder.String(), "._-") if result == "" { return "user" } runes := []rune(result) if len(runes) > 50 { result = string(runes[:50]) } return result } func (s *AuthService) generateUniqueUsername(ctx context.Context, base string) (string, error) { username := sanitizeUsername(base) if s == nil || s.userRepo == nil { return username, nil } exists, err := s.userRepo.ExistsByUsername(ctx, username) if err != nil { return "", err } if !exists { return username, nil } baseRunes := []rune(username) if len(baseRunes) > 40 { username = string(baseRunes[:40]) } for i := 1; i <= 1000; i++ { candidate := fmt.Sprintf("%s_%d", username, i) exists, err = s.userRepo.ExistsByUsername(ctx, candidate) if err != nil { return "", err } if !exists { return candidate, nil } } return "", errors.New("unable to generate unique username") } func validatePasswordStrength(password string, minLength int, strict bool) error { if minLength <= 0 { minLength = defaultPasswordMinLen } info := GetPasswordStrength(password) if info.Length < minLength { return fmt.Errorf("密码长度不能少于%d位", minLength) } if strict { if !info.HasUpper || !info.HasLower || !info.HasDigit { return errors.New("密码必须包含大小写字母和数字") } return nil } if info.Score < 2 { return errors.New("密码强度不足") } return nil } func GetPasswordStrength(password string) PasswordStrengthInfo { info := PasswordStrengthInfo{ Length: utf8.RuneCountInString(password), } for _, r := range password { switch { case unicode.IsUpper(r): info.HasUpper = true case unicode.IsLower(r): info.HasLower = true case unicode.IsDigit(r): info.HasDigit = true case unicode.IsPunct(r) || unicode.IsSymbol(r): info.HasSpecial = true } } if info.HasUpper { info.Score++ } if info.HasLower { info.Score++ } if info.HasDigit { info.Score++ } if info.HasSpecial { info.Score++ } return info } func (s *AuthService) validatePassword(password string) error { if s != nil && s.passwordPolicySet { return s.passwordPolicy.Validate(password) } minLength := defaultPasswordMinLen if s != nil && s.passwordMinLength > 0 { minLength = s.passwordMinLength } return validatePasswordStrength(password, minLength, false) } func (s *AuthService) accessTokenTTLSeconds() int64 { if s == nil || s.jwtManager == nil { return 0 } return int64(s.jwtManager.GetAccessTokenExpire().Seconds()) } func (s *AuthService) RefreshTokenTTLSeconds() int64 { if s == nil || s.jwtManager == nil { return 0 } return int64(s.jwtManager.GetRefreshTokenExpire().Seconds()) } func (s *AuthService) buildUserInfo(user *domain.User) *UserInfo { if user == nil { return nil } return &UserInfo{ ID: user.ID, Username: user.Username, Email: domain.DerefStr(user.Email), Phone: domain.DerefStr(user.Phone), Nickname: user.Nickname, Avatar: user.Avatar, Status: user.Status, } } func (s *AuthService) ensureUserActive(user *domain.User) error { if user == nil { return errors.New("用户不存在") } switch user.Status { case domain.UserStatusActive: return nil case domain.UserStatusInactive: return errors.New("账号未激活") case domain.UserStatusLocked: return errors.New("账号已锁定") case domain.UserStatusDisabled: return errors.New("账号已禁用") default: return errors.New("账号状态异常") } } func (s *AuthService) blacklistTokenClaims(ctx context.Context, token string, validate func(string) (*auth.Claims, error)) error { if s == nil || s.cache == nil { return nil } token = strings.TrimSpace(token) if token == "" || validate == nil { return nil } claims, err := validate(token) if err != nil || claims == nil || strings.TrimSpace(claims.JTI) == "" { return nil } ttl := defaultBlacklistTTL if claims.ExpiresAt != nil { if until := time.Until(claims.ExpiresAt.Time); until > 0 { ttl = until } } return s.cache.Set(ctx, tokenBlacklistPrefix+claims.JTI, true, ttl, ttl) } func (s *AuthService) recordLoginAnomaly(ctx context.Context, userID *int64, ip, location, deviceFingerprint string, success bool) { if s == nil || s.anomalyDetector == nil || userID == nil { return } events := s.anomalyDetector.RecordLogin(ctx, *userID, ip, location, deviceFingerprint, success) if len(events) == 0 { return } s.publishEvent(ctx, domain.EventAnomalyDetected, map[string]interface{}{ "user_id": *userID, "ip": ip, "location": location, "device": deviceFingerprint, "events": events, "success": success, }) } func (s *AuthService) publishEvent(ctx context.Context, eventType domain.WebhookEventType, data interface{}) { if s == nil || s.webhookSvc == nil { return } go s.webhookSvc.Publish(ctx, eventType, data) } func (s *AuthService) writeLoginLog( ctx context.Context, userID *int64, loginType domain.LoginType, ip string, success bool, failReason string, ) { if s == nil || s.loginLogRepo == nil { return } status := 0 if success { status = 1 } loginRecord := &domain.LoginLog{ UserID: userID, LoginType: int(loginType), IP: ip, Status: status, FailReason: failReason, } go func() { if err := s.loginLogRepo.Create(context.Background(), loginRecord); err != nil { log.Printf("auth: write login log failed, user_id=%v login_type=%d err=%v", userID, loginType, err) } }() } func (s *AuthService) incrementFailAttempts(ctx context.Context, key string) int { if s == nil || s.cache == nil || key == "" { return 0 } current := 0 if value, ok := s.cache.Get(ctx, key); ok { current = attemptCount(value) } current++ if err := s.cache.Set(ctx, key, current, s.loginLockDuration, s.loginLockDuration); err != nil { log.Printf("auth: store login attempts failed, key=%s err=%v", key, err) } return current } func isValidPhoneSimple(phone string) bool { return isValidPhone(phone) } // buildDeviceFingerprint 构建设备指纹字符串 func buildDeviceFingerprint(req *LoginRequest) string { if req == nil { return "" } var parts []string if req.DeviceID != "" { parts = append(parts, req.DeviceID) } if req.DeviceName != "" { parts = append(parts, req.DeviceName) } if req.DeviceBrowser != "" { parts = append(parts, req.DeviceBrowser) } if req.DeviceOS != "" { parts = append(parts, req.DeviceOS) } result := strings.Join(parts, "|") if result == "" { return "" } return result } // bestEffortRegisterDevice 尝试自动注册/更新设备记录 func (s *AuthService) bestEffortRegisterDevice(ctx context.Context, userID int64, req *LoginRequest) { if s == nil || s.deviceService == nil || req == nil || req.DeviceID == "" { return } createReq := &CreateDeviceRequest{ DeviceID: req.DeviceID, DeviceName: req.DeviceName, DeviceBrowser: req.DeviceBrowser, DeviceOS: req.DeviceOS, } _, _ = s.deviceService.CreateDevice(ctx, userID, createReq) } func (s *AuthService) cacheUserInfo(ctx context.Context, user *domain.User) { if s == nil || s.cache == nil || user == nil { return } info := s.buildUserInfo(user) if info == nil { return } _ = s.cache.Set(ctx, userInfoCachePrefix+fmt.Sprintf("%d", user.ID), info, defaultUserCacheTTL, defaultUserCacheTTL) } func userInfoFromCacheValue(value interface{}) (*UserInfo, bool) { switch typed := value.(type) { case *UserInfo: return typed, true case UserInfo: userInfo := typed return &userInfo, true case map[string]interface{}: payload, err := json.Marshal(typed) if err != nil { return nil, false } var userInfo UserInfo if err := json.Unmarshal(payload, &userInfo); err != nil { return nil, false } return &userInfo, true default: return nil, false } } func (s *AuthService) Register(ctx context.Context, req *RegisterRequest) (*UserInfo, error) { if req == nil { return nil, errors.New("注册请求不能为空") } if s == nil || s.userRepo == nil { return nil, errors.New("user repository is not configured") } req.Username = strings.TrimSpace(req.Username) req.Email = strings.TrimSpace(req.Email) req.Phone = strings.TrimSpace(req.Phone) if req.Username == "" { return nil, errors.New("用户名不能为空") } if req.Password == "" { return nil, errors.New("密码不能为空") } if req.Phone != "" && !isValidPhoneSimple(req.Phone) { return nil, errors.New("手机号格式不正确") } if err := s.validatePassword(req.Password); err != nil { return nil, err } if err := s.verifyPhoneRegistration(ctx, req); err != nil { return nil, err } exists, err := s.userRepo.ExistsByUsername(ctx, req.Username) if err != nil { return nil, err } if exists { return nil, errors.New("用户名已存在") } if req.Email != "" { exists, err = s.userRepo.ExistsByEmail(ctx, req.Email) if err != nil { return nil, err } if exists { return nil, errors.New("邮箱已存在") } } if req.Phone != "" { exists, err = s.userRepo.ExistsByPhone(ctx, req.Phone) if err != nil { return nil, err } if exists { return nil, errors.New("手机号已存在") } } hashedPassword, err := auth.HashPassword(req.Password) if err != nil { return nil, err } nickname := strings.TrimSpace(req.Nickname) if nickname == "" { nickname = req.Username } user := &domain.User{ Username: req.Username, Email: domain.StrPtr(req.Email), Phone: domain.StrPtr(req.Phone), Password: hashedPassword, Nickname: nickname, Status: domain.UserStatusActive, } if err := s.userRepo.Create(ctx, user); err != nil { return nil, err } s.bestEffortAssignDefaultRoles(ctx, user.ID, "register") s.cacheUserInfo(ctx, user) userInfo := s.buildUserInfo(user) s.publishEvent(ctx, domain.EventUserRegistered, userInfo) return userInfo, nil } func (s *AuthService) Login(ctx context.Context, req *LoginRequest, ip string) (*LoginResponse, error) { if req == nil { return nil, errors.New("登录请求不能为空") } if s == nil || s.userRepo == nil || s.jwtManager == nil { return nil, errors.New("auth service is not fully configured") } account := req.GetAccount() if account == "" { return nil, errors.New("账号不能为空") } if strings.TrimSpace(req.Password) == "" { return nil, errors.New("密码不能为空") } // 构建设备指纹 deviceFingerprint := buildDeviceFingerprint(req) user, err := s.findUserForLogin(ctx, account) if err != nil && !isUserNotFoundError(err) { s.writeLoginLog(ctx, nil, domain.LoginTypePassword, ip, false, err.Error()) return nil, err } attemptKey := loginAttemptKey(account, user) if s.cache != nil { if value, ok := s.cache.Get(ctx, attemptKey); ok && attemptCount(value) >= s.maxLoginAttempts { lockErr := errors.New("账号已锁定,请稍后再试") s.writeLoginLog(ctx, nil, domain.LoginTypePassword, ip, false, lockErr.Error()) return nil, lockErr } } if user == nil { s.incrementFailAttempts(ctx, attemptKey) s.writeLoginLog(ctx, nil, domain.LoginTypePassword, ip, false, "用户不存在") return nil, errors.New("账号或密码错误") } if err := s.ensureUserActive(user); err != nil { s.writeLoginLog(ctx, &user.ID, domain.LoginTypePassword, ip, false, err.Error()) s.recordLoginAnomaly(ctx, &user.ID, ip, "", deviceFingerprint, false) return nil, err } if !auth.VerifyPassword(user.Password, req.Password) { failCount := s.incrementFailAttempts(ctx, attemptKey) failErr := errors.New("账号或密码错误") if failCount >= s.maxLoginAttempts { s.publishEvent(ctx, domain.EventUserLocked, map[string]interface{}{ "user_id": user.ID, "username": user.Username, "ip": ip, }) } s.writeLoginLog(ctx, &user.ID, domain.LoginTypePassword, ip, false, failErr.Error()) s.recordLoginAnomaly(ctx, &user.ID, ip, "", deviceFingerprint, false) s.publishEvent(ctx, domain.EventLoginFailed, map[string]interface{}{ "user_id": user.ID, "username": user.Username, "ip": ip, }) return nil, failErr } if s.cache != nil { _ = s.cache.Delete(ctx, attemptKey) } s.bestEffortUpdateLastLogin(ctx, user.ID, ip, "password") s.cacheUserInfo(ctx, user) s.writeLoginLog(ctx, &user.ID, domain.LoginTypePassword, ip, true, "") s.recordLoginAnomaly(ctx, &user.ID, ip, "", deviceFingerprint, true) s.bestEffortRegisterDevice(ctx, user.ID, req) s.publishEvent(ctx, domain.EventUserLogin, map[string]interface{}{ "user_id": user.ID, "username": user.Username, "ip": ip, "method": "password", }) return s.generateLoginResponse(ctx, user, req.Remember) } func (s *AuthService) RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) { if s == nil || s.jwtManager == nil || s.userRepo == nil { return nil, errors.New("auth service is not fully configured") } claims, err := s.jwtManager.ValidateRefreshToken(strings.TrimSpace(refreshToken)) if err != nil { return nil, err } if s.IsTokenBlacklisted(ctx, claims.JTI) { return nil, errors.New("refresh token has been revoked") } user, err := s.userRepo.GetByID(ctx, claims.UserID) if err != nil { return nil, err } if err := s.ensureUserActive(user); err != nil { return nil, err } return s.generateLoginResponse(ctx, user, claims.Remember) } func (s *AuthService) GetUserInfo(ctx context.Context, userID int64) (*UserInfo, error) { if s == nil || s.userRepo == nil { return nil, errors.New("user repository is not configured") } if s.cache != nil { cacheKey := userInfoCachePrefix + fmt.Sprintf("%d", userID) if value, ok := s.cache.Get(ctx, cacheKey); ok { if info, ok := userInfoFromCacheValue(value); ok { return info, nil } } } user, err := s.userRepo.GetByID(ctx, userID) if err != nil { return nil, err } s.cacheUserInfo(ctx, user) return s.buildUserInfo(user), nil } func (s *AuthService) Logout(ctx context.Context, username string, req *LogoutRequest) error { if s == nil { return nil } if req == nil { return nil } _ = s.blacklistTokenClaims(ctx, req.AccessToken, func(token string) (*auth.Claims, error) { if s.jwtManager == nil { return nil, nil } return s.jwtManager.ValidateAccessToken(token) }) _ = s.blacklistTokenClaims(ctx, req.RefreshToken, func(token string) (*auth.Claims, error) { if s.jwtManager == nil { return nil, nil } return s.jwtManager.ValidateRefreshToken(token) }) if strings.TrimSpace(username) != "" { s.publishEvent(ctx, domain.EventUserLogout, map[string]interface{}{ "username": strings.TrimSpace(username), }) } return nil } func (s *AuthService) IsTokenBlacklisted(ctx context.Context, jti string) bool { if s == nil || s.cache == nil { return false } jti = strings.TrimSpace(jti) if jti == "" { return false } _, ok := s.cache.Get(ctx, tokenBlacklistPrefix+jti) return ok } func (s *AuthService) OAuthLogin(ctx context.Context, provider, state string) (string, error) { if s == nil || s.oauthManager == nil { return "", errors.New("oauth manager is not configured") } return s.oauthManager.GetAuthURL(auth.OAuthProvider(strings.ToLower(strings.TrimSpace(provider))), state) } func (s *AuthService) OAuthCallback(ctx context.Context, provider, code string) (*LoginResponse, error) { if s == nil || s.oauthManager == nil || s.socialRepo == nil || s.userRepo == nil { return nil, errors.New("oauth login is not fully configured") } oauthProvider := auth.OAuthProvider(strings.ToLower(strings.TrimSpace(provider))) token, err := s.oauthManager.ExchangeCode(oauthProvider, strings.TrimSpace(code)) if err != nil { return nil, err } oauthUser, err := s.oauthManager.GetUserInfo(oauthProvider, token) if err != nil { return nil, err } if oauthUser == nil { return nil, errors.New("oauth user info is empty") } socialAccount, err := s.socialRepo.GetByProviderAndOpenID(ctx, string(oauthProvider), oauthUser.OpenID) if err != nil { return nil, err } var user *domain.User if socialAccount != nil { user, err = s.userRepo.GetByID(ctx, socialAccount.UserID) if err != nil { return nil, err } socialAccount.UnionID = oauthUser.UnionID socialAccount.Nickname = oauthUser.Nickname socialAccount.Avatar = oauthUser.Avatar socialAccount.Gender = oauthUser.Gender socialAccount.Email = oauthUser.Email socialAccount.Phone = oauthUser.Phone socialAccount.Status = domain.SocialAccountStatusActive if oauthUser.Extra != nil { socialAccount.Extra = oauthUser.Extra } if err := s.socialRepo.Update(ctx, socialAccount); err != nil { log.Printf("auth: update social account failed, provider=%s open_id=%s err=%v", oauthProvider, oauthUser.OpenID, err) } } else { if strings.TrimSpace(oauthUser.Email) != "" { user, err = s.userRepo.GetByEmail(ctx, strings.TrimSpace(oauthUser.Email)) if err != nil { if !isUserNotFoundError(err) { return nil, err } user = nil } } if user == nil { baseUsername := oauthUser.Nickname if baseUsername == "" && oauthUser.Email != "" { baseUsername = strings.Split(strings.TrimSpace(oauthUser.Email), "@")[0] } if baseUsername == "" { baseUsername = string(oauthProvider) + "_" + oauthUser.OpenID } username, err := s.generateUniqueUsername(ctx, baseUsername) if err != nil { return nil, err } user = &domain.User{ Username: username, Email: domain.StrPtr(strings.TrimSpace(oauthUser.Email)), Phone: domain.StrPtr(strings.TrimSpace(oauthUser.Phone)), Nickname: strings.TrimSpace(oauthUser.Nickname), Avatar: strings.TrimSpace(oauthUser.Avatar), Status: domain.UserStatusActive, } if user.Nickname == "" { user.Nickname = user.Username } if err := s.userRepo.Create(ctx, user); err != nil { return nil, err } s.bestEffortAssignDefaultRoles(ctx, user.ID, "oauth") s.publishEvent(ctx, domain.EventUserRegistered, s.buildUserInfo(user)) } socialAccount = &domain.SocialAccount{ UserID: user.ID, Provider: string(oauthProvider), OpenID: oauthUser.OpenID, UnionID: oauthUser.UnionID, Nickname: oauthUser.Nickname, Avatar: oauthUser.Avatar, Gender: oauthUser.Gender, Email: oauthUser.Email, Phone: oauthUser.Phone, Status: domain.SocialAccountStatusActive, } if oauthUser.Extra != nil { socialAccount.Extra = oauthUser.Extra } if err := s.socialRepo.Create(ctx, socialAccount); err != nil { return nil, err } } if err := s.ensureUserActive(user); err != nil { return nil, err } s.bestEffortUpdateLastLogin(ctx, user.ID, "", "oauth") s.cacheUserInfo(ctx, user) s.writeLoginLog(ctx, &user.ID, domain.LoginTypeOAuth, "", true, "") s.recordLoginAnomaly(ctx, &user.ID, "", "", "", true) s.publishEvent(ctx, domain.EventUserLogin, map[string]interface{}{ "user_id": user.ID, "username": user.Username, "method": "oauth", "provider": string(oauthProvider), }) return s.generateLoginResponseWithoutRemember(ctx, user) } func (s *AuthService) StartSocialAccountBinding( ctx context.Context, userID int64, provider string, returnTo string, currentPassword string, totpCode string, ) (string, string, error) { if s == nil || s.oauthManager == nil || s.socialRepo == nil || s.userRepo == nil { return "", "", errors.New("social account binding is not fully configured") } normalizedProvider := strings.ToLower(strings.TrimSpace(provider)) user, err := s.userRepo.GetByID(ctx, userID) if err != nil { return "", "", err } if err := s.ensureUserActive(user); err != nil { return "", "", err } if err := s.verifySensitiveAction(ctx, user, currentPassword, totpCode); err != nil { return "", "", err } accounts, err := s.GetSocialAccounts(ctx, userID) if err != nil { return "", "", err } if existing := findSocialAccountByProvider(accounts, normalizedProvider); existing != nil { return "", "", auth.ErrOAuthAlreadyBound } state, err := s.CreateOAuthBindState(ctx, userID, returnTo) if err != nil { return "", "", err } authURL, err := s.OAuthLogin(ctx, normalizedProvider, state) if err != nil { return "", "", err } return authURL, state, nil } func (s *AuthService) OAuthBindCallback(ctx context.Context, userID int64, provider, code string) (*domain.SocialAccountInfo, error) { if s == nil || s.oauthManager == nil || s.socialRepo == nil || s.userRepo == nil { return nil, errors.New("social account binding is not fully configured") } user, err := s.userRepo.GetByID(ctx, userID) if err != nil { return nil, err } if err := s.ensureUserActive(user); err != nil { return nil, err } oauthProvider := auth.OAuthProvider(strings.ToLower(strings.TrimSpace(provider))) token, err := s.oauthManager.ExchangeCode(oauthProvider, strings.TrimSpace(code)) if err != nil { return nil, err } oauthUser, err := s.oauthManager.GetUserInfo(oauthProvider, token) if err != nil { return nil, err } if oauthUser == nil { return nil, errors.New("oauth user info is empty") } account, err := s.upsertOAuthSocialAccount(ctx, userID, oauthProvider, oauthUser) if err != nil { return nil, err } return account.ToInfo(), nil } func (s *AuthService) upsertOAuthSocialAccount( ctx context.Context, userID int64, provider auth.OAuthProvider, oauthUser *auth.OAuthUser, ) (*domain.SocialAccount, error) { if s == nil || s.socialRepo == nil || s.userRepo == nil { return nil, errors.New("social account binding is not configured") } if oauthUser == nil { return nil, errors.New("oauth user info is empty") } normalizedProvider := strings.ToLower(strings.TrimSpace(string(provider))) accounts, err := s.GetSocialAccounts(ctx, userID) if err != nil { return nil, err } if currentProviderBinding := findSocialAccountByProvider(accounts, normalizedProvider); currentProviderBinding != nil && !strings.EqualFold(strings.TrimSpace(currentProviderBinding.OpenID), strings.TrimSpace(oauthUser.OpenID)) { return nil, errors.New("provider already bound to current account") } existing, err := s.socialRepo.GetByProviderAndOpenID(ctx, normalizedProvider, strings.TrimSpace(oauthUser.OpenID)) if err != nil { return nil, err } if existing != nil { if existing.UserID != userID { return nil, auth.ErrOAuthAlreadyBound } existing.UnionID = oauthUser.UnionID existing.Nickname = oauthUser.Nickname existing.Avatar = oauthUser.Avatar existing.Gender = oauthUser.Gender existing.Email = oauthUser.Email existing.Phone = oauthUser.Phone existing.Status = domain.SocialAccountStatusActive if oauthUser.Extra != nil { existing.Extra = oauthUser.Extra } if err := s.socialRepo.Update(ctx, existing); err != nil { return nil, err } return existing, nil } account := &domain.SocialAccount{ UserID: userID, Provider: normalizedProvider, OpenID: strings.TrimSpace(oauthUser.OpenID), UnionID: oauthUser.UnionID, Nickname: oauthUser.Nickname, Avatar: oauthUser.Avatar, Gender: oauthUser.Gender, Email: oauthUser.Email, Phone: oauthUser.Phone, Status: domain.SocialAccountStatusActive, } if oauthUser.Extra != nil { account.Extra = oauthUser.Extra } if err := s.socialRepo.Create(ctx, account); err != nil { return nil, err } return account, nil } func (s *AuthService) verifySensitiveAction( ctx context.Context, user *domain.User, currentPassword string, totpCode string, ) error { if user == nil { return errors.New("user is required") } password := strings.TrimSpace(currentPassword) code := strings.TrimSpace(totpCode) hasPassword := strings.TrimSpace(user.Password) != "" hasTOTP := user.TOTPEnabled && strings.TrimSpace(user.TOTPSecret) != "" // 如果用户既没有密码也没有启用TOTP,禁止执行敏感操作 if !hasPassword && !hasTOTP { return errors.New("请先设置密码或启用两步验证") } if password != "" { if !hasPassword || !auth.VerifyPassword(user.Password, password) { return errors.New("当前密码不正确") } return nil } if code != "" { if !hasTOTP { return errors.New("TOTP verification is not available") } return s.verifyTOTPCodeOrRecoveryCode(ctx, user, code) } return errors.New("password or TOTP verification is required") } func (s *AuthService) verifyTOTPCodeOrRecoveryCode(ctx context.Context, user *domain.User, code string) error { if user == nil { return errors.New("user is required") } if !user.TOTPEnabled || strings.TrimSpace(user.TOTPSecret) == "" { return errors.New("TOTP verification is not available") } manager := auth.NewTOTPManager() if manager.ValidateCode(user.TOTPSecret, code) { return nil } var hashedCodes []string if strings.TrimSpace(user.TOTPRecoveryCodes) != "" { _ = json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes) } index, matched := auth.VerifyRecoveryCode(code, hashedCodes) if !matched { return errors.New("TOTP code or recovery code is invalid") } hashedCodes = append(hashedCodes[:index], hashedCodes[index+1:]...) payload, err := json.Marshal(hashedCodes) if err != nil { return err } user.TOTPRecoveryCodes = string(payload) return s.userRepo.UpdateTOTP(ctx, user) } // VerifyTOTP 验证 TOTP(支持设备信任跳过) // 如果设备已信任且未过期,跳过 TOTP 验证 func (s *AuthService) VerifyTOTP(ctx context.Context, userID int64, code, deviceID string) error { if s == nil || s.userRepo == nil { return errors.New("auth service is not fully configured") } user, err := s.userRepo.GetByID(ctx, userID) if err != nil { return errors.New("用户不存在") } // 检查设备信任状态 if deviceID != "" && s.deviceService != nil { device, err := s.deviceService.GetDeviceByDeviceID(ctx, userID, deviceID) if err == nil && device.IsTrusted { // 检查信任是否过期 if device.TrustExpiresAt == nil || device.TrustExpiresAt.After(time.Now()) { return nil // 设备已信任,跳过 TOTP 验证 } } } // 执行 TOTP 验证 return s.verifyTOTPCodeOrRecoveryCode(ctx, user, code) } func findSocialAccountByProvider(accounts []*domain.SocialAccount, provider string) *domain.SocialAccount { normalizedProvider := strings.ToLower(strings.TrimSpace(provider)) for _, account := range accounts { if account == nil { continue } if strings.EqualFold(strings.TrimSpace(account.Provider), normalizedProvider) { return account } } return nil } func (s *AuthService) availableLoginMethodCount( user *domain.User, accounts []*domain.SocialAccount, excludeProvider string, ) int { if user == nil { return 0 } count := 0 if strings.TrimSpace(user.Password) != "" { count++ } if s.emailCodeSvc != nil && strings.TrimSpace(domain.DerefStr(user.Email)) != "" { count++ } if s.smsCodeSvc != nil && strings.TrimSpace(domain.DerefStr(user.Phone)) != "" { count++ } normalizedExcludeProvider := strings.ToLower(strings.TrimSpace(excludeProvider)) for _, account := range accounts { if account == nil || account.Status != domain.SocialAccountStatusActive { continue } if strings.EqualFold(strings.TrimSpace(account.Provider), normalizedExcludeProvider) { continue } count++ } return count } func (s *AuthService) generateLoginResponse(ctx context.Context, user *domain.User, remember bool) (*LoginResponse, error) { if s == nil || s.jwtManager == nil { return nil, errors.New("jwt manager is not configured") } if user == nil { return nil, errors.New("user is required") } var accessToken, refreshToken string var err error if remember { accessToken, refreshToken, err = s.jwtManager.GenerateTokenPairWithRemember(user.ID, user.Username, remember) } else { accessToken, refreshToken, err = s.jwtManager.GenerateTokenPair(user.ID, user.Username) } if err != nil { return nil, err } s.cacheUserInfo(ctx, user) return &LoginResponse{ AccessToken: accessToken, RefreshToken: refreshToken, ExpiresIn: s.accessTokenTTLSeconds(), User: s.buildUserInfo(user), }, nil } // generateLoginResponseWithoutRemember 生成登录响应(不支持记住登录) func (s *AuthService) generateLoginResponseWithoutRemember(ctx context.Context, user *domain.User) (*LoginResponse, error) { return s.generateLoginResponse(ctx, user, false) } func (s *AuthService) BindSocialAccount(ctx context.Context, userID int64, provider, openID string) error { if s == nil || s.socialRepo == nil || s.userRepo == nil { return errors.New("social account binding is not configured") } user, err := s.userRepo.GetByID(ctx, userID) if err != nil { return err } if err := s.ensureUserActive(user); err != nil { return err } normalizedProvider := strings.ToLower(strings.TrimSpace(provider)) normalizedOpenID := strings.TrimSpace(openID) if normalizedProvider == "" || normalizedOpenID == "" { return errors.New("provider and open_id are required") } accounts, err := s.GetSocialAccounts(ctx, userID) if err != nil { return err } if existingProvider := findSocialAccountByProvider(accounts, normalizedProvider); existingProvider != nil && !strings.EqualFold(strings.TrimSpace(existingProvider.OpenID), normalizedOpenID) { return errors.New("provider already bound to current account") } existing, err := s.socialRepo.GetByProviderAndOpenID(ctx, normalizedProvider, normalizedOpenID) if err != nil { return err } if existing != nil { if existing.UserID == userID { return nil } return auth.ErrOAuthAlreadyBound } return s.socialRepo.Create(ctx, &domain.SocialAccount{ UserID: userID, Provider: normalizedProvider, OpenID: normalizedOpenID, Status: domain.SocialAccountStatusActive, }) } func (s *AuthService) UnbindSocialAccount(ctx context.Context, userID int64, provider, currentPassword, totpCode string) error { if s == nil || s.socialRepo == nil || s.userRepo == nil { return errors.New("social account binding is not configured") } user, err := s.userRepo.GetByID(ctx, userID) if err != nil { return err } if err := s.ensureUserActive(user); err != nil { return err } accounts, err := s.GetSocialAccounts(ctx, userID) if err != nil { return err } normalizedProvider := strings.ToLower(strings.TrimSpace(provider)) if findSocialAccountByProvider(accounts, normalizedProvider) == nil { return auth.ErrOAuthNotFound } if err := s.verifySensitiveAction(ctx, user, currentPassword, totpCode); err != nil { return err } if s.availableLoginMethodCount(user, accounts, normalizedProvider) == 0 { return errors.New("at least one login method must remain after unbinding") } return s.socialRepo.DeleteByProviderAndUserID(ctx, normalizedProvider, userID) } func (s *AuthService) GetSocialAccounts(ctx context.Context, userID int64) ([]*domain.SocialAccount, error) { if s == nil || s.socialRepo == nil { return []*domain.SocialAccount{}, nil } accounts, err := s.socialRepo.GetByUserID(ctx, userID) if err != nil { return nil, err } if accounts == nil { return []*domain.SocialAccount{}, nil } return accounts, nil } func (s *AuthService) GetEnabledOAuthProviders() []auth.OAuthProviderInfo { if s == nil || s.oauthManager == nil { return []auth.OAuthProviderInfo{} } providers := s.oauthManager.GetEnabledProviders() if providers == nil { return []auth.OAuthProviderInfo{} } return providers } func (s *AuthService) LoginByCode(ctx context.Context, phone, code, ip string) (*LoginResponse, error) { if s == nil || s.smsCodeSvc == nil || s.userRepo == nil { return nil, errors.New("sms code login is not configured") } phone = strings.TrimSpace(phone) if phone == "" { return nil, errors.New("手机号不能为空") } if err := s.smsCodeSvc.VerifyCode(ctx, phone, "login", strings.TrimSpace(code)); err != nil { s.writeLoginLog(ctx, nil, domain.LoginTypeSMSCode, ip, false, err.Error()) return nil, err } user, err := s.userRepo.GetByPhone(ctx, phone) if err != nil { if isUserNotFoundError(err) { s.writeLoginLog(ctx, nil, domain.LoginTypeSMSCode, ip, false, "手机号未注册") return nil, errors.New("手机号未注册") } s.writeLoginLog(ctx, nil, domain.LoginTypeSMSCode, ip, false, err.Error()) return nil, err } if err := s.ensureUserActive(user); err != nil { s.writeLoginLog(ctx, &user.ID, domain.LoginTypeSMSCode, ip, false, err.Error()) s.recordLoginAnomaly(ctx, &user.ID, ip, "", "", false) return nil, err } s.bestEffortUpdateLastLogin(ctx, user.ID, ip, "sms_code") s.cacheUserInfo(ctx, user) s.writeLoginLog(ctx, &user.ID, domain.LoginTypeSMSCode, ip, true, "") s.recordLoginAnomaly(ctx, &user.ID, ip, "", "", true) s.publishEvent(ctx, domain.EventUserLogin, map[string]interface{}{ "user_id": user.ID, "username": user.Username, "ip": ip, "method": "sms_code", }) return s.generateLoginResponseWithoutRemember(ctx, user) }