From a3594d812679c87355b434850e5a04e7db4a31df Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Wed, 5 Feb 2020 15:48:28 -0800 Subject: [PATCH] Making a few improvements and bug fixes to snapshot 1. After restoring from a checkpoint during write mode, its possible that some new files might have been opened up since the last save. In that case the next_file_index_ is the old one and we'll end up trying to write the new data to an existing file. This doesn't play nice with some file systems. So we take the safer option which is getting a list of all the filenames in the directory and then using that to determine which should be the next file to write to. 2. Adding tensorboard visualizations to the reader / writer buffer size to see utilization. This can be helpful during experimentation so that we can perhaps reduce the buffer size and save memory 3. Adding a few VLOG's that help with some debugging. 4. Doubling the size of the snappy output buffer size. PiperOrigin-RevId: 293470358 Change-Id: I38675d7125ce00eb8e1c005292c62d2b8e3a28a5 --- .../data/experimental/snapshot_dataset_op.cc | 57 ++++++++++++++---- .../snapshot_dataset_serialization_test.py | 60 +++++++++++++++++-- 2 files changed, 100 insertions(+), 17 deletions(-) diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc index fea13103efb..17437e7e1f5 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc @@ -70,7 +70,8 @@ const int64 kSnappyWriterOutputBufferSizeBytes = 16 << 20; // 16 MiB // will throw an error if the compressed block length cannot fit in the input // buffer. const int64 kSnappyReaderInputBufferSizeBytes = 1 << 30; // 1 GiB -const int64 kSnappyReaderOutputBufferSizeBytes = 16 << 20; // 16 MiB +// TODO(b/148804377): Set this in a smarter fashion. +const int64 kSnappyReaderOutputBufferSizeBytes = 32 << 20; // 32 MiB const size_t kHeaderSize = sizeof(uint64); @@ -241,12 +242,15 @@ class SnapshotReader { #if defined(PLATFORM_GOOGLE) Status ReadRecord(absl::Cord* record) { - profiler::TraceMe activity( - [&]() { return absl::StrCat(kClassName, kSeparator, kReadCord); }, - profiler::TraceMeLevel::kInfo); tstring header; TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header)); uint64 length = core::DecodeFixed64(header.data()); + profiler::TraceMe activity( + [&]() { + return absl::StrCat(kClassName, kSeparator, kReadCord, + "#length=", length, "#"); + }, + profiler::TraceMeLevel::kInfo); if (compression_type_ == io::compression::kNone) { return input_stream_->ReadNBytes(length, record); @@ -655,6 +659,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(DetermineOpState( dataset()->mode_, s, metadata, dataset()->pending_snapshot_expiry_seconds_, &state_)); + VLOG(2) << "Snapshot state: " << state_; TF_RETURN_IF_ERROR(InitializeIterator(ctx, metadata)); } return iterator_->GetNext(ctx, out_tensors, end_of_sequence); @@ -667,6 +672,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name(kState), static_cast(state_))); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kHashDir), hash_dir_)); + VLOG(2) << "Saving Snapshot iterator: " << state_; return Status::OK(); } @@ -689,6 +695,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { experimental::SnapshotMetadataRecord metadata; TF_RETURN_IF_ERROR(ReadMetadataFile(hash_dir_, &metadata)); TF_RETURN_IF_ERROR(InitializeIterator(ctx, metadata)); + VLOG(2) << "Restoring Snapshot iterator: " << state_; return RestoreInput(ctx, reader, iterator_); } @@ -820,6 +827,10 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { absl::StrCat(dataset()->node_name(), kSeparator, kSnapshotReadElements), static_cast(num_elements_read_), elements_produced_); + stats_aggregator->AddScalar( + absl::StrCat(dataset()->node_name(), kSeparator, + "snapshot_reader_buffer_size"), + static_cast(buffer_.size()), elements_produced_); } if (!buffer_.empty()) { @@ -899,6 +910,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { writer->WriteScalar(full_name(kNumFilesDone), num_files_done_)); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kNumElementsRead), num_elements_read_)); + VLOG(2) << "Saving SnapshotReaderIterator: " << num_elements_read_ + << "; elements_produced: " << elements_produced_; return Status::OK(); } @@ -962,6 +975,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { reader->ReadScalar(full_name(kNumFilesDone), &num_files_done_)); TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNumElementsRead), &num_elements_read_)); + VLOG(2) << "Restoring SnapshotReaderIterator: " << num_elements_read_ + << "; elements_produced: " << elements_produced_; return Status::OK(); } @@ -1257,6 +1272,10 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { kSnapshotWrittenElements), static_cast(num_elements_written_), elements_produced_); + stats_aggregator->AddScalar( + absl::StrCat(dataset()->node_name(), kSeparator, + "snapshot_writer_buffer_size"), + static_cast(buffer_.size()), elements_produced_); } absl::Time end = absl::Now(); @@ -1315,8 +1334,6 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { buffer_element.value[j])); } } - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name(kNextFileIndex), next_file_index_)); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kNumElementsWritten), num_elements_written_)); if (next_elem_.end_of_sequence) { @@ -1332,6 +1349,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { full_name(strings::StrCat(kNextElem, "[", i, "]")), next_elem_.value[i])); } + VLOG(2) << "Saving SnapshotWriterIterator: " << num_elements_written_ + << "; elements_produced: " << elements_produced_; return Status::OK(); } @@ -1395,12 +1414,25 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { &buffer_element.value.back())); } } - { - int64 temp; - TF_RETURN_IF_ERROR( - reader->ReadScalar(full_name(kNextFileIndex), &temp)); - next_file_index_ = static_cast(temp); + // Since the last save we might have written out some files. So we + // get a list of files in the directory and take the final filename + // written. We use the name of the snapshot file to figure out + // next_file_index_; + std::vector filenames; + TF_RETURN_IF_ERROR(ctx->env()->GetMatchingPaths( + absl::StrCat(absl::string_view(run_dir_), "/*"), &filenames)); + std::sort(filenames.begin(), filenames.end()); + std::string final_filename = filenames.back(); + std::vector split_filename = + absl::StrSplit(final_filename, '/'); + std::vector split_snapshot_filename = + absl::StrSplit(split_filename.back(), '.'); + std::string max_num_str = split_snapshot_filename[0]; + uint64 max_num; + if (!strings::safe_strtou64(max_num_str, &max_num)) { + return errors::Internal("Could not parse: ", max_num, " as uint64"); } + next_file_index_ = max_num + 1; TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNumElementsWritten), &num_elements_written_)); size_t next_elem_size; @@ -1423,6 +1455,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { full_name(strings::StrCat(kNextElem, "[", i, "]")), &next_elem_.value.back())); } + VLOG(2) << "Restoring SnapshotWriterIterator: " + << num_elements_written_ + << "; elements_produced: " << elements_produced_; return Status::OK(); } diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/snapshot_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/snapshot_dataset_serialization_test.py index 6bdb6e089e2..53261d4b298 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/snapshot_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/snapshot_dataset_serialization_test.py @@ -26,6 +26,7 @@ from tensorflow.python.data.experimental.ops import snapshot from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations +from tensorflow.python.framework import ops from tensorflow.python.platform import test @@ -36,21 +37,23 @@ class SnapshotDatasetSerializationTest( def _build_snapshot_dataset(self, num_threads=1, repeat=False, - pending_snapshot_expiry_seconds=-1): + pending_snapshot_expiry_seconds=-1, + shard_size_bytes=None): def ds_fn(): - snapshot_dir = os.path.join(self.get_temp_dir(), "snapshot") - if not os.path.exists(snapshot_dir): - os.mkdir(snapshot_dir) + self.snapshot_dir = os.path.join(self.get_temp_dir(), "snapshot") + if not os.path.exists(self.snapshot_dir): + os.mkdir(self.snapshot_dir) dataset = dataset_ops.Dataset.range(1000) dataset = dataset.apply( snapshot.snapshot( - snapshot_dir, + self.snapshot_dir, num_writer_threads=num_threads, writer_buffer_size=2 * num_threads, num_reader_threads=num_threads, reader_buffer_size=2 * num_threads, - pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds)) + pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds, + shard_size_bytes=shard_size_bytes)) if repeat: dataset = dataset.repeat(2) return dataset @@ -71,6 +74,51 @@ class SnapshotDatasetSerializationTest( ds_fn, [], 900, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, range(1000)) + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(pending_snapshot_expiry_seconds=[None, 1]))) + def testCheckpointBeforeOneEpochThenRunFewStepsSmallShardMultiThread( + self, pending_snapshot_expiry_seconds): + ds_fn = self._build_snapshot_dataset( + pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds, + shard_size_bytes=100) + + outputs = [] + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = self._build_graph(ds_fn) + with self.session(graph=g) as sess: + self._initialize(init_op, sess) + start = 0 + end = 100 + num_iters = end - start + for _ in range(num_iters): + outputs.append(sess.run(get_next_op)) + self._save(sess, saver) + start = 100 + end = 400 + num_iters = end - start + for _ in range(num_iters): + outputs.append(sess.run(get_next_op)) + self.assertSequenceEqual(outputs, range(400)) + + outputs = outputs[:100] + outputs.extend( + self.gen_outputs( + ds_fn, [], 900, ckpt_saved=True, verify_exhausted=False)) + self.assertSequenceEqual(outputs, range(1000)) + fp_dir_list = os.listdir(self.snapshot_dir) + self.assertLen(list(fp_dir_list), 2) + for d in fp_dir_list: + if not d.endswith("-graph.pbtxt"): + fp_dir = os.path.join(self.snapshot_dir, d) + run_dir_list = os.listdir(fp_dir) + self.assertLen(list(run_dir_list), 2) + for e in run_dir_list: + if e != "snapshot.metadata": + run_dir = os.path.join(fp_dir, e) + self.assertLen(list(os.listdir(run_dir)), 258) + @combinations.generate( combinations.times( test_base.default_test_combinations(),