package openai import ( modelprovider2 "ai-css/library/modelprovider" "ai-css/library/modelprovider/config" "context" "errors" "fmt" "io" "math/rand" "net/http" "strings" "ai-css/library/logger" ) type Provider struct { httpClient *http.Client conf *config.ProviderConfig blackApikey map[string]struct{} } func New(conf *config.ProviderConfig, httpc *http.Client) *Provider { if httpc == nil { httpc = http.DefaultClient } return &Provider{conf: conf, httpClient: httpc, blackApikey: make(map[string]struct{})} } func (p *Provider) Capabilities() modelprovider2.Capability { return modelprovider2.Capability{ Vendor: "openai", SupportsStreaming: true, MaxContextTokens: 128000, } } func (p *Provider) InvokeCompletion(ctx context.Context, req *modelprovider2.ChatRequest) (*modelprovider2.ChatResponse, error) { // TODO: 将 req 映射到 OpenAI Responses/Chat API,发起 HTTP,解析返回 return &modelprovider2.ChatResponse{ ID: "mock-openai-id", Model: req.Model, Content: "hello from openai (mock)", Meta: modelprovider2.Meta{Vendor: "openai"}, }, nil } func (p *Provider) StreamCompletion(ctx context.Context, req *modelprovider2.ChatRequest, h modelprovider2.StreamChatCallback) (err error) { var ( temp = float32(0.7) store = false inputMessages []OpenAIChatMessage ) for _, msg := range req.Messages { var ( item OpenAIChatMessage isInput bool ) switch msg.Role { case modelprovider2.RoleSystem: item.Role = "system" isInput = true case modelprovider2.RoleAssistant: item.Role = "assistant" case modelprovider2.RoleUser: item.Role = "user" isInput = true } for _, part := range msg.Parts { var data interface{} switch part.Type { case modelprovider2.PartText: data = NewTextPart(isInput, part.Text) case modelprovider2.PartImage: data = NewImagePart(isInput, part.ImageURL) } item.Content = append(item.Content, data) } inputMessages = append(inputMessages, item) } var ( callreq = &OpenAIResponsesRequest{ Model: req.Model, Input: inputMessages, // 聊天内容 Stream: req.IsStream, // 流式很关键 Store: &store, // 不持久化这次对话 } apikeys []string ) if IsGPT4Model(req.Model) { callreq.Temperature = &temp } for _, item := range p.conf.GetApiKeys() { if ok := blackKeyMgr.IsBlack(item); ok { continue } apikeys = append(apikeys, item) } rand.Shuffle(len(apikeys), func(i, j int) { apikeys[i], apikeys[j] = apikeys[j], apikeys[i] }) logger.Debugf("call openai apikeys:%v", apikeys) for _, ak := range apikeys { c := NewOpenaiClient(ak, p.conf.GetBaseUrl(), p.httpClient) _, err = c.callResponses(ctx, callreq, p.WrapStreamCallback(h)) if err != nil { logger.Errorf("do callResponses api failed err:%v", err) if isApikeyInvalid(err) { blackKeyMgr.AddBlackKey(ak) } if errors.Is(err, NetworkError) { break } if !errors.Is(err, io.EOF) { continue } } return } if err != nil { logger.Errorf("call cloud model failed err:%v", err) err = fmt.Errorf("cloud model server internal error") } return } func (p *Provider) WrapStreamCallback(h modelprovider2.StreamChatCallback) func(*ResponsesStreamEvent) error { return func(event *ResponsesStreamEvent) error { switch EventType(event.Type) { case StreamRespondError, StreamRespondFailed: if event.Error != nil { return fmt.Errorf("OpenAI streaming error: %s (%s)", event.Error.Message, event.Error.Code) } return fmt.Errorf("unknown OpenAI streaming error: %v", event) case StreamRespondOutputTextDelta: if event.Delta != "" { if err := h(modelprovider2.StreamEvent{ Kind: modelprovider2.StreamDelta, Text: event.Delta, }); err != nil { return fmt.Errorf("callback execution failed: %w", err) } } case StreamRespondComplete: if err := h(modelprovider2.StreamEvent{ Kind: modelprovider2.StreamEnd, OutputTokens: event.Response.Usage.OutputTokens, }); err != nil { return fmt.Errorf("callback execution failed: %w", err) } } return nil } } func (p *Provider) ListModels(ctx context.Context) (result []modelprovider2.ModelInfo, err error) { var models *ModelsResponse for _, ak := range p.conf.GetApiKeys() { c := NewOpenaiClient(ak, p.conf.GetBaseUrl(), p.httpClient) models, err = c.getModels(ctx) if err != nil { logger.Errorf("call responses api failed err:%v", err) continue } break } if models == nil { return } for _, model := range models.Data { //if !FilterModel(model) { // continue //} result = append(result, modelprovider2.ModelInfo{ RealID: model.ID, Raw: model, Vendor: model.OwnedBy, DisplayName: model.ID, }) } return } func (p *Provider) GetDefaultModel() string { return "gpt-4o" } func IsGPT4Model(model string) bool { return strings.Contains(model, "gpt-4") } func isApikeyInvalid(err error) bool { logger.Debugf("err:%v,sub:%s,contains:%v", err.Error(), "Incorrect API key provided", strings.Contains(err.Error(), "Incorrect API key provided")) return strings.Contains(err.Error(), "Incorrect API key provided") }