chore: 删除未使用的孤立包
清理以下未导入的包: - internal/response (未使用的响应结构体) - pkg/response (未使用的响应封装) - internal/model (TLSFingerprintProfile, ErrorPassthroughRule) - internal/models (SocialAccount, domain已有) - internal/pkg/response (未使用的响应封装) - internal/security/ratelimit (已迁移到middleware) 验证: go build ./... && go test ./... 通过
This commit is contained in:
@@ -1,75 +0,0 @@
|
||||
// Package model 定义服务层使用的数据模型。
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
// ErrorPassthroughRule 全局错误透传规则
|
||||
// 用于控制上游错误如何返回给客户端
|
||||
type ErrorPassthroughRule struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"` // 规则名称
|
||||
Enabled bool `json:"enabled"` // 是否启用
|
||||
Priority int `json:"priority"` // 优先级(数字越小优先级越高)
|
||||
ErrorCodes []int `json:"error_codes"` // 匹配的错误码列表(OR关系)
|
||||
Keywords []string `json:"keywords"` // 匹配的关键词列表(OR关系)
|
||||
MatchMode string `json:"match_mode"` // "any"(任一条件) 或 "all"(所有条件)
|
||||
Platforms []string `json:"platforms"` // 适用平台列表
|
||||
PassthroughCode bool `json:"passthrough_code"` // 是否透传原始状态码
|
||||
ResponseCode *int `json:"response_code"` // 自定义状态码(passthrough_code=false 时使用)
|
||||
PassthroughBody bool `json:"passthrough_body"` // 是否透传原始错误信息
|
||||
CustomMessage *string `json:"custom_message"` // 自定义错误信息(passthrough_body=false 时使用)
|
||||
SkipMonitoring bool `json:"skip_monitoring"` // 是否跳过运维监控记录
|
||||
Description *string `json:"description"` // 规则描述
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// MatchModeAny 表示任一条件匹配即可
|
||||
const MatchModeAny = "any"
|
||||
|
||||
// MatchModeAll 表示所有条件都必须匹配
|
||||
const MatchModeAll = "all"
|
||||
|
||||
// 支持的平台常量
|
||||
const (
|
||||
PlatformAnthropic = "anthropic"
|
||||
PlatformOpenAI = "openai"
|
||||
PlatformGemini = "gemini"
|
||||
PlatformAntigravity = "antigravity"
|
||||
)
|
||||
|
||||
// AllPlatforms 返回所有支持的平台列表
|
||||
func AllPlatforms() []string {
|
||||
return []string{PlatformAnthropic, PlatformOpenAI, PlatformGemini, PlatformAntigravity}
|
||||
}
|
||||
|
||||
// Validate 验证规则配置的有效性
|
||||
func (r *ErrorPassthroughRule) Validate() error {
|
||||
if r.Name == "" {
|
||||
return &ValidationError{Field: "name", Message: "name is required"}
|
||||
}
|
||||
if r.MatchMode != MatchModeAny && r.MatchMode != MatchModeAll {
|
||||
return &ValidationError{Field: "match_mode", Message: "match_mode must be 'any' or 'all'"}
|
||||
}
|
||||
// 至少需要配置一个匹配条件(错误码或关键词)
|
||||
if len(r.ErrorCodes) == 0 && len(r.Keywords) == 0 {
|
||||
return &ValidationError{Field: "conditions", Message: "at least one error_code or keyword is required"}
|
||||
}
|
||||
if !r.PassthroughCode && (r.ResponseCode == nil || *r.ResponseCode <= 0) {
|
||||
return &ValidationError{Field: "response_code", Message: "response_code is required when passthrough_code is false"}
|
||||
}
|
||||
if !r.PassthroughBody && (r.CustomMessage == nil || *r.CustomMessage == "") {
|
||||
return &ValidationError{Field: "custom_message", Message: "custom_message is required when passthrough_body is false"}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidationError 表示验证错误
|
||||
type ValidationError struct {
|
||||
Field string
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *ValidationError) Error() string {
|
||||
return e.Field + ": " + e.Message
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
// Package model 定义服务层使用的数据模型。
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/pkg/tlsfingerprint"
|
||||
)
|
||||
|
||||
// TLSFingerprintProfile TLS 指纹配置模板
|
||||
// 包含完整的 ClientHello 参数,用于模拟特定客户端的 TLS 握手特征
|
||||
type TLSFingerprintProfile struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description *string `json:"description"`
|
||||
EnableGREASE bool `json:"enable_grease"`
|
||||
CipherSuites []uint16 `json:"cipher_suites"`
|
||||
Curves []uint16 `json:"curves"`
|
||||
PointFormats []uint16 `json:"point_formats"`
|
||||
SignatureAlgorithms []uint16 `json:"signature_algorithms"`
|
||||
ALPNProtocols []string `json:"alpn_protocols"`
|
||||
SupportedVersions []uint16 `json:"supported_versions"`
|
||||
KeyShareGroups []uint16 `json:"key_share_groups"`
|
||||
PSKModes []uint16 `json:"psk_modes"`
|
||||
Extensions []uint16 `json:"extensions"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// Validate 验证模板配置的有效性
|
||||
func (p *TLSFingerprintProfile) Validate() error {
|
||||
if p.Name == "" {
|
||||
return &ValidationError{Field: "name", Message: "name is required"}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToTLSProfile 将领域模型转换为运行时使用的 tlsfingerprint.Profile
|
||||
// 空切片字段会在 dialer 中 fallback 到内置默认值
|
||||
func (p *TLSFingerprintProfile) ToTLSProfile() *tlsfingerprint.Profile {
|
||||
return &tlsfingerprint.Profile{
|
||||
Name: p.Name,
|
||||
EnableGREASE: p.EnableGREASE,
|
||||
CipherSuites: p.CipherSuites,
|
||||
Curves: p.Curves,
|
||||
PointFormats: p.PointFormats,
|
||||
SignatureAlgorithms: p.SignatureAlgorithms,
|
||||
ALPNProtocols: p.ALPNProtocols,
|
||||
SupportedVersions: p.SupportedVersions,
|
||||
KeyShareGroups: p.KeyShareGroups,
|
||||
PSKModes: p.PSKModes,
|
||||
Extensions: p.Extensions,
|
||||
}
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SocialAccount 社交账号绑定模型
|
||||
type SocialAccount struct {
|
||||
ID uint64 `json:"id" db:"id"`
|
||||
UserID uint64 `json:"user_id" db:"user_id"`
|
||||
Provider string `json:"provider" db:"provider"` // wechat, qq, weibo, google, facebook, twitter
|
||||
ProviderUserID string `json:"provider_user_id" db:"provider_user_id"`
|
||||
ProviderUsername string `json:"provider_username" db:"provider_username"`
|
||||
AccessToken string `json:"-" db:"access_token"` // 不返回给前端
|
||||
RefreshToken string `json:"-" db:"refresh_token"`
|
||||
ExpiresAt *time.Time `json:"expires_at" db:"expires_at"`
|
||||
RawData JSON `json:"-" db:"raw_data"`
|
||||
IsPrimary bool `json:"is_primary" db:"is_primary"`
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
|
||||
}
|
||||
|
||||
// SocialAccountInfo 返回给前端的社交账号信息(不含敏感信息)
|
||||
type SocialAccountInfo struct {
|
||||
ID uint64 `json:"id"`
|
||||
Provider string `json:"provider"`
|
||||
ProviderUserID string `json:"provider_user_id"`
|
||||
ProviderUsername string `json:"provider_username"`
|
||||
IsPrimary bool `json:"is_primary"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// ToInfo 转换为安全信息
|
||||
func (sa *SocialAccount) ToInfo() *SocialAccountInfo {
|
||||
return &SocialAccountInfo{
|
||||
ID: sa.ID,
|
||||
Provider: sa.Provider,
|
||||
ProviderUserID: sa.ProviderUserID,
|
||||
ProviderUsername: sa.ProviderUsername,
|
||||
IsPrimary: sa.IsPrimary,
|
||||
CreatedAt: sa.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// JSON 自定义JSON类型,用于存储RawData
|
||||
type JSON struct {
|
||||
Data interface{}
|
||||
}
|
||||
|
||||
// Scan 实现 sql.Scanner 接口
|
||||
func (j *JSON) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
j.Data = nil
|
||||
return nil
|
||||
}
|
||||
bytes, ok := value.([]byte)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(bytes, &j.Data)
|
||||
}
|
||||
|
||||
// Value 实现 driver.Valuer 接口
|
||||
func (j JSON) Value() (interface{}, error) {
|
||||
if j.Data == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return json.Marshal(j.Data)
|
||||
}
|
||||
@@ -1,203 +0,0 @@
|
||||
// Package response provides standardized HTTP response helpers.
|
||||
package response
|
||||
|
||||
import (
|
||||
"log"
|
||||
"math"
|
||||
"net/http"
|
||||
|
||||
infraerrors "github.com/user-management-system/internal/pkg/errors"
|
||||
"github.com/user-management-system/internal/util/logredact"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Response 标准API响应格式
|
||||
type Response struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
Data any `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// PaginatedData 分页数据格式(匹配前端期望)
|
||||
type PaginatedData struct {
|
||||
Items any `json:"items"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
Pages int `json:"pages"`
|
||||
}
|
||||
|
||||
// Success 返回成功响应
|
||||
func Success(c *gin.Context, data any) {
|
||||
c.JSON(http.StatusOK, Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
// Created 返回创建成功响应
|
||||
func Created(c *gin.Context, data any) {
|
||||
c.JSON(http.StatusCreated, Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
// Accepted 返回异步接受响应 (HTTP 202)
|
||||
func Accepted(c *gin.Context, data any) {
|
||||
c.JSON(http.StatusAccepted, Response{
|
||||
Code: 0,
|
||||
Message: "accepted",
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
// Error 返回错误响应
|
||||
func Error(c *gin.Context, statusCode int, message string) {
|
||||
c.JSON(statusCode, Response{
|
||||
Code: statusCode,
|
||||
Message: message,
|
||||
Reason: "",
|
||||
Metadata: nil,
|
||||
})
|
||||
}
|
||||
|
||||
// ErrorWithDetails returns an error response compatible with the existing envelope while
|
||||
// optionally providing structured error fields (reason/metadata).
|
||||
func ErrorWithDetails(c *gin.Context, statusCode int, message, reason string, metadata map[string]string) {
|
||||
c.JSON(statusCode, Response{
|
||||
Code: statusCode,
|
||||
Message: message,
|
||||
Reason: reason,
|
||||
Metadata: metadata,
|
||||
})
|
||||
}
|
||||
|
||||
// ErrorFrom converts an ApplicationError (or any error) into the envelope-compatible error response.
|
||||
// It returns true if an error was written.
|
||||
func ErrorFrom(c *gin.Context, err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
statusCode, status := infraerrors.ToHTTP(err)
|
||||
|
||||
// Log internal errors with full details for debugging
|
||||
if statusCode >= 500 && c.Request != nil {
|
||||
log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, logredact.RedactText(err.Error()))
|
||||
}
|
||||
|
||||
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
|
||||
return true
|
||||
}
|
||||
|
||||
// BadRequest 返回400错误
|
||||
func BadRequest(c *gin.Context, message string) {
|
||||
Error(c, http.StatusBadRequest, message)
|
||||
}
|
||||
|
||||
// Unauthorized 返回401错误
|
||||
func Unauthorized(c *gin.Context, message string) {
|
||||
Error(c, http.StatusUnauthorized, message)
|
||||
}
|
||||
|
||||
// Forbidden 返回403错误
|
||||
func Forbidden(c *gin.Context, message string) {
|
||||
Error(c, http.StatusForbidden, message)
|
||||
}
|
||||
|
||||
// NotFound 返回404错误
|
||||
func NotFound(c *gin.Context, message string) {
|
||||
Error(c, http.StatusNotFound, message)
|
||||
}
|
||||
|
||||
// InternalError 返回500错误
|
||||
func InternalError(c *gin.Context, message string) {
|
||||
Error(c, http.StatusInternalServerError, message)
|
||||
}
|
||||
|
||||
// Paginated 返回分页数据
|
||||
func Paginated(c *gin.Context, items any, total int64, page, pageSize int) {
|
||||
pages := int(math.Ceil(float64(total) / float64(pageSize)))
|
||||
if pages < 1 {
|
||||
pages = 1
|
||||
}
|
||||
|
||||
Success(c, PaginatedData{
|
||||
Items: items,
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Pages: pages,
|
||||
})
|
||||
}
|
||||
|
||||
// PaginationResult 分页结果(与pagination.PaginationResult兼容)
|
||||
type PaginationResult struct {
|
||||
Total int64
|
||||
Page int
|
||||
PageSize int
|
||||
Pages int
|
||||
}
|
||||
|
||||
// PaginatedWithResult 使用PaginationResult返回分页数据
|
||||
func PaginatedWithResult(c *gin.Context, items any, pagination *PaginationResult) {
|
||||
if pagination == nil {
|
||||
Success(c, PaginatedData{
|
||||
Items: items,
|
||||
Total: 0,
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
Pages: 1,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
Success(c, PaginatedData{
|
||||
Items: items,
|
||||
Total: pagination.Total,
|
||||
Page: pagination.Page,
|
||||
PageSize: pagination.PageSize,
|
||||
Pages: pagination.Pages,
|
||||
})
|
||||
}
|
||||
|
||||
// ParsePagination 解析分页参数
|
||||
func ParsePagination(c *gin.Context) (page, pageSize int) {
|
||||
page = 1
|
||||
pageSize = 20
|
||||
|
||||
if p := c.Query("page"); p != "" {
|
||||
if val, err := parseInt(p); err == nil && val > 0 {
|
||||
page = val
|
||||
}
|
||||
}
|
||||
|
||||
// 支持 page_size 和 limit 两种参数名
|
||||
if ps := c.Query("page_size"); ps != "" {
|
||||
if val, err := parseInt(ps); err == nil && val > 0 && val <= 1000 {
|
||||
pageSize = val
|
||||
}
|
||||
} else if l := c.Query("limit"); l != "" {
|
||||
if val, err := parseInt(l); err == nil && val > 0 && val <= 1000 {
|
||||
pageSize = val
|
||||
}
|
||||
}
|
||||
|
||||
return page, pageSize
|
||||
}
|
||||
|
||||
func parseInt(s string) (int, error) {
|
||||
var result int
|
||||
for _, c := range s {
|
||||
if c < '0' || c > '9' {
|
||||
return 0, nil
|
||||
}
|
||||
result = result*10 + int(c-'0')
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
@@ -1,788 +0,0 @@
|
||||
//go:build unit
|
||||
|
||||
package response
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
errors2 "github.com/user-management-system/internal/pkg/errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------- 辅助函数 ----------
|
||||
|
||||
// parseResponseBody 从 httptest.ResponseRecorder 中解析 JSON 响应体
|
||||
func parseResponseBody(t *testing.T, w *httptest.ResponseRecorder) Response {
|
||||
t.Helper()
|
||||
var got Response
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
|
||||
return got
|
||||
}
|
||||
|
||||
// parsePaginatedBody 从响应体中解析分页数据(Data 字段是 PaginatedData)
|
||||
func parsePaginatedBody(t *testing.T, w *httptest.ResponseRecorder) (Response, PaginatedData) {
|
||||
t.Helper()
|
||||
// 先用 raw json 解析,因为 Data 是 any 类型
|
||||
var raw struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &raw))
|
||||
|
||||
var pd PaginatedData
|
||||
require.NoError(t, json.Unmarshal(raw.Data, &pd))
|
||||
|
||||
return Response{Code: raw.Code, Message: raw.Message, Reason: raw.Reason}, pd
|
||||
}
|
||||
|
||||
// newContextWithQuery 创建一个带有 URL query 参数的 gin.Context 用于测试 ParsePagination
|
||||
func newContextWithQuery(query string) (*httptest.ResponseRecorder, *gin.Context) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/?"+query, nil)
|
||||
return w, c
|
||||
}
|
||||
|
||||
// ---------- 现有测试 ----------
|
||||
|
||||
func TestErrorWithDetails(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
message string
|
||||
reason string
|
||||
metadata map[string]string
|
||||
want Response
|
||||
}{
|
||||
{
|
||||
name: "plain_error",
|
||||
statusCode: http.StatusBadRequest,
|
||||
message: "invalid request",
|
||||
want: Response{
|
||||
Code: http.StatusBadRequest,
|
||||
Message: "invalid request",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "structured_error",
|
||||
statusCode: http.StatusForbidden,
|
||||
message: "no access",
|
||||
reason: "FORBIDDEN",
|
||||
metadata: map[string]string{"k": "v"},
|
||||
want: Response{
|
||||
Code: http.StatusForbidden,
|
||||
Message: "no access",
|
||||
Reason: "FORBIDDEN",
|
||||
Metadata: map[string]string{"k": "v"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
ErrorWithDetails(c, tt.statusCode, tt.message, tt.reason, tt.metadata)
|
||||
|
||||
require.Equal(t, tt.statusCode, w.Code)
|
||||
|
||||
var got Response
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorFrom(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantWritten bool
|
||||
wantHTTPCode int
|
||||
wantBody Response
|
||||
}{
|
||||
{
|
||||
name: "nil_error",
|
||||
err: nil,
|
||||
wantWritten: false,
|
||||
},
|
||||
{
|
||||
name: "application_error",
|
||||
err: errors2.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusForbidden,
|
||||
wantBody: Response{
|
||||
Code: http.StatusForbidden,
|
||||
Message: "no access",
|
||||
Reason: "FORBIDDEN",
|
||||
Metadata: map[string]string{"scope": "admin"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad_request_error",
|
||||
err: errors2.BadRequest("INVALID_REQUEST", "invalid request"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusBadRequest,
|
||||
wantBody: Response{
|
||||
Code: http.StatusBadRequest,
|
||||
Message: "invalid request",
|
||||
Reason: "INVALID_REQUEST",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unauthorized_error",
|
||||
err: errors2.Unauthorized("UNAUTHORIZED", "unauthorized"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusUnauthorized,
|
||||
wantBody: Response{
|
||||
Code: http.StatusUnauthorized,
|
||||
Message: "unauthorized",
|
||||
Reason: "UNAUTHORIZED",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "not_found_error",
|
||||
err: errors2.NotFound("NOT_FOUND", "not found"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusNotFound,
|
||||
wantBody: Response{
|
||||
Code: http.StatusNotFound,
|
||||
Message: "not found",
|
||||
Reason: "NOT_FOUND",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "conflict_error",
|
||||
err: errors2.Conflict("CONFLICT", "conflict"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusConflict,
|
||||
wantBody: Response{
|
||||
Code: http.StatusConflict,
|
||||
Message: "conflict",
|
||||
Reason: "CONFLICT",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unknown_error_defaults_to_500",
|
||||
err: errors.New("boom"),
|
||||
wantWritten: true,
|
||||
wantHTTPCode: http.StatusInternalServerError,
|
||||
wantBody: Response{
|
||||
Code: http.StatusInternalServerError,
|
||||
Message: errors2.UnknownMessage,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
written := ErrorFrom(c, tt.err)
|
||||
require.Equal(t, tt.wantWritten, written)
|
||||
|
||||
if !tt.wantWritten {
|
||||
require.Equal(t, 200, w.Code)
|
||||
require.Empty(t, w.Body.String())
|
||||
return
|
||||
}
|
||||
|
||||
require.Equal(t, tt.wantHTTPCode, w.Code)
|
||||
var got Response
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
|
||||
require.Equal(t, tt.wantBody, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- 新增测试 ----------
|
||||
|
||||
func TestSuccess(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data any
|
||||
wantCode int
|
||||
wantBody Response
|
||||
}{
|
||||
{
|
||||
name: "返回字符串数据",
|
||||
data: "hello",
|
||||
wantCode: http.StatusOK,
|
||||
wantBody: Response{Code: 0, Message: "success", Data: "hello"},
|
||||
},
|
||||
{
|
||||
name: "返回nil数据",
|
||||
data: nil,
|
||||
wantCode: http.StatusOK,
|
||||
wantBody: Response{Code: 0, Message: "success"},
|
||||
},
|
||||
{
|
||||
name: "返回map数据",
|
||||
data: map[string]string{"key": "value"},
|
||||
wantCode: http.StatusOK,
|
||||
wantBody: Response{Code: 0, Message: "success"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
Success(c, tt.data)
|
||||
|
||||
require.Equal(t, tt.wantCode, w.Code)
|
||||
|
||||
// 只验证 code 和 message,data 字段类型在 JSON 反序列化时会变成 map/slice
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, 0, got.Code)
|
||||
require.Equal(t, "success", got.Message)
|
||||
|
||||
if tt.data == nil {
|
||||
require.Nil(t, got.Data)
|
||||
} else {
|
||||
require.NotNil(t, got.Data)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreated(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data any
|
||||
wantCode int
|
||||
}{
|
||||
{
|
||||
name: "创建成功_返回数据",
|
||||
data: map[string]int{"id": 42},
|
||||
wantCode: http.StatusCreated,
|
||||
},
|
||||
{
|
||||
name: "创建成功_nil数据",
|
||||
data: nil,
|
||||
wantCode: http.StatusCreated,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
Created(c, tt.data)
|
||||
|
||||
require.Equal(t, tt.wantCode, w.Code)
|
||||
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, 0, got.Code)
|
||||
require.Equal(t, "success", got.Message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
message string
|
||||
}{
|
||||
{
|
||||
name: "400错误",
|
||||
statusCode: http.StatusBadRequest,
|
||||
message: "bad request",
|
||||
},
|
||||
{
|
||||
name: "500错误",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
message: "internal error",
|
||||
},
|
||||
{
|
||||
name: "自定义状态码",
|
||||
statusCode: 418,
|
||||
message: "I'm a teapot",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
Error(c, tt.statusCode, tt.message)
|
||||
|
||||
require.Equal(t, tt.statusCode, w.Code)
|
||||
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, tt.statusCode, got.Code)
|
||||
require.Equal(t, tt.message, got.Message)
|
||||
require.Empty(t, got.Reason)
|
||||
require.Nil(t, got.Metadata)
|
||||
require.Nil(t, got.Data)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBadRequest(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
BadRequest(c, "参数无效")
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, http.StatusBadRequest, got.Code)
|
||||
require.Equal(t, "参数无效", got.Message)
|
||||
}
|
||||
|
||||
func TestUnauthorized(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
Unauthorized(c, "未登录")
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, http.StatusUnauthorized, got.Code)
|
||||
require.Equal(t, "未登录", got.Message)
|
||||
}
|
||||
|
||||
func TestForbidden(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
Forbidden(c, "无权限")
|
||||
|
||||
require.Equal(t, http.StatusForbidden, w.Code)
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, http.StatusForbidden, got.Code)
|
||||
require.Equal(t, "无权限", got.Message)
|
||||
}
|
||||
|
||||
func TestNotFound(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
NotFound(c, "资源不存在")
|
||||
|
||||
require.Equal(t, http.StatusNotFound, w.Code)
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, http.StatusNotFound, got.Code)
|
||||
require.Equal(t, "资源不存在", got.Message)
|
||||
}
|
||||
|
||||
func TestInternalError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
InternalError(c, "服务器内部错误")
|
||||
|
||||
require.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, http.StatusInternalServerError, got.Code)
|
||||
require.Equal(t, "服务器内部错误", got.Message)
|
||||
}
|
||||
|
||||
func TestPaginated(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
items any
|
||||
total int64
|
||||
page int
|
||||
pageSize int
|
||||
wantPages int
|
||||
wantTotal int64
|
||||
wantPage int
|
||||
wantPageSize int
|
||||
}{
|
||||
{
|
||||
name: "标准分页_多页",
|
||||
items: []string{"a", "b"},
|
||||
total: 25,
|
||||
page: 1,
|
||||
pageSize: 10,
|
||||
wantPages: 3,
|
||||
wantTotal: 25,
|
||||
wantPage: 1,
|
||||
wantPageSize: 10,
|
||||
},
|
||||
{
|
||||
name: "总数刚好整除",
|
||||
items: []string{"a"},
|
||||
total: 20,
|
||||
page: 2,
|
||||
pageSize: 10,
|
||||
wantPages: 2,
|
||||
wantTotal: 20,
|
||||
wantPage: 2,
|
||||
wantPageSize: 10,
|
||||
},
|
||||
{
|
||||
name: "总数为0_pages至少为1",
|
||||
items: []string{},
|
||||
total: 0,
|
||||
page: 1,
|
||||
pageSize: 10,
|
||||
wantPages: 1,
|
||||
wantTotal: 0,
|
||||
wantPage: 1,
|
||||
wantPageSize: 10,
|
||||
},
|
||||
{
|
||||
name: "单页数据",
|
||||
items: []int{1, 2, 3},
|
||||
total: 3,
|
||||
page: 1,
|
||||
pageSize: 20,
|
||||
wantPages: 1,
|
||||
wantTotal: 3,
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "总数为1",
|
||||
items: []string{"only"},
|
||||
total: 1,
|
||||
page: 1,
|
||||
pageSize: 10,
|
||||
wantPages: 1,
|
||||
wantTotal: 1,
|
||||
wantPage: 1,
|
||||
wantPageSize: 10,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
Paginated(c, tt.items, tt.total, tt.page, tt.pageSize)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
resp, pd := parsePaginatedBody(t, w)
|
||||
require.Equal(t, 0, resp.Code)
|
||||
require.Equal(t, "success", resp.Message)
|
||||
require.Equal(t, tt.wantTotal, pd.Total)
|
||||
require.Equal(t, tt.wantPage, pd.Page)
|
||||
require.Equal(t, tt.wantPageSize, pd.PageSize)
|
||||
require.Equal(t, tt.wantPages, pd.Pages)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPaginatedWithResult(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
items any
|
||||
pagination *PaginationResult
|
||||
wantTotal int64
|
||||
wantPage int
|
||||
wantPageSize int
|
||||
wantPages int
|
||||
}{
|
||||
{
|
||||
name: "正常分页结果",
|
||||
items: []string{"a", "b"},
|
||||
pagination: &PaginationResult{
|
||||
Total: 50,
|
||||
Page: 3,
|
||||
PageSize: 10,
|
||||
Pages: 5,
|
||||
},
|
||||
wantTotal: 50,
|
||||
wantPage: 3,
|
||||
wantPageSize: 10,
|
||||
wantPages: 5,
|
||||
},
|
||||
{
|
||||
name: "pagination为nil_使用默认值",
|
||||
items: []string{},
|
||||
pagination: nil,
|
||||
wantTotal: 0,
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
wantPages: 1,
|
||||
},
|
||||
{
|
||||
name: "单页结果",
|
||||
items: []int{1},
|
||||
pagination: &PaginationResult{
|
||||
Total: 1,
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
Pages: 1,
|
||||
},
|
||||
wantTotal: 1,
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
wantPages: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
PaginatedWithResult(c, tt.items, tt.pagination)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
resp, pd := parsePaginatedBody(t, w)
|
||||
require.Equal(t, 0, resp.Code)
|
||||
require.Equal(t, "success", resp.Message)
|
||||
require.Equal(t, tt.wantTotal, pd.Total)
|
||||
require.Equal(t, tt.wantPage, pd.Page)
|
||||
require.Equal(t, tt.wantPageSize, pd.PageSize)
|
||||
require.Equal(t, tt.wantPages, pd.Pages)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePagination(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
wantPage int
|
||||
wantPageSize int
|
||||
}{
|
||||
{
|
||||
name: "无参数_使用默认值",
|
||||
query: "",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "仅指定page",
|
||||
query: "page=3",
|
||||
wantPage: 3,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "仅指定page_size",
|
||||
query: "page_size=50",
|
||||
wantPage: 1,
|
||||
wantPageSize: 50,
|
||||
},
|
||||
{
|
||||
name: "同时指定page和page_size",
|
||||
query: "page=2&page_size=30",
|
||||
wantPage: 2,
|
||||
wantPageSize: 30,
|
||||
},
|
||||
{
|
||||
name: "使用limit代替page_size",
|
||||
query: "limit=15",
|
||||
wantPage: 1,
|
||||
wantPageSize: 15,
|
||||
},
|
||||
{
|
||||
name: "page_size优先于limit",
|
||||
query: "page_size=25&limit=50",
|
||||
wantPage: 1,
|
||||
wantPageSize: 25,
|
||||
},
|
||||
{
|
||||
name: "page为0_使用默认值",
|
||||
query: "page=0",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "page_size超过1000_使用默认值",
|
||||
query: "page_size=1001",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "page_size恰好1000_有效",
|
||||
query: "page_size=1000",
|
||||
wantPage: 1,
|
||||
wantPageSize: 1000,
|
||||
},
|
||||
{
|
||||
name: "page为非数字_使用默认值",
|
||||
query: "page=abc",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "page_size为非数字_使用默认值",
|
||||
query: "page_size=xyz",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "limit为非数字_使用默认值",
|
||||
query: "limit=abc",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "page_size为0_使用默认值",
|
||||
query: "page_size=0",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "limit为0_使用默认值",
|
||||
query: "limit=0",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "大页码",
|
||||
query: "page=999&page_size=100",
|
||||
wantPage: 999,
|
||||
wantPageSize: 100,
|
||||
},
|
||||
{
|
||||
name: "page_size为1_最小有效值",
|
||||
query: "page_size=1",
|
||||
wantPage: 1,
|
||||
wantPageSize: 1,
|
||||
},
|
||||
{
|
||||
name: "混合数字和字母的page",
|
||||
query: "page=12a",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "limit超过1000_使用默认值",
|
||||
query: "limit=2000",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, c := newContextWithQuery(tt.query)
|
||||
|
||||
page, pageSize := ParsePagination(c)
|
||||
|
||||
require.Equal(t, tt.wantPage, page, "page 不符合预期")
|
||||
require.Equal(t, tt.wantPageSize, pageSize, "pageSize 不符合预期")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_parseInt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantVal int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "正常数字",
|
||||
input: "123",
|
||||
wantVal: 123,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "零",
|
||||
input: "0",
|
||||
wantVal: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "单个数字",
|
||||
input: "5",
|
||||
wantVal: 5,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "大数字",
|
||||
input: "99999",
|
||||
wantVal: 99999,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "包含字母_返回0",
|
||||
input: "abc",
|
||||
wantVal: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "数字开头接字母_返回0",
|
||||
input: "12a",
|
||||
wantVal: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "包含负号_返回0",
|
||||
input: "-1",
|
||||
wantVal: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "包含小数点_返回0",
|
||||
input: "1.5",
|
||||
wantVal: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "包含空格_返回0",
|
||||
input: "1 2",
|
||||
wantVal: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "空字符串",
|
||||
input: "",
|
||||
wantVal: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
val, err := parseInt(tt.input)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
require.Equal(t, tt.wantVal, val)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
package response
|
||||
|
||||
// Response 统一响应结构
|
||||
type Response struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// Success 成功响应
|
||||
func Success(data interface{}) *Response {
|
||||
return &Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
// Error 错误响应
|
||||
func Error(message string) *Response {
|
||||
return &Response{
|
||||
Code: -1,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorWithCode 带错误码的错误响应
|
||||
func ErrorWithCode(code int, message string) *Response {
|
||||
return &Response{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
// WithData 带扩展数据的成功响应
|
||||
func WithData(data interface{}, extra map[string]interface{}) *Response {
|
||||
payload, ok := data.(map[string]interface{})
|
||||
if !ok {
|
||||
payload = map[string]interface{}{
|
||||
"items": data,
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range extra {
|
||||
payload[k] = v
|
||||
}
|
||||
|
||||
resp := Success(payload)
|
||||
return resp
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
package response
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestWithDataWrapsSlicesAndMergesExtra(t *testing.T) {
|
||||
resp := WithData([]string{"a", "b"}, map[string]interface{}{
|
||||
"total": 2,
|
||||
"page": 1,
|
||||
})
|
||||
|
||||
data, ok := resp.Data.(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected map payload, got %T", resp.Data)
|
||||
}
|
||||
if data["total"] != 2 {
|
||||
t.Fatalf("expected total=2, got %v", data["total"])
|
||||
}
|
||||
items, ok := data["items"].([]string)
|
||||
if !ok || len(items) != 2 {
|
||||
t.Fatalf("expected items slice to be preserved, got %#v", data["items"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithDataPreservesMapPayload(t *testing.T) {
|
||||
resp := WithData(map[string]interface{}{"user": "alice"}, map[string]interface{}{"page": 1})
|
||||
|
||||
data := resp.Data.(map[string]interface{})
|
||||
if data["user"] != "alice" {
|
||||
t.Fatalf("expected user=alice, got %v", data["user"])
|
||||
}
|
||||
if data["page"] != 1 {
|
||||
t.Fatalf("expected page=1, got %v", data["page"])
|
||||
}
|
||||
}
|
||||
@@ -1,184 +0,0 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RateLimitAlgorithm 限流算法类型
|
||||
type RateLimitAlgorithm string
|
||||
|
||||
const (
|
||||
AlgorithmTokenBucket RateLimitAlgorithm = "token_bucket"
|
||||
AlgorithmLeakyBucket RateLimitAlgorithm = "leaky_bucket"
|
||||
AlgorithmSlidingWindow RateLimitAlgorithm = "sliding_window"
|
||||
AlgorithmFixedWindow RateLimitAlgorithm = "fixed_window"
|
||||
)
|
||||
|
||||
// TokenBucket 令牌桶算法
|
||||
type TokenBucket struct {
|
||||
capacity int64
|
||||
tokens int64
|
||||
rate int64 // 每秒产生的令牌数
|
||||
lastRefill time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewTokenBucket 创建令牌桶
|
||||
func NewTokenBucket(capacity, rate int64) *TokenBucket {
|
||||
return &TokenBucket{
|
||||
capacity: capacity,
|
||||
tokens: capacity,
|
||||
rate: rate,
|
||||
lastRefill: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Allow 检查是否允许访问
|
||||
func (tb *TokenBucket) Allow() bool {
|
||||
tb.mu.Lock()
|
||||
defer tb.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(tb.lastRefill).Seconds()
|
||||
|
||||
// 计算需要补充的令牌数
|
||||
refillTokens := int64(elapsed * float64(tb.rate))
|
||||
tb.tokens += refillTokens
|
||||
if tb.tokens > tb.capacity {
|
||||
tb.tokens = tb.capacity
|
||||
}
|
||||
tb.lastRefill = now
|
||||
|
||||
// 检查是否有足够的令牌
|
||||
if tb.tokens > 0 {
|
||||
tb.tokens--
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// LeakyBucket 漏桶算法
|
||||
type LeakyBucket struct {
|
||||
capacity int64
|
||||
water int64
|
||||
rate int64 // 每秒漏出的水量
|
||||
lastLeak time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewLeakyBucket 创建漏桶
|
||||
func NewLeakyBucket(capacity, rate int64) *LeakyBucket {
|
||||
return &LeakyBucket{
|
||||
capacity: capacity,
|
||||
water: 0,
|
||||
rate: rate,
|
||||
lastLeak: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Allow 检查是否允许访问
|
||||
func (lb *LeakyBucket) Allow() bool {
|
||||
lb.mu.Lock()
|
||||
defer lb.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(lb.lastLeak).Seconds()
|
||||
|
||||
// 计算漏出的水量
|
||||
leakWater := int64(elapsed * float64(lb.rate))
|
||||
lb.water -= leakWater
|
||||
if lb.water < 0 {
|
||||
lb.water = 0
|
||||
}
|
||||
lb.lastLeak = now
|
||||
|
||||
// 检查桶是否已满
|
||||
if lb.water < lb.capacity {
|
||||
lb.water++
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// SlidingWindow 滑动窗口算法
|
||||
type SlidingWindow struct {
|
||||
window time.Duration
|
||||
capacity int64
|
||||
requests []time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewSlidingWindow 创建滑动窗口
|
||||
func NewSlidingWindow(window time.Duration, capacity int64) *SlidingWindow {
|
||||
return &SlidingWindow{
|
||||
window: window,
|
||||
capacity: capacity,
|
||||
requests: make([]time.Time, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// Allow 检查是否允许访问
|
||||
func (sw *SlidingWindow) Allow() bool {
|
||||
sw.mu.Lock()
|
||||
defer sw.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// 移除窗口外的请求
|
||||
validRequests := make([]time.Time, 0)
|
||||
for _, req := range sw.requests {
|
||||
if now.Sub(req) < sw.window {
|
||||
validRequests = append(validRequests, req)
|
||||
}
|
||||
}
|
||||
sw.requests = validRequests
|
||||
|
||||
// 检查是否超过容量
|
||||
if int64(len(sw.requests)) < sw.capacity {
|
||||
sw.requests = append(sw.requests, now)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// RateLimiter 限流器
|
||||
type RateLimiter struct {
|
||||
algorithm RateLimitAlgorithm
|
||||
limiter interface{}
|
||||
}
|
||||
|
||||
// NewRateLimiter 创建限流器
|
||||
func NewRateLimiter(algorithm RateLimitAlgorithm, capacity, rate int64, window time.Duration) *RateLimiter {
|
||||
limiter := &RateLimiter{algorithm: algorithm}
|
||||
|
||||
switch algorithm {
|
||||
case AlgorithmTokenBucket:
|
||||
limiter.limiter = NewTokenBucket(capacity, rate)
|
||||
case AlgorithmLeakyBucket:
|
||||
limiter.limiter = NewLeakyBucket(capacity, rate)
|
||||
case AlgorithmSlidingWindow:
|
||||
limiter.limiter = NewSlidingWindow(window, capacity)
|
||||
default:
|
||||
limiter.limiter = NewSlidingWindow(window, capacity)
|
||||
}
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// Allow 检查是否允许访问
|
||||
func (rl *RateLimiter) Allow() bool {
|
||||
switch rl.algorithm {
|
||||
case AlgorithmTokenBucket:
|
||||
return rl.limiter.(*TokenBucket).Allow()
|
||||
case AlgorithmLeakyBucket:
|
||||
return rl.limiter.(*LeakyBucket).Allow()
|
||||
case AlgorithmSlidingWindow:
|
||||
return rl.limiter.(*SlidingWindow).Allow()
|
||||
default:
|
||||
return rl.limiter.(*SlidingWindow).Allow()
|
||||
}
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
package response
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Response 统一响应结构
|
||||
type Response struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// Success 成功响应
|
||||
func Success(c *gin.Context, data interface{}) {
|
||||
c.JSON(http.StatusOK, Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
// Error 错误响应
|
||||
func Error(c *gin.Context, httpStatus int, message string, err error) {
|
||||
if err != nil {
|
||||
// 在开发环境下返回详细错误信息
|
||||
if gin.Mode() == gin.DebugMode {
|
||||
c.JSON(httpStatus, Response{
|
||||
Code: httpStatus,
|
||||
Message: message,
|
||||
Data: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
c.JSON(httpStatus, Response{
|
||||
Code: httpStatus,
|
||||
Message: message,
|
||||
})
|
||||
}
|
||||
|
||||
// ErrorWithCode 错误响应(带自定义错误码)
|
||||
func ErrorWithCode(c *gin.Context, code int, message string) {
|
||||
c.JSON(http.StatusOK, Response{
|
||||
Code: code,
|
||||
Message: message,
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user