Files
user-system/internal/api/middleware/gzip_test.go

103 lines
2.6 KiB
Go

package middleware
import (
"bytes"
"compress/gzip"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
)
func TestGzipMiddleware_CompressesLargeJSONResponses(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
router := gin.New()
router.Use(GzipMiddleware())
router.GET("/data", func(c *gin.Context) {
c.Header("Content-Type", "application/json")
c.String(http.StatusOK, strings.Repeat("a", gzipMinLength+128))
})
req := httptest.NewRequest(http.MethodGet, "/data", nil)
req.Header.Set("Accept-Encoding", "gzip")
router.ServeHTTP(recorder, req)
if got := recorder.Header().Get("Content-Encoding"); got != "gzip" {
t.Fatalf("Content-Encoding = %q, want gzip", got)
}
reader, err := gzip.NewReader(bytes.NewReader(recorder.Body.Bytes()))
if err != nil {
t.Fatalf("gzip.NewReader() error = %v", err)
}
defer reader.Close()
payload, err := io.ReadAll(reader)
if err != nil {
t.Fatalf("ReadAll() error = %v", err)
}
if got := string(payload); got != strings.Repeat("a", gzipMinLength+128) {
t.Fatalf("decompressed payload length = %d, want %d", len(got), gzipMinLength+128)
}
}
func TestGzipMiddleware_PassesThroughWhenCompressionNotUseful(t *testing.T) {
gin.SetMode(gin.TestMode)
testCases := []struct {
name string
acceptEncoding string
contentType string
body string
}{
{
name: "client does not accept gzip",
acceptEncoding: "",
contentType: "application/json",
body: strings.Repeat("b", gzipMinLength+64),
},
{
name: "body below threshold",
acceptEncoding: "gzip",
contentType: "application/json",
body: "small-body",
},
{
name: "unsupported content type",
acceptEncoding: "gzip",
contentType: "image/png",
body: strings.Repeat("c", gzipMinLength+64),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
router := gin.New()
router.Use(GzipMiddleware())
router.GET("/data", func(c *gin.Context) {
c.Header("Content-Type", tc.contentType)
c.String(http.StatusOK, tc.body)
})
req := httptest.NewRequest(http.MethodGet, "/data", nil)
if tc.acceptEncoding != "" {
req.Header.Set("Accept-Encoding", tc.acceptEncoding)
}
router.ServeHTTP(recorder, req)
if got := recorder.Header().Get("Content-Encoding"); got != "" {
t.Fatalf("Content-Encoding = %q, want empty", got)
}
if got := recorder.Body.String(); got != tc.body {
t.Fatalf("body length = %d, want %d", len(got), len(tc.body))
}
})
}
}