Files
user-system/internal/service/password_reset.go

308 lines
8.8 KiB
Go
Raw Normal View History

package service
import (
"context"
cryptorand "crypto/rand"
"crypto/subtle"
"encoding/hex"
"errors"
"fmt"
"log"
"net/smtp"
"time"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/cache"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
"github.com/user-management-system/internal/security"
)
// PasswordResetConfig controls reset-token issuance and SMTP delivery.
type PasswordResetConfig struct {
TokenTTL time.Duration
SMTPHost string
SMTPPort int
SMTPUser string
SMTPPass string
FromEmail string
SiteURL string
PasswordMinLen int
PasswordRequireSpecial bool
PasswordRequireNumber bool
}
func DefaultPasswordResetConfig() *PasswordResetConfig {
return &PasswordResetConfig{
TokenTTL: 15 * time.Minute,
SMTPHost: "",
SMTPPort: 587,
SMTPUser: "",
SMTPPass: "",
FromEmail: "noreply@example.com",
SiteURL: "http://localhost:8080",
PasswordMinLen: 8,
PasswordRequireSpecial: false,
PasswordRequireNumber: false,
}
}
type PasswordResetService struct {
userRepo userRepositoryInterface
cache *cache.CacheManager
config *PasswordResetConfig
passwordHistoryRepo *repository.PasswordHistoryRepository
}
func NewPasswordResetService(
userRepo userRepositoryInterface,
cache *cache.CacheManager,
config *PasswordResetConfig,
) *PasswordResetService {
if config == nil {
config = DefaultPasswordResetConfig()
}
return &PasswordResetService{
userRepo: userRepo,
cache: cache,
config: config,
}
}
// WithPasswordHistoryRepo 注入密码历史 repository用于重置密码时记录历史
func (s *PasswordResetService) WithPasswordHistoryRepo(repo *repository.PasswordHistoryRepository) *PasswordResetService {
s.passwordHistoryRepo = repo
return s
}
func (s *PasswordResetService) ForgotPassword(ctx context.Context, email string) error {
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
return nil
}
tokenBytes := make([]byte, 32)
if _, err := cryptorand.Read(tokenBytes); err != nil {
return fmt.Errorf("生成重置Token失败: %w", err)
}
resetToken := hex.EncodeToString(tokenBytes)
cacheKey := "pwd_reset:" + resetToken
ttl := s.config.TokenTTL
if err := s.cache.Set(ctx, cacheKey, user.ID, ttl, ttl); err != nil {
return fmt.Errorf("缓存重置Token失败: %w", err)
}
go s.sendResetEmail(domain.DerefStr(user.Email), user.Username, resetToken)
return nil
}
func (s *PasswordResetService) ResetPassword(ctx context.Context, token, newPassword string) error {
if token == "" || newPassword == "" {
return errors.New("参数不完整")
}
cacheKey := "pwd_reset:" + token
val, ok := s.cache.Get(ctx, cacheKey)
if !ok {
return errors.New("重置链接已失效或不存在,请重新申请")
}
userID, ok := int64Value(val)
if !ok {
return errors.New("重置Token数据异常")
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return errors.New("用户不存在")
}
if err := s.doResetPassword(ctx, user, newPassword); err != nil {
return err
}
if err := s.cache.Delete(ctx, cacheKey); err != nil {
return fmt.Errorf("清理重置Token失败: %w", err)
}
return nil
}
func (s *PasswordResetService) ValidateResetToken(ctx context.Context, token string) (bool, error) {
if token == "" {
return false, errors.New("token不能为空")
}
_, ok := s.cache.Get(ctx, "pwd_reset:"+token)
return ok, nil
}
func (s *PasswordResetService) sendResetEmail(email, username, token string) {
if s.config.SMTPHost == "" {
return
}
resetURL := fmt.Sprintf("%s/reset-password?token=%s", s.config.SiteURL, token)
subject := "密码重置请求"
body := fmt.Sprintf(`您好 %s
您收到此邮件是因为有人请求重置账户密码
请点击以下链接重置密码链接将在 %s 后失效
%s
如果不是您本人操作请忽略此邮件您的密码不会被修改
用户管理系统团队`, username, s.config.TokenTTL.String(), resetURL)
var authInfo smtp.Auth
if s.config.SMTPUser != "" || s.config.SMTPPass != "" {
authInfo = smtp.PlainAuth("", s.config.SMTPUser, s.config.SMTPPass, s.config.SMTPHost)
}
msg := fmt.Sprintf(
"From: %s\r\nTo: %s\r\nSubject: %s\r\nContent-Type: text/plain; charset=UTF-8\r\n\r\n%s",
s.config.FromEmail,
email,
subject,
body,
)
addr := fmt.Sprintf("%s:%d", s.config.SMTPHost, s.config.SMTPPort)
if err := smtp.SendMail(addr, authInfo, s.config.FromEmail, []string{email}, []byte(msg)); err != nil {
log.Printf("password-reset-email: send failed to=%s err=%v", email, err)
}
}
// ForgotPasswordByPhoneRequest 短信密码重置请求
type ForgotPasswordByPhoneRequest struct {
Phone string `json:"phone" binding:"required"`
}
// ForgotPasswordByPhone 通过手机验证码重置密码 - 发送验证码
func (s *PasswordResetService) ForgotPasswordByPhone(ctx context.Context, phone string) (string, error) {
user, err := s.userRepo.GetByPhone(ctx, phone)
if err != nil {
return "", nil // 用户不存在不提示,防止用户枚举
}
// 生成6位数字验证码
code, err := generateSMSCode()
if err != nil {
return "", fmt.Errorf("生成验证码失败: %w", err)
}
// 存储验证码关联用户ID
cacheKey := fmt.Sprintf("pwd_reset_sms:%s", phone)
ttl := s.config.TokenTTL
if err := s.cache.Set(ctx, cacheKey, user.ID, ttl, ttl); err != nil {
return "", fmt.Errorf("缓存验证码失败: %w", err)
}
// 存储验证码到另一个key用于后续校验
codeKey := fmt.Sprintf("pwd_reset_sms_code:%s", phone)
if err := s.cache.Set(ctx, codeKey, code, ttl, ttl); err != nil {
return "", fmt.Errorf("缓存验证码失败: %w", err)
}
return code, nil
}
// ResetPasswordByPhoneRequest 通过手机验证码重置密码请求
type ResetPasswordByPhoneRequest struct {
Phone string `json:"phone" binding:"required"`
Code string `json:"code" binding:"required"`
NewPassword string `json:"new_password" binding:"required"`
}
// ResetPasswordByPhone 通过手机验证码重置密码 - 验证并重置
func (s *PasswordResetService) ResetPasswordByPhone(ctx context.Context, req *ResetPasswordByPhoneRequest) error {
if req.Phone == "" || req.Code == "" || req.NewPassword == "" {
return errors.New("参数不完整")
}
codeKey := fmt.Sprintf("pwd_reset_sms_code:%s", req.Phone)
storedCode, ok := s.cache.Get(ctx, codeKey)
if !ok {
return errors.New("验证码已失效,请重新获取")
}
code, ok := storedCode.(string)
if !ok || subtle.ConstantTimeCompare([]byte(code), []byte(req.Code)) != 1 {
return errors.New("验证码不正确")
}
// 获取用户ID
cacheKey := fmt.Sprintf("pwd_reset_sms:%s", req.Phone)
val, ok := s.cache.Get(ctx, cacheKey)
if !ok {
return errors.New("验证码已失效,请重新获取")
}
userID, ok := int64Value(val)
if !ok {
return errors.New("验证码数据异常")
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return errors.New("用户不存在")
}
if err := s.doResetPassword(ctx, user, req.NewPassword); err != nil {
return err
}
// 清理验证码
s.cache.Delete(ctx, codeKey)
s.cache.Delete(ctx, cacheKey)
return nil
}
func (s *PasswordResetService) doResetPassword(ctx context.Context, user *domain.User, newPassword string) error {
policy := security.PasswordPolicy{
MinLength: s.config.PasswordMinLen,
RequireSpecial: s.config.PasswordRequireSpecial,
RequireNumber: s.config.PasswordRequireNumber,
}.Normalize()
if err := policy.Validate(newPassword); err != nil {
return err
}
// 检查密码历史防止重用近5次密码
if s.passwordHistoryRepo != nil {
histories, err := s.passwordHistoryRepo.GetByUserID(ctx, user.ID, passwordHistoryLimit)
if err == nil {
for _, h := range histories {
if auth.VerifyPassword(h.PasswordHash, newPassword) {
return errors.New("新密码不能与最近5次密码相同")
}
}
}
}
hashedPassword, err := auth.HashPassword(newPassword)
if err != nil {
return fmt.Errorf("密码加密失败: %w", err)
}
user.Password = hashedPassword
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("更新密码失败: %w", err)
}
// 写入密码历史记录
if s.passwordHistoryRepo != nil {
go func() {
// 使用带超时的独立 context防止 DB 写入无限等待
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.passwordHistoryRepo.Create(bgCtx, &domain.PasswordHistory{
UserID: user.ID,
PasswordHash: hashedPassword,
})
_ = s.passwordHistoryRepo.DeleteOldRecords(bgCtx, user.ID, passwordHistoryLimit)
}()
}
return nil
}