ai-css/library/modelprovider/providers/openai/sdk.go
2026-02-12 08:50:11 +00:00

260 lines
6.5 KiB
Go
Executable File

package openai
import (
modelprovider2 "ai-css/library/modelprovider"
"ai-css/library/modelprovider/config"
"context"
"errors"
"fmt"
"io"
"math/rand"
"net/http"
"strings"
"ai-css/library/logger"
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/option"
"github.com/openai/openai-go/v3/responses"
"github.com/openai/openai-go/v3/shared"
)
type Provider struct {
httpClient *http.Client
conf *config.ProviderConfig
blackApikey map[string]struct{}
}
func New(conf *config.ProviderConfig, httpc *http.Client) *Provider {
if httpc == nil {
httpc = http.DefaultClient
}
return &Provider{conf: conf, httpClient: httpc, blackApikey: make(map[string]struct{})}
}
func (p *Provider) Capabilities() modelprovider2.Capability {
return modelprovider2.Capability{
Vendor: "openai",
SupportsStreaming: true,
MaxContextTokens: 128000,
}
}
func (p *Provider) InvokeCompletion(ctx context.Context, req *modelprovider2.ChatRequest) (*modelprovider2.ChatResponse, error) {
var respParam = responses.ResponseNewParams{
Model: req.Model,
}
var msg []responses.ResponseInputItemUnionParam
for _, item := range req.Messages {
msg = append(msg, modelprovider2.PartsToResponseInputItemUnionParam(item.Role, item.Parts))
}
respParam.Input = responses.ResponseNewParamsInputUnion{
OfInputItemList: msg,
}
logger.Infof("ai chat msg:%v", msg)
var opts []option.RequestOption
if p.conf != nil {
if len(p.conf.GetApiKeys()) > 0 {
opts = []option.RequestOption{
option.WithAPIKey(p.conf.GetApiKeys()[0]),
}
}
if p.conf.GetBaseUrl() != "" {
opts = append(opts, option.WithBaseURL(p.conf.GetBaseUrl()))
}
if p.httpClient != nil {
opts = append(opts, option.WithHTTPClient(p.httpClient))
}
}
client := openai.NewClient(opts...)
resp, err := client.Responses.New(context.TODO(), respParam)
if err != nil {
logger.Errorf("error while calling OpenAI response api failed err: %v", err)
return nil, err
}
if resp.Error.Code != "" {
logger.Errorf("error while calling OpenAI response api failed err: %v", resp.Error.RawJSON())
return nil, fmt.Errorf("call openai response failed err: %v", resp.Error.RawJSON())
}
content := ""
for _, item := range resp.Output {
if item.Type == "message" {
for _, cn := range item.Content {
if cn.Type == "output_text" {
content = cn.Text
}
}
}
}
return &modelprovider2.ChatResponse{
ID: resp.ID,
Model: resp.Model,
Content: content,
Meta: modelprovider2.Meta{Vendor: "openai"},
}, nil
}
func (p *Provider) StreamCompletion(ctx context.Context, req *modelprovider2.ChatRequest, h modelprovider2.StreamChatCallback) (err error) {
var (
temp = float32(0.7)
store = false
inputMessages []OpenAIChatMessage
)
for _, msg := range req.Messages {
var (
item OpenAIChatMessage
isInput bool
)
switch msg.Role {
case modelprovider2.RoleSystem:
item.Role = "system"
isInput = true
case modelprovider2.RoleAssistant:
item.Role = "assistant"
case modelprovider2.RoleUser:
item.Role = "user"
isInput = true
}
for _, part := range msg.Parts {
var data interface{}
switch part.Type {
case modelprovider2.PartText:
data = NewTextPart(isInput, part.Text)
case modelprovider2.PartImage:
data = NewImagePart(isInput, part.ImageURL)
}
item.Content = append(item.Content, data)
}
inputMessages = append(inputMessages, item)
}
var (
callreq = &OpenAIResponsesRequest{
Model: req.Model,
Input: inputMessages, // 聊天内容
Stream: req.IsStream, // 流式很关键
Store: &store, // 不持久化这次对话
}
apikeys []string
)
if IsGPT4Model(req.Model) {
callreq.Temperature = &temp
}
for _, item := range p.conf.GetApiKeys() {
if ok := blackKeyMgr.IsBlack(item); ok {
continue
}
apikeys = append(apikeys, item)
}
rand.Shuffle(len(apikeys), func(i, j int) {
apikeys[i], apikeys[j] = apikeys[j], apikeys[i]
})
logger.Debugf("call openai apikeys:%v", apikeys)
for _, ak := range apikeys {
c := NewOpenaiClient(ak, p.conf.GetBaseUrl(), p.httpClient)
_, err = c.callResponses(ctx, callreq, p.WrapStreamCallback(h))
if err != nil {
logger.Errorf("do callResponses api failed err:%v", err)
if isApikeyInvalid(err) {
blackKeyMgr.AddBlackKey(ak)
}
if errors.Is(err, NetworkError) {
break
}
if !errors.Is(err, io.EOF) {
continue
}
}
return
}
if err != nil {
logger.Errorf("call cloud model failed err:%v", err)
err = fmt.Errorf("cloud model server internal error")
}
return
}
func (p *Provider) WrapStreamCallback(h modelprovider2.StreamChatCallback) func(*ResponsesStreamEvent) error {
return func(event *ResponsesStreamEvent) error {
switch EventType(event.Type) {
case StreamRespondError, StreamRespondFailed:
if event.Error != nil {
return fmt.Errorf("OpenAI streaming error: %s (%s)", event.Error.Message, event.Error.Code)
}
return fmt.Errorf("unknown OpenAI streaming error: %v", event)
case StreamRespondOutputTextDelta:
if event.Delta != "" {
if err := h(modelprovider2.StreamEvent{
Kind: modelprovider2.StreamDelta,
Text: event.Delta,
}); err != nil {
return fmt.Errorf("callback execution failed: %w", err)
}
}
case StreamRespondComplete:
if err := h(modelprovider2.StreamEvent{
Kind: modelprovider2.StreamEnd,
OutputTokens: event.Response.Usage.OutputTokens,
}); err != nil {
return fmt.Errorf("callback execution failed: %w", err)
}
}
return nil
}
}
func (p *Provider) ListModels(ctx context.Context) (result []modelprovider2.ModelInfo, err error) {
var models *ModelsResponse
for _, ak := range p.conf.GetApiKeys() {
c := NewOpenaiClient(ak, p.conf.GetBaseUrl(), p.httpClient)
models, err = c.getModels(ctx)
if err != nil {
logger.Errorf("call responses api failed err:%v", err)
continue
}
break
}
if models == nil {
return
}
for _, model := range models.Data {
//if !FilterModel(model) {
// continue
//}
result = append(result, modelprovider2.ModelInfo{
RealID: model.ID,
Raw: model,
Vendor: model.OwnedBy,
DisplayName: model.ID,
})
}
return
}
func (p *Provider) GetDefaultModel() shared.ResponsesModel {
return responses.ChatModelGPT5Mini
}
func IsGPT4Model(model string) bool {
return strings.Contains(model, "gpt-4")
}
func isApikeyInvalid(err error) bool {
logger.Debugf("err:%v,sub:%s,contains:%v", err.Error(), "Incorrect API key provided", strings.Contains(err.Error(), "Incorrect API key provided"))
return strings.Contains(err.Error(), "Incorrect API key provided")
}