Explicitly call RecordStop when calling group iterator for GroupByWindow dataset, since the current_group_iterator isn't correctly wired up to the output node. See b/154341936 for context.
PiperOrigin-RevId: 313709228 Change-Id: I5aa398f71a46713c96aba96f4d42777edfea4fc0
This commit is contained in:
parent
4c674a64c8
commit
e9ad6196a6
@ -200,8 +200,15 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
// We are currently processing a group, so try to get the
|
// We are currently processing a group, so try to get the
|
||||||
// next element.
|
// next element.
|
||||||
bool end_of_group;
|
bool end_of_group;
|
||||||
|
// TODO(b/154341936): Explicitly stopping and starting this iterator
|
||||||
|
// should not be necessary, but the `::Reduce` added to the prefix
|
||||||
|
// passed to `current_group_iterator_` when it was created prevents
|
||||||
|
// the model from identifying this iterator as the output of
|
||||||
|
// `current_group_iterator_`.
|
||||||
|
RecordStop(ctx);
|
||||||
TF_RETURN_IF_ERROR(current_group_iterator_->GetNext(
|
TF_RETURN_IF_ERROR(current_group_iterator_->GetNext(
|
||||||
ctx, out_tensors, &end_of_group));
|
ctx, out_tensors, &end_of_group));
|
||||||
|
RecordStart(ctx);
|
||||||
if (!end_of_group) {
|
if (!end_of_group) {
|
||||||
// Produce the subelement as output.
|
// Produce the subelement as output.
|
||||||
*end_of_sequence = false;
|
*end_of_sequence = false;
|
||||||
|
@ -331,6 +331,16 @@ class GroupByWindowTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
self.assertDatasetProduces(
|
self.assertDatasetProduces(
|
||||||
dataset, expected_output=[[i] for i in range(10)])
|
dataset, expected_output=[[i] for i in range(10)])
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testGroupByWindowWithAutotune(self):
|
||||||
|
dataset = dataset_ops.Dataset.range(1000).apply(
|
||||||
|
grouping.group_by_window(
|
||||||
|
lambda x: x // 10,
|
||||||
|
lambda key, window: dataset_ops.Dataset.from_tensors(key), 4))
|
||||||
|
dataset = dataset.map(lambda x: x + 1, num_parallel_calls=-1)
|
||||||
|
get_next = self.getNext(dataset)
|
||||||
|
self.evaluate(get_next())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user