206 lines
5.1 KiB
Go
206 lines
5.1 KiB
Go
package openai
|
||
|
||
import (
|
||
modelprovider2 "ai-css/library/modelprovider"
|
||
"ai-css/library/modelprovider/config"
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"math/rand"
|
||
"net/http"
|
||
"strings"
|
||
|
||
"ai-css/library/logger"
|
||
)
|
||
|
||
type Provider struct {
|
||
httpClient *http.Client
|
||
conf *config.ProviderConfig
|
||
blackApikey map[string]struct{}
|
||
}
|
||
|
||
func New(conf *config.ProviderConfig, httpc *http.Client) *Provider {
|
||
if httpc == nil {
|
||
httpc = http.DefaultClient
|
||
}
|
||
return &Provider{conf: conf, httpClient: httpc, blackApikey: make(map[string]struct{})}
|
||
}
|
||
|
||
func (p *Provider) Capabilities() modelprovider2.Capability {
|
||
return modelprovider2.Capability{
|
||
Vendor: "openai",
|
||
SupportsStreaming: true,
|
||
MaxContextTokens: 128000,
|
||
}
|
||
}
|
||
|
||
func (p *Provider) InvokeCompletion(ctx context.Context, req *modelprovider2.ChatRequest) (*modelprovider2.ChatResponse, error) {
|
||
// TODO: 将 req 映射到 OpenAI Responses/Chat API,发起 HTTP,解析返回
|
||
return &modelprovider2.ChatResponse{
|
||
ID: "mock-openai-id",
|
||
Model: req.Model,
|
||
Content: "hello from openai (mock)",
|
||
Meta: modelprovider2.Meta{Vendor: "openai"},
|
||
}, nil
|
||
}
|
||
|
||
func (p *Provider) StreamCompletion(ctx context.Context, req *modelprovider2.ChatRequest, h modelprovider2.StreamChatCallback) (err error) {
|
||
var (
|
||
temp = float32(0.7)
|
||
store = false
|
||
|
||
inputMessages []OpenAIChatMessage
|
||
)
|
||
|
||
for _, msg := range req.Messages {
|
||
var (
|
||
item OpenAIChatMessage
|
||
isInput bool
|
||
)
|
||
switch msg.Role {
|
||
case modelprovider2.RoleSystem:
|
||
item.Role = "system"
|
||
isInput = true
|
||
case modelprovider2.RoleAssistant:
|
||
item.Role = "assistant"
|
||
case modelprovider2.RoleUser:
|
||
item.Role = "user"
|
||
isInput = true
|
||
}
|
||
for _, part := range msg.Parts {
|
||
var data interface{}
|
||
switch part.Type {
|
||
case modelprovider2.PartText:
|
||
data = NewTextPart(isInput, part.Text)
|
||
case modelprovider2.PartImage:
|
||
data = NewImagePart(isInput, part.ImageURL)
|
||
}
|
||
item.Content = append(item.Content, data)
|
||
}
|
||
inputMessages = append(inputMessages, item)
|
||
}
|
||
|
||
var (
|
||
callreq = &OpenAIResponsesRequest{
|
||
Model: req.Model,
|
||
Input: inputMessages, // 聊天内容
|
||
Stream: req.IsStream, // 流式很关键
|
||
Store: &store, // 不持久化这次对话
|
||
}
|
||
|
||
apikeys []string
|
||
)
|
||
|
||
if IsGPT4Model(req.Model) {
|
||
callreq.Temperature = &temp
|
||
}
|
||
|
||
for _, item := range p.conf.GetApiKeys() {
|
||
if ok := blackKeyMgr.IsBlack(item); ok {
|
||
continue
|
||
}
|
||
apikeys = append(apikeys, item)
|
||
}
|
||
|
||
rand.Shuffle(len(apikeys), func(i, j int) {
|
||
apikeys[i], apikeys[j] = apikeys[j], apikeys[i]
|
||
})
|
||
|
||
logger.Debugf("call openai apikeys:%v", apikeys)
|
||
|
||
for _, ak := range apikeys {
|
||
c := NewOpenaiClient(ak, p.conf.GetBaseUrl(), p.httpClient)
|
||
_, err = c.callResponses(ctx, callreq, p.WrapStreamCallback(h))
|
||
if err != nil {
|
||
logger.Errorf("do callResponses api failed err:%v", err)
|
||
if isApikeyInvalid(err) {
|
||
blackKeyMgr.AddBlackKey(ak)
|
||
}
|
||
if errors.Is(err, NetworkError) {
|
||
break
|
||
}
|
||
if !errors.Is(err, io.EOF) {
|
||
continue
|
||
}
|
||
}
|
||
return
|
||
}
|
||
if err != nil {
|
||
logger.Errorf("call cloud model failed err:%v", err)
|
||
err = fmt.Errorf("cloud model server internal error")
|
||
}
|
||
return
|
||
}
|
||
|
||
func (p *Provider) WrapStreamCallback(h modelprovider2.StreamChatCallback) func(*ResponsesStreamEvent) error {
|
||
return func(event *ResponsesStreamEvent) error {
|
||
switch EventType(event.Type) {
|
||
case StreamRespondError, StreamRespondFailed:
|
||
if event.Error != nil {
|
||
return fmt.Errorf("OpenAI streaming error: %s (%s)", event.Error.Message, event.Error.Code)
|
||
}
|
||
return fmt.Errorf("unknown OpenAI streaming error: %v", event)
|
||
case StreamRespondOutputTextDelta:
|
||
if event.Delta != "" {
|
||
if err := h(modelprovider2.StreamEvent{
|
||
Kind: modelprovider2.StreamDelta,
|
||
Text: event.Delta,
|
||
}); err != nil {
|
||
return fmt.Errorf("callback execution failed: %w", err)
|
||
}
|
||
}
|
||
case StreamRespondComplete:
|
||
if err := h(modelprovider2.StreamEvent{
|
||
Kind: modelprovider2.StreamEnd,
|
||
OutputTokens: event.Response.Usage.OutputTokens,
|
||
}); err != nil {
|
||
return fmt.Errorf("callback execution failed: %w", err)
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
}
|
||
|
||
func (p *Provider) ListModels(ctx context.Context) (result []modelprovider2.ModelInfo, err error) {
|
||
var models *ModelsResponse
|
||
for _, ak := range p.conf.GetApiKeys() {
|
||
c := NewOpenaiClient(ak, p.conf.GetBaseUrl(), p.httpClient)
|
||
models, err = c.getModels(ctx)
|
||
if err != nil {
|
||
logger.Errorf("call responses api failed err:%v", err)
|
||
continue
|
||
}
|
||
break
|
||
}
|
||
if models == nil {
|
||
return
|
||
}
|
||
for _, model := range models.Data {
|
||
//if !FilterModel(model) {
|
||
// continue
|
||
//}
|
||
result = append(result, modelprovider2.ModelInfo{
|
||
RealID: model.ID,
|
||
Raw: model,
|
||
Vendor: model.OwnedBy,
|
||
DisplayName: model.ID,
|
||
})
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
func (p *Provider) GetDefaultModel() string {
|
||
return "gpt-4o"
|
||
}
|
||
|
||
func IsGPT4Model(model string) bool {
|
||
return strings.Contains(model, "gpt-4")
|
||
}
|
||
|
||
func isApikeyInvalid(err error) bool {
|
||
logger.Debugf("err:%v,sub:%s,contains:%v", err.Error(), "Incorrect API key provided", strings.Contains(err.Error(), "Incorrect API key provided"))
|
||
return strings.Contains(err.Error(), "Incorrect API key provided")
|
||
}
|