Fix panic during UnmarshalJSON

This commit is contained in:
Connor Peet 2017-06-16 09:02:15 -07:00
parent 02cc386c18
commit d3bf1ae440
No known key found for this signature in database
GPG Key ID: CF8FD2EA0DBC61BD
2 changed files with 31 additions and 9 deletions

View File

@ -5,6 +5,7 @@ import (
"encoding/base64"
"encoding/binary"
"errors"
"fmt"
"strconv"
"sync"
"time"
@ -23,6 +24,13 @@ const encodeBase58Map = "123456789abcdefghijkmnopqrstuvwxyzABCDEFGHJKLMNPQRSTUVW
var decodeBase58Map [256]byte
// A JSONSyntaxError is returned from UnmarshalJSON if an invalid ID is provided.
type JSONSyntaxError struct{ original []byte }
func (j JSONSyntaxError) Error() string {
return fmt.Sprintf("invalid snowflake ID %q", string(j.original))
}
// Create a map for decoding Base58. This speeds up the process tremendously.
func init() {
@ -200,6 +208,10 @@ func (f ID) MarshalJSON() ([]byte, error) {
// UnmarshalJSON converts a json byte array of a snowflake ID into an ID type.
func (f *ID) UnmarshalJSON(b []byte) error {
if len(b) < 3 || b[0] != '"' || b[len(b)-1] != '"' {
return JSONSyntaxError{b}
}
i, err := strconv.ParseInt(string(b[1:len(b)-1]), 10, 64)
if err != nil {
return err

View File

@ -2,6 +2,7 @@ package snowflake
import (
"bytes"
"reflect"
"testing"
)
@ -28,17 +29,26 @@ func TestMarshalsIntBytes(t *testing.T) {
}
func TestUnmarshalJSON(t *testing.T) {
strID := "\"13587\""
expected := ID(13587)
var id ID
err := id.UnmarshalJSON([]byte(strID))
if err != nil {
t.Error("Unexpected error during UnmarshalJSON")
tt := []struct {
json string
expectedId ID
expectedErr error
}{
{`"13587"`, 13587, nil},
{`1`, 0, JSONSyntaxError{[]byte(`1`)}},
{`"invalid`, 0, JSONSyntaxError{[]byte(`"invalid`)}},
}
if id != expected {
t.Errorf("Got %d, expected %d", id, expected)
for _, tc := range tt {
var id ID
err := id.UnmarshalJSON([]byte(tc.json))
if !reflect.DeepEqual(err, tc.expectedErr) {
t.Errorf("Expected to get error '%s' decoding JSON, but got '%s'", tc.expectedErr, err)
}
if id != tc.expectedId {
t.Errorf("Expected to get ID '%s' decoding JSON, but got '%s'", tc.expectedId, id)
}
}
}