Fixed serialization logic for collections containing typedefs or enums


			       


git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@664746 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/compiler/src/cpp_generator.py b/compiler/src/cpp_generator.py
index 922df53..834877e 100644
--- a/compiler/src/cpp_generator.py
+++ b/compiler/src/cpp_generator.py
@@ -17,7 +17,7 @@
 #if !defined(${source}_types_h_)
 #define ${source}_types_h_ 1
 
-#include <thrift/Thrift.h>
+#include <Thrift.h>
 """)
 
 CPP_TYPES_FOOTER = Template("""
@@ -28,7 +28,10 @@
 #if !defined(${source}_h_)
 #define ${source}_h_ 1
 
-#include <thrift/Thrift.h>
+#include <Thrift.h>
+#include <TProcessor.h>
+#include <protocol/TProtocol.h>
+#include <transport/TTransport.h>
 #include \"${source}_types.h\"
 """)
 
@@ -234,8 +237,8 @@
 CPP_PROCESSORP = CPP_SP.substitute(klass=CPP_PROCESSOR)
 
 CPP_PROTOCOL_NS = CPP_THRIFT_NS+"::protocol"
-CPP_PROTOCOL = CPP_PROTOCOL_NS+"::TProcotol"
-CPP_PROTOCOLP = CPP_SP.substitute(klass=CPP_PROTOCOL)
+CPP_PROTOCOL = CPP_PROTOCOL_NS+"::TProtocol"
+CPP_PROTOCOLP = CPP_SP.substitute(klass="const "+CPP_PROTOCOL)
 
 
 CPP_TRANSPORT_NS = CPP_THRIFT_NS+"::transport"
@@ -311,12 +314,17 @@
 
     public:
 
-    ${service}Client("""+CPP_TRANSPORTP+""" transport, """+CPP_PROTOCOLP+""" protocol): _itrans(transport), _otrans(transport), _iprot(protocol), _oprot(protocol {}
+    ${service}Client("""+CPP_TRANSPORTP+""" transport, """+CPP_PROTOCOLP+""" protocol): _itrans(transport), _otrans(transport), _iprot(protocol), _oprot(protocol) {}
 
-    ${service}Client("""+CPP_TRANSPORTP+""" itrans, """+CPP_TRANSPORTP+""" otrans, """+CPP_PROTOCOLP+""" iprot, """+CPP_PROTOCOLP+""" oprot) : _itrans(itrans), _otrans(otrans), _iprot(iprot), _oprot(oprot)x {}
+    ${service}Client("""+CPP_TRANSPORTP+""" itrans, """+CPP_TRANSPORTP+""" otrans, """+CPP_PROTOCOLP+""" iprot, """+CPP_PROTOCOLP+""" oprot) : _itrans(itrans), _otrans(otrans), _iprot(iprot), _oprot(oprot) {}
 
-${functionDeclarations}};
-""")
+${functionDeclarations}
+    private:
+    """+CPP_TRANSPORTP+""" _itrans;
+    """+CPP_TRANSPORTP+""" _otrans;
+    """+CPP_PROTOCOLP+""" _iprot;
+    """+CPP_PROTOCOLP+""" _oprot;
+};""")
 
 def writeClientDeclaration(cfile, service, debugp=None):
 
@@ -431,12 +439,16 @@
     cfile.close()
 
 
-CPP_STRUCT_READ = Template("""void read${name}Struct("""+CPP_PROTOCOLP+""" _iprot, """+CPP_TRANSPORTP+""" itrans, ${declaration}& value) {
+CPP_STRUCT_READ = Template("""
+uint32_t read${name}Struct("""+CPP_PROTOCOLP+""" _iprot, """+CPP_TRANSPORTP+""" itrans, ${declaration}& value) {
+
     std::string name;
     uint32_t id;
     uint32_t type;
+    uint32_t xfer = 0;
+
     while(true) {
-        _iprot->readFieldBegin(_itrans, name, type, id);
+        xfer+= _iprot->readFieldBegin(_itrans, name, type, id);
         if(type == """+CPP_PROTOCOL_TSTOP+""") {
             break;
         }
@@ -444,6 +456,10 @@
 ${readFieldListSwitch}
         }
     }
+
+    xfer+= _iprot->readStructEnd(_itrans);
+
+    return xfer;
 }
 """)
 
@@ -490,10 +506,10 @@
         return "struct_"+ttype.name
 
     elif isinstance(ttype, TypeDef):
-        return typeToIOMethodSuffix(ttype.definitionType)
+        return ttype.name
 
     elif isinstance(ttype, Enum):
-        return typeToIOMethodSuffix(U32_TYPE)
+        return ttype.name
 
     else:
         raise Exception, "Unknown type "+str(ttype)
@@ -503,19 +519,19 @@
     suffix = typeToIOMethodSuffix(ttype)
 
     if isinstance(ttype, PrimitiveType):
-        return "iprot->read"+suffix+"(itrans, "+value+")"
+        return "xfer += iprot->read"+suffix+"(itrans, "+value+")"
 
     elif isinstance(ttype, CollectionType):
-        return "read_"+suffix+"(iprot, itrans, "+value+")"
+        return "xfer+= read_"+suffix+"(iprot, itrans, "+value+")"
 
     elif isinstance(ttype, Struct):
-        return "read_"+suffix+"(iprot, itrans, "+value+")"
+        return "xfer+= read_"+suffix+"(iprot, itrans, "+value+")"
 
     elif isinstance(ttype, TypeDef):
-        return toReaderCall(value, ttype.definitionType)
+        return toReaderCall("reinterpret_cast<"+typeToCTypeDeclaration(ttype.definitionType)+"&>("+value+")", ttype.definitionType)
 
     elif isinstance(ttype, Enum):
-        return toReaderCall(value, U32_TYPE)
+        return toReaderCall("reinterpret_cast<"+typeToCTypeDeclaration(I32_TYPE)+"&>("+value+")", I32_TYPE)
 
     else:
         raise Exception, "Unknown type "+str(ttype)
@@ -525,82 +541,86 @@
     suffix = typeToIOMethodSuffix(ttype)
 
     if isinstance(ttype, PrimitiveType):
-        return "oprot->write"+suffix+"(otrans, "+value+")"
+        return "xfer+= oprot->write"+suffix+"(otrans, "+value+")"
 
     elif isinstance(ttype, CollectionType):
-        return "write_"+suffix+"(oprot, otrans, "+value+")"
+        return "xfer+= write_"+suffix+"(oprot, otrans, "+value+")"
 
     elif isinstance(ttype, Struct):
-        return "write_"+suffix+"(oprot, otrans, "+value+")"
+        return "xfer+= write_"+suffix+"(oprot, otrans, "+value+")"
 
     elif isinstance(ttype, TypeDef):
-        return toWriterCall(value, ttype.definitionType)
+        return toWriterCall("reinterpret_cast<const "+typeToCTypeDeclaration(ttype.definitionType)+"&>("+value+")", ttype.definitionType)
 
     elif isinstance(ttype, Enum):
-        return toWriterCall(value, U32_TYPE)
+        return toWriterCall("reinterpret_cast<const "+typeToCTypeDeclaration(I32_TYPE)+"&>("+value+")", I32_TYPE)
 
     else:
         raise Exception, "Unknown type "+str(ttype)
 
 CPP_READ_MAP_DEFINITION = Template("""
-void read_${suffix}("""+CPP_PROTOCOLP+""" iprot, """+CPP_TRANSPORTP+""" itrans, ${declaration}& value) {
+uint32_t read_${suffix}("""+CPP_PROTOCOLP+""" iprot, """+CPP_TRANSPORTP+""" itrans, ${declaration}& value) {
 
    uint32_t count;
    ${keyType} key;
    ${valueType} elem;
+   uint32_t xfer = 0;
 
-   iprot->readU32(itrans, count);
+   xfer += iprot->readU32(itrans, count);
 
    for(int ix = 0; ix <  count; ix++) {
        ${keyReaderCall};
        ${valueReaderCall};
        value.insert(std::make_pair(key, elem));
    }
+
+   return xfer;
 }
 """)
     
 CPP_WRITE_MAP_DEFINITION = Template("""
-void write_${suffix}("""+CPP_PROTOCOLP+""" oprot, """+CPP_TRANSPORTP+""" otrans, ${declaration}& value) {
+uint32_t write_${suffix}("""+CPP_PROTOCOLP+""" oprot, """+CPP_TRANSPORTP+""" otrans, const ${declaration}& value) {
 
-   uint32_t count;
-   ${keyType} key;
-   ${valueType} elem;
+   uint32_t xfer = 0;
 
-   oprot->writeU32(otrans, value.size());
+   xfer += oprot->writeU32(otrans, value.size());
 
-   for(${declaration}::iterator ix = value.begin(); ix != value.end(); ++ix) {
+   for(${declaration}::const_iterator ix = value.begin(); ix != value.end(); ++ix) {
        ${keyWriterCall};
        ${valueWriterCall};
    }
+   return xfer;
 }
 """)
     
 CPP_READ_LIST_DEFINITION = Template("""
-void read_${suffix}("""+CPP_PROTOCOLP+""" iprot, """+CPP_TRANSPORTP+""" itrans, ${declaration}& value) {
+uint32_t read_${suffix}("""+CPP_PROTOCOLP+""" iprot, """+CPP_TRANSPORTP+""" itrans, ${declaration}& value) {
 
    uint32_t count;
    ${valueType} elem;
+   uint32_t xfer = 0;
 
-   iprot->readU32(itrans,  count);
+   xfer+= iprot->readU32(itrans,  count);
 
    for(int ix = 0; ix < count; ix++) {
        ${valueReaderCall};
-       value.insert(elem);
+       value.${insert}(elem);
    }
+   return xfer;
 }
 """)
     
 CPP_WRITE_LIST_DEFINITION = Template("""
-void write_${suffix}("""+CPP_PROTOCOLP+""" oprot, """+CPP_TRANSPORTP+""" otrans, ${declaration}& value) {
+uint32_t write_${suffix}("""+CPP_PROTOCOLP+""" oprot, """+CPP_TRANSPORTP+""" otrans, const ${declaration}& value) {
 
-   uint32_t count;
-   ${valueType} elem;
+   uint32_t xfer = 0;
 
-   oprot->writeU32(otrans, value.size());
+   xfer+= oprot->writeU32(otrans, value.size());
 
-   for(${declaration}::iterator ix = value.begin(); ix != value.end(); ++ix) {
+   for(${declaration}::const_iterator ix = value.begin(); ix != value.end(); ++ix) {
        ${valueWriterCall};
    }
+   return xfer;
 }
 """)
     
@@ -621,9 +641,15 @@
                                                   valueReaderCall=valueReaderCall)
 
     else:
+	if isinstance(ttype, List):
+	    insert="push_back"
+	else:
+	    insert="insert"
+
         return CPP_READ_LIST_DEFINITION.substitute(suffix=suffix, declaration=typeToCTypeDeclaration(ttype),
                                                    valueReaderCall=valueReaderCall,
-                                                   valueType=typeToCTypeDeclaration(ttype.valueType))
+                                                   valueType=typeToCTypeDeclaration(ttype.valueType),
+						   insert=insert)
 
 
 def toCollectionWriterDefinition(ttype):
@@ -631,9 +657,11 @@
     suffix = typeToIOMethodSuffix(ttype)
 
     if isinstance(ttype, Map):
-        keyWriterCall = toWriterCall("key", ttype.keyType)
+        keyWriterCall = toWriterCall("ix->first", ttype.keyType)
+        valueWriterCall = toWriterCall("ix->second", ttype.valueType)
 
-    valueWriterCall= toWriterCall("elem", ttype.valueType)
+    else:
+	valueWriterCall= toWriterCall("*ix", ttype.valueType)
 
     if isinstance(ttype, Map):
         return CPP_WRITE_MAP_DEFINITION.substitute(suffix=suffix, declaration=typeToCTypeDeclaration(ttype),
@@ -649,24 +677,26 @@
 
 
 CPP_READ_STRUCT_DEFINITION = Template("""
-void read_${suffix}("""+CPP_PROTOCOLP+""" iprot, """+CPP_TRANSPORTP+""" itrans, ${declaration}& value) {
+uint32_t read_${suffix}("""+CPP_PROTOCOLP+""" iprot, """+CPP_TRANSPORTP+""" itrans, ${declaration}& value) {
 
     std::string name;
     """+CPP_PROTOCOL_TTYPE+""" type;
     uint16_t id;
+    uint32_t xfer = 0;
 
     while(true) {
 
-        iprot->readFieldBegin(itrans, name, type, id);
+        xfer+= iprot->readFieldBegin(itrans, name, type, id);
 
         if(type == """+CPP_PROTOCOL_TSTOP+""") {break;}
 
         switch(id) {
 ${fieldSwitch}
-            default:v iprot->skip(itrans, type); break;}
+            default: xfer += iprot->skip(itrans, type); break;}
 
-        iprot->readFieldEnd(itrans);
+        xfer+= iprot->readFieldEnd(itrans);
     }
+    return xfer;
 }
 """)
     
@@ -677,13 +707,15 @@
 """)
     
 CPP_WRITE_STRUCT_DEFINITION = Template("""
-void write_${suffix}("""+CPP_PROTOCOLP+""" oprot, """+CPP_TRANSPORTP+""" otrans, const ${declaration}& value) {
+uint32_t write_${suffix}("""+CPP_PROTOCOLP+""" oprot, """+CPP_TRANSPORTP+""" otrans, const ${declaration}& value) {
 
-    oprot->writeStructBegin(otrans, \"${name}\");
+    uint32_t xfer = 0;
+
+    xfer+= oprot->writeStructBegin(otrans, \"${name}\");
 ${fieldWriterCalls}
-    oprot->writeFieldStop(otrans);
-    oprot->writeStructEnd(otrans);
-    }
+    xfer+= oprot->writeFieldStop(otrans);
+    xfer += oprot->writeStructEnd(otrans);
+    return xfer;
 }
 """)
     
@@ -726,6 +758,15 @@
     elif isinstance(ttype, Struct):
         return toStructReaderDefinition(ttype)
 
+    elif isinstance(ttype, TypeDef):
+	return ""
+
+    elif isinstance(ttype, Enum):
+	return ""
+
+    else:
+	raise Exception, "Unsupported type: "+str(ttype)
+
 def toWriterDefinition(ttype):
     if isinstance(ttype, CollectionType):
         return toCollectionWriterDefinition(ttype)
@@ -733,6 +774,15 @@
     elif isinstance(ttype, Struct):
         return toStructWriterDefinition(ttype)
 
+    elif isinstance(ttype, TypeDef):
+	return ""
+
+    elif isinstance(ttype, Enum):
+	return ""
+
+    else:
+	raise Exception, "Unsupported type: "+str(ttype)
+
 def toOrderedIOList(ttype, result=None):
     if not result:
 	result = []
@@ -758,9 +808,11 @@
 	result.append(ttype)
 
     elif isinstance(ttype, TypeDef):
+	result.append(ttype)
 	return result
 
     elif isinstance(ttype, Enum):
+	result.append(ttype)
 	return result
 
     elif isinstance(ttype, Program):
@@ -795,7 +847,6 @@
     result = ""
 
     for ttype in iolist:
-	
 	result+= toReaderDefinition(ttype)
 	result+= toWriterDefinition(ttype)
 
@@ -843,4 +894,3 @@
         writeServicesHeader(program, filename, genDir, debugp)
         
         writeImplementationSource(program, filename, genDir, debugp)
-    
diff --git a/compiler/src/parser.py b/compiler/src/parser.py
index edafccb..af428fb 100644
--- a/compiler/src/parser.py
+++ b/compiler/src/parser.py
@@ -867,7 +867,7 @@
     def p_listtype(self, p):
         'listtype : LIST LANGLE fieldtype RANGLE'
 	self.pdebug("p_listtype", p)
-	p[0] = Set(p, p[3])
+	p[0] = List(p, p[3])
 
     def p_error(self, p):
         self.errors.append(SyntaxError(p))
@@ -921,4 +921,3 @@
 	    outf = file(os.path.splitext(filename)[0]+".thyc", "w")
 
 	    pickle.dump(self.program, outf)
-
diff --git a/compiler/src/thrift.py b/compiler/src/thrift.py
index 8321bcd..9282620 100644
--- a/compiler/src/thrift.py
+++ b/compiler/src/thrift.py
@@ -24,5 +24,8 @@
 
     p.parse(filename, False)
 
+    if len(p.errors):
+	sys.exit(-1)
+
     [g(p.program, filename) for g in generators]