Add number_of_partitions to the InfeedQueue API to allow infeed to be processed with pure data parallelism, but the partition/resharding happens inside the model function.

PiperOrigin-RevId: 324063661
Change-Id: Ieba2ca2b0a5092cceea4b51e34fb1b30d539d579
This commit is contained in:
A. Unique TensorFlower 2020-07-30 12:54:35 -07:00 committed by TensorFlower Gardener
parent f999bb1785
commit c1336e9a40
3 changed files with 89 additions and 3 deletions

View File

@ -135,6 +135,7 @@ class InfeedQueue(object):
tuple_types=None,
tuple_shapes=None,
shard_dimensions=None,
number_of_partitions=None,
name=None):
"""Creates a new InfeedQueue with the given configuration.
@ -150,6 +151,13 @@ class InfeedQueue(object):
shard_dimensions: if not None, a list of dimensions on which the
elements of the queue should be sharded during automatic
parallelization.
number_of_partitions: if > 1, the infeed dequeue shape will contain
the full shape that includes all partitions and add corresponding XLA
annotation on the infeed dequeue op. In this case, the infeed is still
data parallel that feeds per-core batch size to each core while the XLA
computation may be partitioned. As XLA requires infeed dequeue shape to
be per-replica shape, thus we need number_of_partitions here to
calculate the per-replica unpartitioned shape.
name: the name of the queue.
Raises:
@ -166,6 +174,10 @@ class InfeedQueue(object):
self._generated_enqueue_ops = False
self._generated_dequeue_op = False
self._name = "InfeedQueue" if name is None else name
if number_of_partitions is None:
self._number_of_partitions = 1
else:
self._number_of_partitions = number_of_partitions
if number_of_tuple_elements is None:
if tuple_types is not None:
number_of_tuple_elements = len(tuple_types)
@ -359,6 +371,7 @@ class InfeedQueue(object):
"""
for policy in self._sharding_policies:
policy.set_number_of_shards(number_of_shards)
policy.set_number_of_partitions(self._number_of_partitions)
self._validate()
def set_configuration_from_input_tensors(self, input_tensors):
@ -485,16 +498,23 @@ class InfeedQueue(object):
self._generated_dequeue_op = True
full_name = "%s/dequeue" % self._name
sharded_shapes = [
policy.get_sharded_shape(shape)
policy.get_unpartitioned_shape(policy.get_sharded_shape(shape))
for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
]
if tpu_device is not None:
with ops.device(tpu.core(tpu_device)):
return tpu_ops.infeed_dequeue_tuple(
dequeue_op = tpu_ops.infeed_dequeue_tuple(
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
else:
return tpu_ops.infeed_dequeue_tuple(
dequeue_op = tpu_ops.infeed_dequeue_tuple(
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
if self._number_of_partitions <= 1:
return dequeue_op
partitions = [
policy.get_unpartitioned_shape([1] * shape.ndims).as_list()
for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
]
return tag_sharding_attribute_for_dequeued_tensors(dequeue_op, partitions)
def _generate_enqueue_op(self,
inputs,

View File

@ -34,6 +34,7 @@ class ShardingPolicy(object):
def __init__(self):
self._number_of_shards = None
self._number_of_partitions = 1
self._shard_dimension = None
self._frozen = False
@ -92,6 +93,32 @@ class ShardingPolicy(object):
"Can't set sharding policy to use %s shards; value must be >0" %
str(number_of_shards))
@property
def number_of_partitions(self):
"""Returns the number of partitions of the policy or None if unspecified."""
return self._number_of_partitions
def set_number_of_partitions(self, number_of_partitions):
"""Sets the number of partitions for the current policy.
If the policy has been frozen then shard_dimension must match the
existing setting.
Args:
number_of_partitions: The number of partitions to use in the policy.
Raises:
ValueError: If the policy has been frozen and shard_dimension
differs from the frozen value.
"""
if self._frozen:
if self._number_of_partitions != number_of_partitions:
raise ValueError(
"Can't set number_of_partitions to %d since it has been frozen to "
"use %d." % (number_of_partitions, self._number_of_partitions))
else:
self._number_of_partitions = number_of_partitions
@property
def shard_dimension(self):
"""Returns the shard dimension of the policy or None if unspecified."""
@ -134,6 +161,34 @@ class ShardingPolicy(object):
if other.shard_dimension is not None:
self.set_shard_dimension(other.shard_dimension)
def get_unpartitioned_shape(self, shape):
"""Returns the shape of an unpartitioned Tensor.
When given the shape of a 'sharded-size' Tensor, returns the shape
of the full shape of its unpartitioned Tensor.
Args:
shape: The shape of the sharded Tensor.
Returns:
The shape of the unpartitioned version of the Tensor.
Raises:
ValueError: if shape has unknown sharded dimension
"""
shape = tensor_shape.as_shape(shape)
dims = shape.as_list()
if (self._shard_dimension is None or self._number_of_partitions is None or
not dims):
return None
if dims[self._shard_dimension] is None:
raise ValueError("shape %s must have a fixed size for dimension %d "
"that is known at graph construction time." %
(shape.as_list(), self._shard_dimension))
if self._number_of_partitions > 1:
dims[self._shard_dimension] *= self._number_of_partitions
return tensor_shape.as_shape(dims)
def get_sharded_shape(self, shape, shard_index=None):
"""Returns the shape of a shard of a full Tensor.

View File

@ -107,6 +107,17 @@ class ShardingTest(test.TestCase):
with self.assertRaises(ValueError):
_ = p.get_sharded_shape([4, 10], shard_index=-1)
def testGetUnpartitionedShape(self):
"""Tests getting a sharded shape."""
p = tpu_sharding.ShardingPolicy()
p.set_number_of_shards(3)
p.set_shard_dimension(1)
p.set_number_of_partitions(4)
self.assertEqual(p.get_unpartitioned_shape([3, 5]), [3, 20])
p.freeze()
with self.assertRaises(ValueError):
_ = p.get_unpartitioned_shape([3, None])
def testGetUnshardedShape(self):
"""Tests getting an unsharded shape."""
p = tpu_sharding.ShardingPolicy()