[THRIFT-5871] Add message / container size checking for Rust
Bring the Rust implementation somewhat up to par with the other implementations.
I tried 4-5 different ways to get the "perfect" check but since trait specialization is not yet stable,
I was not able to arrive at a solution I'm happy with (code was either ugly, or had runtime overhead).
So for now, we avoid full message size tracking / more precise limit checking, but this is a strong step
in the right direction.
diff --git a/lib/rs/src/protocol/mod.rs b/lib/rs/src/protocol/mod.rs
index 8bbbb3c..573978a 100644
--- a/lib/rs/src/protocol/mod.rs
+++ b/lib/rs/src/protocol/mod.rs
@@ -62,7 +62,7 @@
use std::fmt::{Display, Formatter};
use crate::transport::{TReadTransport, TWriteTransport};
-use crate::{ProtocolError, ProtocolErrorKind};
+use crate::{ProtocolError, ProtocolErrorKind, TConfiguration};
#[cfg(test)]
macro_rules! assert_eq_written_bytes {
@@ -262,6 +262,15 @@
///
/// This method should **never** be used in generated code.
fn read_byte(&mut self) -> crate::Result<u8>;
+
+ /// Get the minimum number of bytes a type will consume on the wire.
+ /// This picks the minimum possible across all protocols (so currently matches the compact protocol).
+ ///
+ /// This is used for pre-allocation size checks.
+ /// The actual data may be larger (e.g., for strings, lists, etc.).
+ fn min_serialized_size(&self, field_type: TType) -> usize {
+ self::compact::compact_protocol_min_serialized_size(field_type)
+ }
}
/// Converts Thrift identifiers, primitives, containers or structs into a
@@ -444,6 +453,10 @@
fn read_byte(&mut self) -> crate::Result<u8> {
(**self).read_byte()
}
+
+ fn min_serialized_size(&self, field_type: TType) -> usize {
+ (**self).min_serialized_size(field_type)
+ }
}
impl<P> TOutputProtocol for Box<P>
@@ -565,7 +578,7 @@
/// let protocol = factory.create(Box::new(channel));
/// ```
pub trait TInputProtocolFactory {
- // Create a `TInputProtocol` that reads bytes from `transport`.
+ /// Create a `TInputProtocol` that reads bytes from `transport`.
fn create(&self, transport: Box<dyn TReadTransport + Send>) -> Box<dyn TInputProtocol + Send>;
}
@@ -920,6 +933,69 @@
}
}
+/// Common container size validation used by all protocols.
+///
+/// Checks that:
+/// - Container size is not negative
+/// - Container size doesn't exceed configured maximum
+/// - Container size * element size doesn't overflow
+/// - Container memory requirements don't exceed message size limit
+pub(crate) fn check_container_size(
+ config: &TConfiguration,
+ container_size: i32,
+ element_size: usize,
+) -> crate::Result<()> {
+ // Check for negative size
+ if container_size < 0 {
+ return Err(crate::Error::Protocol(ProtocolError::new(
+ ProtocolErrorKind::NegativeSize,
+ format!("Negative container size: {}", container_size),
+ )));
+ }
+
+ let size_as_usize = container_size as usize;
+
+ // Check against configured max container size
+ if let Some(max_size) = config.max_container_size() {
+ if size_as_usize > max_size {
+ return Err(crate::Error::Protocol(ProtocolError::new(
+ ProtocolErrorKind::SizeLimit,
+ format!(
+ "Container size {} exceeds maximum allowed size of {}",
+ container_size, max_size
+ ),
+ )));
+ }
+ }
+
+ // Check for potential overflow
+ if let Some(min_bytes_needed) = size_as_usize.checked_mul(element_size) {
+ // TODO: When Rust trait specialization stabilizes, we can add more precise checks
+ // for transports that track exact remaining bytes. For now, we use the message
+ // size limit as a best-effort check.
+ if let Some(max_message_size) = config.max_message_size() {
+ if min_bytes_needed > max_message_size {
+ return Err(crate::Error::Protocol(ProtocolError::new(
+ ProtocolErrorKind::SizeLimit,
+ format!(
+ "Container would require {} bytes, exceeding message size limit of {}",
+ min_bytes_needed, max_message_size
+ ),
+ )));
+ }
+ }
+ Ok(())
+ } else {
+ Err(crate::Error::Protocol(ProtocolError::new(
+ ProtocolErrorKind::SizeLimit,
+ format!(
+ "Container size {} with element size {} bytes would result in overflow",
+ container_size, element_size
+ ),
+ )))
+ }
+}
+
/// Extract the field id from a Thrift field identifier.
///
/// `field_ident` must *not* have `TFieldIdentifier.field_type` of type `TType::Stop`.