blob: 6b241b2113ebab0ff9611bb0395960823e99e80d [file] [log] [blame]
Roger Meier0ae05422011-04-12 19:23:10 +00001/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
Bryan Duxburycd9aea12011-02-22 18:12:06 +000019
20#ifndef _THRIFT_TRANSPORT_TSSLSOCKET_H_
21#define _THRIFT_TRANSPORT_TSSLSOCKET_H_ 1
22
23#include <string>
24#include <boost/shared_ptr.hpp>
25#include <openssl/ssl.h>
26#include "concurrency/Mutex.h"
27#include "TSocket.h"
28
29namespace apache { namespace thrift { namespace transport {
30
31class AccessManager;
32class SSLContext;
33
34/**
35 * OpenSSL implementation for SSL socket interface.
Bryan Duxburycd9aea12011-02-22 18:12:06 +000036 */
37class TSSLSocket: public TSocket {
38 public:
39 ~TSSLSocket();
40 /**
41 * TTransport interface.
42 */
43 bool isOpen();
44 bool peek();
45 void open();
46 void close();
47 uint32_t read(uint8_t* buf, uint32_t len);
48 void write(const uint8_t* buf, uint32_t len);
49 void flush();
50 /**
51 * Set whether to use client or server side SSL handshake protocol.
52 *
53 * @param flag Use server side handshake protocol if true.
54 */
55 void server(bool flag) { server_ = flag; }
56 /**
57 * Determine whether the SSL socket is server or client mode.
58 */
59 bool server() const { return server_; }
60 /**
61 * Set AccessManager.
62 *
63 * @param manager Instance of AccessManager
64 */
65 virtual void access(boost::shared_ptr<AccessManager> manager) {
66 access_ = manager;
67 }
68protected:
69 /**
70 * Constructor.
71 */
72 TSSLSocket(boost::shared_ptr<SSLContext> ctx);
73 /**
74 * Constructor, create an instance of TSSLSocket given an existing socket.
75 *
76 * @param socket An existing socket
77 */
78 TSSLSocket(boost::shared_ptr<SSLContext> ctx, int socket);
79 /**
80 * Constructor.
81 *
82 * @param host Remote host name
83 * @param port Remote port number
84 */
85 TSSLSocket(boost::shared_ptr<SSLContext> ctx,
86 std::string host,
87 int port);
88 /**
89 * Authorize peer access after SSL handshake completes.
90 */
91 virtual void authorize();
92 /**
93 * Initiate SSL handshake if not already initiated.
94 */
95 void checkHandshake();
96
97 bool server_;
98 SSL* ssl_;
99 boost::shared_ptr<SSLContext> ctx_;
100 boost::shared_ptr<AccessManager> access_;
101 friend class TSSLSocketFactory;
102};
103
104/**
105 * SSL socket factory. SSL sockets should be created via SSL factory.
106 */
107class TSSLSocketFactory {
108 public:
109 /**
110 * Constructor/Destructor
111 */
112 TSSLSocketFactory();
113 virtual ~TSSLSocketFactory();
114 /**
115 * Create an instance of TSSLSocket with a fresh new socket.
116 */
117 virtual boost::shared_ptr<TSSLSocket> createSocket();
118 /**
119 * Create an instance of TSSLSocket with the given socket.
120 *
121 * @param socket An existing socket.
122 */
123 virtual boost::shared_ptr<TSSLSocket> createSocket(int socket);
124 /**
125 * Create an instance of TSSLSocket.
126 *
127 * @param host Remote host to be connected to
128 * @param port Remote port to be connected to
129 */
130 virtual boost::shared_ptr<TSSLSocket> createSocket(const std::string& host,
131 int port);
132 /**
133 * Set ciphers to be used in SSL handshake process.
134 *
135 * @param ciphers A list of ciphers
136 */
137 virtual void ciphers(const std::string& enable);
138 /**
139 * Enable/Disable authentication.
140 *
141 * @param required Require peer to present valid certificate if true
142 */
143 virtual void authenticate(bool required);
144 /**
145 * Load server certificate.
146 *
147 * @param path Path to the certificate file
148 * @param format Certificate file format
149 */
150 virtual void loadCertificate(const char* path, const char* format = "PEM");
151 /**
152 * Load private key.
153 *
154 * @param path Path to the private key file
155 * @param format Private key file format
156 */
157 virtual void loadPrivateKey(const char* path, const char* format = "PEM");
158 /**
159 * Load trusted certificates from specified file.
160 *
161 * @param path Path to trusted certificate file
162 */
163 virtual void loadTrustedCertificates(const char* path);
164 /**
165 * Default randomize method.
166 */
167 virtual void randomize();
168 /**
169 * Override default OpenSSL password callback with getPassword().
170 */
171 void overrideDefaultPasswordCallback();
172 /**
173 * Set/Unset server mode.
174 *
175 * @param flag Server mode if true
176 */
177 virtual void server(bool flag) { server_ = flag; }
178 /**
179 * Determine whether the socket is in server or client mode.
180 *
181 * @return true, if server mode, or, false, if client mode
182 */
183 virtual bool server() const { return server_; }
184 /**
185 * Set AccessManager.
186 *
187 * @param manager The AccessManager instance
188 */
189 virtual void access(boost::shared_ptr<AccessManager> manager) {
190 access_ = manager;
191 }
192 protected:
193 boost::shared_ptr<SSLContext> ctx_;
194
195 static void initializeOpenSSL();
196 static void cleanupOpenSSL();
197 /**
198 * Override this method for custom password callback. It may be called
199 * multiple times at any time during a session as necessary.
200 *
201 * @param password Pass collected password to OpenSSL
202 * @param size Maximum length of password including NULL character
203 */
204 virtual void getPassword(std::string& password, int size) { }
205 private:
206 bool server_;
207 boost::shared_ptr<AccessManager> access_;
208 static bool initialized;
209 static concurrency::Mutex mutex_;
210 static uint64_t count_;
211 void setup(boost::shared_ptr<TSSLSocket> ssl);
212 static int passwordCallback(char* password, int size, int, void* data);
213};
214
215/**
216 * SSL exception.
217 */
218class TSSLException: public TTransportException {
219 public:
220 TSSLException(const std::string& message):
221 TTransportException(TTransportException::INTERNAL_ERROR, message) {}
222
223 virtual const char* what() const throw() {
224 if (message_.empty()) {
225 return "TSSLException";
226 } else {
227 return message_.c_str();
228 }
229 }
230};
231
232/**
233 * Wrap OpenSSL SSL_CTX into a class.
234 */
235class SSLContext {
236 public:
237 SSLContext();
238 virtual ~SSLContext();
239 SSL* createSSL();
240 SSL_CTX* get() { return ctx_; }
241 private:
242 SSL_CTX* ctx_;
243};
244
245/**
246 * Callback interface for access control. It's meant to verify the remote host.
247 * It's constructed when application starts and set to TSSLSocketFactory
248 * instance. It's passed onto all TSSLSocket instances created by this factory
249 * object.
250 */
251class AccessManager {
252 public:
253 enum Decision {
254 DENY = -1, // deny access
255 SKIP = 0, // cannot make decision, move on to next (if any)
256 ALLOW = 1, // allow access
257 };
258 /**
259 * Destructor
260 */
261 virtual ~AccessManager() {}
262 /**
263 * Determine whether the peer should be granted access or not. It's called
264 * once after the SSL handshake completes successfully, before peer certificate
265 * is examined.
266 *
267 * If a valid decision (ALLOW or DENY) is returned, the peer certificate is
268 * not to be verified.
269 *
270 * @param sa Peer IP address
271 * @return True if the peer is trusted, false otherwise
272 */
273 virtual Decision verify(const sockaddr_storage& sa) throw() { return DENY; }
274 /**
275 * Determine whether the peer should be granted access or not. It's called
276 * every time a DNS subjectAltName/common name is extracted from peer's
277 * certificate.
278 *
279 * @param host Client mode: host name returned by TSocket::getHost()
280 * Server mode: host name returned by TSocket::getPeerHost()
281 * @param name SubjectAltName or common name extracted from peer certificate
282 * @param size Length of name
283 * @return True if the peer is trusted, false otherwise
284 *
285 * Note: The "name" parameter may be UTF8 encoded.
286 */
287 virtual Decision verify(const std::string& host, const char* name, int size)
288 throw() { return DENY; }
289 /**
290 * Determine whether the peer should be granted access or not. It's called
291 * every time an IP subjectAltName is extracted from peer's certificate.
292 *
293 * @param sa Peer IP address retrieved from the underlying socket
294 * @param data IP address extracted from certificate
295 * @param size Length of the IP address
296 * @return True if the peer is trusted, false otherwise
297 */
298 virtual Decision verify(const sockaddr_storage& sa, const char* data, int size)
299 throw() { return DENY; }
300};
301
302typedef AccessManager::Decision Decision;
303
304class DefaultClientAccessManager: public AccessManager {
305 public:
306 // AccessManager interface
307 Decision verify(const sockaddr_storage& sa) throw();
308 Decision verify(const std::string& host, const char* name, int size) throw();
309 Decision verify(const sockaddr_storage& sa, const char* data, int size) throw();
310};
311
312
313}}}
314
315#endif