blob: 7865be04add629b615b52645594462331be115f0 [file] [log] [blame]
Roger Meier0ae05422011-04-12 19:23:10 +00001/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
Bryan Duxburycd9aea12011-02-22 18:12:06 +000019
Roger Meier2fa9c312011-09-05 19:15:53 +000020#ifdef HAVE_CONFIG_H
21#include <config.h>
22#endif
Bryan Duxburycd9aea12011-02-22 18:12:06 +000023#include <errno.h>
24#include <string>
Roger Meier2fa9c312011-09-05 19:15:53 +000025#ifdef HAVE_ARPA_INET_H
Bryan Duxburycd9aea12011-02-22 18:12:06 +000026#include <arpa/inet.h>
Roger Meier2fa9c312011-09-05 19:15:53 +000027#endif
Roger Meier34603932011-03-25 12:22:17 +000028#include <sys/types.h>
Roger Meier2fa9c312011-09-05 19:15:53 +000029#ifdef HAVE_SYS_SOCKET_H
Roger Meier34603932011-03-25 12:22:17 +000030#include <sys/socket.h>
Roger Meier2fa9c312011-09-05 19:15:53 +000031#endif
Bryan Duxburycd9aea12011-02-22 18:12:06 +000032#include <boost/lexical_cast.hpp>
33#include <boost/shared_array.hpp>
34#include <openssl/err.h>
35#include <openssl/rand.h>
36#include <openssl/ssl.h>
37#include <openssl/x509v3.h>
38#include "concurrency/Mutex.h"
39#include "TSSLSocket.h"
40
41#define OPENSSL_VERSION_NO_THREAD_ID 0x10000000L
42
43using namespace std;
44using namespace boost;
45using namespace apache::thrift::concurrency;
46
47struct CRYPTO_dynlock_value {
48 Mutex mutex;
49};
50
51namespace apache { namespace thrift { namespace transport {
52
53
54static void buildErrors(string& message, int error = 0);
55static bool matchName(const char* host, const char* pattern, int size);
56static char uppercase(char c);
57
58// SSLContext implementation
59SSLContext::SSLContext() {
60 ctx_ = SSL_CTX_new(TLSv1_method());
61 if (ctx_ == NULL) {
62 string errors;
63 buildErrors(errors);
64 throw TSSLException("SSL_CTX_new: " + errors);
65 }
66 SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
67}
68
69SSLContext::~SSLContext() {
70 if (ctx_ != NULL) {
71 SSL_CTX_free(ctx_);
72 ctx_ = NULL;
73 }
74}
75
76SSL* SSLContext::createSSL() {
77 SSL* ssl = SSL_new(ctx_);
78 if (ssl == NULL) {
79 string errors;
80 buildErrors(errors);
81 throw TSSLException("SSL_new: " + errors);
82 }
83 return ssl;
84}
85
86// TSSLSocket implementation
87TSSLSocket::TSSLSocket(shared_ptr<SSLContext> ctx):
88 TSocket(), server_(false), ssl_(NULL), ctx_(ctx) {
89}
90
91TSSLSocket::TSSLSocket(shared_ptr<SSLContext> ctx, int socket):
92 TSocket(socket), server_(false), ssl_(NULL), ctx_(ctx) {
93}
94
95TSSLSocket::TSSLSocket(shared_ptr<SSLContext> ctx, string host, int port):
96 TSocket(host, port), server_(false), ssl_(NULL), ctx_(ctx) {
97}
98
99TSSLSocket::~TSSLSocket() {
100 close();
101}
102
103bool TSSLSocket::isOpen() {
104 if (ssl_ == NULL || !TSocket::isOpen()) {
105 return false;
106 }
107 int shutdown = SSL_get_shutdown(ssl_);
108 bool shutdownReceived = (shutdown & SSL_RECEIVED_SHUTDOWN);
109 bool shutdownSent = (shutdown & SSL_SENT_SHUTDOWN);
110 if (shutdownReceived && shutdownSent) {
111 return false;
112 }
113 return true;
114}
115
116bool TSSLSocket::peek() {
117 if (!isOpen()) {
118 return false;
119 }
120 checkHandshake();
121 int rc;
122 uint8_t byte;
123 rc = SSL_peek(ssl_, &byte, 1);
124 if (rc < 0) {
125 int errno_copy = errno;
126 string errors;
127 buildErrors(errors, errno_copy);
128 throw TSSLException("SSL_peek: " + errors);
129 }
130 if (rc == 0) {
131 ERR_clear_error();
132 }
133 return (rc > 0);
134}
135
136void TSSLSocket::open() {
137 if (isOpen() || server()) {
138 throw TTransportException(TTransportException::BAD_ARGS);
139 }
140 TSocket::open();
141}
142
143void TSSLSocket::close() {
144 if (ssl_ != NULL) {
145 int rc = SSL_shutdown(ssl_);
146 if (rc == 0) {
147 rc = SSL_shutdown(ssl_);
148 }
149 if (rc < 0) {
150 int errno_copy = errno;
151 string errors;
152 buildErrors(errors, errno_copy);
153 GlobalOutput(("SSL_shutdown: " + errors).c_str());
154 }
155 SSL_free(ssl_);
156 ssl_ = NULL;
157 ERR_remove_state(0);
158 }
159 TSocket::close();
160}
161
162uint32_t TSSLSocket::read(uint8_t* buf, uint32_t len) {
163 checkHandshake();
164 int32_t bytes = 0;
165 for (int32_t retries = 0; retries < maxRecvRetries_; retries++){
166 bytes = SSL_read(ssl_, buf, len);
167 if (bytes >= 0)
168 break;
169 int errno_copy = errno;
170 if (SSL_get_error(ssl_, bytes) == SSL_ERROR_SYSCALL) {
171 if (ERR_get_error() == 0 && errno_copy == EINTR) {
172 continue;
173 }
174 }
175 string errors;
176 buildErrors(errors, errno_copy);
177 throw TSSLException("SSL_read: " + errors);
178 }
179 return bytes;
180}
181
182void TSSLSocket::write(const uint8_t* buf, uint32_t len) {
183 checkHandshake();
184 // loop in case SSL_MODE_ENABLE_PARTIAL_WRITE is set in SSL_CTX.
185 uint32_t written = 0;
186 while (written < len) {
187 int32_t bytes = SSL_write(ssl_, &buf[written], len - written);
188 if (bytes <= 0) {
189 int errno_copy = errno;
190 string errors;
191 buildErrors(errors, errno_copy);
192 throw TSSLException("SSL_write: " + errors);
193 }
194 written += bytes;
195 }
196}
197
198void TSSLSocket::flush() {
199 // Don't throw exception if not open. Thrift servers close socket twice.
200 if (ssl_ == NULL) {
201 return;
202 }
203 checkHandshake();
204 BIO* bio = SSL_get_wbio(ssl_);
205 if (bio == NULL) {
206 throw TSSLException("SSL_get_wbio returns NULL");
207 }
208 if (BIO_flush(bio) != 1) {
209 int errno_copy = errno;
210 string errors;
211 buildErrors(errors, errno_copy);
212 throw TSSLException("BIO_flush: " + errors);
213 }
214}
215
216void TSSLSocket::checkHandshake() {
217 if (!TSocket::isOpen()) {
218 throw TTransportException(TTransportException::NOT_OPEN);
219 }
220 if (ssl_ != NULL) {
221 return;
222 }
223 ssl_ = ctx_->createSSL();
224 SSL_set_fd(ssl_, socket_);
225 int rc;
226 if (server()) {
227 rc = SSL_accept(ssl_);
228 } else {
229 rc = SSL_connect(ssl_);
230 }
231 if (rc <= 0) {
232 int errno_copy = errno;
233 string fname(server() ? "SSL_accept" : "SSL_connect");
234 string errors;
235 buildErrors(errors, errno_copy);
236 throw TSSLException(fname + ": " + errors);
237 }
238 authorize();
239}
240
241void TSSLSocket::authorize() {
242 int rc = SSL_get_verify_result(ssl_);
243 if (rc != X509_V_OK) { // verify authentication result
244 throw TSSLException(string("SSL_get_verify_result(), ") +
245 X509_verify_cert_error_string(rc));
246 }
247
248 X509* cert = SSL_get_peer_certificate(ssl_);
249 if (cert == NULL) {
250 // certificate is not present
251 if (SSL_get_verify_mode(ssl_) & SSL_VERIFY_FAIL_IF_NO_PEER_CERT) {
252 throw TSSLException("authorize: required certificate not present");
253 }
254 // certificate was optional: didn't intend to authorize remote
255 if (server() && access_ != NULL) {
256 throw TSSLException("authorize: certificate required for authorization");
257 }
258 return;
259 }
260 // certificate is present
261 if (access_ == NULL) {
262 X509_free(cert);
263 return;
264 }
265 // both certificate and access manager are present
266
267 string host;
Roger Meiera8cef6e2011-07-17 18:55:59 +0000268 sockaddr_storage sa;
Bryan Duxburycd9aea12011-02-22 18:12:06 +0000269 socklen_t saLength = sizeof(sa);
270
271 if (getpeername(socket_, (sockaddr*)&sa, &saLength) != 0) {
272 sa.ss_family = AF_UNSPEC;
273 }
274
275 AccessManager::Decision decision = access_->verify(sa);
276
277 if (decision != AccessManager::SKIP) {
278 X509_free(cert);
279 if (decision != AccessManager::ALLOW) {
280 throw TSSLException("authorize: access denied based on remote IP");
281 }
282 return;
283 }
284
285 // extract subjectAlternativeName
286 STACK_OF(GENERAL_NAME)* alternatives = (STACK_OF(GENERAL_NAME)*)
287 X509_get_ext_d2i(cert, NID_subject_alt_name, NULL, NULL);
288 if (alternatives != NULL) {
289 const int count = sk_GENERAL_NAME_num(alternatives);
290 for (int i = 0; decision == AccessManager::SKIP && i < count; i++) {
291 const GENERAL_NAME* name = sk_GENERAL_NAME_value(alternatives, i);
292 if (name == NULL) {
293 continue;
294 }
295 char* data = (char*)ASN1_STRING_data(name->d.ia5);
296 int length = ASN1_STRING_length(name->d.ia5);
297 switch (name->type) {
298 case GEN_DNS:
299 if (host.empty()) {
300 host = (server() ? getPeerHost() : getHost());
301 }
302 decision = access_->verify(host, data, length);
303 break;
304 case GEN_IPADD:
305 decision = access_->verify(sa, data, length);
306 break;
307 }
308 }
309 sk_GENERAL_NAME_pop_free(alternatives, GENERAL_NAME_free);
310 }
311
312 if (decision != AccessManager::SKIP) {
313 X509_free(cert);
314 if (decision != AccessManager::ALLOW) {
315 throw TSSLException("authorize: access denied");
316 }
317 return;
318 }
319
320 // extract commonName
321 X509_NAME* name = X509_get_subject_name(cert);
322 if (name != NULL) {
323 X509_NAME_ENTRY* entry;
324 unsigned char* utf8;
325 int last = -1;
326 while (decision == AccessManager::SKIP) {
327 last = X509_NAME_get_index_by_NID(name, NID_commonName, last);
328 if (last == -1)
329 break;
330 entry = X509_NAME_get_entry(name, last);
331 if (entry == NULL)
332 continue;
333 ASN1_STRING* common = X509_NAME_ENTRY_get_data(entry);
334 int size = ASN1_STRING_to_UTF8(&utf8, common);
335 if (host.empty()) {
336 host = (server() ? getHost() : getHost());
337 }
338 decision = access_->verify(host, (char*)utf8, size);
339 OPENSSL_free(utf8);
340 }
341 }
342 X509_free(cert);
343 if (decision != AccessManager::ALLOW) {
344 throw TSSLException("authorize: cannot authorize peer");
345 }
346}
347
348// TSSLSocketFactory implementation
349bool TSSLSocketFactory::initialized = false;
350uint64_t TSSLSocketFactory::count_ = 0;
351Mutex TSSLSocketFactory::mutex_;
352
353TSSLSocketFactory::TSSLSocketFactory(): server_(false) {
354 Guard guard(mutex_);
355 if (count_ == 0) {
356 initializeOpenSSL();
357 randomize();
358 }
359 count_++;
360 ctx_ = shared_ptr<SSLContext>(new SSLContext);
361}
362
363TSSLSocketFactory::~TSSLSocketFactory() {
364 Guard guard(mutex_);
365 count_--;
366 if (count_ == 0) {
367 cleanupOpenSSL();
368 }
369}
370
371shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket() {
372 shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_));
373 setup(ssl);
374 return ssl;
375}
376
377shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(int socket) {
378 shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, socket));
379 setup(ssl);
380 return ssl;
381}
382
383shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(const string& host,
384 int port) {
385 shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, host, port));
386 setup(ssl);
387 return ssl;
388}
389
390void TSSLSocketFactory::setup(shared_ptr<TSSLSocket> ssl) {
391 ssl->server(server());
392 if (access_ == NULL && !server()) {
393 access_ = shared_ptr<AccessManager>(new DefaultClientAccessManager);
394 }
395 if (access_ != NULL) {
396 ssl->access(access_);
397 }
398}
399
400void TSSLSocketFactory::ciphers(const string& enable) {
401 int rc = SSL_CTX_set_cipher_list(ctx_->get(), enable.c_str());
402 if (ERR_peek_error() != 0) {
403 string errors;
404 buildErrors(errors);
405 throw TSSLException("SSL_CTX_set_cipher_list: " + errors);
406 }
407 if (rc == 0) {
408 throw TSSLException("None of specified ciphers are supported");
409 }
410}
411
412void TSSLSocketFactory::authenticate(bool required) {
413 int mode;
414 if (required) {
415 mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
416 } else {
417 mode = SSL_VERIFY_NONE;
418 }
419 SSL_CTX_set_verify(ctx_->get(), mode, NULL);
420}
421
422void TSSLSocketFactory::loadCertificate(const char* path, const char* format) {
423 if (path == NULL || format == NULL) {
424 throw TTransportException(TTransportException::BAD_ARGS,
425 "loadCertificateChain: either <path> or <format> is NULL");
426 }
427 if (strcmp(format, "PEM") == 0) {
428 if (SSL_CTX_use_certificate_chain_file(ctx_->get(), path) == 0) {
429 int errno_copy = errno;
430 string errors;
431 buildErrors(errors, errno_copy);
432 throw TSSLException("SSL_CTX_use_certificate_chain_file: " + errors);
433 }
434 } else {
435 throw TSSLException("Unsupported certificate format: " + string(format));
436 }
437}
438
439void TSSLSocketFactory::loadPrivateKey(const char* path, const char* format) {
440 if (path == NULL || format == NULL) {
441 throw TTransportException(TTransportException::BAD_ARGS,
442 "loadPrivateKey: either <path> or <format> is NULL");
443 }
444 if (strcmp(format, "PEM") == 0) {
445 if (SSL_CTX_use_PrivateKey_file(ctx_->get(), path, SSL_FILETYPE_PEM) == 0) {
446 int errno_copy = errno;
447 string errors;
448 buildErrors(errors, errno_copy);
449 throw TSSLException("SSL_CTX_use_PrivateKey_file: " + errors);
450 }
451 }
452}
453
454void TSSLSocketFactory::loadTrustedCertificates(const char* path) {
455 if (path == NULL) {
456 throw TTransportException(TTransportException::BAD_ARGS,
457 "loadTrustedCertificates: <path> is NULL");
458 }
459 if (SSL_CTX_load_verify_locations(ctx_->get(), path, NULL) == 0) {
460 int errno_copy = errno;
461 string errors;
462 buildErrors(errors, errno_copy);
463 throw TSSLException("SSL_CTX_load_verify_locations: " + errors);
464 }
465}
466
467void TSSLSocketFactory::randomize() {
468 RAND_poll();
469}
470
471void TSSLSocketFactory::overrideDefaultPasswordCallback() {
472 SSL_CTX_set_default_passwd_cb(ctx_->get(), passwordCallback);
473 SSL_CTX_set_default_passwd_cb_userdata(ctx_->get(), this);
474}
475
476int TSSLSocketFactory::passwordCallback(char* password,
477 int size,
478 int,
479 void* data) {
480 TSSLSocketFactory* factory = (TSSLSocketFactory*)data;
481 string userPassword;
482 factory->getPassword(userPassword, size);
483 int length = userPassword.size();
484 if (length > size) {
485 length = size;
486 }
487 strncpy(password, userPassword.c_str(), length);
488 return length;
489}
490
491static shared_array<Mutex> mutexes;
492
493static void callbackLocking(int mode, int n, const char*, int) {
494 if (mode & CRYPTO_LOCK) {
495 mutexes[n].lock();
496 } else {
497 mutexes[n].unlock();
498 }
499}
500
501#if (OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_NO_THREAD_ID)
502static unsigned long callbackThreadID() {
Roger Meier598bf482011-02-22 21:56:33 +0000503 return (unsigned long) pthread_self();
Bryan Duxburycd9aea12011-02-22 18:12:06 +0000504}
505#endif
506
507static CRYPTO_dynlock_value* dyn_create(const char*, int) {
508 return new CRYPTO_dynlock_value;
509}
510
511static void dyn_lock(int mode,
512 struct CRYPTO_dynlock_value* lock,
513 const char*, int) {
514 if (lock != NULL) {
515 if (mode & CRYPTO_LOCK) {
516 lock->mutex.lock();
517 } else {
518 lock->mutex.unlock();
519 }
520 }
521}
522
523static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
524 delete lock;
525}
526
527void TSSLSocketFactory::initializeOpenSSL() {
528 if (initialized) {
529 return;
530 }
531 initialized = true;
532 SSL_library_init();
533 SSL_load_error_strings();
534 // static locking
535 mutexes = shared_array<Mutex>(new Mutex[::CRYPTO_num_locks()]);
536 if (mutexes == NULL) {
537 throw TTransportException(TTransportException::INTERNAL_ERROR,
538 "initializeOpenSSL() failed, "
539 "out of memory while creating mutex array");
540 }
541#if (OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_NO_THREAD_ID)
542 CRYPTO_set_id_callback(callbackThreadID);
543#endif
544 CRYPTO_set_locking_callback(callbackLocking);
545 // dynamic locking
546 CRYPTO_set_dynlock_create_callback(dyn_create);
547 CRYPTO_set_dynlock_lock_callback(dyn_lock);
548 CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
549}
550
551void TSSLSocketFactory::cleanupOpenSSL() {
552 if (!initialized) {
553 return;
554 }
555 initialized = false;
556#if (OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_NO_THREAD_ID)
557 CRYPTO_set_id_callback(NULL);
558#endif
559 CRYPTO_set_locking_callback(NULL);
560 CRYPTO_set_dynlock_create_callback(NULL);
561 CRYPTO_set_dynlock_lock_callback(NULL);
562 CRYPTO_set_dynlock_destroy_callback(NULL);
563 CRYPTO_cleanup_all_ex_data();
564 ERR_free_strings();
565 EVP_cleanup();
566 ERR_remove_state(0);
567 mutexes.reset();
568}
569
570// extract error messages from error queue
571void buildErrors(string& errors, int errno_copy) {
572 unsigned long errorCode;
573 char message[256];
574
575 errors.reserve(512);
576 while ((errorCode = ERR_get_error()) != 0) {
577 if (!errors.empty()) {
578 errors += "; ";
579 }
580 const char* reason = ERR_reason_error_string(errorCode);
581 if (reason == NULL) {
582 snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
583 reason = message;
584 }
585 errors += reason;
586 }
587 if (errors.empty()) {
588 if (errno_copy != 0) {
589 errors += TOutput::strerror_s(errno_copy);
590 }
591 }
592 if (errors.empty()) {
593 errors = "error code: " + lexical_cast<string>(errno_copy);
594 }
595}
596
597/**
598 * Default implementation of AccessManager
599 */
600Decision DefaultClientAccessManager::verify(const sockaddr_storage& sa)
Roger Meiera8cef6e2011-07-17 18:55:59 +0000601 throw() {
602 (void) sa;
603 return SKIP;
604}
Bryan Duxburycd9aea12011-02-22 18:12:06 +0000605
606Decision DefaultClientAccessManager::verify(const string& host,
607 const char* name,
608 int size) throw() {
609 if (host.empty() || name == NULL || size <= 0) {
610 return SKIP;
611 }
612 return (matchName(host.c_str(), name, size) ? ALLOW : SKIP);
613}
614
615Decision DefaultClientAccessManager::verify(const sockaddr_storage& sa,
616 const char* data,
617 int size) throw() {
618 bool match = false;
619 if (sa.ss_family == AF_INET && size == sizeof(in_addr)) {
620 match = (memcmp(&((sockaddr_in*)&sa)->sin_addr, data, size) == 0);
621 } else if (sa.ss_family == AF_INET6 && size == sizeof(in6_addr)) {
622 match = (memcmp(&((sockaddr_in6*)&sa)->sin6_addr, data, size) == 0);
623 }
624 return (match ? ALLOW : SKIP);
625}
626
627/**
628 * Match a name with a pattern. The pattern may include wildcard. A single
629 * wildcard "*" can match up to one component in the domain name.
630 *
631 * @param host Host name, typically the name of the remote host
632 * @param pattern Name retrieved from certificate
633 * @param size Size of "pattern"
634 * @return True, if "host" matches "pattern". False otherwise.
635 */
636bool matchName(const char* host, const char* pattern, int size) {
637 bool match = false;
638 int i = 0, j = 0;
639 while (i < size && host[j] != '\0') {
640 if (uppercase(pattern[i]) == uppercase(host[j])) {
641 i++;
642 j++;
643 continue;
644 }
645 if (pattern[i] == '*') {
646 while (host[j] != '.' && host[j] != '\0') {
647 j++;
648 }
649 i++;
650 continue;
651 }
652 break;
653 }
654 if (i == size && host[j] == '\0') {
655 match = true;
656 }
657 return match;
658
659}
660
661// This is to work around the Turkish locale issue, i.e.,
662// toupper('i') != toupper('I') if locale is "tr_TR"
663char uppercase (char c) {
664 if ('a' <= c && c <= 'z') {
665 return c + ('A' - 'a');
666 }
667 return c;
668}
669
670}}}