Update static shape detection to be static batch size detection for sparse or ragged tensors. This is needed as when they are batched by a dataset they will typically have a shape like (batch_size, None).
PiperOrigin-RevId: 295809971 Change-Id: I64d2fed27e0766c8857141bc28c581086155f77e
This commit is contained in:
parent
2c452720ee
commit
8f3272028b
|
@ -202,6 +202,30 @@ def _get_next_as_optional(iterator, strategy, name=None):
|
|||
return global_has_value, replicas
|
||||
|
||||
|
||||
def _is_statically_shaped(tensor_class, shape):
|
||||
"""Test if an iteratort output is statically shaped.
|
||||
|
||||
For sparse and ragged tensors this only tests the batch dimension.
|
||||
|
||||
Args:
|
||||
tensor_class: a class from an iterator.output_classes list.
|
||||
shape: a TensorShape from an iterator.output_shapes list.
|
||||
|
||||
Returns:
|
||||
True if the shape is static, false otherwise.
|
||||
"""
|
||||
if (tensor_class == sparse_tensor.SparseTensor or
|
||||
isinstance(tensor_class, ragged_tensor.RaggedTensorSpec)):
|
||||
# For sparse or ragged tensor, we should only check the first
|
||||
# dimension in order to get_next_as_optional. This is because
|
||||
# when these tensors get batched by dataset only the batch dimension
|
||||
# is set.
|
||||
if shape.rank > 0 and shape.as_list()[0] is None:
|
||||
return False
|
||||
return True
|
||||
return shape.is_fully_defined()
|
||||
|
||||
|
||||
class DistributedIterator(object):
|
||||
"""Common implementation for all input iterators."""
|
||||
|
||||
|
@ -210,9 +234,10 @@ class DistributedIterator(object):
|
|||
for iterator in iterators:
|
||||
if not isinstance(iterator, _SingleWorkerDatasetIterator):
|
||||
continue
|
||||
flattened_shapes = nest.flatten(iterator.output_shapes)
|
||||
for output_shape in flattened_shapes:
|
||||
if not output_shape.is_fully_defined():
|
||||
flattened = zip(nest.flatten(iterator.output_shapes),
|
||||
nest.flatten(iterator.output_classes))
|
||||
for output_shape, output_class in flattened:
|
||||
if not _is_statically_shaped(output_class, output_shape):
|
||||
static_shape = False
|
||||
break
|
||||
|
||||
|
|
|
@ -525,13 +525,16 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
|
|||
],
|
||||
input_type=["dataset", "input_fn"],
|
||||
drop_remainder=[False, True],
|
||||
defun=[lambda f: f, def_function.function],
|
||||
defun_type=["lambda", "tf_function"],
|
||||
))
|
||||
def testRaggedSparse(self, distribution, input_type, drop_remainder, defun):
|
||||
def testRaggedSparse(self, distribution, input_type, drop_remainder,
|
||||
defun_type):
|
||||
"""Test with `RaggedTensor`s and `SparseTensor`s."""
|
||||
if not tf2.enabled():
|
||||
self.skipTest("Only V2 is supported.")
|
||||
|
||||
defun = {"lambda": lambda f: f,
|
||||
"tf_function": def_function.function}[defun_type]
|
||||
distribution.extended.experimental_enable_get_next_as_optional = True
|
||||
global_batch_size = 8
|
||||
|
||||
|
@ -609,14 +612,72 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
|
|||
except (StopIteration, errors.OutOfRangeError):
|
||||
return sums
|
||||
|
||||
sums = sum_while_loop(
|
||||
while_sums = sum_while_loop(
|
||||
iter(dataset),
|
||||
defun(lambda state, iterator: _reduce(state, next(iterator))))
|
||||
self.assertDictEqual(sums, defun(sum_for_loop)(dataset))
|
||||
self.assertAllEqual(
|
||||
nest.flatten(sums),
|
||||
nest.flatten(while_sums),
|
||||
# When there's no partial batch, the sum is smaller.
|
||||
[200. if input_type == "dataset" and drop_remainder else 310.] * 3)
|
||||
[200. if drop_remainder else 310.] * 3)
|
||||
for_sums = defun(sum_for_loop)(dataset)
|
||||
# For loops always call get next as optional inside tf functions, so we
|
||||
# expect 310 here when using an input function (as there are 5 batches of
|
||||
# size 4 round robined over 2 replicas.
|
||||
expected_for_sum = 200.
|
||||
if (not drop_remainder or (
|
||||
defun_type == "tf_function" and input_type == "input_fn")):
|
||||
expected_for_sum = 310.
|
||||
self.assertAllEqual(nest.flatten(for_sums), [expected_for_sum] * 3)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.one_device_strategy,
|
||||
strategy_combinations.mirrored_strategy_with_one_cpu
|
||||
],
|
||||
input_type=["dataset", "input_fn"],
|
||||
drop_remainder=[False, True],
|
||||
tensor_type=["sparse", "ragged"],
|
||||
enable_get_next_as_optional=[True, False]
|
||||
))
|
||||
def testRaggedSparseGetNextAsOptional(
|
||||
self, distribution, input_type, drop_remainder, tensor_type,
|
||||
enable_get_next_as_optional):
|
||||
"""Test with `RaggedTensor`s and `SparseTensor`s."""
|
||||
if not tf2.enabled():
|
||||
self.skipTest("Only V2 is supported.")
|
||||
|
||||
distribution.extended.experimental_enable_get_next_as_optional = (
|
||||
enable_get_next_as_optional)
|
||||
global_batch_size = 8
|
||||
|
||||
def dataset_fn(ctx=None):
|
||||
ctx = ctx or distribute_lib.InputContext()
|
||||
batch_size = ctx.get_per_replica_batch_size(global_batch_size)
|
||||
# Use 20 which isn't divisible by 8 to test partial batch behavior.
|
||||
row_lengths = np.mod(np.arange(20), 4).astype(np.int64)
|
||||
ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
|
||||
np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths)
|
||||
dataset = dataset_ops.DatasetV2.from_tensor_slices({
|
||||
tensor_type: (ragged_tensor if tensor_type == "ragged" else
|
||||
ragged_tensor.to_sparse()),
|
||||
})
|
||||
dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
|
||||
return dataset.batch(batch_size, drop_remainder=drop_remainder)
|
||||
|
||||
if input_type == "dataset":
|
||||
ds = distribution.experimental_distribute_dataset(
|
||||
dataset_fn(distribute_lib.InputContext()))
|
||||
else:
|
||||
ds = distribution.experimental_distribute_datasets_from_function(
|
||||
dataset_fn)
|
||||
iterator = iter(ds)
|
||||
|
||||
self.assertEqual(iterator._enable_get_next_as_optional,
|
||||
(not drop_remainder) and enable_get_next_as_optional)
|
||||
|
||||
|
||||
class DistributedIteratorMultiWorkerTest(
|
||||
|
|
Loading…
Reference in New Issue