blob: 58e39345df0b12d396af584ea442c3a0ecc33122 [file] [log] [blame]
Bryan Duxburycd9aea12011-02-22 18:12:06 +00001// Copyright (c) 2009- Facebook
2// Distributed under the Thrift Software License
3//
4// See accompanying file LICENSE or visit the Thrift site at:
5// http://developers.facebook.com/thrift/
6
7#ifndef _THRIFT_TRANSPORT_TSSLSOCKET_H_
8#define _THRIFT_TRANSPORT_TSSLSOCKET_H_ 1
9
10#include <string>
11#include <boost/shared_ptr.hpp>
12#include <openssl/ssl.h>
13#include "concurrency/Mutex.h"
14#include "TSocket.h"
15
16namespace apache { namespace thrift { namespace transport {
17
18class AccessManager;
19class SSLContext;
20
21/**
22 * OpenSSL implementation for SSL socket interface.
23 *
24 * @author Ping Li <pingli@facebook.com>
25 */
26class TSSLSocket: public TSocket {
27 public:
28 ~TSSLSocket();
29 /**
30 * TTransport interface.
31 */
32 bool isOpen();
33 bool peek();
34 void open();
35 void close();
36 uint32_t read(uint8_t* buf, uint32_t len);
37 void write(const uint8_t* buf, uint32_t len);
38 void flush();
39 /**
40 * Set whether to use client or server side SSL handshake protocol.
41 *
42 * @param flag Use server side handshake protocol if true.
43 */
44 void server(bool flag) { server_ = flag; }
45 /**
46 * Determine whether the SSL socket is server or client mode.
47 */
48 bool server() const { return server_; }
49 /**
50 * Set AccessManager.
51 *
52 * @param manager Instance of AccessManager
53 */
54 virtual void access(boost::shared_ptr<AccessManager> manager) {
55 access_ = manager;
56 }
57protected:
58 /**
59 * Constructor.
60 */
61 TSSLSocket(boost::shared_ptr<SSLContext> ctx);
62 /**
63 * Constructor, create an instance of TSSLSocket given an existing socket.
64 *
65 * @param socket An existing socket
66 */
67 TSSLSocket(boost::shared_ptr<SSLContext> ctx, int socket);
68 /**
69 * Constructor.
70 *
71 * @param host Remote host name
72 * @param port Remote port number
73 */
74 TSSLSocket(boost::shared_ptr<SSLContext> ctx,
75 std::string host,
76 int port);
77 /**
78 * Authorize peer access after SSL handshake completes.
79 */
80 virtual void authorize();
81 /**
82 * Initiate SSL handshake if not already initiated.
83 */
84 void checkHandshake();
85
86 bool server_;
87 SSL* ssl_;
88 boost::shared_ptr<SSLContext> ctx_;
89 boost::shared_ptr<AccessManager> access_;
90 friend class TSSLSocketFactory;
91};
92
93/**
94 * SSL socket factory. SSL sockets should be created via SSL factory.
95 */
96class TSSLSocketFactory {
97 public:
98 /**
99 * Constructor/Destructor
100 */
101 TSSLSocketFactory();
102 virtual ~TSSLSocketFactory();
103 /**
104 * Create an instance of TSSLSocket with a fresh new socket.
105 */
106 virtual boost::shared_ptr<TSSLSocket> createSocket();
107 /**
108 * Create an instance of TSSLSocket with the given socket.
109 *
110 * @param socket An existing socket.
111 */
112 virtual boost::shared_ptr<TSSLSocket> createSocket(int socket);
113 /**
114 * Create an instance of TSSLSocket.
115 *
116 * @param host Remote host to be connected to
117 * @param port Remote port to be connected to
118 */
119 virtual boost::shared_ptr<TSSLSocket> createSocket(const std::string& host,
120 int port);
121 /**
122 * Set ciphers to be used in SSL handshake process.
123 *
124 * @param ciphers A list of ciphers
125 */
126 virtual void ciphers(const std::string& enable);
127 /**
128 * Enable/Disable authentication.
129 *
130 * @param required Require peer to present valid certificate if true
131 */
132 virtual void authenticate(bool required);
133 /**
134 * Load server certificate.
135 *
136 * @param path Path to the certificate file
137 * @param format Certificate file format
138 */
139 virtual void loadCertificate(const char* path, const char* format = "PEM");
140 /**
141 * Load private key.
142 *
143 * @param path Path to the private key file
144 * @param format Private key file format
145 */
146 virtual void loadPrivateKey(const char* path, const char* format = "PEM");
147 /**
148 * Load trusted certificates from specified file.
149 *
150 * @param path Path to trusted certificate file
151 */
152 virtual void loadTrustedCertificates(const char* path);
153 /**
154 * Default randomize method.
155 */
156 virtual void randomize();
157 /**
158 * Override default OpenSSL password callback with getPassword().
159 */
160 void overrideDefaultPasswordCallback();
161 /**
162 * Set/Unset server mode.
163 *
164 * @param flag Server mode if true
165 */
166 virtual void server(bool flag) { server_ = flag; }
167 /**
168 * Determine whether the socket is in server or client mode.
169 *
170 * @return true, if server mode, or, false, if client mode
171 */
172 virtual bool server() const { return server_; }
173 /**
174 * Set AccessManager.
175 *
176 * @param manager The AccessManager instance
177 */
178 virtual void access(boost::shared_ptr<AccessManager> manager) {
179 access_ = manager;
180 }
181 protected:
182 boost::shared_ptr<SSLContext> ctx_;
183
184 static void initializeOpenSSL();
185 static void cleanupOpenSSL();
186 /**
187 * Override this method for custom password callback. It may be called
188 * multiple times at any time during a session as necessary.
189 *
190 * @param password Pass collected password to OpenSSL
191 * @param size Maximum length of password including NULL character
192 */
193 virtual void getPassword(std::string& password, int size) { }
194 private:
195 bool server_;
196 boost::shared_ptr<AccessManager> access_;
197 static bool initialized;
198 static concurrency::Mutex mutex_;
199 static uint64_t count_;
200 void setup(boost::shared_ptr<TSSLSocket> ssl);
201 static int passwordCallback(char* password, int size, int, void* data);
202};
203
204/**
205 * SSL exception.
206 */
207class TSSLException: public TTransportException {
208 public:
209 TSSLException(const std::string& message):
210 TTransportException(TTransportException::INTERNAL_ERROR, message) {}
211
212 virtual const char* what() const throw() {
213 if (message_.empty()) {
214 return "TSSLException";
215 } else {
216 return message_.c_str();
217 }
218 }
219};
220
221/**
222 * Wrap OpenSSL SSL_CTX into a class.
223 */
224class SSLContext {
225 public:
226 SSLContext();
227 virtual ~SSLContext();
228 SSL* createSSL();
229 SSL_CTX* get() { return ctx_; }
230 private:
231 SSL_CTX* ctx_;
232};
233
234/**
235 * Callback interface for access control. It's meant to verify the remote host.
236 * It's constructed when application starts and set to TSSLSocketFactory
237 * instance. It's passed onto all TSSLSocket instances created by this factory
238 * object.
239 */
240class AccessManager {
241 public:
242 enum Decision {
243 DENY = -1, // deny access
244 SKIP = 0, // cannot make decision, move on to next (if any)
245 ALLOW = 1, // allow access
246 };
247 /**
248 * Destructor
249 */
250 virtual ~AccessManager() {}
251 /**
252 * Determine whether the peer should be granted access or not. It's called
253 * once after the SSL handshake completes successfully, before peer certificate
254 * is examined.
255 *
256 * If a valid decision (ALLOW or DENY) is returned, the peer certificate is
257 * not to be verified.
258 *
259 * @param sa Peer IP address
260 * @return True if the peer is trusted, false otherwise
261 */
262 virtual Decision verify(const sockaddr_storage& sa) throw() { return DENY; }
263 /**
264 * Determine whether the peer should be granted access or not. It's called
265 * every time a DNS subjectAltName/common name is extracted from peer's
266 * certificate.
267 *
268 * @param host Client mode: host name returned by TSocket::getHost()
269 * Server mode: host name returned by TSocket::getPeerHost()
270 * @param name SubjectAltName or common name extracted from peer certificate
271 * @param size Length of name
272 * @return True if the peer is trusted, false otherwise
273 *
274 * Note: The "name" parameter may be UTF8 encoded.
275 */
276 virtual Decision verify(const std::string& host, const char* name, int size)
277 throw() { return DENY; }
278 /**
279 * Determine whether the peer should be granted access or not. It's called
280 * every time an IP subjectAltName is extracted from peer's certificate.
281 *
282 * @param sa Peer IP address retrieved from the underlying socket
283 * @param data IP address extracted from certificate
284 * @param size Length of the IP address
285 * @return True if the peer is trusted, false otherwise
286 */
287 virtual Decision verify(const sockaddr_storage& sa, const char* data, int size)
288 throw() { return DENY; }
289};
290
291typedef AccessManager::Decision Decision;
292
293class DefaultClientAccessManager: public AccessManager {
294 public:
295 // AccessManager interface
296 Decision verify(const sockaddr_storage& sa) throw();
297 Decision verify(const std::string& host, const char* name, int size) throw();
298 Decision verify(const sockaddr_storage& sa, const char* data, int size) throw();
299};
300
301
302}}}
303
304#endif