package handler import ( "context" "crypto/subtle" "errors" "net/http" "os" "strings" "time" "github.com/gin-gonic/gin" apierrors "github.com/user-management-system/internal/pkg/errors" "github.com/user-management-system/internal/service" ) // newBackgroundCtx 创建用于后台 goroutine 的带超时独立 context(与请求 context 无关) func newBackgroundCtx(timeoutSec int) (context.Context, context.CancelFunc) { return context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second) } // AuthHandler handles authentication requests type AuthHandler struct { authService *service.AuthService } // NewAuthHandler creates a new AuthHandler func NewAuthHandler(authService *service.AuthService) *AuthHandler { return &AuthHandler{authService: authService} } func (h *AuthHandler) Register(c *gin.Context) { var req struct { Username string `json:"username" binding:"required"` Email string `json:"email"` Phone string `json:"phone"` Password string `json:"password" binding:"required"` Nickname string `json:"nickname"` } if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } registerReq := &service.RegisterRequest{ Username: req.Username, Email: req.Email, Phone: req.Phone, Password: req.Password, Nickname: req.Nickname, } userInfo, err := h.authService.Register(c.Request.Context(), registerReq) if err != nil { handleError(c, err) return } c.JSON(http.StatusCreated, userInfo) } func (h *AuthHandler) Login(c *gin.Context) { var req struct { Account string `json:"account"` Username string `json:"username"` Email string `json:"email"` Phone string `json:"phone"` Password string `json:"password"` DeviceID string `json:"device_id"` DeviceName string `json:"device_name"` DeviceBrowser string `json:"device_browser"` DeviceOS string `json:"device_os"` } if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } loginReq := &service.LoginRequest{ Account: req.Account, Username: req.Username, Email: req.Email, Phone: req.Phone, Password: req.Password, DeviceID: req.DeviceID, DeviceName: req.DeviceName, DeviceBrowser: req.DeviceBrowser, DeviceOS: req.DeviceOS, } clientIP := c.ClientIP() resp, err := h.authService.Login(c.Request.Context(), loginReq, clientIP) if err != nil { handleError(c, err) return } c.JSON(http.StatusOK, resp) } func (h *AuthHandler) Logout(c *gin.Context) { var req struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` } // 允许 body 为空(仅凭 Authorization header 里的 access_token 注销也可以) _ = c.ShouldBindJSON(&req) // 如果 body 里没有 access_token,则从 Authorization header 中取 if req.AccessToken == "" { if bearer := c.GetHeader("Authorization"); len(bearer) > 7 { req.AccessToken = bearer[7:] // 去掉 "Bearer " } } username, _ := c.Get("username") usernameStr, _ := username.(string) logoutReq := &service.LogoutRequest{ AccessToken: req.AccessToken, RefreshToken: req.RefreshToken, } _ = h.authService.Logout(c.Request.Context(), usernameStr, logoutReq) c.JSON(http.StatusOK, gin.H{"message": "logged out"}) } func (h *AuthHandler) RefreshToken(c *gin.Context) { var req struct { RefreshToken string `json:"refresh_token" binding:"required"` } if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } resp, err := h.authService.RefreshToken(c.Request.Context(), req.RefreshToken) if err != nil { handleError(c, err) return } c.JSON(http.StatusOK, resp) } func (h *AuthHandler) GetUserInfo(c *gin.Context) { userID, ok := getUserIDFromContext(c) if !ok { c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"}) return } userInfo, err := h.authService.GetUserInfo(c.Request.Context(), userID) if err != nil { handleError(c, err) return } c.JSON(http.StatusOK, userInfo) } func (h *AuthHandler) GetCSRFToken(c *gin.Context) { // 系统使用 JWT Bearer Token 认证,Bearer Token 不会被浏览器自动携带(非 cookie) // 因此不存在传统意义上的 CSRF 风险,此端点返回空 token 作为兼容响应 c.JSON(http.StatusOK, gin.H{ "csrf_token": "", "note": "JWT Bearer Token authentication; CSRF protection not required", }) } func (h *AuthHandler) GetAuthCapabilities(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "register": true, "login": true, "oauth_login": false, "totp": true, }) } func (h *AuthHandler) OAuthLogin(c *gin.Context) { provider := c.Param("provider") c.JSON(http.StatusOK, gin.H{"provider": provider, "message": "OAuth not configured"}) } func (h *AuthHandler) OAuthCallback(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"error": "OAuth not configured"}) } func (h *AuthHandler) OAuthExchange(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"error": "OAuth not configured"}) } func (h *AuthHandler) GetEnabledOAuthProviders(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"providers": []string{}}) } func (h *AuthHandler) ActivateEmail(c *gin.Context) { token := c.Query("token") if token == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "token is required"}) return } if err := h.authService.ActivateEmail(c.Request.Context(), token); err != nil { handleError(c, err) return } c.JSON(http.StatusOK, gin.H{"message": "email activated successfully"}) } func (h *AuthHandler) ResendActivationEmail(c *gin.Context) { var req struct { Email string `json:"email" binding:"required,email"` } if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } if err := h.authService.ResendActivationEmail(c.Request.Context(), req.Email); err != nil { handleError(c, err) return } // 防枚举:无论邮箱是否存在,统一返回成功 c.JSON(http.StatusOK, gin.H{"message": "activation email sent if address is registered"}) } func (h *AuthHandler) SendEmailCode(c *gin.Context) { var req struct { Email string `json:"email" binding:"required,email"` } if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } // SendEmailLoginCode 内部会忽略未注册邮箱(防枚举),始终返回 ok if err := h.authService.SendEmailLoginCode(c.Request.Context(), req.Email); err != nil { handleError(c, err) return } c.JSON(http.StatusOK, gin.H{"message": "验证码已发送"}) } func (h *AuthHandler) LoginByEmailCode(c *gin.Context) { var req struct { Email string `json:"email" binding:"required,email"` Code string `json:"code" binding:"required"` DeviceID string `json:"device_id"` DeviceName string `json:"device_name"` DeviceBrowser string `json:"device_browser"` DeviceOS string `json:"device_os"` } if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } clientIP := c.ClientIP() resp, err := h.authService.LoginByEmailCode(c.Request.Context(), req.Email, req.Code, clientIP) if err != nil { handleError(c, err) return } // 异步注册设备(不阻塞主流程) // 注意:必须用 context.WithTimeout(context.Background()) 而非 c.Request.Context() // gin 在 c.JSON 返回后会回收 context,goroutine 中引用会得到已取消的 context if req.DeviceID != "" && resp != nil && resp.User != nil { loginReq := &service.LoginRequest{ DeviceID: req.DeviceID, DeviceName: req.DeviceName, DeviceBrowser: req.DeviceBrowser, DeviceOS: req.DeviceOS, } userID := resp.User.ID go func() { devCtx, cancel := newBackgroundCtx(5) defer cancel() h.authService.BestEffortRegisterDevicePublic(devCtx, userID, loginReq) }() } c.JSON(http.StatusOK, resp) } func (h *AuthHandler) BootstrapAdmin(c *gin.Context) { // P0 修复:BootstrapAdmin 端点需要 bootstrap secret 验证 bootstrapSecret := os.Getenv("BOOTSTRAP_SECRET") if bootstrapSecret == "" { c.JSON(http.StatusForbidden, gin.H{"error": "引导初始化未授权"}) return } providedSecret := c.GetHeader("X-Bootstrap-Secret") if providedSecret == "" { c.JSON(http.StatusUnauthorized, gin.H{"error": "缺少引导密钥"}) return } // 使用恒定时间比较防止时序攻击 if subtle.ConstantTimeCompare([]byte(providedSecret), []byte(bootstrapSecret)) != 1 { c.JSON(http.StatusUnauthorized, gin.H{"error": "引导密钥无效"}) return } var req struct { Username string `json:"username" binding:"required"` Email string `json:"email" binding:"required"` Password string `json:"password" binding:"required"` } if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } bootstrapReq := &service.BootstrapAdminRequest{ Username: req.Username, Email: req.Email, Password: req.Password, } clientIP := c.ClientIP() resp, err := h.authService.BootstrapAdmin(c.Request.Context(), bootstrapReq, clientIP) if err != nil { handleError(c, err) return } c.JSON(http.StatusCreated, resp) } func (h *AuthHandler) SendEmailBindCode(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "email bind not configured"}) } func (h *AuthHandler) BindEmail(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "email bind not configured"}) } func (h *AuthHandler) UnbindEmail(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "email unbind not configured"}) } func (h *AuthHandler) SendPhoneBindCode(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "phone bind not configured"}) } func (h *AuthHandler) BindPhone(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "phone bind not configured"}) } func (h *AuthHandler) UnbindPhone(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "phone unbind not configured"}) } func (h *AuthHandler) GetSocialAccounts(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"accounts": []interface{}{}}) } func (h *AuthHandler) BindSocialAccount(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "social binding not configured"}) } func (h *AuthHandler) UnbindSocialAccount(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "social unbinding not configured"}) } func (h *AuthHandler) SupportsEmailCodeLogin() bool { return h.authService.HasEmailCodeService() } func getUserIDFromContext(c *gin.Context) (int64, bool) { userID, exists := c.Get("user_id") if !exists { return 0, false } id, ok := userID.(int64) return id, ok } // handleError 将 error 转换为对应的 HTTP 响应。 // 优先识别 ApplicationError,其次通过关键词推断业务错误类型,兜底返回 500。 func handleError(c *gin.Context, err error) { if err == nil { return } // 优先尝试 ApplicationError(内置 HTTP 状态码) var appErr *apierrors.ApplicationError if errors.As(err, &appErr) { c.JSON(int(appErr.Code), gin.H{"error": appErr.Message}) return } // 对普通 errors.New 按关键词推断语义,但只返回通用错误信息给客户端 msg := err.Error() code := classifyErrorMessage(msg) c.JSON(code, gin.H{"error": "服务器内部错误"}) } // classifyErrorMessage 通过错误信息关键词推断 HTTP 状态码,避免业务错误被 500 吞掉 func classifyErrorMessage(msg string) int { lower := strings.ToLower(msg) switch { case contains(lower, "not found", "不存在", "找不到"): return http.StatusNotFound case contains(lower, "already exists", "已存在", "已注册", "duplicate"): return http.StatusConflict case contains(lower, "unauthorized", "invalid token", "token", "令牌", "未认证"): return http.StatusUnauthorized case contains(lower, "forbidden", "permission", "权限", "禁止"): return http.StatusForbidden case contains(lower, "invalid", "required", "must", "cannot be empty", "不能为空", "格式", "参数", "密码不正确", "incorrect", "wrong", "too short", "too long", "已失效", "expired", "验证码不正确", "不能与"): return http.StatusBadRequest case contains(lower, "locked", "too many", "账号已被锁定", "rate limit"): return http.StatusTooManyRequests default: return http.StatusInternalServerError } } // contains 检查 s 是否包含 keywords 中的任意一个 func contains(s string, keywords ...string) bool { for _, kw := range keywords { if strings.Contains(s, kw) { return true } } return false }