Always enable get_next_as_optional unless the dataset is finite.
PiperOrigin-RevId: 354672864 Change-Id: I3a490952e8bd075bf035a0126e62b9cf5082104e
This commit is contained in:
parent
cf3d55222d
commit
055896a275
@ -2157,7 +2157,8 @@ def _enable_get_next_as_optional(strategy, dataset):
|
|||||||
# dataset is created in eager mode, as we need to evaluate the dataset
|
# dataset is created in eager mode, as we need to evaluate the dataset
|
||||||
# cardinality.
|
# cardinality.
|
||||||
with ops.device(dataset._variant_tensor.device): # pylint: disable=protected-access
|
with ops.device(dataset._variant_tensor.device): # pylint: disable=protected-access
|
||||||
return dataset.cardinality().numpy() != cardinality.INFINITE
|
if dataset.cardinality().numpy() == cardinality.INFINITE:
|
||||||
|
return False
|
||||||
|
|
||||||
return not _is_statically_shaped(
|
return not _is_statically_shaped(
|
||||||
dataset.element_spec) or strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access
|
dataset.element_spec) or strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access
|
||||||
|
@ -1120,21 +1120,21 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
|
|||||||
except (StopIteration, errors.OutOfRangeError):
|
except (StopIteration, errors.OutOfRangeError):
|
||||||
return sums
|
return sums
|
||||||
|
|
||||||
expected_for_sum = 200.
|
|
||||||
if (not drop_remainder or input_type == "input_fn"):
|
|
||||||
expected_for_sum = 310.
|
|
||||||
while_sums = sum_while_loop(
|
while_sums = sum_while_loop(
|
||||||
iter(dataset),
|
iter(dataset),
|
||||||
defun(lambda state, iterator: _reduce(state, next(iterator))))
|
defun(lambda state, iterator: _reduce(state, next(iterator))))
|
||||||
self.assertAllEqual(nest.flatten(while_sums), [expected_for_sum] * 3)
|
self.assertAllEqual(
|
||||||
|
nest.flatten(while_sums),
|
||||||
|
# When there's no partial batch, the sum is smaller.
|
||||||
|
[200. if drop_remainder else 310.] * 3)
|
||||||
|
for_sums = defun(sum_for_loop)(dataset)
|
||||||
# For loops always call get next as optional inside tf functions, so we
|
# For loops always call get next as optional inside tf functions, so we
|
||||||
# expect 310 here when using an input function (as there are 5 batches of
|
# expect 310 here when using an input function (as there are 5 batches of
|
||||||
# size 4 round robined over 2 replicas.
|
# size 4 round robined over 2 replicas.
|
||||||
expected_for_sum = 200.
|
expected_for_sum = 200.
|
||||||
if (not drop_remainder or input_type == "input_fn"):
|
if (not drop_remainder or (
|
||||||
|
defun_type == "tf_function" and input_type == "input_fn")):
|
||||||
expected_for_sum = 310.
|
expected_for_sum = 310.
|
||||||
for_sums = defun(sum_for_loop)(dataset)
|
|
||||||
self.assertAllEqual(nest.flatten(for_sums), [expected_for_sum] * 3)
|
self.assertAllEqual(nest.flatten(for_sums), [expected_for_sum] * 3)
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
@ -1148,11 +1148,11 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
|
|||||||
],
|
],
|
||||||
input_type=["dataset", "input_fn"],
|
input_type=["dataset", "input_fn"],
|
||||||
drop_remainder=[False, True],
|
drop_remainder=[False, True],
|
||||||
repeat=[False, True],
|
|
||||||
tensor_type=["sparse", "ragged"],
|
tensor_type=["sparse", "ragged"],
|
||||||
enable_get_next_as_optional=[True, False]))
|
enable_get_next_as_optional=[True, False]
|
||||||
def testRaggedSparseGetNextAsOptional(self, distribution, input_type,
|
))
|
||||||
drop_remainder, repeat, tensor_type,
|
def testRaggedSparseGetNextAsOptional(
|
||||||
|
self, distribution, input_type, drop_remainder, tensor_type,
|
||||||
enable_get_next_as_optional):
|
enable_get_next_as_optional):
|
||||||
"""Test with `RaggedTensor`s and `SparseTensor`s."""
|
"""Test with `RaggedTensor`s and `SparseTensor`s."""
|
||||||
if not tf2.enabled():
|
if not tf2.enabled():
|
||||||
@ -1174,8 +1174,6 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
|
|||||||
ragged_tensor.to_sparse()),
|
ragged_tensor.to_sparse()),
|
||||||
})
|
})
|
||||||
dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
|
dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
|
||||||
if repeat:
|
|
||||||
dataset = dataset.repeat()
|
|
||||||
return dataset.batch(batch_size, drop_remainder=drop_remainder)
|
return dataset.batch(batch_size, drop_remainder=drop_remainder)
|
||||||
|
|
||||||
if input_type == "dataset":
|
if input_type == "dataset":
|
||||||
@ -1185,8 +1183,8 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
|
|||||||
ds = distribution.distribute_datasets_from_function(dataset_fn)
|
ds = distribution.distribute_datasets_from_function(dataset_fn)
|
||||||
iterator = iter(ds)
|
iterator = iter(ds)
|
||||||
|
|
||||||
self.assertEqual(iterator._enable_get_next_as_optional, (not repeat) and
|
self.assertEqual(iterator._enable_get_next_as_optional,
|
||||||
enable_get_next_as_optional)
|
(not drop_remainder) and enable_get_next_as_optional)
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.combine(
|
combinations.combine(
|
||||||
|
Loading…
Reference in New Issue
Block a user