THRIFT-5185: Support for using WebSockets as a server transport
Client: cpp
diff --git a/lib/cpp/CMakeLists.txt b/lib/cpp/CMakeLists.txt
index d705bb3..29caad4 100755
--- a/lib/cpp/CMakeLists.txt
+++ b/lib/cpp/CMakeLists.txt
@@ -57,6 +57,8 @@
src/thrift/transport/TServerSocket.cpp
src/thrift/transport/TTransportUtils.cpp
src/thrift/transport/TBufferTransports.cpp
+ src/thrift/transport/TWebSocketServer.h
+ src/thrift/transport/TWebSocketServer.cpp
src/thrift/server/TConnectedClient.cpp
src/thrift/server/TServerFramework.cpp
src/thrift/server/TSimpleServer.cpp
diff --git a/lib/cpp/Makefile.am b/lib/cpp/Makefile.am
index 9b5fb4c..f699f83 100755
--- a/lib/cpp/Makefile.am
+++ b/lib/cpp/Makefile.am
@@ -87,6 +87,7 @@
src/thrift/transport/TNonblockingSSLServerSocket.cpp \
src/thrift/transport/TTransportUtils.cpp \
src/thrift/transport/TBufferTransports.cpp \
+ src/thrift/transport/TWebSocketServer.cpp \
src/thrift/server/TConnectedClient.cpp \
src/thrift/server/TServer.cpp \
src/thrift/server/TServerFramework.cpp \
@@ -140,7 +141,8 @@
src/thrift/TApplicationException.h \
src/thrift/TLogging.h \
src/thrift/TToString.h \
- src/thrift/TBase.h
+ src/thrift/TBase.h \
+ src/thrift/portable_endian.h
include_concurrencydir = $(include_thriftdir)/concurrency
include_concurrency_HEADERS = \
@@ -198,7 +200,8 @@
src/thrift/transport/TTransportUtils.h \
src/thrift/transport/TBufferTransports.h \
src/thrift/transport/TShortReadTransport.h \
- src/thrift/transport/TZlibTransport.h
+ src/thrift/transport/TZlibTransport.h \
+ src/thrift/transport/TWebSocketServer.h
include_serverdir = $(include_thriftdir)/server
include_server_HEADERS = \
diff --git a/lib/cpp/src/thrift/portable_endian.h b/lib/cpp/src/thrift/portable_endian.h
new file mode 100644
index 0000000..e07010e
--- /dev/null
+++ b/lib/cpp/src/thrift/portable_endian.h
@@ -0,0 +1,131 @@
+//
+// endian.h
+//
+// https://gist.github.com/panzi/6856583
+//
+// I, Mathias Panzenböck, place this file hereby into the public domain. Use
+// it at your own risk for whatever you like. In case there are
+// jurisdictions that don't support putting things in the public domain you
+// can also consider it to be "dual licensed" under the BSD, MIT and Apache
+// licenses, if you want to. This code is trivial anyway. Consider it an
+// example on how to get the endian conversion functions on different
+// platforms.
+
+#ifndef PORTABLE_ENDIAN_H__
+#define PORTABLE_ENDIAN_H__
+
+#if (defined(_WIN16) || defined(_WIN32) || defined(_WIN64)) && !defined(__WINDOWS__)
+
+# define __WINDOWS__
+
+#endif
+
+#if defined(__linux__) || defined(__CYGWIN__)
+
+# include <endian.h>
+
+#elif defined(__APPLE__)
+
+# include <libkern/OSByteOrder.h>
+
+# define htobe16(x) OSSwapHostToBigInt16(x)
+# define htole16(x) OSSwapHostToLittleInt16(x)
+# define be16toh(x) OSSwapBigToHostInt16(x)
+# define le16toh(x) OSSwapLittleToHostInt16(x)
+
+# define htobe32(x) OSSwapHostToBigInt32(x)
+# define htole32(x) OSSwapHostToLittleInt32(x)
+# define be32toh(x) OSSwapBigToHostInt32(x)
+# define le32toh(x) OSSwapLittleToHostInt32(x)
+
+# define htobe64(x) OSSwapHostToBigInt64(x)
+# define htole64(x) OSSwapHostToLittleInt64(x)
+# define be64toh(x) OSSwapBigToHostInt64(x)
+# define le64toh(x) OSSwapLittleToHostInt64(x)
+
+# define __BYTE_ORDER BYTE_ORDER
+# define __BIG_ENDIAN BIG_ENDIAN
+# define __LITTLE_ENDIAN LITTLE_ENDIAN
+# define __PDP_ENDIAN PDP_ENDIAN
+
+#elif defined(__OpenBSD__)
+
+# include <sys/endian.h>
+
+#elif defined(__NetBSD__) || defined(__FreeBSD__) || defined(__DragonFly__)
+
+# include <sys/endian.h>
+
+# define be16toh(x) betoh16(x)
+# define le16toh(x) letoh16(x)
+
+# define be32toh(x) betoh32(x)
+# define le32toh(x) letoh32(x)
+
+# define be64toh(x) betoh64(x)
+# define le64toh(x) letoh64(x)
+
+#elif defined(__WINDOWS__)
+
+# include <winsock2.h>
+
+# if BYTE_ORDER == LITTLE_ENDIAN
+
+# define htobe16(x) htons(x)
+# define htole16(x) (x)
+# define be16toh(x) ntohs(x)
+# define le16toh(x) (x)
+
+# define htobe32(x) htonl(x)
+# define htole32(x) (x)
+# define be32toh(x) ntohl(x)
+# define le32toh(x) (x)
+
+# if defined(__MINGW32__)
+# define htobe64(x) __builtin_bswap64(x)
+# define htole64(x) (x)
+# define be64toh(x) __builtin_bswap64(x)
+# define le64toh(x) (x)
+# else
+# define htobe64(x) htonll(x)
+# define htole64(x) (x)
+# define be64toh(x) ntohll(x)
+# define le64toh(x) (x)
+# endif
+
+# elif BYTE_ORDER == BIG_ENDIAN
+
+ /* that would be xbox 360 */
+# define htobe16(x) (x)
+# define htole16(x) __builtin_bswap16(x)
+# define be16toh(x) (x)
+# define le16toh(x) __builtin_bswap16(x)
+
+# define htobe32(x) (x)
+# define htole32(x) __builtin_bswap32(x)
+# define be32toh(x) (x)
+# define le32toh(x) __builtin_bswap32(x)
+
+# define htobe64(x) (x)
+# define htole64(x) __builtin_bswap64(x)
+# define be64toh(x) (x)
+# define le64toh(x) __builtin_bswap64(x)
+
+# else
+
+# error byte order not supported
+
+# endif
+
+# define __BYTE_ORDER BYTE_ORDER
+# define __BIG_ENDIAN BIG_ENDIAN
+# define __LITTLE_ENDIAN LITTLE_ENDIAN
+# define __PDP_ENDIAN PDP_ENDIAN
+
+#else
+
+# error platform not supported
+
+#endif
+
+#endif
diff --git a/lib/cpp/src/thrift/transport/THttpServer.cpp b/lib/cpp/src/thrift/transport/THttpServer.cpp
index 94ac681..98518fd 100644
--- a/lib/cpp/src/thrift/transport/THttpServer.cpp
+++ b/lib/cpp/src/thrift/transport/THttpServer.cpp
@@ -124,12 +124,7 @@
writeBuffer_.getBuffer(&buf, &len);
// Construct the HTTP header
- std::ostringstream h;
- h << "HTTP/1.1 200 OK" << CRLF << "Date: " << getTimeRFC1123() << CRLF << "Server: Thrift/"
- << PACKAGE_VERSION << CRLF << "Access-Control-Allow-Origin: *" << CRLF
- << "Content-Type: application/x-thrift" << CRLF << "Content-Length: " << len << CRLF
- << "Connection: Keep-Alive" << CRLF << CRLF;
- string header = h.str();
+ string header = getHeader(len);
// Write the header, then the data, then flush
// cast should be fine, because none of "header" is under attacker control
@@ -142,6 +137,15 @@
readHeaders_ = true;
}
+std::string THttpServer::getHeader(uint32_t len) {
+ std::ostringstream h;
+ h << "HTTP/1.1 200 OK" << CRLF << "Date: " << getTimeRFC1123() << CRLF << "Server: Thrift/"
+ << PACKAGE_VERSION << CRLF << "Access-Control-Allow-Origin: *" << CRLF
+ << "Content-Type: application/x-thrift" << CRLF << "Content-Length: " << len << CRLF
+ << "Connection: Keep-Alive" << CRLF << CRLF;
+ return h.str();
+}
+
std::string THttpServer::getTimeRFC1123() {
static const char* Days[] = {"Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"};
static const char* Months[]
diff --git a/lib/cpp/src/thrift/transport/THttpServer.h b/lib/cpp/src/thrift/transport/THttpServer.h
index 0e83399..d219691 100644
--- a/lib/cpp/src/thrift/transport/THttpServer.h
+++ b/lib/cpp/src/thrift/transport/THttpServer.h
@@ -35,6 +35,7 @@
void flush() override;
protected:
+ virtual std::string getHeader(uint32_t len);
void readHeaders();
void parseHeader(char* header) override;
bool parseStatusLine(char* status) override;
diff --git a/lib/cpp/src/thrift/transport/TWebSocketServer.cpp b/lib/cpp/src/thrift/transport/TWebSocketServer.cpp
new file mode 100644
index 0000000..9822a7f
--- /dev/null
+++ b/lib/cpp/src/thrift/transport/TWebSocketServer.cpp
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <functional>
+#include <memory>
+#include <string>
+
+#include <openssl/bio.h>
+#include <openssl/evp.h>
+
+#include <thrift/Thrift.h>
+
+using std::string;
+
+namespace apache {
+namespace thrift {
+namespace transport {
+
+std::string base64Encode(unsigned char* data, int length) {
+ std::unique_ptr<BIO, std::function<void(BIO*)>> base64(BIO_new(BIO_f_base64()),
+ [](BIO* b) { BIO_free_all(b); });
+ BIO_set_flags(base64.get(), BIO_FLAGS_BASE64_NO_NL);
+
+ BIO* dest = BIO_new(BIO_s_mem());
+ BIO_push(base64.get(), dest);
+ BIO_write(base64.get(), data, length);
+ int ret = BIO_flush(base64.get());
+ THRIFT_UNUSED_VARIABLE(ret);
+
+ char* encoded;
+ length = BIO_get_mem_data(dest, &encoded);
+ return std::string(encoded, length);
+}
+} // namespace transport
+} // namespace thrift
+} // namespace apache
diff --git a/lib/cpp/src/thrift/transport/TWebSocketServer.h b/lib/cpp/src/thrift/transport/TWebSocketServer.h
new file mode 100644
index 0000000..8edc286
--- /dev/null
+++ b/lib/cpp/src/thrift/transport/TWebSocketServer.h
@@ -0,0 +1,416 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_TRANSPORT_TWEBSOCKETSERVER_H_
+#define _THRIFT_TRANSPORT_TWEBSOCKETSERVER_H_ 1
+
+#include <thrift/portable_endian.h>
+
+#include <cstdlib>
+#include <iostream>
+#include <sstream>
+
+#include <openssl/sha.h>
+
+#include <thrift/config.h>
+#include <thrift/transport/TSocket.h>
+#include <thrift/transport/THttpServer.h>
+#if defined(_MSC_VER) || defined(__MINGW32__)
+#include <Shlwapi.h>
+#define THRIFT_strncasecmp(str1, str2, len) _strnicmp(str1, str2, len)
+#define THRIFT_strcasestr(haystack, needle) StrStrIA(haystack, needle)
+#else
+#define THRIFT_strncasecmp(str1, str2, len) strncasecmp(str1, str2, len)
+#define THRIFT_strcasestr(haystack, needle) strcasestr(haystack, needle)
+#endif
+#if defined(__CYGWIN__)
+#include <alloca.h>
+#endif
+
+using std::string;
+
+namespace apache {
+namespace thrift {
+namespace transport {
+
+std::string base64Encode(unsigned char* data, int length);
+
+template <bool binary>
+class TWebSocketServer : public THttpServer {
+public:
+ TWebSocketServer(std::shared_ptr<TTransport> transport)
+ : THttpServer(transport) {
+ resetHandshake();
+ }
+
+ ~TWebSocketServer() override = default;
+
+ uint32_t readAll_virt(uint8_t* buf, uint32_t len) override {
+ // If we do not have a good handshake, the client will attempt one.
+ if (!handshakeComplete()) {
+ resetHandshake();
+ THttpServer::read(buf, len);
+ // If we did not get everything we expected, the handshake failed
+ // and we need to send a 400 response back.
+ if (!handshakeComplete()) {
+ sendBadRequest();
+ return 0;
+ }
+ // Otherwise, send back the 101 response.
+ THttpServer::flush();
+ }
+
+ uint32_t want = len;
+ auto have = readBuffer_.available_read();
+
+ // If we have some data in the buffer, copy it out and return it.
+ // We have to return it without attempting to read more, since we aren't
+ // guaranteed that the underlying transport actually has more data, so
+ // attempting to read from it could block.
+ if (have > 0 && have >= want) {
+ return readBuffer_.read(buf, want);
+ }
+
+ // Read another frame.
+ if (!readFrame()) {
+ // EOF. No frame available.
+ return 0;
+ }
+
+ // Hand over whatever we have.
+ uint32_t give = (std::min)(want, readBuffer_.available_read());
+ return readBuffer_.read(buf, give);
+ }
+
+ void flush() override {
+ writeFrameHeader();
+ uint8_t* buffer;
+ uint32_t length;
+ writeBuffer_.getBuffer(&buffer, &length);
+ transport_->write(buffer, length);
+ transport_->flush();
+ writeBuffer_.resetBuffer();
+ }
+
+protected:
+ std::string getHeader(uint32_t len) override {
+ THRIFT_UNUSED_VARIABLE(len);
+ std::ostringstream h;
+ h << "HTTP/1.1 101 Switching Protocols" << CRLF << "Server: Thrift/" << PACKAGE_VERSION << CRLF
+ << "Upgrade: websocket" << CRLF << "Connection: Upgrade" << CRLF
+ << "Sec-WebSocket-Accept: " << acceptKey_ << CRLF << CRLF;
+ return h.str();
+ }
+
+ void parseHeader(char* header) override {
+ char* colon = strchr(header, ':');
+ if (colon == nullptr) {
+ return;
+ }
+ size_t sz = colon - header;
+ char* value = colon + 1;
+
+ if (THRIFT_strncasecmp(header, "Upgrade", sz) == 0) {
+ if (THRIFT_strcasestr(value, "websocket") != nullptr) {
+ upgrade_ = true;
+ }
+ } else if (THRIFT_strncasecmp(header, "Connection", sz) == 0) {
+ if (THRIFT_strcasestr(value, "Upgrade") != nullptr) {
+ connection_ = true;
+ }
+ } else if (THRIFT_strncasecmp(header, "Sec-WebSocket-Key", sz) == 0) {
+ std::string toHash = value + 1;
+ toHash += "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
+ unsigned char hash[20];
+ SHA1((const unsigned char*)toHash.c_str(), toHash.length(), hash);
+ acceptKey_ = base64Encode(hash, 20);
+ secWebSocketKey_ = true;
+ } else if (THRIFT_strncasecmp(header, "Sec-WebSocket-Version", sz) == 0) {
+ if (THRIFT_strcasestr(value, "13") != nullptr) {
+ secWebSocketVersion_ = true;
+ }
+ }
+ }
+
+ bool parseStatusLine(char* status) override {
+ char* method = status;
+
+ char* path = strchr(method, ' ');
+ if (path == nullptr) {
+ throw TTransportException(string("Bad Status: ") + status);
+ }
+
+ *path = '\0';
+ while (*(++path) == ' ') {
+ };
+
+ char* http = strchr(path, ' ');
+ if (http == nullptr) {
+ throw TTransportException(string("Bad Status: ") + status);
+ }
+ *http = '\0';
+
+ if (strcmp(method, "GET") == 0) {
+ // GET method ok, looking for content.
+ return true;
+ }
+ throw TTransportException(string("Bad Status (unsupported method): ") + status);
+ }
+
+private:
+ enum class CloseCode : uint16_t {
+ NormalClosure = 1000,
+ GoingAway = 1001,
+ ProtocolError = 1002,
+ UnsupportedDataType = 1003,
+ NoStatusCode = 1005,
+ AbnormalClosure = 1006,
+ InvalidData = 1007,
+ PolicyViolation = 1008,
+ MessageTooBig = 1009,
+ ExtensionExpected = 1010,
+ UnexpectedError = 1011,
+ NotSecure = 1015
+ };
+
+ enum class Opcode : uint8_t {
+ Continuation = 0x0,
+ Text = 0x1,
+ Binary = 0x2,
+ Close = 0x8,
+ Ping = 0x9,
+ Pong = 0xA
+ };
+
+ void failConnection(CloseCode reason) {
+ writeFrameHeader(Opcode::Close);
+ auto buffer = htobe16(static_cast<uint16_t>(reason));
+ transport_->write(reinterpret_cast<const uint8_t*>(&buffer), 2);
+ transport_->flush();
+ transport_->close();
+ }
+
+ bool handshakeComplete() {
+ return upgrade_ && connection_ && secWebSocketKey_ && secWebSocketVersion_;
+ }
+
+ void pong() {
+ writeFrameHeader(Opcode::Pong);
+ uint8_t* buffer;
+ uint32_t size;
+ readBuffer_.getBuffer(&buffer, &size);
+ transport_->write(buffer, size);
+ transport_->flush();
+ }
+
+ bool readFrame() {
+ uint8_t headerBuffer[8];
+
+ auto read = transport_->read(headerBuffer, 2);
+ if (read < 2) {
+ return false;
+ }
+ // Since Thrift has its own message end marker and we read frame by frame,
+ // it doesn't really matter if the frame is marked as FIN.
+ // Capture it only for debugging only.
+ auto fin = (headerBuffer[0] & 0x80) != 0;
+ THRIFT_UNUSED_VARIABLE(fin);
+
+ // RSV1, RSV2, RSV3
+ if ((headerBuffer[0] & 0x70) != 0) {
+ failConnection(CloseCode::ProtocolError);
+ throw TTransportException(TTransportException::CORRUPTED_DATA,
+ "Reserved bits must be zeroes");
+ }
+
+ auto opcode = (Opcode)(headerBuffer[0] & 0x0F);
+
+ // Mask
+ if ((headerBuffer[1] & 0x80) == 0) {
+ failConnection(CloseCode::ProtocolError);
+ throw TTransportException(TTransportException::CORRUPTED_DATA,
+ "Messages from the client must be masked");
+ }
+
+ // Read the length
+ uint64_t payloadLength = headerBuffer[1] & 0x7F;
+ if (payloadLength == 126) {
+ read = transport_->read(headerBuffer, 2);
+ if (read < 2) {
+ return false;
+ }
+ payloadLength = be16toh(*reinterpret_cast<uint16_t*>(headerBuffer));
+ } else if (payloadLength == 127) {
+ read = transport_->read(headerBuffer, 8);
+ if (read < 8) {
+ return false;
+ }
+ payloadLength = be64toh(*reinterpret_cast<uint64_t*>(headerBuffer));
+ if ((payloadLength & 0x8000000000000000) != 0) {
+ failConnection(CloseCode::ProtocolError);
+ throw TTransportException(
+ TTransportException::CORRUPTED_DATA,
+ "The most significant bit of the payload length must be zero");
+ }
+ }
+
+ // size_t is smaller than a ulong on a 32-bit system
+ if (payloadLength > UINT32_MAX) {
+ failConnection(CloseCode::MessageTooBig);
+ return false;
+ }
+
+ auto length = static_cast<uint32_t>(payloadLength);
+
+ if (length > 0) {
+ // Read the masking key
+ read = transport_->read(headerBuffer, 4);
+ if (read < 4) {
+ return false;
+ }
+
+ readBuffer_.resetBuffer(length);
+ uint8_t* buffer = readBuffer_.getWritePtr(length);
+ read = transport_->read(buffer, length);
+ readBuffer_.wroteBytes(read);
+ if (read < length) {
+ return false;
+ }
+
+ // Unmask the data
+ for (size_t i = 0; i < length; i++) {
+ buffer[i] ^= headerBuffer[i % 4];
+ }
+
+ T_DEBUG("FIN=%d, Opcode=%X, length=%d, payload=%s", fin, opcode, length,
+ binary ? readBuffer_.toHexString() : cast(string) readBuffer_);
+ }
+
+ switch (opcode) {
+ case Opcode::Close:
+ if (length >= 2) {
+ uint8_t buffer[2];
+ readBuffer_.read(buffer, 2);
+ CloseCode closeCode = static_cast<CloseCode>(be16toh(*reinterpret_cast<uint16_t*>(buffer)));
+ THRIFT_UNUSED_VARIABLE(closeCode);
+ string closeReason = readBuffer_.readAsString(length - 2);
+ T_DEBUG("Connection closed: %d %s", closeCode, closeReason);
+ }
+ transport_->close();
+ return false;
+ case Opcode::Ping:
+ pong();
+ return readFrame();
+ default:
+ return true;
+ }
+ }
+
+ void resetHandshake() {
+ connection_ = false;
+ secWebSocketKey_ = false;
+ secWebSocketVersion_ = false;
+ upgrade_ = false;
+ }
+
+ void sendBadRequest() {
+ std::ostringstream h;
+ h << "HTTP/1.1 400 Bad Request" << CRLF << "Server: Thrift/" << PACKAGE_VERSION << CRLF << CRLF;
+ std::string header = h.str();
+ transport_->write(reinterpret_cast<const uint8_t*>(header.data()), static_cast<uint32_t>(header.length()));
+ transport_->flush();
+ transport_->close();
+ }
+
+ void writeFrameHeader(Opcode opcode = Opcode::Continuation) {
+ uint32_t headerSize = 1;
+ uint32_t length = writeBuffer_.available_read();
+ if (length < 126) {
+ ++headerSize;
+ } else if (length < 65536) {
+ headerSize += 3;
+ } else {
+ headerSize += 9;
+ }
+ // The server does not mask the response
+
+ uint8_t* header = static_cast<uint8_t*>(alloca(headerSize));
+ if (opcode == Opcode::Continuation) {
+ opcode = binary ? Opcode::Binary : Opcode::Text;
+ }
+ header[0] = static_cast<uint8_t>(opcode) | 0x80;
+ if (length < 126) {
+ header[1] = static_cast<uint8_t>(length);
+ } else if (length < 65536) {
+ header[1] = 126;
+ *reinterpret_cast<uint16_t*>(header + 2) = htobe16(length);
+ } else {
+ header[1] = 127;
+ *reinterpret_cast<uint64_t*>(header + 2) = htobe64(length);
+ }
+
+ transport_->write(header, headerSize);
+ }
+
+ // Add constant here to avoid a linker error on Windows
+ constexpr static const char* CRLF = "\r\n";
+ std::string acceptKey_;
+ bool connection_;
+ bool secWebSocketKey_;
+ bool secWebSocketVersion_;
+ bool upgrade_;
+};
+
+/**
+ * Wraps a transport into binary WebSocket protocol
+ */
+class TBinaryWebSocketServerTransportFactory : public TTransportFactory {
+public:
+ TBinaryWebSocketServerTransportFactory() = default;
+
+ ~TBinaryWebSocketServerTransportFactory() override = default;
+
+ /**
+ * Wraps the transport into a buffered one.
+ */
+ std::shared_ptr<TTransport> getTransport(std::shared_ptr<TTransport> trans) override {
+ return std::shared_ptr<TTransport>(new TWebSocketServer<true>(trans));
+ }
+};
+
+/**
+ * Wraps a transport into text WebSocket protocol
+ */
+class TTextWebSocketServerTransportFactory : public TTransportFactory {
+public:
+ TTextWebSocketServerTransportFactory() = default;
+
+ ~TTextWebSocketServerTransportFactory() override = default;
+
+ /**
+ * Wraps the transport into a buffered one.
+ */
+ std::shared_ptr<TTransport> getTransport(std::shared_ptr<TTransport> trans) override {
+ return std::shared_ptr<TTransport>(new TWebSocketServer<false>(trans));
+ }
+};
+} // namespace transport
+} // namespace thrift
+} // namespace apache
+#endif