Making a few improvements and bug fixes to snapshot
1. After restoring from a checkpoint during write mode, its possible that some new files might have been opened up since the last save. In that case the next_file_index_ is the old one and we'll end up trying to write the new data to an existing file. This doesn't play nice with some file systems. So we take the safer option which is getting a list of all the filenames in the directory and then using that to determine which should be the next file to write to. 2. Adding tensorboard visualizations to the reader / writer buffer size to see utilization. This can be helpful during experimentation so that we can perhaps reduce the buffer size and save memory 3. Adding a few VLOG's that help with some debugging. 4. Doubling the size of the snappy output buffer size. PiperOrigin-RevId: 293470358 Change-Id: I38675d7125ce00eb8e1c005292c62d2b8e3a28a5
This commit is contained in:
parent
ffc08f6c44
commit
a3594d8126
@ -70,7 +70,8 @@ const int64 kSnappyWriterOutputBufferSizeBytes = 16 << 20; // 16 MiB
|
||||
// will throw an error if the compressed block length cannot fit in the input
|
||||
// buffer.
|
||||
const int64 kSnappyReaderInputBufferSizeBytes = 1 << 30; // 1 GiB
|
||||
const int64 kSnappyReaderOutputBufferSizeBytes = 16 << 20; // 16 MiB
|
||||
// TODO(b/148804377): Set this in a smarter fashion.
|
||||
const int64 kSnappyReaderOutputBufferSizeBytes = 32 << 20; // 32 MiB
|
||||
|
||||
const size_t kHeaderSize = sizeof(uint64);
|
||||
|
||||
@ -241,12 +242,15 @@ class SnapshotReader {
|
||||
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
Status ReadRecord(absl::Cord* record) {
|
||||
profiler::TraceMe activity(
|
||||
[&]() { return absl::StrCat(kClassName, kSeparator, kReadCord); },
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
tstring header;
|
||||
TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
|
||||
uint64 length = core::DecodeFixed64(header.data());
|
||||
profiler::TraceMe activity(
|
||||
[&]() {
|
||||
return absl::StrCat(kClassName, kSeparator, kReadCord,
|
||||
"#length=", length, "#");
|
||||
},
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
|
||||
if (compression_type_ == io::compression::kNone) {
|
||||
return input_stream_->ReadNBytes(length, record);
|
||||
@ -655,6 +659,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
TF_RETURN_IF_ERROR(DetermineOpState(
|
||||
dataset()->mode_, s, metadata,
|
||||
dataset()->pending_snapshot_expiry_seconds_, &state_));
|
||||
VLOG(2) << "Snapshot state: " << state_;
|
||||
TF_RETURN_IF_ERROR(InitializeIterator(ctx, metadata));
|
||||
}
|
||||
return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||
@ -667,6 +672,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kState), static_cast<int64>(state_)));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kHashDir), hash_dir_));
|
||||
VLOG(2) << "Saving Snapshot iterator: " << state_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -689,6 +695,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
experimental::SnapshotMetadataRecord metadata;
|
||||
TF_RETURN_IF_ERROR(ReadMetadataFile(hash_dir_, &metadata));
|
||||
TF_RETURN_IF_ERROR(InitializeIterator(ctx, metadata));
|
||||
VLOG(2) << "Restoring Snapshot iterator: " << state_;
|
||||
return RestoreInput(ctx, reader, iterator_);
|
||||
}
|
||||
|
||||
@ -820,6 +827,10 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
absl::StrCat(dataset()->node_name(), kSeparator,
|
||||
kSnapshotReadElements),
|
||||
static_cast<float>(num_elements_read_), elements_produced_);
|
||||
stats_aggregator->AddScalar(
|
||||
absl::StrCat(dataset()->node_name(), kSeparator,
|
||||
"snapshot_reader_buffer_size"),
|
||||
static_cast<float>(buffer_.size()), elements_produced_);
|
||||
}
|
||||
|
||||
if (!buffer_.empty()) {
|
||||
@ -899,6 +910,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
writer->WriteScalar(full_name(kNumFilesDone), num_files_done_));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kNumElementsRead),
|
||||
num_elements_read_));
|
||||
VLOG(2) << "Saving SnapshotReaderIterator: " << num_elements_read_
|
||||
<< "; elements_produced: " << elements_produced_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -962,6 +975,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
reader->ReadScalar(full_name(kNumFilesDone), &num_files_done_));
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNumElementsRead),
|
||||
&num_elements_read_));
|
||||
VLOG(2) << "Restoring SnapshotReaderIterator: " << num_elements_read_
|
||||
<< "; elements_produced: " << elements_produced_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1257,6 +1272,10 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
kSnapshotWrittenElements),
|
||||
static_cast<float>(num_elements_written_),
|
||||
elements_produced_);
|
||||
stats_aggregator->AddScalar(
|
||||
absl::StrCat(dataset()->node_name(), kSeparator,
|
||||
"snapshot_writer_buffer_size"),
|
||||
static_cast<float>(buffer_.size()), elements_produced_);
|
||||
}
|
||||
|
||||
absl::Time end = absl::Now();
|
||||
@ -1315,8 +1334,6 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
buffer_element.value[j]));
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kNextFileIndex), next_file_index_));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kNumElementsWritten),
|
||||
num_elements_written_));
|
||||
if (next_elem_.end_of_sequence) {
|
||||
@ -1332,6 +1349,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
full_name(strings::StrCat(kNextElem, "[", i, "]")),
|
||||
next_elem_.value[i]));
|
||||
}
|
||||
VLOG(2) << "Saving SnapshotWriterIterator: " << num_elements_written_
|
||||
<< "; elements_produced: " << elements_produced_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1395,12 +1414,25 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
&buffer_element.value.back()));
|
||||
}
|
||||
}
|
||||
{
|
||||
int64 temp;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name(kNextFileIndex), &temp));
|
||||
next_file_index_ = static_cast<uint64>(temp);
|
||||
// Since the last save we might have written out some files. So we
|
||||
// get a list of files in the directory and take the final filename
|
||||
// written. We use the name of the snapshot file to figure out
|
||||
// next_file_index_;
|
||||
std::vector<std::string> filenames;
|
||||
TF_RETURN_IF_ERROR(ctx->env()->GetMatchingPaths(
|
||||
absl::StrCat(absl::string_view(run_dir_), "/*"), &filenames));
|
||||
std::sort(filenames.begin(), filenames.end());
|
||||
std::string final_filename = filenames.back();
|
||||
std::vector<std::string> split_filename =
|
||||
absl::StrSplit(final_filename, '/');
|
||||
std::vector<std::string> split_snapshot_filename =
|
||||
absl::StrSplit(split_filename.back(), '.');
|
||||
std::string max_num_str = split_snapshot_filename[0];
|
||||
uint64 max_num;
|
||||
if (!strings::safe_strtou64(max_num_str, &max_num)) {
|
||||
return errors::Internal("Could not parse: ", max_num, " as uint64");
|
||||
}
|
||||
next_file_index_ = max_num + 1;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNumElementsWritten),
|
||||
&num_elements_written_));
|
||||
size_t next_elem_size;
|
||||
@ -1423,6 +1455,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
full_name(strings::StrCat(kNextElem, "[", i, "]")),
|
||||
&next_elem_.value.back()));
|
||||
}
|
||||
VLOG(2) << "Restoring SnapshotWriterIterator: "
|
||||
<< num_elements_written_
|
||||
<< "; elements_produced: " << elements_produced_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -26,6 +26,7 @@ from tensorflow.python.data.experimental.ops import snapshot
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -36,21 +37,23 @@ class SnapshotDatasetSerializationTest(
|
||||
def _build_snapshot_dataset(self,
|
||||
num_threads=1,
|
||||
repeat=False,
|
||||
pending_snapshot_expiry_seconds=-1):
|
||||
pending_snapshot_expiry_seconds=-1,
|
||||
shard_size_bytes=None):
|
||||
|
||||
def ds_fn():
|
||||
snapshot_dir = os.path.join(self.get_temp_dir(), "snapshot")
|
||||
if not os.path.exists(snapshot_dir):
|
||||
os.mkdir(snapshot_dir)
|
||||
self.snapshot_dir = os.path.join(self.get_temp_dir(), "snapshot")
|
||||
if not os.path.exists(self.snapshot_dir):
|
||||
os.mkdir(self.snapshot_dir)
|
||||
dataset = dataset_ops.Dataset.range(1000)
|
||||
dataset = dataset.apply(
|
||||
snapshot.snapshot(
|
||||
snapshot_dir,
|
||||
self.snapshot_dir,
|
||||
num_writer_threads=num_threads,
|
||||
writer_buffer_size=2 * num_threads,
|
||||
num_reader_threads=num_threads,
|
||||
reader_buffer_size=2 * num_threads,
|
||||
pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds))
|
||||
pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds,
|
||||
shard_size_bytes=shard_size_bytes))
|
||||
if repeat:
|
||||
dataset = dataset.repeat(2)
|
||||
return dataset
|
||||
@ -71,6 +74,51 @@ class SnapshotDatasetSerializationTest(
|
||||
ds_fn, [], 900, ckpt_saved=True, verify_exhausted=False))
|
||||
self.assertSequenceEqual(outputs, range(1000))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(pending_snapshot_expiry_seconds=[None, 1])))
|
||||
def testCheckpointBeforeOneEpochThenRunFewStepsSmallShardMultiThread(
|
||||
self, pending_snapshot_expiry_seconds):
|
||||
ds_fn = self._build_snapshot_dataset(
|
||||
pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds,
|
||||
shard_size_bytes=100)
|
||||
|
||||
outputs = []
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next_op, saver = self._build_graph(ds_fn)
|
||||
with self.session(graph=g) as sess:
|
||||
self._initialize(init_op, sess)
|
||||
start = 0
|
||||
end = 100
|
||||
num_iters = end - start
|
||||
for _ in range(num_iters):
|
||||
outputs.append(sess.run(get_next_op))
|
||||
self._save(sess, saver)
|
||||
start = 100
|
||||
end = 400
|
||||
num_iters = end - start
|
||||
for _ in range(num_iters):
|
||||
outputs.append(sess.run(get_next_op))
|
||||
self.assertSequenceEqual(outputs, range(400))
|
||||
|
||||
outputs = outputs[:100]
|
||||
outputs.extend(
|
||||
self.gen_outputs(
|
||||
ds_fn, [], 900, ckpt_saved=True, verify_exhausted=False))
|
||||
self.assertSequenceEqual(outputs, range(1000))
|
||||
fp_dir_list = os.listdir(self.snapshot_dir)
|
||||
self.assertLen(list(fp_dir_list), 2)
|
||||
for d in fp_dir_list:
|
||||
if not d.endswith("-graph.pbtxt"):
|
||||
fp_dir = os.path.join(self.snapshot_dir, d)
|
||||
run_dir_list = os.listdir(fp_dir)
|
||||
self.assertLen(list(run_dir_list), 2)
|
||||
for e in run_dir_list:
|
||||
if e != "snapshot.metadata":
|
||||
run_dir = os.path.join(fp_dir, e)
|
||||
self.assertLen(list(os.listdir(run_dir)), 258)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
|
Loading…
x
Reference in New Issue
Block a user