THRIFT-2388 GoLang - Fix data races in simple_server and server_socket
Patch: Chris Bannister
diff --git a/lib/go/thrift/server_socket.go b/lib/go/thrift/server_socket.go
index 1a01095..4c80714 100644
--- a/lib/go/thrift/server_socket.go
+++ b/lib/go/thrift/server_socket.go
@@ -21,6 +21,7 @@
 
 import (
 	"net"
+	"sync"
 	"time"
 )
 
@@ -28,7 +29,10 @@
 	listener      net.Listener
 	addr          net.Addr
 	clientTimeout time.Duration
-	interrupted   bool
+
+	// Protects the interrupted value to make it thread safe.
+	mu          sync.RWMutex
+	interrupted bool
 }
 
 func NewTServerSocket(listenAddr string) (*TServerSocket, error) {
@@ -56,7 +60,11 @@
 }
 
 func (p *TServerSocket) Accept() (TTransport, error) {
-	if p.interrupted {
+	p.mu.RLock()
+	interrupted := p.interrupted
+	p.mu.RUnlock()
+
+	if interrupted {
 		return nil, errTransportInterrupted
 	}
 	if p.listener == nil {
@@ -102,6 +110,9 @@
 }
 
 func (p *TServerSocket) Interrupt() error {
+	p.mu.Lock()
 	p.interrupted = true
+	p.mu.Unlock()
+
 	return nil
 }
diff --git a/lib/go/thrift/simple_server.go b/lib/go/thrift/simple_server.go
index b5cb0e1..521394c 100644
--- a/lib/go/thrift/simple_server.go
+++ b/lib/go/thrift/simple_server.go
@@ -25,7 +25,7 @@
 
 // Simple, non-concurrent server for testing.
 type TSimpleServer struct {
-	stopped bool
+	quit chan struct{}
 
 	processorFactory       TProcessorFactory
 	serverTransport        TServerTransport
@@ -78,12 +78,14 @@
 }
 
 func NewTSimpleServerFactory6(processorFactory TProcessorFactory, serverTransport TServerTransport, inputTransportFactory TTransportFactory, outputTransportFactory TTransportFactory, inputProtocolFactory TProtocolFactory, outputProtocolFactory TProtocolFactory) *TSimpleServer {
-	return &TSimpleServer{processorFactory: processorFactory,
+	return &TSimpleServer{
+		processorFactory:       processorFactory,
 		serverTransport:        serverTransport,
 		inputTransportFactory:  inputTransportFactory,
 		outputTransportFactory: outputTransportFactory,
 		inputProtocolFactory:   inputProtocolFactory,
 		outputProtocolFactory:  outputProtocolFactory,
+		quit: make(chan struct{}, 1),
 	}
 }
 
@@ -112,12 +114,19 @@
 }
 
 func (p *TSimpleServer) Serve() error {
-	p.stopped = false
 	err := p.serverTransport.Listen()
 	if err != nil {
 		return err
 	}
-	for !p.stopped {
+
+loop:
+	for {
+		select {
+		case <-p.quit:
+			break loop
+		default:
+		}
+
 		client, err := p.serverTransport.Accept()
 		if err != nil {
 			log.Println("Accept err: ", err)
@@ -134,7 +143,7 @@
 }
 
 func (p *TSimpleServer) Stop() error {
-	p.stopped = true
+	p.quit <- struct{}{}
 	p.serverTransport.Interrupt()
 	return nil
 }