[THRIFT-5871] Add message / container size checking for Rust
Bring the Rust implementation somewhat up to par with the other implementations.
I tried 4-5 different ways to get the "perfect" check but since trait specialization is not yet stable,
I was not able to arrive at a solution I'm happy with (code was either ugly, or had runtime overhead).
So for now, we avoid full message size tracking / more precise limit checking, but this is a strong step
in the right direction.
diff --git a/lib/rs/src/protocol/compact.rs b/lib/rs/src/protocol/compact.rs
index 4dc45ca..7e1a751 100644
--- a/lib/rs/src/protocol/compact.rs
+++ b/lib/rs/src/protocol/compact.rs
@@ -26,6 +26,7 @@
};
use super::{TOutputProtocol, TOutputProtocolFactory, TSetIdentifier, TStructIdentifier, TType};
use crate::transport::{TReadTransport, TWriteTransport};
+use crate::{ProtocolError, ProtocolErrorKind, TConfiguration};
const COMPACT_PROTOCOL_ID: u8 = 0x82;
const COMPACT_VERSION: u8 = 0x01;
@@ -64,6 +65,10 @@
pending_read_bool_value: Option<bool>,
// Underlying transport used for byte-level operations.
transport: T,
+ // Configuration
+ config: TConfiguration,
+ // Current recursion depth
+ recursion_depth: usize,
}
impl<T> TCompactInputProtocol<T>
@@ -72,11 +77,18 @@
{
/// Create a `TCompactInputProtocol` that reads bytes from `transport`.
pub fn new(transport: T) -> TCompactInputProtocol<T> {
+ Self::with_config(transport, TConfiguration::default())
+ }
+
+ /// Create a `TCompactInputProtocol` with custom configuration.
+ pub fn with_config(transport: T, config: TConfiguration) -> TCompactInputProtocol<T> {
TCompactInputProtocol {
last_read_field_id: 0,
read_field_id_stack: Vec::new(),
pending_read_bool_value: None,
transport,
+ config,
+ recursion_depth: 0,
}
}
@@ -92,8 +104,23 @@
self.transport.read_varint::<u32>()? as i32
};
+ let min_element_size = self.min_serialized_size(element_type);
+ super::check_container_size(&self.config, element_count, min_element_size)?;
+
Ok((element_type, element_count))
}
+
+ fn check_recursion_depth(&self) -> crate::Result<()> {
+ if let Some(limit) = self.config.max_recursion_depth() {
+ if self.recursion_depth >= limit {
+ return Err(crate::Error::Protocol(ProtocolError::new(
+ ProtocolErrorKind::DepthLimit,
+ format!("Maximum recursion depth {} exceeded", limit),
+ )));
+ }
+ }
+ Ok(())
+ }
}
impl<T> TInputProtocol for TCompactInputProtocol<T>
@@ -101,6 +128,7 @@
T: TReadTransport,
{
fn read_message_begin(&mut self) -> crate::Result<TMessageIdentifier> {
+ // TODO: Once specialization is stable, call the message size tracking here
let compact_id = self.read_byte()?;
if compact_id != COMPACT_PROTOCOL_ID {
Err(crate::Error::Protocol(crate::ProtocolError {
@@ -145,12 +173,15 @@
}
fn read_struct_begin(&mut self) -> crate::Result<Option<TStructIdentifier>> {
+ self.check_recursion_depth()?;
+ self.recursion_depth += 1;
self.read_field_id_stack.push(self.last_read_field_id);
self.last_read_field_id = 0;
Ok(None)
}
fn read_struct_end(&mut self) -> crate::Result<()> {
+ self.recursion_depth -= 1;
self.last_read_field_id = self
.read_field_id_stack
.pop()
@@ -227,6 +258,19 @@
fn read_bytes(&mut self) -> crate::Result<Vec<u8>> {
let len = self.transport.read_varint::<u32>()?;
+
+ if let Some(max_size) = self.config.max_string_size() {
+ if len as usize > max_size {
+ return Err(crate::Error::Protocol(ProtocolError::new(
+ ProtocolErrorKind::SizeLimit,
+ format!(
+ "Byte array size {} exceeds maximum allowed size of {}",
+ len, max_size
+ ),
+ )));
+ }
+ }
+
let mut buf = vec![0u8; len as usize];
self.transport
.read_exact(&mut buf)
@@ -291,6 +335,12 @@
let type_header = self.read_byte()?;
let key_type = collection_u8_to_type((type_header & 0xF0) >> 4)?;
let val_type = collection_u8_to_type(type_header & 0x0F)?;
+
+ let key_min_size = self.min_serialized_size(key_type);
+ let value_min_size = self.min_serialized_size(val_type);
+ let element_size = key_min_size + value_min_size;
+ super::check_container_size(&self.config, element_count, element_size)?;
+
Ok(TMapIdentifier::new(key_type, val_type, element_count))
}
}
@@ -309,6 +359,30 @@
.map_err(From::from)
.map(|_| buf[0])
}
+
+ fn min_serialized_size(&self, field_type: TType) -> usize {
+ compact_protocol_min_serialized_size(field_type)
+ }
+}
+
+pub(crate) fn compact_protocol_min_serialized_size(field_type: TType) -> usize {
+ match field_type {
+ TType::Stop => 1, // 1 byte
+ TType::Void => 1, // 1 byte
+ TType::Bool => 1, // 1 byte
+ TType::I08 => 1, // 1 byte
+ TType::Double => 8, // 8 bytes (not varint encoded)
+ TType::I16 => 1, // 1 byte minimum (varint)
+ TType::I32 => 1, // 1 byte minimum (varint)
+ TType::I64 => 1, // 1 byte minimum (varint)
+ TType::String => 1, // 1 byte minimum for length (varint)
+ TType::Struct => 1, // 1 byte minimum (stop field)
+ TType::Map => 1, // 1 byte minimum
+ TType::Set => 1, // 1 byte minimum
+ TType::List => 1, // 1 byte minimum
+ TType::Uuid => 16, // 16 bytes
+ TType::Utf7 => 1, // 1 byte
+ }
}
impl<T> io::Seek for TCompactInputProtocol<T>
@@ -2573,14 +2647,25 @@
fn must_round_trip_small_sized_list_begin() {
let (mut i_prot, mut o_prot) = test_objects();
- let ident = TListIdentifier::new(TType::I08, 10);
-
+ let ident = TListIdentifier::new(TType::I32, 3);
assert_success!(o_prot.write_list_begin(&ident));
+ assert_success!(o_prot.write_i32(100));
+ assert_success!(o_prot.write_i32(200));
+ assert_success!(o_prot.write_i32(300));
+
+ assert_success!(o_prot.write_list_end());
+
copy_write_buffer_to_read_buffer!(o_prot);
let res = assert_success!(i_prot.read_list_begin());
assert_eq!(&res, &ident);
+
+ assert_eq!(i_prot.read_i32().unwrap(), 100);
+ assert_eq!(i_prot.read_i32().unwrap(), 200);
+ assert_eq!(i_prot.read_i32().unwrap(), 300);
+
+ assert_success!(i_prot.read_list_end());
}
#[test]
@@ -2600,10 +2685,9 @@
#[test]
fn must_round_trip_large_sized_list_begin() {
- let (mut i_prot, mut o_prot) = test_objects();
+ let (mut i_prot, mut o_prot) = test_objects_no_limits();
let ident = TListIdentifier::new(TType::Set, 47381);
-
assert_success!(o_prot.write_list_begin(&ident));
copy_write_buffer_to_read_buffer!(o_prot);
@@ -2632,14 +2716,25 @@
fn must_round_trip_small_sized_set_begin() {
let (mut i_prot, mut o_prot) = test_objects();
- let ident = TSetIdentifier::new(TType::I16, 7);
-
+ let ident = TSetIdentifier::new(TType::I16, 3);
assert_success!(o_prot.write_set_begin(&ident));
+ assert_success!(o_prot.write_i16(111));
+ assert_success!(o_prot.write_i16(222));
+ assert_success!(o_prot.write_i16(333));
+
+ assert_success!(o_prot.write_set_end());
+
copy_write_buffer_to_read_buffer!(o_prot);
let res = assert_success!(i_prot.read_set_begin());
assert_eq!(&res, &ident);
+
+ assert_eq!(i_prot.read_i16().unwrap(), 111);
+ assert_eq!(i_prot.read_i16().unwrap(), 222);
+ assert_eq!(i_prot.read_i16().unwrap(), 333);
+
+ assert_success!(i_prot.read_set_end());
}
#[test]
@@ -2658,10 +2753,9 @@
#[test]
fn must_round_trip_large_sized_set_begin() {
- let (mut i_prot, mut o_prot) = test_objects();
+ let (mut i_prot, mut o_prot) = test_objects_no_limits();
let ident = TSetIdentifier::new(TType::Map, 3_928_429);
-
assert_success!(o_prot.write_set_begin(&ident));
copy_write_buffer_to_read_buffer!(o_prot);
@@ -2725,10 +2819,9 @@
#[test]
fn must_round_trip_map_begin() {
- let (mut i_prot, mut o_prot) = test_objects();
+ let (mut i_prot, mut o_prot) = test_objects_no_limits();
let ident = TMapIdentifier::new(TType::Map, TType::List, 1_928_349);
-
assert_success!(o_prot.write_map_begin(&ident));
copy_write_buffer_to_read_buffer!(o_prot);
@@ -2804,7 +2897,7 @@
TCompactInputProtocol<ReadHalf<TBufferChannel>>,
TCompactOutputProtocol<WriteHalf<TBufferChannel>>,
) {
- let mem = TBufferChannel::with_capacity(80, 80);
+ let mem = TBufferChannel::with_capacity(200, 200);
let (r_mem, w_mem) = mem.split().unwrap();
@@ -2814,6 +2907,20 @@
(i_prot, o_prot)
}
+ fn test_objects_no_limits() -> (
+ TCompactInputProtocol<ReadHalf<TBufferChannel>>,
+ TCompactOutputProtocol<WriteHalf<TBufferChannel>>,
+ ) {
+ let mem = TBufferChannel::with_capacity(200, 200);
+
+ let (r_mem, w_mem) = mem.split().unwrap();
+
+ let i_prot = TCompactInputProtocol::with_config(r_mem, TConfiguration::no_limits());
+ let o_prot = TCompactOutputProtocol::new(w_mem);
+
+ (i_prot, o_prot)
+ }
+
#[test]
fn must_read_write_double() {
let (mut i_prot, mut o_prot) = test_objects();
@@ -2883,4 +2990,248 @@
assert_success!(i_prot.read_list_end());
}
+
+ #[test]
+ fn must_enforce_recursion_depth_limit() {
+ let channel = TBufferChannel::with_capacity(100, 100);
+
+ // Create a configuration with a small recursion limit
+ let config = TConfiguration::builder()
+ .max_recursion_depth(Some(2))
+ .build()
+ .unwrap();
+
+ let mut protocol = TCompactInputProtocol::with_config(channel, config);
+
+ // First struct - should succeed
+ assert!(protocol.read_struct_begin().is_ok());
+
+ // Second struct - should succeed (at limit)
+ assert!(protocol.read_struct_begin().is_ok());
+
+ // Third struct - should fail (exceeds limit)
+ let result = protocol.read_struct_begin();
+ assert!(result.is_err());
+ match result {
+ Err(crate::Error::Protocol(e)) => {
+ assert_eq!(e.kind, ProtocolErrorKind::DepthLimit);
+ }
+ _ => panic!("Expected protocol error with DepthLimit"),
+ }
+ }
+
+ #[test]
+ fn must_check_container_size_overflow() {
+ // Configure a small message size limit
+ let config = TConfiguration::builder()
+ .max_message_size(Some(1000))
+ .max_frame_size(Some(1000))
+ .build()
+ .unwrap();
+ let transport = TBufferChannel::with_capacity(100, 0);
+ let mut i_prot = TCompactInputProtocol::with_config(transport, config);
+
+ // Write a list header that would require more memory than message size limit
+ // List of 100 UUIDs (16 bytes each) = 1600 bytes > 1000 limit
+ i_prot.transport.set_readable_bytes(&[
+ 0xFD, // element type UUID (0x0D) | count in next bytes (0xF0)
+ 0x64, // varint 100
+ ]);
+
+ let result = i_prot.read_list_begin();
+ assert!(result.is_err());
+ match result {
+ Err(crate::Error::Protocol(e)) => {
+ assert_eq!(e.kind, ProtocolErrorKind::SizeLimit);
+ assert!(e
+ .message
+ .contains("1600 bytes, exceeding message size limit of 1000"));
+ }
+ _ => panic!("Expected protocol error with SizeLimit"),
+ }
+ }
+
+ #[test]
+ fn must_reject_negative_container_sizes() {
+ let mut channel = TBufferChannel::with_capacity(100, 100);
+
+ let mut protocol = TCompactInputProtocol::new(channel.clone());
+
+ // Write header with negative size when decoded
+ // In compact protocol, lists/sets use a header byte followed by size
+ // We'll use 0x0F for element type and then a varint-encoded negative number
+ channel.set_readable_bytes(&[
+ 0xF0, // Header: 15 in upper nibble (triggers varint read), List type in lower
+ 0xFF, 0xFF, 0xFF, 0xFF, 0x0F, // Varint encoding of -1
+ ]);
+
+ let result = protocol.read_list_begin();
+ assert!(result.is_err());
+ match result {
+ Err(crate::Error::Protocol(e)) => {
+ assert_eq!(e.kind, ProtocolErrorKind::NegativeSize);
+ }
+ _ => panic!("Expected protocol error with NegativeSize"),
+ }
+ }
+
+ #[test]
+ fn must_enforce_container_size_limit() {
+ let channel = TBufferChannel::with_capacity(100, 100);
+ let (r_channel, mut w_channel) = channel.split().unwrap();
+
+ // Create protocol with explicit container size limit
+ let config = TConfiguration::builder()
+ .max_container_size(Some(1000))
+ .build()
+ .unwrap();
+ let mut protocol = TCompactInputProtocol::with_config(r_channel, config);
+
+ // Write header with large size
+ // Compact protocol: 0xF0 means size >= 15 is encoded as varint
+ // Then we write a varint encoding 10000 (exceeds our limit of 1000)
+ w_channel.set_readable_bytes(&[
+ 0xF0, // Header: 15 in upper nibble (triggers varint read), element type in lower
+ 0x90, 0x4E, // Varint encoding of 10000
+ ]);
+
+ let result = protocol.read_list_begin();
+ assert!(result.is_err());
+ match result {
+ Err(crate::Error::Protocol(e)) => {
+ assert_eq!(e.kind, ProtocolErrorKind::SizeLimit);
+ assert!(e.message.contains("exceeds maximum allowed size"));
+ }
+ _ => panic!("Expected protocol error with SizeLimit"),
+ }
+ }
+
+ #[test]
+ fn must_handle_varint_size_overflow() {
+ // Test that compact protocol properly handles varint-encoded sizes that would cause overflow
+ let mut channel = TBufferChannel::with_capacity(100, 100);
+
+ let mut protocol = TCompactInputProtocol::new(channel.clone());
+
+ // Create input that encodes a very large size using varint encoding
+ // 0xFA = list header with size >= 15 (so size follows as varint)
+ // Then multiple 0xFF bytes which in varint encoding create a very large number
+ channel.set_readable_bytes(&[
+ 0xFA, // List header: size >= 15, element type = 0x0A
+ 0xFF, 0xFF, 0xFF, 0xFF, 0x7F, // Varint encoding of a huge number
+ ]);
+
+ let result = protocol.read_list_begin();
+ assert!(result.is_err());
+ match result {
+ Err(crate::Error::Protocol(e)) => {
+ // The varint decoder might interpret this as negative, which is also fine
+ assert!(
+ e.kind == ProtocolErrorKind::SizeLimit
+ || e.kind == ProtocolErrorKind::NegativeSize,
+ "Expected SizeLimit or NegativeSize but got {:?}",
+ e.kind
+ );
+ }
+ _ => panic!("Expected protocol error"),
+ }
+ }
+
+ #[test]
+ fn must_enforce_string_size_limit() {
+ let channel = TBufferChannel::with_capacity(100, 100);
+ let (r_channel, mut w_channel) = channel.split().unwrap();
+
+ // Create protocol with string limit of 100 bytes
+ let config = TConfiguration::builder()
+ .max_string_size(Some(100))
+ .build()
+ .unwrap();
+ let mut protocol = TCompactInputProtocol::with_config(r_channel, config);
+
+ // Write a varint-encoded string size that exceeds the limit
+ w_channel.set_readable_bytes(&[
+ 0xC8, 0x01, // Varint encoding of 200
+ ]);
+
+ let result = protocol.read_string();
+ assert!(result.is_err());
+ match result {
+ Err(crate::Error::Protocol(e)) => {
+ assert_eq!(e.kind, ProtocolErrorKind::SizeLimit);
+ assert!(e.message.contains("exceeds maximum allowed size"));
+ }
+ _ => panic!("Expected protocol error with SizeLimit"),
+ }
+ }
+
+ #[test]
+ fn must_allow_no_limit_configuration() {
+ let channel = TBufferChannel::with_capacity(40, 40);
+
+ let config = TConfiguration::no_limits();
+ let mut protocol = TCompactInputProtocol::with_config(channel, config);
+
+ // Should be able to nest structs deeply without limit
+ for _ in 0..100 {
+ assert!(protocol.read_struct_begin().is_ok());
+ }
+
+ for _ in 0..100 {
+ assert!(protocol.read_struct_end().is_ok());
+ }
+ }
+
+ #[test]
+ fn must_allow_containers_within_limit() {
+ let channel = TBufferChannel::with_capacity(200, 200);
+ let (r_channel, mut w_channel) = channel.split().unwrap();
+
+ // Create protocol with container limit of 100
+ let config = TConfiguration::builder()
+ .max_container_size(Some(100))
+ .build()
+ .unwrap();
+ let mut protocol = TCompactInputProtocol::with_config(r_channel, config);
+
+ // Write a list with 5 i32 elements (well within limit of 100)
+ // Compact protocol: size < 15 is encoded in header
+ w_channel.set_readable_bytes(&[
+ 0x55, // Header: size=5, element type=5 (i32)
+ // 5 varint-encoded i32 values
+ 0x0A, // 10
+ 0x14, // 20
+ 0x1E, // 30
+ 0x28, // 40
+ 0x32, // 50
+ ]);
+
+ let result = protocol.read_list_begin();
+ assert!(result.is_ok());
+ let list_ident = result.unwrap();
+ assert_eq!(list_ident.size, 5);
+ assert_eq!(list_ident.element_type, TType::I32);
+ }
+
+ #[test]
+ fn must_allow_strings_within_limit() {
+ let channel = TBufferChannel::with_capacity(100, 100);
+ let (r_channel, mut w_channel) = channel.split().unwrap();
+
+ let config = TConfiguration::builder()
+ .max_string_size(Some(1000))
+ .build()
+ .unwrap();
+ let mut protocol = TCompactInputProtocol::with_config(r_channel, config);
+
+ // Write a string "hello" (5 bytes, well within limit)
+ w_channel.set_readable_bytes(&[
+ 0x05, // Varint-encoded length: 5
+ b'h', b'e', b'l', b'l', b'o',
+ ]);
+
+ let result = protocol.read_string();
+ assert!(result.is_ok());
+ assert_eq!(result.unwrap(), "hello");
+ }
}