THRIFT-2174 Deserializing JSON fails in specific cases
Patch: Jens Geyer
diff --git a/lib/go/thrift/json_protocol.go b/lib/go/thrift/json_protocol.go
index 5e8453a..957d8ed 100644
--- a/lib/go/thrift/json_protocol.go
+++ b/lib/go/thrift/json_protocol.go
@@ -100,7 +100,11 @@
if e := p.OutputObjectBegin(); e != nil {
return e
}
- if e := p.WriteString(p.TypeIdToString(typeId)); e != nil {
+ s, e1 := p.TypeIdToString(typeId)
+ if e1 != nil {
+ return e1
+ }
+ if e := p.WriteString(s); e != nil {
return e
}
return nil
@@ -116,10 +120,18 @@
if e := p.OutputListBegin(); e != nil {
return e
}
- if e := p.WriteString(p.TypeIdToString(keyType)); e != nil {
+ s, e1 := p.TypeIdToString(keyType)
+ if e1 != nil {
+ return e1
+ }
+ if e := p.WriteString(s); e != nil {
return e
}
- if e := p.WriteString(p.TypeIdToString(valueType)); e != nil {
+ s, e1 = p.TypeIdToString(valueType)
+ if e1 != nil {
+ return e1
+ }
+ if e := p.WriteString(s); e != nil {
return e
}
return p.WriteI64(int64(size))
@@ -250,7 +262,10 @@
return "", STOP, fieldId, err
}
sType, err := p.ReadString()
- fType := p.StringToTypeId(sType)
+ if err != nil {
+ return "", STOP, fieldId, err
+ }
+ fType, err := p.StringToTypeId(sType)
return "", fType, fieldId, err
}
@@ -265,14 +280,20 @@
// read keyType
sKeyType, e := p.ReadString()
- keyType = p.StringToTypeId(sKeyType)
+ if e != nil {
+ return keyType, valueType, size, e
+ }
+ keyType, e = p.StringToTypeId(sKeyType)
if e != nil {
return keyType, valueType, size, e
}
// read valueType
sValueType, e := p.ReadString()
- valueType = p.StringToTypeId(sValueType)
+ if e != nil {
+ return keyType, valueType, size, e
+ }
+ valueType, e = p.StringToTypeId(sValueType)
if e != nil {
return keyType, valueType, size, e
}
@@ -436,7 +457,11 @@
if e := p.OutputListBegin(); e != nil {
return e
}
- if e := p.WriteString(p.TypeIdToString(elemType)); e != nil {
+ s, e1 := p.TypeIdToString(elemType)
+ if e1 != nil {
+ return e1
+ }
+ if e := p.WriteString(s); e != nil {
return e
}
if e := p.WriteI64(int64(size)); e != nil {
@@ -445,13 +470,15 @@
return nil
}
-
func (p *TJSONProtocol) ParseElemListBegin() (elemType TType, size int, e error) {
if isNull, e := p.ParseListBegin(); isNull || e != nil {
return VOID, 0, e
}
sElemType, err := p.ReadString()
- elemType = p.StringToTypeId(sElemType)
+ if err != nil {
+ return VOID, size, err
+ }
+ elemType, err = p.StringToTypeId(sElemType)
if err != nil {
return elemType, size, err
}
@@ -465,7 +492,10 @@
return VOID, 0, e
}
sElemType, err := p.ReadString()
- elemType = p.StringToTypeId(sElemType)
+ if err != nil {
+ return VOID, size, err
+ }
+ elemType, err = p.StringToTypeId(sElemType)
if err != nil {
return elemType, size, err
}
@@ -478,7 +508,11 @@
if e := p.OutputListBegin(); e != nil {
return e
}
- if e := p.OutputString(p.TypeIdToString(elemType)); e != nil {
+ s, e1 := p.TypeIdToString(elemType)
+ if e1 != nil {
+ return e1
+ }
+ if e := p.OutputString(s); e != nil {
return e
}
if e := p.OutputI64(int64(size)); e != nil {
@@ -487,70 +521,62 @@
return nil
}
-func (p *TJSONProtocol) TypeIdToString(fieldType TType) string {
+func (p *TJSONProtocol) TypeIdToString(fieldType TType) (string, error) {
switch byte(fieldType) {
- case STOP:
- return "stp"
- case VOID:
- return "v"
case BOOL:
- return "tf"
+ return "tf", nil
case BYTE:
- return "i8"
- case DOUBLE:
- return "dbl"
+ return "i8", nil
case I16:
- return "i16"
+ return "i16", nil
case I32:
- return "i32"
+ return "i32", nil
case I64:
- return "i64"
+ return "i64", nil
+ case DOUBLE:
+ return "dbl", nil
case STRING:
- return "str"
+ return "str", nil
case STRUCT:
- return "rec"
+ return "rec", nil
case MAP:
- return "map"
+ return "map", nil
case SET:
- return "set"
+ return "set", nil
case LIST:
- return "lst"
- case UTF16:
- return "str"
+ return "lst", nil
}
- return ""
+
+ e := fmt.Errorf("Unknown fieldType: %d", int(fieldType))
+ return "", NewTProtocolExceptionWithType(INVALID_DATA, e)
}
-func (p *TJSONProtocol) StringToTypeId(fieldType string) TType {
+func (p *TJSONProtocol) StringToTypeId(fieldType string) (TType, error) {
switch fieldType {
- case "stp":
- return TType(STOP)
- case "v":
- return TType(VOID)
case "tf":
- return TType(BOOL)
+ return TType(BOOL), nil
case "i8":
- return TType(BYTE)
- case "dbl":
- return TType(DOUBLE)
- case "16":
- return TType(I16)
+ return TType(BYTE), nil
+ case "i16":
+ return TType(I16), nil
case "i32":
- return TType(I32)
+ return TType(I32), nil
case "i64":
- return TType(I64)
+ return TType(I64), nil
+ case "dbl":
+ return TType(DOUBLE), nil
case "str":
- return TType(STRING)
+ return TType(STRING), nil
case "rec":
- return TType(STRUCT)
+ return TType(STRUCT), nil
case "map":
- return TType(MAP)
+ return TType(MAP), nil
case "set":
- return TType(SET)
+ return TType(SET), nil
case "lst":
- return TType(LIST)
- case "u16":
- return TType(UTF16)
+ return TType(LIST), nil
}
- return TType(STOP)
+
+ e := fmt.Errorf("Unknown type identifier: %s", fieldType)
+ return TType(STOP), NewTProtocolExceptionWithType(INVALID_DATA, e)
}
diff --git a/lib/go/thrift/simple_json_protocol.go b/lib/go/thrift/simple_json_protocol.go
index 9d0f68f..3755a2d 100644
--- a/lib/go/thrift/simple_json_protocol.go
+++ b/lib/go/thrift/simple_json_protocol.go
@@ -322,9 +322,6 @@
if err := p.ParsePreValue(); err != nil {
return "", STOP, 0, err
}
- if p.reader.Buffered() < 1 {
- return "", STOP, 0, nil
- }
b, _ := p.reader.Peek(1)
if len(b) > 0 {
switch b[0] {
@@ -482,11 +479,7 @@
return v, err
}
var b []byte
- if p.reader.Buffered() >= len(JSON_NULL) {
- b, _ = p.reader.Peek(len(JSON_NULL))
- } else {
- b, _ = p.reader.Peek(1)
- }
+ b, _ = p.reader.Peek(len(JSON_NULL))
if len(b) > 0 && b[0] == JSON_QUOTE {
p.reader.ReadByte()
value, err := p.ParseStringBody()
@@ -732,9 +725,6 @@
return NewTProtocolException(e)
}
cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1])
- if p.reader.Buffered() < 1 {
- return nil
- }
b, _ := p.reader.Peek(1)
switch cxt {
case _CONTEXT_IN_LIST:
@@ -813,7 +803,7 @@
}
func (p *TSimpleJSONProtocol) readNonSignificantWhitespace() error {
- for p.reader.Buffered() > 0 {
+ for {
b, _ := p.reader.Peek(1)
if len(b) < 1 {
return nil
@@ -950,11 +940,7 @@
return false, err
}
var b []byte
- if p.reader.Buffered() >= len(JSON_NULL) {
- b, _ = p.reader.Peek(len(JSON_NULL))
- } else if p.reader.Buffered() >= 1 {
- b, _ = p.reader.Peek(1)
- }
+ b, _ = p.reader.Peek(len(JSON_NULL))
if len(b) > 0 && b[0] == JSON_LBRACE[0] {
p.reader.ReadByte()
p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_OBJECT_FIRST))
@@ -997,11 +983,7 @@
return false, e
}
var b []byte
- if p.reader.Buffered() >= len(JSON_NULL) {
- b, err = p.reader.Peek(len(JSON_NULL))
- } else {
- b, err = p.reader.Peek(1)
- }
+ b, err = p.reader.Peek(len(JSON_NULL))
if err != nil {
return false, err
}
@@ -1134,7 +1116,7 @@
func (p *TSimpleJSONProtocol) readIfNull() (bool, error) {
cont := true
- for p.reader.Buffered() > 0 && cont {
+ for cont {
b, _ := p.reader.Peek(1)
if len(b) < 1 {
return false, nil
@@ -1150,9 +1132,6 @@
break
}
}
- if p.reader.Buffered() == 0 {
- return false, nil
- }
b, _ := p.reader.Peek(len(JSON_NULL))
if string(b) == string(JSON_NULL) {
p.reader.Read(b[0:len(JSON_NULL)])
@@ -1162,9 +1141,6 @@
}
func (p *TSimpleJSONProtocol) readQuoteIfNext() {
- if p.reader.Buffered() < 1 {
- return
- }
b, _ := p.reader.Peek(1)
if len(b) > 0 && b[0] == JSON_QUOTE {
p.reader.ReadByte()