Update static shape detection to be static batch size detection for sparse or ragged tensors. This is needed as when they are batched by a dataset they will typically have a shape like (batch_size, None).
PiperOrigin-RevId: 295809971 Change-Id: I64d2fed27e0766c8857141bc28c581086155f77e
This commit is contained in:
parent
2c452720ee
commit
8f3272028b
|
@ -202,6 +202,30 @@ def _get_next_as_optional(iterator, strategy, name=None):
|
||||||
return global_has_value, replicas
|
return global_has_value, replicas
|
||||||
|
|
||||||
|
|
||||||
|
def _is_statically_shaped(tensor_class, shape):
|
||||||
|
"""Test if an iteratort output is statically shaped.
|
||||||
|
|
||||||
|
For sparse and ragged tensors this only tests the batch dimension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor_class: a class from an iterator.output_classes list.
|
||||||
|
shape: a TensorShape from an iterator.output_shapes list.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the shape is static, false otherwise.
|
||||||
|
"""
|
||||||
|
if (tensor_class == sparse_tensor.SparseTensor or
|
||||||
|
isinstance(tensor_class, ragged_tensor.RaggedTensorSpec)):
|
||||||
|
# For sparse or ragged tensor, we should only check the first
|
||||||
|
# dimension in order to get_next_as_optional. This is because
|
||||||
|
# when these tensors get batched by dataset only the batch dimension
|
||||||
|
# is set.
|
||||||
|
if shape.rank > 0 and shape.as_list()[0] is None:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
return shape.is_fully_defined()
|
||||||
|
|
||||||
|
|
||||||
class DistributedIterator(object):
|
class DistributedIterator(object):
|
||||||
"""Common implementation for all input iterators."""
|
"""Common implementation for all input iterators."""
|
||||||
|
|
||||||
|
@ -210,9 +234,10 @@ class DistributedIterator(object):
|
||||||
for iterator in iterators:
|
for iterator in iterators:
|
||||||
if not isinstance(iterator, _SingleWorkerDatasetIterator):
|
if not isinstance(iterator, _SingleWorkerDatasetIterator):
|
||||||
continue
|
continue
|
||||||
flattened_shapes = nest.flatten(iterator.output_shapes)
|
flattened = zip(nest.flatten(iterator.output_shapes),
|
||||||
for output_shape in flattened_shapes:
|
nest.flatten(iterator.output_classes))
|
||||||
if not output_shape.is_fully_defined():
|
for output_shape, output_class in flattened:
|
||||||
|
if not _is_statically_shaped(output_class, output_shape):
|
||||||
static_shape = False
|
static_shape = False
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
|
@ -525,13 +525,16 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
|
||||||
],
|
],
|
||||||
input_type=["dataset", "input_fn"],
|
input_type=["dataset", "input_fn"],
|
||||||
drop_remainder=[False, True],
|
drop_remainder=[False, True],
|
||||||
defun=[lambda f: f, def_function.function],
|
defun_type=["lambda", "tf_function"],
|
||||||
))
|
))
|
||||||
def testRaggedSparse(self, distribution, input_type, drop_remainder, defun):
|
def testRaggedSparse(self, distribution, input_type, drop_remainder,
|
||||||
|
defun_type):
|
||||||
"""Test with `RaggedTensor`s and `SparseTensor`s."""
|
"""Test with `RaggedTensor`s and `SparseTensor`s."""
|
||||||
if not tf2.enabled():
|
if not tf2.enabled():
|
||||||
self.skipTest("Only V2 is supported.")
|
self.skipTest("Only V2 is supported.")
|
||||||
|
|
||||||
|
defun = {"lambda": lambda f: f,
|
||||||
|
"tf_function": def_function.function}[defun_type]
|
||||||
distribution.extended.experimental_enable_get_next_as_optional = True
|
distribution.extended.experimental_enable_get_next_as_optional = True
|
||||||
global_batch_size = 8
|
global_batch_size = 8
|
||||||
|
|
||||||
|
@ -609,14 +612,72 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
|
||||||
except (StopIteration, errors.OutOfRangeError):
|
except (StopIteration, errors.OutOfRangeError):
|
||||||
return sums
|
return sums
|
||||||
|
|
||||||
sums = sum_while_loop(
|
while_sums = sum_while_loop(
|
||||||
iter(dataset),
|
iter(dataset),
|
||||||
defun(lambda state, iterator: _reduce(state, next(iterator))))
|
defun(lambda state, iterator: _reduce(state, next(iterator))))
|
||||||
self.assertDictEqual(sums, defun(sum_for_loop)(dataset))
|
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
nest.flatten(sums),
|
nest.flatten(while_sums),
|
||||||
# When there's no partial batch, the sum is smaller.
|
# When there's no partial batch, the sum is smaller.
|
||||||
[200. if input_type == "dataset" and drop_remainder else 310.] * 3)
|
[200. if drop_remainder else 310.] * 3)
|
||||||
|
for_sums = defun(sum_for_loop)(dataset)
|
||||||
|
# For loops always call get next as optional inside tf functions, so we
|
||||||
|
# expect 310 here when using an input function (as there are 5 batches of
|
||||||
|
# size 4 round robined over 2 replicas.
|
||||||
|
expected_for_sum = 200.
|
||||||
|
if (not drop_remainder or (
|
||||||
|
defun_type == "tf_function" and input_type == "input_fn")):
|
||||||
|
expected_for_sum = 310.
|
||||||
|
self.assertAllEqual(nest.flatten(for_sums), [expected_for_sum] * 3)
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
mode=["eager"],
|
||||||
|
distribution=[
|
||||||
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
|
strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
|
||||||
|
strategy_combinations.one_device_strategy,
|
||||||
|
strategy_combinations.mirrored_strategy_with_one_cpu
|
||||||
|
],
|
||||||
|
input_type=["dataset", "input_fn"],
|
||||||
|
drop_remainder=[False, True],
|
||||||
|
tensor_type=["sparse", "ragged"],
|
||||||
|
enable_get_next_as_optional=[True, False]
|
||||||
|
))
|
||||||
|
def testRaggedSparseGetNextAsOptional(
|
||||||
|
self, distribution, input_type, drop_remainder, tensor_type,
|
||||||
|
enable_get_next_as_optional):
|
||||||
|
"""Test with `RaggedTensor`s and `SparseTensor`s."""
|
||||||
|
if not tf2.enabled():
|
||||||
|
self.skipTest("Only V2 is supported.")
|
||||||
|
|
||||||
|
distribution.extended.experimental_enable_get_next_as_optional = (
|
||||||
|
enable_get_next_as_optional)
|
||||||
|
global_batch_size = 8
|
||||||
|
|
||||||
|
def dataset_fn(ctx=None):
|
||||||
|
ctx = ctx or distribute_lib.InputContext()
|
||||||
|
batch_size = ctx.get_per_replica_batch_size(global_batch_size)
|
||||||
|
# Use 20 which isn't divisible by 8 to test partial batch behavior.
|
||||||
|
row_lengths = np.mod(np.arange(20), 4).astype(np.int64)
|
||||||
|
ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
|
||||||
|
np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths)
|
||||||
|
dataset = dataset_ops.DatasetV2.from_tensor_slices({
|
||||||
|
tensor_type: (ragged_tensor if tensor_type == "ragged" else
|
||||||
|
ragged_tensor.to_sparse()),
|
||||||
|
})
|
||||||
|
dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
|
||||||
|
return dataset.batch(batch_size, drop_remainder=drop_remainder)
|
||||||
|
|
||||||
|
if input_type == "dataset":
|
||||||
|
ds = distribution.experimental_distribute_dataset(
|
||||||
|
dataset_fn(distribute_lib.InputContext()))
|
||||||
|
else:
|
||||||
|
ds = distribution.experimental_distribute_datasets_from_function(
|
||||||
|
dataset_fn)
|
||||||
|
iterator = iter(ds)
|
||||||
|
|
||||||
|
self.assertEqual(iterator._enable_get_next_as_optional,
|
||||||
|
(not drop_remainder) and enable_get_next_as_optional)
|
||||||
|
|
||||||
|
|
||||||
class DistributedIteratorMultiWorkerTest(
|
class DistributedIteratorMultiWorkerTest(
|
||||||
|
|
Loading…
Reference in New Issue