diff --git a/internal/api/handler/auth_handler.go b/internal/api/handler/auth_handler.go index 7047ffb..fcda6bd 100644 --- a/internal/api/handler/auth_handler.go +++ b/internal/api/handler/auth_handler.go @@ -784,13 +784,17 @@ func classifyErrorMessage(msg string) int { return http.StatusNotFound case contains(lower, "already exists", "已存在", "已注册", "duplicate"): return http.StatusConflict + case contains(lower, "验证码错误", "验证码或恢复码错误", "verification code", "recovery code"): + return http.StatusUnauthorized case contains(lower, "unauthorized", "invalid token", "token", "令牌", "未认证"): return http.StatusUnauthorized case contains(lower, "forbidden", "permission", "权限", "禁止"): return http.StatusForbidden + case contains(lower, "2fa 已", "2fa 未", "请先初始化 2fa", "已启用", "未启用"): + return http.StatusBadRequest case contains(lower, "invalid", "required", "must", "cannot be empty", "不能为空", "格式", "参数", "密码不正确", "incorrect", "wrong", "too short", "too long", - "已失效", "expired", "验证码不正确", "不能与"): + "已失效", "expired", "验证码不正确", "不能与", "不能删除自己", "不能删除最后一个管理员"): return http.StatusBadRequest case contains(lower, "locked", "too many", "账号已被锁定", "rate limit"): return http.StatusTooManyRequests diff --git a/internal/api/handler/handler_test.go b/internal/api/handler/handler_test.go index 4e7efec..fe7e8f4 100644 --- a/internal/api/handler/handler_test.go +++ b/internal/api/handler/handler_test.go @@ -2,6 +2,7 @@ package handler_test import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -71,7 +72,7 @@ func seedHandlerAuthzData(t *testing.T, db *gorm.DB) { } } -func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) { +func setupHandlerTestServerWithCacheAndOptions(t *testing.T, enableEmailActivation bool) (*httptest.Server, func(), *cache.CacheManager) { t.Helper() gin.SetMode(gin.TestMode) @@ -85,7 +86,7 @@ func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) { }) if err != nil { t.Skipf("skipping handler test (SQLite unavailable): %v", err) - return nil, func() {} + return nil, func() {}, nil } if err := db.AutoMigrate( @@ -133,6 +134,12 @@ func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) { authSvc.SetRoleRepositories(userRoleRepo, roleRepo) smsCodeSvc := service.NewSMSCodeService(&service.MockSMSProvider{}, cacheManager, service.DefaultSMSCodeConfig()) authSvc.SetSMSCodeService(smsCodeSvc) + emailCodeSvc := service.NewEmailCodeService(&service.MockEmailProvider{}, cacheManager, service.EmailCodeConfig{}) + authSvc.SetEmailCodeService(emailCodeSvc) + if enableEmailActivation { + emailActivationSvc := service.NewEmailActivationService(&service.MockEmailProvider{}, cacheManager, "http://localhost:3000", "TestSite") + authSvc.SetEmailActivationService(emailActivationSvc) + } userSvc := service.NewUserService(userRepo, userRoleRepo, roleRepo, passwordHistoryRepo) roleSvc := service.NewRoleService(roleRepo, rolePermissionRepo) permSvc := service.NewPermissionService(permissionRepo) @@ -166,12 +173,13 @@ func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) { totpHandler := handler.NewTOTPHandler(authSvc, totpSvc) pwdResetHandler := handler.NewPasswordResetHandler(pwdResetSvc) themeHandler := handler.NewThemeHandler(themeSvc) + smsHandler := handler.NewSMSHandler(authSvc, smsCodeSvc) r := router.NewRouter( authHandler, userHandler, roleHandler, permHandler, deviceHandler, logHandler, authMiddleware, rateLimitMiddleware, opLogMiddleware, pwdResetHandler, captchaHandler, totpHandler, nil, - nil, nil, nil, nil, nil, themeHandler, nil, nil, nil, avatarH, + nil, nil, nil, smsHandler, nil, themeHandler, nil, nil, nil, avatarH, ) engine := r.Setup() @@ -181,7 +189,20 @@ func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) { if sqlDB, _ := db.DB(); sqlDB != nil { sqlDB.Close() } - } + }, cacheManager +} + +func setupHandlerTestServerWithCache(t *testing.T) (*httptest.Server, func(), *cache.CacheManager) { + return setupHandlerTestServerWithCacheAndOptions(t, false) +} + +func setupHandlerTestServerWithActivation(t *testing.T) (*httptest.Server, func(), *cache.CacheManager) { + return setupHandlerTestServerWithCacheAndOptions(t, true) +} + +func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) { + server, cleanup, _ := setupHandlerTestServerWithCache(t) + return server, cleanup } func doRequest(method, url string, token string, body interface{}) (*http.Response, string) { @@ -195,8 +216,15 @@ func doRequest(method, url string, token string, body interface{}) (*http.Respon req.Header.Set("Authorization", "Bearer "+token) } req.Header.Set("Content-Type", "application/json") - client := &http.Client{} + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } resp, _ := client.Do(req) + if resp == nil { + return &http.Response{StatusCode: 0}, "" + } bodyBytes, _ := io.ReadAll(resp.Body) resp.Body.Close() return resp, string(bodyBytes) @@ -533,6 +561,36 @@ func TestAuthHandler_BootstrapAdmin_MissingSecret(t *testing.T) { } } +func TestAuthHandler_BootstrapAdmin_InvalidSecret(t *testing.T) { + t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret") + + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + payload, _ := json.Marshal(map[string]interface{}{ + "username": "admin", + "email": "admin@example.com", + "password": "AdminPass123!", + }) + + req, err := http.NewRequest(http.MethodPost, server.URL+"/api/v1/auth/bootstrap-admin", bytes.NewReader(payload)) + if err != nil { + t.Fatalf("create request failed: %v", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Bootstrap-Secret", "wrong-secret") + + resp, err := (&http.Client{}).Do(req) + if err != nil { + t.Fatalf("bootstrap request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode) + } +} + func TestAuthHandler_GetAuthCapabilities(t *testing.T) { server, cleanup := setupHandlerTestServer(t) defer cleanup() @@ -659,6 +717,29 @@ func TestUserHandler_CreateUser_Unauthorized(t *testing.T) { } } +func TestUserHandler_CreateUser_AdminSuccess(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret") + token := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "createuseradmin", "createuseradmin@test.com", "AdminPass123!") + if token == "" { + t.Fatal("bootstrap admin should return access token") + } + + resp, body := doPost(server.URL+"/api/v1/users", token, map[string]interface{}{ + "username": "created-by-admin", + "email": "created-by-admin@test.com", + "password": "CreatedPass123!", + "nickname": "Created User", + }) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusCreated, resp.StatusCode, body) + } +} + func TestUserHandler_ListUsers_Success(t *testing.T) { server, cleanup := setupHandlerTestServer(t) defer cleanup() @@ -695,6 +776,21 @@ func TestUserHandler_GetUser_Success(t *testing.T) { } } +func TestUserHandler_GetUser_InvalidID(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + registerUser(server.URL, "getinvalid", "getinvalid@test.com", "AdminPass123!") + token := getToken(server.URL, "getinvalid", "AdminPass123!") + + resp, body := doGet(server.URL+"/api/v1/users/not-a-number", token) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body) + } +} + func TestUserHandler_UpdateUser_Success(t *testing.T) { server, cleanup := setupHandlerTestServer(t) defer cleanup() @@ -812,6 +908,66 @@ func TestUserHandler_UpdateUser_ProfileFieldsPersisted(t *testing.T) { } } +func TestUserHandler_UpdateUser_InvalidBirthday(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + registerUser(server.URL, "birthdayuser", "birthdayuser@test.com", "UserPass123!") + token := getToken(server.URL, "birthdayuser", "UserPass123!") + + resp, body := doPut(server.URL+"/api/v1/users/1", token, map[string]interface{}{ + "birthday": "2026-99-99", + }) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body) + } +} + +func TestUserHandler_UpdateUser_ForbiddenForOtherUser(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + registerUser(server.URL, "updateactor", "updateactor@test.com", "ActorPass123!") + registerUser(server.URL, "updatetarget", "updatetarget@test.com", "TargetPass123!") + token := getToken(server.URL, "updateactor", "ActorPass123!") + + resp, body := doPut(server.URL+"/api/v1/users/2", token, map[string]interface{}{ + "nickname": "Not Allowed", + }) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusForbidden { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body) + } +} + +func TestUserHandler_UpdateUser_InvalidJSON(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + registerUser(server.URL, "updatejson", "updatejson@test.com", "UserPass123!") + token := getToken(server.URL, "updatejson", "UserPass123!") + + req, err := http.NewRequest(http.MethodPut, server.URL+"/api/v1/users/1", bytes.NewBufferString("{")) + if err != nil { + t.Fatalf("create request failed: %v", err) + } + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + + resp, err := (&http.Client{}).Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode) + } +} + func TestUserHandler_UpdatePassword_NonAdminCannotUpdateAnotherUser(t *testing.T) { server, cleanup := setupHandlerTestServer(t) defer cleanup() @@ -897,6 +1053,49 @@ func TestUserHandler_UpdatePassword_AdminCanResetAnotherUser(t *testing.T) { } } +func TestUserHandler_UpdatePassword_InvalidID(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + registerUser(server.URL, "passwordinvalid", "passwordinvalid@test.com", "UserPass123!") + token := getToken(server.URL, "passwordinvalid", "UserPass123!") + + resp, body := doPut(server.URL+"/api/v1/users/not-a-number/password", token, map[string]interface{}{ + "old_password": "UserPass123!", + "new_password": "ChangedPass123!", + }) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body) + } +} + +func TestUserHandler_UpdatePassword_InvalidJSON(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + registerUser(server.URL, "passwordjson", "passwordjson@test.com", "UserPass123!") + token := getToken(server.URL, "passwordjson", "UserPass123!") + + req, err := http.NewRequest(http.MethodPut, server.URL+"/api/v1/users/1/password", bytes.NewBufferString("{")) + if err != nil { + t.Fatalf("create request failed: %v", err) + } + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + + resp, err := (&http.Client{}).Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode) + } +} + func TestUserHandler_DeleteUser_NonAdmin_Forbidden(t *testing.T) { server, cleanup := setupHandlerTestServer(t) defer cleanup() @@ -913,6 +1112,26 @@ func TestUserHandler_DeleteUser_NonAdmin_Forbidden(t *testing.T) { } } +func TestUserHandler_DeleteUser_AdminSuccess(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret") + token := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "deleteuseradmin", "deleteuseradmin@test.com", "AdminPass123!") + registerUser(server.URL, "delete-target", "delete-target@test.com", "TargetPass123!") + + if token == "" { + t.Fatal("bootstrap admin should return access token") + } + + resp, body := doDelete(server.URL+"/api/v1/users/2", token) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body) + } +} + func TestUserHandler_SearchUsers_Success(t *testing.T) { server, cleanup := setupHandlerTestServer(t) defer cleanup() @@ -946,6 +1165,26 @@ func TestUserHandler_UpdateUserStatus_RequiresAdmin(t *testing.T) { } } +func TestUserHandler_UpdateUserStatus_InvalidStatus(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret") + token := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "statusadmin", "statusadmin@test.com", "AdminPass123!") + if token == "" { + t.Fatal("bootstrap admin should return access token") + } + + resp, body := doPut(server.URL+"/api/v1/users/1/status", token, map[string]interface{}{ + "status": "mystery", + }) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body) + } +} + func TestUserHandler_GetUserRoles_Success(t *testing.T) { server, cleanup := setupHandlerTestServer(t) defer cleanup() @@ -981,6 +1220,37 @@ func TestUserHandler_GetUserRoles_AdminCanViewAnotherUser(t *testing.T) { } } +func TestUserHandler_GetUserRoles_ForbiddenForOtherUser(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + registerUser(server.URL, "rolesactor", "rolesactor@test.com", "ActorPass123!") + registerUser(server.URL, "rolestarget", "rolestarget@test.com", "TargetPass123!") + token := getToken(server.URL, "rolesactor", "ActorPass123!") + + resp, body := doGet(server.URL+"/api/v1/users/2/roles", token) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusForbidden { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body) + } +} + +func TestUserHandler_GetUserRoles_InvalidID(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + registerUser(server.URL, "rolesinvalid", "rolesinvalid@test.com", "UserPass123!") + token := getToken(server.URL, "rolesinvalid", "UserPass123!") + + resp, body := doGet(server.URL+"/api/v1/users/not-a-number/roles", token) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body) + } +} + func TestUserHandler_AssignRoles_RequiresAdmin(t *testing.T) { server, cleanup := setupHandlerTestServer(t) defer cleanup() @@ -999,6 +1269,54 @@ func TestUserHandler_AssignRoles_RequiresAdmin(t *testing.T) { } } +func TestUserHandler_AssignRoles_InvalidID(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret") + token := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "assignadmin", "assignadmin@test.com", "AdminPass123!") + if token == "" { + t.Fatal("bootstrap admin should return access token") + } + + resp, body := doPut(server.URL+"/api/v1/users/not-a-number/roles", token, map[string]interface{}{ + "role_ids": []int64{1}, + }) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body) + } +} + +func TestUserHandler_AssignRoles_InvalidJSON(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret") + token := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "assignjsonadmin", "assignjsonadmin@test.com", "AdminPass123!") + if token == "" { + t.Fatal("bootstrap admin should return access token") + } + + req, err := http.NewRequest(http.MethodPut, server.URL+"/api/v1/users/1/roles", bytes.NewBufferString("{")) + if err != nil { + t.Fatalf("create request failed: %v", err) + } + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + + resp, err := (&http.Client{}).Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode) + } +} + func TestUserHandler_BatchUpdateStatus_RequiresAdmin(t *testing.T) { server, cleanup := setupHandlerTestServer(t) defer cleanup() @@ -1018,6 +1336,57 @@ func TestUserHandler_BatchUpdateStatus_RequiresAdmin(t *testing.T) { } } +func TestUserHandler_BatchUpdateStatus_InvalidJSON(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret") + token := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "batchstatusadmin", "batchstatusadmin@test.com", "AdminPass123!") + if token == "" { + t.Fatal("bootstrap admin should return access token") + } + + req, err := http.NewRequest(http.MethodPut, server.URL+"/api/v1/users/batch/status", bytes.NewBufferString("{")) + if err != nil { + t.Fatalf("create request failed: %v", err) + } + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + + resp, err := (&http.Client{}).Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode) + } +} + +func TestUserHandler_BatchUpdateStatus_Success(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret") + token := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "batchstatusadmin2", "batchstatusadmin2@test.com", "AdminPass123!") + registerUser(server.URL, "batchstatus-target", "batchstatus-target@test.com", "TargetPass123!") + + if token == "" { + t.Fatal("bootstrap admin should return access token") + } + + resp, body := doPut(server.URL+"/api/v1/users/batch/status", token, map[string]interface{}{ + "ids": []int64{2}, + "status": 1, + }) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body) + } +} + func TestUserHandler_BatchDelete_RequiresAdmin(t *testing.T) { server, cleanup := setupHandlerTestServer(t) defer cleanup() @@ -1050,6 +1419,628 @@ func TestUserHandler_BatchDelete_EmptyIDs_RequiresAdmin(t *testing.T) { } } +func TestUserHandler_BatchDelete_InvalidJSON(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret") + token := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "batchdeleteadmin", "batchdeleteadmin@test.com", "AdminPass123!") + if token == "" { + t.Fatal("bootstrap admin should return access token") + } + + req, err := http.NewRequest(http.MethodDelete, server.URL+"/api/v1/users/batch", bytes.NewBufferString("{")) + if err != nil { + t.Fatalf("create request failed: %v", err) + } + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + + resp, err := (&http.Client{}).Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode) + } +} + +func TestUserHandler_BatchDelete_Success(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret") + token := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "batchdeleteadmin2", "batchdeleteadmin2@test.com", "AdminPass123!") + registerUser(server.URL, "batchdelete-target", "batchdelete-target@test.com", "TargetPass123!") + + if token == "" { + t.Fatal("bootstrap admin should return access token") + } + + reqBody, err := json.Marshal(map[string]interface{}{ + "ids": []int64{2}, + }) + if err != nil { + t.Fatalf("marshal request failed: %v", err) + } + + req, err := http.NewRequest(http.MethodDelete, server.URL+"/api/v1/users/batch", bytes.NewReader(reqBody)) + if err != nil { + t.Fatalf("create request failed: %v", err) + } + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + + resp, err := (&http.Client{}).Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read response failed: %v", err) + } + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, string(bodyBytes)) + } +} + +func TestUserHandler_ListAdmins_Success(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret") + token := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "listadmins", "listadmins@test.com", "AdminPass123!") + if token == "" { + t.Fatal("bootstrap admin should return access token") + } + + resp, body := doGet(server.URL+"/api/v1/admin/admins", token) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body) + } +} + +func TestUserHandler_CreateAdmin_Success(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret") + token := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "createadminroot", "createadminroot@test.com", "AdminPass123!") + if token == "" { + t.Fatal("bootstrap admin should return access token") + } + + resp, body := doPost(server.URL+"/api/v1/admin/admins", token, map[string]interface{}{ + "username": "secondadmin", + "password": "SecondAdmin123!", + "email": "secondadmin@test.com", + "nickname": "Second Admin", + }) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusCreated, resp.StatusCode, body) + } +} + +func TestUserHandler_DeleteAdmin_InvalidID(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret") + token := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "deleteadminroot", "deleteadminroot@test.com", "AdminPass123!") + if token == "" { + t.Fatal("bootstrap admin should return access token") + } + + resp, body := doDelete(server.URL+"/api/v1/admin/admins/not-a-number", token) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body) + } +} + +func TestUserHandler_DeleteAdmin_CannotDeleteSelf(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret") + token := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "selfdeleteadmin", "selfdeleteadmin@test.com", "AdminPass123!") + if token == "" { + t.Fatal("bootstrap admin should return access token") + } + + resp, body := doDelete(server.URL+"/api/v1/admin/admins/1", token) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body) + } +} + +func TestUserHandler_DeleteAdmin_CannotDeleteLastAdmin(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret") + rootToken := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "lastadminroot", "lastadminroot@test.com", "AdminPass123!") + if rootToken == "" { + t.Fatal("bootstrap admin should return access token") + } + + resp, body := doPost(server.URL+"/api/v1/admin/admins", rootToken, map[string]interface{}{ + "username": "secondlastadmin", + "password": "SecondAdmin123!", + "email": "secondlastadmin@test.com", + }) + defer resp.Body.Close() + if resp.StatusCode != http.StatusCreated { + t.Fatalf("expected create admin status %d, got %d, body: %s", http.StatusCreated, resp.StatusCode, body) + } + + respDelete, deleteBody := doDelete(server.URL+"/api/v1/admin/admins/2", rootToken) + defer respDelete.Body.Close() + if respDelete.StatusCode != http.StatusOK { + t.Fatalf("expected first delete status %d, got %d, body: %s", http.StatusOK, respDelete.StatusCode, deleteBody) + } + + respLast, lastBody := doDelete(server.URL+"/api/v1/admin/admins/1", rootToken) + defer respLast.Body.Close() + if respLast.StatusCode != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusBadRequest, respLast.StatusCode, lastBody) + } +} + +func TestUserHandler_DeleteAdmin_Success(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret") + rootToken := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "deleteadminsuccess", "deleteadminsuccess@test.com", "AdminPass123!") + if rootToken == "" { + t.Fatal("bootstrap admin should return access token") + } + + resp, body := doPost(server.URL+"/api/v1/admin/admins", rootToken, map[string]interface{}{ + "username": "deleteadmin-target", + "password": "DeleteAdmin123!", + "email": "deleteadmin-target@test.com", + }) + defer resp.Body.Close() + if resp.StatusCode != http.StatusCreated { + t.Fatalf("expected create admin status %d, got %d, body: %s", http.StatusCreated, resp.StatusCode, body) + } + + deleteResp, deleteBody := doDelete(server.URL+"/api/v1/admin/admins/2", rootToken) + defer deleteResp.Body.Close() + if deleteResp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, deleteResp.StatusCode, deleteBody) + } +} + +func TestLogHandler_GetMyLoginLogs_Success(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + registerUser(server.URL, "myloginlogs", "myloginlogs@test.com", "UserPass123!") + token := getToken(server.URL, "myloginlogs", "UserPass123!") + + resp, body := doGet(server.URL+"/api/v1/logs/login/me?page=1&page_size=10", token) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body) + } +} + +func TestLogHandler_GetMyOperationLogs_Success(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + registerUser(server.URL, "myoplogs", "myoplogs@test.com", "UserPass123!") + token := getToken(server.URL, "myoplogs", "UserPass123!") + + resp, body := doGet(server.URL+"/api/v1/logs/operation/me?page=1&page_size=10", token) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body) + } +} + +func TestLogHandler_GetLoginLogs_Success(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret") + token := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "loginlogadmin", "loginlogadmin@test.com", "AdminPass123!") + if token == "" { + t.Fatal("bootstrap admin should return access token") + } + + resp, body := doGet(server.URL+"/api/v1/logs/login?page=1&page_size=10", token) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body) + } +} + +func TestLogHandler_GetOperationLogs_Success(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret") + token := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "oplogadmin", "oplogadmin@test.com", "AdminPass123!") + if token == "" { + t.Fatal("bootstrap admin should return access token") + } + + resp, body := doGet(server.URL+"/api/v1/logs/operation?page=1&page_size=10", token) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body) + } +} + +func TestLogHandler_GetLoginLogs_InvalidCursor(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret") + token := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "loginqueryadmin", "loginqueryadmin@test.com", "AdminPass123!") + if token == "" { + t.Fatal("bootstrap admin should return access token") + } + + resp, body := doGet(server.URL+"/api/v1/logs/login?cursor=bad-cursor", token) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body) + } +} + +func TestLogHandler_GetOperationLogs_InvalidCursor(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret") + token := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "opqueryadmin", "opqueryadmin@test.com", "AdminPass123!") + if token == "" { + t.Fatal("bootstrap admin should return access token") + } + + resp, body := doGet(server.URL+"/api/v1/logs/operation?cursor=bad-cursor", token) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body) + } +} + +func TestLogHandler_GetLoginLogs_CursorMode_Success(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret") + token := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "logcursoradmin", "logcursoradmin@test.com", "AdminPass123!") + if token == "" { + t.Fatal("bootstrap admin should return access token") + } + + resp, body := doGet(server.URL+"/api/v1/logs/login?size=5", token) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body) + } +} + +func TestLogHandler_GetOperationLogs_CursorMode_Success(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret") + token := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "opcursoradmin", "opcursoradmin@test.com", "AdminPass123!") + if token == "" { + t.Fatal("bootstrap admin should return access token") + } + + resp, body := doGet(server.URL+"/api/v1/logs/operation?size=5", token) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body) + } +} + +func TestLogHandler_ExportLoginLogs_Success(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret") + token := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "exportlogadmin", "exportlogadmin@test.com", "AdminPass123!") + if token == "" { + t.Fatal("bootstrap admin should return access token") + } + + req, err := http.NewRequest(http.MethodGet, server.URL+"/api/v1/logs/login/export", nil) + if err != nil { + t.Fatalf("create request failed: %v", err) + } + req.Header.Set("Authorization", "Bearer "+token) + + resp, err := (&http.Client{}).Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, string(bodyBytes)) + } + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read response failed: %v", err) + } + if len(bodyBytes) == 0 { + t.Fatal("expected non-empty export body") + } +} + +func TestSMSHandler_SendCode_InvalidPayload(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + resp, body := doPost(server.URL+"/api/v1/auth/send-code", "", map[string]interface{}{ + "phone": "", + "purpose": "login", + }) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body) + } +} + +func TestSMSHandler_SendCode_Success(t *testing.T) { + server, cleanup, cacheManager := setupHandlerTestServerWithCache(t) + defer cleanup() + + resp, body := doPost(server.URL+"/api/v1/auth/send-code", "", map[string]interface{}{ + "phone": "13800138000", + "purpose": "login", + }) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body) + } + + if _, ok := cacheManager.Get(context.Background(), "sms_code:login:13800138000"); !ok { + t.Fatal("expected SMS code to be stored in cache") + } +} + +func TestSMSHandler_LoginByCode_InvalidPayload(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + resp, body := doPost(server.URL+"/api/v1/auth/login/code", "", map[string]interface{}{ + "phone": "", + "code": "", + }) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body) + } +} + +func TestSMSHandler_LoginByCode_Success(t *testing.T) { + server, cleanup, cacheManager := setupHandlerTestServerWithCache(t) + defer cleanup() + + phone := "13800138001" + registerUser(server.URL, "smsloginuser", "smsloginuser@test.com", "UserPass123!") + + resp, body := doPut(server.URL+"/api/v1/users/1", getToken(server.URL, "smsloginuser", "UserPass123!"), map[string]interface{}{ + "phone": phone, + }) + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected phone update status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body) + } + + sendResp, sendBody := doPost(server.URL+"/api/v1/auth/send-code", "", map[string]interface{}{ + "phone": phone, + "purpose": "login", + }) + defer sendResp.Body.Close() + if sendResp.StatusCode != http.StatusOK { + t.Fatalf("expected send code status %d, got %d, body: %s", http.StatusOK, sendResp.StatusCode, sendBody) + } + + codeValue, ok := cacheManager.Get(context.Background(), "sms_code:login:"+phone) + if !ok { + t.Fatal("expected SMS login code in cache") + } + code, ok := codeValue.(string) + if !ok || code == "" { + t.Fatalf("expected cached SMS login code string, got %#v", codeValue) + } + + loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login/code", "", map[string]interface{}{ + "phone": phone, + "code": code, + }) + defer loginResp.Body.Close() + + if loginResp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, loginResp.StatusCode, loginBody) + } +} + +func TestAuthHandler_SendEmailBindCode_NotConfigured(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + registerUser(server.URL, "emailbinduser", "emailbinduser@test.com", "UserPass123!") + token := getToken(server.URL, "emailbinduser", "UserPass123!") + + resp, body := doPost(server.URL+"/api/v1/users/me/bind-email/code", token, map[string]interface{}{ + "email": "bind@example.com", + }) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body) + } +} + +func TestAuthHandler_BindEmail_NotConfigured(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + registerUser(server.URL, "bindemailuser", "bindemailuser@test.com", "UserPass123!") + token := getToken(server.URL, "bindemailuser", "UserPass123!") + + resp, body := doPost(server.URL+"/api/v1/users/me/bind-email", token, map[string]interface{}{ + "email": "bind@example.com", + "code": "123456", + }) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body) + } +} + +func TestAuthHandler_UnbindEmail_NotConfigured(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + registerUser(server.URL, "unbindemailuser", "unbindemailuser@test.com", "UserPass123!") + token := getToken(server.URL, "unbindemailuser", "UserPass123!") + + resp, body := doDelete(server.URL+"/api/v1/users/me/bind-email", token) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body) + } +} + +func TestAuthHandler_SendPhoneBindCode_NotConfigured(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + registerUser(server.URL, "phonebinduser", "phonebinduser@test.com", "UserPass123!") + token := getToken(server.URL, "phonebinduser", "UserPass123!") + + resp, body := doPost(server.URL+"/api/v1/users/me/bind-phone/code", token, map[string]interface{}{ + "phone": "13800138009", + }) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body) + } +} + +func TestAuthHandler_BindPhone_NotConfigured(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + registerUser(server.URL, "bindphoneuser", "bindphoneuser@test.com", "UserPass123!") + token := getToken(server.URL, "bindphoneuser", "UserPass123!") + + resp, body := doPost(server.URL+"/api/v1/users/me/bind-phone", token, map[string]interface{}{ + "phone": "13800138009", + "code": "123456", + }) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body) + } +} + +func TestAuthHandler_UnbindPhone_NotConfigured(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + registerUser(server.URL, "unbindphoneuser", "unbindphoneuser@test.com", "UserPass123!") + token := getToken(server.URL, "unbindphoneuser", "UserPass123!") + + resp, body := doDelete(server.URL+"/api/v1/users/me/bind-phone", token) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body) + } +} + +func TestAuthHandler_GetSocialAccounts_Success(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + registerUser(server.URL, "sociallistuser", "sociallistuser@test.com", "UserPass123!") + token := getToken(server.URL, "sociallistuser", "UserPass123!") + + resp, body := doGet(server.URL+"/api/v1/users/me/social-accounts", token) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body) + } +} + +func TestAuthHandler_BindSocialAccount_NotConfigured(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + registerUser(server.URL, "socialbinduser", "socialbinduser@test.com", "UserPass123!") + token := getToken(server.URL, "socialbinduser", "UserPass123!") + + resp, body := doPost(server.URL+"/api/v1/users/me/bind-social", token, map[string]interface{}{ + "provider": "github", + "code": "oauth-code", + }) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body) + } +} + +func TestAuthHandler_UnbindSocialAccount_NotConfigured(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + registerUser(server.URL, "socialunbinduser", "socialunbinduser@test.com", "UserPass123!") + token := getToken(server.URL, "socialunbinduser", "UserPass123!") + + resp, body := doDelete(server.URL+"/api/v1/users/me/bind-social/github", token) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body) + } +} + // ============================================================================= // Device Handler Tests // ============================================================================= @@ -2251,6 +3242,201 @@ func TestAuthHandler_RefreshToken_MissingToken(t *testing.T) { } } +func TestAuthHandler_RefreshToken_InvalidJSON(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + req, err := http.NewRequest(http.MethodPost, server.URL+"/api/v1/auth/refresh", bytes.NewBufferString("{")) + if err != nil { + t.Fatalf("create refresh request failed: %v", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := (&http.Client{}).Do(req) + if err != nil { + t.Fatalf("refresh request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode) + } +} + +func TestAuthHandler_RefreshToken_EmptyJSONBodyFallsBackToCookie(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + registerUser(server.URL, "refreshfallbackuser", "refreshfallback@example.com", "Password123!") + loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{ + "account": "refreshfallbackuser", + "password": "Password123!", + }) + defer loginResp.Body.Close() + + if loginResp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, loginResp.StatusCode, loginBody) + } + + refreshCookie := getCookie(loginResp, "ums_refresh_token") + if refreshCookie == nil || refreshCookie.Value == "" { + t.Fatalf("login response missing refresh cookie, cookies=%v", loginResp.Cookies()) + } + + req, err := http.NewRequest(http.MethodPost, server.URL+"/api/v1/auth/refresh", bytes.NewBufferString(`{}`)) + if err != nil { + t.Fatalf("create refresh request failed: %v", err) + } + req.Header.Set("Content-Type", "application/json") + req.AddCookie(refreshCookie) + req.AddCookie(&http.Cookie{Name: "ums_session_present", Value: "1"}) + + resp, err := (&http.Client{}).Do(req) + if err != nil { + t.Fatalf("refresh request failed: %v", err) + } + defer resp.Body.Close() + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read refresh response failed: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, string(bodyBytes)) + } +} + +func TestAuthHandler_SendEmailCode_Success(t *testing.T) { + server, cleanup := setupHandlerTestServer(t) + defer cleanup() + + registerUser(server.URL, "emailcodeuser", "emailcode@example.com", "Password123!") + + resp, body := doPost(server.URL+"/api/v1/auth/send-email-code", "", map[string]interface{}{ + "email": "emailcode@example.com", + }) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body) + } +} + +func TestAuthHandler_LoginByEmailCode_Success(t *testing.T) { + server, cleanup, cacheManager := setupHandlerTestServerWithCache(t) + defer cleanup() + + registerUser(server.URL, "emailloginuser", "emaillogin@example.com", "Password123!") + + sendResp, sendBody := doPost(server.URL+"/api/v1/auth/send-email-code", "", map[string]interface{}{ + "email": "emaillogin@example.com", + }) + defer sendResp.Body.Close() + if sendResp.StatusCode != http.StatusOK { + t.Fatalf("expected send status %d, got %d, body: %s", http.StatusOK, sendResp.StatusCode, sendBody) + } + + codeValue, ok := cacheManager.Get(context.Background(), "email_code:login:emaillogin@example.com") + if !ok { + t.Fatal("expected email login code to be stored in cache") + } + code, ok := codeValue.(string) + if !ok || code == "" { + t.Fatalf("expected cached email login code string, got %#v", codeValue) + } + + loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login-by-email-code", "", map[string]interface{}{ + "email": "emaillogin@example.com", + "code": code, + }) + defer loginResp.Body.Close() + + if loginResp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, loginResp.StatusCode, loginBody) + } +} + +func TestAuthHandler_ActivateEmail_InvalidToken(t *testing.T) { + server, cleanup, _ := setupHandlerTestServerWithActivation(t) + defer cleanup() + + resp, body := doPost(server.URL+"/api/v1/auth/activate-email", "", map[string]interface{}{ + "token": "invalid-token", + }) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusUnauthorized, resp.StatusCode, body) + } +} + +func TestAuthHandler_ActivateEmail_Success(t *testing.T) { + server, cleanup, cacheManager := setupHandlerTestServerWithActivation(t) + defer cleanup() + + resp, body := doPost(server.URL+"/api/v1/auth/register", "", map[string]interface{}{ + "username": "inactiveemailuser", + "email": "inactiveemailuser@example.com", + "password": "Password123!", + }) + defer resp.Body.Close() + if resp.StatusCode != http.StatusCreated { + t.Fatalf("expected register status %d, got %d, body: %s", http.StatusCreated, resp.StatusCode, body) + } + + const token = "known-activation-token" + if err := cacheManager.Set(context.Background(), "email_activation:"+token, int64(1), time.Hour, time.Hour); err != nil { + t.Fatalf("seed activation token failed: %v", err) + } + + activateResp, activateBody := doPost(server.URL+"/api/v1/auth/activate-email", "", map[string]interface{}{ + "token": token, + }) + defer activateResp.Body.Close() + + if activateResp.StatusCode != http.StatusOK { + t.Fatalf("expected activate status %d, got %d, body: %s", http.StatusOK, activateResp.StatusCode, activateBody) + } +} + +func TestAuthHandler_ResendActivationEmail_SuccessForUnknownEmail(t *testing.T) { + server, cleanup, _ := setupHandlerTestServerWithActivation(t) + defer cleanup() + + resp, body := doPost(server.URL+"/api/v1/auth/resend-activation", "", map[string]interface{}{ + "email": "unknown@example.com", + }) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body) + } +} + +func TestAuthHandler_ResendActivationEmail_SuccessForInactiveUser(t *testing.T) { + server, cleanup, _ := setupHandlerTestServerWithActivation(t) + defer cleanup() + + registerResp, registerBody := doPost(server.URL+"/api/v1/auth/register", "", map[string]interface{}{ + "username": "resendinactiveuser", + "email": "resendinactiveuser@example.com", + "password": "Password123!", + }) + defer registerResp.Body.Close() + if registerResp.StatusCode != http.StatusCreated { + t.Fatalf("expected register status %d, got %d, body: %s", http.StatusCreated, registerResp.StatusCode, registerBody) + } + + resp, body := doPost(server.URL+"/api/v1/auth/resend-activation", "", map[string]interface{}{ + "email": "resendinactiveuser@example.com", + }) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body) + } +} + // ============================================================================= // Avatar Handler Tests // ============================================================================= diff --git a/internal/api/handler/theme_handler_test.go b/internal/api/handler/theme_handler_test.go index b1f73ac..b7652e9 100644 --- a/internal/api/handler/theme_handler_test.go +++ b/internal/api/handler/theme_handler_test.go @@ -17,10 +17,6 @@ import ( "gorm.io/gorm/logger" ) -// ============================================================================= -// Theme Handler Tests - TDD approach -// ============================================================================= - func setupThemeTestEnv(t *testing.T) (*handler.ThemeHandler, *gorm.DB) { t.Helper() gin.SetMode(gin.TestMode) @@ -45,10 +41,22 @@ func setupThemeTestEnv(t *testing.T) (*handler.ThemeHandler, *gorm.DB) { return handler.NewThemeHandler(themeSvc), db } +func createThemeForTest(t *testing.T, h *handler.ThemeHandler, body string) { + t.Helper() + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("POST", "/api/v1/themes", bytes.NewReader([]byte(body))) + c.Request.Header.Set("Content-Type", "application/json") + h.CreateTheme(c) + if w.Code != http.StatusCreated { + t.Fatalf("create theme failed: %d %s", w.Code, w.Body.String()) + } +} + func TestThemeHandler_CreateTheme(t *testing.T) { h, _ := setupThemeTestEnv(t) - t.Run("创建主题成功", func(t *testing.T) { + t.Run("create success", func(t *testing.T) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) body := `{"name":"test-theme","primary_color":"#1976d2"}` @@ -58,20 +66,19 @@ func TestThemeHandler_CreateTheme(t *testing.T) { h.CreateTheme(c) if w.Code != http.StatusCreated { - t.Errorf("期望状态码 %d, 得到 %d", http.StatusCreated, w.Code) + t.Fatalf("expected status %d, got %d", http.StatusCreated, w.Code) } var resp map[string]interface{} if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("解析响应失败: %v", err) + t.Fatalf("decode response failed: %v", err) } - if resp["code"].(float64) != 0 { - t.Errorf("期望 code=0, 得到 %v", resp["code"]) + t.Fatalf("expected code=0, got %v", resp["code"]) } }) - t.Run("创建主题失败-缺少名称", func(t *testing.T) { + t.Run("create missing name", func(t *testing.T) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) body := `{"primary_color":"#1976d2"}` @@ -81,31 +88,30 @@ func TestThemeHandler_CreateTheme(t *testing.T) { h.CreateTheme(c) if w.Code != http.StatusBadRequest { - t.Errorf("期望状态码 %d, 得到 %d", http.StatusBadRequest, w.Code) + t.Fatalf("expected status %d, got %d", http.StatusBadRequest, w.Code) } }) } func TestThemeHandler_ListThemes(t *testing.T) { h, _ := setupThemeTestEnv(t) + createThemeForTest(t, h, `{"name":"list-theme","primary_color":"#1976d2"}`) - t.Run("获取主题列表", func(t *testing.T) { - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - c.Request = httptest.NewRequest("GET", "/api/v1/themes", nil) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("GET", "/api/v1/themes", nil) - h.ListThemes(c) + h.ListThemes(c) - if w.Code != http.StatusOK { - t.Errorf("期望状态码 %d, 得到 %d", http.StatusOK, w.Code) - } - }) + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } } func TestThemeHandler_GetTheme(t *testing.T) { h, _ := setupThemeTestEnv(t) - t.Run("获取主题失败-无效ID", func(t *testing.T) { + t.Run("get invalid id", func(t *testing.T) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Params = gin.Params{{Key: "id", Value: "invalid"}} @@ -114,7 +120,70 @@ func TestThemeHandler_GetTheme(t *testing.T) { h.GetTheme(c) if w.Code != http.StatusBadRequest { - t.Errorf("期望状态码 %d, 得到 %d", http.StatusBadRequest, w.Code) + t.Fatalf("expected status %d, got %d", http.StatusBadRequest, w.Code) + } + }) + + t.Run("get success", func(t *testing.T) { + createThemeForTest(t, h, `{"name":"get-theme","primary_color":"#1976d2"}`) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "1"}} + c.Request = httptest.NewRequest("GET", "/api/v1/themes/1", nil) + + h.GetTheme(c) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d, body=%s", http.StatusOK, w.Code, w.Body.String()) + } + }) +} + +func TestThemeHandler_UpdateTheme(t *testing.T) { + h, _ := setupThemeTestEnv(t) + createThemeForTest(t, h, `{"name":"theme-update","primary_color":"#111111"}`) + + t.Run("update success", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "1"}} + body := `{"primary_color":"#222222","enabled":true}` + c.Request = httptest.NewRequest("PUT", "/api/v1/themes/1", bytes.NewReader([]byte(body))) + c.Request.Header.Set("Content-Type", "application/json") + + h.UpdateTheme(c) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d, body=%s", http.StatusOK, w.Code, w.Body.String()) + } + }) + + t.Run("update invalid id", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "invalid"}} + c.Request = httptest.NewRequest("PUT", "/api/v1/themes/invalid", bytes.NewReader([]byte(`{}`))) + c.Request.Header.Set("Content-Type", "application/json") + + h.UpdateTheme(c) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d", http.StatusBadRequest, w.Code) + } + }) + + t.Run("update invalid json", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "1"}} + c.Request = httptest.NewRequest("PUT", "/api/v1/themes/1", bytes.NewReader([]byte(`{"primary_color":`))) + c.Request.Header.Set("Content-Type", "application/json") + + h.UpdateTheme(c) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d", http.StatusBadRequest, w.Code) } }) } @@ -122,7 +191,7 @@ func TestThemeHandler_GetTheme(t *testing.T) { func TestThemeHandler_DeleteTheme(t *testing.T) { h, _ := setupThemeTestEnv(t) - t.Run("删除主题失败-无效ID", func(t *testing.T) { + t.Run("delete invalid id", func(t *testing.T) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Params = gin.Params{{Key: "id", Value: "invalid"}} @@ -131,7 +200,90 @@ func TestThemeHandler_DeleteTheme(t *testing.T) { h.DeleteTheme(c) if w.Code != http.StatusBadRequest { - t.Errorf("期望状态码 %d, 得到 %d", http.StatusBadRequest, w.Code) + t.Fatalf("expected status %d, got %d", http.StatusBadRequest, w.Code) + } + }) + + t.Run("delete success", func(t *testing.T) { + createThemeForTest(t, h, `{"name":"theme-delete","primary_color":"#1976d2"}`) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "1"}} + c.Request = httptest.NewRequest("DELETE", "/api/v1/themes/1", nil) + + h.DeleteTheme(c) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d, body=%s", http.StatusOK, w.Code, w.Body.String()) + } + }) +} + +func TestThemeHandler_DefaultAndActiveFlows(t *testing.T) { + h, _ := setupThemeTestEnv(t) + createThemeForTest(t, h, `{"name":"default-theme","primary_color":"#111111","is_default":true}`) + createThemeForTest(t, h, `{"name":"other-theme","primary_color":"#222222"}`) + + t.Run("list all themes", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("GET", "/api/v1/themes/all", nil) + + h.ListAllThemes(c) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + }) + + t.Run("get default theme", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("GET", "/api/v1/themes/default", nil) + + h.GetDefaultTheme(c) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) + } + }) + + t.Run("set default invalid id", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "bad"}} + c.Request = httptest.NewRequest("PUT", "/api/v1/themes/bad/default", nil) + + h.SetDefaultTheme(c) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d", http.StatusBadRequest, w.Code) + } + }) + + t.Run("set default success", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "2"}} + c.Request = httptest.NewRequest("PUT", "/api/v1/themes/2/default", nil) + + h.SetDefaultTheme(c) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d, body=%s", http.StatusOK, w.Code, w.Body.String()) + } + }) + + t.Run("get active theme", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("GET", "/api/v1/themes/active", nil) + + h.GetActiveTheme(c) + + if w.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) } }) } diff --git a/internal/api/middleware/auth_bootstrap_test.go b/internal/api/middleware/auth_bootstrap_test.go index 65fe55a..87ab47f 100644 --- a/internal/api/middleware/auth_bootstrap_test.go +++ b/internal/api/middleware/auth_bootstrap_test.go @@ -2,6 +2,7 @@ package middleware import ( "context" + "errors" "net/http" "net/http/httptest" "testing" @@ -19,6 +20,68 @@ import ( _ "modernc.org/sqlite" ) +type authStubUserRepo struct { + user *domain.User + err error +} + +func (s authStubUserRepo) GetByID(_ context.Context, _ int64) (*domain.User, error) { + return s.user, s.err +} + +type authStubUserRoleRepo struct { + roles []*domain.Role + perms []*domain.Permission + err error +} + +func (s authStubUserRoleRepo) GetUserRolesAndPermissions(_ context.Context, _ int64) ([]*domain.Role, []*domain.Permission, error) { + return s.roles, s.perms, s.err +} + +func newTestJWT(t *testing.T) *auth.JWT { + t.Helper() + + jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{ + HS256Secret: "test-middleware-secret-at-least-32-chars", + AccessTokenExpire: 15 * time.Minute, + RefreshTokenExpire: 7 * 24 * time.Hour, + }) + if err != nil { + t.Fatalf("create jwt manager failed: %v", err) + } + + return jwtManager +} + +func newAuthMiddlewareForTest(t *testing.T, user *domain.User, roles []*domain.Role, perms []*domain.Permission) (*AuthMiddleware, *auth.JWT, *cache.L1Cache) { + t.Helper() + + jwtManager := newTestJWT(t) + l1Cache := cache.NewL1Cache() + middleware := NewAuthMiddleware(jwtManager, authStubUserRepo{user: user}, authStubUserRoleRepo{roles: roles, perms: perms}, l1Cache) + return middleware, jwtManager, l1Cache +} + +func performMiddlewareRequest(t *testing.T, middleware gin.HandlerFunc, authHeader string) *httptest.ResponseRecorder { + t.Helper() + + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + router := gin.New() + router.Use(middleware) + router.GET("/protected", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"code": 0}) + }) + + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + if authHeader != "" { + req.Header.Set("Authorization", authHeader) + } + router.ServeHTTP(recorder, req) + return recorder +} + func TestAuthMiddleware_AcceptsBootstrapAdminTokenImmediately(t *testing.T) { t.Helper() gin.SetMode(gin.TestMode) @@ -101,3 +164,269 @@ func TestAuthMiddleware_AcceptsBootstrapAdminTokenImmediately(t *testing.T) { t.Fatalf("expected bootstrap token to pass auth middleware immediately, got %d body: %s", recorder.Code, recorder.Body.String()) } } + +func TestAuthMiddleware_RequiredRejectsMissingToken(t *testing.T) { + middleware, _, _ := newAuthMiddlewareForTest(t, nil, nil, nil) + + recorder := performMiddlewareRequest(t, middleware.Required(), "") + + if recorder.Code != http.StatusUnauthorized { + t.Fatalf("expected 401 for missing token, got %d", recorder.Code) + } +} + +func TestAuthMiddleware_RequiredRejectsInvalidToken(t *testing.T) { + middleware, _, _ := newAuthMiddlewareForTest(t, nil, nil, nil) + + recorder := performMiddlewareRequest(t, middleware.Required(), "Bearer not-a-jwt") + + if recorder.Code != http.StatusUnauthorized { + t.Fatalf("expected 401 for invalid token, got %d", recorder.Code) + } +} + +func TestAuthMiddleware_RequiredRejectsBlacklistedToken(t *testing.T) { + user := &domain.User{ID: 7, Username: "alice", Status: domain.UserStatusActive} + middleware, jwtManager, l1Cache := newAuthMiddlewareForTest(t, user, nil, nil) + + token, err := jwtManager.GenerateAccessToken(user.ID, user.Username, 0) + if err != nil { + t.Fatalf("generate access token failed: %v", err) + } + claims, err := jwtManager.ValidateAccessToken(token) + if err != nil { + t.Fatalf("validate access token failed: %v", err) + } + l1Cache.Set("jwt_blacklist:"+claims.JTI, true, time.Minute) + + recorder := performMiddlewareRequest(t, middleware.Required(), "Bearer "+token) + + if recorder.Code != http.StatusUnauthorized { + t.Fatalf("expected 401 for blacklisted token, got %d", recorder.Code) + } +} + +func TestAuthMiddleware_RequiredRejectsInactiveUser(t *testing.T) { + user := &domain.User{ID: 8, Username: "disabled", Status: domain.UserStatusDisabled} + middleware, jwtManager, _ := newAuthMiddlewareForTest(t, user, nil, nil) + + token, err := jwtManager.GenerateAccessToken(user.ID, user.Username, 0) + if err != nil { + t.Fatalf("generate access token failed: %v", err) + } + + recorder := performMiddlewareRequest(t, middleware.Required(), "Bearer "+token) + + if recorder.Code != http.StatusUnauthorized { + t.Fatalf("expected 401 for inactive user, got %d", recorder.Code) + } +} + +func TestAuthMiddleware_RequiredInjectsIdentityAndAuthorizations(t *testing.T) { + gin.SetMode(gin.TestMode) + user := &domain.User{ID: 9, Username: "admin", Status: domain.UserStatusActive} + roles := []*domain.Role{{Code: "admin"}, {Code: "auditor"}} + perms := []*domain.Permission{{Code: "users:read"}, {Code: "users:write"}} + middleware, jwtManager, _ := newAuthMiddlewareForTest(t, user, roles, perms) + + token, err := jwtManager.GenerateAccessToken(user.ID, user.Username, 0) + if err != nil { + t.Fatalf("generate access token failed: %v", err) + } + + recorder := httptest.NewRecorder() + router := gin.New() + router.Use(middleware.Required()) + router.GET("/protected", func(c *gin.Context) { + if got := c.GetInt64("user_id"); got != user.ID { + t.Fatalf("user_id = %d, want %d", got, user.ID) + } + if got := c.GetString("username"); got != user.Username { + t.Fatalf("username = %q, want %q", got, user.Username) + } + roleCodes := GetRoleCodes(c) + if len(roleCodes) != 2 || roleCodes[0] != "admin" || roleCodes[1] != "auditor" { + t.Fatalf("unexpected role codes: %#v", roleCodes) + } + permCodes := GetPermissionCodes(c) + if len(permCodes) != 2 || permCodes[0] != "users:read" || permCodes[1] != "users:write" { + t.Fatalf("unexpected permission codes: %#v", permCodes) + } + c.JSON(http.StatusOK, gin.H{"code": 0}) + }) + + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusOK { + t.Fatalf("expected 200 for valid token, got %d body: %s", recorder.Code, recorder.Body.String()) + } +} + +func TestAuthMiddleware_OptionalAllowsAnonymousRequest(t *testing.T) { + middleware, _, _ := newAuthMiddlewareForTest(t, nil, nil, nil) + + recorder := performMiddlewareRequest(t, middleware.Optional(), "") + + if recorder.Code != http.StatusOK { + t.Fatalf("expected optional middleware to allow anonymous request, got %d", recorder.Code) + } +} + +func TestAuthMiddleware_OptionalInjectsIdentityForValidToken(t *testing.T) { + gin.SetMode(gin.TestMode) + user := &domain.User{ID: 21, Username: "optional-user", Status: domain.UserStatusActive} + roles := []*domain.Role{{Code: "viewer"}} + perms := []*domain.Permission{{Code: "users:read"}} + middleware, jwtManager, _ := newAuthMiddlewareForTest(t, user, roles, perms) + + token, err := jwtManager.GenerateAccessToken(user.ID, user.Username, 0) + if err != nil { + t.Fatalf("generate access token failed: %v", err) + } + + recorder := httptest.NewRecorder() + router := gin.New() + router.Use(middleware.Optional()) + router.GET("/optional", func(c *gin.Context) { + if got := c.GetInt64("user_id"); got != user.ID { + t.Fatalf("user_id = %d, want %d", got, user.ID) + } + if got := c.GetString("username"); got != user.Username { + t.Fatalf("username = %q, want %q", got, user.Username) + } + if got := GetRoleCodes(c); len(got) != 1 || got[0] != "viewer" { + t.Fatalf("role_codes = %#v, want [viewer]", got) + } + if got := GetPermissionCodes(c); len(got) != 1 || got[0] != "users:read" { + t.Fatalf("permission_codes = %#v, want [users:read]", got) + } + c.Status(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodGet, "/optional", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusOK { + t.Fatalf("expected valid optional auth request to pass, got %d", recorder.Code) + } +} + +func TestAuthMiddleware_ExtractTokenCases(t *testing.T) { + gin.SetMode(gin.TestMode) + middleware, _, _ := newAuthMiddlewareForTest(t, nil, nil, nil) + + testCases := []struct { + name string + header string + want string + }{ + {name: "missing header", header: "", want: ""}, + {name: "valid bearer", header: "Bearer abc.def", want: "abc.def"}, + {name: "lowercase bearer rejected", header: "bearer abc", want: ""}, + {name: "missing token value", header: "Bearer", want: ""}, + {name: "wrong scheme", header: "Basic abc", want: ""}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/protected", nil) + if tc.header != "" { + c.Request.Header.Set("Authorization", tc.header) + } + + if got := middleware.extractToken(c); got != tc.want { + t.Fatalf("extractToken() = %q, want %q", got, tc.want) + } + }) + } +} + +func TestAuthMiddleware_ValidateUserStateAndCacheInvalidation(t *testing.T) { + user := &domain.User{ + ID: 11, + Username: "cached-user", + Status: domain.UserStatusActive, + PasswordChangedAt: time.Unix(200, 0), + } + middleware, _, l1Cache := newAuthMiddlewareForTest(t, user, nil, nil) + + if got := middleware.validateUserState(context.Background(), user.ID, 150); got == "" { + t.Fatal("expected password-changed denial for stale token") + } + if _, ok := l1Cache.Get("user_state:11"); !ok { + t.Fatal("expected user state to be cached") + } + + middleware.InvalidateUserStateCache(user.ID) + if _, ok := l1Cache.Get("user_state:11"); ok { + t.Fatal("expected user state cache to be cleared") + } +} + +func TestAuthMiddleware_LoadUserRolesAndPermsCachesAndInvalidates(t *testing.T) { + user := &domain.User{ID: 12, Username: "role-user", Status: domain.UserStatusActive} + roles := []*domain.Role{{Code: "admin"}} + perms := []*domain.Permission{{Code: "users:read"}} + middleware, _, l1Cache := newAuthMiddlewareForTest(t, user, roles, perms) + + roleCodes, permCodes := middleware.loadUserRolesAndPerms(context.Background(), user.ID) + if len(roleCodes) != 1 || roleCodes[0] != "admin" { + t.Fatalf("unexpected role codes: %#v", roleCodes) + } + if len(permCodes) != 1 || permCodes[0] != "users:read" { + t.Fatalf("unexpected permission codes: %#v", permCodes) + } + if _, ok := l1Cache.Get("user_perms:12"); !ok { + t.Fatal("expected user permissions to be cached") + } + + middleware.InvalidateUserPermCache(user.ID) + if _, ok := l1Cache.Get("user_perms:12"); ok { + t.Fatal("expected user permission cache to be cleared") + } +} + +func TestAuthMiddleware_AddToBlacklistAndUserHelpers(t *testing.T) { + activeUser := &domain.User{ID: 13, Username: "active", Status: domain.UserStatusActive} + middleware, _, l1Cache := newAuthMiddlewareForTest(t, activeUser, nil, nil) + + middleware.AddToBlacklist("jti-1", time.Minute) + if _, ok := l1Cache.Get("jwt_blacklist:jti-1"); !ok { + t.Fatal("expected blacklist entry in cache") + } + + if !middleware.isUserActive(context.Background(), activeUser.ID) { + t.Fatal("expected active user to be active") + } + if middleware.isPasswordChangedSinceTokenIssued(context.Background(), activeUser.ID, 0) { + t.Fatal("expected zero token pce to skip password change check") + } + + changedUser := &domain.User{ + ID: 14, + Username: "changed", + Status: domain.UserStatusActive, + PasswordChangedAt: time.Unix(300, 0), + } + changedMiddleware, _, _ := newAuthMiddlewareForTest(t, changedUser, nil, nil) + if !changedMiddleware.isPasswordChangedSinceTokenIssued(context.Background(), changedUser.ID, 200) { + t.Fatal("expected password-changed helper to return true") + } +} + +func TestAuthMiddleware_UserHelpersHandleRepoFailures(t *testing.T) { + middleware, _, _ := newAuthMiddlewareForTest(t, nil, nil, nil) + middleware.userRepo = authStubUserRepo{err: errors.New("db down")} + + if middleware.isUserActive(context.Background(), 99) { + t.Fatal("expected repo failure to mark user inactive") + } + if got := middleware.validateUserState(context.Background(), 99, 0); got == "" { + t.Fatal("expected validateUserState to deny on repo failure") + } +} diff --git a/internal/api/middleware/ratelimit_test.go b/internal/api/middleware/ratelimit_test.go index 17d43c4..44dc1ea 100644 --- a/internal/api/middleware/ratelimit_test.go +++ b/internal/api/middleware/ratelimit_test.go @@ -6,6 +6,7 @@ import ( "net/http/httptest" "strconv" "testing" + "time" "github.com/gin-gonic/gin" @@ -138,3 +139,155 @@ func TestRateLimitMiddleware_Refresh_ScopesBudgetByRefreshTokenBody(t *testing.T t.Fatalf("request for refresh-token-b body after exhausting refresh-token-a budget returned %d, want %d", differentToken.Code, http.StatusOK) } } + +func TestExtractRefreshToken_PreservesRequestBody(t *testing.T) { + gin.SetMode(gin.TestMode) + + body := bytes.NewBufferString(`{"refresh_token":"refresh-token-a"}`) + req := httptest.NewRequest(http.MethodPost, "/auth/refresh", body) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = req + + if got := extractRefreshToken(c); got != "refresh-token-a" { + t.Fatalf("extractRefreshToken() = %q, want refresh-token-a", got) + } + + readBack := new(bytes.Buffer) + if _, err := readBack.ReadFrom(c.Request.Body); err != nil { + t.Fatalf("re-read body failed: %v", err) + } + if got := readBack.String(); got != `{"refresh_token":"refresh-token-a"}` { + t.Fatalf("request body after extraction = %q, want original JSON", got) + } +} + +func TestRateLimitMiddleware_CleanupRemovesExpiredLimiters(t *testing.T) { + middleware := NewRateLimitMiddleware(config.RateLimitConfig{}) + limiter := middleware.getOrCreateLimiter("login:ip:127.0.0.1", time.Millisecond, 1) + limiter.requests = []int64{time.Now().Add(-time.Second).UnixMilli()} + + middleware.Cleanup() + + if _, exists := middleware.limiters["login:ip:127.0.0.1"]; exists { + t.Fatal("expected expired limiter to be removed") + } +} + +func TestRateLimitMiddleware_ResolveLimiterKeyPrefersUserIDForAPI(t *testing.T) { + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/users/1", nil) + c.Params = gin.Params{{Key: "id", Value: "1"}} + c.Set("user_id", int64(99)) + + middleware := NewRateLimitMiddleware(config.RateLimitConfig{}) + key := middleware.resolveLimiterKey(c, "api") + + if key != "api:GET:/users/1:user:99" { + t.Fatalf("resolveLimiterKey() = %q, want api:GET:/users/1:user:99", key) + } +} + +func TestSlidingWindowLimiter_EnforcesCapacityWithinWindow(t *testing.T) { + limiter := NewSlidingWindowLimiter(time.Second, 2) + + if !limiter.Allow() { + t.Fatal("expected first request to pass") + } + if !limiter.Allow() { + t.Fatal("expected second request to pass") + } + if limiter.Allow() { + t.Fatal("expected third request to be rejected") + } +} + +func TestRateLimitMiddleware_StartCleanupStopsSafely(t *testing.T) { + middleware := NewRateLimitMiddleware(config.RateLimitConfig{}) + middleware.cleanupInt = 10 * time.Millisecond + stop := middleware.StartCleanup() + time.Sleep(25 * time.Millisecond) + stop() +} + +func TestRateLimitMiddleware_ResolveLimiterKeyRefreshFallsBackToIP(t *testing.T) { + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/auth/refresh", bytes.NewBufferString(`{}`)) + c.Request.RemoteAddr = "127.0.0.1:12345" + + middleware := NewRateLimitMiddleware(config.RateLimitConfig{}) + key := middleware.resolveLimiterKey(c, "refresh") + + if key != "refresh:ip:127.0.0.1" { + t.Fatalf("resolveLimiterKey() = %q, want refresh:ip:127.0.0.1", key) + } +} + +func TestFingerprintValue_IsDeterministic(t *testing.T) { + first := fingerprintValue("refresh-token-a") + second := fingerprintValue("refresh-token-a") + third := fingerprintValue("refresh-token-b") + + if first != second { + t.Fatalf("expected same input fingerprint to match: %q vs %q", first, second) + } + if first == third { + t.Fatalf("expected different inputs to produce different fingerprints: %q vs %q", first, third) + } +} + +func TestRateLimitMiddleware_RegisterAndLoginLimiters(t *testing.T) { + gin.SetMode(gin.TestMode) + + middleware := NewRateLimitMiddleware(config.RateLimitConfig{}) + router := gin.New() + router.POST("/register", middleware.Register(), func(c *gin.Context) { + c.Status(http.StatusOK) + }) + router.POST("/login", middleware.Login(), func(c *gin.Context) { + c.Status(http.StatusOK) + }) + + for i := 0; i < 10; i++ { + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/register", nil) + req.RemoteAddr = "127.0.0.1:12345" + router.ServeHTTP(recorder, req) + if recorder.Code != http.StatusOK { + t.Fatalf("register request %d returned %d, want %d", i+1, recorder.Code, http.StatusOK) + } + } + + registerOverflow := httptest.NewRecorder() + registerReq := httptest.NewRequest(http.MethodPost, "/register", nil) + registerReq.RemoteAddr = "127.0.0.1:12345" + router.ServeHTTP(registerOverflow, registerReq) + if registerOverflow.Code != http.StatusTooManyRequests { + t.Fatalf("register overflow returned %d, want %d", registerOverflow.Code, http.StatusTooManyRequests) + } + + for i := 0; i < 5; i++ { + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/login", nil) + req.RemoteAddr = "127.0.0.1:54321" + router.ServeHTTP(recorder, req) + if recorder.Code != http.StatusOK { + t.Fatalf("login request %d returned %d, want %d", i+1, recorder.Code, http.StatusOK) + } + } + + loginOverflow := httptest.NewRecorder() + loginReq := httptest.NewRequest(http.MethodPost, "/login", nil) + loginReq.RemoteAddr = "127.0.0.1:54321" + router.ServeHTTP(loginOverflow, loginReq) + if loginOverflow.Code != http.StatusTooManyRequests { + t.Fatalf("login overflow returned %d, want %d", loginOverflow.Code, http.StatusTooManyRequests) + } +} diff --git a/internal/api/middleware/runtime_test.go b/internal/api/middleware/runtime_test.go index 79e65bf..16ae957 100644 --- a/internal/api/middleware/runtime_test.go +++ b/internal/api/middleware/runtime_test.go @@ -1,15 +1,21 @@ package middleware import ( + "bytes" + "encoding/json" "errors" + "log" "net/http" "net/http/httptest" "strings" "testing" + "time" "github.com/gin-gonic/gin" "github.com/user-management-system/internal/config" + apierrors "github.com/user-management-system/internal/pkg/errors" + "github.com/user-management-system/internal/security" ) func TestCORS_UsesConfiguredOrigins(t *testing.T) { @@ -44,6 +50,31 @@ func TestCORS_UsesConfiguredOrigins(t *testing.T) { } } +func TestCORS_RejectsDisallowedOrigin(t *testing.T) { + gin.SetMode(gin.TestMode) + SetCORSConfig(config.CORSConfig{ + AllowedOrigins: []string{"https://app.example.com"}, + AllowCredentials: false, + }) + t.Cleanup(func() { + SetCORSConfig(config.CORSConfig{ + AllowedOrigins: []string{"*"}, + AllowCredentials: true, + }) + }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil) + c.Request.Header.Set("Origin", "https://evil.example.com") + + CORS()(c) + + if recorder.Code != http.StatusForbidden { + t.Fatalf("expected 403, got %d", recorder.Code) + } +} + func TestSanitizeQuery_MasksSensitiveValues(t *testing.T) { raw := "token=abc123&foo=bar&access_token=xyz&secret=s1" sanitized := sanitizeQuery(raw) @@ -180,6 +211,23 @@ func TestTraceID_ExtractsExistingTraceID(t *testing.T) { } } +func TestTraceID_GetTraceIDHandlesMissingAndPresentValue(t *testing.T) { + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil) + + if got := GetTraceID(c); got != "" { + t.Fatalf("GetTraceID() = %q, want empty string", got) + } + + c.Set(TraceIDKey, "trace-123") + if got := GetTraceID(c); got != "trace-123" { + t.Fatalf("GetTraceID() = %q, want trace-123", got) + } +} + // ---------- Error handling middleware ---------- func TestErrorHandler_HandlesErrors(t *testing.T) { @@ -198,6 +246,35 @@ func TestErrorHandler_HandlesErrors(t *testing.T) { } } +func TestErrorHandler_ApplicationErrorPreservesStatusAndReason(t *testing.T) { + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + router := gin.New() + router.Use(ErrorHandler()) + router.GET("/users", func(c *gin.Context) { + _ = c.Error(apierrors.Forbidden("FORBIDDEN", "denied")) + }) + + req := httptest.NewRequest(http.MethodGet, "/users", nil) + router.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusForbidden { + t.Fatalf("expected status 403, got %d", recorder.Code) + } + + var body map[string]any + if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil { + t.Fatalf("unmarshal body failed: %v", err) + } + if got := body["reason"]; got != "FORBIDDEN" { + t.Fatalf("reason = %#v, want FORBIDDEN", got) + } + if got := body["message"]; got != "denied" { + t.Fatalf("message = %#v, want denied", got) + } +} + func TestRecover_HandlesPanic(t *testing.T) { gin.SetMode(gin.TestMode) @@ -216,3 +293,277 @@ func TestRecover_HandlesPanic(t *testing.T) { t.Fatalf("expected status 500 after panic, got %d", recorder.Code) } } + +func TestRecover_ReturnsInternalServerErrorPayload(t *testing.T) { + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + router := gin.New() + router.Use(Recover()) + router.GET("/panic", func(c *gin.Context) { + panic("boom") + }) + + req := httptest.NewRequest(http.MethodGet, "/panic", nil) + router.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusInternalServerError { + t.Fatalf("expected status 500 after panic, got %d", recorder.Code) + } + + var body map[string]any + if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil { + t.Fatalf("unmarshal body failed: %v", err) + } + if got := body["code"]; got != float64(http.StatusInternalServerError) { + t.Fatalf("code = %#v, want %d", got, http.StatusInternalServerError) + } +} + +func TestLogger_WritesSanitizedQueryAndErrorContext(t *testing.T) { + gin.SetMode(gin.TestMode) + + var buf bytes.Buffer + originalWriter := log.Writer() + log.SetOutput(&buf) + t.Cleanup(func() { + log.SetOutput(originalWriter) + }) + + recorder := httptest.NewRecorder() + router := gin.New() + router.Use(TraceID()) + router.Use(Logger()) + router.GET("/users", func(c *gin.Context) { + c.Set("user_id", int64(7)) + _ = c.Error(errors.New("boom")) + c.Status(http.StatusAccepted) + }) + + req := httptest.NewRequest(http.MethodGet, "/users?token=secret&name=alice", nil) + req.RemoteAddr = "203.0.113.5:1234" + req.Header.Set("User-Agent", "logger-test") + router.ServeHTTP(recorder, req) + + deadline := time.Now().Add(time.Second) + for time.Now().Before(deadline) && !strings.Contains(buf.String(), "[Query] /users?name=alice&token=%2A%2A%2A") { + time.Sleep(10 * time.Millisecond) + } + + logOutput := buf.String() + if !strings.Contains(logOutput, "[API]") { + t.Fatalf("expected API log entry, got %q", logOutput) + } + if !strings.Contains(logOutput, "user_id: 7") { + t.Fatalf("expected user id in logs, got %q", logOutput) + } + if !strings.Contains(logOutput, "[Error]") || !strings.Contains(logOutput, "boom") { + t.Fatalf("expected error log entry, got %q", logOutput) + } + if strings.Contains(logOutput, "token=secret") { + t.Fatalf("expected sanitized query string, got %q", logOutput) + } +} + +func TestLogger_DropsMalformedQueryString(t *testing.T) { + gin.SetMode(gin.TestMode) + + var buf bytes.Buffer + originalWriter := log.Writer() + log.SetOutput(&buf) + t.Cleanup(func() { + log.SetOutput(originalWriter) + }) + + recorder := httptest.NewRecorder() + router := gin.New() + router.Use(Logger()) + router.GET("/users", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodGet, "/users?bad=%zz", nil) + router.ServeHTTP(recorder, req) + + time.Sleep(25 * time.Millisecond) + if strings.Contains(buf.String(), "[Query]") { + t.Fatalf("expected malformed query to be skipped, got %q", buf.String()) + } +} + +func TestResponseWrapper_SkipsSSEAndBinaryResponses(t *testing.T) { + gin.SetMode(gin.TestMode) + + testCases := []struct { + name string + path string + contentType string + }{ + {name: "sse", path: "/stream", contentType: "text/event-stream"}, + {name: "binary", path: "/download", contentType: "application/octet-stream"}, + {name: "swagger", path: "/swagger/index.html", contentType: ""}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + router := gin.New() + router.Use(ResponseWrapper()) + router.GET(tc.path, func(c *gin.Context) { + c.Header("Content-Type", "application/json") + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + req := httptest.NewRequest(http.MethodGet, tc.path, nil) + if tc.contentType != "" { + req.Header.Set("Content-Type", tc.contentType) + } + router.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", recorder.Code) + } + if got := recorder.Body.String(); got != `{"ok":true}` { + t.Fatalf("body = %s, want raw payload", got) + } + }) + } +} + +func TestResponseWrapper_BufferMethodsTrackStatusAndBody(t *testing.T) { + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + wrapper := &responseWrapper{ + ResponseWriter: c.Writer, + body: bytes.NewBuffer(nil), + statusCode: http.StatusOK, + } + + if _, err := wrapper.Write([]byte("abc")); err != nil { + t.Fatalf("Write() error = %v", err) + } + if _, err := wrapper.WriteString("def"); err != nil { + t.Fatalf("WriteString() error = %v", err) + } + wrapper.WriteHeader(http.StatusAccepted) + + if got := wrapper.body.String(); got != "abcdef" { + t.Fatalf("buffered body = %q, want abcdef", got) + } + if wrapper.statusCode != http.StatusAccepted { + t.Fatalf("statusCode = %d, want %d", wrapper.statusCode, http.StatusAccepted) + } +} + +func TestIPFilter_RealIPAndInternalOnly(t *testing.T) { + gin.SetMode(gin.TestMode) + + filter := security.NewIPFilter() + middleware := NewIPFilterMiddleware(filter, IPFilterConfig{ + TrustProxy: true, + TrustedProxies: []string{"10.0.0.2"}, + }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/protected", nil) + c.Request.RemoteAddr = "10.0.0.2:8080" + c.Request.Header.Set("X-Forwarded-For", "198.51.100.10, 10.0.0.2") + + if got := middleware.realIP(c); got != "198.51.100.10" { + t.Fatalf("realIP() = %q, want 198.51.100.10", got) + } + if !middleware.isTrustedProxy("10.0.0.2") { + t.Fatal("expected trusted proxy match") + } + if middleware.isTrustedProxy("10.0.0.3") { + t.Fatal("unexpected trusted proxy match") + } + + if !isPrivateIP("127.0.0.1") { + t.Fatal("expected loopback to be private") + } + if isPrivateIP("198.51.100.10") { + t.Fatal("expected public address to be non-private") + } + + allowed := httptest.NewRecorder() + allowedRouter := gin.New() + allowedRouter.Use(InternalOnly()) + allowedRouter.GET("/metrics", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + allowedReq := httptest.NewRequest(http.MethodGet, "/metrics", nil) + allowedReq.RemoteAddr = "127.0.0.1:12345" + allowedRouter.ServeHTTP(allowed, allowedReq) + if allowed.Code != http.StatusOK { + t.Fatalf("expected private IP to pass, got %d", allowed.Code) + } + + blocked := httptest.NewRecorder() + blockedRouter := gin.New() + blockedRouter.Use(InternalOnly()) + blockedRouter.GET("/metrics", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + blockedReq := httptest.NewRequest(http.MethodGet, "/metrics", nil) + blockedReq.RemoteAddr = "198.51.100.10:12345" + blockedRouter.ServeHTTP(blocked, blockedReq) + if blocked.Code != http.StatusForbidden { + t.Fatalf("expected public IP to be rejected, got %d", blocked.Code) + } +} + +func TestIPFilter_FilterAndFallbacks(t *testing.T) { + gin.SetMode(gin.TestMode) + + filter := security.NewIPFilter() + if err := filter.AddToBlacklist("198.51.100.10", "manual", time.Minute); err != nil { + t.Fatalf("AddToBlacklist() error = %v", err) + } + middleware := NewIPFilterMiddleware(filter, IPFilterConfig{}) + if middleware.GetFilter() != filter { + t.Fatal("expected GetFilter() to expose the original filter") + } + + blockedRecorder := httptest.NewRecorder() + blockedRouter := gin.New() + blockedRouter.Use(middleware.Filter()) + blockedRouter.GET("/protected", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + blockedReq := httptest.NewRequest(http.MethodGet, "/protected", nil) + blockedReq.RemoteAddr = "198.51.100.10:12345" + blockedRouter.ServeHTTP(blockedRecorder, blockedReq) + if blockedRecorder.Code != http.StatusForbidden { + t.Fatalf("expected blocked IP to be rejected, got %d", blockedRecorder.Code) + } + + allowedRecorder := httptest.NewRecorder() + allowedRouter := gin.New() + allowedRouter.Use(middleware.Filter()) + allowedRouter.GET("/protected", func(c *gin.Context) { + if got := c.GetString("client_ip"); got != "127.0.0.1" { + t.Fatalf("client_ip = %q, want 127.0.0.1", got) + } + c.Status(http.StatusOK) + }) + allowedReq := httptest.NewRequest(http.MethodGet, "/protected", nil) + allowedReq.RemoteAddr = "127.0.0.1:54321" + allowedRouter.ServeHTTP(allowedRecorder, allowedReq) + if allowedRecorder.Code != http.StatusOK { + t.Fatalf("expected allowed IP to pass, got %d", allowedRecorder.Code) + } + + trustedProxyMiddleware := NewIPFilterMiddleware(filter, IPFilterConfig{ + TrustProxy: true, + }) + proxyRecorder := httptest.NewRecorder() + proxyCtx, _ := gin.CreateTestContext(proxyRecorder) + proxyCtx.Request = httptest.NewRequest(http.MethodGet, "/protected", nil) + proxyCtx.Request.RemoteAddr = "10.0.0.2:8080" + proxyCtx.Request.Header.Set("X-Real-IP", "203.0.113.9") + if got := trustedProxyMiddleware.realIP(proxyCtx); got != "203.0.113.9" { + t.Fatalf("realIP() X-Real-IP fallback = %q, want 203.0.113.9", got) + } +} diff --git a/internal/pkg/errors/errors_test.go b/internal/pkg/errors/errors_test.go index 25e6290..b5c70fc 100644 --- a/internal/pkg/errors/errors_test.go +++ b/internal/pkg/errors/errors_test.go @@ -1,5 +1,3 @@ -//go:build unit - package errors import ( diff --git a/internal/pkg/ip/ip_test.go b/internal/pkg/ip/ip_test.go index 403b2d5..5ce26b2 100644 --- a/internal/pkg/ip/ip_test.go +++ b/internal/pkg/ip/ip_test.go @@ -1,5 +1,3 @@ -//go:build unit - package ip import (