THRIFT-5495: close client when shutdown server in go lib
Client: go
diff --git a/lib/go/README.md b/lib/go/README.md
index 75d7174..b2cf1df 100644
--- a/lib/go/README.md
+++ b/lib/go/README.md
@@ -132,3 +132,27 @@
 excessive cpu overhead.
 
 This feature is also only enabled on non-oneway endpoints.
+
+A note about server stop implementations
+========================================
+
+[TSimpleServer.Stop](https://pkg.go.dev/github.com/apache/thrift/lib/go/thrift#TSimpleServer.Stop) will wait for all client connections to be closed after 
+the last received request to be handled, as the time spent by Stop
+ may sometimes be too long:
+* When socket timeout is not set, server might be hanged before all active
+  clients to finish handling the last received request.
+* When the socket timeout is too long (e.g one hour), server will
+  hang for that duration before all active clients to finish handling the
+  last received request.
+
+To prevent Stop from hanging for too long, you can set 
+thrift.ServerStopTimeout in your main or init function:
+
+    thrift.ServerStopTimeout = <max_duration_to_stop>
+
+If it's set to <=0, the feature will be disabled (by default), and server 
+will wait for all the client connections to be closed gracefully with 
+zero err time. Otherwise, the stop will wait for all the client 
+connections to be closed gracefully util thrift.ServerStopTimeout is 
+reached, and client connections that are not closed after thrift.ServerStopTimeout 
+will be closed abruptly which may cause some client errors.
\ No newline at end of file
diff --git a/lib/go/thrift/simple_server.go b/lib/go/thrift/simple_server.go
index 02863ec..1cfc375 100644
--- a/lib/go/thrift/simple_server.go
+++ b/lib/go/thrift/simple_server.go
@@ -20,6 +20,7 @@
 package thrift
 
 import (
+	"context"
 	"errors"
 	"fmt"
 	"io"
@@ -48,15 +49,26 @@
 // If it's changed to <=0, the feature will be disabled.
 var ServerConnectivityCheckInterval = time.Millisecond * 5
 
+// ServerStopTimeout defines max stop wait duration used by
+// server stop to avoid hanging too long to wait for all client connections to be closed gracefully.
+//
+// It's defined as a variable instead of constant, so that thrift server
+// implementations can change its value to control the behavior.
+//
+// If it's set to <=0, the feature will be disabled(by default), and the server will wait for
+// for all the client connections to be closed gracefully.
+var ServerStopTimeout = time.Duration(0)
+
 /*
  * This is not a typical TSimpleServer as it is not blocked after accept a socket.
  * It is more like a TThreadedServer that can handle different connections in different goroutines.
  * This will work if golang user implements a conn-pool like thing in client side.
  */
 type TSimpleServer struct {
-	closed int32
-	wg     sync.WaitGroup
-	mu     sync.Mutex
+	closed   int32
+	wg       sync.WaitGroup
+	mu       sync.Mutex
+	stopChan chan struct{}
 
 	processorFactory       TProcessorFactory
 	serverTransport        TServerTransport
@@ -121,6 +133,7 @@
 		outputTransportFactory: outputTransportFactory,
 		inputProtocolFactory:   inputProtocolFactory,
 		outputProtocolFactory:  outputProtocolFactory,
+		stopChan:               make(chan struct{}),
 	}
 }
 
@@ -192,13 +205,27 @@
 		return 0, err
 	}
 	if client != nil {
-		p.wg.Add(1)
+		ctx, cancel := context.WithCancel(context.Background())
+		p.wg.Add(2)
+
 		go func() {
 			defer p.wg.Done()
+			defer cancel()
 			if err := p.processRequests(client); err != nil {
 				p.logger(fmt.Sprintf("error processing request: %v", err))
 			}
 		}()
+
+		go func() {
+			defer p.wg.Done()
+			select {
+			case <-ctx.Done():
+				// client exited, do nothing
+			case <-p.stopChan:
+				// TSimpleServer.Close called, close the client connection
+				client.Close()
+			}
+		}()
 	}
 	return 0, nil
 }
@@ -229,12 +256,31 @@
 func (p *TSimpleServer) Stop() error {
 	p.mu.Lock()
 	defer p.mu.Unlock()
+
 	if atomic.LoadInt32(&p.closed) != 0 {
 		return nil
 	}
 	atomic.StoreInt32(&p.closed, 1)
 	p.serverTransport.Interrupt()
-	p.wg.Wait()
+
+	ctx, cancel := context.WithCancel(context.Background())
+	go func() {
+		defer cancel()
+		p.wg.Wait()
+	}()
+
+	if ServerStopTimeout > 0 {
+		timer := time.NewTimer(ServerStopTimeout)
+		select {
+		case <-timer.C:
+		case <-ctx.Done():
+		}
+		close(p.stopChan)
+		timer.Stop()
+	}
+
+	<-ctx.Done()
+	p.stopChan = make(chan struct{})
 	return nil
 }
 
diff --git a/lib/go/thrift/simple_server_test.go b/lib/go/thrift/simple_server_test.go
index 58149a8..b92d50f 100644
--- a/lib/go/thrift/simple_server_test.go
+++ b/lib/go/thrift/simple_server_test.go
@@ -20,11 +20,17 @@
 package thrift
 
 import (
-	"testing"
+	"context"
 	"errors"
+	"net"
 	"runtime"
+	"sync"
+	"testing"
+	"time"
 )
 
+const networkWaitDuration = 10 * time.Millisecond
+
 type mockServerTransport struct {
 	ListenFunc    func() error
 	AcceptFunc    func() (TTransport, error)
@@ -154,3 +160,130 @@
 	runtime.Gosched()
 	serv.Stop()
 }
+
+func TestNoHangDuringStopFromClientNoDataSendDuringAcceptLoop(t *testing.T) {
+	ln, err := net.Listen("tcp", "localhost:0")
+
+	if err != nil {
+		t.Fatalf("Failed to listen: %v", err)
+	}
+
+	proc := &mockProcessor{
+		ProcessFunc: func(in, out TProtocol) (bool, TException) {
+			in.ReadMessageBegin(context.Background())
+			return false, nil
+		},
+	}
+
+	trans := &mockServerTransport{
+		ListenFunc: func() error {
+			return nil
+		},
+		AcceptFunc: func() (TTransport, error) {
+			conn, err := ln.Accept()
+			if err != nil {
+				return nil, err
+			}
+
+			return NewTSocketFromConnConf(conn, nil), nil
+		},
+		CloseFunc: func() error {
+			return nil
+		},
+		InterruptFunc: func() error {
+			return ln.Close()
+		},
+	}
+
+	serv := NewTSimpleServer2(proc, trans)
+	go serv.Serve()
+	time.Sleep(networkWaitDuration)
+
+	netConn, err := net.Dial("tcp", ln.Addr().String())
+	if err != nil || netConn == nil {
+		t.Fatal("error when dial server")
+	}
+	time.Sleep(networkWaitDuration)
+
+	serverStopTimeout := 50 * time.Millisecond
+	backupServerStopTimeout := ServerStopTimeout
+	t.Cleanup(func() {
+		ServerStopTimeout = backupServerStopTimeout
+	})
+	ServerStopTimeout = serverStopTimeout
+
+	st := time.Now()
+	err = serv.Stop()
+	if err != nil {
+		t.Errorf("error when stop server:%v", err)
+	}
+
+	if elapsed := time.Since(st); elapsed < serverStopTimeout {
+		t.Errorf("stop cost less time than server stop timeout, server stop timeout:%v,cost time:%v", ServerStopTimeout, elapsed)
+	}
+}
+
+func TestStopTimeoutWithSocketTimeout(t *testing.T) {
+	ln, err := net.Listen("tcp", "localhost:0")
+
+	if err != nil {
+		t.Fatalf("Failed to listen: %v", err)
+	}
+
+	proc := &mockProcessor{
+		ProcessFunc: func(in, out TProtocol) (bool, TException) {
+			in.ReadMessageBegin(context.Background())
+			return false, nil
+		},
+	}
+
+	conf := &TConfiguration{SocketTimeout: 5 * time.Millisecond}
+	wg := &sync.WaitGroup{}
+	trans := &mockServerTransport{
+		ListenFunc: func() error {
+			return nil
+		},
+		AcceptFunc: func() (TTransport, error) {
+			conn, err := ln.Accept()
+			if err != nil {
+				return nil, err
+			}
+			defer wg.Done()
+			return NewTSocketFromConnConf(conn, conf), nil
+		},
+		CloseFunc: func() error {
+			return nil
+		},
+		InterruptFunc: func() error {
+			return ln.Close()
+		},
+	}
+
+	serv := NewTSimpleServer2(proc, trans)
+	go serv.Serve()
+	time.Sleep(networkWaitDuration)
+
+	wg.Add(1)
+	netConn, err := net.Dial("tcp", ln.Addr().String())
+	if err != nil || netConn == nil {
+		t.Fatal("error when dial server")
+	}
+	wg.Wait()
+
+	expectedStopTimeout := time.Second
+	backupServerStopTimeout := ServerStopTimeout
+	t.Cleanup(func() {
+		ServerStopTimeout = backupServerStopTimeout
+	})
+	ServerStopTimeout = expectedStopTimeout
+
+	st := time.Now()
+	err = serv.Stop()
+	if elapsed := time.Since(st); elapsed > expectedStopTimeout/2 {
+		t.Errorf("stop cost more time than socket timeout, socket timeout:%v,server stop timeout:%v,cost time:%v", conf.SocketTimeout, ServerStopTimeout, elapsed)
+	}
+
+	if err != nil {
+		t.Fatalf("error when stop server:%v", err)
+	}
+}
diff --git a/test/go/src/common/clientserver_test.go b/test/go/src/common/clientserver_test.go
index 609086b..64b326a 100644
--- a/test/go/src/common/clientserver_test.go
+++ b/test/go/src/common/clientserver_test.go
@@ -75,7 +75,7 @@
 		t.Errorf("Unable to start server: %v", err)
 		return
 	}
-	go server.AcceptLoop()
+	go server.Serve()
 	defer server.Stop()
 	client, trans, err := StartClient(unit.host, unit.port, unit.domain_socket, unit.transport, unit.protocol, unit.ssl)
 	if err != nil {