THRIFT-2817 Smarter buffer peeking for json protocols
Client: Go
Patch: Chi Vinh Le <cvl@chinet.info>
diff --git a/lib/go/thrift/json_protocol.go b/lib/go/thrift/json_protocol.go
index 17fe530..78348ec 100644
--- a/lib/go/thrift/json_protocol.go
+++ b/lib/go/thrift/json_protocol.go
@@ -160,7 +160,7 @@
func (p *TJSONProtocol) WriteBool(b bool) error {
if b {
return p.WriteI32(1)
- }
+ }
return p.WriteI32(0)
}
@@ -196,18 +196,23 @@
if e := p.OutputPreValue(); e != nil {
return e
}
- p.writer.Write(JSON_QUOTE_BYTES)
+ if _, e := p.writer.Write(JSON_QUOTE_BYTES); e != nil {
+ return NewTProtocolException(e)
+ }
writer := base64.NewEncoder(base64.StdEncoding, p.writer)
if _, e := writer.Write(v); e != nil {
return NewTProtocolException(e)
}
- writer.Close()
- p.writer.Write(JSON_QUOTE_BYTES)
+ if e := writer.Close(); e != nil {
+ return NewTProtocolException(e)
+ }
+ if _, e := p.writer.Write(JSON_QUOTE_BYTES); e != nil {
+ return NewTProtocolException(e)
+ }
return p.OutputPostValue()
}
// Reading methods.
-
func (p *TJSONProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) {
if isNull, err := p.ParseListBegin(); isNull || err != nil {
return name, typeId, seqId, err
@@ -250,9 +255,6 @@
}
func (p *TJSONProtocol) ReadFieldBegin() (string, TType, int16, error) {
- if p.reader.Buffered() < 1 {
- return "", STOP, -1, nil
- }
b, _ := p.reader.Peek(1)
if len(b) < 1 || b[0] == JSON_RBRACE[0] || b[0] == JSON_RBRACKET[0] {
return "", STOP, -1, nil
@@ -328,7 +330,7 @@
}
func (p *TJSONProtocol) ReadBool() (bool, error) {
- value, err := p.ReadI32();
+ value, err := p.ReadI32()
return (value != 0), err
}
@@ -362,21 +364,26 @@
if err := p.ParsePreValue(); err != nil {
return v, err
}
- b, _ := p.reader.Peek(len(JSON_NULL))
- if len(b) > 0 && b[0] == JSON_QUOTE {
+ f, _ := p.reader.Peek(1)
+ if len(f) > 0 && f[0] == JSON_QUOTE {
p.reader.ReadByte()
value, err := p.ParseStringBody()
v = value
if err != nil {
return v, err
}
- } else if len(b) >= len(JSON_NULL) && string(b[0:len(JSON_NULL)]) == string(JSON_NULL) {
- _, err := p.reader.Read(b[0:len(JSON_NULL)])
+ } else if len(f) >= 0 && f[0] == JSON_NULL[0] {
+ b := make([]byte, len(JSON_NULL))
+ _, err := p.reader.Read(b)
if err != nil {
return v, NewTProtocolException(err)
}
+ if string(b) != string(JSON_NULL) {
+ e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(b))
+ return v, NewTProtocolExceptionWithType(INVALID_DATA, e)
+ }
} else {
- e := fmt.Errorf("Expected a JSON string, found %s", string(b))
+ e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(f))
return v, NewTProtocolExceptionWithType(INVALID_DATA, e)
}
return v, p.ParsePostValue()
@@ -387,23 +394,29 @@
if err := p.ParsePreValue(); err != nil {
return nil, err
}
- b, _ := p.reader.Peek(len(JSON_NULL))
- if len(b) > 0 && b[0] == JSON_QUOTE {
+ f, _ := p.reader.Peek(1)
+ if len(f) > 0 && f[0] == JSON_QUOTE {
p.reader.ReadByte()
value, err := p.ParseBase64EncodedBody()
v = value
if err != nil {
return v, err
}
- } else if len(b) >= len(JSON_NULL) && string(b[0:len(JSON_NULL)]) == string(JSON_NULL) {
- _, err := p.reader.Read(b[0:len(JSON_NULL)])
+ } else if len(f) >= 0 && f[0] == JSON_NULL[0] {
+ b := make([]byte, len(JSON_NULL))
+ _, err := p.reader.Read(b)
if err != nil {
return v, NewTProtocolException(err)
}
+ if string(b) != string(JSON_NULL) {
+ e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(b))
+ return v, NewTProtocolExceptionWithType(INVALID_DATA, e)
+ }
} else {
- e := fmt.Errorf("Expected a JSON string, found %s", string(b))
+ e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(f))
return v, NewTProtocolExceptionWithType(INVALID_DATA, e)
}
+
return v, p.ParsePostValue()
}
diff --git a/lib/go/thrift/simple_json_protocol.go b/lib/go/thrift/simple_json_protocol.go
index 71598ac..ab4e983 100644
--- a/lib/go/thrift/simple_json_protocol.go
+++ b/lib/go/thrift/simple_json_protocol.go
@@ -69,7 +69,7 @@
trans TTransport
parseContextStack []int
- dumpContext []int
+ dumpContext []int
writer *bufio.Writer
reader *bufio.Reader
@@ -286,7 +286,6 @@
}
// Reading methods.
-
func (p *TSimpleJSONProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) {
if isNull, err := p.ParseListBegin(); isNull || err != nil {
return name, typeId, seqId, err
@@ -330,9 +329,9 @@
case JSON_QUOTE:
p.reader.ReadByte()
name, err := p.ParseStringBody()
- // simplejson is not meant to be read back into thrift
- // - see http://wiki.apache.org/thrift/ThriftUsageJava
- // - use JSON instead
+ // simplejson is not meant to be read back into thrift
+ // - see http://wiki.apache.org/thrift/ThriftUsageJava
+ // - use JSON instead
if err != nil {
return name, STOP, 0, err
}
@@ -411,15 +410,20 @@
func (p *TSimpleJSONProtocol) ReadBool() (bool, error) {
var value bool
+
if err := p.ParsePreValue(); err != nil {
return value, err
}
- b, _ := p.reader.Peek(len(JSON_TRUE))
- if len(b) > 0 {
- switch b[0] {
+ f, _ := p.reader.Peek(1)
+ if len(f) > 0 {
+ switch f[0] {
case JSON_TRUE[0]:
+ b := make([]byte, len(JSON_TRUE))
+ _, err := p.reader.Read(b)
+ if err != nil {
+ return false, NewTProtocolException(err)
+ }
if string(b) == string(JSON_TRUE) {
- p.reader.Read(b[0:len(JSON_TRUE)])
value = true
} else {
e := fmt.Errorf("Expected \"true\" but found: %s", string(b))
@@ -427,8 +431,12 @@
}
break
case JSON_FALSE[0]:
- if string(b) == string(JSON_FALSE[:len(b)]) {
- p.reader.Read(b[0:len(JSON_FALSE)])
+ b := make([]byte, len(JSON_FALSE))
+ _, err := p.reader.Read(b)
+ if err != nil {
+ return false, NewTProtocolException(err)
+ }
+ if string(b) == string(JSON_FALSE) {
value = false
} else {
e := fmt.Errorf("Expected \"false\" but found: %s", string(b))
@@ -436,15 +444,19 @@
}
break
case JSON_NULL[0]:
+ b := make([]byte, len(JSON_NULL))
+ _, err := p.reader.Read(b)
+ if err != nil {
+ return false, NewTProtocolException(err)
+ }
if string(b) == string(JSON_NULL) {
- p.reader.Read(b[0:len(JSON_NULL)])
value = false
} else {
e := fmt.Errorf("Expected \"null\" but found: %s", string(b))
return value, NewTProtocolExceptionWithType(INVALID_DATA, e)
}
default:
- e := fmt.Errorf("Expected \"true\", \"false\", or \"null\" but found: %s", string(b))
+ e := fmt.Errorf("Expected \"true\", \"false\", or \"null\" but found: %s", string(f))
return value, NewTProtocolExceptionWithType(INVALID_DATA, e)
}
}
@@ -481,22 +493,26 @@
if err := p.ParsePreValue(); err != nil {
return v, err
}
- var b []byte
- b, _ = p.reader.Peek(len(JSON_NULL))
- if len(b) > 0 && b[0] == JSON_QUOTE {
+ f, _ := p.reader.Peek(1)
+ if len(f) > 0 && f[0] == JSON_QUOTE {
p.reader.ReadByte()
value, err := p.ParseStringBody()
v = value
if err != nil {
return v, err
}
- } else if len(b) >= len(JSON_NULL) && string(b[0:len(JSON_NULL)]) == string(JSON_NULL) {
- _, err := p.reader.Read(b[0:len(JSON_NULL)])
+ } else if len(f) >= 0 && f[0] == JSON_NULL[0] {
+ b := make([]byte, len(JSON_NULL))
+ _, err := p.reader.Read(b)
if err != nil {
return v, NewTProtocolException(err)
}
+ if string(b) != string(JSON_NULL) {
+ e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(b))
+ return v, NewTProtocolExceptionWithType(INVALID_DATA, e)
+ }
} else {
- e := fmt.Errorf("Expected a JSON string, found %s", string(b))
+ e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(f))
return v, NewTProtocolExceptionWithType(INVALID_DATA, e)
}
return v, p.ParsePostValue()
@@ -507,23 +523,29 @@
if err := p.ParsePreValue(); err != nil {
return nil, err
}
- b, _ := p.reader.Peek(len(JSON_NULL))
- if len(b) > 0 && b[0] == JSON_QUOTE {
+ f, _ := p.reader.Peek(1)
+ if len(f) > 0 && f[0] == JSON_QUOTE {
p.reader.ReadByte()
value, err := p.ParseBase64EncodedBody()
v = value
if err != nil {
return v, err
}
- } else if len(b) >= len(JSON_NULL) && string(b[0:len(JSON_NULL)]) == string(JSON_NULL) {
- _, err := p.reader.Read(b[0:len(JSON_NULL)])
+ } else if len(f) >= 0 && f[0] == JSON_NULL[0] {
+ b := make([]byte, len(JSON_NULL))
+ _, err := p.reader.Read(b)
if err != nil {
return v, NewTProtocolException(err)
}
+ if string(b) != string(JSON_NULL) {
+ e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(b))
+ return v, NewTProtocolExceptionWithType(INVALID_DATA, e)
+ }
} else {
- e := fmt.Errorf("Expected a JSON string, found %s", string(b))
+ e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(f))
return v, NewTProtocolExceptionWithType(INVALID_DATA, e)
}
+
return v, p.ParsePostValue()
}
@@ -898,9 +920,8 @@
}
var value int64
var isnull bool
- b, _ := p.reader.Peek(len(JSON_NULL))
- if len(b) >= len(JSON_NULL) && string(b) == string(JSON_NULL) {
- p.reader.Read(b[0:len(JSON_NULL)])
+ if p.safePeekContains(JSON_NULL) {
+ p.reader.Read(make([]byte, len(JSON_NULL)))
isnull = true
} else {
num, err := p.readNumeric()
@@ -921,9 +942,8 @@
}
var value float64
var isnull bool
- b, _ := p.reader.Peek(len(JSON_NULL))
- if len(b) >= len(JSON_NULL) && string(b) == string(JSON_NULL) {
- p.reader.Read(b[0:len(JSON_NULL)])
+ if p.safePeekContains(JSON_NULL) {
+ p.reader.Read(make([]byte, len(JSON_NULL)))
isnull = true
} else {
num, err := p.readNumeric()
@@ -943,12 +963,15 @@
return false, err
}
var b []byte
- b, _ = p.reader.Peek(len(JSON_NULL))
+ b, err := p.reader.Peek(1)
+ if err != nil {
+ return false, err
+ }
if len(b) > 0 && b[0] == JSON_LBRACE[0] {
p.reader.ReadByte()
p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_OBJECT_FIRST))
return false, nil
- } else if len(b) >= len(JSON_NULL) && string(b[0:len(JSON_NULL)]) == string(JSON_NULL) {
+ } else if p.safePeekContains(JSON_NULL) {
return true, nil
}
e := fmt.Errorf("Expected '{' or null, but found '%s'", string(b))
@@ -986,7 +1009,7 @@
return false, e
}
var b []byte
- b, err = p.reader.Peek(len(JSON_NULL))
+ b, err = p.reader.Peek(1)
if err != nil {
return false, err
}
@@ -994,7 +1017,7 @@
p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_LIST_FIRST))
p.reader.ReadByte()
isNull = false
- } else if len(b) >= len(JSON_NULL) && string(b) == string(JSON_NULL) {
+ } else if p.safePeekContains(JSON_NULL) {
isNull = true
} else {
err = fmt.Errorf("Expected \"null\" or \"[\", received %q", b)
@@ -1038,6 +1061,9 @@
}
}
p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1]
+ if _ParseContext(p.parseContextStack[len(p.parseContextStack)-1]) == _CONTEXT_IN_TOPLEVEL {
+ return nil
+ }
return p.ParsePostValue()
}
@@ -1046,7 +1072,7 @@
if e != nil {
return nil, VOID, NewTProtocolException(e)
}
- b, e := p.reader.Peek(10)
+ b, e := p.reader.Peek(1)
if len(b) > 0 {
c := b[0]
switch c {
@@ -1135,9 +1161,8 @@
break
}
}
- b, _ := p.reader.Peek(len(JSON_NULL))
- if string(b) == string(JSON_NULL) {
- p.reader.Read(b[0:len(JSON_NULL)])
+ if p.safePeekContains(JSON_NULL) {
+ p.reader.Read(make([]byte, len(JSON_NULL)))
return true, nil
}
return false, nil
@@ -1275,3 +1300,14 @@
}
return NewNumericFromJSONString(buf.String(), false), nil
}
+
+// Safely peeks into the buffer, reading only what is necessary
+func (p *TSimpleJSONProtocol) safePeekContains(b []byte) bool {
+ for i := 0; i < len(b); i++ {
+ a, _ := p.reader.Peek(i + 1)
+ if len(a) == 0 || a[i] != b[i] {
+ return false
+ }
+ }
+ return true
+}
diff --git a/lib/go/thrift/simple_json_protocol_test.go b/lib/go/thrift/simple_json_protocol_test.go
index 87a5c64..1abff75 100644
--- a/lib/go/thrift/simple_json_protocol_test.go
+++ b/lib/go/thrift/simple_json_protocol_test.go
@@ -221,6 +221,27 @@
}
}
+func TestReadSimpleJSONProtocolI32Null(t *testing.T) {
+ thetype := "int32"
+ value := "null"
+
+ trans := NewTMemoryBuffer()
+ p := NewTSimpleJSONProtocol(trans)
+ trans.WriteString(value)
+ trans.Flush()
+ s := trans.String()
+ v, e := p.ReadI32()
+
+ if e != nil {
+ t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
+ }
+ if v != 0 {
+ t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v)
+ }
+ trans.Reset()
+ trans.Close()
+}
+
func TestWriteSimpleJSONProtocolI64(t *testing.T) {
thetype := "int64"
trans := NewTMemoryBuffer()
@@ -268,6 +289,27 @@
}
}
+func TestReadSimpleJSONProtocolI64Null(t *testing.T) {
+ thetype := "int32"
+ value := "null"
+
+ trans := NewTMemoryBuffer()
+ p := NewTSimpleJSONProtocol(trans)
+ trans.WriteString(value)
+ trans.Flush()
+ s := trans.String()
+ v, e := p.ReadI64()
+
+ if e != nil {
+ t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
+ }
+ if v != 0 {
+ t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v)
+ }
+ trans.Reset()
+ trans.Close()
+}
+
func TestWriteSimpleJSONProtocolDouble(t *testing.T) {
thetype := "double"
trans := NewTMemoryBuffer()
@@ -391,6 +433,25 @@
trans.Close()
}
}
+func TestReadSimpleJSONProtocolStringNull(t *testing.T) {
+ thetype := "string"
+ value := "null"
+
+ trans := NewTMemoryBuffer()
+ p := NewTSimpleJSONProtocol(trans)
+ trans.WriteString(value)
+ trans.Flush()
+ s := trans.String()
+ v, e := p.ReadString()
+ if e != nil {
+ t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
+ }
+ if v != "" {
+ t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v)
+ }
+ trans.Reset()
+ trans.Close()
+}
func TestWriteSimpleJSONProtocolBinary(t *testing.T) {
thetype := "binary"
@@ -448,6 +509,28 @@
trans.Close()
}
+func TestReadSimpleJSONProtocolBinaryNull(t *testing.T) {
+ thetype := "binary"
+ value := "null"
+
+ trans := NewTMemoryBuffer()
+ p := NewTSimpleJSONProtocol(trans)
+ trans.WriteString(value)
+ trans.Flush()
+ s := trans.String()
+ b, e := p.ReadBinary()
+ v := string(b)
+
+ if e != nil {
+ t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
+ }
+ if v != "" {
+ t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v)
+ }
+ trans.Reset()
+ trans.Close()
+}
+
func TestWriteSimpleJSONProtocolList(t *testing.T) {
thetype := "list"
trans := NewTMemoryBuffer()