ai-css/library/modelprovider/providers/openai/sdk.go

206 lines
5.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package openai
import (
modelprovider2 "ai-css/library/modelprovider"
"ai-css/library/modelprovider/config"
"context"
"errors"
"fmt"
"io"
"math/rand"
"net/http"
"strings"
"ai-css/library/logger"
)
type Provider struct {
httpClient *http.Client
conf *config.ProviderConfig
blackApikey map[string]struct{}
}
func New(conf *config.ProviderConfig, httpc *http.Client) *Provider {
if httpc == nil {
httpc = http.DefaultClient
}
return &Provider{conf: conf, httpClient: httpc, blackApikey: make(map[string]struct{})}
}
func (p *Provider) Capabilities() modelprovider2.Capability {
return modelprovider2.Capability{
Vendor: "openai",
SupportsStreaming: true,
MaxContextTokens: 128000,
}
}
func (p *Provider) InvokeCompletion(ctx context.Context, req *modelprovider2.ChatRequest) (*modelprovider2.ChatResponse, error) {
// TODO: 将 req 映射到 OpenAI Responses/Chat API发起 HTTP解析返回
return &modelprovider2.ChatResponse{
ID: "mock-openai-id",
Model: req.Model,
Content: "hello from openai (mock)",
Meta: modelprovider2.Meta{Vendor: "openai"},
}, nil
}
func (p *Provider) StreamCompletion(ctx context.Context, req *modelprovider2.ChatRequest, h modelprovider2.StreamChatCallback) (err error) {
var (
temp = float32(0.7)
store = false
inputMessages []OpenAIChatMessage
)
for _, msg := range req.Messages {
var (
item OpenAIChatMessage
isInput bool
)
switch msg.Role {
case modelprovider2.RoleSystem:
item.Role = "system"
isInput = true
case modelprovider2.RoleAssistant:
item.Role = "assistant"
case modelprovider2.RoleUser:
item.Role = "user"
isInput = true
}
for _, part := range msg.Parts {
var data interface{}
switch part.Type {
case modelprovider2.PartText:
data = NewTextPart(isInput, part.Text)
case modelprovider2.PartImage:
data = NewImagePart(isInput, part.ImageURL)
}
item.Content = append(item.Content, data)
}
inputMessages = append(inputMessages, item)
}
var (
callreq = &OpenAIResponsesRequest{
Model: req.Model,
Input: inputMessages, // 聊天内容
Stream: req.IsStream, // 流式很关键
Store: &store, // 不持久化这次对话
}
apikeys []string
)
if IsGPT4Model(req.Model) {
callreq.Temperature = &temp
}
for _, item := range p.conf.GetApiKeys() {
if ok := blackKeyMgr.IsBlack(item); ok {
continue
}
apikeys = append(apikeys, item)
}
rand.Shuffle(len(apikeys), func(i, j int) {
apikeys[i], apikeys[j] = apikeys[j], apikeys[i]
})
logger.Debugf("call openai apikeys:%v", apikeys)
for _, ak := range apikeys {
c := NewOpenaiClient(ak, p.conf.GetBaseUrl(), p.httpClient)
_, err = c.callResponses(ctx, callreq, p.WrapStreamCallback(h))
if err != nil {
logger.Errorf("do callResponses api failed err:%v", err)
if isApikeyInvalid(err) {
blackKeyMgr.AddBlackKey(ak)
}
if errors.Is(err, NetworkError) {
break
}
if !errors.Is(err, io.EOF) {
continue
}
}
return
}
if err != nil {
logger.Errorf("call cloud model failed err:%v", err)
err = fmt.Errorf("cloud model server internal error")
}
return
}
func (p *Provider) WrapStreamCallback(h modelprovider2.StreamChatCallback) func(*ResponsesStreamEvent) error {
return func(event *ResponsesStreamEvent) error {
switch EventType(event.Type) {
case StreamRespondError, StreamRespondFailed:
if event.Error != nil {
return fmt.Errorf("OpenAI streaming error: %s (%s)", event.Error.Message, event.Error.Code)
}
return fmt.Errorf("unknown OpenAI streaming error: %v", event)
case StreamRespondOutputTextDelta:
if event.Delta != "" {
if err := h(modelprovider2.StreamEvent{
Kind: modelprovider2.StreamDelta,
Text: event.Delta,
}); err != nil {
return fmt.Errorf("callback execution failed: %w", err)
}
}
case StreamRespondComplete:
if err := h(modelprovider2.StreamEvent{
Kind: modelprovider2.StreamEnd,
OutputTokens: event.Response.Usage.OutputTokens,
}); err != nil {
return fmt.Errorf("callback execution failed: %w", err)
}
}
return nil
}
}
func (p *Provider) ListModels(ctx context.Context) (result []modelprovider2.ModelInfo, err error) {
var models *ModelsResponse
for _, ak := range p.conf.GetApiKeys() {
c := NewOpenaiClient(ak, p.conf.GetBaseUrl(), p.httpClient)
models, err = c.getModels(ctx)
if err != nil {
logger.Errorf("call responses api failed err:%v", err)
continue
}
break
}
if models == nil {
return
}
for _, model := range models.Data {
//if !FilterModel(model) {
// continue
//}
result = append(result, modelprovider2.ModelInfo{
RealID: model.ID,
Raw: model,
Vendor: model.OwnedBy,
DisplayName: model.ID,
})
}
return
}
func (p *Provider) GetDefaultModel() string {
return "gpt-4o"
}
func IsGPT4Model(model string) bool {
return strings.Contains(model, "gpt-4")
}
func isApikeyInvalid(err error) bool {
logger.Debugf("err:%v,sub:%s,contains:%v", err.Error(), "Incorrect API key provided", strings.Contains(err.Error(), "Incorrect API key provided"))
return strings.Contains(err.Error(), "Incorrect API key provided")
}