[xla] add CollectivePermute and AllReduce to XLA Python client
PiperOrigin-RevId: 247541333
This commit is contained in:
parent
c23fd17c37
commit
92cec71856
@ -301,7 +301,12 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
// XlaBuilder.
|
// XlaBuilder.
|
||||||
py::module ops = m.def_submodule("ops", "XLA operations");
|
py::module ops = m.def_submodule("ops", "XLA operations");
|
||||||
|
|
||||||
|
ops.def("AllReduce",
|
||||||
|
static_cast<XlaOp (*)(
|
||||||
|
XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
|
||||||
|
const absl::optional<ChannelHandle>&)>(&CrossReplicaSum));
|
||||||
ops.def("AllToAll", &AllToAll);
|
ops.def("AllToAll", &AllToAll);
|
||||||
|
ops.def("CollectivePermute", &CollectivePermute);
|
||||||
ops.def("CrossReplicaSum",
|
ops.def("CrossReplicaSum",
|
||||||
static_cast<XlaOp (*)(XlaOp, absl::Span<const ReplicaGroup>)>(
|
static_cast<XlaOp (*)(XlaOp, absl::Span<const ReplicaGroup>)>(
|
||||||
&CrossReplicaSum));
|
&CrossReplicaSum));
|
||||||
|
@ -1105,6 +1105,27 @@ class ComputationBuilder(object):
|
|||||||
dimensions = tuple(range(ndim))
|
dimensions = tuple(range(ndim))
|
||||||
return ops.Reshape(operand, dimensions, new_sizes)
|
return ops.Reshape(operand, dimensions, new_sizes)
|
||||||
|
|
||||||
|
def AllReduce(self,
|
||||||
|
operand,
|
||||||
|
computation,
|
||||||
|
replica_groups=None):
|
||||||
|
"""AllReduce op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operand: XlaOp representing the input array
|
||||||
|
computation: a Computation object - binary reduction function.
|
||||||
|
replica_groups: optional, list of lists of ints encoding a partition of
|
||||||
|
the set {0, 1, ..., num_replicas} into equally-sized replica groups
|
||||||
|
within which the all-to-all is performed. If not supplied or None (the
|
||||||
|
default), all replicas belong to the same group.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An XlaOp that represents the all-reduced result.
|
||||||
|
"""
|
||||||
|
replica_groups_protos = _get_replica_groups_protos(replica_groups)
|
||||||
|
return ops.AllReduce(operand, computation.computation,
|
||||||
|
replica_groups_protos, None)
|
||||||
|
|
||||||
def AllToAll(self,
|
def AllToAll(self,
|
||||||
operand,
|
operand,
|
||||||
split_dimension,
|
split_dimension,
|
||||||
@ -1125,13 +1146,7 @@ class ComputationBuilder(object):
|
|||||||
Returns:
|
Returns:
|
||||||
An XlaOp that represents the all-to-all concatenation.
|
An XlaOp that represents the all-to-all concatenation.
|
||||||
"""
|
"""
|
||||||
if replica_groups is None:
|
replica_groups_protos = _get_replica_groups_protos(replica_groups)
|
||||||
replica_groups_protos = [] # special value for XLA API
|
|
||||||
else:
|
|
||||||
replica_groups = list(replica_groups)
|
|
||||||
replica_groups_protos = [
|
|
||||||
_make_replica_group_proto(group) for group in replica_groups
|
|
||||||
]
|
|
||||||
if not replica_groups:
|
if not replica_groups:
|
||||||
split_count = 1
|
split_count = 1
|
||||||
else:
|
else:
|
||||||
@ -1740,6 +1755,7 @@ _OTHER_OPS = [
|
|||||||
'Cholesky',
|
'Cholesky',
|
||||||
'Clamp',
|
'Clamp',
|
||||||
'Collapse',
|
'Collapse',
|
||||||
|
'CollectivePermute',
|
||||||
'ConvertElementType',
|
'ConvertElementType',
|
||||||
'Dot',
|
'Dot',
|
||||||
'Gather',
|
'Gather',
|
||||||
@ -1893,3 +1909,14 @@ def _make_replica_group_proto(replica_group):
|
|||||||
replica_group_proto = ReplicaGroup()
|
replica_group_proto = ReplicaGroup()
|
||||||
replica_group_proto.replica_ids.extend(replica_group)
|
replica_group_proto.replica_ids.extend(replica_group)
|
||||||
return replica_group_proto
|
return replica_group_proto
|
||||||
|
|
||||||
|
|
||||||
|
def _get_replica_groups_protos(replica_groups):
|
||||||
|
if replica_groups is None:
|
||||||
|
replica_groups_protos = [] # special value for XLA API
|
||||||
|
else:
|
||||||
|
replica_groups = list(replica_groups)
|
||||||
|
replica_groups_protos = [
|
||||||
|
_make_replica_group_proto(group) for group in replica_groups
|
||||||
|
]
|
||||||
|
return replica_groups_protos
|
||||||
|
Loading…
x
Reference in New Issue
Block a user