THRIFT-768. java: Async client for Java
This patch adds an implementation of a fully-asynchronous client that makes use of NIO. Stubs for the async method calls are generated along with the existing synchronous ones.
git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@948492 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/compiler/cpp/src/generate/t_java_generator.cc b/compiler/cpp/src/generate/t_java_generator.cc
index f1eb566..2db3ca3 100644
--- a/compiler/cpp/src/generate/t_java_generator.cc
+++ b/compiler/cpp/src/generate/t_java_generator.cc
@@ -115,8 +115,10 @@
std::string isset_field_id(t_field* field);
void generate_service_interface (t_service* tservice);
+ void generate_service_async_interface(t_service* tservice);
void generate_service_helpers (t_service* tservice);
void generate_service_client (t_service* tservice);
+ void generate_service_async_client(t_service* tservice);
void generate_service_server (t_service* tservice);
void generate_process_function (t_service* tservice, t_function* tfunction);
@@ -215,13 +217,16 @@
std::string base_type_name(t_base_type* tbase, bool in_container=false);
std::string declare_field(t_field* tfield, bool init=false);
std::string function_signature(t_function* tfunction, std::string prefix="");
- std::string argument_list(t_struct* tstruct);
+ std::string function_signature_async(t_function* tfunction, bool use_base_method = false, std::string prefix="");
+ std::string argument_list(t_struct* tstruct, bool include_types = true);
+ std::string async_function_call_arglist(t_function* tfunc, bool use_base_method = true, bool include_types = true);
+ std::string async_argument_list(t_function* tfunct, t_struct* tstruct, t_type* ttype, bool include_types=false);
std::string type_to_enum(t_type* ttype);
std::string get_enum_class_name(t_type* type);
void generate_struct_desc(ofstream& out, t_struct* tstruct);
void generate_field_descs(ofstream& out, t_struct* tstruct);
void generate_field_name_constants(ofstream& out, t_struct* tstruct);
-
+
bool type_can_be_null(t_type* ttype) {
ttype = get_true_type(ttype);
@@ -330,7 +335,9 @@
return
string() +
"import org.apache.thrift.*;\n" +
+ "import org.apache.thrift.async.*;\n" +
"import org.apache.thrift.meta_data.*;\n" +
+ "import org.apache.thrift.transport.*;\n" +
"import org.apache.thrift.protocol.*;\n\n";
}
@@ -2133,7 +2140,9 @@
// Generate the three main parts of the service
generate_service_interface(tservice);
+ generate_service_async_interface(tservice);
generate_service_client(tservice);
+ generate_service_async_client(tservice);
generate_service_server(tservice);
generate_service_helpers(tservice);
@@ -2164,13 +2173,29 @@
vector<t_function*>::iterator f_iter;
for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
generate_java_doc(f_service_, *f_iter);
- indent(f_service_) << "public " << function_signature(*f_iter) << ";" <<
- endl << endl;
+ indent(f_service_) << "public " << function_signature(*f_iter) << ";" << endl << endl;
}
indent_down();
- f_service_ <<
- indent() << "}" << endl <<
- endl;
+ f_service_ << indent() << "}" << endl << endl;
+}
+
+void t_java_generator::generate_service_async_interface(t_service* tservice) {
+ string extends = "";
+ string extends_iface = "";
+ if (tservice->get_extends() != NULL) {
+ extends = type_name(tservice->get_extends());
+ extends_iface = " extends " + extends + " .AsyncIface";
+ }
+
+ f_service_ << indent() << "public interface AsyncIface" << extends_iface << " {" << endl << endl;
+ indent_up();
+ vector<t_function*> functions = tservice->get_functions();
+ vector<t_function*>::iterator f_iter;
+ for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
+ indent(f_service_) << "public " << function_signature_async(*f_iter, true) << " throws TException;" << endl << endl;
+ }
+ indent_down();
+ f_service_ << indent() << "}" << endl << endl;
}
/**
@@ -2404,6 +2429,138 @@
"}" << endl;
}
+void t_java_generator::generate_service_async_client(t_service* tservice) {
+ string extends = "TAsyncClient";
+ string extends_client = "";
+ if (tservice->get_extends() != NULL) {
+ extends = type_name(tservice->get_extends()) + ".AsyncClient";
+ // extends_client = " extends " + extends + ".AsyncClient";
+ }
+
+ indent(f_service_) <<
+ "public static class AsyncClient extends " << extends << " implements AsyncIface {" << endl;
+ indent_up();
+
+ // Factory method
+ indent(f_service_) << "public static class Factory implements TAsyncClientFactory<AsyncClient> {" << endl;
+ indent(f_service_) << " private TAsyncClientManager clientManager;" << endl;
+ indent(f_service_) << " private TProtocolFactory protocolFactory;" << endl;
+ indent(f_service_) << " public Factory(TAsyncClientManager clientManager, TProtocolFactory protocolFactory) {" << endl;
+ indent(f_service_) << " this.clientManager = clientManager;" << endl;
+ indent(f_service_) << " this.protocolFactory = protocolFactory;" << endl;
+ indent(f_service_) << " }" << endl;
+ indent(f_service_) << " public AsyncClient getAsyncClient(TNonblockingTransport transport) {" << endl;
+ indent(f_service_) << " return new AsyncClient(protocolFactory, clientManager, transport);" << endl;
+ indent(f_service_) << " }" << endl;
+ indent(f_service_) << "}" << endl << endl;
+
+ indent(f_service_) << "public AsyncClient(TProtocolFactory protocolFactory, TAsyncClientManager clientManager, TNonblockingTransport transport) {" << endl;
+ indent(f_service_) << " super(protocolFactory, clientManager, transport);" << endl;
+ indent(f_service_) << "}" << endl << endl;
+
+ // Generate client method implementations
+ vector<t_function*> functions = tservice->get_functions();
+ vector<t_function*>::const_iterator f_iter;
+ for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
+ string funname = (*f_iter)->get_name();
+ t_type* ret_type = (*f_iter)->get_returntype();
+ t_struct* arg_struct = (*f_iter)->get_arglist();
+ string funclassname = funname + "_call";
+ const vector<t_field*>& fields = arg_struct->get_members();
+ const std::vector<t_field*>& xceptions = (*f_iter)->get_xceptions()->get_members();
+ vector<t_field*>::const_iterator fld_iter;
+ string args_name = (*f_iter)->get_name() + "_args";
+ string result_name = (*f_iter)->get_name() + "_result";
+
+ // Main method body
+ indent(f_service_) << "public " << function_signature_async(*f_iter, false) << " throws TException {" << endl;
+ indent(f_service_) << " checkReady();" << endl;
+ indent(f_service_) << " " << funclassname << " method_call = new " + funclassname + "(" << async_argument_list(*f_iter, arg_struct, ret_type) << ", this, protocolFactory, transport);" << endl;
+ indent(f_service_) << " manager.call(method_call);" << endl;
+ indent(f_service_) << "}" << endl;
+
+ f_service_ << endl;
+
+ // TAsyncMethod object for this function call
+ indent(f_service_) << "public static class " + funclassname + " extends TAsyncMethodCall {" << endl;
+ indent_up();
+
+ // Member variables
+ for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) {
+ indent(f_service_) << "private " + type_name((*fld_iter)->get_type()) + " " + (*fld_iter)->get_name() + ";" << endl;
+ }
+
+ // NOTE since we use a new Client instance to deserialize, let's keep seqid to 0 for now
+ // indent(f_service_) << "private int seqid;" << endl << endl;
+
+ // Constructor
+ indent(f_service_) << "public " + funclassname + "(" + async_argument_list(*f_iter, arg_struct, ret_type, true) << ", TAsyncClient client, TProtocolFactory protocolFactory, TNonblockingTransport transport) throws TException {" << endl;
+ indent(f_service_) << " super(client, protocolFactory, transport, resultHandler, " << ((*f_iter)->is_oneway() ? "true" : "false") << ");" << endl;
+
+ // Assign member variables
+ for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) {
+ indent(f_service_) << " this." + (*fld_iter)->get_name() + " = " + (*fld_iter)->get_name() + ";" << endl;
+ }
+
+ indent(f_service_) << "}" << endl << endl;
+
+ indent(f_service_) << "public void write_args(TProtocol prot) throws TException {" << endl;
+ indent_up();
+
+ // Serialize request
+ // NOTE we are leaving seqid as 0, for now (see above)
+ f_service_ <<
+ indent() << "prot.writeMessageBegin(new TMessage(\"" << funname << "\", TMessageType.CALL, 0));" << endl <<
+ indent() << args_name << " args = new " << args_name << "();" << endl;
+
+ for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) {
+ f_service_ << indent() << "args.set" << get_cap_name((*fld_iter)->get_name()) << "(" << (*fld_iter)->get_name() << ");" << endl;
+ }
+
+ f_service_ <<
+ indent() << "args.write(prot);" << endl <<
+ indent() << "prot.writeMessageEnd();" << endl;
+
+ indent_down();
+ indent(f_service_) << "}" << endl << endl;
+
+ // Return method
+ indent(f_service_) << "public " + type_name(ret_type) + " getResult() throws ";
+ vector<t_field*>::const_iterator x_iter;
+ for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) {
+ f_service_ << type_name((*x_iter)->get_type(), false, false) + ", ";
+ }
+ f_service_ << "TException {" << endl;
+
+ indent_up();
+ f_service_ <<
+ indent() << "if (getState() != State.RESPONSE_READ) {" << endl <<
+ indent() << " throw new IllegalStateException(\"Method call not finished!\");" << endl <<
+ indent() << "}" << endl <<
+ indent() << "TMemoryInputTransport memoryTransport = new TMemoryInputTransport(getFrameBuffer().array());" << endl <<
+ indent() << "TProtocol prot = client.getProtocolFactory().getProtocol(memoryTransport);" << endl;
+ if (!(*f_iter)->is_oneway()) {
+ indent(f_service_);
+ if (!ret_type->is_void()) {
+ f_service_ << "return ";
+ }
+ f_service_ << "(new Client(prot)).recv_" + funname + "();" << endl;
+ }
+
+ // Close function
+ indent_down();
+ indent(f_service_) << "}" << endl;
+
+ // Close class
+ indent_down();
+ indent(f_service_) << "}" << endl << endl;
+ }
+
+ // Close AsyncClient
+ scope_down(f_service_);
+ f_service_ << endl;
+}
+
/**
* Generates a service server definition.
*
@@ -3247,9 +3404,48 @@
}
/**
+ * Renders a function signature of the form 'void name(args, resultHandler)'
+ *
+ * @params tfunction Function definition
+ * @return String of rendered function definition
+ */
+string t_java_generator::function_signature_async(t_function* tfunction, bool use_base_method, string prefix) {
+ std::string arglist = async_function_call_arglist(tfunction, use_base_method, true);
+
+ std::string ret_type = "";
+ if (use_base_method) {
+ ret_type += "AsyncClient.";
+ }
+ ret_type += tfunction->get_name() + "_call";
+
+ std::string result = prefix + "void " + tfunction->get_name() + "(" + arglist + ")";
+ return result;
+}
+
+string t_java_generator::async_function_call_arglist(t_function* tfunc, bool use_base_method, bool include_types) {
+ std::string arglist = "";
+ if (tfunc->get_arglist()->get_members().size() > 0) {
+ arglist = argument_list(tfunc->get_arglist(), include_types) + ", ";
+ }
+
+ std::string ret_type = "";
+ if (use_base_method) {
+ ret_type += "AsyncClient.";
+ }
+ ret_type += tfunc->get_name() + "_call";
+
+ if (include_types) {
+ arglist += "AsyncMethodCallback<" + ret_type + "> ";
+ }
+ arglist += "resultHandler";
+
+ return arglist;
+}
+
+/**
* Renders a comma separated field list, with type names
*/
-string t_java_generator::argument_list(t_struct* tstruct) {
+string t_java_generator::argument_list(t_struct* tstruct, bool include_types) {
string result = "";
const vector<t_field*>& fields = tstruct->get_members();
@@ -3261,11 +3457,40 @@
} else {
result += ", ";
}
- result += type_name((*f_iter)->get_type()) + " " + (*f_iter)->get_name();
+ if (include_types) {
+ result += type_name((*f_iter)->get_type()) + " ";
+ }
+ result += (*f_iter)->get_name();
}
return result;
}
+string t_java_generator::async_argument_list(t_function* tfunct, t_struct* tstruct, t_type* ttype, bool include_types) {
+ string result = "";
+ const vector<t_field*>& fields = tstruct->get_members();
+ vector<t_field*>::const_iterator f_iter;
+ bool first = true;
+ for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
+ if (first) {
+ first = false;
+ } else {
+ result += ", ";
+ }
+ if (include_types) {
+ result += type_name((*f_iter)->get_type()) + " ";
+ }
+ result += (*f_iter)->get_name();
+ }
+ if (!first) {
+ result += ", ";
+ }
+ if (include_types) {
+ result += "AsyncMethodCallback<" + tfunct->get_name() + "_call" + "> ";
+ }
+ result += "resultHandler";
+ return result;
+}
+
/**
* Converts the parse type to a C++ enum string for the given type.
*/
diff --git a/lib/java/src/org/apache/thrift/TByteArrayOutputStream.java b/lib/java/src/org/apache/thrift/TByteArrayOutputStream.java
index e35fbcb..9ed83c0 100644
--- a/lib/java/src/org/apache/thrift/TByteArrayOutputStream.java
+++ b/lib/java/src/org/apache/thrift/TByteArrayOutputStream.java
@@ -35,7 +35,6 @@
super();
}
-
public byte[] get() {
return buf;
}
diff --git a/lib/java/src/org/apache/thrift/async/AsyncMethodCallback.java b/lib/java/src/org/apache/thrift/async/AsyncMethodCallback.java
new file mode 100644
index 0000000..b8cd9ed
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/async/AsyncMethodCallback.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.thrift.async;
+
+public interface AsyncMethodCallback<T> {
+ /**
+ * This method will be called when the remote side has completed invoking
+ * your method call and the result is fully read. For oneway method calls,
+ * this method will be called as soon as we have completed writing out the
+ * request.
+ * @param response
+ */
+ public void onComplete(T response);
+
+ /**
+ * This method will be called when there is an unexpected clientside
+ * exception. This does not include application-defined exceptions that
+ * appear in the IDL, but rather things like IOExceptions.
+ * @param throwable
+ */
+ public void onError(Throwable throwable);
+}
diff --git a/lib/java/src/org/apache/thrift/async/TAsyncClient.java b/lib/java/src/org/apache/thrift/async/TAsyncClient.java
new file mode 100644
index 0000000..2e8dea3
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/async/TAsyncClient.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.thrift.async;
+
+import org.apache.thrift.protocol.TProtocolFactory;
+import org.apache.thrift.transport.TNonblockingTransport;
+
+public abstract class TAsyncClient {
+ protected final TProtocolFactory protocolFactory;
+ protected final TNonblockingTransport transport;
+ protected final TAsyncClientManager manager;
+ private TAsyncMethodCall currentMethod;
+ private Throwable error;
+
+ public TAsyncClient(TProtocolFactory protocolFactory, TAsyncClientManager manager, TNonblockingTransport transport) {
+ this.protocolFactory = protocolFactory;
+ this.manager = manager;
+ this.transport = transport;
+ }
+
+ public TProtocolFactory getProtocolFactory() {
+ return protocolFactory;
+ }
+
+ /**
+ * Is the client in an error state?
+ * @return
+ */
+ public boolean hasError() {
+ return error != null;
+ }
+
+ /**
+ * Get the client's error - returns null if no error
+ * @return
+ */
+ public Throwable getError() {
+ return error;
+ }
+
+ protected void checkReady() {
+ // Ensure we are not currently executing a method
+ if (currentMethod != null) {
+ throw new IllegalStateException("Client is currently executing another method: " + currentMethod.getClass().getName());
+ }
+
+ // Ensure we're not in an error state
+ if (error != null) {
+ throw new IllegalStateException("Client has an error!", error);
+ }
+ }
+
+ /**
+ * Called by delegate method when finished
+ */
+ protected void onComplete() {
+ currentMethod = null;
+ }
+
+ /**
+ * Called by delegate method on error
+ */
+ protected void onError(Throwable throwable) {
+ transport.close();
+ currentMethod = null;
+ error = throwable;
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/async/TAsyncClientFactory.java b/lib/java/src/org/apache/thrift/async/TAsyncClientFactory.java
new file mode 100644
index 0000000..28feb73
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/async/TAsyncClientFactory.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.thrift.async;
+
+import org.apache.thrift.transport.TNonblockingTransport;
+
+public interface TAsyncClientFactory<T extends TAsyncClient> {
+ public T getAsyncClient(TNonblockingTransport transport);
+}
diff --git a/lib/java/src/org/apache/thrift/async/TAsyncClientManager.java b/lib/java/src/org/apache/thrift/async/TAsyncClientManager.java
new file mode 100644
index 0000000..8636bc8
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/async/TAsyncClientManager.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.thrift.async;
+
+import java.io.IOException;
+import java.nio.channels.SelectionKey;
+import java.nio.channels.Selector;
+import java.nio.channels.spi.SelectorProvider;
+import java.util.Iterator;
+import java.util.concurrent.ConcurrentLinkedQueue;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Contains selector thread which transitions method call objects
+ */
+public class TAsyncClientManager {
+ private static final Logger LOGGER = LoggerFactory.getLogger(TAsyncClientManager.class.getName());
+
+ private final SelectThread selectThread;
+ private final ConcurrentLinkedQueue<TAsyncMethodCall> pendingCalls = new ConcurrentLinkedQueue<TAsyncMethodCall>();
+
+ public TAsyncClientManager() throws IOException {
+ this.selectThread = new SelectThread();
+ selectThread.start();
+ }
+
+ public void call(TAsyncMethodCall method) {
+ pendingCalls.add(method);
+ selectThread.getSelector().wakeup();
+ }
+
+ public void stop() {
+ selectThread.finish();
+ }
+
+ private class SelectThread extends Thread {
+ private final Selector selector;
+ private volatile boolean running;
+
+ public SelectThread() throws IOException {
+ this.selector = SelectorProvider.provider().openSelector();
+ this.running = true;
+ // We don't want to hold up the JVM when shutting down
+ setDaemon(true);
+ }
+
+ public Selector getSelector() {
+ return selector;
+ }
+
+ public void finish() {
+ running = false;
+ selector.wakeup();
+ }
+
+ public void run() {
+ while (running) {
+ try {
+ selector.select();
+ } catch (IOException e) {
+ LOGGER.error("Caught IOException in TAsyncClientManager!", e);
+ }
+
+ // Handle any ready channels calls
+ Iterator<SelectionKey> keys = selector.selectedKeys().iterator();
+ while (keys.hasNext()) {
+ SelectionKey key = keys.next();
+ keys.remove();
+ if (!key.isValid()) {
+ // this should only have happened if the method call experienced an
+ // error and the key was cancelled. just skip it.
+ continue;
+ }
+ TAsyncMethodCall method = (TAsyncMethodCall)key.attachment();
+ method.transition(key);
+ }
+
+ // Start any new calls
+ TAsyncMethodCall methodCall;
+ while ((methodCall = pendingCalls.poll()) != null) {
+ try {
+ SelectionKey key = methodCall.registerWithSelector(selector);
+ methodCall.transition(key);
+ } catch (IOException e) {
+ LOGGER.warn("Caught IOException in TAsyncClientManager!", e);
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/async/TAsyncMethodCall.java b/lib/java/src/org/apache/thrift/async/TAsyncMethodCall.java
new file mode 100644
index 0000000..e130087
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/async/TAsyncMethodCall.java
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.thrift.async;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.SelectionKey;
+import java.nio.channels.Selector;
+
+import org.apache.thrift.TException;
+import org.apache.thrift.protocol.TProtocol;
+import org.apache.thrift.protocol.TProtocolFactory;
+import org.apache.thrift.transport.TFramedTransport;
+import org.apache.thrift.transport.TMemoryBuffer;
+import org.apache.thrift.transport.TNonblockingTransport;
+import org.apache.thrift.transport.TTransportException;
+
+/**
+ * Encapsulates an async method call
+ * Need to generate:
+ * - private void write_args(TProtocol protocol)
+ * - public T getResult() throws <Exception_1>, <Exception_2>, ...
+ * @param <T>
+ */
+public abstract class TAsyncMethodCall<T extends TAsyncMethodCall> {
+ public static enum State {
+ WRITING_REQUEST_SIZE,
+ WRITING_REQUEST_BODY,
+ READING_RESPONSE_SIZE,
+ READING_RESPONSE_BODY,
+ RESPONSE_READ,
+ ERROR;
+ }
+
+ private static final int INITIAL_MEMORY_BUFFER_SIZE = 128;
+
+ protected final TNonblockingTransport transport;
+ private final TProtocolFactory protocolFactory;
+ protected final TAsyncClient client;
+ private final AsyncMethodCallback<T> callback;
+ private final boolean isOneway;
+
+ private ByteBuffer sizeBuffer;
+ private final byte[] sizeBufferArray = new byte[4];
+
+ private ByteBuffer frameBuffer;
+ private State state;
+
+ protected TAsyncMethodCall(TAsyncClient client, TProtocolFactory protocolFactory, TNonblockingTransport transport, AsyncMethodCallback<T> callback, boolean isOneway) throws TException {
+ this.transport = transport;
+ this.callback = callback;
+ this.protocolFactory = protocolFactory;
+ this.client = client;
+ this.isOneway = isOneway;
+
+ this.state = State.WRITING_REQUEST_SIZE;
+ prepareMethodCall();
+ }
+
+ protected State getState() {
+ return state;
+ }
+
+ protected abstract void write_args(TProtocol protocol) throws TException;
+
+ private void prepareMethodCall() throws TException {
+ TMemoryBuffer memoryBuffer = new TMemoryBuffer(INITIAL_MEMORY_BUFFER_SIZE);
+ TProtocol protocol = protocolFactory.getProtocol(memoryBuffer);
+ write_args(protocol);
+
+ int length = memoryBuffer.length();
+ frameBuffer = ByteBuffer.wrap(memoryBuffer.getArray(), 0, length);
+
+ TFramedTransport.encodeFrameSize(length, sizeBufferArray);
+ sizeBuffer = ByteBuffer.wrap(sizeBufferArray);
+ }
+
+ SelectionKey registerWithSelector(Selector sel) throws IOException {
+ SelectionKey key = transport.registerSelector(sel, SelectionKey.OP_WRITE);
+ key.attach(this);
+ return key;
+ }
+
+ protected ByteBuffer getFrameBuffer() {
+ return frameBuffer;
+ }
+
+ /**
+ * Transition to next state, doing whatever work is required. Since this
+ * method is only called by the selector thread, we can make changes to our
+ * select interests without worrying about concurrency.
+ * @param key
+ */
+ protected void transition(SelectionKey key) {
+ // Ensure key is valid
+ if (!key.isValid()) {
+ key.cancel();
+ Exception e = new TTransportException("Selection key not valid!");
+ client.onError(e);
+ callback.onError(e);
+ return;
+ }
+
+ // Transition function
+ try {
+ switch (state) {
+ case WRITING_REQUEST_SIZE:
+ doWritingRequestSize();
+ break;
+ case WRITING_REQUEST_BODY:
+ doWritingRequestBody(key);
+ break;
+ case READING_RESPONSE_SIZE:
+ doReadingResponseSize();
+ break;
+ case READING_RESPONSE_BODY:
+ doReadingResponseBody(key);
+ break;
+ case RESPONSE_READ:
+ case ERROR:
+ throw new IllegalStateException("Method call in state " + state
+ + " but selector called transition method. Seems like a bug...");
+ }
+ } catch (Throwable e) {
+ state = State.ERROR;
+ key.cancel();
+ key.attach(null);
+ client.onError(e);
+ callback.onError(e);
+ }
+ }
+
+ private void doReadingResponseBody(SelectionKey key) throws IOException {
+ if (transport.read(frameBuffer) < 0) {
+ throw new IOException("Read call frame failed");
+ }
+ if (frameBuffer.remaining() == 0) {
+ cleanUpAndFireCallback(key);
+ }
+ }
+
+ private void cleanUpAndFireCallback(SelectionKey key) {
+ state = State.RESPONSE_READ;
+ key.interestOps(0);
+ // this ensures that the TAsyncMethod instance doesn't hang around
+ key.attach(null);
+ key.cancel();
+ client.onComplete();
+ callback.onComplete((T)this);
+ }
+
+ private void doReadingResponseSize() throws IOException {
+ if (transport.read(sizeBuffer) < 0) {
+ throw new IOException("Read call frame size failed");
+ }
+ if (sizeBuffer.remaining() == 0) {
+ state = State.READING_RESPONSE_BODY;
+ frameBuffer = ByteBuffer.allocate(TFramedTransport.decodeFrameSize(sizeBufferArray));
+ }
+ }
+
+ private void doWritingRequestBody(SelectionKey key) throws IOException {
+ if (transport.write(frameBuffer) < 0) {
+ throw new IOException("Write call frame failed");
+ }
+ if (frameBuffer.remaining() == 0) {
+ if (isOneway) {
+ cleanUpAndFireCallback(key);
+ } else {
+ state = State.READING_RESPONSE_SIZE;
+ sizeBuffer.rewind(); // Prepare to read incoming frame size
+ key.interestOps(SelectionKey.OP_READ);
+ }
+ }
+ }
+
+ private void doWritingRequestSize() throws IOException {
+ if (transport.write(sizeBuffer) < 0) {
+ throw new IOException("Write call frame size failed");
+ }
+ if (sizeBuffer.remaining() == 0) {
+ state = State.WRITING_REQUEST_BODY;
+ }
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/TFramedTransport.java b/lib/java/src/org/apache/thrift/transport/TFramedTransport.java
index fab9c9b..32483ee 100644
--- a/lib/java/src/org/apache/thrift/transport/TFramedTransport.java
+++ b/lib/java/src/org/apache/thrift/transport/TFramedTransport.java
@@ -23,7 +23,7 @@
/**
* TFramedTransport is a buffered TTransport that ensures a fully read message
- * every time by preceeding messages with a 4-byte frame size.
+ * every time by preceeding messages with a 4-byte frame size.
*/
public class TFramedTransport extends TTransport {
@@ -58,6 +58,7 @@
maxLength_ = maxLength;
}
+ @Override
public TTransport getTransport(TTransport base) {
return new TFramedTransport(base, maxLength_);
}
@@ -122,14 +123,11 @@
readBuffer_.consumeBuffer(len);
}
- private final byte[] i32rd = new byte[4];
+ private final byte[] i32buf = new byte[4];
+
private void readFrame() throws TTransportException {
- transport_.readAll(i32rd, 0, 4);
- int size =
- ((i32rd[0] & 0xff) << 24) |
- ((i32rd[1] & 0xff) << 16) |
- ((i32rd[2] & 0xff) << 8) |
- ((i32rd[3] & 0xff));
+ transport_.readAll(i32buf, 0, 4);
+ int size = decodeFrameSize(i32buf);
if (size < 0) {
throw new TTransportException("Read a negative frame size (" + size + ")!");
@@ -148,18 +146,30 @@
writeBuffer_.write(buf, off, len);
}
+ @Override
public void flush() throws TTransportException {
byte[] buf = writeBuffer_.get();
int len = writeBuffer_.len();
writeBuffer_.reset();
- byte[] i32out = new byte[4];
- i32out[0] = (byte)(0xff & (len >> 24));
- i32out[1] = (byte)(0xff & (len >> 16));
- i32out[2] = (byte)(0xff & (len >> 8));
- i32out[3] = (byte)(0xff & (len));
- transport_.write(i32out, 0, 4);
+ encodeFrameSize(len, i32buf);
+ transport_.write(i32buf, 0, 4);
transport_.write(buf, 0, len);
transport_.flush();
}
+
+ public static final void encodeFrameSize(final int frameSize, final byte[] buf) {
+ buf[0] = (byte)(0xff & (frameSize >> 24));
+ buf[1] = (byte)(0xff & (frameSize >> 16));
+ buf[2] = (byte)(0xff & (frameSize >> 8));
+ buf[3] = (byte)(0xff & (frameSize));
+ }
+
+ public static final int decodeFrameSize(final byte[] buf) {
+ return
+ ((buf[0] & 0xff) << 24) |
+ ((buf[1] & 0xff) << 16) |
+ ((buf[2] & 0xff) << 8) |
+ ((buf[3] & 0xff));
+ }
}
diff --git a/lib/java/src/org/apache/thrift/transport/TMemoryBuffer.java b/lib/java/src/org/apache/thrift/transport/TMemoryBuffer.java
index 886fcbf..9b906db 100644
--- a/lib/java/src/org/apache/thrift/transport/TMemoryBuffer.java
+++ b/lib/java/src/org/apache/thrift/transport/TMemoryBuffer.java
@@ -24,12 +24,12 @@
/**
* Memory buffer-based implementation of the TTransport interface.
- *
*/
public class TMemoryBuffer extends TTransport {
-
/**
- *
+ * Create a TMemoryBuffer with an initial buffer size of <i>size</i>. The
+ * internal buffer will grow as necessary to accomodate the size of the data
+ * being written to it.
*/
public TMemoryBuffer(int size) {
arr_ = new TByteArrayOutputStream(size);
@@ -90,9 +90,13 @@
// Position to read next byte from
private int pos_;
-
+
public int length() {
return arr_.size();
}
+
+ public byte[] getArray() {
+ return arr_.get();
+ }
}
diff --git a/lib/java/src/org/apache/thrift/transport/TNonblockingSocket.java b/lib/java/src/org/apache/thrift/transport/TNonblockingSocket.java
index bc2d539..313ef85 100644
--- a/lib/java/src/org/apache/thrift/transport/TNonblockingSocket.java
+++ b/lib/java/src/org/apache/thrift/transport/TNonblockingSocket.java
@@ -21,6 +21,7 @@
package org.apache.thrift.transport;
import java.io.IOException;
+import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketException;
import java.nio.ByteBuffer;
@@ -41,21 +42,22 @@
private Socket socket_ = null;
/**
- * Remote host
- */
- private String host_ = null;
-
- /**
- * Remote port
- */
- private int port_ = 0;
-
- /**
* Socket timeout
*/
private int timeout_ = 0;
/**
+ * Create a new nonblocking socket transport connected to host:port.
+ * @param host
+ * @param port
+ * @throws TTransportException
+ * @throws IOException
+ */
+ public TNonblockingSocket(String host, int port) throws TTransportException, IOException {
+ this(SocketChannel.open(new InetSocketAddress(host, port)));
+ }
+
+ /**
* Constructor that takes an already created socket.
*
* @param socketChannel Already created SocketChannel object
diff --git a/lib/java/test/org/apache/thrift/async/TestTAsyncClientManager.java b/lib/java/test/org/apache/thrift/async/TestTAsyncClientManager.java
new file mode 100644
index 0000000..5c8ff76
--- /dev/null
+++ b/lib/java/test/org/apache/thrift/async/TestTAsyncClientManager.java
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.thrift.async;
+
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import junit.framework.TestCase;
+
+import org.apache.thrift.TException;
+import org.apache.thrift.protocol.TBinaryProtocol;
+import org.apache.thrift.server.TNonblockingServer;
+import org.apache.thrift.transport.TNonblockingServerSocket;
+import org.apache.thrift.transport.TNonblockingSocket;
+
+import thrift.test.CompactProtoTestStruct;
+import thrift.test.Srv;
+import thrift.test.Srv.Iface;
+import thrift.test.Srv.AsyncClient.Janky_call;
+import thrift.test.Srv.AsyncClient.onewayMethod_call;
+import thrift.test.Srv.AsyncClient.voidMethod_call;
+
+public class TestTAsyncClientManager extends TestCase {
+ private static abstract class FailureLessCallback<T extends TAsyncMethodCall> implements AsyncMethodCallback<T> {
+ @Override
+ public void onError(Throwable throwable) {
+ throwable.printStackTrace();
+ fail("unexpected error " + throwable);
+ }
+ }
+
+ public class SrvHandler implements Iface {
+ @Override
+ public int Janky(int arg) throws TException {
+ return 0;
+ }
+
+ @Override
+ public void methodWithDefaultArgs(int something) throws TException {
+ }
+
+ @Override
+ public int primitiveMethod() throws TException {
+ return 0;
+ }
+
+ @Override
+ public CompactProtoTestStruct structMethod() throws TException {
+ return null;
+ }
+
+ @Override
+ public void voidMethod() throws TException {
+ }
+
+ @Override
+ public void onewayMethod() throws TException {
+ }
+ }
+
+ public void testIt() throws Exception {
+ // put up a server
+ final TNonblockingServer s = new TNonblockingServer(new Srv.Processor(new SrvHandler()), new TNonblockingServerSocket(12345));
+ new Thread(new Runnable() {
+ @Override
+ public void run() {
+ s.serve();
+ }
+ }).start();
+ Thread.sleep(1000);
+
+ // set up async client manager
+ TAsyncClientManager acm = new TAsyncClientManager();
+
+ // connect an async client
+ TNonblockingSocket clientSock = new TNonblockingSocket("localhost", 12345);
+ Srv.AsyncClient client = new Srv.AsyncClient(new TBinaryProtocol.Factory(), acm, clientSock);
+
+ final Object o = new Object();
+
+ // make a standard method call
+ final AtomicBoolean jankyReturned = new AtomicBoolean(false);
+ client.Janky(1, new FailureLessCallback<Srv.AsyncClient.Janky_call>() {
+ @Override
+ public void onComplete(Janky_call response) {
+ try {
+ assertEquals(0, response.getResult());
+ jankyReturned.set(true);
+ } catch (TException e) {
+ fail("unexpected exception: " + e);
+ }
+ synchronized(o) {
+ o.notifyAll();
+ }
+ }
+ });
+
+ synchronized(o) {
+ o.wait(100000);
+ }
+ assertTrue(jankyReturned.get());
+
+ // make a void method call
+ final AtomicBoolean voidMethodReturned = new AtomicBoolean(false);
+ client.voidMethod(new FailureLessCallback<Srv.AsyncClient.voidMethod_call>() {
+ @Override
+ public void onComplete(voidMethod_call response) {
+ try {
+ response.getResult();
+ voidMethodReturned.set(true);
+ } catch (TException e) {
+ fail("unexpected exception " + e);
+ }
+ synchronized (o) {
+ o.notifyAll();
+ }
+ }
+ });
+
+ synchronized(o) {
+ o.wait(1000);
+ }
+ assertTrue(voidMethodReturned.get());
+
+ // make a oneway method call
+ final AtomicBoolean onewayReturned = new AtomicBoolean(false);
+ client.onewayMethod(new FailureLessCallback<onewayMethod_call>() {
+ @Override
+ public void onComplete(onewayMethod_call response) {
+ try {
+ response.getResult();
+ onewayReturned.set(true);
+ } catch (TException e) {
+ fail("unexpected exception " + e);
+ }
+ synchronized(o) {
+ o.notifyAll();
+ }
+ }
+ });
+ synchronized(o) {
+ o.wait(1000);
+ }
+
+ assertTrue(onewayReturned.get());
+
+ // make another standard method call
+ final AtomicBoolean voidAfterOnewayReturned = new AtomicBoolean(false);
+ client.voidMethod(new FailureLessCallback<voidMethod_call>() {
+ @Override
+ public void onComplete(voidMethod_call response) {
+ try {
+ response.getResult();
+ voidAfterOnewayReturned.set(true);
+ } catch (TException e) {
+ fail("unexpected exception " + e);
+ }
+ synchronized(o) {
+ o.notifyAll();
+ }
+ }
+ });
+ synchronized(o) {
+ o.wait(1000);
+ }
+
+ assertTrue(voidAfterOnewayReturned.get());
+ }
+}
diff --git a/lib/java/test/org/apache/thrift/protocol/ProtocolTestBase.java b/lib/java/test/org/apache/thrift/protocol/ProtocolTestBase.java
index 365cef7..da0de05 100644
--- a/lib/java/test/org/apache/thrift/protocol/ProtocolTestBase.java
+++ b/lib/java/test/org/apache/thrift/protocol/ProtocolTestBase.java
@@ -305,6 +305,10 @@
public void methodWithDefaultArgs(int something) throws TException {
}
+
+ @Override
+ public void onewayMethod() throws TException {
+ }
};
Srv.Processor testProcessor = new Srv.Processor(handler);
diff --git a/test/DebugProtoTest.thrift b/test/DebugProtoTest.thrift
index dbce93e..5e361d2 100644
--- a/test/DebugProtoTest.thrift
+++ b/test/DebugProtoTest.thrift
@@ -228,14 +228,16 @@
service Srv {
i32 Janky(1: i32 arg);
-
+
// return type only methods
-
+
void voidMethod();
i32 primitiveMethod();
CompactProtoTestStruct structMethod();
-
+
void methodWithDefaultArgs(1: i32 something = MYCONST);
+
+ oneway void onewayMethod();
}
service Inherited extends Srv {