blob: 5e18e9436e390d6f78d4581ca58b17e9a24031ab [file] [log] [blame]
Bryan Duxburycd9aea12011-02-22 18:12:06 +00001// Copyright (c) 2009- Facebook
2// Distributed under the Thrift Software License
3//
4// See accompanying file LICENSE or visit the Thrift site at:
5// http://developers.facebook.com/thrift/
6
7#include <errno.h>
8#include <string>
9#include <arpa/inet.h>
10#include <boost/lexical_cast.hpp>
11#include <boost/shared_array.hpp>
12#include <openssl/err.h>
13#include <openssl/rand.h>
14#include <openssl/ssl.h>
15#include <openssl/x509v3.h>
16#include "concurrency/Mutex.h"
17#include "TSSLSocket.h"
18
19#define OPENSSL_VERSION_NO_THREAD_ID 0x10000000L
20
21using namespace std;
22using namespace boost;
23using namespace apache::thrift::concurrency;
24
25struct CRYPTO_dynlock_value {
26 Mutex mutex;
27};
28
29namespace apache { namespace thrift { namespace transport {
30
31
32static void buildErrors(string& message, int error = 0);
33static bool matchName(const char* host, const char* pattern, int size);
34static char uppercase(char c);
35
36// SSLContext implementation
37SSLContext::SSLContext() {
38 ctx_ = SSL_CTX_new(TLSv1_method());
39 if (ctx_ == NULL) {
40 string errors;
41 buildErrors(errors);
42 throw TSSLException("SSL_CTX_new: " + errors);
43 }
44 SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
45}
46
47SSLContext::~SSLContext() {
48 if (ctx_ != NULL) {
49 SSL_CTX_free(ctx_);
50 ctx_ = NULL;
51 }
52}
53
54SSL* SSLContext::createSSL() {
55 SSL* ssl = SSL_new(ctx_);
56 if (ssl == NULL) {
57 string errors;
58 buildErrors(errors);
59 throw TSSLException("SSL_new: " + errors);
60 }
61 return ssl;
62}
63
64// TSSLSocket implementation
65TSSLSocket::TSSLSocket(shared_ptr<SSLContext> ctx):
66 TSocket(), server_(false), ssl_(NULL), ctx_(ctx) {
67}
68
69TSSLSocket::TSSLSocket(shared_ptr<SSLContext> ctx, int socket):
70 TSocket(socket), server_(false), ssl_(NULL), ctx_(ctx) {
71}
72
73TSSLSocket::TSSLSocket(shared_ptr<SSLContext> ctx, string host, int port):
74 TSocket(host, port), server_(false), ssl_(NULL), ctx_(ctx) {
75}
76
77TSSLSocket::~TSSLSocket() {
78 close();
79}
80
81bool TSSLSocket::isOpen() {
82 if (ssl_ == NULL || !TSocket::isOpen()) {
83 return false;
84 }
85 int shutdown = SSL_get_shutdown(ssl_);
86 bool shutdownReceived = (shutdown & SSL_RECEIVED_SHUTDOWN);
87 bool shutdownSent = (shutdown & SSL_SENT_SHUTDOWN);
88 if (shutdownReceived && shutdownSent) {
89 return false;
90 }
91 return true;
92}
93
94bool TSSLSocket::peek() {
95 if (!isOpen()) {
96 return false;
97 }
98 checkHandshake();
99 int rc;
100 uint8_t byte;
101 rc = SSL_peek(ssl_, &byte, 1);
102 if (rc < 0) {
103 int errno_copy = errno;
104 string errors;
105 buildErrors(errors, errno_copy);
106 throw TSSLException("SSL_peek: " + errors);
107 }
108 if (rc == 0) {
109 ERR_clear_error();
110 }
111 return (rc > 0);
112}
113
114void TSSLSocket::open() {
115 if (isOpen() || server()) {
116 throw TTransportException(TTransportException::BAD_ARGS);
117 }
118 TSocket::open();
119}
120
121void TSSLSocket::close() {
122 if (ssl_ != NULL) {
123 int rc = SSL_shutdown(ssl_);
124 if (rc == 0) {
125 rc = SSL_shutdown(ssl_);
126 }
127 if (rc < 0) {
128 int errno_copy = errno;
129 string errors;
130 buildErrors(errors, errno_copy);
131 GlobalOutput(("SSL_shutdown: " + errors).c_str());
132 }
133 SSL_free(ssl_);
134 ssl_ = NULL;
135 ERR_remove_state(0);
136 }
137 TSocket::close();
138}
139
140uint32_t TSSLSocket::read(uint8_t* buf, uint32_t len) {
141 checkHandshake();
142 int32_t bytes = 0;
143 for (int32_t retries = 0; retries < maxRecvRetries_; retries++){
144 bytes = SSL_read(ssl_, buf, len);
145 if (bytes >= 0)
146 break;
147 int errno_copy = errno;
148 if (SSL_get_error(ssl_, bytes) == SSL_ERROR_SYSCALL) {
149 if (ERR_get_error() == 0 && errno_copy == EINTR) {
150 continue;
151 }
152 }
153 string errors;
154 buildErrors(errors, errno_copy);
155 throw TSSLException("SSL_read: " + errors);
156 }
157 return bytes;
158}
159
160void TSSLSocket::write(const uint8_t* buf, uint32_t len) {
161 checkHandshake();
162 // loop in case SSL_MODE_ENABLE_PARTIAL_WRITE is set in SSL_CTX.
163 uint32_t written = 0;
164 while (written < len) {
165 int32_t bytes = SSL_write(ssl_, &buf[written], len - written);
166 if (bytes <= 0) {
167 int errno_copy = errno;
168 string errors;
169 buildErrors(errors, errno_copy);
170 throw TSSLException("SSL_write: " + errors);
171 }
172 written += bytes;
173 }
174}
175
176void TSSLSocket::flush() {
177 // Don't throw exception if not open. Thrift servers close socket twice.
178 if (ssl_ == NULL) {
179 return;
180 }
181 checkHandshake();
182 BIO* bio = SSL_get_wbio(ssl_);
183 if (bio == NULL) {
184 throw TSSLException("SSL_get_wbio returns NULL");
185 }
186 if (BIO_flush(bio) != 1) {
187 int errno_copy = errno;
188 string errors;
189 buildErrors(errors, errno_copy);
190 throw TSSLException("BIO_flush: " + errors);
191 }
192}
193
194void TSSLSocket::checkHandshake() {
195 if (!TSocket::isOpen()) {
196 throw TTransportException(TTransportException::NOT_OPEN);
197 }
198 if (ssl_ != NULL) {
199 return;
200 }
201 ssl_ = ctx_->createSSL();
202 SSL_set_fd(ssl_, socket_);
203 int rc;
204 if (server()) {
205 rc = SSL_accept(ssl_);
206 } else {
207 rc = SSL_connect(ssl_);
208 }
209 if (rc <= 0) {
210 int errno_copy = errno;
211 string fname(server() ? "SSL_accept" : "SSL_connect");
212 string errors;
213 buildErrors(errors, errno_copy);
214 throw TSSLException(fname + ": " + errors);
215 }
216 authorize();
217}
218
219void TSSLSocket::authorize() {
220 int rc = SSL_get_verify_result(ssl_);
221 if (rc != X509_V_OK) { // verify authentication result
222 throw TSSLException(string("SSL_get_verify_result(), ") +
223 X509_verify_cert_error_string(rc));
224 }
225
226 X509* cert = SSL_get_peer_certificate(ssl_);
227 if (cert == NULL) {
228 // certificate is not present
229 if (SSL_get_verify_mode(ssl_) & SSL_VERIFY_FAIL_IF_NO_PEER_CERT) {
230 throw TSSLException("authorize: required certificate not present");
231 }
232 // certificate was optional: didn't intend to authorize remote
233 if (server() && access_ != NULL) {
234 throw TSSLException("authorize: certificate required for authorization");
235 }
236 return;
237 }
238 // certificate is present
239 if (access_ == NULL) {
240 X509_free(cert);
241 return;
242 }
243 // both certificate and access manager are present
244
245 string host;
246 sockaddr_storage sa = {};
247 socklen_t saLength = sizeof(sa);
248
249 if (getpeername(socket_, (sockaddr*)&sa, &saLength) != 0) {
250 sa.ss_family = AF_UNSPEC;
251 }
252
253 AccessManager::Decision decision = access_->verify(sa);
254
255 if (decision != AccessManager::SKIP) {
256 X509_free(cert);
257 if (decision != AccessManager::ALLOW) {
258 throw TSSLException("authorize: access denied based on remote IP");
259 }
260 return;
261 }
262
263 // extract subjectAlternativeName
264 STACK_OF(GENERAL_NAME)* alternatives = (STACK_OF(GENERAL_NAME)*)
265 X509_get_ext_d2i(cert, NID_subject_alt_name, NULL, NULL);
266 if (alternatives != NULL) {
267 const int count = sk_GENERAL_NAME_num(alternatives);
268 for (int i = 0; decision == AccessManager::SKIP && i < count; i++) {
269 const GENERAL_NAME* name = sk_GENERAL_NAME_value(alternatives, i);
270 if (name == NULL) {
271 continue;
272 }
273 char* data = (char*)ASN1_STRING_data(name->d.ia5);
274 int length = ASN1_STRING_length(name->d.ia5);
275 switch (name->type) {
276 case GEN_DNS:
277 if (host.empty()) {
278 host = (server() ? getPeerHost() : getHost());
279 }
280 decision = access_->verify(host, data, length);
281 break;
282 case GEN_IPADD:
283 decision = access_->verify(sa, data, length);
284 break;
285 }
286 }
287 sk_GENERAL_NAME_pop_free(alternatives, GENERAL_NAME_free);
288 }
289
290 if (decision != AccessManager::SKIP) {
291 X509_free(cert);
292 if (decision != AccessManager::ALLOW) {
293 throw TSSLException("authorize: access denied");
294 }
295 return;
296 }
297
298 // extract commonName
299 X509_NAME* name = X509_get_subject_name(cert);
300 if (name != NULL) {
301 X509_NAME_ENTRY* entry;
302 unsigned char* utf8;
303 int last = -1;
304 while (decision == AccessManager::SKIP) {
305 last = X509_NAME_get_index_by_NID(name, NID_commonName, last);
306 if (last == -1)
307 break;
308 entry = X509_NAME_get_entry(name, last);
309 if (entry == NULL)
310 continue;
311 ASN1_STRING* common = X509_NAME_ENTRY_get_data(entry);
312 int size = ASN1_STRING_to_UTF8(&utf8, common);
313 if (host.empty()) {
314 host = (server() ? getHost() : getHost());
315 }
316 decision = access_->verify(host, (char*)utf8, size);
317 OPENSSL_free(utf8);
318 }
319 }
320 X509_free(cert);
321 if (decision != AccessManager::ALLOW) {
322 throw TSSLException("authorize: cannot authorize peer");
323 }
324}
325
326// TSSLSocketFactory implementation
327bool TSSLSocketFactory::initialized = false;
328uint64_t TSSLSocketFactory::count_ = 0;
329Mutex TSSLSocketFactory::mutex_;
330
331TSSLSocketFactory::TSSLSocketFactory(): server_(false) {
332 Guard guard(mutex_);
333 if (count_ == 0) {
334 initializeOpenSSL();
335 randomize();
336 }
337 count_++;
338 ctx_ = shared_ptr<SSLContext>(new SSLContext);
339}
340
341TSSLSocketFactory::~TSSLSocketFactory() {
342 Guard guard(mutex_);
343 count_--;
344 if (count_ == 0) {
345 cleanupOpenSSL();
346 }
347}
348
349shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket() {
350 shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_));
351 setup(ssl);
352 return ssl;
353}
354
355shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(int socket) {
356 shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, socket));
357 setup(ssl);
358 return ssl;
359}
360
361shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(const string& host,
362 int port) {
363 shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, host, port));
364 setup(ssl);
365 return ssl;
366}
367
368void TSSLSocketFactory::setup(shared_ptr<TSSLSocket> ssl) {
369 ssl->server(server());
370 if (access_ == NULL && !server()) {
371 access_ = shared_ptr<AccessManager>(new DefaultClientAccessManager);
372 }
373 if (access_ != NULL) {
374 ssl->access(access_);
375 }
376}
377
378void TSSLSocketFactory::ciphers(const string& enable) {
379 int rc = SSL_CTX_set_cipher_list(ctx_->get(), enable.c_str());
380 if (ERR_peek_error() != 0) {
381 string errors;
382 buildErrors(errors);
383 throw TSSLException("SSL_CTX_set_cipher_list: " + errors);
384 }
385 if (rc == 0) {
386 throw TSSLException("None of specified ciphers are supported");
387 }
388}
389
390void TSSLSocketFactory::authenticate(bool required) {
391 int mode;
392 if (required) {
393 mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
394 } else {
395 mode = SSL_VERIFY_NONE;
396 }
397 SSL_CTX_set_verify(ctx_->get(), mode, NULL);
398}
399
400void TSSLSocketFactory::loadCertificate(const char* path, const char* format) {
401 if (path == NULL || format == NULL) {
402 throw TTransportException(TTransportException::BAD_ARGS,
403 "loadCertificateChain: either <path> or <format> is NULL");
404 }
405 if (strcmp(format, "PEM") == 0) {
406 if (SSL_CTX_use_certificate_chain_file(ctx_->get(), path) == 0) {
407 int errno_copy = errno;
408 string errors;
409 buildErrors(errors, errno_copy);
410 throw TSSLException("SSL_CTX_use_certificate_chain_file: " + errors);
411 }
412 } else {
413 throw TSSLException("Unsupported certificate format: " + string(format));
414 }
415}
416
417void TSSLSocketFactory::loadPrivateKey(const char* path, const char* format) {
418 if (path == NULL || format == NULL) {
419 throw TTransportException(TTransportException::BAD_ARGS,
420 "loadPrivateKey: either <path> or <format> is NULL");
421 }
422 if (strcmp(format, "PEM") == 0) {
423 if (SSL_CTX_use_PrivateKey_file(ctx_->get(), path, SSL_FILETYPE_PEM) == 0) {
424 int errno_copy = errno;
425 string errors;
426 buildErrors(errors, errno_copy);
427 throw TSSLException("SSL_CTX_use_PrivateKey_file: " + errors);
428 }
429 }
430}
431
432void TSSLSocketFactory::loadTrustedCertificates(const char* path) {
433 if (path == NULL) {
434 throw TTransportException(TTransportException::BAD_ARGS,
435 "loadTrustedCertificates: <path> is NULL");
436 }
437 if (SSL_CTX_load_verify_locations(ctx_->get(), path, NULL) == 0) {
438 int errno_copy = errno;
439 string errors;
440 buildErrors(errors, errno_copy);
441 throw TSSLException("SSL_CTX_load_verify_locations: " + errors);
442 }
443}
444
445void TSSLSocketFactory::randomize() {
446 RAND_poll();
447}
448
449void TSSLSocketFactory::overrideDefaultPasswordCallback() {
450 SSL_CTX_set_default_passwd_cb(ctx_->get(), passwordCallback);
451 SSL_CTX_set_default_passwd_cb_userdata(ctx_->get(), this);
452}
453
454int TSSLSocketFactory::passwordCallback(char* password,
455 int size,
456 int,
457 void* data) {
458 TSSLSocketFactory* factory = (TSSLSocketFactory*)data;
459 string userPassword;
460 factory->getPassword(userPassword, size);
461 int length = userPassword.size();
462 if (length > size) {
463 length = size;
464 }
465 strncpy(password, userPassword.c_str(), length);
466 return length;
467}
468
469static shared_array<Mutex> mutexes;
470
471static void callbackLocking(int mode, int n, const char*, int) {
472 if (mode & CRYPTO_LOCK) {
473 mutexes[n].lock();
474 } else {
475 mutexes[n].unlock();
476 }
477}
478
479#if (OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_NO_THREAD_ID)
480static unsigned long callbackThreadID() {
Roger Meier598bf482011-02-22 21:56:33 +0000481 return (unsigned long) pthread_self();
Bryan Duxburycd9aea12011-02-22 18:12:06 +0000482}
483#endif
484
485static CRYPTO_dynlock_value* dyn_create(const char*, int) {
486 return new CRYPTO_dynlock_value;
487}
488
489static void dyn_lock(int mode,
490 struct CRYPTO_dynlock_value* lock,
491 const char*, int) {
492 if (lock != NULL) {
493 if (mode & CRYPTO_LOCK) {
494 lock->mutex.lock();
495 } else {
496 lock->mutex.unlock();
497 }
498 }
499}
500
501static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
502 delete lock;
503}
504
505void TSSLSocketFactory::initializeOpenSSL() {
506 if (initialized) {
507 return;
508 }
509 initialized = true;
510 SSL_library_init();
511 SSL_load_error_strings();
512 // static locking
513 mutexes = shared_array<Mutex>(new Mutex[::CRYPTO_num_locks()]);
514 if (mutexes == NULL) {
515 throw TTransportException(TTransportException::INTERNAL_ERROR,
516 "initializeOpenSSL() failed, "
517 "out of memory while creating mutex array");
518 }
519#if (OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_NO_THREAD_ID)
520 CRYPTO_set_id_callback(callbackThreadID);
521#endif
522 CRYPTO_set_locking_callback(callbackLocking);
523 // dynamic locking
524 CRYPTO_set_dynlock_create_callback(dyn_create);
525 CRYPTO_set_dynlock_lock_callback(dyn_lock);
526 CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
527}
528
529void TSSLSocketFactory::cleanupOpenSSL() {
530 if (!initialized) {
531 return;
532 }
533 initialized = false;
534#if (OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_NO_THREAD_ID)
535 CRYPTO_set_id_callback(NULL);
536#endif
537 CRYPTO_set_locking_callback(NULL);
538 CRYPTO_set_dynlock_create_callback(NULL);
539 CRYPTO_set_dynlock_lock_callback(NULL);
540 CRYPTO_set_dynlock_destroy_callback(NULL);
541 CRYPTO_cleanup_all_ex_data();
542 ERR_free_strings();
543 EVP_cleanup();
544 ERR_remove_state(0);
545 mutexes.reset();
546}
547
548// extract error messages from error queue
549void buildErrors(string& errors, int errno_copy) {
550 unsigned long errorCode;
551 char message[256];
552
553 errors.reserve(512);
554 while ((errorCode = ERR_get_error()) != 0) {
555 if (!errors.empty()) {
556 errors += "; ";
557 }
558 const char* reason = ERR_reason_error_string(errorCode);
559 if (reason == NULL) {
560 snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
561 reason = message;
562 }
563 errors += reason;
564 }
565 if (errors.empty()) {
566 if (errno_copy != 0) {
567 errors += TOutput::strerror_s(errno_copy);
568 }
569 }
570 if (errors.empty()) {
571 errors = "error code: " + lexical_cast<string>(errno_copy);
572 }
573}
574
575/**
576 * Default implementation of AccessManager
577 */
578Decision DefaultClientAccessManager::verify(const sockaddr_storage& sa)
579 throw() { return SKIP; }
580
581Decision DefaultClientAccessManager::verify(const string& host,
582 const char* name,
583 int size) throw() {
584 if (host.empty() || name == NULL || size <= 0) {
585 return SKIP;
586 }
587 return (matchName(host.c_str(), name, size) ? ALLOW : SKIP);
588}
589
590Decision DefaultClientAccessManager::verify(const sockaddr_storage& sa,
591 const char* data,
592 int size) throw() {
593 bool match = false;
594 if (sa.ss_family == AF_INET && size == sizeof(in_addr)) {
595 match = (memcmp(&((sockaddr_in*)&sa)->sin_addr, data, size) == 0);
596 } else if (sa.ss_family == AF_INET6 && size == sizeof(in6_addr)) {
597 match = (memcmp(&((sockaddr_in6*)&sa)->sin6_addr, data, size) == 0);
598 }
599 return (match ? ALLOW : SKIP);
600}
601
602/**
603 * Match a name with a pattern. The pattern may include wildcard. A single
604 * wildcard "*" can match up to one component in the domain name.
605 *
606 * @param host Host name, typically the name of the remote host
607 * @param pattern Name retrieved from certificate
608 * @param size Size of "pattern"
609 * @return True, if "host" matches "pattern". False otherwise.
610 */
611bool matchName(const char* host, const char* pattern, int size) {
612 bool match = false;
613 int i = 0, j = 0;
614 while (i < size && host[j] != '\0') {
615 if (uppercase(pattern[i]) == uppercase(host[j])) {
616 i++;
617 j++;
618 continue;
619 }
620 if (pattern[i] == '*') {
621 while (host[j] != '.' && host[j] != '\0') {
622 j++;
623 }
624 i++;
625 continue;
626 }
627 break;
628 }
629 if (i == size && host[j] == '\0') {
630 match = true;
631 }
632 return match;
633
634}
635
636// This is to work around the Turkish locale issue, i.e.,
637// toupper('i') != toupper('I') if locale is "tr_TR"
638char uppercase (char c) {
639 if ('a' <= c && c <= 'z') {
640 return c + ('A' - 'a');
641 }
642 return c;
643}
644
645}}}