diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD index d041ab5ac6a..a68b3faeb37 100644 --- a/tensorflow/core/kernels/data/experimental/BUILD +++ b/tensorflow/core/kernels/data/experimental/BUILD @@ -434,6 +434,7 @@ tf_kernel_library( name = "snapshot_dataset_op", srcs = ["snapshot_dataset_op.cc"], deps = [ + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -441,6 +442,7 @@ tf_kernel_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:graph_view", "//tensorflow/core/kernels/data:dataset_utils", + "//tensorflow/core/platform:platform_port", "//tensorflow/core/profiler/lib:traceme", "@com_google_absl//absl/time", ], diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc index ae3015bc833..68ee3c4c134 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include "absl/time/clock.h" +#include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" @@ -32,7 +33,9 @@ limitations under the License. #include "tensorflow/core/lib/io/compression.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/random_inputstream.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/snappy.h" #if !defined(IS_SLIM_BUILD) #include "tensorflow/core/lib/io/snappy/snappy_inputbuffer.h" #include "tensorflow/core/lib/io/snappy/snappy_outputbuffer.h" @@ -63,9 +66,6 @@ enum SnapshotMode { READER = 0, WRITER = 1, PASSTHROUGH = 2 }; // Defaults to 10 GiB per shard. const int64 kDefaultShardSizeBytes = 10LL * 1024 * 1024 * 1024; -const int64 kSnappyWriterInputBufferSizeBytes = 16 << 20; // 16 MiB -const int64 kSnappyWriterOutputBufferSizeBytes = 16 << 20; // 16 MiB - // The reader input buffer size is deliberately large because the input reader // will throw an error if the compressed block length cannot fit in the input // buffer. @@ -75,6 +75,8 @@ const int64 kSnappyReaderOutputBufferSizeBytes = 32 << 20; // 32 MiB const size_t kHeaderSize = sizeof(uint64); +const int64 kCurrentVersion = 1; + constexpr char kModeAuto[] = "auto"; constexpr char kModeWrite[] = "write"; constexpr char kModeRead[] = "read"; @@ -95,6 +97,7 @@ constexpr char kState[] = "state"; constexpr char kHashDir[] = "hash_dir"; constexpr char kRunId[] = "run_id"; constexpr char kRunDir[] = "run_dir"; +constexpr char kVersionStr[] = "version"; constexpr char kFilenames[] = "filenames"; constexpr char kCurrentFilenames[] = "current_filenames"; constexpr char kElementsProduced[] = "elements_produced"; @@ -115,9 +118,9 @@ class SnapshotWriter { static constexpr const char* const kWriteStringPiece = "WriteStringPiece"; static constexpr const char* const kWriteCord = "WriteCord"; - explicit SnapshotWriter(WritableFile* dest, const string& compression_type = - io::compression::kNone) - : dest_(dest), compression_type_(compression_type) { + explicit SnapshotWriter(WritableFile* dest, const string& compression_type, + int version, const DataTypeVector& dtypes) + : dest_(dest), compression_type_(compression_type), version_(version) { #if defined(IS_SLIM_BUILD) if (compression_type != io::compression::kNone) { LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning " @@ -134,41 +137,100 @@ class SnapshotWriter { TF_CHECK_OK(zlib_output_buffer->Init()); dest_ = zlib_output_buffer; dest_is_owned_ = true; - } else if (compression_type == io::compression::kSnappy) { - io::SnappyOutputBuffer* snappy_output_buffer = new io::SnappyOutputBuffer( - dest, /*input_buffer_bytes=*/kSnappyWriterInputBufferSizeBytes, - /*output_buffer_bytes=*/kSnappyWriterOutputBufferSizeBytes); - dest_ = snappy_output_buffer; - dest_is_owned_ = true; } #endif // IS_SLIM_BUILD + simple_tensor_mask_.reserve(dtypes.size()); + for (const auto& dtype : dtypes) { + if (DataTypeCanUseMemcpy(dtype)) { + simple_tensor_mask_.push_back(true); + num_simple_++; + } else { + simple_tensor_mask_.push_back(false); + num_complex_++; + } + } } - Status WriteRecord(const StringPiece& data) { - profiler::TraceMe activity( - [&]() { - return absl::StrCat(kClassName, kSeparator, kWriteStringPiece); - }, - profiler::TraceMeLevel::kInfo); - char header[kHeaderSize]; - core::EncodeFixed64(header, data.size()); - TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header)))); - return dest_->Append(data); - } - + Status WriteTensors(const std::vector& tensors) { + if (compression_type_ != io::compression::kSnappy) { + experimental::SnapshotRecord record; + for (const auto& tensor : tensors) { + TensorProto* t = record.add_tensor(); + tensor.AsProtoTensorContent(t); + } #if defined(PLATFORM_GOOGLE) - Status WriteRecord(const absl::Cord& data) { - profiler::TraceMe activity( - [&]() { return absl::StrCat(kClassName, kSeparator, kWriteCord); }, - profiler::TraceMeLevel::kInfo); - char header[kHeaderSize]; - core::EncodeFixed64(header, data.size()); - - TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header)))); - - return dest_->Append(data); - } + return WriteRecord(record.SerializeAsCord()); +#else // PLATFORM_GOOGLE + return WriteRecord(record.SerializeAsString()); #endif // PLATFORM_GOOGLE + } + + if (version_ != 1) { + return errors::InvalidArgument("Version: ", version_, + " is not supported."); + } + if (compression_type_ != io::compression::kSnappy) { + return errors::InvalidArgument( + "Version 1 is only compatible with snappy compression"); + } + + std::vector tensor_buffers; + tensor_buffers.reserve(num_simple_); + std::vector tensor_protos; + tensor_protos.reserve(num_complex_); + SnapshotTensorMetadata metadata; + int64 total_size = 0; + for (int i = 0; i < tensors.size(); ++i) { + const Tensor& tensor = tensors[i]; + TensorMetadata* tensor_metadata = metadata.add_tensor_metadata(); + tensor.shape().AsProto(tensor_metadata->mutable_tensor_shape()); + int64 size = 0; + if (simple_tensor_mask_[i]) { + auto tensor_buffer = DMAHelper::buffer(&tensor); + tensor_buffers.push_back(tensor_buffer); + size = tensor_buffer->size(); + } else { + TensorProto proto; + tensor.AsProtoTensorContent(&proto); + size = proto.ByteSizeLong(); + tensor_protos.push_back(std::move(proto)); + } + tensor_metadata->set_tensor_size_bytes(size); + total_size += size; + } + + std::vector uncompressed(total_size); + char* position = uncompressed.data(); + int buffer_index = 0; + int proto_index = 0; + for (int i = 0; i < tensors.size(); ++i) { + const auto& tensor_metadata = metadata.tensor_metadata(i); + if (simple_tensor_mask_[i]) { + memcpy(position, tensor_buffers[buffer_index]->data(), + tensor_metadata.tensor_size_bytes()); + buffer_index++; + } else { + tensor_protos[proto_index].SerializeToArray( + position, tensor_metadata.tensor_size_bytes()); + proto_index++; + } + position += tensor_metadata.tensor_size_bytes(); + } + DCHECK_EQ(position, uncompressed.data() + total_size); + + string output; + if (!port::Snappy_Compress(uncompressed.data(), total_size, &output)) { + return errors::Internal("Failed to compress using snappy."); + } +#if defined(PLATFORM_GOOGLE) + absl::Cord metadata_serialized = metadata.SerializeAsCord(); +#else // PLATFORM_GOOGLE + std::string metadata_serialized = metadata.SerializeAsString(); +#endif // PLATFORM_GOOGLE + TF_RETURN_IF_ERROR(WriteRecord(metadata_serialized)); + TF_RETURN_IF_ERROR(WriteRecord(output)); + return Status::OK(); + } Status Sync() { return dest_->Sync(); } @@ -192,9 +254,29 @@ class SnapshotWriter { } private: + Status WriteRecord(const StringPiece& data) { + char header[kHeaderSize]; + core::EncodeFixed64(header, data.size()); + TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header)))); + return dest_->Append(data); + } + +#if defined(PLATFORM_GOOGLE) + Status WriteRecord(const absl::Cord& data) { + char header[kHeaderSize]; + core::EncodeFixed64(header, data.size()); + TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header)))); + return dest_->Append(data); + } +#endif // PLATFORM_GOOGLE + WritableFile* dest_; bool dest_is_owned_ = false; const string compression_type_; + const int version_; + std::vector simple_tensor_mask_; // true for simple, false for complex. + int num_simple_ = 0; + int num_complex_ = 0; }; class SnapshotReader { @@ -203,12 +285,14 @@ class SnapshotReader { static constexpr const char* const kReadString = "ReadString"; static constexpr const char* const kReadCord = "ReadCord"; - explicit SnapshotReader( - RandomAccessFile* file, - const string& compression_type = io::compression::kNone) + explicit SnapshotReader(RandomAccessFile* file, + const string& compression_type, int version, + const DataTypeVector& dtypes) : file_(file), input_stream_(new io::RandomAccessInputStream(file)), - compression_type_(compression_type) { + compression_type_(compression_type), + version_(version), + dtypes_(dtypes) { #if defined(IS_SLIM_BUILD) if (compression_type_ != io::compression::kNone) { LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning " @@ -223,17 +307,167 @@ class SnapshotReader { input_stream_.release(), zlib_options.input_buffer_size, zlib_options.output_buffer_size, zlib_options, true); } else if (compression_type_ == io::compression::kSnappy) { - input_stream_ = absl::make_unique( - file_, /*input_buffer_bytes=*/kSnappyReaderInputBufferSizeBytes, - /*output_buffer_bytes=*/kSnappyReaderOutputBufferSizeBytes); + if (version_ == 0) { + input_stream_ = absl::make_unique( + file_, /*input_buffer_bytes=*/kSnappyReaderInputBufferSizeBytes, + /*output_buffer_bytes=*/kSnappyReaderOutputBufferSizeBytes); + } else { + input_stream_ = + absl::make_unique(file_, 64 << 20); + } } #endif // IS_SLIM_BUILD + simple_tensor_mask_.reserve(dtypes.size()); + for (const auto& dtype : dtypes) { + if (DataTypeCanUseMemcpy(dtype)) { + simple_tensor_mask_.push_back(true); + num_simple_++; + } else { + simple_tensor_mask_.push_back(false); + num_complex_++; + } + } + } + + Status ReadTensors(std::vector* read_tensors) { + profiler::TraceMe activity( + [&]() { return absl::StrCat(kClassName, kSeparator, "ReadTensors"); }, + profiler::TraceMeLevel::kInfo); + if (version_ == 0 || compression_type_ != io::compression::kSnappy) { + return ReadTensorsV0(read_tensors); + } + if (version_ != 1) { + return errors::InvalidArgument("Version: ", version_, + " is not supported."); + } + if (compression_type_ != io::compression::kSnappy) { + return errors::InvalidArgument("Version 1 only supports snappy."); + } + + SnapshotTensorMetadata metadata; + tstring metadata_str; + TF_RETURN_IF_ERROR(ReadRecord(&metadata_str)); + if (!metadata.ParseFromArray(metadata_str.data(), metadata_str.size())) { + return errors::DataLoss("Could not parse SnapshotTensorMetadata"); + } + read_tensors->reserve(metadata.tensor_metadata_size()); + + std::vector simple_tensors; + simple_tensors.reserve(num_simple_); + std::vector, size_t>> tensor_proto_strs; + tensor_proto_strs.reserve(num_complex_); + TF_RETURN_IF_ERROR( + SnappyUncompress(metadata, &simple_tensors, &tensor_proto_strs)); + + int simple_index = 0; + int complex_index = 0; + for (int i = 0; i < simple_tensor_mask_.size(); ++i) { + if (simple_tensor_mask_[i]) { + read_tensors->push_back(std::move(simple_tensors[simple_index])); + simple_index++; + } else { + auto tensor_proto_str = + std::move(tensor_proto_strs[complex_index].first); + size_t tensor_proto_size = tensor_proto_strs[complex_index].second; + TensorProto tp; +#if defined(PLATFORM_GOOGLE) + auto tensor_proto_ptr = tensor_proto_str.release(); + absl::Cord c; + c.AppendExternalMemory( + absl::string_view(tensor_proto_ptr, tensor_proto_size), + tensor_proto_ptr, + [](void* arg) { delete[] static_cast(arg); }); + if (!tp.ParseFromCord(c)) { + return errors::Internal("Could not parse TensorProto"); + } +#else // PLATFORM_GOOGLE + if (!tp.ParseFromArray(tensor_proto_str.get(), tensor_proto_size)) { + return errors::Internal("Could not parse TensorProto"); + } +#endif // PLATFORM_GOOGLE + Tensor t; + if (!t.FromProto(tp)) { + return errors::Internal("Could not parse Tensor"); + } + read_tensors->push_back(std::move(t)); + complex_index++; + } + } + return Status::OK(); + } + + private: + Status ReadTensorsV0(std::vector* read_tensors) { + experimental::SnapshotRecord record; +#if defined(PLATFORM_GOOGLE) + absl::Cord c; + TF_RETURN_IF_ERROR(ReadRecord(&c)); + record.ParseFromCord(c); +#else // PLATFORM_GOOGLE + tstring record_bytes; + TF_RETURN_IF_ERROR(ReadRecord(&record_bytes)); + record.ParseFromArray(record_bytes.data(), record_bytes.size()); +#endif // PLATFORM_GOOGLE + read_tensors->reserve(record.tensor_size()); + for (int i = 0; i < record.tensor_size(); ++i) { + read_tensors->emplace_back(); + if (!read_tensors->back().FromProto(record.tensor(i))) { + return errors::DataLoss("Unable to parse tensor from proto."); + } + } + return Status::OK(); + } + + Status SnappyUncompress( + const SnapshotTensorMetadata& metadata, + std::vector* simple_tensors, + std::vector, size_t>>* + tensor_proto_strs) { + tstring compressed; + TF_RETURN_IF_ERROR(ReadRecord(&compressed)); + size_t size; + if (!port::Snappy_GetUncompressedLength(compressed.data(), + compressed.size(), &size)) { + return errors::Internal("Could not get snappy uncompressed length"); + } + + int num_tensors = metadata.tensor_metadata_size(); + std::vector iov(num_tensors); + int index = 0; + int64 total_size = 0; + for (int i = 0; i < simple_tensor_mask_.size(); ++i) { + const auto& tensor_metadata = metadata.tensor_metadata(i); + if (simple_tensor_mask_[i]) { + TensorShape shape(tensor_metadata.tensor_shape()); + Tensor simple_tensor(dtypes_[i], shape); + TensorBuffer* buffer = DMAHelper::buffer(&simple_tensor); + iov[index].iov_base = buffer->data(); + iov[index].iov_len = buffer->size(); + simple_tensors->push_back(std::move(simple_tensor)); + } else { + auto tensor_proto_str = + absl::make_unique(tensor_metadata.tensor_size_bytes()); + iov[index].iov_base = tensor_proto_str.get(); + iov[index].iov_len = tensor_metadata.tensor_size_bytes(); + tensor_proto_strs->push_back(std::make_pair( + std::move(tensor_proto_str), tensor_metadata.tensor_size_bytes())); + } + total_size += iov[index].iov_len; + index++; + } + if (size != total_size) { + return errors::Internal("Uncompressed size mismatch. Snappy expects ", + size, " whereas the tensor metadata suggests ", + total_size); + } + if (!port::Snappy_UncompressToIOVec(compressed.data(), compressed.size(), + iov.data(), num_tensors)) { + return errors::Internal("Failed to perform snappy decompression."); + } + return Status::OK(); } Status ReadRecord(tstring* record) { - profiler::TraceMe activity( - [&]() { return absl::StrCat(kClassName, kSeparator, kReadString); }, - profiler::TraceMeLevel::kInfo); tstring header; TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header)); uint64 length = core::DecodeFixed64(header.data()); @@ -245,13 +479,6 @@ class SnapshotReader { tstring header; TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header)); uint64 length = core::DecodeFixed64(header.data()); - profiler::TraceMe activity( - [&]() { - return absl::StrCat(kClassName, kSeparator, kReadCord, - "#length=", length, "#"); - }, - profiler::TraceMeLevel::kInfo); - if (compression_type_ == io::compression::kNone) { return input_stream_->ReadNBytes(length, record); } else { @@ -268,50 +495,31 @@ class SnapshotReader { } #endif - private: RandomAccessFile* file_; std::unique_ptr input_stream_; const string compression_type_; + const int version_; + const DataTypeVector dtypes_; + int num_simple_ = 0; + int num_complex_ = 0; + std::vector simple_tensor_mask_; // true for simple, false for complex. }; Status WriteMetadataFile(const string& hash_dir, const experimental::SnapshotMetadataRecord& metadata) { string metadata_filename = io::JoinPath(hash_dir, kSnapshotFilename); TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(hash_dir)); - std::string tmp_filename = absl::StrCat(metadata_filename, "-tmp-", random::New64()); - - std::unique_ptr file; - TF_RETURN_IF_ERROR(Env::Default()->NewWritableFile(tmp_filename, &file)); - - auto writer = absl::make_unique(file.get()); - TF_RETURN_IF_ERROR(writer->WriteRecord(metadata.SerializeAsString())); - TF_RETURN_IF_ERROR(writer->Close()); - TF_RETURN_IF_ERROR(file->Sync()); - TF_RETURN_IF_ERROR(file->Close()); - - TF_RETURN_IF_ERROR( - Env::Default()->RenameFile(tmp_filename, metadata_filename)); - - return Status::OK(); + TF_RETURN_IF_ERROR(WriteBinaryProto(Env::Default(), tmp_filename, metadata)); + return Env::Default()->RenameFile(tmp_filename, metadata_filename); } Status ReadMetadataFile(const string& hash_dir, experimental::SnapshotMetadataRecord* metadata) { string metadata_filename = io::JoinPath(hash_dir, kSnapshotFilename); TF_RETURN_IF_ERROR(Env::Default()->FileExists(metadata_filename)); - - std::unique_ptr file; - TF_RETURN_IF_ERROR( - Env::Default()->NewRandomAccessFile(metadata_filename, &file)); - - tstring record_bytes; - SnapshotReader reader(file.get()); - TF_RETURN_IF_ERROR(reader.ReadRecord(&record_bytes)); - - metadata->ParseFromArray(record_bytes.data(), record_bytes.size()); - return Status::OK(); + return ReadBinaryProto(Env::Default(), metadata_filename, metadata); } Status DumpDatasetGraph(const std::string& path, uint64 hash, @@ -332,6 +540,10 @@ Status DetermineOpState(const std::string& mode_string, const uint64 pending_snapshot_expiry_seconds, SnapshotMode* mode) { if (mode_string == kModeRead) { + // In read mode, we should expect a metadata file is written. + if (errors::IsNotFound(file_status)) { + return file_status; + } LOG(INFO) << "Overriding mode to reader."; *mode = READER; return Status::OK(); @@ -727,10 +939,25 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { if (run_id.empty()) { run_id = metadata.run_id(); } + // dtypes in metadata should be the same as dataset()->output_dtypes + if (metadata.dtype_size() != dataset()->output_dtypes().size()) { + return errors::Internal( + "Expected number of dtypes: ", + dataset()->output_dtypes().size(), + " but number in snapshot: ", metadata.dtype_size()); + } + for (int i = 0; i < metadata.dtype_size(); ++i) { + if (metadata.dtype(i) != dataset()->output_dtypes()[i]) { + return errors::Internal( + "Type: ", i, + " doesn't match. Snapshot: ", metadata.dtype(i), + "; dataset: ", dataset()->output_dtypes()[i]); + } + } iterator_ = absl::make_unique( SnapshotReaderIterator::Params{ dataset(), absl::StrCat(prefix(), "ReaderImpl")}, - hash_dir_, run_id); + hash_dir_, run_id, metadata.version()); break; case PASSTHROUGH: iterator_ = absl::make_unique( @@ -748,10 +975,11 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { explicit SnapshotReaderIterator(const Params& params, const string& hash_dir, - const string& run_id) + const string& run_id, int64 version) : DatasetIterator(params), hash_dir_(hash_dir), - run_id_(run_id) {} + run_id_(run_id), + version_(version) {} ~SnapshotReaderIterator() override { mutex_lock l(mu_); @@ -889,6 +1117,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { writer->WriteScalar(full_name(kHashDir), hash_dir_)); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kRunId), run_id_)); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kRunDir), run_dir_)); + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name(kVersionStr), version_)); TF_RETURN_IF_ERROR(writer->WriteScalar( full_name(strings::StrCat(kFilenames, kSizeSuffix)), filenames_.size())); @@ -932,6 +1162,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { } TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunId), &run_id_)); TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunDir), &run_dir_)); + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name(kVersionStr), &version_)); curr_filenames_.clear(); curr_filenames_.reserve(dataset()->num_reader_threads_); for (auto i = 0; i < dataset()->num_reader_threads_; ++i) { @@ -986,7 +1218,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr file; TF_RETURN_IF_ERROR( Env::Default()->NewRandomAccessFile(filename, &file)); - SnapshotReader reader(file.get(), dataset()->compression_); + SnapshotReader reader(file.get(), dataset()->compression_, version_, + dataset()->output_dtypes()); while (true) { // Wait for a slot in the buffer. @@ -1003,30 +1236,14 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { "ReadFile"); } } -#if !defined(PLATFORM_GOOGLE) - tstring record_bytes; - Status s = reader.ReadRecord(&record_bytes); -#else - absl::Cord record_cord; - Status s = reader.ReadRecord(&record_cord); -#endif + std::vector read_tensors; + Status s = reader.ReadTensors(&read_tensors); if (s.ok()) { profiler::TraceMe activity( [&]() { return absl::StrCat(prefix(), kSeparator, kParse); }, profiler::TraceMeLevel::kInfo); - experimental::SnapshotRecord record; -#if !defined(PLATFORM_GOOGLE) - record.ParseFromArray(record_bytes.data(), record_bytes.size()); -#else - record.ParseFromCord(record_cord); -#endif BufferElement elem; - for (int i = 0; i < record.tensor_size(); ++i) { - elem.value.emplace_back(); - if (!elem.value.back().FromProto(record.tensor(i))) { - return errors::DataLoss("Unable to parse tensor from proto."); - } - } + elem.value = std::move(read_tensors); elem.status = Status::OK(); mutex_lock l(mu_); buffer_.push_back(std::move(elem)); @@ -1142,9 +1359,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { condition_variable cond_var_; const string hash_dir_; - const experimental::SnapshotMetadataRecord metadata_; tstring run_id_ GUARDED_BY(mu_); tstring run_dir_ GUARDED_BY(mu_); + int64 version_; std::vector filenames_; uint64 elements_produced_ GUARDED_BY(mu_) = 0; @@ -1220,6 +1437,10 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { metadata.set_creation_timestamp(EnvTime::NowMicros()); metadata.set_graph_hash(dataset()->graph_hash_); metadata.set_run_id(run_id_.data(), run_id_.size()); + metadata.set_version(kCurrentVersion); + for (const auto& output_dtype : dataset()->output_dtypes()) { + metadata.add_dtype(output_dtype); + } metadata.set_finalized(false); TF_RETURN_IF_ERROR(WriteMetadataFile(hash_dir_, metadata)); } @@ -1564,11 +1785,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { } if (produced_elem) { - experimental::SnapshotRecord record; for (const auto& out_tensor : elem.value) { *bytes_written += out_tensor.TotalBytes(); - TensorProto* t = record.add_tensor(); - out_tensor.AsProtoTensorContent(t); } bool should_close; @@ -1584,16 +1802,11 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(Env::Default()->NewAppendableFile( *snapshot_data_filename, file)); *writer = absl::make_unique( - file->get(), dataset()->compression_); + file->get(), dataset()->compression_, kCurrentVersion, + dataset()->output_dtypes()); *bytes_written = 0; } -#if defined(PLATFORM_GOOGLE) - TF_RETURN_IF_ERROR( - (*writer)->WriteRecord(record.SerializeAsCord())); -#else // PLATFORM_GOOGLE - TF_RETURN_IF_ERROR( - (*writer)->WriteRecord(record.SerializeAsString())); -#endif // PLATFORM_GOOGLE + TF_RETURN_IF_ERROR((*writer)->WriteTensors(elem.value)); return Status::OK(); } @@ -1641,7 +1854,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { return; } std::unique_ptr writer( - new SnapshotWriter(file.get(), dataset()->compression_)); + new SnapshotWriter(file.get(), dataset()->compression_, + kCurrentVersion, dataset()->output_dtypes())); bool end_of_processing = false; while (!end_of_processing) { diff --git a/tensorflow/core/platform/default/port.cc b/tensorflow/core/platform/default/port.cc index 47f4abae3bb..756e7e8a93a 100644 --- a/tensorflow/core/platform/default/port.cc +++ b/tensorflow/core/platform/default/port.cc @@ -332,6 +332,16 @@ bool Snappy_Uncompress(const char* input, size_t length, char* output) { #endif } +bool Snappy_UncompressToIOVec(const char* compressed, size_t compressed_length, + const struct iovec* iov, size_t iov_cnt) { +#ifdef TF_USE_SNAPPY + return snappy::RawUncompressToIOVec(compressed, compressed_length, iov, + iov_cnt); +#else + return false; +#endif +} + string Demangle(const char* mangled) { return mangled; } double NominalCPUFrequency() { diff --git a/tensorflow/core/platform/snappy.h b/tensorflow/core/platform/snappy.h index 5477b097ef0..df06f3dcc1e 100644 --- a/tensorflow/core/platform/snappy.h +++ b/tensorflow/core/platform/snappy.h @@ -18,6 +18,17 @@ limitations under the License. #include "tensorflow/core/platform/types.h" +#if !defined(PLATFORM_WINDOWS) +#include +#else +namespace tensorflow { +struct iovec { + void* iov_base; + size_t iov_len; +}; +} // namespace tensorflow +#endif + namespace tensorflow { namespace port { @@ -28,6 +39,9 @@ bool Snappy_GetUncompressedLength(const char* input, size_t length, size_t* result); bool Snappy_Uncompress(const char* input, size_t length, char* output); +bool Snappy_UncompressToIOVec(const char* compressed, size_t compressed_length, + const struct iovec* iov, size_t iov_cnt); + } // namespace port } // namespace tensorflow diff --git a/tensorflow/core/platform/windows/port.cc b/tensorflow/core/platform/windows/port.cc index 2303b587ce6..547af76bdf6 100644 --- a/tensorflow/core/platform/windows/port.cc +++ b/tensorflow/core/platform/windows/port.cc @@ -157,6 +157,17 @@ bool Snappy_Uncompress(const char* input, size_t length, char* output) { #endif } +bool Snappy_UncompressToIOVec(const char* compressed, size_t compressed_length, + const struct iovec* iov, size_t iov_cnt) { +#ifdef TF_USE_SNAPPY + const snappy::iovec* snappy_iov = reinterpret_cast(iov); + return snappy::RawUncompressToIOVec(compressed, compressed_length, snappy_iov, + iov_cnt); +#else + return false; +#endif +} + string Demangle(const char* mangled) { return mangled; } double NominalCPUFrequency() { diff --git a/tensorflow/core/protobuf/data/experimental/snapshot.proto b/tensorflow/core/protobuf/data/experimental/snapshot.proto index 422602d3760..e013deb2ee1 100644 --- a/tensorflow/core/protobuf/data/experimental/snapshot.proto +++ b/tensorflow/core/protobuf/data/experimental/snapshot.proto @@ -3,6 +3,8 @@ syntax = "proto3"; package tensorflow.data.experimental; import "tensorflow/core/framework/tensor.proto"; +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; // Each SnapshotRecord represents one batch of pre-processed input data. A batch // consists of a list of tensors that we encode as TensorProtos. This message @@ -13,9 +15,29 @@ message SnapshotRecord { // This stores the metadata information present in each snapshot record. message SnapshotMetadataRecord { + // Stores the fingerprint of the graph that describes the dataset that is + // snapshotted. string graph_hash = 1; + // Run ID that this snapshot corresponds to. string run_id = 2; + // Time when we started creating this snapshot. int64 creation_timestamp = 3; + // Version of the snapshot data file format. + int64 version = 4; + // A list of tensor dtype corresponding to each element of the snapshot. + repeated .tensorflow.DataType dtype = 5; bool finalized = 1000; } + +// Metadata for a single tensor in the Snapshot Record. +message TensorMetadata { + .tensorflow.TensorShapeProto tensor_shape = 2; + // Number of uncompressed bytes used to store the tensor representation. + int64 tensor_size_bytes = 3; +} + +// Metadata for all the tensors in a Snapshot Record. +message SnapshotTensorMetadata { + repeated TensorMetadata tensor_metadata = 1; +} diff --git a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py index 96b3b764864..535cf884dc6 100644 --- a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py @@ -161,17 +161,49 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) - @combinations.generate(test_base.default_test_combinations()) - def testWriteSnapshotRepeatAfterwards(self): + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(compression=[ + snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, + snapshot.COMPRESSION_SNAPPY + ]))) + def testWriteSnapshotRepeatAfterwards(self, compression): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.range(10) - dataset = dataset.apply(snapshot.snapshot(tmpdir)) + dataset = dataset.apply(snapshot.snapshot(tmpdir, compression=compression)) dataset = dataset.repeat(10) self.assertDatasetProduces(dataset, list(range(10)) * 10) self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(compression=[ + snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, + snapshot.COMPRESSION_SNAPPY + ]))) + def testWriteSnapshotMixTypes(self, compression): + tmpdir = self.snapshot_dir + + dataset = dataset_ops.Dataset.range(10) + + def map_fn(x): + return (x, string_ops.as_string(x), string_ops.as_string(2 * x), 2 * x) + + dataset = dataset.map(map_fn) + dataset = dataset.apply(snapshot.snapshot(tmpdir, compression=compression)) + dataset = dataset.repeat(10) + + expected = [] + for i in range(10): + expected.append((i, str(i), str(2 * i), 2 * i)) + self.assertDatasetProduces(dataset, expected * 10) + + self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1) + @combinations.generate(test_base.default_test_combinations()) def testSpecifySnapshotNameWriteAndRead(self): tmpdir = self.snapshot_dir @@ -365,8 +397,14 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, res3 = self.evaluate(next3()) self.assertEqual(res2, res3) - @combinations.generate(test_base.default_test_combinations()) - def testReadSnapshotParallelAfterWrite(self): + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(compression=[ + snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, + snapshot.COMPRESSION_SNAPPY + ]))) + def testReadSnapshotParallelAfterWrite(self, compression): self.setUpTFRecord(10, 4000) filenames = self.test_filenames @@ -383,7 +421,8 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, tmpdir, shard_size_bytes=1024 * 1024, num_reader_threads=2, - reader_buffer_size=10)) + reader_buffer_size=10, + compression=compression)) self.assertDatasetProduces(dataset, expected, assert_items_equal=True) # remove the original files and try to read the data back only from @@ -396,7 +435,8 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, tmpdir, shard_size_bytes=1024 * 1024, num_reader_threads=2, - reader_buffer_size=10)) + reader_buffer_size=10, + compression=compression)) self.assertDatasetProduces(dataset2, expected, assert_items_equal=True) # Not testing Snappy here because Snappy reads currently require a lot of @@ -514,21 +554,31 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase, self.evaluate(next2()) self.assertSnapshotDirectoryContains(tmpdir, 2, 1, 1) - @combinations.generate(test_base.default_test_combinations()) - def testSpecifyShardSize(self): + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(compression=[ + snapshot.COMPRESSION_NONE, snapshot.COMPRESSION_GZIP, + snapshot.COMPRESSION_SNAPPY + ]))) + def testSpecifyShardSize(self, compression): tmpdir = self.snapshot_dir dataset = dataset_ops.Dataset.from_tensor_slices([1.0]) dataset = dataset.map(lambda x: gen_array_ops.broadcast_to(x, [1024, 1024])) dataset = dataset.repeat(10) dataset = dataset.apply( - snapshot.snapshot(tmpdir, shard_size_bytes=10 * 1024 * 1024)) + snapshot.snapshot( + tmpdir, shard_size_bytes=10 * 1024 * 1024, compression=compression)) next_fn = self.getNext(dataset) for _ in range(10): self.evaluate(next_fn()) - self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 3) + num_files = 1 + if compression == snapshot.COMPRESSION_NONE: + num_files = 3 + self.assertSnapshotDirectoryContains(tmpdir, 1, 1, num_files) @combinations.generate(test_base.default_test_combinations()) def testAdditionalOperationsAfterReadBack(self): diff --git a/third_party/snappy.BUILD b/third_party/snappy.BUILD index d93f0307690..a2ab4924f29 100644 --- a/third_party/snappy.BUILD +++ b/third_party/snappy.BUILD @@ -27,6 +27,10 @@ cc_library( "-Wno-implicit-function-declaration", ], }), + defines = select({ + "@org_tensorflow//tensorflow:windows": [], + "//conditions:default": ["HAVE_SYS_UIO_H"], + }), ) genrule(