Add number_of_partitions to the InfeedQueue API to allow infeed to be processed with pure data parallelism, but the partition/resharding happens inside the model function.
PiperOrigin-RevId: 324063661 Change-Id: Ieba2ca2b0a5092cceea4b51e34fb1b30d539d579
This commit is contained in:
parent
f999bb1785
commit
c1336e9a40
tensorflow/python/tpu
@ -135,6 +135,7 @@ class InfeedQueue(object):
|
||||
tuple_types=None,
|
||||
tuple_shapes=None,
|
||||
shard_dimensions=None,
|
||||
number_of_partitions=None,
|
||||
name=None):
|
||||
"""Creates a new InfeedQueue with the given configuration.
|
||||
|
||||
@ -150,6 +151,13 @@ class InfeedQueue(object):
|
||||
shard_dimensions: if not None, a list of dimensions on which the
|
||||
elements of the queue should be sharded during automatic
|
||||
parallelization.
|
||||
number_of_partitions: if > 1, the infeed dequeue shape will contain
|
||||
the full shape that includes all partitions and add corresponding XLA
|
||||
annotation on the infeed dequeue op. In this case, the infeed is still
|
||||
data parallel that feeds per-core batch size to each core while the XLA
|
||||
computation may be partitioned. As XLA requires infeed dequeue shape to
|
||||
be per-replica shape, thus we need number_of_partitions here to
|
||||
calculate the per-replica unpartitioned shape.
|
||||
name: the name of the queue.
|
||||
|
||||
Raises:
|
||||
@ -166,6 +174,10 @@ class InfeedQueue(object):
|
||||
self._generated_enqueue_ops = False
|
||||
self._generated_dequeue_op = False
|
||||
self._name = "InfeedQueue" if name is None else name
|
||||
if number_of_partitions is None:
|
||||
self._number_of_partitions = 1
|
||||
else:
|
||||
self._number_of_partitions = number_of_partitions
|
||||
if number_of_tuple_elements is None:
|
||||
if tuple_types is not None:
|
||||
number_of_tuple_elements = len(tuple_types)
|
||||
@ -359,6 +371,7 @@ class InfeedQueue(object):
|
||||
"""
|
||||
for policy in self._sharding_policies:
|
||||
policy.set_number_of_shards(number_of_shards)
|
||||
policy.set_number_of_partitions(self._number_of_partitions)
|
||||
self._validate()
|
||||
|
||||
def set_configuration_from_input_tensors(self, input_tensors):
|
||||
@ -485,16 +498,23 @@ class InfeedQueue(object):
|
||||
self._generated_dequeue_op = True
|
||||
full_name = "%s/dequeue" % self._name
|
||||
sharded_shapes = [
|
||||
policy.get_sharded_shape(shape)
|
||||
policy.get_unpartitioned_shape(policy.get_sharded_shape(shape))
|
||||
for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
|
||||
]
|
||||
if tpu_device is not None:
|
||||
with ops.device(tpu.core(tpu_device)):
|
||||
return tpu_ops.infeed_dequeue_tuple(
|
||||
dequeue_op = tpu_ops.infeed_dequeue_tuple(
|
||||
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
|
||||
else:
|
||||
return tpu_ops.infeed_dequeue_tuple(
|
||||
dequeue_op = tpu_ops.infeed_dequeue_tuple(
|
||||
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
|
||||
if self._number_of_partitions <= 1:
|
||||
return dequeue_op
|
||||
partitions = [
|
||||
policy.get_unpartitioned_shape([1] * shape.ndims).as_list()
|
||||
for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
|
||||
]
|
||||
return tag_sharding_attribute_for_dequeued_tensors(dequeue_op, partitions)
|
||||
|
||||
def _generate_enqueue_op(self,
|
||||
inputs,
|
||||
|
@ -34,6 +34,7 @@ class ShardingPolicy(object):
|
||||
|
||||
def __init__(self):
|
||||
self._number_of_shards = None
|
||||
self._number_of_partitions = 1
|
||||
self._shard_dimension = None
|
||||
self._frozen = False
|
||||
|
||||
@ -92,6 +93,32 @@ class ShardingPolicy(object):
|
||||
"Can't set sharding policy to use %s shards; value must be >0" %
|
||||
str(number_of_shards))
|
||||
|
||||
@property
|
||||
def number_of_partitions(self):
|
||||
"""Returns the number of partitions of the policy or None if unspecified."""
|
||||
return self._number_of_partitions
|
||||
|
||||
def set_number_of_partitions(self, number_of_partitions):
|
||||
"""Sets the number of partitions for the current policy.
|
||||
|
||||
If the policy has been frozen then shard_dimension must match the
|
||||
existing setting.
|
||||
|
||||
Args:
|
||||
number_of_partitions: The number of partitions to use in the policy.
|
||||
|
||||
Raises:
|
||||
ValueError: If the policy has been frozen and shard_dimension
|
||||
differs from the frozen value.
|
||||
"""
|
||||
if self._frozen:
|
||||
if self._number_of_partitions != number_of_partitions:
|
||||
raise ValueError(
|
||||
"Can't set number_of_partitions to %d since it has been frozen to "
|
||||
"use %d." % (number_of_partitions, self._number_of_partitions))
|
||||
else:
|
||||
self._number_of_partitions = number_of_partitions
|
||||
|
||||
@property
|
||||
def shard_dimension(self):
|
||||
"""Returns the shard dimension of the policy or None if unspecified."""
|
||||
@ -134,6 +161,34 @@ class ShardingPolicy(object):
|
||||
if other.shard_dimension is not None:
|
||||
self.set_shard_dimension(other.shard_dimension)
|
||||
|
||||
def get_unpartitioned_shape(self, shape):
|
||||
"""Returns the shape of an unpartitioned Tensor.
|
||||
|
||||
When given the shape of a 'sharded-size' Tensor, returns the shape
|
||||
of the full shape of its unpartitioned Tensor.
|
||||
|
||||
Args:
|
||||
shape: The shape of the sharded Tensor.
|
||||
|
||||
Returns:
|
||||
The shape of the unpartitioned version of the Tensor.
|
||||
|
||||
Raises:
|
||||
ValueError: if shape has unknown sharded dimension
|
||||
"""
|
||||
shape = tensor_shape.as_shape(shape)
|
||||
dims = shape.as_list()
|
||||
if (self._shard_dimension is None or self._number_of_partitions is None or
|
||||
not dims):
|
||||
return None
|
||||
if dims[self._shard_dimension] is None:
|
||||
raise ValueError("shape %s must have a fixed size for dimension %d "
|
||||
"that is known at graph construction time." %
|
||||
(shape.as_list(), self._shard_dimension))
|
||||
if self._number_of_partitions > 1:
|
||||
dims[self._shard_dimension] *= self._number_of_partitions
|
||||
return tensor_shape.as_shape(dims)
|
||||
|
||||
def get_sharded_shape(self, shape, shard_index=None):
|
||||
"""Returns the shape of a shard of a full Tensor.
|
||||
|
||||
|
@ -107,6 +107,17 @@ class ShardingTest(test.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = p.get_sharded_shape([4, 10], shard_index=-1)
|
||||
|
||||
def testGetUnpartitionedShape(self):
|
||||
"""Tests getting a sharded shape."""
|
||||
p = tpu_sharding.ShardingPolicy()
|
||||
p.set_number_of_shards(3)
|
||||
p.set_shard_dimension(1)
|
||||
p.set_number_of_partitions(4)
|
||||
self.assertEqual(p.get_unpartitioned_shape([3, 5]), [3, 20])
|
||||
p.freeze()
|
||||
with self.assertRaises(ValueError):
|
||||
_ = p.get_unpartitioned_shape([3, None])
|
||||
|
||||
def testGetUnshardedShape(self):
|
||||
"""Tests getting an unsharded shape."""
|
||||
p = tpu_sharding.ShardingPolicy()
|
||||
|
Loading…
Reference in New Issue
Block a user