118 lines
2.9 KiB
Go
118 lines
2.9 KiB
Go
package bootstrap
|
||
|
||
import (
|
||
"ai-css/library/modelprovider"
|
||
"ai-css/library/modelprovider/config"
|
||
"ai-css/library/modelprovider/consts"
|
||
"ai-css/library/modelprovider/providers"
|
||
"context"
|
||
"fmt"
|
||
"log"
|
||
)
|
||
|
||
type AIManager struct {
|
||
CfgMgr *config.Manager
|
||
Registry *providers.Registry
|
||
}
|
||
|
||
var DefaultAIManager *AIManager
|
||
|
||
func init() {
|
||
var err error
|
||
DefaultAIManager, err = Init(context.TODO(), &config.Manager{})
|
||
if err != nil {
|
||
log.Fatalf("init ai manager failed err:%v", err)
|
||
}
|
||
}
|
||
|
||
func Init(ctx context.Context, cfgMgr *config.Manager) (*AIManager, error) {
|
||
if err := cfgMgr.LoadConfigs(ctx); err != nil {
|
||
return nil, err
|
||
}
|
||
return &AIManager{
|
||
CfgMgr: cfgMgr,
|
||
Registry: providers.BuildRegistry(),
|
||
}, nil
|
||
}
|
||
|
||
func (a *AIManager) NewClient(providerName consts.ProviderName, opts ...ClientOption) (*modelprovider.Client, error) {
|
||
provider, finalOpts, err := a.resolveProvider(providerName, opts...)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// model 优先级:opts > config > provider
|
||
model := finalOpts.DefaultModel
|
||
if model == "" {
|
||
model = provider.GetDefaultModel()
|
||
}
|
||
|
||
return modelprovider.NewClient(provider, model), nil
|
||
}
|
||
|
||
func (a *AIManager) NewProvider(providerName consts.ProviderName, opts ...ClientOption) (modelprovider.Provider, error) {
|
||
provider, _, err := a.resolveProvider(providerName, opts...)
|
||
return provider, err
|
||
}
|
||
|
||
func (a *AIManager) resolveProvider(providerName consts.ProviderName, opts ...ClientOption) (modelprovider.Provider, *Options, error) {
|
||
// 初始化 options
|
||
o := &Options{
|
||
ProviderName: providerName,
|
||
}
|
||
for _, opt := range opts {
|
||
opt(o)
|
||
}
|
||
|
||
// Step 1: 如果直接传 Provider,则直接返回
|
||
if o.Provider != nil {
|
||
return o.Provider, o, nil
|
||
}
|
||
|
||
// Step 2: 校验 ProviderName
|
||
if o.ProviderName == "" {
|
||
return nil, nil, fmt.Errorf("invalid provider name: %s", o.ProviderName)
|
||
}
|
||
|
||
// Step 3: 解析 ProviderConfig(Option > DB)
|
||
conf, err := a.resolveProviderConfig(o)
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("resolve provider config failed: %w", err)
|
||
}
|
||
|
||
// Step 4: 实际创建 provider(registry lookup)
|
||
provider, err := a.createProvider(o.ProviderName, conf)
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("create provider failed: %w", err)
|
||
}
|
||
|
||
return provider, o, nil
|
||
}
|
||
|
||
func (a *AIManager) resolveProviderConfig(o *Options) (*config.ProviderConfig, error) {
|
||
if o.ProviderConfig != nil {
|
||
return o.ProviderConfig, nil
|
||
}
|
||
|
||
cfg, ok := a.CfgMgr.GetConfigByProviderName(o.ProviderName)
|
||
if !ok {
|
||
return nil, fmt.Errorf("config not found for provider: %s", o.ProviderName)
|
||
}
|
||
|
||
return &cfg, nil
|
||
}
|
||
|
||
func (a *AIManager) createProvider(providerName consts.ProviderName, conf *config.ProviderConfig) (modelprovider.Provider, error) {
|
||
creator := a.Registry.Providers[providerName]
|
||
if creator == nil {
|
||
return nil, fmt.Errorf("provider not supported: %s", providerName)
|
||
}
|
||
|
||
provider, err := creator(conf)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("create provider instance failed: %w", err)
|
||
}
|
||
|
||
return provider, nil
|
||
}
|