ai-css/library/modelprovider/bootstrap/build.go

118 lines
2.9 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package bootstrap
import (
"ai-css/library/modelprovider"
"ai-css/library/modelprovider/config"
"ai-css/library/modelprovider/consts"
"ai-css/library/modelprovider/providers"
"context"
"fmt"
"log"
)
type AIManager struct {
CfgMgr *config.Manager
Registry *providers.Registry
}
var DefaultAIManager *AIManager
func init() {
var err error
DefaultAIManager, err = Init(context.TODO(), &config.Manager{})
if err != nil {
log.Fatalf("init ai manager failed err:%v", err)
}
}
func Init(ctx context.Context, cfgMgr *config.Manager) (*AIManager, error) {
if err := cfgMgr.LoadConfigs(ctx); err != nil {
return nil, err
}
return &AIManager{
CfgMgr: cfgMgr,
Registry: providers.BuildRegistry(),
}, nil
}
func (a *AIManager) NewClient(providerName consts.ProviderName, opts ...ClientOption) (*modelprovider.Client, error) {
provider, finalOpts, err := a.resolveProvider(providerName, opts...)
if err != nil {
return nil, err
}
// model 优先级opts > config > provider
model := finalOpts.DefaultModel
if model == "" {
model = provider.GetDefaultModel()
}
return modelprovider.NewClient(provider, model), nil
}
func (a *AIManager) NewProvider(providerName consts.ProviderName, opts ...ClientOption) (modelprovider.Provider, error) {
provider, _, err := a.resolveProvider(providerName, opts...)
return provider, err
}
func (a *AIManager) resolveProvider(providerName consts.ProviderName, opts ...ClientOption) (modelprovider.Provider, *Options, error) {
// 初始化 options
o := &Options{
ProviderName: providerName,
}
for _, opt := range opts {
opt(o)
}
// Step 1: 如果直接传 Provider则直接返回
if o.Provider != nil {
return o.Provider, o, nil
}
// Step 2: 校验 ProviderName
if o.ProviderName == "" {
return nil, nil, fmt.Errorf("invalid provider name: %s", o.ProviderName)
}
// Step 3: 解析 ProviderConfigOption > DB
conf, err := a.resolveProviderConfig(o)
if err != nil {
return nil, nil, fmt.Errorf("resolve provider config failed: %w", err)
}
// Step 4: 实际创建 providerregistry lookup
provider, err := a.createProvider(o.ProviderName, conf)
if err != nil {
return nil, nil, fmt.Errorf("create provider failed: %w", err)
}
return provider, o, nil
}
func (a *AIManager) resolveProviderConfig(o *Options) (*config.ProviderConfig, error) {
if o.ProviderConfig != nil {
return o.ProviderConfig, nil
}
cfg, ok := a.CfgMgr.GetConfigByProviderName(o.ProviderName)
if !ok {
return nil, fmt.Errorf("config not found for provider: %s", o.ProviderName)
}
return &cfg, nil
}
func (a *AIManager) createProvider(providerName consts.ProviderName, conf *config.ProviderConfig) (modelprovider.Provider, error) {
creator := a.Registry.Providers[providerName]
if creator == nil {
return nil, fmt.Errorf("provider not supported: %s", providerName)
}
provider, err := creator(conf)
if err != nil {
return nil, fmt.Errorf("create provider instance failed: %w", err)
}
return provider, nil
}