Files
sub2api-cn-relay-manager/internal/probe/capability.go

158 lines
4.8 KiB
Go

package probe
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
)
type TransportProfile struct {
SupportsOpenAIModels bool `json:"supports_openai_models"`
SupportsOpenAIChatCompletions bool `json:"supports_openai_chat_completions"`
SupportsOpenAIResponses bool `json:"supports_openai_responses"`
SupportsAnthropicMessages bool `json:"supports_anthropic_messages"`
AuthStyle string `json:"auth_style"`
ModelIDStyle string `json:"model_id_style"`
KnownAdvisories []string `json:"known_advisories"`
}
type ModelCapabilityProfile struct {
RawModelID string `json:"raw_model_id"`
NormalizedModelID string `json:"normalized_model_id"`
CanonicalModelFamily string `json:"canonical_model_family"`
SupportsStream string `json:"supports_stream"`
SupportsTools string `json:"supports_tools"`
SupportsReasoningFields string `json:"supports_reasoning_fields"`
SmokeChatOK bool `json:"smoke_chat_ok"`
}
type CapabilityProfile struct {
TransportProfile TransportProfile `json:"transport_profile"`
ModelProfiles []ModelCapabilityProfile `json:"model_profiles"`
}
func ProbeCapabilities(ctx context.Context, baseURL, apiKey string, rawModels []string) (*CapabilityProfile, error) {
profile := &CapabilityProfile{
TransportProfile: TransportProfile{
SupportsOpenAIModels: len(rawModels) > 0,
AuthStyle: "bearer",
ModelIDStyle: detectModelIDStyle(rawModels),
KnownAdvisories: []string{},
},
ModelProfiles: make([]ModelCapabilityProfile, 0, len(rawModels)),
}
responsesStatus, err := probeJSONEndpoint(ctx, baseURL, apiKey, "/v1/responses", map[string]any{
"model": firstNonEmptyModel(rawModels),
"input": "ping",
})
if err != nil {
return nil, err
}
profile.TransportProfile.SupportsOpenAIResponses = responsesStatus >= http.StatusOK && responsesStatus < http.StatusMultipleChoices
for _, rawModel := range rawModels {
modelProfile := ModelCapabilityProfile{
RawModelID: strings.TrimSpace(rawModel),
NormalizedModelID: NormalizeModelID(rawModel),
CanonicalModelFamily: CanonicalModelFamily(rawModel),
SupportsStream: "unknown",
SupportsTools: "unknown",
SupportsReasoningFields: "unknown",
}
chatStatus, err := probeJSONEndpoint(ctx, baseURL, apiKey, "/v1/chat/completions", map[string]any{
"model": modelProfile.RawModelID,
"messages": []map[string]string{
{"role": "user", "content": "ping"},
},
"max_tokens": 8,
"temperature": 0,
})
if err != nil {
return nil, err
}
modelProfile.SmokeChatOK = chatStatus >= http.StatusOK && chatStatus < http.StatusMultipleChoices
if modelProfile.SmokeChatOK {
profile.TransportProfile.SupportsOpenAIChatCompletions = true
}
if chatStatus == http.StatusForbidden {
appendAdvisory(&profile.TransportProfile.KnownAdvisories, "initial_probe_race_expected")
}
profile.ModelProfiles = append(profile.ModelProfiles, modelProfile)
}
if !profile.TransportProfile.SupportsOpenAIResponses && profile.TransportProfile.SupportsOpenAIChatCompletions {
appendAdvisory(&profile.TransportProfile.KnownAdvisories, "responses_unsupported_but_chat_ok")
}
return profile, nil
}
func probeJSONEndpoint(ctx context.Context, baseURL, apiKey, path string, payload any) (int, error) {
requestURL, err := joinGatewayPath(baseURL, path)
if err != nil {
return 0, fmt.Errorf("resolve %s endpoint: %w", path, err)
}
var body bytes.Buffer
if payload != nil {
if err := json.NewEncoder(&body).Encode(payload); err != nil {
return 0, fmt.Errorf("encode %s probe payload: %w", path, err)
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, &body)
if err != nil {
return 0, fmt.Errorf("build %s request: %w", path, err)
}
req.Header.Set("Content-Type", "application/json")
if token := strings.TrimSpace(apiKey); token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
resp, err := (&http.Client{Timeout: 15 * time.Second}).Do(req)
if err != nil {
return 0, fmt.Errorf("request %s: %w", path, err)
}
defer resp.Body.Close()
return resp.StatusCode, nil
}
func detectModelIDStyle(rawModels []string) string {
for _, rawModel := range rawModels {
if strings.Contains(strings.TrimSpace(rawModel), "/") {
return "vendor_prefixed"
}
}
return "canonical"
}
func appendAdvisory(values *[]string, advisory string) {
if advisory == "" {
return
}
for _, existing := range *values {
if existing == advisory {
return
}
}
*values = append(*values, advisory)
}
func firstNonEmptyModel(rawModels []string) string {
for _, rawModel := range rawModels {
if trimmed := strings.TrimSpace(rawModel); trimmed != "" {
return trimmed
}
}
return "ping"
}