Files
user-system/internal/api/handler/auth_handler_unit_test.go

298 lines
8.3 KiB
Go
Raw Normal View History

package handler
import (
"bytes"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/gin-gonic/gin"
)
func TestAuthHandler_SupportFlags(t *testing.T) {
var nilHandler *AuthHandler
if nilHandler.SupportsPasswordReset() {
t.Fatal("nil handler should not support password reset")
}
handler := &AuthHandler{}
if handler.SupportsPasswordReset() {
t.Fatal("password reset should be disabled by default")
}
handler.SetPasswordResetEnabled(true)
if !handler.SupportsPasswordReset() {
t.Fatal("password reset flag should be enabled")
}
}
func TestGetUserIDFromContext(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/userinfo", nil)
if _, ok := getUserIDFromContext(c); ok {
t.Fatal("expected missing user_id to return false")
}
c.Set("user_id", "1")
if _, ok := getUserIDFromContext(c); ok {
t.Fatal("expected non-int64 user_id to return false")
}
c.Set("user_id", int64(42))
if got, ok := getUserIDFromContext(c); !ok || got != 42 {
t.Fatalf("getUserIDFromContext() = (%d, %v), want (42, true)", got, ok)
}
}
func TestRequestUsesHTTPS(t *testing.T) {
gin.SetMode(gin.TestMode)
if requestUsesHTTPS(nil) {
t.Fatal("nil context should not use https")
}
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/auth", nil)
if requestUsesHTTPS(c) {
t.Fatal("plain http request should not use https")
}
c.Request.Header.Set("X-Forwarded-Proto", "https")
if !requestUsesHTTPS(c) {
t.Fatal("forwarded https request should be detected")
}
}
func TestSessionCookies_SetAndClear(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/auth", nil)
setSessionCookies(c, nil, "")
if len(recorder.Header().Values("Set-Cookie")) != 0 {
t.Fatal("empty refresh token should not set cookies")
}
setSessionCookies(c, nil, "refresh-token")
setCookies := recorder.Header().Values("Set-Cookie")
if len(setCookies) < 2 {
t.Fatalf("expected session cookies to be set, got %d", len(setCookies))
}
if !strings.Contains(setCookies[0], refreshTokenCookieName+"=refresh-token") &&
!strings.Contains(setCookies[1], refreshTokenCookieName+"=refresh-token") {
t.Fatalf("expected refresh token cookie, got %#v", setCookies)
}
recorder = httptest.NewRecorder()
c, _ = gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/auth", nil)
clearSessionCookies(c)
setCookies = recorder.Header().Values("Set-Cookie")
if len(setCookies) < 2 {
t.Fatalf("expected clearing cookies to emit expired cookies, got %d", len(setCookies))
}
}
func TestClassifyErrorMessage(t *testing.T) {
testCases := []struct {
name string
msg string
want int
}{
{name: "not found", msg: "user not found", want: http.StatusNotFound},
{name: "duplicate", msg: "already exists", want: http.StatusConflict},
{name: "verification code", msg: "验证码错误", want: http.StatusUnauthorized},
{name: "unauthorized", msg: "invalid token", want: http.StatusUnauthorized},
{name: "forbidden", msg: "permission denied", want: http.StatusForbidden},
{name: "bad request", msg: "invalid payload", want: http.StatusBadRequest},
{name: "rate limit", msg: "too many attempts", want: http.StatusTooManyRequests},
{name: "fallback", msg: "unexpected boom", want: http.StatusInternalServerError},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if got := classifyErrorMessage(tc.msg); got != tc.want {
t.Fatalf("classifyErrorMessage(%q) = %d, want %d", tc.msg, got, tc.want)
}
})
}
}
func TestAuthHandler_OAuthFallbackEndpoints(t *testing.T) {
gin.SetMode(gin.TestMode)
h := &AuthHandler{}
testCases := []struct {
name string
run func(*gin.Context)
}{
{
name: "oauth login",
run: func(c *gin.Context) {
c.Params = gin.Params{{Key: "provider", Value: "github"}}
h.OAuthLogin(c)
},
},
{
name: "oauth callback",
run: func(c *gin.Context) {
c.Params = gin.Params{{Key: "provider", Value: "github"}}
h.OAuthCallback(c)
},
},
{
name: "oauth exchange",
run: func(c *gin.Context) {
c.Params = gin.Params{{Key: "provider", Value: "github"}}
h.OAuthExchange(c)
},
},
{
name: "oauth providers",
run: func(c *gin.Context) {
h.GetEnabledOAuthProviders(c)
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/auth", nil)
tc.run(c)
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", recorder.Code)
}
})
}
}
func TestAuthHandler_RefreshToken_InvalidJSON(t *testing.T) {
gin.SetMode(gin.TestMode)
h := &AuthHandler{}
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/auth/refresh", bytes.NewBufferString("{"))
c.Request.Header.Set("Content-Type", "application/json")
h.RefreshToken(c)
if recorder.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", recorder.Code)
}
}
func TestAuthHandler_ActivateEmail_MissingToken(t *testing.T) {
gin.SetMode(gin.TestMode)
h := &AuthHandler{}
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/auth/activate-email", bytes.NewBufferString(`{}`))
c.Request.Header.Set("Content-Type", "application/json")
h.ActivateEmail(c)
if recorder.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", recorder.Code)
}
}
func TestAuthHandler_ResendActivationEmail_InvalidEmail(t *testing.T) {
gin.SetMode(gin.TestMode)
h := &AuthHandler{}
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/auth/resend-activation-email", bytes.NewBufferString(`{"email":"bad-email"}`))
c.Request.Header.Set("Content-Type", "application/json")
h.ResendActivationEmail(c)
if recorder.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", recorder.Code)
}
}
func TestAuthHandler_SendEmailCode_InvalidEmail(t *testing.T) {
gin.SetMode(gin.TestMode)
h := &AuthHandler{}
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/auth/send-email-code", bytes.NewBufferString(`{"email":"bad-email"}`))
c.Request.Header.Set("Content-Type", "application/json")
h.SendEmailCode(c)
if recorder.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", recorder.Code)
}
}
func TestAuthHandler_LoginByEmailCode_InvalidPayload(t *testing.T) {
gin.SetMode(gin.TestMode)
h := &AuthHandler{}
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/auth/login-by-email-code", bytes.NewBufferString(`{"email":"bad-email"}`))
c.Request.Header.Set("Content-Type", "application/json")
h.LoginByEmailCode(c)
if recorder.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", recorder.Code)
}
}
func TestAuthHandler_BootstrapAdmin_HeaderFailures(t *testing.T) {
gin.SetMode(gin.TestMode)
h := &AuthHandler{}
original := os.Getenv("BOOTSTRAP_SECRET")
if err := os.Setenv("BOOTSTRAP_SECRET", "expected-secret"); err != nil {
t.Fatalf("set env failed: %v", err)
}
t.Cleanup(func() {
_ = os.Setenv("BOOTSTRAP_SECRET", original)
})
testCases := []struct {
name string
secret string
want int
}{
{name: "missing header", secret: "", want: http.StatusUnauthorized},
{name: "wrong header", secret: "wrong-secret", want: http.StatusUnauthorized},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/auth/bootstrap-admin", bytes.NewBufferString(`{"username":"admin","email":"admin@example.com","password":"AdminPass123!"}`))
c.Request.Header.Set("Content-Type", "application/json")
if tc.secret != "" {
c.Request.Header.Set("X-Bootstrap-Secret", tc.secret)
}
h.BootstrapAdmin(c)
if recorder.Code != tc.want {
t.Fatalf("expected %d, got %d", tc.want, recorder.Code)
}
})
}
}