298 lines
8.3 KiB
Go
298 lines
8.3 KiB
Go
|
|
package handler
|
||
|
|
|
||
|
|
import (
|
||
|
|
"bytes"
|
||
|
|
"net/http"
|
||
|
|
"net/http/httptest"
|
||
|
|
"os"
|
||
|
|
"strings"
|
||
|
|
"testing"
|
||
|
|
|
||
|
|
"github.com/gin-gonic/gin"
|
||
|
|
)
|
||
|
|
|
||
|
|
func TestAuthHandler_SupportFlags(t *testing.T) {
|
||
|
|
var nilHandler *AuthHandler
|
||
|
|
if nilHandler.SupportsPasswordReset() {
|
||
|
|
t.Fatal("nil handler should not support password reset")
|
||
|
|
}
|
||
|
|
|
||
|
|
handler := &AuthHandler{}
|
||
|
|
if handler.SupportsPasswordReset() {
|
||
|
|
t.Fatal("password reset should be disabled by default")
|
||
|
|
}
|
||
|
|
|
||
|
|
handler.SetPasswordResetEnabled(true)
|
||
|
|
if !handler.SupportsPasswordReset() {
|
||
|
|
t.Fatal("password reset flag should be enabled")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestGetUserIDFromContext(t *testing.T) {
|
||
|
|
gin.SetMode(gin.TestMode)
|
||
|
|
recorder := httptest.NewRecorder()
|
||
|
|
c, _ := gin.CreateTestContext(recorder)
|
||
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/userinfo", nil)
|
||
|
|
|
||
|
|
if _, ok := getUserIDFromContext(c); ok {
|
||
|
|
t.Fatal("expected missing user_id to return false")
|
||
|
|
}
|
||
|
|
|
||
|
|
c.Set("user_id", "1")
|
||
|
|
if _, ok := getUserIDFromContext(c); ok {
|
||
|
|
t.Fatal("expected non-int64 user_id to return false")
|
||
|
|
}
|
||
|
|
|
||
|
|
c.Set("user_id", int64(42))
|
||
|
|
if got, ok := getUserIDFromContext(c); !ok || got != 42 {
|
||
|
|
t.Fatalf("getUserIDFromContext() = (%d, %v), want (42, true)", got, ok)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestRequestUsesHTTPS(t *testing.T) {
|
||
|
|
gin.SetMode(gin.TestMode)
|
||
|
|
|
||
|
|
if requestUsesHTTPS(nil) {
|
||
|
|
t.Fatal("nil context should not use https")
|
||
|
|
}
|
||
|
|
|
||
|
|
recorder := httptest.NewRecorder()
|
||
|
|
c, _ := gin.CreateTestContext(recorder)
|
||
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/auth", nil)
|
||
|
|
if requestUsesHTTPS(c) {
|
||
|
|
t.Fatal("plain http request should not use https")
|
||
|
|
}
|
||
|
|
|
||
|
|
c.Request.Header.Set("X-Forwarded-Proto", "https")
|
||
|
|
if !requestUsesHTTPS(c) {
|
||
|
|
t.Fatal("forwarded https request should be detected")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestSessionCookies_SetAndClear(t *testing.T) {
|
||
|
|
gin.SetMode(gin.TestMode)
|
||
|
|
|
||
|
|
recorder := httptest.NewRecorder()
|
||
|
|
c, _ := gin.CreateTestContext(recorder)
|
||
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/auth", nil)
|
||
|
|
|
||
|
|
setSessionCookies(c, nil, "")
|
||
|
|
if len(recorder.Header().Values("Set-Cookie")) != 0 {
|
||
|
|
t.Fatal("empty refresh token should not set cookies")
|
||
|
|
}
|
||
|
|
|
||
|
|
setSessionCookies(c, nil, "refresh-token")
|
||
|
|
setCookies := recorder.Header().Values("Set-Cookie")
|
||
|
|
if len(setCookies) < 2 {
|
||
|
|
t.Fatalf("expected session cookies to be set, got %d", len(setCookies))
|
||
|
|
}
|
||
|
|
if !strings.Contains(setCookies[0], refreshTokenCookieName+"=refresh-token") &&
|
||
|
|
!strings.Contains(setCookies[1], refreshTokenCookieName+"=refresh-token") {
|
||
|
|
t.Fatalf("expected refresh token cookie, got %#v", setCookies)
|
||
|
|
}
|
||
|
|
|
||
|
|
recorder = httptest.NewRecorder()
|
||
|
|
c, _ = gin.CreateTestContext(recorder)
|
||
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/auth", nil)
|
||
|
|
clearSessionCookies(c)
|
||
|
|
setCookies = recorder.Header().Values("Set-Cookie")
|
||
|
|
if len(setCookies) < 2 {
|
||
|
|
t.Fatalf("expected clearing cookies to emit expired cookies, got %d", len(setCookies))
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestClassifyErrorMessage(t *testing.T) {
|
||
|
|
testCases := []struct {
|
||
|
|
name string
|
||
|
|
msg string
|
||
|
|
want int
|
||
|
|
}{
|
||
|
|
{name: "not found", msg: "user not found", want: http.StatusNotFound},
|
||
|
|
{name: "duplicate", msg: "already exists", want: http.StatusConflict},
|
||
|
|
{name: "verification code", msg: "验证码错误", want: http.StatusUnauthorized},
|
||
|
|
{name: "unauthorized", msg: "invalid token", want: http.StatusUnauthorized},
|
||
|
|
{name: "forbidden", msg: "permission denied", want: http.StatusForbidden},
|
||
|
|
{name: "bad request", msg: "invalid payload", want: http.StatusBadRequest},
|
||
|
|
{name: "rate limit", msg: "too many attempts", want: http.StatusTooManyRequests},
|
||
|
|
{name: "fallback", msg: "unexpected boom", want: http.StatusInternalServerError},
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, tc := range testCases {
|
||
|
|
t.Run(tc.name, func(t *testing.T) {
|
||
|
|
if got := classifyErrorMessage(tc.msg); got != tc.want {
|
||
|
|
t.Fatalf("classifyErrorMessage(%q) = %d, want %d", tc.msg, got, tc.want)
|
||
|
|
}
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestAuthHandler_OAuthFallbackEndpoints(t *testing.T) {
|
||
|
|
gin.SetMode(gin.TestMode)
|
||
|
|
h := &AuthHandler{}
|
||
|
|
|
||
|
|
testCases := []struct {
|
||
|
|
name string
|
||
|
|
run func(*gin.Context)
|
||
|
|
}{
|
||
|
|
{
|
||
|
|
name: "oauth login",
|
||
|
|
run: func(c *gin.Context) {
|
||
|
|
c.Params = gin.Params{{Key: "provider", Value: "github"}}
|
||
|
|
h.OAuthLogin(c)
|
||
|
|
},
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "oauth callback",
|
||
|
|
run: func(c *gin.Context) {
|
||
|
|
c.Params = gin.Params{{Key: "provider", Value: "github"}}
|
||
|
|
h.OAuthCallback(c)
|
||
|
|
},
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "oauth exchange",
|
||
|
|
run: func(c *gin.Context) {
|
||
|
|
c.Params = gin.Params{{Key: "provider", Value: "github"}}
|
||
|
|
h.OAuthExchange(c)
|
||
|
|
},
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "oauth providers",
|
||
|
|
run: func(c *gin.Context) {
|
||
|
|
h.GetEnabledOAuthProviders(c)
|
||
|
|
},
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, tc := range testCases {
|
||
|
|
t.Run(tc.name, func(t *testing.T) {
|
||
|
|
recorder := httptest.NewRecorder()
|
||
|
|
c, _ := gin.CreateTestContext(recorder)
|
||
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/auth", nil)
|
||
|
|
tc.run(c)
|
||
|
|
if recorder.Code != http.StatusOK {
|
||
|
|
t.Fatalf("expected 200, got %d", recorder.Code)
|
||
|
|
}
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestAuthHandler_RefreshToken_InvalidJSON(t *testing.T) {
|
||
|
|
gin.SetMode(gin.TestMode)
|
||
|
|
h := &AuthHandler{}
|
||
|
|
|
||
|
|
recorder := httptest.NewRecorder()
|
||
|
|
c, _ := gin.CreateTestContext(recorder)
|
||
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/auth/refresh", bytes.NewBufferString("{"))
|
||
|
|
c.Request.Header.Set("Content-Type", "application/json")
|
||
|
|
|
||
|
|
h.RefreshToken(c)
|
||
|
|
|
||
|
|
if recorder.Code != http.StatusBadRequest {
|
||
|
|
t.Fatalf("expected 400, got %d", recorder.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestAuthHandler_ActivateEmail_MissingToken(t *testing.T) {
|
||
|
|
gin.SetMode(gin.TestMode)
|
||
|
|
h := &AuthHandler{}
|
||
|
|
|
||
|
|
recorder := httptest.NewRecorder()
|
||
|
|
c, _ := gin.CreateTestContext(recorder)
|
||
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/auth/activate-email", bytes.NewBufferString(`{}`))
|
||
|
|
c.Request.Header.Set("Content-Type", "application/json")
|
||
|
|
|
||
|
|
h.ActivateEmail(c)
|
||
|
|
|
||
|
|
if recorder.Code != http.StatusBadRequest {
|
||
|
|
t.Fatalf("expected 400, got %d", recorder.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestAuthHandler_ResendActivationEmail_InvalidEmail(t *testing.T) {
|
||
|
|
gin.SetMode(gin.TestMode)
|
||
|
|
h := &AuthHandler{}
|
||
|
|
|
||
|
|
recorder := httptest.NewRecorder()
|
||
|
|
c, _ := gin.CreateTestContext(recorder)
|
||
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/auth/resend-activation-email", bytes.NewBufferString(`{"email":"bad-email"}`))
|
||
|
|
c.Request.Header.Set("Content-Type", "application/json")
|
||
|
|
|
||
|
|
h.ResendActivationEmail(c)
|
||
|
|
|
||
|
|
if recorder.Code != http.StatusBadRequest {
|
||
|
|
t.Fatalf("expected 400, got %d", recorder.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestAuthHandler_SendEmailCode_InvalidEmail(t *testing.T) {
|
||
|
|
gin.SetMode(gin.TestMode)
|
||
|
|
h := &AuthHandler{}
|
||
|
|
|
||
|
|
recorder := httptest.NewRecorder()
|
||
|
|
c, _ := gin.CreateTestContext(recorder)
|
||
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/auth/send-email-code", bytes.NewBufferString(`{"email":"bad-email"}`))
|
||
|
|
c.Request.Header.Set("Content-Type", "application/json")
|
||
|
|
|
||
|
|
h.SendEmailCode(c)
|
||
|
|
|
||
|
|
if recorder.Code != http.StatusBadRequest {
|
||
|
|
t.Fatalf("expected 400, got %d", recorder.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestAuthHandler_LoginByEmailCode_InvalidPayload(t *testing.T) {
|
||
|
|
gin.SetMode(gin.TestMode)
|
||
|
|
h := &AuthHandler{}
|
||
|
|
|
||
|
|
recorder := httptest.NewRecorder()
|
||
|
|
c, _ := gin.CreateTestContext(recorder)
|
||
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/auth/login-by-email-code", bytes.NewBufferString(`{"email":"bad-email"}`))
|
||
|
|
c.Request.Header.Set("Content-Type", "application/json")
|
||
|
|
|
||
|
|
h.LoginByEmailCode(c)
|
||
|
|
|
||
|
|
if recorder.Code != http.StatusBadRequest {
|
||
|
|
t.Fatalf("expected 400, got %d", recorder.Code)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestAuthHandler_BootstrapAdmin_HeaderFailures(t *testing.T) {
|
||
|
|
gin.SetMode(gin.TestMode)
|
||
|
|
h := &AuthHandler{}
|
||
|
|
|
||
|
|
original := os.Getenv("BOOTSTRAP_SECRET")
|
||
|
|
if err := os.Setenv("BOOTSTRAP_SECRET", "expected-secret"); err != nil {
|
||
|
|
t.Fatalf("set env failed: %v", err)
|
||
|
|
}
|
||
|
|
t.Cleanup(func() {
|
||
|
|
_ = os.Setenv("BOOTSTRAP_SECRET", original)
|
||
|
|
})
|
||
|
|
|
||
|
|
testCases := []struct {
|
||
|
|
name string
|
||
|
|
secret string
|
||
|
|
want int
|
||
|
|
}{
|
||
|
|
{name: "missing header", secret: "", want: http.StatusUnauthorized},
|
||
|
|
{name: "wrong header", secret: "wrong-secret", want: http.StatusUnauthorized},
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, tc := range testCases {
|
||
|
|
t.Run(tc.name, func(t *testing.T) {
|
||
|
|
recorder := httptest.NewRecorder()
|
||
|
|
c, _ := gin.CreateTestContext(recorder)
|
||
|
|
c.Request = httptest.NewRequest(http.MethodPost, "/auth/bootstrap-admin", bytes.NewBufferString(`{"username":"admin","email":"admin@example.com","password":"AdminPass123!"}`))
|
||
|
|
c.Request.Header.Set("Content-Type", "application/json")
|
||
|
|
if tc.secret != "" {
|
||
|
|
c.Request.Header.Set("X-Bootstrap-Secret", tc.secret)
|
||
|
|
}
|
||
|
|
|
||
|
|
h.BootstrapAdmin(c)
|
||
|
|
|
||
|
|
if recorder.Code != tc.want {
|
||
|
|
t.Fatalf("expected %d, got %d", tc.want, recorder.Code)
|
||
|
|
}
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|