Implemented header protocol for Ruby client library
diff --git a/lib/rb/benchmark/benchmark.rb b/lib/rb/benchmark/benchmark.rb
index 4a520a5..0111ee8 100644
--- a/lib/rb/benchmark/benchmark.rb
+++ b/lib/rb/benchmark/benchmark.rb
@@ -35,19 +35,21 @@
attr_accessor :interpreter
attr_accessor :host
attr_accessor :port
+ attr_accessor :protocol_type
def initialize(opts)
@serverclass = opts.fetch(:class, Thrift::NonblockingServer)
@interpreter = opts.fetch(:interpreter, "ruby")
@host = opts.fetch(:host, ::HOST)
@port = opts.fetch(:port, ::PORT)
+ @protocol_type = opts.fetch(:protocol_type, 'binary')
@tls = opts.fetch(:tls, false)
end
def start
return if @serverclass == Object
args = (File.basename(@interpreter) == "jruby" ? "-J-server" : "")
- @pipe = IO.popen("#{@interpreter} #{args} #{File.dirname(__FILE__)}/server.rb #{"-tls" if @tls} #{@host} #{@port} #{@serverclass.name}", "r+")
+ @pipe = IO.popen("#{@interpreter} #{args} #{File.dirname(__FILE__)}/server.rb #{"-tls" if @tls} #{@host} #{@port} #{@serverclass.name} #{@protocol_type}", "r+")
Marshal.load(@pipe) # wait until the server has started
sleep 0.4 # give the server time to actually start spawning sockets
end
@@ -77,6 +79,7 @@
@interpreter = opts.fetch(:interpreter, "ruby")
@server = server
@log_exceptions = opts.fetch(:log_exceptions, false)
+ @protocol_type = opts.fetch(:protocol_type, 'binary')
@tls = opts.fetch(:tls, false)
end
@@ -96,7 +99,7 @@
end
def spawn
- pipe = IO.popen("#{@interpreter} #{File.dirname(__FILE__)}/client.rb #{"-log-exceptions" if @log_exceptions} #{"-tls" if @tls} #{@host} #{@port} #{@clients_per_process} #{@calls_per_client}")
+ pipe = IO.popen("#{@interpreter} #{File.dirname(__FILE__)}/client.rb #{"-log-exceptions" if @log_exceptions} #{"-tls" if @tls} #{@host} #{@port} #{@clients_per_process} #{@calls_per_client} #{@protocol_type}")
@pool << pipe
end
@@ -202,6 +205,7 @@
[["Server class", "%s"], @server.serverclass == Object ? "" : @server.serverclass],
[["Server interpreter", "%s"], @server.interpreter],
[["Client interpreter", "%s"], @interpreter],
+ [["Protocol type", "%s"], @protocol_type],
[["Socket class", "%s"], socket_class],
["Number of processes", @num_processes],
["Clients per process", @clients_per_process],
@@ -255,12 +259,14 @@
end
puts "Starting server..."
+protocol_type = ENV['THRIFT_PROTOCOL'] || 'binary'
args = {}
args[:interpreter] = ENV['THRIFT_SERVER_INTERPRETER'] || ENV['THRIFT_INTERPRETER'] || "ruby"
args[:class] = resolve_const(ENV['THRIFT_SERVER']) || Thrift::NonblockingServer
args[:host] = ENV['THRIFT_HOST'] || HOST
args[:port] = (ENV['THRIFT_PORT'] || PORT).to_i
args[:tls] = ENV['THRIFT_TLS'] == 'true'
+args[:protocol_type] = protocol_type
server = Server.new(args)
server.start
@@ -273,6 +279,7 @@
args[:calls_per_client] = (ENV['THRIFT_NUM_CALLS'] || 50).to_i
args[:interpreter] = ENV['THRIFT_CLIENT_INTERPRETER'] || ENV['THRIFT_INTERPRETER'] || "ruby"
args[:log_exceptions] = !!ENV['THRIFT_LOG_EXCEPTIONS']
+args[:protocol_type] = protocol_type
BenchmarkManager.new(args, server).run
server.shutdown
diff --git a/lib/rb/benchmark/client.rb b/lib/rb/benchmark/client.rb
index 693bf60..304e6d8 100644
--- a/lib/rb/benchmark/client.rb
+++ b/lib/rb/benchmark/client.rb
@@ -25,13 +25,36 @@
require 'benchmark_service'
class Client
- def initialize(host, port, clients_per_process, calls_per_client, log_exceptions, tls)
+ def initialize(host, port, clients_per_process, calls_per_client, log_exceptions, tls, protocol_type)
@host = host
@port = port
@clients_per_process = clients_per_process
@calls_per_client = calls_per_client
@log_exceptions = log_exceptions
@tls = tls
+ @protocol_type = protocol_type || 'binary'
+ end
+
+ def create_protocol(socket)
+ case @protocol_type
+ when 'binary'
+ transport = Thrift::FramedTransport.new(socket)
+ Thrift::BinaryProtocol.new(transport)
+ when 'compact'
+ transport = Thrift::FramedTransport.new(socket)
+ Thrift::CompactProtocol.new(transport)
+ when 'header'
+ Thrift::HeaderProtocol.new(socket)
+ when 'header-compact'
+ Thrift::HeaderProtocol.new(socket, nil, Thrift::HeaderSubprotocolID::COMPACT)
+ when 'header-zlib'
+ protocol = Thrift::HeaderProtocol.new(socket)
+ protocol.add_transform(Thrift::HeaderTransformID::ZLIB)
+ protocol
+ else
+ transport = Thrift::FramedTransport.new(socket)
+ Thrift::BinaryProtocol.new(transport)
+ end
end
def run
@@ -53,8 +76,8 @@
else
Thrift::Socket.new(@host, @port)
end
- transport = Thrift::FramedTransport.new(socket)
- protocol = Thrift::BinaryProtocol.new(transport)
+ protocol = create_protocol(socket)
+ transport = protocol.trans
client = ThriftBenchmark::BenchmarkService::Client.new(protocol)
begin
start = Time.now
@@ -89,6 +112,6 @@
log_exceptions = true if ARGV[0] == '-log-exceptions' and ARGV.shift
tls = true if ARGV[0] == '-tls' and ARGV.shift
-host, port, clients_per_process, calls_per_client = ARGV
+host, port, clients_per_process, calls_per_client, protocol_type = ARGV
-Client.new(host, port.to_i, clients_per_process.to_i, calls_per_client.to_i, log_exceptions, tls).run
+Client.new(host, port.to_i, clients_per_process.to_i, calls_per_client.to_i, log_exceptions, tls, protocol_type).run
diff --git a/lib/rb/benchmark/server.rb b/lib/rb/benchmark/server.rb
index 153eb0f..6df3fa9 100644
--- a/lib/rb/benchmark/server.rb
+++ b/lib/rb/benchmark/server.rb
@@ -38,7 +38,25 @@
end
end
- def self.start_server(host, port, serverClass, tls)
+ def self.create_factories(protocol_type)
+ case protocol_type
+ when 'binary'
+ [FramedTransportFactory.new, BinaryProtocolFactory.new]
+ when 'compact'
+ [FramedTransportFactory.new, CompactProtocolFactory.new]
+ when 'header'
+ [HeaderTransportFactory.new, HeaderProtocolFactory.new]
+ when 'header-compact'
+ [HeaderTransportFactory.new, HeaderProtocolFactory.new(nil, HeaderSubprotocolID::COMPACT)]
+ when 'header-zlib'
+ # Note: Server doesn't add transforms - it mirrors client's transforms
+ [HeaderTransportFactory.new, HeaderProtocolFactory.new]
+ else
+ [FramedTransportFactory.new, BinaryProtocolFactory.new]
+ end
+ end
+
+ def self.start_server(host, port, serverClass, tls, protocol_type = nil)
handler = BenchmarkHandler.new
processor = ThriftBenchmark::BenchmarkService::Processor.new(handler)
transport = if tls
@@ -58,8 +76,8 @@
else
ServerSocket.new(host, port)
end
- transport_factory = FramedTransportFactory.new
- args = [processor, transport, transport_factory, nil, 20]
+ transport_factory, protocol_factory = create_factories(protocol_type || 'binary')
+ args = [processor, transport, transport_factory, protocol_factory, 20]
if serverClass == NonblockingServer
logger = Logger.new(STDERR)
logger.level = Logger::WARN
@@ -88,9 +106,9 @@
tls = true if ARGV[0] == '-tls' and ARGV.shift
-host, port, serverklass = ARGV
+host, port, serverklass, protocol_type = ARGV
-Server.start_server(host, port.to_i, resolve_const(serverklass), tls)
+Server.start_server(host, port.to_i, resolve_const(serverklass), tls, protocol_type)
# let our host know that the interpreter has started
# ideally we'd wait until the server was serving, but we don't have a hook for that
diff --git a/lib/rb/lib/thrift.rb b/lib/rb/lib/thrift.rb
index 8b9a8a5..f12f10f 100644
--- a/lib/rb/lib/thrift.rb
+++ b/lib/rb/lib/thrift.rb
@@ -44,6 +44,7 @@
require 'thrift/protocol/compact_protocol'
require 'thrift/protocol/json_protocol'
require 'thrift/protocol/multiplexed_protocol'
+require 'thrift/protocol/header_protocol'
# transport
require 'thrift/transport/base_transport'
@@ -56,6 +57,7 @@
require 'thrift/transport/unix_server_socket'
require 'thrift/transport/buffered_transport'
require 'thrift/transport/framed_transport'
+require 'thrift/transport/header_transport'
require 'thrift/transport/http_client_transport'
require 'thrift/transport/io_stream_transport'
require 'thrift/transport/memory_buffer_transport'
diff --git a/lib/rb/lib/thrift/protocol/header_protocol.rb b/lib/rb/lib/thrift/protocol/header_protocol.rb
new file mode 100644
index 0000000..8bf4bbe
--- /dev/null
+++ b/lib/rb/lib/thrift/protocol/header_protocol.rb
@@ -0,0 +1,319 @@
+# encoding: ascii-8bit
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+module Thrift
+ # HeaderProtocol is a protocol that wraps HeaderTransport and delegates
+ # to either BinaryProtocol or CompactProtocol based on auto-detection.
+ #
+ # It provides access to header management (get_headers, set_header, etc.)
+ # through the underlying HeaderTransport.
+ #
+ # Example usage:
+ # socket = Thrift::Socket.new('localhost', 9090)
+ # protocol = Thrift::HeaderProtocol.new(socket)
+ # client = MyService::Client.new(protocol)
+ # protocol.trans.open
+ # client.some_method()
+ # protocol.trans.close
+ #
+ class HeaderProtocol < BaseProtocol
+ # Creates a new HeaderProtocol.
+ #
+ # @param transport [BaseTransport, HeaderTransport] The transport to wrap.
+ # If not already a HeaderTransport, it will be wrapped in one.
+ # @param allowed_client_types [Array<Integer>] Allowed client types for auto-detection
+ # @param default_protocol [Integer] Default protocol ID (BINARY or COMPACT)
+ def initialize(transport, allowed_client_types = nil, default_protocol = HeaderSubprotocolID::COMPACT)
+ # Wrap transport in HeaderTransport if not already wrapped
+ if transport.is_a?(HeaderTransport)
+ @header_transport = transport
+ else
+ @header_transport = HeaderTransport.new(transport, allowed_client_types, default_protocol)
+ end
+
+ @default_protocol = default_protocol
+ @current_protocol_id = default_protocol
+
+ # Create initial protocol
+ @protocol = create_protocol(@current_protocol_id)
+ end
+
+ # Returns the HeaderTransport
+ def trans
+ @header_transport
+ end
+
+ # Returns headers read from the last message
+ def get_headers
+ @header_transport.get_headers
+ end
+
+ # Sets a header to be sent with the next message
+ def set_header(key, value)
+ @header_transport.set_header(key, value)
+ end
+
+ # Clears all write headers
+ def clear_headers
+ @header_transport.clear_headers
+ end
+
+ # Adds a transform (e.g., ZLIB compression)
+ def add_transform(transform_id)
+ @header_transport.add_transform(transform_id)
+ end
+
+ # Write methods - delegate to underlying protocol
+ def write_message_begin(name, type, seqid)
+ @protocol.write_message_begin(name, type, seqid)
+ end
+
+ def write_message_end
+ @protocol.write_message_end
+ end
+
+ def write_struct_begin(name)
+ @protocol.write_struct_begin(name)
+ end
+
+ def write_struct_end
+ @protocol.write_struct_end
+ end
+
+ def write_field_begin(name, type, id)
+ @protocol.write_field_begin(name, type, id)
+ end
+
+ def write_field_end
+ @protocol.write_field_end
+ end
+
+ def write_field_stop
+ @protocol.write_field_stop
+ end
+
+ def write_map_begin(ktype, vtype, size)
+ @protocol.write_map_begin(ktype, vtype, size)
+ end
+
+ def write_map_end
+ @protocol.write_map_end
+ end
+
+ def write_list_begin(etype, size)
+ @protocol.write_list_begin(etype, size)
+ end
+
+ def write_list_end
+ @protocol.write_list_end
+ end
+
+ def write_set_begin(etype, size)
+ @protocol.write_set_begin(etype, size)
+ end
+
+ def write_set_end
+ @protocol.write_set_end
+ end
+
+ def write_bool(bool)
+ @protocol.write_bool(bool)
+ end
+
+ def write_byte(byte)
+ @protocol.write_byte(byte)
+ end
+
+ def write_i16(i16)
+ @protocol.write_i16(i16)
+ end
+
+ def write_i32(i32)
+ @protocol.write_i32(i32)
+ end
+
+ def write_i64(i64)
+ @protocol.write_i64(i64)
+ end
+
+ def write_double(dub)
+ @protocol.write_double(dub)
+ end
+
+ def write_string(str)
+ @protocol.write_string(str)
+ end
+
+ def write_binary(buf)
+ @protocol.write_binary(buf)
+ end
+
+ def write_uuid(uuid)
+ @protocol.write_uuid(uuid)
+ end
+
+ # Read methods - delegate to underlying protocol
+ # read_message_begin handles protocol switching after detection
+ def read_message_begin
+ begin
+ @header_transport.reset_protocol
+ reset_protocol_if_needed
+ rescue ProtocolException => ex
+ app_ex = ApplicationException.new(ApplicationException::INVALID_PROTOCOL, ex.message)
+ write_message_begin("", MessageTypes::EXCEPTION, 0)
+ app_ex.write(self)
+ write_message_end
+ @header_transport.flush
+ raise ex
+ end
+ @protocol.read_message_begin
+ end
+
+ def read_message_end
+ @protocol.read_message_end
+ end
+
+ def read_struct_begin
+ @protocol.read_struct_begin
+ end
+
+ def read_struct_end
+ @protocol.read_struct_end
+ end
+
+ def read_field_begin
+ @protocol.read_field_begin
+ end
+
+ def read_field_end
+ @protocol.read_field_end
+ end
+
+ def read_map_begin
+ @protocol.read_map_begin
+ end
+
+ def read_map_end
+ @protocol.read_map_end
+ end
+
+ def read_list_begin
+ @protocol.read_list_begin
+ end
+
+ def read_list_end
+ @protocol.read_list_end
+ end
+
+ def read_set_begin
+ @protocol.read_set_begin
+ end
+
+ def read_set_end
+ @protocol.read_set_end
+ end
+
+ def read_bool
+ @protocol.read_bool
+ end
+
+ def read_byte
+ @protocol.read_byte
+ end
+
+ def read_i16
+ @protocol.read_i16
+ end
+
+ def read_i32
+ @protocol.read_i32
+ end
+
+ def read_i64
+ @protocol.read_i64
+ end
+
+ def read_double
+ @protocol.read_double
+ end
+
+ def read_string
+ @protocol.read_string
+ end
+
+ def read_binary
+ @protocol.read_binary
+ end
+
+ def read_uuid
+ @protocol.read_uuid
+ end
+
+ def to_s
+ "header(#{@protocol.to_s})"
+ end
+
+ private
+
+ # Checks if the protocol needs to be switched after reading
+ def reset_protocol_if_needed
+ new_protocol_id = @header_transport.protocol_id
+ if new_protocol_id != @current_protocol_id
+ @protocol = create_protocol(new_protocol_id)
+ @current_protocol_id = new_protocol_id
+ end
+ end
+
+ # Creates a protocol instance based on protocol ID
+ def create_protocol(protocol_id)
+ case protocol_id
+ when HeaderSubprotocolID::BINARY
+ BinaryProtocol.new(@header_transport)
+ when HeaderSubprotocolID::COMPACT
+ CompactProtocol.new(@header_transport)
+ else
+ raise ProtocolException.new(
+ ProtocolException::INVALID_DATA,
+ "Unknown protocol ID: #{protocol_id}"
+ )
+ end
+ end
+ end
+
+ # Factory for creating HeaderProtocol instances
+ class HeaderProtocolFactory < BaseProtocolFactory
+ # Creates a new HeaderProtocolFactory.
+ #
+ # @param allowed_client_types [Array<Integer>] Allowed client types for auto-detection
+ # @param default_protocol [Integer] Default protocol ID (BINARY or COMPACT)
+ def initialize(allowed_client_types = nil, default_protocol = HeaderSubprotocolID::BINARY)
+ @allowed_client_types = allowed_client_types
+ @default_protocol = default_protocol
+ end
+
+ def get_protocol(trans)
+ HeaderProtocol.new(trans, @allowed_client_types, @default_protocol)
+ end
+
+ def to_s
+ "header"
+ end
+ end
+end
diff --git a/lib/rb/lib/thrift/transport/header_transport.rb b/lib/rb/lib/thrift/transport/header_transport.rb
new file mode 100644
index 0000000..54d3a6c
--- /dev/null
+++ b/lib/rb/lib/thrift/transport/header_transport.rb
@@ -0,0 +1,500 @@
+# encoding: ascii-8bit
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+require 'stringio'
+require 'zlib'
+
+module Thrift
+ # Client type constants for Header protocol
+ module HeaderClientType
+ HEADERS = 0x00
+ FRAMED_BINARY = 0x01
+ UNFRAMED_BINARY = 0x02
+ FRAMED_COMPACT = 0x03
+ UNFRAMED_COMPACT = 0x04
+ end
+
+ # Subprotocol ID constants for Header transport
+ module HeaderSubprotocolID
+ BINARY = 0x00
+ COMPACT = 0x02
+ end
+
+ # Transform ID constants for Header transport
+ module HeaderTransformID
+ ZLIB = 0x01
+ end
+
+ # Info header type constants
+ module HeaderInfoType
+ KEY_VALUE = 0x01
+ end
+
+ # HeaderTransport implements the THeader framing protocol.
+ #
+ # THeader is a transport that adds headers and supports multiple protocols
+ # and transforms. It can auto-detect and communicate with legacy protocols
+ # (framed/unframed binary/compact) for backward compatibility.
+ #
+ # Wire format:
+ # +----------------------------------------------------------------+
+ # | LENGTH (4 bytes, big-endian, excludes itself) |
+ # +----------------------------------------------------------------+
+ # | HEADER MAGIC (2 bytes: 0x0FFF) | FLAGS (2 bytes) |
+ # +----------------------------------------------------------------+
+ # | SEQUENCE NUMBER (4 bytes) |
+ # +----------------------------------------------------------------+
+ # | HEADER SIZE/4 (2 bytes) | HEADER DATA (variable)... |
+ # +----------------------------------------------------------------+
+ # | PAYLOAD (variable) |
+ # +----------------------------------------------------------------+
+ #
+ class HeaderTransport < BaseTransport
+ # Header magic value (first 2 bytes of header)
+ HEADER_MAGIC = 0x0FFF
+
+ # Maximum frame size (~1GB)
+ MAX_FRAME_SIZE = 0x3FFFFFFF
+
+ # Binary protocol version mask and version 1
+ BINARY_VERSION_MASK = 0xffff0000
+ BINARY_VERSION_1 = 0x80010000
+
+ # Compact protocol ID
+ COMPACT_PROTOCOL_ID = 0x82
+ COMPACT_VERSION_MASK = 0x1f
+ COMPACT_VERSION = 0x01
+
+ attr_reader :protocol_id, :sequence_id, :flags
+
+ # Creates a new HeaderTransport wrapping the given transport.
+ #
+ # @param transport [BaseTransport] The underlying transport to wrap
+ # @param allowed_client_types [Array<Integer>] Allowed client types for auto-detection.
+ # Defaults to all types for backward compatibility.
+ # @param default_protocol [Integer] Default protocol ID (BINARY or COMPACT)
+ def initialize(transport, allowed_client_types = nil, default_protocol = HeaderSubprotocolID::COMPACT)
+ @transport = transport
+ @client_type = HeaderClientType::HEADERS
+ @protocol_id = default_protocol
+ @allowed_client_types = allowed_client_types || [
+ HeaderClientType::HEADERS,
+ HeaderClientType::FRAMED_BINARY,
+ HeaderClientType::UNFRAMED_BINARY,
+ HeaderClientType::FRAMED_COMPACT,
+ HeaderClientType::UNFRAMED_COMPACT
+ ]
+
+ @read_buffer = StringIO.new(Bytes.empty_byte_buffer)
+ @write_buffer = StringIO.new(Bytes.empty_byte_buffer)
+
+ @read_headers = {}
+ @write_headers = {}
+ @write_transforms = []
+
+ @sequence_id = 0
+ @flags = 0
+ @max_frame_size = MAX_FRAME_SIZE
+ end
+
+ def open?
+ @transport.open?
+ end
+
+ def open
+ @transport.open
+ end
+
+ def close
+ @transport.close
+ end
+
+ # Returns the headers read from the last frame
+ def get_headers
+ @read_headers
+ end
+
+ # Sets a header to be written with the next flush
+ #
+ # @param key [String] Header key (must be binary string)
+ # @param value [String] Header value (must be binary string)
+ def set_header(key, value)
+ key = Bytes.force_binary_encoding(key.to_s)
+ value = Bytes.force_binary_encoding(value.to_s)
+ @write_headers[key] = value
+ end
+
+ # Clears all write headers
+ def clear_headers
+ @write_headers.clear
+ end
+
+ # Adds a transform to apply when writing
+ #
+ # @param transform_id [Integer] Transform ID (e.g., HeaderTransformID::ZLIB)
+ def add_transform(transform_id)
+ unless transform_id == HeaderTransformID::ZLIB
+ raise TransportException.new(TransportException::UNKNOWN, "Unknown transform: #{transform_id}")
+ end
+ @write_transforms << transform_id unless @write_transforms.include?(transform_id)
+ end
+
+ # Sets the maximum allowed frame size
+ def set_max_frame_size(size)
+ if size <= 0 || size > MAX_FRAME_SIZE
+ raise ArgumentError, "max_frame_size must be > 0 and <= #{MAX_FRAME_SIZE}"
+ end
+ @max_frame_size = size
+ end
+
+ def read(sz)
+ # Try reading from existing buffer
+ data = @read_buffer.read(sz)
+ data = Bytes.empty_byte_buffer if data.nil?
+
+ bytes_left = sz - data.bytesize
+ return data if bytes_left == 0
+
+ # Handle unframed passthrough - read directly from underlying transport
+ if @client_type == HeaderClientType::UNFRAMED_BINARY ||
+ @client_type == HeaderClientType::UNFRAMED_COMPACT
+ return data + @transport.read(bytes_left)
+ end
+
+ # Need to read the next frame
+ read_frame(bytes_left)
+ additional = @read_buffer.read(bytes_left)
+ data + (additional || Bytes.empty_byte_buffer)
+ end
+
+ def write(buf)
+ @write_buffer.write(Bytes.force_binary_encoding(buf))
+ end
+
+ def flush
+ payload = @write_buffer.string
+ @write_buffer = StringIO.new(Bytes.empty_byte_buffer)
+
+ return if payload.empty?
+ if payload.bytesize > @max_frame_size
+ raise TransportException.new(TransportException::UNKNOWN, "Attempting to send frame that is too large")
+ end
+
+ case @client_type
+ when HeaderClientType::HEADERS
+ flush_header_format(payload)
+ when HeaderClientType::FRAMED_BINARY, HeaderClientType::FRAMED_COMPACT
+ flush_framed(payload)
+ when HeaderClientType::UNFRAMED_BINARY, HeaderClientType::UNFRAMED_COMPACT
+ @transport.write(payload)
+ @transport.flush
+ else
+ flush_header_format(payload)
+ end
+ end
+
+ def to_s
+ "header(#{@transport.to_s})"
+ end
+
+ # Reads the next frame to detect protocol/client type before decoding.
+ def reset_protocol
+ return unless @read_buffer.nil? || @read_buffer.eof?
+
+ read_frame(0)
+ end
+
+ private
+
+ # Sets the client type after validation
+ def set_client_type(client_type)
+ unless @allowed_client_types.include?(client_type)
+ raise TransportException.new(TransportException::UNKNOWN, "Client type #{client_type} not allowed by server")
+ end
+ @client_type = client_type
+ end
+
+ # Reads the next frame, detecting client type on first read
+ def read_frame(req_sz)
+ # Read first 4 bytes - could be frame length or protocol magic
+ first_word = @transport.read_all(4)
+ frame_size = first_word.unpack('N').first
+
+ # Check for unframed binary protocol
+ if (frame_size & BINARY_VERSION_MASK) == BINARY_VERSION_1
+ set_client_type(HeaderClientType::UNFRAMED_BINARY)
+ @protocol_id = HeaderSubprotocolID::BINARY
+ handle_unframed(first_word, req_sz)
+ return
+ end
+
+ # Check for unframed compact protocol
+ if Bytes.get_string_byte(first_word, 0) == COMPACT_PROTOCOL_ID &&
+ (Bytes.get_string_byte(first_word, 1) & COMPACT_VERSION_MASK) == COMPACT_VERSION
+ set_client_type(HeaderClientType::UNFRAMED_COMPACT)
+ @protocol_id = HeaderSubprotocolID::COMPACT
+ handle_unframed(first_word, req_sz)
+ return
+ end
+
+ # It's a framed protocol - validate frame size
+ if frame_size > @max_frame_size
+ raise TransportException.new(TransportException::UNKNOWN, "Frame size #{frame_size} exceeds maximum #{@max_frame_size}")
+ end
+
+ # Read the complete frame
+ frame_data = @transport.read_all(frame_size)
+ frame_buf = StringIO.new(frame_data)
+
+ # Check the second word for protocol type
+ second_word = frame_buf.read(4)
+ frame_buf.rewind
+
+ magic = second_word.unpack('n').first
+
+ if magic == HEADER_MAGIC
+ if frame_size < 10
+ raise TransportException.new(TransportException::UNKNOWN, "Header transport frame is too small")
+ end
+ set_client_type(HeaderClientType::HEADERS)
+ @read_buffer = parse_header_format(frame_buf)
+ elsif (second_word.unpack('N').first & BINARY_VERSION_MASK) == BINARY_VERSION_1
+ set_client_type(HeaderClientType::FRAMED_BINARY)
+ @protocol_id = HeaderSubprotocolID::BINARY
+ @read_buffer = frame_buf
+ elsif Bytes.get_string_byte(second_word, 0) == COMPACT_PROTOCOL_ID &&
+ (Bytes.get_string_byte(second_word, 1) & COMPACT_VERSION_MASK) == COMPACT_VERSION
+ set_client_type(HeaderClientType::FRAMED_COMPACT)
+ @protocol_id = HeaderSubprotocolID::COMPACT
+ @read_buffer = frame_buf
+ else
+ raise TransportException.new(TransportException::UNKNOWN, "Could not detect client transport type")
+ end
+ end
+
+ # Handles unframed protocol - puts first_word back in buffer
+ def handle_unframed(first_word, req_sz)
+ bytes_left = req_sz - 4
+ if bytes_left > 0
+ rest = @transport.read(bytes_left)
+ @read_buffer = StringIO.new(first_word + rest)
+ else
+ @read_buffer = StringIO.new(first_word)
+ end
+ end
+
+ # Parses a Header format frame
+ def parse_header_format(buf)
+ # Skip magic (already identified)
+ buf.read(2)
+
+ # Read flags and sequence ID
+ @flags = buf.read(2).unpack('n').first
+ @sequence_id = buf.read(4).unpack('N').first
+
+ # Read header length (in 32-bit words)
+ header_words = buf.read(2).unpack('n').first
+ if header_words >= 16_384
+ raise TransportException.new(TransportException::UNKNOWN, "Header size is unreasonable")
+ end
+ header_length = header_words * 4
+ end_of_headers = buf.pos + header_length
+
+ if end_of_headers > buf.string.bytesize
+ raise TransportException.new(TransportException::UNKNOWN, "Header size exceeds frame size")
+ end
+
+ # Read protocol ID
+ @protocol_id = read_varint32(buf, end_of_headers)
+
+ # Read transforms
+ transforms = []
+ transform_count = read_varint32(buf, end_of_headers)
+ transform_count.times do
+ transform_id = read_varint32(buf, end_of_headers)
+ unless transform_id == HeaderTransformID::ZLIB
+ raise TransportException.new(TransportException::UNKNOWN, "Unknown transform: #{transform_id}")
+ end
+ transforms << transform_id
+ end
+ # Read info headers
+ @read_headers = {}
+ while buf.pos < end_of_headers
+ info_type = read_varint32(buf, end_of_headers)
+ if info_type == 0
+ # header padding
+ break
+ elsif info_type == HeaderInfoType::KEY_VALUE
+ count = read_varint32(buf, end_of_headers)
+ count.times do
+ key = read_varstring(buf, end_of_headers)
+ value = read_varstring(buf, end_of_headers)
+ @read_headers[key] = value
+ end
+ else
+ # Unknown info type, skip to end of headers
+ break
+ end
+ end
+
+ # Skip any remaining header padding
+ buf.pos = end_of_headers
+
+ # Read payload and apply transforms
+ payload = buf.read
+ transforms.each do |transform_id|
+ if transform_id == HeaderTransformID::ZLIB
+ payload = Zlib::Inflate.inflate(payload)
+ end
+ end
+
+ StringIO.new(payload)
+ end
+
+ # Flushes data in Header format
+ def flush_header_format(payload)
+ # Apply transforms
+ @write_transforms.each do |transform_id|
+ if transform_id == HeaderTransformID::ZLIB
+ payload = Zlib::Deflate.deflate(payload)
+ end
+ end
+
+ # Build header data
+ header_buf = StringIO.new(Bytes.empty_byte_buffer)
+
+ # Protocol ID
+ write_varint32(header_buf, @protocol_id)
+
+ # Transforms
+ write_varint32(header_buf, @write_transforms.size)
+ @write_transforms.each { |t| write_varint32(header_buf, t) }
+
+ # Info headers (key-value pairs)
+ unless @write_headers.empty?
+ write_varint32(header_buf, HeaderInfoType::KEY_VALUE)
+ write_varint32(header_buf, @write_headers.size)
+ @write_headers.each do |key, value|
+ write_varstring(header_buf, key)
+ write_varstring(header_buf, value)
+ end
+ @write_headers = {}
+ end
+
+ # Pad header to 4-byte boundary
+ header_data = header_buf.string
+ padding = (4 - (header_data.bytesize % 4)) % 4
+ header_data += "\x00" * padding
+
+ # Calculate total frame size (excludes the 4-byte length field itself)
+ # Frame = magic(2) + flags(2) + seqid(4) + header_len(2) + header_data + payload
+ frame_size = 2 + 2 + 4 + 2 + header_data.bytesize + payload.bytesize
+
+ # Write complete frame
+ frame = Bytes.empty_byte_buffer
+ frame << [frame_size].pack('N') # Length
+ frame << [HEADER_MAGIC].pack('n') # Magic
+ frame << [@flags].pack('n') # Flags
+ frame << [@sequence_id].pack('N') # Sequence ID
+ frame << [header_data.bytesize / 4].pack('n') # Header length (in 32-bit words)
+ frame << header_data # Header data
+ frame << payload # Payload
+
+ @transport.write(frame)
+ @transport.flush
+ end
+
+ # Flushes data in simple framed format (for legacy compatibility)
+ def flush_framed(payload)
+ frame = [payload.bytesize].pack('N') + payload
+ @transport.write(frame)
+ @transport.flush
+ end
+
+ # Reads a varint32 from the given IO
+ def read_varint32(io, boundary_pos = nil)
+ shift = 0
+ result = 0
+ loop do
+ if boundary_pos && io.pos >= boundary_pos
+ raise TransportException.new(TransportException::UNKNOWN, "Trying to read past header boundary")
+ end
+ byte = io.getbyte
+ raise TransportException.new(TransportException::END_OF_FILE, "Unexpected EOF reading varint") if byte.nil?
+ result |= (byte & 0x7f) << shift
+ break if (byte & 0x80) == 0
+ shift += 7
+ end
+ result
+ end
+
+ # Writes a varint32 to the given IO
+ def write_varint32(io, n)
+ loop do
+ if (n & ~0x7F) == 0
+ io.write([n].pack('C'))
+ break
+ else
+ io.write([(n & 0x7F) | 0x80].pack('C'))
+ n >>= 7
+ end
+ end
+ end
+
+ # Reads a varint-prefixed string from the given IO
+ def read_varstring(io, boundary_pos = nil)
+ size = read_varint32(io, boundary_pos)
+ if size < 0
+ raise TransportException.new(TransportException::UNKNOWN, "Negative string size: #{size}")
+ end
+ if boundary_pos && size > (boundary_pos - io.pos)
+ raise TransportException.new(TransportException::UNKNOWN, "Info header length exceeds header size")
+ end
+ data = io.read(size)
+ if data.nil? || data.bytesize < size
+ raise TransportException.new(TransportException::END_OF_FILE, "Unexpected EOF reading string")
+ end
+ data
+ end
+
+ # Writes a varint-prefixed string to the given IO
+ def write_varstring(io, value)
+ value = Bytes.force_binary_encoding(value)
+ write_varint32(io, value.bytesize)
+ io.write(value)
+ end
+ end
+
+ # Factory for creating HeaderTransport instances
+ class HeaderTransportFactory < BaseTransportFactory
+ def initialize(allowed_client_types = nil, default_protocol = HeaderSubprotocolID::BINARY)
+ @allowed_client_types = allowed_client_types
+ @default_protocol = default_protocol
+ end
+
+ def get_transport(transport)
+ HeaderTransport.new(transport, @allowed_client_types, @default_protocol)
+ end
+
+ def to_s
+ "header"
+ end
+ end
+end
diff --git a/lib/rb/spec/header_protocol_spec.rb b/lib/rb/spec/header_protocol_spec.rb
new file mode 100644
index 0000000..3feb9b6
--- /dev/null
+++ b/lib/rb/spec/header_protocol_spec.rb
@@ -0,0 +1,466 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+require 'spec_helper'
+require_relative 'support/header_protocol_helper'
+
+describe 'HeaderProtocol' do
+ include HeaderProtocolHelper
+
+ describe Thrift::HeaderProtocol do
+ before(:each) do
+ @buffer = Thrift::MemoryBufferTransport.new
+ @protocol = Thrift::HeaderProtocol.new(@buffer)
+ end
+
+ it "should provide a to_s" do
+ expect(@protocol.to_s).to match(/header\(compact/)
+ end
+
+ it "should wrap transport in HeaderTransport" do
+ expect(@protocol.trans).to be_a(Thrift::HeaderTransport)
+ end
+
+ it "should use existing HeaderTransport if passed" do
+ header_trans = Thrift::HeaderTransport.new(@buffer)
+ protocol = Thrift::HeaderProtocol.new(header_trans)
+ expect(protocol.trans).to equal(header_trans)
+ end
+
+ describe "header management delegation" do
+ it "should delegate get_headers" do
+ # Write with headers and read back to populate read headers
+ @protocol.set_header("key", "value")
+ @protocol.write_message_begin("test", Thrift::MessageTypes::CALL, 1)
+ @protocol.write_struct_begin("args")
+ @protocol.write_field_stop
+ @protocol.write_struct_end
+ @protocol.write_message_end
+ @protocol.trans.flush
+
+ # Read back
+ data = @buffer.read(@buffer.available)
+ read_buffer = Thrift::MemoryBufferTransport.new(data)
+ read_protocol = Thrift::HeaderProtocol.new(read_buffer)
+
+ read_protocol.read_message_begin
+ headers = read_protocol.get_headers
+ expect(headers["key"]).to eq("value")
+ end
+
+ it "should delegate set_header" do
+ expect(@protocol.trans).to receive(:set_header).with("key", "value")
+ @protocol.set_header("key", "value")
+ end
+
+ it "should delegate clear_headers" do
+ expect(@protocol.trans).to receive(:clear_headers)
+ @protocol.clear_headers
+ end
+
+ it "should delegate add_transform" do
+ expect(@protocol.trans).to receive(:add_transform).with(Thrift::HeaderTransformID::ZLIB)
+ @protocol.add_transform(Thrift::HeaderTransformID::ZLIB)
+ end
+ end
+
+ describe "protocol delegation with Binary protocol" do
+ before(:each) do
+ @buffer = Thrift::MemoryBufferTransport.new
+ @protocol = Thrift::HeaderProtocol.new(@buffer, nil, Thrift::HeaderSubprotocolID::BINARY)
+ end
+
+ it "should write message begin" do
+ @protocol.write_message_begin("test_method", Thrift::MessageTypes::CALL, 123)
+ @protocol.write_message_end
+ @protocol.trans.flush
+
+ # Verify we can read it back
+ data = @buffer.read(@buffer.available)
+ read_buffer = Thrift::MemoryBufferTransport.new(data)
+ read_protocol = Thrift::HeaderProtocol.new(read_buffer)
+
+ name, type, seqid = read_protocol.read_message_begin
+ expect(name).to eq("test_method")
+ expect(type).to eq(Thrift::MessageTypes::CALL)
+ expect(seqid).to eq(123)
+ end
+
+ it "should write and read structs" do
+ @protocol.write_message_begin("test", Thrift::MessageTypes::CALL, 1)
+ @protocol.write_struct_begin("TestStruct")
+ @protocol.write_field_begin("field1", Thrift::Types::I32, 1)
+ @protocol.write_i32(42)
+ @protocol.write_field_end
+ @protocol.write_field_stop
+ @protocol.write_struct_end
+ @protocol.write_message_end
+ @protocol.trans.flush
+
+ data = @buffer.read(@buffer.available)
+ read_buffer = Thrift::MemoryBufferTransport.new(data)
+ read_protocol = Thrift::HeaderProtocol.new(read_buffer)
+
+ read_protocol.read_message_begin
+ read_protocol.read_struct_begin
+ name, type, id = read_protocol.read_field_begin
+ expect(type).to eq(Thrift::Types::I32)
+ expect(id).to eq(1)
+ value = read_protocol.read_i32
+ expect(value).to eq(42)
+ end
+
+ it "should write and read all primitive types" do
+ @protocol.write_message_begin("test", Thrift::MessageTypes::CALL, 1)
+ @protocol.write_struct_begin("TestStruct")
+
+ @protocol.write_field_begin("bool", Thrift::Types::BOOL, 1)
+ @protocol.write_bool(true)
+ @protocol.write_field_end
+
+ @protocol.write_field_begin("byte", Thrift::Types::BYTE, 2)
+ @protocol.write_byte(127)
+ @protocol.write_field_end
+
+ @protocol.write_field_begin("i16", Thrift::Types::I16, 3)
+ @protocol.write_i16(32767)
+ @protocol.write_field_end
+
+ @protocol.write_field_begin("i32", Thrift::Types::I32, 4)
+ @protocol.write_i32(2147483647)
+ @protocol.write_field_end
+
+ @protocol.write_field_begin("i64", Thrift::Types::I64, 5)
+ @protocol.write_i64(9223372036854775807)
+ @protocol.write_field_end
+
+ @protocol.write_field_begin("double", Thrift::Types::DOUBLE, 6)
+ @protocol.write_double(3.14159)
+ @protocol.write_field_end
+
+ @protocol.write_field_begin("string", Thrift::Types::STRING, 7)
+ @protocol.write_string("hello")
+ @protocol.write_field_end
+
+ @protocol.write_field_stop
+ @protocol.write_struct_end
+ @protocol.write_message_end
+ @protocol.trans.flush
+
+ data = @buffer.read(@buffer.available)
+ read_buffer = Thrift::MemoryBufferTransport.new(data)
+ read_protocol = Thrift::HeaderProtocol.new(read_buffer)
+
+ read_protocol.read_message_begin
+ read_protocol.read_struct_begin
+
+ # bool
+ _, type, _ = read_protocol.read_field_begin
+ expect(type).to eq(Thrift::Types::BOOL)
+ expect(read_protocol.read_bool).to eq(true)
+ read_protocol.read_field_end
+
+ # byte
+ _, type, _ = read_protocol.read_field_begin
+ expect(type).to eq(Thrift::Types::BYTE)
+ expect(read_protocol.read_byte).to eq(127)
+ read_protocol.read_field_end
+
+ # i16
+ _, type, _ = read_protocol.read_field_begin
+ expect(type).to eq(Thrift::Types::I16)
+ expect(read_protocol.read_i16).to eq(32767)
+ read_protocol.read_field_end
+
+ # i32
+ _, type, _ = read_protocol.read_field_begin
+ expect(type).to eq(Thrift::Types::I32)
+ expect(read_protocol.read_i32).to eq(2147483647)
+ read_protocol.read_field_end
+
+ # i64
+ _, type, _ = read_protocol.read_field_begin
+ expect(type).to eq(Thrift::Types::I64)
+ expect(read_protocol.read_i64).to eq(9223372036854775807)
+ read_protocol.read_field_end
+
+ # double
+ _, type, _ = read_protocol.read_field_begin
+ expect(type).to eq(Thrift::Types::DOUBLE)
+ expect(read_protocol.read_double).to be_within(0.00001).of(3.14159)
+ read_protocol.read_field_end
+
+ # string
+ _, type, _ = read_protocol.read_field_begin
+ expect(type).to eq(Thrift::Types::STRING)
+ expect(read_protocol.read_string).to eq("hello")
+ read_protocol.read_field_end
+ end
+ end
+
+ describe "protocol delegation with Compact protocol" do
+ before(:each) do
+ @buffer = Thrift::MemoryBufferTransport.new
+ @protocol = Thrift::HeaderProtocol.new(
+ @buffer,
+ nil,
+ Thrift::HeaderSubprotocolID::COMPACT
+ )
+ end
+
+ it "should use Compact protocol" do
+ expect(@protocol.to_s).to match(/header\(compact/)
+ end
+
+ it "should write and read with Compact protocol" do
+ @protocol.write_message_begin("test", Thrift::MessageTypes::CALL, 1)
+ @protocol.write_struct_begin("Test")
+ @protocol.write_field_begin("field", Thrift::Types::I32, 1)
+ @protocol.write_i32(999)
+ @protocol.write_field_end
+ @protocol.write_field_stop
+ @protocol.write_struct_end
+ @protocol.write_message_end
+ @protocol.trans.flush
+
+ data = @buffer.read(@buffer.available)
+ read_buffer = Thrift::MemoryBufferTransport.new(data)
+ read_protocol = Thrift::HeaderProtocol.new(read_buffer, nil, Thrift::HeaderSubprotocolID::COMPACT)
+
+ read_protocol.read_message_begin
+ read_protocol.read_struct_begin
+ _, type, _ = read_protocol.read_field_begin
+ expect(type).to eq(Thrift::Types::I32)
+ expect(read_protocol.read_i32).to eq(999)
+ end
+ end
+
+ describe "unknown protocol handling" do
+ it "should write an exception response on unknown protocol id" do
+ header_data = +""
+ header_data << varint32(0x10)
+ header_data << varint32(0)
+ frame = build_header_frame(header_data)
+
+ buffer = Thrift::MemoryBufferTransport.new(frame)
+ protocol = Thrift::HeaderProtocol.new(buffer)
+
+ expect { protocol.read_message_begin }.to raise_error(Thrift::ProtocolException)
+
+ response = buffer.read(buffer.available)
+ expect(response.bytesize).to be > 0
+ magic = response[4, 2].unpack('n').first
+ expect(magic).to eq(Thrift::HeaderTransport::HEADER_MAGIC)
+ end
+ end
+
+ describe "protocol auto-detection with legacy frames" do
+ it "should detect framed compact messages" do
+ write_buffer = Thrift::MemoryBufferTransport.new
+ write_protocol = Thrift::CompactProtocol.new(write_buffer)
+
+ write_protocol.write_message_begin("legacy_framed", Thrift::MessageTypes::CALL, 7)
+ write_protocol.write_struct_begin("Args")
+ write_protocol.write_field_stop
+ write_protocol.write_struct_end
+ write_protocol.write_message_end
+
+ payload = write_buffer.read(write_buffer.available)
+ framed = [payload.bytesize].pack('N') + payload
+
+ read_buffer = Thrift::MemoryBufferTransport.new(framed)
+ protocol = Thrift::HeaderProtocol.new(read_buffer)
+
+ name, type, seqid = protocol.read_message_begin
+ expect(name).to eq("legacy_framed")
+ expect(type).to eq(Thrift::MessageTypes::CALL)
+ expect(seqid).to eq(7)
+
+ protocol.read_struct_begin
+ _, field_type, _ = protocol.read_field_begin
+ expect(field_type).to eq(Thrift::Types::STOP)
+ protocol.read_struct_end
+ protocol.read_message_end
+ end
+
+ it "should detect unframed compact messages" do
+ write_buffer = Thrift::MemoryBufferTransport.new
+ write_protocol = Thrift::CompactProtocol.new(write_buffer)
+
+ write_protocol.write_message_begin("legacy_unframed", Thrift::MessageTypes::CALL, 9)
+ write_protocol.write_struct_begin("Args")
+ write_protocol.write_field_stop
+ write_protocol.write_struct_end
+ write_protocol.write_message_end
+
+ payload = write_buffer.read(write_buffer.available)
+
+ read_buffer = Thrift::MemoryBufferTransport.new(payload)
+ protocol = Thrift::HeaderProtocol.new(read_buffer)
+
+ name, type, seqid = protocol.read_message_begin
+ expect(name).to eq("legacy_unframed")
+ expect(type).to eq(Thrift::MessageTypes::CALL)
+ expect(seqid).to eq(9)
+
+ protocol.read_struct_begin
+ _, field_type, _ = protocol.read_field_begin
+ expect(field_type).to eq(Thrift::Types::STOP)
+ protocol.read_struct_end
+ protocol.read_message_end
+ end
+ end
+
+ describe "with compression" do
+ it "should work with ZLIB transform" do
+ @protocol.add_transform(Thrift::HeaderTransformID::ZLIB)
+
+ @protocol.write_message_begin("compressed_test", Thrift::MessageTypes::CALL, 42)
+ @protocol.write_struct_begin("Args")
+ @protocol.write_field_begin("data", Thrift::Types::STRING, 1)
+ @protocol.write_string("a" * 100) # Compressible data
+ @protocol.write_field_end
+ @protocol.write_field_stop
+ @protocol.write_struct_end
+ @protocol.write_message_end
+ @protocol.trans.flush
+
+ data = @buffer.read(@buffer.available)
+ read_buffer = Thrift::MemoryBufferTransport.new(data)
+ read_protocol = Thrift::HeaderProtocol.new(read_buffer)
+
+ name, type, seqid = read_protocol.read_message_begin
+ expect(name).to eq("compressed_test")
+ expect(seqid).to eq(42)
+
+ read_protocol.read_struct_begin
+ _, _, _ = read_protocol.read_field_begin
+ result = read_protocol.read_string
+ expect(result).to eq("a" * 100)
+ end
+ end
+
+ describe "containers" do
+ it "should write and read lists" do
+ @protocol.write_message_begin("test", Thrift::MessageTypes::CALL, 1)
+ @protocol.write_struct_begin("Test")
+ @protocol.write_field_begin("list", Thrift::Types::LIST, 1)
+ @protocol.write_list_begin(Thrift::Types::I32, 3)
+ @protocol.write_i32(1)
+ @protocol.write_i32(2)
+ @protocol.write_i32(3)
+ @protocol.write_list_end
+ @protocol.write_field_end
+ @protocol.write_field_stop
+ @protocol.write_struct_end
+ @protocol.write_message_end
+ @protocol.trans.flush
+
+ data = @buffer.read(@buffer.available)
+ read_buffer = Thrift::MemoryBufferTransport.new(data)
+ read_protocol = Thrift::HeaderProtocol.new(read_buffer)
+
+ read_protocol.read_message_begin
+ read_protocol.read_struct_begin
+ _, _, _ = read_protocol.read_field_begin
+ etype, size = read_protocol.read_list_begin
+ expect(etype).to eq(Thrift::Types::I32)
+ expect(size).to eq(3)
+ expect(read_protocol.read_i32).to eq(1)
+ expect(read_protocol.read_i32).to eq(2)
+ expect(read_protocol.read_i32).to eq(3)
+ end
+
+ it "should write and read maps" do
+ @protocol.write_message_begin("test", Thrift::MessageTypes::CALL, 1)
+ @protocol.write_struct_begin("Test")
+ @protocol.write_field_begin("map", Thrift::Types::MAP, 1)
+ @protocol.write_map_begin(Thrift::Types::STRING, Thrift::Types::I32, 2)
+ @protocol.write_string("a")
+ @protocol.write_i32(1)
+ @protocol.write_string("b")
+ @protocol.write_i32(2)
+ @protocol.write_map_end
+ @protocol.write_field_end
+ @protocol.write_field_stop
+ @protocol.write_struct_end
+ @protocol.write_message_end
+ @protocol.trans.flush
+
+ data = @buffer.read(@buffer.available)
+ read_buffer = Thrift::MemoryBufferTransport.new(data)
+ read_protocol = Thrift::HeaderProtocol.new(read_buffer)
+
+ read_protocol.read_message_begin
+ read_protocol.read_struct_begin
+ _, _, _ = read_protocol.read_field_begin
+ ktype, vtype, size = read_protocol.read_map_begin
+ expect(ktype).to eq(Thrift::Types::STRING)
+ expect(vtype).to eq(Thrift::Types::I32)
+ expect(size).to eq(2)
+ end
+
+ it "should write and read sets" do
+ @protocol.write_message_begin("test", Thrift::MessageTypes::CALL, 1)
+ @protocol.write_struct_begin("Test")
+ @protocol.write_field_begin("set", Thrift::Types::SET, 1)
+ @protocol.write_set_begin(Thrift::Types::STRING, 2)
+ @protocol.write_string("x")
+ @protocol.write_string("y")
+ @protocol.write_set_end
+ @protocol.write_field_end
+ @protocol.write_field_stop
+ @protocol.write_struct_end
+ @protocol.write_message_end
+ @protocol.trans.flush
+
+ data = @buffer.read(@buffer.available)
+ read_buffer = Thrift::MemoryBufferTransport.new(data)
+ read_protocol = Thrift::HeaderProtocol.new(read_buffer)
+
+ read_protocol.read_message_begin
+ read_protocol.read_struct_begin
+ _, _, _ = read_protocol.read_field_begin
+ etype, size = read_protocol.read_set_begin
+ expect(etype).to eq(Thrift::Types::STRING)
+ expect(size).to eq(2)
+ end
+ end
+ end
+
+ describe Thrift::HeaderProtocolFactory do
+ it "should create HeaderProtocol" do
+ factory = Thrift::HeaderProtocolFactory.new
+ buffer = Thrift::MemoryBufferTransport.new
+ protocol = factory.get_protocol(buffer)
+ expect(protocol).to be_a(Thrift::HeaderProtocol)
+ end
+
+ it "should provide a reasonable to_s" do
+ expect(Thrift::HeaderProtocolFactory.new.to_s).to eq("header")
+ end
+
+ it "should pass configuration to protocol" do
+ factory = Thrift::HeaderProtocolFactory.new(nil, Thrift::HeaderSubprotocolID::COMPACT)
+ buffer = Thrift::MemoryBufferTransport.new
+ protocol = factory.get_protocol(buffer)
+ expect(protocol.to_s).to match(/compact/)
+ end
+ end
+end
diff --git a/lib/rb/spec/header_transport_spec.rb b/lib/rb/spec/header_transport_spec.rb
new file mode 100644
index 0000000..2857e15
--- /dev/null
+++ b/lib/rb/spec/header_transport_spec.rb
@@ -0,0 +1,364 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+require 'spec_helper'
+require_relative 'support/header_protocol_helper'
+
+describe 'HeaderTransport' do
+ include HeaderProtocolHelper
+
+ describe Thrift::HeaderClientType do
+ it "should define client type constants" do
+ expect(Thrift::HeaderClientType::HEADERS).to eq(0x00)
+ expect(Thrift::HeaderClientType::FRAMED_BINARY).to eq(0x01)
+ expect(Thrift::HeaderClientType::UNFRAMED_BINARY).to eq(0x02)
+ expect(Thrift::HeaderClientType::FRAMED_COMPACT).to eq(0x03)
+ expect(Thrift::HeaderClientType::UNFRAMED_COMPACT).to eq(0x04)
+ end
+ end
+
+ describe Thrift::HeaderSubprotocolID do
+ it "should define protocol ID constants" do
+ expect(Thrift::HeaderSubprotocolID::BINARY).to eq(0x00)
+ expect(Thrift::HeaderSubprotocolID::COMPACT).to eq(0x02)
+ end
+ end
+
+ describe Thrift::HeaderTransformID do
+ it "should define transform ID constants" do
+ expect(Thrift::HeaderTransformID::ZLIB).to eq(0x01)
+ end
+ end
+
+ describe Thrift::HeaderTransport do
+ before(:each) do
+ @underlying = Thrift::MemoryBufferTransport.new
+ @trans = Thrift::HeaderTransport.new(@underlying)
+ end
+
+ it "should provide a to_s that describes the encapsulation" do
+ expect(@trans.to_s).to eq("header(memory)")
+ end
+
+ it "should pass through open?/open/close" do
+ mock_transport = double("Transport")
+ expect(mock_transport).to receive(:open?).and_return(true)
+ expect(mock_transport).to receive(:open).and_return(nil)
+ expect(mock_transport).to receive(:close).and_return(nil)
+
+ trans = Thrift::HeaderTransport.new(mock_transport)
+ expect(trans.open?).to be true
+ trans.open
+ trans.close
+ end
+
+ describe "header management" do
+ it "should allow setting and getting headers" do
+ @trans.set_header("key1", "value1")
+ @trans.set_header("key2", "value2")
+ # Headers aren't read until we receive data, so write and read back
+ expect(@trans.get_headers).to eq({})
+ end
+
+ it "should clear headers" do
+ @trans.set_header("key1", "value1")
+ @trans.clear_headers
+ # Write and flush to verify headers were cleared
+ @trans.write("test")
+ @trans.flush
+ end
+
+ it "should add transforms" do
+ expect { @trans.add_transform(Thrift::HeaderTransformID::ZLIB) }.not_to raise_error
+ end
+
+ it "should reject unknown transforms" do
+ expect { @trans.add_transform(999) }.to raise_error(Thrift::TransportException)
+ end
+ end
+
+ describe "write and flush" do
+ it "should buffer writes" do
+ @trans.write("hello")
+ @trans.write(" world")
+ expect(@underlying.available).to eq(0)
+ end
+
+ it "should write Header format on flush" do
+ @trans.write("test payload")
+ @trans.flush
+
+ # Read back the frame
+ data = @underlying.read(@underlying.available)
+
+ # Should have frame length (4 bytes) + header + payload
+ expect(data.bytesize).to be > 16
+
+ # First 4 bytes are frame length
+ frame_size = data[0, 4].unpack('N').first
+ expect(frame_size).to eq(data.bytesize - 4)
+
+ # Next 2 bytes should be header magic
+ magic = data[4, 2].unpack('n').first
+ expect(magic).to eq(Thrift::HeaderTransport::HEADER_MAGIC)
+ end
+
+ it "should include headers in frame" do
+ @trans.set_header("test-key", "test-value")
+ @trans.write("payload")
+ @trans.flush
+
+ # Read back and verify it's larger due to headers
+ data = @underlying.read(@underlying.available)
+ expect(data.bytesize).to be > 30 # Should include header key-value
+ end
+
+ it "should apply ZLIB transform" do
+ @trans.add_transform(Thrift::HeaderTransformID::ZLIB)
+ original_payload = "a" * 1000 # Compressible data
+ @trans.write(original_payload)
+ @trans.flush
+
+ data = @underlying.read(@underlying.available)
+ # Compressed frame should be smaller than uncompressed
+ expect(data.bytesize).to be < original_payload.bytesize
+ end
+ end
+
+ describe "frame size limits" do
+ it "should reject payloads larger than max frame size" do
+ @trans.set_max_frame_size(4)
+ @trans.write("12345")
+ expect { @trans.flush }.to raise_error(Thrift::TransportException, /frame that is too large/)
+ end
+ end
+
+ describe "read and frame detection" do
+ it "should detect Header format" do
+ # Write a Header frame
+ @trans.write("test data")
+ @trans.flush
+
+ # Reset for reading
+ written_data = @underlying.read(@underlying.available)
+ read_transport = Thrift::MemoryBufferTransport.new(written_data)
+ read_trans = Thrift::HeaderTransport.new(read_transport)
+
+ result = read_trans.read(9)
+ expect(result).to eq("test data")
+ end
+
+ it "should detect framed binary protocol" do
+ # Create a framed binary message
+ payload = [Thrift::BinaryProtocol::VERSION_1 | Thrift::MessageTypes::CALL].pack('N')
+ payload << "test"
+ frame = [payload.bytesize].pack('N') + payload
+
+ read_transport = Thrift::MemoryBufferTransport.new(frame)
+ read_trans = Thrift::HeaderTransport.new(read_transport)
+
+ result = read_trans.read(payload.bytesize)
+ expect(result).to eq(payload)
+ end
+
+ it "should detect unframed binary protocol" do
+ # Create an unframed binary message (version word first)
+ message = [Thrift::BinaryProtocol::VERSION_1 | Thrift::MessageTypes::CALL].pack('N')
+ message << "test"
+
+ read_transport = Thrift::MemoryBufferTransport.new(message)
+ read_trans = Thrift::HeaderTransport.new(read_transport)
+
+ result = read_trans.read(message.bytesize)
+ expect(result).to eq(message)
+ end
+
+ it "should read headers from Header frame" do
+ # Write with headers
+ @trans.set_header("request-id", "12345")
+ @trans.write("payload")
+ @trans.flush
+
+ # Read back
+ written_data = @underlying.read(@underlying.available)
+ read_transport = Thrift::MemoryBufferTransport.new(written_data)
+ read_trans = Thrift::HeaderTransport.new(read_transport)
+
+ read_trans.read(7)
+ headers = read_trans.get_headers
+ expect(headers["request-id"]).to eq("12345")
+ end
+
+ it "should decompress ZLIB payload" do
+ # Write with ZLIB
+ @trans.add_transform(Thrift::HeaderTransformID::ZLIB)
+ original = "hello world this is a test"
+ @trans.write(original)
+ @trans.flush
+
+ # Read back
+ written_data = @underlying.read(@underlying.available)
+ read_transport = Thrift::MemoryBufferTransport.new(written_data)
+ read_trans = Thrift::HeaderTransport.new(read_transport)
+
+ result = read_trans.read(original.bytesize)
+ expect(result).to eq(original)
+ end
+ end
+
+ describe "header parsing protections" do
+ it "should reject unreasonable header sizes" do
+ frame = build_header_frame("", Thrift::Bytes.empty_byte_buffer, header_words: 16_384)
+ read_transport = Thrift::MemoryBufferTransport.new(frame)
+ read_trans = Thrift::HeaderTransport.new(read_transport)
+
+ expect { read_trans.read(1) }.to raise_error(Thrift::TransportException, /Header size is unreasonable/)
+ end
+
+ it "should reject header frames that are too small" do
+ frame = Thrift::Bytes.empty_byte_buffer
+ frame << [9].pack('N')
+ frame << [Thrift::HeaderTransport::HEADER_MAGIC].pack('n')
+ frame << [0].pack('n')
+ frame << [0].pack('N')
+ frame << [0].pack('n')
+ read_transport = Thrift::MemoryBufferTransport.new(frame)
+ read_trans = Thrift::HeaderTransport.new(read_transport)
+
+ expect { read_trans.read(1) }.to raise_error(Thrift::TransportException, /frame is too small/)
+ end
+
+ it "should reject varints that cross header boundary" do
+ header_data = [0x80, 0x80, 0x80, 0x80].pack('C*')
+ frame = build_header_frame(header_data)
+ read_transport = Thrift::MemoryBufferTransport.new(frame)
+ read_trans = Thrift::HeaderTransport.new(read_transport)
+
+ expect { read_trans.read(1) }.to raise_error(Thrift::TransportException, /header boundary/)
+ end
+
+ it "should reject strings that exceed header boundary" do
+ header_data = +""
+ header_data << varint32(Thrift::HeaderSubprotocolID::BINARY)
+ header_data << varint32(0)
+ header_data << varint32(Thrift::HeaderInfoType::KEY_VALUE)
+ header_data << varint32(1)
+ header_data << varint32(10)
+ header_data << "a"
+
+ frame = build_header_frame(header_data)
+ read_transport = Thrift::MemoryBufferTransport.new(frame)
+ read_trans = Thrift::HeaderTransport.new(read_transport)
+
+ expect { read_trans.read(1) }.to raise_error(Thrift::TransportException, /Info header length exceeds header size/)
+ end
+ end
+
+ describe "round-trip" do
+ it "should handle complete write-read cycle" do
+ # Write
+ @trans.set_header("trace-id", "abc123")
+ @trans.write("hello world")
+ @trans.flush
+
+ # Read
+ written_data = @underlying.read(@underlying.available)
+ read_transport = Thrift::MemoryBufferTransport.new(written_data)
+ read_trans = Thrift::HeaderTransport.new(read_transport)
+
+ result = read_trans.read(11)
+ expect(result).to eq("hello world")
+ expect(read_trans.get_headers["trace-id"]).to eq("abc123")
+ end
+
+ it "should handle multiple headers" do
+ @trans.set_header("header1", "value1")
+ @trans.set_header("header2", "value2")
+ @trans.set_header("header3", "value3")
+ @trans.write("data")
+ @trans.flush
+
+ written_data = @underlying.read(@underlying.available)
+ read_transport = Thrift::MemoryBufferTransport.new(written_data)
+ read_trans = Thrift::HeaderTransport.new(read_transport)
+
+ read_trans.read(4)
+ headers = read_trans.get_headers
+ expect(headers["header1"]).to eq("value1")
+ expect(headers["header2"]).to eq("value2")
+ expect(headers["header3"]).to eq("value3")
+ end
+
+ it "should handle ZLIB compression round-trip" do
+ @trans.add_transform(Thrift::HeaderTransformID::ZLIB)
+ @trans.set_header("compressed", "true")
+ original = "x" * 500
+ @trans.write(original)
+ @trans.flush
+
+ written_data = @underlying.read(@underlying.available)
+ read_transport = Thrift::MemoryBufferTransport.new(written_data)
+ read_trans = Thrift::HeaderTransport.new(read_transport)
+
+ result = read_trans.read(500)
+ expect(result).to eq(original)
+ expect(read_trans.get_headers["compressed"]).to eq("true")
+ end
+ end
+
+ describe "client type restrictions" do
+ it "should reject disallowed client types" do
+ # Only allow HEADERS
+ allowed = [Thrift::HeaderClientType::HEADERS]
+
+ # Create framed binary message
+ payload = [Thrift::BinaryProtocol::VERSION_1 | Thrift::MessageTypes::CALL].pack('N')
+ frame = [payload.bytesize].pack('N') + payload
+
+ read_transport = Thrift::MemoryBufferTransport.new(frame)
+ read_trans = Thrift::HeaderTransport.new(read_transport, allowed)
+
+ expect { read_trans.read(4) }.to raise_error(Thrift::TransportException)
+ end
+ end
+ end
+
+ describe Thrift::HeaderTransportFactory do
+ it "should wrap transport in HeaderTransport" do
+ mock_transport = double("Transport")
+ factory = Thrift::HeaderTransportFactory.new
+ result = factory.get_transport(mock_transport)
+ expect(result).to be_a(Thrift::HeaderTransport)
+ end
+
+ it "should provide a reasonable to_s" do
+ expect(Thrift::HeaderTransportFactory.new.to_s).to eq("header")
+ end
+
+ it "should pass allowed_client_types to transport" do
+ allowed = [Thrift::HeaderClientType::HEADERS]
+ factory = Thrift::HeaderTransportFactory.new(allowed)
+
+ mock_transport = Thrift::MemoryBufferTransport.new
+ result = factory.get_transport(mock_transport)
+
+ expect(result).to be_a(Thrift::HeaderTransport)
+ end
+ end
+end
diff --git a/lib/rb/spec/support/header_protocol_helper.rb b/lib/rb/spec/support/header_protocol_helper.rb
new file mode 100644
index 0000000..75875a6
--- /dev/null
+++ b/lib/rb/spec/support/header_protocol_helper.rb
@@ -0,0 +1,54 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+module HeaderProtocolHelper
+ def varint32(n)
+ bytes = []
+ loop do
+ if (n & ~0x7f) == 0
+ bytes << n
+ break
+ else
+ bytes << ((n & 0x7f) | 0x80)
+ n >>= 7
+ end
+ end
+ bytes.pack('C*')
+ end
+
+ def build_header_frame(header_data, payload = Thrift::Bytes.empty_byte_buffer, header_words: nil)
+ header_data = Thrift::Bytes.force_binary_encoding(header_data)
+ if header_words.nil?
+ padding = (4 - (header_data.bytesize % 4)) % 4
+ header_data += "\x00" * padding
+ header_words = header_data.bytesize / 4
+ end
+
+ frame_size = 2 + 2 + 4 + 2 + header_data.bytesize + payload.bytesize
+ frame = Thrift::Bytes.empty_byte_buffer
+ frame << [frame_size].pack('N')
+ frame << [Thrift::HeaderTransport::HEADER_MAGIC].pack('n')
+ frame << [0].pack('n')
+ frame << [0].pack('N')
+ frame << [header_words].pack('n')
+ frame << header_data
+ frame << payload
+ frame
+ end
+end