THRIFT-5214: Connectivity check on go's TSocket
Client: go
Implement connectivity check on go's TSocket and TSSLSocket for
non-Windows systems.
The implementation is inspired by
https://github.blog/2020-05-20-three-bugs-in-the-go-mysql-driver/
diff --git a/lib/go/thrift/ssl_socket.go b/lib/go/thrift/ssl_socket.go
index 45bf38a..661111c 100644
--- a/lib/go/thrift/ssl_socket.go
+++ b/lib/go/thrift/ssl_socket.go
@@ -27,7 +27,7 @@
)
type TSSLSocket struct {
- conn net.Conn
+ conn *socketConn
// hostPort contains host:port (e.g. "asdf.com:12345"). The field is
// only valid if addr is nil.
hostPort string
@@ -62,7 +62,7 @@
// Creates a TSSLSocket from an existing net.Conn
func NewTSSLSocketFromConnTimeout(conn net.Conn, cfg *tls.Config, timeout time.Duration) *TSSLSocket {
- return &TSSLSocket{conn: conn, addr: conn.RemoteAddr(), timeout: timeout, cfg: cfg}
+ return &TSSLSocket{conn: wrapSocketConn(conn), addr: conn.RemoteAddr(), timeout: timeout, cfg: cfg}
}
// Sets the socket timeout
@@ -91,12 +91,18 @@
// If we have a hostname, we need to pass the hostname to tls.Dial for
// certificate hostname checks.
if p.hostPort != "" {
- if p.conn, err = tls.DialWithDialer(&net.Dialer{
- Timeout: p.timeout}, "tcp", p.hostPort, p.cfg); err != nil {
+ if p.conn, err = createSocketConnFromReturn(tls.DialWithDialer(
+ &net.Dialer{
+ Timeout: p.timeout,
+ },
+ "tcp",
+ p.hostPort,
+ p.cfg,
+ )); err != nil {
return NewTTransportException(NOT_OPEN, err.Error())
}
} else {
- if p.IsOpen() {
+ if p.conn.isValid() {
return NewTTransportException(ALREADY_OPEN, "Socket already connected.")
}
if p.addr == nil {
@@ -108,8 +114,14 @@
if len(p.addr.String()) == 0 {
return NewTTransportException(NOT_OPEN, "Cannot open bad address.")
}
- if p.conn, err = tls.DialWithDialer(&net.Dialer{
- Timeout: p.timeout}, p.addr.Network(), p.addr.String(), p.cfg); err != nil {
+ if p.conn, err = createSocketConnFromReturn(tls.DialWithDialer(
+ &net.Dialer{
+ Timeout: p.timeout,
+ },
+ p.addr.Network(),
+ p.addr.String(),
+ p.cfg,
+ )); err != nil {
return NewTTransportException(NOT_OPEN, err.Error())
}
}
@@ -123,10 +135,7 @@
// Returns true if the connection is open
func (p *TSSLSocket) IsOpen() bool {
- if p.conn == nil {
- return false
- }
- return true
+ return p.conn.IsOpen()
}
// Closes the socket.
@@ -143,7 +152,7 @@
}
func (p *TSSLSocket) Read(buf []byte) (int, error) {
- if !p.IsOpen() {
+ if !p.conn.isValid() {
return 0, NewTTransportException(NOT_OPEN, "Connection not open")
}
p.pushDeadline(true, false)
@@ -152,7 +161,7 @@
}
func (p *TSSLSocket) Write(buf []byte) (int, error) {
- if !p.IsOpen() {
+ if !p.conn.isValid() {
return 0, NewTTransportException(NOT_OPEN, "Connection not open")
}
p.pushDeadline(false, true)
@@ -164,7 +173,7 @@
}
func (p *TSSLSocket) Interrupt() error {
- if !p.IsOpen() {
+ if !p.conn.isValid() {
return nil
}
return p.conn.Close()