THRIFT-926. cpp: Make TZlibTransport::flush() behave like other transports

Previously, TZlibTransport::flush() finished the zlib stream, so calling
write() after flush() would result in an error.  Now it just flushes the
data, without finishing the stream.  A new TZlibTransport::finish()
function has been added to finish the stream.

This breaks compatibility.  I'm aware of anyone using this code outside
of Facebook, though.

git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@1005151 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/cpp/src/transport/TTransport.h b/lib/cpp/src/transport/TTransport.h
index 8f2bd3d..fa0ed8a 100644
--- a/lib/cpp/src/transport/TTransport.h
+++ b/lib/cpp/src/transport/TTransport.h
@@ -144,6 +144,12 @@
   /**
    * Writes the string in its entirety to the buffer.
    *
+   * Note: You must call flush() to ensure the data is actually written,
+   * and available to be read back in the future.  Destroying a TTransport
+   * object does not automatically flush pending data--if you destroy a
+   * TTransport object with written but unflushed data, that data may be
+   * discarded.
+   *
    * @param buf  The data to write out
    * @throws TTransportException if an error occurs
    */
diff --git a/lib/cpp/src/transport/TZlibTransport.cpp b/lib/cpp/src/transport/TZlibTransport.cpp
index 2f14e90..8953742 100644
--- a/lib/cpp/src/transport/TZlibTransport.cpp
+++ b/lib/cpp/src/transport/TZlibTransport.cpp
@@ -90,8 +90,16 @@
   int rv;
   rv = inflateEnd(rstream_);
   checkZlibRvNothrow(rv, rstream_->msg);
+
   rv = deflateEnd(wstream_);
-  checkZlibRvNothrow(rv, wstream_->msg);
+  // Z_DATA_ERROR may be returned if the caller has written data, but not
+  // called flush() to actually finish writing the data out to the underlying
+  // transport.  The defined TTransport behavior in this case is that this data
+  // may be discarded, so we ignore the error and silently discard the data.
+  // For other erros, log a message.
+  if (rv != Z_DATA_ERROR) {
+    checkZlibRvNothrow(rv, wstream_->msg);
+  }
 
   delete[] urbuf_;
   delete[] crbuf_;
@@ -200,15 +208,20 @@
 // - Deflate from the source into the compressed buffer.
 
 void TZlibTransport::write(const uint8_t* buf, uint32_t len) {
+  if (output_finished_) {
+    throw TTransportException(TTransportException::BAD_ARGS,
+                              "write() called after finish()");
+  }
+
   // zlib's "deflate" function has enough logic in it that I think
   // we're better off (performance-wise) buffering up small writes.
   if ((int)len > MIN_DIRECT_DEFLATE_SIZE) {
-    flushToZlib(uwbuf_, uwpos_);
+    flushToZlib(uwbuf_, uwpos_, Z_NO_FLUSH);
     uwpos_ = 0;
-    flushToZlib(buf, len);
+    flushToZlib(buf, len, Z_NO_FLUSH);
   } else if (len > 0) {
     if (uwbuf_size_ - uwpos_ < (int)len) {
-      flushToZlib(uwbuf_, uwpos_);
+      flushToZlib(uwbuf_, uwpos_, Z_NO_FLUSH);
       uwpos_ = 0;
     }
     memcpy(uwbuf_ + uwpos_, buf, len);
@@ -217,19 +230,46 @@
 }
 
 void TZlibTransport::flush()  {
-  flushToZlib(uwbuf_, uwpos_, true);
-  assert((int)wstream_->avail_out != cwbuf_size_);
+  if (output_finished_) {
+    throw TTransportException(TTransportException::BAD_ARGS,
+                              "flush() called after finish()");
+  }
+
+  flushToTransport(Z_SYNC_FLUSH);
+}
+
+void TZlibTransport::finish()  {
+  if (output_finished_) {
+    throw TTransportException(TTransportException::BAD_ARGS,
+                              "finish() called more than once");
+  }
+
+  flushToTransport(Z_FINISH);
+}
+
+void TZlibTransport::flushToTransport(int flush)  {
+  // write pending data in uwbuf_ to zlib
+  flushToZlib(uwbuf_, uwpos_, flush);
+  uwpos_ = 0;
+
+  // write all available data from zlib to the transport
   transport_->write(cwbuf_, cwbuf_size_ - wstream_->avail_out);
+  wstream_->next_out = cwbuf_;
+  wstream_->avail_out = cwbuf_size_;
+
+  // flush the transport
   transport_->flush();
 }
 
-void TZlibTransport::flushToZlib(const uint8_t* buf, int len, bool finish) {
-  int flush = (finish ? Z_FINISH : Z_NO_FLUSH);
-
+void TZlibTransport::flushToZlib(const uint8_t* buf, int len, int flush) {
   wstream_->next_in  = const_cast<uint8_t*>(buf);
   wstream_->avail_in = len;
 
-  while (wstream_->avail_in > 0 || finish) {
+  while (true) {
+    if (flush == Z_NO_FLUSH && wstream_->avail_in == 0) {
+      break;
+    }
+
     // If our ouput buffer is full, flush to the underlying transport.
     if (wstream_->avail_out == 0) {
       transport_->write(cwbuf_, cwbuf_size_);
@@ -239,12 +279,18 @@
 
     int zlib_rv = deflate(wstream_, flush);
 
-    if (finish && zlib_rv == Z_STREAM_END) {
+    if (flush == Z_FINISH && zlib_rv == Z_STREAM_END) {
       assert(wstream_->avail_in == 0);
+      output_finished_ = true;
       break;
     }
 
     checkZlibRv(zlib_rv, wstream_->msg);
+
+    if ((flush == Z_SYNC_FLUSH || flush == Z_FULL_FLUSH) &&
+        wstream_->avail_in == 0 && wstream_->avail_out != 0) {
+      break;
+    }
   }
 }
 
diff --git a/lib/cpp/src/transport/TZlibTransport.h b/lib/cpp/src/transport/TZlibTransport.h
index 61c43fe..0f9815e 100644
--- a/lib/cpp/src/transport/TZlibTransport.h
+++ b/lib/cpp/src/transport/TZlibTransport.h
@@ -96,7 +96,7 @@
     urpos_(0),
     uwpos_(0),
     input_ended_(false),
-    output_flushed_(false),
+    output_finished_(false),
     urbuf_size_(urbuf_size),
     crbuf_size_(crbuf_size),
     uwbuf_size_(uwbuf_size),
@@ -145,6 +145,13 @@
   // Don't call this outside of the constructor.
   void initZlib();
 
+  /**
+   * TZlibTransport destructor.
+   *
+   * Warning: Destroying a TZlibTransport object may discard any written but
+   * unflushed data.  You must explicitly call flush() or finish() to ensure
+   * that data is actually written and flushed to the underlying transport.
+   */
   ~TZlibTransport();
 
   bool isOpen();
@@ -163,10 +170,25 @@
 
   void flush();
 
+  /**
+   * Finalize the zlib stream.
+   *
+   * This causes zlib to flush any pending write data and write end-of-stream
+   * information, including the checksum.  Once finish() has been called, no
+   * new data can be written to the stream.
+   */
+  void finish();
+
   const uint8_t* borrow(uint8_t* buf, uint32_t* len);
 
   void consume(uint32_t len);
 
+  /**
+   * Verify the checksum at the end of the zlib stream.
+   *
+   * This may only be called after all data has been read.
+   * It verifies the checksum that was written by the finish() call.
+   */
   void verifyChecksum();
 
    /**
@@ -182,7 +204,8 @@
   inline void checkZlibRv(int status, const char* msg);
   inline void checkZlibRvNothrow(int status, const char* msg);
   inline int readAvail();
-  void flushToZlib(const uint8_t* buf, int len, bool finish = false);
+  void flushToTransport(int flush);
+  void flushToZlib(const uint8_t* buf, int len, int flush);
 
   // Writes smaller than this are buffered up.
   // Larger (or equal) writes are dumped straight to zlib.
@@ -197,9 +220,9 @@
   /// True iff zlib has reached the end of a stream.
   /// This is only ever true in standalone protcol objects.
   bool input_ended_;
-  /// True iff we have flushed the output stream.
+  /// True iff we have finished the output stream.
   /// This is only ever true in standalone protcol objects.
-  bool output_flushed_;
+  bool output_finished_;
 
   int urbuf_size_;
   int crbuf_size_;
diff --git a/lib/cpp/test/TransportTest.cpp b/lib/cpp/test/TransportTest.cpp
index e5ddeee..7f95e38 100644
--- a/lib/cpp/test/TransportTest.cpp
+++ b/lib/cpp/test/TransportTest.cpp
@@ -36,6 +36,7 @@
 #include <transport/TBufferTransports.h>
 #include <transport/TFDTransport.h>
 #include <transport/TFileTransport.h>
+#include <transport/TZlibTransport.h>
 
 using namespace apache::thrift::transport;
 
@@ -178,6 +179,22 @@
   boost::shared_ptr<TMemoryBuffer> buf;
 };
 
+class CoupledZlibTransports : public CoupledTransports<TZlibTransport> {
+ public:
+  CoupledZlibTransports() :
+    buf(new TMemoryBuffer) {
+    in = new TZlibTransport(buf, false);
+    out = new TZlibTransport(buf, false);
+  }
+
+  ~CoupledZlibTransports() {
+    delete in;
+    delete out;
+  }
+
+  boost::shared_ptr<TMemoryBuffer> buf;
+};
+
 class CoupledFDTransports : public CoupledTransports<TFDTransport> {
  public:
   CoupledFDTransports() {
@@ -363,7 +380,16 @@
         read_size = rchunk_size - chunk_read;
       }
 
-      int bytes_read = transports.in->read(rbuf.get() + total_read, read_size);
+      int bytes_read = -1;
+      try {
+        bytes_read = transports.in->read(rbuf.get() + total_read, read_size);
+      } catch (TTransportException& e) {
+        BOOST_FAIL("read(pos=" << total_read << ", size=" << read_size <<
+                   ") threw exception \"" << e.what() <<
+                   "\"; written so far: " << total_written << " / " <<
+                   totalSize << " bytes");
+      }
+
       BOOST_REQUIRE_MESSAGE(bytes_read > 0,
                             "read(pos=" << total_read << ", size=" <<
                             read_size << ") returned " << bytes_read <<
@@ -449,6 +475,17 @@
     BUFFER_TESTS(CoupledBufferedTransports)
     BUFFER_TESTS(CoupledFramedTransports)
 
+    TEST_RW(CoupledZlibTransports, 1024*1024*10, 0, 0);
+    TEST_RW(CoupledZlibTransports, 1024*1024*10, rand4k, rand4k);
+    TEST_RW(CoupledZlibTransports, 1024*1024*5, 167, 163);
+    TEST_RW(CoupledZlibTransports, 1024*64, 1, 1);
+
+    TEST_RW(CoupledZlibTransports, 1024*1024*10, 0, 0, rand4k, rand4k);
+    TEST_RW(CoupledZlibTransports, 1024*1024*10,
+            rand4k, rand4k, rand4k, rand4k);
+    TEST_RW(CoupledZlibTransports, 1024*1024*5, 167, 163, rand4k, rand4k);
+    TEST_RW(CoupledZlibTransports, 1024*64, 1, 1, rand4k, rand4k);
+
     // TFDTransport tests
     // Since CoupledFDTransports tests with a pipe, writes will block
     // if there is too much outstanding unread data in the pipe.
diff --git a/lib/cpp/test/ZlibTest.cpp b/lib/cpp/test/ZlibTest.cpp
index e952e71..e2403d7 100644
--- a/lib/cpp/test/ZlibTest.cpp
+++ b/lib/cpp/test/ZlibTest.cpp
@@ -145,7 +145,7 @@
   shared_ptr<TMemoryBuffer> membuf(new TMemoryBuffer());
   shared_ptr<TZlibTransport> zlib_trans(new TZlibTransport(membuf, false));
   zlib_trans->write(buf, buf_len);
-  zlib_trans->flush();
+  zlib_trans->finish();
 
   boost::shared_array<uint8_t> mirror(new uint8_t[buf_len]);
   uint32_t got = zlib_trans->read(mirror.get(), buf_len);
@@ -164,7 +164,7 @@
   shared_ptr<TMemoryBuffer> membuf(new TMemoryBuffer());
   shared_ptr<TZlibTransport> zlib_trans(new TZlibTransport(membuf, false));
   zlib_trans->write(buf, buf_len);
-  zlib_trans->flush();
+  zlib_trans->finish();
   string tmp_buf;
   membuf->appendBufferToString(tmp_buf);
   zlib_trans.reset(new TZlibTransport(membuf, false,
@@ -184,7 +184,7 @@
   shared_ptr<TMemoryBuffer> membuf(new TMemoryBuffer());
   shared_ptr<TZlibTransport> zlib_trans(new TZlibTransport(membuf, false));
   zlib_trans->write(buf, buf_len);
-  zlib_trans->flush();
+  zlib_trans->finish();
   string tmp_buf;
   membuf->appendBufferToString(tmp_buf);
   tmp_buf.erase(tmp_buf.length() - 1);
@@ -222,7 +222,7 @@
     tot += write_len;
   }
 
-  zlib_trans->flush();
+  zlib_trans->finish();
 
   tot = 0;
   boost::shared_array<uint8_t> mirror(new uint8_t[buf_len]);
@@ -246,7 +246,7 @@
   shared_ptr<TMemoryBuffer> membuf(new TMemoryBuffer());
   shared_ptr<TZlibTransport> zlib_trans(new TZlibTransport(membuf, false));
   zlib_trans->write(buf, buf_len);
-  zlib_trans->flush();
+  zlib_trans->finish();
   string tmp_buf;
   membuf->appendBufferToString(tmp_buf);
   // Modify a byte at the end of the buffer (part of the checksum).
@@ -279,6 +279,54 @@
   }
 }
 
+void test_write_after_flush(const uint8_t* buf, uint32_t buf_len) {
+  // write some data
+  shared_ptr<TMemoryBuffer> membuf(new TMemoryBuffer());
+  shared_ptr<TZlibTransport> zlib_trans(new TZlibTransport(membuf, false));
+  zlib_trans->write(buf, buf_len);
+
+  // call finish()
+  zlib_trans->finish();
+
+  // make sure write() throws an error
+  try {
+    uint8_t write_buf[] = "a";
+    zlib_trans->write(write_buf, 1);
+    BOOST_ERROR("write() after finish() did not raise an exception");
+  } catch (TTransportException& ex) {
+    BOOST_CHECK_EQUAL(ex.getType(), TTransportException::BAD_ARGS);
+  }
+
+  // make sure flush() throws an error
+  try {
+    zlib_trans->flush();
+    BOOST_ERROR("flush() after finish() did not raise an exception");
+  } catch (TTransportException& ex) {
+    BOOST_CHECK_EQUAL(ex.getType(), TTransportException::BAD_ARGS);
+  }
+
+  // make sure finish() throws an error
+  try {
+    zlib_trans->finish();
+    BOOST_ERROR("finish() after finish() did not raise an exception");
+  } catch (TTransportException& ex) {
+    BOOST_CHECK_EQUAL(ex.getType(), TTransportException::BAD_ARGS);
+  }
+}
+
+void test_no_write() {
+  // Verify that no data is written to the underlying transport if we
+  // never write data to the TZlibTransport.
+  shared_ptr<TMemoryBuffer> membuf(new TMemoryBuffer());
+  {
+    // Create a TZlibTransport object, and immediately destroy it
+    // when it goes out of scope.
+    TZlibTransport w_zlib_trans(membuf, false);
+  }
+
+  BOOST_CHECK_EQUAL(membuf->available_read(), 0);
+}
+
 /*
  * Initialization
  */
@@ -301,6 +349,7 @@
   ADD_TEST_CASE(suite, name, test_separate_checksum, buf, buf_len);
   ADD_TEST_CASE(suite, name, test_incomplete_checksum, buf, buf_len);
   ADD_TEST_CASE(suite, name, test_invalid_checksum, buf, buf_len);
+  ADD_TEST_CASE(suite, name, test_write_after_flush, buf, buf_len);
 
   shared_ptr<SizeGenerator> size_32k(new ConstantSizeGenerator(1<<15));
   shared_ptr<SizeGenerator> size_lognormal(new LogNormalSizeGenerator(20, 30));
@@ -397,5 +446,7 @@
   add_tests(suite, gen_compressible_buffer(buf_len), buf_len, "compressible");
   add_tests(suite, gen_random_buffer(buf_len), buf_len, "random");
 
+  suite->add(BOOST_TEST_CASE(test_no_write));
+
   return suite;
 }