blob: 602a8945ae90ae1f30c14b0fafe9a4bf0cad80e9 [file] [log] [blame]
Roger Meier0ae05422011-04-12 19:23:10 +00001/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
Bryan Duxburycd9aea12011-02-22 18:12:06 +000019
20#include <errno.h>
21#include <string>
22#include <arpa/inet.h>
Roger Meier34603932011-03-25 12:22:17 +000023#include <sys/types.h>
24#include <sys/socket.h>
Bryan Duxburycd9aea12011-02-22 18:12:06 +000025#include <boost/lexical_cast.hpp>
26#include <boost/shared_array.hpp>
27#include <openssl/err.h>
28#include <openssl/rand.h>
29#include <openssl/ssl.h>
30#include <openssl/x509v3.h>
31#include "concurrency/Mutex.h"
32#include "TSSLSocket.h"
33
34#define OPENSSL_VERSION_NO_THREAD_ID 0x10000000L
35
36using namespace std;
37using namespace boost;
38using namespace apache::thrift::concurrency;
39
40struct CRYPTO_dynlock_value {
41 Mutex mutex;
42};
43
44namespace apache { namespace thrift { namespace transport {
45
46
47static void buildErrors(string& message, int error = 0);
48static bool matchName(const char* host, const char* pattern, int size);
49static char uppercase(char c);
50
51// SSLContext implementation
52SSLContext::SSLContext() {
53 ctx_ = SSL_CTX_new(TLSv1_method());
54 if (ctx_ == NULL) {
55 string errors;
56 buildErrors(errors);
57 throw TSSLException("SSL_CTX_new: " + errors);
58 }
59 SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
60}
61
62SSLContext::~SSLContext() {
63 if (ctx_ != NULL) {
64 SSL_CTX_free(ctx_);
65 ctx_ = NULL;
66 }
67}
68
69SSL* SSLContext::createSSL() {
70 SSL* ssl = SSL_new(ctx_);
71 if (ssl == NULL) {
72 string errors;
73 buildErrors(errors);
74 throw TSSLException("SSL_new: " + errors);
75 }
76 return ssl;
77}
78
79// TSSLSocket implementation
80TSSLSocket::TSSLSocket(shared_ptr<SSLContext> ctx):
81 TSocket(), server_(false), ssl_(NULL), ctx_(ctx) {
82}
83
84TSSLSocket::TSSLSocket(shared_ptr<SSLContext> ctx, int socket):
85 TSocket(socket), server_(false), ssl_(NULL), ctx_(ctx) {
86}
87
88TSSLSocket::TSSLSocket(shared_ptr<SSLContext> ctx, string host, int port):
89 TSocket(host, port), server_(false), ssl_(NULL), ctx_(ctx) {
90}
91
92TSSLSocket::~TSSLSocket() {
93 close();
94}
95
96bool TSSLSocket::isOpen() {
97 if (ssl_ == NULL || !TSocket::isOpen()) {
98 return false;
99 }
100 int shutdown = SSL_get_shutdown(ssl_);
101 bool shutdownReceived = (shutdown & SSL_RECEIVED_SHUTDOWN);
102 bool shutdownSent = (shutdown & SSL_SENT_SHUTDOWN);
103 if (shutdownReceived && shutdownSent) {
104 return false;
105 }
106 return true;
107}
108
109bool TSSLSocket::peek() {
110 if (!isOpen()) {
111 return false;
112 }
113 checkHandshake();
114 int rc;
115 uint8_t byte;
116 rc = SSL_peek(ssl_, &byte, 1);
117 if (rc < 0) {
118 int errno_copy = errno;
119 string errors;
120 buildErrors(errors, errno_copy);
121 throw TSSLException("SSL_peek: " + errors);
122 }
123 if (rc == 0) {
124 ERR_clear_error();
125 }
126 return (rc > 0);
127}
128
129void TSSLSocket::open() {
130 if (isOpen() || server()) {
131 throw TTransportException(TTransportException::BAD_ARGS);
132 }
133 TSocket::open();
134}
135
136void TSSLSocket::close() {
137 if (ssl_ != NULL) {
138 int rc = SSL_shutdown(ssl_);
139 if (rc == 0) {
140 rc = SSL_shutdown(ssl_);
141 }
142 if (rc < 0) {
143 int errno_copy = errno;
144 string errors;
145 buildErrors(errors, errno_copy);
146 GlobalOutput(("SSL_shutdown: " + errors).c_str());
147 }
148 SSL_free(ssl_);
149 ssl_ = NULL;
150 ERR_remove_state(0);
151 }
152 TSocket::close();
153}
154
155uint32_t TSSLSocket::read(uint8_t* buf, uint32_t len) {
156 checkHandshake();
157 int32_t bytes = 0;
158 for (int32_t retries = 0; retries < maxRecvRetries_; retries++){
159 bytes = SSL_read(ssl_, buf, len);
160 if (bytes >= 0)
161 break;
162 int errno_copy = errno;
163 if (SSL_get_error(ssl_, bytes) == SSL_ERROR_SYSCALL) {
164 if (ERR_get_error() == 0 && errno_copy == EINTR) {
165 continue;
166 }
167 }
168 string errors;
169 buildErrors(errors, errno_copy);
170 throw TSSLException("SSL_read: " + errors);
171 }
172 return bytes;
173}
174
175void TSSLSocket::write(const uint8_t* buf, uint32_t len) {
176 checkHandshake();
177 // loop in case SSL_MODE_ENABLE_PARTIAL_WRITE is set in SSL_CTX.
178 uint32_t written = 0;
179 while (written < len) {
180 int32_t bytes = SSL_write(ssl_, &buf[written], len - written);
181 if (bytes <= 0) {
182 int errno_copy = errno;
183 string errors;
184 buildErrors(errors, errno_copy);
185 throw TSSLException("SSL_write: " + errors);
186 }
187 written += bytes;
188 }
189}
190
191void TSSLSocket::flush() {
192 // Don't throw exception if not open. Thrift servers close socket twice.
193 if (ssl_ == NULL) {
194 return;
195 }
196 checkHandshake();
197 BIO* bio = SSL_get_wbio(ssl_);
198 if (bio == NULL) {
199 throw TSSLException("SSL_get_wbio returns NULL");
200 }
201 if (BIO_flush(bio) != 1) {
202 int errno_copy = errno;
203 string errors;
204 buildErrors(errors, errno_copy);
205 throw TSSLException("BIO_flush: " + errors);
206 }
207}
208
209void TSSLSocket::checkHandshake() {
210 if (!TSocket::isOpen()) {
211 throw TTransportException(TTransportException::NOT_OPEN);
212 }
213 if (ssl_ != NULL) {
214 return;
215 }
216 ssl_ = ctx_->createSSL();
217 SSL_set_fd(ssl_, socket_);
218 int rc;
219 if (server()) {
220 rc = SSL_accept(ssl_);
221 } else {
222 rc = SSL_connect(ssl_);
223 }
224 if (rc <= 0) {
225 int errno_copy = errno;
226 string fname(server() ? "SSL_accept" : "SSL_connect");
227 string errors;
228 buildErrors(errors, errno_copy);
229 throw TSSLException(fname + ": " + errors);
230 }
231 authorize();
232}
233
234void TSSLSocket::authorize() {
235 int rc = SSL_get_verify_result(ssl_);
236 if (rc != X509_V_OK) { // verify authentication result
237 throw TSSLException(string("SSL_get_verify_result(), ") +
238 X509_verify_cert_error_string(rc));
239 }
240
241 X509* cert = SSL_get_peer_certificate(ssl_);
242 if (cert == NULL) {
243 // certificate is not present
244 if (SSL_get_verify_mode(ssl_) & SSL_VERIFY_FAIL_IF_NO_PEER_CERT) {
245 throw TSSLException("authorize: required certificate not present");
246 }
247 // certificate was optional: didn't intend to authorize remote
248 if (server() && access_ != NULL) {
249 throw TSSLException("authorize: certificate required for authorization");
250 }
251 return;
252 }
253 // certificate is present
254 if (access_ == NULL) {
255 X509_free(cert);
256 return;
257 }
258 // both certificate and access manager are present
259
260 string host;
261 sockaddr_storage sa = {};
262 socklen_t saLength = sizeof(sa);
263
264 if (getpeername(socket_, (sockaddr*)&sa, &saLength) != 0) {
265 sa.ss_family = AF_UNSPEC;
266 }
267
268 AccessManager::Decision decision = access_->verify(sa);
269
270 if (decision != AccessManager::SKIP) {
271 X509_free(cert);
272 if (decision != AccessManager::ALLOW) {
273 throw TSSLException("authorize: access denied based on remote IP");
274 }
275 return;
276 }
277
278 // extract subjectAlternativeName
279 STACK_OF(GENERAL_NAME)* alternatives = (STACK_OF(GENERAL_NAME)*)
280 X509_get_ext_d2i(cert, NID_subject_alt_name, NULL, NULL);
281 if (alternatives != NULL) {
282 const int count = sk_GENERAL_NAME_num(alternatives);
283 for (int i = 0; decision == AccessManager::SKIP && i < count; i++) {
284 const GENERAL_NAME* name = sk_GENERAL_NAME_value(alternatives, i);
285 if (name == NULL) {
286 continue;
287 }
288 char* data = (char*)ASN1_STRING_data(name->d.ia5);
289 int length = ASN1_STRING_length(name->d.ia5);
290 switch (name->type) {
291 case GEN_DNS:
292 if (host.empty()) {
293 host = (server() ? getPeerHost() : getHost());
294 }
295 decision = access_->verify(host, data, length);
296 break;
297 case GEN_IPADD:
298 decision = access_->verify(sa, data, length);
299 break;
300 }
301 }
302 sk_GENERAL_NAME_pop_free(alternatives, GENERAL_NAME_free);
303 }
304
305 if (decision != AccessManager::SKIP) {
306 X509_free(cert);
307 if (decision != AccessManager::ALLOW) {
308 throw TSSLException("authorize: access denied");
309 }
310 return;
311 }
312
313 // extract commonName
314 X509_NAME* name = X509_get_subject_name(cert);
315 if (name != NULL) {
316 X509_NAME_ENTRY* entry;
317 unsigned char* utf8;
318 int last = -1;
319 while (decision == AccessManager::SKIP) {
320 last = X509_NAME_get_index_by_NID(name, NID_commonName, last);
321 if (last == -1)
322 break;
323 entry = X509_NAME_get_entry(name, last);
324 if (entry == NULL)
325 continue;
326 ASN1_STRING* common = X509_NAME_ENTRY_get_data(entry);
327 int size = ASN1_STRING_to_UTF8(&utf8, common);
328 if (host.empty()) {
329 host = (server() ? getHost() : getHost());
330 }
331 decision = access_->verify(host, (char*)utf8, size);
332 OPENSSL_free(utf8);
333 }
334 }
335 X509_free(cert);
336 if (decision != AccessManager::ALLOW) {
337 throw TSSLException("authorize: cannot authorize peer");
338 }
339}
340
341// TSSLSocketFactory implementation
342bool TSSLSocketFactory::initialized = false;
343uint64_t TSSLSocketFactory::count_ = 0;
344Mutex TSSLSocketFactory::mutex_;
345
346TSSLSocketFactory::TSSLSocketFactory(): server_(false) {
347 Guard guard(mutex_);
348 if (count_ == 0) {
349 initializeOpenSSL();
350 randomize();
351 }
352 count_++;
353 ctx_ = shared_ptr<SSLContext>(new SSLContext);
354}
355
356TSSLSocketFactory::~TSSLSocketFactory() {
357 Guard guard(mutex_);
358 count_--;
359 if (count_ == 0) {
360 cleanupOpenSSL();
361 }
362}
363
364shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket() {
365 shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_));
366 setup(ssl);
367 return ssl;
368}
369
370shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(int socket) {
371 shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, socket));
372 setup(ssl);
373 return ssl;
374}
375
376shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(const string& host,
377 int port) {
378 shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, host, port));
379 setup(ssl);
380 return ssl;
381}
382
383void TSSLSocketFactory::setup(shared_ptr<TSSLSocket> ssl) {
384 ssl->server(server());
385 if (access_ == NULL && !server()) {
386 access_ = shared_ptr<AccessManager>(new DefaultClientAccessManager);
387 }
388 if (access_ != NULL) {
389 ssl->access(access_);
390 }
391}
392
393void TSSLSocketFactory::ciphers(const string& enable) {
394 int rc = SSL_CTX_set_cipher_list(ctx_->get(), enable.c_str());
395 if (ERR_peek_error() != 0) {
396 string errors;
397 buildErrors(errors);
398 throw TSSLException("SSL_CTX_set_cipher_list: " + errors);
399 }
400 if (rc == 0) {
401 throw TSSLException("None of specified ciphers are supported");
402 }
403}
404
405void TSSLSocketFactory::authenticate(bool required) {
406 int mode;
407 if (required) {
408 mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
409 } else {
410 mode = SSL_VERIFY_NONE;
411 }
412 SSL_CTX_set_verify(ctx_->get(), mode, NULL);
413}
414
415void TSSLSocketFactory::loadCertificate(const char* path, const char* format) {
416 if (path == NULL || format == NULL) {
417 throw TTransportException(TTransportException::BAD_ARGS,
418 "loadCertificateChain: either <path> or <format> is NULL");
419 }
420 if (strcmp(format, "PEM") == 0) {
421 if (SSL_CTX_use_certificate_chain_file(ctx_->get(), path) == 0) {
422 int errno_copy = errno;
423 string errors;
424 buildErrors(errors, errno_copy);
425 throw TSSLException("SSL_CTX_use_certificate_chain_file: " + errors);
426 }
427 } else {
428 throw TSSLException("Unsupported certificate format: " + string(format));
429 }
430}
431
432void TSSLSocketFactory::loadPrivateKey(const char* path, const char* format) {
433 if (path == NULL || format == NULL) {
434 throw TTransportException(TTransportException::BAD_ARGS,
435 "loadPrivateKey: either <path> or <format> is NULL");
436 }
437 if (strcmp(format, "PEM") == 0) {
438 if (SSL_CTX_use_PrivateKey_file(ctx_->get(), path, SSL_FILETYPE_PEM) == 0) {
439 int errno_copy = errno;
440 string errors;
441 buildErrors(errors, errno_copy);
442 throw TSSLException("SSL_CTX_use_PrivateKey_file: " + errors);
443 }
444 }
445}
446
447void TSSLSocketFactory::loadTrustedCertificates(const char* path) {
448 if (path == NULL) {
449 throw TTransportException(TTransportException::BAD_ARGS,
450 "loadTrustedCertificates: <path> is NULL");
451 }
452 if (SSL_CTX_load_verify_locations(ctx_->get(), path, NULL) == 0) {
453 int errno_copy = errno;
454 string errors;
455 buildErrors(errors, errno_copy);
456 throw TSSLException("SSL_CTX_load_verify_locations: " + errors);
457 }
458}
459
460void TSSLSocketFactory::randomize() {
461 RAND_poll();
462}
463
464void TSSLSocketFactory::overrideDefaultPasswordCallback() {
465 SSL_CTX_set_default_passwd_cb(ctx_->get(), passwordCallback);
466 SSL_CTX_set_default_passwd_cb_userdata(ctx_->get(), this);
467}
468
469int TSSLSocketFactory::passwordCallback(char* password,
470 int size,
471 int,
472 void* data) {
473 TSSLSocketFactory* factory = (TSSLSocketFactory*)data;
474 string userPassword;
475 factory->getPassword(userPassword, size);
476 int length = userPassword.size();
477 if (length > size) {
478 length = size;
479 }
480 strncpy(password, userPassword.c_str(), length);
481 return length;
482}
483
484static shared_array<Mutex> mutexes;
485
486static void callbackLocking(int mode, int n, const char*, int) {
487 if (mode & CRYPTO_LOCK) {
488 mutexes[n].lock();
489 } else {
490 mutexes[n].unlock();
491 }
492}
493
494#if (OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_NO_THREAD_ID)
495static unsigned long callbackThreadID() {
Roger Meier598bf482011-02-22 21:56:33 +0000496 return (unsigned long) pthread_self();
Bryan Duxburycd9aea12011-02-22 18:12:06 +0000497}
498#endif
499
500static CRYPTO_dynlock_value* dyn_create(const char*, int) {
501 return new CRYPTO_dynlock_value;
502}
503
504static void dyn_lock(int mode,
505 struct CRYPTO_dynlock_value* lock,
506 const char*, int) {
507 if (lock != NULL) {
508 if (mode & CRYPTO_LOCK) {
509 lock->mutex.lock();
510 } else {
511 lock->mutex.unlock();
512 }
513 }
514}
515
516static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
517 delete lock;
518}
519
520void TSSLSocketFactory::initializeOpenSSL() {
521 if (initialized) {
522 return;
523 }
524 initialized = true;
525 SSL_library_init();
526 SSL_load_error_strings();
527 // static locking
528 mutexes = shared_array<Mutex>(new Mutex[::CRYPTO_num_locks()]);
529 if (mutexes == NULL) {
530 throw TTransportException(TTransportException::INTERNAL_ERROR,
531 "initializeOpenSSL() failed, "
532 "out of memory while creating mutex array");
533 }
534#if (OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_NO_THREAD_ID)
535 CRYPTO_set_id_callback(callbackThreadID);
536#endif
537 CRYPTO_set_locking_callback(callbackLocking);
538 // dynamic locking
539 CRYPTO_set_dynlock_create_callback(dyn_create);
540 CRYPTO_set_dynlock_lock_callback(dyn_lock);
541 CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
542}
543
544void TSSLSocketFactory::cleanupOpenSSL() {
545 if (!initialized) {
546 return;
547 }
548 initialized = false;
549#if (OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_NO_THREAD_ID)
550 CRYPTO_set_id_callback(NULL);
551#endif
552 CRYPTO_set_locking_callback(NULL);
553 CRYPTO_set_dynlock_create_callback(NULL);
554 CRYPTO_set_dynlock_lock_callback(NULL);
555 CRYPTO_set_dynlock_destroy_callback(NULL);
556 CRYPTO_cleanup_all_ex_data();
557 ERR_free_strings();
558 EVP_cleanup();
559 ERR_remove_state(0);
560 mutexes.reset();
561}
562
563// extract error messages from error queue
564void buildErrors(string& errors, int errno_copy) {
565 unsigned long errorCode;
566 char message[256];
567
568 errors.reserve(512);
569 while ((errorCode = ERR_get_error()) != 0) {
570 if (!errors.empty()) {
571 errors += "; ";
572 }
573 const char* reason = ERR_reason_error_string(errorCode);
574 if (reason == NULL) {
575 snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
576 reason = message;
577 }
578 errors += reason;
579 }
580 if (errors.empty()) {
581 if (errno_copy != 0) {
582 errors += TOutput::strerror_s(errno_copy);
583 }
584 }
585 if (errors.empty()) {
586 errors = "error code: " + lexical_cast<string>(errno_copy);
587 }
588}
589
590/**
591 * Default implementation of AccessManager
592 */
593Decision DefaultClientAccessManager::verify(const sockaddr_storage& sa)
594 throw() { return SKIP; }
595
596Decision DefaultClientAccessManager::verify(const string& host,
597 const char* name,
598 int size) throw() {
599 if (host.empty() || name == NULL || size <= 0) {
600 return SKIP;
601 }
602 return (matchName(host.c_str(), name, size) ? ALLOW : SKIP);
603}
604
605Decision DefaultClientAccessManager::verify(const sockaddr_storage& sa,
606 const char* data,
607 int size) throw() {
608 bool match = false;
609 if (sa.ss_family == AF_INET && size == sizeof(in_addr)) {
610 match = (memcmp(&((sockaddr_in*)&sa)->sin_addr, data, size) == 0);
611 } else if (sa.ss_family == AF_INET6 && size == sizeof(in6_addr)) {
612 match = (memcmp(&((sockaddr_in6*)&sa)->sin6_addr, data, size) == 0);
613 }
614 return (match ? ALLOW : SKIP);
615}
616
617/**
618 * Match a name with a pattern. The pattern may include wildcard. A single
619 * wildcard "*" can match up to one component in the domain name.
620 *
621 * @param host Host name, typically the name of the remote host
622 * @param pattern Name retrieved from certificate
623 * @param size Size of "pattern"
624 * @return True, if "host" matches "pattern". False otherwise.
625 */
626bool matchName(const char* host, const char* pattern, int size) {
627 bool match = false;
628 int i = 0, j = 0;
629 while (i < size && host[j] != '\0') {
630 if (uppercase(pattern[i]) == uppercase(host[j])) {
631 i++;
632 j++;
633 continue;
634 }
635 if (pattern[i] == '*') {
636 while (host[j] != '.' && host[j] != '\0') {
637 j++;
638 }
639 i++;
640 continue;
641 }
642 break;
643 }
644 if (i == size && host[j] == '\0') {
645 match = true;
646 }
647 return match;
648
649}
650
651// This is to work around the Turkish locale issue, i.e.,
652// toupper('i') != toupper('I') if locale is "tr_TR"
653char uppercase (char c) {
654 if ('a' <= c && c <= 'z') {
655 return c + ('A' - 'a');
656 }
657 return c;
658}
659
660}}}