THRIFT-2031: Make SO_KEEPALIVE configurable for C++ lib
Client: cpp
Patch: Ben Craig
diff --git a/lib/cpp/src/thrift/transport/TServerSocket.cpp b/lib/cpp/src/thrift/transport/TServerSocket.cpp
index 1df719d..1c5b9de 100755
--- a/lib/cpp/src/thrift/transport/TServerSocket.cpp
+++ b/lib/cpp/src/thrift/transport/TServerSocket.cpp
@@ -86,7 +86,9 @@
   tcpSendBuffer_(0),
   tcpRecvBuffer_(0),
   intSock1_(THRIFT_INVALID_SOCKET),
-  intSock2_(THRIFT_INVALID_SOCKET) {}
+  intSock2_(THRIFT_INVALID_SOCKET),
+  keepAlive_(false)
+{}
 
 TServerSocket::TServerSocket(int port, int sendTimeout, int recvTimeout) :
   port_(port),
@@ -100,7 +102,9 @@
   tcpSendBuffer_(0),
   tcpRecvBuffer_(0),
   intSock1_(THRIFT_INVALID_SOCKET),
-  intSock2_(THRIFT_INVALID_SOCKET) {}
+  intSock2_(THRIFT_INVALID_SOCKET),
+  keepAlive_(false)
+{}
 
 TServerSocket::TServerSocket(string path) :
   port_(0),
@@ -115,7 +119,9 @@
   tcpSendBuffer_(0),
   tcpRecvBuffer_(0),
   intSock1_(THRIFT_INVALID_SOCKET),
-  intSock2_(THRIFT_INVALID_SOCKET) {}
+  intSock2_(THRIFT_INVALID_SOCKET),
+  keepAlive_(false)
+{}
 
 TServerSocket::~TServerSocket() {
   close();
@@ -480,6 +486,9 @@
   if (recvTimeout_ > 0) {
     client->setRecvTimeout(recvTimeout_);
   }
+  if (keepAlive_) {
+    client->setKeepAlive(keepAlive_);
+  }
   client->setCachedAddress((sockaddr*) &clientAddress, size);
 
   return client;
diff --git a/lib/cpp/src/thrift/transport/TServerSocket.h b/lib/cpp/src/thrift/transport/TServerSocket.h
index e7b7a82..c30d3e3 100644
--- a/lib/cpp/src/thrift/transport/TServerSocket.h
+++ b/lib/cpp/src/thrift/transport/TServerSocket.h
@@ -52,6 +52,8 @@
   void setRetryLimit(int retryLimit);
   void setRetryDelay(int retryDelay);
 
+  void setKeepAlive(bool keepAlive) {keepAlive_ = keepAlive;}
+
   void setTcpSendBuffer(int tcpSendBuffer);
   void setTcpRecvBuffer(int tcpRecvBuffer);
 
@@ -77,6 +79,7 @@
   int retryDelay_;
   int tcpSendBuffer_;
   int tcpRecvBuffer_;
+  bool keepAlive_;
 
   THRIFT_SOCKET intSock1_;
   THRIFT_SOCKET intSock2_;
diff --git a/lib/cpp/src/thrift/transport/TSocket.cpp b/lib/cpp/src/thrift/transport/TSocket.cpp
index 1ea98bd..381daa2 100755
--- a/lib/cpp/src/thrift/transport/TSocket.cpp
+++ b/lib/cpp/src/thrift/transport/TSocket.cpp
@@ -83,6 +83,7 @@
   connTimeout_(0),
   sendTimeout_(0),
   recvTimeout_(0),
+  keepAlive_(false),
   lingerOn_(1),
   lingerVal_(0),
   noDelay_(1),
@@ -99,6 +100,7 @@
   connTimeout_(0),
   sendTimeout_(0),
   recvTimeout_(0),
+  keepAlive_(false),
   lingerOn_(1),
   lingerVal_(0),
   noDelay_(1),
@@ -116,6 +118,7 @@
   connTimeout_(0),
   sendTimeout_(0),
   recvTimeout_(0),
+  keepAlive_(false),
   lingerOn_(1),
   lingerVal_(0),
   noDelay_(1),
@@ -133,6 +136,7 @@
   connTimeout_(0),
   sendTimeout_(0),
   recvTimeout_(0),
+  keepAlive_(false),
   lingerOn_(1),
   lingerVal_(0),
   noDelay_(1),
@@ -203,6 +207,10 @@
     setRecvTimeout(recvTimeout_);
   }
 
+  if(keepAlive_) {
+    setKeepAlive(keepAlive_);
+  }
+
   // Linger
   setLinger(lingerOn_, lingerVal_);
 
@@ -677,6 +685,22 @@
   }
 }
 
+void TSocket::setKeepAlive(bool keepAlive) {
+  keepAlive_ = keepAlive;
+
+  if (socket_ == -1) {
+    return;
+  }
+
+  int value = keepAlive_;
+  int ret = setsockopt(socket_, SOL_SOCKET, SO_KEEPALIVE, const_cast_sockopt(&value), sizeof(value));
+
+  if (ret == -1) {
+    int errno_copy = THRIFT_GET_SOCKET_ERROR;  // Copy THRIFT_GET_SOCKET_ERROR because we're allocating memory.
+    GlobalOutput.perror("TSocket::setKeepAlive() setsockopt() " + getSocketInfo(), errno_copy);
+  }
+}
+
 void TSocket::setMaxRecvRetries(int maxRecvRetries) {
   maxRecvRetries_ = maxRecvRetries;
 }
diff --git a/lib/cpp/src/thrift/transport/TSocket.h b/lib/cpp/src/thrift/transport/TSocket.h
index fd5b961..38d8c7f 100644
--- a/lib/cpp/src/thrift/transport/TSocket.h
+++ b/lib/cpp/src/thrift/transport/TSocket.h
@@ -179,6 +179,11 @@
   void setMaxRecvRetries(int maxRecvRetries);
 
   /**
+   * Set SO_KEEPALIVE
+   */
+  void setKeepAlive(bool keepAlive);
+
+  /**
    * Get socket information formated as a string <Host: x Port: x>
    */
   std::string getSocketInfo();
@@ -274,6 +279,9 @@
   /** Recv timeout in ms */
   int recvTimeout_;
 
+  /** Keep alive on */
+  bool keepAlive_;
+
   /** Linger on */
   bool lingerOn_;