THRIFT-5774: Add remote client's IP address to ServerContext in TServerEventHandler
diff --git a/lib/java/src/main/java/org/apache/thrift/server/AbstractNonblockingServer.java b/lib/java/src/main/java/org/apache/thrift/server/AbstractNonblockingServer.java
index 0388c11..2bcd4b2 100644
--- a/lib/java/src/main/java/org/apache/thrift/server/AbstractNonblockingServer.java
+++ b/lib/java/src/main/java/org/apache/thrift/server/AbstractNonblockingServer.java
@@ -20,6 +20,7 @@
package org.apache.thrift.server;
import java.io.IOException;
+import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
@@ -31,6 +32,7 @@
import org.apache.thrift.TByteArrayOutputStream;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TProtocol;
+import org.apache.thrift.transport.SocketAddressProvider;
import org.apache.thrift.transport.TIOStreamTransport;
import org.apache.thrift.transport.TMemoryInputTransport;
import org.apache.thrift.transport.TNonblockingServerTransport;
@@ -297,6 +299,11 @@
if (eventHandler_ != null) {
context_ = eventHandler_.createContext(inProt_, outProt_);
+ SocketAddress remoteAddress =
+ trans_ instanceof SocketAddressProvider
+ ? ((SocketAddressProvider) trans_).getRemoteSocketAddress()
+ : null;
+ context_.setRemoteAddress(remoteAddress);
} else {
context_ = null;
}
diff --git a/lib/java/src/main/java/org/apache/thrift/server/ServerContext.java b/lib/java/src/main/java/org/apache/thrift/server/ServerContext.java
index adf2a43..8cc95aa 100644
--- a/lib/java/src/main/java/org/apache/thrift/server/ServerContext.java
+++ b/lib/java/src/main/java/org/apache/thrift/server/ServerContext.java
@@ -20,6 +20,8 @@
/** Interface for storing server's connection context. */
package org.apache.thrift.server;
+import java.net.SocketAddress;
+
public interface ServerContext {
/**
@@ -42,4 +44,12 @@
* unwrapped from this context.
*/
boolean isWrapperFor(Class<?> iface);
+
+ /**
+ * Set the remote socket address for this ServerContext. The remoteAddress is null when transport
+ * is not socket based
+ *
+ * @param remoteAddress The remote socket address, may be null.
+ */
+ default void setRemoteAddress(SocketAddress remoteAddress) {}
}
diff --git a/lib/java/src/main/java/org/apache/thrift/server/TSimpleServer.java b/lib/java/src/main/java/org/apache/thrift/server/TSimpleServer.java
index 69af88a..db1e57f 100644
--- a/lib/java/src/main/java/org/apache/thrift/server/TSimpleServer.java
+++ b/lib/java/src/main/java/org/apache/thrift/server/TSimpleServer.java
@@ -19,9 +19,11 @@
package org.apache.thrift.server;
+import java.net.SocketAddress;
import org.apache.thrift.TException;
import org.apache.thrift.TProcessor;
import org.apache.thrift.protocol.TProtocol;
+import org.apache.thrift.transport.SocketAddressProvider;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;
import org.slf4j.Logger;
@@ -70,6 +72,11 @@
outputProtocol = outputProtocolFactory_.getProtocol(outputTransport);
if (eventHandler_ != null) {
connectionContext = eventHandler_.createContext(inputProtocol, outputProtocol);
+ SocketAddress remoteAddress =
+ client instanceof SocketAddressProvider
+ ? ((SocketAddressProvider) client).getRemoteSocketAddress()
+ : null;
+ connectionContext.setRemoteAddress(remoteAddress);
}
while (true) {
if (eventHandler_ != null) {
diff --git a/lib/java/src/main/java/org/apache/thrift/server/TThreadPoolServer.java b/lib/java/src/main/java/org/apache/thrift/server/TThreadPoolServer.java
index 5409034..073f1bc 100644
--- a/lib/java/src/main/java/org/apache/thrift/server/TThreadPoolServer.java
+++ b/lib/java/src/main/java/org/apache/thrift/server/TThreadPoolServer.java
@@ -19,6 +19,7 @@
package org.apache.thrift.server;
+import java.net.SocketAddress;
import java.net.SocketException;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
@@ -31,6 +32,7 @@
import org.apache.thrift.TException;
import org.apache.thrift.TProcessor;
import org.apache.thrift.protocol.TProtocol;
+import org.apache.thrift.transport.SocketAddressProvider;
import org.apache.thrift.transport.TServerTransport;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;
@@ -239,7 +241,12 @@
eventHandler = Optional.ofNullable(getEventHandler());
if (eventHandler.isPresent()) {
- connectionContext = eventHandler.get().createContext(inputProtocol, outputProtocol);
+ connectionContext = eventHandler_.createContext(inputProtocol, outputProtocol);
+ SocketAddress remoteAddress =
+ client_ instanceof SocketAddressProvider
+ ? ((SocketAddressProvider) client_).getRemoteSocketAddress()
+ : null;
+ connectionContext.setRemoteAddress(remoteAddress);
}
while (true) {
diff --git a/lib/java/src/main/java/org/apache/thrift/transport/SocketAddressProvider.java b/lib/java/src/main/java/org/apache/thrift/transport/SocketAddressProvider.java
new file mode 100644
index 0000000..1f79941
--- /dev/null
+++ b/lib/java/src/main/java/org/apache/thrift/transport/SocketAddressProvider.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport;
+
+import java.net.SocketAddress;
+
+/** Interface that can retrieve the socket address. */
+public interface SocketAddressProvider {
+
+ SocketAddress getRemoteSocketAddress();
+
+ SocketAddress getLocalSocketAddress();
+}
diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TNonblockingSocket.java b/lib/java/src/main/java/org/apache/thrift/transport/TNonblockingSocket.java
index 0f4076c..650b196 100644
--- a/lib/java/src/main/java/org/apache/thrift/transport/TNonblockingSocket.java
+++ b/lib/java/src/main/java/org/apache/thrift/transport/TNonblockingSocket.java
@@ -33,7 +33,7 @@
import org.slf4j.LoggerFactory;
/** Transport for use with async client. */
-public class TNonblockingSocket extends TNonblockingTransport {
+public class TNonblockingSocket extends TNonblockingTransport implements SocketAddressProvider {
private static final Logger LOGGER = LoggerFactory.getLogger(TNonblockingSocket.class.getName());
@@ -205,4 +205,14 @@
+ socketChannel_.socket().getLocalAddress()
+ "]";
}
+
+ @Override
+ public SocketAddress getRemoteSocketAddress() {
+ return socketChannel_.socket().getRemoteSocketAddress();
+ }
+
+ @Override
+ public SocketAddress getLocalSocketAddress() {
+ return socketChannel_.socket().getLocalSocketAddress();
+ }
}
diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TSocket.java b/lib/java/src/main/java/org/apache/thrift/transport/TSocket.java
index 558c4fa..2458d0f 100644
--- a/lib/java/src/main/java/org/apache/thrift/transport/TSocket.java
+++ b/lib/java/src/main/java/org/apache/thrift/transport/TSocket.java
@@ -24,13 +24,14 @@
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Socket;
+import java.net.SocketAddress;
import java.net.SocketException;
import org.apache.thrift.TConfiguration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** Socket implementation of the TTransport interface. To be commented soon! */
-public class TSocket extends TIOStreamTransport {
+public class TSocket extends TIOStreamTransport implements SocketAddressProvider {
private static final Logger LOGGER = LoggerFactory.getLogger(TSocket.class.getName());
@@ -247,4 +248,14 @@
socket_ = null;
}
}
+
+ @Override
+ public SocketAddress getRemoteSocketAddress() {
+ return socket_.getRemoteSocketAddress();
+ }
+
+ @Override
+ public SocketAddress getLocalSocketAddress() {
+ return socket_.getLocalSocketAddress();
+ }
}
diff --git a/lib/java/src/main/java/org/apache/thrift/transport/sasl/NonblockingSaslHandler.java b/lib/java/src/main/java/org/apache/thrift/transport/sasl/NonblockingSaslHandler.java
index da82c89..66a1e5f 100644
--- a/lib/java/src/main/java/org/apache/thrift/transport/sasl/NonblockingSaslHandler.java
+++ b/lib/java/src/main/java/org/apache/thrift/transport/sasl/NonblockingSaslHandler.java
@@ -22,6 +22,7 @@
import static org.apache.thrift.transport.sasl.NegotiationStatus.COMPLETE;
import static org.apache.thrift.transport.sasl.NegotiationStatus.OK;
+import java.net.SocketAddress;
import java.nio.channels.SelectionKey;
import java.nio.charset.StandardCharsets;
import javax.security.sasl.SaslServer;
@@ -31,6 +32,7 @@
import org.apache.thrift.protocol.TProtocolFactory;
import org.apache.thrift.server.ServerContext;
import org.apache.thrift.server.TServerEventHandler;
+import org.apache.thrift.transport.SocketAddressProvider;
import org.apache.thrift.transport.TMemoryTransport;
import org.apache.thrift.transport.TNonblockingTransport;
import org.apache.thrift.transport.TTransportException;
@@ -325,6 +327,11 @@
if (eventHandler != null) {
if (!serverContextCreated) {
serverContext = eventHandler.createContext(requestProtocol, responseProtocol);
+ SocketAddress remoteAddress =
+ underlyingTransport instanceof SocketAddressProvider
+ ? ((SocketAddressProvider) underlyingTransport).getRemoteSocketAddress()
+ : null;
+ serverContext.setRemoteAddress(remoteAddress);
serverContextCreated = true;
}
eventHandler.processContext(serverContext, memoryTransport, memoryTransport);