[tf.data] Support checkpointing parallel map datasets with large buffers.
PiperOrigin-RevId: 327095638 Change-Id: I88ca358f4c9688788a33e4348bfe47c68fabd0bc
This commit is contained in:
parent
30bd504acc
commit
4112865ad4
@ -57,11 +57,12 @@ namespace data {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kComponent[] = "component";
|
||||
constexpr char kInvocationResults[] = "invocation_results";
|
||||
constexpr char kSizeSuffix[] = ".size";
|
||||
constexpr char kEndOfInputSuffix[] = ".end_of_input";
|
||||
constexpr char kCodeSuffix[] = ".code";
|
||||
constexpr char kErrorMessage[] = ".error_message";
|
||||
constexpr char kSize[] = "size";
|
||||
constexpr char kEndOfInput[] = "end_of_input";
|
||||
constexpr char kErrorCode[] = "code";
|
||||
constexpr char kErrorMessage[] = "error_message";
|
||||
|
||||
// Period between reporting dataset statistics.
|
||||
constexpr int kStatsReportingPeriodMillis = 1000;
|
||||
@ -274,27 +275,25 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
||||
"Unexpected outstanding calls encountered.");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(strings::StrCat(kInvocationResults, kSizeSuffix)),
|
||||
invocation_results_.size()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(absl::StrCat(prefix(), "::", kInvocationResults),
|
||||
kSize, invocation_results_.size()));
|
||||
for (size_t i = 0; i < invocation_results_.size(); i++) {
|
||||
const auto& result = *(invocation_results_[i]);
|
||||
TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result.status));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(
|
||||
strings::StrCat(kInvocationResults, "[", i, "]", kSizeSuffix)),
|
||||
result.return_values.size()));
|
||||
std::string element_prefix =
|
||||
absl::StrCat(prefix(), "::", kInvocationResults, "::", i);
|
||||
TF_RETURN_IF_ERROR(
|
||||
WriteStatusLocked(writer, element_prefix, result.status));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(element_prefix, kSize,
|
||||
result.return_values.size()));
|
||||
for (size_t j = 0; j < result.return_values.size(); j++) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteTensor(
|
||||
full_name(
|
||||
strings::StrCat(kInvocationResults, "[", i, "][", j, "]")),
|
||||
element_prefix, absl::StrCat(kComponent, "[", j, "]"),
|
||||
result.return_values[j]));
|
||||
}
|
||||
if (result.end_of_input) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(strings::StrCat(kInvocationResults, "[", i, "]",
|
||||
kEndOfInputSuffix)),
|
||||
""));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(element_prefix, kEndOfInput, ""));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
@ -305,39 +304,36 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
||||
mutex_lock l(*mu_);
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
|
||||
int64 invocation_results_size;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||
full_name(strings::StrCat(kInvocationResults, kSizeSuffix)),
|
||||
&invocation_results_size));
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(absl::StrCat(prefix(), "::", kInvocationResults),
|
||||
kSize, &invocation_results_size));
|
||||
if (!invocation_results_.empty()) invocation_results_.clear();
|
||||
for (size_t i = 0; i < invocation_results_size; i++) {
|
||||
invocation_results_.push_back(std::make_shared<InvocationResult>());
|
||||
auto& result = *invocation_results_.back();
|
||||
TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result.status));
|
||||
std::string element_prefix =
|
||||
absl::StrCat(prefix(), "::", kInvocationResults, "::", i);
|
||||
TF_RETURN_IF_ERROR(
|
||||
ReadStatusLocked(reader, element_prefix, &result.status));
|
||||
size_t num_return_values;
|
||||
{
|
||||
int64 size;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||
full_name(strings::StrCat(kInvocationResults, "[", i, "]",
|
||||
kSizeSuffix)),
|
||||
&size));
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(element_prefix, kSize, &size));
|
||||
num_return_values = static_cast<size_t>(size);
|
||||
if (num_return_values != size) {
|
||||
return errors::InvalidArgument(strings::StrCat(
|
||||
full_name(strings::StrCat(kInvocationResults, "[", i, "]",
|
||||
kSizeSuffix)),
|
||||
": ", size, " is not a valid value of type size_t."));
|
||||
return errors::InvalidArgument(
|
||||
element_prefix, ",", kSize, ": ", size,
|
||||
" is not a valid value of type size_t.");
|
||||
}
|
||||
}
|
||||
result.return_values.reserve(num_return_values);
|
||||
for (size_t j = 0; j < num_return_values; j++) {
|
||||
result.return_values.emplace_back();
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadTensor(full_name(strings::StrCat(
|
||||
kInvocationResults, "[", i, "][", j, "]")),
|
||||
&result.return_values.back()));
|
||||
TF_RETURN_IF_ERROR(reader->ReadTensor(
|
||||
element_prefix, absl::StrCat(kComponent, "[", j, "]"),
|
||||
&result.return_values.back()));
|
||||
}
|
||||
result.end_of_input = reader->Contains(full_name(strings::StrCat(
|
||||
kInvocationResults, "[", i, "]", kEndOfInputSuffix)));
|
||||
result.end_of_input = reader->Contains(element_prefix, kEndOfInput);
|
||||
result.notification.Notify();
|
||||
}
|
||||
return Status::OK();
|
||||
@ -592,28 +588,28 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
}
|
||||
|
||||
Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
|
||||
const Status& status)
|
||||
Status WriteStatusLocked(IteratorStateWriter* writer,
|
||||
const std::string& key, const Status& status)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
CodeKey(index), static_cast<int64>(status.code())));
|
||||
key, kErrorCode, static_cast<int64>(status.code())));
|
||||
if (!status.ok()) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
|
||||
status.error_message()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(key, kErrorMessage, status.error_message()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
|
||||
Status ReadStatusLocked(IteratorStateReader* reader, const std::string& key,
|
||||
Status* status) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
int64 code_int;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(key, kErrorCode, &code_int));
|
||||
error::Code code = static_cast<error::Code>(code_int);
|
||||
|
||||
if (code != error::Code::OK) {
|
||||
tstring error_message;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(ErrorMessageKey(index), &error_message));
|
||||
reader->ReadScalar(key, kErrorMessage, &error_message));
|
||||
*status = Status(code, error_message);
|
||||
} else {
|
||||
*status = Status::OK();
|
||||
@ -621,16 +617,6 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
string CodeKey(size_t index) {
|
||||
return full_name(
|
||||
strings::StrCat(kInvocationResults, "[", index, "]", kCodeSuffix));
|
||||
}
|
||||
|
||||
string ErrorMessageKey(size_t index) {
|
||||
return full_name(
|
||||
strings::StrCat(kInvocationResults, "[", index, "]", kErrorMessage));
|
||||
}
|
||||
|
||||
// Used for coordination between the main thread and the runner thread.
|
||||
const std::shared_ptr<mutex> mu_;
|
||||
// Used for coordination between the main thread and the runner thread. In
|
||||
|
@ -56,6 +56,8 @@ from tensorflow.python.ops.ragged import ragged_concat_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
|
||||
|
||||
def _test_combinations_with_mode_v1(mode):
|
||||
@ -1380,6 +1382,23 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset = apply_map(dataset, map_function)
|
||||
self.assertDatasetProduces(dataset, expected_output=[21])
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testCheckpointLargeBuffer(self):
|
||||
# Tensor of size 100M
|
||||
dataset = dataset_ops.Dataset.from_tensors(
|
||||
array_ops.ones((25, 1000, 1000), dtype=dtypes.float32))
|
||||
# Repeat 25 times to exceed the 2G proto limit
|
||||
dataset = dataset.repeat(30)
|
||||
dataset = dataset.map(lambda x: x * 2, num_parallel_calls=25)
|
||||
|
||||
iterator = iter(dataset)
|
||||
# Call next() to trigger parallel map calls.
|
||||
next(iterator)
|
||||
ckpt = trackable_utils.Checkpoint(iterator=iterator)
|
||||
manager = checkpoint_management.CheckpointManager(
|
||||
ckpt, self.get_temp_dir(), max_to_keep=1)
|
||||
manager.save()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user