156 lines
4.1 KiB
Go
156 lines
4.1 KiB
Go
package main
|
||
|
||
import (
|
||
"errors"
|
||
"reflect"
|
||
"strings"
|
||
"sync"
|
||
|
||
"github.com/spf13/viper"
|
||
)
|
||
|
||
type AzureOpenAIConfig struct {
|
||
Endpoint string `mapstructure:"endpoint"`
|
||
Keys []string `mapstructure:"keys"`
|
||
GPT4Model string `mapstructure:"gpt4_model"`
|
||
GPT35Model string `mapstructure:"gpt35_model"`
|
||
}
|
||
|
||
type OpenAIConfig struct {
|
||
APIKey string `mapstructure:"api_key"`
|
||
Organization string `mapstructure:"organization"`
|
||
}
|
||
|
||
type AIGrammarBaseConfig struct {
|
||
// 在 Go 中,只有首字母大写的字段才能被外部包(如 viper)访问。
|
||
JwtSecret string `mapstructure:"jwt_secret"`
|
||
BindAddr string `mapstructure:"bind_addr"`
|
||
}
|
||
|
||
type DataBaseConfig struct {
|
||
// 在 Go 中,只有首字母大写的字段才能被外部包(如 viper)访问。
|
||
MysqlConn string `mapstructure:"mysql_conn"`
|
||
RedisConn string `mapstructure:"redis_conn"`
|
||
MysqlUser string `mapstructure:"mysql_user"`
|
||
MysqlPass string `mapstructure:"mysql_pass"`
|
||
}
|
||
|
||
type LoggerConfig struct {
|
||
// 在 Go 中,只有首字母大写的字段才能被外部包(如 viper)访问。
|
||
EchoLogFile string `mapstructure:"echo_log_file"`
|
||
LogFile string `mapstructure:"log_file"`
|
||
MaxSize int `mapstructure:"max_size"`
|
||
MaxBackups int `mapstructure:"max_backups"`
|
||
MaxAge int `mapstructure:"max_age"`
|
||
Compress bool `mapstructure:"compress"`
|
||
Level string `mapstructure:"level"`
|
||
}
|
||
|
||
type ConfigManager struct {
|
||
BaseConfig AIGrammarBaseConfig
|
||
AzureOpenAI AzureOpenAIConfig
|
||
OpenAI OpenAIConfig
|
||
DBConfig DataBaseConfig
|
||
LogConfig LoggerConfig
|
||
}
|
||
|
||
var once sync.Once
|
||
var instance *ConfigManager
|
||
var initError error
|
||
|
||
func GetConfigManager() (*ConfigManager, error) {
|
||
once.Do(func() {
|
||
instance = &ConfigManager{}
|
||
initError = instance.initConfig()
|
||
})
|
||
return instance, initError
|
||
}
|
||
|
||
func (cm *ConfigManager) initConfig() error {
|
||
viper.SetConfigName("config")
|
||
viper.SetConfigType("toml")
|
||
viper.AddConfigPath(".")
|
||
|
||
if err := viper.ReadInConfig(); err != nil {
|
||
return err
|
||
}
|
||
|
||
// 这里如果有字段是带逗号的字符串,处理会有问题,需要改正。
|
||
if err := cm.loadConfigSection("base", &cm.BaseConfig); err != nil {
|
||
return err
|
||
}
|
||
|
||
if err := cm.loadConfigSection("azure_openai", &cm.AzureOpenAI); err != nil {
|
||
return err
|
||
}
|
||
if err := cm.validateAzureOpenAI(); err != nil {
|
||
return err
|
||
}
|
||
|
||
if err := cm.loadConfigSection("openai", &cm.OpenAI); err != nil {
|
||
return err
|
||
}
|
||
|
||
if err := cm.loadConfigSection("database", &cm.DBConfig); err != nil {
|
||
return err
|
||
}
|
||
|
||
if err := cm.loadConfigSection("log", &cm.LogConfig); err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (cm *ConfigManager) loadConfigSection(key string, config interface{}) error {
|
||
return viper.UnmarshalKey(key, config, viper.DecodeHook(
|
||
func(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) {
|
||
if f == reflect.String && t == reflect.Slice {
|
||
return strings.Split(data.(string), ","), nil
|
||
}
|
||
return data, nil
|
||
}))
|
||
}
|
||
|
||
// validateAzureOpenAI checks that the AzureOpenAI configuration has all required fields properly set.
|
||
func (cm *ConfigManager) validateAzureOpenAI() error {
|
||
if strings.TrimSpace(cm.AzureOpenAI.Endpoint) == "" {
|
||
return errors.New("AzureOpenAI endpoint cannot be empty")
|
||
}
|
||
if strings.TrimSpace(cm.AzureOpenAI.GPT4Model) == "" {
|
||
return errors.New("AzureOpenAI GPT4Model cannot be empty")
|
||
}
|
||
if strings.TrimSpace(cm.AzureOpenAI.GPT35Model) == "" {
|
||
return errors.New("AzureOpenAI GPT35Model cannot be empty")
|
||
}
|
||
if len(cm.AzureOpenAI.Keys) == 0 || strings.TrimSpace(cm.AzureOpenAI.Keys[0]) == "" {
|
||
return errors.New("AzureOpenAI Keys must contain at least one valid key")
|
||
}
|
||
|
||
if strings.TrimSpace(cm.BaseConfig.JwtSecret) == "" {
|
||
return errors.New("jwt secret cannot be empty")
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (cm *ConfigManager) GetAzureConfig() *AzureOpenAIConfig {
|
||
return &cm.AzureOpenAI
|
||
}
|
||
|
||
func (cm *ConfigManager) GetOpenAIConfig() *OpenAIConfig {
|
||
return &cm.OpenAI
|
||
}
|
||
|
||
func (cm *ConfigManager) GetBaseConfig() *AIGrammarBaseConfig {
|
||
return &cm.BaseConfig
|
||
}
|
||
|
||
func (cm *ConfigManager) GetDatabaseConfig() *DataBaseConfig {
|
||
return &cm.DBConfig
|
||
}
|
||
|
||
func (cm *ConfigManager) GetLogConfig() *LoggerConfig {
|
||
return &cm.LogConfig
|
||
}
|