Files
user-system/internal/service/webhook.go

485 lines
12 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"bytes"
"context"
"crypto/hmac"
cryptorand "crypto/rand"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"log/slog"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
"gorm.io/gorm"
)
// WebhookService Webhook 服务
type WebhookService struct {
db *gorm.DB
repo *repository.WebhookRepository
queue chan *deliveryTask
workers int
config WebhookServiceConfig
wg sync.WaitGroup
once sync.Once
}
type WebhookServiceConfig struct {
Enabled bool
SecretHeader string
TimeoutSec int
MaxRetries int
RetryBackoff string
WorkerCount int
QueueSize int
}
// deliveryTask 投递任务
type deliveryTask struct {
webhook *domain.Webhook
eventType domain.WebhookEventType
payload []byte
attempt int
}
// WebhookEvent 发布的事件结构
type WebhookEvent struct {
EventID string `json:"event_id"`
EventType domain.WebhookEventType `json:"event_type"`
Timestamp time.Time `json:"timestamp"`
Data interface{} `json:"data"`
}
// NewWebhookService 创建 Webhook 服务
func NewWebhookService(db *gorm.DB, cfgs ...WebhookServiceConfig) *WebhookService {
cfg := defaultWebhookServiceConfig()
if len(cfgs) > 0 {
cfg = cfgs[0]
}
if cfg.WorkerCount <= 0 {
cfg.WorkerCount = defaultWebhookServiceConfig().WorkerCount
}
if cfg.QueueSize <= 0 {
cfg.QueueSize = defaultWebhookServiceConfig().QueueSize
}
if cfg.SecretHeader == "" {
cfg.SecretHeader = defaultWebhookServiceConfig().SecretHeader
}
if cfg.TimeoutSec <= 0 {
cfg.TimeoutSec = defaultWebhookServiceConfig().TimeoutSec
}
if cfg.MaxRetries <= 0 {
cfg.MaxRetries = defaultWebhookServiceConfig().MaxRetries
}
if cfg.RetryBackoff == "" {
cfg.RetryBackoff = defaultWebhookServiceConfig().RetryBackoff
}
svc := &WebhookService{
db: db,
repo: repository.NewWebhookRepository(db),
queue: make(chan *deliveryTask, cfg.QueueSize),
workers: cfg.WorkerCount,
config: cfg,
}
svc.startWorkers()
return svc
}
func defaultWebhookServiceConfig() WebhookServiceConfig {
return WebhookServiceConfig{
Enabled: true,
SecretHeader: "X-Webhook-Signature",
TimeoutSec: 10,
MaxRetries: 3,
RetryBackoff: "exponential",
WorkerCount: 4,
QueueSize: 1000,
}
}
// startWorkers 启动后台投递 worker
func (s *WebhookService) startWorkers() {
s.once.Do(func() {
for i := 0; i < s.workers; i++ {
s.wg.Add(1)
go func() {
defer s.wg.Done()
for task := range s.queue {
s.deliver(task)
}
}()
}
})
}
// Publish 发布事件:找到订阅该事件的所有 Webhook异步投递
func (s *WebhookService) Publish(ctx context.Context, eventType domain.WebhookEventType, data interface{}) {
if !s.config.Enabled {
return
}
// 查询所有活跃 Webhook
webhooks, err := s.repo.ListActive(ctx)
if err != nil {
return
}
// 构建事件载荷
eventID, err := generateEventID()
if err != nil {
slog.Error("generate event ID failed", "error", err)
return
}
event := &WebhookEvent{
EventID: eventID,
EventType: eventType,
Timestamp: time.Now().UTC(),
Data: data,
}
payloadBytes, err := json.Marshal(event)
if err != nil {
return
}
for i := range webhooks {
wh := webhooks[i]
// 检查是否订阅了该事件类型
if !webhookSubscribesTo(wh, eventType) {
continue
}
task := &deliveryTask{
webhook: wh,
eventType: eventType,
payload: payloadBytes,
attempt: 1,
}
// 非阻塞投递到队列
select {
case s.queue <- task:
default:
// 队列满时记录但不阻塞
}
}
}
// deliver 执行单次 HTTP 投递
func (s *WebhookService) deliver(task *deliveryTask) {
wh := task.webhook
// NEW-SEC-01 修复:检查 URL 安全性
if !isSafeURL(wh.URL) {
s.recordDelivery(task, 0, "", "webhook URL 不安全: 可能存在 SSRF 风险", false)
return
}
timeout := time.Duration(wh.TimeoutSec) * time.Second
if timeout <= 0 {
timeout = time.Duration(s.config.TimeoutSec) * time.Second
}
if timeout <= 0 {
timeout = 10 * time.Second
}
client := &http.Client{Timeout: timeout}
req, err := http.NewRequest("POST", wh.URL, bytes.NewReader(task.payload))
if err != nil {
s.recordDelivery(task, 0, "", err.Error(), false)
return
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", "UserManagementSystem-Webhook/1.0")
req.Header.Set("X-Webhook-Event", string(task.eventType))
req.Header.Set("X-Webhook-Attempt", fmt.Sprintf("%d", task.attempt))
// HMAC 签名
if wh.Secret != "" {
sig := computeHMAC(task.payload, wh.Secret)
req.Header.Set(s.config.SecretHeader, "sha256="+sig)
}
// 使用带超时的 context 避免请求无限等待
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
resp, err := client.Do(req.WithContext(ctx))
if err != nil {
s.handleFailure(task, 0, "", err.Error())
return
}
defer resp.Body.Close()
var respBuf bytes.Buffer
respBuf.ReadFrom(resp.Body)
success := resp.StatusCode >= 200 && resp.StatusCode < 300
if !success {
s.handleFailure(task, resp.StatusCode, respBuf.String(), "非 2xx 响应")
return
}
s.recordDelivery(task, resp.StatusCode, respBuf.String(), "", true)
}
// handleFailure 处理投递失败(重试逻辑)
func (s *WebhookService) handleFailure(task *deliveryTask, statusCode int, body, errMsg string) {
s.recordDelivery(task, statusCode, body, errMsg, false)
// 指数退避重试
if task.attempt < task.webhook.MaxRetries {
backoff := time.Second
if s.config.RetryBackoff == "fixed" {
backoff = 2 * time.Second
} else {
backoff = time.Duration(1<<uint(task.attempt)) * time.Second
}
time.AfterFunc(backoff, func() {
task.attempt++
select {
case s.queue <- task:
default:
}
})
}
}
// recordDelivery 记录投递日志
func (s *WebhookService) recordDelivery(task *deliveryTask, statusCode int, body, errMsg string, success bool) {
now := time.Now()
delivery := &domain.WebhookDelivery{
WebhookID: task.webhook.ID,
EventType: task.eventType,
Payload: string(task.payload),
StatusCode: statusCode,
ResponseBody: body,
Attempt: task.attempt,
Success: success,
Error: errMsg,
}
if success {
delivery.DeliveredAt = &now
}
_ = s.repo.CreateDelivery(context.Background(), delivery)
}
// CreateWebhook 创建 Webhook
func (s *WebhookService) CreateWebhook(ctx context.Context, req *CreateWebhookRequest, createdBy int64) (*domain.Webhook, error) {
eventsJSON, err := json.Marshal(req.Events)
if err != nil {
return nil, fmt.Errorf("序列化事件列表失败")
}
secret := req.Secret
if secret == "" {
generatedSecret, err := generateWebhookSecret()
if err != nil {
return nil, fmt.Errorf("generate webhook secret failed: %w", err)
}
secret = generatedSecret
}
wh := &domain.Webhook{
Name: req.Name,
URL: req.URL,
Secret: secret,
Events: string(eventsJSON),
Status: domain.WebhookStatusActive,
MaxRetries: s.config.MaxRetries,
TimeoutSec: s.config.TimeoutSec,
CreatedBy: createdBy,
}
if err := s.repo.Create(ctx, wh); err != nil {
return nil, err
}
return wh, nil
}
// UpdateWebhook 更新 Webhook
func (s *WebhookService) UpdateWebhook(ctx context.Context, id int64, req *UpdateWebhookRequest) error {
updates := map[string]interface{}{}
if req.Name != "" {
updates["name"] = req.Name
}
if req.URL != "" {
updates["url"] = req.URL
}
if len(req.Events) > 0 {
b, _ := json.Marshal(req.Events)
updates["events"] = string(b)
}
if req.Status != nil {
updates["status"] = *req.Status
}
return s.repo.Update(ctx, id, updates)
}
// DeleteWebhook 删除 Webhook
func (s *WebhookService) DeleteWebhook(ctx context.Context, id int64) error {
return s.repo.Delete(ctx, id)
}
func (s *WebhookService) GetWebhook(ctx context.Context, id int64) (*domain.Webhook, error) {
return s.repo.GetByID(ctx, id)
}
// ListWebhooks 获取 Webhook 列表(不分页)
func (s *WebhookService) ListWebhooks(ctx context.Context, createdBy int64) ([]*domain.Webhook, error) {
return s.repo.ListByCreator(ctx, createdBy)
}
// ListWebhooksPaginated 获取 Webhook 列表(分页)
func (s *WebhookService) ListWebhooksPaginated(ctx context.Context, createdBy int64, offset, limit int) ([]*domain.Webhook, int64, error) {
return s.repo.ListByCreatorPaginated(ctx, createdBy, offset, limit)
}
// GetWebhookDeliveries 获取投递记录
func (s *WebhookService) GetWebhookDeliveries(ctx context.Context, webhookID int64, limit int) ([]*domain.WebhookDelivery, error) {
return s.repo.ListDeliveries(ctx, webhookID, limit)
}
// ---- Request/Response 结构 ----
// CreateWebhookRequest 创建 Webhook 请求
type CreateWebhookRequest struct {
Name string `json:"name" binding:"required"`
URL string `json:"url" binding:"required,url"`
Secret string `json:"secret"`
Events []domain.WebhookEventType `json:"events" binding:"required,min=1"`
}
// UpdateWebhookRequest 更新 Webhook 请求
type UpdateWebhookRequest struct {
Name string `json:"name"`
URL string `json:"url"`
Events []domain.WebhookEventType `json:"events"`
Status *domain.WebhookStatus `json:"status"`
}
// ---- 辅助函数 ----
// webhookSubscribesTo 检查 Webhook 是否订阅了指定事件类型
func webhookSubscribesTo(w *domain.Webhook, eventType domain.WebhookEventType) bool {
var events []domain.WebhookEventType
if err := json.Unmarshal([]byte(w.Events), &events); err != nil {
return false
}
for _, e := range events {
if e == eventType || e == "*" {
return true
}
}
return false
}
// SubscribesTo 检查 Webhook 是否订阅了指定事件类型(为 domain.Webhook 添加方法,通过包装实现)
// 注意:此函数在 domain 包外部无法直接扩展,使用独立函数代替
// isSafeURL 检查 URL 是否安全(防止 SSRF 攻击)
// NEW-SEC-01 修复:添加完整的 URL 安全检查
func isSafeURL(rawURL string) bool {
u, err := url.Parse(rawURL)
if err != nil || u.Scheme == "" {
return false
}
// 只允许 http/https
if u.Scheme != "http" && u.Scheme != "https" {
return false
}
host := u.Hostname()
// 禁止 localhost
if host == "localhost" || host == "127.0.0.1" || host == "::1" {
return false
}
// 检查内网 IP
if ip := net.ParseIP(host); ip != nil {
if isPrivateIP(ip) {
return false
}
}
// 检查内网域名
if strings.HasSuffix(host, ".internal") ||
strings.HasSuffix(host, ".local") ||
strings.HasSuffix(host, ".corp") ||
strings.HasSuffix(host, ".lan") ||
strings.HasSuffix(host, ".intranet") {
return false
}
// 检查知名内网服务地址
blockedHosts := []string{
"metadata.google.internal", // GCP 元数据服务
"169.254.169.254", // AWS/Azure/GCP 元数据服务
"metadata.azure.internal", // Azure 元数据服务
"100.100.100.200", // 阿里云元数据服务
}
for _, blocked := range blockedHosts {
if host == blocked {
return false
}
}
return true
}
// isPrivateIP 检查是否为内网 IP
func isPrivateIP(ip net.IP) bool {
privateRanges := []string{
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"127.0.0.0/8",
"::1/128",
"fc00::/7",
}
for _, cidr := range privateRanges {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
continue
}
if network.Contains(ip) {
return true
}
}
return false
}
// computeHMAC 计算 HMAC-SHA256 签名
func computeHMAC(payload []byte, secret string) string {
mac := hmac.New(sha256.New, []byte(secret))
mac.Write(payload)
return hex.EncodeToString(mac.Sum(nil))
}
// generateEventID 生成随机事件 ID
func generateEventID() (string, error) {
b := make([]byte, 8)
if _, err := cryptorand.Read(b); err != nil {
return "", fmt.Errorf("generate event ID failed: %w", err)
}
return "evt_" + hex.EncodeToString(b), nil
}
// generateWebhookSecret 生成随机 Webhook 签名密钥
func generateWebhookSecret() (string, error) {
b := make([]byte, 24)
if _, err := cryptorand.Read(b); err != nil {
return "", fmt.Errorf("generate webhook secret failed: %w", err)
}
return strings.ToLower(hex.EncodeToString(b)), nil
}