150 lines
4.5 KiB
Go
150 lines
4.5 KiB
Go
package auth
|
||
|
||
import (
|
||
"bytes"
|
||
"crypto/hmac"
|
||
"crypto/rand"
|
||
"crypto/sha256"
|
||
"crypto/subtle"
|
||
"encoding/base32"
|
||
"encoding/base64"
|
||
"encoding/hex"
|
||
"fmt"
|
||
"image/png"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/pquerna/otp"
|
||
"github.com/pquerna/otp/totp"
|
||
)
|
||
|
||
const (
|
||
// TOTPIssuer 应用名称(显示在 Authenticator App 中)
|
||
TOTPIssuer = "UserManagementSystem"
|
||
// TOTPPeriod TOTP 时间步长(秒)
|
||
TOTPPeriod = 30
|
||
// TOTPDigits TOTP 位数
|
||
TOTPDigits = 6
|
||
// TOTPAlgorithm TOTP 算法(使用 SHA256 更安全)
|
||
TOTPAlgorithm = otp.AlgorithmSHA256
|
||
// RecoveryCodeCount 恢复码数量
|
||
RecoveryCodeCount = 8
|
||
// RecoveryCodeLength 每个恢复码的字节长度(生成后编码为 hex 字符串)
|
||
RecoveryCodeLength = 5
|
||
)
|
||
|
||
// TOTPManager TOTP 管理器
|
||
type TOTPManager struct{}
|
||
|
||
// NewTOTPManager 创建 TOTP 管理器
|
||
func NewTOTPManager() *TOTPManager {
|
||
return &TOTPManager{}
|
||
}
|
||
|
||
// TOTPSetup TOTP 初始化结果
|
||
type TOTPSetup struct {
|
||
Secret string `json:"secret"` // Base32 密钥(用户备用)
|
||
QRCodeBase64 string `json:"qr_code_base64"` // Base64 编码的 PNG 二维码图片
|
||
RecoveryCodes []string `json:"recovery_codes"` // 一次性恢复码列表
|
||
}
|
||
|
||
// GenerateSecret 为指定用户生成 TOTP 密钥及二维码
|
||
func (m *TOTPManager) GenerateSecret(username string) (*TOTPSetup, error) {
|
||
key, err := totp.Generate(totp.GenerateOpts{
|
||
Issuer: TOTPIssuer,
|
||
AccountName: username,
|
||
Period: TOTPPeriod,
|
||
Digits: otp.DigitsSix,
|
||
Algorithm: TOTPAlgorithm,
|
||
})
|
||
if err != nil {
|
||
return nil, fmt.Errorf("generate totp key failed: %w", err)
|
||
}
|
||
|
||
// 生成二维码图片
|
||
img, err := key.Image(200, 200)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("generate qr image failed: %w", err)
|
||
}
|
||
var buf bytes.Buffer
|
||
if err := png.Encode(&buf, img); err != nil {
|
||
return nil, fmt.Errorf("encode qr image failed: %w", err)
|
||
}
|
||
qrBase64 := base64.StdEncoding.EncodeToString(buf.Bytes())
|
||
|
||
// 生成恢复码
|
||
codes, err := generateRecoveryCodes(RecoveryCodeCount)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("generate recovery codes failed: %w", err)
|
||
}
|
||
|
||
return &TOTPSetup{
|
||
Secret: key.Secret(),
|
||
QRCodeBase64: qrBase64,
|
||
RecoveryCodes: codes,
|
||
}, nil
|
||
}
|
||
|
||
// ValidateCode 验证用户输入的 TOTP 码(允许 ±1 个时间窗口的时钟偏差)
|
||
func (m *TOTPManager) ValidateCode(secret, code string) bool {
|
||
// 注意:pquerna/otp 库的 ValidateCustom 与 GenerateCode 存在算法不匹配 bug(GenerateCode 固定用 SHA1)
|
||
// 因此使用 totp.Validate() 代替,它内部正确处理算法检测
|
||
return totp.Validate(strings.TrimSpace(code), secret)
|
||
}
|
||
|
||
// GenerateCurrentCode 生成当前时间的 TOTP 码(用于测试)
|
||
func (m *TOTPManager) GenerateCurrentCode(secret string) (string, error) {
|
||
return totp.GenerateCode(secret, time.Now().UTC())
|
||
}
|
||
|
||
// ValidateRecoveryCode 验证恢复码(传入哈希后的已存储恢复码列表,返回匹配索引)
|
||
// 注意:调用方负责在验证后将该恢复码标记为已使用
|
||
// 使用恒定时间比较防止时序攻击
|
||
func ValidateRecoveryCode(inputCode string, storedCodes []string) (int, bool) {
|
||
normalized := strings.ToUpper(strings.ReplaceAll(strings.TrimSpace(inputCode), "-", ""))
|
||
for i, stored := range storedCodes {
|
||
storedNormalized := strings.ToUpper(strings.ReplaceAll(stored, "-", ""))
|
||
// 使用恒定时间比较防止时序攻击
|
||
if subtle.ConstantTimeCompare([]byte(normalized), []byte(storedNormalized)) == 1 {
|
||
return i, true
|
||
}
|
||
}
|
||
return -1, false
|
||
}
|
||
|
||
// HashRecoveryCode 使用 SHA256 哈希恢复码(用于存储)
|
||
func HashRecoveryCode(code string) (string, error) {
|
||
h := sha256.Sum256([]byte(code))
|
||
return hex.EncodeToString(h[:]), nil
|
||
}
|
||
|
||
// VerifyRecoveryCode 验证恢复码(自动哈希后比较)
|
||
func VerifyRecoveryCode(inputCode string, hashedCodes []string) (int, bool) {
|
||
hashedInput, err := HashRecoveryCode(inputCode)
|
||
if err != nil {
|
||
return -1, false
|
||
}
|
||
for i, hashed := range hashedCodes {
|
||
if hmac.Equal([]byte(hashedInput), []byte(hashed)) {
|
||
return i, true
|
||
}
|
||
}
|
||
return -1, false
|
||
}
|
||
|
||
// generateRecoveryCodes 生成 N 个随机恢复码(格式:XXXXX-XXXXX)
|
||
func generateRecoveryCodes(count int) ([]string, error) {
|
||
codes := make([]string, count)
|
||
for i := 0; i < count; i++ {
|
||
b := make([]byte, RecoveryCodeLength*2)
|
||
if _, err := rand.Read(b); err != nil {
|
||
return nil, err
|
||
}
|
||
encoded := base32.StdEncoding.EncodeToString(b)
|
||
// 格式化为 XXXXX-XXXXX
|
||
part := strings.ToUpper(encoded[:10])
|
||
codes[i] = part[:5] + "-" + part[5:]
|
||
}
|
||
return codes, nil
|
||
}
|