Implemented header protocol for Ruby client library
diff --git a/lib/rb/spec/header_protocol_spec.rb b/lib/rb/spec/header_protocol_spec.rb
new file mode 100644
index 0000000..3feb9b6
--- /dev/null
+++ b/lib/rb/spec/header_protocol_spec.rb
@@ -0,0 +1,466 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+require 'spec_helper'
+require_relative 'support/header_protocol_helper'
+
+describe 'HeaderProtocol' do
+ include HeaderProtocolHelper
+
+ describe Thrift::HeaderProtocol do
+ before(:each) do
+ @buffer = Thrift::MemoryBufferTransport.new
+ @protocol = Thrift::HeaderProtocol.new(@buffer)
+ end
+
+ it "should provide a to_s" do
+ expect(@protocol.to_s).to match(/header\(compact/)
+ end
+
+ it "should wrap transport in HeaderTransport" do
+ expect(@protocol.trans).to be_a(Thrift::HeaderTransport)
+ end
+
+ it "should use existing HeaderTransport if passed" do
+ header_trans = Thrift::HeaderTransport.new(@buffer)
+ protocol = Thrift::HeaderProtocol.new(header_trans)
+ expect(protocol.trans).to equal(header_trans)
+ end
+
+ describe "header management delegation" do
+ it "should delegate get_headers" do
+ # Write with headers and read back to populate read headers
+ @protocol.set_header("key", "value")
+ @protocol.write_message_begin("test", Thrift::MessageTypes::CALL, 1)
+ @protocol.write_struct_begin("args")
+ @protocol.write_field_stop
+ @protocol.write_struct_end
+ @protocol.write_message_end
+ @protocol.trans.flush
+
+ # Read back
+ data = @buffer.read(@buffer.available)
+ read_buffer = Thrift::MemoryBufferTransport.new(data)
+ read_protocol = Thrift::HeaderProtocol.new(read_buffer)
+
+ read_protocol.read_message_begin
+ headers = read_protocol.get_headers
+ expect(headers["key"]).to eq("value")
+ end
+
+ it "should delegate set_header" do
+ expect(@protocol.trans).to receive(:set_header).with("key", "value")
+ @protocol.set_header("key", "value")
+ end
+
+ it "should delegate clear_headers" do
+ expect(@protocol.trans).to receive(:clear_headers)
+ @protocol.clear_headers
+ end
+
+ it "should delegate add_transform" do
+ expect(@protocol.trans).to receive(:add_transform).with(Thrift::HeaderTransformID::ZLIB)
+ @protocol.add_transform(Thrift::HeaderTransformID::ZLIB)
+ end
+ end
+
+ describe "protocol delegation with Binary protocol" do
+ before(:each) do
+ @buffer = Thrift::MemoryBufferTransport.new
+ @protocol = Thrift::HeaderProtocol.new(@buffer, nil, Thrift::HeaderSubprotocolID::BINARY)
+ end
+
+ it "should write message begin" do
+ @protocol.write_message_begin("test_method", Thrift::MessageTypes::CALL, 123)
+ @protocol.write_message_end
+ @protocol.trans.flush
+
+ # Verify we can read it back
+ data = @buffer.read(@buffer.available)
+ read_buffer = Thrift::MemoryBufferTransport.new(data)
+ read_protocol = Thrift::HeaderProtocol.new(read_buffer)
+
+ name, type, seqid = read_protocol.read_message_begin
+ expect(name).to eq("test_method")
+ expect(type).to eq(Thrift::MessageTypes::CALL)
+ expect(seqid).to eq(123)
+ end
+
+ it "should write and read structs" do
+ @protocol.write_message_begin("test", Thrift::MessageTypes::CALL, 1)
+ @protocol.write_struct_begin("TestStruct")
+ @protocol.write_field_begin("field1", Thrift::Types::I32, 1)
+ @protocol.write_i32(42)
+ @protocol.write_field_end
+ @protocol.write_field_stop
+ @protocol.write_struct_end
+ @protocol.write_message_end
+ @protocol.trans.flush
+
+ data = @buffer.read(@buffer.available)
+ read_buffer = Thrift::MemoryBufferTransport.new(data)
+ read_protocol = Thrift::HeaderProtocol.new(read_buffer)
+
+ read_protocol.read_message_begin
+ read_protocol.read_struct_begin
+ name, type, id = read_protocol.read_field_begin
+ expect(type).to eq(Thrift::Types::I32)
+ expect(id).to eq(1)
+ value = read_protocol.read_i32
+ expect(value).to eq(42)
+ end
+
+ it "should write and read all primitive types" do
+ @protocol.write_message_begin("test", Thrift::MessageTypes::CALL, 1)
+ @protocol.write_struct_begin("TestStruct")
+
+ @protocol.write_field_begin("bool", Thrift::Types::BOOL, 1)
+ @protocol.write_bool(true)
+ @protocol.write_field_end
+
+ @protocol.write_field_begin("byte", Thrift::Types::BYTE, 2)
+ @protocol.write_byte(127)
+ @protocol.write_field_end
+
+ @protocol.write_field_begin("i16", Thrift::Types::I16, 3)
+ @protocol.write_i16(32767)
+ @protocol.write_field_end
+
+ @protocol.write_field_begin("i32", Thrift::Types::I32, 4)
+ @protocol.write_i32(2147483647)
+ @protocol.write_field_end
+
+ @protocol.write_field_begin("i64", Thrift::Types::I64, 5)
+ @protocol.write_i64(9223372036854775807)
+ @protocol.write_field_end
+
+ @protocol.write_field_begin("double", Thrift::Types::DOUBLE, 6)
+ @protocol.write_double(3.14159)
+ @protocol.write_field_end
+
+ @protocol.write_field_begin("string", Thrift::Types::STRING, 7)
+ @protocol.write_string("hello")
+ @protocol.write_field_end
+
+ @protocol.write_field_stop
+ @protocol.write_struct_end
+ @protocol.write_message_end
+ @protocol.trans.flush
+
+ data = @buffer.read(@buffer.available)
+ read_buffer = Thrift::MemoryBufferTransport.new(data)
+ read_protocol = Thrift::HeaderProtocol.new(read_buffer)
+
+ read_protocol.read_message_begin
+ read_protocol.read_struct_begin
+
+ # bool
+ _, type, _ = read_protocol.read_field_begin
+ expect(type).to eq(Thrift::Types::BOOL)
+ expect(read_protocol.read_bool).to eq(true)
+ read_protocol.read_field_end
+
+ # byte
+ _, type, _ = read_protocol.read_field_begin
+ expect(type).to eq(Thrift::Types::BYTE)
+ expect(read_protocol.read_byte).to eq(127)
+ read_protocol.read_field_end
+
+ # i16
+ _, type, _ = read_protocol.read_field_begin
+ expect(type).to eq(Thrift::Types::I16)
+ expect(read_protocol.read_i16).to eq(32767)
+ read_protocol.read_field_end
+
+ # i32
+ _, type, _ = read_protocol.read_field_begin
+ expect(type).to eq(Thrift::Types::I32)
+ expect(read_protocol.read_i32).to eq(2147483647)
+ read_protocol.read_field_end
+
+ # i64
+ _, type, _ = read_protocol.read_field_begin
+ expect(type).to eq(Thrift::Types::I64)
+ expect(read_protocol.read_i64).to eq(9223372036854775807)
+ read_protocol.read_field_end
+
+ # double
+ _, type, _ = read_protocol.read_field_begin
+ expect(type).to eq(Thrift::Types::DOUBLE)
+ expect(read_protocol.read_double).to be_within(0.00001).of(3.14159)
+ read_protocol.read_field_end
+
+ # string
+ _, type, _ = read_protocol.read_field_begin
+ expect(type).to eq(Thrift::Types::STRING)
+ expect(read_protocol.read_string).to eq("hello")
+ read_protocol.read_field_end
+ end
+ end
+
+ describe "protocol delegation with Compact protocol" do
+ before(:each) do
+ @buffer = Thrift::MemoryBufferTransport.new
+ @protocol = Thrift::HeaderProtocol.new(
+ @buffer,
+ nil,
+ Thrift::HeaderSubprotocolID::COMPACT
+ )
+ end
+
+ it "should use Compact protocol" do
+ expect(@protocol.to_s).to match(/header\(compact/)
+ end
+
+ it "should write and read with Compact protocol" do
+ @protocol.write_message_begin("test", Thrift::MessageTypes::CALL, 1)
+ @protocol.write_struct_begin("Test")
+ @protocol.write_field_begin("field", Thrift::Types::I32, 1)
+ @protocol.write_i32(999)
+ @protocol.write_field_end
+ @protocol.write_field_stop
+ @protocol.write_struct_end
+ @protocol.write_message_end
+ @protocol.trans.flush
+
+ data = @buffer.read(@buffer.available)
+ read_buffer = Thrift::MemoryBufferTransport.new(data)
+ read_protocol = Thrift::HeaderProtocol.new(read_buffer, nil, Thrift::HeaderSubprotocolID::COMPACT)
+
+ read_protocol.read_message_begin
+ read_protocol.read_struct_begin
+ _, type, _ = read_protocol.read_field_begin
+ expect(type).to eq(Thrift::Types::I32)
+ expect(read_protocol.read_i32).to eq(999)
+ end
+ end
+
+ describe "unknown protocol handling" do
+ it "should write an exception response on unknown protocol id" do
+ header_data = +""
+ header_data << varint32(0x10)
+ header_data << varint32(0)
+ frame = build_header_frame(header_data)
+
+ buffer = Thrift::MemoryBufferTransport.new(frame)
+ protocol = Thrift::HeaderProtocol.new(buffer)
+
+ expect { protocol.read_message_begin }.to raise_error(Thrift::ProtocolException)
+
+ response = buffer.read(buffer.available)
+ expect(response.bytesize).to be > 0
+ magic = response[4, 2].unpack('n').first
+ expect(magic).to eq(Thrift::HeaderTransport::HEADER_MAGIC)
+ end
+ end
+
+ describe "protocol auto-detection with legacy frames" do
+ it "should detect framed compact messages" do
+ write_buffer = Thrift::MemoryBufferTransport.new
+ write_protocol = Thrift::CompactProtocol.new(write_buffer)
+
+ write_protocol.write_message_begin("legacy_framed", Thrift::MessageTypes::CALL, 7)
+ write_protocol.write_struct_begin("Args")
+ write_protocol.write_field_stop
+ write_protocol.write_struct_end
+ write_protocol.write_message_end
+
+ payload = write_buffer.read(write_buffer.available)
+ framed = [payload.bytesize].pack('N') + payload
+
+ read_buffer = Thrift::MemoryBufferTransport.new(framed)
+ protocol = Thrift::HeaderProtocol.new(read_buffer)
+
+ name, type, seqid = protocol.read_message_begin
+ expect(name).to eq("legacy_framed")
+ expect(type).to eq(Thrift::MessageTypes::CALL)
+ expect(seqid).to eq(7)
+
+ protocol.read_struct_begin
+ _, field_type, _ = protocol.read_field_begin
+ expect(field_type).to eq(Thrift::Types::STOP)
+ protocol.read_struct_end
+ protocol.read_message_end
+ end
+
+ it "should detect unframed compact messages" do
+ write_buffer = Thrift::MemoryBufferTransport.new
+ write_protocol = Thrift::CompactProtocol.new(write_buffer)
+
+ write_protocol.write_message_begin("legacy_unframed", Thrift::MessageTypes::CALL, 9)
+ write_protocol.write_struct_begin("Args")
+ write_protocol.write_field_stop
+ write_protocol.write_struct_end
+ write_protocol.write_message_end
+
+ payload = write_buffer.read(write_buffer.available)
+
+ read_buffer = Thrift::MemoryBufferTransport.new(payload)
+ protocol = Thrift::HeaderProtocol.new(read_buffer)
+
+ name, type, seqid = protocol.read_message_begin
+ expect(name).to eq("legacy_unframed")
+ expect(type).to eq(Thrift::MessageTypes::CALL)
+ expect(seqid).to eq(9)
+
+ protocol.read_struct_begin
+ _, field_type, _ = protocol.read_field_begin
+ expect(field_type).to eq(Thrift::Types::STOP)
+ protocol.read_struct_end
+ protocol.read_message_end
+ end
+ end
+
+ describe "with compression" do
+ it "should work with ZLIB transform" do
+ @protocol.add_transform(Thrift::HeaderTransformID::ZLIB)
+
+ @protocol.write_message_begin("compressed_test", Thrift::MessageTypes::CALL, 42)
+ @protocol.write_struct_begin("Args")
+ @protocol.write_field_begin("data", Thrift::Types::STRING, 1)
+ @protocol.write_string("a" * 100) # Compressible data
+ @protocol.write_field_end
+ @protocol.write_field_stop
+ @protocol.write_struct_end
+ @protocol.write_message_end
+ @protocol.trans.flush
+
+ data = @buffer.read(@buffer.available)
+ read_buffer = Thrift::MemoryBufferTransport.new(data)
+ read_protocol = Thrift::HeaderProtocol.new(read_buffer)
+
+ name, type, seqid = read_protocol.read_message_begin
+ expect(name).to eq("compressed_test")
+ expect(seqid).to eq(42)
+
+ read_protocol.read_struct_begin
+ _, _, _ = read_protocol.read_field_begin
+ result = read_protocol.read_string
+ expect(result).to eq("a" * 100)
+ end
+ end
+
+ describe "containers" do
+ it "should write and read lists" do
+ @protocol.write_message_begin("test", Thrift::MessageTypes::CALL, 1)
+ @protocol.write_struct_begin("Test")
+ @protocol.write_field_begin("list", Thrift::Types::LIST, 1)
+ @protocol.write_list_begin(Thrift::Types::I32, 3)
+ @protocol.write_i32(1)
+ @protocol.write_i32(2)
+ @protocol.write_i32(3)
+ @protocol.write_list_end
+ @protocol.write_field_end
+ @protocol.write_field_stop
+ @protocol.write_struct_end
+ @protocol.write_message_end
+ @protocol.trans.flush
+
+ data = @buffer.read(@buffer.available)
+ read_buffer = Thrift::MemoryBufferTransport.new(data)
+ read_protocol = Thrift::HeaderProtocol.new(read_buffer)
+
+ read_protocol.read_message_begin
+ read_protocol.read_struct_begin
+ _, _, _ = read_protocol.read_field_begin
+ etype, size = read_protocol.read_list_begin
+ expect(etype).to eq(Thrift::Types::I32)
+ expect(size).to eq(3)
+ expect(read_protocol.read_i32).to eq(1)
+ expect(read_protocol.read_i32).to eq(2)
+ expect(read_protocol.read_i32).to eq(3)
+ end
+
+ it "should write and read maps" do
+ @protocol.write_message_begin("test", Thrift::MessageTypes::CALL, 1)
+ @protocol.write_struct_begin("Test")
+ @protocol.write_field_begin("map", Thrift::Types::MAP, 1)
+ @protocol.write_map_begin(Thrift::Types::STRING, Thrift::Types::I32, 2)
+ @protocol.write_string("a")
+ @protocol.write_i32(1)
+ @protocol.write_string("b")
+ @protocol.write_i32(2)
+ @protocol.write_map_end
+ @protocol.write_field_end
+ @protocol.write_field_stop
+ @protocol.write_struct_end
+ @protocol.write_message_end
+ @protocol.trans.flush
+
+ data = @buffer.read(@buffer.available)
+ read_buffer = Thrift::MemoryBufferTransport.new(data)
+ read_protocol = Thrift::HeaderProtocol.new(read_buffer)
+
+ read_protocol.read_message_begin
+ read_protocol.read_struct_begin
+ _, _, _ = read_protocol.read_field_begin
+ ktype, vtype, size = read_protocol.read_map_begin
+ expect(ktype).to eq(Thrift::Types::STRING)
+ expect(vtype).to eq(Thrift::Types::I32)
+ expect(size).to eq(2)
+ end
+
+ it "should write and read sets" do
+ @protocol.write_message_begin("test", Thrift::MessageTypes::CALL, 1)
+ @protocol.write_struct_begin("Test")
+ @protocol.write_field_begin("set", Thrift::Types::SET, 1)
+ @protocol.write_set_begin(Thrift::Types::STRING, 2)
+ @protocol.write_string("x")
+ @protocol.write_string("y")
+ @protocol.write_set_end
+ @protocol.write_field_end
+ @protocol.write_field_stop
+ @protocol.write_struct_end
+ @protocol.write_message_end
+ @protocol.trans.flush
+
+ data = @buffer.read(@buffer.available)
+ read_buffer = Thrift::MemoryBufferTransport.new(data)
+ read_protocol = Thrift::HeaderProtocol.new(read_buffer)
+
+ read_protocol.read_message_begin
+ read_protocol.read_struct_begin
+ _, _, _ = read_protocol.read_field_begin
+ etype, size = read_protocol.read_set_begin
+ expect(etype).to eq(Thrift::Types::STRING)
+ expect(size).to eq(2)
+ end
+ end
+ end
+
+ describe Thrift::HeaderProtocolFactory do
+ it "should create HeaderProtocol" do
+ factory = Thrift::HeaderProtocolFactory.new
+ buffer = Thrift::MemoryBufferTransport.new
+ protocol = factory.get_protocol(buffer)
+ expect(protocol).to be_a(Thrift::HeaderProtocol)
+ end
+
+ it "should provide a reasonable to_s" do
+ expect(Thrift::HeaderProtocolFactory.new.to_s).to eq("header")
+ end
+
+ it "should pass configuration to protocol" do
+ factory = Thrift::HeaderProtocolFactory.new(nil, Thrift::HeaderSubprotocolID::COMPACT)
+ buffer = Thrift::MemoryBufferTransport.new
+ protocol = factory.get_protocol(buffer)
+ expect(protocol.to_s).to match(/compact/)
+ end
+ end
+end
diff --git a/lib/rb/spec/header_transport_spec.rb b/lib/rb/spec/header_transport_spec.rb
new file mode 100644
index 0000000..2857e15
--- /dev/null
+++ b/lib/rb/spec/header_transport_spec.rb
@@ -0,0 +1,364 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+require 'spec_helper'
+require_relative 'support/header_protocol_helper'
+
+describe 'HeaderTransport' do
+ include HeaderProtocolHelper
+
+ describe Thrift::HeaderClientType do
+ it "should define client type constants" do
+ expect(Thrift::HeaderClientType::HEADERS).to eq(0x00)
+ expect(Thrift::HeaderClientType::FRAMED_BINARY).to eq(0x01)
+ expect(Thrift::HeaderClientType::UNFRAMED_BINARY).to eq(0x02)
+ expect(Thrift::HeaderClientType::FRAMED_COMPACT).to eq(0x03)
+ expect(Thrift::HeaderClientType::UNFRAMED_COMPACT).to eq(0x04)
+ end
+ end
+
+ describe Thrift::HeaderSubprotocolID do
+ it "should define protocol ID constants" do
+ expect(Thrift::HeaderSubprotocolID::BINARY).to eq(0x00)
+ expect(Thrift::HeaderSubprotocolID::COMPACT).to eq(0x02)
+ end
+ end
+
+ describe Thrift::HeaderTransformID do
+ it "should define transform ID constants" do
+ expect(Thrift::HeaderTransformID::ZLIB).to eq(0x01)
+ end
+ end
+
+ describe Thrift::HeaderTransport do
+ before(:each) do
+ @underlying = Thrift::MemoryBufferTransport.new
+ @trans = Thrift::HeaderTransport.new(@underlying)
+ end
+
+ it "should provide a to_s that describes the encapsulation" do
+ expect(@trans.to_s).to eq("header(memory)")
+ end
+
+ it "should pass through open?/open/close" do
+ mock_transport = double("Transport")
+ expect(mock_transport).to receive(:open?).and_return(true)
+ expect(mock_transport).to receive(:open).and_return(nil)
+ expect(mock_transport).to receive(:close).and_return(nil)
+
+ trans = Thrift::HeaderTransport.new(mock_transport)
+ expect(trans.open?).to be true
+ trans.open
+ trans.close
+ end
+
+ describe "header management" do
+ it "should allow setting and getting headers" do
+ @trans.set_header("key1", "value1")
+ @trans.set_header("key2", "value2")
+ # Headers aren't read until we receive data, so write and read back
+ expect(@trans.get_headers).to eq({})
+ end
+
+ it "should clear headers" do
+ @trans.set_header("key1", "value1")
+ @trans.clear_headers
+ # Write and flush to verify headers were cleared
+ @trans.write("test")
+ @trans.flush
+ end
+
+ it "should add transforms" do
+ expect { @trans.add_transform(Thrift::HeaderTransformID::ZLIB) }.not_to raise_error
+ end
+
+ it "should reject unknown transforms" do
+ expect { @trans.add_transform(999) }.to raise_error(Thrift::TransportException)
+ end
+ end
+
+ describe "write and flush" do
+ it "should buffer writes" do
+ @trans.write("hello")
+ @trans.write(" world")
+ expect(@underlying.available).to eq(0)
+ end
+
+ it "should write Header format on flush" do
+ @trans.write("test payload")
+ @trans.flush
+
+ # Read back the frame
+ data = @underlying.read(@underlying.available)
+
+ # Should have frame length (4 bytes) + header + payload
+ expect(data.bytesize).to be > 16
+
+ # First 4 bytes are frame length
+ frame_size = data[0, 4].unpack('N').first
+ expect(frame_size).to eq(data.bytesize - 4)
+
+ # Next 2 bytes should be header magic
+ magic = data[4, 2].unpack('n').first
+ expect(magic).to eq(Thrift::HeaderTransport::HEADER_MAGIC)
+ end
+
+ it "should include headers in frame" do
+ @trans.set_header("test-key", "test-value")
+ @trans.write("payload")
+ @trans.flush
+
+ # Read back and verify it's larger due to headers
+ data = @underlying.read(@underlying.available)
+ expect(data.bytesize).to be > 30 # Should include header key-value
+ end
+
+ it "should apply ZLIB transform" do
+ @trans.add_transform(Thrift::HeaderTransformID::ZLIB)
+ original_payload = "a" * 1000 # Compressible data
+ @trans.write(original_payload)
+ @trans.flush
+
+ data = @underlying.read(@underlying.available)
+ # Compressed frame should be smaller than uncompressed
+ expect(data.bytesize).to be < original_payload.bytesize
+ end
+ end
+
+ describe "frame size limits" do
+ it "should reject payloads larger than max frame size" do
+ @trans.set_max_frame_size(4)
+ @trans.write("12345")
+ expect { @trans.flush }.to raise_error(Thrift::TransportException, /frame that is too large/)
+ end
+ end
+
+ describe "read and frame detection" do
+ it "should detect Header format" do
+ # Write a Header frame
+ @trans.write("test data")
+ @trans.flush
+
+ # Reset for reading
+ written_data = @underlying.read(@underlying.available)
+ read_transport = Thrift::MemoryBufferTransport.new(written_data)
+ read_trans = Thrift::HeaderTransport.new(read_transport)
+
+ result = read_trans.read(9)
+ expect(result).to eq("test data")
+ end
+
+ it "should detect framed binary protocol" do
+ # Create a framed binary message
+ payload = [Thrift::BinaryProtocol::VERSION_1 | Thrift::MessageTypes::CALL].pack('N')
+ payload << "test"
+ frame = [payload.bytesize].pack('N') + payload
+
+ read_transport = Thrift::MemoryBufferTransport.new(frame)
+ read_trans = Thrift::HeaderTransport.new(read_transport)
+
+ result = read_trans.read(payload.bytesize)
+ expect(result).to eq(payload)
+ end
+
+ it "should detect unframed binary protocol" do
+ # Create an unframed binary message (version word first)
+ message = [Thrift::BinaryProtocol::VERSION_1 | Thrift::MessageTypes::CALL].pack('N')
+ message << "test"
+
+ read_transport = Thrift::MemoryBufferTransport.new(message)
+ read_trans = Thrift::HeaderTransport.new(read_transport)
+
+ result = read_trans.read(message.bytesize)
+ expect(result).to eq(message)
+ end
+
+ it "should read headers from Header frame" do
+ # Write with headers
+ @trans.set_header("request-id", "12345")
+ @trans.write("payload")
+ @trans.flush
+
+ # Read back
+ written_data = @underlying.read(@underlying.available)
+ read_transport = Thrift::MemoryBufferTransport.new(written_data)
+ read_trans = Thrift::HeaderTransport.new(read_transport)
+
+ read_trans.read(7)
+ headers = read_trans.get_headers
+ expect(headers["request-id"]).to eq("12345")
+ end
+
+ it "should decompress ZLIB payload" do
+ # Write with ZLIB
+ @trans.add_transform(Thrift::HeaderTransformID::ZLIB)
+ original = "hello world this is a test"
+ @trans.write(original)
+ @trans.flush
+
+ # Read back
+ written_data = @underlying.read(@underlying.available)
+ read_transport = Thrift::MemoryBufferTransport.new(written_data)
+ read_trans = Thrift::HeaderTransport.new(read_transport)
+
+ result = read_trans.read(original.bytesize)
+ expect(result).to eq(original)
+ end
+ end
+
+ describe "header parsing protections" do
+ it "should reject unreasonable header sizes" do
+ frame = build_header_frame("", Thrift::Bytes.empty_byte_buffer, header_words: 16_384)
+ read_transport = Thrift::MemoryBufferTransport.new(frame)
+ read_trans = Thrift::HeaderTransport.new(read_transport)
+
+ expect { read_trans.read(1) }.to raise_error(Thrift::TransportException, /Header size is unreasonable/)
+ end
+
+ it "should reject header frames that are too small" do
+ frame = Thrift::Bytes.empty_byte_buffer
+ frame << [9].pack('N')
+ frame << [Thrift::HeaderTransport::HEADER_MAGIC].pack('n')
+ frame << [0].pack('n')
+ frame << [0].pack('N')
+ frame << [0].pack('n')
+ read_transport = Thrift::MemoryBufferTransport.new(frame)
+ read_trans = Thrift::HeaderTransport.new(read_transport)
+
+ expect { read_trans.read(1) }.to raise_error(Thrift::TransportException, /frame is too small/)
+ end
+
+ it "should reject varints that cross header boundary" do
+ header_data = [0x80, 0x80, 0x80, 0x80].pack('C*')
+ frame = build_header_frame(header_data)
+ read_transport = Thrift::MemoryBufferTransport.new(frame)
+ read_trans = Thrift::HeaderTransport.new(read_transport)
+
+ expect { read_trans.read(1) }.to raise_error(Thrift::TransportException, /header boundary/)
+ end
+
+ it "should reject strings that exceed header boundary" do
+ header_data = +""
+ header_data << varint32(Thrift::HeaderSubprotocolID::BINARY)
+ header_data << varint32(0)
+ header_data << varint32(Thrift::HeaderInfoType::KEY_VALUE)
+ header_data << varint32(1)
+ header_data << varint32(10)
+ header_data << "a"
+
+ frame = build_header_frame(header_data)
+ read_transport = Thrift::MemoryBufferTransport.new(frame)
+ read_trans = Thrift::HeaderTransport.new(read_transport)
+
+ expect { read_trans.read(1) }.to raise_error(Thrift::TransportException, /Info header length exceeds header size/)
+ end
+ end
+
+ describe "round-trip" do
+ it "should handle complete write-read cycle" do
+ # Write
+ @trans.set_header("trace-id", "abc123")
+ @trans.write("hello world")
+ @trans.flush
+
+ # Read
+ written_data = @underlying.read(@underlying.available)
+ read_transport = Thrift::MemoryBufferTransport.new(written_data)
+ read_trans = Thrift::HeaderTransport.new(read_transport)
+
+ result = read_trans.read(11)
+ expect(result).to eq("hello world")
+ expect(read_trans.get_headers["trace-id"]).to eq("abc123")
+ end
+
+ it "should handle multiple headers" do
+ @trans.set_header("header1", "value1")
+ @trans.set_header("header2", "value2")
+ @trans.set_header("header3", "value3")
+ @trans.write("data")
+ @trans.flush
+
+ written_data = @underlying.read(@underlying.available)
+ read_transport = Thrift::MemoryBufferTransport.new(written_data)
+ read_trans = Thrift::HeaderTransport.new(read_transport)
+
+ read_trans.read(4)
+ headers = read_trans.get_headers
+ expect(headers["header1"]).to eq("value1")
+ expect(headers["header2"]).to eq("value2")
+ expect(headers["header3"]).to eq("value3")
+ end
+
+ it "should handle ZLIB compression round-trip" do
+ @trans.add_transform(Thrift::HeaderTransformID::ZLIB)
+ @trans.set_header("compressed", "true")
+ original = "x" * 500
+ @trans.write(original)
+ @trans.flush
+
+ written_data = @underlying.read(@underlying.available)
+ read_transport = Thrift::MemoryBufferTransport.new(written_data)
+ read_trans = Thrift::HeaderTransport.new(read_transport)
+
+ result = read_trans.read(500)
+ expect(result).to eq(original)
+ expect(read_trans.get_headers["compressed"]).to eq("true")
+ end
+ end
+
+ describe "client type restrictions" do
+ it "should reject disallowed client types" do
+ # Only allow HEADERS
+ allowed = [Thrift::HeaderClientType::HEADERS]
+
+ # Create framed binary message
+ payload = [Thrift::BinaryProtocol::VERSION_1 | Thrift::MessageTypes::CALL].pack('N')
+ frame = [payload.bytesize].pack('N') + payload
+
+ read_transport = Thrift::MemoryBufferTransport.new(frame)
+ read_trans = Thrift::HeaderTransport.new(read_transport, allowed)
+
+ expect { read_trans.read(4) }.to raise_error(Thrift::TransportException)
+ end
+ end
+ end
+
+ describe Thrift::HeaderTransportFactory do
+ it "should wrap transport in HeaderTransport" do
+ mock_transport = double("Transport")
+ factory = Thrift::HeaderTransportFactory.new
+ result = factory.get_transport(mock_transport)
+ expect(result).to be_a(Thrift::HeaderTransport)
+ end
+
+ it "should provide a reasonable to_s" do
+ expect(Thrift::HeaderTransportFactory.new.to_s).to eq("header")
+ end
+
+ it "should pass allowed_client_types to transport" do
+ allowed = [Thrift::HeaderClientType::HEADERS]
+ factory = Thrift::HeaderTransportFactory.new(allowed)
+
+ mock_transport = Thrift::MemoryBufferTransport.new
+ result = factory.get_transport(mock_transport)
+
+ expect(result).to be_a(Thrift::HeaderTransport)
+ end
+ end
+end
diff --git a/lib/rb/spec/support/header_protocol_helper.rb b/lib/rb/spec/support/header_protocol_helper.rb
new file mode 100644
index 0000000..75875a6
--- /dev/null
+++ b/lib/rb/spec/support/header_protocol_helper.rb
@@ -0,0 +1,54 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+module HeaderProtocolHelper
+ def varint32(n)
+ bytes = []
+ loop do
+ if (n & ~0x7f) == 0
+ bytes << n
+ break
+ else
+ bytes << ((n & 0x7f) | 0x80)
+ n >>= 7
+ end
+ end
+ bytes.pack('C*')
+ end
+
+ def build_header_frame(header_data, payload = Thrift::Bytes.empty_byte_buffer, header_words: nil)
+ header_data = Thrift::Bytes.force_binary_encoding(header_data)
+ if header_words.nil?
+ padding = (4 - (header_data.bytesize % 4)) % 4
+ header_data += "\x00" * padding
+ header_words = header_data.bytesize / 4
+ end
+
+ frame_size = 2 + 2 + 4 + 2 + header_data.bytesize + payload.bytesize
+ frame = Thrift::Bytes.empty_byte_buffer
+ frame << [frame_size].pack('N')
+ frame << [Thrift::HeaderTransport::HEADER_MAGIC].pack('n')
+ frame << [0].pack('n')
+ frame << [0].pack('N')
+ frame << [header_words].pack('n')
+ frame << header_data
+ frame << payload
+ frame
+ end
+end