Update static shape detection to be static batch size detection for sparse or ragged tensors. This is needed as when they are batched by a dataset they will typically have a shape like (batch_size, None).

PiperOrigin-RevId: 295809971
Change-Id: I64d2fed27e0766c8857141bc28c581086155f77e
This commit is contained in:
Bruce Fontaine 2020-02-18 13:35:16 -08:00 committed by TensorFlower Gardener
parent 2c452720ee
commit 8f3272028b
2 changed files with 95 additions and 9 deletions

View File

@ -202,6 +202,30 @@ def _get_next_as_optional(iterator, strategy, name=None):
return global_has_value, replicas
def _is_statically_shaped(tensor_class, shape):
"""Test if an iteratort output is statically shaped.
For sparse and ragged tensors this only tests the batch dimension.
Args:
tensor_class: a class from an iterator.output_classes list.
shape: a TensorShape from an iterator.output_shapes list.
Returns:
True if the shape is static, false otherwise.
"""
if (tensor_class == sparse_tensor.SparseTensor or
isinstance(tensor_class, ragged_tensor.RaggedTensorSpec)):
# For sparse or ragged tensor, we should only check the first
# dimension in order to get_next_as_optional. This is because
# when these tensors get batched by dataset only the batch dimension
# is set.
if shape.rank > 0 and shape.as_list()[0] is None:
return False
return True
return shape.is_fully_defined()
class DistributedIterator(object):
"""Common implementation for all input iterators."""
@ -210,9 +234,10 @@ class DistributedIterator(object):
for iterator in iterators:
if not isinstance(iterator, _SingleWorkerDatasetIterator):
continue
flattened_shapes = nest.flatten(iterator.output_shapes)
for output_shape in flattened_shapes:
if not output_shape.is_fully_defined():
flattened = zip(nest.flatten(iterator.output_shapes),
nest.flatten(iterator.output_classes))
for output_shape, output_class in flattened:
if not _is_statically_shaped(output_class, output_shape):
static_shape = False
break

View File

@ -525,13 +525,16 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
],
input_type=["dataset", "input_fn"],
drop_remainder=[False, True],
defun=[lambda f: f, def_function.function],
defun_type=["lambda", "tf_function"],
))
def testRaggedSparse(self, distribution, input_type, drop_remainder, defun):
def testRaggedSparse(self, distribution, input_type, drop_remainder,
defun_type):
"""Test with `RaggedTensor`s and `SparseTensor`s."""
if not tf2.enabled():
self.skipTest("Only V2 is supported.")
defun = {"lambda": lambda f: f,
"tf_function": def_function.function}[defun_type]
distribution.extended.experimental_enable_get_next_as_optional = True
global_batch_size = 8
@ -609,14 +612,72 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
except (StopIteration, errors.OutOfRangeError):
return sums
sums = sum_while_loop(
while_sums = sum_while_loop(
iter(dataset),
defun(lambda state, iterator: _reduce(state, next(iterator))))
self.assertDictEqual(sums, defun(sum_for_loop)(dataset))
self.assertAllEqual(
nest.flatten(sums),
nest.flatten(while_sums),
# When there's no partial batch, the sum is smaller.
[200. if input_type == "dataset" and drop_remainder else 310.] * 3)
[200. if drop_remainder else 310.] * 3)
for_sums = defun(sum_for_loop)(dataset)
# For loops always call get next as optional inside tf functions, so we
# expect 310 here when using an input function (as there are 5 batches of
# size 4 round robined over 2 replicas.
expected_for_sum = 200.
if (not drop_remainder or (
defun_type == "tf_function" and input_type == "input_fn")):
expected_for_sum = 310.
self.assertAllEqual(nest.flatten(for_sums), [expected_for_sum] * 3)
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
strategy_combinations.one_device_strategy,
strategy_combinations.mirrored_strategy_with_one_cpu
],
input_type=["dataset", "input_fn"],
drop_remainder=[False, True],
tensor_type=["sparse", "ragged"],
enable_get_next_as_optional=[True, False]
))
def testRaggedSparseGetNextAsOptional(
self, distribution, input_type, drop_remainder, tensor_type,
enable_get_next_as_optional):
"""Test with `RaggedTensor`s and `SparseTensor`s."""
if not tf2.enabled():
self.skipTest("Only V2 is supported.")
distribution.extended.experimental_enable_get_next_as_optional = (
enable_get_next_as_optional)
global_batch_size = 8
def dataset_fn(ctx=None):
ctx = ctx or distribute_lib.InputContext()
batch_size = ctx.get_per_replica_batch_size(global_batch_size)
# Use 20 which isn't divisible by 8 to test partial batch behavior.
row_lengths = np.mod(np.arange(20), 4).astype(np.int64)
ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths)
dataset = dataset_ops.DatasetV2.from_tensor_slices({
tensor_type: (ragged_tensor if tensor_type == "ragged" else
ragged_tensor.to_sparse()),
})
dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
return dataset.batch(batch_size, drop_remainder=drop_remainder)
if input_type == "dataset":
ds = distribution.experimental_distribute_dataset(
dataset_fn(distribute_lib.InputContext()))
else:
ds = distribution.experimental_distribute_datasets_from_function(
dataset_fn)
iterator = iter(ds)
self.assertEqual(iterator._enable_get_next_as_optional,
(not drop_remainder) and enable_get_next_as_optional)
class DistributedIteratorMultiWorkerTest(