150 lines
4.5 KiB
Go
150 lines
4.5 KiB
Go
|
|
package auth
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"bytes"
|
|||
|
|
"crypto/hmac"
|
|||
|
|
"crypto/rand"
|
|||
|
|
"crypto/sha256"
|
|||
|
|
"crypto/subtle"
|
|||
|
|
"encoding/base32"
|
|||
|
|
"encoding/base64"
|
|||
|
|
"encoding/hex"
|
|||
|
|
"fmt"
|
|||
|
|
"image/png"
|
|||
|
|
"strings"
|
|||
|
|
"time"
|
|||
|
|
|
|||
|
|
"github.com/pquerna/otp"
|
|||
|
|
"github.com/pquerna/otp/totp"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
const (
|
|||
|
|
// TOTPIssuer 应用名称(显示在 Authenticator App 中)
|
|||
|
|
TOTPIssuer = "UserManagementSystem"
|
|||
|
|
// TOTPPeriod TOTP 时间步长(秒)
|
|||
|
|
TOTPPeriod = 30
|
|||
|
|
// TOTPDigits TOTP 位数
|
|||
|
|
TOTPDigits = 6
|
|||
|
|
// TOTPAlgorithm TOTP 算法(使用 SHA256 更安全)
|
|||
|
|
TOTPAlgorithm = otp.AlgorithmSHA256
|
|||
|
|
// RecoveryCodeCount 恢复码数量
|
|||
|
|
RecoveryCodeCount = 8
|
|||
|
|
// RecoveryCodeLength 每个恢复码的字节长度(生成后编码为 hex 字符串)
|
|||
|
|
RecoveryCodeLength = 5
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// TOTPManager TOTP 管理器
|
|||
|
|
type TOTPManager struct{}
|
|||
|
|
|
|||
|
|
// NewTOTPManager 创建 TOTP 管理器
|
|||
|
|
func NewTOTPManager() *TOTPManager {
|
|||
|
|
return &TOTPManager{}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// TOTPSetup TOTP 初始化结果
|
|||
|
|
type TOTPSetup struct {
|
|||
|
|
Secret string `json:"secret"` // Base32 密钥(用户备用)
|
|||
|
|
QRCodeBase64 string `json:"qr_code_base64"` // Base64 编码的 PNG 二维码图片
|
|||
|
|
RecoveryCodes []string `json:"recovery_codes"` // 一次性恢复码列表
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GenerateSecret 为指定用户生成 TOTP 密钥及二维码
|
|||
|
|
func (m *TOTPManager) GenerateSecret(username string) (*TOTPSetup, error) {
|
|||
|
|
key, err := totp.Generate(totp.GenerateOpts{
|
|||
|
|
Issuer: TOTPIssuer,
|
|||
|
|
AccountName: username,
|
|||
|
|
Period: TOTPPeriod,
|
|||
|
|
Digits: otp.DigitsSix,
|
|||
|
|
Algorithm: TOTPAlgorithm,
|
|||
|
|
})
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("generate totp key failed: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 生成二维码图片
|
|||
|
|
img, err := key.Image(200, 200)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("generate qr image failed: %w", err)
|
|||
|
|
}
|
|||
|
|
var buf bytes.Buffer
|
|||
|
|
if err := png.Encode(&buf, img); err != nil {
|
|||
|
|
return nil, fmt.Errorf("encode qr image failed: %w", err)
|
|||
|
|
}
|
|||
|
|
qrBase64 := base64.StdEncoding.EncodeToString(buf.Bytes())
|
|||
|
|
|
|||
|
|
// 生成恢复码
|
|||
|
|
codes, err := generateRecoveryCodes(RecoveryCodeCount)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("generate recovery codes failed: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return &TOTPSetup{
|
|||
|
|
Secret: key.Secret(),
|
|||
|
|
QRCodeBase64: qrBase64,
|
|||
|
|
RecoveryCodes: codes,
|
|||
|
|
}, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ValidateCode 验证用户输入的 TOTP 码(允许 ±1 个时间窗口的时钟偏差)
|
|||
|
|
func (m *TOTPManager) ValidateCode(secret, code string) bool {
|
|||
|
|
// 注意:pquerna/otp 库的 ValidateCustom 与 GenerateCode 存在算法不匹配 bug(GenerateCode 固定用 SHA1)
|
|||
|
|
// 因此使用 totp.Validate() 代替,它内部正确处理算法检测
|
|||
|
|
return totp.Validate(strings.TrimSpace(code), secret)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GenerateCurrentCode 生成当前时间的 TOTP 码(用于测试)
|
|||
|
|
func (m *TOTPManager) GenerateCurrentCode(secret string) (string, error) {
|
|||
|
|
return totp.GenerateCode(secret, time.Now().UTC())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ValidateRecoveryCode 验证恢复码(传入哈希后的已存储恢复码列表,返回匹配索引)
|
|||
|
|
// 注意:调用方负责在验证后将该恢复码标记为已使用
|
|||
|
|
// 使用恒定时间比较防止时序攻击
|
|||
|
|
func ValidateRecoveryCode(inputCode string, storedCodes []string) (int, bool) {
|
|||
|
|
normalized := strings.ToUpper(strings.ReplaceAll(strings.TrimSpace(inputCode), "-", ""))
|
|||
|
|
for i, stored := range storedCodes {
|
|||
|
|
storedNormalized := strings.ToUpper(strings.ReplaceAll(stored, "-", ""))
|
|||
|
|
// 使用恒定时间比较防止时序攻击
|
|||
|
|
if subtle.ConstantTimeCompare([]byte(normalized), []byte(storedNormalized)) == 1 {
|
|||
|
|
return i, true
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return -1, false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// HashRecoveryCode 使用 SHA256 哈希恢复码(用于存储)
|
|||
|
|
func HashRecoveryCode(code string) (string, error) {
|
|||
|
|
h := sha256.Sum256([]byte(code))
|
|||
|
|
return hex.EncodeToString(h[:]), nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// VerifyRecoveryCode 验证恢复码(自动哈希后比较)
|
|||
|
|
func VerifyRecoveryCode(inputCode string, hashedCodes []string) (int, bool) {
|
|||
|
|
hashedInput, err := HashRecoveryCode(inputCode)
|
|||
|
|
if err != nil {
|
|||
|
|
return -1, false
|
|||
|
|
}
|
|||
|
|
for i, hashed := range hashedCodes {
|
|||
|
|
if hmac.Equal([]byte(hashedInput), []byte(hashed)) {
|
|||
|
|
return i, true
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return -1, false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// generateRecoveryCodes 生成 N 个随机恢复码(格式:XXXXX-XXXXX)
|
|||
|
|
func generateRecoveryCodes(count int) ([]string, error) {
|
|||
|
|
codes := make([]string, count)
|
|||
|
|
for i := 0; i < count; i++ {
|
|||
|
|
b := make([]byte, RecoveryCodeLength*2)
|
|||
|
|
if _, err := rand.Read(b); err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
encoded := base32.StdEncoding.EncodeToString(b)
|
|||
|
|
// 格式化为 XXXXX-XXXXX
|
|||
|
|
part := strings.ToUpper(encoded[:10])
|
|||
|
|
codes[i] = part[:5] + "-" + part[5:]
|
|||
|
|
}
|
|||
|
|
return codes, nil
|
|||
|
|
}
|