THRIFT-1115 python TBase class for dynamic (de)serialization, and __slots__ option for memory savings
Patch: Will Pierce
git-svn-id: https://svn.apache.org/repos/asf/thrift/trunk@1169492 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/.gitignore b/.gitignore
index bfb61d2..082f7b7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -179,6 +179,7 @@
 /test/py/Makefile
 /test/py/Makefile.in
 /test/py/gen-py
+/test/py/gen-py-*
 /test/py.twisted/Makefile
 /test/py.twisted/Makefile.in
 /test/py.twisted/_trial_temp/
diff --git a/compiler/cpp/src/generate/t_py_generator.cc b/compiler/cpp/src/generate/t_py_generator.cc
index 6a82bd7..34acba4 100644
--- a/compiler/cpp/src/generate/t_py_generator.cc
+++ b/compiler/cpp/src/generate/t_py_generator.cc
@@ -52,12 +52,44 @@
     iter = parsed_options.find("new_style");
     gen_newstyle_ = (iter != parsed_options.end());
 
+    iter = parsed_options.find("slots");
+    gen_slots_ = (iter != parsed_options.end());
+
+    iter = parsed_options.find("dynamic");
+    gen_dynamic_ = (iter != parsed_options.end());
+
+    if (gen_dynamic_) {
+      gen_newstyle_ = 0; // dynamic is newstyle
+      gen_dynbaseclass_ = "TBase";
+      gen_dynbaseclass_exc_ = "TExceptionBase";
+      import_dynbase_ = "from thrift.protocol.TBase import TBase, TExceptionBase\n";
+    }
+
+    iter = parsed_options.find("dynbase");
+    if (iter != parsed_options.end()) {
+      gen_dynbase_ = true;
+      gen_dynbaseclass_ = (iter->second);
+    }
+
+    iter = parsed_options.find("dynexc");
+    if (iter != parsed_options.end()) {
+      gen_dynbaseclass_exc_ = (iter->second);
+    }
+
+    iter = parsed_options.find("dynimport");
+    if (iter != parsed_options.end()) {
+      gen_dynbase_ = true;
+      import_dynbase_ = (iter->second);
+    }
+
     iter = parsed_options.find("twisted");
     gen_twisted_ = (iter != parsed_options.end());
 
     iter = parsed_options.find("utf8strings");
     gen_utf8strings_ = (iter != parsed_options.end());
 
+    copy_options_ = option_string;
+    
     if (gen_twisted_){
       out_dir_base_ = "gen-py.twisted";
     } else {
@@ -214,17 +246,32 @@
  private:
 
   /**
-   * True iff we should generate new-style classes.
+   * True if we should generate new-style classes.
    */
   bool gen_newstyle_;
 
+   /**
+   * True if we should generate dynamic style classes.
+   */
+  bool gen_dynamic_;
+ 
+  bool gen_dynbase_;
+  std::string gen_dynbaseclass_;
+  std::string gen_dynbaseclass_exc_;
+ 
+  std::string import_dynbase_;
+
+  bool gen_slots_;
+
+  std::string copy_options_;
+ 
   /**
-   * True iff we should generate Twisted-friendly RPC services.
+   * True if we should generate Twisted-friendly RPC services.
    */
   bool gen_twisted_;
 
   /**
-   * True iff strings should be encoded using utf-8.
+   * True if strings should be encoded using utf-8.
    */
   bool gen_utf8strings_;
 
@@ -325,13 +372,19 @@
  * Renders all the imports necessary to use the accelerated TBinaryProtocol
  */
 string t_py_generator::render_fastbinary_includes() {
-  return
-    "from thrift.transport import TTransport\n"
-    "from thrift.protocol import TBinaryProtocol, TProtocol\n"
-    "try:\n"
-    "  from thrift.protocol import fastbinary\n"
-    "except:\n"
-    "  fastbinary = None\n";
+  string hdr = "";
+  if (gen_dynamic_) {
+    hdr += std::string(import_dynbase_);
+  } else {
+    hdr +=
+      "from thrift.transport import TTransport\n"
+      "from thrift.protocol import TBinaryProtocol, TProtocol\n"
+      "try:\n"
+      "  from thrift.protocol import fastbinary\n"
+      "except:\n"
+      "  fastbinary = None\n";
+  }
+  return hdr;
 }
 
 /**
@@ -343,6 +396,8 @@
     "# Autogenerated by Thrift Compiler (" + THRIFT_VERSION + ")\n" +
     "#\n" +
     "# DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING\n" +
+    "#\n" +
+    "#  options string: " + copy_options_  + "\n" +
     "#\n";
 }
 
@@ -351,7 +406,7 @@
  */
 string t_py_generator::py_imports() {
   return
-    string("from thrift.Thrift import *");
+    string("from thrift.Thrift import TType, TMessageType");
 }
 
 /**
@@ -384,6 +439,7 @@
   f_types_ <<
     "class " << tenum->get_name() <<
     (gen_newstyle_ ? "(object)" : "") <<
+    (gen_dynamic_ ? "(" + gen_dynbaseclass_ + ")" : "") <<  
     ":" << endl;
   indent_up();
   generate_python_docstring(f_types_, tenum);
@@ -575,12 +631,19 @@
   out << std::endl <<
     "class " << tstruct->get_name();
   if (is_exception) {
-    out << "(Exception)";
-  } else if (gen_newstyle_) {
-    out << "(object)";
+    if (gen_dynamic_) {
+      out << "(" << gen_dynbaseclass_exc_ << ")";
+    } else {
+      out << "(Exception)";
+    }
+  } else {
+    if (gen_newstyle_) {
+      out << "(object)";
+    } else if (gen_dynamic_) {
+      out << "(" << gen_dynbaseclass_ << ")";
+    }
   }
-  out <<
-    ":" << endl;
+  out << ":" << endl;
   indent_up();
   generate_python_docstring(out, tstruct);
 
@@ -606,6 +669,17 @@
      TODO(dreiss): Consider making this work for structs with negative tags.
   */
 
+  if (gen_slots_) {
+    indent(out) << "__slots__ = [ " << endl;
+    indent_up();
+    for (m_iter = sorted_members.begin(); m_iter != sorted_members.end(); ++m_iter) {
+      indent(out) <<  "'" << (*m_iter)->get_name()  << "'," << endl;
+    }
+    indent_down();
+    indent(out) << " ]" << endl << endl;
+
+  }
+
   // TODO(dreiss): Look into generating an empty tuple instead of None
   // for structures with no members.
   // TODO(dreiss): Test encoding of structs where some inner structs
@@ -672,8 +746,10 @@
     out << endl;
   }
 
-  generate_py_struct_reader(out, tstruct);
-  generate_py_struct_writer(out, tstruct);
+  if (!gen_dynamic_) {
+    generate_py_struct_reader(out, tstruct);
+    generate_py_struct_writer(out, tstruct);
+  }
 
   // For exceptions only, generate a __str__ method. This is
   // because when raised exceptions are printed to the console, __repr__
@@ -685,31 +761,61 @@
       endl;
   }
 
-  // Printing utilities so that on the command line thrift
-  // structs look pretty like dictionaries
-  out <<
-    indent() << "def __repr__(self):" << endl <<
-    indent() << "  L = ['%s=%r' % (key, value)" << endl <<
-    indent() << "    for key, value in self.__dict__.iteritems()]" << endl <<
-    indent() << "  return '%s(%s)' % (self.__class__.__name__, ', '.join(L))" << endl <<
-    endl;
+  if (!gen_slots_) {
+    // Printing utilities so that on the command line thrift
+    // structs look pretty like dictionaries
+    out <<
+      indent() << "def __repr__(self):" << endl <<
+      indent() << "  L = ['%s=%r' % (key, value)" << endl <<
+      indent() << "    for key, value in self.__dict__.iteritems()]" << endl <<
+      indent() << "  return '%s(%s)' % (self.__class__.__name__, ', '.join(L))" << endl <<
+      endl;
 
-  // Equality and inequality methods that compare by value
-  out <<
-    indent() << "def __eq__(self, other):" << endl;
-  indent_up();
-  out <<
-    indent() << "return isinstance(other, self.__class__) and "
-                "self.__dict__ == other.__dict__" << endl;
-  indent_down();
-  out << endl;
+    // Equality and inequality methods that compare by value
+    out <<
+      indent() << "def __eq__(self, other):" << endl;
+    indent_up();
+    out <<
+      indent() << "return isinstance(other, self.__class__) and "
+                  "self.__dict__ == other.__dict__" << endl;
+    indent_down();
+    out << endl;
 
-  out <<
-    indent() << "def __ne__(self, other):" << endl;
-  indent_up();
-  out <<
-    indent() << "return not (self == other)" << endl;
-  indent_down();
+    out <<
+      indent() << "def __ne__(self, other):" << endl;
+    indent_up();
+
+    out <<
+      indent() << "return not (self == other)" << endl;
+    indent_down();
+  } else if (!gen_dynamic_) {
+    // no base class available to implement __eq__ and __repr__ and __ne__ for us
+    // so we must provide one that uses __slots__
+    out <<
+      indent() << "def __repr__(self):" << endl <<
+      indent() << "  L = ['%s=%r' % (key, getattr(self, key))" << endl <<
+      indent() << "    for key in self.__slots__]" << endl <<
+      indent() << "  return '%s(%s)' % (self.__class__.__name__, ', '.join(L))" << endl <<
+      endl;
+    
+    // Equality method that compares each attribute by value and type, walking __slots__
+    out <<
+      indent() << "def __eq__(self, other):" << endl <<
+      indent() << "  if not isinstance(other, self.__class__):" << endl <<
+      indent() << "    return False" << endl <<
+      indent() << "  for attr in self.__slots__:" << endl <<
+      indent() << "    my_val = getattr(self, attr)" << endl <<
+      indent() << "    other_val = getattr(other, attr)" << endl <<
+      indent() << "    if my_val != other_val:" << endl <<
+      indent() << "      return False" << endl <<
+      indent() << "  return True" << endl <<
+      endl;
+    
+    out <<
+      indent() << "def __ne__(self, other):" << endl <<
+      indent() << "  return not (self == other)" << endl <<
+      endl;
+  }
   indent_down();
 }
 
@@ -984,7 +1090,7 @@
   } else {
     if (gen_twisted_) {
       extends_if = "(Interface)";
-    } else if (gen_newstyle_) {
+    } else if (gen_newstyle_ || gen_dynamic_) {
       extends_if = "(object)";
     }
   }
@@ -1031,8 +1137,8 @@
       extends_client = extends + ".Client, ";
     }
   } else {
-    if (gen_twisted_ && gen_newstyle_) {
-        extends_client = "(object)";
+    if (gen_twisted_ && (gen_newstyle_ || gen_dynamic_)) {
+      extends_client = "(object)";
     }
   }
 
@@ -2388,6 +2494,11 @@
 THRIFT_REGISTER_GENERATOR(py, "Python",
 "    new_style:       Generate new-style classes.\n" \
 "    twisted:         Generate Twisted-friendly RPC services.\n" \
-"    utf8strings:     Encode/decode strings using utf8 in the generated code.\n"
-)
+"    utf8strings:     Encode/decode strings using utf8 in the generated code.\n" \
+"    slots:           Generate code using slots for instance members.\n" \
+"    dynamic:         Generate dynamic code, less code generated but slower.\n" \
+"    dynbase=CLS      Derive generated classes from class CLS instead of TBase.\n" \
+"    dynexc=CLS       Derive generated exceptions from CLS instead of TExceptionBase.\n" \
+"    dynimport='from foo.bar import CLS'\n" \
+"                     Add an import line to generated code to find the dynbase class.\n")
 
diff --git a/lib/py/src/Thrift.py b/lib/py/src/Thrift.py
index af6f58d..1d271fc 100644
--- a/lib/py/src/Thrift.py
+++ b/lib/py/src/Thrift.py
@@ -38,6 +38,25 @@
   UTF8   = 16
   UTF16  = 17
 
+  _VALUES_TO_NAMES = ( 'STOP',
+                      'VOID',
+                      'BOOL',
+                      'BYTE',
+                      'DOUBLE',
+                      None,
+                      'I16',
+                      None,
+                      'I32',
+                      None,
+                       'I64',
+                       'STRING',
+                       'STRUCT',
+                       'MAP',
+                       'SET',
+                       'LIST',
+                       'UTF8',
+                       'UTF16' )
+
 class TMessageType:
   CALL  = 1
   REPLY = 2
diff --git a/lib/py/src/protocol/TBase.py b/lib/py/src/protocol/TBase.py
new file mode 100644
index 0000000..e675c7d
--- /dev/null
+++ b/lib/py/src/protocol/TBase.py
@@ -0,0 +1,72 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from thrift.Thrift import *
+from thrift.protocol import TBinaryProtocol
+from thrift.transport import TTransport
+
+try:
+  from thrift.protocol import fastbinary
+except:
+  fastbinary = None
+
+class TBase(object):
+  __slots__ = []
+
+  def __repr__(self):
+    L = ['%s=%r' % (key, getattr(self, key))
+              for key in self.__slots__ ]
+    return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
+
+  def __eq__(self, other):
+    if not isinstance(other, self.__class__):
+      return False
+    for attr in self.__slots__:
+      my_val = getattr(self, attr)
+      other_val = getattr(other, attr)
+      if my_val != other_val:
+        return False
+    return True
+    
+  def __ne__(self, other):
+    return not (self == other)
+  
+  def read(self, iprot):
+    if iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None and fastbinary is not None:
+      fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec))
+      return
+    iprot.readStruct(self, self.thrift_spec)
+
+  def write(self, oprot):
+    if oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and self.thrift_spec is not None and fastbinary is not None:
+      oprot.trans.write(fastbinary.encode_binary(self, (self.__class__, self.thrift_spec)))
+      return
+    oprot.writeStruct(self, self.thrift_spec)
+
+class TExceptionBase(Exception):
+  # old style class so python2.4 can raise exceptions derived from this
+  #  This can't inherit from TBase because of that limitation.
+  __slots__ = []
+  
+  __repr__ = TBase.__repr__.im_func
+  __eq__ = TBase.__eq__.im_func
+  __ne__ = TBase.__ne__.im_func
+  read = TBase.read.im_func
+  write = TBase.write.im_func
+  
diff --git a/lib/py/src/protocol/TCompactProtocol.py b/lib/py/src/protocol/TCompactProtocol.py
index 6d57aeb..016a331 100644
--- a/lib/py/src/protocol/TCompactProtocol.py
+++ b/lib/py/src/protocol/TCompactProtocol.py
@@ -1,3 +1,22 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
 from TProtocol import *
 from struct import pack, unpack
 
diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py
index be3cb14..7338ff6 100644
--- a/lib/py/src/protocol/TProtocol.py
+++ b/lib/py/src/protocol/TProtocol.py
@@ -200,6 +200,205 @@
         self.skip(etype)
       self.readListEnd()
 
+  # tuple of: ( 'reader method' name, is_container boolean, 'writer_method' name )
+  _TTYPE_HANDLERS = (
+       (None, None, False), # 0 == TType,STOP
+       (None, None, False), # 1 == TType.VOID # TODO: handle void?
+       ('readBool', 'writeBool', False), # 2 == TType.BOOL
+       ('readByte',  'writeByte', False), # 3 == TType.BYTE and I08
+       ('readDouble', 'writeDouble', False), # 4 == TType.DOUBLE
+       (None, None, False), # 5, undefined
+       ('readI16', 'writeI16', False), # 6 == TType.I16
+       (None, None, False), # 7, undefined
+       ('readI32', 'writeI32', False), # 8 == TType.I32
+       (None, None, False), # 9, undefined
+       ('readI64', 'writeI64', False), # 10 == TType.I64
+       ('readString', 'writeString', False), # 11 == TType.STRING and UTF7
+       ('readContainerStruct', 'writeContainerStruct', True), # 12 == TType.STRUCT
+       ('readContainerMap', 'writeContainerMap', True), # 13 == TType.MAP
+       ('readContainerSet', 'writeContainerSet', True), # 14 == TType.SET
+       ('readContainerList', 'writeContainerList', True), # 15 == TType.LIST
+       (None, None, False), # 16 == TType.UTF8 # TODO: handle utf8 types?
+       (None, None, False)# 17 == TType.UTF16 # TODO: handle utf16 types?
+      )
+
+  def readFieldByTType(self, ttype, spec):
+    try:
+      (r_handler, w_handler, is_container) = self._TTYPE_HANDLERS[ttype]
+    except IndexError:
+      raise TProtocolException(type=TProtocolException.INVALID_DATA,
+                               message='Invalid field type %d' % (ttype))
+    if r_handler is None:
+      raise TProtocolException(type=TProtocolException.INVALID_DATA,
+                               message='Invalid field type %d' % (ttype))
+    reader = getattr(self, r_handler)
+    if not is_container:
+      return reader()
+    return reader(spec)
+
+  def readContainerList(self, spec):
+    results = []
+    ttype, tspec = spec[0], spec[1]
+    r_handler = self._TTYPE_HANDLERS[ttype][0]
+    reader = getattr(self, r_handler)
+    (list_type, list_len) = self.readListBegin()
+    if tspec is None:
+      # list values are simple types
+      for idx in xrange(list_len):
+        results.append(reader())
+    else:
+      # this is like an inlined readFieldByTType
+      container_reader = self._TTYPE_HANDLERS[list_type][0]
+      val_reader = getattr(self, container_reader)
+      for idx in xrange(list_len):
+        val = val_reader(tspec)
+        results.append(val)
+    self.readListEnd()
+    return results
+
+  def readContainerSet(self, spec):
+    results = set()
+    ttype, tspec = spec[0], spec[1]
+    r_handler = self._TTYPE_HANDLERS[ttype][0]
+    reader = getattr(self, r_handler)
+    (set_type, set_len) = self.readSetBegin()
+    if tspec is None:
+      # set members are simple types
+      for idx in xrange(set_len):
+        results.add(reader())
+    else:
+      container_reader = self._TTYPE_HANDLERS[set_type][0]
+      val_reader = getattr(self, container_reader)
+      for idx in xrange(set_len):
+        results.add(val_reader(tspec)) 
+    self.readSetEnd()
+    return results
+
+  def readContainerStruct(self, spec):
+    (obj_class, obj_spec) = spec
+    obj = obj_class()
+    obj.read(self)
+    return obj
+  
+  def readContainerMap(self, spec):
+    results = dict()
+    key_ttype, key_spec = spec[0], spec[1]
+    val_ttype, val_spec = spec[2], spec[3]
+    (map_ktype, map_vtype, map_len) = self.readMapBegin()
+    # TODO: compare types we just decoded with thrift_spec and abort/skip if types disagree
+    key_reader = getattr(self, self._TTYPE_HANDLERS[key_ttype][0])
+    val_reader = getattr(self, self._TTYPE_HANDLERS[val_ttype][0])
+    # list values are simple types
+    for idx in xrange(map_len):
+      if key_spec is None:
+        k_val = key_reader()
+      else:
+        k_val = self.readFieldByTType(key_ttype, key_spec)
+      if val_spec is None:
+        v_val = val_reader()
+      else:
+        v_val = self.readFieldByTType(val_ttype, val_spec)
+      # this raises a TypeError with unhashable keys types. i.e. d=dict(); d[[0,1]] = 2 fails
+      results[k_val] = v_val
+    self.readMapEnd()
+    return results
+
+  def readStruct(self, obj, thrift_spec):
+    self.readStructBegin()
+    while True:
+      (fname, ftype, fid) = self.readFieldBegin()
+      if ftype == TType.STOP:
+        break
+      try:
+        field = thrift_spec[fid]
+      except IndexError:
+        self.skip(ftype)
+      else:
+        if field is not None and ftype == field[1]:
+          fname = field[2]
+          fspec = field[3]
+          val = self.readFieldByTType(ftype, fspec)
+          setattr(obj, fname, val)
+        else:
+          self.skip(ftype)
+      self.readFieldEnd()
+    self.readStructEnd()
+
+  def writeContainerStruct(self, val, spec):
+    val.write(self)
+
+  def writeContainerList(self, val, spec):
+    self.writeListBegin(spec[0], len(val))
+    r_handler, w_handler, is_container  = self._TTYPE_HANDLERS[spec[0]]
+    e_writer = getattr(self, w_handler)
+    if not is_container:
+      for elem in val:
+        e_writer(elem)
+    else:
+      for elem in val:
+        e_writer(elem, spec[1])
+    self.writeListEnd()
+
+  def writeContainerSet(self, val, spec):
+    self.writeSetBegin(spec[0], len(val))
+    r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]]
+    e_writer = getattr(self, w_handler)
+    if not is_container:
+      for elem in val:
+        e_writer(elem)
+    else:
+      for elem in val:
+        e_writer(elem, spec[1])
+    self.writeSetEnd()
+
+  def writeContainerMap(self, val, spec):
+    k_type = spec[0]
+    v_type = spec[2]
+    ignore, ktype_name, k_is_container = self._TTYPE_HANDLERS[k_type]
+    ignore, vtype_name, v_is_container = self._TTYPE_HANDLERS[v_type]
+    k_writer = getattr(self, ktype_name)
+    v_writer = getattr(self, vtype_name)
+    self.writeMapBegin(k_type, v_type, len(val))
+    for m_key, m_val in val.iteritems():
+      if not k_is_container:
+        k_writer(m_key)
+      else:
+        k_writer(m_key, spec[1])
+      if not v_is_container:
+        v_writer(m_val)
+      else:
+        v_writer(m_val, spec[3])
+    self.writeMapEnd()
+
+  def writeStruct(self, obj, thrift_spec):
+    self.writeStructBegin(obj.__class__.__name__)
+    for field in thrift_spec:
+      if field is None:
+        continue
+      fname = field[2]
+      val = getattr(obj, fname)
+      if val is None:
+        # skip writing out unset fields
+        continue
+      fid = field[0]
+      ftype = field[1]
+      fspec = field[3]
+      # get the writer method for this value
+      self.writeFieldBegin(fname, ftype, fid)
+      self.writeFieldByTType(ftype, val, fspec)
+      self.writeFieldEnd()
+    self.writeFieldStop()
+    self.writeStructEnd()
+
+  def writeFieldByTType(self, ttype, val, spec):
+    r_handler, w_handler, is_container = self._TTYPE_HANDLERS[ttype]
+    writer = getattr(self, w_handler)
+    if is_container:
+      writer(val, spec)
+    else:
+      writer(val)
+
 class TProtocolFactory:
   def getProtocol(self, trans):
     pass
+
diff --git a/lib/py/src/protocol/__init__.py b/lib/py/src/protocol/__init__.py
index 01bfe18..d53359b 100644
--- a/lib/py/src/protocol/__init__.py
+++ b/lib/py/src/protocol/__init__.py
@@ -17,4 +17,4 @@
 # under the License.
 #
 
-__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary']
+__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary', 'TBase']
diff --git a/test/ThriftTest.thrift b/test/ThriftTest.thrift
index b6cd939..17b0295 100644
--- a/test/ThriftTest.thrift
+++ b/test/ThriftTest.thrift
@@ -208,3 +208,21 @@
   3000: VersioningTestV2 vertwo3000,
   4000: list<i32> big_numbers
 }
+
+struct NestedListsI32x2 {
+  1: list<list<i32>> integerlist
+}
+struct NestedListsI32x3 {
+  1: list<list<list<i32>>> integerlist
+}
+struct NestedMixedx2 {
+  1: list<set<i32>> int_set_list
+  2: map<i32,set<string>> map_int_strset
+  3: list<map<i32,set<string>>> map_int_strset_list
+}
+struct ListBonks {
+  1: list<Bonk> bonk
+}
+struct NestedListsBonk {
+  1: list<list<list<Bonk>>> bonk
+}
diff --git a/test/py/Makefile.am b/test/py/Makefile.am
index 63b7a89..2317ef6 100644
--- a/test/py/Makefile.am
+++ b/test/py/Makefile.am
@@ -19,22 +19,30 @@
 
 THRIFT = $(top_srcdir)/compiler/cpp/thrift
 
-py_unit_tests =                                 \
-        SerializationTest.py                    \
-        TestEof.py                              \
-        TestSyntax.py                           \
-        RunClientServer.py
+py_unit_tests = RunClientServer.py
 
 thrift_gen =                                    \
         gen-py/ThriftTest/__init__.py           \
-        gen-py/DebugProtoTest/__init__.py
+        gen-py/DebugProtoTest/__init__.py \
+        gen-py-default/ThriftTest/__init__.py           \
+        gen-py-default/DebugProtoTest/__init__.py \
+        gen-py-slots/ThriftTest/__init__.py           \
+        gen-py-slots/DebugProtoTest/__init__.py \
+        gen-py-newstyle/ThriftTest/__init__.py           \
+        gen-py-newstyle/DebugProtoTest/__init__.py \
+        gen-py-newstyleslots/ThriftTest/__init__.py           \
+        gen-py-newstyleslots/DebugProtoTest/__init__.py \
+        gen-py-dynamic/ThriftTest/__init__.py           \
+        gen-py-dynamic/DebugProtoTest/__init__.py \
+        gen-py-dynamicslots/ThriftTest/__init__.py           \
+        gen-py-dynamicslots/DebugProtoTest/__init__.py
 
 helper_scripts=                                 \
         TestClient.py                           \
         TestServer.py
 
 check_SCRIPTS=                                  \
-        $(thrift_gen)                           \
+        $(thrift_gen) \
         $(py_unit_tests)                        \
         $(helper_scripts)
 
@@ -42,7 +50,29 @@
 
 
 gen-py/%/__init__.py: ../%.thrift
-	$(THRIFT) --gen py $<
+	$(THRIFT) --gen py  $<
+	test -d gen-py-default || mkdir gen-py-default
+	$(THRIFT) --gen py -out gen-py-default $<
+
+gen-py-slots/%/__init__.py: ../%.thrift
+	test -d gen-py-slots || mkdir gen-py-slots
+	$(THRIFT) --gen py:slots -out gen-py-slots $<
+
+gen-py-newstyle/%/__init__.py: ../%.thrift
+	test -d gen-py-newstyle || mkdir gen-py-newstyle
+	$(THRIFT) --gen py:new_style -out gen-py-newstyle $<
+
+gen-py-newstyleslots/%/__init__.py: ../%.thrift
+	test -d gen-py-newstyleslots || mkdir gen-py-newstyleslots
+	$(THRIFT) --gen py:new_style,slots -out gen-py-newstyleslots $<
+
+gen-py-dynamic/%/__init__.py: ../%.thrift
+	test -d gen-py-dynamic || mkdir gen-py-dynamic
+	$(THRIFT) --gen py:dynamic -out gen-py-dynamic $<
+
+gen-py-dynamicslots/%/__init__.py: ../%.thrift
+	test -d gen-py-dynamicslots || mkdir gen-py-dynamicslots
+	$(THRIFT) --gen py:dynamic,slots -out gen-py-dynamicslots $<
 
 clean-local:
-	$(RM) -r gen-py
+	$(RM) -r gen-py gen-py-slots gen-py-default gen-py-newstyle gen-py-newstyleslots gen-py-dynamic gen-py-dynamicslots
diff --git a/test/py/RunClientServer.py b/test/py/RunClientServer.py
index 633856f..8a7fda6 100755
--- a/test/py/RunClientServer.py
+++ b/test/py/RunClientServer.py
@@ -28,6 +28,9 @@
 from optparse import OptionParser
 
 parser = OptionParser()
+parser.add_option('--genpydirs', type='string', dest='genpydirs',
+    default='default,slots,newstyle,newstyleslots,dynamic,dynamicslots',
+    help='directory extensions for generated code, used as suffixes for \"gen-py-*\" added sys.path for individual tests')
 parser.add_option("--port", type="int", dest="port", default=9090,
     help="port number for server to listen on")
 parser.add_option('-v', '--verbose', action="store_const", 
@@ -39,11 +42,15 @@
 parser.set_defaults(verbose=1)
 options, args = parser.parse_args()
 
+generated_dirs = []
+for gp_dir in options.genpydirs.split(','):
+  generated_dirs.append('gen-py-%s' % (gp_dir))
+
+SCRIPTS = ['SerializationTest.py', 'TestEof.py', 'TestSyntax.py', 'TestSocket.py']
 FRAMED = ["TNonblockingServer"]
 SKIP_ZLIB = ['TNonblockingServer', 'THttpServer']
 SKIP_SSL = ['TNonblockingServer', 'THttpServer']
-EXTRA_DELAY = ['TProcessPoolServer']
-EXTRA_SLEEP = 3.5
+EXTRA_DELAY = dict(TProcessPoolServer=3.5)
 
 PROTOS= [
     'accel',
@@ -85,11 +92,21 @@
 def relfile(fname):
     return os.path.join(os.path.dirname(__file__), fname)
 
-def runTest(server_class, proto, port, use_zlib, use_ssl):
+def runScriptTest(genpydir, script):
+  script_args = [sys.executable, relfile(script) ]
+  script_args.append('--genpydir=%s' % genpydir)
+  serverproc = subprocess.Popen(script_args)
+  print '\nTesting script: %s\n----' % (' '.join(script_args))
+  ret = subprocess.call(script_args)
+  if ret != 0:
+    raise Exception("Script subprocess failed, retcode=%d, args: %s" % (ret, ' '.join(script_args)))
+  
+def runServiceTest(genpydir, server_class, proto, port, use_zlib, use_ssl):
   # Build command line arguments
   server_args = [sys.executable, relfile('TestServer.py') ]
   cli_args = [sys.executable, relfile('TestClient.py') ]
   for which in (server_args, cli_args):
+    which.append('--genpydir=%s' % genpydir)
     which.append('--proto=%s' % proto) # accel, binary or compact
     which.append('--port=%d' % port) # default to 9090
     if use_zlib:
@@ -110,7 +127,7 @@
   if options.verbose > 0:
     print 'Testing server %s: %s' % (server_class, ' '.join(server_args))
   serverproc = subprocess.Popen(server_args)
-  time.sleep(0.2)
+  time.sleep(0.15)
   try:
     if options.verbose > 0:
       print 'Testing client: %s' % (' '.join(cli_args))
@@ -124,29 +141,47 @@
       print 'FAIL: Server process (%s) failed with retcode %d' % (' '.join(server_args), serverproc.returncode)
       raise Exception('Server subprocess %s died, args: %s' % (server_class, ' '.join(server_args)))
     else:
-      if server_class in EXTRA_DELAY:
-        if options.verbose > 0:
-          print 'Giving %s (proto=%s,zlib=%s,ssl=%s) an extra %d seconds for child processes to terminate via alarm' % (server_class,
-                proto, use_zlib, use_ssl, EXTRA_SLEEP)
-        time.sleep(EXTRA_SLEEP)
+      extra_sleep = EXTRA_DELAY.get(server_class, 0)
+      if extra_sleep > 0 and options.verbose > 0:
+        print 'Giving %s (proto=%s,zlib=%s,ssl=%s) an extra %d seconds for child processes to terminate via alarm' % (server_class,
+              proto, use_zlib, use_ssl, extra_sleep)
+        time.sleep(extra_sleep)
       os.kill(serverproc.pid, signal.SIGKILL)
   # wait for shutdown
-  time.sleep(0.1)
+  time.sleep(0.05)
 
 test_count = 0
+# run tests without a client/server first
+print '----------------'
+print ' Executing individual test scripts with various generated code directories'
+print ' Directories to be tested: ' + ', '.join(generated_dirs)
+print ' Scripts to be tested: ' + ', '.join(SCRIPTS)
+print '----------------'
+for genpydir in generated_dirs:
+  for script in SCRIPTS:
+    runScriptTest(genpydir, script)
+  
+print '----------------'
+print ' Executing Client/Server tests with various generated code directories'
+print ' Servers to be tested: ' + ', '.join(SERVERS)
+print ' Directories to be tested: ' + ', '.join(generated_dirs)
+print ' Protocols to be tested: ' + ', '.join(PROTOS)
+print ' Options to be tested: ZLIB(yes/no), SSL(yes/no)'
+print '----------------'
 for try_server in SERVERS:
-  for try_proto in PROTOS:
-    for with_zlib in (False, True):
-      # skip any servers that don't work with the Zlib transport
-      if with_zlib and try_server in SKIP_ZLIB:
-        continue
-      for with_ssl in (False, True):
-        # skip any servers that don't work with SSL
-        if with_ssl and try_server in SKIP_SSL:
+  for genpydir in generated_dirs:
+    for try_proto in PROTOS:
+      for with_zlib in (False, True):
+        # skip any servers that don't work with the Zlib transport
+        if with_zlib and try_server in SKIP_ZLIB:
           continue
-        test_count += 1
-        if options.verbose > 0:
-          print '\nTest run #%d:  Server=%s,  Proto=%s,  zlib=%s,  SSL=%s' % (test_count, try_server, try_proto, with_zlib, with_ssl)
-        runTest(try_server, try_proto, options.port, with_zlib, with_ssl)
-        if options.verbose > 0:
-          print 'OK: Finished  %s / %s proto / zlib=%s / SSL=%s.   %d combinations tested.' % (try_server, try_proto, with_zlib, with_ssl, test_count)
+        for with_ssl in (False, True):
+          # skip any servers that don't work with SSL
+          if with_ssl and try_server in SKIP_SSL:
+            continue
+          test_count += 1
+          if options.verbose > 0:
+            print '\nTest run #%d:  (includes %s) Server=%s,  Proto=%s,  zlib=%s,  SSL=%s' % (test_count, genpydir, try_server, try_proto, with_zlib, with_ssl)
+          runServiceTest(genpydir, try_server, try_proto, options.port, with_zlib, with_ssl)
+          if options.verbose > 0:
+            print 'OK: Finished (includes %s)  %s / %s proto / zlib=%s / SSL=%s.   %d combinations tested.' % (genpydir, try_server, try_proto, with_zlib, with_ssl, test_count)
diff --git a/test/py/SerializationTest.py b/test/py/SerializationTest.py
index 3ba76fb..0664146 100755
--- a/test/py/SerializationTest.py
+++ b/test/py/SerializationTest.py
@@ -20,7 +20,12 @@
 #
 
 import sys, glob
-sys.path.insert(0, './gen-py')
+from optparse import OptionParser
+parser = OptionParser()
+parser.add_option('--genpydir', type='string', dest='genpydir', default='gen-py')
+options, args = parser.parse_args()
+del sys.argv[1:] # clean up hack so unittest doesn't complain
+sys.path.insert(0, options.genpydir)
 sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
 
 from ThriftTest.ttypes import *
@@ -119,28 +124,86 @@
           byte_list_map={0 : [], 1 : [1], 2 : [1, 2]},
           )
 
+      self.nested_lists_i32x2 = NestedListsI32x2(
+                                              [
+                                                [ 1, 1, 2 ],
+                                                [ 2, 7, 9 ],
+                                                [ 3, 5, 8 ]
+                                              ]
+                                            )
+
+      self.nested_lists_i32x3 = NestedListsI32x3(
+                                              [
+                                                [
+                                                  [ 2, 7, 9 ],
+                                                  [ 3, 5, 8 ]
+                                                ],
+                                                [
+                                                  [ 1, 1, 2 ],
+                                                  [ 1, 4, 9 ]
+                                                ]
+                                              ]
+                                            )
+
+      self.nested_mixedx2 = NestedMixedx2( int_set_list=[
+                                            set([1,2,3]),
+                                            set([1,4,9]),
+                                            set([1,2,3,5,8,13,21]),
+                                            set([-1, 0, 1])
+                                            ],
+                                            # note, the sets below are sets of chars, since the strings are iterated
+                                            map_int_strset={ 10:set('abc'), 20:set('def'), 30:set('GHI') },
+                                            map_int_strset_list=[
+                                                                 { 10:set('abc'), 20:set('def'), 30:set('GHI') },
+                                                                 { 100:set('lmn'), 200:set('opq'), 300:set('RST') },
+                                                                 { 1000:set('uvw'), 2000:set('wxy'), 3000:set('XYZ') }
+                                                                 ]
+                                          )
+
+      self.nested_lists_bonk = NestedListsBonk(
+                                              [
+                                                [
+                                                  [
+                                                    Bonk(message='inner A first', type=1),
+                                                    Bonk(message='inner A second', type=1)
+                                                  ],
+                                                  [
+                                                  Bonk(message='inner B first', type=2),
+                                                  Bonk(message='inner B second', type=2)
+                                                  ]
+                                                ]
+                                              ]
+                                            )
+
+      self.list_bonks = ListBonks(
+                                    [
+                                      Bonk(message='inner A', type=1),
+                                      Bonk(message='inner B', type=2),
+                                      Bonk(message='inner C', type=0)
+                                    ]
+                                  )
 
   def _serialize(self, obj):
-      trans = TTransport.TMemoryBuffer()
-      prot = self.protocol_factory.getProtocol(trans)
-      obj.write(prot)
-      return trans.getvalue()
+    trans = TTransport.TMemoryBuffer()
+    prot = self.protocol_factory.getProtocol(trans)
+    obj.write(prot)
+    return trans.getvalue()
 
   def _deserialize(self, objtype, data):
-      prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data))
-      ret = objtype()
-      ret.read(prot)
-      return ret
+    prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data))
+    ret = objtype()
+    ret.read(prot)
+    return ret
 
   def testForwards(self):
-      obj = self._deserialize(VersioningTestV2, self._serialize(self.v1obj))
-      self.assertEquals(obj.begin_in_both, self.v1obj.begin_in_both)
-      self.assertEquals(obj.end_in_both, self.v1obj.end_in_both)
+    obj = self._deserialize(VersioningTestV2, self._serialize(self.v1obj))
+    self.assertEquals(obj.begin_in_both, self.v1obj.begin_in_both)
+    self.assertEquals(obj.end_in_both, self.v1obj.end_in_both)
 
   def testBackwards(self):
-      obj = self._deserialize(VersioningTestV1, self._serialize(self.v2obj))
-      self.assertEquals(obj.begin_in_both, self.v2obj.begin_in_both)
-      self.assertEquals(obj.end_in_both, self.v2obj.end_in_both)
+    obj = self._deserialize(VersioningTestV1, self._serialize(self.v2obj))
+    self.assertEquals(obj.begin_in_both, self.v2obj.begin_in_both)
+    self.assertEquals(obj.end_in_both, self.v2obj.end_in_both)
 
   def testSerializeV1(self):
     obj = self._deserialize(VersioningTestV1, self._serialize(self.v1obj))
@@ -152,20 +215,57 @@
 
   def testBools(self):
     self.assertNotEquals(self.bools, self.bools_flipped)
+    self.assertNotEquals(self.bools, self.v1obj)
     obj = self._deserialize(Bools, self._serialize(self.bools))
     self.assertEquals(obj, self.bools)
     obj = self._deserialize(Bools, self._serialize(self.bools_flipped))
     self.assertEquals(obj, self.bools_flipped)
+    rep = repr(self.bools)
+    self.assertTrue(len(rep) > 0)
 
   def testLargeDeltas(self):
     # test large field deltas (meaningful in CompactProto only)
     obj = self._deserialize(LargeDeltas, self._serialize(self.large_deltas))
     self.assertEquals(obj, self.large_deltas)
+    rep = repr(self.large_deltas)
+    self.assertTrue(len(rep) > 0)
+
+  def testNestedListsI32x2(self):
+    obj = self._deserialize(NestedListsI32x2, self._serialize(self.nested_lists_i32x2))
+    self.assertEquals(obj, self.nested_lists_i32x2)
+    rep = repr(self.nested_lists_i32x2)
+    self.assertTrue(len(rep) > 0)
+
+  def testNestedListsI32x3(self):
+    obj = self._deserialize(NestedListsI32x3, self._serialize(self.nested_lists_i32x3))
+    self.assertEquals(obj, self.nested_lists_i32x3)
+    rep = repr(self.nested_lists_i32x3)
+    self.assertTrue(len(rep) > 0)
+
+  def testNestedMixedx2(self):
+    obj = self._deserialize(NestedMixedx2, self._serialize(self.nested_mixedx2))
+    self.assertEquals(obj, self.nested_mixedx2)
+    rep = repr(self.nested_mixedx2)
+    self.assertTrue(len(rep) > 0)
+
+  def testNestedListsBonk(self):
+    obj = self._deserialize(NestedListsBonk, self._serialize(self.nested_lists_bonk))
+    self.assertEquals(obj, self.nested_lists_bonk)
+    rep = repr(self.nested_lists_bonk)
+    self.assertTrue(len(rep) > 0)
+
+  def testListBonks(self):
+    obj = self._deserialize(ListBonks, self._serialize(self.list_bonks))
+    self.assertEquals(obj, self.list_bonks)
+    rep = repr(self.list_bonks)
+    self.assertTrue(len(rep) > 0)
 
   def testCompactStruct(self):
     # test large field deltas (meaningful in CompactProto only)
     obj = self._deserialize(CompactProtoTestStruct, self._serialize(self.compact_struct))
     self.assertEquals(obj, self.compact_struct)
+    rep = repr(self.compact_struct)
+    self.assertTrue(len(rep) > 0)
 
 class NormalBinaryTest(AbstractTest):
   protocol_factory = TBinaryProtocol.TBinaryProtocolFactory()
diff --git a/test/py/TestClient.py b/test/py/TestClient.py
index 6429ec3..e5d4326 100755
--- a/test/py/TestClient.py
+++ b/test/py/TestClient.py
@@ -20,23 +20,16 @@
 #
 
 import sys, glob
-sys.path.insert(0, './gen-py')
 sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
 
-from ThriftTest import ThriftTest
-from ThriftTest.ttypes import *
-from thrift.transport import TTransport
-from thrift.transport import TSocket
-from thrift.transport import THttpClient
-from thrift.transport import TZlibTransport
-from thrift.protocol import TBinaryProtocol
-from thrift.protocol import TCompactProtocol
 import unittest
 import time
 from optparse import OptionParser
 
-
 parser = OptionParser()
+parser.add_option('--genpydir', type='string', dest='genpydir',
+                  default='gen-py',
+                  help='include this local directory in sys.path for locating generated code')
 parser.add_option("--port", type="int", dest="port",
     help="connect to server at port")
 parser.add_option("--host", type="string", dest="host",
@@ -60,6 +53,17 @@
 parser.set_defaults(framed=False, http_path=None, verbose=1, host='localhost', port=9090, proto='binary')
 options, args = parser.parse_args()
 
+sys.path.insert(0, options.genpydir)
+
+from ThriftTest import ThriftTest
+from ThriftTest.ttypes import *
+from thrift.transport import TTransport
+from thrift.transport import TSocket
+from thrift.transport import THttpClient
+from thrift.transport import TZlibTransport
+from thrift.protocol import TBinaryProtocol
+from thrift.protocol import TCompactProtocol
+
 class AbstractTest(unittest.TestCase):
   def setUp(self):
     if options.http_path:
@@ -176,6 +180,9 @@
     except Xception, x:
       self.assertEqual(x.errorCode, 1001)
       self.assertEqual(x.message, 'Xception')
+      # ensure exception's repr method works
+      x_repr = repr(x)
+      self.assertEqual(x_repr, 'Xception(errorCode=1001, message=\'Xception\')')
 
     try:
       self.client.testException("throw_undeclared")
@@ -225,4 +232,4 @@
         self.createTests()
 
 if __name__ == "__main__":
-  OwnArgsTestProgram(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2))
+  OwnArgsTestProgram(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=1))
diff --git a/test/py/TestEof.py b/test/py/TestEof.py
index 7ff0b42..a9d81f1 100755
--- a/test/py/TestEof.py
+++ b/test/py/TestEof.py
@@ -20,7 +20,12 @@
 #
 
 import sys, glob
-sys.path.insert(0, './gen-py')
+from optparse import OptionParser
+parser = OptionParser()
+parser.add_option('--genpydir', type='string', dest='genpydir', default='gen-py')
+options, args = parser.parse_args()
+del sys.argv[1:] # clean up hack so unittest doesn't complain
+sys.path.insert(0, options.genpydir)
 sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
 
 from ThriftTest import ThriftTest
diff --git a/test/py/TestServer.py b/test/py/TestServer.py
index fa62765..6f4af44 100755
--- a/test/py/TestServer.py
+++ b/test/py/TestServer.py
@@ -20,24 +20,13 @@
 #
 from __future__ import division
 import sys, glob, time
-sys.path.insert(0, './gen-py')
 sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
 from optparse import OptionParser
 
-from ThriftTest import ThriftTest
-from ThriftTest.ttypes import *
-from thrift.transport import TTransport
-from thrift.transport import TSocket
-from thrift.transport import TZlibTransport
-from thrift.protocol import TBinaryProtocol
-from thrift.protocol import TCompactProtocol
-from thrift.server import TServer, TNonblockingServer, THttpServer
-
-PROT_FACTORIES = {'binary': TBinaryProtocol.TBinaryProtocolFactory,
-    'accel': TBinaryProtocol.TBinaryProtocolAcceleratedFactory,
-    'compact': TCompactProtocol.TCompactProtocolFactory}
-
 parser = OptionParser()
+parser.add_option('--genpydir', type='string', dest='genpydir',
+                  default='gen-py',
+                  help='include this local directory in sys.path for locating generated code')
 parser.add_option("--port", type="int", dest="port",
     help="port number for server to listen on")
 parser.add_option("--zlib", action="store_true", dest="zlib",
@@ -55,6 +44,21 @@
 parser.set_defaults(port=9090, verbose=1, proto='binary')
 options, args = parser.parse_args()
 
+sys.path.insert(0, options.genpydir)
+
+from ThriftTest import ThriftTest
+from ThriftTest.ttypes import *
+from thrift.transport import TTransport
+from thrift.transport import TSocket
+from thrift.transport import TZlibTransport
+from thrift.protocol import TBinaryProtocol
+from thrift.protocol import TCompactProtocol
+from thrift.server import TServer, TNonblockingServer, THttpServer
+
+PROT_FACTORIES = {'binary': TBinaryProtocol.TBinaryProtocolFactory,
+    'accel': TBinaryProtocol.TBinaryProtocolAcceleratedFactory,
+    'compact': TCompactProtocol.TCompactProtocolFactory}
+
 class TestHandler:
 
   def testVoid(self):
@@ -105,7 +109,7 @@
       x.message = str
       raise x
     elif str == "throw_undeclared":
-      raise ValueError("foo")
+      raise ValueError("Exception test PASSES.")
 
   def testOneway(self, seconds):
     if options.verbose > 1:
@@ -206,7 +210,10 @@
         worker.terminate()
       if options.verbose > 0:
         print 'Requesting server to stop()'
-      server.stop()
+      try:
+        server.stop()
+      except:
+        pass
     signal.signal(signal.SIGALRM, clean_shutdown)
     signal.alarm(2)
   set_alarm()
diff --git a/test/py/TestSocket.py b/test/py/TestSocket.py
index 2f7353f..b9bdf27 100755
--- a/test/py/TestSocket.py
+++ b/test/py/TestSocket.py
@@ -20,7 +20,12 @@
 #
 
 import sys, glob
-sys.path.insert(0, './gen-py')
+from optparse import OptionParser
+parser = OptionParser()
+parser.add_option('--genpydir', type='string', dest='genpydir', default='gen-py')
+options, args = parser.parse_args()
+del sys.argv[1:] # clean up hack so unittest doesn't complain
+sys.path.insert(0, options.genpydir)
 sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
 
 from ThriftTest import ThriftTest
diff --git a/test/py/TestSyntax.py b/test/py/TestSyntax.py
index df67d48..9f71cf5 100755
--- a/test/py/TestSyntax.py
+++ b/test/py/TestSyntax.py
@@ -20,7 +20,12 @@
 #
 
 import sys, glob
-sys.path.insert(0, './gen-py')
+from optparse import OptionParser
+parser = OptionParser()
+parser.add_option('--genpydir', type='string', dest='genpydir', default='gen-py')
+options, args = parser.parse_args()
+del sys.argv[1:] # clean up hack so unittest doesn't complain
+sys.path.insert(0, options.genpydir)
 sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
 
 # Just import these generated files to make sure they are syntactically valid