Rev 2 of Thrift, the Pillar successor
Summary: End-to-end communications and serialization in C++ is working
Reviewed By: aditya
Test Plan: See the new top-level test/ folder. It vaguely resembles a unit test, though it could be more automated.
Revert Plan: Revertible
Notes: Still a LOT of optimization work to be done on the generated C++ code, which should be using dynamic memory in a number of places. Next major task is writing the PHP/Java/Python generators.
git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@664712 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/cpp/Makefile b/lib/cpp/Makefile
new file mode 100644
index 0000000..2045dba
--- /dev/null
+++ b/lib/cpp/Makefile
@@ -0,0 +1,28 @@
+# Makefile for Thrift C++ library.
+#
+# Author:
+# Mark Slee <mcslee@facebook.com>
+
+target: libthrift
+
+# Tools
+LD = g++
+LDFL = -shared -Wall -I. -fPIC -Wl,-soname=libthrift.so
+
+# Source files
+SRCS = client/TSimpleClient.cc \
+ protocol/TBinaryProtocol.cc \
+ server/TSimpleServer.cc \
+ transport/TSocket.cc \
+ transport/TServerSocket.cc
+
+# Linked library
+libthrift:
+ $(LD) -o libthrift.so $(LDFL) $(SRCS)
+
+clean:
+ rm -f libthrift.so
+
+# Install
+install: libthrift
+ sudo install libthrift.so /usr/local/lib
diff --git a/lib/cpp/TDispatcher.h b/lib/cpp/TDispatcher.h
new file mode 100644
index 0000000..f8ff847
--- /dev/null
+++ b/lib/cpp/TDispatcher.h
@@ -0,0 +1,22 @@
+#ifndef T_DISPATCHER_H
+#define T_DISPATCHER_H
+
+#include <string>
+
+/**
+ * A dispatcher is a generic object that accepts an input buffer and returns
+ * a buffer. It can be used in a variety of ways, i.e. as a client that
+ * sends data over the network and returns a response, or as a server that
+ * reads an input and returns an output.
+ *
+ * @author Mark Slee <mcslee@facebook.com>
+ */
+class TDispatcher {
+ public:
+ virtual ~TDispatcher() {};
+ virtual std::string dispatch(const std::string& s) = 0;
+ protected:
+ TDispatcher() {}
+};
+
+#endif
diff --git a/lib/cpp/Thrift.h b/lib/cpp/Thrift.h
new file mode 100644
index 0000000..04fbaa1
--- /dev/null
+++ b/lib/cpp/Thrift.h
@@ -0,0 +1,10 @@
+#ifndef THRIFT_H
+#define THRIFT_H
+
+#include <sys/types.h>
+#include <string>
+#include <map>
+#include <list>
+#include <set>
+
+#endif
diff --git a/lib/cpp/client/TClient.h b/lib/cpp/client/TClient.h
new file mode 100644
index 0000000..73dd093
--- /dev/null
+++ b/lib/cpp/client/TClient.h
@@ -0,0 +1,16 @@
+#ifndef T_CLIENT_H
+#define T_CLIENT_H
+
+#include "TDispatcher.h"
+
+class TClient : public TDispatcher {
+ public:
+ virtual ~TClient() {}
+ virtual bool open() = 0;
+ virtual void close() = 0;
+ protected:
+ TClient() {}
+};
+
+#endif
+
diff --git a/lib/cpp/client/TSimpleClient.cc b/lib/cpp/client/TSimpleClient.cc
new file mode 100644
index 0000000..9069c91
--- /dev/null
+++ b/lib/cpp/client/TSimpleClient.cc
@@ -0,0 +1,44 @@
+#include "TSimpleClient.h"
+using std::string;
+
+TSimpleClient::TSimpleClient(TTransport* transport) :
+ transport_(transport) {}
+
+bool TSimpleClient::open() {
+ return transport_->open();
+}
+
+void TSimpleClient::close() {
+ transport_->close();
+}
+
+std::string TSimpleClient::dispatch(const string& s) {
+ // Write size header
+ int32_t size = s.size();
+ // fprintf(stderr, "Writing size header %d to server\n", size);
+ transport_->write(string((char*)&size, 4));
+
+ // Write data payload
+ // fprintf(stderr, "Writing %d byte payload to server\n", (int)s.size());
+ transport_->write(s);
+
+ // Read response size
+ // fprintf(stderr, "Reading 4-byte response size header\n");
+ string response;
+ transport_->read(response, 4);
+ size = *(int32_t*)response.data();
+
+ // Read response data
+ if (size < 0) {
+ // TODO(mcslee): Handle exception
+ // fprintf(stderr, "Exception case! Response size < 0\n");
+ return "";
+ } else {
+ // fprintf(stderr, "Reading %d byte response payload\n", size);
+ transport_->read(response, size);
+ // TODO(mcslee): Check that we actually read enough data
+ // fprintf(stderr, "Done reading payload, returning.\n");
+ return response;
+ }
+}
+
diff --git a/lib/cpp/client/TSimpleClient.h b/lib/cpp/client/TSimpleClient.h
new file mode 100644
index 0000000..249afe5
--- /dev/null
+++ b/lib/cpp/client/TSimpleClient.h
@@ -0,0 +1,21 @@
+#ifndef T_SIMPLE_CLIENT_H
+#define T_SIMPLE_CLIENT_H
+
+#include "client/TClient.h"
+#include "transport/TTransport.h"
+
+class TSimpleClient : public TClient {
+ public:
+ TSimpleClient(TTransport* transport);
+ ~TSimpleClient() {}
+
+ bool open();
+ void close();
+ std::string dispatch(const std::string& in);
+
+ protected:
+ TTransport* transport_;
+};
+
+#endif
+
diff --git a/lib/cpp/protocol/TBinaryProtocol.cc b/lib/cpp/protocol/TBinaryProtocol.cc
new file mode 100644
index 0000000..4c10bab
--- /dev/null
+++ b/lib/cpp/protocol/TBinaryProtocol.cc
@@ -0,0 +1,140 @@
+#include "protocol/TBinaryProtocol.h"
+using namespace std;
+
+string TBinaryProtocol::readFunction(TBuf& buf) const {
+ // Let readString increment the buffer position
+ return readString(buf);
+}
+
+string TBinaryProtocol::writeFunction(const string& name,
+ const string& args) const{
+ return writeString(name) + args;
+}
+
+map<uint32_t, TBuf> TBinaryProtocol::readStruct(TBuf& buf) const {
+ map<uint32_t, TBuf> fieldMap;
+
+ if (buf.len < 4) {
+ return fieldMap;
+ }
+ uint32_t total_size = readU32(buf);
+ if (buf.len < total_size) {
+ // Data looks corrupt, we don't have that much, we will try to read what
+ // we can but be sure not to go over
+ total_size = buf.len;
+ }
+
+ // Field headers are 8 bytes, 4 byte fid + 4 byte length
+ while (total_size > 0 && buf.len > 8) {
+ uint32_t fid = readU32(buf);
+ uint32_t flen = readU32(buf);
+ if (flen > buf.len) {
+ // flen corrupt, there isn't that much data left
+ break;
+ }
+ fieldMap.insert(make_pair(fid, TBuf(buf.data, flen)));
+ buf.data += flen;
+ buf.len -= flen;
+ total_size -= 8 + flen;
+ }
+
+ return fieldMap;
+}
+
+string TBinaryProtocol::writeStruct(const map<uint32_t,string>& s) const {
+ string result = "";
+ map<uint32_t,string>::const_iterator s_iter;
+ for (s_iter = s.begin(); s_iter != s.end(); ++s_iter) {
+ result += writeU32(s_iter->first);
+ result += writeU32(s_iter->second.size());
+ result += s_iter->second;
+ }
+ return writeU32(result.size()) + result;
+}
+
+string TBinaryProtocol::readString(TBuf& buf) const {
+ uint32_t len = readU32(buf);
+ if (len == 0) {
+ return "";
+ }
+ string result((const char*)(buf.data), len);
+ buf.data += len;
+ buf.len -= len;
+ return result;
+}
+
+uint8_t TBinaryProtocol::readByte(TBuf& buf) const {
+ if (buf.len == 0) {
+ return 0;
+ }
+ uint8_t result = (uint8_t)buf.data[0];
+ buf.data += 1;
+ buf.len -= 1;
+ return result;
+}
+
+uint32_t TBinaryProtocol::readU32(TBuf& buf) const {
+ if (buf.len < 4) {
+ return 0;
+ }
+ uint32_t result = *(uint32_t*)buf.data;
+ buf.data += 4;
+ buf.len -= 4;
+ return result;
+}
+
+int32_t TBinaryProtocol::readI32(TBuf& buf) const {
+ if (buf.len < 4) {
+ return 0;
+ }
+ int32_t result = *(int32_t*)buf.data;
+ buf.data += 4;
+ buf.len -= 4;
+ return result;
+}
+
+uint64_t TBinaryProtocol::readU64(TBuf& buf) const {
+ if (buf.len < 8) {
+ return 0;
+ }
+ uint64_t result = *(uint64_t*)buf.data;
+ buf.data += 8;
+ buf.len -= 8;
+ return result;
+}
+
+int64_t TBinaryProtocol::readI64(TBuf& buf) const {
+ if (buf.len < 8) {
+ return 0;
+ }
+ int64_t result = *(int64_t*)buf.data;
+ buf.data += 8;
+ buf.len -= 8;
+ return result;
+}
+
+string TBinaryProtocol::writeString(const string& str) const {
+ uint32_t size = str.size();
+ string result = string((const char*)&size, 4);
+ return result + str;
+}
+
+string TBinaryProtocol::writeByte(const uint8_t byte) const {
+ return string((const char*)&byte, 1);
+}
+
+string TBinaryProtocol::writeU32(const uint32_t u32) const {
+ return string((const char*)&u32, 4);
+}
+
+string TBinaryProtocol::writeI32(int32_t i32) const {
+ return string((const char*)&i32, 4);
+}
+
+string TBinaryProtocol::writeU64(uint64_t u64) const {
+ return string((const char*)&u64, 8);
+}
+
+string TBinaryProtocol::writeI64(int64_t i64) const {
+ return string((const char*)&i64, 8);
+}
diff --git a/lib/cpp/protocol/TBinaryProtocol.h b/lib/cpp/protocol/TBinaryProtocol.h
new file mode 100644
index 0000000..976c383
--- /dev/null
+++ b/lib/cpp/protocol/TBinaryProtocol.h
@@ -0,0 +1,42 @@
+#ifndef T_BINARY_PROTOCOL_H
+#define T_BINARY_PROTOCOL_H
+
+#include "protocol/TProtocol.h"
+
+/**
+ * The default binary protocol for thrift. Writes all data in a very basic
+ * binary format, essentially just spitting out the raw bytes.
+ *
+ * @author Mark Slee <mcslee@facebook.com>
+ */
+class TBinaryProtocol : public TProtocol {
+ public:
+ TBinaryProtocol() {}
+ ~TBinaryProtocol() {}
+
+ std::string
+ readFunction(TBuf& buf) const;
+ std::string
+ writeFunction(const std::string& name, const std::string& args) const;
+
+ std::map<uint32_t, TBuf>
+ readStruct(TBuf& buf) const;
+ std::string
+ writeStruct(const std::map<uint32_t,std::string>& s) const;
+
+ std::string readString (TBuf& buf) const;
+ uint8_t readByte (TBuf& buf) const;
+ uint32_t readU32 (TBuf& buf) const;
+ int32_t readI32 (TBuf& buf) const;
+ uint64_t readU64 (TBuf& buf) const;
+ int64_t readI64 (TBuf& buf) const;
+
+ std::string writeString (const std::string& str) const;
+ std::string writeByte (const uint8_t byte) const;
+ std::string writeU32 (const uint32_t u32) const;
+ std::string writeI32 (const int32_t i32) const;
+ std::string writeU64 (const uint64_t u64) const;
+ std::string writeI64 (const int64_t i64) const;
+};
+
+#endif
diff --git a/lib/cpp/protocol/TProtocol.h b/lib/cpp/protocol/TProtocol.h
new file mode 100644
index 0000000..1f2e0c8
--- /dev/null
+++ b/lib/cpp/protocol/TProtocol.h
@@ -0,0 +1,88 @@
+#ifndef T_PROTOCOL_H
+#define T_PROTOCOL_H
+
+#include <sys/types.h>
+#include <string>
+#include <map>
+
+/** Forward declaration for TProtocol */
+struct TBuf;
+
+/**
+ * Abstract class for a thrift protocol driver. These are all the methods that
+ * a protocol must implement. Essentially, there must be some way of reading
+ * and writing all the base types, plus a mechanism for writing out structs
+ * with indexed fields. Also notice that all methods are strictly const. This
+ * is by design. Protcol impelementations may NOT keep state, because the
+ * same TProtocol object may be used simultaneously by multiple threads. This
+ * theoretically introduces some limititations into the possible protocol
+ * formats, but with the benefit of performance, clarity, and simplicity.
+ *
+ * @author Mark Slee <mcslee@facebook.com>
+ */
+class TProtocol {
+ public:
+ virtual ~TProtocol() {}
+
+ /**
+ * Function call serialization.
+ */
+
+ virtual std::string
+ readFunction(TBuf& buf) const = 0;
+ virtual std::string
+ writeFunction(const std::string& name, const std::string& args) const = 0;
+
+ /**
+ * Struct serialization.
+ */
+
+ virtual std::map<uint32_t, TBuf>
+ readStruct(TBuf& buf) const = 0;
+ virtual std::string
+ writeStruct(const std::map<uint32_t,std::string>& s) const = 0;
+
+ /**
+ * Basic data type deserialization. Note that these read methods do not
+ * take a const reference to the TBuf object. They SHOULD change the TBuf
+ * object so that it reflects the buffer AFTER the basic data type has
+ * been consumed such that data may continue being read serially from the
+ * buffer.
+ */
+
+ virtual std::string readString (TBuf& buf) const = 0;
+ virtual uint8_t readByte (TBuf& buf) const = 0;
+ virtual uint32_t readU32 (TBuf& buf) const = 0;
+ virtual int32_t readI32 (TBuf& buf) const = 0;
+ virtual uint64_t readU64 (TBuf& buf) const = 0;
+ virtual int64_t readI64 (TBuf& buf) const = 0;
+
+ virtual std::string writeString (const std::string& str) const = 0;
+ virtual std::string writeByte (const uint8_t byte) const = 0;
+ virtual std::string writeU32 (const uint32_t u32) const = 0;
+ virtual std::string writeI32 (const int32_t i32) const = 0;
+ virtual std::string writeU64 (const uint64_t u64) const = 0;
+ virtual std::string writeI64 (const int64_t i64) const = 0;
+
+ protected:
+ TProtocol() {}
+};
+
+/**
+ * Wrapper around raw data that allows us to track the length of a data
+ * buffer. It is the responsibility of a robust TProtocol implementation
+ * to ensure that any reads that are done from data do NOT overrun the
+ * memory address at data+len. It is also a convention that TBuf objects
+ * do NOT own the memory pointed to by data. They are merely wrappers
+ * around buffers that have been allocated elsewhere. Therefore, the user
+ * should never allocate memory before putting it into a TBuf nor should
+ * they free the data pointed to by a TBuf.
+ */
+struct TBuf {
+ TBuf(const TBuf& that) : data(that.data), len(that.len) {}
+ TBuf(const uint8_t* d, uint32_t l) : data(d), len(l) {}
+ const uint8_t* data;
+ uint32_t len;
+};
+
+#endif
diff --git a/lib/cpp/server/TServer.h b/lib/cpp/server/TServer.h
new file mode 100644
index 0000000..9c4cc59
--- /dev/null
+++ b/lib/cpp/server/TServer.h
@@ -0,0 +1,36 @@
+#ifndef T_SERVER_H
+#define T_SERVER_H
+
+#include "TDispatcher.h"
+
+class TServerOptions;
+
+/**
+ * Thrift server.
+ *
+ * @author Mark Slee <mcslee@facebook.com>
+ */
+class TServer {
+ public:
+ virtual ~TServer() {}
+ virtual void run() = 0;
+
+ protected:
+ TServer(TDispatcher* dispatcher, TServerOptions* options) :
+ dispatcher_(dispatcher), options_(options) {}
+
+ TDispatcher* dispatcher_;
+ TServerOptions* options_;
+};
+
+/**
+ * Class to encapsulate all generic server options.
+ */
+class TServerOptions {
+ public:
+ // TODO(mcslee): Fill in getters/setters here
+ protected:
+ // TODO(mcslee): Fill data members in here
+};
+
+#endif
diff --git a/lib/cpp/server/TSimpleServer.cc b/lib/cpp/server/TSimpleServer.cc
new file mode 100644
index 0000000..16f5006
--- /dev/null
+++ b/lib/cpp/server/TSimpleServer.cc
@@ -0,0 +1,60 @@
+#include "server/TSimpleServer.h"
+#include <string>
+using namespace std;
+
+void TSimpleServer::run() {
+ TTransport* client;
+
+ // Start the server listening
+ if (serverTransport_->listen() == false) {
+ // TODO(mcslee): Log error here
+ fprintf(stderr, "TSimpleServer::run(): Call to listen failed\n");
+ return;
+ }
+
+ // Fetch client from server
+ while (true) {
+ // fprintf(stderr, "Listening for connection\n");
+ if ((client = serverTransport_->accept()) == NULL) {
+ // fprintf(stderr, "Got NULL connection, exiting.\n");
+ break;
+ }
+
+ while (true) {
+ // Read header from client
+ // fprintf(stderr, "Reading 4 byte header from client.\n");
+ string in;
+ if (client->read(in, 4) <= 0) {
+ // fprintf(stderr, "Size header negative. Exception!\n");
+ break;
+ }
+
+ // Read payload from client
+ int32_t size = *(int32_t*)(in.data());
+ // fprintf(stderr, "Reading %d byte payload from client.\n", size);
+ if (client->read(in, size) < size) {
+ // fprintf(stderr, "Didn't get enough data!!!\n");
+ break;
+ }
+
+ // Pass payload to dispatcher
+ // TODO(mcslee): Wrap this in try/catch and return exceptions
+ string out = dispatcher_->dispatch(in);
+
+ size = out.size();
+
+ // Write size of response packet
+ client->write(string((char*)&size, 4));
+
+ // Write response payload
+ client->write(out);
+ }
+
+ // Clean up that client
+ // fprintf(stderr, "Closing and cleaning up client\n");
+ client->close();
+ delete client;
+ }
+
+ // TODO(mcslee): Is this a timeout case or the real thing?
+}
diff --git a/lib/cpp/server/TSimpleServer.h b/lib/cpp/server/TSimpleServer.h
new file mode 100644
index 0000000..47ab69e
--- /dev/null
+++ b/lib/cpp/server/TSimpleServer.h
@@ -0,0 +1,30 @@
+#ifndef T_SIMPLE_SERVER_H
+#define T_SIMPLE_SERVER_H
+
+#include "server/TServer.h"
+#include "transport/TServerTransport.h"
+
+/**
+ * This is the most basic simple server. It is single-threaded and runs a
+ * continuous loop of accepting a single connection, processing requests on
+ * that connection until it closes, and then repeating. It is a good example
+ * of how to extend the TServer interface.
+ *
+ * @author Mark Slee <mcslee@facebook.com>
+ */
+class TSimpleServer : public TServer {
+ public:
+ TSimpleServer(TDispatcher* dispatcher,
+ TServerOptions* options,
+ TServerTransport* serverTransport) :
+ TServer(dispatcher, options), serverTransport_(serverTransport) {}
+
+ ~TSimpleServer() {}
+
+ void run();
+
+ protected:
+ TServerTransport* serverTransport_;
+};
+
+#endif
diff --git a/lib/cpp/transport/TServerSocket.cc b/lib/cpp/transport/TServerSocket.cc
new file mode 100644
index 0000000..178de81
--- /dev/null
+++ b/lib/cpp/transport/TServerSocket.cc
@@ -0,0 +1,90 @@
+#include <sys/socket.h>
+#include <netinet/in.h>
+
+#include "transport/TSocket.h"
+#include "transport/TServerSocket.h"
+
+TServerSocket::TServerSocket(int port) :
+ port_(port), serverSocket_(0), acceptBacklog_(1024) {}
+
+TServerSocket::~TServerSocket() {
+ close();
+}
+
+bool TServerSocket::listen() {
+ serverSocket_ = socket(AF_INET, SOCK_STREAM, 0);
+ if (serverSocket_ == -1) {
+ close();
+ return false;
+ }
+
+ // Set reusaddress to prevent 2MSL delay on accept
+ int one = 1;
+ if (-1 == setsockopt(serverSocket_, SOL_SOCKET, SO_REUSEADDR,
+ &one, sizeof(one))) {
+ perror("TServerSocket::listen() SO_REUSEADDR");
+ close();
+ return false;
+ }
+
+ // Turn linger off, don't want to block on calls to close
+ struct linger ling = {0, 0};
+ if (-1 == setsockopt(serverSocket_, SOL_SOCKET, SO_LINGER,
+ &ling, sizeof(ling))) {
+ perror("TServerSocket::listen() SO_LINGER");
+ close();
+ return false;
+ }
+
+ // Bind to a port
+ struct sockaddr_in addr;
+ memset(&addr, 0, sizeof(addr));
+ addr.sin_family = AF_INET;
+ addr.sin_port = htons(port_);
+ addr.sin_addr.s_addr = INADDR_ANY;
+ if (-1 == bind(serverSocket_, (struct sockaddr *)&addr, sizeof(addr))) {
+ char errbuf[1024];
+ sprintf(errbuf, "TServerSocket::listen() BIND %d", port_);
+ perror(errbuf);
+ close();
+ return false;
+ }
+
+ // Call listen
+ if (-1 == ::listen(serverSocket_, acceptBacklog_)) {
+ perror("TServerSocket::listen() LISTEN");
+ close();
+ return false;
+ }
+
+ // The socket is now listening!
+ return true;
+}
+
+TTransport* TServerSocket::accept() {
+ if (serverSocket_ <= 0) {
+ // TODO(mcslee): Log error with common logging tool
+ return NULL;
+ }
+
+ struct sockaddr_in clientAddress;
+ int size = sizeof(clientAddress);
+ int clientSocket = ::accept(serverSocket_,
+ (struct sockaddr *) &clientAddress,
+ (socklen_t *) &size);
+
+ if (clientSocket <= 0) {
+ perror("TServerSocket::accept()");
+ return NULL;
+ }
+
+ return new TSocket(clientSocket);
+}
+
+void TServerSocket::close() {
+ if (serverSocket_ > 0) {
+ shutdown(serverSocket_, SHUT_RDWR);
+ ::close(serverSocket_);
+ }
+ serverSocket_ = 0;
+}
diff --git a/lib/cpp/transport/TServerSocket.h b/lib/cpp/transport/TServerSocket.h
new file mode 100644
index 0000000..8ded4e2
--- /dev/null
+++ b/lib/cpp/transport/TServerSocket.h
@@ -0,0 +1,29 @@
+#ifndef T_SERVER_SOCKET_H
+#define T_SERVER_SOCKET_H
+
+#include "transport/TServerTransport.h"
+
+class TSocket;
+
+/**
+ * Server socket implementation of TServerTransport. Wrapper around a unix
+ * socket listen and accept calls.
+ *
+ * @author Mark Slee <mcslee@facebook.com>
+ */
+class TServerSocket : public TServerTransport {
+ public:
+ TServerSocket(int port);
+ ~TServerSocket();
+
+ bool listen();
+ TTransport* accept();
+ void close();
+
+ private:
+ int port_;
+ int serverSocket_;
+ int acceptBacklog_;
+};
+
+#endif
diff --git a/lib/cpp/transport/TServerTransport.h b/lib/cpp/transport/TServerTransport.h
new file mode 100644
index 0000000..4d063fc
--- /dev/null
+++ b/lib/cpp/transport/TServerTransport.h
@@ -0,0 +1,24 @@
+#ifndef T_SERVER_TRANSPORT_H
+#define T_SERVER_TRANSPORT_H
+
+#include "TTransport.h"
+
+/**
+ * Server transport framework. A server needs to have some facility for
+ * creating base transports to read/write from.
+ *
+ * @author Mark Slee <mcslee@facebook.com>
+ */
+class TServerTransport {
+ public:
+ virtual ~TServerTransport() {}
+
+ virtual bool listen() = 0;
+ virtual TTransport* accept() = 0;
+ virtual void close() = 0;
+
+ protected:
+ TServerTransport() {}
+};
+
+#endif
diff --git a/lib/cpp/transport/TSocket.cc b/lib/cpp/transport/TSocket.cc
new file mode 100644
index 0000000..1dfe431
--- /dev/null
+++ b/lib/cpp/transport/TSocket.cc
@@ -0,0 +1,180 @@
+#include <sys/socket.h>
+#include <arpa/inet.h>
+#include <netinet/in.h>
+#include <netinet/tcp.h>
+#include <netdb.h>
+#include <unistd.h>
+#include <errno.h>
+
+#include "transport/TSocket.h"
+
+using namespace std;
+
+// Mutex to protect syscalls to netdb
+pthread_mutex_t g_netdb_mutex = PTHREAD_MUTEX_INITIALIZER;
+
+// TODO(mcslee): Make this an option to the socket class
+#define MAX_RECV_RETRIES 20
+
+TSocket::TSocket(string host, int port) :
+ host_(host), port_(port), socket_(0) {}
+
+TSocket::TSocket(int socket) {
+ socket_ = socket;
+}
+
+TSocket::~TSocket() {
+ close();
+}
+
+bool TSocket::open() {
+ // Create socket
+ socket_ = socket(AF_INET, SOCK_STREAM, 0);
+ if (socket_ == -1) {
+ socket_ = 0;
+ return false;
+ }
+
+ // Lookup the host
+ struct sockaddr_in addr;
+ addr.sin_family = AF_INET;
+ addr.sin_port = htons(port_);
+
+ /*
+ if (inet_pton(AF_INET, host_.c_str(), &addr.sin_addr) < 0) {
+ perror("TSocket::open() inet_pton");
+ }
+ */
+
+ {
+ // TODO(mcslee): Fix scope-locking here to protect hostname lookups
+ // scopelock sl(&netdb_mutex);
+ struct hostent *host_entry = gethostbyname(host_.c_str());
+
+ if (host_entry == NULL) {
+ // perror("dns error: failed call to gethostbyname.\n");
+ close();
+ return false;
+ }
+
+ addr.sin_port = htons(port_);
+ memcpy(&addr.sin_addr.s_addr,
+ host_entry->h_addr_list[0],
+ host_entry->h_length);
+ }
+
+ // Connect the socket
+ int ret = connect(socket_, (struct sockaddr *)&addr, sizeof(addr));
+
+ // Connect failed
+ if (ret < 0) {
+ perror("TSocket::open() connect");
+ close();
+ return false;
+ }
+
+ return true;
+}
+
+void TSocket::close() {
+ if (socket_ > 0) {
+ shutdown(socket_, SHUT_RDWR);
+ ::close(socket_);
+ }
+ socket_ = 0;
+}
+
+int TSocket::read(string& s, uint32_t len) {
+ char buff[len];
+ s = "";
+
+ uint32_t have = 0;
+ uint32_t retries = 0;
+
+ while (have < len) {
+ try_again:
+ // Read from the socket
+ int got = recv(socket_, buff+have, len-have, 0);
+
+ // Check for error on read
+ if (got < 0) {
+ perror("TSocket::read()");
+
+ // If temporarily out of resources, sleep a bit and try again
+ if (errno == EAGAIN && retries++ < MAX_RECV_RETRIES) {
+ usleep(50);
+ goto try_again;
+ }
+
+ // If interrupted, try again
+ if (errno == EINTR && retries++ < MAX_RECV_RETRIES) {
+ goto try_again;
+ }
+
+ // If we disconnect with no linger time
+ if (errno == ECONNRESET) {
+ return 0;
+ }
+
+ return 0;
+ }
+
+ // Check for empty read
+ if (got == 0) {
+ return 0;
+ }
+
+ // Update the count
+ have += (uint32_t) got;
+ }
+
+ // Pack data into string
+ s = string(buff, have);
+ return have;
+}
+
+void TSocket::write(const string& s) {
+ uint32_t sent = 0;
+
+ while (sent < s.size()) {
+ int b = send(socket_, s.data() + sent, s.size() - sent, 0);
+
+ // Fail on a send error
+ if (b < 0) {
+ // TODO(mcslee): Make the function return how many bytes it wrote or
+ // throw an exception
+ // throw_perror("send");
+ return;
+ }
+
+ // Fail on blocked send
+ if (b == 0) {
+ // TODO(mcslee): Make the function return how many bytes it wrote or
+ // throw string("couldn't send data.\n");
+ return;
+ }
+
+ sent += b;
+ }
+}
+
+bool TSocket::setLinger(bool on, int linger) {
+ struct linger ling = {(on ? 1 : 0), linger};
+ if (-1 == setsockopt(socket_, SOL_SOCKET, SO_LINGER, &ling, sizeof(ling))) {
+ close();
+ perror("TSocket::setLinger()");
+ return false;
+ }
+ return true;
+}
+
+bool TSocket::setNoDelay(bool noDelay) {
+ // Set socket to NODELAY
+ int val = (noDelay ? 1 : 0);
+ if (-1 == setsockopt(socket_, IPPROTO_TCP, TCP_NODELAY, &val, sizeof(val))) {
+ close();
+ perror("TSocket::setNoDelay()");
+ return false;
+ }
+ return true;
+}
diff --git a/lib/cpp/transport/TSocket.h b/lib/cpp/transport/TSocket.h
new file mode 100644
index 0000000..1da74c6
--- /dev/null
+++ b/lib/cpp/transport/TSocket.h
@@ -0,0 +1,39 @@
+#ifndef T_SOCKET_H
+#define T_SOCKET_H
+
+#include <string>
+
+#include "transport/TTransport.h"
+#include "transport/TServerSocket.h"
+
+class TSocketOptions;
+
+/**
+ * TCP Socket implementation of the TTransport interface.
+ *
+ * @author Mark Slee <mcslee@facebook.com>
+ */
+class TSocket : public TTransport {
+ friend TTransport* TServerSocket::accept();
+
+ public:
+ TSocket(std::string host, int port);
+ ~TSocket();
+
+ bool open();
+ void close();
+ int read (std::string &s, uint32_t size);
+ void write(const std::string& s);
+
+ bool setLinger(bool on, int linger);
+ bool setNoDelay(bool noDelay);
+
+ private:
+ TSocket(int socket);
+ TSocketOptions *options_;
+ std::string host_;
+ int port_;
+ int socket_;
+};
+
+#endif
diff --git a/lib/cpp/transport/TTransport.h b/lib/cpp/transport/TTransport.h
new file mode 100644
index 0000000..a1f43d4
--- /dev/null
+++ b/lib/cpp/transport/TTransport.h
@@ -0,0 +1,25 @@
+#ifndef T_TRANSPORT_H
+#define T_TRANSPORT_H
+
+#include <string>
+
+/**
+ * Generic interface for a method of transporting data.
+ *
+ * @author Mark Slee <mcslee@facebook.com>
+ */
+class TTransport {
+ public:
+ virtual ~TTransport() {};
+
+ virtual bool open() = 0;
+ virtual void close() = 0;
+
+ virtual int read (std::string& s, uint32_t size) = 0;
+ virtual void write(const std::string& s) = 0;
+
+ protected:
+ TTransport() {};
+};
+
+#endif