485 lines
12 KiB
Go
485 lines
12 KiB
Go
|
|
package service
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"bytes"
|
|||
|
|
"context"
|
|||
|
|
"crypto/hmac"
|
|||
|
|
cryptorand "crypto/rand"
|
|||
|
|
"crypto/sha256"
|
|||
|
|
"encoding/hex"
|
|||
|
|
"encoding/json"
|
|||
|
|
"fmt"
|
|||
|
|
"log/slog"
|
|||
|
|
"net"
|
|||
|
|
"net/http"
|
|||
|
|
"net/url"
|
|||
|
|
"strings"
|
|||
|
|
"sync"
|
|||
|
|
"time"
|
|||
|
|
|
|||
|
|
"github.com/user-management-system/internal/domain"
|
|||
|
|
"github.com/user-management-system/internal/repository"
|
|||
|
|
"gorm.io/gorm"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// WebhookService Webhook 服务
|
|||
|
|
type WebhookService struct {
|
|||
|
|
db *gorm.DB
|
|||
|
|
repo *repository.WebhookRepository
|
|||
|
|
queue chan *deliveryTask
|
|||
|
|
workers int
|
|||
|
|
config WebhookServiceConfig
|
|||
|
|
wg sync.WaitGroup
|
|||
|
|
once sync.Once
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type WebhookServiceConfig struct {
|
|||
|
|
Enabled bool
|
|||
|
|
SecretHeader string
|
|||
|
|
TimeoutSec int
|
|||
|
|
MaxRetries int
|
|||
|
|
RetryBackoff string
|
|||
|
|
WorkerCount int
|
|||
|
|
QueueSize int
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// deliveryTask 投递任务
|
|||
|
|
type deliveryTask struct {
|
|||
|
|
webhook *domain.Webhook
|
|||
|
|
eventType domain.WebhookEventType
|
|||
|
|
payload []byte
|
|||
|
|
attempt int
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// WebhookEvent 发布的事件结构
|
|||
|
|
type WebhookEvent struct {
|
|||
|
|
EventID string `json:"event_id"`
|
|||
|
|
EventType domain.WebhookEventType `json:"event_type"`
|
|||
|
|
Timestamp time.Time `json:"timestamp"`
|
|||
|
|
Data interface{} `json:"data"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewWebhookService 创建 Webhook 服务
|
|||
|
|
func NewWebhookService(db *gorm.DB, cfgs ...WebhookServiceConfig) *WebhookService {
|
|||
|
|
cfg := defaultWebhookServiceConfig()
|
|||
|
|
if len(cfgs) > 0 {
|
|||
|
|
cfg = cfgs[0]
|
|||
|
|
}
|
|||
|
|
if cfg.WorkerCount <= 0 {
|
|||
|
|
cfg.WorkerCount = defaultWebhookServiceConfig().WorkerCount
|
|||
|
|
}
|
|||
|
|
if cfg.QueueSize <= 0 {
|
|||
|
|
cfg.QueueSize = defaultWebhookServiceConfig().QueueSize
|
|||
|
|
}
|
|||
|
|
if cfg.SecretHeader == "" {
|
|||
|
|
cfg.SecretHeader = defaultWebhookServiceConfig().SecretHeader
|
|||
|
|
}
|
|||
|
|
if cfg.TimeoutSec <= 0 {
|
|||
|
|
cfg.TimeoutSec = defaultWebhookServiceConfig().TimeoutSec
|
|||
|
|
}
|
|||
|
|
if cfg.MaxRetries <= 0 {
|
|||
|
|
cfg.MaxRetries = defaultWebhookServiceConfig().MaxRetries
|
|||
|
|
}
|
|||
|
|
if cfg.RetryBackoff == "" {
|
|||
|
|
cfg.RetryBackoff = defaultWebhookServiceConfig().RetryBackoff
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
svc := &WebhookService{
|
|||
|
|
db: db,
|
|||
|
|
repo: repository.NewWebhookRepository(db),
|
|||
|
|
queue: make(chan *deliveryTask, cfg.QueueSize),
|
|||
|
|
workers: cfg.WorkerCount,
|
|||
|
|
config: cfg,
|
|||
|
|
}
|
|||
|
|
svc.startWorkers()
|
|||
|
|
return svc
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func defaultWebhookServiceConfig() WebhookServiceConfig {
|
|||
|
|
return WebhookServiceConfig{
|
|||
|
|
Enabled: true,
|
|||
|
|
SecretHeader: "X-Webhook-Signature",
|
|||
|
|
TimeoutSec: 10,
|
|||
|
|
MaxRetries: 3,
|
|||
|
|
RetryBackoff: "exponential",
|
|||
|
|
WorkerCount: 4,
|
|||
|
|
QueueSize: 1000,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// startWorkers 启动后台投递 worker
|
|||
|
|
func (s *WebhookService) startWorkers() {
|
|||
|
|
s.once.Do(func() {
|
|||
|
|
for i := 0; i < s.workers; i++ {
|
|||
|
|
s.wg.Add(1)
|
|||
|
|
go func() {
|
|||
|
|
defer s.wg.Done()
|
|||
|
|
for task := range s.queue {
|
|||
|
|
s.deliver(task)
|
|||
|
|
}
|
|||
|
|
}()
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Publish 发布事件:找到订阅该事件的所有 Webhook,异步投递
|
|||
|
|
func (s *WebhookService) Publish(ctx context.Context, eventType domain.WebhookEventType, data interface{}) {
|
|||
|
|
if !s.config.Enabled {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
// 查询所有活跃 Webhook
|
|||
|
|
webhooks, err := s.repo.ListActive(ctx)
|
|||
|
|
if err != nil {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 构建事件载荷
|
|||
|
|
eventID, err := generateEventID()
|
|||
|
|
if err != nil {
|
|||
|
|
slog.Error("generate event ID failed", "error", err)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
event := &WebhookEvent{
|
|||
|
|
EventID: eventID,
|
|||
|
|
EventType: eventType,
|
|||
|
|
Timestamp: time.Now().UTC(),
|
|||
|
|
Data: data,
|
|||
|
|
}
|
|||
|
|
payloadBytes, err := json.Marshal(event)
|
|||
|
|
if err != nil {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for i := range webhooks {
|
|||
|
|
wh := webhooks[i]
|
|||
|
|
// 检查是否订阅了该事件类型
|
|||
|
|
if !webhookSubscribesTo(wh, eventType) {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
task := &deliveryTask{
|
|||
|
|
webhook: wh,
|
|||
|
|
eventType: eventType,
|
|||
|
|
payload: payloadBytes,
|
|||
|
|
attempt: 1,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 非阻塞投递到队列
|
|||
|
|
select {
|
|||
|
|
case s.queue <- task:
|
|||
|
|
default:
|
|||
|
|
// 队列满时记录但不阻塞
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// deliver 执行单次 HTTP 投递
|
|||
|
|
func (s *WebhookService) deliver(task *deliveryTask) {
|
|||
|
|
wh := task.webhook
|
|||
|
|
|
|||
|
|
// NEW-SEC-01 修复:检查 URL 安全性
|
|||
|
|
if !isSafeURL(wh.URL) {
|
|||
|
|
s.recordDelivery(task, 0, "", "webhook URL 不安全: 可能存在 SSRF 风险", false)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
timeout := time.Duration(wh.TimeoutSec) * time.Second
|
|||
|
|
if timeout <= 0 {
|
|||
|
|
timeout = time.Duration(s.config.TimeoutSec) * time.Second
|
|||
|
|
}
|
|||
|
|
if timeout <= 0 {
|
|||
|
|
timeout = 10 * time.Second
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
client := &http.Client{Timeout: timeout}
|
|||
|
|
|
|||
|
|
req, err := http.NewRequest("POST", wh.URL, bytes.NewReader(task.payload))
|
|||
|
|
if err != nil {
|
|||
|
|
s.recordDelivery(task, 0, "", err.Error(), false)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
req.Header.Set("Content-Type", "application/json")
|
|||
|
|
req.Header.Set("User-Agent", "UserManagementSystem-Webhook/1.0")
|
|||
|
|
req.Header.Set("X-Webhook-Event", string(task.eventType))
|
|||
|
|
req.Header.Set("X-Webhook-Attempt", fmt.Sprintf("%d", task.attempt))
|
|||
|
|
|
|||
|
|
// HMAC 签名
|
|||
|
|
if wh.Secret != "" {
|
|||
|
|
sig := computeHMAC(task.payload, wh.Secret)
|
|||
|
|
req.Header.Set(s.config.SecretHeader, "sha256="+sig)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 使用带超时的 context 避免请求无限等待
|
|||
|
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|||
|
|
defer cancel()
|
|||
|
|
resp, err := client.Do(req.WithContext(ctx))
|
|||
|
|
if err != nil {
|
|||
|
|
s.handleFailure(task, 0, "", err.Error())
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
defer resp.Body.Close()
|
|||
|
|
|
|||
|
|
var respBuf bytes.Buffer
|
|||
|
|
respBuf.ReadFrom(resp.Body)
|
|||
|
|
success := resp.StatusCode >= 200 && resp.StatusCode < 300
|
|||
|
|
|
|||
|
|
if !success {
|
|||
|
|
s.handleFailure(task, resp.StatusCode, respBuf.String(), "非 2xx 响应")
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
s.recordDelivery(task, resp.StatusCode, respBuf.String(), "", true)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// handleFailure 处理投递失败(重试逻辑)
|
|||
|
|
func (s *WebhookService) handleFailure(task *deliveryTask, statusCode int, body, errMsg string) {
|
|||
|
|
s.recordDelivery(task, statusCode, body, errMsg, false)
|
|||
|
|
|
|||
|
|
// 指数退避重试
|
|||
|
|
if task.attempt < task.webhook.MaxRetries {
|
|||
|
|
backoff := time.Second
|
|||
|
|
if s.config.RetryBackoff == "fixed" {
|
|||
|
|
backoff = 2 * time.Second
|
|||
|
|
} else {
|
|||
|
|
backoff = time.Duration(1<<uint(task.attempt)) * time.Second
|
|||
|
|
}
|
|||
|
|
time.AfterFunc(backoff, func() {
|
|||
|
|
task.attempt++
|
|||
|
|
select {
|
|||
|
|
case s.queue <- task:
|
|||
|
|
default:
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// recordDelivery 记录投递日志
|
|||
|
|
func (s *WebhookService) recordDelivery(task *deliveryTask, statusCode int, body, errMsg string, success bool) {
|
|||
|
|
now := time.Now()
|
|||
|
|
delivery := &domain.WebhookDelivery{
|
|||
|
|
WebhookID: task.webhook.ID,
|
|||
|
|
EventType: task.eventType,
|
|||
|
|
Payload: string(task.payload),
|
|||
|
|
StatusCode: statusCode,
|
|||
|
|
ResponseBody: body,
|
|||
|
|
Attempt: task.attempt,
|
|||
|
|
Success: success,
|
|||
|
|
Error: errMsg,
|
|||
|
|
}
|
|||
|
|
if success {
|
|||
|
|
delivery.DeliveredAt = &now
|
|||
|
|
}
|
|||
|
|
_ = s.repo.CreateDelivery(context.Background(), delivery)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// CreateWebhook 创建 Webhook
|
|||
|
|
func (s *WebhookService) CreateWebhook(ctx context.Context, req *CreateWebhookRequest, createdBy int64) (*domain.Webhook, error) {
|
|||
|
|
eventsJSON, err := json.Marshal(req.Events)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("序列化事件列表失败")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
secret := req.Secret
|
|||
|
|
if secret == "" {
|
|||
|
|
generatedSecret, err := generateWebhookSecret()
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("generate webhook secret failed: %w", err)
|
|||
|
|
}
|
|||
|
|
secret = generatedSecret
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
wh := &domain.Webhook{
|
|||
|
|
Name: req.Name,
|
|||
|
|
URL: req.URL,
|
|||
|
|
Secret: secret,
|
|||
|
|
Events: string(eventsJSON),
|
|||
|
|
Status: domain.WebhookStatusActive,
|
|||
|
|
MaxRetries: s.config.MaxRetries,
|
|||
|
|
TimeoutSec: s.config.TimeoutSec,
|
|||
|
|
CreatedBy: createdBy,
|
|||
|
|
}
|
|||
|
|
if err := s.repo.Create(ctx, wh); err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
return wh, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// UpdateWebhook 更新 Webhook
|
|||
|
|
func (s *WebhookService) UpdateWebhook(ctx context.Context, id int64, req *UpdateWebhookRequest) error {
|
|||
|
|
updates := map[string]interface{}{}
|
|||
|
|
if req.Name != "" {
|
|||
|
|
updates["name"] = req.Name
|
|||
|
|
}
|
|||
|
|
if req.URL != "" {
|
|||
|
|
updates["url"] = req.URL
|
|||
|
|
}
|
|||
|
|
if len(req.Events) > 0 {
|
|||
|
|
b, _ := json.Marshal(req.Events)
|
|||
|
|
updates["events"] = string(b)
|
|||
|
|
}
|
|||
|
|
if req.Status != nil {
|
|||
|
|
updates["status"] = *req.Status
|
|||
|
|
}
|
|||
|
|
return s.repo.Update(ctx, id, updates)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// DeleteWebhook 删除 Webhook
|
|||
|
|
func (s *WebhookService) DeleteWebhook(ctx context.Context, id int64) error {
|
|||
|
|
return s.repo.Delete(ctx, id)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (s *WebhookService) GetWebhook(ctx context.Context, id int64) (*domain.Webhook, error) {
|
|||
|
|
return s.repo.GetByID(ctx, id)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ListWebhooks 获取 Webhook 列表(不分页)
|
|||
|
|
func (s *WebhookService) ListWebhooks(ctx context.Context, createdBy int64) ([]*domain.Webhook, error) {
|
|||
|
|
return s.repo.ListByCreator(ctx, createdBy)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ListWebhooksPaginated 获取 Webhook 列表(分页)
|
|||
|
|
func (s *WebhookService) ListWebhooksPaginated(ctx context.Context, createdBy int64, offset, limit int) ([]*domain.Webhook, int64, error) {
|
|||
|
|
return s.repo.ListByCreatorPaginated(ctx, createdBy, offset, limit)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetWebhookDeliveries 获取投递记录
|
|||
|
|
func (s *WebhookService) GetWebhookDeliveries(ctx context.Context, webhookID int64, limit int) ([]*domain.WebhookDelivery, error) {
|
|||
|
|
return s.repo.ListDeliveries(ctx, webhookID, limit)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ---- Request/Response 结构 ----
|
|||
|
|
|
|||
|
|
// CreateWebhookRequest 创建 Webhook 请求
|
|||
|
|
type CreateWebhookRequest struct {
|
|||
|
|
Name string `json:"name" binding:"required"`
|
|||
|
|
URL string `json:"url" binding:"required,url"`
|
|||
|
|
Secret string `json:"secret"`
|
|||
|
|
Events []domain.WebhookEventType `json:"events" binding:"required,min=1"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// UpdateWebhookRequest 更新 Webhook 请求
|
|||
|
|
type UpdateWebhookRequest struct {
|
|||
|
|
Name string `json:"name"`
|
|||
|
|
URL string `json:"url"`
|
|||
|
|
Events []domain.WebhookEventType `json:"events"`
|
|||
|
|
Status *domain.WebhookStatus `json:"status"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ---- 辅助函数 ----
|
|||
|
|
|
|||
|
|
// webhookSubscribesTo 检查 Webhook 是否订阅了指定事件类型
|
|||
|
|
func webhookSubscribesTo(w *domain.Webhook, eventType domain.WebhookEventType) bool {
|
|||
|
|
var events []domain.WebhookEventType
|
|||
|
|
if err := json.Unmarshal([]byte(w.Events), &events); err != nil {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
for _, e := range events {
|
|||
|
|
if e == eventType || e == "*" {
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// SubscribesTo 检查 Webhook 是否订阅了指定事件类型(为 domain.Webhook 添加方法,通过包装实现)
|
|||
|
|
// 注意:此函数在 domain 包外部无法直接扩展,使用独立函数代替
|
|||
|
|
|
|||
|
|
// isSafeURL 检查 URL 是否安全(防止 SSRF 攻击)
|
|||
|
|
// NEW-SEC-01 修复:添加完整的 URL 安全检查
|
|||
|
|
func isSafeURL(rawURL string) bool {
|
|||
|
|
u, err := url.Parse(rawURL)
|
|||
|
|
if err != nil || u.Scheme == "" {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
// 只允许 http/https
|
|||
|
|
if u.Scheme != "http" && u.Scheme != "https" {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
host := u.Hostname()
|
|||
|
|
|
|||
|
|
// 禁止 localhost
|
|||
|
|
if host == "localhost" || host == "127.0.0.1" || host == "::1" {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 检查内网 IP
|
|||
|
|
if ip := net.ParseIP(host); ip != nil {
|
|||
|
|
if isPrivateIP(ip) {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 检查内网域名
|
|||
|
|
if strings.HasSuffix(host, ".internal") ||
|
|||
|
|
strings.HasSuffix(host, ".local") ||
|
|||
|
|
strings.HasSuffix(host, ".corp") ||
|
|||
|
|
strings.HasSuffix(host, ".lan") ||
|
|||
|
|
strings.HasSuffix(host, ".intranet") {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 检查知名内网服务地址
|
|||
|
|
blockedHosts := []string{
|
|||
|
|
"metadata.google.internal", // GCP 元数据服务
|
|||
|
|
"169.254.169.254", // AWS/Azure/GCP 元数据服务
|
|||
|
|
"metadata.azure.internal", // Azure 元数据服务
|
|||
|
|
"100.100.100.200", // 阿里云元数据服务
|
|||
|
|
}
|
|||
|
|
for _, blocked := range blockedHosts {
|
|||
|
|
if host == blocked {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// isPrivateIP 检查是否为内网 IP
|
|||
|
|
func isPrivateIP(ip net.IP) bool {
|
|||
|
|
privateRanges := []string{
|
|||
|
|
"10.0.0.0/8",
|
|||
|
|
"172.16.0.0/12",
|
|||
|
|
"192.168.0.0/16",
|
|||
|
|
"127.0.0.0/8",
|
|||
|
|
"::1/128",
|
|||
|
|
"fc00::/7",
|
|||
|
|
}
|
|||
|
|
for _, cidr := range privateRanges {
|
|||
|
|
_, network, err := net.ParseCIDR(cidr)
|
|||
|
|
if err != nil {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
if network.Contains(ip) {
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// computeHMAC 计算 HMAC-SHA256 签名
|
|||
|
|
func computeHMAC(payload []byte, secret string) string {
|
|||
|
|
mac := hmac.New(sha256.New, []byte(secret))
|
|||
|
|
mac.Write(payload)
|
|||
|
|
return hex.EncodeToString(mac.Sum(nil))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// generateEventID 生成随机事件 ID
|
|||
|
|
func generateEventID() (string, error) {
|
|||
|
|
b := make([]byte, 8)
|
|||
|
|
if _, err := cryptorand.Read(b); err != nil {
|
|||
|
|
return "", fmt.Errorf("generate event ID failed: %w", err)
|
|||
|
|
}
|
|||
|
|
return "evt_" + hex.EncodeToString(b), nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// generateWebhookSecret 生成随机 Webhook 签名密钥
|
|||
|
|
func generateWebhookSecret() (string, error) {
|
|||
|
|
b := make([]byte, 24)
|
|||
|
|
if _, err := cryptorand.Read(b); err != nil {
|
|||
|
|
return "", fmt.Errorf("generate webhook secret failed: %w", err)
|
|||
|
|
}
|
|||
|
|
return strings.ToLower(hex.EncodeToString(b)), nil
|
|||
|
|
}
|