THRIFT-2388 GoLang - Fix data races in simple_server and server_socket
Patch: Chris Bannister
diff --git a/lib/go/thrift/server_socket.go b/lib/go/thrift/server_socket.go
index 1a01095..4c80714 100644
--- a/lib/go/thrift/server_socket.go
+++ b/lib/go/thrift/server_socket.go
@@ -21,6 +21,7 @@
import (
"net"
+ "sync"
"time"
)
@@ -28,7 +29,10 @@
listener net.Listener
addr net.Addr
clientTimeout time.Duration
- interrupted bool
+
+ // Protects the interrupted value to make it thread safe.
+ mu sync.RWMutex
+ interrupted bool
}
func NewTServerSocket(listenAddr string) (*TServerSocket, error) {
@@ -56,7 +60,11 @@
}
func (p *TServerSocket) Accept() (TTransport, error) {
- if p.interrupted {
+ p.mu.RLock()
+ interrupted := p.interrupted
+ p.mu.RUnlock()
+
+ if interrupted {
return nil, errTransportInterrupted
}
if p.listener == nil {
@@ -102,6 +110,9 @@
}
func (p *TServerSocket) Interrupt() error {
+ p.mu.Lock()
p.interrupted = true
+ p.mu.Unlock()
+
return nil
}
diff --git a/lib/go/thrift/simple_server.go b/lib/go/thrift/simple_server.go
index b5cb0e1..521394c 100644
--- a/lib/go/thrift/simple_server.go
+++ b/lib/go/thrift/simple_server.go
@@ -25,7 +25,7 @@
// Simple, non-concurrent server for testing.
type TSimpleServer struct {
- stopped bool
+ quit chan struct{}
processorFactory TProcessorFactory
serverTransport TServerTransport
@@ -78,12 +78,14 @@
}
func NewTSimpleServerFactory6(processorFactory TProcessorFactory, serverTransport TServerTransport, inputTransportFactory TTransportFactory, outputTransportFactory TTransportFactory, inputProtocolFactory TProtocolFactory, outputProtocolFactory TProtocolFactory) *TSimpleServer {
- return &TSimpleServer{processorFactory: processorFactory,
+ return &TSimpleServer{
+ processorFactory: processorFactory,
serverTransport: serverTransport,
inputTransportFactory: inputTransportFactory,
outputTransportFactory: outputTransportFactory,
inputProtocolFactory: inputProtocolFactory,
outputProtocolFactory: outputProtocolFactory,
+ quit: make(chan struct{}, 1),
}
}
@@ -112,12 +114,19 @@
}
func (p *TSimpleServer) Serve() error {
- p.stopped = false
err := p.serverTransport.Listen()
if err != nil {
return err
}
- for !p.stopped {
+
+loop:
+ for {
+ select {
+ case <-p.quit:
+ break loop
+ default:
+ }
+
client, err := p.serverTransport.Accept()
if err != nil {
log.Println("Accept err: ", err)
@@ -134,7 +143,7 @@
}
func (p *TSimpleServer) Stop() error {
- p.stopped = true
+ p.quit <- struct{}{}
p.serverTransport.Interrupt()
return nil
}