THRIFT-5237 Implement MAX_MESSAGE_SIZE and consolidate limits into a TConfiguration class
Client: cpp
Patch: Zezeng Wang
This closes #2185
diff --git a/lib/cpp/Makefile.am b/lib/cpp/Makefile.am
index a536d17..3a0c4e6 100755
--- a/lib/cpp/Makefile.am
+++ b/lib/cpp/Makefile.am
@@ -141,7 +141,8 @@
src/thrift/TApplicationException.h \
src/thrift/TLogging.h \
src/thrift/TToString.h \
- src/thrift/TBase.h
+ src/thrift/TBase.h \
+ src/thrift/TConfiguration.h
include_concurrencydir = $(include_thriftdir)/concurrency
include_concurrency_HEADERS = \
@@ -156,6 +157,10 @@
include_protocoldir = $(include_thriftdir)/protocol
include_protocol_HEADERS = \
+ src/thrift/protocol/TEnum.h \
+ src/thrift/protocol/TList.h \
+ src/thrift/protocol/TSet.h \
+ src/thrift/protocol/TMap.h \
src/thrift/protocol/TBinaryProtocol.h \
src/thrift/protocol/TBinaryProtocol.tcc \
src/thrift/protocol/TCompactProtocol.h \
diff --git a/lib/cpp/src/thrift/TConfiguration.h b/lib/cpp/src/thrift/TConfiguration.h
new file mode 100644
index 0000000..5bff440
--- /dev/null
+++ b/lib/cpp/src/thrift/TConfiguration.h
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef THRIFT_TCONFIGURATION_H
+#define THRIFT_TCONFIGURATION_H
+
+namespace apache {
+namespace thrift {
+
+class TConfiguration
+{
+public:
+ TConfiguration(int maxMessageSize = DEFAULT_MAX_MESSAGE_SIZE,
+ int maxFrameSize = DEFAULT_MAX_FRAME_SIZE, int recursionLimit = DEFAULT_RECURSION_DEPTH)
+ : maxMessageSize_(maxMessageSize), maxFrameSize_(maxFrameSize), recursionLimit_(recursionLimit) {}
+
+ const static int DEFAULT_MAX_MESSAGE_SIZE = 100 * 1024 * 1024;
+ const static int DEFAULT_MAX_FRAME_SIZE = 16384000; // this value is used consistently across all Thrift libraries
+ const static int DEFAULT_RECURSION_DEPTH = 64;
+
+ inline int getMaxMessageSize() { return maxMessageSize_; }
+ inline void setMaxMessageSize(int maxMessageSize) { maxMessageSize_ = maxMessageSize; }
+ inline int getMaxFrameSize() { return maxFrameSize_; }
+ inline void setMaxFrameSize(int maxFrameSize) { maxFrameSize_ = maxFrameSize; }
+ inline int getRecursionLimit() { return recursionLimit_; }
+ inline void setRecursionLimit(int recursionLimit) { recursionLimit_ = recursionLimit; }
+
+private:
+ int maxMessageSize_ = DEFAULT_MAX_MESSAGE_SIZE;
+ int maxFrameSize_ = DEFAULT_MAX_FRAME_SIZE;
+ int recursionLimit_ = DEFAULT_RECURSION_DEPTH;
+
+ // TODO(someone_smart): add connection and i/o timeouts
+};
+}
+} // apache::thrift
+
+#endif /* THRIFT_TCONFIGURATION_H */
+
diff --git a/lib/cpp/src/thrift/protocol/TBinaryProtocol.h b/lib/cpp/src/thrift/protocol/TBinaryProtocol.h
index 6bd5fb8..b431440 100644
--- a/lib/cpp/src/thrift/protocol/TBinaryProtocol.h
+++ b/lib/cpp/src/thrift/protocol/TBinaryProtocol.h
@@ -166,6 +166,24 @@
inline uint32_t readBinary(std::string& str);
+ int getMinSerializedSize(TType type);
+
+ void checkReadBytesAvailable(TSet& set)
+ {
+ trans_->checkReadBytesAvailable(set.size_ * getMinSerializedSize(set.elemType_));
+ }
+
+ void checkReadBytesAvailable(TList& list)
+ {
+ trans_->checkReadBytesAvailable(list.size_ * getMinSerializedSize(list.elemType_));
+ }
+
+ void checkReadBytesAvailable(TMap& map)
+ {
+ int elmSize = getMinSerializedSize(map.keyType_) + getMinSerializedSize(map.valueType_);
+ trans_->checkReadBytesAvailable(map.size_ * elmSize);
+ }
+
protected:
template <typename StrType>
uint32_t readStringBody(StrType& str, int32_t sz);
diff --git a/lib/cpp/src/thrift/protocol/TBinaryProtocol.tcc b/lib/cpp/src/thrift/protocol/TBinaryProtocol.tcc
index 2964f25..755f243 100644
--- a/lib/cpp/src/thrift/protocol/TBinaryProtocol.tcc
+++ b/lib/cpp/src/thrift/protocol/TBinaryProtocol.tcc
@@ -21,6 +21,7 @@
#define _THRIFT_PROTOCOL_TBINARYPROTOCOL_TCC_ 1
#include <thrift/protocol/TBinaryProtocol.h>
+#include <thrift/transport/TTransportException.h>
#include <limits>
@@ -285,6 +286,10 @@
throw TProtocolException(TProtocolException::SIZE_LIMIT);
}
size = (uint32_t)sizei;
+
+ TMap map(keyType, valType, size);
+ checkReadBytesAvailable(map);
+
return result;
}
@@ -307,6 +312,10 @@
throw TProtocolException(TProtocolException::SIZE_LIMIT);
}
size = (uint32_t)sizei;
+
+ TList list(elemType, size);
+ checkReadBytesAvailable(list);
+
return result;
}
@@ -329,6 +338,10 @@
throw TProtocolException(TProtocolException::SIZE_LIMIT);
}
size = (uint32_t)sizei;
+
+ TSet set(elemType, size);
+ checkReadBytesAvailable(set);
+
return result;
}
@@ -447,6 +460,30 @@
this->trans_->readAll(reinterpret_cast<uint8_t*>(&str[0]), size);
return (uint32_t)size;
}
+
+// Return the minimum number of bytes a type will consume on the wire
+template <class Transport_, class ByteOrder_>
+int TBinaryProtocolT<Transport_, ByteOrder_>::getMinSerializedSize(TType type)
+{
+ switch (type)
+ {
+ case T_STOP: return 0;
+ case T_VOID: return 0;
+ case T_BOOL: return sizeof(int8_t);
+ case T_BYTE: return sizeof(int8_t);
+ case T_DOUBLE: return sizeof(double);
+ case T_I16: return sizeof(short);
+ case T_I32: return sizeof(int);
+ case T_I64: return sizeof(long);
+ case T_STRING: return sizeof(int); // string length
+ case T_STRUCT: return 0; // empty struct
+ case T_MAP: return sizeof(int); // element count
+ case T_SET: return sizeof(int); // element count
+ case T_LIST: return sizeof(int); // element count
+ default: throw TProtocolException(TProtocolException::UNKNOWN, "unrecognized type code");
+ }
+}
+
}
}
} // apache::thrift::protocol
diff --git a/lib/cpp/src/thrift/protocol/TCompactProtocol.h b/lib/cpp/src/thrift/protocol/TCompactProtocol.h
index 2930aba..6f990b2 100644
--- a/lib/cpp/src/thrift/protocol/TCompactProtocol.h
+++ b/lib/cpp/src/thrift/protocol/TCompactProtocol.h
@@ -140,6 +140,24 @@
uint32_t writeBinary(const std::string& str);
+ int getMinSerializedSize(TType type);
+
+ void checkReadBytesAvailable(TSet& set)
+ {
+ trans_->checkReadBytesAvailable(set.size_ * getMinSerializedSize(set.elemType_));
+ }
+
+ void checkReadBytesAvailable(TList& list)
+ {
+ trans_->checkReadBytesAvailable(list.size_ * getMinSerializedSize(list.elemType_));
+ }
+
+ void checkReadBytesAvailable(TMap& map)
+ {
+ int elmSize = getMinSerializedSize(map.keyType_) + getMinSerializedSize(map.valueType_);
+ trans_->checkReadBytesAvailable(map.size_ * elmSize);
+ }
+
/**
* These methods are called by structs, but don't actually have any wired
* output or purpose
diff --git a/lib/cpp/src/thrift/protocol/TCompactProtocol.tcc b/lib/cpp/src/thrift/protocol/TCompactProtocol.tcc
index d1e342e..1678091 100644
--- a/lib/cpp/src/thrift/protocol/TCompactProtocol.tcc
+++ b/lib/cpp/src/thrift/protocol/TCompactProtocol.tcc
@@ -538,6 +538,9 @@
valType = getTType((int8_t)((uint8_t)kvType & 0xf));
size = (uint32_t)msize;
+ TMap map(keyType, valType, size);
+ checkReadBytesAvailable(map);
+
return rsize;
}
@@ -570,6 +573,9 @@
elemType = getTType((int8_t)(size_and_type & 0x0f));
size = (uint32_t)lsize;
+ TList list(elemType, size);
+ checkReadBytesAvailable(list);
+
return rsize;
}
@@ -706,6 +712,8 @@
trans_->readAll(string_buf_, size);
str.assign((char*)string_buf_, size);
+ trans_->checkReadBytesAvailable(rsize + (uint32_t)size);
+
return rsize + (uint32_t)size;
}
@@ -821,6 +829,30 @@
}
}
+// Return the minimum number of bytes a type will consume on the wire
+template <class Transport_>
+int TCompactProtocolT<Transport_>::getMinSerializedSize(TType type)
+{
+ switch (type)
+ {
+ case T_STOP: return 0;
+ case T_VOID: return 0;
+ case T_BOOL: return sizeof(int8_t);
+ case T_DOUBLE: return 8; // uses fixedLongToBytes() which always writes 8 bytes
+ case T_BYTE: return sizeof(int8_t);
+ case T_I16: return sizeof(int8_t); // zigzag
+ case T_I32: return sizeof(int8_t); // zigzag
+ case T_I64: return sizeof(int8_t); // zigzag
+ case T_STRING: return sizeof(int8_t); // string length
+ case T_STRUCT: return 0; // empty struct
+ case T_MAP: return sizeof(int8_t); // element count
+ case T_SET: return sizeof(int8_t); // element count
+ case T_LIST: return sizeof(int8_t); // element count
+ default: throw TProtocolException(TProtocolException::UNKNOWN, "unrecognized type code");
+ }
+}
+
+
}}} // apache::thrift::protocol
#endif // _THRIFT_PROTOCOL_TCOMPACTPROTOCOL_TCC_
diff --git a/lib/cpp/src/thrift/protocol/TEnum.h b/lib/cpp/src/thrift/protocol/TEnum.h
new file mode 100644
index 0000000..9636785
--- /dev/null
+++ b/lib/cpp/src/thrift/protocol/TEnum.h
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_ENUM_H_
+#define _THRIFT_ENUM_H_
+
+namespace apache {
+namespace thrift {
+namespace protocol {
+
+/**
+ * Enumerated definition of the types that the Thrift protocol supports.
+ * Take special note of the T_END type which is used specifically to mark
+ * the end of a sequence of fields.
+ */
+enum TType {
+ T_STOP = 0,
+ T_VOID = 1,
+ T_BOOL = 2,
+ T_BYTE = 3,
+ T_I08 = 3,
+ T_I16 = 6,
+ T_I32 = 8,
+ T_U64 = 9,
+ T_I64 = 10,
+ T_DOUBLE = 4,
+ T_STRING = 11,
+ T_UTF7 = 11,
+ T_STRUCT = 12,
+ T_MAP = 13,
+ T_SET = 14,
+ T_LIST = 15,
+ T_UTF8 = 16,
+ T_UTF16 = 17
+};
+
+/**
+ * Enumerated definition of the message types that the Thrift protocol
+ * supports.
+ */
+enum TMessageType {
+ T_CALL = 1,
+ T_REPLY = 2,
+ T_EXCEPTION = 3,
+ T_ONEWAY = 4
+};
+
+}}} // apache::thrift::protocol
+
+#endif // #define _THRIFT_ENUM_H_
diff --git a/lib/cpp/src/thrift/protocol/TJSONProtocol.cpp b/lib/cpp/src/thrift/protocol/TJSONProtocol.cpp
index 28d0da2..6e4e8ef 100644
--- a/lib/cpp/src/thrift/protocol/TJSONProtocol.cpp
+++ b/lib/cpp/src/thrift/protocol/TJSONProtocol.cpp
@@ -1013,6 +1013,10 @@
throw TProtocolException(TProtocolException::SIZE_LIMIT);
size = static_cast<uint32_t>(tmpVal);
result += readJSONObjectStart();
+
+ TMap map(keyType, valType, size);
+ checkReadBytesAvailable(map);
+
return result;
}
@@ -1032,6 +1036,10 @@
if (tmpVal > (std::numeric_limits<uint32_t>::max)())
throw TProtocolException(TProtocolException::SIZE_LIMIT);
size = static_cast<uint32_t>(tmpVal);
+
+ TList list(elemType, size);
+ checkReadBytesAvailable(list);
+
return result;
}
@@ -1049,6 +1057,10 @@
if (tmpVal > (std::numeric_limits<uint32_t>::max)())
throw TProtocolException(TProtocolException::SIZE_LIMIT);
size = static_cast<uint32_t>(tmpVal);
+
+ TSet set(elemType, size);
+ checkReadBytesAvailable(set);
+
return result;
}
@@ -1093,6 +1105,29 @@
uint32_t TJSONProtocol::readBinary(std::string& str) {
return readJSONBase64(str);
}
+
+// Return the minimum number of bytes a type will consume on the wire
+int TJSONProtocol::getMinSerializedSize(TType type)
+{
+ switch (type)
+ {
+ case T_STOP: return 0;
+ case T_VOID: return 0;
+ case T_BOOL: return 1; // written as int
+ case T_BYTE: return 1;
+ case T_DOUBLE: return 1;
+ case T_I16: return 1;
+ case T_I32: return 1;
+ case T_I64: return 1;
+ case T_STRING: return 2; // empty string
+ case T_STRUCT: return 2; // empty struct
+ case T_MAP: return 2; // empty map
+ case T_SET: return 2; // empty set
+ case T_LIST: return 2; // empty list
+ default: throw TProtocolException(TProtocolException::UNKNOWN, "unrecognized type code");
+ }
+}
+
}
}
} // apache::thrift::protocol
diff --git a/lib/cpp/src/thrift/protocol/TJSONProtocol.h b/lib/cpp/src/thrift/protocol/TJSONProtocol.h
index 420995e..e775240 100644
--- a/lib/cpp/src/thrift/protocol/TJSONProtocol.h
+++ b/lib/cpp/src/thrift/protocol/TJSONProtocol.h
@@ -245,6 +245,24 @@
uint32_t readBinary(std::string& str);
+ int getMinSerializedSize(TType type);
+
+ void checkReadBytesAvailable(TSet& set)
+ {
+ trans_->checkReadBytesAvailable(set.size_ * getMinSerializedSize(set.elemType_));
+ }
+
+ void checkReadBytesAvailable(TList& list)
+ {
+ trans_->checkReadBytesAvailable(list.size_ * getMinSerializedSize(list.elemType_));
+ }
+
+ void checkReadBytesAvailable(TMap& map)
+ {
+ int elmSize = getMinSerializedSize(map.keyType_) + getMinSerializedSize(map.valueType_);
+ trans_->checkReadBytesAvailable(map.size_ * elmSize);
+ }
+
class LookaheadReader {
public:
diff --git a/lib/cpp/src/thrift/protocol/TList.h b/lib/cpp/src/thrift/protocol/TList.h
new file mode 100644
index 0000000..bf2c1f9
--- /dev/null
+++ b/lib/cpp/src/thrift/protocol/TList.h
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_TLIST_H_
+#define _THRIFT_TLIST_H_
+
+#include <thrift/protocol/TEnum.h>
+
+namespace apache {
+namespace thrift {
+namespace protocol {
+
+// using namespace apache::thrift::protocol;
+
+/**
+ * Helper class that encapsulates list metadata.
+ *
+ */
+class TList {
+public:
+ TList() : elemType_(T_STOP),
+ size_(0) {
+
+ }
+
+ TList(TType t = T_STOP, int s = 0)
+ : elemType_(t),
+ size_(s) {
+
+ }
+
+ TType elemType_;
+ int size_;
+};
+}
+}
+} // apache::thrift::protocol
+
+#endif // #ifndef _THRIFT_TLIST_H_
diff --git a/lib/cpp/src/thrift/protocol/TMap.h b/lib/cpp/src/thrift/protocol/TMap.h
new file mode 100644
index 0000000..b52ea8f
--- /dev/null
+++ b/lib/cpp/src/thrift/protocol/TMap.h
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_TMAP_H_
+#define _THRIFT_TMAP_H_
+
+#include <thrift/protocol/TEnum.h>
+
+namespace apache {
+namespace thrift {
+namespace protocol {
+
+using namespace apache::thrift::protocol;
+
+/**
+ * Helper class that encapsulates map metadata.
+ *
+ */
+class TMap {
+public:
+ TMap()
+ : keyType_(T_STOP),
+ valueType_(T_STOP),
+ size_(0) {
+
+ }
+
+ TMap(TType k, TType v, int s)
+ : keyType_(k),
+ valueType_(v),
+ size_(s) {
+
+ }
+
+ TType keyType_;
+ TType valueType_;
+ int size_;
+};
+}
+}
+} // apache::thrift::protocol
+
+#endif // #ifndef _THRIFT_TMAP_H_
diff --git a/lib/cpp/src/thrift/protocol/TProtocol.h b/lib/cpp/src/thrift/protocol/TProtocol.h
index df9c5c3..867ceb0 100644
--- a/lib/cpp/src/thrift/protocol/TProtocol.h
+++ b/lib/cpp/src/thrift/protocol/TProtocol.h
@@ -27,6 +27,10 @@
#include <thrift/transport/TTransport.h>
#include <thrift/protocol/TProtocolException.h>
+#include <thrift/protocol/TEnum.h>
+#include <thrift/protocol/TList.h>
+#include <thrift/protocol/TSet.h>
+#include <thrift/protocol/TMap.h>
#include <memory>
@@ -171,45 +175,6 @@
using apache::thrift::transport::TTransport;
/**
- * Enumerated definition of the types that the Thrift protocol supports.
- * Take special note of the T_END type which is used specifically to mark
- * the end of a sequence of fields.
- */
-enum TType {
- T_STOP = 0,
- T_VOID = 1,
- T_BOOL = 2,
- T_BYTE = 3,
- T_I08 = 3,
- T_I16 = 6,
- T_I32 = 8,
- T_U64 = 9,
- T_I64 = 10,
- T_DOUBLE = 4,
- T_STRING = 11,
- T_UTF7 = 11,
- T_STRUCT = 12,
- T_MAP = 13,
- T_SET = 14,
- T_LIST = 15,
- T_UTF8 = 16,
- T_UTF16 = 17
-};
-
-/**
- * Enumerated definition of the message types that the Thrift protocol
- * supports.
- */
-enum TMessageType {
- T_CALL = 1,
- T_REPLY = 2,
- T_EXCEPTION = 3,
- T_ONEWAY = 4
-};
-
-static const uint32_t DEFAULT_RECURSION_LIMIT = 64;
-
-/**
* Abstract class for a thrift protocol driver. These are all the methods that
* a protocol must implement. Essentially, there must be some way of reading
* and writing all the base types, plus a mechanism for writing out structs
@@ -578,11 +543,34 @@
uint32_t getRecursionLimit() const {return recursion_limit_;}
void setRecurisionLimit(uint32_t depth) {recursion_limit_ = depth;}
+ // Returns the minimum amount of bytes needed to store the smallest possible instance of TType.
+ virtual int getMinSerializedSize(TType type) {
+ THRIFT_UNUSED_VARIABLE(type);
+ return 0;
+ }
+
protected:
TProtocol(std::shared_ptr<TTransport> ptrans)
- : ptrans_(ptrans), input_recursion_depth_(0), output_recursion_depth_(0), recursion_limit_(DEFAULT_RECURSION_LIMIT)
+ : ptrans_(ptrans), input_recursion_depth_(0), output_recursion_depth_(0),
+ recursion_limit_(ptrans->getConfiguration()->getRecursionLimit())
{}
+ virtual void checkReadBytesAvailable(TSet& set)
+ {
+ ptrans_->checkReadBytesAvailable(set.size_ * getMinSerializedSize(set.elemType_));
+ }
+
+ virtual void checkReadBytesAvailable(TList& list)
+ {
+ ptrans_->checkReadBytesAvailable(list.size_ * getMinSerializedSize(list.elemType_));
+ }
+
+ virtual void checkReadBytesAvailable(TMap& map)
+ {
+ int elmSize = getMinSerializedSize(map.keyType_) + getMinSerializedSize(map.valueType_);
+ ptrans_->checkReadBytesAvailable(map.size_ * elmSize);
+ }
+
std::shared_ptr<TTransport> ptrans_;
private:
diff --git a/lib/cpp/src/thrift/protocol/TSet.h b/lib/cpp/src/thrift/protocol/TSet.h
new file mode 100644
index 0000000..3a4718c
--- /dev/null
+++ b/lib/cpp/src/thrift/protocol/TSet.h
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef _THRIFT_TSET_H_
+#define _THRIFT_TSET_H_
+
+#include <thrift/protocol/TEnum.h>
+#include <thrift/protocol/TList.h>
+
+namespace apache {
+namespace thrift {
+namespace protocol {
+
+using namespace apache::thrift::protocol;
+
+/**
+ * Helper class that encapsulates set metadata.
+ *
+ */
+class TSet {
+public:
+ TSet() : elemType_(T_STOP), size_(0) {
+
+ }
+
+ TSet(TType t, int s)
+ : elemType_(t),
+ size_(s) {
+
+ }
+
+ TSet(TList list)
+ : elemType_(list.elemType_),
+ size_(list.size_) {
+
+ }
+
+ TType elemType_;
+ int size_;
+};
+}
+}
+} // apache::thrift::protocol
+
+#endif // #ifndef _THRIFT_TSET_H_
diff --git a/lib/cpp/src/thrift/transport/TBufferTransports.cpp b/lib/cpp/src/thrift/transport/TBufferTransports.cpp
index d8a1b3e..45c0c9b 100644
--- a/lib/cpp/src/thrift/transport/TBufferTransports.cpp
+++ b/lib/cpp/src/thrift/transport/TBufferTransports.cpp
@@ -118,6 +118,7 @@
}
void TBufferedTransport::flush() {
+ resetConsumedMessageSize();
// Write out any data waiting in the write buffer.
auto have_bytes = static_cast<uint32_t>(wBase_ - wBuf_.get());
if (have_bytes > 0) {
@@ -248,6 +249,7 @@
}
void TFramedTransport::flush() {
+ resetConsumedMessageSize();
int32_t sz_hbo, sz_nbo;
assert(wBufSize_ > sizeof(sz_nbo));
diff --git a/lib/cpp/src/thrift/transport/TBufferTransports.h b/lib/cpp/src/thrift/transport/TBufferTransports.h
index 86f0c5a..179934b 100644
--- a/lib/cpp/src/thrift/transport/TBufferTransports.h
+++ b/lib/cpp/src/thrift/transport/TBufferTransports.h
@@ -62,6 +62,7 @@
* This method is meant to eventually be nonvirtual and inlinable.
*/
uint32_t read(uint8_t* buf, uint32_t len) {
+ checkReadBytesAvailable(len);
uint8_t* new_rBase = rBase_ + len;
if (TDB_LIKELY(new_rBase <= rBound_)) {
std::memcpy(buf, rBase_, len);
@@ -120,6 +121,7 @@
* Consume doesn't require a slow path.
*/
void consume(uint32_t len) {
+ countConsumedMessageBytes(len);
if (TDB_LIKELY(static_cast<ptrdiff_t>(len) <= rBound_ - rBase_)) {
rBase_ += len;
} else {
@@ -148,7 +150,8 @@
* performance-sensitive operation, so it is okay to just leave it to
* the concrete class to set up pointers correctly.
*/
- TBufferBase() : rBase_(nullptr), rBound_(nullptr), wBase_(nullptr), wBound_(nullptr) {}
+ TBufferBase(std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(config), rBase_(nullptr), rBound_(nullptr), wBase_(nullptr), wBound_(nullptr) {}
/// Convenience mutator for setting the read buffer.
void setReadBuffer(uint8_t* buf, uint32_t len) {
@@ -186,8 +189,9 @@
static const int DEFAULT_BUFFER_SIZE = 512;
/// Use default buffer sizes.
- TBufferedTransport(std::shared_ptr<TTransport> transport)
- : transport_(transport),
+ TBufferedTransport(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(config),
+ transport_(transport),
rBufSize_(DEFAULT_BUFFER_SIZE),
wBufSize_(DEFAULT_BUFFER_SIZE),
rBuf_(new uint8_t[rBufSize_]),
@@ -196,8 +200,9 @@
}
/// Use specified buffer sizes.
- TBufferedTransport(std::shared_ptr<TTransport> transport, uint32_t sz)
- : transport_(transport),
+ TBufferedTransport(std::shared_ptr<TTransport> transport, uint32_t sz, std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(config),
+ transport_(transport),
rBufSize_(sz),
wBufSize_(sz),
rBuf_(new uint8_t[rBufSize_]),
@@ -206,8 +211,10 @@
}
/// Use specified read and write buffer sizes.
- TBufferedTransport(std::shared_ptr<TTransport> transport, uint32_t rsz, uint32_t wsz)
- : transport_(transport),
+ TBufferedTransport(std::shared_ptr<TTransport> transport, uint32_t rsz, uint32_t wsz,
+ std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(config),
+ transport_(transport),
rBufSize_(rsz),
wBufSize_(wsz),
rBuf_(new uint8_t[rBufSize_]),
@@ -309,8 +316,9 @@
static const int DEFAULT_MAX_FRAME_SIZE = 256 * 1024 * 1024;
/// Use default buffer sizes.
- TFramedTransport()
- : transport_(),
+ TFramedTransport(std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(config),
+ transport_(),
rBufSize_(0),
wBufSize_(DEFAULT_BUFFER_SIZE),
rBuf_(),
@@ -319,27 +327,30 @@
initPointers();
}
- TFramedTransport(std::shared_ptr<TTransport> transport)
- : transport_(transport),
+ TFramedTransport(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(config),
+ transport_(transport),
rBufSize_(0),
wBufSize_(DEFAULT_BUFFER_SIZE),
rBuf_(),
wBuf_(new uint8_t[wBufSize_]),
bufReclaimThresh_((std::numeric_limits<uint32_t>::max)()),
- maxFrameSize_(DEFAULT_MAX_FRAME_SIZE) {
+ maxFrameSize_(configuration_->getMaxFrameSize()) {
initPointers();
}
TFramedTransport(std::shared_ptr<TTransport> transport,
uint32_t sz,
- uint32_t bufReclaimThresh = (std::numeric_limits<uint32_t>::max)())
- : transport_(transport),
+ uint32_t bufReclaimThresh = (std::numeric_limits<uint32_t>::max)(),
+ std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(config),
+ transport_(transport),
rBufSize_(0),
wBufSize_(sz),
rBuf_(),
wBuf_(new uint8_t[wBufSize_]),
bufReclaimThresh_(bufReclaimThresh),
- maxFrameSize_(DEFAULT_MAX_FRAME_SIZE) {
+ maxFrameSize_(configuration_->getMaxFrameSize()) {
initPointers();
}
@@ -503,7 +514,10 @@
* Construct a TMemoryBuffer with a default-sized buffer,
* owned by the TMemoryBuffer object.
*/
- TMemoryBuffer() { initCommon(nullptr, defaultSize, true, 0); }
+ TMemoryBuffer(std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(config) {
+ initCommon(nullptr, defaultSize, true, 0);
+ }
/**
* Construct a TMemoryBuffer with a buffer of a specified size,
@@ -511,7 +525,10 @@
*
* @param sz The initial size of the buffer.
*/
- TMemoryBuffer(uint32_t sz) { initCommon(nullptr, sz, true, 0); }
+ TMemoryBuffer(uint32_t sz, std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(config) {
+ initCommon(nullptr, sz, true, 0);
+ }
/**
* Construct a TMemoryBuffer with buf as its initial contents.
@@ -523,7 +540,8 @@
* @param sz The size of @c buf.
* @param policy See @link MemoryPolicy @endlink .
*/
- TMemoryBuffer(uint8_t* buf, uint32_t sz, MemoryPolicy policy = OBSERVE) {
+ TMemoryBuffer(uint8_t* buf, uint32_t sz, MemoryPolicy policy = OBSERVE, std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(config) {
if (buf == nullptr && sz != 0) {
throw TTransportException(TTransportException::BAD_ARGS,
"TMemoryBuffer given null buffer with non-zero size.");
diff --git a/lib/cpp/src/thrift/transport/TFDTransport.cpp b/lib/cpp/src/thrift/transport/TFDTransport.cpp
index 93dd100..fa7f0da 100644
--- a/lib/cpp/src/thrift/transport/TFDTransport.cpp
+++ b/lib/cpp/src/thrift/transport/TFDTransport.cpp
@@ -52,6 +52,7 @@
}
uint32_t TFDTransport::read(uint8_t* buf, uint32_t len) {
+ checkReadBytesAvailable(len);
unsigned int maxRetries = 5; // same as the TSocket default
unsigned int retries = 0;
while (true) {
diff --git a/lib/cpp/src/thrift/transport/TFDTransport.h b/lib/cpp/src/thrift/transport/TFDTransport.h
index a3cf519..fb84c9d 100644
--- a/lib/cpp/src/thrift/transport/TFDTransport.h
+++ b/lib/cpp/src/thrift/transport/TFDTransport.h
@@ -40,8 +40,10 @@
public:
enum ClosePolicy { NO_CLOSE_ON_DESTROY = 0, CLOSE_ON_DESTROY = 1 };
- TFDTransport(int fd, ClosePolicy close_policy = NO_CLOSE_ON_DESTROY)
- : fd_(fd), close_policy_(close_policy) {}
+ TFDTransport(int fd, ClosePolicy close_policy = NO_CLOSE_ON_DESTROY,
+ std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(config), fd_(fd), close_policy_(close_policy) {
+ }
~TFDTransport() override {
if (close_policy_ == CLOSE_ON_DESTROY) {
diff --git a/lib/cpp/src/thrift/transport/TFileTransport.cpp b/lib/cpp/src/thrift/transport/TFileTransport.cpp
index eaf2bc3..08372b3 100644
--- a/lib/cpp/src/thrift/transport/TFileTransport.cpp
+++ b/lib/cpp/src/thrift/transport/TFileTransport.cpp
@@ -63,8 +63,9 @@
using namespace apache::thrift::protocol;
using namespace apache::thrift::concurrency;
-TFileTransport::TFileTransport(string path, bool readOnly)
- : readState_(),
+TFileTransport::TFileTransport(string path, bool readOnly, std::shared_ptr<TConfiguration> config)
+ : TTransport(config),
+ readState_(),
readBuff_(nullptr),
currentEvent_(nullptr),
readBuffSize_(DEFAULT_READ_BUFF_SIZE),
@@ -519,6 +520,7 @@
}
void TFileTransport::flush() {
+ resetConsumedMessageSize();
// file must be open for writing for any flushing to take place
if (!writerThread_.get()) {
return;
@@ -537,6 +539,7 @@
}
uint32_t TFileTransport::readAll(uint8_t* buf, uint32_t len) {
+ checkReadBytesAvailable(len);
uint32_t have = 0;
uint32_t get = 0;
@@ -568,6 +571,7 @@
}
uint32_t TFileTransport::read(uint8_t* buf, uint32_t len) {
+ checkReadBytesAvailable(len);
// check if there an event is ready to be read
if (!currentEvent_) {
currentEvent_ = readEvent();
diff --git a/lib/cpp/src/thrift/transport/TFileTransport.h b/lib/cpp/src/thrift/transport/TFileTransport.h
index 0df5cf9..608cff1 100644
--- a/lib/cpp/src/thrift/transport/TFileTransport.h
+++ b/lib/cpp/src/thrift/transport/TFileTransport.h
@@ -173,7 +173,7 @@
*/
class TFileTransport : public TFileReaderTransport, public TFileWriterTransport {
public:
- TFileTransport(std::string path, bool readOnly = false);
+ TFileTransport(std::string path, bool readOnly = false, std::shared_ptr<TConfiguration> config = nullptr);
~TFileTransport() override;
// TODO: what is the correct behaviour for this?
diff --git a/lib/cpp/src/thrift/transport/THeaderTransport.cpp b/lib/cpp/src/thrift/transport/THeaderTransport.cpp
index b582d8d..b3b8333 100644
--- a/lib/cpp/src/thrift/transport/THeaderTransport.cpp
+++ b/lib/cpp/src/thrift/transport/THeaderTransport.cpp
@@ -415,6 +415,7 @@
}
void THeaderTransport::flush() {
+ resetConsumedMessageSize();
// Write out any data waiting in the write buffer.
uint32_t haveBytes = getWriteBytes();
diff --git a/lib/cpp/src/thrift/transport/THeaderTransport.h b/lib/cpp/src/thrift/transport/THeaderTransport.h
index d1e9d43..63a4ac8 100644
--- a/lib/cpp/src/thrift/transport/THeaderTransport.h
+++ b/lib/cpp/src/thrift/transport/THeaderTransport.h
@@ -74,8 +74,9 @@
static const int THRIFT_MAX_VARINT32_BYTES = 5;
/// Use default buffer sizes.
- explicit THeaderTransport(const std::shared_ptr<TTransport>& transport)
- : TVirtualTransport(transport),
+ explicit THeaderTransport(const std::shared_ptr<TTransport>& transport,
+ std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(transport, config),
outTransport_(transport),
protoId(T_COMPACT_PROTOCOL),
clientType(THRIFT_HEADER_CLIENT_TYPE),
@@ -88,8 +89,9 @@
}
THeaderTransport(const std::shared_ptr<TTransport> inTransport,
- const std::shared_ptr<TTransport> outTransport)
- : TVirtualTransport(inTransport),
+ const std::shared_ptr<TTransport> outTransport,
+ std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(inTransport, config),
outTransport_(outTransport),
protoId(T_COMPACT_PROTOCOL),
clientType(THRIFT_HEADER_CLIENT_TYPE),
diff --git a/lib/cpp/src/thrift/transport/THttpClient.cpp b/lib/cpp/src/thrift/transport/THttpClient.cpp
index fdee787..ea2eb99 100644
--- a/lib/cpp/src/thrift/transport/THttpClient.cpp
+++ b/lib/cpp/src/thrift/transport/THttpClient.cpp
@@ -34,12 +34,16 @@
THttpClient::THttpClient(std::shared_ptr<TTransport> transport,
std::string host,
- std::string path)
- : THttpTransport(transport), host_(host), path_(path) {
+ std::string path,
+ std::shared_ptr<TConfiguration> config)
+ : THttpTransport(transport, config),
+ host_(host),
+ path_(path) {
}
-THttpClient::THttpClient(string host, int port, string path)
- : THttpTransport(std::shared_ptr<TTransport>(new TSocket(host, port))),
+THttpClient::THttpClient(string host, int port, string path,
+ std::shared_ptr<TConfiguration> config)
+ : THttpTransport(std::shared_ptr<TTransport>(new TSocket(host, port)), config),
host_(host),
path_(path) {
}
@@ -93,6 +97,7 @@
}
void THttpClient::flush() {
+ resetConsumedMessageSize();
// Fetch the contents of the write buffer
uint8_t* buf;
uint32_t len;
diff --git a/lib/cpp/src/thrift/transport/THttpClient.h b/lib/cpp/src/thrift/transport/THttpClient.h
index 81ddc56..f0d7e8b 100644
--- a/lib/cpp/src/thrift/transport/THttpClient.h
+++ b/lib/cpp/src/thrift/transport/THttpClient.h
@@ -40,13 +40,16 @@
*/
THttpClient(std::shared_ptr<TTransport> transport,
std::string host = "localhost",
- std::string path = "/service");
+ std::string path = "/service",
+ std::shared_ptr<TConfiguration> config = nullptr);
/**
* @brief Constructor that will create a new socket transport using the host
* and port.
*/
- THttpClient(std::string host, int port, std::string path = "");
+ THttpClient(std::string host, int port,
+ std::string path = "",
+ std::shared_ptr<TConfiguration> config = nullptr);
~THttpClient() override;
diff --git a/lib/cpp/src/thrift/transport/THttpServer.cpp b/lib/cpp/src/thrift/transport/THttpServer.cpp
index 98518fd..91a1c39 100644
--- a/lib/cpp/src/thrift/transport/THttpServer.cpp
+++ b/lib/cpp/src/thrift/transport/THttpServer.cpp
@@ -34,7 +34,9 @@
namespace thrift {
namespace transport {
-THttpServer::THttpServer(std::shared_ptr<TTransport> transport) : THttpTransport(transport) {
+THttpServer::THttpServer(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config)
+ : THttpTransport(transport, config) {
+
}
THttpServer::~THttpServer() = default;
@@ -118,6 +120,7 @@
}
void THttpServer::flush() {
+ resetConsumedMessageSize();
// Fetch the contents of the write buffer
uint8_t* buf;
uint32_t len;
diff --git a/lib/cpp/src/thrift/transport/THttpServer.h b/lib/cpp/src/thrift/transport/THttpServer.h
index d219691..bc98986 100644
--- a/lib/cpp/src/thrift/transport/THttpServer.h
+++ b/lib/cpp/src/thrift/transport/THttpServer.h
@@ -28,7 +28,7 @@
class THttpServer : public THttpTransport {
public:
- THttpServer(std::shared_ptr<TTransport> transport);
+ THttpServer(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config = nullptr);
~THttpServer() override;
diff --git a/lib/cpp/src/thrift/transport/THttpTransport.cpp b/lib/cpp/src/thrift/transport/THttpTransport.cpp
index aea2b28..305221e 100644
--- a/lib/cpp/src/thrift/transport/THttpTransport.cpp
+++ b/lib/cpp/src/thrift/transport/THttpTransport.cpp
@@ -31,8 +31,9 @@
const char* THttpTransport::CRLF = "\r\n";
const int THttpTransport::CRLF_LEN = 2;
-THttpTransport::THttpTransport(std::shared_ptr<TTransport> transport)
- : transport_(transport),
+THttpTransport::THttpTransport(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config)
+ : TVirtualTransport(config),
+ transport_(transport),
origin_(""),
readHeaders_(true),
chunked_(false),
@@ -61,6 +62,7 @@
}
uint32_t THttpTransport::read(uint8_t* buf, uint32_t len) {
+ checkReadBytesAvailable(len);
if (readBuffer_.available_read() == 0) {
readBuffer_.resetBuffer();
uint32_t got = readMoreData();
diff --git a/lib/cpp/src/thrift/transport/THttpTransport.h b/lib/cpp/src/thrift/transport/THttpTransport.h
index 75f0d8c..5d2bd37 100644
--- a/lib/cpp/src/thrift/transport/THttpTransport.h
+++ b/lib/cpp/src/thrift/transport/THttpTransport.h
@@ -36,7 +36,7 @@
*/
class THttpTransport : public TVirtualTransport<THttpTransport> {
public:
- THttpTransport(std::shared_ptr<TTransport> transport);
+ THttpTransport(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config = nullptr);
~THttpTransport() override;
@@ -54,7 +54,9 @@
void write(const uint8_t* buf, uint32_t len);
- void flush() override = 0;
+ void flush() override {
+ resetConsumedMessageSize();
+ };
const std::string getOrigin() const override;
diff --git a/lib/cpp/src/thrift/transport/TPipe.cpp b/lib/cpp/src/thrift/transport/TPipe.cpp
index 4c2fea9..953cec1 100644
--- a/lib/cpp/src/thrift/transport/TPipe.cpp
+++ b/lib/cpp/src/thrift/transport/TPipe.cpp
@@ -222,30 +222,35 @@
}
//---- Constructors ----
-TPipe::TPipe(TAutoHandle &Pipe)
- : impl_(new TWaitableNamedPipeImpl(Pipe)), TimeoutSeconds_(3), isAnonymous_(false) {
+TPipe::TPipe(TAutoHandle &Pipe, std::shared_ptr<TConfiguration> config)
+ : impl_(new TWaitableNamedPipeImpl(Pipe)), TimeoutSeconds_(3),
+ isAnonymous_(false), TVirtualTransport(config) {
}
-TPipe::TPipe(HANDLE Pipe)
- : TimeoutSeconds_(3), isAnonymous_(false)
+TPipe::TPipe(HANDLE Pipe, std::shared_ptr<TConfiguration> config)
+ : TimeoutSeconds_(3), isAnonymous_(false), TVirtualTransport(config)
{
TAutoHandle pipeHandle(Pipe);
impl_.reset(new TWaitableNamedPipeImpl(pipeHandle));
}
-TPipe::TPipe(const char* pipename) : TimeoutSeconds_(3), isAnonymous_(false) {
+TPipe::TPipe(const char* pipename, std::shared_ptr<TConfiguration> config) : TimeoutSeconds_(3),
+ isAnonymous_(false), TVirtualTransport(config) {
setPipename(pipename);
}
-TPipe::TPipe(const std::string& pipename) : TimeoutSeconds_(3), isAnonymous_(false) {
+TPipe::TPipe(const std::string& pipename, std::shared_ptr<TConfiguration> config) : TimeoutSeconds_(3),
+ isAnonymous_(false), TVirtualTransport(config) {
setPipename(pipename);
}
-TPipe::TPipe(HANDLE PipeRd, HANDLE PipeWrt)
- : impl_(new TAnonPipeImpl(PipeRd, PipeWrt)), TimeoutSeconds_(3), isAnonymous_(true) {
+TPipe::TPipe(HANDLE PipeRd, HANDLE PipeWrt, std::shared_ptr<TConfiguration> config)
+ : impl_(new TAnonPipeImpl(PipeRd, PipeWrt)), TimeoutSeconds_(3), isAnonymous_(true),
+ TVirtualTransport(config) {
}
-TPipe::TPipe() : TimeoutSeconds_(3), isAnonymous_(false) {
+TPipe::TPipe(std::shared_ptr<TConfiguration> config) : TimeoutSeconds_(3), isAnonymous_(false),
+ TVirtualTransport(config) {
}
TPipe::~TPipe() {
@@ -299,6 +304,7 @@
}
uint32_t TPipe::read(uint8_t* buf, uint32_t len) {
+ checkReadBytesAvailable(len);
if (!isOpen())
throw TTransportException(TTransportException::NOT_OPEN, "Called read on non-open pipe");
return impl_->read(buf, len);
diff --git a/lib/cpp/src/thrift/transport/TPipe.h b/lib/cpp/src/thrift/transport/TPipe.h
index ba149b1..7795151 100644
--- a/lib/cpp/src/thrift/transport/TPipe.h
+++ b/lib/cpp/src/thrift/transport/TPipe.h
@@ -49,15 +49,15 @@
class TPipe : public TVirtualTransport<TPipe> {
public:
// Constructs a new pipe object.
- TPipe();
+ TPipe(std::shared_ptr<TConfiguration> config = nullptr);
// Named pipe constructors -
- explicit TPipe(HANDLE Pipe); // HANDLE is a void*
- explicit TPipe(TAutoHandle& Pipe); // this ctor will clear out / move from Pipe
+ explicit TPipe(HANDLE Pipe, std::shared_ptr<TConfiguration> config = nullptr); // HANDLE is a void*
+ explicit TPipe(TAutoHandle& Pipe, std::shared_ptr<TConfiguration> config = nullptr); // this ctor will clear out / move from Pipe
// need a const char * overload so string literals don't go to the HANDLE overload
- explicit TPipe(const char* pipename);
- explicit TPipe(const std::string& pipename);
+ explicit TPipe(const char* pipename, std::shared_ptr<TConfiguration> config = nullptr);
+ explicit TPipe(const std::string& pipename, std::shared_ptr<TConfiguration> config = nullptr);
// Anonymous pipe -
- TPipe(HANDLE PipeRd, HANDLE PipeWrt);
+ TPipe(HANDLE PipeRd, HANDLE PipeWrt, std::shared_ptr<TConfiguration> config = nullptr);
// Destroys the pipe object, closing it if necessary.
virtual ~TPipe();
diff --git a/lib/cpp/src/thrift/transport/TSSLSocket.cpp b/lib/cpp/src/thrift/transport/TSSLSocket.cpp
index aa76980..9efc5fc 100644
--- a/lib/cpp/src/thrift/transport/TSSLSocket.cpp
+++ b/lib/cpp/src/thrift/transport/TSSLSocket.cpp
@@ -214,34 +214,37 @@
}
// TSSLSocket implementation
-TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx)
- : TSocket(), server_(false), ssl_(nullptr), ctx_(ctx) {
+TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<TConfiguration> config)
+ : TSocket(config), server_(false), ssl_(nullptr), ctx_(ctx) {
init();
}
-TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<THRIFT_SOCKET> interruptListener)
- : TSocket(), server_(false), ssl_(nullptr), ctx_(ctx) {
+TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<THRIFT_SOCKET> interruptListener,
+ std::shared_ptr<TConfiguration> config)
+ : TSocket(config), server_(false), ssl_(nullptr), ctx_(ctx) {
init();
interruptListener_ = interruptListener;
}
-TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket)
- : TSocket(socket), server_(false), ssl_(nullptr), ctx_(ctx) {
+TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<TConfiguration> config)
+ : TSocket(socket, config), server_(false), ssl_(nullptr), ctx_(ctx) {
init();
}
-TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener)
- : TSocket(socket, interruptListener), server_(false), ssl_(nullptr), ctx_(ctx) {
+TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener,
+ std::shared_ptr<TConfiguration> config)
+ : TSocket(socket, interruptListener, config), server_(false), ssl_(nullptr), ctx_(ctx) {
init();
}
-TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, string host, int port)
- : TSocket(host, port), server_(false), ssl_(nullptr), ctx_(ctx) {
+TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, string host, int port, std::shared_ptr<TConfiguration> config)
+ : TSocket(host, port, config), server_(false), ssl_(nullptr), ctx_(ctx) {
init();
}
-TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, string host, int port, std::shared_ptr<THRIFT_SOCKET> interruptListener)
- : TSocket(host, port), server_(false), ssl_(nullptr), ctx_(ctx) {
+TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, string host, int port, std::shared_ptr<THRIFT_SOCKET> interruptListener,
+ std::shared_ptr<TConfiguration> config)
+ : TSocket(host, port, config), server_(false), ssl_(nullptr), ctx_(ctx) {
init();
interruptListener_ = interruptListener;
}
@@ -391,6 +394,7 @@
* exception incase of failure.
*/
uint32_t TSSLSocket::read(uint8_t* buf, uint32_t len) {
+ checkReadBytesAvailable(len);
initializeHandshake();
if (!checkHandshake())
throw TTransportException(TTransportException::UNKNOWN, "retry again");
@@ -553,6 +557,7 @@
}
void TSSLSocket::flush() {
+ resetConsumedMessageSize();
// Don't throw exception if not open. Thrift servers close socket twice.
if (ssl_ == nullptr) {
return;
diff --git a/lib/cpp/src/thrift/transport/TSSLSocket.h b/lib/cpp/src/thrift/transport/TSSLSocket.h
index a78112c..5afc571 100644
--- a/lib/cpp/src/thrift/transport/TSSLSocket.h
+++ b/lib/cpp/src/thrift/transport/TSSLSocket.h
@@ -111,37 +111,40 @@
/**
* Constructor.
*/
- TSSLSocket(std::shared_ptr<SSLContext> ctx);
+ TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructor with an interrupt signal.
*/
- TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<THRIFT_SOCKET> interruptListener);
+ TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<THRIFT_SOCKET> interruptListener,
+ std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructor, create an instance of TSSLSocket given an existing socket.
*
* @param socket An existing socket
*/
- TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket);
+ TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructor, create an instance of TSSLSocket given an existing socket that can be interrupted.
*
* @param socket An existing socket
*/
- TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener);
+ TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener,
+ std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructor.
*
* @param host Remote host name
* @param port Remote port number
*/
- TSSLSocket(std::shared_ptr<SSLContext> ctx, std::string host, int port);
+ TSSLSocket(std::shared_ptr<SSLContext> ctx, std::string host, int port, std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructor with an interrupt signal.
*
* @param host Remote host name
* @param port Remote port number
*/
- TSSLSocket(std::shared_ptr<SSLContext> ctx, std::string host, int port, std::shared_ptr<THRIFT_SOCKET> interruptListener);
+ TSSLSocket(std::shared_ptr<SSLContext> ctx, std::string host, int port, std::shared_ptr<THRIFT_SOCKET> interruptListener,
+ std::shared_ptr<TConfiguration> config = nullptr);
/**
* Authorize peer access after SSL handshake completes.
*/
diff --git a/lib/cpp/src/thrift/transport/TShortReadTransport.h b/lib/cpp/src/thrift/transport/TShortReadTransport.h
index 185c78d..c99e6a7 100644
--- a/lib/cpp/src/thrift/transport/TShortReadTransport.h
+++ b/lib/cpp/src/thrift/transport/TShortReadTransport.h
@@ -38,8 +38,10 @@
*/
class TShortReadTransport : public TVirtualTransport<TShortReadTransport> {
public:
- TShortReadTransport(std::shared_ptr<TTransport> transport, double full_prob)
- : transport_(transport), fullProb_(full_prob) {}
+ TShortReadTransport(std::shared_ptr<TTransport> transport, double full_prob,
+ std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(config), transport_(transport), fullProb_(full_prob) {
+ }
bool isOpen() const override { return transport_->isOpen(); }
@@ -50,6 +52,7 @@
void close() override { transport_->close(); }
uint32_t read(uint8_t* buf, uint32_t len) {
+ checkReadBytesAvailable(len);
if (len == 0) {
return 0;
}
@@ -62,11 +65,17 @@
void write(const uint8_t* buf, uint32_t len) { transport_->write(buf, len); }
- void flush() override { transport_->flush(); }
+ void flush() override {
+ resetConsumedMessageSize();
+ transport_->flush();
+ }
const uint8_t* borrow(uint8_t* buf, uint32_t* len) { return transport_->borrow(buf, len); }
- void consume(uint32_t len) { return transport_->consume(len); }
+ void consume(uint32_t len) {
+ countConsumedMessageBytes(len);
+ return transport_->consume(len);
+ }
std::shared_ptr<TTransport> getUnderlyingTransport() { return transport_; }
diff --git a/lib/cpp/src/thrift/transport/TSimpleFileTransport.cpp b/lib/cpp/src/thrift/transport/TSimpleFileTransport.cpp
index 4b1399e..c41affb 100644
--- a/lib/cpp/src/thrift/transport/TSimpleFileTransport.cpp
+++ b/lib/cpp/src/thrift/transport/TSimpleFileTransport.cpp
@@ -35,8 +35,8 @@
namespace thrift {
namespace transport {
-TSimpleFileTransport::TSimpleFileTransport(const std::string& path, bool read, bool write)
- : TFDTransport(-1, TFDTransport::CLOSE_ON_DESTROY) {
+TSimpleFileTransport::TSimpleFileTransport(const std::string& path, bool read, bool write, std::shared_ptr<TConfiguration> config)
+ : TFDTransport(-1, TFDTransport::CLOSE_ON_DESTROY, config) {
int flags = 0;
if (read && write) {
flags = O_RDWR;
diff --git a/lib/cpp/src/thrift/transport/TSimpleFileTransport.h b/lib/cpp/src/thrift/transport/TSimpleFileTransport.h
index 32e1897..24741b0 100644
--- a/lib/cpp/src/thrift/transport/TSimpleFileTransport.h
+++ b/lib/cpp/src/thrift/transport/TSimpleFileTransport.h
@@ -33,7 +33,8 @@
*/
class TSimpleFileTransport : public TFDTransport {
public:
- TSimpleFileTransport(const std::string& path, bool read = true, bool write = false);
+ TSimpleFileTransport(const std::string& path, bool read = true, bool write = false,
+ std::shared_ptr<TConfiguration> config = nullptr);
};
}
}
diff --git a/lib/cpp/src/thrift/transport/TSocket.cpp b/lib/cpp/src/thrift/transport/TSocket.cpp
index a1a6dfb..81aaccf 100644
--- a/lib/cpp/src/thrift/transport/TSocket.cpp
+++ b/lib/cpp/src/thrift/transport/TSocket.cpp
@@ -77,8 +77,9 @@
*
*/
-TSocket::TSocket(const string& host, int port)
- : host_(host),
+TSocket::TSocket(const string& host, int port, std::shared_ptr<TConfiguration> config)
+ : TVirtualTransport(config),
+ host_(host),
port_(port),
socket_(THRIFT_INVALID_SOCKET),
peerPort_(0),
@@ -92,8 +93,9 @@
maxRecvRetries_(5) {
}
-TSocket::TSocket(const string& path)
- : port_(0),
+TSocket::TSocket(const string& path, std::shared_ptr<TConfiguration> config)
+ : TVirtualTransport(config),
+ port_(0),
path_(path),
socket_(THRIFT_INVALID_SOCKET),
peerPort_(0),
@@ -108,8 +110,9 @@
cachedPeerAddr_.ipv4.sin_family = AF_UNSPEC;
}
-TSocket::TSocket()
- : port_(0),
+TSocket::TSocket(std::shared_ptr<TConfiguration> config)
+ : TVirtualTransport(config),
+ port_(0),
socket_(THRIFT_INVALID_SOCKET),
peerPort_(0),
connTimeout_(0),
@@ -123,8 +126,9 @@
cachedPeerAddr_.ipv4.sin_family = AF_UNSPEC;
}
-TSocket::TSocket(THRIFT_SOCKET socket)
- : port_(0),
+TSocket::TSocket(THRIFT_SOCKET socket, std::shared_ptr<TConfiguration> config)
+ : TVirtualTransport(config),
+ port_(0),
socket_(socket),
peerPort_(0),
connTimeout_(0),
@@ -144,8 +148,10 @@
#endif
}
-TSocket::TSocket(THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener)
- : port_(0),
+TSocket::TSocket(THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener,
+ std::shared_ptr<TConfiguration> config)
+ : TVirtualTransport(config),
+ port_(0),
socket_(socket),
peerPort_(0),
interruptListener_(interruptListener),
@@ -522,6 +528,7 @@
}
uint32_t TSocket::read(uint8_t* buf, uint32_t len) {
+ checkReadBytesAvailable(len);
if (socket_ == THRIFT_INVALID_SOCKET) {
throw TTransportException(TTransportException::NOT_OPEN, "Called read on non-open socket");
}
diff --git a/lib/cpp/src/thrift/transport/TSocket.h b/lib/cpp/src/thrift/transport/TSocket.h
index b0e8ade..043f0de 100644
--- a/lib/cpp/src/thrift/transport/TSocket.h
+++ b/lib/cpp/src/thrift/transport/TSocket.h
@@ -52,7 +52,7 @@
* socket.
*
*/
- TSocket();
+ TSocket(std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructs a new socket. Note that this does NOT actually connect the
@@ -61,7 +61,7 @@
* @param host An IP address or hostname to connect to
* @param port The port to connect on
*/
- TSocket(const std::string& host, int port);
+ TSocket(const std::string& host, int port, std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructs a new Unix domain socket.
@@ -69,7 +69,7 @@
*
* @param path The Unix domain socket e.g. "/tmp/ThriftTest.binary.thrift"
*/
- TSocket(const std::string& path);
+ TSocket(const std::string& path, std::shared_ptr<TConfiguration> config = nullptr);
/**
* Destroyes the socket object, closing it if necessary.
@@ -264,13 +264,14 @@
/**
* Constructor to create socket from file descriptor.
*/
- TSocket(THRIFT_SOCKET socket);
+ TSocket(THRIFT_SOCKET socket, std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructor to create socket from file descriptor that
* can be interrupted safely.
*/
- TSocket(THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener);
+ TSocket(THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener,
+ std::shared_ptr<TConfiguration> config = nullptr);
/**
* Set a cache of the peer address (used when trivially available: e.g.
diff --git a/lib/cpp/src/thrift/transport/TTransport.h b/lib/cpp/src/thrift/transport/TTransport.h
index 6397882..5f657f8 100644
--- a/lib/cpp/src/thrift/transport/TTransport.h
+++ b/lib/cpp/src/thrift/transport/TTransport.h
@@ -21,6 +21,7 @@
#define _THRIFT_TRANSPORT_TTRANSPORT_H_ 1
#include <thrift/Thrift.h>
+#include <thrift/TConfiguration.h>
#include <thrift/transport/TTransportException.h>
#include <memory>
#include <string>
@@ -55,6 +56,15 @@
*/
class TTransport {
public:
+ TTransport(std::shared_ptr<TConfiguration> config = nullptr) {
+ if(config == nullptr) {
+ configuration_ = std::shared_ptr<TConfiguration> (new TConfiguration());
+ } else {
+ configuration_ = config;
+ }
+ resetConsumedMessageSize();
+ }
+
/**
* Virtual deconstructor.
*/
@@ -238,11 +248,87 @@
*/
virtual const std::string getOrigin() const { return "Unknown"; }
-protected:
+ std::shared_ptr<TConfiguration> getConfiguration() { return configuration_; }
+
+ void setConfiguration(std::shared_ptr<TConfiguration> config) {
+ if (config != nullptr) configuration_ = config;
+ }
+
/**
- * Simple constructor.
+ * Updates RemainingMessageSize to reflect then known real message size (e.g. framed transport).
+ * Will throw if we already consumed too many bytes or if the new size is larger than allowed.
+ *
+ * @param size real message size
*/
- TTransport() = default;
+ void updateKnownMessageSize(long int size)
+ {
+ long int consumed = knownMessageSize_ - remainingMessageSize_;
+ resetConsumedMessageSize(size);
+ countConsumedMessageBytes(consumed);
+ }
+
+ /**
+ * Throws if there are not enough bytes in the input stream to satisfy a read of numBytes bytes of data
+ *
+ * @param numBytes numBytes bytes of data
+ */
+ void checkReadBytesAvailable(long int numBytes)
+ {
+ if (remainingMessageSize_ < numBytes)
+ throw new TTransportException(TTransportException::END_OF_FILE, "MaxMessageSize reached");
+ }
+
+protected:
+ std::shared_ptr<TConfiguration> configuration_;
+ long int remainingMessageSize_;
+ long int knownMessageSize_;
+
+ inline long int getRemainingMessageSize() { return remainingMessageSize_; }
+ inline void setRemainingMessageSize(long int remainingMessageSize) { remainingMessageSize_ = remainingMessageSize; }
+ inline int getMaxMessageSize() { return configuration_->getMaxMessageSize(); }
+ inline long int getKnownMessageSize() { return knownMessageSize_; }
+ void setKnownMessageSize(long int knownMessageSize) { knownMessageSize_ = knownMessageSize; }
+
+ /**
+ * Resets RemainingMessageSize to the configured maximum
+ *
+ * @param newSize configured size
+ */
+ void resetConsumedMessageSize(long newSize = -1)
+ {
+ // full reset
+ if (newSize < 0)
+ {
+ knownMessageSize_ = getMaxMessageSize();
+ remainingMessageSize_ = getMaxMessageSize();
+ return;
+ }
+
+ // update only: message size can shrink, but not grow
+ if (newSize > knownMessageSize_)
+ throw new TTransportException(TTransportException::END_OF_FILE, "MaxMessageSize reached");
+
+ knownMessageSize_ = newSize;
+ remainingMessageSize_ = newSize;
+ }
+
+ /**
+ * Consumes numBytes from the RemainingMessageSize.
+ *
+ * @param numBytes Consumes numBytes
+ */
+ void countConsumedMessageBytes(long int numBytes)
+ {
+ if (remainingMessageSize_ >= numBytes)
+ {
+ remainingMessageSize_ -= numBytes;
+ }
+ else
+ {
+ remainingMessageSize_ = 0;
+ throw new TTransportException(TTransportException::END_OF_FILE, "MaxMessageSize reached");
+ }
+ }
};
/**
diff --git a/lib/cpp/src/thrift/transport/TTransportUtils.cpp b/lib/cpp/src/thrift/transport/TTransportUtils.cpp
index 69372f3..427a2e7 100644
--- a/lib/cpp/src/thrift/transport/TTransportUtils.cpp
+++ b/lib/cpp/src/thrift/transport/TTransportUtils.cpp
@@ -26,6 +26,7 @@
namespace transport {
uint32_t TPipedTransport::read(uint8_t* buf, uint32_t len) {
+ checkReadBytesAvailable(len);
uint32_t need = len;
// We don't have enough data yet
@@ -104,8 +105,9 @@
TPipedFileReaderTransport::TPipedFileReaderTransport(
std::shared_ptr<TFileReaderTransport> srcTrans,
- std::shared_ptr<TTransport> dstTrans)
- : TPipedTransport(srcTrans, dstTrans), srcTrans_(srcTrans) {
+ std::shared_ptr<TTransport> dstTrans,
+ std::shared_ptr<TConfiguration> config)
+ : TPipedTransport(srcTrans, dstTrans, config), srcTrans_(srcTrans) {
}
TPipedFileReaderTransport::~TPipedFileReaderTransport() = default;
@@ -131,6 +133,7 @@
}
uint32_t TPipedFileReaderTransport::readAll(uint8_t* buf, uint32_t len) {
+ checkReadBytesAvailable(len);
uint32_t have = 0;
uint32_t get = 0;
diff --git a/lib/cpp/src/thrift/transport/TTransportUtils.h b/lib/cpp/src/thrift/transport/TTransportUtils.h
index 28c93d2..68c25f4 100644
--- a/lib/cpp/src/thrift/transport/TTransportUtils.h
+++ b/lib/cpp/src/thrift/transport/TTransportUtils.h
@@ -63,8 +63,10 @@
*/
class TPipedTransport : virtual public TTransport {
public:
- TPipedTransport(std::shared_ptr<TTransport> srcTrans, std::shared_ptr<TTransport> dstTrans)
- : srcTrans_(srcTrans),
+ TPipedTransport(std::shared_ptr<TTransport> srcTrans, std::shared_ptr<TTransport> dstTrans,
+ std::shared_ptr<TConfiguration> config = nullptr)
+ : TTransport(config),
+ srcTrans_(srcTrans),
dstTrans_(dstTrans),
rBufSize_(512),
rPos_(0),
@@ -88,8 +90,10 @@
TPipedTransport(std::shared_ptr<TTransport> srcTrans,
std::shared_ptr<TTransport> dstTrans,
- uint32_t sz)
- : srcTrans_(srcTrans),
+ uint32_t sz,
+ std::shared_ptr<TConfiguration> config = nullptr)
+ : TTransport(config),
+ srcTrans_(srcTrans),
dstTrans_(dstTrans),
rBufSize_(512),
rPos_(0),
@@ -241,7 +245,8 @@
class TPipedFileReaderTransport : public TPipedTransport, public TFileReaderTransport {
public:
TPipedFileReaderTransport(std::shared_ptr<TFileReaderTransport> srcTrans,
- std::shared_ptr<TTransport> dstTrans);
+ std::shared_ptr<TTransport> dstTrans,
+ std::shared_ptr<TConfiguration> config = nullptr);
~TPipedFileReaderTransport() override;
diff --git a/lib/cpp/src/thrift/transport/TVirtualTransport.h b/lib/cpp/src/thrift/transport/TVirtualTransport.h
index 0a04857..44bfa13 100644
--- a/lib/cpp/src/thrift/transport/TVirtualTransport.h
+++ b/lib/cpp/src/thrift/transport/TVirtualTransport.h
@@ -57,7 +57,7 @@
void consume(uint32_t len) { this->TTransport::consume_virt(len); }
protected:
- TTransportDefaults() = default;
+ TTransportDefaults(std::shared_ptr<TConfiguration> config = nullptr) : TTransport(config) {}
};
/**
@@ -118,7 +118,7 @@
}
protected:
- TVirtualTransport() = default;
+ TVirtualTransport() : Super_() {}
/*
* Templatized constructors, to allow arguments to be passed to the Super_
diff --git a/lib/cpp/src/thrift/transport/TWebSocketServer.h b/lib/cpp/src/thrift/transport/TWebSocketServer.h
index 2e94c83..7f39f36 100644
--- a/lib/cpp/src/thrift/transport/TWebSocketServer.h
+++ b/lib/cpp/src/thrift/transport/TWebSocketServer.h
@@ -53,8 +53,8 @@
template <bool binary>
class TWebSocketServer : public THttpServer {
public:
- TWebSocketServer(std::shared_ptr<TTransport> transport)
- : THttpServer(transport) {
+ TWebSocketServer(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config = nullptr)
+ : THttpServer(transport, config) {
resetHandshake();
}
@@ -98,6 +98,7 @@
}
void flush() override {
+ resetConsumedMessageSize();
writeFrameHeader();
uint8_t* buffer;
uint32_t length;
diff --git a/lib/cpp/src/thrift/transport/TZlibTransport.cpp b/lib/cpp/src/thrift/transport/TZlibTransport.cpp
index b4c43d6..657ce52 100644
--- a/lib/cpp/src/thrift/transport/TZlibTransport.cpp
+++ b/lib/cpp/src/thrift/transport/TZlibTransport.cpp
@@ -136,6 +136,7 @@
}
uint32_t TZlibTransport::read(uint8_t* buf, uint32_t len) {
+ checkReadBytesAvailable(len);
uint32_t need = len;
// TODO(dreiss): Skip urbuf on big reads.
@@ -265,6 +266,7 @@
}
flushToTransport(Z_FULL_FLUSH);
+ resetConsumedMessageSize();
}
void TZlibTransport::finish() {
@@ -335,6 +337,7 @@
}
void TZlibTransport::consume(uint32_t len) {
+ countConsumedMessageBytes(len);
if (readAvail() >= (int)len) {
urpos_ += len;
} else {
diff --git a/lib/cpp/src/thrift/transport/TZlibTransport.h b/lib/cpp/src/thrift/transport/TZlibTransport.h
index 4990aff..85765e6 100644
--- a/lib/cpp/src/thrift/transport/TZlibTransport.h
+++ b/lib/cpp/src/thrift/transport/TZlibTransport.h
@@ -83,8 +83,10 @@
int crbuf_size = DEFAULT_CRBUF_SIZE,
int uwbuf_size = DEFAULT_UWBUF_SIZE,
int cwbuf_size = DEFAULT_CWBUF_SIZE,
- int16_t comp_level = Z_DEFAULT_COMPRESSION)
- : transport_(transport),
+ int16_t comp_level = Z_DEFAULT_COMPRESSION,
+ std::shared_ptr<TConfiguration> config = nullptr)
+ : TVirtualTransport(config),
+ transport_(transport),
urpos_(0),
uwpos_(0),
input_ended_(false),
diff --git a/lib/cpp/test/CMakeLists.txt b/lib/cpp/test/CMakeLists.txt
index 48e2fd3..ced78a2 100644
--- a/lib/cpp/test/CMakeLists.txt
+++ b/lib/cpp/test/CMakeLists.txt
@@ -81,6 +81,7 @@
TypedefTest.cpp
TServerSocketTest.cpp
TServerTransportTest.cpp
+ ThrifttReadCheckTests.cpp
)
add_executable(UnitTests ${UnitTest_SOURCES})
diff --git a/lib/cpp/test/Makefile.am b/lib/cpp/test/Makefile.am
index 8982683..7f630db 100755
--- a/lib/cpp/test/Makefile.am
+++ b/lib/cpp/test/Makefile.am
@@ -130,7 +130,8 @@
TypedefTest.cpp \
TServerSocketTest.cpp \
TServerTransportTest.cpp \
- TTransportCheckThrow.h
+ TTransportCheckThrow.h \
+ ThrifttReadCheckTests.cpp
UnitTests_LDADD = \
libtestgencpp.la \
diff --git a/lib/cpp/test/ThrifttReadCheckTests.cpp b/lib/cpp/test/ThrifttReadCheckTests.cpp
new file mode 100644
index 0000000..4a594e6
--- /dev/null
+++ b/lib/cpp/test/ThrifttReadCheckTests.cpp
@@ -0,0 +1,227 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#define MAX_MESSAGE_SIZE 2
+
+#include <boost/test/auto_unit_test.hpp>
+#include <boost/test/unit_test.hpp>
+#include <iostream>
+#include <climits>
+#include <vector>
+#include <thrift/TConfiguration.h>
+#include <thrift/protocol/TBinaryProtocol.h>
+#include <thrift/protocol/TCompactProtocol.h>
+#include <thrift/protocol/TJSONProtocol.h>
+#include <thrift/Thrift.h>
+#include <memory>
+#include <thrift/transport/TTransportUtils.h>
+#include <thrift/transport/TBufferTransports.h>
+#include <thrift/transport/TSimpleFileTransport.h>
+#include <thrift/transport/TFileTransport.h>
+#include <thrift/protocol/TEnum.h>
+#include <thrift/protocol/TList.h>
+#include <thrift/protocol/TSet.h>
+#include <thrift/protocol/TMap.h>
+
+BOOST_AUTO_TEST_SUITE(ThriftReadCheckExceptionTest)
+
+using apache::thrift::TConfiguration;
+using apache::thrift::protocol::TBinaryProtocol;
+using apache::thrift::protocol::TCompactProtocol;
+using apache::thrift::protocol::TJSONProtocol;
+using apache::thrift::protocol::TType;
+using apache::thrift::transport::TPipedTransport;
+using apache::thrift::transport::TMemoryBuffer;
+using apache::thrift::transport::TSimpleFileTransport;
+using apache::thrift::transport::TFileTransport;
+using apache::thrift::transport::TFDTransport;
+using apache::thrift::transport::TTransportException;
+using apache::thrift::transport::TBufferedTransport;
+using apache::thrift::transport::TFramedTransport;
+using std::shared_ptr;
+using std::cout;
+using std::endl;
+using std::string;
+using std::memset;
+using namespace apache::thrift;
+using namespace apache::thrift::protocol;
+
+
+BOOST_AUTO_TEST_CASE(test_tmemorybuffer_read_check_exception) {
+ std::shared_ptr<TConfiguration> config(new TConfiguration(MAX_MESSAGE_SIZE));
+ TMemoryBuffer trans_out(config);
+ uint8_t buffer[6] = {1, 2, 3, 4, 5, 6};
+ trans_out.write((const uint8_t*)buffer, sizeof(buffer));
+ trans_out.close();
+
+ TMemoryBuffer trans_in(config);
+ memset(buffer, 0, sizeof(buffer));
+ BOOST_CHECK_THROW(trans_in.read(buffer, sizeof(buffer)), TTransportException*);
+ trans_in.close();
+}
+
+BOOST_AUTO_TEST_CASE(test_tpipedtransport_read_check_exception) {
+ std::shared_ptr<TConfiguration> config(new TConfiguration(MAX_MESSAGE_SIZE));
+ std::shared_ptr<TMemoryBuffer> pipe(new TMemoryBuffer);
+ std::shared_ptr<TMemoryBuffer> underlying(new TMemoryBuffer);
+ std::shared_ptr<TPipedTransport> trans(new TPipedTransport(underlying, pipe, config));
+
+ uint8_t buffer[4];
+
+ underlying->write((uint8_t*)"abcd", 4);
+ BOOST_CHECK_THROW(trans->read(buffer, sizeof(buffer)), TTransportException*);
+ BOOST_CHECK_THROW(trans->readAll(buffer, sizeof(buffer)), TTransportException*);
+ trans->readEnd();
+ pipe->resetBuffer();
+ underlying->write((uint8_t*)"ef", 2);
+ BOOST_CHECK_THROW(trans->read(buffer, sizeof(buffer)), TTransportException*);
+ BOOST_CHECK_THROW(trans->readAll(buffer, sizeof(buffer)), TTransportException*);
+ trans->readEnd();
+}
+
+BOOST_AUTO_TEST_CASE(test_tsimplefiletransport_read_check_exception) {
+ std::shared_ptr<TConfiguration> config(new TConfiguration(MAX_MESSAGE_SIZE));
+ TSimpleFileTransport trans_out("data", false, true, config);
+ uint8_t buffer[6] = {1, 2, 3, 4, 5, 6};
+ trans_out.write((const uint8_t*)buffer, sizeof(buffer));
+ trans_out.close();
+
+ TSimpleFileTransport trans_in("data",true, false, config);
+ memset(buffer, 0, sizeof(buffer));
+ BOOST_CHECK_THROW(trans_in.read(buffer, sizeof(buffer)), TTransportException*);
+ trans_in.close();
+
+ remove("./data");
+}
+
+BOOST_AUTO_TEST_CASE(test_tfiletransport_read_check_exception) {
+ std::shared_ptr<TConfiguration> config(new TConfiguration(MAX_MESSAGE_SIZE));
+ TFileTransport trans_out("data", false, config);
+ uint8_t buffer[6] = {1, 2, 3, 4, 5, 6};
+ trans_out.write((const uint8_t*)buffer, sizeof(buffer));
+
+ TFileTransport trans_in("data", false, config);
+ memset(buffer, 0, sizeof(buffer));
+ BOOST_CHECK_THROW(trans_in.read(buffer, sizeof(buffer)), TTransportException*);
+
+ remove("./data");
+}
+
+BOOST_AUTO_TEST_CASE(test_tbufferedtransport_read_check_exception) {
+ uint8_t arr[4] = {1, 2, 3, 4};
+ std::shared_ptr<TMemoryBuffer> buffer (new TMemoryBuffer(arr, sizeof(arr)));
+ std::shared_ptr<TConfiguration> config (new TConfiguration(MAX_MESSAGE_SIZE));
+ std::shared_ptr<TBufferedTransport> trans (new TBufferedTransport(buffer, config));
+
+ trans->write((const uint8_t*)arr, sizeof(arr));
+ BOOST_CHECK_THROW(trans->read(arr, sizeof(arr)), TTransportException*);
+}
+
+BOOST_AUTO_TEST_CASE(test_tframedtransport_read_check_exception) {
+ uint8_t arr[4] = {1, 2, 3, 4};
+ std::shared_ptr<TMemoryBuffer> buffer (new TMemoryBuffer(arr, sizeof(arr)));
+ std::shared_ptr<TConfiguration> config (new TConfiguration(MAX_MESSAGE_SIZE));
+ std::shared_ptr<TFramedTransport> trans (new TFramedTransport(buffer, config));
+
+ trans->write((const uint8_t*)arr, sizeof(arr));
+ BOOST_CHECK_THROW(trans->read(arr, sizeof(arr)), TTransportException*);
+}
+
+BOOST_AUTO_TEST_CASE(test_tthriftbinaryprotocol_read_check_exception) {
+ std::shared_ptr<TConfiguration> config (new TConfiguration(MAX_MESSAGE_SIZE));
+ std::shared_ptr<TMemoryBuffer> transport(new TMemoryBuffer(config));
+ std::shared_ptr<TBinaryProtocol> protocol(new TBinaryProtocol(transport));
+
+ uint32_t val = 0;
+ TType elemType = apache::thrift::protocol::T_STOP;
+ TType elemType1 = apache::thrift::protocol::T_STOP;
+ TList list(T_I32, 8);
+ protocol->writeListBegin(list.elemType_, list.size_);
+ protocol->writeListEnd();
+ BOOST_CHECK_THROW(protocol->readListBegin(elemType, val), TTransportException*);
+ protocol->readListEnd();
+
+ TSet set(T_I32, 8);
+ protocol->writeSetBegin(set.elemType_, set.size_);
+ protocol->writeSetEnd();
+ BOOST_CHECK_THROW(protocol->readSetBegin(elemType, val), TTransportException*);
+ protocol->readSetEnd();
+
+ TMap map(T_I32, T_I32, 8);
+ protocol->writeMapBegin(map.keyType_, map.valueType_, map.size_);
+ protocol->writeMapEnd();
+ BOOST_CHECK_THROW(protocol->readMapBegin(elemType, elemType1, val), TTransportException*);
+ protocol->readMapEnd();
+}
+
+BOOST_AUTO_TEST_CASE(test_tthriftcompactprotocol_read_check_exception) {
+ std::shared_ptr<TConfiguration> config (new TConfiguration(MAX_MESSAGE_SIZE));
+ std::shared_ptr<TMemoryBuffer> transport(new TMemoryBuffer(config));
+ std::shared_ptr<TCompactProtocol> protocol(new TCompactProtocol(transport));
+
+ uint32_t val = 0;
+ TType elemType = apache::thrift::protocol::T_STOP;
+ TType elemType1 = apache::thrift::protocol::T_STOP;
+ TList list(T_I32, 8);
+ protocol->writeListBegin(list.elemType_, list.size_);
+ protocol->writeListEnd();
+ BOOST_CHECK_THROW(protocol->readListBegin(elemType, val), TTransportException*);
+ protocol->readListEnd();
+
+ TSet set(T_I32, 8);
+ protocol->writeSetBegin(set.elemType_, set.size_);
+ protocol->writeSetEnd();
+ BOOST_CHECK_THROW(protocol->readSetBegin(elemType, val), TTransportException*);
+ protocol->readSetEnd();
+
+ TMap map(T_I32, T_I32, 8);
+ protocol->writeMapBegin(map.keyType_, map.valueType_, map.size_);
+ protocol->writeMapEnd();
+ BOOST_CHECK_THROW(protocol->readMapBegin(elemType, elemType1, val), TTransportException*);
+ protocol->readMapEnd();
+}
+
+BOOST_AUTO_TEST_CASE(test_tthriftjsonprotocol_read_check_exception) {
+ std::shared_ptr<TConfiguration> config (new TConfiguration(MAX_MESSAGE_SIZE));
+ std::shared_ptr<TMemoryBuffer> transport(new TMemoryBuffer(config));
+ std::shared_ptr<TJSONProtocol> protocol(new TJSONProtocol(transport));
+
+ uint32_t val = 0;
+ TType elemType = apache::thrift::protocol::T_STOP;
+ TType elemType1 = apache::thrift::protocol::T_STOP;
+ TList list(T_I32, 8);
+ protocol->writeListBegin(list.elemType_, list.size_);
+ protocol->writeListEnd();
+ BOOST_CHECK_THROW(protocol->readListBegin(elemType, val), TTransportException*);
+ protocol->readListEnd();
+
+ TSet set(T_I32, 8);
+ protocol->writeSetBegin(set.elemType_, set.size_);
+ protocol->writeSetEnd();
+ BOOST_CHECK_THROW(protocol->readSetBegin(elemType, val), TTransportException*);
+ protocol->readSetEnd();
+
+ TMap map(T_I32, T_I32, 8);
+ protocol->writeMapBegin(map.keyType_, map.valueType_, map.size_);
+ protocol->writeMapEnd();
+ BOOST_CHECK_THROW(protocol->readMapBegin(elemType, elemType1, val), TTransportException*);
+ protocol->readMapEnd();
+}
+
+BOOST_AUTO_TEST_SUITE_END()