166 lines
4.8 KiB
Go
166 lines
4.8 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/user-management-system/internal/domain"
|
|
"github.com/user-management-system/internal/repository"
|
|
gormsqlite "gorm.io/driver/sqlite"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/logger"
|
|
_ "modernc.org/sqlite"
|
|
)
|
|
|
|
func newOperationLogRepositoryForTest(t *testing.T) *repository.OperationLogRepository {
|
|
t.Helper()
|
|
|
|
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
|
DriverName: "sqlite",
|
|
DSN: "file:operation_log_test?mode=memory&cache=shared",
|
|
}), &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("open sqlite failed: %v", err)
|
|
}
|
|
|
|
if err := db.AutoMigrate(&domain.OperationLog{}); err != nil {
|
|
t.Fatalf("migrate failed: %v", err)
|
|
}
|
|
|
|
if err := db.Exec("DELETE FROM operation_logs").Error; err != nil {
|
|
t.Fatalf("cleanup operation_logs failed: %v", err)
|
|
}
|
|
|
|
return repository.NewOperationLogRepository(db)
|
|
}
|
|
|
|
func waitForOperationLogs(t *testing.T, repo *repository.OperationLogRepository, want int) []*domain.OperationLog {
|
|
t.Helper()
|
|
|
|
deadline := time.Now().Add(2 * time.Second)
|
|
for time.Now().Before(deadline) {
|
|
logs, _, err := repo.List(context.Background(), 0, 20)
|
|
if err != nil {
|
|
t.Fatalf("list operation logs failed: %v", err)
|
|
}
|
|
if len(logs) >= want {
|
|
return logs
|
|
}
|
|
time.Sleep(25 * time.Millisecond)
|
|
}
|
|
|
|
logs, _, err := repo.List(context.Background(), 0, 20)
|
|
if err != nil {
|
|
t.Fatalf("list operation logs failed: %v", err)
|
|
}
|
|
t.Fatalf("timed out waiting for %d operation logs, got %d", want, len(logs))
|
|
return nil
|
|
}
|
|
|
|
func TestOperationLogMiddleware_SkipsReadOnlyMethods(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
repo := newOperationLogRepositoryForTest(t)
|
|
router := gin.New()
|
|
router.Use(NewOperationLogMiddleware(repo).Record())
|
|
router.GET("/logs", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
|
|
recorder := httptest.NewRecorder()
|
|
router.ServeHTTP(recorder, req)
|
|
|
|
if recorder.Code != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d", recorder.Code)
|
|
}
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
logs, _, err := repo.List(context.Background(), 0, 20)
|
|
if err != nil {
|
|
t.Fatalf("list operation logs failed: %v", err)
|
|
}
|
|
if len(logs) != 0 {
|
|
t.Fatalf("expected no logs for GET request, got %d", len(logs))
|
|
}
|
|
}
|
|
|
|
func TestOperationLogMiddleware_RecordsAdminMutationAndSanitizesParams(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
repo := newOperationLogRepositoryForTest(t)
|
|
router := gin.New()
|
|
router.Use(func(c *gin.Context) {
|
|
c.Set("user_id", int64(42))
|
|
c.Set(ContextKeyRoleCodes, []string{"admin"})
|
|
c.Next()
|
|
})
|
|
router.Use(NewOperationLogMiddleware(repo).Record())
|
|
router.POST("/users", func(c *gin.Context) {
|
|
c.Status(http.StatusCreated)
|
|
})
|
|
|
|
body := `{"username":"alice","password":"super-secret","token":"abc"}`
|
|
req := httptest.NewRequest(http.MethodPost, "/users", strings.NewReader(body))
|
|
req.RemoteAddr = "203.0.113.10:8080"
|
|
req.Header.Set("User-Agent", "middleware-test")
|
|
recorder := httptest.NewRecorder()
|
|
router.ServeHTTP(recorder, req)
|
|
|
|
if recorder.Code != http.StatusCreated {
|
|
t.Fatalf("expected 201, got %d", recorder.Code)
|
|
}
|
|
|
|
logs := waitForOperationLogs(t, repo, 1)
|
|
entry := logs[0]
|
|
if entry.UserID == nil || *entry.UserID != 42 {
|
|
t.Fatalf("user_id = %#v, want 42", entry.UserID)
|
|
}
|
|
if entry.OperationType != "admin:CREATE" {
|
|
t.Fatalf("operation_type = %q, want admin:CREATE", entry.OperationType)
|
|
}
|
|
if entry.ResponseStatus != http.StatusCreated {
|
|
t.Fatalf("response_status = %d, want %d", entry.ResponseStatus, http.StatusCreated)
|
|
}
|
|
if strings.Contains(entry.RequestParams, "super-secret") || strings.Contains(entry.RequestParams, "abc") {
|
|
t.Fatalf("expected sanitized params, got %s", entry.RequestParams)
|
|
}
|
|
}
|
|
|
|
func TestOperationLogMiddleware_MethodToTypeAndSanitizeFallbacks(t *testing.T) {
|
|
if got := methodToType(http.MethodPatch); got != "UPDATE" {
|
|
t.Fatalf("methodToType(PATCH) = %q, want UPDATE", got)
|
|
}
|
|
if got := methodToType(http.MethodDelete); got != "DELETE" {
|
|
t.Fatalf("methodToType(DELETE) = %q, want DELETE", got)
|
|
}
|
|
if got := methodToType(http.MethodGet); got != "OTHER" {
|
|
t.Fatalf("methodToType(GET) = %q, want OTHER", got)
|
|
}
|
|
|
|
raw := []byte(`{"password":"secret","name":"alice"}`)
|
|
sanitized := sanitizeParams(raw)
|
|
if strings.Contains(sanitized, "secret") {
|
|
t.Fatalf("expected password to be masked, got %s", sanitized)
|
|
}
|
|
|
|
plain := sanitizeParams([]byte("not-json"))
|
|
if plain != "not-json" {
|
|
t.Fatalf("sanitizeParams(non-json) = %q, want not-json", plain)
|
|
}
|
|
|
|
var payload map[string]string
|
|
if err := json.Unmarshal([]byte(sanitized), &payload); err != nil {
|
|
t.Fatalf("unmarshal sanitized params failed: %v", err)
|
|
}
|
|
if payload["password"] != "***" {
|
|
t.Fatalf("password = %q, want ***", payload["password"])
|
|
}
|
|
}
|