diff --git a/snowflake.go b/snowflake.go index 859a648..bdd68bd 100644 --- a/snowflake.go +++ b/snowflake.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "encoding/binary" "errors" + "fmt" "strconv" "sync" "time" @@ -23,6 +24,13 @@ const encodeBase58Map = "123456789abcdefghijkmnopqrstuvwxyzABCDEFGHJKLMNPQRSTUVW var decodeBase58Map [256]byte +// A JSONSyntaxError is returned from UnmarshalJSON if an invalid ID is provided. +type JSONSyntaxError struct{ original []byte } + +func (j JSONSyntaxError) Error() string { + return fmt.Sprintf("invalid snowflake ID %q", string(j.original)) +} + // Create a map for decoding Base58. This speeds up the process tremendously. func init() { @@ -200,6 +208,10 @@ func (f ID) MarshalJSON() ([]byte, error) { // UnmarshalJSON converts a json byte array of a snowflake ID into an ID type. func (f *ID) UnmarshalJSON(b []byte) error { + if len(b) < 3 || b[0] != '"' || b[len(b)-1] != '"' { + return JSONSyntaxError{b} + } + i, err := strconv.ParseInt(string(b[1:len(b)-1]), 10, 64) if err != nil { return err diff --git a/snowflake_test.go b/snowflake_test.go index f2ba51d..d185add 100644 --- a/snowflake_test.go +++ b/snowflake_test.go @@ -2,6 +2,7 @@ package snowflake import ( "bytes" + "reflect" "testing" ) @@ -28,17 +29,26 @@ func TestMarshalsIntBytes(t *testing.T) { } func TestUnmarshalJSON(t *testing.T) { - strID := "\"13587\"" - expected := ID(13587) - - var id ID - err := id.UnmarshalJSON([]byte(strID)) - if err != nil { - t.Error("Unexpected error during UnmarshalJSON") + tt := []struct { + json string + expectedId ID + expectedErr error + }{ + {`"13587"`, 13587, nil}, + {`1`, 0, JSONSyntaxError{[]byte(`1`)}}, + {`"invalid`, 0, JSONSyntaxError{[]byte(`"invalid`)}}, } - if id != expected { - t.Errorf("Got %d, expected %d", id, expected) + for _, tc := range tt { + var id ID + err := id.UnmarshalJSON([]byte(tc.json)) + if !reflect.DeepEqual(err, tc.expectedErr) { + t.Errorf("Expected to get error '%s' decoding JSON, but got '%s'", tc.expectedErr, err) + } + + if id != tc.expectedId { + t.Errorf("Expected to get ID '%s' decoding JSON, but got '%s'", tc.expectedId, id) + } } }