From 4bba44c1df9d706537f4d6b1ef736eacd3c0a7e5 Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Tue, 25 Feb 2020 10:35:46 -0800 Subject: [PATCH] Improving snapshot read performance (under snappy compression) by reducing the number of copies by 2. We do the following in this CL. 1) Get rid of the snappy input and output buffers that were previously being used to do the compression. This saved one copy. 2) Directly decompress the compressed bytes into the TensorBuffer for simple types (not string, variant, resource). This saves another copy during Tensor creation. For complex types, we still continue to use the TensorProto encoding and pay a copy there. 3) As a result, we end up changing the on-disk format for Snapshot. For a group of tensors that make up one element of an IteratorGetNext output, we first write out a metadata proto that describes the types, shapes and sizes of tensors. After that we lay out the Tensor data (TensorBuffers for simple types and TensorProtos serialized for complex ones) and compress them via snappy. 4) Add a version to the SnapshotMetadata. If it isn't set its assumed to be 0 and the old code path runs. We now set it to 1 while writing so that all new snapshots are written in this data format. PiperOrigin-RevId: 297149479 Change-Id: I2c9a35c5a254189a5fad946b2995f25cdc452308 --- .../core/kernels/data/experimental/BUILD | 2 + .../data/experimental/snapshot_dataset_op.cc | 444 +++++++++++++----- tensorflow/core/platform/default/port.cc | 10 + tensorflow/core/platform/snappy.h | 14 + tensorflow/core/platform/windows/port.cc | 11 + .../protobuf/data/experimental/snapshot.proto | 22 + .../kernel_tests/snapshot_test.py | 72 ++- third_party/snappy.BUILD | 4 + 8 files changed, 453 insertions(+), 126 deletions(-) diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD index d041ab5ac6a..a68b3faeb37 100644 --- a/tensorflow/core/kernels/data/experimental/BUILD +++ b/tensorflow/core/kernels/data/experimental/BUILD @@ -434,6 +434,7 @@ tf_kernel_library( name = "snapshot_dataset_op", srcs = ["snapshot_dataset_op.cc"], deps = [ + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -441,6 +442,7 @@ tf_kernel_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:graph_view", "//tensorflow/core/kernels/data:dataset_utils", + "//tensorflow/core/platform:platform_port", "//tensorflow/core/profiler/lib:traceme", "@com_google_absl//absl/time", ], diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc index ae3015bc833..68ee3c4c134 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include "absl/time/clock.h" +#include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" @@ -32,7 +33,9 @@ limitations under the License. #include "tensorflow/core/lib/io/compression.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/random_inputstream.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/snappy.h" #if !defined(IS_SLIM_BUILD) #include "tensorflow/core/lib/io/snappy/snappy_inputbuffer.h" #include "tensorflow/core/lib/io/snappy/snappy_outputbuffer.h" @@ -63,9 +66,6 @@ enum SnapshotMode { READER = 0, WRITER = 1, PASSTHROUGH = 2 }; // Defaults to 10 GiB per shard. const int64 kDefaultShardSizeBytes = 10LL * 1024 * 1024 * 1024; -const int64 kSnappyWriterInputBufferSizeBytes = 16 << 20; // 16 MiB -const int64 kSnappyWriterOutputBufferSizeBytes = 16 << 20; // 16 MiB - // The reader input buffer size is deliberately large because the input reader // will throw an error if the compressed block length cannot fit in the input // buffer. @@ -75,6 +75,8 @@ const int64 kSnappyReaderOutputBufferSizeBytes = 32 << 20; // 32 MiB const size_t kHeaderSize = sizeof(uint64); +const int64 kCurrentVersion = 1; + constexpr char kModeAuto[] = "auto"; constexpr char kModeWrite[] = "write"; constexpr char kModeRead[] = "read"; @@ -95,6 +97,7 @@ constexpr char kState[] = "state"; constexpr char kHashDir[] = "hash_dir"; constexpr char kRunId[] = "run_id"; constexpr char kRunDir[] = "run_dir"; +constexpr char kVersionStr[] = "version"; constexpr char kFilenames[] = "filenames"; constexpr char kCurrentFilenames[] = "current_filenames"; constexpr char kElementsProduced[] = "elements_produced"; @@ -115,9 +118,9 @@ class SnapshotWriter { static constexpr const char* const kWriteStringPiece = "WriteStringPiece"; static constexpr const char* const kWriteCord = "WriteCord"; - explicit SnapshotWriter(WritableFile* dest, const string& compression_type = - io::compression::kNone) - : dest_(dest), compression_type_(compression_type) { + explicit SnapshotWriter(WritableFile* dest, const string& compression_type, + int version, const DataTypeVector& dtypes) + : dest_(dest), compression_type_(compression_type), version_(version) { #if defined(IS_SLIM_BUILD) if (compression_type != io::compression::kNone) { LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning " @@ -134,41 +137,100 @@ class SnapshotWriter { TF_CHECK_OK(zlib_output_buffer->Init()); dest_ = zlib_output_buffer; dest_is_owned_ = true; - } else if (compression_type == io::compression::kSnappy) { - io::SnappyOutputBuffer* snappy_output_buffer = new io::SnappyOutputBuffer( - dest, /*input_buffer_bytes=*/kSnappyWriterInputBufferSizeBytes, - /*output_buffer_bytes=*/kSnappyWriterOutputBufferSizeBytes); - dest_ = snappy_output_buffer; - dest_is_owned_ = true; } #endif // IS_SLIM_BUILD + simple_tensor_mask_.reserve(dtypes.size()); + for (const auto& dtype : dtypes) { + if (DataTypeCanUseMemcpy(dtype)) { + simple_tensor_mask_.push_back(true); + num_simple_++; + } else { + simple_tensor_mask_.push_back(false); + num_complex_++; + } + } } - Status WriteRecord(const StringPiece& data) { - profiler::TraceMe activity( - [&]() { - return absl::StrCat(kClassName, kSeparator, kWriteStringPiece); - }, - profiler::TraceMeLevel::kInfo); - char header[kHeaderSize]; - core::EncodeFixed64(header, data.size()); - TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header)))); - return dest_->Append(data); - } - + Status WriteTensors(const std::vector& tensors) { + if (compression_type_ != io::compression::kSnappy) { + experimental::SnapshotRecord record; + for (const auto& tensor : tensors) { + TensorProto* t = record.add_tensor(); + tensor.AsProtoTensorContent(t); + } #if defined(PLATFORM_GOOGLE) - Status WriteRecord(const absl::Cord& data) { - profiler::TraceMe activity( - [&]() { return absl::StrCat(kClassName, kSeparator, kWriteCord); }, - profiler::TraceMeLevel::kInfo); - char header[kHeaderSize]; - core::EncodeFixed64(header, data.size()); - - TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header)))); - - return dest_->Append(data); - } + return WriteRecord(record.SerializeAsCord()); +#else // PLATFORM_GOOGLE + return WriteRecord(record.SerializeAsString()); #endif // PLATFORM_GOOGLE + } + + if (version_ != 1) { + return errors::InvalidArgument("Version: ", version_, + " is not supported."); + } + if (compression_type_ != io::compression::kSnappy) { + return errors::InvalidArgument( + "Version 1 is only compatible with snappy compression"); + } + + std::vector tensor_buffers; + tensor_buffers.reserve(num_simple_); + std::vector tensor_protos; + tensor_protos.reserve(num_complex_); + SnapshotTensorMetadata metadata; + int64 total_size = 0; + for (int i = 0; i < tensors.size(); ++i) { + const Tensor& tensor = tensors[i]; + TensorMetadata* tensor_metadata = metadata.add_tensor_metadata(); + tensor.shape().AsProto(tensor_metadata->mutable_tensor_shape()); + int64 size = 0; + if (simple_tensor_mask_[i]) { + auto tensor_buffer = DMAHelper::buffer(&tensor); + tensor_buffers.push_back(tensor_buffer); + size = tensor_buffer->size(); + } else { + TensorProto proto; + tensor.AsProtoTensorContent(&proto); + size = proto.ByteSizeLong(); + tensor_protos.push_back(std::move(proto)); + } + tensor_metadata->set_tensor_size_bytes(size); + total_size += size; + } + + std::vector uncompressed(total_size); + char* position = uncompressed.data(); + int buffer_index = 0; + int proto_index = 0; + for (int i = 0; i < tensors.size(); ++i) { + const auto& tensor_metadata = metadata.tensor_metadata(i); + if (simple_tensor_mask_[i]) { + memcpy(position, tensor_buffers[buffer_index]->data(), + tensor_metadata.tensor_size_bytes()); + buffer_index++; + } else { + tensor_protos[proto_index].SerializeToArray( + position, tensor_metadata.tensor_size_bytes()); + proto_index++; + } + position += tensor_metadata.tensor_size_bytes(); + } + DCHECK_EQ(position, uncompressed.data() + total_size); + + string output; + if (!port::Snappy_Compress(uncompressed.data(), total_size, &output)) { + return errors::Internal("Failed to compress using snappy."); + } +#if defined(PLATFORM_GOOGLE) + absl::Cord metadata_serialized = metadata.SerializeAsCord(); +#else // PLATFORM_GOOGLE + std::string metadata_serialized = metadata.SerializeAsString(); +#endif // PLATFORM_GOOGLE + TF_RETURN_IF_ERROR(WriteRecord(metadata_serialized)); + TF_RETURN_IF_ERROR(WriteRecord(output)); + return Status::OK(); + } Status Sync() { return dest_->Sync(); } @@ -192,9 +254,29 @@ class SnapshotWriter { } private: + Status WriteRecord(const StringPiece& data) { + char header[kHeaderSize]; + core::EncodeFixed64(header, data.size()); + TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header)))); + return dest_->Append(data); + } + +#if defined(PLATFORM_GOOGLE) + Status WriteRecord(const absl::Cord& data) { + char header[kHeaderSize]; + core::EncodeFixed64(header, data.size()); + TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header)))); + return dest_->Append(data); + } +#endif // PLATFORM_GOOGLE + WritableFile* dest_; bool dest_is_owned_ = false; const string compression_type_; + const int version_; + std::vector simple_tensor_mask_; // true for simple, false for complex. + int num_simple_ = 0; + int num_complex_ = 0; }; class SnapshotReader { @@ -203,12 +285,14 @@ class SnapshotReader { static constexpr const char* const kReadString = "ReadString"; static constexpr const char* const kReadCord = "ReadCord"; - explicit SnapshotReader( - RandomAccessFile* file, - const string& compression_type = io::compression::kNone) + explicit SnapshotReader(RandomAccessFile* file, + const string& compression_type, int version, + const DataTypeVector& dtypes) : file_(file), input_stream_(new io::RandomAccessInputStream(file)), - compression_type_(compression_type) { + compression_type_(compression_type), + version_(version), + dtypes_(dtypes) { #if defined(IS_SLIM_BUILD) if (compression_type_ != io::compression::kNone) { LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning " @@ -223,17 +307,167 @@ class SnapshotReader { input_stream_.release(), zlib_options.input_buffer_size, zlib_options.output_buffer_size, zlib_options, true); } else if (compression_type_ == io::compression::kSnappy) { - input_stream_ = absl::make_unique( - file_, /*input_buffer_bytes=*/kSnappyReaderInputBufferSizeBytes, - /*output_buffer_bytes=*/kSnappyReaderOutputBufferSizeBytes); + if (version_ == 0) { + input_stream_ = absl::make_unique( + file_, /*input_buffer_bytes=*/kSnappyReaderInputBufferSizeBytes, + /*output_buffer_bytes=*/kSnappyReaderOutputBufferSizeBytes); + } else { + input_stream_ = + absl::make_unique(file_, 64 << 20); + } } #endif // IS_SLIM_BUILD + simple_tensor_mask_.reserve(dtypes.size()); + for (const auto& dtype : dtypes) { + if (DataTypeCanUseMemcpy(dtype)) { + simple_tensor_mask_.push_back(true); + num_simple_++; + } else { + simple_tensor_mask_.push_back(false); + num_complex_++; + } + } + } + + Status ReadTensors(std::vector* read_tensors) { + profiler::TraceMe activity( + [&]() { return absl::StrCat(kClassName, kSeparator, "ReadTensors"); }, + profiler::TraceMeLevel::kInfo); + if (version_ == 0 || compression_type_ != io::compression::kSnappy) { + return ReadTensorsV0(read_tensors); + } + if (version_ != 1) { + return errors::InvalidArgument("Version: ", version_, + " is not supported."); + } + if (compression_type_ != io::compression::kSnappy) { + return errors::InvalidArgument("Version 1 only supports snappy."); + } + + SnapshotTensorMetadata metadata; + tstring metadata_str; + TF_RETURN_IF_ERROR(ReadRecord(&metadata_str)); + if (!metadata.ParseFromArray(metadata_str.data(), metadata_str.size())) { + return errors::DataLoss("Could not parse SnapshotTensorMetadata"); + } + read_tensors->reserve(metadata.tensor_metadata_size()); + + std::vector simple_tensors; + simple_tensors.reserve(num_simple_); + std::vector, size_t>> tensor_proto_strs; + tensor_proto_strs.reserve(num_complex_); + TF_RETURN_IF_ERROR( + SnappyUncompress(metadata, &simple_tensors, &tensor_proto_strs)); + + int simple_index = 0; + int complex_index = 0; + for (int i = 0; i < simple_tensor_mask_.size(); ++i) { + if (simple_tensor_mask_[i]) { + read_tensors->push_back(std::move(simple_tensors[simple_index])); + simple_index++; + } else { + auto tensor_proto_str = + std::move(tensor_proto_strs[complex_index].first); + size_t tensor_proto_size = tensor_proto_strs[complex_index].second; + TensorProto tp; +#if defined(PLATFORM_GOOGLE) + auto tensor_proto_ptr = tensor_proto_str.release(); + absl::Cord c; + c.AppendExternalMemory( + absl::string_view(tensor_proto_ptr, tensor_proto_size), + tensor_proto_ptr, + [](void* arg) { delete[] static_cast(arg); }); + if (!tp.ParseFromCord(c)) { + return errors::Internal("Could not parse TensorProto"); + } +#else // PLATFORM_GOOGLE + if (!tp.ParseFromArray(tensor_proto_str.get(), tensor_proto_size)) { + return errors::Internal("Could not parse TensorProto"); + } +#endif // PLATFORM_GOOGLE + Tensor t; + if (!t.FromProto(tp)) { + return errors::Internal("Could not parse Tensor"); + } + read_tensors->push_back(std::move(t)); + complex_index++; + } + } + return Status::OK(); + } + + private: + Status ReadTensorsV0(std::vector* read_tensors) { + experimental::SnapshotRecord record; +#if defined(PLATFORM_GOOGLE) + absl::Cord c; + TF_RETURN_IF_ERROR(ReadRecord(&c)); + record.ParseFromCord(c); +#else // PLATFORM_GOOGLE + tstring record_bytes; + TF_RETURN_IF_ERROR(ReadRecord(&record_bytes)); + record.ParseFromArray(record_bytes.data(), record_bytes.size()); +#endif // PLATFORM_GOOGLE + read_tensors->reserve(record.tensor_size()); + for (int i = 0; i < record.tensor_size(); ++i) { + read_tensors->emplace_back(); + if (!read_tensors->back().FromProto(record.tensor(i))) { + return errors::DataLoss("Unable to parse tensor from proto."); + } + } + return Status::OK(); + } + + Status SnappyUncompress( + const SnapshotTensorMetadata& metadata, + std::vector* simple_tensors, + std::vector, size_t>>* + tensor_proto_strs) { + tstring compressed; + TF_RETURN_IF_ERROR(ReadRecord(&compressed)); + size_t size; + if (!port::Snappy_GetUncompressedLength(compressed.data(), + compressed.size(), &size)) { + return errors::Internal("Could not get snappy uncompressed length"); + } + + int num_tensors = metadata.tensor_metadata_size(); + std::vector iov(num_tensors); + int index = 0; + int64 total_size = 0; + for (int i = 0; i < simple_tensor_mask_.size(); ++i) { + const auto& tensor_metadata = metadata.tensor_metadata(i); + if (simple_tensor_mask_[i]) { + TensorShape shape(tensor_metadata.tensor_shape()); + Tensor simple_tensor(dtypes_[i], shape); + TensorBuffer* buffer = DMAHelper::buffer(&simple_tensor); + iov[index].iov_base = buffer->data(); + iov[index].iov_len = buffer->size(); + simple_tensors->push_back(std::move(simple_tensor)); + } else { + auto tensor_proto_str = + absl::make_unique(tensor_metadata.tensor_size_bytes()); + iov[index].iov_base = tensor_proto_str.get(); + iov[index].iov_len = tensor_metadata.tensor_size_bytes(); + tensor_proto_strs->push_back(std::make_pair( + std::move(tensor_proto_str), tensor_metadata.tensor_size_bytes())); + } + total_size += iov[index].iov_len; + index++; + } + if (size != total_size) { + return errors::Internal("Uncompressed size mismatch. Snappy expects ", + size, " whereas the tensor metadata suggests ", + total_size); + } + if (!port::Snappy_UncompressToIOVec(compressed.data(), compressed.size(), + iov.data(), num_tensors)) { + return errors::Internal("Failed to perform snappy decompression."); + } + return Status::OK(); } Status ReadRecord(tstring* record) { - profiler::TraceMe activity( - [&]() { return absl::StrCat(kClassName, kSeparator, kReadString); }, - profiler::TraceMeLevel::kInfo); tstring header; TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header)); uint64 length = core::DecodeFixed64(header.data()); @@ -245,13 +479,6 @@ class SnapshotReader { tstring header; TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header)); uint64 length = core::DecodeFixed64(header.data()); - profiler::TraceMe activity( - [&]() { - return absl::StrCat(kClassName, kSeparator, kReadCord, - "#length=", length, "#"); - }, - profiler::TraceMeLevel::kInfo); - if (compression_type_ == io::compression::kNone) { return input_stream_->ReadNBytes(length, record); } else { @@ -268,50 +495,31 @@ class SnapshotReader { } #endif - private: RandomAccessFile* file_; std::unique_ptr input_stream_; const string compression_type_; + const int version_; + const DataTypeVector dtypes_; + int num_simple_ = 0; + int num_complex_ = 0; + std::vector simple_tensor_mask_; // true for simple, false for complex. }; Status WriteMetadataFile(const string& hash_dir, const experimental::SnapshotMetadataRecord& metadata) { string metadata_filename = io::JoinPath(hash_dir, kSnapshotFilename); TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(hash_dir)); - std::string tmp_filename = absl::StrCat(metadata_filename, "-tmp-", random::New64()); - - std::unique_ptr file; - TF_RETURN_IF_ERROR(Env::Default()->NewWritableFile(tmp_filename, &file)); - - auto writer = absl::make_unique(file.get()); - TF_RETURN_IF_ERROR(writer->WriteRecord(metadata.SerializeAsString())); - TF_RETURN_IF_ERROR(writer->Close()); - TF_RETURN_IF_ERROR(file->Sync()); - TF_RETURN_IF_ERROR(file->Close()); - - TF_RETURN_IF_ERROR( - Env::Default()->RenameFile(tmp_filename, metadata_filename)); - - return Status::OK(); + TF_RETURN_IF_ERROR(WriteBinaryProto(Env::Default(), tmp_filename, metadata)); + return Env::Default()->RenameFile(tmp_filename, metadata_filename); } Status ReadMetadataFile(const string& hash_dir, experimental::SnapshotMetadataRecord* metadata) { string metadata_filename = io::JoinPath(hash_dir, kSnapshotFilename); TF_RETURN_IF_ERROR(Env::Default()->FileExists(metadata_filename)); - - std::unique_ptr file; - TF_RETURN_IF_ERROR( - Env::Default()->NewRandomAccessFile(metadata_filename, &file)); - - tstring record_bytes; - SnapshotReader reader(file.get()); - TF_RETURN_IF_ERROR(reader.ReadRecord(&record_bytes)); - - metadata->ParseFromArray(record_bytes.data(), record_bytes.size()); - return Status::OK(); + return ReadBinaryProto(Env::Default(), metadata_filename, metadata); } Status DumpDatasetGraph(const std::string& path, uint64 hash, @@ -332,6 +540,10 @@ Status DetermineOpState(const std::string& mode_string, const uint64 pending_snapshot_expiry_seconds, SnapshotMode* mode) { if (mode_string == kModeRead) { + // In read mode, we should expect a metadata file is written. + if (errors::IsNotFound(file_status)) { + return file_status; + } LOG(INFO) << "Overriding mode to reader."; *mode = READER; return Status::OK(); @@ -727,10 +939,25 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { if (run_id.empty()) { run_id = metadata.run_id(); } + // dtypes in metadata should be the same as dataset()->output_dtypes + if (metadata.dtype_size() != dataset()->output_dtypes().size()) { + return errors::Internal( + "Expected number of dtypes: ", + dataset()->output_dtypes().size(), + " but number in snapshot: ", metadata.dtype_size()); + } + for (int i = 0; i < metadata.dtype_size(); ++i) { + if (metadata.dtype(i) != dataset()->output_dtypes()[i]) { + return errors::Internal( + "Type: ", i, + " doesn't match. Snapshot: ", metadata.dtype(i), + "; dataset: ", dataset()->output_dtypes()[i]); + } + } iterator_ = absl::make_unique( SnapshotReaderIterator::Params{ dataset(), absl::StrCat(prefix(), "ReaderImpl")}, - hash_dir_, run_id); + hash_dir_, run_id, metadata.version()); break; case PASSTHROUGH: iterator_ = absl::make_unique( @@ -748,10 +975,11 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { explicit SnapshotReaderIterator(const Params& params, const string& hash_dir, - const string& run_id) + const string& run_id, int64 version) : DatasetIterator(params), hash_dir_(hash_dir), - run_id_(run_id) {} + run_id_(run_id), + version_(version) {} ~SnapshotReaderIterator() override { mutex_lock l(mu_); @@ -889,6 +1117,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { writer->WriteScalar(full_name(kHashDir), hash_dir_)); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kRunId), run_id_)); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kRunDir), run_dir_)); + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name(kVersionStr), version_)); TF_RETURN_IF_ERROR(writer->WriteScalar( full_name(strings::StrCat(kFilenames, kSizeSuffix)), filenames_.size())); @@ -932,6 +1162,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { } TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunId), &run_id_)); TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunDir), &run_dir_)); + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name(kVersionStr), &version_)); curr_filenames_.clear(); curr_filenames_.reserve(dataset()->num_reader_threads_); for (auto i = 0; i < dataset()->num_reader_threads_; ++i) { @@ -986,7 +1218,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr file; TF_RETURN_IF_ERROR( Env::Default()->NewRandomAccessFile(filename, &file)); - SnapshotReader reader(file.get(), dataset()->compression_); + SnapshotReader reader(file.get(), dataset()->compression_, version_, + dataset()->output_dtypes()); while (true) { // Wait for a slot in the buffer. @@ -1003,30 +1236,14 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { "ReadFile"); } } -#if !defined(PLATFORM_GOOGLE) - tstring record_bytes; - Status s = reader.ReadRecord(&record_bytes); -#else - absl::Cord record_cord; - Status s = reader.ReadRecord(&record_cord); -#endif + std::vector read_tensors; + Status s = reader.ReadTensors(&read_tensors); if (s.ok()) { profiler::TraceMe activity( [&]() { return absl::StrCat(prefix(), kSeparator, kParse); }, profiler::TraceMeLevel::kInfo); - experimental::SnapshotRecord record; -#if !defined(PLATFORM_GOOGLE) - record.ParseFromArray(record_bytes.data(), record_bytes.size()); -#else - record.ParseFromCord(record_cord); -#endif BufferElement elem; - for (int i = 0; i < record.tensor_size(); ++i) { - elem.value.emplace_back(); - if (!elem.value.back().FromProto(record.tensor(i))) { - return errors::DataLoss("Unable to parse tensor from proto."); - } - } + elem.value = std::move(read_tensors); elem.status = Status::OK(); mutex_lock l(mu_); buffer_.push_back(std::move(elem)); @@ -1142,9 +1359,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { condition_variable cond_var_; const string hash_dir_; - const experimental::SnapshotMetadataRecord metadata_; tstring run_id_ GUARDED_BY(mu_); tstring run_dir_ GUARDED_BY(mu_); + int64 version_; std::vector filenames_; uint64 elements_produced_ GUARDED_BY(mu_) = 0; @@ -1220,6 +1437,10 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { metadata.set_creation_timestamp(EnvTime::NowMicros()); metadata.set_graph_hash(dataset()->graph_hash_); metadata.set_run_id(run_id_.data(), run_id_.size()); + metadata.set_version(kCurrentVersion); + for (const auto& output_dtype : dataset()->output_dtypes()) { + metadata.add_dtype(output_dtype); + } metadata.set_finalized(false); TF_RETURN_IF_ERROR(WriteMetadataFile(hash_dir_, metadata)); } @@ -1564,11 +1785,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { } if (produced_elem) { - experimental::SnapshotRecord record; for (const auto& out_tensor : elem.value) { *bytes_written += out_tensor.TotalBytes(); - TensorProto* t = record.add_tensor(); - out_tensor.AsProtoTensorContent(t); } bool should_close; @@ -1584,16 +1802,11 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(Env::Default()->NewAppendableFile( *snapshot_data_filename, file)); *writer = absl::make_unique( - file->get(), dataset()->compression_); + file->get(), dataset()->compression_, kCurrentVersion, + dataset()->output_dtypes()); *bytes_written = 0; } -#if defined(PLATFORM_GOOGLE) - TF_RETURN_IF_ERROR( - (*writer)->WriteRecord(record.SerializeAsCord())); -#else // PLATFORM_GOOGLE - TF_RETURN_IF_ERROR( - (*writer)->WriteRecord(record.SerializeAsString())); -#endif // PLATFORM_GOOGLE + TF_RETURN_IF_ERROR((*writer)->WriteTensors(elem.value)); return Status::OK(); } @@ -1641,7 +1854,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { return; } std::unique_ptr writer( - new SnapshotWriter(file.get(), dataset()->compression_)); + new SnapshotWriter(file.get(), dataset()->compression_, + kCurrentVersion, dataset()->output_dtypes())); bool end_of_processing = false; while (!end_of_processing) { diff --git a/tensorflow/core/platform/default/port.cc b/tensorflow/core/platform/default/port.cc index 47f4abae3bb..756e7e8a93a 100644 --- a/tensorflow/core/platform/default/port.cc +++ b/tensorflow/core/platform/default/port.cc @@ -332,6 +332,16 @@ bool Snappy_Uncompress(const char* input, size_t length, char* output) { #endif } +bool Snappy_UncompressToIOVec(const char* compressed, size_t compressed_length, + const struct iovec* iov, size_t iov_cnt) { +#ifdef TF_USE_SNAPPY + return snappy::RawUncompressToIOVec(compressed, compressed_length, iov, + iov_cnt); +#else + return false; +#endif +} + string Demangle(const char* mangled) { return mangled; } double NominalCPUFrequency() { diff --git a/tensorflow/core/platform/snappy.h b/tensorflow/core/platform/snappy.h index 5477b097ef0..df06f3dcc1e 100644 --- a/tensorflow/core/platform/snappy.h +++ b/tensorflow/core/platform/snappy.h @@ -18,6 +18,17 @@ limitations under the License. #include "tensorflow/core/platform/types.h" +#if !defined(PLATFORM_WINDOWS) +#include +#else +namespace tensorflow { +struct iovec { + void* iov_base; + size_t iov_len; +}; +} // namespace tensorflow +#endif + namespace tensorflow { namespace port { @@ -28,6 +39,9 @@ bool Snappy_GetUncompressedLength(const char* input, size_t length, size_t* result); bool Snappy_Uncompress(const char* input, size_t length, char* output); +bool Snappy_UncompressToIOVec(const char* compressed, size_t compressed_length, + const struct iovec* iov, size_t iov_cnt); + } // namespace port } // namespace tensorflow diff --git a/tensorflow/core/platform/windows/port.cc b/tensorflow/core/platform/windows/port.cc index 2303b587ce6..547af76bdf6 100644 --- a/tensorflow/core/platform/windows/port.cc +++ b/tensorflow/core/platform/windows/port.cc @@ -157,6 +157,17 @@ bool Snappy_Uncompress(const char* input, size_t length, char* output) { #endif } +bool Snappy_UncompressToIOVec(const char* compressed, size_t compressed_length, + const struct iovec* iov, size_t iov_cnt) { +#ifdef TF_USE_SNAPPY + const snappy::iovec* snappy_iov = reinterpret_cast(iov); + return snappy::RawUncompressToIOVec(compressed, compressed_length, snappy_iov, + iov_cnt); +#else + return false; +#endif +} + string Demangle(const char* mangled) { return mangled; } double NominalCPUFrequency() { diff --git a/tensorflow/core/protobuf/data/experimental/snapshot.proto b/tensorflow/core/protobuf/data/experimental/snapshot.proto index 422602d3760..e013deb2ee1 100644 --- a/tensorflow/core/protobuf/data/experimental/snapshot.proto +++ b/tensorflow/core/protobuf/data/experimental/snapshot.proto @@ -3,6 +3,8 @@ syntax = "proto3"; package tensorflow.data.experimental; import "tensorflow/core/framework/tensor.proto"; +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; // Each SnapshotRecord represents one batch of pre-processed input data. A batch // consists of a list of tensors that we encode as TensorProtos. This message @@ -13,9 +15,29 @@ message SnapshotRecord { // This stores the metadata information present in each snapshot record. message SnapshotMetadataRecord { + // Stores the fingerprint of the graph that describes the dataset that is + // snapshotted. string graph_hash = 1; + // Run ID that this snapshot corresponds to. string run_id = 2; + // Time when we started creating this snapshot. int64 creation_timestamp = 3; + // Version of the snapshot data file format. + int64 version = 4; + // A list of tensor dtype corresponding to each element of the snapshot. + repeated .tensorflow.DataType dtype = 5; bool finalized = 1000; } + +// Metadata for a single tensor in the Snapshot Record. +message TensorMetadata { + .tensorflow.TensorShapeProto tensor_shape = 2; + // Number of uncompressed bytes used to store the tensor representation. + int64 tensor_size_bytes = 3; +} + +// Metadata for all the tensors in a Snapshot Record. +message SnapshotTensorMetadata { + repeated TensorMetadata tensor_metadata = 1; +} diff --git a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py index 96b3b764864..535cf884dc6 100644 --- a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py @@ -161,17 +161,49 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) - @combinations.generate(test_base.default_test_combinations()) - def testWriteSnapshotRepeatAfterwards(self): + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(compression=[ + snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, + snapshot.COMPRESSION_SNAPPY + ]))) + def testWriteSnapshotRepeatAfterwards(self, compression): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.range(10) - dataset = dataset.apply(snapshot.snapshot(tmpdir)) + dataset = dataset.apply(snapshot.snapshot(tmpdir, compression=compression)) dataset = dataset.repeat(10) self.assertDatasetProduces(dataset, list(range(10)) * 10) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(compression=[ + snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, + snapshot.COMPRESSION_SNAPPY + ]))) + def testWriteSnapshotMixTypes(self, compression): + tmpdir = self.snapshot_dir + + dataset = dataset_ops.Dataset.range(10) + + def map_fn(x): + return (x, string_ops.as_string(x), string_ops.as_string(2 * x), 2 * x) + + dataset = dataset.map(map_fn) + dataset = dataset.apply(snapshot.snapshot(tmpdir, compression=compression)) + dataset = dataset.repeat(10) + + expected = [] + for i in range(10): + expected.append((i, str(i), str(2 * i), 2 * i)) + self.assertDatasetProduces(dataset, expected * 10) + + self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) + @combinations.generate(test_base.default_test_combinations()) def testSpecifySnapshotNameWriteAndRead(self): tmpdir = self.snapshot_dir @@ -365,8 +397,14 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, res3 = self.evaluate(next3()) self.assertEqual(res2, res3) - @combinations.generate(test_base.default_test_combinations()) - def testReadSnapshotParallelAfterWrite(self): + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(compression=[ + snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, + snapshot.COMPRESSION_SNAPPY + ]))) + def testReadSnapshotParallelAfterWrite(self, compression): self.setUpTFRecord(10, 4000) filenames = self.test_filenames @@ -383,7 +421,8 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, tmpdir, shard_size_bytes=1024 * 1024, num_reader_threads=2, - reader_buffer_size=10)) + reader_buffer_size=10, + compression=compression)) self.assertDatasetProduces(dataset, expected, assert_items_equal=True) # remove the original files and try to read the data back only from @@ -396,7 +435,8 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, tmpdir, shard_size_bytes=1024 * 1024, num_reader_threads=2, - reader_buffer_size=10)) + reader_buffer_size=10, + compression=compression)) self.assertDatasetProduces(dataset2, expected, assert_items_equal=True) # Not testing Snappy here because Snappy reads currently require a lot of @@ -514,21 +554,31 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, self.evaluate(next2()) self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1) - @combinations.generate(test_base.default_test_combinations()) - def testSpecifyShardSize(self): + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(compression=[ + snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, + snapshot.COMPRESSION_SNAPPY + ]))) + def testSpecifyShardSize(self, compression): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.from_tensor_slices([1.0]) dataset = dataset.map(lambda x: gen_array_ops.broadcast_to(x, [1024, 1024])) dataset = dataset.repeat(10) dataset = dataset.apply( - snapshot.snapshot(tmpdir, shard_size_bytes=10 * 1024 * 1024)) + snapshot.snapshot( + tmpdir, shard_size_bytes=10 * 1024 * 1024, compression=compression)) next_fn = self.getNext(dataset) for _ in range(10): self.evaluate(next_fn()) - self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 3) + num_files = 1 + if compression == snapshot.COMPRESSION_NONE: + num_files = 3 + self.assertSnapshotDirectoryContains(tmpdir, 1, 1, num_files) @combinations.generate(test_base.default_test_combinations()) def testAdditionalOperationsAfterReadBack(self): diff --git a/third_party/snappy.BUILD b/third_party/snappy.BUILD index d93f0307690..a2ab4924f29 100644 --- a/third_party/snappy.BUILD +++ b/third_party/snappy.BUILD @@ -27,6 +27,10 @@ cc_library( "-Wno-implicit-function-declaration", ], }), + defines = select({ + "@org_tensorflow//tensorflow:windows": [], + "//conditions:default": ["HAVE_SYS_UIO_H"], + }), ) genrule(