THRIFT-5283: add support for Unix Domain Sockets in lib/rs (#2545)
Client: rs
diff --git a/test/rs/src/bin/test_client.rs b/test/rs/src/bin/test_client.rs
index 8623915..8274aae 100644
--- a/test/rs/src/bin/test_client.rs
+++ b/test/rs/src/bin/test_client.rs
@@ -21,7 +21,12 @@
use std::collections::{BTreeMap, BTreeSet};
use std::fmt::Debug;
-use std::net::TcpStream;
+use std::net::{TcpStream, ToSocketAddrs};
+
+#[cfg(unix)]
+use std::os::unix::net::UnixStream;
+#[cfg(unix)]
+use std::path::Path;
use thrift;
use thrift::protocol::{
@@ -35,6 +40,11 @@
use thrift::OrderedFloat;
use thrift_test::*;
+type ThriftClientPair = (
+ ThriftTestSyncClient<Box<dyn TInputProtocol>, Box<dyn TOutputProtocol>>,
+ Option<SecondServiceSyncClient<Box<dyn TInputProtocol>, Box<dyn TOutputProtocol>>>,
+);
+
fn main() {
env_logger::init();
@@ -51,7 +61,6 @@
fn run() -> thrift::Result<()> {
// unsupported options:
- // --domain-socket
// --pipe
// --anon-pipes
// --ssl
@@ -62,41 +71,38 @@
(about: "Rust Thrift test client")
(@arg host: --host +takes_value "Host on which the Thrift test server is located")
(@arg port: --port +takes_value "Port on which the Thrift test server is listening")
- (@arg transport: --transport +takes_value "Thrift transport implementation to use (\"buffered\", \"framed\")")
+ (@arg domain_socket: --("domain-socket") +takes_value "Unix Domain Socket on which the Thrift test server is listening")
(@arg protocol: --protocol +takes_value "Thrift protocol implementation to use (\"binary\", \"compact\", \"multi\", \"multic\")")
+ (@arg transport: --transport +takes_value "Thrift transport implementation to use (\"buffered\", \"framed\")")
(@arg testloops: -n --testloops +takes_value "Number of times to run tests")
)
.get_matches();
let host = matches.value_of("host").unwrap_or("127.0.0.1");
let port = value_t!(matches, "port", u16).unwrap_or(9090);
- let testloops = value_t!(matches, "testloops", u8).unwrap_or(1);
- let transport = matches.value_of("transport").unwrap_or("buffered");
+ let domain_socket = matches.value_of("domain_socket");
let protocol = matches.value_of("protocol").unwrap_or("binary");
+ let transport = matches.value_of("transport").unwrap_or("buffered");
+ let testloops = value_t!(matches, "testloops", u8).unwrap_or(1);
- // create a TCPStream that will be shared by all Thrift clients
- // service calls from multiple Thrift clients will be interleaved over the same connection
- // this isn't a problem for us because we're single-threaded and all calls block to completion
- let shared_stream = TcpStream::connect(format!("{}:{}", host, port))?;
-
- let mut second_service_client = if protocol.starts_with("multi") {
- let shared_stream_clone = shared_stream.try_clone()?;
- let (i_prot, o_prot) = build(shared_stream_clone, transport, protocol, "SecondService")?;
- Some(SecondServiceSyncClient::new(i_prot, o_prot))
- } else {
- None
+ let (mut thrift_test_client, mut second_service_client) = match domain_socket {
+ None => {
+ let listen_address = format!("{}:{}", host, port);
+ info!(
+ "Client binds to {} with {}+{} stack",
+ listen_address, protocol, transport
+ );
+ bind(listen_address.as_str(), protocol, transport)?
+ }
+ Some(domain_socket) => {
+ info!(
+ "Client binds to {} (UDS) with {}+{} stack",
+ domain_socket, protocol, transport
+ );
+ bind_uds(domain_socket, protocol, transport)?
+ }
};
- let mut thrift_test_client = {
- let (i_prot, o_prot) = build(shared_stream, transport, protocol, "ThriftTest")?;
- ThriftTestSyncClient::new(i_prot, o_prot)
- };
-
- info!(
- "connecting to {}:{} with {}+{} stack",
- host, port, protocol, transport
- );
-
for _ in 0..testloops {
make_thrift_calls(&mut thrift_test_client, &mut second_service_client)?
}
@@ -104,14 +110,68 @@
Ok(())
}
-fn build(
- stream: TcpStream,
+fn bind<A: ToSocketAddrs>(
+ listen_address: A,
+ protocol: &str,
+ transport: &str,
+) -> Result<ThriftClientPair, thrift::Error> {
+ // create a TCPStream that will be shared by all Thrift clients
+ // service calls from multiple Thrift clients will be interleaved over the same connection
+ // this isn't a problem for us because we're single-threaded and all calls block to completion
+ let shared_stream = TcpStream::connect(listen_address)?;
+
+ let second_service_client = if protocol.starts_with("multi") {
+ let shared_stream_clone = shared_stream.try_clone()?;
+ let channel = TTcpChannel::with_stream(shared_stream_clone);
+ let (i_prot, o_prot) = build(channel, transport, protocol, "SecondService")?;
+ Some(SecondServiceSyncClient::new(i_prot, o_prot))
+ } else {
+ None
+ };
+
+ let thrift_test_client = {
+ let channel = TTcpChannel::with_stream(shared_stream);
+ let (i_prot, o_prot) = build(channel, transport, protocol, "ThriftTest")?;
+ ThriftTestSyncClient::new(i_prot, o_prot)
+ };
+
+ Ok((thrift_test_client, second_service_client))
+}
+
+#[cfg(unix)]
+fn bind_uds<P: AsRef<Path>>(
+ domain_socket: P,
+ protocol: &str,
+ transport: &str,
+) -> Result<ThriftClientPair, thrift::Error> {
+ // create a UnixStream that will be shared by all Thrift clients
+ // service calls from multiple Thrift clients will be interleaved over the same connection
+ // this isn't a problem for us because we're single-threaded and all calls block to completion
+ let shared_stream = UnixStream::connect(domain_socket)?;
+
+ let second_service_client = if protocol.starts_with("multi") {
+ let shared_stream_clone = shared_stream.try_clone()?;
+ let (i_prot, o_prot) = build(shared_stream_clone, transport, protocol, "SecondService")?;
+ Some(SecondServiceSyncClient::new(i_prot, o_prot))
+ } else {
+ None
+ };
+
+ let thrift_test_client = {
+ let (i_prot, o_prot) = build(shared_stream, transport, protocol, "ThriftTest")?;
+ ThriftTestSyncClient::new(i_prot, o_prot)
+ };
+
+ Ok((thrift_test_client, second_service_client))
+}
+
+fn build<C: TIoChannel + 'static>(
+ channel: C,
transport: &str,
protocol: &str,
service_name: &str,
) -> thrift::Result<(Box<dyn TInputProtocol>, Box<dyn TOutputProtocol>)> {
- let c = TTcpChannel::with_stream(stream);
- let (i_chan, o_chan) = c.split()?;
+ let (i_chan, o_chan) = channel.split()?;
let (i_tran, o_tran): (Box<dyn TReadTransport>, Box<dyn TWriteTransport>) = match transport {
"buffered" => (