fix(provision): reconcile channel pricing and hosted access
This commit is contained in:
@@ -1,6 +1,10 @@
|
||||
package sub2api
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func (c *Client) CreateChannel(ctx context.Context, req CreateChannelRequest) (ChannelRef, error) {
|
||||
var ref ChannelRef
|
||||
@@ -9,3 +13,15 @@ func (c *Client) CreateChannel(ctx context.Context, req CreateChannelRequest) (C
|
||||
}
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
func (c *Client) UpdateChannel(ctx context.Context, channelID string, req CreateChannelRequest) error {
|
||||
path := fmt.Sprintf("/api/v1/admin/channels/%s", channelID)
|
||||
statusCode, _, body, err := c.perform(ctx, http.MethodPut, path, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices {
|
||||
return newHTTPError(http.MethodPut, path, statusCode, body)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ type HostAdapter interface {
|
||||
CreateGroup(ctx context.Context, req CreateGroupRequest) (GroupRef, error)
|
||||
DeleteGroup(ctx context.Context, groupID string) error
|
||||
CreateChannel(ctx context.Context, req CreateChannelRequest) (ChannelRef, error)
|
||||
UpdateChannel(ctx context.Context, channelID string, req CreateChannelRequest) error
|
||||
DeleteChannel(ctx context.Context, channelID string) error
|
||||
CreatePlan(ctx context.Context, req CreatePlanRequest) (PlanRef, error)
|
||||
DeletePlan(ctx context.Context, planID string) error
|
||||
@@ -26,6 +27,7 @@ type HostAdapter interface {
|
||||
DeleteAccount(ctx context.Context, accountID string) error
|
||||
TestAccount(ctx context.Context, accountID string) (ProbeResult, error)
|
||||
GetAccountModels(ctx context.Context, accountID string) ([]AccountModel, error)
|
||||
EnsureSubscriptionAccess(ctx context.Context, req EnsureSubscriptionAccessRequest) (SubscriptionAccessRef, error)
|
||||
AssignSubscription(ctx context.Context, req AssignSubscriptionRequest) (SubscriptionRef, error)
|
||||
CheckGatewayAccess(ctx context.Context, req GatewayAccessCheckRequest) (GatewayAccessResult, error)
|
||||
ListManagedResources(ctx context.Context, req ListManagedResourcesRequest) (ManagedResourceSnapshot, error)
|
||||
@@ -54,11 +56,38 @@ type GroupRef struct {
|
||||
}
|
||||
|
||||
type CreateChannelRequest struct {
|
||||
Name string `json:"name"`
|
||||
GroupIDs []string `json:"group_ids"`
|
||||
ModelMapping map[string]string `json:"model_mapping,omitempty"`
|
||||
RestrictModels bool `json:"restrict_models,omitempty"`
|
||||
BillingModelSource string `json:"billing_model_source,omitempty"`
|
||||
Name string `json:"name"`
|
||||
GroupIDs []string `json:"group_ids"`
|
||||
ModelMapping map[string]string `json:"model_mapping,omitempty"`
|
||||
ModelPricing []ChannelModelPricing `json:"model_pricing,omitempty"`
|
||||
Platform string `json:"-"`
|
||||
RestrictModels bool `json:"restrict_models,omitempty"`
|
||||
BillingModelSource string `json:"billing_model_source,omitempty"`
|
||||
}
|
||||
|
||||
type ChannelModelPricing struct {
|
||||
Platform string `json:"platform,omitempty"`
|
||||
Models []string `json:"models,omitempty"`
|
||||
BillingMode string `json:"billing_mode,omitempty"`
|
||||
InputPrice *float64 `json:"input_price,omitempty"`
|
||||
OutputPrice *float64 `json:"output_price,omitempty"`
|
||||
CacheWritePrice *float64 `json:"cache_write_price,omitempty"`
|
||||
CacheReadPrice *float64 `json:"cache_read_price,omitempty"`
|
||||
ImageOutputPrice *float64 `json:"image_output_price,omitempty"`
|
||||
PerRequestPrice *float64 `json:"per_request_price,omitempty"`
|
||||
Intervals []ChannelPricingTier `json:"intervals,omitempty"`
|
||||
}
|
||||
|
||||
type ChannelPricingTier struct {
|
||||
MinTokens int `json:"min_tokens,omitempty"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
TierLabel string `json:"tier_label,omitempty"`
|
||||
InputPrice *float64 `json:"input_price,omitempty"`
|
||||
OutputPrice *float64 `json:"output_price,omitempty"`
|
||||
CacheWritePrice *float64 `json:"cache_write_price,omitempty"`
|
||||
CacheReadPrice *float64 `json:"cache_read_price,omitempty"`
|
||||
PerRequestPrice *float64 `json:"per_request_price,omitempty"`
|
||||
SortOrder int `json:"sort_order,omitempty"`
|
||||
}
|
||||
|
||||
type ChannelRef struct {
|
||||
@@ -116,6 +145,16 @@ type AssignSubscriptionRequest struct {
|
||||
DurationDays int `json:"validity_days,omitempty"`
|
||||
}
|
||||
|
||||
type EnsureSubscriptionAccessRequest struct {
|
||||
UserSelector string
|
||||
GroupID string
|
||||
}
|
||||
|
||||
type SubscriptionAccessRef struct {
|
||||
UserID string
|
||||
APIKey string
|
||||
}
|
||||
|
||||
type SubscriptionRef struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
@@ -48,12 +48,41 @@ func flexibleIDSliceValues(raw []string) []any {
|
||||
}
|
||||
|
||||
func (r CreateChannelRequest) MarshalJSON() ([]byte, error) {
|
||||
modelMapping := map[string]map[string]string{}
|
||||
platform := strings.TrimSpace(r.Platform)
|
||||
if platform == "" {
|
||||
platform = "openai"
|
||||
}
|
||||
if len(r.ModelMapping) > 0 {
|
||||
inner := make(map[string]string, len(r.ModelMapping))
|
||||
for key, value := range r.ModelMapping {
|
||||
inner[key] = value
|
||||
}
|
||||
modelMapping[platform] = inner
|
||||
}
|
||||
modelPricing := make([]ChannelModelPricing, 0, len(r.ModelPricing))
|
||||
for _, entry := range r.ModelPricing {
|
||||
pricing := entry
|
||||
if strings.TrimSpace(pricing.Platform) == "" {
|
||||
pricing.Platform = platform
|
||||
}
|
||||
modelPricing = append(modelPricing, pricing)
|
||||
}
|
||||
|
||||
return json.Marshal(struct {
|
||||
Name string `json:"name"`
|
||||
GroupIDs []any `json:"group_ids"`
|
||||
Name string `json:"name"`
|
||||
GroupIDs []any `json:"group_ids"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping,omitempty"`
|
||||
ModelPricing []ChannelModelPricing `json:"model_pricing,omitempty"`
|
||||
RestrictModels bool `json:"restrict_models,omitempty"`
|
||||
BillingModelSource string `json:"billing_model_source,omitempty"`
|
||||
}{
|
||||
Name: r.Name,
|
||||
GroupIDs: flexibleIDSliceValues(r.GroupIDs),
|
||||
Name: r.Name,
|
||||
GroupIDs: flexibleIDSliceValues(r.GroupIDs),
|
||||
ModelMapping: modelMapping,
|
||||
ModelPricing: modelPricing,
|
||||
RestrictModels: r.RestrictModels,
|
||||
BillingModelSource: r.BillingModelSource,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -591,8 +592,23 @@ func TestCreateGroupWithMock(t *testing.T) {
|
||||
func TestCreateChannelWithMock(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
Name string `json:"name"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||
ModelPricing []struct {
|
||||
Platform string `json:"platform"`
|
||||
Models []string `json:"models"`
|
||||
BillingMode string `json:"billing_mode"`
|
||||
InputPrice *float64 `json:"input_price"`
|
||||
OutputPrice *float64 `json:"output_price"`
|
||||
CacheWritePrice *float64 `json:"cache_write_price"`
|
||||
CacheReadPrice *float64 `json:"cache_read_price"`
|
||||
ImageOutputPrice *float64 `json:"image_output_price"`
|
||||
PerRequestPrice *float64 `json:"per_request_price"`
|
||||
Intervals []any `json:"intervals"`
|
||||
} `json:"model_pricing"`
|
||||
RestrictModels bool `json:"restrict_models"`
|
||||
BillingModelSource string `json:"billing_model_source"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("decode request: %v", err)
|
||||
@@ -603,11 +619,36 @@ func TestCreateChannelWithMock(t *testing.T) {
|
||||
if len(req.GroupIDs) != 1 || req.GroupIDs[0] != 101 {
|
||||
t.Fatalf("group_ids = %v, want [101]", req.GroupIDs)
|
||||
}
|
||||
if req.ModelMapping["openai"]["deepseek-v4-pro"] != "deepseek-v4-pro" {
|
||||
t.Fatalf("model_mapping = %+v, want openai/deepseek-v4-pro passthrough", req.ModelMapping)
|
||||
}
|
||||
if len(req.ModelPricing) != 1 {
|
||||
t.Fatalf("model_pricing len = %d, want 1", len(req.ModelPricing))
|
||||
}
|
||||
if req.ModelPricing[0].Platform != "openai" || req.ModelPricing[0].BillingMode != "token" {
|
||||
t.Fatalf("model_pricing[0] = %+v, want openai/token entry", req.ModelPricing[0])
|
||||
}
|
||||
if len(req.ModelPricing[0].Models) != 1 || req.ModelPricing[0].Models[0] != "deepseek-v4-pro" {
|
||||
t.Fatalf("model_pricing[0].models = %v, want [deepseek-v4-pro]", req.ModelPricing[0].Models)
|
||||
}
|
||||
if !req.RestrictModels {
|
||||
t.Fatal("restrict_models = false, want true")
|
||||
}
|
||||
if req.BillingModelSource != "channel_mapped" {
|
||||
t.Fatalf("billing_model_source = %q, want channel_mapped", req.BillingModelSource)
|
||||
}
|
||||
w.Write([]byte(`{"data":{"id":201,"name":"ch"}}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
||||
ref, err := client.CreateChannel(context.Background(), CreateChannelRequest{Name: "ch", GroupIDs: []string{"101"}})
|
||||
ref, err := client.CreateChannel(context.Background(), CreateChannelRequest{
|
||||
Name: "ch",
|
||||
GroupIDs: []string{"101"},
|
||||
ModelMapping: map[string]string{"deepseek-v4-pro": "deepseek-v4-pro"},
|
||||
ModelPricing: []ChannelModelPricing{{Platform: "openai", Models: []string{"deepseek-v4-pro"}, BillingMode: "token"}},
|
||||
RestrictModels: true,
|
||||
BillingModelSource: "channel_mapped",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -699,6 +740,66 @@ func TestAssignSubscriptionWithMock(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureSubscriptionAccessWithMock(t *testing.T) {
|
||||
var calls []string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
calls = append(calls, r.Method+" "+r.URL.Path)
|
||||
switch {
|
||||
case r.Method == http.MethodGet && strings.HasPrefix(r.URL.RequestURI(), "/api/v1/admin/users?"):
|
||||
w.Write([]byte(`{"data":{"items":[]}}`))
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/users":
|
||||
w.Write([]byte(`{"data":{"id":84,"email":"relay-sub-user-1@sub2api.local"}}`))
|
||||
case r.Method == http.MethodPut && r.URL.Path == "/api/v1/admin/users/84":
|
||||
w.Write([]byte(`{"data":{"id":84}}`))
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/users/84/balance":
|
||||
w.Write([]byte(`{"data":{"id":84}}`))
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/subscriptions/assign":
|
||||
var req struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
DurationDays int `json:"validity_days"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("decode assign subscription request: %v", err)
|
||||
}
|
||||
if req.UserID != 84 || req.GroupID != 101 || req.DurationDays != 30 {
|
||||
t.Fatalf("unexpected assign subscription request: %+v", req)
|
||||
}
|
||||
w.Write([]byte(`{"data":{"id":401}}`))
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/auth/login":
|
||||
w.Write([]byte(`{"data":{"access_token":"user-jwt"}}`))
|
||||
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/keys":
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("decode managed key request: %v", err)
|
||||
}
|
||||
if _, ok := req["group_id"]; ok {
|
||||
t.Fatalf("managed key request unexpectedly carried group_id: %+v", req)
|
||||
}
|
||||
w.Write([]byte(`{"data":{"id":501,"key":"sk-relay-key","name":"managed-key"}}`))
|
||||
case r.Method == http.MethodPut && r.URL.Path == "/api/v1/admin/api-keys/501":
|
||||
w.Write([]byte(`{"data":{"api_key":{"id":501}}}`))
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
client, _ := NewClient(srv.URL, WithBearerToken("admin-token"))
|
||||
ref, err := client.EnsureSubscriptionAccess(context.Background(), EnsureSubscriptionAccessRequest{UserSelector: "crm-user-1", GroupID: "101"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if ref.UserID != "84" {
|
||||
t.Fatalf("user id = %q, want 84", ref.UserID)
|
||||
}
|
||||
if !strings.HasPrefix(ref.APIKey, "sk-relay-") {
|
||||
t.Fatalf("api key = %q, want managed sk-relay-* key", ref.APIKey)
|
||||
}
|
||||
if len(calls) < 7 {
|
||||
t.Fatalf("calls = %v, want managed subscription setup sequence", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckGatewayAccessWithMock(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(`{"data":[{"id":"gpt-4"},{"id":"claude-3"}]}`))
|
||||
@@ -741,12 +842,19 @@ func TestBatchCreateAccountsWithMock(t *testing.T) {
|
||||
if len(acct.GroupIDs) != 1 || acct.GroupIDs[0] != 101 {
|
||||
t.Fatalf("group_ids = %v, want [101]", acct.GroupIDs)
|
||||
}
|
||||
rawMapping, ok := acct.Credentials["model_mapping"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("credentials = %+v, want model_mapping map", acct.Credentials)
|
||||
}
|
||||
if got, _ := rawMapping["deepseek-v4-pro"].(string); got != "deepseek-v4-pro" {
|
||||
t.Fatalf("model_mapping = %+v, want deepseek-v4-pro passthrough", rawMapping)
|
||||
}
|
||||
w.Write([]byte(`{"data":[{"id":601,"name":"acct1"}]}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
client, _ := NewClient(srv.URL, WithAPIKey("k"))
|
||||
refs, err := client.BatchCreateAccounts(context.Background(), BatchCreateAccountsRequest{
|
||||
Accounts: []CreateAccountRequest{{Name: "acct1", Platform: "openai", Type: "apikey", GroupIDs: []string{"101"}, Credentials: map[string]any{"api_key": "sk-test", "base_url": "https://api.example.com"}}},
|
||||
Accounts: []CreateAccountRequest{{Name: "acct1", Platform: "openai", Type: "apikey", GroupIDs: []string{"101"}, Credentials: map[string]any{"api_key": "sk-test", "base_url": "https://api.example.com", "model_mapping": map[string]string{"deepseek-v4-pro": "deepseek-v4-pro"}}}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
320
internal/host/sub2api/subscription_access.go
Normal file
320
internal/host/sub2api/subscription_access.go
Normal file
@@ -0,0 +1,320 @@
|
||||
package sub2api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
managedSubscriptionBalance = 10.0
|
||||
managedSubscriptionValidityDays = 30
|
||||
)
|
||||
|
||||
type adminUserRecord struct {
|
||||
ID int64 `json:"id"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
type adminAPIKeyRecord struct {
|
||||
ID int64 `json:"id"`
|
||||
Key string `json:"key"`
|
||||
Name string `json:"name"`
|
||||
Group *struct {
|
||||
ID int64 `json:"id"`
|
||||
} `json:"group,omitempty"`
|
||||
GroupID *int64 `json:"group_id,omitempty"`
|
||||
}
|
||||
|
||||
type authTokenPair struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
}
|
||||
|
||||
func (c *Client) EnsureSubscriptionAccess(ctx context.Context, req EnsureSubscriptionAccessRequest) (SubscriptionAccessRef, error) {
|
||||
if c == nil {
|
||||
return SubscriptionAccessRef{}, fmt.Errorf("client is required")
|
||||
}
|
||||
selector := strings.TrimSpace(req.UserSelector)
|
||||
groupID := strings.TrimSpace(req.GroupID)
|
||||
if selector == "" {
|
||||
return SubscriptionAccessRef{}, fmt.Errorf("user selector is required")
|
||||
}
|
||||
if groupID == "" {
|
||||
return SubscriptionAccessRef{}, fmt.Errorf("group id is required")
|
||||
}
|
||||
groupInt, err := strconv.ParseInt(groupID, 10, 64)
|
||||
if err != nil {
|
||||
return SubscriptionAccessRef{}, fmt.Errorf("parse group id %q: %w", groupID, err)
|
||||
}
|
||||
|
||||
identity := buildManagedSubscriptionIdentity(selector, groupID)
|
||||
user, err := c.findManagedSubscriptionUser(ctx, identity.Email)
|
||||
if err != nil {
|
||||
return SubscriptionAccessRef{}, err
|
||||
}
|
||||
if user == nil {
|
||||
user, err = c.createManagedSubscriptionUser(ctx, identity, groupInt)
|
||||
if err != nil {
|
||||
return SubscriptionAccessRef{}, err
|
||||
}
|
||||
}
|
||||
if err := c.updateManagedSubscriptionUser(ctx, user.ID, groupInt); err != nil {
|
||||
return SubscriptionAccessRef{}, err
|
||||
}
|
||||
if err := c.setManagedSubscriptionBalance(ctx, user.ID); err != nil {
|
||||
return SubscriptionAccessRef{}, err
|
||||
}
|
||||
if err := c.ensureManagedSubscriptionAssignment(ctx, user.ID, groupID); err != nil {
|
||||
return SubscriptionAccessRef{}, err
|
||||
}
|
||||
|
||||
userClient, err := c.loginAsManagedSubscriptionUser(ctx, identity.Email, identity.Password)
|
||||
if err != nil {
|
||||
return SubscriptionAccessRef{}, err
|
||||
}
|
||||
keyRecord, err := c.ensureManagedSubscriptionAPIKey(ctx, userClient, user.ID, identity)
|
||||
if err != nil {
|
||||
return SubscriptionAccessRef{}, err
|
||||
}
|
||||
if err := c.bindManagedSubscriptionAPIKey(ctx, keyRecord.ID, groupInt); err != nil {
|
||||
return SubscriptionAccessRef{}, err
|
||||
}
|
||||
return SubscriptionAccessRef{UserID: strconv.FormatInt(user.ID, 10), APIKey: identity.CustomKey}, nil
|
||||
}
|
||||
|
||||
type managedSubscriptionIdentity struct {
|
||||
Email string
|
||||
Username string
|
||||
Password string
|
||||
CustomKey string
|
||||
KeyName string
|
||||
}
|
||||
|
||||
func buildManagedSubscriptionIdentity(selector, groupID string) managedSubscriptionIdentity {
|
||||
normalizedSelector := strings.TrimSpace(selector)
|
||||
seedMaterial := strings.ToLower(normalizedSelector) + "|" + strings.TrimSpace(groupID)
|
||||
sum := sha256.Sum256([]byte(seedMaterial))
|
||||
hash := hex.EncodeToString(sum[:])
|
||||
prefix := sanitizeManagedSubscriptionPrefix(normalizedSelector)
|
||||
if prefix == "" {
|
||||
prefix = "relay-sub"
|
||||
}
|
||||
prefix = truncateManagedSubscriptionToken(prefix, 24)
|
||||
shortHash := hash[:16]
|
||||
keyHash := hash[:32]
|
||||
username := truncateManagedSubscriptionToken(prefix+"-"+shortHash[:8], 32)
|
||||
return managedSubscriptionIdentity{
|
||||
Email: fmt.Sprintf("%s-%s@sub2api.local", prefix, shortHash),
|
||||
Username: username,
|
||||
Password: "RelayPwd!" + hash[:12],
|
||||
CustomKey: "sk-relay-" + keyHash,
|
||||
KeyName: truncateManagedSubscriptionToken(username+"-key", 48),
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeManagedSubscriptionPrefix(value string) string {
|
||||
value = strings.ToLower(strings.TrimSpace(value))
|
||||
var b strings.Builder
|
||||
lastDash := false
|
||||
for _, r := range value {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z', r >= '0' && r <= '9':
|
||||
b.WriteRune(r)
|
||||
lastDash = false
|
||||
case !lastDash:
|
||||
b.WriteByte('-')
|
||||
lastDash = true
|
||||
}
|
||||
}
|
||||
return strings.Trim(b.String(), "-")
|
||||
}
|
||||
|
||||
func truncateManagedSubscriptionToken(value string, max int) string {
|
||||
if len(value) <= max {
|
||||
return value
|
||||
}
|
||||
return strings.Trim(value[:max], "-")
|
||||
}
|
||||
|
||||
func (c *Client) findManagedSubscriptionUser(ctx context.Context, email string) (*adminUserRecord, error) {
|
||||
statusCode, _, body, err := c.perform(ctx, http.MethodGet, "/api/v1/admin/users?search="+url.QueryEscape(email)+"&page=1&page_size=20&sort_by=created_at&sort_order=desc", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list admin users: %w", err)
|
||||
}
|
||||
if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices {
|
||||
return nil, newHTTPError(http.MethodGet, "/api/v1/admin/users", statusCode, body)
|
||||
}
|
||||
var envelope struct {
|
||||
Data struct {
|
||||
Items []adminUserRecord `json:"items"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &envelope); err != nil {
|
||||
return nil, fmt.Errorf("decode admin users response: %w", err)
|
||||
}
|
||||
for _, item := range envelope.Data.Items {
|
||||
if strings.EqualFold(strings.TrimSpace(item.Email), email) {
|
||||
user := item
|
||||
return &user, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *Client) createManagedSubscriptionUser(ctx context.Context, identity managedSubscriptionIdentity, groupID int64) (*adminUserRecord, error) {
|
||||
payload := map[string]any{
|
||||
"email": identity.Email,
|
||||
"password": identity.Password,
|
||||
"username": identity.Username,
|
||||
"notes": "managed by sub2api-cn-relay-manager",
|
||||
"balance": managedSubscriptionBalance,
|
||||
"concurrency": 5,
|
||||
"allowed_groups": []int64{groupID},
|
||||
}
|
||||
statusCode, _, body, err := c.perform(ctx, http.MethodPost, "/api/v1/admin/users", payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create admin user: %w", err)
|
||||
}
|
||||
if statusCode == http.StatusConflict {
|
||||
return c.findManagedSubscriptionUser(ctx, identity.Email)
|
||||
}
|
||||
if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices {
|
||||
return nil, newHTTPError(http.MethodPost, "/api/v1/admin/users", statusCode, body)
|
||||
}
|
||||
var user adminUserRecord
|
||||
if err := decodeEnvelopeObject(body, &user); err != nil {
|
||||
return nil, fmt.Errorf("decode created admin user: %w", err)
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (c *Client) updateManagedSubscriptionUser(ctx context.Context, userID, groupID int64) error {
|
||||
payload := map[string]any{"allowed_groups": []int64{groupID}}
|
||||
statusCode, _, body, err := c.perform(ctx, http.MethodPut, fmt.Sprintf("/api/v1/admin/users/%d", userID), payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update admin user groups: %w", err)
|
||||
}
|
||||
if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices {
|
||||
return newHTTPError(http.MethodPut, fmt.Sprintf("/api/v1/admin/users/%d", userID), statusCode, body)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) setManagedSubscriptionBalance(ctx context.Context, userID int64) error {
|
||||
payload := map[string]any{"balance": managedSubscriptionBalance, "operation": "set", "notes": "managed by sub2api-cn-relay-manager"}
|
||||
statusCode, _, body, err := c.perform(ctx, http.MethodPost, fmt.Sprintf("/api/v1/admin/users/%d/balance", userID), payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set admin user balance: %w", err)
|
||||
}
|
||||
if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices {
|
||||
return newHTTPError(http.MethodPost, fmt.Sprintf("/api/v1/admin/users/%d/balance", userID), statusCode, body)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) ensureManagedSubscriptionAssignment(ctx context.Context, userID int64, groupID string) error {
|
||||
_, err := c.AssignSubscription(ctx, AssignSubscriptionRequest{
|
||||
UserID: strconv.FormatInt(userID, 10),
|
||||
GroupID: groupID,
|
||||
DurationDays: managedSubscriptionValidityDays,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("assign managed subscription: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) loginAsManagedSubscriptionUser(ctx context.Context, email, password string) (*Client, error) {
|
||||
anon := c.cloneWithAuth("", "")
|
||||
payload := map[string]any{"email": email, "password": password, "turnstile_token": ""}
|
||||
statusCode, _, body, err := anon.perform(ctx, http.MethodPost, "/api/v1/auth/login", payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login managed subscription user: %w", err)
|
||||
}
|
||||
if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices {
|
||||
return nil, newHTTPError(http.MethodPost, "/api/v1/auth/login", statusCode, body)
|
||||
}
|
||||
var tokenPair authTokenPair
|
||||
if err := decodeEnvelopeObject(body, &tokenPair); err != nil {
|
||||
return nil, fmt.Errorf("decode managed user login response: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(tokenPair.AccessToken) == "" {
|
||||
return nil, fmt.Errorf("managed user login returned empty access token")
|
||||
}
|
||||
return c.cloneWithAuth("", tokenPair.AccessToken), nil
|
||||
}
|
||||
|
||||
func (c *Client) ensureManagedSubscriptionAPIKey(ctx context.Context, userClient *Client, userID int64, identity managedSubscriptionIdentity) (*adminAPIKeyRecord, error) {
|
||||
payload := map[string]any{
|
||||
"name": identity.KeyName,
|
||||
"custom_key": identity.CustomKey,
|
||||
}
|
||||
statusCode, _, body, err := userClient.perform(ctx, http.MethodPost, "/api/v1/keys", payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create managed api key: %w", err)
|
||||
}
|
||||
if statusCode >= http.StatusOK && statusCode < http.StatusMultipleChoices {
|
||||
var key adminAPIKeyRecord
|
||||
if err := decodeEnvelopeObject(body, &key); err != nil {
|
||||
return nil, fmt.Errorf("decode created api key: %w", err)
|
||||
}
|
||||
return &key, nil
|
||||
}
|
||||
if statusCode != http.StatusConflict && statusCode != http.StatusBadRequest {
|
||||
return nil, newHTTPError(http.MethodPost, "/api/v1/keys", statusCode, body)
|
||||
}
|
||||
return c.findManagedSubscriptionAPIKey(ctx, userID, identity)
|
||||
}
|
||||
|
||||
func (c *Client) findManagedSubscriptionAPIKey(ctx context.Context, userID int64, identity managedSubscriptionIdentity) (*adminAPIKeyRecord, error) {
|
||||
statusCode, _, body, err := c.perform(ctx, http.MethodGet, fmt.Sprintf("/api/v1/admin/users/%d/api-keys?page=1&page_size=100&sort_by=created_at&sort_order=desc", userID), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list managed api keys: %w", err)
|
||||
}
|
||||
if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices {
|
||||
return nil, newHTTPError(http.MethodGet, fmt.Sprintf("/api/v1/admin/users/%d/api-keys", userID), statusCode, body)
|
||||
}
|
||||
var envelope struct {
|
||||
Data struct {
|
||||
Items []adminAPIKeyRecord `json:"items"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &envelope); err != nil {
|
||||
return nil, fmt.Errorf("decode admin api keys response: %w", err)
|
||||
}
|
||||
for _, item := range envelope.Data.Items {
|
||||
if strings.TrimSpace(item.Key) == identity.CustomKey || strings.TrimSpace(item.Name) == identity.KeyName {
|
||||
key := item
|
||||
return &key, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("managed api key %q not found for user %d", identity.KeyName, userID)
|
||||
}
|
||||
|
||||
func (c *Client) bindManagedSubscriptionAPIKey(ctx context.Context, keyID, groupID int64) error {
|
||||
payload := map[string]any{"group_id": groupID}
|
||||
statusCode, _, body, err := c.perform(ctx, http.MethodPut, fmt.Sprintf("/api/v1/admin/api-keys/%d", keyID), payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bind managed api key group: %w", err)
|
||||
}
|
||||
if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices {
|
||||
return newHTTPError(http.MethodPut, fmt.Sprintf("/api/v1/admin/api-keys/%d", keyID), statusCode, body)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) cloneWithAuth(apiKey, bearerToken string) *Client {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
clone := *c
|
||||
clone.apiKey = strings.TrimSpace(apiKey)
|
||||
clone.bearerToken = strings.TrimSpace(bearerToken)
|
||||
return &clone
|
||||
}
|
||||
Reference in New Issue
Block a user