198 lines
5.6 KiB
Go
198 lines
5.6 KiB
Go
|
|
package handler
|
||
|
|
|
||
|
|
import (
|
||
|
|
"crypto/rand"
|
||
|
|
"encoding/json"
|
||
|
|
"fmt"
|
||
|
|
"net/http"
|
||
|
|
"strconv"
|
||
|
|
"strings"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"github.com/gin-gonic/gin"
|
||
|
|
)
|
||
|
|
|
||
|
|
type InterceptType int
|
||
|
|
|
||
|
|
const (
|
||
|
|
InterceptTypeNone InterceptType = iota
|
||
|
|
InterceptTypeWarmup
|
||
|
|
InterceptTypeSuggestionMode
|
||
|
|
InterceptTypeMaxTokensOneHaiku
|
||
|
|
)
|
||
|
|
|
||
|
|
func isHaikuModel(model string) bool {
|
||
|
|
return strings.Contains(strings.ToLower(model), "haiku")
|
||
|
|
}
|
||
|
|
|
||
|
|
func isMaxTokensOneHaikuRequest(model string, maxTokens int, isStream bool) bool {
|
||
|
|
return maxTokens == 1 && isHaikuModel(model) && !isStream
|
||
|
|
}
|
||
|
|
|
||
|
|
func detectInterceptType(body []byte, model string, maxTokens int, isStream bool, isClaudeCodeClient bool) InterceptType {
|
||
|
|
if isClaudeCodeClient && isMaxTokensOneHaikuRequest(model, maxTokens, isStream) {
|
||
|
|
return InterceptTypeMaxTokensOneHaiku
|
||
|
|
}
|
||
|
|
|
||
|
|
bodyStr := string(body)
|
||
|
|
hasSuggestionMode := strings.Contains(bodyStr, "[SUGGESTION MODE:")
|
||
|
|
hasWarmupKeyword := strings.Contains(bodyStr, "title") || strings.Contains(bodyStr, "Warmup")
|
||
|
|
|
||
|
|
if !hasSuggestionMode && !hasWarmupKeyword {
|
||
|
|
return InterceptTypeNone
|
||
|
|
}
|
||
|
|
|
||
|
|
var req struct {
|
||
|
|
Messages []struct {
|
||
|
|
Role string `json:"role"`
|
||
|
|
Content []struct {
|
||
|
|
Type string `json:"type"`
|
||
|
|
Text string `json:"text"`
|
||
|
|
} `json:"content"`
|
||
|
|
} `json:"messages"`
|
||
|
|
System []struct {
|
||
|
|
Text string `json:"text"`
|
||
|
|
} `json:"system"`
|
||
|
|
}
|
||
|
|
if err := json.Unmarshal(body, &req); err != nil {
|
||
|
|
return InterceptTypeNone
|
||
|
|
}
|
||
|
|
|
||
|
|
if hasSuggestionMode && len(req.Messages) > 0 {
|
||
|
|
lastMsg := req.Messages[len(req.Messages)-1]
|
||
|
|
if lastMsg.Role == "user" && len(lastMsg.Content) > 0 &&
|
||
|
|
lastMsg.Content[0].Type == "text" &&
|
||
|
|
strings.HasPrefix(lastMsg.Content[0].Text, "[SUGGESTION MODE:") {
|
||
|
|
return InterceptTypeSuggestionMode
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if hasWarmupKeyword {
|
||
|
|
for _, msg := range req.Messages {
|
||
|
|
for _, content := range msg.Content {
|
||
|
|
if content.Type == "text" {
|
||
|
|
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
|
||
|
|
content.Text == "Warmup" {
|
||
|
|
return InterceptTypeWarmup
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
for _, sys := range req.System {
|
||
|
|
if strings.Contains(sys.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
|
||
|
|
return InterceptTypeWarmup
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return InterceptTypeNone
|
||
|
|
}
|
||
|
|
|
||
|
|
func sendMockInterceptStream(c *gin.Context, model string, interceptType InterceptType) {
|
||
|
|
c.Header("Content-Type", "text/event-stream")
|
||
|
|
c.Header("Cache-Control", "no-cache")
|
||
|
|
c.Header("Connection", "keep-alive")
|
||
|
|
c.Header("X-Accel-Buffering", "no")
|
||
|
|
|
||
|
|
var msgID string
|
||
|
|
var outputTokens int
|
||
|
|
var textDeltas []string
|
||
|
|
|
||
|
|
switch interceptType {
|
||
|
|
case InterceptTypeSuggestionMode:
|
||
|
|
msgID = "msg_mock_suggestion"
|
||
|
|
outputTokens = 1
|
||
|
|
textDeltas = []string{""}
|
||
|
|
default:
|
||
|
|
msgID = "msg_mock_warmup"
|
||
|
|
outputTokens = 2
|
||
|
|
textDeltas = []string{"New", " Conversation"}
|
||
|
|
}
|
||
|
|
|
||
|
|
messageStartJSON := `{"type":"message_start","message":{"id":` + strconv.Quote(msgID) + `,"type":"message","role":"assistant","model":` + strconv.Quote(model) + `,"content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":0}}}`
|
||
|
|
|
||
|
|
events := []string{
|
||
|
|
`event: message_start` + "\n" + `data: ` + string(messageStartJSON),
|
||
|
|
`event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`,
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, text := range textDeltas {
|
||
|
|
deltaJSON := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":` + strconv.Quote(text) + `}}`
|
||
|
|
events = append(events, `event: content_block_delta`+"\n"+`data: `+string(deltaJSON))
|
||
|
|
}
|
||
|
|
|
||
|
|
messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":10,"output_tokens":` + strconv.Itoa(outputTokens) + `}}`
|
||
|
|
|
||
|
|
events = append(events,
|
||
|
|
`event: content_block_stop`+"\n"+`data: {"index":0,"type":"content_block_stop"}`,
|
||
|
|
`event: message_delta`+"\n"+`data: `+string(messageDeltaJSON),
|
||
|
|
`event: message_stop`+"\n"+`data: {"type":"message_stop"}`,
|
||
|
|
)
|
||
|
|
|
||
|
|
for _, event := range events {
|
||
|
|
_, _ = c.Writer.WriteString(event + "\n\n")
|
||
|
|
c.Writer.Flush()
|
||
|
|
time.Sleep(20 * time.Millisecond)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func generateRealisticMsgID() string {
|
||
|
|
const charset = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||
|
|
const idLen = 24
|
||
|
|
randomBytes := make([]byte, idLen)
|
||
|
|
if _, err := rand.Read(randomBytes); err != nil {
|
||
|
|
return fmt.Sprintf("msg_bdrk_%d", time.Now().UnixNano())
|
||
|
|
}
|
||
|
|
b := make([]byte, idLen)
|
||
|
|
for i := range b {
|
||
|
|
b[i] = charset[int(randomBytes[i])%len(charset)]
|
||
|
|
}
|
||
|
|
return "msg_bdrk_" + string(b)
|
||
|
|
}
|
||
|
|
|
||
|
|
func sendMockInterceptResponse(c *gin.Context, model string, interceptType InterceptType) {
|
||
|
|
var msgID, text, stopReason string
|
||
|
|
var outputTokens int
|
||
|
|
|
||
|
|
switch interceptType {
|
||
|
|
case InterceptTypeSuggestionMode:
|
||
|
|
msgID = "msg_mock_suggestion"
|
||
|
|
text = ""
|
||
|
|
outputTokens = 1
|
||
|
|
stopReason = "end_turn"
|
||
|
|
case InterceptTypeMaxTokensOneHaiku:
|
||
|
|
msgID = generateRealisticMsgID()
|
||
|
|
text = "#"
|
||
|
|
outputTokens = 1
|
||
|
|
stopReason = "max_tokens"
|
||
|
|
default:
|
||
|
|
msgID = "msg_mock_warmup"
|
||
|
|
text = "New Conversation"
|
||
|
|
outputTokens = 2
|
||
|
|
stopReason = "end_turn"
|
||
|
|
}
|
||
|
|
|
||
|
|
response := gin.H{
|
||
|
|
"model": model,
|
||
|
|
"id": msgID,
|
||
|
|
"type": "message",
|
||
|
|
"role": "assistant",
|
||
|
|
"content": []gin.H{{"type": "text", "text": text}},
|
||
|
|
"stop_reason": stopReason,
|
||
|
|
"stop_sequence": nil,
|
||
|
|
"usage": gin.H{
|
||
|
|
"input_tokens": 10,
|
||
|
|
"cache_creation_input_tokens": 0,
|
||
|
|
"cache_read_input_tokens": 0,
|
||
|
|
"cache_creation": gin.H{
|
||
|
|
"ephemeral_5m_input_tokens": 0,
|
||
|
|
"ephemeral_1h_input_tokens": 0,
|
||
|
|
},
|
||
|
|
"output_tokens": outputTokens,
|
||
|
|
"total_tokens": 10 + outputTokens,
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
c.JSON(http.StatusOK, response)
|
||
|
|
}
|