go: Make socketConn.Close thread-safe
Client: go
We used to rely on setting the connection inside TSocket/TSSLSocket as
nil after Close is called to mark the connection as closed, but that is
not thread safe and causing TSocket.Close/TSSLSocket.Close cannot be
called concurrently. Use an atomic int to mark closure instead.
diff --git a/lib/go/thrift/socket.go b/lib/go/thrift/socket.go
index eeac4f1..cba7c0f 100644
--- a/lib/go/thrift/socket.go
+++ b/lib/go/thrift/socket.go
@@ -194,15 +194,7 @@
// Closes the socket.
func (p *TSocket) Close() error {
- // Close the socket
- if p.conn != nil {
- err := p.conn.Close()
- if err != nil {
- return err
- }
- p.conn = nil
- }
- return nil
+ return p.conn.Close()
}
//Returns the remote address of the socket.
diff --git a/lib/go/thrift/socket_conn.go b/lib/go/thrift/socket_conn.go
index c1cc30c..5619d96 100644
--- a/lib/go/thrift/socket_conn.go
+++ b/lib/go/thrift/socket_conn.go
@@ -21,6 +21,7 @@
import (
"net"
+ "sync/atomic"
)
// socketConn is a wrapped net.Conn that tries to do connectivity check.
@@ -28,6 +29,7 @@
net.Conn
buffer [1]byte
+ closed int32
}
var _ net.Conn = (*socketConn)(nil)
@@ -64,7 +66,7 @@
// It's the same as the previous implementation of TSocket.IsOpen and
// TSSLSocket.IsOpen before we added connectivity check.
func (sc *socketConn) isValid() bool {
- return sc != nil && sc.Conn != nil
+ return sc != nil && sc.Conn != nil && atomic.LoadInt32(&sc.closed) == 0
}
// IsOpen checks whether the connection is open.
@@ -100,3 +102,12 @@
return sc.Conn.Read(p)
}
+
+func (sc *socketConn) Close() error {
+ if !sc.isValid() {
+ // Already closed
+ return net.ErrClosed
+ }
+ atomic.StoreInt32(&sc.closed, 1)
+ return sc.Conn.Close()
+}
diff --git a/lib/go/thrift/ssl_socket.go b/lib/go/thrift/ssl_socket.go
index bee1097..d7ba415 100644
--- a/lib/go/thrift/ssl_socket.go
+++ b/lib/go/thrift/ssl_socket.go
@@ -220,15 +220,7 @@
// Closes the socket.
func (p *TSSLSocket) Close() error {
- // Close the socket
- if p.conn != nil {
- err := p.conn.Close()
- if err != nil {
- return err
- }
- p.conn = nil
- }
- return nil
+ return p.conn.Close()
}
func (p *TSSLSocket) Read(buf []byte) (int, error) {