260 lines
6.5 KiB
Go
Executable File
260 lines
6.5 KiB
Go
Executable File
package openai
|
|
|
|
import (
|
|
modelprovider2 "ai-css/library/modelprovider"
|
|
"ai-css/library/modelprovider/config"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"math/rand"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"ai-css/library/logger"
|
|
|
|
"github.com/openai/openai-go/v3"
|
|
"github.com/openai/openai-go/v3/option"
|
|
"github.com/openai/openai-go/v3/responses"
|
|
"github.com/openai/openai-go/v3/shared"
|
|
)
|
|
|
|
type Provider struct {
|
|
httpClient *http.Client
|
|
conf *config.ProviderConfig
|
|
blackApikey map[string]struct{}
|
|
}
|
|
|
|
func New(conf *config.ProviderConfig, httpc *http.Client) *Provider {
|
|
if httpc == nil {
|
|
httpc = http.DefaultClient
|
|
}
|
|
return &Provider{conf: conf, httpClient: httpc, blackApikey: make(map[string]struct{})}
|
|
}
|
|
|
|
func (p *Provider) Capabilities() modelprovider2.Capability {
|
|
return modelprovider2.Capability{
|
|
Vendor: "openai",
|
|
SupportsStreaming: true,
|
|
MaxContextTokens: 128000,
|
|
}
|
|
}
|
|
|
|
func (p *Provider) InvokeCompletion(ctx context.Context, req *modelprovider2.ChatRequest) (*modelprovider2.ChatResponse, error) {
|
|
var respParam = responses.ResponseNewParams{
|
|
Model: req.Model,
|
|
}
|
|
var msg []responses.ResponseInputItemUnionParam
|
|
for _, item := range req.Messages {
|
|
msg = append(msg, modelprovider2.PartsToResponseInputItemUnionParam(item.Role, item.Parts))
|
|
}
|
|
respParam.Input = responses.ResponseNewParamsInputUnion{
|
|
OfInputItemList: msg,
|
|
}
|
|
logger.Infof("ai chat msg:%v", msg)
|
|
|
|
var opts []option.RequestOption
|
|
|
|
if p.conf != nil {
|
|
if len(p.conf.GetApiKeys()) > 0 {
|
|
opts = []option.RequestOption{
|
|
option.WithAPIKey(p.conf.GetApiKeys()[0]),
|
|
}
|
|
}
|
|
if p.conf.GetBaseUrl() != "" {
|
|
opts = append(opts, option.WithBaseURL(p.conf.GetBaseUrl()))
|
|
}
|
|
if p.httpClient != nil {
|
|
opts = append(opts, option.WithHTTPClient(p.httpClient))
|
|
}
|
|
}
|
|
|
|
client := openai.NewClient(opts...)
|
|
|
|
resp, err := client.Responses.New(context.TODO(), respParam)
|
|
if err != nil {
|
|
logger.Errorf("error while calling OpenAI response api failed err: %v", err)
|
|
return nil, err
|
|
}
|
|
if resp.Error.Code != "" {
|
|
logger.Errorf("error while calling OpenAI response api failed err: %v", resp.Error.RawJSON())
|
|
return nil, fmt.Errorf("call openai response failed err: %v", resp.Error.RawJSON())
|
|
}
|
|
|
|
content := ""
|
|
for _, item := range resp.Output {
|
|
if item.Type == "message" {
|
|
for _, cn := range item.Content {
|
|
if cn.Type == "output_text" {
|
|
content = cn.Text
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return &modelprovider2.ChatResponse{
|
|
ID: resp.ID,
|
|
Model: resp.Model,
|
|
Content: content,
|
|
Meta: modelprovider2.Meta{Vendor: "openai"},
|
|
}, nil
|
|
}
|
|
|
|
func (p *Provider) StreamCompletion(ctx context.Context, req *modelprovider2.ChatRequest, h modelprovider2.StreamChatCallback) (err error) {
|
|
var (
|
|
temp = float32(0.7)
|
|
store = false
|
|
|
|
inputMessages []OpenAIChatMessage
|
|
)
|
|
|
|
for _, msg := range req.Messages {
|
|
var (
|
|
item OpenAIChatMessage
|
|
isInput bool
|
|
)
|
|
switch msg.Role {
|
|
case modelprovider2.RoleSystem:
|
|
item.Role = "system"
|
|
isInput = true
|
|
case modelprovider2.RoleAssistant:
|
|
item.Role = "assistant"
|
|
case modelprovider2.RoleUser:
|
|
item.Role = "user"
|
|
isInput = true
|
|
}
|
|
for _, part := range msg.Parts {
|
|
var data interface{}
|
|
switch part.Type {
|
|
case modelprovider2.PartText:
|
|
data = NewTextPart(isInput, part.Text)
|
|
case modelprovider2.PartImage:
|
|
data = NewImagePart(isInput, part.ImageURL)
|
|
}
|
|
item.Content = append(item.Content, data)
|
|
}
|
|
inputMessages = append(inputMessages, item)
|
|
}
|
|
|
|
var (
|
|
callreq = &OpenAIResponsesRequest{
|
|
Model: req.Model,
|
|
Input: inputMessages, // 聊天内容
|
|
Stream: req.IsStream, // 流式很关键
|
|
Store: &store, // 不持久化这次对话
|
|
}
|
|
|
|
apikeys []string
|
|
)
|
|
|
|
if IsGPT4Model(req.Model) {
|
|
callreq.Temperature = &temp
|
|
}
|
|
|
|
for _, item := range p.conf.GetApiKeys() {
|
|
if ok := blackKeyMgr.IsBlack(item); ok {
|
|
continue
|
|
}
|
|
apikeys = append(apikeys, item)
|
|
}
|
|
|
|
rand.Shuffle(len(apikeys), func(i, j int) {
|
|
apikeys[i], apikeys[j] = apikeys[j], apikeys[i]
|
|
})
|
|
|
|
logger.Debugf("call openai apikeys:%v", apikeys)
|
|
|
|
for _, ak := range apikeys {
|
|
c := NewOpenaiClient(ak, p.conf.GetBaseUrl(), p.httpClient)
|
|
_, err = c.callResponses(ctx, callreq, p.WrapStreamCallback(h))
|
|
if err != nil {
|
|
logger.Errorf("do callResponses api failed err:%v", err)
|
|
if isApikeyInvalid(err) {
|
|
blackKeyMgr.AddBlackKey(ak)
|
|
}
|
|
if errors.Is(err, NetworkError) {
|
|
break
|
|
}
|
|
if !errors.Is(err, io.EOF) {
|
|
continue
|
|
}
|
|
}
|
|
return
|
|
}
|
|
if err != nil {
|
|
logger.Errorf("call cloud model failed err:%v", err)
|
|
err = fmt.Errorf("cloud model server internal error")
|
|
}
|
|
return
|
|
}
|
|
|
|
func (p *Provider) WrapStreamCallback(h modelprovider2.StreamChatCallback) func(*ResponsesStreamEvent) error {
|
|
return func(event *ResponsesStreamEvent) error {
|
|
switch EventType(event.Type) {
|
|
case StreamRespondError, StreamRespondFailed:
|
|
if event.Error != nil {
|
|
return fmt.Errorf("OpenAI streaming error: %s (%s)", event.Error.Message, event.Error.Code)
|
|
}
|
|
return fmt.Errorf("unknown OpenAI streaming error: %v", event)
|
|
case StreamRespondOutputTextDelta:
|
|
if event.Delta != "" {
|
|
if err := h(modelprovider2.StreamEvent{
|
|
Kind: modelprovider2.StreamDelta,
|
|
Text: event.Delta,
|
|
}); err != nil {
|
|
return fmt.Errorf("callback execution failed: %w", err)
|
|
}
|
|
}
|
|
case StreamRespondComplete:
|
|
if err := h(modelprovider2.StreamEvent{
|
|
Kind: modelprovider2.StreamEnd,
|
|
OutputTokens: event.Response.Usage.OutputTokens,
|
|
}); err != nil {
|
|
return fmt.Errorf("callback execution failed: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (p *Provider) ListModels(ctx context.Context) (result []modelprovider2.ModelInfo, err error) {
|
|
var models *ModelsResponse
|
|
for _, ak := range p.conf.GetApiKeys() {
|
|
c := NewOpenaiClient(ak, p.conf.GetBaseUrl(), p.httpClient)
|
|
models, err = c.getModels(ctx)
|
|
if err != nil {
|
|
logger.Errorf("call responses api failed err:%v", err)
|
|
continue
|
|
}
|
|
break
|
|
}
|
|
if models == nil {
|
|
return
|
|
}
|
|
for _, model := range models.Data {
|
|
//if !FilterModel(model) {
|
|
// continue
|
|
//}
|
|
result = append(result, modelprovider2.ModelInfo{
|
|
RealID: model.ID,
|
|
Raw: model,
|
|
Vendor: model.OwnedBy,
|
|
DisplayName: model.ID,
|
|
})
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (p *Provider) GetDefaultModel() shared.ResponsesModel {
|
|
return responses.ChatModelGPT5Mini
|
|
}
|
|
|
|
func IsGPT4Model(model string) bool {
|
|
return strings.Contains(model, "gpt-4")
|
|
}
|
|
|
|
func isApikeyInvalid(err error) bool {
|
|
logger.Debugf("err:%v,sub:%s,contains:%v", err.Error(), "Incorrect API key provided", strings.Contains(err.Error(), "Incorrect API key provided"))
|
|
return strings.Contains(err.Error(), "Incorrect API key provided")
|
|
}
|