diff --git a/env.go b/env.go index f5348ef..b5712b7 100644 --- a/env.go +++ b/env.go @@ -12,17 +12,37 @@ import ( var ( defaultSep = "_" ) +var env *Env +type Env struct { + ignorePrefix bool +} + +func init() { + env = &Env{false} +} func upper(v string) string { return strings.ToUpper(v) } +func IgnorePrefix() { + env.ignorePrefix = true +} + func Fill(v interface{}) error { - if reflect.ValueOf(v).Kind() != reflect.Ptr { + return env.Fill(v) +} + +func (e *Env)Fill(v interface{}) error { + ind := reflect.Indirect(reflect.ValueOf(v)) + if reflect.ValueOf(v).Kind() != reflect.Ptr || ind.Kind() != reflect.Struct{ return fmt.Errorf("only the pointer to a struct is supported") } - ind := reflect.Indirect(reflect.ValueOf(v)) + prefix := upper(ind.Type().Name()) + if e.ignorePrefix { + prefix = "" + } err := fill(prefix, ind) if err != nil { return err @@ -30,11 +50,14 @@ func Fill(v interface{}) error { return nil } -func combine(p, n string, v string, ok bool) string { +func combine(p, n string, sep string, ok bool) string { + if p == "" { + return n + } if !ok { return p + defaultSep + n } - return p + v + n + return p + sep + n } func parseBool(v string) (bool, error) { @@ -71,7 +94,6 @@ func fill(pf string, ind reflect.Value) error { } func parse(prefix string, f reflect.Value, sf reflect.StructField) error { - //log.Print("parse:", prefix, f.String(), f.Type().String(), f.Kind().String()) df := sf.Tag.Get("default") isRequire, err := parseBool(sf.Tag.Get("require")) if err != nil { @@ -93,51 +115,51 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error { case reflect.Int: iv, err := strconv.ParseInt(ev, 10, 32) if err != nil { - return err + return fmt.Errorf("%s:%s", prefix, err) } f.SetInt(iv) case reflect.Int64: if f.Type().String() == "time.Duration" { t, err := time.ParseDuration(ev) if err != nil { - return err + return fmt.Errorf("%s:%s", prefix, err) } f.Set(reflect.ValueOf(t)) } else { iv, err := strconv.ParseInt(ev, 10, 64) if err != nil { - return err + return fmt.Errorf("%s:%s", prefix, err) } f.SetInt(iv) } case reflect.Uint: uiv, err := strconv.ParseUint(ev, 10, 32) if err != nil { - return err + return fmt.Errorf("%s:%s", prefix, err) } f.SetUint(uiv) case reflect.Uint64: uiv, err := strconv.ParseUint(ev, 10, 64) if err != nil { - return err + return fmt.Errorf("%s:%s", prefix, err) } f.SetUint(uiv) case reflect.Float32: f32, err := strconv.ParseFloat(ev, 32) if err != nil { - return err + return fmt.Errorf("%s:%s", prefix, err) } f.SetFloat(f32) case reflect.Float64: f64, err := strconv.ParseFloat(ev, 64) if err != nil { - return err + return fmt.Errorf("%s:%s", prefix, err) } f.SetFloat(f64) case reflect.Bool: b, err := parseBool(ev) if err != nil { - return err + return fmt.Errorf("%s:%s", prefix, err) } f.SetBool(b) case reflect.Slice: @@ -155,7 +177,7 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error { for i, v := range vals { val, err := strconv.ParseInt(v, 10, 32) if err != nil { - return err + return fmt.Errorf("%s:%s", prefix, err) } t[i] = int(val) } @@ -164,7 +186,7 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error { for i, v := range vals { val, err := strconv.ParseInt(v, 10, 64) if err != nil { - return err + return fmt.Errorf("%s:%s", prefix, err) } t[i] = val } @@ -173,7 +195,7 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error { for i, v := range vals { val, err := strconv.ParseUint(v, 10, 32) if err != nil { - return err + return fmt.Errorf("%s:%s", prefix, err) } t[i] = uint(val) } @@ -182,7 +204,7 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error { for i, v := range vals { val, err := strconv.ParseUint(v, 10, 64) if err != nil { - return err + return fmt.Errorf("%s:%s", prefix, err) } t[i] = val } @@ -191,7 +213,7 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error { for i, v := range vals { val, err := strconv.ParseFloat(v, 32) if err != nil { - return err + return fmt.Errorf("%s:%s", prefix, err) } t[i] = float32(val) } @@ -200,7 +222,7 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error { for i, v := range vals { val, err := strconv.ParseFloat(v, 64) if err != nil { - return err + return fmt.Errorf("%s:%s", prefix, err) } t[i] = val } @@ -209,7 +231,7 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error { for i, v := range vals { val, err := parseBool(v) if err != nil { - return err + return fmt.Errorf("%s:%s", prefix, err) } t[i] = val } diff --git a/env_test.go b/env_test.go new file mode 100644 index 0000000..08022af --- /dev/null +++ b/env_test.go @@ -0,0 +1,103 @@ +package env + +import ( + "testing" + "os" + "time" + "github.com/stretchr/testify/assert" +) +type config struct { + App string + Port int `default:"8000"` + IsDebug bool `env:"DEBUG"` + Hosts []string `slice_sep:","` + Timeout time.Duration + + Redis struct { + Version string `sep:""` // no sep between `CONFIG` and `REDIS` + Host string + Port int + } + + MySQL struct { + Version string `default:"5.7"` + Host string + Port int + } +} +func TestGeneralEnv(t *testing.T) { + os.Setenv("CONFIG_APP", "ENV APP") + // os.Setenv("CONFIG_PORT", "default") // default value + os.Setenv("CONFIG_DEBUG", "1") + os.Setenv("CONFIG_HOSTS", "192.168.0.1,127.0.0.1") + os.Setenv("CONFIG_TIMEOUT", "5s") + + os.Setenv("CONFIG_REDISVERSION", "3.2") + os.Setenv("CONFIG_REDIS_HOST", "rdb") + os.Setenv("CONFIG_REDIS_PORT", "6379") + + os.Setenv("CONFIG_MYSQL_HOST", "mysqldb") + os.Setenv("CONFIG_MYSQL_PORT", "3306") + + defer os.Clearenv() + + cfg := new(config) + err := Fill(cfg) + if err != nil { + t.Error(err) + } + + assert := assert.New(t) + + assert.Equal(cfg.App, "ENV APP") + assert.Equal(cfg.Port, 8000) + assert.Equal(cfg.IsDebug, true) + assert.Equal(cfg.Hosts, []string{"192.168.0.1", "127.0.0.1"}) + assert.Equal(cfg.Timeout, 5*time.Second) + assert.Equal(cfg.Redis.Version, "3.2") + assert.Equal(cfg.Redis.Host, "rdb") + assert.Equal(cfg.MySQL.Version, "5.7") + assert.Equal(cfg.MySQL.Host, "mysqldb") + assert.Equal(cfg.MySQL.Port, 3306) + +} + + + +func TestNoPrefixEnv(t *testing.T) { + os.Setenv("APP", "ENV_APP") + // os.Setenv("PORT", "default") // default value + os.Setenv("DEBUG", "1") + os.Setenv("HOSTS", "192.168.1.1,127.0.0.1") + os.Setenv("TIMEOUT", "5s") + + os.Setenv("REDISVERSION", "3.2") + os.Setenv("REDIS_HOST", "rdb") + os.Setenv("REDIS_PORT", "6379") + + os.Setenv("MYSQL_HOST", "mysqldb") + os.Setenv("MYSQL_PORT", "3306") + + defer os.Clearenv() + + cfg := new(config) + IgnorePrefix() + err := Fill(cfg) + if err != nil { + t.Error(err) + } + + assert := assert.New(t) + + assert.Equal(cfg.App, "ENV_APP") + assert.Equal(cfg.Port, 8000) + assert.Equal(cfg.IsDebug, true) + assert.Equal(cfg.Hosts, []string{"192.168.1.1", "127.0.0.1"}) + assert.Equal(cfg.Timeout, 5*time.Second) + assert.Equal(cfg.Redis.Version, "3.2") + assert.Equal(cfg.Redis.Host, "rdb") + assert.Equal(cfg.MySQL.Version, "5.7") + assert.Equal(cfg.MySQL.Host, "mysqldb") + assert.Equal(cfg.MySQL.Port, 3306) + +} \ No newline at end of file