Always enable get_next_as_optional unless the dataset is finite.
PiperOrigin-RevId: 341945136 Change-Id: I79fdec366be2119b6a28063f193e6cecb7a5f9e2
This commit is contained in:
parent
3e03af4c7f
commit
5a0ed634af
@ -2147,7 +2147,8 @@ def _enable_get_next_as_optional(strategy, dataset):
|
||||
# dataset is created in eager mode, as we need to evaluate the dataset
|
||||
# cardinality.
|
||||
with ops.device(dataset._variant_tensor.device): # pylint: disable=protected-access
|
||||
return dataset.cardinality().numpy() != cardinality.INFINITE
|
||||
if dataset.cardinality().numpy() == cardinality.INFINITE:
|
||||
return False
|
||||
|
||||
return not _is_statically_shaped(
|
||||
dataset.element_spec) or strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access
|
||||
|
@ -1118,21 +1118,21 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
|
||||
except (StopIteration, errors.OutOfRangeError):
|
||||
return sums
|
||||
|
||||
expected_for_sum = 200.
|
||||
if (not drop_remainder or input_type == "input_fn"):
|
||||
expected_for_sum = 310.
|
||||
while_sums = sum_while_loop(
|
||||
iter(dataset),
|
||||
defun(lambda state, iterator: _reduce(state, next(iterator))))
|
||||
self.assertAllEqual(nest.flatten(while_sums), [expected_for_sum] * 3)
|
||||
|
||||
self.assertAllEqual(
|
||||
nest.flatten(while_sums),
|
||||
# When there's no partial batch, the sum is smaller.
|
||||
[200. if drop_remainder else 310.] * 3)
|
||||
for_sums = defun(sum_for_loop)(dataset)
|
||||
# For loops always call get next as optional inside tf functions, so we
|
||||
# expect 310 here when using an input function (as there are 5 batches of
|
||||
# size 4 round robined over 2 replicas.
|
||||
expected_for_sum = 200.
|
||||
if (not drop_remainder or input_type == "input_fn"):
|
||||
if (not drop_remainder or (
|
||||
defun_type == "tf_function" and input_type == "input_fn")):
|
||||
expected_for_sum = 310.
|
||||
for_sums = defun(sum_for_loop)(dataset)
|
||||
self.assertAllEqual(nest.flatten(for_sums), [expected_for_sum] * 3)
|
||||
|
||||
@combinations.generate(
|
||||
@ -1146,12 +1146,12 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
|
||||
],
|
||||
input_type=["dataset", "input_fn"],
|
||||
drop_remainder=[False, True],
|
||||
repeat=[False, True],
|
||||
tensor_type=["sparse", "ragged"],
|
||||
enable_get_next_as_optional=[True, False]))
|
||||
def testRaggedSparseGetNextAsOptional(self, distribution, input_type,
|
||||
drop_remainder, repeat, tensor_type,
|
||||
enable_get_next_as_optional):
|
||||
enable_get_next_as_optional=[True, False]
|
||||
))
|
||||
def testRaggedSparseGetNextAsOptional(
|
||||
self, distribution, input_type, drop_remainder, tensor_type,
|
||||
enable_get_next_as_optional):
|
||||
"""Test with `RaggedTensor`s and `SparseTensor`s."""
|
||||
if not tf2.enabled():
|
||||
self.skipTest("Only V2 is supported.")
|
||||
@ -1172,8 +1172,6 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
|
||||
ragged_tensor.to_sparse()),
|
||||
})
|
||||
dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
|
||||
if repeat:
|
||||
dataset = dataset.repeat()
|
||||
return dataset.batch(batch_size, drop_remainder=drop_remainder)
|
||||
|
||||
if input_type == "dataset":
|
||||
@ -1183,8 +1181,8 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
|
||||
ds = distribution.distribute_datasets_from_function(dataset_fn)
|
||||
iterator = iter(ds)
|
||||
|
||||
self.assertEqual(iterator._enable_get_next_as_optional, (not repeat) and
|
||||
enable_get_next_as_optional)
|
||||
self.assertEqual(iterator._enable_get_next_as_optional,
|
||||
(not drop_remainder) and enable_get_next_as_optional)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
|
Loading…
Reference in New Issue
Block a user