320 lines
8.5 KiB
Go
320 lines
8.5 KiB
Go
package handler_test
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/user-management-system/internal/api/handler"
|
|
"github.com/user-management-system/internal/api/middleware"
|
|
"github.com/user-management-system/internal/api/router"
|
|
"github.com/user-management-system/internal/auth"
|
|
"github.com/user-management-system/internal/cache"
|
|
"github.com/user-management-system/internal/config"
|
|
"github.com/user-management-system/internal/domain"
|
|
"github.com/user-management-system/internal/repository"
|
|
"github.com/user-management-system/internal/service"
|
|
gormsqlite "gorm.io/driver/sqlite"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/logger"
|
|
)
|
|
|
|
var exportDbCounter int64
|
|
|
|
func setupExportTestServer(t *testing.T) (*httptest.Server, string, string, func()) {
|
|
t.Helper()
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
id := atomic.AddInt64(&exportDbCounter, 1)
|
|
dsn := fmt.Sprintf("file:exportdb_%d_%s?mode=memory&cache=shared", id, t.Name())
|
|
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
|
DriverName: "sqlite",
|
|
DSN: dsn,
|
|
}), &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
})
|
|
if err != nil {
|
|
t.Skipf("skipping export test (SQLite unavailable): %v", err)
|
|
return nil, "", "", func() {}
|
|
}
|
|
|
|
if err := db.AutoMigrate(
|
|
&domain.User{},
|
|
&domain.Role{},
|
|
&domain.Permission{},
|
|
&domain.UserRole{},
|
|
&domain.RolePermission{},
|
|
); err != nil {
|
|
t.Fatalf("db migration failed: %v", err)
|
|
}
|
|
|
|
seedHandlerAuthzData(t, db)
|
|
|
|
jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{
|
|
HS256Secret: "test-export-secret-key",
|
|
AccessTokenExpire: 15 * time.Minute,
|
|
RefreshTokenExpire: 7 * 24 * time.Hour,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("create jwt manager failed: %v", err)
|
|
}
|
|
|
|
l1Cache := cache.NewL1Cache()
|
|
l2Cache := cache.NewRedisCache(false)
|
|
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
|
|
|
|
userRepo := repository.NewUserRepository(db)
|
|
roleRepo := repository.NewRoleRepository(db)
|
|
userRoleRepo := repository.NewUserRoleRepository(db)
|
|
|
|
authSvc := service.NewAuthService(userRepo, nil, jwtManager, cacheManager, 8, 5, 15*time.Minute)
|
|
authSvc.SetRoleRepositories(userRoleRepo, roleRepo)
|
|
|
|
exportSvc := service.NewExportService(userRepo, nil)
|
|
exportHandler := handler.NewExportHandler(exportSvc)
|
|
|
|
rateLimitCfg := config.RateLimitConfig{}
|
|
rateLimitMiddleware := middleware.NewRateLimitMiddleware(rateLimitCfg)
|
|
authMiddleware := middleware.NewAuthMiddleware(
|
|
jwtManager, userRepo, userRoleRepo, l1Cache,
|
|
)
|
|
authMiddleware.SetCacheManager(cacheManager)
|
|
|
|
authHandler := handler.NewAuthHandler(authSvc)
|
|
|
|
r := router.NewRouter(
|
|
authHandler, nil, nil, nil, nil, nil,
|
|
authMiddleware, rateLimitMiddleware, nil,
|
|
nil, nil, nil, nil,
|
|
nil, exportHandler, nil, nil, nil, nil, nil, nil, nil,
|
|
)
|
|
engine := r.Setup()
|
|
server := httptest.NewServer(engine)
|
|
|
|
// Register a regular user
|
|
regBody := map[string]interface{}{
|
|
"username": fmt.Sprintf("exportuser_%d", id),
|
|
"password": "TestPass123!",
|
|
"email": fmt.Sprintf("ex_%d@test.com", id),
|
|
}
|
|
regBytes, _ := json.Marshal(regBody)
|
|
regResp, _ := http.Post(server.URL+"/api/v1/auth/register", "application/json", bytes.NewReader(regBytes))
|
|
io.ReadAll(regResp.Body)
|
|
regResp.Body.Close()
|
|
|
|
// Login as regular user
|
|
loginBody := map[string]interface{}{
|
|
"account": regBody["username"],
|
|
"password": regBody["password"],
|
|
}
|
|
loginBytes, _ := json.Marshal(loginBody)
|
|
loginResp, _ := http.Post(server.URL+"/api/v1/auth/login", "application/json", bytes.NewReader(loginBytes))
|
|
var loginResult struct {
|
|
Data struct {
|
|
AccessToken string `json:"access_token"`
|
|
} `json:"data"`
|
|
}
|
|
json.NewDecoder(loginResp.Body).Decode(&loginResult)
|
|
loginResp.Body.Close()
|
|
userToken := loginResult.Data.AccessToken
|
|
|
|
// Bootstrap admin
|
|
t.Setenv("BOOTSTRAP_SECRET", fmt.Sprintf("export-bootstrap-%d", id))
|
|
adminToken := bootstrapAdmin(server.URL, fmt.Sprintf("export-bootstrap-%d", id), fmt.Sprintf("exportadmin_%d", id), fmt.Sprintf("exa_%d@test.com", id), "AdminPass123!")
|
|
if adminToken == "" {
|
|
t.Fatal("bootstrap admin failed")
|
|
}
|
|
|
|
return server, adminToken, userToken, func() {
|
|
server.Close()
|
|
if sqlDB, err := db.DB(); err == nil {
|
|
sqlDB.Close()
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestExportHandler_ExportUsers(t *testing.T) {
|
|
server, adminToken, userToken, cleanup := setupExportTestServer(t)
|
|
defer cleanup()
|
|
|
|
tests := []struct {
|
|
name string
|
|
query string
|
|
token string
|
|
wantStatus int
|
|
}{
|
|
{
|
|
name: "success_csv",
|
|
query: "format=csv",
|
|
token: adminToken,
|
|
wantStatus: http.StatusOK,
|
|
},
|
|
{
|
|
name: "success_excel",
|
|
query: "format=xlsx",
|
|
token: adminToken,
|
|
wantStatus: http.StatusOK,
|
|
},
|
|
{
|
|
name: "forbidden_regular_user",
|
|
query: "format=csv",
|
|
token: userToken,
|
|
wantStatus: http.StatusForbidden,
|
|
},
|
|
{
|
|
name: "unauthorized",
|
|
query: "format=csv",
|
|
token: "",
|
|
wantStatus: http.StatusUnauthorized,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
url := server.URL + "/api/v1/admin/users/export"
|
|
if tt.query != "" {
|
|
url = url + "?" + tt.query
|
|
}
|
|
resp, body := doGet(url, tt.token)
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != tt.wantStatus {
|
|
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestExportHandler_ImportUsers(t *testing.T) {
|
|
server, adminToken, userToken, cleanup := setupExportTestServer(t)
|
|
defer cleanup()
|
|
|
|
csvData := []byte("\xEF\xBB\xBF用户名,密码,邮箱,手机号,昵称,性别,地区,个人简介\nimportuser1,Password123!,import1@test.com,13800138001,Import1,男,北京,简介1\n")
|
|
|
|
tests := []struct {
|
|
name string
|
|
fileBody []byte
|
|
filename string
|
|
token string
|
|
wantStatus int
|
|
}{
|
|
{
|
|
name: "success_csv",
|
|
fileBody: csvData,
|
|
filename: "users.csv",
|
|
token: adminToken,
|
|
wantStatus: http.StatusOK,
|
|
},
|
|
{
|
|
name: "forbidden_regular_user",
|
|
fileBody: csvData,
|
|
filename: "users.csv",
|
|
token: userToken,
|
|
wantStatus: http.StatusForbidden,
|
|
},
|
|
{
|
|
name: "unauthorized",
|
|
fileBody: csvData,
|
|
filename: "users.csv",
|
|
token: "",
|
|
wantStatus: http.StatusUnauthorized,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
var body bytes.Buffer
|
|
writer := multipart.NewWriter(&body)
|
|
part, err := writer.CreateFormFile("file", tt.filename)
|
|
if err != nil {
|
|
t.Fatalf("create form file failed: %v", err)
|
|
}
|
|
if _, err := part.Write(tt.fileBody); err != nil {
|
|
t.Fatalf("write file body failed: %v", err)
|
|
}
|
|
if err := writer.Close(); err != nil {
|
|
t.Fatalf("close multipart writer failed: %v", err)
|
|
}
|
|
|
|
req, err := http.NewRequest(http.MethodPost, server.URL+"/api/v1/admin/users/import", &body)
|
|
if err != nil {
|
|
t.Fatalf("create request failed: %v", err)
|
|
}
|
|
if tt.token != "" {
|
|
req.Header.Set("Authorization", "Bearer "+tt.token)
|
|
}
|
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
|
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != tt.wantStatus {
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, string(respBody))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestExportHandler_GetImportTemplate(t *testing.T) {
|
|
server, adminToken, userToken, cleanup := setupExportTestServer(t)
|
|
defer cleanup()
|
|
|
|
tests := []struct {
|
|
name string
|
|
query string
|
|
token string
|
|
wantStatus int
|
|
}{
|
|
{
|
|
name: "success_csv",
|
|
query: "format=csv",
|
|
token: adminToken,
|
|
wantStatus: http.StatusOK,
|
|
},
|
|
{
|
|
name: "success_excel",
|
|
query: "format=xlsx",
|
|
token: adminToken,
|
|
wantStatus: http.StatusOK,
|
|
},
|
|
{
|
|
name: "forbidden_regular_user",
|
|
query: "format=csv",
|
|
token: userToken,
|
|
wantStatus: http.StatusForbidden,
|
|
},
|
|
{
|
|
name: "unauthorized",
|
|
query: "format=csv",
|
|
token: "",
|
|
wantStatus: http.StatusUnauthorized,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
url := server.URL + "/api/v1/admin/users/import/template"
|
|
if tt.query != "" {
|
|
url = url + "?" + tt.query
|
|
}
|
|
resp, body := doGet(url, tt.token)
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != tt.wantStatus {
|
|
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
|
}
|
|
})
|
|
}
|
|
}
|