Modified cpp code generation to build read/write methods for each non-primitive type rather than inlining all serialization in client/server function handlers

Modified parser to assign negative numbers to autogenerated struct fields and function args.
			       


git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@664745 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/compiler/src/cpp_generator.py b/compiler/src/cpp_generator.py
index f340ee8..922df53 100644
--- a/compiler/src/cpp_generator.py
+++ b/compiler/src/cpp_generator.py
@@ -29,11 +29,18 @@
 #define ${source}_h_ 1
 
 #include <thrift/Thrift.h>
+#include \"${source}_types.h\"
 """)
 
 CPP_SERVICES_FOOTER = Template("""
 #endif // !defined(${source}_h_)""")
 
+CPP_IMPL_HEADER = Template(HEADER_COMMENT+"""
+#include \"${source}.h\"
+""")
+
+CPP_IMPL_FOOTER = Template("")
+
 def cpp_debug(arg):
     print(arg)
 
@@ -241,6 +248,41 @@
 CPP_PROTOCOL_TSTOP = CPP_PROTOCOL_NS+"::T_STOP"
 CPP_PROTOCOL_TTYPE = CPP_PROTOCOL_NS+"::TType"
 
+CPP_TTYPE_MAP = {
+    STOP_TYPE : CPP_PROTOCOL_NS+"::T_STOP",
+    VOID_TYPE : CPP_PROTOCOL_NS+"::T_VOID",
+    BOOL_TYPE : CPP_PROTOCOL_NS+"::T_BOOL",
+    UTF7_TYPE : CPP_PROTOCOL_NS+"::T_UTF7",
+    UTF7_TYPE : CPP_PROTOCOL_NS+"::T_UTF7",
+    UTF8_TYPE : CPP_PROTOCOL_NS+"::T_UTF8",
+    UTF16_TYPE : CPP_PROTOCOL_NS+"::T_UTF16",
+    U08_TYPE : CPP_PROTOCOL_NS+"::T_U08",
+    I08_TYPE : CPP_PROTOCOL_NS+"::T_I08",
+    I16_TYPE : CPP_PROTOCOL_NS+"::T_I16",
+    I32_TYPE : CPP_PROTOCOL_NS+"::T_I32",
+    I64_TYPE : CPP_PROTOCOL_NS+"::T_I64",
+    U08_TYPE : CPP_PROTOCOL_NS+"::T_U08",
+    U16_TYPE : CPP_PROTOCOL_NS+"::T_U16",
+    U32_TYPE : CPP_PROTOCOL_NS+"::T_U32",
+    U64_TYPE : CPP_PROTOCOL_NS+"::T_U64",
+    FLOAT_TYPE : CPP_PROTOCOL_NS+"::T_FLOAT",
+    Struct : CPP_PROTOCOL_NS+"::T_STRUCT",
+    List : CPP_PROTOCOL_NS+"::T_LIST",
+    Map : CPP_PROTOCOL_NS+"::T_MAP",
+    Set : CPP_PROTOCOL_NS+"::T_SET"
+}
+
+def toWireType(ttype):
+
+    if isinstance(ttype, PrimitiveType):
+	return CPP_TTYPE_MAP[ttype]
+
+    elif isinstance(ttype, Struct) or isinstance(ttype, CollectionType):
+	return CPP_TTYPE_MAP[type(ttype)]
+
+    else:
+	raise Exception, "No wire type for thrift type: "+str(ttype)
+
 CPP_SERVER_DECLARATION = Template("""
 class ${service}ServerIf : public ${service}If, public """+CPP_PROCESSOR+""" {
     public:
@@ -423,9 +465,9 @@
 }
 
 CPP_COLLECTION_TYPE_IO_METHOD_SUFFIX_MAP = {
-    Map : "stdmap",
-    List : "stdlist",
-    Set : "stdset"
+    Map : "map",
+    List : "list",
+    Set : "set"
 }
 
 def typeToIOMethodSuffix(ttype):
@@ -456,7 +498,7 @@
     else:
         raise Exception, "Unknown type "+str(ttype)
 
-def toReadCall(value, ttype):
+def toReaderCall(value, ttype):
 
     suffix = typeToIOMethodSuffix(ttype)
 
@@ -470,10 +512,32 @@
         return "read_"+suffix+"(iprot, itrans, "+value+")"
 
     elif isinstance(ttype, TypeDef):
-        return toReadCall(value, ttype.definitionType)
+        return toReaderCall(value, ttype.definitionType)
 
     elif isinstance(ttype, Enum):
-        return toReadCall(value, U32_TYPE)
+        return toReaderCall(value, U32_TYPE)
+
+    else:
+        raise Exception, "Unknown type "+str(ttype)
+
+def toWriterCall(value, ttype):
+
+    suffix = typeToIOMethodSuffix(ttype)
+
+    if isinstance(ttype, PrimitiveType):
+        return "oprot->write"+suffix+"(otrans, "+value+")"
+
+    elif isinstance(ttype, CollectionType):
+        return "write_"+suffix+"(oprot, otrans, "+value+")"
+
+    elif isinstance(ttype, Struct):
+        return "write_"+suffix+"(oprot, otrans, "+value+")"
+
+    elif isinstance(ttype, TypeDef):
+        return toWriterCall(value, ttype.definitionType)
+
+    elif isinstance(ttype, Enum):
+        return toWriterCall(value, U32_TYPE)
 
     else:
         raise Exception, "Unknown type "+str(ttype)
@@ -485,11 +549,11 @@
    ${keyType} key;
    ${valueType} elem;
 
-   _iprot->readU32(itrans, count);
+   iprot->readU32(itrans, count);
 
    for(int ix = 0; ix <  count; ix++) {
-       ${keyReadCall};
-       ${valueReadCall};
+       ${keyReaderCall};
+       ${valueReaderCall};
        value.insert(std::make_pair(key, elem));
    }
 }
@@ -502,51 +566,85 @@
    ${keyType} key;
    ${valueType} elem;
 
-   _oprot->writeU32(otrans, count);
+   oprot->writeU32(otrans, value.size());
 
-   for(int ix = 0; ix <  count; ix++) {
-       ${keyReadCall};
-       ${valueReadCall};
-       value.insert(std::make_pair(key, elem));
+   for(${declaration}::iterator ix = value.begin(); ix != value.end(); ++ix) {
+       ${keyWriterCall};
+       ${valueWriterCall};
    }
 }
 """)
     
-
 CPP_READ_LIST_DEFINITION = Template("""
 void read_${suffix}("""+CPP_PROTOCOLP+""" iprot, """+CPP_TRANSPORTP+""" itrans, ${declaration}& value) {
 
    uint32_t count;
    ${valueType} elem;
 
-   _iprot->readU32(itrans,  count);
+   iprot->readU32(itrans,  count);
 
    for(int ix = 0; ix < count; ix++) {
-       ${valueReadCall};
+       ${valueReaderCall};
        value.insert(elem);
    }
 }
 """)
     
-def toCollectionReadDefinition(ttype):
+CPP_WRITE_LIST_DEFINITION = Template("""
+void write_${suffix}("""+CPP_PROTOCOLP+""" oprot, """+CPP_TRANSPORTP+""" otrans, ${declaration}& value) {
+
+   uint32_t count;
+   ${valueType} elem;
+
+   oprot->writeU32(otrans, value.size());
+
+   for(${declaration}::iterator ix = value.begin(); ix != value.end(); ++ix) {
+       ${valueWriterCall};
+   }
+}
+""")
+    
+def toCollectionReaderDefinition(ttype):
 
     suffix = typeToIOMethodSuffix(ttype)
 
     if isinstance(ttype, Map):
-        keyReadCall = toReadCall("key", ttype.keyType)
+        keyReaderCall = toReaderCall("key", ttype.keyType)
 
-    valueReadCall= toReadCall("elem", ttype.valueType)
+    valueReaderCall= toReaderCall("elem", ttype.valueType)
 
     if isinstance(ttype, Map):
         return CPP_READ_MAP_DEFINITION.substitute(suffix=suffix, declaration=typeToCTypeDeclaration(ttype),
                                                   keyType=typeToCTypeDeclaration(ttype.keyType),
-                                                  keyReadCall=keyReadCall,
+                                                  keyReaderCall=keyReaderCall,
                                                   valueType=typeToCTypeDeclaration(ttype.valueType),
-                                                  valueReadCall=valueReadCall)
+                                                  valueReaderCall=valueReaderCall)
 
     else:
         return CPP_READ_LIST_DEFINITION.substitute(suffix=suffix, declaration=typeToCTypeDeclaration(ttype),
-                                                   valueReadCall=valueReadCall,
+                                                   valueReaderCall=valueReaderCall,
+                                                   valueType=typeToCTypeDeclaration(ttype.valueType))
+
+
+def toCollectionWriterDefinition(ttype):
+
+    suffix = typeToIOMethodSuffix(ttype)
+
+    if isinstance(ttype, Map):
+        keyWriterCall = toWriterCall("key", ttype.keyType)
+
+    valueWriterCall= toWriterCall("elem", ttype.valueType)
+
+    if isinstance(ttype, Map):
+        return CPP_WRITE_MAP_DEFINITION.substitute(suffix=suffix, declaration=typeToCTypeDeclaration(ttype),
+                                                  keyType=typeToCTypeDeclaration(ttype.keyType),
+                                                  keyWriterCall=keyWriterCall,
+                                                  valueType=typeToCTypeDeclaration(ttype.valueType),
+                                                  valueWriterCall=valueWriterCall)
+
+    else:
+        return CPP_WRITE_LIST_DEFINITION.substitute(suffix=suffix, declaration=typeToCTypeDeclaration(ttype),
+                                                   valueWriterCall=valueWriterCall,
                                                    valueType=typeToCTypeDeclaration(ttype.valueType))
 
 
@@ -559,23 +657,37 @@
 
     while(true) {
 
-        _iprot->readFieldBegin(itrans, name, type, id);
+        iprot->readFieldBegin(itrans, name, type, id);
 
         if(type == """+CPP_PROTOCOL_TSTOP+""") {break;}
 
         switch(id) {
 ${fieldSwitch}
-            default:
-            iprot->skip(itrans, type);
-            break;
-        }
+            default:v iprot->skip(itrans, type); break;}
 
-        _iprot->readFieldEnd(itrans);
+        iprot->readFieldEnd(itrans);
     }
 }
 """)
     
-def toStructReadDefinition(ttype):
+CPP_WRITE_FIELD_DEFINITION  = Template("""
+    oprot->writeFieldBegin(otrans, \"${name}\", ${type}, ${id});
+    ${fieldWriterCall};
+    oprot->writeFieldEnd(otrans);
+""")
+    
+CPP_WRITE_STRUCT_DEFINITION = Template("""
+void write_${suffix}("""+CPP_PROTOCOLP+""" oprot, """+CPP_TRANSPORTP+""" otrans, const ${declaration}& value) {
+
+    oprot->writeStructBegin(otrans, \"${name}\");
+${fieldWriterCalls}
+    oprot->writeFieldStop(otrans);
+    oprot->writeStructEnd(otrans);
+    }
+}
+""")
+    
+def toStructReaderDefinition(ttype):
 
     suffix = typeToIOMethodSuffix(ttype)
 
@@ -590,24 +702,145 @@
 
     for field in fieldList:
         fieldSwitch+= "            case "+str(field.id)+": "
-        fieldSwitch+= toReadCall("value."+field.name, field.type)+"; break;\n"
+        fieldSwitch+= toReaderCall("value."+field.name, field.type)+"; break;\n"
 
     return CPP_READ_STRUCT_DEFINITION.substitute(suffix=suffix, declaration=typeToCTypeDeclaration(ttype), fieldSwitch=fieldSwitch)
+
+def toStructWriterDefinition(ttype):
+
+    suffix = typeToIOMethodSuffix(ttype)
+
+    writeCalls = ""
+
+    for field in ttype.fieldList:
+
+	writeCalls+= CPP_WRITE_FIELD_DEFINITION.substitute(name=field.name, type=toWireType(field.type), id=field.id,
+							   fieldWriterCall=toWriterCall("value."+field.name, field.type))
+				   
+    return CPP_WRITE_STRUCT_DEFINITION.substitute(name=ttype.name, suffix=suffix, declaration=typeToCTypeDeclaration(ttype), fieldWriterCalls=writeCalls)
     
-def toReadDefinition(ttype):
+def toReaderDefinition(ttype):
     if isinstance(ttype, CollectionType):
-        return toCollectionReadDefinition(ttype)
+        return toCollectionReaderDefinition(ttype)
 
     elif isinstance(ttype, Struct):
-        return toStructReadDefinition(ttype)
+        return toStructReaderDefinition(ttype)
+
+def toWriterDefinition(ttype):
+    if isinstance(ttype, CollectionType):
+        return toCollectionWriterDefinition(ttype)
+
+    elif isinstance(ttype, Struct):
+        return toStructWriterDefinition(ttype)
+
+def toOrderedIOList(ttype, result=None):
+    if not result:
+	result = []
+
+    if ttype in result:
+	return result
+
+    elif isinstance(ttype, PrimitiveType):
+	return result
+
+    elif isinstance(ttype, CollectionType):
+
+	if isinstance(ttype, Map):
+	    result = toOrderedIOList(ttype.keyType, result)
+
+	result = toOrderedIOList(ttype.valueType, result)
+
+	result.append(ttype)
+
+    elif isinstance(ttype, Struct):
+	for field in ttype.fieldList:
+	    result = toOrderedIOList(field.type, result)
+	result.append(ttype)
+
+    elif isinstance(ttype, TypeDef):
+	return result
+
+    elif isinstance(ttype, Enum):
+	return result
+
+    elif isinstance(ttype, Program):
+
+	for struct in ttype.structMap.values():
+	    result = toOrderedIOList(struct, result)
+
+	for service in ttype.serviceMap.values():
+	    result = toOrderedIOList(service, result)
+
+    elif isinstance(ttype, Service):
+	for function in ttype.functionList:
+	    result = toOrderedIOList(function, result)
+
+    elif isinstance(ttype, Function):
+	result = toOrderedIOList(ttype.resultType, result)
+
+	for arg in ttype.argFieldList:
+	    result = toOrderedIOList(arg.type, result)
+
+    else:
+	raise Exception, "Unsupported thrift type: "+str(ttype)
+
+    return result
+
+def toIOMethodImplementations(program):
+    
+    # get orderede list of all types that need marshallers:
+
+    iolist = toOrderedIOList(program)
+
+    result = ""
+
+    for ttype in iolist:
+	
+	result+= toReaderDefinition(ttype)
+	result+= toWriterDefinition(ttype)
+
+    return result;
+
+def toImplementationSourceName(filename, genDir=None, debugp=None):
+
+    if not genDir:
+        genDir = toGenDir(filename)
+
+    basename = toBasename(filename)
+
+    result = os.path.join(genDir, basename+".cc")
+
+    if debugp:
+        debugp("toDefinitionHeaderName("+str(filename)+", "+str(genDir)+") => "+str(basename))
+
+    return result
+
+def writeImplementationSource(program, filename, genDir=None, debugp=None):
+
+    implementationSource = toImplementationSourceName(filename, genDir)
+
+    if debugp:
+        debugp("implementationSource: "+str(implementationSource))
+
+    cfile = CFile(implementationSource, "w")
+
+    basename = toBasename(filename)
+
+    cfile.writeln(CPP_IMPL_HEADER.substitute(source=basename, date=time.ctime()))
+
+    cfile.write(toIOMethodImplementations(program))
+
+    cfile.writeln(CPP_IMPL_FOOTER.substitute(source=basename))
+
+    cfile.close()
 
 class CPPGenerator(Generator):
 
     def __call__(self, program, filename, genDir=None, debugp=None):
 
-        writeDefinitionHeader(program, filename, gendir, debugp)
+        writeDefinitionHeader(program, filename, genDir, debugp)
         
-        writeServicesHeader(program, filename, gendir, debugp)
+        writeServicesHeader(program, filename, genDir, debugp)
         
-        writeClientHeader(program, filename, gendir, debugp)
+        writeImplementationSource(program, filename, genDir, debugp)
     
diff --git a/compiler/src/thrift.py b/compiler/src/thrift.py
new file mode 100644
index 0000000..8321bcd
--- /dev/null
+++ b/compiler/src/thrift.py
@@ -0,0 +1,28 @@
+import sys
+import generator
+import cpp_generator
+import parser
+
+if __name__ == '__main__':
+
+    args = sys.argv[1:]
+
+    generators = []
+
+    debug = False
+
+    if "--cpp" in args:
+	generators.append(cpp_generator.CPPGenerator())
+	args.remove("--cpp")
+    if "--debug" in args:
+	debug = True
+	args.remove("--debug")
+
+    filename = args[-1]
+
+    p = parser.Parser(debug=debug)
+
+    p.parse(filename, False)
+
+    [g(p.program, filename) for g in generators]
+