From 5730b2a798cf0fb2b7ebaeed25a052a3da645bea Mon Sep 17 00:00:00 2001 From: hjjjj <1311711287@qq.com> Date: Tue, 17 Mar 2026 18:28:54 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E5=9B=BE=E7=89=87?= =?UTF-8?q?=E7=BC=96=E8=BE=91=E5=8A=9F=E8=83=BD=E5=B9=B6=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E5=A4=9A=E6=A8=A1=E6=80=81=E8=AE=A1=E8=B4=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit refactor: 重构 Gemini 适配器以支持图片编辑和生成 feat(relay): 添加图片编辑模式支持 feat(controller): 实现 UsageAPIURL 用于获取真实 token 用量 feat(web): 在渠道测试中添加模型选择功能 perf(token): 优化多模态 token 计算逻辑 fix(web): 修复日志分页组件显示问题 docs: 更新渠道配置中的 UsageAPIURL 说明 style: 清理调试日志和注释 feat(gemini): 支持 Imagen 3+ 图片生成模型 feat(openai): 添加生成 ID 捕获和元数据获取功能 --- controller/model.go | 2 +- controller/relay.go | 2 +- model/channel.go | 5 + relay/adaptor/common.go | 1 + relay/adaptor/gemini/adaptor.go | 161 ++++++++++++++++-- relay/adaptor/gemini/main.go | 67 ++++++-- relay/adaptor/gemini/model.go | 55 +++++- relay/adaptor/openai/adaptor.go | 2 +- relay/adaptor/openai/helper.go | 62 ++++++- relay/adaptor/openai/main.go | 14 +- relay/adaptor/openai/token.go | 51 +++++- relay/adaptor/vertexai/gemini/adapter.go | 2 +- relay/adaptor/zhipu/adaptor.go | 2 +- relay/billing/ratio/image.go | 6 + relay/billing/ratio/model.go | 8 + relay/controller/text.go | 10 ++ relay/meta/relay_meta.go | 3 + relay/model/image.go | 1 + relay/relaymode/define.go | 1 + relay/relaymode/helper.go | 2 + router/relay.go | 2 +- .../src/views/Channel/component/EditModal.js | 2 +- .../src/views/Channel/component/TableRow.js | 42 ++++- web/berry/src/views/Channel/index.js | 2 +- web/berry/src/views/Channel/type/Config.js | 22 ++- web/berry/src/views/Log/index.js | 20 ++- 26 files changed, 480 insertions(+), 67 deletions(-) diff --git a/controller/model.go b/controller/model.go index d12fb82..89a161f 100644 --- a/controller/model.go +++ b/controller/model.go @@ -475,7 +475,7 @@ type zenmuxChannelDef struct { var allZenmuxChannelDefs = []zenmuxChannelDef{ { Name: "Zenmux", - Type: 50, // OpenAICompatible + Type: channeltype.OpenAICompatible, Protocols: []string{"chat.completions", "responses", "messages", "gemini", "generate"}, }, } diff --git a/controller/relay.go b/controller/relay.go index 038123b..a0448f7 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -26,7 +26,7 @@ import ( func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { var err *model.ErrorWithStatusCode switch relayMode { - case relaymode.ImagesGenerations: + case relaymode.ImagesGenerations, relaymode.ImagesEdits: err = controller.RelayImageHelper(c, relayMode) case relaymode.AudioSpeech: fallthrough diff --git a/model/channel.go b/model/channel.go index 4b0f4b0..6d7e373 100644 --- a/model/channel.go +++ b/model/channel.go @@ -50,6 +50,11 @@ type ChannelConfig struct { Plugin string `json:"plugin,omitempty"` VertexAIProjectID string `json:"vertex_ai_project_id,omitempty"` VertexAIADC string `json:"vertex_ai_adc,omitempty"` + // UsageAPIURL is an optional URL template for fetching accurate per-request + // token usage after a generation completes. Use {id} as a placeholder for + // the generation/request ID returned in the upstream response. + // Example: https://zenmux.ai/api/v1/generation?ctoken=xxx&id={id} + UsageAPIURL string `json:"usage_api_url,omitempty"` } func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) { diff --git a/relay/adaptor/common.go b/relay/adaptor/common.go index 8953d7a..31b006e 100644 --- a/relay/adaptor/common.go +++ b/relay/adaptor/common.go @@ -23,6 +23,7 @@ func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io. if err != nil { return nil, fmt.Errorf("get request url failed: %w", err) } + req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { return nil, fmt.Errorf("new request failed: %w", err) diff --git a/relay/adaptor/gemini/adaptor.go b/relay/adaptor/gemini/adaptor.go index 41e2134..1180d25 100644 --- a/relay/adaptor/gemini/adaptor.go +++ b/relay/adaptor/gemini/adaptor.go @@ -39,7 +39,22 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { modelName := meta.ActualModelName switch meta.Mode { + case relaymode.ImagesEdits: + // Image editing uses :predict (same as generation) with referenceImages in instances[] + if strings.Contains(modelName, "/") { + return fmt.Sprintf("%s/v1/%s:predict", meta.BaseURL, publisherModelPath(modelName)), nil + } + return fmt.Sprintf("%s/v1beta/models/%s:predict", meta.BaseURL, modelName), nil case relaymode.ImagesGenerations: + // Imagen 3+ models use :generateImages endpoint + if isImagen3Model(modelName) { + return fmt.Sprintf("%s/v1/models/%s:generateImages", meta.BaseURL, publisherModelPath(modelName)), nil + } + // Publisher models (e.g. klingai/kling-v2, volcengine/doubao-*) use v1 :predict + if strings.Contains(modelName, "/") { + return fmt.Sprintf("%s/v1/%s:predict", meta.BaseURL, publisherModelPath(modelName)), nil + } + // Legacy imagegeneration models (no publisher prefix) use v1beta :predict return fmt.Sprintf("%s/v1beta/models/%s:predict", meta.BaseURL, modelName), nil case relaymode.Embeddings: return fmt.Sprintf("%s/%s/models/%s:batchEmbedContents", meta.BaseURL, version, modelName), nil @@ -49,7 +64,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { if meta.IsStream { action = "streamGenerateContent?alt=sse" } - return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, modelName, action), nil + return fmt.Sprintf("%s/%s/%s:%s", meta.BaseURL, version, publisherModelPath(modelName), action), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { @@ -80,6 +95,41 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) if n <= 0 { n = 1 } + // If an image is provided, convert to :editImage format (image-to-image) + if request.Image != nil { + imageStr, ok := request.Image.(string) + if !ok { + return nil, errors.New("image field must be a base64 string or data URL") + } + // Strip data URL prefix if present (e.g. "data:image/png;base64,") + base64Data := imageStr + if idx := strings.Index(imageStr, ";base64,"); idx != -1 { + base64Data = imageStr[idx+8:] + } + return EditImageRequest{ + Instances: []EditImageInstance{ + { + Prompt: request.Prompt, + ReferenceImages: []ReferenceImageItem{ + { + ReferenceType: "REFERENCE_TYPE_RAW", + ReferenceId: 1, + ReferenceImage: ReferenceImageData{ + BytesBase64Encoded: base64Data, + }, + }, + }, + }, + }, + }, nil + } + if isImagen3Model(request.Model) { + return GenerateImagesRequest{ + Prompt: request.Prompt, + NumberOfImages: n, + AspectRatio: sizeToAspectRatio(request.Size), + }, nil + } return ImagenRequest{ Instances: []ImagenInstance{{Prompt: request.Prompt}}, Parameters: ImagenParameters{ @@ -89,7 +139,28 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) }, nil } -// sizeToAspectRatio converts OpenAI size strings to Imagen aspect ratios. +// publisherModelPath converts publisher/model names to the correct Vertex AI URL path segment. +// Google's own models (gemini/, google/) keep the raw /models/publisher/model format. +// Third-party publisher models use /publishers/{pub}/models/{model} format. +func publisherModelPath(modelName string) string { + if idx := strings.Index(modelName, "/"); idx != -1 { + publisher := modelName[:idx] + model := modelName[idx+1:] + if publisher == "gemini" || publisher == "google" { + return fmt.Sprintf("models/%s", modelName) + } + return fmt.Sprintf("publishers/%s/models/%s", publisher, model) + } + return fmt.Sprintf("models/%s", modelName) +} + +// isImagen3Model returns true for image-generation-only models that use the :generateImages endpoint. +func isImagen3Model(modelName string) bool { + lower := strings.ToLower(modelName) + return strings.HasPrefix(lower, "imagen-3") || + strings.HasPrefix(lower, "imagen-4") || + strings.Contains(lower, "imagegeneration@00") +} func sizeToAspectRatio(size string) string { switch size { case "1792x1024": @@ -109,7 +180,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Read func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { - err, usage = StreamHandler(c, resp) + err, usage, meta.GenerationId = StreamHandler(c, resp) if usage != nil { usage.PromptTokens = meta.PromptTokens } @@ -117,7 +188,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met switch meta.Mode { case relaymode.Embeddings: err, usage = EmbeddingHandler(c, resp) - case relaymode.ImagesGenerations: + case relaymode.ImagesGenerations, relaymode.ImagesEdits: err = ImagenHandler(c, resp) default: err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) @@ -126,7 +197,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met return } -// ImagenHandler converts a Google Imagen predict response to OpenAI image response format. +// ImagenHandler converts a Google Imagen predict/generateImages response to OpenAI image response format. +// Handles both Vertex AI format (predictions[].bytesBase64Encoded) and OpenAI format (data[].b64_json). func ImagenHandler(c *gin.Context, resp *http.Response) *model.ErrorWithStatusCode { responseBody, readErr := io.ReadAll(resp.Body) resp.Body.Close() @@ -134,24 +206,79 @@ func ImagenHandler(c *gin.Context, resp *http.Response) *model.ErrorWithStatusCo return openai.ErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError) } - var imagenResp ImagenResponse - if err := json.Unmarshal(responseBody, &imagenResp); err != nil { - return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) - } + fmt.Printf("[ImagenHandler] raw response body: %s\n", string(responseBody)) - if imagenResp.Error != nil { + // Safety check: empty response + if len(responseBody) == 0 || string(responseBody) == "{}" || string(responseBody) == "null" { return openai.ErrorWrapper( - fmt.Errorf("%s", imagenResp.Error.Message), - "imagen_error", - http.StatusInternalServerError, + fmt.Errorf("upstream returned empty response body"), + "empty_response_body", + http.StatusBadGateway, ) } - // Convert to OpenAI image response format - data := make([]openai.ImageData, 0, len(imagenResp.Predictions)) - for _, p := range imagenResp.Predictions { - data = append(data, openai.ImageData{B64Json: p.BytesBase64Encoded}) + data := make([]openai.ImageData, 0) + + // Format 1: :generateImages / :editImage → generatedImages[].image.imageBytes + var generateResp GenerateImagesResponse + if err := json.Unmarshal(responseBody, &generateResp); err == nil { + if generateResp.Error != nil { + return openai.ErrorWrapper( + fmt.Errorf("%s", generateResp.Error.Message), + "imagen_error", + http.StatusInternalServerError, + ) + } + for _, img := range generateResp.GeneratedImages { + if img.Image.ImageBytes != "" { + data = append(data, openai.ImageData{B64Json: img.Image.ImageBytes}) + } + } } + + // Format 2: legacy :predict → predictions[].bytesBase64Encoded or predictions[].gcsUri + if len(data) == 0 { + var imagenResp ImagenResponse + if err := json.Unmarshal(responseBody, &imagenResp); err == nil { + if imagenResp.Error != nil { + return openai.ErrorWrapper( + fmt.Errorf("%s", imagenResp.Error.Message), + "imagen_error", + http.StatusInternalServerError, + ) + } + for _, p := range imagenResp.Predictions { + if p.BytesBase64Encoded != "" { + data = append(data, openai.ImageData{B64Json: p.BytesBase64Encoded}) + } else if p.GcsUri != "" { + data = append(data, openai.ImageData{Url: p.GcsUri}) + } + } + } + } + + // Format 3: OpenAI format data[].b64_json — used by models like openai/gpt-image-* + if len(data) == 0 { + var openaiImgResp struct { + Data []struct { + B64Json string `json:"b64_json"` + URL string `json:"url"` + } `json:"data"` + } + if err := json.Unmarshal(responseBody, &openaiImgResp); err == nil { + for _, item := range openaiImgResp.Data { + data = append(data, openai.ImageData{B64Json: item.B64Json, Url: item.URL}) + } + } + } + if len(data) == 0 { + return openai.ErrorWrapper( + fmt.Errorf("upstream returned no image data (response: %s)", string(responseBody)), + "no_image_data", + http.StatusBadGateway, + ) + } + openaiResp := openai.ImageResponse{ Created: time.Now().Unix(), Data: data, diff --git a/relay/adaptor/gemini/main.go b/relay/adaptor/gemini/main.go index 8724807..b32b5eb 100644 --- a/relay/adaptor/gemini/main.go +++ b/relay/adaptor/gemini/main.go @@ -321,13 +321,21 @@ func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *BatchEmbedding } } +type TokensDetail struct { + Modality string `json:"modality"` + TokenCount int `json:"tokenCount"` +} + type UsageMetadata struct { - PromptTokenCount int `json:"promptTokenCount"` - CandidatesTokenCount int `json:"candidatesTokenCount"` - TotalTokenCount int `json:"totalTokenCount"` + PromptTokenCount int `json:"promptTokenCount"` + CandidatesTokenCount int `json:"candidatesTokenCount"` + TotalTokenCount int `json:"totalTokenCount"` + ThoughtsTokenCount int `json:"thoughtsTokenCount"` + CandidatesTokensDetails []TokensDetail `json:"candidatesTokensDetails"` } type ChatResponse struct { + Id string `json:"id,omitempty"` // set by some proxies; used for metadata fetches Candidates []ChatCandidate `json:"candidates"` PromptFeedback ChatPromptFeedback `json:"promptFeedback"` UsageMetadata *UsageMetadata `json:"usageMetadata"` @@ -475,9 +483,15 @@ func embeddingResponseGemini2OpenAI(response *EmbeddingResponse) *openai.Embeddi return &openAIEmbeddingResponse } -func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { +// StreamHandler processes a Gemini SSE stream and returns +// (error, usage, generationId). +// generationId is captured from the response's "id" field if present (set by +// some proxies); callers can use it for post-response metadata fetches. +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage, string) { var usage *model.Usage + var generationId string responseText := "" + outputImageCount := 0 scanner := bufio.NewScanner(resp.Body) // Default bufio.Scanner buffer is 64KB which is too small for inline image data (base64). // Allocate 20MB to handle large image payloads from Gemini image-generation models. @@ -503,14 +517,36 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC continue } + if generationId == "" && geminiResponse.Id != "" { + generationId = geminiResponse.Id + } + // Extract usageMetadata from the last chunk that carries it. - // This includes image/video/audio generation costs that cannot be - // estimated from text tokenisation alone. if geminiResponse.UsageMetadata != nil { + meta := geminiResponse.UsageMetadata + // Image output tokens are priced at $60/M vs text completion $3/M (20× ratio). + // Separate image and text tokens from candidatesTokensDetails so we can + // represent image tokens as equivalent text completion tokens for billing. + imageOutputTokens := 0 + textOutputTokens := 0 + if len(meta.CandidatesTokensDetails) > 0 { + for _, d := range meta.CandidatesTokensDetails { + if d.Modality == "IMAGE" { + imageOutputTokens += d.TokenCount + } else { + textOutputTokens += d.TokenCount + } + } + } else { + textOutputTokens = meta.CandidatesTokenCount + } + // ThoughtsTokenCount billed at text completion rate; image tokens at 20× that rate. + const imageToTextRatio = 20 + completionTokens := textOutputTokens + imageOutputTokens*imageToTextRatio + meta.ThoughtsTokenCount usage = &model.Usage{ - PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount, - CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount, - TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount, + PromptTokens: meta.PromptTokenCount, + CompletionTokens: completionTokens, + TotalTokens: meta.TotalTokenCount, } } @@ -520,11 +556,13 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC } // Accumulate text for fallback token estimation (used only when - // usageMetadata is absent from the stream). + // usageMetadata is absent from the stream). Also count output images. if len(geminiResponse.Candidates) > 0 { for _, part := range geminiResponse.Candidates[0].Content.Parts { if part.InlineData == nil { responseText += part.Text + } else if strings.HasPrefix(part.InlineData.MimeType, "image/") { + outputImageCount++ } } } @@ -543,15 +581,14 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC err := resp.Body.Close() if err != nil { - return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, "" } - // If upstream provided usageMetadata, use it (accurate, includes image costs). - // Otherwise fall back to local tiktoken estimation on text-only content. + // If upstream provided usageMetadata, use it (image tokens already adjusted above). if usage != nil { - return nil, usage + return nil, usage, generationId } - return nil, openai.ResponseText2Usage(responseText, "", 0) + return nil, openai.ResponseText2Usage(responseText, "", 0), generationId } func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { diff --git a/relay/adaptor/gemini/model.go b/relay/adaptor/gemini/model.go index bff08d6..95f2871 100644 --- a/relay/adaptor/gemini/model.go +++ b/relay/adaptor/gemini/model.go @@ -92,11 +92,11 @@ type ChatGenerationConfig struct { // --- Google Imagen --- -// ImagenRequest is the request body for Google Imagen (predict endpoint). +// ImagenRequest is the request body for legacy Imagen (predict endpoint). // POST /v1beta/models/{model}:predict type ImagenRequest struct { - Instances []ImagenInstance `json:"instances"` - Parameters ImagenParameters `json:"parameters"` + Instances []ImagenInstance `json:"instances"` + Parameters ImagenParameters `json:"parameters"` } type ImagenInstance struct { @@ -108,14 +108,61 @@ type ImagenParameters struct { AspectRatio string `json:"aspectRatio,omitempty"` } -// ImagenResponse is the response from Google Imagen. +// GenerateImagesRequest is the request body for Imagen 3+ (generateImages endpoint). +// POST /v1/models/{model}:generateImages +type GenerateImagesRequest struct { + Prompt string `json:"prompt"` + NumberOfImages int `json:"number_of_images"` + AspectRatio string `json:"aspectRatio,omitempty"` +} + +// EditImageRequest is the request body for Vertex AI image editing via :predict. +// Same endpoint as generation, uses instances[] with referenceImages. No parameters field. +type EditImageRequest struct { + Instances []EditImageInstance `json:"instances"` +} + +type EditImageInstance struct { + Prompt string `json:"prompt"` + ReferenceImages []ReferenceImageItem `json:"referenceImages"` +} + +type ReferenceImageItem struct { + ReferenceType string `json:"referenceType"` + ReferenceId int `json:"referenceId"` + ReferenceImage ReferenceImageData `json:"referenceImage"` +} + +type ReferenceImageData struct { + BytesBase64Encoded string `json:"bytesBase64Encoded"` +} + +// ImagenResponse is the response from legacy Imagen :predict endpoint. type ImagenResponse struct { Predictions []ImagenPrediction `json:"predictions"` Error *Error `json:"error,omitempty"` } +// GenerateImagesResponse is the response from :generateImages and :editImage endpoints. +type GenerateImagesResponse struct { + GeneratedImages []GeneratedImageItem `json:"generatedImages"` + Error *Error `json:"error,omitempty"` +} + +type GeneratedImageItem struct { + Image GeneratedImageData `json:"image"` + MimeType string `json:"mimeType"` +} + +type GeneratedImageData struct { + // ImageBytes is base64-encoded image data returned by :generateImages / :editImage + ImageBytes string `json:"imageBytes"` +} + type ImagenPrediction struct { BytesBase64Encoded string `json:"bytesBase64Encoded"` MimeType string `json:"mimeType"` + GcsUri string `json:"gcsUri"` + Prompt string `json:"prompt"` } diff --git a/relay/adaptor/openai/adaptor.go b/relay/adaptor/openai/adaptor.go index 8faf90a..85f7e72 100644 --- a/relay/adaptor/openai/adaptor.go +++ b/relay/adaptor/openai/adaptor.go @@ -109,7 +109,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Read func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { var responseText string - err, responseText, usage = StreamHandler(c, resp, meta.Mode) + err, responseText, usage, meta.GenerationId = StreamHandler(c, resp, meta.Mode) if usage == nil || usage.TotalTokens == 0 { usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) } diff --git a/relay/adaptor/openai/helper.go b/relay/adaptor/openai/helper.go index c6d5bd7..c7e7d75 100644 --- a/relay/adaptor/openai/helper.go +++ b/relay/adaptor/openai/helper.go @@ -1,17 +1,77 @@ package openai import ( + "encoding/json" "fmt" + "net/http" "strings" + "time" + "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/model" ) +var metadataHTTPClient = &http.Client{Timeout: 8 * time.Second} + +// usageMetadataResponse is a generic structure for upstream metadata endpoints +// that return nativeTokens (e.g. Zenmux /api/v1/generation). +type usageMetadataResponse struct { + NativeTokens struct { + PromptTokenCount int `json:"promptTokenCount"` + CandidatesTokenCount int `json:"candidatesTokenCount"` + TotalTokenCount int `json:"totalTokenCount"` + // ThoughtsTokenCount is billed at the completion rate, included in CompletionTokens. + ThoughtsTokenCount int `json:"thoughtsTokenCount"` + } `json:"nativeTokens"` +} + +// FetchUsageFromMetadataURL fetches accurate token usage from an upstream metadata +// endpoint. urlTemplate must contain {id} which is replaced with generationId. +// Returns nil if the fetch fails or the response contains no usable token data. +func FetchUsageFromMetadataURL(urlTemplate, generationId string) (*model.Usage, error) { + url := strings.ReplaceAll(urlTemplate, "{id}", generationId) + resp, err := metadataHTTPClient.Get(url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("metadata API returned status %d", resp.StatusCode) + } + var meta usageMetadataResponse + if err := json.NewDecoder(resp.Body).Decode(&meta); err != nil { + return nil, err + } + if meta.NativeTokens.TotalTokenCount == 0 { + return nil, fmt.Errorf("metadata API returned zero total tokens") + } + prompt := meta.NativeTokens.PromptTokenCount + completion := meta.NativeTokens.CandidatesTokenCount + meta.NativeTokens.ThoughtsTokenCount + logger.SysLog(fmt.Sprintf("usage from metadata API (id=%s): prompt=%d completion=%d", generationId, prompt, completion)) + return &model.Usage{ + PromptTokens: prompt, + CompletionTokens: completion, + TotalTokens: prompt + completion, + }, nil +} + +// countOutputMediaTokens returns a fixed token estimate for any embedded +// media data URIs found in text (image/video/audio), consistent with the +// fixed estimates used for input media in CountTokenMessages. +func countOutputMediaTokens(text string) int { + tokens := 0 + tokens += strings.Count(text, "data:image/") * 2500 + tokens += strings.Count(text, "data:video/") * 10000 + tokens += strings.Count(text, "data:audio/") * 1500 + return tokens +} + func ResponseText2Usage(responseText string, modelName string, promptTokens int) *model.Usage { usage := &model.Usage{} usage.PromptTokens = promptTokens - usage.CompletionTokens = CountTokenText(responseText, modelName) + usage.CompletionTokens = CountTokenText(stripBase64Payloads(responseText), modelName) + + countOutputMediaTokens(responseText) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens return usage } diff --git a/relay/adaptor/openai/main.go b/relay/adaptor/openai/main.go index 9708073..2023025 100644 --- a/relay/adaptor/openai/main.go +++ b/relay/adaptor/openai/main.go @@ -24,11 +24,16 @@ const ( dataPrefixLength = len(dataPrefix) ) -func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) { +// StreamHandler processes an OpenAI-compatible SSE stream and returns +// (error, responseText, usage, generationId). +// generationId is the "id" field captured from the first stream chunk; it can +// be used by callers that need to fetch per-request metadata from an upstream API. +func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage, string) { responseText := "" scanner := bufio.NewScanner(resp.Body) scanner.Split(bufio.ScanLines) var usage *model.Usage + var generationId string common.SetEventStreamHeaders(c) @@ -63,6 +68,9 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E for _, choice := range streamResponse.Choices { responseText += conv.AsString(choice.Delta.Content) } + if generationId == "" && streamResponse.Id != "" { + generationId = streamResponse.Id + } if streamResponse.Usage != nil { usage = streamResponse.Usage } @@ -90,10 +98,10 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E err := resp.Body.Close() if err != nil { - return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil + return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil, "" } - return nil, responseText, usage + return nil, responseText, usage, generationId } func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { diff --git a/relay/adaptor/openai/token.go b/relay/adaptor/openai/token.go index 1d3a276..1bccccf 100644 --- a/relay/adaptor/openai/token.go +++ b/relay/adaptor/openai/token.go @@ -73,6 +73,33 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { return len(tokenEncoder.Encode(text, nil, nil)) } +// stripBase64Payloads removes the base64 payload from data URIs so that huge +// image/video/audio blobs do not inflate the local token estimate. +// Example: "data:image/jpeg;base64,/9j/4AAQ..." → "data:image/jpeg;base64," +func stripBase64Payloads(s string) string { + const marker = ";base64," + out := s + for { + idx := strings.Index(out, marker) + if idx < 0 { + break + } + payloadStart := idx + len(marker) + end := payloadStart + for end < len(out) { + c := out[end] + if (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || + (c >= '0' && c <= '9') || c == '+' || c == '/' || c == '=' { + end++ + } else { + break + } + } + out = out[:payloadStart] + out[end:] + } + return out +} + func CountTokenMessages(messages []model.Message, model string) int { tokenEncoder := getTokenEncoder(model) // Reference: @@ -94,7 +121,7 @@ func CountTokenMessages(messages []model.Message, model string) int { tokenNum += tokensPerMessage switch v := message.Content.(type) { case string: - tokenNum += getTokenNum(tokenEncoder, v) + tokenNum += getTokenNum(tokenEncoder, stripBase64Payloads(v)) case []any: for _, it := range v { m := it.(map[string]any) @@ -102,7 +129,7 @@ func CountTokenMessages(messages []model.Message, model string) int { case "text": if textValue, ok := m["text"]; ok { if textString, ok := textValue.(string); ok { - tokenNum += getTokenNum(tokenEncoder, textString) + tokenNum += getTokenNum(tokenEncoder, stripBase64Payloads(textString)) } } case "image_url": @@ -114,12 +141,24 @@ func CountTokenMessages(messages []model.Message, model string) int { detail = imageUrl["detail"].(string) } imageTokens, err := countImageTokens(url, detail, model) - if err != nil { - logger.SysError("error counting image tokens: " + err.Error()) - } else { - tokenNum += imageTokens + if err != nil || imageTokens == 0 { + // Fallback for base64 images (payload stripped, dimensions + // unavailable) or any other failure: use a conservative + // fixed estimate (~768x768 high-detail image). + if err != nil { + logger.SysError("error counting image tokens, using fallback: " + err.Error()) + } + imageTokens = 765 } + tokenNum += imageTokens } + case "video_url": + // Videos cannot be tokenised locally; conservative fixed estimate. + // Final billing always uses the upstream's real usage. + tokenNum += 2000 + case "input_audio", "audio": + // Audio cannot be tokenised locally; conservative fixed estimate. + tokenNum += 500 } } } diff --git a/relay/adaptor/vertexai/gemini/adapter.go b/relay/adaptor/vertexai/gemini/adapter.go index e4a8591..59db1a1 100644 --- a/relay/adaptor/vertexai/gemini/adapter.go +++ b/relay/adaptor/vertexai/gemini/adapter.go @@ -39,7 +39,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { - err, usage = gemini.StreamHandler(c, resp) + err, usage, meta.GenerationId = gemini.StreamHandler(c, resp) if usage != nil { usage.PromptTokens = meta.PromptTokens } diff --git a/relay/adaptor/zhipu/adaptor.go b/relay/adaptor/zhipu/adaptor.go index 660bd37..f07cc6a 100644 --- a/relay/adaptor/zhipu/adaptor.go +++ b/relay/adaptor/zhipu/adaptor.go @@ -98,7 +98,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Read func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { - err, _, usage = openai.StreamHandler(c, resp, meta.Mode) + err, _, usage, _ = openai.StreamHandler(c, resp, meta.Mode) } else { err, usage = openai.Handler(c, resp, meta.PromptTokens, meta.ActualModelName) } diff --git a/relay/billing/ratio/image.go b/relay/billing/ratio/image.go index c8c42a1..37c146d 100644 --- a/relay/billing/ratio/image.go +++ b/relay/billing/ratio/image.go @@ -11,6 +11,12 @@ var ImageSizeRatios = map[string]map[string]float64{ "1024x1792": 2, "1792x1024": 2, }, + // gpt-image-1.5: base=medium 1024x1024 ($0.034), size ratio for other sizes + "openai/gpt-image-1.5": { + "1024x1024": 1.0, // medium $0.034 + "1024x1536": 1.47, // medium $0.050 + "1536x1024": 1.47, // medium $0.050 + }, "ali-stable-diffusion-xl": { "512x1024": 1, "1024x768": 1, diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index e8b3b61..219d7f4 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -87,6 +87,14 @@ var ModelRatio = map[string]float64{ "text-moderation-latest": 0.1, "dall-e-2": 0.02 * USD, // $0.016 - $0.020 / image "dall-e-3": 0.04 * USD, // $0.040 - $0.120 / image + // Zenmux image generation models (calibrated for 1M quota/dollar) + "openai/gpt-image-1.5": 34, // $0.034 / image (medium 1024x1024 base) + "qwen/qwen-image-2.0": 28.9, // $0.0289 / image + "qwen/qwen-image-2.0-pro": 73, // $0.073 / image + "volcengine/doubao-seedream-5.0-lite": 32, // $0.032 / image + "z-ai/glm-image": 15, // $0.015 / image + "tencent/hunyuan-image3": 29, // $0.029 / image + "klingai/kling-v2": 14, // $0.014 / image // https://docs.anthropic.com/en/docs/about-claude/models "claude-instant-1.2": 0.8 / 1000 * USD, "claude-2.0": 8.0 / 1000 * USD, diff --git a/relay/controller/text.go b/relay/controller/text.go index f912498..0a01a90 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -82,6 +82,16 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) return respErr } + // If the channel has a UsageAPIURL configured and the adaptor captured a + // generation ID, fetch accurate multimodal token counts now (works for any + // channel type that sets meta.GenerationId during DoResponse). + if meta.Config.UsageAPIURL != "" && meta.GenerationId != "" { + if fetchedUsage, fetchErr := openai.FetchUsageFromMetadataURL(meta.Config.UsageAPIURL, meta.GenerationId); fetchErr == nil { + usage = fetchedUsage + } else { + logger.Warnf(ctx, "failed to fetch usage from metadata URL: %s", fetchErr.Error()) + } + } // post-consume quota go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio, systemPromptReset) return nil diff --git a/relay/meta/relay_meta.go b/relay/meta/relay_meta.go index 8c74ef8..501409f 100644 --- a/relay/meta/relay_meta.go +++ b/relay/meta/relay_meta.go @@ -35,6 +35,9 @@ type Meta struct { PromptTokens int // only for DoResponse ForcedSystemPrompt string StartTime time.Time + // GenerationId is the upstream generation/request ID captured during DoResponse. + // Adaptors should set this when available; it is used for post-response metadata fetches. + GenerationId string } func GetByContext(c *gin.Context) *Meta { diff --git a/relay/model/image.go b/relay/model/image.go index bab8425..d7b8acf 100644 --- a/relay/model/image.go +++ b/relay/model/image.go @@ -3,6 +3,7 @@ package model type ImageRequest struct { Model string `json:"model"` Prompt string `json:"prompt" binding:"required"` + Image any `json:"image,omitempty"` N int `json:"n,omitempty"` Size string `json:"size,omitempty"` Quality string `json:"quality,omitempty"` diff --git a/relay/relaymode/define.go b/relay/relaymode/define.go index aa77120..ef74092 100644 --- a/relay/relaymode/define.go +++ b/relay/relaymode/define.go @@ -7,6 +7,7 @@ const ( Embeddings Moderations ImagesGenerations + ImagesEdits Edits AudioSpeech AudioTranscription diff --git a/relay/relaymode/helper.go b/relay/relaymode/helper.go index 2cde5b8..c17d5b0 100644 --- a/relay/relaymode/helper.go +++ b/relay/relaymode/helper.go @@ -16,6 +16,8 @@ func GetByPath(path string) int { relayMode = Moderations } else if strings.HasPrefix(path, "/v1/images/generations") { relayMode = ImagesGenerations + } else if strings.HasPrefix(path, "/v1/images/edits") { + relayMode = ImagesEdits } else if strings.HasPrefix(path, "/v1/edits") { relayMode = Edits } else if strings.HasPrefix(path, "/v1/audio/speech") { diff --git a/router/relay.go b/router/relay.go index 8f3c730..554a64f 100644 --- a/router/relay.go +++ b/router/relay.go @@ -25,7 +25,7 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.POST("/chat/completions", controller.Relay) relayV1Router.POST("/edits", controller.Relay) relayV1Router.POST("/images/generations", controller.Relay) - relayV1Router.POST("/images/edits", controller.RelayNotImplemented) + relayV1Router.POST("/images/edits", controller.Relay) relayV1Router.POST("/images/variations", controller.RelayNotImplemented) relayV1Router.POST("/embeddings", controller.Relay) relayV1Router.POST("/engines/:model/embeddings", controller.Relay) diff --git a/web/berry/src/views/Channel/component/EditModal.js b/web/berry/src/views/Channel/component/EditModal.js index c2520a9..5576a8d 100644 --- a/web/berry/src/views/Channel/component/EditModal.js +++ b/web/berry/src/views/Channel/component/EditModal.js @@ -44,7 +44,7 @@ const validationSchema = Yup.object().shape({ is: (is_edit, type) => !is_edit && type !== 33, then: Yup.string().required('密钥 不能为空') }), - other: Yup.string(), + other: Yup.string().nullable(), models: Yup.array().min(1, '模型 不能为空'), groups: Yup.array().min(1, '用户组 不能为空'), base_url: Yup.string().when('type', { diff --git a/web/berry/src/views/Channel/component/TableRow.js b/web/berry/src/views/Channel/component/TableRow.js index 525f918..b83308d 100644 --- a/web/berry/src/views/Channel/component/TableRow.js +++ b/web/berry/src/views/Channel/component/TableRow.js @@ -19,6 +19,9 @@ import { DialogTitle, Tooltip, Button, + Select, + FormControl, + InputLabel, } from "@mui/material"; import Label from "ui-component/Label"; @@ -28,7 +31,7 @@ import ResponseTimeLabel from "./ResponseTimeLabel"; import GroupLabel from "./GroupLabel"; import NameLabel from "./NameLabel"; -import { IconDotsVertical, IconEdit, IconTrash } from "@tabler/icons-react"; +import { IconDotsVertical, IconEdit, IconTrash, IconBrandSpeedtest } from "@tabler/icons-react"; export default function ChannelTableRow({ item, @@ -36,9 +39,12 @@ export default function ChannelTableRow({ handleOpenModal, setModalChannelId, }) { + const modelOptions = item.models ? item.models.split(',').filter(Boolean) : []; const [open, setOpen] = useState(null); const [openDelete, setOpenDelete] = useState(false); + const [openTest, setOpenTest] = useState(false); const [statusSwitch, setStatusSwitch] = useState(item.status); + const [testModel, setTestModel] = useState(modelOptions[0] || ''); const [priorityValve, setPriority] = useState(item.priority); const [responseTimeData, setResponseTimeData] = useState({ test_time: item.test_time, @@ -87,7 +93,8 @@ export default function ChannelTableRow({ }; const handleResponseTime = async () => { - const { success, time } = await manageChannel(item.id, "test", ""); + setOpenTest(false); + const { success, time } = await manageChannel(item.id, "test", testModel); if (success) { setResponseTimeData({ test_time: Date.now() / 1000, @@ -223,6 +230,15 @@ export default function ChannelTableRow({ 编辑 + { + handleCloseMenu(); + setOpenTest(true); + }} + > + + 测试 + 删除 @@ -241,6 +257,28 @@ export default function ChannelTableRow({ + + setOpenTest(false)} maxWidth="xs" fullWidth> + 测试渠道:{item.name} + + + 选择测试模型 + + + + + + + + ); } diff --git a/web/berry/src/views/Channel/index.js b/web/berry/src/views/Channel/index.js index 1a8b766..bf6df24 100644 --- a/web/berry/src/views/Channel/index.js +++ b/web/berry/src/views/Channel/index.js @@ -108,7 +108,7 @@ export default function ChannelPage() { }); break; case 'test': - res = await API.get(url + `test/${id}`); + res = await API.get(url + `test/${id}` + (value ? `?model=${encodeURIComponent(value)}` : '')); break; default: return; diff --git a/web/berry/src/views/Channel/type/Config.js b/web/berry/src/views/Channel/type/Config.js index 67b9073..e8f4136 100644 --- a/web/berry/src/views/Channel/type/Config.js +++ b/web/berry/src/views/Channel/type/Config.js @@ -123,13 +123,19 @@ const typeConfig = { }, 24: { inputLabel: { - other: '版本号' + other: '版本号', + config: { + usage_api_url: 'Usage API URL' + } }, input: { models: ['gemini-pro'] }, prompt: { - other: '请输入版本号,例如:v1' + other: '请输入版本号,例如:v1', + config: { + usage_api_url: '可选,填写后将在每次请求结束后从该接口获取真实的多模态 Token 用量,使用 {id} 作为 generation ID 占位符,例如:https://zenmux.ai/api/v1/generation?ctoken=xxx&id={id}' + } }, modelGroup: 'google gemini' }, @@ -228,6 +234,18 @@ const typeConfig = { 45: { modelGroup: 'xai' }, + 50: { + inputLabel: { + config: { + usage_api_url: 'Usage API URL' + } + }, + prompt: { + config: { + usage_api_url: '可选,填写后将在每次请求结束后从该接口获取真实的多模态 Token 用量,使用 {id} 作为 generation ID 占位符,例如:https://zenmux.ai/api/v1/generation?ctoken=xxx&id={id}' + } + } + } }; export { defaultConfig, typeConfig }; diff --git a/web/berry/src/views/Log/index.js b/web/berry/src/views/Log/index.js index f8cef0e..57a24e0 100644 --- a/web/berry/src/views/Log/index.js +++ b/web/berry/src/views/Log/index.js @@ -5,7 +5,7 @@ import Table from '@mui/material/Table'; import TableBody from '@mui/material/TableBody'; import TableContainer from '@mui/material/TableContainer'; import PerfectScrollbar from 'react-perfect-scrollbar'; -import TablePagination from '@mui/material/TablePagination'; +import Pagination from '@mui/material/Pagination'; import LinearProgress from '@mui/material/LinearProgress'; import ButtonGroup from '@mui/material/ButtonGroup'; import Toolbar from '@mui/material/Toolbar'; @@ -143,14 +143,16 @@ export default function Log() { - + + onPaginationChange(e, page - 1)} + color="primary" + showFirstButton + showLastButton + /> + );