[XLA:SPMD] Add partial sharding API to SPMD and bridge support
PiperOrigin-RevId: 326348623 Change-Id: I32e94f708ff7c13aa17fabac645100f178c2e1be
This commit is contained in:
parent
8b03b9681e
commit
97b414dd46
tensorflow
compiler/xla/experimental/xla_sharding
core/tpu/graph_rewrite
@ -89,6 +89,32 @@ class Sharding(object):
|
||||
tile_assignment_dimensions=dims,
|
||||
tile_assignment_devices=list(flattened_devices)))
|
||||
|
||||
@classmethod
|
||||
def partial_tile(cls, tile_assignment):
|
||||
"""Returns a partially tiled sharding attribute.
|
||||
|
||||
This is similar to tile(), but tile_assignment has one more dimension than
|
||||
the tensor, and tiles in the last dimension of tile_assignment are
|
||||
replicated.
|
||||
|
||||
Args:
|
||||
tile_assignment: An np.ndarray describing the topology of the tiling and
|
||||
which device will compute which part of the topology.
|
||||
|
||||
Raises:
|
||||
TypeError: tile_assignment was not of np.array type.
|
||||
"""
|
||||
if not isinstance(tile_assignment, _np.ndarray):
|
||||
raise TypeError('PartialTile assignment must be of type np.ndarray')
|
||||
dims = list(tile_assignment.shape)
|
||||
flattened_devices = tile_assignment.reshape(-1, order='C')
|
||||
return Sharding(
|
||||
proto=xla_data_pb2.OpSharding(
|
||||
type=xla_data_pb2.OpSharding.OTHER,
|
||||
tile_assignment_dimensions=dims,
|
||||
tile_assignment_devices=list(flattened_devices),
|
||||
replicate_on_last_tile_dim=True))
|
||||
|
||||
@classmethod
|
||||
def split(cls, tensor, split_dimension, num_devices, input_shape=None):
|
||||
"""Returns a Sharding that splits a tensor across a dimension.
|
||||
@ -245,6 +271,23 @@ def split(tensor,
|
||||
return tensor
|
||||
|
||||
|
||||
def partial_tile(tensor, tile_assignment, use_sharding_op=False):
|
||||
"""Returns a tensor that has tiled sharding.
|
||||
|
||||
Args:
|
||||
tensor: A tf.Tensor to shard.
|
||||
tile_assignment: An np.ndarray describing the topology of the tiling and
|
||||
which device will compute which part of the topology. It must have one
|
||||
more dimension than tensor, and the last dimension represents partially
|
||||
replicated tiles.
|
||||
use_sharding_op: If true, adds a sharding op to set the sharding.
|
||||
"""
|
||||
if use_sharding_op:
|
||||
tensor = tf2xla.sharding(tensor)
|
||||
Sharding.partial_tile(tile_assignment).apply_to_tensor(tensor)
|
||||
return tensor
|
||||
|
||||
|
||||
def get_op_sharding(op):
|
||||
"""Returns sharding attribute of an op.
|
||||
|
||||
@ -313,20 +356,30 @@ def mesh_split(tensor,
|
||||
use_sharding_op: If true, adds a sharding op to set the sharding.
|
||||
|
||||
Raises:
|
||||
ValueError: The number of tensor split dimensions is different from device
|
||||
mesh rank.
|
||||
ValueError: The number of tensor split dimensions is larger than device mesh
|
||||
rank.
|
||||
"""
|
||||
permutation = [d for d in tensor_split_dims_mapping if d >= 0]
|
||||
if len(permutation) != len(device_mesh.shape):
|
||||
if len(permutation) > len(device_mesh.shape):
|
||||
raise ValueError(
|
||||
'Number of tensor split dimensions (%r) is different from device mesh '
|
||||
'Number of tensor split dimensions (%r) is larger than device mesh '
|
||||
'rank (%r). tensor_split_dims_mapping: %r, device_mesh.shape: %r' %
|
||||
(len(permutation), len(
|
||||
device_mesh.shape), tensor_split_dims_mapping, device_mesh.shape))
|
||||
tile_assignment = _np.transpose(device_mesh, permutation)
|
||||
# Append replicated dimensions to the end.
|
||||
transpose_permutation = permutation + [
|
||||
d for d in range(len(device_mesh.shape)) if d not in permutation
|
||||
]
|
||||
tile_assignment = _np.transpose(device_mesh, transpose_permutation)
|
||||
tile_shape = [
|
||||
1 if d < 0 else device_mesh.shape[d] for d in tensor_split_dims_mapping
|
||||
]
|
||||
partial = len(permutation) < len(device_mesh.shape)
|
||||
if partial:
|
||||
tile_shape.append(_np.prod(device_mesh.shape) // _np.prod(tile_shape))
|
||||
tile_assignment = _np.reshape(tile_assignment, tile_shape)
|
||||
|
||||
if partial:
|
||||
return partial_tile(
|
||||
tensor, tile_assignment, use_sharding_op=use_sharding_op)
|
||||
return tile(tensor, tile_assignment, use_sharding_op=use_sharding_op)
|
||||
|
@ -599,8 +599,11 @@ Status GetStepMarkerLocation(const Node& replicate_node,
|
||||
// sharding attribute.
|
||||
Status GetDimensionIndicesAndNumSplitsFromSharding(
|
||||
const xla::OpSharding& sharding, std::map<int, int>* split_dimension_map) {
|
||||
for (int dim_index = 0;
|
||||
dim_index < sharding.tile_assignment_dimensions_size(); dim_index++) {
|
||||
int64 tensor_tile_rank = sharding.tile_assignment_dimensions_size();
|
||||
if (sharding.replicate_on_last_tile_dim()) {
|
||||
tensor_tile_rank--;
|
||||
}
|
||||
for (int dim_index = 0; dim_index < tensor_tile_rank; dim_index++) {
|
||||
if (sharding.tile_assignment_dimensions(dim_index) > 1) {
|
||||
split_dimension_map->emplace(
|
||||
dim_index, sharding.tile_assignment_dimensions(dim_index));
|
||||
@ -777,8 +780,9 @@ xla::StatusOr<ShardedInputInfo> CreateOrGetSplitNodesForInputSharding(
|
||||
// `split_nodes_for_dimension` now includes final split nodes
|
||||
// from which sharded data will be fed into TPUExcute nodes -- sorted by
|
||||
// row major order.
|
||||
std::vector<NodeOut> sharded_inputs_list;
|
||||
sharded_inputs_list.reserve(split_nodes_for_dimension.size());
|
||||
std::vector<NodeOut> sharded_inputs_list(
|
||||
sharding.tile_assignment_devices_size());
|
||||
int64 next_core_tile_index = 0;
|
||||
while (!split_nodes_for_dimension.empty()) {
|
||||
Node* split_node = split_nodes_for_dimension.front();
|
||||
split_nodes_for_dimension.pop();
|
||||
@ -786,7 +790,14 @@ xla::StatusOr<ShardedInputInfo> CreateOrGetSplitNodesForInputSharding(
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(split_node->def(), "num_split", &num_splits));
|
||||
for (int out_index = 0; out_index < num_splits; ++out_index) {
|
||||
sharded_inputs_list.emplace_back(NodeOut{split_node, out_index});
|
||||
int64 repeat_count = sharding.replicate_on_last_tile_dim()
|
||||
? *sharding.tile_assignment_dimensions().rbegin()
|
||||
: 1;
|
||||
for (int64 i = 0; i < repeat_count; ++i) {
|
||||
int64 next_core =
|
||||
sharding.tile_assignment_devices(next_core_tile_index++);
|
||||
sharded_inputs_list[next_core] = NodeOut{split_node, out_index};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -889,19 +900,6 @@ xla::StatusOr<Node*> CreateConcatNodesForRetval(
|
||||
return inputs_to_sharded_retval.at(0).node;
|
||||
}
|
||||
|
||||
absl::optional<int> GetCoreIndexInSharding(const xla::OpSharding& sharding,
|
||||
int64 core) {
|
||||
absl::optional<int> output_index;
|
||||
for (int i = 0; i < sharding.tile_assignment_devices_size(); i++) {
|
||||
int64 assigned_core = sharding.tile_assignment_devices(i);
|
||||
if (assigned_core == core) {
|
||||
output_index = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return output_index;
|
||||
}
|
||||
|
||||
// Set the padding ops the same devices as the original inputs. If the original
|
||||
// inputs are on TPUs, the padding ops will be placed on TPUs and XLA on demand
|
||||
// mode will be triggered, so we don't need to copy the data back to the host
|
||||
@ -2763,14 +2761,8 @@ Status DistributedTPURewritePass::BuildExecuteNodes(
|
||||
sharding, orig_arg_num, dtype, replica,
|
||||
edge->src_output(), edge->src(), control_predecessor,
|
||||
graph, &input_index_to_sharded_inputs));
|
||||
|
||||
// Calculate which output we should receive from the Split node.
|
||||
absl::optional<int> output_index =
|
||||
GetCoreIndexInSharding(sharding, core);
|
||||
TF_RET_CHECK(output_index);
|
||||
|
||||
NodeOut split_node_and_index =
|
||||
sharded_input_info.sharded_inputs.at(output_index.value());
|
||||
sharded_input_info.sharded_inputs.at(core);
|
||||
// Connect with Split node output.
|
||||
graph->AddEdge(split_node_and_index.node,
|
||||
split_node_and_index.index, node, i);
|
||||
@ -2850,13 +2842,8 @@ Status DistributedTPURewritePass::BuildExecuteNodes(
|
||||
arg_shapes[orig_arg_num].handle_type, replica,
|
||||
var_data.index, var_data.node, control_predecessor, graph,
|
||||
&input_index_to_sharded_inputs));
|
||||
|
||||
// Calculate which output we should receive from the Split node.
|
||||
absl::optional<int> output_index =
|
||||
GetCoreIndexInSharding(sharding, core);
|
||||
TF_RET_CHECK(output_index);
|
||||
NodeOut split_node_and_index =
|
||||
sharded_input_info.sharded_inputs[output_index.value()];
|
||||
sharded_input_info.sharded_inputs[core];
|
||||
// Connect with Split node output.
|
||||
graph->AddEdge(split_node_and_index.node,
|
||||
split_node_and_index.index, node, i);
|
||||
@ -2919,7 +2906,16 @@ Status DistributedTPURewritePass::BuildExecuteNodes(
|
||||
|
||||
// Add a Concat node.
|
||||
std::vector<NodeOut> orig_inputs;
|
||||
for (int64 core_id : sharding.tile_assignment_devices()) {
|
||||
for (int64 tile_index = 0;
|
||||
tile_index < sharding.tile_assignment_devices_size();
|
||||
++tile_index) {
|
||||
int64 last_tile_dim_size =
|
||||
*sharding.tile_assignment_dimensions().rbegin();
|
||||
if (sharding.replicate_on_last_tile_dim() &&
|
||||
tile_index % last_tile_dim_size != 0) {
|
||||
continue;
|
||||
}
|
||||
int64 core_id = sharding.tile_assignment_devices(tile_index);
|
||||
int core_retval_index =
|
||||
retval_index_to_output_index_mapping[retval_index][core_id];
|
||||
orig_inputs.push_back(
|
||||
@ -2987,7 +2983,16 @@ Status DistributedTPURewritePass::BuildExecuteNodes(
|
||||
|
||||
// Add a Concat node.
|
||||
std::vector<NodeOut> orig_inputs;
|
||||
for (int64 core_id : sharding.tile_assignment_devices()) {
|
||||
for (int64 tile_index = 0;
|
||||
tile_index < sharding.tile_assignment_devices_size();
|
||||
++tile_index) {
|
||||
int64 last_tile_dim_size =
|
||||
*sharding.tile_assignment_dimensions().rbegin();
|
||||
if (sharding.replicate_on_last_tile_dim() &&
|
||||
tile_index % last_tile_dim_size != 0) {
|
||||
continue;
|
||||
}
|
||||
int64 core_id = sharding.tile_assignment_devices(tile_index);
|
||||
int core_retval_num =
|
||||
orig_arg_num_to_output_index_mapping[orig_arg_num][core_id];
|
||||
orig_inputs.push_back(
|
||||
|
Loading…
Reference in New Issue
Block a user