Like-api/controller/model.go
hjjjj c734c541b2
Some checks failed
CI / Unit tests (push) Has been cancelled
CI / commit_lint (push) Has been cancelled
feat: 添加视频URL支持及Zenmux集成
refactor: 重构Gemini适配器以支持多模态输入
fix: 修复React Hooks依赖警告
style: 清理未使用的导入和代码
docs: 更新用户界面文本和提示
perf: 优化图像和视频URL处理性能
test: 添加数据迁移工具和测试
build: 更新依赖项和.gitignore
chore: 同步Zenmux模型和价格比例
2026-03-12 17:53:27 +08:00

734 lines
21 KiB
Go

package controller
import (
"encoding/json"
"fmt"
"net/http"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/model"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
relay "github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/apitype"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
// https://platform.openai.com/docs/api-reference/models/list
type OpenAIModelPermission struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group *string `json:"group"`
IsBlocking bool `json:"is_blocking"`
}
type OpenAIModels struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
OwnedBy string `json:"owned_by"`
Permission []OpenAIModelPermission `json:"permission"`
Root string `json:"root"`
Parent *string `json:"parent"`
}
var models []OpenAIModels
var modelsMap map[string]OpenAIModels
var channelId2Models map[int][]string
func init() {
var permission []OpenAIModelPermission
permission = append(permission, OpenAIModelPermission{
Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
Object: "model_permission",
Created: 1626777600,
AllowCreateEngine: true,
AllowSampling: true,
AllowLogprobs: true,
AllowSearchIndices: false,
AllowView: true,
AllowFineTuning: false,
Organization: "*",
Group: nil,
IsBlocking: false,
})
// https://platform.openai.com/docs/models/model-endpoint-compatibility
for i := 0; i < apitype.Dummy; i++ {
if i == apitype.AIProxyLibrary {
continue
}
adaptor := relay.GetAdaptor(i)
channelName := adaptor.GetChannelName()
modelNames := adaptor.GetModelList()
for _, modelName := range modelNames {
models = append(models, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: channelName,
Permission: permission,
Root: modelName,
Parent: nil,
})
}
}
for _, channelType := range openai.CompatibleChannels {
if channelType == channeltype.Azure {
continue
}
channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType)
for _, modelName := range channelModelList {
models = append(models, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: channelName,
Permission: permission,
Root: modelName,
Parent: nil,
})
}
}
modelsMap = make(map[string]OpenAIModels)
for _, model := range models {
modelsMap[model.Id] = model
}
channelId2Models = make(map[int][]string)
for i := 1; i < channeltype.Dummy; i++ {
adaptor := relay.GetAdaptor(channeltype.ToAPIType(i))
meta := &meta.Meta{
ChannelType: i,
}
adaptor.Init(meta)
channelId2Models[i] = adaptor.GetModelList()
}
}
func DashboardListModels(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": channelId2Models,
})
}
func ListAllModels(c *gin.Context) {
c.JSON(200, gin.H{
"object": "list",
"data": models,
})
}
func ListModels(c *gin.Context) {
ctx := c.Request.Context()
var availableModels []string
if c.GetString(ctxkey.AvailableModels) != "" {
availableModels = strings.Split(c.GetString(ctxkey.AvailableModels), ",")
} else {
userId := c.GetInt(ctxkey.Id)
userGroup, _ := model.CacheGetUserGroup(userId)
availableModels, _ = model.CacheGetGroupModels(ctx, userGroup)
}
modelSet := make(map[string]bool)
for _, availableModel := range availableModels {
modelSet[availableModel] = true
}
availableOpenAIModels := make([]OpenAIModels, 0)
for _, model := range models {
if _, ok := modelSet[model.Id]; ok {
modelSet[model.Id] = false
availableOpenAIModels = append(availableOpenAIModels, model)
}
}
for modelName, ok := range modelSet {
if ok {
availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
Root: modelName,
Parent: nil,
})
}
}
c.JSON(200, gin.H{
"object": "list",
"data": availableOpenAIModels,
})
}
func RetrieveModel(c *gin.Context) {
modelId := c.Param("model")
if model, ok := modelsMap[modelId]; ok {
c.JSON(200, model)
} else {
Error := relaymodel.Error{
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
Type: "invalid_request_error",
Param: "model",
Code: "model_not_found",
}
c.JSON(200, gin.H{
"error": Error,
})
}
}
func GetUserAvailableModels(c *gin.Context) {
ctx := c.Request.Context()
id := c.GetInt(ctxkey.Id)
userGroup, err := model.CacheGetUserGroup(id)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
models, err := model.CacheGetGroupModels(ctx, userGroup)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": models,
})
return
}
// --- Model Catalog (powered by Zenmux) ---
// ModelCatalogItem is the enriched model record returned to aiscri-xiong clients.
type ModelCatalogItem struct {
Id string `json:"id"`
Name string `json:"name"`
Provider string `json:"provider"`
Description string `json:"description"`
InputModalities []string `json:"input_modalities"`
OutputModalities []string `json:"output_modalities"`
ContextLength int `json:"context_length"`
MaxOutputTokens int `json:"max_output_tokens"`
InputPrice float64 `json:"input_price"`
OutputPrice float64 `json:"output_price"`
// PricingDiscount: e.g. "0.8" means 20% off; applied to both input and output prices
PricingDiscount string `json:"pricing_discount"`
Tags []string `json:"tags"`
// SupportsReasoning: 0=none, 1=always-on, 2=toggleable
SupportsReasoning int `json:"supports_reasoning"`
SuitableApi string `json:"suitable_api"`
}
// zenmuxModel is the raw response shape from Zenmux listByFilter API.
type zenmuxModel struct {
Slug string `json:"slug"`
Name string `json:"name"`
Author string `json:"author"`
Description string `json:"description"`
InputModalities string `json:"input_modalities"` // comma-separated
OutputModalities string `json:"output_modalities"` // comma-separated
ContextLength int `json:"context_length"`
MaxCompletionTokens int `json:"max_completion_tokens"`
PricingPrompt string `json:"pricing_prompt"`
PricingCompletion string `json:"pricing_completion"`
PricingDiscount string `json:"pricing_discount"`
SuitableApi string `json:"suitable_api"` // comma-separated
SupportsReasoning int `json:"supports_reasoning"`
SupportedParameters string `json:"supported_parameters"` // comma-separated
}
type zenmuxListResponse struct {
Success bool `json:"success"`
Data []zenmuxModel `json:"data"`
}
var (
catalogCache []ModelCatalogItem
catalogCacheTime time.Time
catalogMu sync.RWMutex
catalogCacheTTL = 5 * time.Minute
// zenmuxCToken is the public frontend token embedded in the Zenmux website.
// Override via env var ZENMUX_CTOKEN if it ever changes.
zenmuxCToken = func() string {
if v := os.Getenv("ZENMUX_CTOKEN"); v != "" {
return v
}
return "2uF44yawHUs41edv2fRcV_eE"
}()
)
func splitCSV(s string) []string {
if s == "" {
return nil
}
parts := strings.Split(s, ",")
out := make([]string, 0, len(parts))
for _, p := range parts {
if t := strings.TrimSpace(p); t != "" {
out = append(out, t)
}
}
return out
}
func parsePrice(s string) float64 {
s = strings.TrimSpace(s)
if s == "" {
return 0
}
var f float64
fmt.Sscanf(s, "%f", &f)
return f
}
func zenmuxToItem(z zenmuxModel) ModelCatalogItem {
inMod := splitCSV(z.InputModalities)
outMod := splitCSV(z.OutputModalities)
params := splitCSV(z.SupportedParameters)
// Derive tags
tags := []string{}
for _, p := range params {
if p == "tools" {
tags = append(tags, "function-call")
break
}
}
for _, m := range inMod {
if m == "image" {
tags = append(tags, "vision")
break
}
}
if z.SupportsReasoning > 0 {
tags = append(tags, "reasoning")
}
return ModelCatalogItem{
Id: z.Slug,
Name: z.Name,
Provider: z.Author,
Description: z.Description,
InputModalities: inMod,
OutputModalities: outMod,
ContextLength: z.ContextLength,
MaxOutputTokens: z.MaxCompletionTokens,
InputPrice: parsePrice(z.PricingPrompt),
OutputPrice: parsePrice(z.PricingCompletion),
PricingDiscount: z.PricingDiscount,
Tags: tags,
SupportsReasoning: z.SupportsReasoning,
SuitableApi: z.SuitableApi,
}
}
func fetchZenmuxCatalog() ([]ModelCatalogItem, error) {
catalogMu.RLock()
if catalogCache != nil && time.Since(catalogCacheTime) < catalogCacheTTL {
defer catalogMu.RUnlock()
return catalogCache, nil
}
catalogMu.RUnlock()
url := "https://zenmux.ai/api/frontend/model/listByFilter?ctoken=" + zenmuxCToken + "&sort=newest&context_length=&keyword="
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Get(url)
if err != nil {
return nil, fmt.Errorf("zenmux catalog fetch failed: %w", err)
}
defer resp.Body.Close()
var parsed zenmuxListResponse
if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil {
return nil, fmt.Errorf("zenmux catalog parse failed: %w", err)
}
if !parsed.Success {
return nil, fmt.Errorf("zenmux catalog returned success=false")
}
items := make([]ModelCatalogItem, 0, len(parsed.Data))
for _, z := range parsed.Data {
if z.Slug != "" {
items = append(items, zenmuxToItem(z))
}
}
catalogMu.Lock()
catalogCache = items
catalogCacheTime = time.Now()
catalogMu.Unlock()
return items, nil
}
// GetTokenQuota returns quota info for the authenticated relay token.
// GET /api/token-quota — requires TokenAuth middleware.
func GetTokenQuota(c *gin.Context) {
tokenId := c.GetInt(ctxkey.TokenId)
token, err := model.GetTokenById(tokenId)
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"remain_quota": token.RemainQuota,
"used_quota": token.UsedQuota,
"unlimited_quota": token.UnlimitedQuota,
"expired_time": token.ExpiredTime,
"name": token.Name,
},
})
}
// GetModelCatalog returns models with capability metadata, filtered to only those
// available to the requesting user's group.
func GetModelCatalog(c *gin.Context) {
ctx := c.Request.Context()
var availableModels []string
if c.GetString(ctxkey.AvailableModels) != "" {
availableModels = strings.Split(c.GetString(ctxkey.AvailableModels), ",")
} else {
id := c.GetInt(ctxkey.Id)
userGroup, err := model.CacheGetUserGroup(id)
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
var groupErr error
availableModels, groupErr = model.CacheGetGroupModels(ctx, userGroup)
if groupErr != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": groupErr.Error()})
return
}
}
available := make(map[string]bool, len(availableModels))
for _, m := range availableModels {
available[m] = true
}
catalog, err := fetchZenmuxCatalog()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "model catalog unavailable: " + err.Error(),
"data": []ModelCatalogItem{},
})
return
}
result := make([]ModelCatalogItem, 0, len(catalog))
for _, item := range catalog {
if available[item.Id] {
result = append(result, item)
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": result,
})
}
// zenmuxChannelDef describes one Zenmux protocol → one-api channel mapping.
type zenmuxChannelDef struct {
Name string
Type int
Protocols []string // suitable_api values to include
}
// allZenmuxChannelDefs is the canonical channel mapping for Zenmux.
//
// Zenmux exposes a single OpenAI-compatible ingress at /api/v1 for ALL text models
// (chat.completions, responses, messages, gemini, generate).
// suitable_api is upstream metadata only — Zenmux always accepts OpenAI format on input.
//
// Imagen and Veo use https://zenmux.ai/api/vertex-ai (Vertex AI SDK) and are
// not handled here (one-api's relay framework doesn't support that format yet).
var allZenmuxChannelDefs = []zenmuxChannelDef{
{
Name: "Zenmux",
Type: 50, // OpenAICompatible
Protocols: []string{"chat.completions", "responses", "messages", "gemini", "generate"},
},
}
// GetZenmuxProtocols returns all distinct suitable_api values found in the Zenmux catalog.
// GET /api/zenmux/protocols — admin only, for debugging protocol → channel mapping.
func GetZenmuxProtocols(c *gin.Context) {
catalog, err := fetchZenmuxCatalog()
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
seen := make(map[string]int)
for _, item := range catalog {
for _, api := range splitCSV(item.SuitableApi) {
api = strings.TrimSpace(api)
if api != "" {
seen[api]++
}
}
}
c.JSON(http.StatusOK, gin.H{"success": true, "data": seen})
}
// SetupZenmuxChannels creates all Zenmux protocol channels and syncs their model lists.
// POST /api/zenmux/setup — body: {"key":"<zenmux-api-key>","base_url":"https://zenmux.ai"}
// Skips channels that already exist (matched by name). Idempotent.
func SetupZenmuxChannels(c *gin.Context) {
var req struct {
Key string `json:"key"`
BaseURL string `json:"base_url"`
}
if err := c.ShouldBindJSON(&req); err != nil || req.Key == "" {
c.JSON(http.StatusOK, gin.H{"success": false, "message": "key is required"})
return
}
baseURL := req.BaseURL
if baseURL == "" {
baseURL = "https://zenmux.ai/api/v1"
}
catalog, err := fetchZenmuxCatalog()
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": "fetch zenmux catalog failed: " + err.Error()})
return
}
type result struct {
Name string `json:"name"`
Models int `json:"models"`
Action string `json:"action"` // "created" | "skipped"
}
results := make([]result, 0, len(allZenmuxChannelDefs))
for _, def := range allZenmuxChannelDefs {
// Build protocol set for fast lookup.
wantProtos := make(map[string]bool, len(def.Protocols))
for _, p := range def.Protocols {
wantProtos[strings.ToLower(p)] = true
}
// Filter catalog to matching protocols.
slugs := make([]string, 0)
for _, item := range catalog {
for _, api := range splitCSV(item.SuitableApi) {
if wantProtos[strings.ToLower(strings.TrimSpace(api))] {
slugs = append(slugs, item.Id)
break
}
}
}
modelList := strings.Join(slugs, ",")
// Check if a channel with this name already exists.
var existing []model.Channel
model.DB.Where("name = ?", def.Name).Find(&existing)
if len(existing) > 0 {
// Update model list on the first match.
ch := existing[0]
ch.Models = modelList
_ = ch.Update()
results = append(results, result{Name: def.Name, Models: len(slugs), Action: "updated"})
continue
}
// Create new channel.
priority := int64(0)
weight := uint(1)
ch := model.Channel{
Type: def.Type,
Name: def.Name,
Key: req.Key,
BaseURL: &baseURL,
Models: modelList,
Status: 1,
Group: "default",
Priority: &priority,
Weight: &weight,
}
if err := ch.Insert(); err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": fmt.Sprintf("create channel %q failed: %s", def.Name, err.Error())})
return
}
results = append(results, result{Name: def.Name, Models: len(slugs), Action: "created"})
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": fmt.Sprintf("setup complete: %d channels processed", len(results)),
"data": results,
})
}
// SyncZenmuxModels fetches all models from Zenmux and updates the channel's model list.
// POST /api/channel/:id/sync-zenmux
// Optional query param: ?protocol=chat.completions
// Accepts comma-separated values to match multiple protocols, e.g.
// ?protocol=google.gemini,google.imagen,google.video
func SyncZenmuxModels(c *gin.Context) {
channelId, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": "invalid channel id"})
return
}
// Build a set of requested protocols (lowercased) for O(1) lookup.
protocolParam := strings.TrimSpace(c.Query("protocol"))
wantProtocols := make(map[string]bool)
for _, p := range splitCSV(protocolParam) {
if p != "" {
wantProtocols[strings.ToLower(p)] = true
}
}
catalog, err := fetchZenmuxCatalog()
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
slugs := make([]string, 0, len(catalog))
for _, item := range catalog {
if len(wantProtocols) > 0 {
matched := false
for _, api := range splitCSV(item.SuitableApi) {
if wantProtocols[strings.ToLower(strings.TrimSpace(api))] {
matched = true
break
}
}
if !matched {
continue
}
}
slugs = append(slugs, item.Id)
}
modelList := strings.Join(slugs, ",")
ch, err := model.GetChannelById(channelId, false)
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel not found"})
return
}
ch.Models = modelList
if err := ch.Update(); err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
msg := fmt.Sprintf("synced %d models to channel", len(slugs))
if protocolParam != "" {
msg = fmt.Sprintf("synced %d models (protocol=%s) to channel", len(slugs), protocolParam)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": msg,
"data": len(slugs),
})
}
// SyncZenmuxRatios fetches the Zenmux model catalog and updates one-api's
// ModelRatio and CompletionRatio maps so every Zenmux model has correct billing.
//
// Ratio formula:
//
// modelRatio = input_price ($/1M tokens, raw)
// completionRatio = output_price / input_price (relative to input)
//
// POST /api/zenmux/sync-ratios — admin only.
func SyncZenmuxRatios(c *gin.Context) {
catalog, err := fetchZenmuxCatalog()
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": "fetch catalog failed: " + err.Error()})
return
}
// Load current ratios so we only ADD/UPDATE Zenmux entries, not wipe custom ones.
var modelRatioMap map[string]float64
var completionRatioMap map[string]float64
_ = json.Unmarshal([]byte(billingratio.ModelRatio2JSONString()), &modelRatioMap)
_ = json.Unmarshal([]byte(billingratio.CompletionRatio2JSONString()), &completionRatioMap)
if modelRatioMap == nil {
modelRatioMap = make(map[string]float64)
}
if completionRatioMap == nil {
completionRatioMap = make(map[string]float64)
}
updated, skipped := 0, 0
for _, item := range catalog {
if item.InputPrice <= 0 {
skipped++
continue
}
modelRatioMap[item.Id] = item.InputPrice
if item.OutputPrice > 0 {
completionRatioMap[item.Id] = item.OutputPrice / item.InputPrice
}
updated++
}
newModelJSON, err := json.Marshal(modelRatioMap)
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": "marshal model ratio failed: " + err.Error()})
return
}
if err = billingratio.UpdateModelRatioByJSONString(string(newModelJSON)); err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": "update model ratio failed: " + err.Error()})
return
}
if err = model.UpdateOption("ModelRatio", string(newModelJSON)); err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": "save model ratio failed: " + err.Error()})
return
}
newCompletionJSON, err := json.Marshal(completionRatioMap)
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": "marshal completion ratio failed: " + err.Error()})
return
}
if err = billingratio.UpdateCompletionRatioByJSONString(string(newCompletionJSON)); err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": "update completion ratio failed: " + err.Error()})
return
}
if err = model.UpdateOption("CompletionRatio", string(newCompletionJSON)); err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": "save completion ratio failed: " + err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": fmt.Sprintf("synced ratios for %d models (%d skipped — no pricing)", updated, skipped),
})
}