feat: enhance strict compatibility for OpenAI requests
- Implement sanitization for `tool_choice` and removal of `disable_parallel_tool_use` in request payloads. - Introduce logging for tool choice changes in `DoRequestHelper`. - Update `ConvertRequest` to handle tool-call compatibility and maintain structured tool history. - Add `ThoughtSignature` to `Part` struct for better tracking of reasoning content. - Refactor request handling in `getRequestBody` to ensure strict compliance with OpenAI API requirements.
This commit is contained in:
parent
f67f4b8caf
commit
8b87c3d404
@ -1,10 +1,13 @@
|
|||||||
package adaptor
|
package adaptor
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/songquanpeng/one-api/common/client"
|
"github.com/songquanpeng/one-api/common/client"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/relay/meta"
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -24,6 +27,29 @@ func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io.
|
|||||||
return nil, fmt.Errorf("get request url failed: %w", err)
|
return nil, fmt.Errorf("get request url failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if requestBody != nil {
|
||||||
|
raw, readErr := io.ReadAll(requestBody)
|
||||||
|
if readErr != nil {
|
||||||
|
return nil, fmt.Errorf("read request body failed: %w", readErr)
|
||||||
|
}
|
||||||
|
hasDisableParallelBefore := bytes.Contains(raw, []byte("disable_parallel_tool_use"))
|
||||||
|
beforeToolChoice := extractToolChoiceForLog(raw)
|
||||||
|
raw = sanitizeStrictToolChoicePayload(raw)
|
||||||
|
hasDisableParallelAfter := bytes.Contains(raw, []byte("disable_parallel_tool_use"))
|
||||||
|
afterToolChoice := extractToolChoiceForLog(raw)
|
||||||
|
logger.Infof(
|
||||||
|
c.Request.Context(),
|
||||||
|
"[DoRequestHelper] outbound %s %s tool_choice(before=%s, after=%s) disable_parallel(before=%t, after=%t)",
|
||||||
|
c.Request.Method,
|
||||||
|
c.Request.URL.Path,
|
||||||
|
beforeToolChoice,
|
||||||
|
afterToolChoice,
|
||||||
|
hasDisableParallelBefore,
|
||||||
|
hasDisableParallelAfter,
|
||||||
|
)
|
||||||
|
requestBody = bytes.NewBuffer(raw)
|
||||||
|
}
|
||||||
|
|
||||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("new request failed: %w", err)
|
return nil, fmt.Errorf("new request failed: %w", err)
|
||||||
@ -51,3 +77,80 @@ func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
|
|||||||
_ = c.Request.Body.Close()
|
_ = c.Request.Body.Close()
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Global outbound guard for strict OpenAI-compatible gateways:
|
||||||
|
// remove disable_parallel_tool_use recursively and normalize object tool_choice.
|
||||||
|
func sanitizeStrictToolChoicePayload(data []byte) []byte {
|
||||||
|
var payload any
|
||||||
|
if err := json.Unmarshal(data, &payload); err != nil {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
root, ok := payload.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
root = stripDisableParallelKeyAny(root).(map[string]any)
|
||||||
|
|
||||||
|
hasTools := false
|
||||||
|
if tools, ok := root["tools"].([]any); ok && len(tools) > 0 {
|
||||||
|
hasTools = true
|
||||||
|
}
|
||||||
|
if functions, ok := root["functions"]; ok && functions != nil {
|
||||||
|
hasTools = true
|
||||||
|
}
|
||||||
|
if tc, ok := root["tool_choice"]; ok {
|
||||||
|
switch tc.(type) {
|
||||||
|
case map[string]any, []any:
|
||||||
|
if hasTools {
|
||||||
|
root["tool_choice"] = "auto"
|
||||||
|
} else {
|
||||||
|
delete(root, "tool_choice")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := json.Marshal(root)
|
||||||
|
if err != nil {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripDisableParallelKeyAny(v any) any {
|
||||||
|
switch val := v.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
delete(val, "disable_parallel_tool_use")
|
||||||
|
for k, child := range val {
|
||||||
|
val[k] = stripDisableParallelKeyAny(child)
|
||||||
|
}
|
||||||
|
return val
|
||||||
|
case []any:
|
||||||
|
for i, child := range val {
|
||||||
|
val[i] = stripDisableParallelKeyAny(child)
|
||||||
|
}
|
||||||
|
return val
|
||||||
|
default:
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractToolChoiceForLog(data []byte) string {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal(data, &payload); err != nil {
|
||||||
|
return "<invalid-json>"
|
||||||
|
}
|
||||||
|
tc, ok := payload["tool_choice"]
|
||||||
|
if !ok {
|
||||||
|
return "<absent>"
|
||||||
|
}
|
||||||
|
switch v := tc.(type) {
|
||||||
|
case string:
|
||||||
|
return v
|
||||||
|
default:
|
||||||
|
out, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return "<unmarshalable>"
|
||||||
|
}
|
||||||
|
return string(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -59,6 +59,7 @@ func sanitizeSchema(v interface{}) interface{} {
|
|||||||
|
|
||||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
||||||
|
messagesForConvert := textRequest.Messages
|
||||||
geminiRequest := ChatRequest{
|
geminiRequest := ChatRequest{
|
||||||
Contents: make([]ChatContent, 0, len(textRequest.Messages)),
|
Contents: make([]ChatContent, 0, len(textRequest.Messages)),
|
||||||
SafetySettings: []ChatSafetySettings{
|
SafetySettings: []ChatSafetySettings{
|
||||||
@ -104,11 +105,6 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
|||||||
geminiRequest.GenerationConfig.ResponseModalities = []string{"TEXT", "IMAGE"}
|
geminiRequest.GenerationConfig.ResponseModalities = []string{"TEXT", "IMAGE"}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enable thinking when the client explicitly requests it via enable_thinking=true.
|
|
||||||
// Use thinkingBudget=-1 (dynamic) so Gemini decides the appropriate budget.
|
|
||||||
if textRequest.EnableThinking {
|
|
||||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ThinkingBudget: -1}
|
|
||||||
}
|
|
||||||
if textRequest.Tools != nil {
|
if textRequest.Tools != nil {
|
||||||
functions := make([]model.Function, 0, len(textRequest.Tools))
|
functions := make([]model.Function, 0, len(textRequest.Tools))
|
||||||
for _, tool := range textRequest.Tools {
|
for _, tool := range textRequest.Tools {
|
||||||
@ -130,42 +126,87 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tool-call compatibility mode:
|
||||||
|
// Keep tool rounds on strict structured functionCall/functionResponse flow.
|
||||||
|
// Gemini thinking tool rounds behind gateways are prone to thought_signature
|
||||||
|
// validation failures; disable thinking config for tool flow to avoid fallback
|
||||||
|
// text replay and preserve structured tool history.
|
||||||
|
hasToolFlow := len(geminiRequest.Tools) > 0
|
||||||
|
if !hasToolFlow {
|
||||||
|
for _, msg := range messagesForConvert {
|
||||||
|
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
|
||||||
|
hasToolFlow = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if hasToolFlow {
|
||||||
|
// Remove incompatible historical tool-call turns that cannot satisfy
|
||||||
|
// Gemini thought_signature requirements. Keep structured tool flow for
|
||||||
|
// the remaining valid turns (no text downgrade replay).
|
||||||
|
messagesForConvert = sanitizeGeminiToolHistoryMessages(messagesForConvert)
|
||||||
|
}
|
||||||
|
// Enable thinking only for non-tool turns.
|
||||||
|
if textRequest.EnableThinking && !hasToolFlow {
|
||||||
|
// Use thinkingBudget=-1 (dynamic) so Gemini decides the appropriate budget.
|
||||||
|
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ThinkingBudget: -1}
|
||||||
|
} else if hasToolFlow {
|
||||||
|
geminiRequest.GenerationConfig.ThinkingConfig = nil
|
||||||
|
}
|
||||||
// Build a map from tool_call_id → function name for resolving tool result names
|
// Build a map from tool_call_id → function name for resolving tool result names
|
||||||
toolCallIdToName := map[string]string{}
|
toolCallIdToName := map[string]string{}
|
||||||
for _, message := range textRequest.Messages {
|
lastKnownThoughtSignature := ""
|
||||||
|
for _, message := range messagesForConvert {
|
||||||
|
if message.ReasoningEncryptedContent != "" {
|
||||||
|
lastKnownThoughtSignature = message.ReasoningEncryptedContent
|
||||||
|
}
|
||||||
if message.Role == "assistant" {
|
if message.Role == "assistant" {
|
||||||
for _, tc := range message.ToolCalls {
|
for _, tc := range message.ToolCalls {
|
||||||
if tc.Id != "" && tc.Function.Name != "" {
|
if tc.Id != "" && tc.Function.Name != "" {
|
||||||
toolCallIdToName[tc.Id] = tc.Function.Name
|
toolCallIdToName[tc.Id] = tc.Function.Name
|
||||||
}
|
}
|
||||||
|
if tc.ExtraContent != nil && tc.ExtraContent.Google != nil && tc.ExtraContent.Google.ThoughtSignature != "" {
|
||||||
|
lastKnownThoughtSignature = tc.ExtraContent.Google.ThoughtSignature
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
shouldAddDummyModelMessage := false
|
shouldAddDummyModelMessage := false
|
||||||
for _, message := range textRequest.Messages {
|
for i := 0; i < len(messagesForConvert); i++ {
|
||||||
|
message := messagesForConvert[i]
|
||||||
// --- tool result: role=tool → Gemini functionResponse (user role) ---
|
// --- tool result: role=tool → Gemini functionResponse (user role) ---
|
||||||
|
// Gemini requires functionResponse parts to match functionCall parts from the
|
||||||
|
// preceding model turn. For OpenAI-compat payloads with multiple consecutive
|
||||||
|
// role=tool messages, merge them into a single user turn with multiple parts.
|
||||||
if message.Role == "tool" {
|
if message.Role == "tool" {
|
||||||
toolName := message.ToolCallId
|
var responseParts []Part
|
||||||
if name, ok := toolCallIdToName[message.ToolCallId]; ok {
|
for ; i < len(messagesForConvert) && messagesForConvert[i].Role == "tool"; i++ {
|
||||||
toolName = name
|
toolMessage := messagesForConvert[i]
|
||||||
} else if message.Name != nil && *message.Name != "" {
|
toolName := toolMessage.ToolCallId
|
||||||
toolName = *message.Name
|
if name, ok := toolCallIdToName[toolMessage.ToolCallId]; ok {
|
||||||
}
|
toolName = name
|
||||||
if toolName == "" {
|
} else if toolMessage.Name != nil && *toolMessage.Name != "" {
|
||||||
toolName = "unknown_tool"
|
toolName = *toolMessage.Name
|
||||||
}
|
}
|
||||||
geminiRequest.Contents = append(geminiRequest.Contents, ChatContent{
|
if toolName == "" {
|
||||||
Role: "user",
|
toolName = "unknown_tool"
|
||||||
Parts: []Part{
|
}
|
||||||
{
|
responseParts = append(responseParts, Part{
|
||||||
FunctionResponse: &FunctionResponse{
|
FunctionResponse: &FunctionResponse{
|
||||||
Name: toolName,
|
Name: toolName,
|
||||||
Response: map[string]any{"content": message.StringContent()},
|
Response: map[string]any{"content": toolMessage.StringContent()},
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
})
|
||||||
})
|
}
|
||||||
|
if len(responseParts) > 0 {
|
||||||
|
geminiRequest.Contents = append(geminiRequest.Contents, ChatContent{
|
||||||
|
Role: "user",
|
||||||
|
Parts: responseParts,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
i--
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -248,11 +289,25 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
|||||||
} else {
|
} else {
|
||||||
args = map[string]any{}
|
args = map[string]any{}
|
||||||
}
|
}
|
||||||
|
thoughtSignature := ""
|
||||||
|
if tc.ExtraContent != nil && tc.ExtraContent.Google != nil {
|
||||||
|
thoughtSignature = tc.ExtraContent.Google.ThoughtSignature
|
||||||
|
}
|
||||||
|
if thoughtSignature == "" {
|
||||||
|
thoughtSignature = message.ReasoningEncryptedContent
|
||||||
|
}
|
||||||
|
if thoughtSignature == "" {
|
||||||
|
thoughtSignature = lastKnownThoughtSignature
|
||||||
|
}
|
||||||
|
if thoughtSignature != "" {
|
||||||
|
lastKnownThoughtSignature = thoughtSignature
|
||||||
|
}
|
||||||
fcParts = append(fcParts, Part{
|
fcParts = append(fcParts, Part{
|
||||||
FunctionCall: &FunctionCall{
|
FunctionCall: &FunctionCall{
|
||||||
FunctionName: tc.Function.Name,
|
FunctionName: tc.Function.Name,
|
||||||
Arguments: args,
|
Arguments: args,
|
||||||
},
|
},
|
||||||
|
ThoughtSignature: thoughtSignature,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
content.Role = "model"
|
content.Role = "model"
|
||||||
@ -298,6 +353,76 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
|||||||
return &geminiRequest
|
return &geminiRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sanitizeGeminiToolHistoryMessages(messages []model.Message) []model.Message {
|
||||||
|
if len(messages) == 0 {
|
||||||
|
return messages
|
||||||
|
}
|
||||||
|
sanitized := make([]model.Message, 0, len(messages))
|
||||||
|
droppedToolCallIDs := make(map[string]struct{})
|
||||||
|
lastKnownThoughtSignature := ""
|
||||||
|
|
||||||
|
for _, message := range messages {
|
||||||
|
if message.ReasoningEncryptedContent != "" {
|
||||||
|
lastKnownThoughtSignature = message.ReasoningEncryptedContent
|
||||||
|
}
|
||||||
|
|
||||||
|
if message.Role == "assistant" && len(message.ToolCalls) > 0 {
|
||||||
|
keptToolCalls := make([]model.Tool, 0, len(message.ToolCalls))
|
||||||
|
for _, tc := range message.ToolCalls {
|
||||||
|
thoughtSignature := ""
|
||||||
|
if tc.ExtraContent != nil && tc.ExtraContent.Google != nil {
|
||||||
|
thoughtSignature = strings.TrimSpace(tc.ExtraContent.Google.ThoughtSignature)
|
||||||
|
}
|
||||||
|
if thoughtSignature == "" {
|
||||||
|
thoughtSignature = strings.TrimSpace(message.ReasoningEncryptedContent)
|
||||||
|
}
|
||||||
|
if thoughtSignature == "" {
|
||||||
|
thoughtSignature = strings.TrimSpace(lastKnownThoughtSignature)
|
||||||
|
}
|
||||||
|
if thoughtSignature == "" {
|
||||||
|
if tc.Id != "" {
|
||||||
|
droppedToolCallIDs[tc.Id] = struct{}{}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.ExtraContent == nil {
|
||||||
|
tc.ExtraContent = &model.ToolExtraContent{}
|
||||||
|
}
|
||||||
|
if tc.ExtraContent.Google == nil {
|
||||||
|
tc.ExtraContent.Google = &model.GoogleToolExtraContent{}
|
||||||
|
}
|
||||||
|
tc.ExtraContent.Google.ThoughtSignature = thoughtSignature
|
||||||
|
lastKnownThoughtSignature = thoughtSignature
|
||||||
|
keptToolCalls = append(keptToolCalls, tc)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(keptToolCalls) == 0 {
|
||||||
|
assistantText := strings.TrimSpace(message.StringContent())
|
||||||
|
if assistantText == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
message.ToolCalls = nil
|
||||||
|
message.Content = assistantText
|
||||||
|
sanitized = append(sanitized, message)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
message.ToolCalls = keptToolCalls
|
||||||
|
sanitized = append(sanitized, message)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if message.Role == "tool" && message.ToolCallId != "" {
|
||||||
|
if _, dropped := droppedToolCallIDs[message.ToolCallId]; dropped {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sanitized = append(sanitized, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
}
|
||||||
|
|
||||||
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *BatchEmbeddingRequest {
|
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *BatchEmbeddingRequest {
|
||||||
inputs := request.ParseInput()
|
inputs := request.ParseInput()
|
||||||
requests := make([]EmbeddingRequest, len(inputs))
|
requests := make([]EmbeddingRequest, len(inputs))
|
||||||
@ -369,25 +494,32 @@ type ChatPromptFeedback struct {
|
|||||||
|
|
||||||
func getToolCalls(candidate *ChatCandidate) []model.Tool {
|
func getToolCalls(candidate *ChatCandidate) []model.Tool {
|
||||||
var toolCalls []model.Tool
|
var toolCalls []model.Tool
|
||||||
|
for _, item := range candidate.Content.Parts {
|
||||||
item := candidate.Content.Parts[0]
|
if item.FunctionCall == nil {
|
||||||
if item.FunctionCall == nil {
|
continue
|
||||||
return toolCalls
|
}
|
||||||
|
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
|
||||||
|
if err != nil {
|
||||||
|
logger.FatalLog("getToolCalls failed: " + err.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
toolCall := model.Tool{
|
||||||
|
Id: fmt.Sprintf("call_%s", random.GetUUID()),
|
||||||
|
Type: "function",
|
||||||
|
Function: model.Function{
|
||||||
|
Arguments: string(argsBytes),
|
||||||
|
Name: item.FunctionCall.FunctionName,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if item.ThoughtSignature != "" {
|
||||||
|
toolCall.ExtraContent = &model.ToolExtraContent{
|
||||||
|
Google: &model.GoogleToolExtraContent{
|
||||||
|
ThoughtSignature: item.ThoughtSignature,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
toolCalls = append(toolCalls, toolCall)
|
||||||
}
|
}
|
||||||
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
|
|
||||||
if err != nil {
|
|
||||||
logger.FatalLog("getToolCalls failed: " + err.Error())
|
|
||||||
return toolCalls
|
|
||||||
}
|
|
||||||
toolCall := model.Tool{
|
|
||||||
Id: fmt.Sprintf("call_%s", random.GetUUID()),
|
|
||||||
Type: "function",
|
|
||||||
Function: model.Function{
|
|
||||||
Arguments: string(argsBytes),
|
|
||||||
Name: item.FunctionCall.FunctionName,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
toolCalls = append(toolCalls, toolCall)
|
|
||||||
return toolCalls
|
return toolCalls
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -407,12 +539,13 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
|
|||||||
FinishReason: constant.StopFinishReason,
|
FinishReason: constant.StopFinishReason,
|
||||||
}
|
}
|
||||||
if len(candidate.Content.Parts) > 0 {
|
if len(candidate.Content.Parts) > 0 {
|
||||||
if candidate.Content.Parts[0].FunctionCall != nil {
|
toolCalls := getToolCalls(&candidate)
|
||||||
choice.Message.ToolCalls = getToolCalls(&candidate)
|
if len(toolCalls) > 0 {
|
||||||
|
choice.Message.ToolCalls = toolCalls
|
||||||
} else {
|
} else {
|
||||||
var builder strings.Builder
|
var builder strings.Builder
|
||||||
for _, part := range candidate.Content.Parts {
|
for j, part := range candidate.Content.Parts {
|
||||||
if i > 0 {
|
if j > 0 {
|
||||||
builder.WriteString("\n")
|
builder.WriteString("\n")
|
||||||
}
|
}
|
||||||
builder.WriteString(part.Text)
|
builder.WriteString(part.Text)
|
||||||
@ -434,10 +567,35 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatC
|
|||||||
if len(geminiResponse.Candidates) > 0 {
|
if len(geminiResponse.Candidates) > 0 {
|
||||||
var textBuilder strings.Builder
|
var textBuilder strings.Builder
|
||||||
var thinkingBuilder strings.Builder
|
var thinkingBuilder strings.Builder
|
||||||
|
var toolCalls []model.Tool
|
||||||
for _, part := range geminiResponse.Candidates[0].Content.Parts {
|
for _, part := range geminiResponse.Candidates[0].Content.Parts {
|
||||||
if part.Thought {
|
if part.Thought {
|
||||||
// Thinking/reasoning content — route to reasoning_content field
|
// Thinking/reasoning content — route to reasoning_content field
|
||||||
thinkingBuilder.WriteString(part.Text)
|
thinkingBuilder.WriteString(part.Text)
|
||||||
|
} else if part.FunctionCall != nil {
|
||||||
|
argsBytes, err := json.Marshal(part.FunctionCall.Arguments)
|
||||||
|
if err != nil {
|
||||||
|
logger.FatalLog("streamResponseGeminiChat2OpenAI marshal args failed: " + err.Error())
|
||||||
|
argsBytes = []byte("{}")
|
||||||
|
}
|
||||||
|
toolCalls = append(toolCalls, model.Tool{
|
||||||
|
Id: fmt.Sprintf("call_%s", random.GetUUID()),
|
||||||
|
Type: "function",
|
||||||
|
Function: model.Function{
|
||||||
|
Name: part.FunctionCall.FunctionName,
|
||||||
|
Arguments: string(argsBytes),
|
||||||
|
},
|
||||||
|
ExtraContent: func() *model.ToolExtraContent {
|
||||||
|
if part.ThoughtSignature == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &model.ToolExtraContent{
|
||||||
|
Google: &model.GoogleToolExtraContent{
|
||||||
|
ThoughtSignature: part.ThoughtSignature,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}(),
|
||||||
|
})
|
||||||
} else if part.Text != "" {
|
} else if part.Text != "" {
|
||||||
textBuilder.WriteString(part.Text)
|
textBuilder.WriteString(part.Text)
|
||||||
} else if part.InlineData != nil && part.InlineData.Data != "" {
|
} else if part.InlineData != nil && part.InlineData.Data != "" {
|
||||||
@ -455,6 +613,11 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatC
|
|||||||
if thinkingBuilder.Len() > 0 {
|
if thinkingBuilder.Len() > 0 {
|
||||||
choice.Delta.ReasoningContent = thinkingBuilder.String()
|
choice.Delta.ReasoningContent = thinkingBuilder.String()
|
||||||
}
|
}
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
choice.Delta.ToolCalls = toolCalls
|
||||||
|
toolFinish := "tool_calls"
|
||||||
|
choice.FinishReason = &toolFinish
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var response openai.ChatCompletionsStreamResponse
|
var response openai.ChatCompletionsStreamResponse
|
||||||
|
|||||||
@ -55,6 +55,7 @@ type Part struct {
|
|||||||
InlineData *InlineData `json:"inlineData,omitempty"`
|
InlineData *InlineData `json:"inlineData,omitempty"`
|
||||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||||
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
|
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
|
||||||
|
ThoughtSignature string `json:"thoughtSignature,omitempty"`
|
||||||
// Thought marks this part as internal reasoning/thinking content (Gemini thinking models)
|
// Thought marks this part as internal reasoning/thinking content (Gemini thinking models)
|
||||||
Thought bool `json:"thought,omitempty"`
|
Thought bool `json:"thought,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@ -32,6 +32,8 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
sanitized := sanitizeToolChoiceForStrictCompat(textRequest)
|
||||||
|
c.Set("sanitized_text_request_modified", sanitized)
|
||||||
if relayMode == relaymode.Moderations && textRequest.Model == "" {
|
if relayMode == relaymode.Moderations && textRequest.Model == "" {
|
||||||
textRequest.Model = "text-moderation-latest"
|
textRequest.Model = "text-moderation-latest"
|
||||||
}
|
}
|
||||||
@ -45,6 +47,62 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener
|
|||||||
return textRequest, nil
|
return textRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sanitizeToolChoiceForStrictCompat(req *relaymodel.GeneralOpenAIRequest) bool {
|
||||||
|
if req == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
modified := stripDisableParallelToolUse(&req.ToolChoice)
|
||||||
|
// Strict OpenAI-compatible gateways often reject object-form tool_choice.
|
||||||
|
// Coerce object/array forms to "auto" when tools/functions exist.
|
||||||
|
switch req.ToolChoice.(type) {
|
||||||
|
case map[string]any, []any:
|
||||||
|
if len(req.Tools) > 0 || req.Functions != nil {
|
||||||
|
if req.ToolChoice != "auto" {
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
req.ToolChoice = "auto"
|
||||||
|
} else {
|
||||||
|
if req.ToolChoice != nil {
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
req.ToolChoice = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return modified
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripDisableParallelToolUse(v *any) bool {
|
||||||
|
if v == nil || *v == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
modified := false
|
||||||
|
switch val := (*v).(type) {
|
||||||
|
case map[string]any:
|
||||||
|
if _, ok := val["disable_parallel_tool_use"]; ok {
|
||||||
|
delete(val, "disable_parallel_tool_use")
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
for k := range val {
|
||||||
|
child := val[k]
|
||||||
|
if stripDisableParallelToolUse(&child) {
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
val[k] = child
|
||||||
|
}
|
||||||
|
*v = val
|
||||||
|
case []any:
|
||||||
|
for i := range val {
|
||||||
|
child := val[i]
|
||||||
|
if stripDisableParallelToolUse(&child) {
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
val[i] = child
|
||||||
|
}
|
||||||
|
*v = val
|
||||||
|
}
|
||||||
|
return modified
|
||||||
|
}
|
||||||
|
|
||||||
func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
|
func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
|
||||||
switch relayMode {
|
switch relayMode {
|
||||||
case relaymode.ChatCompletions:
|
case relaymode.ChatCompletions:
|
||||||
|
|||||||
@ -2,7 +2,9 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@ -24,7 +26,14 @@ func RelayProxyHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
}
|
}
|
||||||
adaptor.Init(meta)
|
adaptor.Init(meta)
|
||||||
|
|
||||||
resp, err := adaptor.DoRequest(c, meta, c.Request.Body)
|
requestBody, err := io.ReadAll(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf(ctx, "read proxy request body failed: %s", err.Error())
|
||||||
|
return openai.ErrorWrapper(err, "read_request_body_failed", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
requestBody = sanitizeStrictToolChoiceJSON(requestBody)
|
||||||
|
|
||||||
|
resp, err := adaptor.DoRequest(c, meta, bytes.NewBuffer(requestBody))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf(ctx, "DoRequest failed: %s", err.Error())
|
logger.Errorf(ctx, "DoRequest failed: %s", err.Error())
|
||||||
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
|
|||||||
@ -9,15 +9,12 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/relay"
|
"github.com/songquanpeng/one-api/relay"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
"github.com/songquanpeng/one-api/relay/apitype"
|
|
||||||
"github.com/songquanpeng/one-api/relay/billing"
|
"github.com/songquanpeng/one-api/relay/billing"
|
||||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
|
||||||
"github.com/songquanpeng/one-api/relay/meta"
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
)
|
)
|
||||||
@ -98,14 +95,9 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) {
|
func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) {
|
||||||
if !config.EnforceIncludeUsage &&
|
// Always convert request through adaptor before sending upstream.
|
||||||
meta.APIType == apitype.OpenAI &&
|
// This avoids passthrough edge cases where strict-compat sanitization may be bypassed.
|
||||||
meta.OriginModelName == meta.ActualModelName &&
|
// The slight performance cost is acceptable compared to malformed request risk.
|
||||||
meta.ChannelType != channeltype.Baichuan &&
|
|
||||||
meta.ForcedSystemPrompt == "" {
|
|
||||||
// no need to convert request for openai
|
|
||||||
return c.Request.Body, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// get request body
|
// get request body
|
||||||
var requestBody io.Reader
|
var requestBody io.Reader
|
||||||
@ -119,7 +111,72 @@ func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralO
|
|||||||
logger.Debugf(c.Request.Context(), "converted request json_marshal_failed: %s\n", err.Error())
|
logger.Debugf(c.Request.Context(), "converted request json_marshal_failed: %s\n", err.Error())
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
jsonData = sanitizeStrictToolChoiceJSON(jsonData)
|
||||||
logger.Debugf(c.Request.Context(), "converted request: \n%s", string(jsonData))
|
logger.Debugf(c.Request.Context(), "converted request: \n%s", string(jsonData))
|
||||||
requestBody = bytes.NewBuffer(jsonData)
|
requestBody = bytes.NewBuffer(jsonData)
|
||||||
return requestBody, nil
|
return requestBody, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Final outbound guard for strict OpenAI-compatible gateways:
|
||||||
|
// - remove any disable_parallel_tool_use key at any depth
|
||||||
|
// - coerce object/array tool_choice to OpenAI-compatible string form
|
||||||
|
func sanitizeStrictToolChoiceJSON(data []byte) []byte {
|
||||||
|
var payload any
|
||||||
|
if err := json.Unmarshal(data, &payload); err != nil {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
root, ok := payload.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
cleaned := stripDisableParallelToolUseAny(root)
|
||||||
|
root, ok = cleaned.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
hasTools := false
|
||||||
|
if tools, ok := root["tools"].([]any); ok && len(tools) > 0 {
|
||||||
|
hasTools = true
|
||||||
|
}
|
||||||
|
if functions, ok := root["functions"]; ok && functions != nil {
|
||||||
|
hasTools = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc, ok := root["tool_choice"]; ok {
|
||||||
|
switch tc.(type) {
|
||||||
|
case map[string]any, []any:
|
||||||
|
if hasTools {
|
||||||
|
root["tool_choice"] = "auto"
|
||||||
|
} else {
|
||||||
|
delete(root, "tool_choice")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := json.Marshal(root)
|
||||||
|
if err != nil {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripDisableParallelToolUseAny(v any) any {
|
||||||
|
switch val := v.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
delete(val, "disable_parallel_tool_use")
|
||||||
|
for k, child := range val {
|
||||||
|
val[k] = stripDisableParallelToolUseAny(child)
|
||||||
|
}
|
||||||
|
return val
|
||||||
|
case []any:
|
||||||
|
for i, child := range val {
|
||||||
|
val[i] = stripDisableParallelToolUseAny(child)
|
||||||
|
}
|
||||||
|
return val
|
||||||
|
default:
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -4,6 +4,7 @@ type Message struct {
|
|||||||
Role string `json:"role,omitempty"`
|
Role string `json:"role,omitempty"`
|
||||||
Content any `json:"content,omitempty"`
|
Content any `json:"content,omitempty"`
|
||||||
ReasoningContent any `json:"reasoning_content,omitempty"`
|
ReasoningContent any `json:"reasoning_content,omitempty"`
|
||||||
|
ReasoningEncryptedContent string `json:"reasoning_encrypted_content,omitempty"`
|
||||||
Name *string `json:"name,omitempty"`
|
Name *string `json:"name,omitempty"`
|
||||||
ToolCalls []Tool `json:"tool_calls,omitempty"`
|
ToolCalls []Tool `json:"tool_calls,omitempty"`
|
||||||
ToolCallId string `json:"tool_call_id,omitempty"`
|
ToolCallId string `json:"tool_call_id,omitempty"`
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
type Tool struct {
|
type Tool struct {
|
||||||
Id string `json:"id,omitempty"`
|
Id string `json:"id,omitempty"`
|
||||||
Type string `json:"type,omitempty"` // when splicing claude tools stream messages, it is empty
|
Type string `json:"type,omitempty"` // when splicing claude tools stream messages, it is empty
|
||||||
Function Function `json:"function"`
|
Function Function `json:"function"`
|
||||||
|
ExtraContent *ToolExtraContent `json:"extra_content,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Function struct {
|
type Function struct {
|
||||||
@ -12,3 +13,11 @@ type Function struct {
|
|||||||
Parameters any `json:"parameters,omitempty"` // request
|
Parameters any `json:"parameters,omitempty"` // request
|
||||||
Arguments any `json:"arguments,omitempty"` // response
|
Arguments any `json:"arguments,omitempty"` // response
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ToolExtraContent struct {
|
||||||
|
Google *GoogleToolExtraContent `json:"google,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GoogleToolExtraContent struct {
|
||||||
|
ThoughtSignature string `json:"thought_signature,omitempty"`
|
||||||
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user