Initial commit
This commit is contained in:
155
config.go
Normal file
155
config.go
Normal file
@ -0,0 +1,155 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type AzureOpenAIConfig struct {
|
||||
Endpoint string `mapstructure:"endpoint"`
|
||||
Keys []string `mapstructure:"keys"`
|
||||
GPT4Model string `mapstructure:"gpt4_model"`
|
||||
GPT35Model string `mapstructure:"gpt35_model"`
|
||||
}
|
||||
|
||||
type OpenAIConfig struct {
|
||||
APIKey string `mapstructure:"api_key"`
|
||||
Organization string `mapstructure:"organization"`
|
||||
}
|
||||
|
||||
type AIGrammarBaseConfig struct {
|
||||
// 在 Go 中,只有首字母大写的字段才能被外部包(如 viper)访问。
|
||||
JwtSecret string `mapstructure:"jwt_secret"`
|
||||
BindAddr string `mapstructure:"bind_addr"`
|
||||
}
|
||||
|
||||
type DataBaseConfig struct {
|
||||
// 在 Go 中,只有首字母大写的字段才能被外部包(如 viper)访问。
|
||||
MysqlConn string `mapstructure:"mysql_conn"`
|
||||
RedisConn string `mapstructure:"redis_conn"`
|
||||
MysqlUser string `mapstructure:"mysql_user"`
|
||||
MysqlPass string `mapstructure:"mysql_pass"`
|
||||
}
|
||||
|
||||
type LoggerConfig struct {
|
||||
// 在 Go 中,只有首字母大写的字段才能被外部包(如 viper)访问。
|
||||
EchoLogFile string `mapstructure:"echo_log_file"`
|
||||
LogFile string `mapstructure:"log_file"`
|
||||
MaxSize int `mapstructure:"max_size"`
|
||||
MaxBackups int `mapstructure:"max_backups"`
|
||||
MaxAge int `mapstructure:"max_age"`
|
||||
Compress bool `mapstructure:"compress"`
|
||||
Level string `mapstructure:"level"`
|
||||
}
|
||||
|
||||
type ConfigManager struct {
|
||||
BaseConfig AIGrammarBaseConfig
|
||||
AzureOpenAI AzureOpenAIConfig
|
||||
OpenAI OpenAIConfig
|
||||
DBConfig DataBaseConfig
|
||||
LogConfig LoggerConfig
|
||||
}
|
||||
|
||||
var once sync.Once
|
||||
var instance *ConfigManager
|
||||
var initError error
|
||||
|
||||
func GetConfigManager() (*ConfigManager, error) {
|
||||
once.Do(func() {
|
||||
instance = &ConfigManager{}
|
||||
initError = instance.initConfig()
|
||||
})
|
||||
return instance, initError
|
||||
}
|
||||
|
||||
func (cm *ConfigManager) initConfig() error {
|
||||
viper.SetConfigName("config")
|
||||
viper.SetConfigType("toml")
|
||||
viper.AddConfigPath(".")
|
||||
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 这里如果有字段是带逗号的字符串,处理会有问题,需要改正。
|
||||
if err := cm.loadConfigSection("base", &cm.BaseConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := cm.loadConfigSection("azure_openai", &cm.AzureOpenAI); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := cm.validateAzureOpenAI(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := cm.loadConfigSection("openai", &cm.OpenAI); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := cm.loadConfigSection("database", &cm.DBConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := cm.loadConfigSection("log", &cm.LogConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cm *ConfigManager) loadConfigSection(key string, config interface{}) error {
|
||||
return viper.UnmarshalKey(key, config, viper.DecodeHook(
|
||||
func(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) {
|
||||
if f == reflect.String && t == reflect.Slice {
|
||||
return strings.Split(data.(string), ","), nil
|
||||
}
|
||||
return data, nil
|
||||
}))
|
||||
}
|
||||
|
||||
// validateAzureOpenAI checks that the AzureOpenAI configuration has all required fields properly set.
|
||||
func (cm *ConfigManager) validateAzureOpenAI() error {
|
||||
if strings.TrimSpace(cm.AzureOpenAI.Endpoint) == "" {
|
||||
return errors.New("AzureOpenAI endpoint cannot be empty")
|
||||
}
|
||||
if strings.TrimSpace(cm.AzureOpenAI.GPT4Model) == "" {
|
||||
return errors.New("AzureOpenAI GPT4Model cannot be empty")
|
||||
}
|
||||
if strings.TrimSpace(cm.AzureOpenAI.GPT35Model) == "" {
|
||||
return errors.New("AzureOpenAI GPT35Model cannot be empty")
|
||||
}
|
||||
if len(cm.AzureOpenAI.Keys) == 0 || strings.TrimSpace(cm.AzureOpenAI.Keys[0]) == "" {
|
||||
return errors.New("AzureOpenAI Keys must contain at least one valid key")
|
||||
}
|
||||
|
||||
if strings.TrimSpace(cm.BaseConfig.JwtSecret) == "" {
|
||||
return errors.New("jwt secret cannot be empty")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cm *ConfigManager) GetAzureConfig() *AzureOpenAIConfig {
|
||||
return &cm.AzureOpenAI
|
||||
}
|
||||
|
||||
func (cm *ConfigManager) GetOpenAIConfig() *OpenAIConfig {
|
||||
return &cm.OpenAI
|
||||
}
|
||||
|
||||
func (cm *ConfigManager) GetBaseConfig() *AIGrammarBaseConfig {
|
||||
return &cm.BaseConfig
|
||||
}
|
||||
|
||||
func (cm *ConfigManager) GetDatabaseConfig() *DataBaseConfig {
|
||||
return &cm.DBConfig
|
||||
}
|
||||
|
||||
func (cm *ConfigManager) GetLogConfig() *LoggerConfig {
|
||||
return &cm.LogConfig
|
||||
}
|
||||
Reference in New Issue
Block a user