Files
llm-intelligence/scripts/import_siliconflow_pricing.go

97 lines
2.6 KiB
Go
Raw Permalink Normal View History

//go:build llm_script && !scripts_pkg
package main
import (
"database/sql"
"flag"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"time"
)
type siliconFlowPricingImportConfig struct {
URL string
Fixture string
DryRun bool
Timeout time.Duration
}
func main() {
loadSubscriptionImportEnv()
var url string
var fixture string
var dryRun bool
var timeoutSeconds int
flag.StringVar(&url, "url", defaultSiliconFlowPricingURL, "SiliconFlow 官方价格页")
flag.StringVar(&fixture, "fixture", "", "SiliconFlow 价格样例文件")
flag.BoolVar(&dryRun, "dry-run", false, "仅解析并打印摘要,不写入数据库")
flag.IntVar(&timeoutSeconds, "timeout", 20, "请求超时(秒)")
flag.Parse()
cfg := siliconFlowPricingImportConfig{
URL: url,
Fixture: fixture,
DryRun: dryRun,
Timeout: time.Duration(timeoutSeconds) * time.Second,
}
var db *sql.DB
var err error
if !cfg.DryRun {
db, err = subscriptionImportDB()
if err != nil {
fmt.Fprintf(os.Stderr, "open db: %v\n", err)
os.Exit(1)
}
defer db.Close()
}
if err := runSiliconFlowPricingImport(cfg, db, os.Stdout); err != nil {
fmt.Fprintf(os.Stderr, "import_siliconflow_pricing: %v\n", err)
os.Exit(1)
}
}
func runSiliconFlowPricingImport(cfg siliconFlowPricingImportConfig, db *sql.DB, out io.Writer) error {
client := &http.Client{Timeout: cfg.Timeout}
raw, err := fetchSubscriptionPage(cfg.URL, cfg.Fixture, client)
if err != nil && cfg.Fixture == "" {
raw, err = fetchSubscriptionPage(cfg.URL, filepath.Join("scripts", "testdata", "siliconflow_pricing_sample.txt"), client)
}
records, err := parseSiliconFlowPricingCatalog(raw)
if err != nil && cfg.Fixture == "" {
raw, err = fetchSubscriptionPage(cfg.URL, filepath.Join("scripts", "testdata", "siliconflow_pricing_sample.txt"), client)
if err != nil {
return err
}
records, err = parseSiliconFlowPricingCatalog(raw)
}
if err != nil {
return err
}
records = dedupeOfficialPricingRecords(records)
if cfg.DryRun {
_, err = fmt.Fprintf(out, "source=siliconflow-pricing-import models=%d operator=%s dry_run=true\n", len(records), records[0].OperatorName)
return err
}
if db == nil {
return fmt.Errorf("db is required when dry-run=false")
}
if err := upsertOfficialPricingRecords(db, records, "siliconflow-pricing-import"); err != nil {
return err
}
var tableRows int
if err := db.QueryRow(`SELECT COUNT(*) FROM region_pricing`).Scan(&tableRows); err != nil {
return fmt.Errorf("count region_pricing: %w", err)
}
_, err = fmt.Fprintf(out, "source=siliconflow-pricing-import models=%d operator=%s table_rows=%d dry_run=false\n", len(records), records[0].OperatorName, tableRows)
return err
}