[tf.data] Support checkpointing parallel map datasets with large buffers.

PiperOrigin-RevId: 327095638
Change-Id: I88ca358f4c9688788a33e4348bfe47c68fabd0bc
This commit is contained in:
Andrew Audibert 2020-08-17 14:10:01 -07:00 committed by TensorFlower Gardener
parent 30bd504acc
commit 4112865ad4
2 changed files with 59 additions and 54 deletions

View File

@ -57,11 +57,12 @@ namespace data {
namespace { namespace {
constexpr char kComponent[] = "component";
constexpr char kInvocationResults[] = "invocation_results"; constexpr char kInvocationResults[] = "invocation_results";
constexpr char kSizeSuffix[] = ".size"; constexpr char kSize[] = "size";
constexpr char kEndOfInputSuffix[] = ".end_of_input"; constexpr char kEndOfInput[] = "end_of_input";
constexpr char kCodeSuffix[] = ".code"; constexpr char kErrorCode[] = "code";
constexpr char kErrorMessage[] = ".error_message"; constexpr char kErrorMessage[] = "error_message";
// Period between reporting dataset statistics. // Period between reporting dataset statistics.
constexpr int kStatsReportingPeriodMillis = 1000; constexpr int kStatsReportingPeriodMillis = 1000;
@ -274,27 +275,25 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
"Unexpected outstanding calls encountered."); "Unexpected outstanding calls encountered.");
} }
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
TF_RETURN_IF_ERROR(writer->WriteScalar( TF_RETURN_IF_ERROR(
full_name(strings::StrCat(kInvocationResults, kSizeSuffix)), writer->WriteScalar(absl::StrCat(prefix(), "::", kInvocationResults),
invocation_results_.size())); kSize, invocation_results_.size()));
for (size_t i = 0; i < invocation_results_.size(); i++) { for (size_t i = 0; i < invocation_results_.size(); i++) {
const auto& result = *(invocation_results_[i]); const auto& result = *(invocation_results_[i]);
TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result.status)); std::string element_prefix =
TF_RETURN_IF_ERROR(writer->WriteScalar( absl::StrCat(prefix(), "::", kInvocationResults, "::", i);
full_name( TF_RETURN_IF_ERROR(
strings::StrCat(kInvocationResults, "[", i, "]", kSizeSuffix)), WriteStatusLocked(writer, element_prefix, result.status));
result.return_values.size())); TF_RETURN_IF_ERROR(writer->WriteScalar(element_prefix, kSize,
result.return_values.size()));
for (size_t j = 0; j < result.return_values.size(); j++) { for (size_t j = 0; j < result.return_values.size(); j++) {
TF_RETURN_IF_ERROR(writer->WriteTensor( TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name( element_prefix, absl::StrCat(kComponent, "[", j, "]"),
strings::StrCat(kInvocationResults, "[", i, "][", j, "]")),
result.return_values[j])); result.return_values[j]));
} }
if (result.end_of_input) { if (result.end_of_input) {
TF_RETURN_IF_ERROR(writer->WriteScalar( TF_RETURN_IF_ERROR(
full_name(strings::StrCat(kInvocationResults, "[", i, "]", writer->WriteScalar(element_prefix, kEndOfInput, ""));
kEndOfInputSuffix)),
""));
} }
} }
return Status::OK(); return Status::OK();
@ -305,39 +304,36 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
mutex_lock l(*mu_); mutex_lock l(*mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
int64 invocation_results_size; int64 invocation_results_size;
TF_RETURN_IF_ERROR(reader->ReadScalar( TF_RETURN_IF_ERROR(
full_name(strings::StrCat(kInvocationResults, kSizeSuffix)), reader->ReadScalar(absl::StrCat(prefix(), "::", kInvocationResults),
&invocation_results_size)); kSize, &invocation_results_size));
if (!invocation_results_.empty()) invocation_results_.clear(); if (!invocation_results_.empty()) invocation_results_.clear();
for (size_t i = 0; i < invocation_results_size; i++) { for (size_t i = 0; i < invocation_results_size; i++) {
invocation_results_.push_back(std::make_shared<InvocationResult>()); invocation_results_.push_back(std::make_shared<InvocationResult>());
auto& result = *invocation_results_.back(); auto& result = *invocation_results_.back();
TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result.status)); std::string element_prefix =
absl::StrCat(prefix(), "::", kInvocationResults, "::", i);
TF_RETURN_IF_ERROR(
ReadStatusLocked(reader, element_prefix, &result.status));
size_t num_return_values; size_t num_return_values;
{ {
int64 size; int64 size;
TF_RETURN_IF_ERROR(reader->ReadScalar( TF_RETURN_IF_ERROR(reader->ReadScalar(element_prefix, kSize, &size));
full_name(strings::StrCat(kInvocationResults, "[", i, "]",
kSizeSuffix)),
&size));
num_return_values = static_cast<size_t>(size); num_return_values = static_cast<size_t>(size);
if (num_return_values != size) { if (num_return_values != size) {
return errors::InvalidArgument(strings::StrCat( return errors::InvalidArgument(
full_name(strings::StrCat(kInvocationResults, "[", i, "]", element_prefix, ",", kSize, ": ", size,
kSizeSuffix)), " is not a valid value of type size_t.");
": ", size, " is not a valid value of type size_t."));
} }
} }
result.return_values.reserve(num_return_values); result.return_values.reserve(num_return_values);
for (size_t j = 0; j < num_return_values; j++) { for (size_t j = 0; j < num_return_values; j++) {
result.return_values.emplace_back(); result.return_values.emplace_back();
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(reader->ReadTensor(
reader->ReadTensor(full_name(strings::StrCat( element_prefix, absl::StrCat(kComponent, "[", j, "]"),
kInvocationResults, "[", i, "][", j, "]")), &result.return_values.back()));
&result.return_values.back()));
} }
result.end_of_input = reader->Contains(full_name(strings::StrCat( result.end_of_input = reader->Contains(element_prefix, kEndOfInput);
kInvocationResults, "[", i, "]", kEndOfInputSuffix)));
result.notification.Notify(); result.notification.Notify();
} }
return Status::OK(); return Status::OK();
@ -592,28 +588,28 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
} }
} }
Status WriteStatusLocked(IteratorStateWriter* writer, size_t index, Status WriteStatusLocked(IteratorStateWriter* writer,
const Status& status) const std::string& key, const Status& status)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
TF_RETURN_IF_ERROR(writer->WriteScalar( TF_RETURN_IF_ERROR(writer->WriteScalar(
CodeKey(index), static_cast<int64>(status.code()))); key, kErrorCode, static_cast<int64>(status.code())));
if (!status.ok()) { if (!status.ok()) {
TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index), TF_RETURN_IF_ERROR(
status.error_message())); writer->WriteScalar(key, kErrorMessage, status.error_message()));
} }
return Status::OK(); return Status::OK();
} }
Status ReadStatusLocked(IteratorStateReader* reader, size_t index, Status ReadStatusLocked(IteratorStateReader* reader, const std::string& key,
Status* status) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { Status* status) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
int64 code_int; int64 code_int;
TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int)); TF_RETURN_IF_ERROR(reader->ReadScalar(key, kErrorCode, &code_int));
error::Code code = static_cast<error::Code>(code_int); error::Code code = static_cast<error::Code>(code_int);
if (code != error::Code::OK) { if (code != error::Code::OK) {
tstring error_message; tstring error_message;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
reader->ReadScalar(ErrorMessageKey(index), &error_message)); reader->ReadScalar(key, kErrorMessage, &error_message));
*status = Status(code, error_message); *status = Status(code, error_message);
} else { } else {
*status = Status::OK(); *status = Status::OK();
@ -621,16 +617,6 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
return Status::OK(); return Status::OK();
} }
string CodeKey(size_t index) {
return full_name(
strings::StrCat(kInvocationResults, "[", index, "]", kCodeSuffix));
}
string ErrorMessageKey(size_t index) {
return full_name(
strings::StrCat(kInvocationResults, "[", index, "]", kErrorMessage));
}
// Used for coordination between the main thread and the runner thread. // Used for coordination between the main thread and the runner thread.
const std::shared_ptr<mutex> mu_; const std::shared_ptr<mutex> mu_;
// Used for coordination between the main thread and the runner thread. In // Used for coordination between the main thread and the runner thread. In

View File

@ -56,6 +56,8 @@ from tensorflow.python.ops.ragged import ragged_concat_ops
from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training.tracking import util as trackable_utils
def _test_combinations_with_mode_v1(mode): def _test_combinations_with_mode_v1(mode):
@ -1380,6 +1382,23 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = apply_map(dataset, map_function) dataset = apply_map(dataset, map_function)
self.assertDatasetProduces(dataset, expected_output=[21]) self.assertDatasetProduces(dataset, expected_output=[21])
@combinations.generate(test_base.eager_only_combinations())
def testCheckpointLargeBuffer(self):
# Tensor of size 100M
dataset = dataset_ops.Dataset.from_tensors(
array_ops.ones((25, 1000, 1000), dtype=dtypes.float32))
# Repeat 25 times to exceed the 2G proto limit
dataset = dataset.repeat(30)
dataset = dataset.map(lambda x: x * 2, num_parallel_calls=25)
iterator = iter(dataset)
# Call next() to trigger parallel map calls.
next(iterator)
ckpt = trackable_utils.Checkpoint(iterator=iterator)
manager = checkpoint_management.CheckpointManager(
ckpt, self.get_temp_dir(), max_to_keep=1)
manager.save()
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()