add test
This commit is contained in:
parent
3d591d9186
commit
d90934c137
62
env.go
62
env.go
@ -12,17 +12,37 @@ import (
|
||||
var (
|
||||
defaultSep = "_"
|
||||
)
|
||||
var env *Env
|
||||
type Env struct {
|
||||
ignorePrefix bool
|
||||
}
|
||||
|
||||
func init() {
|
||||
env = &Env{false}
|
||||
}
|
||||
|
||||
func upper(v string) string {
|
||||
return strings.ToUpper(v)
|
||||
}
|
||||
|
||||
func IgnorePrefix() {
|
||||
env.ignorePrefix = true
|
||||
}
|
||||
|
||||
func Fill(v interface{}) error {
|
||||
if reflect.ValueOf(v).Kind() != reflect.Ptr {
|
||||
return env.Fill(v)
|
||||
}
|
||||
|
||||
func (e *Env)Fill(v interface{}) error {
|
||||
ind := reflect.Indirect(reflect.ValueOf(v))
|
||||
if reflect.ValueOf(v).Kind() != reflect.Ptr || ind.Kind() != reflect.Struct{
|
||||
return fmt.Errorf("only the pointer to a struct is supported")
|
||||
}
|
||||
ind := reflect.Indirect(reflect.ValueOf(v))
|
||||
|
||||
prefix := upper(ind.Type().Name())
|
||||
if e.ignorePrefix {
|
||||
prefix = ""
|
||||
}
|
||||
err := fill(prefix, ind)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -30,11 +50,14 @@ func Fill(v interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func combine(p, n string, v string, ok bool) string {
|
||||
func combine(p, n string, sep string, ok bool) string {
|
||||
if p == "" {
|
||||
return n
|
||||
}
|
||||
if !ok {
|
||||
return p + defaultSep + n
|
||||
}
|
||||
return p + v + n
|
||||
return p + sep + n
|
||||
}
|
||||
|
||||
func parseBool(v string) (bool, error) {
|
||||
@ -71,7 +94,6 @@ func fill(pf string, ind reflect.Value) error {
|
||||
}
|
||||
|
||||
func parse(prefix string, f reflect.Value, sf reflect.StructField) error {
|
||||
//log.Print("parse:", prefix, f.String(), f.Type().String(), f.Kind().String())
|
||||
df := sf.Tag.Get("default")
|
||||
isRequire, err := parseBool(sf.Tag.Get("require"))
|
||||
if err != nil {
|
||||
@ -93,51 +115,51 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error {
|
||||
case reflect.Int:
|
||||
iv, err := strconv.ParseInt(ev, 10, 32)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("%s:%s", prefix, err)
|
||||
}
|
||||
f.SetInt(iv)
|
||||
case reflect.Int64:
|
||||
if f.Type().String() == "time.Duration" {
|
||||
t, err := time.ParseDuration(ev)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("%s:%s", prefix, err)
|
||||
}
|
||||
f.Set(reflect.ValueOf(t))
|
||||
} else {
|
||||
iv, err := strconv.ParseInt(ev, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("%s:%s", prefix, err)
|
||||
}
|
||||
f.SetInt(iv)
|
||||
}
|
||||
case reflect.Uint:
|
||||
uiv, err := strconv.ParseUint(ev, 10, 32)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("%s:%s", prefix, err)
|
||||
}
|
||||
f.SetUint(uiv)
|
||||
case reflect.Uint64:
|
||||
uiv, err := strconv.ParseUint(ev, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("%s:%s", prefix, err)
|
||||
}
|
||||
f.SetUint(uiv)
|
||||
case reflect.Float32:
|
||||
f32, err := strconv.ParseFloat(ev, 32)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("%s:%s", prefix, err)
|
||||
}
|
||||
f.SetFloat(f32)
|
||||
case reflect.Float64:
|
||||
f64, err := strconv.ParseFloat(ev, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("%s:%s", prefix, err)
|
||||
}
|
||||
f.SetFloat(f64)
|
||||
case reflect.Bool:
|
||||
b, err := parseBool(ev)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("%s:%s", prefix, err)
|
||||
}
|
||||
f.SetBool(b)
|
||||
case reflect.Slice:
|
||||
@ -155,7 +177,7 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error {
|
||||
for i, v := range vals {
|
||||
val, err := strconv.ParseInt(v, 10, 32)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("%s:%s", prefix, err)
|
||||
}
|
||||
t[i] = int(val)
|
||||
}
|
||||
@ -164,7 +186,7 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error {
|
||||
for i, v := range vals {
|
||||
val, err := strconv.ParseInt(v, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("%s:%s", prefix, err)
|
||||
}
|
||||
t[i] = val
|
||||
}
|
||||
@ -173,7 +195,7 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error {
|
||||
for i, v := range vals {
|
||||
val, err := strconv.ParseUint(v, 10, 32)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("%s:%s", prefix, err)
|
||||
}
|
||||
t[i] = uint(val)
|
||||
}
|
||||
@ -182,7 +204,7 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error {
|
||||
for i, v := range vals {
|
||||
val, err := strconv.ParseUint(v, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("%s:%s", prefix, err)
|
||||
}
|
||||
t[i] = val
|
||||
}
|
||||
@ -191,7 +213,7 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error {
|
||||
for i, v := range vals {
|
||||
val, err := strconv.ParseFloat(v, 32)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("%s:%s", prefix, err)
|
||||
}
|
||||
t[i] = float32(val)
|
||||
}
|
||||
@ -200,7 +222,7 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error {
|
||||
for i, v := range vals {
|
||||
val, err := strconv.ParseFloat(v, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("%s:%s", prefix, err)
|
||||
}
|
||||
t[i] = val
|
||||
}
|
||||
@ -209,7 +231,7 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error {
|
||||
for i, v := range vals {
|
||||
val, err := parseBool(v)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("%s:%s", prefix, err)
|
||||
}
|
||||
t[i] = val
|
||||
}
|
||||
|
103
env_test.go
Normal file
103
env_test.go
Normal file
@ -0,0 +1,103 @@
|
||||
package env
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"os"
|
||||
"time"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
type config struct {
|
||||
App string
|
||||
Port int `default:"8000"`
|
||||
IsDebug bool `env:"DEBUG"`
|
||||
Hosts []string `slice_sep:","`
|
||||
Timeout time.Duration
|
||||
|
||||
Redis struct {
|
||||
Version string `sep:""` // no sep between `CONFIG` and `REDIS`
|
||||
Host string
|
||||
Port int
|
||||
}
|
||||
|
||||
MySQL struct {
|
||||
Version string `default:"5.7"`
|
||||
Host string
|
||||
Port int
|
||||
}
|
||||
}
|
||||
func TestGeneralEnv(t *testing.T) {
|
||||
os.Setenv("CONFIG_APP", "ENV APP")
|
||||
// os.Setenv("CONFIG_PORT", "default") // default value
|
||||
os.Setenv("CONFIG_DEBUG", "1")
|
||||
os.Setenv("CONFIG_HOSTS", "192.168.0.1,127.0.0.1")
|
||||
os.Setenv("CONFIG_TIMEOUT", "5s")
|
||||
|
||||
os.Setenv("CONFIG_REDISVERSION", "3.2")
|
||||
os.Setenv("CONFIG_REDIS_HOST", "rdb")
|
||||
os.Setenv("CONFIG_REDIS_PORT", "6379")
|
||||
|
||||
os.Setenv("CONFIG_MYSQL_HOST", "mysqldb")
|
||||
os.Setenv("CONFIG_MYSQL_PORT", "3306")
|
||||
|
||||
defer os.Clearenv()
|
||||
|
||||
cfg := new(config)
|
||||
err := Fill(cfg)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
assert := assert.New(t)
|
||||
|
||||
assert.Equal(cfg.App, "ENV APP")
|
||||
assert.Equal(cfg.Port, 8000)
|
||||
assert.Equal(cfg.IsDebug, true)
|
||||
assert.Equal(cfg.Hosts, []string{"192.168.0.1", "127.0.0.1"})
|
||||
assert.Equal(cfg.Timeout, 5*time.Second)
|
||||
assert.Equal(cfg.Redis.Version, "3.2")
|
||||
assert.Equal(cfg.Redis.Host, "rdb")
|
||||
assert.Equal(cfg.MySQL.Version, "5.7")
|
||||
assert.Equal(cfg.MySQL.Host, "mysqldb")
|
||||
assert.Equal(cfg.MySQL.Port, 3306)
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
func TestNoPrefixEnv(t *testing.T) {
|
||||
os.Setenv("APP", "ENV_APP")
|
||||
// os.Setenv("PORT", "default") // default value
|
||||
os.Setenv("DEBUG", "1")
|
||||
os.Setenv("HOSTS", "192.168.1.1,127.0.0.1")
|
||||
os.Setenv("TIMEOUT", "5s")
|
||||
|
||||
os.Setenv("REDISVERSION", "3.2")
|
||||
os.Setenv("REDIS_HOST", "rdb")
|
||||
os.Setenv("REDIS_PORT", "6379")
|
||||
|
||||
os.Setenv("MYSQL_HOST", "mysqldb")
|
||||
os.Setenv("MYSQL_PORT", "3306")
|
||||
|
||||
defer os.Clearenv()
|
||||
|
||||
cfg := new(config)
|
||||
IgnorePrefix()
|
||||
err := Fill(cfg)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
assert := assert.New(t)
|
||||
|
||||
assert.Equal(cfg.App, "ENV_APP")
|
||||
assert.Equal(cfg.Port, 8000)
|
||||
assert.Equal(cfg.IsDebug, true)
|
||||
assert.Equal(cfg.Hosts, []string{"192.168.1.1", "127.0.0.1"})
|
||||
assert.Equal(cfg.Timeout, 5*time.Second)
|
||||
assert.Equal(cfg.Redis.Version, "3.2")
|
||||
assert.Equal(cfg.Redis.Host, "rdb")
|
||||
assert.Equal(cfg.MySQL.Version, "5.7")
|
||||
assert.Equal(cfg.MySQL.Host, "mysqldb")
|
||||
assert.Equal(cfg.MySQL.Port, 3306)
|
||||
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user