Explicitly call RecordStop when calling group iterator for GroupByWindow dataset, since the current_group_iterator isn't correctly wired up to the output node. See b/154341936 for context.
PiperOrigin-RevId: 313709228 Change-Id: I5aa398f71a46713c96aba96f4d42777edfea4fc0
This commit is contained in:
parent
4c674a64c8
commit
e9ad6196a6
tensorflow
core/kernels/data/experimental
python/data/experimental/kernel_tests
@ -200,8 +200,15 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
|
||||
// We are currently processing a group, so try to get the
|
||||
// next element.
|
||||
bool end_of_group;
|
||||
// TODO(b/154341936): Explicitly stopping and starting this iterator
|
||||
// should not be necessary, but the `::Reduce` added to the prefix
|
||||
// passed to `current_group_iterator_` when it was created prevents
|
||||
// the model from identifying this iterator as the output of
|
||||
// `current_group_iterator_`.
|
||||
RecordStop(ctx);
|
||||
TF_RETURN_IF_ERROR(current_group_iterator_->GetNext(
|
||||
ctx, out_tensors, &end_of_group));
|
||||
RecordStart(ctx);
|
||||
if (!end_of_group) {
|
||||
// Produce the subelement as output.
|
||||
*end_of_sequence = false;
|
||||
|
@ -331,6 +331,16 @@ class GroupByWindowTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertDatasetProduces(
|
||||
dataset, expected_output=[[i] for i in range(10)])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testGroupByWindowWithAutotune(self):
|
||||
dataset = dataset_ops.Dataset.range(1000).apply(
|
||||
grouping.group_by_window(
|
||||
lambda x: x // 10,
|
||||
lambda key, window: dataset_ops.Dataset.from_tensors(key), 4))
|
||||
dataset = dataset.map(lambda x: x + 1, num_parallel_calls=-1)
|
||||
get_next = self.getNext(dataset)
|
||||
self.evaluate(get_next())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user