Enable tuple type sharding when using a single element InfeedDequeueTuple.
This will enable spatial partitioning for single input Infeeds. PiperOrigin-RevId: 243594818
This commit is contained in:
parent
555e7e3b65
commit
221a5146df
@ -120,9 +120,14 @@ class Sharding(object):
|
|||||||
tile_assignment_dimensions=tile_assignment_dims,
|
tile_assignment_dimensions=tile_assignment_dims,
|
||||||
tile_assignment_devices=range(num_devices)))
|
tile_assignment_devices=range(num_devices)))
|
||||||
|
|
||||||
def apply_to_tensor(self, tensor):
|
def apply_to_tensor(self, tensor, assign_tuple_sharding=False):
|
||||||
"""Applies this Sharding attribute to `tensor`."""
|
"""Applies this Sharding attribute to `tensor`.
|
||||||
if len(tensor.op.outputs) > 1:
|
|
||||||
|
Args:
|
||||||
|
tensor: A tf.Tensor to split.
|
||||||
|
assign_tuple_sharding: If the sharding type should be a tuple.
|
||||||
|
"""
|
||||||
|
if len(tensor.op.outputs) > 1 or assign_tuple_sharding:
|
||||||
proto = self._get_or_create_tuple_proto(tensor.op)
|
proto = self._get_or_create_tuple_proto(tensor.op)
|
||||||
# We can't mutate an element of old_proto.tuple_shardings, so create
|
# We can't mutate an element of old_proto.tuple_shardings, so create
|
||||||
# a new proto.
|
# a new proto.
|
||||||
@ -166,21 +171,30 @@ class Sharding(object):
|
|||||||
# tensor = xla_sharding.replicate(tensor)
|
# tensor = xla_sharding.replicate(tensor)
|
||||||
|
|
||||||
|
|
||||||
def replicate(tensor):
|
def replicate(tensor, assign_tuple_sharding=False):
|
||||||
Sharding.replicate().apply_to_tensor(tensor)
|
Sharding.replicate().apply_to_tensor(
|
||||||
|
tensor,
|
||||||
|
assign_tuple_sharding=assign_tuple_sharding)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def assign_device(tensor, device):
|
def assign_device(tensor, device, assign_tuple_sharding=False):
|
||||||
Sharding.assign_device(device).apply_to_tensor(tensor)
|
Sharding.assign_device(device).apply_to_tensor(
|
||||||
|
tensor,
|
||||||
|
assign_tuple_sharding=assign_tuple_sharding)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def tile(tensor, tile_assignment):
|
def tile(tensor, tile_assignment, assign_tuple_sharding=False):
|
||||||
Sharding.tile(tile_assignment).apply_to_tensor(tensor)
|
Sharding.tile(tile_assignment).apply_to_tensor(
|
||||||
|
tensor,
|
||||||
|
assign_tuple_sharding=assign_tuple_sharding
|
||||||
|
)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def split(tensor, split_dimension, num_devices):
|
def split(tensor, split_dimension, num_devices, assign_tuple_sharding=False):
|
||||||
Sharding.split(tensor, split_dimension, num_devices).apply_to_tensor(tensor)
|
Sharding.split(tensor, split_dimension, num_devices).apply_to_tensor(
|
||||||
|
tensor,
|
||||||
|
assign_tuple_sharding=assign_tuple_sharding)
|
||||||
return tensor
|
return tensor
|
||||||
|
@ -86,6 +86,8 @@ def partition_or_replicate_on_host(tensor, dims):
|
|||||||
def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims):
|
def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims):
|
||||||
"""Tags appropriate XLA sharding attribute to the dequeued tensor.
|
"""Tags appropriate XLA sharding attribute to the dequeued tensor.
|
||||||
|
|
||||||
|
The sharding attribute of the dequeued tensor will be a tuple.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor: The dequeued tensor on TPU.
|
tensor: The dequeued tensor on TPU.
|
||||||
dims: A list of integer describes how the tensor is partitioned.
|
dims: A list of integer describes how the tensor is partitioned.
|
||||||
@ -94,12 +96,15 @@ def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims):
|
|||||||
The same tensor with the xla_sharding attribute.
|
The same tensor with the xla_sharding attribute.
|
||||||
"""
|
"""
|
||||||
if dims is None:
|
if dims is None:
|
||||||
return xla_sharding.replicate(tensor)
|
return xla_sharding.replicate(tensor, assign_tuple_sharding=True)
|
||||||
elif np.prod(dims) == 1:
|
elif np.prod(dims) == 1:
|
||||||
return xla_sharding.assign_device(tensor, 0)
|
return xla_sharding.assign_device(tensor, 0, assign_tuple_sharding=True)
|
||||||
else:
|
else:
|
||||||
tile_assignment = np.arange(np.prod(dims)).reshape(dims)
|
tile_assignment = np.arange(np.prod(dims)).reshape(dims)
|
||||||
return xla_sharding.tile(tensor=tensor, tile_assignment=tile_assignment)
|
return xla_sharding.tile(
|
||||||
|
tensor=tensor,
|
||||||
|
tile_assignment=tile_assignment,
|
||||||
|
assign_tuple_sharding=True)
|
||||||
|
|
||||||
|
|
||||||
def tag_sharding_attribute_for_dequeued_tensors(dequeues, dims):
|
def tag_sharding_attribute_for_dequeued_tensors(dequeues, dims):
|
||||||
|
Loading…
Reference in New Issue
Block a user