[THRIFT-5871] Add message / container size checking for Rust
Bring the Rust implementation somewhat up to par with the other implementations.
I tried 4-5 different ways to get the "perfect" check but since trait specialization is not yet stable,
I was not able to arrive at a solution I'm happy with (code was either ugly, or had runtime overhead).
So for now, we avoid full message size tracking / more precise limit checking, but this is a strong step
in the right direction.
diff --git a/lib/rs/src/configuration.rs b/lib/rs/src/configuration.rs
new file mode 100644
index 0000000..0f786f4
--- /dev/null
+++ b/lib/rs/src/configuration.rs
@@ -0,0 +1,178 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+/// Configuration for Thrift protocols.
+#[derive(Debug, Clone)]
+pub struct TConfiguration {
+ max_message_size: Option<usize>,
+ max_frame_size: Option<usize>,
+ max_recursion_depth: Option<usize>,
+ max_container_size: Option<usize>,
+ max_string_size: Option<usize>,
+}
+
+impl TConfiguration {
+ // this value is used consistently across all Thrift libraries
+ pub const DEFAULT_MAX_MESSAGE_SIZE: usize = 100 * 1024 * 1024;
+
+ // this value is used consistently across all Thrift libraries
+ pub const DEFAULT_MAX_FRAME_SIZE: usize = 16_384_000;
+
+ pub const DEFAULT_RECURSION_LIMIT: usize = 64;
+
+ pub const DEFAULT_CONTAINER_LIMIT: Option<usize> = None;
+
+ pub const DEFAULT_STRING_LIMIT: usize = 100 * 1024 * 1024;
+
+ pub fn no_limits() -> Self {
+ Self {
+ max_message_size: None,
+ max_frame_size: None,
+ max_recursion_depth: None,
+ max_container_size: None,
+ max_string_size: None,
+ }
+ }
+
+ pub fn max_message_size(&self) -> Option<usize> {
+ self.max_message_size
+ }
+
+ pub fn max_frame_size(&self) -> Option<usize> {
+ self.max_frame_size
+ }
+
+ pub fn max_recursion_depth(&self) -> Option<usize> {
+ self.max_recursion_depth
+ }
+
+ pub fn max_container_size(&self) -> Option<usize> {
+ self.max_container_size
+ }
+
+ pub fn max_string_size(&self) -> Option<usize> {
+ self.max_string_size
+ }
+
+ pub fn builder() -> TConfigurationBuilder {
+ TConfigurationBuilder::default()
+ }
+}
+
+impl Default for TConfiguration {
+ fn default() -> Self {
+ Self {
+ max_message_size: Some(Self::DEFAULT_MAX_MESSAGE_SIZE),
+ max_frame_size: Some(Self::DEFAULT_MAX_FRAME_SIZE),
+ max_recursion_depth: Some(Self::DEFAULT_RECURSION_LIMIT),
+ max_container_size: Self::DEFAULT_CONTAINER_LIMIT,
+ max_string_size: Some(Self::DEFAULT_STRING_LIMIT),
+ }
+ }
+}
+
+#[derive(Debug, Default)]
+pub struct TConfigurationBuilder {
+ config: TConfiguration,
+}
+
+impl TConfigurationBuilder {
+ pub fn max_message_size(mut self, limit: Option<usize>) -> Self {
+ self.config.max_message_size = limit;
+ self
+ }
+
+ pub fn max_frame_size(mut self, limit: Option<usize>) -> Self {
+ self.config.max_frame_size = limit;
+ self
+ }
+
+ pub fn max_recursion_depth(mut self, limit: Option<usize>) -> Self {
+ self.config.max_recursion_depth = limit;
+ self
+ }
+
+ pub fn max_container_size(mut self, limit: Option<usize>) -> Self {
+ self.config.max_container_size = limit;
+ self
+ }
+
+ pub fn max_string_size(mut self, limit: Option<usize>) -> Self {
+ self.config.max_string_size = limit;
+ self
+ }
+
+ pub fn build(self) -> crate::Result<TConfiguration> {
+ if let (Some(frame_size), Some(message_size)) =
+ (self.config.max_frame_size, self.config.max_message_size)
+ {
+ if frame_size > message_size {
+ // FIXME: This should probably be a different error type.
+ return Err(crate::Error::Application(crate::ApplicationError::new(
+ crate::ApplicationErrorKind::Unknown,
+ format!(
+ "Invalid configuration: max_frame_size ({}) cannot exceed max_message_size ({})",
+ frame_size, message_size
+ ),
+ )));
+ }
+ }
+
+ Ok(self.config)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_custom_configuration_builder() {
+ let config = TConfiguration::builder()
+ .max_message_size(Some(1024))
+ .max_frame_size(Some(512))
+ .max_recursion_depth(Some(10))
+ .max_container_size(Some(100))
+ .max_string_size(Some(256))
+ .build()
+ .unwrap();
+
+ assert_eq!(config.max_message_size(), Some(1024));
+ assert_eq!(config.max_frame_size(), Some(512));
+ assert_eq!(config.max_recursion_depth(), Some(10));
+ assert_eq!(config.max_container_size(), Some(100));
+ assert_eq!(config.max_string_size(), Some(256));
+ }
+
+ #[test]
+ fn test_invalid_configuration() {
+ // Test that builder catches invalid configurations
+ let result = TConfiguration::builder()
+ .max_frame_size(Some(1000))
+ .max_message_size(Some(500)) // frame size > message size is invalid
+ .build();
+
+ assert!(result.is_err());
+ match result {
+ Err(crate::Error::Application(e)) => {
+ assert!(e.message.contains("max_frame_size"));
+ assert!(e.message.contains("cannot exceed max_message_size"));
+ }
+ _ => panic!("Expected Application error"),
+ }
+ }
+}
diff --git a/lib/rs/src/lib.rs b/lib/rs/src/lib.rs
index 2f60188..d3804ec 100644
--- a/lib/rs/src/lib.rs
+++ b/lib/rs/src/lib.rs
@@ -21,10 +21,11 @@
//! Thrift server and client. It is divided into the following modules:
//!
//! 1. errors
-//! 2. protocol
-//! 3. transport
-//! 4. server
-//! 5. autogen
+//! 2. configuration
+//! 3. protocol
+//! 4. transport
+//! 5. server
+//! 6. autogen
//!
//! The modules are layered as shown in the diagram below. The `autogen'd`
//! layer is generated by the Thrift compiler's Rust plugin. It uses the
@@ -82,6 +83,9 @@
mod autogen;
pub use crate::autogen::*;
+mod configuration;
+pub use crate::configuration::*;
+
/// Result type returned by all runtime library functions.
///
/// As is convention this is a typedef of `std::result::Result`
diff --git a/lib/rs/src/protocol/binary.rs b/lib/rs/src/protocol/binary.rs
index b4b51f6..596285f 100644
--- a/lib/rs/src/protocol/binary.rs
+++ b/lib/rs/src/protocol/binary.rs
@@ -24,7 +24,7 @@
};
use super::{TOutputProtocol, TOutputProtocolFactory, TSetIdentifier, TStructIdentifier, TType};
use crate::transport::{TReadTransport, TWriteTransport};
-use crate::{ProtocolError, ProtocolErrorKind};
+use crate::{ProtocolError, ProtocolErrorKind, TConfiguration};
const BINARY_PROTOCOL_VERSION_1: u32 = 0x8001_0000;
@@ -57,6 +57,8 @@
{
strict: bool,
pub transport: T, // FIXME: shouldn't be public
+ config: TConfiguration,
+ recursion_depth: usize,
}
impl<T> TBinaryInputProtocol<T>
@@ -67,8 +69,29 @@
///
/// Set `strict` to `true` if all incoming messages contain the protocol
/// version number in the protocol header.
- pub fn new(transport: T, strict: bool) -> TBinaryInputProtocol<T> {
- TBinaryInputProtocol { strict, transport }
+ pub fn new(transport: T, strict: bool) -> Self {
+ Self::with_config(transport, strict, TConfiguration::default())
+ }
+
+ pub fn with_config(transport: T, strict: bool, config: TConfiguration) -> Self {
+ TBinaryInputProtocol {
+ strict,
+ transport,
+ config,
+ recursion_depth: 0,
+ }
+ }
+
+ fn check_recursion_depth(&self) -> crate::Result<()> {
+ if let Some(limit) = self.config.max_recursion_depth() {
+ if self.recursion_depth >= limit {
+ return Err(crate::Error::Protocol(ProtocolError::new(
+ ProtocolErrorKind::DepthLimit,
+ format!("Maximum recursion depth {} exceeded", limit),
+ )));
+ }
+ }
+ Ok(())
}
}
@@ -78,6 +101,7 @@
{
#[allow(clippy::collapsible_if)]
fn read_message_begin(&mut self) -> crate::Result<TMessageIdentifier> {
+ // TODO: Once specialization is stable, call the message size tracking here
let mut first_bytes = vec![0; 4];
self.transport.read_exact(&mut first_bytes[..])?;
@@ -130,10 +154,13 @@
}
fn read_struct_begin(&mut self) -> crate::Result<Option<TStructIdentifier>> {
+ self.check_recursion_depth()?;
+ self.recursion_depth += 1;
Ok(None)
}
fn read_struct_end(&mut self) -> crate::Result<()> {
+ self.recursion_depth -= 1;
Ok(())
}
@@ -154,8 +181,28 @@
}
fn read_bytes(&mut self) -> crate::Result<Vec<u8>> {
- let num_bytes = self.transport.read_i32::<BigEndian>()? as usize;
- let mut buf = vec![0u8; num_bytes];
+ let num_bytes = self.transport.read_i32::<BigEndian>()?;
+
+ if num_bytes < 0 {
+ return Err(crate::Error::Protocol(ProtocolError::new(
+ ProtocolErrorKind::NegativeSize,
+ format!("Negative byte array size: {}", num_bytes),
+ )));
+ }
+
+ if let Some(max_size) = self.config.max_string_size() {
+ if num_bytes as usize > max_size {
+ return Err(crate::Error::Protocol(ProtocolError::new(
+ ProtocolErrorKind::SizeLimit,
+ format!(
+ "Byte array size {} exceeds maximum allowed size of {}",
+ num_bytes, max_size
+ ),
+ )));
+ }
+ }
+
+ let mut buf = vec![0u8; num_bytes as usize];
self.transport
.read_exact(&mut buf)
.map(|_| buf)
@@ -206,6 +253,8 @@
fn read_list_begin(&mut self) -> crate::Result<TListIdentifier> {
let element_type: TType = self.read_byte().and_then(field_type_from_u8)?;
let size = self.read_i32()?;
+ let min_element_size = self.min_serialized_size(element_type);
+ super::check_container_size(&self.config, size, min_element_size)?;
Ok(TListIdentifier::new(element_type, size))
}
@@ -216,6 +265,8 @@
fn read_set_begin(&mut self) -> crate::Result<TSetIdentifier> {
let element_type: TType = self.read_byte().and_then(field_type_from_u8)?;
let size = self.read_i32()?;
+ let min_element_size = self.min_serialized_size(element_type);
+ super::check_container_size(&self.config, size, min_element_size)?;
Ok(TSetIdentifier::new(element_type, size))
}
@@ -227,6 +278,12 @@
let key_type: TType = self.read_byte().and_then(field_type_from_u8)?;
let value_type: TType = self.read_byte().and_then(field_type_from_u8)?;
let size = self.read_i32()?;
+
+ let key_min_size = self.min_serialized_size(key_type);
+ let value_min_size = self.min_serialized_size(value_type);
+ let element_size = key_min_size + value_min_size;
+ super::check_container_size(&self.config, size, element_size)?;
+
Ok(TMapIdentifier::new(key_type, value_type, size))
}
@@ -240,6 +297,26 @@
fn read_byte(&mut self) -> crate::Result<u8> {
self.transport.read_u8().map_err(From::from)
}
+
+ fn min_serialized_size(&self, field_type: TType) -> usize {
+ match field_type {
+ TType::Stop => 1, // 1 byte minimum
+ TType::Void => 1, // 1 byte minimum
+ TType::Bool => 1, // 1 byte
+ TType::I08 => 1, // 1 byte
+ TType::Double => 8, // 8 bytes
+ TType::I16 => 2, // 2 bytes
+ TType::I32 => 4, // 4 bytes
+ TType::I64 => 8, // 8 bytes
+ TType::String => 4, // 4 bytes for length prefix
+ TType::Struct => 1, // 1 byte minimum (stop field)
+ TType::Map => 4, // 4 bytes size
+ TType::Set => 4, // 4 bytes size
+ TType::List => 4, // 4 bytes size
+ TType::Uuid => 16, // 16 bytes
+ TType::Utf7 => 1, // 1 byte
+ }
+ }
}
/// Factory for creating instances of `TBinaryInputProtocol`.
@@ -514,14 +591,13 @@
#[cfg(test)]
mod tests {
+ use super::*;
use crate::protocol::{
TFieldIdentifier, TInputProtocol, TListIdentifier, TMapIdentifier, TMessageIdentifier,
TMessageType, TOutputProtocol, TSetIdentifier, TStructIdentifier, TType,
};
use crate::transport::{ReadHalf, TBufferChannel, TIoChannel, WriteHalf};
- use super::*;
-
#[test]
fn must_write_strict_message_call_begin() {
let (_, mut o_prot) = test_objects(true);
@@ -759,13 +835,26 @@
fn must_round_trip_list_begin() {
let (mut i_prot, mut o_prot) = test_objects(true);
- let ident = TListIdentifier::new(TType::List, 900);
+ let ident = TListIdentifier::new(TType::I32, 4);
assert!(o_prot.write_list_begin(&ident).is_ok());
+ assert!(o_prot.write_i32(10).is_ok());
+ assert!(o_prot.write_i32(20).is_ok());
+ assert!(o_prot.write_i32(30).is_ok());
+ assert!(o_prot.write_i32(40).is_ok());
+
+ assert!(o_prot.write_list_end().is_ok());
copy_write_buffer_to_read_buffer!(o_prot);
let received_ident = assert_success!(i_prot.read_list_begin());
assert_eq!(&received_ident, &ident);
+
+ assert_eq!(i_prot.read_i32().unwrap(), 10);
+ assert_eq!(i_prot.read_i32().unwrap(), 20);
+ assert_eq!(i_prot.read_i32().unwrap(), 30);
+ assert_eq!(i_prot.read_i32().unwrap(), 40);
+
+ assert!(i_prot.read_list_end().is_ok());
}
#[test]
@@ -789,14 +878,25 @@
fn must_round_trip_set_begin() {
let (mut i_prot, mut o_prot) = test_objects(true);
- let ident = TSetIdentifier::new(TType::I64, 2000);
+ let ident = TSetIdentifier::new(TType::I64, 3);
assert!(o_prot.write_set_begin(&ident).is_ok());
+ assert!(o_prot.write_i64(123).is_ok());
+ assert!(o_prot.write_i64(456).is_ok());
+ assert!(o_prot.write_i64(789).is_ok());
+
+ assert!(o_prot.write_set_end().is_ok());
copy_write_buffer_to_read_buffer!(o_prot);
let received_ident_result = i_prot.read_set_begin();
assert!(received_ident_result.is_ok());
assert_eq!(&received_ident_result.unwrap(), &ident);
+
+ assert_eq!(i_prot.read_i64().unwrap(), 123);
+ assert_eq!(i_prot.read_i64().unwrap(), 456);
+ assert_eq!(i_prot.read_i64().unwrap(), 789);
+
+ assert!(i_prot.read_set_end().is_ok());
}
#[test]
@@ -820,13 +920,26 @@
fn must_round_trip_map_begin() {
let (mut i_prot, mut o_prot) = test_objects(true);
- let ident = TMapIdentifier::new(TType::Map, TType::Set, 100);
+ let ident = TMapIdentifier::new(TType::String, TType::I32, 2);
assert!(o_prot.write_map_begin(&ident).is_ok());
+ assert!(o_prot.write_string("key1").is_ok());
+ assert!(o_prot.write_i32(100).is_ok());
+ assert!(o_prot.write_string("key2").is_ok());
+ assert!(o_prot.write_i32(200).is_ok());
+
+ assert!(o_prot.write_map_end().is_ok());
copy_write_buffer_to_read_buffer!(o_prot);
let received_ident = assert_success!(i_prot.read_map_begin());
assert_eq!(&received_ident, &ident);
+
+ assert_eq!(i_prot.read_string().unwrap(), "key1");
+ assert_eq!(i_prot.read_i32().unwrap(), 100);
+ assert_eq!(i_prot.read_string().unwrap(), "key2");
+ assert_eq!(i_prot.read_i32().unwrap(), 200);
+
+ assert!(i_prot.read_map_end().is_ok());
}
#[test]
@@ -963,7 +1076,7 @@
TBinaryInputProtocol<ReadHalf<TBufferChannel>>,
TBinaryOutputProtocol<WriteHalf<TBufferChannel>>,
) {
- let mem = TBufferChannel::with_capacity(40, 40);
+ let mem = TBufferChannel::with_capacity(200, 200);
let (r_mem, w_mem) = mem.split().unwrap();
@@ -981,4 +1094,154 @@
assert!(write_fn(&mut o_prot).is_ok());
assert_eq!(o_prot.transport.write_bytes().len(), 0);
}
+
+ #[test]
+ fn must_enforce_recursion_depth_limit() {
+ let mem = TBufferChannel::with_capacity(40, 40);
+ let (r_mem, _) = mem.split().unwrap();
+
+ let config = TConfiguration::builder()
+ .max_recursion_depth(Some(2))
+ .build()
+ .unwrap();
+ let mut i_prot = TBinaryInputProtocol::with_config(r_mem, true, config);
+
+ assert!(i_prot.read_struct_begin().is_ok());
+ assert_eq!(i_prot.recursion_depth, 1);
+
+ assert!(i_prot.read_struct_begin().is_ok());
+ assert_eq!(i_prot.recursion_depth, 2);
+
+ let result = i_prot.read_struct_begin();
+ assert!(result.is_err());
+ match result {
+ Err(crate::Error::Protocol(e)) => {
+ assert_eq!(e.kind, ProtocolErrorKind::DepthLimit);
+ }
+ _ => panic!("Expected protocol error with DepthLimit"),
+ }
+
+ assert!(i_prot.read_struct_end().is_ok());
+ assert_eq!(i_prot.recursion_depth, 1);
+ assert!(i_prot.read_struct_end().is_ok());
+ assert_eq!(i_prot.recursion_depth, 0);
+ }
+
+ #[test]
+ fn must_reject_negative_container_sizes() {
+ let mem = TBufferChannel::with_capacity(40, 40);
+ let (r_mem, mut w_mem) = mem.split().unwrap();
+
+ let mut i_prot = TBinaryInputProtocol::new(r_mem, true);
+
+ w_mem.set_readable_bytes(&[0x0F, 0xFF, 0xFF, 0xFF, 0xFF]);
+
+ let result = i_prot.read_list_begin();
+ assert!(result.is_err());
+ match result {
+ Err(crate::Error::Protocol(e)) => {
+ assert_eq!(e.kind, ProtocolErrorKind::NegativeSize);
+ }
+ _ => panic!("Expected protocol error with NegativeSize"),
+ }
+ }
+
+ #[test]
+ fn must_enforce_container_size_limit() {
+ let mem = TBufferChannel::with_capacity(40, 40);
+ let (r_mem, mut w_mem) = mem.split().unwrap();
+
+ let config = TConfiguration::builder()
+ .max_container_size(Some(100))
+ .build()
+ .unwrap();
+
+ let mut i_prot = TBinaryInputProtocol::with_config(r_mem, true, config);
+
+ w_mem.set_readable_bytes(&[0x0F, 0x00, 0x00, 0x00, 0xC8]);
+
+ let result = i_prot.read_list_begin();
+ assert!(result.is_err());
+ match result {
+ Err(crate::Error::Protocol(e)) => {
+ assert_eq!(e.kind, ProtocolErrorKind::SizeLimit);
+ assert!(e
+ .message
+ .contains("Container size 200 exceeds maximum allowed size of 100"));
+ }
+ _ => panic!("Expected protocol error with SizeLimit"),
+ }
+ }
+
+ #[test]
+ fn must_allow_containers_within_limit() {
+ let mem = TBufferChannel::with_capacity(200, 200);
+ let (r_mem, mut w_mem) = mem.split().unwrap();
+
+ // Create protocol with container limit of 100
+ let config = TConfiguration::builder()
+ .max_container_size(Some(100))
+ .build()
+ .unwrap();
+ let mut i_prot = TBinaryInputProtocol::with_config(r_mem, true, config);
+
+ let mut data = vec![0x08]; // TType::I32
+ data.extend_from_slice(&5i32.to_be_bytes()); // size = 5
+
+ for i in 1i32..=5i32 {
+ data.extend_from_slice(&(i * 10).to_be_bytes());
+ }
+
+ w_mem.set_readable_bytes(&data);
+
+ let result = i_prot.read_list_begin();
+ assert!(result.is_ok());
+ let list_ident = result.unwrap();
+ assert_eq!(list_ident.size, 5);
+ assert_eq!(list_ident.element_type, TType::I32);
+ }
+
+ #[test]
+ fn must_enforce_string_size_limit() {
+ let mem = TBufferChannel::with_capacity(100, 100);
+ let (r_mem, mut w_mem) = mem.split().unwrap();
+
+ let config = TConfiguration::builder()
+ .max_string_size(Some(1000))
+ .build()
+ .unwrap();
+ let mut i_prot = TBinaryInputProtocol::with_config(r_mem, true, config);
+
+ w_mem.set_readable_bytes(&[0x00, 0x00, 0x07, 0xD0]);
+
+ let result = i_prot.read_string();
+ assert!(result.is_err());
+ match result {
+ Err(crate::Error::Protocol(e)) => {
+ assert_eq!(e.kind, ProtocolErrorKind::SizeLimit);
+ assert!(e
+ .message
+ .contains("Byte array size 2000 exceeds maximum allowed size of 1000"));
+ }
+ _ => panic!("Expected protocol error with SizeLimit"),
+ }
+ }
+
+ #[test]
+ fn must_allow_strings_within_limit() {
+ let mem = TBufferChannel::with_capacity(100, 100);
+ let (r_mem, mut w_mem) = mem.split().unwrap();
+
+ let config = TConfiguration::builder()
+ .max_string_size(Some(1000))
+ .build()
+ .unwrap();
+ let mut i_prot = TBinaryInputProtocol::with_config(r_mem, true, config);
+
+ w_mem.set_readable_bytes(&[0x00, 0x00, 0x00, 0x05, b'h', b'e', b'l', b'l', b'o']);
+
+ let result = i_prot.read_string();
+ assert!(result.is_ok());
+ assert_eq!(result.unwrap(), "hello");
+ }
}
diff --git a/lib/rs/src/protocol/compact.rs b/lib/rs/src/protocol/compact.rs
index 4dc45ca..7e1a751 100644
--- a/lib/rs/src/protocol/compact.rs
+++ b/lib/rs/src/protocol/compact.rs
@@ -26,6 +26,7 @@
};
use super::{TOutputProtocol, TOutputProtocolFactory, TSetIdentifier, TStructIdentifier, TType};
use crate::transport::{TReadTransport, TWriteTransport};
+use crate::{ProtocolError, ProtocolErrorKind, TConfiguration};
const COMPACT_PROTOCOL_ID: u8 = 0x82;
const COMPACT_VERSION: u8 = 0x01;
@@ -64,6 +65,10 @@
pending_read_bool_value: Option<bool>,
// Underlying transport used for byte-level operations.
transport: T,
+ // Configuration
+ config: TConfiguration,
+ // Current recursion depth
+ recursion_depth: usize,
}
impl<T> TCompactInputProtocol<T>
@@ -72,11 +77,18 @@
{
/// Create a `TCompactInputProtocol` that reads bytes from `transport`.
pub fn new(transport: T) -> TCompactInputProtocol<T> {
+ Self::with_config(transport, TConfiguration::default())
+ }
+
+ /// Create a `TCompactInputProtocol` with custom configuration.
+ pub fn with_config(transport: T, config: TConfiguration) -> TCompactInputProtocol<T> {
TCompactInputProtocol {
last_read_field_id: 0,
read_field_id_stack: Vec::new(),
pending_read_bool_value: None,
transport,
+ config,
+ recursion_depth: 0,
}
}
@@ -92,8 +104,23 @@
self.transport.read_varint::<u32>()? as i32
};
+ let min_element_size = self.min_serialized_size(element_type);
+ super::check_container_size(&self.config, element_count, min_element_size)?;
+
Ok((element_type, element_count))
}
+
+ fn check_recursion_depth(&self) -> crate::Result<()> {
+ if let Some(limit) = self.config.max_recursion_depth() {
+ if self.recursion_depth >= limit {
+ return Err(crate::Error::Protocol(ProtocolError::new(
+ ProtocolErrorKind::DepthLimit,
+ format!("Maximum recursion depth {} exceeded", limit),
+ )));
+ }
+ }
+ Ok(())
+ }
}
impl<T> TInputProtocol for TCompactInputProtocol<T>
@@ -101,6 +128,7 @@
T: TReadTransport,
{
fn read_message_begin(&mut self) -> crate::Result<TMessageIdentifier> {
+ // TODO: Once specialization is stable, call the message size tracking here
let compact_id = self.read_byte()?;
if compact_id != COMPACT_PROTOCOL_ID {
Err(crate::Error::Protocol(crate::ProtocolError {
@@ -145,12 +173,15 @@
}
fn read_struct_begin(&mut self) -> crate::Result<Option<TStructIdentifier>> {
+ self.check_recursion_depth()?;
+ self.recursion_depth += 1;
self.read_field_id_stack.push(self.last_read_field_id);
self.last_read_field_id = 0;
Ok(None)
}
fn read_struct_end(&mut self) -> crate::Result<()> {
+ self.recursion_depth -= 1;
self.last_read_field_id = self
.read_field_id_stack
.pop()
@@ -227,6 +258,19 @@
fn read_bytes(&mut self) -> crate::Result<Vec<u8>> {
let len = self.transport.read_varint::<u32>()?;
+
+ if let Some(max_size) = self.config.max_string_size() {
+ if len as usize > max_size {
+ return Err(crate::Error::Protocol(ProtocolError::new(
+ ProtocolErrorKind::SizeLimit,
+ format!(
+ "Byte array size {} exceeds maximum allowed size of {}",
+ len, max_size
+ ),
+ )));
+ }
+ }
+
let mut buf = vec![0u8; len as usize];
self.transport
.read_exact(&mut buf)
@@ -291,6 +335,12 @@
let type_header = self.read_byte()?;
let key_type = collection_u8_to_type((type_header & 0xF0) >> 4)?;
let val_type = collection_u8_to_type(type_header & 0x0F)?;
+
+ let key_min_size = self.min_serialized_size(key_type);
+ let value_min_size = self.min_serialized_size(val_type);
+ let element_size = key_min_size + value_min_size;
+ super::check_container_size(&self.config, element_count, element_size)?;
+
Ok(TMapIdentifier::new(key_type, val_type, element_count))
}
}
@@ -309,6 +359,30 @@
.map_err(From::from)
.map(|_| buf[0])
}
+
+ fn min_serialized_size(&self, field_type: TType) -> usize {
+ compact_protocol_min_serialized_size(field_type)
+ }
+}
+
+pub(crate) fn compact_protocol_min_serialized_size(field_type: TType) -> usize {
+ match field_type {
+ TType::Stop => 1, // 1 byte
+ TType::Void => 1, // 1 byte
+ TType::Bool => 1, // 1 byte
+ TType::I08 => 1, // 1 byte
+ TType::Double => 8, // 8 bytes (not varint encoded)
+ TType::I16 => 1, // 1 byte minimum (varint)
+ TType::I32 => 1, // 1 byte minimum (varint)
+ TType::I64 => 1, // 1 byte minimum (varint)
+ TType::String => 1, // 1 byte minimum for length (varint)
+ TType::Struct => 1, // 1 byte minimum (stop field)
+ TType::Map => 1, // 1 byte minimum
+ TType::Set => 1, // 1 byte minimum
+ TType::List => 1, // 1 byte minimum
+ TType::Uuid => 16, // 16 bytes
+ TType::Utf7 => 1, // 1 byte
+ }
}
impl<T> io::Seek for TCompactInputProtocol<T>
@@ -2573,14 +2647,25 @@
fn must_round_trip_small_sized_list_begin() {
let (mut i_prot, mut o_prot) = test_objects();
- let ident = TListIdentifier::new(TType::I08, 10);
-
+ let ident = TListIdentifier::new(TType::I32, 3);
assert_success!(o_prot.write_list_begin(&ident));
+ assert_success!(o_prot.write_i32(100));
+ assert_success!(o_prot.write_i32(200));
+ assert_success!(o_prot.write_i32(300));
+
+ assert_success!(o_prot.write_list_end());
+
copy_write_buffer_to_read_buffer!(o_prot);
let res = assert_success!(i_prot.read_list_begin());
assert_eq!(&res, &ident);
+
+ assert_eq!(i_prot.read_i32().unwrap(), 100);
+ assert_eq!(i_prot.read_i32().unwrap(), 200);
+ assert_eq!(i_prot.read_i32().unwrap(), 300);
+
+ assert_success!(i_prot.read_list_end());
}
#[test]
@@ -2600,10 +2685,9 @@
#[test]
fn must_round_trip_large_sized_list_begin() {
- let (mut i_prot, mut o_prot) = test_objects();
+ let (mut i_prot, mut o_prot) = test_objects_no_limits();
let ident = TListIdentifier::new(TType::Set, 47381);
-
assert_success!(o_prot.write_list_begin(&ident));
copy_write_buffer_to_read_buffer!(o_prot);
@@ -2632,14 +2716,25 @@
fn must_round_trip_small_sized_set_begin() {
let (mut i_prot, mut o_prot) = test_objects();
- let ident = TSetIdentifier::new(TType::I16, 7);
-
+ let ident = TSetIdentifier::new(TType::I16, 3);
assert_success!(o_prot.write_set_begin(&ident));
+ assert_success!(o_prot.write_i16(111));
+ assert_success!(o_prot.write_i16(222));
+ assert_success!(o_prot.write_i16(333));
+
+ assert_success!(o_prot.write_set_end());
+
copy_write_buffer_to_read_buffer!(o_prot);
let res = assert_success!(i_prot.read_set_begin());
assert_eq!(&res, &ident);
+
+ assert_eq!(i_prot.read_i16().unwrap(), 111);
+ assert_eq!(i_prot.read_i16().unwrap(), 222);
+ assert_eq!(i_prot.read_i16().unwrap(), 333);
+
+ assert_success!(i_prot.read_set_end());
}
#[test]
@@ -2658,10 +2753,9 @@
#[test]
fn must_round_trip_large_sized_set_begin() {
- let (mut i_prot, mut o_prot) = test_objects();
+ let (mut i_prot, mut o_prot) = test_objects_no_limits();
let ident = TSetIdentifier::new(TType::Map, 3_928_429);
-
assert_success!(o_prot.write_set_begin(&ident));
copy_write_buffer_to_read_buffer!(o_prot);
@@ -2725,10 +2819,9 @@
#[test]
fn must_round_trip_map_begin() {
- let (mut i_prot, mut o_prot) = test_objects();
+ let (mut i_prot, mut o_prot) = test_objects_no_limits();
let ident = TMapIdentifier::new(TType::Map, TType::List, 1_928_349);
-
assert_success!(o_prot.write_map_begin(&ident));
copy_write_buffer_to_read_buffer!(o_prot);
@@ -2804,7 +2897,7 @@
TCompactInputProtocol<ReadHalf<TBufferChannel>>,
TCompactOutputProtocol<WriteHalf<TBufferChannel>>,
) {
- let mem = TBufferChannel::with_capacity(80, 80);
+ let mem = TBufferChannel::with_capacity(200, 200);
let (r_mem, w_mem) = mem.split().unwrap();
@@ -2814,6 +2907,20 @@
(i_prot, o_prot)
}
+ fn test_objects_no_limits() -> (
+ TCompactInputProtocol<ReadHalf<TBufferChannel>>,
+ TCompactOutputProtocol<WriteHalf<TBufferChannel>>,
+ ) {
+ let mem = TBufferChannel::with_capacity(200, 200);
+
+ let (r_mem, w_mem) = mem.split().unwrap();
+
+ let i_prot = TCompactInputProtocol::with_config(r_mem, TConfiguration::no_limits());
+ let o_prot = TCompactOutputProtocol::new(w_mem);
+
+ (i_prot, o_prot)
+ }
+
#[test]
fn must_read_write_double() {
let (mut i_prot, mut o_prot) = test_objects();
@@ -2883,4 +2990,248 @@
assert_success!(i_prot.read_list_end());
}
+
+ #[test]
+ fn must_enforce_recursion_depth_limit() {
+ let channel = TBufferChannel::with_capacity(100, 100);
+
+ // Create a configuration with a small recursion limit
+ let config = TConfiguration::builder()
+ .max_recursion_depth(Some(2))
+ .build()
+ .unwrap();
+
+ let mut protocol = TCompactInputProtocol::with_config(channel, config);
+
+ // First struct - should succeed
+ assert!(protocol.read_struct_begin().is_ok());
+
+ // Second struct - should succeed (at limit)
+ assert!(protocol.read_struct_begin().is_ok());
+
+ // Third struct - should fail (exceeds limit)
+ let result = protocol.read_struct_begin();
+ assert!(result.is_err());
+ match result {
+ Err(crate::Error::Protocol(e)) => {
+ assert_eq!(e.kind, ProtocolErrorKind::DepthLimit);
+ }
+ _ => panic!("Expected protocol error with DepthLimit"),
+ }
+ }
+
+ #[test]
+ fn must_check_container_size_overflow() {
+ // Configure a small message size limit
+ let config = TConfiguration::builder()
+ .max_message_size(Some(1000))
+ .max_frame_size(Some(1000))
+ .build()
+ .unwrap();
+ let transport = TBufferChannel::with_capacity(100, 0);
+ let mut i_prot = TCompactInputProtocol::with_config(transport, config);
+
+ // Write a list header that would require more memory than message size limit
+ // List of 100 UUIDs (16 bytes each) = 1600 bytes > 1000 limit
+ i_prot.transport.set_readable_bytes(&[
+ 0xFD, // element type UUID (0x0D) | count in next bytes (0xF0)
+ 0x64, // varint 100
+ ]);
+
+ let result = i_prot.read_list_begin();
+ assert!(result.is_err());
+ match result {
+ Err(crate::Error::Protocol(e)) => {
+ assert_eq!(e.kind, ProtocolErrorKind::SizeLimit);
+ assert!(e
+ .message
+ .contains("1600 bytes, exceeding message size limit of 1000"));
+ }
+ _ => panic!("Expected protocol error with SizeLimit"),
+ }
+ }
+
+ #[test]
+ fn must_reject_negative_container_sizes() {
+ let mut channel = TBufferChannel::with_capacity(100, 100);
+
+ let mut protocol = TCompactInputProtocol::new(channel.clone());
+
+ // Write header with negative size when decoded
+ // In compact protocol, lists/sets use a header byte followed by size
+ // We'll use 0x0F for element type and then a varint-encoded negative number
+ channel.set_readable_bytes(&[
+ 0xF0, // Header: 15 in upper nibble (triggers varint read), List type in lower
+ 0xFF, 0xFF, 0xFF, 0xFF, 0x0F, // Varint encoding of -1
+ ]);
+
+ let result = protocol.read_list_begin();
+ assert!(result.is_err());
+ match result {
+ Err(crate::Error::Protocol(e)) => {
+ assert_eq!(e.kind, ProtocolErrorKind::NegativeSize);
+ }
+ _ => panic!("Expected protocol error with NegativeSize"),
+ }
+ }
+
+ #[test]
+ fn must_enforce_container_size_limit() {
+ let channel = TBufferChannel::with_capacity(100, 100);
+ let (r_channel, mut w_channel) = channel.split().unwrap();
+
+ // Create protocol with explicit container size limit
+ let config = TConfiguration::builder()
+ .max_container_size(Some(1000))
+ .build()
+ .unwrap();
+ let mut protocol = TCompactInputProtocol::with_config(r_channel, config);
+
+ // Write header with large size
+ // Compact protocol: 0xF0 means size >= 15 is encoded as varint
+ // Then we write a varint encoding 10000 (exceeds our limit of 1000)
+ w_channel.set_readable_bytes(&[
+ 0xF0, // Header: 15 in upper nibble (triggers varint read), element type in lower
+ 0x90, 0x4E, // Varint encoding of 10000
+ ]);
+
+ let result = protocol.read_list_begin();
+ assert!(result.is_err());
+ match result {
+ Err(crate::Error::Protocol(e)) => {
+ assert_eq!(e.kind, ProtocolErrorKind::SizeLimit);
+ assert!(e.message.contains("exceeds maximum allowed size"));
+ }
+ _ => panic!("Expected protocol error with SizeLimit"),
+ }
+ }
+
+ #[test]
+ fn must_handle_varint_size_overflow() {
+ // Test that compact protocol properly handles varint-encoded sizes that would cause overflow
+ let mut channel = TBufferChannel::with_capacity(100, 100);
+
+ let mut protocol = TCompactInputProtocol::new(channel.clone());
+
+ // Create input that encodes a very large size using varint encoding
+ // 0xFA = list header with size >= 15 (so size follows as varint)
+ // Then multiple 0xFF bytes which in varint encoding create a very large number
+ channel.set_readable_bytes(&[
+ 0xFA, // List header: size >= 15, element type = 0x0A
+ 0xFF, 0xFF, 0xFF, 0xFF, 0x7F, // Varint encoding of a huge number
+ ]);
+
+ let result = protocol.read_list_begin();
+ assert!(result.is_err());
+ match result {
+ Err(crate::Error::Protocol(e)) => {
+ // The varint decoder might interpret this as negative, which is also fine
+ assert!(
+ e.kind == ProtocolErrorKind::SizeLimit
+ || e.kind == ProtocolErrorKind::NegativeSize,
+ "Expected SizeLimit or NegativeSize but got {:?}",
+ e.kind
+ );
+ }
+ _ => panic!("Expected protocol error"),
+ }
+ }
+
+ #[test]
+ fn must_enforce_string_size_limit() {
+ let channel = TBufferChannel::with_capacity(100, 100);
+ let (r_channel, mut w_channel) = channel.split().unwrap();
+
+ // Create protocol with string limit of 100 bytes
+ let config = TConfiguration::builder()
+ .max_string_size(Some(100))
+ .build()
+ .unwrap();
+ let mut protocol = TCompactInputProtocol::with_config(r_channel, config);
+
+ // Write a varint-encoded string size that exceeds the limit
+ w_channel.set_readable_bytes(&[
+ 0xC8, 0x01, // Varint encoding of 200
+ ]);
+
+ let result = protocol.read_string();
+ assert!(result.is_err());
+ match result {
+ Err(crate::Error::Protocol(e)) => {
+ assert_eq!(e.kind, ProtocolErrorKind::SizeLimit);
+ assert!(e.message.contains("exceeds maximum allowed size"));
+ }
+ _ => panic!("Expected protocol error with SizeLimit"),
+ }
+ }
+
+ #[test]
+ fn must_allow_no_limit_configuration() {
+ let channel = TBufferChannel::with_capacity(40, 40);
+
+ let config = TConfiguration::no_limits();
+ let mut protocol = TCompactInputProtocol::with_config(channel, config);
+
+ // Should be able to nest structs deeply without limit
+ for _ in 0..100 {
+ assert!(protocol.read_struct_begin().is_ok());
+ }
+
+ for _ in 0..100 {
+ assert!(protocol.read_struct_end().is_ok());
+ }
+ }
+
+ #[test]
+ fn must_allow_containers_within_limit() {
+ let channel = TBufferChannel::with_capacity(200, 200);
+ let (r_channel, mut w_channel) = channel.split().unwrap();
+
+ // Create protocol with container limit of 100
+ let config = TConfiguration::builder()
+ .max_container_size(Some(100))
+ .build()
+ .unwrap();
+ let mut protocol = TCompactInputProtocol::with_config(r_channel, config);
+
+ // Write a list with 5 i32 elements (well within limit of 100)
+ // Compact protocol: size < 15 is encoded in header
+ w_channel.set_readable_bytes(&[
+ 0x55, // Header: size=5, element type=5 (i32)
+ // 5 varint-encoded i32 values
+ 0x0A, // 10
+ 0x14, // 20
+ 0x1E, // 30
+ 0x28, // 40
+ 0x32, // 50
+ ]);
+
+ let result = protocol.read_list_begin();
+ assert!(result.is_ok());
+ let list_ident = result.unwrap();
+ assert_eq!(list_ident.size, 5);
+ assert_eq!(list_ident.element_type, TType::I32);
+ }
+
+ #[test]
+ fn must_allow_strings_within_limit() {
+ let channel = TBufferChannel::with_capacity(100, 100);
+ let (r_channel, mut w_channel) = channel.split().unwrap();
+
+ let config = TConfiguration::builder()
+ .max_string_size(Some(1000))
+ .build()
+ .unwrap();
+ let mut protocol = TCompactInputProtocol::with_config(r_channel, config);
+
+ // Write a string "hello" (5 bytes, well within limit)
+ w_channel.set_readable_bytes(&[
+ 0x05, // Varint-encoded length: 5
+ b'h', b'e', b'l', b'l', b'o',
+ ]);
+
+ let result = protocol.read_string();
+ assert!(result.is_ok());
+ assert_eq!(result.unwrap(), "hello");
+ }
}
diff --git a/lib/rs/src/protocol/mod.rs b/lib/rs/src/protocol/mod.rs
index 8bbbb3c..573978a 100644
--- a/lib/rs/src/protocol/mod.rs
+++ b/lib/rs/src/protocol/mod.rs
@@ -62,7 +62,7 @@
use std::fmt::{Display, Formatter};
use crate::transport::{TReadTransport, TWriteTransport};
-use crate::{ProtocolError, ProtocolErrorKind};
+use crate::{ProtocolError, ProtocolErrorKind, TConfiguration};
#[cfg(test)]
macro_rules! assert_eq_written_bytes {
@@ -262,6 +262,15 @@
///
/// This method should **never** be used in generated code.
fn read_byte(&mut self) -> crate::Result<u8>;
+
+ /// Get the minimum number of bytes a type will consume on the wire.
+ /// This picks the minimum possible across all protocols (so currently matches the compact protocol).
+ ///
+ /// This is used for pre-allocation size checks.
+ /// The actual data may be larger (e.g., for strings, lists, etc.).
+ fn min_serialized_size(&self, field_type: TType) -> usize {
+ self::compact::compact_protocol_min_serialized_size(field_type)
+ }
}
/// Converts Thrift identifiers, primitives, containers or structs into a
@@ -444,6 +453,10 @@
fn read_byte(&mut self) -> crate::Result<u8> {
(**self).read_byte()
}
+
+ fn min_serialized_size(&self, field_type: TType) -> usize {
+ (**self).min_serialized_size(field_type)
+ }
}
impl<P> TOutputProtocol for Box<P>
@@ -565,7 +578,7 @@
/// let protocol = factory.create(Box::new(channel));
/// ```
pub trait TInputProtocolFactory {
- // Create a `TInputProtocol` that reads bytes from `transport`.
+ /// Create a `TInputProtocol` that reads bytes from `transport`.
fn create(&self, transport: Box<dyn TReadTransport + Send>) -> Box<dyn TInputProtocol + Send>;
}
@@ -920,6 +933,69 @@
}
}
+/// Common container size validation used by all protocols.
+///
+/// Checks that:
+/// - Container size is not negative
+/// - Container size doesn't exceed configured maximum
+/// - Container size * element size doesn't overflow
+/// - Container memory requirements don't exceed message size limit
+pub(crate) fn check_container_size(
+ config: &TConfiguration,
+ container_size: i32,
+ element_size: usize,
+) -> crate::Result<()> {
+ // Check for negative size
+ if container_size < 0 {
+ return Err(crate::Error::Protocol(ProtocolError::new(
+ ProtocolErrorKind::NegativeSize,
+ format!("Negative container size: {}", container_size),
+ )));
+ }
+
+ let size_as_usize = container_size as usize;
+
+ // Check against configured max container size
+ if let Some(max_size) = config.max_container_size() {
+ if size_as_usize > max_size {
+ return Err(crate::Error::Protocol(ProtocolError::new(
+ ProtocolErrorKind::SizeLimit,
+ format!(
+ "Container size {} exceeds maximum allowed size of {}",
+ container_size, max_size
+ ),
+ )));
+ }
+ }
+
+ // Check for potential overflow
+ if let Some(min_bytes_needed) = size_as_usize.checked_mul(element_size) {
+ // TODO: When Rust trait specialization stabilizes, we can add more precise checks
+ // for transports that track exact remaining bytes. For now, we use the message
+ // size limit as a best-effort check.
+ if let Some(max_message_size) = config.max_message_size() {
+ if min_bytes_needed > max_message_size {
+ return Err(crate::Error::Protocol(ProtocolError::new(
+ ProtocolErrorKind::SizeLimit,
+ format!(
+ "Container would require {} bytes, exceeding message size limit of {}",
+ min_bytes_needed, max_message_size
+ ),
+ )));
+ }
+ }
+ Ok(())
+ } else {
+ Err(crate::Error::Protocol(ProtocolError::new(
+ ProtocolErrorKind::SizeLimit,
+ format!(
+ "Container size {} with element size {} bytes would result in overflow",
+ container_size, element_size
+ ),
+ )))
+ }
+}
+
/// Extract the field id from a Thrift field identifier.
///
/// `field_ident` must *not* have `TFieldIdentifier.field_type` of type `TType::Stop`.
diff --git a/lib/rs/src/transport/framed.rs b/lib/rs/src/transport/framed.rs
index d8a7448..cf959be 100644
--- a/lib/rs/src/transport/framed.rs
+++ b/lib/rs/src/transport/framed.rs
@@ -21,6 +21,7 @@
use std::io::{Read, Write};
use super::{TReadTransport, TReadTransportFactory, TWriteTransport, TWriteTransportFactory};
+use crate::TConfiguration;
/// Default capacity of the read buffer in bytes.
const READ_CAPACITY: usize = 4096;
@@ -61,6 +62,7 @@
pos: usize,
cap: usize,
chan: C,
+ config: TConfiguration,
}
impl<C> TFramedReadTransport<C>
@@ -81,6 +83,7 @@
pos: 0,
cap: 0,
chan: channel,
+ config: TConfiguration::default(),
}
}
}
@@ -91,7 +94,28 @@
{
fn read(&mut self, b: &mut [u8]) -> io::Result<usize> {
if self.cap - self.pos == 0 {
- let message_size = self.chan.read_i32::<BigEndian>()? as usize;
+ let frame_size_bytes = self.chan.read_i32::<BigEndian>()?;
+
+ if frame_size_bytes < 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("Negative frame size: {}", frame_size_bytes),
+ ));
+ }
+
+ let message_size = frame_size_bytes as usize;
+
+ if let Some(max_frame) = self.config.max_frame_size() {
+ if message_size > max_frame {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!(
+ "Frame size {} exceeds maximum allowed size of {}",
+ message_size, max_frame
+ ),
+ ));
+ }
+ }
let buf_capacity = cmp::max(message_size, READ_CAPACITY);
self.buf.resize(buf_capacity, 0);
@@ -125,7 +149,6 @@
Box::new(TFramedReadTransport::new(channel))
}
}
-
/// Transport that writes framed messages.
///
/// A `TFramedWriteTransport` maintains a fixed-size internal write buffer. All