From e9ad6196a699454754581f61f47d8a8572c7f21f Mon Sep 17 00:00:00 2001 From: Rachel Lim Date: Thu, 28 May 2020 20:15:47 -0700 Subject: [PATCH] Explicitly call RecordStop when calling group iterator for GroupByWindow dataset, since the current_group_iterator isn't correctly wired up to the output node. See b/154341936 for context. PiperOrigin-RevId: 313709228 Change-Id: I5aa398f71a46713c96aba96f4d42777edfea4fc0 --- .../data/experimental/group_by_window_dataset_op.cc | 7 +++++++ .../experimental/kernel_tests/group_by_window_test.py | 10 ++++++++++ 2 files changed, 17 insertions(+) diff --git a/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc index 462f8ce6ef7..a61ebd70141 100644 --- a/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc @@ -200,8 +200,15 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { // We are currently processing a group, so try to get the // next element. bool end_of_group; + // TODO(b/154341936): Explicitly stopping and starting this iterator + // should not be necessary, but the `::Reduce` added to the prefix + // passed to `current_group_iterator_` when it was created prevents + // the model from identifying this iterator as the output of + // `current_group_iterator_`. + RecordStop(ctx); TF_RETURN_IF_ERROR(current_group_iterator_->GetNext( ctx, out_tensors, &end_of_group)); + RecordStart(ctx); if (!end_of_group) { // Produce the subelement as output. *end_of_sequence = false; diff --git a/tensorflow/python/data/experimental/kernel_tests/group_by_window_test.py b/tensorflow/python/data/experimental/kernel_tests/group_by_window_test.py index 2495083cf63..581d8f42792 100644 --- a/tensorflow/python/data/experimental/kernel_tests/group_by_window_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/group_by_window_test.py @@ -331,6 +331,16 @@ class GroupByWindowTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertDatasetProduces( dataset, expected_output=[[i] for i in range(10)]) + @combinations.generate(test_base.default_test_combinations()) + def testGroupByWindowWithAutotune(self): + dataset = dataset_ops.Dataset.range(1000).apply( + grouping.group_by_window( + lambda x: x // 10, + lambda key, window: dataset_ops.Dataset.from_tensors(key), 4)) + dataset = dataset.map(lambda x: x + 1, num_parallel_calls=-1) + get_next = self.getNext(dataset) + self.evaluate(get_next()) + if __name__ == "__main__": test.main()