THRIFT-5299: Encode sequence numbers as non-zigzag varint
Client: rs
diff --git a/lib/rs/src/protocol/compact.rs b/lib/rs/src/protocol/compact.rs
index 6fa364f..3687354 100644
--- a/lib/rs/src/protocol/compact.rs
+++ b/lib/rs/src/protocol/compact.rs
@@ -80,6 +80,24 @@
}
}
+ /// Reads a varint for an `i32` without attempting to zigzag-decode it.
+ fn read_non_zigzag_varint(&mut self) -> crate::Result<i32> {
+ let mut num = 0i32;
+
+ let mut buf = [0u8; 1];
+ let mut shift_bits = 0u32;
+ let mut should_continue = true;
+
+ while should_continue {
+ self.transport.read_exact(&mut buf)?;
+ num |= ((buf[0] & 0x7F) as i32) << shift_bits;
+ shift_bits += 7;
+ should_continue = (buf[0] & 0x80) > 0;
+ }
+
+ Ok(num)
+ }
+
fn read_list_set_begin(&mut self) -> crate::Result<(TType, i32)> {
let header = self.read_byte()?;
let element_type = collection_u8_to_type(header & 0x0F)?;
@@ -128,7 +146,7 @@
// NOTE: unsigned right shift will pad with 0s
let message_type: TMessageType = TMessageType::try_from(type_and_byte >> 5)?;
- let sequence_number = self.read_i32()?;
+ let sequence_number = self.read_non_zigzag_varint()?;
let service_call_name = self.read_string()?;
self.last_read_field_id = 0;
@@ -247,7 +265,9 @@
}
fn read_double(&mut self) -> crate::Result<f64> {
- self.transport.read_f64::<LittleEndian>().map_err(From::from)
+ self.transport
+ .read_f64::<LittleEndian>()
+ .map_err(From::from)
}
fn read_string(&mut self) -> crate::Result<String> {
@@ -375,6 +395,27 @@
}
}
+ /// Writes a varint for an `i32` without attempting to zigzag-encode it first.
+ fn write_non_zigzag_varint(&mut self, i: i32) -> crate::Result<()> {
+ let mut output = Vec::with_capacity(5);
+ let mut i = i as u32; // avoids sign extension on bitshift
+
+ loop {
+ if (i & !0x7F) == 0 {
+ output.push(i as u8);
+ break;
+ } else {
+ let varint_byte = ((i as u8) & 0x7F) | 0x80;
+ output.push(varint_byte);
+ i >>= 7;
+ }
+ }
+
+ self.transport
+ .write_all(output.as_slice())
+ .map_err(From::from)
+ }
+
// FIXME: field_type as unconstrained u8 is bad
fn write_field_header(&mut self, field_type: u8, field_id: i16) -> crate::Result<()> {
let field_delta = field_id - self.last_write_field_id;
@@ -388,7 +429,11 @@
Ok(())
}
- fn write_list_set_begin(&mut self, element_type: TType, element_count: i32) -> crate::Result<()> {
+ fn write_list_set_begin(
+ &mut self,
+ element_type: TType,
+ element_count: i32,
+ ) -> crate::Result<()> {
let elem_identifier = collection_type_to_u8(element_type);
if element_count <= 14 {
let header = (element_count as u8) << 4 | elem_identifier;
@@ -396,6 +441,10 @@
} else {
let header = 0xF0 | elem_identifier;
self.write_byte(header)?;
+ // size is strictly positive as per the spec, so:
+ // 1. we first cast to u32
+ // 2. write as varint
+ // which means that integer_encoding will **not** zigzag it first
self.transport
.write_varint(element_count as u32)
.map_err(From::from)
@@ -417,7 +466,7 @@
fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> crate::Result<()> {
self.write_byte(COMPACT_PROTOCOL_ID)?;
self.write_byte((u8::from(identifier.message_type) << 5) | COMPACT_VERSION)?;
- self.write_i32(identifier.sequence_number)?;
+ self.write_non_zigzag_varint(identifier.sequence_number)?;
self.write_string(&identifier.name)?;
Ok(())
}
@@ -491,6 +540,10 @@
}
fn write_bytes(&mut self, b: &[u8]) -> crate::Result<()> {
+ // size is strictly positive as per the spec, so:
+ // 1. we first cast to u32
+ // 2. write as varint
+ // which means that integer_encoding will **not** zigzag it first
self.transport.write_varint(b.len() as u32)?;
self.transport.write_all(b).map_err(From::from)
}
@@ -501,7 +554,7 @@
fn write_i16(&mut self, i: i16) -> crate::Result<()> {
self.transport
- .write_varint(i)
+ .write_varint(i as i32)
.map_err(From::from)
.map(|_| ())
}
@@ -521,7 +574,9 @@
}
fn write_double(&mut self, d: f64) -> crate::Result<()> {
- self.transport.write_f64::<LittleEndian>(d).map_err(From::from)
+ self.transport
+ .write_f64::<LittleEndian>(d)
+ .map_err(From::from)
}
fn write_string(&mut self, s: &str) -> crate::Result<()> {
@@ -548,6 +603,10 @@
if identifier.size == 0 {
self.write_byte(0)
} else {
+ // size is strictly positive as per the spec, so:
+ // 1. we first cast to u32
+ // 2. write as varint
+ // which means that integer_encoding will **not** zigzag it first
self.transport.write_varint(identifier.size as u32)?;
let key_type = identifier
@@ -593,7 +652,10 @@
}
impl TOutputProtocolFactory for TCompactOutputProtocolFactory {
- fn create(&self, transport: Box<dyn TWriteTransport + Send>) -> Box<dyn TOutputProtocol + Send> {
+ fn create(
+ &self,
+ transport: Box<dyn TWriteTransport + Send>,
+ ) -> Box<dyn TOutputProtocol + Send> {
Box::new(TCompactOutputProtocol::new(transport))
}
}
@@ -655,6 +717,8 @@
#[cfg(test)]
mod tests {
+ use std::i32;
+
use crate::protocol::{
TFieldIdentifier, TInputProtocol, TListIdentifier, TMapIdentifier, TMessageIdentifier,
TMessageType, TOutputProtocol, TSetIdentifier, TStructIdentifier, TType,
@@ -664,7 +728,62 @@
use super::*;
#[test]
- fn must_write_message_begin_0() {
+ fn must_write_message_begin_largest_maximum_positive_sequence_number() {
+ let (_, mut o_prot) = test_objects();
+
+ assert_success!(o_prot.write_message_begin(&TMessageIdentifier::new(
+ "bar",
+ TMessageType::Reply,
+ i32::MAX
+ )));
+
+ #[rustfmt::skip]
+ let expected: [u8; 11] = [
+ 0x82, /* protocol ID */
+ 0x41, /* message type | protocol version */
+ 0xFF,
+ 0xFF,
+ 0xFF,
+ 0xFF,
+ 0x07, /* non-zig-zag varint sequence number */
+ 0x03, /* message-name length */
+ 0x62,
+ 0x61,
+ 0x72 /* "bar" */,
+ ];
+
+ assert_eq_written_bytes!(o_prot, expected);
+ }
+
+ #[test]
+ fn must_read_message_begin_largest_maximum_positive_sequence_number() {
+ let (mut i_prot, _) = test_objects();
+
+ #[rustfmt::skip]
+ let source_bytes: [u8; 11] = [
+ 0x82, /* protocol ID */
+ 0x41, /* message type | protocol version */
+ 0xFF,
+ 0xFF,
+ 0xFF,
+ 0xFF,
+ 0x07, /* non-zig-zag varint sequence number */
+ 0x03, /* message-name length */
+ 0x62,
+ 0x61,
+ 0x72 /* "bar" */,
+ ];
+
+ i_prot.transport.set_readable_bytes(&source_bytes);
+
+ let expected = TMessageIdentifier::new("bar", TMessageType::Reply, i32::MAX);
+ let res = assert_success!(i_prot.read_message_begin());
+
+ assert_eq!(&expected, &res);
+ }
+
+ #[test]
+ fn must_write_message_begin_positive_sequence_number_0() {
let (_, mut o_prot) = test_objects();
assert_success!(o_prot.write_message_begin(&TMessageIdentifier::new(
@@ -677,8 +796,8 @@
let expected: [u8; 8] = [
0x82, /* protocol ID */
0x21, /* message type | protocol version */
- 0xDE,
- 0x06, /* zig-zag varint sequence number */
+ 0xAF,
+ 0x03, /* non-zig-zag varint sequence number */
0x03, /* message-name length */
0x66,
0x6F,
@@ -689,7 +808,31 @@
}
#[test]
- fn must_write_message_begin_1() {
+ fn must_read_message_begin_positive_sequence_number_0() {
+ let (mut i_prot, _) = test_objects();
+
+ #[rustfmt::skip]
+ let source_bytes: [u8; 8] = [
+ 0x82, /* protocol ID */
+ 0x21, /* message type | protocol version */
+ 0xAF,
+ 0x03, /* non-zig-zag varint sequence number */
+ 0x03, /* message-name length */
+ 0x66,
+ 0x6F,
+ 0x6F /* "foo" */,
+ ];
+
+ i_prot.transport.set_readable_bytes(&source_bytes);
+
+ let expected = TMessageIdentifier::new("foo", TMessageType::Call, 431);
+ let res = assert_success!(i_prot.read_message_begin());
+
+ assert_eq!(&expected, &res);
+ }
+
+ #[test]
+ fn must_write_message_begin_positive_sequence_number_1() {
let (_, mut o_prot) = test_objects();
assert_success!(o_prot.write_message_begin(&TMessageIdentifier::new(
@@ -702,9 +845,9 @@
let expected: [u8; 9] = [
0x82, /* protocol ID */
0x41, /* message type | protocol version */
- 0xA8,
- 0x89,
- 0x79, /* zig-zag varint sequence number */
+ 0xD4,
+ 0xC4,
+ 0x3C, /* non-zig-zag varint sequence number */
0x03, /* message-name length */
0x62,
0x61,
@@ -715,6 +858,305 @@
}
#[test]
+ fn must_read_message_begin_positive_sequence_number_1() {
+ let (mut i_prot, _) = test_objects();
+
+ #[rustfmt::skip]
+ let source_bytes: [u8; 9] = [
+ 0x82, /* protocol ID */
+ 0x41, /* message type | protocol version */
+ 0xD4,
+ 0xC4,
+ 0x3C, /* non-zig-zag varint sequence number */
+ 0x03, /* message-name length */
+ 0x62,
+ 0x61,
+ 0x72 /* "bar" */,
+ ];
+
+ i_prot.transport.set_readable_bytes(&source_bytes);
+
+ let expected = TMessageIdentifier::new("bar", TMessageType::Reply, 991_828);
+ let res = assert_success!(i_prot.read_message_begin());
+
+ assert_eq!(&expected, &res);
+ }
+
+ #[test]
+ fn must_write_message_begin_zero_sequence_number() {
+ let (_, mut o_prot) = test_objects();
+
+ assert_success!(o_prot.write_message_begin(&TMessageIdentifier::new(
+ "bar",
+ TMessageType::Reply,
+ 0
+ )));
+
+ #[rustfmt::skip]
+ let expected: [u8; 7] = [
+ 0x82, /* protocol ID */
+ 0x41, /* message type | protocol version */
+ 0x00, /* non-zig-zag varint sequence number */
+ 0x03, /* message-name length */
+ 0x62,
+ 0x61,
+ 0x72 /* "bar" */,
+ ];
+
+ assert_eq_written_bytes!(o_prot, expected);
+ }
+
+ #[test]
+ fn must_read_message_begin_zero_sequence_number() {
+ let (mut i_prot, _) = test_objects();
+
+ #[rustfmt::skip]
+ let source_bytes: [u8; 7] = [
+ 0x82, /* protocol ID */
+ 0x41, /* message type | protocol version */
+ 0x00, /* non-zig-zag varint sequence number */
+ 0x03, /* message-name length */
+ 0x62,
+ 0x61,
+ 0x72 /* "bar" */,
+ ];
+
+ i_prot.transport.set_readable_bytes(&source_bytes);
+
+ let expected = TMessageIdentifier::new("bar", TMessageType::Reply, 0);
+ let res = assert_success!(i_prot.read_message_begin());
+
+ assert_eq!(&expected, &res);
+ }
+
+ #[test]
+ fn must_write_message_begin_largest_minimum_negative_sequence_number() {
+ let (_, mut o_prot) = test_objects();
+
+ assert_success!(o_prot.write_message_begin(&TMessageIdentifier::new(
+ "bar",
+ TMessageType::Reply,
+ i32::MIN
+ )));
+
+ // two's complement notation of i32::MIN = 1000_0000_0000_0000_0000_0000_0000_0000
+ #[rustfmt::skip]
+ let expected: [u8; 11] = [
+ 0x82, /* protocol ID */
+ 0x41, /* message type | protocol version */
+ 0x80,
+ 0x80,
+ 0x80,
+ 0x80,
+ 0x08, /* non-zig-zag varint sequence number */
+ 0x03, /* message-name length */
+ 0x62,
+ 0x61,
+ 0x72 /* "bar" */,
+ ];
+
+ assert_eq_written_bytes!(o_prot, expected);
+ }
+
+ #[test]
+ fn must_read_message_begin_largest_minimum_negative_sequence_number() {
+ let (mut i_prot, _) = test_objects();
+
+ // two's complement notation of i32::MIN = 1000_0000_0000_0000_0000_0000_0000_0000
+ #[rustfmt::skip]
+ let source_bytes: [u8; 11] = [
+ 0x82, /* protocol ID */
+ 0x41, /* message type | protocol version */
+ 0x80,
+ 0x80,
+ 0x80,
+ 0x80,
+ 0x08, /* non-zig-zag varint sequence number */
+ 0x03, /* message-name length */
+ 0x62,
+ 0x61,
+ 0x72 /* "bar" */,
+ ];
+
+ i_prot.transport.set_readable_bytes(&source_bytes);
+
+ let expected = TMessageIdentifier::new("bar", TMessageType::Reply, i32::MIN);
+ let res = assert_success!(i_prot.read_message_begin());
+
+ assert_eq!(&expected, &res);
+ }
+
+ #[test]
+ fn must_write_message_begin_negative_sequence_number_0() {
+ let (_, mut o_prot) = test_objects();
+
+ assert_success!(o_prot.write_message_begin(&TMessageIdentifier::new(
+ "foo",
+ TMessageType::Call,
+ -431
+ )));
+
+ // signed two's complement of -431 = 1111_1111_1111_1111_1111_1110_0101_0001
+ #[rustfmt::skip]
+ let expected: [u8; 11] = [
+ 0x82, /* protocol ID */
+ 0x21, /* message type | protocol version */
+ 0xD1,
+ 0xFC,
+ 0xFF,
+ 0xFF,
+ 0x0F, /* non-zig-zag varint sequence number */
+ 0x03, /* message-name length */
+ 0x66,
+ 0x6F,
+ 0x6F /* "foo" */,
+ ];
+
+ assert_eq_written_bytes!(o_prot, expected);
+ }
+
+ #[test]
+ fn must_read_message_begin_negative_sequence_number_0() {
+ let (mut i_prot, _) = test_objects();
+
+ // signed two's complement of -431 = 1111_1111_1111_1111_1111_1110_0101_0001
+ #[rustfmt::skip]
+ let source_bytes: [u8; 11] = [
+ 0x82, /* protocol ID */
+ 0x21, /* message type | protocol version */
+ 0xD1,
+ 0xFC,
+ 0xFF,
+ 0xFF,
+ 0x0F, /* non-zig-zag varint sequence number */
+ 0x03, /* message-name length */
+ 0x66,
+ 0x6F,
+ 0x6F /* "foo" */,
+ ];
+
+ i_prot.transport.set_readable_bytes(&source_bytes);
+
+ let expected = TMessageIdentifier::new("foo", TMessageType::Call, -431);
+ let res = assert_success!(i_prot.read_message_begin());
+
+ assert_eq!(&expected, &res);
+ }
+
+ #[test]
+ fn must_write_message_begin_negative_sequence_number_1() {
+ let (_, mut o_prot) = test_objects();
+
+ assert_success!(o_prot.write_message_begin(&TMessageIdentifier::new(
+ "foo",
+ TMessageType::Call,
+ -73_184_125
+ )));
+
+ // signed two's complement of -73184125 = 1111_1011_1010_0011_0100_1100_1000_0011
+ #[rustfmt::skip]
+ let expected: [u8; 11] = [
+ 0x82, /* protocol ID */
+ 0x21, /* message type | protocol version */
+ 0x83,
+ 0x99,
+ 0x8D,
+ 0xDD,
+ 0x0F, /* non-zig-zag varint sequence number */
+ 0x03, /* message-name length */
+ 0x66,
+ 0x6F,
+ 0x6F /* "foo" */,
+ ];
+
+ assert_eq_written_bytes!(o_prot, expected);
+ }
+
+ #[test]
+ fn must_read_message_begin_negative_sequence_number_1() {
+ let (mut i_prot, _) = test_objects();
+
+ // signed two's complement of -73184125 = 1111_1011_1010_0011_0100_1100_1000_0011
+ #[rustfmt::skip]
+ let source_bytes: [u8; 11] = [
+ 0x82, /* protocol ID */
+ 0x21, /* message type | protocol version */
+ 0x83,
+ 0x99,
+ 0x8D,
+ 0xDD,
+ 0x0F, /* non-zig-zag varint sequence number */
+ 0x03, /* message-name length */
+ 0x66,
+ 0x6F,
+ 0x6F /* "foo" */,
+ ];
+
+ i_prot.transport.set_readable_bytes(&source_bytes);
+
+ let expected = TMessageIdentifier::new("foo", TMessageType::Call, -73_184_125);
+ let res = assert_success!(i_prot.read_message_begin());
+
+ assert_eq!(&expected, &res);
+ }
+
+ #[test]
+ fn must_write_message_begin_negative_sequence_number_2() {
+ let (_, mut o_prot) = test_objects();
+
+ assert_success!(o_prot.write_message_begin(&TMessageIdentifier::new(
+ "foo",
+ TMessageType::Call,
+ -1_073_741_823
+ )));
+
+ // signed two's complement of -1073741823 = 1100_0000_0000_0000_0000_0000_0000_0001
+ #[rustfmt::skip]
+ let expected: [u8; 11] = [
+ 0x82, /* protocol ID */
+ 0x21, /* message type | protocol version */
+ 0x81,
+ 0x80,
+ 0x80,
+ 0x80,
+ 0x0C, /* non-zig-zag varint sequence number */
+ 0x03, /* message-name length */
+ 0x66,
+ 0x6F,
+ 0x6F /* "foo" */,
+ ];
+
+ assert_eq_written_bytes!(o_prot, expected);
+ }
+
+ #[test]
+ fn must_read_message_begin_negative_sequence_number_2() {
+ let (mut i_prot, _) = test_objects();
+
+ // signed two's complement of -1073741823 = 1100_0000_0000_0000_0000_0000_0000_0001
+ let source_bytes: [u8; 11] = [
+ 0x82, /* protocol ID */
+ 0x21, /* message type | protocol version */
+ 0x81,
+ 0x80,
+ 0x80,
+ 0x80,
+ 0x0C, /* non-zig-zag varint sequence number */
+ 0x03, /* message-name length */
+ 0x66,
+ 0x6F,
+ 0x6F /* "foo" */,
+ ];
+
+ i_prot.transport.set_readable_bytes(&source_bytes);
+
+ let expected = TMessageIdentifier::new("foo", TMessageType::Call, -1_073_741_823);
+ let res = assert_success!(i_prot.read_message_begin());
+
+ assert_eq!(&expected, &res);
+ }
+
+ #[test]
fn must_round_trip_upto_i64_maxvalue() {
// See https://issues.apache.org/jira/browse/THRIFT-5131
for i in 0..64 {
@@ -722,11 +1164,7 @@
let val: i64 = ((1u64 << i) - 1) as i64;
o_prot
- .write_field_begin(&TFieldIdentifier::new(
- "val",
- TType::I64,
- 1
- ))
+ .write_field_begin(&TFieldIdentifier::new("val", TType::I64, 1))
.unwrap();
o_prot.write_i64(val).unwrap();
o_prot.write_field_end().unwrap();