378 lines
12 KiB
Go
Executable File
378 lines
12 KiB
Go
Executable File
package openai
|
||
|
||
import (
|
||
"github.com/openai/openai-go/v3/responses"
|
||
|
||
"ai-css/library/modelprovider/errorswrap"
|
||
"bufio"
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
|
||
"ai-css/library/logger"
|
||
)
|
||
|
||
type EventType string
|
||
|
||
const (
|
||
StreamRespondError EventType = "response.error"
|
||
StreamRespondFailed EventType = "response.failed"
|
||
StreamRespondOutputTextDelta EventType = "response.output_text.delta"
|
||
StreamRespondComplete EventType = "response.completed"
|
||
)
|
||
|
||
var NetworkError = errors.New("network unreachable")
|
||
|
||
// OpenAIResponsesRequest models POST /v1/responses request body.
|
||
type OpenAIResponsesRequest struct {
|
||
Background *bool `json:"background,omitempty"`
|
||
Conversation json.RawMessage `json:"conversation,omitempty"` // string 或 {id: "..."} 等,用 RawMessage 保持灵活
|
||
Include []string `json:"include,omitempty"`
|
||
Input interface{} `json:"input,omitempty"` // 聊天场景我们会塞 []OpenAIChatMessage,其他场景可自定义
|
||
Instructions string `json:"instructions,omitempty"`
|
||
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
|
||
MaxToolCalls *int `json:"max_tool_calls,omitempty"`
|
||
Metadata map[string]string `json:"metadata,omitempty"`
|
||
Model string `json:"model,omitempty"`
|
||
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||
PreviousResponseID string `json:"previous_response_id,omitempty"`
|
||
Prompt json.RawMessage `json:"prompt,omitempty"` // prompt 模板引用,结构不固定,用 RawMessage
|
||
PromptCacheKey string `json:"prompt_cache_key,omitempty"`
|
||
Reasoning json.RawMessage `json:"reasoning,omitempty"` // {effort: "..."} 等
|
||
Summary string `json:"summary,omitempty"`
|
||
SafetyIdentifier string `json:"safety_identifier,omitempty"`
|
||
ServiceTier string `json:"service_tier,omitempty"`
|
||
Store *bool `json:"store,omitempty"`
|
||
Stream bool `json:"stream,omitempty"`
|
||
StreamOptions json.RawMessage `json:"stream_options,omitempty"` // e.g. {"include_usage": true}
|
||
Temperature *float32 `json:"temperature,omitempty"`
|
||
Text json.RawMessage `json:"text,omitempty"` // 结构化输出配置等
|
||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||
Tools json.RawMessage `json:"tools,omitempty"` // 工具 / 函数 / MCP 定义
|
||
TopLogprobs *int `json:"top_logprobs,omitempty"`
|
||
TopP *float32 `json:"top_p,omitempty"`
|
||
Truncation string `json:"truncation,omitempty"`
|
||
}
|
||
|
||
type OpenAIResponsesResponse struct {
|
||
ID string `json:"id"`
|
||
Object string `json:"object"`
|
||
CreatedAt int64 `json:"created_at"`
|
||
Status string `json:"status"`
|
||
Error OpenAIErrorMessage `json:"error,omitempty"` // 可能是 null 或对象
|
||
IncompleteDetails any `json:"incomplete_details,omitempty"` // 可能是 null 或对象
|
||
Instructions *string `json:"instructions,omitempty"`
|
||
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
|
||
Model string `json:"model"`
|
||
Output []OutputItem `json:"output"`
|
||
ParallelToolCalls bool `json:"parallel_tool_calls"`
|
||
PreviousResponseID *string `json:"previous_response_id,omitempty"`
|
||
Reasoning Reasoning `json:"reasoning"`
|
||
Store bool `json:"store"`
|
||
Temperature float64 `json:"temperature"`
|
||
Text TextSpec `json:"text"`
|
||
ToolChoice string `json:"tool_choice"` // "auto" | 其他
|
||
Tools []json.RawMessage `json:"tools"` // 留作将来扩展(function/tool schemas 等)
|
||
TopP float64 `json:"top_p"`
|
||
Truncation string `json:"truncation"`
|
||
Usage Usage `json:"usage"`
|
||
User *string `json:"user,omitempty"`
|
||
Metadata map[string]any `json:"metadata"`
|
||
}
|
||
|
||
type OpenAIErrorMessage struct {
|
||
Msg string `json:"message"`
|
||
Type string `json:"type"`
|
||
Param string `json:"model"`
|
||
Code string `json:"model_not_found"`
|
||
}
|
||
|
||
// ResponsesStreamEvent 流事件的通用结构
|
||
type ResponsesStreamEvent struct {
|
||
Type string `json:"type"` // e.g. "response.output_text.delta"
|
||
Delta string `json:"delta,omitempty"` // 文本增量内容(仅在 output_text.delta 事件里有)
|
||
ItemID string `json:"item_id,omitempty"` // 其他字段可以按需用
|
||
OutputIndex int `json:"output_index,omitempty"` // 这里先不用
|
||
ContentIndex int `json:"content_index,omitempty"`
|
||
// 错误事件: type = "response.error" / "response.failed"
|
||
Error *struct {
|
||
Code string `json:"code"`
|
||
Message string `json:"message"`
|
||
} `json:"error,omitempty"`
|
||
Response responses.Response `json:"response"`
|
||
}
|
||
|
||
type OutputItem struct {
|
||
Type string `json:"type"` // "message" 等
|
||
ID string `json:"id"`
|
||
Status string `json:"status"` // "completed" 等
|
||
Role string `json:"role"` // "assistant" 等
|
||
Content []ContentBlock `json:"content"`
|
||
}
|
||
|
||
type ContentBlock struct {
|
||
Type string `json:"type"` // "output_text" 等
|
||
Text string `json:"text,omitempty"` // 当 type=output_text 时存在
|
||
Annotations []any `json:"annotations,omitempty"` // 留空/数组
|
||
// 未来还可能有其他字段(如 tool_calls 等),用 RawMessage 兜底更安全:
|
||
// Raw json to keep forward-compatibility:
|
||
// Raw json.RawMessage `json:"-"`
|
||
}
|
||
|
||
type Reasoning struct {
|
||
Effort *string `json:"effort,omitempty"`
|
||
Summary *string `json:"summary,omitempty"`
|
||
}
|
||
|
||
type TextSpec struct {
|
||
Format TextFormat `json:"format"`
|
||
}
|
||
|
||
type TextFormat struct {
|
||
Type string `json:"type"` // "text"
|
||
}
|
||
|
||
type Usage struct {
|
||
InputTokens int `json:"input_tokens"`
|
||
InputTokensDetails InputTokensDetails `json:"input_tokens_details"`
|
||
OutputTokens int `json:"output_tokens"`
|
||
OutputTokensDetails OutputTokensDetail `json:"output_tokens_details"`
|
||
TotalTokens int `json:"total_tokens"`
|
||
}
|
||
|
||
type InputTokensDetails struct {
|
||
CachedTokens int `json:"cached_tokens"`
|
||
}
|
||
|
||
type OutputTokensDetail struct {
|
||
ReasoningTokens int `json:"reasoning_tokens"`
|
||
}
|
||
|
||
type OpenAIChatMessage struct {
|
||
Role string `json:"role"` // "system" / "user" / "assistant"
|
||
Content []interface{} `json:"content"` // 多模态就多个 part,这里只放 text
|
||
}
|
||
|
||
// 单条内容片段(这里只演示 text)
|
||
type OpenAIContentPart struct {
|
||
Type string `json:"type"` // "text"
|
||
Text string `json:"text,omitempty"` // 文本内容
|
||
}
|
||
|
||
// 文本输入
|
||
type TextInput struct {
|
||
Type string `json:"type"` // 固定为 "input_text"
|
||
Text string `json:"text"`
|
||
}
|
||
|
||
// 图片输入
|
||
type ImageInput struct {
|
||
Type string `json:"type"` // 固定为 "input_image"
|
||
ImageURL string `json:"image_url,omitempty"` // URL 或 Base64
|
||
Detail string `json:"detail,omitempty"` // high / low / auto
|
||
FileID string `json:"file_id,omitempty"` // 若图片来自文件API
|
||
}
|
||
|
||
// 文件输入
|
||
type FileInput struct {
|
||
Type string `json:"type"` // 固定为 "input_file"
|
||
FileID string `json:"file_id,omitempty"` // Files API 上传返回的 ID
|
||
FileData string `json:"file_data,omitempty"` // Base64 文件内容
|
||
FileURL string `json:"file_url,omitempty"` // 文件URL
|
||
Filename string `json:"filename,omitempty"` // 文件名(可选)
|
||
}
|
||
|
||
// Model 表示单个模型对象
|
||
type Model struct {
|
||
ID string `json:"id"`
|
||
Object string `json:"object"`
|
||
Created int64 `json:"created"`
|
||
OwnedBy string `json:"owned_by"`
|
||
}
|
||
|
||
// ModelsResponse 表示 /v1/models 的响应结构
|
||
type ModelsResponse struct {
|
||
Object string `json:"object"` // 固定为 "list"
|
||
Data []Model `json:"data"`
|
||
Error RespError `json:"error"`
|
||
}
|
||
|
||
type RespError struct {
|
||
Msg string `json:"message"`
|
||
Type string `json:"type"`
|
||
Code string `json:"code"`
|
||
}
|
||
|
||
type OpenAIClient struct {
|
||
apiKey string
|
||
baseURL string
|
||
httpClient *http.Client
|
||
}
|
||
|
||
func NewOpenaiClient(apikey, apiUrl string, httpC *http.Client) OpenAIClient {
|
||
return OpenAIClient{apikey, apiUrl, httpC}
|
||
}
|
||
|
||
// callResponses 调用openAI Responses 接口
|
||
func (o *OpenAIClient) callResponses(
|
||
ctx context.Context, req *OpenAIResponsesRequest, callback func(evt *ResponsesStreamEvent) error,
|
||
) (resp *OpenAIResponsesResponse, err error) {
|
||
|
||
reqBody, err := json.Marshal(req)
|
||
if err != nil {
|
||
err = fmt.Errorf("failed to serialize request: %w", err)
|
||
return
|
||
}
|
||
// 2. Send POST to /v1/responses
|
||
httpReq, err := http.NewRequestWithContext(
|
||
ctx,
|
||
http.MethodPost,
|
||
o.baseURL+"/v1/responses",
|
||
bytes.NewBuffer(reqBody),
|
||
)
|
||
if err != nil {
|
||
logger.Errorf("new request failed err:%v", err)
|
||
err = fmt.Errorf("failed to create HTTP request: %w", err)
|
||
return
|
||
}
|
||
|
||
httpReq.Header.Set("Content-Type", "application/json")
|
||
httpReq.Header.Set("Authorization", "Bearer "+o.apiKey)
|
||
httpReq.Header.Set("Accept", "text/event-stream")
|
||
|
||
logger.Debugf("openai callResponses req:%s", string(reqBody))
|
||
|
||
respond, err := o.httpClient.Do(httpReq)
|
||
if err != nil {
|
||
logger.Errorf("call responses api failed err:%v", err)
|
||
err = NetworkError
|
||
return
|
||
}
|
||
defer respond.Body.Close()
|
||
|
||
if respond.StatusCode != http.StatusOK {
|
||
body, _ := io.ReadAll(respond.Body)
|
||
var respondData *OpenAIResponsesResponse
|
||
json.Unmarshal(body, &respondData)
|
||
err = fmt.Errorf("OpenAI API returned error [%d]: %s", respond.StatusCode, string(body))
|
||
return
|
||
}
|
||
|
||
// 3. Parse SSE stream
|
||
reader := bufio.NewReader(respond.Body)
|
||
|
||
for {
|
||
select {
|
||
case <-ctx.Done():
|
||
err = ctx.Err()
|
||
logger.Errorf("lisent stream failed err:%v", err)
|
||
if err == io.EOF {
|
||
return
|
||
}
|
||
err = NetworkError
|
||
return
|
||
default:
|
||
}
|
||
|
||
var line []byte
|
||
line, err = reader.ReadBytes('\n')
|
||
if err != nil {
|
||
if err == io.EOF {
|
||
return
|
||
}
|
||
logger.Errorf("read body failed err:%v", err)
|
||
err = NetworkError
|
||
return
|
||
}
|
||
|
||
line = bytes.TrimSpace(line)
|
||
if len(line) == 0 {
|
||
continue
|
||
}
|
||
|
||
if !bytes.HasPrefix(line, []byte("data: ")) {
|
||
continue
|
||
}
|
||
|
||
data := bytes.TrimPrefix(line, []byte("data: "))
|
||
|
||
var event = new(ResponsesStreamEvent)
|
||
if err = json.Unmarshal(data, event); err != nil {
|
||
continue
|
||
}
|
||
if err = callback(event); err != nil {
|
||
err = fmt.Errorf("callback execution failed: %w", err)
|
||
return
|
||
}
|
||
}
|
||
|
||
}
|
||
|
||
func (o *OpenAIClient) getModels(ctx context.Context) (*ModelsResponse, error) {
|
||
req, err := http.NewRequestWithContext(
|
||
ctx,
|
||
http.MethodGet,
|
||
o.baseURL+"/v1/models",
|
||
nil,
|
||
)
|
||
if err != nil {
|
||
logger.Errorf("new request failed err:%v", err)
|
||
return nil, errorswrap.NewError(errorswrap.ErrorProviderApiUrlInvalid)
|
||
}
|
||
req.Header.Set("Authorization", "Bearer "+o.apiKey)
|
||
|
||
resp, err := o.httpClient.Do(req)
|
||
if err != nil {
|
||
logger.Infof("call openai api failed err:%v,openAIclient:%v", err, o)
|
||
return nil, errorswrap.NewError(errorswrap.ErrorProviderApiUrlInvalid)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
logger.Errorf("status code not ok code:%d", resp.StatusCode)
|
||
var body []byte
|
||
body, err = io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
logger.Errorf("read response body failed: %v", err)
|
||
return nil, errorswrap.NewError(errorswrap.ErrorProviderApiUrlInvalid)
|
||
}
|
||
logger.Errorf("status code not ok body:%s", string(body))
|
||
return nil, errorswrap.NewError(errorswrap.ErrorProviderApiKeyInvalid)
|
||
}
|
||
|
||
var result ModelsResponse
|
||
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||
return nil, err
|
||
}
|
||
if result.Error.Msg != "" {
|
||
return nil, errorswrap.NewError(errorswrap.ErrorProviderApiKeyInvalid)
|
||
}
|
||
return &result, nil
|
||
}
|
||
|
||
func NewTextPart(isInput bool, text string) TextInput {
|
||
var prefix = "output"
|
||
if isInput {
|
||
prefix = "input"
|
||
}
|
||
return TextInput{
|
||
Type: fmt.Sprintf("%s_text", prefix),
|
||
Text: text,
|
||
}
|
||
}
|
||
|
||
func NewImagePart(isInput bool, ImageURL string) ImageInput {
|
||
var prefix = "output"
|
||
if isInput {
|
||
prefix = "input"
|
||
}
|
||
return ImageInput{
|
||
Type: fmt.Sprintf("%s_image", prefix),
|
||
ImageURL: ImageURL,
|
||
}
|
||
}
|