add test
This commit is contained in:
parent
3d591d9186
commit
d90934c137
62
env.go
62
env.go
@ -12,17 +12,37 @@ import (
|
|||||||
var (
|
var (
|
||||||
defaultSep = "_"
|
defaultSep = "_"
|
||||||
)
|
)
|
||||||
|
var env *Env
|
||||||
|
type Env struct {
|
||||||
|
ignorePrefix bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
env = &Env{false}
|
||||||
|
}
|
||||||
|
|
||||||
func upper(v string) string {
|
func upper(v string) string {
|
||||||
return strings.ToUpper(v)
|
return strings.ToUpper(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func IgnorePrefix() {
|
||||||
|
env.ignorePrefix = true
|
||||||
|
}
|
||||||
|
|
||||||
func Fill(v interface{}) error {
|
func Fill(v interface{}) error {
|
||||||
if reflect.ValueOf(v).Kind() != reflect.Ptr {
|
return env.Fill(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Env)Fill(v interface{}) error {
|
||||||
|
ind := reflect.Indirect(reflect.ValueOf(v))
|
||||||
|
if reflect.ValueOf(v).Kind() != reflect.Ptr || ind.Kind() != reflect.Struct{
|
||||||
return fmt.Errorf("only the pointer to a struct is supported")
|
return fmt.Errorf("only the pointer to a struct is supported")
|
||||||
}
|
}
|
||||||
ind := reflect.Indirect(reflect.ValueOf(v))
|
|
||||||
prefix := upper(ind.Type().Name())
|
prefix := upper(ind.Type().Name())
|
||||||
|
if e.ignorePrefix {
|
||||||
|
prefix = ""
|
||||||
|
}
|
||||||
err := fill(prefix, ind)
|
err := fill(prefix, ind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -30,11 +50,14 @@ func Fill(v interface{}) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func combine(p, n string, v string, ok bool) string {
|
func combine(p, n string, sep string, ok bool) string {
|
||||||
|
if p == "" {
|
||||||
|
return n
|
||||||
|
}
|
||||||
if !ok {
|
if !ok {
|
||||||
return p + defaultSep + n
|
return p + defaultSep + n
|
||||||
}
|
}
|
||||||
return p + v + n
|
return p + sep + n
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseBool(v string) (bool, error) {
|
func parseBool(v string) (bool, error) {
|
||||||
@ -71,7 +94,6 @@ func fill(pf string, ind reflect.Value) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func parse(prefix string, f reflect.Value, sf reflect.StructField) error {
|
func parse(prefix string, f reflect.Value, sf reflect.StructField) error {
|
||||||
//log.Print("parse:", prefix, f.String(), f.Type().String(), f.Kind().String())
|
|
||||||
df := sf.Tag.Get("default")
|
df := sf.Tag.Get("default")
|
||||||
isRequire, err := parseBool(sf.Tag.Get("require"))
|
isRequire, err := parseBool(sf.Tag.Get("require"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -93,51 +115,51 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error {
|
|||||||
case reflect.Int:
|
case reflect.Int:
|
||||||
iv, err := strconv.ParseInt(ev, 10, 32)
|
iv, err := strconv.ParseInt(ev, 10, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("%s:%s", prefix, err)
|
||||||
}
|
}
|
||||||
f.SetInt(iv)
|
f.SetInt(iv)
|
||||||
case reflect.Int64:
|
case reflect.Int64:
|
||||||
if f.Type().String() == "time.Duration" {
|
if f.Type().String() == "time.Duration" {
|
||||||
t, err := time.ParseDuration(ev)
|
t, err := time.ParseDuration(ev)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("%s:%s", prefix, err)
|
||||||
}
|
}
|
||||||
f.Set(reflect.ValueOf(t))
|
f.Set(reflect.ValueOf(t))
|
||||||
} else {
|
} else {
|
||||||
iv, err := strconv.ParseInt(ev, 10, 64)
|
iv, err := strconv.ParseInt(ev, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("%s:%s", prefix, err)
|
||||||
}
|
}
|
||||||
f.SetInt(iv)
|
f.SetInt(iv)
|
||||||
}
|
}
|
||||||
case reflect.Uint:
|
case reflect.Uint:
|
||||||
uiv, err := strconv.ParseUint(ev, 10, 32)
|
uiv, err := strconv.ParseUint(ev, 10, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("%s:%s", prefix, err)
|
||||||
}
|
}
|
||||||
f.SetUint(uiv)
|
f.SetUint(uiv)
|
||||||
case reflect.Uint64:
|
case reflect.Uint64:
|
||||||
uiv, err := strconv.ParseUint(ev, 10, 64)
|
uiv, err := strconv.ParseUint(ev, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("%s:%s", prefix, err)
|
||||||
}
|
}
|
||||||
f.SetUint(uiv)
|
f.SetUint(uiv)
|
||||||
case reflect.Float32:
|
case reflect.Float32:
|
||||||
f32, err := strconv.ParseFloat(ev, 32)
|
f32, err := strconv.ParseFloat(ev, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("%s:%s", prefix, err)
|
||||||
}
|
}
|
||||||
f.SetFloat(f32)
|
f.SetFloat(f32)
|
||||||
case reflect.Float64:
|
case reflect.Float64:
|
||||||
f64, err := strconv.ParseFloat(ev, 64)
|
f64, err := strconv.ParseFloat(ev, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("%s:%s", prefix, err)
|
||||||
}
|
}
|
||||||
f.SetFloat(f64)
|
f.SetFloat(f64)
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
b, err := parseBool(ev)
|
b, err := parseBool(ev)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("%s:%s", prefix, err)
|
||||||
}
|
}
|
||||||
f.SetBool(b)
|
f.SetBool(b)
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
@ -155,7 +177,7 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error {
|
|||||||
for i, v := range vals {
|
for i, v := range vals {
|
||||||
val, err := strconv.ParseInt(v, 10, 32)
|
val, err := strconv.ParseInt(v, 10, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("%s:%s", prefix, err)
|
||||||
}
|
}
|
||||||
t[i] = int(val)
|
t[i] = int(val)
|
||||||
}
|
}
|
||||||
@ -164,7 +186,7 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error {
|
|||||||
for i, v := range vals {
|
for i, v := range vals {
|
||||||
val, err := strconv.ParseInt(v, 10, 64)
|
val, err := strconv.ParseInt(v, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("%s:%s", prefix, err)
|
||||||
}
|
}
|
||||||
t[i] = val
|
t[i] = val
|
||||||
}
|
}
|
||||||
@ -173,7 +195,7 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error {
|
|||||||
for i, v := range vals {
|
for i, v := range vals {
|
||||||
val, err := strconv.ParseUint(v, 10, 32)
|
val, err := strconv.ParseUint(v, 10, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("%s:%s", prefix, err)
|
||||||
}
|
}
|
||||||
t[i] = uint(val)
|
t[i] = uint(val)
|
||||||
}
|
}
|
||||||
@ -182,7 +204,7 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error {
|
|||||||
for i, v := range vals {
|
for i, v := range vals {
|
||||||
val, err := strconv.ParseUint(v, 10, 64)
|
val, err := strconv.ParseUint(v, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("%s:%s", prefix, err)
|
||||||
}
|
}
|
||||||
t[i] = val
|
t[i] = val
|
||||||
}
|
}
|
||||||
@ -191,7 +213,7 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error {
|
|||||||
for i, v := range vals {
|
for i, v := range vals {
|
||||||
val, err := strconv.ParseFloat(v, 32)
|
val, err := strconv.ParseFloat(v, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("%s:%s", prefix, err)
|
||||||
}
|
}
|
||||||
t[i] = float32(val)
|
t[i] = float32(val)
|
||||||
}
|
}
|
||||||
@ -200,7 +222,7 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error {
|
|||||||
for i, v := range vals {
|
for i, v := range vals {
|
||||||
val, err := strconv.ParseFloat(v, 64)
|
val, err := strconv.ParseFloat(v, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("%s:%s", prefix, err)
|
||||||
}
|
}
|
||||||
t[i] = val
|
t[i] = val
|
||||||
}
|
}
|
||||||
@ -209,7 +231,7 @@ func parse(prefix string, f reflect.Value, sf reflect.StructField) error {
|
|||||||
for i, v := range vals {
|
for i, v := range vals {
|
||||||
val, err := parseBool(v)
|
val, err := parseBool(v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("%s:%s", prefix, err)
|
||||||
}
|
}
|
||||||
t[i] = val
|
t[i] = val
|
||||||
}
|
}
|
||||||
|
103
env_test.go
Normal file
103
env_test.go
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
package env
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
type config struct {
|
||||||
|
App string
|
||||||
|
Port int `default:"8000"`
|
||||||
|
IsDebug bool `env:"DEBUG"`
|
||||||
|
Hosts []string `slice_sep:","`
|
||||||
|
Timeout time.Duration
|
||||||
|
|
||||||
|
Redis struct {
|
||||||
|
Version string `sep:""` // no sep between `CONFIG` and `REDIS`
|
||||||
|
Host string
|
||||||
|
Port int
|
||||||
|
}
|
||||||
|
|
||||||
|
MySQL struct {
|
||||||
|
Version string `default:"5.7"`
|
||||||
|
Host string
|
||||||
|
Port int
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func TestGeneralEnv(t *testing.T) {
|
||||||
|
os.Setenv("CONFIG_APP", "ENV APP")
|
||||||
|
// os.Setenv("CONFIG_PORT", "default") // default value
|
||||||
|
os.Setenv("CONFIG_DEBUG", "1")
|
||||||
|
os.Setenv("CONFIG_HOSTS", "192.168.0.1,127.0.0.1")
|
||||||
|
os.Setenv("CONFIG_TIMEOUT", "5s")
|
||||||
|
|
||||||
|
os.Setenv("CONFIG_REDISVERSION", "3.2")
|
||||||
|
os.Setenv("CONFIG_REDIS_HOST", "rdb")
|
||||||
|
os.Setenv("CONFIG_REDIS_PORT", "6379")
|
||||||
|
|
||||||
|
os.Setenv("CONFIG_MYSQL_HOST", "mysqldb")
|
||||||
|
os.Setenv("CONFIG_MYSQL_PORT", "3306")
|
||||||
|
|
||||||
|
defer os.Clearenv()
|
||||||
|
|
||||||
|
cfg := new(config)
|
||||||
|
err := Fill(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert := assert.New(t)
|
||||||
|
|
||||||
|
assert.Equal(cfg.App, "ENV APP")
|
||||||
|
assert.Equal(cfg.Port, 8000)
|
||||||
|
assert.Equal(cfg.IsDebug, true)
|
||||||
|
assert.Equal(cfg.Hosts, []string{"192.168.0.1", "127.0.0.1"})
|
||||||
|
assert.Equal(cfg.Timeout, 5*time.Second)
|
||||||
|
assert.Equal(cfg.Redis.Version, "3.2")
|
||||||
|
assert.Equal(cfg.Redis.Host, "rdb")
|
||||||
|
assert.Equal(cfg.MySQL.Version, "5.7")
|
||||||
|
assert.Equal(cfg.MySQL.Host, "mysqldb")
|
||||||
|
assert.Equal(cfg.MySQL.Port, 3306)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
func TestNoPrefixEnv(t *testing.T) {
|
||||||
|
os.Setenv("APP", "ENV_APP")
|
||||||
|
// os.Setenv("PORT", "default") // default value
|
||||||
|
os.Setenv("DEBUG", "1")
|
||||||
|
os.Setenv("HOSTS", "192.168.1.1,127.0.0.1")
|
||||||
|
os.Setenv("TIMEOUT", "5s")
|
||||||
|
|
||||||
|
os.Setenv("REDISVERSION", "3.2")
|
||||||
|
os.Setenv("REDIS_HOST", "rdb")
|
||||||
|
os.Setenv("REDIS_PORT", "6379")
|
||||||
|
|
||||||
|
os.Setenv("MYSQL_HOST", "mysqldb")
|
||||||
|
os.Setenv("MYSQL_PORT", "3306")
|
||||||
|
|
||||||
|
defer os.Clearenv()
|
||||||
|
|
||||||
|
cfg := new(config)
|
||||||
|
IgnorePrefix()
|
||||||
|
err := Fill(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert := assert.New(t)
|
||||||
|
|
||||||
|
assert.Equal(cfg.App, "ENV_APP")
|
||||||
|
assert.Equal(cfg.Port, 8000)
|
||||||
|
assert.Equal(cfg.IsDebug, true)
|
||||||
|
assert.Equal(cfg.Hosts, []string{"192.168.1.1", "127.0.0.1"})
|
||||||
|
assert.Equal(cfg.Timeout, 5*time.Second)
|
||||||
|
assert.Equal(cfg.Redis.Version, "3.2")
|
||||||
|
assert.Equal(cfg.Redis.Host, "rdb")
|
||||||
|
assert.Equal(cfg.MySQL.Version, "5.7")
|
||||||
|
assert.Equal(cfg.MySQL.Host, "mysqldb")
|
||||||
|
assert.Equal(cfg.MySQL.Port, 3306)
|
||||||
|
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user