[XLA:Python] Add default arguments to some collective ops.
Update some tests to use XlaBuilder instead of the ComputationBuilder wrapper. PiperOrigin-RevId: 308158895 Change-Id: Ib68832a4c44a7b42f030a8d189fa2db6f42587e4
This commit is contained in:
parent
431c46b008
commit
c27b63fc5e
@ -324,12 +324,13 @@ void BuildOpsSubmodule(py::module* m) {
|
||||
XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
|
||||
const absl::optional<ChannelHandle>&, const absl::optional<Shape>&)>(
|
||||
&AllReduce),
|
||||
py::arg("operand"), py::arg("computation"), py::arg("replica_groups"),
|
||||
py::arg("operand"), py::arg("computation"),
|
||||
py::arg("replica_groups") = py::list(),
|
||||
py::arg("channel_id") = absl::nullopt,
|
||||
py::arg("shape_with_layout") = absl::nullopt);
|
||||
ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"),
|
||||
py::arg("concat_dimension"), py::arg("split_count"),
|
||||
py::arg("replica_groups"));
|
||||
py::arg("replica_groups") = py::list());
|
||||
ops.def("CollectivePermute", &CollectivePermute, py::arg("operand"),
|
||||
py::arg("source_target_pairs"));
|
||||
ops.def("CreateToken", &CreateToken, py::arg("builder"));
|
||||
|
Loading…
Reference in New Issue
Block a user