Added a check for output_buffer_size <= 1 for ZlibOutputBuffer. Also adding some tests for Zlib compression reading / writing.

Change: 132370925
This commit is contained in:
Rohan Jain 2016-09-06 14:39:08 -08:00 committed by TensorFlower Gardener
parent d7bc08fd5f
commit 680966059e
5 changed files with 83 additions and 19 deletions

View File

@ -67,4 +67,42 @@ TEST(RecordReaderWriterTest, TestBasics) {
}
}
TEST(RecordReaderWriterTest, TestZlib) {
Env* env = Env::Default();
string fname = testing::TmpDir() + "/record_reader_writer_zlib_test";
for (auto buf_size : BufferSizes()) {
// Zlib compression needs output buffer size > 1.
if (buf_size == 1) continue;
{
std::unique_ptr<WritableFile> file;
TF_CHECK_OK(env->NewWritableFile(fname, &file));
io::RecordWriterOptions options;
options.compression_type = io::RecordWriterOptions::ZLIB_COMPRESSION;
options.zlib_options.output_buffer_size = buf_size;
io::RecordWriter writer(file.get(), options);
writer.WriteRecord("abc");
writer.WriteRecord("defg");
TF_CHECK_OK(writer.Flush());
}
{
std::unique_ptr<RandomAccessFile> read_file;
// Read it back with the RecordReader.
TF_CHECK_OK(env->NewRandomAccessFile(fname, &read_file));
io::RecordReaderOptions options;
options.compression_type = io::RecordReaderOptions::ZLIB_COMPRESSION;
options.zlib_options.input_buffer_size = buf_size;
io::RecordReader reader(read_file.get(), options);
uint64 offset = 0;
string record;
TF_CHECK_OK(reader.ReadRecord(&offset, &record));
EXPECT_EQ("abc", record);
TF_CHECK_OK(reader.ReadRecord(&offset, &record));
EXPECT_EQ("defg", record);
}
}
}
} // namespace tensorflow

View File

@ -33,6 +33,11 @@ RecordWriter::RecordWriter(WritableFile* dest,
zlib_output_buffer_.reset(new ZlibOutputBuffer(
dest_, options.zlib_options.input_buffer_size,
options.zlib_options.output_buffer_size, options.zlib_options));
Status s = zlib_output_buffer_->Init();
if (!s.ok()) {
LOG(FATAL) << "Failed to initialize Zlib inputbuffer. Error: "
<< s.ToString();
}
#endif // IS_SLIM_BUILD
} else if (options.compression_type == RecordWriterOptions::NONE) {
// Nothing to do

View File

@ -73,6 +73,7 @@ void TestAllCombinations(CompressionOptions input_options,
ZlibOutputBuffer out(file_writer.get(), input_buf_size, output_buf_size,
output_options);
TF_CHECK_OK(out.Init());
TF_CHECK_OK(out.Write(StringPiece(data)));
TF_CHECK_OK(out.Close());
@ -120,6 +121,7 @@ void TestMultipleWrites(uint8 input_buf_size, uint8 output_buf_size,
TF_CHECK_OK(env->NewWritableFile(fname, &file_writer));
ZlibOutputBuffer out(file_writer.get(), input_buf_size, output_buf_size,
output_options);
TF_CHECK_OK(out.Init());
for (int i = 0; i < num_writes; i++) {
TF_CHECK_OK(out.Write(StringPiece(data)));
@ -172,6 +174,7 @@ TEST(ZlibInputStream, FailsToReadIfWindowBitsAreIncompatible) {
string result;
ZlibOutputBuffer out(file_writer.get(), input_buf_size, output_buf_size,
output_options);
TF_CHECK_OK(out.Init());
TF_CHECK_OK(out.Write(StringPiece(data)));
TF_CHECK_OK(out.Close());

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/core/lib/io/zlib_outputbuffer.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
namespace io {
@ -25,30 +27,13 @@ ZlibOutputBuffer::ZlibOutputBuffer(
const ZlibCompressionOptions&
zlib_options) // size of z_stream.next_out buffer
: file_(file),
init_status_(),
input_buffer_capacity_(input_buffer_bytes),
output_buffer_capacity_(output_buffer_bytes),
z_stream_input_(new Bytef[input_buffer_bytes]),
z_stream_output_(new Bytef[output_buffer_bytes]),
zlib_options_(zlib_options),
z_stream_(new z_stream) {
memset(z_stream_.get(), 0, sizeof(z_stream));
z_stream_->zalloc = Z_NULL;
z_stream_->zfree = Z_NULL;
z_stream_->opaque = Z_NULL;
int status =
deflateInit2(z_stream_.get(), zlib_options.compression_level,
zlib_options.compression_method, zlib_options.window_bits,
zlib_options.mem_level, zlib_options.compression_strategy);
if (status != Z_OK) {
LOG(FATAL) << "deflateInit failed with status " << status;
z_stream_.reset(NULL);
} else {
z_stream_->next_in = z_stream_input_.get();
z_stream_->next_out = z_stream_output_.get();
z_stream_->avail_in = 0;
z_stream_->avail_out = output_buffer_capacity_;
}
}
z_stream_(new z_stream) {}
ZlibOutputBuffer::~ZlibOutputBuffer() {
if (z_stream_.get()) {
@ -56,6 +41,33 @@ ZlibOutputBuffer::~ZlibOutputBuffer() {
}
}
Status ZlibOutputBuffer::Init() {
// Output buffer size should be greater than 1 because deflation needs atleast
// one byte for book keeping etc.
if (output_buffer_capacity_ <= 1) {
return errors::InvalidArgument(
"output_buffer_bytes should be greater than "
"1");
}
memset(z_stream_.get(), 0, sizeof(z_stream));
z_stream_->zalloc = Z_NULL;
z_stream_->zfree = Z_NULL;
z_stream_->opaque = Z_NULL;
int status =
deflateInit2(z_stream_.get(), zlib_options_.compression_level,
zlib_options_.compression_method, zlib_options_.window_bits,
zlib_options_.mem_level, zlib_options_.compression_strategy);
if (status != Z_OK) {
z_stream_.reset(NULL);
return errors::InvalidArgument("deflateInit failed with status", status);
}
z_stream_->next_in = z_stream_input_.get();
z_stream_->next_out = z_stream_output_.get();
z_stream_->avail_in = 0;
z_stream_->avail_out = output_buffer_capacity_;
return Status::OK();
}
int32 ZlibOutputBuffer::AvailableInputSpace() const {
return input_buffer_capacity_ - z_stream_->avail_in;
}

View File

@ -45,6 +45,7 @@ class ZlibOutputBuffer {
// 2. the deflated output
// with sizes `input_buffer_bytes` and `output_buffer_bytes` respectively.
// Does not take ownership of `file`.
// output_buffer_bytes should be greater than 1.
ZlibOutputBuffer(
WritableFile* file,
int32 input_buffer_bytes, // size of z_stream.next_in buffer
@ -53,6 +54,10 @@ class ZlibOutputBuffer {
~ZlibOutputBuffer();
// Initializes some state necessary for the output buffer. This call is
// required before any other operation on the buffer.
Status Init();
// Adds `data` to the compression pipeline.
//
// The input data is buffered in `z_stream_input_` and is compressed in bulk
@ -78,6 +83,7 @@ class ZlibOutputBuffer {
private:
WritableFile* file_; // Not owned
Status init_status_;
size_t input_buffer_capacity_;
size_t output_buffer_capacity_;