feat: backend core - auth, user, role, permission, device, webhook, monitoring, cache, repository, service, middleware, API handlers
This commit is contained in:
95
internal/security/encryption.go
Normal file
95
internal/security/encryption.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Encryption 加密工具
|
||||
type Encryption struct {
|
||||
key []byte
|
||||
}
|
||||
|
||||
// NewEncryption 创建加密工具(密钥长度必须是16, 24或32字节)
|
||||
func NewEncryption(key string) (*Encryption, error) {
|
||||
if len(key) != 16 && len(key) != 24 && len(key) != 32 {
|
||||
return nil, errors.New("key length must be 16, 24 or 32 bytes")
|
||||
}
|
||||
return &Encryption{key: []byte(key)}, nil
|
||||
}
|
||||
|
||||
// Encrypt 使用AES-256-GCM加密
|
||||
func (e *Encryption) Encrypt(plaintext string) (string, error) {
|
||||
block, err := aes.NewCipher(e.key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// Decrypt 使用AES-256-GCM解密
|
||||
func (e *Encryption) Decrypt(ciphertext string) (string, error) {
|
||||
data, err := base64.StdEncoding.DecodeString(ciphertext)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(e.key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(data) < nonceSize {
|
||||
return "", errors.New("ciphertext too short")
|
||||
}
|
||||
|
||||
nonce, cipherData := data[:nonceSize], data[nonceSize:]
|
||||
plaintext, err := gcm.Open(nil, nonce, cipherData, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
|
||||
// MaskEmail 邮箱脱敏
|
||||
func MaskEmail(email string) string {
|
||||
if email == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
prefix := email[:3]
|
||||
suffix := email[strings.Index(email, "@"):]
|
||||
return prefix + "***" + suffix
|
||||
}
|
||||
|
||||
// MaskPhone 手机号脱敏
|
||||
func MaskPhone(phone string) string {
|
||||
if len(phone) != 11 {
|
||||
return phone
|
||||
}
|
||||
return phone[:3] + "****" + phone[7:]
|
||||
}
|
||||
373
internal/security/ip_filter.go
Normal file
373
internal/security/ip_filter.go
Normal file
@@ -0,0 +1,373 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// IPRule IP 规则
|
||||
type IPRule struct {
|
||||
IP string // CIDR 或精确 IP
|
||||
Reason string // 封禁原因
|
||||
ExpireAt time.Time // 过期时间(零值表示永久)
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// isExpired 是否已过期
|
||||
func (r *IPRule) isExpired() bool {
|
||||
return !r.ExpireAt.IsZero() && time.Now().After(r.ExpireAt)
|
||||
}
|
||||
|
||||
// IPFilter IP 黑白名单过滤器
|
||||
type IPFilter struct {
|
||||
mu sync.RWMutex
|
||||
blacklist map[string]*IPRule // key: IP/CIDR
|
||||
whitelist map[string]*IPRule // key: IP/CIDR
|
||||
}
|
||||
|
||||
// NewIPFilter 创建 IP 过滤器
|
||||
func NewIPFilter() *IPFilter {
|
||||
return &IPFilter{
|
||||
blacklist: make(map[string]*IPRule),
|
||||
whitelist: make(map[string]*IPRule),
|
||||
}
|
||||
}
|
||||
|
||||
// AddToBlacklist 将 IP/CIDR 加入黑名单
|
||||
// duration 为 0 表示永久封禁
|
||||
func (f *IPFilter) AddToBlacklist(ip, reason string, duration time.Duration) error {
|
||||
if err := validateIPOrCIDR(ip); err != nil {
|
||||
return err
|
||||
}
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
rule := &IPRule{
|
||||
IP: ip,
|
||||
Reason: reason,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if duration > 0 {
|
||||
rule.ExpireAt = time.Now().Add(duration)
|
||||
}
|
||||
f.blacklist[ip] = rule
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveFromBlacklist 从黑名单移除
|
||||
func (f *IPFilter) RemoveFromBlacklist(ip string) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
delete(f.blacklist, ip)
|
||||
}
|
||||
|
||||
// AddToWhitelist 将 IP/CIDR 加入白名单
|
||||
func (f *IPFilter) AddToWhitelist(ip, reason string) error {
|
||||
if err := validateIPOrCIDR(ip); err != nil {
|
||||
return err
|
||||
}
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.whitelist[ip] = &IPRule{
|
||||
IP: ip,
|
||||
Reason: reason,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveFromWhitelist 从白名单移除
|
||||
func (f *IPFilter) RemoveFromWhitelist(ip string) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
delete(f.whitelist, ip)
|
||||
}
|
||||
|
||||
// IsBlocked 检查 IP 是否被封禁
|
||||
// 白名单优先:白名单中的 IP 永远不被封禁
|
||||
func (f *IPFilter) IsBlocked(ip string) (bool, string) {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
|
||||
// 白名单检查(优先)
|
||||
if f.matchesAnyRule(ip, f.whitelist) {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// 黑名单检查
|
||||
for _, rule := range f.blacklist {
|
||||
if rule.isExpired() {
|
||||
continue
|
||||
}
|
||||
if matchIP(ip, rule.IP) {
|
||||
return true, rule.Reason
|
||||
}
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// CleanExpired 清理过期规则
|
||||
func (f *IPFilter) CleanExpired() {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
for k, rule := range f.blacklist {
|
||||
if rule.isExpired() {
|
||||
delete(f.blacklist, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ListBlacklist 列出黑名单(不含已过期的)
|
||||
func (f *IPFilter) ListBlacklist() []*IPRule {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
var result []*IPRule
|
||||
for _, rule := range f.blacklist {
|
||||
if !rule.isExpired() {
|
||||
result = append(result, rule)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ListWhitelist 列出白名单
|
||||
func (f *IPFilter) ListWhitelist() []*IPRule {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
var result []*IPRule
|
||||
for _, rule := range f.whitelist {
|
||||
result = append(result, rule)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// matchesAnyRule 检查 IP 是否匹配任意规则集
|
||||
func (f *IPFilter) matchesAnyRule(ip string, rules map[string]*IPRule) bool {
|
||||
for _, rule := range rules {
|
||||
if matchIP(ip, rule.IP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// matchIP 检查 ip 是否匹配 target(精确 IP 或 CIDR)
|
||||
func matchIP(ip, target string) bool {
|
||||
if ip == target {
|
||||
return true
|
||||
}
|
||||
// 尝试 CIDR 匹配
|
||||
_, network, err := net.ParseCIDR(target)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
parsed := net.ParseIP(ip)
|
||||
return parsed != nil && network.Contains(parsed)
|
||||
}
|
||||
|
||||
// validateIPOrCIDR 验证 IP 或 CIDR 格式
|
||||
func validateIPOrCIDR(s string) error {
|
||||
if net.ParseIP(s) != nil {
|
||||
return nil
|
||||
}
|
||||
if _, _, err := net.ParseCIDR(s); err == nil {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("无效的 IP 或 CIDR 格式: %s", s)
|
||||
}
|
||||
|
||||
// ---- 异常登录检测 ----
|
||||
|
||||
// AnomalyEvent 异常登录事件类型
|
||||
type AnomalyEvent string
|
||||
|
||||
const (
|
||||
AnomalyBruteForce AnomalyEvent = "brute_force" // 暴力破解(短时间大量失败)
|
||||
AnomalyNewLocation AnomalyEvent = "new_location" // 新地区登录
|
||||
AnomalyMultipleIP AnomalyEvent = "multiple_ip" // 短时间内多个 IP 登录
|
||||
AnomalyOffHours AnomalyEvent = "off_hours" // 非工作时间登录(可配置)
|
||||
AnomalyNewDevice AnomalyEvent = "new_device" // 新设备登录
|
||||
AnomalySuspicious AnomalyEvent = "suspicious" // 可疑活动(综合判断)
|
||||
)
|
||||
|
||||
// LoginRecord 登录记录
|
||||
type LoginRecord struct {
|
||||
UserID int64
|
||||
IP string
|
||||
Location string // 登录地区
|
||||
DeviceFingerprint string // 设备指纹
|
||||
Success bool
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
// AnomalyDetector 异常登录检测器
|
||||
type AnomalyDetector struct {
|
||||
mu sync.Mutex
|
||||
records map[int64][]LoginRecord // userID -> 最近登录记录
|
||||
maxRecords int // 每用户保留的最大记录数
|
||||
windowSize time.Duration // 检测时间窗口
|
||||
maxFailures int // 窗口内最大失败次数(触发暴力破解告警)
|
||||
maxIPs int // 窗口内最大不同 IP 数(触发多 IP 告警)
|
||||
ipFilter *IPFilter // 用于自动封禁
|
||||
autoBlockDur time.Duration // 自动封禁时长
|
||||
knownLocationsLimit int // 常用地区数量阈值
|
||||
knownDevicesLimit int // 已知设备数量阈值
|
||||
}
|
||||
|
||||
// AnomalyDetectorConfig 检测器配置
|
||||
type AnomalyDetectorConfig struct {
|
||||
MaxRecordsPerUser int
|
||||
Window time.Duration
|
||||
MaxFailures int
|
||||
MaxDistinctIPs int
|
||||
AutoBlockDuration time.Duration
|
||||
// 跨区域检测配置
|
||||
KnownLocationsLimit int // 常用地区数量阈值,超过则不再告警新地区(默认 5)
|
||||
// 新设备检测配置
|
||||
KnownDevicesLimit int // 已知设备数量阈值,超过则不再告警新设备(默认 10)
|
||||
}
|
||||
|
||||
// DefaultAnomalyConfig 默认配置
|
||||
var DefaultAnomalyConfig = AnomalyDetectorConfig{
|
||||
MaxRecordsPerUser: 100,
|
||||
Window: 15 * time.Minute,
|
||||
MaxFailures: 10,
|
||||
MaxDistinctIPs: 5,
|
||||
AutoBlockDuration: 30 * time.Minute,
|
||||
KnownLocationsLimit: 5,
|
||||
KnownDevicesLimit: 10,
|
||||
}
|
||||
|
||||
// NewAnomalyDetector 创建异常登录检测器
|
||||
func NewAnomalyDetector(cfg AnomalyDetectorConfig, ipFilter *IPFilter) *AnomalyDetector {
|
||||
if cfg.KnownLocationsLimit <= 0 {
|
||||
cfg.KnownLocationsLimit = 5
|
||||
}
|
||||
if cfg.KnownDevicesLimit <= 0 {
|
||||
cfg.KnownDevicesLimit = 10
|
||||
}
|
||||
return &AnomalyDetector{
|
||||
records: make(map[int64][]LoginRecord),
|
||||
maxRecords: cfg.MaxRecordsPerUser,
|
||||
windowSize: cfg.Window,
|
||||
maxFailures: cfg.MaxFailures,
|
||||
maxIPs: cfg.MaxDistinctIPs,
|
||||
ipFilter: ipFilter,
|
||||
autoBlockDur: cfg.AutoBlockDuration,
|
||||
knownLocationsLimit: cfg.KnownLocationsLimit,
|
||||
knownDevicesLimit: cfg.KnownDevicesLimit,
|
||||
}
|
||||
}
|
||||
|
||||
// RecordLogin 记录登录事件,返回检测到的异常列表
|
||||
// location: 登录地区信息(如"广东省广州市")
|
||||
// deviceFingerprint: 设备指纹(如浏览器的UserAgent+屏幕分辨率+时区等组合hash)
|
||||
func (d *AnomalyDetector) RecordLogin(_ context.Context, userID int64, ip, location, deviceFingerprint string, success bool) []AnomalyEvent {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
record := LoginRecord{
|
||||
UserID: userID,
|
||||
IP: ip,
|
||||
Location: location,
|
||||
DeviceFingerprint: deviceFingerprint,
|
||||
Success: success,
|
||||
Timestamp: now,
|
||||
}
|
||||
|
||||
// 追加记录,保留最新的 maxRecords 条
|
||||
records := append(d.records[userID], record)
|
||||
if len(records) > d.maxRecords {
|
||||
records = records[len(records)-d.maxRecords:]
|
||||
}
|
||||
d.records[userID] = records
|
||||
|
||||
// 检测异常
|
||||
return d.detect(userID, records, now)
|
||||
}
|
||||
|
||||
// detect 在持有锁的情况下检测异常
|
||||
func (d *AnomalyDetector) detect(userID int64, records []LoginRecord, now time.Time) []AnomalyEvent {
|
||||
windowStart := now.Add(-d.windowSize)
|
||||
|
||||
var failures int
|
||||
ipSet := make(map[string]struct{})
|
||||
locationSet := make(map[string]struct{}) // 历史地区集合
|
||||
deviceSet := make(map[string]struct{}) // 历史设备集合
|
||||
var currentLocation string
|
||||
var currentDeviceFingerprint string
|
||||
|
||||
for _, r := range records {
|
||||
if r.Timestamp.Before(windowStart) {
|
||||
continue
|
||||
}
|
||||
if !r.Success {
|
||||
failures++
|
||||
}
|
||||
ipSet[r.IP] = struct{}{}
|
||||
// 记录当前登录的 location 和 deviceFingerprint(最后一个在窗口内的记录)
|
||||
currentLocation = r.Location
|
||||
currentDeviceFingerprint = r.DeviceFingerprint
|
||||
if r.Location != "" {
|
||||
locationSet[r.Location] = struct{}{}
|
||||
}
|
||||
if r.DeviceFingerprint != "" {
|
||||
deviceSet[r.DeviceFingerprint] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
var events []AnomalyEvent
|
||||
|
||||
// 暴力破解检测
|
||||
if failures >= d.maxFailures {
|
||||
events = append(events, AnomalyBruteForce)
|
||||
// 自动封禁
|
||||
if d.ipFilter != nil && len(ipSet) == 1 {
|
||||
for ip := range ipSet {
|
||||
_ = d.ipFilter.AddToBlacklist(ip,
|
||||
fmt.Sprintf("自动封禁:用户 %d 暴力破解检测", userID),
|
||||
d.autoBlockDur,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 多 IP 登录检测
|
||||
if len(ipSet) >= d.maxIPs {
|
||||
events = append(events, AnomalyMultipleIP)
|
||||
}
|
||||
|
||||
// 新地区登录检测
|
||||
// 如果当前登录地区与历史记录都不相同,且历史地区数量在阈值内,则告警
|
||||
if currentLocation != "" && len(locationSet) > 0 {
|
||||
if _, seen := locationSet[currentLocation]; !seen && len(locationSet) <= d.knownLocationsLimit {
|
||||
events = append(events, AnomalyNewLocation)
|
||||
}
|
||||
}
|
||||
|
||||
// 新设备登录检测
|
||||
// 如果当前设备指纹与历史记录都不相同,且历史设备数量在阈值内,则告警
|
||||
if currentDeviceFingerprint != "" && len(deviceSet) > 0 {
|
||||
if _, seen := deviceSet[currentDeviceFingerprint]; !seen && len(deviceSet) <= d.knownDevicesLimit {
|
||||
events = append(events, AnomalyNewDevice)
|
||||
}
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// GetRecentLogins 获取用户最近的登录记录
|
||||
func (d *AnomalyDetector) GetRecentLogins(userID int64, limit int) []LoginRecord {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
records := d.records[userID]
|
||||
if len(records) <= limit {
|
||||
return records
|
||||
}
|
||||
return records[len(records)-limit:]
|
||||
}
|
||||
234
internal/security/ip_filter_test.go
Normal file
234
internal/security/ip_filter_test.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---- IPFilter 测试 ----
|
||||
|
||||
func TestIPFilter_BlacklistBasic(t *testing.T) {
|
||||
f := NewIPFilter()
|
||||
|
||||
// 未加入黑名单时,IP 应该通过
|
||||
blocked, reason := f.IsBlocked("192.168.1.1")
|
||||
if blocked {
|
||||
t.Fatalf("未加入黑名单时不应被封禁,reason=%s", reason)
|
||||
}
|
||||
|
||||
// 加入黑名单
|
||||
if err := f.AddToBlacklist("192.168.1.1", "测试封禁", 0); err != nil {
|
||||
t.Fatalf("AddToBlacklist 失败: %v", err)
|
||||
}
|
||||
|
||||
blocked, reason = f.IsBlocked("192.168.1.1")
|
||||
if !blocked {
|
||||
t.Fatal("加入黑名单后应该被封禁")
|
||||
}
|
||||
if reason == "" {
|
||||
t.Fatal("封禁原因不应为空")
|
||||
}
|
||||
t.Logf("正确封禁,reason=%s", reason)
|
||||
}
|
||||
|
||||
func TestIPFilter_BlacklistExpiry(t *testing.T) {
|
||||
f := NewIPFilter()
|
||||
|
||||
// 加入 50ms 后过期的黑名单
|
||||
if err := f.AddToBlacklist("10.0.0.1", "临时封禁", 50*time.Millisecond); err != nil {
|
||||
t.Fatalf("AddToBlacklist 失败: %v", err)
|
||||
}
|
||||
|
||||
blocked, _ := f.IsBlocked("10.0.0.1")
|
||||
if !blocked {
|
||||
t.Fatal("封禁期间应该被拦截")
|
||||
}
|
||||
|
||||
// 等待过期
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
blocked, _ = f.IsBlocked("10.0.0.1")
|
||||
if blocked {
|
||||
t.Fatal("过期后不应该再被封禁")
|
||||
}
|
||||
t.Log("过期解封正常")
|
||||
}
|
||||
|
||||
func TestIPFilter_WhitelistOverridesBlacklist(t *testing.T) {
|
||||
f := NewIPFilter()
|
||||
|
||||
// 同时加入黑名单和白名单
|
||||
_ = f.AddToBlacklist("172.16.0.1", "黑名单", 0)
|
||||
_ = f.AddToWhitelist("172.16.0.1", "白名单优先")
|
||||
|
||||
blocked, _ := f.IsBlocked("172.16.0.1")
|
||||
if blocked {
|
||||
t.Fatal("白名单应优先于黑名单")
|
||||
}
|
||||
t.Log("白名单优先级验证通过")
|
||||
}
|
||||
|
||||
func TestIPFilter_CIDRMatch(t *testing.T) {
|
||||
f := NewIPFilter()
|
||||
|
||||
// 封禁整个 /24 段
|
||||
if err := f.AddToBlacklist("10.10.10.0/24", "封禁 C 段", 0); err != nil {
|
||||
t.Fatalf("CIDR 黑名单失败: %v", err)
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
ip string
|
||||
blocked bool
|
||||
}{
|
||||
{"10.10.10.1", true},
|
||||
{"10.10.10.254", true},
|
||||
{"10.10.11.1", false},
|
||||
{"192.168.1.1", false},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
blocked, _ := f.IsBlocked(tc.ip)
|
||||
if blocked != tc.blocked {
|
||||
t.Errorf("IP %s: 期望 blocked=%v,实际=%v", tc.ip, tc.blocked, blocked)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPFilter_InvalidIP(t *testing.T) {
|
||||
f := NewIPFilter()
|
||||
err := f.AddToBlacklist("not-an-ip", "invalid", 0)
|
||||
if err == nil {
|
||||
t.Fatal("无效 IP 应返回错误")
|
||||
}
|
||||
t.Logf("无效 IP 错误: %v", err)
|
||||
}
|
||||
|
||||
func TestIPFilter_RemoveFromBlacklist(t *testing.T) {
|
||||
f := NewIPFilter()
|
||||
_ = f.AddToBlacklist("1.2.3.4", "test", 0)
|
||||
f.RemoveFromBlacklist("1.2.3.4")
|
||||
|
||||
blocked, _ := f.IsBlocked("1.2.3.4")
|
||||
if blocked {
|
||||
t.Fatal("移除黑名单后不应被封禁")
|
||||
}
|
||||
}
|
||||
|
||||
// ---- AnomalyDetector 测试 ----
|
||||
|
||||
func TestAnomalyDetector_BruteForce(t *testing.T) {
|
||||
ipFilter := NewIPFilter()
|
||||
cfg := AnomalyDetectorConfig{
|
||||
MaxRecordsPerUser: 50,
|
||||
Window: time.Minute,
|
||||
MaxFailures: 5,
|
||||
MaxDistinctIPs: 10,
|
||||
AutoBlockDuration: time.Minute,
|
||||
}
|
||||
detector := NewAnomalyDetector(cfg, ipFilter)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID = int64(42)
|
||||
const ip = "6.6.6.6"
|
||||
|
||||
// 正常失败,未达阈值
|
||||
for i := 0; i < 4; i++ {
|
||||
events := detector.RecordLogin(ctx, userID, ip, "", "", false)
|
||||
if len(events) > 0 {
|
||||
t.Fatalf("第 %d 次失败不应触发告警", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
// 第 5 次失败触发暴力破解告警
|
||||
events := detector.RecordLogin(ctx, userID, ip, "", "", false)
|
||||
hasBruteForce := false
|
||||
for _, e := range events {
|
||||
if e == AnomalyBruteForce {
|
||||
hasBruteForce = true
|
||||
}
|
||||
}
|
||||
if !hasBruteForce {
|
||||
t.Fatalf("第 5 次失败应触发 brute_force 告警,实际 events=%v", events)
|
||||
}
|
||||
t.Log("暴力破解检测正常触发")
|
||||
|
||||
// 验证 IP 被自动封禁
|
||||
blocked, _ := ipFilter.IsBlocked(ip)
|
||||
if !blocked {
|
||||
t.Fatal("暴力破解后该 IP 应被自动封禁")
|
||||
}
|
||||
t.Log("IP 自动封禁验证通过")
|
||||
}
|
||||
|
||||
func TestAnomalyDetector_MultipleIPs(t *testing.T) {
|
||||
ipFilter := NewIPFilter()
|
||||
cfg := AnomalyDetectorConfig{
|
||||
MaxRecordsPerUser: 50,
|
||||
Window: time.Minute,
|
||||
MaxFailures: 100,
|
||||
MaxDistinctIPs: 3,
|
||||
AutoBlockDuration: time.Minute,
|
||||
}
|
||||
detector := NewAnomalyDetector(cfg, ipFilter)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID = int64(99)
|
||||
ips := []string{"1.1.1.1", "2.2.2.2", "3.3.3.3"}
|
||||
|
||||
for _, ip := range ips {
|
||||
detector.RecordLogin(ctx, userID, ip, "", "", true)
|
||||
}
|
||||
|
||||
// 第 3 个不同 IP 时触发 multiple_ip 告警
|
||||
events := detector.RecordLogin(ctx, userID, "4.4.4.4", "", "", true)
|
||||
hasMultiIP := false
|
||||
for _, e := range events {
|
||||
if e == AnomalyMultipleIP {
|
||||
hasMultiIP = true
|
||||
}
|
||||
}
|
||||
if !hasMultiIP {
|
||||
t.Fatalf("4 个不同 IP 应触发 multiple_ip 告警,实际 events=%v", events)
|
||||
}
|
||||
t.Log("多 IP 检测正常触发")
|
||||
}
|
||||
|
||||
func TestAnomalyDetector_GetRecentLogins(t *testing.T) {
|
||||
detector := NewAnomalyDetector(DefaultAnomalyConfig, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID = int64(1)
|
||||
for i := 0; i < 5; i++ {
|
||||
detector.RecordLogin(ctx, userID, "8.8.8.8", "", "", true)
|
||||
}
|
||||
|
||||
recent := detector.GetRecentLogins(userID, 3)
|
||||
if len(recent) != 3 {
|
||||
t.Fatalf("期望获取 3 条记录,实际 %d", len(recent))
|
||||
}
|
||||
}
|
||||
|
||||
// ---- 现有 ratelimit/validator/encryption 补充测试 ----
|
||||
|
||||
func TestValidateIPOrCIDR(t *testing.T) {
|
||||
cases := []struct {
|
||||
input string
|
||||
wantErr bool
|
||||
}{
|
||||
{"192.168.1.1", false},
|
||||
{"10.0.0.0/8", false},
|
||||
{"2001:db8::1", false},
|
||||
{"not-ip", true},
|
||||
{"999.999.999.999", true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
err := validateIPOrCIDR(tc.input)
|
||||
if tc.wantErr && err == nil {
|
||||
t.Errorf("输入 %q 期望出错,但没有错误", tc.input)
|
||||
}
|
||||
if !tc.wantErr && err != nil {
|
||||
t.Errorf("输入 %q 不期望出错,但得到: %v", tc.input, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
60
internal/security/password_policy.go
Normal file
60
internal/security/password_policy.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// PasswordPolicy defines the runtime password rules enforced by services.
|
||||
type PasswordPolicy struct {
|
||||
MinLength int
|
||||
RequireSpecial bool
|
||||
RequireNumber bool
|
||||
}
|
||||
|
||||
// Normalize fills in safe defaults for unset policy fields.
|
||||
func (p PasswordPolicy) Normalize() PasswordPolicy {
|
||||
if p.MinLength <= 0 {
|
||||
p.MinLength = 8
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// Validate checks whether the password satisfies the configured policy.
|
||||
func (p PasswordPolicy) Validate(password string) error {
|
||||
p = p.Normalize()
|
||||
|
||||
if utf8.RuneCountInString(password) < p.MinLength {
|
||||
return fmt.Errorf("密码长度不能少于%d位", p.MinLength)
|
||||
}
|
||||
|
||||
var hasUpper, hasLower, hasNumber, hasSpecial bool
|
||||
for _, ch := range password {
|
||||
switch {
|
||||
case unicode.IsUpper(ch):
|
||||
hasUpper = true
|
||||
case unicode.IsLower(ch):
|
||||
hasLower = true
|
||||
case unicode.IsDigit(ch):
|
||||
hasNumber = true
|
||||
case unicode.IsPunct(ch) || unicode.IsSymbol(ch):
|
||||
hasSpecial = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasUpper {
|
||||
return fmt.Errorf("密码必须包含大写字母")
|
||||
}
|
||||
if !hasLower {
|
||||
return fmt.Errorf("密码必须包含小写字母")
|
||||
}
|
||||
if p.RequireNumber && !hasNumber {
|
||||
return fmt.Errorf("密码必须包含数字")
|
||||
}
|
||||
if p.RequireSpecial && !hasSpecial {
|
||||
return fmt.Errorf("密码必须包含特殊字符")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
184
internal/security/ratelimit.go
Normal file
184
internal/security/ratelimit.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RateLimitAlgorithm 限流算法类型
|
||||
type RateLimitAlgorithm string
|
||||
|
||||
const (
|
||||
AlgorithmTokenBucket RateLimitAlgorithm = "token_bucket"
|
||||
AlgorithmLeakyBucket RateLimitAlgorithm = "leaky_bucket"
|
||||
AlgorithmSlidingWindow RateLimitAlgorithm = "sliding_window"
|
||||
AlgorithmFixedWindow RateLimitAlgorithm = "fixed_window"
|
||||
)
|
||||
|
||||
// TokenBucket 令牌桶算法
|
||||
type TokenBucket struct {
|
||||
capacity int64
|
||||
tokens int64
|
||||
rate int64 // 每秒产生的令牌数
|
||||
lastRefill time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewTokenBucket 创建令牌桶
|
||||
func NewTokenBucket(capacity, rate int64) *TokenBucket {
|
||||
return &TokenBucket{
|
||||
capacity: capacity,
|
||||
tokens: capacity,
|
||||
rate: rate,
|
||||
lastRefill: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Allow 检查是否允许访问
|
||||
func (tb *TokenBucket) Allow() bool {
|
||||
tb.mu.Lock()
|
||||
defer tb.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(tb.lastRefill).Seconds()
|
||||
|
||||
// 计算需要补充的令牌数
|
||||
refillTokens := int64(elapsed * float64(tb.rate))
|
||||
tb.tokens += refillTokens
|
||||
if tb.tokens > tb.capacity {
|
||||
tb.tokens = tb.capacity
|
||||
}
|
||||
tb.lastRefill = now
|
||||
|
||||
// 检查是否有足够的令牌
|
||||
if tb.tokens > 0 {
|
||||
tb.tokens--
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// LeakyBucket 漏桶算法
|
||||
type LeakyBucket struct {
|
||||
capacity int64
|
||||
water int64
|
||||
rate int64 // 每秒漏出的水量
|
||||
lastLeak time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewLeakyBucket 创建漏桶
|
||||
func NewLeakyBucket(capacity, rate int64) *LeakyBucket {
|
||||
return &LeakyBucket{
|
||||
capacity: capacity,
|
||||
water: 0,
|
||||
rate: rate,
|
||||
lastLeak: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Allow 检查是否允许访问
|
||||
func (lb *LeakyBucket) Allow() bool {
|
||||
lb.mu.Lock()
|
||||
defer lb.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(lb.lastLeak).Seconds()
|
||||
|
||||
// 计算漏出的水量
|
||||
leakWater := int64(elapsed * float64(lb.rate))
|
||||
lb.water -= leakWater
|
||||
if lb.water < 0 {
|
||||
lb.water = 0
|
||||
}
|
||||
lb.lastLeak = now
|
||||
|
||||
// 检查桶是否已满
|
||||
if lb.water < lb.capacity {
|
||||
lb.water++
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// SlidingWindow 滑动窗口算法
|
||||
type SlidingWindow struct {
|
||||
window time.Duration
|
||||
capacity int64
|
||||
requests []time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewSlidingWindow 创建滑动窗口
|
||||
func NewSlidingWindow(window time.Duration, capacity int64) *SlidingWindow {
|
||||
return &SlidingWindow{
|
||||
window: window,
|
||||
capacity: capacity,
|
||||
requests: make([]time.Time, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// Allow 检查是否允许访问
|
||||
func (sw *SlidingWindow) Allow() bool {
|
||||
sw.mu.Lock()
|
||||
defer sw.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// 移除窗口外的请求
|
||||
validRequests := make([]time.Time, 0)
|
||||
for _, req := range sw.requests {
|
||||
if now.Sub(req) < sw.window {
|
||||
validRequests = append(validRequests, req)
|
||||
}
|
||||
}
|
||||
sw.requests = validRequests
|
||||
|
||||
// 检查是否超过容量
|
||||
if int64(len(sw.requests)) < sw.capacity {
|
||||
sw.requests = append(sw.requests, now)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// RateLimiter 限流器
|
||||
type RateLimiter struct {
|
||||
algorithm RateLimitAlgorithm
|
||||
limiter interface{}
|
||||
}
|
||||
|
||||
// NewRateLimiter 创建限流器
|
||||
func NewRateLimiter(algorithm RateLimitAlgorithm, capacity, rate int64, window time.Duration) *RateLimiter {
|
||||
limiter := &RateLimiter{algorithm: algorithm}
|
||||
|
||||
switch algorithm {
|
||||
case AlgorithmTokenBucket:
|
||||
limiter.limiter = NewTokenBucket(capacity, rate)
|
||||
case AlgorithmLeakyBucket:
|
||||
limiter.limiter = NewLeakyBucket(capacity, rate)
|
||||
case AlgorithmSlidingWindow:
|
||||
limiter.limiter = NewSlidingWindow(window, capacity)
|
||||
default:
|
||||
limiter.limiter = NewSlidingWindow(window, capacity)
|
||||
}
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// Allow 检查是否允许访问
|
||||
func (rl *RateLimiter) Allow() bool {
|
||||
switch rl.algorithm {
|
||||
case AlgorithmTokenBucket:
|
||||
return rl.limiter.(*TokenBucket).Allow()
|
||||
case AlgorithmLeakyBucket:
|
||||
return rl.limiter.(*LeakyBucket).Allow()
|
||||
case AlgorithmSlidingWindow:
|
||||
return rl.limiter.(*SlidingWindow).Allow()
|
||||
default:
|
||||
return rl.limiter.(*SlidingWindow).Allow()
|
||||
}
|
||||
}
|
||||
185
internal/security/validator.go
Normal file
185
internal/security/validator.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Validator groups lightweight validation and sanitization helpers.
|
||||
type Validator struct {
|
||||
passwordMinLength int
|
||||
passwordRequireSpecial bool
|
||||
passwordRequireNumber bool
|
||||
}
|
||||
|
||||
// NewValidator creates a validator with the configured password rules.
|
||||
func NewValidator(minLength int, requireSpecial, requireNumber bool) *Validator {
|
||||
return &Validator{
|
||||
passwordMinLength: minLength,
|
||||
passwordRequireSpecial: requireSpecial,
|
||||
passwordRequireNumber: requireNumber,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateEmail validates email format.
|
||||
func (v *Validator) ValidateEmail(email string) bool {
|
||||
if email == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
pattern := `^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`
|
||||
matched, _ := regexp.MatchString(pattern, email)
|
||||
return matched
|
||||
}
|
||||
|
||||
// ValidatePhone validates mainland China mobile numbers.
|
||||
func (v *Validator) ValidatePhone(phone string) bool {
|
||||
if phone == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
pattern := `^1[3-9]\d{9}$`
|
||||
matched, _ := regexp.MatchString(pattern, phone)
|
||||
return matched
|
||||
}
|
||||
|
||||
// ValidateUsername validates usernames.
|
||||
func (v *Validator) ValidateUsername(username string) bool {
|
||||
if username == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
pattern := `^[a-zA-Z][a-zA-Z0-9_]{3,19}$`
|
||||
matched, _ := regexp.MatchString(pattern, username)
|
||||
return matched
|
||||
}
|
||||
|
||||
// ValidatePassword validates passwords using the shared runtime policy.
|
||||
func (v *Validator) ValidatePassword(password string) bool {
|
||||
policy := PasswordPolicy{
|
||||
MinLength: v.passwordMinLength,
|
||||
RequireSpecial: v.passwordRequireSpecial,
|
||||
RequireNumber: v.passwordRequireNumber,
|
||||
}
|
||||
|
||||
return policy.Validate(password) == nil
|
||||
}
|
||||
|
||||
// SanitizeSQL removes obviously dangerous SQL injection patterns using regex.
|
||||
// This is a defense-in-depth measure; parameterized queries should always be used.
|
||||
func (v *Validator) SanitizeSQL(input string) string {
|
||||
// Escape SQL special characters by doubling them (SQL standard approach)
|
||||
// Order matters: escape backslash first to avoid double-escaping
|
||||
replacer := strings.NewReplacer(
|
||||
`\`, `\\`,
|
||||
`'`, `''`,
|
||||
`"`, `""`,
|
||||
)
|
||||
|
||||
// Remove common SQL injection patterns that could bypass quoting
|
||||
dangerousPatterns := []string{
|
||||
`;[\s]*--`, // SQL comment
|
||||
`/\*.*?\*/`, // Block comment (non-greedy)
|
||||
`\bxp_\w+`, // Extended stored procedures
|
||||
`\bexec[\s\(]`, // EXEC statements
|
||||
`\bsp_\w+`, // System stored procedures
|
||||
`\bwaitfor[\s]+delay`, // Time-based blind SQL injection
|
||||
`\bunion[\s]+select`, // UNION injection
|
||||
`\bdrop[\s]+table`, // DROP TABLE
|
||||
`\binsert[\s]+into`, // INSERT
|
||||
`\bupdate[\s]+\w+[\s]+set`, // UPDATE
|
||||
`\bdelete[\s]+from`, // DELETE
|
||||
}
|
||||
|
||||
result := replacer.Replace(input)
|
||||
|
||||
// Apply pattern removal
|
||||
for _, pattern := range dangerousPatterns {
|
||||
re := regexp.MustCompile(`(?i)` + pattern) // Case-insensitive
|
||||
result = re.ReplaceAllString(result, "")
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// SanitizeXSS removes obviously dangerous XSS patterns using regex.
|
||||
// This is a defense-in-depth measure; output encoding should always be used.
|
||||
func (v *Validator) SanitizeXSS(input string) string {
|
||||
// Remove dangerous tags and attributes using pattern matching
|
||||
dangerousPatterns := []struct {
|
||||
pattern string
|
||||
replaceAll bool
|
||||
}{
|
||||
{`(?i)<script[^>]*>.*?</script>`, true}, // Script tags
|
||||
{`(?i)</script>`, false}, // Closing script
|
||||
{`(?i)<iframe[^>]*>.*?</iframe>`, true}, // Iframe injection
|
||||
{`(?i)<object[^>]*>.*?</object>`, true}, // Object injection
|
||||
{`(?i)<embed[^>]*>.*?</embed>`, true}, // Embed injection
|
||||
{`(?i)<applet[^>]*>.*?</applet>`, true}, // Applet injection
|
||||
{`(?i)javascript\s*:`, false}, // JavaScript protocol
|
||||
{`(?i)vbscript\s*:`, false}, // VBScript protocol
|
||||
{`(?i)data\s*:`, false}, // Data URL protocol
|
||||
{`(?i)on\w+\s*=`, false}, // Event handlers
|
||||
{`(?i)<style[^>]*>.*?</style>`, true}, // Style injection
|
||||
}
|
||||
|
||||
result := input
|
||||
|
||||
for _, p := range dangerousPatterns {
|
||||
re := regexp.MustCompile(p.pattern)
|
||||
if p.replaceAll {
|
||||
result = re.ReplaceAllString(result, "")
|
||||
} else {
|
||||
result = re.ReplaceAllString(result, "")
|
||||
}
|
||||
}
|
||||
|
||||
// Encode < and > to prevent tag construction
|
||||
result = strings.ReplaceAll(result, "<", "<")
|
||||
result = strings.ReplaceAll(result, ">", ">")
|
||||
|
||||
// Restore entities if they were part of legitimate content
|
||||
result = strings.ReplaceAll(result, "<", "<")
|
||||
result = strings.ReplaceAll(result, ">", ">")
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidateURL validates a basic HTTP/HTTPS URL.
|
||||
func (v *Validator) ValidateURL(url string) bool {
|
||||
if url == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
pattern := `^https?://[a-zA-Z0-9\-._~:/?#[\]@!$&'()*+,;=]+$`
|
||||
matched, _ := regexp.MatchString(pattern, url)
|
||||
return matched
|
||||
}
|
||||
|
||||
// ValidateIP validates IPv4 or IPv6 addresses using net.ParseIP.
|
||||
// Supports all valid formats including compressed IPv6 (::1, fe80::1, etc.)
|
||||
func (v *Validator) ValidateIP(ip string) bool {
|
||||
if ip == "" {
|
||||
return false
|
||||
}
|
||||
return net.ParseIP(ip) != nil
|
||||
}
|
||||
|
||||
// ValidateIPv4 validates IPv4 addresses only.
|
||||
func (v *Validator) ValidateIPv4(ip string) bool {
|
||||
if ip == "" {
|
||||
return false
|
||||
}
|
||||
parsed := net.ParseIP(ip)
|
||||
return parsed != nil && parsed.To4() != nil
|
||||
}
|
||||
|
||||
// ValidateIPv6 validates IPv6 addresses only.
|
||||
func (v *Validator) ValidateIPv6(ip string) bool {
|
||||
if ip == "" {
|
||||
return false
|
||||
}
|
||||
parsed := net.ParseIP(ip)
|
||||
return parsed != nil && parsed.To4() == nil
|
||||
}
|
||||
Reference in New Issue
Block a user