[tf.data] Fix performance modeling bug in `group_by_window`.
The use of `::Reduce` suffix would result in iterator nodes created for iterating through the contents of a group not being properly attached to the parent iterator. PiperOrigin-RevId: 314218939 Change-Id: Ic38851d0f9740fc3731bb93050c5e386368bdec4
This commit is contained in:
parent
e0b56ace77
commit
61a0c3bccd
|
@ -200,15 +200,8 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
|
|||
// We are currently processing a group, so try to get the
|
||||
// next element.
|
||||
bool end_of_group;
|
||||
// TODO(b/154341936): Explicitly stopping and starting this iterator
|
||||
// should not be necessary, but the `::Reduce` added to the prefix
|
||||
// passed to `current_group_iterator_` when it was created prevents
|
||||
// the model from identifying this iterator as the output of
|
||||
// `current_group_iterator_`.
|
||||
RecordStop(ctx);
|
||||
TF_RETURN_IF_ERROR(current_group_iterator_->GetNext(
|
||||
ctx, out_tensors, &end_of_group));
|
||||
RecordStart(ctx);
|
||||
if (!end_of_group) {
|
||||
// Produce the subelement as output.
|
||||
*end_of_sequence = false;
|
||||
|
@ -360,7 +353,8 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
|
|||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name("current_iterator_not_initialized"), ""));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("group_counter"),
|
||||
group_counter_ - 1));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -371,7 +365,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
|
|||
|
||||
if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true;
|
||||
|
||||
// Restoring groups
|
||||
// Restoring groups_
|
||||
if (reader->Contains(full_name("groups_size"))) {
|
||||
int64 size;
|
||||
TF_RETURN_IF_ERROR(
|
||||
|
@ -388,7 +382,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
|
|||
}
|
||||
}
|
||||
|
||||
// Restoring Windows
|
||||
// Restoring window_sizes_
|
||||
if (reader->Contains(full_name("window_sizes_size"))) {
|
||||
int64 size;
|
||||
TF_RETURN_IF_ERROR(
|
||||
|
@ -404,6 +398,10 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
|
|||
}
|
||||
}
|
||||
|
||||
// Group counter needs to be restored before current group iterator.
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name("group_counter"), &group_counter_));
|
||||
|
||||
if (reader->Contains(full_name("current_iterator_not_initialized"))) {
|
||||
current_group_iterator_.reset();
|
||||
} else {
|
||||
|
@ -493,11 +491,12 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
|
|||
|
||||
// Create an iterator for the dataset that was returned by `f`.
|
||||
return returned_dataset->MakeIterator(
|
||||
ctx, this, strings::StrCat(prefix(), "::Reduce"),
|
||||
ctx, this, strings::StrCat(prefix(), "[", group_counter_++, "]"),
|
||||
¤t_group_iterator_);
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
int64 group_counter_ TF_GUARDED_BY(mu_) = 0;
|
||||
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
|
||||
// TODO(mrry): Optimize for dense key space if appropriate.
|
||||
bool end_of_input_ TF_GUARDED_BY(mu_) = false;
|
||||
|
|
Loading…
Reference in New Issue