362 lines
11 KiB
Go
362 lines
11 KiB
Go
package sub2api
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type HostAdapter interface {
|
|
GetHostVersion(ctx context.Context) (string, error)
|
|
ProbeCapabilities(ctx context.Context) (HostCapabilities, error)
|
|
CreateGroup(ctx context.Context, req CreateGroupRequest) (GroupRef, error)
|
|
DeleteGroup(ctx context.Context, groupID string) error
|
|
CreateChannel(ctx context.Context, req CreateChannelRequest) (ChannelRef, error)
|
|
UpdateChannel(ctx context.Context, channelID string, req CreateChannelRequest) error
|
|
DeleteChannel(ctx context.Context, channelID string) error
|
|
CreatePlan(ctx context.Context, req CreatePlanRequest) (PlanRef, error)
|
|
DeletePlan(ctx context.Context, planID string) error
|
|
CreateAccount(ctx context.Context, req CreateAccountRequest) (AccountRef, error)
|
|
BatchCreateAccounts(ctx context.Context, req BatchCreateAccountsRequest) ([]AccountRef, error)
|
|
DeleteAccount(ctx context.Context, accountID string) error
|
|
TestAccount(ctx context.Context, accountID, modelID string) (ProbeResult, error)
|
|
GetAccountModels(ctx context.Context, accountID string) ([]AccountModel, error)
|
|
EnsureSubscriptionAccess(ctx context.Context, req EnsureSubscriptionAccessRequest) (SubscriptionAccessRef, error)
|
|
AssignSubscription(ctx context.Context, req AssignSubscriptionRequest) (SubscriptionRef, error)
|
|
CheckGatewayAccess(ctx context.Context, req GatewayAccessCheckRequest) (GatewayAccessResult, error)
|
|
CheckGatewayCompletion(ctx context.Context, req GatewayCompletionCheckRequest) (GatewayCompletionResult, error)
|
|
DisableOpenAIResponsesAPI(ctx context.Context, accountIDs []string) error
|
|
ListManagedResources(ctx context.Context, req ListManagedResourcesRequest) (ManagedResourceSnapshot, error)
|
|
}
|
|
|
|
type HostCapabilities struct {
|
|
Groups bool `json:"groups"`
|
|
Channels bool `json:"channels"`
|
|
Plans bool `json:"plans"`
|
|
Accounts bool `json:"accounts"`
|
|
AccountTest bool `json:"account_test"`
|
|
AccountModels bool `json:"account_models"`
|
|
Subscriptions bool `json:"subscriptions"`
|
|
}
|
|
|
|
type CreateGroupRequest struct {
|
|
Name string `json:"name"`
|
|
Platform string `json:"platform,omitempty"`
|
|
RateMultiplier float64 `json:"rate_multiplier"`
|
|
SubscriptionType string `json:"subscription_type,omitempty"`
|
|
}
|
|
|
|
type GroupRef struct {
|
|
ID string `json:"id"`
|
|
Name string `json:"name"`
|
|
}
|
|
|
|
type CreateChannelRequest struct {
|
|
Name string `json:"name"`
|
|
GroupIDs []string `json:"group_ids"`
|
|
ModelMapping map[string]string `json:"model_mapping,omitempty"`
|
|
ModelPricing []ChannelModelPricing `json:"model_pricing,omitempty"`
|
|
Platform string `json:"-"`
|
|
RestrictModels bool `json:"restrict_models,omitempty"`
|
|
BillingModelSource string `json:"billing_model_source,omitempty"`
|
|
}
|
|
|
|
type ChannelModelPricing struct {
|
|
Platform string `json:"platform,omitempty"`
|
|
Models []string `json:"models,omitempty"`
|
|
BillingMode string `json:"billing_mode,omitempty"`
|
|
InputPrice *float64 `json:"input_price,omitempty"`
|
|
OutputPrice *float64 `json:"output_price,omitempty"`
|
|
CacheWritePrice *float64 `json:"cache_write_price,omitempty"`
|
|
CacheReadPrice *float64 `json:"cache_read_price,omitempty"`
|
|
ImageOutputPrice *float64 `json:"image_output_price,omitempty"`
|
|
PerRequestPrice *float64 `json:"per_request_price,omitempty"`
|
|
Intervals []ChannelPricingTier `json:"intervals,omitempty"`
|
|
}
|
|
|
|
type ChannelPricingTier struct {
|
|
MinTokens int `json:"min_tokens,omitempty"`
|
|
MaxTokens *int `json:"max_tokens,omitempty"`
|
|
TierLabel string `json:"tier_label,omitempty"`
|
|
InputPrice *float64 `json:"input_price,omitempty"`
|
|
OutputPrice *float64 `json:"output_price,omitempty"`
|
|
CacheWritePrice *float64 `json:"cache_write_price,omitempty"`
|
|
CacheReadPrice *float64 `json:"cache_read_price,omitempty"`
|
|
PerRequestPrice *float64 `json:"per_request_price,omitempty"`
|
|
SortOrder int `json:"sort_order,omitempty"`
|
|
}
|
|
|
|
type ChannelRef struct {
|
|
ID string `json:"id"`
|
|
Name string `json:"name"`
|
|
}
|
|
|
|
type CreatePlanRequest struct {
|
|
GroupID string `json:"group_id"`
|
|
Name string `json:"name"`
|
|
Price float64 `json:"price"`
|
|
ValidityDays int `json:"validity_days"`
|
|
ValidityUnit string `json:"validity_unit"`
|
|
}
|
|
|
|
type PlanRef struct {
|
|
ID string `json:"id"`
|
|
Name string `json:"name"`
|
|
}
|
|
|
|
type CreateAccountRequest struct {
|
|
Name string `json:"name"`
|
|
Platform string `json:"platform"`
|
|
Type string `json:"type"`
|
|
Credentials map[string]any `json:"credentials"`
|
|
GroupIDs []string `json:"group_ids"`
|
|
}
|
|
|
|
type BatchCreateAccountsRequest struct {
|
|
Accounts []CreateAccountRequest `json:"accounts"`
|
|
}
|
|
|
|
type AccountRef struct {
|
|
ID string `json:"id"`
|
|
Name string `json:"name,omitempty"`
|
|
Platform string `json:"platform,omitempty"`
|
|
Type string `json:"type,omitempty"`
|
|
}
|
|
|
|
type ProbeResult struct {
|
|
OK bool `json:"ok"`
|
|
Status string `json:"status"`
|
|
Message string `json:"message,omitempty"`
|
|
}
|
|
|
|
type AccountModel struct {
|
|
ID string `json:"id"`
|
|
DisplayName string `json:"display_name"`
|
|
Type string `json:"type"`
|
|
}
|
|
|
|
type AssignSubscriptionRequest struct {
|
|
UserID string `json:"user_id"`
|
|
GroupID string `json:"group_id"`
|
|
DurationDays int `json:"validity_days,omitempty"`
|
|
}
|
|
|
|
type EnsureSubscriptionAccessRequest struct {
|
|
UserSelector string
|
|
GroupID string
|
|
}
|
|
|
|
type SubscriptionAccessRef struct {
|
|
UserID string
|
|
APIKey string
|
|
}
|
|
|
|
type SubscriptionRef struct {
|
|
ID string `json:"id"`
|
|
}
|
|
|
|
type GatewayCompletionCheckRequest struct {
|
|
APIKey string
|
|
Model string
|
|
Prompt string
|
|
MaxTokens int
|
|
}
|
|
|
|
type GatewayCompletionResult struct {
|
|
OK bool `json:"ok"`
|
|
StatusCode int `json:"status_code"`
|
|
ContentType string `json:"content_type,omitempty"`
|
|
BodyPreview string `json:"body_preview,omitempty"`
|
|
}
|
|
|
|
type Client struct {
|
|
baseURL *url.URL
|
|
httpClient *http.Client
|
|
apiKey string
|
|
bearerToken string
|
|
}
|
|
|
|
type Option func(*Client)
|
|
|
|
func WithHTTPClient(httpClient *http.Client) Option {
|
|
return func(client *Client) {
|
|
client.httpClient = httpClient
|
|
}
|
|
}
|
|
|
|
func WithAPIKey(apiKey string) Option {
|
|
return func(client *Client) {
|
|
client.apiKey = strings.TrimSpace(apiKey)
|
|
}
|
|
}
|
|
|
|
func WithBearerToken(token string) Option {
|
|
return func(client *Client) {
|
|
client.bearerToken = strings.TrimSpace(token)
|
|
}
|
|
}
|
|
|
|
func NewClient(baseURL string, opts ...Option) (*Client, error) {
|
|
parsedURL, err := url.Parse(strings.TrimSpace(baseURL))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse base url: %w", err)
|
|
}
|
|
|
|
if parsedURL.Scheme == "" || parsedURL.Host == "" {
|
|
return nil, fmt.Errorf("base url must include scheme and host")
|
|
}
|
|
|
|
client := &Client{
|
|
baseURL: parsedURL,
|
|
httpClient: &http.Client{
|
|
Timeout: 15 * time.Second,
|
|
},
|
|
}
|
|
for _, opt := range opts {
|
|
if opt != nil {
|
|
opt(client)
|
|
}
|
|
}
|
|
|
|
return client, nil
|
|
}
|
|
|
|
type HTTPError struct {
|
|
Method string
|
|
Path string
|
|
StatusCode int
|
|
Body string
|
|
}
|
|
|
|
func (e *HTTPError) Error() string {
|
|
return fmt.Sprintf("sub2api %s %s returned %d: %s", e.Method, e.Path, e.StatusCode, strings.TrimSpace(e.Body))
|
|
}
|
|
|
|
func (c *Client) GetHostVersion(ctx context.Context) (string, error) {
|
|
statusCode, _, body, err := c.perform(ctx, http.MethodGet, "/api/v1/admin/system/version", nil)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices {
|
|
return "", newHTTPError(http.MethodGet, "/api/v1/admin/system/version", statusCode, body)
|
|
}
|
|
|
|
var payload struct {
|
|
Version string `json:"version"`
|
|
}
|
|
if err := decodeEnvelopeObject(body, &payload); err != nil {
|
|
return "", fmt.Errorf("decode host version: %w", err)
|
|
}
|
|
if strings.TrimSpace(payload.Version) == "" {
|
|
return "", fmt.Errorf("decode host version: missing data.version")
|
|
}
|
|
|
|
return payload.Version, nil
|
|
}
|
|
|
|
func (c *Client) perform(ctx context.Context, method, path string, requestBody any) (int, http.Header, []byte, error) {
|
|
var bodyReader io.Reader
|
|
if requestBody != nil {
|
|
payload, err := json.Marshal(requestBody)
|
|
if err != nil {
|
|
return 0, nil, nil, fmt.Errorf("marshal %s %s request: %w", method, path, err)
|
|
}
|
|
bodyReader = bytes.NewReader(payload)
|
|
}
|
|
|
|
requestURL := c.resolvePath(path)
|
|
req, err := http.NewRequestWithContext(ctx, method, requestURL, bodyReader)
|
|
if err != nil {
|
|
return 0, nil, nil, fmt.Errorf("build %s %s request: %w", method, path, err)
|
|
}
|
|
if requestBody != nil {
|
|
req.Header.Set("Content-Type", "application/json")
|
|
}
|
|
req.Header.Set("Accept", "application/json, text/event-stream")
|
|
c.applyAuth(req)
|
|
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return 0, nil, nil, fmt.Errorf("perform %s %s request: %w", method, path, err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return 0, nil, nil, fmt.Errorf("read %s %s response: %w", method, path, err)
|
|
}
|
|
|
|
return resp.StatusCode, resp.Header.Clone(), body, nil
|
|
}
|
|
|
|
func (c *Client) postJSON(ctx context.Context, path string, requestBody any, dest any) error {
|
|
statusCode, _, body, err := c.perform(ctx, http.MethodPost, path, requestBody)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices {
|
|
return newHTTPError(http.MethodPost, path, statusCode, body)
|
|
}
|
|
if dest == nil {
|
|
return nil
|
|
}
|
|
if err := decodeEnvelopeObject(body, dest); err != nil {
|
|
return fmt.Errorf("decode %s response: %w", path, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) getJSON(ctx context.Context, path string, dest any) error {
|
|
statusCode, _, body, err := c.perform(ctx, http.MethodGet, path, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices {
|
|
return newHTTPError(http.MethodGet, path, statusCode, body)
|
|
}
|
|
if err := decodeEnvelopeObject(body, dest); err != nil {
|
|
return fmt.Errorf("decode %s response: %w", path, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) resolvePath(path string) string {
|
|
base := strings.TrimRight(c.baseURL.String(), "/")
|
|
return base + "/" + strings.TrimLeft(path, "/")
|
|
}
|
|
|
|
func (c *Client) applyAuth(req *http.Request) {
|
|
if c.apiKey != "" {
|
|
req.Header.Set("x-api-key", c.apiKey)
|
|
return
|
|
}
|
|
if c.bearerToken != "" {
|
|
req.Header.Set("Authorization", "Bearer "+c.bearerToken)
|
|
}
|
|
}
|
|
|
|
func newHTTPError(method, path string, statusCode int, body []byte) *HTTPError {
|
|
return &HTTPError{
|
|
Method: method,
|
|
Path: path,
|
|
StatusCode: statusCode,
|
|
Body: string(body),
|
|
}
|
|
}
|
|
|
|
func decodeEnvelopeObject(body []byte, dest any) error {
|
|
var envelope struct {
|
|
Data json.RawMessage `json:"data"`
|
|
}
|
|
if err := json.Unmarshal(body, &envelope); err == nil && len(bytes.TrimSpace(envelope.Data)) > 0 && string(bytes.TrimSpace(envelope.Data)) != "null" {
|
|
return json.Unmarshal(envelope.Data, dest)
|
|
}
|
|
return json.Unmarshal(body, dest)
|
|
}
|