THRIFT-4658: TBinaryInputProtocol fails when strict is false
Client: rs
diff --git a/lib/rs/src/protocol/binary.rs b/lib/rs/src/protocol/binary.rs
index 8505b63..42c6c97 100644
--- a/lib/rs/src/protocol/binary.rs
+++ b/lib/rs/src/protocol/binary.rs
@@ -123,7 +123,7 @@
// is the message name. strings (byte arrays) are length-prefixed,
// so we've just read the length in the first 4 bytes
let name_size = BigEndian::read_i32(&first_bytes) as usize;
- let mut name_buf: Vec<u8> = Vec::with_capacity(name_size);
+ let mut name_buf: Vec<u8> = vec![0; name_size];
self.transport.read_exact(&mut name_buf)?;
let name = String::from_utf8(name_buf)?;
@@ -544,8 +544,8 @@
use super::*;
#[test]
- fn must_write_message_call_begin() {
- let (_, mut o_prot) = test_objects();
+ fn must_write_strict_message_call_begin() {
+ let (_, mut o_prot) = test_objects(true);
let ident = TMessageIdentifier::new("test", TMessageType::Call, 1);
assert!(o_prot.write_message_begin(&ident).is_ok());
@@ -573,8 +573,34 @@
}
#[test]
- fn must_write_message_reply_begin() {
- let (_, mut o_prot) = test_objects();
+ fn must_write_non_strict_message_call_begin() {
+ let (_, mut o_prot) = test_objects(false);
+
+ let ident = TMessageIdentifier::new("test", TMessageType::Call, 1);
+ assert!(o_prot.write_message_begin(&ident).is_ok());
+
+ let expected: [u8; 13] = [
+ 0x00,
+ 0x00,
+ 0x00,
+ 0x04,
+ 0x74,
+ 0x65,
+ 0x73,
+ 0x74,
+ 0x01,
+ 0x00,
+ 0x00,
+ 0x00,
+ 0x01,
+ ];
+
+ assert_eq_written_bytes!(o_prot, expected);
+ }
+
+ #[test]
+ fn must_write_strict_message_reply_begin() {
+ let (_, mut o_prot) = test_objects(true);
let ident = TMessageIdentifier::new("test", TMessageType::Reply, 10);
assert!(o_prot.write_message_begin(&ident).is_ok());
@@ -602,8 +628,47 @@
}
#[test]
+ fn must_write_non_strict_message_reply_begin() {
+ let (_, mut o_prot) = test_objects(false);
+
+ let ident = TMessageIdentifier::new("test", TMessageType::Reply, 10);
+ assert!(o_prot.write_message_begin(&ident).is_ok());
+
+ let expected: [u8; 13] = [
+ 0x00,
+ 0x00,
+ 0x00,
+ 0x04,
+ 0x74,
+ 0x65,
+ 0x73,
+ 0x74,
+ 0x02,
+ 0x00,
+ 0x00,
+ 0x00,
+ 0x0A,
+ ];
+
+ assert_eq_written_bytes!(o_prot, expected);
+ }
+
+ #[test]
fn must_round_trip_strict_message_begin() {
- let (mut i_prot, mut o_prot) = test_objects();
+ let (mut i_prot, mut o_prot) = test_objects(true);
+
+ let sent_ident = TMessageIdentifier::new("test", TMessageType::Call, 1);
+ assert!(o_prot.write_message_begin(&sent_ident).is_ok());
+
+ copy_write_buffer_to_read_buffer!(o_prot);
+
+ let received_ident = assert_success!(i_prot.read_message_begin());
+ assert_eq!(&received_ident, &sent_ident);
+ }
+
+ #[test]
+ fn must_round_trip_non_strict_message_begin() {
+ let (mut i_prot, mut o_prot) = test_objects(false);
let sent_ident = TMessageIdentifier::new("test", TMessageType::Call, 1);
assert!(o_prot.write_message_begin(&sent_ident).is_ok());
@@ -616,22 +681,22 @@
#[test]
fn must_write_message_end() {
- assert_no_write(|o| o.write_message_end());
+ assert_no_write(|o| o.write_message_end(), true);
}
#[test]
fn must_write_struct_begin() {
- assert_no_write(|o| o.write_struct_begin(&TStructIdentifier::new("foo")));
+ assert_no_write(|o| o.write_struct_begin(&TStructIdentifier::new("foo")), true);
}
#[test]
fn must_write_struct_end() {
- assert_no_write(|o| o.write_struct_end());
+ assert_no_write(|o| o.write_struct_end(), true);
}
#[test]
fn must_write_field_begin() {
- let (_, mut o_prot) = test_objects();
+ let (_, mut o_prot) = test_objects(true);
assert!(
o_prot
@@ -645,7 +710,7 @@
#[test]
fn must_round_trip_field_begin() {
- let (mut i_prot, mut o_prot) = test_objects();
+ let (mut i_prot, mut o_prot) = test_objects(true);
let sent_field_ident = TFieldIdentifier::new("foo", TType::I64, 20);
assert!(o_prot.write_field_begin(&sent_field_ident).is_ok());
@@ -663,7 +728,7 @@
#[test]
fn must_write_stop_field() {
- let (_, mut o_prot) = test_objects();
+ let (_, mut o_prot) = test_objects(true);
assert!(o_prot.write_field_stop().is_ok());
@@ -673,7 +738,7 @@
#[test]
fn must_round_trip_field_stop() {
- let (mut i_prot, mut o_prot) = test_objects();
+ let (mut i_prot, mut o_prot) = test_objects(true);
assert!(o_prot.write_field_stop().is_ok());
@@ -691,12 +756,12 @@
#[test]
fn must_write_field_end() {
- assert_no_write(|o| o.write_field_end());
+ assert_no_write(|o| o.write_field_end(), true);
}
#[test]
fn must_write_list_begin() {
- let (_, mut o_prot) = test_objects();
+ let (_, mut o_prot) = test_objects(true);
assert!(
o_prot
@@ -710,7 +775,7 @@
#[test]
fn must_round_trip_list_begin() {
- let (mut i_prot, mut o_prot) = test_objects();
+ let (mut i_prot, mut o_prot) = test_objects(true);
let ident = TListIdentifier::new(TType::List, 900);
assert!(o_prot.write_list_begin(&ident).is_ok());
@@ -723,12 +788,12 @@
#[test]
fn must_write_list_end() {
- assert_no_write(|o| o.write_list_end());
+ assert_no_write(|o| o.write_list_end(), true);
}
#[test]
fn must_write_set_begin() {
- let (_, mut o_prot) = test_objects();
+ let (_, mut o_prot) = test_objects(true);
assert!(
o_prot
@@ -742,7 +807,7 @@
#[test]
fn must_round_trip_set_begin() {
- let (mut i_prot, mut o_prot) = test_objects();
+ let (mut i_prot, mut o_prot) = test_objects(true);
let ident = TSetIdentifier::new(TType::I64, 2000);
assert!(o_prot.write_set_begin(&ident).is_ok());
@@ -756,12 +821,12 @@
#[test]
fn must_write_set_end() {
- assert_no_write(|o| o.write_set_end());
+ assert_no_write(|o| o.write_set_end(), true);
}
#[test]
fn must_write_map_begin() {
- let (_, mut o_prot) = test_objects();
+ let (_, mut o_prot) = test_objects(true);
assert!(
o_prot
@@ -775,7 +840,7 @@
#[test]
fn must_round_trip_map_begin() {
- let (mut i_prot, mut o_prot) = test_objects();
+ let (mut i_prot, mut o_prot) = test_objects(true);
let ident = TMapIdentifier::new(TType::Map, TType::Set, 100);
assert!(o_prot.write_map_begin(&ident).is_ok());
@@ -788,12 +853,12 @@
#[test]
fn must_write_map_end() {
- assert_no_write(|o| o.write_map_end());
+ assert_no_write(|o| o.write_map_end(), true);
}
#[test]
fn must_write_bool_true() {
- let (_, mut o_prot) = test_objects();
+ let (_, mut o_prot) = test_objects(true);
assert!(o_prot.write_bool(true).is_ok());
@@ -803,7 +868,7 @@
#[test]
fn must_write_bool_false() {
- let (_, mut o_prot) = test_objects();
+ let (_, mut o_prot) = test_objects(true);
assert!(o_prot.write_bool(false).is_ok());
@@ -813,7 +878,7 @@
#[test]
fn must_read_bool_true() {
- let (mut i_prot, _) = test_objects();
+ let (mut i_prot, _) = test_objects(true);
set_readable_bytes!(i_prot, &[0x01]);
@@ -823,7 +888,7 @@
#[test]
fn must_read_bool_false() {
- let (mut i_prot, _) = test_objects();
+ let (mut i_prot, _) = test_objects(true);
set_readable_bytes!(i_prot, &[0x00]);
@@ -833,7 +898,7 @@
#[test]
fn must_allow_any_non_zero_value_to_be_interpreted_as_bool_true() {
- let (mut i_prot, _) = test_objects();
+ let (mut i_prot, _) = test_objects(true);
set_readable_bytes!(i_prot, &[0xAC]);
@@ -843,7 +908,7 @@
#[test]
fn must_write_bytes() {
- let (_, mut o_prot) = test_objects();
+ let (_, mut o_prot) = test_objects(true);
let bytes: [u8; 10] = [0x0A, 0xCC, 0xD1, 0x84, 0x99, 0x12, 0xAB, 0xBB, 0x45, 0xDF];
@@ -856,7 +921,7 @@
#[test]
fn must_round_trip_bytes() {
- let (mut i_prot, mut o_prot) = test_objects();
+ let (mut i_prot, mut o_prot) = test_objects(true);
let bytes: [u8; 25] = [
0x20,
@@ -894,7 +959,7 @@
assert_eq!(&received_bytes, &bytes);
}
- fn test_objects()
+ fn test_objects(strict: bool)
-> (TBinaryInputProtocol<ReadHalf<TBufferChannel>>,
TBinaryOutputProtocol<WriteHalf<TBufferChannel>>)
{
@@ -902,17 +967,17 @@
let (r_mem, w_mem) = mem.split().unwrap();
- let i_prot = TBinaryInputProtocol::new(r_mem, true);
- let o_prot = TBinaryOutputProtocol::new(w_mem, true);
+ let i_prot = TBinaryInputProtocol::new(r_mem, strict);
+ let o_prot = TBinaryOutputProtocol::new(w_mem, strict);
(i_prot, o_prot)
}
- fn assert_no_write<F>(mut write_fn: F)
+ fn assert_no_write<F>(mut write_fn: F, strict: bool)
where
F: FnMut(&mut TBinaryOutputProtocol<WriteHalf<TBufferChannel>>) -> ::Result<()>,
{
- let (_, mut o_prot) = test_objects();
+ let (_, mut o_prot) = test_objects(strict);
assert!(write_fn(&mut o_prot).is_ok());
assert_eq!(o_prot.transport.write_bytes().len(), 0);
}