feat: enhance strict compatibility for OpenAI requests
- Implement sanitization for `tool_choice` and removal of `disable_parallel_tool_use` in request payloads. - Introduce logging for tool choice changes in `DoRequestHelper`. - Update `ConvertRequest` to handle tool-call compatibility and maintain structured tool history. - Add `ThoughtSignature` to `Part` struct for better tracking of reasoning content. - Refactor request handling in `getRequestBody` to ensure strict compliance with OpenAI API requirements.
This commit is contained in:
parent
f67f4b8caf
commit
8b87c3d404
@ -1,10 +1,13 @@
|
||||
package adaptor
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/client"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -24,6 +27,29 @@ func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io.
|
||||
return nil, fmt.Errorf("get request url failed: %w", err)
|
||||
}
|
||||
|
||||
if requestBody != nil {
|
||||
raw, readErr := io.ReadAll(requestBody)
|
||||
if readErr != nil {
|
||||
return nil, fmt.Errorf("read request body failed: %w", readErr)
|
||||
}
|
||||
hasDisableParallelBefore := bytes.Contains(raw, []byte("disable_parallel_tool_use"))
|
||||
beforeToolChoice := extractToolChoiceForLog(raw)
|
||||
raw = sanitizeStrictToolChoicePayload(raw)
|
||||
hasDisableParallelAfter := bytes.Contains(raw, []byte("disable_parallel_tool_use"))
|
||||
afterToolChoice := extractToolChoiceForLog(raw)
|
||||
logger.Infof(
|
||||
c.Request.Context(),
|
||||
"[DoRequestHelper] outbound %s %s tool_choice(before=%s, after=%s) disable_parallel(before=%t, after=%t)",
|
||||
c.Request.Method,
|
||||
c.Request.URL.Path,
|
||||
beforeToolChoice,
|
||||
afterToolChoice,
|
||||
hasDisableParallelBefore,
|
||||
hasDisableParallelAfter,
|
||||
)
|
||||
requestBody = bytes.NewBuffer(raw)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new request failed: %w", err)
|
||||
@ -51,3 +77,80 @@ func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
|
||||
_ = c.Request.Body.Close()
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Global outbound guard for strict OpenAI-compatible gateways:
|
||||
// remove disable_parallel_tool_use recursively and normalize object tool_choice.
|
||||
func sanitizeStrictToolChoicePayload(data []byte) []byte {
|
||||
var payload any
|
||||
if err := json.Unmarshal(data, &payload); err != nil {
|
||||
return data
|
||||
}
|
||||
root, ok := payload.(map[string]any)
|
||||
if !ok {
|
||||
return data
|
||||
}
|
||||
root = stripDisableParallelKeyAny(root).(map[string]any)
|
||||
|
||||
hasTools := false
|
||||
if tools, ok := root["tools"].([]any); ok && len(tools) > 0 {
|
||||
hasTools = true
|
||||
}
|
||||
if functions, ok := root["functions"]; ok && functions != nil {
|
||||
hasTools = true
|
||||
}
|
||||
if tc, ok := root["tool_choice"]; ok {
|
||||
switch tc.(type) {
|
||||
case map[string]any, []any:
|
||||
if hasTools {
|
||||
root["tool_choice"] = "auto"
|
||||
} else {
|
||||
delete(root, "tool_choice")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out, err := json.Marshal(root)
|
||||
if err != nil {
|
||||
return data
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func stripDisableParallelKeyAny(v any) any {
|
||||
switch val := v.(type) {
|
||||
case map[string]any:
|
||||
delete(val, "disable_parallel_tool_use")
|
||||
for k, child := range val {
|
||||
val[k] = stripDisableParallelKeyAny(child)
|
||||
}
|
||||
return val
|
||||
case []any:
|
||||
for i, child := range val {
|
||||
val[i] = stripDisableParallelKeyAny(child)
|
||||
}
|
||||
return val
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
func extractToolChoiceForLog(data []byte) string {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(data, &payload); err != nil {
|
||||
return "<invalid-json>"
|
||||
}
|
||||
tc, ok := payload["tool_choice"]
|
||||
if !ok {
|
||||
return "<absent>"
|
||||
}
|
||||
switch v := tc.(type) {
|
||||
case string:
|
||||
return v
|
||||
default:
|
||||
out, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return "<unmarshalable>"
|
||||
}
|
||||
return string(out)
|
||||
}
|
||||
}
|
||||
|
||||
@ -59,6 +59,7 @@ func sanitizeSchema(v interface{}) interface{} {
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
||||
messagesForConvert := textRequest.Messages
|
||||
geminiRequest := ChatRequest{
|
||||
Contents: make([]ChatContent, 0, len(textRequest.Messages)),
|
||||
SafetySettings: []ChatSafetySettings{
|
||||
@ -104,11 +105,6 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
||||
geminiRequest.GenerationConfig.ResponseModalities = []string{"TEXT", "IMAGE"}
|
||||
}
|
||||
|
||||
// Enable thinking when the client explicitly requests it via enable_thinking=true.
|
||||
// Use thinkingBudget=-1 (dynamic) so Gemini decides the appropriate budget.
|
||||
if textRequest.EnableThinking {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ThinkingBudget: -1}
|
||||
}
|
||||
if textRequest.Tools != nil {
|
||||
functions := make([]model.Function, 0, len(textRequest.Tools))
|
||||
for _, tool := range textRequest.Tools {
|
||||
@ -130,42 +126,87 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Tool-call compatibility mode:
|
||||
// Keep tool rounds on strict structured functionCall/functionResponse flow.
|
||||
// Gemini thinking tool rounds behind gateways are prone to thought_signature
|
||||
// validation failures; disable thinking config for tool flow to avoid fallback
|
||||
// text replay and preserve structured tool history.
|
||||
hasToolFlow := len(geminiRequest.Tools) > 0
|
||||
if !hasToolFlow {
|
||||
for _, msg := range messagesForConvert {
|
||||
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
|
||||
hasToolFlow = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if hasToolFlow {
|
||||
// Remove incompatible historical tool-call turns that cannot satisfy
|
||||
// Gemini thought_signature requirements. Keep structured tool flow for
|
||||
// the remaining valid turns (no text downgrade replay).
|
||||
messagesForConvert = sanitizeGeminiToolHistoryMessages(messagesForConvert)
|
||||
}
|
||||
// Enable thinking only for non-tool turns.
|
||||
if textRequest.EnableThinking && !hasToolFlow {
|
||||
// Use thinkingBudget=-1 (dynamic) so Gemini decides the appropriate budget.
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ThinkingBudget: -1}
|
||||
} else if hasToolFlow {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = nil
|
||||
}
|
||||
// Build a map from tool_call_id → function name for resolving tool result names
|
||||
toolCallIdToName := map[string]string{}
|
||||
for _, message := range textRequest.Messages {
|
||||
lastKnownThoughtSignature := ""
|
||||
for _, message := range messagesForConvert {
|
||||
if message.ReasoningEncryptedContent != "" {
|
||||
lastKnownThoughtSignature = message.ReasoningEncryptedContent
|
||||
}
|
||||
if message.Role == "assistant" {
|
||||
for _, tc := range message.ToolCalls {
|
||||
if tc.Id != "" && tc.Function.Name != "" {
|
||||
toolCallIdToName[tc.Id] = tc.Function.Name
|
||||
}
|
||||
if tc.ExtraContent != nil && tc.ExtraContent.Google != nil && tc.ExtraContent.Google.ThoughtSignature != "" {
|
||||
lastKnownThoughtSignature = tc.ExtraContent.Google.ThoughtSignature
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
shouldAddDummyModelMessage := false
|
||||
for _, message := range textRequest.Messages {
|
||||
for i := 0; i < len(messagesForConvert); i++ {
|
||||
message := messagesForConvert[i]
|
||||
// --- tool result: role=tool → Gemini functionResponse (user role) ---
|
||||
// Gemini requires functionResponse parts to match functionCall parts from the
|
||||
// preceding model turn. For OpenAI-compat payloads with multiple consecutive
|
||||
// role=tool messages, merge them into a single user turn with multiple parts.
|
||||
if message.Role == "tool" {
|
||||
toolName := message.ToolCallId
|
||||
if name, ok := toolCallIdToName[message.ToolCallId]; ok {
|
||||
var responseParts []Part
|
||||
for ; i < len(messagesForConvert) && messagesForConvert[i].Role == "tool"; i++ {
|
||||
toolMessage := messagesForConvert[i]
|
||||
toolName := toolMessage.ToolCallId
|
||||
if name, ok := toolCallIdToName[toolMessage.ToolCallId]; ok {
|
||||
toolName = name
|
||||
} else if message.Name != nil && *message.Name != "" {
|
||||
toolName = *message.Name
|
||||
} else if toolMessage.Name != nil && *toolMessage.Name != "" {
|
||||
toolName = *toolMessage.Name
|
||||
}
|
||||
if toolName == "" {
|
||||
toolName = "unknown_tool"
|
||||
}
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, ChatContent{
|
||||
Role: "user",
|
||||
Parts: []Part{
|
||||
{
|
||||
responseParts = append(responseParts, Part{
|
||||
FunctionResponse: &FunctionResponse{
|
||||
Name: toolName,
|
||||
Response: map[string]any{"content": message.StringContent()},
|
||||
},
|
||||
},
|
||||
Response: map[string]any{"content": toolMessage.StringContent()},
|
||||
},
|
||||
})
|
||||
}
|
||||
if len(responseParts) > 0 {
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, ChatContent{
|
||||
Role: "user",
|
||||
Parts: responseParts,
|
||||
})
|
||||
}
|
||||
i--
|
||||
continue
|
||||
}
|
||||
|
||||
@ -248,11 +289,25 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
||||
} else {
|
||||
args = map[string]any{}
|
||||
}
|
||||
thoughtSignature := ""
|
||||
if tc.ExtraContent != nil && tc.ExtraContent.Google != nil {
|
||||
thoughtSignature = tc.ExtraContent.Google.ThoughtSignature
|
||||
}
|
||||
if thoughtSignature == "" {
|
||||
thoughtSignature = message.ReasoningEncryptedContent
|
||||
}
|
||||
if thoughtSignature == "" {
|
||||
thoughtSignature = lastKnownThoughtSignature
|
||||
}
|
||||
if thoughtSignature != "" {
|
||||
lastKnownThoughtSignature = thoughtSignature
|
||||
}
|
||||
fcParts = append(fcParts, Part{
|
||||
FunctionCall: &FunctionCall{
|
||||
FunctionName: tc.Function.Name,
|
||||
Arguments: args,
|
||||
},
|
||||
ThoughtSignature: thoughtSignature,
|
||||
})
|
||||
}
|
||||
content.Role = "model"
|
||||
@ -298,6 +353,76 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
||||
return &geminiRequest
|
||||
}
|
||||
|
||||
func sanitizeGeminiToolHistoryMessages(messages []model.Message) []model.Message {
|
||||
if len(messages) == 0 {
|
||||
return messages
|
||||
}
|
||||
sanitized := make([]model.Message, 0, len(messages))
|
||||
droppedToolCallIDs := make(map[string]struct{})
|
||||
lastKnownThoughtSignature := ""
|
||||
|
||||
for _, message := range messages {
|
||||
if message.ReasoningEncryptedContent != "" {
|
||||
lastKnownThoughtSignature = message.ReasoningEncryptedContent
|
||||
}
|
||||
|
||||
if message.Role == "assistant" && len(message.ToolCalls) > 0 {
|
||||
keptToolCalls := make([]model.Tool, 0, len(message.ToolCalls))
|
||||
for _, tc := range message.ToolCalls {
|
||||
thoughtSignature := ""
|
||||
if tc.ExtraContent != nil && tc.ExtraContent.Google != nil {
|
||||
thoughtSignature = strings.TrimSpace(tc.ExtraContent.Google.ThoughtSignature)
|
||||
}
|
||||
if thoughtSignature == "" {
|
||||
thoughtSignature = strings.TrimSpace(message.ReasoningEncryptedContent)
|
||||
}
|
||||
if thoughtSignature == "" {
|
||||
thoughtSignature = strings.TrimSpace(lastKnownThoughtSignature)
|
||||
}
|
||||
if thoughtSignature == "" {
|
||||
if tc.Id != "" {
|
||||
droppedToolCallIDs[tc.Id] = struct{}{}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if tc.ExtraContent == nil {
|
||||
tc.ExtraContent = &model.ToolExtraContent{}
|
||||
}
|
||||
if tc.ExtraContent.Google == nil {
|
||||
tc.ExtraContent.Google = &model.GoogleToolExtraContent{}
|
||||
}
|
||||
tc.ExtraContent.Google.ThoughtSignature = thoughtSignature
|
||||
lastKnownThoughtSignature = thoughtSignature
|
||||
keptToolCalls = append(keptToolCalls, tc)
|
||||
}
|
||||
|
||||
if len(keptToolCalls) == 0 {
|
||||
assistantText := strings.TrimSpace(message.StringContent())
|
||||
if assistantText == "" {
|
||||
continue
|
||||
}
|
||||
message.ToolCalls = nil
|
||||
message.Content = assistantText
|
||||
sanitized = append(sanitized, message)
|
||||
continue
|
||||
}
|
||||
message.ToolCalls = keptToolCalls
|
||||
sanitized = append(sanitized, message)
|
||||
continue
|
||||
}
|
||||
|
||||
if message.Role == "tool" && message.ToolCallId != "" {
|
||||
if _, dropped := droppedToolCallIDs[message.ToolCallId]; dropped {
|
||||
continue
|
||||
}
|
||||
}
|
||||
sanitized = append(sanitized, message)
|
||||
}
|
||||
|
||||
return sanitized
|
||||
}
|
||||
|
||||
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *BatchEmbeddingRequest {
|
||||
inputs := request.ParseInput()
|
||||
requests := make([]EmbeddingRequest, len(inputs))
|
||||
@ -369,15 +494,14 @@ type ChatPromptFeedback struct {
|
||||
|
||||
func getToolCalls(candidate *ChatCandidate) []model.Tool {
|
||||
var toolCalls []model.Tool
|
||||
|
||||
item := candidate.Content.Parts[0]
|
||||
for _, item := range candidate.Content.Parts {
|
||||
if item.FunctionCall == nil {
|
||||
return toolCalls
|
||||
continue
|
||||
}
|
||||
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
|
||||
if err != nil {
|
||||
logger.FatalLog("getToolCalls failed: " + err.Error())
|
||||
return toolCalls
|
||||
continue
|
||||
}
|
||||
toolCall := model.Tool{
|
||||
Id: fmt.Sprintf("call_%s", random.GetUUID()),
|
||||
@ -387,7 +511,15 @@ func getToolCalls(candidate *ChatCandidate) []model.Tool {
|
||||
Name: item.FunctionCall.FunctionName,
|
||||
},
|
||||
}
|
||||
if item.ThoughtSignature != "" {
|
||||
toolCall.ExtraContent = &model.ToolExtraContent{
|
||||
Google: &model.GoogleToolExtraContent{
|
||||
ThoughtSignature: item.ThoughtSignature,
|
||||
},
|
||||
}
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
}
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
@ -407,12 +539,13 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
FinishReason: constant.StopFinishReason,
|
||||
}
|
||||
if len(candidate.Content.Parts) > 0 {
|
||||
if candidate.Content.Parts[0].FunctionCall != nil {
|
||||
choice.Message.ToolCalls = getToolCalls(&candidate)
|
||||
toolCalls := getToolCalls(&candidate)
|
||||
if len(toolCalls) > 0 {
|
||||
choice.Message.ToolCalls = toolCalls
|
||||
} else {
|
||||
var builder strings.Builder
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if i > 0 {
|
||||
for j, part := range candidate.Content.Parts {
|
||||
if j > 0 {
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
builder.WriteString(part.Text)
|
||||
@ -434,10 +567,35 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatC
|
||||
if len(geminiResponse.Candidates) > 0 {
|
||||
var textBuilder strings.Builder
|
||||
var thinkingBuilder strings.Builder
|
||||
var toolCalls []model.Tool
|
||||
for _, part := range geminiResponse.Candidates[0].Content.Parts {
|
||||
if part.Thought {
|
||||
// Thinking/reasoning content — route to reasoning_content field
|
||||
thinkingBuilder.WriteString(part.Text)
|
||||
} else if part.FunctionCall != nil {
|
||||
argsBytes, err := json.Marshal(part.FunctionCall.Arguments)
|
||||
if err != nil {
|
||||
logger.FatalLog("streamResponseGeminiChat2OpenAI marshal args failed: " + err.Error())
|
||||
argsBytes = []byte("{}")
|
||||
}
|
||||
toolCalls = append(toolCalls, model.Tool{
|
||||
Id: fmt.Sprintf("call_%s", random.GetUUID()),
|
||||
Type: "function",
|
||||
Function: model.Function{
|
||||
Name: part.FunctionCall.FunctionName,
|
||||
Arguments: string(argsBytes),
|
||||
},
|
||||
ExtraContent: func() *model.ToolExtraContent {
|
||||
if part.ThoughtSignature == "" {
|
||||
return nil
|
||||
}
|
||||
return &model.ToolExtraContent{
|
||||
Google: &model.GoogleToolExtraContent{
|
||||
ThoughtSignature: part.ThoughtSignature,
|
||||
},
|
||||
}
|
||||
}(),
|
||||
})
|
||||
} else if part.Text != "" {
|
||||
textBuilder.WriteString(part.Text)
|
||||
} else if part.InlineData != nil && part.InlineData.Data != "" {
|
||||
@ -455,6 +613,11 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatC
|
||||
if thinkingBuilder.Len() > 0 {
|
||||
choice.Delta.ReasoningContent = thinkingBuilder.String()
|
||||
}
|
||||
if len(toolCalls) > 0 {
|
||||
choice.Delta.ToolCalls = toolCalls
|
||||
toolFinish := "tool_calls"
|
||||
choice.FinishReason = &toolFinish
|
||||
}
|
||||
}
|
||||
|
||||
var response openai.ChatCompletionsStreamResponse
|
||||
|
||||
@ -55,6 +55,7 @@ type Part struct {
|
||||
InlineData *InlineData `json:"inlineData,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
|
||||
ThoughtSignature string `json:"thoughtSignature,omitempty"`
|
||||
// Thought marks this part as internal reasoning/thinking content (Gemini thinking models)
|
||||
Thought bool `json:"thought,omitempty"`
|
||||
}
|
||||
|
||||
@ -32,6 +32,8 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sanitized := sanitizeToolChoiceForStrictCompat(textRequest)
|
||||
c.Set("sanitized_text_request_modified", sanitized)
|
||||
if relayMode == relaymode.Moderations && textRequest.Model == "" {
|
||||
textRequest.Model = "text-moderation-latest"
|
||||
}
|
||||
@ -45,6 +47,62 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener
|
||||
return textRequest, nil
|
||||
}
|
||||
|
||||
func sanitizeToolChoiceForStrictCompat(req *relaymodel.GeneralOpenAIRequest) bool {
|
||||
if req == nil {
|
||||
return false
|
||||
}
|
||||
modified := stripDisableParallelToolUse(&req.ToolChoice)
|
||||
// Strict OpenAI-compatible gateways often reject object-form tool_choice.
|
||||
// Coerce object/array forms to "auto" when tools/functions exist.
|
||||
switch req.ToolChoice.(type) {
|
||||
case map[string]any, []any:
|
||||
if len(req.Tools) > 0 || req.Functions != nil {
|
||||
if req.ToolChoice != "auto" {
|
||||
modified = true
|
||||
}
|
||||
req.ToolChoice = "auto"
|
||||
} else {
|
||||
if req.ToolChoice != nil {
|
||||
modified = true
|
||||
}
|
||||
req.ToolChoice = nil
|
||||
}
|
||||
}
|
||||
return modified
|
||||
}
|
||||
|
||||
func stripDisableParallelToolUse(v *any) bool {
|
||||
if v == nil || *v == nil {
|
||||
return false
|
||||
}
|
||||
modified := false
|
||||
switch val := (*v).(type) {
|
||||
case map[string]any:
|
||||
if _, ok := val["disable_parallel_tool_use"]; ok {
|
||||
delete(val, "disable_parallel_tool_use")
|
||||
modified = true
|
||||
}
|
||||
for k := range val {
|
||||
child := val[k]
|
||||
if stripDisableParallelToolUse(&child) {
|
||||
modified = true
|
||||
}
|
||||
val[k] = child
|
||||
}
|
||||
*v = val
|
||||
case []any:
|
||||
for i := range val {
|
||||
child := val[i]
|
||||
if stripDisableParallelToolUse(&child) {
|
||||
modified = true
|
||||
}
|
||||
val[i] = child
|
||||
}
|
||||
*v = val
|
||||
}
|
||||
return modified
|
||||
}
|
||||
|
||||
func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
|
||||
switch relayMode {
|
||||
case relaymode.ChatCompletions:
|
||||
|
||||
@ -2,7 +2,9 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@ -24,7 +26,14 @@ func RelayProxyHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
}
|
||||
adaptor.Init(meta)
|
||||
|
||||
resp, err := adaptor.DoRequest(c, meta, c.Request.Body)
|
||||
requestBody, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "read proxy request body failed: %s", err.Error())
|
||||
return openai.ErrorWrapper(err, "read_request_body_failed", http.StatusBadRequest)
|
||||
}
|
||||
requestBody = sanitizeStrictToolChoiceJSON(requestBody)
|
||||
|
||||
resp, err := adaptor.DoRequest(c, meta, bytes.NewBuffer(requestBody))
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "DoRequest failed: %s", err.Error())
|
||||
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
|
||||
@ -9,15 +9,12 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/apitype"
|
||||
"github.com/songquanpeng/one-api/relay/billing"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
@ -98,14 +95,9 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
}
|
||||
|
||||
func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) {
|
||||
if !config.EnforceIncludeUsage &&
|
||||
meta.APIType == apitype.OpenAI &&
|
||||
meta.OriginModelName == meta.ActualModelName &&
|
||||
meta.ChannelType != channeltype.Baichuan &&
|
||||
meta.ForcedSystemPrompt == "" {
|
||||
// no need to convert request for openai
|
||||
return c.Request.Body, nil
|
||||
}
|
||||
// Always convert request through adaptor before sending upstream.
|
||||
// This avoids passthrough edge cases where strict-compat sanitization may be bypassed.
|
||||
// The slight performance cost is acceptable compared to malformed request risk.
|
||||
|
||||
// get request body
|
||||
var requestBody io.Reader
|
||||
@ -119,7 +111,72 @@ func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralO
|
||||
logger.Debugf(c.Request.Context(), "converted request json_marshal_failed: %s\n", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
jsonData = sanitizeStrictToolChoiceJSON(jsonData)
|
||||
logger.Debugf(c.Request.Context(), "converted request: \n%s", string(jsonData))
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
return requestBody, nil
|
||||
}
|
||||
|
||||
// Final outbound guard for strict OpenAI-compatible gateways:
|
||||
// - remove any disable_parallel_tool_use key at any depth
|
||||
// - coerce object/array tool_choice to OpenAI-compatible string form
|
||||
func sanitizeStrictToolChoiceJSON(data []byte) []byte {
|
||||
var payload any
|
||||
if err := json.Unmarshal(data, &payload); err != nil {
|
||||
return data
|
||||
}
|
||||
|
||||
root, ok := payload.(map[string]any)
|
||||
if !ok {
|
||||
return data
|
||||
}
|
||||
|
||||
cleaned := stripDisableParallelToolUseAny(root)
|
||||
root, ok = cleaned.(map[string]any)
|
||||
if !ok {
|
||||
return data
|
||||
}
|
||||
|
||||
hasTools := false
|
||||
if tools, ok := root["tools"].([]any); ok && len(tools) > 0 {
|
||||
hasTools = true
|
||||
}
|
||||
if functions, ok := root["functions"]; ok && functions != nil {
|
||||
hasTools = true
|
||||
}
|
||||
|
||||
if tc, ok := root["tool_choice"]; ok {
|
||||
switch tc.(type) {
|
||||
case map[string]any, []any:
|
||||
if hasTools {
|
||||
root["tool_choice"] = "auto"
|
||||
} else {
|
||||
delete(root, "tool_choice")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out, err := json.Marshal(root)
|
||||
if err != nil {
|
||||
return data
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func stripDisableParallelToolUseAny(v any) any {
|
||||
switch val := v.(type) {
|
||||
case map[string]any:
|
||||
delete(val, "disable_parallel_tool_use")
|
||||
for k, child := range val {
|
||||
val[k] = stripDisableParallelToolUseAny(child)
|
||||
}
|
||||
return val
|
||||
case []any:
|
||||
for i, child := range val {
|
||||
val[i] = stripDisableParallelToolUseAny(child)
|
||||
}
|
||||
return val
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
@ -4,6 +4,7 @@ type Message struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content any `json:"content,omitempty"`
|
||||
ReasoningContent any `json:"reasoning_content,omitempty"`
|
||||
ReasoningEncryptedContent string `json:"reasoning_encrypted_content,omitempty"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
ToolCalls []Tool `json:"tool_calls,omitempty"`
|
||||
ToolCallId string `json:"tool_call_id,omitempty"`
|
||||
|
||||
@ -4,6 +4,7 @@ type Tool struct {
|
||||
Id string `json:"id,omitempty"`
|
||||
Type string `json:"type,omitempty"` // when splicing claude tools stream messages, it is empty
|
||||
Function Function `json:"function"`
|
||||
ExtraContent *ToolExtraContent `json:"extra_content,omitempty"`
|
||||
}
|
||||
|
||||
type Function struct {
|
||||
@ -12,3 +13,11 @@ type Function struct {
|
||||
Parameters any `json:"parameters,omitempty"` // request
|
||||
Arguments any `json:"arguments,omitempty"` // response
|
||||
}
|
||||
|
||||
type ToolExtraContent struct {
|
||||
Google *GoogleToolExtraContent `json:"google,omitempty"`
|
||||
}
|
||||
|
||||
type GoogleToolExtraContent struct {
|
||||
ThoughtSignature string `json:"thought_signature,omitempty"`
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user