hjjjj c734c541b2
Some checks failed
CI / Unit tests (push) Has been cancelled
CI / commit_lint (push) Has been cancelled
feat: 添加视频URL支持及Zenmux集成
refactor: 重构Gemini适配器以支持多模态输入
fix: 修复React Hooks依赖警告
style: 清理未使用的导入和代码
docs: 更新用户界面文本和提示
perf: 优化图像和视频URL处理性能
test: 添加数据迁移工具和测试
build: 更新依赖项和.gitignore
chore: 同步Zenmux模型和价格比例
2026-03-12 17:53:27 +08:00

240 lines
7.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package controller
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
func getImageRequest(c *gin.Context, _ int) (*relaymodel.ImageRequest, error) {
imageRequest := &relaymodel.ImageRequest{}
err := common.UnmarshalBodyReusable(c, imageRequest)
if err != nil {
return nil, err
}
if imageRequest.N == 0 {
imageRequest.N = 1
}
if imageRequest.Size == "" {
imageRequest.Size = "1024x1024"
}
if imageRequest.Model == "" {
imageRequest.Model = "dall-e-2"
}
return imageRequest, nil
}
func isValidImageSize(model string, size string) bool {
if model == "cogview-3" || billingratio.ImageSizeRatios[model] == nil {
return true
}
_, ok := billingratio.ImageSizeRatios[model][size]
return ok
}
func isValidImagePromptLength(model string, promptLength int) bool {
maxPromptLength, ok := billingratio.ImagePromptLengthLimitations[model]
return !ok || promptLength <= maxPromptLength
}
func isWithinRange(element string, value int) bool {
amounts, ok := billingratio.ImageGenerationAmounts[element]
return !ok || (value >= amounts[0] && value <= amounts[1])
}
func getImageSizeRatio(model string, size string) float64 {
if ratio, ok := billingratio.ImageSizeRatios[model][size]; ok {
return ratio
}
return 1
}
func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *meta.Meta) *relaymodel.ErrorWithStatusCode {
// check prompt length
if imageRequest.Prompt == "" {
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
}
// model validation
if !isValidImageSize(imageRequest.Model, imageRequest.Size) {
return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
}
if !isValidImagePromptLength(imageRequest.Model, len(imageRequest.Prompt)) {
return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
}
// Number of generated images validation
if !isWithinRange(imageRequest.Model, imageRequest.N) {
return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
}
return nil
}
func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
if imageRequest == nil {
return 0, errors.New("imageRequest is nil")
}
imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size)
if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
if imageRequest.Size == "1024x1024" {
imageCostRatio *= 2
} else {
imageCostRatio *= 1.5
}
}
return imageCostRatio, nil
}
func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
ctx := c.Request.Context()
meta := meta.GetByContext(c)
imageRequest, err := getImageRequest(c, meta.Mode)
if err != nil {
logger.Errorf(ctx, "getImageRequest failed: %s", err.Error())
return openai.ErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
}
// map model name
var isModelMapped bool
meta.OriginModelName = imageRequest.Model
imageRequest.Model, isModelMapped = getMappedModelName(imageRequest.Model, meta.ModelMapping)
meta.ActualModelName = imageRequest.Model
// model validation
bizErr := validateImageRequest(imageRequest, meta)
if bizErr != nil {
return bizErr
}
imageCostRatio, err := getImageCostRatio(imageRequest)
if err != nil {
return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError)
}
imageModel := imageRequest.Model
// Convert the original image model
imageRequest.Model, _ = getMappedModelName(imageRequest.Model, billingratio.ImageOriginModelName)
c.Set("response_format", imageRequest.ResponseFormat)
var requestBody io.Reader
if isModelMapped || meta.ChannelType == channeltype.Azure { // make Azure channel request body
jsonStr, err := json.Marshal(imageRequest)
if err != nil {
return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
adaptor := relay.GetAdaptor(meta.APIType)
if adaptor == nil {
return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest)
}
adaptor.Init(meta)
// these adaptors need to convert the request
switch meta.ChannelType {
case channeltype.Zhipu,
channeltype.Ali,
channeltype.Replicate,
channeltype.Baidu,
channeltype.Gemini:
finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
if err != nil {
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
}
jsonStr, err := json.Marshal(finalRequest)
if err != nil {
return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
}
modelRatio := billingratio.GetModelRatio(imageModel, meta.ChannelType)
groupRatio := billingratio.GetGroupRatio(meta.Group)
ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
var quota int64
switch meta.ChannelType {
case channeltype.Replicate:
// replicate always return 1 image
quota = int64(ratio * imageCostRatio * 1000)
default:
quota = int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
}
if userQuota-quota < 0 {
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
// do request
resp, err := adaptor.DoRequest(c, meta, requestBody)
if err != nil {
logger.Errorf(ctx, "DoRequest failed: %s", err.Error())
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
defer func(ctx context.Context) {
if resp != nil &&
resp.StatusCode != http.StatusCreated && // replicate returns 201
resp.StatusCode != http.StatusOK {
return
}
err := model.PostConsumeTokenQuota(meta.TokenId, quota)
if err != nil {
logger.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(ctx, meta.UserId)
if err != nil {
logger.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString(ctxkey.TokenName)
logContent := fmt.Sprintf("倍率:%.2f × %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, &model.Log{
UserId: meta.UserId,
ChannelId: meta.ChannelId,
PromptTokens: 0,
CompletionTokens: 0,
ModelName: imageRequest.Model,
TokenName: tokenName,
Quota: int(quota),
Content: logContent,
})
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
channelId := c.GetInt(ctxkey.ChannelId)
model.UpdateChannelUsedQuota(channelId, quota)
}
}(c.Request.Context())
// do response
_, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
return respErr
}
return nil
}