[xla] add CollectivePermute and AllReduce to XLA Python client
PiperOrigin-RevId: 247541333
This commit is contained in:
parent
c23fd17c37
commit
92cec71856
@ -301,7 +301,12 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
// XlaBuilder.
|
||||
py::module ops = m.def_submodule("ops", "XLA operations");
|
||||
|
||||
ops.def("AllReduce",
|
||||
static_cast<XlaOp (*)(
|
||||
XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
|
||||
const absl::optional<ChannelHandle>&)>(&CrossReplicaSum));
|
||||
ops.def("AllToAll", &AllToAll);
|
||||
ops.def("CollectivePermute", &CollectivePermute);
|
||||
ops.def("CrossReplicaSum",
|
||||
static_cast<XlaOp (*)(XlaOp, absl::Span<const ReplicaGroup>)>(
|
||||
&CrossReplicaSum));
|
||||
|
@ -1105,6 +1105,27 @@ class ComputationBuilder(object):
|
||||
dimensions = tuple(range(ndim))
|
||||
return ops.Reshape(operand, dimensions, new_sizes)
|
||||
|
||||
def AllReduce(self,
|
||||
operand,
|
||||
computation,
|
||||
replica_groups=None):
|
||||
"""AllReduce op.
|
||||
|
||||
Args:
|
||||
operand: XlaOp representing the input array
|
||||
computation: a Computation object - binary reduction function.
|
||||
replica_groups: optional, list of lists of ints encoding a partition of
|
||||
the set {0, 1, ..., num_replicas} into equally-sized replica groups
|
||||
within which the all-to-all is performed. If not supplied or None (the
|
||||
default), all replicas belong to the same group.
|
||||
|
||||
Returns:
|
||||
An XlaOp that represents the all-reduced result.
|
||||
"""
|
||||
replica_groups_protos = _get_replica_groups_protos(replica_groups)
|
||||
return ops.AllReduce(operand, computation.computation,
|
||||
replica_groups_protos, None)
|
||||
|
||||
def AllToAll(self,
|
||||
operand,
|
||||
split_dimension,
|
||||
@ -1125,13 +1146,7 @@ class ComputationBuilder(object):
|
||||
Returns:
|
||||
An XlaOp that represents the all-to-all concatenation.
|
||||
"""
|
||||
if replica_groups is None:
|
||||
replica_groups_protos = [] # special value for XLA API
|
||||
else:
|
||||
replica_groups = list(replica_groups)
|
||||
replica_groups_protos = [
|
||||
_make_replica_group_proto(group) for group in replica_groups
|
||||
]
|
||||
replica_groups_protos = _get_replica_groups_protos(replica_groups)
|
||||
if not replica_groups:
|
||||
split_count = 1
|
||||
else:
|
||||
@ -1740,6 +1755,7 @@ _OTHER_OPS = [
|
||||
'Cholesky',
|
||||
'Clamp',
|
||||
'Collapse',
|
||||
'CollectivePermute',
|
||||
'ConvertElementType',
|
||||
'Dot',
|
||||
'Gather',
|
||||
@ -1893,3 +1909,14 @@ def _make_replica_group_proto(replica_group):
|
||||
replica_group_proto = ReplicaGroup()
|
||||
replica_group_proto.replica_ids.extend(replica_group)
|
||||
return replica_group_proto
|
||||
|
||||
|
||||
def _get_replica_groups_protos(replica_groups):
|
||||
if replica_groups is None:
|
||||
replica_groups_protos = [] # special value for XLA API
|
||||
else:
|
||||
replica_groups = list(replica_groups)
|
||||
replica_groups_protos = [
|
||||
_make_replica_group_proto(group) for group in replica_groups
|
||||
]
|
||||
return replica_groups_protos
|
||||
|
Loading…
x
Reference in New Issue
Block a user