Explicitly call RecordStop when calling group iterator for GroupByWindow dataset, since the current_group_iterator isn't correctly wired up to the output node. See b/154341936 for context.

PiperOrigin-RevId: 313709228
Change-Id: I5aa398f71a46713c96aba96f4d42777edfea4fc0
This commit is contained in:
Rachel Lim 2020-05-28 20:15:47 -07:00 committed by TensorFlower Gardener
parent 4c674a64c8
commit e9ad6196a6
2 changed files with 17 additions and 0 deletions
tensorflow
core/kernels/data/experimental
python/data/experimental/kernel_tests

View File

@ -200,8 +200,15 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
// We are currently processing a group, so try to get the
// next element.
bool end_of_group;
// TODO(b/154341936): Explicitly stopping and starting this iterator
// should not be necessary, but the `::Reduce` added to the prefix
// passed to `current_group_iterator_` when it was created prevents
// the model from identifying this iterator as the output of
// `current_group_iterator_`.
RecordStop(ctx);
TF_RETURN_IF_ERROR(current_group_iterator_->GetNext(
ctx, out_tensors, &end_of_group));
RecordStart(ctx);
if (!end_of_group) {
// Produce the subelement as output.
*end_of_sequence = false;

View File

@ -331,6 +331,16 @@ class GroupByWindowTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertDatasetProduces(
dataset, expected_output=[[i] for i in range(10)])
@combinations.generate(test_base.default_test_combinations())
def testGroupByWindowWithAutotune(self):
dataset = dataset_ops.Dataset.range(1000).apply(
grouping.group_by_window(
lambda x: x // 10,
lambda key, window: dataset_ops.Dataset.from_tensors(key), 4))
dataset = dataset.map(lambda x: x + 1, num_parallel_calls=-1)
get_next = self.getNext(dataset)
self.evaluate(get_next())
if __name__ == "__main__":
test.main()