[XLA:Python] Add default arguments to some collective ops.

Update some tests to use XlaBuilder instead of the ComputationBuilder wrapper.

PiperOrigin-RevId: 308158895
Change-Id: Ib68832a4c44a7b42f030a8d189fa2db6f42587e4
This commit is contained in:
Peter Hawkins 2020-04-23 16:58:40 -07:00 committed by TensorFlower Gardener
parent 431c46b008
commit c27b63fc5e

View File

@ -324,12 +324,13 @@ void BuildOpsSubmodule(py::module* m) {
XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
const absl::optional<ChannelHandle>&, const absl::optional<Shape>&)>(
&AllReduce),
py::arg("operand"), py::arg("computation"), py::arg("replica_groups"),
py::arg("operand"), py::arg("computation"),
py::arg("replica_groups") = py::list(),
py::arg("channel_id") = absl::nullopt,
py::arg("shape_with_layout") = absl::nullopt);
ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"),
py::arg("concat_dimension"), py::arg("split_count"),
py::arg("replica_groups"));
py::arg("replica_groups") = py::list());
ops.def("CollectivePermute", &CollectivePermute, py::arg("operand"),
py::arg("source_target_pairs"));
ops.def("CreateToken", &CreateToken, py::arg("builder"));