Merge pull request #8 from mixer/fix-decode-panic
Fix panic during UnmarshalJSON
This commit is contained in:
		
							
								
								
									
										12
									
								
								snowflake.go
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								snowflake.go
									
									
									
									
									
								
							| @@ -5,6 +5,7 @@ import ( | |||||||
| 	"encoding/base64" | 	"encoding/base64" | ||||||
| 	"encoding/binary" | 	"encoding/binary" | ||||||
| 	"errors" | 	"errors" | ||||||
|  | 	"fmt" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| @@ -23,6 +24,13 @@ const encodeBase58Map = "123456789abcdefghijkmnopqrstuvwxyzABCDEFGHJKLMNPQRSTUVW | |||||||
|  |  | ||||||
| var decodeBase58Map [256]byte | var decodeBase58Map [256]byte | ||||||
|  |  | ||||||
|  | // A JSONSyntaxError is returned from UnmarshalJSON if an invalid ID is provided. | ||||||
|  | type JSONSyntaxError struct{ original []byte } | ||||||
|  |  | ||||||
|  | func (j JSONSyntaxError) Error() string { | ||||||
|  | 	return fmt.Sprintf("invalid snowflake ID %q", string(j.original)) | ||||||
|  | } | ||||||
|  |  | ||||||
| // Create a map for decoding Base58.  This speeds up the process tremendously. | // Create a map for decoding Base58.  This speeds up the process tremendously. | ||||||
| func init() { | func init() { | ||||||
|  |  | ||||||
| @@ -200,6 +208,10 @@ func (f ID) MarshalJSON() ([]byte, error) { | |||||||
|  |  | ||||||
| // UnmarshalJSON converts a json byte array of a snowflake ID into an ID type. | // UnmarshalJSON converts a json byte array of a snowflake ID into an ID type. | ||||||
| func (f *ID) UnmarshalJSON(b []byte) error { | func (f *ID) UnmarshalJSON(b []byte) error { | ||||||
|  | 	if len(b) < 3 || b[0] != '"' || b[len(b)-1] != '"' { | ||||||
|  | 		return JSONSyntaxError{b} | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	i, err := strconv.ParseInt(string(b[1:len(b)-1]), 10, 64) | 	i, err := strconv.ParseInt(string(b[1:len(b)-1]), 10, 64) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ package snowflake | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
|  | 	"reflect" | ||||||
| 	"testing" | 	"testing" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -28,17 +29,26 @@ func TestMarshalsIntBytes(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestUnmarshalJSON(t *testing.T) { | func TestUnmarshalJSON(t *testing.T) { | ||||||
| 	strID := "\"13587\"" | 	tt := []struct { | ||||||
| 	expected := ID(13587) | 		json        string | ||||||
|  | 		expectedId  ID | ||||||
| 	var id ID | 		expectedErr error | ||||||
| 	err := id.UnmarshalJSON([]byte(strID)) | 	}{ | ||||||
| 	if err != nil { | 		{`"13587"`, 13587, nil}, | ||||||
| 		t.Error("Unexpected error during UnmarshalJSON") | 		{`1`, 0, JSONSyntaxError{[]byte(`1`)}}, | ||||||
|  | 		{`"invalid`, 0, JSONSyntaxError{[]byte(`"invalid`)}}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if id != expected { | 	for _, tc := range tt { | ||||||
| 		t.Errorf("Got %d, expected %d", id, expected) | 		var id ID | ||||||
|  | 		err := id.UnmarshalJSON([]byte(tc.json)) | ||||||
|  | 		if !reflect.DeepEqual(err, tc.expectedErr) { | ||||||
|  | 			t.Errorf("Expected to get error '%s' decoding JSON, but got '%s'", tc.expectedErr, err) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if id != tc.expectedId { | ||||||
|  | 			t.Errorf("Expected to get ID '%s' decoding JSON, but got '%s'", tc.expectedId, id) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user