diff --git a/internal/pkg/gemini/models_test.go b/internal/pkg/gemini/models_test.go index b80047f..e75caab 100644 --- a/internal/pkg/gemini/models_test.go +++ b/internal/pkg/gemini/models_test.go @@ -1,28 +1,49 @@ package gemini -import "testing" +import ( + "strings" + "testing" -func TestDefaultModels_ContainsImageModels(t *testing.T) { - t.Parallel() + "github.com/stretchr/testify/require" +) +func TestDefaultModels(t *testing.T) { models := DefaultModels() - byName := make(map[string]Model, len(models)) - for _, model := range models { - byName[model.Name] = model - } - - required := []string{ - "models/gemini-2.5-flash-image", - "models/gemini-3.1-flash-image", - } - - for _, name := range required { - model, ok := byName[name] - if !ok { - t.Fatalf("expected fallback model %q to exist", name) - } - if len(model.SupportedGenerationMethods) == 0 { - t.Fatalf("expected fallback model %q to advertise generation methods", name) - } + + // Should return 8 models + require.Len(t, models, 8) + + // Each model should have name and methods + for _, m := range models { + require.NotEmpty(t, m.Name) + require.True(t, strings.HasPrefix(m.Name, "models/")) + require.Contains(t, m.SupportedGenerationMethods, "generateContent") + require.Contains(t, m.SupportedGenerationMethods, "streamGenerateContent") + } +} + +func TestFallbackModelsList(t *testing.T) { + resp := FallbackModelsList() + require.Len(t, resp.Models, 8) +} + +func TestFallbackModel(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"", "models/unknown"}, + {"gemini-2.0-flash", "models/gemini-2.0-flash"}, + {"models/gemini-2.5-pro", "models/gemini-2.5-pro"}, + {"custom-model", "models/custom-model"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + model := FallbackModel(tt.input) + require.Equal(t, tt.expected, model.Name) + require.Contains(t, model.SupportedGenerationMethods, "generateContent") + require.Contains(t, model.SupportedGenerationMethods, "streamGenerateContent") + }) } } diff --git a/internal/pkg/geminicli/sanitize_test.go b/internal/pkg/geminicli/sanitize_test.go new file mode 100644 index 0000000..8267646 --- /dev/null +++ b/internal/pkg/geminicli/sanitize_test.go @@ -0,0 +1,103 @@ +package geminicli + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsBase64Char(t *testing.T) { + tests := []struct { + char byte + want bool + }{ + {'A', true}, {'Z', true}, + {'a', true}, {'z', true}, + {'0', true}, {'9', true}, + {'+', true}, {'/', true}, {'=', true}, + {'-', false}, {'_', false}, {' ', false}, + {'.', false}, {'\n', false}, + } + + for _, tt := range tests { + t.Run(string(tt.char), func(t *testing.T) { + got := isBase64Char(tt.char) + require.Equal(t, tt.want, got) + }) + } +} + +func TestTruncateBase64InMessage(t *testing.T) { + tests := []struct { + name string + msg string + want string + }{ + { + name: "no_base64", + msg: "This is a normal message without base64", + want: "This is a normal message without base64", + }, + { + name: "short_base64", + msg: "data:image/png;base64,abc123", + want: "data:image/png;base64,abc123", + }, + { + name: "long_base64_truncated", + msg: "data:image/png;base64," + strings.Repeat("a", 100), + want: "data:image/png;base64," + strings.Repeat("a", 50) + "...[truncated]", + }, + { + name: "multiple_base64", + msg: "start;base64," + strings.Repeat("b", 30) + " middle;base64," + strings.Repeat("c", 60), + want: "start;base64," + strings.Repeat("b", 30) + " middle;base64," + strings.Repeat("c", 50) + "...[truncated]", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := truncateBase64InMessage(tt.msg) + require.Equal(t, tt.want, got) + }) + } +} + +func TestSanitizeBodyForLogs(t *testing.T) { + tests := []struct { + name string + body string + check func(t *testing.T, got string) + }{ + { + name: "short_body_no_change", + body: "Short message", + check: func(t *testing.T, got string) { + require.Equal(t, "Short message", got) + }, + }, + { + name: "body_truncated", + body: strings.Repeat("x", 3000), + check: func(t *testing.T, got string) { + require.LessOrEqual(t, len(got), 2100) // maxLogBodyLen + "...[truncated]" + require.True(t, strings.HasSuffix(got, "...[truncated]")) + }, + }, + { + name: "body_with_base64_truncated", + body: "data:image/png;base64," + strings.Repeat("a", 100), + check: func(t *testing.T, got string) { + require.Contains(t, got, "...[truncated]") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := SanitizeBodyForLogs(tt.body) + tt.check(t, got) + }) + } +} diff --git a/internal/pkg/openai/constants_test.go b/internal/pkg/openai/constants_test.go new file mode 100644 index 0000000..de47599 --- /dev/null +++ b/internal/pkg/openai/constants_test.go @@ -0,0 +1,49 @@ +package openai + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDefaultModelIDs(t *testing.T) { + ids := DefaultModelIDs() + + // Should return same number of IDs as DefaultModels + require.Equal(t, len(DefaultModels), len(ids)) + + // Check all expected IDs are present + expected := []string{ + "gpt-5.4", + "gpt-5.4-mini", + "gpt-5.4-nano", + "gpt-5.3-codex", + "gpt-5.3-codex-spark", + "gpt-5.2", + "gpt-5.2-codex", + "gpt-5.1-codex-max", + "gpt-5.1-codex", + "gpt-5.1", + "gpt-5.1-codex-mini", + "gpt-5", + } + require.Equal(t, expected, ids) +} + +func TestDefaultModels(t *testing.T) { + // Verify DefaultModels is not empty + require.NotEmpty(t, DefaultModels) + + // Verify each model has required fields + for _, m := range DefaultModels { + require.NotEmpty(t, m.ID, "Model ID should not be empty") + require.NotEmpty(t, m.DisplayName, "DisplayName should not be empty") + require.NotEmpty(t, m.Object, "Object should not be empty") + require.NotEmpty(t, m.OwnedBy, "OwnedBy should not be empty") + require.NotZero(t, m.Created, "Created should not be zero") + } +} + +func TestDefaultTestModel(t *testing.T) { + require.Equal(t, "gpt-5.1-codex", DefaultTestModel) +}