Always enable get_next_as_optional unless the dataset is finite.

PiperOrigin-RevId: 341945136
Change-Id: I79fdec366be2119b6a28063f193e6cecb7a5f9e2
This commit is contained in:
Ruoxin Sang 2020-11-11 17:34:10 -08:00 committed by TensorFlower Gardener
parent 3e03af4c7f
commit 5a0ed634af
2 changed files with 16 additions and 17 deletions

View File

@ -2147,7 +2147,8 @@ def _enable_get_next_as_optional(strategy, dataset):
# dataset is created in eager mode, as we need to evaluate the dataset
# cardinality.
with ops.device(dataset._variant_tensor.device): # pylint: disable=protected-access
return dataset.cardinality().numpy() != cardinality.INFINITE
if dataset.cardinality().numpy() == cardinality.INFINITE:
return False
return not _is_statically_shaped(
dataset.element_spec) or strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access

View File

@ -1118,21 +1118,21 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
except (StopIteration, errors.OutOfRangeError):
return sums
expected_for_sum = 200.
if (not drop_remainder or input_type == "input_fn"):
expected_for_sum = 310.
while_sums = sum_while_loop(
iter(dataset),
defun(lambda state, iterator: _reduce(state, next(iterator))))
self.assertAllEqual(nest.flatten(while_sums), [expected_for_sum] * 3)
self.assertAllEqual(
nest.flatten(while_sums),
# When there's no partial batch, the sum is smaller.
[200. if drop_remainder else 310.] * 3)
for_sums = defun(sum_for_loop)(dataset)
# For loops always call get next as optional inside tf functions, so we
# expect 310 here when using an input function (as there are 5 batches of
# size 4 round robined over 2 replicas.
expected_for_sum = 200.
if (not drop_remainder or input_type == "input_fn"):
if (not drop_remainder or (
defun_type == "tf_function" and input_type == "input_fn")):
expected_for_sum = 310.
for_sums = defun(sum_for_loop)(dataset)
self.assertAllEqual(nest.flatten(for_sums), [expected_for_sum] * 3)
@combinations.generate(
@ -1146,12 +1146,12 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
],
input_type=["dataset", "input_fn"],
drop_remainder=[False, True],
repeat=[False, True],
tensor_type=["sparse", "ragged"],
enable_get_next_as_optional=[True, False]))
def testRaggedSparseGetNextAsOptional(self, distribution, input_type,
drop_remainder, repeat, tensor_type,
enable_get_next_as_optional):
enable_get_next_as_optional=[True, False]
))
def testRaggedSparseGetNextAsOptional(
self, distribution, input_type, drop_remainder, tensor_type,
enable_get_next_as_optional):
"""Test with `RaggedTensor`s and `SparseTensor`s."""
if not tf2.enabled():
self.skipTest("Only V2 is supported.")
@ -1172,8 +1172,6 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
ragged_tensor.to_sparse()),
})
dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
if repeat:
dataset = dataset.repeat()
return dataset.batch(batch_size, drop_remainder=drop_remainder)
if input_type == "dataset":
@ -1183,8 +1181,8 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
ds = distribution.distribute_datasets_from_function(dataset_fn)
iterator = iter(ds)
self.assertEqual(iterator._enable_get_next_as_optional, (not repeat) and
enable_get_next_as_optional)
self.assertEqual(iterator._enable_get_next_as_optional,
(not drop_remainder) and enable_get_next_as_optional)
@combinations.generate(
combinations.combine(