Files
sub2api-cn-relay-manager/internal/host/sub2api/client.go

362 lines
11 KiB
Go

package sub2api
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
)
type HostAdapter interface {
GetHostVersion(ctx context.Context) (string, error)
ProbeCapabilities(ctx context.Context) (HostCapabilities, error)
CreateGroup(ctx context.Context, req CreateGroupRequest) (GroupRef, error)
DeleteGroup(ctx context.Context, groupID string) error
CreateChannel(ctx context.Context, req CreateChannelRequest) (ChannelRef, error)
UpdateChannel(ctx context.Context, channelID string, req CreateChannelRequest) error
DeleteChannel(ctx context.Context, channelID string) error
CreatePlan(ctx context.Context, req CreatePlanRequest) (PlanRef, error)
DeletePlan(ctx context.Context, planID string) error
CreateAccount(ctx context.Context, req CreateAccountRequest) (AccountRef, error)
BatchCreateAccounts(ctx context.Context, req BatchCreateAccountsRequest) ([]AccountRef, error)
DeleteAccount(ctx context.Context, accountID string) error
TestAccount(ctx context.Context, accountID, modelID string) (ProbeResult, error)
GetAccountModels(ctx context.Context, accountID string) ([]AccountModel, error)
EnsureSubscriptionAccess(ctx context.Context, req EnsureSubscriptionAccessRequest) (SubscriptionAccessRef, error)
AssignSubscription(ctx context.Context, req AssignSubscriptionRequest) (SubscriptionRef, error)
CheckGatewayAccess(ctx context.Context, req GatewayAccessCheckRequest) (GatewayAccessResult, error)
CheckGatewayCompletion(ctx context.Context, req GatewayCompletionCheckRequest) (GatewayCompletionResult, error)
DisableOpenAIResponsesAPI(ctx context.Context, accountIDs []string) error
ListManagedResources(ctx context.Context, req ListManagedResourcesRequest) (ManagedResourceSnapshot, error)
}
type HostCapabilities struct {
Groups bool `json:"groups"`
Channels bool `json:"channels"`
Plans bool `json:"plans"`
Accounts bool `json:"accounts"`
AccountTest bool `json:"account_test"`
AccountModels bool `json:"account_models"`
Subscriptions bool `json:"subscriptions"`
}
type CreateGroupRequest struct {
Name string `json:"name"`
Platform string `json:"platform,omitempty"`
RateMultiplier float64 `json:"rate_multiplier"`
SubscriptionType string `json:"subscription_type,omitempty"`
}
type GroupRef struct {
ID string `json:"id"`
Name string `json:"name"`
}
type CreateChannelRequest struct {
Name string `json:"name"`
GroupIDs []string `json:"group_ids"`
ModelMapping map[string]string `json:"model_mapping,omitempty"`
ModelPricing []ChannelModelPricing `json:"model_pricing,omitempty"`
Platform string `json:"-"`
RestrictModels bool `json:"restrict_models,omitempty"`
BillingModelSource string `json:"billing_model_source,omitempty"`
}
type ChannelModelPricing struct {
Platform string `json:"platform,omitempty"`
Models []string `json:"models,omitempty"`
BillingMode string `json:"billing_mode,omitempty"`
InputPrice *float64 `json:"input_price,omitempty"`
OutputPrice *float64 `json:"output_price,omitempty"`
CacheWritePrice *float64 `json:"cache_write_price,omitempty"`
CacheReadPrice *float64 `json:"cache_read_price,omitempty"`
ImageOutputPrice *float64 `json:"image_output_price,omitempty"`
PerRequestPrice *float64 `json:"per_request_price,omitempty"`
Intervals []ChannelPricingTier `json:"intervals,omitempty"`
}
type ChannelPricingTier struct {
MinTokens int `json:"min_tokens,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"`
TierLabel string `json:"tier_label,omitempty"`
InputPrice *float64 `json:"input_price,omitempty"`
OutputPrice *float64 `json:"output_price,omitempty"`
CacheWritePrice *float64 `json:"cache_write_price,omitempty"`
CacheReadPrice *float64 `json:"cache_read_price,omitempty"`
PerRequestPrice *float64 `json:"per_request_price,omitempty"`
SortOrder int `json:"sort_order,omitempty"`
}
type ChannelRef struct {
ID string `json:"id"`
Name string `json:"name"`
}
type CreatePlanRequest struct {
GroupID string `json:"group_id"`
Name string `json:"name"`
Price float64 `json:"price"`
ValidityDays int `json:"validity_days"`
ValidityUnit string `json:"validity_unit"`
}
type PlanRef struct {
ID string `json:"id"`
Name string `json:"name"`
}
type CreateAccountRequest struct {
Name string `json:"name"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]any `json:"credentials"`
GroupIDs []string `json:"group_ids"`
}
type BatchCreateAccountsRequest struct {
Accounts []CreateAccountRequest `json:"accounts"`
}
type AccountRef struct {
ID string `json:"id"`
Name string `json:"name,omitempty"`
Platform string `json:"platform,omitempty"`
Type string `json:"type,omitempty"`
}
type ProbeResult struct {
OK bool `json:"ok"`
Status string `json:"status"`
Message string `json:"message,omitempty"`
}
type AccountModel struct {
ID string `json:"id"`
DisplayName string `json:"display_name"`
Type string `json:"type"`
}
type AssignSubscriptionRequest struct {
UserID string `json:"user_id"`
GroupID string `json:"group_id"`
DurationDays int `json:"validity_days,omitempty"`
}
type EnsureSubscriptionAccessRequest struct {
UserSelector string
GroupID string
}
type SubscriptionAccessRef struct {
UserID string
APIKey string
}
type SubscriptionRef struct {
ID string `json:"id"`
}
type GatewayCompletionCheckRequest struct {
APIKey string
Model string
Prompt string
MaxTokens int
}
type GatewayCompletionResult struct {
OK bool `json:"ok"`
StatusCode int `json:"status_code"`
ContentType string `json:"content_type,omitempty"`
BodyPreview string `json:"body_preview,omitempty"`
}
type Client struct {
baseURL *url.URL
httpClient *http.Client
apiKey string
bearerToken string
}
type Option func(*Client)
func WithHTTPClient(httpClient *http.Client) Option {
return func(client *Client) {
client.httpClient = httpClient
}
}
func WithAPIKey(apiKey string) Option {
return func(client *Client) {
client.apiKey = strings.TrimSpace(apiKey)
}
}
func WithBearerToken(token string) Option {
return func(client *Client) {
client.bearerToken = strings.TrimSpace(token)
}
}
func NewClient(baseURL string, opts ...Option) (*Client, error) {
parsedURL, err := url.Parse(strings.TrimSpace(baseURL))
if err != nil {
return nil, fmt.Errorf("parse base url: %w", err)
}
if parsedURL.Scheme == "" || parsedURL.Host == "" {
return nil, fmt.Errorf("base url must include scheme and host")
}
client := &Client{
baseURL: parsedURL,
httpClient: &http.Client{
Timeout: 15 * time.Second,
},
}
for _, opt := range opts {
if opt != nil {
opt(client)
}
}
return client, nil
}
type HTTPError struct {
Method string
Path string
StatusCode int
Body string
}
func (e *HTTPError) Error() string {
return fmt.Sprintf("sub2api %s %s returned %d: %s", e.Method, e.Path, e.StatusCode, strings.TrimSpace(e.Body))
}
func (c *Client) GetHostVersion(ctx context.Context) (string, error) {
statusCode, _, body, err := c.perform(ctx, http.MethodGet, "/api/v1/admin/system/version", nil)
if err != nil {
return "", err
}
if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices {
return "", newHTTPError(http.MethodGet, "/api/v1/admin/system/version", statusCode, body)
}
var payload struct {
Version string `json:"version"`
}
if err := decodeEnvelopeObject(body, &payload); err != nil {
return "", fmt.Errorf("decode host version: %w", err)
}
if strings.TrimSpace(payload.Version) == "" {
return "", fmt.Errorf("decode host version: missing data.version")
}
return payload.Version, nil
}
func (c *Client) perform(ctx context.Context, method, path string, requestBody any) (int, http.Header, []byte, error) {
var bodyReader io.Reader
if requestBody != nil {
payload, err := json.Marshal(requestBody)
if err != nil {
return 0, nil, nil, fmt.Errorf("marshal %s %s request: %w", method, path, err)
}
bodyReader = bytes.NewReader(payload)
}
requestURL := c.resolvePath(path)
req, err := http.NewRequestWithContext(ctx, method, requestURL, bodyReader)
if err != nil {
return 0, nil, nil, fmt.Errorf("build %s %s request: %w", method, path, err)
}
if requestBody != nil {
req.Header.Set("Content-Type", "application/json")
}
req.Header.Set("Accept", "application/json, text/event-stream")
c.applyAuth(req)
resp, err := c.httpClient.Do(req)
if err != nil {
return 0, nil, nil, fmt.Errorf("perform %s %s request: %w", method, path, err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return 0, nil, nil, fmt.Errorf("read %s %s response: %w", method, path, err)
}
return resp.StatusCode, resp.Header.Clone(), body, nil
}
func (c *Client) postJSON(ctx context.Context, path string, requestBody any, dest any) error {
statusCode, _, body, err := c.perform(ctx, http.MethodPost, path, requestBody)
if err != nil {
return err
}
if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices {
return newHTTPError(http.MethodPost, path, statusCode, body)
}
if dest == nil {
return nil
}
if err := decodeEnvelopeObject(body, dest); err != nil {
return fmt.Errorf("decode %s response: %w", path, err)
}
return nil
}
func (c *Client) getJSON(ctx context.Context, path string, dest any) error {
statusCode, _, body, err := c.perform(ctx, http.MethodGet, path, nil)
if err != nil {
return err
}
if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices {
return newHTTPError(http.MethodGet, path, statusCode, body)
}
if err := decodeEnvelopeObject(body, dest); err != nil {
return fmt.Errorf("decode %s response: %w", path, err)
}
return nil
}
func (c *Client) resolvePath(path string) string {
base := strings.TrimRight(c.baseURL.String(), "/")
return base + "/" + strings.TrimLeft(path, "/")
}
func (c *Client) applyAuth(req *http.Request) {
if c.apiKey != "" {
req.Header.Set("x-api-key", c.apiKey)
return
}
if c.bearerToken != "" {
req.Header.Set("Authorization", "Bearer "+c.bearerToken)
}
}
func newHTTPError(method, path string, statusCode int, body []byte) *HTTPError {
return &HTTPError{
Method: method,
Path: path,
StatusCode: statusCode,
Body: string(body),
}
}
func decodeEnvelopeObject(body []byte, dest any) error {
var envelope struct {
Data json.RawMessage `json:"data"`
}
if err := json.Unmarshal(body, &envelope); err == nil && len(bytes.TrimSpace(envelope.Data)) > 0 && string(bytes.TrimSpace(envelope.Data)) != "null" {
return json.Unmarshal(envelope.Data, dest)
}
return json.Unmarshal(body, dest)
}