THRIFT-926. cpp: Make TZlibTransport::flush() behave like other transports
Previously, TZlibTransport::flush() finished the zlib stream, so calling
write() after flush() would result in an error. Now it just flushes the
data, without finishing the stream. A new TZlibTransport::finish()
function has been added to finish the stream.
This breaks compatibility. I'm aware of anyone using this code outside
of Facebook, though.
git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@1005151 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/cpp/src/transport/TTransport.h b/lib/cpp/src/transport/TTransport.h
index 8f2bd3d..fa0ed8a 100644
--- a/lib/cpp/src/transport/TTransport.h
+++ b/lib/cpp/src/transport/TTransport.h
@@ -144,6 +144,12 @@
/**
* Writes the string in its entirety to the buffer.
*
+ * Note: You must call flush() to ensure the data is actually written,
+ * and available to be read back in the future. Destroying a TTransport
+ * object does not automatically flush pending data--if you destroy a
+ * TTransport object with written but unflushed data, that data may be
+ * discarded.
+ *
* @param buf The data to write out
* @throws TTransportException if an error occurs
*/
diff --git a/lib/cpp/src/transport/TZlibTransport.cpp b/lib/cpp/src/transport/TZlibTransport.cpp
index 2f14e90..8953742 100644
--- a/lib/cpp/src/transport/TZlibTransport.cpp
+++ b/lib/cpp/src/transport/TZlibTransport.cpp
@@ -90,8 +90,16 @@
int rv;
rv = inflateEnd(rstream_);
checkZlibRvNothrow(rv, rstream_->msg);
+
rv = deflateEnd(wstream_);
- checkZlibRvNothrow(rv, wstream_->msg);
+ // Z_DATA_ERROR may be returned if the caller has written data, but not
+ // called flush() to actually finish writing the data out to the underlying
+ // transport. The defined TTransport behavior in this case is that this data
+ // may be discarded, so we ignore the error and silently discard the data.
+ // For other erros, log a message.
+ if (rv != Z_DATA_ERROR) {
+ checkZlibRvNothrow(rv, wstream_->msg);
+ }
delete[] urbuf_;
delete[] crbuf_;
@@ -200,15 +208,20 @@
// - Deflate from the source into the compressed buffer.
void TZlibTransport::write(const uint8_t* buf, uint32_t len) {
+ if (output_finished_) {
+ throw TTransportException(TTransportException::BAD_ARGS,
+ "write() called after finish()");
+ }
+
// zlib's "deflate" function has enough logic in it that I think
// we're better off (performance-wise) buffering up small writes.
if ((int)len > MIN_DIRECT_DEFLATE_SIZE) {
- flushToZlib(uwbuf_, uwpos_);
+ flushToZlib(uwbuf_, uwpos_, Z_NO_FLUSH);
uwpos_ = 0;
- flushToZlib(buf, len);
+ flushToZlib(buf, len, Z_NO_FLUSH);
} else if (len > 0) {
if (uwbuf_size_ - uwpos_ < (int)len) {
- flushToZlib(uwbuf_, uwpos_);
+ flushToZlib(uwbuf_, uwpos_, Z_NO_FLUSH);
uwpos_ = 0;
}
memcpy(uwbuf_ + uwpos_, buf, len);
@@ -217,19 +230,46 @@
}
void TZlibTransport::flush() {
- flushToZlib(uwbuf_, uwpos_, true);
- assert((int)wstream_->avail_out != cwbuf_size_);
+ if (output_finished_) {
+ throw TTransportException(TTransportException::BAD_ARGS,
+ "flush() called after finish()");
+ }
+
+ flushToTransport(Z_SYNC_FLUSH);
+}
+
+void TZlibTransport::finish() {
+ if (output_finished_) {
+ throw TTransportException(TTransportException::BAD_ARGS,
+ "finish() called more than once");
+ }
+
+ flushToTransport(Z_FINISH);
+}
+
+void TZlibTransport::flushToTransport(int flush) {
+ // write pending data in uwbuf_ to zlib
+ flushToZlib(uwbuf_, uwpos_, flush);
+ uwpos_ = 0;
+
+ // write all available data from zlib to the transport
transport_->write(cwbuf_, cwbuf_size_ - wstream_->avail_out);
+ wstream_->next_out = cwbuf_;
+ wstream_->avail_out = cwbuf_size_;
+
+ // flush the transport
transport_->flush();
}
-void TZlibTransport::flushToZlib(const uint8_t* buf, int len, bool finish) {
- int flush = (finish ? Z_FINISH : Z_NO_FLUSH);
-
+void TZlibTransport::flushToZlib(const uint8_t* buf, int len, int flush) {
wstream_->next_in = const_cast<uint8_t*>(buf);
wstream_->avail_in = len;
- while (wstream_->avail_in > 0 || finish) {
+ while (true) {
+ if (flush == Z_NO_FLUSH && wstream_->avail_in == 0) {
+ break;
+ }
+
// If our ouput buffer is full, flush to the underlying transport.
if (wstream_->avail_out == 0) {
transport_->write(cwbuf_, cwbuf_size_);
@@ -239,12 +279,18 @@
int zlib_rv = deflate(wstream_, flush);
- if (finish && zlib_rv == Z_STREAM_END) {
+ if (flush == Z_FINISH && zlib_rv == Z_STREAM_END) {
assert(wstream_->avail_in == 0);
+ output_finished_ = true;
break;
}
checkZlibRv(zlib_rv, wstream_->msg);
+
+ if ((flush == Z_SYNC_FLUSH || flush == Z_FULL_FLUSH) &&
+ wstream_->avail_in == 0 && wstream_->avail_out != 0) {
+ break;
+ }
}
}
diff --git a/lib/cpp/src/transport/TZlibTransport.h b/lib/cpp/src/transport/TZlibTransport.h
index 61c43fe..0f9815e 100644
--- a/lib/cpp/src/transport/TZlibTransport.h
+++ b/lib/cpp/src/transport/TZlibTransport.h
@@ -96,7 +96,7 @@
urpos_(0),
uwpos_(0),
input_ended_(false),
- output_flushed_(false),
+ output_finished_(false),
urbuf_size_(urbuf_size),
crbuf_size_(crbuf_size),
uwbuf_size_(uwbuf_size),
@@ -145,6 +145,13 @@
// Don't call this outside of the constructor.
void initZlib();
+ /**
+ * TZlibTransport destructor.
+ *
+ * Warning: Destroying a TZlibTransport object may discard any written but
+ * unflushed data. You must explicitly call flush() or finish() to ensure
+ * that data is actually written and flushed to the underlying transport.
+ */
~TZlibTransport();
bool isOpen();
@@ -163,10 +170,25 @@
void flush();
+ /**
+ * Finalize the zlib stream.
+ *
+ * This causes zlib to flush any pending write data and write end-of-stream
+ * information, including the checksum. Once finish() has been called, no
+ * new data can be written to the stream.
+ */
+ void finish();
+
const uint8_t* borrow(uint8_t* buf, uint32_t* len);
void consume(uint32_t len);
+ /**
+ * Verify the checksum at the end of the zlib stream.
+ *
+ * This may only be called after all data has been read.
+ * It verifies the checksum that was written by the finish() call.
+ */
void verifyChecksum();
/**
@@ -182,7 +204,8 @@
inline void checkZlibRv(int status, const char* msg);
inline void checkZlibRvNothrow(int status, const char* msg);
inline int readAvail();
- void flushToZlib(const uint8_t* buf, int len, bool finish = false);
+ void flushToTransport(int flush);
+ void flushToZlib(const uint8_t* buf, int len, int flush);
// Writes smaller than this are buffered up.
// Larger (or equal) writes are dumped straight to zlib.
@@ -197,9 +220,9 @@
/// True iff zlib has reached the end of a stream.
/// This is only ever true in standalone protcol objects.
bool input_ended_;
- /// True iff we have flushed the output stream.
+ /// True iff we have finished the output stream.
/// This is only ever true in standalone protcol objects.
- bool output_flushed_;
+ bool output_finished_;
int urbuf_size_;
int crbuf_size_;
diff --git a/lib/cpp/test/TransportTest.cpp b/lib/cpp/test/TransportTest.cpp
index e5ddeee..7f95e38 100644
--- a/lib/cpp/test/TransportTest.cpp
+++ b/lib/cpp/test/TransportTest.cpp
@@ -36,6 +36,7 @@
#include <transport/TBufferTransports.h>
#include <transport/TFDTransport.h>
#include <transport/TFileTransport.h>
+#include <transport/TZlibTransport.h>
using namespace apache::thrift::transport;
@@ -178,6 +179,22 @@
boost::shared_ptr<TMemoryBuffer> buf;
};
+class CoupledZlibTransports : public CoupledTransports<TZlibTransport> {
+ public:
+ CoupledZlibTransports() :
+ buf(new TMemoryBuffer) {
+ in = new TZlibTransport(buf, false);
+ out = new TZlibTransport(buf, false);
+ }
+
+ ~CoupledZlibTransports() {
+ delete in;
+ delete out;
+ }
+
+ boost::shared_ptr<TMemoryBuffer> buf;
+};
+
class CoupledFDTransports : public CoupledTransports<TFDTransport> {
public:
CoupledFDTransports() {
@@ -363,7 +380,16 @@
read_size = rchunk_size - chunk_read;
}
- int bytes_read = transports.in->read(rbuf.get() + total_read, read_size);
+ int bytes_read = -1;
+ try {
+ bytes_read = transports.in->read(rbuf.get() + total_read, read_size);
+ } catch (TTransportException& e) {
+ BOOST_FAIL("read(pos=" << total_read << ", size=" << read_size <<
+ ") threw exception \"" << e.what() <<
+ "\"; written so far: " << total_written << " / " <<
+ totalSize << " bytes");
+ }
+
BOOST_REQUIRE_MESSAGE(bytes_read > 0,
"read(pos=" << total_read << ", size=" <<
read_size << ") returned " << bytes_read <<
@@ -449,6 +475,17 @@
BUFFER_TESTS(CoupledBufferedTransports)
BUFFER_TESTS(CoupledFramedTransports)
+ TEST_RW(CoupledZlibTransports, 1024*1024*10, 0, 0);
+ TEST_RW(CoupledZlibTransports, 1024*1024*10, rand4k, rand4k);
+ TEST_RW(CoupledZlibTransports, 1024*1024*5, 167, 163);
+ TEST_RW(CoupledZlibTransports, 1024*64, 1, 1);
+
+ TEST_RW(CoupledZlibTransports, 1024*1024*10, 0, 0, rand4k, rand4k);
+ TEST_RW(CoupledZlibTransports, 1024*1024*10,
+ rand4k, rand4k, rand4k, rand4k);
+ TEST_RW(CoupledZlibTransports, 1024*1024*5, 167, 163, rand4k, rand4k);
+ TEST_RW(CoupledZlibTransports, 1024*64, 1, 1, rand4k, rand4k);
+
// TFDTransport tests
// Since CoupledFDTransports tests with a pipe, writes will block
// if there is too much outstanding unread data in the pipe.
diff --git a/lib/cpp/test/ZlibTest.cpp b/lib/cpp/test/ZlibTest.cpp
index e952e71..e2403d7 100644
--- a/lib/cpp/test/ZlibTest.cpp
+++ b/lib/cpp/test/ZlibTest.cpp
@@ -145,7 +145,7 @@
shared_ptr<TMemoryBuffer> membuf(new TMemoryBuffer());
shared_ptr<TZlibTransport> zlib_trans(new TZlibTransport(membuf, false));
zlib_trans->write(buf, buf_len);
- zlib_trans->flush();
+ zlib_trans->finish();
boost::shared_array<uint8_t> mirror(new uint8_t[buf_len]);
uint32_t got = zlib_trans->read(mirror.get(), buf_len);
@@ -164,7 +164,7 @@
shared_ptr<TMemoryBuffer> membuf(new TMemoryBuffer());
shared_ptr<TZlibTransport> zlib_trans(new TZlibTransport(membuf, false));
zlib_trans->write(buf, buf_len);
- zlib_trans->flush();
+ zlib_trans->finish();
string tmp_buf;
membuf->appendBufferToString(tmp_buf);
zlib_trans.reset(new TZlibTransport(membuf, false,
@@ -184,7 +184,7 @@
shared_ptr<TMemoryBuffer> membuf(new TMemoryBuffer());
shared_ptr<TZlibTransport> zlib_trans(new TZlibTransport(membuf, false));
zlib_trans->write(buf, buf_len);
- zlib_trans->flush();
+ zlib_trans->finish();
string tmp_buf;
membuf->appendBufferToString(tmp_buf);
tmp_buf.erase(tmp_buf.length() - 1);
@@ -222,7 +222,7 @@
tot += write_len;
}
- zlib_trans->flush();
+ zlib_trans->finish();
tot = 0;
boost::shared_array<uint8_t> mirror(new uint8_t[buf_len]);
@@ -246,7 +246,7 @@
shared_ptr<TMemoryBuffer> membuf(new TMemoryBuffer());
shared_ptr<TZlibTransport> zlib_trans(new TZlibTransport(membuf, false));
zlib_trans->write(buf, buf_len);
- zlib_trans->flush();
+ zlib_trans->finish();
string tmp_buf;
membuf->appendBufferToString(tmp_buf);
// Modify a byte at the end of the buffer (part of the checksum).
@@ -279,6 +279,54 @@
}
}
+void test_write_after_flush(const uint8_t* buf, uint32_t buf_len) {
+ // write some data
+ shared_ptr<TMemoryBuffer> membuf(new TMemoryBuffer());
+ shared_ptr<TZlibTransport> zlib_trans(new TZlibTransport(membuf, false));
+ zlib_trans->write(buf, buf_len);
+
+ // call finish()
+ zlib_trans->finish();
+
+ // make sure write() throws an error
+ try {
+ uint8_t write_buf[] = "a";
+ zlib_trans->write(write_buf, 1);
+ BOOST_ERROR("write() after finish() did not raise an exception");
+ } catch (TTransportException& ex) {
+ BOOST_CHECK_EQUAL(ex.getType(), TTransportException::BAD_ARGS);
+ }
+
+ // make sure flush() throws an error
+ try {
+ zlib_trans->flush();
+ BOOST_ERROR("flush() after finish() did not raise an exception");
+ } catch (TTransportException& ex) {
+ BOOST_CHECK_EQUAL(ex.getType(), TTransportException::BAD_ARGS);
+ }
+
+ // make sure finish() throws an error
+ try {
+ zlib_trans->finish();
+ BOOST_ERROR("finish() after finish() did not raise an exception");
+ } catch (TTransportException& ex) {
+ BOOST_CHECK_EQUAL(ex.getType(), TTransportException::BAD_ARGS);
+ }
+}
+
+void test_no_write() {
+ // Verify that no data is written to the underlying transport if we
+ // never write data to the TZlibTransport.
+ shared_ptr<TMemoryBuffer> membuf(new TMemoryBuffer());
+ {
+ // Create a TZlibTransport object, and immediately destroy it
+ // when it goes out of scope.
+ TZlibTransport w_zlib_trans(membuf, false);
+ }
+
+ BOOST_CHECK_EQUAL(membuf->available_read(), 0);
+}
+
/*
* Initialization
*/
@@ -301,6 +349,7 @@
ADD_TEST_CASE(suite, name, test_separate_checksum, buf, buf_len);
ADD_TEST_CASE(suite, name, test_incomplete_checksum, buf, buf_len);
ADD_TEST_CASE(suite, name, test_invalid_checksum, buf, buf_len);
+ ADD_TEST_CASE(suite, name, test_write_after_flush, buf, buf_len);
shared_ptr<SizeGenerator> size_32k(new ConstantSizeGenerator(1<<15));
shared_ptr<SizeGenerator> size_lognormal(new LogNormalSizeGenerator(20, 30));
@@ -397,5 +446,7 @@
add_tests(suite, gen_compressible_buffer(buf_len), buf_len, "compressible");
add_tests(suite, gen_random_buffer(buf_len), buf_len, "random");
+ suite->add(BOOST_TEST_CASE(test_no_write));
+
return suite;
}