diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py index c34e84efc80..7c458844a93 100644 --- a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py +++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py @@ -120,9 +120,14 @@ class Sharding(object): tile_assignment_dimensions=tile_assignment_dims, tile_assignment_devices=range(num_devices))) - def apply_to_tensor(self, tensor): - """Applies this Sharding attribute to `tensor`.""" - if len(tensor.op.outputs) > 1: + def apply_to_tensor(self, tensor, assign_tuple_sharding=False): + """Applies this Sharding attribute to `tensor`. + + Args: + tensor: A tf.Tensor to split. + assign_tuple_sharding: If the sharding type should be a tuple. + """ + if len(tensor.op.outputs) > 1 or assign_tuple_sharding: proto = self._get_or_create_tuple_proto(tensor.op) # We can't mutate an element of old_proto.tuple_shardings, so create # a new proto. @@ -166,21 +171,30 @@ class Sharding(object): # tensor = xla_sharding.replicate(tensor) -def replicate(tensor): - Sharding.replicate().apply_to_tensor(tensor) +def replicate(tensor, assign_tuple_sharding=False): + Sharding.replicate().apply_to_tensor( + tensor, + assign_tuple_sharding=assign_tuple_sharding) return tensor -def assign_device(tensor, device): - Sharding.assign_device(device).apply_to_tensor(tensor) +def assign_device(tensor, device, assign_tuple_sharding=False): + Sharding.assign_device(device).apply_to_tensor( + tensor, + assign_tuple_sharding=assign_tuple_sharding) return tensor -def tile(tensor, tile_assignment): - Sharding.tile(tile_assignment).apply_to_tensor(tensor) +def tile(tensor, tile_assignment, assign_tuple_sharding=False): + Sharding.tile(tile_assignment).apply_to_tensor( + tensor, + assign_tuple_sharding=assign_tuple_sharding + ) return tensor -def split(tensor, split_dimension, num_devices): - Sharding.split(tensor, split_dimension, num_devices).apply_to_tensor(tensor) +def split(tensor, split_dimension, num_devices, assign_tuple_sharding=False): + Sharding.split(tensor, split_dimension, num_devices).apply_to_tensor( + tensor, + assign_tuple_sharding=assign_tuple_sharding) return tensor diff --git a/tensorflow/python/tpu/tpu_feed.py b/tensorflow/python/tpu/tpu_feed.py index de1adc80e60..159131c2bc1 100644 --- a/tensorflow/python/tpu/tpu_feed.py +++ b/tensorflow/python/tpu/tpu_feed.py @@ -86,6 +86,8 @@ def partition_or_replicate_on_host(tensor, dims): def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims): """Tags appropriate XLA sharding attribute to the dequeued tensor. + The sharding attribute of the dequeued tensor will be a tuple. + Args: tensor: The dequeued tensor on TPU. dims: A list of integer describes how the tensor is partitioned. @@ -94,12 +96,15 @@ def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims): The same tensor with the xla_sharding attribute. """ if dims is None: - return xla_sharding.replicate(tensor) + return xla_sharding.replicate(tensor, assign_tuple_sharding=True) elif np.prod(dims) == 1: - return xla_sharding.assign_device(tensor, 0) + return xla_sharding.assign_device(tensor, 0, assign_tuple_sharding=True) else: tile_assignment = np.arange(np.prod(dims)).reshape(dims) - return xla_sharding.tile(tensor=tensor, tile_assignment=tile_assignment) + return xla_sharding.tile( + tensor=tensor, + tile_assignment=tile_assignment, + assign_tuple_sharding=True) def tag_sharding_attribute_for_dequeued_tensors(dequeues, dims):