[XLA:Python] Split bindings for XLA ops into a separate file. No functional changes.

This is partially to make xla.cc shorter and partially to parallelize its build time.

PiperOrigin-RevId: 313307447
Change-Id: I4f6de5723dbef4464599813bc9284b4ac9e271d7
This commit is contained in:
Peter Hawkins 2020-05-26 18:34:20 -07:00 committed by TensorFlower Gardener
parent 05653928da
commit 518423dd27
4 changed files with 411 additions and 327 deletions
tensorflow/compiler/xla/python

View File

@ -186,6 +186,32 @@ cc_library(
],
)
cc_library(
name = "ops",
srcs = ["ops.cc"],
hdrs = ["ops.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
deps = [
":types",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:comparators",
"//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/lib:qr",
"//tensorflow/compiler/xla/client/lib:self_adjoint_eig",
"//tensorflow/compiler/xla/client/lib:sorting",
"//tensorflow/compiler/xla/client/lib:svd",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
"@pybind11",
],
)
config_setting(
name = "enable_gpu",
values = {"define": "xla_python_enable_gpu=true"},
@ -205,6 +231,7 @@ pybind_extension(
deps = [
":bfloat16",
":dlpack",
":ops",
":python_ref_manager",
":types",
"@com_google_absl//absl/base",
@ -228,12 +255,6 @@ pybind_extension(
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:comparators",
"//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/lib:qr",
"//tensorflow/compiler/xla/client/lib:self_adjoint_eig",
"//tensorflow/compiler/xla/client/lib:sorting",
"//tensorflow/compiler/xla/client/lib:svd",
"//tensorflow/compiler/xla/pjrt:cpu_device",
"//tensorflow/compiler/xla/pjrt:nvidia_gpu_device",
"//tensorflow/compiler/xla/pjrt:pjrt_client",

View File

@ -0,0 +1,356 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/ops.h"
#include <string>
#include <vector>
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "pybind11/attr.h"
#include "pybind11/pybind11.h"
#include "tensorflow/compiler/xla/client/lib/comparators.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/qr.h"
#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
#include "tensorflow/compiler/xla/client/lib/sorting.h"
#include "tensorflow/compiler/xla/client/lib/svd.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/python/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace py = pybind11;
void BuildOpsSubmodule(py::module* m) {
// ops submodule, containing free functions that add operators to an
// XlaBuilder.
py::module ops = m->def_submodule("ops", "XLA operations");
py::enum_<TriangularSolveOptions::Transpose>(
ops, "TriangularSolveOptions_Transpose")
.value("TRANSPOSE_INVALID", TriangularSolveOptions::TRANSPOSE_INVALID)
.value("NO_TRANSPOSE", TriangularSolveOptions::NO_TRANSPOSE)
.value("TRANSPOSE", TriangularSolveOptions::TRANSPOSE)
.value("ADJOINT", TriangularSolveOptions::ADJOINT);
ops.def("AfterAll", &AfterAll, py::arg("builder"), py::arg("tokens"));
ops.def(
"AllReduce",
static_cast<XlaOp (*)(
XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
const absl::optional<ChannelHandle>&, const absl::optional<Shape>&)>(
&AllReduce),
py::arg("operand"), py::arg("computation"),
py::arg("replica_groups") = py::list(),
py::arg("channel_id") = absl::nullopt,
py::arg("shape_with_layout") = absl::nullopt);
ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"),
py::arg("concat_dimension"), py::arg("split_count"),
py::arg("replica_groups") = py::list(),
py::arg("layout") = absl::nullopt);
ops.def("CollectivePermute", &CollectivePermute, py::arg("operand"),
py::arg("source_target_pairs"));
ops.def("CreateToken", &CreateToken, py::arg("builder"));
ops.def("CrossReplicaSum",
static_cast<XlaOp (*)(XlaOp, absl::Span<const ReplicaGroup>)>(
&CrossReplicaSum),
py::arg("operand"), py::arg("replica_groups") = py::list());
ops.def("BitcastConvertType", &BitcastConvertType, py::arg("operand"),
py::arg("new_element_type"));
ops.def("Broadcast", &Broadcast, py::arg("operand"), py::arg("sizes"));
ops.def("BroadcastInDim", &BroadcastInDim, py::arg("operand"),
py::arg("shape"), py::arg("broadcast_dimensions"));
ops.def("Call", &Call, py::arg("builder"), py::arg("computation"),
py::arg("operands"));
ops.def("Cholesky", &Cholesky, py::arg("a"), py::arg("lower") = true);
ops.def("Clamp", &Clamp, py::arg("min"), py::arg("operand"), py::arg("max"));
ops.def("Collapse", &Collapse, py::arg("operand"), py::arg("dimensions"));
ops.def("ConcatInDim", &ConcatInDim, py::arg("builder"), py::arg("operands"),
py::arg("dimension"));
ops.def("Conditional",
static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaComputation* const>,
absl::Span<const XlaOp>)>(&Conditional),
py::arg("branch_index"), py::arg("branch_computations"),
py::arg("branch_operands"));
ops.def("Conditional",
static_cast<XlaOp (*)(XlaOp, XlaOp, const XlaComputation&, XlaOp,
const XlaComputation&)>(&Conditional),
py::arg("predicate"), py::arg("true_operand"),
py::arg("true_computation"), py::arg("false_operand"),
py::arg("false_computation"));
ops.def("Constant", &ConstantLiteral, py::arg("builder"), py::arg("literal"));
ops.def("ConstantLiteral", &ConstantLiteral, py::arg("builder"),
py::arg("literal"));
ops.def("ConvGeneralDilated", &ConvGeneralDilated, py::arg("lhs"),
py::arg("rhs"), py::arg("window_strides"), py::arg("padding"),
py::arg("lhs_dilation"), py::arg("rhs_dilation"),
py::arg("dimension_numbers"), py::arg("feature_group_count") = 1,
py::arg("batch_group_count") = 1,
py::arg("precision_config") = nullptr);
ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"),
py::arg("new_element_type"));
ops.def(
"CustomCall",
[](XlaBuilder* builder, const py::bytes& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape,
const py::bytes& opaque) -> XlaOp {
return CustomCall(builder, call_target_name, operands, shape, opaque);
},
py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
py::arg("shape"), py::arg("opaque") = py::bytes(""));
ops.def(
"CustomCallWithLayout",
[](XlaBuilder* builder, const py::bytes& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
absl::Span<const Shape> operand_shapes_with_layout,
const py::bytes& opaque) -> XlaOp {
return CustomCallWithLayout(builder, call_target_name, operands,
shape_with_layout,
operand_shapes_with_layout, opaque);
},
py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"),
py::arg("opaque") = py::bytes(""));
ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"),
py::arg("precision_config") = nullptr);
ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"),
py::arg("dimension_numbers"), py::arg("precision_config") = nullptr);
ops.def("DynamicSlice",
static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaOp>,
absl::Span<const int64>)>(&DynamicSlice),
py::arg("operand"), py::arg("start_indices"), py::arg("slice_sizes"));
ops.def("DynamicUpdateSlice",
static_cast<XlaOp (*)(XlaOp, XlaOp, absl::Span<const XlaOp>)>(
&DynamicUpdateSlice),
py::arg("operand"), py::arg("update"), py::arg("start_indices"));
ops.def("Fft", &Fft, py::arg("operand"), py::arg("fft_type"),
py::arg("fft_length"));
ops.def("Gather", &Gather, py::arg("a"), py::arg("start_indices"),
py::arg("dimension_numbers"), py::arg("slice_sizes"),
py::arg("indices_are_sorted") = false);
ops.def("GetTupleElement", &GetTupleElement, py::arg("tuple_data"),
py::arg("index"));
ops.def("InfeedWithToken", &InfeedWithToken, py::arg("token"),
py::arg("shape"), py::arg("config") = "");
ops.def("Iota",
static_cast<XlaOp (*)(XlaBuilder*, const Shape&, int64)>(&Iota),
py::arg("builder"), py::arg("shape"), py::arg("iota_dimension"));
ops.def("Iota",
static_cast<XlaOp (*)(XlaBuilder*, PrimitiveType, int64)>(&Iota),
py::arg("builder"), py::arg("type"), py::arg("size"));
ops.def("Map", &Map, py::arg("builder"), py::arg("operands"),
py::arg("computation"), py::arg("dimensions"),
py::arg("static_operands") = py::list());
ops.def("NextAfter", &NextAfter, py::arg("from"), py::arg("to"));
ops.def("OutfeedWithToken", &OutfeedWithToken, py::arg("operand"),
py::arg("token"), py::arg("shape_with_layout"),
py::arg("outfeed_config") = "");
ops.def("Pad", &Pad, py::arg("operand"), py::arg("padding_value"),
py::arg("padding_config"));
ops.def("Parameter",
static_cast<XlaOp (*)(XlaBuilder*, int64, const Shape&,
const std::string&, const std::vector<bool>&)>(
&Parameter),
py::arg("builder"), py::arg("parameter_number"), py::arg("shape"),
py::arg("name") = "",
py::arg("replicated_at_leaf_buffers") = std::vector<bool>());
ops.def(
"QR",
[](XlaOp a, bool full_matrices) -> StatusOr<std::pair<XlaOp, XlaOp>> {
TF_ASSIGN_OR_RETURN(auto qr, QRDecomposition(a, full_matrices));
return std::make_pair(qr.q, qr.r);
},
py::arg("operand"), py::arg("full_matrices"));
ops.def(
"Eigh",
[](XlaOp a, bool lower, int64 max_iter,
float epsilon) -> std::pair<XlaOp, XlaOp> {
auto eigh = SelfAdjointEig(a, lower, max_iter, epsilon);
return std::make_pair(eigh.v, eigh.w);
},
py::arg("a"), py::arg("lower") = true, py::arg("max_iter") = 100,
py::arg("epsilon") = 1e-6);
ops.def(
"SVD",
[](XlaOp a, int64 max_iter,
float epsilon) -> std::tuple<XlaOp, XlaOp, XlaOp> {
auto svd = SVD(a, max_iter, epsilon);
return std::make_tuple(svd.u, svd.d, svd.v);
},
py::arg("a"), py::arg("max_iter") = 100, py::arg("epsilon") = 1e-6);
ops.def("Reduce",
static_cast<XlaOp (*)(XlaBuilder*, absl::Span<const XlaOp>,
absl::Span<const XlaOp>, const XlaComputation&,
absl::Span<const int64>)>(&Reduce),
py::arg("builder"), py::arg("operands"), py::arg("init_values"),
py::arg("computation"), py::arg("dimensions_to_reduce"));
ops.def("ReducePrecision", &ReducePrecision, py::arg("operand"),
py::arg("exponent_bits"), py::arg("mantissa_bits"));
ops.def("ReduceWindowWithGeneralPadding", &ReduceWindowWithGeneralPadding,
py::arg("operand"), py::arg("init_value"), py::arg("computation"),
py::arg("window_dimensions"), py::arg("window_strides"),
py::arg("base_dilations"), py::arg("window_dilations"),
py::arg("padding"));
ops.def("ReplicaId", &ReplicaId, py::arg("builder"));
ops.def("Reshape",
static_cast<XlaOp (*)(XlaOp, absl::Span<const int64>,
absl::Span<const int64>)>(&Reshape),
py::arg("operand"), py::arg("dimensions"), py::arg("new_sizes"));
ops.def("Reshape",
static_cast<XlaOp (*)(XlaOp, absl::Span<const int64>)>(&Reshape),
py::arg("operand"), py::arg("new_sizes"));
ops.def("Rev", &Rev, py::arg("operand"), py::arg("dimensions"));
ops.def("RngNormal", &RngNormal, py::arg("mu"), py::arg("sigma"),
py::arg("shape"));
ops.def("RngUniform", &RngUniform, py::arg("a"), py::arg("b"),
py::arg("shape"));
ops.def("Scatter", &Scatter, py::arg("input"), py::arg("scatter_indices"),
py::arg("updates"), py::arg("update_computation"),
py::arg("dimension_numbers"), py::arg("indices_are_sorted") = false,
py::arg("unique_indices") = false);
ops.def("Select", &Select, py::arg("pred"), py::arg("on_true"),
py::arg("on_false"));
ops.def("SelectAndScatterWithGeneralPadding",
&SelectAndScatterWithGeneralPadding, py::arg("operand"),
py::arg("select"), py::arg("window_dimensions"),
py::arg("window_strides"), py::arg("padding"), py::arg("source"),
py::arg("init_value"), py::arg("scatter"));
ops.def("Slice", &Slice, py::arg("operand"), py::arg("start_indices"),
py::arg("limit_indices"), py::arg("strides"));
ops.def("SliceInDim", &SliceInDim, py::arg("operand"), py::arg("start_index"),
py::arg("limit_index"), py::arg("stride"), py::arg("dimno"));
ops.def(
"Sort",
[](XlaBuilder* builder, absl::Span<const XlaOp> operands,
absl::optional<const XlaComputation*> comparator, int64 dimension,
bool is_stable) -> XlaOp {
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
std::vector<PrimitiveType> operand_types;
for (const auto& operand : operands) {
TF_ASSIGN_OR_RETURN(auto operand_shape, builder->GetShape(operand));
operand_types.push_back(operand_shape.element_type());
}
if (comparator) {
return Sort(operands, **comparator, dimension, is_stable);
} else {
return Sort(operands,
CreateScalarLtComputation(operand_types, builder),
dimension, is_stable);
}
});
},
py::arg("builder"), py::arg("operands"),
py::arg("comparator") = absl::nullopt, py::arg("dimension") = -1,
py::arg("is_stable") = false);
ops.def("TopK", &TopK, py::arg("input"), py::arg("k"));
ops.def("Transpose", &Transpose, py::arg("operand"), py::arg("permutation"));
ops.def("TriangularSolve", &TriangularSolve, py::arg("a"), py::arg("b"),
py::arg("left_side"), py::arg("lower"), py::arg("unit_diagonal"),
py::arg("transpose_a"));
ops.def("Tuple", &Tuple, py::arg("builder"), py::arg("elements"));
ops.def("While", &While, py::arg("condition"), py::arg("body"),
py::arg("init"));
ops.def("Igamma", &Igamma, py::arg("a"), py::arg("x"));
ops.def("Igammac", &Igammac, py::arg("a"), py::arg("x"));
ops.def("IgammaGradA", &IgammaGradA, py::arg("a"), py::arg("x"));
ops.def("RandomGammaGrad", &RandomGammaGrad, py::arg("a"), py::arg("x"));
ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta, py::arg("a"),
py::arg("b"), py::arg("x"));
#define BINARY_OP(op) \
ops.def( \
#op, \
[](XlaOp a, XlaOp b, absl::optional<std::vector<int64>> dims) { \
return dims ? op(a, b, *dims) : op(a, b); \
}, \
py::arg("lhs"), py::arg("rhs"), \
py::arg("broadcast_dimensions") = absl::nullopt)
BINARY_OP(Eq);
BINARY_OP(Ne);
BINARY_OP(Ge);
BINARY_OP(Gt);
BINARY_OP(Lt);
BINARY_OP(Le);
BINARY_OP(Add);
BINARY_OP(Sub);
BINARY_OP(Mul);
BINARY_OP(Div);
BINARY_OP(Rem);
BINARY_OP(Max);
BINARY_OP(Min);
BINARY_OP(And);
BINARY_OP(Or);
BINARY_OP(Xor);
BINARY_OP(ShiftLeft);
BINARY_OP(ShiftRightArithmetic);
BINARY_OP(ShiftRightLogical);
BINARY_OP(Atan2);
BINARY_OP(Pow);
BINARY_OP(Complex);
#undef BINARY_OP
#define UNARY_OP(op) ops.def(#op, &op)
UNARY_OP(Not);
UNARY_OP(PopulationCount);
UNARY_OP(Clz);
UNARY_OP(Abs);
UNARY_OP(Exp);
UNARY_OP(Expm1);
UNARY_OP(Floor);
UNARY_OP(Ceil);
UNARY_OP(Round);
UNARY_OP(Log);
UNARY_OP(Log1p);
UNARY_OP(Sign);
UNARY_OP(Cos);
UNARY_OP(Sin);
UNARY_OP(Tanh);
UNARY_OP(IsFinite);
UNARY_OP(Neg);
UNARY_OP(Sqrt);
UNARY_OP(Rsqrt);
UNARY_OP(Square);
UNARY_OP(Reciprocal);
UNARY_OP(Erfc);
UNARY_OP(Erf);
UNARY_OP(ErfInv);
UNARY_OP(Lgamma);
UNARY_OP(Digamma);
UNARY_OP(BesselI0e);
UNARY_OP(BesselI1e);
UNARY_OP(Acos);
UNARY_OP(Asin);
UNARY_OP(Atan);
UNARY_OP(Tan);
UNARY_OP(Acosh);
UNARY_OP(Asinh);
UNARY_OP(Atanh);
UNARY_OP(Cosh);
UNARY_OP(Sinh);
UNARY_OP(Real);
UNARY_OP(Imag);
UNARY_OP(Conj);
#undef UNARY_OP
}
} // namespace xla

View File

@ -0,0 +1,27 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_OPS_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_OPS_H_
#include "pybind11/pybind11.h"
namespace xla {
void BuildOpsSubmodule(pybind11::module* m);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_OPS_H_

View File

@ -30,12 +30,6 @@ limitations under the License.
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/lib/comparators.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/qr.h"
#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
#include "tensorflow/compiler/xla/client/lib/sorting.h"
#include "tensorflow/compiler/xla/client/lib/svd.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
@ -48,6 +42,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/python/bfloat16.h"
#include "tensorflow/compiler/xla/python/dlpack.h"
#include "tensorflow/compiler/xla/python/ops.h"
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
#include "tensorflow/compiler/xla/python/types.h"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
@ -306,321 +301,6 @@ StatusOr<py::dict> PjRtBufferCudaArrayInterface(const PjRtBuffer& buffer) {
return result;
}
void BuildOpsSubmodule(py::module* m) {
// ops submodule, containing free functions that add operators to an
// XlaBuilder.
py::module ops = m->def_submodule("ops", "XLA operations");
py::enum_<TriangularSolveOptions::Transpose>(
ops, "TriangularSolveOptions_Transpose")
.value("TRANSPOSE_INVALID", TriangularSolveOptions::TRANSPOSE_INVALID)
.value("NO_TRANSPOSE", TriangularSolveOptions::NO_TRANSPOSE)
.value("TRANSPOSE", TriangularSolveOptions::TRANSPOSE)
.value("ADJOINT", TriangularSolveOptions::ADJOINT);
ops.def("AfterAll", &AfterAll, py::arg("builder"), py::arg("tokens"));
ops.def(
"AllReduce",
static_cast<XlaOp (*)(
XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
const absl::optional<ChannelHandle>&, const absl::optional<Shape>&)>(
&AllReduce),
py::arg("operand"), py::arg("computation"),
py::arg("replica_groups") = py::list(),
py::arg("channel_id") = absl::nullopt,
py::arg("shape_with_layout") = absl::nullopt);
ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"),
py::arg("concat_dimension"), py::arg("split_count"),
py::arg("replica_groups") = py::list(),
py::arg("layout") = absl::nullopt);
ops.def("CollectivePermute", &CollectivePermute, py::arg("operand"),
py::arg("source_target_pairs"));
ops.def("CreateToken", &CreateToken, py::arg("builder"));
ops.def("CrossReplicaSum",
static_cast<XlaOp (*)(XlaOp, absl::Span<const ReplicaGroup>)>(
&CrossReplicaSum),
py::arg("operand"), py::arg("replica_groups") = py::list());
ops.def("BitcastConvertType", &BitcastConvertType, py::arg("operand"),
py::arg("new_element_type"));
ops.def("Broadcast", &Broadcast, py::arg("operand"), py::arg("sizes"));
ops.def("BroadcastInDim", &BroadcastInDim, py::arg("operand"),
py::arg("shape"), py::arg("broadcast_dimensions"));
ops.def("Call", &Call, py::arg("builder"), py::arg("computation"),
py::arg("operands"));
ops.def("Cholesky", &Cholesky, py::arg("a"), py::arg("lower") = true);
ops.def("Clamp", &Clamp, py::arg("min"), py::arg("operand"), py::arg("max"));
ops.def("Collapse", &Collapse, py::arg("operand"), py::arg("dimensions"));
ops.def("ConcatInDim", &ConcatInDim, py::arg("builder"), py::arg("operands"),
py::arg("dimension"));
ops.def("Conditional",
static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaComputation* const>,
absl::Span<const XlaOp>)>(&Conditional),
py::arg("branch_index"), py::arg("branch_computations"),
py::arg("branch_operands"));
ops.def("Conditional",
static_cast<XlaOp (*)(XlaOp, XlaOp, const XlaComputation&, XlaOp,
const XlaComputation&)>(&Conditional),
py::arg("predicate"), py::arg("true_operand"),
py::arg("true_computation"), py::arg("false_operand"),
py::arg("false_computation"));
ops.def("Constant", &ConstantLiteral, py::arg("builder"), py::arg("literal"));
ops.def("ConstantLiteral", &ConstantLiteral, py::arg("builder"),
py::arg("literal"));
ops.def("ConvGeneralDilated", &ConvGeneralDilated, py::arg("lhs"),
py::arg("rhs"), py::arg("window_strides"), py::arg("padding"),
py::arg("lhs_dilation"), py::arg("rhs_dilation"),
py::arg("dimension_numbers"), py::arg("feature_group_count") = 1,
py::arg("batch_group_count") = 1,
py::arg("precision_config") = nullptr);
ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"),
py::arg("new_element_type"));
ops.def(
"CustomCall",
[](XlaBuilder* builder, const py::bytes& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape,
const py::bytes& opaque) -> XlaOp {
return CustomCall(builder, call_target_name, operands, shape, opaque);
},
py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
py::arg("shape"), py::arg("opaque") = py::bytes(""));
ops.def(
"CustomCallWithLayout",
[](XlaBuilder* builder, const py::bytes& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
absl::Span<const Shape> operand_shapes_with_layout,
const py::bytes& opaque) -> XlaOp {
return CustomCallWithLayout(builder, call_target_name, operands,
shape_with_layout,
operand_shapes_with_layout, opaque);
},
py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"),
py::arg("opaque") = py::bytes(""));
ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"),
py::arg("precision_config") = nullptr);
ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"),
py::arg("dimension_numbers"), py::arg("precision_config") = nullptr);
ops.def("DynamicSlice",
static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaOp>,
absl::Span<const int64>)>(&DynamicSlice),
py::arg("operand"), py::arg("start_indices"), py::arg("slice_sizes"));
ops.def("DynamicUpdateSlice",
static_cast<XlaOp (*)(XlaOp, XlaOp, absl::Span<const XlaOp>)>(
&DynamicUpdateSlice),
py::arg("operand"), py::arg("update"), py::arg("start_indices"));
ops.def("Fft", &Fft, py::arg("operand"), py::arg("fft_type"),
py::arg("fft_length"));
ops.def("Gather", &Gather, py::arg("a"), py::arg("start_indices"),
py::arg("dimension_numbers"), py::arg("slice_sizes"),
py::arg("indices_are_sorted") = false);
ops.def("GetTupleElement", &GetTupleElement, py::arg("tuple_data"),
py::arg("index"));
ops.def("InfeedWithToken", &InfeedWithToken, py::arg("token"),
py::arg("shape"), py::arg("config") = "");
ops.def("Iota",
static_cast<XlaOp (*)(XlaBuilder*, const Shape&, int64)>(&Iota),
py::arg("builder"), py::arg("shape"), py::arg("iota_dimension"));
ops.def("Iota",
static_cast<XlaOp (*)(XlaBuilder*, PrimitiveType, int64)>(&Iota),
py::arg("builder"), py::arg("type"), py::arg("size"));
ops.def("Map", &Map, py::arg("builder"), py::arg("operands"),
py::arg("computation"), py::arg("dimensions"),
py::arg("static_operands") = py::list());
ops.def("NextAfter", &NextAfter, py::arg("from"), py::arg("to"));
ops.def("OutfeedWithToken", &OutfeedWithToken, py::arg("operand"),
py::arg("token"), py::arg("shape_with_layout"),
py::arg("outfeed_config") = "");
ops.def("Pad", &Pad, py::arg("operand"), py::arg("padding_value"),
py::arg("padding_config"));
ops.def("Parameter",
static_cast<XlaOp (*)(XlaBuilder*, int64, const Shape&,
const std::string&, const std::vector<bool>&)>(
&Parameter),
py::arg("builder"), py::arg("parameter_number"), py::arg("shape"),
py::arg("name") = "",
py::arg("replicated_at_leaf_buffers") = std::vector<bool>());
ops.def(
"QR",
[](XlaOp a, bool full_matrices) -> StatusOr<std::pair<XlaOp, XlaOp>> {
TF_ASSIGN_OR_RETURN(auto qr, QRDecomposition(a, full_matrices));
return std::make_pair(qr.q, qr.r);
},
py::arg("operand"), py::arg("full_matrices"));
ops.def(
"Eigh",
[](XlaOp a, bool lower, int64 max_iter,
float epsilon) -> std::pair<XlaOp, XlaOp> {
auto eigh = SelfAdjointEig(a, lower, max_iter, epsilon);
return std::make_pair(eigh.v, eigh.w);
},
py::arg("a"), py::arg("lower") = true, py::arg("max_iter") = 100,
py::arg("epsilon") = 1e-6);
ops.def(
"SVD",
[](XlaOp a, int64 max_iter,
float epsilon) -> std::tuple<XlaOp, XlaOp, XlaOp> {
auto svd = SVD(a, max_iter, epsilon);
return std::make_tuple(svd.u, svd.d, svd.v);
},
py::arg("a"), py::arg("max_iter") = 100, py::arg("epsilon") = 1e-6);
ops.def("Reduce",
static_cast<XlaOp (*)(XlaBuilder*, absl::Span<const XlaOp>,
absl::Span<const XlaOp>, const XlaComputation&,
absl::Span<const int64>)>(&Reduce),
py::arg("builder"), py::arg("operands"), py::arg("init_values"),
py::arg("computation"), py::arg("dimensions_to_reduce"));
ops.def("ReducePrecision", &ReducePrecision, py::arg("operand"),
py::arg("exponent_bits"), py::arg("mantissa_bits"));
ops.def("ReduceWindowWithGeneralPadding", &ReduceWindowWithGeneralPadding,
py::arg("operand"), py::arg("init_value"), py::arg("computation"),
py::arg("window_dimensions"), py::arg("window_strides"),
py::arg("base_dilations"), py::arg("window_dilations"),
py::arg("padding"));
ops.def("ReplicaId", &ReplicaId, py::arg("builder"));
ops.def("Reshape",
static_cast<XlaOp (*)(XlaOp, absl::Span<const int64>,
absl::Span<const int64>)>(&Reshape),
py::arg("operand"), py::arg("dimensions"), py::arg("new_sizes"));
ops.def("Reshape",
static_cast<XlaOp (*)(XlaOp, absl::Span<const int64>)>(&Reshape),
py::arg("operand"), py::arg("new_sizes"));
ops.def("Rev", &Rev, py::arg("operand"), py::arg("dimensions"));
ops.def("RngNormal", &RngNormal, py::arg("mu"), py::arg("sigma"),
py::arg("shape"));
ops.def("RngUniform", &RngUniform, py::arg("a"), py::arg("b"),
py::arg("shape"));
ops.def("Scatter", &Scatter, py::arg("input"), py::arg("scatter_indices"),
py::arg("updates"), py::arg("update_computation"),
py::arg("dimension_numbers"), py::arg("indices_are_sorted") = false,
py::arg("unique_indices") = false);
ops.def("Select", &Select, py::arg("pred"), py::arg("on_true"),
py::arg("on_false"));
ops.def("SelectAndScatterWithGeneralPadding",
&SelectAndScatterWithGeneralPadding, py::arg("operand"),
py::arg("select"), py::arg("window_dimensions"),
py::arg("window_strides"), py::arg("padding"), py::arg("source"),
py::arg("init_value"), py::arg("scatter"));
ops.def("Slice", &Slice, py::arg("operand"), py::arg("start_indices"),
py::arg("limit_indices"), py::arg("strides"));
ops.def("SliceInDim", &SliceInDim, py::arg("operand"), py::arg("start_index"),
py::arg("limit_index"), py::arg("stride"), py::arg("dimno"));
ops.def(
"Sort",
[](XlaBuilder* builder, absl::Span<const XlaOp> operands,
absl::optional<const XlaComputation*> comparator, int64 dimension,
bool is_stable) -> XlaOp {
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
std::vector<PrimitiveType> operand_types;
for (const auto& operand : operands) {
TF_ASSIGN_OR_RETURN(auto operand_shape, builder->GetShape(operand));
operand_types.push_back(operand_shape.element_type());
}
if (comparator) {
return Sort(operands, **comparator, dimension, is_stable);
} else {
return Sort(operands,
CreateScalarLtComputation(operand_types, builder),
dimension, is_stable);
}
});
},
py::arg("builder"), py::arg("operands"),
py::arg("comparator") = absl::nullopt, py::arg("dimension") = -1,
py::arg("is_stable") = false);
ops.def("TopK", &TopK, py::arg("input"), py::arg("k"));
ops.def("Transpose", &Transpose, py::arg("operand"), py::arg("permutation"));
ops.def("TriangularSolve", &TriangularSolve, py::arg("a"), py::arg("b"),
py::arg("left_side"), py::arg("lower"), py::arg("unit_diagonal"),
py::arg("transpose_a"));
ops.def("Tuple", &Tuple, py::arg("builder"), py::arg("elements"));
ops.def("While", &While, py::arg("condition"), py::arg("body"),
py::arg("init"));
ops.def("Igamma", &Igamma, py::arg("a"), py::arg("x"));
ops.def("Igammac", &Igammac, py::arg("a"), py::arg("x"));
ops.def("IgammaGradA", &IgammaGradA, py::arg("a"), py::arg("x"));
ops.def("RandomGammaGrad", &RandomGammaGrad, py::arg("a"), py::arg("x"));
ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta, py::arg("a"),
py::arg("b"), py::arg("x"));
#define BINARY_OP(op) \
ops.def( \
#op, \
[](XlaOp a, XlaOp b, absl::optional<std::vector<int64>> dims) { \
return dims ? op(a, b, *dims) : op(a, b); \
}, \
py::arg("lhs"), py::arg("rhs"), \
py::arg("broadcast_dimensions") = absl::nullopt)
BINARY_OP(Eq);
BINARY_OP(Ne);
BINARY_OP(Ge);
BINARY_OP(Gt);
BINARY_OP(Lt);
BINARY_OP(Le);
BINARY_OP(Add);
BINARY_OP(Sub);
BINARY_OP(Mul);
BINARY_OP(Div);
BINARY_OP(Rem);
BINARY_OP(Max);
BINARY_OP(Min);
BINARY_OP(And);
BINARY_OP(Or);
BINARY_OP(Xor);
BINARY_OP(ShiftLeft);
BINARY_OP(ShiftRightArithmetic);
BINARY_OP(ShiftRightLogical);
BINARY_OP(Atan2);
BINARY_OP(Pow);
BINARY_OP(Complex);
#undef BINARY_OP
#define UNARY_OP(op) ops.def(#op, &op)
UNARY_OP(Not);
UNARY_OP(PopulationCount);
UNARY_OP(Clz);
UNARY_OP(Abs);
UNARY_OP(Exp);
UNARY_OP(Expm1);
UNARY_OP(Floor);
UNARY_OP(Ceil);
UNARY_OP(Round);
UNARY_OP(Log);
UNARY_OP(Log1p);
UNARY_OP(Sign);
UNARY_OP(Cos);
UNARY_OP(Sin);
UNARY_OP(Tanh);
UNARY_OP(IsFinite);
UNARY_OP(Neg);
UNARY_OP(Sqrt);
UNARY_OP(Rsqrt);
UNARY_OP(Square);
UNARY_OP(Reciprocal);
UNARY_OP(Erfc);
UNARY_OP(Erf);
UNARY_OP(ErfInv);
UNARY_OP(Lgamma);
UNARY_OP(Digamma);
UNARY_OP(BesselI0e);
UNARY_OP(BesselI1e);
UNARY_OP(Acos);
UNARY_OP(Asin);
UNARY_OP(Atan);
UNARY_OP(Tan);
UNARY_OP(Acosh);
UNARY_OP(Asinh);
UNARY_OP(Atanh);
UNARY_OP(Cosh);
UNARY_OP(Sinh);
UNARY_OP(Real);
UNARY_OP(Imag);
UNARY_OP(Conj);
#undef UNARY_OP
}
void BuildProfilerSubmodule(py::module* m) {
py::module profiler =