ai-css/library/modelprovider/providers/openai/api.go

378 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package openai
import (
"github.com/openai/openai-go/v3/responses"
"ai-css/library/modelprovider/errorswrap"
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"ai-css/library/logger"
)
type EventType string
const (
StreamRespondError EventType = "response.error"
StreamRespondFailed EventType = "response.failed"
StreamRespondOutputTextDelta EventType = "response.output_text.delta"
StreamRespondComplete EventType = "response.completed"
)
var NetworkError = errors.New("network unreachable")
// OpenAIResponsesRequest models POST /v1/responses request body.
type OpenAIResponsesRequest struct {
Background *bool `json:"background,omitempty"`
Conversation json.RawMessage `json:"conversation,omitempty"` // string 或 {id: "..."} 等,用 RawMessage 保持灵活
Include []string `json:"include,omitempty"`
Input interface{} `json:"input,omitempty"` // 聊天场景我们会塞 []OpenAIChatMessage其他场景可自定义
Instructions string `json:"instructions,omitempty"`
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
MaxToolCalls *int `json:"max_tool_calls,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
Model string `json:"model,omitempty"`
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
PreviousResponseID string `json:"previous_response_id,omitempty"`
Prompt json.RawMessage `json:"prompt,omitempty"` // prompt 模板引用,结构不固定,用 RawMessage
PromptCacheKey string `json:"prompt_cache_key,omitempty"`
Reasoning json.RawMessage `json:"reasoning,omitempty"` // {effort: "..."} 等
Summary string `json:"summary,omitempty"`
SafetyIdentifier string `json:"safety_identifier,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
Store *bool `json:"store,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions json.RawMessage `json:"stream_options,omitempty"` // e.g. {"include_usage": true}
Temperature *float32 `json:"temperature,omitempty"`
Text json.RawMessage `json:"text,omitempty"` // 结构化输出配置等
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
Tools json.RawMessage `json:"tools,omitempty"` // 工具 / 函数 / MCP 定义
TopLogprobs *int `json:"top_logprobs,omitempty"`
TopP *float32 `json:"top_p,omitempty"`
Truncation string `json:"truncation,omitempty"`
}
type OpenAIResponsesResponse struct {
ID string `json:"id"`
Object string `json:"object"`
CreatedAt int64 `json:"created_at"`
Status string `json:"status"`
Error OpenAIErrorMessage `json:"error,omitempty"` // 可能是 null 或对象
IncompleteDetails any `json:"incomplete_details,omitempty"` // 可能是 null 或对象
Instructions *string `json:"instructions,omitempty"`
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
Model string `json:"model"`
Output []OutputItem `json:"output"`
ParallelToolCalls bool `json:"parallel_tool_calls"`
PreviousResponseID *string `json:"previous_response_id,omitempty"`
Reasoning Reasoning `json:"reasoning"`
Store bool `json:"store"`
Temperature float64 `json:"temperature"`
Text TextSpec `json:"text"`
ToolChoice string `json:"tool_choice"` // "auto" | 其他
Tools []json.RawMessage `json:"tools"` // 留作将来扩展function/tool schemas 等)
TopP float64 `json:"top_p"`
Truncation string `json:"truncation"`
Usage Usage `json:"usage"`
User *string `json:"user,omitempty"`
Metadata map[string]any `json:"metadata"`
}
type OpenAIErrorMessage struct {
Msg string `json:"message"`
Type string `json:"type"`
Param string `json:"model"`
Code string `json:"model_not_found"`
}
// ResponsesStreamEvent 流事件的通用结构
type ResponsesStreamEvent struct {
Type string `json:"type"` // e.g. "response.output_text.delta"
Delta string `json:"delta,omitempty"` // 文本增量内容(仅在 output_text.delta 事件里有)
ItemID string `json:"item_id,omitempty"` // 其他字段可以按需用
OutputIndex int `json:"output_index,omitempty"` // 这里先不用
ContentIndex int `json:"content_index,omitempty"`
// 错误事件: type = "response.error" / "response.failed"
Error *struct {
Code string `json:"code"`
Message string `json:"message"`
} `json:"error,omitempty"`
Response responses.Response `json:"response"`
}
type OutputItem struct {
Type string `json:"type"` // "message" 等
ID string `json:"id"`
Status string `json:"status"` // "completed" 等
Role string `json:"role"` // "assistant" 等
Content []ContentBlock `json:"content"`
}
type ContentBlock struct {
Type string `json:"type"` // "output_text" 等
Text string `json:"text,omitempty"` // 当 type=output_text 时存在
Annotations []any `json:"annotations,omitempty"` // 留空/数组
// 未来还可能有其他字段(如 tool_calls 等),用 RawMessage 兜底更安全:
// Raw json to keep forward-compatibility:
// Raw json.RawMessage `json:"-"`
}
type Reasoning struct {
Effort *string `json:"effort,omitempty"`
Summary *string `json:"summary,omitempty"`
}
type TextSpec struct {
Format TextFormat `json:"format"`
}
type TextFormat struct {
Type string `json:"type"` // "text"
}
type Usage struct {
InputTokens int `json:"input_tokens"`
InputTokensDetails InputTokensDetails `json:"input_tokens_details"`
OutputTokens int `json:"output_tokens"`
OutputTokensDetails OutputTokensDetail `json:"output_tokens_details"`
TotalTokens int `json:"total_tokens"`
}
type InputTokensDetails struct {
CachedTokens int `json:"cached_tokens"`
}
type OutputTokensDetail struct {
ReasoningTokens int `json:"reasoning_tokens"`
}
type OpenAIChatMessage struct {
Role string `json:"role"` // "system" / "user" / "assistant"
Content []interface{} `json:"content"` // 多模态就多个 part这里只放 text
}
// 单条内容片段(这里只演示 text
type OpenAIContentPart struct {
Type string `json:"type"` // "text"
Text string `json:"text,omitempty"` // 文本内容
}
// 文本输入
type TextInput struct {
Type string `json:"type"` // 固定为 "input_text"
Text string `json:"text"`
}
// 图片输入
type ImageInput struct {
Type string `json:"type"` // 固定为 "input_image"
ImageURL string `json:"image_url,omitempty"` // URL 或 Base64
Detail string `json:"detail,omitempty"` // high / low / auto
FileID string `json:"file_id,omitempty"` // 若图片来自文件API
}
// 文件输入
type FileInput struct {
Type string `json:"type"` // 固定为 "input_file"
FileID string `json:"file_id,omitempty"` // Files API 上传返回的 ID
FileData string `json:"file_data,omitempty"` // Base64 文件内容
FileURL string `json:"file_url,omitempty"` // 文件URL
Filename string `json:"filename,omitempty"` // 文件名(可选)
}
// Model 表示单个模型对象
type Model struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
// ModelsResponse 表示 /v1/models 的响应结构
type ModelsResponse struct {
Object string `json:"object"` // 固定为 "list"
Data []Model `json:"data"`
Error RespError `json:"error"`
}
type RespError struct {
Msg string `json:"message"`
Type string `json:"type"`
Code string `json:"code"`
}
type OpenAIClient struct {
apiKey string
baseURL string
httpClient *http.Client
}
func NewOpenaiClient(apikey, apiUrl string, httpC *http.Client) OpenAIClient {
return OpenAIClient{apikey, apiUrl, httpC}
}
// callResponses 调用openAI Responses 接口
func (o *OpenAIClient) callResponses(
ctx context.Context, req *OpenAIResponsesRequest, callback func(evt *ResponsesStreamEvent) error,
) (resp *OpenAIResponsesResponse, err error) {
reqBody, err := json.Marshal(req)
if err != nil {
err = fmt.Errorf("failed to serialize request: %w", err)
return
}
// 2. Send POST to /v1/responses
httpReq, err := http.NewRequestWithContext(
ctx,
http.MethodPost,
o.baseURL+"/v1/responses",
bytes.NewBuffer(reqBody),
)
if err != nil {
logger.Errorf("new request failed err:%v", err)
err = fmt.Errorf("failed to create HTTP request: %w", err)
return
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+o.apiKey)
httpReq.Header.Set("Accept", "text/event-stream")
logger.Debugf("openai callResponses req:%s", string(reqBody))
respond, err := o.httpClient.Do(httpReq)
if err != nil {
logger.Errorf("call responses api failed err:%v", err)
err = NetworkError
return
}
defer respond.Body.Close()
if respond.StatusCode != http.StatusOK {
body, _ := io.ReadAll(respond.Body)
var respondData *OpenAIResponsesResponse
json.Unmarshal(body, &respondData)
err = fmt.Errorf("OpenAI API returned error [%d]: %s", respond.StatusCode, string(body))
return
}
// 3. Parse SSE stream
reader := bufio.NewReader(respond.Body)
for {
select {
case <-ctx.Done():
err = ctx.Err()
logger.Errorf("lisent stream failed err:%v", err)
if err == io.EOF {
return
}
err = NetworkError
return
default:
}
var line []byte
line, err = reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
return
}
logger.Errorf("read body failed err:%v", err)
err = NetworkError
return
}
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
if !bytes.HasPrefix(line, []byte("data: ")) {
continue
}
data := bytes.TrimPrefix(line, []byte("data: "))
var event = new(ResponsesStreamEvent)
if err = json.Unmarshal(data, event); err != nil {
continue
}
if err = callback(event); err != nil {
err = fmt.Errorf("callback execution failed: %w", err)
return
}
}
}
func (o *OpenAIClient) getModels(ctx context.Context) (*ModelsResponse, error) {
req, err := http.NewRequestWithContext(
ctx,
http.MethodGet,
o.baseURL+"/v1/models",
nil,
)
if err != nil {
logger.Errorf("new request failed err:%v", err)
return nil, errorswrap.NewError(errorswrap.ErrorProviderApiUrlInvalid)
}
req.Header.Set("Authorization", "Bearer "+o.apiKey)
resp, err := o.httpClient.Do(req)
if err != nil {
logger.Infof("call openai api failed err:%v,openAIclient:%v", err, o)
return nil, errorswrap.NewError(errorswrap.ErrorProviderApiUrlInvalid)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
logger.Errorf("status code not ok code:%d", resp.StatusCode)
var body []byte
body, err = io.ReadAll(resp.Body)
if err != nil {
logger.Errorf("read response body failed: %v", err)
return nil, errorswrap.NewError(errorswrap.ErrorProviderApiUrlInvalid)
}
logger.Errorf("status code not ok body:%s", string(body))
return nil, errorswrap.NewError(errorswrap.ErrorProviderApiKeyInvalid)
}
var result ModelsResponse
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err
}
if result.Error.Msg != "" {
return nil, errorswrap.NewError(errorswrap.ErrorProviderApiKeyInvalid)
}
return &result, nil
}
func NewTextPart(isInput bool, text string) TextInput {
var prefix = "output"
if isInput {
prefix = "input"
}
return TextInput{
Type: fmt.Sprintf("%s_text", prefix),
Text: text,
}
}
func NewImagePart(isInput bool, ImageURL string) ImageInput {
var prefix = "output"
if isInput {
prefix = "input"
}
return ImageInput{
Type: fmt.Sprintf("%s_image", prefix),
ImageURL: ImageURL,
}
}