blob: 2111de8b0cb4b6fbacce8e62579088cc5b1ce63d [file] [log] [blame]
Divya Thaluru808d1432017-08-06 16:36:36 -07001/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#define BOOST_TEST_MODULE TNonblockingSSLServerTest
21#include <boost/test/unit_test.hpp>
Divya Thaluru808d1432017-08-06 16:36:36 -070022#include <boost/filesystem.hpp>
23#include <boost/format.hpp>
24
25#include "thrift/server/TNonblockingServer.h"
26#include "thrift/transport/TSSLSocket.h"
27#include "thrift/transport/TNonblockingSSLServerSocket.h"
28
29#include "gen-cpp/ParentService.h"
30
31#include <event.h>
32
33using namespace apache::thrift;
34using apache::thrift::concurrency::Guard;
35using apache::thrift::concurrency::Monitor;
36using apache::thrift::concurrency::Mutex;
37using apache::thrift::server::TServerEventHandler;
38using apache::thrift::transport::TSSLSocketFactory;
39using apache::thrift::transport::TSSLSocket;
40
41struct Handler : public test::ParentServiceIf {
42 void addString(const std::string& s) { strings_.push_back(s); }
43 void getStrings(std::vector<std::string>& _return) { _return = strings_; }
44 std::vector<std::string> strings_;
45
46 // dummy overrides not used in this test
47 int32_t incrementGeneration() { return 0; }
48 int32_t getGeneration() { return 0; }
49 void getDataWait(std::string&, const int32_t) {}
50 void onewayWait() {}
51 void exceptionWait(const std::string&) {}
52 void unexpectedExceptionWait(const std::string&) {}
53};
54
55boost::filesystem::path keyDir;
56boost::filesystem::path certFile(const std::string& filename)
57{
58 return keyDir / filename;
59}
60
61struct GlobalFixtureSSL
62{
63 GlobalFixtureSSL()
64 {
65 using namespace boost::unit_test::framework;
66 for (int i = 0; i < master_test_suite().argc; ++i)
67 {
68 BOOST_TEST_MESSAGE(boost::format("argv[%1%] = \"%2%\"") % i % master_test_suite().argv[i]);
69 }
70
71#ifdef __linux__
72 // OpenSSL calls send() without MSG_NOSIGPIPE so writing to a socket that has
73 // disconnected can cause a SIGPIPE signal...
74 signal(SIGPIPE, SIG_IGN);
75#endif
76
77 TSSLSocketFactory::setManualOpenSSLInitialization(true);
78 apache::thrift::transport::initializeOpenSSL();
79
80 keyDir = boost::filesystem::current_path().parent_path().parent_path().parent_path() / "test" / "keys";
81 if (!boost::filesystem::exists(certFile("server.crt")))
82 {
83 keyDir = boost::filesystem::path(master_test_suite().argv[master_test_suite().argc - 1]);
84 if (!boost::filesystem::exists(certFile("server.crt")))
85 {
86 throw std::invalid_argument("The last argument to this test must be the directory containing the test certificate(s).");
87 }
88 }
89 }
90
91 virtual ~GlobalFixtureSSL()
92 {
93 apache::thrift::transport::cleanupOpenSSL();
94#ifdef __linux__
95 signal(SIGPIPE, SIG_DFL);
96#endif
97 }
98};
99
100#if (BOOST_VERSION >= 105900)
101BOOST_GLOBAL_FIXTURE(GlobalFixtureSSL);
102#else
103BOOST_GLOBAL_FIXTURE(GlobalFixtureSSL)
104#endif
105
cyy316723a2019-01-05 16:35:14 +0800106std::shared_ptr<TSSLSocketFactory> createServerSocketFactory() {
107 std::shared_ptr<TSSLSocketFactory> pServerSocketFactory;
Divya Thaluru808d1432017-08-06 16:36:36 -0700108
109 pServerSocketFactory.reset(new TSSLSocketFactory());
110 pServerSocketFactory->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
111 pServerSocketFactory->loadCertificate(certFile("server.crt").string().c_str());
112 pServerSocketFactory->loadPrivateKey(certFile("server.key").string().c_str());
113 pServerSocketFactory->server(true);
114 return pServerSocketFactory;
115}
116
cyy316723a2019-01-05 16:35:14 +0800117std::shared_ptr<TSSLSocketFactory> createClientSocketFactory() {
118 std::shared_ptr<TSSLSocketFactory> pClientSocketFactory;
Divya Thaluru808d1432017-08-06 16:36:36 -0700119
120 pClientSocketFactory.reset(new TSSLSocketFactory());
121 pClientSocketFactory->authenticate(true);
122 pClientSocketFactory->loadCertificate(certFile("client.crt").string().c_str());
123 pClientSocketFactory->loadPrivateKey(certFile("client.key").string().c_str());
124 pClientSocketFactory->loadTrustedCertificates(certFile("CA.pem").string().c_str());
125 return pClientSocketFactory;
126}
127
128class Fixture {
129private:
130 struct ListenEventHandler : public TServerEventHandler {
131 public:
132 ListenEventHandler(Mutex* mutex) : listenMonitor_(mutex), ready_(false) {}
133
134 void preServe() /* override */ {
135 Guard g(listenMonitor_.mutex());
136 ready_ = true;
137 listenMonitor_.notify();
138 }
139
140 Monitor listenMonitor_;
141 bool ready_;
142 };
143
144 struct Runner : public apache::thrift::concurrency::Runnable {
145 int port;
cyy316723a2019-01-05 16:35:14 +0800146 std::shared_ptr<event_base> userEventBase;
147 std::shared_ptr<TProcessor> processor;
148 std::shared_ptr<server::TNonblockingServer> server;
149 std::shared_ptr<ListenEventHandler> listenHandler;
150 std::shared_ptr<TSSLSocketFactory> pServerSocketFactory;
151 std::shared_ptr<transport::TNonblockingSSLServerSocket> socket;
Divya Thaluru808d1432017-08-06 16:36:36 -0700152 Mutex mutex_;
153
154 Runner() {
155 listenHandler.reset(new ListenEventHandler(&mutex_));
156 }
157
158 virtual void run() {
159 // When binding to explicit port, allow retrying to workaround bind failures on ports in use
160 int retryCount = port ? 10 : 0;
161 pServerSocketFactory = createServerSocketFactory();
162 startServer(retryCount);
163 }
164
165 void readyBarrier() {
166 // block until server is listening and ready to accept connections
167 Guard g(mutex_);
168 while (!listenHandler->ready_) {
169 listenHandler->listenMonitor_.wait();
170 }
171 }
172 private:
173 void startServer(int retry_count) {
174 try {
175 socket.reset(new transport::TNonblockingSSLServerSocket(port, pServerSocketFactory));
176 server.reset(new server::TNonblockingServer(processor, socket));
177 server->setServerEventHandler(listenHandler);
178 server->setNumIOThreads(1);
179 if (userEventBase) {
180 server->registerEvents(userEventBase.get());
181 }
182 server->serve();
183 } catch (const transport::TTransportException&) {
184 if (retry_count > 0) {
185 ++port;
186 startServer(retry_count - 1);
187 } else {
188 throw;
189 }
190 }
191 }
192 };
193
194 struct EventDeleter {
195 void operator()(event_base* p) { event_base_free(p); }
196 };
197
198protected:
cyy316723a2019-01-05 16:35:14 +0800199 Fixture() : processor(new test::ParentServiceProcessor(std::make_shared<Handler>())) {}
Divya Thaluru808d1432017-08-06 16:36:36 -0700200
201 ~Fixture() {
202 if (server) {
203 server->stop();
204 }
205 if (thread) {
206 thread->join();
207 }
208 }
209
210 void setEventBase(event_base* user_event_base) {
211 userEventBase_.reset(user_event_base, EventDeleter());
212 }
213
214 int startServer(int port) {
cyy316723a2019-01-05 16:35:14 +0800215 std::shared_ptr<Runner> runner(new Runner);
Divya Thaluru808d1432017-08-06 16:36:36 -0700216 runner->port = port;
217 runner->processor = processor;
218 runner->userEventBase = userEventBase_;
219
cyy316723a2019-01-05 16:35:14 +0800220 std::unique_ptr<apache::thrift::concurrency::ThreadFactory> threadFactory(
cyyca8af9b2019-01-11 22:13:12 +0800221 new apache::thrift::concurrency::ThreadFactory(false));
Divya Thaluru808d1432017-08-06 16:36:36 -0700222 thread = threadFactory->newThread(runner);
223 thread->start();
224 runner->readyBarrier();
225
226 server = runner->server;
227 return runner->port;
228 }
229
230 bool canCommunicate(int serverPort) {
cyy316723a2019-01-05 16:35:14 +0800231 std::shared_ptr<TSSLSocketFactory> pClientSocketFactory = createClientSocketFactory();
232 std::shared_ptr<TSSLSocket> socket = pClientSocketFactory->createSocket("localhost", serverPort);
Divya Thaluru808d1432017-08-06 16:36:36 -0700233 socket->open();
cyy316723a2019-01-05 16:35:14 +0800234 test::ParentServiceClient client(std::make_shared<protocol::TBinaryProtocol>(
235 std::make_shared<transport::TFramedTransport>(socket)));
Divya Thaluru808d1432017-08-06 16:36:36 -0700236 client.addString("foo");
237 std::vector<std::string> strings;
238 client.getStrings(strings);
239 return strings.size() == 1 && !(strings[0].compare("foo"));
240 }
241
242private:
cyy316723a2019-01-05 16:35:14 +0800243 std::shared_ptr<event_base> userEventBase_;
244 std::shared_ptr<test::ParentServiceProcessor> processor;
Divya Thaluru808d1432017-08-06 16:36:36 -0700245protected:
cyy316723a2019-01-05 16:35:14 +0800246 std::shared_ptr<server::TNonblockingServer> server;
Divya Thaluru808d1432017-08-06 16:36:36 -0700247private:
cyy316723a2019-01-05 16:35:14 +0800248 std::shared_ptr<apache::thrift::concurrency::Thread> thread;
Divya Thaluru808d1432017-08-06 16:36:36 -0700249
250};
251
252BOOST_AUTO_TEST_SUITE(TNonblockingSSLServerTest)
253
254BOOST_FIXTURE_TEST_CASE(get_specified_port, Fixture) {
255 int specified_port = startServer(12345);
256 BOOST_REQUIRE_GE(specified_port, 12345);
257 BOOST_REQUIRE_EQUAL(server->getListenPort(), specified_port);
258 BOOST_CHECK(canCommunicate(specified_port));
259
260 server->stop();
261}
262
263BOOST_FIXTURE_TEST_CASE(get_assigned_port, Fixture) {
264 int specified_port = startServer(0);
265 BOOST_REQUIRE_EQUAL(specified_port, 0);
266 int assigned_port = server->getListenPort();
267 BOOST_REQUIRE_NE(assigned_port, 0);
268 BOOST_CHECK(canCommunicate(assigned_port));
269
270 server->stop();
271}
272
273BOOST_FIXTURE_TEST_CASE(provide_event_base, Fixture) {
274 event_base* eb = event_base_new();
275 setEventBase(eb);
276 startServer(0);
277
278 // assert that the server works
279 BOOST_CHECK(canCommunicate(server->getListenPort()));
280#if LIBEVENT_VERSION_NUMBER > 0x02010400
281 // also assert that the event_base is actually used when it's easy
282 BOOST_CHECK_GT(event_base_get_num_events(eb, EVENT_BASE_COUNT_ADDED), 0);
283#endif
284}
285
286BOOST_AUTO_TEST_SUITE_END()