From c1336e9a40be28faa285b1407ab7b618a1db247a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 30 Jul 2020 12:54:35 -0700 Subject: [PATCH] Add number_of_partitions to the InfeedQueue API to allow infeed to be processed with pure data parallelism, but the partition/resharding happens inside the model function. PiperOrigin-RevId: 324063661 Change-Id: Ieba2ca2b0a5092cceea4b51e34fb1b30d539d579 --- tensorflow/python/tpu/tpu_feed.py | 26 ++++++++-- tensorflow/python/tpu/tpu_sharding.py | 55 ++++++++++++++++++++++ tensorflow/python/tpu/tpu_sharding_test.py | 11 +++++ 3 files changed, 89 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/tpu/tpu_feed.py b/tensorflow/python/tpu/tpu_feed.py index ce5f9aa6b8b..d3b66e3fd08 100644 --- a/tensorflow/python/tpu/tpu_feed.py +++ b/tensorflow/python/tpu/tpu_feed.py @@ -135,6 +135,7 @@ class InfeedQueue(object): tuple_types=None, tuple_shapes=None, shard_dimensions=None, + number_of_partitions=None, name=None): """Creates a new InfeedQueue with the given configuration. @@ -150,6 +151,13 @@ class InfeedQueue(object): shard_dimensions: if not None, a list of dimensions on which the elements of the queue should be sharded during automatic parallelization. + number_of_partitions: if > 1, the infeed dequeue shape will contain + the full shape that includes all partitions and add corresponding XLA + annotation on the infeed dequeue op. In this case, the infeed is still + data parallel that feeds per-core batch size to each core while the XLA + computation may be partitioned. As XLA requires infeed dequeue shape to + be per-replica shape, thus we need number_of_partitions here to + calculate the per-replica unpartitioned shape. name: the name of the queue. Raises: @@ -166,6 +174,10 @@ class InfeedQueue(object): self._generated_enqueue_ops = False self._generated_dequeue_op = False self._name = "InfeedQueue" if name is None else name + if number_of_partitions is None: + self._number_of_partitions = 1 + else: + self._number_of_partitions = number_of_partitions if number_of_tuple_elements is None: if tuple_types is not None: number_of_tuple_elements = len(tuple_types) @@ -359,6 +371,7 @@ class InfeedQueue(object): """ for policy in self._sharding_policies: policy.set_number_of_shards(number_of_shards) + policy.set_number_of_partitions(self._number_of_partitions) self._validate() def set_configuration_from_input_tensors(self, input_tensors): @@ -485,16 +498,23 @@ class InfeedQueue(object): self._generated_dequeue_op = True full_name = "%s/dequeue" % self._name sharded_shapes = [ - policy.get_sharded_shape(shape) + policy.get_unpartitioned_shape(policy.get_sharded_shape(shape)) for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) ] if tpu_device is not None: with ops.device(tpu.core(tpu_device)): - return tpu_ops.infeed_dequeue_tuple( + dequeue_op = tpu_ops.infeed_dequeue_tuple( dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) else: - return tpu_ops.infeed_dequeue_tuple( + dequeue_op = tpu_ops.infeed_dequeue_tuple( dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) + if self._number_of_partitions <= 1: + return dequeue_op + partitions = [ + policy.get_unpartitioned_shape([1] * shape.ndims).as_list() + for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) + ] + return tag_sharding_attribute_for_dequeued_tensors(dequeue_op, partitions) def _generate_enqueue_op(self, inputs, diff --git a/tensorflow/python/tpu/tpu_sharding.py b/tensorflow/python/tpu/tpu_sharding.py index c6f5017efbd..6fd4256a8a1 100644 --- a/tensorflow/python/tpu/tpu_sharding.py +++ b/tensorflow/python/tpu/tpu_sharding.py @@ -34,6 +34,7 @@ class ShardingPolicy(object): def __init__(self): self._number_of_shards = None + self._number_of_partitions = 1 self._shard_dimension = None self._frozen = False @@ -92,6 +93,32 @@ class ShardingPolicy(object): "Can't set sharding policy to use %s shards; value must be >0" % str(number_of_shards)) + @property + def number_of_partitions(self): + """Returns the number of partitions of the policy or None if unspecified.""" + return self._number_of_partitions + + def set_number_of_partitions(self, number_of_partitions): + """Sets the number of partitions for the current policy. + + If the policy has been frozen then shard_dimension must match the + existing setting. + + Args: + number_of_partitions: The number of partitions to use in the policy. + + Raises: + ValueError: If the policy has been frozen and shard_dimension + differs from the frozen value. + """ + if self._frozen: + if self._number_of_partitions != number_of_partitions: + raise ValueError( + "Can't set number_of_partitions to %d since it has been frozen to " + "use %d." % (number_of_partitions, self._number_of_partitions)) + else: + self._number_of_partitions = number_of_partitions + @property def shard_dimension(self): """Returns the shard dimension of the policy or None if unspecified.""" @@ -134,6 +161,34 @@ class ShardingPolicy(object): if other.shard_dimension is not None: self.set_shard_dimension(other.shard_dimension) + def get_unpartitioned_shape(self, shape): + """Returns the shape of an unpartitioned Tensor. + + When given the shape of a 'sharded-size' Tensor, returns the shape + of the full shape of its unpartitioned Tensor. + + Args: + shape: The shape of the sharded Tensor. + + Returns: + The shape of the unpartitioned version of the Tensor. + + Raises: + ValueError: if shape has unknown sharded dimension + """ + shape = tensor_shape.as_shape(shape) + dims = shape.as_list() + if (self._shard_dimension is None or self._number_of_partitions is None or + not dims): + return None + if dims[self._shard_dimension] is None: + raise ValueError("shape %s must have a fixed size for dimension %d " + "that is known at graph construction time." % + (shape.as_list(), self._shard_dimension)) + if self._number_of_partitions > 1: + dims[self._shard_dimension] *= self._number_of_partitions + return tensor_shape.as_shape(dims) + def get_sharded_shape(self, shape, shard_index=None): """Returns the shape of a shard of a full Tensor. diff --git a/tensorflow/python/tpu/tpu_sharding_test.py b/tensorflow/python/tpu/tpu_sharding_test.py index 21d2a0897a0..0d67939adfa 100644 --- a/tensorflow/python/tpu/tpu_sharding_test.py +++ b/tensorflow/python/tpu/tpu_sharding_test.py @@ -107,6 +107,17 @@ class ShardingTest(test.TestCase): with self.assertRaises(ValueError): _ = p.get_sharded_shape([4, 10], shard_index=-1) + def testGetUnpartitionedShape(self): + """Tests getting a sharded shape.""" + p = tpu_sharding.ShardingPolicy() + p.set_number_of_shards(3) + p.set_shard_dimension(1) + p.set_number_of_partitions(4) + self.assertEqual(p.get_unpartitioned_shape([3, 5]), [3, 20]) + p.freeze() + with self.assertRaises(ValueError): + _ = p.get_unpartitioned_shape([3, None]) + def testGetUnshardedShape(self): """Tests getting an unsharded shape.""" p = tpu_sharding.ShardingPolicy()