From 8f3272028b674ad08c80ae1e0f31d7ce56f8295e Mon Sep 17 00:00:00 2001 From: Bruce Fontaine Date: Tue, 18 Feb 2020 13:35:16 -0800 Subject: [PATCH] Update static shape detection to be static batch size detection for sparse or ragged tensors. This is needed as when they are batched by a dataset they will typically have a shape like (batch_size, None). PiperOrigin-RevId: 295809971 Change-Id: I64d2fed27e0766c8857141bc28c581086155f77e --- tensorflow/python/distribute/input_lib.py | 31 +++++++- .../python/distribute/input_lib_test.py | 73 +++++++++++++++++-- 2 files changed, 95 insertions(+), 9 deletions(-) diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py index aa02323c75e..163f775cc93 100644 --- a/tensorflow/python/distribute/input_lib.py +++ b/tensorflow/python/distribute/input_lib.py @@ -202,6 +202,30 @@ def _get_next_as_optional(iterator, strategy, name=None): return global_has_value, replicas +def _is_statically_shaped(tensor_class, shape): + """Test if an iteratort output is statically shaped. + + For sparse and ragged tensors this only tests the batch dimension. + + Args: + tensor_class: a class from an iterator.output_classes list. + shape: a TensorShape from an iterator.output_shapes list. + + Returns: + True if the shape is static, false otherwise. + """ + if (tensor_class == sparse_tensor.SparseTensor or + isinstance(tensor_class, ragged_tensor.RaggedTensorSpec)): + # For sparse or ragged tensor, we should only check the first + # dimension in order to get_next_as_optional. This is because + # when these tensors get batched by dataset only the batch dimension + # is set. + if shape.rank > 0 and shape.as_list()[0] is None: + return False + return True + return shape.is_fully_defined() + + class DistributedIterator(object): """Common implementation for all input iterators.""" @@ -210,9 +234,10 @@ class DistributedIterator(object): for iterator in iterators: if not isinstance(iterator, _SingleWorkerDatasetIterator): continue - flattened_shapes = nest.flatten(iterator.output_shapes) - for output_shape in flattened_shapes: - if not output_shape.is_fully_defined(): + flattened = zip(nest.flatten(iterator.output_shapes), + nest.flatten(iterator.output_classes)) + for output_shape, output_class in flattened: + if not _is_statically_shaped(output_class, output_shape): static_shape = False break diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py index 3c59d0f5e43..80d5db38403 100644 --- a/tensorflow/python/distribute/input_lib_test.py +++ b/tensorflow/python/distribute/input_lib_test.py @@ -525,13 +525,16 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase, ], input_type=["dataset", "input_fn"], drop_remainder=[False, True], - defun=[lambda f: f, def_function.function], + defun_type=["lambda", "tf_function"], )) - def testRaggedSparse(self, distribution, input_type, drop_remainder, defun): + def testRaggedSparse(self, distribution, input_type, drop_remainder, + defun_type): """Test with `RaggedTensor`s and `SparseTensor`s.""" if not tf2.enabled(): self.skipTest("Only V2 is supported.") + defun = {"lambda": lambda f: f, + "tf_function": def_function.function}[defun_type] distribution.extended.experimental_enable_get_next_as_optional = True global_batch_size = 8 @@ -609,14 +612,72 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase, except (StopIteration, errors.OutOfRangeError): return sums - sums = sum_while_loop( + while_sums = sum_while_loop( iter(dataset), defun(lambda state, iterator: _reduce(state, next(iterator)))) - self.assertDictEqual(sums, defun(sum_for_loop)(dataset)) self.assertAllEqual( - nest.flatten(sums), + nest.flatten(while_sums), # When there's no partial batch, the sum is smaller. - [200. if input_type == "dataset" and drop_remainder else 310.] * 3) + [200. if drop_remainder else 310.] * 3) + for_sums = defun(sum_for_loop)(dataset) + # For loops always call get next as optional inside tf functions, so we + # expect 310 here when using an input function (as there are 5 batches of + # size 4 round robined over 2 replicas. + expected_for_sum = 200. + if (not drop_remainder or ( + defun_type == "tf_function" and input_type == "input_fn")): + expected_for_sum = 310. + self.assertAllEqual(nest.flatten(for_sums), [expected_for_sum] * 3) + + @combinations.generate( + combinations.combine( + mode=["eager"], + distribution=[ + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.central_storage_strategy_with_gpu_and_cpu, + strategy_combinations.one_device_strategy, + strategy_combinations.mirrored_strategy_with_one_cpu + ], + input_type=["dataset", "input_fn"], + drop_remainder=[False, True], + tensor_type=["sparse", "ragged"], + enable_get_next_as_optional=[True, False] + )) + def testRaggedSparseGetNextAsOptional( + self, distribution, input_type, drop_remainder, tensor_type, + enable_get_next_as_optional): + """Test with `RaggedTensor`s and `SparseTensor`s.""" + if not tf2.enabled(): + self.skipTest("Only V2 is supported.") + + distribution.extended.experimental_enable_get_next_as_optional = ( + enable_get_next_as_optional) + global_batch_size = 8 + + def dataset_fn(ctx=None): + ctx = ctx or distribute_lib.InputContext() + batch_size = ctx.get_per_replica_batch_size(global_batch_size) + # Use 20 which isn't divisible by 8 to test partial batch behavior. + row_lengths = np.mod(np.arange(20), 4).astype(np.int64) + ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths( + np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths) + dataset = dataset_ops.DatasetV2.from_tensor_slices({ + tensor_type: (ragged_tensor if tensor_type == "ragged" else + ragged_tensor.to_sparse()), + }) + dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) + return dataset.batch(batch_size, drop_remainder=drop_remainder) + + if input_type == "dataset": + ds = distribution.experimental_distribute_dataset( + dataset_fn(distribute_lib.InputContext())) + else: + ds = distribution.experimental_distribute_datasets_from_function( + dataset_fn) + iterator = iter(ds) + + self.assertEqual(iterator._enable_get_next_as_optional, + (not drop_remainder) and enable_get_next_as_optional) class DistributedIteratorMultiWorkerTest(