Enable tuple type sharding when using a single element InfeedDequeueTuple.
This will enable spatial partitioning for single input Infeeds. PiperOrigin-RevId: 243594818
This commit is contained in:
parent
555e7e3b65
commit
221a5146df
@ -120,9 +120,14 @@ class Sharding(object):
|
||||
tile_assignment_dimensions=tile_assignment_dims,
|
||||
tile_assignment_devices=range(num_devices)))
|
||||
|
||||
def apply_to_tensor(self, tensor):
|
||||
"""Applies this Sharding attribute to `tensor`."""
|
||||
if len(tensor.op.outputs) > 1:
|
||||
def apply_to_tensor(self, tensor, assign_tuple_sharding=False):
|
||||
"""Applies this Sharding attribute to `tensor`.
|
||||
|
||||
Args:
|
||||
tensor: A tf.Tensor to split.
|
||||
assign_tuple_sharding: If the sharding type should be a tuple.
|
||||
"""
|
||||
if len(tensor.op.outputs) > 1 or assign_tuple_sharding:
|
||||
proto = self._get_or_create_tuple_proto(tensor.op)
|
||||
# We can't mutate an element of old_proto.tuple_shardings, so create
|
||||
# a new proto.
|
||||
@ -166,21 +171,30 @@ class Sharding(object):
|
||||
# tensor = xla_sharding.replicate(tensor)
|
||||
|
||||
|
||||
def replicate(tensor):
|
||||
Sharding.replicate().apply_to_tensor(tensor)
|
||||
def replicate(tensor, assign_tuple_sharding=False):
|
||||
Sharding.replicate().apply_to_tensor(
|
||||
tensor,
|
||||
assign_tuple_sharding=assign_tuple_sharding)
|
||||
return tensor
|
||||
|
||||
|
||||
def assign_device(tensor, device):
|
||||
Sharding.assign_device(device).apply_to_tensor(tensor)
|
||||
def assign_device(tensor, device, assign_tuple_sharding=False):
|
||||
Sharding.assign_device(device).apply_to_tensor(
|
||||
tensor,
|
||||
assign_tuple_sharding=assign_tuple_sharding)
|
||||
return tensor
|
||||
|
||||
|
||||
def tile(tensor, tile_assignment):
|
||||
Sharding.tile(tile_assignment).apply_to_tensor(tensor)
|
||||
def tile(tensor, tile_assignment, assign_tuple_sharding=False):
|
||||
Sharding.tile(tile_assignment).apply_to_tensor(
|
||||
tensor,
|
||||
assign_tuple_sharding=assign_tuple_sharding
|
||||
)
|
||||
return tensor
|
||||
|
||||
|
||||
def split(tensor, split_dimension, num_devices):
|
||||
Sharding.split(tensor, split_dimension, num_devices).apply_to_tensor(tensor)
|
||||
def split(tensor, split_dimension, num_devices, assign_tuple_sharding=False):
|
||||
Sharding.split(tensor, split_dimension, num_devices).apply_to_tensor(
|
||||
tensor,
|
||||
assign_tuple_sharding=assign_tuple_sharding)
|
||||
return tensor
|
||||
|
@ -86,6 +86,8 @@ def partition_or_replicate_on_host(tensor, dims):
|
||||
def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims):
|
||||
"""Tags appropriate XLA sharding attribute to the dequeued tensor.
|
||||
|
||||
The sharding attribute of the dequeued tensor will be a tuple.
|
||||
|
||||
Args:
|
||||
tensor: The dequeued tensor on TPU.
|
||||
dims: A list of integer describes how the tensor is partitioned.
|
||||
@ -94,12 +96,15 @@ def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims):
|
||||
The same tensor with the xla_sharding attribute.
|
||||
"""
|
||||
if dims is None:
|
||||
return xla_sharding.replicate(tensor)
|
||||
return xla_sharding.replicate(tensor, assign_tuple_sharding=True)
|
||||
elif np.prod(dims) == 1:
|
||||
return xla_sharding.assign_device(tensor, 0)
|
||||
return xla_sharding.assign_device(tensor, 0, assign_tuple_sharding=True)
|
||||
else:
|
||||
tile_assignment = np.arange(np.prod(dims)).reshape(dims)
|
||||
return xla_sharding.tile(tensor=tensor, tile_assignment=tile_assignment)
|
||||
return xla_sharding.tile(
|
||||
tensor=tensor,
|
||||
tile_assignment=tile_assignment,
|
||||
assign_tuple_sharding=True)
|
||||
|
||||
|
||||
def tag_sharding_attribute_for_dequeued_tensors(dequeues, dims):
|
||||
|
Loading…
Reference in New Issue
Block a user