diff --git a/controller/model.go b/controller/model.go
index d12fb82..89a161f 100644
--- a/controller/model.go
+++ b/controller/model.go
@@ -475,7 +475,7 @@ type zenmuxChannelDef struct {
var allZenmuxChannelDefs = []zenmuxChannelDef{
{
Name: "Zenmux",
- Type: 50, // OpenAICompatible
+ Type: channeltype.OpenAICompatible,
Protocols: []string{"chat.completions", "responses", "messages", "gemini", "generate"},
},
}
diff --git a/controller/relay.go b/controller/relay.go
index 038123b..a0448f7 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -26,7 +26,7 @@ import (
func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
var err *model.ErrorWithStatusCode
switch relayMode {
- case relaymode.ImagesGenerations:
+ case relaymode.ImagesGenerations, relaymode.ImagesEdits:
err = controller.RelayImageHelper(c, relayMode)
case relaymode.AudioSpeech:
fallthrough
diff --git a/model/channel.go b/model/channel.go
index 4b0f4b0..6d7e373 100644
--- a/model/channel.go
+++ b/model/channel.go
@@ -50,6 +50,11 @@ type ChannelConfig struct {
Plugin string `json:"plugin,omitempty"`
VertexAIProjectID string `json:"vertex_ai_project_id,omitempty"`
VertexAIADC string `json:"vertex_ai_adc,omitempty"`
+ // UsageAPIURL is an optional URL template for fetching accurate per-request
+ // token usage after a generation completes. Use {id} as a placeholder for
+ // the generation/request ID returned in the upstream response.
+ // Example: https://zenmux.ai/api/v1/generation?ctoken=xxx&id={id}
+ UsageAPIURL string `json:"usage_api_url,omitempty"`
}
func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) {
diff --git a/relay/adaptor/common.go b/relay/adaptor/common.go
index 8953d7a..31b006e 100644
--- a/relay/adaptor/common.go
+++ b/relay/adaptor/common.go
@@ -23,6 +23,7 @@ func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io.
if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err)
}
+
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
diff --git a/relay/adaptor/gemini/adaptor.go b/relay/adaptor/gemini/adaptor.go
index 41e2134..1180d25 100644
--- a/relay/adaptor/gemini/adaptor.go
+++ b/relay/adaptor/gemini/adaptor.go
@@ -39,7 +39,22 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
modelName := meta.ActualModelName
switch meta.Mode {
+ case relaymode.ImagesEdits:
+ // Image editing uses :predict (same as generation) with referenceImages in instances[]
+ if strings.Contains(modelName, "/") {
+ return fmt.Sprintf("%s/v1/%s:predict", meta.BaseURL, publisherModelPath(modelName)), nil
+ }
+ return fmt.Sprintf("%s/v1beta/models/%s:predict", meta.BaseURL, modelName), nil
case relaymode.ImagesGenerations:
+ // Imagen 3+ models use :generateImages endpoint
+ if isImagen3Model(modelName) {
+ return fmt.Sprintf("%s/v1/models/%s:generateImages", meta.BaseURL, publisherModelPath(modelName)), nil
+ }
+ // Publisher models (e.g. klingai/kling-v2, volcengine/doubao-*) use v1 :predict
+ if strings.Contains(modelName, "/") {
+ return fmt.Sprintf("%s/v1/%s:predict", meta.BaseURL, publisherModelPath(modelName)), nil
+ }
+ // Legacy imagegeneration models (no publisher prefix) use v1beta :predict
return fmt.Sprintf("%s/v1beta/models/%s:predict", meta.BaseURL, modelName), nil
case relaymode.Embeddings:
return fmt.Sprintf("%s/%s/models/%s:batchEmbedContents", meta.BaseURL, version, modelName), nil
@@ -49,7 +64,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
if meta.IsStream {
action = "streamGenerateContent?alt=sse"
}
- return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, modelName, action), nil
+ return fmt.Sprintf("%s/%s/%s:%s", meta.BaseURL, version, publisherModelPath(modelName), action), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
@@ -80,6 +95,41 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error)
if n <= 0 {
n = 1
}
+ // If an image is provided, convert to :editImage format (image-to-image)
+ if request.Image != nil {
+ imageStr, ok := request.Image.(string)
+ if !ok {
+ return nil, errors.New("image field must be a base64 string or data URL")
+ }
+ // Strip data URL prefix if present (e.g. "data:image/png;base64,")
+ base64Data := imageStr
+ if idx := strings.Index(imageStr, ";base64,"); idx != -1 {
+ base64Data = imageStr[idx+8:]
+ }
+ return EditImageRequest{
+ Instances: []EditImageInstance{
+ {
+ Prompt: request.Prompt,
+ ReferenceImages: []ReferenceImageItem{
+ {
+ ReferenceType: "REFERENCE_TYPE_RAW",
+ ReferenceId: 1,
+ ReferenceImage: ReferenceImageData{
+ BytesBase64Encoded: base64Data,
+ },
+ },
+ },
+ },
+ },
+ }, nil
+ }
+ if isImagen3Model(request.Model) {
+ return GenerateImagesRequest{
+ Prompt: request.Prompt,
+ NumberOfImages: n,
+ AspectRatio: sizeToAspectRatio(request.Size),
+ }, nil
+ }
return ImagenRequest{
Instances: []ImagenInstance{{Prompt: request.Prompt}},
Parameters: ImagenParameters{
@@ -89,7 +139,28 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error)
}, nil
}
-// sizeToAspectRatio converts OpenAI size strings to Imagen aspect ratios.
+// publisherModelPath converts publisher/model names to the correct Vertex AI URL path segment.
+// Google's own models (gemini/, google/) keep the raw /models/publisher/model format.
+// Third-party publisher models use /publishers/{pub}/models/{model} format.
+func publisherModelPath(modelName string) string {
+ if idx := strings.Index(modelName, "/"); idx != -1 {
+ publisher := modelName[:idx]
+ model := modelName[idx+1:]
+ if publisher == "gemini" || publisher == "google" {
+ return fmt.Sprintf("models/%s", modelName)
+ }
+ return fmt.Sprintf("publishers/%s/models/%s", publisher, model)
+ }
+ return fmt.Sprintf("models/%s", modelName)
+}
+
+// isImagen3Model returns true for image-generation-only models that use the :generateImages endpoint.
+func isImagen3Model(modelName string) bool {
+ lower := strings.ToLower(modelName)
+ return strings.HasPrefix(lower, "imagen-3") ||
+ strings.HasPrefix(lower, "imagen-4") ||
+ strings.Contains(lower, "imagegeneration@00")
+}
func sizeToAspectRatio(size string) string {
switch size {
case "1792x1024":
@@ -109,7 +180,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Read
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
- err, usage = StreamHandler(c, resp)
+ err, usage, meta.GenerationId = StreamHandler(c, resp)
if usage != nil {
usage.PromptTokens = meta.PromptTokens
}
@@ -117,7 +188,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
switch meta.Mode {
case relaymode.Embeddings:
err, usage = EmbeddingHandler(c, resp)
- case relaymode.ImagesGenerations:
+ case relaymode.ImagesGenerations, relaymode.ImagesEdits:
err = ImagenHandler(c, resp)
default:
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
@@ -126,7 +197,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
return
}
-// ImagenHandler converts a Google Imagen predict response to OpenAI image response format.
+// ImagenHandler converts a Google Imagen predict/generateImages response to OpenAI image response format.
+// Handles both Vertex AI format (predictions[].bytesBase64Encoded) and OpenAI format (data[].b64_json).
func ImagenHandler(c *gin.Context, resp *http.Response) *model.ErrorWithStatusCode {
responseBody, readErr := io.ReadAll(resp.Body)
resp.Body.Close()
@@ -134,24 +206,79 @@ func ImagenHandler(c *gin.Context, resp *http.Response) *model.ErrorWithStatusCo
return openai.ErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
}
- var imagenResp ImagenResponse
- if err := json.Unmarshal(responseBody, &imagenResp); err != nil {
- return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
- }
+ fmt.Printf("[ImagenHandler] raw response body: %s\n", string(responseBody))
- if imagenResp.Error != nil {
+ // Safety check: empty response
+ if len(responseBody) == 0 || string(responseBody) == "{}" || string(responseBody) == "null" {
return openai.ErrorWrapper(
- fmt.Errorf("%s", imagenResp.Error.Message),
- "imagen_error",
- http.StatusInternalServerError,
+ fmt.Errorf("upstream returned empty response body"),
+ "empty_response_body",
+ http.StatusBadGateway,
)
}
- // Convert to OpenAI image response format
- data := make([]openai.ImageData, 0, len(imagenResp.Predictions))
- for _, p := range imagenResp.Predictions {
- data = append(data, openai.ImageData{B64Json: p.BytesBase64Encoded})
+ data := make([]openai.ImageData, 0)
+
+ // Format 1: :generateImages / :editImage → generatedImages[].image.imageBytes
+ var generateResp GenerateImagesResponse
+ if err := json.Unmarshal(responseBody, &generateResp); err == nil {
+ if generateResp.Error != nil {
+ return openai.ErrorWrapper(
+ fmt.Errorf("%s", generateResp.Error.Message),
+ "imagen_error",
+ http.StatusInternalServerError,
+ )
+ }
+ for _, img := range generateResp.GeneratedImages {
+ if img.Image.ImageBytes != "" {
+ data = append(data, openai.ImageData{B64Json: img.Image.ImageBytes})
+ }
+ }
}
+
+ // Format 2: legacy :predict → predictions[].bytesBase64Encoded or predictions[].gcsUri
+ if len(data) == 0 {
+ var imagenResp ImagenResponse
+ if err := json.Unmarshal(responseBody, &imagenResp); err == nil {
+ if imagenResp.Error != nil {
+ return openai.ErrorWrapper(
+ fmt.Errorf("%s", imagenResp.Error.Message),
+ "imagen_error",
+ http.StatusInternalServerError,
+ )
+ }
+ for _, p := range imagenResp.Predictions {
+ if p.BytesBase64Encoded != "" {
+ data = append(data, openai.ImageData{B64Json: p.BytesBase64Encoded})
+ } else if p.GcsUri != "" {
+ data = append(data, openai.ImageData{Url: p.GcsUri})
+ }
+ }
+ }
+ }
+
+ // Format 3: OpenAI format data[].b64_json — used by models like openai/gpt-image-*
+ if len(data) == 0 {
+ var openaiImgResp struct {
+ Data []struct {
+ B64Json string `json:"b64_json"`
+ URL string `json:"url"`
+ } `json:"data"`
+ }
+ if err := json.Unmarshal(responseBody, &openaiImgResp); err == nil {
+ for _, item := range openaiImgResp.Data {
+ data = append(data, openai.ImageData{B64Json: item.B64Json, Url: item.URL})
+ }
+ }
+ }
+ if len(data) == 0 {
+ return openai.ErrorWrapper(
+ fmt.Errorf("upstream returned no image data (response: %s)", string(responseBody)),
+ "no_image_data",
+ http.StatusBadGateway,
+ )
+ }
+
openaiResp := openai.ImageResponse{
Created: time.Now().Unix(),
Data: data,
diff --git a/relay/adaptor/gemini/main.go b/relay/adaptor/gemini/main.go
index 8724807..b32b5eb 100644
--- a/relay/adaptor/gemini/main.go
+++ b/relay/adaptor/gemini/main.go
@@ -321,13 +321,21 @@ func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *BatchEmbedding
}
}
+type TokensDetail struct {
+ Modality string `json:"modality"`
+ TokenCount int `json:"tokenCount"`
+}
+
type UsageMetadata struct {
- PromptTokenCount int `json:"promptTokenCount"`
- CandidatesTokenCount int `json:"candidatesTokenCount"`
- TotalTokenCount int `json:"totalTokenCount"`
+ PromptTokenCount int `json:"promptTokenCount"`
+ CandidatesTokenCount int `json:"candidatesTokenCount"`
+ TotalTokenCount int `json:"totalTokenCount"`
+ ThoughtsTokenCount int `json:"thoughtsTokenCount"`
+ CandidatesTokensDetails []TokensDetail `json:"candidatesTokensDetails"`
}
type ChatResponse struct {
+ Id string `json:"id,omitempty"` // set by some proxies; used for metadata fetches
Candidates []ChatCandidate `json:"candidates"`
PromptFeedback ChatPromptFeedback `json:"promptFeedback"`
UsageMetadata *UsageMetadata `json:"usageMetadata"`
@@ -475,9 +483,15 @@ func embeddingResponseGemini2OpenAI(response *EmbeddingResponse) *openai.Embeddi
return &openAIEmbeddingResponse
}
-func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
+// StreamHandler processes a Gemini SSE stream and returns
+// (error, usage, generationId).
+// generationId is captured from the response's "id" field if present (set by
+// some proxies); callers can use it for post-response metadata fetches.
+func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage, string) {
var usage *model.Usage
+ var generationId string
responseText := ""
+ outputImageCount := 0
scanner := bufio.NewScanner(resp.Body)
// Default bufio.Scanner buffer is 64KB which is too small for inline image data (base64).
// Allocate 20MB to handle large image payloads from Gemini image-generation models.
@@ -503,14 +517,36 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
continue
}
+ if generationId == "" && geminiResponse.Id != "" {
+ generationId = geminiResponse.Id
+ }
+
// Extract usageMetadata from the last chunk that carries it.
- // This includes image/video/audio generation costs that cannot be
- // estimated from text tokenisation alone.
if geminiResponse.UsageMetadata != nil {
+ meta := geminiResponse.UsageMetadata
+ // Image output tokens are priced at $60/M vs text completion $3/M (20× ratio).
+ // Separate image and text tokens from candidatesTokensDetails so we can
+ // represent image tokens as equivalent text completion tokens for billing.
+ imageOutputTokens := 0
+ textOutputTokens := 0
+ if len(meta.CandidatesTokensDetails) > 0 {
+ for _, d := range meta.CandidatesTokensDetails {
+ if d.Modality == "IMAGE" {
+ imageOutputTokens += d.TokenCount
+ } else {
+ textOutputTokens += d.TokenCount
+ }
+ }
+ } else {
+ textOutputTokens = meta.CandidatesTokenCount
+ }
+ // ThoughtsTokenCount billed at text completion rate; image tokens at 20× that rate.
+ const imageToTextRatio = 20
+ completionTokens := textOutputTokens + imageOutputTokens*imageToTextRatio + meta.ThoughtsTokenCount
usage = &model.Usage{
- PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
- CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
- TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
+ PromptTokens: meta.PromptTokenCount,
+ CompletionTokens: completionTokens,
+ TotalTokens: meta.TotalTokenCount,
}
}
@@ -520,11 +556,13 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
}
// Accumulate text for fallback token estimation (used only when
- // usageMetadata is absent from the stream).
+ // usageMetadata is absent from the stream). Also count output images.
if len(geminiResponse.Candidates) > 0 {
for _, part := range geminiResponse.Candidates[0].Content.Parts {
if part.InlineData == nil {
responseText += part.Text
+ } else if strings.HasPrefix(part.InlineData.MimeType, "image/") {
+ outputImageCount++
}
}
}
@@ -543,15 +581,14 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
err := resp.Body.Close()
if err != nil {
- return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, ""
}
- // If upstream provided usageMetadata, use it (accurate, includes image costs).
- // Otherwise fall back to local tiktoken estimation on text-only content.
+ // If upstream provided usageMetadata, use it (image tokens already adjusted above).
if usage != nil {
- return nil, usage
+ return nil, usage, generationId
}
- return nil, openai.ResponseText2Usage(responseText, "", 0)
+ return nil, openai.ResponseText2Usage(responseText, "", 0), generationId
}
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
diff --git a/relay/adaptor/gemini/model.go b/relay/adaptor/gemini/model.go
index bff08d6..95f2871 100644
--- a/relay/adaptor/gemini/model.go
+++ b/relay/adaptor/gemini/model.go
@@ -92,11 +92,11 @@ type ChatGenerationConfig struct {
// --- Google Imagen ---
-// ImagenRequest is the request body for Google Imagen (predict endpoint).
+// ImagenRequest is the request body for legacy Imagen (predict endpoint).
// POST /v1beta/models/{model}:predict
type ImagenRequest struct {
- Instances []ImagenInstance `json:"instances"`
- Parameters ImagenParameters `json:"parameters"`
+ Instances []ImagenInstance `json:"instances"`
+ Parameters ImagenParameters `json:"parameters"`
}
type ImagenInstance struct {
@@ -108,14 +108,61 @@ type ImagenParameters struct {
AspectRatio string `json:"aspectRatio,omitempty"`
}
-// ImagenResponse is the response from Google Imagen.
+// GenerateImagesRequest is the request body for Imagen 3+ (generateImages endpoint).
+// POST /v1/models/{model}:generateImages
+type GenerateImagesRequest struct {
+ Prompt string `json:"prompt"`
+ NumberOfImages int `json:"number_of_images"`
+ AspectRatio string `json:"aspectRatio,omitempty"`
+}
+
+// EditImageRequest is the request body for Vertex AI image editing via :predict.
+// Same endpoint as generation, uses instances[] with referenceImages. No parameters field.
+type EditImageRequest struct {
+ Instances []EditImageInstance `json:"instances"`
+}
+
+type EditImageInstance struct {
+ Prompt string `json:"prompt"`
+ ReferenceImages []ReferenceImageItem `json:"referenceImages"`
+}
+
+type ReferenceImageItem struct {
+ ReferenceType string `json:"referenceType"`
+ ReferenceId int `json:"referenceId"`
+ ReferenceImage ReferenceImageData `json:"referenceImage"`
+}
+
+type ReferenceImageData struct {
+ BytesBase64Encoded string `json:"bytesBase64Encoded"`
+}
+
+// ImagenResponse is the response from legacy Imagen :predict endpoint.
type ImagenResponse struct {
Predictions []ImagenPrediction `json:"predictions"`
Error *Error `json:"error,omitempty"`
}
+// GenerateImagesResponse is the response from :generateImages and :editImage endpoints.
+type GenerateImagesResponse struct {
+ GeneratedImages []GeneratedImageItem `json:"generatedImages"`
+ Error *Error `json:"error,omitempty"`
+}
+
+type GeneratedImageItem struct {
+ Image GeneratedImageData `json:"image"`
+ MimeType string `json:"mimeType"`
+}
+
+type GeneratedImageData struct {
+ // ImageBytes is base64-encoded image data returned by :generateImages / :editImage
+ ImageBytes string `json:"imageBytes"`
+}
+
type ImagenPrediction struct {
BytesBase64Encoded string `json:"bytesBase64Encoded"`
MimeType string `json:"mimeType"`
+ GcsUri string `json:"gcsUri"`
+ Prompt string `json:"prompt"`
}
diff --git a/relay/adaptor/openai/adaptor.go b/relay/adaptor/openai/adaptor.go
index 8faf90a..85f7e72 100644
--- a/relay/adaptor/openai/adaptor.go
+++ b/relay/adaptor/openai/adaptor.go
@@ -109,7 +109,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Read
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
var responseText string
- err, responseText, usage = StreamHandler(c, resp, meta.Mode)
+ err, responseText, usage, meta.GenerationId = StreamHandler(c, resp, meta.Mode)
if usage == nil || usage.TotalTokens == 0 {
usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
}
diff --git a/relay/adaptor/openai/helper.go b/relay/adaptor/openai/helper.go
index c6d5bd7..c7e7d75 100644
--- a/relay/adaptor/openai/helper.go
+++ b/relay/adaptor/openai/helper.go
@@ -1,17 +1,77 @@
package openai
import (
+ "encoding/json"
"fmt"
+ "net/http"
"strings"
+ "time"
+ "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/model"
)
+var metadataHTTPClient = &http.Client{Timeout: 8 * time.Second}
+
+// usageMetadataResponse is a generic structure for upstream metadata endpoints
+// that return nativeTokens (e.g. Zenmux /api/v1/generation).
+type usageMetadataResponse struct {
+ NativeTokens struct {
+ PromptTokenCount int `json:"promptTokenCount"`
+ CandidatesTokenCount int `json:"candidatesTokenCount"`
+ TotalTokenCount int `json:"totalTokenCount"`
+ // ThoughtsTokenCount is billed at the completion rate, included in CompletionTokens.
+ ThoughtsTokenCount int `json:"thoughtsTokenCount"`
+ } `json:"nativeTokens"`
+}
+
+// FetchUsageFromMetadataURL fetches accurate token usage from an upstream metadata
+// endpoint. urlTemplate must contain {id} which is replaced with generationId.
+// Returns nil if the fetch fails or the response contains no usable token data.
+func FetchUsageFromMetadataURL(urlTemplate, generationId string) (*model.Usage, error) {
+ url := strings.ReplaceAll(urlTemplate, "{id}", generationId)
+ resp, err := metadataHTTPClient.Get(url)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("metadata API returned status %d", resp.StatusCode)
+ }
+ var meta usageMetadataResponse
+ if err := json.NewDecoder(resp.Body).Decode(&meta); err != nil {
+ return nil, err
+ }
+ if meta.NativeTokens.TotalTokenCount == 0 {
+ return nil, fmt.Errorf("metadata API returned zero total tokens")
+ }
+ prompt := meta.NativeTokens.PromptTokenCount
+ completion := meta.NativeTokens.CandidatesTokenCount + meta.NativeTokens.ThoughtsTokenCount
+ logger.SysLog(fmt.Sprintf("usage from metadata API (id=%s): prompt=%d completion=%d", generationId, prompt, completion))
+ return &model.Usage{
+ PromptTokens: prompt,
+ CompletionTokens: completion,
+ TotalTokens: prompt + completion,
+ }, nil
+}
+
+// countOutputMediaTokens returns a fixed token estimate for any embedded
+// media data URIs found in text (image/video/audio), consistent with the
+// fixed estimates used for input media in CountTokenMessages.
+func countOutputMediaTokens(text string) int {
+ tokens := 0
+ tokens += strings.Count(text, "data:image/") * 2500
+ tokens += strings.Count(text, "data:video/") * 10000
+ tokens += strings.Count(text, "data:audio/") * 1500
+ return tokens
+}
+
func ResponseText2Usage(responseText string, modelName string, promptTokens int) *model.Usage {
usage := &model.Usage{}
usage.PromptTokens = promptTokens
- usage.CompletionTokens = CountTokenText(responseText, modelName)
+ usage.CompletionTokens = CountTokenText(stripBase64Payloads(responseText), modelName) +
+ countOutputMediaTokens(responseText)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return usage
}
diff --git a/relay/adaptor/openai/main.go b/relay/adaptor/openai/main.go
index 9708073..2023025 100644
--- a/relay/adaptor/openai/main.go
+++ b/relay/adaptor/openai/main.go
@@ -24,11 +24,16 @@ const (
dataPrefixLength = len(dataPrefix)
)
-func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) {
+// StreamHandler processes an OpenAI-compatible SSE stream and returns
+// (error, responseText, usage, generationId).
+// generationId is the "id" field captured from the first stream chunk; it can
+// be used by callers that need to fetch per-request metadata from an upstream API.
+func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage, string) {
responseText := ""
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
var usage *model.Usage
+ var generationId string
common.SetEventStreamHeaders(c)
@@ -63,6 +68,9 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
for _, choice := range streamResponse.Choices {
responseText += conv.AsString(choice.Delta.Content)
}
+ if generationId == "" && streamResponse.Id != "" {
+ generationId = streamResponse.Id
+ }
if streamResponse.Usage != nil {
usage = streamResponse.Usage
}
@@ -90,10 +98,10 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
err := resp.Body.Close()
if err != nil {
- return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil
+ return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil, ""
}
- return nil, responseText, usage
+ return nil, responseText, usage, generationId
}
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
diff --git a/relay/adaptor/openai/token.go b/relay/adaptor/openai/token.go
index 1d3a276..1bccccf 100644
--- a/relay/adaptor/openai/token.go
+++ b/relay/adaptor/openai/token.go
@@ -73,6 +73,33 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
return len(tokenEncoder.Encode(text, nil, nil))
}
+// stripBase64Payloads removes the base64 payload from data URIs so that huge
+// image/video/audio blobs do not inflate the local token estimate.
+// Example: "data:image/jpeg;base64,/9j/4AAQ..." → "data:image/jpeg;base64,"
+func stripBase64Payloads(s string) string {
+ const marker = ";base64,"
+ out := s
+ for {
+ idx := strings.Index(out, marker)
+ if idx < 0 {
+ break
+ }
+ payloadStart := idx + len(marker)
+ end := payloadStart
+ for end < len(out) {
+ c := out[end]
+ if (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') ||
+ (c >= '0' && c <= '9') || c == '+' || c == '/' || c == '=' {
+ end++
+ } else {
+ break
+ }
+ }
+ out = out[:payloadStart] + out[end:]
+ }
+ return out
+}
+
func CountTokenMessages(messages []model.Message, model string) int {
tokenEncoder := getTokenEncoder(model)
// Reference:
@@ -94,7 +121,7 @@ func CountTokenMessages(messages []model.Message, model string) int {
tokenNum += tokensPerMessage
switch v := message.Content.(type) {
case string:
- tokenNum += getTokenNum(tokenEncoder, v)
+ tokenNum += getTokenNum(tokenEncoder, stripBase64Payloads(v))
case []any:
for _, it := range v {
m := it.(map[string]any)
@@ -102,7 +129,7 @@ func CountTokenMessages(messages []model.Message, model string) int {
case "text":
if textValue, ok := m["text"]; ok {
if textString, ok := textValue.(string); ok {
- tokenNum += getTokenNum(tokenEncoder, textString)
+ tokenNum += getTokenNum(tokenEncoder, stripBase64Payloads(textString))
}
}
case "image_url":
@@ -114,12 +141,24 @@ func CountTokenMessages(messages []model.Message, model string) int {
detail = imageUrl["detail"].(string)
}
imageTokens, err := countImageTokens(url, detail, model)
- if err != nil {
- logger.SysError("error counting image tokens: " + err.Error())
- } else {
- tokenNum += imageTokens
+ if err != nil || imageTokens == 0 {
+ // Fallback for base64 images (payload stripped, dimensions
+ // unavailable) or any other failure: use a conservative
+ // fixed estimate (~768x768 high-detail image).
+ if err != nil {
+ logger.SysError("error counting image tokens, using fallback: " + err.Error())
+ }
+ imageTokens = 765
}
+ tokenNum += imageTokens
}
+ case "video_url":
+ // Videos cannot be tokenised locally; conservative fixed estimate.
+ // Final billing always uses the upstream's real usage.
+ tokenNum += 2000
+ case "input_audio", "audio":
+ // Audio cannot be tokenised locally; conservative fixed estimate.
+ tokenNum += 500
}
}
}
diff --git a/relay/adaptor/vertexai/gemini/adapter.go b/relay/adaptor/vertexai/gemini/adapter.go
index e4a8591..59db1a1 100644
--- a/relay/adaptor/vertexai/gemini/adapter.go
+++ b/relay/adaptor/vertexai/gemini/adapter.go
@@ -39,7 +39,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
- err, usage = gemini.StreamHandler(c, resp)
+ err, usage, meta.GenerationId = gemini.StreamHandler(c, resp)
if usage != nil {
usage.PromptTokens = meta.PromptTokens
}
diff --git a/relay/adaptor/zhipu/adaptor.go b/relay/adaptor/zhipu/adaptor.go
index 660bd37..f07cc6a 100644
--- a/relay/adaptor/zhipu/adaptor.go
+++ b/relay/adaptor/zhipu/adaptor.go
@@ -98,7 +98,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Read
func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
- err, _, usage = openai.StreamHandler(c, resp, meta.Mode)
+ err, _, usage, _ = openai.StreamHandler(c, resp, meta.Mode)
} else {
err, usage = openai.Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
diff --git a/relay/billing/ratio/image.go b/relay/billing/ratio/image.go
index c8c42a1..37c146d 100644
--- a/relay/billing/ratio/image.go
+++ b/relay/billing/ratio/image.go
@@ -11,6 +11,12 @@ var ImageSizeRatios = map[string]map[string]float64{
"1024x1792": 2,
"1792x1024": 2,
},
+ // gpt-image-1.5: base=medium 1024x1024 ($0.034), size ratio for other sizes
+ "openai/gpt-image-1.5": {
+ "1024x1024": 1.0, // medium $0.034
+ "1024x1536": 1.47, // medium $0.050
+ "1536x1024": 1.47, // medium $0.050
+ },
"ali-stable-diffusion-xl": {
"512x1024": 1,
"1024x768": 1,
diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go
index e8b3b61..219d7f4 100644
--- a/relay/billing/ratio/model.go
+++ b/relay/billing/ratio/model.go
@@ -87,6 +87,14 @@ var ModelRatio = map[string]float64{
"text-moderation-latest": 0.1,
"dall-e-2": 0.02 * USD, // $0.016 - $0.020 / image
"dall-e-3": 0.04 * USD, // $0.040 - $0.120 / image
+ // Zenmux image generation models (calibrated for 1M quota/dollar)
+ "openai/gpt-image-1.5": 34, // $0.034 / image (medium 1024x1024 base)
+ "qwen/qwen-image-2.0": 28.9, // $0.0289 / image
+ "qwen/qwen-image-2.0-pro": 73, // $0.073 / image
+ "volcengine/doubao-seedream-5.0-lite": 32, // $0.032 / image
+ "z-ai/glm-image": 15, // $0.015 / image
+ "tencent/hunyuan-image3": 29, // $0.029 / image
+ "klingai/kling-v2": 14, // $0.014 / image
// https://docs.anthropic.com/en/docs/about-claude/models
"claude-instant-1.2": 0.8 / 1000 * USD,
"claude-2.0": 8.0 / 1000 * USD,
diff --git a/relay/controller/text.go b/relay/controller/text.go
index f912498..0a01a90 100644
--- a/relay/controller/text.go
+++ b/relay/controller/text.go
@@ -82,6 +82,16 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
return respErr
}
+ // If the channel has a UsageAPIURL configured and the adaptor captured a
+ // generation ID, fetch accurate multimodal token counts now (works for any
+ // channel type that sets meta.GenerationId during DoResponse).
+ if meta.Config.UsageAPIURL != "" && meta.GenerationId != "" {
+ if fetchedUsage, fetchErr := openai.FetchUsageFromMetadataURL(meta.Config.UsageAPIURL, meta.GenerationId); fetchErr == nil {
+ usage = fetchedUsage
+ } else {
+ logger.Warnf(ctx, "failed to fetch usage from metadata URL: %s", fetchErr.Error())
+ }
+ }
// post-consume quota
go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio, systemPromptReset)
return nil
diff --git a/relay/meta/relay_meta.go b/relay/meta/relay_meta.go
index 8c74ef8..501409f 100644
--- a/relay/meta/relay_meta.go
+++ b/relay/meta/relay_meta.go
@@ -35,6 +35,9 @@ type Meta struct {
PromptTokens int // only for DoResponse
ForcedSystemPrompt string
StartTime time.Time
+ // GenerationId is the upstream generation/request ID captured during DoResponse.
+ // Adaptors should set this when available; it is used for post-response metadata fetches.
+ GenerationId string
}
func GetByContext(c *gin.Context) *Meta {
diff --git a/relay/model/image.go b/relay/model/image.go
index bab8425..d7b8acf 100644
--- a/relay/model/image.go
+++ b/relay/model/image.go
@@ -3,6 +3,7 @@ package model
type ImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt" binding:"required"`
+ Image any `json:"image,omitempty"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
Quality string `json:"quality,omitempty"`
diff --git a/relay/relaymode/define.go b/relay/relaymode/define.go
index aa77120..ef74092 100644
--- a/relay/relaymode/define.go
+++ b/relay/relaymode/define.go
@@ -7,6 +7,7 @@ const (
Embeddings
Moderations
ImagesGenerations
+ ImagesEdits
Edits
AudioSpeech
AudioTranscription
diff --git a/relay/relaymode/helper.go b/relay/relaymode/helper.go
index 2cde5b8..c17d5b0 100644
--- a/relay/relaymode/helper.go
+++ b/relay/relaymode/helper.go
@@ -16,6 +16,8 @@ func GetByPath(path string) int {
relayMode = Moderations
} else if strings.HasPrefix(path, "/v1/images/generations") {
relayMode = ImagesGenerations
+ } else if strings.HasPrefix(path, "/v1/images/edits") {
+ relayMode = ImagesEdits
} else if strings.HasPrefix(path, "/v1/edits") {
relayMode = Edits
} else if strings.HasPrefix(path, "/v1/audio/speech") {
diff --git a/router/relay.go b/router/relay.go
index 8f3c730..554a64f 100644
--- a/router/relay.go
+++ b/router/relay.go
@@ -25,7 +25,7 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router.POST("/chat/completions", controller.Relay)
relayV1Router.POST("/edits", controller.Relay)
relayV1Router.POST("/images/generations", controller.Relay)
- relayV1Router.POST("/images/edits", controller.RelayNotImplemented)
+ relayV1Router.POST("/images/edits", controller.Relay)
relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
relayV1Router.POST("/embeddings", controller.Relay)
relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
diff --git a/web/berry/src/views/Channel/component/EditModal.js b/web/berry/src/views/Channel/component/EditModal.js
index c2520a9..5576a8d 100644
--- a/web/berry/src/views/Channel/component/EditModal.js
+++ b/web/berry/src/views/Channel/component/EditModal.js
@@ -44,7 +44,7 @@ const validationSchema = Yup.object().shape({
is: (is_edit, type) => !is_edit && type !== 33,
then: Yup.string().required('密钥 不能为空')
}),
- other: Yup.string(),
+ other: Yup.string().nullable(),
models: Yup.array().min(1, '模型 不能为空'),
groups: Yup.array().min(1, '用户组 不能为空'),
base_url: Yup.string().when('type', {
diff --git a/web/berry/src/views/Channel/component/TableRow.js b/web/berry/src/views/Channel/component/TableRow.js
index 525f918..b83308d 100644
--- a/web/berry/src/views/Channel/component/TableRow.js
+++ b/web/berry/src/views/Channel/component/TableRow.js
@@ -19,6 +19,9 @@ import {
DialogTitle,
Tooltip,
Button,
+ Select,
+ FormControl,
+ InputLabel,
} from "@mui/material";
import Label from "ui-component/Label";
@@ -28,7 +31,7 @@ import ResponseTimeLabel from "./ResponseTimeLabel";
import GroupLabel from "./GroupLabel";
import NameLabel from "./NameLabel";
-import { IconDotsVertical, IconEdit, IconTrash } from "@tabler/icons-react";
+import { IconDotsVertical, IconEdit, IconTrash, IconBrandSpeedtest } from "@tabler/icons-react";
export default function ChannelTableRow({
item,
@@ -36,9 +39,12 @@ export default function ChannelTableRow({
handleOpenModal,
setModalChannelId,
}) {
+ const modelOptions = item.models ? item.models.split(',').filter(Boolean) : [];
const [open, setOpen] = useState(null);
const [openDelete, setOpenDelete] = useState(false);
+ const [openTest, setOpenTest] = useState(false);
const [statusSwitch, setStatusSwitch] = useState(item.status);
+ const [testModel, setTestModel] = useState(modelOptions[0] || '');
const [priorityValve, setPriority] = useState(item.priority);
const [responseTimeData, setResponseTimeData] = useState({
test_time: item.test_time,
@@ -87,7 +93,8 @@ export default function ChannelTableRow({
};
const handleResponseTime = async () => {
- const { success, time } = await manageChannel(item.id, "test", "");
+ setOpenTest(false);
+ const { success, time } = await manageChannel(item.id, "test", testModel);
if (success) {
setResponseTimeData({
test_time: Date.now() / 1000,
@@ -223,6 +230,15 @@ export default function ChannelTableRow({
编辑
+