Files
llm-intelligence/scripts/fetch_openrouter.go

627 lines
17 KiB
Go
Raw Normal View History

//go:build llm_script
// fetch_openrouter.go - OpenRouter 模型数据采集器 v2.0
// Sprint 2 增强版:指数退避重试 + 批量插入 + ProviderMapper + audit_log + 价格变动检测 + slog
package main
import (
"bufio"
"context"
"database/sql"
"encoding/json"
"flag"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"strings"
"time"
"llm-intelligence/internal/collectors"
"llm-intelligence/internal/retry"
_ "github.com/lib/pq"
)
// Config 采集配置
type Config struct {
APIKey string
APIURL string
OutPath string
MaxRetries int
TimeoutSec int
BatchSize int
DBConn string
StrictReal bool
}
// ModelInfo 模型信息(与 collectors 包兼容)
type ModelInfo struct {
ID string `json:"id"`
Name string `json:"name,omitempty"`
Created int64 `json:"created,omitempty"`
Description string `json:"description,omitempty"`
ContextLength int `json:"context_length,omitempty"`
Capabilities []string `json:"capabilities,omitempty"`
Pricing ModelPricing `json:"pricing,omitempty"`
}
type ModelPricing struct {
Input float64 `json:"input,omitempty"`
Output float64 `json:"output,omitempty"`
}
var (
collectorVersion = "v2.0"
logger *slog.Logger
)
func init() {
logger = slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelInfo,
}))
}
func main() {
cfg := parseArgs()
start := time.Now()
logger.Info("采集器启动", "collector", "openrouter", "version", collectorVersion, "batch_size", cfg.BatchSize)
var runErr error
if err := run(cfg); err != nil {
logger.Error("采集失败", "error", err, "duration", time.Since(start))
runErr = err
}
duration := time.Since(start)
// 写入采集统计
if cfg.DBConn != "" {
if err := recordCollectorStats(cfg.DBConn, runErr, duration); err != nil {
logger.Warn("采集统计写入失败", "error", err)
}
}
if runErr != nil {
os.Exit(1)
}
logger.Info("采集完成", "collector", "openrouter", "duration_ms", duration.Milliseconds())
}
func parseArgs() Config {
loadProjectEnv()
apiKey := flag.String("api-key", "", "OpenRouter API Key")
apiURL := flag.String("api-url", "https://openrouter.ai/api/v1/models", "API 地址")
outPath := flag.String("out", "models.json", "输出文件路径")
maxRetries := flag.Int("retry", 3, "最大重试次数")
timeoutSec := flag.Int("timeout", 30, "请求超时(秒)")
batchSize := flag.Int("batch", 100, "批量插入批次大小")
dbConn := flag.String("db", os.Getenv("DATABASE_URL"), "PostgreSQL 连接字符串")
strictReal := flag.Bool("strict-real", false, "严格真实模式:缺少 API Key 或数据库写入失败时返回错误")
flag.Parse()
return Config{
APIKey: *apiKey,
APIURL: *apiURL,
OutPath: *outPath,
MaxRetries: *maxRetries,
TimeoutSec: *timeoutSec,
BatchSize: *batchSize,
DBConn: *dbConn,
StrictReal: *strictReal,
}
}
func loadProjectEnv() {
for _, path := range []string{".env.local", ".env"} {
loadEnvFile(path)
}
}
func loadEnvFile(path string) {
f, err := os.Open(path)
if err != nil {
return
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
key, value, ok := strings.Cut(line, "=")
if !ok {
continue
}
key = strings.TrimSpace(key)
value = strings.TrimSpace(value)
value = strings.Trim(value, `"'`)
if key == "" {
continue
}
if _, exists := os.LookupEnv(key); exists {
continue
}
_ = os.Setenv(key, value)
}
}
func run(cfg Config) error {
models, err := fetchModels(cfg)
if err != nil {
return err
}
logger.Info("API 数据获取完成", "records", len(models))
if cfg.DBConn != "" {
if err := summarizeDB(cfg.DBConn, models, cfg.BatchSize); err != nil {
logger.Error("PostgreSQL 写入失败", "error", err)
if cfg.StrictReal {
return fmt.Errorf("PostgreSQL 写入失败: %w", err)
}
logger.Warn("降级为仅写入 JSON")
} else {
logger.Info("PostgreSQL 写入完成", "records", len(models))
}
}
return summarize(cfg.OutPath, models)
}
// fetchModels 抓取 OpenRouter 模型列表(集成指数退避重试)
func fetchModels(cfg Config) ([]ModelInfo, error) {
if cfg.APIKey == "" {
if cfg.StrictReal {
return nil, fmt.Errorf("严格真实模式下必须提供 API Key")
}
logger.Warn("未提供 API Key使用模拟数据")
return []ModelInfo{
{ID: "openai/gpt-4o", ContextLength: 128000, Pricing: ModelPricing{Input: 2.5, Output: 10.0}},
{ID: "anthropic/claude-3.5-sonnet:free", ContextLength: 200000, Pricing: ModelPricing{}},
}, nil
}
strategy := retry.Strategy{
MaxRetries: cfg.MaxRetries,
BaseDelay: 1 * time.Second,
MaxDelay: 30 * time.Second,
Multiplier: 2.0,
Jitter: true,
Retryable: retry.IsRetryable,
}
var models []ModelInfo
var lastErr error
err := retry.Do(context.Background(), strategy, func() error {
client := &http.Client{Timeout: time.Duration(cfg.TimeoutSec) * time.Second}
req, err := http.NewRequest("GET", cfg.APIURL, nil)
if err != nil {
return fmt.Errorf("构造请求失败: %w", err)
}
req.Header.Set("Authorization", "Bearer "+cfg.APIKey)
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
lastErr = err
return fmt.Errorf("请求失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
lastErr = fmt.Errorf("非 200 响应: %d %s", resp.StatusCode, string(body))
return lastErr
}
body, err := io.ReadAll(resp.Body)
if err != nil {
lastErr = err
return fmt.Errorf("读取响应失败: %w", err)
}
models, err = parseModels(body)
if err != nil {
lastErr = err
return fmt.Errorf("JSON 解析失败: %w", err)
}
return nil
})
if err != nil {
return nil, fmt.Errorf("采集失败(%d次尝试: %w", strategy.MaxRetries+1, lastErr)
}
return models, nil
}
func parseModels(raw []byte) ([]ModelInfo, error) {
var wrapper struct {
Data json.RawMessage `json:"data"`
}
if err := json.Unmarshal(raw, &wrapper); err != nil {
return nil, fmt.Errorf("解析 data 字段失败: %w", err)
}
var rawItems []any
if err := json.Unmarshal(wrapper.Data, &rawItems); err != nil {
return nil, fmt.Errorf("解析模型数组失败: %w", err)
}
models := make([]ModelInfo, 0, len(rawItems))
for _, item := range rawItems {
m, ok := item.(map[string]any)
if !ok {
continue
}
model := ModelInfo{
ID: getString(m, "id"),
Name: getString(m, "name"),
}
if model.ID == "" {
continue
}
if p, ok := m["pricing"].(map[string]any); ok {
model.Pricing.Input = getPrice(p, "input", "prompt")
model.Pricing.Output = getPrice(p, "output", "completion")
}
model.ContextLength = getInt(m, "context_length")
model.Description = getString(m, "description")
model.Created = getInt64(m, "created")
if caps, ok := m["capabilities"].([]any); ok {
for _, c := range caps {
if s, ok := c.(string); ok {
model.Capabilities = append(model.Capabilities, s)
}
}
}
models = append(models, model)
}
return models, nil
}
func getString(m map[string]any, key string) string {
if v, ok := m[key].(string); ok {
return v
}
return ""
}
func getInt(m map[string]any, key string) int {
if v, ok := m[key].(float64); ok {
return int(v)
}
return 0
}
func getInt64(m map[string]any, key string) int64 {
if v, ok := m[key].(float64); ok {
return int64(v)
}
return 0
}
func getPrice(m map[string]any, keys ...string) float64 {
for _, k := range keys {
if v, ok := m[k].(float64); ok {
return v
}
}
return 0
}
func summarize(outPath string, models []ModelInfo) error {
return writeJSON(outPath, models)
}
// summarizeDB 将采集结果写入 PostgreSQL批量插入 + ProviderMapper + 价格变动检测 + audit_log
func summarizeDB(connStr string, models []ModelInfo, batchSize int) error {
db, err := sql.Open("postgres", connStr)
if err != nil {
return fmt.Errorf("连接数据库失败: %w", err)
}
defer db.Close()
if err := db.Ping(); err != nil {
return fmt.Errorf("ping 数据库失败: %w", err)
}
batchID := fmt.Sprintf("batch-%d", time.Now().Unix())
now := time.Now()
effectiveDate := now.Format("2006-01-02")
// 获取默认 operatorOpenRouter
var operatorID int64
err = db.QueryRow("SELECT id FROM operator WHERE name = 'OpenRouter' LIMIT 1").Scan(&operatorID)
if err != nil {
logger.Warn("未找到 OpenRouter operator使用 NULL", "error", err)
operatorID = 0
}
// 获取上次价格数据(用于变动检测)
lastPrices := make(map[int64]ModelPricing)
rows, err := db.Query(`
SELECT model_id, input_price_per_mtok, output_price_per_mtok
FROM region_pricing
WHERE operator_id = $1 AND effective_date = (
SELECT MAX(effective_date) FROM region_pricing WHERE operator_id = $1
)
`, operatorID)
if err == nil {
for rows.Next() {
var mid int64
var p ModelPricing
if err := rows.Scan(&mid, &p.Input, &p.Output); err == nil {
lastPrices[mid] = p
}
}
rows.Close()
}
insertedModels := 0
insertedPrices := 0
priceChanges := 0
// 批量处理
for i := 0; i < len(models); i += batchSize {
end := i + batchSize
if end > len(models) {
end = len(models)
}
batch := models[i:end]
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("开启事务失败: %w", err)
}
for _, m := range batch {
// 使用 ProviderMapper 映射厂商
mapping, err := collectors.MapOpenRouterID(m.ID)
if err != nil {
logger.Warn("Provider 映射失败", "id", m.ID, "error", err)
mapping = collectors.ModelMapping{
Provider: collectors.ProviderInfo{ID: "unknown", Name: "Unknown"},
ModelName: m.Name,
RawID: m.ID,
IsFree: false,
}
}
// 查找或创建 provider_id
var providerID int64
err = tx.QueryRow("SELECT id FROM model_provider WHERE name = $1 LIMIT 1", mapping.Provider.Name).Scan(&providerID)
if err != nil {
// 未知厂商,插入
err = tx.QueryRow(`
INSERT INTO model_provider (name, name_cn, country, status)
VALUES ($1, $2, $3, 'active')
ON CONFLICT (name) DO UPDATE SET name = EXCLUDED.name
RETURNING id
`, mapping.Provider.Name, mapping.Provider.NameCN, mapping.Provider.Country).Scan(&providerID)
if err != nil {
logger.Warn("创建 provider 失败", "name", mapping.Provider.Name, "error", err)
providerID = 0
}
}
isFree := mapping.IsFree || (m.Pricing.Input == 0 && m.Pricing.Output == 0)
// upsert models 表(带新字段)
var modelID int64
err = tx.QueryRow(`
INSERT INTO models (
source, external_id, name, description, context_length,
capabilities, created_at_source, is_free, status,
raw_payload, provider_id, version, modality,
data_confidence, retrieved_at, batch_id, collector_version,
source_url, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $19)
ON CONFLICT (external_id) DO UPDATE SET
name = EXCLUDED.name,
description = EXCLUDED.description,
context_length = EXCLUDED.context_length,
capabilities = EXCLUDED.capabilities,
created_at_source = EXCLUDED.created_at_source,
is_free = EXCLUDED.is_free,
status = EXCLUDED.status,
raw_payload = EXCLUDED.raw_payload,
provider_id = EXCLUDED.provider_id,
data_confidence = 'official',
retrieved_at = EXCLUDED.retrieved_at,
batch_id = EXCLUDED.batch_id,
collector_version = EXCLUDED.collector_version,
updated_at = EXCLUDED.updated_at
RETURNING id
`,
"openrouter", m.ID, m.Name, m.Description, m.ContextLength,
jsonCapabilities(m.Capabilities), m.Created, isFree, "active",
rawPayload(m), providerID, "", "text",
"official", now, batchID, collectorVersion,
"https://openrouter.ai/api/v1/models", now).Scan(&modelID)
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("写入 models 失败 (%s): %w", m.ID, err)
}
insertedModels++
// 写入 audit_log
_, _ = tx.Exec(`
INSERT INTO audit_log (table_name, record_id, field_name, old_value, new_value, operation, operator, batch_id, source_url)
VALUES ('models', $1, 'external_id', NULL, $2, 'INSERT', 'fetch_openrouter', $3, $4)
`, modelID, m.ID, batchID, "https://openrouter.ai/api/v1/models")
// upsert region_pricing 表(替代 model_prices
sourceType := "reseller"
freeQuota := ""
freeLimitations := "[]"
rateLimit := "{}"
if isFree {
sourceType = "free_tier"
freeQuota = "Imported free-tier pricing entry"
freeLimitations = `["See source_url for current quota and policy"]`
}
var pricingID int64
err = tx.QueryRow(`
INSERT INTO region_pricing (
model_id, operator_id, region, currency,
input_price_per_mtok, output_price_per_mtok,
is_free, effective_date, source_url, source_type,
free_quota, free_limitations, rate_limit,
created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $14)
ON CONFLICT (model_id, operator_id, region, currency, effective_date) DO UPDATE SET
input_price_per_mtok = EXCLUDED.input_price_per_mtok,
output_price_per_mtok = EXCLUDED.output_price_per_mtok,
is_free = EXCLUDED.is_free,
source_type = EXCLUDED.source_type,
free_quota = EXCLUDED.free_quota,
free_limitations = EXCLUDED.free_limitations,
rate_limit = EXCLUDED.rate_limit,
updated_at = EXCLUDED.updated_at
RETURNING id
`, modelID, operatorID, "global", "USD", m.Pricing.Input, m.Pricing.Output,
isFree, effectiveDate, "https://openrouter.ai/api/v1/models", sourceType,
freeQuota, freeLimitations, rateLimit, now).Scan(&pricingID)
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("写入 region_pricing 失败 (%s): %w", m.ID, err)
}
insertedPrices++
// 价格变动检测(>5%
if lastPrice, ok := lastPrices[modelID]; ok {
inputChange := calcChangePercent(lastPrice.Input, m.Pricing.Input)
outputChange := calcChangePercent(lastPrice.Output, m.Pricing.Output)
if abs(inputChange) > 5 || abs(outputChange) > 5 {
_, _ = tx.Exec(`
INSERT INTO pricing_history (
model_id, region, currency,
old_input_price, new_input_price,
old_output_price, new_output_price,
change_percent, changed_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
`, modelID, "global", "USD",
lastPrice.Input, m.Pricing.Input,
lastPrice.Output, m.Pricing.Output,
max(abs(inputChange), abs(outputChange)), now)
priceChanges++
}
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("提交事务失败: %w", err)
}
logger.Info("批次完成", "batch", i/batchSize+1, "records", len(batch))
}
logger.Info("PostgreSQL 写入完成",
"models", insertedModels,
"prices", insertedPrices,
"price_changes", priceChanges,
"batch_id", batchID)
return nil
}
func calcChangePercent(old, new float64) float64 {
if old == 0 {
if new == 0 {
return 0
}
return 100
}
return ((new - old) / old) * 100
}
func abs(v float64) float64 {
if v < 0 {
return -v
}
return v
}
func max(a, b float64) float64 {
if a > b {
return a
}
return b
}
func jsonCapabilities(caps []string) []byte {
if len(caps) == 0 {
return []byte("[]")
}
b, _ := json.Marshal(caps)
return b
}
func rawPayload(m ModelInfo) []byte {
b, _ := json.Marshal(m)
return b
}
func writeJSON(outPath string, models []ModelInfo) error {
total := len(models)
var freeCnt, paidCnt int
for _, m := range models {
if len(m.ID) > 5 && m.ID[len(m.ID)-5:] == ":free" {
freeCnt++
} else if m.Pricing.Input > 0 || m.Pricing.Output > 0 {
paidCnt++
}
}
summary := fmt.Sprintf("采集完成: 共 %d 模型(免费 %d / 付费 %d\n", total, freeCnt, paidCnt)
fmt.Print(summary)
out, err := os.Create(outPath)
if err != nil {
return fmt.Errorf("创建输出文件失败: %w", err)
}
defer out.Close()
enc := json.NewEncoder(out)
enc.SetIndent("", " ")
if err := enc.Encode(map[string]any{
"generated_at": time.Now().Format(time.RFC3339),
"total": total,
"free": freeCnt,
"paid": paidCnt,
"models": models,
}); err != nil {
return fmt.Errorf("写入 JSON 失败: %w", err)
}
fmt.Printf("结果已写入: %s\n", outPath)
return nil
}
// recordCollectorStats 记录采集统计到 collector_stats 表
func recordCollectorStats(connStr string, runErr error, duration time.Duration) error {
db, err := sql.Open("postgres", connStr)
if err != nil {
return err
}
defer db.Close()
success := runErr == nil
errMsg := ""
if runErr != nil {
errMsg = runErr.Error()
}
_, err = db.Exec(`
INSERT INTO collector_stats (source, batch_id, success, duration_ms, error_message, created_at)
VALUES ('openrouter', $1, $2, $3, $4, $5)
`, fmt.Sprintf("batch-%d", time.Now().Unix()), success, int(duration.Milliseconds()), errMsg, time.Now())
return err
}