[XLA:SPMD] Add partial sharding API to SPMD and bridge support

PiperOrigin-RevId: 326348623
Change-Id: I32e94f708ff7c13aa17fabac645100f178c2e1be
This commit is contained in:
Yuanzhong Xu 2020-08-12 17:28:50 -07:00 committed by TensorFlower Gardener
parent 8b03b9681e
commit 97b414dd46
2 changed files with 96 additions and 38 deletions
tensorflow
compiler/xla/experimental/xla_sharding
core/tpu/graph_rewrite

View File

@ -89,6 +89,32 @@ class Sharding(object):
tile_assignment_dimensions=dims,
tile_assignment_devices=list(flattened_devices)))
@classmethod
def partial_tile(cls, tile_assignment):
"""Returns a partially tiled sharding attribute.
This is similar to tile(), but tile_assignment has one more dimension than
the tensor, and tiles in the last dimension of tile_assignment are
replicated.
Args:
tile_assignment: An np.ndarray describing the topology of the tiling and
which device will compute which part of the topology.
Raises:
TypeError: tile_assignment was not of np.array type.
"""
if not isinstance(tile_assignment, _np.ndarray):
raise TypeError('PartialTile assignment must be of type np.ndarray')
dims = list(tile_assignment.shape)
flattened_devices = tile_assignment.reshape(-1, order='C')
return Sharding(
proto=xla_data_pb2.OpSharding(
type=xla_data_pb2.OpSharding.OTHER,
tile_assignment_dimensions=dims,
tile_assignment_devices=list(flattened_devices),
replicate_on_last_tile_dim=True))
@classmethod
def split(cls, tensor, split_dimension, num_devices, input_shape=None):
"""Returns a Sharding that splits a tensor across a dimension.
@ -245,6 +271,23 @@ def split(tensor,
return tensor
def partial_tile(tensor, tile_assignment, use_sharding_op=False):
"""Returns a tensor that has tiled sharding.
Args:
tensor: A tf.Tensor to shard.
tile_assignment: An np.ndarray describing the topology of the tiling and
which device will compute which part of the topology. It must have one
more dimension than tensor, and the last dimension represents partially
replicated tiles.
use_sharding_op: If true, adds a sharding op to set the sharding.
"""
if use_sharding_op:
tensor = tf2xla.sharding(tensor)
Sharding.partial_tile(tile_assignment).apply_to_tensor(tensor)
return tensor
def get_op_sharding(op):
"""Returns sharding attribute of an op.
@ -313,20 +356,30 @@ def mesh_split(tensor,
use_sharding_op: If true, adds a sharding op to set the sharding.
Raises:
ValueError: The number of tensor split dimensions is different from device
mesh rank.
ValueError: The number of tensor split dimensions is larger than device mesh
rank.
"""
permutation = [d for d in tensor_split_dims_mapping if d >= 0]
if len(permutation) != len(device_mesh.shape):
if len(permutation) > len(device_mesh.shape):
raise ValueError(
'Number of tensor split dimensions (%r) is different from device mesh '
'Number of tensor split dimensions (%r) is larger than device mesh '
'rank (%r). tensor_split_dims_mapping: %r, device_mesh.shape: %r' %
(len(permutation), len(
device_mesh.shape), tensor_split_dims_mapping, device_mesh.shape))
tile_assignment = _np.transpose(device_mesh, permutation)
# Append replicated dimensions to the end.
transpose_permutation = permutation + [
d for d in range(len(device_mesh.shape)) if d not in permutation
]
tile_assignment = _np.transpose(device_mesh, transpose_permutation)
tile_shape = [
1 if d < 0 else device_mesh.shape[d] for d in tensor_split_dims_mapping
]
partial = len(permutation) < len(device_mesh.shape)
if partial:
tile_shape.append(_np.prod(device_mesh.shape) // _np.prod(tile_shape))
tile_assignment = _np.reshape(tile_assignment, tile_shape)
if partial:
return partial_tile(
tensor, tile_assignment, use_sharding_op=use_sharding_op)
return tile(tensor, tile_assignment, use_sharding_op=use_sharding_op)

View File

@ -599,8 +599,11 @@ Status GetStepMarkerLocation(const Node& replicate_node,
// sharding attribute.
Status GetDimensionIndicesAndNumSplitsFromSharding(
const xla::OpSharding& sharding, std::map<int, int>* split_dimension_map) {
for (int dim_index = 0;
dim_index < sharding.tile_assignment_dimensions_size(); dim_index++) {
int64 tensor_tile_rank = sharding.tile_assignment_dimensions_size();
if (sharding.replicate_on_last_tile_dim()) {
tensor_tile_rank--;
}
for (int dim_index = 0; dim_index < tensor_tile_rank; dim_index++) {
if (sharding.tile_assignment_dimensions(dim_index) > 1) {
split_dimension_map->emplace(
dim_index, sharding.tile_assignment_dimensions(dim_index));
@ -777,8 +780,9 @@ xla::StatusOr<ShardedInputInfo> CreateOrGetSplitNodesForInputSharding(
// `split_nodes_for_dimension` now includes final split nodes
// from which sharded data will be fed into TPUExcute nodes -- sorted by
// row major order.
std::vector<NodeOut> sharded_inputs_list;
sharded_inputs_list.reserve(split_nodes_for_dimension.size());
std::vector<NodeOut> sharded_inputs_list(
sharding.tile_assignment_devices_size());
int64 next_core_tile_index = 0;
while (!split_nodes_for_dimension.empty()) {
Node* split_node = split_nodes_for_dimension.front();
split_nodes_for_dimension.pop();
@ -786,7 +790,14 @@ xla::StatusOr<ShardedInputInfo> CreateOrGetSplitNodesForInputSharding(
TF_RETURN_IF_ERROR(
GetNodeAttr(split_node->def(), "num_split", &num_splits));
for (int out_index = 0; out_index < num_splits; ++out_index) {
sharded_inputs_list.emplace_back(NodeOut{split_node, out_index});
int64 repeat_count = sharding.replicate_on_last_tile_dim()
? *sharding.tile_assignment_dimensions().rbegin()
: 1;
for (int64 i = 0; i < repeat_count; ++i) {
int64 next_core =
sharding.tile_assignment_devices(next_core_tile_index++);
sharded_inputs_list[next_core] = NodeOut{split_node, out_index};
}
}
}
@ -889,19 +900,6 @@ xla::StatusOr<Node*> CreateConcatNodesForRetval(
return inputs_to_sharded_retval.at(0).node;
}
absl::optional<int> GetCoreIndexInSharding(const xla::OpSharding& sharding,
int64 core) {
absl::optional<int> output_index;
for (int i = 0; i < sharding.tile_assignment_devices_size(); i++) {
int64 assigned_core = sharding.tile_assignment_devices(i);
if (assigned_core == core) {
output_index = i;
break;
}
}
return output_index;
}
// Set the padding ops the same devices as the original inputs. If the original
// inputs are on TPUs, the padding ops will be placed on TPUs and XLA on demand
// mode will be triggered, so we don't need to copy the data back to the host
@ -2763,14 +2761,8 @@ Status DistributedTPURewritePass::BuildExecuteNodes(
sharding, orig_arg_num, dtype, replica,
edge->src_output(), edge->src(), control_predecessor,
graph, &input_index_to_sharded_inputs));
// Calculate which output we should receive from the Split node.
absl::optional<int> output_index =
GetCoreIndexInSharding(sharding, core);
TF_RET_CHECK(output_index);
NodeOut split_node_and_index =
sharded_input_info.sharded_inputs.at(output_index.value());
sharded_input_info.sharded_inputs.at(core);
// Connect with Split node output.
graph->AddEdge(split_node_and_index.node,
split_node_and_index.index, node, i);
@ -2850,13 +2842,8 @@ Status DistributedTPURewritePass::BuildExecuteNodes(
arg_shapes[orig_arg_num].handle_type, replica,
var_data.index, var_data.node, control_predecessor, graph,
&input_index_to_sharded_inputs));
// Calculate which output we should receive from the Split node.
absl::optional<int> output_index =
GetCoreIndexInSharding(sharding, core);
TF_RET_CHECK(output_index);
NodeOut split_node_and_index =
sharded_input_info.sharded_inputs[output_index.value()];
sharded_input_info.sharded_inputs[core];
// Connect with Split node output.
graph->AddEdge(split_node_and_index.node,
split_node_and_index.index, node, i);
@ -2919,7 +2906,16 @@ Status DistributedTPURewritePass::BuildExecuteNodes(
// Add a Concat node.
std::vector<NodeOut> orig_inputs;
for (int64 core_id : sharding.tile_assignment_devices()) {
for (int64 tile_index = 0;
tile_index < sharding.tile_assignment_devices_size();
++tile_index) {
int64 last_tile_dim_size =
*sharding.tile_assignment_dimensions().rbegin();
if (sharding.replicate_on_last_tile_dim() &&
tile_index % last_tile_dim_size != 0) {
continue;
}
int64 core_id = sharding.tile_assignment_devices(tile_index);
int core_retval_index =
retval_index_to_output_index_mapping[retval_index][core_id];
orig_inputs.push_back(
@ -2987,7 +2983,16 @@ Status DistributedTPURewritePass::BuildExecuteNodes(
// Add a Concat node.
std::vector<NodeOut> orig_inputs;
for (int64 core_id : sharding.tile_assignment_devices()) {
for (int64 tile_index = 0;
tile_index < sharding.tile_assignment_devices_size();
++tile_index) {
int64 last_tile_dim_size =
*sharding.tile_assignment_dimensions().rbegin();
if (sharding.replicate_on_last_tile_dim() &&
tile_index % last_tile_dim_size != 0) {
continue;
}
int64 core_id = sharding.tile_assignment_devices(tile_index);
int core_retval_num =
orig_arg_num_to_output_index_mapping[orig_arg_num][core_id];
orig_inputs.push_back(