THRIFT-2817 Smarter buffer peeking for json protocols
Client: Go
Patch: Chi Vinh Le <cvl@chinet.info>
diff --git a/lib/go/thrift/json_protocol.go b/lib/go/thrift/json_protocol.go
index 17fe530..78348ec 100644
--- a/lib/go/thrift/json_protocol.go
+++ b/lib/go/thrift/json_protocol.go
@@ -160,7 +160,7 @@
 func (p *TJSONProtocol) WriteBool(b bool) error {
 	if b {
 		return p.WriteI32(1)
-    }
+	}
 	return p.WriteI32(0)
 }
 
@@ -196,18 +196,23 @@
 	if e := p.OutputPreValue(); e != nil {
 		return e
 	}
-	p.writer.Write(JSON_QUOTE_BYTES)
+	if _, e := p.writer.Write(JSON_QUOTE_BYTES); e != nil {
+		return NewTProtocolException(e)
+	}
 	writer := base64.NewEncoder(base64.StdEncoding, p.writer)
 	if _, e := writer.Write(v); e != nil {
 		return NewTProtocolException(e)
 	}
-	writer.Close()
-	p.writer.Write(JSON_QUOTE_BYTES)
+	if e := writer.Close(); e != nil {
+		return NewTProtocolException(e)
+	}
+	if _, e := p.writer.Write(JSON_QUOTE_BYTES); e != nil {
+		return NewTProtocolException(e)
+	}
 	return p.OutputPostValue()
 }
 
 // Reading methods.
-
 func (p *TJSONProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) {
 	if isNull, err := p.ParseListBegin(); isNull || err != nil {
 		return name, typeId, seqId, err
@@ -250,9 +255,6 @@
 }
 
 func (p *TJSONProtocol) ReadFieldBegin() (string, TType, int16, error) {
-	if p.reader.Buffered() < 1 {
-		return "", STOP, -1, nil
-	}
 	b, _ := p.reader.Peek(1)
 	if len(b) < 1 || b[0] == JSON_RBRACE[0] || b[0] == JSON_RBRACKET[0] {
 		return "", STOP, -1, nil
@@ -328,7 +330,7 @@
 }
 
 func (p *TJSONProtocol) ReadBool() (bool, error) {
-	value, err := p.ReadI32(); 
+	value, err := p.ReadI32()
 	return (value != 0), err
 }
 
@@ -362,21 +364,26 @@
 	if err := p.ParsePreValue(); err != nil {
 		return v, err
 	}
-	b, _ := p.reader.Peek(len(JSON_NULL))
-	if len(b) > 0 && b[0] == JSON_QUOTE {
+	f, _ := p.reader.Peek(1)
+	if len(f) > 0 && f[0] == JSON_QUOTE {
 		p.reader.ReadByte()
 		value, err := p.ParseStringBody()
 		v = value
 		if err != nil {
 			return v, err
 		}
-	} else if len(b) >= len(JSON_NULL) && string(b[0:len(JSON_NULL)]) == string(JSON_NULL) {
-		_, err := p.reader.Read(b[0:len(JSON_NULL)])
+	} else if len(f) >= 0 && f[0] == JSON_NULL[0] {
+		b := make([]byte, len(JSON_NULL))
+		_, err := p.reader.Read(b)
 		if err != nil {
 			return v, NewTProtocolException(err)
 		}
+		if string(b) != string(JSON_NULL) {
+			e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(b))
+			return v, NewTProtocolExceptionWithType(INVALID_DATA, e)
+		}
 	} else {
-		e := fmt.Errorf("Expected a JSON string, found %s", string(b))
+		e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(f))
 		return v, NewTProtocolExceptionWithType(INVALID_DATA, e)
 	}
 	return v, p.ParsePostValue()
@@ -387,23 +394,29 @@
 	if err := p.ParsePreValue(); err != nil {
 		return nil, err
 	}
-	b, _ := p.reader.Peek(len(JSON_NULL))
-	if len(b) > 0 && b[0] == JSON_QUOTE {
+	f, _ := p.reader.Peek(1)
+	if len(f) > 0 && f[0] == JSON_QUOTE {
 		p.reader.ReadByte()
 		value, err := p.ParseBase64EncodedBody()
 		v = value
 		if err != nil {
 			return v, err
 		}
-	} else if len(b) >= len(JSON_NULL) && string(b[0:len(JSON_NULL)]) == string(JSON_NULL) {
-		_, err := p.reader.Read(b[0:len(JSON_NULL)])
+	} else if len(f) >= 0 && f[0] == JSON_NULL[0] {
+		b := make([]byte, len(JSON_NULL))
+		_, err := p.reader.Read(b)
 		if err != nil {
 			return v, NewTProtocolException(err)
 		}
+		if string(b) != string(JSON_NULL) {
+			e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(b))
+			return v, NewTProtocolExceptionWithType(INVALID_DATA, e)
+		}
 	} else {
-		e := fmt.Errorf("Expected a JSON string, found %s", string(b))
+		e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(f))
 		return v, NewTProtocolExceptionWithType(INVALID_DATA, e)
 	}
+
 	return v, p.ParsePostValue()
 }
 
diff --git a/lib/go/thrift/simple_json_protocol.go b/lib/go/thrift/simple_json_protocol.go
index 71598ac..ab4e983 100644
--- a/lib/go/thrift/simple_json_protocol.go
+++ b/lib/go/thrift/simple_json_protocol.go
@@ -69,7 +69,7 @@
 	trans TTransport
 
 	parseContextStack []int
-	dumpContext []int
+	dumpContext       []int
 
 	writer *bufio.Writer
 	reader *bufio.Reader
@@ -286,7 +286,6 @@
 }
 
 // Reading methods.
-
 func (p *TSimpleJSONProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) {
 	if isNull, err := p.ParseListBegin(); isNull || err != nil {
 		return name, typeId, seqId, err
@@ -330,9 +329,9 @@
 		case JSON_QUOTE:
 			p.reader.ReadByte()
 			name, err := p.ParseStringBody()
-            // simplejson is not meant to be read back into thrift 
-            // - see http://wiki.apache.org/thrift/ThriftUsageJava
-            // - use JSON instead
+			// simplejson is not meant to be read back into thrift
+			// - see http://wiki.apache.org/thrift/ThriftUsageJava
+			// - use JSON instead
 			if err != nil {
 				return name, STOP, 0, err
 			}
@@ -411,15 +410,20 @@
 
 func (p *TSimpleJSONProtocol) ReadBool() (bool, error) {
 	var value bool
+
 	if err := p.ParsePreValue(); err != nil {
 		return value, err
 	}
-	b, _ := p.reader.Peek(len(JSON_TRUE))
-	if len(b) > 0 {
-		switch b[0] {
+	f, _ := p.reader.Peek(1)
+	if len(f) > 0 {
+		switch f[0] {
 		case JSON_TRUE[0]:
+			b := make([]byte, len(JSON_TRUE))
+			_, err := p.reader.Read(b)
+			if err != nil {
+				return false, NewTProtocolException(err)
+			}
 			if string(b) == string(JSON_TRUE) {
-				p.reader.Read(b[0:len(JSON_TRUE)])
 				value = true
 			} else {
 				e := fmt.Errorf("Expected \"true\" but found: %s", string(b))
@@ -427,8 +431,12 @@
 			}
 			break
 		case JSON_FALSE[0]:
-			if string(b) == string(JSON_FALSE[:len(b)]) {
-				p.reader.Read(b[0:len(JSON_FALSE)])
+			b := make([]byte, len(JSON_FALSE))
+			_, err := p.reader.Read(b)
+			if err != nil {
+				return false, NewTProtocolException(err)
+			}
+			if string(b) == string(JSON_FALSE) {
 				value = false
 			} else {
 				e := fmt.Errorf("Expected \"false\" but found: %s", string(b))
@@ -436,15 +444,19 @@
 			}
 			break
 		case JSON_NULL[0]:
+			b := make([]byte, len(JSON_NULL))
+			_, err := p.reader.Read(b)
+			if err != nil {
+				return false, NewTProtocolException(err)
+			}
 			if string(b) == string(JSON_NULL) {
-				p.reader.Read(b[0:len(JSON_NULL)])
 				value = false
 			} else {
 				e := fmt.Errorf("Expected \"null\" but found: %s", string(b))
 				return value, NewTProtocolExceptionWithType(INVALID_DATA, e)
 			}
 		default:
-			e := fmt.Errorf("Expected \"true\", \"false\", or \"null\" but found: %s", string(b))
+			e := fmt.Errorf("Expected \"true\", \"false\", or \"null\" but found: %s", string(f))
 			return value, NewTProtocolExceptionWithType(INVALID_DATA, e)
 		}
 	}
@@ -481,22 +493,26 @@
 	if err := p.ParsePreValue(); err != nil {
 		return v, err
 	}
-	var b []byte
-	b, _ = p.reader.Peek(len(JSON_NULL))
-	if len(b) > 0 && b[0] == JSON_QUOTE {
+	f, _ := p.reader.Peek(1)
+	if len(f) > 0 && f[0] == JSON_QUOTE {
 		p.reader.ReadByte()
 		value, err := p.ParseStringBody()
 		v = value
 		if err != nil {
 			return v, err
 		}
-	} else if len(b) >= len(JSON_NULL) && string(b[0:len(JSON_NULL)]) == string(JSON_NULL) {
-		_, err := p.reader.Read(b[0:len(JSON_NULL)])
+	} else if len(f) >= 0 && f[0] == JSON_NULL[0] {
+		b := make([]byte, len(JSON_NULL))
+		_, err := p.reader.Read(b)
 		if err != nil {
 			return v, NewTProtocolException(err)
 		}
+		if string(b) != string(JSON_NULL) {
+			e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(b))
+			return v, NewTProtocolExceptionWithType(INVALID_DATA, e)
+		}
 	} else {
-		e := fmt.Errorf("Expected a JSON string, found %s", string(b))
+		e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(f))
 		return v, NewTProtocolExceptionWithType(INVALID_DATA, e)
 	}
 	return v, p.ParsePostValue()
@@ -507,23 +523,29 @@
 	if err := p.ParsePreValue(); err != nil {
 		return nil, err
 	}
-	b, _ := p.reader.Peek(len(JSON_NULL))
-	if len(b) > 0 && b[0] == JSON_QUOTE {
+	f, _ := p.reader.Peek(1)
+	if len(f) > 0 && f[0] == JSON_QUOTE {
 		p.reader.ReadByte()
 		value, err := p.ParseBase64EncodedBody()
 		v = value
 		if err != nil {
 			return v, err
 		}
-	} else if len(b) >= len(JSON_NULL) && string(b[0:len(JSON_NULL)]) == string(JSON_NULL) {
-		_, err := p.reader.Read(b[0:len(JSON_NULL)])
+	} else if len(f) >= 0 && f[0] == JSON_NULL[0] {
+		b := make([]byte, len(JSON_NULL))
+		_, err := p.reader.Read(b)
 		if err != nil {
 			return v, NewTProtocolException(err)
 		}
+		if string(b) != string(JSON_NULL) {
+			e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(b))
+			return v, NewTProtocolExceptionWithType(INVALID_DATA, e)
+		}
 	} else {
-		e := fmt.Errorf("Expected a JSON string, found %s", string(b))
+		e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(f))
 		return v, NewTProtocolExceptionWithType(INVALID_DATA, e)
 	}
+
 	return v, p.ParsePostValue()
 }
 
@@ -898,9 +920,8 @@
 	}
 	var value int64
 	var isnull bool
-	b, _ := p.reader.Peek(len(JSON_NULL))
-	if len(b) >= len(JSON_NULL) && string(b) == string(JSON_NULL) {
-		p.reader.Read(b[0:len(JSON_NULL)])
+	if p.safePeekContains(JSON_NULL) {
+		p.reader.Read(make([]byte, len(JSON_NULL)))
 		isnull = true
 	} else {
 		num, err := p.readNumeric()
@@ -921,9 +942,8 @@
 	}
 	var value float64
 	var isnull bool
-	b, _ := p.reader.Peek(len(JSON_NULL))
-	if len(b) >= len(JSON_NULL) && string(b) == string(JSON_NULL) {
-		p.reader.Read(b[0:len(JSON_NULL)])
+	if p.safePeekContains(JSON_NULL) {
+		p.reader.Read(make([]byte, len(JSON_NULL)))
 		isnull = true
 	} else {
 		num, err := p.readNumeric()
@@ -943,12 +963,15 @@
 		return false, err
 	}
 	var b []byte
-	b, _ = p.reader.Peek(len(JSON_NULL))
+	b, err := p.reader.Peek(1)
+	if err != nil {
+		return false, err
+	}
 	if len(b) > 0 && b[0] == JSON_LBRACE[0] {
 		p.reader.ReadByte()
 		p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_OBJECT_FIRST))
 		return false, nil
-	} else if len(b) >= len(JSON_NULL) && string(b[0:len(JSON_NULL)]) == string(JSON_NULL) {
+	} else if p.safePeekContains(JSON_NULL) {
 		return true, nil
 	}
 	e := fmt.Errorf("Expected '{' or null, but found '%s'", string(b))
@@ -986,7 +1009,7 @@
 		return false, e
 	}
 	var b []byte
-	b, err = p.reader.Peek(len(JSON_NULL))
+	b, err = p.reader.Peek(1)
 	if err != nil {
 		return false, err
 	}
@@ -994,7 +1017,7 @@
 		p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_LIST_FIRST))
 		p.reader.ReadByte()
 		isNull = false
-	} else if len(b) >= len(JSON_NULL) && string(b) == string(JSON_NULL) {
+	} else if p.safePeekContains(JSON_NULL) {
 		isNull = true
 	} else {
 		err = fmt.Errorf("Expected \"null\" or \"[\", received %q", b)
@@ -1038,6 +1061,9 @@
 		}
 	}
 	p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1]
+	if _ParseContext(p.parseContextStack[len(p.parseContextStack)-1]) == _CONTEXT_IN_TOPLEVEL {
+		return nil
+	}
 	return p.ParsePostValue()
 }
 
@@ -1046,7 +1072,7 @@
 	if e != nil {
 		return nil, VOID, NewTProtocolException(e)
 	}
-	b, e := p.reader.Peek(10)
+	b, e := p.reader.Peek(1)
 	if len(b) > 0 {
 		c := b[0]
 		switch c {
@@ -1135,9 +1161,8 @@
 			break
 		}
 	}
-	b, _ := p.reader.Peek(len(JSON_NULL))
-	if string(b) == string(JSON_NULL) {
-		p.reader.Read(b[0:len(JSON_NULL)])
+	if p.safePeekContains(JSON_NULL) {
+		p.reader.Read(make([]byte, len(JSON_NULL)))
 		return true, nil
 	}
 	return false, nil
@@ -1275,3 +1300,14 @@
 	}
 	return NewNumericFromJSONString(buf.String(), false), nil
 }
+
+// Safely peeks into the buffer, reading only what is necessary
+func (p *TSimpleJSONProtocol) safePeekContains(b []byte) bool {
+	for i := 0; i < len(b); i++ {
+		a, _ := p.reader.Peek(i + 1)
+		if len(a) == 0 || a[i] != b[i] {
+			return false
+		}
+	}
+	return true
+}
diff --git a/lib/go/thrift/simple_json_protocol_test.go b/lib/go/thrift/simple_json_protocol_test.go
index 87a5c64..1abff75 100644
--- a/lib/go/thrift/simple_json_protocol_test.go
+++ b/lib/go/thrift/simple_json_protocol_test.go
@@ -221,6 +221,27 @@
 	}
 }
 
+func TestReadSimpleJSONProtocolI32Null(t *testing.T) {
+	thetype := "int32"
+	value := "null"
+
+	trans := NewTMemoryBuffer()
+	p := NewTSimpleJSONProtocol(trans)
+	trans.WriteString(value)
+	trans.Flush()
+	s := trans.String()
+	v, e := p.ReadI32()
+
+	if e != nil {
+		t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
+	}
+	if v != 0 {
+		t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v)
+	}
+	trans.Reset()
+	trans.Close()
+}
+
 func TestWriteSimpleJSONProtocolI64(t *testing.T) {
 	thetype := "int64"
 	trans := NewTMemoryBuffer()
@@ -268,6 +289,27 @@
 	}
 }
 
+func TestReadSimpleJSONProtocolI64Null(t *testing.T) {
+	thetype := "int32"
+	value := "null"
+
+	trans := NewTMemoryBuffer()
+	p := NewTSimpleJSONProtocol(trans)
+	trans.WriteString(value)
+	trans.Flush()
+	s := trans.String()
+	v, e := p.ReadI64()
+
+	if e != nil {
+		t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
+	}
+	if v != 0 {
+		t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v)
+	}
+	trans.Reset()
+	trans.Close()
+}
+
 func TestWriteSimpleJSONProtocolDouble(t *testing.T) {
 	thetype := "double"
 	trans := NewTMemoryBuffer()
@@ -391,6 +433,25 @@
 		trans.Close()
 	}
 }
+func TestReadSimpleJSONProtocolStringNull(t *testing.T) {
+	thetype := "string"
+	value := "null"
+
+	trans := NewTMemoryBuffer()
+	p := NewTSimpleJSONProtocol(trans)
+	trans.WriteString(value)
+	trans.Flush()
+	s := trans.String()
+	v, e := p.ReadString()
+	if e != nil {
+		t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
+	}
+	if v != "" {
+		t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v)
+	}
+	trans.Reset()
+	trans.Close()
+}
 
 func TestWriteSimpleJSONProtocolBinary(t *testing.T) {
 	thetype := "binary"
@@ -448,6 +509,28 @@
 	trans.Close()
 }
 
+func TestReadSimpleJSONProtocolBinaryNull(t *testing.T) {
+	thetype := "binary"
+	value := "null"
+
+	trans := NewTMemoryBuffer()
+	p := NewTSimpleJSONProtocol(trans)
+	trans.WriteString(value)
+	trans.Flush()
+	s := trans.String()
+	b, e := p.ReadBinary()
+	v := string(b)
+
+	if e != nil {
+		t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
+	}
+	if v != "" {
+		t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v)
+	}
+	trans.Reset()
+	trans.Close()
+}
+
 func TestWriteSimpleJSONProtocolList(t *testing.T) {
 	thetype := "list"
 	trans := NewTMemoryBuffer()