Files
aigrammar/config.go
2024-08-12 04:10:48 +00:00

156 lines
4.1 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"errors"
"reflect"
"strings"
"sync"
"github.com/spf13/viper"
)
type AzureOpenAIConfig struct {
Endpoint string `mapstructure:"endpoint"`
Keys []string `mapstructure:"keys"`
GPT4Model string `mapstructure:"gpt4_model"`
GPT35Model string `mapstructure:"gpt35_model"`
}
type OpenAIConfig struct {
APIKey string `mapstructure:"api_key"`
Organization string `mapstructure:"organization"`
}
type AIGrammarBaseConfig struct {
// 在 Go 中,只有首字母大写的字段才能被外部包(如 viper访问。
JwtSecret string `mapstructure:"jwt_secret"`
BindAddr string `mapstructure:"bind_addr"`
}
type DataBaseConfig struct {
// 在 Go 中,只有首字母大写的字段才能被外部包(如 viper访问。
MysqlConn string `mapstructure:"mysql_conn"`
RedisConn string `mapstructure:"redis_conn"`
MysqlUser string `mapstructure:"mysql_user"`
MysqlPass string `mapstructure:"mysql_pass"`
}
type LoggerConfig struct {
// 在 Go 中,只有首字母大写的字段才能被外部包(如 viper访问。
EchoLogFile string `mapstructure:"echo_log_file"`
LogFile string `mapstructure:"log_file"`
MaxSize int `mapstructure:"max_size"`
MaxBackups int `mapstructure:"max_backups"`
MaxAge int `mapstructure:"max_age"`
Compress bool `mapstructure:"compress"`
Level string `mapstructure:"level"`
}
type ConfigManager struct {
BaseConfig AIGrammarBaseConfig
AzureOpenAI AzureOpenAIConfig
OpenAI OpenAIConfig
DBConfig DataBaseConfig
LogConfig LoggerConfig
}
var once sync.Once
var instance *ConfigManager
var initError error
func GetConfigManager() (*ConfigManager, error) {
once.Do(func() {
instance = &ConfigManager{}
initError = instance.initConfig()
})
return instance, initError
}
func (cm *ConfigManager) initConfig() error {
viper.SetConfigName("config")
viper.SetConfigType("toml")
viper.AddConfigPath(".")
if err := viper.ReadInConfig(); err != nil {
return err
}
// 这里如果有字段是带逗号的字符串,处理会有问题,需要改正。
if err := cm.loadConfigSection("base", &cm.BaseConfig); err != nil {
return err
}
if err := cm.loadConfigSection("azure_openai", &cm.AzureOpenAI); err != nil {
return err
}
if err := cm.validateAzureOpenAI(); err != nil {
return err
}
if err := cm.loadConfigSection("openai", &cm.OpenAI); err != nil {
return err
}
if err := cm.loadConfigSection("database", &cm.DBConfig); err != nil {
return err
}
if err := cm.loadConfigSection("log", &cm.LogConfig); err != nil {
return err
}
return nil
}
func (cm *ConfigManager) loadConfigSection(key string, config interface{}) error {
return viper.UnmarshalKey(key, config, viper.DecodeHook(
func(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) {
if f == reflect.String && t == reflect.Slice {
return strings.Split(data.(string), ","), nil
}
return data, nil
}))
}
// validateAzureOpenAI checks that the AzureOpenAI configuration has all required fields properly set.
func (cm *ConfigManager) validateAzureOpenAI() error {
if strings.TrimSpace(cm.AzureOpenAI.Endpoint) == "" {
return errors.New("AzureOpenAI endpoint cannot be empty")
}
if strings.TrimSpace(cm.AzureOpenAI.GPT4Model) == "" {
return errors.New("AzureOpenAI GPT4Model cannot be empty")
}
if strings.TrimSpace(cm.AzureOpenAI.GPT35Model) == "" {
return errors.New("AzureOpenAI GPT35Model cannot be empty")
}
if len(cm.AzureOpenAI.Keys) == 0 || strings.TrimSpace(cm.AzureOpenAI.Keys[0]) == "" {
return errors.New("AzureOpenAI Keys must contain at least one valid key")
}
if strings.TrimSpace(cm.BaseConfig.JwtSecret) == "" {
return errors.New("jwt secret cannot be empty")
}
return nil
}
func (cm *ConfigManager) GetAzureConfig() *AzureOpenAIConfig {
return &cm.AzureOpenAI
}
func (cm *ConfigManager) GetOpenAIConfig() *OpenAIConfig {
return &cm.OpenAI
}
func (cm *ConfigManager) GetBaseConfig() *AIGrammarBaseConfig {
return &cm.BaseConfig
}
func (cm *ConfigManager) GetDatabaseConfig() *DataBaseConfig {
return &cm.DBConfig
}
func (cm *ConfigManager) GetLogConfig() *LoggerConfig {
return &cm.LogConfig
}