THRIFT-447. java: Make an abstract base Client class so we can generate less code

This patch introduces a handful of abstract, non-generated classes that allow us to generate much less code for service implementations.

git-svn-id: https://svn.apache.org/repos/asf/thrift/trunk@1068487 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/compiler/cpp/src/generate/t_java_generator.cc b/compiler/cpp/src/generate/t_java_generator.cc
index b56d720..e84fd1d 100644
--- a/compiler/cpp/src/generate/t_java_generator.cc
+++ b/compiler/cpp/src/generate/t_java_generator.cc
@@ -2267,13 +2267,15 @@
 void t_java_generator::generate_service_client(t_service* tservice) {
   string extends = "";
   string extends_client = "";
-  if (tservice->get_extends() != NULL) {
+  if (tservice->get_extends() == NULL) {
+    extends_client = "org.apache.thrift.TServiceClient";
+  } else {
     extends = type_name(tservice->get_extends());
-    extends_client = " extends " + extends + ".Client";
+    extends_client = extends + ".Client";
   }
 
   indent(f_service_) <<
-    "public static class Client" << extends_client << " implements org.apache.thrift.TServiceClient, Iface {" << endl;
+    "public static class Client extends " << extends_client << " implements Iface {" << endl;
   indent_up();
 
   indent(f_service_) << "public static class Factory implements org.apache.thrift.TServiceClientFactory<Client> {" << endl;
@@ -2296,49 +2298,14 @@
     "public Client(org.apache.thrift.protocol.TProtocol prot)" << endl;
   scope_up(f_service_);
   indent(f_service_) <<
-    "this(prot, prot);" << endl;
+    "super(prot, prot);" << endl;
   scope_down(f_service_);
   f_service_ << endl;
 
   indent(f_service_) <<
-    "public Client(org.apache.thrift.protocol.TProtocol iprot, org.apache.thrift.protocol.TProtocol oprot)" << endl;
-  scope_up(f_service_);
-  if (extends.empty()) {
-    f_service_ <<
-      indent() << "iprot_ = iprot;" << endl <<
-      indent() << "oprot_ = oprot;" << endl;
-  } else {
-    f_service_ <<
-      indent() << "super(iprot, oprot);" << endl;
-  }
-  scope_down(f_service_);
-  f_service_ << endl;
-
-  if (extends.empty()) {
-    f_service_ <<
-      indent() << "protected org.apache.thrift.protocol.TProtocol iprot_;"  << endl <<
-      indent() << "protected org.apache.thrift.protocol.TProtocol oprot_;"  << endl <<
-      endl <<
-      indent() << "protected int seqid_;" << endl <<
-      endl;
-
-    indent(f_service_) <<
-      "public org.apache.thrift.protocol.TProtocol getInputProtocol()" << endl;
-    scope_up(f_service_);
-    indent(f_service_) <<
-      "return this.iprot_;" << endl;
-    scope_down(f_service_);
-    f_service_ << endl;
-
-    indent(f_service_) <<
-      "public org.apache.thrift.protocol.TProtocol getOutputProtocol()" << endl;
-    scope_up(f_service_);
-    indent(f_service_) <<
-      "return this.oprot_;" << endl;
-    scope_down(f_service_);
-    f_service_ << endl;
-
-  }
+    "public Client(org.apache.thrift.protocol.TProtocol iprot, org.apache.thrift.protocol.TProtocol oprot) {" << endl;
+  indent(f_service_) << "  super(iprot, oprot);" << endl;
+  indent(f_service_) << "}" << endl << endl;
 
   // Generate client method implementations
   vector<t_function*> functions = tservice->get_functions();
@@ -2393,19 +2360,14 @@
     scope_up(f_service_);
 
     // Serialize the request
-    f_service_ <<
-      indent() << "oprot_.writeMessageBegin(new org.apache.thrift.protocol.TMessage(\"" << funname << "\", org.apache.thrift.protocol.TMessageType.CALL, ++seqid_));" << endl <<
-      indent() << argsname << " args = new " << argsname << "();" << endl;
+    indent(f_service_) << argsname << " args = new " << argsname << "();" << endl;
 
     for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) {
       f_service_ <<
         indent() << "args.set" << get_cap_name((*fld_iter)->get_name()) << "(" << (*fld_iter)->get_name() << ");" << endl;
     }
 
-    f_service_ <<
-      indent() << "args.write(oprot_);" << endl <<
-      indent() << "oprot_.writeMessageEnd();" << endl <<
-      indent() << "oprot_.getTransport().flush();" << endl;
+    indent(f_service_) << "sendBase(\"" << funname << "\", args);" << endl;
 
     scope_down(f_service_);
     f_service_ << endl;
@@ -2424,18 +2386,8 @@
       scope_up(f_service_);
 
       f_service_ <<
-        indent() << "org.apache.thrift.protocol.TMessage msg = iprot_.readMessageBegin();" << endl <<
-        indent() << "if (msg.type == org.apache.thrift.protocol.TMessageType.EXCEPTION) {" << endl <<
-        indent() << "  org.apache.thrift.TApplicationException x = org.apache.thrift.TApplicationException.read(iprot_);" << endl <<
-        indent() << "  iprot_.readMessageEnd();" << endl <<
-        indent() << "  throw x;" << endl <<
-        indent() << "}" << endl <<
-        indent() << "if (msg.seqid != seqid_) {" << endl <<
-        indent() << "  throw new org.apache.thrift.TApplicationException(org.apache.thrift.TApplicationException.BAD_SEQUENCE_ID, \"" << (*f_iter)->get_name() << " failed: out of sequence response\");" << endl <<
-        indent() << "}" << endl <<
         indent() << resultname << " result = new " << resultname << "();" << endl <<
-        indent() << "result.read(iprot_);" << endl <<
-        indent() << "iprot_.readMessageEnd();" << endl;
+        indent() << "receiveBase(result, \"" << funname << "\");" << endl;
 
       // Careful, only return _result if not a void function
       if (!(*f_iter)->get_returntype()->is_void()) {
@@ -2620,83 +2572,36 @@
   // Extends stuff
   string extends = "";
   string extends_processor = "";
-  if (tservice->get_extends() != NULL) {
+  if (tservice->get_extends() == NULL) {
+    extends_processor = "org.apache.thrift.TBaseProcessor";
+  } else {
     extends = type_name(tservice->get_extends());
-    extends_processor = " extends " + extends + ".Processor";
+    extends_processor = extends + ".Processor";
   }
 
   // Generate the header portion
   indent(f_service_) <<
-    "public static class Processor" << extends_processor << " implements org.apache.thrift.TProcessor {" << endl;
+    "public static class Processor<I extends Iface> extends " << extends_processor << " implements org.apache.thrift.TProcessor {" << endl;
   indent_up();
 
   indent(f_service_) << "private static final Logger LOGGER = LoggerFactory.getLogger(Processor.class.getName());" << endl;
 
-  indent(f_service_) <<
-    "public Processor(Iface iface)" << endl;
-  scope_up(f_service_);
-  if (!extends.empty()) {
-    f_service_ <<
-      indent() << "super(iface);" << endl;
-  }
-  f_service_ <<
-    indent() << "iface_ = iface;" << endl;
+  indent(f_service_) << "public Processor(I iface) {" << endl;
+  indent(f_service_) << "  super(iface, getProcessMap(new HashMap<String, org.apache.thrift.ProcessFunction<I, ? extends org.apache.thrift.TBase>>()));" << endl;
+  indent(f_service_) << "}" << endl << endl;
 
+  indent(f_service_) << "protected Processor(I iface, Map<String,  org.apache.thrift.ProcessFunction<I, ? extends  org.apache.thrift.TBase>> processMap) {" << endl;
+  indent(f_service_) << "  super(iface, getProcessMap(processMap));" << endl;
+  indent(f_service_) << "}" << endl << endl;
+
+  indent(f_service_) << "private static <I extends Iface> Map<String,  org.apache.thrift.ProcessFunction<I, ? extends  org.apache.thrift.TBase>> getProcessMap(Map<String,  org.apache.thrift.ProcessFunction<I, ? extends  org.apache.thrift.TBase>> processMap) {" << endl;
+  indent_up();
   for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
-    f_service_ <<
-      indent() << "processMap_.put(\"" << (*f_iter)->get_name() << "\", new " << (*f_iter)->get_name() << "());" << endl;
+    indent(f_service_) << "processMap.put(\"" << (*f_iter)->get_name() << "\", new " << (*f_iter)->get_name() << "());" << endl;
   }
-
-  scope_down(f_service_);
-  f_service_ << endl;
-
-  if (extends.empty()) {
-    f_service_ <<
-      indent() << "protected static interface ProcessFunction {" << endl <<
-      indent() << "  public void process(int seqid, org.apache.thrift.protocol.TProtocol iprot, org.apache.thrift.protocol.TProtocol oprot) throws org.apache.thrift.TException;" << endl <<
-      indent() << "}" << endl <<
-      endl;
-  }
-
-  f_service_ <<
-    indent() << "private Iface iface_;" << endl;
-
-  if (extends.empty()) {
-    f_service_ <<
-      indent() << "protected final HashMap<String,ProcessFunction> processMap_ = new HashMap<String,ProcessFunction>();" << endl;
-  }
-
-  f_service_ << endl;
-
-  // Generate the server implementation
-  indent(f_service_) <<
-    "public boolean process(org.apache.thrift.protocol.TProtocol iprot, org.apache.thrift.protocol.TProtocol oprot) throws org.apache.thrift.TException" << endl;
-  scope_up(f_service_);
-
-  f_service_ <<
-    indent() << "org.apache.thrift.protocol.TMessage msg = iprot.readMessageBegin();" << endl;
-
-  // TODO(mcslee): validate message, was the seqid etc. legit?
-
-  f_service_ <<
-    indent() << "ProcessFunction fn = processMap_.get(msg.name);" << endl <<
-    indent() << "if (fn == null) {" << endl <<
-    indent() << "  org.apache.thrift.protocol.TProtocolUtil.skip(iprot, org.apache.thrift.protocol.TType.STRUCT);" << endl <<
-    indent() << "  iprot.readMessageEnd();" << endl <<
-    indent() << "  org.apache.thrift.TApplicationException x = new org.apache.thrift.TApplicationException(org.apache.thrift.TApplicationException.UNKNOWN_METHOD, \"Invalid method name: '\"+msg.name+\"'\");" << endl <<
-    indent() << "  oprot.writeMessageBegin(new org.apache.thrift.protocol.TMessage(msg.name, org.apache.thrift.protocol.TMessageType.EXCEPTION, msg.seqid));" << endl <<
-    indent() << "  x.write(oprot);" << endl <<
-    indent() << "  oprot.writeMessageEnd();" << endl <<
-    indent() << "  oprot.getTransport().flush();" << endl <<
-    indent() << "  return true;" << endl <<
-    indent() << "}" << endl <<
-    indent() << "fn.process(msg.seqid, iprot, oprot);" << endl;
-
-  f_service_ <<
-    indent() << "return true;" << endl;
-
-  scope_down(f_service_);
-  f_service_ << endl;
+  indent(f_service_) << "return processMap;" << endl;
+  indent_down();
+  indent(f_service_) << "}" << endl << endl;
 
   // Generate the process subfunctions
   for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
@@ -2704,9 +2609,7 @@
   }
 
   indent_down();
-  indent(f_service_) <<
-    "}" << endl <<
-    endl;
+  indent(f_service_) << "}" << endl << endl;
 }
 
 /**
@@ -2742,53 +2645,36 @@
  */
 void t_java_generator::generate_process_function(t_service* tservice,
                                                  t_function* tfunction) {
+  string argsname = tfunction->get_name() + "_args";
+  string resultname = tfunction->get_name() + "_result";
+  if (tfunction->is_oneway()) {
+    resultname = "org.apache.thrift.TBase";
+  }
+
   (void) tservice;
   // Open class
   indent(f_service_) <<
-    "private class " << tfunction->get_name() << " implements ProcessFunction {" << endl;
+    "private static class " << tfunction->get_name() << "<I extends Iface> extends org.apache.thrift.ProcessFunction<I, " << argsname << "> {" << endl;
   indent_up();
 
-  // Open function
-  indent(f_service_) <<
-    "public void process(int seqid, org.apache.thrift.protocol.TProtocol iprot, org.apache.thrift.protocol.TProtocol oprot) throws org.apache.thrift.TException" << endl;
-  scope_up(f_service_);
+  indent(f_service_) << "public " << tfunction->get_name() << "() {" << endl;
+  indent(f_service_) << "  super(\"" << tfunction->get_name() << "\");" << endl;
+  indent(f_service_) << "}" << endl << endl;
 
-  string argsname = tfunction->get_name() + "_args";
-  string resultname = tfunction->get_name() + "_result";
+  indent(f_service_) << "protected " << argsname << " getEmptyArgsInstance() {" << endl;
+  indent(f_service_) << "  return new " << argsname << "();" << endl;
+  indent(f_service_) << "}" << endl << endl;
 
-  f_service_ <<
-    indent() << argsname << " args = new " << argsname << "();" << endl <<
-    indent() << "try {" << endl;
+  indent(f_service_) << "protected " << resultname << " getResult(I iface, " << argsname << " args) throws org.apache.thrift.TException {" << endl;
   indent_up();
-  f_service_ <<
-    indent() << "args.read(iprot);" << endl;
-  indent_down();
-  f_service_ << 
-    indent() << "} catch (org.apache.thrift.protocol.TProtocolException e) {" << endl;
-  indent_up();
-  f_service_ <<
-    indent() << "iprot.readMessageEnd();" << endl <<
-    indent() << "org.apache.thrift.TApplicationException x = new org.apache.thrift.TApplicationException(org.apache.thrift.TApplicationException.PROTOCOL_ERROR, e.getMessage());" << endl <<
-    indent() << "oprot.writeMessageBegin(new org.apache.thrift.protocol.TMessage(\"" << tfunction->get_name() << "\", org.apache.thrift.protocol.TMessageType.EXCEPTION, seqid));" << endl <<
-    indent() << "x.write(oprot);" << endl <<
-    indent() << "oprot.writeMessageEnd();" << endl <<
-    indent() << "oprot.getTransport().flush();" << endl <<
-    indent() << "return;" << endl;
-  indent_down();
-  f_service_ << indent() << "}" << endl;
-  f_service_ <<
-    indent() << "iprot.readMessageEnd();" << endl;
+  if (!tfunction->is_oneway()) {
+    indent(f_service_) << resultname << " result = new " << resultname << "();" << endl;
+  }
 
   t_struct* xs = tfunction->get_xceptions();
   const std::vector<t_field*>& xceptions = xs->get_members();
   vector<t_field*>::const_iterator x_iter;
 
-  // Declare result for non oneway function
-  if (!tfunction->is_oneway()) {
-    f_service_ <<
-      indent() << resultname << " result = new " << resultname << "();" << endl;
-  }
-
   // Try block for a function with exceptions
   if (xceptions.size() > 0) {
     f_service_ <<
@@ -2800,13 +2686,13 @@
   t_struct* arg_struct = tfunction->get_arglist();
   const std::vector<t_field*>& fields = arg_struct->get_members();
   vector<t_field*>::const_iterator f_iter;
-
   f_service_ << indent();
+
   if (!tfunction->is_oneway() && !tfunction->get_returntype()->is_void()) {
     f_service_ << "result.success = ";
   }
   f_service_ <<
-    "iface_." << tfunction->get_name() << "(";
+    "iface." << tfunction->get_name() << "(";
   bool first = true;
   for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
     if (first) {
@@ -2839,42 +2725,18 @@
         f_service_ << "}";
       }
     }
-    f_service_ << " catch (Throwable th) {" << endl;
-    indent_up();
-    f_service_ <<
-      indent() << "LOGGER.error(\"Internal error processing " << tfunction->get_name() << "\", th);" << endl <<
-      indent() << "org.apache.thrift.TApplicationException x = new org.apache.thrift.TApplicationException(org.apache.thrift.TApplicationException.INTERNAL_ERROR, \"Internal error processing " << tfunction->get_name() << "\");" << endl <<
-      indent() << "oprot.writeMessageBegin(new org.apache.thrift.protocol.TMessage(\"" << tfunction->get_name() << "\", org.apache.thrift.protocol.TMessageType.EXCEPTION, seqid));" << endl <<
-      indent() << "x.write(oprot);" << endl <<
-      indent() << "oprot.writeMessageEnd();" << endl <<
-      indent() << "oprot.getTransport().flush();" << endl <<
-      indent() << "return;" << endl;
-    indent_down();
-    f_service_ << indent() << "}" << endl;
+    f_service_ << endl;
   }
 
-  // Shortcut out here for oneway functions
   if (tfunction->is_oneway()) {
-    f_service_ <<
-      indent() << "return;" << endl;
-    scope_down(f_service_);
-
-    // Close class
-    indent_down();
-    f_service_ <<
-      indent() << "}" << endl <<
-      endl;
-    return;
+    indent(f_service_) << "return null;" << endl;
+  } else {
+    indent(f_service_) << "return result;" << endl;
   }
-
-  f_service_ <<
-    indent() << "oprot.writeMessageBegin(new org.apache.thrift.protocol.TMessage(\"" << tfunction->get_name() << "\", org.apache.thrift.protocol.TMessageType.REPLY, seqid));" << endl <<
-    indent() << "result.write(oprot);" << endl <<
-    indent() << "oprot.writeMessageEnd();" << endl <<
-    indent() << "oprot.getTransport().flush();" << endl;
+  indent_down();
+  indent(f_service_) << "}";
 
   // Close function
-  scope_down(f_service_);
   f_service_ << endl;
 
   // Close class
diff --git a/lib/java/src/org/apache/thrift/ProcessFunction.java b/lib/java/src/org/apache/thrift/ProcessFunction.java
new file mode 100644
index 0000000..e0cdc7b
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/ProcessFunction.java
@@ -0,0 +1,46 @@
+/**
+ * 
+ */
+package org.apache.thrift;
+
+import org.apache.thrift.protocol.TMessage;
+import org.apache.thrift.protocol.TMessageType;
+import org.apache.thrift.protocol.TProtocol;
+import org.apache.thrift.protocol.TProtocolException;
+
+public abstract class ProcessFunction<I, T extends TBase> {
+  private final String methodName;
+
+  public ProcessFunction(String methodName) {
+    this.methodName = methodName;
+  }
+
+  public final void process(int seqid, TProtocol iprot, TProtocol oprot, I iface) throws TException {
+    T args = getEmptyArgsInstance();
+    try {
+      args.read(iprot);
+    } catch (TProtocolException e) {
+      iprot.readMessageEnd();
+      TApplicationException x = new TApplicationException(TApplicationException.PROTOCOL_ERROR, e.getMessage());
+      oprot.writeMessageBegin(new TMessage(getMethodName(), TMessageType.EXCEPTION, seqid));
+      x.write(oprot);
+      oprot.writeMessageEnd();
+      oprot.getTransport().flush();
+      return;
+    }
+    iprot.readMessageEnd();
+    TBase result = getResult(iface, args);
+    oprot.writeMessageBegin(new TMessage(getMethodName(), TMessageType.REPLY, seqid));
+    result.write(oprot);
+    oprot.writeMessageEnd();
+    oprot.getTransport().flush();
+  }
+
+  protected abstract TBase getResult(I iface, T args) throws TException;
+
+  protected abstract T getEmptyArgsInstance();
+
+  public String getMethodName() {
+    return methodName;
+  }
+}
\ No newline at end of file
diff --git a/lib/java/src/org/apache/thrift/TBaseProcessor.java b/lib/java/src/org/apache/thrift/TBaseProcessor.java
new file mode 100644
index 0000000..f93b133
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/TBaseProcessor.java
@@ -0,0 +1,37 @@
+package org.apache.thrift;
+
+import java.util.Map;
+
+import org.apache.thrift.protocol.TMessage;
+import org.apache.thrift.protocol.TMessageType;
+import org.apache.thrift.protocol.TProtocol;
+import org.apache.thrift.protocol.TProtocolUtil;
+import org.apache.thrift.protocol.TType;
+
+public abstract class TBaseProcessor<I> implements TProcessor {
+  private final I iface;
+  private final Map<String,ProcessFunction<I, ? extends TBase>> processMap;
+
+  protected TBaseProcessor(I iface, Map<String, ProcessFunction<I, ? extends TBase>> processFunctionMap) {
+    this.iface = iface;
+    this.processMap = processFunctionMap;
+  }
+
+  @Override
+  public boolean process(TProtocol in, TProtocol out) throws TException {
+    TMessage msg = in.readMessageBegin();
+    ProcessFunction fn = processMap.get(msg.name);
+    if (fn == null) {
+      TProtocolUtil.skip(in, TType.STRUCT);
+      in.readMessageEnd();
+      TApplicationException x = new TApplicationException(TApplicationException.UNKNOWN_METHOD, "Invalid method name: '"+msg.name+"'");
+      out.writeMessageBegin(new TMessage(msg.name, TMessageType.EXCEPTION, msg.seqid));
+      x.write(out);
+      out.writeMessageEnd();
+      out.getTransport().flush();
+      return true;
+    }
+    fn.process(msg.seqid, in, out, iface);
+    return true;
+  }
+}
diff --git a/lib/java/src/org/apache/thrift/TServiceClient.java b/lib/java/src/org/apache/thrift/TServiceClient.java
index ee07b78..c70e66f 100644
--- a/lib/java/src/org/apache/thrift/TServiceClient.java
+++ b/lib/java/src/org/apache/thrift/TServiceClient.java
@@ -19,21 +19,63 @@
 
 package org.apache.thrift;
 
+import org.apache.thrift.protocol.TMessage;
+import org.apache.thrift.protocol.TMessageType;
 import org.apache.thrift.protocol.TProtocol;
 
 /**
  * A TServiceClient is used to communicate with a TService implementation
  * across protocols and transports.
  */
-public interface TServiceClient {
+public abstract class TServiceClient {
+  public TServiceClient(TProtocol prot) {
+    this(prot, prot);
+  }
+
+  public TServiceClient(TProtocol iprot, TProtocol oprot) {
+    iprot_ = iprot;
+    oprot_ = oprot;
+  }
+
+  protected TProtocol iprot_;
+  protected TProtocol oprot_;
+
+  protected int seqid_;
+
   /**
    * Get the TProtocol being used as the input (read) protocol.
    * @return
    */
-  public TProtocol getInputProtocol();
+  public TProtocol getInputProtocol() {
+    return this.iprot_;
+  }
+
   /**
    * Get the TProtocol being used as the output (write) protocol.
    * @return
    */
-  public TProtocol getOutputProtocol();
+  public TProtocol getOutputProtocol() {
+    return this.oprot_;
+  }
+
+  protected void sendBase(String methodName, TBase args) throws TException {
+    oprot_.writeMessageBegin(new TMessage(methodName, TMessageType.CALL, ++seqid_));
+    args.write(oprot_);
+    oprot_.writeMessageEnd();
+    oprot_.getTransport().flush();
+  }
+
+  protected void receiveBase(TBase result, String methodName) throws TException {
+    TMessage msg = iprot_.readMessageBegin();
+    if (msg.type == TMessageType.EXCEPTION) {
+      TApplicationException x = TApplicationException.read(iprot_);
+      iprot_.readMessageEnd();
+      throw x;
+    }
+    if (msg.seqid != seqid_) {
+      throw new TApplicationException(TApplicationException.BAD_SEQUENCE_ID, methodName + " failed: out of sequence response");
+    }
+    result.read(iprot_);
+    iprot_.readMessageEnd();
+  }
 }