[THRIFT-5871] Add message / container size checking for Rust

Bring the Rust implementation somewhat up to par with the other implementations.
I tried 4-5 different ways to get the "perfect" check but since trait specialization is not yet stable,
I was not able to arrive at a solution I'm happy with (code was either ugly, or had runtime overhead).

So for now, we avoid full message size tracking / more precise limit checking, but this is a strong step
in the right direction.
diff --git a/lib/rs/src/protocol/binary.rs b/lib/rs/src/protocol/binary.rs
index b4b51f6..596285f 100644
--- a/lib/rs/src/protocol/binary.rs
+++ b/lib/rs/src/protocol/binary.rs
@@ -24,7 +24,7 @@
 };
 use super::{TOutputProtocol, TOutputProtocolFactory, TSetIdentifier, TStructIdentifier, TType};
 use crate::transport::{TReadTransport, TWriteTransport};
-use crate::{ProtocolError, ProtocolErrorKind};
+use crate::{ProtocolError, ProtocolErrorKind, TConfiguration};
 
 const BINARY_PROTOCOL_VERSION_1: u32 = 0x8001_0000;
 
@@ -57,6 +57,8 @@
 {
     strict: bool,
     pub transport: T, // FIXME: shouldn't be public
+    config: TConfiguration,
+    recursion_depth: usize,
 }
 
 impl<T> TBinaryInputProtocol<T>
@@ -67,8 +69,29 @@
     ///
     /// Set `strict` to `true` if all incoming messages contain the protocol
     /// version number in the protocol header.
-    pub fn new(transport: T, strict: bool) -> TBinaryInputProtocol<T> {
-        TBinaryInputProtocol { strict, transport }
+    pub fn new(transport: T, strict: bool) -> Self {
+        Self::with_config(transport, strict, TConfiguration::default())
+    }
+
+    pub fn with_config(transport: T, strict: bool, config: TConfiguration) -> Self {
+        TBinaryInputProtocol {
+            strict,
+            transport,
+            config,
+            recursion_depth: 0,
+        }
+    }
+
+    fn check_recursion_depth(&self) -> crate::Result<()> {
+        if let Some(limit) = self.config.max_recursion_depth() {
+            if self.recursion_depth >= limit {
+                return Err(crate::Error::Protocol(ProtocolError::new(
+                    ProtocolErrorKind::DepthLimit,
+                    format!("Maximum recursion depth {} exceeded", limit),
+                )));
+            }
+        }
+        Ok(())
     }
 }
 
@@ -78,6 +101,7 @@
 {
     #[allow(clippy::collapsible_if)]
     fn read_message_begin(&mut self) -> crate::Result<TMessageIdentifier> {
+        // TODO: Once specialization is stable, call the message size tracking here
         let mut first_bytes = vec![0; 4];
         self.transport.read_exact(&mut first_bytes[..])?;
 
@@ -130,10 +154,13 @@
     }
 
     fn read_struct_begin(&mut self) -> crate::Result<Option<TStructIdentifier>> {
+        self.check_recursion_depth()?;
+        self.recursion_depth += 1;
         Ok(None)
     }
 
     fn read_struct_end(&mut self) -> crate::Result<()> {
+        self.recursion_depth -= 1;
         Ok(())
     }
 
@@ -154,8 +181,28 @@
     }
 
     fn read_bytes(&mut self) -> crate::Result<Vec<u8>> {
-        let num_bytes = self.transport.read_i32::<BigEndian>()? as usize;
-        let mut buf = vec![0u8; num_bytes];
+        let num_bytes = self.transport.read_i32::<BigEndian>()?;
+
+        if num_bytes < 0 {
+            return Err(crate::Error::Protocol(ProtocolError::new(
+                ProtocolErrorKind::NegativeSize,
+                format!("Negative byte array size: {}", num_bytes),
+            )));
+        }
+
+        if let Some(max_size) = self.config.max_string_size() {
+            if num_bytes as usize > max_size {
+                return Err(crate::Error::Protocol(ProtocolError::new(
+                    ProtocolErrorKind::SizeLimit,
+                    format!(
+                        "Byte array size {} exceeds maximum allowed size of {}",
+                        num_bytes, max_size
+                    ),
+                )));
+            }
+        }
+
+        let mut buf = vec![0u8; num_bytes as usize];
         self.transport
             .read_exact(&mut buf)
             .map(|_| buf)
@@ -206,6 +253,8 @@
     fn read_list_begin(&mut self) -> crate::Result<TListIdentifier> {
         let element_type: TType = self.read_byte().and_then(field_type_from_u8)?;
         let size = self.read_i32()?;
+        let min_element_size = self.min_serialized_size(element_type);
+        super::check_container_size(&self.config, size, min_element_size)?;
         Ok(TListIdentifier::new(element_type, size))
     }
 
@@ -216,6 +265,8 @@
     fn read_set_begin(&mut self) -> crate::Result<TSetIdentifier> {
         let element_type: TType = self.read_byte().and_then(field_type_from_u8)?;
         let size = self.read_i32()?;
+        let min_element_size = self.min_serialized_size(element_type);
+        super::check_container_size(&self.config, size, min_element_size)?;
         Ok(TSetIdentifier::new(element_type, size))
     }
 
@@ -227,6 +278,12 @@
         let key_type: TType = self.read_byte().and_then(field_type_from_u8)?;
         let value_type: TType = self.read_byte().and_then(field_type_from_u8)?;
         let size = self.read_i32()?;
+
+        let key_min_size = self.min_serialized_size(key_type);
+        let value_min_size = self.min_serialized_size(value_type);
+        let element_size = key_min_size + value_min_size;
+        super::check_container_size(&self.config, size, element_size)?;
+
         Ok(TMapIdentifier::new(key_type, value_type, size))
     }
 
@@ -240,6 +297,26 @@
     fn read_byte(&mut self) -> crate::Result<u8> {
         self.transport.read_u8().map_err(From::from)
     }
+
+    fn min_serialized_size(&self, field_type: TType) -> usize {
+        match field_type {
+            TType::Stop => 1,   // 1 byte minimum
+            TType::Void => 1,   // 1 byte minimum
+            TType::Bool => 1,   // 1 byte
+            TType::I08 => 1,    // 1 byte
+            TType::Double => 8, // 8 bytes
+            TType::I16 => 2,    // 2 bytes
+            TType::I32 => 4,    // 4 bytes
+            TType::I64 => 8,    // 8 bytes
+            TType::String => 4, // 4 bytes for length prefix
+            TType::Struct => 1, // 1 byte minimum (stop field)
+            TType::Map => 4,    // 4 bytes size
+            TType::Set => 4,    // 4 bytes size
+            TType::List => 4,   // 4 bytes size
+            TType::Uuid => 16,  // 16 bytes
+            TType::Utf7 => 1,   // 1 byte
+        }
+    }
 }
 
 /// Factory for creating instances of `TBinaryInputProtocol`.
@@ -514,14 +591,13 @@
 #[cfg(test)]
 mod tests {
 
+    use super::*;
     use crate::protocol::{
         TFieldIdentifier, TInputProtocol, TListIdentifier, TMapIdentifier, TMessageIdentifier,
         TMessageType, TOutputProtocol, TSetIdentifier, TStructIdentifier, TType,
     };
     use crate::transport::{ReadHalf, TBufferChannel, TIoChannel, WriteHalf};
 
-    use super::*;
-
     #[test]
     fn must_write_strict_message_call_begin() {
         let (_, mut o_prot) = test_objects(true);
@@ -759,13 +835,26 @@
     fn must_round_trip_list_begin() {
         let (mut i_prot, mut o_prot) = test_objects(true);
 
-        let ident = TListIdentifier::new(TType::List, 900);
+        let ident = TListIdentifier::new(TType::I32, 4);
         assert!(o_prot.write_list_begin(&ident).is_ok());
+        assert!(o_prot.write_i32(10).is_ok());
+        assert!(o_prot.write_i32(20).is_ok());
+        assert!(o_prot.write_i32(30).is_ok());
+        assert!(o_prot.write_i32(40).is_ok());
+
+        assert!(o_prot.write_list_end().is_ok());
 
         copy_write_buffer_to_read_buffer!(o_prot);
 
         let received_ident = assert_success!(i_prot.read_list_begin());
         assert_eq!(&received_ident, &ident);
+
+        assert_eq!(i_prot.read_i32().unwrap(), 10);
+        assert_eq!(i_prot.read_i32().unwrap(), 20);
+        assert_eq!(i_prot.read_i32().unwrap(), 30);
+        assert_eq!(i_prot.read_i32().unwrap(), 40);
+
+        assert!(i_prot.read_list_end().is_ok());
     }
 
     #[test]
@@ -789,14 +878,25 @@
     fn must_round_trip_set_begin() {
         let (mut i_prot, mut o_prot) = test_objects(true);
 
-        let ident = TSetIdentifier::new(TType::I64, 2000);
+        let ident = TSetIdentifier::new(TType::I64, 3);
         assert!(o_prot.write_set_begin(&ident).is_ok());
+        assert!(o_prot.write_i64(123).is_ok());
+        assert!(o_prot.write_i64(456).is_ok());
+        assert!(o_prot.write_i64(789).is_ok());
+
+        assert!(o_prot.write_set_end().is_ok());
 
         copy_write_buffer_to_read_buffer!(o_prot);
 
         let received_ident_result = i_prot.read_set_begin();
         assert!(received_ident_result.is_ok());
         assert_eq!(&received_ident_result.unwrap(), &ident);
+
+        assert_eq!(i_prot.read_i64().unwrap(), 123);
+        assert_eq!(i_prot.read_i64().unwrap(), 456);
+        assert_eq!(i_prot.read_i64().unwrap(), 789);
+
+        assert!(i_prot.read_set_end().is_ok());
     }
 
     #[test]
@@ -820,13 +920,26 @@
     fn must_round_trip_map_begin() {
         let (mut i_prot, mut o_prot) = test_objects(true);
 
-        let ident = TMapIdentifier::new(TType::Map, TType::Set, 100);
+        let ident = TMapIdentifier::new(TType::String, TType::I32, 2);
         assert!(o_prot.write_map_begin(&ident).is_ok());
+        assert!(o_prot.write_string("key1").is_ok());
+        assert!(o_prot.write_i32(100).is_ok());
+        assert!(o_prot.write_string("key2").is_ok());
+        assert!(o_prot.write_i32(200).is_ok());
+
+        assert!(o_prot.write_map_end().is_ok());
 
         copy_write_buffer_to_read_buffer!(o_prot);
 
         let received_ident = assert_success!(i_prot.read_map_begin());
         assert_eq!(&received_ident, &ident);
+
+        assert_eq!(i_prot.read_string().unwrap(), "key1");
+        assert_eq!(i_prot.read_i32().unwrap(), 100);
+        assert_eq!(i_prot.read_string().unwrap(), "key2");
+        assert_eq!(i_prot.read_i32().unwrap(), 200);
+
+        assert!(i_prot.read_map_end().is_ok());
     }
 
     #[test]
@@ -963,7 +1076,7 @@
         TBinaryInputProtocol<ReadHalf<TBufferChannel>>,
         TBinaryOutputProtocol<WriteHalf<TBufferChannel>>,
     ) {
-        let mem = TBufferChannel::with_capacity(40, 40);
+        let mem = TBufferChannel::with_capacity(200, 200);
 
         let (r_mem, w_mem) = mem.split().unwrap();
 
@@ -981,4 +1094,154 @@
         assert!(write_fn(&mut o_prot).is_ok());
         assert_eq!(o_prot.transport.write_bytes().len(), 0);
     }
+
+    #[test]
+    fn must_enforce_recursion_depth_limit() {
+        let mem = TBufferChannel::with_capacity(40, 40);
+        let (r_mem, _) = mem.split().unwrap();
+
+        let config = TConfiguration::builder()
+            .max_recursion_depth(Some(2))
+            .build()
+            .unwrap();
+        let mut i_prot = TBinaryInputProtocol::with_config(r_mem, true, config);
+
+        assert!(i_prot.read_struct_begin().is_ok());
+        assert_eq!(i_prot.recursion_depth, 1);
+
+        assert!(i_prot.read_struct_begin().is_ok());
+        assert_eq!(i_prot.recursion_depth, 2);
+
+        let result = i_prot.read_struct_begin();
+        assert!(result.is_err());
+        match result {
+            Err(crate::Error::Protocol(e)) => {
+                assert_eq!(e.kind, ProtocolErrorKind::DepthLimit);
+            }
+            _ => panic!("Expected protocol error with DepthLimit"),
+        }
+
+        assert!(i_prot.read_struct_end().is_ok());
+        assert_eq!(i_prot.recursion_depth, 1);
+        assert!(i_prot.read_struct_end().is_ok());
+        assert_eq!(i_prot.recursion_depth, 0);
+    }
+
+    #[test]
+    fn must_reject_negative_container_sizes() {
+        let mem = TBufferChannel::with_capacity(40, 40);
+        let (r_mem, mut w_mem) = mem.split().unwrap();
+
+        let mut i_prot = TBinaryInputProtocol::new(r_mem, true);
+
+        w_mem.set_readable_bytes(&[0x0F, 0xFF, 0xFF, 0xFF, 0xFF]);
+
+        let result = i_prot.read_list_begin();
+        assert!(result.is_err());
+        match result {
+            Err(crate::Error::Protocol(e)) => {
+                assert_eq!(e.kind, ProtocolErrorKind::NegativeSize);
+            }
+            _ => panic!("Expected protocol error with NegativeSize"),
+        }
+    }
+
+    #[test]
+    fn must_enforce_container_size_limit() {
+        let mem = TBufferChannel::with_capacity(40, 40);
+        let (r_mem, mut w_mem) = mem.split().unwrap();
+
+        let config = TConfiguration::builder()
+            .max_container_size(Some(100))
+            .build()
+            .unwrap();
+
+        let mut i_prot = TBinaryInputProtocol::with_config(r_mem, true, config);
+
+        w_mem.set_readable_bytes(&[0x0F, 0x00, 0x00, 0x00, 0xC8]);
+
+        let result = i_prot.read_list_begin();
+        assert!(result.is_err());
+        match result {
+            Err(crate::Error::Protocol(e)) => {
+                assert_eq!(e.kind, ProtocolErrorKind::SizeLimit);
+                assert!(e
+                    .message
+                    .contains("Container size 200 exceeds maximum allowed size of 100"));
+            }
+            _ => panic!("Expected protocol error with SizeLimit"),
+        }
+    }
+
+    #[test]
+    fn must_allow_containers_within_limit() {
+        let mem = TBufferChannel::with_capacity(200, 200);
+        let (r_mem, mut w_mem) = mem.split().unwrap();
+
+        // Create protocol with container limit of 100
+        let config = TConfiguration::builder()
+            .max_container_size(Some(100))
+            .build()
+            .unwrap();
+        let mut i_prot = TBinaryInputProtocol::with_config(r_mem, true, config);
+
+        let mut data = vec![0x08]; // TType::I32
+        data.extend_from_slice(&5i32.to_be_bytes()); // size = 5
+
+        for i in 1i32..=5i32 {
+            data.extend_from_slice(&(i * 10).to_be_bytes());
+        }
+
+        w_mem.set_readable_bytes(&data);
+
+        let result = i_prot.read_list_begin();
+        assert!(result.is_ok());
+        let list_ident = result.unwrap();
+        assert_eq!(list_ident.size, 5);
+        assert_eq!(list_ident.element_type, TType::I32);
+    }
+
+    #[test]
+    fn must_enforce_string_size_limit() {
+        let mem = TBufferChannel::with_capacity(100, 100);
+        let (r_mem, mut w_mem) = mem.split().unwrap();
+
+        let config = TConfiguration::builder()
+            .max_string_size(Some(1000))
+            .build()
+            .unwrap();
+        let mut i_prot = TBinaryInputProtocol::with_config(r_mem, true, config);
+
+        w_mem.set_readable_bytes(&[0x00, 0x00, 0x07, 0xD0]);
+
+        let result = i_prot.read_string();
+        assert!(result.is_err());
+        match result {
+            Err(crate::Error::Protocol(e)) => {
+                assert_eq!(e.kind, ProtocolErrorKind::SizeLimit);
+                assert!(e
+                    .message
+                    .contains("Byte array size 2000 exceeds maximum allowed size of 1000"));
+            }
+            _ => panic!("Expected protocol error with SizeLimit"),
+        }
+    }
+
+    #[test]
+    fn must_allow_strings_within_limit() {
+        let mem = TBufferChannel::with_capacity(100, 100);
+        let (r_mem, mut w_mem) = mem.split().unwrap();
+
+        let config = TConfiguration::builder()
+            .max_string_size(Some(1000))
+            .build()
+            .unwrap();
+        let mut i_prot = TBinaryInputProtocol::with_config(r_mem, true, config);
+
+        w_mem.set_readable_bytes(&[0x00, 0x00, 0x00, 0x05, b'h', b'e', b'l', b'l', b'o']);
+
+        let result = i_prot.read_string();
+        assert!(result.is_ok());
+        assert_eq!(result.unwrap(), "hello");
+    }
 }