diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc index fea13103efb..17437e7e1f5 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc @@ -70,7 +70,8 @@ const int64 kSnappyWriterOutputBufferSizeBytes = 16 << 20; // 16 MiB // will throw an error if the compressed block length cannot fit in the input // buffer. const int64 kSnappyReaderInputBufferSizeBytes = 1 << 30; // 1 GiB -const int64 kSnappyReaderOutputBufferSizeBytes = 16 << 20; // 16 MiB +// TODO(b/148804377): Set this in a smarter fashion. +const int64 kSnappyReaderOutputBufferSizeBytes = 32 << 20; // 32 MiB const size_t kHeaderSize = sizeof(uint64); @@ -241,12 +242,15 @@ class SnapshotReader { #if defined(PLATFORM_GOOGLE) Status ReadRecord(absl::Cord* record) { - profiler::TraceMe activity( - [&]() { return absl::StrCat(kClassName, kSeparator, kReadCord); }, - profiler::TraceMeLevel::kInfo); tstring header; TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header)); uint64 length = core::DecodeFixed64(header.data()); + profiler::TraceMe activity( + [&]() { + return absl::StrCat(kClassName, kSeparator, kReadCord, + "#length=", length, "#"); + }, + profiler::TraceMeLevel::kInfo); if (compression_type_ == io::compression::kNone) { return input_stream_->ReadNBytes(length, record); @@ -655,6 +659,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(DetermineOpState( dataset()->mode_, s, metadata, dataset()->pending_snapshot_expiry_seconds_, &state_)); + VLOG(2) << "Snapshot state: " << state_; TF_RETURN_IF_ERROR(InitializeIterator(ctx, metadata)); } return iterator_->GetNext(ctx, out_tensors, end_of_sequence); @@ -667,6 +672,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name(kState), static_cast(state_))); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kHashDir), hash_dir_)); + VLOG(2) << "Saving Snapshot iterator: " << state_; return Status::OK(); } @@ -689,6 +695,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { experimental::SnapshotMetadataRecord metadata; TF_RETURN_IF_ERROR(ReadMetadataFile(hash_dir_, &metadata)); TF_RETURN_IF_ERROR(InitializeIterator(ctx, metadata)); + VLOG(2) << "Restoring Snapshot iterator: " << state_; return RestoreInput(ctx, reader, iterator_); } @@ -820,6 +827,10 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { absl::StrCat(dataset()->node_name(), kSeparator, kSnapshotReadElements), static_cast(num_elements_read_), elements_produced_); + stats_aggregator->AddScalar( + absl::StrCat(dataset()->node_name(), kSeparator, + "snapshot_reader_buffer_size"), + static_cast(buffer_.size()), elements_produced_); } if (!buffer_.empty()) { @@ -899,6 +910,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { writer->WriteScalar(full_name(kNumFilesDone), num_files_done_)); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kNumElementsRead), num_elements_read_)); + VLOG(2) << "Saving SnapshotReaderIterator: " << num_elements_read_ + << "; elements_produced: " << elements_produced_; return Status::OK(); } @@ -962,6 +975,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { reader->ReadScalar(full_name(kNumFilesDone), &num_files_done_)); TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNumElementsRead), &num_elements_read_)); + VLOG(2) << "Restoring SnapshotReaderIterator: " << num_elements_read_ + << "; elements_produced: " << elements_produced_; return Status::OK(); } @@ -1257,6 +1272,10 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { kSnapshotWrittenElements), static_cast(num_elements_written_), elements_produced_); + stats_aggregator->AddScalar( + absl::StrCat(dataset()->node_name(), kSeparator, + "snapshot_writer_buffer_size"), + static_cast(buffer_.size()), elements_produced_); } absl::Time end = absl::Now(); @@ -1315,8 +1334,6 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { buffer_element.value[j])); } } - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name(kNextFileIndex), next_file_index_)); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kNumElementsWritten), num_elements_written_)); if (next_elem_.end_of_sequence) { @@ -1332,6 +1349,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { full_name(strings::StrCat(kNextElem, "[", i, "]")), next_elem_.value[i])); } + VLOG(2) << "Saving SnapshotWriterIterator: " << num_elements_written_ + << "; elements_produced: " << elements_produced_; return Status::OK(); } @@ -1395,12 +1414,25 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { &buffer_element.value.back())); } } - { - int64 temp; - TF_RETURN_IF_ERROR( - reader->ReadScalar(full_name(kNextFileIndex), &temp)); - next_file_index_ = static_cast(temp); + // Since the last save we might have written out some files. So we + // get a list of files in the directory and take the final filename + // written. We use the name of the snapshot file to figure out + // next_file_index_; + std::vector filenames; + TF_RETURN_IF_ERROR(ctx->env()->GetMatchingPaths( + absl::StrCat(absl::string_view(run_dir_), "/*"), &filenames)); + std::sort(filenames.begin(), filenames.end()); + std::string final_filename = filenames.back(); + std::vector split_filename = + absl::StrSplit(final_filename, '/'); + std::vector split_snapshot_filename = + absl::StrSplit(split_filename.back(), '.'); + std::string max_num_str = split_snapshot_filename[0]; + uint64 max_num; + if (!strings::safe_strtou64(max_num_str, &max_num)) { + return errors::Internal("Could not parse: ", max_num, " as uint64"); } + next_file_index_ = max_num + 1; TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNumElementsWritten), &num_elements_written_)); size_t next_elem_size; @@ -1423,6 +1455,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { full_name(strings::StrCat(kNextElem, "[", i, "]")), &next_elem_.value.back())); } + VLOG(2) << "Restoring SnapshotWriterIterator: " + << num_elements_written_ + << "; elements_produced: " << elements_produced_; return Status::OK(); } diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/snapshot_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/snapshot_dataset_serialization_test.py index 6bdb6e089e2..53261d4b298 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/snapshot_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/snapshot_dataset_serialization_test.py @@ -26,6 +26,7 @@ from tensorflow.python.data.experimental.ops import snapshot from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations +from tensorflow.python.framework import ops from tensorflow.python.platform import test @@ -36,21 +37,23 @@ class SnapshotDatasetSerializationTest( def _build_snapshot_dataset(self, num_threads=1, repeat=False, - pending_snapshot_expiry_seconds=-1): + pending_snapshot_expiry_seconds=-1, + shard_size_bytes=None): def ds_fn(): - snapshot_dir = os.path.join(self.get_temp_dir(), "snapshot") - if not os.path.exists(snapshot_dir): - os.mkdir(snapshot_dir) + self.snapshot_dir = os.path.join(self.get_temp_dir(), "snapshot") + if not os.path.exists(self.snapshot_dir): + os.mkdir(self.snapshot_dir) dataset = dataset_ops.Dataset.range(1000) dataset = dataset.apply( snapshot.snapshot( - snapshot_dir, + self.snapshot_dir, num_writer_threads=num_threads, writer_buffer_size=2 * num_threads, num_reader_threads=num_threads, reader_buffer_size=2 * num_threads, - pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds)) + pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds, + shard_size_bytes=shard_size_bytes)) if repeat: dataset = dataset.repeat(2) return dataset @@ -71,6 +74,51 @@ class SnapshotDatasetSerializationTest( ds_fn, [], 900, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, range(1000)) + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(pending_snapshot_expiry_seconds=[None, 1]))) + def testCheckpointBeforeOneEpochThenRunFewStepsSmallShardMultiThread( + self, pending_snapshot_expiry_seconds): + ds_fn = self._build_snapshot_dataset( + pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds, + shard_size_bytes=100) + + outputs = [] + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = self._build_graph(ds_fn) + with self.session(graph=g) as sess: + self._initialize(init_op, sess) + start = 0 + end = 100 + num_iters = end - start + for _ in range(num_iters): + outputs.append(sess.run(get_next_op)) + self._save(sess, saver) + start = 100 + end = 400 + num_iters = end - start + for _ in range(num_iters): + outputs.append(sess.run(get_next_op)) + self.assertSequenceEqual(outputs, range(400)) + + outputs = outputs[:100] + outputs.extend( + self.gen_outputs( + ds_fn, [], 900, ckpt_saved=True, verify_exhausted=False)) + self.assertSequenceEqual(outputs, range(1000)) + fp_dir_list = os.listdir(self.snapshot_dir) + self.assertLen(list(fp_dir_list), 2) + for d in fp_dir_list: + if not d.endswith("-graph.pbtxt"): + fp_dir = os.path.join(self.snapshot_dir, d) + run_dir_list = os.listdir(fp_dir) + self.assertLen(list(run_dir_list), 2) + for e in run_dir_list: + if e != "snapshot.metadata": + run_dir = os.path.join(fp_dir, e) + self.assertLen(list(os.listdir(run_dir)), 258) + @combinations.generate( combinations.times( test_base.default_test_combinations(),