THRIFT-5283: add support for Unix Domain Sockets in lib/rs (#2545)
Client: rs
diff --git a/lib/rs/src/server/threaded.rs b/lib/rs/src/server/threaded.rs
index 897235c..ad55b44 100644
--- a/lib/rs/src/server/threaded.rs
+++ b/lib/rs/src/server/threaded.rs
@@ -17,10 +17,15 @@
use log::warn;
-use std::net::{TcpListener, TcpStream, ToSocketAddrs};
+use std::net::{TcpListener, ToSocketAddrs};
use std::sync::Arc;
use threadpool::ThreadPool;
+#[cfg(unix)]
+use std::os::unix::net::UnixListener;
+#[cfg(unix)]
+use std::path::Path;
+
use crate::protocol::{
TInputProtocol, TInputProtocolFactory, TOutputProtocol, TOutputProtocolFactory,
};
@@ -178,10 +183,8 @@
for stream in listener.incoming() {
match stream {
Ok(s) => {
- let (i_prot, o_prot) = self.new_protocols_for_connection(s)?;
- let processor = self.processor.clone();
- self.worker_pool
- .execute(move || handle_incoming_connection(processor, i_prot, o_prot));
+ let channel = TTcpChannel::with_stream(s);
+ self.handle_stream(channel)?;
}
Err(e) => {
warn!("failed to accept remote connection with error {:?}", e);
@@ -195,19 +198,55 @@
}))
}
- fn new_protocols_for_connection(
+ /// Listen for incoming connections on `listen_path`.
+ ///
+ /// `listen_path` should implement `AsRef<Path>` trait.
+ ///
+ /// Return `()` if successful.
+ ///
+ /// Return `Err` when the server cannot bind to `listen_path` or there
+ /// is an unrecoverable error.
+ #[cfg(unix)]
+ pub fn listen_uds<P: AsRef<Path>>(&mut self, listen_path: P) -> crate::Result<()> {
+ let listener = UnixListener::bind(listen_path)?;
+ for stream in listener.incoming() {
+ match stream {
+ Ok(s) => {
+ self.handle_stream(s)?;
+ }
+ Err(e) => {
+ warn!(
+ "failed to accept connection via unix domain socket with error {:?}",
+ e
+ );
+ }
+ }
+ }
+
+ Err(crate::Error::Application(ApplicationError {
+ kind: ApplicationErrorKind::Unknown,
+ message: "aborted listen loop".into(),
+ }))
+ }
+
+ fn handle_stream<S: TIoChannel + Send + 'static>(&mut self, stream: S) -> crate::Result<()> {
+ let (i_prot, o_prot) = self.new_protocols_for_connection(stream)?;
+ let processor = self.processor.clone();
+ self.worker_pool
+ .execute(move || handle_incoming_connection(processor, i_prot, o_prot));
+ Ok(())
+ }
+
+ fn new_protocols_for_connection<S: TIoChannel + Send + 'static>(
&mut self,
- stream: TcpStream,
+ stream: S,
) -> crate::Result<(
Box<dyn TInputProtocol + Send>,
Box<dyn TOutputProtocol + Send>,
)> {
- // create the shared tcp stream
- let channel = TTcpChannel::with_stream(stream);
-
// split it into two - one to be owned by the
// input tran/proto and the other by the output
- let (r_chan, w_chan) = channel.split()?;
+ let (r_chan, w_chan) = stream.split()?;
// input protocol and transport
let r_tran = self.r_trans_factory.create(Box::new(r_chan));
diff --git a/lib/rs/src/transport/socket.rs b/lib/rs/src/transport/socket.rs
index 275bcd4..48d6dda 100644
--- a/lib/rs/src/transport/socket.rs
+++ b/lib/rs/src/transport/socket.rs
@@ -20,6 +20,9 @@
use std::io::{ErrorKind, Read, Write};
use std::net::{Shutdown, TcpStream, ToSocketAddrs};
+#[cfg(unix)]
+use std::os::unix::net::UnixStream;
+
use super::{ReadHalf, TIoChannel, WriteHalf};
use crate::{new_transport_error, TransportErrorKind};
@@ -166,3 +169,15 @@
self.if_set(|s| s.flush())
}
}
+
+#[cfg(unix)]
+impl TIoChannel for UnixStream {
+ fn split(self) -> crate::Result<(ReadHalf<Self>, WriteHalf<Self>)>
+ where
+ Self: Sized,
+ {
+ let socket_rx = self.try_clone().unwrap();
+
+ Ok((ReadHalf::new(self), WriteHalf::new(socket_rx)))
+ }
+}
diff --git a/lib/rs/test/Cargo.toml b/lib/rs/test/Cargo.toml
index 0ba96fd..47b8cbf 100644
--- a/lib/rs/test/Cargo.toml
+++ b/lib/rs/test/Cargo.toml
@@ -9,6 +9,7 @@
[dependencies]
clap = "~2.33"
bitflags = "=1.2"
+log = "0.4"
[dependencies.thrift]
path = "../"
diff --git a/lib/rs/test/src/bin/kitchen_sink_client.rs b/lib/rs/test/src/bin/kitchen_sink_client.rs
index 74197de..b98afb8 100644
--- a/lib/rs/test/src/bin/kitchen_sink_client.rs
+++ b/lib/rs/test/src/bin/kitchen_sink_client.rs
@@ -16,8 +16,16 @@
// under the License.
use clap::{clap_app, value_t};
+use log::*;
use std::convert::Into;
+use std::net::TcpStream;
+use std::net::ToSocketAddrs;
+
+#[cfg(unix)]
+use std::os::unix::net::UnixStream;
+#[cfg(unix)]
+use std::path::Path;
use kitchen_sink::base_two::{TNapkinServiceSyncClient, TRamenServiceSyncClient};
use kitchen_sink::midlayer::{MealServiceSyncClient, TMealServiceSyncClient};
@@ -30,9 +38,9 @@
TBinaryInputProtocol, TBinaryOutputProtocol, TCompactInputProtocol, TCompactOutputProtocol,
TInputProtocol, TOutputProtocol,
};
-use thrift::transport::{
- ReadHalf, TFramedReadTransport, TFramedWriteTransport, TIoChannel, TTcpChannel, WriteHalf,
-};
+use thrift::transport::{TFramedReadTransport, TFramedWriteTransport, TIoChannel, TTcpChannel};
+
+type IoProtocol = (Box<dyn TInputProtocol>, Box<dyn TOutputProtocol>);
fn main() {
match run() {
@@ -51,6 +59,7 @@
(about: "Thrift Rust kitchen sink client")
(@arg host: --host +takes_value "Host on which the Thrift test server is located")
(@arg port: --port +takes_value "Port on which the Thrift test server is listening")
+ (@arg domain_socket: --("domain-socket") + takes_value "Unix Domain Socket on which the Thrift test server is listening")
(@arg protocol: --protocol +takes_value "Thrift protocol implementation to use (\"binary\", \"compact\")")
(@arg service: --service +takes_value "Service type to contact (\"part\", \"full\", \"recursive\")")
)
@@ -58,10 +67,47 @@
let host = matches.value_of("host").unwrap_or("127.0.0.1");
let port = value_t!(matches, "port", u16).unwrap_or(9090);
+ let domain_socket = matches.value_of("domain_socket");
let protocol = matches.value_of("protocol").unwrap_or("compact");
let service = matches.value_of("service").unwrap_or("part");
- let (i_chan, o_chan) = tcp_channel(host, port)?;
+ let (i_prot, o_prot) = match domain_socket {
+ None => {
+ let listen_address = format!("{}:{}", host, port);
+ info!("Client binds to {} with {}", listen_address, protocol);
+ bind(listen_address, protocol)?
+ }
+ Some(domain_socket) => {
+ info!("Client binds to {} (UDS) with {}", domain_socket, protocol);
+ bind_uds(domain_socket, protocol)?
+ }
+ };
+
+ run_client(service, i_prot, o_prot)
+}
+
+fn bind<A: ToSocketAddrs>(listen_address: A, protocol: &str) -> Result<IoProtocol, thrift::Error> {
+ let stream = TcpStream::connect(listen_address)?;
+ let channel = TTcpChannel::with_stream(stream);
+
+ let (i_prot, o_prot) = build(channel, protocol)?;
+ Ok((i_prot, o_prot))
+}
+
+#[cfg(unix)]
+fn bind_uds<P: AsRef<Path>>(domain_socket: P, protocol: &str) -> Result<IoProtocol, thrift::Error> {
+ let stream = UnixStream::connect(domain_socket)?;
+
+ let (i_prot, o_prot) = build(stream, protocol)?;
+ Ok((i_prot, o_prot))
+}
+
+fn build<C: TIoChannel + 'static>(
+ channel: C,
+ protocol: &str,
+) -> thrift::Result<(Box<dyn TInputProtocol>, Box<dyn TOutputProtocol>)> {
+ let (i_chan, o_chan) = channel.split()?;
+
let (i_tran, o_tran) = (
TFramedReadTransport::new(i_chan),
TFramedWriteTransport::new(o_chan),
@@ -79,7 +125,7 @@
unmatched => return Err(format!("unsupported protocol {}", unmatched).into()),
};
- run_client(service, i_prot, o_prot)
+ Ok((i_prot, o_prot))
}
fn run_client(
@@ -98,15 +144,6 @@
}
}
-fn tcp_channel(
- host: &str,
- port: u16,
-) -> thrift::Result<(ReadHalf<TTcpChannel>, WriteHalf<TTcpChannel>)> {
- let mut c = TTcpChannel::new();
- c.open(&format!("{}:{}", host, port))?;
- c.split()
-}
-
fn exec_meal_client(
i_prot: Box<dyn TInputProtocol>,
o_prot: Box<dyn TOutputProtocol>,
diff --git a/lib/rs/test/src/bin/kitchen_sink_server.rs b/lib/rs/test/src/bin/kitchen_sink_server.rs
index 8b910b3..ea571c6 100644
--- a/lib/rs/test/src/bin/kitchen_sink_server.rs
+++ b/lib/rs/test/src/bin/kitchen_sink_server.rs
@@ -16,6 +16,7 @@
// under the License.
use clap::{clap_app, value_t};
+use log::*;
use thrift;
use thrift::protocol::{
@@ -28,6 +29,7 @@
TWriteTransportFactory,
};
+use crate::Socket::{ListenAddress, UnixDomainSocket};
use kitchen_sink::base_one::Noodle;
use kitchen_sink::base_two::{
BrothType, Napkin, NapkinServiceSyncHandler, Ramen, RamenServiceSyncHandler,
@@ -42,6 +44,11 @@
FullMealServiceSyncHandler,
};
+enum Socket {
+ ListenAddress(String),
+ UnixDomainSocket(String),
+}
+
fn main() {
match run() {
Ok(()) => println!("kitchen sink server completed successfully"),
@@ -57,18 +64,29 @@
(version: "0.1.0")
(author: "Apache Thrift Developers <dev@thrift.apache.org>")
(about: "Thrift Rust kitchen sink test server")
- (@arg port: --port +takes_value "port on which the test server listens")
+ (@arg port: --port +takes_value "Port on which the Thrift test server listens")
+ (@arg domain_socket: --("domain-socket") + takes_value "Unix Domain Socket on which the Thrift test server listens")
(@arg protocol: --protocol +takes_value "Thrift protocol implementation to use (\"binary\", \"compact\")")
(@arg service: --service +takes_value "Service type to contact (\"part\", \"full\", \"recursive\")")
)
.get_matches();
let port = value_t!(matches, "port", u16).unwrap_or(9090);
+ let domain_socket = matches.value_of("domain_socket");
let protocol = matches.value_of("protocol").unwrap_or("compact");
let service = matches.value_of("service").unwrap_or("part");
let listen_address = format!("127.0.0.1:{}", port);
- println!("binding to {}", listen_address);
+ let socket = match domain_socket {
+ None => {
+ info!("Server is binding to {}", listen_address);
+ Socket::ListenAddress(listen_address)
+ }
+ Some(domain_socket) => {
+ info!("Server is binding to {} (UDS)", domain_socket);
+ Socket::UnixDomainSocket(domain_socket.to_string())
+ }
+ };
let r_transport_factory = TFramedReadTransportFactory::new();
let w_transport_factory = TFramedWriteTransportFactory::new();
@@ -102,21 +120,21 @@
// Since what I'm doing is uncommon I'm just going to duplicate the code
match &*service {
"part" => run_meal_server(
- &listen_address,
+ socket,
r_transport_factory,
i_protocol_factory,
w_transport_factory,
o_protocol_factory,
),
"full" => run_full_meal_server(
- &listen_address,
+ socket,
r_transport_factory,
i_protocol_factory,
w_transport_factory,
o_protocol_factory,
),
"recursive" => run_recursive_server(
- &listen_address,
+ socket,
r_transport_factory,
i_protocol_factory,
w_transport_factory,
@@ -127,7 +145,7 @@
}
fn run_meal_server<RTF, IPF, WTF, OPF>(
- listen_address: &str,
+ socket: Socket,
r_transport_factory: RTF,
i_protocol_factory: IPF,
w_transport_factory: WTF,
@@ -149,11 +167,14 @@
1,
);
- server.listen(listen_address)
+ match socket {
+ ListenAddress(listen_address) => server.listen(listen_address),
+ UnixDomainSocket(s) => server.listen_uds(s),
+ }
}
fn run_full_meal_server<RTF, IPF, WTF, OPF>(
- listen_address: &str,
+ socket: Socket,
r_transport_factory: RTF,
i_protocol_factory: IPF,
w_transport_factory: WTF,
@@ -175,7 +196,10 @@
1,
);
- server.listen(listen_address)
+ match socket {
+ ListenAddress(listen_address) => server.listen(listen_address),
+ UnixDomainSocket(s) => server.listen_uds(s),
+ }
}
struct PartHandler;
@@ -267,7 +291,7 @@
}
fn run_recursive_server<RTF, IPF, WTF, OPF>(
- listen_address: &str,
+ socket: Socket,
r_transport_factory: RTF,
i_protocol_factory: IPF,
w_transport_factory: WTF,
@@ -289,7 +313,10 @@
1,
);
- server.listen(listen_address)
+ match socket {
+ ListenAddress(listen_address) => server.listen(listen_address),
+ UnixDomainSocket(s) => server.listen_uds(s),
+ }
}
struct RecursiveTestServerHandler;
diff --git a/test/rs/src/bin/test_client.rs b/test/rs/src/bin/test_client.rs
index 8623915..8274aae 100644
--- a/test/rs/src/bin/test_client.rs
+++ b/test/rs/src/bin/test_client.rs
@@ -21,7 +21,12 @@
use std::collections::{BTreeMap, BTreeSet};
use std::fmt::Debug;
-use std::net::TcpStream;
+use std::net::{TcpStream, ToSocketAddrs};
+
+#[cfg(unix)]
+use std::os::unix::net::UnixStream;
+#[cfg(unix)]
+use std::path::Path;
use thrift;
use thrift::protocol::{
@@ -35,6 +40,11 @@
use thrift::OrderedFloat;
use thrift_test::*;
+type ThriftClientPair = (
+ ThriftTestSyncClient<Box<dyn TInputProtocol>, Box<dyn TOutputProtocol>>,
+ Option<SecondServiceSyncClient<Box<dyn TInputProtocol>, Box<dyn TOutputProtocol>>>,
+);
+
fn main() {
env_logger::init();
@@ -51,7 +61,6 @@
fn run() -> thrift::Result<()> {
// unsupported options:
- // --domain-socket
// --pipe
// --anon-pipes
// --ssl
@@ -62,41 +71,38 @@
(about: "Rust Thrift test client")
(@arg host: --host +takes_value "Host on which the Thrift test server is located")
(@arg port: --port +takes_value "Port on which the Thrift test server is listening")
- (@arg transport: --transport +takes_value "Thrift transport implementation to use (\"buffered\", \"framed\")")
+ (@arg domain_socket: --("domain-socket") +takes_value "Unix Domain Socket on which the Thrift test server is listening")
(@arg protocol: --protocol +takes_value "Thrift protocol implementation to use (\"binary\", \"compact\", \"multi\", \"multic\")")
+ (@arg transport: --transport +takes_value "Thrift transport implementation to use (\"buffered\", \"framed\")")
(@arg testloops: -n --testloops +takes_value "Number of times to run tests")
)
.get_matches();
let host = matches.value_of("host").unwrap_or("127.0.0.1");
let port = value_t!(matches, "port", u16).unwrap_or(9090);
- let testloops = value_t!(matches, "testloops", u8).unwrap_or(1);
- let transport = matches.value_of("transport").unwrap_or("buffered");
+ let domain_socket = matches.value_of("domain_socket");
let protocol = matches.value_of("protocol").unwrap_or("binary");
+ let transport = matches.value_of("transport").unwrap_or("buffered");
+ let testloops = value_t!(matches, "testloops", u8).unwrap_or(1);
- // create a TCPStream that will be shared by all Thrift clients
- // service calls from multiple Thrift clients will be interleaved over the same connection
- // this isn't a problem for us because we're single-threaded and all calls block to completion
- let shared_stream = TcpStream::connect(format!("{}:{}", host, port))?;
-
- let mut second_service_client = if protocol.starts_with("multi") {
- let shared_stream_clone = shared_stream.try_clone()?;
- let (i_prot, o_prot) = build(shared_stream_clone, transport, protocol, "SecondService")?;
- Some(SecondServiceSyncClient::new(i_prot, o_prot))
- } else {
- None
+ let (mut thrift_test_client, mut second_service_client) = match domain_socket {
+ None => {
+ let listen_address = format!("{}:{}", host, port);
+ info!(
+ "Client binds to {} with {}+{} stack",
+ listen_address, protocol, transport
+ );
+ bind(listen_address.as_str(), protocol, transport)?
+ }
+ Some(domain_socket) => {
+ info!(
+ "Client binds to {} (UDS) with {}+{} stack",
+ domain_socket, protocol, transport
+ );
+ bind_uds(domain_socket, protocol, transport)?
+ }
};
- let mut thrift_test_client = {
- let (i_prot, o_prot) = build(shared_stream, transport, protocol, "ThriftTest")?;
- ThriftTestSyncClient::new(i_prot, o_prot)
- };
-
- info!(
- "connecting to {}:{} with {}+{} stack",
- host, port, protocol, transport
- );
-
for _ in 0..testloops {
make_thrift_calls(&mut thrift_test_client, &mut second_service_client)?
}
@@ -104,14 +110,68 @@
Ok(())
}
-fn build(
- stream: TcpStream,
+fn bind<A: ToSocketAddrs>(
+ listen_address: A,
+ protocol: &str,
+ transport: &str,
+) -> Result<ThriftClientPair, thrift::Error> {
+ // create a TCPStream that will be shared by all Thrift clients
+ // service calls from multiple Thrift clients will be interleaved over the same connection
+ // this isn't a problem for us because we're single-threaded and all calls block to completion
+ let shared_stream = TcpStream::connect(listen_address)?;
+
+ let second_service_client = if protocol.starts_with("multi") {
+ let shared_stream_clone = shared_stream.try_clone()?;
+ let channel = TTcpChannel::with_stream(shared_stream_clone);
+ let (i_prot, o_prot) = build(channel, transport, protocol, "SecondService")?;
+ Some(SecondServiceSyncClient::new(i_prot, o_prot))
+ } else {
+ None
+ };
+
+ let thrift_test_client = {
+ let channel = TTcpChannel::with_stream(shared_stream);
+ let (i_prot, o_prot) = build(channel, transport, protocol, "ThriftTest")?;
+ ThriftTestSyncClient::new(i_prot, o_prot)
+ };
+
+ Ok((thrift_test_client, second_service_client))
+}
+
+#[cfg(unix)]
+fn bind_uds<P: AsRef<Path>>(
+ domain_socket: P,
+ protocol: &str,
+ transport: &str,
+) -> Result<ThriftClientPair, thrift::Error> {
+ // create a UnixStream that will be shared by all Thrift clients
+ // service calls from multiple Thrift clients will be interleaved over the same connection
+ // this isn't a problem for us because we're single-threaded and all calls block to completion
+ let shared_stream = UnixStream::connect(domain_socket)?;
+
+ let second_service_client = if protocol.starts_with("multi") {
+ let shared_stream_clone = shared_stream.try_clone()?;
+ let (i_prot, o_prot) = build(shared_stream_clone, transport, protocol, "SecondService")?;
+ Some(SecondServiceSyncClient::new(i_prot, o_prot))
+ } else {
+ None
+ };
+
+ let thrift_test_client = {
+ let (i_prot, o_prot) = build(shared_stream, transport, protocol, "ThriftTest")?;
+ ThriftTestSyncClient::new(i_prot, o_prot)
+ };
+
+ Ok((thrift_test_client, second_service_client))
+}
+
+fn build<C: TIoChannel + 'static>(
+ channel: C,
transport: &str,
protocol: &str,
service_name: &str,
) -> thrift::Result<(Box<dyn TInputProtocol>, Box<dyn TOutputProtocol>)> {
- let c = TTcpChannel::with_stream(stream);
- let (i_chan, o_chan) = c.split()?;
+ let (i_chan, o_chan) = channel.split()?;
let (i_tran, o_tran): (Box<dyn TReadTransport>, Box<dyn TWriteTransport>) = match transport {
"buffered" => (
diff --git a/test/rs/src/bin/test_server.rs b/test/rs/src/bin/test_server.rs
index 6a05e79..7e6d08f 100644
--- a/test/rs/src/bin/test_server.rs
+++ b/test/rs/src/bin/test_server.rs
@@ -52,7 +52,6 @@
fn run() -> thrift::Result<()> {
// unsupported options:
- // --domain-socket
// --pipe
// --ssl
let matches = clap_app!(rust_test_client =>
@@ -60,21 +59,26 @@
(author: "Apache Thrift Developers <dev@thrift.apache.org>")
(about: "Rust Thrift test server")
(@arg port: --port +takes_value "port on which the test server listens")
+ (@arg domain_socket: --("domain-socket") +takes_value "Unix Domain Socket on which the test server listens")
(@arg transport: --transport +takes_value "transport implementation to use (\"buffered\", \"framed\")")
(@arg protocol: --protocol +takes_value "protocol implementation to use (\"binary\", \"compact\")")
- (@arg server_type: --server_type +takes_value "type of server instantiated (\"simple\", \"thread-pool\")")
+ (@arg server_type: --("server-type") +takes_value "type of server instantiated (\"simple\", \"thread-pool\")")
(@arg workers: -n --workers +takes_value "number of thread-pool workers (\"4\")")
)
- .get_matches();
+ .get_matches();
let port = value_t!(matches, "port", u16).unwrap_or(9090);
+ let domain_socket = matches.value_of("domain_socket");
let transport = matches.value_of("transport").unwrap_or("buffered");
let protocol = matches.value_of("protocol").unwrap_or("binary");
let server_type = matches.value_of("server_type").unwrap_or("thread-pool");
let workers = value_t!(matches, "workers", usize).unwrap_or(4);
let listen_address = format!("127.0.0.1:{}", port);
- info!("binding to {}", listen_address);
+ match domain_socket {
+ None => info!("Server is binding to {}", listen_address),
+ Some(domain_socket) => info!("Server is binding to {} (UDS)", domain_socket),
+ }
let (i_transport_factory, o_transport_factory): (
Box<dyn TReadTransportFactory>,
@@ -135,7 +139,10 @@
workers,
);
- server.listen(&listen_address)
+ match domain_socket {
+ None => server.listen(&listen_address),
+ Some(domain_socket) => server.listen_uds(domain_socket),
+ }
} else {
let mut server = TServer::new(
i_transport_factory,
@@ -146,9 +153,13 @@
workers,
);
- server.listen(&listen_address)
+ match domain_socket {
+ None => server.listen(&listen_address),
+ Some(domain_socket) => server.listen_uds(domain_socket),
+ }
}
}
+
unknown => Err(format!("unsupported server type {}", unknown).into()),
}
}
diff --git a/test/rs/src/lib.rs b/test/rs/src/lib.rs
index 3c7cfc0..9cfd7a6 100644
--- a/test/rs/src/lib.rs
+++ b/test/rs/src/lib.rs
@@ -15,9 +15,5 @@
// specific language governing permissions and limitations
// under the License.
-
-
-
-
mod thrift_test;
pub use crate::thrift_test::*;
diff --git a/test/tests.json b/test/tests.json
index a8dbef7..3563dc9 100644
--- a/test/tests.json
+++ b/test/tests.json
@@ -679,7 +679,8 @@
]
},
"sockets": [
- "ip"
+ "ip",
+ "domain"
],
"transports": [
"buffered",