Thrift now a TLP - INFRA-3116
git-svn-id: https://svn.apache.org/repos/asf/thrift/branches/0.1.x@1028168 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/cpp/src/TLogging.h b/lib/cpp/src/TLogging.h
new file mode 100644
index 0000000..2df82dd
--- /dev/null
+++ b/lib/cpp/src/TLogging.h
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_TLOGGING_H_
+#define _THRIFT_TLOGGING_H_ 1
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+/**
+ * Contains utility macros for debugging and logging.
+ *
+ */
+
+#ifndef HAVE_CLOCK_GETTIME
+#include <time.h>
+#else
+#include <sys/time.h>
+#endif
+
+#ifdef HAVE_STDINT_H
+#include <stdint.h>
+#endif
+
+/**
+ * T_GLOBAL_DEBUGGING_LEVEL = 0: all debugging turned off, debug macros undefined
+ * T_GLOBAL_DEBUGGING_LEVEL = 1: all debugging turned on
+ */
+#define T_GLOBAL_DEBUGGING_LEVEL 0
+
+
+/**
+ * T_GLOBAL_LOGGING_LEVEL = 0: all logging turned off, logging macros undefined
+ * T_GLOBAL_LOGGING_LEVEL = 1: all logging turned on
+ */
+#define T_GLOBAL_LOGGING_LEVEL 1
+
+
+/**
+ * Standard wrapper around fprintf what will prefix the file name and line
+ * number to the line. Uses T_GLOBAL_DEBUGGING_LEVEL to control whether it is
+ * turned on or off.
+ *
+ * @param format_string
+ */
+#if T_GLOBAL_DEBUGGING_LEVEL > 0
+ #define T_DEBUG(format_string,...) \
+ if (T_GLOBAL_DEBUGGING_LEVEL > 0) { \
+ fprintf(stderr,"[%s,%d] " #format_string " \n", __FILE__, __LINE__,##__VA_ARGS__); \
+ }
+#else
+ #define T_DEBUG(format_string,...)
+#endif
+
+
+/**
+ * analagous to T_DEBUG but also prints the time
+ *
+ * @param string format_string input: printf style format string
+ */
+#if T_GLOBAL_DEBUGGING_LEVEL > 0
+ #define T_DEBUG_T(format_string,...) \
+ { \
+ if (T_GLOBAL_DEBUGGING_LEVEL > 0) { \
+ time_t now; \
+ char dbgtime[26] ; \
+ time(&now); \
+ ctime_r(&now, dbgtime); \
+ dbgtime[24] = '\0'; \
+ fprintf(stderr,"[%s,%d] [%s] " #format_string " \n", __FILE__, __LINE__,dbgtime,##__VA_ARGS__); \
+ } \
+ }
+#else
+ #define T_DEBUG_T(format_string,...)
+#endif
+
+
+/**
+ * analagous to T_DEBUG but uses input level to determine whether or not the string
+ * should be logged.
+ *
+ * @param int level: specified debug level
+ * @param string format_string input: format string
+ */
+#define T_DEBUG_L(level, format_string,...) \
+ if ((level) > 0) { \
+ fprintf(stderr,"[%s,%d] " #format_string " \n", __FILE__, __LINE__,##__VA_ARGS__); \
+ }
+
+
+/**
+ * Explicit error logging. Prints time, file name and line number
+ *
+ * @param string format_string input: printf style format string
+ */
+#define T_ERROR(format_string,...) \
+ { \
+ time_t now; \
+ char dbgtime[26] ; \
+ time(&now); \
+ ctime_r(&now, dbgtime); \
+ dbgtime[24] = '\0'; \
+ fprintf(stderr,"[%s,%d] [%s] ERROR: " #format_string " \n", __FILE__, __LINE__,dbgtime,##__VA_ARGS__); \
+ }
+
+
+/**
+ * Analagous to T_ERROR, additionally aborting the process.
+ * WARNING: macro calls abort(), ending program execution
+ *
+ * @param string format_string input: printf style format string
+ */
+#define T_ERROR_ABORT(format_string,...) \
+ { \
+ time_t now; \
+ char dbgtime[26] ; \
+ time(&now); \
+ ctime_r(&now, dbgtime); \
+ dbgtime[24] = '\0'; \
+ fprintf(stderr,"[%s,%d] [%s] ERROR: Going to abort " #format_string " \n", __FILE__, __LINE__,dbgtime,##__VA_ARGS__); \
+ exit(1); \
+ }
+
+
+/**
+ * Log input message
+ *
+ * @param string format_string input: printf style format string
+ */
+#if T_GLOBAL_LOGGING_LEVEL > 0
+ #define T_LOG_OPER(format_string,...) \
+ { \
+ if (T_GLOBAL_LOGGING_LEVEL > 0) { \
+ time_t now; \
+ char dbgtime[26] ; \
+ time(&now); \
+ ctime_r(&now, dbgtime); \
+ dbgtime[24] = '\0'; \
+ fprintf(stderr,"[%s] " #format_string " \n", dbgtime,##__VA_ARGS__); \
+ } \
+ }
+#else
+ #define T_LOG_OPER(format_string,...)
+#endif
+
+#endif // #ifndef _THRIFT_TLOGGING_H_
diff --git a/lib/cpp/src/TProcessor.h b/lib/cpp/src/TProcessor.h
new file mode 100644
index 0000000..f2d5279
--- /dev/null
+++ b/lib/cpp/src/TProcessor.h
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_TPROCESSOR_H_
+#define _THRIFT_TPROCESSOR_H_ 1
+
+#include <string>
+#include <protocol/TProtocol.h>
+#include <boost/shared_ptr.hpp>
+
+namespace apache { namespace thrift {
+
+/**
+ * A processor is a generic object that acts upon two streams of data, one
+ * an input and the other an output. The definition of this object is loose,
+ * though the typical case is for some sort of server that either generates
+ * responses to an input stream or forwards data from one pipe onto another.
+ *
+ */
+class TProcessor {
+ public:
+ virtual ~TProcessor() {}
+
+ virtual bool process(boost::shared_ptr<protocol::TProtocol> in,
+ boost::shared_ptr<protocol::TProtocol> out) = 0;
+
+ bool process(boost::shared_ptr<apache::thrift::protocol::TProtocol> io) {
+ return process(io, io);
+ }
+
+ protected:
+ TProcessor() {}
+};
+
+}} // apache::thrift
+
+#endif // #ifndef _THRIFT_PROCESSOR_H_
diff --git a/lib/cpp/src/TReflectionLocal.h b/lib/cpp/src/TReflectionLocal.h
new file mode 100644
index 0000000..e83e475
--- /dev/null
+++ b/lib/cpp/src/TReflectionLocal.h
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_TREFLECTIONLOCAL_H_
+#define _THRIFT_TREFLECTIONLOCAL_H_ 1
+
+#include <stdint.h>
+#include <cstring>
+#include <protocol/TProtocol.h>
+
+/**
+ * Local Reflection is a blanket term referring to the the structure
+ * and generation of this particular representation of Thrift types.
+ * (It is called local because it cannot be serialized by Thrift).
+ *
+ */
+
+namespace apache { namespace thrift { namespace reflection { namespace local {
+
+using apache::thrift::protocol::TType;
+
+// We include this many bytes of the structure's fingerprint when serializing
+// a top-level structure. Long enough to make collisions unlikely, short
+// enough to not significantly affect the amount of memory used.
+const int FP_PREFIX_LEN = 4;
+
+struct FieldMeta {
+ int16_t tag;
+ bool is_optional;
+};
+
+struct TypeSpec {
+ TType ttype;
+ uint8_t fp_prefix[FP_PREFIX_LEN];
+
+ // Use an anonymous union here so we can fit two TypeSpecs in one cache line.
+ union {
+ struct {
+ // Use parallel arrays here for denser packing (of the arrays).
+ FieldMeta* metas;
+ TypeSpec** specs;
+ } tstruct;
+ struct {
+ TypeSpec *subtype1;
+ TypeSpec *subtype2;
+ } tcontainer;
+ };
+
+ // Static initialization of unions isn't really possible,
+ // so take the plunge and use constructors.
+ // Hopefully they'll be evaluated at compile time.
+
+ TypeSpec(TType ttype) : ttype(ttype) {
+ std::memset(fp_prefix, 0, FP_PREFIX_LEN);
+ }
+
+ TypeSpec(TType ttype,
+ const uint8_t* fingerprint,
+ FieldMeta* metas,
+ TypeSpec** specs) :
+ ttype(ttype)
+ {
+ std::memcpy(fp_prefix, fingerprint, FP_PREFIX_LEN);
+ tstruct.metas = metas;
+ tstruct.specs = specs;
+ }
+
+ TypeSpec(TType ttype, TypeSpec* subtype1, TypeSpec* subtype2) :
+ ttype(ttype)
+ {
+ std::memset(fp_prefix, 0, FP_PREFIX_LEN);
+ tcontainer.subtype1 = subtype1;
+ tcontainer.subtype2 = subtype2;
+ }
+
+};
+
+}}}} // apache::thrift::reflection::local
+
+#endif // #ifndef _THRIFT_TREFLECTIONLOCAL_H_
diff --git a/lib/cpp/src/Thrift.cpp b/lib/cpp/src/Thrift.cpp
new file mode 100644
index 0000000..ed99205
--- /dev/null
+++ b/lib/cpp/src/Thrift.cpp
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <Thrift.h>
+#include <cstring>
+#include <boost/lexical_cast.hpp>
+#include <protocol/TProtocol.h>
+#include <stdarg.h>
+#include <stdio.h>
+
+namespace apache { namespace thrift {
+
+TOutput GlobalOutput;
+
+void TOutput::printf(const char *message, ...) {
+ // Try to reduce heap usage, even if printf is called rarely.
+ static const int STACK_BUF_SIZE = 256;
+ char stack_buf[STACK_BUF_SIZE];
+ va_list ap;
+
+ va_start(ap, message);
+ int need = vsnprintf(stack_buf, STACK_BUF_SIZE, message, ap);
+ va_end(ap);
+
+ if (need < STACK_BUF_SIZE) {
+ f_(stack_buf);
+ return;
+ }
+
+ char *heap_buf = (char*)malloc((need+1) * sizeof(char));
+ if (heap_buf == NULL) {
+ // Malloc failed. We might as well print the stack buffer.
+ f_(stack_buf);
+ return;
+ }
+
+ va_start(ap, message);
+ int rval = vsnprintf(heap_buf, need+1, message, ap);
+ va_end(ap);
+ // TODO(shigin): inform user
+ if (rval != -1) {
+ f_(heap_buf);
+ }
+ free(heap_buf);
+}
+
+void TOutput::perror(const char *message, int errno_copy) {
+ std::string out = message + strerror_s(errno_copy);
+ f_(out.c_str());
+}
+
+std::string TOutput::strerror_s(int errno_copy) {
+#ifndef HAVE_STRERROR_R
+ return "errno = " + boost::lexical_cast<std::string>(errno_copy);
+#else // HAVE_STRERROR_R
+
+ char b_errbuf[1024] = { '\0' };
+#ifdef STRERROR_R_CHAR_P
+ char *b_error = strerror_r(errno_copy, b_errbuf, sizeof(b_errbuf));
+#else
+ char *b_error = b_errbuf;
+ int rv = strerror_r(errno_copy, b_errbuf, sizeof(b_errbuf));
+ if (rv == -1) {
+ // strerror_r failed. omgwtfbbq.
+ return "XSI-compliant strerror_r() failed with errno = " +
+ boost::lexical_cast<std::string>(errno_copy);
+ }
+#endif
+ // Can anyone prove that explicit cast is probably not necessary
+ // to ensure that the string object is constructed before
+ // b_error becomes invalid?
+ return std::string(b_error);
+
+#endif // HAVE_STRERROR_R
+}
+
+uint32_t TApplicationException::read(apache::thrift::protocol::TProtocol* iprot) {
+ uint32_t xfer = 0;
+ std::string fname;
+ apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ xfer += iprot->readStructBegin(fname);
+
+ while (true) {
+ xfer += iprot->readFieldBegin(fname, ftype, fid);
+ if (ftype == apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ switch (fid) {
+ case 1:
+ if (ftype == apache::thrift::protocol::T_STRING) {
+ xfer += iprot->readString(message_);
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ case 2:
+ if (ftype == apache::thrift::protocol::T_I32) {
+ int32_t type;
+ xfer += iprot->readI32(type);
+ type_ = (TApplicationExceptionType)type;
+ } else {
+ xfer += iprot->skip(ftype);
+ }
+ break;
+ default:
+ xfer += iprot->skip(ftype);
+ break;
+ }
+ xfer += iprot->readFieldEnd();
+ }
+
+ xfer += iprot->readStructEnd();
+ return xfer;
+}
+
+uint32_t TApplicationException::write(apache::thrift::protocol::TProtocol* oprot) const {
+ uint32_t xfer = 0;
+ xfer += oprot->writeStructBegin("TApplicationException");
+ xfer += oprot->writeFieldBegin("message", apache::thrift::protocol::T_STRING, 1);
+ xfer += oprot->writeString(message_);
+ xfer += oprot->writeFieldEnd();
+ xfer += oprot->writeFieldBegin("type", apache::thrift::protocol::T_I32, 2);
+ xfer += oprot->writeI32(type_);
+ xfer += oprot->writeFieldEnd();
+ xfer += oprot->writeFieldStop();
+ xfer += oprot->writeStructEnd();
+ return xfer;
+}
+
+}} // apache::thrift
diff --git a/lib/cpp/src/Thrift.h b/lib/cpp/src/Thrift.h
new file mode 100644
index 0000000..26d2b0f
--- /dev/null
+++ b/lib/cpp/src/Thrift.h
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_THRIFT_H_
+#define _THRIFT_THRIFT_H_ 1
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+#include <stdio.h>
+
+#include <netinet/in.h>
+#ifdef HAVE_INTTYPES_H
+#include <inttypes.h>
+#endif
+#include <string>
+#include <map>
+#include <list>
+#include <set>
+#include <vector>
+#include <exception>
+
+#include "TLogging.h"
+
+namespace apache { namespace thrift {
+
+class TOutput {
+ public:
+ TOutput() : f_(&errorTimeWrapper) {}
+
+ inline void setOutputFunction(void (*function)(const char *)){
+ f_ = function;
+ }
+
+ inline void operator()(const char *message){
+ f_(message);
+ }
+
+ // It is important to have a const char* overload here instead of
+ // just the string version, otherwise errno could be corrupted
+ // if there is some problem allocating memory when constructing
+ // the string.
+ void perror(const char *message, int errno_copy);
+ inline void perror(const std::string &message, int errno_copy) {
+ perror(message.c_str(), errno_copy);
+ }
+
+ void printf(const char *message, ...);
+
+ inline static void errorTimeWrapper(const char* msg) {
+ time_t now;
+ char dbgtime[25];
+ time(&now);
+ ctime_r(&now, dbgtime);
+ dbgtime[24] = 0;
+ fprintf(stderr, "Thrift: %s %s\n", dbgtime, msg);
+ }
+
+ /** Just like strerror_r but returns a C++ string object. */
+ static std::string strerror_s(int errno_copy);
+
+ private:
+ void (*f_)(const char *);
+};
+
+extern TOutput GlobalOutput;
+
+namespace protocol {
+ class TProtocol;
+}
+
+class TException : public std::exception {
+ public:
+ TException() {}
+
+ TException(const std::string& message) :
+ message_(message) {}
+
+ virtual ~TException() throw() {}
+
+ virtual const char* what() const throw() {
+ if (message_.empty()) {
+ return "Default TException.";
+ } else {
+ return message_.c_str();
+ }
+ }
+
+ protected:
+ std::string message_;
+
+};
+
+class TApplicationException : public TException {
+ public:
+
+ /**
+ * Error codes for the various types of exceptions.
+ */
+ enum TApplicationExceptionType
+ { UNKNOWN = 0
+ , UNKNOWN_METHOD = 1
+ , INVALID_MESSAGE_TYPE = 2
+ , WRONG_METHOD_NAME = 3
+ , BAD_SEQUENCE_ID = 4
+ , MISSING_RESULT = 5
+ };
+
+ TApplicationException() :
+ TException(),
+ type_(UNKNOWN) {}
+
+ TApplicationException(TApplicationExceptionType type) :
+ TException(),
+ type_(type) {}
+
+ TApplicationException(const std::string& message) :
+ TException(message),
+ type_(UNKNOWN) {}
+
+ TApplicationException(TApplicationExceptionType type,
+ const std::string& message) :
+ TException(message),
+ type_(type) {}
+
+ virtual ~TApplicationException() throw() {}
+
+ /**
+ * Returns an error code that provides information about the type of error
+ * that has occurred.
+ *
+ * @return Error code
+ */
+ TApplicationExceptionType getType() {
+ return type_;
+ }
+
+ virtual const char* what() const throw() {
+ if (message_.empty()) {
+ switch (type_) {
+ case UNKNOWN : return "TApplicationException: Unknown application exception";
+ case UNKNOWN_METHOD : return "TApplicationException: Unknown method";
+ case INVALID_MESSAGE_TYPE : return "TApplicationException: Invalid message type";
+ case WRONG_METHOD_NAME : return "TApplicationException: Wrong method name";
+ case BAD_SEQUENCE_ID : return "TApplicationException: Bad sequence identifier";
+ case MISSING_RESULT : return "TApplicationException: Missing result";
+ default : return "TApplicationException: (Invalid exception type)";
+ };
+ } else {
+ return message_.c_str();
+ }
+ }
+
+ uint32_t read(protocol::TProtocol* iprot);
+ uint32_t write(protocol::TProtocol* oprot) const;
+
+ protected:
+ /**
+ * Error code
+ */
+ TApplicationExceptionType type_;
+
+};
+
+
+// Forward declare this structure used by TDenseProtocol
+namespace reflection { namespace local {
+struct TypeSpec;
+}}
+
+
+}} // apache::thrift
+
+#endif // #ifndef _THRIFT_THRIFT_H_
diff --git a/lib/cpp/src/concurrency/Exception.h b/lib/cpp/src/concurrency/Exception.h
new file mode 100644
index 0000000..ec46629
--- /dev/null
+++ b/lib/cpp/src/concurrency/Exception.h
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_CONCURRENCY_EXCEPTION_H_
+#define _THRIFT_CONCURRENCY_EXCEPTION_H_ 1
+
+#include <exception>
+#include <Thrift.h>
+
+namespace apache { namespace thrift { namespace concurrency {
+
+class NoSuchTaskException : public apache::thrift::TException {};
+
+class UncancellableTaskException : public apache::thrift::TException {};
+
+class InvalidArgumentException : public apache::thrift::TException {};
+
+class IllegalStateException : public apache::thrift::TException {};
+
+class TimedOutException : public apache::thrift::TException {
+public:
+ TimedOutException():TException("TimedOutException"){};
+ TimedOutException(const std::string& message ) :
+ TException(message) {}
+};
+
+class TooManyPendingTasksException : public apache::thrift::TException {
+public:
+ TooManyPendingTasksException():TException("TooManyPendingTasksException"){};
+ TooManyPendingTasksException(const std::string& message ) :
+ TException(message) {}
+};
+
+class SystemResourceException : public apache::thrift::TException {
+public:
+ SystemResourceException() {}
+
+ SystemResourceException(const std::string& message) :
+ TException(message) {}
+};
+
+}}} // apache::thrift::concurrency
+
+#endif // #ifndef _THRIFT_CONCURRENCY_EXCEPTION_H_
diff --git a/lib/cpp/src/concurrency/FunctionRunner.h b/lib/cpp/src/concurrency/FunctionRunner.h
new file mode 100644
index 0000000..2216927
--- /dev/null
+++ b/lib/cpp/src/concurrency/FunctionRunner.h
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_CONCURRENCY_FUNCTION_RUNNER_H
+#define _THRIFT_CONCURRENCY_FUNCTION_RUNNER_H 1
+
+#include <tr1/functional>
+#include "thrift/lib/cpp/concurrency/Thread.h"
+
+namespace apache { namespace thrift { namespace concurrency {
+
+/**
+ * Convenient implementation of Runnable that will execute arbitrary callbacks.
+ * Interfaces are provided to accept both a generic 'void(void)' callback, and
+ * a 'void* (void*)' pthread_create-style callback.
+ *
+ * Example use:
+ * void* my_thread_main(void* arg);
+ * shared_ptr<ThreadFactory> factory = ...;
+ * shared_ptr<Thread> thread =
+ * factory->newThread(shared_ptr<FunctionRunner>(
+ * new FunctionRunner(my_thread_main, some_argument)));
+ * thread->start();
+ *
+ *
+ */
+
+class FunctionRunner : public Runnable {
+ public:
+ // This is the type of callback 'pthread_create()' expects.
+ typedef void* (*PthreadFuncPtr)(void *arg);
+ // This a fully-generic void(void) callback for custom bindings.
+ typedef std::tr1::function<void()> VoidFunc;
+
+ /**
+ * Given a 'pthread_create' style callback, this FunctionRunner will
+ * execute the given callback. Note that the 'void*' return value is ignored.
+ */
+ FunctionRunner(PthreadFuncPtr func, void* arg)
+ : func_(std::tr1::bind(func, arg))
+ { }
+
+ /**
+ * Given a generic callback, this FunctionRunner will execute it.
+ */
+ FunctionRunner(const VoidFunc& cob)
+ : func_(cob)
+ { }
+
+
+ void run() {
+ func_();
+ }
+
+ private:
+ VoidFunc func_;
+};
+
+}}} // apache::thrift::concurrency
+
+#endif // #ifndef _THRIFT_CONCURRENCY_FUNCTION_RUNNER_H
diff --git a/lib/cpp/src/concurrency/Monitor.cpp b/lib/cpp/src/concurrency/Monitor.cpp
new file mode 100644
index 0000000..2055caa
--- /dev/null
+++ b/lib/cpp/src/concurrency/Monitor.cpp
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "Monitor.h"
+#include "Exception.h"
+#include "Util.h"
+
+#include <assert.h>
+#include <errno.h>
+
+#include <iostream>
+
+#include <pthread.h>
+
+namespace apache { namespace thrift { namespace concurrency {
+
+/**
+ * Monitor implementation using the POSIX pthread library
+ *
+ * @version $Id:$
+ */
+class Monitor::Impl {
+
+ public:
+
+ Impl() :
+ mutexInitialized_(false),
+ condInitialized_(false) {
+
+ if (pthread_mutex_init(&pthread_mutex_, NULL) == 0) {
+ mutexInitialized_ = true;
+
+ if (pthread_cond_init(&pthread_cond_, NULL) == 0) {
+ condInitialized_ = true;
+ }
+ }
+
+ if (!mutexInitialized_ || !condInitialized_) {
+ cleanup();
+ throw SystemResourceException();
+ }
+ }
+
+ ~Impl() { cleanup(); }
+
+ void lock() const { pthread_mutex_lock(&pthread_mutex_); }
+
+ void unlock() const { pthread_mutex_unlock(&pthread_mutex_); }
+
+ void wait(int64_t timeout) const {
+
+ // XXX Need to assert that caller owns mutex
+ assert(timeout >= 0LL);
+ if (timeout == 0LL) {
+ int iret = pthread_cond_wait(&pthread_cond_, &pthread_mutex_);
+ assert(iret == 0);
+ } else {
+ struct timespec abstime;
+ int64_t now = Util::currentTime();
+ Util::toTimespec(abstime, now + timeout);
+ int result = pthread_cond_timedwait(&pthread_cond_,
+ &pthread_mutex_,
+ &abstime);
+ if (result == ETIMEDOUT) {
+ // pthread_cond_timedwait has been observed to return early on
+ // various platforms, so comment out this assert.
+ //assert(Util::currentTime() >= (now + timeout));
+ throw TimedOutException();
+ }
+ }
+ }
+
+ void notify() {
+ // XXX Need to assert that caller owns mutex
+ int iret = pthread_cond_signal(&pthread_cond_);
+ assert(iret == 0);
+ }
+
+ void notifyAll() {
+ // XXX Need to assert that caller owns mutex
+ int iret = pthread_cond_broadcast(&pthread_cond_);
+ assert(iret == 0);
+ }
+
+ private:
+
+ void cleanup() {
+ if (mutexInitialized_) {
+ mutexInitialized_ = false;
+ int iret = pthread_mutex_destroy(&pthread_mutex_);
+ assert(iret == 0);
+ }
+
+ if (condInitialized_) {
+ condInitialized_ = false;
+ int iret = pthread_cond_destroy(&pthread_cond_);
+ assert(iret == 0);
+ }
+ }
+
+ mutable pthread_mutex_t pthread_mutex_;
+ mutable bool mutexInitialized_;
+ mutable pthread_cond_t pthread_cond_;
+ mutable bool condInitialized_;
+};
+
+Monitor::Monitor() : impl_(new Monitor::Impl()) {}
+
+Monitor::~Monitor() { delete impl_; }
+
+void Monitor::lock() const { impl_->lock(); }
+
+void Monitor::unlock() const { impl_->unlock(); }
+
+void Monitor::wait(int64_t timeout) const { impl_->wait(timeout); }
+
+void Monitor::notify() const { impl_->notify(); }
+
+void Monitor::notifyAll() const { impl_->notifyAll(); }
+
+}}} // apache::thrift::concurrency
diff --git a/lib/cpp/src/concurrency/Monitor.h b/lib/cpp/src/concurrency/Monitor.h
new file mode 100644
index 0000000..234bf32
--- /dev/null
+++ b/lib/cpp/src/concurrency/Monitor.h
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_CONCURRENCY_MONITOR_H_
+#define _THRIFT_CONCURRENCY_MONITOR_H_ 1
+
+#include "Exception.h"
+
+namespace apache { namespace thrift { namespace concurrency {
+
+/**
+ * A monitor is a combination mutex and condition-event. Waiting and
+ * notifying condition events requires that the caller own the mutex. Mutex
+ * lock and unlock operations can be performed independently of condition
+ * events. This is more or less analogous to java.lang.Object multi-thread
+ * operations
+ *
+ * Note that all methods are const. Monitors implement logical constness, not
+ * bit constness. This allows const methods to call monitor methods without
+ * needing to cast away constness or change to non-const signatures.
+ *
+ * @version $Id:$
+ */
+class Monitor {
+
+ public:
+
+ Monitor();
+
+ virtual ~Monitor();
+
+ virtual void lock() const;
+
+ virtual void unlock() const;
+
+ virtual void wait(int64_t timeout=0LL) const;
+
+ virtual void notify() const;
+
+ virtual void notifyAll() const;
+
+ private:
+
+ class Impl;
+
+ Impl* impl_;
+};
+
+class Synchronized {
+ public:
+
+ Synchronized(const Monitor& value) :
+ monitor_(value) {
+ monitor_.lock();
+ }
+
+ ~Synchronized() {
+ monitor_.unlock();
+ }
+
+ private:
+ const Monitor& monitor_;
+};
+
+
+}}} // apache::thrift::concurrency
+
+#endif // #ifndef _THRIFT_CONCURRENCY_MONITOR_H_
diff --git a/lib/cpp/src/concurrency/Mutex.cpp b/lib/cpp/src/concurrency/Mutex.cpp
new file mode 100644
index 0000000..045dbdf
--- /dev/null
+++ b/lib/cpp/src/concurrency/Mutex.cpp
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "Mutex.h"
+
+#include <assert.h>
+#include <pthread.h>
+
+using boost::shared_ptr;
+
+namespace apache { namespace thrift { namespace concurrency {
+
+/**
+ * Implementation of Mutex class using POSIX mutex
+ *
+ * @version $Id:$
+ */
+class Mutex::impl {
+ public:
+ impl(Initializer init) : initialized_(false) {
+ init(&pthread_mutex_);
+ initialized_ = true;
+ }
+
+ ~impl() {
+ if (initialized_) {
+ initialized_ = false;
+ int ret = pthread_mutex_destroy(&pthread_mutex_);
+ assert(ret == 0);
+ }
+ }
+
+ void lock() const { pthread_mutex_lock(&pthread_mutex_); }
+
+ bool trylock() const { return (0 == pthread_mutex_trylock(&pthread_mutex_)); }
+
+ void unlock() const { pthread_mutex_unlock(&pthread_mutex_); }
+
+ private:
+ mutable pthread_mutex_t pthread_mutex_;
+ mutable bool initialized_;
+};
+
+Mutex::Mutex(Initializer init) : impl_(new Mutex::impl(init)) {}
+
+void Mutex::lock() const { impl_->lock(); }
+
+bool Mutex::trylock() const { return impl_->trylock(); }
+
+void Mutex::unlock() const { impl_->unlock(); }
+
+void Mutex::DEFAULT_INITIALIZER(void* arg) {
+ pthread_mutex_t* pthread_mutex = (pthread_mutex_t*)arg;
+ int ret = pthread_mutex_init(pthread_mutex, NULL);
+ assert(ret == 0);
+}
+
+static void init_with_kind(pthread_mutex_t* mutex, int kind) {
+ pthread_mutexattr_t mutexattr;
+ int ret = pthread_mutexattr_init(&mutexattr);
+ assert(ret == 0);
+
+ // Apparently, this can fail. Should we really be aborting?
+ ret = pthread_mutexattr_settype(&mutexattr, kind);
+ assert(ret == 0);
+
+ ret = pthread_mutex_init(mutex, &mutexattr);
+ assert(ret == 0);
+
+ ret = pthread_mutexattr_destroy(&mutexattr);
+ assert(ret == 0);
+}
+
+#ifdef PTHREAD_ADAPTIVE_MUTEX_INITIALIZER_NP
+void Mutex::ADAPTIVE_INITIALIZER(void* arg) {
+ // From mysql source: mysys/my_thr_init.c
+ // Set mutex type to "fast" a.k.a "adaptive"
+ //
+ // In this case the thread may steal the mutex from some other thread
+ // that is waiting for the same mutex. This will save us some
+ // context switches but may cause a thread to 'starve forever' while
+ // waiting for the mutex (not likely if the code within the mutex is
+ // short).
+ init_with_kind((pthread_mutex_t*)arg, PTHREAD_MUTEX_ADAPTIVE_NP);
+}
+#endif
+
+#ifdef PTHREAD_RECURSIVE_MUTEX_INITIALIZER_NP
+void Mutex::RECURSIVE_INITIALIZER(void* arg) {
+ init_with_kind((pthread_mutex_t*)arg, PTHREAD_MUTEX_RECURSIVE_NP);
+}
+#endif
+
+
+/**
+ * Implementation of ReadWriteMutex class using POSIX rw lock
+ *
+ * @version $Id:$
+ */
+class ReadWriteMutex::impl {
+public:
+ impl() : initialized_(false) {
+ int ret = pthread_rwlock_init(&rw_lock_, NULL);
+ assert(ret == 0);
+ initialized_ = true;
+ }
+
+ ~impl() {
+ if(initialized_) {
+ initialized_ = false;
+ int ret = pthread_rwlock_destroy(&rw_lock_);
+ assert(ret == 0);
+ }
+ }
+
+ void acquireRead() const { pthread_rwlock_rdlock(&rw_lock_); }
+
+ void acquireWrite() const { pthread_rwlock_wrlock(&rw_lock_); }
+
+ bool attemptRead() const { return pthread_rwlock_tryrdlock(&rw_lock_); }
+
+ bool attemptWrite() const { return pthread_rwlock_trywrlock(&rw_lock_); }
+
+ void release() const { pthread_rwlock_unlock(&rw_lock_); }
+
+private:
+ mutable pthread_rwlock_t rw_lock_;
+ mutable bool initialized_;
+};
+
+ReadWriteMutex::ReadWriteMutex() : impl_(new ReadWriteMutex::impl()) {}
+
+void ReadWriteMutex::acquireRead() const { impl_->acquireRead(); }
+
+void ReadWriteMutex::acquireWrite() const { impl_->acquireWrite(); }
+
+bool ReadWriteMutex::attemptRead() const { return impl_->attemptRead(); }
+
+bool ReadWriteMutex::attemptWrite() const { return impl_->attemptWrite(); }
+
+void ReadWriteMutex::release() const { impl_->release(); }
+
+}}} // apache::thrift::concurrency
+
diff --git a/lib/cpp/src/concurrency/Mutex.h b/lib/cpp/src/concurrency/Mutex.h
new file mode 100644
index 0000000..884412b
--- /dev/null
+++ b/lib/cpp/src/concurrency/Mutex.h
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_CONCURRENCY_MUTEX_H_
+#define _THRIFT_CONCURRENCY_MUTEX_H_ 1
+
+#include <boost/shared_ptr.hpp>
+
+namespace apache { namespace thrift { namespace concurrency {
+
+/**
+ * A simple mutex class
+ *
+ * @version $Id:$
+ */
+class Mutex {
+ public:
+ typedef void (*Initializer)(void*);
+
+ Mutex(Initializer init = DEFAULT_INITIALIZER);
+ virtual ~Mutex() {}
+ virtual void lock() const;
+ virtual bool trylock() const;
+ virtual void unlock() const;
+
+ static void DEFAULT_INITIALIZER(void*);
+ static void ADAPTIVE_INITIALIZER(void*);
+ static void RECURSIVE_INITIALIZER(void*);
+
+ private:
+
+ class impl;
+ boost::shared_ptr<impl> impl_;
+};
+
+class ReadWriteMutex {
+public:
+ ReadWriteMutex();
+ virtual ~ReadWriteMutex() {}
+
+ // these get the lock and block until it is done successfully
+ virtual void acquireRead() const;
+ virtual void acquireWrite() const;
+
+ // these attempt to get the lock, returning false immediately if they fail
+ virtual bool attemptRead() const;
+ virtual bool attemptWrite() const;
+
+ // this releases both read and write locks
+ virtual void release() const;
+
+private:
+
+ class impl;
+ boost::shared_ptr<impl> impl_;
+};
+
+class Guard {
+ public:
+ Guard(const Mutex& value) : mutex_(value) {
+ mutex_.lock();
+ }
+ ~Guard() {
+ mutex_.unlock();
+ }
+
+ private:
+ const Mutex& mutex_;
+};
+
+class RWGuard {
+ public:
+ RWGuard(const ReadWriteMutex& value, bool write = 0) : rw_mutex_(value) {
+ if (write) {
+ rw_mutex_.acquireWrite();
+ } else {
+ rw_mutex_.acquireRead();
+ }
+ }
+ ~RWGuard() {
+ rw_mutex_.release();
+ }
+ private:
+ const ReadWriteMutex& rw_mutex_;
+};
+
+
+// A little hack to prevent someone from trying to do "Guard(m);"
+// Sorry for polluting the global namespace, but I think it's worth it.
+#define Guard(m) incorrect_use_of_Guard(m)
+#define RWGuard(m) incorrect_use_of_RWGuard(m)
+
+
+}}} // apache::thrift::concurrency
+
+#endif // #ifndef _THRIFT_CONCURRENCY_MUTEX_H_
diff --git a/lib/cpp/src/concurrency/PosixThreadFactory.cpp b/lib/cpp/src/concurrency/PosixThreadFactory.cpp
new file mode 100644
index 0000000..e48dce3
--- /dev/null
+++ b/lib/cpp/src/concurrency/PosixThreadFactory.cpp
@@ -0,0 +1,308 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "PosixThreadFactory.h"
+#include "Exception.h"
+
+#if GOOGLE_PERFTOOLS_REGISTER_THREAD
+# include <google/profiler.h>
+#endif
+
+#include <assert.h>
+#include <pthread.h>
+
+#include <iostream>
+
+#include <boost/weak_ptr.hpp>
+
+namespace apache { namespace thrift { namespace concurrency {
+
+using boost::shared_ptr;
+using boost::weak_ptr;
+
+/**
+ * The POSIX thread class.
+ *
+ * @version $Id:$
+ */
+class PthreadThread: public Thread {
+ public:
+
+ enum STATE {
+ uninitialized,
+ starting,
+ started,
+ stopping,
+ stopped
+ };
+
+ static const int MB = 1024 * 1024;
+
+ static void* threadMain(void* arg);
+
+ private:
+ pthread_t pthread_;
+ STATE state_;
+ int policy_;
+ int priority_;
+ int stackSize_;
+ weak_ptr<PthreadThread> self_;
+ bool detached_;
+
+ public:
+
+ PthreadThread(int policy, int priority, int stackSize, bool detached, shared_ptr<Runnable> runnable) :
+ pthread_(0),
+ state_(uninitialized),
+ policy_(policy),
+ priority_(priority),
+ stackSize_(stackSize),
+ detached_(detached) {
+
+ this->Thread::runnable(runnable);
+ }
+
+ ~PthreadThread() {
+ /* Nothing references this thread, if is is not detached, do a join
+ now, otherwise the thread-id and, possibly, other resources will
+ be leaked. */
+ if(!detached_) {
+ try {
+ join();
+ } catch(...) {
+ // We're really hosed.
+ }
+ }
+ }
+
+ void start() {
+ if (state_ != uninitialized) {
+ return;
+ }
+
+ pthread_attr_t thread_attr;
+ if (pthread_attr_init(&thread_attr) != 0) {
+ throw SystemResourceException("pthread_attr_init failed");
+ }
+
+ if(pthread_attr_setdetachstate(&thread_attr,
+ detached_ ?
+ PTHREAD_CREATE_DETACHED :
+ PTHREAD_CREATE_JOINABLE) != 0) {
+ throw SystemResourceException("pthread_attr_setdetachstate failed");
+ }
+
+ // Set thread stack size
+ if (pthread_attr_setstacksize(&thread_attr, MB * stackSize_) != 0) {
+ throw SystemResourceException("pthread_attr_setstacksize failed");
+ }
+
+ // Set thread policy
+ if (pthread_attr_setschedpolicy(&thread_attr, policy_) != 0) {
+ throw SystemResourceException("pthread_attr_setschedpolicy failed");
+ }
+
+ struct sched_param sched_param;
+ sched_param.sched_priority = priority_;
+
+ // Set thread priority
+ if (pthread_attr_setschedparam(&thread_attr, &sched_param) != 0) {
+ throw SystemResourceException("pthread_attr_setschedparam failed");
+ }
+
+ // Create reference
+ shared_ptr<PthreadThread>* selfRef = new shared_ptr<PthreadThread>();
+ *selfRef = self_.lock();
+
+ state_ = starting;
+
+ if (pthread_create(&pthread_, &thread_attr, threadMain, (void*)selfRef) != 0) {
+ throw SystemResourceException("pthread_create failed");
+ }
+ }
+
+ void join() {
+ if (!detached_ && state_ != uninitialized) {
+ void* ignore;
+ /* XXX
+ If join fails it is most likely due to the fact
+ that the last reference was the thread itself and cannot
+ join. This results in leaked threads and will eventually
+ cause the process to run out of thread resources.
+ We're beyond the point of throwing an exception. Not clear how
+ best to handle this. */
+ detached_ = pthread_join(pthread_, &ignore) == 0;
+ }
+ }
+
+ Thread::id_t getId() {
+ return (Thread::id_t)pthread_;
+ }
+
+ shared_ptr<Runnable> runnable() const { return Thread::runnable(); }
+
+ void runnable(shared_ptr<Runnable> value) { Thread::runnable(value); }
+
+ void weakRef(shared_ptr<PthreadThread> self) {
+ assert(self.get() == this);
+ self_ = weak_ptr<PthreadThread>(self);
+ }
+};
+
+void* PthreadThread::threadMain(void* arg) {
+ shared_ptr<PthreadThread> thread = *(shared_ptr<PthreadThread>*)arg;
+ delete reinterpret_cast<shared_ptr<PthreadThread>*>(arg);
+
+ if (thread == NULL) {
+ return (void*)0;
+ }
+
+ if (thread->state_ != starting) {
+ return (void*)0;
+ }
+
+#if GOOGLE_PERFTOOLS_REGISTER_THREAD
+ ProfilerRegisterThread();
+#endif
+
+ thread->state_ = starting;
+ thread->runnable()->run();
+ if (thread->state_ != stopping && thread->state_ != stopped) {
+ thread->state_ = stopping;
+ }
+
+ return (void*)0;
+}
+
+/**
+ * POSIX Thread factory implementation
+ */
+class PosixThreadFactory::Impl {
+
+ private:
+ POLICY policy_;
+ PRIORITY priority_;
+ int stackSize_;
+ bool detached_;
+
+ /**
+ * Converts generic posix thread schedule policy enums into pthread
+ * API values.
+ */
+ static int toPthreadPolicy(POLICY policy) {
+ switch (policy) {
+ case OTHER:
+ return SCHED_OTHER;
+ case FIFO:
+ return SCHED_FIFO;
+ case ROUND_ROBIN:
+ return SCHED_RR;
+ }
+ return SCHED_OTHER;
+ }
+
+ /**
+ * Converts relative thread priorities to absolute value based on posix
+ * thread scheduler policy
+ *
+ * The idea is simply to divide up the priority range for the given policy
+ * into the correpsonding relative priority level (lowest..highest) and
+ * then pro-rate accordingly.
+ */
+ static int toPthreadPriority(POLICY policy, PRIORITY priority) {
+ int pthread_policy = toPthreadPolicy(policy);
+ int min_priority = sched_get_priority_min(pthread_policy);
+ int max_priority = sched_get_priority_max(pthread_policy);
+ int quanta = (HIGHEST - LOWEST) + 1;
+ float stepsperquanta = (max_priority - min_priority) / quanta;
+
+ if (priority <= HIGHEST) {
+ return (int)(min_priority + stepsperquanta * priority);
+ } else {
+ // should never get here for priority increments.
+ assert(false);
+ return (int)(min_priority + stepsperquanta * NORMAL);
+ }
+ }
+
+ public:
+
+ Impl(POLICY policy, PRIORITY priority, int stackSize, bool detached) :
+ policy_(policy),
+ priority_(priority),
+ stackSize_(stackSize),
+ detached_(detached) {}
+
+ /**
+ * Creates a new POSIX thread to run the runnable object
+ *
+ * @param runnable A runnable object
+ */
+ shared_ptr<Thread> newThread(shared_ptr<Runnable> runnable) const {
+ shared_ptr<PthreadThread> result = shared_ptr<PthreadThread>(new PthreadThread(toPthreadPolicy(policy_), toPthreadPriority(policy_, priority_), stackSize_, detached_, runnable));
+ result->weakRef(result);
+ runnable->thread(result);
+ return result;
+ }
+
+ int getStackSize() const { return stackSize_; }
+
+ void setStackSize(int value) { stackSize_ = value; }
+
+ PRIORITY getPriority() const { return priority_; }
+
+ /**
+ * Sets priority.
+ *
+ * XXX
+ * Need to handle incremental priorities properly.
+ */
+ void setPriority(PRIORITY value) { priority_ = value; }
+
+ bool isDetached() const { return detached_; }
+
+ void setDetached(bool value) { detached_ = value; }
+
+ Thread::id_t getCurrentThreadId() const {
+ // TODO(dreiss): Stop using C-style casts.
+ return (id_t)pthread_self();
+ }
+
+};
+
+PosixThreadFactory::PosixThreadFactory(POLICY policy, PRIORITY priority, int stackSize, bool detached) :
+ impl_(new PosixThreadFactory::Impl(policy, priority, stackSize, detached)) {}
+
+shared_ptr<Thread> PosixThreadFactory::newThread(shared_ptr<Runnable> runnable) const { return impl_->newThread(runnable); }
+
+int PosixThreadFactory::getStackSize() const { return impl_->getStackSize(); }
+
+void PosixThreadFactory::setStackSize(int value) { impl_->setStackSize(value); }
+
+PosixThreadFactory::PRIORITY PosixThreadFactory::getPriority() const { return impl_->getPriority(); }
+
+void PosixThreadFactory::setPriority(PosixThreadFactory::PRIORITY value) { impl_->setPriority(value); }
+
+bool PosixThreadFactory::isDetached() const { return impl_->isDetached(); }
+
+void PosixThreadFactory::setDetached(bool value) { impl_->setDetached(value); }
+
+Thread::id_t PosixThreadFactory::getCurrentThreadId() const { return impl_->getCurrentThreadId(); }
+
+}}} // apache::thrift::concurrency
diff --git a/lib/cpp/src/concurrency/PosixThreadFactory.h b/lib/cpp/src/concurrency/PosixThreadFactory.h
new file mode 100644
index 0000000..d6d83a3
--- /dev/null
+++ b/lib/cpp/src/concurrency/PosixThreadFactory.h
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_CONCURRENCY_POSIXTHREADFACTORY_H_
+#define _THRIFT_CONCURRENCY_POSIXTHREADFACTORY_H_ 1
+
+#include "Thread.h"
+
+#include <boost/shared_ptr.hpp>
+
+namespace apache { namespace thrift { namespace concurrency {
+
+/**
+ * A thread factory to create posix threads
+ *
+ * @version $Id:$
+ */
+class PosixThreadFactory : public ThreadFactory {
+
+ public:
+
+ /**
+ * POSIX Thread scheduler policies
+ */
+ enum POLICY {
+ OTHER,
+ FIFO,
+ ROUND_ROBIN
+ };
+
+ /**
+ * POSIX Thread scheduler relative priorities,
+ *
+ * Absolute priority is determined by scheduler policy and OS. This
+ * enumeration specifies relative priorities such that one can specify a
+ * priority withing a giving scheduler policy without knowing the absolute
+ * value of the priority.
+ */
+ enum PRIORITY {
+ LOWEST = 0,
+ LOWER = 1,
+ LOW = 2,
+ NORMAL = 3,
+ HIGH = 4,
+ HIGHER = 5,
+ HIGHEST = 6,
+ INCREMENT = 7,
+ DECREMENT = 8
+ };
+
+ /**
+ * Posix thread (pthread) factory. All threads created by a factory are reference-counted
+ * via boost::shared_ptr and boost::weak_ptr. The factory guarantees that threads and
+ * the Runnable tasks they host will be properly cleaned up once the last strong reference
+ * to both is given up.
+ *
+ * Threads are created with the specified policy, priority, stack-size and detachable-mode
+ * detached means the thread is free-running and will release all system resources the
+ * when it completes. A detachable thread is not joinable. The join method
+ * of a detachable thread will return immediately with no error.
+ *
+ * By default threads are not joinable.
+ */
+
+ PosixThreadFactory(POLICY policy=ROUND_ROBIN, PRIORITY priority=NORMAL, int stackSize=1, bool detached=true);
+
+ // From ThreadFactory;
+ boost::shared_ptr<Thread> newThread(boost::shared_ptr<Runnable> runnable) const;
+
+ // From ThreadFactory;
+ Thread::id_t getCurrentThreadId() const;
+
+ /**
+ * Gets stack size for created threads
+ *
+ * @return int size in megabytes
+ */
+ virtual int getStackSize() const;
+
+ /**
+ * Sets stack size for created threads
+ *
+ * @param value size in megabytes
+ */
+ virtual void setStackSize(int value);
+
+ /**
+ * Gets priority relative to current policy
+ */
+ virtual PRIORITY getPriority() const;
+
+ /**
+ * Sets priority relative to current policy
+ */
+ virtual void setPriority(PRIORITY priority);
+
+ /**
+ * Sets detached mode of threads
+ */
+ virtual void setDetached(bool detached);
+
+ /**
+ * Gets current detached mode
+ */
+ virtual bool isDetached() const;
+
+ private:
+ class Impl;
+ boost::shared_ptr<Impl> impl_;
+};
+
+}}} // apache::thrift::concurrency
+
+#endif // #ifndef _THRIFT_CONCURRENCY_POSIXTHREADFACTORY_H_
diff --git a/lib/cpp/src/concurrency/Thread.h b/lib/cpp/src/concurrency/Thread.h
new file mode 100644
index 0000000..d4282ad
--- /dev/null
+++ b/lib/cpp/src/concurrency/Thread.h
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_CONCURRENCY_THREAD_H_
+#define _THRIFT_CONCURRENCY_THREAD_H_ 1
+
+#include <stdint.h>
+#include <boost/shared_ptr.hpp>
+#include <boost/weak_ptr.hpp>
+
+namespace apache { namespace thrift { namespace concurrency {
+
+class Thread;
+
+/**
+ * Minimal runnable class. More or less analogous to java.lang.Runnable.
+ *
+ * @version $Id:$
+ */
+class Runnable {
+
+ public:
+ virtual ~Runnable() {};
+ virtual void run() = 0;
+
+ /**
+ * Gets the thread object that is hosting this runnable object - can return
+ * an empty boost::shared pointer if no references remain on thet thread object
+ */
+ virtual boost::shared_ptr<Thread> thread() { return thread_.lock(); }
+
+ /**
+ * Sets the thread that is executing this object. This is only meant for
+ * use by concrete implementations of Thread.
+ */
+ virtual void thread(boost::shared_ptr<Thread> value) { thread_ = value; }
+
+ private:
+ boost::weak_ptr<Thread> thread_;
+};
+
+/**
+ * Minimal thread class. Returned by thread factory bound to a Runnable object
+ * and ready to start execution. More or less analogous to java.lang.Thread
+ * (minus all the thread group, priority, mode and other baggage, since that
+ * is difficult to abstract across platforms and is left for platform-specific
+ * ThreadFactory implemtations to deal with
+ *
+ * @see apache::thrift::concurrency::ThreadFactory)
+ */
+class Thread {
+
+ public:
+
+ typedef uint64_t id_t;
+
+ virtual ~Thread() {};
+
+ /**
+ * Starts the thread. Does platform specific thread creation and
+ * configuration then invokes the run method of the Runnable object bound
+ * to this thread.
+ */
+ virtual void start() = 0;
+
+ /**
+ * Join this thread. Current thread blocks until this target thread
+ * completes.
+ */
+ virtual void join() = 0;
+
+ /**
+ * Gets the thread's platform-specific ID
+ */
+ virtual id_t getId() = 0;
+
+ /**
+ * Gets the runnable object this thread is hosting
+ */
+ virtual boost::shared_ptr<Runnable> runnable() const { return _runnable; }
+
+ protected:
+ virtual void runnable(boost::shared_ptr<Runnable> value) { _runnable = value; }
+
+ private:
+ boost::shared_ptr<Runnable> _runnable;
+
+};
+
+/**
+ * Factory to create platform-specific thread object and bind them to Runnable
+ * object for execution
+ */
+class ThreadFactory {
+
+ public:
+ virtual ~ThreadFactory() {}
+ virtual boost::shared_ptr<Thread> newThread(boost::shared_ptr<Runnable> runnable) const = 0;
+
+ /** Gets the current thread id or unknown_thread_id if the current thread is not a thrift thread */
+
+ static const Thread::id_t unknown_thread_id;
+
+ virtual Thread::id_t getCurrentThreadId() const = 0;
+};
+
+}}} // apache::thrift::concurrency
+
+#endif // #ifndef _THRIFT_CONCURRENCY_THREAD_H_
diff --git a/lib/cpp/src/concurrency/ThreadManager.cpp b/lib/cpp/src/concurrency/ThreadManager.cpp
new file mode 100644
index 0000000..abfcf6e
--- /dev/null
+++ b/lib/cpp/src/concurrency/ThreadManager.cpp
@@ -0,0 +1,493 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "ThreadManager.h"
+#include "Exception.h"
+#include "Monitor.h"
+
+#include <boost/shared_ptr.hpp>
+
+#include <assert.h>
+#include <queue>
+#include <set>
+
+#if defined(DEBUG)
+#include <iostream>
+#endif //defined(DEBUG)
+
+namespace apache { namespace thrift { namespace concurrency {
+
+using boost::shared_ptr;
+using boost::dynamic_pointer_cast;
+
+/**
+ * ThreadManager class
+ *
+ * This class manages a pool of threads. It uses a ThreadFactory to create
+ * threads. It never actually creates or destroys worker threads, rather
+ * it maintains statistics on number of idle threads, number of active threads,
+ * task backlog, and average wait and service times.
+ *
+ * @version $Id:$
+ */
+class ThreadManager::Impl : public ThreadManager {
+
+ public:
+ Impl() :
+ workerCount_(0),
+ workerMaxCount_(0),
+ idleCount_(0),
+ pendingTaskCountMax_(0),
+ state_(ThreadManager::UNINITIALIZED) {}
+
+ ~Impl() { stop(); }
+
+ void start();
+
+ void stop() { stopImpl(false); }
+
+ void join() { stopImpl(true); }
+
+ const ThreadManager::STATE state() const {
+ return state_;
+ }
+
+ shared_ptr<ThreadFactory> threadFactory() const {
+ Synchronized s(monitor_);
+ return threadFactory_;
+ }
+
+ void threadFactory(shared_ptr<ThreadFactory> value) {
+ Synchronized s(monitor_);
+ threadFactory_ = value;
+ }
+
+ void addWorker(size_t value);
+
+ void removeWorker(size_t value);
+
+ size_t idleWorkerCount() const {
+ return idleCount_;
+ }
+
+ size_t workerCount() const {
+ Synchronized s(monitor_);
+ return workerCount_;
+ }
+
+ size_t pendingTaskCount() const {
+ Synchronized s(monitor_);
+ return tasks_.size();
+ }
+
+ size_t totalTaskCount() const {
+ Synchronized s(monitor_);
+ return tasks_.size() + workerCount_ - idleCount_;
+ }
+
+ size_t pendingTaskCountMax() const {
+ Synchronized s(monitor_);
+ return pendingTaskCountMax_;
+ }
+
+ void pendingTaskCountMax(const size_t value) {
+ Synchronized s(monitor_);
+ pendingTaskCountMax_ = value;
+ }
+
+ bool canSleep();
+
+ void add(shared_ptr<Runnable> value, int64_t timeout);
+
+ void remove(shared_ptr<Runnable> task);
+
+private:
+ void stopImpl(bool join);
+
+ size_t workerCount_;
+ size_t workerMaxCount_;
+ size_t idleCount_;
+ size_t pendingTaskCountMax_;
+
+ ThreadManager::STATE state_;
+ shared_ptr<ThreadFactory> threadFactory_;
+
+
+ friend class ThreadManager::Task;
+ std::queue<shared_ptr<Task> > tasks_;
+ Monitor monitor_;
+ Monitor workerMonitor_;
+
+ friend class ThreadManager::Worker;
+ std::set<shared_ptr<Thread> > workers_;
+ std::set<shared_ptr<Thread> > deadWorkers_;
+ std::map<const Thread::id_t, shared_ptr<Thread> > idMap_;
+};
+
+class ThreadManager::Task : public Runnable {
+
+ public:
+ enum STATE {
+ WAITING,
+ EXECUTING,
+ CANCELLED,
+ COMPLETE
+ };
+
+ Task(shared_ptr<Runnable> runnable) :
+ runnable_(runnable),
+ state_(WAITING) {}
+
+ ~Task() {}
+
+ void run() {
+ if (state_ == EXECUTING) {
+ runnable_->run();
+ state_ = COMPLETE;
+ }
+ }
+
+ private:
+ shared_ptr<Runnable> runnable_;
+ friend class ThreadManager::Worker;
+ STATE state_;
+};
+
+class ThreadManager::Worker: public Runnable {
+ enum STATE {
+ UNINITIALIZED,
+ STARTING,
+ STARTED,
+ STOPPING,
+ STOPPED
+ };
+
+ public:
+ Worker(ThreadManager::Impl* manager) :
+ manager_(manager),
+ state_(UNINITIALIZED),
+ idle_(false) {}
+
+ ~Worker() {}
+
+ private:
+ bool isActive() const {
+ return
+ (manager_->workerCount_ <= manager_->workerMaxCount_) ||
+ (manager_->state_ == JOINING && !manager_->tasks_.empty());
+ }
+
+ public:
+ /**
+ * Worker entry point
+ *
+ * As long as worker thread is running, pull tasks off the task queue and
+ * execute.
+ */
+ void run() {
+ bool active = false;
+ bool notifyManager = false;
+
+ /**
+ * Increment worker semaphore and notify manager if worker count reached
+ * desired max
+ *
+ * Note: We have to release the monitor and acquire the workerMonitor
+ * since that is what the manager blocks on for worker add/remove
+ */
+ {
+ Synchronized s(manager_->monitor_);
+ active = manager_->workerCount_ < manager_->workerMaxCount_;
+ if (active) {
+ manager_->workerCount_++;
+ notifyManager = manager_->workerCount_ == manager_->workerMaxCount_;
+ }
+ }
+
+ if (notifyManager) {
+ Synchronized s(manager_->workerMonitor_);
+ manager_->workerMonitor_.notify();
+ notifyManager = false;
+ }
+
+ while (active) {
+ shared_ptr<ThreadManager::Task> task;
+
+ /**
+ * While holding manager monitor block for non-empty task queue (Also
+ * check that the thread hasn't been requested to stop). Once the queue
+ * is non-empty, dequeue a task, release monitor, and execute. If the
+ * worker max count has been decremented such that we exceed it, mark
+ * ourself inactive, decrement the worker count and notify the manager
+ * (technically we're notifying the next blocked thread but eventually
+ * the manager will see it.
+ */
+ {
+ Synchronized s(manager_->monitor_);
+ active = isActive();
+
+ while (active && manager_->tasks_.empty()) {
+ manager_->idleCount_++;
+ idle_ = true;
+ manager_->monitor_.wait();
+ active = isActive();
+ idle_ = false;
+ manager_->idleCount_--;
+ }
+
+ if (active) {
+ if (!manager_->tasks_.empty()) {
+ task = manager_->tasks_.front();
+ manager_->tasks_.pop();
+ if (task->state_ == ThreadManager::Task::WAITING) {
+ task->state_ = ThreadManager::Task::EXECUTING;
+ }
+
+ /* If we have a pending task max and we just dropped below it, wakeup any
+ thread that might be blocked on add. */
+ if (manager_->pendingTaskCountMax_ != 0 &&
+ manager_->tasks_.size() == manager_->pendingTaskCountMax_ - 1) {
+ manager_->monitor_.notify();
+ }
+ }
+ } else {
+ idle_ = true;
+ manager_->workerCount_--;
+ notifyManager = (manager_->workerCount_ == manager_->workerMaxCount_);
+ }
+ }
+
+ if (task != NULL) {
+ if (task->state_ == ThreadManager::Task::EXECUTING) {
+ try {
+ task->run();
+ } catch(...) {
+ // XXX need to log this
+ }
+ }
+ }
+ }
+
+ {
+ Synchronized s(manager_->workerMonitor_);
+ manager_->deadWorkers_.insert(this->thread());
+ if (notifyManager) {
+ manager_->workerMonitor_.notify();
+ }
+ }
+
+ return;
+ }
+
+ private:
+ ThreadManager::Impl* manager_;
+ friend class ThreadManager::Impl;
+ STATE state_;
+ bool idle_;
+};
+
+
+ void ThreadManager::Impl::addWorker(size_t value) {
+ std::set<shared_ptr<Thread> > newThreads;
+ for (size_t ix = 0; ix < value; ix++) {
+ class ThreadManager::Worker;
+ shared_ptr<ThreadManager::Worker> worker = shared_ptr<ThreadManager::Worker>(new ThreadManager::Worker(this));
+ newThreads.insert(threadFactory_->newThread(worker));
+ }
+
+ {
+ Synchronized s(monitor_);
+ workerMaxCount_ += value;
+ workers_.insert(newThreads.begin(), newThreads.end());
+ }
+
+ for (std::set<shared_ptr<Thread> >::iterator ix = newThreads.begin(); ix != newThreads.end(); ix++) {
+ shared_ptr<ThreadManager::Worker> worker = dynamic_pointer_cast<ThreadManager::Worker, Runnable>((*ix)->runnable());
+ worker->state_ = ThreadManager::Worker::STARTING;
+ (*ix)->start();
+ idMap_.insert(std::pair<const Thread::id_t, shared_ptr<Thread> >((*ix)->getId(), *ix));
+ }
+
+ {
+ Synchronized s(workerMonitor_);
+ while (workerCount_ != workerMaxCount_) {
+ workerMonitor_.wait();
+ }
+ }
+}
+
+void ThreadManager::Impl::start() {
+
+ if (state_ == ThreadManager::STOPPED) {
+ return;
+ }
+
+ {
+ Synchronized s(monitor_);
+ if (state_ == ThreadManager::UNINITIALIZED) {
+ if (threadFactory_ == NULL) {
+ throw InvalidArgumentException();
+ }
+ state_ = ThreadManager::STARTED;
+ monitor_.notifyAll();
+ }
+
+ while (state_ == STARTING) {
+ monitor_.wait();
+ }
+ }
+}
+
+void ThreadManager::Impl::stopImpl(bool join) {
+ bool doStop = false;
+ if (state_ == ThreadManager::STOPPED) {
+ return;
+ }
+
+ {
+ Synchronized s(monitor_);
+ if (state_ != ThreadManager::STOPPING &&
+ state_ != ThreadManager::JOINING &&
+ state_ != ThreadManager::STOPPED) {
+ doStop = true;
+ state_ = join ? ThreadManager::JOINING : ThreadManager::STOPPING;
+ }
+ }
+
+ if (doStop) {
+ removeWorker(workerCount_);
+ }
+
+ // XXX
+ // should be able to block here for transition to STOPPED since we're no
+ // using shared_ptrs
+
+ {
+ Synchronized s(monitor_);
+ state_ = ThreadManager::STOPPED;
+ }
+
+}
+
+void ThreadManager::Impl::removeWorker(size_t value) {
+ std::set<shared_ptr<Thread> > removedThreads;
+ {
+ Synchronized s(monitor_);
+ if (value > workerMaxCount_) {
+ throw InvalidArgumentException();
+ }
+
+ workerMaxCount_ -= value;
+
+ if (idleCount_ < value) {
+ for (size_t ix = 0; ix < idleCount_; ix++) {
+ monitor_.notify();
+ }
+ } else {
+ monitor_.notifyAll();
+ }
+ }
+
+ {
+ Synchronized s(workerMonitor_);
+
+ while (workerCount_ != workerMaxCount_) {
+ workerMonitor_.wait();
+ }
+
+ for (std::set<shared_ptr<Thread> >::iterator ix = deadWorkers_.begin(); ix != deadWorkers_.end(); ix++) {
+ workers_.erase(*ix);
+ idMap_.erase((*ix)->getId());
+ }
+
+ deadWorkers_.clear();
+ }
+}
+
+ bool ThreadManager::Impl::canSleep() {
+ const Thread::id_t id = threadFactory_->getCurrentThreadId();
+ return idMap_.find(id) == idMap_.end();
+ }
+
+ void ThreadManager::Impl::add(shared_ptr<Runnable> value, int64_t timeout) {
+ Synchronized s(monitor_);
+
+ if (state_ != ThreadManager::STARTED) {
+ throw IllegalStateException();
+ }
+
+ if (pendingTaskCountMax_ > 0 && (tasks_.size() >= pendingTaskCountMax_)) {
+ if (canSleep() && timeout >= 0) {
+ while (pendingTaskCountMax_ > 0 && tasks_.size() >= pendingTaskCountMax_) {
+ monitor_.wait(timeout);
+ }
+ } else {
+ throw TooManyPendingTasksException();
+ }
+ }
+
+ tasks_.push(shared_ptr<ThreadManager::Task>(new ThreadManager::Task(value)));
+
+ // If idle thread is available notify it, otherwise all worker threads are
+ // running and will get around to this task in time.
+ if (idleCount_ > 0) {
+ monitor_.notify();
+ }
+ }
+
+void ThreadManager::Impl::remove(shared_ptr<Runnable> task) {
+ Synchronized s(monitor_);
+ if (state_ != ThreadManager::STARTED) {
+ throw IllegalStateException();
+ }
+}
+
+class SimpleThreadManager : public ThreadManager::Impl {
+
+ public:
+ SimpleThreadManager(size_t workerCount=4, size_t pendingTaskCountMax=0) :
+ workerCount_(workerCount),
+ pendingTaskCountMax_(pendingTaskCountMax),
+ firstTime_(true) {
+ }
+
+ void start() {
+ ThreadManager::Impl::pendingTaskCountMax(pendingTaskCountMax_);
+ ThreadManager::Impl::start();
+ addWorker(workerCount_);
+ }
+
+ private:
+ const size_t workerCount_;
+ const size_t pendingTaskCountMax_;
+ bool firstTime_;
+ Monitor monitor_;
+};
+
+
+shared_ptr<ThreadManager> ThreadManager::newThreadManager() {
+ return shared_ptr<ThreadManager>(new ThreadManager::Impl());
+}
+
+shared_ptr<ThreadManager> ThreadManager::newSimpleThreadManager(size_t count, size_t pendingTaskCountMax) {
+ return shared_ptr<ThreadManager>(new SimpleThreadManager(count, pendingTaskCountMax));
+}
+
+}}} // apache::thrift::concurrency
+
diff --git a/lib/cpp/src/concurrency/ThreadManager.h b/lib/cpp/src/concurrency/ThreadManager.h
new file mode 100644
index 0000000..6e5a178
--- /dev/null
+++ b/lib/cpp/src/concurrency/ThreadManager.h
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_CONCURRENCY_THREADMANAGER_H_
+#define _THRIFT_CONCURRENCY_THREADMANAGER_H_ 1
+
+#include <boost/shared_ptr.hpp>
+#include <sys/types.h>
+#include "Thread.h"
+
+namespace apache { namespace thrift { namespace concurrency {
+
+/**
+ * Thread Pool Manager and related classes
+ *
+ * @version $Id:$
+ */
+class ThreadManager;
+
+/**
+ * ThreadManager class
+ *
+ * This class manages a pool of threads. It uses a ThreadFactory to create
+ * threads. It never actually creates or destroys worker threads, rather
+ * It maintains statistics on number of idle threads, number of active threads,
+ * task backlog, and average wait and service times and informs the PoolPolicy
+ * object bound to instances of this manager of interesting transitions. It is
+ * then up the PoolPolicy object to decide if the thread pool size needs to be
+ * adjusted and call this object addWorker and removeWorker methods to make
+ * changes.
+ *
+ * This design allows different policy implementations to used this code to
+ * handle basic worker thread management and worker task execution and focus on
+ * policy issues. The simplest policy, StaticPolicy, does nothing other than
+ * create a fixed number of threads.
+ */
+class ThreadManager {
+
+ protected:
+ ThreadManager() {}
+
+ public:
+ virtual ~ThreadManager() {}
+
+ /**
+ * Starts the thread manager. Verifies all attributes have been properly
+ * initialized, then allocates necessary resources to begin operation
+ */
+ virtual void start() = 0;
+
+ /**
+ * Stops the thread manager. Aborts all remaining unprocessed task, shuts
+ * down all created worker threads, and realeases all allocated resources.
+ * This method blocks for all worker threads to complete, thus it can
+ * potentially block forever if a worker thread is running a task that
+ * won't terminate.
+ */
+ virtual void stop() = 0;
+
+ /**
+ * Joins the thread manager. This is the same as stop, except that it will
+ * block until all the workers have finished their work. At that point
+ * the ThreadManager will transition into the STOPPED state.
+ */
+ virtual void join() = 0;
+
+ enum STATE {
+ UNINITIALIZED,
+ STARTING,
+ STARTED,
+ JOINING,
+ STOPPING,
+ STOPPED
+ };
+
+ virtual const STATE state() const = 0;
+
+ virtual boost::shared_ptr<ThreadFactory> threadFactory() const = 0;
+
+ virtual void threadFactory(boost::shared_ptr<ThreadFactory> value) = 0;
+
+ virtual void addWorker(size_t value=1) = 0;
+
+ virtual void removeWorker(size_t value=1) = 0;
+
+ /**
+ * Gets the current number of idle worker threads
+ */
+ virtual size_t idleWorkerCount() const = 0;
+
+ /**
+ * Gets the current number of total worker threads
+ */
+ virtual size_t workerCount() const = 0;
+
+ /**
+ * Gets the current number of pending tasks
+ */
+ virtual size_t pendingTaskCount() const = 0;
+
+ /**
+ * Gets the current number of pending and executing tasks
+ */
+ virtual size_t totalTaskCount() const = 0;
+
+ /**
+ * Gets the maximum pending task count. 0 indicates no maximum
+ */
+ virtual size_t pendingTaskCountMax() const = 0;
+
+ /**
+ * Adds a task to be executed at some time in the future by a worker thread.
+ *
+ * This method will block if pendingTaskCountMax() in not zero and pendingTaskCount()
+ * is greater than or equalt to pendingTaskCountMax(). If this method is called in the
+ * context of a ThreadManager worker thread it will throw a
+ * TooManyPendingTasksException
+ *
+ * @param task The task to queue for execution
+ *
+ * @param timeout Time to wait in milliseconds to add a task when a pending-task-count
+ * is specified. Specific cases:
+ * timeout = 0 : Wait forever to queue task.
+ * timeout = -1 : Return immediately if pending task count exceeds specified max
+ *
+ * @throws TooManyPendingTasksException Pending task count exceeds max pending task count
+ */
+ virtual void add(boost::shared_ptr<Runnable>task, int64_t timeout=0LL) = 0;
+
+ /**
+ * Removes a pending task
+ */
+ virtual void remove(boost::shared_ptr<Runnable> task) = 0;
+
+ static boost::shared_ptr<ThreadManager> newThreadManager();
+
+ /**
+ * Creates a simple thread manager the uses count number of worker threads and has
+ * a pendingTaskCountMax maximum pending tasks. The default, 0, specified no limit
+ * on pending tasks
+ */
+ static boost::shared_ptr<ThreadManager> newSimpleThreadManager(size_t count=4, size_t pendingTaskCountMax=0);
+
+ class Task;
+
+ class Worker;
+
+ class Impl;
+};
+
+}}} // apache::thrift::concurrency
+
+#endif // #ifndef _THRIFT_CONCURRENCY_THREADMANAGER_H_
diff --git a/lib/cpp/src/concurrency/TimerManager.cpp b/lib/cpp/src/concurrency/TimerManager.cpp
new file mode 100644
index 0000000..25515dc
--- /dev/null
+++ b/lib/cpp/src/concurrency/TimerManager.cpp
@@ -0,0 +1,284 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "TimerManager.h"
+#include "Exception.h"
+#include "Util.h"
+
+#include <assert.h>
+#include <iostream>
+#include <set>
+
+namespace apache { namespace thrift { namespace concurrency {
+
+using boost::shared_ptr;
+
+typedef std::multimap<int64_t, shared_ptr<TimerManager::Task> >::iterator task_iterator;
+typedef std::pair<task_iterator, task_iterator> task_range;
+
+/**
+ * TimerManager class
+ *
+ * @version $Id:$
+ */
+class TimerManager::Task : public Runnable {
+
+ public:
+ enum STATE {
+ WAITING,
+ EXECUTING,
+ CANCELLED,
+ COMPLETE
+ };
+
+ Task(shared_ptr<Runnable> runnable) :
+ runnable_(runnable),
+ state_(WAITING) {}
+
+ ~Task() {
+ }
+
+ void run() {
+ if (state_ == EXECUTING) {
+ runnable_->run();
+ state_ = COMPLETE;
+ }
+ }
+
+ private:
+ shared_ptr<Runnable> runnable_;
+ class TimerManager::Dispatcher;
+ friend class TimerManager::Dispatcher;
+ STATE state_;
+};
+
+class TimerManager::Dispatcher: public Runnable {
+
+ public:
+ Dispatcher(TimerManager* manager) :
+ manager_(manager) {}
+
+ ~Dispatcher() {}
+
+ /**
+ * Dispatcher entry point
+ *
+ * As long as dispatcher thread is running, pull tasks off the task taskMap_
+ * and execute.
+ */
+ void run() {
+ {
+ Synchronized s(manager_->monitor_);
+ if (manager_->state_ == TimerManager::STARTING) {
+ manager_->state_ = TimerManager::STARTED;
+ manager_->monitor_.notifyAll();
+ }
+ }
+
+ do {
+ std::set<shared_ptr<TimerManager::Task> > expiredTasks;
+ {
+ Synchronized s(manager_->monitor_);
+ task_iterator expiredTaskEnd;
+ int64_t now = Util::currentTime();
+ while (manager_->state_ == TimerManager::STARTED &&
+ (expiredTaskEnd = manager_->taskMap_.upper_bound(now)) == manager_->taskMap_.begin()) {
+ int64_t timeout = 0LL;
+ if (!manager_->taskMap_.empty()) {
+ timeout = manager_->taskMap_.begin()->first - now;
+ }
+ assert((timeout != 0 && manager_->taskCount_ > 0) || (timeout == 0 && manager_->taskCount_ == 0));
+ try {
+ manager_->monitor_.wait(timeout);
+ } catch (TimedOutException &e) {}
+ now = Util::currentTime();
+ }
+
+ if (manager_->state_ == TimerManager::STARTED) {
+ for (task_iterator ix = manager_->taskMap_.begin(); ix != expiredTaskEnd; ix++) {
+ shared_ptr<TimerManager::Task> task = ix->second;
+ expiredTasks.insert(task);
+ if (task->state_ == TimerManager::Task::WAITING) {
+ task->state_ = TimerManager::Task::EXECUTING;
+ }
+ manager_->taskCount_--;
+ }
+ manager_->taskMap_.erase(manager_->taskMap_.begin(), expiredTaskEnd);
+ }
+ }
+
+ for (std::set<shared_ptr<Task> >::iterator ix = expiredTasks.begin(); ix != expiredTasks.end(); ix++) {
+ (*ix)->run();
+ }
+
+ } while (manager_->state_ == TimerManager::STARTED);
+
+ {
+ Synchronized s(manager_->monitor_);
+ if (manager_->state_ == TimerManager::STOPPING) {
+ manager_->state_ = TimerManager::STOPPED;
+ manager_->monitor_.notify();
+ }
+ }
+ return;
+ }
+
+ private:
+ TimerManager* manager_;
+ friend class TimerManager;
+};
+
+TimerManager::TimerManager() :
+ taskCount_(0),
+ state_(TimerManager::UNINITIALIZED),
+ dispatcher_(shared_ptr<Dispatcher>(new Dispatcher(this))) {
+}
+
+
+TimerManager::~TimerManager() {
+
+ // If we haven't been explicitly stopped, do so now. We don't need to grab
+ // the monitor here, since stop already takes care of reentrancy.
+
+ if (state_ != STOPPED) {
+ try {
+ stop();
+ } catch(...) {
+ throw;
+ // uhoh
+ }
+ }
+}
+
+void TimerManager::start() {
+ bool doStart = false;
+ {
+ Synchronized s(monitor_);
+ if (threadFactory_ == NULL) {
+ throw InvalidArgumentException();
+ }
+ if (state_ == TimerManager::UNINITIALIZED) {
+ state_ = TimerManager::STARTING;
+ doStart = true;
+ }
+ }
+
+ if (doStart) {
+ dispatcherThread_ = threadFactory_->newThread(dispatcher_);
+ dispatcherThread_->start();
+ }
+
+ {
+ Synchronized s(monitor_);
+ while (state_ == TimerManager::STARTING) {
+ monitor_.wait();
+ }
+ assert(state_ != TimerManager::STARTING);
+ }
+}
+
+void TimerManager::stop() {
+ bool doStop = false;
+ {
+ Synchronized s(monitor_);
+ if (state_ == TimerManager::UNINITIALIZED) {
+ state_ = TimerManager::STOPPED;
+ } else if (state_ != STOPPING && state_ != STOPPED) {
+ doStop = true;
+ state_ = STOPPING;
+ monitor_.notifyAll();
+ }
+ while (state_ != STOPPED) {
+ monitor_.wait();
+ }
+ }
+
+ if (doStop) {
+ // Clean up any outstanding tasks
+ for (task_iterator ix = taskMap_.begin(); ix != taskMap_.end(); ix++) {
+ taskMap_.erase(ix);
+ }
+
+ // Remove dispatcher's reference to us.
+ dispatcher_->manager_ = NULL;
+ }
+}
+
+shared_ptr<const ThreadFactory> TimerManager::threadFactory() const {
+ Synchronized s(monitor_);
+ return threadFactory_;
+}
+
+void TimerManager::threadFactory(shared_ptr<const ThreadFactory> value) {
+ Synchronized s(monitor_);
+ threadFactory_ = value;
+}
+
+size_t TimerManager::taskCount() const {
+ return taskCount_;
+}
+
+void TimerManager::add(shared_ptr<Runnable> task, int64_t timeout) {
+ int64_t now = Util::currentTime();
+ timeout += now;
+
+ {
+ Synchronized s(monitor_);
+ if (state_ != TimerManager::STARTED) {
+ throw IllegalStateException();
+ }
+
+ taskCount_++;
+ taskMap_.insert(std::pair<int64_t, shared_ptr<Task> >(timeout, shared_ptr<Task>(new Task(task))));
+
+ // If the task map was empty, or if we have an expiration that is earlier
+ // than any previously seen, kick the dispatcher so it can update its
+ // timeout
+ if (taskCount_ == 1 || timeout < taskMap_.begin()->first) {
+ monitor_.notify();
+ }
+ }
+}
+
+void TimerManager::add(shared_ptr<Runnable> task, const struct timespec& value) {
+
+ int64_t expiration;
+ Util::toMilliseconds(expiration, value);
+
+ int64_t now = Util::currentTime();
+
+ if (expiration < now) {
+ throw InvalidArgumentException();
+ }
+
+ add(task, expiration - now);
+}
+
+
+void TimerManager::remove(shared_ptr<Runnable> task) {
+ Synchronized s(monitor_);
+ if (state_ != TimerManager::STARTED) {
+ throw IllegalStateException();
+ }
+}
+
+const TimerManager::STATE TimerManager::state() const { return state_; }
+
+}}} // apache::thrift::concurrency
+
diff --git a/lib/cpp/src/concurrency/TimerManager.h b/lib/cpp/src/concurrency/TimerManager.h
new file mode 100644
index 0000000..f3f799f
--- /dev/null
+++ b/lib/cpp/src/concurrency/TimerManager.h
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_CONCURRENCY_TIMERMANAGER_H_
+#define _THRIFT_CONCURRENCY_TIMERMANAGER_H_ 1
+
+#include "Exception.h"
+#include "Monitor.h"
+#include "Thread.h"
+
+#include <boost/shared_ptr.hpp>
+#include <map>
+#include <time.h>
+
+namespace apache { namespace thrift { namespace concurrency {
+
+/**
+ * Timer Manager
+ *
+ * This class dispatches timer tasks when they fall due.
+ *
+ * @version $Id:$
+ */
+class TimerManager {
+
+ public:
+
+ TimerManager();
+
+ virtual ~TimerManager();
+
+ virtual boost::shared_ptr<const ThreadFactory> threadFactory() const;
+
+ virtual void threadFactory(boost::shared_ptr<const ThreadFactory> value);
+
+ /**
+ * Starts the timer manager service
+ *
+ * @throws IllegalArgumentException Missing thread factory attribute
+ */
+ virtual void start();
+
+ /**
+ * Stops the timer manager service
+ */
+ virtual void stop();
+
+ virtual size_t taskCount() const ;
+
+ /**
+ * Adds a task to be executed at some time in the future by a worker thread.
+ *
+ * @param task The task to execute
+ * @param timeout Time in milliseconds to delay before executing task
+ */
+ virtual void add(boost::shared_ptr<Runnable> task, int64_t timeout);
+
+ /**
+ * Adds a task to be executed at some time in the future by a worker thread.
+ *
+ * @param task The task to execute
+ * @param timeout Absolute time in the future to execute task.
+ */
+ virtual void add(boost::shared_ptr<Runnable> task, const struct timespec& timeout);
+
+ /**
+ * Removes a pending task
+ *
+ * @throws NoSuchTaskException Specified task doesn't exist. It was either
+ * processed already or this call was made for a
+ * task that was never added to this timer
+ *
+ * @throws UncancellableTaskException Specified task is already being
+ * executed or has completed execution.
+ */
+ virtual void remove(boost::shared_ptr<Runnable> task);
+
+ enum STATE {
+ UNINITIALIZED,
+ STARTING,
+ STARTED,
+ STOPPING,
+ STOPPED
+ };
+
+ virtual const STATE state() const;
+
+ private:
+ boost::shared_ptr<const ThreadFactory> threadFactory_;
+ class Task;
+ friend class Task;
+ std::multimap<int64_t, boost::shared_ptr<Task> > taskMap_;
+ size_t taskCount_;
+ Monitor monitor_;
+ STATE state_;
+ class Dispatcher;
+ friend class Dispatcher;
+ boost::shared_ptr<Dispatcher> dispatcher_;
+ boost::shared_ptr<Thread> dispatcherThread_;
+};
+
+}}} // apache::thrift::concurrency
+
+#endif // #ifndef _THRIFT_CONCURRENCY_TIMERMANAGER_H_
diff --git a/lib/cpp/src/concurrency/Util.cpp b/lib/cpp/src/concurrency/Util.cpp
new file mode 100644
index 0000000..1c44937
--- /dev/null
+++ b/lib/cpp/src/concurrency/Util.cpp
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "Util.h"
+
+#ifdef HAVE_CONFIG_H
+#include <config.h>
+#endif
+
+#if defined(HAVE_CLOCK_GETTIME)
+#include <time.h>
+#elif defined(HAVE_GETTIMEOFDAY)
+#include <sys/time.h>
+#endif // defined(HAVE_CLOCK_GETTIME)
+
+namespace apache { namespace thrift { namespace concurrency {
+
+const int64_t Util::currentTime() {
+ int64_t result;
+
+#if defined(HAVE_CLOCK_GETTIME)
+ struct timespec now;
+ int ret = clock_gettime(CLOCK_REALTIME, &now);
+ assert(ret == 0);
+ toMilliseconds(result, now);
+#elif defined(HAVE_GETTIMEOFDAY)
+ struct timeval now;
+ int ret = gettimeofday(&now, NULL);
+ assert(ret == 0);
+ toMilliseconds(result, now);
+#else
+#error "No high-precision clock is available."
+#endif // defined(HAVE_CLOCK_GETTIME)
+
+ return result;
+}
+
+
+}}} // apache::thrift::concurrency
diff --git a/lib/cpp/src/concurrency/Util.h b/lib/cpp/src/concurrency/Util.h
new file mode 100644
index 0000000..25fcc20
--- /dev/null
+++ b/lib/cpp/src/concurrency/Util.h
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_CONCURRENCY_UTIL_H_
+#define _THRIFT_CONCURRENCY_UTIL_H_ 1
+
+#include <assert.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <time.h>
+#include <sys/time.h>
+
+namespace apache { namespace thrift { namespace concurrency {
+
+/**
+ * Utility methods
+ *
+ * This class contains basic utility methods for converting time formats,
+ * and other common platform-dependent concurrency operations.
+ * It should not be included in API headers for other concurrency library
+ * headers, since it will, by definition, pull in all sorts of horrid
+ * platform dependent crap. Rather it should be inluded directly in
+ * concurrency library implementation source.
+ *
+ * @version $Id:$
+ */
+class Util {
+
+ static const int64_t NS_PER_S = 1000000000LL;
+ static const int64_t US_PER_S = 1000000LL;
+ static const int64_t MS_PER_S = 1000LL;
+
+ static const int64_t NS_PER_MS = NS_PER_S / MS_PER_S;
+ static const int64_t US_PER_MS = US_PER_S / MS_PER_S;
+
+ public:
+
+ /**
+ * Converts millisecond timestamp into a timespec struct
+ *
+ * @param struct timespec& result
+ * @param time or duration in milliseconds
+ */
+ static void toTimespec(struct timespec& result, int64_t value) {
+ result.tv_sec = value / MS_PER_S; // ms to s
+ result.tv_nsec = (value % MS_PER_S) * NS_PER_MS; // ms to ns
+ }
+
+ static void toTimeval(struct timeval& result, int64_t value) {
+ result.tv_sec = value / MS_PER_S; // ms to s
+ result.tv_usec = (value % MS_PER_S) * US_PER_MS; // ms to us
+ }
+
+ /**
+ * Converts struct timespec to milliseconds
+ */
+ static const void toMilliseconds(int64_t& result, const struct timespec& value) {
+ result = (value.tv_sec * MS_PER_S) + (value.tv_nsec / NS_PER_MS);
+ // round up -- int64_t cast is to avoid a compiler error for some GCCs
+ if (int64_t(value.tv_nsec) % NS_PER_MS >= (NS_PER_MS / 2)) {
+ ++result;
+ }
+ }
+
+ /**
+ * Converts struct timeval to milliseconds
+ */
+ static const void toMilliseconds(int64_t& result, const struct timeval& value) {
+ result = (value.tv_sec * MS_PER_S) + (value.tv_usec / US_PER_MS);
+ // round up -- int64_t cast is to avoid a compiler error for some GCCs
+ if (int64_t(value.tv_usec) % US_PER_MS >= (US_PER_MS / 2)) {
+ ++result;
+ }
+ }
+
+ /**
+ * Get current time as milliseconds from epoch
+ */
+ static const int64_t currentTime();
+};
+
+}}} // apache::thrift::concurrency
+
+#endif // #ifndef _THRIFT_CONCURRENCY_UTIL_H_
diff --git a/lib/cpp/src/concurrency/test/Tests.cpp b/lib/cpp/src/concurrency/test/Tests.cpp
new file mode 100644
index 0000000..c80bb88
--- /dev/null
+++ b/lib/cpp/src/concurrency/test/Tests.cpp
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <iostream>
+#include <vector>
+#include <string>
+
+#include "ThreadFactoryTests.h"
+#include "TimerManagerTests.h"
+#include "ThreadManagerTests.h"
+
+int main(int argc, char** argv) {
+
+ std::string arg;
+
+ std::vector<std::string> args(argc - 1 > 1 ? argc - 1 : 1);
+
+ args[0] = "all";
+
+ for (int ix = 1; ix < argc; ix++) {
+ args[ix - 1] = std::string(argv[ix]);
+ }
+
+ bool runAll = args[0].compare("all") == 0;
+
+ if (runAll || args[0].compare("thread-factory") == 0) {
+
+ ThreadFactoryTests threadFactoryTests;
+
+ std::cout << "ThreadFactory tests..." << std::endl;
+
+ size_t count = 1000;
+ size_t floodLoops = 1;
+ size_t floodCount = 100000;
+
+ std::cout << "\t\tThreadFactory reap N threads test: N = " << count << std::endl;
+
+ assert(threadFactoryTests.reapNThreads(count));
+
+ std::cout << "\t\tThreadFactory floodN threads test: N = " << floodCount << std::endl;
+
+ assert(threadFactoryTests.floodNTest(floodLoops, floodCount));
+
+ std::cout << "\t\tThreadFactory synchronous start test" << std::endl;
+
+ assert(threadFactoryTests.synchStartTest());
+
+ std::cout << "\t\tThreadFactory monitor timeout test" << std::endl;
+
+ assert(threadFactoryTests.monitorTimeoutTest());
+ }
+
+ if (runAll || args[0].compare("util") == 0) {
+
+ std::cout << "Util tests..." << std::endl;
+
+ std::cout << "\t\tUtil minimum time" << std::endl;
+
+ int64_t time00 = Util::currentTime();
+ int64_t time01 = Util::currentTime();
+
+ std::cout << "\t\t\tMinimum time: " << time01 - time00 << "ms" << std::endl;
+
+ time00 = Util::currentTime();
+ time01 = time00;
+ size_t count = 0;
+
+ while (time01 < time00 + 10) {
+ count++;
+ time01 = Util::currentTime();
+ }
+
+ std::cout << "\t\t\tscall per ms: " << count / (time01 - time00) << std::endl;
+ }
+
+
+ if (runAll || args[0].compare("timer-manager") == 0) {
+
+ std::cout << "TimerManager tests..." << std::endl;
+
+ std::cout << "\t\tTimerManager test00" << std::endl;
+
+ TimerManagerTests timerManagerTests;
+
+ assert(timerManagerTests.test00());
+ }
+
+ if (runAll || args[0].compare("thread-manager") == 0) {
+
+ std::cout << "ThreadManager tests..." << std::endl;
+
+ {
+
+ size_t workerCount = 100;
+
+ size_t taskCount = 100000;
+
+ int64_t delay = 10LL;
+
+ std::cout << "\t\tThreadManager load test: worker count: " << workerCount << " task count: " << taskCount << " delay: " << delay << std::endl;
+
+ ThreadManagerTests threadManagerTests;
+
+ assert(threadManagerTests.loadTest(taskCount, delay, workerCount));
+
+ std::cout << "\t\tThreadManager block test: worker count: " << workerCount << " delay: " << delay << std::endl;
+
+ assert(threadManagerTests.blockTest(delay, workerCount));
+
+ }
+ }
+
+ if (runAll || args[0].compare("thread-manager-benchmark") == 0) {
+
+ std::cout << "ThreadManager benchmark tests..." << std::endl;
+
+ {
+
+ size_t minWorkerCount = 2;
+
+ size_t maxWorkerCount = 512;
+
+ size_t tasksPerWorker = 1000;
+
+ int64_t delay = 10LL;
+
+ for (size_t workerCount = minWorkerCount; workerCount < maxWorkerCount; workerCount*= 2) {
+
+ size_t taskCount = workerCount * tasksPerWorker;
+
+ std::cout << "\t\tThreadManager load test: worker count: " << workerCount << " task count: " << taskCount << " delay: " << delay << std::endl;
+
+ ThreadManagerTests threadManagerTests;
+
+ threadManagerTests.loadTest(taskCount, delay, workerCount);
+ }
+ }
+ }
+}
diff --git a/lib/cpp/src/concurrency/test/ThreadFactoryTests.h b/lib/cpp/src/concurrency/test/ThreadFactoryTests.h
new file mode 100644
index 0000000..859fbaf
--- /dev/null
+++ b/lib/cpp/src/concurrency/test/ThreadFactoryTests.h
@@ -0,0 +1,357 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <config.h>
+#include <concurrency/Thread.h>
+#include <concurrency/PosixThreadFactory.h>
+#include <concurrency/Monitor.h>
+#include <concurrency/Util.h>
+
+#include <assert.h>
+#include <iostream>
+#include <set>
+
+namespace apache { namespace thrift { namespace concurrency { namespace test {
+
+using boost::shared_ptr;
+using namespace apache::thrift::concurrency;
+
+/**
+ * ThreadManagerTests class
+ *
+ * @version $Id:$
+ */
+class ThreadFactoryTests {
+
+public:
+
+ static const double ERROR;
+
+ class Task: public Runnable {
+
+ public:
+
+ Task() {}
+
+ void run() {
+ std::cout << "\t\t\tHello World" << std::endl;
+ }
+ };
+
+ /**
+ * Hello world test
+ */
+ bool helloWorldTest() {
+
+ PosixThreadFactory threadFactory = PosixThreadFactory();
+
+ shared_ptr<Task> task = shared_ptr<Task>(new ThreadFactoryTests::Task());
+
+ shared_ptr<Thread> thread = threadFactory.newThread(task);
+
+ thread->start();
+
+ thread->join();
+
+ std::cout << "\t\t\tSuccess!" << std::endl;
+
+ return true;
+ }
+
+ /**
+ * Reap N threads
+ */
+ class ReapNTask: public Runnable {
+
+ public:
+
+ ReapNTask(Monitor& monitor, int& activeCount) :
+ _monitor(monitor),
+ _count(activeCount) {}
+
+ void run() {
+ Synchronized s(_monitor);
+
+ _count--;
+
+ //std::cout << "\t\t\tthread count: " << _count << std::endl;
+
+ if (_count == 0) {
+ _monitor.notify();
+ }
+ }
+
+ Monitor& _monitor;
+
+ int& _count;
+ };
+
+ bool reapNThreads(int loop=1, int count=10) {
+
+ PosixThreadFactory threadFactory = PosixThreadFactory();
+
+ Monitor* monitor = new Monitor();
+
+ for(int lix = 0; lix < loop; lix++) {
+
+ int* activeCount = new int(count);
+
+ std::set<shared_ptr<Thread> > threads;
+
+ int tix;
+
+ for (tix = 0; tix < count; tix++) {
+ try {
+ threads.insert(threadFactory.newThread(shared_ptr<Runnable>(new ReapNTask(*monitor, *activeCount))));
+ } catch(SystemResourceException& e) {
+ std::cout << "\t\t\tfailed to create " << lix * count + tix << " thread " << e.what() << std::endl;
+ throw e;
+ }
+ }
+
+ tix = 0;
+ for (std::set<shared_ptr<Thread> >::const_iterator thread = threads.begin(); thread != threads.end(); tix++, ++thread) {
+
+ try {
+ (*thread)->start();
+ } catch(SystemResourceException& e) {
+ std::cout << "\t\t\tfailed to start " << lix * count + tix << " thread " << e.what() << std::endl;
+ throw e;
+ }
+ }
+
+ {
+ Synchronized s(*monitor);
+ while (*activeCount > 0) {
+ monitor->wait(1000);
+ }
+ }
+
+ for (std::set<shared_ptr<Thread> >::const_iterator thread = threads.begin(); thread != threads.end(); thread++) {
+ threads.erase(*thread);
+ }
+
+ std::cout << "\t\t\treaped " << lix * count << " threads" << std::endl;
+ }
+
+ std::cout << "\t\t\tSuccess!" << std::endl;
+
+ return true;
+ }
+
+ class SynchStartTask: public Runnable {
+
+ public:
+
+ enum STATE {
+ UNINITIALIZED,
+ STARTING,
+ STARTED,
+ STOPPING,
+ STOPPED
+ };
+
+ SynchStartTask(Monitor& monitor, volatile STATE& state) :
+ _monitor(monitor),
+ _state(state) {}
+
+ void run() {
+ {
+ Synchronized s(_monitor);
+ if (_state == SynchStartTask::STARTING) {
+ _state = SynchStartTask::STARTED;
+ _monitor.notify();
+ }
+ }
+
+ {
+ Synchronized s(_monitor);
+ while (_state == SynchStartTask::STARTED) {
+ _monitor.wait();
+ }
+
+ if (_state == SynchStartTask::STOPPING) {
+ _state = SynchStartTask::STOPPED;
+ _monitor.notifyAll();
+ }
+ }
+ }
+
+ private:
+ Monitor& _monitor;
+ volatile STATE& _state;
+ };
+
+ bool synchStartTest() {
+
+ Monitor monitor;
+
+ SynchStartTask::STATE state = SynchStartTask::UNINITIALIZED;
+
+ shared_ptr<SynchStartTask> task = shared_ptr<SynchStartTask>(new SynchStartTask(monitor, state));
+
+ PosixThreadFactory threadFactory = PosixThreadFactory();
+
+ shared_ptr<Thread> thread = threadFactory.newThread(task);
+
+ if (state == SynchStartTask::UNINITIALIZED) {
+
+ state = SynchStartTask::STARTING;
+
+ thread->start();
+ }
+
+ {
+ Synchronized s(monitor);
+ while (state == SynchStartTask::STARTING) {
+ monitor.wait();
+ }
+ }
+
+ assert(state != SynchStartTask::STARTING);
+
+ {
+ Synchronized s(monitor);
+
+ try {
+ monitor.wait(100);
+ } catch(TimedOutException& e) {
+ }
+
+ if (state == SynchStartTask::STARTED) {
+
+ state = SynchStartTask::STOPPING;
+
+ monitor.notify();
+ }
+
+ while (state == SynchStartTask::STOPPING) {
+ monitor.wait();
+ }
+ }
+
+ assert(state == SynchStartTask::STOPPED);
+
+ bool success = true;
+
+ std::cout << "\t\t\t" << (success ? "Success" : "Failure") << "!" << std::endl;
+
+ return true;
+ }
+
+ /** See how accurate monitor timeout is. */
+
+ bool monitorTimeoutTest(size_t count=1000, int64_t timeout=10) {
+
+ Monitor monitor;
+
+ int64_t startTime = Util::currentTime();
+
+ for (size_t ix = 0; ix < count; ix++) {
+ {
+ Synchronized s(monitor);
+ try {
+ monitor.wait(timeout);
+ } catch(TimedOutException& e) {
+ }
+ }
+ }
+
+ int64_t endTime = Util::currentTime();
+
+ double error = ((endTime - startTime) - (count * timeout)) / (double)(count * timeout);
+
+ if (error < 0.0) {
+
+ error *= 1.0;
+ }
+
+ bool success = error < ThreadFactoryTests::ERROR;
+
+ std::cout << "\t\t\t" << (success ? "Success" : "Failure") << "! expected time: " << count * timeout << "ms elapsed time: "<< endTime - startTime << "ms error%: " << error * 100.0 << std::endl;
+
+ return success;
+ }
+
+
+ class FloodTask : public Runnable {
+ public:
+
+ FloodTask(const size_t id) :_id(id) {}
+ ~FloodTask(){
+ if(_id % 1000 == 0) {
+ std::cout << "\t\tthread " << _id << " done" << std::endl;
+ }
+ }
+
+ void run(){
+ if(_id % 1000 == 0) {
+ std::cout << "\t\tthread " << _id << " started" << std::endl;
+ }
+
+ usleep(1);
+ }
+ const size_t _id;
+ };
+
+ void foo(PosixThreadFactory *tf) {
+ }
+
+ bool floodNTest(size_t loop=1, size_t count=100000) {
+
+ bool success = false;
+
+ for(size_t lix = 0; lix < loop; lix++) {
+
+ PosixThreadFactory threadFactory = PosixThreadFactory();
+ threadFactory.setDetached(true);
+
+ for(size_t tix = 0; tix < count; tix++) {
+
+ try {
+
+ shared_ptr<FloodTask> task(new FloodTask(lix * count + tix ));
+
+ shared_ptr<Thread> thread = threadFactory.newThread(task);
+
+ thread->start();
+
+ usleep(1);
+
+ } catch (TException& e) {
+
+ std::cout << "\t\t\tfailed to start " << lix * count + tix << " thread " << e.what() << std::endl;
+
+ return success;
+ }
+ }
+
+ std::cout << "\t\t\tflooded " << (lix + 1) * count << " threads" << std::endl;
+
+ success = true;
+ }
+
+ return success;
+ }
+};
+
+const double ThreadFactoryTests::ERROR = .20;
+
+}}}} // apache::thrift::concurrency::test
+
diff --git a/lib/cpp/src/concurrency/test/ThreadManagerTests.h b/lib/cpp/src/concurrency/test/ThreadManagerTests.h
new file mode 100644
index 0000000..e7b5174
--- /dev/null
+++ b/lib/cpp/src/concurrency/test/ThreadManagerTests.h
@@ -0,0 +1,366 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <config.h>
+#include <concurrency/ThreadManager.h>
+#include <concurrency/PosixThreadFactory.h>
+#include <concurrency/Monitor.h>
+#include <concurrency/Util.h>
+
+#include <assert.h>
+#include <set>
+#include <iostream>
+#include <set>
+#include <stdint.h>
+
+namespace apache { namespace thrift { namespace concurrency { namespace test {
+
+using namespace apache::thrift::concurrency;
+
+/**
+ * ThreadManagerTests class
+ *
+ * @version $Id:$
+ */
+class ThreadManagerTests {
+
+public:
+
+ static const double ERROR;
+
+ class Task: public Runnable {
+
+ public:
+
+ Task(Monitor& monitor, size_t& count, int64_t timeout) :
+ _monitor(monitor),
+ _count(count),
+ _timeout(timeout),
+ _done(false) {}
+
+ void run() {
+
+ _startTime = Util::currentTime();
+
+ {
+ Synchronized s(_sleep);
+
+ try {
+ _sleep.wait(_timeout);
+ } catch(TimedOutException& e) {
+ ;
+ }catch(...) {
+ assert(0);
+ }
+ }
+
+ _endTime = Util::currentTime();
+
+ _done = true;
+
+ {
+ Synchronized s(_monitor);
+
+ // std::cout << "Thread " << _count << " completed " << std::endl;
+
+ _count--;
+
+ if (_count == 0) {
+
+ _monitor.notify();
+ }
+ }
+ }
+
+ Monitor& _monitor;
+ size_t& _count;
+ int64_t _timeout;
+ int64_t _startTime;
+ int64_t _endTime;
+ bool _done;
+ Monitor _sleep;
+ };
+
+ /**
+ * Dispatch count tasks, each of which blocks for timeout milliseconds then
+ * completes. Verify that all tasks completed and that thread manager cleans
+ * up properly on delete.
+ */
+ bool loadTest(size_t count=100, int64_t timeout=100LL, size_t workerCount=4) {
+
+ Monitor monitor;
+
+ size_t activeCount = count;
+
+ shared_ptr<ThreadManager> threadManager = ThreadManager::newSimpleThreadManager(workerCount);
+
+ shared_ptr<PosixThreadFactory> threadFactory = shared_ptr<PosixThreadFactory>(new PosixThreadFactory());
+
+ threadFactory->setPriority(PosixThreadFactory::HIGHEST);
+
+ threadManager->threadFactory(threadFactory);
+
+ threadManager->start();
+
+ std::set<shared_ptr<ThreadManagerTests::Task> > tasks;
+
+ for (size_t ix = 0; ix < count; ix++) {
+
+ tasks.insert(shared_ptr<ThreadManagerTests::Task>(new ThreadManagerTests::Task(monitor, activeCount, timeout)));
+ }
+
+ int64_t time00 = Util::currentTime();
+
+ for (std::set<shared_ptr<ThreadManagerTests::Task> >::iterator ix = tasks.begin(); ix != tasks.end(); ix++) {
+
+ threadManager->add(*ix);
+ }
+
+ {
+ Synchronized s(monitor);
+
+ while(activeCount > 0) {
+
+ monitor.wait();
+ }
+ }
+
+ int64_t time01 = Util::currentTime();
+
+ int64_t firstTime = 9223372036854775807LL;
+ int64_t lastTime = 0;
+
+ double averageTime = 0;
+ int64_t minTime = 9223372036854775807LL;
+ int64_t maxTime = 0;
+
+ for (std::set<shared_ptr<ThreadManagerTests::Task> >::iterator ix = tasks.begin(); ix != tasks.end(); ix++) {
+
+ shared_ptr<ThreadManagerTests::Task> task = *ix;
+
+ int64_t delta = task->_endTime - task->_startTime;
+
+ assert(delta > 0);
+
+ if (task->_startTime < firstTime) {
+ firstTime = task->_startTime;
+ }
+
+ if (task->_endTime > lastTime) {
+ lastTime = task->_endTime;
+ }
+
+ if (delta < minTime) {
+ minTime = delta;
+ }
+
+ if (delta > maxTime) {
+ maxTime = delta;
+ }
+
+ averageTime+= delta;
+ }
+
+ averageTime /= count;
+
+ std::cout << "\t\t\tfirst start: " << firstTime << "ms Last end: " << lastTime << "ms min: " << minTime << "ms max: " << maxTime << "ms average: " << averageTime << "ms" << std::endl;
+
+ double expectedTime = ((count + (workerCount - 1)) / workerCount) * timeout;
+
+ double error = ((time01 - time00) - expectedTime) / expectedTime;
+
+ if (error < 0) {
+ error*= -1.0;
+ }
+
+ bool success = error < ERROR;
+
+ std::cout << "\t\t\t" << (success ? "Success" : "Failure") << "! expected time: " << expectedTime << "ms elapsed time: "<< time01 - time00 << "ms error%: " << error * 100.0 << std::endl;
+
+ return success;
+ }
+
+ class BlockTask: public Runnable {
+
+ public:
+
+ BlockTask(Monitor& monitor, Monitor& bmonitor, size_t& count) :
+ _monitor(monitor),
+ _bmonitor(bmonitor),
+ _count(count) {}
+
+ void run() {
+ {
+ Synchronized s(_bmonitor);
+
+ _bmonitor.wait();
+
+ }
+
+ {
+ Synchronized s(_monitor);
+
+ _count--;
+
+ if (_count == 0) {
+
+ _monitor.notify();
+ }
+ }
+ }
+
+ Monitor& _monitor;
+ Monitor& _bmonitor;
+ size_t& _count;
+ };
+
+ /**
+ * Block test. Create pendingTaskCountMax tasks. Verify that we block adding the
+ * pendingTaskCountMax + 1th task. Verify that we unblock when a task completes */
+
+ bool blockTest(int64_t timeout=100LL, size_t workerCount=2) {
+
+ bool success = false;
+
+ try {
+
+ Monitor bmonitor;
+ Monitor monitor;
+
+ size_t pendingTaskMaxCount = workerCount;
+
+ size_t activeCounts[] = {workerCount, pendingTaskMaxCount, 1};
+
+ shared_ptr<ThreadManager> threadManager = ThreadManager::newSimpleThreadManager(workerCount, pendingTaskMaxCount);
+
+ shared_ptr<PosixThreadFactory> threadFactory = shared_ptr<PosixThreadFactory>(new PosixThreadFactory());
+
+ threadFactory->setPriority(PosixThreadFactory::HIGHEST);
+
+ threadManager->threadFactory(threadFactory);
+
+ threadManager->start();
+
+ std::set<shared_ptr<ThreadManagerTests::BlockTask> > tasks;
+
+ for (size_t ix = 0; ix < workerCount; ix++) {
+
+ tasks.insert(shared_ptr<ThreadManagerTests::BlockTask>(new ThreadManagerTests::BlockTask(monitor, bmonitor,activeCounts[0])));
+ }
+
+ for (size_t ix = 0; ix < pendingTaskMaxCount; ix++) {
+
+ tasks.insert(shared_ptr<ThreadManagerTests::BlockTask>(new ThreadManagerTests::BlockTask(monitor, bmonitor,activeCounts[1])));
+ }
+
+ for (std::set<shared_ptr<ThreadManagerTests::BlockTask> >::iterator ix = tasks.begin(); ix != tasks.end(); ix++) {
+ threadManager->add(*ix);
+ }
+
+ if(!(success = (threadManager->totalTaskCount() == pendingTaskMaxCount + workerCount))) {
+ throw TException("Unexpected pending task count");
+ }
+
+ shared_ptr<ThreadManagerTests::BlockTask> extraTask(new ThreadManagerTests::BlockTask(monitor, bmonitor, activeCounts[2]));
+
+ try {
+ threadManager->add(extraTask, 1);
+ throw TException("Unexpected success adding task in excess of pending task count");
+ } catch(TimedOutException& e) {
+ }
+
+ std::cout << "\t\t\t" << "Pending tasks " << threadManager->pendingTaskCount() << std::endl;
+
+ {
+ Synchronized s(bmonitor);
+
+ bmonitor.notifyAll();
+ }
+
+ {
+ Synchronized s(monitor);
+
+ while(activeCounts[0] != 0) {
+ monitor.wait();
+ }
+ }
+
+ std::cout << "\t\t\t" << "Pending tasks " << threadManager->pendingTaskCount() << std::endl;
+
+ try {
+ threadManager->add(extraTask, 1);
+ } catch(TimedOutException& e) {
+ std::cout << "\t\t\t" << "add timed out unexpectedly" << std::endl;
+ throw TException("Unexpected timeout adding task");
+
+ } catch(TooManyPendingTasksException& e) {
+ std::cout << "\t\t\t" << "add encountered too many pending exepctions" << std::endl;
+ throw TException("Unexpected timeout adding task");
+ }
+
+ // Wake up tasks that were pending before and wait for them to complete
+
+ {
+ Synchronized s(bmonitor);
+
+ bmonitor.notifyAll();
+ }
+
+ {
+ Synchronized s(monitor);
+
+ while(activeCounts[1] != 0) {
+ monitor.wait();
+ }
+ }
+
+ // Wake up the extra task and wait for it to complete
+
+ {
+ Synchronized s(bmonitor);
+
+ bmonitor.notifyAll();
+ }
+
+ {
+ Synchronized s(monitor);
+
+ while(activeCounts[2] != 0) {
+ monitor.wait();
+ }
+ }
+
+ if(!(success = (threadManager->totalTaskCount() == 0))) {
+ throw TException("Unexpected pending task count");
+ }
+
+ } catch(TException& e) {
+ }
+
+ std::cout << "\t\t\t" << (success ? "Success" : "Failure") << std::endl;
+ return success;
+ }
+};
+
+const double ThreadManagerTests::ERROR = .20;
+
+}}}} // apache::thrift::concurrency
+
+using namespace apache::thrift::concurrency::test;
+
diff --git a/lib/cpp/src/concurrency/test/TimerManagerTests.h b/lib/cpp/src/concurrency/test/TimerManagerTests.h
new file mode 100644
index 0000000..e6fe6ce
--- /dev/null
+++ b/lib/cpp/src/concurrency/test/TimerManagerTests.h
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <concurrency/TimerManager.h>
+#include <concurrency/PosixThreadFactory.h>
+#include <concurrency/Monitor.h>
+#include <concurrency/Util.h>
+
+#include <assert.h>
+#include <iostream>
+
+namespace apache { namespace thrift { namespace concurrency { namespace test {
+
+using namespace apache::thrift::concurrency;
+
+/**
+ * ThreadManagerTests class
+ *
+ * @version $Id:$
+ */
+class TimerManagerTests {
+
+ public:
+
+ static const double ERROR;
+
+ class Task: public Runnable {
+ public:
+
+ Task(Monitor& monitor, int64_t timeout) :
+ _timeout(timeout),
+ _startTime(Util::currentTime()),
+ _monitor(monitor),
+ _success(false),
+ _done(false) {}
+
+ ~Task() { std::cerr << this << std::endl; }
+
+ void run() {
+
+ _endTime = Util::currentTime();
+
+ // Figure out error percentage
+
+ int64_t delta = _endTime - _startTime;
+
+
+ delta = delta > _timeout ? delta - _timeout : _timeout - delta;
+
+ float error = delta / _timeout;
+
+ if(error < ERROR) {
+ _success = true;
+ }
+
+ _done = true;
+
+ std::cout << "\t\t\tTimerManagerTests::Task[" << this << "] done" << std::endl; //debug
+
+ {Synchronized s(_monitor);
+ _monitor.notifyAll();
+ }
+ }
+
+ int64_t _timeout;
+ int64_t _startTime;
+ int64_t _endTime;
+ Monitor& _monitor;
+ bool _success;
+ bool _done;
+ };
+
+ /**
+ * This test creates two tasks and waits for the first to expire within 10%
+ * of the expected expiration time. It then verifies that the timer manager
+ * properly clean up itself and the remaining orphaned timeout task when the
+ * manager goes out of scope and its destructor is called.
+ */
+ bool test00(int64_t timeout=1000LL) {
+
+ shared_ptr<TimerManagerTests::Task> orphanTask = shared_ptr<TimerManagerTests::Task>(new TimerManagerTests::Task(_monitor, 10 * timeout));
+
+ {
+
+ TimerManager timerManager;
+
+ timerManager.threadFactory(shared_ptr<PosixThreadFactory>(new PosixThreadFactory()));
+
+ timerManager.start();
+
+ assert(timerManager.state() == TimerManager::STARTED);
+
+ shared_ptr<TimerManagerTests::Task> task = shared_ptr<TimerManagerTests::Task>(new TimerManagerTests::Task(_monitor, timeout));
+
+ {
+ Synchronized s(_monitor);
+
+ timerManager.add(orphanTask, 10 * timeout);
+
+ timerManager.add(task, timeout);
+
+ _monitor.wait();
+ }
+
+ assert(task->_done);
+
+
+ std::cout << "\t\t\t" << (task->_success ? "Success" : "Failure") << "!" << std::endl;
+ }
+
+ // timerManager.stop(); This is where it happens via destructor
+
+ assert(!orphanTask->_done);
+
+ return true;
+ }
+
+ friend class TestTask;
+
+ Monitor _monitor;
+};
+
+const double TimerManagerTests::ERROR = .20;
+
+}}}} // apache::thrift::concurrency
+
diff --git a/lib/cpp/src/processor/PeekProcessor.cpp b/lib/cpp/src/processor/PeekProcessor.cpp
new file mode 100644
index 0000000..c721861
--- /dev/null
+++ b/lib/cpp/src/processor/PeekProcessor.cpp
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "PeekProcessor.h"
+
+using namespace apache::thrift::transport;
+using namespace apache::thrift::protocol;
+using namespace apache::thrift;
+
+namespace apache { namespace thrift { namespace processor {
+
+PeekProcessor::PeekProcessor() {
+ memoryBuffer_.reset(new TMemoryBuffer());
+ targetTransport_ = memoryBuffer_;
+}
+PeekProcessor::~PeekProcessor() {}
+
+void PeekProcessor::initialize(boost::shared_ptr<TProcessor> actualProcessor,
+ boost::shared_ptr<TProtocolFactory> protocolFactory,
+ boost::shared_ptr<TPipedTransportFactory> transportFactory) {
+ actualProcessor_ = actualProcessor;
+ pipedProtocol_ = protocolFactory->getProtocol(targetTransport_);
+ transportFactory_ = transportFactory;
+ transportFactory_->initializeTargetTransport(targetTransport_);
+}
+
+boost::shared_ptr<TTransport> PeekProcessor::getPipedTransport(boost::shared_ptr<TTransport> in) {
+ return transportFactory_->getTransport(in);
+}
+
+void PeekProcessor::setTargetTransport(boost::shared_ptr<TTransport> targetTransport) {
+ targetTransport_ = targetTransport;
+ if (boost::dynamic_pointer_cast<TMemoryBuffer>(targetTransport_)) {
+ memoryBuffer_ = boost::dynamic_pointer_cast<TMemoryBuffer>(targetTransport);
+ } else if (boost::dynamic_pointer_cast<TPipedTransport>(targetTransport_)) {
+ memoryBuffer_ = boost::dynamic_pointer_cast<TMemoryBuffer>(boost::dynamic_pointer_cast<TPipedTransport>(targetTransport_)->getTargetTransport());
+ }
+
+ if (!memoryBuffer_) {
+ throw TException("Target transport must be a TMemoryBuffer or a TPipedTransport with TMemoryBuffer");
+ }
+}
+
+bool PeekProcessor::process(boost::shared_ptr<TProtocol> in,
+ boost::shared_ptr<TProtocol> out) {
+
+ std::string fname;
+ TMessageType mtype;
+ int32_t seqid;
+ in->readMessageBegin(fname, mtype, seqid);
+
+ if (mtype != T_CALL) {
+ throw TException("Unexpected message type");
+ }
+
+ // Peek at the name
+ peekName(fname);
+
+ TType ftype;
+ int16_t fid;
+ while (true) {
+ in->readFieldBegin(fname, ftype, fid);
+ if (ftype == T_STOP) {
+ break;
+ }
+
+ // Peek at the variable
+ peek(in, ftype, fid);
+ in->readFieldEnd();
+ }
+ in->readMessageEnd();
+ in->getTransport()->readEnd();
+
+ //
+ // All the data is now in memoryBuffer_ and ready to be processed
+ //
+
+ // Let's first take a peek at the full data in memory
+ uint8_t* buffer;
+ uint32_t size;
+ memoryBuffer_->getBuffer(&buffer, &size);
+ peekBuffer(buffer, size);
+
+ // Done peeking at variables
+ peekEnd();
+
+ bool ret = actualProcessor_->process(pipedProtocol_, out);
+ memoryBuffer_->resetBuffer();
+ return ret;
+}
+
+void PeekProcessor::peekName(const std::string& fname) {
+}
+
+void PeekProcessor::peekBuffer(uint8_t* buffer, uint32_t size) {
+}
+
+void PeekProcessor::peek(boost::shared_ptr<TProtocol> in,
+ TType ftype,
+ int16_t fid) {
+ in->skip(ftype);
+}
+
+void PeekProcessor::peekEnd() {}
+
+}}}
diff --git a/lib/cpp/src/processor/PeekProcessor.h b/lib/cpp/src/processor/PeekProcessor.h
new file mode 100644
index 0000000..0f7c016
--- /dev/null
+++ b/lib/cpp/src/processor/PeekProcessor.h
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef PEEKPROCESSOR_H
+#define PEEKPROCESSOR_H
+
+#include <string>
+#include <TProcessor.h>
+#include <transport/TTransport.h>
+#include <transport/TTransportUtils.h>
+#include <transport/TBufferTransports.h>
+#include <boost/shared_ptr.hpp>
+
+namespace apache { namespace thrift { namespace processor {
+
+/*
+ * Class for peeking at the raw data that is being processed by another processor
+ * and gives the derived class a chance to change behavior accordingly
+ *
+ */
+class PeekProcessor : public apache::thrift::TProcessor {
+
+ public:
+ PeekProcessor();
+ virtual ~PeekProcessor();
+
+ // Input here: actualProcessor - the underlying processor
+ // protocolFactory - the protocol factory used to wrap the memory buffer
+ // transportFactory - this TPipedTransportFactory is used to wrap the source transport
+ // via a call to getPipedTransport
+ void initialize(boost::shared_ptr<apache::thrift::TProcessor> actualProcessor,
+ boost::shared_ptr<apache::thrift::protocol::TProtocolFactory> protocolFactory,
+ boost::shared_ptr<apache::thrift::transport::TPipedTransportFactory> transportFactory);
+
+ boost::shared_ptr<apache::thrift::transport::TTransport> getPipedTransport(boost::shared_ptr<apache::thrift::transport::TTransport> in);
+
+ void setTargetTransport(boost::shared_ptr<apache::thrift::transport::TTransport> targetTransport);
+
+ virtual bool process(boost::shared_ptr<apache::thrift::protocol::TProtocol> in,
+ boost::shared_ptr<apache::thrift::protocol::TProtocol> out);
+
+ // The following three functions can be overloaded by child classes to
+ // achieve desired peeking behavior
+ virtual void peekName(const std::string& fname);
+ virtual void peekBuffer(uint8_t* buffer, uint32_t size);
+ virtual void peek(boost::shared_ptr<apache::thrift::protocol::TProtocol> in,
+ apache::thrift::protocol::TType ftype,
+ int16_t fid);
+ virtual void peekEnd();
+
+ private:
+ boost::shared_ptr<apache::thrift::TProcessor> actualProcessor_;
+ boost::shared_ptr<apache::thrift::protocol::TProtocol> pipedProtocol_;
+ boost::shared_ptr<apache::thrift::transport::TPipedTransportFactory> transportFactory_;
+ boost::shared_ptr<apache::thrift::transport::TMemoryBuffer> memoryBuffer_;
+ boost::shared_ptr<apache::thrift::transport::TTransport> targetTransport_;
+};
+
+}}} // apache::thrift::processor
+
+#endif
diff --git a/lib/cpp/src/processor/StatsProcessor.h b/lib/cpp/src/processor/StatsProcessor.h
new file mode 100644
index 0000000..820b3ad
--- /dev/null
+++ b/lib/cpp/src/processor/StatsProcessor.h
@@ -0,0 +1,264 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef STATSPROCESSOR_H
+#define STATSPROCESSOR_H
+
+#include <boost/shared_ptr.hpp>
+#include <transport/TTransport.h>
+#include <protocol/TProtocol.h>
+#include <TProcessor.h>
+
+namespace apache { namespace thrift { namespace processor {
+
+/*
+ * Class for keeping track of function call statistics and printing them if desired
+ *
+ */
+class StatsProcessor : public apache::thrift::TProcessor {
+public:
+ StatsProcessor(bool print, bool frequency)
+ : print_(print),
+ frequency_(frequency)
+ {}
+ virtual ~StatsProcessor() {};
+
+ virtual bool process(boost::shared_ptr<apache::thrift::protocol::TProtocol> piprot, boost::shared_ptr<apache::thrift::protocol::TProtocol> poprot) {
+
+ piprot_ = piprot;
+
+ std::string fname;
+ apache::thrift::protocol::TMessageType mtype;
+ int32_t seqid;
+
+ piprot_->readMessageBegin(fname, mtype, seqid);
+ if (mtype != apache::thrift::protocol::T_CALL) {
+ if (print_) {
+ printf("Unknown message type\n");
+ }
+ throw apache::thrift::TException("Unexpected message type");
+ }
+ if (print_) {
+ printf("%s (", fname.c_str());
+ }
+ if (frequency_) {
+ if (frequency_map_.find(fname) != frequency_map_.end()) {
+ frequency_map_[fname]++;
+ } else {
+ frequency_map_[fname] = 1;
+ }
+ }
+
+ apache::thrift::protocol::TType ftype;
+ int16_t fid;
+
+ while (true) {
+ piprot_->readFieldBegin(fname, ftype, fid);
+ if (ftype == apache::thrift::protocol::T_STOP) {
+ break;
+ }
+
+ printAndPassToBuffer(ftype);
+ if (print_) {
+ printf(", ");
+ }
+ }
+
+ if (print_) {
+ printf("\b\b)\n");
+ }
+ return true;
+ }
+
+ const std::map<std::string, int64_t>& get_frequency_map() {
+ return frequency_map_;
+ }
+
+protected:
+ void printAndPassToBuffer(apache::thrift::protocol::TType ftype) {
+ switch (ftype) {
+ case apache::thrift::protocol::T_BOOL:
+ {
+ bool boolv;
+ piprot_->readBool(boolv);
+ if (print_) {
+ printf("%d", boolv);
+ }
+ }
+ break;
+ case apache::thrift::protocol::T_BYTE:
+ {
+ int8_t bytev;
+ piprot_->readByte(bytev);
+ if (print_) {
+ printf("%d", bytev);
+ }
+ }
+ break;
+ case apache::thrift::protocol::T_I16:
+ {
+ int16_t i16;
+ piprot_->readI16(i16);
+ if (print_) {
+ printf("%d", i16);
+ }
+ }
+ break;
+ case apache::thrift::protocol::T_I32:
+ {
+ int32_t i32;
+ piprot_->readI32(i32);
+ if (print_) {
+ printf("%d", i32);
+ }
+ }
+ break;
+ case apache::thrift::protocol::T_I64:
+ {
+ int64_t i64;
+ piprot_->readI64(i64);
+ if (print_) {
+ printf("%ld", i64);
+ }
+ }
+ break;
+ case apache::thrift::protocol::T_DOUBLE:
+ {
+ double dub;
+ piprot_->readDouble(dub);
+ if (print_) {
+ printf("%f", dub);
+ }
+ }
+ break;
+ case apache::thrift::protocol::T_STRING:
+ {
+ std::string str;
+ piprot_->readString(str);
+ if (print_) {
+ printf("%s", str.c_str());
+ }
+ }
+ break;
+ case apache::thrift::protocol::T_STRUCT:
+ {
+ std::string name;
+ int16_t fid;
+ apache::thrift::protocol::TType ftype;
+ piprot_->readStructBegin(name);
+ if (print_) {
+ printf("<");
+ }
+ while (true) {
+ piprot_->readFieldBegin(name, ftype, fid);
+ if (ftype == apache::thrift::protocol::T_STOP) {
+ break;
+ }
+ printAndPassToBuffer(ftype);
+ if (print_) {
+ printf(",");
+ }
+ piprot_->readFieldEnd();
+ }
+ piprot_->readStructEnd();
+ if (print_) {
+ printf("\b>");
+ }
+ }
+ break;
+ case apache::thrift::protocol::T_MAP:
+ {
+ apache::thrift::protocol::TType keyType;
+ apache::thrift::protocol::TType valType;
+ uint32_t i, size;
+ piprot_->readMapBegin(keyType, valType, size);
+ if (print_) {
+ printf("{");
+ }
+ for (i = 0; i < size; i++) {
+ printAndPassToBuffer(keyType);
+ if (print_) {
+ printf("=>");
+ }
+ printAndPassToBuffer(valType);
+ if (print_) {
+ printf(",");
+ }
+ }
+ piprot_->readMapEnd();
+ if (print_) {
+ printf("\b}");
+ }
+ }
+ break;
+ case apache::thrift::protocol::T_SET:
+ {
+ apache::thrift::protocol::TType elemType;
+ uint32_t i, size;
+ piprot_->readSetBegin(elemType, size);
+ if (print_) {
+ printf("{");
+ }
+ for (i = 0; i < size; i++) {
+ printAndPassToBuffer(elemType);
+ if (print_) {
+ printf(",");
+ }
+ }
+ piprot_->readSetEnd();
+ if (print_) {
+ printf("\b}");
+ }
+ }
+ break;
+ case apache::thrift::protocol::T_LIST:
+ {
+ apache::thrift::protocol::TType elemType;
+ uint32_t i, size;
+ piprot_->readListBegin(elemType, size);
+ if (print_) {
+ printf("[");
+ }
+ for (i = 0; i < size; i++) {
+ printAndPassToBuffer(elemType);
+ if (print_) {
+ printf(",");
+ }
+ }
+ piprot_->readListEnd();
+ if (print_) {
+ printf("\b]");
+ }
+ }
+ break;
+ default:
+ break;
+ }
+ }
+
+ boost::shared_ptr<apache::thrift::protocol::TProtocol> piprot_;
+ std::map<std::string, int64_t> frequency_map_;
+
+ bool print_;
+ bool frequency_;
+};
+
+}}} // apache::thrift::processor
+
+#endif
diff --git a/lib/cpp/src/protocol/TBase64Utils.cpp b/lib/cpp/src/protocol/TBase64Utils.cpp
new file mode 100644
index 0000000..14481c4
--- /dev/null
+++ b/lib/cpp/src/protocol/TBase64Utils.cpp
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "TBase64Utils.h"
+
+#include <boost/static_assert.hpp>
+
+using std::string;
+
+namespace apache { namespace thrift { namespace protocol {
+
+
+static const uint8_t *kBase64EncodeTable = (const uint8_t *)
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
+
+void base64_encode(const uint8_t *in, uint32_t len, uint8_t *buf) {
+ buf[0] = kBase64EncodeTable[(in[0] >> 2) & 0x3F];
+ if (len == 3) {
+ buf[1] = kBase64EncodeTable[((in[0] << 4) + (in[1] >> 4)) & 0x3f];
+ buf[2] = kBase64EncodeTable[((in[1] << 2) + (in[2] >> 6)) & 0x3f];
+ buf[3] = kBase64EncodeTable[in[2] & 0x3f];
+ } else if (len == 2) {
+ buf[1] = kBase64EncodeTable[((in[0] << 4) + (in[1] >> 4)) & 0x3f];
+ buf[2] = kBase64EncodeTable[(in[1] << 2) & 0x3f];
+ } else { // len == 1
+ buf[1] = kBase64EncodeTable[(in[0] << 4) & 0x3f];
+ }
+}
+
+static const uint8_t kBase64DecodeTable[256] ={
+ -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
+ -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
+ -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,62,-1,-1,-1,63,
+ 52,53,54,55,56,57,58,59,60,61,-1,-1,-1,-1,-1,-1,
+ -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,
+ 15,16,17,18,19,20,21,22,23,24,25,-1,-1,-1,-1,-1,
+ -1,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,
+ 41,42,43,44,45,46,47,48,49,50,51,-1,-1,-1,-1,-1,
+ -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
+ -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
+ -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
+ -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
+ -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
+ -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
+ -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
+ -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
+};
+
+void base64_decode(uint8_t *buf, uint32_t len) {
+ buf[0] = (kBase64DecodeTable[buf[0]] << 2) |
+ (kBase64DecodeTable[buf[1]] >> 4);
+ if (len > 2) {
+ buf[1] = ((kBase64DecodeTable[buf[1]] << 4) & 0xf0) |
+ (kBase64DecodeTable[buf[2]] >> 2);
+ if (len > 3) {
+ buf[2] = ((kBase64DecodeTable[buf[2]] << 6) & 0xc0) |
+ (kBase64DecodeTable[buf[3]]);
+ }
+ }
+}
+
+
+}}} // apache::thrift::protocol
diff --git a/lib/cpp/src/protocol/TBase64Utils.h b/lib/cpp/src/protocol/TBase64Utils.h
new file mode 100644
index 0000000..3def733
--- /dev/null
+++ b/lib/cpp/src/protocol/TBase64Utils.h
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_PROTOCOL_TBASE64UTILS_H_
+#define _THRIFT_PROTOCOL_TBASE64UTILS_H_
+
+#include <stdint.h>
+#include <string>
+
+namespace apache { namespace thrift { namespace protocol {
+
+// in must be at least len bytes
+// len must be 1, 2, or 3
+// buf must be a buffer of at least 4 bytes and may not overlap in
+// the data is not padded with '='; the caller can do this if desired
+void base64_encode(const uint8_t *in, uint32_t len, uint8_t *buf);
+
+// buf must be a buffer of at least 4 bytes and contain base64 encoded values
+// buf will be changed to contain output bytes
+// len is number of bytes to consume from input (must be 2, 3, or 4)
+// no '=' padding should be included in the input
+void base64_decode(uint8_t *buf, uint32_t len);
+
+}}} // apache::thrift::protocol
+
+#endif // #define _THRIFT_PROTOCOL_TBASE64UTILS_H_
diff --git a/lib/cpp/src/protocol/TBinaryProtocol.cpp b/lib/cpp/src/protocol/TBinaryProtocol.cpp
new file mode 100644
index 0000000..6a4838b
--- /dev/null
+++ b/lib/cpp/src/protocol/TBinaryProtocol.cpp
@@ -0,0 +1,394 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "TBinaryProtocol.h"
+
+#include <limits>
+
+using std::string;
+
+namespace apache { namespace thrift { namespace protocol {
+
+uint32_t TBinaryProtocol::writeMessageBegin(const std::string& name,
+ const TMessageType messageType,
+ const int32_t seqid) {
+ if (strict_write_) {
+ int32_t version = (VERSION_1) | ((int32_t)messageType);
+ uint32_t wsize = 0;
+ wsize += writeI32(version);
+ wsize += writeString(name);
+ wsize += writeI32(seqid);
+ return wsize;
+ } else {
+ uint32_t wsize = 0;
+ wsize += writeString(name);
+ wsize += writeByte((int8_t)messageType);
+ wsize += writeI32(seqid);
+ return wsize;
+ }
+}
+
+uint32_t TBinaryProtocol::writeMessageEnd() {
+ return 0;
+}
+
+uint32_t TBinaryProtocol::writeStructBegin(const char* name) {
+ return 0;
+}
+
+uint32_t TBinaryProtocol::writeStructEnd() {
+ return 0;
+}
+
+uint32_t TBinaryProtocol::writeFieldBegin(const char* name,
+ const TType fieldType,
+ const int16_t fieldId) {
+ uint32_t wsize = 0;
+ wsize += writeByte((int8_t)fieldType);
+ wsize += writeI16(fieldId);
+ return wsize;
+}
+
+uint32_t TBinaryProtocol::writeFieldEnd() {
+ return 0;
+}
+
+uint32_t TBinaryProtocol::writeFieldStop() {
+ return
+ writeByte((int8_t)T_STOP);
+}
+
+uint32_t TBinaryProtocol::writeMapBegin(const TType keyType,
+ const TType valType,
+ const uint32_t size) {
+ uint32_t wsize = 0;
+ wsize += writeByte((int8_t)keyType);
+ wsize += writeByte((int8_t)valType);
+ wsize += writeI32((int32_t)size);
+ return wsize;
+}
+
+uint32_t TBinaryProtocol::writeMapEnd() {
+ return 0;
+}
+
+uint32_t TBinaryProtocol::writeListBegin(const TType elemType,
+ const uint32_t size) {
+ uint32_t wsize = 0;
+ wsize += writeByte((int8_t) elemType);
+ wsize += writeI32((int32_t)size);
+ return wsize;
+}
+
+uint32_t TBinaryProtocol::writeListEnd() {
+ return 0;
+}
+
+uint32_t TBinaryProtocol::writeSetBegin(const TType elemType,
+ const uint32_t size) {
+ uint32_t wsize = 0;
+ wsize += writeByte((int8_t)elemType);
+ wsize += writeI32((int32_t)size);
+ return wsize;
+}
+
+uint32_t TBinaryProtocol::writeSetEnd() {
+ return 0;
+}
+
+uint32_t TBinaryProtocol::writeBool(const bool value) {
+ uint8_t tmp = value ? 1 : 0;
+ trans_->write(&tmp, 1);
+ return 1;
+}
+
+uint32_t TBinaryProtocol::writeByte(const int8_t byte) {
+ trans_->write((uint8_t*)&byte, 1);
+ return 1;
+}
+
+uint32_t TBinaryProtocol::writeI16(const int16_t i16) {
+ int16_t net = (int16_t)htons(i16);
+ trans_->write((uint8_t*)&net, 2);
+ return 2;
+}
+
+uint32_t TBinaryProtocol::writeI32(const int32_t i32) {
+ int32_t net = (int32_t)htonl(i32);
+ trans_->write((uint8_t*)&net, 4);
+ return 4;
+}
+
+uint32_t TBinaryProtocol::writeI64(const int64_t i64) {
+ int64_t net = (int64_t)htonll(i64);
+ trans_->write((uint8_t*)&net, 8);
+ return 8;
+}
+
+uint32_t TBinaryProtocol::writeDouble(const double dub) {
+ BOOST_STATIC_ASSERT(sizeof(double) == sizeof(uint64_t));
+ BOOST_STATIC_ASSERT(std::numeric_limits<double>::is_iec559);
+
+ uint64_t bits = bitwise_cast<uint64_t>(dub);
+ bits = htonll(bits);
+ trans_->write((uint8_t*)&bits, 8);
+ return 8;
+}
+
+
+uint32_t TBinaryProtocol::writeString(const string& str) {
+ uint32_t size = str.size();
+ uint32_t result = writeI32((int32_t)size);
+ if (size > 0) {
+ trans_->write((uint8_t*)str.data(), size);
+ }
+ return result + size;
+}
+
+uint32_t TBinaryProtocol::writeBinary(const string& str) {
+ return TBinaryProtocol::writeString(str);
+}
+
+/**
+ * Reading functions
+ */
+
+uint32_t TBinaryProtocol::readMessageBegin(std::string& name,
+ TMessageType& messageType,
+ int32_t& seqid) {
+ uint32_t result = 0;
+ int32_t sz;
+ result += readI32(sz);
+
+ if (sz < 0) {
+ // Check for correct version number
+ int32_t version = sz & VERSION_MASK;
+ if (version != VERSION_1) {
+ throw TProtocolException(TProtocolException::BAD_VERSION, "Bad version identifier");
+ }
+ messageType = (TMessageType)(sz & 0x000000ff);
+ result += readString(name);
+ result += readI32(seqid);
+ } else {
+ if (strict_read_) {
+ throw TProtocolException(TProtocolException::BAD_VERSION, "No version identifier... old protocol client in strict mode?");
+ } else {
+ // Handle pre-versioned input
+ int8_t type;
+ result += readStringBody(name, sz);
+ result += readByte(type);
+ messageType = (TMessageType)type;
+ result += readI32(seqid);
+ }
+ }
+ return result;
+}
+
+uint32_t TBinaryProtocol::readMessageEnd() {
+ return 0;
+}
+
+uint32_t TBinaryProtocol::readStructBegin(string& name) {
+ name = "";
+ return 0;
+}
+
+uint32_t TBinaryProtocol::readStructEnd() {
+ return 0;
+}
+
+uint32_t TBinaryProtocol::readFieldBegin(string& name,
+ TType& fieldType,
+ int16_t& fieldId) {
+ uint32_t result = 0;
+ int8_t type;
+ result += readByte(type);
+ fieldType = (TType)type;
+ if (fieldType == T_STOP) {
+ fieldId = 0;
+ return result;
+ }
+ result += readI16(fieldId);
+ return result;
+}
+
+uint32_t TBinaryProtocol::readFieldEnd() {
+ return 0;
+}
+
+uint32_t TBinaryProtocol::readMapBegin(TType& keyType,
+ TType& valType,
+ uint32_t& size) {
+ int8_t k, v;
+ uint32_t result = 0;
+ int32_t sizei;
+ result += readByte(k);
+ keyType = (TType)k;
+ result += readByte(v);
+ valType = (TType)v;
+ result += readI32(sizei);
+ if (sizei < 0) {
+ throw TProtocolException(TProtocolException::NEGATIVE_SIZE);
+ } else if (container_limit_ && sizei > container_limit_) {
+ throw TProtocolException(TProtocolException::SIZE_LIMIT);
+ }
+ size = (uint32_t)sizei;
+ return result;
+}
+
+uint32_t TBinaryProtocol::readMapEnd() {
+ return 0;
+}
+
+uint32_t TBinaryProtocol::readListBegin(TType& elemType,
+ uint32_t& size) {
+ int8_t e;
+ uint32_t result = 0;
+ int32_t sizei;
+ result += readByte(e);
+ elemType = (TType)e;
+ result += readI32(sizei);
+ if (sizei < 0) {
+ throw TProtocolException(TProtocolException::NEGATIVE_SIZE);
+ } else if (container_limit_ && sizei > container_limit_) {
+ throw TProtocolException(TProtocolException::SIZE_LIMIT);
+ }
+ size = (uint32_t)sizei;
+ return result;
+}
+
+uint32_t TBinaryProtocol::readListEnd() {
+ return 0;
+}
+
+uint32_t TBinaryProtocol::readSetBegin(TType& elemType,
+ uint32_t& size) {
+ int8_t e;
+ uint32_t result = 0;
+ int32_t sizei;
+ result += readByte(e);
+ elemType = (TType)e;
+ result += readI32(sizei);
+ if (sizei < 0) {
+ throw TProtocolException(TProtocolException::NEGATIVE_SIZE);
+ } else if (container_limit_ && sizei > container_limit_) {
+ throw TProtocolException(TProtocolException::SIZE_LIMIT);
+ }
+ size = (uint32_t)sizei;
+ return result;
+}
+
+uint32_t TBinaryProtocol::readSetEnd() {
+ return 0;
+}
+
+uint32_t TBinaryProtocol::readBool(bool& value) {
+ uint8_t b[1];
+ trans_->readAll(b, 1);
+ value = *(int8_t*)b != 0;
+ return 1;
+}
+
+uint32_t TBinaryProtocol::readByte(int8_t& byte) {
+ uint8_t b[1];
+ trans_->readAll(b, 1);
+ byte = *(int8_t*)b;
+ return 1;
+}
+
+uint32_t TBinaryProtocol::readI16(int16_t& i16) {
+ uint8_t b[2];
+ trans_->readAll(b, 2);
+ i16 = *(int16_t*)b;
+ i16 = (int16_t)ntohs(i16);
+ return 2;
+}
+
+uint32_t TBinaryProtocol::readI32(int32_t& i32) {
+ uint8_t b[4];
+ trans_->readAll(b, 4);
+ i32 = *(int32_t*)b;
+ i32 = (int32_t)ntohl(i32);
+ return 4;
+}
+
+uint32_t TBinaryProtocol::readI64(int64_t& i64) {
+ uint8_t b[8];
+ trans_->readAll(b, 8);
+ i64 = *(int64_t*)b;
+ i64 = (int64_t)ntohll(i64);
+ return 8;
+}
+
+uint32_t TBinaryProtocol::readDouble(double& dub) {
+ BOOST_STATIC_ASSERT(sizeof(double) == sizeof(uint64_t));
+ BOOST_STATIC_ASSERT(std::numeric_limits<double>::is_iec559);
+
+ uint64_t bits;
+ uint8_t b[8];
+ trans_->readAll(b, 8);
+ bits = *(uint64_t*)b;
+ bits = ntohll(bits);
+ dub = bitwise_cast<double>(bits);
+ return 8;
+}
+
+uint32_t TBinaryProtocol::readString(string& str) {
+ uint32_t result;
+ int32_t size;
+ result = readI32(size);
+ return result + readStringBody(str, size);
+}
+
+uint32_t TBinaryProtocol::readBinary(string& str) {
+ return TBinaryProtocol::readString(str);
+}
+
+uint32_t TBinaryProtocol::readStringBody(string& str, int32_t size) {
+ uint32_t result = 0;
+
+ // Catch error cases
+ if (size < 0) {
+ throw TProtocolException(TProtocolException::NEGATIVE_SIZE);
+ }
+ if (string_limit_ > 0 && size > string_limit_) {
+ throw TProtocolException(TProtocolException::SIZE_LIMIT);
+ }
+
+ // Catch empty string case
+ if (size == 0) {
+ str = "";
+ return result;
+ }
+
+ // Use the heap here to prevent stack overflow for v. large strings
+ if (size > string_buf_size_ || string_buf_ == NULL) {
+ void* new_string_buf = std::realloc(string_buf_, (uint32_t)size);
+ if (new_string_buf == NULL) {
+ throw TProtocolException(TProtocolException::UNKNOWN, "Out of memory in TBinaryProtocol::readString");
+ }
+ string_buf_ = (uint8_t*)new_string_buf;
+ string_buf_size_ = size;
+ }
+ trans_->readAll(string_buf_, size);
+ str = string((char*)string_buf_, size);
+ return (uint32_t)size;
+}
+
+}}} // apache::thrift::protocol
diff --git a/lib/cpp/src/protocol/TBinaryProtocol.h b/lib/cpp/src/protocol/TBinaryProtocol.h
new file mode 100644
index 0000000..7fd3de6
--- /dev/null
+++ b/lib/cpp/src/protocol/TBinaryProtocol.h
@@ -0,0 +1,254 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_PROTOCOL_TBINARYPROTOCOL_H_
+#define _THRIFT_PROTOCOL_TBINARYPROTOCOL_H_ 1
+
+#include "TProtocol.h"
+
+#include <boost/shared_ptr.hpp>
+
+namespace apache { namespace thrift { namespace protocol {
+
+/**
+ * The default binary protocol for thrift. Writes all data in a very basic
+ * binary format, essentially just spitting out the raw bytes.
+ *
+ */
+class TBinaryProtocol : public TProtocol {
+ protected:
+ static const int32_t VERSION_MASK = 0xffff0000;
+ static const int32_t VERSION_1 = 0x80010000;
+ // VERSION_2 (0x80020000) is taken by TDenseProtocol.
+
+ public:
+ TBinaryProtocol(boost::shared_ptr<TTransport> trans) :
+ TProtocol(trans),
+ string_limit_(0),
+ container_limit_(0),
+ strict_read_(false),
+ strict_write_(true),
+ string_buf_(NULL),
+ string_buf_size_(0) {}
+
+ TBinaryProtocol(boost::shared_ptr<TTransport> trans,
+ int32_t string_limit,
+ int32_t container_limit,
+ bool strict_read,
+ bool strict_write) :
+ TProtocol(trans),
+ string_limit_(string_limit),
+ container_limit_(container_limit),
+ strict_read_(strict_read),
+ strict_write_(strict_write),
+ string_buf_(NULL),
+ string_buf_size_(0) {}
+
+ ~TBinaryProtocol() {
+ if (string_buf_ != NULL) {
+ std::free(string_buf_);
+ string_buf_size_ = 0;
+ }
+ }
+
+ void setStringSizeLimit(int32_t string_limit) {
+ string_limit_ = string_limit;
+ }
+
+ void setContainerSizeLimit(int32_t container_limit) {
+ container_limit_ = container_limit;
+ }
+
+ void setStrict(bool strict_read, bool strict_write) {
+ strict_read_ = strict_read;
+ strict_write_ = strict_write;
+ }
+
+ /**
+ * Writing functions.
+ */
+
+ virtual uint32_t writeMessageBegin(const std::string& name,
+ const TMessageType messageType,
+ const int32_t seqid);
+
+ virtual uint32_t writeMessageEnd();
+
+
+ uint32_t writeStructBegin(const char* name);
+
+ uint32_t writeStructEnd();
+
+ uint32_t writeFieldBegin(const char* name,
+ const TType fieldType,
+ const int16_t fieldId);
+
+ uint32_t writeFieldEnd();
+
+ uint32_t writeFieldStop();
+
+ uint32_t writeMapBegin(const TType keyType,
+ const TType valType,
+ const uint32_t size);
+
+ uint32_t writeMapEnd();
+
+ uint32_t writeListBegin(const TType elemType,
+ const uint32_t size);
+
+ uint32_t writeListEnd();
+
+ uint32_t writeSetBegin(const TType elemType,
+ const uint32_t size);
+
+ uint32_t writeSetEnd();
+
+ uint32_t writeBool(const bool value);
+
+ uint32_t writeByte(const int8_t byte);
+
+ uint32_t writeI16(const int16_t i16);
+
+ uint32_t writeI32(const int32_t i32);
+
+ uint32_t writeI64(const int64_t i64);
+
+ uint32_t writeDouble(const double dub);
+
+ uint32_t writeString(const std::string& str);
+
+ uint32_t writeBinary(const std::string& str);
+
+ /**
+ * Reading functions
+ */
+
+
+ uint32_t readMessageBegin(std::string& name,
+ TMessageType& messageType,
+ int32_t& seqid);
+
+ uint32_t readMessageEnd();
+
+ uint32_t readStructBegin(std::string& name);
+
+ uint32_t readStructEnd();
+
+ uint32_t readFieldBegin(std::string& name,
+ TType& fieldType,
+ int16_t& fieldId);
+
+ uint32_t readFieldEnd();
+
+ uint32_t readMapBegin(TType& keyType,
+ TType& valType,
+ uint32_t& size);
+
+ uint32_t readMapEnd();
+
+ uint32_t readListBegin(TType& elemType,
+ uint32_t& size);
+
+ uint32_t readListEnd();
+
+ uint32_t readSetBegin(TType& elemType,
+ uint32_t& size);
+
+ uint32_t readSetEnd();
+
+ uint32_t readBool(bool& value);
+
+ uint32_t readByte(int8_t& byte);
+
+ uint32_t readI16(int16_t& i16);
+
+ uint32_t readI32(int32_t& i32);
+
+ uint32_t readI64(int64_t& i64);
+
+ uint32_t readDouble(double& dub);
+
+ uint32_t readString(std::string& str);
+
+ uint32_t readBinary(std::string& str);
+
+ protected:
+ uint32_t readStringBody(std::string& str, int32_t sz);
+
+ int32_t string_limit_;
+ int32_t container_limit_;
+
+ // Enforce presence of version identifier
+ bool strict_read_;
+ bool strict_write_;
+
+ // Buffer for reading strings, save for the lifetime of the protocol to
+ // avoid memory churn allocating memory on every string read
+ uint8_t* string_buf_;
+ int32_t string_buf_size_;
+
+};
+
+/**
+ * Constructs binary protocol handlers
+ */
+class TBinaryProtocolFactory : public TProtocolFactory {
+ public:
+ TBinaryProtocolFactory() :
+ string_limit_(0),
+ container_limit_(0),
+ strict_read_(false),
+ strict_write_(true) {}
+
+ TBinaryProtocolFactory(int32_t string_limit, int32_t container_limit, bool strict_read, bool strict_write) :
+ string_limit_(string_limit),
+ container_limit_(container_limit),
+ strict_read_(strict_read),
+ strict_write_(strict_write) {}
+
+ virtual ~TBinaryProtocolFactory() {}
+
+ void setStringSizeLimit(int32_t string_limit) {
+ string_limit_ = string_limit;
+ }
+
+ void setContainerSizeLimit(int32_t container_limit) {
+ container_limit_ = container_limit;
+ }
+
+ void setStrict(bool strict_read, bool strict_write) {
+ strict_read_ = strict_read;
+ strict_write_ = strict_write;
+ }
+
+ boost::shared_ptr<TProtocol> getProtocol(boost::shared_ptr<TTransport> trans) {
+ return boost::shared_ptr<TProtocol>(new TBinaryProtocol(trans, string_limit_, container_limit_, strict_read_, strict_write_));
+ }
+
+ private:
+ int32_t string_limit_;
+ int32_t container_limit_;
+ bool strict_read_;
+ bool strict_write_;
+
+};
+
+}}} // apache::thrift::protocol
+
+#endif // #ifndef _THRIFT_PROTOCOL_TBINARYPROTOCOL_H_
diff --git a/lib/cpp/src/protocol/TCompactProtocol.cpp b/lib/cpp/src/protocol/TCompactProtocol.cpp
new file mode 100644
index 0000000..ce2ee54
--- /dev/null
+++ b/lib/cpp/src/protocol/TCompactProtocol.cpp
@@ -0,0 +1,736 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "TCompactProtocol.h"
+
+#include <config.h>
+#include <limits>
+
+/*
+ * TCompactProtocol::i*ToZigzag depend on the fact that the right shift
+ * operator on a signed integer is an arithmetic (sign-extending) shift.
+ * If this is not the case, the current implementation will not work.
+ * If anyone encounters this error, we can try to figure out the best
+ * way to implement an arithmetic right shift on their platform.
+ */
+#if !defined(SIGNED_RIGHT_SHIFT_IS) || !defined(ARITHMETIC_RIGHT_SHIFT)
+# error "Unable to determine the behavior of a signed right shift"
+#endif
+#if SIGNED_RIGHT_SHIFT_IS != ARITHMETIC_RIGHT_SHIFT
+# error "TCompactProtocol currenly only works if a signed right shift is arithmetic"
+#endif
+
+#ifdef __GNUC__
+#define UNLIKELY(val) (__builtin_expect((val), 0))
+#else
+#define UNLIKELY(val) (val)
+#endif
+
+namespace apache { namespace thrift { namespace protocol {
+
+const int8_t TCompactProtocol::TTypeToCType[16] = {
+ CT_STOP, // T_STOP
+ 0, // unused
+ CT_BOOLEAN_TRUE, // T_BOOL
+ CT_BYTE, // T_BYTE
+ CT_DOUBLE, // T_DOUBLE
+ 0, // unused
+ CT_I16, // T_I16
+ 0, // unused
+ CT_I32, // T_I32
+ 0, // unused
+ CT_I64, // T_I64
+ CT_BINARY, // T_STRING
+ CT_STRUCT, // T_STRUCT
+ CT_MAP, // T_MAP
+ CT_SET, // T_SET
+ CT_LIST, // T_LIST
+ };
+
+
+uint32_t TCompactProtocol::writeMessageBegin(const std::string& name,
+ const TMessageType messageType,
+ const int32_t seqid) {
+ uint32_t wsize = 0;
+ wsize += writeByte(PROTOCOL_ID);
+ wsize += writeByte((VERSION_N & VERSION_MASK) | (((int32_t)messageType << TYPE_SHIFT_AMOUNT) & TYPE_MASK));
+ wsize += writeVarint32(seqid);
+ wsize += writeString(name);
+ return wsize;
+}
+
+/**
+ * Write a field header containing the field id and field type. If the
+ * difference between the current field id and the last one is small (< 15),
+ * then the field id will be encoded in the 4 MSB as a delta. Otherwise, the
+ * field id will follow the type header as a zigzag varint.
+ */
+uint32_t TCompactProtocol::writeFieldBegin(const char* name,
+ const TType fieldType,
+ const int16_t fieldId) {
+ if (fieldType == T_BOOL) {
+ booleanField_.name = name;
+ booleanField_.fieldType = fieldType;
+ booleanField_.fieldId = fieldId;
+ } else {
+ return writeFieldBeginInternal(name, fieldType, fieldId, -1);
+ }
+ return 0;
+}
+
+/**
+ * Write the STOP symbol so we know there are no more fields in this struct.
+ */
+uint32_t TCompactProtocol::writeFieldStop() {
+ return writeByte(T_STOP);
+}
+
+/**
+ * Write a struct begin. This doesn't actually put anything on the wire. We
+ * use it as an opportunity to put special placeholder markers on the field
+ * stack so we can get the field id deltas correct.
+ */
+uint32_t TCompactProtocol::writeStructBegin(const char* name) {
+ lastField_.push(lastFieldId_);
+ lastFieldId_ = 0;
+ return 0;
+}
+
+/**
+ * Write a struct end. This doesn't actually put anything on the wire. We use
+ * this as an opportunity to pop the last field from the current struct off
+ * of the field stack.
+ */
+uint32_t TCompactProtocol::writeStructEnd() {
+ lastFieldId_ = lastField_.top();
+ lastField_.pop();
+ return 0;
+}
+
+/**
+ * Write a List header.
+ */
+uint32_t TCompactProtocol::writeListBegin(const TType elemType,
+ const uint32_t size) {
+ return writeCollectionBegin(elemType, size);
+}
+
+/**
+ * Write a set header.
+ */
+uint32_t TCompactProtocol::writeSetBegin(const TType elemType,
+ const uint32_t size) {
+ return writeCollectionBegin(elemType, size);
+}
+
+/**
+ * Write a map header. If the map is empty, omit the key and value type
+ * headers, as we don't need any additional information to skip it.
+ */
+uint32_t TCompactProtocol::writeMapBegin(const TType keyType,
+ const TType valType,
+ const uint32_t size) {
+ uint32_t wsize = 0;
+
+ if (size == 0) {
+ wsize += writeByte(0);
+ } else {
+ wsize += writeVarint32(size);
+ wsize += writeByte(getCompactType(keyType) << 4 | getCompactType(valType));
+ }
+ return wsize;
+}
+
+/**
+ * Write a boolean value. Potentially, this could be a boolean field, in
+ * which case the field header info isn't written yet. If so, decide what the
+ * right type header is for the value and then write the field header.
+ * Otherwise, write a single byte.
+ */
+uint32_t TCompactProtocol::writeBool(const bool value) {
+ uint32_t wsize = 0;
+
+ if (booleanField_.name != NULL) {
+ // we haven't written the field header yet
+ wsize += writeFieldBeginInternal(booleanField_.name,
+ booleanField_.fieldType,
+ booleanField_.fieldId,
+ value ? CT_BOOLEAN_TRUE : CT_BOOLEAN_FALSE);
+ booleanField_.name = NULL;
+ } else {
+ // we're not part of a field, so just write the value
+ wsize += writeByte(value ? CT_BOOLEAN_TRUE : CT_BOOLEAN_FALSE);
+ }
+ return wsize;
+}
+
+uint32_t TCompactProtocol::writeByte(const int8_t byte) {
+ trans_->write((uint8_t*)&byte, 1);
+ return 1;
+}
+
+/**
+ * Write an i16 as a zigzag varint.
+ */
+uint32_t TCompactProtocol::writeI16(const int16_t i16) {
+ return writeVarint32(i32ToZigzag(i16));
+}
+
+/**
+ * Write an i32 as a zigzag varint.
+ */
+uint32_t TCompactProtocol::writeI32(const int32_t i32) {
+ return writeVarint32(i32ToZigzag(i32));
+}
+
+/**
+ * Write an i64 as a zigzag varint.
+ */
+uint32_t TCompactProtocol::writeI64(const int64_t i64) {
+ return writeVarint64(i64ToZigzag(i64));
+}
+
+/**
+ * Write a double to the wire as 8 bytes.
+ */
+uint32_t TCompactProtocol::writeDouble(const double dub) {
+ BOOST_STATIC_ASSERT(sizeof(double) == sizeof(uint64_t));
+ BOOST_STATIC_ASSERT(std::numeric_limits<double>::is_iec559);
+
+ uint64_t bits = bitwise_cast<uint64_t>(dub);
+ bits = htolell(bits);
+ trans_->write((uint8_t*)&bits, 8);
+ return 8;
+}
+
+/**
+ * Write a string to the wire with a varint size preceeding.
+ */
+uint32_t TCompactProtocol::writeString(const std::string& str) {
+ return writeBinary(str);
+}
+
+uint32_t TCompactProtocol::writeBinary(const std::string& str) {
+ uint32_t ssize = str.size();
+ uint32_t wsize = writeVarint32(ssize) + ssize;
+ trans_->write((uint8_t*)str.data(), ssize);
+ return wsize;
+}
+
+//
+// Internal Writing methods
+//
+
+/**
+ * The workhorse of writeFieldBegin. It has the option of doing a
+ * 'type override' of the type header. This is used specifically in the
+ * boolean field case.
+ */
+int32_t TCompactProtocol::writeFieldBeginInternal(const char* name,
+ const TType fieldType,
+ const int16_t fieldId,
+ int8_t typeOverride) {
+ uint32_t wsize = 0;
+
+ // if there's a type override, use that.
+ int8_t typeToWrite = (typeOverride == -1 ? getCompactType(fieldType) : typeOverride);
+
+ // check if we can use delta encoding for the field id
+ if (fieldId > lastFieldId_ && fieldId - lastFieldId_ <= 15) {
+ // write them together
+ wsize += writeByte((fieldId - lastFieldId_) << 4 | typeToWrite);
+ } else {
+ // write them separate
+ wsize += writeByte(typeToWrite);
+ wsize += writeI16(fieldId);
+ }
+
+ lastFieldId_ = fieldId;
+ return wsize;
+}
+
+/**
+ * Abstract method for writing the start of lists and sets. List and sets on
+ * the wire differ only by the type indicator.
+ */
+uint32_t TCompactProtocol::writeCollectionBegin(int8_t elemType, int32_t size) {
+ uint32_t wsize = 0;
+ if (size <= 14) {
+ wsize += writeByte(size << 4 | getCompactType(elemType));
+ } else {
+ wsize += writeByte(0xf0 | getCompactType(elemType));
+ wsize += writeVarint32(size);
+ }
+ return wsize;
+}
+
+/**
+ * Write an i32 as a varint. Results in 1-5 bytes on the wire.
+ */
+uint32_t TCompactProtocol::writeVarint32(uint32_t n) {
+ uint8_t buf[5];
+ uint32_t wsize = 0;
+
+ while (true) {
+ if ((n & ~0x7F) == 0) {
+ buf[wsize++] = (int8_t)n;
+ break;
+ } else {
+ buf[wsize++] = (int8_t)((n & 0x7F) | 0x80);
+ n >>= 7;
+ }
+ }
+ trans_->write(buf, wsize);
+ return wsize;
+}
+
+/**
+ * Write an i64 as a varint. Results in 1-10 bytes on the wire.
+ */
+uint32_t TCompactProtocol::writeVarint64(uint64_t n) {
+ uint8_t buf[10];
+ uint32_t wsize = 0;
+
+ while (true) {
+ if ((n & ~0x7FL) == 0) {
+ buf[wsize++] = (int8_t)n;
+ break;
+ } else {
+ buf[wsize++] = (int8_t)((n & 0x7F) | 0x80);
+ n >>= 7;
+ }
+ }
+ trans_->write(buf, wsize);
+ return wsize;
+}
+
+/**
+ * Convert l into a zigzag long. This allows negative numbers to be
+ * represented compactly as a varint.
+ */
+uint64_t TCompactProtocol::i64ToZigzag(const int64_t l) {
+ return (l << 1) ^ (l >> 63);
+}
+
+/**
+ * Convert n into a zigzag int. This allows negative numbers to be
+ * represented compactly as a varint.
+ */
+uint32_t TCompactProtocol::i32ToZigzag(const int32_t n) {
+ return (n << 1) ^ (n >> 31);
+}
+
+/**
+ * Given a TType value, find the appropriate TCompactProtocol.Type value
+ */
+int8_t TCompactProtocol::getCompactType(int8_t ttype) {
+ return TTypeToCType[ttype];
+}
+
+//
+// Reading Methods
+//
+
+/**
+ * Read a message header.
+ */
+uint32_t TCompactProtocol::readMessageBegin(std::string& name,
+ TMessageType& messageType,
+ int32_t& seqid) {
+ uint32_t rsize = 0;
+ int8_t protocolId;
+ int8_t versionAndType;
+ int8_t version;
+
+ rsize += readByte(protocolId);
+ if (protocolId != PROTOCOL_ID) {
+ throw TProtocolException(TProtocolException::BAD_VERSION, "Bad protocol identifier");
+ }
+
+ rsize += readByte(versionAndType);
+ version = (int8_t)(versionAndType & VERSION_MASK);
+ if (version != VERSION_N) {
+ throw TProtocolException(TProtocolException::BAD_VERSION, "Bad protocol version");
+ }
+
+ messageType = (TMessageType)((versionAndType >> TYPE_SHIFT_AMOUNT) & 0x03);
+ rsize += readVarint32(seqid);
+ rsize += readString(name);
+
+ return rsize;
+}
+
+/**
+ * Read a struct begin. There's nothing on the wire for this, but it is our
+ * opportunity to push a new struct begin marker on the field stack.
+ */
+uint32_t TCompactProtocol::readStructBegin(std::string& name) {
+ name = "";
+ lastField_.push(lastFieldId_);
+ lastFieldId_ = 0;
+ return 0;
+}
+
+/**
+ * Doesn't actually consume any wire data, just removes the last field for
+ * this struct from the field stack.
+ */
+uint32_t TCompactProtocol::readStructEnd() {
+ lastFieldId_ = lastField_.top();
+ lastField_.pop();
+ return 0;
+}
+
+/**
+ * Read a field header off the wire.
+ */
+uint32_t TCompactProtocol::readFieldBegin(std::string& name,
+ TType& fieldType,
+ int16_t& fieldId) {
+ uint32_t rsize = 0;
+ int8_t byte;
+ int8_t type;
+
+ rsize += readByte(byte);
+ type = (byte & 0x0f);
+
+ // if it's a stop, then we can return immediately, as the struct is over.
+ if (type == T_STOP) {
+ fieldType = T_STOP;
+ fieldId = 0;
+ return rsize;
+ }
+
+ // mask off the 4 MSB of the type header. it could contain a field id delta.
+ int16_t modifier = (int16_t)(((uint8_t)byte & 0xf0) >> 4);
+ if (modifier == 0) {
+ // not a delta, look ahead for the zigzag varint field id.
+ rsize += readI16(fieldId);
+ } else {
+ fieldId = (int16_t)(lastFieldId_ + modifier);
+ }
+ fieldType = getTType(type);
+
+ // if this happens to be a boolean field, the value is encoded in the type
+ if (type == CT_BOOLEAN_TRUE || type == CT_BOOLEAN_FALSE) {
+ // save the boolean value in a special instance variable.
+ boolValue_.hasBoolValue = true;
+ boolValue_.boolValue = (type == CT_BOOLEAN_TRUE ? true : false);
+ }
+
+ // push the new field onto the field stack so we can keep the deltas going.
+ lastFieldId_ = fieldId;
+ return rsize;
+}
+
+/**
+ * Read a map header off the wire. If the size is zero, skip reading the key
+ * and value type. This means that 0-length maps will yield TMaps without the
+ * "correct" types.
+ */
+uint32_t TCompactProtocol::readMapBegin(TType& keyType,
+ TType& valType,
+ uint32_t& size) {
+ uint32_t rsize = 0;
+ int8_t kvType = 0;
+ int32_t msize = 0;
+
+ rsize += readVarint32(msize);
+ if (msize != 0)
+ rsize += readByte(kvType);
+
+ if (msize < 0) {
+ throw TProtocolException(TProtocolException::NEGATIVE_SIZE);
+ } else if (container_limit_ && msize > container_limit_) {
+ throw TProtocolException(TProtocolException::SIZE_LIMIT);
+ }
+
+ keyType = getTType((int8_t)((uint8_t)kvType >> 4));
+ valType = getTType((int8_t)((uint8_t)kvType & 0xf));
+ size = (uint32_t)msize;
+
+ return rsize;
+}
+
+/**
+ * Read a list header off the wire. If the list size is 0-14, the size will
+ * be packed into the element type header. If it's a longer list, the 4 MSB
+ * of the element type header will be 0xF, and a varint will follow with the
+ * true size.
+ */
+uint32_t TCompactProtocol::readListBegin(TType& elemType,
+ uint32_t& size) {
+ int8_t size_and_type;
+ uint32_t rsize = 0;
+ int32_t lsize;
+
+ rsize += readByte(size_and_type);
+
+ lsize = ((uint8_t)size_and_type >> 4) & 0x0f;
+ if (lsize == 15) {
+ rsize += readVarint32(lsize);
+ }
+
+ if (lsize < 0) {
+ throw TProtocolException(TProtocolException::NEGATIVE_SIZE);
+ } else if (container_limit_ && lsize > container_limit_) {
+ throw TProtocolException(TProtocolException::SIZE_LIMIT);
+ }
+
+ elemType = getTType((int8_t)(size_and_type & 0x0f));
+ size = (uint32_t)lsize;
+
+ return rsize;
+}
+
+/**
+ * Read a set header off the wire. If the set size is 0-14, the size will
+ * be packed into the element type header. If it's a longer set, the 4 MSB
+ * of the element type header will be 0xF, and a varint will follow with the
+ * true size.
+ */
+uint32_t TCompactProtocol::readSetBegin(TType& elemType,
+ uint32_t& size) {
+ return readListBegin(elemType, size);
+}
+
+/**
+ * Read a boolean off the wire. If this is a boolean field, the value should
+ * already have been read during readFieldBegin, so we'll just consume the
+ * pre-stored value. Otherwise, read a byte.
+ */
+uint32_t TCompactProtocol::readBool(bool& value) {
+ if (boolValue_.hasBoolValue == true) {
+ value = boolValue_.boolValue;
+ boolValue_.hasBoolValue = false;
+ return 0;
+ } else {
+ int8_t val;
+ readByte(val);
+ value = (val == CT_BOOLEAN_TRUE);
+ return 1;
+ }
+}
+
+/**
+ * Read a single byte off the wire. Nothing interesting here.
+ */
+uint32_t TCompactProtocol::readByte(int8_t& byte) {
+ uint8_t b[1];
+ trans_->readAll(b, 1);
+ byte = *(int8_t*)b;
+ return 1;
+}
+
+/**
+ * Read an i16 from the wire as a zigzag varint.
+ */
+uint32_t TCompactProtocol::readI16(int16_t& i16) {
+ int32_t value;
+ uint32_t rsize = readVarint32(value);
+ i16 = (int16_t)zigzagToI32(value);
+ return rsize;
+}
+
+/**
+ * Read an i32 from the wire as a zigzag varint.
+ */
+uint32_t TCompactProtocol::readI32(int32_t& i32) {
+ int32_t value;
+ uint32_t rsize = readVarint32(value);
+ i32 = zigzagToI32(value);
+ return rsize;
+}
+
+/**
+ * Read an i64 from the wire as a zigzag varint.
+ */
+uint32_t TCompactProtocol::readI64(int64_t& i64) {
+ int64_t value;
+ uint32_t rsize = readVarint64(value);
+ i64 = zigzagToI64(value);
+ return rsize;
+}
+
+/**
+ * No magic here - just read a double off the wire.
+ */
+uint32_t TCompactProtocol::readDouble(double& dub) {
+ BOOST_STATIC_ASSERT(sizeof(double) == sizeof(uint64_t));
+ BOOST_STATIC_ASSERT(std::numeric_limits<double>::is_iec559);
+
+ uint64_t bits;
+ uint8_t b[8];
+ trans_->readAll(b, 8);
+ bits = *(uint64_t*)b;
+ bits = letohll(bits);
+ dub = bitwise_cast<double>(bits);
+ return 8;
+}
+
+uint32_t TCompactProtocol::readString(std::string& str) {
+ return readBinary(str);
+}
+
+/**
+ * Read a byte[] from the wire.
+ */
+uint32_t TCompactProtocol::readBinary(std::string& str) {
+ int32_t rsize = 0;
+ int32_t size;
+
+ rsize += readVarint32(size);
+ // Catch empty string case
+ if (size == 0) {
+ str = "";
+ return rsize;
+ }
+
+ // Catch error cases
+ if (size < 0) {
+ throw TProtocolException(TProtocolException::NEGATIVE_SIZE);
+ }
+ if (string_limit_ > 0 && size > string_limit_) {
+ throw TProtocolException(TProtocolException::SIZE_LIMIT);
+ }
+
+ // Use the heap here to prevent stack overflow for v. large strings
+ if (size > string_buf_size_ || string_buf_ == NULL) {
+ void* new_string_buf = std::realloc(string_buf_, (uint32_t)size);
+ if (new_string_buf == NULL) {
+ throw TProtocolException(TProtocolException::UNKNOWN, "Out of memory in TCompactProtocol::readString");
+ }
+ string_buf_ = (uint8_t*)new_string_buf;
+ string_buf_size_ = size;
+ }
+ trans_->readAll(string_buf_, size);
+ str.assign((char*)string_buf_, size);
+
+ return rsize + (uint32_t)size;
+}
+
+/**
+ * Read an i32 from the wire as a varint. The MSB of each byte is set
+ * if there is another byte to follow. This can read up to 5 bytes.
+ */
+uint32_t TCompactProtocol::readVarint32(int32_t& i32) {
+ int64_t val;
+ uint32_t rsize = readVarint64(val);
+ i32 = (int32_t)val;
+ return rsize;
+}
+
+/**
+ * Read an i64 from the wire as a proper varint. The MSB of each byte is set
+ * if there is another byte to follow. This can read up to 10 bytes.
+ */
+uint32_t TCompactProtocol::readVarint64(int64_t& i64) {
+ uint32_t rsize = 0;
+ uint64_t val = 0;
+ int shift = 0;
+ uint8_t buf[10]; // 64 bits / (7 bits/byte) = 10 bytes.
+ uint32_t buf_size = sizeof(buf);
+ const uint8_t* borrowed = trans_->borrow(buf, &buf_size);
+
+ // Fast path.
+ if (borrowed != NULL) {
+ while (true) {
+ uint8_t byte = borrowed[rsize];
+ rsize++;
+ val |= (uint64_t)(byte & 0x7f) << shift;
+ shift += 7;
+ if (!(byte & 0x80)) {
+ i64 = val;
+ trans_->consume(rsize);
+ return rsize;
+ }
+ // Have to check for invalid data so we don't crash.
+ if (UNLIKELY(rsize == sizeof(buf))) {
+ throw TProtocolException(TProtocolException::INVALID_DATA, "Variable-length int over 10 bytes.");
+ }
+ }
+ }
+
+ // Slow path.
+ else {
+ while (true) {
+ uint8_t byte;
+ rsize += trans_->readAll(&byte, 1);
+ val |= (uint64_t)(byte & 0x7f) << shift;
+ shift += 7;
+ if (!(byte & 0x80)) {
+ i64 = val;
+ return rsize;
+ }
+ // Might as well check for invalid data on the slow path too.
+ if (UNLIKELY(rsize >= sizeof(buf))) {
+ throw TProtocolException(TProtocolException::INVALID_DATA, "Variable-length int over 10 bytes.");
+ }
+ }
+ }
+}
+
+/**
+ * Convert from zigzag int to int.
+ */
+int32_t TCompactProtocol::zigzagToI32(uint32_t n) {
+ return (n >> 1) ^ -(n & 1);
+}
+
+/**
+ * Convert from zigzag long to long.
+ */
+int64_t TCompactProtocol::zigzagToI64(uint64_t n) {
+ return (n >> 1) ^ -(n & 1);
+}
+
+TType TCompactProtocol::getTType(int8_t type) {
+ switch (type) {
+ case T_STOP:
+ return T_STOP;
+ case CT_BOOLEAN_FALSE:
+ case CT_BOOLEAN_TRUE:
+ return T_BOOL;
+ case CT_BYTE:
+ return T_BYTE;
+ case CT_I16:
+ return T_I16;
+ case CT_I32:
+ return T_I32;
+ case CT_I64:
+ return T_I64;
+ case CT_DOUBLE:
+ return T_DOUBLE;
+ case CT_BINARY:
+ return T_STRING;
+ case CT_LIST:
+ return T_LIST;
+ case CT_SET:
+ return T_SET;
+ case CT_MAP:
+ return T_MAP;
+ case CT_STRUCT:
+ return T_STRUCT;
+ default:
+ throw TException("don't know what type: " + type);
+ }
+ return T_STOP;
+}
+
+}}} // apache::thrift::protocol
diff --git a/lib/cpp/src/protocol/TCompactProtocol.h b/lib/cpp/src/protocol/TCompactProtocol.h
new file mode 100644
index 0000000..b4e06f0
--- /dev/null
+++ b/lib/cpp/src/protocol/TCompactProtocol.h
@@ -0,0 +1,279 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_PROTOCOL_TCOMPACTPROTOCOL_H_
+#define _THRIFT_PROTOCOL_TCOMPACTPROTOCOL_H_ 1
+
+#include "TProtocol.h"
+
+#include <stack>
+#include <boost/shared_ptr.hpp>
+
+namespace apache { namespace thrift { namespace protocol {
+
+/**
+ * C++ Implementation of the Compact Protocol as described in THRIFT-110
+ */
+class TCompactProtocol : public TProtocol {
+
+ protected:
+ static const int8_t PROTOCOL_ID = 0x82;
+ static const int8_t VERSION_N = 1;
+ static const int8_t VERSION_MASK = 0x1f; // 0001 1111
+ static const int8_t TYPE_MASK = 0xE0; // 1110 0000
+ static const int32_t TYPE_SHIFT_AMOUNT = 5;
+
+ /**
+ * (Writing) If we encounter a boolean field begin, save the TField here
+ * so it can have the value incorporated.
+ */
+ struct {
+ const char* name;
+ TType fieldType;
+ int16_t fieldId;
+ } booleanField_;
+
+ /**
+ * (Reading) If we read a field header, and it's a boolean field, save
+ * the boolean value here so that readBool can use it.
+ */
+ struct {
+ bool hasBoolValue;
+ bool boolValue;
+ } boolValue_;
+
+ /**
+ * Used to keep track of the last field for the current and previous structs,
+ * so we can do the delta stuff.
+ */
+
+ std::stack<int16_t> lastField_;
+ int16_t lastFieldId_;
+
+ enum Types {
+ CT_STOP = 0x00,
+ CT_BOOLEAN_TRUE = 0x01,
+ CT_BOOLEAN_FALSE = 0x02,
+ CT_BYTE = 0x03,
+ CT_I16 = 0x04,
+ CT_I32 = 0x05,
+ CT_I64 = 0x06,
+ CT_DOUBLE = 0x07,
+ CT_BINARY = 0x08,
+ CT_LIST = 0x09,
+ CT_SET = 0x0A,
+ CT_MAP = 0x0B,
+ CT_STRUCT = 0x0C,
+ };
+
+ static const int8_t TTypeToCType[16];
+
+ public:
+ TCompactProtocol(boost::shared_ptr<TTransport> trans) :
+ TProtocol(trans),
+ lastFieldId_(0),
+ string_limit_(0),
+ string_buf_(NULL),
+ string_buf_size_(0),
+ container_limit_(0) {
+ booleanField_.name = NULL;
+ boolValue_.hasBoolValue = false;
+ }
+
+ TCompactProtocol(boost::shared_ptr<TTransport> trans,
+ int32_t string_limit,
+ int32_t container_limit) :
+ TProtocol(trans),
+ lastFieldId_(0),
+ string_limit_(string_limit),
+ string_buf_(NULL),
+ string_buf_size_(0),
+ container_limit_(container_limit) {
+ booleanField_.name = NULL;
+ boolValue_.hasBoolValue = false;
+ }
+
+
+
+ /**
+ * Writing functions
+ */
+
+ virtual uint32_t writeMessageBegin(const std::string& name,
+ const TMessageType messageType,
+ const int32_t seqid);
+
+ uint32_t writeStructBegin(const char* name);
+
+ uint32_t writeStructEnd();
+
+ uint32_t writeFieldBegin(const char* name,
+ const TType fieldType,
+ const int16_t fieldId);
+
+ uint32_t writeFieldStop();
+
+ uint32_t writeListBegin(const TType elemType,
+ const uint32_t size);
+
+ uint32_t writeSetBegin(const TType elemType,
+ const uint32_t size);
+
+ virtual uint32_t writeMapBegin(const TType keyType,
+ const TType valType,
+ const uint32_t size);
+
+ uint32_t writeBool(const bool value);
+
+ uint32_t writeByte(const int8_t byte);
+
+ uint32_t writeI16(const int16_t i16);
+
+ uint32_t writeI32(const int32_t i32);
+
+ uint32_t writeI64(const int64_t i64);
+
+ uint32_t writeDouble(const double dub);
+
+ uint32_t writeString(const std::string& str);
+
+ uint32_t writeBinary(const std::string& str);
+
+ /**
+ * These methods are called by structs, but don't actually have any wired
+ * output or purpose
+ */
+ virtual uint32_t writeMessageEnd() { return 0; }
+ uint32_t writeMapEnd() { return 0; }
+ uint32_t writeListEnd() { return 0; }
+ uint32_t writeSetEnd() { return 0; }
+ uint32_t writeFieldEnd() { return 0; }
+
+ protected:
+ int32_t writeFieldBeginInternal(const char* name,
+ const TType fieldType,
+ const int16_t fieldId,
+ int8_t typeOverride);
+ uint32_t writeCollectionBegin(int8_t elemType, int32_t size);
+ uint32_t writeVarint32(uint32_t n);
+ uint32_t writeVarint64(uint64_t n);
+ uint64_t i64ToZigzag(const int64_t l);
+ uint32_t i32ToZigzag(const int32_t n);
+ inline int8_t getCompactType(int8_t ttype);
+
+ public:
+ uint32_t readMessageBegin(std::string& name,
+ TMessageType& messageType,
+ int32_t& seqid);
+
+ uint32_t readStructBegin(std::string& name);
+
+ uint32_t readStructEnd();
+
+ uint32_t readFieldBegin(std::string& name,
+ TType& fieldType,
+ int16_t& fieldId);
+
+ uint32_t readMapBegin(TType& keyType,
+ TType& valType,
+ uint32_t& size);
+
+ uint32_t readListBegin(TType& elemType,
+ uint32_t& size);
+
+ uint32_t readSetBegin(TType& elemType,
+ uint32_t& size);
+
+ uint32_t readBool(bool& value);
+
+ uint32_t readByte(int8_t& byte);
+
+ uint32_t readI16(int16_t& i16);
+
+ uint32_t readI32(int32_t& i32);
+
+ uint32_t readI64(int64_t& i64);
+
+ uint32_t readDouble(double& dub);
+
+ uint32_t readString(std::string& str);
+
+ uint32_t readBinary(std::string& str);
+
+ /*
+ *These methods are here for the struct to call, but don't have any wire
+ * encoding.
+ */
+ uint32_t readMessageEnd() { return 0; }
+ uint32_t readFieldEnd() { return 0; }
+ uint32_t readMapEnd() { return 0; }
+ uint32_t readListEnd() { return 0; }
+ uint32_t readSetEnd() { return 0; }
+
+ protected:
+ uint32_t readVarint32(int32_t& i32);
+ uint32_t readVarint64(int64_t& i64);
+ int32_t zigzagToI32(uint32_t n);
+ int64_t zigzagToI64(uint64_t n);
+ TType getTType(int8_t type);
+
+ // Buffer for reading strings, save for the lifetime of the protocol to
+ // avoid memory churn allocating memory on every string read
+ int32_t string_limit_;
+ uint8_t* string_buf_;
+ int32_t string_buf_size_;
+ int32_t container_limit_;
+};
+
+/**
+ * Constructs compact protocol handlers
+ */
+class TCompactProtocolFactory : public TProtocolFactory {
+ public:
+ TCompactProtocolFactory() :
+ string_limit_(0),
+ container_limit_(0) {}
+
+ TCompactProtocolFactory(int32_t string_limit, int32_t container_limit) :
+ string_limit_(string_limit),
+ container_limit_(container_limit) {}
+
+ virtual ~TCompactProtocolFactory() {}
+
+ void setStringSizeLimit(int32_t string_limit) {
+ string_limit_ = string_limit;
+ }
+
+ void setContainerSizeLimit(int32_t container_limit) {
+ container_limit_ = container_limit;
+ }
+
+ boost::shared_ptr<TProtocol> getProtocol(boost::shared_ptr<TTransport> trans) {
+ return boost::shared_ptr<TProtocol>(new TCompactProtocol(trans, string_limit_, container_limit_));
+ }
+
+ private:
+ int32_t string_limit_;
+ int32_t container_limit_;
+
+};
+
+}}} // apache::thrift::protocol
+
+#endif
diff --git a/lib/cpp/src/protocol/TDebugProtocol.cpp b/lib/cpp/src/protocol/TDebugProtocol.cpp
new file mode 100644
index 0000000..40aa36b
--- /dev/null
+++ b/lib/cpp/src/protocol/TDebugProtocol.cpp
@@ -0,0 +1,346 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "TDebugProtocol.h"
+
+#include <cassert>
+#include <cctype>
+#include <cstdio>
+#include <stdexcept>
+#include <boost/static_assert.hpp>
+#include <boost/lexical_cast.hpp>
+
+using std::string;
+
+
+static string byte_to_hex(const uint8_t byte) {
+ char buf[3];
+ int ret = std::sprintf(buf, "%02x", (int)byte);
+ assert(ret == 2);
+ assert(buf[2] == '\0');
+ return buf;
+}
+
+
+namespace apache { namespace thrift { namespace protocol {
+
+string TDebugProtocol::fieldTypeName(TType type) {
+ switch (type) {
+ case T_STOP : return "stop" ;
+ case T_VOID : return "void" ;
+ case T_BOOL : return "bool" ;
+ case T_BYTE : return "byte" ;
+ case T_I16 : return "i16" ;
+ case T_I32 : return "i32" ;
+ case T_U64 : return "u64" ;
+ case T_I64 : return "i64" ;
+ case T_DOUBLE : return "double" ;
+ case T_STRING : return "string" ;
+ case T_STRUCT : return "struct" ;
+ case T_MAP : return "map" ;
+ case T_SET : return "set" ;
+ case T_LIST : return "list" ;
+ case T_UTF8 : return "utf8" ;
+ case T_UTF16 : return "utf16" ;
+ default: return "unknown";
+ }
+}
+
+void TDebugProtocol::indentUp() {
+ indent_str_ += string(indent_inc, ' ');
+}
+
+void TDebugProtocol::indentDown() {
+ if (indent_str_.length() < (string::size_type)indent_inc) {
+ throw TProtocolException(TProtocolException::INVALID_DATA);
+ }
+ indent_str_.erase(indent_str_.length() - indent_inc);
+}
+
+uint32_t TDebugProtocol::writePlain(const string& str) {
+ trans_->write((uint8_t*)str.data(), str.length());
+ return str.length();
+}
+
+uint32_t TDebugProtocol::writeIndented(const string& str) {
+ trans_->write((uint8_t*)indent_str_.data(), indent_str_.length());
+ trans_->write((uint8_t*)str.data(), str.length());
+ return indent_str_.length() + str.length();
+}
+
+uint32_t TDebugProtocol::startItem() {
+ uint32_t size;
+
+ switch (write_state_.back()) {
+ case UNINIT:
+ // XXX figure out what to do here.
+ //throw TProtocolException(TProtocolException::INVALID_DATA);
+ //return writeIndented(str);
+ return 0;
+ case STRUCT:
+ return 0;
+ case SET:
+ return writeIndented("");
+ case MAP_KEY:
+ return writeIndented("");
+ case MAP_VALUE:
+ return writePlain(" -> ");
+ case LIST:
+ size = writeIndented(
+ "[" + boost::lexical_cast<string>(list_idx_.back()) + "] = ");
+ list_idx_.back()++;
+ return size;
+ default:
+ throw std::logic_error("Invalid enum value.");
+ }
+}
+
+uint32_t TDebugProtocol::endItem() {
+ //uint32_t size;
+
+ switch (write_state_.back()) {
+ case UNINIT:
+ // XXX figure out what to do here.
+ //throw TProtocolException(TProtocolException::INVALID_DATA);
+ //return writeIndented(str);
+ return 0;
+ case STRUCT:
+ return writePlain(",\n");
+ case SET:
+ return writePlain(",\n");
+ case MAP_KEY:
+ write_state_.back() = MAP_VALUE;
+ return 0;
+ case MAP_VALUE:
+ write_state_.back() = MAP_KEY;
+ return writePlain(",\n");
+ case LIST:
+ return writePlain(",\n");
+ default:
+ throw std::logic_error("Invalid enum value.");
+ }
+}
+
+uint32_t TDebugProtocol::writeItem(const std::string& str) {
+ uint32_t size = 0;
+ size += startItem();
+ size += writePlain(str);
+ size += endItem();
+ return size;
+}
+
+uint32_t TDebugProtocol::writeMessageBegin(const std::string& name,
+ const TMessageType messageType,
+ const int32_t seqid) {
+ string mtype;
+ switch (messageType) {
+ case T_CALL : mtype = "call" ; break;
+ case T_REPLY : mtype = "reply" ; break;
+ case T_EXCEPTION : mtype = "exn" ; break;
+ }
+
+ uint32_t size = writeIndented("(" + mtype + ") " + name + "(");
+ indentUp();
+ return size;
+}
+
+uint32_t TDebugProtocol::writeMessageEnd() {
+ indentDown();
+ return writeIndented(")\n");
+}
+
+uint32_t TDebugProtocol::writeStructBegin(const char* name) {
+ uint32_t size = 0;
+ size += startItem();
+ size += writePlain(string(name) + " {\n");
+ indentUp();
+ write_state_.push_back(STRUCT);
+ return size;
+}
+
+uint32_t TDebugProtocol::writeStructEnd() {
+ indentDown();
+ write_state_.pop_back();
+ uint32_t size = 0;
+ size += writeIndented("}");
+ size += endItem();
+ return size;
+}
+
+uint32_t TDebugProtocol::writeFieldBegin(const char* name,
+ const TType fieldType,
+ const int16_t fieldId) {
+ // sprintf(id_str, "%02d", fieldId);
+ string id_str = boost::lexical_cast<string>(fieldId);
+ if (id_str.length() == 1) id_str = '0' + id_str;
+
+ return writeIndented(
+ id_str + ": " +
+ name + " (" +
+ fieldTypeName(fieldType) + ") = ");
+}
+
+uint32_t TDebugProtocol::writeFieldEnd() {
+ assert(write_state_.back() == STRUCT);
+ return 0;
+}
+
+uint32_t TDebugProtocol::writeFieldStop() {
+ return 0;
+ //writeIndented("***STOP***\n");
+}
+
+uint32_t TDebugProtocol::writeMapBegin(const TType keyType,
+ const TType valType,
+ const uint32_t size) {
+ // TODO(dreiss): Optimize short maps?
+ uint32_t bsize = 0;
+ bsize += startItem();
+ bsize += writePlain(
+ "map<" + fieldTypeName(keyType) + "," + fieldTypeName(valType) + ">"
+ "[" + boost::lexical_cast<string>(size) + "] {\n");
+ indentUp();
+ write_state_.push_back(MAP_KEY);
+ return bsize;
+}
+
+uint32_t TDebugProtocol::writeMapEnd() {
+ indentDown();
+ write_state_.pop_back();
+ uint32_t size = 0;
+ size += writeIndented("}");
+ size += endItem();
+ return size;
+}
+
+uint32_t TDebugProtocol::writeListBegin(const TType elemType,
+ const uint32_t size) {
+ // TODO(dreiss): Optimize short arrays.
+ uint32_t bsize = 0;
+ bsize += startItem();
+ bsize += writePlain(
+ "list<" + fieldTypeName(elemType) + ">"
+ "[" + boost::lexical_cast<string>(size) + "] {\n");
+ indentUp();
+ write_state_.push_back(LIST);
+ list_idx_.push_back(0);
+ return bsize;
+}
+
+uint32_t TDebugProtocol::writeListEnd() {
+ indentDown();
+ write_state_.pop_back();
+ list_idx_.pop_back();
+ uint32_t size = 0;
+ size += writeIndented("}");
+ size += endItem();
+ return size;
+}
+
+uint32_t TDebugProtocol::writeSetBegin(const TType elemType,
+ const uint32_t size) {
+ // TODO(dreiss): Optimize short sets.
+ uint32_t bsize = 0;
+ bsize += startItem();
+ bsize += writePlain(
+ "set<" + fieldTypeName(elemType) + ">"
+ "[" + boost::lexical_cast<string>(size) + "] {\n");
+ indentUp();
+ write_state_.push_back(SET);
+ return bsize;
+}
+
+uint32_t TDebugProtocol::writeSetEnd() {
+ indentDown();
+ write_state_.pop_back();
+ uint32_t size = 0;
+ size += writeIndented("}");
+ size += endItem();
+ return size;
+}
+
+uint32_t TDebugProtocol::writeBool(const bool value) {
+ return writeItem(value ? "true" : "false");
+}
+
+uint32_t TDebugProtocol::writeByte(const int8_t byte) {
+ return writeItem("0x" + byte_to_hex(byte));
+}
+
+uint32_t TDebugProtocol::writeI16(const int16_t i16) {
+ return writeItem(boost::lexical_cast<string>(i16));
+}
+
+uint32_t TDebugProtocol::writeI32(const int32_t i32) {
+ return writeItem(boost::lexical_cast<string>(i32));
+}
+
+uint32_t TDebugProtocol::writeI64(const int64_t i64) {
+ return writeItem(boost::lexical_cast<string>(i64));
+}
+
+uint32_t TDebugProtocol::writeDouble(const double dub) {
+ return writeItem(boost::lexical_cast<string>(dub));
+}
+
+
+uint32_t TDebugProtocol::writeString(const string& str) {
+ // XXX Raw/UTF-8?
+
+ string to_show = str;
+ if (to_show.length() > (string::size_type)string_limit_) {
+ to_show = str.substr(0, string_prefix_size_);
+ to_show += "[...](" + boost::lexical_cast<string>(str.length()) + ")";
+ }
+
+ string output = "\"";
+
+ for (string::const_iterator it = to_show.begin(); it != to_show.end(); ++it) {
+ if (*it == '\\') {
+ output += "\\\\";
+ } else if (*it == '"') {
+ output += "\\\"";
+ } else if (std::isprint(*it)) {
+ output += *it;
+ } else {
+ switch (*it) {
+ case '\a': output += "\\a"; break;
+ case '\b': output += "\\b"; break;
+ case '\f': output += "\\f"; break;
+ case '\n': output += "\\n"; break;
+ case '\r': output += "\\r"; break;
+ case '\t': output += "\\t"; break;
+ case '\v': output += "\\v"; break;
+ default:
+ output += "\\x";
+ output += byte_to_hex(*it);
+ }
+ }
+ }
+
+ output += '\"';
+ return writeItem(output);
+}
+
+uint32_t TDebugProtocol::writeBinary(const string& str) {
+ // XXX Hex?
+ return TDebugProtocol::writeString(str);
+}
+
+}}} // apache::thrift::protocol
diff --git a/lib/cpp/src/protocol/TDebugProtocol.h b/lib/cpp/src/protocol/TDebugProtocol.h
new file mode 100644
index 0000000..ab69e0c
--- /dev/null
+++ b/lib/cpp/src/protocol/TDebugProtocol.h
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_PROTOCOL_TDEBUGPROTOCOL_H_
+#define _THRIFT_PROTOCOL_TDEBUGPROTOCOL_H_ 1
+
+#include "TProtocol.h"
+#include "TOneWayProtocol.h"
+
+#include <boost/shared_ptr.hpp>
+
+namespace apache { namespace thrift { namespace protocol {
+
+/*
+
+!!! EXPERIMENTAL CODE !!!
+
+This protocol is very much a work in progress.
+It doesn't handle many cases properly.
+It throws exceptions in many cases.
+It probably segfaults in many cases.
+Bug reports and feature requests are welcome.
+Complaints are not. :R
+
+*/
+
+
+/**
+ * Protocol that prints the payload in a nice human-readable format.
+ * Reading from this protocol is not supported.
+ *
+ */
+class TDebugProtocol : public TWriteOnlyProtocol {
+ private:
+ enum write_state_t
+ { UNINIT
+ , STRUCT
+ , LIST
+ , SET
+ , MAP_KEY
+ , MAP_VALUE
+ };
+
+ public:
+ TDebugProtocol(boost::shared_ptr<TTransport> trans)
+ : TWriteOnlyProtocol(trans, "TDebugProtocol")
+ , string_limit_(DEFAULT_STRING_LIMIT)
+ , string_prefix_size_(DEFAULT_STRING_PREFIX_SIZE)
+ {
+ write_state_.push_back(UNINIT);
+ }
+
+ static const int32_t DEFAULT_STRING_LIMIT = 256;
+ static const int32_t DEFAULT_STRING_PREFIX_SIZE = 16;
+
+ void setStringSizeLimit(int32_t string_limit) {
+ string_limit_ = string_limit;
+ }
+
+ void setStringPrefixSize(int32_t string_prefix_size) {
+ string_prefix_size_ = string_prefix_size;
+ }
+
+
+ virtual uint32_t writeMessageBegin(const std::string& name,
+ const TMessageType messageType,
+ const int32_t seqid);
+
+ virtual uint32_t writeMessageEnd();
+
+
+ uint32_t writeStructBegin(const char* name);
+
+ uint32_t writeStructEnd();
+
+ uint32_t writeFieldBegin(const char* name,
+ const TType fieldType,
+ const int16_t fieldId);
+
+ uint32_t writeFieldEnd();
+
+ uint32_t writeFieldStop();
+
+ uint32_t writeMapBegin(const TType keyType,
+ const TType valType,
+ const uint32_t size);
+
+ uint32_t writeMapEnd();
+
+ uint32_t writeListBegin(const TType elemType,
+ const uint32_t size);
+
+ uint32_t writeListEnd();
+
+ uint32_t writeSetBegin(const TType elemType,
+ const uint32_t size);
+
+ uint32_t writeSetEnd();
+
+ uint32_t writeBool(const bool value);
+
+ uint32_t writeByte(const int8_t byte);
+
+ uint32_t writeI16(const int16_t i16);
+
+ uint32_t writeI32(const int32_t i32);
+
+ uint32_t writeI64(const int64_t i64);
+
+ uint32_t writeDouble(const double dub);
+
+ uint32_t writeString(const std::string& str);
+
+ uint32_t writeBinary(const std::string& str);
+
+
+ private:
+ void indentUp();
+ void indentDown();
+ uint32_t writePlain(const std::string& str);
+ uint32_t writeIndented(const std::string& str);
+ uint32_t startItem();
+ uint32_t endItem();
+ uint32_t writeItem(const std::string& str);
+
+ static std::string fieldTypeName(TType type);
+
+ int32_t string_limit_;
+ int32_t string_prefix_size_;
+
+ std::string indent_str_;
+ static const int indent_inc = 2;
+
+ std::vector<write_state_t> write_state_;
+ std::vector<int> list_idx_;
+};
+
+/**
+ * Constructs debug protocol handlers
+ */
+class TDebugProtocolFactory : public TProtocolFactory {
+ public:
+ TDebugProtocolFactory() {}
+ virtual ~TDebugProtocolFactory() {}
+
+ boost::shared_ptr<TProtocol> getProtocol(boost::shared_ptr<TTransport> trans) {
+ return boost::shared_ptr<TProtocol>(new TDebugProtocol(trans));
+ }
+
+};
+
+}}} // apache::thrift::protocol
+
+
+// TODO(dreiss): Move (part of) ThriftDebugString into a .cpp file and remove this.
+#include <transport/TBufferTransports.h>
+
+namespace apache { namespace thrift {
+
+template<typename ThriftStruct>
+std::string ThriftDebugString(const ThriftStruct& ts) {
+ using namespace apache::thrift::transport;
+ using namespace apache::thrift::protocol;
+ TMemoryBuffer* buffer = new TMemoryBuffer;
+ boost::shared_ptr<TTransport> trans(buffer);
+ TDebugProtocol protocol(trans);
+
+ ts.write(&protocol);
+
+ uint8_t* buf;
+ uint32_t size;
+ buffer->getBuffer(&buf, &size);
+ return std::string((char*)buf, (unsigned int)size);
+}
+
+// TODO(dreiss): This is badly broken. Don't use it unless you are me.
+#if 0
+template<typename Object>
+std::string DebugString(const std::vector<Object>& vec) {
+ using namespace apache::thrift::transport;
+ using namespace apache::thrift::protocol;
+ TMemoryBuffer* buffer = new TMemoryBuffer;
+ boost::shared_ptr<TTransport> trans(buffer);
+ TDebugProtocol protocol(trans);
+
+ // I am gross!
+ protocol.writeStructBegin("SomeRandomVector");
+
+ // TODO: Fix this with a trait.
+ protocol.writeListBegin((TType)99, vec.size());
+ typename std::vector<Object>::const_iterator it;
+ for (it = vec.begin(); it != vec.end(); ++it) {
+ it->write(&protocol);
+ }
+ protocol.writeListEnd();
+
+ uint8_t* buf;
+ uint32_t size;
+ buffer->getBuffer(&buf, &size);
+ return std::string((char*)buf, (unsigned int)size);
+}
+#endif // 0
+
+}} // apache::thrift
+
+
+#endif // #ifndef _THRIFT_PROTOCOL_TDEBUGPROTOCOL_H_
+
+
diff --git a/lib/cpp/src/protocol/TDenseProtocol.cpp b/lib/cpp/src/protocol/TDenseProtocol.cpp
new file mode 100644
index 0000000..8e76dc4
--- /dev/null
+++ b/lib/cpp/src/protocol/TDenseProtocol.cpp
@@ -0,0 +1,762 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*
+
+IMPLEMENTATION DETAILS
+
+TDenseProtocol was designed to have a smaller serialized form than
+TBinaryProtocol. This is accomplished using two techniques. The first is
+variable-length integer encoding. We use the same technique that the Standard
+MIDI File format uses for "variable-length quantities"
+(http://en.wikipedia.org/wiki/Variable-length_quantity).
+All integers (including i16, but not byte) are first cast to uint64_t,
+then written out as variable-length quantities. This has the unfortunate side
+effect that all negative numbers require 10 bytes, but negative numbers tend
+to be far less common than positive ones.
+
+The second technique eliminating the field ids used by TBinaryProtocol. This
+decision required support from the Thrift compiler and also sacrifices some of
+the backward and forward compatibility of TBinaryProtocol.
+
+We considered implementing this technique by generating separate readers and
+writers for the dense protocol (this is how Pillar, Thrift's predecessor,
+worked), but this idea had a few problems:
+- Our abstractions go out the window.
+- We would have to maintain a second code generator.
+- Preserving compatibility with old versions of the structures would be a
+ nightmare.
+
+Therefore, we chose an alternate implementation that stored the description of
+the data neither in the data itself (like TBinaryProtocol) nor in the
+serialization code (like Pillar), but instead in a separate data structure,
+called a TypeSpec. TypeSpecs are generated by the Thrift compiler
+(specifically in the t_cpp_generator), and their structure should be
+documented there (TODO(dreiss): s/should be/is/).
+
+We maintain a stack of TypeSpecs within the protocol so it knows where the
+generated code is in the reading/writing process. For example, if we are
+writing an i32 contained in a struct bar, contained in a struct foo, then the
+stack would look like: TOP , i32 , struct bar , struct foo , BOTTOM.
+The following invariant: whenever we are about to read/write an object
+(structBegin, containerBegin, or a scalar), the TypeSpec on the top of the
+stack must match the type being read/written. The main reasons that this
+invariant must be maintained is that if we ever start reading a structure, we
+must have its exact TypeSpec in order to pass the right tags to the
+deserializer.
+
+We use the following strategies for maintaining this invariant:
+
+- For structures, we have a separate stack of indexes, one for each structure
+ on the TypeSpec stack. These are indexes into the list of fields in the
+ structure's TypeSpec. When we {read,write}FieldBegin, we push on the
+ TypeSpec for the field.
+- When we begin writing a list or set, we push on the TypeSpec for the
+ element type.
+- For maps, we have a separate stack of booleans, one for each map on the
+ TypeSpec stack. The boolean is true if we are writing the key for that
+ map, and false if we are writing the value. Maps are the trickiest case
+ because the generated code does not call any protocol method between
+ the key and the value. As a result, we potentially have to switch
+ between map key state and map value state after reading/writing any object.
+- This job is handled by the stateTransition method. It is called after
+ reading/writing every object. It pops the current TypeSpec off the stack,
+ then optionally pushes a new one on, depending on what the next TypeSpec is.
+ If it is a struct, the job is left to the next writeFieldBegin. If it is a
+ set or list, the just-popped typespec is pushed back on. If it is a map,
+ the top of the key/value stack is toggled, and the appropriate TypeSpec
+ is pushed.
+
+Optional fields are a little tricky also. We write a zero byte if they are
+absent and prefix them with an 0x01 byte if they are present
+*/
+
+#define __STDC_LIMIT_MACROS
+#include <stdint.h>
+#include "TDenseProtocol.h"
+#include "TReflectionLocal.h"
+
+// Leaving this on for now. Disabling it will turn off asserts, which should
+// give a performance boost. When we have *really* thorough test cases,
+// we should drop this.
+#define DEBUG_TDENSEPROTOCOL
+
+// NOTE: Assertions should *only* be used to detect bugs in code,
+// either in TDenseProtocol itself, or in code using it.
+// (For example, using the wrong TypeSpec.)
+// Invalid data should NEVER cause an assertion failure,
+// no matter how grossly corrupted, nor how ingeniously crafted.
+#ifdef DEBUG_TDENSEPROTOCOL
+#undef NDEBUG
+#else
+#define NDEBUG
+#endif
+#include <cassert>
+
+using std::string;
+
+#ifdef __GNUC__
+#define UNLIKELY(val) (__builtin_expect((val), 0))
+#else
+#define UNLIKELY(val) (val)
+#endif
+
+namespace apache { namespace thrift { namespace protocol {
+
+const int TDenseProtocol::FP_PREFIX_LEN =
+ apache::thrift::reflection::local::FP_PREFIX_LEN;
+
+// Top TypeSpec. TypeSpec of the structure being encoded.
+#define TTS (ts_stack_.back()) // type = TypeSpec*
+// InDeX. Index into TTS of the current/next field to encode.
+#define IDX (idx_stack_.back()) // type = int
+// Field TypeSpec. TypeSpec of the current/next field to encode.
+#define FTS (TTS->tstruct.specs[IDX]) // type = TypeSpec*
+// Field MeTa. Metadata of the current/next field to encode.
+#define FMT (TTS->tstruct.metas[IDX]) // type = FieldMeta
+// SubType 1/2. TypeSpec of the first/second subtype of this container.
+#define ST1 (TTS->tcontainer.subtype1)
+#define ST2 (TTS->tcontainer.subtype2)
+
+
+/**
+ * Checks that @c ttype is indeed the ttype that we should be writing,
+ * according to our typespec. Aborts if the test fails and debugging in on.
+ */
+inline void TDenseProtocol::checkTType(const TType ttype) {
+ assert(!ts_stack_.empty());
+ assert(TTS->ttype == ttype);
+}
+
+/**
+ * Makes sure that the TypeSpec stack is correct for the next object.
+ * See top-of-file comments.
+ */
+inline void TDenseProtocol::stateTransition() {
+ TypeSpec* old_tts = ts_stack_.back();
+ ts_stack_.pop_back();
+
+ // If this is the end of the top-level write, we should have just popped
+ // the TypeSpec passed to the constructor.
+ if (ts_stack_.empty()) {
+ assert(old_tts = type_spec_);
+ return;
+ }
+
+ switch (TTS->ttype) {
+
+ case T_STRUCT:
+ assert(old_tts == FTS);
+ break;
+
+ case T_LIST:
+ case T_SET:
+ assert(old_tts == ST1);
+ ts_stack_.push_back(old_tts);
+ break;
+
+ case T_MAP:
+ assert(old_tts == (mkv_stack_.back() ? ST1 : ST2));
+ mkv_stack_.back() = !mkv_stack_.back();
+ ts_stack_.push_back(mkv_stack_.back() ? ST1 : ST2);
+ break;
+
+ default:
+ assert(!"Invalid TType in stateTransition.");
+ break;
+
+ }
+}
+
+
+/*
+ * Variable-length quantity functions.
+ */
+
+inline uint32_t TDenseProtocol::vlqRead(uint64_t& vlq) {
+ uint32_t used = 0;
+ uint64_t val = 0;
+ uint8_t buf[10]; // 64 bits / (7 bits/byte) = 10 bytes.
+ uint32_t buf_size = sizeof(buf);
+ const uint8_t* borrowed = trans_->borrow(buf, &buf_size);
+
+ // Fast path. TODO(dreiss): Make it faster.
+ if (borrowed != NULL) {
+ while (true) {
+ uint8_t byte = borrowed[used];
+ used++;
+ val = (val << 7) | (byte & 0x7f);
+ if (!(byte & 0x80)) {
+ vlq = val;
+ trans_->consume(used);
+ return used;
+ }
+ // Have to check for invalid data so we don't crash.
+ if (UNLIKELY(used == sizeof(buf))) {
+ resetState();
+ throw TProtocolException(TProtocolException::INVALID_DATA, "Variable-length int over 10 bytes.");
+ }
+ }
+ }
+
+ // Slow path.
+ else {
+ while (true) {
+ uint8_t byte;
+ used += trans_->readAll(&byte, 1);
+ val = (val << 7) | (byte & 0x7f);
+ if (!(byte & 0x80)) {
+ vlq = val;
+ return used;
+ }
+ // Might as well check for invalid data on the slow path too.
+ if (UNLIKELY(used >= sizeof(buf))) {
+ resetState();
+ throw TProtocolException(TProtocolException::INVALID_DATA, "Variable-length int over 10 bytes.");
+ }
+ }
+ }
+}
+
+inline uint32_t TDenseProtocol::vlqWrite(uint64_t vlq) {
+ uint8_t buf[10]; // 64 bits / (7 bits/byte) = 10 bytes.
+ int32_t pos = sizeof(buf) - 1;
+
+ // Write the thing from back to front.
+ buf[pos] = vlq & 0x7f;
+ vlq >>= 7;
+ pos--;
+
+ while (vlq > 0) {
+ assert(pos >= 0);
+ buf[pos] = (vlq | 0x80);
+ vlq >>= 7;
+ pos--;
+ }
+
+ // Back up one step before writing.
+ pos++;
+
+ trans_->write(buf+pos, sizeof(buf) - pos);
+ return sizeof(buf) - pos;
+}
+
+
+
+/*
+ * Writing functions.
+ */
+
+uint32_t TDenseProtocol::writeMessageBegin(const std::string& name,
+ const TMessageType messageType,
+ const int32_t seqid) {
+ throw TApplicationException("TDenseProtocol doesn't work with messages (yet).");
+
+ int32_t version = (VERSION_2) | ((int32_t)messageType);
+ uint32_t wsize = 0;
+ wsize += subWriteI32(version);
+ wsize += subWriteString(name);
+ wsize += subWriteI32(seqid);
+ return wsize;
+}
+
+uint32_t TDenseProtocol::writeMessageEnd() {
+ return 0;
+}
+
+uint32_t TDenseProtocol::writeStructBegin(const char* name) {
+ uint32_t xfer = 0;
+
+ // The TypeSpec stack should be empty if this is the top-level read/write.
+ // If it is, we push the TypeSpec passed to the constructor.
+ if (ts_stack_.empty()) {
+ assert(standalone_);
+
+ if (type_spec_ == NULL) {
+ resetState();
+ throw TApplicationException("TDenseProtocol: No type specified.");
+ } else {
+ assert(type_spec_->ttype == T_STRUCT);
+ ts_stack_.push_back(type_spec_);
+ // Write out a prefix of the structure fingerprint.
+ trans_->write(type_spec_->fp_prefix, FP_PREFIX_LEN);
+ xfer += FP_PREFIX_LEN;
+ }
+ }
+
+ // We need a new field index for this structure.
+ idx_stack_.push_back(0);
+ return 0;
+}
+
+uint32_t TDenseProtocol::writeStructEnd() {
+ idx_stack_.pop_back();
+ stateTransition();
+ return 0;
+}
+
+uint32_t TDenseProtocol::writeFieldBegin(const char* name,
+ const TType fieldType,
+ const int16_t fieldId) {
+ uint32_t xfer = 0;
+
+ // Skip over optional fields.
+ while (FMT.tag != fieldId) {
+ // TODO(dreiss): Old meta here.
+ assert(FTS->ttype != T_STOP);
+ assert(FMT.is_optional);
+ // Write a zero byte so the reader can skip it.
+ xfer += subWriteBool(false);
+ // And advance to the next field.
+ IDX++;
+ }
+
+ // TODO(dreiss): give a better exception.
+ assert(FTS->ttype == fieldType);
+
+ if (FMT.is_optional) {
+ subWriteBool(true);
+ xfer += 1;
+ }
+
+ // writeFieldStop shares all lot of logic up to this point.
+ // Instead of replicating it all, we just call this method from that one
+ // and use a gross special case here.
+ if (UNLIKELY(FTS->ttype != T_STOP)) {
+ // For normal fields, push the TypeSpec that we're about to use.
+ ts_stack_.push_back(FTS);
+ }
+ return xfer;
+}
+
+uint32_t TDenseProtocol::writeFieldEnd() {
+ // Just move on to the next field.
+ IDX++;
+ return 0;
+}
+
+uint32_t TDenseProtocol::writeFieldStop() {
+ return TDenseProtocol::writeFieldBegin("", T_STOP, 0);
+}
+
+uint32_t TDenseProtocol::writeMapBegin(const TType keyType,
+ const TType valType,
+ const uint32_t size) {
+ checkTType(T_MAP);
+
+ assert(keyType == ST1->ttype);
+ assert(valType == ST2->ttype);
+
+ ts_stack_.push_back(ST1);
+ mkv_stack_.push_back(true);
+
+ return subWriteI32((int32_t)size);
+}
+
+uint32_t TDenseProtocol::writeMapEnd() {
+ // Pop off the value type, as well as our entry in the map key/value stack.
+ // stateTransition takes care of popping off our TypeSpec.
+ ts_stack_.pop_back();
+ mkv_stack_.pop_back();
+ stateTransition();
+ return 0;
+}
+
+uint32_t TDenseProtocol::writeListBegin(const TType elemType,
+ const uint32_t size) {
+ checkTType(T_LIST);
+
+ assert(elemType == ST1->ttype);
+ ts_stack_.push_back(ST1);
+ return subWriteI32((int32_t)size);
+}
+
+uint32_t TDenseProtocol::writeListEnd() {
+ // Pop off the element type. stateTransition takes care of popping off ours.
+ ts_stack_.pop_back();
+ stateTransition();
+ return 0;
+}
+
+uint32_t TDenseProtocol::writeSetBegin(const TType elemType,
+ const uint32_t size) {
+ checkTType(T_SET);
+
+ assert(elemType == ST1->ttype);
+ ts_stack_.push_back(ST1);
+ return subWriteI32((int32_t)size);
+}
+
+uint32_t TDenseProtocol::writeSetEnd() {
+ // Pop off the element type. stateTransition takes care of popping off ours.
+ ts_stack_.pop_back();
+ stateTransition();
+ return 0;
+}
+
+uint32_t TDenseProtocol::writeBool(const bool value) {
+ checkTType(T_BOOL);
+ stateTransition();
+ return TBinaryProtocol::writeBool(value);
+}
+
+uint32_t TDenseProtocol::writeByte(const int8_t byte) {
+ checkTType(T_BYTE);
+ stateTransition();
+ return TBinaryProtocol::writeByte(byte);
+}
+
+uint32_t TDenseProtocol::writeI16(const int16_t i16) {
+ checkTType(T_I16);
+ stateTransition();
+ return vlqWrite(i16);
+}
+
+uint32_t TDenseProtocol::writeI32(const int32_t i32) {
+ checkTType(T_I32);
+ stateTransition();
+ return vlqWrite(i32);
+}
+
+uint32_t TDenseProtocol::writeI64(const int64_t i64) {
+ checkTType(T_I64);
+ stateTransition();
+ return vlqWrite(i64);
+}
+
+uint32_t TDenseProtocol::writeDouble(const double dub) {
+ checkTType(T_DOUBLE);
+ stateTransition();
+ return TBinaryProtocol::writeDouble(dub);
+}
+
+uint32_t TDenseProtocol::writeString(const std::string& str) {
+ checkTType(T_STRING);
+ stateTransition();
+ return subWriteString(str);
+}
+
+uint32_t TDenseProtocol::writeBinary(const std::string& str) {
+ return TDenseProtocol::writeString(str);
+}
+
+inline uint32_t TDenseProtocol::subWriteI32(const int32_t i32) {
+ return vlqWrite(i32);
+}
+
+uint32_t TDenseProtocol::subWriteString(const std::string& str) {
+ uint32_t size = str.size();
+ uint32_t xfer = subWriteI32((int32_t)size);
+ if (size > 0) {
+ trans_->write((uint8_t*)str.data(), size);
+ }
+ return xfer + size;
+}
+
+
+
+/*
+ * Reading functions
+ *
+ * These have a lot of the same logic as the writing functions, so if
+ * something is confusing, look for comments in the corresponding writer.
+ */
+
+uint32_t TDenseProtocol::readMessageBegin(std::string& name,
+ TMessageType& messageType,
+ int32_t& seqid) {
+ throw TApplicationException("TDenseProtocol doesn't work with messages (yet).");
+
+ uint32_t xfer = 0;
+ int32_t sz;
+ xfer += subReadI32(sz);
+
+ if (sz < 0) {
+ // Check for correct version number
+ int32_t version = sz & VERSION_MASK;
+ if (version != VERSION_2) {
+ throw TProtocolException(TProtocolException::BAD_VERSION, "Bad version identifier");
+ }
+ messageType = (TMessageType)(sz & 0x000000ff);
+ xfer += subReadString(name);
+ xfer += subReadI32(seqid);
+ } else {
+ throw TProtocolException(TProtocolException::BAD_VERSION, "No version identifier... old protocol client in strict mode?");
+ }
+ return xfer;
+}
+
+uint32_t TDenseProtocol::readMessageEnd() {
+ return 0;
+}
+
+uint32_t TDenseProtocol::readStructBegin(string& name) {
+ uint32_t xfer = 0;
+
+ if (ts_stack_.empty()) {
+ assert(standalone_);
+
+ if (type_spec_ == NULL) {
+ resetState();
+ throw TApplicationException("TDenseProtocol: No type specified.");
+ } else {
+ assert(type_spec_->ttype == T_STRUCT);
+ ts_stack_.push_back(type_spec_);
+
+ // Check the fingerprint prefix.
+ uint8_t buf[FP_PREFIX_LEN];
+ xfer += trans_->read(buf, FP_PREFIX_LEN);
+ if (std::memcmp(buf, type_spec_->fp_prefix, FP_PREFIX_LEN) != 0) {
+ resetState();
+ throw TProtocolException(TProtocolException::INVALID_DATA,
+ "Fingerprint in data does not match type_spec.");
+ }
+ }
+ }
+
+ // We need a new field index for this structure.
+ idx_stack_.push_back(0);
+ return 0;
+}
+
+uint32_t TDenseProtocol::readStructEnd() {
+ idx_stack_.pop_back();
+ stateTransition();
+ return 0;
+}
+
+uint32_t TDenseProtocol::readFieldBegin(string& name,
+ TType& fieldType,
+ int16_t& fieldId) {
+ uint32_t xfer = 0;
+
+ // For optional fields, check to see if they are there.
+ while (FMT.is_optional) {
+ bool is_present;
+ xfer += subReadBool(is_present);
+ if (is_present) {
+ break;
+ }
+ IDX++;
+ }
+
+ // Once we hit a mandatory field, or an optional field that is present,
+ // we know that FMT and FTS point to the appropriate field.
+
+ fieldId = FMT.tag;
+ fieldType = FTS->ttype;
+
+ // Normally, we push the TypeSpec that we are about to read,
+ // but no reading is done for T_STOP.
+ if (FTS->ttype != T_STOP) {
+ ts_stack_.push_back(FTS);
+ }
+ return xfer;
+}
+
+uint32_t TDenseProtocol::readFieldEnd() {
+ IDX++;
+ return 0;
+}
+
+uint32_t TDenseProtocol::readMapBegin(TType& keyType,
+ TType& valType,
+ uint32_t& size) {
+ checkTType(T_MAP);
+
+ uint32_t xfer = 0;
+ int32_t sizei;
+ xfer += subReadI32(sizei);
+ if (sizei < 0) {
+ resetState();
+ throw TProtocolException(TProtocolException::NEGATIVE_SIZE);
+ } else if (container_limit_ && sizei > container_limit_) {
+ resetState();
+ throw TProtocolException(TProtocolException::SIZE_LIMIT);
+ }
+ size = (uint32_t)sizei;
+
+ keyType = ST1->ttype;
+ valType = ST2->ttype;
+
+ ts_stack_.push_back(ST1);
+ mkv_stack_.push_back(true);
+
+ return xfer;
+}
+
+uint32_t TDenseProtocol::readMapEnd() {
+ ts_stack_.pop_back();
+ mkv_stack_.pop_back();
+ stateTransition();
+ return 0;
+}
+
+uint32_t TDenseProtocol::readListBegin(TType& elemType,
+ uint32_t& size) {
+ checkTType(T_LIST);
+
+ uint32_t xfer = 0;
+ int32_t sizei;
+ xfer += subReadI32(sizei);
+ if (sizei < 0) {
+ resetState();
+ throw TProtocolException(TProtocolException::NEGATIVE_SIZE);
+ } else if (container_limit_ && sizei > container_limit_) {
+ resetState();
+ throw TProtocolException(TProtocolException::SIZE_LIMIT);
+ }
+ size = (uint32_t)sizei;
+
+ elemType = ST1->ttype;
+
+ ts_stack_.push_back(ST1);
+
+ return xfer;
+}
+
+uint32_t TDenseProtocol::readListEnd() {
+ ts_stack_.pop_back();
+ stateTransition();
+ return 0;
+}
+
+uint32_t TDenseProtocol::readSetBegin(TType& elemType,
+ uint32_t& size) {
+ checkTType(T_SET);
+
+ uint32_t xfer = 0;
+ int32_t sizei;
+ xfer += subReadI32(sizei);
+ if (sizei < 0) {
+ resetState();
+ throw TProtocolException(TProtocolException::NEGATIVE_SIZE);
+ } else if (container_limit_ && sizei > container_limit_) {
+ resetState();
+ throw TProtocolException(TProtocolException::SIZE_LIMIT);
+ }
+ size = (uint32_t)sizei;
+
+ elemType = ST1->ttype;
+
+ ts_stack_.push_back(ST1);
+
+ return xfer;
+}
+
+uint32_t TDenseProtocol::readSetEnd() {
+ ts_stack_.pop_back();
+ stateTransition();
+ return 0;
+}
+
+uint32_t TDenseProtocol::readBool(bool& value) {
+ checkTType(T_BOOL);
+ stateTransition();
+ return TBinaryProtocol::readBool(value);
+}
+
+uint32_t TDenseProtocol::readByte(int8_t& byte) {
+ checkTType(T_BYTE);
+ stateTransition();
+ return TBinaryProtocol::readByte(byte);
+}
+
+uint32_t TDenseProtocol::readI16(int16_t& i16) {
+ checkTType(T_I16);
+ stateTransition();
+ uint64_t u64;
+ uint32_t rv = vlqRead(u64);
+ int64_t val = (int64_t)u64;
+ if (UNLIKELY(val > INT16_MAX || val < INT16_MIN)) {
+ resetState();
+ throw TProtocolException(TProtocolException::INVALID_DATA,
+ "i16 out of range.");
+ }
+ i16 = (int16_t)val;
+ return rv;
+}
+
+uint32_t TDenseProtocol::readI32(int32_t& i32) {
+ checkTType(T_I32);
+ stateTransition();
+ uint64_t u64;
+ uint32_t rv = vlqRead(u64);
+ int64_t val = (int64_t)u64;
+ if (UNLIKELY(val > INT32_MAX || val < INT32_MIN)) {
+ resetState();
+ throw TProtocolException(TProtocolException::INVALID_DATA,
+ "i32 out of range.");
+ }
+ i32 = (int32_t)val;
+ return rv;
+}
+
+uint32_t TDenseProtocol::readI64(int64_t& i64) {
+ checkTType(T_I64);
+ stateTransition();
+ uint64_t u64;
+ uint32_t rv = vlqRead(u64);
+ int64_t val = (int64_t)u64;
+ if (UNLIKELY(val > INT64_MAX || val < INT64_MIN)) {
+ resetState();
+ throw TProtocolException(TProtocolException::INVALID_DATA,
+ "i64 out of range.");
+ }
+ i64 = (int64_t)val;
+ return rv;
+}
+
+uint32_t TDenseProtocol::readDouble(double& dub) {
+ checkTType(T_DOUBLE);
+ stateTransition();
+ return TBinaryProtocol::readDouble(dub);
+}
+
+uint32_t TDenseProtocol::readString(std::string& str) {
+ checkTType(T_STRING);
+ stateTransition();
+ return subReadString(str);
+}
+
+uint32_t TDenseProtocol::readBinary(std::string& str) {
+ return TDenseProtocol::readString(str);
+}
+
+uint32_t TDenseProtocol::subReadI32(int32_t& i32) {
+ uint64_t u64;
+ uint32_t rv = vlqRead(u64);
+ int64_t val = (int64_t)u64;
+ if (UNLIKELY(val > INT32_MAX || val < INT32_MIN)) {
+ resetState();
+ throw TProtocolException(TProtocolException::INVALID_DATA,
+ "i32 out of range.");
+ }
+ i32 = (int32_t)val;
+ return rv;
+}
+
+uint32_t TDenseProtocol::subReadString(std::string& str) {
+ uint32_t xfer;
+ int32_t size;
+ xfer = subReadI32(size);
+ return xfer + readStringBody(str, size);
+}
+
+}}} // apache::thrift::protocol
diff --git a/lib/cpp/src/protocol/TDenseProtocol.h b/lib/cpp/src/protocol/TDenseProtocol.h
new file mode 100644
index 0000000..7655a47
--- /dev/null
+++ b/lib/cpp/src/protocol/TDenseProtocol.h
@@ -0,0 +1,253 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_PROTOCOL_TDENSEPROTOCOL_H_
+#define _THRIFT_PROTOCOL_TDENSEPROTOCOL_H_ 1
+
+#include "TBinaryProtocol.h"
+
+namespace apache { namespace thrift { namespace protocol {
+
+/**
+ * !!!WARNING!!!
+ * This class is still highly experimental. Incompatible changes
+ * WILL be made to it without notice. DO NOT USE IT YET unless
+ * you are coordinating your testing with the author.
+ *
+ * The dense protocol is designed to use as little space as possible.
+ *
+ * There are two types of dense protocol instances. Standalone instances
+ * are not used for RPC and just encoded and decode structures of
+ * a predetermined type. Non-standalone instances are used for RPC.
+ * Currently, only standalone instances exist.
+ *
+ * To use a standalone dense protocol object, you must set the type_spec
+ * property (either in the constructor, or with setTypeSpec) to the local
+ * reflection TypeSpec of the structures you will write to (or read from) the
+ * protocol instance.
+ *
+ * BEST PRACTICES:
+ * - Never use optional for primitives or containers.
+ * - Only use optional for structures if they are very big and very rarely set.
+ * - All integers are variable-length, so you can use i64 without bloating.
+ * - NEVER EVER change the struct definitions IN ANY WAY without either
+ * changing your cache keys or talking to dreiss.
+ *
+ * TODO(dreiss): New class write with old meta.
+ *
+ * We override all of TBinaryProtocol's methods.
+ * We inherit so that we can can explicitly call TBPs's primitive-writing
+ * methods within our versions.
+ *
+ */
+class TDenseProtocol : public TBinaryProtocol {
+ protected:
+ static const int32_t VERSION_MASK = 0xffff0000;
+ // VERSION_1 (0x80010000) is taken by TBinaryProtocol.
+ static const int32_t VERSION_2 = 0x80020000;
+
+ public:
+ typedef apache::thrift::reflection::local::TypeSpec TypeSpec;
+ static const int FP_PREFIX_LEN;
+
+ /**
+ * @param tran The transport to use.
+ * @param type_spec The TypeSpec of the structures using this protocol.
+ */
+ TDenseProtocol(boost::shared_ptr<TTransport> trans,
+ TypeSpec* type_spec = NULL) :
+ TBinaryProtocol(trans),
+ type_spec_(type_spec),
+ standalone_(true)
+ {}
+
+ void setTypeSpec(TypeSpec* type_spec) {
+ type_spec_ = type_spec;
+ }
+ TypeSpec* getTypeSpec() {
+ return type_spec_;
+ }
+
+
+ /*
+ * Writing functions.
+ */
+
+ virtual uint32_t writeMessageBegin(const std::string& name,
+ const TMessageType messageType,
+ const int32_t seqid);
+
+ virtual uint32_t writeMessageEnd();
+
+
+ virtual uint32_t writeStructBegin(const char* name);
+
+ virtual uint32_t writeStructEnd();
+
+ virtual uint32_t writeFieldBegin(const char* name,
+ const TType fieldType,
+ const int16_t fieldId);
+
+ virtual uint32_t writeFieldEnd();
+
+ virtual uint32_t writeFieldStop();
+
+ virtual uint32_t writeMapBegin(const TType keyType,
+ const TType valType,
+ const uint32_t size);
+
+ virtual uint32_t writeMapEnd();
+
+ virtual uint32_t writeListBegin(const TType elemType,
+ const uint32_t size);
+
+ virtual uint32_t writeListEnd();
+
+ virtual uint32_t writeSetBegin(const TType elemType,
+ const uint32_t size);
+
+ virtual uint32_t writeSetEnd();
+
+ virtual uint32_t writeBool(const bool value);
+
+ virtual uint32_t writeByte(const int8_t byte);
+
+ virtual uint32_t writeI16(const int16_t i16);
+
+ virtual uint32_t writeI32(const int32_t i32);
+
+ virtual uint32_t writeI64(const int64_t i64);
+
+ virtual uint32_t writeDouble(const double dub);
+
+ virtual uint32_t writeString(const std::string& str);
+
+ virtual uint32_t writeBinary(const std::string& str);
+
+
+ /*
+ * Helper writing functions (don't do state transitions).
+ */
+ inline uint32_t subWriteI32(const int32_t i32);
+
+ inline uint32_t subWriteString(const std::string& str);
+
+ uint32_t subWriteBool(const bool value) {
+ return TBinaryProtocol::writeBool(value);
+ }
+
+
+ /*
+ * Reading functions
+ */
+
+ uint32_t readMessageBegin(std::string& name,
+ TMessageType& messageType,
+ int32_t& seqid);
+
+ uint32_t readMessageEnd();
+
+ uint32_t readStructBegin(std::string& name);
+
+ uint32_t readStructEnd();
+
+ uint32_t readFieldBegin(std::string& name,
+ TType& fieldType,
+ int16_t& fieldId);
+
+ uint32_t readFieldEnd();
+
+ uint32_t readMapBegin(TType& keyType,
+ TType& valType,
+ uint32_t& size);
+
+ uint32_t readMapEnd();
+
+ uint32_t readListBegin(TType& elemType,
+ uint32_t& size);
+
+ uint32_t readListEnd();
+
+ uint32_t readSetBegin(TType& elemType,
+ uint32_t& size);
+
+ uint32_t readSetEnd();
+
+ uint32_t readBool(bool& value);
+
+ uint32_t readByte(int8_t& byte);
+
+ uint32_t readI16(int16_t& i16);
+
+ uint32_t readI32(int32_t& i32);
+
+ uint32_t readI64(int64_t& i64);
+
+ uint32_t readDouble(double& dub);
+
+ uint32_t readString(std::string& str);
+
+ uint32_t readBinary(std::string& str);
+
+ /*
+ * Helper reading functions (don't do state transitions).
+ */
+ inline uint32_t subReadI32(int32_t& i32);
+
+ inline uint32_t subReadString(std::string& str);
+
+ uint32_t subReadBool(bool& value) {
+ return TBinaryProtocol::readBool(value);
+ }
+
+
+ private:
+
+ // Implementation functions, documented in the .cpp.
+ inline void checkTType(const TType ttype);
+ inline void stateTransition();
+
+ // Read and write variable-length integers.
+ // Uses the same technique as the MIDI file format.
+ inline uint32_t vlqRead(uint64_t& vlq);
+ inline uint32_t vlqWrite(uint64_t vlq);
+
+ // Called before throwing an exception to make the object reusable.
+ void resetState() {
+ ts_stack_.clear();
+ idx_stack_.clear();
+ mkv_stack_.clear();
+ }
+
+ // TypeSpec of the top-level structure to write,
+ // for standalone protocol objects.
+ TypeSpec* type_spec_;
+
+ std::vector<TypeSpec*> ts_stack_; // TypeSpec stack.
+ std::vector<int> idx_stack_; // InDeX stack.
+ std::vector<bool> mkv_stack_; // Map Key/Vlue stack.
+ // True = key, False = value.
+
+ // True iff this is a standalone instance (no RPC).
+ bool standalone_;
+};
+
+}}} // apache::thrift::protocol
+
+#endif // #ifndef _THRIFT_PROTOCOL_TDENSEPROTOCOL_H_
diff --git a/lib/cpp/src/protocol/TJSONProtocol.cpp b/lib/cpp/src/protocol/TJSONProtocol.cpp
new file mode 100644
index 0000000..2a9c8f0
--- /dev/null
+++ b/lib/cpp/src/protocol/TJSONProtocol.cpp
@@ -0,0 +1,998 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "TJSONProtocol.h"
+
+#include <math.h>
+#include <boost/lexical_cast.hpp>
+#include "TBase64Utils.h"
+#include <transport/TTransportException.h>
+
+using namespace apache::thrift::transport;
+
+namespace apache { namespace thrift { namespace protocol {
+
+
+// Static data
+
+static const uint8_t kJSONObjectStart = '{';
+static const uint8_t kJSONObjectEnd = '}';
+static const uint8_t kJSONArrayStart = '[';
+static const uint8_t kJSONArrayEnd = ']';
+static const uint8_t kJSONNewline = '\n';
+static const uint8_t kJSONPairSeparator = ':';
+static const uint8_t kJSONElemSeparator = ',';
+static const uint8_t kJSONBackslash = '\\';
+static const uint8_t kJSONStringDelimiter = '"';
+static const uint8_t kJSONZeroChar = '0';
+static const uint8_t kJSONEscapeChar = 'u';
+
+static const std::string kJSONEscapePrefix("\\u00");
+
+static const uint32_t kThriftVersion1 = 1;
+
+static const std::string kThriftNan("NaN");
+static const std::string kThriftInfinity("Infinity");
+static const std::string kThriftNegativeInfinity("-Infinity");
+
+static const std::string kTypeNameBool("tf");
+static const std::string kTypeNameByte("i8");
+static const std::string kTypeNameI16("i16");
+static const std::string kTypeNameI32("i32");
+static const std::string kTypeNameI64("i64");
+static const std::string kTypeNameDouble("dbl");
+static const std::string kTypeNameStruct("rec");
+static const std::string kTypeNameString("str");
+static const std::string kTypeNameMap("map");
+static const std::string kTypeNameList("lst");
+static const std::string kTypeNameSet("set");
+
+static const std::string &getTypeNameForTypeID(TType typeID) {
+ switch (typeID) {
+ case T_BOOL:
+ return kTypeNameBool;
+ case T_BYTE:
+ return kTypeNameByte;
+ case T_I16:
+ return kTypeNameI16;
+ case T_I32:
+ return kTypeNameI32;
+ case T_I64:
+ return kTypeNameI64;
+ case T_DOUBLE:
+ return kTypeNameDouble;
+ case T_STRING:
+ return kTypeNameString;
+ case T_STRUCT:
+ return kTypeNameStruct;
+ case T_MAP:
+ return kTypeNameMap;
+ case T_SET:
+ return kTypeNameSet;
+ case T_LIST:
+ return kTypeNameList;
+ default:
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ "Unrecognized type");
+ }
+}
+
+static TType getTypeIDForTypeName(const std::string &name) {
+ TType result = T_STOP; // Sentinel value
+ if (name.length() > 1) {
+ switch (name[0]) {
+ case 'd':
+ result = T_DOUBLE;
+ break;
+ case 'i':
+ switch (name[1]) {
+ case '8':
+ result = T_BYTE;
+ break;
+ case '1':
+ result = T_I16;
+ break;
+ case '3':
+ result = T_I32;
+ break;
+ case '6':
+ result = T_I64;
+ break;
+ }
+ break;
+ case 'l':
+ result = T_LIST;
+ break;
+ case 'm':
+ result = T_MAP;
+ break;
+ case 'r':
+ result = T_STRUCT;
+ break;
+ case 's':
+ if (name[1] == 't') {
+ result = T_STRING;
+ }
+ else if (name[1] == 'e') {
+ result = T_SET;
+ }
+ break;
+ case 't':
+ result = T_BOOL;
+ break;
+ }
+ }
+ if (result == T_STOP) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ "Unrecognized type");
+ }
+ return result;
+}
+
+
+// This table describes the handling for the first 0x30 characters
+// 0 : escape using "\u00xx" notation
+// 1 : just output index
+// <other> : escape using "\<other>" notation
+static const uint8_t kJSONCharTable[0x30] = {
+// 0 1 2 3 4 5 6 7 8 9 A B C D E F
+ 0, 0, 0, 0, 0, 0, 0, 0,'b','t','n', 0,'f','r', 0, 0, // 0
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 1
+ 1, 1,'"', 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 2
+};
+
+
+// This string's characters must match up with the elements in kEscapeCharVals.
+// I don't have '/' on this list even though it appears on www.json.org --
+// it is not in the RFC
+const static std::string kEscapeChars("\"\\bfnrt");
+
+// The elements of this array must match up with the sequence of characters in
+// kEscapeChars
+const static uint8_t kEscapeCharVals[7] = {
+ '"', '\\', '\b', '\f', '\n', '\r', '\t',
+};
+
+
+// Static helper functions
+
+// Read 1 character from the transport trans and verify that it is the
+// expected character ch.
+// Throw a protocol exception if it is not.
+static uint32_t readSyntaxChar(TJSONProtocol::LookaheadReader &reader,
+ uint8_t ch) {
+ uint8_t ch2 = reader.read();
+ if (ch2 != ch) {
+ throw TProtocolException(TProtocolException::INVALID_DATA,
+ "Expected \'" + std::string((char *)&ch, 1) +
+ "\'; got \'" + std::string((char *)&ch2, 1) +
+ "\'.");
+ }
+ return 1;
+}
+
+// Return the integer value of a hex character ch.
+// Throw a protocol exception if the character is not [0-9a-f].
+static uint8_t hexVal(uint8_t ch) {
+ if ((ch >= '0') && (ch <= '9')) {
+ return ch - '0';
+ }
+ else if ((ch >= 'a') && (ch <= 'f')) {
+ return ch - 'a';
+ }
+ else {
+ throw TProtocolException(TProtocolException::INVALID_DATA,
+ "Expected hex val ([0-9a-f]); got \'"
+ + std::string((char *)&ch, 1) + "\'.");
+ }
+}
+
+// Return the hex character representing the integer val. The value is masked
+// to make sure it is in the correct range.
+static uint8_t hexChar(uint8_t val) {
+ val &= 0x0F;
+ if (val < 10) {
+ return val + '0';
+ }
+ else {
+ return val + 'a';
+ }
+}
+
+// Return true if the character ch is in [-+0-9.Ee]; false otherwise
+static bool isJSONNumeric(uint8_t ch) {
+ switch (ch) {
+ case '+':
+ case '-':
+ case '.':
+ case '0':
+ case '1':
+ case '2':
+ case '3':
+ case '4':
+ case '5':
+ case '6':
+ case '7':
+ case '8':
+ case '9':
+ case 'E':
+ case 'e':
+ return true;
+ }
+ return false;
+}
+
+
+/**
+ * Class to serve as base JSON context and as base class for other context
+ * implementations
+ */
+class TJSONContext {
+
+ public:
+
+ TJSONContext() {};
+
+ virtual ~TJSONContext() {};
+
+ /**
+ * Write context data to the transport. Default is to do nothing.
+ */
+ virtual uint32_t write(TTransport &trans) {
+ return 0;
+ };
+
+ /**
+ * Read context data from the transport. Default is to do nothing.
+ */
+ virtual uint32_t read(TJSONProtocol::LookaheadReader &reader) {
+ return 0;
+ };
+
+ /**
+ * Return true if numbers need to be escaped as strings in this context.
+ * Default behavior is to return false.
+ */
+ virtual bool escapeNum() {
+ return false;
+ }
+};
+
+// Context class for object member key-value pairs
+class JSONPairContext : public TJSONContext {
+
+public:
+
+ JSONPairContext() :
+ first_(true),
+ colon_(true) {
+ }
+
+ uint32_t write(TTransport &trans) {
+ if (first_) {
+ first_ = false;
+ colon_ = true;
+ return 0;
+ }
+ else {
+ trans.write(colon_ ? &kJSONPairSeparator : &kJSONElemSeparator, 1);
+ colon_ = !colon_;
+ return 1;
+ }
+ }
+
+ uint32_t read(TJSONProtocol::LookaheadReader &reader) {
+ if (first_) {
+ first_ = false;
+ colon_ = true;
+ return 0;
+ }
+ else {
+ uint8_t ch = (colon_ ? kJSONPairSeparator : kJSONElemSeparator);
+ colon_ = !colon_;
+ return readSyntaxChar(reader, ch);
+ }
+ }
+
+ // Numbers must be turned into strings if they are the key part of a pair
+ virtual bool escapeNum() {
+ return colon_;
+ }
+
+ private:
+
+ bool first_;
+ bool colon_;
+};
+
+// Context class for lists
+class JSONListContext : public TJSONContext {
+
+public:
+
+ JSONListContext() :
+ first_(true) {
+ }
+
+ uint32_t write(TTransport &trans) {
+ if (first_) {
+ first_ = false;
+ return 0;
+ }
+ else {
+ trans.write(&kJSONElemSeparator, 1);
+ return 1;
+ }
+ }
+
+ uint32_t read(TJSONProtocol::LookaheadReader &reader) {
+ if (first_) {
+ first_ = false;
+ return 0;
+ }
+ else {
+ return readSyntaxChar(reader, kJSONElemSeparator);
+ }
+ }
+
+ private:
+ bool first_;
+};
+
+
+TJSONProtocol::TJSONProtocol(boost::shared_ptr<TTransport> ptrans) :
+ TProtocol(ptrans),
+ context_(new TJSONContext()),
+ reader_(*ptrans) {
+}
+
+TJSONProtocol::~TJSONProtocol() {}
+
+void TJSONProtocol::pushContext(boost::shared_ptr<TJSONContext> c) {
+ contexts_.push(context_);
+ context_ = c;
+}
+
+void TJSONProtocol::popContext() {
+ context_ = contexts_.top();
+ contexts_.pop();
+}
+
+// Write the character ch as a JSON escape sequence ("\u00xx")
+uint32_t TJSONProtocol::writeJSONEscapeChar(uint8_t ch) {
+ trans_->write((const uint8_t *)kJSONEscapePrefix.c_str(),
+ kJSONEscapePrefix.length());
+ uint8_t outCh = hexChar(ch >> 4);
+ trans_->write(&outCh, 1);
+ outCh = hexChar(ch);
+ trans_->write(&outCh, 1);
+ return 6;
+}
+
+// Write the character ch as part of a JSON string, escaping as appropriate.
+uint32_t TJSONProtocol::writeJSONChar(uint8_t ch) {
+ if (ch >= 0x30) {
+ if (ch == kJSONBackslash) { // Only special character >= 0x30 is '\'
+ trans_->write(&kJSONBackslash, 1);
+ trans_->write(&kJSONBackslash, 1);
+ return 2;
+ }
+ else {
+ trans_->write(&ch, 1);
+ return 1;
+ }
+ }
+ else {
+ uint8_t outCh = kJSONCharTable[ch];
+ // Check if regular character, backslash escaped, or JSON escaped
+ if (outCh == 1) {
+ trans_->write(&ch, 1);
+ return 1;
+ }
+ else if (outCh > 1) {
+ trans_->write(&kJSONBackslash, 1);
+ trans_->write(&outCh, 1);
+ return 2;
+ }
+ else {
+ return writeJSONEscapeChar(ch);
+ }
+ }
+}
+
+// Write out the contents of the string str as a JSON string, escaping
+// characters as appropriate.
+uint32_t TJSONProtocol::writeJSONString(const std::string &str) {
+ uint32_t result = context_->write(*trans_);
+ result += 2; // For quotes
+ trans_->write(&kJSONStringDelimiter, 1);
+ std::string::const_iterator iter(str.begin());
+ std::string::const_iterator end(str.end());
+ while (iter != end) {
+ result += writeJSONChar(*iter++);
+ }
+ trans_->write(&kJSONStringDelimiter, 1);
+ return result;
+}
+
+// Write out the contents of the string as JSON string, base64-encoding
+// the string's contents, and escaping as appropriate
+uint32_t TJSONProtocol::writeJSONBase64(const std::string &str) {
+ uint32_t result = context_->write(*trans_);
+ result += 2; // For quotes
+ trans_->write(&kJSONStringDelimiter, 1);
+ uint8_t b[4];
+ const uint8_t *bytes = (const uint8_t *)str.c_str();
+ uint32_t len = str.length();
+ while (len >= 3) {
+ // Encode 3 bytes at a time
+ base64_encode(bytes, 3, b);
+ trans_->write(b, 4);
+ result += 4;
+ bytes += 3;
+ len -=3;
+ }
+ if (len) { // Handle remainder
+ base64_encode(bytes, len, b);
+ trans_->write(b, len + 1);
+ result += len + 1;
+ }
+ trans_->write(&kJSONStringDelimiter, 1);
+ return result;
+}
+
+// Convert the given integer type to a JSON number, or a string
+// if the context requires it (eg: key in a map pair).
+template <typename NumberType>
+uint32_t TJSONProtocol::writeJSONInteger(NumberType num) {
+ uint32_t result = context_->write(*trans_);
+ std::string val(boost::lexical_cast<std::string>(num));
+ bool escapeNum = context_->escapeNum();
+ if (escapeNum) {
+ trans_->write(&kJSONStringDelimiter, 1);
+ result += 1;
+ }
+ trans_->write((const uint8_t *)val.c_str(), val.length());
+ result += val.length();
+ if (escapeNum) {
+ trans_->write(&kJSONStringDelimiter, 1);
+ result += 1;
+ }
+ return result;
+}
+
+// Convert the given double to a JSON string, which is either the number,
+// "NaN" or "Infinity" or "-Infinity".
+uint32_t TJSONProtocol::writeJSONDouble(double num) {
+ uint32_t result = context_->write(*trans_);
+ std::string val(boost::lexical_cast<std::string>(num));
+
+ // Normalize output of boost::lexical_cast for NaNs and Infinities
+ bool special = false;
+ switch (val[0]) {
+ case 'N':
+ case 'n':
+ val = kThriftNan;
+ special = true;
+ break;
+ case 'I':
+ case 'i':
+ val = kThriftInfinity;
+ special = true;
+ break;
+ case '-':
+ if ((val[1] == 'I') || (val[1] == 'i')) {
+ val = kThriftNegativeInfinity;
+ special = true;
+ }
+ break;
+ }
+
+ bool escapeNum = special || context_->escapeNum();
+ if (escapeNum) {
+ trans_->write(&kJSONStringDelimiter, 1);
+ result += 1;
+ }
+ trans_->write((const uint8_t *)val.c_str(), val.length());
+ result += val.length();
+ if (escapeNum) {
+ trans_->write(&kJSONStringDelimiter, 1);
+ result += 1;
+ }
+ return result;
+}
+
+uint32_t TJSONProtocol::writeJSONObjectStart() {
+ uint32_t result = context_->write(*trans_);
+ trans_->write(&kJSONObjectStart, 1);
+ pushContext(boost::shared_ptr<TJSONContext>(new JSONPairContext()));
+ return result + 1;
+}
+
+uint32_t TJSONProtocol::writeJSONObjectEnd() {
+ popContext();
+ trans_->write(&kJSONObjectEnd, 1);
+ return 1;
+}
+
+uint32_t TJSONProtocol::writeJSONArrayStart() {
+ uint32_t result = context_->write(*trans_);
+ trans_->write(&kJSONArrayStart, 1);
+ pushContext(boost::shared_ptr<TJSONContext>(new JSONListContext()));
+ return result + 1;
+}
+
+uint32_t TJSONProtocol::writeJSONArrayEnd() {
+ popContext();
+ trans_->write(&kJSONArrayEnd, 1);
+ return 1;
+}
+
+uint32_t TJSONProtocol::writeMessageBegin(const std::string& name,
+ const TMessageType messageType,
+ const int32_t seqid) {
+ uint32_t result = writeJSONArrayStart();
+ result += writeJSONInteger(kThriftVersion1);
+ result += writeJSONString(name);
+ result += writeJSONInteger(messageType);
+ result += writeJSONInteger(seqid);
+ return result;
+}
+
+uint32_t TJSONProtocol::writeMessageEnd() {
+ return writeJSONArrayEnd();
+}
+
+uint32_t TJSONProtocol::writeStructBegin(const char* name) {
+ return writeJSONObjectStart();
+}
+
+uint32_t TJSONProtocol::writeStructEnd() {
+ return writeJSONObjectEnd();
+}
+
+uint32_t TJSONProtocol::writeFieldBegin(const char* name,
+ const TType fieldType,
+ const int16_t fieldId) {
+ uint32_t result = writeJSONInteger(fieldId);
+ result += writeJSONObjectStart();
+ result += writeJSONString(getTypeNameForTypeID(fieldType));
+ return result;
+}
+
+uint32_t TJSONProtocol::writeFieldEnd() {
+ return writeJSONObjectEnd();
+}
+
+uint32_t TJSONProtocol::writeFieldStop() {
+ return 0;
+}
+
+uint32_t TJSONProtocol::writeMapBegin(const TType keyType,
+ const TType valType,
+ const uint32_t size) {
+ uint32_t result = writeJSONArrayStart();
+ result += writeJSONString(getTypeNameForTypeID(keyType));
+ result += writeJSONString(getTypeNameForTypeID(valType));
+ result += writeJSONInteger((int64_t)size);
+ result += writeJSONObjectStart();
+ return result;
+}
+
+uint32_t TJSONProtocol::writeMapEnd() {
+ return writeJSONObjectEnd() + writeJSONArrayEnd();
+}
+
+uint32_t TJSONProtocol::writeListBegin(const TType elemType,
+ const uint32_t size) {
+ uint32_t result = writeJSONArrayStart();
+ result += writeJSONString(getTypeNameForTypeID(elemType));
+ result += writeJSONInteger((int64_t)size);
+ return result;
+}
+
+uint32_t TJSONProtocol::writeListEnd() {
+ return writeJSONArrayEnd();
+}
+
+uint32_t TJSONProtocol::writeSetBegin(const TType elemType,
+ const uint32_t size) {
+ uint32_t result = writeJSONArrayStart();
+ result += writeJSONString(getTypeNameForTypeID(elemType));
+ result += writeJSONInteger((int64_t)size);
+ return result;
+}
+
+uint32_t TJSONProtocol::writeSetEnd() {
+ return writeJSONArrayEnd();
+}
+
+uint32_t TJSONProtocol::writeBool(const bool value) {
+ return writeJSONInteger(value);
+}
+
+uint32_t TJSONProtocol::writeByte(const int8_t byte) {
+ // writeByte() must be handled specially becuase boost::lexical cast sees
+ // int8_t as a text type instead of an integer type
+ return writeJSONInteger((int16_t)byte);
+}
+
+uint32_t TJSONProtocol::writeI16(const int16_t i16) {
+ return writeJSONInteger(i16);
+}
+
+uint32_t TJSONProtocol::writeI32(const int32_t i32) {
+ return writeJSONInteger(i32);
+}
+
+uint32_t TJSONProtocol::writeI64(const int64_t i64) {
+ return writeJSONInteger(i64);
+}
+
+uint32_t TJSONProtocol::writeDouble(const double dub) {
+ return writeJSONDouble(dub);
+}
+
+uint32_t TJSONProtocol::writeString(const std::string& str) {
+ return writeJSONString(str);
+}
+
+uint32_t TJSONProtocol::writeBinary(const std::string& str) {
+ return writeJSONBase64(str);
+}
+
+ /**
+ * Reading functions
+ */
+
+// Reads 1 byte and verifies that it matches ch.
+uint32_t TJSONProtocol::readJSONSyntaxChar(uint8_t ch) {
+ return readSyntaxChar(reader_, ch);
+}
+
+// Decodes the four hex parts of a JSON escaped string character and returns
+// the character via out. The first two characters must be "00".
+uint32_t TJSONProtocol::readJSONEscapeChar(uint8_t *out) {
+ uint8_t b[2];
+ readJSONSyntaxChar(kJSONZeroChar);
+ readJSONSyntaxChar(kJSONZeroChar);
+ b[0] = reader_.read();
+ b[1] = reader_.read();
+ *out = (hexVal(b[0]) << 4) + hexVal(b[1]);
+ return 4;
+}
+
+// Decodes a JSON string, including unescaping, and returns the string via str
+uint32_t TJSONProtocol::readJSONString(std::string &str, bool skipContext) {
+ uint32_t result = (skipContext ? 0 : context_->read(reader_));
+ result += readJSONSyntaxChar(kJSONStringDelimiter);
+ uint8_t ch;
+ str.clear();
+ while (true) {
+ ch = reader_.read();
+ ++result;
+ if (ch == kJSONStringDelimiter) {
+ break;
+ }
+ if (ch == kJSONBackslash) {
+ ch = reader_.read();
+ ++result;
+ if (ch == kJSONEscapeChar) {
+ result += readJSONEscapeChar(&ch);
+ }
+ else {
+ size_t pos = kEscapeChars.find(ch);
+ if (pos == std::string::npos) {
+ throw TProtocolException(TProtocolException::INVALID_DATA,
+ "Expected control char, got '" +
+ std::string((const char *)&ch, 1) + "'.");
+ }
+ ch = kEscapeCharVals[pos];
+ }
+ }
+ str += ch;
+ }
+ return result;
+}
+
+// Reads a block of base64 characters, decoding it, and returns via str
+uint32_t TJSONProtocol::readJSONBase64(std::string &str) {
+ std::string tmp;
+ uint32_t result = readJSONString(tmp);
+ uint8_t *b = (uint8_t *)tmp.c_str();
+ uint32_t len = tmp.length();
+ str.clear();
+ while (len >= 4) {
+ base64_decode(b, 4);
+ str.append((const char *)b, 3);
+ b += 4;
+ len -= 4;
+ }
+ // Don't decode if we hit the end or got a single leftover byte (invalid
+ // base64 but legal for skip of regular string type)
+ if (len > 1) {
+ base64_decode(b, len);
+ str.append((const char *)b, len - 1);
+ }
+ return result;
+}
+
+// Reads a sequence of characters, stopping at the first one that is not
+// a valid JSON numeric character.
+uint32_t TJSONProtocol::readJSONNumericChars(std::string &str) {
+ uint32_t result = 0;
+ str.clear();
+ while (true) {
+ uint8_t ch = reader_.peek();
+ if (!isJSONNumeric(ch)) {
+ break;
+ }
+ reader_.read();
+ str += ch;
+ ++result;
+ }
+ return result;
+}
+
+// Reads a sequence of characters and assembles them into a number,
+// returning them via num
+template <typename NumberType>
+uint32_t TJSONProtocol::readJSONInteger(NumberType &num) {
+ uint32_t result = context_->read(reader_);
+ if (context_->escapeNum()) {
+ result += readJSONSyntaxChar(kJSONStringDelimiter);
+ }
+ std::string str;
+ result += readJSONNumericChars(str);
+ try {
+ num = boost::lexical_cast<NumberType>(str);
+ }
+ catch (boost::bad_lexical_cast e) {
+ throw new TProtocolException(TProtocolException::INVALID_DATA,
+ "Expected numeric value; got \"" + str +
+ "\"");
+ }
+ if (context_->escapeNum()) {
+ result += readJSONSyntaxChar(kJSONStringDelimiter);
+ }
+ return result;
+}
+
+// Reads a JSON number or string and interprets it as a double.
+uint32_t TJSONProtocol::readJSONDouble(double &num) {
+ uint32_t result = context_->read(reader_);
+ std::string str;
+ if (reader_.peek() == kJSONStringDelimiter) {
+ result += readJSONString(str, true);
+ // Check for NaN, Infinity and -Infinity
+ if (str == kThriftNan) {
+ num = HUGE_VAL/HUGE_VAL; // generates NaN
+ }
+ else if (str == kThriftInfinity) {
+ num = HUGE_VAL;
+ }
+ else if (str == kThriftNegativeInfinity) {
+ num = -HUGE_VAL;
+ }
+ else {
+ if (!context_->escapeNum()) {
+ // Throw exception -- we should not be in a string in this case
+ throw new TProtocolException(TProtocolException::INVALID_DATA,
+ "Numeric data unexpectedly quoted");
+ }
+ try {
+ num = boost::lexical_cast<double>(str);
+ }
+ catch (boost::bad_lexical_cast e) {
+ throw new TProtocolException(TProtocolException::INVALID_DATA,
+ "Expected numeric value; got \"" + str +
+ "\"");
+ }
+ }
+ }
+ else {
+ if (context_->escapeNum()) {
+ // This will throw - we should have had a quote if escapeNum == true
+ readJSONSyntaxChar(kJSONStringDelimiter);
+ }
+ result += readJSONNumericChars(str);
+ try {
+ num = boost::lexical_cast<double>(str);
+ }
+ catch (boost::bad_lexical_cast e) {
+ throw new TProtocolException(TProtocolException::INVALID_DATA,
+ "Expected numeric value; got \"" + str +
+ "\"");
+ }
+ }
+ return result;
+}
+
+uint32_t TJSONProtocol::readJSONObjectStart() {
+ uint32_t result = context_->read(reader_);
+ result += readJSONSyntaxChar(kJSONObjectStart);
+ pushContext(boost::shared_ptr<TJSONContext>(new JSONPairContext()));
+ return result;
+}
+
+uint32_t TJSONProtocol::readJSONObjectEnd() {
+ uint32_t result = readJSONSyntaxChar(kJSONObjectEnd);
+ popContext();
+ return result;
+}
+
+uint32_t TJSONProtocol::readJSONArrayStart() {
+ uint32_t result = context_->read(reader_);
+ result += readJSONSyntaxChar(kJSONArrayStart);
+ pushContext(boost::shared_ptr<TJSONContext>(new JSONListContext()));
+ return result;
+}
+
+uint32_t TJSONProtocol::readJSONArrayEnd() {
+ uint32_t result = readJSONSyntaxChar(kJSONArrayEnd);
+ popContext();
+ return result;
+}
+
+uint32_t TJSONProtocol::readMessageBegin(std::string& name,
+ TMessageType& messageType,
+ int32_t& seqid) {
+ uint32_t result = readJSONArrayStart();
+ uint64_t tmpVal = 0;
+ result += readJSONInteger(tmpVal);
+ if (tmpVal != kThriftVersion1) {
+ throw TProtocolException(TProtocolException::BAD_VERSION,
+ "Message contained bad version.");
+ }
+ result += readJSONString(name);
+ result += readJSONInteger(tmpVal);
+ messageType = (TMessageType)tmpVal;
+ result += readJSONInteger(tmpVal);
+ seqid = tmpVal;
+ return result;
+}
+
+uint32_t TJSONProtocol::readMessageEnd() {
+ return readJSONArrayEnd();
+}
+
+uint32_t TJSONProtocol::readStructBegin(std::string& name) {
+ return readJSONObjectStart();
+}
+
+uint32_t TJSONProtocol::readStructEnd() {
+ return readJSONObjectEnd();
+}
+
+uint32_t TJSONProtocol::readFieldBegin(std::string& name,
+ TType& fieldType,
+ int16_t& fieldId) {
+ uint32_t result = 0;
+ // Check if we hit the end of the list
+ uint8_t ch = reader_.peek();
+ if (ch == kJSONObjectEnd) {
+ fieldType = apache::thrift::protocol::T_STOP;
+ }
+ else {
+ uint64_t tmpVal = 0;
+ std::string tmpStr;
+ result += readJSONInteger(tmpVal);
+ fieldId = tmpVal;
+ result += readJSONObjectStart();
+ result += readJSONString(tmpStr);
+ fieldType = getTypeIDForTypeName(tmpStr);
+ }
+ return result;
+}
+
+uint32_t TJSONProtocol::readFieldEnd() {
+ return readJSONObjectEnd();
+}
+
+uint32_t TJSONProtocol::readMapBegin(TType& keyType,
+ TType& valType,
+ uint32_t& size) {
+ uint64_t tmpVal = 0;
+ std::string tmpStr;
+ uint32_t result = readJSONArrayStart();
+ result += readJSONString(tmpStr);
+ keyType = getTypeIDForTypeName(tmpStr);
+ result += readJSONString(tmpStr);
+ valType = getTypeIDForTypeName(tmpStr);
+ result += readJSONInteger(tmpVal);
+ size = tmpVal;
+ result += readJSONObjectStart();
+ return result;
+}
+
+uint32_t TJSONProtocol::readMapEnd() {
+ return readJSONObjectEnd() + readJSONArrayEnd();
+}
+
+uint32_t TJSONProtocol::readListBegin(TType& elemType,
+ uint32_t& size) {
+ uint64_t tmpVal = 0;
+ std::string tmpStr;
+ uint32_t result = readJSONArrayStart();
+ result += readJSONString(tmpStr);
+ elemType = getTypeIDForTypeName(tmpStr);
+ result += readJSONInteger(tmpVal);
+ size = tmpVal;
+ return result;
+}
+
+uint32_t TJSONProtocol::readListEnd() {
+ return readJSONArrayEnd();
+}
+
+uint32_t TJSONProtocol::readSetBegin(TType& elemType,
+ uint32_t& size) {
+ uint64_t tmpVal = 0;
+ std::string tmpStr;
+ uint32_t result = readJSONArrayStart();
+ result += readJSONString(tmpStr);
+ elemType = getTypeIDForTypeName(tmpStr);
+ result += readJSONInteger(tmpVal);
+ size = tmpVal;
+ return result;
+}
+
+uint32_t TJSONProtocol::readSetEnd() {
+ return readJSONArrayEnd();
+}
+
+uint32_t TJSONProtocol::readBool(bool& value) {
+ return readJSONInteger(value);
+}
+
+// readByte() must be handled properly becuase boost::lexical cast sees int8_t
+// as a text type instead of an integer type
+uint32_t TJSONProtocol::readByte(int8_t& byte) {
+ int16_t tmp = (int16_t) byte;
+ uint32_t result = readJSONInteger(tmp);
+ assert(tmp < 256);
+ byte = (int8_t)tmp;
+ return result;
+}
+
+uint32_t TJSONProtocol::readI16(int16_t& i16) {
+ return readJSONInteger(i16);
+}
+
+uint32_t TJSONProtocol::readI32(int32_t& i32) {
+ return readJSONInteger(i32);
+}
+
+uint32_t TJSONProtocol::readI64(int64_t& i64) {
+ return readJSONInteger(i64);
+}
+
+uint32_t TJSONProtocol::readDouble(double& dub) {
+ return readJSONDouble(dub);
+}
+
+uint32_t TJSONProtocol::readString(std::string &str) {
+ return readJSONString(str);
+}
+
+uint32_t TJSONProtocol::readBinary(std::string &str) {
+ return readJSONBase64(str);
+}
+
+}}} // apache::thrift::protocol
diff --git a/lib/cpp/src/protocol/TJSONProtocol.h b/lib/cpp/src/protocol/TJSONProtocol.h
new file mode 100644
index 0000000..2df499a
--- /dev/null
+++ b/lib/cpp/src/protocol/TJSONProtocol.h
@@ -0,0 +1,340 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_PROTOCOL_TJSONPROTOCOL_H_
+#define _THRIFT_PROTOCOL_TJSONPROTOCOL_H_ 1
+
+#include "TProtocol.h"
+
+#include <stack>
+
+namespace apache { namespace thrift { namespace protocol {
+
+// Forward declaration
+class TJSONContext;
+
+/**
+ * JSON protocol for Thrift.
+ *
+ * Implements a protocol which uses JSON as the wire-format.
+ *
+ * Thrift types are represented as described below:
+ *
+ * 1. Every Thrift integer type is represented as a JSON number.
+ *
+ * 2. Thrift doubles are represented as JSON numbers. Some special values are
+ * represented as strings:
+ * a. "NaN" for not-a-number values
+ * b. "Infinity" for postive infinity
+ * c. "-Infinity" for negative infinity
+ *
+ * 3. Thrift string values are emitted as JSON strings, with appropriate
+ * escaping.
+ *
+ * 4. Thrift binary values are encoded into Base64 and emitted as JSON strings.
+ * The readBinary() method is written such that it will properly skip if
+ * called on a Thrift string (although it will decode garbage data).
+ *
+ * 5. Thrift structs are represented as JSON objects, with the field ID as the
+ * key, and the field value represented as a JSON object with a single
+ * key-value pair. The key is a short string identifier for that type,
+ * followed by the value. The valid type identifiers are: "tf" for bool,
+ * "i8" for byte, "i16" for 16-bit integer, "i32" for 32-bit integer, "i64"
+ * for 64-bit integer, "dbl" for double-precision loating point, "str" for
+ * string (including binary), "rec" for struct ("records"), "map" for map,
+ * "lst" for list, "set" for set.
+ *
+ * 6. Thrift lists and sets are represented as JSON arrays, with the first
+ * element of the JSON array being the string identifier for the Thrift
+ * element type and the second element of the JSON array being the count of
+ * the Thrift elements. The Thrift elements then follow.
+ *
+ * 7. Thrift maps are represented as JSON arrays, with the first two elements
+ * of the JSON array being the string identifiers for the Thrift key type
+ * and value type, followed by the count of the Thrift pairs, followed by a
+ * JSON object containing the key-value pairs. Note that JSON keys can only
+ * be strings, which means that the key type of the Thrift map should be
+ * restricted to numeric or string types -- in the case of numerics, they
+ * are serialized as strings.
+ *
+ * 8. Thrift messages are represented as JSON arrays, with the protocol
+ * version #, the message name, the message type, and the sequence ID as
+ * the first 4 elements.
+ *
+ * More discussion of the double handling is probably warranted. The aim of
+ * the current implementation is to match as closely as possible the behavior
+ * of Java's Double.toString(), which has no precision loss. Implementors in
+ * other languages should strive to achieve that where possible. I have not
+ * yet verified whether boost:lexical_cast, which is doing that work for me in
+ * C++, loses any precision, but I am leaving this as a future improvement. I
+ * may try to provide a C component for this, so that other languages could
+ * bind to the same underlying implementation for maximum consistency.
+ *
+ * Note further that JavaScript itself is not capable of representing
+ * floating point infinities -- presumably when we have a JavaScript Thrift
+ * client, this would mean that infinities get converted to not-a-number in
+ * transmission. I don't know of any work-around for this issue.
+ *
+ */
+class TJSONProtocol : public TProtocol {
+ public:
+
+ TJSONProtocol(boost::shared_ptr<TTransport> ptrans);
+
+ ~TJSONProtocol();
+
+ private:
+
+ void pushContext(boost::shared_ptr<TJSONContext> c);
+
+ void popContext();
+
+ uint32_t writeJSONEscapeChar(uint8_t ch);
+
+ uint32_t writeJSONChar(uint8_t ch);
+
+ uint32_t writeJSONString(const std::string &str);
+
+ uint32_t writeJSONBase64(const std::string &str);
+
+ template <typename NumberType>
+ uint32_t writeJSONInteger(NumberType num);
+
+ uint32_t writeJSONDouble(double num);
+
+ uint32_t writeJSONObjectStart() ;
+
+ uint32_t writeJSONObjectEnd();
+
+ uint32_t writeJSONArrayStart();
+
+ uint32_t writeJSONArrayEnd();
+
+ uint32_t readJSONSyntaxChar(uint8_t ch);
+
+ uint32_t readJSONEscapeChar(uint8_t *out);
+
+ uint32_t readJSONString(std::string &str, bool skipContext = false);
+
+ uint32_t readJSONBase64(std::string &str);
+
+ uint32_t readJSONNumericChars(std::string &str);
+
+ template <typename NumberType>
+ uint32_t readJSONInteger(NumberType &num);
+
+ uint32_t readJSONDouble(double &num);
+
+ uint32_t readJSONObjectStart();
+
+ uint32_t readJSONObjectEnd();
+
+ uint32_t readJSONArrayStart();
+
+ uint32_t readJSONArrayEnd();
+
+ public:
+
+ /**
+ * Writing functions.
+ */
+
+ uint32_t writeMessageBegin(const std::string& name,
+ const TMessageType messageType,
+ const int32_t seqid);
+
+ uint32_t writeMessageEnd();
+
+ uint32_t writeStructBegin(const char* name);
+
+ uint32_t writeStructEnd();
+
+ uint32_t writeFieldBegin(const char* name,
+ const TType fieldType,
+ const int16_t fieldId);
+
+ uint32_t writeFieldEnd();
+
+ uint32_t writeFieldStop();
+
+ uint32_t writeMapBegin(const TType keyType,
+ const TType valType,
+ const uint32_t size);
+
+ uint32_t writeMapEnd();
+
+ uint32_t writeListBegin(const TType elemType,
+ const uint32_t size);
+
+ uint32_t writeListEnd();
+
+ uint32_t writeSetBegin(const TType elemType,
+ const uint32_t size);
+
+ uint32_t writeSetEnd();
+
+ uint32_t writeBool(const bool value);
+
+ uint32_t writeByte(const int8_t byte);
+
+ uint32_t writeI16(const int16_t i16);
+
+ uint32_t writeI32(const int32_t i32);
+
+ uint32_t writeI64(const int64_t i64);
+
+ uint32_t writeDouble(const double dub);
+
+ uint32_t writeString(const std::string& str);
+
+ uint32_t writeBinary(const std::string& str);
+
+ /**
+ * Reading functions
+ */
+
+ uint32_t readMessageBegin(std::string& name,
+ TMessageType& messageType,
+ int32_t& seqid);
+
+ uint32_t readMessageEnd();
+
+ uint32_t readStructBegin(std::string& name);
+
+ uint32_t readStructEnd();
+
+ uint32_t readFieldBegin(std::string& name,
+ TType& fieldType,
+ int16_t& fieldId);
+
+ uint32_t readFieldEnd();
+
+ uint32_t readMapBegin(TType& keyType,
+ TType& valType,
+ uint32_t& size);
+
+ uint32_t readMapEnd();
+
+ uint32_t readListBegin(TType& elemType,
+ uint32_t& size);
+
+ uint32_t readListEnd();
+
+ uint32_t readSetBegin(TType& elemType,
+ uint32_t& size);
+
+ uint32_t readSetEnd();
+
+ uint32_t readBool(bool& value);
+
+ uint32_t readByte(int8_t& byte);
+
+ uint32_t readI16(int16_t& i16);
+
+ uint32_t readI32(int32_t& i32);
+
+ uint32_t readI64(int64_t& i64);
+
+ uint32_t readDouble(double& dub);
+
+ uint32_t readString(std::string& str);
+
+ uint32_t readBinary(std::string& str);
+
+ class LookaheadReader {
+
+ public:
+
+ LookaheadReader(TTransport &trans) :
+ trans_(&trans),
+ hasData_(false) {
+ }
+
+ uint8_t read() {
+ if (hasData_) {
+ hasData_ = false;
+ }
+ else {
+ trans_->readAll(&data_, 1);
+ }
+ return data_;
+ }
+
+ uint8_t peek() {
+ if (!hasData_) {
+ trans_->readAll(&data_, 1);
+ }
+ hasData_ = true;
+ return data_;
+ }
+
+ private:
+ TTransport *trans_;
+ bool hasData_;
+ uint8_t data_;
+ };
+
+ private:
+
+ std::stack<boost::shared_ptr<TJSONContext> > contexts_;
+ boost::shared_ptr<TJSONContext> context_;
+ LookaheadReader reader_;
+};
+
+/**
+ * Constructs input and output protocol objects given transports.
+ */
+class TJSONProtocolFactory : public TProtocolFactory {
+ public:
+ TJSONProtocolFactory() {}
+
+ virtual ~TJSONProtocolFactory() {}
+
+ boost::shared_ptr<TProtocol> getProtocol(boost::shared_ptr<TTransport> trans) {
+ return boost::shared_ptr<TProtocol>(new TJSONProtocol(trans));
+ }
+};
+
+}}} // apache::thrift::protocol
+
+
+// TODO(dreiss): Move part of ThriftJSONString into a .cpp file and remove this.
+#include <transport/TBufferTransports.h>
+
+namespace apache { namespace thrift {
+
+template<typename ThriftStruct>
+ std::string ThriftJSONString(const ThriftStruct& ts) {
+ using namespace apache::thrift::transport;
+ using namespace apache::thrift::protocol;
+ TMemoryBuffer* buffer = new TMemoryBuffer;
+ boost::shared_ptr<TTransport> trans(buffer);
+ TJSONProtocol protocol(trans);
+
+ ts.write(&protocol);
+
+ uint8_t* buf;
+ uint32_t size;
+ buffer->getBuffer(&buf, &size);
+ return std::string((char*)buf, (unsigned int)size);
+}
+
+}} // apache::thrift
+
+#endif // #define _THRIFT_PROTOCOL_TJSONPROTOCOL_H_ 1
diff --git a/lib/cpp/src/protocol/TOneWayProtocol.h b/lib/cpp/src/protocol/TOneWayProtocol.h
new file mode 100644
index 0000000..6f08fe1
--- /dev/null
+++ b/lib/cpp/src/protocol/TOneWayProtocol.h
@@ -0,0 +1,304 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_PROTOCOL_TONEWAYPROTOCOL_H_
+#define _THRIFT_PROTOCOL_TONEWAYPROTOCOL_H_ 1
+
+#include "TProtocol.h"
+
+namespace apache { namespace thrift { namespace protocol {
+
+/**
+ * Abstract class for implementing a protocol that can only be written,
+ * not read.
+ *
+ */
+class TWriteOnlyProtocol : public TProtocol {
+ public:
+ /**
+ * @param subclass_name The name of the concrete subclass.
+ */
+ TWriteOnlyProtocol(boost::shared_ptr<TTransport> trans,
+ const std::string& subclass_name)
+ : TProtocol(trans)
+ , subclass_(subclass_name)
+ {}
+
+ // All writing functions remain abstract.
+
+ /**
+ * Reading functions all throw an exception.
+ */
+
+ uint32_t readMessageBegin(std::string& name,
+ TMessageType& messageType,
+ int32_t& seqid) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support reading (yet).");
+ }
+
+ uint32_t readMessageEnd() {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support reading (yet).");
+ }
+
+ uint32_t readStructBegin(std::string& name) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support reading (yet).");
+ }
+
+ uint32_t readStructEnd() {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support reading (yet).");
+ }
+
+ uint32_t readFieldBegin(std::string& name,
+ TType& fieldType,
+ int16_t& fieldId) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support reading (yet).");
+ }
+
+ uint32_t readFieldEnd() {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support reading (yet).");
+ }
+
+ uint32_t readMapBegin(TType& keyType,
+ TType& valType,
+ uint32_t& size) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support reading (yet).");
+ }
+
+ uint32_t readMapEnd() {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support reading (yet).");
+ }
+
+ uint32_t readListBegin(TType& elemType,
+ uint32_t& size) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support reading (yet).");
+ }
+
+ uint32_t readListEnd() {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support reading (yet).");
+ }
+
+ uint32_t readSetBegin(TType& elemType,
+ uint32_t& size) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support reading (yet).");
+ }
+
+ uint32_t readSetEnd() {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support reading (yet).");
+ }
+
+ uint32_t readBool(bool& value) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support reading (yet).");
+ }
+
+ uint32_t readByte(int8_t& byte) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support reading (yet).");
+ }
+
+ uint32_t readI16(int16_t& i16) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support reading (yet).");
+ }
+
+ uint32_t readI32(int32_t& i32) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support reading (yet).");
+ }
+
+ uint32_t readI64(int64_t& i64) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support reading (yet).");
+ }
+
+ uint32_t readDouble(double& dub) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support reading (yet).");
+ }
+
+ uint32_t readString(std::string& str) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support reading (yet).");
+ }
+
+ uint32_t readBinary(std::string& str) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support reading (yet).");
+ }
+
+ private:
+ std::string subclass_;
+};
+
+
+/**
+ * Abstract class for implementing a protocol that can only be read,
+ * not written.
+ *
+ */
+class TReadOnlyProtocol : public TProtocol {
+ public:
+ /**
+ * @param subclass_name The name of the concrete subclass.
+ */
+ TReadOnlyProtocol(boost::shared_ptr<TTransport> trans,
+ const std::string& subclass_name)
+ : TProtocol(trans)
+ , subclass_(subclass_name)
+ {}
+
+ // All reading functions remain abstract.
+
+ /**
+ * Writing functions all throw an exception.
+ */
+
+ uint32_t writeMessageBegin(const std::string& name,
+ const TMessageType messageType,
+ const int32_t seqid) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support writing (yet).");
+ }
+
+ uint32_t writeMessageEnd() {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support writing (yet).");
+ }
+
+
+ uint32_t writeStructBegin(const char* name) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support writing (yet).");
+ }
+
+ uint32_t writeStructEnd() {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support writing (yet).");
+ }
+
+ uint32_t writeFieldBegin(const char* name,
+ const TType fieldType,
+ const int16_t fieldId) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support writing (yet).");
+ }
+
+ uint32_t writeFieldEnd() {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support writing (yet).");
+ }
+
+ uint32_t writeFieldStop() {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support writing (yet).");
+ }
+
+ uint32_t writeMapBegin(const TType keyType,
+ const TType valType,
+ const uint32_t size) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support writing (yet).");
+ }
+
+ uint32_t writeMapEnd() {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support writing (yet).");
+ }
+
+ uint32_t writeListBegin(const TType elemType,
+ const uint32_t size) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support writing (yet).");
+ }
+
+ uint32_t writeListEnd() {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support writing (yet).");
+ }
+
+ uint32_t writeSetBegin(const TType elemType,
+ const uint32_t size) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support writing (yet).");
+ }
+
+ uint32_t writeSetEnd() {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support writing (yet).");
+ }
+
+ uint32_t writeBool(const bool value) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support writing (yet).");
+ }
+
+ uint32_t writeByte(const int8_t byte) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support writing (yet).");
+ }
+
+ uint32_t writeI16(const int16_t i16) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support writing (yet).");
+ }
+
+ uint32_t writeI32(const int32_t i32) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support writing (yet).");
+ }
+
+ uint32_t writeI64(const int64_t i64) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support writing (yet).");
+ }
+
+ uint32_t writeDouble(const double dub) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support writing (yet).");
+ }
+
+ uint32_t writeString(const std::string& str) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support writing (yet).");
+ }
+
+ uint32_t writeBinary(const std::string& str) {
+ throw TProtocolException(TProtocolException::NOT_IMPLEMENTED,
+ subclass_ + " does not support writing (yet).");
+ }
+
+ private:
+ std::string subclass_;
+};
+
+}}} // apache::thrift::protocol
+
+#endif // #ifndef _THRIFT_PROTOCOL_TBINARYPROTOCOL_H_
diff --git a/lib/cpp/src/protocol/TProtocol.h b/lib/cpp/src/protocol/TProtocol.h
new file mode 100644
index 0000000..4025827
--- /dev/null
+++ b/lib/cpp/src/protocol/TProtocol.h
@@ -0,0 +1,438 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_PROTOCOL_TPROTOCOL_H_
+#define _THRIFT_PROTOCOL_TPROTOCOL_H_ 1
+
+#include <transport/TTransport.h>
+#include <protocol/TProtocolException.h>
+
+#include <boost/shared_ptr.hpp>
+#include <boost/static_assert.hpp>
+
+#include <netinet/in.h>
+#include <sys/types.h>
+#include <string>
+#include <map>
+
+
+// Use this to get around strict aliasing rules.
+// For example, uint64_t i = bitwise_cast<uint64_t>(returns_double());
+// The most obvious implementation is to just cast a pointer,
+// but that doesn't work.
+// For a pretty in-depth explanation of the problem, see
+// http://www.cellperformance.com/mike_acton/2006/06/ (...)
+// understanding_strict_aliasing.html
+template <typename To, typename From>
+static inline To bitwise_cast(From from) {
+ BOOST_STATIC_ASSERT(sizeof(From) == sizeof(To));
+
+ // BAD!!! These are all broken with -O2.
+ //return *reinterpret_cast<To*>(&from); // BAD!!!
+ //return *static_cast<To*>(static_cast<void*>(&from)); // BAD!!!
+ //return *(To*)(void*)&from; // BAD!!!
+
+ // Super clean and paritally blessed by section 3.9 of the standard.
+ //unsigned char c[sizeof(from)];
+ //memcpy(c, &from, sizeof(from));
+ //To to;
+ //memcpy(&to, c, sizeof(c));
+ //return to;
+
+ // Slightly more questionable.
+ // Same code emitted by GCC.
+ //To to;
+ //memcpy(&to, &from, sizeof(from));
+ //return to;
+
+ // Technically undefined, but almost universally supported,
+ // and the most efficient implementation.
+ union {
+ From f;
+ To t;
+ } u;
+ u.f = from;
+ return u.t;
+}
+
+
+namespace apache { namespace thrift { namespace protocol {
+
+using apache::thrift::transport::TTransport;
+
+#ifdef HAVE_ENDIAN_H
+#include <endian.h>
+#endif
+
+#ifndef __BYTE_ORDER
+# if defined(BYTE_ORDER) && defined(LITTLE_ENDIAN) && defined(BIG_ENDIAN)
+# define __BYTE_ORDER BYTE_ORDER
+# define __LITTLE_ENDIAN LITTLE_ENDIAN
+# define __BIG_ENDIAN BIG_ENDIAN
+# else
+# error "Cannot determine endianness"
+# endif
+#endif
+
+#if __BYTE_ORDER == __BIG_ENDIAN
+# define ntohll(n) (n)
+# define htonll(n) (n)
+# if defined(__GNUC__) && defined(__GLIBC__)
+# include <byteswap.h>
+# define htolell(n) bswap_64(n)
+# define letohll(n) bswap_64(n)
+# else /* GNUC & GLIBC */
+# define bswap_64(n) \
+ ( (((n) & 0xff00000000000000ull) >> 56) \
+ | (((n) & 0x00ff000000000000ull) >> 40) \
+ | (((n) & 0x0000ff0000000000ull) >> 24) \
+ | (((n) & 0x000000ff00000000ull) >> 8) \
+ | (((n) & 0x00000000ff000000ull) << 8) \
+ | (((n) & 0x0000000000ff0000ull) << 24) \
+ | (((n) & 0x000000000000ff00ull) << 40) \
+ | (((n) & 0x00000000000000ffull) << 56) )
+# define ntolell(n) bswap_64(n)
+# define letonll(n) bswap_64(n)
+# endif /* GNUC & GLIBC */
+#elif __BYTE_ORDER == __LITTLE_ENDIAN
+# define htolell(n) (n)
+# define letohll(n) (n)
+# if defined(__GNUC__) && defined(__GLIBC__)
+# include <byteswap.h>
+# define ntohll(n) bswap_64(n)
+# define htonll(n) bswap_64(n)
+# else /* GNUC & GLIBC */
+# define ntohll(n) ( (((unsigned long long)ntohl(n)) << 32) + ntohl(n >> 32) )
+# define htonll(n) ( (((unsigned long long)htonl(n)) << 32) + htonl(n >> 32) )
+# endif /* GNUC & GLIBC */
+#else /* __BYTE_ORDER */
+# error "Can't define htonll or ntohll!"
+#endif
+
+/**
+ * Enumerated definition of the types that the Thrift protocol supports.
+ * Take special note of the T_END type which is used specifically to mark
+ * the end of a sequence of fields.
+ */
+enum TType {
+ T_STOP = 0,
+ T_VOID = 1,
+ T_BOOL = 2,
+ T_BYTE = 3,
+ T_I08 = 3,
+ T_I16 = 6,
+ T_I32 = 8,
+ T_U64 = 9,
+ T_I64 = 10,
+ T_DOUBLE = 4,
+ T_STRING = 11,
+ T_UTF7 = 11,
+ T_STRUCT = 12,
+ T_MAP = 13,
+ T_SET = 14,
+ T_LIST = 15,
+ T_UTF8 = 16,
+ T_UTF16 = 17
+};
+
+/**
+ * Enumerated definition of the message types that the Thrift protocol
+ * supports.
+ */
+enum TMessageType {
+ T_CALL = 1,
+ T_REPLY = 2,
+ T_EXCEPTION = 3,
+ T_ONEWAY = 4
+};
+
+/**
+ * Abstract class for a thrift protocol driver. These are all the methods that
+ * a protocol must implement. Essentially, there must be some way of reading
+ * and writing all the base types, plus a mechanism for writing out structs
+ * with indexed fields.
+ *
+ * TProtocol objects should not be shared across multiple encoding contexts,
+ * as they may need to maintain internal state in some protocols (i.e. XML).
+ * Note that is is acceptable for the TProtocol module to do its own internal
+ * buffered reads/writes to the underlying TTransport where appropriate (i.e.
+ * when parsing an input XML stream, reading should be batched rather than
+ * looking ahead character by character for a close tag).
+ *
+ */
+class TProtocol {
+ public:
+ virtual ~TProtocol() {}
+
+ /**
+ * Writing functions.
+ */
+
+ virtual uint32_t writeMessageBegin(const std::string& name,
+ const TMessageType messageType,
+ const int32_t seqid) = 0;
+
+ virtual uint32_t writeMessageEnd() = 0;
+
+
+ virtual uint32_t writeStructBegin(const char* name) = 0;
+
+ virtual uint32_t writeStructEnd() = 0;
+
+ virtual uint32_t writeFieldBegin(const char* name,
+ const TType fieldType,
+ const int16_t fieldId) = 0;
+
+ virtual uint32_t writeFieldEnd() = 0;
+
+ virtual uint32_t writeFieldStop() = 0;
+
+ virtual uint32_t writeMapBegin(const TType keyType,
+ const TType valType,
+ const uint32_t size) = 0;
+
+ virtual uint32_t writeMapEnd() = 0;
+
+ virtual uint32_t writeListBegin(const TType elemType,
+ const uint32_t size) = 0;
+
+ virtual uint32_t writeListEnd() = 0;
+
+ virtual uint32_t writeSetBegin(const TType elemType,
+ const uint32_t size) = 0;
+
+ virtual uint32_t writeSetEnd() = 0;
+
+ virtual uint32_t writeBool(const bool value) = 0;
+
+ virtual uint32_t writeByte(const int8_t byte) = 0;
+
+ virtual uint32_t writeI16(const int16_t i16) = 0;
+
+ virtual uint32_t writeI32(const int32_t i32) = 0;
+
+ virtual uint32_t writeI64(const int64_t i64) = 0;
+
+ virtual uint32_t writeDouble(const double dub) = 0;
+
+ virtual uint32_t writeString(const std::string& str) = 0;
+
+ virtual uint32_t writeBinary(const std::string& str) = 0;
+
+ /**
+ * Reading functions
+ */
+
+ virtual uint32_t readMessageBegin(std::string& name,
+ TMessageType& messageType,
+ int32_t& seqid) = 0;
+
+ virtual uint32_t readMessageEnd() = 0;
+
+ virtual uint32_t readStructBegin(std::string& name) = 0;
+
+ virtual uint32_t readStructEnd() = 0;
+
+ virtual uint32_t readFieldBegin(std::string& name,
+ TType& fieldType,
+ int16_t& fieldId) = 0;
+
+ virtual uint32_t readFieldEnd() = 0;
+
+ virtual uint32_t readMapBegin(TType& keyType,
+ TType& valType,
+ uint32_t& size) = 0;
+
+ virtual uint32_t readMapEnd() = 0;
+
+ virtual uint32_t readListBegin(TType& elemType,
+ uint32_t& size) = 0;
+
+ virtual uint32_t readListEnd() = 0;
+
+ virtual uint32_t readSetBegin(TType& elemType,
+ uint32_t& size) = 0;
+
+ virtual uint32_t readSetEnd() = 0;
+
+ virtual uint32_t readBool(bool& value) = 0;
+
+ virtual uint32_t readByte(int8_t& byte) = 0;
+
+ virtual uint32_t readI16(int16_t& i16) = 0;
+
+ virtual uint32_t readI32(int32_t& i32) = 0;
+
+ virtual uint32_t readI64(int64_t& i64) = 0;
+
+ virtual uint32_t readDouble(double& dub) = 0;
+
+ virtual uint32_t readString(std::string& str) = 0;
+
+ virtual uint32_t readBinary(std::string& str) = 0;
+
+ uint32_t readBool(std::vector<bool>::reference ref) {
+ bool value;
+ uint32_t rv = readBool(value);
+ ref = value;
+ return rv;
+ }
+
+ /**
+ * Method to arbitrarily skip over data.
+ */
+ uint32_t skip(TType type) {
+ switch (type) {
+ case T_BOOL:
+ {
+ bool boolv;
+ return readBool(boolv);
+ }
+ case T_BYTE:
+ {
+ int8_t bytev;
+ return readByte(bytev);
+ }
+ case T_I16:
+ {
+ int16_t i16;
+ return readI16(i16);
+ }
+ case T_I32:
+ {
+ int32_t i32;
+ return readI32(i32);
+ }
+ case T_I64:
+ {
+ int64_t i64;
+ return readI64(i64);
+ }
+ case T_DOUBLE:
+ {
+ double dub;
+ return readDouble(dub);
+ }
+ case T_STRING:
+ {
+ std::string str;
+ return readBinary(str);
+ }
+ case T_STRUCT:
+ {
+ uint32_t result = 0;
+ std::string name;
+ int16_t fid;
+ TType ftype;
+ result += readStructBegin(name);
+ while (true) {
+ result += readFieldBegin(name, ftype, fid);
+ if (ftype == T_STOP) {
+ break;
+ }
+ result += skip(ftype);
+ result += readFieldEnd();
+ }
+ result += readStructEnd();
+ return result;
+ }
+ case T_MAP:
+ {
+ uint32_t result = 0;
+ TType keyType;
+ TType valType;
+ uint32_t i, size;
+ result += readMapBegin(keyType, valType, size);
+ for (i = 0; i < size; i++) {
+ result += skip(keyType);
+ result += skip(valType);
+ }
+ result += readMapEnd();
+ return result;
+ }
+ case T_SET:
+ {
+ uint32_t result = 0;
+ TType elemType;
+ uint32_t i, size;
+ result += readSetBegin(elemType, size);
+ for (i = 0; i < size; i++) {
+ result += skip(elemType);
+ }
+ result += readSetEnd();
+ return result;
+ }
+ case T_LIST:
+ {
+ uint32_t result = 0;
+ TType elemType;
+ uint32_t i, size;
+ result += readListBegin(elemType, size);
+ for (i = 0; i < size; i++) {
+ result += skip(elemType);
+ }
+ result += readListEnd();
+ return result;
+ }
+ default:
+ return 0;
+ }
+ }
+
+ inline boost::shared_ptr<TTransport> getTransport() {
+ return ptrans_;
+ }
+
+ // TODO: remove these two calls, they are for backwards
+ // compatibility
+ inline boost::shared_ptr<TTransport> getInputTransport() {
+ return ptrans_;
+ }
+ inline boost::shared_ptr<TTransport> getOutputTransport() {
+ return ptrans_;
+ }
+
+ protected:
+ TProtocol(boost::shared_ptr<TTransport> ptrans):
+ ptrans_(ptrans) {
+ trans_ = ptrans.get();
+ }
+
+ boost::shared_ptr<TTransport> ptrans_;
+ TTransport* trans_;
+
+ private:
+ TProtocol() {}
+};
+
+/**
+ * Constructs input and output protocol objects given transports.
+ */
+class TProtocolFactory {
+ public:
+ TProtocolFactory() {}
+
+ virtual ~TProtocolFactory() {}
+
+ virtual boost::shared_ptr<TProtocol> getProtocol(boost::shared_ptr<TTransport> trans) = 0;
+};
+
+}}} // apache::thrift::protocol
+
+#endif // #define _THRIFT_PROTOCOL_TPROTOCOL_H_ 1
diff --git a/lib/cpp/src/protocol/TProtocolException.h b/lib/cpp/src/protocol/TProtocolException.h
new file mode 100644
index 0000000..33011b3
--- /dev/null
+++ b/lib/cpp/src/protocol/TProtocolException.h
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_PROTOCOL_TPROTOCOLEXCEPTION_H_
+#define _THRIFT_PROTOCOL_TPROTOCOLEXCEPTION_H_ 1
+
+#include <string>
+
+namespace apache { namespace thrift { namespace protocol {
+
+/**
+ * Class to encapsulate all the possible types of protocol errors that may
+ * occur in various protocol systems. This provides a sort of generic
+ * wrapper around the shitty UNIX E_ error codes that lets a common code
+ * base of error handling to be used for various types of protocols, i.e.
+ * pipes etc.
+ *
+ */
+class TProtocolException : public apache::thrift::TException {
+ public:
+
+ /**
+ * Error codes for the various types of exceptions.
+ */
+ enum TProtocolExceptionType
+ { UNKNOWN = 0
+ , INVALID_DATA = 1
+ , NEGATIVE_SIZE = 2
+ , SIZE_LIMIT = 3
+ , BAD_VERSION = 4
+ , NOT_IMPLEMENTED = 5
+ };
+
+ TProtocolException() :
+ apache::thrift::TException(),
+ type_(UNKNOWN) {}
+
+ TProtocolException(TProtocolExceptionType type) :
+ apache::thrift::TException(),
+ type_(type) {}
+
+ TProtocolException(const std::string& message) :
+ apache::thrift::TException(message),
+ type_(UNKNOWN) {}
+
+ TProtocolException(TProtocolExceptionType type, const std::string& message) :
+ apache::thrift::TException(message),
+ type_(type) {}
+
+ virtual ~TProtocolException() throw() {}
+
+ /**
+ * Returns an error code that provides information about the type of error
+ * that has occurred.
+ *
+ * @return Error code
+ */
+ TProtocolExceptionType getType() {
+ return type_;
+ }
+
+ virtual const char* what() const throw() {
+ if (message_.empty()) {
+ switch (type_) {
+ case UNKNOWN : return "TProtocolException: Unknown protocol exception";
+ case INVALID_DATA : return "TProtocolException: Invalid data";
+ case NEGATIVE_SIZE : return "TProtocolException: Negative size";
+ case SIZE_LIMIT : return "TProtocolException: Exceeded size limit";
+ case BAD_VERSION : return "TProtocolException: Invalid version";
+ case NOT_IMPLEMENTED : return "TProtocolException: Not implemented";
+ default : return "TProtocolException: (Invalid exception type)";
+ }
+ } else {
+ return message_.c_str();
+ }
+ }
+
+ protected:
+ /**
+ * Error code
+ */
+ TProtocolExceptionType type_;
+
+};
+
+}}} // apache::thrift::protocol
+
+#endif // #ifndef _THRIFT_PROTOCOL_TPROTOCOLEXCEPTION_H_
diff --git a/lib/cpp/src/protocol/TProtocolTap.h b/lib/cpp/src/protocol/TProtocolTap.h
new file mode 100644
index 0000000..5580216
--- /dev/null
+++ b/lib/cpp/src/protocol/TProtocolTap.h
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_PROTOCOL_TPROTOCOLTAP_H_
+#define _THRIFT_PROTOCOL_TPROTOCOLTAP_H_ 1
+
+#include <protocol/TOneWayProtocol.h>
+
+namespace apache { namespace thrift { namespace protocol {
+
+using apache::thrift::transport::TTransport;
+
+/**
+ * Puts a wiretap on a protocol object. Any reads to this class are passed
+ * through to an enclosed protocol object, but also mirrored as write to a
+ * second protocol object.
+ *
+ */
+class TProtocolTap : public TReadOnlyProtocol {
+ public:
+ TProtocolTap(boost::shared_ptr<TProtocol> source,
+ boost::shared_ptr<TProtocol> sink)
+ : TReadOnlyProtocol(source->getTransport(), "TProtocolTap")
+ , source_(source)
+ , sink_(sink)
+ {}
+
+ virtual uint32_t readMessageBegin(std::string& name,
+ TMessageType& messageType,
+ int32_t& seqid) {
+ uint32_t rv = source_->readMessageBegin(name, messageType, seqid);
+ sink_->writeMessageBegin(name, messageType, seqid);
+ return rv;
+ }
+
+ virtual uint32_t readMessageEnd() {
+ uint32_t rv = source_->readMessageEnd();
+ sink_->writeMessageEnd();
+ return rv;
+ }
+
+ virtual uint32_t readStructBegin(std::string& name) {
+ uint32_t rv = source_->readStructBegin(name);
+ sink_->writeStructBegin(name.c_str());
+ return rv;
+ }
+
+ virtual uint32_t readStructEnd() {
+ uint32_t rv = source_->readStructEnd();
+ sink_->writeStructEnd();
+ return rv;
+ }
+
+ virtual uint32_t readFieldBegin(std::string& name,
+ TType& fieldType,
+ int16_t& fieldId) {
+ uint32_t rv = source_->readFieldBegin(name, fieldType, fieldId);
+ if (fieldType == T_STOP) {
+ sink_->writeFieldStop();
+ } else {
+ sink_->writeFieldBegin(name.c_str(), fieldType, fieldId);
+ }
+ return rv;
+ }
+
+
+ virtual uint32_t readFieldEnd() {
+ uint32_t rv = source_->readFieldEnd();
+ sink_->writeFieldEnd();
+ return rv;
+ }
+
+ virtual uint32_t readMapBegin(TType& keyType,
+ TType& valType,
+ uint32_t& size) {
+ uint32_t rv = source_->readMapBegin(keyType, valType, size);
+ sink_->writeMapBegin(keyType, valType, size);
+ return rv;
+ }
+
+
+ virtual uint32_t readMapEnd() {
+ uint32_t rv = source_->readMapEnd();
+ sink_->writeMapEnd();
+ return rv;
+ }
+
+ virtual uint32_t readListBegin(TType& elemType,
+ uint32_t& size) {
+ uint32_t rv = source_->readListBegin(elemType, size);
+ sink_->writeListBegin(elemType, size);
+ return rv;
+ }
+
+
+ virtual uint32_t readListEnd() {
+ uint32_t rv = source_->readListEnd();
+ sink_->writeListEnd();
+ return rv;
+ }
+
+ virtual uint32_t readSetBegin(TType& elemType,
+ uint32_t& size) {
+ uint32_t rv = source_->readSetBegin(elemType, size);
+ sink_->writeSetBegin(elemType, size);
+ return rv;
+ }
+
+
+ virtual uint32_t readSetEnd() {
+ uint32_t rv = source_->readSetEnd();
+ sink_->writeSetEnd();
+ return rv;
+ }
+
+ virtual uint32_t readBool(bool& value) {
+ uint32_t rv = source_->readBool(value);
+ sink_->writeBool(value);
+ return rv;
+ }
+
+ virtual uint32_t readByte(int8_t& byte) {
+ uint32_t rv = source_->readByte(byte);
+ sink_->writeByte(byte);
+ return rv;
+ }
+
+ virtual uint32_t readI16(int16_t& i16) {
+ uint32_t rv = source_->readI16(i16);
+ sink_->writeI16(i16);
+ return rv;
+ }
+
+ virtual uint32_t readI32(int32_t& i32) {
+ uint32_t rv = source_->readI32(i32);
+ sink_->writeI32(i32);
+ return rv;
+ }
+
+ virtual uint32_t readI64(int64_t& i64) {
+ uint32_t rv = source_->readI64(i64);
+ sink_->writeI64(i64);
+ return rv;
+ }
+
+ virtual uint32_t readDouble(double& dub) {
+ uint32_t rv = source_->readDouble(dub);
+ sink_->writeDouble(dub);
+ return rv;
+ }
+
+ virtual uint32_t readString(std::string& str) {
+ uint32_t rv = source_->readString(str);
+ sink_->writeString(str);
+ return rv;
+ }
+
+ virtual uint32_t readBinary(std::string& str) {
+ uint32_t rv = source_->readBinary(str);
+ sink_->writeBinary(str);
+ return rv;
+ }
+
+ private:
+ boost::shared_ptr<TProtocol> source_;
+ boost::shared_ptr<TProtocol> sink_;
+};
+
+}}} // apache::thrift::protocol
+
+#endif // #define _THRIFT_PROTOCOL_TPROTOCOLTAP_H_ 1
diff --git a/lib/cpp/src/server/TNonblockingServer.cpp b/lib/cpp/src/server/TNonblockingServer.cpp
new file mode 100644
index 0000000..45f635c
--- /dev/null
+++ b/lib/cpp/src/server/TNonblockingServer.cpp
@@ -0,0 +1,750 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "TNonblockingServer.h"
+#include <concurrency/Exception.h>
+
+#include <iostream>
+#include <sys/socket.h>
+#include <netinet/in.h>
+#include <netinet/tcp.h>
+#include <netdb.h>
+#include <fcntl.h>
+#include <errno.h>
+#include <assert.h>
+
+namespace apache { namespace thrift { namespace server {
+
+using namespace apache::thrift::protocol;
+using namespace apache::thrift::transport;
+using namespace apache::thrift::concurrency;
+using namespace std;
+
+class TConnection::Task: public Runnable {
+ public:
+ Task(boost::shared_ptr<TProcessor> processor,
+ boost::shared_ptr<TProtocol> input,
+ boost::shared_ptr<TProtocol> output,
+ int taskHandle) :
+ processor_(processor),
+ input_(input),
+ output_(output),
+ taskHandle_(taskHandle) {}
+
+ void run() {
+ try {
+ while (processor_->process(input_, output_)) {
+ if (!input_->getTransport()->peek()) {
+ break;
+ }
+ }
+ } catch (TTransportException& ttx) {
+ cerr << "TNonblockingServer client died: " << ttx.what() << endl;
+ } catch (TException& x) {
+ cerr << "TNonblockingServer exception: " << x.what() << endl;
+ } catch (...) {
+ cerr << "TNonblockingServer uncaught exception." << endl;
+ }
+
+ // Signal completion back to the libevent thread via a socketpair
+ int8_t b = 0;
+ if (-1 == send(taskHandle_, &b, sizeof(int8_t), 0)) {
+ GlobalOutput.perror("TNonblockingServer::Task: send ", errno);
+ }
+ if (-1 == ::close(taskHandle_)) {
+ GlobalOutput.perror("TNonblockingServer::Task: close, possible resource leak ", errno);
+ }
+ }
+
+ private:
+ boost::shared_ptr<TProcessor> processor_;
+ boost::shared_ptr<TProtocol> input_;
+ boost::shared_ptr<TProtocol> output_;
+ int taskHandle_;
+};
+
+void TConnection::init(int socket, short eventFlags, TNonblockingServer* s) {
+ socket_ = socket;
+ server_ = s;
+ appState_ = APP_INIT;
+ eventFlags_ = 0;
+
+ readBufferPos_ = 0;
+ readWant_ = 0;
+
+ writeBuffer_ = NULL;
+ writeBufferSize_ = 0;
+ writeBufferPos_ = 0;
+
+ socketState_ = SOCKET_RECV;
+ appState_ = APP_INIT;
+
+ taskHandle_ = -1;
+
+ // Set flags, which also registers the event
+ setFlags(eventFlags);
+
+ // get input/transports
+ factoryInputTransport_ = s->getInputTransportFactory()->getTransport(inputTransport_);
+ factoryOutputTransport_ = s->getOutputTransportFactory()->getTransport(outputTransport_);
+
+ // Create protocol
+ inputProtocol_ = s->getInputProtocolFactory()->getProtocol(factoryInputTransport_);
+ outputProtocol_ = s->getOutputProtocolFactory()->getProtocol(factoryOutputTransport_);
+}
+
+void TConnection::workSocket() {
+ int flags=0, got=0, left=0, sent=0;
+ uint32_t fetch = 0;
+
+ switch (socketState_) {
+ case SOCKET_RECV:
+ // It is an error to be in this state if we already have all the data
+ assert(readBufferPos_ < readWant_);
+
+ // Double the buffer size until it is big enough
+ if (readWant_ > readBufferSize_) {
+ while (readWant_ > readBufferSize_) {
+ readBufferSize_ *= 2;
+ }
+ readBuffer_ = (uint8_t*)std::realloc(readBuffer_, readBufferSize_);
+ if (readBuffer_ == NULL) {
+ GlobalOutput("TConnection::workSocket() realloc");
+ close();
+ return;
+ }
+ }
+
+ // Read from the socket
+ fetch = readWant_ - readBufferPos_;
+ got = recv(socket_, readBuffer_ + readBufferPos_, fetch, 0);
+
+ if (got > 0) {
+ // Move along in the buffer
+ readBufferPos_ += got;
+
+ // Check that we did not overdo it
+ assert(readBufferPos_ <= readWant_);
+
+ // We are done reading, move onto the next state
+ if (readBufferPos_ == readWant_) {
+ transition();
+ }
+ return;
+ } else if (got == -1) {
+ // Blocking errors are okay, just move on
+ if (errno == EAGAIN || errno == EWOULDBLOCK) {
+ return;
+ }
+
+ if (errno != ECONNRESET) {
+ GlobalOutput.perror("TConnection::workSocket() recv -1 ", errno);
+ }
+ }
+
+ // Whenever we get down here it means a remote disconnect
+ close();
+
+ return;
+
+ case SOCKET_SEND:
+ // Should never have position past size
+ assert(writeBufferPos_ <= writeBufferSize_);
+
+ // If there is no data to send, then let us move on
+ if (writeBufferPos_ == writeBufferSize_) {
+ GlobalOutput("WARNING: Send state with no data to send\n");
+ transition();
+ return;
+ }
+
+ flags = 0;
+ #ifdef MSG_NOSIGNAL
+ // Note the use of MSG_NOSIGNAL to suppress SIGPIPE errors, instead we
+ // check for the EPIPE return condition and close the socket in that case
+ flags |= MSG_NOSIGNAL;
+ #endif // ifdef MSG_NOSIGNAL
+
+ left = writeBufferSize_ - writeBufferPos_;
+ sent = send(socket_, writeBuffer_ + writeBufferPos_, left, flags);
+
+ if (sent <= 0) {
+ // Blocking errors are okay, just move on
+ if (errno == EAGAIN || errno == EWOULDBLOCK) {
+ return;
+ }
+ if (errno != EPIPE) {
+ GlobalOutput.perror("TConnection::workSocket() send -1 ", errno);
+ }
+ close();
+ return;
+ }
+
+ writeBufferPos_ += sent;
+
+ // Did we overdo it?
+ assert(writeBufferPos_ <= writeBufferSize_);
+
+ // We are done!
+ if (writeBufferPos_ == writeBufferSize_) {
+ transition();
+ }
+
+ return;
+
+ default:
+ GlobalOutput.printf("Shit Got Ill. Socket State %d", socketState_);
+ assert(0);
+ }
+}
+
+/**
+ * This is called when the application transitions from one state into
+ * another. This means that it has finished writing the data that it needed
+ * to, or finished receiving the data that it needed to.
+ */
+void TConnection::transition() {
+
+ int sz = 0;
+
+ // Switch upon the state that we are currently in and move to a new state
+ switch (appState_) {
+
+ case APP_READ_REQUEST:
+ // We are done reading the request, package the read buffer into transport
+ // and get back some data from the dispatch function
+ // If we've used these transport buffers enough times, reset them to avoid bloating
+
+ inputTransport_->resetBuffer(readBuffer_, readBufferPos_);
+ ++numReadsSinceReset_;
+ if (numWritesSinceReset_ < 512) {
+ outputTransport_->resetBuffer();
+ } else {
+ // reset the capacity of the output transport if we used it enough times that it might be bloated
+ try {
+ outputTransport_->resetBuffer(true);
+ numWritesSinceReset_ = 0;
+ } catch (TTransportException &ttx) {
+ GlobalOutput.printf("TTransportException: TMemoryBuffer::resetBuffer() %s", ttx.what());
+ close();
+ return;
+ }
+ }
+
+ // Prepend four bytes of blank space to the buffer so we can
+ // write the frame size there later.
+ outputTransport_->getWritePtr(4);
+ outputTransport_->wroteBytes(4);
+
+ if (server_->isThreadPoolProcessing()) {
+ // We are setting up a Task to do this work and we will wait on it
+ int sv[2];
+ if (-1 == socketpair(AF_LOCAL, SOCK_STREAM, 0, sv)) {
+ GlobalOutput.perror("TConnection::socketpair() failed ", errno);
+ // Now we will fall through to the APP_WAIT_TASK block with no response
+ } else {
+ // Create task and dispatch to the thread manager
+ boost::shared_ptr<Runnable> task =
+ boost::shared_ptr<Runnable>(new Task(server_->getProcessor(),
+ inputProtocol_,
+ outputProtocol_,
+ sv[1]));
+ // The application is now waiting on the task to finish
+ appState_ = APP_WAIT_TASK;
+
+ // Create an event to be notified when the task finishes
+ event_set(&taskEvent_,
+ taskHandle_ = sv[0],
+ EV_READ,
+ TConnection::taskHandler,
+ this);
+
+ // Attach to the base
+ event_base_set(server_->getEventBase(), &taskEvent_);
+
+ // Add the event and start up the server
+ if (-1 == event_add(&taskEvent_, 0)) {
+ GlobalOutput("TNonblockingServer::serve(): coult not event_add");
+ return;
+ }
+ try {
+ server_->addTask(task);
+ } catch (IllegalStateException & ise) {
+ // The ThreadManager is not ready to handle any more tasks (it's probably shutting down).
+ GlobalOutput.printf("IllegalStateException: Server::process() %s", ise.what());
+ close();
+ }
+
+ // Set this connection idle so that libevent doesn't process more
+ // data on it while we're still waiting for the threadmanager to
+ // finish this task
+ setIdle();
+ return;
+ }
+ } else {
+ try {
+ // Invoke the processor
+ server_->getProcessor()->process(inputProtocol_, outputProtocol_);
+ } catch (TTransportException &ttx) {
+ GlobalOutput.printf("TTransportException: Server::process() %s", ttx.what());
+ close();
+ return;
+ } catch (TException &x) {
+ GlobalOutput.printf("TException: Server::process() %s", x.what());
+ close();
+ return;
+ } catch (...) {
+ GlobalOutput.printf("Server::process() unknown exception");
+ close();
+ return;
+ }
+ }
+
+ // Intentionally fall through here, the call to process has written into
+ // the writeBuffer_
+
+ case APP_WAIT_TASK:
+ // We have now finished processing a task and the result has been written
+ // into the outputTransport_, so we grab its contents and place them into
+ // the writeBuffer_ for actual writing by the libevent thread
+
+ // Get the result of the operation
+ outputTransport_->getBuffer(&writeBuffer_, &writeBufferSize_);
+
+ // If the function call generated return data, then move into the send
+ // state and get going
+ // 4 bytes were reserved for frame size
+ if (writeBufferSize_ > 4) {
+
+ // Move into write state
+ writeBufferPos_ = 0;
+ socketState_ = SOCKET_SEND;
+
+ // Put the frame size into the write buffer
+ int32_t frameSize = (int32_t)htonl(writeBufferSize_ - 4);
+ memcpy(writeBuffer_, &frameSize, 4);
+
+ // Socket into write mode
+ appState_ = APP_SEND_RESULT;
+ setWrite();
+
+ // Try to work the socket immediately
+ // workSocket();
+
+ return;
+ }
+
+ // In this case, the request was oneway and we should fall through
+ // right back into the read frame header state
+ goto LABEL_APP_INIT;
+
+ case APP_SEND_RESULT:
+
+ ++numWritesSinceReset_;
+
+ // N.B.: We also intentionally fall through here into the INIT state!
+
+ LABEL_APP_INIT:
+ case APP_INIT:
+
+ // reset the input buffer if we used it enough times that it might be bloated
+ if (numReadsSinceReset_ > 512)
+ {
+ void * new_buffer = std::realloc(readBuffer_, 1024);
+ if (new_buffer == NULL) {
+ GlobalOutput("TConnection::transition() realloc");
+ close();
+ return;
+ }
+ readBuffer_ = (uint8_t*) new_buffer;
+ readBufferSize_ = 1024;
+ numReadsSinceReset_ = 0;
+ }
+
+ // Clear write buffer variables
+ writeBuffer_ = NULL;
+ writeBufferPos_ = 0;
+ writeBufferSize_ = 0;
+
+ // Set up read buffer for getting 4 bytes
+ readBufferPos_ = 0;
+ readWant_ = 4;
+
+ // Into read4 state we go
+ socketState_ = SOCKET_RECV;
+ appState_ = APP_READ_FRAME_SIZE;
+
+ // Register read event
+ setRead();
+
+ // Try to work the socket right away
+ // workSocket();
+
+ return;
+
+ case APP_READ_FRAME_SIZE:
+ // We just read the request length, deserialize it
+ sz = *(int32_t*)readBuffer_;
+ sz = (int32_t)ntohl(sz);
+
+ if (sz <= 0) {
+ GlobalOutput.printf("TConnection:transition() Negative frame size %d, remote side not using TFramedTransport?", sz);
+ close();
+ return;
+ }
+
+ // Reset the read buffer
+ readWant_ = (uint32_t)sz;
+ readBufferPos_= 0;
+
+ // Move into read request state
+ appState_ = APP_READ_REQUEST;
+
+ // Work the socket right away
+ // workSocket();
+
+ return;
+
+ default:
+ GlobalOutput.printf("Totally Fucked. Application State %d", appState_);
+ assert(0);
+ }
+}
+
+void TConnection::setFlags(short eventFlags) {
+ // Catch the do nothing case
+ if (eventFlags_ == eventFlags) {
+ return;
+ }
+
+ // Delete a previously existing event
+ if (eventFlags_ != 0) {
+ if (event_del(&event_) == -1) {
+ GlobalOutput("TConnection::setFlags event_del");
+ return;
+ }
+ }
+
+ // Update in memory structure
+ eventFlags_ = eventFlags;
+
+ // Do not call event_set if there are no flags
+ if (!eventFlags_) {
+ return;
+ }
+
+ /**
+ * event_set:
+ *
+ * Prepares the event structure &event to be used in future calls to
+ * event_add() and event_del(). The event will be prepared to call the
+ * eventHandler using the 'sock' file descriptor to monitor events.
+ *
+ * The events can be either EV_READ, EV_WRITE, or both, indicating
+ * that an application can read or write from the file respectively without
+ * blocking.
+ *
+ * The eventHandler will be called with the file descriptor that triggered
+ * the event and the type of event which will be one of: EV_TIMEOUT,
+ * EV_SIGNAL, EV_READ, EV_WRITE.
+ *
+ * The additional flag EV_PERSIST makes an event_add() persistent until
+ * event_del() has been called.
+ *
+ * Once initialized, the &event struct can be used repeatedly with
+ * event_add() and event_del() and does not need to be reinitialized unless
+ * the eventHandler and/or the argument to it are to be changed. However,
+ * when an ev structure has been added to libevent using event_add() the
+ * structure must persist until the event occurs (assuming EV_PERSIST
+ * is not set) or is removed using event_del(). You may not reuse the same
+ * ev structure for multiple monitored descriptors; each descriptor needs
+ * its own ev.
+ */
+ event_set(&event_, socket_, eventFlags_, TConnection::eventHandler, this);
+ event_base_set(server_->getEventBase(), &event_);
+
+ // Add the event
+ if (event_add(&event_, 0) == -1) {
+ GlobalOutput("TConnection::setFlags(): could not event_add");
+ }
+}
+
+/**
+ * Closes a connection
+ */
+void TConnection::close() {
+ // Delete the registered libevent
+ if (event_del(&event_) == -1) {
+ GlobalOutput("TConnection::close() event_del");
+ }
+
+ // Close the socket
+ if (socket_ > 0) {
+ ::close(socket_);
+ }
+ socket_ = 0;
+
+ // close any factory produced transports
+ factoryInputTransport_->close();
+ factoryOutputTransport_->close();
+
+ // Give this object back to the server that owns it
+ server_->returnConnection(this);
+}
+
+void TConnection::checkIdleBufferMemLimit(uint32_t limit) {
+ if (readBufferSize_ > limit) {
+ readBufferSize_ = limit;
+ readBuffer_ = (uint8_t*)std::realloc(readBuffer_, readBufferSize_);
+ if (readBuffer_ == NULL) {
+ GlobalOutput("TConnection::checkIdleBufferMemLimit() realloc");
+ close();
+ }
+ }
+}
+
+/**
+ * Creates a new connection either by reusing an object off the stack or
+ * by allocating a new one entirely
+ */
+TConnection* TNonblockingServer::createConnection(int socket, short flags) {
+ // Check the stack
+ if (connectionStack_.empty()) {
+ return new TConnection(socket, flags, this);
+ } else {
+ TConnection* result = connectionStack_.top();
+ connectionStack_.pop();
+ result->init(socket, flags, this);
+ return result;
+ }
+}
+
+/**
+ * Returns a connection to the stack
+ */
+void TNonblockingServer::returnConnection(TConnection* connection) {
+ if (connectionStackLimit_ &&
+ (connectionStack_.size() >= connectionStackLimit_)) {
+ delete connection;
+ } else {
+ connection->checkIdleBufferMemLimit(idleBufferMemLimit_);
+ connectionStack_.push(connection);
+ }
+}
+
+/**
+ * Server socket had something happen. We accept all waiting client
+ * connections on fd and assign TConnection objects to handle those requests.
+ */
+void TNonblockingServer::handleEvent(int fd, short which) {
+ // Make sure that libevent didn't fuck up the socket handles
+ assert(fd == serverSocket_);
+
+ // Server socket accepted a new connection
+ socklen_t addrLen;
+ struct sockaddr addr;
+ addrLen = sizeof(addr);
+
+ // Going to accept a new client socket
+ int clientSocket;
+
+ // Accept as many new clients as possible, even though libevent signaled only
+ // one, this helps us to avoid having to go back into the libevent engine so
+ // many times
+ while ((clientSocket = accept(fd, &addr, &addrLen)) != -1) {
+
+ // Explicitly set this socket to NONBLOCK mode
+ int flags;
+ if ((flags = fcntl(clientSocket, F_GETFL, 0)) < 0 ||
+ fcntl(clientSocket, F_SETFL, flags | O_NONBLOCK) < 0) {
+ GlobalOutput.perror("thriftServerEventHandler: set O_NONBLOCK (fcntl) ", errno);
+ close(clientSocket);
+ return;
+ }
+
+ // Create a new TConnection for this client socket.
+ TConnection* clientConnection =
+ createConnection(clientSocket, EV_READ | EV_PERSIST);
+
+ // Fail fast if we could not create a TConnection object
+ if (clientConnection == NULL) {
+ GlobalOutput.printf("thriftServerEventHandler: failed TConnection factory");
+ close(clientSocket);
+ return;
+ }
+
+ // Put this client connection into the proper state
+ clientConnection->transition();
+ }
+
+ // Done looping accept, now we have to make sure the error is due to
+ // blocking. Any other error is a problem
+ if (errno != EAGAIN && errno != EWOULDBLOCK) {
+ GlobalOutput.perror("thriftServerEventHandler: accept() ", errno);
+ }
+}
+
+/**
+ * Creates a socket to listen on and binds it to the local port.
+ */
+void TNonblockingServer::listenSocket() {
+ int s;
+ struct addrinfo hints, *res, *res0;
+ int error;
+
+ char port[sizeof("65536") + 1];
+ memset(&hints, 0, sizeof(hints));
+ hints.ai_family = PF_UNSPEC;
+ hints.ai_socktype = SOCK_STREAM;
+ hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG;
+ sprintf(port, "%d", port_);
+
+ // Wildcard address
+ error = getaddrinfo(NULL, port, &hints, &res0);
+ if (error) {
+ string errStr = "TNonblockingServer::serve() getaddrinfo " + string(gai_strerror(error));
+ GlobalOutput(errStr.c_str());
+ return;
+ }
+
+ // Pick the ipv6 address first since ipv4 addresses can be mapped
+ // into ipv6 space.
+ for (res = res0; res; res = res->ai_next) {
+ if (res->ai_family == AF_INET6 || res->ai_next == NULL)
+ break;
+ }
+
+ // Create the server socket
+ s = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
+ if (s == -1) {
+ freeaddrinfo(res0);
+ throw TException("TNonblockingServer::serve() socket() -1");
+ }
+
+ #ifdef IPV6_V6ONLY
+ int zero = 0;
+ if (-1 == setsockopt(s, IPPROTO_IPV6, IPV6_V6ONLY, &zero, sizeof(zero))) {
+ GlobalOutput("TServerSocket::listen() IPV6_V6ONLY");
+ }
+ #endif // #ifdef IPV6_V6ONLY
+
+
+ int one = 1;
+
+ // Set reuseaddr to avoid 2MSL delay on server restart
+ setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one));
+
+ if (bind(s, res->ai_addr, res->ai_addrlen) == -1) {
+ close(s);
+ freeaddrinfo(res0);
+ throw TException("TNonblockingServer::serve() bind");
+ }
+
+ // Done with the addr info
+ freeaddrinfo(res0);
+
+ // Set up this file descriptor for listening
+ listenSocket(s);
+}
+
+/**
+ * Takes a socket created by listenSocket() and sets various options on it
+ * to prepare for use in the server.
+ */
+void TNonblockingServer::listenSocket(int s) {
+ // Set socket to nonblocking mode
+ int flags;
+ if ((flags = fcntl(s, F_GETFL, 0)) < 0 ||
+ fcntl(s, F_SETFL, flags | O_NONBLOCK) < 0) {
+ close(s);
+ throw TException("TNonblockingServer::serve() O_NONBLOCK");
+ }
+
+ int one = 1;
+ struct linger ling = {0, 0};
+
+ // Keepalive to ensure full result flushing
+ setsockopt(s, SOL_SOCKET, SO_KEEPALIVE, &one, sizeof(one));
+
+ // Turn linger off to avoid hung sockets
+ setsockopt(s, SOL_SOCKET, SO_LINGER, &ling, sizeof(ling));
+
+ // Set TCP nodelay if available, MAC OS X Hack
+ // See http://lists.danga.com/pipermail/memcached/2005-March/001240.html
+ #ifndef TCP_NOPUSH
+ setsockopt(s, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one));
+ #endif
+
+ if (listen(s, LISTEN_BACKLOG) == -1) {
+ close(s);
+ throw TException("TNonblockingServer::serve() listen");
+ }
+
+ // Cool, this socket is good to go, set it as the serverSocket_
+ serverSocket_ = s;
+}
+
+/**
+ * Register the core libevent events onto the proper base.
+ */
+void TNonblockingServer::registerEvents(event_base* base) {
+ assert(serverSocket_ != -1);
+ assert(!eventBase_);
+ eventBase_ = base;
+
+ // Print some libevent stats
+ GlobalOutput.printf("libevent %s method %s",
+ event_get_version(),
+ event_get_method());
+
+ // Register the server event
+ event_set(&serverEvent_,
+ serverSocket_,
+ EV_READ | EV_PERSIST,
+ TNonblockingServer::eventHandler,
+ this);
+ event_base_set(eventBase_, &serverEvent_);
+
+ // Add the event and start up the server
+ if (-1 == event_add(&serverEvent_, 0)) {
+ throw TException("TNonblockingServer::serve(): coult not event_add");
+ }
+}
+
+/**
+ * Main workhorse function, starts up the server listening on a port and
+ * loops over the libevent handler.
+ */
+void TNonblockingServer::serve() {
+ // Init socket
+ listenSocket();
+
+ // Initialize libevent core
+ registerEvents(static_cast<event_base*>(event_init()));
+
+ // Run the preServe event
+ if (eventHandler_ != NULL) {
+ eventHandler_->preServe();
+ }
+
+ // Run libevent engine, never returns, invokes calls to eventHandler
+ event_base_loop(eventBase_, 0);
+}
+
+}}} // apache::thrift::server
diff --git a/lib/cpp/src/server/TNonblockingServer.h b/lib/cpp/src/server/TNonblockingServer.h
new file mode 100644
index 0000000..1684b64
--- /dev/null
+++ b/lib/cpp/src/server/TNonblockingServer.h
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_SERVER_TNONBLOCKINGSERVER_H_
+#define _THRIFT_SERVER_TNONBLOCKINGSERVER_H_ 1
+
+#include <Thrift.h>
+#include <server/TServer.h>
+#include <transport/TBufferTransports.h>
+#include <concurrency/ThreadManager.h>
+#include <stack>
+#include <string>
+#include <errno.h>
+#include <cstdlib>
+#include <event.h>
+
+namespace apache { namespace thrift { namespace server {
+
+using apache::thrift::transport::TMemoryBuffer;
+using apache::thrift::protocol::TProtocol;
+using apache::thrift::concurrency::Runnable;
+using apache::thrift::concurrency::ThreadManager;
+
+// Forward declaration of class
+class TConnection;
+
+/**
+ * This is a non-blocking server in C++ for high performance that operates a
+ * single IO thread. It assumes that all incoming requests are framed with a
+ * 4 byte length indicator and writes out responses using the same framing.
+ *
+ * It does not use the TServerTransport framework, but rather has socket
+ * operations hardcoded for use with select.
+ *
+ */
+class TNonblockingServer : public TServer {
+ private:
+
+ // Listen backlog
+ static const int LISTEN_BACKLOG = 1024;
+
+ // Default limit on size of idle connection pool
+ static const size_t CONNECTION_STACK_LIMIT = 1024;
+
+ // Maximum size of buffer allocated to idle connection
+ static const uint32_t IDLE_BUFFER_MEM_LIMIT = 8192;
+
+ // Server socket file descriptor
+ int serverSocket_;
+
+ // Port server runs on
+ int port_;
+
+ // For processing via thread pool, may be NULL
+ boost::shared_ptr<ThreadManager> threadManager_;
+
+ // Is thread pool processing?
+ bool threadPoolProcessing_;
+
+ // The event base for libevent
+ event_base* eventBase_;
+
+ // Event struct, for use with eventBase_
+ struct event serverEvent_;
+
+ // Number of TConnection object we've created
+ size_t numTConnections_;
+
+ // Limit for how many TConnection objects to cache
+ size_t connectionStackLimit_;
+
+ /**
+ * Max read buffer size for an idle connection. When we place an idle
+ * TConnection into connectionStack_, we insure that its read buffer is
+ * reduced to this size to insure that idle connections don't hog memory.
+ */
+ uint32_t idleBufferMemLimit_;
+
+ /**
+ * This is a stack of all the objects that have been created but that
+ * are NOT currently in use. When we close a connection, we place it on this
+ * stack so that the object can be reused later, rather than freeing the
+ * memory and reallocating a new object later.
+ */
+ std::stack<TConnection*> connectionStack_;
+
+ void handleEvent(int fd, short which);
+
+ public:
+ TNonblockingServer(boost::shared_ptr<TProcessor> processor,
+ int port) :
+ TServer(processor),
+ serverSocket_(-1),
+ port_(port),
+ threadPoolProcessing_(false),
+ eventBase_(NULL),
+ numTConnections_(0),
+ connectionStackLimit_(CONNECTION_STACK_LIMIT),
+ idleBufferMemLimit_(IDLE_BUFFER_MEM_LIMIT) {}
+
+ TNonblockingServer(boost::shared_ptr<TProcessor> processor,
+ boost::shared_ptr<TProtocolFactory> protocolFactory,
+ int port,
+ boost::shared_ptr<ThreadManager> threadManager = boost::shared_ptr<ThreadManager>()) :
+ TServer(processor),
+ serverSocket_(-1),
+ port_(port),
+ threadManager_(threadManager),
+ eventBase_(NULL),
+ numTConnections_(0),
+ connectionStackLimit_(CONNECTION_STACK_LIMIT),
+ idleBufferMemLimit_(IDLE_BUFFER_MEM_LIMIT) {
+ setInputTransportFactory(boost::shared_ptr<TTransportFactory>(new TTransportFactory()));
+ setOutputTransportFactory(boost::shared_ptr<TTransportFactory>(new TTransportFactory()));
+ setInputProtocolFactory(protocolFactory);
+ setOutputProtocolFactory(protocolFactory);
+ setThreadManager(threadManager);
+ }
+
+ TNonblockingServer(boost::shared_ptr<TProcessor> processor,
+ boost::shared_ptr<TTransportFactory> inputTransportFactory,
+ boost::shared_ptr<TTransportFactory> outputTransportFactory,
+ boost::shared_ptr<TProtocolFactory> inputProtocolFactory,
+ boost::shared_ptr<TProtocolFactory> outputProtocolFactory,
+ int port,
+ boost::shared_ptr<ThreadManager> threadManager = boost::shared_ptr<ThreadManager>()) :
+ TServer(processor),
+ serverSocket_(0),
+ port_(port),
+ threadManager_(threadManager),
+ eventBase_(NULL),
+ numTConnections_(0),
+ connectionStackLimit_(CONNECTION_STACK_LIMIT),
+ idleBufferMemLimit_(IDLE_BUFFER_MEM_LIMIT) {
+ setInputTransportFactory(inputTransportFactory);
+ setOutputTransportFactory(outputTransportFactory);
+ setInputProtocolFactory(inputProtocolFactory);
+ setOutputProtocolFactory(outputProtocolFactory);
+ setThreadManager(threadManager);
+ }
+
+ ~TNonblockingServer() {}
+
+ void setThreadManager(boost::shared_ptr<ThreadManager> threadManager) {
+ threadManager_ = threadManager;
+ threadPoolProcessing_ = (threadManager != NULL);
+ }
+
+ boost::shared_ptr<ThreadManager> getThreadManager() {
+ return threadManager_;
+ }
+
+ /**
+ * Get the maximum number of unused TConnection we will hold in reserve.
+ *
+ * @return the current limit on TConnection pool size.
+ */
+ size_t getConnectionStackLimit() const {
+ return connectionStackLimit_;
+ }
+
+ /**
+ * Set the maximum number of unused TConnection we will hold in reserve.
+ *
+ * @param sz the new limit for TConnection pool size.
+ */
+ void setConnectionStackLimit(size_t sz) {
+ connectionStackLimit_ = sz;
+ }
+
+ bool isThreadPoolProcessing() const {
+ return threadPoolProcessing_;
+ }
+
+ void addTask(boost::shared_ptr<Runnable> task) {
+ threadManager_->add(task);
+ }
+
+ event_base* getEventBase() const {
+ return eventBase_;
+ }
+
+ void incrementNumConnections() {
+ ++numTConnections_;
+ }
+
+ void decrementNumConnections() {
+ --numTConnections_;
+ }
+
+ size_t getNumConnections() {
+ return numTConnections_;
+ }
+
+ size_t getNumIdleConnections() {
+ return connectionStack_.size();
+ }
+
+ /**
+ * Get the maximum limit of memory allocated to idle TConnection objects.
+ *
+ * @return # bytes beyond which we will shrink buffers when idle.
+ */
+ size_t getIdleBufferMemLimit() const {
+ return idleBufferMemLimit_;
+ }
+
+ /**
+ * Set the maximum limit of memory allocated to idle TConnection objects.
+ * If a TConnection object goes idle with more than this much memory
+ * allocated to its buffer, we shrink it to this value.
+ *
+ * @param limit of bytes beyond which we will shrink buffers when idle.
+ */
+ void setIdleBufferMemLimit(size_t limit) {
+ idleBufferMemLimit_ = limit;
+ }
+
+ TConnection* createConnection(int socket, short flags);
+
+ void returnConnection(TConnection* connection);
+
+ static void eventHandler(int fd, short which, void* v) {
+ ((TNonblockingServer*)v)->handleEvent(fd, which);
+ }
+
+ void listenSocket();
+
+ void listenSocket(int fd);
+
+ void registerEvents(event_base* base);
+
+ void serve();
+};
+
+/**
+ * Two states for sockets, recv and send mode
+ */
+enum TSocketState {
+ SOCKET_RECV,
+ SOCKET_SEND
+};
+
+/**
+ * Four states for the nonblocking servr:
+ * 1) initialize
+ * 2) read 4 byte frame size
+ * 3) read frame of data
+ * 4) send back data (if any)
+ */
+enum TAppState {
+ APP_INIT,
+ APP_READ_FRAME_SIZE,
+ APP_READ_REQUEST,
+ APP_WAIT_TASK,
+ APP_SEND_RESULT
+};
+
+/**
+ * Represents a connection that is handled via libevent. This connection
+ * essentially encapsulates a socket that has some associated libevent state.
+ */
+class TConnection {
+ private:
+
+ class Task;
+
+ // Server handle
+ TNonblockingServer* server_;
+
+ // Socket handle
+ int socket_;
+
+ // Libevent object
+ struct event event_;
+
+ // Libevent flags
+ short eventFlags_;
+
+ // Socket mode
+ TSocketState socketState_;
+
+ // Application state
+ TAppState appState_;
+
+ // How much data needed to read
+ uint32_t readWant_;
+
+ // Where in the read buffer are we
+ uint32_t readBufferPos_;
+
+ // Read buffer
+ uint8_t* readBuffer_;
+
+ // Read buffer size
+ uint32_t readBufferSize_;
+
+ // Write buffer
+ uint8_t* writeBuffer_;
+
+ // Write buffer size
+ uint32_t writeBufferSize_;
+
+ // How far through writing are we?
+ uint32_t writeBufferPos_;
+
+ // How many times have we read since our last buffer reset?
+ uint32_t numReadsSinceReset_;
+
+ // How many times have we written since our last buffer reset?
+ uint32_t numWritesSinceReset_;
+
+ // Task handle
+ int taskHandle_;
+
+ // Task event
+ struct event taskEvent_;
+
+ // Transport to read from
+ boost::shared_ptr<TMemoryBuffer> inputTransport_;
+
+ // Transport that processor writes to
+ boost::shared_ptr<TMemoryBuffer> outputTransport_;
+
+ // extra transport generated by transport factory (e.g. BufferedRouterTransport)
+ boost::shared_ptr<TTransport> factoryInputTransport_;
+ boost::shared_ptr<TTransport> factoryOutputTransport_;
+
+ // Protocol decoder
+ boost::shared_ptr<TProtocol> inputProtocol_;
+
+ // Protocol encoder
+ boost::shared_ptr<TProtocol> outputProtocol_;
+
+ // Go into read mode
+ void setRead() {
+ setFlags(EV_READ | EV_PERSIST);
+ }
+
+ // Go into write mode
+ void setWrite() {
+ setFlags(EV_WRITE | EV_PERSIST);
+ }
+
+ // Set socket idle
+ void setIdle() {
+ setFlags(0);
+ }
+
+ // Set event flags
+ void setFlags(short eventFlags);
+
+ // Libevent handlers
+ void workSocket();
+
+ // Close this client and reset
+ void close();
+
+ public:
+
+ // Constructor
+ TConnection(int socket, short eventFlags, TNonblockingServer *s) {
+ readBuffer_ = (uint8_t*)std::malloc(1024);
+ if (readBuffer_ == NULL) {
+ throw new apache::thrift::TException("Out of memory.");
+ }
+ readBufferSize_ = 1024;
+
+ numReadsSinceReset_ = 0;
+ numWritesSinceReset_ = 0;
+
+ // Allocate input and output tranpsorts
+ // these only need to be allocated once per TConnection (they don't need to be
+ // reallocated on init() call)
+ inputTransport_ = boost::shared_ptr<TMemoryBuffer>(new TMemoryBuffer(readBuffer_, readBufferSize_));
+ outputTransport_ = boost::shared_ptr<TMemoryBuffer>(new TMemoryBuffer());
+
+ init(socket, eventFlags, s);
+ server_->incrementNumConnections();
+ }
+
+ ~TConnection() {
+ server_->decrementNumConnections();
+ }
+
+ /**
+ * Check read buffer against a given limit and shrink it if exceeded.
+ *
+ * @param limit we limit buffer size to.
+ */
+ void checkIdleBufferMemLimit(uint32_t limit);
+
+ // Initialize
+ void init(int socket, short eventFlags, TNonblockingServer *s);
+
+ // Transition into a new state
+ void transition();
+
+ // Handler wrapper
+ static void eventHandler(int fd, short /* which */, void* v) {
+ assert(fd == ((TConnection*)v)->socket_);
+ ((TConnection*)v)->workSocket();
+ }
+
+ // Handler wrapper for task block
+ static void taskHandler(int fd, short /* which */, void* v) {
+ assert(fd == ((TConnection*)v)->taskHandle_);
+ if (-1 == ::close(((TConnection*)v)->taskHandle_)) {
+ GlobalOutput.perror("TConnection::taskHandler close handle failed, resource leak ", errno);
+ }
+ ((TConnection*)v)->transition();
+ }
+
+};
+
+}}} // apache::thrift::server
+
+#endif // #ifndef _THRIFT_SERVER_TSIMPLESERVER_H_
diff --git a/lib/cpp/src/server/TServer.cpp b/lib/cpp/src/server/TServer.cpp
new file mode 100644
index 0000000..6b692ab
--- /dev/null
+++ b/lib/cpp/src/server/TServer.cpp
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <sys/time.h>
+#include <sys/resource.h>
+#include <unistd.h>
+
+namespace apache { namespace thrift { namespace server {
+
+int increase_max_fds(int max_fds=(1<<24)) {
+ struct rlimit fdmaxrl;
+
+ for(fdmaxrl.rlim_cur = max_fds, fdmaxrl.rlim_max = max_fds;
+ max_fds && (setrlimit(RLIMIT_NOFILE, &fdmaxrl) < 0);
+ fdmaxrl.rlim_cur = max_fds, fdmaxrl.rlim_max = max_fds) {
+ max_fds /= 2;
+ }
+
+ return fdmaxrl.rlim_cur;
+}
+
+}}} // apache::thrift::server
diff --git a/lib/cpp/src/server/TServer.h b/lib/cpp/src/server/TServer.h
new file mode 100644
index 0000000..5c4c588
--- /dev/null
+++ b/lib/cpp/src/server/TServer.h
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_SERVER_TSERVER_H_
+#define _THRIFT_SERVER_TSERVER_H_ 1
+
+#include <TProcessor.h>
+#include <transport/TServerTransport.h>
+#include <protocol/TBinaryProtocol.h>
+#include <concurrency/Thread.h>
+
+#include <boost/shared_ptr.hpp>
+
+namespace apache { namespace thrift { namespace server {
+
+using apache::thrift::TProcessor;
+using apache::thrift::protocol::TBinaryProtocolFactory;
+using apache::thrift::protocol::TProtocol;
+using apache::thrift::protocol::TProtocolFactory;
+using apache::thrift::transport::TServerTransport;
+using apache::thrift::transport::TTransport;
+using apache::thrift::transport::TTransportFactory;
+
+/**
+ * Virtual interface class that can handle events from the server core. To
+ * use this you should subclass it and implement the methods that you care
+ * about. Your subclass can also store local data that you may care about,
+ * such as additional "arguments" to these methods (stored in the object
+ * instance's state).
+ */
+class TServerEventHandler {
+ public:
+
+ virtual ~TServerEventHandler() {}
+
+ /**
+ * Called before the server begins.
+ */
+ virtual void preServe() {}
+
+ /**
+ * Called when a new client has connected and is about to being processing.
+ */
+ virtual void clientBegin(boost::shared_ptr<TProtocol> /* input */,
+ boost::shared_ptr<TProtocol> /* output */) {}
+
+ /**
+ * Called when a client has finished making requests.
+ */
+ virtual void clientEnd(boost::shared_ptr<TProtocol> /* input */,
+ boost::shared_ptr<TProtocol> /* output */) {}
+
+ protected:
+
+ /**
+ * Prevent direct instantiation.
+ */
+ TServerEventHandler() {}
+
+};
+
+/**
+ * Thrift server.
+ *
+ */
+class TServer : public concurrency::Runnable {
+ public:
+
+ virtual ~TServer() {}
+
+ virtual void serve() = 0;
+
+ virtual void stop() {}
+
+ // Allows running the server as a Runnable thread
+ virtual void run() {
+ serve();
+ }
+
+ boost::shared_ptr<TProcessor> getProcessor() {
+ return processor_;
+ }
+
+ boost::shared_ptr<TServerTransport> getServerTransport() {
+ return serverTransport_;
+ }
+
+ boost::shared_ptr<TTransportFactory> getInputTransportFactory() {
+ return inputTransportFactory_;
+ }
+
+ boost::shared_ptr<TTransportFactory> getOutputTransportFactory() {
+ return outputTransportFactory_;
+ }
+
+ boost::shared_ptr<TProtocolFactory> getInputProtocolFactory() {
+ return inputProtocolFactory_;
+ }
+
+ boost::shared_ptr<TProtocolFactory> getOutputProtocolFactory() {
+ return outputProtocolFactory_;
+ }
+
+ boost::shared_ptr<TServerEventHandler> getEventHandler() {
+ return eventHandler_;
+ }
+
+protected:
+ TServer(boost::shared_ptr<TProcessor> processor):
+ processor_(processor) {
+ setInputTransportFactory(boost::shared_ptr<TTransportFactory>(new TTransportFactory()));
+ setOutputTransportFactory(boost::shared_ptr<TTransportFactory>(new TTransportFactory()));
+ setInputProtocolFactory(boost::shared_ptr<TProtocolFactory>(new TBinaryProtocolFactory()));
+ setOutputProtocolFactory(boost::shared_ptr<TProtocolFactory>(new TBinaryProtocolFactory()));
+ }
+
+ TServer(boost::shared_ptr<TProcessor> processor,
+ boost::shared_ptr<TServerTransport> serverTransport):
+ processor_(processor),
+ serverTransport_(serverTransport) {
+ setInputTransportFactory(boost::shared_ptr<TTransportFactory>(new TTransportFactory()));
+ setOutputTransportFactory(boost::shared_ptr<TTransportFactory>(new TTransportFactory()));
+ setInputProtocolFactory(boost::shared_ptr<TProtocolFactory>(new TBinaryProtocolFactory()));
+ setOutputProtocolFactory(boost::shared_ptr<TProtocolFactory>(new TBinaryProtocolFactory()));
+ }
+
+ TServer(boost::shared_ptr<TProcessor> processor,
+ boost::shared_ptr<TServerTransport> serverTransport,
+ boost::shared_ptr<TTransportFactory> transportFactory,
+ boost::shared_ptr<TProtocolFactory> protocolFactory):
+ processor_(processor),
+ serverTransport_(serverTransport),
+ inputTransportFactory_(transportFactory),
+ outputTransportFactory_(transportFactory),
+ inputProtocolFactory_(protocolFactory),
+ outputProtocolFactory_(protocolFactory) {}
+
+ TServer(boost::shared_ptr<TProcessor> processor,
+ boost::shared_ptr<TServerTransport> serverTransport,
+ boost::shared_ptr<TTransportFactory> inputTransportFactory,
+ boost::shared_ptr<TTransportFactory> outputTransportFactory,
+ boost::shared_ptr<TProtocolFactory> inputProtocolFactory,
+ boost::shared_ptr<TProtocolFactory> outputProtocolFactory):
+ processor_(processor),
+ serverTransport_(serverTransport),
+ inputTransportFactory_(inputTransportFactory),
+ outputTransportFactory_(outputTransportFactory),
+ inputProtocolFactory_(inputProtocolFactory),
+ outputProtocolFactory_(outputProtocolFactory) {}
+
+
+ // Class variables
+ boost::shared_ptr<TProcessor> processor_;
+ boost::shared_ptr<TServerTransport> serverTransport_;
+
+ boost::shared_ptr<TTransportFactory> inputTransportFactory_;
+ boost::shared_ptr<TTransportFactory> outputTransportFactory_;
+
+ boost::shared_ptr<TProtocolFactory> inputProtocolFactory_;
+ boost::shared_ptr<TProtocolFactory> outputProtocolFactory_;
+
+ boost::shared_ptr<TServerEventHandler> eventHandler_;
+
+public:
+ void setInputTransportFactory(boost::shared_ptr<TTransportFactory> inputTransportFactory) {
+ inputTransportFactory_ = inputTransportFactory;
+ }
+
+ void setOutputTransportFactory(boost::shared_ptr<TTransportFactory> outputTransportFactory) {
+ outputTransportFactory_ = outputTransportFactory;
+ }
+
+ void setInputProtocolFactory(boost::shared_ptr<TProtocolFactory> inputProtocolFactory) {
+ inputProtocolFactory_ = inputProtocolFactory;
+ }
+
+ void setOutputProtocolFactory(boost::shared_ptr<TProtocolFactory> outputProtocolFactory) {
+ outputProtocolFactory_ = outputProtocolFactory;
+ }
+
+ void setServerEventHandler(boost::shared_ptr<TServerEventHandler> eventHandler) {
+ eventHandler_ = eventHandler;
+ }
+
+};
+
+/**
+ * Helper function to increase the max file descriptors limit
+ * for the current process and all of its children.
+ * By default, tries to increase it to as much as 2^24.
+ */
+ int increase_max_fds(int max_fds=(1<<24));
+
+
+}}} // apache::thrift::server
+
+#endif // #ifndef _THRIFT_SERVER_TSERVER_H_
diff --git a/lib/cpp/src/server/TSimpleServer.cpp b/lib/cpp/src/server/TSimpleServer.cpp
new file mode 100644
index 0000000..394ce21
--- /dev/null
+++ b/lib/cpp/src/server/TSimpleServer.cpp
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "server/TSimpleServer.h"
+#include "transport/TTransportException.h"
+#include <string>
+#include <iostream>
+
+namespace apache { namespace thrift { namespace server {
+
+using namespace std;
+using namespace apache::thrift;
+using namespace apache::thrift::protocol;
+using namespace apache::thrift::transport;
+using boost::shared_ptr;
+
+/**
+ * A simple single-threaded application server. Perfect for unit tests!
+ *
+ */
+void TSimpleServer::serve() {
+
+ shared_ptr<TTransport> client;
+ shared_ptr<TTransport> inputTransport;
+ shared_ptr<TTransport> outputTransport;
+ shared_ptr<TProtocol> inputProtocol;
+ shared_ptr<TProtocol> outputProtocol;
+
+ try {
+ // Start the server listening
+ serverTransport_->listen();
+ } catch (TTransportException& ttx) {
+ cerr << "TSimpleServer::run() listen(): " << ttx.what() << endl;
+ return;
+ }
+
+ // Run the preServe event
+ if (eventHandler_ != NULL) {
+ eventHandler_->preServe();
+ }
+
+ // Fetch client from server
+ while (!stop_) {
+ try {
+ client = serverTransport_->accept();
+ inputTransport = inputTransportFactory_->getTransport(client);
+ outputTransport = outputTransportFactory_->getTransport(client);
+ inputProtocol = inputProtocolFactory_->getProtocol(inputTransport);
+ outputProtocol = outputProtocolFactory_->getProtocol(outputTransport);
+ if (eventHandler_ != NULL) {
+ eventHandler_->clientBegin(inputProtocol, outputProtocol);
+ }
+ try {
+ while (processor_->process(inputProtocol, outputProtocol)) {
+ // Peek ahead, is the remote side closed?
+ if (!inputTransport->peek()) {
+ break;
+ }
+ }
+ } catch (TTransportException& ttx) {
+ cerr << "TSimpleServer client died: " << ttx.what() << endl;
+ } catch (TException& tx) {
+ cerr << "TSimpleServer exception: " << tx.what() << endl;
+ }
+ if (eventHandler_ != NULL) {
+ eventHandler_->clientEnd(inputProtocol, outputProtocol);
+ }
+ inputTransport->close();
+ outputTransport->close();
+ client->close();
+ } catch (TTransportException& ttx) {
+ if (inputTransport != NULL) { inputTransport->close(); }
+ if (outputTransport != NULL) { outputTransport->close(); }
+ if (client != NULL) { client->close(); }
+ cerr << "TServerTransport died on accept: " << ttx.what() << endl;
+ continue;
+ } catch (TException& tx) {
+ if (inputTransport != NULL) { inputTransport->close(); }
+ if (outputTransport != NULL) { outputTransport->close(); }
+ if (client != NULL) { client->close(); }
+ cerr << "Some kind of accept exception: " << tx.what() << endl;
+ continue;
+ } catch (string s) {
+ if (inputTransport != NULL) { inputTransport->close(); }
+ if (outputTransport != NULL) { outputTransport->close(); }
+ if (client != NULL) { client->close(); }
+ cerr << "TThreadPoolServer: Unknown exception: " << s << endl;
+ break;
+ }
+ }
+
+ if (stop_) {
+ try {
+ serverTransport_->close();
+ } catch (TTransportException &ttx) {
+ cerr << "TServerTransport failed on close: " << ttx.what() << endl;
+ }
+ stop_ = false;
+ }
+}
+
+}}} // apache::thrift::server
diff --git a/lib/cpp/src/server/TSimpleServer.h b/lib/cpp/src/server/TSimpleServer.h
new file mode 100644
index 0000000..c4fc91c
--- /dev/null
+++ b/lib/cpp/src/server/TSimpleServer.h
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_SERVER_TSIMPLESERVER_H_
+#define _THRIFT_SERVER_TSIMPLESERVER_H_ 1
+
+#include "server/TServer.h"
+#include "transport/TServerTransport.h"
+
+namespace apache { namespace thrift { namespace server {
+
+/**
+ * This is the most basic simple server. It is single-threaded and runs a
+ * continuous loop of accepting a single connection, processing requests on
+ * that connection until it closes, and then repeating. It is a good example
+ * of how to extend the TServer interface.
+ *
+ */
+class TSimpleServer : public TServer {
+ public:
+ TSimpleServer(boost::shared_ptr<TProcessor> processor,
+ boost::shared_ptr<TServerTransport> serverTransport,
+ boost::shared_ptr<TTransportFactory> transportFactory,
+ boost::shared_ptr<TProtocolFactory> protocolFactory) :
+ TServer(processor, serverTransport, transportFactory, protocolFactory),
+ stop_(false) {}
+
+ TSimpleServer(boost::shared_ptr<TProcessor> processor,
+ boost::shared_ptr<TServerTransport> serverTransport,
+ boost::shared_ptr<TTransportFactory> inputTransportFactory,
+ boost::shared_ptr<TTransportFactory> outputTransportFactory,
+ boost::shared_ptr<TProtocolFactory> inputProtocolFactory,
+ boost::shared_ptr<TProtocolFactory> outputProtocolFactory):
+ TServer(processor, serverTransport,
+ inputTransportFactory, outputTransportFactory,
+ inputProtocolFactory, outputProtocolFactory),
+ stop_(false) {}
+
+ ~TSimpleServer() {}
+
+ void serve();
+
+ void stop() {
+ stop_ = true;
+ }
+
+ protected:
+ bool stop_;
+
+};
+
+}}} // apache::thrift::server
+
+#endif // #ifndef _THRIFT_SERVER_TSIMPLESERVER_H_
diff --git a/lib/cpp/src/server/TThreadPoolServer.cpp b/lib/cpp/src/server/TThreadPoolServer.cpp
new file mode 100644
index 0000000..0894cfa
--- /dev/null
+++ b/lib/cpp/src/server/TThreadPoolServer.cpp
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "server/TThreadPoolServer.h"
+#include "transport/TTransportException.h"
+#include "concurrency/Thread.h"
+#include "concurrency/ThreadManager.h"
+#include <string>
+#include <iostream>
+
+namespace apache { namespace thrift { namespace server {
+
+using boost::shared_ptr;
+using namespace std;
+using namespace apache::thrift;
+using namespace apache::thrift::concurrency;
+using namespace apache::thrift::protocol;;
+using namespace apache::thrift::transport;
+
+class TThreadPoolServer::Task : public Runnable {
+
+public:
+
+ Task(TThreadPoolServer &server,
+ shared_ptr<TProcessor> processor,
+ shared_ptr<TProtocol> input,
+ shared_ptr<TProtocol> output) :
+ server_(server),
+ processor_(processor),
+ input_(input),
+ output_(output) {
+ }
+
+ ~Task() {}
+
+ void run() {
+ boost::shared_ptr<TServerEventHandler> eventHandler =
+ server_.getEventHandler();
+ if (eventHandler != NULL) {
+ eventHandler->clientBegin(input_, output_);
+ }
+ try {
+ while (processor_->process(input_, output_)) {
+ if (!input_->getTransport()->peek()) {
+ break;
+ }
+ }
+ } catch (TTransportException& ttx) {
+ // This is reasonably expected, client didn't send a full request so just
+ // ignore him
+ // string errStr = string("TThreadPoolServer client died: ") + ttx.what();
+ // GlobalOutput(errStr.c_str());
+ } catch (TException& x) {
+ string errStr = string("TThreadPoolServer exception: ") + x.what();
+ GlobalOutput(errStr.c_str());
+ } catch (std::exception &x) {
+ string errStr = string("TThreadPoolServer, std::exception: ") + x.what();
+ GlobalOutput(errStr.c_str());
+ }
+
+ if (eventHandler != NULL) {
+ eventHandler->clientEnd(input_, output_);
+ }
+
+ try {
+ input_->getTransport()->close();
+ } catch (TTransportException& ttx) {
+ string errStr = string("TThreadPoolServer input close failed: ") + ttx.what();
+ GlobalOutput(errStr.c_str());
+ }
+ try {
+ output_->getTransport()->close();
+ } catch (TTransportException& ttx) {
+ string errStr = string("TThreadPoolServer output close failed: ") + ttx.what();
+ GlobalOutput(errStr.c_str());
+ }
+
+ }
+
+ private:
+ TServer& server_;
+ shared_ptr<TProcessor> processor_;
+ shared_ptr<TProtocol> input_;
+ shared_ptr<TProtocol> output_;
+
+};
+
+TThreadPoolServer::TThreadPoolServer(shared_ptr<TProcessor> processor,
+ shared_ptr<TServerTransport> serverTransport,
+ shared_ptr<TTransportFactory> transportFactory,
+ shared_ptr<TProtocolFactory> protocolFactory,
+ shared_ptr<ThreadManager> threadManager) :
+ TServer(processor, serverTransport, transportFactory, protocolFactory),
+ threadManager_(threadManager),
+ stop_(false), timeout_(0) {}
+
+TThreadPoolServer::TThreadPoolServer(shared_ptr<TProcessor> processor,
+ shared_ptr<TServerTransport> serverTransport,
+ shared_ptr<TTransportFactory> inputTransportFactory,
+ shared_ptr<TTransportFactory> outputTransportFactory,
+ shared_ptr<TProtocolFactory> inputProtocolFactory,
+ shared_ptr<TProtocolFactory> outputProtocolFactory,
+ shared_ptr<ThreadManager> threadManager) :
+ TServer(processor, serverTransport, inputTransportFactory, outputTransportFactory,
+ inputProtocolFactory, outputProtocolFactory),
+ threadManager_(threadManager),
+ stop_(false), timeout_(0) {}
+
+
+TThreadPoolServer::~TThreadPoolServer() {}
+
+void TThreadPoolServer::serve() {
+ shared_ptr<TTransport> client;
+ shared_ptr<TTransport> inputTransport;
+ shared_ptr<TTransport> outputTransport;
+ shared_ptr<TProtocol> inputProtocol;
+ shared_ptr<TProtocol> outputProtocol;
+
+ try {
+ // Start the server listening
+ serverTransport_->listen();
+ } catch (TTransportException& ttx) {
+ string errStr = string("TThreadPoolServer::run() listen(): ") + ttx.what();
+ GlobalOutput(errStr.c_str());
+ return;
+ }
+
+ // Run the preServe event
+ if (eventHandler_ != NULL) {
+ eventHandler_->preServe();
+ }
+
+ while (!stop_) {
+ try {
+ client.reset();
+ inputTransport.reset();
+ outputTransport.reset();
+ inputProtocol.reset();
+ outputProtocol.reset();
+
+ // Fetch client from server
+ client = serverTransport_->accept();
+
+ // Make IO transports
+ inputTransport = inputTransportFactory_->getTransport(client);
+ outputTransport = outputTransportFactory_->getTransport(client);
+ inputProtocol = inputProtocolFactory_->getProtocol(inputTransport);
+ outputProtocol = outputProtocolFactory_->getProtocol(outputTransport);
+
+ // Add to threadmanager pool
+ threadManager_->add(shared_ptr<TThreadPoolServer::Task>(new TThreadPoolServer::Task(*this, processor_, inputProtocol, outputProtocol)), timeout_);
+
+ } catch (TTransportException& ttx) {
+ if (inputTransport != NULL) { inputTransport->close(); }
+ if (outputTransport != NULL) { outputTransport->close(); }
+ if (client != NULL) { client->close(); }
+ if (!stop_ || ttx.getType() != TTransportException::INTERRUPTED) {
+ string errStr = string("TThreadPoolServer: TServerTransport died on accept: ") + ttx.what();
+ GlobalOutput(errStr.c_str());
+ }
+ continue;
+ } catch (TException& tx) {
+ if (inputTransport != NULL) { inputTransport->close(); }
+ if (outputTransport != NULL) { outputTransport->close(); }
+ if (client != NULL) { client->close(); }
+ string errStr = string("TThreadPoolServer: Caught TException: ") + tx.what();
+ GlobalOutput(errStr.c_str());
+ continue;
+ } catch (string s) {
+ if (inputTransport != NULL) { inputTransport->close(); }
+ if (outputTransport != NULL) { outputTransport->close(); }
+ if (client != NULL) { client->close(); }
+ string errStr = "TThreadPoolServer: Unknown exception: " + s;
+ GlobalOutput(errStr.c_str());
+ break;
+ }
+ }
+
+ // If stopped manually, join the existing threads
+ if (stop_) {
+ try {
+ serverTransport_->close();
+ threadManager_->join();
+ } catch (TException &tx) {
+ string errStr = string("TThreadPoolServer: Exception shutting down: ") + tx.what();
+ GlobalOutput(errStr.c_str());
+ }
+ stop_ = false;
+ }
+
+}
+
+int64_t TThreadPoolServer::getTimeout() const {
+ return timeout_;
+}
+
+void TThreadPoolServer::setTimeout(int64_t value) {
+ timeout_ = value;
+}
+
+}}} // apache::thrift::server
diff --git a/lib/cpp/src/server/TThreadPoolServer.h b/lib/cpp/src/server/TThreadPoolServer.h
new file mode 100644
index 0000000..7b7e906
--- /dev/null
+++ b/lib/cpp/src/server/TThreadPoolServer.h
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_SERVER_TTHREADPOOLSERVER_H_
+#define _THRIFT_SERVER_TTHREADPOOLSERVER_H_ 1
+
+#include <concurrency/ThreadManager.h>
+#include <server/TServer.h>
+#include <transport/TServerTransport.h>
+
+#include <boost/shared_ptr.hpp>
+
+namespace apache { namespace thrift { namespace server {
+
+using apache::thrift::concurrency::ThreadManager;
+using apache::thrift::protocol::TProtocolFactory;
+using apache::thrift::transport::TServerTransport;
+using apache::thrift::transport::TTransportFactory;
+
+class TThreadPoolServer : public TServer {
+ public:
+ class Task;
+
+ TThreadPoolServer(boost::shared_ptr<TProcessor> processor,
+ boost::shared_ptr<TServerTransport> serverTransport,
+ boost::shared_ptr<TTransportFactory> transportFactory,
+ boost::shared_ptr<TProtocolFactory> protocolFactory,
+ boost::shared_ptr<ThreadManager> threadManager);
+
+ TThreadPoolServer(boost::shared_ptr<TProcessor> processor,
+ boost::shared_ptr<TServerTransport> serverTransport,
+ boost::shared_ptr<TTransportFactory> inputTransportFactory,
+ boost::shared_ptr<TTransportFactory> outputTransportFactory,
+ boost::shared_ptr<TProtocolFactory> inputProtocolFactory,
+ boost::shared_ptr<TProtocolFactory> outputProtocolFactory,
+ boost::shared_ptr<ThreadManager> threadManager);
+
+ virtual ~TThreadPoolServer();
+
+ virtual void serve();
+
+ virtual int64_t getTimeout() const;
+
+ virtual void setTimeout(int64_t value);
+
+ virtual void stop() {
+ stop_ = true;
+ serverTransport_->interrupt();
+ }
+
+ protected:
+
+ boost::shared_ptr<ThreadManager> threadManager_;
+
+ volatile bool stop_;
+
+ volatile int64_t timeout_;
+
+};
+
+}}} // apache::thrift::server
+
+#endif // #ifndef _THRIFT_SERVER_TTHREADPOOLSERVER_H_
diff --git a/lib/cpp/src/server/TThreadedServer.cpp b/lib/cpp/src/server/TThreadedServer.cpp
new file mode 100644
index 0000000..cc30f8f
--- /dev/null
+++ b/lib/cpp/src/server/TThreadedServer.cpp
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "server/TThreadedServer.h"
+#include "transport/TTransportException.h"
+#include "concurrency/PosixThreadFactory.h"
+
+#include <string>
+#include <iostream>
+#include <pthread.h>
+#include <unistd.h>
+
+namespace apache { namespace thrift { namespace server {
+
+using boost::shared_ptr;
+using namespace std;
+using namespace apache::thrift;
+using namespace apache::thrift::protocol;
+using namespace apache::thrift::transport;
+using namespace apache::thrift::concurrency;
+
+class TThreadedServer::Task: public Runnable {
+
+public:
+
+ Task(TThreadedServer& server,
+ shared_ptr<TProcessor> processor,
+ shared_ptr<TProtocol> input,
+ shared_ptr<TProtocol> output) :
+ server_(server),
+ processor_(processor),
+ input_(input),
+ output_(output) {
+ }
+
+ ~Task() {}
+
+ void run() {
+ boost::shared_ptr<TServerEventHandler> eventHandler =
+ server_.getEventHandler();
+ if (eventHandler != NULL) {
+ eventHandler->clientBegin(input_, output_);
+ }
+ try {
+ while (processor_->process(input_, output_)) {
+ if (!input_->getTransport()->peek()) {
+ break;
+ }
+ }
+ } catch (TTransportException& ttx) {
+ string errStr = string("TThreadedServer client died: ") + ttx.what();
+ GlobalOutput(errStr.c_str());
+ } catch (TException& x) {
+ string errStr = string("TThreadedServer exception: ") + x.what();
+ GlobalOutput(errStr.c_str());
+ } catch (...) {
+ GlobalOutput("TThreadedServer uncaught exception.");
+ }
+ if (eventHandler != NULL) {
+ eventHandler->clientEnd(input_, output_);
+ }
+
+ try {
+ input_->getTransport()->close();
+ } catch (TTransportException& ttx) {
+ string errStr = string("TThreadedServer input close failed: ") + ttx.what();
+ GlobalOutput(errStr.c_str());
+ }
+ try {
+ output_->getTransport()->close();
+ } catch (TTransportException& ttx) {
+ string errStr = string("TThreadedServer output close failed: ") + ttx.what();
+ GlobalOutput(errStr.c_str());
+ }
+
+ // Remove this task from parent bookkeeping
+ {
+ Synchronized s(server_.tasksMonitor_);
+ server_.tasks_.erase(this);
+ if (server_.tasks_.empty()) {
+ server_.tasksMonitor_.notify();
+ }
+ }
+
+ }
+
+ private:
+ TThreadedServer& server_;
+ friend class TThreadedServer;
+
+ shared_ptr<TProcessor> processor_;
+ shared_ptr<TProtocol> input_;
+ shared_ptr<TProtocol> output_;
+};
+
+
+TThreadedServer::TThreadedServer(shared_ptr<TProcessor> processor,
+ shared_ptr<TServerTransport> serverTransport,
+ shared_ptr<TTransportFactory> transportFactory,
+ shared_ptr<TProtocolFactory> protocolFactory):
+ TServer(processor, serverTransport, transportFactory, protocolFactory),
+ stop_(false) {
+ threadFactory_ = shared_ptr<PosixThreadFactory>(new PosixThreadFactory());
+}
+
+TThreadedServer::TThreadedServer(boost::shared_ptr<TProcessor> processor,
+ boost::shared_ptr<TServerTransport> serverTransport,
+ boost::shared_ptr<TTransportFactory> transportFactory,
+ boost::shared_ptr<TProtocolFactory> protocolFactory,
+ boost::shared_ptr<ThreadFactory> threadFactory):
+ TServer(processor, serverTransport, transportFactory, protocolFactory),
+ threadFactory_(threadFactory),
+ stop_(false) {
+}
+
+TThreadedServer::~TThreadedServer() {}
+
+void TThreadedServer::serve() {
+
+ shared_ptr<TTransport> client;
+ shared_ptr<TTransport> inputTransport;
+ shared_ptr<TTransport> outputTransport;
+ shared_ptr<TProtocol> inputProtocol;
+ shared_ptr<TProtocol> outputProtocol;
+
+ try {
+ // Start the server listening
+ serverTransport_->listen();
+ } catch (TTransportException& ttx) {
+ string errStr = string("TThreadedServer::run() listen(): ") +ttx.what();
+ GlobalOutput(errStr.c_str());
+ return;
+ }
+
+ // Run the preServe event
+ if (eventHandler_ != NULL) {
+ eventHandler_->preServe();
+ }
+
+ while (!stop_) {
+ try {
+ client.reset();
+ inputTransport.reset();
+ outputTransport.reset();
+ inputProtocol.reset();
+ outputProtocol.reset();
+
+ // Fetch client from server
+ client = serverTransport_->accept();
+
+ // Make IO transports
+ inputTransport = inputTransportFactory_->getTransport(client);
+ outputTransport = outputTransportFactory_->getTransport(client);
+ inputProtocol = inputProtocolFactory_->getProtocol(inputTransport);
+ outputProtocol = outputProtocolFactory_->getProtocol(outputTransport);
+
+ TThreadedServer::Task* task = new TThreadedServer::Task(*this,
+ processor_,
+ inputProtocol,
+ outputProtocol);
+
+ // Create a task
+ shared_ptr<Runnable> runnable =
+ shared_ptr<Runnable>(task);
+
+ // Create a thread for this task
+ shared_ptr<Thread> thread =
+ shared_ptr<Thread>(threadFactory_->newThread(runnable));
+
+ // Insert thread into the set of threads
+ {
+ Synchronized s(tasksMonitor_);
+ tasks_.insert(task);
+ }
+
+ // Start the thread!
+ thread->start();
+
+ } catch (TTransportException& ttx) {
+ if (inputTransport != NULL) { inputTransport->close(); }
+ if (outputTransport != NULL) { outputTransport->close(); }
+ if (client != NULL) { client->close(); }
+ if (!stop_ || ttx.getType() != TTransportException::INTERRUPTED) {
+ string errStr = string("TThreadedServer: TServerTransport died on accept: ") + ttx.what();
+ GlobalOutput(errStr.c_str());
+ }
+ continue;
+ } catch (TException& tx) {
+ if (inputTransport != NULL) { inputTransport->close(); }
+ if (outputTransport != NULL) { outputTransport->close(); }
+ if (client != NULL) { client->close(); }
+ string errStr = string("TThreadedServer: Caught TException: ") + tx.what();
+ GlobalOutput(errStr.c_str());
+ continue;
+ } catch (string s) {
+ if (inputTransport != NULL) { inputTransport->close(); }
+ if (outputTransport != NULL) { outputTransport->close(); }
+ if (client != NULL) { client->close(); }
+ string errStr = "TThreadedServer: Unknown exception: " + s;
+ GlobalOutput(errStr.c_str());
+ break;
+ }
+ }
+
+ // If stopped manually, make sure to close server transport
+ if (stop_) {
+ try {
+ serverTransport_->close();
+ } catch (TException &tx) {
+ string errStr = string("TThreadedServer: Exception shutting down: ") + tx.what();
+ GlobalOutput(errStr.c_str());
+ }
+ try {
+ Synchronized s(tasksMonitor_);
+ while (!tasks_.empty()) {
+ tasksMonitor_.wait();
+ }
+ } catch (TException &tx) {
+ string errStr = string("TThreadedServer: Exception joining workers: ") + tx.what();
+ GlobalOutput(errStr.c_str());
+ }
+ stop_ = false;
+ }
+
+}
+
+}}} // apache::thrift::server
diff --git a/lib/cpp/src/server/TThreadedServer.h b/lib/cpp/src/server/TThreadedServer.h
new file mode 100644
index 0000000..4d0811a
--- /dev/null
+++ b/lib/cpp/src/server/TThreadedServer.h
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_SERVER_TTHREADEDSERVER_H_
+#define _THRIFT_SERVER_TTHREADEDSERVER_H_ 1
+
+#include <server/TServer.h>
+#include <transport/TServerTransport.h>
+#include <concurrency/Monitor.h>
+#include <concurrency/Thread.h>
+
+#include <boost/shared_ptr.hpp>
+
+namespace apache { namespace thrift { namespace server {
+
+using apache::thrift::TProcessor;
+using apache::thrift::transport::TServerTransport;
+using apache::thrift::transport::TTransportFactory;
+using apache::thrift::concurrency::Monitor;
+using apache::thrift::concurrency::ThreadFactory;
+
+class TThreadedServer : public TServer {
+
+ public:
+ class Task;
+
+ TThreadedServer(boost::shared_ptr<TProcessor> processor,
+ boost::shared_ptr<TServerTransport> serverTransport,
+ boost::shared_ptr<TTransportFactory> transportFactory,
+ boost::shared_ptr<TProtocolFactory> protocolFactory);
+
+ TThreadedServer(boost::shared_ptr<TProcessor> processor,
+ boost::shared_ptr<TServerTransport> serverTransport,
+ boost::shared_ptr<TTransportFactory> transportFactory,
+ boost::shared_ptr<TProtocolFactory> protocolFactory,
+ boost::shared_ptr<ThreadFactory> threadFactory);
+
+ virtual ~TThreadedServer();
+
+ virtual void serve();
+
+ void stop() {
+ stop_ = true;
+ serverTransport_->interrupt();
+ }
+
+ protected:
+ boost::shared_ptr<ThreadFactory> threadFactory_;
+ volatile bool stop_;
+
+ Monitor tasksMonitor_;
+ std::set<Task*> tasks_;
+
+};
+
+}}} // apache::thrift::server
+
+#endif // #ifndef _THRIFT_SERVER_TTHREADEDSERVER_H_
diff --git a/lib/cpp/src/transport/TBufferTransports.cpp b/lib/cpp/src/transport/TBufferTransports.cpp
new file mode 100644
index 0000000..7a7e5e9
--- /dev/null
+++ b/lib/cpp/src/transport/TBufferTransports.cpp
@@ -0,0 +1,370 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <cassert>
+#include <algorithm>
+
+#include <transport/TBufferTransports.h>
+
+using std::string;
+
+namespace apache { namespace thrift { namespace transport {
+
+
+uint32_t TBufferedTransport::readSlow(uint8_t* buf, uint32_t len) {
+ uint32_t want = len;
+ uint32_t have = rBound_ - rBase_;
+
+ // We should only take the slow path if we can't satisfy the read
+ // with the data already in the buffer.
+ assert(have < want);
+
+ // Copy out whatever we have.
+ if (have > 0) {
+ memcpy(buf, rBase_, have);
+ want -= have;
+ buf += have;
+ }
+ // Get more from underlying transport up to buffer size.
+ // Note that this makes a lot of sense if len < rBufSize_
+ // and almost no sense otherwise. TODO(dreiss): Fix that
+ // case (possibly including some readv hotness).
+ setReadBuffer(rBuf_.get(), transport_->read(rBuf_.get(), rBufSize_));
+
+ // Hand over whatever we have.
+ uint32_t give = std::min(want, static_cast<uint32_t>(rBound_ - rBase_));
+ memcpy(buf, rBase_, give);
+ rBase_ += give;
+ want -= give;
+
+ return (len - want);
+}
+
+void TBufferedTransport::writeSlow(const uint8_t* buf, uint32_t len) {
+ uint32_t have_bytes = wBase_ - wBuf_.get();
+ uint32_t space = wBound_ - wBase_;
+ // We should only take the slow path if we can't accomodate the write
+ // with the free space already in the buffer.
+ assert(wBound_ - wBase_ < static_cast<ptrdiff_t>(len));
+
+ // Now here's the tricky question: should we copy data from buf into our
+ // internal buffer and write it from there, or should we just write out
+ // the current internal buffer in one syscall and write out buf in another.
+ // If our currently buffered data plus buf is at least double our buffer
+ // size, we will have to do two syscalls no matter what (except in the
+ // degenerate case when our buffer is empty), so there is no use copying.
+ // Otherwise, there is sort of a sliding scale. If we have N-1 bytes
+ // buffered and need to write 2, it would be crazy to do two syscalls.
+ // On the other hand, if we have 2 bytes buffered and are writing 2N-3,
+ // we can save a syscall in the short term by loading up our buffer, writing
+ // it out, and copying the rest of the bytes into our buffer. Of course,
+ // if we get another 2-byte write, we haven't saved any syscalls at all,
+ // and have just copied nearly 2N bytes for nothing. Finding a perfect
+ // policy would require predicting the size of future writes, so we're just
+ // going to always eschew syscalls if we have less than 2N bytes to write.
+
+ // The case where we have to do two syscalls.
+ // This case also covers the case where the buffer is empty,
+ // but it is clearer (I think) to think of it as two separate cases.
+ if ((have_bytes + len >= 2*wBufSize_) || (have_bytes == 0)) {
+ // TODO(dreiss): writev
+ if (have_bytes > 0) {
+ transport_->write(wBuf_.get(), have_bytes);
+ }
+ transport_->write(buf, len);
+ wBase_ = wBuf_.get();
+ return;
+ }
+
+ // Fill up our internal buffer for a write.
+ memcpy(wBase_, buf, space);
+ buf += space;
+ len -= space;
+ transport_->write(wBuf_.get(), wBufSize_);
+
+ // Copy the rest into our buffer.
+ assert(len < wBufSize_);
+ memcpy(wBuf_.get(), buf, len);
+ wBase_ = wBuf_.get() + len;
+ return;
+}
+
+const uint8_t* TBufferedTransport::borrowSlow(uint8_t* buf, uint32_t* len) {
+ // If the request is bigger than our buffer, we are hosed.
+ if (*len > rBufSize_) {
+ return NULL;
+ }
+
+ // The number of bytes of data we have already.
+ uint32_t have = rBound_ - rBase_;
+ // The number of additional bytes we need from the underlying transport.
+ int32_t need = *len - have;
+ // The space from the start of the buffer to the end of our data.
+ uint32_t offset = rBound_ - rBuf_.get();
+ assert(need > 0);
+
+ // If we have less than half our buffer space available, shift the data
+ // we have down to the start. If the borrow is big compared to our buffer,
+ // this could be kind of a waste, but if the borrow is small, it frees up
+ // space at the end of our buffer to do a bigger single read from the
+ // underlying transport. Also, if our needs extend past the end of the
+ // buffer, we have to do a copy no matter what.
+ if ((offset > rBufSize_/2) || (offset + need > rBufSize_)) {
+ memmove(rBuf_.get(), rBase_, have);
+ setReadBuffer(rBuf_.get(), have);
+ }
+
+ // First try to fill up the buffer.
+ uint32_t got = transport_->read(rBound_, rBufSize_ - have);
+ rBound_ += got;
+ need -= got;
+
+ // If that fails, readAll until we get what we need.
+ if (need > 0) {
+ rBound_ += transport_->readAll(rBound_, need);
+ }
+
+ *len = rBound_ - rBase_;
+ return rBase_;
+}
+
+void TBufferedTransport::flush() {
+ // Write out any data waiting in the write buffer.
+ uint32_t have_bytes = wBase_ - wBuf_.get();
+ if (have_bytes > 0) {
+ // Note that we reset wBase_ prior to the underlying write
+ // to ensure we're in a sane state (i.e. internal buffer cleaned)
+ // if the underlying write throws up an exception
+ wBase_ = wBuf_.get();
+ transport_->write(wBuf_.get(), have_bytes);
+ }
+
+ // Flush the underlying transport.
+ transport_->flush();
+}
+
+
+uint32_t TFramedTransport::readSlow(uint8_t* buf, uint32_t len) {
+ uint32_t want = len;
+ uint32_t have = rBound_ - rBase_;
+
+ // We should only take the slow path if we can't satisfy the read
+ // with the data already in the buffer.
+ assert(have < want);
+
+ // Copy out whatever we have.
+ if (have > 0) {
+ memcpy(buf, rBase_, have);
+ want -= have;
+ buf += have;
+ }
+
+ // Read another frame.
+ readFrame();
+
+ // TODO(dreiss): Should we warn when reads cross frames?
+
+ // Hand over whatever we have.
+ uint32_t give = std::min(want, static_cast<uint32_t>(rBound_ - rBase_));
+ memcpy(buf, rBase_, give);
+ rBase_ += give;
+ want -= give;
+
+ return (len - want);
+}
+
+void TFramedTransport::readFrame() {
+ // TODO(dreiss): Think about using readv here, even though it would
+ // result in (gasp) read-ahead.
+
+ // Read the size of the next frame.
+ int32_t sz;
+ transport_->readAll((uint8_t*)&sz, sizeof(sz));
+ sz = ntohl(sz);
+
+ if (sz < 0) {
+ throw TTransportException("Frame size has negative value");
+ }
+
+ // Read the frame payload, and reset markers.
+ if (sz > static_cast<int32_t>(rBufSize_)) {
+ rBuf_.reset(new uint8_t[sz]);
+ rBufSize_ = sz;
+ }
+ transport_->readAll(rBuf_.get(), sz);
+ setReadBuffer(rBuf_.get(), sz);
+}
+
+void TFramedTransport::writeSlow(const uint8_t* buf, uint32_t len) {
+ // Double buffer size until sufficient.
+ uint32_t have = wBase_ - wBuf_.get();
+ while (wBufSize_ < len + have) {
+ wBufSize_ *= 2;
+ }
+
+ // TODO(dreiss): Consider modifying this class to use malloc/free
+ // so we can use realloc here.
+
+ // Allocate new buffer.
+ uint8_t* new_buf = new uint8_t[wBufSize_];
+
+ // Copy the old buffer to the new one.
+ memcpy(new_buf, wBuf_.get(), have);
+
+ // Now point buf to the new one.
+ wBuf_.reset(new_buf);
+ wBase_ = wBuf_.get() + have;
+ wBound_ = wBuf_.get() + wBufSize_;
+
+ // Copy the data into the new buffer.
+ memcpy(wBase_, buf, len);
+ wBase_ += len;
+}
+
+void TFramedTransport::flush() {
+ int32_t sz_hbo, sz_nbo;
+ assert(wBufSize_ > sizeof(sz_nbo));
+
+ // Slip the frame size into the start of the buffer.
+ sz_hbo = wBase_ - (wBuf_.get() + sizeof(sz_nbo));
+ sz_nbo = (int32_t)htonl((uint32_t)(sz_hbo));
+ memcpy(wBuf_.get(), (uint8_t*)&sz_nbo, sizeof(sz_nbo));
+
+ if (sz_hbo > 0) {
+ // Note that we reset wBase_ (with a pad for the frame size)
+ // prior to the underlying write to ensure we're in a sane state
+ // (i.e. internal buffer cleaned) if the underlying write throws
+ // up an exception
+ wBase_ = wBuf_.get() + sizeof(sz_nbo);
+
+ // Write size and frame body.
+ transport_->write(wBuf_.get(), sizeof(sz_nbo)+sz_hbo);
+ }
+
+ // Flush the underlying transport.
+ transport_->flush();
+}
+
+const uint8_t* TFramedTransport::borrowSlow(uint8_t* buf, uint32_t* len) {
+ // Don't try to be clever with shifting buffers.
+ // If the fast path failed let the protocol use its slow path.
+ // Besides, who is going to try to borrow across messages?
+ return NULL;
+}
+
+
+void TMemoryBuffer::computeRead(uint32_t len, uint8_t** out_start, uint32_t* out_give) {
+ // Correct rBound_ so we can use the fast path in the future.
+ rBound_ = wBase_;
+
+ // Decide how much to give.
+ uint32_t give = std::min(len, available_read());
+
+ *out_start = rBase_;
+ *out_give = give;
+
+ // Preincrement rBase_ so the caller doesn't have to.
+ rBase_ += give;
+}
+
+uint32_t TMemoryBuffer::readSlow(uint8_t* buf, uint32_t len) {
+ uint8_t* start;
+ uint32_t give;
+ computeRead(len, &start, &give);
+
+ // Copy into the provided buffer.
+ memcpy(buf, start, give);
+
+ return give;
+}
+
+uint32_t TMemoryBuffer::readAppendToString(std::string& str, uint32_t len) {
+ // Don't get some stupid assertion failure.
+ if (buffer_ == NULL) {
+ return 0;
+ }
+
+ uint8_t* start;
+ uint32_t give;
+ computeRead(len, &start, &give);
+
+ // Append to the provided string.
+ str.append((char*)start, give);
+
+ return give;
+}
+
+void TMemoryBuffer::ensureCanWrite(uint32_t len) {
+ // Check available space
+ uint32_t avail = available_write();
+ if (len <= avail) {
+ return;
+ }
+
+ if (!owner_) {
+ throw TTransportException("Insufficient space in external MemoryBuffer");
+ }
+
+ // Grow the buffer as necessary.
+ while (len > avail) {
+ bufferSize_ *= 2;
+ wBound_ = buffer_ + bufferSize_;
+ avail = available_write();
+ }
+
+ // Allocate into a new pointer so we don't bork ours if it fails.
+ void* new_buffer = std::realloc(buffer_, bufferSize_);
+ if (new_buffer == NULL) {
+ throw TTransportException("Out of memory.");
+ }
+
+ ptrdiff_t offset = (uint8_t*)new_buffer - buffer_;
+ buffer_ += offset;
+ rBase_ += offset;
+ rBound_ += offset;
+ wBase_ += offset;
+ wBound_ += offset;
+}
+
+void TMemoryBuffer::writeSlow(const uint8_t* buf, uint32_t len) {
+ ensureCanWrite(len);
+
+ // Copy into the buffer and increment wBase_.
+ memcpy(wBase_, buf, len);
+ wBase_ += len;
+}
+
+void TMemoryBuffer::wroteBytes(uint32_t len) {
+ uint32_t avail = available_write();
+ if (len > avail) {
+ throw TTransportException("Client wrote more bytes than size of buffer.");
+ }
+ wBase_ += len;
+}
+
+const uint8_t* TMemoryBuffer::borrowSlow(uint8_t* buf, uint32_t* len) {
+ rBound_ = wBase_;
+ if (available_read() >= *len) {
+ *len = available_read();
+ return rBase_;
+ }
+ return NULL;
+}
+
+}}} // apache::thrift::transport
diff --git a/lib/cpp/src/transport/TBufferTransports.h b/lib/cpp/src/transport/TBufferTransports.h
new file mode 100644
index 0000000..1908205
--- /dev/null
+++ b/lib/cpp/src/transport/TBufferTransports.h
@@ -0,0 +1,667 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_TRANSPORT_TBUFFERTRANSPORTS_H_
+#define _THRIFT_TRANSPORT_TBUFFERTRANSPORTS_H_ 1
+
+#include <cstring>
+#include "boost/scoped_array.hpp"
+
+#include <transport/TTransport.h>
+
+#ifdef __GNUC__
+#define TDB_LIKELY(val) (__builtin_expect((val), 1))
+#define TDB_UNLIKELY(val) (__builtin_expect((val), 0))
+#else
+#define TDB_LIKELY(val) (val)
+#define TDB_UNLIKELY(val) (val)
+#endif
+
+namespace apache { namespace thrift { namespace transport {
+
+
+/**
+ * Base class for all transports that use read/write buffers for performance.
+ *
+ * TBufferBase is designed to implement the fast-path "memcpy" style
+ * operations that work in the common case. It does so with small and
+ * (eventually) nonvirtual, inlinable methods. TBufferBase is an abstract
+ * class. Subclasses are expected to define the "slow path" operations
+ * that have to be done when the buffers are full or empty.
+ *
+ */
+class TBufferBase : public TTransport {
+
+ public:
+
+ /**
+ * Fast-path read.
+ *
+ * When we have enough data buffered to fulfill the read, we can satisfy it
+ * with a single memcpy, then adjust our internal pointers. If the buffer
+ * is empty, we call out to our slow path, implemented by a subclass.
+ * This method is meant to eventually be nonvirtual and inlinable.
+ */
+ uint32_t read(uint8_t* buf, uint32_t len) {
+ uint8_t* new_rBase = rBase_ + len;
+ if (TDB_LIKELY(new_rBase <= rBound_)) {
+ std::memcpy(buf, rBase_, len);
+ rBase_ = new_rBase;
+ return len;
+ }
+ return readSlow(buf, len);
+ }
+
+ /**
+ * Fast-path write.
+ *
+ * When we have enough empty space in our buffer to accomodate the write, we
+ * can satisfy it with a single memcpy, then adjust our internal pointers.
+ * If the buffer is full, we call out to our slow path, implemented by a
+ * subclass. This method is meant to eventually be nonvirtual and
+ * inlinable.
+ */
+ void write(const uint8_t* buf, uint32_t len) {
+ uint8_t* new_wBase = wBase_ + len;
+ if (TDB_LIKELY(new_wBase <= wBound_)) {
+ std::memcpy(wBase_, buf, len);
+ wBase_ = new_wBase;
+ return;
+ }
+ writeSlow(buf, len);
+ }
+
+ /**
+ * Fast-path borrow. A lot like the fast-path read.
+ */
+ const uint8_t* borrow(uint8_t* buf, uint32_t* len) {
+ if (TDB_LIKELY(static_cast<ptrdiff_t>(*len) <= rBound_ - rBase_)) {
+ // With strict aliasing, writing to len shouldn't force us to
+ // refetch rBase_ from memory. TODO(dreiss): Verify this.
+ *len = rBound_ - rBase_;
+ return rBase_;
+ }
+ return borrowSlow(buf, len);
+ }
+
+ /**
+ * Consume doesn't require a slow path.
+ */
+ void consume(uint32_t len) {
+ if (TDB_LIKELY(static_cast<ptrdiff_t>(len) <= rBound_ - rBase_)) {
+ rBase_ += len;
+ } else {
+ throw TTransportException(TTransportException::BAD_ARGS,
+ "consume did not follow a borrow.");
+ }
+ }
+
+
+ protected:
+
+ /// Slow path read.
+ virtual uint32_t readSlow(uint8_t* buf, uint32_t len) = 0;
+
+ /// Slow path write.
+ virtual void writeSlow(const uint8_t* buf, uint32_t len) = 0;
+
+ /**
+ * Slow path borrow.
+ *
+ * POSTCONDITION: return == NULL || rBound_ - rBase_ >= *len
+ */
+ virtual const uint8_t* borrowSlow(uint8_t* buf, uint32_t* len) = 0;
+
+ /**
+ * Trivial constructor.
+ *
+ * Initialize pointers safely. Constructing is not a very
+ * performance-sensitive operation, so it is okay to just leave it to
+ * the concrete class to set up pointers correctly.
+ */
+ TBufferBase()
+ : rBase_(NULL)
+ , rBound_(NULL)
+ , wBase_(NULL)
+ , wBound_(NULL)
+ {}
+
+ /// Convenience mutator for setting the read buffer.
+ void setReadBuffer(uint8_t* buf, uint32_t len) {
+ rBase_ = buf;
+ rBound_ = buf+len;
+ }
+
+ /// Convenience mutator for setting the write buffer.
+ void setWriteBuffer(uint8_t* buf, uint32_t len) {
+ wBase_ = buf;
+ wBound_ = buf+len;
+ }
+
+ virtual ~TBufferBase() {}
+
+ /// Reads begin here.
+ uint8_t* rBase_;
+ /// Reads may extend to just before here.
+ uint8_t* rBound_;
+
+ /// Writes begin here.
+ uint8_t* wBase_;
+ /// Writes may extend to just before here.
+ uint8_t* wBound_;
+};
+
+
+/**
+ * Base class for all transport which wraps transport to new one.
+ */
+class TUnderlyingTransport : public TBufferBase {
+ public:
+ static const int DEFAULT_BUFFER_SIZE = 512;
+
+ virtual bool peek() {
+ return (rBase_ < rBound_) || transport_->peek();
+ }
+
+ void open() {
+ transport_->open();
+ }
+
+ bool isOpen() {
+ return transport_->isOpen();
+ }
+
+ void close() {
+ flush();
+ transport_->close();
+ }
+
+ boost::shared_ptr<TTransport> getUnderlyingTransport() {
+ return transport_;
+ }
+
+ protected:
+ boost::shared_ptr<TTransport> transport_;
+
+ uint32_t rBufSize_;
+ uint32_t wBufSize_;
+ boost::scoped_array<uint8_t> rBuf_;
+ boost::scoped_array<uint8_t> wBuf_;
+
+ TUnderlyingTransport(boost::shared_ptr<TTransport> transport, uint32_t sz)
+ : transport_(transport)
+ , rBufSize_(sz)
+ , wBufSize_(sz)
+ , rBuf_(new uint8_t[rBufSize_])
+ , wBuf_(new uint8_t[wBufSize_]) {}
+
+ TUnderlyingTransport(boost::shared_ptr<TTransport> transport)
+ : transport_(transport)
+ , rBufSize_(DEFAULT_BUFFER_SIZE)
+ , wBufSize_(DEFAULT_BUFFER_SIZE)
+ , rBuf_(new uint8_t[rBufSize_])
+ , wBuf_(new uint8_t[wBufSize_]) {}
+
+ TUnderlyingTransport(boost::shared_ptr<TTransport> transport, uint32_t rsz, uint32_t wsz)
+ : transport_(transport)
+ , rBufSize_(rsz)
+ , wBufSize_(wsz)
+ , rBuf_(new uint8_t[rBufSize_])
+ , wBuf_(new uint8_t[wBufSize_]) {}
+};
+
+/**
+ * Buffered transport. For reads it will read more data than is requested
+ * and will serve future data out of a local buffer. For writes, data is
+ * stored to an in memory buffer before being written out.
+ *
+ */
+class TBufferedTransport : public TUnderlyingTransport {
+ public:
+
+ /// Use default buffer sizes.
+ TBufferedTransport(boost::shared_ptr<TTransport> transport)
+ : TUnderlyingTransport(transport)
+ {
+ initPointers();
+ }
+
+ /// Use specified buffer sizes.
+ TBufferedTransport(boost::shared_ptr<TTransport> transport, uint32_t sz)
+ : TUnderlyingTransport(transport, sz)
+ {
+ initPointers();
+ }
+
+ /// Use specified read and write buffer sizes.
+ TBufferedTransport(boost::shared_ptr<TTransport> transport, uint32_t rsz, uint32_t wsz)
+ : TUnderlyingTransport(transport, rsz, wsz)
+ {
+ initPointers();
+ }
+
+ virtual bool peek() {
+ /* shigin: see THRIFT-96 discussion */
+ if (rBase_ == rBound_) {
+ setReadBuffer(rBuf_.get(), transport_->read(rBuf_.get(), rBufSize_));
+ }
+ return (rBound_ > rBase_);
+ }
+ virtual uint32_t readSlow(uint8_t* buf, uint32_t len);
+
+ virtual void writeSlow(const uint8_t* buf, uint32_t len);
+
+ void flush();
+
+
+ /**
+ * The following behavior is currently implemented by TBufferedTransport,
+ * but that may change in a future version:
+ * 1/ If len is at most rBufSize_, borrow will never return NULL.
+ * Depending on the underlying transport, it could throw an exception
+ * or hang forever.
+ * 2/ Some borrow requests may copy bytes internally. However,
+ * if len is at most rBufSize_/2, none of the copied bytes
+ * will ever have to be copied again. For optimial performance,
+ * stay under this limit.
+ */
+ virtual const uint8_t* borrowSlow(uint8_t* buf, uint32_t* len);
+
+ protected:
+ void initPointers() {
+ setReadBuffer(rBuf_.get(), 0);
+ setWriteBuffer(wBuf_.get(), wBufSize_);
+ // Write size never changes.
+ }
+};
+
+
+/**
+ * Wraps a transport into a buffered one.
+ *
+ */
+class TBufferedTransportFactory : public TTransportFactory {
+ public:
+ TBufferedTransportFactory() {}
+
+ virtual ~TBufferedTransportFactory() {}
+
+ /**
+ * Wraps the transport into a buffered one.
+ */
+ virtual boost::shared_ptr<TTransport> getTransport(boost::shared_ptr<TTransport> trans) {
+ return boost::shared_ptr<TTransport>(new TBufferedTransport(trans));
+ }
+
+};
+
+
+/**
+ * Framed transport. All writes go into an in-memory buffer until flush is
+ * called, at which point the transport writes the length of the entire
+ * binary chunk followed by the data payload. This allows the receiver on the
+ * other end to always do fixed-length reads.
+ *
+ */
+class TFramedTransport : public TUnderlyingTransport {
+ public:
+
+ /// Use default buffer sizes.
+ TFramedTransport(boost::shared_ptr<TTransport> transport)
+ : TUnderlyingTransport(transport)
+ {
+ initPointers();
+ }
+
+ TFramedTransport(boost::shared_ptr<TTransport> transport, uint32_t sz)
+ : TUnderlyingTransport(transport, sz)
+ {
+ initPointers();
+ }
+
+ virtual uint32_t readSlow(uint8_t* buf, uint32_t len);
+
+ virtual void writeSlow(const uint8_t* buf, uint32_t len);
+
+ virtual void flush();
+
+ const uint8_t* borrowSlow(uint8_t* buf, uint32_t* len);
+
+ protected:
+ /**
+ * Reads a frame of input from the underlying stream.
+ */
+ void readFrame();
+
+ void initPointers() {
+ setReadBuffer(NULL, 0);
+ setWriteBuffer(wBuf_.get(), wBufSize_);
+
+ // Pad the buffer so we can insert the size later.
+ int32_t pad = 0;
+ this->write((uint8_t*)&pad, sizeof(pad));
+ }
+};
+
+/**
+ * Wraps a transport into a framed one.
+ *
+ */
+class TFramedTransportFactory : public TTransportFactory {
+ public:
+ TFramedTransportFactory() {}
+
+ virtual ~TFramedTransportFactory() {}
+
+ /**
+ * Wraps the transport into a framed one.
+ */
+ virtual boost::shared_ptr<TTransport> getTransport(boost::shared_ptr<TTransport> trans) {
+ return boost::shared_ptr<TTransport>(new TFramedTransport(trans));
+ }
+
+};
+
+
+/**
+ * A memory buffer is a tranpsort that simply reads from and writes to an
+ * in memory buffer. Anytime you call write on it, the data is simply placed
+ * into a buffer, and anytime you call read, data is read from that buffer.
+ *
+ * The buffers are allocated using C constructs malloc,realloc, and the size
+ * doubles as necessary. We've considered using scoped
+ *
+ */
+class TMemoryBuffer : public TBufferBase {
+ private:
+
+ // Common initialization done by all constructors.
+ void initCommon(uint8_t* buf, uint32_t size, bool owner, uint32_t wPos) {
+ if (buf == NULL && size != 0) {
+ assert(owner);
+ buf = (uint8_t*)std::malloc(size);
+ if (buf == NULL) {
+ throw TTransportException("Out of memory");
+ }
+ }
+
+ buffer_ = buf;
+ bufferSize_ = size;
+
+ rBase_ = buffer_;
+ rBound_ = buffer_ + wPos;
+ // TODO(dreiss): Investigate NULL-ing this if !owner.
+ wBase_ = buffer_ + wPos;
+ wBound_ = buffer_ + bufferSize_;
+
+ owner_ = owner;
+
+ // rBound_ is really an artifact. In principle, it should always be
+ // equal to wBase_. We update it in a few places (computeRead, etc.).
+ }
+
+ public:
+ static const uint32_t defaultSize = 1024;
+
+ /**
+ * This enum specifies how a TMemoryBuffer should treat
+ * memory passed to it via constructors or resetBuffer.
+ *
+ * OBSERVE:
+ * TMemoryBuffer will simply store a pointer to the memory.
+ * It is the callers responsibility to ensure that the pointer
+ * remains valid for the lifetime of the TMemoryBuffer,
+ * and that it is properly cleaned up.
+ * Note that no data can be written to observed buffers.
+ *
+ * COPY:
+ * TMemoryBuffer will make an internal copy of the buffer.
+ * The caller has no responsibilities.
+ *
+ * TAKE_OWNERSHIP:
+ * TMemoryBuffer will become the "owner" of the buffer,
+ * and will be responsible for freeing it.
+ * The membory must have been allocated with malloc.
+ */
+ enum MemoryPolicy
+ { OBSERVE = 1
+ , COPY = 2
+ , TAKE_OWNERSHIP = 3
+ };
+
+ /**
+ * Construct a TMemoryBuffer with a default-sized buffer,
+ * owned by the TMemoryBuffer object.
+ */
+ TMemoryBuffer() {
+ initCommon(NULL, defaultSize, true, 0);
+ }
+
+ /**
+ * Construct a TMemoryBuffer with a buffer of a specified size,
+ * owned by the TMemoryBuffer object.
+ *
+ * @param sz The initial size of the buffer.
+ */
+ TMemoryBuffer(uint32_t sz) {
+ initCommon(NULL, sz, true, 0);
+ }
+
+ /**
+ * Construct a TMemoryBuffer with buf as its initial contents.
+ *
+ * @param buf The initial contents of the buffer.
+ * Note that, while buf is a non-const pointer,
+ * TMemoryBuffer will not write to it if policy == OBSERVE,
+ * so it is safe to const_cast<uint8_t*>(whatever).
+ * @param sz The size of @c buf.
+ * @param policy See @link MemoryPolicy @endlink .
+ */
+ TMemoryBuffer(uint8_t* buf, uint32_t sz, MemoryPolicy policy = OBSERVE) {
+ if (buf == NULL && sz != 0) {
+ throw TTransportException(TTransportException::BAD_ARGS,
+ "TMemoryBuffer given null buffer with non-zero size.");
+ }
+
+ switch (policy) {
+ case OBSERVE:
+ case TAKE_OWNERSHIP:
+ initCommon(buf, sz, policy == TAKE_OWNERSHIP, sz);
+ break;
+ case COPY:
+ initCommon(NULL, sz, true, 0);
+ this->write(buf, sz);
+ break;
+ default:
+ throw TTransportException(TTransportException::BAD_ARGS,
+ "Invalid MemoryPolicy for TMemoryBuffer");
+ }
+ }
+
+ ~TMemoryBuffer() {
+ if (owner_) {
+ std::free(buffer_);
+ }
+ }
+
+ bool isOpen() {
+ return true;
+ }
+
+ bool peek() {
+ return (rBase_ < wBase_);
+ }
+
+ void open() {}
+
+ void close() {}
+
+ // TODO(dreiss): Make bufPtr const.
+ void getBuffer(uint8_t** bufPtr, uint32_t* sz) {
+ *bufPtr = rBase_;
+ *sz = wBase_ - rBase_;
+ }
+
+ std::string getBufferAsString() {
+ if (buffer_ == NULL) {
+ return "";
+ }
+ uint8_t* buf;
+ uint32_t sz;
+ getBuffer(&buf, &sz);
+ return std::string((char*)buf, (std::string::size_type)sz);
+ }
+
+ void appendBufferToString(std::string& str) {
+ if (buffer_ == NULL) {
+ return;
+ }
+ uint8_t* buf;
+ uint32_t sz;
+ getBuffer(&buf, &sz);
+ str.append((char*)buf, sz);
+ }
+
+ void resetBuffer(bool reset_capacity = false) {
+ if (reset_capacity)
+ {
+ assert(owner_);
+
+ void* new_buffer = std::realloc(buffer_, defaultSize);
+
+ if (new_buffer == NULL) {
+ throw TTransportException("Out of memory.");
+ }
+
+ buffer_ = (uint8_t*) new_buffer;
+ bufferSize_ = defaultSize;
+
+ wBound_ = buffer_ + bufferSize_;
+ }
+
+ rBase_ = buffer_;
+ rBound_ = buffer_;
+ wBase_ = buffer_;
+ // It isn't safe to write into a buffer we don't own.
+ if (!owner_) {
+ wBound_ = wBase_;
+ bufferSize_ = 0;
+ }
+ }
+
+ /// See constructor documentation.
+ void resetBuffer(uint8_t* buf, uint32_t sz, MemoryPolicy policy = OBSERVE) {
+ // Use a variant of the copy-and-swap trick for assignment operators.
+ // This is sub-optimal in terms of performance for two reasons:
+ // 1/ The constructing and swapping of the (small) values
+ // in the temporary object takes some time, and is not necessary.
+ // 2/ If policy == COPY, we allocate the new buffer before
+ // freeing the old one, precluding the possibility of
+ // reusing that memory.
+ // I doubt that either of these problems could be optimized away,
+ // but the second is probably no a common case, and the first is minor.
+ // I don't expect resetBuffer to be a common operation, so I'm willing to
+ // bite the performance bullet to make the method this simple.
+
+ // Construct the new buffer.
+ TMemoryBuffer new_buffer(buf, sz, policy);
+ // Move it into ourself.
+ this->swap(new_buffer);
+ // Our old self gets destroyed.
+ }
+
+ std::string readAsString(uint32_t len) {
+ std::string str;
+ (void)readAppendToString(str, len);
+ return str;
+ }
+
+ uint32_t readAppendToString(std::string& str, uint32_t len);
+
+ void readEnd() {
+ if (rBase_ == wBase_) {
+ resetBuffer();
+ }
+ }
+
+ uint32_t available_read() const {
+ // Remember, wBase_ is the real rBound_.
+ return wBase_ - rBase_;
+ }
+
+ uint32_t available_write() const {
+ return wBound_ - wBase_;
+ }
+
+ // Returns a pointer to where the client can write data to append to
+ // the TMemoryBuffer, and ensures the buffer is big enough to accomodate a
+ // write of the provided length. The returned pointer is very convenient for
+ // passing to read(), recv(), or similar. You must call wroteBytes() as soon
+ // as data is written or the buffer will not be aware that data has changed.
+ uint8_t* getWritePtr(uint32_t len) {
+ ensureCanWrite(len);
+ return wBase_;
+ }
+
+ // Informs the buffer that the client has written 'len' bytes into storage
+ // that had been provided by getWritePtr().
+ void wroteBytes(uint32_t len);
+
+ protected:
+ void swap(TMemoryBuffer& that) {
+ using std::swap;
+ swap(buffer_, that.buffer_);
+ swap(bufferSize_, that.bufferSize_);
+
+ swap(rBase_, that.rBase_);
+ swap(rBound_, that.rBound_);
+ swap(wBase_, that.wBase_);
+ swap(wBound_, that.wBound_);
+
+ swap(owner_, that.owner_);
+ }
+
+ // Make sure there's at least 'len' bytes available for writing.
+ void ensureCanWrite(uint32_t len);
+
+ // Compute the position and available data for reading.
+ void computeRead(uint32_t len, uint8_t** out_start, uint32_t* out_give);
+
+ uint32_t readSlow(uint8_t* buf, uint32_t len);
+
+ void writeSlow(const uint8_t* buf, uint32_t len);
+
+ const uint8_t* borrowSlow(uint8_t* buf, uint32_t* len);
+
+ // Data buffer
+ uint8_t* buffer_;
+
+ // Allocated buffer size
+ uint32_t bufferSize_;
+
+ // Is this object the owner of the buffer?
+ bool owner_;
+
+ // Don't forget to update constrctors, initCommon, and swap if
+ // you add new members.
+};
+
+}}} // apache::thrift::transport
+
+#endif // #ifndef _THRIFT_TRANSPORT_TBUFFERTRANSPORTS_H_
diff --git a/lib/cpp/src/transport/TFDTransport.cpp b/lib/cpp/src/transport/TFDTransport.cpp
new file mode 100644
index 0000000..a042f8b
--- /dev/null
+++ b/lib/cpp/src/transport/TFDTransport.cpp
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <cerrno>
+#include <exception>
+
+#include <transport/TFDTransport.h>
+
+#include <unistd.h>
+
+using namespace std;
+
+namespace apache { namespace thrift { namespace transport {
+
+void TFDTransport::close() {
+ if (!isOpen()) {
+ return;
+ }
+
+ int rv = ::close(fd_);
+ int errno_copy = errno;
+ fd_ = -1;
+ // Have to check uncaught_exception because this is called in the destructor.
+ if (rv < 0 && !std::uncaught_exception()) {
+ throw TTransportException(TTransportException::UNKNOWN,
+ "TFDTransport::close()",
+ errno_copy);
+ }
+}
+
+uint32_t TFDTransport::read(uint8_t* buf, uint32_t len) {
+ ssize_t rv = ::read(fd_, buf, len);
+ if (rv < 0) {
+ int errno_copy = errno;
+ throw TTransportException(TTransportException::UNKNOWN,
+ "TFDTransport::read()",
+ errno_copy);
+ }
+ return rv;
+}
+
+void TFDTransport::write(const uint8_t* buf, uint32_t len) {
+ while (len > 0) {
+ ssize_t rv = ::write(fd_, buf, len);
+
+ if (rv < 0) {
+ int errno_copy = errno;
+ throw TTransportException(TTransportException::UNKNOWN,
+ "TFDTransport::write()",
+ errno_copy);
+ } else if (rv == 0) {
+ throw TTransportException(TTransportException::END_OF_FILE,
+ "TFDTransport::write()");
+ }
+
+ buf += rv;
+ len -= rv;
+ }
+}
+
+}}} // apache::thrift::transport
diff --git a/lib/cpp/src/transport/TFDTransport.h b/lib/cpp/src/transport/TFDTransport.h
new file mode 100644
index 0000000..bda5d82
--- /dev/null
+++ b/lib/cpp/src/transport/TFDTransport.h
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_TRANSPORT_TFDTRANSPORT_H_
+#define _THRIFT_TRANSPORT_TFDTRANSPORT_H_ 1
+
+#include <string>
+#include <sys/time.h>
+
+#include "TTransport.h"
+#include "TServerSocket.h"
+
+namespace apache { namespace thrift { namespace transport {
+
+/**
+ * Dead-simple wrapper around a file descriptor.
+ *
+ */
+class TFDTransport : public TTransport {
+ public:
+ enum ClosePolicy
+ { NO_CLOSE_ON_DESTROY = 0
+ , CLOSE_ON_DESTROY = 1
+ };
+
+ TFDTransport(int fd, ClosePolicy close_policy = NO_CLOSE_ON_DESTROY)
+ : fd_(fd)
+ , close_policy_(close_policy)
+ {}
+
+ ~TFDTransport() {
+ if (close_policy_ == CLOSE_ON_DESTROY) {
+ close();
+ }
+ }
+
+ bool isOpen() { return fd_ >= 0; }
+
+ void open() {}
+
+ void close();
+
+ uint32_t read(uint8_t* buf, uint32_t len);
+
+ void write(const uint8_t* buf, uint32_t len);
+
+ void setFD(int fd) { fd_ = fd; }
+ int getFD() { return fd_; }
+
+ protected:
+ int fd_;
+ ClosePolicy close_policy_;
+};
+
+}}} // apache::thrift::transport
+
+#endif // #ifndef _THRIFT_TRANSPORT_TFDTRANSPORT_H_
diff --git a/lib/cpp/src/transport/TFileTransport.cpp b/lib/cpp/src/transport/TFileTransport.cpp
new file mode 100644
index 0000000..f67b9e3
--- /dev/null
+++ b/lib/cpp/src/transport/TFileTransport.cpp
@@ -0,0 +1,953 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include "TFileTransport.h"
+#include "TTransportUtils.h"
+
+#include <pthread.h>
+#ifdef HAVE_SYS_TIME_H
+#include <sys/time.h>
+#else
+#include <time.h>
+#endif
+#include <fcntl.h>
+#include <errno.h>
+#include <unistd.h>
+#ifdef HAVE_STRINGS_H
+#include <strings.h>
+#endif
+#include <cstdlib>
+#include <cstring>
+#include <iostream>
+#include <sys/stat.h>
+
+namespace apache { namespace thrift { namespace transport {
+
+using boost::shared_ptr;
+using namespace std;
+using namespace apache::thrift::protocol;
+
+#ifndef HAVE_CLOCK_GETTIME
+
+/**
+ * Fake clock_gettime for systems like darwin
+ *
+ */
+#define CLOCK_REALTIME 0
+static int clock_gettime(int clk_id /*ignored*/, struct timespec *tp) {
+ struct timeval now;
+
+ int rv = gettimeofday(&now, NULL);
+ if (rv != 0) {
+ return rv;
+ }
+
+ tp->tv_sec = now.tv_sec;
+ tp->tv_nsec = now.tv_usec * 1000;
+ return 0;
+}
+#endif
+
+TFileTransport::TFileTransport(string path, bool readOnly)
+ : readState_()
+ , readBuff_(NULL)
+ , currentEvent_(NULL)
+ , readBuffSize_(DEFAULT_READ_BUFF_SIZE)
+ , readTimeout_(NO_TAIL_READ_TIMEOUT)
+ , chunkSize_(DEFAULT_CHUNK_SIZE)
+ , eventBufferSize_(DEFAULT_EVENT_BUFFER_SIZE)
+ , flushMaxUs_(DEFAULT_FLUSH_MAX_US)
+ , flushMaxBytes_(DEFAULT_FLUSH_MAX_BYTES)
+ , maxEventSize_(DEFAULT_MAX_EVENT_SIZE)
+ , maxCorruptedEvents_(DEFAULT_MAX_CORRUPTED_EVENTS)
+ , eofSleepTime_(DEFAULT_EOF_SLEEP_TIME_US)
+ , corruptedEventSleepTime_(DEFAULT_CORRUPTED_SLEEP_TIME_US)
+ , writerThreadId_(0)
+ , dequeueBuffer_(NULL)
+ , enqueueBuffer_(NULL)
+ , closing_(false)
+ , forceFlush_(false)
+ , filename_(path)
+ , fd_(0)
+ , bufferAndThreadInitialized_(false)
+ , offset_(0)
+ , lastBadChunk_(0)
+ , numCorruptedEventsInChunk_(0)
+ , readOnly_(readOnly)
+{
+ // initialize all the condition vars/mutexes
+ pthread_mutex_init(&mutex_, NULL);
+ pthread_cond_init(¬Full_, NULL);
+ pthread_cond_init(¬Empty_, NULL);
+ pthread_cond_init(&flushed_, NULL);
+
+ openLogFile();
+}
+
+void TFileTransport::resetOutputFile(int fd, string filename, int64_t offset) {
+ filename_ = filename;
+ offset_ = offset;
+
+ // check if current file is still open
+ if (fd_ > 0) {
+ // flush any events in the queue
+ flush();
+ GlobalOutput.printf("error, current file (%s) not closed", filename_.c_str());
+ if (-1 == ::close(fd_)) {
+ int errno_copy = errno;
+ GlobalOutput.perror("TFileTransport: resetOutputFile() ::close() ", errno_copy);
+ throw TTransportException(TTransportException::UNKNOWN, "TFileTransport: error in file close", errno_copy);
+ }
+ }
+
+ if (fd) {
+ fd_ = fd;
+ } else {
+ // open file if the input fd is 0
+ openLogFile();
+ }
+}
+
+
+TFileTransport::~TFileTransport() {
+ // flush the buffer if a writer thread is active
+ if (writerThreadId_ > 0) {
+ // reduce the flush timeout so that closing is quicker
+ setFlushMaxUs(300*1000);
+
+ // flush output buffer
+ flush();
+
+ // set state to closing
+ closing_ = true;
+
+ // TODO: make sure event queue is empty
+ // currently only the write buffer is flushed
+ // we dont actually wait until the queue is empty. This shouldn't be a big
+ // deal in the common case because writing is quick
+
+ pthread_join(writerThreadId_, NULL);
+ writerThreadId_ = 0;
+ }
+
+ if (dequeueBuffer_) {
+ delete dequeueBuffer_;
+ dequeueBuffer_ = NULL;
+ }
+
+ if (enqueueBuffer_) {
+ delete enqueueBuffer_;
+ enqueueBuffer_ = NULL;
+ }
+
+ if (readBuff_) {
+ delete[] readBuff_;
+ readBuff_ = NULL;
+ }
+
+ if (currentEvent_) {
+ delete currentEvent_;
+ currentEvent_ = NULL;
+ }
+
+ // close logfile
+ if (fd_ > 0) {
+ if(-1 == ::close(fd_)) {
+ GlobalOutput.perror("TFileTransport: ~TFileTransport() ::close() ", errno);
+ }
+ }
+}
+
+bool TFileTransport::initBufferAndWriteThread() {
+ if (bufferAndThreadInitialized_) {
+ T_ERROR("Trying to double-init TFileTransport");
+ return false;
+ }
+
+ if (writerThreadId_ == 0) {
+ if (pthread_create(&writerThreadId_, NULL, startWriterThread, (void *)this) != 0) {
+ T_ERROR("Could not create writer thread");
+ return false;
+ }
+ }
+
+ dequeueBuffer_ = new TFileTransportBuffer(eventBufferSize_);
+ enqueueBuffer_ = new TFileTransportBuffer(eventBufferSize_);
+ bufferAndThreadInitialized_ = true;
+
+ return true;
+}
+
+void TFileTransport::write(const uint8_t* buf, uint32_t len) {
+ if (readOnly_) {
+ throw TTransportException("TFileTransport: attempting to write to file opened readonly");
+ }
+
+ enqueueEvent(buf, len, false);
+}
+
+void TFileTransport::enqueueEvent(const uint8_t* buf, uint32_t eventLen, bool blockUntilFlush) {
+ // can't enqueue more events if file is going to close
+ if (closing_) {
+ return;
+ }
+
+ // make sure that event size is valid
+ if ( (maxEventSize_ > 0) && (eventLen > maxEventSize_) ) {
+ T_ERROR("msg size is greater than max event size: %u > %u\n", eventLen, maxEventSize_);
+ return;
+ }
+
+ if (eventLen == 0) {
+ T_ERROR("cannot enqueue an empty event");
+ return;
+ }
+
+ eventInfo* toEnqueue = new eventInfo();
+ toEnqueue->eventBuff_ = (uint8_t *)std::malloc((sizeof(uint8_t) * eventLen) + 4);
+ // first 4 bytes is the event length
+ memcpy(toEnqueue->eventBuff_, (void*)(&eventLen), 4);
+ // actual event contents
+ memcpy(toEnqueue->eventBuff_ + 4, buf, eventLen);
+ toEnqueue->eventSize_ = eventLen + 4;
+
+ // lock mutex
+ pthread_mutex_lock(&mutex_);
+
+ // make sure that enqueue buffer is initialized and writer thread is running
+ if (!bufferAndThreadInitialized_) {
+ if (!initBufferAndWriteThread()) {
+ delete toEnqueue;
+ pthread_mutex_unlock(&mutex_);
+ return;
+ }
+ }
+
+ // Can't enqueue while buffer is full
+ while (enqueueBuffer_->isFull()) {
+ pthread_cond_wait(¬Full_, &mutex_);
+ }
+
+ // add to the buffer
+ if (!enqueueBuffer_->addEvent(toEnqueue)) {
+ delete toEnqueue;
+ pthread_mutex_unlock(&mutex_);
+ return;
+ }
+
+ // signal anybody who's waiting for the buffer to be non-empty
+ pthread_cond_signal(¬Empty_);
+
+ if (blockUntilFlush) {
+ pthread_cond_wait(&flushed_, &mutex_);
+ }
+
+ // this really should be a loop where it makes sure it got flushed
+ // because condition variables can get triggered by the os for no reason
+ // it is probably a non-factor for the time being
+ pthread_mutex_unlock(&mutex_);
+}
+
+bool TFileTransport::swapEventBuffers(struct timespec* deadline) {
+ pthread_mutex_lock(&mutex_);
+ if (deadline != NULL) {
+ // if we were handed a deadline time struct, do a timed wait
+ pthread_cond_timedwait(¬Empty_, &mutex_, deadline);
+ } else {
+ // just wait until the buffer gets an item
+ pthread_cond_wait(¬Empty_, &mutex_);
+ }
+
+ bool swapped = false;
+
+ // could be empty if we timed out
+ if (!enqueueBuffer_->isEmpty()) {
+ TFileTransportBuffer *temp = enqueueBuffer_;
+ enqueueBuffer_ = dequeueBuffer_;
+ dequeueBuffer_ = temp;
+
+ swapped = true;
+ }
+
+ // unlock the mutex and signal if required
+ pthread_mutex_unlock(&mutex_);
+
+ if (swapped) {
+ pthread_cond_signal(¬Full_);
+ }
+
+ return swapped;
+}
+
+
+void TFileTransport::writerThread() {
+ // open file if it is not open
+ if(!fd_) {
+ openLogFile();
+ }
+
+ // set the offset to the correct value (EOF)
+ try {
+ seekToEnd();
+ } catch (TException &te) {
+ }
+
+ // throw away any partial events
+ offset_ += readState_.lastDispatchPtr_;
+ ftruncate(fd_, offset_);
+ readState_.resetAllValues();
+
+ // Figure out the next time by which a flush must take place
+
+ struct timespec ts_next_flush;
+ getNextFlushTime(&ts_next_flush);
+ uint32_t unflushed = 0;
+
+ while(1) {
+ // this will only be true when the destructor is being invoked
+ if(closing_) {
+ // empty out both the buffers
+ if (enqueueBuffer_->isEmpty() && dequeueBuffer_->isEmpty()) {
+ if (-1 == ::close(fd_)) {
+ int errno_copy = errno;
+ GlobalOutput.perror("TFileTransport: writerThread() ::close() ", errno_copy);
+ throw TTransportException(TTransportException::UNKNOWN, "TFileTransport: error in file close", errno_copy);
+ }
+ // just be safe and sync to disk
+ fsync(fd_);
+ fd_ = 0;
+ pthread_exit(NULL);
+ return;
+ }
+ }
+
+ if (swapEventBuffers(&ts_next_flush)) {
+ eventInfo* outEvent;
+ while (NULL != (outEvent = dequeueBuffer_->getNext())) {
+ if (!outEvent) {
+ T_DEBUG_L(1, "Got an empty event");
+ return;
+ }
+
+ // sanity check on event
+ if ((maxEventSize_ > 0) && (outEvent->eventSize_ > maxEventSize_)) {
+ T_ERROR("msg size is greater than max event size: %u > %u\n", outEvent->eventSize_, maxEventSize_);
+ continue;
+ }
+
+ // If chunking is required, then make sure that msg does not cross chunk boundary
+ if ((outEvent->eventSize_ > 0) && (chunkSize_ != 0)) {
+
+ // event size must be less than chunk size
+ if(outEvent->eventSize_ > chunkSize_) {
+ T_ERROR("TFileTransport: event size(%u) is greater than chunk size(%u): skipping event",
+ outEvent->eventSize_, chunkSize_);
+ continue;
+ }
+
+ int64_t chunk1 = offset_/chunkSize_;
+ int64_t chunk2 = (offset_ + outEvent->eventSize_ - 1)/chunkSize_;
+
+ // if adding this event will cross a chunk boundary, pad the chunk with zeros
+ if (chunk1 != chunk2) {
+ // refetch the offset to keep in sync
+ offset_ = lseek(fd_, 0, SEEK_CUR);
+ int32_t padding = (int32_t)((offset_/chunkSize_ + 1)*chunkSize_ - offset_);
+
+ uint8_t zeros[padding];
+ bzero(zeros, padding);
+ if (-1 == ::write(fd_, zeros, padding)) {
+ int errno_copy = errno;
+ GlobalOutput.perror("TFileTransport: writerThread() error while padding zeros ", errno_copy);
+ throw TTransportException(TTransportException::UNKNOWN, "TFileTransport: error while padding zeros", errno_copy);
+ }
+ unflushed += padding;
+ offset_ += padding;
+ }
+ }
+
+ // write the dequeued event to the file
+ if (outEvent->eventSize_ > 0) {
+ if (-1 == ::write(fd_, outEvent->eventBuff_, outEvent->eventSize_)) {
+ int errno_copy = errno;
+ GlobalOutput.perror("TFileTransport: error while writing event ", errno_copy);
+ throw TTransportException(TTransportException::UNKNOWN, "TFileTransport: error while writing event", errno_copy);
+ }
+
+ unflushed += outEvent->eventSize_;
+ offset_ += outEvent->eventSize_;
+ }
+ }
+ dequeueBuffer_->reset();
+ }
+
+ bool flushTimeElapsed = false;
+ struct timespec current_time;
+ clock_gettime(CLOCK_REALTIME, ¤t_time);
+
+ if (current_time.tv_sec > ts_next_flush.tv_sec ||
+ (current_time.tv_sec == ts_next_flush.tv_sec && current_time.tv_nsec > ts_next_flush.tv_nsec)) {
+ flushTimeElapsed = true;
+ getNextFlushTime(&ts_next_flush);
+ }
+
+ // couple of cases from which a flush could be triggered
+ if ((flushTimeElapsed && unflushed > 0) ||
+ unflushed > flushMaxBytes_ ||
+ forceFlush_) {
+
+ // sync (force flush) file to disk
+ fsync(fd_);
+ unflushed = 0;
+
+ // notify anybody waiting for flush completion
+ forceFlush_ = false;
+ pthread_cond_broadcast(&flushed_);
+ }
+ }
+}
+
+void TFileTransport::flush() {
+ // file must be open for writing for any flushing to take place
+ if (writerThreadId_ <= 0) {
+ return;
+ }
+ // wait for flush to take place
+ pthread_mutex_lock(&mutex_);
+
+ forceFlush_ = true;
+
+ while (forceFlush_) {
+ pthread_cond_wait(&flushed_, &mutex_);
+ }
+
+ pthread_mutex_unlock(&mutex_);
+}
+
+
+uint32_t TFileTransport::readAll(uint8_t* buf, uint32_t len) {
+ uint32_t have = 0;
+ uint32_t get = 0;
+
+ while (have < len) {
+ get = read(buf+have, len-have);
+ if (get <= 0) {
+ throw TEOFException();
+ }
+ have += get;
+ }
+
+ return have;
+}
+
+uint32_t TFileTransport::read(uint8_t* buf, uint32_t len) {
+ // check if there an event is ready to be read
+ if (!currentEvent_) {
+ currentEvent_ = readEvent();
+ }
+
+ // did not manage to read an event from the file. This could have happened
+ // if the timeout expired or there was some other error
+ if (!currentEvent_) {
+ return 0;
+ }
+
+ // read as much of the current event as possible
+ int32_t remaining = currentEvent_->eventSize_ - currentEvent_->eventBuffPos_;
+ if (remaining <= (int32_t)len) {
+ // copy over anything thats remaining
+ if (remaining > 0) {
+ memcpy(buf,
+ currentEvent_->eventBuff_ + currentEvent_->eventBuffPos_,
+ remaining);
+ }
+ delete(currentEvent_);
+ currentEvent_ = NULL;
+ return remaining;
+ }
+
+ // read as much as possible
+ memcpy(buf, currentEvent_->eventBuff_ + currentEvent_->eventBuffPos_, len);
+ currentEvent_->eventBuffPos_ += len;
+ return len;
+}
+
+eventInfo* TFileTransport::readEvent() {
+ int readTries = 0;
+
+ if (!readBuff_) {
+ readBuff_ = new uint8_t[readBuffSize_];
+ }
+
+ while (1) {
+ // read from the file if read buffer is exhausted
+ if (readState_.bufferPtr_ == readState_.bufferLen_) {
+ // advance the offset pointer
+ offset_ += readState_.bufferLen_;
+ readState_.bufferLen_ = ::read(fd_, readBuff_, readBuffSize_);
+ // if (readState_.bufferLen_) {
+ // T_DEBUG_L(1, "Amount read: %u (offset: %lu)", readState_.bufferLen_, offset_);
+ // }
+ readState_.bufferPtr_ = 0;
+ readState_.lastDispatchPtr_ = 0;
+
+ // read error
+ if (readState_.bufferLen_ == -1) {
+ readState_.resetAllValues();
+ GlobalOutput("TFileTransport: error while reading from file");
+ throw TTransportException("TFileTransport: error while reading from file");
+ } else if (readState_.bufferLen_ == 0) { // EOF
+ // wait indefinitely if there is no timeout
+ if (readTimeout_ == TAIL_READ_TIMEOUT) {
+ usleep(eofSleepTime_);
+ continue;
+ } else if (readTimeout_ == NO_TAIL_READ_TIMEOUT) {
+ // reset state
+ readState_.resetState(0);
+ return NULL;
+ } else if (readTimeout_ > 0) {
+ // timeout already expired once
+ if (readTries > 0) {
+ readState_.resetState(0);
+ return NULL;
+ } else {
+ usleep(readTimeout_ * 1000);
+ readTries++;
+ continue;
+ }
+ }
+ }
+ }
+
+ readTries = 0;
+
+ // attempt to read an event from the buffer
+ while(readState_.bufferPtr_ < readState_.bufferLen_) {
+ if (readState_.readingSize_) {
+ if(readState_.eventSizeBuffPos_ == 0) {
+ if ( (offset_ + readState_.bufferPtr_)/chunkSize_ !=
+ ((offset_ + readState_.bufferPtr_ + 3)/chunkSize_)) {
+ // skip one byte towards chunk boundary
+ // T_DEBUG_L(1, "Skipping a byte");
+ readState_.bufferPtr_++;
+ continue;
+ }
+ }
+
+ readState_.eventSizeBuff_[readState_.eventSizeBuffPos_++] =
+ readBuff_[readState_.bufferPtr_++];
+ if (readState_.eventSizeBuffPos_ == 4) {
+ // 0 length event indicates padding
+ if (*((uint32_t *)(readState_.eventSizeBuff_)) == 0) {
+ // T_DEBUG_L(1, "Got padding");
+ readState_.resetState(readState_.lastDispatchPtr_);
+ continue;
+ }
+ // got a valid event
+ readState_.readingSize_ = false;
+ if (readState_.event_) {
+ delete(readState_.event_);
+ }
+ readState_.event_ = new eventInfo();
+ readState_.event_->eventSize_ = *((uint32_t *)(readState_.eventSizeBuff_));
+
+ // check if the event is corrupted and perform recovery if required
+ if (isEventCorrupted()) {
+ performRecovery();
+ // start from the top
+ break;
+ }
+ }
+ } else {
+ if (!readState_.event_->eventBuff_) {
+ readState_.event_->eventBuff_ = new uint8_t[readState_.event_->eventSize_];
+ readState_.event_->eventBuffPos_ = 0;
+ }
+ // take either the entire event or the remaining bytes in the buffer
+ int reclaimBuffer = min((uint32_t)(readState_.bufferLen_ - readState_.bufferPtr_),
+ readState_.event_->eventSize_ - readState_.event_->eventBuffPos_);
+
+ // copy data from read buffer into event buffer
+ memcpy(readState_.event_->eventBuff_ + readState_.event_->eventBuffPos_,
+ readBuff_ + readState_.bufferPtr_,
+ reclaimBuffer);
+
+ // increment position ptrs
+ readState_.event_->eventBuffPos_ += reclaimBuffer;
+ readState_.bufferPtr_ += reclaimBuffer;
+
+ // check if the event has been read in full
+ if (readState_.event_->eventBuffPos_ == readState_.event_->eventSize_) {
+ // set the completed event to the current event
+ eventInfo* completeEvent = readState_.event_;
+ completeEvent->eventBuffPos_ = 0;
+
+ readState_.event_ = NULL;
+ readState_.resetState(readState_.bufferPtr_);
+
+ // exit criteria
+ return completeEvent;
+ }
+ }
+ }
+
+ }
+}
+
+bool TFileTransport::isEventCorrupted() {
+ // an error is triggered if:
+ if ( (maxEventSize_ > 0) && (readState_.event_->eventSize_ > maxEventSize_)) {
+ // 1. Event size is larger than user-speficied max-event size
+ T_ERROR("Read corrupt event. Event size(%u) greater than max event size (%u)",
+ readState_.event_->eventSize_, maxEventSize_);
+ return true;
+ } else if (readState_.event_->eventSize_ > chunkSize_) {
+ // 2. Event size is larger than chunk size
+ T_ERROR("Read corrupt event. Event size(%u) greater than chunk size (%u)",
+ readState_.event_->eventSize_, chunkSize_);
+ return true;
+ } else if( ((offset_ + readState_.bufferPtr_ - 4)/chunkSize_) !=
+ ((offset_ + readState_.bufferPtr_ + readState_.event_->eventSize_ - 1)/chunkSize_) ) {
+ // 3. size indicates that event crosses chunk boundary
+ T_ERROR("Read corrupt event. Event crosses chunk boundary. Event size:%u Offset:%ld",
+ readState_.event_->eventSize_, offset_ + readState_.bufferPtr_ + 4);
+ return true;
+ }
+
+ return false;
+}
+
+void TFileTransport::performRecovery() {
+ // perform some kickass recovery
+ uint32_t curChunk = getCurChunk();
+ if (lastBadChunk_ == curChunk) {
+ numCorruptedEventsInChunk_++;
+ } else {
+ lastBadChunk_ = curChunk;
+ numCorruptedEventsInChunk_ = 1;
+ }
+
+ if (numCorruptedEventsInChunk_ < maxCorruptedEvents_) {
+ // maybe there was an error in reading the file from disk
+ // seek to the beginning of chunk and try again
+ seekToChunk(curChunk);
+ } else {
+
+ // just skip ahead to the next chunk if we not already at the last chunk
+ if (curChunk != (getNumChunks() - 1)) {
+ seekToChunk(curChunk + 1);
+ } else if (readTimeout_ == TAIL_READ_TIMEOUT) {
+ // if tailing the file, wait until there is enough data to start
+ // the next chunk
+ while(curChunk == (getNumChunks() - 1)) {
+ usleep(DEFAULT_CORRUPTED_SLEEP_TIME_US);
+ }
+ seekToChunk(curChunk + 1);
+ } else {
+ // pretty hosed at this stage, rewind the file back to the last successful
+ // point and punt on the error
+ readState_.resetState(readState_.lastDispatchPtr_);
+ currentEvent_ = NULL;
+ char errorMsg[1024];
+ sprintf(errorMsg, "TFileTransport: log file corrupted at offset: %lu",
+ offset_ + readState_.lastDispatchPtr_);
+ GlobalOutput(errorMsg);
+ throw TTransportException(errorMsg);
+ }
+ }
+
+}
+
+void TFileTransport::seekToChunk(int32_t chunk) {
+ if (fd_ <= 0) {
+ throw TTransportException("File not open");
+ }
+
+ int32_t numChunks = getNumChunks();
+
+ // file is empty, seeking to chunk is pointless
+ if (numChunks == 0) {
+ return;
+ }
+
+ // negative indicates reverse seek (from the end)
+ if (chunk < 0) {
+ chunk += numChunks;
+ }
+
+ // too large a value for reverse seek, just seek to beginning
+ if (chunk < 0) {
+ T_DEBUG("Incorrect value for reverse seek. Seeking to beginning...", chunk)
+ chunk = 0;
+ }
+
+ // cannot seek past EOF
+ bool seekToEnd = false;
+ uint32_t minEndOffset = 0;
+ if (chunk >= numChunks) {
+ T_DEBUG("Trying to seek past EOF. Seeking to EOF instead...");
+ seekToEnd = true;
+ chunk = numChunks - 1;
+ // this is the min offset to process events till
+ minEndOffset = lseek(fd_, 0, SEEK_END);
+ }
+
+ off_t newOffset = off_t(chunk) * chunkSize_;
+ offset_ = lseek(fd_, newOffset, SEEK_SET);
+ readState_.resetAllValues();
+ currentEvent_ = NULL;
+ if (offset_ == -1) {
+ GlobalOutput("TFileTransport: lseek error in seekToChunk");
+ throw TTransportException("TFileTransport: lseek error in seekToChunk");
+ }
+
+ // seek to EOF if user wanted to go to last chunk
+ if (seekToEnd) {
+ uint32_t oldReadTimeout = getReadTimeout();
+ setReadTimeout(NO_TAIL_READ_TIMEOUT);
+ // keep on reading unti the last event at point of seekChunk call
+ while (readEvent() && ((offset_ + readState_.bufferPtr_) < minEndOffset)) {};
+ setReadTimeout(oldReadTimeout);
+ }
+
+}
+
+void TFileTransport::seekToEnd() {
+ seekToChunk(getNumChunks());
+}
+
+uint32_t TFileTransport::getNumChunks() {
+ if (fd_ <= 0) {
+ return 0;
+ }
+
+ struct stat f_info;
+ int rv = fstat(fd_, &f_info);
+
+ if (rv < 0) {
+ int errno_copy = errno;
+ throw TTransportException(TTransportException::UNKNOWN,
+ "TFileTransport::getNumChunks() (fstat)",
+ errno_copy);
+ }
+
+ if (f_info.st_size > 0) {
+ return ((f_info.st_size)/chunkSize_) + 1;
+ }
+
+ // empty file has no chunks
+ return 0;
+}
+
+uint32_t TFileTransport::getCurChunk() {
+ return offset_/chunkSize_;
+}
+
+// Utility Functions
+void TFileTransport::openLogFile() {
+ mode_t mode = readOnly_ ? S_IRUSR | S_IRGRP | S_IROTH : S_IRUSR | S_IWUSR| S_IRGRP | S_IROTH;
+ int flags = readOnly_ ? O_RDONLY : O_RDWR | O_CREAT | O_APPEND;
+ fd_ = ::open(filename_.c_str(), flags, mode);
+ offset_ = 0;
+
+ // make sure open call was successful
+ if(fd_ == -1) {
+ int errno_copy = errno;
+ GlobalOutput.perror("TFileTransport: openLogFile() ::open() file: " + filename_, errno_copy);
+ throw TTransportException(TTransportException::NOT_OPEN, filename_, errno_copy);
+ }
+
+}
+
+void TFileTransport::getNextFlushTime(struct timespec* ts_next_flush) {
+ clock_gettime(CLOCK_REALTIME, ts_next_flush);
+ ts_next_flush->tv_nsec += (flushMaxUs_ % 1000000) * 1000;
+ if (ts_next_flush->tv_nsec > 1000000000) {
+ ts_next_flush->tv_nsec -= 1000000000;
+ ts_next_flush->tv_sec += 1;
+ }
+ ts_next_flush->tv_sec += flushMaxUs_ / 1000000;
+}
+
+TFileTransportBuffer::TFileTransportBuffer(uint32_t size)
+ : bufferMode_(WRITE)
+ , writePoint_(0)
+ , readPoint_(0)
+ , size_(size)
+{
+ buffer_ = new eventInfo*[size];
+}
+
+TFileTransportBuffer::~TFileTransportBuffer() {
+ if (buffer_) {
+ for (uint32_t i = 0; i < writePoint_; i++) {
+ delete buffer_[i];
+ }
+ delete[] buffer_;
+ buffer_ = NULL;
+ }
+}
+
+bool TFileTransportBuffer::addEvent(eventInfo *event) {
+ if (bufferMode_ == READ) {
+ GlobalOutput("Trying to write to a buffer in read mode");
+ }
+ if (writePoint_ < size_) {
+ buffer_[writePoint_++] = event;
+ return true;
+ } else {
+ // buffer is full
+ return false;
+ }
+}
+
+eventInfo* TFileTransportBuffer::getNext() {
+ if (bufferMode_ == WRITE) {
+ bufferMode_ = READ;
+ }
+ if (readPoint_ < writePoint_) {
+ return buffer_[readPoint_++];
+ } else {
+ // no more entries
+ return NULL;
+ }
+}
+
+void TFileTransportBuffer::reset() {
+ if (bufferMode_ == WRITE || writePoint_ > readPoint_) {
+ T_DEBUG("Resetting a buffer with unread entries");
+ }
+ // Clean up the old entries
+ for (uint32_t i = 0; i < writePoint_; i++) {
+ delete buffer_[i];
+ }
+ bufferMode_ = WRITE;
+ writePoint_ = 0;
+ readPoint_ = 0;
+}
+
+bool TFileTransportBuffer::isFull() {
+ return writePoint_ == size_;
+}
+
+bool TFileTransportBuffer::isEmpty() {
+ return writePoint_ == 0;
+}
+
+TFileProcessor::TFileProcessor(shared_ptr<TProcessor> processor,
+ shared_ptr<TProtocolFactory> protocolFactory,
+ shared_ptr<TFileReaderTransport> inputTransport):
+ processor_(processor),
+ inputProtocolFactory_(protocolFactory),
+ outputProtocolFactory_(protocolFactory),
+ inputTransport_(inputTransport) {
+
+ // default the output transport to a null transport (common case)
+ outputTransport_ = shared_ptr<TNullTransport>(new TNullTransport());
+}
+
+TFileProcessor::TFileProcessor(shared_ptr<TProcessor> processor,
+ shared_ptr<TProtocolFactory> inputProtocolFactory,
+ shared_ptr<TProtocolFactory> outputProtocolFactory,
+ shared_ptr<TFileReaderTransport> inputTransport):
+ processor_(processor),
+ inputProtocolFactory_(inputProtocolFactory),
+ outputProtocolFactory_(outputProtocolFactory),
+ inputTransport_(inputTransport) {
+
+ // default the output transport to a null transport (common case)
+ outputTransport_ = shared_ptr<TNullTransport>(new TNullTransport());
+}
+
+TFileProcessor::TFileProcessor(shared_ptr<TProcessor> processor,
+ shared_ptr<TProtocolFactory> protocolFactory,
+ shared_ptr<TFileReaderTransport> inputTransport,
+ shared_ptr<TTransport> outputTransport):
+ processor_(processor),
+ inputProtocolFactory_(protocolFactory),
+ outputProtocolFactory_(protocolFactory),
+ inputTransport_(inputTransport),
+ outputTransport_(outputTransport) {};
+
+void TFileProcessor::process(uint32_t numEvents, bool tail) {
+ shared_ptr<TProtocol> inputProtocol = inputProtocolFactory_->getProtocol(inputTransport_);
+ shared_ptr<TProtocol> outputProtocol = outputProtocolFactory_->getProtocol(outputTransport_);
+
+ // set the read timeout to 0 if tailing is required
+ int32_t oldReadTimeout = inputTransport_->getReadTimeout();
+ if (tail) {
+ // save old read timeout so it can be restored
+ inputTransport_->setReadTimeout(TFileTransport::TAIL_READ_TIMEOUT);
+ }
+
+ uint32_t numProcessed = 0;
+ while(1) {
+ // bad form to use exceptions for flow control but there is really
+ // no other way around it
+ try {
+ processor_->process(inputProtocol, outputProtocol);
+ numProcessed++;
+ if ( (numEvents > 0) && (numProcessed == numEvents)) {
+ return;
+ }
+ } catch (TEOFException& teof) {
+ if (!tail) {
+ break;
+ }
+ } catch (TException &te) {
+ cerr << te.what() << endl;
+ break;
+ }
+ }
+
+ // restore old read timeout
+ if (tail) {
+ inputTransport_->setReadTimeout(oldReadTimeout);
+ }
+
+}
+
+void TFileProcessor::processChunk() {
+ shared_ptr<TProtocol> inputProtocol = inputProtocolFactory_->getProtocol(inputTransport_);
+ shared_ptr<TProtocol> outputProtocol = outputProtocolFactory_->getProtocol(outputTransport_);
+
+ uint32_t curChunk = inputTransport_->getCurChunk();
+
+ while(1) {
+ // bad form to use exceptions for flow control but there is really
+ // no other way around it
+ try {
+ processor_->process(inputProtocol, outputProtocol);
+ if (curChunk != inputTransport_->getCurChunk()) {
+ break;
+ }
+ } catch (TEOFException& teof) {
+ break;
+ } catch (TException &te) {
+ cerr << te.what() << endl;
+ break;
+ }
+ }
+}
+
+}}} // apache::thrift::transport
diff --git a/lib/cpp/src/transport/TFileTransport.h b/lib/cpp/src/transport/TFileTransport.h
new file mode 100644
index 0000000..fbaf2cd
--- /dev/null
+++ b/lib/cpp/src/transport/TFileTransport.h
@@ -0,0 +1,440 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_TRANSPORT_TFILETRANSPORT_H_
+#define _THRIFT_TRANSPORT_TFILETRANSPORT_H_ 1
+
+#include "TTransport.h"
+#include "Thrift.h"
+#include "TProcessor.h"
+
+#include <string>
+#include <stdio.h>
+
+#include <boost/shared_ptr.hpp>
+
+namespace apache { namespace thrift { namespace transport {
+
+using apache::thrift::TProcessor;
+using apache::thrift::protocol::TProtocolFactory;
+
+// Data pertaining to a single event
+typedef struct eventInfo {
+ uint8_t* eventBuff_;
+ uint32_t eventSize_;
+ uint32_t eventBuffPos_;
+
+ eventInfo():eventBuff_(NULL), eventSize_(0), eventBuffPos_(0){};
+ ~eventInfo() {
+ if (eventBuff_) {
+ delete[] eventBuff_;
+ }
+ }
+} eventInfo;
+
+// information about current read state
+typedef struct readState {
+ eventInfo* event_;
+
+ // keep track of event size
+ uint8_t eventSizeBuff_[4];
+ uint8_t eventSizeBuffPos_;
+ bool readingSize_;
+
+ // read buffer variables
+ int32_t bufferPtr_;
+ int32_t bufferLen_;
+
+ // last successful dispatch point
+ int32_t lastDispatchPtr_;
+
+ void resetState(uint32_t lastDispatchPtr) {
+ readingSize_ = true;
+ eventSizeBuffPos_ = 0;
+ lastDispatchPtr_ = lastDispatchPtr;
+ }
+
+ void resetAllValues() {
+ resetState(0);
+ bufferPtr_ = 0;
+ bufferLen_ = 0;
+ if (event_) {
+ delete(event_);
+ }
+ event_ = 0;
+ }
+
+ readState() {
+ event_ = 0;
+ resetAllValues();
+ }
+
+ ~readState() {
+ if (event_) {
+ delete(event_);
+ }
+ }
+
+} readState;
+
+/**
+ * TFileTransportBuffer - buffer class used by TFileTransport for queueing up events
+ * to be written to disk. Should be used in the following way:
+ * 1) Buffer created
+ * 2) Buffer written to (addEvent)
+ * 3) Buffer read from (getNext)
+ * 4) Buffer reset (reset)
+ * 5) Go back to 2, or destroy buffer
+ *
+ * The buffer should never be written to after it is read from, unless it is reset first.
+ * Note: The above rules are enforced mainly for debugging its sole client TFileTransport
+ * which uses the buffer in this way.
+ *
+ */
+class TFileTransportBuffer {
+ public:
+ TFileTransportBuffer(uint32_t size);
+ ~TFileTransportBuffer();
+
+ bool addEvent(eventInfo *event);
+ eventInfo* getNext();
+ void reset();
+ bool isFull();
+ bool isEmpty();
+
+ private:
+ TFileTransportBuffer(); // should not be used
+
+ enum mode {
+ WRITE,
+ READ
+ };
+ mode bufferMode_;
+
+ uint32_t writePoint_;
+ uint32_t readPoint_;
+ uint32_t size_;
+ eventInfo** buffer_;
+};
+
+/**
+ * Abstract interface for transports used to read files
+ */
+class TFileReaderTransport : virtual public TTransport {
+ public:
+ virtual int32_t getReadTimeout() = 0;
+ virtual void setReadTimeout(int32_t readTimeout) = 0;
+
+ virtual uint32_t getNumChunks() = 0;
+ virtual uint32_t getCurChunk() = 0;
+ virtual void seekToChunk(int32_t chunk) = 0;
+ virtual void seekToEnd() = 0;
+};
+
+/**
+ * Abstract interface for transports used to write files
+ */
+class TFileWriterTransport : virtual public TTransport {
+ public:
+ virtual uint32_t getChunkSize() = 0;
+ virtual void setChunkSize(uint32_t chunkSize) = 0;
+};
+
+/**
+ * File implementation of a transport. Reads and writes are done to a
+ * file on disk.
+ *
+ */
+class TFileTransport : public TFileReaderTransport,
+ public TFileWriterTransport {
+ public:
+ TFileTransport(std::string path, bool readOnly=false);
+ ~TFileTransport();
+
+ // TODO: what is the correct behaviour for this?
+ // the log file is generally always open
+ bool isOpen() {
+ return true;
+ }
+
+ void write(const uint8_t* buf, uint32_t len);
+ void flush();
+
+ uint32_t readAll(uint8_t* buf, uint32_t len);
+ uint32_t read(uint8_t* buf, uint32_t len);
+
+ // log-file specific functions
+ void seekToChunk(int32_t chunk);
+ void seekToEnd();
+ uint32_t getNumChunks();
+ uint32_t getCurChunk();
+
+ // for changing the output file
+ void resetOutputFile(int fd, std::string filename, int64_t offset);
+
+ // Setter/Getter functions for user-controllable options
+ void setReadBuffSize(uint32_t readBuffSize) {
+ if (readBuffSize) {
+ readBuffSize_ = readBuffSize;
+ }
+ }
+ uint32_t getReadBuffSize() {
+ return readBuffSize_;
+ }
+
+ static const int32_t TAIL_READ_TIMEOUT = -1;
+ static const int32_t NO_TAIL_READ_TIMEOUT = 0;
+ void setReadTimeout(int32_t readTimeout) {
+ readTimeout_ = readTimeout;
+ }
+ int32_t getReadTimeout() {
+ return readTimeout_;
+ }
+
+ void setChunkSize(uint32_t chunkSize) {
+ if (chunkSize) {
+ chunkSize_ = chunkSize;
+ }
+ }
+ uint32_t getChunkSize() {
+ return chunkSize_;
+ }
+
+ void setEventBufferSize(uint32_t bufferSize) {
+ if (bufferAndThreadInitialized_) {
+ GlobalOutput("Cannot change the buffer size after writer thread started");
+ return;
+ }
+ eventBufferSize_ = bufferSize;
+ }
+
+ uint32_t getEventBufferSize() {
+ return eventBufferSize_;
+ }
+
+ void setFlushMaxUs(uint32_t flushMaxUs) {
+ if (flushMaxUs) {
+ flushMaxUs_ = flushMaxUs;
+ }
+ }
+ uint32_t getFlushMaxUs() {
+ return flushMaxUs_;
+ }
+
+ void setFlushMaxBytes(uint32_t flushMaxBytes) {
+ if (flushMaxBytes) {
+ flushMaxBytes_ = flushMaxBytes;
+ }
+ }
+ uint32_t getFlushMaxBytes() {
+ return flushMaxBytes_;
+ }
+
+ void setMaxEventSize(uint32_t maxEventSize) {
+ maxEventSize_ = maxEventSize;
+ }
+ uint32_t getMaxEventSize() {
+ return maxEventSize_;
+ }
+
+ void setMaxCorruptedEvents(uint32_t maxCorruptedEvents) {
+ maxCorruptedEvents_ = maxCorruptedEvents;
+ }
+ uint32_t getMaxCorruptedEvents() {
+ return maxCorruptedEvents_;
+ }
+
+ void setEofSleepTimeUs(uint32_t eofSleepTime) {
+ if (eofSleepTime) {
+ eofSleepTime_ = eofSleepTime;
+ }
+ }
+ uint32_t getEofSleepTimeUs() {
+ return eofSleepTime_;
+ }
+
+ private:
+ // helper functions for writing to a file
+ void enqueueEvent(const uint8_t* buf, uint32_t eventLen, bool blockUntilFlush);
+ bool swapEventBuffers(struct timespec* deadline);
+ bool initBufferAndWriteThread();
+
+ // control for writer thread
+ static void* startWriterThread(void* ptr) {
+ (((TFileTransport*)ptr)->writerThread());
+ return 0;
+ }
+ void writerThread();
+
+ // helper functions for reading from a file
+ eventInfo* readEvent();
+
+ // event corruption-related functions
+ bool isEventCorrupted();
+ void performRecovery();
+
+ // Utility functions
+ void openLogFile();
+ void getNextFlushTime(struct timespec* ts_next_flush);
+
+ // Class variables
+ readState readState_;
+ uint8_t* readBuff_;
+ eventInfo* currentEvent_;
+
+ uint32_t readBuffSize_;
+ static const uint32_t DEFAULT_READ_BUFF_SIZE = 1 * 1024 * 1024;
+
+ int32_t readTimeout_;
+ static const int32_t DEFAULT_READ_TIMEOUT_MS = 200;
+
+ // size of chunks that file will be split up into
+ uint32_t chunkSize_;
+ static const uint32_t DEFAULT_CHUNK_SIZE = 16 * 1024 * 1024;
+
+ // size of event buffers
+ uint32_t eventBufferSize_;
+ static const uint32_t DEFAULT_EVENT_BUFFER_SIZE = 10000;
+
+ // max number of microseconds that can pass without flushing
+ uint32_t flushMaxUs_;
+ static const uint32_t DEFAULT_FLUSH_MAX_US = 3000000;
+
+ // max number of bytes that can be written without flushing
+ uint32_t flushMaxBytes_;
+ static const uint32_t DEFAULT_FLUSH_MAX_BYTES = 1000 * 1024;
+
+ // max event size
+ uint32_t maxEventSize_;
+ static const uint32_t DEFAULT_MAX_EVENT_SIZE = 0;
+
+ // max number of corrupted events per chunk
+ uint32_t maxCorruptedEvents_;
+ static const uint32_t DEFAULT_MAX_CORRUPTED_EVENTS = 0;
+
+ // sleep duration when EOF is hit
+ uint32_t eofSleepTime_;
+ static const uint32_t DEFAULT_EOF_SLEEP_TIME_US = 500 * 1000;
+
+ // sleep duration when a corrupted event is encountered
+ uint32_t corruptedEventSleepTime_;
+ static const uint32_t DEFAULT_CORRUPTED_SLEEP_TIME_US = 1 * 1000 * 1000;
+
+ // writer thread id
+ pthread_t writerThreadId_;
+
+ // buffers to hold data before it is flushed. Each element of the buffer stores a msg that
+ // needs to be written to the file. The buffers are swapped by the writer thread.
+ TFileTransportBuffer *dequeueBuffer_;
+ TFileTransportBuffer *enqueueBuffer_;
+
+ // conditions used to block when the buffer is full or empty
+ pthread_cond_t notFull_, notEmpty_;
+ volatile bool closing_;
+
+ // To keep track of whether the buffer has been flushed
+ pthread_cond_t flushed_;
+ volatile bool forceFlush_;
+
+ // Mutex that is grabbed when enqueueing and swapping the read/write buffers
+ pthread_mutex_t mutex_;
+
+ // File information
+ std::string filename_;
+ int fd_;
+
+ // Whether the writer thread and buffers have been initialized
+ bool bufferAndThreadInitialized_;
+
+ // Offset within the file
+ off_t offset_;
+
+ // event corruption information
+ uint32_t lastBadChunk_;
+ uint32_t numCorruptedEventsInChunk_;
+
+ bool readOnly_;
+};
+
+// Exception thrown when EOF is hit
+class TEOFException : public TTransportException {
+ public:
+ TEOFException():
+ TTransportException(TTransportException::END_OF_FILE) {};
+};
+
+
+// wrapper class to process events from a file containing thrift events
+class TFileProcessor {
+ public:
+ /**
+ * Constructor that defaults output transport to null transport
+ *
+ * @param processor processes log-file events
+ * @param protocolFactory protocol factory
+ * @param inputTransport file transport
+ */
+ TFileProcessor(boost::shared_ptr<TProcessor> processor,
+ boost::shared_ptr<TProtocolFactory> protocolFactory,
+ boost::shared_ptr<TFileReaderTransport> inputTransport);
+
+ TFileProcessor(boost::shared_ptr<TProcessor> processor,
+ boost::shared_ptr<TProtocolFactory> inputProtocolFactory,
+ boost::shared_ptr<TProtocolFactory> outputProtocolFactory,
+ boost::shared_ptr<TFileReaderTransport> inputTransport);
+
+ /**
+ * Constructor
+ *
+ * @param processor processes log-file events
+ * @param protocolFactory protocol factory
+ * @param inputTransport input file transport
+ * @param output output transport
+ */
+ TFileProcessor(boost::shared_ptr<TProcessor> processor,
+ boost::shared_ptr<TProtocolFactory> protocolFactory,
+ boost::shared_ptr<TFileReaderTransport> inputTransport,
+ boost::shared_ptr<TTransport> outputTransport);
+
+ /**
+ * processes events from the file
+ *
+ * @param numEvents number of events to process (0 for unlimited)
+ * @param tail tails the file if true
+ */
+ void process(uint32_t numEvents, bool tail);
+
+ /**
+ * process events until the end of the chunk
+ *
+ */
+ void processChunk();
+
+ private:
+ boost::shared_ptr<TProcessor> processor_;
+ boost::shared_ptr<TProtocolFactory> inputProtocolFactory_;
+ boost::shared_ptr<TProtocolFactory> outputProtocolFactory_;
+ boost::shared_ptr<TFileReaderTransport> inputTransport_;
+ boost::shared_ptr<TTransport> outputTransport_;
+};
+
+
+}}} // apache::thrift::transport
+
+#endif // _THRIFT_TRANSPORT_TFILETRANSPORT_H_
diff --git a/lib/cpp/src/transport/THttpClient.cpp b/lib/cpp/src/transport/THttpClient.cpp
new file mode 100644
index 0000000..59f2339
--- /dev/null
+++ b/lib/cpp/src/transport/THttpClient.cpp
@@ -0,0 +1,348 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <cstdlib>
+#include <sstream>
+
+#include "THttpClient.h"
+#include "TSocket.h"
+
+namespace apache { namespace thrift { namespace transport {
+
+using namespace std;
+
+/**
+ * Http client implementation.
+ *
+ */
+
+// Yeah, yeah, hacky to put these here, I know.
+static const char* CRLF = "\r\n";
+static const int CRLF_LEN = 2;
+
+THttpClient::THttpClient(boost::shared_ptr<TTransport> transport, string host, string path) :
+ transport_(transport),
+ host_(host),
+ path_(path),
+ readHeaders_(true),
+ chunked_(false),
+ chunkedDone_(false),
+ chunkSize_(0),
+ contentLength_(0),
+ httpBuf_(NULL),
+ httpPos_(0),
+ httpBufLen_(0),
+ httpBufSize_(1024) {
+ init();
+}
+
+THttpClient::THttpClient(string host, int port, string path) :
+ host_(host),
+ path_(path),
+ readHeaders_(true),
+ chunked_(false),
+ chunkedDone_(false),
+ chunkSize_(0),
+ contentLength_(0),
+ httpBuf_(NULL),
+ httpPos_(0),
+ httpBufLen_(0),
+ httpBufSize_(1024) {
+ transport_ = boost::shared_ptr<TTransport>(new TSocket(host, port));
+ init();
+}
+
+void THttpClient::init() {
+ httpBuf_ = (char*)std::malloc(httpBufSize_+1);
+ if (httpBuf_ == NULL) {
+ throw TTransportException("Out of memory.");
+ }
+ httpBuf_[httpBufLen_] = '\0';
+}
+
+THttpClient::~THttpClient() {
+ if (httpBuf_ != NULL) {
+ std::free(httpBuf_);
+ }
+}
+
+uint32_t THttpClient::read(uint8_t* buf, uint32_t len) {
+ if (readBuffer_.available_read() == 0) {
+ readBuffer_.resetBuffer();
+ uint32_t got = readMoreData();
+ if (got == 0) {
+ return 0;
+ }
+ }
+ return readBuffer_.read(buf, len);
+}
+
+void THttpClient::readEnd() {
+ // Read any pending chunked data (footers etc.)
+ if (chunked_) {
+ while (!chunkedDone_) {
+ readChunked();
+ }
+ }
+}
+
+uint32_t THttpClient::readMoreData() {
+ // Get more data!
+ refill();
+
+ if (readHeaders_) {
+ readHeaders();
+ }
+
+ if (chunked_) {
+ return readChunked();
+ } else {
+ return readContent(contentLength_);
+ }
+}
+
+uint32_t THttpClient::readChunked() {
+ uint32_t length = 0;
+
+ char* line = readLine();
+ uint32_t chunkSize = parseChunkSize(line);
+ if (chunkSize == 0) {
+ readChunkedFooters();
+ } else {
+ // Read data content
+ length += readContent(chunkSize);
+ // Read trailing CRLF after content
+ readLine();
+ }
+ return length;
+}
+
+void THttpClient::readChunkedFooters() {
+ // End of data, read footer lines until a blank one appears
+ while (true) {
+ char* line = readLine();
+ if (strlen(line) == 0) {
+ chunkedDone_ = true;
+ break;
+ }
+ }
+}
+
+uint32_t THttpClient::parseChunkSize(char* line) {
+ char* semi = strchr(line, ';');
+ if (semi != NULL) {
+ *semi = '\0';
+ }
+ int size = 0;
+ sscanf(line, "%x", &size);
+ return (uint32_t)size;
+}
+
+uint32_t THttpClient::readContent(uint32_t size) {
+ uint32_t need = size;
+ while (need > 0) {
+ uint32_t avail = httpBufLen_ - httpPos_;
+ if (avail == 0) {
+ // We have given all the data, reset position to head of the buffer
+ httpPos_ = 0;
+ httpBufLen_ = 0;
+ refill();
+
+ // Now have available however much we read
+ avail = httpBufLen_;
+ }
+ uint32_t give = avail;
+ if (need < give) {
+ give = need;
+ }
+ readBuffer_.write((uint8_t*)(httpBuf_+httpPos_), give);
+ httpPos_ += give;
+ need -= give;
+ }
+ return size;
+}
+
+char* THttpClient::readLine() {
+ while (true) {
+ char* eol = NULL;
+
+ eol = strstr(httpBuf_+httpPos_, CRLF);
+
+ // No CRLF yet?
+ if (eol == NULL) {
+ // Shift whatever we have now to front and refill
+ shift();
+ refill();
+ } else {
+ // Return pointer to next line
+ *eol = '\0';
+ char* line = httpBuf_+httpPos_;
+ httpPos_ = (eol-httpBuf_) + CRLF_LEN;
+ return line;
+ }
+ }
+
+}
+
+void THttpClient::shift() {
+ if (httpBufLen_ > httpPos_) {
+ // Shift down remaining data and read more
+ uint32_t length = httpBufLen_ - httpPos_;
+ memmove(httpBuf_, httpBuf_+httpPos_, length);
+ httpBufLen_ = length;
+ } else {
+ httpBufLen_ = 0;
+ }
+ httpPos_ = 0;
+ httpBuf_[httpBufLen_] = '\0';
+}
+
+void THttpClient::refill() {
+ uint32_t avail = httpBufSize_ - httpBufLen_;
+ if (avail <= (httpBufSize_ / 4)) {
+ httpBufSize_ *= 2;
+ httpBuf_ = (char*)std::realloc(httpBuf_, httpBufSize_+1);
+ if (httpBuf_ == NULL) {
+ throw TTransportException("Out of memory.");
+ }
+ }
+
+ // Read more data
+ uint32_t got = transport_->read((uint8_t*)(httpBuf_+httpBufLen_), httpBufSize_-httpBufLen_);
+ httpBufLen_ += got;
+ httpBuf_[httpBufLen_] = '\0';
+
+ if (got == 0) {
+ throw TTransportException("Could not refill buffer");
+ }
+}
+
+void THttpClient::readHeaders() {
+ // Initialize headers state variables
+ contentLength_ = 0;
+ chunked_ = false;
+ chunkedDone_ = false;
+ chunkSize_ = 0;
+
+ // Control state flow
+ bool statusLine = true;
+ bool finished = false;
+
+ // Loop until headers are finished
+ while (true) {
+ char* line = readLine();
+
+ if (strlen(line) == 0) {
+ if (finished) {
+ readHeaders_ = false;
+ return;
+ } else {
+ // Must have been an HTTP 100, keep going for another status line
+ statusLine = true;
+ }
+ } else {
+ if (statusLine) {
+ statusLine = false;
+ finished = parseStatusLine(line);
+ } else {
+ parseHeader(line);
+ }
+ }
+ }
+}
+
+bool THttpClient::parseStatusLine(char* status) {
+ char* http = status;
+
+ char* code = strchr(http, ' ');
+ if (code == NULL) {
+ throw TTransportException(string("Bad Status: ") + status);
+ }
+
+ *code = '\0';
+ while (*(code++) == ' ');
+
+ char* msg = strchr(code, ' ');
+ if (msg == NULL) {
+ throw TTransportException(string("Bad Status: ") + status);
+ }
+ *msg = '\0';
+
+ if (strcmp(code, "200") == 0) {
+ // HTTP 200 = OK, we got the response
+ return true;
+ } else if (strcmp(code, "100") == 0) {
+ // HTTP 100 = continue, just keep reading
+ return false;
+ } else {
+ throw TTransportException(string("Bad Status: ") + status);
+ }
+}
+
+void THttpClient::parseHeader(char* header) {
+ char* colon = strchr(header, ':');
+ if (colon == NULL) {
+ return;
+ }
+ uint32_t sz = colon - header;
+ char* value = colon+1;
+
+ if (strncmp(header, "Transfer-Encoding", sz) == 0) {
+ if (strstr(value, "chunked") != NULL) {
+ chunked_ = true;
+ }
+ } else if (strncmp(header, "Content-Length", sz) == 0) {
+ chunked_ = false;
+ contentLength_ = atoi(value);
+ }
+}
+
+void THttpClient::write(const uint8_t* buf, uint32_t len) {
+ writeBuffer_.write(buf, len);
+}
+
+void THttpClient::flush() {
+ // Fetch the contents of the write buffer
+ uint8_t* buf;
+ uint32_t len;
+ writeBuffer_.getBuffer(&buf, &len);
+
+ // Construct the HTTP header
+ std::ostringstream h;
+ h <<
+ "POST " << path_ << " HTTP/1.1" << CRLF <<
+ "Host: " << host_ << CRLF <<
+ "Content-Type: application/x-thrift" << CRLF <<
+ "Content-Length: " << len << CRLF <<
+ "Accept: application/x-thrift" << CRLF <<
+ "User-Agent: C++/THttpClient" << CRLF <<
+ CRLF;
+ string header = h.str();
+
+ // Write the header, then the data, then flush
+ transport_->write((const uint8_t*)header.c_str(), header.size());
+ transport_->write(buf, len);
+ transport_->flush();
+
+ // Reset the buffer and header variables
+ writeBuffer_.resetBuffer();
+ readHeaders_ = true;
+}
+
+}}} // apache::thrift::transport
diff --git a/lib/cpp/src/transport/THttpClient.h b/lib/cpp/src/transport/THttpClient.h
new file mode 100644
index 0000000..f4be4c1
--- /dev/null
+++ b/lib/cpp/src/transport/THttpClient.h
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_TRANSPORT_THTTPCLIENT_H_
+#define _THRIFT_TRANSPORT_THTTPCLIENT_H_ 1
+
+#include <transport/TBufferTransports.h>
+
+namespace apache { namespace thrift { namespace transport {
+
+/**
+ * HTTP client implementation of the thrift transport. This was irritating
+ * to write, but the alternatives in C++ land are daunting. Linking CURL
+ * requires 23 dynamic libraries last time I checked (WTF?!?). All we have
+ * here is a VERY basic HTTP/1.1 client which supports HTTP 100 Continue,
+ * chunked transfer encoding, keepalive, etc. Tested against Apache.
+ *
+ */
+class THttpClient : public TTransport {
+ public:
+ THttpClient(boost::shared_ptr<TTransport> transport, std::string host, std::string path="");
+
+ THttpClient(std::string host, int port, std::string path="");
+
+ virtual ~THttpClient();
+
+ void open() {
+ transport_->open();
+ }
+
+ bool isOpen() {
+ return transport_->isOpen();
+ }
+
+ bool peek() {
+ return transport_->peek();
+ }
+
+ void close() {
+ transport_->close();
+ }
+
+ uint32_t read(uint8_t* buf, uint32_t len);
+
+ void readEnd();
+
+ void write(const uint8_t* buf, uint32_t len);
+
+ void flush();
+
+ private:
+ void init();
+
+ protected:
+
+ boost::shared_ptr<TTransport> transport_;
+
+ TMemoryBuffer writeBuffer_;
+ TMemoryBuffer readBuffer_;
+
+ std::string host_;
+ std::string path_;
+
+ bool readHeaders_;
+ bool chunked_;
+ bool chunkedDone_;
+ uint32_t chunkSize_;
+ uint32_t contentLength_;
+
+ char* httpBuf_;
+ uint32_t httpPos_;
+ uint32_t httpBufLen_;
+ uint32_t httpBufSize_;
+
+ uint32_t readMoreData();
+ char* readLine();
+
+ void readHeaders();
+ void parseHeader(char* header);
+ bool parseStatusLine(char* status);
+
+ uint32_t readChunked();
+ void readChunkedFooters();
+ uint32_t parseChunkSize(char* line);
+
+ uint32_t readContent(uint32_t size);
+
+ void refill();
+ void shift();
+
+};
+
+}}} // apache::thrift::transport
+
+#endif // #ifndef _THRIFT_TRANSPORT_THTTPCLIENT_H_
diff --git a/lib/cpp/src/transport/TServerSocket.cpp b/lib/cpp/src/transport/TServerSocket.cpp
new file mode 100644
index 0000000..9b47aa5
--- /dev/null
+++ b/lib/cpp/src/transport/TServerSocket.cpp
@@ -0,0 +1,366 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <cstring>
+#include <sys/socket.h>
+#include <sys/poll.h>
+#include <sys/types.h>
+#include <netinet/in.h>
+#include <netinet/tcp.h>
+#include <netdb.h>
+#include <fcntl.h>
+#include <errno.h>
+
+#include "TSocket.h"
+#include "TServerSocket.h"
+#include <boost/shared_ptr.hpp>
+
+namespace apache { namespace thrift { namespace transport {
+
+using namespace std;
+using boost::shared_ptr;
+
+TServerSocket::TServerSocket(int port) :
+ port_(port),
+ serverSocket_(-1),
+ acceptBacklog_(1024),
+ sendTimeout_(0),
+ recvTimeout_(0),
+ retryLimit_(0),
+ retryDelay_(0),
+ tcpSendBuffer_(0),
+ tcpRecvBuffer_(0),
+ intSock1_(-1),
+ intSock2_(-1) {}
+
+TServerSocket::TServerSocket(int port, int sendTimeout, int recvTimeout) :
+ port_(port),
+ serverSocket_(-1),
+ acceptBacklog_(1024),
+ sendTimeout_(sendTimeout),
+ recvTimeout_(recvTimeout),
+ retryLimit_(0),
+ retryDelay_(0),
+ tcpSendBuffer_(0),
+ tcpRecvBuffer_(0),
+ intSock1_(-1),
+ intSock2_(-1) {}
+
+TServerSocket::~TServerSocket() {
+ close();
+}
+
+void TServerSocket::setSendTimeout(int sendTimeout) {
+ sendTimeout_ = sendTimeout;
+}
+
+void TServerSocket::setRecvTimeout(int recvTimeout) {
+ recvTimeout_ = recvTimeout;
+}
+
+void TServerSocket::setRetryLimit(int retryLimit) {
+ retryLimit_ = retryLimit;
+}
+
+void TServerSocket::setRetryDelay(int retryDelay) {
+ retryDelay_ = retryDelay;
+}
+
+void TServerSocket::setTcpSendBuffer(int tcpSendBuffer) {
+ tcpSendBuffer_ = tcpSendBuffer;
+}
+
+void TServerSocket::setTcpRecvBuffer(int tcpRecvBuffer) {
+ tcpRecvBuffer_ = tcpRecvBuffer;
+}
+
+void TServerSocket::listen() {
+ int sv[2];
+ if (-1 == socketpair(AF_LOCAL, SOCK_STREAM, 0, sv)) {
+ GlobalOutput.perror("TServerSocket::listen() socketpair() ", errno);
+ intSock1_ = -1;
+ intSock2_ = -1;
+ } else {
+ intSock1_ = sv[1];
+ intSock2_ = sv[0];
+ }
+
+ struct addrinfo hints, *res, *res0;
+ int error;
+ char port[sizeof("65536") + 1];
+ std::memset(&hints, 0, sizeof(hints));
+ hints.ai_family = PF_UNSPEC;
+ hints.ai_socktype = SOCK_STREAM;
+ hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG;
+ sprintf(port, "%d", port_);
+
+ // Wildcard address
+ error = getaddrinfo(NULL, port, &hints, &res0);
+ if (error) {
+ GlobalOutput.printf("getaddrinfo %d: %s", error, gai_strerror(error));
+ close();
+ throw TTransportException(TTransportException::NOT_OPEN, "Could not resolve host for server socket.");
+ }
+
+ // Pick the ipv6 address first since ipv4 addresses can be mapped
+ // into ipv6 space.
+ for (res = res0; res; res = res->ai_next) {
+ if (res->ai_family == AF_INET6 || res->ai_next == NULL)
+ break;
+ }
+
+ serverSocket_ = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
+ if (serverSocket_ == -1) {
+ int errno_copy = errno;
+ GlobalOutput.perror("TServerSocket::listen() socket() ", errno_copy);
+ close();
+ throw TTransportException(TTransportException::NOT_OPEN, "Could not create server socket.", errno_copy);
+ }
+
+ // Set reusaddress to prevent 2MSL delay on accept
+ int one = 1;
+ if (-1 == setsockopt(serverSocket_, SOL_SOCKET, SO_REUSEADDR,
+ &one, sizeof(one))) {
+ int errno_copy = errno;
+ GlobalOutput.perror("TServerSocket::listen() setsockopt() SO_REUSEADDR ", errno_copy);
+ close();
+ throw TTransportException(TTransportException::NOT_OPEN, "Could not set SO_REUSEADDR", errno_copy);
+ }
+
+ // Set TCP buffer sizes
+ if (tcpSendBuffer_ > 0) {
+ if (-1 == setsockopt(serverSocket_, SOL_SOCKET, SO_SNDBUF,
+ &tcpSendBuffer_, sizeof(tcpSendBuffer_))) {
+ int errno_copy = errno;
+ GlobalOutput.perror("TServerSocket::listen() setsockopt() SO_SNDBUF ", errno_copy);
+ close();
+ throw TTransportException(TTransportException::NOT_OPEN, "Could not set SO_SNDBUF", errno_copy);
+ }
+ }
+
+ if (tcpRecvBuffer_ > 0) {
+ if (-1 == setsockopt(serverSocket_, SOL_SOCKET, SO_RCVBUF,
+ &tcpRecvBuffer_, sizeof(tcpRecvBuffer_))) {
+ int errno_copy = errno;
+ GlobalOutput.perror("TServerSocket::listen() setsockopt() SO_RCVBUF ", errno_copy);
+ close();
+ throw TTransportException(TTransportException::NOT_OPEN, "Could not set SO_RCVBUF", errno_copy);
+ }
+ }
+
+ // Defer accept
+ #ifdef TCP_DEFER_ACCEPT
+ if (-1 == setsockopt(serverSocket_, SOL_SOCKET, TCP_DEFER_ACCEPT,
+ &one, sizeof(one))) {
+ int errno_copy = errno;
+ GlobalOutput.perror("TServerSocket::listen() setsockopt() TCP_DEFER_ACCEPT ", errno_copy);
+ close();
+ throw TTransportException(TTransportException::NOT_OPEN, "Could not set TCP_DEFER_ACCEPT", errno_copy);
+ }
+ #endif // #ifdef TCP_DEFER_ACCEPT
+
+ #ifdef IPV6_V6ONLY
+ int zero = 0;
+ if (-1 == setsockopt(serverSocket_, IPPROTO_IPV6, IPV6_V6ONLY,
+ &zero, sizeof(zero))) {
+ GlobalOutput.perror("TServerSocket::listen() IPV6_V6ONLY ", errno);
+ }
+ #endif // #ifdef IPV6_V6ONLY
+
+ // Turn linger off, don't want to block on calls to close
+ struct linger ling = {0, 0};
+ if (-1 == setsockopt(serverSocket_, SOL_SOCKET, SO_LINGER,
+ &ling, sizeof(ling))) {
+ int errno_copy = errno;
+ GlobalOutput.perror("TServerSocket::listen() setsockopt() SO_LINGER ", errno_copy);
+ close();
+ throw TTransportException(TTransportException::NOT_OPEN, "Could not set SO_LINGER", errno_copy);
+ }
+
+ // TCP Nodelay, speed over bandwidth
+ if (-1 == setsockopt(serverSocket_, IPPROTO_TCP, TCP_NODELAY,
+ &one, sizeof(one))) {
+ int errno_copy = errno;
+ GlobalOutput.perror("TServerSocket::listen() setsockopt() TCP_NODELAY ", errno_copy);
+ close();
+ throw TTransportException(TTransportException::NOT_OPEN, "Could not set TCP_NODELAY", errno_copy);
+ }
+
+ // Set NONBLOCK on the accept socket
+ int flags = fcntl(serverSocket_, F_GETFL, 0);
+ if (flags == -1) {
+ int errno_copy = errno;
+ GlobalOutput.perror("TServerSocket::listen() fcntl() F_GETFL ", errno_copy);
+ throw TTransportException(TTransportException::NOT_OPEN, "fcntl() failed", errno_copy);
+ }
+
+ if (-1 == fcntl(serverSocket_, F_SETFL, flags | O_NONBLOCK)) {
+ int errno_copy = errno;
+ GlobalOutput.perror("TServerSocket::listen() fcntl() O_NONBLOCK ", errno_copy);
+ throw TTransportException(TTransportException::NOT_OPEN, "fcntl() failed", errno_copy);
+ }
+
+ // prepare the port information
+ // we may want to try to bind more than once, since SO_REUSEADDR doesn't
+ // always seem to work. The client can configure the retry variables.
+ int retries = 0;
+ do {
+ if (0 == bind(serverSocket_, res->ai_addr, res->ai_addrlen)) {
+ break;
+ }
+
+ // use short circuit evaluation here to only sleep if we need to
+ } while ((retries++ < retryLimit_) && (sleep(retryDelay_) == 0));
+
+ // free addrinfo
+ freeaddrinfo(res0);
+
+ // throw an error if we failed to bind properly
+ if (retries > retryLimit_) {
+ char errbuf[1024];
+ sprintf(errbuf, "TServerSocket::listen() BIND %d", port_);
+ GlobalOutput(errbuf);
+ close();
+ throw TTransportException(TTransportException::NOT_OPEN, "Could not bind");
+ }
+
+ // Call listen
+ if (-1 == ::listen(serverSocket_, acceptBacklog_)) {
+ int errno_copy = errno;
+ GlobalOutput.perror("TServerSocket::listen() listen() ", errno_copy);
+ close();
+ throw TTransportException(TTransportException::NOT_OPEN, "Could not listen", errno_copy);
+ }
+
+ // The socket is now listening!
+}
+
+shared_ptr<TTransport> TServerSocket::acceptImpl() {
+ if (serverSocket_ < 0) {
+ throw TTransportException(TTransportException::NOT_OPEN, "TServerSocket not listening");
+ }
+
+ struct pollfd fds[2];
+
+ int maxEintrs = 5;
+ int numEintrs = 0;
+
+ while (true) {
+ std::memset(fds, 0 , sizeof(fds));
+ fds[0].fd = serverSocket_;
+ fds[0].events = POLLIN;
+ if (intSock2_ >= 0) {
+ fds[1].fd = intSock2_;
+ fds[1].events = POLLIN;
+ }
+ int ret = poll(fds, 2, -1);
+
+ if (ret < 0) {
+ // error cases
+ if (errno == EINTR && (numEintrs++ < maxEintrs)) {
+ // EINTR needs to be handled manually and we can tolerate
+ // a certain number
+ continue;
+ }
+ int errno_copy = errno;
+ GlobalOutput.perror("TServerSocket::acceptImpl() poll() ", errno_copy);
+ throw TTransportException(TTransportException::UNKNOWN, "Unknown", errno_copy);
+ } else if (ret > 0) {
+ // Check for an interrupt signal
+ if (intSock2_ >= 0 && (fds[1].revents & POLLIN)) {
+ int8_t buf;
+ if (-1 == recv(intSock2_, &buf, sizeof(int8_t), 0)) {
+ GlobalOutput.perror("TServerSocket::acceptImpl() recv() interrupt ", errno);
+ }
+ throw TTransportException(TTransportException::INTERRUPTED);
+ }
+
+ // Check for the actual server socket being ready
+ if (fds[0].revents & POLLIN) {
+ break;
+ }
+ } else {
+ GlobalOutput("TServerSocket::acceptImpl() poll 0");
+ throw TTransportException(TTransportException::UNKNOWN);
+ }
+ }
+
+ struct sockaddr_storage clientAddress;
+ int size = sizeof(clientAddress);
+ int clientSocket = ::accept(serverSocket_,
+ (struct sockaddr *) &clientAddress,
+ (socklen_t *) &size);
+
+ if (clientSocket < 0) {
+ int errno_copy = errno;
+ GlobalOutput.perror("TServerSocket::acceptImpl() ::accept() ", errno_copy);
+ throw TTransportException(TTransportException::UNKNOWN, "accept()", errno_copy);
+ }
+
+ // Make sure client socket is blocking
+ int flags = fcntl(clientSocket, F_GETFL, 0);
+ if (flags == -1) {
+ int errno_copy = errno;
+ GlobalOutput.perror("TServerSocket::acceptImpl() fcntl() F_GETFL ", errno_copy);
+ throw TTransportException(TTransportException::UNKNOWN, "fcntl(F_GETFL)", errno_copy);
+ }
+
+ if (-1 == fcntl(clientSocket, F_SETFL, flags & ~O_NONBLOCK)) {
+ int errno_copy = errno;
+ GlobalOutput.perror("TServerSocket::acceptImpl() fcntl() F_SETFL ~O_NONBLOCK ", errno_copy);
+ throw TTransportException(TTransportException::UNKNOWN, "fcntl(F_SETFL)", errno_copy);
+ }
+
+ shared_ptr<TSocket> client(new TSocket(clientSocket));
+ if (sendTimeout_ > 0) {
+ client->setSendTimeout(sendTimeout_);
+ }
+ if (recvTimeout_ > 0) {
+ client->setRecvTimeout(recvTimeout_);
+ }
+
+ return client;
+}
+
+void TServerSocket::interrupt() {
+ if (intSock1_ >= 0) {
+ int8_t byte = 0;
+ if (-1 == send(intSock1_, &byte, sizeof(int8_t), 0)) {
+ GlobalOutput.perror("TServerSocket::interrupt() send() ", errno);
+ }
+ }
+}
+
+void TServerSocket::close() {
+ if (serverSocket_ >= 0) {
+ shutdown(serverSocket_, SHUT_RDWR);
+ ::close(serverSocket_);
+ }
+ if (intSock1_ >= 0) {
+ ::close(intSock1_);
+ }
+ if (intSock2_ >= 0) {
+ ::close(intSock2_);
+ }
+ serverSocket_ = -1;
+ intSock1_ = -1;
+ intSock2_ = -1;
+}
+
+}}} // apache::thrift::transport
diff --git a/lib/cpp/src/transport/TServerSocket.h b/lib/cpp/src/transport/TServerSocket.h
new file mode 100644
index 0000000..a6be017
--- /dev/null
+++ b/lib/cpp/src/transport/TServerSocket.h
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_TRANSPORT_TSERVERSOCKET_H_
+#define _THRIFT_TRANSPORT_TSERVERSOCKET_H_ 1
+
+#include "TServerTransport.h"
+#include <boost/shared_ptr.hpp>
+
+namespace apache { namespace thrift { namespace transport {
+
+class TSocket;
+
+/**
+ * Server socket implementation of TServerTransport. Wrapper around a unix
+ * socket listen and accept calls.
+ *
+ */
+class TServerSocket : public TServerTransport {
+ public:
+ TServerSocket(int port);
+ TServerSocket(int port, int sendTimeout, int recvTimeout);
+
+ ~TServerSocket();
+
+ void setSendTimeout(int sendTimeout);
+ void setRecvTimeout(int recvTimeout);
+
+ void setRetryLimit(int retryLimit);
+ void setRetryDelay(int retryDelay);
+
+ void setTcpSendBuffer(int tcpSendBuffer);
+ void setTcpRecvBuffer(int tcpRecvBuffer);
+
+ void listen();
+ void close();
+
+ void interrupt();
+
+ protected:
+ boost::shared_ptr<TTransport> acceptImpl();
+
+ private:
+ int port_;
+ int serverSocket_;
+ int acceptBacklog_;
+ int sendTimeout_;
+ int recvTimeout_;
+ int retryLimit_;
+ int retryDelay_;
+ int tcpSendBuffer_;
+ int tcpRecvBuffer_;
+
+ int intSock1_;
+ int intSock2_;
+};
+
+}}} // apache::thrift::transport
+
+#endif // #ifndef _THRIFT_TRANSPORT_TSERVERSOCKET_H_
diff --git a/lib/cpp/src/transport/TServerTransport.h b/lib/cpp/src/transport/TServerTransport.h
new file mode 100644
index 0000000..40bbc6c
--- /dev/null
+++ b/lib/cpp/src/transport/TServerTransport.h
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_TRANSPORT_TSERVERTRANSPORT_H_
+#define _THRIFT_TRANSPORT_TSERVERTRANSPORT_H_ 1
+
+#include "TTransport.h"
+#include "TTransportException.h"
+#include <boost/shared_ptr.hpp>
+
+namespace apache { namespace thrift { namespace transport {
+
+/**
+ * Server transport framework. A server needs to have some facility for
+ * creating base transports to read/write from.
+ *
+ */
+class TServerTransport {
+ public:
+ virtual ~TServerTransport() {}
+
+ /**
+ * Starts the server transport listening for new connections. Prior to this
+ * call most transports will not return anything when accept is called.
+ *
+ * @throws TTransportException if we were unable to listen
+ */
+ virtual void listen() {}
+
+ /**
+ * Gets a new dynamically allocated transport object and passes it to the
+ * caller. Note that it is the explicit duty of the caller to free the
+ * allocated object. The returned TTransport object must always be in the
+ * opened state. NULL should never be returned, instead an Exception should
+ * always be thrown.
+ *
+ * @return A new TTransport object
+ * @throws TTransportException if there is an error
+ */
+ boost::shared_ptr<TTransport> accept() {
+ boost::shared_ptr<TTransport> result = acceptImpl();
+ if (result == NULL) {
+ throw TTransportException("accept() may not return NULL");
+ }
+ return result;
+ }
+
+ /**
+ * For "smart" TServerTransport implementations that work in a multi
+ * threaded context this can be used to break out of an accept() call.
+ * It is expected that the transport will throw a TTransportException
+ * with the interrupted error code.
+ */
+ virtual void interrupt() {}
+
+ /**
+ * Closes this transport such that future calls to accept will do nothing.
+ */
+ virtual void close() = 0;
+
+ protected:
+ TServerTransport() {}
+
+ /**
+ * Subclasses should implement this function for accept.
+ *
+ * @return A newly allocated TTransport object
+ * @throw TTransportException If an error occurs
+ */
+ virtual boost::shared_ptr<TTransport> acceptImpl() = 0;
+
+};
+
+}}} // apache::thrift::transport
+
+#endif // #ifndef _THRIFT_TRANSPORT_TSERVERTRANSPORT_H_
diff --git a/lib/cpp/src/transport/TShortReadTransport.h b/lib/cpp/src/transport/TShortReadTransport.h
new file mode 100644
index 0000000..3df8a57
--- /dev/null
+++ b/lib/cpp/src/transport/TShortReadTransport.h
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_TRANSPORT_TSHORTREADTRANSPORT_H_
+#define _THRIFT_TRANSPORT_TSHORTREADTRANSPORT_H_ 1
+
+#include <cstdlib>
+
+#include <transport/TTransport.h>
+
+namespace apache { namespace thrift { namespace transport { namespace test {
+
+/**
+ * This class is only meant for testing. It wraps another transport.
+ * Calls to read are passed through with some probability. Otherwise,
+ * the read amount is randomly reduced before being passed through.
+ *
+ */
+class TShortReadTransport : public TTransport {
+ public:
+ TShortReadTransport(boost::shared_ptr<TTransport> transport, double full_prob)
+ : transport_(transport)
+ , fullProb_(full_prob)
+ {}
+
+ bool isOpen() {
+ return transport_->isOpen();
+ }
+
+ bool peek() {
+ return transport_->peek();
+ }
+
+ void open() {
+ transport_->open();
+ }
+
+ void close() {
+ transport_->close();
+ }
+
+ uint32_t read(uint8_t* buf, uint32_t len) {
+ if (len == 0) {
+ return 0;
+ }
+
+ if (rand()/(double)RAND_MAX >= fullProb_) {
+ len = 1 + rand()%len;
+ }
+ return transport_->read(buf, len);
+ }
+
+ void write(const uint8_t* buf, uint32_t len) {
+ transport_->write(buf, len);
+ }
+
+ void flush() {
+ transport_->flush();
+ }
+
+ const uint8_t* borrow(uint8_t* buf, uint32_t* len) {
+ return transport_->borrow(buf, len);
+ }
+
+ void consume(uint32_t len) {
+ return transport_->consume(len);
+ }
+
+ boost::shared_ptr<TTransport> getUnderlyingTransport() {
+ return transport_;
+ }
+
+ protected:
+ boost::shared_ptr<TTransport> transport_;
+ double fullProb_;
+};
+
+}}}} // apache::thrift::transport::test
+
+#endif // #ifndef _THRIFT_TRANSPORT_TSHORTREADTRANSPORT_H_
diff --git a/lib/cpp/src/transport/TSimpleFileTransport.cpp b/lib/cpp/src/transport/TSimpleFileTransport.cpp
new file mode 100644
index 0000000..e58a574
--- /dev/null
+++ b/lib/cpp/src/transport/TSimpleFileTransport.cpp
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "TSimpleFileTransport.h"
+
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <fcntl.h>
+
+namespace apache { namespace thrift { namespace transport {
+
+TSimpleFileTransport::
+TSimpleFileTransport(const std::string& path, bool read, bool write)
+ : TFDTransport(-1, TFDTransport::CLOSE_ON_DESTROY) {
+ int flags = 0;
+ if (read && write) {
+ flags = O_RDWR;
+ } else if (read) {
+ flags = O_RDONLY;
+ } else if (write) {
+ flags = O_WRONLY;
+ } else {
+ throw TTransportException("Neither READ nor WRITE specified");
+ }
+ if (write) {
+ flags |= O_CREAT | O_APPEND;
+ }
+ int fd = ::open(path.c_str(),
+ flags,
+ S_IRUSR | S_IWUSR| S_IRGRP | S_IROTH);
+ if (fd < 0) {
+ throw TTransportException("failed to open file for writing: " + path);
+ }
+ setFD(fd);
+ open();
+}
+
+}}} // apache::thrift::transport
diff --git a/lib/cpp/src/transport/TSimpleFileTransport.h b/lib/cpp/src/transport/TSimpleFileTransport.h
new file mode 100644
index 0000000..6cc52ea
--- /dev/null
+++ b/lib/cpp/src/transport/TSimpleFileTransport.h
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_TRANSPORT_TSIMPLEFILETRANSPORT_H_
+#define _THRIFT_TRANSPORT_TSIMPLEFILETRANSPORT_H_ 1
+
+#include "TFDTransport.h"
+
+namespace apache { namespace thrift { namespace transport {
+
+/**
+ * Dead-simple wrapper around a file.
+ *
+ * Writeable files are opened with O_CREAT and O_APPEND
+ */
+class TSimpleFileTransport : public TFDTransport {
+ public:
+ TSimpleFileTransport(const std::string& path,
+ bool read = true,
+ bool write = false);
+};
+
+}}} // apache::thrift::transport
+
+#endif // _THRIFT_TRANSPORT_TSIMPLEFILETRANSPORT_H_
diff --git a/lib/cpp/src/transport/TSocket.cpp b/lib/cpp/src/transport/TSocket.cpp
new file mode 100644
index 0000000..3395dab
--- /dev/null
+++ b/lib/cpp/src/transport/TSocket.cpp
@@ -0,0 +1,589 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <config.h>
+#include <cstring>
+#include <sstream>
+#include <sys/socket.h>
+#include <sys/poll.h>
+#include <sys/types.h>
+#include <arpa/inet.h>
+#include <netinet/in.h>
+#include <netinet/tcp.h>
+#include <netdb.h>
+#include <unistd.h>
+#include <errno.h>
+#include <fcntl.h>
+
+#include "concurrency/Monitor.h"
+#include "TSocket.h"
+#include "TTransportException.h"
+
+namespace apache { namespace thrift { namespace transport {
+
+using namespace std;
+
+// Global var to track total socket sys calls
+uint32_t g_socket_syscalls = 0;
+
+/**
+ * TSocket implementation.
+ *
+ */
+
+TSocket::TSocket(string host, int port) :
+ host_(host),
+ port_(port),
+ socket_(-1),
+ connTimeout_(0),
+ sendTimeout_(0),
+ recvTimeout_(0),
+ lingerOn_(1),
+ lingerVal_(0),
+ noDelay_(1),
+ maxRecvRetries_(5) {
+ recvTimeval_.tv_sec = (int)(recvTimeout_/1000);
+ recvTimeval_.tv_usec = (int)((recvTimeout_%1000)*1000);
+}
+
+TSocket::TSocket() :
+ host_(""),
+ port_(0),
+ socket_(-1),
+ connTimeout_(0),
+ sendTimeout_(0),
+ recvTimeout_(0),
+ lingerOn_(1),
+ lingerVal_(0),
+ noDelay_(1),
+ maxRecvRetries_(5) {
+ recvTimeval_.tv_sec = (int)(recvTimeout_/1000);
+ recvTimeval_.tv_usec = (int)((recvTimeout_%1000)*1000);
+}
+
+TSocket::TSocket(int socket) :
+ host_(""),
+ port_(0),
+ socket_(socket),
+ connTimeout_(0),
+ sendTimeout_(0),
+ recvTimeout_(0),
+ lingerOn_(1),
+ lingerVal_(0),
+ noDelay_(1),
+ maxRecvRetries_(5) {
+ recvTimeval_.tv_sec = (int)(recvTimeout_/1000);
+ recvTimeval_.tv_usec = (int)((recvTimeout_%1000)*1000);
+}
+
+TSocket::~TSocket() {
+ close();
+}
+
+bool TSocket::isOpen() {
+ return (socket_ >= 0);
+}
+
+bool TSocket::peek() {
+ if (!isOpen()) {
+ return false;
+ }
+ uint8_t buf;
+ int r = recv(socket_, &buf, 1, MSG_PEEK);
+ if (r == -1) {
+ int errno_copy = errno;
+ #ifdef __FreeBSD__
+ /* shigin:
+ * freebsd returns -1 and ECONNRESET if socket was closed by
+ * the other side
+ */
+ if (errno_copy == ECONNRESET)
+ {
+ close();
+ return false;
+ }
+ #endif
+ GlobalOutput.perror("TSocket::peek() recv() " + getSocketInfo(), errno_copy);
+ throw TTransportException(TTransportException::UNKNOWN, "recv()", errno_copy);
+ }
+ return (r > 0);
+}
+
+void TSocket::openConnection(struct addrinfo *res) {
+ if (isOpen()) {
+ throw TTransportException(TTransportException::ALREADY_OPEN);
+ }
+
+ socket_ = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
+ if (socket_ == -1) {
+ int errno_copy = errno;
+ GlobalOutput.perror("TSocket::open() socket() " + getSocketInfo(), errno_copy);
+ throw TTransportException(TTransportException::NOT_OPEN, "socket()", errno_copy);
+ }
+
+ // Send timeout
+ if (sendTimeout_ > 0) {
+ setSendTimeout(sendTimeout_);
+ }
+
+ // Recv timeout
+ if (recvTimeout_ > 0) {
+ setRecvTimeout(recvTimeout_);
+ }
+
+ // Linger
+ setLinger(lingerOn_, lingerVal_);
+
+ // No delay
+ setNoDelay(noDelay_);
+
+ // Set the socket to be non blocking for connect if a timeout exists
+ int flags = fcntl(socket_, F_GETFL, 0);
+ if (connTimeout_ > 0) {
+ if (-1 == fcntl(socket_, F_SETFL, flags | O_NONBLOCK)) {
+ int errno_copy = errno;
+ GlobalOutput.perror("TSocket::open() fcntl() " + getSocketInfo(), errno_copy);
+ throw TTransportException(TTransportException::NOT_OPEN, "fcntl() failed", errno_copy);
+ }
+ } else {
+ if (-1 == fcntl(socket_, F_SETFL, flags & ~O_NONBLOCK)) {
+ int errno_copy = errno;
+ GlobalOutput.perror("TSocket::open() fcntl " + getSocketInfo(), errno_copy);
+ throw TTransportException(TTransportException::NOT_OPEN, "fcntl() failed", errno_copy);
+ }
+ }
+
+ // Connect the socket
+ int ret = connect(socket_, res->ai_addr, res->ai_addrlen);
+
+ // success case
+ if (ret == 0) {
+ goto done;
+ }
+
+ if (errno != EINPROGRESS) {
+ int errno_copy = errno;
+ GlobalOutput.perror("TSocket::open() connect() " + getSocketInfo(), errno_copy);
+ throw TTransportException(TTransportException::NOT_OPEN, "connect() failed", errno_copy);
+ }
+
+
+ struct pollfd fds[1];
+ std::memset(fds, 0 , sizeof(fds));
+ fds[0].fd = socket_;
+ fds[0].events = POLLOUT;
+ ret = poll(fds, 1, connTimeout_);
+
+ if (ret > 0) {
+ // Ensure the socket is connected and that there are no errors set
+ int val;
+ socklen_t lon;
+ lon = sizeof(int);
+ int ret2 = getsockopt(socket_, SOL_SOCKET, SO_ERROR, (void *)&val, &lon);
+ if (ret2 == -1) {
+ int errno_copy = errno;
+ GlobalOutput.perror("TSocket::open() getsockopt() " + getSocketInfo(), errno_copy);
+ throw TTransportException(TTransportException::NOT_OPEN, "getsockopt()", errno_copy);
+ }
+ // no errors on socket, go to town
+ if (val == 0) {
+ goto done;
+ }
+ GlobalOutput.perror("TSocket::open() error on socket (after poll) " + getSocketInfo(), val);
+ throw TTransportException(TTransportException::NOT_OPEN, "socket open() error", val);
+ } else if (ret == 0) {
+ // socket timed out
+ string errStr = "TSocket::open() timed out " + getSocketInfo();
+ GlobalOutput(errStr.c_str());
+ throw TTransportException(TTransportException::NOT_OPEN, "open() timed out");
+ } else {
+ // error on poll()
+ int errno_copy = errno;
+ GlobalOutput.perror("TSocket::open() poll() " + getSocketInfo(), errno_copy);
+ throw TTransportException(TTransportException::NOT_OPEN, "poll() failed", errno_copy);
+ }
+
+ done:
+ // Set socket back to normal mode (blocking)
+ fcntl(socket_, F_SETFL, flags);
+}
+
+void TSocket::open() {
+ if (isOpen()) {
+ throw TTransportException(TTransportException::ALREADY_OPEN);
+ }
+
+ // Validate port number
+ if (port_ < 0 || port_ > 65536) {
+ throw TTransportException(TTransportException::NOT_OPEN, "Specified port is invalid");
+ }
+
+ struct addrinfo hints, *res, *res0;
+ res = NULL;
+ res0 = NULL;
+ int error;
+ char port[sizeof("65536")];
+ std::memset(&hints, 0, sizeof(hints));
+ hints.ai_family = PF_UNSPEC;
+ hints.ai_socktype = SOCK_STREAM;
+ hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG;
+ sprintf(port, "%d", port_);
+
+ error = getaddrinfo(host_.c_str(), port, &hints, &res0);
+
+ if (error) {
+ string errStr = "TSocket::open() getaddrinfo() " + getSocketInfo() + string(gai_strerror(error));
+ GlobalOutput(errStr.c_str());
+ close();
+ throw TTransportException(TTransportException::NOT_OPEN, "Could not resolve host for client socket.");
+ }
+
+ // Cycle through all the returned addresses until one
+ // connects or push the exception up.
+ for (res = res0; res; res = res->ai_next) {
+ try {
+ openConnection(res);
+ break;
+ } catch (TTransportException& ttx) {
+ if (res->ai_next) {
+ close();
+ } else {
+ close();
+ freeaddrinfo(res0); // cleanup on failure
+ throw;
+ }
+ }
+ }
+
+ // Free address structure memory
+ freeaddrinfo(res0);
+}
+
+void TSocket::close() {
+ if (socket_ >= 0) {
+ shutdown(socket_, SHUT_RDWR);
+ ::close(socket_);
+ }
+ socket_ = -1;
+}
+
+uint32_t TSocket::read(uint8_t* buf, uint32_t len) {
+ if (socket_ < 0) {
+ throw TTransportException(TTransportException::NOT_OPEN, "Called read on non-open socket");
+ }
+
+ int32_t retries = 0;
+
+ // EAGAIN can be signalled both when a timeout has occurred and when
+ // the system is out of resources (an awesome undocumented feature).
+ // The following is an approximation of the time interval under which
+ // EAGAIN is taken to indicate an out of resources error.
+ uint32_t eagainThresholdMicros = 0;
+ if (recvTimeout_) {
+ // if a readTimeout is specified along with a max number of recv retries, then
+ // the threshold will ensure that the read timeout is not exceeded even in the
+ // case of resource errors
+ eagainThresholdMicros = (recvTimeout_*1000)/ ((maxRecvRetries_>0) ? maxRecvRetries_ : 2);
+ }
+
+ try_again:
+ // Read from the socket
+ struct timeval begin;
+ gettimeofday(&begin, NULL);
+ int got = recv(socket_, buf, len, 0);
+ int errno_copy = errno; //gettimeofday can change errno
+ struct timeval end;
+ gettimeofday(&end, NULL);
+ uint32_t readElapsedMicros = (((end.tv_sec - begin.tv_sec) * 1000 * 1000)
+ + (((uint64_t)(end.tv_usec - begin.tv_usec))));
+ ++g_socket_syscalls;
+
+ // Check for error on read
+ if (got < 0) {
+ if (errno_copy == EAGAIN) {
+ // check if this is the lack of resources or timeout case
+ if (!eagainThresholdMicros || (readElapsedMicros < eagainThresholdMicros)) {
+ if (retries++ < maxRecvRetries_) {
+ usleep(50);
+ goto try_again;
+ } else {
+ throw TTransportException(TTransportException::TIMED_OUT,
+ "EAGAIN (unavailable resources)");
+ }
+ } else {
+ // infer that timeout has been hit
+ throw TTransportException(TTransportException::TIMED_OUT,
+ "EAGAIN (timed out)");
+ }
+ }
+
+ // If interrupted, try again
+ if (errno_copy == EINTR && retries++ < maxRecvRetries_) {
+ goto try_again;
+ }
+
+ // Now it's not a try again case, but a real probblez
+ GlobalOutput.perror("TSocket::read() recv() " + getSocketInfo(), errno_copy);
+
+ // If we disconnect with no linger time
+ if (errno_copy == ECONNRESET) {
+ #ifdef __FreeBSD__
+ /* shigin: freebsd doesn't follow POSIX semantic of recv and fails with
+ * ECONNRESET if peer performed shutdown
+ */
+ close();
+ return 0;
+ #else
+ throw TTransportException(TTransportException::NOT_OPEN, "ECONNRESET");
+ #endif
+ }
+
+ // This ish isn't open
+ if (errno_copy == ENOTCONN) {
+ throw TTransportException(TTransportException::NOT_OPEN, "ENOTCONN");
+ }
+
+ // Timed out!
+ if (errno_copy == ETIMEDOUT) {
+ throw TTransportException(TTransportException::TIMED_OUT, "ETIMEDOUT");
+ }
+
+ // Some other error, whatevz
+ throw TTransportException(TTransportException::UNKNOWN, "Unknown", errno_copy);
+ }
+
+ // The remote host has closed the socket
+ if (got == 0) {
+ close();
+ return 0;
+ }
+
+ // Pack data into string
+ return got;
+}
+
+void TSocket::write(const uint8_t* buf, uint32_t len) {
+ if (socket_ < 0) {
+ throw TTransportException(TTransportException::NOT_OPEN, "Called write on non-open socket");
+ }
+
+ uint32_t sent = 0;
+
+ while (sent < len) {
+
+ int flags = 0;
+ #ifdef MSG_NOSIGNAL
+ // Note the use of MSG_NOSIGNAL to suppress SIGPIPE errors, instead we
+ // check for the EPIPE return condition and close the socket in that case
+ flags |= MSG_NOSIGNAL;
+ #endif // ifdef MSG_NOSIGNAL
+
+ int b = send(socket_, buf + sent, len - sent, flags);
+ ++g_socket_syscalls;
+
+ // Fail on a send error
+ if (b < 0) {
+ int errno_copy = errno;
+ GlobalOutput.perror("TSocket::write() send() " + getSocketInfo(), errno_copy);
+
+ if (errno == EPIPE || errno == ECONNRESET || errno == ENOTCONN) {
+ close();
+ throw TTransportException(TTransportException::NOT_OPEN, "write() send()", errno_copy);
+ }
+
+ throw TTransportException(TTransportException::UNKNOWN, "write() send()", errno_copy);
+ }
+
+ // Fail on blocked send
+ if (b == 0) {
+ throw TTransportException(TTransportException::NOT_OPEN, "Socket send returned 0.");
+ }
+ sent += b;
+ }
+}
+
+std::string TSocket::getHost() {
+ return host_;
+}
+
+int TSocket::getPort() {
+ return port_;
+}
+
+void TSocket::setHost(string host) {
+ host_ = host;
+}
+
+void TSocket::setPort(int port) {
+ port_ = port;
+}
+
+void TSocket::setLinger(bool on, int linger) {
+ lingerOn_ = on;
+ lingerVal_ = linger;
+ if (socket_ < 0) {
+ return;
+ }
+
+ struct linger l = {(lingerOn_ ? 1 : 0), lingerVal_};
+ int ret = setsockopt(socket_, SOL_SOCKET, SO_LINGER, &l, sizeof(l));
+ if (ret == -1) {
+ int errno_copy = errno; // Copy errno because we're allocating memory.
+ GlobalOutput.perror("TSocket::setLinger() setsockopt() " + getSocketInfo(), errno_copy);
+ }
+}
+
+void TSocket::setNoDelay(bool noDelay) {
+ noDelay_ = noDelay;
+ if (socket_ < 0) {
+ return;
+ }
+
+ // Set socket to NODELAY
+ int v = noDelay_ ? 1 : 0;
+ int ret = setsockopt(socket_, IPPROTO_TCP, TCP_NODELAY, &v, sizeof(v));
+ if (ret == -1) {
+ int errno_copy = errno; // Copy errno because we're allocating memory.
+ GlobalOutput.perror("TSocket::setNoDelay() setsockopt() " + getSocketInfo(), errno_copy);
+ }
+}
+
+void TSocket::setConnTimeout(int ms) {
+ connTimeout_ = ms;
+}
+
+void TSocket::setRecvTimeout(int ms) {
+ if (ms < 0) {
+ char errBuf[512];
+ sprintf(errBuf, "TSocket::setRecvTimeout with negative input: %d\n", ms);
+ GlobalOutput(errBuf);
+ return;
+ }
+ recvTimeout_ = ms;
+
+ if (socket_ < 0) {
+ return;
+ }
+
+ recvTimeval_.tv_sec = (int)(recvTimeout_/1000);
+ recvTimeval_.tv_usec = (int)((recvTimeout_%1000)*1000);
+
+ // Copy because poll may modify
+ struct timeval r = recvTimeval_;
+ int ret = setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO, &r, sizeof(r));
+ if (ret == -1) {
+ int errno_copy = errno; // Copy errno because we're allocating memory.
+ GlobalOutput.perror("TSocket::setRecvTimeout() setsockopt() " + getSocketInfo(), errno_copy);
+ }
+}
+
+void TSocket::setSendTimeout(int ms) {
+ if (ms < 0) {
+ char errBuf[512];
+ sprintf(errBuf, "TSocket::setSendTimeout with negative input: %d\n", ms);
+ GlobalOutput(errBuf);
+ return;
+ }
+ sendTimeout_ = ms;
+
+ if (socket_ < 0) {
+ return;
+ }
+
+ struct timeval s = {(int)(sendTimeout_/1000),
+ (int)((sendTimeout_%1000)*1000)};
+ int ret = setsockopt(socket_, SOL_SOCKET, SO_SNDTIMEO, &s, sizeof(s));
+ if (ret == -1) {
+ int errno_copy = errno; // Copy errno because we're allocating memory.
+ GlobalOutput.perror("TSocket::setSendTimeout() setsockopt() " + getSocketInfo(), errno_copy);
+ }
+}
+
+void TSocket::setMaxRecvRetries(int maxRecvRetries) {
+ maxRecvRetries_ = maxRecvRetries;
+}
+
+string TSocket::getSocketInfo() {
+ std::ostringstream oss;
+ oss << "<Host: " << host_ << " Port: " << port_ << ">";
+ return oss.str();
+}
+
+std::string TSocket::getPeerHost() {
+ if (peerHost_.empty()) {
+ struct sockaddr_storage addr;
+ socklen_t addrLen = sizeof(addr);
+
+ if (socket_ < 0) {
+ return host_;
+ }
+
+ int rv = getpeername(socket_, (sockaddr*) &addr, &addrLen);
+
+ if (rv != 0) {
+ return peerHost_;
+ }
+
+ char clienthost[NI_MAXHOST];
+ char clientservice[NI_MAXSERV];
+
+ getnameinfo((sockaddr*) &addr, addrLen,
+ clienthost, sizeof(clienthost),
+ clientservice, sizeof(clientservice), 0);
+
+ peerHost_ = clienthost;
+ }
+ return peerHost_;
+}
+
+std::string TSocket::getPeerAddress() {
+ if (peerAddress_.empty()) {
+ struct sockaddr_storage addr;
+ socklen_t addrLen = sizeof(addr);
+
+ if (socket_ < 0) {
+ return peerAddress_;
+ }
+
+ int rv = getpeername(socket_, (sockaddr*) &addr, &addrLen);
+
+ if (rv != 0) {
+ return peerAddress_;
+ }
+
+ char clienthost[NI_MAXHOST];
+ char clientservice[NI_MAXSERV];
+
+ getnameinfo((sockaddr*) &addr, addrLen,
+ clienthost, sizeof(clienthost),
+ clientservice, sizeof(clientservice),
+ NI_NUMERICHOST|NI_NUMERICSERV);
+
+ peerAddress_ = clienthost;
+ peerPort_ = std::atoi(clientservice);
+ }
+ return peerAddress_;
+}
+
+int TSocket::getPeerPort() {
+ getPeerAddress();
+ return peerPort_;
+}
+
+}}} // apache::thrift::transport
diff --git a/lib/cpp/src/transport/TSocket.h b/lib/cpp/src/transport/TSocket.h
new file mode 100644
index 0000000..b0f445a
--- /dev/null
+++ b/lib/cpp/src/transport/TSocket.h
@@ -0,0 +1,242 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_TRANSPORT_TSOCKET_H_
+#define _THRIFT_TRANSPORT_TSOCKET_H_ 1
+
+#include <string>
+#include <sys/time.h>
+
+#include "TTransport.h"
+#include "TServerSocket.h"
+
+namespace apache { namespace thrift { namespace transport {
+
+/**
+ * TCP Socket implementation of the TTransport interface.
+ *
+ */
+class TSocket : public TTransport {
+ /**
+ * We allow the TServerSocket acceptImpl() method to access the private
+ * members of a socket so that it can access the TSocket(int socket)
+ * constructor which creates a socket object from the raw UNIX socket
+ * handle.
+ */
+ friend class TServerSocket;
+
+ public:
+ /**
+ * Constructs a new socket. Note that this does NOT actually connect the
+ * socket.
+ *
+ */
+ TSocket();
+
+ /**
+ * Constructs a new socket. Note that this does NOT actually connect the
+ * socket.
+ *
+ * @param host An IP address or hostname to connect to
+ * @param port The port to connect on
+ */
+ TSocket(std::string host, int port);
+
+ /**
+ * Destroyes the socket object, closing it if necessary.
+ */
+ virtual ~TSocket();
+
+ /**
+ * Whether the socket is alive.
+ *
+ * @return Is the socket alive?
+ */
+ bool isOpen();
+
+ /**
+ * Calls select on the socket to see if there is more data available.
+ */
+ bool peek();
+
+ /**
+ * Creates and opens the UNIX socket.
+ *
+ * @throws TTransportException If the socket could not connect
+ */
+ virtual void open();
+
+ /**
+ * Shuts down communications on the socket.
+ */
+ void close();
+
+ /**
+ * Reads from the underlying socket.
+ */
+ uint32_t read(uint8_t* buf, uint32_t len);
+
+ /**
+ * Writes to the underlying socket.
+ */
+ void write(const uint8_t* buf, uint32_t len);
+
+ /**
+ * Get the host that the socket is connected to
+ *
+ * @return string host identifier
+ */
+ std::string getHost();
+
+ /**
+ * Get the port that the socket is connected to
+ *
+ * @return int port number
+ */
+ int getPort();
+
+ /**
+ * Set the host that socket will connect to
+ *
+ * @param host host identifier
+ */
+ void setHost(std::string host);
+
+ /**
+ * Set the port that socket will connect to
+ *
+ * @param port port number
+ */
+ void setPort(int port);
+
+ /**
+ * Controls whether the linger option is set on the socket.
+ *
+ * @param on Whether SO_LINGER is on
+ * @param linger If linger is active, the number of seconds to linger for
+ */
+ void setLinger(bool on, int linger);
+
+ /**
+ * Whether to enable/disable Nagle's algorithm.
+ *
+ * @param noDelay Whether or not to disable the algorithm.
+ * @return
+ */
+ void setNoDelay(bool noDelay);
+
+ /**
+ * Set the connect timeout
+ */
+ void setConnTimeout(int ms);
+
+ /**
+ * Set the receive timeout
+ */
+ void setRecvTimeout(int ms);
+
+ /**
+ * Set the send timeout
+ */
+ void setSendTimeout(int ms);
+
+ /**
+ * Set the max number of recv retries in case of an EAGAIN
+ * error
+ */
+ void setMaxRecvRetries(int maxRecvRetries);
+
+ /**
+ * Get socket information formated as a string <Host: x Port: x>
+ */
+ std::string getSocketInfo();
+
+ /**
+ * Returns the DNS name of the host to which the socket is connected
+ */
+ std::string getPeerHost();
+
+ /**
+ * Returns the address of the host to which the socket is connected
+ */
+ std::string getPeerAddress();
+
+ /**
+ * Returns the port of the host to which the socket is connected
+ **/
+ int getPeerPort();
+
+
+ protected:
+ /**
+ * Constructor to create socket from raw UNIX handle. Never called directly
+ * but used by the TServerSocket class.
+ */
+ TSocket(int socket);
+
+ /** connect, called by open */
+ void openConnection(struct addrinfo *res);
+
+ /** Host to connect to */
+ std::string host_;
+
+ /** Peer hostname */
+ std::string peerHost_;
+
+ /** Peer address */
+ std::string peerAddress_;
+
+ /** Peer port */
+ int peerPort_;
+
+ /** Port number to connect on */
+ int port_;
+
+ /** Underlying UNIX socket handle */
+ int socket_;
+
+ /** Connect timeout in ms */
+ int connTimeout_;
+
+ /** Send timeout in ms */
+ int sendTimeout_;
+
+ /** Recv timeout in ms */
+ int recvTimeout_;
+
+ /** Linger on */
+ bool lingerOn_;
+
+ /** Linger val */
+ int lingerVal_;
+
+ /** Nodelay */
+ bool noDelay_;
+
+ /** Recv EGAIN retries */
+ int maxRecvRetries_;
+
+ /** Recv timeout timeval */
+ struct timeval recvTimeval_;
+};
+
+}}} // apache::thrift::transport
+
+#endif // #ifndef _THRIFT_TRANSPORT_TSOCKET_H_
+
diff --git a/lib/cpp/src/transport/TSocketPool.cpp b/lib/cpp/src/transport/TSocketPool.cpp
new file mode 100644
index 0000000..1150282
--- /dev/null
+++ b/lib/cpp/src/transport/TSocketPool.cpp
@@ -0,0 +1,235 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <algorithm>
+#include <iostream>
+
+#include "TSocketPool.h"
+
+namespace apache { namespace thrift { namespace transport {
+
+using namespace std;
+
+using boost::shared_ptr;
+
+/**
+ * TSocketPoolServer implementation
+ *
+ */
+TSocketPoolServer::TSocketPoolServer()
+ : host_(""),
+ port_(0),
+ socket_(-1),
+ lastFailTime_(0),
+ consecutiveFailures_(0) {}
+
+/**
+ * Constructor for TSocketPool server
+ */
+TSocketPoolServer::TSocketPoolServer(const string &host, int port)
+ : host_(host),
+ port_(port),
+ socket_(-1),
+ lastFailTime_(0),
+ consecutiveFailures_(0) {}
+
+/**
+ * TSocketPool implementation.
+ *
+ */
+
+TSocketPool::TSocketPool() : TSocket(),
+ numRetries_(1),
+ retryInterval_(60),
+ maxConsecutiveFailures_(1),
+ randomize_(true),
+ alwaysTryLast_(true) {
+}
+
+TSocketPool::TSocketPool(const vector<string> &hosts,
+ const vector<int> &ports) : TSocket(),
+ numRetries_(1),
+ retryInterval_(60),
+ maxConsecutiveFailures_(1),
+ randomize_(true),
+ alwaysTryLast_(true)
+{
+ if (hosts.size() != ports.size()) {
+ GlobalOutput("TSocketPool::TSocketPool: hosts.size != ports.size");
+ throw TTransportException(TTransportException::BAD_ARGS);
+ }
+
+ for (unsigned int i = 0; i < hosts.size(); ++i) {
+ addServer(hosts[i], ports[i]);
+ }
+}
+
+TSocketPool::TSocketPool(const vector<pair<string, int> >& servers) : TSocket(),
+ numRetries_(1),
+ retryInterval_(60),
+ maxConsecutiveFailures_(1),
+ randomize_(true),
+ alwaysTryLast_(true)
+{
+ for (unsigned i = 0; i < servers.size(); ++i) {
+ addServer(servers[i].first, servers[i].second);
+ }
+}
+
+TSocketPool::TSocketPool(const vector< shared_ptr<TSocketPoolServer> >& servers) : TSocket(),
+ servers_(servers),
+ numRetries_(1),
+ retryInterval_(60),
+ maxConsecutiveFailures_(1),
+ randomize_(true),
+ alwaysTryLast_(true)
+{
+}
+
+TSocketPool::TSocketPool(const string& host, int port) : TSocket(),
+ numRetries_(1),
+ retryInterval_(60),
+ maxConsecutiveFailures_(1),
+ randomize_(true),
+ alwaysTryLast_(true)
+{
+ addServer(host, port);
+}
+
+TSocketPool::~TSocketPool() {
+ vector< shared_ptr<TSocketPoolServer> >::const_iterator iter = servers_.begin();
+ vector< shared_ptr<TSocketPoolServer> >::const_iterator iterEnd = servers_.end();
+ for (; iter != iterEnd; ++iter) {
+ setCurrentServer(*iter);
+ TSocketPool::close();
+ }
+}
+
+void TSocketPool::addServer(const string& host, int port) {
+ servers_.push_back(shared_ptr<TSocketPoolServer>(new TSocketPoolServer(host, port)));
+}
+
+void TSocketPool::setServers(const vector< shared_ptr<TSocketPoolServer> >& servers) {
+ servers_ = servers;
+}
+
+void TSocketPool::getServers(vector< shared_ptr<TSocketPoolServer> >& servers) {
+ servers = servers_;
+}
+
+void TSocketPool::setNumRetries(int numRetries) {
+ numRetries_ = numRetries;
+}
+
+void TSocketPool::setRetryInterval(int retryInterval) {
+ retryInterval_ = retryInterval;
+}
+
+
+void TSocketPool::setMaxConsecutiveFailures(int maxConsecutiveFailures) {
+ maxConsecutiveFailures_ = maxConsecutiveFailures;
+}
+
+void TSocketPool::setRandomize(bool randomize) {
+ randomize_ = randomize;
+}
+
+void TSocketPool::setAlwaysTryLast(bool alwaysTryLast) {
+ alwaysTryLast_ = alwaysTryLast;
+}
+
+void TSocketPool::setCurrentServer(const shared_ptr<TSocketPoolServer> &server) {
+ currentServer_ = server;
+ host_ = server->host_;
+ port_ = server->port_;
+ socket_ = server->socket_;
+}
+
+/* TODO: without apc we ignore a lot of functionality from the php version */
+void TSocketPool::open() {
+ if (randomize_) {
+ random_shuffle(servers_.begin(), servers_.end());
+ }
+
+ unsigned int numServers = servers_.size();
+ for (unsigned int i = 0; i < numServers; ++i) {
+
+ shared_ptr<TSocketPoolServer> &server = servers_[i];
+ bool retryIntervalPassed = (server->lastFailTime_ == 0);
+ bool isLastServer = alwaysTryLast_ ? (i == (numServers - 1)) : false;
+
+ // Impersonate the server socket
+ setCurrentServer(server);
+
+ if (isOpen()) {
+ // already open means we're done
+ return;
+ }
+
+ if (server->lastFailTime_ > 0) {
+ // The server was marked as down, so check if enough time has elapsed to retry
+ int elapsedTime = time(NULL) - server->lastFailTime_;
+ if (elapsedTime > retryInterval_) {
+ retryIntervalPassed = true;
+ }
+ }
+
+ if (retryIntervalPassed || isLastServer) {
+ for (int j = 0; j < numRetries_; ++j) {
+ try {
+ TSocket::open();
+
+ // Copy over the opened socket so that we can keep it persistent
+ server->socket_ = socket_;
+
+ // reset lastFailTime_ is required
+ if (server->lastFailTime_) {
+ server->lastFailTime_ = 0;
+ }
+
+ // success
+ return;
+ } catch (TException e) {
+ string errStr = "TSocketPool::open failed "+getSocketInfo()+": "+e.what();
+ GlobalOutput(errStr.c_str());
+ // connection failed
+ }
+ }
+
+ ++server->consecutiveFailures_;
+ if (server->consecutiveFailures_ > maxConsecutiveFailures_) {
+ // Mark server as down
+ server->consecutiveFailures_ = 0;
+ server->lastFailTime_ = time(NULL);
+ }
+ }
+ }
+
+ GlobalOutput("TSocketPool::open: all connections failed");
+ throw TTransportException(TTransportException::NOT_OPEN);
+}
+
+void TSocketPool::close() {
+ if (isOpen()) {
+ TSocket::close();
+ currentServer_->socket_ = -1;
+ }
+}
+
+}}} // apache::thrift::transport
diff --git a/lib/cpp/src/transport/TSocketPool.h b/lib/cpp/src/transport/TSocketPool.h
new file mode 100644
index 0000000..8c50669
--- /dev/null
+++ b/lib/cpp/src/transport/TSocketPool.h
@@ -0,0 +1,191 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_TRANSPORT_TSOCKETPOOL_H_
+#define _THRIFT_TRANSPORT_TSOCKETPOOL_H_ 1
+
+#include <vector>
+#include "TSocket.h"
+
+namespace apache { namespace thrift { namespace transport {
+
+ /**
+ * Class to hold server information for TSocketPool
+ *
+ */
+class TSocketPoolServer {
+
+ public:
+ /**
+ * Default constructor for server info
+ */
+ TSocketPoolServer();
+
+ /**
+ * Constructor for TSocketPool server
+ */
+ TSocketPoolServer(const std::string &host, int port);
+
+ // Host name
+ std::string host_;
+
+ // Port to connect on
+ int port_;
+
+ // Socket for the server
+ int socket_;
+
+ // Last time connecting to this server failed
+ int lastFailTime_;
+
+ // Number of consecutive times connecting to this server failed
+ int consecutiveFailures_;
+};
+
+/**
+ * TCP Socket implementation of the TTransport interface.
+ *
+ */
+class TSocketPool : public TSocket {
+
+ public:
+
+ /**
+ * Socket pool constructor
+ */
+ TSocketPool();
+
+ /**
+ * Socket pool constructor
+ *
+ * @param hosts list of host names
+ * @param ports list of port names
+ */
+ TSocketPool(const std::vector<std::string> &hosts,
+ const std::vector<int> &ports);
+
+ /**
+ * Socket pool constructor
+ *
+ * @param servers list of pairs of host name and port
+ */
+ TSocketPool(const std::vector<std::pair<std::string, int> >& servers);
+
+ /**
+ * Socket pool constructor
+ *
+ * @param servers list of TSocketPoolServers
+ */
+ TSocketPool(const std::vector< boost::shared_ptr<TSocketPoolServer> >& servers);
+
+ /**
+ * Socket pool constructor
+ *
+ * @param host single host
+ * @param port single port
+ */
+ TSocketPool(const std::string& host, int port);
+
+ /**
+ * Destroyes the socket object, closing it if necessary.
+ */
+ virtual ~TSocketPool();
+
+ /**
+ * Add a server to the pool
+ */
+ void addServer(const std::string& host, int port);
+
+ /**
+ * Set list of servers in this pool
+ */
+ void setServers(const std::vector< boost::shared_ptr<TSocketPoolServer> >& servers);
+
+ /**
+ * Get list of servers in this pool
+ */
+ void getServers(std::vector< boost::shared_ptr<TSocketPoolServer> >& servers);
+
+ /**
+ * Sets how many times to keep retrying a host in the connect function.
+ */
+ void setNumRetries(int numRetries);
+
+ /**
+ * Sets how long to wait until retrying a host if it was marked down
+ */
+ void setRetryInterval(int retryInterval);
+
+ /**
+ * Sets how many times to keep retrying a host before marking it as down.
+ */
+ void setMaxConsecutiveFailures(int maxConsecutiveFailures);
+
+ /**
+ * Turns randomization in connect order on or off.
+ */
+ void setRandomize(bool randomize);
+
+ /**
+ * Whether to always try the last server.
+ */
+ void setAlwaysTryLast(bool alwaysTryLast);
+
+ /**
+ * Creates and opens the UNIX socket.
+ */
+ void open();
+
+ /*
+ * Closes the UNIX socket
+ */
+ void close();
+
+ protected:
+
+ void setCurrentServer(const boost::shared_ptr<TSocketPoolServer> &server);
+
+ /** List of servers to connect to */
+ std::vector< boost::shared_ptr<TSocketPoolServer> > servers_;
+
+ /** Current server */
+ boost::shared_ptr<TSocketPoolServer> currentServer_;
+
+ /** How many times to retry each host in connect */
+ int numRetries_;
+
+ /** Retry interval in seconds, how long to not try a host if it has been
+ * marked as down.
+ */
+ int retryInterval_;
+
+ /** Max consecutive failures before marking a host down. */
+ int maxConsecutiveFailures_;
+
+ /** Try hosts in order? or Randomized? */
+ bool randomize_;
+
+ /** Always try last host, even if marked down? */
+ bool alwaysTryLast_;
+};
+
+}}} // apache::thrift::transport
+
+#endif // #ifndef _THRIFT_TRANSPORT_TSOCKETPOOL_H_
+
diff --git a/lib/cpp/src/transport/TTransport.h b/lib/cpp/src/transport/TTransport.h
new file mode 100644
index 0000000..eb0d5df
--- /dev/null
+++ b/lib/cpp/src/transport/TTransport.h
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_TRANSPORT_TTRANSPORT_H_
+#define _THRIFT_TRANSPORT_TTRANSPORT_H_ 1
+
+#include <Thrift.h>
+#include <boost/shared_ptr.hpp>
+#include <transport/TTransportException.h>
+#include <string>
+
+namespace apache { namespace thrift { namespace transport {
+
+/**
+ * Generic interface for a method of transporting data. A TTransport may be
+ * capable of either reading or writing, but not necessarily both.
+ *
+ */
+class TTransport {
+ public:
+ /**
+ * Virtual deconstructor.
+ */
+ virtual ~TTransport() {}
+
+ /**
+ * Whether this transport is open.
+ */
+ virtual bool isOpen() {
+ return false;
+ }
+
+ /**
+ * Tests whether there is more data to read or if the remote side is
+ * still open. By default this is true whenever the transport is open,
+ * but implementations should add logic to test for this condition where
+ * possible (i.e. on a socket).
+ * This is used by a server to check if it should listen for another
+ * request.
+ */
+ virtual bool peek() {
+ return isOpen();
+ }
+
+ /**
+ * Opens the transport for communications.
+ *
+ * @return bool Whether the transport was successfully opened
+ * @throws TTransportException if opening failed
+ */
+ virtual void open() {
+ throw TTransportException(TTransportException::NOT_OPEN, "Cannot open base TTransport.");
+ }
+
+ /**
+ * Closes the transport.
+ */
+ virtual void close() {
+ throw TTransportException(TTransportException::NOT_OPEN, "Cannot close base TTransport.");
+ }
+
+ /**
+ * Attempt to read up to the specified number of bytes into the string.
+ *
+ * @param buf Reference to the location to write the data
+ * @param len How many bytes to read
+ * @return How many bytes were actually read
+ * @throws TTransportException If an error occurs
+ */
+ virtual uint32_t read(uint8_t* /* buf */, uint32_t /* len */) {
+ throw TTransportException(TTransportException::NOT_OPEN, "Base TTransport cannot read.");
+ }
+
+ /**
+ * Reads the given amount of data in its entirety no matter what.
+ *
+ * @param s Reference to location for read data
+ * @param len How many bytes to read
+ * @return How many bytes read, which must be equal to size
+ * @throws TTransportException If insufficient data was read
+ */
+ virtual uint32_t readAll(uint8_t* buf, uint32_t len) {
+ uint32_t have = 0;
+ uint32_t get = 0;
+
+ while (have < len) {
+ get = read(buf+have, len-have);
+ if (get <= 0) {
+ throw TTransportException("No more data to read.");
+ }
+ have += get;
+ }
+
+ return have;
+ }
+
+ /**
+ * Called when read is completed.
+ * This can be over-ridden to perform a transport-specific action
+ * e.g. logging the request to a file
+ *
+ */
+ virtual void readEnd() {
+ // default behaviour is to do nothing
+ return;
+ }
+
+ /**
+ * Writes the string in its entirety to the buffer.
+ *
+ * @param buf The data to write out
+ * @throws TTransportException if an error occurs
+ */
+ virtual void write(const uint8_t* /* buf */, uint32_t /* len */) {
+ throw TTransportException(TTransportException::NOT_OPEN, "Base TTransport cannot write.");
+ }
+
+ /**
+ * Called when write is completed.
+ * This can be over-ridden to perform a transport-specific action
+ * at the end of a request.
+ *
+ */
+ virtual void writeEnd() {
+ // default behaviour is to do nothing
+ return;
+ }
+
+ /**
+ * Flushes any pending data to be written. Typically used with buffered
+ * transport mechanisms.
+ *
+ * @throws TTransportException if an error occurs
+ */
+ virtual void flush() {}
+
+ /**
+ * Attempts to return a pointer to \c len bytes, possibly copied into \c buf.
+ * Does not consume the bytes read (i.e.: a later read will return the same
+ * data). This method is meant to support protocols that need to read
+ * variable-length fields. They can attempt to borrow the maximum amount of
+ * data that they will need, then consume (see next method) what they
+ * actually use. Some transports will not support this method and others
+ * will fail occasionally, so protocols must be prepared to use read if
+ * borrow fails.
+ *
+ * @oaram buf A buffer where the data can be stored if needed.
+ * If borrow doesn't return buf, then the contents of
+ * buf after the call are undefined.
+ * @param len *len should initially contain the number of bytes to borrow.
+ * If borrow succeeds, *len will contain the number of bytes
+ * available in the returned pointer. This will be at least
+ * what was requested, but may be more if borrow returns
+ * a pointer to an internal buffer, rather than buf.
+ * If borrow fails, the contents of *len are undefined.
+ * @return If the borrow succeeds, return a pointer to the borrowed data.
+ * This might be equal to \c buf, or it might be a pointer into
+ * the transport's internal buffers.
+ * @throws TTransportException if an error occurs
+ */
+ virtual const uint8_t* borrow(uint8_t* /* buf */, uint32_t* /* len */) {
+ return NULL;
+ }
+
+ /**
+ * Remove len bytes from the transport. This should always follow a borrow
+ * of at least len bytes, and should always succeed.
+ * TODO(dreiss): Is there any transport that could borrow but fail to
+ * consume, or that would require a buffer to dump the consumed data?
+ *
+ * @param len How many bytes to consume
+ * @throws TTransportException If an error occurs
+ */
+ virtual void consume(uint32_t /* len */) {
+ throw TTransportException(TTransportException::NOT_OPEN, "Base TTransport cannot consume.");
+ }
+
+ protected:
+ /**
+ * Simple constructor.
+ */
+ TTransport() {}
+};
+
+/**
+ * Generic factory class to make an input and output transport out of a
+ * source transport. Commonly used inside servers to make input and output
+ * streams out of raw clients.
+ *
+ */
+class TTransportFactory {
+ public:
+ TTransportFactory() {}
+
+ virtual ~TTransportFactory() {}
+
+ /**
+ * Default implementation does nothing, just returns the transport given.
+ */
+ virtual boost::shared_ptr<TTransport> getTransport(boost::shared_ptr<TTransport> trans) {
+ return trans;
+ }
+
+};
+
+}}} // apache::thrift::transport
+
+#endif // #ifndef _THRIFT_TRANSPORT_TTRANSPORT_H_
diff --git a/lib/cpp/src/transport/TTransportException.cpp b/lib/cpp/src/transport/TTransportException.cpp
new file mode 100644
index 0000000..f0aaedc
--- /dev/null
+++ b/lib/cpp/src/transport/TTransportException.cpp
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <transport/TTransportException.h>
+#include <boost/lexical_cast.hpp>
+#include <cstring>
+#include <config.h>
+
+using std::string;
+using boost::lexical_cast;
+
+namespace apache { namespace thrift { namespace transport {
+
+}}} // apache::thrift::transport
+
diff --git a/lib/cpp/src/transport/TTransportException.h b/lib/cpp/src/transport/TTransportException.h
new file mode 100644
index 0000000..330785c
--- /dev/null
+++ b/lib/cpp/src/transport/TTransportException.h
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_TRANSPORT_TTRANSPORTEXCEPTION_H_
+#define _THRIFT_TRANSPORT_TTRANSPORTEXCEPTION_H_ 1
+
+#include <string>
+#include <Thrift.h>
+
+namespace apache { namespace thrift { namespace transport {
+
+/**
+ * Class to encapsulate all the possible types of transport errors that may
+ * occur in various transport systems. This provides a sort of generic
+ * wrapper around the shitty UNIX E_ error codes that lets a common code
+ * base of error handling to be used for various types of transports, i.e.
+ * pipes etc.
+ *
+ */
+class TTransportException : public apache::thrift::TException {
+ public:
+ /**
+ * Error codes for the various types of exceptions.
+ */
+ enum TTransportExceptionType
+ { UNKNOWN = 0
+ , NOT_OPEN = 1
+ , ALREADY_OPEN = 2
+ , TIMED_OUT = 3
+ , END_OF_FILE = 4
+ , INTERRUPTED = 5
+ , BAD_ARGS = 6
+ , CORRUPTED_DATA = 7
+ , INTERNAL_ERROR = 8
+ };
+
+ TTransportException() :
+ apache::thrift::TException(),
+ type_(UNKNOWN) {}
+
+ TTransportException(TTransportExceptionType type) :
+ apache::thrift::TException(),
+ type_(type) {}
+
+ TTransportException(const std::string& message) :
+ apache::thrift::TException(message),
+ type_(UNKNOWN) {}
+
+ TTransportException(TTransportExceptionType type, const std::string& message) :
+ apache::thrift::TException(message),
+ type_(type) {}
+
+ TTransportException(TTransportExceptionType type,
+ const std::string& message,
+ int errno_copy) :
+ apache::thrift::TException(message + ": " + TOutput::strerror_s(errno_copy)),
+ type_(type) {}
+
+ virtual ~TTransportException() throw() {}
+
+ /**
+ * Returns an error code that provides information about the type of error
+ * that has occurred.
+ *
+ * @return Error code
+ */
+ TTransportExceptionType getType() const throw() {
+ return type_;
+ }
+
+ virtual const char* what() const throw() {
+ if (message_.empty()) {
+ switch (type_) {
+ case UNKNOWN : return "TTransportException: Unknown transport exception";
+ case NOT_OPEN : return "TTransportException: Transport not open";
+ case ALREADY_OPEN : return "TTransportException: Transport already open";
+ case TIMED_OUT : return "TTransportException: Timed out";
+ case END_OF_FILE : return "TTransportException: End of file";
+ case INTERRUPTED : return "TTransportException: Interrupted";
+ case BAD_ARGS : return "TTransportException: Invalid arguments";
+ case CORRUPTED_DATA : return "TTransportException: Corrupted Data";
+ case INTERNAL_ERROR : return "TTransportException: Internal error";
+ default : return "TTransportException: (Invalid exception type)";
+ }
+ } else {
+ return message_.c_str();
+ }
+ }
+
+ protected:
+ /** Just like strerror_r but returns a C++ string object. */
+ std::string strerror_s(int errno_copy);
+
+ /** Error code */
+ TTransportExceptionType type_;
+
+};
+
+}}} // apache::thrift::transport
+
+#endif // #ifndef _THRIFT_TRANSPORT_TTRANSPORTEXCEPTION_H_
diff --git a/lib/cpp/src/transport/TTransportUtils.cpp b/lib/cpp/src/transport/TTransportUtils.cpp
new file mode 100644
index 0000000..a840fa6
--- /dev/null
+++ b/lib/cpp/src/transport/TTransportUtils.cpp
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <transport/TTransportUtils.h>
+
+using std::string;
+
+namespace apache { namespace thrift { namespace transport {
+
+uint32_t TPipedTransport::read(uint8_t* buf, uint32_t len) {
+ uint32_t need = len;
+
+ // We don't have enough data yet
+ if (rLen_-rPos_ < need) {
+ // Copy out whatever we have
+ if (rLen_-rPos_ > 0) {
+ memcpy(buf, rBuf_+rPos_, rLen_-rPos_);
+ need -= rLen_-rPos_;
+ buf += rLen_-rPos_;
+ rPos_ = rLen_;
+ }
+
+ // Double the size of the underlying buffer if it is full
+ if (rLen_ == rBufSize_) {
+ rBufSize_ *=2;
+ rBuf_ = (uint8_t *)std::realloc(rBuf_, sizeof(uint8_t) * rBufSize_);
+ }
+
+ // try to fill up the buffer
+ rLen_ += srcTrans_->read(rBuf_+rPos_, rBufSize_ - rPos_);
+ }
+
+
+ // Hand over whatever we have
+ uint32_t give = need;
+ if (rLen_-rPos_ < give) {
+ give = rLen_-rPos_;
+ }
+ if (give > 0) {
+ memcpy(buf, rBuf_+rPos_, give);
+ rPos_ += give;
+ need -= give;
+ }
+
+ return (len - need);
+}
+
+void TPipedTransport::write(const uint8_t* buf, uint32_t len) {
+ if (len == 0) {
+ return;
+ }
+
+ // Make the buffer as big as it needs to be
+ if ((len + wLen_) >= wBufSize_) {
+ uint32_t newBufSize = wBufSize_*2;
+ while ((len + wLen_) >= newBufSize) {
+ newBufSize *= 2;
+ }
+ wBuf_ = (uint8_t *)std::realloc(wBuf_, sizeof(uint8_t) * newBufSize);
+ wBufSize_ = newBufSize;
+ }
+
+ // Copy into the buffer
+ memcpy(wBuf_ + wLen_, buf, len);
+ wLen_ += len;
+}
+
+void TPipedTransport::flush() {
+ // Write out any data waiting in the write buffer
+ if (wLen_ > 0) {
+ srcTrans_->write(wBuf_, wLen_);
+ wLen_ = 0;
+ }
+
+ // Flush the underlying transport
+ srcTrans_->flush();
+}
+
+TPipedFileReaderTransport::TPipedFileReaderTransport(boost::shared_ptr<TFileReaderTransport> srcTrans, boost::shared_ptr<TTransport> dstTrans)
+ : TPipedTransport(srcTrans, dstTrans),
+ srcTrans_(srcTrans) {
+}
+
+TPipedFileReaderTransport::~TPipedFileReaderTransport() {
+}
+
+bool TPipedFileReaderTransport::isOpen() {
+ return TPipedTransport::isOpen();
+}
+
+bool TPipedFileReaderTransport::peek() {
+ return TPipedTransport::peek();
+}
+
+void TPipedFileReaderTransport::open() {
+ TPipedTransport::open();
+}
+
+void TPipedFileReaderTransport::close() {
+ TPipedTransport::close();
+}
+
+uint32_t TPipedFileReaderTransport::read(uint8_t* buf, uint32_t len) {
+ return TPipedTransport::read(buf, len);
+}
+
+uint32_t TPipedFileReaderTransport::readAll(uint8_t* buf, uint32_t len) {
+ uint32_t have = 0;
+ uint32_t get = 0;
+
+ while (have < len) {
+ get = read(buf+have, len-have);
+ if (get <= 0) {
+ throw TEOFException();
+ }
+ have += get;
+ }
+
+ return have;
+}
+
+void TPipedFileReaderTransport::readEnd() {
+ TPipedTransport::readEnd();
+}
+
+void TPipedFileReaderTransport::write(const uint8_t* buf, uint32_t len) {
+ TPipedTransport::write(buf, len);
+}
+
+void TPipedFileReaderTransport::writeEnd() {
+ TPipedTransport::writeEnd();
+}
+
+void TPipedFileReaderTransport::flush() {
+ TPipedTransport::flush();
+}
+
+int32_t TPipedFileReaderTransport::getReadTimeout() {
+ return srcTrans_->getReadTimeout();
+}
+
+void TPipedFileReaderTransport::setReadTimeout(int32_t readTimeout) {
+ srcTrans_->setReadTimeout(readTimeout);
+}
+
+uint32_t TPipedFileReaderTransport::getNumChunks() {
+ return srcTrans_->getNumChunks();
+}
+
+uint32_t TPipedFileReaderTransport::getCurChunk() {
+ return srcTrans_->getCurChunk();
+}
+
+void TPipedFileReaderTransport::seekToChunk(int32_t chunk) {
+ srcTrans_->seekToChunk(chunk);
+}
+
+void TPipedFileReaderTransport::seekToEnd() {
+ srcTrans_->seekToEnd();
+}
+
+}}} // apache::thrift::transport
diff --git a/lib/cpp/src/transport/TTransportUtils.h b/lib/cpp/src/transport/TTransportUtils.h
new file mode 100644
index 0000000..d65c916
--- /dev/null
+++ b/lib/cpp/src/transport/TTransportUtils.h
@@ -0,0 +1,287 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_TRANSPORT_TTRANSPORTUTILS_H_
+#define _THRIFT_TRANSPORT_TTRANSPORTUTILS_H_ 1
+
+#include <cstdlib>
+#include <cstring>
+#include <string>
+#include <algorithm>
+#include <transport/TTransport.h>
+// Include the buffered transports that used to be defined here.
+#include <transport/TBufferTransports.h>
+#include <transport/TFileTransport.h>
+
+namespace apache { namespace thrift { namespace transport {
+
+/**
+ * The null transport is a dummy transport that doesn't actually do anything.
+ * It's sort of an analogy to /dev/null, you can never read anything from it
+ * and it will let you write anything you want to it, though it won't actually
+ * go anywhere.
+ *
+ */
+class TNullTransport : public TTransport {
+ public:
+ TNullTransport() {}
+
+ ~TNullTransport() {}
+
+ bool isOpen() {
+ return true;
+ }
+
+ void open() {}
+
+ void write(const uint8_t* /* buf */, uint32_t /* len */) {
+ return;
+ }
+
+};
+
+
+/**
+ * TPipedTransport. This transport allows piping of a request from one
+ * transport to another either when readEnd() or writeEnd(). The typical
+ * use case for this is to log a request or a reply to disk.
+ * The underlying buffer expands to a keep a copy of the entire
+ * request/response.
+ *
+ */
+class TPipedTransport : virtual public TTransport {
+ public:
+ TPipedTransport(boost::shared_ptr<TTransport> srcTrans,
+ boost::shared_ptr<TTransport> dstTrans) :
+ srcTrans_(srcTrans),
+ dstTrans_(dstTrans),
+ rBufSize_(512), rPos_(0), rLen_(0),
+ wBufSize_(512), wLen_(0) {
+
+ // default is to to pipe the request when readEnd() is called
+ pipeOnRead_ = true;
+ pipeOnWrite_ = false;
+
+ rBuf_ = (uint8_t*) std::malloc(sizeof(uint8_t) * rBufSize_);
+ wBuf_ = (uint8_t*) std::malloc(sizeof(uint8_t) * wBufSize_);
+ }
+
+ TPipedTransport(boost::shared_ptr<TTransport> srcTrans,
+ boost::shared_ptr<TTransport> dstTrans,
+ uint32_t sz) :
+ srcTrans_(srcTrans),
+ dstTrans_(dstTrans),
+ rBufSize_(512), rPos_(0), rLen_(0),
+ wBufSize_(sz), wLen_(0) {
+
+ rBuf_ = (uint8_t*) std::malloc(sizeof(uint8_t) * rBufSize_);
+ wBuf_ = (uint8_t*) std::malloc(sizeof(uint8_t) * wBufSize_);
+ }
+
+ ~TPipedTransport() {
+ std::free(rBuf_);
+ std::free(wBuf_);
+ }
+
+ bool isOpen() {
+ return srcTrans_->isOpen();
+ }
+
+ bool peek() {
+ if (rPos_ >= rLen_) {
+ // Double the size of the underlying buffer if it is full
+ if (rLen_ == rBufSize_) {
+ rBufSize_ *=2;
+ rBuf_ = (uint8_t *)std::realloc(rBuf_, sizeof(uint8_t) * rBufSize_);
+ }
+
+ // try to fill up the buffer
+ rLen_ += srcTrans_->read(rBuf_+rPos_, rBufSize_ - rPos_);
+ }
+ return (rLen_ > rPos_);
+ }
+
+
+ void open() {
+ srcTrans_->open();
+ }
+
+ void close() {
+ srcTrans_->close();
+ }
+
+ void setPipeOnRead(bool pipeVal) {
+ pipeOnRead_ = pipeVal;
+ }
+
+ void setPipeOnWrite(bool pipeVal) {
+ pipeOnWrite_ = pipeVal;
+ }
+
+ uint32_t read(uint8_t* buf, uint32_t len);
+
+ void readEnd() {
+
+ if (pipeOnRead_) {
+ dstTrans_->write(rBuf_, rPos_);
+ dstTrans_->flush();
+ }
+
+ srcTrans_->readEnd();
+
+ // If requests are being pipelined, copy down our read-ahead data,
+ // then reset our state.
+ int read_ahead = rLen_ - rPos_;
+ memcpy(rBuf_, rBuf_ + rPos_, read_ahead);
+ rPos_ = 0;
+ rLen_ = read_ahead;
+ }
+
+ void write(const uint8_t* buf, uint32_t len);
+
+ void writeEnd() {
+ if (pipeOnWrite_) {
+ dstTrans_->write(wBuf_, wLen_);
+ dstTrans_->flush();
+ }
+ }
+
+ void flush();
+
+ boost::shared_ptr<TTransport> getTargetTransport() {
+ return dstTrans_;
+ }
+
+ protected:
+ boost::shared_ptr<TTransport> srcTrans_;
+ boost::shared_ptr<TTransport> dstTrans_;
+
+ uint8_t* rBuf_;
+ uint32_t rBufSize_;
+ uint32_t rPos_;
+ uint32_t rLen_;
+
+ uint8_t* wBuf_;
+ uint32_t wBufSize_;
+ uint32_t wLen_;
+
+ bool pipeOnRead_;
+ bool pipeOnWrite_;
+};
+
+
+/**
+ * Wraps a transport into a pipedTransport instance.
+ *
+ */
+class TPipedTransportFactory : public TTransportFactory {
+ public:
+ TPipedTransportFactory() {}
+ TPipedTransportFactory(boost::shared_ptr<TTransport> dstTrans) {
+ initializeTargetTransport(dstTrans);
+ }
+ virtual ~TPipedTransportFactory() {}
+
+ /**
+ * Wraps the base transport into a piped transport.
+ */
+ virtual boost::shared_ptr<TTransport> getTransport(boost::shared_ptr<TTransport> srcTrans) {
+ return boost::shared_ptr<TTransport>(new TPipedTransport(srcTrans, dstTrans_));
+ }
+
+ virtual void initializeTargetTransport(boost::shared_ptr<TTransport> dstTrans) {
+ if (dstTrans_.get() == NULL) {
+ dstTrans_ = dstTrans;
+ } else {
+ throw TException("Target transport already initialized");
+ }
+ }
+
+ protected:
+ boost::shared_ptr<TTransport> dstTrans_;
+};
+
+/**
+ * TPipedFileTransport. This is just like a TTransport, except that
+ * it is a templatized class, so that clients who rely on a specific
+ * TTransport can still access the original transport.
+ *
+ */
+class TPipedFileReaderTransport : public TPipedTransport,
+ public TFileReaderTransport {
+ public:
+ TPipedFileReaderTransport(boost::shared_ptr<TFileReaderTransport> srcTrans, boost::shared_ptr<TTransport> dstTrans);
+
+ ~TPipedFileReaderTransport();
+
+ // TTransport functions
+ bool isOpen();
+ bool peek();
+ void open();
+ void close();
+ uint32_t read(uint8_t* buf, uint32_t len);
+ uint32_t readAll(uint8_t* buf, uint32_t len);
+ void readEnd();
+ void write(const uint8_t* buf, uint32_t len);
+ void writeEnd();
+ void flush();
+
+ // TFileReaderTransport functions
+ int32_t getReadTimeout();
+ void setReadTimeout(int32_t readTimeout);
+ uint32_t getNumChunks();
+ uint32_t getCurChunk();
+ void seekToChunk(int32_t chunk);
+ void seekToEnd();
+
+ protected:
+ // shouldn't be used
+ TPipedFileReaderTransport();
+ boost::shared_ptr<TFileReaderTransport> srcTrans_;
+};
+
+/**
+ * Creates a TPipedFileReaderTransport from a filepath and a destination transport
+ *
+ */
+class TPipedFileReaderTransportFactory : public TPipedTransportFactory {
+ public:
+ TPipedFileReaderTransportFactory() {}
+ TPipedFileReaderTransportFactory(boost::shared_ptr<TTransport> dstTrans)
+ : TPipedTransportFactory(dstTrans)
+ {}
+ virtual ~TPipedFileReaderTransportFactory() {}
+
+ boost::shared_ptr<TTransport> getTransport(boost::shared_ptr<TTransport> srcTrans) {
+ boost::shared_ptr<TFileReaderTransport> pFileReaderTransport = boost::dynamic_pointer_cast<TFileReaderTransport>(srcTrans);
+ if (pFileReaderTransport.get() != NULL) {
+ return getFileReaderTransport(pFileReaderTransport);
+ } else {
+ return boost::shared_ptr<TTransport>();
+ }
+ }
+
+ boost::shared_ptr<TFileReaderTransport> getFileReaderTransport(boost::shared_ptr<TFileReaderTransport> srcTrans) {
+ return boost::shared_ptr<TFileReaderTransport>(new TPipedFileReaderTransport(srcTrans, dstTrans_));
+ }
+};
+
+}}} // apache::thrift::transport
+
+#endif // #ifndef _THRIFT_TRANSPORT_TTRANSPORTUTILS_H_
diff --git a/lib/cpp/src/transport/TZlibTransport.cpp b/lib/cpp/src/transport/TZlibTransport.cpp
new file mode 100644
index 0000000..2f14e90
--- /dev/null
+++ b/lib/cpp/src/transport/TZlibTransport.cpp
@@ -0,0 +1,299 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <cassert>
+#include <cstring>
+#include <algorithm>
+#include <transport/TZlibTransport.h>
+#include <zlib.h>
+
+using std::string;
+
+namespace apache { namespace thrift { namespace transport {
+
+// Don't call this outside of the constructor.
+void TZlibTransport::initZlib() {
+ int rv;
+ bool r_init = false;
+ try {
+ rstream_ = new z_stream;
+ wstream_ = new z_stream;
+
+ rstream_->zalloc = Z_NULL;
+ wstream_->zalloc = Z_NULL;
+ rstream_->zfree = Z_NULL;
+ wstream_->zfree = Z_NULL;
+ rstream_->opaque = Z_NULL;
+ wstream_->opaque = Z_NULL;
+
+ rstream_->next_in = crbuf_;
+ wstream_->next_in = uwbuf_;
+ rstream_->next_out = urbuf_;
+ wstream_->next_out = cwbuf_;
+ rstream_->avail_in = 0;
+ wstream_->avail_in = 0;
+ rstream_->avail_out = urbuf_size_;
+ wstream_->avail_out = cwbuf_size_;
+
+ rv = inflateInit(rstream_);
+ checkZlibRv(rv, rstream_->msg);
+
+ // Have to set this flag so we know whether to de-initialize.
+ r_init = true;
+
+ rv = deflateInit(wstream_, Z_DEFAULT_COMPRESSION);
+ checkZlibRv(rv, wstream_->msg);
+ }
+
+ catch (...) {
+ if (r_init) {
+ rv = inflateEnd(rstream_);
+ checkZlibRvNothrow(rv, rstream_->msg);
+ }
+ // There is no way we can get here if wstream_ was initialized.
+
+ throw;
+ }
+}
+
+inline void TZlibTransport::checkZlibRv(int status, const char* message) {
+ if (status != Z_OK) {
+ throw TZlibTransportException(status, message);
+ }
+}
+
+inline void TZlibTransport::checkZlibRvNothrow(int status, const char* message) {
+ if (status != Z_OK) {
+ string output = "TZlibTransport: zlib failure in destructor: " +
+ TZlibTransportException::errorMessage(status, message);
+ GlobalOutput(output.c_str());
+ }
+}
+
+TZlibTransport::~TZlibTransport() {
+ int rv;
+ rv = inflateEnd(rstream_);
+ checkZlibRvNothrow(rv, rstream_->msg);
+ rv = deflateEnd(wstream_);
+ checkZlibRvNothrow(rv, wstream_->msg);
+
+ delete[] urbuf_;
+ delete[] crbuf_;
+ delete[] uwbuf_;
+ delete[] cwbuf_;
+ delete rstream_;
+ delete wstream_;
+}
+
+bool TZlibTransport::isOpen() {
+ return (readAvail() > 0) || transport_->isOpen();
+}
+
+// READING STRATEGY
+//
+// We have two buffers for reading: one containing the compressed data (crbuf_)
+// and one containing the uncompressed data (urbuf_). When read is called,
+// we repeat the following steps until we have satisfied the request:
+// - Copy data from urbuf_ into the caller's buffer.
+// - If we had enough, return.
+// - If urbuf_ is empty, read some data into it from the underlying transport.
+// - Inflate data from crbuf_ into urbuf_.
+//
+// In standalone objects, we set input_ended_ to true when inflate returns
+// Z_STREAM_END. This allows to make sure that a checksum was verified.
+
+inline int TZlibTransport::readAvail() {
+ return urbuf_size_ - rstream_->avail_out - urpos_;
+}
+
+uint32_t TZlibTransport::read(uint8_t* buf, uint32_t len) {
+ int need = len;
+
+ // TODO(dreiss): Skip urbuf on big reads.
+
+ while (true) {
+ // Copy out whatever we have available, then give them the min of
+ // what we have and what they want, then advance indices.
+ int give = std::min(readAvail(), need);
+ memcpy(buf, urbuf_ + urpos_, give);
+ need -= give;
+ buf += give;
+ urpos_ += give;
+
+ // If they were satisfied, we are done.
+ if (need == 0) {
+ return len;
+ }
+
+ // If we get to this point, we need to get some more data.
+
+ // If zlib has reported the end of a stream, we can't really do any more.
+ if (input_ended_) {
+ return len - need;
+ }
+
+ // The uncompressed read buffer is empty, so reset the stream fields.
+ rstream_->next_out = urbuf_;
+ rstream_->avail_out = urbuf_size_;
+ urpos_ = 0;
+
+ // If we don't have any more compressed data available,
+ // read some from the underlying transport.
+ if (rstream_->avail_in == 0) {
+ uint32_t got = transport_->read(crbuf_, crbuf_size_);
+ if (got == 0) {
+ return len - need;
+ }
+ rstream_->next_in = crbuf_;
+ rstream_->avail_in = got;
+ }
+
+ // We have some compressed data now. Uncompress it.
+ int zlib_rv = inflate(rstream_, Z_SYNC_FLUSH);
+
+ if (zlib_rv == Z_STREAM_END) {
+ if (standalone_) {
+ input_ended_ = true;
+ }
+ } else {
+ checkZlibRv(zlib_rv, rstream_->msg);
+ }
+
+ // Okay. The read buffer should have whatever we can give it now.
+ // Loop back to the start and try to give some more.
+ }
+}
+
+
+// WRITING STRATEGY
+//
+// We buffer up small writes before sending them to zlib, so our logic is:
+// - Is the write big?
+// - Send the buffer to zlib.
+// - Send this data to zlib.
+// - Is the write small?
+// - Is there insufficient space in the buffer for it?
+// - Send the buffer to zlib.
+// - Copy the data to the buffer.
+//
+// We have two buffers for writing also: the uncompressed buffer (mentioned
+// above) and the compressed buffer. When sending data to zlib we loop over
+// the following until the source (uncompressed buffer or big write) is empty:
+// - Is there no more space in the compressed buffer?
+// - Write the compressed buffer to the underlying transport.
+// - Deflate from the source into the compressed buffer.
+
+void TZlibTransport::write(const uint8_t* buf, uint32_t len) {
+ // zlib's "deflate" function has enough logic in it that I think
+ // we're better off (performance-wise) buffering up small writes.
+ if ((int)len > MIN_DIRECT_DEFLATE_SIZE) {
+ flushToZlib(uwbuf_, uwpos_);
+ uwpos_ = 0;
+ flushToZlib(buf, len);
+ } else if (len > 0) {
+ if (uwbuf_size_ - uwpos_ < (int)len) {
+ flushToZlib(uwbuf_, uwpos_);
+ uwpos_ = 0;
+ }
+ memcpy(uwbuf_ + uwpos_, buf, len);
+ uwpos_ += len;
+ }
+}
+
+void TZlibTransport::flush() {
+ flushToZlib(uwbuf_, uwpos_, true);
+ assert((int)wstream_->avail_out != cwbuf_size_);
+ transport_->write(cwbuf_, cwbuf_size_ - wstream_->avail_out);
+ transport_->flush();
+}
+
+void TZlibTransport::flushToZlib(const uint8_t* buf, int len, bool finish) {
+ int flush = (finish ? Z_FINISH : Z_NO_FLUSH);
+
+ wstream_->next_in = const_cast<uint8_t*>(buf);
+ wstream_->avail_in = len;
+
+ while (wstream_->avail_in > 0 || finish) {
+ // If our ouput buffer is full, flush to the underlying transport.
+ if (wstream_->avail_out == 0) {
+ transport_->write(cwbuf_, cwbuf_size_);
+ wstream_->next_out = cwbuf_;
+ wstream_->avail_out = cwbuf_size_;
+ }
+
+ int zlib_rv = deflate(wstream_, flush);
+
+ if (finish && zlib_rv == Z_STREAM_END) {
+ assert(wstream_->avail_in == 0);
+ break;
+ }
+
+ checkZlibRv(zlib_rv, wstream_->msg);
+ }
+}
+
+const uint8_t* TZlibTransport::borrow(uint8_t* buf, uint32_t* len) {
+ // Don't try to be clever with shifting buffers.
+ // If we have enough data, give a pointer to it,
+ // otherwise let the protcol use its slow path.
+ if (readAvail() >= (int)*len) {
+ *len = (uint32_t)readAvail();
+ return urbuf_ + urpos_;
+ }
+ return NULL;
+}
+
+void TZlibTransport::consume(uint32_t len) {
+ if (readAvail() >= (int)len) {
+ urpos_ += len;
+ } else {
+ throw TTransportException(TTransportException::BAD_ARGS,
+ "consume did not follow a borrow.");
+ }
+}
+
+void TZlibTransport::verifyChecksum() {
+ if (!standalone_) {
+ throw TTransportException(
+ TTransportException::BAD_ARGS,
+ "TZLibTransport can only verify checksums for standalone objects.");
+ }
+
+ if (!input_ended_) {
+ // This should only be called when reading is complete,
+ // but it's possible that the whole checksum has not been fed to zlib yet.
+ // We try to read an extra byte here to force zlib to finish the stream.
+ // It might not always be easy to "unread" this byte,
+ // but we throw an exception if we get it, which is not really
+ // a recoverable error, so it doesn't matter.
+ uint8_t buf[1];
+ uint32_t got = this->read(buf, sizeof(buf));
+ if (got || !input_ended_) {
+ throw TTransportException(
+ TTransportException::CORRUPTED_DATA,
+ "Zlib stream not complete.");
+ }
+ }
+
+ // If the checksum had been bad, we would have gotten an error while
+ // inflating.
+}
+
+
+}}} // apache::thrift::transport
diff --git a/lib/cpp/src/transport/TZlibTransport.h b/lib/cpp/src/transport/TZlibTransport.h
new file mode 100644
index 0000000..1439d9d
--- /dev/null
+++ b/lib/cpp/src/transport/TZlibTransport.h
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_
+#define _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_ 1
+
+#include <boost/lexical_cast.hpp>
+#include <transport/TTransport.h>
+
+struct z_stream_s;
+
+namespace apache { namespace thrift { namespace transport {
+
+class TZlibTransportException : public TTransportException {
+ public:
+ TZlibTransportException(int status, const char* msg) :
+ TTransportException(TTransportException::INTERNAL_ERROR,
+ errorMessage(status, msg)),
+ zlib_status_(status),
+ zlib_msg_(msg == NULL ? "(null)" : msg) {}
+
+ virtual ~TZlibTransportException() throw() {}
+
+ int getZlibStatus() { return zlib_status_; }
+ std::string getZlibMessage() { return zlib_msg_; }
+
+ static std::string errorMessage(int status, const char* msg) {
+ std::string rv = "zlib error: ";
+ if (msg) {
+ rv += msg;
+ } else {
+ rv += "(no message)";
+ }
+ rv += " (status = ";
+ rv += boost::lexical_cast<std::string>(status);
+ rv += ")";
+ return rv;
+ }
+
+ int zlib_status_;
+ std::string zlib_msg_;
+};
+
+/**
+ * This transport uses zlib's compressed format on the "far" side.
+ *
+ * There are two kinds of TZlibTransport objects:
+ * - Standalone objects are used to encode self-contained chunks of data
+ * (like structures). They include checksums.
+ * - Non-standalone transports are used for RPC. They are not implemented yet.
+ *
+ * TODO(dreiss): Don't do an extra copy of the compressed data if
+ * the underlying transport is TBuffered or TMemory.
+ *
+ */
+class TZlibTransport : public TTransport {
+ public:
+
+ /**
+ * @param transport The transport to read compressed data from
+ * and write compressed data to.
+ * @param use_for_rpc True if this object will be used for RPC,
+ * false if this is a standalone object.
+ * @param urbuf_size Uncompressed buffer size for reading.
+ * @param crbuf_size Compressed buffer size for reading.
+ * @param uwbuf_size Uncompressed buffer size for writing.
+ * @param cwbuf_size Compressed buffer size for writing.
+ *
+ * TODO(dreiss): Write a constructor that isn't a pain.
+ */
+ TZlibTransport(boost::shared_ptr<TTransport> transport,
+ bool use_for_rpc,
+ int urbuf_size = DEFAULT_URBUF_SIZE,
+ int crbuf_size = DEFAULT_CRBUF_SIZE,
+ int uwbuf_size = DEFAULT_UWBUF_SIZE,
+ int cwbuf_size = DEFAULT_CWBUF_SIZE) :
+ transport_(transport),
+ standalone_(!use_for_rpc),
+ urpos_(0),
+ uwpos_(0),
+ input_ended_(false),
+ output_flushed_(false),
+ urbuf_size_(urbuf_size),
+ crbuf_size_(crbuf_size),
+ uwbuf_size_(uwbuf_size),
+ cwbuf_size_(cwbuf_size),
+ urbuf_(NULL),
+ crbuf_(NULL),
+ uwbuf_(NULL),
+ cwbuf_(NULL),
+ rstream_(NULL),
+ wstream_(NULL)
+ {
+
+ if (!standalone_) {
+ throw TTransportException(
+ TTransportException::BAD_ARGS,
+ "TZLibTransport has not been tested for RPC.");
+ }
+
+ if (uwbuf_size_ < MIN_DIRECT_DEFLATE_SIZE) {
+ // Have to copy this into a local because of a linking issue.
+ int minimum = MIN_DIRECT_DEFLATE_SIZE;
+ throw TTransportException(
+ TTransportException::BAD_ARGS,
+ "TZLibTransport: uncompressed write buffer must be at least"
+ + boost::lexical_cast<std::string>(minimum) + ".");
+ }
+
+ try {
+ urbuf_ = new uint8_t[urbuf_size];
+ crbuf_ = new uint8_t[crbuf_size];
+ uwbuf_ = new uint8_t[uwbuf_size];
+ cwbuf_ = new uint8_t[cwbuf_size];
+
+ // Don't call this outside of the constructor.
+ initZlib();
+
+ } catch (...) {
+ delete[] urbuf_;
+ delete[] crbuf_;
+ delete[] uwbuf_;
+ delete[] cwbuf_;
+ throw;
+ }
+ }
+
+ // Don't call this outside of the constructor.
+ void initZlib();
+
+ ~TZlibTransport();
+
+ bool isOpen();
+
+ void open() {
+ transport_->open();
+ }
+
+ void close() {
+ transport_->close();
+ }
+
+ uint32_t read(uint8_t* buf, uint32_t len);
+
+ void write(const uint8_t* buf, uint32_t len);
+
+ void flush();
+
+ const uint8_t* borrow(uint8_t* buf, uint32_t* len);
+
+ void consume(uint32_t len);
+
+ void verifyChecksum();
+
+ /**
+ * TODO(someone_smart): Choose smart defaults.
+ */
+ static const int DEFAULT_URBUF_SIZE = 128;
+ static const int DEFAULT_CRBUF_SIZE = 1024;
+ static const int DEFAULT_UWBUF_SIZE = 128;
+ static const int DEFAULT_CWBUF_SIZE = 1024;
+
+ protected:
+
+ inline void checkZlibRv(int status, const char* msg);
+ inline void checkZlibRvNothrow(int status, const char* msg);
+ inline int readAvail();
+ void flushToZlib(const uint8_t* buf, int len, bool finish = false);
+
+ // Writes smaller than this are buffered up.
+ // Larger (or equal) writes are dumped straight to zlib.
+ static const int MIN_DIRECT_DEFLATE_SIZE = 32;
+
+ boost::shared_ptr<TTransport> transport_;
+ bool standalone_;
+
+ int urpos_;
+ int uwpos_;
+
+ /// True iff zlib has reached the end of a stream.
+ /// This is only ever true in standalone protcol objects.
+ bool input_ended_;
+ /// True iff we have flushed the output stream.
+ /// This is only ever true in standalone protcol objects.
+ bool output_flushed_;
+
+ int urbuf_size_;
+ int crbuf_size_;
+ int uwbuf_size_;
+ int cwbuf_size_;
+
+ uint8_t* urbuf_;
+ uint8_t* crbuf_;
+ uint8_t* uwbuf_;
+ uint8_t* cwbuf_;
+
+ struct z_stream_s* rstream_;
+ struct z_stream_s* wstream_;
+};
+
+}}} // apache::thrift::transport
+
+#endif // #ifndef _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_