THRIFT-5294: Fix panic in go TSimpleJSONProtocol
Client: go
In go library's TSimpleJSONProtocol and TJSONProtocol implementations,
we use slices as stacks for context info, but didn't do proper boundary
check when peeking/popping, result in it might panic with using -1 as
slice index in certain cases of calling Write*End without matching
Write*Begin before.
Refactor the code to properly implement the stack, and return a
TProtocolException instead on those cases.
Also add unit tests for all protocols. The unit tests shown that
TCompactProtocol.[Read|Write]StructEnd would also panic with unmatched
Begin calls, so fix them as well.
diff --git a/lib/go/thrift/compact_protocol.go b/lib/go/thrift/compact_protocol.go
index 8510f1f..a016195 100644
--- a/lib/go/thrift/compact_protocol.go
+++ b/lib/go/thrift/compact_protocol.go
@@ -22,6 +22,7 @@
import (
"context"
"encoding/binary"
+ "errors"
"fmt"
"io"
"math"
@@ -158,6 +159,9 @@
// this as an opportunity to pop the last field from the current struct off
// of the field stack.
func (p *TCompactProtocol) WriteStructEnd(ctx context.Context) error {
+ if len(p.lastField) <= 0 {
+ return NewTProtocolExceptionWithType(INVALID_DATA, errors.New("WriteStructEnd called without matching WriteStructBegin call before"))
+ }
p.lastFieldId = p.lastField[len(p.lastField)-1]
p.lastField = p.lastField[:len(p.lastField)-1]
return nil
@@ -386,6 +390,9 @@
// this struct from the field stack.
func (p *TCompactProtocol) ReadStructEnd(ctx context.Context) error {
// consume the last field we read off the wire.
+ if len(p.lastField) <= 0 {
+ return NewTProtocolExceptionWithType(INVALID_DATA, errors.New("ReadStructEnd called without matching ReadStructBegin call before"))
+ }
p.lastFieldId = p.lastField[len(p.lastField)-1]
p.lastField = p.lastField[:len(p.lastField)-1]
return nil
diff --git a/lib/go/thrift/json_protocol.go b/lib/go/thrift/json_protocol.go
index 9a9328d..edc49cc 100644
--- a/lib/go/thrift/json_protocol.go
+++ b/lib/go/thrift/json_protocol.go
@@ -41,8 +41,8 @@
// Constructor
func NewTJSONProtocol(t TTransport) *TJSONProtocol {
v := &TJSONProtocol{TSimpleJSONProtocol: NewTSimpleJSONProtocol(t)}
- v.parseContextStack = append(v.parseContextStack, int(_CONTEXT_IN_TOPLEVEL))
- v.dumpContext = append(v.dumpContext, int(_CONTEXT_IN_TOPLEVEL))
+ v.parseContextStack.push(_CONTEXT_IN_TOPLEVEL)
+ v.dumpContext.push(_CONTEXT_IN_TOPLEVEL)
return v
}
diff --git a/lib/go/thrift/json_protocol_test.go b/lib/go/thrift/json_protocol_test.go
index 333d383..39e52d1 100644
--- a/lib/go/thrift/json_protocol_test.go
+++ b/lib/go/thrift/json_protocol_test.go
@@ -648,3 +648,7 @@
}
trans.Close()
}
+
+func TestTJSONProtocolUnmatchedBeginEnd(t *testing.T) {
+ UnmatchedBeginEndProtocolTest(t, NewTJSONProtocolFactory())
+}
diff --git a/lib/go/thrift/protocol_test.go b/lib/go/thrift/protocol_test.go
index c1c67e8..caac78e 100644
--- a/lib/go/thrift/protocol_test.go
+++ b/lib/go/thrift/protocol_test.go
@@ -217,6 +217,10 @@
ReadWriteByte(t, p, trans)
trans.Close()
}
+
+ t.Run("UnmatchedBeginEnd", func(t *testing.T) {
+ UnmatchedBeginEndProtocolTest(t, protocolFactory)
+ })
}
func ReadWriteBool(t testing.TB, p TProtocol, trans TTransport) {
@@ -515,3 +519,88 @@
}
}
}
+
+func UnmatchedBeginEndProtocolTest(t *testing.T, protocolFactory TProtocolFactory) {
+ // NOTE: not all protocol implementations do strict state check to
+ // return an error on unmatched Begin/End calls.
+ // This test is only meant to make sure that those unmatched Begin/End
+ // calls won't cause panic. There's no real "test" here.
+ trans := NewTMemoryBuffer()
+ t.Run("Read", func(t *testing.T) {
+ t.Run("Message", func(t *testing.T) {
+ trans.Reset()
+ p := protocolFactory.GetProtocol(trans)
+ p.ReadMessageEnd(context.Background())
+ p.ReadMessageEnd(context.Background())
+ })
+ t.Run("Struct", func(t *testing.T) {
+ trans.Reset()
+ p := protocolFactory.GetProtocol(trans)
+ p.ReadStructEnd(context.Background())
+ p.ReadStructEnd(context.Background())
+ })
+ t.Run("Field", func(t *testing.T) {
+ trans.Reset()
+ p := protocolFactory.GetProtocol(trans)
+ p.ReadFieldEnd(context.Background())
+ p.ReadFieldEnd(context.Background())
+ })
+ t.Run("Map", func(t *testing.T) {
+ trans.Reset()
+ p := protocolFactory.GetProtocol(trans)
+ p.ReadMapEnd(context.Background())
+ p.ReadMapEnd(context.Background())
+ })
+ t.Run("List", func(t *testing.T) {
+ trans.Reset()
+ p := protocolFactory.GetProtocol(trans)
+ p.ReadListEnd(context.Background())
+ p.ReadListEnd(context.Background())
+ })
+ t.Run("Set", func(t *testing.T) {
+ trans.Reset()
+ p := protocolFactory.GetProtocol(trans)
+ p.ReadSetEnd(context.Background())
+ p.ReadSetEnd(context.Background())
+ })
+ })
+ t.Run("Write", func(t *testing.T) {
+ t.Run("Message", func(t *testing.T) {
+ trans.Reset()
+ p := protocolFactory.GetProtocol(trans)
+ p.WriteMessageEnd(context.Background())
+ p.WriteMessageEnd(context.Background())
+ })
+ t.Run("Struct", func(t *testing.T) {
+ trans.Reset()
+ p := protocolFactory.GetProtocol(trans)
+ p.WriteStructEnd(context.Background())
+ p.WriteStructEnd(context.Background())
+ })
+ t.Run("Field", func(t *testing.T) {
+ trans.Reset()
+ p := protocolFactory.GetProtocol(trans)
+ p.WriteFieldEnd(context.Background())
+ p.WriteFieldEnd(context.Background())
+ })
+ t.Run("Map", func(t *testing.T) {
+ trans.Reset()
+ p := protocolFactory.GetProtocol(trans)
+ p.WriteMapEnd(context.Background())
+ p.WriteMapEnd(context.Background())
+ })
+ t.Run("List", func(t *testing.T) {
+ trans.Reset()
+ p := protocolFactory.GetProtocol(trans)
+ p.WriteListEnd(context.Background())
+ p.WriteListEnd(context.Background())
+ })
+ t.Run("Set", func(t *testing.T) {
+ trans.Reset()
+ p := protocolFactory.GetProtocol(trans)
+ p.WriteSetEnd(context.Background())
+ p.WriteSetEnd(context.Background())
+ })
+ })
+ trans.Close()
+}
diff --git a/lib/go/thrift/simple_json_protocol.go b/lib/go/thrift/simple_json_protocol.go
index d101b99..e94b44b 100644
--- a/lib/go/thrift/simple_json_protocol.go
+++ b/lib/go/thrift/simple_json_protocol.go
@@ -25,6 +25,7 @@
"context"
"encoding/base64"
"encoding/json"
+ "errors"
"fmt"
"io"
"math"
@@ -34,12 +35,13 @@
type _ParseContext int
const (
- _CONTEXT_IN_TOPLEVEL _ParseContext = 1
- _CONTEXT_IN_LIST_FIRST _ParseContext = 2
- _CONTEXT_IN_LIST _ParseContext = 3
- _CONTEXT_IN_OBJECT_FIRST _ParseContext = 4
- _CONTEXT_IN_OBJECT_NEXT_KEY _ParseContext = 5
- _CONTEXT_IN_OBJECT_NEXT_VALUE _ParseContext = 6
+ _CONTEXT_INVALID _ParseContext = iota
+ _CONTEXT_IN_TOPLEVEL // 1
+ _CONTEXT_IN_LIST_FIRST // 2
+ _CONTEXT_IN_LIST // 3
+ _CONTEXT_IN_OBJECT_FIRST // 4
+ _CONTEXT_IN_OBJECT_NEXT_KEY // 5
+ _CONTEXT_IN_OBJECT_NEXT_VALUE // 6
)
func (p _ParseContext) String() string {
@@ -60,6 +62,32 @@
return "UNKNOWN-PARSE-CONTEXT"
}
+type jsonContextStack []_ParseContext
+
+func (s *jsonContextStack) push(v _ParseContext) {
+ *s = append(*s, v)
+}
+
+func (s jsonContextStack) peek() (v _ParseContext, ok bool) {
+ l := len(s)
+ if l <= 0 {
+ return
+ }
+ return s[l-1], true
+}
+
+func (s *jsonContextStack) pop() (v _ParseContext, ok bool) {
+ l := len(*s)
+ if l <= 0 {
+ return
+ }
+ v = (*s)[l-1]
+ *s = (*s)[0 : l-1]
+ return v, true
+}
+
+var errEmptyJSONContextStack = NewTProtocolExceptionWithType(INVALID_DATA, errors.New("Unexpected empty json protocol context stack"))
+
// Simple JSON protocol implementation for thrift.
//
// This protocol produces/consumes a simple output format
@@ -69,8 +97,8 @@
type TSimpleJSONProtocol struct {
trans TTransport
- parseContextStack []int
- dumpContext []int
+ parseContextStack jsonContextStack
+ dumpContext jsonContextStack
writer *bufio.Writer
reader *bufio.Reader
@@ -82,8 +110,8 @@
writer: bufio.NewWriter(t),
reader: bufio.NewReader(t),
}
- v.parseContextStack = append(v.parseContextStack, int(_CONTEXT_IN_TOPLEVEL))
- v.dumpContext = append(v.dumpContext, int(_CONTEXT_IN_TOPLEVEL))
+ v.parseContextStack.push(_CONTEXT_IN_TOPLEVEL)
+ v.dumpContext.push(_CONTEXT_IN_TOPLEVEL)
return v
}
@@ -549,41 +577,41 @@
}
func (p *TSimpleJSONProtocol) OutputPreValue() error {
- cxt := _ParseContext(p.dumpContext[len(p.dumpContext)-1])
+ cxt, ok := p.dumpContext.peek()
+ if !ok {
+ return errEmptyJSONContextStack
+ }
switch cxt {
case _CONTEXT_IN_LIST, _CONTEXT_IN_OBJECT_NEXT_KEY:
if _, e := p.write(JSON_COMMA); e != nil {
return NewTProtocolException(e)
}
- break
case _CONTEXT_IN_OBJECT_NEXT_VALUE:
if _, e := p.write(JSON_COLON); e != nil {
return NewTProtocolException(e)
}
- break
}
return nil
}
func (p *TSimpleJSONProtocol) OutputPostValue() error {
- cxt := _ParseContext(p.dumpContext[len(p.dumpContext)-1])
+ cxt, ok := p.dumpContext.peek()
+ if !ok {
+ return errEmptyJSONContextStack
+ }
switch cxt {
case _CONTEXT_IN_LIST_FIRST:
- p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
- p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_LIST))
- break
+ p.dumpContext.pop()
+ p.dumpContext.push(_CONTEXT_IN_LIST)
case _CONTEXT_IN_OBJECT_FIRST:
- p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
- p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_NEXT_VALUE))
- break
+ p.dumpContext.pop()
+ p.dumpContext.push(_CONTEXT_IN_OBJECT_NEXT_VALUE)
case _CONTEXT_IN_OBJECT_NEXT_KEY:
- p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
- p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_NEXT_VALUE))
- break
+ p.dumpContext.pop()
+ p.dumpContext.push(_CONTEXT_IN_OBJECT_NEXT_VALUE)
case _CONTEXT_IN_OBJECT_NEXT_VALUE:
- p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
- p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_NEXT_KEY))
- break
+ p.dumpContext.pop()
+ p.dumpContext.push(_CONTEXT_IN_OBJECT_NEXT_KEY)
}
return nil
}
@@ -598,10 +626,13 @@
} else {
v = string(JSON_FALSE)
}
- switch _ParseContext(p.dumpContext[len(p.dumpContext)-1]) {
+ cxt, ok := p.dumpContext.peek()
+ if !ok {
+ return errEmptyJSONContextStack
+ }
+ switch cxt {
case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY:
v = jsonQuote(v)
- default:
}
if e := p.OutputStringData(v); e != nil {
return e
@@ -631,11 +662,14 @@
} else if math.IsInf(value, -1) {
v = string(JSON_QUOTE) + JSON_NEGATIVE_INFINITY + string(JSON_QUOTE)
} else {
+ cxt, ok := p.dumpContext.peek()
+ if !ok {
+ return errEmptyJSONContextStack
+ }
v = strconv.FormatFloat(value, 'g', -1, 64)
- switch _ParseContext(p.dumpContext[len(p.dumpContext)-1]) {
+ switch cxt {
case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY:
v = string(JSON_QUOTE) + v + string(JSON_QUOTE)
- default:
}
}
if e := p.OutputStringData(v); e != nil {
@@ -648,11 +682,14 @@
if e := p.OutputPreValue(); e != nil {
return e
}
+ cxt, ok := p.dumpContext.peek()
+ if !ok {
+ return errEmptyJSONContextStack
+ }
v := strconv.FormatInt(value, 10)
- switch _ParseContext(p.dumpContext[len(p.dumpContext)-1]) {
+ switch cxt {
case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY:
v = jsonQuote(v)
- default:
}
if e := p.OutputStringData(v); e != nil {
return e
@@ -682,7 +719,7 @@
if _, e := p.write(JSON_LBRACE); e != nil {
return NewTProtocolException(e)
}
- p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_FIRST))
+ p.dumpContext.push(_CONTEXT_IN_OBJECT_FIRST)
return nil
}
@@ -690,7 +727,10 @@
if _, e := p.write(JSON_RBRACE); e != nil {
return NewTProtocolException(e)
}
- p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
+ _, ok := p.dumpContext.pop()
+ if !ok {
+ return errEmptyJSONContextStack
+ }
if e := p.OutputPostValue(); e != nil {
return e
}
@@ -704,7 +744,7 @@
if _, e := p.write(JSON_LBRACKET); e != nil {
return NewTProtocolException(e)
}
- p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_LIST_FIRST))
+ p.dumpContext.push(_CONTEXT_IN_LIST_FIRST)
return nil
}
@@ -712,7 +752,10 @@
if _, e := p.write(JSON_RBRACKET); e != nil {
return NewTProtocolException(e)
}
- p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
+ _, ok := p.dumpContext.pop()
+ if !ok {
+ return errEmptyJSONContextStack
+ }
if e := p.OutputPostValue(); e != nil {
return e
}
@@ -736,7 +779,10 @@
if e := p.readNonSignificantWhitespace(); e != nil {
return NewTProtocolException(e)
}
- cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1])
+ cxt, ok := p.parseContextStack.peek()
+ if !ok {
+ return errEmptyJSONContextStack
+ }
b, _ := p.reader.Peek(1)
switch cxt {
case _CONTEXT_IN_LIST:
@@ -755,7 +801,6 @@
return NewTProtocolExceptionWithType(INVALID_DATA, e)
}
}
- break
case _CONTEXT_IN_OBJECT_NEXT_KEY:
if len(b) > 0 {
switch b[0] {
@@ -772,7 +817,6 @@
return NewTProtocolExceptionWithType(INVALID_DATA, e)
}
}
- break
case _CONTEXT_IN_OBJECT_NEXT_VALUE:
if len(b) > 0 {
switch b[0] {
@@ -787,7 +831,6 @@
return NewTProtocolExceptionWithType(INVALID_DATA, e)
}
}
- break
}
return nil
}
@@ -796,20 +839,20 @@
if e := p.readNonSignificantWhitespace(); e != nil {
return NewTProtocolException(e)
}
- cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1])
+ cxt, ok := p.parseContextStack.peek()
+ if !ok {
+ return errEmptyJSONContextStack
+ }
switch cxt {
case _CONTEXT_IN_LIST_FIRST:
- p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1]
- p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_LIST))
- break
+ p.parseContextStack.pop()
+ p.parseContextStack.push(_CONTEXT_IN_LIST)
case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY:
- p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1]
- p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_OBJECT_NEXT_VALUE))
- break
+ p.parseContextStack.pop()
+ p.parseContextStack.push(_CONTEXT_IN_OBJECT_NEXT_VALUE)
case _CONTEXT_IN_OBJECT_NEXT_VALUE:
- p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1]
- p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_OBJECT_NEXT_KEY))
- break
+ p.parseContextStack.pop()
+ p.parseContextStack.push(_CONTEXT_IN_OBJECT_NEXT_KEY)
}
return nil
}
@@ -962,7 +1005,7 @@
}
if len(b) > 0 && b[0] == JSON_LBRACE[0] {
p.reader.ReadByte()
- p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_OBJECT_FIRST))
+ p.parseContextStack.push(_CONTEXT_IN_OBJECT_FIRST)
return false, nil
} else if p.safePeekContains(JSON_NULL) {
return true, nil
@@ -975,7 +1018,7 @@
if isNull, err := p.readIfNull(); isNull || err != nil {
return err
}
- cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1])
+ cxt, _ := p.parseContextStack.peek()
if (cxt != _CONTEXT_IN_OBJECT_FIRST) && (cxt != _CONTEXT_IN_OBJECT_NEXT_KEY) {
e := fmt.Errorf("Expected to be in the Object Context, but not in Object Context (%d)", cxt)
return NewTProtocolExceptionWithType(INVALID_DATA, e)
@@ -993,7 +1036,7 @@
break
}
}
- p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1]
+ p.parseContextStack.pop()
return p.ParsePostValue()
}
@@ -1007,7 +1050,7 @@
return false, err
}
if len(b) >= 1 && b[0] == JSON_LBRACKET[0] {
- p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_LIST_FIRST))
+ p.parseContextStack.push(_CONTEXT_IN_LIST_FIRST)
p.reader.ReadByte()
isNull = false
} else if p.safePeekContains(JSON_NULL) {
@@ -1036,7 +1079,7 @@
if isNull, err := p.readIfNull(); isNull || err != nil {
return err
}
- cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1])
+ cxt, _ := p.parseContextStack.peek()
if cxt != _CONTEXT_IN_LIST {
e := fmt.Errorf("Expected to be in the List Context, but not in List Context (%d)", cxt)
return NewTProtocolExceptionWithType(INVALID_DATA, e)
@@ -1054,8 +1097,10 @@
break
}
}
- p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1]
- if _ParseContext(p.parseContextStack[len(p.parseContextStack)-1]) == _CONTEXT_IN_TOPLEVEL {
+ p.parseContextStack.pop()
+ if cxt, ok := p.parseContextStack.peek(); !ok {
+ return errEmptyJSONContextStack
+ } else if cxt == _CONTEXT_IN_TOPLEVEL {
return nil
}
return p.ParsePostValue()
@@ -1308,8 +1353,8 @@
// Reset the context stack to its initial state.
func (p *TSimpleJSONProtocol) resetContextStack() {
- p.parseContextStack = []int{int(_CONTEXT_IN_TOPLEVEL)}
- p.dumpContext = []int{int(_CONTEXT_IN_TOPLEVEL)}
+ p.parseContextStack = jsonContextStack{_CONTEXT_IN_TOPLEVEL}
+ p.dumpContext = jsonContextStack{_CONTEXT_IN_TOPLEVEL}
}
func (p *TSimpleJSONProtocol) write(b []byte) (int, error) {
diff --git a/lib/go/thrift/simple_json_protocol_test.go b/lib/go/thrift/simple_json_protocol_test.go
index 986fff2..89753c6 100644
--- a/lib/go/thrift/simple_json_protocol_test.go
+++ b/lib/go/thrift/simple_json_protocol_test.go
@@ -736,3 +736,58 @@
t.Fatalf("Should not match at test 3")
}
}
+
+func TestJSONContextStack(t *testing.T) {
+ var stack jsonContextStack
+ t.Run("empty-peek", func(t *testing.T) {
+ v, ok := stack.peek()
+ if ok {
+ t.Error("peek() on empty should return ok: false")
+ }
+ expected := _CONTEXT_INVALID
+ if v != expected {
+ t.Errorf("Expected value from peek() to be %v(%d), got %v(%d)", expected, expected, v, v)
+ }
+ })
+ t.Run("empty-pop", func(t *testing.T) {
+ v, ok := stack.pop()
+ if ok {
+ t.Error("pop() on empty should return ok: false")
+ }
+ expected := _CONTEXT_INVALID
+ if v != expected {
+ t.Errorf("Expected value from pop() to be %v(%d), got %v(%d)", expected, expected, v, v)
+ }
+ })
+ t.Run("push-peek-pop", func(t *testing.T) {
+ expected := _CONTEXT_INVALID
+ stack.push(expected)
+ if len(stack) != 1 {
+ t.Errorf("Expected stack to be as size 1 after push, got %#v", stack)
+ }
+ v, ok := stack.peek()
+ if !ok {
+ t.Error("peek() on non-empty should return ok: true")
+ }
+ if v != expected {
+ t.Errorf("Expected value from peek() to be %v(%d), got %v(%d)", expected, expected, v, v)
+ }
+ if len(stack) != 1 {
+ t.Errorf("Expected peek() to be read-only, got %#v", stack)
+ }
+ v, ok = stack.pop()
+ if !ok {
+ t.Error("pop() on non-empty should return ok: true")
+ }
+ if v != expected {
+ t.Errorf("Expected value from pop() to be %v(%d), got %v(%d)", expected, expected, v, v)
+ }
+ if len(stack) != 0 {
+ t.Errorf("Expected pop() to empty the stack, got %#v", stack)
+ }
+ })
+}
+
+func TestTSimpleJSONProtocolUnmatchedBeginEnd(t *testing.T) {
+ UnmatchedBeginEndProtocolTest(t, NewTSimpleJSONProtocolFactory())
+}