From 518423dd27e1673a9cafa76507178c11c83de560 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 26 May 2020 18:34:20 -0700 Subject: [PATCH] [XLA:Python] Split bindings for XLA ops into a separate file. No functional changes. This is partially to make xla.cc shorter and partially to parallelize its build time. PiperOrigin-RevId: 313307447 Change-Id: I4f6de5723dbef4464599813bc9284b4ac9e271d7 --- tensorflow/compiler/xla/python/BUILD | 33 ++- tensorflow/compiler/xla/python/ops.cc | 356 ++++++++++++++++++++++++++ tensorflow/compiler/xla/python/ops.h | 27 ++ tensorflow/compiler/xla/python/xla.cc | 322 +---------------------- 4 files changed, 411 insertions(+), 327 deletions(-) create mode 100644 tensorflow/compiler/xla/python/ops.cc create mode 100644 tensorflow/compiler/xla/python/ops.h diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 5b4182b75e1..3dcdc46040a 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -186,6 +186,32 @@ cc_library( ], ) +cc_library( + name = "ops", + srcs = ["ops.cc"], + hdrs = ["ops.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":types", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:comparators", + "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:qr", + "//tensorflow/compiler/xla/client/lib:self_adjoint_eig", + "//tensorflow/compiler/xla/client/lib:sorting", + "//tensorflow/compiler/xla/client/lib:svd", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@pybind11", + ], +) + config_setting( name = "enable_gpu", values = {"define": "xla_python_enable_gpu=true"}, @@ -205,6 +231,7 @@ pybind_extension( deps = [ ":bfloat16", ":dlpack", + ":ops", ":python_ref_manager", ":types", "@com_google_absl//absl/base", @@ -228,12 +255,6 @@ pybind_extension( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/lib:comparators", - "//tensorflow/compiler/xla/client/lib:math", - "//tensorflow/compiler/xla/client/lib:qr", - "//tensorflow/compiler/xla/client/lib:self_adjoint_eig", - "//tensorflow/compiler/xla/client/lib:sorting", - "//tensorflow/compiler/xla/client/lib:svd", "//tensorflow/compiler/xla/pjrt:cpu_device", "//tensorflow/compiler/xla/pjrt:nvidia_gpu_device", "//tensorflow/compiler/xla/pjrt:pjrt_client", diff --git a/tensorflow/compiler/xla/python/ops.cc b/tensorflow/compiler/xla/python/ops.cc new file mode 100644 index 00000000000..89891d39f78 --- /dev/null +++ b/tensorflow/compiler/xla/python/ops.cc @@ -0,0 +1,356 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/ops.h" + +#include +#include + +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "pybind11/attr.h" +#include "pybind11/pybind11.h" +#include "tensorflow/compiler/xla/client/lib/comparators.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/qr.h" +#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" +#include "tensorflow/compiler/xla/client/lib/sorting.h" +#include "tensorflow/compiler/xla/client/lib/svd.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/python/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +namespace py = pybind11; + +void BuildOpsSubmodule(py::module* m) { + // ops submodule, containing free functions that add operators to an + // XlaBuilder. + py::module ops = m->def_submodule("ops", "XLA operations"); + + py::enum_( + ops, "TriangularSolveOptions_Transpose") + .value("TRANSPOSE_INVALID", TriangularSolveOptions::TRANSPOSE_INVALID) + .value("NO_TRANSPOSE", TriangularSolveOptions::NO_TRANSPOSE) + .value("TRANSPOSE", TriangularSolveOptions::TRANSPOSE) + .value("ADJOINT", TriangularSolveOptions::ADJOINT); + + ops.def("AfterAll", &AfterAll, py::arg("builder"), py::arg("tokens")); + ops.def( + "AllReduce", + static_cast, + const absl::optional&, const absl::optional&)>( + &AllReduce), + py::arg("operand"), py::arg("computation"), + py::arg("replica_groups") = py::list(), + py::arg("channel_id") = absl::nullopt, + py::arg("shape_with_layout") = absl::nullopt); + ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"), + py::arg("concat_dimension"), py::arg("split_count"), + py::arg("replica_groups") = py::list(), + py::arg("layout") = absl::nullopt); + ops.def("CollectivePermute", &CollectivePermute, py::arg("operand"), + py::arg("source_target_pairs")); + ops.def("CreateToken", &CreateToken, py::arg("builder")); + ops.def("CrossReplicaSum", + static_cast)>( + &CrossReplicaSum), + py::arg("operand"), py::arg("replica_groups") = py::list()); + ops.def("BitcastConvertType", &BitcastConvertType, py::arg("operand"), + py::arg("new_element_type")); + ops.def("Broadcast", &Broadcast, py::arg("operand"), py::arg("sizes")); + ops.def("BroadcastInDim", &BroadcastInDim, py::arg("operand"), + py::arg("shape"), py::arg("broadcast_dimensions")); + ops.def("Call", &Call, py::arg("builder"), py::arg("computation"), + py::arg("operands")); + ops.def("Cholesky", &Cholesky, py::arg("a"), py::arg("lower") = true); + ops.def("Clamp", &Clamp, py::arg("min"), py::arg("operand"), py::arg("max")); + ops.def("Collapse", &Collapse, py::arg("operand"), py::arg("dimensions")); + ops.def("ConcatInDim", &ConcatInDim, py::arg("builder"), py::arg("operands"), + py::arg("dimension")); + ops.def("Conditional", + static_cast, + absl::Span)>(&Conditional), + py::arg("branch_index"), py::arg("branch_computations"), + py::arg("branch_operands")); + ops.def("Conditional", + static_cast(&Conditional), + py::arg("predicate"), py::arg("true_operand"), + py::arg("true_computation"), py::arg("false_operand"), + py::arg("false_computation")); + ops.def("Constant", &ConstantLiteral, py::arg("builder"), py::arg("literal")); + ops.def("ConstantLiteral", &ConstantLiteral, py::arg("builder"), + py::arg("literal")); + ops.def("ConvGeneralDilated", &ConvGeneralDilated, py::arg("lhs"), + py::arg("rhs"), py::arg("window_strides"), py::arg("padding"), + py::arg("lhs_dilation"), py::arg("rhs_dilation"), + py::arg("dimension_numbers"), py::arg("feature_group_count") = 1, + py::arg("batch_group_count") = 1, + py::arg("precision_config") = nullptr); + ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"), + py::arg("new_element_type")); + ops.def( + "CustomCall", + [](XlaBuilder* builder, const py::bytes& call_target_name, + absl::Span operands, const Shape& shape, + const py::bytes& opaque) -> XlaOp { + return CustomCall(builder, call_target_name, operands, shape, opaque); + }, + py::arg("builder"), py::arg("call_target_name"), py::arg("operands"), + py::arg("shape"), py::arg("opaque") = py::bytes("")); + ops.def( + "CustomCallWithLayout", + [](XlaBuilder* builder, const py::bytes& call_target_name, + absl::Span operands, const Shape& shape_with_layout, + absl::Span operand_shapes_with_layout, + const py::bytes& opaque) -> XlaOp { + return CustomCallWithLayout(builder, call_target_name, operands, + shape_with_layout, + operand_shapes_with_layout, opaque); + }, + py::arg("builder"), py::arg("call_target_name"), py::arg("operands"), + py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"), + py::arg("opaque") = py::bytes("")); + ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"), + py::arg("precision_config") = nullptr); + ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"), + py::arg("dimension_numbers"), py::arg("precision_config") = nullptr); + ops.def("DynamicSlice", + static_cast, + absl::Span)>(&DynamicSlice), + py::arg("operand"), py::arg("start_indices"), py::arg("slice_sizes")); + ops.def("DynamicUpdateSlice", + static_cast)>( + &DynamicUpdateSlice), + py::arg("operand"), py::arg("update"), py::arg("start_indices")); + + ops.def("Fft", &Fft, py::arg("operand"), py::arg("fft_type"), + py::arg("fft_length")); + + ops.def("Gather", &Gather, py::arg("a"), py::arg("start_indices"), + py::arg("dimension_numbers"), py::arg("slice_sizes"), + py::arg("indices_are_sorted") = false); + ops.def("GetTupleElement", &GetTupleElement, py::arg("tuple_data"), + py::arg("index")); + ops.def("InfeedWithToken", &InfeedWithToken, py::arg("token"), + py::arg("shape"), py::arg("config") = ""); + ops.def("Iota", + static_cast(&Iota), + py::arg("builder"), py::arg("shape"), py::arg("iota_dimension")); + ops.def("Iota", + static_cast(&Iota), + py::arg("builder"), py::arg("type"), py::arg("size")); + ops.def("Map", &Map, py::arg("builder"), py::arg("operands"), + py::arg("computation"), py::arg("dimensions"), + py::arg("static_operands") = py::list()); + ops.def("NextAfter", &NextAfter, py::arg("from"), py::arg("to")); + ops.def("OutfeedWithToken", &OutfeedWithToken, py::arg("operand"), + py::arg("token"), py::arg("shape_with_layout"), + py::arg("outfeed_config") = ""); + ops.def("Pad", &Pad, py::arg("operand"), py::arg("padding_value"), + py::arg("padding_config")); + ops.def("Parameter", + static_cast&)>( + &Parameter), + py::arg("builder"), py::arg("parameter_number"), py::arg("shape"), + py::arg("name") = "", + py::arg("replicated_at_leaf_buffers") = std::vector()); + ops.def( + "QR", + [](XlaOp a, bool full_matrices) -> StatusOr> { + TF_ASSIGN_OR_RETURN(auto qr, QRDecomposition(a, full_matrices)); + return std::make_pair(qr.q, qr.r); + }, + py::arg("operand"), py::arg("full_matrices")); + ops.def( + "Eigh", + [](XlaOp a, bool lower, int64 max_iter, + float epsilon) -> std::pair { + auto eigh = SelfAdjointEig(a, lower, max_iter, epsilon); + return std::make_pair(eigh.v, eigh.w); + }, + py::arg("a"), py::arg("lower") = true, py::arg("max_iter") = 100, + py::arg("epsilon") = 1e-6); + ops.def( + "SVD", + [](XlaOp a, int64 max_iter, + float epsilon) -> std::tuple { + auto svd = SVD(a, max_iter, epsilon); + return std::make_tuple(svd.u, svd.d, svd.v); + }, + py::arg("a"), py::arg("max_iter") = 100, py::arg("epsilon") = 1e-6); + ops.def("Reduce", + static_cast, + absl::Span, const XlaComputation&, + absl::Span)>(&Reduce), + py::arg("builder"), py::arg("operands"), py::arg("init_values"), + py::arg("computation"), py::arg("dimensions_to_reduce")); + ops.def("ReducePrecision", &ReducePrecision, py::arg("operand"), + py::arg("exponent_bits"), py::arg("mantissa_bits")); + ops.def("ReduceWindowWithGeneralPadding", &ReduceWindowWithGeneralPadding, + py::arg("operand"), py::arg("init_value"), py::arg("computation"), + py::arg("window_dimensions"), py::arg("window_strides"), + py::arg("base_dilations"), py::arg("window_dilations"), + py::arg("padding")); + ops.def("ReplicaId", &ReplicaId, py::arg("builder")); + ops.def("Reshape", + static_cast, + absl::Span)>(&Reshape), + py::arg("operand"), py::arg("dimensions"), py::arg("new_sizes")); + ops.def("Reshape", + static_cast)>(&Reshape), + py::arg("operand"), py::arg("new_sizes")); + ops.def("Rev", &Rev, py::arg("operand"), py::arg("dimensions")); + ops.def("RngNormal", &RngNormal, py::arg("mu"), py::arg("sigma"), + py::arg("shape")); + ops.def("RngUniform", &RngUniform, py::arg("a"), py::arg("b"), + py::arg("shape")); + ops.def("Scatter", &Scatter, py::arg("input"), py::arg("scatter_indices"), + py::arg("updates"), py::arg("update_computation"), + py::arg("dimension_numbers"), py::arg("indices_are_sorted") = false, + py::arg("unique_indices") = false); + ops.def("Select", &Select, py::arg("pred"), py::arg("on_true"), + py::arg("on_false")); + ops.def("SelectAndScatterWithGeneralPadding", + &SelectAndScatterWithGeneralPadding, py::arg("operand"), + py::arg("select"), py::arg("window_dimensions"), + py::arg("window_strides"), py::arg("padding"), py::arg("source"), + py::arg("init_value"), py::arg("scatter")); + ops.def("Slice", &Slice, py::arg("operand"), py::arg("start_indices"), + py::arg("limit_indices"), py::arg("strides")); + ops.def("SliceInDim", &SliceInDim, py::arg("operand"), py::arg("start_index"), + py::arg("limit_index"), py::arg("stride"), py::arg("dimno")); + ops.def( + "Sort", + [](XlaBuilder* builder, absl::Span operands, + absl::optional comparator, int64 dimension, + bool is_stable) -> XlaOp { + return builder->ReportErrorOrReturn([&]() -> StatusOr { + std::vector operand_types; + for (const auto& operand : operands) { + TF_ASSIGN_OR_RETURN(auto operand_shape, builder->GetShape(operand)); + operand_types.push_back(operand_shape.element_type()); + } + + if (comparator) { + return Sort(operands, **comparator, dimension, is_stable); + } else { + return Sort(operands, + CreateScalarLtComputation(operand_types, builder), + dimension, is_stable); + } + }); + }, + py::arg("builder"), py::arg("operands"), + py::arg("comparator") = absl::nullopt, py::arg("dimension") = -1, + py::arg("is_stable") = false); + ops.def("TopK", &TopK, py::arg("input"), py::arg("k")); + ops.def("Transpose", &Transpose, py::arg("operand"), py::arg("permutation")); + ops.def("TriangularSolve", &TriangularSolve, py::arg("a"), py::arg("b"), + py::arg("left_side"), py::arg("lower"), py::arg("unit_diagonal"), + py::arg("transpose_a")); + ops.def("Tuple", &Tuple, py::arg("builder"), py::arg("elements")); + ops.def("While", &While, py::arg("condition"), py::arg("body"), + py::arg("init")); + + ops.def("Igamma", &Igamma, py::arg("a"), py::arg("x")); + ops.def("Igammac", &Igammac, py::arg("a"), py::arg("x")); + ops.def("IgammaGradA", &IgammaGradA, py::arg("a"), py::arg("x")); + ops.def("RandomGammaGrad", &RandomGammaGrad, py::arg("a"), py::arg("x")); + ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta, py::arg("a"), + py::arg("b"), py::arg("x")); + +#define BINARY_OP(op) \ + ops.def( \ + #op, \ + [](XlaOp a, XlaOp b, absl::optional> dims) { \ + return dims ? op(a, b, *dims) : op(a, b); \ + }, \ + py::arg("lhs"), py::arg("rhs"), \ + py::arg("broadcast_dimensions") = absl::nullopt) + BINARY_OP(Eq); + BINARY_OP(Ne); + BINARY_OP(Ge); + BINARY_OP(Gt); + BINARY_OP(Lt); + BINARY_OP(Le); + BINARY_OP(Add); + BINARY_OP(Sub); + BINARY_OP(Mul); + BINARY_OP(Div); + BINARY_OP(Rem); + BINARY_OP(Max); + BINARY_OP(Min); + BINARY_OP(And); + BINARY_OP(Or); + BINARY_OP(Xor); + BINARY_OP(ShiftLeft); + BINARY_OP(ShiftRightArithmetic); + BINARY_OP(ShiftRightLogical); + BINARY_OP(Atan2); + BINARY_OP(Pow); + BINARY_OP(Complex); +#undef BINARY_OP + +#define UNARY_OP(op) ops.def(#op, &op) + UNARY_OP(Not); + UNARY_OP(PopulationCount); + UNARY_OP(Clz); + UNARY_OP(Abs); + UNARY_OP(Exp); + UNARY_OP(Expm1); + UNARY_OP(Floor); + UNARY_OP(Ceil); + UNARY_OP(Round); + UNARY_OP(Log); + UNARY_OP(Log1p); + UNARY_OP(Sign); + UNARY_OP(Cos); + UNARY_OP(Sin); + UNARY_OP(Tanh); + UNARY_OP(IsFinite); + UNARY_OP(Neg); + UNARY_OP(Sqrt); + UNARY_OP(Rsqrt); + UNARY_OP(Square); + UNARY_OP(Reciprocal); + UNARY_OP(Erfc); + UNARY_OP(Erf); + UNARY_OP(ErfInv); + UNARY_OP(Lgamma); + UNARY_OP(Digamma); + UNARY_OP(BesselI0e); + UNARY_OP(BesselI1e); + UNARY_OP(Acos); + UNARY_OP(Asin); + UNARY_OP(Atan); + UNARY_OP(Tan); + UNARY_OP(Acosh); + UNARY_OP(Asinh); + UNARY_OP(Atanh); + UNARY_OP(Cosh); + UNARY_OP(Sinh); + UNARY_OP(Real); + UNARY_OP(Imag); + UNARY_OP(Conj); +#undef UNARY_OP +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/ops.h b/tensorflow/compiler/xla/python/ops.h new file mode 100644 index 00000000000..7fe34e941ba --- /dev/null +++ b/tensorflow/compiler/xla/python/ops.h @@ -0,0 +1,27 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_OPS_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_OPS_H_ + +#include "pybind11/pybind11.h" + +namespace xla { + +void BuildOpsSubmodule(pybind11::module* m); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_OPS_H_ diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index abf0937d057..fb7d7df58f7 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -30,12 +30,6 @@ limitations under the License. #include "pybind11/pybind11.h" #include "pybind11/pytypes.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/lib/comparators.h" -#include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/lib/qr.h" -#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" -#include "tensorflow/compiler/xla/client/lib/sorting.h" -#include "tensorflow/compiler/xla/client/lib/svd.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" @@ -48,6 +42,7 @@ limitations under the License. #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/python/bfloat16.h" #include "tensorflow/compiler/xla/python/dlpack.h" +#include "tensorflow/compiler/xla/python/ops.h" #include "tensorflow/compiler/xla/python/python_ref_manager.h" #include "tensorflow/compiler/xla/python/types.h" #include "tensorflow/compiler/xla/service/custom_call_target_registry.h" @@ -306,321 +301,6 @@ StatusOr PjRtBufferCudaArrayInterface(const PjRtBuffer& buffer) { return result; } -void BuildOpsSubmodule(py::module* m) { - // ops submodule, containing free functions that add operators to an - // XlaBuilder. - py::module ops = m->def_submodule("ops", "XLA operations"); - - py::enum_( - ops, "TriangularSolveOptions_Transpose") - .value("TRANSPOSE_INVALID", TriangularSolveOptions::TRANSPOSE_INVALID) - .value("NO_TRANSPOSE", TriangularSolveOptions::NO_TRANSPOSE) - .value("TRANSPOSE", TriangularSolveOptions::TRANSPOSE) - .value("ADJOINT", TriangularSolveOptions::ADJOINT); - - ops.def("AfterAll", &AfterAll, py::arg("builder"), py::arg("tokens")); - ops.def( - "AllReduce", - static_cast, - const absl::optional&, const absl::optional&)>( - &AllReduce), - py::arg("operand"), py::arg("computation"), - py::arg("replica_groups") = py::list(), - py::arg("channel_id") = absl::nullopt, - py::arg("shape_with_layout") = absl::nullopt); - ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"), - py::arg("concat_dimension"), py::arg("split_count"), - py::arg("replica_groups") = py::list(), - py::arg("layout") = absl::nullopt); - ops.def("CollectivePermute", &CollectivePermute, py::arg("operand"), - py::arg("source_target_pairs")); - ops.def("CreateToken", &CreateToken, py::arg("builder")); - ops.def("CrossReplicaSum", - static_cast)>( - &CrossReplicaSum), - py::arg("operand"), py::arg("replica_groups") = py::list()); - ops.def("BitcastConvertType", &BitcastConvertType, py::arg("operand"), - py::arg("new_element_type")); - ops.def("Broadcast", &Broadcast, py::arg("operand"), py::arg("sizes")); - ops.def("BroadcastInDim", &BroadcastInDim, py::arg("operand"), - py::arg("shape"), py::arg("broadcast_dimensions")); - ops.def("Call", &Call, py::arg("builder"), py::arg("computation"), - py::arg("operands")); - ops.def("Cholesky", &Cholesky, py::arg("a"), py::arg("lower") = true); - ops.def("Clamp", &Clamp, py::arg("min"), py::arg("operand"), py::arg("max")); - ops.def("Collapse", &Collapse, py::arg("operand"), py::arg("dimensions")); - ops.def("ConcatInDim", &ConcatInDim, py::arg("builder"), py::arg("operands"), - py::arg("dimension")); - ops.def("Conditional", - static_cast, - absl::Span)>(&Conditional), - py::arg("branch_index"), py::arg("branch_computations"), - py::arg("branch_operands")); - ops.def("Conditional", - static_cast(&Conditional), - py::arg("predicate"), py::arg("true_operand"), - py::arg("true_computation"), py::arg("false_operand"), - py::arg("false_computation")); - ops.def("Constant", &ConstantLiteral, py::arg("builder"), py::arg("literal")); - ops.def("ConstantLiteral", &ConstantLiteral, py::arg("builder"), - py::arg("literal")); - ops.def("ConvGeneralDilated", &ConvGeneralDilated, py::arg("lhs"), - py::arg("rhs"), py::arg("window_strides"), py::arg("padding"), - py::arg("lhs_dilation"), py::arg("rhs_dilation"), - py::arg("dimension_numbers"), py::arg("feature_group_count") = 1, - py::arg("batch_group_count") = 1, - py::arg("precision_config") = nullptr); - ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"), - py::arg("new_element_type")); - ops.def( - "CustomCall", - [](XlaBuilder* builder, const py::bytes& call_target_name, - absl::Span operands, const Shape& shape, - const py::bytes& opaque) -> XlaOp { - return CustomCall(builder, call_target_name, operands, shape, opaque); - }, - py::arg("builder"), py::arg("call_target_name"), py::arg("operands"), - py::arg("shape"), py::arg("opaque") = py::bytes("")); - ops.def( - "CustomCallWithLayout", - [](XlaBuilder* builder, const py::bytes& call_target_name, - absl::Span operands, const Shape& shape_with_layout, - absl::Span operand_shapes_with_layout, - const py::bytes& opaque) -> XlaOp { - return CustomCallWithLayout(builder, call_target_name, operands, - shape_with_layout, - operand_shapes_with_layout, opaque); - }, - py::arg("builder"), py::arg("call_target_name"), py::arg("operands"), - py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"), - py::arg("opaque") = py::bytes("")); - ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"), - py::arg("precision_config") = nullptr); - ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"), - py::arg("dimension_numbers"), py::arg("precision_config") = nullptr); - ops.def("DynamicSlice", - static_cast, - absl::Span)>(&DynamicSlice), - py::arg("operand"), py::arg("start_indices"), py::arg("slice_sizes")); - ops.def("DynamicUpdateSlice", - static_cast)>( - &DynamicUpdateSlice), - py::arg("operand"), py::arg("update"), py::arg("start_indices")); - - ops.def("Fft", &Fft, py::arg("operand"), py::arg("fft_type"), - py::arg("fft_length")); - - ops.def("Gather", &Gather, py::arg("a"), py::arg("start_indices"), - py::arg("dimension_numbers"), py::arg("slice_sizes"), - py::arg("indices_are_sorted") = false); - ops.def("GetTupleElement", &GetTupleElement, py::arg("tuple_data"), - py::arg("index")); - ops.def("InfeedWithToken", &InfeedWithToken, py::arg("token"), - py::arg("shape"), py::arg("config") = ""); - ops.def("Iota", - static_cast(&Iota), - py::arg("builder"), py::arg("shape"), py::arg("iota_dimension")); - ops.def("Iota", - static_cast(&Iota), - py::arg("builder"), py::arg("type"), py::arg("size")); - ops.def("Map", &Map, py::arg("builder"), py::arg("operands"), - py::arg("computation"), py::arg("dimensions"), - py::arg("static_operands") = py::list()); - ops.def("NextAfter", &NextAfter, py::arg("from"), py::arg("to")); - ops.def("OutfeedWithToken", &OutfeedWithToken, py::arg("operand"), - py::arg("token"), py::arg("shape_with_layout"), - py::arg("outfeed_config") = ""); - ops.def("Pad", &Pad, py::arg("operand"), py::arg("padding_value"), - py::arg("padding_config")); - ops.def("Parameter", - static_cast&)>( - &Parameter), - py::arg("builder"), py::arg("parameter_number"), py::arg("shape"), - py::arg("name") = "", - py::arg("replicated_at_leaf_buffers") = std::vector()); - ops.def( - "QR", - [](XlaOp a, bool full_matrices) -> StatusOr> { - TF_ASSIGN_OR_RETURN(auto qr, QRDecomposition(a, full_matrices)); - return std::make_pair(qr.q, qr.r); - }, - py::arg("operand"), py::arg("full_matrices")); - ops.def( - "Eigh", - [](XlaOp a, bool lower, int64 max_iter, - float epsilon) -> std::pair { - auto eigh = SelfAdjointEig(a, lower, max_iter, epsilon); - return std::make_pair(eigh.v, eigh.w); - }, - py::arg("a"), py::arg("lower") = true, py::arg("max_iter") = 100, - py::arg("epsilon") = 1e-6); - ops.def( - "SVD", - [](XlaOp a, int64 max_iter, - float epsilon) -> std::tuple { - auto svd = SVD(a, max_iter, epsilon); - return std::make_tuple(svd.u, svd.d, svd.v); - }, - py::arg("a"), py::arg("max_iter") = 100, py::arg("epsilon") = 1e-6); - ops.def("Reduce", - static_cast, - absl::Span, const XlaComputation&, - absl::Span)>(&Reduce), - py::arg("builder"), py::arg("operands"), py::arg("init_values"), - py::arg("computation"), py::arg("dimensions_to_reduce")); - ops.def("ReducePrecision", &ReducePrecision, py::arg("operand"), - py::arg("exponent_bits"), py::arg("mantissa_bits")); - ops.def("ReduceWindowWithGeneralPadding", &ReduceWindowWithGeneralPadding, - py::arg("operand"), py::arg("init_value"), py::arg("computation"), - py::arg("window_dimensions"), py::arg("window_strides"), - py::arg("base_dilations"), py::arg("window_dilations"), - py::arg("padding")); - ops.def("ReplicaId", &ReplicaId, py::arg("builder")); - ops.def("Reshape", - static_cast, - absl::Span)>(&Reshape), - py::arg("operand"), py::arg("dimensions"), py::arg("new_sizes")); - ops.def("Reshape", - static_cast)>(&Reshape), - py::arg("operand"), py::arg("new_sizes")); - ops.def("Rev", &Rev, py::arg("operand"), py::arg("dimensions")); - ops.def("RngNormal", &RngNormal, py::arg("mu"), py::arg("sigma"), - py::arg("shape")); - ops.def("RngUniform", &RngUniform, py::arg("a"), py::arg("b"), - py::arg("shape")); - ops.def("Scatter", &Scatter, py::arg("input"), py::arg("scatter_indices"), - py::arg("updates"), py::arg("update_computation"), - py::arg("dimension_numbers"), py::arg("indices_are_sorted") = false, - py::arg("unique_indices") = false); - ops.def("Select", &Select, py::arg("pred"), py::arg("on_true"), - py::arg("on_false")); - ops.def("SelectAndScatterWithGeneralPadding", - &SelectAndScatterWithGeneralPadding, py::arg("operand"), - py::arg("select"), py::arg("window_dimensions"), - py::arg("window_strides"), py::arg("padding"), py::arg("source"), - py::arg("init_value"), py::arg("scatter")); - ops.def("Slice", &Slice, py::arg("operand"), py::arg("start_indices"), - py::arg("limit_indices"), py::arg("strides")); - ops.def("SliceInDim", &SliceInDim, py::arg("operand"), py::arg("start_index"), - py::arg("limit_index"), py::arg("stride"), py::arg("dimno")); - ops.def( - "Sort", - [](XlaBuilder* builder, absl::Span operands, - absl::optional comparator, int64 dimension, - bool is_stable) -> XlaOp { - return builder->ReportErrorOrReturn([&]() -> StatusOr { - std::vector operand_types; - for (const auto& operand : operands) { - TF_ASSIGN_OR_RETURN(auto operand_shape, builder->GetShape(operand)); - operand_types.push_back(operand_shape.element_type()); - } - - if (comparator) { - return Sort(operands, **comparator, dimension, is_stable); - } else { - return Sort(operands, - CreateScalarLtComputation(operand_types, builder), - dimension, is_stable); - } - }); - }, - py::arg("builder"), py::arg("operands"), - py::arg("comparator") = absl::nullopt, py::arg("dimension") = -1, - py::arg("is_stable") = false); - ops.def("TopK", &TopK, py::arg("input"), py::arg("k")); - ops.def("Transpose", &Transpose, py::arg("operand"), py::arg("permutation")); - ops.def("TriangularSolve", &TriangularSolve, py::arg("a"), py::arg("b"), - py::arg("left_side"), py::arg("lower"), py::arg("unit_diagonal"), - py::arg("transpose_a")); - ops.def("Tuple", &Tuple, py::arg("builder"), py::arg("elements")); - ops.def("While", &While, py::arg("condition"), py::arg("body"), - py::arg("init")); - - ops.def("Igamma", &Igamma, py::arg("a"), py::arg("x")); - ops.def("Igammac", &Igammac, py::arg("a"), py::arg("x")); - ops.def("IgammaGradA", &IgammaGradA, py::arg("a"), py::arg("x")); - ops.def("RandomGammaGrad", &RandomGammaGrad, py::arg("a"), py::arg("x")); - ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta, py::arg("a"), - py::arg("b"), py::arg("x")); - -#define BINARY_OP(op) \ - ops.def( \ - #op, \ - [](XlaOp a, XlaOp b, absl::optional> dims) { \ - return dims ? op(a, b, *dims) : op(a, b); \ - }, \ - py::arg("lhs"), py::arg("rhs"), \ - py::arg("broadcast_dimensions") = absl::nullopt) - BINARY_OP(Eq); - BINARY_OP(Ne); - BINARY_OP(Ge); - BINARY_OP(Gt); - BINARY_OP(Lt); - BINARY_OP(Le); - BINARY_OP(Add); - BINARY_OP(Sub); - BINARY_OP(Mul); - BINARY_OP(Div); - BINARY_OP(Rem); - BINARY_OP(Max); - BINARY_OP(Min); - BINARY_OP(And); - BINARY_OP(Or); - BINARY_OP(Xor); - BINARY_OP(ShiftLeft); - BINARY_OP(ShiftRightArithmetic); - BINARY_OP(ShiftRightLogical); - BINARY_OP(Atan2); - BINARY_OP(Pow); - BINARY_OP(Complex); -#undef BINARY_OP - -#define UNARY_OP(op) ops.def(#op, &op) - UNARY_OP(Not); - UNARY_OP(PopulationCount); - UNARY_OP(Clz); - UNARY_OP(Abs); - UNARY_OP(Exp); - UNARY_OP(Expm1); - UNARY_OP(Floor); - UNARY_OP(Ceil); - UNARY_OP(Round); - UNARY_OP(Log); - UNARY_OP(Log1p); - UNARY_OP(Sign); - UNARY_OP(Cos); - UNARY_OP(Sin); - UNARY_OP(Tanh); - UNARY_OP(IsFinite); - UNARY_OP(Neg); - UNARY_OP(Sqrt); - UNARY_OP(Rsqrt); - UNARY_OP(Square); - UNARY_OP(Reciprocal); - UNARY_OP(Erfc); - UNARY_OP(Erf); - UNARY_OP(ErfInv); - UNARY_OP(Lgamma); - UNARY_OP(Digamma); - UNARY_OP(BesselI0e); - UNARY_OP(BesselI1e); - UNARY_OP(Acos); - UNARY_OP(Asin); - UNARY_OP(Atan); - UNARY_OP(Tan); - UNARY_OP(Acosh); - UNARY_OP(Asinh); - UNARY_OP(Atanh); - UNARY_OP(Cosh); - UNARY_OP(Sinh); - UNARY_OP(Real); - UNARY_OP(Imag); - UNARY_OP(Conj); -#undef UNARY_OP -} void BuildProfilerSubmodule(py::module* m) { py::module profiler =