THRIFT-876. java: Add SASL support
This patch adds support for a SASL-secured transport to the Java library. In its current form, it only works for the blocking-IO servers.
Patch: Aaron T Meyers
git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@993563 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/doc/thrift-sasl-spec.txt b/doc/thrift-sasl-spec.txt
new file mode 100644
index 0000000..59bfcf9
--- /dev/null
+++ b/doc/thrift-sasl-spec.txt
@@ -0,0 +1,107 @@
+A Thrift SASL message shall be a byte array of one of the following forms:
+
+| 1-byte START status code | 1-byte mechanism name length | variable length mechanism name | 4-byte payload length | variable-length payload |
+| 1-byte status code | 4-byte payload length | variable-length payload |
+
+The length fields shall be interpreted as integers, with the high byte sent
+first. This indicates the length of the field immediately following it, not
+including the status code or the length bytes.
+
+The possible status codes are:
+
+0x01 - START - Hello, let's go on a date.
+0x02 - OK - Everything's been going alright so far, let's see each other again.
+0x03 - BAD - I understand what you're saying. I really do. I just don't like it. We have to break up.
+0x04 - ERROR - We can't go on like this. It's like you're speaking another language.
+0x05 - COMPLETE - Will you marry me?
+
+The Thrift SASL communication will proceed as follows:
+
+1. The client is configured at instantiation of the transport with a single
+underlying SASL security mechanism that it supports.
+
+2. The server is configured with a mapping of underlying security mechanism
+name -> mechanism options.
+
+3. At connection time, the client will initiate communication by sending the
+server a START byte, followed by a 1-byte field indicating the length in bytes
+of the underlying security mechanism name that the client would like to use.
+This mechanism name shall be 1-20 characters in length, and follow the
+specifications for SASL mechanism names specified in RFC 2222. This mechanism
+name shall be followed by a 4-byte, potentially zero-value message length word,
+followed by a potentially zero-length payload. The payload is determined by the
+output byte array of the underlying actual security mechanism, and will be
+empty except for those underlying security protocols which implement the
+optional SASL initial response.
+
+4. The server receives this message and, if the mechanism name provided is
+among the set of mechanisms this server transport is configured to accept,
+appropriate initialization of the underlying security mechanism may take place.
+If the mechanism name is not one which the server is configured to support, the
+server shall return the BAD byte, followed by a 4-byte, potentially zero-value
+message length, followed by the potentially zero-length payload which may be a
+status code or message indicating failure. No further communication may take
+place via this transport. If the mechanism name is one which the server
+supports, then proceed to step 5.
+
+5. The server then provides the byte array of the payload received to its
+underlying security mechanism. A challenge is generated by the underlying
+security mechanism on the server, and this is used as the payload for a message
+sent to the client. This message shall consist of an OK byte, followed by the
+non-zero message length word, followed by the payload.
+
+6. The client receives this message from the server and passes the payload to
+its underlying security mechanism to generate a response. The client then sends
+the server an OK byte, followed by the non-zero-value length of the response,
+followed by the bytes of the response as the payload.
+
+7. Steps 5 and 6 are repeated until both security mechanisms are satisfied with
+the challenge/response exchange. When either side has completed its security
+protocol, its next message shall be the COMPLETE byte, followed by a 4-byte
+potentially zero-value length word, followed by a potentially zero-length
+payload. This payload will be empty except for those underlying security
+mechanisms which provide additional data with success.
+
+If at any point in time either side is able to interpret the challenge or
+response sent by the other, but is dissatisfied with the contents thereof, this
+side should send the other a BAD byte, followed by a 4-byte potentially
+zero-value length word, followed by an optional, potentially zero-length
+message encoded in UTF-8 indicating failure. This message should be passed to
+the protocol above the thrift transport by whatever mechanism is appropriate
+and idiomatic for the particular language these thrift bindings are for.
+
+If at any point in time either side fails to interpret the challenge or
+response sent by the other, this side should send the other an ERROR byte,
+followed by a 4-byte potentially zero-value length word, followed by an
+optional, potentially zero-length message encoded in UTF-8. This message should
+be passed to the protocol above the thrift transport by whatever mechanism is
+appropriate and idiomatic for the particular language these thrift bindings are
+for.
+
+If step 7 completes successfully, then the communication is considered
+authenticated and subsequent communication may commence.
+
+If step 7 fails to complete successfully, then no further communication may
+take place via this transport.
+
+8. All writes to the underlying transport must be prefixed by the 4-byte length
+of the payload data, followed by the payload. All reads from this transport
+should read the 4-byte length word, then read the full quantity of bytes
+specified by this length word.
+
+If no SASL QOP (quality of protection) is negotiated during steps 5 and 6, then
+all subsequent writes to/reads from this transport are written/read unaltered,
+save for the length prefix, to the underlying transport.
+
+If a SASL QOP is negotiated, then this must be used by the Thrift transport for
+all subsequent communication. This is done by wrapping subsequent writes to the
+transport using the underlying security mechanism, and unwrapping subsequent
+reads from the underlying transport. Note that in this case, the length prefix
+of the write to the underlying transport is the length of the data after it has
+been wrapped by the underlying security mechanism. Note that the complete
+message must be read before giving this data to the underlying security
+mechanism for unwrapping.
+
+If at any point in time reading of a message fails either because of a
+malformed length word or failure to unwrap by the underlying security
+mechanism, then all further communication on this transport must cease.
diff --git a/lib/java/src/org/apache/thrift/EncodingUtils.java b/lib/java/src/org/apache/thrift/EncodingUtils.java
new file mode 100644
index 0000000..072de93
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/EncodingUtils.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift;
+
+/**
+ * Utility methods for use when encoding/decoding raw data as byte arrays.
+ */
+public class EncodingUtils {
+
+ /**
+ * Encode <code>integer</code> as a series of 4 bytes into <code>buf</code>
+ * starting at position 0 within that buffer.
+ *
+ * @param integer
+ * The integer to encode.
+ * @param buf
+ * The buffer to write to.
+ */
+ public static final void encodeBigEndian(final int integer, final byte[] buf) {
+ encodeBigEndian(integer, buf, 0);
+ }
+
+ /**
+ * Encode <code>integer</code> as a series of 4 bytes into <code>buf</code>
+ * starting at position <code>offset</code>.
+ *
+ * @param integer
+ * The integer to encode.
+ * @param buf
+ * The buffer to write to.
+ * @param offset
+ * The offset within <code>buf</code> to start the encoding.
+ */
+ public static final void encodeBigEndian(final int integer, final byte[] buf, int offset) {
+ buf[offset] = (byte) (0xff & (integer >> 24));
+ buf[offset + 1] = (byte) (0xff & (integer >> 16));
+ buf[offset + 2] = (byte) (0xff & (integer >> 8));
+ buf[offset + 3] = (byte) (0xff & (integer));
+ }
+
+ /**
+ * Decode a series of 4 bytes from <code>buf</code>, starting at position 0,
+ * and interpret them as an integer.
+ *
+ * @param buf
+ * The buffer to read from.
+ * @return An integer, as read from the buffer.
+ */
+ public static final int decodeBigEndian(final byte[] buf) {
+ return decodeBigEndian(buf, 0);
+ }
+
+ /**
+ * Decode a series of 4 bytes from <code>buf</code>, start at
+ * <code>offset</code>, and interpret them as an integer.
+ *
+ * @param buf
+ * The buffer to read from.
+ * @param offset
+ * The offset with <code>buf</code> to start the decoding.
+ * @return An integer, as read from the buffer.
+ */
+ public static final int decodeBigEndian(final byte[] buf, int offset) {
+ return ((buf[offset] & 0xff) << 24) | ((buf[offset + 1] & 0xff) << 16)
+ | ((buf[offset + 2] & 0xff) << 8) | ((buf[offset + 3] & 0xff));
+ }
+
+}
diff --git a/lib/java/src/org/apache/thrift/transport/TSaslClientTransport.java b/lib/java/src/org/apache/thrift/transport/TSaslClientTransport.java
new file mode 100644
index 0000000..fc8a3ea
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/TSaslClientTransport.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport;
+
+import java.util.Map;
+
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslClient;
+import javax.security.sasl.SaslException;
+
+import org.apache.thrift.EncodingUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Wraps another Thrift <code>TTransport</code>, but performs SASL client
+ * negotiation on the call to <code>open()</code>. This class will wrap ensuing
+ * communication over it, if a SASL QOP is negotiated with the other party.
+ */
+public class TSaslClientTransport extends TSaslTransport {
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(TSaslClientTransport.class);
+
+ /**
+ * The name of the mechanism this client supports.
+ */
+ private final String mechanism;
+
+ /**
+ * Uses the given <code>SaslClient</code>.
+ *
+ * @param saslClient
+ * The <code>SaslClient</code> to use for the subsequent SASL
+ * negotiation.
+ * @param transport
+ * Transport underlying this one.
+ */
+ public TSaslClientTransport(SaslClient saslClient, TTransport transport) {
+ super(saslClient, transport);
+ mechanism = saslClient.getMechanismName();
+ }
+
+ /**
+ * Creates a <code>SaslClient</code> using the given SASL-specific parameters.
+ * See the Java documentation for <code>Sasl.createSaslClient</code> for the
+ * details of the parameters.
+ *
+ * @param transport
+ * The underlying Thrift transport.
+ * @throws SaslException
+ */
+ public TSaslClientTransport(String mechanism, String authorizationId, String protocol,
+ String serverName, Map<String, String> props, CallbackHandler cbh, TTransport transport)
+ throws SaslException {
+ super(Sasl.createSaslClient(new String[] { mechanism }, authorizationId, protocol, serverName,
+ props, cbh), transport);
+ this.mechanism = mechanism;
+ }
+
+ /**
+ * Performs the client side of the initial portion of the Thrift SASL
+ * protocol. Generates and sends the initial response to the server, including
+ * which mechanism this client wants to use.
+ */
+ @Override
+ protected void handleSaslStartMessage() throws TTransportException, SaslException {
+ SaslClient saslClient = getSaslClient();
+
+ byte[] initialResponse = new byte[0];
+ if (saslClient.hasInitialResponse())
+ initialResponse = saslClient.evaluateChallenge(initialResponse);
+
+ byte[] mechanismBytes = mechanism.getBytes();
+ byte[] messageHeader = new byte[STATUS_BYTES + MECHANISM_NAME_BYTES + mechanismBytes.length
+ + PAYLOAD_LENGTH_BYTES];
+
+ messageHeader[0] = START;
+ messageHeader[1] = (byte) (0xff & mechanismBytes.length);
+ System.arraycopy(mechanismBytes, 0, messageHeader, STATUS_BYTES + MECHANISM_NAME_BYTES,
+ mechanismBytes.length);
+ EncodingUtils.encodeBigEndian(initialResponse.length, messageHeader, STATUS_BYTES
+ + MECHANISM_NAME_BYTES + mechanismBytes.length);
+
+ LOGGER.debug("Sending mechanism name {} and initial response of length {}", mechanism,
+ initialResponse.length);
+ underlyingTransport.write(messageHeader);
+ underlyingTransport.write(initialResponse);
+ underlyingTransport.flush();
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/TSaslServerTransport.java b/lib/java/src/org/apache/thrift/transport/TSaslServerTransport.java
new file mode 100644
index 0000000..b07e597
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/TSaslServerTransport.java
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.WeakHashMap;
+
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslException;
+import javax.security.sasl.SaslServer;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Wraps another Thrift <code>TTransport</code>, but performs SASL server
+ * negotiation on the call to <code>open()</code>. This class will wrap ensuing
+ * communication over it, if a SASL QOP is negotiated with the other party.
+ */
+public class TSaslServerTransport extends TSaslTransport {
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(TSaslServerTransport.class);
+
+ /**
+ * Mapping from SASL mechanism name -> all the parameters required to
+ * instantiate a SASL server.
+ */
+ private Map<String, TSaslServerDefinition> serverDefinitionMap = new HashMap<String, TSaslServerDefinition>();
+
+ /**
+ * Contains all the parameters used to define a SASL server implementation.
+ */
+ private static class TSaslServerDefinition {
+ public String mechanism;
+ public String protocol;
+ public String serverName;
+ public Map<String, String> props;
+ public CallbackHandler cbh;
+
+ public TSaslServerDefinition(String mechanism, String protocol, String serverName,
+ Map<String, String> props, CallbackHandler cbh) {
+ this.mechanism = mechanism;
+ this.protocol = protocol;
+ this.serverName = serverName;
+ this.props = props;
+ this.cbh = cbh;
+ }
+ }
+
+ /**
+ * Uses the given underlying transport. Assumes that addServerDefinition is
+ * called later.
+ *
+ * @param transport
+ * Transport underlying this one.
+ */
+ public TSaslServerTransport(TTransport transport) {
+ super(transport);
+ }
+
+ /**
+ * Creates a <code>SaslServer</code> using the given SASL-specific parameters.
+ * See the Java documentation for <code>Sasl.createSaslServer</code> for the
+ * details of the parameters.
+ *
+ * @param transport
+ * The underlying Thrift transport.
+ */
+ public TSaslServerTransport(String mechanism, String protocol, String serverName,
+ Map<String, String> props, CallbackHandler cbh, TTransport transport) {
+ super(transport);
+ addServerDefinition(mechanism, protocol, serverName, props, cbh);
+ }
+
+ private TSaslServerTransport(Map<String, TSaslServerDefinition> serverDefinitionMap, TTransport transport) {
+ super(transport);
+ this.serverDefinitionMap.putAll(serverDefinitionMap);
+ }
+
+ /**
+ * Add a supported server definition to this transport. See the Java
+ * documentation for <code>Sasl.createSaslServer</code> for the details of the
+ * parameters.
+ */
+ public void addServerDefinition(String mechanism, String protocol, String serverName,
+ Map<String, String> props, CallbackHandler cbh) {
+ serverDefinitionMap.put(mechanism, new TSaslServerDefinition(mechanism, protocol, serverName,
+ props, cbh));
+ }
+
+ /**
+ * Performs the server side of the initial portion of the Thrift SASL protocol.
+ * Receives the initial response from the client, creates a SASL server using
+ * the mechanism requested by the client (if this server supports it), and
+ * sends the first challenge back to the client.
+ */
+ @Override
+ protected void handleSaslStartMessage() throws TTransportException, SaslException {
+ // Get the status byte and length of the mechanism name.
+ byte[] messageHeader = new byte[STATUS_BYTES + MECHANISM_NAME_BYTES];
+ underlyingTransport.readAll(messageHeader, 0, messageHeader.length);
+ LOGGER.debug("Received status {} and mechanism name length {}", messageHeader[0],
+ messageHeader[1]);
+ if (messageHeader[0] != START) {
+ sendAndThrowMessage(ERROR, "Expecting START status, received " + messageHeader[0]);
+ }
+
+ // Get the mechanism name.
+ byte[] mechanismBytes = new byte[messageHeader[1]];
+ underlyingTransport.readAll(mechanismBytes, 0, mechanismBytes.length);
+
+ String mechanismName = new String(mechanismBytes);
+ TSaslServerDefinition serverDefinition = serverDefinitionMap.get(new String(mechanismBytes));
+ LOGGER.debug("Received mechanism name '{}'", mechanismName);
+
+ if (serverDefinition == null) {
+ sendAndThrowMessage(BAD, "Unsupported mechanism type " + mechanismName);
+ }
+ SaslServer saslServer = Sasl.createSaslServer(serverDefinition.mechanism,
+ serverDefinition.protocol, serverDefinition.serverName, serverDefinition.props,
+ serverDefinition.cbh);
+
+ // Evaluate the initial response and send the first challenge.
+ byte[] initialResponse = new byte[readLength()];
+ sendSaslMessage(saslServer.isComplete() ? COMPLETE : OK, saslServer
+ .evaluateResponse(initialResponse));
+
+ setSaslServer(saslServer);
+ }
+
+ /**
+ * <code>TTransportFactory</code> to create
+ * <code>TSaslServerTransports<c/ode>. Ensures that a given
+ * underlying <code>TTransport</code> instance receives the same
+ * <code>TSaslServerTransport</code>. This is kind of an awful hack to work
+ * around the fact that Thrift is designed assuming that
+ * <code>TTransport</code> instances are stateless, and thus the existing
+ * <code>TServers</code> use different <code>TTransport</code> instances for
+ * input and output.
+ */
+ public static class Factory extends TTransportFactory {
+
+ /**
+ * This is the implementation of the awful hack described above.
+ * <code>WeakHashMap</code> is used to ensure that we don't leak memory.
+ */
+ private static Map<TTransport, TSaslServerTransport> transportMap =
+ Collections.synchronizedMap(new WeakHashMap<TTransport, TSaslServerTransport>());
+
+ /**
+ * Mapping from SASL mechanism name -> all the parameters required to
+ * instantiate a SASL server.
+ */
+ private Map<String, TSaslServerDefinition> serverDefinitionMap = new HashMap<String, TSaslServerDefinition>();
+
+ /**
+ * Create a new Factory. Assumes that <code>addServerDefinition</code> will
+ * be called later.
+ */
+ public Factory() {
+ super();
+ }
+
+ /**
+ * Create a new <code>Factory</code>, initially with the single server
+ * definition given. You may still call <code>addServerDefinition</code>
+ * later. See the Java documentation for <code>Sasl.createSaslServer</code>
+ * for the details of the parameters.
+ */
+ public Factory(String mechanism, String protocol, String serverName,
+ Map<String, String> props, CallbackHandler cbh) {
+ super();
+ addServerDefinition(mechanism, protocol, serverName, props, cbh);
+ }
+
+ /**
+ * Add a supported server definition to the transports created by this
+ * factory. See the Java documentation for
+ * <code>Sasl.createSaslServer</code> for the details of the parameters.
+ */
+ public void addServerDefinition(String mechanism, String protocol, String serverName,
+ Map<String, String> props, CallbackHandler cbh) {
+ serverDefinitionMap.put(mechanism, new TSaslServerDefinition(mechanism, protocol, serverName,
+ props, cbh));
+ }
+
+ /**
+ * Get a new <code>TSaslServerTransport</code> instance, or reuse the
+ * existing one if a <code>TSaslServerTransport</code> has already been
+ * created before using the given <code>TTransport</code> as an underlying
+ * transport. This ensures that a given underlying transport instance
+ * receives the same <code>TSaslServerTransport</code>.
+ */
+ @Override
+ public TTransport getTransport(TTransport base) {
+ TSaslServerTransport ret = transportMap.get(base);
+ if (ret == null) {
+ LOGGER.debug("transport map does not contain key", base);
+ ret = new TSaslServerTransport(serverDefinitionMap, base);
+ try {
+ ret.open();
+ } catch (TTransportException e) {
+ LOGGER.debug("failed to open server transport", e);
+ return null;
+ }
+ transportMap.put(base, ret);
+ } else {
+ LOGGER.debug("transport map does contain key {}", base);
+ }
+ return ret;
+ }
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/TSaslTransport.java b/lib/java/src/org/apache/thrift/transport/TSaslTransport.java
new file mode 100644
index 0000000..b5eadb7
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/TSaslTransport.java
@@ -0,0 +1,470 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport;
+
+import java.io.UnsupportedEncodingException;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Set;
+
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslClient;
+import javax.security.sasl.SaslException;
+import javax.security.sasl.SaslServer;
+
+import org.apache.thrift.EncodingUtils;
+import org.apache.thrift.TByteArrayOutputStream;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A superclass for SASL client/server thrift transports. A subclass need only
+ * implement the <code>open</open> method.
+ */
+abstract class TSaslTransport extends TTransport {
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(TSaslTransport.class);
+
+ protected static final int DEFAULT_MAX_LENGTH = 0x7FFFFFFF;
+
+ protected static final int MECHANISM_NAME_BYTES = 1;
+ protected static final int STATUS_BYTES = 1;
+ protected static final int PAYLOAD_LENGTH_BYTES = 4;
+
+ /**
+ * Status bytes used during the initial Thrift SASL handshake.
+ */
+ protected static final byte START = 0x01;
+ protected static final byte OK = 0x02;
+ protected static final byte BAD = 0x03;
+ protected static final byte ERROR = 0x04;
+ protected static final byte COMPLETE = 0x05;
+
+ protected static final Set<Byte> VALID_STATUSES = new HashSet<Byte>(Arrays.asList(START, OK, BAD, ERROR, COMPLETE));
+
+ /**
+ * Transport underlying this one.
+ */
+ protected TTransport underlyingTransport;
+
+ /**
+ * Either a SASL client or a SASL server.
+ */
+ private SaslParticipant sasl;
+
+ /**
+ * Whether or not we should wrap/unwrap reads/writes. Determined by whether or
+ * not a QOP is negotiated during the SASL handshake.
+ */
+ private boolean shouldWrap = false;
+
+ /**
+ * Buffer for input.
+ */
+ private TMemoryInputTransport readBuffer = new TMemoryInputTransport();
+
+ /**
+ * Buffer for output.
+ */
+ private final TByteArrayOutputStream writeBuffer = new TByteArrayOutputStream(1024);
+
+ /**
+ * Create a TSaslTransport. It's assumed that setSaslServer will be called
+ * later to initialize the SASL endpoint underlying this transport.
+ *
+ * @param underlyingTransport
+ * The thrift transport which this transport is wrapping.
+ */
+ protected TSaslTransport(TTransport underlyingTransport) {
+ this.underlyingTransport = underlyingTransport;
+ }
+
+ /**
+ * Create a TSaslTransport which acts as a client.
+ *
+ * @param saslClient
+ * The <code>SaslClient</code> which this transport will use for SASL
+ * negotiation.
+ * @param underlyingTransport
+ * The thrift transport which this transport is wrapping.
+ */
+ protected TSaslTransport(SaslClient saslClient, TTransport underlyingTransport) {
+ sasl = new SaslParticipant(saslClient);
+ this.underlyingTransport = underlyingTransport;
+ }
+
+ protected void setSaslServer(SaslServer saslServer) {
+ sasl = new SaslParticipant(saslServer);
+ }
+
+ // Used to read the status byte and payload length.
+ private final byte[] messageHeader = new byte[STATUS_BYTES + PAYLOAD_LENGTH_BYTES];
+
+ /**
+ * Send a complete Thrift SASL message.
+ *
+ * @param status
+ * The status to send.
+ * @param payload
+ * The data to send as the payload of this message.
+ * @throws TTransportException
+ */
+ protected void sendSaslMessage(byte status, byte[] payload) throws TTransportException {
+ if (payload == null)
+ payload = new byte[0];
+
+ messageHeader[0] = status;
+ EncodingUtils.encodeBigEndian(payload.length, messageHeader, STATUS_BYTES);
+
+ LOGGER.debug("Writing message with status {} and payload length {}", status, payload.length);
+ underlyingTransport.write(messageHeader);
+ underlyingTransport.write(payload);
+ underlyingTransport.flush();
+ }
+
+ /**
+ * Read a complete Thrift SASL message.
+ *
+ * @return The SASL status and payload from this message.
+ * @throws TTransportException
+ * Thrown if there is a failure reading from the underlying
+ * transport, or if a status code of BAD or ERROR is encountered.
+ */
+ protected SaslResponse receiveSaslMessage() throws TTransportException {
+ underlyingTransport.readAll(messageHeader, 0, messageHeader.length);
+
+ byte status = messageHeader[0];
+ byte[] payload = new byte[EncodingUtils.decodeBigEndian(messageHeader, STATUS_BYTES)];
+ underlyingTransport.readAll(payload, 0, payload.length);
+
+ if (!VALID_STATUSES.contains(status))
+ sendAndThrowMessage(ERROR, "Invalid status " + status);
+ else if (status == BAD || status == ERROR) {
+ try {
+ throw new TTransportException(new String(payload, "UTF-8"));
+ } catch (UnsupportedEncodingException e) {
+ throw new TTransportException(e);
+ }
+ }
+
+ LOGGER.debug("Received message with status {} and payload length {}", status, payload.length);
+ return new SaslResponse(status, payload);
+ }
+
+ /**
+ * Send a Thrift SASL message with the given status (usaully BAD or ERROR) and
+ * string message, and then throw a TTransportException with the given
+ * message.
+ *
+ * @param status
+ * The Thrift SASL status code to send. Usually BAD or ERROR.
+ * @param message
+ * The optional message to send to the other side.
+ * @throws TTransportException
+ * Always thrown with the message provided.
+ */
+ protected void sendAndThrowMessage(byte status, String message) throws TTransportException {
+ sendSaslMessage(status, message.getBytes());
+ throw new TTransportException(message);
+ }
+
+ /**
+ * Implemented by subclasses to start the Thrift SASL handshake process. When
+ * this method completes, the <code>SaslParticipant</code> in this class is
+ * assumed to be initialized.
+ *
+ * @throws TTransportException
+ * @throws SaslException
+ */
+ abstract protected void handleSaslStartMessage() throws TTransportException, SaslException;
+
+ /**
+ * Opens the underlying transport if it's not already open and then performs
+ * SASL negotiation. If a QOP is negoiated during this SASL handshake, it used
+ * for all communication on this transport after this call is complete.
+ */
+ @Override
+ public void open() throws TTransportException {
+ LOGGER.debug("opening transport {}", this);
+ if (sasl != null && sasl.isComplete())
+ throw new TTransportException("SASL transport already open");
+
+ if (!underlyingTransport.isOpen())
+ underlyingTransport.open();
+
+ try {
+ handleSaslStartMessage();
+
+ SaslResponse message;
+ do {
+ message = receiveSaslMessage();
+ if (message.status != COMPLETE && message.status != OK) {
+ throw new TTransportException("Expected COMPLETE or OK, got " + message.status);
+ }
+
+ if (sasl.isComplete() && message.status == COMPLETE)
+ break;
+
+ byte[] challenge = sasl.evaluateChallengeOrResponse(message.payload);
+ sendSaslMessage(sasl.isComplete() ? COMPLETE : OK, challenge);
+ } while (!(sasl.isComplete() && message.status == COMPLETE));
+ } catch (SaslException e) {
+ underlyingTransport.close();
+ sendAndThrowMessage(BAD, e.getMessage());
+ }
+
+ String qop = (String) sasl.getNegotiatedProperty(Sasl.QOP);
+ if (qop != null && !qop.equalsIgnoreCase("auth"))
+ shouldWrap = true;
+ }
+
+ /**
+ * Get the underlying <code>SaslClient</code>.
+ *
+ * @return The <code>SaslClient</code>, or <code>null</code> if this transport
+ * is backed by a <code>SaslServer</code>.
+ */
+ protected SaslClient getSaslClient() {
+ return sasl.saslClient;
+ }
+
+ /**
+ * Get the underlying <code>SaslServer</code>.
+ *
+ * @return The <code>SaslServer</code>, or <code>null</code> if this transport
+ * is backed by a <code>SaslClient</code>.
+ */
+ protected SaslServer getSaslServer() {
+ return sasl.saslServer;
+ }
+
+ /**
+ * Read a 4-byte word from the underlying transport and interpret it as an
+ * integer.
+ *
+ * @return The length prefix of the next SASL message to read.
+ * @throws TTransportException
+ * Thrown if reading from the underlying transport fails.
+ */
+ protected int readLength() throws TTransportException {
+ byte[] lenBuf = new byte[4];
+ underlyingTransport.readAll(lenBuf, 0, lenBuf.length);
+ return EncodingUtils.decodeBigEndian(lenBuf);
+ }
+
+ /**
+ * Write the given integer as 4 bytes to the underlying transport.
+ *
+ * @param length
+ * The length prefix of the next SASL message to write.
+ * @throws TTransportException
+ * Thrown if writing to the underlying transport fails.
+ */
+ protected void writeLength(int length) throws TTransportException {
+ byte[] lenBuf = new byte[4];
+ TFramedTransport.encodeFrameSize(length, lenBuf);
+ underlyingTransport.write(lenBuf);
+ }
+
+ // Below is the SASL implementation of the TTransport interface.
+
+ /**
+ * Closes the underlying transport and disposes of the SASL implementation
+ * underlying this transport.
+ */
+ @Override
+ public void close() {
+ underlyingTransport.close();
+ try {
+ sasl.dispose();
+ } catch (SaslException e) {
+ // Not much we can do here.
+ }
+ }
+
+ /**
+ * True if the underlying transport is open and the SASL handshake is
+ * complete.
+ */
+ @Override
+ public boolean isOpen() {
+ return underlyingTransport.isOpen() && sasl != null && sasl.isComplete();
+ }
+
+ /**
+ * Read from the underlying transport. Unwraps the contents if a QOP was
+ * negotiated during the SASL handshake.
+ */
+ @Override
+ public int read(byte[] buf, int off, int len) throws TTransportException {
+ if (!isOpen())
+ throw new TTransportException("SASL authentication not complete");
+
+ int got = readBuffer.read(buf, off, len);
+ if (got > 0) {
+ return got;
+ }
+
+ // Read another frame of data
+ try {
+ readFrame();
+ } catch (SaslException e) {
+ throw new TTransportException(e);
+ }
+
+ return readBuffer.read(buf, off, len);
+ }
+
+ /**
+ * Read a single frame of data from the underlying transport, unwrapping if
+ * necessary.
+ *
+ * @throws TTransportException
+ * Thrown if there's an error reading from the underlying transport.
+ * @throws SaslException
+ * Thrown if there's an error unwrapping the data.
+ */
+ private void readFrame() throws TTransportException, SaslException {
+ int dataLength = readLength();
+
+ if (dataLength < 0)
+ throw new TTransportException("Read a negative frame size (" + dataLength + ")!");
+
+ byte[] buff = new byte[dataLength];
+ LOGGER.debug("reading data length: {}", dataLength);
+ underlyingTransport.readAll(buff, 0, dataLength);
+ if (shouldWrap) {
+ buff = sasl.unwrap(buff, 0, buff.length);
+ LOGGER.debug("data length after unwrap: {}", buff.length);
+ }
+ readBuffer.reset(buff);
+ }
+
+ /**
+ * Write to the underlying transport.
+ */
+ @Override
+ public void write(byte[] buf, int off, int len) throws TTransportException {
+ if (!isOpen())
+ throw new TTransportException("SASL authentication not complete");
+
+ writeBuffer.write(buf, off, len);
+ }
+
+ /**
+ * Flushes to the underlying transport. Wraps the contents if a QOP was
+ * negotiated during the SASL handshake.
+ */
+ @Override
+ public void flush() throws TTransportException {
+ byte[] buf = writeBuffer.get();
+ int dataLength = writeBuffer.len();
+ writeBuffer.reset();
+
+ if (shouldWrap) {
+ LOGGER.debug("data length before wrap: {}", dataLength);
+ try {
+ buf = sasl.wrap(buf, 0, dataLength);
+ } catch (SaslException e) {
+ throw new TTransportException(e);
+ }
+ dataLength = buf.length;
+ }
+ LOGGER.debug("writing data length: {}", dataLength);
+ writeLength(dataLength);
+ underlyingTransport.write(buf, 0, dataLength);
+ underlyingTransport.flush();
+ }
+
+ /**
+ * Used exclusively by readSaslMessage to return both a status and data.
+ */
+ private static class SaslResponse {
+ public byte status;
+ public byte[] payload;
+
+ public SaslResponse(byte status, byte[] payload) {
+ this.status = status;
+ this.payload = payload;
+ }
+ }
+
+ /**
+ * Used to abstract over the <code>SaslServer</code> and
+ * <code>SaslClient</code> classes, which share a lot of their interface, but
+ * unfortunately don't share a common superclass.
+ */
+ private static class SaslParticipant {
+ // One of these will always be null.
+ public SaslServer saslServer;
+ public SaslClient saslClient;
+
+ public SaslParticipant(SaslServer saslServer) {
+ this.saslServer = saslServer;
+ }
+
+ public SaslParticipant(SaslClient saslClient) {
+ this.saslClient = saslClient;
+ }
+
+ public byte[] evaluateChallengeOrResponse(byte[] challengeOrResponse) throws SaslException {
+ if (saslClient != null) {
+ return saslClient.evaluateChallenge(challengeOrResponse);
+ } else {
+ return saslServer.evaluateResponse(challengeOrResponse);
+ }
+ }
+
+ public boolean isComplete() {
+ if (saslClient != null)
+ return saslClient.isComplete();
+ else
+ return saslServer.isComplete();
+ }
+
+ public void dispose() throws SaslException {
+ if (saslClient != null)
+ saslClient.dispose();
+ else
+ saslServer.dispose();
+ }
+
+ public byte[] unwrap(byte[] buf, int off, int len) throws SaslException {
+ if (saslClient != null)
+ return saslClient.unwrap(buf, off, len);
+ else
+ return saslServer.unwrap(buf, off, len);
+ }
+
+ public byte[] wrap(byte[] buf, int off, int len) throws SaslException {
+ if (saslClient != null)
+ return saslClient.wrap(buf, off, len);
+ else
+ return saslServer.wrap(buf, off, len);
+ }
+
+ public Object getNegotiatedProperty(String propName) {
+ if (saslClient != null)
+ return saslClient.getNegotiatedProperty(propName);
+ else
+ return saslServer.getNegotiatedProperty(propName);
+ }
+ }
+}
diff --git a/lib/java/test/org/apache/thrift/server/ServerTestBase.java b/lib/java/test/org/apache/thrift/server/ServerTestBase.java
index 88430e6..3bfc8d7 100644
--- a/lib/java/test/org/apache/thrift/server/ServerTestBase.java
+++ b/lib/java/test/org/apache/thrift/server/ServerTestBase.java
@@ -34,7 +34,6 @@
import org.apache.thrift.protocol.TCompactProtocol;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.protocol.TProtocolFactory;
-import org.apache.thrift.transport.TFramedTransport;
import org.apache.thrift.transport.TSocket;
import org.apache.thrift.transport.TTransport;
@@ -286,7 +285,7 @@
public abstract void stopServer() throws Exception;
- public abstract TTransport getTransport() throws Exception;
+ public abstract TTransport getClientTransport(TTransport underlyingTransport) throws Exception;
private void testByte(ThriftTest.Client testClient) throws TException {
byte i8 = testClient.testByte((byte)1);
@@ -374,12 +373,9 @@
startServer(processor, protoFactory);
- TTransport transport;
-
TSocket socket = new TSocket(HOST, PORT);
socket.setTimeout(SOCKET_TIMEOUT);
- transport = socket;
- transport = new TFramedTransport(transport);
+ TTransport transport = getClientTransport(socket);
TProtocol protocol = protoFactory.getProtocol(transport);
ThriftTest.Client testClient = new ThriftTest.Client(protocol);
diff --git a/lib/java/test/org/apache/thrift/server/TestNonblockingServer.java b/lib/java/test/org/apache/thrift/server/TestNonblockingServer.java
index c43b473..e202435 100644
--- a/lib/java/test/org/apache/thrift/server/TestNonblockingServer.java
+++ b/lib/java/test/org/apache/thrift/server/TestNonblockingServer.java
@@ -23,7 +23,6 @@
import org.apache.thrift.protocol.TProtocolFactory;
import org.apache.thrift.transport.TFramedTransport;
import org.apache.thrift.transport.TNonblockingServerSocket;
-import org.apache.thrift.transport.TSocket;
import org.apache.thrift.transport.TTransport;
public class TestNonblockingServer extends ServerTestBase {
@@ -68,9 +67,7 @@
}
@Override
- public TTransport getTransport() throws Exception {
- TSocket socket = new TSocket(HOST, PORT);
- socket.setTimeout(SOCKET_TIMEOUT);
- return new TFramedTransport(socket);
+ public TTransport getClientTransport(TTransport underlyingTransport) throws Exception {
+ return new TFramedTransport(underlyingTransport);
}
}
diff --git a/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java b/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java
new file mode 100644
index 0000000..812028d
--- /dev/null
+++ b/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.callback.NameCallback;
+import javax.security.auth.callback.PasswordCallback;
+import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.sasl.AuthorizeCallback;
+import javax.security.sasl.RealmCallback;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslException;
+
+import org.apache.thrift.TProcessor;
+import org.apache.thrift.protocol.TProtocolFactory;
+import org.apache.thrift.server.ServerTestBase;
+import org.apache.thrift.server.TServer;
+import org.apache.thrift.server.TSimpleServer;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import junit.framework.TestCase;
+
+public class TestTSaslTransports extends TestCase {
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(TestTSaslTransports.class);
+
+ private static final String HOST = "localhost";
+ private static final String SERVICE = "thrift-test";
+ private static final String PRINCIPAL = "thrift-test-principal";
+ private static final String PASSWORD = "super secret password";
+ private static final String REALM = "thrift-test-realm";
+
+ private static final String UNWRAPPED_MECHANISM = "CRAM-MD5";
+ private static final Map<String, String> UNWRAPPED_PROPS = null;
+
+ private static final String WRAPPED_MECHANISM = "DIGEST-MD5";
+ private static final Map<String, String> WRAPPED_PROPS = new HashMap<String, String>();
+
+ static {
+ WRAPPED_PROPS.put(Sasl.QOP, "auth-int");
+ WRAPPED_PROPS.put("com.sun.security.sasl.digest.realm", REALM);
+ }
+
+ private static final String testMessage1 = "Hello, world! Also, four "
+ + "score and seven years ago our fathers brought forth on this "
+ + "continent a new nation, conceived in liberty, and dedicated to the "
+ + "proposition that all men are created equal.";
+
+ private static final String testMessage2 = "I have a dream that one day "
+ + "this nation will rise up and live out the true meaning of its creed: "
+ + "'We hold these truths to be self-evident, that all men are created equal.'";
+
+
+ private static class TestSaslCallbackHandler implements CallbackHandler {
+ @Override
+ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
+ for (Callback c : callbacks) {
+ if (c instanceof NameCallback) {
+ ((NameCallback) c).setName(PRINCIPAL);
+ } else if (c instanceof PasswordCallback) {
+ ((PasswordCallback) c).setPassword(PASSWORD.toCharArray());
+ } else if (c instanceof AuthorizeCallback) {
+ ((AuthorizeCallback) c).setAuthorized(true);
+ } else if (c instanceof RealmCallback) {
+ ((RealmCallback) c).setText(REALM);
+ } else {
+ throw new UnsupportedCallbackException(c);
+ }
+ }
+ }
+ }
+
+ private void testSaslOpen(final String mechanism, final Map<String, String> props)
+ throws SaslException, TTransportException {
+ Thread serverThread = new Thread() {
+ public void run() {
+ try {
+ TServerSocket serverSocket = new TServerSocket(ServerTestBase.PORT);
+ TTransport serverTransport = serverSocket.accept();
+ TTransport saslServerTransport = new TSaslServerTransport(mechanism, SERVICE, HOST,
+ props, new TestSaslCallbackHandler(), serverTransport);
+
+ saslServerTransport.open();
+
+ byte[] inBuf = new byte[testMessage1.getBytes().length];
+ // Deliberately read less than the full buffer to ensure
+ // that TSaslTransport is correctly buffering reads. This
+ // will fail for the WRAPPED test, if it doesn't work.
+ saslServerTransport.readAll(inBuf, 0, 5);
+ saslServerTransport.readAll(inBuf, 5, 10);
+ saslServerTransport.readAll(inBuf, 15, inBuf.length - 15);
+ LOGGER.debug("server got: {}", new String(inBuf));
+ assertEquals(new String(inBuf), testMessage1);
+
+ LOGGER.debug("server writing: {}", testMessage2);
+ saslServerTransport.write(testMessage2.getBytes());
+ saslServerTransport.flush();
+
+ serverSocket.close();
+ saslServerTransport.close();
+ } catch (TTransportException e) {
+ fail(e.toString());
+ }
+ }
+ };
+ serverThread.start();
+
+ try {
+ Thread.sleep(1000);
+ } catch (InterruptedException e) {
+ // Ah well.
+ }
+
+ TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT);
+ TTransport saslClientTransport = new TSaslClientTransport(mechanism,
+ PRINCIPAL, SERVICE, HOST, props, new TestSaslCallbackHandler(), clientSocket);
+ saslClientTransport.open();
+ LOGGER.debug("client writing: {}", testMessage1);
+ saslClientTransport.write(testMessage1.getBytes());
+ saslClientTransport.flush();
+
+ byte[] inBuf = new byte[testMessage2.getBytes().length];
+ saslClientTransport.readAll(inBuf, 0, inBuf.length);
+ LOGGER.debug("client got: {}", new String(inBuf));
+ assertEquals(new String(inBuf), testMessage2);
+
+ TTransportException expectedException = null;
+ try {
+ saslClientTransport.open();
+ } catch (TTransportException e) {
+ expectedException = e;
+ }
+ assertNotNull(expectedException);
+
+ saslClientTransport.close();
+
+ try {
+ serverThread.join();
+ } catch (InterruptedException e) {
+ // Ah well.
+ }
+ }
+
+ public void testUnwrappedOpen() throws SaslException, TTransportException {
+ testSaslOpen(UNWRAPPED_MECHANISM, UNWRAPPED_PROPS);
+ }
+
+ public void testWrappedOpen() throws SaslException, TTransportException {
+ testSaslOpen(WRAPPED_MECHANISM, WRAPPED_PROPS);
+ }
+
+ public void testWithServer() throws Exception {
+ new TestTSaslTransportsWithServer().testIt();
+ }
+
+ private static class TestTSaslTransportsWithServer extends ServerTestBase {
+
+ private Thread serverThread;
+ private TServer server;
+
+ @Override
+ public TTransport getClientTransport(TTransport underlyingTransport) throws Exception {
+ return new TSaslClientTransport(WRAPPED_MECHANISM,
+ PRINCIPAL, SERVICE, HOST, WRAPPED_PROPS, new TestSaslCallbackHandler(), underlyingTransport);
+ }
+
+ @Override
+ public void startServer(final TProcessor processor, final TProtocolFactory protoFactory) throws Exception {
+ serverThread = new Thread() {
+ public void run() {
+ try {
+ // Transport
+ TServerSocket socket = new TServerSocket(PORT);
+
+ TTransportFactory factory = new TSaslServerTransport.Factory(WRAPPED_MECHANISM,
+ SERVICE, HOST, WRAPPED_PROPS, new TestSaslCallbackHandler());
+ server = new TSimpleServer(processor, socket, factory, protoFactory);
+
+ // Run it
+ LOGGER.debug("Starting the server on port {}", PORT);
+ server.serve();
+ } catch (Exception e) {
+ e.printStackTrace();
+ fail();
+ }
+ }
+ };
+ serverThread.start();
+ Thread.sleep(1000);
+ }
+
+ @Override
+ public void stopServer() throws Exception {
+ server.stop();
+ try {
+ serverThread.join();
+ } catch (InterruptedException e) {}
+ }
+
+ }
+
+}