THRIFT-912. java: Fix some bugs in SASL implementation, update protocol spec slightly


git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@1001973 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/doc/thrift-sasl-spec.txt b/doc/thrift-sasl-spec.txt
index 59bfcf9..02cf79e 100644
--- a/doc/thrift-sasl-spec.txt
+++ b/doc/thrift-sasl-spec.txt
@@ -1,6 +1,5 @@
-A Thrift SASL message shall be a byte array of one of the following forms:
+A Thrift SASL message shall be a byte array of the following form:
 
-| 1-byte START status code | 1-byte mechanism name length | variable length mechanism name | 4-byte payload length | variable-length payload |
 | 1-byte status code | 4-byte payload length | variable-length payload |
 
 The length fields shall be interpreted as integers, with the high byte sent
@@ -24,15 +23,10 @@
 name -> mechanism options.
 
 3. At connection time, the client will initiate communication by sending the
-server a START byte, followed by a 1-byte field indicating the length in bytes
-of the underlying security mechanism name that the client would like to use.
+server a START message. The payload of this message will be the name of the
+underlying security mechanism that the client would like to use.
 This mechanism name shall be 1-20 characters in length, and follow the
-specifications for SASL mechanism names specified in RFC 2222. This mechanism
-name shall be followed by a 4-byte, potentially zero-value message length word,
-followed by a potentially zero-length payload. The payload is determined by the
-output byte array of the underlying actual security mechanism, and will be
-empty except for those underlying security protocols which implement the
-optional SASL initial response.
+specifications for SASL mechanism names specified in RFC 2222.
 
 4. The server receives this message and, if the mechanism name provided is
 among the set of mechanisms this server transport is configured to accept,
@@ -44,18 +38,25 @@
 place via this transport. If the mechanism name is one which the server
 supports, then proceed to step 5.
 
-5. The server then provides the byte array of the payload received to its
+5. Following the START message, the client must send another message containing
+the "initial response" of the chosen SASL implementation. The client may send
+this message piggy-backed on the "START" message of step 3. The message type
+of this message must be either "OK" or "COMPLETE", depending on whether the
+SASL implementation indicates that this side of the authentication has been
+satisfied.
+
+6. The server then provides the byte array of the payload received to its
 underlying security mechanism. A challenge is generated by the underlying
 security mechanism on the server, and this is used as the payload for a message
 sent to the client. This message shall consist of an OK byte, followed by the
 non-zero message length word, followed by the payload.
 
-6. The client receives this message from the server and passes the payload to
+7. The client receives this message from the server and passes the payload to
 its underlying security mechanism to generate a response. The client then sends
 the server an OK byte, followed by the non-zero-value length of the response,
 followed by the bytes of the response as the payload.
 
-7. Steps 5 and 6 are repeated until both security mechanisms are satisfied with
+8. Steps 6 and 7 are repeated until both security mechanisms are satisfied with
 the challenge/response exchange. When either side has completed its security
 protocol, its next message shall be the COMPLETE byte, followed by a 4-byte
 potentially zero-value length word, followed by a potentially zero-length
@@ -78,10 +79,10 @@
 appropriate and idiomatic for the particular language these thrift bindings are
 for.
 
-If step 7 completes successfully, then the communication is considered
+If step 8 completes successfully, then the communication is considered
 authenticated and subsequent communication may commence.
 
-If step 7 fails to complete successfully, then no further communication may
+If step 8 fails to complete successfully, then no further communication may
 take place via this transport.
 
 8. All writes to the underlying transport must be prefixed by the 4-byte length
@@ -89,7 +90,7 @@
 should read the 4-byte length word, then read the full quantity of bytes
 specified by this length word.
 
-If no SASL QOP (quality of protection) is negotiated during steps 5 and 6, then
+If no SASL QOP (quality of protection) is negotiated during steps 6 and 7, then
 all subsequent writes to/reads from this transport are written/read unaltered,
 save for the length prefix, to the underlying transport.
 
diff --git a/lib/java/src/org/apache/thrift/transport/TSaslClientTransport.java b/lib/java/src/org/apache/thrift/transport/TSaslClientTransport.java
index fc8a3ea..8c1d0e5 100644
--- a/lib/java/src/org/apache/thrift/transport/TSaslClientTransport.java
+++ b/lib/java/src/org/apache/thrift/transport/TSaslClientTransport.java
@@ -75,6 +75,12 @@
     this.mechanism = mechanism;
   }
 
+
+  @Override
+  protected SaslRole getRole() {
+    return SaslRole.CLIENT;
+  }
+
   /**
    * Performs the client side of the initial portion of the Thrift SASL
    * protocol. Generates and sends the initial response to the server, including
@@ -88,21 +94,15 @@
     if (saslClient.hasInitialResponse())
       initialResponse = saslClient.evaluateChallenge(initialResponse);
 
-    byte[] mechanismBytes = mechanism.getBytes();
-    byte[] messageHeader = new byte[STATUS_BYTES + MECHANISM_NAME_BYTES + mechanismBytes.length
-        + PAYLOAD_LENGTH_BYTES];
-
-    messageHeader[0] = START;
-    messageHeader[1] = (byte) (0xff & mechanismBytes.length);
-    System.arraycopy(mechanismBytes, 0, messageHeader, STATUS_BYTES + MECHANISM_NAME_BYTES,
-        mechanismBytes.length);
-    EncodingUtils.encodeBigEndian(initialResponse.length, messageHeader, STATUS_BYTES
-        + MECHANISM_NAME_BYTES + mechanismBytes.length);
-
     LOGGER.debug("Sending mechanism name {} and initial response of length {}", mechanism,
         initialResponse.length);
-    underlyingTransport.write(messageHeader);
-    underlyingTransport.write(initialResponse);
+
+    byte[] mechanismBytes = mechanism.getBytes();
+    sendSaslMessage(NegotiationStatus.START,
+                    mechanismBytes);
+    // Send initial response
+    sendSaslMessage(saslClient.isComplete() ? NegotiationStatus.COMPLETE : NegotiationStatus.OK,
+                    initialResponse);
     underlyingTransport.flush();
   }
 }
diff --git a/lib/java/src/org/apache/thrift/transport/TSaslServerTransport.java b/lib/java/src/org/apache/thrift/transport/TSaslServerTransport.java
index b07e597..8abcf36 100644
--- a/lib/java/src/org/apache/thrift/transport/TSaslServerTransport.java
+++ b/lib/java/src/org/apache/thrift/transport/TSaslServerTransport.java
@@ -108,6 +108,11 @@
         props, cbh));
   }
 
+  @Override
+  protected SaslRole getRole() {
+    return SaslRole.SERVER;
+  }
+
   /**
    * Performs the server side of the initial portion of the Thrift SASL protocol.
    * Receives the initial response from the client, creates a SASL server using
@@ -116,35 +121,24 @@
    */
   @Override
   protected void handleSaslStartMessage() throws TTransportException, SaslException {
-    // Get the status byte and length of the mechanism name.
-    byte[] messageHeader = new byte[STATUS_BYTES + MECHANISM_NAME_BYTES];
-    underlyingTransport.readAll(messageHeader, 0, messageHeader.length);
-    LOGGER.debug("Received status {} and mechanism name length {}", messageHeader[0],
-        messageHeader[1]);
-    if (messageHeader[0] != START) {
-      sendAndThrowMessage(ERROR, "Expecting START status, received " + messageHeader[0]);
+    SaslResponse message = receiveSaslMessage();
+
+    LOGGER.debug("Received start message with status {}", message.status);
+    if (message.status != NegotiationStatus.START) {
+      sendAndThrowMessage(NegotiationStatus.ERROR, "Expecting START status, received " + message.status);
     }
 
     // Get the mechanism name.
-    byte[] mechanismBytes = new byte[messageHeader[1]];
-    underlyingTransport.readAll(mechanismBytes, 0, mechanismBytes.length);
-
-    String mechanismName = new String(mechanismBytes);
-    TSaslServerDefinition serverDefinition = serverDefinitionMap.get(new String(mechanismBytes));
+    String mechanismName = new String(message.payload);
+    TSaslServerDefinition serverDefinition = serverDefinitionMap.get(mechanismName);
     LOGGER.debug("Received mechanism name '{}'", mechanismName);
 
     if (serverDefinition == null) {
-      sendAndThrowMessage(BAD, "Unsupported mechanism type " + mechanismName);
+      sendAndThrowMessage(NegotiationStatus.BAD, "Unsupported mechanism type " + mechanismName);
     }
     SaslServer saslServer = Sasl.createSaslServer(serverDefinition.mechanism,
         serverDefinition.protocol, serverDefinition.serverName, serverDefinition.props,
         serverDefinition.cbh);
-
-    // Evaluate the initial response and send the first challenge.
-    byte[] initialResponse = new byte[readLength()];
-    sendSaslMessage(saslServer.isComplete() ? COMPLETE : OK, saslServer
-        .evaluateResponse(initialResponse));
-
     setSaslServer(saslServer);
   }
 
@@ -221,7 +215,7 @@
           ret.open();
         } catch (TTransportException e) {
           LOGGER.debug("failed to open server transport", e);
-          return null;
+          throw new RuntimeException(e);
         }
         transportMap.put(base, ret);
       } else {
diff --git a/lib/java/src/org/apache/thrift/transport/TSaslTransport.java b/lib/java/src/org/apache/thrift/transport/TSaslTransport.java
index b5eadb7..24470d9 100644
--- a/lib/java/src/org/apache/thrift/transport/TSaslTransport.java
+++ b/lib/java/src/org/apache/thrift/transport/TSaslTransport.java
@@ -21,6 +21,8 @@
 
 import java.io.UnsupportedEncodingException;
 import java.util.Arrays;
+import java.util.Map;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Set;
 
@@ -48,16 +50,42 @@
   protected static final int STATUS_BYTES = 1;
   protected static final int PAYLOAD_LENGTH_BYTES = 4;
 
+  protected static enum SaslRole {
+    SERVER, CLIENT;
+  }
+
   /**
    * Status bytes used during the initial Thrift SASL handshake.
    */
-  protected static final byte START = 0x01;
-  protected static final byte OK = 0x02;
-  protected static final byte BAD = 0x03;
-  protected static final byte ERROR = 0x04;
-  protected static final byte COMPLETE = 0x05;
+  protected static enum NegotiationStatus {
+    START((byte)0x01),
+    OK((byte)0x02),
+    BAD((byte)0x03),
+    ERROR((byte)0x04),
+    COMPLETE((byte)0x05);
 
-  protected static final Set<Byte> VALID_STATUSES = new HashSet<Byte>(Arrays.asList(START, OK, BAD, ERROR, COMPLETE));
+    private final byte value;
+
+    private static final Map<Byte, NegotiationStatus> reverseMap =
+      new HashMap<Byte, NegotiationStatus>();
+    static {
+      for (NegotiationStatus s : NegotiationStatus.class.getEnumConstants()) {
+        reverseMap.put(s.getValue(), s);
+      }
+    }
+
+    private NegotiationStatus(byte val) {
+      this.value = val;
+    }
+
+    public byte getValue() {
+      return value;
+    }
+
+    public static NegotiationStatus byValue(byte val) {
+      return reverseMap.get(val);
+    }
+  }
 
   /**
    * Transport underlying this one.
@@ -126,14 +154,16 @@
    *          The data to send as the payload of this message.
    * @throws TTransportException
    */
-  protected void sendSaslMessage(byte status, byte[] payload) throws TTransportException {
+  protected void sendSaslMessage(NegotiationStatus status, byte[] payload) throws TTransportException {
     if (payload == null)
       payload = new byte[0];
 
-    messageHeader[0] = status;
+    messageHeader[0] = status.getValue();
     EncodingUtils.encodeBigEndian(payload.length, messageHeader, STATUS_BYTES);
 
-    LOGGER.debug("Writing message with status {} and payload length {}", status, payload.length);
+    if (LOGGER.isDebugEnabled())
+      LOGGER.debug(getRole() + ": Writing message with status {} and payload length {}",
+                   status, payload.length);
     underlyingTransport.write(messageHeader);
     underlyingTransport.write(payload);
     underlyingTransport.flush();
@@ -150,21 +180,25 @@
   protected SaslResponse receiveSaslMessage() throws TTransportException {
     underlyingTransport.readAll(messageHeader, 0, messageHeader.length);
 
-    byte status = messageHeader[0];
+    byte statusByte = messageHeader[0];
     byte[] payload = new byte[EncodingUtils.decodeBigEndian(messageHeader, STATUS_BYTES)];
     underlyingTransport.readAll(payload, 0, payload.length);
 
-    if (!VALID_STATUSES.contains(status))
-      sendAndThrowMessage(ERROR, "Invalid status " + status);
-    else if (status == BAD || status == ERROR) {
+    NegotiationStatus status = NegotiationStatus.byValue(statusByte);
+    if (status == null) {
+      sendAndThrowMessage(NegotiationStatus.ERROR, "Invalid status " + statusByte);
+    } else if (status == NegotiationStatus.BAD || status == NegotiationStatus.ERROR) {
       try {
-        throw new TTransportException(new String(payload, "UTF-8"));
+        String remoteMessage = new String(payload, "UTF-8");
+        throw new TTransportException("Peer indicated failure: " + remoteMessage);
       } catch (UnsupportedEncodingException e) {
         throw new TTransportException(e);
       }
     }
 
-    LOGGER.debug("Received message with status {} and payload length {}", status, payload.length);
+    if (LOGGER.isDebugEnabled())
+      LOGGER.debug(getRole() + ": Received message with status {} and payload length {}",
+                   status, payload.length);
     return new SaslResponse(status, payload);
   }
 
@@ -180,8 +214,13 @@
    * @throws TTransportException
    *           Always thrown with the message provided.
    */
-  protected void sendAndThrowMessage(byte status, String message) throws TTransportException {
-    sendSaslMessage(status, message.getBytes());
+  protected void sendAndThrowMessage(NegotiationStatus status, String message) throws TTransportException {
+    try {
+      sendSaslMessage(status, message.getBytes());
+    } catch (Exception e) {
+      LOGGER.warn("Could not send failure response", e);
+      message += "\nAlso, could not send response: " + e.toString();
+    }
     throw new TTransportException(message);
   }
 
@@ -195,6 +234,8 @@
    */
   abstract protected void handleSaslStartMessage() throws TTransportException, SaslException;
 
+  protected abstract SaslRole getRole();
+
   /**
    * Opens the underlying transport if it's not already open and then performs
    * SASL negotiation. If a QOP is negoiated during this SASL handshake, it used
@@ -210,24 +251,55 @@
       underlyingTransport.open();
 
     try {
+      // Negotiate a SASL mechanism. The client also sends its
+      // initial response, or an empty one.
       handleSaslStartMessage();
+      LOGGER.debug("{}: Start message handled", getRole());
 
-      SaslResponse message;
-      do {
+      SaslResponse message = null;
+      while (!sasl.isComplete()) {
         message = receiveSaslMessage();
-        if (message.status != COMPLETE && message.status != OK) {
+        if (message.status != NegotiationStatus.COMPLETE &&
+            message.status != NegotiationStatus.OK) {
           throw new TTransportException("Expected COMPLETE or OK, got " + message.status);
         }
 
-        if (sasl.isComplete() && message.status == COMPLETE)
-          break;
-
         byte[] challenge = sasl.evaluateChallengeOrResponse(message.payload);
-        sendSaslMessage(sasl.isComplete() ? COMPLETE : OK, challenge);
-      } while (!(sasl.isComplete() && message.status == COMPLETE));
+
+        // If we are the client, and the server indicates COMPLETE, we don't need to
+        // send back any further response.
+        if (message.status == NegotiationStatus.COMPLETE &&
+            getRole() == SaslRole.CLIENT) {
+          LOGGER.debug("{}: All done!", getRole());
+          break;
+        }
+
+        sendSaslMessage(sasl.isComplete() ? NegotiationStatus.COMPLETE : NegotiationStatus.OK,
+                        challenge);
+      }
+      LOGGER.debug("{}: Main negotiation loop complete", getRole());
+
+      assert sasl.isComplete();
+
+      // If we're the client, and we're complete, but the server isn't
+      // complete yet, we need to wait for its response. This will occur
+      // with ANONYMOUS auth, for example, where we send an initial response
+      // and are immediately complete.
+      if (getRole() == SaslRole.CLIENT &&
+          (message == null || message.status == NegotiationStatus.OK)) {
+        LOGGER.debug("{}: SASL Client receiving last message", getRole());
+        message = receiveSaslMessage();
+        if (message.status != NegotiationStatus.COMPLETE) {
+          throw new TTransportException(
+            "Expected SASL COMPLETE, but got " + message.status);
+        }
+      }
     } catch (SaslException e) {
-      underlyingTransport.close();
-      sendAndThrowMessage(BAD, e.getMessage());
+      try {
+        sendAndThrowMessage(NegotiationStatus.BAD, e.getMessage());
+      } finally {
+        underlyingTransport.close();
+      }
     }
 
     String qop = (String) sasl.getNegotiatedProperty(Sasl.QOP);
@@ -241,7 +313,7 @@
    * @return The <code>SaslClient</code>, or <code>null</code> if this transport
    *         is backed by a <code>SaslServer</code>.
    */
-  protected SaslClient getSaslClient() {
+  public SaslClient getSaslClient() {
     return sasl.saslClient;
   }
 
@@ -251,7 +323,7 @@
    * @return The <code>SaslServer</code>, or <code>null</code> if this transport
    *         is backed by a <code>SaslClient</code>.
    */
-  protected SaslServer getSaslServer() {
+  public SaslServer getSaslServer() {
     return sasl.saslServer;
   }
 
@@ -348,7 +420,7 @@
       throw new TTransportException("Read a negative frame size (" + dataLength + ")!");
 
     byte[] buff = new byte[dataLength];
-    LOGGER.debug("reading data length: {}", dataLength);
+    LOGGER.debug("{}: reading data length: {}", getRole(), dataLength);
     underlyingTransport.readAll(buff, 0, dataLength);
     if (shouldWrap) {
       buff = sasl.unwrap(buff, 0, buff.length);
@@ -396,11 +468,11 @@
   /**
    * Used exclusively by readSaslMessage to return both a status and data.
    */
-  private static class SaslResponse {
-    public byte status;
+  protected static class SaslResponse {
+    public NegotiationStatus status;
     public byte[] payload;
 
-    public SaslResponse(byte status, byte[] payload) {
+    public SaslResponse(NegotiationStatus status, byte[] payload) {
       this.status = status;
       this.payload = payload;
     }
diff --git a/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java b/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java
index 812028d..ca121c1 100644
--- a/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java
+++ b/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java
@@ -31,6 +31,10 @@
 import javax.security.sasl.AuthorizeCallback;
 import javax.security.sasl.RealmCallback;
 import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslClient;
+import javax.security.sasl.SaslClientFactory;
+import javax.security.sasl.SaslServer;
+import javax.security.sasl.SaslServerFactory;
 import javax.security.sasl.SaslException;
 
 import org.apache.thrift.TProcessor;
@@ -75,13 +79,19 @@
 
 
   private static class TestSaslCallbackHandler implements CallbackHandler {
+    private final String password;
+
+    public TestSaslCallbackHandler(String password) {
+      this.password = password;
+    }
+
     @Override
     public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
       for (Callback c : callbacks) {
         if (c instanceof NameCallback) {
           ((NameCallback) c).setName(PRINCIPAL);
         } else if (c instanceof PasswordCallback) {
-          ((PasswordCallback) c).setPassword(PASSWORD.toCharArray());
+          ((PasswordCallback) c).setPassword(password.toCharArray());
         } else if (c instanceof AuthorizeCallback) {
           ((AuthorizeCallback) c).setAuthorized(true);
         } else if (c instanceof RealmCallback) {
@@ -93,39 +103,63 @@
     }
   }
 
-  private void testSaslOpen(final String mechanism, final Map<String, String> props)
-      throws SaslException, TTransportException {
-    Thread serverThread = new Thread() {
-      public void run() {
-        try {
-          TServerSocket serverSocket = new TServerSocket(ServerTestBase.PORT);
-          TTransport serverTransport = serverSocket.accept();
-          TTransport saslServerTransport = new TSaslServerTransport(mechanism, SERVICE, HOST,
-              props, new TestSaslCallbackHandler(), serverTransport);
+  private class ServerThread extends Thread {
+    final String mechanism;
+    final Map<String, String> props;
+    volatile Throwable thrown;
 
-          saslServerTransport.open();
+    public ServerThread(String mechanism, Map<String, String> props) {
+      this.mechanism = mechanism;
+      this.props = props;
+    }
 
-          byte[] inBuf = new byte[testMessage1.getBytes().length];
-          // Deliberately read less than the full buffer to ensure
-          // that TSaslTransport is correctly buffering reads. This
-          // will fail for the WRAPPED test, if it doesn't work.
-          saslServerTransport.readAll(inBuf, 0, 5);
-          saslServerTransport.readAll(inBuf, 5, 10);
-          saslServerTransport.readAll(inBuf, 15, inBuf.length - 15);
-          LOGGER.debug("server got: {}", new String(inBuf));
-          assertEquals(new String(inBuf), testMessage1);
-
-          LOGGER.debug("server writing: {}", testMessage2);
-          saslServerTransport.write(testMessage2.getBytes());
-          saslServerTransport.flush();
-
-          serverSocket.close();
-          saslServerTransport.close();
-        } catch (TTransportException e) {
-          fail(e.toString());
-        }
+    public void run() {
+      try {
+        internalRun();
+      } catch (Throwable t) {
+        thrown = t;
       }
-    };
+    }
+
+    private void internalRun() throws Exception {
+      TServerSocket serverSocket = new TServerSocket(ServerTestBase.PORT);
+      try {
+        acceptAndWrite(serverSocket);
+      } finally {
+        serverSocket.close();
+      }
+    }
+
+    private void acceptAndWrite(TServerSocket serverSocket)
+      throws Exception {
+      TTransport serverTransport = serverSocket.accept();
+      TTransport saslServerTransport = new TSaslServerTransport(
+        mechanism, SERVICE, HOST,
+        props, new TestSaslCallbackHandler(PASSWORD), serverTransport);
+
+      saslServerTransport.open();
+
+      byte[] inBuf = new byte[testMessage1.getBytes().length];
+      // Deliberately read less than the full buffer to ensure
+      // that TSaslTransport is correctly buffering reads. This
+      // will fail for the WRAPPED test, if it doesn't work.
+      saslServerTransport.readAll(inBuf, 0, 5);
+      saslServerTransport.readAll(inBuf, 5, 10);
+      saslServerTransport.readAll(inBuf, 15, inBuf.length - 15);
+      LOGGER.debug("server got: {}", new String(inBuf));
+      assertEquals(new String(inBuf), testMessage1);
+
+      LOGGER.debug("server writing: {}", testMessage2);
+      saslServerTransport.write(testMessage2.getBytes());
+      saslServerTransport.flush();
+
+      saslServerTransport.close();
+    }
+  }
+
+  private void testSaslOpen(final String mechanism, final Map<String, String> props)
+      throws Exception {
+    ServerThread serverThread = new ServerThread(mechanism, props);
     serverThread.start();
 
     try {
@@ -134,44 +168,95 @@
       // Ah well.
     }
 
-    TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT);
-    TTransport saslClientTransport = new TSaslClientTransport(mechanism,
-        PRINCIPAL, SERVICE, HOST, props, new TestSaslCallbackHandler(), clientSocket);
-    saslClientTransport.open();
-    LOGGER.debug("client writing: {}", testMessage1);
-    saslClientTransport.write(testMessage1.getBytes());
-    saslClientTransport.flush();
-
-    byte[] inBuf = new byte[testMessage2.getBytes().length];
-    saslClientTransport.readAll(inBuf, 0, inBuf.length);
-    LOGGER.debug("client got: {}", new String(inBuf));
-    assertEquals(new String(inBuf), testMessage2);
-
-    TTransportException expectedException = null;
     try {
+      TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT);
+      TTransport saslClientTransport = new TSaslClientTransport(mechanism,
+                                                                PRINCIPAL, SERVICE, HOST, props, new TestSaslCallbackHandler(PASSWORD), clientSocket);
       saslClientTransport.open();
-    } catch (TTransportException e) {
-      expectedException = e;
-    }
-    assertNotNull(expectedException);
+      LOGGER.debug("client writing: {}", testMessage1);
+      saslClientTransport.write(testMessage1.getBytes());
+      saslClientTransport.flush();
 
-    saslClientTransport.close();
+      byte[] inBuf = new byte[testMessage2.getBytes().length];
+      saslClientTransport.readAll(inBuf, 0, inBuf.length);
+      LOGGER.debug("client got: {}", new String(inBuf));
+      assertEquals(new String(inBuf), testMessage2);
 
-    try {
-      serverThread.join();
-    } catch (InterruptedException e) {
-      // Ah well.
+      TTransportException expectedException = null;
+      try {
+        saslClientTransport.open();
+      } catch (TTransportException e) {
+        expectedException = e;
+      }
+      assertNotNull(expectedException);
+
+      saslClientTransport.close();
+    } catch (Exception e) {
+      LOGGER.warn("Exception caught", e);
+      throw e;
+    } finally {
+      serverThread.interrupt();
+      try {
+        serverThread.join();
+      } catch (InterruptedException e) {
+        // Ah well.
+      }
+      assertNull(serverThread.thrown);
     }
   }
 
-  public void testUnwrappedOpen() throws SaslException, TTransportException {
+  public void testUnwrappedOpen() throws Exception {
     testSaslOpen(UNWRAPPED_MECHANISM, UNWRAPPED_PROPS);
   }
 
-  public void testWrappedOpen() throws SaslException, TTransportException {
+  public void testWrappedOpen() throws Exception {
     testSaslOpen(WRAPPED_MECHANISM, WRAPPED_PROPS);
   }
 
+  public void testAnonymousOpen() throws Exception {
+    testSaslOpen("ANONYMOUS", null);
+  }
+
+  /**
+   * Test that we get the proper exceptions thrown back the server when
+   * the client provides invalid password.
+   */
+  public void testBadPassword() throws Exception {
+    ServerThread serverThread = new ServerThread(UNWRAPPED_MECHANISM, UNWRAPPED_PROPS);
+    serverThread.start();
+
+    try {
+      Thread.sleep(1000);
+    } catch (InterruptedException e) {
+      // Ah well.
+    }
+
+    boolean clientSidePassed = true;
+
+    try {
+      TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT);
+      TTransport saslClientTransport = new TSaslClientTransport(
+        UNWRAPPED_MECHANISM, PRINCIPAL, SERVICE, HOST, UNWRAPPED_PROPS,
+        new TestSaslCallbackHandler("NOT THE PASSWORD"), clientSocket);
+      saslClientTransport.open();
+      clientSidePassed = false;
+      fail("Was able to open transport with bad password");
+    } catch (TTransportException tte) {
+      LOGGER.error("Exception for bad password", tte);
+      assertNotNull(tte.getMessage());
+      assertTrue(tte.getMessage().contains("Invalid response"));
+
+    } finally {
+      serverThread.interrupt();
+      serverThread.join();
+
+      if (clientSidePassed) {
+        assertNotNull(serverThread.thrown);
+        assertTrue(serverThread.thrown.getMessage().contains("Invalid response"));
+      }
+    }
+  }
+
   public void testWithServer() throws Exception {
     new TestTSaslTransportsWithServer().testIt();
   }
@@ -183,8 +268,9 @@
 
     @Override
     public TTransport getClientTransport(TTransport underlyingTransport) throws Exception {
-      return new TSaslClientTransport(WRAPPED_MECHANISM,
-          PRINCIPAL, SERVICE, HOST, WRAPPED_PROPS, new TestSaslCallbackHandler(), underlyingTransport);
+      return new TSaslClientTransport(
+        WRAPPED_MECHANISM, PRINCIPAL, SERVICE, HOST, WRAPPED_PROPS,
+        new TestSaslCallbackHandler(PASSWORD), underlyingTransport);
     }
 
     @Override
@@ -195,8 +281,9 @@
             // Transport
             TServerSocket socket = new TServerSocket(PORT);
 
-            TTransportFactory factory = new TSaslServerTransport.Factory(WRAPPED_MECHANISM,
-                SERVICE, HOST, WRAPPED_PROPS, new TestSaslCallbackHandler());
+            TTransportFactory factory = new TSaslServerTransport.Factory(
+              WRAPPED_MECHANISM, SERVICE, HOST, WRAPPED_PROPS,
+              new TestSaslCallbackHandler(PASSWORD));
             server = new TSimpleServer(processor, socket, factory, protoFactory);
 
             // Run it
@@ -222,4 +309,104 @@
 
   }
 
+
+  /**
+   * Implementation of SASL ANONYMOUS, used for testing client-side
+   * intial responses.
+   */
+  private static class AnonymousClient implements SaslClient {
+    private final String username;
+    private boolean hasProvidedInitialResponse;
+
+    public AnonymousClient(String username) {
+      this.username = username;
+    }
+
+    public String getMechanismName() { return "ANONYMOUS"; }
+    public boolean hasInitialResponse() { return true; }
+    public byte[] evaluateChallenge(byte[] challenge) throws SaslException {
+      if (hasProvidedInitialResponse) {
+        throw new SaslException("Already complete!");
+      }
+
+      try {
+        hasProvidedInitialResponse = true;
+        return username.getBytes("UTF-8");
+      } catch (IOException e) {
+        throw new SaslException(e.toString());
+      }
+    }
+    public boolean isComplete() { return hasProvidedInitialResponse; }
+    public byte[] unwrap(byte[] incoming, int offset, int len) {
+      throw new UnsupportedOperationException();
+    }
+    public byte[] wrap(byte[] outgoing, int offset, int len) {
+      throw new UnsupportedOperationException();
+    }
+    public Object getNegotiatedProperty(String propName) { return null; }
+    public void dispose() {}
+  }
+
+  private static class AnonymousServer implements SaslServer {
+    private String user;
+    public String getMechanismName() { return "ANONYMOUS"; }
+    public byte[] evaluateResponse(byte[] response) throws SaslException {
+      try {
+        this.user = new String(response, "UTF-8");
+      } catch (IOException e) {
+        throw new SaslException(e.toString());
+      }
+      return null;
+    }
+    public boolean isComplete() { return user != null; }
+    public String getAuthorizationID() { return user; }
+    public byte[] unwrap(byte[] incoming, int offset, int len) {
+      throw new UnsupportedOperationException();
+    }
+    public byte[] wrap(byte[] outgoing, int offset, int len) {
+      throw new UnsupportedOperationException();
+    }
+    public Object getNegotiatedProperty(String propName) { return null; }
+    public void dispose() {}
+
+  }
+
+  public static class SaslAnonymousFactory
+    implements SaslClientFactory, SaslServerFactory {
+
+    public SaslClient createSaslClient(
+      String[] mechanisms, String authorizationId, String protocol,
+      String serverName, Map<String,?> props, CallbackHandler cbh)
+    {
+      for (String mech : mechanisms) {
+        if ("ANONYMOUS".equals(mech)) {
+          return new AnonymousClient(authorizationId);
+        }
+      }
+      return null;
+    }
+
+    public SaslServer createSaslServer(
+      String mechanism, String protocol, String serverName, Map<String,?> props, CallbackHandler cbh)
+    {
+      if ("ANONYMOUS".equals(mechanism)) {
+        return new AnonymousServer();
+      }
+      return null;
+    }
+    public String[] getMechanismNames(Map<String, ?> props) {
+      return new String[] { "ANONYMOUS" };
+    }
+  }
+
+  static {
+    java.security.Security.addProvider(new SaslAnonymousProvider());
+  }
+  public static class SaslAnonymousProvider extends java.security.Provider {
+    public SaslAnonymousProvider() {
+      super("ThriftSaslAnonymous", 1.0, "Thrift Anonymous SASL provider");
+      put("SaslClientFactory.ANONYMOUS", SaslAnonymousFactory.class.getName());
+      put("SaslServerFactory.ANONYMOUS", SaslAnonymousFactory.class.getName());
+    }
+  }
 }