THRIFT-1025 C++ ServerSocket should inherit from Socket with the necessary Ctor to listen on connections from a specific host (similar to perl library)
Patch: Jim King <jim.king@simplivity.com>
This closes PR: #417
diff --git a/lib/cpp/src/thrift/transport/TSSLServerSocket.cpp b/lib/cpp/src/thrift/transport/TSSLServerSocket.cpp
index cf686e0..421af6a 100644
--- a/lib/cpp/src/thrift/transport/TSSLServerSocket.cpp
+++ b/lib/cpp/src/thrift/transport/TSSLServerSocket.cpp
@@ -27,11 +27,17 @@
/**
* SSL server socket implementation.
*/
-TSSLServerSocket::TSSLServerSocket(THRIFT_SOCKET port, boost::shared_ptr<TSSLSocketFactory> factory)
+TSSLServerSocket::TSSLServerSocket(int port, boost::shared_ptr<TSSLSocketFactory> factory)
: TServerSocket(port), factory_(factory) {
factory_->server(true);
}
+TSSLServerSocket::TSSLServerSocket(const std::string& address, int port,
+ boost::shared_ptr<TSSLSocketFactory> factory)
+ : TServerSocket(address, port), factory_(factory) {
+ factory_->server(true);
+}
+
TSSLServerSocket::TSSLServerSocket(int port,
int sendTimeout,
int recvTimeout,
diff --git a/lib/cpp/src/thrift/transport/TSSLServerSocket.h b/lib/cpp/src/thrift/transport/TSSLServerSocket.h
index bb52b04..7d2dfcc 100644
--- a/lib/cpp/src/thrift/transport/TSSLServerSocket.h
+++ b/lib/cpp/src/thrift/transport/TSSLServerSocket.h
@@ -35,14 +35,25 @@
class TSSLServerSocket : public TServerSocket {
public:
/**
- * Constructor.
+ * Constructor. Binds to all interfaces.
*
* @param port Listening port
* @param factory SSL socket factory implementation
*/
- TSSLServerSocket(THRIFT_SOCKET port, boost::shared_ptr<TSSLSocketFactory> factory);
+ TSSLServerSocket(int port, boost::shared_ptr<TSSLSocketFactory> factory);
+
/**
- * Constructor.
+ * Constructor. Binds to the specified address.
+ *
+ * @param address Address to bind to
+ * @param port Listening port
+ * @param factory SSL socket factory implementation
+ */
+ TSSLServerSocket(const std::string& address, int port,
+ boost::shared_ptr<TSSLSocketFactory> factory);
+
+ /**
+ * Constructor. Binds to all interfaces.
*
* @param port Listening port
* @param sendTimeout Socket send timeout
diff --git a/lib/cpp/src/thrift/transport/TServerSocket.cpp b/lib/cpp/src/thrift/transport/TServerSocket.cpp
index e228dab..fccbcfa 100644
--- a/lib/cpp/src/thrift/transport/TServerSocket.cpp
+++ b/lib/cpp/src/thrift/transport/TServerSocket.cpp
@@ -108,7 +108,24 @@
intSock2_(THRIFT_INVALID_SOCKET) {
}
-TServerSocket::TServerSocket(string path)
+TServerSocket::TServerSocket(const string& address, int port)
+ : port_(port),
+ address_(address),
+ serverSocket_(THRIFT_INVALID_SOCKET),
+ acceptBacklog_(DEFAULT_BACKLOG),
+ sendTimeout_(0),
+ recvTimeout_(0),
+ accTimeout_(-1),
+ retryLimit_(0),
+ retryDelay_(0),
+ tcpSendBuffer_(0),
+ tcpRecvBuffer_(0),
+ keepAlive_(false),
+ intSock1_(THRIFT_INVALID_SOCKET),
+ intSock2_(THRIFT_INVALID_SOCKET) {
+}
+
+TServerSocket::TServerSocket(const string& path)
: port_(0),
path_(path),
serverSocket_(THRIFT_INVALID_SOCKET),
@@ -184,8 +201,8 @@
hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG;
sprintf(port, "%d", port_);
- // Wildcard address
- error = getaddrinfo(NULL, port, &hints, &res0);
+ // If address is not specified use wildcard address (NULL)
+ error = getaddrinfo(address_.empty() ? NULL : &address_[0], port, &hints, &res0);
if (error) {
GlobalOutput.printf("getaddrinfo %d: %s", error, THRIFT_GAI_STRERROR(error));
close();
diff --git a/lib/cpp/src/thrift/transport/TServerSocket.h b/lib/cpp/src/thrift/transport/TServerSocket.h
index 1533937..49711e8 100644
--- a/lib/cpp/src/thrift/transport/TServerSocket.h
+++ b/lib/cpp/src/thrift/transport/TServerSocket.h
@@ -42,11 +42,38 @@
const static int DEFAULT_BACKLOG = 1024;
+ /**
+ * Constructor.
+ *
+ * @param port Port number to bind to
+ */
TServerSocket(int port);
- TServerSocket(int port, int sendTimeout, int recvTimeout);
- TServerSocket(std::string path);
- ~TServerSocket();
+ /**
+ * Constructor.
+ *
+ * @param port Port number to bind to
+ * @param sendTimeout Socket send timeout
+ * @param recvTimeout Socket receive timeout
+ */
+ TServerSocket(int port, int sendTimeout, int recvTimeout);
+
+ /**
+ * Constructor.
+ *
+ * @param address Address to bind to
+ * @param port Port number to bind to
+ */
+ TServerSocket(const std::string& address, int port);
+
+ /**
+ * Constructor used for unix sockets.
+ *
+ * @param path Pathname for unix socket.
+ */
+ TServerSocket(const std::string& path);
+
+ virtual ~TServerSocket();
void setSendTimeout(int sendTimeout);
void setRecvTimeout(int recvTimeout);
@@ -85,6 +112,7 @@
private:
int port_;
+ std::string address_;
std::string path_;
THRIFT_SOCKET serverSocket_;
int acceptBacklog_;
diff --git a/lib/cpp/test/Makefile.am b/lib/cpp/test/Makefile.am
index 43c5975..46ff911 100755
--- a/lib/cpp/test/Makefile.am
+++ b/lib/cpp/test/Makefile.am
@@ -69,6 +69,7 @@
Benchmark_LDADD = libtestgencpp.la
check_PROGRAMS = \
+ UnitTests \
TFDTransportTest \
TPipedTransportTest \
DebugProtoTest \
@@ -80,7 +81,6 @@
TransportTest \
ZlibTest \
TFileTransportTest \
- UnitTests \
link_test \
OpenSSLManualInitTest \
EnumTest
@@ -106,7 +106,8 @@
TBufferBaseTest.cpp \
Base64Test.cpp \
ToStringTest.cpp \
- TypedefTest.cpp
+ TypedefTest.cpp \
+ TServerSocketTest.cpp
if !WITH_BOOSTTHREADS
UnitTests_SOURCES += \
diff --git a/lib/cpp/test/TServerSocketTest.cpp b/lib/cpp/test/TServerSocketTest.cpp
new file mode 100644
index 0000000..ebfd03f
--- /dev/null
+++ b/lib/cpp/test/TServerSocketTest.cpp
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <boost/test/auto_unit_test.hpp>
+#include <thrift/transport/TSocket.h>
+#include <thrift/transport/TServerSocket.h>
+#include "TestPortFixture.h"
+
+using apache::thrift::transport::TServerSocket;
+using apache::thrift::transport::TSocket;
+using apache::thrift::transport::TTransport;
+using apache::thrift::transport::TTransportException;
+
+BOOST_FIXTURE_TEST_SUITE ( TServerSocketTest, TestPortFixture )
+
+class TestTServerSocket : public TServerSocket
+{
+ public:
+ TestTServerSocket(const std::string& address, int port) : TServerSocket(address, port) { }
+ using TServerSocket::acceptImpl;
+};
+
+BOOST_AUTO_TEST_CASE( test_bind_to_address )
+{
+ TestTServerSocket sock1("localhost", m_serverPort);
+ sock1.listen();
+ TSocket clientSock("localhost", m_serverPort);
+ clientSock.open();
+ boost::shared_ptr<TTransport> accepted = sock1.acceptImpl();
+ accepted->close();
+ sock1.close();
+
+ TServerSocket sock2("this.is.truly.an.unrecognizable.address.", m_serverPort);
+ BOOST_CHECK_THROW(sock2.listen(), TTransportException);
+ sock2.close();
+}
+
+BOOST_AUTO_TEST_CASE( test_close_before_listen )
+{
+ TServerSocket sock1("localhost", m_serverPort);
+ sock1.close();
+}
+
+BOOST_AUTO_TEST_CASE( test_get_port )
+{
+ TServerSocket sock1("localHost", 888);
+ BOOST_CHECK_EQUAL(888, sock1.getPort());
+}
+
+BOOST_AUTO_TEST_SUITE_END()
+
diff --git a/lib/cpp/test/TestPortFixture.h b/lib/cpp/test/TestPortFixture.h
new file mode 100644
index 0000000..5b27e5e
--- /dev/null
+++ b/lib/cpp/test/TestPortFixture.h
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#pragma once
+
+#include <cstdlib>
+
+class TestPortFixture
+{
+ public:
+ TestPortFixture()
+ {
+ const char *spEnv = std::getenv("THRIFT_TEST_PORT");
+ m_serverPort = (spEnv) ? atoi(spEnv) : 9090;
+ }
+
+ protected:
+ int m_serverPort;
+};
+