THRIFT-5495: close client when shutdown server in go lib
Client: go
diff --git a/lib/go/README.md b/lib/go/README.md
index 75d7174..b2cf1df 100644
--- a/lib/go/README.md
+++ b/lib/go/README.md
@@ -132,3 +132,27 @@
excessive cpu overhead.
This feature is also only enabled on non-oneway endpoints.
+
+A note about server stop implementations
+========================================
+
+[TSimpleServer.Stop](https://pkg.go.dev/github.com/apache/thrift/lib/go/thrift#TSimpleServer.Stop) will wait for all client connections to be closed after
+the last received request to be handled, as the time spent by Stop
+ may sometimes be too long:
+* When socket timeout is not set, server might be hanged before all active
+ clients to finish handling the last received request.
+* When the socket timeout is too long (e.g one hour), server will
+ hang for that duration before all active clients to finish handling the
+ last received request.
+
+To prevent Stop from hanging for too long, you can set
+thrift.ServerStopTimeout in your main or init function:
+
+ thrift.ServerStopTimeout = <max_duration_to_stop>
+
+If it's set to <=0, the feature will be disabled (by default), and server
+will wait for all the client connections to be closed gracefully with
+zero err time. Otherwise, the stop will wait for all the client
+connections to be closed gracefully util thrift.ServerStopTimeout is
+reached, and client connections that are not closed after thrift.ServerStopTimeout
+will be closed abruptly which may cause some client errors.
\ No newline at end of file
diff --git a/lib/go/thrift/simple_server.go b/lib/go/thrift/simple_server.go
index 02863ec..1cfc375 100644
--- a/lib/go/thrift/simple_server.go
+++ b/lib/go/thrift/simple_server.go
@@ -20,6 +20,7 @@
package thrift
import (
+ "context"
"errors"
"fmt"
"io"
@@ -48,15 +49,26 @@
// If it's changed to <=0, the feature will be disabled.
var ServerConnectivityCheckInterval = time.Millisecond * 5
+// ServerStopTimeout defines max stop wait duration used by
+// server stop to avoid hanging too long to wait for all client connections to be closed gracefully.
+//
+// It's defined as a variable instead of constant, so that thrift server
+// implementations can change its value to control the behavior.
+//
+// If it's set to <=0, the feature will be disabled(by default), and the server will wait for
+// for all the client connections to be closed gracefully.
+var ServerStopTimeout = time.Duration(0)
+
/*
* This is not a typical TSimpleServer as it is not blocked after accept a socket.
* It is more like a TThreadedServer that can handle different connections in different goroutines.
* This will work if golang user implements a conn-pool like thing in client side.
*/
type TSimpleServer struct {
- closed int32
- wg sync.WaitGroup
- mu sync.Mutex
+ closed int32
+ wg sync.WaitGroup
+ mu sync.Mutex
+ stopChan chan struct{}
processorFactory TProcessorFactory
serverTransport TServerTransport
@@ -121,6 +133,7 @@
outputTransportFactory: outputTransportFactory,
inputProtocolFactory: inputProtocolFactory,
outputProtocolFactory: outputProtocolFactory,
+ stopChan: make(chan struct{}),
}
}
@@ -192,13 +205,27 @@
return 0, err
}
if client != nil {
- p.wg.Add(1)
+ ctx, cancel := context.WithCancel(context.Background())
+ p.wg.Add(2)
+
go func() {
defer p.wg.Done()
+ defer cancel()
if err := p.processRequests(client); err != nil {
p.logger(fmt.Sprintf("error processing request: %v", err))
}
}()
+
+ go func() {
+ defer p.wg.Done()
+ select {
+ case <-ctx.Done():
+ // client exited, do nothing
+ case <-p.stopChan:
+ // TSimpleServer.Close called, close the client connection
+ client.Close()
+ }
+ }()
}
return 0, nil
}
@@ -229,12 +256,31 @@
func (p *TSimpleServer) Stop() error {
p.mu.Lock()
defer p.mu.Unlock()
+
if atomic.LoadInt32(&p.closed) != 0 {
return nil
}
atomic.StoreInt32(&p.closed, 1)
p.serverTransport.Interrupt()
- p.wg.Wait()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ go func() {
+ defer cancel()
+ p.wg.Wait()
+ }()
+
+ if ServerStopTimeout > 0 {
+ timer := time.NewTimer(ServerStopTimeout)
+ select {
+ case <-timer.C:
+ case <-ctx.Done():
+ }
+ close(p.stopChan)
+ timer.Stop()
+ }
+
+ <-ctx.Done()
+ p.stopChan = make(chan struct{})
return nil
}
diff --git a/lib/go/thrift/simple_server_test.go b/lib/go/thrift/simple_server_test.go
index 58149a8..b92d50f 100644
--- a/lib/go/thrift/simple_server_test.go
+++ b/lib/go/thrift/simple_server_test.go
@@ -20,11 +20,17 @@
package thrift
import (
- "testing"
+ "context"
"errors"
+ "net"
"runtime"
+ "sync"
+ "testing"
+ "time"
)
+const networkWaitDuration = 10 * time.Millisecond
+
type mockServerTransport struct {
ListenFunc func() error
AcceptFunc func() (TTransport, error)
@@ -154,3 +160,130 @@
runtime.Gosched()
serv.Stop()
}
+
+func TestNoHangDuringStopFromClientNoDataSendDuringAcceptLoop(t *testing.T) {
+ ln, err := net.Listen("tcp", "localhost:0")
+
+ if err != nil {
+ t.Fatalf("Failed to listen: %v", err)
+ }
+
+ proc := &mockProcessor{
+ ProcessFunc: func(in, out TProtocol) (bool, TException) {
+ in.ReadMessageBegin(context.Background())
+ return false, nil
+ },
+ }
+
+ trans := &mockServerTransport{
+ ListenFunc: func() error {
+ return nil
+ },
+ AcceptFunc: func() (TTransport, error) {
+ conn, err := ln.Accept()
+ if err != nil {
+ return nil, err
+ }
+
+ return NewTSocketFromConnConf(conn, nil), nil
+ },
+ CloseFunc: func() error {
+ return nil
+ },
+ InterruptFunc: func() error {
+ return ln.Close()
+ },
+ }
+
+ serv := NewTSimpleServer2(proc, trans)
+ go serv.Serve()
+ time.Sleep(networkWaitDuration)
+
+ netConn, err := net.Dial("tcp", ln.Addr().String())
+ if err != nil || netConn == nil {
+ t.Fatal("error when dial server")
+ }
+ time.Sleep(networkWaitDuration)
+
+ serverStopTimeout := 50 * time.Millisecond
+ backupServerStopTimeout := ServerStopTimeout
+ t.Cleanup(func() {
+ ServerStopTimeout = backupServerStopTimeout
+ })
+ ServerStopTimeout = serverStopTimeout
+
+ st := time.Now()
+ err = serv.Stop()
+ if err != nil {
+ t.Errorf("error when stop server:%v", err)
+ }
+
+ if elapsed := time.Since(st); elapsed < serverStopTimeout {
+ t.Errorf("stop cost less time than server stop timeout, server stop timeout:%v,cost time:%v", ServerStopTimeout, elapsed)
+ }
+}
+
+func TestStopTimeoutWithSocketTimeout(t *testing.T) {
+ ln, err := net.Listen("tcp", "localhost:0")
+
+ if err != nil {
+ t.Fatalf("Failed to listen: %v", err)
+ }
+
+ proc := &mockProcessor{
+ ProcessFunc: func(in, out TProtocol) (bool, TException) {
+ in.ReadMessageBegin(context.Background())
+ return false, nil
+ },
+ }
+
+ conf := &TConfiguration{SocketTimeout: 5 * time.Millisecond}
+ wg := &sync.WaitGroup{}
+ trans := &mockServerTransport{
+ ListenFunc: func() error {
+ return nil
+ },
+ AcceptFunc: func() (TTransport, error) {
+ conn, err := ln.Accept()
+ if err != nil {
+ return nil, err
+ }
+ defer wg.Done()
+ return NewTSocketFromConnConf(conn, conf), nil
+ },
+ CloseFunc: func() error {
+ return nil
+ },
+ InterruptFunc: func() error {
+ return ln.Close()
+ },
+ }
+
+ serv := NewTSimpleServer2(proc, trans)
+ go serv.Serve()
+ time.Sleep(networkWaitDuration)
+
+ wg.Add(1)
+ netConn, err := net.Dial("tcp", ln.Addr().String())
+ if err != nil || netConn == nil {
+ t.Fatal("error when dial server")
+ }
+ wg.Wait()
+
+ expectedStopTimeout := time.Second
+ backupServerStopTimeout := ServerStopTimeout
+ t.Cleanup(func() {
+ ServerStopTimeout = backupServerStopTimeout
+ })
+ ServerStopTimeout = expectedStopTimeout
+
+ st := time.Now()
+ err = serv.Stop()
+ if elapsed := time.Since(st); elapsed > expectedStopTimeout/2 {
+ t.Errorf("stop cost more time than socket timeout, socket timeout:%v,server stop timeout:%v,cost time:%v", conf.SocketTimeout, ServerStopTimeout, elapsed)
+ }
+
+ if err != nil {
+ t.Fatalf("error when stop server:%v", err)
+ }
+}
diff --git a/test/go/src/common/clientserver_test.go b/test/go/src/common/clientserver_test.go
index 609086b..64b326a 100644
--- a/test/go/src/common/clientserver_test.go
+++ b/test/go/src/common/clientserver_test.go
@@ -75,7 +75,7 @@
t.Errorf("Unable to start server: %v", err)
return
}
- go server.AcceptLoop()
+ go server.Serve()
defer server.Stop()
client, trans, err := StartClient(unit.host, unit.port, unit.domain_socket, unit.transport, unit.protocol, unit.ssl)
if err != nil {