[XLA:SPMD] Add partial sharding API to SPMD and bridge support
PiperOrigin-RevId: 326348623 Change-Id: I32e94f708ff7c13aa17fabac645100f178c2e1be
This commit is contained in:
parent
8b03b9681e
commit
97b414dd46
tensorflow
compiler/xla/experimental/xla_sharding
core/tpu/graph_rewrite
@ -89,6 +89,32 @@ class Sharding(object):
|
|||||||
tile_assignment_dimensions=dims,
|
tile_assignment_dimensions=dims,
|
||||||
tile_assignment_devices=list(flattened_devices)))
|
tile_assignment_devices=list(flattened_devices)))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def partial_tile(cls, tile_assignment):
|
||||||
|
"""Returns a partially tiled sharding attribute.
|
||||||
|
|
||||||
|
This is similar to tile(), but tile_assignment has one more dimension than
|
||||||
|
the tensor, and tiles in the last dimension of tile_assignment are
|
||||||
|
replicated.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tile_assignment: An np.ndarray describing the topology of the tiling and
|
||||||
|
which device will compute which part of the topology.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: tile_assignment was not of np.array type.
|
||||||
|
"""
|
||||||
|
if not isinstance(tile_assignment, _np.ndarray):
|
||||||
|
raise TypeError('PartialTile assignment must be of type np.ndarray')
|
||||||
|
dims = list(tile_assignment.shape)
|
||||||
|
flattened_devices = tile_assignment.reshape(-1, order='C')
|
||||||
|
return Sharding(
|
||||||
|
proto=xla_data_pb2.OpSharding(
|
||||||
|
type=xla_data_pb2.OpSharding.OTHER,
|
||||||
|
tile_assignment_dimensions=dims,
|
||||||
|
tile_assignment_devices=list(flattened_devices),
|
||||||
|
replicate_on_last_tile_dim=True))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def split(cls, tensor, split_dimension, num_devices, input_shape=None):
|
def split(cls, tensor, split_dimension, num_devices, input_shape=None):
|
||||||
"""Returns a Sharding that splits a tensor across a dimension.
|
"""Returns a Sharding that splits a tensor across a dimension.
|
||||||
@ -245,6 +271,23 @@ def split(tensor,
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def partial_tile(tensor, tile_assignment, use_sharding_op=False):
|
||||||
|
"""Returns a tensor that has tiled sharding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: A tf.Tensor to shard.
|
||||||
|
tile_assignment: An np.ndarray describing the topology of the tiling and
|
||||||
|
which device will compute which part of the topology. It must have one
|
||||||
|
more dimension than tensor, and the last dimension represents partially
|
||||||
|
replicated tiles.
|
||||||
|
use_sharding_op: If true, adds a sharding op to set the sharding.
|
||||||
|
"""
|
||||||
|
if use_sharding_op:
|
||||||
|
tensor = tf2xla.sharding(tensor)
|
||||||
|
Sharding.partial_tile(tile_assignment).apply_to_tensor(tensor)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def get_op_sharding(op):
|
def get_op_sharding(op):
|
||||||
"""Returns sharding attribute of an op.
|
"""Returns sharding attribute of an op.
|
||||||
|
|
||||||
@ -313,20 +356,30 @@ def mesh_split(tensor,
|
|||||||
use_sharding_op: If true, adds a sharding op to set the sharding.
|
use_sharding_op: If true, adds a sharding op to set the sharding.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: The number of tensor split dimensions is different from device
|
ValueError: The number of tensor split dimensions is larger than device mesh
|
||||||
mesh rank.
|
rank.
|
||||||
"""
|
"""
|
||||||
permutation = [d for d in tensor_split_dims_mapping if d >= 0]
|
permutation = [d for d in tensor_split_dims_mapping if d >= 0]
|
||||||
if len(permutation) != len(device_mesh.shape):
|
if len(permutation) > len(device_mesh.shape):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Number of tensor split dimensions (%r) is different from device mesh '
|
'Number of tensor split dimensions (%r) is larger than device mesh '
|
||||||
'rank (%r). tensor_split_dims_mapping: %r, device_mesh.shape: %r' %
|
'rank (%r). tensor_split_dims_mapping: %r, device_mesh.shape: %r' %
|
||||||
(len(permutation), len(
|
(len(permutation), len(
|
||||||
device_mesh.shape), tensor_split_dims_mapping, device_mesh.shape))
|
device_mesh.shape), tensor_split_dims_mapping, device_mesh.shape))
|
||||||
tile_assignment = _np.transpose(device_mesh, permutation)
|
# Append replicated dimensions to the end.
|
||||||
|
transpose_permutation = permutation + [
|
||||||
|
d for d in range(len(device_mesh.shape)) if d not in permutation
|
||||||
|
]
|
||||||
|
tile_assignment = _np.transpose(device_mesh, transpose_permutation)
|
||||||
tile_shape = [
|
tile_shape = [
|
||||||
1 if d < 0 else device_mesh.shape[d] for d in tensor_split_dims_mapping
|
1 if d < 0 else device_mesh.shape[d] for d in tensor_split_dims_mapping
|
||||||
]
|
]
|
||||||
|
partial = len(permutation) < len(device_mesh.shape)
|
||||||
|
if partial:
|
||||||
|
tile_shape.append(_np.prod(device_mesh.shape) // _np.prod(tile_shape))
|
||||||
tile_assignment = _np.reshape(tile_assignment, tile_shape)
|
tile_assignment = _np.reshape(tile_assignment, tile_shape)
|
||||||
|
|
||||||
|
if partial:
|
||||||
|
return partial_tile(
|
||||||
|
tensor, tile_assignment, use_sharding_op=use_sharding_op)
|
||||||
return tile(tensor, tile_assignment, use_sharding_op=use_sharding_op)
|
return tile(tensor, tile_assignment, use_sharding_op=use_sharding_op)
|
||||||
|
@ -599,8 +599,11 @@ Status GetStepMarkerLocation(const Node& replicate_node,
|
|||||||
// sharding attribute.
|
// sharding attribute.
|
||||||
Status GetDimensionIndicesAndNumSplitsFromSharding(
|
Status GetDimensionIndicesAndNumSplitsFromSharding(
|
||||||
const xla::OpSharding& sharding, std::map<int, int>* split_dimension_map) {
|
const xla::OpSharding& sharding, std::map<int, int>* split_dimension_map) {
|
||||||
for (int dim_index = 0;
|
int64 tensor_tile_rank = sharding.tile_assignment_dimensions_size();
|
||||||
dim_index < sharding.tile_assignment_dimensions_size(); dim_index++) {
|
if (sharding.replicate_on_last_tile_dim()) {
|
||||||
|
tensor_tile_rank--;
|
||||||
|
}
|
||||||
|
for (int dim_index = 0; dim_index < tensor_tile_rank; dim_index++) {
|
||||||
if (sharding.tile_assignment_dimensions(dim_index) > 1) {
|
if (sharding.tile_assignment_dimensions(dim_index) > 1) {
|
||||||
split_dimension_map->emplace(
|
split_dimension_map->emplace(
|
||||||
dim_index, sharding.tile_assignment_dimensions(dim_index));
|
dim_index, sharding.tile_assignment_dimensions(dim_index));
|
||||||
@ -777,8 +780,9 @@ xla::StatusOr<ShardedInputInfo> CreateOrGetSplitNodesForInputSharding(
|
|||||||
// `split_nodes_for_dimension` now includes final split nodes
|
// `split_nodes_for_dimension` now includes final split nodes
|
||||||
// from which sharded data will be fed into TPUExcute nodes -- sorted by
|
// from which sharded data will be fed into TPUExcute nodes -- sorted by
|
||||||
// row major order.
|
// row major order.
|
||||||
std::vector<NodeOut> sharded_inputs_list;
|
std::vector<NodeOut> sharded_inputs_list(
|
||||||
sharded_inputs_list.reserve(split_nodes_for_dimension.size());
|
sharding.tile_assignment_devices_size());
|
||||||
|
int64 next_core_tile_index = 0;
|
||||||
while (!split_nodes_for_dimension.empty()) {
|
while (!split_nodes_for_dimension.empty()) {
|
||||||
Node* split_node = split_nodes_for_dimension.front();
|
Node* split_node = split_nodes_for_dimension.front();
|
||||||
split_nodes_for_dimension.pop();
|
split_nodes_for_dimension.pop();
|
||||||
@ -786,7 +790,14 @@ xla::StatusOr<ShardedInputInfo> CreateOrGetSplitNodesForInputSharding(
|
|||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
GetNodeAttr(split_node->def(), "num_split", &num_splits));
|
GetNodeAttr(split_node->def(), "num_split", &num_splits));
|
||||||
for (int out_index = 0; out_index < num_splits; ++out_index) {
|
for (int out_index = 0; out_index < num_splits; ++out_index) {
|
||||||
sharded_inputs_list.emplace_back(NodeOut{split_node, out_index});
|
int64 repeat_count = sharding.replicate_on_last_tile_dim()
|
||||||
|
? *sharding.tile_assignment_dimensions().rbegin()
|
||||||
|
: 1;
|
||||||
|
for (int64 i = 0; i < repeat_count; ++i) {
|
||||||
|
int64 next_core =
|
||||||
|
sharding.tile_assignment_devices(next_core_tile_index++);
|
||||||
|
sharded_inputs_list[next_core] = NodeOut{split_node, out_index};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -889,19 +900,6 @@ xla::StatusOr<Node*> CreateConcatNodesForRetval(
|
|||||||
return inputs_to_sharded_retval.at(0).node;
|
return inputs_to_sharded_retval.at(0).node;
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::optional<int> GetCoreIndexInSharding(const xla::OpSharding& sharding,
|
|
||||||
int64 core) {
|
|
||||||
absl::optional<int> output_index;
|
|
||||||
for (int i = 0; i < sharding.tile_assignment_devices_size(); i++) {
|
|
||||||
int64 assigned_core = sharding.tile_assignment_devices(i);
|
|
||||||
if (assigned_core == core) {
|
|
||||||
output_index = i;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return output_index;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the padding ops the same devices as the original inputs. If the original
|
// Set the padding ops the same devices as the original inputs. If the original
|
||||||
// inputs are on TPUs, the padding ops will be placed on TPUs and XLA on demand
|
// inputs are on TPUs, the padding ops will be placed on TPUs and XLA on demand
|
||||||
// mode will be triggered, so we don't need to copy the data back to the host
|
// mode will be triggered, so we don't need to copy the data back to the host
|
||||||
@ -2763,14 +2761,8 @@ Status DistributedTPURewritePass::BuildExecuteNodes(
|
|||||||
sharding, orig_arg_num, dtype, replica,
|
sharding, orig_arg_num, dtype, replica,
|
||||||
edge->src_output(), edge->src(), control_predecessor,
|
edge->src_output(), edge->src(), control_predecessor,
|
||||||
graph, &input_index_to_sharded_inputs));
|
graph, &input_index_to_sharded_inputs));
|
||||||
|
|
||||||
// Calculate which output we should receive from the Split node.
|
|
||||||
absl::optional<int> output_index =
|
|
||||||
GetCoreIndexInSharding(sharding, core);
|
|
||||||
TF_RET_CHECK(output_index);
|
|
||||||
|
|
||||||
NodeOut split_node_and_index =
|
NodeOut split_node_and_index =
|
||||||
sharded_input_info.sharded_inputs.at(output_index.value());
|
sharded_input_info.sharded_inputs.at(core);
|
||||||
// Connect with Split node output.
|
// Connect with Split node output.
|
||||||
graph->AddEdge(split_node_and_index.node,
|
graph->AddEdge(split_node_and_index.node,
|
||||||
split_node_and_index.index, node, i);
|
split_node_and_index.index, node, i);
|
||||||
@ -2850,13 +2842,8 @@ Status DistributedTPURewritePass::BuildExecuteNodes(
|
|||||||
arg_shapes[orig_arg_num].handle_type, replica,
|
arg_shapes[orig_arg_num].handle_type, replica,
|
||||||
var_data.index, var_data.node, control_predecessor, graph,
|
var_data.index, var_data.node, control_predecessor, graph,
|
||||||
&input_index_to_sharded_inputs));
|
&input_index_to_sharded_inputs));
|
||||||
|
|
||||||
// Calculate which output we should receive from the Split node.
|
|
||||||
absl::optional<int> output_index =
|
|
||||||
GetCoreIndexInSharding(sharding, core);
|
|
||||||
TF_RET_CHECK(output_index);
|
|
||||||
NodeOut split_node_and_index =
|
NodeOut split_node_and_index =
|
||||||
sharded_input_info.sharded_inputs[output_index.value()];
|
sharded_input_info.sharded_inputs[core];
|
||||||
// Connect with Split node output.
|
// Connect with Split node output.
|
||||||
graph->AddEdge(split_node_and_index.node,
|
graph->AddEdge(split_node_and_index.node,
|
||||||
split_node_and_index.index, node, i);
|
split_node_and_index.index, node, i);
|
||||||
@ -2919,7 +2906,16 @@ Status DistributedTPURewritePass::BuildExecuteNodes(
|
|||||||
|
|
||||||
// Add a Concat node.
|
// Add a Concat node.
|
||||||
std::vector<NodeOut> orig_inputs;
|
std::vector<NodeOut> orig_inputs;
|
||||||
for (int64 core_id : sharding.tile_assignment_devices()) {
|
for (int64 tile_index = 0;
|
||||||
|
tile_index < sharding.tile_assignment_devices_size();
|
||||||
|
++tile_index) {
|
||||||
|
int64 last_tile_dim_size =
|
||||||
|
*sharding.tile_assignment_dimensions().rbegin();
|
||||||
|
if (sharding.replicate_on_last_tile_dim() &&
|
||||||
|
tile_index % last_tile_dim_size != 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
int64 core_id = sharding.tile_assignment_devices(tile_index);
|
||||||
int core_retval_index =
|
int core_retval_index =
|
||||||
retval_index_to_output_index_mapping[retval_index][core_id];
|
retval_index_to_output_index_mapping[retval_index][core_id];
|
||||||
orig_inputs.push_back(
|
orig_inputs.push_back(
|
||||||
@ -2987,7 +2983,16 @@ Status DistributedTPURewritePass::BuildExecuteNodes(
|
|||||||
|
|
||||||
// Add a Concat node.
|
// Add a Concat node.
|
||||||
std::vector<NodeOut> orig_inputs;
|
std::vector<NodeOut> orig_inputs;
|
||||||
for (int64 core_id : sharding.tile_assignment_devices()) {
|
for (int64 tile_index = 0;
|
||||||
|
tile_index < sharding.tile_assignment_devices_size();
|
||||||
|
++tile_index) {
|
||||||
|
int64 last_tile_dim_size =
|
||||||
|
*sharding.tile_assignment_dimensions().rbegin();
|
||||||
|
if (sharding.replicate_on_last_tile_dim() &&
|
||||||
|
tile_index % last_tile_dim_size != 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
int64 core_id = sharding.tile_assignment_devices(tile_index);
|
||||||
int core_retval_num =
|
int core_retval_num =
|
||||||
orig_arg_num_to_output_index_mapping[orig_arg_num][core_id];
|
orig_arg_num_to_output_index_mapping[orig_arg_num][core_id];
|
||||||
orig_inputs.push_back(
|
orig_inputs.push_back(
|
||||||
|
Loading…
Reference in New Issue
Block a user