[xla] add CollectivePermute and AllReduce to XLA Python client

PiperOrigin-RevId: 247541333
This commit is contained in:
A. Unique TensorFlower 2019-05-09 19:42:59 -07:00 committed by TensorFlower Gardener
parent c23fd17c37
commit 92cec71856
2 changed files with 39 additions and 7 deletions

View File

@ -301,7 +301,12 @@ PYBIND11_MODULE(xla_extension, m) {
// XlaBuilder.
py::module ops = m.def_submodule("ops", "XLA operations");
ops.def("AllReduce",
static_cast<XlaOp (*)(
XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
const absl::optional<ChannelHandle>&)>(&CrossReplicaSum));
ops.def("AllToAll", &AllToAll);
ops.def("CollectivePermute", &CollectivePermute);
ops.def("CrossReplicaSum",
static_cast<XlaOp (*)(XlaOp, absl::Span<const ReplicaGroup>)>(
&CrossReplicaSum));

View File

@ -1105,6 +1105,27 @@ class ComputationBuilder(object):
dimensions = tuple(range(ndim))
return ops.Reshape(operand, dimensions, new_sizes)
def AllReduce(self,
operand,
computation,
replica_groups=None):
"""AllReduce op.
Args:
operand: XlaOp representing the input array
computation: a Computation object - binary reduction function.
replica_groups: optional, list of lists of ints encoding a partition of
the set {0, 1, ..., num_replicas} into equally-sized replica groups
within which the all-to-all is performed. If not supplied or None (the
default), all replicas belong to the same group.
Returns:
An XlaOp that represents the all-reduced result.
"""
replica_groups_protos = _get_replica_groups_protos(replica_groups)
return ops.AllReduce(operand, computation.computation,
replica_groups_protos, None)
def AllToAll(self,
operand,
split_dimension,
@ -1125,13 +1146,7 @@ class ComputationBuilder(object):
Returns:
An XlaOp that represents the all-to-all concatenation.
"""
if replica_groups is None:
replica_groups_protos = [] # special value for XLA API
else:
replica_groups = list(replica_groups)
replica_groups_protos = [
_make_replica_group_proto(group) for group in replica_groups
]
replica_groups_protos = _get_replica_groups_protos(replica_groups)
if not replica_groups:
split_count = 1
else:
@ -1740,6 +1755,7 @@ _OTHER_OPS = [
'Cholesky',
'Clamp',
'Collapse',
'CollectivePermute',
'ConvertElementType',
'Dot',
'Gather',
@ -1893,3 +1909,14 @@ def _make_replica_group_proto(replica_group):
replica_group_proto = ReplicaGroup()
replica_group_proto.replica_ids.extend(replica_group)
return replica_group_proto
def _get_replica_groups_protos(replica_groups):
if replica_groups is None:
replica_groups_protos = [] # special value for XLA API
else:
replica_groups = list(replica_groups)
replica_groups_protos = [
_make_replica_group_proto(group) for group in replica_groups
]
return replica_groups_protos