THRIFT-1025 C++ ServerSocket should inherit from Socket with the necessary Ctor to listen on connections from a specific host (similar to perl library)

Patch: Jim King <jim.king@simplivity.com>
This closes PR: #417
diff --git a/lib/cpp/src/thrift/transport/TSSLServerSocket.cpp b/lib/cpp/src/thrift/transport/TSSLServerSocket.cpp
index cf686e0..421af6a 100644
--- a/lib/cpp/src/thrift/transport/TSSLServerSocket.cpp
+++ b/lib/cpp/src/thrift/transport/TSSLServerSocket.cpp
@@ -27,11 +27,17 @@
 /**
  * SSL server socket implementation.
  */
-TSSLServerSocket::TSSLServerSocket(THRIFT_SOCKET port, boost::shared_ptr<TSSLSocketFactory> factory)
+TSSLServerSocket::TSSLServerSocket(int port, boost::shared_ptr<TSSLSocketFactory> factory)
   : TServerSocket(port), factory_(factory) {
   factory_->server(true);
 }
 
+TSSLServerSocket::TSSLServerSocket(const std::string& address, int port,
+                                   boost::shared_ptr<TSSLSocketFactory> factory)
+  : TServerSocket(address, port), factory_(factory) {
+  factory_->server(true);
+}
+
 TSSLServerSocket::TSSLServerSocket(int port,
                                    int sendTimeout,
                                    int recvTimeout,
diff --git a/lib/cpp/src/thrift/transport/TSSLServerSocket.h b/lib/cpp/src/thrift/transport/TSSLServerSocket.h
index bb52b04..7d2dfcc 100644
--- a/lib/cpp/src/thrift/transport/TSSLServerSocket.h
+++ b/lib/cpp/src/thrift/transport/TSSLServerSocket.h
@@ -35,14 +35,25 @@
 class TSSLServerSocket : public TServerSocket {
 public:
   /**
-   * Constructor.
+   * Constructor.  Binds to all interfaces.
    *
    * @param port    Listening port
    * @param factory SSL socket factory implementation
    */
-  TSSLServerSocket(THRIFT_SOCKET port, boost::shared_ptr<TSSLSocketFactory> factory);
+  TSSLServerSocket(int port, boost::shared_ptr<TSSLSocketFactory> factory);
+
   /**
-   * Constructor.
+   * Constructor.  Binds to the specified address.
+   *
+   * @param address Address to bind to
+   * @param port    Listening port
+   * @param factory SSL socket factory implementation
+   */
+  TSSLServerSocket(const std::string& address, int port,
+                   boost::shared_ptr<TSSLSocketFactory> factory);
+
+  /**
+   * Constructor.  Binds to all interfaces.
    *
    * @param port        Listening port
    * @param sendTimeout Socket send timeout
diff --git a/lib/cpp/src/thrift/transport/TServerSocket.cpp b/lib/cpp/src/thrift/transport/TServerSocket.cpp
index e228dab..fccbcfa 100644
--- a/lib/cpp/src/thrift/transport/TServerSocket.cpp
+++ b/lib/cpp/src/thrift/transport/TServerSocket.cpp
@@ -108,7 +108,24 @@
     intSock2_(THRIFT_INVALID_SOCKET) {
 }
 
-TServerSocket::TServerSocket(string path)
+TServerSocket::TServerSocket(const string& address, int port)
+  : port_(port),
+    address_(address),
+    serverSocket_(THRIFT_INVALID_SOCKET),
+    acceptBacklog_(DEFAULT_BACKLOG),
+    sendTimeout_(0),
+    recvTimeout_(0),
+    accTimeout_(-1),
+    retryLimit_(0),
+    retryDelay_(0),
+    tcpSendBuffer_(0),
+    tcpRecvBuffer_(0),
+    keepAlive_(false),
+    intSock1_(THRIFT_INVALID_SOCKET),
+    intSock2_(THRIFT_INVALID_SOCKET) {
+}
+
+TServerSocket::TServerSocket(const string& path)
   : port_(0),
     path_(path),
     serverSocket_(THRIFT_INVALID_SOCKET),
@@ -184,8 +201,8 @@
   hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG;
   sprintf(port, "%d", port_);
 
-  // Wildcard address
-  error = getaddrinfo(NULL, port, &hints, &res0);
+  // If address is not specified use wildcard address (NULL)
+  error = getaddrinfo(address_.empty() ? NULL : &address_[0], port, &hints, &res0);
   if (error) {
     GlobalOutput.printf("getaddrinfo %d: %s", error, THRIFT_GAI_STRERROR(error));
     close();
diff --git a/lib/cpp/src/thrift/transport/TServerSocket.h b/lib/cpp/src/thrift/transport/TServerSocket.h
index 1533937..49711e8 100644
--- a/lib/cpp/src/thrift/transport/TServerSocket.h
+++ b/lib/cpp/src/thrift/transport/TServerSocket.h
@@ -42,11 +42,38 @@
 
   const static int DEFAULT_BACKLOG = 1024;
 
+  /**
+   * Constructor.
+   *
+   * @param port    Port number to bind to
+   */
   TServerSocket(int port);
-  TServerSocket(int port, int sendTimeout, int recvTimeout);
-  TServerSocket(std::string path);
 
-  ~TServerSocket();
+  /**
+   * Constructor.
+   *
+   * @param port        Port number to bind to
+   * @param sendTimeout Socket send timeout
+   * @param recvTimeout Socket receive timeout
+   */
+  TServerSocket(int port, int sendTimeout, int recvTimeout);
+
+  /**
+   * Constructor.
+   *
+   * @param address Address to bind to
+   * @param port    Port number to bind to
+   */
+  TServerSocket(const std::string& address, int port);
+
+  /**
+   * Constructor used for unix sockets.
+   *
+   * @param path Pathname for unix socket.
+   */
+  TServerSocket(const std::string& path);
+
+  virtual ~TServerSocket();
 
   void setSendTimeout(int sendTimeout);
   void setRecvTimeout(int recvTimeout);
@@ -85,6 +112,7 @@
 
 private:
   int port_;
+  std::string address_;
   std::string path_;
   THRIFT_SOCKET serverSocket_;
   int acceptBacklog_;
diff --git a/lib/cpp/test/Makefile.am b/lib/cpp/test/Makefile.am
index 43c5975..46ff911 100755
--- a/lib/cpp/test/Makefile.am
+++ b/lib/cpp/test/Makefile.am
@@ -69,6 +69,7 @@
 Benchmark_LDADD = libtestgencpp.la
 
 check_PROGRAMS = \
+	UnitTests \
 	TFDTransportTest \
 	TPipedTransportTest \
 	DebugProtoTest \
@@ -80,7 +81,6 @@
 	TransportTest \
 	ZlibTest \
 	TFileTransportTest \
-	UnitTests \
 	link_test \
 	OpenSSLManualInitTest \
 	EnumTest
@@ -106,7 +106,8 @@
 	TBufferBaseTest.cpp \
 	Base64Test.cpp \
 	ToStringTest.cpp \
-	TypedefTest.cpp
+	TypedefTest.cpp \
+        TServerSocketTest.cpp
 
 if !WITH_BOOSTTHREADS
 UnitTests_SOURCES += \
diff --git a/lib/cpp/test/TServerSocketTest.cpp b/lib/cpp/test/TServerSocketTest.cpp
new file mode 100644
index 0000000..ebfd03f
--- /dev/null
+++ b/lib/cpp/test/TServerSocketTest.cpp
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <boost/test/auto_unit_test.hpp>
+#include <thrift/transport/TSocket.h>
+#include <thrift/transport/TServerSocket.h>
+#include "TestPortFixture.h"
+
+using apache::thrift::transport::TServerSocket;
+using apache::thrift::transport::TSocket;
+using apache::thrift::transport::TTransport;
+using apache::thrift::transport::TTransportException;
+
+BOOST_FIXTURE_TEST_SUITE ( TServerSocketTest, TestPortFixture )
+
+class TestTServerSocket : public TServerSocket
+{
+  public:
+    TestTServerSocket(const std::string& address, int port) : TServerSocket(address, port) { }
+    using TServerSocket::acceptImpl;
+};
+
+BOOST_AUTO_TEST_CASE( test_bind_to_address )
+{
+    TestTServerSocket sock1("localhost", m_serverPort);
+    sock1.listen();
+    TSocket clientSock("localhost", m_serverPort);
+    clientSock.open();
+    boost::shared_ptr<TTransport> accepted = sock1.acceptImpl();
+    accepted->close();
+    sock1.close();
+
+    TServerSocket sock2("this.is.truly.an.unrecognizable.address.", m_serverPort);
+    BOOST_CHECK_THROW(sock2.listen(), TTransportException);
+    sock2.close();
+}
+
+BOOST_AUTO_TEST_CASE( test_close_before_listen )
+{
+    TServerSocket sock1("localhost", m_serverPort);
+    sock1.close();
+}
+
+BOOST_AUTO_TEST_CASE( test_get_port )
+{
+    TServerSocket sock1("localHost", 888);
+    BOOST_CHECK_EQUAL(888, sock1.getPort());
+}
+
+BOOST_AUTO_TEST_SUITE_END()
+
diff --git a/lib/cpp/test/TestPortFixture.h b/lib/cpp/test/TestPortFixture.h
new file mode 100644
index 0000000..5b27e5e
--- /dev/null
+++ b/lib/cpp/test/TestPortFixture.h
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#pragma once
+
+#include <cstdlib>
+
+class TestPortFixture
+{
+  public:
+    TestPortFixture()
+    {
+        const char *spEnv = std::getenv("THRIFT_TEST_PORT");
+        m_serverPort = (spEnv) ? atoi(spEnv) : 9090;
+    }
+
+  protected:
+    int m_serverPort;
+};
+