Thrift Binary protocol improvements and application exceptions

Summary: Add application exceptions for unknown methods etc, and also let binary protocol support size limits on containers and strings

Reviewed By: aditya, xp-wayne


git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@665003 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/cpp/src/protocol/TBinaryProtocol.cpp b/lib/cpp/src/protocol/TBinaryProtocol.cpp
index 05de125..cea248b 100644
--- a/lib/cpp/src/protocol/TBinaryProtocol.cpp
+++ b/lib/cpp/src/protocol/TBinaryProtocol.cpp
@@ -123,9 +123,12 @@
 
   
 uint32_t TBinaryProtocol::writeString(const string& str) {
-  uint32_t result = writeI32(str.size());
-  trans_->write((uint8_t*)str.data(), str.size());
-  return result + str.size();
+  uint32_t size = str.size();
+  uint32_t result = writeI32((int32_t)size);
+  if (size > 0) {
+    trans_->write((uint8_t*)str.data(), size);
+  }
+  return result + size;
 }
 
 /**
@@ -188,7 +191,11 @@
   result += readByte(v);
   valType = (TType)v;
   result += readI32(sizei);
-  // TODO(mcslee): check for negative size
+  if (sizei < 0) {
+    throw TProtocolException(TProtocolException::NEGATIVE_SIZE);
+  } else if (container_limit_ && sizei > container_limit_) {
+    throw TProtocolException(TProtocolException::SIZE_LIMIT);
+  }
   size = (uint32_t)sizei;
   return result;
 }
@@ -205,7 +212,11 @@
   result += readByte(e);
   elemType = (TType)e;
   result += readI32(sizei);
-  // TODO(mcslee): check for negative size
+  if (sizei < 0) {
+    throw TProtocolException(TProtocolException::NEGATIVE_SIZE);
+  } else if (container_limit_ && sizei > container_limit_) {
+    throw TProtocolException(TProtocolException::SIZE_LIMIT);
+  }
   size = (uint32_t)sizei;
   return result;
 }
@@ -222,7 +233,11 @@
   result += readByte(e);
   elemType = (TType)e;
   result += readI32(sizei);
-  // TODO(mcslee): check for negative size
+  if (sizei < 0) {
+    throw TProtocolException(TProtocolException::NEGATIVE_SIZE);
+  } else if (container_limit_ && sizei > container_limit_) {
+    throw TProtocolException(TProtocolException::SIZE_LIMIT);
+  }
   size = (uint32_t)sizei;
   return result;
 }
@@ -290,13 +305,31 @@
   int32_t size;
   result = readI32(size);
 
-  // TODO(mcslee): check for negative size
+  // Catch error cases
+  if (size < 0) {
+    throw TProtocolException(TProtocolException::NEGATIVE_SIZE);
+  }
+  if (string_limit_ > 0 && size > string_limit_) {
+    throw TProtocolException(TProtocolException::SIZE_LIMIT);
+  }
+
+  // Catch empty string case
+  if (size == 0) {
+    str = "";
+    return result;
+  }
 
   // Use the heap here to prevent stack overflow for v. large strings
-  uint8_t *b = new uint8_t[size];
-  trans_->readAll(b, size);
-  str = string((char*)b, size);
-  delete [] b;
+  if (size > string_buf_size_ || string_buf_ == NULL) {
+    string_buf_ = (uint8_t*)realloc(string_buf_, (uint32_t)size);
+    if (string_buf_ == NULL) {
+      string_buf_size_ = 0;
+      throw TProtocolException(TProtocolException::UNKNOWN, "Out of memory in TBinaryProtocol::readString");
+    }
+    string_buf_size_ = size;
+  }
+  trans_->readAll(string_buf_, size);
+  str = string((char*)string_buf_, size);
 
   return result + (uint32_t)size;
 }
diff --git a/lib/cpp/src/protocol/TBinaryProtocol.h b/lib/cpp/src/protocol/TBinaryProtocol.h
index c0a1837..3414cec 100644
--- a/lib/cpp/src/protocol/TBinaryProtocol.h
+++ b/lib/cpp/src/protocol/TBinaryProtocol.h
@@ -15,12 +15,38 @@
  *
  * @author Mark Slee <mcslee@facebook.com>
  */
-    class TBinaryProtocol : public TProtocol {
+class TBinaryProtocol : public TProtocol {
  public:
   TBinaryProtocol(shared_ptr<TTransport> trans) :
-    TProtocol(trans) {}
+    TProtocol(trans),
+    string_limit_(0),
+    container_limit_(0),
+    string_buf_(NULL),
+    string_buf_size_(0) {}
 
-  ~TBinaryProtocol() {}
+  TBinaryProtocol(shared_ptr<TTransport> trans,
+                  int32_t string_limit,
+                  int32_t container_limit) :
+    TProtocol(trans),
+    string_limit_(string_limit),
+    container_limit_(container_limit),
+    string_buf_(NULL),
+    string_buf_size_(0) {}
+
+  ~TBinaryProtocol() {
+    if (string_buf_ != NULL) {
+      free(string_buf_);
+      string_buf_size_ = 0;
+    }
+  }
+
+  void setStringSizeLimit(int32_t string_limit) {
+    string_limit_ = string_limit;
+  }
+
+  void setContainerSizeLimit(int32_t container_limit) {
+    container_limit_ = container_limit;
+  }
 
   /**
    * Writing functions.
@@ -126,6 +152,16 @@
   uint32_t readDouble(double& dub);
 
   uint32_t readString(std::string& str);
+
+ private:
+  int32_t string_limit_;
+  int32_t container_limit_;
+
+  // Buffer for reading strings, save for the lifetime of the protocol to
+  // avoid memory churn allocating memory on every string read
+  uint8_t* string_buf_;
+  int32_t string_buf_size_;
+
 };
 
 /**
@@ -133,13 +169,32 @@
  */
 class TBinaryProtocolFactory : public TProtocolFactory {
  public:
-  TBinaryProtocolFactory() {}
+  TBinaryProtocolFactory() :
+    string_limit_(0),
+    container_limit_(0) {}
+
+  TBinaryProtocolFactory(int32_t string_limit, int32_t container_limit) :
+    string_limit_(string_limit),
+    container_limit_(container_limit) {}
 
   virtual ~TBinaryProtocolFactory() {}
 
-  boost::shared_ptr<TProtocol> getProtocol(boost::shared_ptr<TTransport> trans) {
-    return boost::shared_ptr<TProtocol>(new TBinaryProtocol(trans));
+  void setStringSizeLimit(int32_t string_limit) {
+    string_limit_ = string_limit;
   }
+
+  void setContainerSizeLimit(int32_t container_limit) {
+    container_limit_ = container_limit;
+  }
+
+  boost::shared_ptr<TProtocol> getProtocol(boost::shared_ptr<TTransport> trans) {
+    return boost::shared_ptr<TProtocol>(new TBinaryProtocol(trans, string_limit_, container_limit_));
+  }
+
+ private:
+  int32_t string_limit_;
+  int32_t container_limit_;
+
 };
 
 }}} // facebook::thrift::protocol
diff --git a/lib/cpp/src/protocol/TProtocol.h b/lib/cpp/src/protocol/TProtocol.h
index 2ad5a57..42bea2b 100644
--- a/lib/cpp/src/protocol/TProtocol.h
+++ b/lib/cpp/src/protocol/TProtocol.h
@@ -2,6 +2,7 @@
 #define _THRIFT_PROTOCOL_TPROTOCOL_H_ 1
 
 #include <transport/TTransport.h>
+#include <protocol/TProtocolException.h>
 
 #include <boost/shared_ptr.hpp>
 
@@ -59,7 +60,8 @@
  */
 enum TMessageType {
   T_CALL       = 1,
-  T_REPLY      = 2
+  T_REPLY      = 2,
+  T_EXCEPTION  = 3
 };
 
 /**
diff --git a/lib/cpp/src/protocol/TProtocolException.h b/lib/cpp/src/protocol/TProtocolException.h
new file mode 100644
index 0000000..74c55cd
--- /dev/null
+++ b/lib/cpp/src/protocol/TProtocolException.h
@@ -0,0 +1,78 @@
+#ifndef _THRIFT_PROTOCOL_TPROTOCOLEXCEPTION_H_
+#define _THRIFT_PROTOCOL_TPROTOCOLEXCEPTION_H_ 1
+
+#include <boost/lexical_cast.hpp>
+#include <string>
+
+namespace facebook { namespace thrift { namespace protocol { 
+
+/**
+ * Class to encapsulate all the possible types of protocol errors that may
+ * occur in various protocol systems. This provides a sort of generic
+ * wrapper around the shitty UNIX E_ error codes that lets a common code
+ * base of error handling to be used for various types of protocols, i.e.
+ * pipes etc.
+ *
+ * @author Mark Slee <mcslee@facebook.com>
+ */
+class TProtocolException : public facebook::thrift::TException {
+ public:
+
+  /**
+   * Error codes for the various types of exceptions.
+   */
+  enum TProtocolExceptionType {
+    UNKNOWN = 0,
+    INVALID_DATA = 1,
+    NEGATIVE_SIZE = 2,
+    SIZE_LIMIT = 3
+  };
+
+  TProtocolException() :
+    facebook::thrift::TException(),
+    type_(UNKNOWN) {}
+
+  TProtocolException(TProtocolExceptionType type) :
+    facebook::thrift::TException(), 
+    type_(type) {}
+
+  TProtocolException(const std::string message) :
+    facebook::thrift::TException(message),
+    type_(UNKNOWN) {}
+
+  TProtocolException(TProtocolExceptionType type, const std::string message) :
+    facebook::thrift::TException(message),
+    type_(type) {}
+
+  virtual ~TProtocolException() throw() {}
+
+  /**
+   * Returns an error code that provides information about the type of error
+   * that has occurred.
+   *
+   * @return Error code
+   */
+  TProtocolExceptionType getType() {
+    return type_;
+  }
+
+  virtual const char* what() const throw() {
+    if (message_.empty()) {
+      return (std::string("Default Protocol Exception: ") +
+        boost::lexical_cast<std::string>(type_)).c_str();
+    } else {
+      return message_.c_str();
+    }
+  }
+
+ protected:
+  /** 
+   * Error code
+   */
+  TProtocolExceptionType type_;
+ 
+};
+
+}}} // facebook::thrift::protocol
+
+#endif // #ifndef _THRIFT_PROTOCOL_TPROTOCOLEXCEPTION_H_