[THRIFT-5871] Add message / container size checking for Rust
Bring the Rust implementation somewhat up to par with the other implementations.
I tried 4-5 different ways to get the "perfect" check but since trait specialization is not yet stable,
I was not able to arrive at a solution I'm happy with (code was either ugly, or had runtime overhead).
So for now, we avoid full message size tracking / more precise limit checking, but this is a strong step
in the right direction.
diff --git a/lib/rs/src/protocol/binary.rs b/lib/rs/src/protocol/binary.rs
index b4b51f6..596285f 100644
--- a/lib/rs/src/protocol/binary.rs
+++ b/lib/rs/src/protocol/binary.rs
@@ -24,7 +24,7 @@
};
use super::{TOutputProtocol, TOutputProtocolFactory, TSetIdentifier, TStructIdentifier, TType};
use crate::transport::{TReadTransport, TWriteTransport};
-use crate::{ProtocolError, ProtocolErrorKind};
+use crate::{ProtocolError, ProtocolErrorKind, TConfiguration};
const BINARY_PROTOCOL_VERSION_1: u32 = 0x8001_0000;
@@ -57,6 +57,8 @@
{
strict: bool,
pub transport: T, // FIXME: shouldn't be public
+ config: TConfiguration,
+ recursion_depth: usize,
}
impl<T> TBinaryInputProtocol<T>
@@ -67,8 +69,29 @@
///
/// Set `strict` to `true` if all incoming messages contain the protocol
/// version number in the protocol header.
- pub fn new(transport: T, strict: bool) -> TBinaryInputProtocol<T> {
- TBinaryInputProtocol { strict, transport }
+ pub fn new(transport: T, strict: bool) -> Self {
+ Self::with_config(transport, strict, TConfiguration::default())
+ }
+
+ pub fn with_config(transport: T, strict: bool, config: TConfiguration) -> Self {
+ TBinaryInputProtocol {
+ strict,
+ transport,
+ config,
+ recursion_depth: 0,
+ }
+ }
+
+ fn check_recursion_depth(&self) -> crate::Result<()> {
+ if let Some(limit) = self.config.max_recursion_depth() {
+ if self.recursion_depth >= limit {
+ return Err(crate::Error::Protocol(ProtocolError::new(
+ ProtocolErrorKind::DepthLimit,
+ format!("Maximum recursion depth {} exceeded", limit),
+ )));
+ }
+ }
+ Ok(())
}
}
@@ -78,6 +101,7 @@
{
#[allow(clippy::collapsible_if)]
fn read_message_begin(&mut self) -> crate::Result<TMessageIdentifier> {
+ // TODO: Once specialization is stable, call the message size tracking here
let mut first_bytes = vec![0; 4];
self.transport.read_exact(&mut first_bytes[..])?;
@@ -130,10 +154,13 @@
}
fn read_struct_begin(&mut self) -> crate::Result<Option<TStructIdentifier>> {
+ self.check_recursion_depth()?;
+ self.recursion_depth += 1;
Ok(None)
}
fn read_struct_end(&mut self) -> crate::Result<()> {
+ self.recursion_depth -= 1;
Ok(())
}
@@ -154,8 +181,28 @@
}
fn read_bytes(&mut self) -> crate::Result<Vec<u8>> {
- let num_bytes = self.transport.read_i32::<BigEndian>()? as usize;
- let mut buf = vec![0u8; num_bytes];
+ let num_bytes = self.transport.read_i32::<BigEndian>()?;
+
+ if num_bytes < 0 {
+ return Err(crate::Error::Protocol(ProtocolError::new(
+ ProtocolErrorKind::NegativeSize,
+ format!("Negative byte array size: {}", num_bytes),
+ )));
+ }
+
+ if let Some(max_size) = self.config.max_string_size() {
+ if num_bytes as usize > max_size {
+ return Err(crate::Error::Protocol(ProtocolError::new(
+ ProtocolErrorKind::SizeLimit,
+ format!(
+ "Byte array size {} exceeds maximum allowed size of {}",
+ num_bytes, max_size
+ ),
+ )));
+ }
+ }
+
+ let mut buf = vec![0u8; num_bytes as usize];
self.transport
.read_exact(&mut buf)
.map(|_| buf)
@@ -206,6 +253,8 @@
fn read_list_begin(&mut self) -> crate::Result<TListIdentifier> {
let element_type: TType = self.read_byte().and_then(field_type_from_u8)?;
let size = self.read_i32()?;
+ let min_element_size = self.min_serialized_size(element_type);
+ super::check_container_size(&self.config, size, min_element_size)?;
Ok(TListIdentifier::new(element_type, size))
}
@@ -216,6 +265,8 @@
fn read_set_begin(&mut self) -> crate::Result<TSetIdentifier> {
let element_type: TType = self.read_byte().and_then(field_type_from_u8)?;
let size = self.read_i32()?;
+ let min_element_size = self.min_serialized_size(element_type);
+ super::check_container_size(&self.config, size, min_element_size)?;
Ok(TSetIdentifier::new(element_type, size))
}
@@ -227,6 +278,12 @@
let key_type: TType = self.read_byte().and_then(field_type_from_u8)?;
let value_type: TType = self.read_byte().and_then(field_type_from_u8)?;
let size = self.read_i32()?;
+
+ let key_min_size = self.min_serialized_size(key_type);
+ let value_min_size = self.min_serialized_size(value_type);
+ let element_size = key_min_size + value_min_size;
+ super::check_container_size(&self.config, size, element_size)?;
+
Ok(TMapIdentifier::new(key_type, value_type, size))
}
@@ -240,6 +297,26 @@
fn read_byte(&mut self) -> crate::Result<u8> {
self.transport.read_u8().map_err(From::from)
}
+
+ fn min_serialized_size(&self, field_type: TType) -> usize {
+ match field_type {
+ TType::Stop => 1, // 1 byte minimum
+ TType::Void => 1, // 1 byte minimum
+ TType::Bool => 1, // 1 byte
+ TType::I08 => 1, // 1 byte
+ TType::Double => 8, // 8 bytes
+ TType::I16 => 2, // 2 bytes
+ TType::I32 => 4, // 4 bytes
+ TType::I64 => 8, // 8 bytes
+ TType::String => 4, // 4 bytes for length prefix
+ TType::Struct => 1, // 1 byte minimum (stop field)
+ TType::Map => 4, // 4 bytes size
+ TType::Set => 4, // 4 bytes size
+ TType::List => 4, // 4 bytes size
+ TType::Uuid => 16, // 16 bytes
+ TType::Utf7 => 1, // 1 byte
+ }
+ }
}
/// Factory for creating instances of `TBinaryInputProtocol`.
@@ -514,14 +591,13 @@
#[cfg(test)]
mod tests {
+ use super::*;
use crate::protocol::{
TFieldIdentifier, TInputProtocol, TListIdentifier, TMapIdentifier, TMessageIdentifier,
TMessageType, TOutputProtocol, TSetIdentifier, TStructIdentifier, TType,
};
use crate::transport::{ReadHalf, TBufferChannel, TIoChannel, WriteHalf};
- use super::*;
-
#[test]
fn must_write_strict_message_call_begin() {
let (_, mut o_prot) = test_objects(true);
@@ -759,13 +835,26 @@
fn must_round_trip_list_begin() {
let (mut i_prot, mut o_prot) = test_objects(true);
- let ident = TListIdentifier::new(TType::List, 900);
+ let ident = TListIdentifier::new(TType::I32, 4);
assert!(o_prot.write_list_begin(&ident).is_ok());
+ assert!(o_prot.write_i32(10).is_ok());
+ assert!(o_prot.write_i32(20).is_ok());
+ assert!(o_prot.write_i32(30).is_ok());
+ assert!(o_prot.write_i32(40).is_ok());
+
+ assert!(o_prot.write_list_end().is_ok());
copy_write_buffer_to_read_buffer!(o_prot);
let received_ident = assert_success!(i_prot.read_list_begin());
assert_eq!(&received_ident, &ident);
+
+ assert_eq!(i_prot.read_i32().unwrap(), 10);
+ assert_eq!(i_prot.read_i32().unwrap(), 20);
+ assert_eq!(i_prot.read_i32().unwrap(), 30);
+ assert_eq!(i_prot.read_i32().unwrap(), 40);
+
+ assert!(i_prot.read_list_end().is_ok());
}
#[test]
@@ -789,14 +878,25 @@
fn must_round_trip_set_begin() {
let (mut i_prot, mut o_prot) = test_objects(true);
- let ident = TSetIdentifier::new(TType::I64, 2000);
+ let ident = TSetIdentifier::new(TType::I64, 3);
assert!(o_prot.write_set_begin(&ident).is_ok());
+ assert!(o_prot.write_i64(123).is_ok());
+ assert!(o_prot.write_i64(456).is_ok());
+ assert!(o_prot.write_i64(789).is_ok());
+
+ assert!(o_prot.write_set_end().is_ok());
copy_write_buffer_to_read_buffer!(o_prot);
let received_ident_result = i_prot.read_set_begin();
assert!(received_ident_result.is_ok());
assert_eq!(&received_ident_result.unwrap(), &ident);
+
+ assert_eq!(i_prot.read_i64().unwrap(), 123);
+ assert_eq!(i_prot.read_i64().unwrap(), 456);
+ assert_eq!(i_prot.read_i64().unwrap(), 789);
+
+ assert!(i_prot.read_set_end().is_ok());
}
#[test]
@@ -820,13 +920,26 @@
fn must_round_trip_map_begin() {
let (mut i_prot, mut o_prot) = test_objects(true);
- let ident = TMapIdentifier::new(TType::Map, TType::Set, 100);
+ let ident = TMapIdentifier::new(TType::String, TType::I32, 2);
assert!(o_prot.write_map_begin(&ident).is_ok());
+ assert!(o_prot.write_string("key1").is_ok());
+ assert!(o_prot.write_i32(100).is_ok());
+ assert!(o_prot.write_string("key2").is_ok());
+ assert!(o_prot.write_i32(200).is_ok());
+
+ assert!(o_prot.write_map_end().is_ok());
copy_write_buffer_to_read_buffer!(o_prot);
let received_ident = assert_success!(i_prot.read_map_begin());
assert_eq!(&received_ident, &ident);
+
+ assert_eq!(i_prot.read_string().unwrap(), "key1");
+ assert_eq!(i_prot.read_i32().unwrap(), 100);
+ assert_eq!(i_prot.read_string().unwrap(), "key2");
+ assert_eq!(i_prot.read_i32().unwrap(), 200);
+
+ assert!(i_prot.read_map_end().is_ok());
}
#[test]
@@ -963,7 +1076,7 @@
TBinaryInputProtocol<ReadHalf<TBufferChannel>>,
TBinaryOutputProtocol<WriteHalf<TBufferChannel>>,
) {
- let mem = TBufferChannel::with_capacity(40, 40);
+ let mem = TBufferChannel::with_capacity(200, 200);
let (r_mem, w_mem) = mem.split().unwrap();
@@ -981,4 +1094,154 @@
assert!(write_fn(&mut o_prot).is_ok());
assert_eq!(o_prot.transport.write_bytes().len(), 0);
}
+
+ #[test]
+ fn must_enforce_recursion_depth_limit() {
+ let mem = TBufferChannel::with_capacity(40, 40);
+ let (r_mem, _) = mem.split().unwrap();
+
+ let config = TConfiguration::builder()
+ .max_recursion_depth(Some(2))
+ .build()
+ .unwrap();
+ let mut i_prot = TBinaryInputProtocol::with_config(r_mem, true, config);
+
+ assert!(i_prot.read_struct_begin().is_ok());
+ assert_eq!(i_prot.recursion_depth, 1);
+
+ assert!(i_prot.read_struct_begin().is_ok());
+ assert_eq!(i_prot.recursion_depth, 2);
+
+ let result = i_prot.read_struct_begin();
+ assert!(result.is_err());
+ match result {
+ Err(crate::Error::Protocol(e)) => {
+ assert_eq!(e.kind, ProtocolErrorKind::DepthLimit);
+ }
+ _ => panic!("Expected protocol error with DepthLimit"),
+ }
+
+ assert!(i_prot.read_struct_end().is_ok());
+ assert_eq!(i_prot.recursion_depth, 1);
+ assert!(i_prot.read_struct_end().is_ok());
+ assert_eq!(i_prot.recursion_depth, 0);
+ }
+
+ #[test]
+ fn must_reject_negative_container_sizes() {
+ let mem = TBufferChannel::with_capacity(40, 40);
+ let (r_mem, mut w_mem) = mem.split().unwrap();
+
+ let mut i_prot = TBinaryInputProtocol::new(r_mem, true);
+
+ w_mem.set_readable_bytes(&[0x0F, 0xFF, 0xFF, 0xFF, 0xFF]);
+
+ let result = i_prot.read_list_begin();
+ assert!(result.is_err());
+ match result {
+ Err(crate::Error::Protocol(e)) => {
+ assert_eq!(e.kind, ProtocolErrorKind::NegativeSize);
+ }
+ _ => panic!("Expected protocol error with NegativeSize"),
+ }
+ }
+
+ #[test]
+ fn must_enforce_container_size_limit() {
+ let mem = TBufferChannel::with_capacity(40, 40);
+ let (r_mem, mut w_mem) = mem.split().unwrap();
+
+ let config = TConfiguration::builder()
+ .max_container_size(Some(100))
+ .build()
+ .unwrap();
+
+ let mut i_prot = TBinaryInputProtocol::with_config(r_mem, true, config);
+
+ w_mem.set_readable_bytes(&[0x0F, 0x00, 0x00, 0x00, 0xC8]);
+
+ let result = i_prot.read_list_begin();
+ assert!(result.is_err());
+ match result {
+ Err(crate::Error::Protocol(e)) => {
+ assert_eq!(e.kind, ProtocolErrorKind::SizeLimit);
+ assert!(e
+ .message
+ .contains("Container size 200 exceeds maximum allowed size of 100"));
+ }
+ _ => panic!("Expected protocol error with SizeLimit"),
+ }
+ }
+
+ #[test]
+ fn must_allow_containers_within_limit() {
+ let mem = TBufferChannel::with_capacity(200, 200);
+ let (r_mem, mut w_mem) = mem.split().unwrap();
+
+ // Create protocol with container limit of 100
+ let config = TConfiguration::builder()
+ .max_container_size(Some(100))
+ .build()
+ .unwrap();
+ let mut i_prot = TBinaryInputProtocol::with_config(r_mem, true, config);
+
+ let mut data = vec![0x08]; // TType::I32
+ data.extend_from_slice(&5i32.to_be_bytes()); // size = 5
+
+ for i in 1i32..=5i32 {
+ data.extend_from_slice(&(i * 10).to_be_bytes());
+ }
+
+ w_mem.set_readable_bytes(&data);
+
+ let result = i_prot.read_list_begin();
+ assert!(result.is_ok());
+ let list_ident = result.unwrap();
+ assert_eq!(list_ident.size, 5);
+ assert_eq!(list_ident.element_type, TType::I32);
+ }
+
+ #[test]
+ fn must_enforce_string_size_limit() {
+ let mem = TBufferChannel::with_capacity(100, 100);
+ let (r_mem, mut w_mem) = mem.split().unwrap();
+
+ let config = TConfiguration::builder()
+ .max_string_size(Some(1000))
+ .build()
+ .unwrap();
+ let mut i_prot = TBinaryInputProtocol::with_config(r_mem, true, config);
+
+ w_mem.set_readable_bytes(&[0x00, 0x00, 0x07, 0xD0]);
+
+ let result = i_prot.read_string();
+ assert!(result.is_err());
+ match result {
+ Err(crate::Error::Protocol(e)) => {
+ assert_eq!(e.kind, ProtocolErrorKind::SizeLimit);
+ assert!(e
+ .message
+ .contains("Byte array size 2000 exceeds maximum allowed size of 1000"));
+ }
+ _ => panic!("Expected protocol error with SizeLimit"),
+ }
+ }
+
+ #[test]
+ fn must_allow_strings_within_limit() {
+ let mem = TBufferChannel::with_capacity(100, 100);
+ let (r_mem, mut w_mem) = mem.split().unwrap();
+
+ let config = TConfiguration::builder()
+ .max_string_size(Some(1000))
+ .build()
+ .unwrap();
+ let mut i_prot = TBinaryInputProtocol::with_config(r_mem, true, config);
+
+ w_mem.set_readable_bytes(&[0x00, 0x00, 0x00, 0x05, b'h', b'e', b'l', b'l', b'o']);
+
+ let result = i_prot.read_string();
+ assert!(result.is_ok());
+ assert_eq!(result.unwrap(), "hello");
+ }
}