THRIFT-5152: introduce connect timeout and socket timeout
Client: Go
Patch: Qian Lv
This closes #2071
diff --git a/lib/go/test/tests/multiplexed_protocol_test.go b/lib/go/test/tests/multiplexed_protocol_test.go
index 61ac628..4fb6f4f 100644
--- a/lib/go/test/tests/multiplexed_protocol_test.go
+++ b/lib/go/test/tests/multiplexed_protocol_test.go
@@ -50,7 +50,7 @@
}
func createTransport(addr net.Addr) (thrift.TTransport, error) {
- socket := thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT)
+ socket := thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT, TIMEOUT)
transport := thrift.NewTFramedTransport(socket)
err := transport.Open()
if err != nil {
diff --git a/lib/go/test/tests/one_way_test.go b/lib/go/test/tests/one_way_test.go
index 48d0bbe..010e3bb 100644
--- a/lib/go/test/tests/one_way_test.go
+++ b/lib/go/test/tests/one_way_test.go
@@ -65,7 +65,7 @@
}
func TestInitOnewayClient(t *testing.T) {
- transport := thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT)
+ transport := thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT, TIMEOUT)
protocol := thrift.NewTBinaryProtocolTransport(transport)
client = onewaytest.NewOneWayClient(thrift.NewTStandardClient(protocol, protocol))
err := transport.Open()
diff --git a/lib/go/test/tests/protocols_test.go b/lib/go/test/tests/protocols_test.go
index cffd9c3..9030e9d 100644
--- a/lib/go/test/tests/protocols_test.go
+++ b/lib/go/test/tests/protocols_test.go
@@ -41,7 +41,7 @@
go server.Serve()
// client
- var transport thrift.TTransport = thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT)
+ var transport thrift.TTransport = thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT, TIMEOUT)
transport, err = transportFactory.GetTransport(transport)
if err != nil {
t.Fatal(err)
diff --git a/lib/go/thrift/socket.go b/lib/go/thrift/socket.go
index 88b98f5..558818a 100644
--- a/lib/go/thrift/socket.go
+++ b/lib/go/thrift/socket.go
@@ -26,9 +26,10 @@
)
type TSocket struct {
- conn net.Conn
- addr net.Addr
- timeout time.Duration
+ conn net.Conn
+ addr net.Addr
+ connectTimeout time.Duration
+ socketTimeout time.Duration
}
// NewTSocket creates a net.Conn-backed TTransport, given a host and port
@@ -36,40 +37,46 @@
// Example:
// trans, err := thrift.NewTSocket("localhost:9090")
func NewTSocket(hostPort string) (*TSocket, error) {
- return NewTSocketTimeout(hostPort, 0)
+ return NewTSocketTimeout(hostPort, 0, 0)
}
// NewTSocketTimeout creates a net.Conn-backed TTransport, given a host and port
// it also accepts a timeout as a time.Duration
-func NewTSocketTimeout(hostPort string, timeout time.Duration) (*TSocket, error) {
+func NewTSocketTimeout(hostPort string, connTimeout time.Duration, soTimeout time.Duration) (*TSocket, error) {
//conn, err := net.DialTimeout(network, address, timeout)
addr, err := net.ResolveTCPAddr("tcp", hostPort)
if err != nil {
return nil, err
}
- return NewTSocketFromAddrTimeout(addr, timeout), nil
+ return NewTSocketFromAddrTimeout(addr, connTimeout, soTimeout), nil
}
// Creates a TSocket from a net.Addr
-func NewTSocketFromAddrTimeout(addr net.Addr, timeout time.Duration) *TSocket {
- return &TSocket{addr: addr, timeout: timeout}
+func NewTSocketFromAddrTimeout(addr net.Addr, connTimeout time.Duration, soTimeout time.Duration) *TSocket {
+ return &TSocket{addr: addr, connectTimeout: connTimeout, socketTimeout: soTimeout}
}
// Creates a TSocket from an existing net.Conn
-func NewTSocketFromConnTimeout(conn net.Conn, timeout time.Duration) *TSocket {
- return &TSocket{conn: conn, addr: conn.RemoteAddr(), timeout: timeout}
+func NewTSocketFromConnTimeout(conn net.Conn, connTimeout time.Duration) *TSocket {
+ return &TSocket{conn: conn, addr: conn.RemoteAddr(), connectTimeout: connTimeout, socketTimeout: connTimeout}
+}
+
+// Sets the connect timeout
+func (p *TSocket) SetConnTimeout(timeout time.Duration) error {
+ p.connectTimeout = timeout
+ return nil
}
// Sets the socket timeout
-func (p *TSocket) SetTimeout(timeout time.Duration) error {
- p.timeout = timeout
+func (p *TSocket) SetSocketTimeout(timeout time.Duration) error {
+ p.socketTimeout = timeout
return nil
}
func (p *TSocket) pushDeadline(read, write bool) {
var t time.Time
- if p.timeout > 0 {
- t = time.Now().Add(time.Duration(p.timeout))
+ if p.socketTimeout > 0 {
+ t = time.Now().Add(time.Duration(p.socketTimeout))
}
if read && write {
p.conn.SetDeadline(t)
@@ -95,7 +102,7 @@
return NewTTransportException(NOT_OPEN, "Cannot open bad address.")
}
var err error
- if p.conn, err = net.DialTimeout(p.addr.Network(), p.addr.String(), p.timeout); err != nil {
+ if p.conn, err = net.DialTimeout(p.addr.Network(), p.addr.String(), p.connectTimeout); err != nil {
return NewTTransportException(NOT_OPEN, err.Error())
}
return nil