已验证的问题状态: 1. P0-07补偿处理器 - 已集成到main.go ✅ 2. P0-09外键校验器 - 已集成到main.go并调用 ✅ 3. 幂等协议Idempotency-Key - 已在idempotency.go实现 ✅ 4. 幂等唯一索引 - 已在SQL中定义 ✅ Gateway修复: - 修复cors.go语法错误(重复函数定义) - 修复middleware_test.go参数不匹配问题 - 修复go.mod降级到go 1.21解决依赖问题
507 lines
14 KiB
Go
507 lines
14 KiB
Go
package adapter
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"net/http"
|
||
"net/http/httptest"
|
||
"testing"
|
||
"time"
|
||
)
|
||
|
||
func TestNewOpenAIAdapter(t *testing.T) {
|
||
adapter := NewOpenAIAdapter("https://api.openai.com", "test-key", []string{"gpt-4", "gpt-3.5"})
|
||
|
||
if adapter.baseURL != "https://api.openai.com" {
|
||
t.Errorf("unexpected baseURL: %s", adapter.baseURL)
|
||
}
|
||
if adapter.apiKey != "test-key" {
|
||
t.Errorf("unexpected apiKey: %s", adapter.apiKey)
|
||
}
|
||
if len(adapter.models) != 2 {
|
||
t.Errorf("expected 2 models, got %d", len(adapter.models))
|
||
}
|
||
if adapter.httpClient == nil {
|
||
t.Error("httpClient should not be nil")
|
||
}
|
||
}
|
||
|
||
func TestOpenAIAdapter_ProviderName(t *testing.T) {
|
||
adapter := NewOpenAIAdapter("https://api.openai.com", "test-key", []string{"gpt-4"})
|
||
if adapter.ProviderName() != "openai" {
|
||
t.Errorf("expected openai, got %s", adapter.ProviderName())
|
||
}
|
||
}
|
||
|
||
func TestOpenAIAdapter_SupportedModels(t *testing.T) {
|
||
models := []string{"gpt-4", "gpt-3.5-turbo"}
|
||
adapter := NewOpenAIAdapter("https://api.openai.com", "test-key", models)
|
||
|
||
result := adapter.SupportedModels()
|
||
if len(result) != 2 {
|
||
t.Errorf("expected 2 models, got %d", len(result))
|
||
}
|
||
if result[0] != "gpt-4" {
|
||
t.Errorf("expected gpt-4, got %s", result[0])
|
||
}
|
||
}
|
||
|
||
func TestOpenAIAdapter_GetUsage(t *testing.T) {
|
||
adapter := NewOpenAIAdapter("https://api.openai.com", "test-key", []string{"gpt-4"})
|
||
|
||
resp := &CompletionResponse{
|
||
Usage: Usage{
|
||
PromptTokens: 10,
|
||
CompletionTokens: 5,
|
||
TotalTokens: 15,
|
||
},
|
||
}
|
||
|
||
usage := adapter.GetUsage(resp)
|
||
if usage.PromptTokens != 10 {
|
||
t.Errorf("expected 10, got %d", usage.PromptTokens)
|
||
}
|
||
if usage.TotalTokens != 15 {
|
||
t.Errorf("expected 15, got %d", usage.TotalTokens)
|
||
}
|
||
}
|
||
|
||
func TestOpenAIAdapter_MapError(t *testing.T) {
|
||
adapter := NewOpenAIAdapter("https://api.openai.com", "test-key", []string{"gpt-4"})
|
||
|
||
tests := []struct {
|
||
name string
|
||
errMsg string
|
||
wantCode string
|
||
wantHTTP int
|
||
wantRetryable bool
|
||
}{
|
||
{
|
||
name: "invalid_api_key",
|
||
errMsg: "invalid_api_key",
|
||
wantCode: "PROVIDER_001",
|
||
wantHTTP: 401,
|
||
wantRetryable: false,
|
||
},
|
||
{
|
||
name: "rate_limit",
|
||
errMsg: "rate_limit exceeded",
|
||
wantCode: "PROVIDER_002",
|
||
wantHTTP: 429,
|
||
wantRetryable: true,
|
||
},
|
||
{
|
||
name: "quota",
|
||
errMsg: "quota exceeded",
|
||
wantCode: "PROVIDER_003",
|
||
wantHTTP: 402,
|
||
wantRetryable: false,
|
||
},
|
||
{
|
||
name: "model_not_found",
|
||
errMsg: "model_not_found error",
|
||
wantCode: "PROVIDER_004",
|
||
wantHTTP: 404,
|
||
wantRetryable: false,
|
||
},
|
||
{
|
||
name: "unknown_error",
|
||
errMsg: "some unknown error",
|
||
wantCode: "PROVIDER_005",
|
||
wantHTTP: 502,
|
||
wantRetryable: true,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
provErr := adapter.MapError(&testError{msg: tt.errMsg})
|
||
if provErr.Code != tt.wantCode {
|
||
t.Errorf("expected code %s, got %s", tt.wantCode, provErr.Code)
|
||
}
|
||
if provErr.HTTPStatus != tt.wantHTTP {
|
||
t.Errorf("expected http status %d, got %d", tt.wantHTTP, provErr.HTTPStatus)
|
||
}
|
||
if provErr.Retryable != tt.wantRetryable {
|
||
t.Errorf("expected retryable %v, got %v", tt.wantRetryable, provErr.Retryable)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
type testError struct {
|
||
msg string
|
||
}
|
||
|
||
func (e *testError) Error() string {
|
||
return e.msg
|
||
}
|
||
|
||
func TestOpenAIAdapter_ChatCompletion_Success(t *testing.T) {
|
||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
// 验证请求
|
||
if r.Header.Get("Content-Type") != "application/json" {
|
||
t.Error("expected Content-Type application/json")
|
||
}
|
||
if r.Header.Get("Authorization") != "Bearer test-key" {
|
||
t.Error("expected Authorization header")
|
||
}
|
||
|
||
// 返回模拟响应
|
||
resp := map[string]interface{}{
|
||
"id": "chatcmpl-123",
|
||
"object": "chat.completion",
|
||
"created": 1234567890,
|
||
"model": "gpt-4",
|
||
"choices": []map[string]interface{}{
|
||
{
|
||
"message": map[string]string{
|
||
"role": "assistant",
|
||
"content": "Hello!",
|
||
},
|
||
"finish_reason": "stop",
|
||
},
|
||
},
|
||
"usage": map[string]int{
|
||
"prompt_tokens": 10,
|
||
"completion_tokens": 5,
|
||
"total_tokens": 15,
|
||
},
|
||
}
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(resp)
|
||
}))
|
||
defer server.Close()
|
||
|
||
adapter := NewOpenAIAdapter(server.URL, "test-key", []string{"gpt-4"})
|
||
|
||
messages := []Message{
|
||
{Role: "user", Content: "Hi"},
|
||
}
|
||
|
||
resp, err := adapter.ChatCompletion(context.Background(), "gpt-4", messages, CompletionOptions{})
|
||
if err != nil {
|
||
t.Fatalf("unexpected error: %v", err)
|
||
}
|
||
|
||
if resp.ID != "chatcmpl-123" {
|
||
t.Errorf("expected chatcmpl-123, got %s", resp.ID)
|
||
}
|
||
if resp.Choices[0].Message.Content != "Hello!" {
|
||
t.Errorf("expected Hello!, got %s", resp.Choices[0].Message.Content)
|
||
}
|
||
if resp.Usage.TotalTokens != 15 {
|
||
t.Errorf("expected 15, got %d", resp.Usage.TotalTokens)
|
||
}
|
||
}
|
||
|
||
func TestOpenAIAdapter_ChatCompletion_WithOptions(t *testing.T) {
|
||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
var reqBody map[string]interface{}
|
||
json.NewDecoder(r.Body).Decode(&reqBody)
|
||
|
||
// 验证选项被正确传递
|
||
if reqBody["temperature"] != 0.7 {
|
||
t.Errorf("expected temperature 0.7, got %v", reqBody["temperature"])
|
||
}
|
||
if reqBody["max_tokens"] != 100.0 {
|
||
t.Errorf("expected max_tokens 100, got %v", reqBody["max_tokens"])
|
||
}
|
||
if reqBody["top_p"] != 0.9 {
|
||
t.Errorf("expected top_p 0.9, got %v", reqBody["top_p"])
|
||
}
|
||
|
||
resp := map[string]interface{}{
|
||
"id": "chatcmpl-123",
|
||
"object": "chat.completion",
|
||
"created": 1234567890,
|
||
"model": "gpt-4",
|
||
"choices": []map[string]interface{}{
|
||
{
|
||
"message": map[string]string{
|
||
"role": "assistant",
|
||
"content": "Hi",
|
||
},
|
||
"finish_reason": "stop",
|
||
},
|
||
},
|
||
"usage": map[string]int{
|
||
"prompt_tokens": 10,
|
||
"completion_tokens": 5,
|
||
"total_tokens": 15,
|
||
},
|
||
}
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(resp)
|
||
}))
|
||
defer server.Close()
|
||
|
||
adapter := NewOpenAIAdapter(server.URL, "test-key", []string{"gpt-4"})
|
||
|
||
messages := []Message{{Role: "user", Content: "Hi"}}
|
||
options := CompletionOptions{
|
||
Temperature: 0.7,
|
||
MaxTokens: 100,
|
||
TopP: 0.9,
|
||
}
|
||
|
||
_, err := adapter.ChatCompletion(context.Background(), "gpt-4", messages, options)
|
||
if err != nil {
|
||
t.Fatalf("unexpected error: %v", err)
|
||
}
|
||
}
|
||
|
||
func TestOpenAIAdapter_ChatCompletion_HTTPError(t *testing.T) {
|
||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
w.WriteHeader(http.StatusUnauthorized)
|
||
w.Write([]byte(`{"error": {"message": "invalid_api_key"}}`))
|
||
}))
|
||
defer server.Close()
|
||
|
||
adapter := NewOpenAIAdapter(server.URL, "wrong-key", []string{"gpt-4"})
|
||
|
||
_, err := adapter.ChatCompletion(context.Background(), "gpt-4", []Message{{Role: "user", Content: "Hi"}}, CompletionOptions{})
|
||
if err == nil {
|
||
t.Fatal("expected error")
|
||
}
|
||
|
||
provErr, ok := err.(ProviderError)
|
||
if !ok {
|
||
t.Fatalf("expected ProviderError, got %T", err)
|
||
}
|
||
if provErr.Code != "PROVIDER_001" {
|
||
t.Errorf("expected PROVIDER_001, got %s", provErr.Code)
|
||
}
|
||
if provErr.HTTPStatus != 401 {
|
||
t.Errorf("expected 401, got %d", provErr.HTTPStatus)
|
||
}
|
||
}
|
||
|
||
func TestOpenAIAdapter_ChatCompletion_RateLimitError(t *testing.T) {
|
||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
w.WriteHeader(http.StatusTooManyRequests)
|
||
w.Write([]byte(`{"error": {"message": "rate_limit_exceeded"}}`))
|
||
}))
|
||
defer server.Close()
|
||
|
||
adapter := NewOpenAIAdapter(server.URL, "test-key", []string{"gpt-4"})
|
||
|
||
_, err := adapter.ChatCompletion(context.Background(), "gpt-4", []Message{{Role: "user", Content: "Hi"}}, CompletionOptions{})
|
||
if err == nil {
|
||
t.Fatal("expected error")
|
||
}
|
||
|
||
provErr, ok := err.(ProviderError)
|
||
if !ok {
|
||
t.Fatalf("expected ProviderError, got %T", err)
|
||
}
|
||
if provErr.Code != "PROVIDER_002" {
|
||
t.Errorf("expected PROVIDER_002, got %s", provErr.Code)
|
||
}
|
||
if !provErr.Retryable {
|
||
t.Error("expected Retryable to be true")
|
||
}
|
||
}
|
||
|
||
func TestOpenAIAdapter_ChatCompletion_ContextCanceled(t *testing.T) {
|
||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
// 长时间等待,确保context会取消
|
||
select {}
|
||
}))
|
||
defer server.Close()
|
||
|
||
adapter := NewOpenAIAdapter(server.URL, "test-key", []string{"gpt-4"})
|
||
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
cancel() // 立即取消
|
||
|
||
_, err := adapter.ChatCompletion(ctx, "gpt-4", []Message{{Role: "user", Content: "Hi"}}, CompletionOptions{})
|
||
if err == nil {
|
||
t.Fatal("expected error")
|
||
}
|
||
}
|
||
|
||
func TestContains(t *testing.T) {
|
||
tests := []struct {
|
||
s string
|
||
substr string
|
||
want bool
|
||
}{
|
||
{"hello world", "world", true},
|
||
{"hello world", "hello", true},
|
||
{"hello world", "xyz", false},
|
||
{"", "", true},
|
||
{"a", "abc", false},
|
||
{"abc", "abc", true},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
got := contains(tt.s, tt.substr)
|
||
if got != tt.want {
|
||
t.Errorf("contains(%q, %q) = %v, want %v", tt.s, tt.substr, got, tt.want)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestContainsHelper(t *testing.T) {
|
||
tests := []struct {
|
||
s string
|
||
substr string
|
||
want bool
|
||
}{
|
||
{"hello world", "world", true},
|
||
{"hello world", "lo wo", true},
|
||
{"hello world", "xyz", false},
|
||
{"abc", "abc", true},
|
||
{"abc", "abcd", false},
|
||
{"ab", "abc", false},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
got := containsHelper(tt.s, tt.substr)
|
||
if got != tt.want {
|
||
t.Errorf("containsHelper(%q, %q) = %v, want %v", tt.s, tt.substr, got, tt.want)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestOpenAIAdapter_HealthCheck_Success(t *testing.T) {
|
||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
if r.URL.Path != "/v1/models" {
|
||
t.Errorf("expected /v1/models, got %s", r.URL.Path)
|
||
}
|
||
w.WriteHeader(http.StatusOK)
|
||
}))
|
||
defer server.Close()
|
||
|
||
adapter := NewOpenAIAdapter(server.URL, "test-key", []string{"gpt-4"})
|
||
|
||
healthy := adapter.HealthCheck(context.Background())
|
||
if !healthy {
|
||
t.Error("expected health check to pass")
|
||
}
|
||
}
|
||
|
||
func TestOpenAIAdapter_HealthCheck_Failure(t *testing.T) {
|
||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
w.WriteHeader(http.StatusUnauthorized)
|
||
}))
|
||
defer server.Close()
|
||
|
||
adapter := NewOpenAIAdapter(server.URL, "wrong-key", []string{"gpt-4"})
|
||
|
||
healthy := adapter.HealthCheck(context.Background())
|
||
if healthy {
|
||
t.Error("expected health check to fail")
|
||
}
|
||
}
|
||
|
||
func TestOpenAIAdapter_HealthCheck_ContextTimeout(t *testing.T) {
|
||
// 使用一个会延迟响应的服务器
|
||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
// 延迟关闭连接
|
||
time.Sleep(10 * time.Second)
|
||
}))
|
||
defer server.Close()
|
||
|
||
adapter := NewOpenAIAdapter(server.URL, "test-key", []string{"gpt-4"})
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||
defer cancel()
|
||
|
||
healthy := adapter.HealthCheck(ctx)
|
||
if healthy {
|
||
t.Error("expected health check to fail due to timeout")
|
||
}
|
||
}
|
||
|
||
func TestOpenAIAdapter_ChatCompletionStream_Success(t *testing.T) {
|
||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
// 验证请求头
|
||
if r.Header.Get("Content-Type") != "application/json" {
|
||
t.Error("expected Content-Type application/json")
|
||
}
|
||
if r.Header.Get("Authorization") != "Bearer test-key" {
|
||
t.Error("expected Authorization header")
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "text/event-stream")
|
||
// 发送SSE格式的流式响应
|
||
fmt.Fprintf(w, "data: {\"id\":\"chatcmpl-1\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"delta\":{\"role\":\"assistant\",\"content\":\"Hello\"},\"finish_reason\":\"stop\"}]}\n\n")
|
||
fmt.Fprint(w, "data: [DONE]\n\n")
|
||
}))
|
||
defer server.Close()
|
||
|
||
adapter := NewOpenAIAdapter(server.URL, "test-key", []string{"gpt-4"})
|
||
|
||
messages := []Message{{Role: "user", Content: "Hi"}}
|
||
ch, err := adapter.ChatCompletionStream(context.Background(), "gpt-4", messages, CompletionOptions{})
|
||
if err != nil {
|
||
t.Fatalf("unexpected error: %v", err)
|
||
}
|
||
|
||
chunks := 0
|
||
for chunk := range ch {
|
||
chunks++
|
||
if chunk.ID != "chatcmpl-1" {
|
||
t.Errorf("expected chatcmpl-1, got %s", chunk.ID)
|
||
}
|
||
}
|
||
|
||
if chunks != 1 {
|
||
t.Errorf("expected 1 chunk, got %d", chunks)
|
||
}
|
||
}
|
||
|
||
func TestOpenAIAdapter_ChatCompletionStream_HTTPError(t *testing.T) {
|
||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
w.WriteHeader(http.StatusInternalServerError)
|
||
w.Write([]byte(`{"error": {"message": "server error"}}`))
|
||
}))
|
||
defer server.Close()
|
||
|
||
adapter := NewOpenAIAdapter(server.URL, "test-key", []string{"gpt-4"})
|
||
|
||
_, err := adapter.ChatCompletionStream(context.Background(), "gpt-4", []Message{{Role: "user", Content: "Hi"}}, CompletionOptions{})
|
||
if err == nil {
|
||
t.Fatal("expected error")
|
||
}
|
||
|
||
provErr, ok := err.(ProviderError)
|
||
if !ok {
|
||
t.Fatalf("expected ProviderError, got %T", err)
|
||
}
|
||
// MapError returns 502 for unknown errors
|
||
if provErr.HTTPStatus != 502 {
|
||
t.Errorf("expected 502, got %d", provErr.HTTPStatus)
|
||
}
|
||
}
|
||
|
||
func TestOpenAIAdapter_ChatCompletionStream_ContextCanceled(t *testing.T) {
|
||
// 这个测试验证当context在请求发送前就被取消时会发生错误
|
||
// 由于context已被取消,http.NewRequestWithContext会立即返回错误
|
||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
t.Error("server should not be called when context is already canceled")
|
||
}))
|
||
defer server.Close()
|
||
|
||
adapter := NewOpenAIAdapter(server.URL, "test-key", []string{"gpt-4"})
|
||
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
cancel() // 立即取消
|
||
|
||
ch, err := adapter.ChatCompletionStream(ctx, "gpt-4", []Message{{Role: "user", Content: "Hi"}}, CompletionOptions{})
|
||
// 当context已取消时,http.NewRequestWithContext会返回错误
|
||
if err == nil {
|
||
t.Fatal("expected error for canceled context")
|
||
}
|
||
|
||
// ch可能是nil也可能有值,取决于错误发生的时机
|
||
if ch != nil {
|
||
for range ch {
|
||
// 不应该收到任何数据
|
||
}
|
||
}
|
||
}
|