blob: c8a3b25b3733e0302cc4f36a4076d869c419f476 [file] [log] [blame]
Bryan Duxburycd9aea12011-02-22 18:12:06 +00001// Copyright (c) 2009- Facebook
2// Distributed under the Thrift Software License
3//
4// See accompanying file LICENSE or visit the Thrift site at:
5// http://developers.facebook.com/thrift/
6
7#include <errno.h>
8#include <string>
9#include <arpa/inet.h>
Roger Meier34603932011-03-25 12:22:17 +000010#include <sys/types.h>
11#include <sys/socket.h>
Bryan Duxburycd9aea12011-02-22 18:12:06 +000012#include <boost/lexical_cast.hpp>
13#include <boost/shared_array.hpp>
14#include <openssl/err.h>
15#include <openssl/rand.h>
16#include <openssl/ssl.h>
17#include <openssl/x509v3.h>
18#include "concurrency/Mutex.h"
19#include "TSSLSocket.h"
20
21#define OPENSSL_VERSION_NO_THREAD_ID 0x10000000L
22
23using namespace std;
24using namespace boost;
25using namespace apache::thrift::concurrency;
26
27struct CRYPTO_dynlock_value {
28 Mutex mutex;
29};
30
31namespace apache { namespace thrift { namespace transport {
32
33
34static void buildErrors(string& message, int error = 0);
35static bool matchName(const char* host, const char* pattern, int size);
36static char uppercase(char c);
37
38// SSLContext implementation
39SSLContext::SSLContext() {
40 ctx_ = SSL_CTX_new(TLSv1_method());
41 if (ctx_ == NULL) {
42 string errors;
43 buildErrors(errors);
44 throw TSSLException("SSL_CTX_new: " + errors);
45 }
46 SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
47}
48
49SSLContext::~SSLContext() {
50 if (ctx_ != NULL) {
51 SSL_CTX_free(ctx_);
52 ctx_ = NULL;
53 }
54}
55
56SSL* SSLContext::createSSL() {
57 SSL* ssl = SSL_new(ctx_);
58 if (ssl == NULL) {
59 string errors;
60 buildErrors(errors);
61 throw TSSLException("SSL_new: " + errors);
62 }
63 return ssl;
64}
65
66// TSSLSocket implementation
67TSSLSocket::TSSLSocket(shared_ptr<SSLContext> ctx):
68 TSocket(), server_(false), ssl_(NULL), ctx_(ctx) {
69}
70
71TSSLSocket::TSSLSocket(shared_ptr<SSLContext> ctx, int socket):
72 TSocket(socket), server_(false), ssl_(NULL), ctx_(ctx) {
73}
74
75TSSLSocket::TSSLSocket(shared_ptr<SSLContext> ctx, string host, int port):
76 TSocket(host, port), server_(false), ssl_(NULL), ctx_(ctx) {
77}
78
79TSSLSocket::~TSSLSocket() {
80 close();
81}
82
83bool TSSLSocket::isOpen() {
84 if (ssl_ == NULL || !TSocket::isOpen()) {
85 return false;
86 }
87 int shutdown = SSL_get_shutdown(ssl_);
88 bool shutdownReceived = (shutdown & SSL_RECEIVED_SHUTDOWN);
89 bool shutdownSent = (shutdown & SSL_SENT_SHUTDOWN);
90 if (shutdownReceived && shutdownSent) {
91 return false;
92 }
93 return true;
94}
95
96bool TSSLSocket::peek() {
97 if (!isOpen()) {
98 return false;
99 }
100 checkHandshake();
101 int rc;
102 uint8_t byte;
103 rc = SSL_peek(ssl_, &byte, 1);
104 if (rc < 0) {
105 int errno_copy = errno;
106 string errors;
107 buildErrors(errors, errno_copy);
108 throw TSSLException("SSL_peek: " + errors);
109 }
110 if (rc == 0) {
111 ERR_clear_error();
112 }
113 return (rc > 0);
114}
115
116void TSSLSocket::open() {
117 if (isOpen() || server()) {
118 throw TTransportException(TTransportException::BAD_ARGS);
119 }
120 TSocket::open();
121}
122
123void TSSLSocket::close() {
124 if (ssl_ != NULL) {
125 int rc = SSL_shutdown(ssl_);
126 if (rc == 0) {
127 rc = SSL_shutdown(ssl_);
128 }
129 if (rc < 0) {
130 int errno_copy = errno;
131 string errors;
132 buildErrors(errors, errno_copy);
133 GlobalOutput(("SSL_shutdown: " + errors).c_str());
134 }
135 SSL_free(ssl_);
136 ssl_ = NULL;
137 ERR_remove_state(0);
138 }
139 TSocket::close();
140}
141
142uint32_t TSSLSocket::read(uint8_t* buf, uint32_t len) {
143 checkHandshake();
144 int32_t bytes = 0;
145 for (int32_t retries = 0; retries < maxRecvRetries_; retries++){
146 bytes = SSL_read(ssl_, buf, len);
147 if (bytes >= 0)
148 break;
149 int errno_copy = errno;
150 if (SSL_get_error(ssl_, bytes) == SSL_ERROR_SYSCALL) {
151 if (ERR_get_error() == 0 && errno_copy == EINTR) {
152 continue;
153 }
154 }
155 string errors;
156 buildErrors(errors, errno_copy);
157 throw TSSLException("SSL_read: " + errors);
158 }
159 return bytes;
160}
161
162void TSSLSocket::write(const uint8_t* buf, uint32_t len) {
163 checkHandshake();
164 // loop in case SSL_MODE_ENABLE_PARTIAL_WRITE is set in SSL_CTX.
165 uint32_t written = 0;
166 while (written < len) {
167 int32_t bytes = SSL_write(ssl_, &buf[written], len - written);
168 if (bytes <= 0) {
169 int errno_copy = errno;
170 string errors;
171 buildErrors(errors, errno_copy);
172 throw TSSLException("SSL_write: " + errors);
173 }
174 written += bytes;
175 }
176}
177
178void TSSLSocket::flush() {
179 // Don't throw exception if not open. Thrift servers close socket twice.
180 if (ssl_ == NULL) {
181 return;
182 }
183 checkHandshake();
184 BIO* bio = SSL_get_wbio(ssl_);
185 if (bio == NULL) {
186 throw TSSLException("SSL_get_wbio returns NULL");
187 }
188 if (BIO_flush(bio) != 1) {
189 int errno_copy = errno;
190 string errors;
191 buildErrors(errors, errno_copy);
192 throw TSSLException("BIO_flush: " + errors);
193 }
194}
195
196void TSSLSocket::checkHandshake() {
197 if (!TSocket::isOpen()) {
198 throw TTransportException(TTransportException::NOT_OPEN);
199 }
200 if (ssl_ != NULL) {
201 return;
202 }
203 ssl_ = ctx_->createSSL();
204 SSL_set_fd(ssl_, socket_);
205 int rc;
206 if (server()) {
207 rc = SSL_accept(ssl_);
208 } else {
209 rc = SSL_connect(ssl_);
210 }
211 if (rc <= 0) {
212 int errno_copy = errno;
213 string fname(server() ? "SSL_accept" : "SSL_connect");
214 string errors;
215 buildErrors(errors, errno_copy);
216 throw TSSLException(fname + ": " + errors);
217 }
218 authorize();
219}
220
221void TSSLSocket::authorize() {
222 int rc = SSL_get_verify_result(ssl_);
223 if (rc != X509_V_OK) { // verify authentication result
224 throw TSSLException(string("SSL_get_verify_result(), ") +
225 X509_verify_cert_error_string(rc));
226 }
227
228 X509* cert = SSL_get_peer_certificate(ssl_);
229 if (cert == NULL) {
230 // certificate is not present
231 if (SSL_get_verify_mode(ssl_) & SSL_VERIFY_FAIL_IF_NO_PEER_CERT) {
232 throw TSSLException("authorize: required certificate not present");
233 }
234 // certificate was optional: didn't intend to authorize remote
235 if (server() && access_ != NULL) {
236 throw TSSLException("authorize: certificate required for authorization");
237 }
238 return;
239 }
240 // certificate is present
241 if (access_ == NULL) {
242 X509_free(cert);
243 return;
244 }
245 // both certificate and access manager are present
246
247 string host;
248 sockaddr_storage sa = {};
249 socklen_t saLength = sizeof(sa);
250
251 if (getpeername(socket_, (sockaddr*)&sa, &saLength) != 0) {
252 sa.ss_family = AF_UNSPEC;
253 }
254
255 AccessManager::Decision decision = access_->verify(sa);
256
257 if (decision != AccessManager::SKIP) {
258 X509_free(cert);
259 if (decision != AccessManager::ALLOW) {
260 throw TSSLException("authorize: access denied based on remote IP");
261 }
262 return;
263 }
264
265 // extract subjectAlternativeName
266 STACK_OF(GENERAL_NAME)* alternatives = (STACK_OF(GENERAL_NAME)*)
267 X509_get_ext_d2i(cert, NID_subject_alt_name, NULL, NULL);
268 if (alternatives != NULL) {
269 const int count = sk_GENERAL_NAME_num(alternatives);
270 for (int i = 0; decision == AccessManager::SKIP && i < count; i++) {
271 const GENERAL_NAME* name = sk_GENERAL_NAME_value(alternatives, i);
272 if (name == NULL) {
273 continue;
274 }
275 char* data = (char*)ASN1_STRING_data(name->d.ia5);
276 int length = ASN1_STRING_length(name->d.ia5);
277 switch (name->type) {
278 case GEN_DNS:
279 if (host.empty()) {
280 host = (server() ? getPeerHost() : getHost());
281 }
282 decision = access_->verify(host, data, length);
283 break;
284 case GEN_IPADD:
285 decision = access_->verify(sa, data, length);
286 break;
287 }
288 }
289 sk_GENERAL_NAME_pop_free(alternatives, GENERAL_NAME_free);
290 }
291
292 if (decision != AccessManager::SKIP) {
293 X509_free(cert);
294 if (decision != AccessManager::ALLOW) {
295 throw TSSLException("authorize: access denied");
296 }
297 return;
298 }
299
300 // extract commonName
301 X509_NAME* name = X509_get_subject_name(cert);
302 if (name != NULL) {
303 X509_NAME_ENTRY* entry;
304 unsigned char* utf8;
305 int last = -1;
306 while (decision == AccessManager::SKIP) {
307 last = X509_NAME_get_index_by_NID(name, NID_commonName, last);
308 if (last == -1)
309 break;
310 entry = X509_NAME_get_entry(name, last);
311 if (entry == NULL)
312 continue;
313 ASN1_STRING* common = X509_NAME_ENTRY_get_data(entry);
314 int size = ASN1_STRING_to_UTF8(&utf8, common);
315 if (host.empty()) {
316 host = (server() ? getHost() : getHost());
317 }
318 decision = access_->verify(host, (char*)utf8, size);
319 OPENSSL_free(utf8);
320 }
321 }
322 X509_free(cert);
323 if (decision != AccessManager::ALLOW) {
324 throw TSSLException("authorize: cannot authorize peer");
325 }
326}
327
328// TSSLSocketFactory implementation
329bool TSSLSocketFactory::initialized = false;
330uint64_t TSSLSocketFactory::count_ = 0;
331Mutex TSSLSocketFactory::mutex_;
332
333TSSLSocketFactory::TSSLSocketFactory(): server_(false) {
334 Guard guard(mutex_);
335 if (count_ == 0) {
336 initializeOpenSSL();
337 randomize();
338 }
339 count_++;
340 ctx_ = shared_ptr<SSLContext>(new SSLContext);
341}
342
343TSSLSocketFactory::~TSSLSocketFactory() {
344 Guard guard(mutex_);
345 count_--;
346 if (count_ == 0) {
347 cleanupOpenSSL();
348 }
349}
350
351shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket() {
352 shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_));
353 setup(ssl);
354 return ssl;
355}
356
357shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(int socket) {
358 shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, socket));
359 setup(ssl);
360 return ssl;
361}
362
363shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(const string& host,
364 int port) {
365 shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, host, port));
366 setup(ssl);
367 return ssl;
368}
369
370void TSSLSocketFactory::setup(shared_ptr<TSSLSocket> ssl) {
371 ssl->server(server());
372 if (access_ == NULL && !server()) {
373 access_ = shared_ptr<AccessManager>(new DefaultClientAccessManager);
374 }
375 if (access_ != NULL) {
376 ssl->access(access_);
377 }
378}
379
380void TSSLSocketFactory::ciphers(const string& enable) {
381 int rc = SSL_CTX_set_cipher_list(ctx_->get(), enable.c_str());
382 if (ERR_peek_error() != 0) {
383 string errors;
384 buildErrors(errors);
385 throw TSSLException("SSL_CTX_set_cipher_list: " + errors);
386 }
387 if (rc == 0) {
388 throw TSSLException("None of specified ciphers are supported");
389 }
390}
391
392void TSSLSocketFactory::authenticate(bool required) {
393 int mode;
394 if (required) {
395 mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
396 } else {
397 mode = SSL_VERIFY_NONE;
398 }
399 SSL_CTX_set_verify(ctx_->get(), mode, NULL);
400}
401
402void TSSLSocketFactory::loadCertificate(const char* path, const char* format) {
403 if (path == NULL || format == NULL) {
404 throw TTransportException(TTransportException::BAD_ARGS,
405 "loadCertificateChain: either <path> or <format> is NULL");
406 }
407 if (strcmp(format, "PEM") == 0) {
408 if (SSL_CTX_use_certificate_chain_file(ctx_->get(), path) == 0) {
409 int errno_copy = errno;
410 string errors;
411 buildErrors(errors, errno_copy);
412 throw TSSLException("SSL_CTX_use_certificate_chain_file: " + errors);
413 }
414 } else {
415 throw TSSLException("Unsupported certificate format: " + string(format));
416 }
417}
418
419void TSSLSocketFactory::loadPrivateKey(const char* path, const char* format) {
420 if (path == NULL || format == NULL) {
421 throw TTransportException(TTransportException::BAD_ARGS,
422 "loadPrivateKey: either <path> or <format> is NULL");
423 }
424 if (strcmp(format, "PEM") == 0) {
425 if (SSL_CTX_use_PrivateKey_file(ctx_->get(), path, SSL_FILETYPE_PEM) == 0) {
426 int errno_copy = errno;
427 string errors;
428 buildErrors(errors, errno_copy);
429 throw TSSLException("SSL_CTX_use_PrivateKey_file: " + errors);
430 }
431 }
432}
433
434void TSSLSocketFactory::loadTrustedCertificates(const char* path) {
435 if (path == NULL) {
436 throw TTransportException(TTransportException::BAD_ARGS,
437 "loadTrustedCertificates: <path> is NULL");
438 }
439 if (SSL_CTX_load_verify_locations(ctx_->get(), path, NULL) == 0) {
440 int errno_copy = errno;
441 string errors;
442 buildErrors(errors, errno_copy);
443 throw TSSLException("SSL_CTX_load_verify_locations: " + errors);
444 }
445}
446
447void TSSLSocketFactory::randomize() {
448 RAND_poll();
449}
450
451void TSSLSocketFactory::overrideDefaultPasswordCallback() {
452 SSL_CTX_set_default_passwd_cb(ctx_->get(), passwordCallback);
453 SSL_CTX_set_default_passwd_cb_userdata(ctx_->get(), this);
454}
455
456int TSSLSocketFactory::passwordCallback(char* password,
457 int size,
458 int,
459 void* data) {
460 TSSLSocketFactory* factory = (TSSLSocketFactory*)data;
461 string userPassword;
462 factory->getPassword(userPassword, size);
463 int length = userPassword.size();
464 if (length > size) {
465 length = size;
466 }
467 strncpy(password, userPassword.c_str(), length);
468 return length;
469}
470
471static shared_array<Mutex> mutexes;
472
473static void callbackLocking(int mode, int n, const char*, int) {
474 if (mode & CRYPTO_LOCK) {
475 mutexes[n].lock();
476 } else {
477 mutexes[n].unlock();
478 }
479}
480
481#if (OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_NO_THREAD_ID)
482static unsigned long callbackThreadID() {
Roger Meier598bf482011-02-22 21:56:33 +0000483 return (unsigned long) pthread_self();
Bryan Duxburycd9aea12011-02-22 18:12:06 +0000484}
485#endif
486
487static CRYPTO_dynlock_value* dyn_create(const char*, int) {
488 return new CRYPTO_dynlock_value;
489}
490
491static void dyn_lock(int mode,
492 struct CRYPTO_dynlock_value* lock,
493 const char*, int) {
494 if (lock != NULL) {
495 if (mode & CRYPTO_LOCK) {
496 lock->mutex.lock();
497 } else {
498 lock->mutex.unlock();
499 }
500 }
501}
502
503static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
504 delete lock;
505}
506
507void TSSLSocketFactory::initializeOpenSSL() {
508 if (initialized) {
509 return;
510 }
511 initialized = true;
512 SSL_library_init();
513 SSL_load_error_strings();
514 // static locking
515 mutexes = shared_array<Mutex>(new Mutex[::CRYPTO_num_locks()]);
516 if (mutexes == NULL) {
517 throw TTransportException(TTransportException::INTERNAL_ERROR,
518 "initializeOpenSSL() failed, "
519 "out of memory while creating mutex array");
520 }
521#if (OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_NO_THREAD_ID)
522 CRYPTO_set_id_callback(callbackThreadID);
523#endif
524 CRYPTO_set_locking_callback(callbackLocking);
525 // dynamic locking
526 CRYPTO_set_dynlock_create_callback(dyn_create);
527 CRYPTO_set_dynlock_lock_callback(dyn_lock);
528 CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
529}
530
531void TSSLSocketFactory::cleanupOpenSSL() {
532 if (!initialized) {
533 return;
534 }
535 initialized = false;
536#if (OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_NO_THREAD_ID)
537 CRYPTO_set_id_callback(NULL);
538#endif
539 CRYPTO_set_locking_callback(NULL);
540 CRYPTO_set_dynlock_create_callback(NULL);
541 CRYPTO_set_dynlock_lock_callback(NULL);
542 CRYPTO_set_dynlock_destroy_callback(NULL);
543 CRYPTO_cleanup_all_ex_data();
544 ERR_free_strings();
545 EVP_cleanup();
546 ERR_remove_state(0);
547 mutexes.reset();
548}
549
550// extract error messages from error queue
551void buildErrors(string& errors, int errno_copy) {
552 unsigned long errorCode;
553 char message[256];
554
555 errors.reserve(512);
556 while ((errorCode = ERR_get_error()) != 0) {
557 if (!errors.empty()) {
558 errors += "; ";
559 }
560 const char* reason = ERR_reason_error_string(errorCode);
561 if (reason == NULL) {
562 snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
563 reason = message;
564 }
565 errors += reason;
566 }
567 if (errors.empty()) {
568 if (errno_copy != 0) {
569 errors += TOutput::strerror_s(errno_copy);
570 }
571 }
572 if (errors.empty()) {
573 errors = "error code: " + lexical_cast<string>(errno_copy);
574 }
575}
576
577/**
578 * Default implementation of AccessManager
579 */
580Decision DefaultClientAccessManager::verify(const sockaddr_storage& sa)
581 throw() { return SKIP; }
582
583Decision DefaultClientAccessManager::verify(const string& host,
584 const char* name,
585 int size) throw() {
586 if (host.empty() || name == NULL || size <= 0) {
587 return SKIP;
588 }
589 return (matchName(host.c_str(), name, size) ? ALLOW : SKIP);
590}
591
592Decision DefaultClientAccessManager::verify(const sockaddr_storage& sa,
593 const char* data,
594 int size) throw() {
595 bool match = false;
596 if (sa.ss_family == AF_INET && size == sizeof(in_addr)) {
597 match = (memcmp(&((sockaddr_in*)&sa)->sin_addr, data, size) == 0);
598 } else if (sa.ss_family == AF_INET6 && size == sizeof(in6_addr)) {
599 match = (memcmp(&((sockaddr_in6*)&sa)->sin6_addr, data, size) == 0);
600 }
601 return (match ? ALLOW : SKIP);
602}
603
604/**
605 * Match a name with a pattern. The pattern may include wildcard. A single
606 * wildcard "*" can match up to one component in the domain name.
607 *
608 * @param host Host name, typically the name of the remote host
609 * @param pattern Name retrieved from certificate
610 * @param size Size of "pattern"
611 * @return True, if "host" matches "pattern". False otherwise.
612 */
613bool matchName(const char* host, const char* pattern, int size) {
614 bool match = false;
615 int i = 0, j = 0;
616 while (i < size && host[j] != '\0') {
617 if (uppercase(pattern[i]) == uppercase(host[j])) {
618 i++;
619 j++;
620 continue;
621 }
622 if (pattern[i] == '*') {
623 while (host[j] != '.' && host[j] != '\0') {
624 j++;
625 }
626 i++;
627 continue;
628 }
629 break;
630 }
631 if (i == size && host[j] == '\0') {
632 match = true;
633 }
634 return match;
635
636}
637
638// This is to work around the Turkish locale issue, i.e.,
639// toupper('i') != toupper('I') if locale is "tr_TR"
640char uppercase (char c) {
641 if ('a' <= c && c <= 'z') {
642 return c + ('A' - 'a');
643 }
644 return c;
645}
646
647}}}