149 lines
4.2 KiB
Go
149 lines
4.2 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
|
|
"github.com/user-management-system/internal/auth"
|
|
)
|
|
|
|
// TOTPService manages 2FA setup, enable/disable, and verification.
|
|
type TOTPService struct {
|
|
userRepo userRepositoryInterface
|
|
totpManager *auth.TOTPManager
|
|
}
|
|
|
|
func NewTOTPService(userRepo userRepositoryInterface) *TOTPService {
|
|
return &TOTPService{
|
|
userRepo: userRepo,
|
|
totpManager: auth.NewTOTPManager(),
|
|
}
|
|
}
|
|
|
|
type SetupTOTPResponse struct {
|
|
Secret string `json:"secret"`
|
|
QRCodeBase64 string `json:"qr_code_base64"`
|
|
RecoveryCodes []string `json:"recovery_codes"`
|
|
}
|
|
|
|
func (s *TOTPService) SetupTOTP(ctx context.Context, userID int64) (*SetupTOTPResponse, error) {
|
|
user, err := s.userRepo.GetByID(ctx, userID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
|
|
}
|
|
if user.TOTPEnabled {
|
|
return nil, errors.New("2FA \u5df2\u7ecf\u542f\u7528\uff0c\u5982\u9700\u91cd\u7f6e\u8bf7\u5148\u7981\u7528")
|
|
}
|
|
|
|
setup, err := s.totpManager.GenerateSecret(user.Username)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("\u751f\u6210 TOTP \u5bc6\u94a5\u5931\u8d25: %w", err)
|
|
}
|
|
|
|
// Persist the generated secret and recovery codes before activation.
|
|
user.TOTPSecret = setup.Secret
|
|
// Hash recovery codes before storing (SEC-03 fix)
|
|
hashedCodes := make([]string, len(setup.RecoveryCodes))
|
|
for i, code := range setup.RecoveryCodes {
|
|
hashedCodes[i], _ = auth.HashRecoveryCode(code)
|
|
}
|
|
codesJSON, _ := json.Marshal(hashedCodes)
|
|
user.TOTPRecoveryCodes = string(codesJSON)
|
|
|
|
if err := s.userRepo.UpdateTOTP(ctx, user); err != nil {
|
|
return nil, fmt.Errorf("\u4fdd\u5b58 TOTP \u5bc6\u94a5\u5931\u8d25: %w", err)
|
|
}
|
|
|
|
return &SetupTOTPResponse{
|
|
Secret: setup.Secret,
|
|
QRCodeBase64: setup.QRCodeBase64,
|
|
RecoveryCodes: setup.RecoveryCodes,
|
|
}, nil
|
|
}
|
|
|
|
func (s *TOTPService) EnableTOTP(ctx context.Context, userID int64, code string) error {
|
|
user, err := s.userRepo.GetByID(ctx, userID)
|
|
if err != nil {
|
|
return fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
|
|
}
|
|
if user.TOTPSecret == "" {
|
|
return errors.New("\u8bf7\u5148\u521d\u59cb\u5316 2FA\uff0c\u83b7\u53d6\u4e8c\u7ef4\u7801\u540e\u518d\u6fc0\u6d3b")
|
|
}
|
|
if user.TOTPEnabled {
|
|
return errors.New("2FA \u5df2\u542f\u7528")
|
|
}
|
|
|
|
if !s.totpManager.ValidateCode(user.TOTPSecret, code) {
|
|
return errors.New("\u9a8c\u8bc1\u7801\u9519\u8bef\u6216\u5df2\u8fc7\u671f")
|
|
}
|
|
|
|
user.TOTPEnabled = true
|
|
return s.userRepo.UpdateTOTP(ctx, user)
|
|
}
|
|
|
|
func (s *TOTPService) DisableTOTP(ctx context.Context, userID int64, code string) error {
|
|
user, err := s.userRepo.GetByID(ctx, userID)
|
|
if err != nil {
|
|
return fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
|
|
}
|
|
if !user.TOTPEnabled {
|
|
return errors.New("2FA \u672a\u542f\u7528")
|
|
}
|
|
|
|
valid := s.totpManager.ValidateCode(user.TOTPSecret, code)
|
|
if !valid {
|
|
var hashedCodes []string
|
|
if user.TOTPRecoveryCodes != "" {
|
|
_ = json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes)
|
|
}
|
|
_, matched := auth.VerifyRecoveryCode(code, hashedCodes)
|
|
if !matched {
|
|
return errors.New("\u9a8c\u8bc1\u7801\u6216\u6062\u590d\u7801\u9519\u8bef")
|
|
}
|
|
}
|
|
|
|
user.TOTPEnabled = false
|
|
user.TOTPSecret = ""
|
|
user.TOTPRecoveryCodes = ""
|
|
return s.userRepo.UpdateTOTP(ctx, user)
|
|
}
|
|
|
|
func (s *TOTPService) VerifyTOTP(ctx context.Context, userID int64, code string) error {
|
|
user, err := s.userRepo.GetByID(ctx, userID)
|
|
if err != nil {
|
|
return fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
|
|
}
|
|
if !user.TOTPEnabled {
|
|
return nil
|
|
}
|
|
|
|
if s.totpManager.ValidateCode(user.TOTPSecret, code) {
|
|
return nil
|
|
}
|
|
|
|
var storedCodes []string
|
|
if user.TOTPRecoveryCodes != "" {
|
|
_ = json.Unmarshal([]byte(user.TOTPRecoveryCodes), &storedCodes)
|
|
}
|
|
idx, matched := auth.ValidateRecoveryCode(code, storedCodes)
|
|
if !matched {
|
|
return errors.New("\u9a8c\u8bc1\u7801\u9519\u8bef\u6216\u5df2\u8fc7\u671f")
|
|
}
|
|
|
|
storedCodes = append(storedCodes[:idx], storedCodes[idx+1:]...)
|
|
codesJSON, _ := json.Marshal(storedCodes)
|
|
user.TOTPRecoveryCodes = string(codesJSON)
|
|
_ = s.userRepo.UpdateTOTP(ctx, user)
|
|
return nil
|
|
}
|
|
|
|
func (s *TOTPService) GetTOTPStatus(ctx context.Context, userID int64) (bool, error) {
|
|
user, err := s.userRepo.GetByID(ctx, userID)
|
|
if err != nil {
|
|
return false, fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
|
|
}
|
|
return user.TOTPEnabled, nil
|
|
}
|