feat: 添加视频URL支持及Zenmux集成
refactor: 重构Gemini适配器以支持多模态输入 fix: 修复React Hooks依赖警告 style: 清理未使用的导入和代码 docs: 更新用户界面文本和提示 perf: 优化图像和视频URL处理性能 test: 添加数据迁移工具和测试 build: 更新依赖项和.gitignore chore: 同步Zenmux模型和价格比例
This commit is contained in:
parent
885ad0507b
commit
c734c541b2
1
.gitignore
vendored
1
.gitignore
vendored
@ -13,3 +13,4 @@ cmd.md
|
||||
/one-api
|
||||
temp
|
||||
.DS_Store
|
||||
.claude
|
||||
@ -16,8 +16,8 @@ import (
|
||||
_ "golang.org/x/image/webp"
|
||||
)
|
||||
|
||||
// Regex to match data URL pattern
|
||||
var dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`)
|
||||
// Regex to match data URL pattern — supports image, video, and audio MIME types
|
||||
var dataURLPattern = regexp.MustCompile(`data:([^;]+);base64,(.*)`)
|
||||
|
||||
func IsImageUrl(url string) (bool, error) {
|
||||
resp, err := client.UserContentRequestHTTPClient.Head(url)
|
||||
@ -51,8 +51,8 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
||||
// Check if the URL is a data URL
|
||||
matches := dataURLPattern.FindStringSubmatch(url)
|
||||
if len(matches) == 3 {
|
||||
// URL is a data URL
|
||||
mimeType = "image/" + matches[1]
|
||||
// URL is a data URL — matches[1] is the full MIME type (e.g. "video/mp4", "audio/webm", "image/png")
|
||||
mimeType = matches[1]
|
||||
data = matches[2]
|
||||
return
|
||||
}
|
||||
|
||||
@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@ -12,6 +13,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
relay "github.com/songquanpeng/one-api/relay"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/apitype"
|
||||
@ -224,12 +226,15 @@ type ModelCatalogItem struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
Description string `json:"description"`
|
||||
InputModalities []string `json:"input_modalities"`
|
||||
OutputModalities []string `json:"output_modalities"`
|
||||
ContextLength int `json:"context_length"`
|
||||
MaxOutputTokens int `json:"max_output_tokens"`
|
||||
InputPrice float64 `json:"input_price"`
|
||||
OutputPrice float64 `json:"output_price"`
|
||||
// PricingDiscount: e.g. "0.8" means 20% off; applied to both input and output prices
|
||||
PricingDiscount string `json:"pricing_discount"`
|
||||
Tags []string `json:"tags"`
|
||||
// SupportsReasoning: 0=none, 1=always-on, 2=toggleable
|
||||
SupportsReasoning int `json:"supports_reasoning"`
|
||||
@ -241,12 +246,14 @@ type zenmuxModel struct {
|
||||
Slug string `json:"slug"`
|
||||
Name string `json:"name"`
|
||||
Author string `json:"author"`
|
||||
Description string `json:"description"`
|
||||
InputModalities string `json:"input_modalities"` // comma-separated
|
||||
OutputModalities string `json:"output_modalities"` // comma-separated
|
||||
ContextLength int `json:"context_length"`
|
||||
MaxCompletionTokens int `json:"max_completion_tokens"`
|
||||
PricingPrompt string `json:"pricing_prompt"`
|
||||
PricingCompletion string `json:"pricing_completion"`
|
||||
PricingDiscount string `json:"pricing_discount"`
|
||||
SuitableApi string `json:"suitable_api"` // comma-separated
|
||||
SupportsReasoning int `json:"supports_reasoning"`
|
||||
SupportedParameters string `json:"supported_parameters"` // comma-separated
|
||||
@ -324,12 +331,14 @@ func zenmuxToItem(z zenmuxModel) ModelCatalogItem {
|
||||
Id: z.Slug,
|
||||
Name: z.Name,
|
||||
Provider: z.Author,
|
||||
Description: z.Description,
|
||||
InputModalities: inMod,
|
||||
OutputModalities: outMod,
|
||||
ContextLength: z.ContextLength,
|
||||
MaxOutputTokens: z.MaxCompletionTokens,
|
||||
InputPrice: parsePrice(z.PricingPrompt),
|
||||
OutputPrice: parsePrice(z.PricingCompletion),
|
||||
PricingDiscount: z.PricingDiscount,
|
||||
Tags: tags,
|
||||
SupportsReasoning: z.SupportsReasoning,
|
||||
SuitableApi: z.SuitableApi,
|
||||
@ -375,25 +384,50 @@ func fetchZenmuxCatalog() ([]ModelCatalogItem, error) {
|
||||
return items, nil
|
||||
}
|
||||
|
||||
// GetTokenQuota returns quota info for the authenticated relay token.
|
||||
// GET /api/token-quota — requires TokenAuth middleware.
|
||||
func GetTokenQuota(c *gin.Context) {
|
||||
tokenId := c.GetInt(ctxkey.TokenId)
|
||||
token, err := model.GetTokenById(tokenId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": gin.H{
|
||||
"remain_quota": token.RemainQuota,
|
||||
"used_quota": token.UsedQuota,
|
||||
"unlimited_quota": token.UnlimitedQuota,
|
||||
"expired_time": token.ExpiredTime,
|
||||
"name": token.Name,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// GetModelCatalog returns models with capability metadata, filtered to only those
|
||||
// available to the requesting user's group.
|
||||
func GetModelCatalog(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
id := c.GetInt(ctxkey.Id)
|
||||
|
||||
var availableModels []string
|
||||
if c.GetString(ctxkey.AvailableModels) != "" {
|
||||
availableModels = strings.Split(c.GetString(ctxkey.AvailableModels), ",")
|
||||
} else {
|
||||
id := c.GetInt(ctxkey.Id)
|
||||
userGroup, err := model.CacheGetUserGroup(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
availableModels, err := model.CacheGetGroupModels(ctx, userGroup)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
var groupErr error
|
||||
availableModels, groupErr = model.CacheGetGroupModels(ctx, userGroup)
|
||||
if groupErr != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": groupErr.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Build a set for O(1) lookup
|
||||
available := make(map[string]bool, len(availableModels))
|
||||
for _, m := range availableModels {
|
||||
available[m] = true
|
||||
@ -401,7 +435,6 @@ func GetModelCatalog(c *gin.Context) {
|
||||
|
||||
catalog, err := fetchZenmuxCatalog()
|
||||
if err != nil {
|
||||
// Return empty catalog with warning rather than failing hard
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "model catalog unavailable: " + err.Error(),
|
||||
@ -410,7 +443,6 @@ func GetModelCatalog(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Filter to models actually available to this user
|
||||
result := make([]ModelCatalogItem, 0, len(catalog))
|
||||
for _, item := range catalog {
|
||||
if available[item.Id] {
|
||||
@ -425,3 +457,277 @@ func GetModelCatalog(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// zenmuxChannelDef describes one Zenmux protocol → one-api channel mapping.
|
||||
type zenmuxChannelDef struct {
|
||||
Name string
|
||||
Type int
|
||||
Protocols []string // suitable_api values to include
|
||||
}
|
||||
|
||||
// allZenmuxChannelDefs is the canonical channel mapping for Zenmux.
|
||||
//
|
||||
// Zenmux exposes a single OpenAI-compatible ingress at /api/v1 for ALL text models
|
||||
// (chat.completions, responses, messages, gemini, generate).
|
||||
// suitable_api is upstream metadata only — Zenmux always accepts OpenAI format on input.
|
||||
//
|
||||
// Imagen and Veo use https://zenmux.ai/api/vertex-ai (Vertex AI SDK) and are
|
||||
// not handled here (one-api's relay framework doesn't support that format yet).
|
||||
var allZenmuxChannelDefs = []zenmuxChannelDef{
|
||||
{
|
||||
Name: "Zenmux",
|
||||
Type: 50, // OpenAICompatible
|
||||
Protocols: []string{"chat.completions", "responses", "messages", "gemini", "generate"},
|
||||
},
|
||||
}
|
||||
|
||||
// GetZenmuxProtocols returns all distinct suitable_api values found in the Zenmux catalog.
|
||||
// GET /api/zenmux/protocols — admin only, for debugging protocol → channel mapping.
|
||||
func GetZenmuxProtocols(c *gin.Context) {
|
||||
catalog, err := fetchZenmuxCatalog()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
seen := make(map[string]int)
|
||||
for _, item := range catalog {
|
||||
for _, api := range splitCSV(item.SuitableApi) {
|
||||
api = strings.TrimSpace(api)
|
||||
if api != "" {
|
||||
seen[api]++
|
||||
}
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "data": seen})
|
||||
}
|
||||
|
||||
// SetupZenmuxChannels creates all Zenmux protocol channels and syncs their model lists.
|
||||
// POST /api/zenmux/setup — body: {"key":"<zenmux-api-key>","base_url":"https://zenmux.ai"}
|
||||
// Skips channels that already exist (matched by name). Idempotent.
|
||||
func SetupZenmuxChannels(c *gin.Context) {
|
||||
var req struct {
|
||||
Key string `json:"key"`
|
||||
BaseURL string `json:"base_url"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil || req.Key == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "key is required"})
|
||||
return
|
||||
}
|
||||
baseURL := req.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = "https://zenmux.ai/api/v1"
|
||||
}
|
||||
|
||||
catalog, err := fetchZenmuxCatalog()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "fetch zenmux catalog failed: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
type result struct {
|
||||
Name string `json:"name"`
|
||||
Models int `json:"models"`
|
||||
Action string `json:"action"` // "created" | "skipped"
|
||||
}
|
||||
results := make([]result, 0, len(allZenmuxChannelDefs))
|
||||
|
||||
for _, def := range allZenmuxChannelDefs {
|
||||
// Build protocol set for fast lookup.
|
||||
wantProtos := make(map[string]bool, len(def.Protocols))
|
||||
for _, p := range def.Protocols {
|
||||
wantProtos[strings.ToLower(p)] = true
|
||||
}
|
||||
|
||||
// Filter catalog to matching protocols.
|
||||
slugs := make([]string, 0)
|
||||
for _, item := range catalog {
|
||||
for _, api := range splitCSV(item.SuitableApi) {
|
||||
if wantProtos[strings.ToLower(strings.TrimSpace(api))] {
|
||||
slugs = append(slugs, item.Id)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
modelList := strings.Join(slugs, ",")
|
||||
|
||||
// Check if a channel with this name already exists.
|
||||
var existing []model.Channel
|
||||
model.DB.Where("name = ?", def.Name).Find(&existing)
|
||||
if len(existing) > 0 {
|
||||
// Update model list on the first match.
|
||||
ch := existing[0]
|
||||
ch.Models = modelList
|
||||
_ = ch.Update()
|
||||
results = append(results, result{Name: def.Name, Models: len(slugs), Action: "updated"})
|
||||
continue
|
||||
}
|
||||
|
||||
// Create new channel.
|
||||
priority := int64(0)
|
||||
weight := uint(1)
|
||||
ch := model.Channel{
|
||||
Type: def.Type,
|
||||
Name: def.Name,
|
||||
Key: req.Key,
|
||||
BaseURL: &baseURL,
|
||||
Models: modelList,
|
||||
Status: 1,
|
||||
Group: "default",
|
||||
Priority: &priority,
|
||||
Weight: &weight,
|
||||
}
|
||||
if err := ch.Insert(); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": fmt.Sprintf("create channel %q failed: %s", def.Name, err.Error())})
|
||||
return
|
||||
}
|
||||
results = append(results, result{Name: def.Name, Models: len(slugs), Action: "created"})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": fmt.Sprintf("setup complete: %d channels processed", len(results)),
|
||||
"data": results,
|
||||
})
|
||||
}
|
||||
|
||||
// SyncZenmuxModels fetches all models from Zenmux and updates the channel's model list.
|
||||
// POST /api/channel/:id/sync-zenmux
|
||||
// Optional query param: ?protocol=chat.completions
|
||||
// Accepts comma-separated values to match multiple protocols, e.g.
|
||||
// ?protocol=google.gemini,google.imagen,google.video
|
||||
func SyncZenmuxModels(c *gin.Context) {
|
||||
channelId, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "invalid channel id"})
|
||||
return
|
||||
}
|
||||
|
||||
// Build a set of requested protocols (lowercased) for O(1) lookup.
|
||||
protocolParam := strings.TrimSpace(c.Query("protocol"))
|
||||
wantProtocols := make(map[string]bool)
|
||||
for _, p := range splitCSV(protocolParam) {
|
||||
if p != "" {
|
||||
wantProtocols[strings.ToLower(p)] = true
|
||||
}
|
||||
}
|
||||
|
||||
catalog, err := fetchZenmuxCatalog()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
slugs := make([]string, 0, len(catalog))
|
||||
for _, item := range catalog {
|
||||
if len(wantProtocols) > 0 {
|
||||
matched := false
|
||||
for _, api := range splitCSV(item.SuitableApi) {
|
||||
if wantProtocols[strings.ToLower(strings.TrimSpace(api))] {
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !matched {
|
||||
continue
|
||||
}
|
||||
}
|
||||
slugs = append(slugs, item.Id)
|
||||
}
|
||||
modelList := strings.Join(slugs, ",")
|
||||
|
||||
ch, err := model.GetChannelById(channelId, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel not found"})
|
||||
return
|
||||
}
|
||||
ch.Models = modelList
|
||||
if err := ch.Update(); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("synced %d models to channel", len(slugs))
|
||||
if protocolParam != "" {
|
||||
msg = fmt.Sprintf("synced %d models (protocol=%s) to channel", len(slugs), protocolParam)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": msg,
|
||||
"data": len(slugs),
|
||||
})
|
||||
}
|
||||
|
||||
// SyncZenmuxRatios fetches the Zenmux model catalog and updates one-api's
|
||||
// ModelRatio and CompletionRatio maps so every Zenmux model has correct billing.
|
||||
//
|
||||
// Ratio formula:
|
||||
//
|
||||
// modelRatio = input_price ($/1M tokens, raw)
|
||||
// completionRatio = output_price / input_price (relative to input)
|
||||
//
|
||||
// POST /api/zenmux/sync-ratios — admin only.
|
||||
func SyncZenmuxRatios(c *gin.Context) {
|
||||
catalog, err := fetchZenmuxCatalog()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "fetch catalog failed: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Load current ratios so we only ADD/UPDATE Zenmux entries, not wipe custom ones.
|
||||
var modelRatioMap map[string]float64
|
||||
var completionRatioMap map[string]float64
|
||||
_ = json.Unmarshal([]byte(billingratio.ModelRatio2JSONString()), &modelRatioMap)
|
||||
_ = json.Unmarshal([]byte(billingratio.CompletionRatio2JSONString()), &completionRatioMap)
|
||||
if modelRatioMap == nil {
|
||||
modelRatioMap = make(map[string]float64)
|
||||
}
|
||||
if completionRatioMap == nil {
|
||||
completionRatioMap = make(map[string]float64)
|
||||
}
|
||||
|
||||
updated, skipped := 0, 0
|
||||
for _, item := range catalog {
|
||||
if item.InputPrice <= 0 {
|
||||
skipped++
|
||||
continue
|
||||
}
|
||||
modelRatioMap[item.Id] = item.InputPrice
|
||||
if item.OutputPrice > 0 {
|
||||
completionRatioMap[item.Id] = item.OutputPrice / item.InputPrice
|
||||
}
|
||||
updated++
|
||||
}
|
||||
|
||||
newModelJSON, err := json.Marshal(modelRatioMap)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "marshal model ratio failed: " + err.Error()})
|
||||
return
|
||||
}
|
||||
if err = billingratio.UpdateModelRatioByJSONString(string(newModelJSON)); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "update model ratio failed: " + err.Error()})
|
||||
return
|
||||
}
|
||||
if err = model.UpdateOption("ModelRatio", string(newModelJSON)); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "save model ratio failed: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
newCompletionJSON, err := json.Marshal(completionRatioMap)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "marshal completion ratio failed: " + err.Error()})
|
||||
return
|
||||
}
|
||||
if err = billingratio.UpdateCompletionRatioByJSONString(string(newCompletionJSON)); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "update completion ratio failed: " + err.Error()})
|
||||
return
|
||||
}
|
||||
if err = model.UpdateOption("CompletionRatio", string(newCompletionJSON)); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "save completion ratio failed: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": fmt.Sprintf("synced ratios for %d models (%d skipped — no pricing)", updated, skipped),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
@ -25,25 +27,29 @@ func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
defaultVersion := config.GeminiVersion
|
||||
if strings.Contains(meta.ActualModelName, "gemini-2.0") ||
|
||||
strings.Contains(meta.ActualModelName, "gemini-1.5") {
|
||||
modelLower := strings.ToLower(meta.ActualModelName)
|
||||
if strings.Contains(modelLower, "gemini-1.5") ||
|
||||
strings.Contains(modelLower, "gemini-2.") ||
|
||||
strings.Contains(modelLower, "gemini-3.") {
|
||||
defaultVersion = "v1beta"
|
||||
}
|
||||
|
||||
version := helper.AssignOrDefault(meta.Config.APIVersion, defaultVersion)
|
||||
action := ""
|
||||
|
||||
modelName := meta.ActualModelName
|
||||
|
||||
switch meta.Mode {
|
||||
case relaymode.ImagesGenerations:
|
||||
return fmt.Sprintf("%s/v1beta/models/%s:predict", meta.BaseURL, modelName), nil
|
||||
case relaymode.Embeddings:
|
||||
action = "batchEmbedContents"
|
||||
default:
|
||||
action = "generateContent"
|
||||
return fmt.Sprintf("%s/%s/models/%s:batchEmbedContents", meta.BaseURL, version, modelName), nil
|
||||
}
|
||||
|
||||
action := "generateContent"
|
||||
if meta.IsStream {
|
||||
action = "streamGenerateContent?alt=sse"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
|
||||
return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, modelName, action), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||
@ -70,7 +76,31 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error)
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
n := request.N
|
||||
if n <= 0 {
|
||||
n = 1
|
||||
}
|
||||
return ImagenRequest{
|
||||
Instances: []ImagenInstance{{Prompt: request.Prompt}},
|
||||
Parameters: ImagenParameters{
|
||||
SampleCount: n,
|
||||
AspectRatio: sizeToAspectRatio(request.Size),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// sizeToAspectRatio converts OpenAI size strings to Imagen aspect ratios.
|
||||
func sizeToAspectRatio(size string) string {
|
||||
switch size {
|
||||
case "1792x1024":
|
||||
return "16:9"
|
||||
case "1024x1792":
|
||||
return "9:16"
|
||||
case "1024x1024", "":
|
||||
return "1:1"
|
||||
default:
|
||||
return "1:1"
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
@ -79,13 +109,16 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Read
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
var responseText string
|
||||
err, responseText = StreamHandler(c, resp)
|
||||
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
||||
err, usage = StreamHandler(c, resp)
|
||||
if usage != nil {
|
||||
usage.PromptTokens = meta.PromptTokens
|
||||
}
|
||||
} else {
|
||||
switch meta.Mode {
|
||||
case relaymode.Embeddings:
|
||||
err, usage = EmbeddingHandler(c, resp)
|
||||
case relaymode.ImagesGenerations:
|
||||
err = ImagenHandler(c, resp)
|
||||
default:
|
||||
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||
}
|
||||
@ -93,6 +126,48 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
|
||||
return
|
||||
}
|
||||
|
||||
// ImagenHandler converts a Google Imagen predict response to OpenAI image response format.
|
||||
func ImagenHandler(c *gin.Context, resp *http.Response) *model.ErrorWithStatusCode {
|
||||
responseBody, readErr := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if readErr != nil {
|
||||
return openai.ErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var imagenResp ImagenResponse
|
||||
if err := json.Unmarshal(responseBody, &imagenResp); err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if imagenResp.Error != nil {
|
||||
return openai.ErrorWrapper(
|
||||
fmt.Errorf("%s", imagenResp.Error.Message),
|
||||
"imagen_error",
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
|
||||
// Convert to OpenAI image response format
|
||||
data := make([]openai.ImageData, 0, len(imagenResp.Predictions))
|
||||
for _, p := range imagenResp.Predictions {
|
||||
data = append(data, openai.ImageData{B64Json: p.BytesBase64Encoded})
|
||||
}
|
||||
openaiResp := openai.ImageResponse{
|
||||
Created: time.Now().Unix(),
|
||||
Data: data,
|
||||
}
|
||||
|
||||
jsonBytes, err := json.Marshal(openaiResp)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
c.Writer.Write(jsonBytes)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
@ -100,3 +175,4 @@ func (a *Adaptor) GetModelList() []string {
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return "google gemini"
|
||||
}
|
||||
|
||||
|
||||
@ -34,6 +34,29 @@ var mimeTypeMap = map[string]string{
|
||||
"text": "text/plain",
|
||||
}
|
||||
|
||||
// sanitizeSchema recursively removes JSON Schema keywords unsupported by Gemini
|
||||
// (e.g. "const", "$schema", "additionalProperties") from a schema map.
|
||||
func sanitizeSchema(v interface{}) interface{} {
|
||||
switch val := v.(type) {
|
||||
case map[string]interface{}:
|
||||
// Only remove fields Gemini explicitly rejects; leave others intact
|
||||
unsupported := []string{"const", "$schema", "additionalProperties"}
|
||||
for _, key := range unsupported {
|
||||
delete(val, key)
|
||||
}
|
||||
for k, child := range val {
|
||||
val[k] = sanitizeSchema(child)
|
||||
}
|
||||
return val
|
||||
case []interface{}:
|
||||
for i, item := range val {
|
||||
val[i] = sanitizeSchema(item)
|
||||
}
|
||||
return val
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
||||
geminiRequest := ChatRequest{
|
||||
@ -75,10 +98,25 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
||||
geminiRequest.GenerationConfig.ResponseMimeType = mimeTypeMap["json_object"]
|
||||
}
|
||||
}
|
||||
// For models that support image generation (e.g. gemini-2.5-flash-image),
|
||||
// request both TEXT and IMAGE modalities so the model returns inline images.
|
||||
if strings.Contains(strings.ToLower(textRequest.Model), "image") {
|
||||
geminiRequest.GenerationConfig.ResponseModalities = []string{"TEXT", "IMAGE"}
|
||||
}
|
||||
|
||||
// Enable thinking when the client explicitly requests it via enable_thinking=true.
|
||||
// Use thinkingBudget=-1 (dynamic) so Gemini decides the appropriate budget.
|
||||
if textRequest.EnableThinking {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ThinkingBudget: -1}
|
||||
}
|
||||
if textRequest.Tools != nil {
|
||||
functions := make([]model.Function, 0, len(textRequest.Tools))
|
||||
for _, tool := range textRequest.Tools {
|
||||
functions = append(functions, tool.Function)
|
||||
fn := tool.Function
|
||||
if fn.Parameters != nil {
|
||||
fn.Parameters = sanitizeSchema(fn.Parameters)
|
||||
}
|
||||
functions = append(functions, fn)
|
||||
}
|
||||
geminiRequest.Tools = []ChatTools{
|
||||
{
|
||||
@ -92,8 +130,45 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
||||
},
|
||||
}
|
||||
}
|
||||
// Build a map from tool_call_id → function name for resolving tool result names
|
||||
toolCallIdToName := map[string]string{}
|
||||
for _, message := range textRequest.Messages {
|
||||
if message.Role == "assistant" {
|
||||
for _, tc := range message.ToolCalls {
|
||||
if tc.Id != "" && tc.Function.Name != "" {
|
||||
toolCallIdToName[tc.Id] = tc.Function.Name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
shouldAddDummyModelMessage := false
|
||||
for _, message := range textRequest.Messages {
|
||||
// --- tool result: role=tool → Gemini functionResponse (user role) ---
|
||||
if message.Role == "tool" {
|
||||
toolName := message.ToolCallId
|
||||
if name, ok := toolCallIdToName[message.ToolCallId]; ok {
|
||||
toolName = name
|
||||
} else if message.Name != nil && *message.Name != "" {
|
||||
toolName = *message.Name
|
||||
}
|
||||
if toolName == "" {
|
||||
toolName = "unknown_tool"
|
||||
}
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, ChatContent{
|
||||
Role: "user",
|
||||
Parts: []Part{
|
||||
{
|
||||
FunctionResponse: &FunctionResponse{
|
||||
Name: toolName,
|
||||
Response: map[string]any{"content": message.StringContent()},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
content := ChatContent{
|
||||
Role: message.Role,
|
||||
Parts: []Part{
|
||||
@ -111,11 +186,24 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
||||
Text: part.Text,
|
||||
})
|
||||
} else if part.Type == model.ContentTypeImageURL {
|
||||
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
|
||||
// Only count images toward the image limit; video/audio have no such limit
|
||||
isImage := strings.HasPrefix(mimeType, "image/")
|
||||
if isImage {
|
||||
imageNum += 1
|
||||
if imageNum > VisionMaxImageNum {
|
||||
continue
|
||||
}
|
||||
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
|
||||
}
|
||||
parts = append(parts, Part{
|
||||
InlineData: &InlineData{
|
||||
MimeType: mimeType,
|
||||
Data: data,
|
||||
},
|
||||
})
|
||||
} else if part.Type == model.ContentTypeVideoURL {
|
||||
mimeType, data, _ := image.GetImageFromUrl(part.VideoURL.Url)
|
||||
if data != "" {
|
||||
parts = append(parts, Part{
|
||||
InlineData: &InlineData{
|
||||
MimeType: mimeType,
|
||||
@ -123,7 +211,56 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
||||
},
|
||||
})
|
||||
}
|
||||
} else if part.Type == model.ContentTypeInputAudio {
|
||||
// input_audio: { data: "base64...", format: "mp3" }
|
||||
// Convert directly to Gemini inlineData — bypasses Zenmux fileUri conversion
|
||||
// that occurs when audio is embedded in image_url.
|
||||
if part.InputAudio != nil && part.InputAudio.Data != "" {
|
||||
mimeType := "audio/" + part.InputAudio.Format
|
||||
if part.InputAudio.Format == "" {
|
||||
mimeType = "audio/webm"
|
||||
}
|
||||
parts = append(parts, Part{
|
||||
InlineData: &InlineData{
|
||||
MimeType: mimeType,
|
||||
Data: part.InputAudio.Data,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- assistant with tool_calls → Gemini functionCall parts ---
|
||||
if message.Role == "assistant" && len(message.ToolCalls) > 0 {
|
||||
var fcParts []Part
|
||||
// Include any text content first
|
||||
for _, p := range parts {
|
||||
if p.Text != "" {
|
||||
fcParts = append(fcParts, p)
|
||||
}
|
||||
}
|
||||
for _, tc := range message.ToolCalls {
|
||||
var args any
|
||||
if argStr, ok := tc.Function.Arguments.(string); ok && argStr != "" {
|
||||
if err := json.Unmarshal([]byte(argStr), &args); err != nil {
|
||||
args = map[string]any{}
|
||||
}
|
||||
} else {
|
||||
args = map[string]any{}
|
||||
}
|
||||
fcParts = append(fcParts, Part{
|
||||
FunctionCall: &FunctionCall{
|
||||
FunctionName: tc.Function.Name,
|
||||
Arguments: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
content.Role = "model"
|
||||
content.Parts = fcParts
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, content)
|
||||
continue
|
||||
}
|
||||
|
||||
content.Parts = parts
|
||||
|
||||
// there's no assistant role in gemini and API shall vomit if Role is not user or model
|
||||
@ -184,9 +321,16 @@ func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *BatchEmbedding
|
||||
}
|
||||
}
|
||||
|
||||
type UsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
TotalTokenCount int `json:"totalTokenCount"`
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
Candidates []ChatCandidate `json:"candidates"`
|
||||
PromptFeedback ChatPromptFeedback `json:"promptFeedback"`
|
||||
UsageMetadata *UsageMetadata `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
func (g *ChatResponse) GetResponseText() string {
|
||||
@ -278,8 +422,33 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = geminiResponse.GetResponseText()
|
||||
//choice.FinishReason = &constant.StopFinishReason
|
||||
|
||||
if len(geminiResponse.Candidates) > 0 {
|
||||
var textBuilder strings.Builder
|
||||
var thinkingBuilder strings.Builder
|
||||
for _, part := range geminiResponse.Candidates[0].Content.Parts {
|
||||
if part.Thought {
|
||||
// Thinking/reasoning content — route to reasoning_content field
|
||||
thinkingBuilder.WriteString(part.Text)
|
||||
} else if part.Text != "" {
|
||||
textBuilder.WriteString(part.Text)
|
||||
} else if part.InlineData != nil && part.InlineData.Data != "" {
|
||||
// Inline image — embed as markdown data-URI so it passes through the SSE pipeline
|
||||
mimeType := part.InlineData.MimeType
|
||||
if mimeType == "" {
|
||||
mimeType = "image/png"
|
||||
}
|
||||
textBuilder.WriteString(fmt.Sprintf("", mimeType, part.InlineData.Data))
|
||||
}
|
||||
}
|
||||
if textBuilder.Len() > 0 {
|
||||
choice.Delta.Content = textBuilder.String()
|
||||
}
|
||||
if thinkingBuilder.Len() > 0 {
|
||||
choice.Delta.ReasoningContent = thinkingBuilder.String()
|
||||
}
|
||||
}
|
||||
|
||||
var response openai.ChatCompletionsStreamResponse
|
||||
response.Id = fmt.Sprintf("chatcmpl-%s", random.GetUUID())
|
||||
response.Created = helper.GetTimestamp()
|
||||
@ -306,9 +475,14 @@ func embeddingResponseGemini2OpenAI(response *EmbeddingResponse) *openai.Embeddi
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
var usage *model.Usage
|
||||
responseText := ""
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
// Default bufio.Scanner buffer is 64KB which is too small for inline image data (base64).
|
||||
// Allocate 20MB to handle large image payloads from Gemini image-generation models.
|
||||
const maxScanTokenSize = 20 * 1024 * 1024
|
||||
scanner.Buffer(make([]byte, maxScanTokenSize), maxScanTokenSize)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
common.SetEventStreamHeaders(c)
|
||||
@ -329,12 +503,31 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract usageMetadata from the last chunk that carries it.
|
||||
// This includes image/video/audio generation costs that cannot be
|
||||
// estimated from text tokenisation alone.
|
||||
if geminiResponse.UsageMetadata != nil {
|
||||
usage = &model.Usage{
|
||||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
||||
}
|
||||
}
|
||||
|
||||
response := streamResponseGeminiChat2OpenAI(&geminiResponse)
|
||||
if response == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
responseText += response.Choices[0].Delta.StringContent()
|
||||
// Accumulate text for fallback token estimation (used only when
|
||||
// usageMetadata is absent from the stream).
|
||||
if len(geminiResponse.Candidates) > 0 {
|
||||
for _, part := range geminiResponse.Candidates[0].Content.Parts {
|
||||
if part.InlineData == nil {
|
||||
responseText += part.Text
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = render.ObjectData(c, response)
|
||||
if err != nil {
|
||||
@ -350,10 +543,15 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
|
||||
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
return nil, responseText
|
||||
// If upstream provided usageMetadata, use it (accurate, includes image costs).
|
||||
// Otherwise fall back to local tiktoken estimation on text-only content.
|
||||
if usage != nil {
|
||||
return nil, usage
|
||||
}
|
||||
return nil, openai.ResponseText2Usage(responseText, "", 0)
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||
|
||||
@ -45,10 +45,18 @@ type FunctionCall struct {
|
||||
Arguments any `json:"args"`
|
||||
}
|
||||
|
||||
type FunctionResponse struct {
|
||||
Name string `json:"name"`
|
||||
Response any `json:"response"`
|
||||
}
|
||||
|
||||
type Part struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *InlineData `json:"inlineData,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
|
||||
// Thought marks this part as internal reasoning/thinking content (Gemini thinking models)
|
||||
Thought bool `json:"thought,omitempty"`
|
||||
}
|
||||
|
||||
type ChatContent struct {
|
||||
@ -65,13 +73,49 @@ type ChatTools struct {
|
||||
FunctionDeclarations any `json:"function_declarations,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiThinkingConfig struct {
|
||||
ThinkingBudget int `json:"thinkingBudget"` // -1 = dynamic, 0 = disabled
|
||||
}
|
||||
|
||||
type ChatGenerationConfig struct {
|
||||
ResponseMimeType string `json:"responseMimeType,omitempty"`
|
||||
ResponseSchema any `json:"responseSchema,omitempty"`
|
||||
ResponseModalities []string `json:"responseModalities,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"topP,omitempty"`
|
||||
TopK float64 `json:"topK,omitempty"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
}
|
||||
|
||||
// --- Google Imagen ---
|
||||
|
||||
// ImagenRequest is the request body for Google Imagen (predict endpoint).
|
||||
// POST /v1beta/models/{model}:predict
|
||||
type ImagenRequest struct {
|
||||
Instances []ImagenInstance `json:"instances"`
|
||||
Parameters ImagenParameters `json:"parameters"`
|
||||
}
|
||||
|
||||
type ImagenInstance struct {
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
type ImagenParameters struct {
|
||||
SampleCount int `json:"sampleCount"`
|
||||
AspectRatio string `json:"aspectRatio,omitempty"`
|
||||
}
|
||||
|
||||
// ImagenResponse is the response from Google Imagen.
|
||||
type ImagenResponse struct {
|
||||
Predictions []ImagenPrediction `json:"predictions"`
|
||||
Error *Error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type ImagenPrediction struct {
|
||||
BytesBase64Encoded string `json:"bytesBase64Encoded"`
|
||||
MimeType string `json:"mimeType"`
|
||||
}
|
||||
|
||||
|
||||
@ -146,6 +146,11 @@ const (
|
||||
// https://platform.openai.com/docs/guides/vision/calculating-costs
|
||||
// https://github.com/openai/openai-cookbook/blob/05e3f9be4c7a2ae7ecf029a7c32065b024730ebe/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
func countImageTokens(url string, detail string, model string) (_ int, err error) {
|
||||
// Skip token counting for non-image data URLs (video, audio, etc.)
|
||||
// These cannot be decoded as images and will cause errors.
|
||||
if strings.HasPrefix(url, "data:") && !strings.HasPrefix(url, "data:image/") {
|
||||
return 0, nil
|
||||
}
|
||||
var fetchSize = true
|
||||
var width, height int
|
||||
// Reference: https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding
|
||||
|
||||
@ -7,7 +7,6 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/gemini"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
@ -40,9 +39,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
var responseText string
|
||||
err, responseText = gemini.StreamHandler(c, resp)
|
||||
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
|
||||
err, usage = gemini.StreamHandler(c, resp)
|
||||
if usage != nil {
|
||||
usage.PromptTokens = meta.PromptTokens
|
||||
}
|
||||
} else {
|
||||
switch meta.Mode {
|
||||
case relaymode.Embeddings:
|
||||
|
||||
@ -39,6 +39,20 @@ type innerAIAdapter interface {
|
||||
|
||||
func GetAdaptor(model string) innerAIAdapter {
|
||||
adaptorType := modelMapping[model]
|
||||
// Handle "google/" prefixed model names (e.g. "google/gemini-2.5-flash-image")
|
||||
// by stripping the prefix and looking up the bare model name.
|
||||
if adaptorType == 0 {
|
||||
bare := model
|
||||
if idx := len("google/"); len(model) > idx && model[:idx] == "google/" {
|
||||
bare = model[idx:]
|
||||
}
|
||||
adaptorType = modelMapping[bare]
|
||||
}
|
||||
// If still not found, default to Gemini adaptor for any unrecognized model
|
||||
// (Zenmux proxies all Gemini variants through the same generateContent endpoint).
|
||||
if adaptorType == 0 {
|
||||
adaptorType = VerterAIGemini
|
||||
}
|
||||
switch adaptorType {
|
||||
case VerterAIClaude:
|
||||
return &claude.Adaptor{}
|
||||
|
||||
@ -156,7 +156,8 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
case channeltype.Zhipu,
|
||||
channeltype.Ali,
|
||||
channeltype.Replicate,
|
||||
channeltype.Baidu:
|
||||
channeltype.Baidu,
|
||||
channeltype.Gemini:
|
||||
finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
|
||||
|
||||
@ -3,5 +3,6 @@ package model
|
||||
const (
|
||||
ContentTypeText = "text"
|
||||
ContentTypeImageURL = "image_url"
|
||||
ContentTypeVideoURL = "video_url"
|
||||
ContentTypeInputAudio = "input_audio"
|
||||
)
|
||||
|
||||
@ -66,6 +66,7 @@ type GeneralOpenAIRequest struct {
|
||||
// Others
|
||||
Instruction string `json:"instruction,omitempty"`
|
||||
NumCtx int `json:"num_ctx,omitempty"`
|
||||
EnableThinking bool `json:"enable_thinking,omitempty"`
|
||||
}
|
||||
|
||||
func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||
|
||||
@ -72,6 +72,26 @@ func (m Message) ParseContent() []MessageContent {
|
||||
},
|
||||
})
|
||||
}
|
||||
case ContentTypeVideoURL:
|
||||
if subObj, ok := contentMap["video_url"].(map[string]any); ok {
|
||||
if url, ok := subObj["url"].(string); ok {
|
||||
contentList = append(contentList, MessageContent{
|
||||
Type: ContentTypeVideoURL,
|
||||
VideoURL: &VideoURL{Url: url},
|
||||
})
|
||||
}
|
||||
}
|
||||
case ContentTypeInputAudio:
|
||||
if subObj, ok := contentMap["input_audio"].(map[string]any); ok {
|
||||
data, _ := subObj["data"].(string)
|
||||
format, _ := subObj["format"].(string)
|
||||
if data != "" {
|
||||
contentList = append(contentList, MessageContent{
|
||||
Type: ContentTypeInputAudio,
|
||||
InputAudio: &InputAudio{Data: data, Format: format},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return contentList
|
||||
@ -84,8 +104,19 @@ type ImageURL struct {
|
||||
Detail string `json:"detail,omitempty"`
|
||||
}
|
||||
|
||||
type VideoURL struct {
|
||||
Url string `json:"url,omitempty"`
|
||||
}
|
||||
|
||||
type InputAudio struct {
|
||||
Data string `json:"data,omitempty"` // base64-encoded audio (no data: prefix)
|
||||
Format string `json:"format,omitempty"` // e.g. "mp3", "wav", "webm", "ogg"
|
||||
}
|
||||
|
||||
type MessageContent struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Text string `json:"text"`
|
||||
ImageURL *ImageURL `json:"image_url,omitempty"`
|
||||
VideoURL *VideoURL `json:"video_url,omitempty"`
|
||||
InputAudio *InputAudio `json:"input_audio,omitempty"`
|
||||
}
|
||||
|
||||
@ -16,6 +16,12 @@ func SetApiRouter(router *gin.Engine) {
|
||||
{
|
||||
apiRouter.GET("/status", controller.GetStatus)
|
||||
apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels)
|
||||
apiRouter.GET("/model-catalog", middleware.TokenAuth(), controller.GetModelCatalog)
|
||||
apiRouter.GET("/token-quota", middleware.TokenAuth(), controller.GetTokenQuota)
|
||||
apiRouter.POST("/channel/:id/sync-zenmux", middleware.AdminAuth(), controller.SyncZenmuxModels)
|
||||
apiRouter.POST("/zenmux/setup", middleware.AdminAuth(), controller.SetupZenmuxChannels)
|
||||
apiRouter.GET("/zenmux/protocols", middleware.AdminAuth(), controller.GetZenmuxProtocols)
|
||||
apiRouter.POST("/zenmux/sync-ratios", middleware.AdminAuth(), controller.SyncZenmuxRatios)
|
||||
apiRouter.GET("/notice", controller.GetNotice)
|
||||
apiRouter.GET("/about", controller.GetAbout)
|
||||
apiRouter.GET("/home_page_content", controller.GetHomePageContent)
|
||||
|
||||
130
tools/migrate/main.go
Normal file
130
tools/migrate/main.go
Normal file
@ -0,0 +1,130 @@
|
||||
// migrate/main.go — SQLite → MySQL one-api data migration tool
|
||||
// Usage: go run tools/migrate/main.go
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
const (
|
||||
sqlitePath = "one-api.db"
|
||||
mysqlDSN = "root:123456@tcp(localhost:3306)/oneapi?charset=utf8mb4&parseTime=True&loc=Local"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if _, err := os.Stat(sqlitePath); os.IsNotExist(err) {
|
||||
log.Fatalf("SQLite file not found: %s", sqlitePath)
|
||||
}
|
||||
|
||||
sqlite, err := sql.Open("sqlite3", sqlitePath)
|
||||
if err != nil {
|
||||
log.Fatalf("open sqlite: %v", err)
|
||||
}
|
||||
defer sqlite.Close()
|
||||
|
||||
mysql, err := sql.Open("mysql", mysqlDSN)
|
||||
if err != nil {
|
||||
log.Fatalf("open mysql: %v", err)
|
||||
}
|
||||
defer mysql.Close()
|
||||
|
||||
if err = mysql.Ping(); err != nil {
|
||||
log.Fatalf("mysql ping failed: %v\nCheck if MySQL is running and DSN is correct.", err)
|
||||
}
|
||||
|
||||
fmt.Println("Connected to both databases. Starting migration...")
|
||||
|
||||
tables := []string{"users", "channels", "tokens", "options", "redemptions", "logs", "abilities"}
|
||||
for _, table := range tables {
|
||||
if err := migrateTable(sqlite, mysql, table); err != nil {
|
||||
fmt.Printf("[WARN] table %s: %v\n", table, err)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("\nMigration complete!")
|
||||
}
|
||||
|
||||
func migrateTable(src, dst *sql.DB, table string) error {
|
||||
// Check table exists in SQLite
|
||||
var count int
|
||||
err := src.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='%s'", table)).Scan(&count)
|
||||
if err != nil || count == 0 {
|
||||
fmt.Printf("[SKIP] table %s not found in SQLite\n", table)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get row count
|
||||
var total int
|
||||
_ = src.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM `%s`", table)).Scan(&total)
|
||||
if total == 0 {
|
||||
fmt.Printf("[SKIP] table %s is empty\n", table)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read all rows
|
||||
rows, err := src.Query(fmt.Sprintf("SELECT * FROM `%s`", table))
|
||||
if err != nil {
|
||||
return fmt.Errorf("select from sqlite: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
cols, err := rows.Columns()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build INSERT statement with placeholders
|
||||
placeholders := ""
|
||||
colNames := ""
|
||||
for i, col := range cols {
|
||||
if i > 0 {
|
||||
placeholders += ","
|
||||
colNames += ","
|
||||
}
|
||||
placeholders += "?"
|
||||
colNames += fmt.Sprintf("`%s`", col)
|
||||
}
|
||||
insertSQL := fmt.Sprintf("INSERT IGNORE INTO `%s` (%s) VALUES (%s)", table, colNames, placeholders)
|
||||
|
||||
stmt, err := dst.Prepare(insertSQL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("prepare insert for %s: %w", table, err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
inserted := 0
|
||||
vals := make([]interface{}, len(cols))
|
||||
valPtrs := make([]interface{}, len(cols))
|
||||
for i := range vals {
|
||||
valPtrs[i] = &vals[i]
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(valPtrs...); err != nil {
|
||||
return fmt.Errorf("scan: %w", err)
|
||||
}
|
||||
// Convert []byte to string for MySQL compatibility
|
||||
args := make([]interface{}, len(vals))
|
||||
for i, v := range vals {
|
||||
if b, ok := v.([]byte); ok {
|
||||
args[i] = string(b)
|
||||
} else {
|
||||
args[i] = v
|
||||
}
|
||||
}
|
||||
if _, err := stmt.Exec(args...); err != nil {
|
||||
fmt.Printf("[WARN] insert row in %s: %v\n", table, err)
|
||||
continue
|
||||
}
|
||||
inserted++
|
||||
}
|
||||
|
||||
fmt.Printf("[OK] %s: %d/%d rows migrated\n", table, inserted, total)
|
||||
return nil
|
||||
}
|
||||
@ -5,6 +5,36 @@ export const CHANNEL_OPTIONS = {
|
||||
value: 1,
|
||||
color: 'success'
|
||||
},
|
||||
50: {
|
||||
key: 50,
|
||||
text: 'OpenAI 兼容',
|
||||
value: 50,
|
||||
color: 'warning'
|
||||
},
|
||||
51: {
|
||||
key: 51,
|
||||
text: 'Gemini (OpenAI)',
|
||||
value: 51,
|
||||
color: 'warning'
|
||||
},
|
||||
47: {
|
||||
key: 47,
|
||||
text: '百度文心千帆 V2',
|
||||
value: 47,
|
||||
color: 'primary'
|
||||
},
|
||||
48: {
|
||||
key: 48,
|
||||
text: '讯飞星火认知 V2',
|
||||
value: 48,
|
||||
color: 'primary'
|
||||
},
|
||||
49: {
|
||||
key: 49,
|
||||
text: '阿里云百炼',
|
||||
value: 49,
|
||||
color: 'primary'
|
||||
},
|
||||
14: {
|
||||
key: 14,
|
||||
text: 'Anthropic Claude',
|
||||
@ -185,7 +215,7 @@ export const CHANNEL_OPTIONS = {
|
||||
value: 45,
|
||||
color: 'primary'
|
||||
},
|
||||
45: {
|
||||
46: {
|
||||
key: 46,
|
||||
text: 'Replicate',
|
||||
value: 46,
|
||||
|
||||
@ -49,7 +49,7 @@ const GitHubOAuth = () => {
|
||||
let code = searchParams.get('code');
|
||||
let state = searchParams.get('state');
|
||||
sendCode(code, state, 0).then();
|
||||
}, []);
|
||||
}, []); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
return (
|
||||
<AuthWrapper>
|
||||
|
||||
@ -49,7 +49,7 @@ const LarkOAuth = () => {
|
||||
let code = searchParams.get('code');
|
||||
let state = searchParams.get('state');
|
||||
sendCode(code, state, 0).then();
|
||||
}, []);
|
||||
}, []); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
return (
|
||||
<AuthWrapper>
|
||||
|
||||
@ -49,7 +49,7 @@ const OidcOAuth = () => {
|
||||
let code = searchParams.get('code');
|
||||
let state = searchParams.get('state');
|
||||
sendCode(code, state, 0).then();
|
||||
}, []);
|
||||
}, []); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
return (
|
||||
<AuthWrapper>
|
||||
|
||||
@ -92,7 +92,7 @@ const RegisterForm = ({ ...others }) => {
|
||||
setTurnstileEnabled(true);
|
||||
setTurnstileSiteKey(siteInfo.turnstile_site_key);
|
||||
}
|
||||
}, [siteInfo]);
|
||||
}, [siteInfo]); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
return (
|
||||
<>
|
||||
|
||||
@ -37,7 +37,7 @@ const ResetPasswordForm = () => {
|
||||
token,
|
||||
email
|
||||
});
|
||||
}, []);
|
||||
}, [searchParams]); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
return (
|
||||
<Stack spacing={3} padding={'24px'} justifyContent={'center'} alignItems={'center'}>
|
||||
|
||||
@ -104,7 +104,7 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
|
||||
initChannel(typeValue);
|
||||
let localModels = getChannelModels(typeValue);
|
||||
setBasicModels(localModels);
|
||||
if (localModels.length > 0 && Array.isArray(values['models']) && values['models'].length == 0) {
|
||||
if (localModels.length > 0 && Array.isArray(values['models']) && values['models'].length === 0) {
|
||||
setFieldValue('models', initialModel(localModels));
|
||||
}
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ import PropTypes from 'prop-types';
|
||||
import { Tooltip, Stack, Container } from '@mui/material';
|
||||
import Label from 'ui-component/Label';
|
||||
import { styled } from '@mui/material/styles';
|
||||
import { showSuccess, copy } from 'utils/common';
|
||||
import { copy } from 'utils/common';
|
||||
|
||||
const TooltipContainer = styled(Container)({
|
||||
maxHeight: '250px',
|
||||
|
||||
@ -23,13 +23,13 @@ const ResponseTimeLabel = ({ test_time, response_time, handle_action }) => {
|
||||
<>
|
||||
点击测速
|
||||
<br />
|
||||
{test_time != 0 ? '上次测速时间:' + timestamp2string(test_time) : '未测试'}
|
||||
{test_time !== 0 ? '上次测速时间:' + timestamp2string(test_time) : '未测试'}
|
||||
</>
|
||||
);
|
||||
|
||||
return (
|
||||
<Tooltip title={title} placement="top" onClick={handle_action}>
|
||||
<Label color={color}> {response_time == 0 ? '未测试' : time} </Label>
|
||||
<Label color={color}> {response_time === 0 ? '未测试' : time} </Label>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
|
||||
@ -110,6 +110,8 @@ export default function ChannelPage() {
|
||||
case 'test':
|
||||
res = await API.get(url + `test/${id}`);
|
||||
break;
|
||||
default:
|
||||
return;
|
||||
}
|
||||
const { success, message } = res.data;
|
||||
if (success) {
|
||||
|
||||
@ -90,19 +90,19 @@ const Dashboard = () => {
|
||||
<Typography variant="h4">余额:</Typography>
|
||||
</Grid>
|
||||
<Grid item xs={8}>
|
||||
<Typography variant="h3"> {users?.quota ? '$' + calculateQuota(users.quota) : '未知'}</Typography>
|
||||
<Typography variant="h3"> {users?.quota != null ? '$' + calculateQuota(users.quota) : '未知'}</Typography>
|
||||
</Grid>
|
||||
<Grid item xs={4}>
|
||||
<Typography variant="h4">已使用:</Typography>
|
||||
</Grid>
|
||||
<Grid item xs={8}>
|
||||
<Typography variant="h3"> {users?.used_quota ? '$' + calculateQuota(users.used_quota) : '未知'}</Typography>
|
||||
<Typography variant="h3"> {users?.used_quota != null ? '$' + calculateQuota(users.used_quota) : '未知'}</Typography>
|
||||
</Grid>
|
||||
<Grid item xs={4}>
|
||||
<Typography variant="h4">调用次数:</Typography>
|
||||
</Grid>
|
||||
<Grid item xs={8}>
|
||||
<Typography variant="h3"> {users?.request_count || '未知'}</Typography>
|
||||
<Typography variant="h3"> {users?.request_count != null ? users.request_count : '未知'}</Typography>
|
||||
</Grid>
|
||||
</Grid>
|
||||
</UserCard>
|
||||
|
||||
@ -156,7 +156,7 @@ export default function Token() {
|
||||
</Stack>
|
||||
<Stack mb={2}>
|
||||
<Alert severity="info">
|
||||
将 OpenAI API 基础地址 https://api.openai.com 替换为 <b>{siteInfo.server_address}</b>,复制下面的密钥即可使用
|
||||
在API_KEY处复制下面的密钥即可使用
|
||||
</Alert>
|
||||
</Stack>
|
||||
<Card>
|
||||
|
||||
@ -46,10 +46,10 @@ const InviteCard = () => {
|
||||
>
|
||||
<Stack justifyContent="center" alignItems={'center'} spacing={3}>
|
||||
<Typography variant="h3" sx={{ color: theme.palette.primary.dark }}>
|
||||
邀请奖励
|
||||
邀请好友
|
||||
</Typography>
|
||||
<Typography variant="body" sx={{ color: theme.palette.primary.dark }}>
|
||||
分享您的邀请链接,邀请好友注册,即可获得奖励!
|
||||
分享您的邀请链接,邀请好友注册!
|
||||
</Typography>
|
||||
|
||||
<OutlinedInput
|
||||
|
||||
@ -8,7 +8,7 @@ const Topup = () => {
|
||||
<Grid container spacing={2}>
|
||||
<Grid xs={12}>
|
||||
<Alert severity="warning">
|
||||
充值记录以及邀请记录请在日志中查询。充值记录请在日志中选择类型【充值】查询;邀请记录请在日志中选择【系统】查询{' '}
|
||||
充值记录请在日志中选择类型【充值】查询{' '}
|
||||
</Alert>
|
||||
</Grid>
|
||||
<Grid xs={12} md={6} lg={8}>
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
"version": "0.1.0",
|
||||
"private": true,
|
||||
"dependencies": {
|
||||
"ajv": "^8.18.0",
|
||||
"axios": "^0.27.2",
|
||||
"history": "^5.3.0",
|
||||
"i18next": "^24.2.2",
|
||||
|
||||
@ -445,7 +445,7 @@
|
||||
"personal": {
|
||||
"general": {
|
||||
"title": "通用设置",
|
||||
"system_token_notice": "注意,此处生成的令牌用于系统管理,而非用于请求 OpenAI 相关的服务,请知悉。",
|
||||
"system_token_notice": "注意,此处生成的令牌用于系统管理,无需生成使用,请知悉。",
|
||||
"buttons": {
|
||||
"update_profile": "更新个人信息",
|
||||
"generate_token": "生成系统访问令牌",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user