THRIFT-4243 Fix Go TSimpleServer race on wait in Stop() method
Client: Go
Patch: Zachary Wasserman <zachwass2000@gmail.com>
This closes #1302
diff --git a/lib/go/thrift/simple_server.go b/lib/go/thrift/simple_server.go
index 3381e5b..72541b6 100644
--- a/lib/go/thrift/simple_server.go
+++ b/lib/go/thrift/simple_server.go
@@ -23,6 +23,7 @@
"log"
"runtime/debug"
"sync"
+ "sync/atomic"
)
/*
@@ -31,8 +32,9 @@
* This will work if golang user implements a conn-pool like thing in client side.
*/
type TSimpleServer struct {
- quit chan struct{}
- once sync.Once
+ closed int32
+ wg sync.WaitGroup
+ mu sync.Mutex
processorFactory TProcessorFactory
serverTransport TServerTransport
@@ -40,7 +42,6 @@
outputTransportFactory TTransportFactory
inputProtocolFactory TProtocolFactory
outputProtocolFactory TProtocolFactory
- sync.WaitGroup
}
func NewTSimpleServer2(processor TProcessor, serverTransport TServerTransport) *TSimpleServer {
@@ -93,7 +94,6 @@
outputTransportFactory: outputTransportFactory,
inputProtocolFactory: inputProtocolFactory,
outputProtocolFactory: outputProtocolFactory,
- quit: make(chan struct{}, 1),
}
}
@@ -128,22 +128,23 @@
func (p *TSimpleServer) AcceptLoop() error {
for {
client, err := p.serverTransport.Accept()
+ p.mu.Lock()
+ if atomic.LoadInt32(&p.closed) != 0 {
+ return nil
+ }
if err != nil {
- select {
- case <-p.quit:
- return nil
- default:
- }
return err
}
if client != nil {
- p.Add(1)
+ p.wg.Add(1)
go func() {
+ defer p.wg.Done()
if err := p.processRequests(client); err != nil {
log.Println("error processing request:", err)
}
}()
}
+ p.mu.Unlock()
}
}
@@ -157,18 +158,18 @@
}
func (p *TSimpleServer) Stop() error {
- q := func() {
- close(p.quit)
- p.serverTransport.Interrupt()
- p.Wait()
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if atomic.LoadInt32(&p.closed) != 0 {
+ return nil
}
- p.once.Do(q)
+ atomic.StoreInt32(&p.closed, 1)
+ p.serverTransport.Interrupt()
+ p.wg.Wait()
return nil
}
func (p *TSimpleServer) processRequests(client TTransport) error {
- defer p.Done()
-
processor := p.processorFactory.GetProcessor(client)
inputTransport, err := p.inputTransportFactory.GetTransport(client)
if err != nil {
@@ -193,10 +194,8 @@
defer outputTransport.Close()
}
for {
- select {
- case <-p.quit:
+ if atomic.LoadInt32(&p.closed) != 0 {
return nil
- default:
}
ok, err := processor.Process(inputProtocol, outputProtocol)
diff --git a/lib/go/thrift/simple_server_test.go b/lib/go/thrift/simple_server_test.go
index 068f3cc..8763a3b 100644
--- a/lib/go/thrift/simple_server_test.go
+++ b/lib/go/thrift/simple_server_test.go
@@ -21,6 +21,7 @@
import (
"testing"
+ "time"
)
type mockProcessor struct {
@@ -54,6 +55,14 @@
return m.InterruptFunc()
}
+type mockTTransport struct {
+ TTransport
+}
+
+func (m *mockTTransport) Close() error {
+ return nil
+}
+
func TestMultipleStop(t *testing.T) {
proc := &mockProcessor{
ProcessFunc: func(in, out TProtocol) (bool, TException) {
@@ -96,3 +105,31 @@
t.Error("second server transport should have been interrupted")
}
}
+
+func TestWaitRace(t *testing.T) {
+ proc := &mockProcessor{
+ ProcessFunc: func(in, out TProtocol) (bool, TException) {
+ return false, nil
+ },
+ }
+
+ trans := &mockServerTransport{
+ ListenFunc: func() error {
+ return nil
+ },
+ AcceptFunc: func() (TTransport, error) {
+ return &mockTTransport{}, nil
+ },
+ CloseFunc: func() error {
+ return nil
+ },
+ InterruptFunc: func() error {
+ return nil
+ },
+ }
+
+ serv := NewTSimpleServer2(proc, trans)
+ go serv.Serve()
+ time.Sleep(1)
+ serv.Stop()
+}