Files
llm-intelligence/scripts/fetch_openrouter_test.go

133 lines
3.4 KiB
Go
Raw Normal View History

//go:build llm_script
package main
import (
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
)
// Test 1: parseModels 正确解析 name、context_length、capabilities、pricing input/prompt 和 output/completion
func TestParseModels(t *testing.T) {
// 从样例文件读取,而非内联 JSON
samplePath := filepath.Join("testdata", "openrouter_models_sample.json")
raw, err := os.ReadFile(samplePath)
if err != nil {
t.Fatalf("读取样例文件失败: %v", err)
}
models, err := parseModels(raw)
if err != nil {
t.Fatalf("parseModels 失败: %v", err)
}
if len(models) != 3 {
t.Fatalf("期望 3 条,实际 %d", len(models))
}
// 第一条:完整字段
m := models[0]
if m.ID != "openai/gpt-4o" {
t.Errorf("ID 错误: %s", m.ID)
}
if m.Name != "GPT-4o" {
t.Errorf("Name 错误: %s", m.Name)
}
if m.ContextLength != 128000 {
t.Errorf("ContextLength 错误: %d", m.ContextLength)
}
if len(m.Capabilities) != 3 {
t.Errorf("Capabilities 长度错误: %d", len(m.Capabilities))
}
if m.Pricing.Input != 2.5 {
t.Errorf("Pricing.Input 错误: %f", m.Pricing.Input)
}
if m.Pricing.Output != 10.0 {
t.Errorf("Pricing.Output 错误: %f", m.Pricing.Output)
}
// 第二条pricing 用 prompt/completion 别名回退
m2 := models[1]
if m2.Pricing.Input != 0.1 {
t.Errorf("Input 回退 prompt 失败: %f", m2.Pricing.Input)
}
if m2.Pricing.Output != 0.3 {
t.Errorf("Output 回退 completion 失败: %f", m2.Pricing.Output)
}
// 第三条:空 pricing
m3 := models[2]
if m3.Pricing.Input != 0 || m3.Pricing.Output != 0 {
t.Errorf("空 pricing 未返回 0: input=%f output=%f", m3.Pricing.Input, m3.Pricing.Output)
}
}
// Test 2: run 无 API Key 时写入临时文件JSON 含 total 和 models 字段
func TestRunNoAPIKey(t *testing.T) {
tmpDir := t.TempDir()
outPath := filepath.Join(tmpDir, "models.json")
cfg := Config{OutPath: outPath}
err := run(cfg)
if err != nil {
t.Fatalf("run 失败: %v", err)
}
data, err := os.ReadFile(outPath)
if err != nil {
t.Fatalf("读取输出文件失败: %v", err)
}
var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("JSON 解析失败: %v", err)
}
if _, ok := result["total"]; !ok {
t.Error("JSON 缺少 total 字段")
}
if _, ok := result["models"]; !ok {
t.Error("JSON 缺少 models 字段")
}
models, ok := result["models"].([]any)
if !ok {
t.Fatal("models 字段类型错误")
}
if len(models) == 0 {
t.Error("models 为空")
}
}
func TestFetchModelsFailsInStrictRealModeWithoutAPIKey(t *testing.T) {
_, err := fetchModels(Config{StrictReal: true})
if err == nil {
t.Fatal("strict real mode should fail without API key")
}
}
func TestRunFailsInStrictRealModeWhenDBWriteFails(t *testing.T) {
tmpDir := t.TempDir()
outPath := filepath.Join(tmpDir, "models.json")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"data":[{"id":"openai/gpt-4o","name":"GPT-4o","context_length":128000,"pricing":{"input":2.5,"output":10.0}}]}`))
}))
defer server.Close()
err := run(Config{
APIKey: "test-key",
APIURL: server.URL,
OutPath: outPath,
DBConn: "postgres://invalid@127.0.0.1:1/invalid?sslmode=disable",
BatchSize: 10,
TimeoutSec: 1,
StrictReal: true,
})
if err == nil {
t.Fatal("strict real mode should fail when database write fails")
}
}