feat(probe): add capability profile and smoke completion routing

This commit is contained in:
phamnazage-jpg
2026-05-22 14:31:41 +08:00
parent 2bc7554cf8
commit 6420efbef1
4 changed files with 493 additions and 0 deletions

View File

@@ -0,0 +1,157 @@
package probe
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
)
type TransportProfile struct {
SupportsOpenAIModels bool `json:"supports_openai_models"`
SupportsOpenAIChatCompletions bool `json:"supports_openai_chat_completions"`
SupportsOpenAIResponses bool `json:"supports_openai_responses"`
SupportsAnthropicMessages bool `json:"supports_anthropic_messages"`
AuthStyle string `json:"auth_style"`
ModelIDStyle string `json:"model_id_style"`
KnownAdvisories []string `json:"known_advisories"`
}
type ModelCapabilityProfile struct {
RawModelID string `json:"raw_model_id"`
NormalizedModelID string `json:"normalized_model_id"`
CanonicalModelFamily string `json:"canonical_model_family"`
SupportsStream string `json:"supports_stream"`
SupportsTools string `json:"supports_tools"`
SupportsReasoningFields string `json:"supports_reasoning_fields"`
SmokeChatOK bool `json:"smoke_chat_ok"`
}
type CapabilityProfile struct {
TransportProfile TransportProfile `json:"transport_profile"`
ModelProfiles []ModelCapabilityProfile `json:"model_profiles"`
}
func ProbeCapabilities(ctx context.Context, baseURL, apiKey string, rawModels []string) (*CapabilityProfile, error) {
profile := &CapabilityProfile{
TransportProfile: TransportProfile{
SupportsOpenAIModels: len(rawModels) > 0,
AuthStyle: "bearer",
ModelIDStyle: detectModelIDStyle(rawModels),
KnownAdvisories: []string{},
},
ModelProfiles: make([]ModelCapabilityProfile, 0, len(rawModels)),
}
responsesStatus, err := probeJSONEndpoint(ctx, baseURL, apiKey, "/v1/responses", map[string]any{
"model": firstNonEmptyModel(rawModels),
"input": "ping",
})
if err != nil {
return nil, err
}
profile.TransportProfile.SupportsOpenAIResponses = responsesStatus >= http.StatusOK && responsesStatus < http.StatusMultipleChoices
for _, rawModel := range rawModels {
modelProfile := ModelCapabilityProfile{
RawModelID: strings.TrimSpace(rawModel),
NormalizedModelID: NormalizeModelID(rawModel),
CanonicalModelFamily: CanonicalModelFamily(rawModel),
SupportsStream: "unknown",
SupportsTools: "unknown",
SupportsReasoningFields: "unknown",
}
chatStatus, err := probeJSONEndpoint(ctx, baseURL, apiKey, "/v1/chat/completions", map[string]any{
"model": modelProfile.RawModelID,
"messages": []map[string]string{
{"role": "user", "content": "ping"},
},
"max_tokens": 8,
"temperature": 0,
})
if err != nil {
return nil, err
}
modelProfile.SmokeChatOK = chatStatus >= http.StatusOK && chatStatus < http.StatusMultipleChoices
if modelProfile.SmokeChatOK {
profile.TransportProfile.SupportsOpenAIChatCompletions = true
}
if chatStatus == http.StatusForbidden {
appendAdvisory(&profile.TransportProfile.KnownAdvisories, "initial_probe_race_expected")
}
profile.ModelProfiles = append(profile.ModelProfiles, modelProfile)
}
if !profile.TransportProfile.SupportsOpenAIResponses && profile.TransportProfile.SupportsOpenAIChatCompletions {
appendAdvisory(&profile.TransportProfile.KnownAdvisories, "responses_unsupported_but_chat_ok")
}
return profile, nil
}
func probeJSONEndpoint(ctx context.Context, baseURL, apiKey, path string, payload any) (int, error) {
requestURL, err := joinGatewayPath(baseURL, path)
if err != nil {
return 0, fmt.Errorf("resolve %s endpoint: %w", path, err)
}
var body bytes.Buffer
if payload != nil {
if err := json.NewEncoder(&body).Encode(payload); err != nil {
return 0, fmt.Errorf("encode %s probe payload: %w", path, err)
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, &body)
if err != nil {
return 0, fmt.Errorf("build %s request: %w", path, err)
}
req.Header.Set("Content-Type", "application/json")
if token := strings.TrimSpace(apiKey); token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
resp, err := (&http.Client{Timeout: 15 * time.Second}).Do(req)
if err != nil {
return 0, fmt.Errorf("request %s: %w", path, err)
}
defer resp.Body.Close()
return resp.StatusCode, nil
}
func detectModelIDStyle(rawModels []string) string {
for _, rawModel := range rawModels {
if strings.Contains(strings.TrimSpace(rawModel), "/") {
return "vendor_prefixed"
}
}
return "canonical"
}
func appendAdvisory(values *[]string, advisory string) {
if advisory == "" {
return
}
for _, existing := range *values {
if existing == advisory {
return
}
}
*values = append(*values, advisory)
}
func firstNonEmptyModel(rawModels []string) string {
for _, rawModel := range rawModels {
if trimmed := strings.TrimSpace(rawModel); trimmed != "" {
return trimmed
}
}
return "ping"
}

View File

@@ -0,0 +1,119 @@
package probe
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestProbeCapabilities(t *testing.T) {
t.Parallel()
t.Run("responses unsupported but chat works", func(t *testing.T) {
t.Parallel()
var responseCalls int
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/v1/responses":
responseCalls++
http.Error(w, `{"error":"unsupported"}`, http.StatusForbidden)
case "/v1/chat/completions":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"chatcmpl-1","choices":[{"message":{"content":"pong"}}]}`))
default:
t.Fatalf("unexpected path %q", r.URL.Path)
}
}))
defer server.Close()
profile, err := ProbeCapabilities(context.Background(), server.URL, "sk-test", []string{"kimi-k2.6"})
if err != nil {
t.Fatalf("ProbeCapabilities() error = %v", err)
}
if !profile.TransportProfile.SupportsOpenAIChatCompletions {
t.Fatal("SupportsOpenAIChatCompletions = false, want true")
}
if profile.TransportProfile.SupportsOpenAIResponses {
t.Fatal("SupportsOpenAIResponses = true, want false")
}
if responseCalls == 0 {
t.Fatal("responses endpoint was not probed")
}
if !containsString(profile.TransportProfile.KnownAdvisories, "responses_unsupported_but_chat_ok") {
t.Fatalf("KnownAdvisories = %#v, want responses advisory", profile.TransportProfile.KnownAdvisories)
}
})
t.Run("records per model capability profile", func(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/v1/responses":
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"id":"resp_1"}`))
case "/v1/chat/completions":
body := make([]byte, r.ContentLength)
_, _ = r.Body.Read(body)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"chatcmpl-1","choices":[{"message":{"content":"ok"}}]}`))
default:
t.Fatalf("unexpected path %q", r.URL.Path)
}
}))
defer server.Close()
profile, err := ProbeCapabilities(context.Background(), server.URL, "sk-test", []string{"deepseek-ai/DeepSeek-V4-Pro", "kimi-k2.6"})
if err != nil {
t.Fatalf("ProbeCapabilities() error = %v", err)
}
if len(profile.ModelProfiles) != 2 {
t.Fatalf("len(ModelProfiles) = %d, want 2", len(profile.ModelProfiles))
}
if profile.ModelProfiles[0].NormalizedModelID != "deepseek-v4-pro" {
t.Fatalf("NormalizedModelID = %q, want %q", profile.ModelProfiles[0].NormalizedModelID, "deepseek-v4-pro")
}
if profile.ModelProfiles[0].CanonicalModelFamily != "deepseek-v4-pro" {
t.Fatalf("CanonicalModelFamily = %q, want %q", profile.ModelProfiles[0].CanonicalModelFamily, "deepseek-v4-pro")
}
if !profile.ModelProfiles[0].SmokeChatOK {
t.Fatal("SmokeChatOK = false, want true")
}
})
t.Run("records initial probe advisory on transient auth race", func(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/v1/responses":
http.Error(w, `{"error":"forbidden"}`, http.StatusForbidden)
case "/v1/chat/completions":
http.Error(w, `{"error":"warmup"}`, http.StatusForbidden)
default:
t.Fatalf("unexpected path %q", r.URL.Path)
}
}))
defer server.Close()
profile, err := ProbeCapabilities(context.Background(), server.URL, "sk-test", []string{"kimi-k2.6"})
if err != nil {
t.Fatalf("ProbeCapabilities() error = %v", err)
}
if !containsString(profile.TransportProfile.KnownAdvisories, "initial_probe_race_expected") {
t.Fatalf("KnownAdvisories = %#v, want initial probe advisory", profile.TransportProfile.KnownAdvisories)
}
})
}
func containsString(values []string, want string) bool {
for _, value := range values {
if strings.TrimSpace(value) == want {
return true
}
}
return false
}

View File

@@ -0,0 +1,123 @@
package probe
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
)
type CompletionResult struct {
Model string
HTTPStatus int
LatencyMs int64
Classification string
Error string
}
func ResolveSmokeModel(requested []string, rawModels []string, profile *CapabilityProfile) (string, []string, error) {
recommended := RecommendModels(requested, rawModels)
for _, candidate := range recommended {
if profileAllowsSmoke(profile, candidate) {
return candidate, recommended, nil
}
}
for _, rawModel := range rawModels {
if strings.TrimSpace(rawModel) == "" {
continue
}
if profileAllowsSmoke(profile, rawModel) {
return rawModel, recommended, nil
}
}
if len(rawModels) > 0 && strings.TrimSpace(rawModels[0]) != "" {
return rawModels[0], recommended, nil
}
return "", recommended, fmt.Errorf("no smoke model available")
}
func SmokeCompletion(ctx context.Context, baseURL, apiKey, model string, profile *CapabilityProfile) (*CompletionResult, error) {
model = strings.TrimSpace(model)
if model == "" {
return nil, fmt.Errorf("model is required")
}
path := "/v1/chat/completions"
classification := "chat_completions"
payload := map[string]any{
"model": model,
"messages": []map[string]string{
{"role": "user", "content": "ping"},
},
"max_tokens": 8,
"temperature": 0,
}
if profile != nil && profile.TransportProfile.SupportsOpenAIResponses {
path = "/v1/responses"
classification = "responses"
payload = map[string]any{
"model": model,
"input": "ping",
}
}
requestURL, err := joinGatewayPath(baseURL, path)
if err != nil {
return nil, fmt.Errorf("resolve smoke endpoint: %w", err)
}
var body bytes.Buffer
if err := json.NewEncoder(&body).Encode(payload); err != nil {
return nil, fmt.Errorf("encode smoke payload: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, &body)
if err != nil {
return nil, fmt.Errorf("build smoke request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if token := strings.TrimSpace(apiKey); token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
startedAt := time.Now()
resp, err := (&http.Client{Timeout: 15 * time.Second}).Do(req)
if err != nil {
return nil, fmt.Errorf("request smoke completion: %w", err)
}
defer resp.Body.Close()
result := &CompletionResult{
Model: model,
HTTPStatus: resp.StatusCode,
LatencyMs: time.Since(startedAt).Milliseconds(),
Classification: classification,
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
result.Error = fmt.Sprintf("unexpected_status_%d", resp.StatusCode)
}
return result, nil
}
func profileAllowsSmoke(profile *CapabilityProfile, rawModel string) bool {
if profile == nil || len(profile.ModelProfiles) == 0 {
return true
}
targetRaw := strings.TrimSpace(rawModel)
targetCanonical := CanonicalModelFamily(rawModel)
for _, modelProfile := range profile.ModelProfiles {
if strings.TrimSpace(modelProfile.RawModelID) == targetRaw || modelProfile.CanonicalModelFamily == targetCanonical {
return modelProfile.SmokeChatOK
}
}
return false
}

View File

@@ -0,0 +1,94 @@
package probe
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
func TestResolveSmokeModel(t *testing.T) {
t.Parallel()
t.Run("uses requested alias when matched", func(t *testing.T) {
t.Parallel()
profile := &CapabilityProfile{
ModelProfiles: []ModelCapabilityProfile{
{RawModelID: "kimi-k2.6", CanonicalModelFamily: "kimi-2.6", SmokeChatOK: true},
},
}
model, recommended, err := ResolveSmokeModel([]string{"kimi 2.6"}, []string{"kimi-k2.6"}, profile)
if err != nil {
t.Fatalf("ResolveSmokeModel() error = %v", err)
}
if model != "kimi-k2.6" {
t.Fatalf("ResolveSmokeModel() model = %q, want %q", model, "kimi-k2.6")
}
if len(recommended) != 1 || recommended[0] != "kimi-k2.6" {
t.Fatalf("recommended = %#v, want discovered alias", recommended)
}
})
t.Run("falls back to discovered model with smoke support", func(t *testing.T) {
t.Parallel()
profile := &CapabilityProfile{
ModelProfiles: []ModelCapabilityProfile{
{RawModelID: "deepseek-ai/DeepSeek-V4-Pro", CanonicalModelFamily: "deepseek-v4-pro", SmokeChatOK: true},
},
}
model, recommended, err := ResolveSmokeModel([]string{"unknown"}, []string{"deepseek-ai/DeepSeek-V4-Pro"}, profile)
if err != nil {
t.Fatalf("ResolveSmokeModel() error = %v", err)
}
if model != "deepseek-ai/DeepSeek-V4-Pro" {
t.Fatalf("ResolveSmokeModel() model = %q, want discovered model", model)
}
if len(recommended) != 0 {
t.Fatalf("recommended = %#v, want empty for unknown request", recommended)
}
})
}
func TestSmokeCompletion(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/chat/completions" {
t.Fatalf("path = %q, want chat completions fallback", r.URL.Path)
}
var payload map[string]any
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
t.Fatalf("decode request body: %v", err)
}
if payload["model"] != "kimi-k2.6" {
t.Fatalf("payload model = %v, want kimi-k2.6", payload["model"])
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"chatcmpl-1","choices":[{"message":{"content":"pong"}}]}`))
}))
defer server.Close()
profile := &CapabilityProfile{
TransportProfile: TransportProfile{
SupportsOpenAIChatCompletions: true,
SupportsOpenAIResponses: false,
KnownAdvisories: []string{"responses_unsupported_but_chat_ok"},
},
}
result, err := SmokeCompletion(context.Background(), server.URL, "sk-test", "kimi-k2.6", profile)
if err != nil {
t.Fatalf("SmokeCompletion() error = %v", err)
}
if result.HTTPStatus != http.StatusOK {
t.Fatalf("HTTPStatus = %d, want %d", result.HTTPStatus, http.StatusOK)
}
if result.Classification != "chat_completions" {
t.Fatalf("Classification = %q, want %q", result.Classification, "chat_completions")
}
}