diff --git a/tensorflow/core/lib/io/record_reader_writer_test.cc b/tensorflow/core/lib/io/record_reader_writer_test.cc index a44c35d7fd1..0a656473e4b 100644 --- a/tensorflow/core/lib/io/record_reader_writer_test.cc +++ b/tensorflow/core/lib/io/record_reader_writer_test.cc @@ -67,4 +67,42 @@ TEST(RecordReaderWriterTest, TestBasics) { } } +TEST(RecordReaderWriterTest, TestZlib) { + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/record_reader_writer_zlib_test"; + + for (auto buf_size : BufferSizes()) { + // Zlib compression needs output buffer size > 1. + if (buf_size == 1) continue; + { + std::unique_ptr<WritableFile> file; + TF_CHECK_OK(env->NewWritableFile(fname, &file)); + + io::RecordWriterOptions options; + options.compression_type = io::RecordWriterOptions::ZLIB_COMPRESSION; + options.zlib_options.output_buffer_size = buf_size; + io::RecordWriter writer(file.get(), options); + writer.WriteRecord("abc"); + writer.WriteRecord("defg"); + TF_CHECK_OK(writer.Flush()); + } + + { + std::unique_ptr<RandomAccessFile> read_file; + // Read it back with the RecordReader. + TF_CHECK_OK(env->NewRandomAccessFile(fname, &read_file)); + io::RecordReaderOptions options; + options.compression_type = io::RecordReaderOptions::ZLIB_COMPRESSION; + options.zlib_options.input_buffer_size = buf_size; + io::RecordReader reader(read_file.get(), options); + uint64 offset = 0; + string record; + TF_CHECK_OK(reader.ReadRecord(&offset, &record)); + EXPECT_EQ("abc", record); + TF_CHECK_OK(reader.ReadRecord(&offset, &record)); + EXPECT_EQ("defg", record); + } + } +} + } // namespace tensorflow diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc index 25873b83ba3..516332d2b73 100644 --- a/tensorflow/core/lib/io/record_writer.cc +++ b/tensorflow/core/lib/io/record_writer.cc @@ -33,6 +33,11 @@ RecordWriter::RecordWriter(WritableFile* dest, zlib_output_buffer_.reset(new ZlibOutputBuffer( dest_, options.zlib_options.input_buffer_size, options.zlib_options.output_buffer_size, options.zlib_options)); + Status s = zlib_output_buffer_->Init(); + if (!s.ok()) { + LOG(FATAL) << "Failed to initialize Zlib inputbuffer. Error: " + << s.ToString(); + } #endif // IS_SLIM_BUILD } else if (options.compression_type == RecordWriterOptions::NONE) { // Nothing to do diff --git a/tensorflow/core/lib/io/zlib_buffers_test.cc b/tensorflow/core/lib/io/zlib_buffers_test.cc index eaaf1497594..1290e98ce2c 100644 --- a/tensorflow/core/lib/io/zlib_buffers_test.cc +++ b/tensorflow/core/lib/io/zlib_buffers_test.cc @@ -73,6 +73,7 @@ void TestAllCombinations(CompressionOptions input_options, ZlibOutputBuffer out(file_writer.get(), input_buf_size, output_buf_size, output_options); + TF_CHECK_OK(out.Init()); TF_CHECK_OK(out.Write(StringPiece(data))); TF_CHECK_OK(out.Close()); @@ -120,6 +121,7 @@ void TestMultipleWrites(uint8 input_buf_size, uint8 output_buf_size, TF_CHECK_OK(env->NewWritableFile(fname, &file_writer)); ZlibOutputBuffer out(file_writer.get(), input_buf_size, output_buf_size, output_options); + TF_CHECK_OK(out.Init()); for (int i = 0; i < num_writes; i++) { TF_CHECK_OK(out.Write(StringPiece(data))); @@ -172,6 +174,7 @@ TEST(ZlibInputStream, FailsToReadIfWindowBitsAreIncompatible) { string result; ZlibOutputBuffer out(file_writer.get(), input_buf_size, output_buf_size, output_options); + TF_CHECK_OK(out.Init()); TF_CHECK_OK(out.Write(StringPiece(data))); TF_CHECK_OK(out.Close()); diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.cc b/tensorflow/core/lib/io/zlib_outputbuffer.cc index 9493804bcb8..bdedfd00e86 100644 --- a/tensorflow/core/lib/io/zlib_outputbuffer.cc +++ b/tensorflow/core/lib/io/zlib_outputbuffer.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/lib/io/zlib_outputbuffer.h" +#include "tensorflow/core/lib/core/errors.h" + namespace tensorflow { namespace io { @@ -25,30 +27,13 @@ ZlibOutputBuffer::ZlibOutputBuffer( const ZlibCompressionOptions& zlib_options) // size of z_stream.next_out buffer : file_(file), + init_status_(), input_buffer_capacity_(input_buffer_bytes), output_buffer_capacity_(output_buffer_bytes), z_stream_input_(new Bytef[input_buffer_bytes]), z_stream_output_(new Bytef[output_buffer_bytes]), zlib_options_(zlib_options), - z_stream_(new z_stream) { - memset(z_stream_.get(), 0, sizeof(z_stream)); - z_stream_->zalloc = Z_NULL; - z_stream_->zfree = Z_NULL; - z_stream_->opaque = Z_NULL; - int status = - deflateInit2(z_stream_.get(), zlib_options.compression_level, - zlib_options.compression_method, zlib_options.window_bits, - zlib_options.mem_level, zlib_options.compression_strategy); - if (status != Z_OK) { - LOG(FATAL) << "deflateInit failed with status " << status; - z_stream_.reset(NULL); - } else { - z_stream_->next_in = z_stream_input_.get(); - z_stream_->next_out = z_stream_output_.get(); - z_stream_->avail_in = 0; - z_stream_->avail_out = output_buffer_capacity_; - } -} + z_stream_(new z_stream) {} ZlibOutputBuffer::~ZlibOutputBuffer() { if (z_stream_.get()) { @@ -56,6 +41,33 @@ ZlibOutputBuffer::~ZlibOutputBuffer() { } } +Status ZlibOutputBuffer::Init() { + // Output buffer size should be greater than 1 because deflation needs atleast + // one byte for book keeping etc. + if (output_buffer_capacity_ <= 1) { + return errors::InvalidArgument( + "output_buffer_bytes should be greater than " + "1"); + } + memset(z_stream_.get(), 0, sizeof(z_stream)); + z_stream_->zalloc = Z_NULL; + z_stream_->zfree = Z_NULL; + z_stream_->opaque = Z_NULL; + int status = + deflateInit2(z_stream_.get(), zlib_options_.compression_level, + zlib_options_.compression_method, zlib_options_.window_bits, + zlib_options_.mem_level, zlib_options_.compression_strategy); + if (status != Z_OK) { + z_stream_.reset(NULL); + return errors::InvalidArgument("deflateInit failed with status", status); + } + z_stream_->next_in = z_stream_input_.get(); + z_stream_->next_out = z_stream_output_.get(); + z_stream_->avail_in = 0; + z_stream_->avail_out = output_buffer_capacity_; + return Status::OK(); +} + int32 ZlibOutputBuffer::AvailableInputSpace() const { return input_buffer_capacity_ - z_stream_->avail_in; } diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.h b/tensorflow/core/lib/io/zlib_outputbuffer.h index 08455b63b50..a53c40b8fbc 100644 --- a/tensorflow/core/lib/io/zlib_outputbuffer.h +++ b/tensorflow/core/lib/io/zlib_outputbuffer.h @@ -45,6 +45,7 @@ class ZlibOutputBuffer { // 2. the deflated output // with sizes `input_buffer_bytes` and `output_buffer_bytes` respectively. // Does not take ownership of `file`. + // output_buffer_bytes should be greater than 1. ZlibOutputBuffer( WritableFile* file, int32 input_buffer_bytes, // size of z_stream.next_in buffer @@ -53,6 +54,10 @@ class ZlibOutputBuffer { ~ZlibOutputBuffer(); + // Initializes some state necessary for the output buffer. This call is + // required before any other operation on the buffer. + Status Init(); + // Adds `data` to the compression pipeline. // // The input data is buffered in `z_stream_input_` and is compressed in bulk @@ -78,6 +83,7 @@ class ZlibOutputBuffer { private: WritableFile* file_; // Not owned + Status init_status_; size_t input_buffer_capacity_; size_t output_buffer_capacity_;