package openai import ( modelprovider2 "ai-css/library/modelprovider" "ai-css/library/modelprovider/config" "context" "errors" "fmt" "io" "math/rand" "net/http" "strings" "ai-css/library/logger" "github.com/openai/openai-go/v3" "github.com/openai/openai-go/v3/option" "github.com/openai/openai-go/v3/responses" "github.com/openai/openai-go/v3/shared" ) type Provider struct { httpClient *http.Client conf *config.ProviderConfig blackApikey map[string]struct{} } func New(conf *config.ProviderConfig, httpc *http.Client) *Provider { if httpc == nil { httpc = http.DefaultClient } return &Provider{conf: conf, httpClient: httpc, blackApikey: make(map[string]struct{})} } func (p *Provider) Capabilities() modelprovider2.Capability { return modelprovider2.Capability{ Vendor: "openai", SupportsStreaming: true, MaxContextTokens: 128000, } } func (p *Provider) InvokeCompletion(ctx context.Context, req *modelprovider2.ChatRequest) (*modelprovider2.ChatResponse, error) { var respParam = responses.ResponseNewParams{ Model: req.Model, } var msg []responses.ResponseInputItemUnionParam for _, item := range req.Messages { msg = append(msg, modelprovider2.PartsToResponseInputItemUnionParam(item.Role, item.Parts)) } respParam.Input = responses.ResponseNewParamsInputUnion{ OfInputItemList: msg, } logger.Infof("ai chat msg:%v", msg) var opts []option.RequestOption if p.conf != nil { if len(p.conf.GetApiKeys()) > 0 { opts = []option.RequestOption{ option.WithAPIKey(p.conf.GetApiKeys()[0]), } } if p.conf.GetBaseUrl() != "" { opts = append(opts, option.WithBaseURL(p.conf.GetBaseUrl())) } if p.httpClient != nil { opts = append(opts, option.WithHTTPClient(p.httpClient)) } } client := openai.NewClient(opts...) resp, err := client.Responses.New(context.TODO(), respParam) if err != nil { logger.Errorf("error while calling OpenAI response api failed err: %v", err) return nil, err } if resp.Error.Code != "" { logger.Errorf("error while calling OpenAI response api failed err: %v", resp.Error.RawJSON()) return nil, fmt.Errorf("call openai response failed err: %v", resp.Error.RawJSON()) } content := "" for _, item := range resp.Output { if item.Type == "message" { for _, cn := range item.Content { if cn.Type == "output_text" { content = cn.Text } } } } return &modelprovider2.ChatResponse{ ID: resp.ID, Model: resp.Model, Content: content, Meta: modelprovider2.Meta{Vendor: "openai"}, }, nil } func (p *Provider) StreamCompletion(ctx context.Context, req *modelprovider2.ChatRequest, h modelprovider2.StreamChatCallback) (err error) { var ( temp = float32(0.7) store = false inputMessages []OpenAIChatMessage ) for _, msg := range req.Messages { var ( item OpenAIChatMessage isInput bool ) switch msg.Role { case modelprovider2.RoleSystem: item.Role = "system" isInput = true case modelprovider2.RoleAssistant: item.Role = "assistant" case modelprovider2.RoleUser: item.Role = "user" isInput = true } for _, part := range msg.Parts { var data interface{} switch part.Type { case modelprovider2.PartText: data = NewTextPart(isInput, part.Text) case modelprovider2.PartImage: data = NewImagePart(isInput, part.ImageURL) } item.Content = append(item.Content, data) } inputMessages = append(inputMessages, item) } var ( callreq = &OpenAIResponsesRequest{ Model: req.Model, Input: inputMessages, // 聊天内容 Stream: req.IsStream, // 流式很关键 Store: &store, // 不持久化这次对话 } apikeys []string ) if IsGPT4Model(req.Model) { callreq.Temperature = &temp } for _, item := range p.conf.GetApiKeys() { if ok := blackKeyMgr.IsBlack(item); ok { continue } apikeys = append(apikeys, item) } rand.Shuffle(len(apikeys), func(i, j int) { apikeys[i], apikeys[j] = apikeys[j], apikeys[i] }) logger.Debugf("call openai apikeys:%v", apikeys) for _, ak := range apikeys { c := NewOpenaiClient(ak, p.conf.GetBaseUrl(), p.httpClient) _, err = c.callResponses(ctx, callreq, p.WrapStreamCallback(h)) if err != nil { logger.Errorf("do callResponses api failed err:%v", err) if isApikeyInvalid(err) { blackKeyMgr.AddBlackKey(ak) } if errors.Is(err, NetworkError) { break } if !errors.Is(err, io.EOF) { continue } } return } if err != nil { logger.Errorf("call cloud model failed err:%v", err) err = fmt.Errorf("cloud model server internal error") } return } func (p *Provider) WrapStreamCallback(h modelprovider2.StreamChatCallback) func(*ResponsesStreamEvent) error { return func(event *ResponsesStreamEvent) error { switch EventType(event.Type) { case StreamRespondError, StreamRespondFailed: if event.Error != nil { return fmt.Errorf("OpenAI streaming error: %s (%s)", event.Error.Message, event.Error.Code) } return fmt.Errorf("unknown OpenAI streaming error: %v", event) case StreamRespondOutputTextDelta: if event.Delta != "" { if err := h(modelprovider2.StreamEvent{ Kind: modelprovider2.StreamDelta, Text: event.Delta, }); err != nil { return fmt.Errorf("callback execution failed: %w", err) } } case StreamRespondComplete: if err := h(modelprovider2.StreamEvent{ Kind: modelprovider2.StreamEnd, OutputTokens: event.Response.Usage.OutputTokens, }); err != nil { return fmt.Errorf("callback execution failed: %w", err) } } return nil } } func (p *Provider) ListModels(ctx context.Context) (result []modelprovider2.ModelInfo, err error) { var models *ModelsResponse for _, ak := range p.conf.GetApiKeys() { c := NewOpenaiClient(ak, p.conf.GetBaseUrl(), p.httpClient) models, err = c.getModels(ctx) if err != nil { logger.Errorf("call responses api failed err:%v", err) continue } break } if models == nil { return } for _, model := range models.Data { //if !FilterModel(model) { // continue //} result = append(result, modelprovider2.ModelInfo{ RealID: model.ID, Raw: model, Vendor: model.OwnedBy, DisplayName: model.ID, }) } return } func (p *Provider) GetDefaultModel() shared.ResponsesModel { return responses.ChatModelGPT5Mini } func IsGPT4Model(model string) bool { return strings.Contains(model, "gpt-4") } func isApikeyInvalid(err error) bool { logger.Debugf("err:%v,sub:%s,contains:%v", err.Error(), "Incorrect API key provided", strings.Contains(err.Error(), "Incorrect API key provided")) return strings.Contains(err.Error(), "Incorrect API key provided") }