diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py index 16563bab5bc..a926e8b3c88 100644 --- a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py +++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py @@ -89,6 +89,32 @@ class Sharding(object): tile_assignment_dimensions=dims, tile_assignment_devices=list(flattened_devices))) + @classmethod + def partial_tile(cls, tile_assignment): + """Returns a partially tiled sharding attribute. + + This is similar to tile(), but tile_assignment has one more dimension than + the tensor, and tiles in the last dimension of tile_assignment are + replicated. + + Args: + tile_assignment: An np.ndarray describing the topology of the tiling and + which device will compute which part of the topology. + + Raises: + TypeError: tile_assignment was not of np.array type. + """ + if not isinstance(tile_assignment, _np.ndarray): + raise TypeError('PartialTile assignment must be of type np.ndarray') + dims = list(tile_assignment.shape) + flattened_devices = tile_assignment.reshape(-1, order='C') + return Sharding( + proto=xla_data_pb2.OpSharding( + type=xla_data_pb2.OpSharding.OTHER, + tile_assignment_dimensions=dims, + tile_assignment_devices=list(flattened_devices), + replicate_on_last_tile_dim=True)) + @classmethod def split(cls, tensor, split_dimension, num_devices, input_shape=None): """Returns a Sharding that splits a tensor across a dimension. @@ -245,6 +271,23 @@ def split(tensor, return tensor +def partial_tile(tensor, tile_assignment, use_sharding_op=False): + """Returns a tensor that has tiled sharding. + + Args: + tensor: A tf.Tensor to shard. + tile_assignment: An np.ndarray describing the topology of the tiling and + which device will compute which part of the topology. It must have one + more dimension than tensor, and the last dimension represents partially + replicated tiles. + use_sharding_op: If true, adds a sharding op to set the sharding. + """ + if use_sharding_op: + tensor = tf2xla.sharding(tensor) + Sharding.partial_tile(tile_assignment).apply_to_tensor(tensor) + return tensor + + def get_op_sharding(op): """Returns sharding attribute of an op. @@ -313,20 +356,30 @@ def mesh_split(tensor, use_sharding_op: If true, adds a sharding op to set the sharding. Raises: - ValueError: The number of tensor split dimensions is different from device - mesh rank. + ValueError: The number of tensor split dimensions is larger than device mesh + rank. """ permutation = [d for d in tensor_split_dims_mapping if d >= 0] - if len(permutation) != len(device_mesh.shape): + if len(permutation) > len(device_mesh.shape): raise ValueError( - 'Number of tensor split dimensions (%r) is different from device mesh ' + 'Number of tensor split dimensions (%r) is larger than device mesh ' 'rank (%r). tensor_split_dims_mapping: %r, device_mesh.shape: %r' % (len(permutation), len( device_mesh.shape), tensor_split_dims_mapping, device_mesh.shape)) - tile_assignment = _np.transpose(device_mesh, permutation) + # Append replicated dimensions to the end. + transpose_permutation = permutation + [ + d for d in range(len(device_mesh.shape)) if d not in permutation + ] + tile_assignment = _np.transpose(device_mesh, transpose_permutation) tile_shape = [ 1 if d < 0 else device_mesh.shape[d] for d in tensor_split_dims_mapping ] + partial = len(permutation) < len(device_mesh.shape) + if partial: + tile_shape.append(_np.prod(device_mesh.shape) // _np.prod(tile_shape)) tile_assignment = _np.reshape(tile_assignment, tile_shape) + if partial: + return partial_tile( + tensor, tile_assignment, use_sharding_op=use_sharding_op) return tile(tensor, tile_assignment, use_sharding_op=use_sharding_op) diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc index 73510319b0a..882947c1c65 100644 --- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc @@ -599,8 +599,11 @@ Status GetStepMarkerLocation(const Node& replicate_node, // sharding attribute. Status GetDimensionIndicesAndNumSplitsFromSharding( const xla::OpSharding& sharding, std::map* split_dimension_map) { - for (int dim_index = 0; - dim_index < sharding.tile_assignment_dimensions_size(); dim_index++) { + int64 tensor_tile_rank = sharding.tile_assignment_dimensions_size(); + if (sharding.replicate_on_last_tile_dim()) { + tensor_tile_rank--; + } + for (int dim_index = 0; dim_index < tensor_tile_rank; dim_index++) { if (sharding.tile_assignment_dimensions(dim_index) > 1) { split_dimension_map->emplace( dim_index, sharding.tile_assignment_dimensions(dim_index)); @@ -777,8 +780,9 @@ xla::StatusOr CreateOrGetSplitNodesForInputSharding( // `split_nodes_for_dimension` now includes final split nodes // from which sharded data will be fed into TPUExcute nodes -- sorted by // row major order. - std::vector sharded_inputs_list; - sharded_inputs_list.reserve(split_nodes_for_dimension.size()); + std::vector sharded_inputs_list( + sharding.tile_assignment_devices_size()); + int64 next_core_tile_index = 0; while (!split_nodes_for_dimension.empty()) { Node* split_node = split_nodes_for_dimension.front(); split_nodes_for_dimension.pop(); @@ -786,7 +790,14 @@ xla::StatusOr CreateOrGetSplitNodesForInputSharding( TF_RETURN_IF_ERROR( GetNodeAttr(split_node->def(), "num_split", &num_splits)); for (int out_index = 0; out_index < num_splits; ++out_index) { - sharded_inputs_list.emplace_back(NodeOut{split_node, out_index}); + int64 repeat_count = sharding.replicate_on_last_tile_dim() + ? *sharding.tile_assignment_dimensions().rbegin() + : 1; + for (int64 i = 0; i < repeat_count; ++i) { + int64 next_core = + sharding.tile_assignment_devices(next_core_tile_index++); + sharded_inputs_list[next_core] = NodeOut{split_node, out_index}; + } } } @@ -889,19 +900,6 @@ xla::StatusOr CreateConcatNodesForRetval( return inputs_to_sharded_retval.at(0).node; } -absl::optional GetCoreIndexInSharding(const xla::OpSharding& sharding, - int64 core) { - absl::optional output_index; - for (int i = 0; i < sharding.tile_assignment_devices_size(); i++) { - int64 assigned_core = sharding.tile_assignment_devices(i); - if (assigned_core == core) { - output_index = i; - break; - } - } - return output_index; -} - // Set the padding ops the same devices as the original inputs. If the original // inputs are on TPUs, the padding ops will be placed on TPUs and XLA on demand // mode will be triggered, so we don't need to copy the data back to the host @@ -2763,14 +2761,8 @@ Status DistributedTPURewritePass::BuildExecuteNodes( sharding, orig_arg_num, dtype, replica, edge->src_output(), edge->src(), control_predecessor, graph, &input_index_to_sharded_inputs)); - - // Calculate which output we should receive from the Split node. - absl::optional output_index = - GetCoreIndexInSharding(sharding, core); - TF_RET_CHECK(output_index); - NodeOut split_node_and_index = - sharded_input_info.sharded_inputs.at(output_index.value()); + sharded_input_info.sharded_inputs.at(core); // Connect with Split node output. graph->AddEdge(split_node_and_index.node, split_node_and_index.index, node, i); @@ -2850,13 +2842,8 @@ Status DistributedTPURewritePass::BuildExecuteNodes( arg_shapes[orig_arg_num].handle_type, replica, var_data.index, var_data.node, control_predecessor, graph, &input_index_to_sharded_inputs)); - - // Calculate which output we should receive from the Split node. - absl::optional output_index = - GetCoreIndexInSharding(sharding, core); - TF_RET_CHECK(output_index); NodeOut split_node_and_index = - sharded_input_info.sharded_inputs[output_index.value()]; + sharded_input_info.sharded_inputs[core]; // Connect with Split node output. graph->AddEdge(split_node_and_index.node, split_node_and_index.index, node, i); @@ -2919,7 +2906,16 @@ Status DistributedTPURewritePass::BuildExecuteNodes( // Add a Concat node. std::vector orig_inputs; - for (int64 core_id : sharding.tile_assignment_devices()) { + for (int64 tile_index = 0; + tile_index < sharding.tile_assignment_devices_size(); + ++tile_index) { + int64 last_tile_dim_size = + *sharding.tile_assignment_dimensions().rbegin(); + if (sharding.replicate_on_last_tile_dim() && + tile_index % last_tile_dim_size != 0) { + continue; + } + int64 core_id = sharding.tile_assignment_devices(tile_index); int core_retval_index = retval_index_to_output_index_mapping[retval_index][core_id]; orig_inputs.push_back( @@ -2987,7 +2983,16 @@ Status DistributedTPURewritePass::BuildExecuteNodes( // Add a Concat node. std::vector orig_inputs; - for (int64 core_id : sharding.tile_assignment_devices()) { + for (int64 tile_index = 0; + tile_index < sharding.tile_assignment_devices_size(); + ++tile_index) { + int64 last_tile_dim_size = + *sharding.tile_assignment_dimensions().rbegin(); + if (sharding.replicate_on_last_tile_dim() && + tile_index % last_tile_dim_size != 0) { + continue; + } + int64 core_id = sharding.tile_assignment_devices(tile_index); int core_retval_num = orig_arg_num_to_output_index_mapping[orig_arg_num][core_id]; orig_inputs.push_back(