feat(probe): add capability profile and smoke completion routing
This commit is contained in:
157
internal/probe/capability.go
Normal file
157
internal/probe/capability.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package probe
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TransportProfile struct {
|
||||
SupportsOpenAIModels bool `json:"supports_openai_models"`
|
||||
SupportsOpenAIChatCompletions bool `json:"supports_openai_chat_completions"`
|
||||
SupportsOpenAIResponses bool `json:"supports_openai_responses"`
|
||||
SupportsAnthropicMessages bool `json:"supports_anthropic_messages"`
|
||||
AuthStyle string `json:"auth_style"`
|
||||
ModelIDStyle string `json:"model_id_style"`
|
||||
KnownAdvisories []string `json:"known_advisories"`
|
||||
}
|
||||
|
||||
type ModelCapabilityProfile struct {
|
||||
RawModelID string `json:"raw_model_id"`
|
||||
NormalizedModelID string `json:"normalized_model_id"`
|
||||
CanonicalModelFamily string `json:"canonical_model_family"`
|
||||
SupportsStream string `json:"supports_stream"`
|
||||
SupportsTools string `json:"supports_tools"`
|
||||
SupportsReasoningFields string `json:"supports_reasoning_fields"`
|
||||
SmokeChatOK bool `json:"smoke_chat_ok"`
|
||||
}
|
||||
|
||||
type CapabilityProfile struct {
|
||||
TransportProfile TransportProfile `json:"transport_profile"`
|
||||
ModelProfiles []ModelCapabilityProfile `json:"model_profiles"`
|
||||
}
|
||||
|
||||
func ProbeCapabilities(ctx context.Context, baseURL, apiKey string, rawModels []string) (*CapabilityProfile, error) {
|
||||
profile := &CapabilityProfile{
|
||||
TransportProfile: TransportProfile{
|
||||
SupportsOpenAIModels: len(rawModels) > 0,
|
||||
AuthStyle: "bearer",
|
||||
ModelIDStyle: detectModelIDStyle(rawModels),
|
||||
KnownAdvisories: []string{},
|
||||
},
|
||||
ModelProfiles: make([]ModelCapabilityProfile, 0, len(rawModels)),
|
||||
}
|
||||
|
||||
responsesStatus, err := probeJSONEndpoint(ctx, baseURL, apiKey, "/v1/responses", map[string]any{
|
||||
"model": firstNonEmptyModel(rawModels),
|
||||
"input": "ping",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
profile.TransportProfile.SupportsOpenAIResponses = responsesStatus >= http.StatusOK && responsesStatus < http.StatusMultipleChoices
|
||||
|
||||
for _, rawModel := range rawModels {
|
||||
modelProfile := ModelCapabilityProfile{
|
||||
RawModelID: strings.TrimSpace(rawModel),
|
||||
NormalizedModelID: NormalizeModelID(rawModel),
|
||||
CanonicalModelFamily: CanonicalModelFamily(rawModel),
|
||||
SupportsStream: "unknown",
|
||||
SupportsTools: "unknown",
|
||||
SupportsReasoningFields: "unknown",
|
||||
}
|
||||
|
||||
chatStatus, err := probeJSONEndpoint(ctx, baseURL, apiKey, "/v1/chat/completions", map[string]any{
|
||||
"model": modelProfile.RawModelID,
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": "ping"},
|
||||
},
|
||||
"max_tokens": 8,
|
||||
"temperature": 0,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modelProfile.SmokeChatOK = chatStatus >= http.StatusOK && chatStatus < http.StatusMultipleChoices
|
||||
if modelProfile.SmokeChatOK {
|
||||
profile.TransportProfile.SupportsOpenAIChatCompletions = true
|
||||
}
|
||||
if chatStatus == http.StatusForbidden {
|
||||
appendAdvisory(&profile.TransportProfile.KnownAdvisories, "initial_probe_race_expected")
|
||||
}
|
||||
|
||||
profile.ModelProfiles = append(profile.ModelProfiles, modelProfile)
|
||||
}
|
||||
|
||||
if !profile.TransportProfile.SupportsOpenAIResponses && profile.TransportProfile.SupportsOpenAIChatCompletions {
|
||||
appendAdvisory(&profile.TransportProfile.KnownAdvisories, "responses_unsupported_but_chat_ok")
|
||||
}
|
||||
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
func probeJSONEndpoint(ctx context.Context, baseURL, apiKey, path string, payload any) (int, error) {
|
||||
requestURL, err := joinGatewayPath(baseURL, path)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("resolve %s endpoint: %w", path, err)
|
||||
}
|
||||
|
||||
var body bytes.Buffer
|
||||
if payload != nil {
|
||||
if err := json.NewEncoder(&body).Encode(payload); err != nil {
|
||||
return 0, fmt.Errorf("encode %s probe payload: %w", path, err)
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, &body)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("build %s request: %w", path, err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if token := strings.TrimSpace(apiKey); token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
resp, err := (&http.Client{Timeout: 15 * time.Second}).Do(req)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("request %s: %w", path, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return resp.StatusCode, nil
|
||||
}
|
||||
|
||||
func detectModelIDStyle(rawModels []string) string {
|
||||
for _, rawModel := range rawModels {
|
||||
if strings.Contains(strings.TrimSpace(rawModel), "/") {
|
||||
return "vendor_prefixed"
|
||||
}
|
||||
}
|
||||
return "canonical"
|
||||
}
|
||||
|
||||
func appendAdvisory(values *[]string, advisory string) {
|
||||
if advisory == "" {
|
||||
return
|
||||
}
|
||||
for _, existing := range *values {
|
||||
if existing == advisory {
|
||||
return
|
||||
}
|
||||
}
|
||||
*values = append(*values, advisory)
|
||||
}
|
||||
|
||||
func firstNonEmptyModel(rawModels []string) string {
|
||||
for _, rawModel := range rawModels {
|
||||
if trimmed := strings.TrimSpace(rawModel); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
return "ping"
|
||||
}
|
||||
119
internal/probe/capability_test.go
Normal file
119
internal/probe/capability_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package probe
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestProbeCapabilities(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("responses unsupported but chat works", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var responseCalls int
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/v1/responses":
|
||||
responseCalls++
|
||||
http.Error(w, `{"error":"unsupported"}`, http.StatusForbidden)
|
||||
case "/v1/chat/completions":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"id":"chatcmpl-1","choices":[{"message":{"content":"pong"}}]}`))
|
||||
default:
|
||||
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
profile, err := ProbeCapabilities(context.Background(), server.URL, "sk-test", []string{"kimi-k2.6"})
|
||||
if err != nil {
|
||||
t.Fatalf("ProbeCapabilities() error = %v", err)
|
||||
}
|
||||
if !profile.TransportProfile.SupportsOpenAIChatCompletions {
|
||||
t.Fatal("SupportsOpenAIChatCompletions = false, want true")
|
||||
}
|
||||
if profile.TransportProfile.SupportsOpenAIResponses {
|
||||
t.Fatal("SupportsOpenAIResponses = true, want false")
|
||||
}
|
||||
if responseCalls == 0 {
|
||||
t.Fatal("responses endpoint was not probed")
|
||||
}
|
||||
if !containsString(profile.TransportProfile.KnownAdvisories, "responses_unsupported_but_chat_ok") {
|
||||
t.Fatalf("KnownAdvisories = %#v, want responses advisory", profile.TransportProfile.KnownAdvisories)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("records per model capability profile", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/v1/responses":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"id":"resp_1"}`))
|
||||
case "/v1/chat/completions":
|
||||
body := make([]byte, r.ContentLength)
|
||||
_, _ = r.Body.Read(body)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"id":"chatcmpl-1","choices":[{"message":{"content":"ok"}}]}`))
|
||||
default:
|
||||
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
profile, err := ProbeCapabilities(context.Background(), server.URL, "sk-test", []string{"deepseek-ai/DeepSeek-V4-Pro", "kimi-k2.6"})
|
||||
if err != nil {
|
||||
t.Fatalf("ProbeCapabilities() error = %v", err)
|
||||
}
|
||||
if len(profile.ModelProfiles) != 2 {
|
||||
t.Fatalf("len(ModelProfiles) = %d, want 2", len(profile.ModelProfiles))
|
||||
}
|
||||
if profile.ModelProfiles[0].NormalizedModelID != "deepseek-v4-pro" {
|
||||
t.Fatalf("NormalizedModelID = %q, want %q", profile.ModelProfiles[0].NormalizedModelID, "deepseek-v4-pro")
|
||||
}
|
||||
if profile.ModelProfiles[0].CanonicalModelFamily != "deepseek-v4-pro" {
|
||||
t.Fatalf("CanonicalModelFamily = %q, want %q", profile.ModelProfiles[0].CanonicalModelFamily, "deepseek-v4-pro")
|
||||
}
|
||||
if !profile.ModelProfiles[0].SmokeChatOK {
|
||||
t.Fatal("SmokeChatOK = false, want true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("records initial probe advisory on transient auth race", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/v1/responses":
|
||||
http.Error(w, `{"error":"forbidden"}`, http.StatusForbidden)
|
||||
case "/v1/chat/completions":
|
||||
http.Error(w, `{"error":"warmup"}`, http.StatusForbidden)
|
||||
default:
|
||||
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
profile, err := ProbeCapabilities(context.Background(), server.URL, "sk-test", []string{"kimi-k2.6"})
|
||||
if err != nil {
|
||||
t.Fatalf("ProbeCapabilities() error = %v", err)
|
||||
}
|
||||
if !containsString(profile.TransportProfile.KnownAdvisories, "initial_probe_race_expected") {
|
||||
t.Fatalf("KnownAdvisories = %#v, want initial probe advisory", profile.TransportProfile.KnownAdvisories)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func containsString(values []string, want string) bool {
|
||||
for _, value := range values {
|
||||
if strings.TrimSpace(value) == want {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
123
internal/probe/completion.go
Normal file
123
internal/probe/completion.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package probe
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type CompletionResult struct {
|
||||
Model string
|
||||
HTTPStatus int
|
||||
LatencyMs int64
|
||||
Classification string
|
||||
Error string
|
||||
}
|
||||
|
||||
func ResolveSmokeModel(requested []string, rawModels []string, profile *CapabilityProfile) (string, []string, error) {
|
||||
recommended := RecommendModels(requested, rawModels)
|
||||
for _, candidate := range recommended {
|
||||
if profileAllowsSmoke(profile, candidate) {
|
||||
return candidate, recommended, nil
|
||||
}
|
||||
}
|
||||
|
||||
for _, rawModel := range rawModels {
|
||||
if strings.TrimSpace(rawModel) == "" {
|
||||
continue
|
||||
}
|
||||
if profileAllowsSmoke(profile, rawModel) {
|
||||
return rawModel, recommended, nil
|
||||
}
|
||||
}
|
||||
|
||||
if len(rawModels) > 0 && strings.TrimSpace(rawModels[0]) != "" {
|
||||
return rawModels[0], recommended, nil
|
||||
}
|
||||
|
||||
return "", recommended, fmt.Errorf("no smoke model available")
|
||||
}
|
||||
|
||||
func SmokeCompletion(ctx context.Context, baseURL, apiKey, model string, profile *CapabilityProfile) (*CompletionResult, error) {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return nil, fmt.Errorf("model is required")
|
||||
}
|
||||
|
||||
path := "/v1/chat/completions"
|
||||
classification := "chat_completions"
|
||||
payload := map[string]any{
|
||||
"model": model,
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": "ping"},
|
||||
},
|
||||
"max_tokens": 8,
|
||||
"temperature": 0,
|
||||
}
|
||||
|
||||
if profile != nil && profile.TransportProfile.SupportsOpenAIResponses {
|
||||
path = "/v1/responses"
|
||||
classification = "responses"
|
||||
payload = map[string]any{
|
||||
"model": model,
|
||||
"input": "ping",
|
||||
}
|
||||
}
|
||||
|
||||
requestURL, err := joinGatewayPath(baseURL, path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve smoke endpoint: %w", err)
|
||||
}
|
||||
|
||||
var body bytes.Buffer
|
||||
if err := json.NewEncoder(&body).Encode(payload); err != nil {
|
||||
return nil, fmt.Errorf("encode smoke payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, &body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build smoke request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if token := strings.TrimSpace(apiKey); token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
startedAt := time.Now()
|
||||
resp, err := (&http.Client{Timeout: 15 * time.Second}).Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request smoke completion: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
result := &CompletionResult{
|
||||
Model: model,
|
||||
HTTPStatus: resp.StatusCode,
|
||||
LatencyMs: time.Since(startedAt).Milliseconds(),
|
||||
Classification: classification,
|
||||
}
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
result.Error = fmt.Sprintf("unexpected_status_%d", resp.StatusCode)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func profileAllowsSmoke(profile *CapabilityProfile, rawModel string) bool {
|
||||
if profile == nil || len(profile.ModelProfiles) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
targetRaw := strings.TrimSpace(rawModel)
|
||||
targetCanonical := CanonicalModelFamily(rawModel)
|
||||
for _, modelProfile := range profile.ModelProfiles {
|
||||
if strings.TrimSpace(modelProfile.RawModelID) == targetRaw || modelProfile.CanonicalModelFamily == targetCanonical {
|
||||
return modelProfile.SmokeChatOK
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
94
internal/probe/completion_test.go
Normal file
94
internal/probe/completion_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package probe
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestResolveSmokeModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("uses requested alias when matched", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
profile := &CapabilityProfile{
|
||||
ModelProfiles: []ModelCapabilityProfile{
|
||||
{RawModelID: "kimi-k2.6", CanonicalModelFamily: "kimi-2.6", SmokeChatOK: true},
|
||||
},
|
||||
}
|
||||
|
||||
model, recommended, err := ResolveSmokeModel([]string{"kimi 2.6"}, []string{"kimi-k2.6"}, profile)
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveSmokeModel() error = %v", err)
|
||||
}
|
||||
if model != "kimi-k2.6" {
|
||||
t.Fatalf("ResolveSmokeModel() model = %q, want %q", model, "kimi-k2.6")
|
||||
}
|
||||
if len(recommended) != 1 || recommended[0] != "kimi-k2.6" {
|
||||
t.Fatalf("recommended = %#v, want discovered alias", recommended)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to discovered model with smoke support", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
profile := &CapabilityProfile{
|
||||
ModelProfiles: []ModelCapabilityProfile{
|
||||
{RawModelID: "deepseek-ai/DeepSeek-V4-Pro", CanonicalModelFamily: "deepseek-v4-pro", SmokeChatOK: true},
|
||||
},
|
||||
}
|
||||
|
||||
model, recommended, err := ResolveSmokeModel([]string{"unknown"}, []string{"deepseek-ai/DeepSeek-V4-Pro"}, profile)
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveSmokeModel() error = %v", err)
|
||||
}
|
||||
if model != "deepseek-ai/DeepSeek-V4-Pro" {
|
||||
t.Fatalf("ResolveSmokeModel() model = %q, want discovered model", model)
|
||||
}
|
||||
if len(recommended) != 0 {
|
||||
t.Fatalf("recommended = %#v, want empty for unknown request", recommended)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSmokeCompletion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/chat/completions" {
|
||||
t.Fatalf("path = %q, want chat completions fallback", r.URL.Path)
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
t.Fatalf("decode request body: %v", err)
|
||||
}
|
||||
if payload["model"] != "kimi-k2.6" {
|
||||
t.Fatalf("payload model = %v, want kimi-k2.6", payload["model"])
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"id":"chatcmpl-1","choices":[{"message":{"content":"pong"}}]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
profile := &CapabilityProfile{
|
||||
TransportProfile: TransportProfile{
|
||||
SupportsOpenAIChatCompletions: true,
|
||||
SupportsOpenAIResponses: false,
|
||||
KnownAdvisories: []string{"responses_unsupported_but_chat_ok"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := SmokeCompletion(context.Background(), server.URL, "sk-test", "kimi-k2.6", profile)
|
||||
if err != nil {
|
||||
t.Fatalf("SmokeCompletion() error = %v", err)
|
||||
}
|
||||
if result.HTTPStatus != http.StatusOK {
|
||||
t.Fatalf("HTTPStatus = %d, want %d", result.HTTPStatus, http.StatusOK)
|
||||
}
|
||||
if result.Classification != "chat_completions" {
|
||||
t.Fatalf("Classification = %q, want %q", result.Classification, "chat_completions")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user