THRIFT-4889 Nonblocking server with sasl support
Client: Java
Patch: Qinghui Xu
This closes #1892
diff --git a/lib/java/src/org/apache/thrift/server/TSaslNonblockingServer.java b/lib/java/src/org/apache/thrift/server/TSaslNonblockingServer.java
new file mode 100644
index 0000000..89dbb78
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/server/TSaslNonblockingServer.java
@@ -0,0 +1,480 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.server;
+
+import java.io.IOException;
+import java.nio.channels.SelectionKey;
+import java.nio.channels.Selector;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+import javax.security.auth.callback.CallbackHandler;
+
+import org.apache.thrift.TProcessor;
+import org.apache.thrift.transport.TNonblockingServerSocket;
+import org.apache.thrift.transport.TNonblockingServerTransport;
+import org.apache.thrift.transport.TNonblockingTransport;
+import org.apache.thrift.transport.TTransportException;
+import org.apache.thrift.transport.sasl.NonblockingSaslHandler;
+import org.apache.thrift.transport.sasl.NonblockingSaslHandler.Phase;
+import org.apache.thrift.transport.sasl.TBaseSaslProcessorFactory;
+import org.apache.thrift.transport.sasl.TSaslProcessorFactory;
+import org.apache.thrift.transport.sasl.TSaslServerFactory;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * TServer with sasl support, using asynchronous execution and nonblocking io.
+ */
+public class TSaslNonblockingServer extends TServer {
+ private static final Logger LOGGER = LoggerFactory.getLogger(TSaslNonblockingServer.class);
+
+ private static final int DEFAULT_NETWORK_THREADS = 1;
+ private static final int DEFAULT_AUTHENTICATION_THREADS = 1;
+ private static final int DEFAULT_PROCESSING_THREADS = Runtime.getRuntime().availableProcessors();
+
+ private final AcceptorThread acceptor;
+ private final NetworkThreadPool networkThreadPool;
+ private final ExecutorService authenticationExecutor;
+ private final ExecutorService processingExecutor;
+ private final TSaslServerFactory saslServerFactory;
+ private final TSaslProcessorFactory saslProcessorFactory;
+
+ public TSaslNonblockingServer(Args args) throws IOException {
+ super(args);
+ acceptor = new AcceptorThread((TNonblockingServerSocket) serverTransport_);
+ networkThreadPool = new NetworkThreadPool(args.networkThreads);
+ authenticationExecutor = Executors.newFixedThreadPool(args.saslThreads);
+ processingExecutor = Executors.newFixedThreadPool(args.processingThreads);
+ saslServerFactory = args.saslServerFactory;
+ saslProcessorFactory = args.saslProcessorFactory;
+ }
+
+ @Override
+ public void serve() {
+ if (eventHandler_ != null) {
+ eventHandler_.preServe();
+ }
+ networkThreadPool.start();
+ acceptor.start();
+ setServing(true);
+ }
+
+ /**
+ * Trigger a graceful shutdown, but it does not block to wait for the shutdown to finish.
+ */
+ @Override
+ public void stop() {
+ if (!stopped_) {
+ setServing(false);
+ stopped_ = true;
+ acceptor.wakeup();
+ networkThreadPool.wakeupAll();
+ authenticationExecutor.shutdownNow();
+ processingExecutor.shutdownNow();
+ }
+ }
+
+ /**
+ * Gracefully shut down the server and block until all threads are stopped.
+ *
+ * @throws InterruptedException if is interrupted while waiting for shutdown.
+ */
+ public void shutdown() throws InterruptedException {
+ stop();
+ acceptor.join();
+ for (NetworkThread networkThread : networkThreadPool.networkThreads) {
+ networkThread.join();
+ }
+ while (!authenticationExecutor.isTerminated()) {
+ authenticationExecutor.awaitTermination(10, TimeUnit.SECONDS);
+ }
+ while (!processingExecutor.isTerminated()) {
+ processingExecutor.awaitTermination(10, TimeUnit.SECONDS);
+ }
+ }
+
+ private class AcceptorThread extends Thread {
+
+ private final TNonblockingServerTransport serverTransport;
+ private final Selector acceptSelector;
+
+ private AcceptorThread(TNonblockingServerSocket serverTransport) throws IOException {
+ super("acceptor-thread");
+ this.serverTransport = serverTransport;
+ acceptSelector = Selector.open();
+ serverTransport.registerSelector(acceptSelector);
+ }
+
+ @Override
+ public void run() {
+ try {
+ serverTransport.listen();
+ while (!stopped_) {
+ select();
+ acceptNewConnection();
+ }
+ } catch (TTransportException e) {
+ // Failed to listen.
+ LOGGER.error("Failed to listen on server socket, error " + e.getType(), e);
+ } catch (Throwable e) {
+ // Unexpected errors.
+ LOGGER.error("Unexpected error in acceptor thread.", e);
+ } finally {
+ TSaslNonblockingServer.this.stop();
+ close();
+ }
+ }
+
+ void wakeup() {
+ acceptSelector.wakeup();
+ }
+
+ private void acceptNewConnection() {
+ Iterator<SelectionKey> selectedKeyItr = acceptSelector.selectedKeys().iterator();
+ while (!stopped_ && selectedKeyItr.hasNext()) {
+ SelectionKey selected = selectedKeyItr.next();
+ selectedKeyItr.remove();
+ if (selected.isAcceptable()) {
+ try {
+ while (true) {
+ // Accept all available connections from the backlog.
+ TNonblockingTransport connection = serverTransport.accept();
+ if (connection == null) {
+ break;
+ }
+ if (!networkThreadPool.acceptNewConnection(connection)) {
+ LOGGER.error("Network thread does not accept: " + connection);
+ connection.close();
+ }
+ }
+ } catch (TTransportException e) {
+ LOGGER.warn("Failed to accept incoming connection.", e);
+ }
+ } else {
+ LOGGER.error("Not acceptable selection: " + selected.channel());
+ }
+ }
+ }
+
+ private void select() {
+ try {
+ acceptSelector.select();
+ } catch (IOException e) {
+ LOGGER.error("Failed to select on the server socket.", e);
+ }
+ }
+
+ private void close() {
+ LOGGER.info("Closing acceptor thread.");
+ serverTransport.close();
+ try {
+ acceptSelector.close();
+ } catch (IOException e) {
+ LOGGER.error("Failed to close accept selector.", e);
+ }
+ }
+ }
+
+ private class NetworkThread extends Thread {
+ private final BlockingQueue<TNonblockingTransport> incomingConnections = new LinkedBlockingQueue<>();
+ private final BlockingQueue<NonblockingSaslHandler> stateTransitions = new LinkedBlockingQueue<>();
+ private final Selector ioSelector;
+
+ NetworkThread(String name) throws IOException {
+ super(name);
+ ioSelector = Selector.open();
+ }
+
+ @Override
+ public void run() {
+ try {
+ while (!stopped_) {
+ handleIncomingConnections();
+ handleStateChanges();
+ select();
+ handleIO();
+ }
+ } catch (Throwable e) {
+ LOGGER.error("Unreoverable error in " + getName(), e);
+ } finally {
+ close();
+ }
+ }
+
+ private void handleStateChanges() {
+ while (true) {
+ NonblockingSaslHandler statemachine = stateTransitions.poll();
+ if (statemachine == null) {
+ return;
+ }
+ tryRunNextPhase(statemachine);
+ }
+ }
+
+ private void select() {
+ try {
+ ioSelector.select();
+ } catch (IOException e) {
+ LOGGER.error("Failed to select in " + getName(), e);
+ }
+ }
+
+ private void handleIO() {
+ Iterator<SelectionKey> selectedKeyItr = ioSelector.selectedKeys().iterator();
+ while (!stopped_ && selectedKeyItr.hasNext()) {
+ SelectionKey selected = selectedKeyItr.next();
+ selectedKeyItr.remove();
+ if (!selected.isValid()) {
+ closeChannel(selected);
+ }
+ NonblockingSaslHandler saslHandler = (NonblockingSaslHandler) selected.attachment();
+ if (selected.isReadable()) {
+ saslHandler.handleRead();
+ } else if (selected.isWritable()) {
+ saslHandler.handleWrite();
+ } else {
+ LOGGER.error("Invalid intrest op " + selected.interestOps());
+ closeChannel(selected);
+ continue;
+ }
+ if (saslHandler.isCurrentPhaseDone()) {
+ tryRunNextPhase(saslHandler);
+ }
+ }
+ }
+
+ // The following methods are modifying the registered channel set on the selector, which itself
+ // is not thread safe. Thus we need a lock to protect it from race condition.
+
+ private synchronized void handleIncomingConnections() {
+ while (true) {
+ TNonblockingTransport connection = incomingConnections.poll();
+ if (connection == null) {
+ return;
+ }
+ if (!connection.isOpen()) {
+ LOGGER.warn("Incoming connection is already closed");
+ continue;
+ }
+ try {
+ SelectionKey selectionKey = connection.registerSelector(ioSelector, SelectionKey.OP_READ);
+ if (selectionKey.isValid()) {
+ NonblockingSaslHandler saslHandler = new NonblockingSaslHandler(selectionKey, connection,
+ saslServerFactory, saslProcessorFactory, inputProtocolFactory_, outputProtocolFactory_,
+ eventHandler_);
+ selectionKey.attach(saslHandler);
+ }
+ } catch (IOException e) {
+ LOGGER.error("Failed to register connection for the selector, close it.", e);
+ connection.close();
+ }
+ }
+ }
+
+ private synchronized void close() {
+ LOGGER.warn("Closing " + getName());
+ while (true) {
+ TNonblockingTransport incomingConnection = incomingConnections.poll();
+ if (incomingConnection == null) {
+ break;
+ }
+ incomingConnection.close();
+ }
+ Set<SelectionKey> registered = ioSelector.keys();
+ for (SelectionKey selection : registered) {
+ closeChannel(selection);
+ }
+ try {
+ ioSelector.close();
+ } catch (IOException e) {
+ LOGGER.error("Failed to close io selector " + getName(), e);
+ }
+ }
+
+ private synchronized void closeChannel(SelectionKey selectionKey) {
+ if (selectionKey.attachment() == null) {
+ try {
+ selectionKey.channel().close();
+ } catch (IOException e) {
+ LOGGER.error("Failed to close channel.", e);
+ } finally {
+ selectionKey.cancel();
+ }
+ } else {
+ NonblockingSaslHandler saslHandler = (NonblockingSaslHandler) selectionKey.attachment();
+ saslHandler.close();
+ }
+ }
+
+ private void tryRunNextPhase(NonblockingSaslHandler saslHandler) {
+ Phase nextPhase = saslHandler.getNextPhase();
+ saslHandler.stepToNextPhase();
+ switch (nextPhase) {
+ case EVALUATING_SASL_RESPONSE:
+ authenticationExecutor.submit(new Computation(saslHandler));
+ break;
+ case PROCESSING:
+ processingExecutor.submit(new Computation(saslHandler));
+ break;
+ case CLOSING:
+ saslHandler.runCurrentPhase();
+ break;
+ default: // waiting for next io event for the current state machine
+ }
+ }
+
+ public boolean accept(TNonblockingTransport connection) {
+ if (stopped_) {
+ return false;
+ }
+ if (incomingConnections.offer(connection)) {
+ wakeup();
+ return true;
+ }
+ return false;
+ }
+
+ private void wakeup() {
+ ioSelector.wakeup();
+ }
+
+ private class Computation implements Runnable {
+
+ private final NonblockingSaslHandler statemachine;
+
+ private Computation(NonblockingSaslHandler statemachine) {
+ this.statemachine = statemachine;
+ }
+
+ @Override
+ public void run() {
+ try {
+ while (!statemachine.isCurrentPhaseDone()) {
+ statemachine.runCurrentPhase();
+ }
+ stateTransitions.add(statemachine);
+ wakeup();
+ } catch (Throwable e) {
+ LOGGER.error("Damn it!", e);
+ }
+ }
+ }
+ }
+
+ private class NetworkThreadPool {
+ private final List<NetworkThread> networkThreads;
+ private int accepted = 0;
+
+ NetworkThreadPool(int size) throws IOException {
+ networkThreads = new ArrayList<>(size);
+ int digits = (int) Math.log10(size) + 1;
+ String threadNamePattern = "network-thread-%0" + digits + "d";
+ for (int i = 0; i < size; i++) {
+ networkThreads.add(new NetworkThread(String.format(threadNamePattern, i)));
+ }
+ }
+
+ /**
+ * Round robin new connection among all the network threads.
+ *
+ * @param connection incoming connection.
+ * @return true if the incoming connection is accepted by network thread pool.
+ */
+ boolean acceptNewConnection(TNonblockingTransport connection) {
+ return networkThreads.get((accepted ++) % networkThreads.size()).accept(connection);
+ }
+
+ public void start() {
+ for (NetworkThread thread : networkThreads) {
+ thread.start();
+ }
+ }
+
+ void wakeupAll() {
+ for (NetworkThread networkThread : networkThreads) {
+ networkThread.wakeup();
+ }
+ }
+ }
+
+ public static class Args extends AbstractServerArgs<Args> {
+
+ private int networkThreads = DEFAULT_NETWORK_THREADS;
+ private int saslThreads = DEFAULT_AUTHENTICATION_THREADS;
+ private int processingThreads = DEFAULT_PROCESSING_THREADS;
+ private TSaslServerFactory saslServerFactory = new TSaslServerFactory();
+ private TSaslProcessorFactory saslProcessorFactory;
+
+ public Args(TNonblockingServerTransport transport) {
+ super(transport);
+ }
+
+ public Args networkThreads(int networkThreads) {
+ this.networkThreads = networkThreads <= 0 ? DEFAULT_NETWORK_THREADS : networkThreads;
+ return this;
+ }
+
+ public Args saslThreads(int authenticationThreads) {
+ this.saslThreads = authenticationThreads <= 0 ? DEFAULT_AUTHENTICATION_THREADS : authenticationThreads;
+ return this;
+ }
+
+ public Args processingThreads(int processingThreads) {
+ this.processingThreads = processingThreads <= 0 ? DEFAULT_PROCESSING_THREADS : processingThreads;
+ return this;
+ }
+
+ public Args processor(TProcessor processor) {
+ saslProcessorFactory = new TBaseSaslProcessorFactory(processor);
+ return this;
+ }
+
+ public Args saslProcessorFactory(TSaslProcessorFactory saslProcessorFactory) {
+ if (saslProcessorFactory == null) {
+ throw new NullPointerException("Processor factory cannot be null");
+ }
+ this.saslProcessorFactory = saslProcessorFactory;
+ return this;
+ }
+
+ public Args addSaslMechanism(String mechanism, String protocol, String serverName,
+ Map<String, String> props, CallbackHandler cbh) {
+ saslServerFactory.addSaslMechanism(mechanism, protocol, serverName, props, cbh);
+ return this;
+ }
+
+ public Args saslServerFactory(TSaslServerFactory saslServerFactory) {
+ if (saslServerFactory == null) {
+ throw new NullPointerException("saslServerFactory cannot be null");
+ }
+ this.saslServerFactory = saslServerFactory;
+ return this;
+ }
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/server/TServerEventHandler.java b/lib/java/src/org/apache/thrift/server/TServerEventHandler.java
index f069b9b..3bd7959 100644
--- a/lib/java/src/org/apache/thrift/server/TServerEventHandler.java
+++ b/lib/java/src/org/apache/thrift/server/TServerEventHandler.java
@@ -28,6 +28,10 @@
* about. Your subclass can also store local data that you may care about,
* such as additional "arguments" to these methods (stored in the object
* instance's state).
+ *
+ * TODO: It seems this is a custom code entry point created for some resource management purpose in hive.
+ * But when looking into hive code, we see that the argments of TProtocol and TTransport are never used.
+ * We probably should remove these arguments from all the methods.
*/
public interface TServerEventHandler {
@@ -56,4 +60,4 @@
void processContext(ServerContext serverContext,
TTransport inputTransport, TTransport outputTransport);
-}
\ No newline at end of file
+}
diff --git a/lib/java/src/org/apache/thrift/transport/TEOFException.java b/lib/java/src/org/apache/thrift/transport/TEOFException.java
new file mode 100644
index 0000000..b5ae6ef
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/TEOFException.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport;
+
+/**
+ * End of file, especially, the underlying socket is closed.
+ */
+public class TEOFException extends TTransportException {
+
+ public TEOFException(String message) {
+ super(TTransportException.END_OF_FILE, message);
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/TMemoryTransport.java b/lib/java/src/org/apache/thrift/transport/TMemoryTransport.java
new file mode 100644
index 0000000..f41bc09
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/TMemoryTransport.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport;
+
+import java.nio.ByteBuffer;
+
+import org.apache.thrift.TByteArrayOutputStream;
+
+/**
+ * In memory transport with separate buffers for input and output.
+ */
+public class TMemoryTransport extends TTransport {
+
+ private final ByteBuffer inputBuffer;
+ private final TByteArrayOutputStream outputBuffer;
+
+ public TMemoryTransport(byte[] input) {
+ inputBuffer = ByteBuffer.wrap(input);
+ outputBuffer = new TByteArrayOutputStream(1024);
+ }
+
+ @Override
+ public boolean isOpen() {
+ return true;
+ }
+
+ /**
+ * Opening on an in memory transport should have no effect.
+ */
+ @Override
+ public void open() {
+ // Do nothing.
+ }
+
+ @Override
+ public void close() {
+ // Do nothing.
+ }
+
+ @Override
+ public int read(byte[] buf, int off, int len) throws TTransportException {
+ int remaining = inputBuffer.remaining();
+ if (remaining < len) {
+ throw new TTransportException(TTransportException.END_OF_FILE,
+ "There's only " + remaining + "bytes, but it asks for " + len);
+ }
+ inputBuffer.get(buf, off, len);
+ return len;
+ }
+
+ @Override
+ public void write(byte[] buf, int off, int len) throws TTransportException {
+ outputBuffer.write(buf, off, len);
+ }
+
+ /**
+ * Get all the bytes written by thrift output protocol.
+ *
+ * @return a byte array.
+ */
+ public TByteArrayOutputStream getOutput() {
+ return outputBuffer;
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/TNonblockingServerSocket.java b/lib/java/src/org/apache/thrift/transport/TNonblockingServerSocket.java
index df37cb0..1631892 100644
--- a/lib/java/src/org/apache/thrift/transport/TNonblockingServerSocket.java
+++ b/lib/java/src/org/apache/thrift/transport/TNonblockingServerSocket.java
@@ -108,7 +108,8 @@
}
}
- protected TNonblockingSocket acceptImpl() throws TTransportException {
+ @Override
+ public TNonblockingSocket accept() throws TTransportException {
if (serverSocket_ == null) {
throw new TTransportException(TTransportException.NOT_OPEN, "No underlying server socket.");
}
@@ -160,4 +161,9 @@
return serverSocket_.getLocalPort();
}
+ // Expose it for test purpose.
+ ServerSocketChannel getServerSocketChannel() {
+ return serverSocketChannel;
+ }
+
}
diff --git a/lib/java/src/org/apache/thrift/transport/TNonblockingServerTransport.java b/lib/java/src/org/apache/thrift/transport/TNonblockingServerTransport.java
index ba45b09..daac0d5 100644
--- a/lib/java/src/org/apache/thrift/transport/TNonblockingServerTransport.java
+++ b/lib/java/src/org/apache/thrift/transport/TNonblockingServerTransport.java
@@ -28,4 +28,12 @@
public abstract class TNonblockingServerTransport extends TServerTransport {
public abstract void registerSelector(Selector selector);
+
+ /**
+ *
+ * @return an incoming connection or null if there is none.
+ * @throws TTransportException
+ */
+ @Override
+ public abstract TNonblockingTransport accept() throws TTransportException;
}
diff --git a/lib/java/src/org/apache/thrift/transport/TNonblockingSocket.java b/lib/java/src/org/apache/thrift/transport/TNonblockingSocket.java
index f86a48b..37a66d6 100644
--- a/lib/java/src/org/apache/thrift/transport/TNonblockingSocket.java
+++ b/lib/java/src/org/apache/thrift/transport/TNonblockingSocket.java
@@ -207,4 +207,9 @@
return socketChannel_.finishConnect();
}
+ @Override
+ public String toString() {
+ return "[remote: " + socketChannel_.socket().getRemoteSocketAddress() +
+ ", local: " + socketChannel_.socket().getLocalAddress() + "]" ;
+ }
}
diff --git a/lib/java/src/org/apache/thrift/transport/TSaslClientTransport.java b/lib/java/src/org/apache/thrift/transport/TSaslClientTransport.java
index 4b1ca0a..5fc7cff 100644
--- a/lib/java/src/org/apache/thrift/transport/TSaslClientTransport.java
+++ b/lib/java/src/org/apache/thrift/transport/TSaslClientTransport.java
@@ -27,6 +27,7 @@
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
+import org.apache.thrift.transport.sasl.NegotiationStatus;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
diff --git a/lib/java/src/org/apache/thrift/transport/TSaslServerTransport.java b/lib/java/src/org/apache/thrift/transport/TSaslServerTransport.java
index 39b81ca..31f309e 100644
--- a/lib/java/src/org/apache/thrift/transport/TSaslServerTransport.java
+++ b/lib/java/src/org/apache/thrift/transport/TSaslServerTransport.java
@@ -31,6 +31,8 @@
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
+import org.apache.thrift.transport.sasl.NegotiationStatus;
+import org.apache.thrift.transport.sasl.TSaslServerDefinition;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -50,29 +52,9 @@
private Map<String, TSaslServerDefinition> serverDefinitionMap = new HashMap<String, TSaslServerDefinition>();
/**
- * Contains all the parameters used to define a SASL server implementation.
- */
- private static class TSaslServerDefinition {
- public String mechanism;
- public String protocol;
- public String serverName;
- public Map<String, String> props;
- public CallbackHandler cbh;
-
- public TSaslServerDefinition(String mechanism, String protocol, String serverName,
- Map<String, String> props, CallbackHandler cbh) {
- this.mechanism = mechanism;
- this.protocol = protocol;
- this.serverName = serverName;
- this.props = props;
- this.cbh = cbh;
- }
- }
-
- /**
* Uses the given underlying transport. Assumes that addServerDefinition is
* called later.
- *
+ *
* @param transport
* Transport underlying this one.
*/
@@ -84,7 +66,7 @@
* Creates a <code>SaslServer</code> using the given SASL-specific parameters.
* See the Java documentation for <code>Sasl.createSaslServer</code> for the
* details of the parameters.
- *
+ *
* @param transport
* The underlying Thrift transport.
*/
diff --git a/lib/java/src/org/apache/thrift/transport/TSaslTransport.java b/lib/java/src/org/apache/thrift/transport/TSaslTransport.java
index 4685d64..d1a3d31 100644
--- a/lib/java/src/org/apache/thrift/transport/TSaslTransport.java
+++ b/lib/java/src/org/apache/thrift/transport/TSaslTransport.java
@@ -20,8 +20,6 @@
package org.apache.thrift.transport;
import java.nio.charset.StandardCharsets;
-import java.util.HashMap;
-import java.util.Map;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
@@ -30,6 +28,7 @@
import org.apache.thrift.EncodingUtils;
import org.apache.thrift.TByteArrayOutputStream;
+import org.apache.thrift.transport.sasl.NegotiationStatus;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -52,39 +51,6 @@
}
/**
- * Status bytes used during the initial Thrift SASL handshake.
- */
- protected static enum NegotiationStatus {
- START((byte)0x01),
- OK((byte)0x02),
- BAD((byte)0x03),
- ERROR((byte)0x04),
- COMPLETE((byte)0x05);
-
- private final byte value;
-
- private static final Map<Byte, NegotiationStatus> reverseMap =
- new HashMap<Byte, NegotiationStatus>();
- static {
- for (NegotiationStatus s : NegotiationStatus.class.getEnumConstants()) {
- reverseMap.put(s.getValue(), s);
- }
- }
-
- private NegotiationStatus(byte val) {
- this.value = val;
- }
-
- public byte getValue() {
- return value;
- }
-
- public static NegotiationStatus byValue(byte val) {
- return reverseMap.get(val);
- }
- }
-
- /**
* Transport underlying this one.
*/
protected TTransport underlyingTransport;
diff --git a/lib/java/src/org/apache/thrift/transport/TServerSocket.java b/lib/java/src/org/apache/thrift/transport/TServerSocket.java
index 79f7b7f..eb302fd 100644
--- a/lib/java/src/org/apache/thrift/transport/TServerSocket.java
+++ b/lib/java/src/org/apache/thrift/transport/TServerSocket.java
@@ -121,18 +121,23 @@
}
}
- protected TSocket acceptImpl() throws TTransportException {
+ @Override
+ public TSocket accept() throws TTransportException {
if (serverSocket_ == null) {
throw new TTransportException(TTransportException.NOT_OPEN, "No underlying server socket.");
}
+ Socket result;
try {
- Socket result = serverSocket_.accept();
- TSocket result2 = new TSocket(result);
- result2.setTimeout(clientTimeout_);
- return result2;
- } catch (IOException iox) {
- throw new TTransportException(iox);
+ result = serverSocket_.accept();
+ } catch (Exception e) {
+ throw new TTransportException(e);
}
+ if (result == null) {
+ throw new TTransportException("Blocking server's accept() may not return NULL");
+ }
+ TSocket socket = new TSocket(result);
+ socket.setTimeout(clientTimeout_);
+ return socket;
}
public void close() {
diff --git a/lib/java/src/org/apache/thrift/transport/TServerTransport.java b/lib/java/src/org/apache/thrift/transport/TServerTransport.java
index 424e4fa..55ef0c4 100644
--- a/lib/java/src/org/apache/thrift/transport/TServerTransport.java
+++ b/lib/java/src/org/apache/thrift/transport/TServerTransport.java
@@ -56,18 +56,18 @@
public abstract void listen() throws TTransportException;
- public final TTransport accept() throws TTransportException {
- TTransport transport = acceptImpl();
- if (transport == null) {
- throw new TTransportException("accept() may not return NULL");
- }
- return transport;
- }
+ /**
+ * Accept incoming connection on the server socket. When there is no incoming connection available:
+ * either it should block infinitely in a blocking implementation, either it should return null in
+ * a nonblocking implementation.
+ *
+ * @return new connection
+ * @throws TTransportException if IO error.
+ */
+ public abstract TTransport accept() throws TTransportException;
public abstract void close();
- protected abstract TTransport acceptImpl() throws TTransportException;
-
/**
* Optional method implementation. This signals to the server transport
* that it should break out of any accept() or listen() that it is currently
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/DataFrameHeaderReader.java b/lib/java/src/org/apache/thrift/transport/sasl/DataFrameHeaderReader.java
new file mode 100644
index 0000000..2900df9
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/DataFrameHeaderReader.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+/**
+ * The header for data frame, it only contains a 4-byte payload size.
+ */
+public class DataFrameHeaderReader extends FixedSizeHeaderReader {
+ public static final int PAYLOAD_LENGTH_BYTES = 4;
+
+ private int payloadSize;
+
+ @Override
+ protected int headerSize() {
+ return PAYLOAD_LENGTH_BYTES;
+ }
+
+ @Override
+ protected void onComplete() throws TInvalidSaslFrameException {
+ payloadSize = byteBuffer.getInt(0);
+ if (payloadSize < 0) {
+ throw new TInvalidSaslFrameException("Payload size is negative: " + payloadSize);
+ }
+ }
+
+ @Override
+ public int payloadSize() {
+ return payloadSize;
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/DataFrameReader.java b/lib/java/src/org/apache/thrift/transport/sasl/DataFrameReader.java
new file mode 100644
index 0000000..e6900bb
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/DataFrameReader.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+/**
+ * Frames for thrift (serialized) messages.
+ */
+public class DataFrameReader extends FrameReader<DataFrameHeaderReader> {
+
+ public DataFrameReader() {
+ super(new DataFrameHeaderReader());
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/DataFrameWriter.java b/lib/java/src/org/apache/thrift/transport/sasl/DataFrameWriter.java
new file mode 100644
index 0000000..a2dd15a
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/DataFrameWriter.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import java.nio.ByteBuffer;
+
+import org.apache.thrift.EncodingUtils;
+import org.apache.thrift.utils.StringUtils;
+
+import static org.apache.thrift.transport.sasl.DataFrameHeaderReader.PAYLOAD_LENGTH_BYTES;
+
+/**
+ * Write frames of thrift messages. It expects an empty/null header to be provided with a payload
+ * to be written out. Non empty headers are considered as error.
+ */
+public class DataFrameWriter extends FrameWriter {
+
+ @Override
+ public void withOnlyPayload(byte[] payload, int offset, int length) {
+ if (!isComplete()) {
+ throw new IllegalStateException("Previsous write is not yet complete, with " +
+ frameBytes.remaining() + " bytes left.");
+ }
+ frameBytes = buildFrameWithPayload(payload, offset, length);
+ }
+
+ @Override
+ protected ByteBuffer buildFrame(byte[] header, int headerOffset, int headerLength,
+ byte[] payload, int payloadOffset, int payloadLength) {
+ if (header != null && headerLength > 0) {
+ throw new IllegalArgumentException("Extra header [" + StringUtils.bytesToHexString(header) +
+ "] offset " + payloadOffset + " length " + payloadLength);
+ }
+ return buildFrameWithPayload(payload, payloadOffset, payloadLength);
+ }
+
+ private ByteBuffer buildFrameWithPayload(byte[] payload, int offset, int length) {
+ byte[] bytes = new byte[PAYLOAD_LENGTH_BYTES + length];
+ EncodingUtils.encodeBigEndian(length, bytes, 0);
+ System.arraycopy(payload, offset, bytes, PAYLOAD_LENGTH_BYTES, length);
+ return ByteBuffer.wrap(bytes);
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/FixedSizeHeaderReader.java b/lib/java/src/org/apache/thrift/transport/sasl/FixedSizeHeaderReader.java
new file mode 100644
index 0000000..1cbc0ac
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/FixedSizeHeaderReader.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import org.apache.thrift.transport.TTransport;
+import org.apache.thrift.transport.TTransportException;
+import org.apache.thrift.utils.StringUtils;
+
+import java.nio.ByteBuffer;
+
+/**
+ * Headers' size should be predefined.
+ */
+public abstract class FixedSizeHeaderReader implements FrameHeaderReader {
+
+ protected final ByteBuffer byteBuffer = ByteBuffer.allocate(headerSize());
+
+ @Override
+ public boolean isComplete() {
+ return !byteBuffer.hasRemaining();
+ }
+
+ @Override
+ public void clear() {
+ byteBuffer.clear();
+ }
+
+ @Override
+ public byte[] toBytes() {
+ if (!isComplete()) {
+ throw new IllegalStateException("Header is not yet complete " + StringUtils.bytesToHexString(byteBuffer.array(), 0, byteBuffer.position()));
+ }
+ return byteBuffer.array();
+ }
+
+ @Override
+ public boolean read(TTransport transport) throws TTransportException {
+ FrameReader.readAvailable(transport, byteBuffer);
+ if (byteBuffer.hasRemaining()) {
+ return false;
+ }
+ onComplete();
+ return true;
+ }
+
+ /**
+ * @return Size of the header.
+ */
+ protected abstract int headerSize();
+
+ /**
+ * Actions (e.g. validation) to carry out when the header is complete.
+ *
+ * @throws TTransportException
+ */
+ protected abstract void onComplete() throws TTransportException;
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/FrameHeaderReader.java b/lib/java/src/org/apache/thrift/transport/sasl/FrameHeaderReader.java
new file mode 100644
index 0000000..f7c6593
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/FrameHeaderReader.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import org.apache.thrift.transport.TTransport;
+import org.apache.thrift.transport.TTransportException;
+
+/**
+ * Read headers for a frame. For each frame, the header contains payload size and other metadata.
+ */
+public interface FrameHeaderReader {
+
+ /**
+ * As the thrift sasl specification states, all sasl messages (both for negotiatiing and for
+ * sending data) should have a header to indicate the size of the payload.
+ *
+ * @return size of the payload.
+ */
+ int payloadSize();
+
+ /**
+ *
+ * @return The received bytes for the header.
+ * @throws IllegalStateException if isComplete returns false.
+ */
+ byte[] toBytes();
+
+ /**
+ * @return true if this header has all its fields set.
+ */
+ boolean isComplete();
+
+ /**
+ * Clear the header and make it available to read a new header.
+ */
+ void clear();
+
+ /**
+ * (Nonblocking) Read fields from underlying transport layer.
+ *
+ * @param transport underlying transport.
+ * @return true if header is complete after read.
+ * @throws TSaslNegotiationException if fail to read a valid header of a sasl negotiation message.
+ * @throws TTransportException if io error.
+ */
+ boolean read(TTransport transport) throws TSaslNegotiationException, TTransportException;
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/FrameReader.java b/lib/java/src/org/apache/thrift/transport/sasl/FrameReader.java
new file mode 100644
index 0000000..acb4b73
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/FrameReader.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import org.apache.thrift.transport.TEOFException;
+import org.apache.thrift.transport.TTransport;
+import org.apache.thrift.transport.TTransportException;
+
+import java.nio.ByteBuffer;
+
+/**
+ * Read frames from a transport. Each frame has a header and a payload. A header will indicate
+ * the size of the payload and other informations about how to decode payload.
+ * Implementations should subclass it by providing a header reader implementation.
+ *
+ * @param <T> Header type.
+ */
+public abstract class FrameReader<T extends FrameHeaderReader> {
+ private final T header;
+ private ByteBuffer payload;
+
+ protected FrameReader(T header) {
+ this.header = header;
+ }
+
+ /**
+ * (Nonblocking) Read available bytes out of the transport without blocking to wait for incoming
+ * data.
+ *
+ * @param transport TTransport
+ * @return true if current frame is complete after read.
+ * @throws TSaslNegotiationException if fail to read back a valid sasl negotiation message.
+ * @throws TTransportException if io error.
+ */
+ public boolean read(TTransport transport) throws TSaslNegotiationException, TTransportException {
+ if (!header.isComplete()) {
+ if (readHeader(transport)) {
+ payload = ByteBuffer.allocate(header.payloadSize());
+ } else {
+ return false;
+ }
+ }
+ if (header.payloadSize() == 0) {
+ return true;
+ }
+ return readPayload(transport);
+ }
+
+ /**
+ * (Nonblocking) Try to read available header bytes from transport.
+ *
+ * @return true if header is complete after read.
+ * @throws TSaslNegotiationException if fail to read back a validd sasl negotiation header.
+ * @throws TTransportException if io error.
+ */
+ private boolean readHeader(TTransport transport) throws TSaslNegotiationException, TTransportException {
+ return header.read(transport);
+ }
+
+ /**
+ * (Nonblocking) Try to read available
+ *
+ * @param transport underlying transport.
+ * @return true if payload is complete after read.
+ * @throws TTransportException if io error.
+ */
+ private boolean readPayload(TTransport transport) throws TTransportException {
+ readAvailable(transport, payload);
+ return payload.hasRemaining();
+ }
+
+ /**
+ *
+ * @return header of the frame
+ */
+ public T getHeader() {
+ return header;
+ }
+
+ /**
+ *
+ * @return number of bytes of the header
+ */
+ public int getHeaderSize() {
+ return header.toBytes().length;
+ }
+
+ /**
+ *
+ * @return byte array of the payload
+ */
+ public byte[] getPayload() {
+ return payload.array();
+ }
+
+ /**
+ *
+ * @return size of the payload
+ */
+ public int getPayloadSize() {
+ return header.payloadSize();
+ }
+
+ /**
+ *
+ * @return true if the reader has fully read a frame
+ */
+ public boolean isComplete() {
+ return !(payload == null || payload.hasRemaining());
+ }
+
+ /**
+ * Reset the state of the reader so that it can be reused to read a new frame.
+ */
+ public void clear() {
+ header.clear();
+ payload = null;
+ }
+
+ /**
+ * Read immediately available bytes from the transport into the byte buffer.
+ *
+ * @param transport TTransport
+ * @param recipient ByteBuffer
+ * @return number of bytes read out of the transport
+ * @throws TTransportException if io error
+ */
+ static int readAvailable(TTransport transport, ByteBuffer recipient) throws TTransportException {
+ if (!recipient.hasRemaining()) {
+ throw new IllegalStateException("Trying to fill a full recipient with " + recipient.limit()
+ + " bytes");
+ }
+ int currentPosition = recipient.position();
+ byte[] bytes = recipient.array();
+ int offset = recipient.arrayOffset() + currentPosition;
+ int expectedLength = recipient.remaining();
+ int got = transport.read(bytes, offset, expectedLength);
+ if (got < 0) {
+ throw new TEOFException("Transport is closed, while trying to read " + expectedLength +
+ " bytes");
+ }
+ recipient.position(currentPosition + got);
+ return got;
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/FrameWriter.java b/lib/java/src/org/apache/thrift/transport/sasl/FrameWriter.java
new file mode 100644
index 0000000..5f48121
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/FrameWriter.java
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+import org.apache.thrift.transport.TNonblockingTransport;
+
+/**
+ * Write frame (header and payload) to transport in a nonblocking way.
+ */
+public abstract class FrameWriter {
+
+ protected ByteBuffer frameBytes;
+
+ /**
+ * Provide (maybe empty) header and payload to the frame. This can be called only when isComplete
+ * returns true (last frame has been written out).
+ *
+ * @param header Some extra header bytes (without the 4 bytes for payload length), which will be
+ * the start of the frame. It can be empty, depending on the message format
+ * @param payload Payload as a byte array
+ * @throws IllegalStateException if it is called when isComplete returns false
+ * @throws IllegalArgumentException if header or payload is invalid
+ */
+ public void withHeaderAndPayload(byte[] header, byte[] payload) {
+ if (payload == null) {
+ payload = new byte[0];
+ }
+ if (header == null) {
+ withOnlyPayload(payload);
+ } else {
+ withHeaderAndPayload(header, 0, header.length, payload, 0, payload.length);
+ }
+ }
+
+ /**
+ * Provide extra header and payload to the frame.
+ *
+ * @param header byte array containing the extra header
+ * @param headerOffset starting offset of the header portition
+ * @param headerLength length of the extra header
+ * @param payload byte array containing the payload
+ * @param payloadOffset starting offset of the payload portion
+ * @param payloadLength length of the payload
+ * @throws IllegalStateException if preivous frame is not yet complete (isComplete returns fals)
+ * @throws IllegalArgumentException if header or payload is invalid
+ */
+ public void withHeaderAndPayload(byte[] header, int headerOffset, int headerLength,
+ byte[] payload, int payloadOffset, int payloadLength) {
+ if (!isComplete()) {
+ throw new IllegalStateException("Previsous write is not yet complete, with " +
+ frameBytes.remaining() + " bytes left.");
+ }
+ frameBytes = buildFrame(header, headerOffset, headerLength, payload, payloadOffset, payloadLength);
+ }
+
+ /**
+ * Provide only payload to the frame. Throws UnsupportedOperationException if the frame expects
+ * a header.
+ *
+ * @param payload payload as a byte array
+ */
+ public void withOnlyPayload(byte[] payload) {
+ withOnlyPayload(payload, 0, payload.length);
+ }
+
+ /**
+ * Provide only payload to the frame. Throws UnsupportedOperationException if the frame expects
+ * a header.
+ *
+ * @param payload The underlying byte array as a recipient of the payload
+ * @param offset The offset in the byte array starting from where the payload is located
+ * @param length The length of the payload
+ */
+ public abstract void withOnlyPayload(byte[] payload, int offset, int length);
+
+ protected abstract ByteBuffer buildFrame(byte[] header, int headerOffset, int headerLength,
+ byte[] payload, int payloadOffset, int payloadeLength);
+
+ /**
+ * Nonblocking write to the underlying transport.
+ *
+ * @throws IOException
+ */
+ public void write(TNonblockingTransport transport) throws IOException {
+ transport.write(frameBytes);
+ }
+
+ /**
+ *
+ * @return true when no more data needs to be written out
+ */
+ public boolean isComplete() {
+ return frameBytes == null || !frameBytes.hasRemaining();
+ }
+
+ /**
+ * Release the byte buffer.
+ */
+ public void clear() {
+ frameBytes = null;
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/NegotiationStatus.java b/lib/java/src/org/apache/thrift/transport/sasl/NegotiationStatus.java
new file mode 100644
index 0000000..ad704a0
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/NegotiationStatus.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.apache.thrift.transport.sasl.TSaslNegotiationException.ErrorType.PROTOCOL_ERROR;
+
+/**
+ * Status bytes used during the initial Thrift SASL handshake.
+ */
+public enum NegotiationStatus {
+ START((byte)0x01),
+ OK((byte)0x02),
+ BAD((byte)0x03),
+ ERROR((byte)0x04),
+ COMPLETE((byte)0x05);
+
+ private static final Map<Byte, NegotiationStatus> reverseMap = new HashMap<>();
+
+ static {
+ for (NegotiationStatus s : NegotiationStatus.values()) {
+ reverseMap.put(s.getValue(), s);
+ }
+ }
+
+ private final byte value;
+
+ NegotiationStatus(byte val) {
+ this.value = val;
+ }
+
+ public byte getValue() {
+ return value;
+ }
+
+ public static NegotiationStatus byValue(byte val) throws TSaslNegotiationException {
+ if (!reverseMap.containsKey(val)) {
+ throw new TSaslNegotiationException(PROTOCOL_ERROR, "Invalid status " + val);
+ }
+ return reverseMap.get(val);
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/NonblockingSaslHandler.java b/lib/java/src/org/apache/thrift/transport/sasl/NonblockingSaslHandler.java
new file mode 100644
index 0000000..4557146
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/NonblockingSaslHandler.java
@@ -0,0 +1,528 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import java.io.IOException;
+import java.nio.channels.SelectionKey;
+import java.nio.charset.StandardCharsets;
+
+import javax.security.sasl.SaslServer;
+
+import org.apache.thrift.TByteArrayOutputStream;
+import org.apache.thrift.TProcessor;
+import org.apache.thrift.protocol.TProtocol;
+import org.apache.thrift.protocol.TProtocolFactory;
+import org.apache.thrift.server.ServerContext;
+import org.apache.thrift.server.TServerEventHandler;
+import org.apache.thrift.transport.TMemoryTransport;
+import org.apache.thrift.transport.TNonblockingTransport;
+import org.apache.thrift.transport.TTransportException;
+import org.apache.thrift.transport.sasl.TSaslNegotiationException.ErrorType;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import static org.apache.thrift.transport.sasl.NegotiationStatus.COMPLETE;
+import static org.apache.thrift.transport.sasl.NegotiationStatus.OK;
+
+/**
+ * State machine managing one sasl connection in a nonblocking way.
+ */
+public class NonblockingSaslHandler {
+ private static final Logger LOGGER = LoggerFactory.getLogger(NonblockingSaslHandler.class);
+
+ private static final int INTEREST_NONE = 0;
+ private static final int INTEREST_READ = SelectionKey.OP_READ;
+ private static final int INTEREST_WRITE = SelectionKey.OP_WRITE;
+
+ // Tracking the current running phase
+ private Phase currentPhase = Phase.INITIIALIIZING;
+ // Tracking the next phase on the next invocation of the state machine.
+ // It should be the same as current phase if current phase is not yet finished.
+ // Otherwise, if it is different from current phase, the statemachine is in a transition state:
+ // current phase is done, and next phase is not yet started.
+ private Phase nextPhase = currentPhase;
+
+ // Underlying nonblocking transport
+ private SelectionKey selectionKey;
+ private TNonblockingTransport underlyingTransport;
+
+ // APIs for intercepting event / customizing behaviors:
+ // Factories (decorating the base implementations) & EventHandler (intercepting)
+ private TSaslServerFactory saslServerFactory;
+ private TSaslProcessorFactory processorFactory;
+ private TProtocolFactory inputProtocolFactory;
+ private TProtocolFactory outputProtocolFactory;
+ private TServerEventHandler eventHandler;
+ private ServerContext serverContext;
+ // It turns out the event handler implementation in hive sometimes creates a null ServerContext.
+ // In order to know whether TServerEventHandler#createContext is called we use such a flag.
+ private boolean serverContextCreated = false;
+
+ // Wrapper around sasl server
+ private ServerSaslPeer saslPeer;
+
+ // Sasl negotiation io
+ private SaslNegotiationFrameReader saslResponse;
+ private SaslNegotiationFrameWriter saslChallenge;
+ // IO for request from and response to the socket
+ private DataFrameReader requestReader;
+ private DataFrameWriter responseWriter;
+ // If sasl is negotiated for integrity/confidentiality protection
+ private boolean dataProtected;
+
+ public NonblockingSaslHandler(SelectionKey selectionKey, TNonblockingTransport underlyingTransport,
+ TSaslServerFactory saslServerFactory, TSaslProcessorFactory processorFactory,
+ TProtocolFactory inputProtocolFactory, TProtocolFactory outputProtocolFactory,
+ TServerEventHandler eventHandler) {
+ this.selectionKey = selectionKey;
+ this.underlyingTransport = underlyingTransport;
+ this.saslServerFactory = saslServerFactory;
+ this.processorFactory = processorFactory;
+ this.inputProtocolFactory = inputProtocolFactory;
+ this.outputProtocolFactory = outputProtocolFactory;
+ this.eventHandler = eventHandler;
+
+ saslResponse = new SaslNegotiationFrameReader();
+ saslChallenge = new SaslNegotiationFrameWriter();
+ requestReader = new DataFrameReader();
+ responseWriter = new DataFrameWriter();
+ }
+
+ /**
+ * Get current phase of the state machine.
+ *
+ * @return current phase.
+ */
+ public Phase getCurrentPhase() {
+ return currentPhase;
+ }
+
+ /**
+ * Get next phase of the state machine.
+ * It is different from current phase iff current phase is done (and next phase not yet started).
+ *
+ * @return next phase.
+ */
+ public Phase getNextPhase() {
+ return nextPhase;
+ }
+
+ /**
+ *
+ * @return underlying nonblocking socket
+ */
+ public TNonblockingTransport getUnderlyingTransport() {
+ return underlyingTransport;
+ }
+
+ /**
+ *
+ * @return SaslServer instance
+ */
+ public SaslServer getSaslServer() {
+ return saslPeer.getSaslServer();
+ }
+
+ /**
+ *
+ * @return true if current phase is done.
+ */
+ public boolean isCurrentPhaseDone() {
+ return currentPhase != nextPhase;
+ }
+
+ /**
+ * Run state machine.
+ *
+ * @throws IllegalStateException if current state is already done.
+ */
+ public void runCurrentPhase() {
+ currentPhase.runStateMachine(this);
+ }
+
+ /**
+ * When current phase is intrested in read selection, calling this will run the current phase and
+ * its following phases if the following ones are interested to read, until there is nothing
+ * available in the underlying transport.
+ *
+ * @throws IllegalStateException if is called in an irrelevant phase.
+ */
+ public void handleRead() {
+ handleOps(INTEREST_READ);
+ }
+
+ /**
+ * Similiar to handleRead. But it is for write ops.
+ *
+ * @throws IllegalStateException if it is called in an irrelevant phase.
+ */
+ public void handleWrite() {
+ handleOps(INTEREST_WRITE);
+ }
+
+ private void handleOps(int interestOps) {
+ if (currentPhase.selectionInterest != interestOps) {
+ throw new IllegalStateException("Current phase " + currentPhase + " but got interest " +
+ interestOps);
+ }
+ runCurrentPhase();
+ if (isCurrentPhaseDone() && nextPhase.selectionInterest == interestOps) {
+ stepToNextPhase();
+ handleOps(interestOps);
+ }
+ }
+
+ /**
+ * When current phase is finished, it's expected to call this method first before running the
+ * state machine again.
+ * By calling this, "next phase" is marked as started (and not done), thus is ready to run.
+ *
+ * @throws IllegalArgumentException if current phase is not yet done.
+ */
+ public void stepToNextPhase() {
+ if (!isCurrentPhaseDone()) {
+ throw new IllegalArgumentException("Not yet done with current phase: " + currentPhase);
+ }
+ LOGGER.debug("Switch phase {} to {}", currentPhase, nextPhase);
+ switch (nextPhase) {
+ case INITIIALIIZING:
+ throw new IllegalStateException("INITIALIZING cannot be the next phase of " + currentPhase);
+ default:
+ }
+ // If next phase's interest is not the same as current, nor the same as the selection key,
+ // we need to change interest on the selector.
+ if (!(nextPhase.selectionInterest == currentPhase.selectionInterest ||
+ nextPhase.selectionInterest == selectionKey.interestOps())) {
+ changeSelectionInterest(nextPhase.selectionInterest);
+ }
+ currentPhase = nextPhase;
+ }
+
+ private void changeSelectionInterest(int selectionInterest) {
+ selectionKey.interestOps(selectionInterest);
+ }
+
+ // sasl negotiaion failure handling
+ private void failSaslNegotiation(TSaslNegotiationException e) {
+ LOGGER.error("Sasl negotiation failed", e);
+ String errorMsg = e.getDetails();
+ saslChallenge.withHeaderAndPayload(new byte[]{e.getErrorType().code.getValue()},
+ errorMsg.getBytes(StandardCharsets.UTF_8));
+ nextPhase = Phase.WRITING_FAILURE_MESSAGE;
+ }
+
+ private void fail(Exception e) {
+ LOGGER.error("Failed io in " + currentPhase, e);
+ nextPhase = Phase.CLOSING;
+ }
+
+ private void failIO(TTransportException e) {
+ StringBuilder errorMsg = new StringBuilder("IO failure ")
+ .append(e.getType())
+ .append(" in ")
+ .append(currentPhase);
+ if (e.getMessage() != null) {
+ errorMsg.append(": ").append(e.getMessage());
+ }
+ LOGGER.error(errorMsg.toString(), e);
+ nextPhase = Phase.CLOSING;
+ }
+
+ // Read handlings
+
+ private void handleInitializing() {
+ try {
+ saslResponse.read(underlyingTransport);
+ if (saslResponse.isComplete()) {
+ SaslNegotiationHeaderReader startHeader = saslResponse.getHeader();
+ if (startHeader.getStatus() != NegotiationStatus.START) {
+ throw new TInvalidSaslFrameException("Expecting START status but got " + startHeader.getStatus());
+ }
+ String mechanism = new String(saslResponse.getPayload(), StandardCharsets.UTF_8);
+ saslPeer = saslServerFactory.getSaslPeer(mechanism);
+ saslResponse.clear();
+ nextPhase = Phase.READING_SASL_RESPONSE;
+ }
+ } catch (TSaslNegotiationException e) {
+ failSaslNegotiation(e);
+ } catch (TTransportException e) {
+ failIO(e);
+ }
+ }
+
+ private void handleReadingSaslResponse() {
+ try {
+ saslResponse.read(underlyingTransport);
+ if (saslResponse.isComplete()) {
+ nextPhase = Phase.EVALUATING_SASL_RESPONSE;
+ }
+ } catch (TSaslNegotiationException e) {
+ failSaslNegotiation(e);
+ } catch (TTransportException e) {
+ failIO(e);
+ }
+ }
+
+ private void handleReadingRequest() {
+ try {
+ requestReader.read(underlyingTransport);
+ if (requestReader.isComplete()) {
+ nextPhase = Phase.PROCESSING;
+ }
+ } catch (TTransportException e) {
+ failIO(e);
+ }
+ }
+
+ // Computation executions
+
+ private void executeEvaluatingSaslResponse() {
+ if (!(saslResponse.getHeader().getStatus() == OK || saslResponse.getHeader().getStatus() == COMPLETE)) {
+ String error = "Expect status OK or COMPLETE, but got " + saslResponse.getHeader().getStatus();
+ failSaslNegotiation(new TSaslNegotiationException(ErrorType.PROTOCOL_ERROR, error));
+ return;
+ }
+ try {
+ byte[] response = saslResponse.getPayload();
+ saslResponse.clear();
+ byte[] newChallenge = saslPeer.evaluate(response);
+ if (saslPeer.isAuthenticated()) {
+ dataProtected = saslPeer.isDataProtected();
+ saslChallenge.withHeaderAndPayload(new byte[]{COMPLETE.getValue()}, newChallenge);
+ nextPhase = Phase.WRITING_SUCCESS_MESSAGE;
+ } else {
+ saslChallenge.withHeaderAndPayload(new byte[]{OK.getValue()}, newChallenge);
+ nextPhase = Phase.WRITING_SASL_CHALLENGE;
+ }
+ } catch (TSaslNegotiationException e) {
+ failSaslNegotiation(e);
+ }
+ }
+
+ private void executeProcessing() {
+ try {
+ byte[] inputPayload = requestReader.getPayload();
+ requestReader.clear();
+ byte[] rawInput = dataProtected ? saslPeer.unwrap(inputPayload) : inputPayload;
+ TMemoryTransport memoryTransport = new TMemoryTransport(rawInput);
+ TProtocol requestProtocol = inputProtocolFactory.getProtocol(memoryTransport);
+ TProtocol responseProtocol = outputProtocolFactory.getProtocol(memoryTransport);
+
+ if (eventHandler != null) {
+ if (!serverContextCreated) {
+ serverContext = eventHandler.createContext(requestProtocol, responseProtocol);
+ serverContextCreated = true;
+ }
+ eventHandler.processContext(serverContext, memoryTransport, memoryTransport);
+ }
+
+ TProcessor processor = processorFactory.getProcessor(this);
+ processor.process(requestProtocol, responseProtocol);
+ TByteArrayOutputStream rawOutput = memoryTransport.getOutput();
+ if (rawOutput.len() == 0) {
+ // This is a oneway request, no response to send back. Waiting for next incoming request.
+ nextPhase = Phase.READING_REQUEST;
+ return;
+ }
+ if (dataProtected) {
+ byte[] outputPayload = saslPeer.wrap(rawOutput.get(), 0, rawOutput.len());
+ responseWriter.withOnlyPayload(outputPayload);
+ } else {
+ responseWriter.withOnlyPayload(rawOutput.get(), 0 ,rawOutput.len());
+ }
+ nextPhase = Phase.WRITING_RESPONSE;
+ } catch (TTransportException e) {
+ failIO(e);
+ } catch (Exception e) {
+ fail(e);
+ }
+ }
+
+ // Write handlings
+
+ private void handleWritingSaslChallenge() {
+ try {
+ saslChallenge.write(underlyingTransport);
+ if (saslChallenge.isComplete()) {
+ saslChallenge.clear();
+ nextPhase = Phase.READING_SASL_RESPONSE;
+ }
+ } catch (IOException e) {
+ fail(e);
+ }
+ }
+
+ private void handleWritingSuccessMessage() {
+ try {
+ saslChallenge.write(underlyingTransport);
+ if (saslChallenge.isComplete()) {
+ LOGGER.debug("Authentication is done.");
+ saslChallenge = null;
+ saslResponse = null;
+ nextPhase = Phase.READING_REQUEST;
+ }
+ } catch (IOException e) {
+ fail(e);
+ }
+ }
+
+ private void handleWritingFailureMessage() {
+ try {
+ saslChallenge.write(underlyingTransport);
+ if (saslChallenge.isComplete()) {
+ nextPhase = Phase.CLOSING;
+ }
+ } catch (IOException e) {
+ fail(e);
+ }
+ }
+
+ private void handleWritingResponse() {
+ try {
+ responseWriter.write(underlyingTransport);
+ if (responseWriter.isComplete()) {
+ responseWriter.clear();
+ nextPhase = Phase.READING_REQUEST;
+ }
+ } catch (IOException e) {
+ fail(e);
+ }
+ }
+
+ /**
+ * Release all the resources managed by this state machine (connection, selection and sasl server).
+ * To avoid being blocked, this should be invoked in the network thread that manages the selector.
+ */
+ public void close() {
+ underlyingTransport.close();
+ selectionKey.cancel();
+ if (saslPeer != null) {
+ saslPeer.dispose();
+ }
+ if (serverContextCreated) {
+ eventHandler.deleteContext(serverContext,
+ inputProtocolFactory.getProtocol(underlyingTransport),
+ outputProtocolFactory.getProtocol(underlyingTransport));
+ }
+ nextPhase = Phase.CLOSED;
+ currentPhase = Phase.CLOSED;
+ LOGGER.trace("Connection closed: {}", underlyingTransport);
+ }
+
+ public enum Phase {
+ INITIIALIIZING(INTEREST_READ) {
+ @Override
+ void unsafeRun(NonblockingSaslHandler statemachine) {
+ statemachine.handleInitializing();
+ }
+ },
+ READING_SASL_RESPONSE(INTEREST_READ) {
+ @Override
+ void unsafeRun(NonblockingSaslHandler statemachine) {
+ statemachine.handleReadingSaslResponse();
+ }
+ },
+ EVALUATING_SASL_RESPONSE(INTEREST_NONE) {
+ @Override
+ void unsafeRun(NonblockingSaslHandler statemachine) {
+ statemachine.executeEvaluatingSaslResponse();
+ }
+ },
+ WRITING_SASL_CHALLENGE(INTEREST_WRITE) {
+ @Override
+ void unsafeRun(NonblockingSaslHandler statemachine) {
+ statemachine.handleWritingSaslChallenge();
+ }
+ },
+ WRITING_SUCCESS_MESSAGE(INTEREST_WRITE) {
+ @Override
+ void unsafeRun(NonblockingSaslHandler statemachine) {
+ statemachine.handleWritingSuccessMessage();
+ }
+ },
+ WRITING_FAILURE_MESSAGE(INTEREST_WRITE) {
+ @Override
+ void unsafeRun(NonblockingSaslHandler statemachine) {
+ statemachine.handleWritingFailureMessage();
+ }
+ },
+ READING_REQUEST(INTEREST_READ) {
+ @Override
+ void unsafeRun(NonblockingSaslHandler statemachine) {
+ statemachine.handleReadingRequest();
+ }
+ },
+ PROCESSING(INTEREST_NONE) {
+ @Override
+ void unsafeRun(NonblockingSaslHandler statemachine) {
+ statemachine.executeProcessing();
+ }
+ },
+ WRITING_RESPONSE(INTEREST_WRITE) {
+ @Override
+ void unsafeRun(NonblockingSaslHandler statemachine) {
+ statemachine.handleWritingResponse();
+ }
+ },
+ CLOSING(INTEREST_NONE) {
+ @Override
+ void unsafeRun(NonblockingSaslHandler statemachine) {
+ statemachine.close();
+ }
+ },
+ CLOSED(INTEREST_NONE) {
+ @Override
+ void unsafeRun(NonblockingSaslHandler statemachine) {
+ // Do nothing.
+ }
+ }
+ ;
+
+ // The interest on the selection key during the phase
+ private int selectionInterest;
+
+ Phase(int selectionInterest) {
+ this.selectionInterest = selectionInterest;
+ }
+
+ /**
+ * Provide the execution to run for the state machine in current phase. The execution should
+ * return the next phase after running on the state machine.
+ *
+ * @param statemachine The state machine to run.
+ * @throws IllegalArgumentException if the state machine's current phase is different.
+ * @throws IllegalStateException if the state machine' current phase is already done.
+ */
+ void runStateMachine(NonblockingSaslHandler statemachine) {
+ if (statemachine.currentPhase != this) {
+ throw new IllegalArgumentException("State machine is " + statemachine.currentPhase +
+ " but is expected to be " + this);
+ }
+ if (statemachine.isCurrentPhaseDone()) {
+ throw new IllegalStateException("State machine should step into " + statemachine.nextPhase);
+ }
+ unsafeRun(statemachine);
+ }
+
+ // Run the state machine without checkiing its own phase
+ // It should not be called direcly by users.
+ abstract void unsafeRun(NonblockingSaslHandler statemachine);
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/SaslNegotiationFrameReader.java b/lib/java/src/org/apache/thrift/transport/sasl/SaslNegotiationFrameReader.java
new file mode 100644
index 0000000..01c1728
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/SaslNegotiationFrameReader.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+/**
+ * Read frames for sasl negotiatiions.
+ */
+public class SaslNegotiationFrameReader extends FrameReader<SaslNegotiationHeaderReader> {
+
+ public SaslNegotiationFrameReader() {
+ super(new SaslNegotiationHeaderReader());
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/SaslNegotiationFrameWriter.java b/lib/java/src/org/apache/thrift/transport/sasl/SaslNegotiationFrameWriter.java
new file mode 100644
index 0000000..1e9ad15
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/SaslNegotiationFrameWriter.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import java.nio.ByteBuffer;
+
+import org.apache.thrift.EncodingUtils;
+import org.apache.thrift.utils.StringUtils;
+
+import static org.apache.thrift.transport.sasl.SaslNegotiationHeaderReader.PAYLOAD_LENGTH_BYTES;
+import static org.apache.thrift.transport.sasl.SaslNegotiationHeaderReader.STATUS_BYTES;
+
+/**
+ * Writer for sasl negotiation frames. It expect a status byte as header with a payload to be
+ * written out (any header whose size is not equal to 1 would be considered as error).
+ */
+public class SaslNegotiationFrameWriter extends FrameWriter {
+
+ public static final int HEADER_BYTES = STATUS_BYTES + PAYLOAD_LENGTH_BYTES;
+
+ @Override
+ public void withOnlyPayload(byte[] payload, int offset, int length) {
+ throw new UnsupportedOperationException("Status byte is expected for sasl frame header.");
+ }
+
+ @Override
+ protected ByteBuffer buildFrame(byte[] header, int headerOffset, int headerLength,
+ byte[] payload, int payloadOffset, int payloadLength) {
+ if (header == null || headerLength != STATUS_BYTES) {
+ throw new IllegalArgumentException("Header " + StringUtils.bytesToHexString(header) +
+ " does not have expected length " + STATUS_BYTES);
+ }
+ byte[] bytes = new byte[HEADER_BYTES + payloadLength];
+ System.arraycopy(header, headerOffset, bytes, 0, STATUS_BYTES);
+ EncodingUtils.encodeBigEndian(payloadLength, bytes, STATUS_BYTES);
+ System.arraycopy(payload, payloadOffset, bytes, HEADER_BYTES, payloadLength);
+ return ByteBuffer.wrap(bytes);
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/SaslNegotiationHeaderReader.java b/lib/java/src/org/apache/thrift/transport/sasl/SaslNegotiationHeaderReader.java
new file mode 100644
index 0000000..2d76ddb
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/SaslNegotiationHeaderReader.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import static org.apache.thrift.transport.sasl.TSaslNegotiationException.ErrorType.PROTOCOL_ERROR;
+
+/**
+ * Header for sasl negotiation frames. It contains status byte of negotiation and a 4-byte integer
+ * (payload size).
+ */
+public class SaslNegotiationHeaderReader extends FixedSizeHeaderReader {
+ public static final int STATUS_BYTES = 1;
+ public static final int PAYLOAD_LENGTH_BYTES = 4;
+
+ private NegotiationStatus negotiationStatus;
+ private int payloadSize;
+
+ @Override
+ protected int headerSize() {
+ return STATUS_BYTES + PAYLOAD_LENGTH_BYTES;
+ }
+
+ @Override
+ protected void onComplete() throws TSaslNegotiationException {
+ negotiationStatus = NegotiationStatus.byValue(byteBuffer.get(0));
+ payloadSize = byteBuffer.getInt(1);
+ if (payloadSize < 0) {
+ throw new TSaslNegotiationException(PROTOCOL_ERROR, "Payload size is negative: " + payloadSize);
+ }
+ }
+
+ @Override
+ public int payloadSize() {
+ return payloadSize;
+ }
+
+ public NegotiationStatus getStatus() {
+ return negotiationStatus;
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/SaslPeer.java b/lib/java/src/org/apache/thrift/transport/sasl/SaslPeer.java
new file mode 100644
index 0000000..8f81380
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/SaslPeer.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import org.apache.thrift.transport.TTransportException;
+
+/**
+ * A peer in a sasl negotiation.
+ */
+public interface SaslPeer {
+
+ /**
+ * Evaluate and validate the negotiation message (response/challenge) received from peer.
+ *
+ * @param negotiationMessage response/challenge received from peer.
+ * @return new response/challenge to send to peer, can be null if authentication becomes success.
+ * @throws TSaslNegotiationException if sasl authentication fails.
+ */
+ byte[] evaluate(byte[] negotiationMessage) throws TSaslNegotiationException;
+
+ /**
+ * @return true if authentication is done.
+ */
+ boolean isAuthenticated();
+
+ /**
+ * This method can only be called when the negotiation is complete (isAuthenticated returns true).
+ * Otherwise it will throw IllegalStateExceptiion.
+ *
+ * @return if the qop requires some integrity/confidential protection.
+ * @throws IllegalStateException if negotiation is not yet complete.
+ */
+ boolean isDataProtected();
+
+ /**
+ * Wrap raw bytes to protect it.
+ *
+ * @param data raw bytes.
+ * @param offset the start position of the content to wrap.
+ * @param length the length of the content to wrap.
+ * @return bytes with protection to send to peer.
+ * @throws TTransportException if failure.
+ */
+ byte[] wrap(byte[] data, int offset, int length) throws TTransportException;
+
+ /**
+ * Wrap the whole byte array.
+ *
+ * @param data raw bytes.
+ * @return wrapped bytes.
+ * @throws TTransportException if failure.
+ */
+ default byte[] wrap(byte[] data) throws TTransportException {
+ return wrap(data, 0, data.length);
+ }
+
+ /**
+ * Unwrap protected data to raw bytes.
+ *
+ * @param data protected data received from peer.
+ * @param offset the start position of the content to unwrap.
+ * @param length the length of the content to unwrap.
+ * @return raw bytes.
+ * @throws TTransportException if failed.
+ */
+ byte[] unwrap(byte[] data, int offset, int length) throws TTransportException;
+
+ /**
+ * Unwrap the whole byte array.
+ *
+ * @param data wrapped bytes.
+ * @return raw bytes.
+ * @throws TTransportException if failure.
+ */
+ default byte[] unwrap(byte[] data) throws TTransportException {
+ return unwrap(data, 0, data.length);
+ }
+
+ /**
+ * Close this peer and release resources.
+ */
+ void dispose();
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/ServerSaslPeer.java b/lib/java/src/org/apache/thrift/transport/sasl/ServerSaslPeer.java
new file mode 100644
index 0000000..31992e5
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/ServerSaslPeer.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslException;
+import javax.security.sasl.SaslServer;
+
+import org.apache.thrift.transport.TTransportException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import static org.apache.thrift.transport.sasl.TSaslNegotiationException.ErrorType.AUTHENTICATION_FAILURE;
+
+/**
+ * Server side sasl peer, a wrapper around SaslServer to provide some handy methods.
+ */
+public class ServerSaslPeer implements SaslPeer {
+ private static final Logger LOGGER = LoggerFactory.getLogger(ServerSaslPeer.class);
+
+ private static final String QOP_AUTH_INT = "auth-int";
+ private static final String QOP_AUTH_CONF = "auth-conf";
+
+ private final SaslServer saslServer;
+
+ public ServerSaslPeer(SaslServer saslServer) {
+ this.saslServer = saslServer;
+ }
+
+ @Override
+ public byte[] evaluate(byte[] negotiationMessage) throws TSaslNegotiationException {
+ try {
+ return saslServer.evaluateResponse(negotiationMessage);
+ } catch (SaslException e) {
+ throw new TSaslNegotiationException(AUTHENTICATION_FAILURE,
+ "Authentication failed with " + saslServer.getMechanismName(), e);
+ }
+ }
+
+ @Override
+ public boolean isAuthenticated() {
+ return saslServer.isComplete();
+ }
+
+ @Override
+ public boolean isDataProtected() {
+ Object qop = saslServer.getNegotiatedProperty(Sasl.QOP);
+ if (qop == null) {
+ return false;
+ }
+ for (String word : qop.toString().split("\\s*,\\s*")) {
+ String lowerCaseWord = word.toLowerCase();
+ if (QOP_AUTH_INT.equals(lowerCaseWord) || QOP_AUTH_CONF.equals(lowerCaseWord)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ @Override
+ public byte[] wrap(byte[] data, int offset, int length) throws TTransportException {
+ try {
+ return saslServer.wrap(data, offset, length);
+ } catch (SaslException e) {
+ throw new TTransportException("Failed to wrap data", e);
+ }
+ }
+
+ @Override
+ public byte[] unwrap(byte[] data, int offset, int length) throws TTransportException {
+ try {
+ return saslServer.unwrap(data, offset, length);
+ } catch (SaslException e) {
+ throw new TTransportException(TTransportException.CORRUPTED_DATA, "Failed to unwrap data", e);
+ }
+ }
+
+ @Override
+ public void dispose() {
+ try {
+ saslServer.dispose();
+ } catch (Exception e) {
+ LOGGER.warn("Failed to close sasl server " + saslServer.getMechanismName(), e);
+ }
+ }
+
+ SaslServer getSaslServer() {
+ return saslServer;
+ }
+
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/TBaseSaslProcessorFactory.java b/lib/java/src/org/apache/thrift/transport/sasl/TBaseSaslProcessorFactory.java
new file mode 100644
index 0000000..c08884c
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/TBaseSaslProcessorFactory.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import org.apache.thrift.TProcessor;
+
+public class TBaseSaslProcessorFactory implements TSaslProcessorFactory {
+
+ private final TProcessor processor;
+
+ public TBaseSaslProcessorFactory(TProcessor processor) {
+ this.processor = processor;
+ }
+
+ @Override
+ public TProcessor getProcessor(NonblockingSaslHandler saslHandler) {
+ return processor;
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/TInvalidSaslFrameException.java b/lib/java/src/org/apache/thrift/transport/sasl/TInvalidSaslFrameException.java
new file mode 100644
index 0000000..ff57ea5
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/TInvalidSaslFrameException.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+/**
+ * Got an invalid frame that does not respect the thrift sasl protocol.
+ */
+public class TInvalidSaslFrameException extends TSaslNegotiationException {
+
+ public TInvalidSaslFrameException(String message) {
+ super(ErrorType.PROTOCOL_ERROR, message);
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/TSaslNegotiationException.java b/lib/java/src/org/apache/thrift/transport/sasl/TSaslNegotiationException.java
new file mode 100644
index 0000000..9b1fa06
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/TSaslNegotiationException.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import org.apache.thrift.transport.TTransportException;
+
+/**
+ * Exception for sasl negotiation errors.
+ */
+public class TSaslNegotiationException extends TTransportException {
+
+ private final ErrorType error;
+
+ public TSaslNegotiationException(ErrorType error, String summary) {
+ super(summary);
+ this.error = error;
+ }
+
+ public TSaslNegotiationException(ErrorType error, String summary, Throwable cause) {
+ super(summary, cause);
+ this.error = error;
+ }
+
+ public ErrorType getErrorType() {
+ return error;
+ }
+
+ /**
+ * @return Errory type plus the message.
+ */
+ public String getSummary() {
+ return error.name() + ": " + getMessage();
+ }
+
+ /**
+ * @return Summary and eventually the cause's message.
+ */
+ public String getDetails() {
+ return getCause() == null ? getSummary() : getSummary() + "\nReason: " + getCause().getMessage();
+ }
+
+ public enum ErrorType {
+ // Unexpected system internal error during negotiation (e.g. sasl initialization failure)
+ INTERNAL_ERROR(NegotiationStatus.ERROR),
+ // Cannot read correct sasl frames from the connection => Send "ERROR" status byte to peer
+ PROTOCOL_ERROR(NegotiationStatus.ERROR),
+ // Peer is using unsupported sasl mechanisms => Send "BAD" status byte to peer
+ MECHANISME_MISMATCH(NegotiationStatus.BAD),
+ // Sasl authentication failure => Send "BAD" status byte to peer
+ AUTHENTICATION_FAILURE(NegotiationStatus.BAD),
+ ;
+
+ public final NegotiationStatus code;
+
+ ErrorType(NegotiationStatus code) {
+ this.code = code;
+ }
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/TSaslProcessorFactory.java b/lib/java/src/org/apache/thrift/transport/sasl/TSaslProcessorFactory.java
new file mode 100644
index 0000000..877d049
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/TSaslProcessorFactory.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import org.apache.thrift.TException;
+import org.apache.thrift.TProcessor;
+
+/**
+ * Get processor for a given state machine, so that users can customize the behavior of a TProcessor
+ * by interacting with the state machine.
+ */
+public interface TSaslProcessorFactory {
+
+ TProcessor getProcessor(NonblockingSaslHandler saslHandler) throws TException;
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/TSaslServerDefinition.java b/lib/java/src/org/apache/thrift/transport/sasl/TSaslServerDefinition.java
new file mode 100644
index 0000000..5486641
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/TSaslServerDefinition.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import javax.security.auth.callback.CallbackHandler;
+import java.util.Map;
+
+/**
+ * Contains all the parameters used to define a SASL server implementation.
+ */
+public class TSaslServerDefinition {
+ public final String mechanism;
+ public final String protocol;
+ public final String serverName;
+ public final Map<String, String> props;
+ public final CallbackHandler cbh;
+
+ public TSaslServerDefinition(String mechanism, String protocol, String serverName,
+ Map<String, String> props, CallbackHandler cbh) {
+ this.mechanism = mechanism;
+ this.protocol = protocol;
+ this.serverName = serverName;
+ this.props = props;
+ this.cbh = cbh;
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/TSaslServerFactory.java b/lib/java/src/org/apache/thrift/transport/sasl/TSaslServerFactory.java
new file mode 100644
index 0000000..06cf534
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/TSaslServerFactory.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslException;
+import javax.security.sasl.SaslServer;
+
+import static org.apache.thrift.transport.sasl.TSaslNegotiationException.ErrorType.MECHANISME_MISMATCH;
+import static org.apache.thrift.transport.sasl.TSaslNegotiationException.ErrorType.PROTOCOL_ERROR;
+
+/**
+ * Factory to create sasl server. Users can extend this class to customize the SaslServer creation.
+ */
+public class TSaslServerFactory {
+
+ private final Map<String, TSaslServerDefinition> saslMechanisms;
+
+ public TSaslServerFactory() {
+ this.saslMechanisms = new HashMap<>();
+ }
+
+ public void addSaslMechanism(String mechanism, String protocol, String serverName,
+ Map<String, String> props, CallbackHandler cbh) {
+ TSaslServerDefinition definition = new TSaslServerDefinition(mechanism, protocol, serverName,
+ props, cbh);
+ saslMechanisms.put(definition.mechanism, definition);
+ }
+
+ public ServerSaslPeer getSaslPeer(String mechanism) throws TSaslNegotiationException {
+ if (!saslMechanisms.containsKey(mechanism)) {
+ throw new TSaslNegotiationException(MECHANISME_MISMATCH, "Unsupported mechanism " + mechanism);
+ }
+ TSaslServerDefinition saslDef = saslMechanisms.get(mechanism);
+ try {
+ SaslServer saslServer = Sasl.createSaslServer(saslDef.mechanism, saslDef.protocol,
+ saslDef.serverName, saslDef.props, saslDef.cbh);
+ return new ServerSaslPeer(saslServer);
+ } catch (SaslException e) {
+ throw new TSaslNegotiationException(PROTOCOL_ERROR, "Fail to create sasl server " + mechanism, e);
+ }
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/utils/StringUtils.java b/lib/java/src/org/apache/thrift/utils/StringUtils.java
new file mode 100644
index 0000000..15183a3
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/utils/StringUtils.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.utils;
+
+public final class StringUtils {
+
+ private StringUtils() {
+ // Utility class.
+ }
+
+ private static final char[] HEX_CHARS = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'};
+
+ /**
+ * Stringify a byte array to the hex representation for each byte.
+ *
+ * @param bytes
+ * @return hex string.
+ */
+ public static String bytesToHexString(byte[] bytes) {
+ if (bytes == null) {
+ return null;
+ }
+ return bytesToHexString(bytes, 0, bytes.length);
+ }
+
+ /**
+ * Stringify a portion of the byte array.
+ *
+ * @param bytes byte array.
+ * @param offset portion start.
+ * @param length portion length.
+ * @return hex string.
+ */
+ public static String bytesToHexString(byte[] bytes, int offset, int length) {
+ if (length < 0) {
+ throw new IllegalArgumentException("Negative length " + length);
+ }
+ if (offset < 0) {
+ throw new IndexOutOfBoundsException("Negative start offset " + offset);
+ }
+ char[] chars = new char[length * 2];
+ for (int i = 0; i < length; i++) {
+ int unsignedInt = bytes[i + offset] & 0xFF;
+ chars[2 * i] = HEX_CHARS[unsignedInt >>> 4];
+ chars[2 * i + 1] = HEX_CHARS[unsignedInt & 0x0F];
+ }
+ return new String(chars);
+ }
+}
diff --git a/lib/java/test/org/apache/thrift/server/TestSaslNonblockingServer.java b/lib/java/test/org/apache/thrift/server/TestSaslNonblockingServer.java
new file mode 100644
index 0000000..d0a6746
--- /dev/null
+++ b/lib/java/test/org/apache/thrift/server/TestSaslNonblockingServer.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.server;
+
+import org.apache.thrift.TProcessor;
+import org.apache.thrift.protocol.TBinaryProtocol;
+import org.apache.thrift.protocol.TProtocolFactory;
+import org.apache.thrift.transport.TNonblockingServerSocket;
+import org.apache.thrift.transport.TNonblockingServerTransport;
+import org.apache.thrift.transport.TSaslClientTransport;
+import org.apache.thrift.transport.TSocket;
+import org.apache.thrift.transport.TTransportException;
+import org.apache.thrift.transport.TTransportFactory;
+import org.apache.thrift.transport.TestTSaslTransports;
+import org.apache.thrift.transport.TestTSaslTransports.TestSaslCallbackHandler;
+import org.apache.thrift.transport.sasl.TSaslNegotiationException;
+import thrift.test.ThriftTest;
+
+import static org.apache.thrift.transport.sasl.TSaslNegotiationException.ErrorType.AUTHENTICATION_FAILURE;
+
+public class TestSaslNonblockingServer extends TestTSaslTransports.TestTSaslTransportsWithServer {
+
+ private TSaslNonblockingServer server;
+
+ @Override
+ public void startServer(TProcessor processor, TProtocolFactory protoFactory, TTransportFactory factory)
+ throws Exception {
+ TNonblockingServerTransport serverSocket = new TNonblockingServerSocket(
+ new TNonblockingServerSocket.NonblockingAbstractServerSocketArgs().port(PORT));
+ TSaslNonblockingServer.Args args = new TSaslNonblockingServer.Args(serverSocket)
+ .processor(processor)
+ .transportFactory(factory)
+ .protocolFactory(protoFactory)
+ .addSaslMechanism(TestTSaslTransports.WRAPPED_MECHANISM, TestTSaslTransports.SERVICE,
+ TestTSaslTransports.HOST, TestTSaslTransports.WRAPPED_PROPS,
+ new TestSaslCallbackHandler(TestTSaslTransports.PASSWORD));
+ server = new TSaslNonblockingServer(args);
+ server.serve();
+ }
+
+ @Override
+ public void stopServer() throws Exception {
+ server.shutdown();
+ }
+
+ @Override
+ public void testIt() throws Exception {
+ super.testIt();
+ }
+
+ public void testBadPassword() throws Exception {
+ TProtocolFactory protocolFactory = new TBinaryProtocol.Factory();
+ TProcessor processor = new ThriftTest.Processor<>(new TestHandler());
+ startServer(processor, protocolFactory);
+
+ TSocket socket = new TSocket(HOST, PORT);
+ socket.setTimeout(SOCKET_TIMEOUT);
+ TSaslClientTransport client = new TSaslClientTransport(TestTSaslTransports.WRAPPED_MECHANISM,
+ TestTSaslTransports.PRINCIPAL, TestTSaslTransports.SERVICE, TestTSaslTransports.HOST,
+ TestTSaslTransports.WRAPPED_PROPS, new TestSaslCallbackHandler("bad_password"), socket);
+ try {
+ client.open();
+ fail("Client should fail with sasl negotiation.");
+ } catch (TTransportException error) {
+ TSaslNegotiationException serverSideError = new TSaslNegotiationException(AUTHENTICATION_FAILURE,
+ "Authentication failed with " + TestTSaslTransports.WRAPPED_MECHANISM);
+ assertTrue("Server should return error message \"" + serverSideError.getSummary() + "\"",
+ error.getMessage().contains(serverSideError.getSummary()));
+ } finally {
+ stopServer();
+ client.close();
+ }
+ }
+
+ @Override
+ public void testTransportFactory() {
+ // This test is irrelevant here, so skipped.
+ }
+}
diff --git a/lib/java/test/org/apache/thrift/transport/TestNonblockingServerSocket.java b/lib/java/test/org/apache/thrift/transport/TestNonblockingServerSocket.java
new file mode 100644
index 0000000..6b28dfd
--- /dev/null
+++ b/lib/java/test/org/apache/thrift/transport/TestNonblockingServerSocket.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.nio.channels.ServerSocketChannel;
+
+public class TestNonblockingServerSocket {
+
+ @Test
+ public void testSocketChannelBlockingMode() throws TTransportException {
+ try (TNonblockingServerSocket nonblockingServer = new TNonblockingServerSocket(0)){
+ ServerSocketChannel socketChannel = nonblockingServer.getServerSocketChannel();
+ Assert.assertFalse("Socket channel should be nonblocking", socketChannel.isBlocking());
+ }
+ }
+}
diff --git a/lib/java/test/org/apache/thrift/transport/TestTMemoryTransport.java b/lib/java/test/org/apache/thrift/transport/TestTMemoryTransport.java
new file mode 100644
index 0000000..2e20ffe
--- /dev/null
+++ b/lib/java/test/org/apache/thrift/transport/TestTMemoryTransport.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport;
+
+import org.apache.thrift.TByteArrayOutputStream;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.nio.ByteBuffer;
+import java.util.Random;
+
+public class TestTMemoryTransport {
+
+ @Test
+ public void testReadBatches() throws TTransportException {
+ byte[] inputBytes = {0x10, 0x7A, (byte) 0xBF, (byte) 0xFE, 0x53, (byte) 0x82, (byte) 0xFF};
+ TMemoryTransport transport = new TMemoryTransport(inputBytes);
+ byte[] read = new byte[inputBytes.length];
+ int firstBatch = new Random().nextInt(inputBytes.length);
+ int secondBatch = inputBytes.length - firstBatch;
+ transport.read(read, 0, firstBatch);
+ transport.read(read, firstBatch, secondBatch);
+ boolean equal = true;
+ for (int i = 0; i < inputBytes.length; i++) {
+ equal = equal && inputBytes[i] == read[i];
+ }
+ Assert.assertEquals(ByteBuffer.wrap(inputBytes), ByteBuffer.wrap(read));
+ }
+
+ @Test (expected = TTransportException.class)
+ public void testReadMoreThanRemaining() throws TTransportException {
+ TMemoryTransport transport = new TMemoryTransport(new byte[] {0x00, 0x32});
+ byte[] read = new byte[3];
+ transport.read(read, 0, 3);
+ }
+
+ @Test
+ public void testWrite() throws TTransportException {
+ TMemoryTransport transport = new TMemoryTransport(new byte[0]);
+ byte[] output1 = {0x72, 0x56, 0x29, (byte) 0xAF, (byte) 0x9B};
+ transport.write(output1);
+ byte[] output2 = {(byte) 0x83, 0x10, 0x00};
+ transport.write(output2, 0, 2);
+ byte[] expected = {0x72, 0x56, 0x29, (byte) 0xAF, (byte) 0x9B, (byte) 0x83, 0x10};
+ TByteArrayOutputStream outputByteArray = transport.getOutput();
+ Assert.assertEquals(ByteBuffer.wrap(expected), ByteBuffer.wrap(outputByteArray.get(), 0, outputByteArray.len()));
+ }
+}
diff --git a/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java b/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java
index 36a06e9..6eb38e7 100644
--- a/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java
+++ b/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java
@@ -53,17 +53,17 @@
private static final Logger LOGGER = LoggerFactory.getLogger(TestTSaslTransports.class);
- private static final String HOST = "localhost";
- private static final String SERVICE = "thrift-test";
- private static final String PRINCIPAL = "thrift-test-principal";
- private static final String PASSWORD = "super secret password";
- private static final String REALM = "thrift-test-realm";
+ public static final String HOST = "localhost";
+ public static final String SERVICE = "thrift-test";
+ public static final String PRINCIPAL = "thrift-test-principal";
+ public static final String PASSWORD = "super secret password";
+ public static final String REALM = "thrift-test-realm";
- private static final String UNWRAPPED_MECHANISM = "CRAM-MD5";
- private static final Map<String, String> UNWRAPPED_PROPS = null;
+ public static final String UNWRAPPED_MECHANISM = "CRAM-MD5";
+ public static final Map<String, String> UNWRAPPED_PROPS = null;
- private static final String WRAPPED_MECHANISM = "DIGEST-MD5";
- private static final Map<String, String> WRAPPED_PROPS = new HashMap<String, String>();
+ public static final String WRAPPED_MECHANISM = "DIGEST-MD5";
+ public static final Map<String, String> WRAPPED_PROPS = new HashMap<String, String>();
static {
WRAPPED_PROPS.put(Sasl.QOP, "auth-int");
@@ -80,7 +80,7 @@
+ "'We hold these truths to be self-evident, that all men are created equal.'";
- private static class TestSaslCallbackHandler implements CallbackHandler {
+ public static class TestSaslCallbackHandler implements CallbackHandler {
private final String password;
public TestSaslCallbackHandler(String password) {
@@ -265,7 +265,7 @@
new TestTSaslTransportsWithServer().testIt();
}
- private static class TestTSaslTransportsWithServer extends ServerTestBase {
+ public static class TestTSaslTransportsWithServer extends ServerTestBase {
private Thread serverThread;
private TServer server;
diff --git a/lib/java/test/org/apache/thrift/transport/sasl/TestDataFrameReader.java b/lib/java/test/org/apache/thrift/transport/sasl/TestDataFrameReader.java
new file mode 100644
index 0000000..9ae0e1e
--- /dev/null
+++ b/lib/java/test/org/apache/thrift/transport/sasl/TestDataFrameReader.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import org.apache.thrift.transport.TMemoryInputTransport;
+import org.apache.thrift.transport.TTransportException;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.nio.ByteBuffer;
+
+public class TestDataFrameReader {
+
+ @Test
+ public void testRead() throws TTransportException {
+ // Prepare data
+ int payloadSize = 23;
+ ByteBuffer buffer = ByteBuffer.allocate(DataFrameHeaderReader.PAYLOAD_LENGTH_BYTES + payloadSize);
+ buffer.putInt(payloadSize);
+ for (int i = 0; i < payloadSize; i++) {
+ buffer.put((byte) i);
+ }
+ buffer.rewind();
+
+ TMemoryInputTransport transport = new TMemoryInputTransport();
+ DataFrameReader dataFrameReader = new DataFrameReader();
+ // No bytes received.
+ dataFrameReader.read(transport);
+ Assert.assertFalse("No bytes received", dataFrameReader.isComplete());
+ Assert.assertFalse("No bytes received", dataFrameReader.getHeader().isComplete());
+ // Payload size (header) and part of the payload are received.
+ transport.reset(buffer.array(), 0, 6);
+ dataFrameReader.read(transport);
+ Assert.assertFalse("Only header is complete", dataFrameReader.isComplete());
+ Assert.assertTrue("Header should be complete", dataFrameReader.getHeader().isComplete());
+ Assert.assertEquals("Payload size should be " + payloadSize, payloadSize, dataFrameReader.getHeader().payloadSize());
+ // Read the rest of payload.
+ transport.reset(buffer.array(), 6, 21);
+ dataFrameReader.read(transport);
+ Assert.assertTrue("Reader should be complete", dataFrameReader.isComplete());
+ buffer.position(DataFrameHeaderReader.PAYLOAD_LENGTH_BYTES);
+ Assert.assertEquals("Payload should be the same as from the transport", buffer, ByteBuffer.wrap(dataFrameReader.getPayload()));
+ }
+}
diff --git a/lib/java/test/org/apache/thrift/transport/sasl/TestDataFrameWriter.java b/lib/java/test/org/apache/thrift/transport/sasl/TestDataFrameWriter.java
new file mode 100644
index 0000000..d242593
--- /dev/null
+++ b/lib/java/test/org/apache/thrift/transport/sasl/TestDataFrameWriter.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+import org.apache.thrift.EncodingUtils;
+import org.apache.thrift.transport.TNonblockingTransport;
+import org.junit.Assert;
+import org.junit.Test;
+import org.mockito.Mockito;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+import static org.apache.thrift.transport.sasl.DataFrameHeaderReader.PAYLOAD_LENGTH_BYTES;
+
+public class TestDataFrameWriter {
+
+ private static final byte[] BYTES = new byte[]{0x32, 0x2A, (byte) 0xE1, 0x18, (byte) 0x90, 0x75};
+
+ @Test
+ public void testProvideEntireByteArrayAsPayload() {
+ DataFrameWriter frameWriter = new DataFrameWriter();
+ frameWriter.withOnlyPayload(BYTES);
+ byte[] expectedBytes = new byte[BYTES.length + PAYLOAD_LENGTH_BYTES];
+ EncodingUtils.encodeBigEndian(BYTES.length, expectedBytes);
+ System.arraycopy(BYTES, 0, expectedBytes, PAYLOAD_LENGTH_BYTES, BYTES.length);
+ Assert.assertEquals(ByteBuffer.wrap(expectedBytes), frameWriter.frameBytes);
+ }
+
+ @Test
+ public void testProvideByteArrayPortionAsPayload() {
+ DataFrameWriter frameWriter = new DataFrameWriter();
+ int portionOffset = 2;
+ int portionLength = 3;
+ frameWriter.withOnlyPayload(BYTES, portionOffset, portionLength);
+ byte[] expectedBytes = new byte[portionLength + PAYLOAD_LENGTH_BYTES];
+ EncodingUtils.encodeBigEndian(portionLength, expectedBytes);
+ System.arraycopy(BYTES, portionOffset, expectedBytes, PAYLOAD_LENGTH_BYTES, portionLength);
+ Assert.assertEquals(ByteBuffer.wrap(expectedBytes), frameWriter.frameBytes);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testProvideHeaderAndPayload() {
+ DataFrameWriter frameWriter = new DataFrameWriter();
+ frameWriter.withHeaderAndPayload(new byte[1], new byte[1]);
+ }
+
+ @Test(expected = IllegalStateException.class)
+ public void testProvidePayloadToIncompleteFrame() {
+ DataFrameWriter frameWriter = new DataFrameWriter();
+ frameWriter.withOnlyPayload(BYTES);
+ frameWriter.withOnlyPayload(new byte[1]);
+ }
+
+ @Test
+ public void testWrite() throws IOException {
+ DataFrameWriter frameWriter = new DataFrameWriter();
+ frameWriter.withOnlyPayload(BYTES);
+ // Slow socket which writes one byte per call.
+ TNonblockingTransport transport = Mockito.mock(TNonblockingTransport.class);
+ SlowWriting slowWriting = new SlowWriting();
+ Mockito.when(transport.write(frameWriter.frameBytes)).thenAnswer(slowWriting);
+ frameWriter.write(transport);
+ while (slowWriting.written < frameWriter.frameBytes.limit()) {
+ Assert.assertFalse("Frame writer should not be complete", frameWriter.isComplete());
+ frameWriter.write(transport);
+ }
+ Assert.assertTrue("Frame writer should be complete", frameWriter.isComplete());
+ }
+
+ private static class SlowWriting implements Answer<Integer> {
+ int written = 0;
+
+ @Override
+ public Integer answer(InvocationOnMock invocation) throws Throwable {
+ ByteBuffer bytes = (ByteBuffer) invocation.getArguments()[0];
+ bytes.get();
+ written ++;
+ return 1;
+ }
+ }
+}
diff --git a/lib/java/test/org/apache/thrift/transport/sasl/TestSaslNegotiationFrameReader.java b/lib/java/test/org/apache/thrift/transport/sasl/TestSaslNegotiationFrameReader.java
new file mode 100644
index 0000000..f2abbe6
--- /dev/null
+++ b/lib/java/test/org/apache/thrift/transport/sasl/TestSaslNegotiationFrameReader.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import org.apache.thrift.transport.TMemoryInputTransport;
+import org.apache.thrift.transport.TTransportException;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.nio.ByteBuffer;
+
+public class TestSaslNegotiationFrameReader {
+
+ @Test
+ public void testRead() throws TTransportException {
+ TMemoryInputTransport transport = new TMemoryInputTransport();
+ SaslNegotiationFrameReader negotiationReader = new SaslNegotiationFrameReader();
+ // No bytes received
+ negotiationReader.read(transport);
+ Assert.assertFalse("No bytes received", negotiationReader.isComplete());
+ Assert.assertFalse("No bytes received", negotiationReader.getHeader().isComplete());
+ // Read header
+ ByteBuffer buffer = ByteBuffer.allocate(5);
+ buffer.put(0, NegotiationStatus.OK.getValue());
+ buffer.putInt(1, 10);
+ transport.reset(buffer.array());
+ negotiationReader.read(transport);
+ Assert.assertFalse("Only header is complete", negotiationReader.isComplete());
+ Assert.assertTrue("Header should be complete", negotiationReader.getHeader().isComplete());
+ Assert.assertEquals("Payload size should be 10", 10, negotiationReader.getHeader().payloadSize());
+ // Read payload
+ transport.reset(new byte[20]);
+ negotiationReader.read(transport);
+ Assert.assertTrue("Reader should be complete", negotiationReader.isComplete());
+ Assert.assertEquals("Payload length should be 10", 10, negotiationReader.getPayload().length);
+ }
+
+ @Test (expected = TSaslNegotiationException.class)
+ public void testReadInvalidNegotiationStatus() throws TTransportException {
+ byte[] bytes = new byte[5];
+ // Invalid status byte.
+ bytes[0] = -1;
+ TMemoryInputTransport transport = new TMemoryInputTransport(bytes);
+ SaslNegotiationFrameReader negotiationReader = new SaslNegotiationFrameReader();
+ negotiationReader.read(transport);
+ }
+}
diff --git a/lib/java/test/org/apache/thrift/transport/sasl/TestSaslNegotiationFrameWriter.java b/lib/java/test/org/apache/thrift/transport/sasl/TestSaslNegotiationFrameWriter.java
new file mode 100644
index 0000000..ce7ff29
--- /dev/null
+++ b/lib/java/test/org/apache/thrift/transport/sasl/TestSaslNegotiationFrameWriter.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import java.nio.ByteBuffer;
+
+import org.apache.thrift.EncodingUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+import static org.apache.thrift.transport.sasl.SaslNegotiationFrameWriter.HEADER_BYTES;
+
+public class TestSaslNegotiationFrameWriter {
+
+ private static final byte[] PAYLOAD = {0x11, 0x08, 0x3F, 0x58, 0x73, 0x22, 0x00, (byte) 0xFF};
+
+ @Test
+ public void testWithHeaderAndPayload() {
+ SaslNegotiationFrameWriter frameWriter = new SaslNegotiationFrameWriter();
+ frameWriter.withHeaderAndPayload(new byte[] {NegotiationStatus.OK.getValue()}, PAYLOAD);
+ byte[] expectedBytes = new byte[HEADER_BYTES + PAYLOAD.length];
+ expectedBytes[0] = NegotiationStatus.OK.getValue();
+ EncodingUtils.encodeBigEndian(PAYLOAD.length, expectedBytes, 1);
+ System.arraycopy(PAYLOAD, 0, expectedBytes, HEADER_BYTES, PAYLOAD.length);
+ Assert.assertEquals(ByteBuffer.wrap(expectedBytes), frameWriter.frameBytes);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testWithInvalidHeaderLength() {
+ SaslNegotiationFrameWriter frameWriter = new SaslNegotiationFrameWriter();
+ frameWriter.withHeaderAndPayload(new byte[5], 0, 2, PAYLOAD, 0, 1);
+ }
+
+ @Test(expected = UnsupportedOperationException.class)
+ public void testWithOnlyPayload() {
+ SaslNegotiationFrameWriter frameWriter = new SaslNegotiationFrameWriter();
+ frameWriter.withOnlyPayload(new byte[0]);
+ }
+}
diff --git a/lib/java/test/org/apache/thrift/utils/TestStringUtils.java b/lib/java/test/org/apache/thrift/utils/TestStringUtils.java
new file mode 100644
index 0000000..3a8cf39
--- /dev/null
+++ b/lib/java/test/org/apache/thrift/utils/TestStringUtils.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.utils;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TestStringUtils {
+
+ @Test
+ public void testToHexString() {
+ byte[] bytes = {0x00, 0x1A, (byte) 0xEF, (byte) 0xAB, (byte) 0x92};
+ Assert.assertEquals("001AEFAB92", StringUtils.bytesToHexString(bytes));
+ Assert.assertEquals("EFAB92", StringUtils.bytesToHexString(bytes, 2, 3));
+ Assert.assertNull(StringUtils.bytesToHexString(null));
+ }
+}