hjjjj 8b87c3d404
Some checks failed
CI / Unit tests (push) Has been cancelled
CI / commit_lint (push) Has been cancelled
feat: enhance strict compatibility for OpenAI requests
- Implement sanitization for `tool_choice` and removal of `disable_parallel_tool_use` in request payloads.
- Introduce logging for tool choice changes in `DoRequestHelper`.
- Update `ConvertRequest` to handle tool-call compatibility and maintain structured tool history.
- Add `ThoughtSignature` to `Part` struct for better tracking of reasoning content.
- Refactor request handling in `getRequestBody` to ensure strict compliance with OpenAI API requirements.
2026-03-31 16:37:53 +08:00

157 lines
3.9 KiB
Go

package adaptor
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/client"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/meta"
"io"
"net/http"
)
func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) {
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if meta.IsStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
}
}
func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
fullRequestURL, err := a.GetRequestURL(meta)
if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err)
}
if requestBody != nil {
raw, readErr := io.ReadAll(requestBody)
if readErr != nil {
return nil, fmt.Errorf("read request body failed: %w", readErr)
}
hasDisableParallelBefore := bytes.Contains(raw, []byte("disable_parallel_tool_use"))
beforeToolChoice := extractToolChoiceForLog(raw)
raw = sanitizeStrictToolChoicePayload(raw)
hasDisableParallelAfter := bytes.Contains(raw, []byte("disable_parallel_tool_use"))
afterToolChoice := extractToolChoiceForLog(raw)
logger.Infof(
c.Request.Context(),
"[DoRequestHelper] outbound %s %s tool_choice(before=%s, after=%s) disable_parallel(before=%t, after=%t)",
c.Request.Method,
c.Request.URL.Path,
beforeToolChoice,
afterToolChoice,
hasDisableParallelBefore,
hasDisableParallelAfter,
)
requestBody = bytes.NewBuffer(raw)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
err = a.SetupRequestHeader(c, req, meta)
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
resp, err := DoRequest(c, req)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
}
return resp, nil
}
func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
resp, err := client.HTTPClient.Do(req)
if err != nil {
return nil, err
}
if resp == nil {
return nil, errors.New("resp is nil")
}
_ = req.Body.Close()
_ = c.Request.Body.Close()
return resp, nil
}
// Global outbound guard for strict OpenAI-compatible gateways:
// remove disable_parallel_tool_use recursively and normalize object tool_choice.
func sanitizeStrictToolChoicePayload(data []byte) []byte {
var payload any
if err := json.Unmarshal(data, &payload); err != nil {
return data
}
root, ok := payload.(map[string]any)
if !ok {
return data
}
root = stripDisableParallelKeyAny(root).(map[string]any)
hasTools := false
if tools, ok := root["tools"].([]any); ok && len(tools) > 0 {
hasTools = true
}
if functions, ok := root["functions"]; ok && functions != nil {
hasTools = true
}
if tc, ok := root["tool_choice"]; ok {
switch tc.(type) {
case map[string]any, []any:
if hasTools {
root["tool_choice"] = "auto"
} else {
delete(root, "tool_choice")
}
}
}
out, err := json.Marshal(root)
if err != nil {
return data
}
return out
}
func stripDisableParallelKeyAny(v any) any {
switch val := v.(type) {
case map[string]any:
delete(val, "disable_parallel_tool_use")
for k, child := range val {
val[k] = stripDisableParallelKeyAny(child)
}
return val
case []any:
for i, child := range val {
val[i] = stripDisableParallelKeyAny(child)
}
return val
default:
return v
}
}
func extractToolChoiceForLog(data []byte) string {
var payload map[string]any
if err := json.Unmarshal(data, &payload); err != nil {
return "<invalid-json>"
}
tc, ok := payload["tool_choice"]
if !ok {
return "<absent>"
}
switch v := tc.(type) {
case string:
return v
default:
out, err := json.Marshal(v)
if err != nil {
return "<unmarshalable>"
}
return string(out)
}
}