Files
user-system/internal/api/middleware/operation_log_test.go

166 lines
4.8 KiB
Go

package middleware
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
gormsqlite "gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
_ "modernc.org/sqlite"
)
func newOperationLogRepositoryForTest(t *testing.T) *repository.OperationLogRepository {
t.Helper()
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
DriverName: "sqlite",
DSN: "file:operation_log_test?mode=memory&cache=shared",
}), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("open sqlite failed: %v", err)
}
if err := db.AutoMigrate(&domain.OperationLog{}); err != nil {
t.Fatalf("migrate failed: %v", err)
}
if err := db.Exec("DELETE FROM operation_logs").Error; err != nil {
t.Fatalf("cleanup operation_logs failed: %v", err)
}
return repository.NewOperationLogRepository(db)
}
func waitForOperationLogs(t *testing.T, repo *repository.OperationLogRepository, want int) []*domain.OperationLog {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
logs, _, err := repo.List(context.Background(), 0, 20)
if err != nil {
t.Fatalf("list operation logs failed: %v", err)
}
if len(logs) >= want {
return logs
}
time.Sleep(25 * time.Millisecond)
}
logs, _, err := repo.List(context.Background(), 0, 20)
if err != nil {
t.Fatalf("list operation logs failed: %v", err)
}
t.Fatalf("timed out waiting for %d operation logs, got %d", want, len(logs))
return nil
}
func TestOperationLogMiddleware_SkipsReadOnlyMethods(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := newOperationLogRepositoryForTest(t)
router := gin.New()
router.Use(NewOperationLogMiddleware(repo).Record())
router.GET("/logs", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", recorder.Code)
}
time.Sleep(100 * time.Millisecond)
logs, _, err := repo.List(context.Background(), 0, 20)
if err != nil {
t.Fatalf("list operation logs failed: %v", err)
}
if len(logs) != 0 {
t.Fatalf("expected no logs for GET request, got %d", len(logs))
}
}
func TestOperationLogMiddleware_RecordsAdminMutationAndSanitizesParams(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := newOperationLogRepositoryForTest(t)
router := gin.New()
router.Use(func(c *gin.Context) {
c.Set("user_id", int64(42))
c.Set(ContextKeyRoleCodes, []string{"admin"})
c.Next()
})
router.Use(NewOperationLogMiddleware(repo).Record())
router.POST("/users", func(c *gin.Context) {
c.Status(http.StatusCreated)
})
body := `{"username":"alice","password":"super-secret","token":"abc"}`
req := httptest.NewRequest(http.MethodPost, "/users", strings.NewReader(body))
req.RemoteAddr = "203.0.113.10:8080"
req.Header.Set("User-Agent", "middleware-test")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d", recorder.Code)
}
logs := waitForOperationLogs(t, repo, 1)
entry := logs[0]
if entry.UserID == nil || *entry.UserID != 42 {
t.Fatalf("user_id = %#v, want 42", entry.UserID)
}
if entry.OperationType != "admin:CREATE" {
t.Fatalf("operation_type = %q, want admin:CREATE", entry.OperationType)
}
if entry.ResponseStatus != http.StatusCreated {
t.Fatalf("response_status = %d, want %d", entry.ResponseStatus, http.StatusCreated)
}
if strings.Contains(entry.RequestParams, "super-secret") || strings.Contains(entry.RequestParams, "abc") {
t.Fatalf("expected sanitized params, got %s", entry.RequestParams)
}
}
func TestOperationLogMiddleware_MethodToTypeAndSanitizeFallbacks(t *testing.T) {
if got := methodToType(http.MethodPatch); got != "UPDATE" {
t.Fatalf("methodToType(PATCH) = %q, want UPDATE", got)
}
if got := methodToType(http.MethodDelete); got != "DELETE" {
t.Fatalf("methodToType(DELETE) = %q, want DELETE", got)
}
if got := methodToType(http.MethodGet); got != "OTHER" {
t.Fatalf("methodToType(GET) = %q, want OTHER", got)
}
raw := []byte(`{"password":"secret","name":"alice"}`)
sanitized := sanitizeParams(raw)
if strings.Contains(sanitized, "secret") {
t.Fatalf("expected password to be masked, got %s", sanitized)
}
plain := sanitizeParams([]byte("not-json"))
if plain != "not-json" {
t.Fatalf("sanitizeParams(non-json) = %q, want not-json", plain)
}
var payload map[string]string
if err := json.Unmarshal([]byte(sanitized), &payload); err != nil {
t.Fatalf("unmarshal sanitized params failed: %v", err)
}
if payload["password"] != "***" {
t.Fatalf("password = %q, want ***", payload["password"])
}
}