[tf.data] Fix performance modeling bug in `group_by_window`.

The use of `::Reduce` suffix would result in iterator nodes created for iterating through the contents of a group not being properly attached to the parent iterator.

PiperOrigin-RevId: 314218939
Change-Id: Ic38851d0f9740fc3731bb93050c5e386368bdec4
This commit is contained in:
Jiri Simsa 2020-06-01 15:48:07 -07:00 committed by TensorFlower Gardener
parent e0b56ace77
commit 61a0c3bccd
1 changed files with 10 additions and 11 deletions

View File

@ -200,15 +200,8 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
// We are currently processing a group, so try to get the
// next element.
bool end_of_group;
// TODO(b/154341936): Explicitly stopping and starting this iterator
// should not be necessary, but the `::Reduce` added to the prefix
// passed to `current_group_iterator_` when it was created prevents
// the model from identifying this iterator as the output of
// `current_group_iterator_`.
RecordStop(ctx);
TF_RETURN_IF_ERROR(current_group_iterator_->GetNext(
ctx, out_tensors, &end_of_group));
RecordStart(ctx);
if (!end_of_group) {
// Produce the subelement as output.
*end_of_sequence = false;
@ -360,7 +353,8 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name("current_iterator_not_initialized"), ""));
}
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("group_counter"),
group_counter_ - 1));
return Status::OK();
}
@ -371,7 +365,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true;
// Restoring groups
// Restoring groups_
if (reader->Contains(full_name("groups_size"))) {
int64 size;
TF_RETURN_IF_ERROR(
@ -388,7 +382,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
}
}
// Restoring Windows
// Restoring window_sizes_
if (reader->Contains(full_name("window_sizes_size"))) {
int64 size;
TF_RETURN_IF_ERROR(
@ -404,6 +398,10 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
}
}
// Group counter needs to be restored before current group iterator.
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("group_counter"), &group_counter_));
if (reader->Contains(full_name("current_iterator_not_initialized"))) {
current_group_iterator_.reset();
} else {
@ -493,11 +491,12 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
// Create an iterator for the dataset that was returned by `f`.
return returned_dataset->MakeIterator(
ctx, this, strings::StrCat(prefix(), "::Reduce"),
ctx, this, strings::StrCat(prefix(), "[", group_counter_++, "]"),
&current_group_iterator_);
}
mutex mu_;
int64 group_counter_ TF_GUARDED_BY(mu_) = 0;
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
// TODO(mrry): Optimize for dense key space if appropriate.
bool end_of_input_ TF_GUARDED_BY(mu_) = false;