diff --git a/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc index a61ebd70141..0a6df24d40a 100644 --- a/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc @@ -200,15 +200,8 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { // We are currently processing a group, so try to get the // next element. bool end_of_group; - // TODO(b/154341936): Explicitly stopping and starting this iterator - // should not be necessary, but the `::Reduce` added to the prefix - // passed to `current_group_iterator_` when it was created prevents - // the model from identifying this iterator as the output of - // `current_group_iterator_`. - RecordStop(ctx); TF_RETURN_IF_ERROR(current_group_iterator_->GetNext( ctx, out_tensors, &end_of_group)); - RecordStart(ctx); if (!end_of_group) { // Produce the subelement as output. *end_of_sequence = false; @@ -360,7 +353,8 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(writer->WriteScalar( full_name("current_iterator_not_initialized"), "")); } - + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("group_counter"), + group_counter_ - 1)); return Status::OK(); } @@ -371,7 +365,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true; - // Restoring groups + // Restoring groups_ if (reader->Contains(full_name("groups_size"))) { int64 size; TF_RETURN_IF_ERROR( @@ -388,7 +382,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { } } - // Restoring Windows + // Restoring window_sizes_ if (reader->Contains(full_name("window_sizes_size"))) { int64 size; TF_RETURN_IF_ERROR( @@ -404,6 +398,10 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { } } + // Group counter needs to be restored before current group iterator. + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("group_counter"), &group_counter_)); + if (reader->Contains(full_name("current_iterator_not_initialized"))) { current_group_iterator_.reset(); } else { @@ -493,11 +491,12 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { // Create an iterator for the dataset that was returned by `f`. return returned_dataset->MakeIterator( - ctx, this, strings::StrCat(prefix(), "::Reduce"), + ctx, this, strings::StrCat(prefix(), "[", group_counter_++, "]"), ¤t_group_iterator_); } mutex mu_; + int64 group_counter_ TF_GUARDED_BY(mu_) = 0; std::unique_ptr input_impl_ TF_GUARDED_BY(mu_); // TODO(mrry): Optimize for dense key space if appropriate. bool end_of_input_ TF_GUARDED_BY(mu_) = false;