[tf.data] Support checkpointing parallel map datasets with large buffers.

PiperOrigin-RevId: 327095638
Change-Id: I88ca358f4c9688788a33e4348bfe47c68fabd0bc
This commit is contained in:
Andrew Audibert 2020-08-17 14:10:01 -07:00 committed by TensorFlower Gardener
parent 30bd504acc
commit 4112865ad4
2 changed files with 59 additions and 54 deletions

View File

@ -57,11 +57,12 @@ namespace data {
namespace {
constexpr char kComponent[] = "component";
constexpr char kInvocationResults[] = "invocation_results";
constexpr char kSizeSuffix[] = ".size";
constexpr char kEndOfInputSuffix[] = ".end_of_input";
constexpr char kCodeSuffix[] = ".code";
constexpr char kErrorMessage[] = ".error_message";
constexpr char kSize[] = "size";
constexpr char kEndOfInput[] = "end_of_input";
constexpr char kErrorCode[] = "code";
constexpr char kErrorMessage[] = "error_message";
// Period between reporting dataset statistics.
constexpr int kStatsReportingPeriodMillis = 1000;
@ -274,27 +275,25 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
"Unexpected outstanding calls encountered.");
}
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(kInvocationResults, kSizeSuffix)),
invocation_results_.size()));
TF_RETURN_IF_ERROR(
writer->WriteScalar(absl::StrCat(prefix(), "::", kInvocationResults),
kSize, invocation_results_.size()));
for (size_t i = 0; i < invocation_results_.size(); i++) {
const auto& result = *(invocation_results_[i]);
TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result.status));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(
strings::StrCat(kInvocationResults, "[", i, "]", kSizeSuffix)),
result.return_values.size()));
std::string element_prefix =
absl::StrCat(prefix(), "::", kInvocationResults, "::", i);
TF_RETURN_IF_ERROR(
WriteStatusLocked(writer, element_prefix, result.status));
TF_RETURN_IF_ERROR(writer->WriteScalar(element_prefix, kSize,
result.return_values.size()));
for (size_t j = 0; j < result.return_values.size(); j++) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(
strings::StrCat(kInvocationResults, "[", i, "][", j, "]")),
element_prefix, absl::StrCat(kComponent, "[", j, "]"),
result.return_values[j]));
}
if (result.end_of_input) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat(kInvocationResults, "[", i, "]",
kEndOfInputSuffix)),
""));
TF_RETURN_IF_ERROR(
writer->WriteScalar(element_prefix, kEndOfInput, ""));
}
}
return Status::OK();
@ -305,39 +304,36 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
mutex_lock l(*mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
int64 invocation_results_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(kInvocationResults, kSizeSuffix)),
&invocation_results_size));
TF_RETURN_IF_ERROR(
reader->ReadScalar(absl::StrCat(prefix(), "::", kInvocationResults),
kSize, &invocation_results_size));
if (!invocation_results_.empty()) invocation_results_.clear();
for (size_t i = 0; i < invocation_results_size; i++) {
invocation_results_.push_back(std::make_shared<InvocationResult>());
auto& result = *invocation_results_.back();
TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result.status));
std::string element_prefix =
absl::StrCat(prefix(), "::", kInvocationResults, "::", i);
TF_RETURN_IF_ERROR(
ReadStatusLocked(reader, element_prefix, &result.status));
size_t num_return_values;
{
int64 size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(kInvocationResults, "[", i, "]",
kSizeSuffix)),
&size));
TF_RETURN_IF_ERROR(reader->ReadScalar(element_prefix, kSize, &size));
num_return_values = static_cast<size_t>(size);
if (num_return_values != size) {
return errors::InvalidArgument(strings::StrCat(
full_name(strings::StrCat(kInvocationResults, "[", i, "]",
kSizeSuffix)),
": ", size, " is not a valid value of type size_t."));
return errors::InvalidArgument(
element_prefix, ",", kSize, ": ", size,
" is not a valid value of type size_t.");
}
}
result.return_values.reserve(num_return_values);
for (size_t j = 0; j < num_return_values; j++) {
result.return_values.emplace_back();
TF_RETURN_IF_ERROR(
reader->ReadTensor(full_name(strings::StrCat(
kInvocationResults, "[", i, "][", j, "]")),
&result.return_values.back()));
TF_RETURN_IF_ERROR(reader->ReadTensor(
element_prefix, absl::StrCat(kComponent, "[", j, "]"),
&result.return_values.back()));
}
result.end_of_input = reader->Contains(full_name(strings::StrCat(
kInvocationResults, "[", i, "]", kEndOfInputSuffix)));
result.end_of_input = reader->Contains(element_prefix, kEndOfInput);
result.notification.Notify();
}
return Status::OK();
@ -592,28 +588,28 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
}
}
Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
const Status& status)
Status WriteStatusLocked(IteratorStateWriter* writer,
const std::string& key, const Status& status)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
CodeKey(index), static_cast<int64>(status.code())));
key, kErrorCode, static_cast<int64>(status.code())));
if (!status.ok()) {
TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
status.error_message()));
TF_RETURN_IF_ERROR(
writer->WriteScalar(key, kErrorMessage, status.error_message()));
}
return Status::OK();
}
Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
Status ReadStatusLocked(IteratorStateReader* reader, const std::string& key,
Status* status) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
int64 code_int;
TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
TF_RETURN_IF_ERROR(reader->ReadScalar(key, kErrorCode, &code_int));
error::Code code = static_cast<error::Code>(code_int);
if (code != error::Code::OK) {
tstring error_message;
TF_RETURN_IF_ERROR(
reader->ReadScalar(ErrorMessageKey(index), &error_message));
reader->ReadScalar(key, kErrorMessage, &error_message));
*status = Status(code, error_message);
} else {
*status = Status::OK();
@ -621,16 +617,6 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
return Status::OK();
}
string CodeKey(size_t index) {
return full_name(
strings::StrCat(kInvocationResults, "[", index, "]", kCodeSuffix));
}
string ErrorMessageKey(size_t index) {
return full_name(
strings::StrCat(kInvocationResults, "[", index, "]", kErrorMessage));
}
// Used for coordination between the main thread and the runner thread.
const std::shared_ptr<mutex> mu_;
// Used for coordination between the main thread and the runner thread. In

View File

@ -56,6 +56,8 @@ from tensorflow.python.ops.ragged import ragged_concat_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import test
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training.tracking import util as trackable_utils
def _test_combinations_with_mode_v1(mode):
@ -1380,6 +1382,23 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = apply_map(dataset, map_function)
self.assertDatasetProduces(dataset, expected_output=[21])
@combinations.generate(test_base.eager_only_combinations())
def testCheckpointLargeBuffer(self):
# Tensor of size 100M
dataset = dataset_ops.Dataset.from_tensors(
array_ops.ones((25, 1000, 1000), dtype=dtypes.float32))
# Repeat 25 times to exceed the 2G proto limit
dataset = dataset.repeat(30)
dataset = dataset.map(lambda x: x * 2, num_parallel_calls=25)
iterator = iter(dataset)
# Call next() to trigger parallel map calls.
next(iterator)
ckpt = trackable_utils.Checkpoint(iterator=iterator)
manager = checkpoint_management.CheckpointManager(
ckpt, self.get_temp_dir(), max_to_keep=1)
manager.save()
if __name__ == "__main__":
test.main()