Add number_of_partitions to the InfeedQueue API to allow infeed to be processed with pure data parallelism, but the partition/resharding happens inside the model function.
PiperOrigin-RevId: 324063661 Change-Id: Ieba2ca2b0a5092cceea4b51e34fb1b30d539d579
This commit is contained in:
parent
f999bb1785
commit
c1336e9a40
@ -135,6 +135,7 @@ class InfeedQueue(object):
|
|||||||
tuple_types=None,
|
tuple_types=None,
|
||||||
tuple_shapes=None,
|
tuple_shapes=None,
|
||||||
shard_dimensions=None,
|
shard_dimensions=None,
|
||||||
|
number_of_partitions=None,
|
||||||
name=None):
|
name=None):
|
||||||
"""Creates a new InfeedQueue with the given configuration.
|
"""Creates a new InfeedQueue with the given configuration.
|
||||||
|
|
||||||
@ -150,6 +151,13 @@ class InfeedQueue(object):
|
|||||||
shard_dimensions: if not None, a list of dimensions on which the
|
shard_dimensions: if not None, a list of dimensions on which the
|
||||||
elements of the queue should be sharded during automatic
|
elements of the queue should be sharded during automatic
|
||||||
parallelization.
|
parallelization.
|
||||||
|
number_of_partitions: if > 1, the infeed dequeue shape will contain
|
||||||
|
the full shape that includes all partitions and add corresponding XLA
|
||||||
|
annotation on the infeed dequeue op. In this case, the infeed is still
|
||||||
|
data parallel that feeds per-core batch size to each core while the XLA
|
||||||
|
computation may be partitioned. As XLA requires infeed dequeue shape to
|
||||||
|
be per-replica shape, thus we need number_of_partitions here to
|
||||||
|
calculate the per-replica unpartitioned shape.
|
||||||
name: the name of the queue.
|
name: the name of the queue.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -166,6 +174,10 @@ class InfeedQueue(object):
|
|||||||
self._generated_enqueue_ops = False
|
self._generated_enqueue_ops = False
|
||||||
self._generated_dequeue_op = False
|
self._generated_dequeue_op = False
|
||||||
self._name = "InfeedQueue" if name is None else name
|
self._name = "InfeedQueue" if name is None else name
|
||||||
|
if number_of_partitions is None:
|
||||||
|
self._number_of_partitions = 1
|
||||||
|
else:
|
||||||
|
self._number_of_partitions = number_of_partitions
|
||||||
if number_of_tuple_elements is None:
|
if number_of_tuple_elements is None:
|
||||||
if tuple_types is not None:
|
if tuple_types is not None:
|
||||||
number_of_tuple_elements = len(tuple_types)
|
number_of_tuple_elements = len(tuple_types)
|
||||||
@ -359,6 +371,7 @@ class InfeedQueue(object):
|
|||||||
"""
|
"""
|
||||||
for policy in self._sharding_policies:
|
for policy in self._sharding_policies:
|
||||||
policy.set_number_of_shards(number_of_shards)
|
policy.set_number_of_shards(number_of_shards)
|
||||||
|
policy.set_number_of_partitions(self._number_of_partitions)
|
||||||
self._validate()
|
self._validate()
|
||||||
|
|
||||||
def set_configuration_from_input_tensors(self, input_tensors):
|
def set_configuration_from_input_tensors(self, input_tensors):
|
||||||
@ -485,16 +498,23 @@ class InfeedQueue(object):
|
|||||||
self._generated_dequeue_op = True
|
self._generated_dequeue_op = True
|
||||||
full_name = "%s/dequeue" % self._name
|
full_name = "%s/dequeue" % self._name
|
||||||
sharded_shapes = [
|
sharded_shapes = [
|
||||||
policy.get_sharded_shape(shape)
|
policy.get_unpartitioned_shape(policy.get_sharded_shape(shape))
|
||||||
for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
|
for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
|
||||||
]
|
]
|
||||||
if tpu_device is not None:
|
if tpu_device is not None:
|
||||||
with ops.device(tpu.core(tpu_device)):
|
with ops.device(tpu.core(tpu_device)):
|
||||||
return tpu_ops.infeed_dequeue_tuple(
|
dequeue_op = tpu_ops.infeed_dequeue_tuple(
|
||||||
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
|
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
|
||||||
else:
|
else:
|
||||||
return tpu_ops.infeed_dequeue_tuple(
|
dequeue_op = tpu_ops.infeed_dequeue_tuple(
|
||||||
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
|
dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
|
||||||
|
if self._number_of_partitions <= 1:
|
||||||
|
return dequeue_op
|
||||||
|
partitions = [
|
||||||
|
policy.get_unpartitioned_shape([1] * shape.ndims).as_list()
|
||||||
|
for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
|
||||||
|
]
|
||||||
|
return tag_sharding_attribute_for_dequeued_tensors(dequeue_op, partitions)
|
||||||
|
|
||||||
def _generate_enqueue_op(self,
|
def _generate_enqueue_op(self,
|
||||||
inputs,
|
inputs,
|
||||||
|
@ -34,6 +34,7 @@ class ShardingPolicy(object):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._number_of_shards = None
|
self._number_of_shards = None
|
||||||
|
self._number_of_partitions = 1
|
||||||
self._shard_dimension = None
|
self._shard_dimension = None
|
||||||
self._frozen = False
|
self._frozen = False
|
||||||
|
|
||||||
@ -92,6 +93,32 @@ class ShardingPolicy(object):
|
|||||||
"Can't set sharding policy to use %s shards; value must be >0" %
|
"Can't set sharding policy to use %s shards; value must be >0" %
|
||||||
str(number_of_shards))
|
str(number_of_shards))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def number_of_partitions(self):
|
||||||
|
"""Returns the number of partitions of the policy or None if unspecified."""
|
||||||
|
return self._number_of_partitions
|
||||||
|
|
||||||
|
def set_number_of_partitions(self, number_of_partitions):
|
||||||
|
"""Sets the number of partitions for the current policy.
|
||||||
|
|
||||||
|
If the policy has been frozen then shard_dimension must match the
|
||||||
|
existing setting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
number_of_partitions: The number of partitions to use in the policy.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the policy has been frozen and shard_dimension
|
||||||
|
differs from the frozen value.
|
||||||
|
"""
|
||||||
|
if self._frozen:
|
||||||
|
if self._number_of_partitions != number_of_partitions:
|
||||||
|
raise ValueError(
|
||||||
|
"Can't set number_of_partitions to %d since it has been frozen to "
|
||||||
|
"use %d." % (number_of_partitions, self._number_of_partitions))
|
||||||
|
else:
|
||||||
|
self._number_of_partitions = number_of_partitions
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shard_dimension(self):
|
def shard_dimension(self):
|
||||||
"""Returns the shard dimension of the policy or None if unspecified."""
|
"""Returns the shard dimension of the policy or None if unspecified."""
|
||||||
@ -134,6 +161,34 @@ class ShardingPolicy(object):
|
|||||||
if other.shard_dimension is not None:
|
if other.shard_dimension is not None:
|
||||||
self.set_shard_dimension(other.shard_dimension)
|
self.set_shard_dimension(other.shard_dimension)
|
||||||
|
|
||||||
|
def get_unpartitioned_shape(self, shape):
|
||||||
|
"""Returns the shape of an unpartitioned Tensor.
|
||||||
|
|
||||||
|
When given the shape of a 'sharded-size' Tensor, returns the shape
|
||||||
|
of the full shape of its unpartitioned Tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape: The shape of the sharded Tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The shape of the unpartitioned version of the Tensor.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if shape has unknown sharded dimension
|
||||||
|
"""
|
||||||
|
shape = tensor_shape.as_shape(shape)
|
||||||
|
dims = shape.as_list()
|
||||||
|
if (self._shard_dimension is None or self._number_of_partitions is None or
|
||||||
|
not dims):
|
||||||
|
return None
|
||||||
|
if dims[self._shard_dimension] is None:
|
||||||
|
raise ValueError("shape %s must have a fixed size for dimension %d "
|
||||||
|
"that is known at graph construction time." %
|
||||||
|
(shape.as_list(), self._shard_dimension))
|
||||||
|
if self._number_of_partitions > 1:
|
||||||
|
dims[self._shard_dimension] *= self._number_of_partitions
|
||||||
|
return tensor_shape.as_shape(dims)
|
||||||
|
|
||||||
def get_sharded_shape(self, shape, shard_index=None):
|
def get_sharded_shape(self, shape, shard_index=None):
|
||||||
"""Returns the shape of a shard of a full Tensor.
|
"""Returns the shape of a shard of a full Tensor.
|
||||||
|
|
||||||
|
@ -107,6 +107,17 @@ class ShardingTest(test.TestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
_ = p.get_sharded_shape([4, 10], shard_index=-1)
|
_ = p.get_sharded_shape([4, 10], shard_index=-1)
|
||||||
|
|
||||||
|
def testGetUnpartitionedShape(self):
|
||||||
|
"""Tests getting a sharded shape."""
|
||||||
|
p = tpu_sharding.ShardingPolicy()
|
||||||
|
p.set_number_of_shards(3)
|
||||||
|
p.set_shard_dimension(1)
|
||||||
|
p.set_number_of_partitions(4)
|
||||||
|
self.assertEqual(p.get_unpartitioned_shape([3, 5]), [3, 20])
|
||||||
|
p.freeze()
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_ = p.get_unpartitioned_shape([3, None])
|
||||||
|
|
||||||
def testGetUnshardedShape(self):
|
def testGetUnshardedShape(self):
|
||||||
"""Tests getting an unsharded shape."""
|
"""Tests getting an unsharded shape."""
|
||||||
p = tpu_sharding.ShardingPolicy()
|
p = tpu_sharding.ShardingPolicy()
|
||||||
|
Loading…
Reference in New Issue
Block a user