diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 0f709750897..695035c91e9 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1864,6 +1864,7 @@ cc_library( "//tensorflow/core/lib/io:random_inputstream", "//tensorflow/core/lib/io:record_reader", "//tensorflow/core/lib/io:record_writer", + "//tensorflow/core/lib/io:snappy_compression_options", "//tensorflow/core/lib/io:snappy_inputbuffer", "//tensorflow/core/lib/io:snappy_inputstream", "//tensorflow/core/lib/io:snappy_outputbuffer", diff --git a/tensorflow/core/kernels/data/experimental/snapshot_util_test.cc b/tensorflow/core/kernels/data/experimental/snapshot_util_test.cc index aedc0e194d7..e253014bf94 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_util_test.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_util_test.cc @@ -88,6 +88,7 @@ TEST(SnapshotUtilTest, CombinationRoundTripTest) { SnapshotRoundTrip(io::compression::kNone, 2); SnapshotRoundTrip(io::compression::kGzip, 2); + SnapshotRoundTrip(io::compression::kSnappy, 2); } void SnapshotReaderBenchmarkLoop(int iters, std::string compression_type, @@ -195,11 +196,16 @@ void SnapshotTFRecordWriterGzipBenchmark(int iters) { SnapshotWriterBenchmarkLoop(iters, io::compression::kGzip, 2); } +void SnapshotTFRecordWriterSnappyBenchmark(int iters) { + SnapshotWriterBenchmarkLoop(iters, io::compression::kSnappy, 2); +} + BENCHMARK(SnapshotCustomWriterNoneBenchmark); BENCHMARK(SnapshotCustomWriterGzipBenchmark); BENCHMARK(SnapshotCustomWriterSnappyBenchmark); BENCHMARK(SnapshotTFRecordWriterNoneBenchmark); BENCHMARK(SnapshotTFRecordWriterGzipBenchmark); +BENCHMARK(SnapshotTFRecordWriterSnappyBenchmark); } // namespace } // namespace snapshot_util diff --git a/tensorflow/core/lib/io/BUILD b/tensorflow/core/lib/io/BUILD index 5e1704a50c1..797e9ad1a4b 100644 --- a/tensorflow/core/lib/io/BUILD +++ b/tensorflow/core/lib/io/BUILD @@ -145,6 +145,8 @@ cc_library( ":compression", ":inputstream_interface", ":random_inputstream", + ":snappy_compression_options", + ":snappy_inputstream", ":zlib_compression_options", ":zlib_inputstream", "//tensorflow/core/lib/core:coding", @@ -164,6 +166,8 @@ cc_library( hdrs = ["record_writer.h"], deps = [ ":compression", + ":snappy_compression_options", + ":snappy_outputbuffer", ":zlib_compression_options", ":zlib_outputbuffer", "//tensorflow/core/lib/core:coding", @@ -221,6 +225,15 @@ cc_library( alwayslink = True, ) +cc_library( + name = "snappy_compression_options", + hdrs = ["snappy/snappy_compression_options.h"], + deps = [ + "//tensorflow/core/platform:types", + ], + alwayslink = True, +) + cc_library( name = "cache", srcs = [ @@ -336,6 +349,9 @@ filegroup( "random_inputstream.h", "record_reader.cc", "record_reader.h", + "snappy/snappy_compression_options.h", + "snappy/snappy_inputstream.cc", + "snappy/snappy_inputstream.h", "table.cc", "table.h", "table_builder.cc", @@ -366,6 +382,7 @@ filegroup( "random_inputstream.h", "record_reader.h", "record_writer.h", + "snappy/snappy_compression_options.h", "snappy/snappy_inputbuffer.h", "snappy/snappy_inputstream.h", "snappy/snappy_outputbuffer.h", @@ -422,6 +439,7 @@ filegroup( srcs = [ "inputbuffer.h", "iterator.h", + "snappy/snappy_compression_options.h", "snappy/snappy_inputbuffer.h", "snappy/snappy_inputstream.h", "snappy/snappy_outputbuffer.h", diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc index 1af81bd902c..40e516f5ef9 100644 --- a/tensorflow/core/lib/io/record_reader.cc +++ b/tensorflow/core/lib/io/record_reader.cc @@ -31,26 +31,26 @@ namespace io { RecordReaderOptions RecordReaderOptions::CreateRecordReaderOptions( const string& compression_type) { RecordReaderOptions options; + +#if defined(IS_SLIM_BUILD) + if (compression_type != compression::kNone) { + LOG(ERROR) << "Compression is not supported but compression_type is set." + << " No compression will be used."; + } +#else if (compression_type == compression::kZlib) { options.compression_type = io::RecordReaderOptions::ZLIB_COMPRESSION; -#if defined(IS_SLIM_BUILD) - LOG(ERROR) << "Compression is not supported but compression_type is set." - << " No compression will be used."; -#else options.zlib_options = io::ZlibCompressionOptions::DEFAULT(); -#endif // IS_SLIM_BUILD } else if (compression_type == compression::kGzip) { options.compression_type = io::RecordReaderOptions::ZLIB_COMPRESSION; -#if defined(IS_SLIM_BUILD) - LOG(ERROR) << "Compression is not supported but compression_type is set." - << " No compression will be used."; -#else options.zlib_options = io::ZlibCompressionOptions::GZIP(); -#endif // IS_SLIM_BUILD + } else if (compression_type == compression::kSnappy) { + options.compression_type = io::RecordReaderOptions::SNAPPY_COMPRESSION; } else if (compression_type != compression::kNone) { LOG(ERROR) << "Unsupported compression_type:" << compression_type << ". No compression will be used."; } +#endif return options; } @@ -63,20 +63,26 @@ RecordReader::RecordReader(RandomAccessFile* file, input_stream_.reset(new BufferedInputStream(input_stream_.release(), options.buffer_size, true)); } - if (options.compression_type == RecordReaderOptions::ZLIB_COMPRESSION) { -// We don't have zlib available on all embedded platforms, so fail. #if defined(IS_SLIM_BUILD) - LOG(FATAL) << "Zlib compression is unsupported on mobile platforms."; -#else // IS_SLIM_BUILD + if (options.compression_type != RecordReaderOptions::NONE) { + LOG(FATAL) << "Compression is unsupported on mobile platforms."; + } +#else + if (options.compression_type == RecordReaderOptions::ZLIB_COMPRESSION) { input_stream_.reset(new ZlibInputStream( input_stream_.release(), options.zlib_options.input_buffer_size, options.zlib_options.output_buffer_size, options.zlib_options, true)); -#endif // IS_SLIM_BUILD + } else if (options.compression_type == + RecordReaderOptions::SNAPPY_COMPRESSION) { + input_stream_.reset( + new SnappyInputStream(input_stream_.release(), + options.snappy_options.output_buffer_size, true)); } else if (options.compression_type == RecordReaderOptions::NONE) { // Nothing to do. } else { LOG(FATAL) << "Unrecognized compression type :" << options.compression_type; } +#endif } // Read n+4 bytes from file, verify that checksum of first n bytes is diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h index dd7def79f05..07709990a64 100644 --- a/tensorflow/core/lib/io/record_reader.h +++ b/tensorflow/core/lib/io/record_reader.h @@ -20,6 +20,8 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/inputstream_interface.h" #if !defined(IS_SLIM_BUILD) +#include "tensorflow/core/lib/io/snappy/snappy_compression_options.h" +#include "tensorflow/core/lib/io/snappy/snappy_inputstream.h" #include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/core/lib/io/zlib_inputstream.h" #endif // IS_SLIM_BUILD @@ -32,9 +34,12 @@ class RandomAccessFile; namespace io { -class RecordReaderOptions { - public: - enum CompressionType { NONE = 0, ZLIB_COMPRESSION = 1 }; +struct RecordReaderOptions { + enum CompressionType { + NONE = 0, + ZLIB_COMPRESSION = 1, + SNAPPY_COMPRESSION = 2 + }; CompressionType compression_type = NONE; // If buffer_size is non-zero, then all reads must be sequential, and no @@ -46,8 +51,9 @@ class RecordReaderOptions { const string& compression_type); #if !defined(IS_SLIM_BUILD) - // Options specific to zlib compression. + // Options specific to compression. ZlibCompressionOptions zlib_options; + SnappyCompressionOptions snappy_options; #endif // IS_SLIM_BUILD }; diff --git a/tensorflow/core/lib/io/record_reader_writer_test.cc b/tensorflow/core/lib/io/record_reader_writer_test.cc index 373c0d8b664..486b238bd29 100644 --- a/tensorflow/core/lib/io/record_reader_writer_test.cc +++ b/tensorflow/core/lib/io/record_reader_writer_test.cc @@ -158,6 +158,44 @@ TEST(RecordReaderWriterTest, TestBasics) { } } +TEST(RecordReaderWriterTest, TestSnappy) { + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/record_reader_writer_snappy_test"; + + for (auto buf_size : BufferSizes()) { + // Snappy compression needs output buffer size > 1. + if (buf_size == 1) continue; + { + std::unique_ptr file; + TF_CHECK_OK(env->NewWritableFile(fname, &file)); + + io::RecordWriterOptions options; + options.compression_type = io::RecordWriterOptions::SNAPPY_COMPRESSION; + options.zlib_options.output_buffer_size = buf_size; + io::RecordWriter writer(file.get(), options); + TF_EXPECT_OK(writer.WriteRecord("abc")); + TF_EXPECT_OK(writer.WriteRecord("defg")); + TF_CHECK_OK(writer.Flush()); + } + + { + std::unique_ptr read_file; + // Read it back with the RecordReader. + TF_CHECK_OK(env->NewRandomAccessFile(fname, &read_file)); + io::RecordReaderOptions options; + options.compression_type = io::RecordReaderOptions::SNAPPY_COMPRESSION; + options.zlib_options.input_buffer_size = buf_size; + io::RecordReader reader(read_file.get(), options); + uint64 offset = 0; + tstring record; + TF_CHECK_OK(reader.ReadRecord(&offset, &record)); + EXPECT_EQ("abc", record); + TF_CHECK_OK(reader.ReadRecord(&offset, &record)); + EXPECT_EQ("defg", record); + } + } +} + TEST(RecordReaderWriterTest, TestZlib) { Env* env = Env::Default(); string fname = testing::TmpDir() + "/record_reader_writer_zlib_test"; diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc index 52d0ef9a358..c82963d40c2 100644 --- a/tensorflow/core/lib/io/record_writer.cc +++ b/tensorflow/core/lib/io/record_writer.cc @@ -23,45 +23,49 @@ limitations under the License. namespace tensorflow { namespace io { namespace { -bool IsZlibCompressed(RecordWriterOptions options) { +bool IsZlibCompressed(const RecordWriterOptions& options) { return options.compression_type == RecordWriterOptions::ZLIB_COMPRESSION; } + +bool IsSnappyCompressed(const RecordWriterOptions& options) { + return options.compression_type == RecordWriterOptions::SNAPPY_COMPRESSION; +} } // namespace RecordWriterOptions RecordWriterOptions::CreateRecordWriterOptions( const string& compression_type) { RecordWriterOptions options; +#if defined(IS_SLIM_BUILD) + if (compression_type != compression::kNone) { + LOG(ERROR) << "Compression is not supported but compression_type is set." + << " No compression will be used."; + } +#else if (compression_type == compression::kZlib) { options.compression_type = io::RecordWriterOptions::ZLIB_COMPRESSION; -#if defined(IS_SLIM_BUILD) - LOG(ERROR) << "Compression is not supported but compression_type is set." - << " No compression will be used."; -#else options.zlib_options = io::ZlibCompressionOptions::DEFAULT(); -#endif // IS_SLIM_BUILD } else if (compression_type == compression::kGzip) { options.compression_type = io::RecordWriterOptions::ZLIB_COMPRESSION; -#if defined(IS_SLIM_BUILD) - LOG(ERROR) << "Compression is not supported but compression_type is set." - << " No compression will be used."; -#else options.zlib_options = io::ZlibCompressionOptions::GZIP(); -#endif // IS_SLIM_BUILD + } else if (compression_type == compression::kSnappy) { + options.compression_type = io::RecordWriterOptions::SNAPPY_COMPRESSION; } else if (compression_type != compression::kNone) { LOG(ERROR) << "Unsupported compression_type:" << compression_type << ". No compression will be used."; } +#endif return options; } RecordWriter::RecordWriter(WritableFile* dest, const RecordWriterOptions& options) : dest_(dest), options_(options) { - if (IsZlibCompressed(options)) { -// We don't have zlib available on all embedded platforms, so fail. #if defined(IS_SLIM_BUILD) - LOG(FATAL) << "Zlib compression is unsupported on mobile platforms."; -#else // IS_SLIM_BUILD + if (compression_type != compression::kNone) { + LOG(FATAL) << "Compression is unsupported on mobile platforms."; + } +#else + if (IsZlibCompressed(options)) { ZlibOutputBuffer* zlib_output_buffer = new ZlibOutputBuffer( dest, options.zlib_options.input_buffer_size, options.zlib_options.output_buffer_size, options.zlib_options); @@ -71,12 +75,16 @@ RecordWriter::RecordWriter(WritableFile* dest, << s.ToString(); } dest_ = zlib_output_buffer; -#endif // IS_SLIM_BUILD + } else if (IsSnappyCompressed(options)) { + dest_ = + new SnappyOutputBuffer(dest, options.snappy_options.input_buffer_size, + options.snappy_options.output_buffer_size); } else if (options.compression_type == RecordWriterOptions::NONE) { // Nothing to do } else { LOG(FATAL) << "Unspecified compression type :" << options.compression_type; } +#endif } RecordWriter::~RecordWriter() { @@ -130,14 +138,12 @@ Status RecordWriter::WriteRecord(const absl::Cord& data) { Status RecordWriter::Close() { if (dest_ == nullptr) return Status::OK(); -#if !defined(IS_SLIM_BUILD) - if (IsZlibCompressed(options_)) { + if (IsZlibCompressed(options_) || IsSnappyCompressed(options_)) { Status s = dest_->Close(); delete dest_; dest_ = nullptr; return s; } -#endif // IS_SLIM_BUILD return Status::OK(); } diff --git a/tensorflow/core/lib/io/record_writer.h b/tensorflow/core/lib/io/record_writer.h index 012c2fbbc91..243dc847ec5 100644 --- a/tensorflow/core/lib/io/record_writer.h +++ b/tensorflow/core/lib/io/record_writer.h @@ -21,6 +21,8 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/hash/crc32c.h" #if !defined(IS_SLIM_BUILD) +#include "tensorflow/core/lib/io/snappy/snappy_compression_options.h" +#include "tensorflow/core/lib/io/snappy/snappy_outputbuffer.h" #include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/core/lib/io/zlib_outputbuffer.h" #endif // IS_SLIM_BUILD @@ -34,17 +36,22 @@ class WritableFile; namespace io { -class RecordWriterOptions { +struct RecordWriterOptions { public: - enum CompressionType { NONE = 0, ZLIB_COMPRESSION = 1 }; + enum CompressionType { + NONE = 0, + ZLIB_COMPRESSION = 1, + SNAPPY_COMPRESSION = 2 + }; CompressionType compression_type = NONE; static RecordWriterOptions CreateRecordWriterOptions( const string& compression_type); -// Options specific to zlib compression. #if !defined(IS_SLIM_BUILD) + // Options specific to compression. tensorflow::io::ZlibCompressionOptions zlib_options; + tensorflow::io::SnappyCompressionOptions snappy_options; #endif // IS_SLIM_BUILD }; @@ -70,7 +77,7 @@ class RecordWriter { // implicit Close() call in the destructor. ~RecordWriter(); - Status WriteRecord(StringPiece slice); + Status WriteRecord(StringPiece data); #if defined(PLATFORM_GOOGLE) Status WriteRecord(const absl::Cord& data); diff --git a/tensorflow/core/lib/io/snappy/snappy_compression_options.h b/tensorflow/core/lib/io/snappy/snappy_compression_options.h new file mode 100644 index 00000000000..d3d798bfa8f --- /dev/null +++ b/tensorflow/core/lib/io/snappy/snappy_compression_options.h @@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_IO_SNAPPY_SNAPPY_COMPRESSION_OPTIONS_H_ +#define TENSORFLOW_CORE_LIB_IO_SNAPPY_SNAPPY_COMPRESSION_OPTIONS_H_ + +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace io { + +struct SnappyCompressionOptions { + // Size of the buffer used for caching the data read from source file. + int64 input_buffer_size = 256 << 10; + + // Size of the sink buffer where the compressed/decompressed data produced by + // snappy is cached. + int64 output_buffer_size = 256 << 10; +}; + +} // namespace io +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_IO_SNAPPY_SNAPPY_COMPRESSION_OPTIONS_H_