This commit is contained in:
timest 2018-07-11 11:53:49 +08:00
parent 3d591d9186
commit d90934c137
2 changed files with 145 additions and 20 deletions

62
env.go
View File

@ -12,17 +12,37 @@ import (
var (
defaultSep = "_"
)
var env *Env
type Env struct {
ignorePrefix bool
}
func init() {
env = &Env{false}
}
func upper(v string) string {
return strings.ToUpper(v)
}
func IgnorePrefix() {
env.ignorePrefix = true
}
func Fill(v interface{}) error {
if reflect.ValueOf(v).Kind() != reflect.Ptr {
return env.Fill(v)
}
func (e *Env)Fill(v interface{}) error {
ind := reflect.Indirect(reflect.ValueOf(v))
if reflect.ValueOf(v).Kind() != reflect.Ptr || ind.Kind() != reflect.Struct{
return fmt.Errorf("only the pointer to a struct is supported")
}
ind := reflect.Indirect(reflect.ValueOf(v))
prefix := upper(ind.Type().Name())
if e.ignorePrefix {
prefix = ""
}
err := fill(prefix, ind)
if err != nil {
return err
@ -30,11 +50,14 @@ func Fill(v interface{}) error {
return nil
}
func combine(p, n string, v string, ok bool) string {
func combine(p, n string, sep string, ok bool) string {
if p == "" {
return n
}
if !ok {
return p + defaultSep + n
}
return p + v + n
return p + sep + n
}
func parseBool(v string) (bool, error) {
@ -71,7 +94,6 @@ func fill(pf string, ind reflect.Value) error {
}
func parse(prefix string, f reflect.Value, sf reflect.StructField) error {
//log.Print("parse:", prefix, f.String(), f.Type().String(), f.Kind().String())
df := sf.Tag.Get("default")
isRequire, err := parseBool(sf.Tag.Get("require"))
if err != nil {
@ -93,51 +115,51 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error {
case reflect.Int:
iv, err := strconv.ParseInt(ev, 10, 32)
if err != nil {
return err
return fmt.Errorf("%s:%s", prefix, err)
}
f.SetInt(iv)
case reflect.Int64:
if f.Type().String() == "time.Duration" {
t, err := time.ParseDuration(ev)
if err != nil {
return err
return fmt.Errorf("%s:%s", prefix, err)
}
f.Set(reflect.ValueOf(t))
} else {
iv, err := strconv.ParseInt(ev, 10, 64)
if err != nil {
return err
return fmt.Errorf("%s:%s", prefix, err)
}
f.SetInt(iv)
}
case reflect.Uint:
uiv, err := strconv.ParseUint(ev, 10, 32)
if err != nil {
return err
return fmt.Errorf("%s:%s", prefix, err)
}
f.SetUint(uiv)
case reflect.Uint64:
uiv, err := strconv.ParseUint(ev, 10, 64)
if err != nil {
return err
return fmt.Errorf("%s:%s", prefix, err)
}
f.SetUint(uiv)
case reflect.Float32:
f32, err := strconv.ParseFloat(ev, 32)
if err != nil {
return err
return fmt.Errorf("%s:%s", prefix, err)
}
f.SetFloat(f32)
case reflect.Float64:
f64, err := strconv.ParseFloat(ev, 64)
if err != nil {
return err
return fmt.Errorf("%s:%s", prefix, err)
}
f.SetFloat(f64)
case reflect.Bool:
b, err := parseBool(ev)
if err != nil {
return err
return fmt.Errorf("%s:%s", prefix, err)
}
f.SetBool(b)
case reflect.Slice:
@ -155,7 +177,7 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error {
for i, v := range vals {
val, err := strconv.ParseInt(v, 10, 32)
if err != nil {
return err
return fmt.Errorf("%s:%s", prefix, err)
}
t[i] = int(val)
}
@ -164,7 +186,7 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error {
for i, v := range vals {
val, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return err
return fmt.Errorf("%s:%s", prefix, err)
}
t[i] = val
}
@ -173,7 +195,7 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error {
for i, v := range vals {
val, err := strconv.ParseUint(v, 10, 32)
if err != nil {
return err
return fmt.Errorf("%s:%s", prefix, err)
}
t[i] = uint(val)
}
@ -182,7 +204,7 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error {
for i, v := range vals {
val, err := strconv.ParseUint(v, 10, 64)
if err != nil {
return err
return fmt.Errorf("%s:%s", prefix, err)
}
t[i] = val
}
@ -191,7 +213,7 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error {
for i, v := range vals {
val, err := strconv.ParseFloat(v, 32)
if err != nil {
return err
return fmt.Errorf("%s:%s", prefix, err)
}
t[i] = float32(val)
}
@ -200,7 +222,7 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error {
for i, v := range vals {
val, err := strconv.ParseFloat(v, 64)
if err != nil {
return err
return fmt.Errorf("%s:%s", prefix, err)
}
t[i] = val
}
@ -209,7 +231,7 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error {
for i, v := range vals {
val, err := parseBool(v)
if err != nil {
return err
return fmt.Errorf("%s:%s", prefix, err)
}
t[i] = val
}

103
env_test.go Normal file
View File

@ -0,0 +1,103 @@
package env
import (
"testing"
"os"
"time"
"github.com/stretchr/testify/assert"
)
type config struct {
App string
Port int `default:"8000"`
IsDebug bool `env:"DEBUG"`
Hosts []string `slice_sep:","`
Timeout time.Duration
Redis struct {
Version string `sep:""` // no sep between `CONFIG` and `REDIS`
Host string
Port int
}
MySQL struct {
Version string `default:"5.7"`
Host string
Port int
}
}
func TestGeneralEnv(t *testing.T) {
os.Setenv("CONFIG_APP", "ENV APP")
// os.Setenv("CONFIG_PORT", "default") // default value
os.Setenv("CONFIG_DEBUG", "1")
os.Setenv("CONFIG_HOSTS", "192.168.0.1,127.0.0.1")
os.Setenv("CONFIG_TIMEOUT", "5s")
os.Setenv("CONFIG_REDISVERSION", "3.2")
os.Setenv("CONFIG_REDIS_HOST", "rdb")
os.Setenv("CONFIG_REDIS_PORT", "6379")
os.Setenv("CONFIG_MYSQL_HOST", "mysqldb")
os.Setenv("CONFIG_MYSQL_PORT", "3306")
defer os.Clearenv()
cfg := new(config)
err := Fill(cfg)
if err != nil {
t.Error(err)
}
assert := assert.New(t)
assert.Equal(cfg.App, "ENV APP")
assert.Equal(cfg.Port, 8000)
assert.Equal(cfg.IsDebug, true)
assert.Equal(cfg.Hosts, []string{"192.168.0.1", "127.0.0.1"})
assert.Equal(cfg.Timeout, 5*time.Second)
assert.Equal(cfg.Redis.Version, "3.2")
assert.Equal(cfg.Redis.Host, "rdb")
assert.Equal(cfg.MySQL.Version, "5.7")
assert.Equal(cfg.MySQL.Host, "mysqldb")
assert.Equal(cfg.MySQL.Port, 3306)
}
func TestNoPrefixEnv(t *testing.T) {
os.Setenv("APP", "ENV_APP")
// os.Setenv("PORT", "default") // default value
os.Setenv("DEBUG", "1")
os.Setenv("HOSTS", "192.168.1.1,127.0.0.1")
os.Setenv("TIMEOUT", "5s")
os.Setenv("REDISVERSION", "3.2")
os.Setenv("REDIS_HOST", "rdb")
os.Setenv("REDIS_PORT", "6379")
os.Setenv("MYSQL_HOST", "mysqldb")
os.Setenv("MYSQL_PORT", "3306")
defer os.Clearenv()
cfg := new(config)
IgnorePrefix()
err := Fill(cfg)
if err != nil {
t.Error(err)
}
assert := assert.New(t)
assert.Equal(cfg.App, "ENV_APP")
assert.Equal(cfg.Port, 8000)
assert.Equal(cfg.IsDebug, true)
assert.Equal(cfg.Hosts, []string{"192.168.1.1", "127.0.0.1"})
assert.Equal(cfg.Timeout, 5*time.Second)
assert.Equal(cfg.Redis.Version, "3.2")
assert.Equal(cfg.Redis.Host, "rdb")
assert.Equal(cfg.MySQL.Version, "5.7")
assert.Equal(cfg.MySQL.Host, "mysqldb")
assert.Equal(cfg.MySQL.Port, 3306)
}