THRIFT-5214: Connectivity check on go's TSocket

Client: go

Implement connectivity check on go's TSocket and TSSLSocket for
non-Windows systems.

The implementation is inspired by
https://github.blog/2020-05-20-three-bugs-in-the-go-mysql-driver/
diff --git a/lib/go/thrift/socket.go b/lib/go/thrift/socket.go
index 558818a..7c765f5 100644
--- a/lib/go/thrift/socket.go
+++ b/lib/go/thrift/socket.go
@@ -26,7 +26,7 @@
 )
 
 type TSocket struct {
-	conn           net.Conn
+	conn           *socketConn
 	addr           net.Addr
 	connectTimeout time.Duration
 	socketTimeout  time.Duration
@@ -58,7 +58,7 @@
 
 // Creates a TSocket from an existing net.Conn
 func NewTSocketFromConnTimeout(conn net.Conn, connTimeout time.Duration) *TSocket {
-	return &TSocket{conn: conn, addr: conn.RemoteAddr(), connectTimeout: connTimeout, socketTimeout: connTimeout}
+	return &TSocket{conn: wrapSocketConn(conn), addr: conn.RemoteAddr(), connectTimeout: connTimeout, socketTimeout: connTimeout}
 }
 
 // Sets the connect timeout
@@ -89,7 +89,7 @@
 
 // Connects the socket, creating a new socket object if necessary.
 func (p *TSocket) Open() error {
-	if p.IsOpen() {
+	if p.conn.isValid() {
 		return NewTTransportException(ALREADY_OPEN, "Socket already connected.")
 	}
 	if p.addr == nil {
@@ -102,7 +102,11 @@
 		return NewTTransportException(NOT_OPEN, "Cannot open bad address.")
 	}
 	var err error
-	if p.conn, err = net.DialTimeout(p.addr.Network(), p.addr.String(), p.connectTimeout); err != nil {
+	if p.conn, err = createSocketConnFromReturn(net.DialTimeout(
+		p.addr.Network(),
+		p.addr.String(),
+		p.connectTimeout,
+	)); err != nil {
 		return NewTTransportException(NOT_OPEN, err.Error())
 	}
 	return nil
@@ -115,10 +119,7 @@
 
 // Returns true if the connection is open
 func (p *TSocket) IsOpen() bool {
-	if p.conn == nil {
-		return false
-	}
-	return true
+	return p.conn.IsOpen()
 }
 
 // Closes the socket.
@@ -140,7 +141,7 @@
 }
 
 func (p *TSocket) Read(buf []byte) (int, error) {
-	if !p.IsOpen() {
+	if !p.conn.isValid() {
 		return 0, NewTTransportException(NOT_OPEN, "Connection not open")
 	}
 	p.pushDeadline(true, false)
@@ -149,7 +150,7 @@
 }
 
 func (p *TSocket) Write(buf []byte) (int, error) {
-	if !p.IsOpen() {
+	if !p.conn.isValid() {
 		return 0, NewTTransportException(NOT_OPEN, "Connection not open")
 	}
 	p.pushDeadline(false, true)
@@ -161,7 +162,7 @@
 }
 
 func (p *TSocket) Interrupt() error {
-	if !p.IsOpen() {
+	if !p.conn.isValid() {
 		return nil
 	}
 	return p.conn.Close()
diff --git a/lib/go/thrift/socket_conn.go b/lib/go/thrift/socket_conn.go
new file mode 100644
index 0000000..b0f7b3e
--- /dev/null
+++ b/lib/go/thrift/socket_conn.go
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+	"bytes"
+	"io"
+	"net"
+)
+
+// socketConn is a wrapped net.Conn that tries to do connectivity check.
+type socketConn struct {
+	net.Conn
+
+	buf bytes.Buffer
+}
+
+var _ net.Conn = (*socketConn)(nil)
+
+// createSocketConnFromReturn is a language sugar to help create socketConn from
+// return values of functions like net.Dial, tls.Dial, net.Listener.Accept, etc.
+func createSocketConnFromReturn(conn net.Conn, err error) (*socketConn, error) {
+	if err != nil {
+		return nil, err
+	}
+	return &socketConn{
+		Conn: conn,
+	}, nil
+}
+
+// wrapSocketConn wraps an existing net.Conn into *socketConn.
+func wrapSocketConn(conn net.Conn) *socketConn {
+	// In case conn is already wrapped,
+	// return it as-is and avoid double wrapping.
+	if sc, ok := conn.(*socketConn); ok {
+		return sc
+	}
+
+	return &socketConn{
+		Conn: conn,
+	}
+}
+
+// isValid checks whether there's a valid connection.
+//
+// It's nil safe, and returns false if sc itself is nil, or if the underlying
+// connection is nil.
+//
+// It's the same as the previous implementation of TSocket.IsOpen and
+// TSSLSocket.IsOpen before we added connectivity check.
+func (sc *socketConn) isValid() bool {
+	return sc != nil && sc.Conn != nil
+}
+
+// IsOpen checks whether the connection is open.
+//
+// It's nil safe, and returns false if sc itself is nil, or if the underlying
+// connection is nil.
+//
+// Otherwise, it tries to do a connectivity check and returns the result.
+func (sc *socketConn) IsOpen() bool {
+	if !sc.isValid() {
+		return false
+	}
+	return sc.checkConn() == nil
+}
+
+// Read implements io.Reader.
+//
+// On Windows, it behaves the same as the underlying net.Conn.Read.
+//
+// On non-Windows, it treats len(p) == 0 as a connectivity check instead of
+// readability check, which means instead of blocking until there's something to
+// read (readability check), or always return (0, nil) (the default behavior of
+// go's stdlib implementation on non-Windows), it never blocks, and will return
+// an error if the connection is lost.
+func (sc *socketConn) Read(p []byte) (n int, err error) {
+	if len(p) == 0 {
+		return 0, sc.read0()
+	}
+
+	n, err = sc.buf.Read(p)
+	if err != nil && err != io.EOF {
+		return
+	}
+	if n == len(p) {
+		return n, nil
+	}
+	// Continue reading from the wire.
+	var newRead int
+	newRead, err = sc.Conn.Read(p[n:])
+	n += newRead
+	return
+}
diff --git a/lib/go/thrift/socket_conn_test.go b/lib/go/thrift/socket_conn_test.go
new file mode 100644
index 0000000..ab92462
--- /dev/null
+++ b/lib/go/thrift/socket_conn_test.go
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+	"io"
+	"net"
+	"strings"
+	"testing"
+	"time"
+)
+
+type serverSocketConnCallback func(testing.TB, *socketConn)
+
+func serverSocketConn(tb testing.TB, f serverSocketConnCallback) (net.Listener, error) {
+	tb.Helper()
+
+	ln, err := net.Listen("tcp", "localhost:0")
+	if err != nil {
+		return nil, err
+	}
+	go func() {
+		for {
+			sc, err := createSocketConnFromReturn(ln.Accept())
+			if err != nil {
+				// This is usually caused by Listener being
+				// closed, not really an error.
+				return
+			}
+			go f(tb, sc)
+		}
+	}()
+	return ln, nil
+}
+
+func writeFully(tb testing.TB, w io.Writer, s string) bool {
+	tb.Helper()
+
+	n, err := io.Copy(w, strings.NewReader(s))
+	if err != nil {
+		tb.Errorf("Failed to write %q: %v", s, err)
+		return false
+	}
+	if int(n) < len(s) {
+		tb.Errorf("Only wrote %d out of %q", n, s)
+		return false
+	}
+	return true
+}
+
+func TestSocketConn(t *testing.T) {
+	const (
+		interval = time.Millisecond * 10
+		first    = "hello"
+		second   = "world"
+	)
+
+	ln, err := serverSocketConn(
+		t,
+		func(tb testing.TB, sc *socketConn) {
+			defer sc.Close()
+
+			if !writeFully(tb, sc, first) {
+				return
+			}
+			time.Sleep(interval)
+			writeFully(tb, sc, second)
+		},
+	)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer ln.Close()
+
+	sc, err := createSocketConnFromReturn(net.Dial("tcp", ln.Addr().String()))
+	if err != nil {
+		t.Fatal(err)
+	}
+	buf := make([]byte, 1024)
+
+	n, err := sc.Read(buf)
+	if err != nil {
+		t.Fatal(err)
+	}
+	read := string(buf[:n])
+	if read != first {
+		t.Errorf("Expected read %q, got %q", first, read)
+	}
+
+	n, err = sc.Read(buf)
+	if err != nil {
+		t.Fatal(err)
+	}
+	read = string(buf[:n])
+	if read != second {
+		t.Errorf("Expected read %q, got %q", second, read)
+	}
+}
+
+func TestSocketConnNilSafe(t *testing.T) {
+	sc := (*socketConn)(nil)
+	if sc.isValid() {
+		t.Error("Expected false for nil.isValid(), got true")
+	}
+	if sc.IsOpen() {
+		t.Error("Expected false for nil.IsOpen(), got true")
+	}
+}
diff --git a/lib/go/thrift/socket_unix_conn.go b/lib/go/thrift/socket_unix_conn.go
new file mode 100644
index 0000000..f18e0e6
--- /dev/null
+++ b/lib/go/thrift/socket_unix_conn.go
@@ -0,0 +1,73 @@
+// +build !windows
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+	"io"
+	"syscall"
+)
+
+func (sc *socketConn) read0() error {
+	return sc.checkConn()
+}
+
+func (sc *socketConn) checkConn() error {
+	syscallConn, ok := sc.Conn.(syscall.Conn)
+	if !ok {
+		// No way to check, return nil
+		return nil
+	}
+	rc, err := syscallConn.SyscallConn()
+	if err != nil {
+		return err
+	}
+
+	var n int
+	var buf [1]byte
+
+	if readErr := rc.Read(func(fd uintptr) bool {
+		n, err = syscall.Read(int(fd), buf[:])
+		return true
+	}); readErr != nil {
+		return readErr
+	}
+
+	if err == syscall.EAGAIN || err == syscall.EWOULDBLOCK {
+		// This means the connection is still open but we don't have
+		// anything to read right now.
+		return nil
+	}
+
+	if n > 0 {
+		// We got 1 byte,
+		// put it to sc's buf for the next real read to use.
+		sc.buf.Write(buf[:])
+		return nil
+	}
+
+	if err != nil {
+		return err
+	}
+
+	// At this point, it means the other side already closed the connection.
+	return io.EOF
+}
diff --git a/lib/go/thrift/socket_unix_conn_test.go b/lib/go/thrift/socket_unix_conn_test.go
new file mode 100644
index 0000000..3563a25
--- /dev/null
+++ b/lib/go/thrift/socket_unix_conn_test.go
@@ -0,0 +1,105 @@
+// +build !windows
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+	"io"
+	"net"
+	"testing"
+	"time"
+)
+
+func TestSocketConnUnix(t *testing.T) {
+	const (
+		interval = time.Millisecond * 10
+		first    = "hello"
+		second   = "world"
+	)
+
+	ln, err := serverSocketConn(
+		t,
+		func(tb testing.TB, sc *socketConn) {
+			defer sc.Close()
+
+			time.Sleep(interval)
+			if !writeFully(tb, sc, first) {
+				return
+			}
+			time.Sleep(interval)
+			writeFully(tb, sc, second)
+		},
+	)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer ln.Close()
+
+	sc, err := createSocketConnFromReturn(net.Dial("tcp", ln.Addr().String()))
+	if err != nil {
+		t.Fatal(err)
+	}
+	buf := make([]byte, 1024)
+
+	if !sc.IsOpen() {
+		t.Error("Expected sc to report open, got false")
+	}
+	n, err := sc.Read(buf)
+	if err != nil {
+		t.Fatal(err)
+	}
+	read := string(buf[:n])
+	if read != first {
+		t.Errorf("Expected read %q, got %q", first, read)
+	}
+
+	if !sc.IsOpen() {
+		t.Error("Expected sc to report open, got false")
+	}
+	// Do connection check again twice after server already wrote new data,
+	// make sure we correctly buffered the read bytes
+	time.Sleep(interval * 10)
+	if !sc.IsOpen() {
+		t.Error("Expected sc to report open, got false")
+	}
+	if !sc.IsOpen() {
+		t.Error("Expected sc to report open, got false")
+	}
+	if sc.buf.Len() == 0 {
+		t.Error("Expected sc to buffer read bytes, got empty buffer")
+	}
+	n, err = sc.Read(buf)
+	if err != nil {
+		t.Fatal(err)
+	}
+	read = string(buf[:n])
+	if read != second {
+		t.Errorf("Expected read %q, got %q", second, read)
+	}
+
+	// Now it's supposed to be closed on the server side
+	if err := sc.read0(); err != io.EOF {
+		t.Errorf("Expected to get EOF on read0, got %v", err)
+	}
+	if sc.IsOpen() {
+		t.Error("Expected sc to report not open, got true")
+	}
+}
diff --git a/lib/go/thrift/socket_windows_conn.go b/lib/go/thrift/socket_windows_conn.go
new file mode 100644
index 0000000..679838c
--- /dev/null
+++ b/lib/go/thrift/socket_windows_conn.go
@@ -0,0 +1,34 @@
+// +build windows
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+func (sc *socketConn) read0() error {
+	// On windows, we fallback to the default behavior of reading 0 bytes.
+	var p []byte
+	_, err := sc.Conn.Read(p)
+	return err
+}
+
+func (sc *socketConn) checkConn() error {
+	// On windows, we always return nil for this check.
+	return nil
+}
diff --git a/lib/go/thrift/ssl_socket.go b/lib/go/thrift/ssl_socket.go
index 45bf38a..661111c 100644
--- a/lib/go/thrift/ssl_socket.go
+++ b/lib/go/thrift/ssl_socket.go
@@ -27,7 +27,7 @@
 )
 
 type TSSLSocket struct {
-	conn net.Conn
+	conn *socketConn
 	// hostPort contains host:port (e.g. "asdf.com:12345"). The field is
 	// only valid if addr is nil.
 	hostPort string
@@ -62,7 +62,7 @@
 
 // Creates a TSSLSocket from an existing net.Conn
 func NewTSSLSocketFromConnTimeout(conn net.Conn, cfg *tls.Config, timeout time.Duration) *TSSLSocket {
-	return &TSSLSocket{conn: conn, addr: conn.RemoteAddr(), timeout: timeout, cfg: cfg}
+	return &TSSLSocket{conn: wrapSocketConn(conn), addr: conn.RemoteAddr(), timeout: timeout, cfg: cfg}
 }
 
 // Sets the socket timeout
@@ -91,12 +91,18 @@
 	// If we have a hostname, we need to pass the hostname to tls.Dial for
 	// certificate hostname checks.
 	if p.hostPort != "" {
-		if p.conn, err = tls.DialWithDialer(&net.Dialer{
-			Timeout: p.timeout}, "tcp", p.hostPort, p.cfg); err != nil {
+		if p.conn, err = createSocketConnFromReturn(tls.DialWithDialer(
+			&net.Dialer{
+				Timeout: p.timeout,
+			},
+			"tcp",
+			p.hostPort,
+			p.cfg,
+		)); err != nil {
 			return NewTTransportException(NOT_OPEN, err.Error())
 		}
 	} else {
-		if p.IsOpen() {
+		if p.conn.isValid() {
 			return NewTTransportException(ALREADY_OPEN, "Socket already connected.")
 		}
 		if p.addr == nil {
@@ -108,8 +114,14 @@
 		if len(p.addr.String()) == 0 {
 			return NewTTransportException(NOT_OPEN, "Cannot open bad address.")
 		}
-		if p.conn, err = tls.DialWithDialer(&net.Dialer{
-			Timeout: p.timeout}, p.addr.Network(), p.addr.String(), p.cfg); err != nil {
+		if p.conn, err = createSocketConnFromReturn(tls.DialWithDialer(
+			&net.Dialer{
+				Timeout: p.timeout,
+			},
+			p.addr.Network(),
+			p.addr.String(),
+			p.cfg,
+		)); err != nil {
 			return NewTTransportException(NOT_OPEN, err.Error())
 		}
 	}
@@ -123,10 +135,7 @@
 
 // Returns true if the connection is open
 func (p *TSSLSocket) IsOpen() bool {
-	if p.conn == nil {
-		return false
-	}
-	return true
+	return p.conn.IsOpen()
 }
 
 // Closes the socket.
@@ -143,7 +152,7 @@
 }
 
 func (p *TSSLSocket) Read(buf []byte) (int, error) {
-	if !p.IsOpen() {
+	if !p.conn.isValid() {
 		return 0, NewTTransportException(NOT_OPEN, "Connection not open")
 	}
 	p.pushDeadline(true, false)
@@ -152,7 +161,7 @@
 }
 
 func (p *TSSLSocket) Write(buf []byte) (int, error) {
-	if !p.IsOpen() {
+	if !p.conn.isValid() {
 		return 0, NewTTransportException(NOT_OPEN, "Connection not open")
 	}
 	p.pushDeadline(false, true)
@@ -164,7 +173,7 @@
 }
 
 func (p *TSSLSocket) Interrupt() error {
-	if !p.IsOpen() {
+	if !p.conn.isValid() {
 		return nil
 	}
 	return p.conn.Close()