hjjjj 8b87c3d404
Some checks failed
CI / Unit tests (push) Has been cancelled
CI / commit_lint (push) Has been cancelled
feat: enhance strict compatibility for OpenAI requests
- Implement sanitization for `tool_choice` and removal of `disable_parallel_tool_use` in request payloads.
- Introduce logging for tool choice changes in `DoRequestHelper`.
- Update `ConvertRequest` to handle tool-call compatibility and maintain structured tool history.
- Add `ThoughtSignature` to `Part` struct for better tracking of reasoning content.
- Refactor request handling in `getRequestBody` to ensure strict compliance with OpenAI API requirements.
2026-03-31 16:37:53 +08:00

836 lines
26 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package gemini
import (
"bufio"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/songquanpeng/one-api/common/render"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/gin-gonic/gin"
)
// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn
const (
VisionMaxImageNum = 16
)
var mimeTypeMap = map[string]string{
"json_object": "application/json",
"text": "text/plain",
}
// sanitizeSchema recursively removes JSON Schema keywords unsupported by Gemini
// (e.g. "const", "$schema", "additionalProperties") from a schema map.
func sanitizeSchema(v interface{}) interface{} {
switch val := v.(type) {
case map[string]interface{}:
// Only remove fields Gemini explicitly rejects; leave others intact
unsupported := []string{"const", "$schema", "additionalProperties"}
for _, key := range unsupported {
delete(val, key)
}
for k, child := range val {
val[k] = sanitizeSchema(child)
}
return val
case []interface{}:
for i, item := range val {
val[i] = sanitizeSchema(item)
}
return val
}
return v
}
// Setting safety to the lowest possible values since Gemini is already powerless enough
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
messagesForConvert := textRequest.Messages
geminiRequest := ChatRequest{
Contents: make([]ChatContent, 0, len(textRequest.Messages)),
SafetySettings: []ChatSafetySettings{
{
Category: "HARM_CATEGORY_HARASSMENT",
Threshold: config.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_HATE_SPEECH",
Threshold: config.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
Threshold: config.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
Threshold: config.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_CIVIC_INTEGRITY",
Threshold: config.GeminiSafetySetting,
},
},
GenerationConfig: ChatGenerationConfig{
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
MaxOutputTokens: textRequest.MaxTokens,
},
}
if textRequest.ResponseFormat != nil {
if mimeType, ok := mimeTypeMap[textRequest.ResponseFormat.Type]; ok {
geminiRequest.GenerationConfig.ResponseMimeType = mimeType
}
if textRequest.ResponseFormat.JsonSchema != nil {
geminiRequest.GenerationConfig.ResponseSchema = textRequest.ResponseFormat.JsonSchema.Schema
geminiRequest.GenerationConfig.ResponseMimeType = mimeTypeMap["json_object"]
}
}
// For models that support image generation (e.g. gemini-2.5-flash-image),
// request both TEXT and IMAGE modalities so the model returns inline images.
if strings.Contains(strings.ToLower(textRequest.Model), "image") {
geminiRequest.GenerationConfig.ResponseModalities = []string{"TEXT", "IMAGE"}
}
if textRequest.Tools != nil {
functions := make([]model.Function, 0, len(textRequest.Tools))
for _, tool := range textRequest.Tools {
fn := tool.Function
if fn.Parameters != nil {
fn.Parameters = sanitizeSchema(fn.Parameters)
}
functions = append(functions, fn)
}
geminiRequest.Tools = []ChatTools{
{
FunctionDeclarations: functions,
},
}
} else if textRequest.Functions != nil {
geminiRequest.Tools = []ChatTools{
{
FunctionDeclarations: textRequest.Functions,
},
}
}
// Tool-call compatibility mode:
// Keep tool rounds on strict structured functionCall/functionResponse flow.
// Gemini thinking tool rounds behind gateways are prone to thought_signature
// validation failures; disable thinking config for tool flow to avoid fallback
// text replay and preserve structured tool history.
hasToolFlow := len(geminiRequest.Tools) > 0
if !hasToolFlow {
for _, msg := range messagesForConvert {
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
hasToolFlow = true
break
}
}
}
if hasToolFlow {
// Remove incompatible historical tool-call turns that cannot satisfy
// Gemini thought_signature requirements. Keep structured tool flow for
// the remaining valid turns (no text downgrade replay).
messagesForConvert = sanitizeGeminiToolHistoryMessages(messagesForConvert)
}
// Enable thinking only for non-tool turns.
if textRequest.EnableThinking && !hasToolFlow {
// Use thinkingBudget=-1 (dynamic) so Gemini decides the appropriate budget.
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ThinkingBudget: -1}
} else if hasToolFlow {
geminiRequest.GenerationConfig.ThinkingConfig = nil
}
// Build a map from tool_call_id → function name for resolving tool result names
toolCallIdToName := map[string]string{}
lastKnownThoughtSignature := ""
for _, message := range messagesForConvert {
if message.ReasoningEncryptedContent != "" {
lastKnownThoughtSignature = message.ReasoningEncryptedContent
}
if message.Role == "assistant" {
for _, tc := range message.ToolCalls {
if tc.Id != "" && tc.Function.Name != "" {
toolCallIdToName[tc.Id] = tc.Function.Name
}
if tc.ExtraContent != nil && tc.ExtraContent.Google != nil && tc.ExtraContent.Google.ThoughtSignature != "" {
lastKnownThoughtSignature = tc.ExtraContent.Google.ThoughtSignature
}
}
}
}
shouldAddDummyModelMessage := false
for i := 0; i < len(messagesForConvert); i++ {
message := messagesForConvert[i]
// --- tool result: role=tool → Gemini functionResponse (user role) ---
// Gemini requires functionResponse parts to match functionCall parts from the
// preceding model turn. For OpenAI-compat payloads with multiple consecutive
// role=tool messages, merge them into a single user turn with multiple parts.
if message.Role == "tool" {
var responseParts []Part
for ; i < len(messagesForConvert) && messagesForConvert[i].Role == "tool"; i++ {
toolMessage := messagesForConvert[i]
toolName := toolMessage.ToolCallId
if name, ok := toolCallIdToName[toolMessage.ToolCallId]; ok {
toolName = name
} else if toolMessage.Name != nil && *toolMessage.Name != "" {
toolName = *toolMessage.Name
}
if toolName == "" {
toolName = "unknown_tool"
}
responseParts = append(responseParts, Part{
FunctionResponse: &FunctionResponse{
Name: toolName,
Response: map[string]any{"content": toolMessage.StringContent()},
},
})
}
if len(responseParts) > 0 {
geminiRequest.Contents = append(geminiRequest.Contents, ChatContent{
Role: "user",
Parts: responseParts,
})
}
i--
continue
}
content := ChatContent{
Role: message.Role,
Parts: []Part{
{
Text: message.StringContent(),
},
},
}
openaiContent := message.ParseContent()
var parts []Part
imageNum := 0
for _, part := range openaiContent {
if part.Type == model.ContentTypeText {
parts = append(parts, Part{
Text: part.Text,
})
} else if part.Type == model.ContentTypeImageURL {
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
// Only count images toward the image limit; video/audio have no such limit
isImage := strings.HasPrefix(mimeType, "image/")
if isImage {
imageNum += 1
if imageNum > VisionMaxImageNum {
continue
}
}
parts = append(parts, Part{
InlineData: &InlineData{
MimeType: mimeType,
Data: data,
},
})
} else if part.Type == model.ContentTypeVideoURL {
mimeType, data, _ := image.GetImageFromUrl(part.VideoURL.Url)
if data != "" {
parts = append(parts, Part{
InlineData: &InlineData{
MimeType: mimeType,
Data: data,
},
})
}
} else if part.Type == model.ContentTypeInputAudio {
// input_audio: { data: "base64...", format: "mp3" }
// Convert directly to Gemini inlineData — bypasses Zenmux fileUri conversion
// that occurs when audio is embedded in image_url.
if part.InputAudio != nil && part.InputAudio.Data != "" {
mimeType := "audio/" + part.InputAudio.Format
if part.InputAudio.Format == "" {
mimeType = "audio/webm"
}
parts = append(parts, Part{
InlineData: &InlineData{
MimeType: mimeType,
Data: part.InputAudio.Data,
},
})
}
}
}
// --- assistant with tool_calls → Gemini functionCall parts ---
if message.Role == "assistant" && len(message.ToolCalls) > 0 {
var fcParts []Part
// Include any text content first
for _, p := range parts {
if p.Text != "" {
fcParts = append(fcParts, p)
}
}
for _, tc := range message.ToolCalls {
var args any
if argStr, ok := tc.Function.Arguments.(string); ok && argStr != "" {
if err := json.Unmarshal([]byte(argStr), &args); err != nil {
args = map[string]any{}
}
} else {
args = map[string]any{}
}
thoughtSignature := ""
if tc.ExtraContent != nil && tc.ExtraContent.Google != nil {
thoughtSignature = tc.ExtraContent.Google.ThoughtSignature
}
if thoughtSignature == "" {
thoughtSignature = message.ReasoningEncryptedContent
}
if thoughtSignature == "" {
thoughtSignature = lastKnownThoughtSignature
}
if thoughtSignature != "" {
lastKnownThoughtSignature = thoughtSignature
}
fcParts = append(fcParts, Part{
FunctionCall: &FunctionCall{
FunctionName: tc.Function.Name,
Arguments: args,
},
ThoughtSignature: thoughtSignature,
})
}
content.Role = "model"
content.Parts = fcParts
geminiRequest.Contents = append(geminiRequest.Contents, content)
continue
}
content.Parts = parts
// there's no assistant role in gemini and API shall vomit if Role is not user or model
if content.Role == "assistant" {
content.Role = "model"
}
// Converting system prompt to prompt from user for the same reason
if content.Role == "system" {
shouldAddDummyModelMessage = true
if IsModelSupportSystemInstruction(textRequest.Model) {
geminiRequest.SystemInstruction = &content
geminiRequest.SystemInstruction.Role = ""
continue
} else {
content.Role = "user"
}
}
geminiRequest.Contents = append(geminiRequest.Contents, content)
// If a system message is the last message, we need to add a dummy model message to make gemini happy
if shouldAddDummyModelMessage {
geminiRequest.Contents = append(geminiRequest.Contents, ChatContent{
Role: "model",
Parts: []Part{
{
Text: "Okay",
},
},
})
shouldAddDummyModelMessage = false
}
}
return &geminiRequest
}
func sanitizeGeminiToolHistoryMessages(messages []model.Message) []model.Message {
if len(messages) == 0 {
return messages
}
sanitized := make([]model.Message, 0, len(messages))
droppedToolCallIDs := make(map[string]struct{})
lastKnownThoughtSignature := ""
for _, message := range messages {
if message.ReasoningEncryptedContent != "" {
lastKnownThoughtSignature = message.ReasoningEncryptedContent
}
if message.Role == "assistant" && len(message.ToolCalls) > 0 {
keptToolCalls := make([]model.Tool, 0, len(message.ToolCalls))
for _, tc := range message.ToolCalls {
thoughtSignature := ""
if tc.ExtraContent != nil && tc.ExtraContent.Google != nil {
thoughtSignature = strings.TrimSpace(tc.ExtraContent.Google.ThoughtSignature)
}
if thoughtSignature == "" {
thoughtSignature = strings.TrimSpace(message.ReasoningEncryptedContent)
}
if thoughtSignature == "" {
thoughtSignature = strings.TrimSpace(lastKnownThoughtSignature)
}
if thoughtSignature == "" {
if tc.Id != "" {
droppedToolCallIDs[tc.Id] = struct{}{}
}
continue
}
if tc.ExtraContent == nil {
tc.ExtraContent = &model.ToolExtraContent{}
}
if tc.ExtraContent.Google == nil {
tc.ExtraContent.Google = &model.GoogleToolExtraContent{}
}
tc.ExtraContent.Google.ThoughtSignature = thoughtSignature
lastKnownThoughtSignature = thoughtSignature
keptToolCalls = append(keptToolCalls, tc)
}
if len(keptToolCalls) == 0 {
assistantText := strings.TrimSpace(message.StringContent())
if assistantText == "" {
continue
}
message.ToolCalls = nil
message.Content = assistantText
sanitized = append(sanitized, message)
continue
}
message.ToolCalls = keptToolCalls
sanitized = append(sanitized, message)
continue
}
if message.Role == "tool" && message.ToolCallId != "" {
if _, dropped := droppedToolCallIDs[message.ToolCallId]; dropped {
continue
}
}
sanitized = append(sanitized, message)
}
return sanitized
}
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *BatchEmbeddingRequest {
inputs := request.ParseInput()
requests := make([]EmbeddingRequest, len(inputs))
model := fmt.Sprintf("models/%s", request.Model)
for i, input := range inputs {
requests[i] = EmbeddingRequest{
Model: model,
Content: ChatContent{
Parts: []Part{
{
Text: input,
},
},
},
}
}
return &BatchEmbeddingRequest{
Requests: requests,
}
}
type TokensDetail struct {
Modality string `json:"modality"`
TokenCount int `json:"tokenCount"`
}
type UsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
ThoughtsTokenCount int `json:"thoughtsTokenCount"`
CandidatesTokensDetails []TokensDetail `json:"candidatesTokensDetails"`
}
type ChatResponse struct {
Id string `json:"id,omitempty"` // set by some proxies; used for metadata fetches
Candidates []ChatCandidate `json:"candidates"`
PromptFeedback ChatPromptFeedback `json:"promptFeedback"`
UsageMetadata *UsageMetadata `json:"usageMetadata"`
}
func (g *ChatResponse) GetResponseText() string {
if g == nil {
return ""
}
if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
return g.Candidates[0].Content.Parts[0].Text
}
return ""
}
type ChatCandidate struct {
Content ChatContent `json:"content"`
FinishReason string `json:"finishReason"`
Index int64 `json:"index"`
SafetyRatings []ChatSafetyRating `json:"safetyRatings"`
}
type ChatSafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}
type ChatPromptFeedback struct {
SafetyRatings []ChatSafetyRating `json:"safetyRatings"`
}
func getToolCalls(candidate *ChatCandidate) []model.Tool {
var toolCalls []model.Tool
for _, item := range candidate.Content.Parts {
if item.FunctionCall == nil {
continue
}
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
if err != nil {
logger.FatalLog("getToolCalls failed: " + err.Error())
continue
}
toolCall := model.Tool{
Id: fmt.Sprintf("call_%s", random.GetUUID()),
Type: "function",
Function: model.Function{
Arguments: string(argsBytes),
Name: item.FunctionCall.FunctionName,
},
}
if item.ThoughtSignature != "" {
toolCall.ExtraContent = &model.ToolExtraContent{
Google: &model.GoogleToolExtraContent{
ThoughtSignature: item.ThoughtSignature,
},
}
}
toolCalls = append(toolCalls, toolCall)
}
return toolCalls
}
func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
}
for i, candidate := range response.Candidates {
choice := openai.TextResponseChoice{
Index: i,
Message: model.Message{
Role: "assistant",
},
FinishReason: constant.StopFinishReason,
}
if len(candidate.Content.Parts) > 0 {
toolCalls := getToolCalls(&candidate)
if len(toolCalls) > 0 {
choice.Message.ToolCalls = toolCalls
} else {
var builder strings.Builder
for j, part := range candidate.Content.Parts {
if j > 0 {
builder.WriteString("\n")
}
builder.WriteString(part.Text)
}
choice.Message.Content = builder.String()
}
} else {
choice.Message.Content = ""
choice.FinishReason = candidate.FinishReason
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
if len(geminiResponse.Candidates) > 0 {
var textBuilder strings.Builder
var thinkingBuilder strings.Builder
var toolCalls []model.Tool
for _, part := range geminiResponse.Candidates[0].Content.Parts {
if part.Thought {
// Thinking/reasoning content — route to reasoning_content field
thinkingBuilder.WriteString(part.Text)
} else if part.FunctionCall != nil {
argsBytes, err := json.Marshal(part.FunctionCall.Arguments)
if err != nil {
logger.FatalLog("streamResponseGeminiChat2OpenAI marshal args failed: " + err.Error())
argsBytes = []byte("{}")
}
toolCalls = append(toolCalls, model.Tool{
Id: fmt.Sprintf("call_%s", random.GetUUID()),
Type: "function",
Function: model.Function{
Name: part.FunctionCall.FunctionName,
Arguments: string(argsBytes),
},
ExtraContent: func() *model.ToolExtraContent {
if part.ThoughtSignature == "" {
return nil
}
return &model.ToolExtraContent{
Google: &model.GoogleToolExtraContent{
ThoughtSignature: part.ThoughtSignature,
},
}
}(),
})
} else if part.Text != "" {
textBuilder.WriteString(part.Text)
} else if part.InlineData != nil && part.InlineData.Data != "" {
// Inline image — embed as markdown data-URI so it passes through the SSE pipeline
mimeType := part.InlineData.MimeType
if mimeType == "" {
mimeType = "image/png"
}
textBuilder.WriteString(fmt.Sprintf("![generated](data:%s;base64,%s)", mimeType, part.InlineData.Data))
}
}
if textBuilder.Len() > 0 {
choice.Delta.Content = textBuilder.String()
}
if thinkingBuilder.Len() > 0 {
choice.Delta.ReasoningContent = thinkingBuilder.String()
}
if len(toolCalls) > 0 {
choice.Delta.ToolCalls = toolCalls
toolFinish := "tool_calls"
choice.FinishReason = &toolFinish
}
}
var response openai.ChatCompletionsStreamResponse
response.Id = fmt.Sprintf("chatcmpl-%s", random.GetUUID())
response.Created = helper.GetTimestamp()
response.Object = "chat.completion.chunk"
response.Model = "gemini"
response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
return &response
}
func embeddingResponseGemini2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list",
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)),
Model: "gemini-embedding",
Usage: model.Usage{TotalTokens: 0},
}
for _, item := range response.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
Object: `embedding`,
Index: 0,
Embedding: item.Values,
})
}
return &openAIEmbeddingResponse
}
// StreamHandler processes a Gemini SSE stream and returns
// (error, usage, generationId).
// generationId is captured from the response's "id" field if present (set by
// some proxies); callers can use it for post-response metadata fetches.
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage, string) {
var usage *model.Usage
var generationId string
responseText := ""
outputImageCount := 0
scanner := bufio.NewScanner(resp.Body)
// Default bufio.Scanner buffer is 64KB which is too small for inline image data (base64).
// Allocate 20MB to handle large image payloads from Gemini image-generation models.
const maxScanTokenSize = 20 * 1024 * 1024
scanner.Buffer(make([]byte, maxScanTokenSize), maxScanTokenSize)
scanner.Split(bufio.ScanLines)
common.SetEventStreamHeaders(c)
for scanner.Scan() {
data := scanner.Text()
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "data: ") {
continue
}
data = strings.TrimPrefix(data, "data: ")
data = strings.TrimSuffix(data, "\"")
var geminiResponse ChatResponse
err := json.Unmarshal([]byte(data), &geminiResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
if generationId == "" && geminiResponse.Id != "" {
generationId = geminiResponse.Id
}
// Extract usageMetadata from the last chunk that carries it.
if geminiResponse.UsageMetadata != nil {
meta := geminiResponse.UsageMetadata
// Image output tokens are priced at $60/M vs text completion $3/M (20× ratio).
// Separate image and text tokens from candidatesTokensDetails so we can
// represent image tokens as equivalent text completion tokens for billing.
imageOutputTokens := 0
textOutputTokens := 0
if len(meta.CandidatesTokensDetails) > 0 {
for _, d := range meta.CandidatesTokensDetails {
if d.Modality == "IMAGE" {
imageOutputTokens += d.TokenCount
} else {
textOutputTokens += d.TokenCount
}
}
} else {
textOutputTokens = meta.CandidatesTokenCount
}
// ThoughtsTokenCount billed at text completion rate; image tokens at 20× that rate.
const imageToTextRatio = 20
completionTokens := textOutputTokens + imageOutputTokens*imageToTextRatio + meta.ThoughtsTokenCount
usage = &model.Usage{
PromptTokens: meta.PromptTokenCount,
CompletionTokens: completionTokens,
TotalTokens: meta.TotalTokenCount,
}
}
response := streamResponseGeminiChat2OpenAI(&geminiResponse)
if response == nil {
continue
}
// Accumulate text for fallback token estimation (used only when
// usageMetadata is absent from the stream). Also count output images.
if len(geminiResponse.Candidates) > 0 {
for _, part := range geminiResponse.Candidates[0].Content.Parts {
if part.InlineData == nil {
responseText += part.Text
} else if strings.HasPrefix(part.InlineData.MimeType, "image/") {
outputImageCount++
}
}
}
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
render.Done(c)
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, ""
}
// If upstream provided usageMetadata, use it (image tokens already adjusted above).
if usage != nil {
return nil, usage, generationId
}
return nil, openai.ResponseText2Usage(responseText, "", 0), generationId
}
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var geminiResponse ChatResponse
err = json.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if len(geminiResponse.Candidates) == 0 {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: "No candidates returned",
Type: "server_error",
Param: "",
Code: 500,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
fullTextResponse.Model = modelName
completionTokens := openai.CountTokenText(geminiResponse.GetResponseText(), modelName)
usage := model.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var geminiEmbeddingResponse EmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &geminiEmbeddingResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if geminiEmbeddingResponse.Error != nil {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: geminiEmbeddingResponse.Error.Message,
Type: "gemini_error",
Param: "",
Code: geminiEmbeddingResponse.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := embeddingResponseGemini2OpenAI(&geminiEmbeddingResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}