Added new method isUnixDomainSocket() to check type of socket
diff --git a/lib/cpp/src/thrift/transport/TServerSocket.cpp b/lib/cpp/src/thrift/transport/TServerSocket.cpp
index 5e7e2c0..671cabc 100644
--- a/lib/cpp/src/thrift/transport/TServerSocket.cpp
+++ b/lib/cpp/src/thrift/transport/TServerSocket.cpp
@@ -187,7 +187,7 @@
if (!listening_)
return false;
- if (!path_.empty() && (path_[0] != '\0')) {
+ if (isUnixDomainSocket() && (path_[0] != '\0')) {
// On some platforms the domain socket file may not be instantly
// available yet, i.e. the Windows file system can be slow. Therefore
// we should check that the domain socket file actually exists.
@@ -339,7 +339,7 @@
// Defer accept
#ifdef TCP_DEFER_ACCEPT
- if (path_.empty()) {
+ if (!isUnixDomainSocket()) {
if (-1 == setsockopt(serverSocket_, IPPROTO_TCP, TCP_DEFER_ACCEPT, &one, sizeof(one))) {
int errno_copy = THRIFT_GET_SOCKET_ERROR;
GlobalOutput.perror("TServerSocket::listen() setsockopt() TCP_DEFER_ACCEPT ", errno_copy);
@@ -391,8 +391,6 @@
= std::shared_ptr<THRIFT_SOCKET>(new THRIFT_SOCKET(sv[0]), destroyer_of_fine_sockets);
}
- // tcp == false means Unix Domain socket
- bool tcp = (path_.empty());
// Validate port number
if (port_ < 0 || port_ > 0xFFFF) {
@@ -401,7 +399,7 @@
// Resolve host:port strings into an iterable of struct addrinfo*
AddressResolutionHelper resolved_addresses;
- if (tcp) {
+ if (!isUnixDomainSocket()) {
try {
resolved_addresses.resolve(address_, std::to_string(port_), SOCK_STREAM,
#ifdef ANDROID
@@ -422,7 +420,7 @@
int retries = 0;
int errno_copy = 0;
- if (!tcp) {
+ if (isUnixDomainSocket()) {
// -- Unix Domain Socket -- //
serverSocket_ = socket(PF_UNIX, SOCK_STREAM, IPPROTO_IP);
@@ -538,7 +536,7 @@
// throw an error if we failed to bind properly
if (retries > retryLimit_) {
char errbuf[1024];
- if (!tcp) {
+ if (isUnixDomainSocket()) {
THRIFT_SNPRINTF(errbuf, sizeof(errbuf), "TServerSocket::listen() PATH %s", path_.c_str());
} else {
THRIFT_SNPRINTF(errbuf, sizeof(errbuf), "TServerSocket::listen() BIND %d", port_);
@@ -565,10 +563,18 @@
listening_ = true;
}
-int TServerSocket::getPort() {
+int TServerSocket::getPort() const {
return port_;
}
+std::string TServerSocket::getPath() const {
+ return path_;
+}
+
+bool TServerSocket::isUnixDomainSocket() const {
+ return !path_.empty();
+}
+
shared_ptr<TTransport> TServerSocket::acceptImpl() {
if (serverSocket_ == THRIFT_INVALID_SOCKET) {
throw TTransportException(TTransportException::NOT_OPEN, "TServerSocket not listening");
diff --git a/lib/cpp/src/thrift/transport/TServerSocket.h b/lib/cpp/src/thrift/transport/TServerSocket.h
index e4659a0..c87a7f6 100644
--- a/lib/cpp/src/thrift/transport/TServerSocket.h
+++ b/lib/cpp/src/thrift/transport/TServerSocket.h
@@ -125,7 +125,11 @@
THRIFT_SOCKET getSocketFD() override { return serverSocket_; }
- int getPort();
+ int getPort() const;
+
+ std::string getPath() const;
+
+ bool isUnixDomainSocket() const;
void listen() override;
void interrupt() override;
diff --git a/lib/cpp/src/thrift/transport/TSocket.cpp b/lib/cpp/src/thrift/transport/TSocket.cpp
index a9bf442..1542c08 100644
--- a/lib/cpp/src/thrift/transport/TSocket.cpp
+++ b/lib/cpp/src/thrift/transport/TSocket.cpp
@@ -265,7 +265,7 @@
return;
}
- if (!path_.empty()) {
+ if (isUnixDomainSocket()) {
socket_ = socket(PF_UNIX, SOCK_STREAM, IPPROTO_IP);
} else {
socket_ = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
@@ -330,7 +330,7 @@
// Connect the socket
int ret;
- if (!path_.empty()) {
+ if (isUnixDomainSocket()) {
/*
* TODO: seems that windows now support unix sockets,
@@ -408,7 +408,7 @@
throw TTransportException(TTransportException::NOT_OPEN, "THRIFT_FCNTL() failed", errno_copy);
}
- if (path_.empty()) {
+ if (!isUnixDomainSocket()) {
setCachedAddress(res->ai_addr, static_cast<socklen_t>(res->ai_addrlen));
}
}
@@ -417,7 +417,7 @@
if (isOpen()) {
return;
}
- if (!path_.empty()) {
+ if (isUnixDomainSocket()) {
unix_open();
} else {
local_open();
@@ -425,7 +425,7 @@
}
void TSocket::unix_open() {
- if (!path_.empty()) {
+ if (isUnixDomainSocket()) {
// Unix Domain Socket does not need addrinfo struct, so we pass NULL
openConnection(nullptr);
}
@@ -692,18 +692,22 @@
return b;
}
-std::string TSocket::getHost() {
+std::string TSocket::getHost() const {
return host_;
}
-int TSocket::getPort() {
+int TSocket::getPort() const {
return port_;
}
-std::string TSocket::getPath() {
+std::string TSocket::getPath() const {
return path_;
}
+bool TSocket::isUnixDomainSocket() const {
+ return !path_.empty();
+}
+
void TSocket::setHost(string host) {
host_ = host;
}
@@ -739,7 +743,7 @@
void TSocket::setNoDelay(bool noDelay) {
noDelay_ = noDelay;
- if (socket_ == THRIFT_INVALID_SOCKET || !path_.empty()) {
+ if (socket_ == THRIFT_INVALID_SOCKET || isUnixDomainSocket()) {
return;
}
@@ -817,7 +821,7 @@
string TSocket::getSocketInfo() const {
std::ostringstream oss;
- if (path_.empty()) {
+ if (!isUnixDomainSocket()) {
if (host_.empty() || port_ == 0) {
oss << "<Host: " << getPeerAddress();
oss << " Port: " << getPeerPort() << ">";
@@ -835,7 +839,7 @@
}
std::string TSocket::getPeerHost() const {
- if (peerHost_.empty() && path_.empty()) {
+ if (peerHost_.empty() && !isUnixDomainSocket()) {
struct sockaddr_storage addr;
struct sockaddr* addrPtr;
socklen_t addrLen;
@@ -873,7 +877,7 @@
}
std::string TSocket::getPeerAddress() const {
- if (peerAddress_.empty() && path_.empty()) {
+ if (peerAddress_.empty() && !isUnixDomainSocket()) {
struct sockaddr_storage addr;
struct sockaddr* addrPtr;
socklen_t addrLen;
@@ -917,7 +921,7 @@
}
void TSocket::setCachedAddress(const sockaddr* addr, socklen_t len) {
- if (!path_.empty()) {
+ if (isUnixDomainSocket()) {
return;
}
diff --git a/lib/cpp/src/thrift/transport/TSocket.h b/lib/cpp/src/thrift/transport/TSocket.h
index 8a224f2..f14546d 100644
--- a/lib/cpp/src/thrift/transport/TSocket.h
+++ b/lib/cpp/src/thrift/transport/TSocket.h
@@ -141,21 +141,29 @@
*
* @return string host identifier
*/
- std::string getHost();
+ std::string getHost() const;
/**
* Get the port that the socket is connected to
*
* @return int port number
*/
- int getPort();
+ int getPort() const;
/**
* Get the Unix domain socket path that the socket is connected to
*
* @return std::string path
*/
- std::string getPath();
+ std::string getPath() const;
+
+ /**
+ * Whether the socket is a Unix domain socket. This is the same as checking
+ * if getPath() is not empty.
+ *
+ * @return Is the socket a Unix domain socket?
+ */
+ bool isUnixDomainSocket() const;
/**
* Set the host that socket will connect to
@@ -285,7 +293,7 @@
* Constructor to create socket from file descriptor that
* can be interrupted safely.
*/
- TSocket(THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener,
+ TSocket(THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener,
std::shared_ptr<TConfiguration> config = nullptr);
/**