Add a generic sync.Pool wrapper to go library
Since we dropped support of Go 1.18-, use generic to avoid dealing with
type assertions with interface{}/any.
While I'm here, also remove the usages of ioutil, as that's officially
marked as deprecated in Go 1.19.
Client: go
diff --git a/lib/go/thrift/buf_pool.go b/lib/go/thrift/buf_pool.go
deleted file mode 100644
index 9708ea0..0000000
--- a/lib/go/thrift/buf_pool.go
+++ /dev/null
@@ -1,52 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package thrift
-
-import (
- "bytes"
- "sync"
-)
-
-var bufPool = sync.Pool{
- New: func() interface{} {
- return new(bytes.Buffer)
- },
-}
-
-// getBufFromPool gets a buffer out of the pool and guarantees that it's reset
-// before return.
-func getBufFromPool() *bytes.Buffer {
- buf := bufPool.Get().(*bytes.Buffer)
- buf.Reset()
- return buf
-}
-
-// returnBufToPool returns a buffer to the pool, and sets it to nil to avoid
-// accidental usage after it's returned.
-//
-// You usually want to use it this way:
-//
-// buf := getBufFromPool()
-// defer returnBufToPool(&buf)
-// // use buf
-func returnBufToPool(buf **bytes.Buffer) {
- bufPool.Put(*buf)
- *buf = nil
-}
diff --git a/lib/go/thrift/deserializer.go b/lib/go/thrift/deserializer.go
index 2f2468b..0c68d6b 100644
--- a/lib/go/thrift/deserializer.go
+++ b/lib/go/thrift/deserializer.go
@@ -21,7 +21,6 @@
import (
"context"
- "sync"
)
type TDeserializer struct {
@@ -81,7 +80,7 @@
// It must be initialized with either NewTDeserializerPool or
// NewTDeserializerPoolSizeFactory.
type TDeserializerPool struct {
- pool sync.Pool
+ pool *pool[TDeserializer]
}
// NewTDeserializerPool creates a new TDeserializerPool.
@@ -89,11 +88,7 @@
// NewTDeserializer can be used as the arg here.
func NewTDeserializerPool(f func() *TDeserializer) *TDeserializerPool {
return &TDeserializerPool{
- pool: sync.Pool{
- New: func() interface{} {
- return f()
- },
- },
+ pool: newPool(f, nil),
}
}
@@ -104,28 +99,26 @@
// larger than that. It just dictates the initial size.
func NewTDeserializerPoolSizeFactory(size int, factory TProtocolFactory) *TDeserializerPool {
return &TDeserializerPool{
- pool: sync.Pool{
- New: func() interface{} {
- transport := NewTMemoryBufferLen(size)
- protocol := factory.GetProtocol(transport)
+ pool: newPool(func() *TDeserializer {
+ transport := NewTMemoryBufferLen(size)
+ protocol := factory.GetProtocol(transport)
- return &TDeserializer{
- Transport: transport,
- Protocol: protocol,
- }
- },
- },
+ return &TDeserializer{
+ Transport: transport,
+ Protocol: protocol,
+ }
+ }, nil),
}
}
func (t *TDeserializerPool) ReadString(ctx context.Context, msg TStruct, s string) error {
- d := t.pool.Get().(*TDeserializer)
- defer t.pool.Put(d)
+ d := t.pool.get()
+ defer t.pool.put(&d)
return d.ReadString(ctx, msg, s)
}
func (t *TDeserializerPool) Read(ctx context.Context, msg TStruct, b []byte) error {
- d := t.pool.Get().(*TDeserializer)
- defer t.pool.Put(d)
+ d := t.pool.get()
+ defer t.pool.put(&d)
return d.Read(ctx, msg, b)
}
diff --git a/lib/go/thrift/framed_transport.go b/lib/go/thrift/framed_transport.go
index c8bd35e..e3c323a 100644
--- a/lib/go/thrift/framed_transport.go
+++ b/lib/go/thrift/framed_transport.go
@@ -133,7 +133,7 @@
// Make sure we return the read buffer back to pool
// after we finished reading from it.
if p.readBuf != nil && p.readBuf.Len() == 0 {
- returnBufToPool(&p.readBuf)
+ bufPool.put(&p.readBuf)
}
}()
@@ -175,7 +175,7 @@
func (p *TFramedTransport) ensureWriteBufferBeforeWrite() {
if p.writeBuf == nil {
- p.writeBuf = getBufFromPool()
+ p.writeBuf = bufPool.get()
}
}
@@ -196,7 +196,7 @@
}
func (p *TFramedTransport) Flush(ctx context.Context) error {
- defer returnBufToPool(&p.writeBuf)
+ defer bufPool.put(&p.writeBuf)
size := p.writeBuf.Len()
buf := p.buffer[:4]
binary.BigEndian.PutUint32(buf, uint32(size))
@@ -215,9 +215,9 @@
func (p *TFramedTransport) readFrame() error {
if p.readBuf != nil {
- returnBufToPool(&p.readBuf)
+ bufPool.put(&p.readBuf)
}
- p.readBuf = getBufFromPool()
+ p.readBuf = bufPool.get()
buf := p.buffer[:4]
if _, err := io.ReadFull(p.reader, buf); err != nil {
diff --git a/lib/go/thrift/header_transport.go b/lib/go/thrift/header_transport.go
index 5ec0454..3aea5a9 100644
--- a/lib/go/thrift/header_transport.go
+++ b/lib/go/thrift/header_transport.go
@@ -370,7 +370,7 @@
// Read the frame fully into frameBuffer.
if t.frameBuffer == nil {
- t.frameBuffer = getBufFromPool()
+ t.frameBuffer = bufPool.get()
}
_, err = io.CopyN(t.frameBuffer, t.reader, int64(frameSize))
if err != nil {
@@ -407,7 +407,7 @@
// It closes frameReader, and also resets frame related states.
func (t *THeaderTransport) endOfFrame() error {
defer func() {
- returnBufToPool(&t.frameBuffer)
+ bufPool.put(&t.frameBuffer)
t.frameReader = nil
}()
return t.frameReader.Close()
@@ -572,7 +572,7 @@
// You need to call Flush to actually write them to the transport.
func (t *THeaderTransport) Write(p []byte) (int, error) {
if t.writeBuffer == nil {
- t.writeBuffer = getBufFromPool()
+ t.writeBuffer = bufPool.get()
}
return t.writeBuffer.Write(p)
}
@@ -583,7 +583,7 @@
return nil
}
- defer returnBufToPool(&t.writeBuffer)
+ defer bufPool.put(&t.writeBuffer)
switch t.clientType {
default:
@@ -633,8 +633,8 @@
}
}
- payload := getBufFromPool()
- defer returnBufToPool(&payload)
+ payload := bufPool.get()
+ defer bufPool.put(&payload)
meta := headerMeta{
MagicFlags: THeaderHeaderMagic + t.Flags&THeaderFlagsMask,
SequenceID: t.SequenceID,
diff --git a/lib/go/thrift/http_client.go b/lib/go/thrift/http_client.go
index ce62c96..a0f2066 100644
--- a/lib/go/thrift/http_client.go
+++ b/lib/go/thrift/http_client.go
@@ -24,7 +24,6 @@
"context"
"errors"
"io"
- "io/ioutil"
"net/http"
"net/url"
"strconv"
@@ -136,7 +135,7 @@
// reused. Errors are being ignored here because if the connection is invalid
// and this fails for some reason, the Close() method will do any remaining
// cleanup.
- io.Copy(ioutil.Discard, p.response.Body)
+ io.Copy(io.Discard, p.response.Body)
err = p.response.Body.Close()
}
diff --git a/lib/go/thrift/http_transport.go b/lib/go/thrift/http_transport.go
index bc69227..c84aba9 100644
--- a/lib/go/thrift/http_transport.go
+++ b/lib/go/thrift/http_transport.go
@@ -24,7 +24,6 @@
"io"
"net/http"
"strings"
- "sync"
)
// NewThriftHandlerFunc is a function that create a ready to use Apache Thrift Handler function
@@ -41,11 +40,9 @@
// gz transparently compresses the HTTP response if the client supports it.
func gz(handler http.HandlerFunc) http.HandlerFunc {
- sp := &sync.Pool{
- New: func() interface{} {
- return gzip.NewWriter(nil)
- },
- }
+ sp := newPool(func() *gzip.Writer {
+ return gzip.NewWriter(nil)
+ }, nil)
return func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
@@ -53,11 +50,11 @@
return
}
w.Header().Set("Content-Encoding", "gzip")
- gz := sp.Get().(*gzip.Writer)
+ gz := sp.get()
gz.Reset(w)
defer func() {
- _ = gz.Close()
- sp.Put(gz)
+ gz.Close()
+ sp.put(&gz)
}()
gzw := gzipResponseWriter{Writer: gz, ResponseWriter: w}
handler(gzw, r)
diff --git a/lib/go/thrift/pool.go b/lib/go/thrift/pool.go
new file mode 100644
index 0000000..1d623d4
--- /dev/null
+++ b/lib/go/thrift/pool.go
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+ "bytes"
+ "sync"
+)
+
+// pool is a generic sync.Pool wrapper with bells and whistles.
+type pool[T any] struct {
+ pool sync.Pool
+ reset func(*T)
+}
+
+// newPool creates a new pool.
+//
+// Both generate and reset are optional.
+// Default generate is just new(T),
+// When reset is nil we don't do any additional resetting when calling get.
+func newPool[T any](generate func() *T, reset func(*T)) *pool[T] {
+ if generate == nil {
+ generate = func() *T {
+ return new(T)
+ }
+ }
+ return &pool[T]{
+ pool: sync.Pool{
+ New: func() interface{} {
+ return generate()
+ },
+ },
+ reset: reset,
+ }
+}
+
+func (p *pool[T]) get() *T {
+ r := p.pool.Get().(*T)
+ if p.reset != nil {
+ p.reset(r)
+ }
+ return r
+}
+
+func (p *pool[T]) put(r **T) {
+ p.pool.Put(*r)
+ *r = nil
+}
+
+var bufPool = newPool(nil, func(buf *bytes.Buffer) {
+ buf.Reset()
+})
diff --git a/lib/go/thrift/pool_test.go b/lib/go/thrift/pool_test.go
new file mode 100644
index 0000000..c717e1d
--- /dev/null
+++ b/lib/go/thrift/pool_test.go
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+ "testing"
+ "testing/quick"
+)
+
+type poolTest int
+
+func TestPoolReset(t *testing.T) {
+ p := newPool(nil, func(elem *poolTest) {
+ *elem = 0
+ })
+ f := func(i int) (passed bool) {
+ pt := p.get()
+ defer func() {
+ p.put(&pt)
+ if pt != nil {
+ t.Errorf("Expected pt to be nil after put, got %#v", pt)
+ passed = false
+ }
+ }()
+ if *pt != 0 {
+ t.Errorf("Expected *pt to be reset to 0 after get, got %d", *pt)
+ }
+ *pt = poolTest(i)
+ return !t.Failed()
+ }
+ if err := quick.Check(f, nil); err != nil {
+ t.Error(err)
+ }
+}
diff --git a/lib/go/thrift/protocol_test.go b/lib/go/thrift/protocol_test.go
index caac78e..d66dc65 100644
--- a/lib/go/thrift/protocol_test.go
+++ b/lib/go/thrift/protocol_test.go
@@ -22,7 +22,7 @@
import (
"bytes"
"context"
- "io/ioutil"
+ "io"
"math"
"net"
"net/http"
@@ -60,7 +60,7 @@
type HTTPHeaderEchoServer struct{}
func (p *HTTPEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
- buf, err := ioutil.ReadAll(req.Body)
+ buf, err := io.ReadAll(req.Body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write(buf)
@@ -71,7 +71,7 @@
}
func (p *HTTPHeaderEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
- buf, err := ioutil.ReadAll(req.Body)
+ buf, err := io.ReadAll(req.Body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write(buf)
diff --git a/lib/go/thrift/serializer.go b/lib/go/thrift/serializer.go
index f4d9201..53a674e 100644
--- a/lib/go/thrift/serializer.go
+++ b/lib/go/thrift/serializer.go
@@ -21,7 +21,6 @@
import (
"context"
- "sync"
)
type TSerializer struct {
@@ -92,7 +91,7 @@
// It must be initialized with either NewTSerializerPool or
// NewTSerializerPoolSizeFactory.
type TSerializerPool struct {
- pool sync.Pool
+ pool *pool[TSerializer]
}
// NewTSerializerPool creates a new TSerializerPool.
@@ -100,11 +99,7 @@
// NewTSerializer can be used as the arg here.
func NewTSerializerPool(f func() *TSerializer) *TSerializerPool {
return &TSerializerPool{
- pool: sync.Pool{
- New: func() interface{} {
- return f()
- },
- },
+ pool: newPool(f, nil),
}
}
@@ -115,28 +110,26 @@
// larger than that. It just dictates the initial size.
func NewTSerializerPoolSizeFactory(size int, factory TProtocolFactory) *TSerializerPool {
return &TSerializerPool{
- pool: sync.Pool{
- New: func() interface{} {
- transport := NewTMemoryBufferLen(size)
- protocol := factory.GetProtocol(transport)
+ pool: newPool(func() *TSerializer {
+ transport := NewTMemoryBufferLen(size)
+ protocol := factory.GetProtocol(transport)
- return &TSerializer{
- Transport: transport,
- Protocol: protocol,
- }
- },
- },
+ return &TSerializer{
+ Transport: transport,
+ Protocol: protocol,
+ }
+ }, nil),
}
}
func (t *TSerializerPool) WriteString(ctx context.Context, msg TStruct) (string, error) {
- s := t.pool.Get().(*TSerializer)
- defer t.pool.Put(s)
+ s := t.pool.get()
+ defer t.pool.put(&s)
return s.WriteString(ctx, msg)
}
func (t *TSerializerPool) Write(ctx context.Context, msg TStruct) ([]byte, error) {
- s := t.pool.Get().(*TSerializer)
- defer t.pool.Put(s)
+ s := t.pool.get()
+ defer t.pool.put(&s)
return s.Write(ctx, msg)
}
diff --git a/lib/go/thrift/simple_server_test.go b/lib/go/thrift/simple_server_test.go
index e0cf151..f3a59ee 100644
--- a/lib/go/thrift/simple_server_test.go
+++ b/lib/go/thrift/simple_server_test.go
@@ -201,11 +201,11 @@
netConn, err := net.Dial("tcp", ln.Addr().String())
if err != nil || netConn == nil {
- t.Fatal("error when dial server")
+ t.Fatalf("error when dial server: %v", err)
}
time.Sleep(networkWaitDuration)
- serverStopTimeout := 50 * time.Millisecond
+ const serverStopTimeout = 50 * time.Millisecond
backupServerStopTimeout := ServerStopTimeout
t.Cleanup(func() {
ServerStopTimeout = backupServerStopTimeout
@@ -213,13 +213,12 @@
ServerStopTimeout = serverStopTimeout
st := time.Now()
- err = serv.Stop()
- if err != nil {
+ if err := serv.Stop(); err != nil {
t.Errorf("error when stop server:%v", err)
}
if elapsed := time.Since(st); elapsed < serverStopTimeout {
- t.Errorf("stop cost less time than server stop timeout, server stop timeout:%v,cost time:%v", ServerStopTimeout, elapsed)
+ t.Errorf("stop cost less time than server stop timeout, server stop timeout:%v,cost time:%v", serverStopTimeout, elapsed)
}
}