feat: backend core - auth, user, role, permission, device, webhook, monitoring, cache, repository, service, middleware, API handlers

This commit is contained in:
2026-04-02 11:19:50 +08:00
parent e59a77bc49
commit dcc1f186f8
298 changed files with 62603 additions and 0 deletions

View File

@@ -0,0 +1,95 @@
package security
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"errors"
"io"
"strings"
)
// Encryption 加密工具
type Encryption struct {
key []byte
}
// NewEncryption 创建加密工具密钥长度必须是16, 24或32字节
func NewEncryption(key string) (*Encryption, error) {
if len(key) != 16 && len(key) != 24 && len(key) != 32 {
return nil, errors.New("key length must be 16, 24 or 32 bytes")
}
return &Encryption{key: []byte(key)}, nil
}
// Encrypt 使用AES-256-GCM加密
func (e *Encryption) Encrypt(plaintext string) (string, error) {
block, err := aes.NewCipher(e.key)
if err != nil {
return "", err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
nonce := make([]byte, gcm.NonceSize())
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
return "", err
}
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// Decrypt 使用AES-256-GCM解密
func (e *Encryption) Decrypt(ciphertext string) (string, error) {
data, err := base64.StdEncoding.DecodeString(ciphertext)
if err != nil {
return "", err
}
block, err := aes.NewCipher(e.key)
if err != nil {
return "", err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return "", errors.New("ciphertext too short")
}
nonce, cipherData := data[:nonceSize], data[nonceSize:]
plaintext, err := gcm.Open(nil, nonce, cipherData, nil)
if err != nil {
return "", err
}
return string(plaintext), nil
}
// MaskEmail 邮箱脱敏
func MaskEmail(email string) string {
if email == "" {
return ""
}
prefix := email[:3]
suffix := email[strings.Index(email, "@"):]
return prefix + "***" + suffix
}
// MaskPhone 手机号脱敏
func MaskPhone(phone string) string {
if len(phone) != 11 {
return phone
}
return phone[:3] + "****" + phone[7:]
}

View File

@@ -0,0 +1,373 @@
package security
import (
"context"
"fmt"
"net"
"sync"
"time"
)
// IPRule IP 规则
type IPRule struct {
IP string // CIDR 或精确 IP
Reason string // 封禁原因
ExpireAt time.Time // 过期时间(零值表示永久)
CreatedAt time.Time
}
// isExpired 是否已过期
func (r *IPRule) isExpired() bool {
return !r.ExpireAt.IsZero() && time.Now().After(r.ExpireAt)
}
// IPFilter IP 黑白名单过滤器
type IPFilter struct {
mu sync.RWMutex
blacklist map[string]*IPRule // key: IP/CIDR
whitelist map[string]*IPRule // key: IP/CIDR
}
// NewIPFilter 创建 IP 过滤器
func NewIPFilter() *IPFilter {
return &IPFilter{
blacklist: make(map[string]*IPRule),
whitelist: make(map[string]*IPRule),
}
}
// AddToBlacklist 将 IP/CIDR 加入黑名单
// duration 为 0 表示永久封禁
func (f *IPFilter) AddToBlacklist(ip, reason string, duration time.Duration) error {
if err := validateIPOrCIDR(ip); err != nil {
return err
}
f.mu.Lock()
defer f.mu.Unlock()
rule := &IPRule{
IP: ip,
Reason: reason,
CreatedAt: time.Now(),
}
if duration > 0 {
rule.ExpireAt = time.Now().Add(duration)
}
f.blacklist[ip] = rule
return nil
}
// RemoveFromBlacklist 从黑名单移除
func (f *IPFilter) RemoveFromBlacklist(ip string) {
f.mu.Lock()
defer f.mu.Unlock()
delete(f.blacklist, ip)
}
// AddToWhitelist 将 IP/CIDR 加入白名单
func (f *IPFilter) AddToWhitelist(ip, reason string) error {
if err := validateIPOrCIDR(ip); err != nil {
return err
}
f.mu.Lock()
defer f.mu.Unlock()
f.whitelist[ip] = &IPRule{
IP: ip,
Reason: reason,
CreatedAt: time.Now(),
}
return nil
}
// RemoveFromWhitelist 从白名单移除
func (f *IPFilter) RemoveFromWhitelist(ip string) {
f.mu.Lock()
defer f.mu.Unlock()
delete(f.whitelist, ip)
}
// IsBlocked 检查 IP 是否被封禁
// 白名单优先:白名单中的 IP 永远不被封禁
func (f *IPFilter) IsBlocked(ip string) (bool, string) {
f.mu.RLock()
defer f.mu.RUnlock()
// 白名单检查(优先)
if f.matchesAnyRule(ip, f.whitelist) {
return false, ""
}
// 黑名单检查
for _, rule := range f.blacklist {
if rule.isExpired() {
continue
}
if matchIP(ip, rule.IP) {
return true, rule.Reason
}
}
return false, ""
}
// CleanExpired 清理过期规则
func (f *IPFilter) CleanExpired() {
f.mu.Lock()
defer f.mu.Unlock()
for k, rule := range f.blacklist {
if rule.isExpired() {
delete(f.blacklist, k)
}
}
}
// ListBlacklist 列出黑名单(不含已过期的)
func (f *IPFilter) ListBlacklist() []*IPRule {
f.mu.RLock()
defer f.mu.RUnlock()
var result []*IPRule
for _, rule := range f.blacklist {
if !rule.isExpired() {
result = append(result, rule)
}
}
return result
}
// ListWhitelist 列出白名单
func (f *IPFilter) ListWhitelist() []*IPRule {
f.mu.RLock()
defer f.mu.RUnlock()
var result []*IPRule
for _, rule := range f.whitelist {
result = append(result, rule)
}
return result
}
// matchesAnyRule 检查 IP 是否匹配任意规则集
func (f *IPFilter) matchesAnyRule(ip string, rules map[string]*IPRule) bool {
for _, rule := range rules {
if matchIP(ip, rule.IP) {
return true
}
}
return false
}
// matchIP 检查 ip 是否匹配 target精确 IP 或 CIDR
func matchIP(ip, target string) bool {
if ip == target {
return true
}
// 尝试 CIDR 匹配
_, network, err := net.ParseCIDR(target)
if err != nil {
return false
}
parsed := net.ParseIP(ip)
return parsed != nil && network.Contains(parsed)
}
// validateIPOrCIDR 验证 IP 或 CIDR 格式
func validateIPOrCIDR(s string) error {
if net.ParseIP(s) != nil {
return nil
}
if _, _, err := net.ParseCIDR(s); err == nil {
return nil
}
return fmt.Errorf("无效的 IP 或 CIDR 格式: %s", s)
}
// ---- 异常登录检测 ----
// AnomalyEvent 异常登录事件类型
type AnomalyEvent string
const (
AnomalyBruteForce AnomalyEvent = "brute_force" // 暴力破解(短时间大量失败)
AnomalyNewLocation AnomalyEvent = "new_location" // 新地区登录
AnomalyMultipleIP AnomalyEvent = "multiple_ip" // 短时间内多个 IP 登录
AnomalyOffHours AnomalyEvent = "off_hours" // 非工作时间登录(可配置)
AnomalyNewDevice AnomalyEvent = "new_device" // 新设备登录
AnomalySuspicious AnomalyEvent = "suspicious" // 可疑活动(综合判断)
)
// LoginRecord 登录记录
type LoginRecord struct {
UserID int64
IP string
Location string // 登录地区
DeviceFingerprint string // 设备指纹
Success bool
Timestamp time.Time
}
// AnomalyDetector 异常登录检测器
type AnomalyDetector struct {
mu sync.Mutex
records map[int64][]LoginRecord // userID -> 最近登录记录
maxRecords int // 每用户保留的最大记录数
windowSize time.Duration // 检测时间窗口
maxFailures int // 窗口内最大失败次数(触发暴力破解告警)
maxIPs int // 窗口内最大不同 IP 数(触发多 IP 告警)
ipFilter *IPFilter // 用于自动封禁
autoBlockDur time.Duration // 自动封禁时长
knownLocationsLimit int // 常用地区数量阈值
knownDevicesLimit int // 已知设备数量阈值
}
// AnomalyDetectorConfig 检测器配置
type AnomalyDetectorConfig struct {
MaxRecordsPerUser int
Window time.Duration
MaxFailures int
MaxDistinctIPs int
AutoBlockDuration time.Duration
// 跨区域检测配置
KnownLocationsLimit int // 常用地区数量阈值,超过则不再告警新地区(默认 5
// 新设备检测配置
KnownDevicesLimit int // 已知设备数量阈值,超过则不再告警新设备(默认 10
}
// DefaultAnomalyConfig 默认配置
var DefaultAnomalyConfig = AnomalyDetectorConfig{
MaxRecordsPerUser: 100,
Window: 15 * time.Minute,
MaxFailures: 10,
MaxDistinctIPs: 5,
AutoBlockDuration: 30 * time.Minute,
KnownLocationsLimit: 5,
KnownDevicesLimit: 10,
}
// NewAnomalyDetector 创建异常登录检测器
func NewAnomalyDetector(cfg AnomalyDetectorConfig, ipFilter *IPFilter) *AnomalyDetector {
if cfg.KnownLocationsLimit <= 0 {
cfg.KnownLocationsLimit = 5
}
if cfg.KnownDevicesLimit <= 0 {
cfg.KnownDevicesLimit = 10
}
return &AnomalyDetector{
records: make(map[int64][]LoginRecord),
maxRecords: cfg.MaxRecordsPerUser,
windowSize: cfg.Window,
maxFailures: cfg.MaxFailures,
maxIPs: cfg.MaxDistinctIPs,
ipFilter: ipFilter,
autoBlockDur: cfg.AutoBlockDuration,
knownLocationsLimit: cfg.KnownLocationsLimit,
knownDevicesLimit: cfg.KnownDevicesLimit,
}
}
// RecordLogin 记录登录事件,返回检测到的异常列表
// location: 登录地区信息(如"广东省广州市"
// deviceFingerprint: 设备指纹如浏览器的UserAgent+屏幕分辨率+时区等组合hash
func (d *AnomalyDetector) RecordLogin(_ context.Context, userID int64, ip, location, deviceFingerprint string, success bool) []AnomalyEvent {
d.mu.Lock()
defer d.mu.Unlock()
now := time.Now()
record := LoginRecord{
UserID: userID,
IP: ip,
Location: location,
DeviceFingerprint: deviceFingerprint,
Success: success,
Timestamp: now,
}
// 追加记录,保留最新的 maxRecords 条
records := append(d.records[userID], record)
if len(records) > d.maxRecords {
records = records[len(records)-d.maxRecords:]
}
d.records[userID] = records
// 检测异常
return d.detect(userID, records, now)
}
// detect 在持有锁的情况下检测异常
func (d *AnomalyDetector) detect(userID int64, records []LoginRecord, now time.Time) []AnomalyEvent {
windowStart := now.Add(-d.windowSize)
var failures int
ipSet := make(map[string]struct{})
locationSet := make(map[string]struct{}) // 历史地区集合
deviceSet := make(map[string]struct{}) // 历史设备集合
var currentLocation string
var currentDeviceFingerprint string
for _, r := range records {
if r.Timestamp.Before(windowStart) {
continue
}
if !r.Success {
failures++
}
ipSet[r.IP] = struct{}{}
// 记录当前登录的 location 和 deviceFingerprint最后一个在窗口内的记录
currentLocation = r.Location
currentDeviceFingerprint = r.DeviceFingerprint
if r.Location != "" {
locationSet[r.Location] = struct{}{}
}
if r.DeviceFingerprint != "" {
deviceSet[r.DeviceFingerprint] = struct{}{}
}
}
var events []AnomalyEvent
// 暴力破解检测
if failures >= d.maxFailures {
events = append(events, AnomalyBruteForce)
// 自动封禁
if d.ipFilter != nil && len(ipSet) == 1 {
for ip := range ipSet {
_ = d.ipFilter.AddToBlacklist(ip,
fmt.Sprintf("自动封禁:用户 %d 暴力破解检测", userID),
d.autoBlockDur,
)
}
}
}
// 多 IP 登录检测
if len(ipSet) >= d.maxIPs {
events = append(events, AnomalyMultipleIP)
}
// 新地区登录检测
// 如果当前登录地区与历史记录都不相同,且历史地区数量在阈值内,则告警
if currentLocation != "" && len(locationSet) > 0 {
if _, seen := locationSet[currentLocation]; !seen && len(locationSet) <= d.knownLocationsLimit {
events = append(events, AnomalyNewLocation)
}
}
// 新设备登录检测
// 如果当前设备指纹与历史记录都不相同,且历史设备数量在阈值内,则告警
if currentDeviceFingerprint != "" && len(deviceSet) > 0 {
if _, seen := deviceSet[currentDeviceFingerprint]; !seen && len(deviceSet) <= d.knownDevicesLimit {
events = append(events, AnomalyNewDevice)
}
}
return events
}
// GetRecentLogins 获取用户最近的登录记录
func (d *AnomalyDetector) GetRecentLogins(userID int64, limit int) []LoginRecord {
d.mu.Lock()
defer d.mu.Unlock()
records := d.records[userID]
if len(records) <= limit {
return records
}
return records[len(records)-limit:]
}

View File

@@ -0,0 +1,234 @@
package security
import (
"context"
"testing"
"time"
)
// ---- IPFilter 测试 ----
func TestIPFilter_BlacklistBasic(t *testing.T) {
f := NewIPFilter()
// 未加入黑名单时IP 应该通过
blocked, reason := f.IsBlocked("192.168.1.1")
if blocked {
t.Fatalf("未加入黑名单时不应被封禁reason=%s", reason)
}
// 加入黑名单
if err := f.AddToBlacklist("192.168.1.1", "测试封禁", 0); err != nil {
t.Fatalf("AddToBlacklist 失败: %v", err)
}
blocked, reason = f.IsBlocked("192.168.1.1")
if !blocked {
t.Fatal("加入黑名单后应该被封禁")
}
if reason == "" {
t.Fatal("封禁原因不应为空")
}
t.Logf("正确封禁reason=%s", reason)
}
func TestIPFilter_BlacklistExpiry(t *testing.T) {
f := NewIPFilter()
// 加入 50ms 后过期的黑名单
if err := f.AddToBlacklist("10.0.0.1", "临时封禁", 50*time.Millisecond); err != nil {
t.Fatalf("AddToBlacklist 失败: %v", err)
}
blocked, _ := f.IsBlocked("10.0.0.1")
if !blocked {
t.Fatal("封禁期间应该被拦截")
}
// 等待过期
time.Sleep(100 * time.Millisecond)
blocked, _ = f.IsBlocked("10.0.0.1")
if blocked {
t.Fatal("过期后不应该再被封禁")
}
t.Log("过期解封正常")
}
func TestIPFilter_WhitelistOverridesBlacklist(t *testing.T) {
f := NewIPFilter()
// 同时加入黑名单和白名单
_ = f.AddToBlacklist("172.16.0.1", "黑名单", 0)
_ = f.AddToWhitelist("172.16.0.1", "白名单优先")
blocked, _ := f.IsBlocked("172.16.0.1")
if blocked {
t.Fatal("白名单应优先于黑名单")
}
t.Log("白名单优先级验证通过")
}
func TestIPFilter_CIDRMatch(t *testing.T) {
f := NewIPFilter()
// 封禁整个 /24 段
if err := f.AddToBlacklist("10.10.10.0/24", "封禁 C 段", 0); err != nil {
t.Fatalf("CIDR 黑名单失败: %v", err)
}
cases := []struct {
ip string
blocked bool
}{
{"10.10.10.1", true},
{"10.10.10.254", true},
{"10.10.11.1", false},
{"192.168.1.1", false},
}
for _, tc := range cases {
blocked, _ := f.IsBlocked(tc.ip)
if blocked != tc.blocked {
t.Errorf("IP %s: 期望 blocked=%v实际=%v", tc.ip, tc.blocked, blocked)
}
}
}
func TestIPFilter_InvalidIP(t *testing.T) {
f := NewIPFilter()
err := f.AddToBlacklist("not-an-ip", "invalid", 0)
if err == nil {
t.Fatal("无效 IP 应返回错误")
}
t.Logf("无效 IP 错误: %v", err)
}
func TestIPFilter_RemoveFromBlacklist(t *testing.T) {
f := NewIPFilter()
_ = f.AddToBlacklist("1.2.3.4", "test", 0)
f.RemoveFromBlacklist("1.2.3.4")
blocked, _ := f.IsBlocked("1.2.3.4")
if blocked {
t.Fatal("移除黑名单后不应被封禁")
}
}
// ---- AnomalyDetector 测试 ----
func TestAnomalyDetector_BruteForce(t *testing.T) {
ipFilter := NewIPFilter()
cfg := AnomalyDetectorConfig{
MaxRecordsPerUser: 50,
Window: time.Minute,
MaxFailures: 5,
MaxDistinctIPs: 10,
AutoBlockDuration: time.Minute,
}
detector := NewAnomalyDetector(cfg, ipFilter)
ctx := context.Background()
const userID = int64(42)
const ip = "6.6.6.6"
// 正常失败,未达阈值
for i := 0; i < 4; i++ {
events := detector.RecordLogin(ctx, userID, ip, "", "", false)
if len(events) > 0 {
t.Fatalf("第 %d 次失败不应触发告警", i+1)
}
}
// 第 5 次失败触发暴力破解告警
events := detector.RecordLogin(ctx, userID, ip, "", "", false)
hasBruteForce := false
for _, e := range events {
if e == AnomalyBruteForce {
hasBruteForce = true
}
}
if !hasBruteForce {
t.Fatalf("第 5 次失败应触发 brute_force 告警,实际 events=%v", events)
}
t.Log("暴力破解检测正常触发")
// 验证 IP 被自动封禁
blocked, _ := ipFilter.IsBlocked(ip)
if !blocked {
t.Fatal("暴力破解后该 IP 应被自动封禁")
}
t.Log("IP 自动封禁验证通过")
}
func TestAnomalyDetector_MultipleIPs(t *testing.T) {
ipFilter := NewIPFilter()
cfg := AnomalyDetectorConfig{
MaxRecordsPerUser: 50,
Window: time.Minute,
MaxFailures: 100,
MaxDistinctIPs: 3,
AutoBlockDuration: time.Minute,
}
detector := NewAnomalyDetector(cfg, ipFilter)
ctx := context.Background()
const userID = int64(99)
ips := []string{"1.1.1.1", "2.2.2.2", "3.3.3.3"}
for _, ip := range ips {
detector.RecordLogin(ctx, userID, ip, "", "", true)
}
// 第 3 个不同 IP 时触发 multiple_ip 告警
events := detector.RecordLogin(ctx, userID, "4.4.4.4", "", "", true)
hasMultiIP := false
for _, e := range events {
if e == AnomalyMultipleIP {
hasMultiIP = true
}
}
if !hasMultiIP {
t.Fatalf("4 个不同 IP 应触发 multiple_ip 告警,实际 events=%v", events)
}
t.Log("多 IP 检测正常触发")
}
func TestAnomalyDetector_GetRecentLogins(t *testing.T) {
detector := NewAnomalyDetector(DefaultAnomalyConfig, nil)
ctx := context.Background()
const userID = int64(1)
for i := 0; i < 5; i++ {
detector.RecordLogin(ctx, userID, "8.8.8.8", "", "", true)
}
recent := detector.GetRecentLogins(userID, 3)
if len(recent) != 3 {
t.Fatalf("期望获取 3 条记录,实际 %d", len(recent))
}
}
// ---- 现有 ratelimit/validator/encryption 补充测试 ----
func TestValidateIPOrCIDR(t *testing.T) {
cases := []struct {
input string
wantErr bool
}{
{"192.168.1.1", false},
{"10.0.0.0/8", false},
{"2001:db8::1", false},
{"not-ip", true},
{"999.999.999.999", true},
}
for _, tc := range cases {
err := validateIPOrCIDR(tc.input)
if tc.wantErr && err == nil {
t.Errorf("输入 %q 期望出错,但没有错误", tc.input)
}
if !tc.wantErr && err != nil {
t.Errorf("输入 %q 不期望出错,但得到: %v", tc.input, err)
}
}
}

View File

@@ -0,0 +1,60 @@
package security
import (
"fmt"
"unicode"
"unicode/utf8"
)
// PasswordPolicy defines the runtime password rules enforced by services.
type PasswordPolicy struct {
MinLength int
RequireSpecial bool
RequireNumber bool
}
// Normalize fills in safe defaults for unset policy fields.
func (p PasswordPolicy) Normalize() PasswordPolicy {
if p.MinLength <= 0 {
p.MinLength = 8
}
return p
}
// Validate checks whether the password satisfies the configured policy.
func (p PasswordPolicy) Validate(password string) error {
p = p.Normalize()
if utf8.RuneCountInString(password) < p.MinLength {
return fmt.Errorf("密码长度不能少于%d位", p.MinLength)
}
var hasUpper, hasLower, hasNumber, hasSpecial bool
for _, ch := range password {
switch {
case unicode.IsUpper(ch):
hasUpper = true
case unicode.IsLower(ch):
hasLower = true
case unicode.IsDigit(ch):
hasNumber = true
case unicode.IsPunct(ch) || unicode.IsSymbol(ch):
hasSpecial = true
}
}
if !hasUpper {
return fmt.Errorf("密码必须包含大写字母")
}
if !hasLower {
return fmt.Errorf("密码必须包含小写字母")
}
if p.RequireNumber && !hasNumber {
return fmt.Errorf("密码必须包含数字")
}
if p.RequireSpecial && !hasSpecial {
return fmt.Errorf("密码必须包含特殊字符")
}
return nil
}

View File

@@ -0,0 +1,184 @@
package security
import (
"sync"
"time"
)
// RateLimitAlgorithm 限流算法类型
type RateLimitAlgorithm string
const (
AlgorithmTokenBucket RateLimitAlgorithm = "token_bucket"
AlgorithmLeakyBucket RateLimitAlgorithm = "leaky_bucket"
AlgorithmSlidingWindow RateLimitAlgorithm = "sliding_window"
AlgorithmFixedWindow RateLimitAlgorithm = "fixed_window"
)
// TokenBucket 令牌桶算法
type TokenBucket struct {
capacity int64
tokens int64
rate int64 // 每秒产生的令牌数
lastRefill time.Time
mu sync.Mutex
}
// NewTokenBucket 创建令牌桶
func NewTokenBucket(capacity, rate int64) *TokenBucket {
return &TokenBucket{
capacity: capacity,
tokens: capacity,
rate: rate,
lastRefill: time.Now(),
}
}
// Allow 检查是否允许访问
func (tb *TokenBucket) Allow() bool {
tb.mu.Lock()
defer tb.mu.Unlock()
now := time.Now()
elapsed := now.Sub(tb.lastRefill).Seconds()
// 计算需要补充的令牌数
refillTokens := int64(elapsed * float64(tb.rate))
tb.tokens += refillTokens
if tb.tokens > tb.capacity {
tb.tokens = tb.capacity
}
tb.lastRefill = now
// 检查是否有足够的令牌
if tb.tokens > 0 {
tb.tokens--
return true
}
return false
}
// LeakyBucket 漏桶算法
type LeakyBucket struct {
capacity int64
water int64
rate int64 // 每秒漏出的水量
lastLeak time.Time
mu sync.Mutex
}
// NewLeakyBucket 创建漏桶
func NewLeakyBucket(capacity, rate int64) *LeakyBucket {
return &LeakyBucket{
capacity: capacity,
water: 0,
rate: rate,
lastLeak: time.Now(),
}
}
// Allow 检查是否允许访问
func (lb *LeakyBucket) Allow() bool {
lb.mu.Lock()
defer lb.mu.Unlock()
now := time.Now()
elapsed := now.Sub(lb.lastLeak).Seconds()
// 计算漏出的水量
leakWater := int64(elapsed * float64(lb.rate))
lb.water -= leakWater
if lb.water < 0 {
lb.water = 0
}
lb.lastLeak = now
// 检查桶是否已满
if lb.water < lb.capacity {
lb.water++
return true
}
return false
}
// SlidingWindow 滑动窗口算法
type SlidingWindow struct {
window time.Duration
capacity int64
requests []time.Time
mu sync.Mutex
}
// NewSlidingWindow 创建滑动窗口
func NewSlidingWindow(window time.Duration, capacity int64) *SlidingWindow {
return &SlidingWindow{
window: window,
capacity: capacity,
requests: make([]time.Time, 0),
}
}
// Allow 检查是否允许访问
func (sw *SlidingWindow) Allow() bool {
sw.mu.Lock()
defer sw.mu.Unlock()
now := time.Now()
// 移除窗口外的请求
validRequests := make([]time.Time, 0)
for _, req := range sw.requests {
if now.Sub(req) < sw.window {
validRequests = append(validRequests, req)
}
}
sw.requests = validRequests
// 检查是否超过容量
if int64(len(sw.requests)) < sw.capacity {
sw.requests = append(sw.requests, now)
return true
}
return false
}
// RateLimiter 限流器
type RateLimiter struct {
algorithm RateLimitAlgorithm
limiter interface{}
}
// NewRateLimiter 创建限流器
func NewRateLimiter(algorithm RateLimitAlgorithm, capacity, rate int64, window time.Duration) *RateLimiter {
limiter := &RateLimiter{algorithm: algorithm}
switch algorithm {
case AlgorithmTokenBucket:
limiter.limiter = NewTokenBucket(capacity, rate)
case AlgorithmLeakyBucket:
limiter.limiter = NewLeakyBucket(capacity, rate)
case AlgorithmSlidingWindow:
limiter.limiter = NewSlidingWindow(window, capacity)
default:
limiter.limiter = NewSlidingWindow(window, capacity)
}
return limiter
}
// Allow 检查是否允许访问
func (rl *RateLimiter) Allow() bool {
switch rl.algorithm {
case AlgorithmTokenBucket:
return rl.limiter.(*TokenBucket).Allow()
case AlgorithmLeakyBucket:
return rl.limiter.(*LeakyBucket).Allow()
case AlgorithmSlidingWindow:
return rl.limiter.(*SlidingWindow).Allow()
default:
return rl.limiter.(*SlidingWindow).Allow()
}
}

View File

@@ -0,0 +1,185 @@
package security
import (
"net"
"regexp"
"strings"
)
// Validator groups lightweight validation and sanitization helpers.
type Validator struct {
passwordMinLength int
passwordRequireSpecial bool
passwordRequireNumber bool
}
// NewValidator creates a validator with the configured password rules.
func NewValidator(minLength int, requireSpecial, requireNumber bool) *Validator {
return &Validator{
passwordMinLength: minLength,
passwordRequireSpecial: requireSpecial,
passwordRequireNumber: requireNumber,
}
}
// ValidateEmail validates email format.
func (v *Validator) ValidateEmail(email string) bool {
if email == "" {
return false
}
pattern := `^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`
matched, _ := regexp.MatchString(pattern, email)
return matched
}
// ValidatePhone validates mainland China mobile numbers.
func (v *Validator) ValidatePhone(phone string) bool {
if phone == "" {
return false
}
pattern := `^1[3-9]\d{9}$`
matched, _ := regexp.MatchString(pattern, phone)
return matched
}
// ValidateUsername validates usernames.
func (v *Validator) ValidateUsername(username string) bool {
if username == "" {
return false
}
pattern := `^[a-zA-Z][a-zA-Z0-9_]{3,19}$`
matched, _ := regexp.MatchString(pattern, username)
return matched
}
// ValidatePassword validates passwords using the shared runtime policy.
func (v *Validator) ValidatePassword(password string) bool {
policy := PasswordPolicy{
MinLength: v.passwordMinLength,
RequireSpecial: v.passwordRequireSpecial,
RequireNumber: v.passwordRequireNumber,
}
return policy.Validate(password) == nil
}
// SanitizeSQL removes obviously dangerous SQL injection patterns using regex.
// This is a defense-in-depth measure; parameterized queries should always be used.
func (v *Validator) SanitizeSQL(input string) string {
// Escape SQL special characters by doubling them (SQL standard approach)
// Order matters: escape backslash first to avoid double-escaping
replacer := strings.NewReplacer(
`\`, `\\`,
`'`, `''`,
`"`, `""`,
)
// Remove common SQL injection patterns that could bypass quoting
dangerousPatterns := []string{
`;[\s]*--`, // SQL comment
`/\*.*?\*/`, // Block comment (non-greedy)
`\bxp_\w+`, // Extended stored procedures
`\bexec[\s\(]`, // EXEC statements
`\bsp_\w+`, // System stored procedures
`\bwaitfor[\s]+delay`, // Time-based blind SQL injection
`\bunion[\s]+select`, // UNION injection
`\bdrop[\s]+table`, // DROP TABLE
`\binsert[\s]+into`, // INSERT
`\bupdate[\s]+\w+[\s]+set`, // UPDATE
`\bdelete[\s]+from`, // DELETE
}
result := replacer.Replace(input)
// Apply pattern removal
for _, pattern := range dangerousPatterns {
re := regexp.MustCompile(`(?i)` + pattern) // Case-insensitive
result = re.ReplaceAllString(result, "")
}
return result
}
// SanitizeXSS removes obviously dangerous XSS patterns using regex.
// This is a defense-in-depth measure; output encoding should always be used.
func (v *Validator) SanitizeXSS(input string) string {
// Remove dangerous tags and attributes using pattern matching
dangerousPatterns := []struct {
pattern string
replaceAll bool
}{
{`(?i)<script[^>]*>.*?</script>`, true}, // Script tags
{`(?i)</script>`, false}, // Closing script
{`(?i)<iframe[^>]*>.*?</iframe>`, true}, // Iframe injection
{`(?i)<object[^>]*>.*?</object>`, true}, // Object injection
{`(?i)<embed[^>]*>.*?</embed>`, true}, // Embed injection
{`(?i)<applet[^>]*>.*?</applet>`, true}, // Applet injection
{`(?i)javascript\s*:`, false}, // JavaScript protocol
{`(?i)vbscript\s*:`, false}, // VBScript protocol
{`(?i)data\s*:`, false}, // Data URL protocol
{`(?i)on\w+\s*=`, false}, // Event handlers
{`(?i)<style[^>]*>.*?</style>`, true}, // Style injection
}
result := input
for _, p := range dangerousPatterns {
re := regexp.MustCompile(p.pattern)
if p.replaceAll {
result = re.ReplaceAllString(result, "")
} else {
result = re.ReplaceAllString(result, "")
}
}
// Encode < and > to prevent tag construction
result = strings.ReplaceAll(result, "<", "&lt;")
result = strings.ReplaceAll(result, ">", "&gt;")
// Restore entities if they were part of legitimate content
result = strings.ReplaceAll(result, "&lt;", "<")
result = strings.ReplaceAll(result, "&gt;", ">")
return result
}
// ValidateURL validates a basic HTTP/HTTPS URL.
func (v *Validator) ValidateURL(url string) bool {
if url == "" {
return false
}
pattern := `^https?://[a-zA-Z0-9\-._~:/?#[\]@!$&'()*+,;=]+$`
matched, _ := regexp.MatchString(pattern, url)
return matched
}
// ValidateIP validates IPv4 or IPv6 addresses using net.ParseIP.
// Supports all valid formats including compressed IPv6 (::1, fe80::1, etc.)
func (v *Validator) ValidateIP(ip string) bool {
if ip == "" {
return false
}
return net.ParseIP(ip) != nil
}
// ValidateIPv4 validates IPv4 addresses only.
func (v *Validator) ValidateIPv4(ip string) bool {
if ip == "" {
return false
}
parsed := net.ParseIP(ip)
return parsed != nil && parsed.To4() != nil
}
// ValidateIPv6 validates IPv6 addresses only.
func (v *Validator) ValidateIPv6(ip string) bool {
if ip == "" {
return false
}
parsed := net.ParseIP(ip)
return parsed != nil && parsed.To4() == nil
}