Add SnappyInputStream implementation.

This is to support snappy compression/decompression on new version of snapshot dataset, as the dataset uses input streams instead.

PiperOrigin-RevId: 317435680
Change-Id: Ie57b43e73b6b7911398d883c3c5a0de72973288e
This commit is contained in:
Frank Chen 2020-06-19 22:49:37 -07:00 committed by TensorFlower Gardener
parent 7f2bfd5709
commit 27cb9aa834
5 changed files with 427 additions and 15 deletions

View File

@ -1865,6 +1865,7 @@ cc_library(
"//tensorflow/core/lib/io:record_reader",
"//tensorflow/core/lib/io:record_writer",
"//tensorflow/core/lib/io:snappy_inputbuffer",
"//tensorflow/core/lib/io:snappy_inputstream",
"//tensorflow/core/lib/io:snappy_outputbuffer",
"//tensorflow/core/lib/io:table",
"//tensorflow/core/lib/io:table_options",

View File

@ -208,6 +208,19 @@ cc_library(
alwayslink = True,
)
cc_library(
name = "snappy_inputstream",
srcs = ["snappy/snappy_inputstream.cc"],
hdrs = ["snappy/snappy_inputstream.h"],
deps = [
":inputstream_interface",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:platform_port",
"@com_google_absl//absl/memory",
],
alwayslink = True,
)
cc_library(
name = "cache",
srcs = [
@ -354,6 +367,7 @@ filegroup(
"record_reader.h",
"record_writer.h",
"snappy/snappy_inputbuffer.h",
"snappy/snappy_inputstream.h",
"snappy/snappy_outputbuffer.h",
"table.h",
"table_builder.h",
@ -377,7 +391,7 @@ filegroup(
"random_inputstream_test.cc",
"record_reader_writer_test.cc",
"recordio_test.cc",
"snappy/snappy_buffers_test.cc",
"snappy/snappy_test.cc",
"table_test.cc",
"zlib_buffers_test.cc",
],
@ -409,6 +423,7 @@ filegroup(
"inputbuffer.h",
"iterator.h",
"snappy/snappy_inputbuffer.h",
"snappy/snappy_inputstream.h",
"snappy/snappy_outputbuffer.h",
"zlib_compression_options.h",
"zlib_inputstream.h",

View File

@ -0,0 +1,153 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/lib/io/snappy/snappy_inputstream.h"
#include "absl/memory/memory.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/snappy.h"
namespace tensorflow {
namespace io {
SnappyInputStream::SnappyInputStream(InputStreamInterface* input_stream,
size_t output_buffer_bytes,
bool owns_input_stream)
: input_stream_(input_stream),
output_buffer_bytes_(output_buffer_bytes),
owns_input_stream_(owns_input_stream),
bytes_read_(0),
output_buffer_(new char[output_buffer_bytes]),
next_out_(nullptr),
avail_out_(0) {}
SnappyInputStream::SnappyInputStream(InputStreamInterface* input_stream,
size_t output_buffer_bytes)
: SnappyInputStream(input_stream, output_buffer_bytes, false) {}
SnappyInputStream::~SnappyInputStream() {
if (owns_input_stream_) {
delete input_stream_;
}
}
Status SnappyInputStream::ReadNBytes(int64 bytes_to_read, tstring* result) {
result->clear();
result->resize_uninitialized(bytes_to_read);
char* result_ptr = result->mdata();
// Read as many bytes as possible from the cache.
size_t bytes_read = ReadBytesFromCache(bytes_to_read, result_ptr);
bytes_to_read -= bytes_read;
result_ptr += bytes_read;
while (bytes_to_read > 0) {
DCHECK_EQ(avail_out_, 0);
// Fill the cache with more data.
TF_RETURN_IF_ERROR(Inflate());
size_t bytes_read = ReadBytesFromCache(bytes_to_read, result_ptr);
bytes_to_read -= bytes_read;
result_ptr += bytes_read;
}
return Status::OK();
}
#if defined(PLATFORM_GOOGLE)
Status SnappyInputStream::ReadNBytes(int64 bytes_to_read, absl::Cord* result) {
// TODO(frankchn): Optimize this instead of bouncing through the buffer.
tstring buf;
TF_RETURN_IF_ERROR(ReadNBytes(bytes_to_read, &buf));
result->Clear();
result->Append(buf.data());
return Status::OK();
}
#endif
Status SnappyInputStream::Inflate() {
tstring compressed_block_length_ts;
uint32 compressed_block_length;
TF_RETURN_IF_ERROR(
input_stream_->ReadNBytes(sizeof(uint32), &compressed_block_length_ts));
for (int i = 0; i < sizeof(uint32); ++i) {
compressed_block_length =
(compressed_block_length << 8) |
static_cast<unsigned char>(compressed_block_length_ts.data()[i]);
}
tstring compressed_block;
compressed_block.resize_uninitialized(compressed_block_length);
Status s =
input_stream_->ReadNBytes(compressed_block_length, &compressed_block);
if (errors::IsOutOfRange(s)) {
return errors::DataLoss("Failed to read ", compressed_block_length,
" bytes from file. Possible data corruption.");
}
TF_RETURN_IF_ERROR(s);
size_t uncompressed_length;
if (!port::Snappy_GetUncompressedLength(compressed_block.data(),
compressed_block_length,
&uncompressed_length)) {
return errors::DataLoss("Parsing error in Snappy_GetUncompressedLength");
}
DCHECK_EQ(avail_out_, 0);
if (output_buffer_bytes_ < uncompressed_length) {
return errors::ResourceExhausted(
"Output buffer(size: ", output_buffer_bytes_,
" bytes"
") too small. Should be larger than ",
uncompressed_length, " bytes.");
}
next_out_ = output_buffer_.get();
if (!port::Snappy_Uncompress(compressed_block.data(), compressed_block_length,
output_buffer_.get())) {
return errors::DataLoss("Snappy_Uncompress failed.");
}
avail_out_ += uncompressed_length;
return Status::OK();
}
size_t SnappyInputStream::ReadBytesFromCache(size_t bytes_to_read,
char* result) {
size_t can_read_bytes = std::min(bytes_to_read, avail_out_);
if (can_read_bytes) {
memcpy(result, next_out_, can_read_bytes);
next_out_ += can_read_bytes;
avail_out_ -= can_read_bytes;
}
bytes_read_ += can_read_bytes;
return can_read_bytes;
}
int64 SnappyInputStream::Tell() const { return bytes_read_; }
Status SnappyInputStream::Reset() {
TF_RETURN_IF_ERROR(input_stream_->Reset());
avail_out_ = 0;
bytes_read_ = 0;
return Status::OK();
}
} // namespace io
} // namespace tensorflow

View File

@ -0,0 +1,89 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_LIB_IO_SNAPPY_SNAPPY_INPUTSTREAM_H_
#define TENSORFLOW_CORE_LIB_IO_SNAPPY_SNAPPY_INPUTSTREAM_H_
#include "tensorflow/core/lib/io/inputstream_interface.h"
namespace tensorflow {
namespace io {
class SnappyInputStream : public InputStreamInterface {
public:
// Creates a SnappyInputStream for `input_stream`.
//
// Takes ownership of `input_stream` iff `owns_input_stream` is true.
SnappyInputStream(InputStreamInterface* input_stream,
size_t output_buffer_bytes, bool owns_input_stream);
// Equivalent to the previous constructor with owns_input_stream = false.
explicit SnappyInputStream(InputStreamInterface* input_stream,
size_t output_buffer_bytes);
~SnappyInputStream() override;
// Reads bytes_to_read bytes into *result, overwriting *result.
//
// Return Status codes:
// OK: If successful.
// OUT_OF_RANGE: If there are not enough bytes to read before
// the end of the stream.
// ABORTED: If inflate() fails, we return the error code with the
// error message in `z_stream_->msg`.
// others: If reading from stream failed.
Status ReadNBytes(int64 bytes_to_read, tstring* result) override;
#if defined(PLATFORM_GOOGLE)
Status ReadNBytes(int64 bytes_to_read, absl::Cord* result) override;
#endif
int64 Tell() const override;
Status Reset() override;
private:
// Decompress the next chunk of data and place the data into the cache.
Status Inflate();
// Attempt to read `bytes_to_read` from the decompressed data cache. Returns
// the actual number of bytes read.
size_t ReadBytesFromCache(size_t bytes_to_read, char* result);
InputStreamInterface* input_stream_;
const size_t output_buffer_bytes_;
const bool owns_input_stream_;
// Specifies the number of decompressed bytes currently read.
int64 bytes_read_;
// output_buffer_ contains decompressed data not yet read by the client.
std::unique_ptr<char[]> output_buffer_;
// next_out_ points to the position in the `output_buffer_` that contains the
// next unread byte.
char* next_out_;
// avail_out_ specifies the number of bytes left in the output_buffers_ that
// is not yet read.
size_t avail_out_;
TF_DISALLOW_COPY_AND_ASSIGN(SnappyInputStream);
};
} // namespace io
} // namespace tensorflow
#endif // TENSORFLOW_CORE_LIB_IO_SNAPPY_SNAPPY_INPUTSTREAM_H_

View File

@ -15,7 +15,9 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/inputbuffer.h"
#include "tensorflow/core/lib/io/random_inputstream.h"
#include "tensorflow/core/lib/io/snappy/snappy_inputbuffer.h"
#include "tensorflow/core/lib/io/snappy/snappy_inputstream.h"
#include "tensorflow/core/lib/io/snappy/snappy_outputbuffer.h"
namespace tensorflow {
@ -50,18 +52,17 @@ static string GenTestString(int copies = 1) {
return result;
}
Status TestMultipleWrites(size_t compress_input_buf_size,
size_t compress_output_buf_size,
size_t uncompress_input_buf_size,
size_t uncompress_output_buf_size, int num_writes = 1,
bool with_flush = false, int num_copies = 1,
bool corrupt_compressed_file = false) {
Status TestMultipleWritesWriteFile(size_t compress_input_buf_size,
size_t compress_output_buf_size,
int num_writes, bool with_flush,
int num_copies, bool corrupt_compressed_file,
string& fname, string& data,
string& expected_result) {
Env* env = Env::Default();
string fname = testing::TmpDir() + "/snappy_buffers_test";
string data = GenTestString(num_copies);
fname = testing::TmpDir() + "/snappy_buffers_test";
data = GenTestString(num_copies);
std::unique_ptr<WritableFile> file_writer;
string expected_result;
TF_RETURN_IF_ERROR(env->NewWritableFile(fname, &file_writer));
io::SnappyOutputBuffer out(file_writer.get(), compress_input_buf_size,
@ -112,6 +113,25 @@ Status TestMultipleWrites(size_t compress_input_buf_size,
fname = corrupt_fname;
}
return Status::OK();
}
Status TestMultipleWrites(size_t compress_input_buf_size,
size_t compress_output_buf_size,
size_t uncompress_input_buf_size,
size_t uncompress_output_buf_size, int num_writes = 1,
bool with_flush = false, int num_copies = 1,
bool corrupt_compressed_file = false) {
Env* env = Env::Default();
string expected_result;
string fname;
string data;
TF_RETURN_IF_ERROR(TestMultipleWritesWriteFile(
compress_input_buf_size, compress_output_buf_size, num_writes, with_flush,
num_copies, corrupt_compressed_file, fname, data, expected_result));
std::unique_ptr<RandomAccessFile> file_reader;
TF_RETURN_IF_ERROR(env->NewRandomAccessFile(fname, &file_reader));
io::SnappyInputBuffer in(file_reader.get(), uncompress_input_buf_size,
@ -131,15 +151,56 @@ Status TestMultipleWrites(size_t compress_input_buf_size,
}
TF_RETURN_IF_ERROR(in.Reset());
}
return Status::OK();
}
void TestTell(size_t compress_input_buf_size, size_t compress_output_buf_size,
size_t uncompress_input_buf_size,
size_t uncompress_output_buf_size, int num_copies = 1) {
Status TestMultipleWritesInputStream(
size_t compress_input_buf_size, size_t compress_output_buf_size,
size_t uncompress_input_buf_size, size_t uncompress_output_buf_size,
int num_writes = 1, bool with_flush = false, int num_copies = 1,
bool corrupt_compressed_file = false) {
Env* env = Env::Default();
string fname = testing::TmpDir() + "/snappy_buffers_test";
string data = GenTestString(num_copies);
string expected_result;
string fname;
string data;
TF_RETURN_IF_ERROR(TestMultipleWritesWriteFile(
compress_input_buf_size, compress_output_buf_size, num_writes, with_flush,
num_copies, corrupt_compressed_file, fname, data, expected_result));
std::unique_ptr<RandomAccessFile> file_reader;
TF_RETURN_IF_ERROR(env->NewRandomAccessFile(fname, &file_reader));
io::RandomAccessInputStream random_input_stream(file_reader.get(), false);
io::SnappyInputStream snappy_input_stream(&random_input_stream,
uncompress_output_buf_size);
for (int attempt = 0; attempt < 2; ++attempt) {
string actual_result;
for (int i = 0; i < num_writes; ++i) {
tstring decompressed_output;
TF_RETURN_IF_ERROR(
snappy_input_stream.ReadNBytes(data.size(), &decompressed_output));
strings::StrAppend(&actual_result, decompressed_output);
}
if (actual_result.compare(expected_result)) {
return errors::DataLoss("Actual and expected results don't match.");
}
TF_RETURN_IF_ERROR(snappy_input_stream.Reset());
}
return Status::OK();
}
void TestTellWriteFile(size_t compress_input_buf_size,
size_t compress_output_buf_size,
size_t uncompress_input_buf_size,
size_t uncompress_output_buf_size, int num_copies,
string& fname, string& data) {
Env* env = Env::Default();
fname = testing::TmpDir() + "/snappy_buffers_test";
data = GenTestString(num_copies);
// Write the compressed file.
std::unique_ptr<WritableFile> file_writer;
@ -150,6 +211,18 @@ void TestTell(size_t compress_input_buf_size, size_t compress_output_buf_size,
TF_CHECK_OK(out.Flush());
TF_CHECK_OK(file_writer->Flush());
TF_CHECK_OK(file_writer->Close());
}
void TestTell(size_t compress_input_buf_size, size_t compress_output_buf_size,
size_t uncompress_input_buf_size,
size_t uncompress_output_buf_size, int num_copies = 1) {
Env* env = Env::Default();
string data;
string fname;
TestTellWriteFile(compress_input_buf_size, compress_output_buf_size,
uncompress_input_buf_size, uncompress_output_buf_size,
num_copies, fname, data);
tstring first_half(string(data, 0, data.size() / 2));
tstring bytes_read;
@ -175,6 +248,43 @@ void TestTell(size_t compress_input_buf_size, size_t compress_output_buf_size,
EXPECT_EQ(bytes_read, data);
}
void TestTellInputStream(size_t compress_input_buf_size,
size_t compress_output_buf_size,
size_t uncompress_input_buf_size,
size_t uncompress_output_buf_size,
int num_copies = 1) {
Env* env = Env::Default();
string data;
string fname;
TestTellWriteFile(compress_input_buf_size, compress_output_buf_size,
uncompress_input_buf_size, uncompress_output_buf_size,
num_copies, fname, data);
tstring first_half(string(data, 0, data.size() / 2));
tstring bytes_read;
std::unique_ptr<RandomAccessFile> file_reader;
TF_CHECK_OK(env->NewRandomAccessFile(fname, &file_reader));
io::RandomAccessInputStream random_input_stream(file_reader.get(), false);
io::SnappyInputStream in(&random_input_stream, uncompress_output_buf_size);
// Read the first half of the uncompressed file and expect that Tell()
// returns half the uncompressed length of the file.
TF_CHECK_OK(in.ReadNBytes(first_half.size(), &bytes_read));
EXPECT_EQ(in.Tell(), first_half.size());
EXPECT_EQ(bytes_read, first_half);
// Read the remaining half of the uncompressed file and expect that
// Tell() points past the end of file.
tstring second_half;
TF_CHECK_OK(in.ReadNBytes(data.size() - first_half.size(), &second_half));
EXPECT_EQ(in.Tell(), data.size());
bytes_read.append(second_half);
// Expect that the file is correctly read.
EXPECT_EQ(bytes_read, data);
}
static bool SnappyCompressionSupported() {
string out;
StringPiece in = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
@ -187,6 +297,7 @@ TEST(SnappyBuffers, MultipleWritesWithoutFlush) {
return;
}
TF_CHECK_OK(TestMultipleWrites(10000, 10000, 10000, 10000, 2));
TF_CHECK_OK(TestMultipleWritesInputStream(10000, 10000, 10000, 10000, 2));
}
TEST(SnappyBuffers, MultipleWriteCallsWithFlush) {
@ -195,6 +306,8 @@ TEST(SnappyBuffers, MultipleWriteCallsWithFlush) {
return;
}
TF_CHECK_OK(TestMultipleWrites(10000, 10000, 10000, 10000, 2, true));
TF_CHECK_OK(
TestMultipleWritesInputStream(10000, 10000, 10000, 10000, 2, true));
}
TEST(SnappyBuffers, SmallUncompressInputBuffer) {
@ -208,6 +321,17 @@ TEST(SnappyBuffers, SmallUncompressInputBuffer) {
COMPRESSED_RECORD_SIZE, " bytes."));
}
TEST(SnappyBuffers, SmallUncompressInputStream) {
if (!SnappyCompressionSupported()) {
fprintf(stderr, "skipping compression tests\n");
return;
}
CHECK_EQ(TestMultipleWritesInputStream(10000, 10000, 10000, 10, 2, true),
errors::ResourceExhausted(
"Output buffer(size: 10 bytes) too small. ",
"Should be larger than ", GetRecord().size(), " bytes."));
}
TEST(SnappyBuffers, CorruptBlock) {
if (!SnappyCompressionSupported()) {
fprintf(stderr, "skipping compression tests\n");
@ -218,6 +342,17 @@ TEST(SnappyBuffers, CorruptBlock) {
" bytes from file. ", "Possible data corruption."));
}
TEST(SnappyBuffers, CorruptBlockInputStream) {
if (!SnappyCompressionSupported()) {
fprintf(stderr, "skipping compression tests\n");
return;
}
CHECK_EQ(
TestMultipleWritesInputStream(10000, 10000, 700, 10000, 2, true, 1, true),
errors::DataLoss("Failed to read ", COMPRESSED_RECORD_SIZE,
" bytes from file. ", "Possible data corruption."));
}
TEST(SnappyBuffers, CorruptBlockLargeInputBuffer) {
if (!SnappyCompressionSupported()) {
fprintf(stderr, "skipping compression tests\n");
@ -227,6 +362,17 @@ TEST(SnappyBuffers, CorruptBlockLargeInputBuffer) {
errors::OutOfRange("EOF reached"));
}
TEST(SnappyBuffers, CorruptBlockLargeInputStream) {
if (!SnappyCompressionSupported()) {
fprintf(stderr, "skipping compression tests\n");
return;
}
CHECK_EQ(TestMultipleWritesInputStream(10000, 10000, 2000, 10000, 2, true, 1,
true),
errors::DataLoss("Failed to read ", COMPRESSED_RECORD_SIZE,
" bytes from file. Possible data corruption."));
}
TEST(SnappyBuffers, Tell) {
if (!SnappyCompressionSupported()) {
fprintf(stderr, "skipping compression tests\n");
@ -235,4 +381,12 @@ TEST(SnappyBuffers, Tell) {
TestTell(10000, 10000, 2000, 10000, 2);
}
TEST(SnappyBuffers, TellInputStream) {
if (!SnappyCompressionSupported()) {
fprintf(stderr, "skipping compression tests\n");
return;
}
TestTellInputStream(10000, 10000, 2000, 10000, 2);
}
} // namespace tensorflow