Merge pull request #8 from mixer/fix-decode-panic
Fix panic during UnmarshalJSON
This commit is contained in:
commit
1c0147d077
12
snowflake.go
12
snowflake.go
@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -23,6 +24,13 @@ const encodeBase58Map = "123456789abcdefghijkmnopqrstuvwxyzABCDEFGHJKLMNPQRSTUVW
|
|||||||
|
|
||||||
var decodeBase58Map [256]byte
|
var decodeBase58Map [256]byte
|
||||||
|
|
||||||
|
// A JSONSyntaxError is returned from UnmarshalJSON if an invalid ID is provided.
|
||||||
|
type JSONSyntaxError struct{ original []byte }
|
||||||
|
|
||||||
|
func (j JSONSyntaxError) Error() string {
|
||||||
|
return fmt.Sprintf("invalid snowflake ID %q", string(j.original))
|
||||||
|
}
|
||||||
|
|
||||||
// Create a map for decoding Base58. This speeds up the process tremendously.
|
// Create a map for decoding Base58. This speeds up the process tremendously.
|
||||||
func init() {
|
func init() {
|
||||||
|
|
||||||
@ -200,6 +208,10 @@ func (f ID) MarshalJSON() ([]byte, error) {
|
|||||||
|
|
||||||
// UnmarshalJSON converts a json byte array of a snowflake ID into an ID type.
|
// UnmarshalJSON converts a json byte array of a snowflake ID into an ID type.
|
||||||
func (f *ID) UnmarshalJSON(b []byte) error {
|
func (f *ID) UnmarshalJSON(b []byte) error {
|
||||||
|
if len(b) < 3 || b[0] != '"' || b[len(b)-1] != '"' {
|
||||||
|
return JSONSyntaxError{b}
|
||||||
|
}
|
||||||
|
|
||||||
i, err := strconv.ParseInt(string(b[1:len(b)-1]), 10, 64)
|
i, err := strconv.ParseInt(string(b[1:len(b)-1]), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -2,6 +2,7 @@ package snowflake
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -28,17 +29,26 @@ func TestMarshalsIntBytes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalJSON(t *testing.T) {
|
func TestUnmarshalJSON(t *testing.T) {
|
||||||
strID := "\"13587\""
|
tt := []struct {
|
||||||
expected := ID(13587)
|
json string
|
||||||
|
expectedId ID
|
||||||
var id ID
|
expectedErr error
|
||||||
err := id.UnmarshalJSON([]byte(strID))
|
}{
|
||||||
if err != nil {
|
{`"13587"`, 13587, nil},
|
||||||
t.Error("Unexpected error during UnmarshalJSON")
|
{`1`, 0, JSONSyntaxError{[]byte(`1`)}},
|
||||||
|
{`"invalid`, 0, JSONSyntaxError{[]byte(`"invalid`)}},
|
||||||
}
|
}
|
||||||
|
|
||||||
if id != expected {
|
for _, tc := range tt {
|
||||||
t.Errorf("Got %d, expected %d", id, expected)
|
var id ID
|
||||||
|
err := id.UnmarshalJSON([]byte(tc.json))
|
||||||
|
if !reflect.DeepEqual(err, tc.expectedErr) {
|
||||||
|
t.Errorf("Expected to get error '%s' decoding JSON, but got '%s'", tc.expectedErr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if id != tc.expectedId {
|
||||||
|
t.Errorf("Expected to get ID '%s' decoding JSON, but got '%s'", tc.expectedId, id)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user