Enable tuple type sharding when using a single element InfeedDequeueTuple.

This will enable spatial partitioning for single input Infeeds.

PiperOrigin-RevId: 243594818
This commit is contained in:
A. Unique TensorFlower 2019-04-15 05:01:20 -07:00 committed by TensorFlower Gardener
parent 555e7e3b65
commit 221a5146df
2 changed files with 33 additions and 14 deletions

View File

@ -120,9 +120,14 @@ class Sharding(object):
tile_assignment_dimensions=tile_assignment_dims, tile_assignment_dimensions=tile_assignment_dims,
tile_assignment_devices=range(num_devices))) tile_assignment_devices=range(num_devices)))
def apply_to_tensor(self, tensor): def apply_to_tensor(self, tensor, assign_tuple_sharding=False):
"""Applies this Sharding attribute to `tensor`.""" """Applies this Sharding attribute to `tensor`.
if len(tensor.op.outputs) > 1:
Args:
tensor: A tf.Tensor to split.
assign_tuple_sharding: If the sharding type should be a tuple.
"""
if len(tensor.op.outputs) > 1 or assign_tuple_sharding:
proto = self._get_or_create_tuple_proto(tensor.op) proto = self._get_or_create_tuple_proto(tensor.op)
# We can't mutate an element of old_proto.tuple_shardings, so create # We can't mutate an element of old_proto.tuple_shardings, so create
# a new proto. # a new proto.
@ -166,21 +171,30 @@ class Sharding(object):
# tensor = xla_sharding.replicate(tensor) # tensor = xla_sharding.replicate(tensor)
def replicate(tensor): def replicate(tensor, assign_tuple_sharding=False):
Sharding.replicate().apply_to_tensor(tensor) Sharding.replicate().apply_to_tensor(
tensor,
assign_tuple_sharding=assign_tuple_sharding)
return tensor return tensor
def assign_device(tensor, device): def assign_device(tensor, device, assign_tuple_sharding=False):
Sharding.assign_device(device).apply_to_tensor(tensor) Sharding.assign_device(device).apply_to_tensor(
tensor,
assign_tuple_sharding=assign_tuple_sharding)
return tensor return tensor
def tile(tensor, tile_assignment): def tile(tensor, tile_assignment, assign_tuple_sharding=False):
Sharding.tile(tile_assignment).apply_to_tensor(tensor) Sharding.tile(tile_assignment).apply_to_tensor(
tensor,
assign_tuple_sharding=assign_tuple_sharding
)
return tensor return tensor
def split(tensor, split_dimension, num_devices): def split(tensor, split_dimension, num_devices, assign_tuple_sharding=False):
Sharding.split(tensor, split_dimension, num_devices).apply_to_tensor(tensor) Sharding.split(tensor, split_dimension, num_devices).apply_to_tensor(
tensor,
assign_tuple_sharding=assign_tuple_sharding)
return tensor return tensor

View File

@ -86,6 +86,8 @@ def partition_or_replicate_on_host(tensor, dims):
def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims): def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims):
"""Tags appropriate XLA sharding attribute to the dequeued tensor. """Tags appropriate XLA sharding attribute to the dequeued tensor.
The sharding attribute of the dequeued tensor will be a tuple.
Args: Args:
tensor: The dequeued tensor on TPU. tensor: The dequeued tensor on TPU.
dims: A list of integer describes how the tensor is partitioned. dims: A list of integer describes how the tensor is partitioned.
@ -94,12 +96,15 @@ def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims):
The same tensor with the xla_sharding attribute. The same tensor with the xla_sharding attribute.
""" """
if dims is None: if dims is None:
return xla_sharding.replicate(tensor) return xla_sharding.replicate(tensor, assign_tuple_sharding=True)
elif np.prod(dims) == 1: elif np.prod(dims) == 1:
return xla_sharding.assign_device(tensor, 0) return xla_sharding.assign_device(tensor, 0, assign_tuple_sharding=True)
else: else:
tile_assignment = np.arange(np.prod(dims)).reshape(dims) tile_assignment = np.arange(np.prod(dims)).reshape(dims)
return xla_sharding.tile(tensor=tensor, tile_assignment=tile_assignment) return xla_sharding.tile(
tensor=tensor,
tile_assignment=tile_assignment,
assign_tuple_sharding=True)
def tag_sharding_attribute_for_dequeued_tensors(dequeues, dims): def tag_sharding_attribute_for_dequeued_tensors(dequeues, dims):