- Implement sanitization for `tool_choice` and removal of `disable_parallel_tool_use` in request payloads. - Introduce logging for tool choice changes in `DoRequestHelper`. - Update `ConvertRequest` to handle tool-call compatibility and maintain structured tool history. - Add `ThoughtSignature` to `Part` struct for better tracking of reasoning content. - Refactor request handling in `getRequestBody` to ensure strict compliance with OpenAI API requirements.
157 lines
3.9 KiB
Go
157 lines
3.9 KiB
Go
package adaptor
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/songquanpeng/one-api/common/client"
|
|
"github.com/songquanpeng/one-api/common/logger"
|
|
"github.com/songquanpeng/one-api/relay/meta"
|
|
"io"
|
|
"net/http"
|
|
)
|
|
|
|
func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) {
|
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
|
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
|
if meta.IsStream && c.Request.Header.Get("Accept") == "" {
|
|
req.Header.Set("Accept", "text/event-stream")
|
|
}
|
|
}
|
|
|
|
func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
|
fullRequestURL, err := a.GetRequestURL(meta)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get request url failed: %w", err)
|
|
}
|
|
|
|
if requestBody != nil {
|
|
raw, readErr := io.ReadAll(requestBody)
|
|
if readErr != nil {
|
|
return nil, fmt.Errorf("read request body failed: %w", readErr)
|
|
}
|
|
hasDisableParallelBefore := bytes.Contains(raw, []byte("disable_parallel_tool_use"))
|
|
beforeToolChoice := extractToolChoiceForLog(raw)
|
|
raw = sanitizeStrictToolChoicePayload(raw)
|
|
hasDisableParallelAfter := bytes.Contains(raw, []byte("disable_parallel_tool_use"))
|
|
afterToolChoice := extractToolChoiceForLog(raw)
|
|
logger.Infof(
|
|
c.Request.Context(),
|
|
"[DoRequestHelper] outbound %s %s tool_choice(before=%s, after=%s) disable_parallel(before=%t, after=%t)",
|
|
c.Request.Method,
|
|
c.Request.URL.Path,
|
|
beforeToolChoice,
|
|
afterToolChoice,
|
|
hasDisableParallelBefore,
|
|
hasDisableParallelAfter,
|
|
)
|
|
requestBody = bytes.NewBuffer(raw)
|
|
}
|
|
|
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("new request failed: %w", err)
|
|
}
|
|
err = a.SetupRequestHeader(c, req, meta)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("setup request header failed: %w", err)
|
|
}
|
|
resp, err := DoRequest(c, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("do request failed: %w", err)
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
|
|
resp, err := client.HTTPClient.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if resp == nil {
|
|
return nil, errors.New("resp is nil")
|
|
}
|
|
_ = req.Body.Close()
|
|
_ = c.Request.Body.Close()
|
|
return resp, nil
|
|
}
|
|
|
|
// Global outbound guard for strict OpenAI-compatible gateways:
|
|
// remove disable_parallel_tool_use recursively and normalize object tool_choice.
|
|
func sanitizeStrictToolChoicePayload(data []byte) []byte {
|
|
var payload any
|
|
if err := json.Unmarshal(data, &payload); err != nil {
|
|
return data
|
|
}
|
|
root, ok := payload.(map[string]any)
|
|
if !ok {
|
|
return data
|
|
}
|
|
root = stripDisableParallelKeyAny(root).(map[string]any)
|
|
|
|
hasTools := false
|
|
if tools, ok := root["tools"].([]any); ok && len(tools) > 0 {
|
|
hasTools = true
|
|
}
|
|
if functions, ok := root["functions"]; ok && functions != nil {
|
|
hasTools = true
|
|
}
|
|
if tc, ok := root["tool_choice"]; ok {
|
|
switch tc.(type) {
|
|
case map[string]any, []any:
|
|
if hasTools {
|
|
root["tool_choice"] = "auto"
|
|
} else {
|
|
delete(root, "tool_choice")
|
|
}
|
|
}
|
|
}
|
|
|
|
out, err := json.Marshal(root)
|
|
if err != nil {
|
|
return data
|
|
}
|
|
return out
|
|
}
|
|
|
|
func stripDisableParallelKeyAny(v any) any {
|
|
switch val := v.(type) {
|
|
case map[string]any:
|
|
delete(val, "disable_parallel_tool_use")
|
|
for k, child := range val {
|
|
val[k] = stripDisableParallelKeyAny(child)
|
|
}
|
|
return val
|
|
case []any:
|
|
for i, child := range val {
|
|
val[i] = stripDisableParallelKeyAny(child)
|
|
}
|
|
return val
|
|
default:
|
|
return v
|
|
}
|
|
}
|
|
|
|
func extractToolChoiceForLog(data []byte) string {
|
|
var payload map[string]any
|
|
if err := json.Unmarshal(data, &payload); err != nil {
|
|
return "<invalid-json>"
|
|
}
|
|
tc, ok := payload["tool_choice"]
|
|
if !ok {
|
|
return "<absent>"
|
|
}
|
|
switch v := tc.(type) {
|
|
case string:
|
|
return v
|
|
default:
|
|
out, err := json.Marshal(v)
|
|
if err != nil {
|
|
return "<unmarshalable>"
|
|
}
|
|
return string(out)
|
|
}
|
|
}
|