feat: enhance strict compatibility for OpenAI requests
Some checks failed
CI / Unit tests (push) Has been cancelled
CI / commit_lint (push) Has been cancelled

- Implement sanitization for `tool_choice` and removal of `disable_parallel_tool_use` in request payloads.
- Introduce logging for tool choice changes in `DoRequestHelper`.
- Update `ConvertRequest` to handle tool-call compatibility and maintain structured tool history.
- Add `ThoughtSignature` to `Part` struct for better tracking of reasoning content.
- Refactor request handling in `getRequestBody` to ensure strict compliance with OpenAI API requirements.
This commit is contained in:
hjjjj 2026-03-31 16:37:53 +08:00
parent f67f4b8caf
commit 8b87c3d404
8 changed files with 464 additions and 63 deletions

View File

@ -1,10 +1,13 @@
package adaptor package adaptor
import ( import (
"bytes"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/client" "github.com/songquanpeng/one-api/common/client"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
"io" "io"
"net/http" "net/http"
@ -24,6 +27,29 @@ func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io.
return nil, fmt.Errorf("get request url failed: %w", err) return nil, fmt.Errorf("get request url failed: %w", err)
} }
if requestBody != nil {
raw, readErr := io.ReadAll(requestBody)
if readErr != nil {
return nil, fmt.Errorf("read request body failed: %w", readErr)
}
hasDisableParallelBefore := bytes.Contains(raw, []byte("disable_parallel_tool_use"))
beforeToolChoice := extractToolChoiceForLog(raw)
raw = sanitizeStrictToolChoicePayload(raw)
hasDisableParallelAfter := bytes.Contains(raw, []byte("disable_parallel_tool_use"))
afterToolChoice := extractToolChoiceForLog(raw)
logger.Infof(
c.Request.Context(),
"[DoRequestHelper] outbound %s %s tool_choice(before=%s, after=%s) disable_parallel(before=%t, after=%t)",
c.Request.Method,
c.Request.URL.Path,
beforeToolChoice,
afterToolChoice,
hasDisableParallelBefore,
hasDisableParallelAfter,
)
requestBody = bytes.NewBuffer(raw)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil { if err != nil {
return nil, fmt.Errorf("new request failed: %w", err) return nil, fmt.Errorf("new request failed: %w", err)
@ -51,3 +77,80 @@ func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
_ = c.Request.Body.Close() _ = c.Request.Body.Close()
return resp, nil return resp, nil
} }
// Global outbound guard for strict OpenAI-compatible gateways:
// remove disable_parallel_tool_use recursively and normalize object tool_choice.
func sanitizeStrictToolChoicePayload(data []byte) []byte {
var payload any
if err := json.Unmarshal(data, &payload); err != nil {
return data
}
root, ok := payload.(map[string]any)
if !ok {
return data
}
root = stripDisableParallelKeyAny(root).(map[string]any)
hasTools := false
if tools, ok := root["tools"].([]any); ok && len(tools) > 0 {
hasTools = true
}
if functions, ok := root["functions"]; ok && functions != nil {
hasTools = true
}
if tc, ok := root["tool_choice"]; ok {
switch tc.(type) {
case map[string]any, []any:
if hasTools {
root["tool_choice"] = "auto"
} else {
delete(root, "tool_choice")
}
}
}
out, err := json.Marshal(root)
if err != nil {
return data
}
return out
}
func stripDisableParallelKeyAny(v any) any {
switch val := v.(type) {
case map[string]any:
delete(val, "disable_parallel_tool_use")
for k, child := range val {
val[k] = stripDisableParallelKeyAny(child)
}
return val
case []any:
for i, child := range val {
val[i] = stripDisableParallelKeyAny(child)
}
return val
default:
return v
}
}
func extractToolChoiceForLog(data []byte) string {
var payload map[string]any
if err := json.Unmarshal(data, &payload); err != nil {
return "<invalid-json>"
}
tc, ok := payload["tool_choice"]
if !ok {
return "<absent>"
}
switch v := tc.(type) {
case string:
return v
default:
out, err := json.Marshal(v)
if err != nil {
return "<unmarshalable>"
}
return string(out)
}
}

View File

@ -59,6 +59,7 @@ func sanitizeSchema(v interface{}) interface{} {
// Setting safety to the lowest possible values since Gemini is already powerless enough // Setting safety to the lowest possible values since Gemini is already powerless enough
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
messagesForConvert := textRequest.Messages
geminiRequest := ChatRequest{ geminiRequest := ChatRequest{
Contents: make([]ChatContent, 0, len(textRequest.Messages)), Contents: make([]ChatContent, 0, len(textRequest.Messages)),
SafetySettings: []ChatSafetySettings{ SafetySettings: []ChatSafetySettings{
@ -104,11 +105,6 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
geminiRequest.GenerationConfig.ResponseModalities = []string{"TEXT", "IMAGE"} geminiRequest.GenerationConfig.ResponseModalities = []string{"TEXT", "IMAGE"}
} }
// Enable thinking when the client explicitly requests it via enable_thinking=true.
// Use thinkingBudget=-1 (dynamic) so Gemini decides the appropriate budget.
if textRequest.EnableThinking {
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ThinkingBudget: -1}
}
if textRequest.Tools != nil { if textRequest.Tools != nil {
functions := make([]model.Function, 0, len(textRequest.Tools)) functions := make([]model.Function, 0, len(textRequest.Tools))
for _, tool := range textRequest.Tools { for _, tool := range textRequest.Tools {
@ -130,42 +126,87 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
}, },
} }
} }
// Tool-call compatibility mode:
// Keep tool rounds on strict structured functionCall/functionResponse flow.
// Gemini thinking tool rounds behind gateways are prone to thought_signature
// validation failures; disable thinking config for tool flow to avoid fallback
// text replay and preserve structured tool history.
hasToolFlow := len(geminiRequest.Tools) > 0
if !hasToolFlow {
for _, msg := range messagesForConvert {
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
hasToolFlow = true
break
}
}
}
if hasToolFlow {
// Remove incompatible historical tool-call turns that cannot satisfy
// Gemini thought_signature requirements. Keep structured tool flow for
// the remaining valid turns (no text downgrade replay).
messagesForConvert = sanitizeGeminiToolHistoryMessages(messagesForConvert)
}
// Enable thinking only for non-tool turns.
if textRequest.EnableThinking && !hasToolFlow {
// Use thinkingBudget=-1 (dynamic) so Gemini decides the appropriate budget.
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ThinkingBudget: -1}
} else if hasToolFlow {
geminiRequest.GenerationConfig.ThinkingConfig = nil
}
// Build a map from tool_call_id → function name for resolving tool result names // Build a map from tool_call_id → function name for resolving tool result names
toolCallIdToName := map[string]string{} toolCallIdToName := map[string]string{}
for _, message := range textRequest.Messages { lastKnownThoughtSignature := ""
for _, message := range messagesForConvert {
if message.ReasoningEncryptedContent != "" {
lastKnownThoughtSignature = message.ReasoningEncryptedContent
}
if message.Role == "assistant" { if message.Role == "assistant" {
for _, tc := range message.ToolCalls { for _, tc := range message.ToolCalls {
if tc.Id != "" && tc.Function.Name != "" { if tc.Id != "" && tc.Function.Name != "" {
toolCallIdToName[tc.Id] = tc.Function.Name toolCallIdToName[tc.Id] = tc.Function.Name
} }
if tc.ExtraContent != nil && tc.ExtraContent.Google != nil && tc.ExtraContent.Google.ThoughtSignature != "" {
lastKnownThoughtSignature = tc.ExtraContent.Google.ThoughtSignature
}
} }
} }
} }
shouldAddDummyModelMessage := false shouldAddDummyModelMessage := false
for _, message := range textRequest.Messages { for i := 0; i < len(messagesForConvert); i++ {
message := messagesForConvert[i]
// --- tool result: role=tool → Gemini functionResponse (user role) --- // --- tool result: role=tool → Gemini functionResponse (user role) ---
// Gemini requires functionResponse parts to match functionCall parts from the
// preceding model turn. For OpenAI-compat payloads with multiple consecutive
// role=tool messages, merge them into a single user turn with multiple parts.
if message.Role == "tool" { if message.Role == "tool" {
toolName := message.ToolCallId var responseParts []Part
if name, ok := toolCallIdToName[message.ToolCallId]; ok { for ; i < len(messagesForConvert) && messagesForConvert[i].Role == "tool"; i++ {
toolMessage := messagesForConvert[i]
toolName := toolMessage.ToolCallId
if name, ok := toolCallIdToName[toolMessage.ToolCallId]; ok {
toolName = name toolName = name
} else if message.Name != nil && *message.Name != "" { } else if toolMessage.Name != nil && *toolMessage.Name != "" {
toolName = *message.Name toolName = *toolMessage.Name
} }
if toolName == "" { if toolName == "" {
toolName = "unknown_tool" toolName = "unknown_tool"
} }
geminiRequest.Contents = append(geminiRequest.Contents, ChatContent{ responseParts = append(responseParts, Part{
Role: "user",
Parts: []Part{
{
FunctionResponse: &FunctionResponse{ FunctionResponse: &FunctionResponse{
Name: toolName, Name: toolName,
Response: map[string]any{"content": message.StringContent()}, Response: map[string]any{"content": toolMessage.StringContent()},
},
},
}, },
}) })
}
if len(responseParts) > 0 {
geminiRequest.Contents = append(geminiRequest.Contents, ChatContent{
Role: "user",
Parts: responseParts,
})
}
i--
continue continue
} }
@ -248,11 +289,25 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
} else { } else {
args = map[string]any{} args = map[string]any{}
} }
thoughtSignature := ""
if tc.ExtraContent != nil && tc.ExtraContent.Google != nil {
thoughtSignature = tc.ExtraContent.Google.ThoughtSignature
}
if thoughtSignature == "" {
thoughtSignature = message.ReasoningEncryptedContent
}
if thoughtSignature == "" {
thoughtSignature = lastKnownThoughtSignature
}
if thoughtSignature != "" {
lastKnownThoughtSignature = thoughtSignature
}
fcParts = append(fcParts, Part{ fcParts = append(fcParts, Part{
FunctionCall: &FunctionCall{ FunctionCall: &FunctionCall{
FunctionName: tc.Function.Name, FunctionName: tc.Function.Name,
Arguments: args, Arguments: args,
}, },
ThoughtSignature: thoughtSignature,
}) })
} }
content.Role = "model" content.Role = "model"
@ -298,6 +353,76 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
return &geminiRequest return &geminiRequest
} }
func sanitizeGeminiToolHistoryMessages(messages []model.Message) []model.Message {
if len(messages) == 0 {
return messages
}
sanitized := make([]model.Message, 0, len(messages))
droppedToolCallIDs := make(map[string]struct{})
lastKnownThoughtSignature := ""
for _, message := range messages {
if message.ReasoningEncryptedContent != "" {
lastKnownThoughtSignature = message.ReasoningEncryptedContent
}
if message.Role == "assistant" && len(message.ToolCalls) > 0 {
keptToolCalls := make([]model.Tool, 0, len(message.ToolCalls))
for _, tc := range message.ToolCalls {
thoughtSignature := ""
if tc.ExtraContent != nil && tc.ExtraContent.Google != nil {
thoughtSignature = strings.TrimSpace(tc.ExtraContent.Google.ThoughtSignature)
}
if thoughtSignature == "" {
thoughtSignature = strings.TrimSpace(message.ReasoningEncryptedContent)
}
if thoughtSignature == "" {
thoughtSignature = strings.TrimSpace(lastKnownThoughtSignature)
}
if thoughtSignature == "" {
if tc.Id != "" {
droppedToolCallIDs[tc.Id] = struct{}{}
}
continue
}
if tc.ExtraContent == nil {
tc.ExtraContent = &model.ToolExtraContent{}
}
if tc.ExtraContent.Google == nil {
tc.ExtraContent.Google = &model.GoogleToolExtraContent{}
}
tc.ExtraContent.Google.ThoughtSignature = thoughtSignature
lastKnownThoughtSignature = thoughtSignature
keptToolCalls = append(keptToolCalls, tc)
}
if len(keptToolCalls) == 0 {
assistantText := strings.TrimSpace(message.StringContent())
if assistantText == "" {
continue
}
message.ToolCalls = nil
message.Content = assistantText
sanitized = append(sanitized, message)
continue
}
message.ToolCalls = keptToolCalls
sanitized = append(sanitized, message)
continue
}
if message.Role == "tool" && message.ToolCallId != "" {
if _, dropped := droppedToolCallIDs[message.ToolCallId]; dropped {
continue
}
}
sanitized = append(sanitized, message)
}
return sanitized
}
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *BatchEmbeddingRequest { func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *BatchEmbeddingRequest {
inputs := request.ParseInput() inputs := request.ParseInput()
requests := make([]EmbeddingRequest, len(inputs)) requests := make([]EmbeddingRequest, len(inputs))
@ -369,15 +494,14 @@ type ChatPromptFeedback struct {
func getToolCalls(candidate *ChatCandidate) []model.Tool { func getToolCalls(candidate *ChatCandidate) []model.Tool {
var toolCalls []model.Tool var toolCalls []model.Tool
for _, item := range candidate.Content.Parts {
item := candidate.Content.Parts[0]
if item.FunctionCall == nil { if item.FunctionCall == nil {
return toolCalls continue
} }
argsBytes, err := json.Marshal(item.FunctionCall.Arguments) argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
if err != nil { if err != nil {
logger.FatalLog("getToolCalls failed: " + err.Error()) logger.FatalLog("getToolCalls failed: " + err.Error())
return toolCalls continue
} }
toolCall := model.Tool{ toolCall := model.Tool{
Id: fmt.Sprintf("call_%s", random.GetUUID()), Id: fmt.Sprintf("call_%s", random.GetUUID()),
@ -387,7 +511,15 @@ func getToolCalls(candidate *ChatCandidate) []model.Tool {
Name: item.FunctionCall.FunctionName, Name: item.FunctionCall.FunctionName,
}, },
} }
if item.ThoughtSignature != "" {
toolCall.ExtraContent = &model.ToolExtraContent{
Google: &model.GoogleToolExtraContent{
ThoughtSignature: item.ThoughtSignature,
},
}
}
toolCalls = append(toolCalls, toolCall) toolCalls = append(toolCalls, toolCall)
}
return toolCalls return toolCalls
} }
@ -407,12 +539,13 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
FinishReason: constant.StopFinishReason, FinishReason: constant.StopFinishReason,
} }
if len(candidate.Content.Parts) > 0 { if len(candidate.Content.Parts) > 0 {
if candidate.Content.Parts[0].FunctionCall != nil { toolCalls := getToolCalls(&candidate)
choice.Message.ToolCalls = getToolCalls(&candidate) if len(toolCalls) > 0 {
choice.Message.ToolCalls = toolCalls
} else { } else {
var builder strings.Builder var builder strings.Builder
for _, part := range candidate.Content.Parts { for j, part := range candidate.Content.Parts {
if i > 0 { if j > 0 {
builder.WriteString("\n") builder.WriteString("\n")
} }
builder.WriteString(part.Text) builder.WriteString(part.Text)
@ -434,10 +567,35 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatC
if len(geminiResponse.Candidates) > 0 { if len(geminiResponse.Candidates) > 0 {
var textBuilder strings.Builder var textBuilder strings.Builder
var thinkingBuilder strings.Builder var thinkingBuilder strings.Builder
var toolCalls []model.Tool
for _, part := range geminiResponse.Candidates[0].Content.Parts { for _, part := range geminiResponse.Candidates[0].Content.Parts {
if part.Thought { if part.Thought {
// Thinking/reasoning content — route to reasoning_content field // Thinking/reasoning content — route to reasoning_content field
thinkingBuilder.WriteString(part.Text) thinkingBuilder.WriteString(part.Text)
} else if part.FunctionCall != nil {
argsBytes, err := json.Marshal(part.FunctionCall.Arguments)
if err != nil {
logger.FatalLog("streamResponseGeminiChat2OpenAI marshal args failed: " + err.Error())
argsBytes = []byte("{}")
}
toolCalls = append(toolCalls, model.Tool{
Id: fmt.Sprintf("call_%s", random.GetUUID()),
Type: "function",
Function: model.Function{
Name: part.FunctionCall.FunctionName,
Arguments: string(argsBytes),
},
ExtraContent: func() *model.ToolExtraContent {
if part.ThoughtSignature == "" {
return nil
}
return &model.ToolExtraContent{
Google: &model.GoogleToolExtraContent{
ThoughtSignature: part.ThoughtSignature,
},
}
}(),
})
} else if part.Text != "" { } else if part.Text != "" {
textBuilder.WriteString(part.Text) textBuilder.WriteString(part.Text)
} else if part.InlineData != nil && part.InlineData.Data != "" { } else if part.InlineData != nil && part.InlineData.Data != "" {
@ -455,6 +613,11 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatC
if thinkingBuilder.Len() > 0 { if thinkingBuilder.Len() > 0 {
choice.Delta.ReasoningContent = thinkingBuilder.String() choice.Delta.ReasoningContent = thinkingBuilder.String()
} }
if len(toolCalls) > 0 {
choice.Delta.ToolCalls = toolCalls
toolFinish := "tool_calls"
choice.FinishReason = &toolFinish
}
} }
var response openai.ChatCompletionsStreamResponse var response openai.ChatCompletionsStreamResponse

View File

@ -55,6 +55,7 @@ type Part struct {
InlineData *InlineData `json:"inlineData,omitempty"` InlineData *InlineData `json:"inlineData,omitempty"`
FunctionCall *FunctionCall `json:"functionCall,omitempty"` FunctionCall *FunctionCall `json:"functionCall,omitempty"`
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"` FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
ThoughtSignature string `json:"thoughtSignature,omitempty"`
// Thought marks this part as internal reasoning/thinking content (Gemini thinking models) // Thought marks this part as internal reasoning/thinking content (Gemini thinking models)
Thought bool `json:"thought,omitempty"` Thought bool `json:"thought,omitempty"`
} }

View File

@ -32,6 +32,8 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener
if err != nil { if err != nil {
return nil, err return nil, err
} }
sanitized := sanitizeToolChoiceForStrictCompat(textRequest)
c.Set("sanitized_text_request_modified", sanitized)
if relayMode == relaymode.Moderations && textRequest.Model == "" { if relayMode == relaymode.Moderations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest" textRequest.Model = "text-moderation-latest"
} }
@ -45,6 +47,62 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener
return textRequest, nil return textRequest, nil
} }
func sanitizeToolChoiceForStrictCompat(req *relaymodel.GeneralOpenAIRequest) bool {
if req == nil {
return false
}
modified := stripDisableParallelToolUse(&req.ToolChoice)
// Strict OpenAI-compatible gateways often reject object-form tool_choice.
// Coerce object/array forms to "auto" when tools/functions exist.
switch req.ToolChoice.(type) {
case map[string]any, []any:
if len(req.Tools) > 0 || req.Functions != nil {
if req.ToolChoice != "auto" {
modified = true
}
req.ToolChoice = "auto"
} else {
if req.ToolChoice != nil {
modified = true
}
req.ToolChoice = nil
}
}
return modified
}
func stripDisableParallelToolUse(v *any) bool {
if v == nil || *v == nil {
return false
}
modified := false
switch val := (*v).(type) {
case map[string]any:
if _, ok := val["disable_parallel_tool_use"]; ok {
delete(val, "disable_parallel_tool_use")
modified = true
}
for k := range val {
child := val[k]
if stripDisableParallelToolUse(&child) {
modified = true
}
val[k] = child
}
*v = val
case []any:
for i := range val {
child := val[i]
if stripDisableParallelToolUse(&child) {
modified = true
}
val[i] = child
}
*v = val
}
return modified
}
func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
switch relayMode { switch relayMode {
case relaymode.ChatCompletions: case relaymode.ChatCompletions:

View File

@ -2,7 +2,9 @@
package controller package controller
import ( import (
"bytes"
"fmt" "fmt"
"io"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -24,7 +26,14 @@ func RelayProxyHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
} }
adaptor.Init(meta) adaptor.Init(meta)
resp, err := adaptor.DoRequest(c, meta, c.Request.Body) requestBody, err := io.ReadAll(c.Request.Body)
if err != nil {
logger.Errorf(ctx, "read proxy request body failed: %s", err.Error())
return openai.ErrorWrapper(err, "read_request_body_failed", http.StatusBadRequest)
}
requestBody = sanitizeStrictToolChoiceJSON(requestBody)
resp, err := adaptor.DoRequest(c, meta, bytes.NewBuffer(requestBody))
if err != nil { if err != nil {
logger.Errorf(ctx, "DoRequest failed: %s", err.Error()) logger.Errorf(ctx, "DoRequest failed: %s", err.Error())
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)

View File

@ -9,15 +9,12 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay" "github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/apitype"
"github.com/songquanpeng/one-api/relay/billing" "github.com/songquanpeng/one-api/relay/billing"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
) )
@ -98,14 +95,9 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
} }
func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) { func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) {
if !config.EnforceIncludeUsage && // Always convert request through adaptor before sending upstream.
meta.APIType == apitype.OpenAI && // This avoids passthrough edge cases where strict-compat sanitization may be bypassed.
meta.OriginModelName == meta.ActualModelName && // The slight performance cost is acceptable compared to malformed request risk.
meta.ChannelType != channeltype.Baichuan &&
meta.ForcedSystemPrompt == "" {
// no need to convert request for openai
return c.Request.Body, nil
}
// get request body // get request body
var requestBody io.Reader var requestBody io.Reader
@ -119,7 +111,72 @@ func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralO
logger.Debugf(c.Request.Context(), "converted request json_marshal_failed: %s\n", err.Error()) logger.Debugf(c.Request.Context(), "converted request json_marshal_failed: %s\n", err.Error())
return nil, err return nil, err
} }
jsonData = sanitizeStrictToolChoiceJSON(jsonData)
logger.Debugf(c.Request.Context(), "converted request: \n%s", string(jsonData)) logger.Debugf(c.Request.Context(), "converted request: \n%s", string(jsonData))
requestBody = bytes.NewBuffer(jsonData) requestBody = bytes.NewBuffer(jsonData)
return requestBody, nil return requestBody, nil
} }
// Final outbound guard for strict OpenAI-compatible gateways:
// - remove any disable_parallel_tool_use key at any depth
// - coerce object/array tool_choice to OpenAI-compatible string form
func sanitizeStrictToolChoiceJSON(data []byte) []byte {
var payload any
if err := json.Unmarshal(data, &payload); err != nil {
return data
}
root, ok := payload.(map[string]any)
if !ok {
return data
}
cleaned := stripDisableParallelToolUseAny(root)
root, ok = cleaned.(map[string]any)
if !ok {
return data
}
hasTools := false
if tools, ok := root["tools"].([]any); ok && len(tools) > 0 {
hasTools = true
}
if functions, ok := root["functions"]; ok && functions != nil {
hasTools = true
}
if tc, ok := root["tool_choice"]; ok {
switch tc.(type) {
case map[string]any, []any:
if hasTools {
root["tool_choice"] = "auto"
} else {
delete(root, "tool_choice")
}
}
}
out, err := json.Marshal(root)
if err != nil {
return data
}
return out
}
func stripDisableParallelToolUseAny(v any) any {
switch val := v.(type) {
case map[string]any:
delete(val, "disable_parallel_tool_use")
for k, child := range val {
val[k] = stripDisableParallelToolUseAny(child)
}
return val
case []any:
for i, child := range val {
val[i] = stripDisableParallelToolUseAny(child)
}
return val
default:
return v
}
}

View File

@ -4,6 +4,7 @@ type Message struct {
Role string `json:"role,omitempty"` Role string `json:"role,omitempty"`
Content any `json:"content,omitempty"` Content any `json:"content,omitempty"`
ReasoningContent any `json:"reasoning_content,omitempty"` ReasoningContent any `json:"reasoning_content,omitempty"`
ReasoningEncryptedContent string `json:"reasoning_encrypted_content,omitempty"`
Name *string `json:"name,omitempty"` Name *string `json:"name,omitempty"`
ToolCalls []Tool `json:"tool_calls,omitempty"` ToolCalls []Tool `json:"tool_calls,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"` ToolCallId string `json:"tool_call_id,omitempty"`

View File

@ -4,6 +4,7 @@ type Tool struct {
Id string `json:"id,omitempty"` Id string `json:"id,omitempty"`
Type string `json:"type,omitempty"` // when splicing claude tools stream messages, it is empty Type string `json:"type,omitempty"` // when splicing claude tools stream messages, it is empty
Function Function `json:"function"` Function Function `json:"function"`
ExtraContent *ToolExtraContent `json:"extra_content,omitempty"`
} }
type Function struct { type Function struct {
@ -12,3 +13,11 @@ type Function struct {
Parameters any `json:"parameters,omitempty"` // request Parameters any `json:"parameters,omitempty"` // request
Arguments any `json:"arguments,omitempty"` // response Arguments any `json:"arguments,omitempty"` // response
} }
type ToolExtraContent struct {
Google *GoogleToolExtraContent `json:"google,omitempty"`
}
type GoogleToolExtraContent struct {
ThoughtSignature string `json:"thought_signature,omitempty"`
}