THRIFT-5069: Make TDeserializer resource pool friendly
Client: go
This change improves performance when using TDeserializer with a
resource pool. See https://issues.apache.org/jira/browse/THRIFT-5069 for
more context.
Also add TSerializerPool and TDeserializerPool, which are thread-safe
versions of TSerializer and TDeserializer. Benchmark result shows that
they are both faster and use less memory than the plain version:
$ go test -bench Serializer -benchmem
goos: darwin
goarch: amd64
BenchmarkSerializer/baseline-8 577558 1930 ns/op 512 B/op 6 allocs/op
BenchmarkSerializer/plain-8 452712 2638 ns/op 2976 B/op 16 allocs/op
BenchmarkSerializer/pool-8 591698 2032 ns/op 512 B/op 6 allocs/op
PASS
diff --git a/CHANGES.md b/CHANGES.md
index e179a63..1dddab9 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -7,10 +7,16 @@
- [THRIFT-4990](https://issues.apache.org/jira/browse/THRIFT-4990) - Upgrade to .NET Core 3.1 (LTS)
- [THRIFT-4981](https://issues.apache.org/jira/browse/THRIFT-4981) - Remove deprecated netcore bindings from the code base
- [THRIFT-5006](https://issues.apache.org/jira/browse/THRIFT-5006) - Implement DEFAULT_MAX_LENGTH at TFramedTransport
+- [THRIFT-5069](https://issues.apache.org/jira/browse/THRIFT-5069) - In Go library TDeserializer.Transport is now typed \*TMemoryBuffer instead of TTransport
### Java
- [THRIFT-5022](https://issues.apache.org/jira/browse/THRIFT-5022) - TIOStreamTransport.isOpen returns true for one-sided transports (see THRIFT-2530).
+
+### Go
+
+- [THRIFT-5069](https://issues.apache.org/jira/browse/THRIFT-5069) - Add TSerializerPool and TDeserializerPool, which are thread-safe versions of TSerializer and TDeserializer.
+
## 0.13.0
### New Languages
diff --git a/lib/go/thrift/deserializer.go b/lib/go/thrift/deserializer.go
index 91a0983..2ab8214 100644
--- a/lib/go/thrift/deserializer.go
+++ b/lib/go/thrift/deserializer.go
@@ -19,14 +19,17 @@
package thrift
+import (
+ "sync"
+)
+
type TDeserializer struct {
- Transport TTransport
+ Transport *TMemoryBuffer
Protocol TProtocol
}
func NewTDeserializer() *TDeserializer {
- var transport TTransport
- transport = NewTMemoryBufferLen(1024)
+ transport := NewTMemoryBufferLen(1024)
protocol := NewTBinaryProtocolFactoryDefault().GetProtocol(transport)
@@ -36,6 +39,8 @@
}
func (t *TDeserializer) ReadString(msg TStruct, s string) (err error) {
+ t.Transport.Reset()
+
err = nil
if _, err = t.Transport.Write([]byte(s)); err != nil {
return
@@ -47,6 +52,8 @@
}
func (t *TDeserializer) Read(msg TStruct, b []byte) (err error) {
+ t.Transport.Reset()
+
err = nil
if _, err = t.Transport.Write(b); err != nil {
return
@@ -56,3 +63,36 @@
}
return
}
+
+// TDeserializerPool is the thread-safe version of TDeserializer,
+// it uses resource pool of TDeserializer under the hood.
+//
+// It must be initialized with NewTDeserializerPool.
+type TDeserializerPool struct {
+ pool sync.Pool
+}
+
+// NewTDeserializerPool creates a new TDeserializerPool.
+//
+// NewTDeserializer can be used as the arg here.
+func NewTDeserializerPool(f func() *TDeserializer) *TDeserializerPool {
+ return &TDeserializerPool{
+ pool: sync.Pool{
+ New: func() interface{} {
+ return f()
+ },
+ },
+ }
+}
+
+func (t *TDeserializerPool) ReadString(msg TStruct, s string) error {
+ d := t.pool.Get().(*TDeserializer)
+ defer t.pool.Put(d)
+ return d.ReadString(msg, s)
+}
+
+func (t *TDeserializerPool) Read(msg TStruct, b []byte) error {
+ d := t.pool.Get().(*TDeserializer)
+ defer t.pool.Put(d)
+ return d.Read(msg, b)
+}
diff --git a/lib/go/thrift/serializer.go b/lib/go/thrift/serializer.go
index 1ff4d37..d85d204 100644
--- a/lib/go/thrift/serializer.go
+++ b/lib/go/thrift/serializer.go
@@ -21,6 +21,7 @@
import (
"context"
+ "sync"
)
type TSerializer struct {
@@ -77,3 +78,36 @@
b = append(b, t.Transport.Bytes()...)
return
}
+
+// TSerializerPool is the thread-safe version of TSerializer, it uses resource
+// pool of TSerializer under the hood.
+//
+// It must be initialized with NewTSerializerPool.
+type TSerializerPool struct {
+ pool sync.Pool
+}
+
+// NewTSerializerPool creates a new TSerializerPool.
+//
+// NewTSerializer can be used as the arg here.
+func NewTSerializerPool(f func() *TSerializer) *TSerializerPool {
+ return &TSerializerPool{
+ pool: sync.Pool{
+ New: func() interface{} {
+ return f()
+ },
+ },
+ }
+}
+
+func (t *TSerializerPool) WriteString(ctx context.Context, msg TStruct) (string, error) {
+ s := t.pool.Get().(*TSerializer)
+ defer t.pool.Put(s)
+ return s.WriteString(ctx, msg)
+}
+
+func (t *TSerializerPool) Write(ctx context.Context, msg TStruct) ([]byte, error) {
+ s := t.pool.Get().(*TSerializer)
+ defer t.pool.Put(s)
+ return s.Write(ctx, msg)
+}
diff --git a/lib/go/thrift/serializer_test.go b/lib/go/thrift/serializer_test.go
index 32227ef..52ebdca 100644
--- a/lib/go/thrift/serializer_test.go
+++ b/lib/go/thrift/serializer_test.go
@@ -23,122 +23,193 @@
"context"
"errors"
"fmt"
+ "sync"
+ "sync/atomic"
"testing"
+ "testing/quick"
)
type ProtocolFactory interface {
GetProtocol(t TTransport) TProtocol
}
-func compareStructs(m, m1 MyTestStruct) (bool, error) {
+func compareStructs(m, m1 MyTestStruct) error {
switch {
case m.On != m1.On:
- return false, errors.New("Boolean not equal")
+ return errors.New("Boolean not equal")
case m.B != m1.B:
- return false, errors.New("Byte not equal")
+ return errors.New("Byte not equal")
case m.Int16 != m1.Int16:
- return false, errors.New("Int16 not equal")
+ return errors.New("Int16 not equal")
case m.Int32 != m1.Int32:
- return false, errors.New("Int32 not equal")
+ return errors.New("Int32 not equal")
case m.Int64 != m1.Int64:
- return false, errors.New("Int64 not equal")
+ return errors.New("Int64 not equal")
case m.D != m1.D:
- return false, errors.New("Double not equal")
+ return errors.New("Double not equal")
case m.St != m1.St:
- return false, errors.New("String not equal")
+ return errors.New("String not equal")
case len(m.Bin) != len(m1.Bin):
- return false, errors.New("Binary size not equal")
+ return errors.New("Binary size not equal")
case len(m.Bin) == len(m1.Bin):
for i := range m.Bin {
if m.Bin[i] != m1.Bin[i] {
- return false, errors.New("Binary not equal")
+ return errors.New("Binary not equal")
}
}
case len(m.StringMap) != len(m1.StringMap):
- return false, errors.New("StringMap size not equal")
+ return errors.New("StringMap size not equal")
case len(m.StringList) != len(m1.StringList):
- return false, errors.New("StringList size not equal")
+ return errors.New("StringList size not equal")
case len(m.StringSet) != len(m1.StringSet):
- return false, errors.New("StringSet size not equal")
+ return errors.New("StringSet size not equal")
case m.E != m1.E:
- return false, errors.New("MyTestEnum not equal")
+ return errors.New("MyTestEnum not equal")
default:
- return true, nil
+ return nil
}
- return true, nil
+ return nil
}
-func ProtocolTest1(test *testing.T, pf ProtocolFactory) (bool, error) {
- t := NewTSerializer()
- t.Protocol = pf.GetProtocol(t.Transport)
- var m = MyTestStruct{}
- m.On = true
- m.B = int8(0)
- m.Int16 = 1
- m.Int32 = 2
- m.Int64 = 3
- m.D = 4.1
- m.St = "Test"
- m.Bin = make([]byte, 10)
- m.StringMap = make(map[string]string, 5)
- m.StringList = make([]string, 5)
- m.StringSet = make(map[string]struct{}, 5)
- m.E = 2
-
- s, err := t.WriteString(context.Background(), &m)
- if err != nil {
- return false, errors.New(fmt.Sprintf("Unable to Serialize struct\n\t %s", err))
- }
-
- t1 := NewTDeserializer()
- t1.Protocol = pf.GetProtocol(t1.Transport)
- var m1 = MyTestStruct{}
- if err = t1.ReadString(&m1, s); err != nil {
- return false, errors.New(fmt.Sprintf("Unable to Deserialize struct\n\t %s", err))
-
- }
-
- return compareStructs(m, m1)
-
+type serializer interface {
+ WriteString(context.Context, TStruct) (string, error)
}
-func ProtocolTest2(test *testing.T, pf ProtocolFactory) (bool, error) {
+type deserializer interface {
+ ReadString(TStruct, string) error
+}
+
+func plainSerializer(pf ProtocolFactory) serializer {
t := NewTSerializer()
t.Protocol = pf.GetProtocol(t.Transport)
- var m = MyTestStruct{}
- m.On = false
- m.B = int8(0)
- m.Int16 = 1
- m.Int32 = 2
- m.Int64 = 3
- m.D = 4.1
- m.St = "Test"
- m.Bin = make([]byte, 10)
- m.StringMap = make(map[string]string, 5)
- m.StringList = make([]string, 5)
- m.StringSet = make(map[string]struct{}, 5)
- m.E = 2
+ return t
+}
- s, err := t.WriteString(context.Background(), &m)
- if err != nil {
- return false, errors.New(fmt.Sprintf("Unable to Serialize struct\n\t %s", err))
+func poolSerializer(pf ProtocolFactory) serializer {
+ return NewTSerializerPool(
+ func() *TSerializer {
+ return plainSerializer(pf).(*TSerializer)
+ },
+ )
+}
+func plainDeserializer(pf ProtocolFactory) deserializer {
+ d := NewTDeserializer()
+ d.Protocol = pf.GetProtocol(d.Transport)
+ return d
+}
+
+func poolDeserializer(pf ProtocolFactory) deserializer {
+ return NewTDeserializerPool(
+ func() *TDeserializer {
+ return plainDeserializer(pf).(*TDeserializer)
+ },
+ )
+}
+
+type constructors struct {
+ Label string
+ Serializer func(pf ProtocolFactory) serializer
+ Deserializer func(pf ProtocolFactory) deserializer
+}
+
+var implementations = []constructors{
+ {
+ Label: "plain",
+ Serializer: plainSerializer,
+ Deserializer: plainDeserializer,
+ },
+ {
+ Label: "pool",
+ Serializer: poolSerializer,
+ Deserializer: poolDeserializer,
+ },
+}
+
+func ProtocolTest1(t *testing.T, pf ProtocolFactory) {
+ for _, impl := range implementations {
+ t.Run(
+ impl.Label,
+ func(test *testing.T) {
+ t := impl.Serializer(pf)
+ var m = MyTestStruct{}
+ m.On = true
+ m.B = int8(0)
+ m.Int16 = 1
+ m.Int32 = 2
+ m.Int64 = 3
+ m.D = 4.1
+ m.St = "Test"
+ m.Bin = make([]byte, 10)
+ m.StringMap = make(map[string]string, 5)
+ m.StringList = make([]string, 5)
+ m.StringSet = make(map[string]struct{}, 5)
+ m.E = 2
+
+ s, err := t.WriteString(context.Background(), &m)
+ if err != nil {
+ test.Fatalf("Unable to Serialize struct: %v", err)
+
+ }
+
+ t1 := impl.Deserializer(pf)
+ var m1 MyTestStruct
+ if err = t1.ReadString(&m1, s); err != nil {
+ test.Fatalf("Unable to Deserialize struct: %v", err)
+
+ }
+
+ if err := compareStructs(m, m1); err != nil {
+ test.Error(err)
+ }
+ },
+ )
}
+}
- t1 := NewTDeserializer()
- t1.Protocol = pf.GetProtocol(t1.Transport)
- var m1 = MyTestStruct{}
- if err = t1.ReadString(&m1, s); err != nil {
- return false, errors.New(fmt.Sprintf("Unable to Deserialize struct\n\t %s", err))
+func ProtocolTest2(t *testing.T, pf ProtocolFactory) {
+ for _, impl := range implementations {
+ t.Run(
+ impl.Label,
+ func(test *testing.T) {
+ t := impl.Serializer(pf)
+ var m = MyTestStruct{}
+ m.On = false
+ m.B = int8(0)
+ m.Int16 = 1
+ m.Int32 = 2
+ m.Int64 = 3
+ m.D = 4.1
+ m.St = "Test"
+ m.Bin = make([]byte, 10)
+ m.StringMap = make(map[string]string, 5)
+ m.StringList = make([]string, 5)
+ m.StringSet = make(map[string]struct{}, 5)
+ m.E = 2
+ s, err := t.WriteString(context.Background(), &m)
+ if err != nil {
+ test.Fatalf("Unable to Serialize struct: %v", err)
+
+ }
+
+ t1 := impl.Deserializer(pf)
+ var m1 MyTestStruct
+ if err = t1.ReadString(&m1, s); err != nil {
+ test.Fatalf("Unable to Deserialize struct: %v", err)
+
+ }
+
+ if err := compareStructs(m, m1); err != nil {
+ test.Error(err)
+ }
+ },
+ )
}
-
- return compareStructs(m, m1)
-
}
func TestSerializer(t *testing.T) {
@@ -150,21 +221,123 @@
//protocol_factories["SimpleJSON"] = NewTSimpleJSONProtocolFactory() - write only, can't be read back by design
protocol_factories["JSON"] = NewTJSONProtocolFactory()
- var tests map[string]func(*testing.T, ProtocolFactory) (bool, error)
- tests = make(map[string]func(*testing.T, ProtocolFactory) (bool, error))
+ tests := make(map[string]func(*testing.T, ProtocolFactory))
tests["Test 1"] = ProtocolTest1
tests["Test 2"] = ProtocolTest2
//tests["Test 3"] = ProtocolTest3 // Example of how to add additional tests
for name, pf := range protocol_factories {
-
- for test, f := range tests {
-
- if s, err := f(t, pf); !s || err != nil {
- t.Errorf("%s Failed for %s protocol\n\t %s", test, name, err)
- }
-
- }
+ t.Run(
+ name,
+ func(t *testing.T) {
+ for label, f := range tests {
+ t.Run(
+ label,
+ func(t *testing.T) {
+ f(t, pf)
+ },
+ )
+ }
+ },
+ )
}
}
+
+func TestSerializerPoolAsync(t *testing.T) {
+ var wg sync.WaitGroup
+ var counter int64
+ s := NewTSerializerPool(NewTSerializer)
+ d := NewTDeserializerPool(NewTDeserializer)
+ f := func(i int64) bool {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ t.Run(
+ fmt.Sprintf("#%d-%d", atomic.AddInt64(&counter, 1), i),
+ func(t *testing.T) {
+ m := MyTestStruct{
+ Int64: i,
+ }
+ str, err := s.WriteString(context.Background(), &m)
+ if err != nil {
+ t.Fatal("serialize:", err)
+ }
+ var m1 MyTestStruct
+ if err = d.ReadString(&m1, str); err != nil {
+ t.Fatal("deserialize:", err)
+
+ }
+
+ if err := compareStructs(m, m1); err != nil {
+ t.Error(err)
+ }
+ },
+ )
+ }()
+ return true
+ }
+ quick.Check(f, nil)
+ wg.Wait()
+}
+
+func BenchmarkSerializer(b *testing.B) {
+ sharedSerializer := NewTSerializer()
+ poolSerializer := NewTSerializerPool(NewTSerializer)
+ sharedDeserializer := NewTDeserializer()
+ poolDeserializer := NewTDeserializerPool(NewTDeserializer)
+
+ cases := []struct {
+ Label string
+ Serializer func() serializer
+ Deserializer func() deserializer
+ }{
+ {
+ // Baseline uses shared plain serializer/deserializer
+ Label: "baseline",
+ Serializer: func() serializer {
+ return sharedSerializer
+ },
+ Deserializer: func() deserializer {
+ return sharedDeserializer
+ },
+ },
+ {
+ // Plain creates new serializer/deserializer on every run,
+ // as that's how it's used in real world
+ Label: "plain",
+ Serializer: func() serializer {
+ return NewTSerializer()
+ },
+ Deserializer: func() deserializer {
+ return NewTDeserializer()
+ },
+ },
+ {
+ // Pool uses the shared pool serializer/deserializer
+ Label: "pool",
+ Serializer: func() serializer {
+ return poolSerializer
+ },
+ Deserializer: func() deserializer {
+ return poolDeserializer
+ },
+ },
+ }
+
+ for _, c := range cases {
+ b.Run(
+ c.Label,
+ func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ s := c.Serializer()
+ m := MyTestStruct{}
+ str, _ := s.WriteString(context.Background(), &m)
+ var m1 MyTestStruct
+ d := c.Deserializer()
+ d.ReadString(&m1, str)
+ }
+ },
+ )
+ }
+}