[XLA:Python] Split bindings for XLA ops into a separate file. No functional changes.
This is partially to make xla.cc shorter and partially to parallelize its build time. PiperOrigin-RevId: 313307447 Change-Id: I4f6de5723dbef4464599813bc9284b4ac9e271d7
This commit is contained in:
parent
05653928da
commit
518423dd27
@ -186,6 +186,32 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ops",
|
||||
srcs = ["ops.cc"],
|
||||
hdrs = ["ops.h"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
":types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/client/lib:comparators",
|
||||
"//tensorflow/compiler/xla/client/lib:math",
|
||||
"//tensorflow/compiler/xla/client/lib:qr",
|
||||
"//tensorflow/compiler/xla/client/lib:self_adjoint_eig",
|
||||
"//tensorflow/compiler/xla/client/lib:sorting",
|
||||
"//tensorflow/compiler/xla/client/lib:svd",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "enable_gpu",
|
||||
values = {"define": "xla_python_enable_gpu=true"},
|
||||
@ -205,6 +231,7 @@ pybind_extension(
|
||||
deps = [
|
||||
":bfloat16",
|
||||
":dlpack",
|
||||
":ops",
|
||||
":python_ref_manager",
|
||||
":types",
|
||||
"@com_google_absl//absl/base",
|
||||
@ -228,12 +255,6 @@ pybind_extension(
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/client/lib:comparators",
|
||||
"//tensorflow/compiler/xla/client/lib:math",
|
||||
"//tensorflow/compiler/xla/client/lib:qr",
|
||||
"//tensorflow/compiler/xla/client/lib:self_adjoint_eig",
|
||||
"//tensorflow/compiler/xla/client/lib:sorting",
|
||||
"//tensorflow/compiler/xla/client/lib:svd",
|
||||
"//tensorflow/compiler/xla/pjrt:cpu_device",
|
||||
"//tensorflow/compiler/xla/pjrt:nvidia_gpu_device",
|
||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||
|
356
tensorflow/compiler/xla/python/ops.cc
Normal file
356
tensorflow/compiler/xla/python/ops.cc
Normal file
@ -0,0 +1,356 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/python/ops.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "pybind11/attr.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/comparators.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/qr.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/sorting.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/svd.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/python/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void BuildOpsSubmodule(py::module* m) {
|
||||
// ops submodule, containing free functions that add operators to an
|
||||
// XlaBuilder.
|
||||
py::module ops = m->def_submodule("ops", "XLA operations");
|
||||
|
||||
py::enum_<TriangularSolveOptions::Transpose>(
|
||||
ops, "TriangularSolveOptions_Transpose")
|
||||
.value("TRANSPOSE_INVALID", TriangularSolveOptions::TRANSPOSE_INVALID)
|
||||
.value("NO_TRANSPOSE", TriangularSolveOptions::NO_TRANSPOSE)
|
||||
.value("TRANSPOSE", TriangularSolveOptions::TRANSPOSE)
|
||||
.value("ADJOINT", TriangularSolveOptions::ADJOINT);
|
||||
|
||||
ops.def("AfterAll", &AfterAll, py::arg("builder"), py::arg("tokens"));
|
||||
ops.def(
|
||||
"AllReduce",
|
||||
static_cast<XlaOp (*)(
|
||||
XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
|
||||
const absl::optional<ChannelHandle>&, const absl::optional<Shape>&)>(
|
||||
&AllReduce),
|
||||
py::arg("operand"), py::arg("computation"),
|
||||
py::arg("replica_groups") = py::list(),
|
||||
py::arg("channel_id") = absl::nullopt,
|
||||
py::arg("shape_with_layout") = absl::nullopt);
|
||||
ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"),
|
||||
py::arg("concat_dimension"), py::arg("split_count"),
|
||||
py::arg("replica_groups") = py::list(),
|
||||
py::arg("layout") = absl::nullopt);
|
||||
ops.def("CollectivePermute", &CollectivePermute, py::arg("operand"),
|
||||
py::arg("source_target_pairs"));
|
||||
ops.def("CreateToken", &CreateToken, py::arg("builder"));
|
||||
ops.def("CrossReplicaSum",
|
||||
static_cast<XlaOp (*)(XlaOp, absl::Span<const ReplicaGroup>)>(
|
||||
&CrossReplicaSum),
|
||||
py::arg("operand"), py::arg("replica_groups") = py::list());
|
||||
ops.def("BitcastConvertType", &BitcastConvertType, py::arg("operand"),
|
||||
py::arg("new_element_type"));
|
||||
ops.def("Broadcast", &Broadcast, py::arg("operand"), py::arg("sizes"));
|
||||
ops.def("BroadcastInDim", &BroadcastInDim, py::arg("operand"),
|
||||
py::arg("shape"), py::arg("broadcast_dimensions"));
|
||||
ops.def("Call", &Call, py::arg("builder"), py::arg("computation"),
|
||||
py::arg("operands"));
|
||||
ops.def("Cholesky", &Cholesky, py::arg("a"), py::arg("lower") = true);
|
||||
ops.def("Clamp", &Clamp, py::arg("min"), py::arg("operand"), py::arg("max"));
|
||||
ops.def("Collapse", &Collapse, py::arg("operand"), py::arg("dimensions"));
|
||||
ops.def("ConcatInDim", &ConcatInDim, py::arg("builder"), py::arg("operands"),
|
||||
py::arg("dimension"));
|
||||
ops.def("Conditional",
|
||||
static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaComputation* const>,
|
||||
absl::Span<const XlaOp>)>(&Conditional),
|
||||
py::arg("branch_index"), py::arg("branch_computations"),
|
||||
py::arg("branch_operands"));
|
||||
ops.def("Conditional",
|
||||
static_cast<XlaOp (*)(XlaOp, XlaOp, const XlaComputation&, XlaOp,
|
||||
const XlaComputation&)>(&Conditional),
|
||||
py::arg("predicate"), py::arg("true_operand"),
|
||||
py::arg("true_computation"), py::arg("false_operand"),
|
||||
py::arg("false_computation"));
|
||||
ops.def("Constant", &ConstantLiteral, py::arg("builder"), py::arg("literal"));
|
||||
ops.def("ConstantLiteral", &ConstantLiteral, py::arg("builder"),
|
||||
py::arg("literal"));
|
||||
ops.def("ConvGeneralDilated", &ConvGeneralDilated, py::arg("lhs"),
|
||||
py::arg("rhs"), py::arg("window_strides"), py::arg("padding"),
|
||||
py::arg("lhs_dilation"), py::arg("rhs_dilation"),
|
||||
py::arg("dimension_numbers"), py::arg("feature_group_count") = 1,
|
||||
py::arg("batch_group_count") = 1,
|
||||
py::arg("precision_config") = nullptr);
|
||||
ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"),
|
||||
py::arg("new_element_type"));
|
||||
ops.def(
|
||||
"CustomCall",
|
||||
[](XlaBuilder* builder, const py::bytes& call_target_name,
|
||||
absl::Span<const XlaOp> operands, const Shape& shape,
|
||||
const py::bytes& opaque) -> XlaOp {
|
||||
return CustomCall(builder, call_target_name, operands, shape, opaque);
|
||||
},
|
||||
py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
|
||||
py::arg("shape"), py::arg("opaque") = py::bytes(""));
|
||||
ops.def(
|
||||
"CustomCallWithLayout",
|
||||
[](XlaBuilder* builder, const py::bytes& call_target_name,
|
||||
absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
|
||||
absl::Span<const Shape> operand_shapes_with_layout,
|
||||
const py::bytes& opaque) -> XlaOp {
|
||||
return CustomCallWithLayout(builder, call_target_name, operands,
|
||||
shape_with_layout,
|
||||
operand_shapes_with_layout, opaque);
|
||||
},
|
||||
py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
|
||||
py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"),
|
||||
py::arg("opaque") = py::bytes(""));
|
||||
ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("precision_config") = nullptr);
|
||||
ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("dimension_numbers"), py::arg("precision_config") = nullptr);
|
||||
ops.def("DynamicSlice",
|
||||
static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaOp>,
|
||||
absl::Span<const int64>)>(&DynamicSlice),
|
||||
py::arg("operand"), py::arg("start_indices"), py::arg("slice_sizes"));
|
||||
ops.def("DynamicUpdateSlice",
|
||||
static_cast<XlaOp (*)(XlaOp, XlaOp, absl::Span<const XlaOp>)>(
|
||||
&DynamicUpdateSlice),
|
||||
py::arg("operand"), py::arg("update"), py::arg("start_indices"));
|
||||
|
||||
ops.def("Fft", &Fft, py::arg("operand"), py::arg("fft_type"),
|
||||
py::arg("fft_length"));
|
||||
|
||||
ops.def("Gather", &Gather, py::arg("a"), py::arg("start_indices"),
|
||||
py::arg("dimension_numbers"), py::arg("slice_sizes"),
|
||||
py::arg("indices_are_sorted") = false);
|
||||
ops.def("GetTupleElement", &GetTupleElement, py::arg("tuple_data"),
|
||||
py::arg("index"));
|
||||
ops.def("InfeedWithToken", &InfeedWithToken, py::arg("token"),
|
||||
py::arg("shape"), py::arg("config") = "");
|
||||
ops.def("Iota",
|
||||
static_cast<XlaOp (*)(XlaBuilder*, const Shape&, int64)>(&Iota),
|
||||
py::arg("builder"), py::arg("shape"), py::arg("iota_dimension"));
|
||||
ops.def("Iota",
|
||||
static_cast<XlaOp (*)(XlaBuilder*, PrimitiveType, int64)>(&Iota),
|
||||
py::arg("builder"), py::arg("type"), py::arg("size"));
|
||||
ops.def("Map", &Map, py::arg("builder"), py::arg("operands"),
|
||||
py::arg("computation"), py::arg("dimensions"),
|
||||
py::arg("static_operands") = py::list());
|
||||
ops.def("NextAfter", &NextAfter, py::arg("from"), py::arg("to"));
|
||||
ops.def("OutfeedWithToken", &OutfeedWithToken, py::arg("operand"),
|
||||
py::arg("token"), py::arg("shape_with_layout"),
|
||||
py::arg("outfeed_config") = "");
|
||||
ops.def("Pad", &Pad, py::arg("operand"), py::arg("padding_value"),
|
||||
py::arg("padding_config"));
|
||||
ops.def("Parameter",
|
||||
static_cast<XlaOp (*)(XlaBuilder*, int64, const Shape&,
|
||||
const std::string&, const std::vector<bool>&)>(
|
||||
&Parameter),
|
||||
py::arg("builder"), py::arg("parameter_number"), py::arg("shape"),
|
||||
py::arg("name") = "",
|
||||
py::arg("replicated_at_leaf_buffers") = std::vector<bool>());
|
||||
ops.def(
|
||||
"QR",
|
||||
[](XlaOp a, bool full_matrices) -> StatusOr<std::pair<XlaOp, XlaOp>> {
|
||||
TF_ASSIGN_OR_RETURN(auto qr, QRDecomposition(a, full_matrices));
|
||||
return std::make_pair(qr.q, qr.r);
|
||||
},
|
||||
py::arg("operand"), py::arg("full_matrices"));
|
||||
ops.def(
|
||||
"Eigh",
|
||||
[](XlaOp a, bool lower, int64 max_iter,
|
||||
float epsilon) -> std::pair<XlaOp, XlaOp> {
|
||||
auto eigh = SelfAdjointEig(a, lower, max_iter, epsilon);
|
||||
return std::make_pair(eigh.v, eigh.w);
|
||||
},
|
||||
py::arg("a"), py::arg("lower") = true, py::arg("max_iter") = 100,
|
||||
py::arg("epsilon") = 1e-6);
|
||||
ops.def(
|
||||
"SVD",
|
||||
[](XlaOp a, int64 max_iter,
|
||||
float epsilon) -> std::tuple<XlaOp, XlaOp, XlaOp> {
|
||||
auto svd = SVD(a, max_iter, epsilon);
|
||||
return std::make_tuple(svd.u, svd.d, svd.v);
|
||||
},
|
||||
py::arg("a"), py::arg("max_iter") = 100, py::arg("epsilon") = 1e-6);
|
||||
ops.def("Reduce",
|
||||
static_cast<XlaOp (*)(XlaBuilder*, absl::Span<const XlaOp>,
|
||||
absl::Span<const XlaOp>, const XlaComputation&,
|
||||
absl::Span<const int64>)>(&Reduce),
|
||||
py::arg("builder"), py::arg("operands"), py::arg("init_values"),
|
||||
py::arg("computation"), py::arg("dimensions_to_reduce"));
|
||||
ops.def("ReducePrecision", &ReducePrecision, py::arg("operand"),
|
||||
py::arg("exponent_bits"), py::arg("mantissa_bits"));
|
||||
ops.def("ReduceWindowWithGeneralPadding", &ReduceWindowWithGeneralPadding,
|
||||
py::arg("operand"), py::arg("init_value"), py::arg("computation"),
|
||||
py::arg("window_dimensions"), py::arg("window_strides"),
|
||||
py::arg("base_dilations"), py::arg("window_dilations"),
|
||||
py::arg("padding"));
|
||||
ops.def("ReplicaId", &ReplicaId, py::arg("builder"));
|
||||
ops.def("Reshape",
|
||||
static_cast<XlaOp (*)(XlaOp, absl::Span<const int64>,
|
||||
absl::Span<const int64>)>(&Reshape),
|
||||
py::arg("operand"), py::arg("dimensions"), py::arg("new_sizes"));
|
||||
ops.def("Reshape",
|
||||
static_cast<XlaOp (*)(XlaOp, absl::Span<const int64>)>(&Reshape),
|
||||
py::arg("operand"), py::arg("new_sizes"));
|
||||
ops.def("Rev", &Rev, py::arg("operand"), py::arg("dimensions"));
|
||||
ops.def("RngNormal", &RngNormal, py::arg("mu"), py::arg("sigma"),
|
||||
py::arg("shape"));
|
||||
ops.def("RngUniform", &RngUniform, py::arg("a"), py::arg("b"),
|
||||
py::arg("shape"));
|
||||
ops.def("Scatter", &Scatter, py::arg("input"), py::arg("scatter_indices"),
|
||||
py::arg("updates"), py::arg("update_computation"),
|
||||
py::arg("dimension_numbers"), py::arg("indices_are_sorted") = false,
|
||||
py::arg("unique_indices") = false);
|
||||
ops.def("Select", &Select, py::arg("pred"), py::arg("on_true"),
|
||||
py::arg("on_false"));
|
||||
ops.def("SelectAndScatterWithGeneralPadding",
|
||||
&SelectAndScatterWithGeneralPadding, py::arg("operand"),
|
||||
py::arg("select"), py::arg("window_dimensions"),
|
||||
py::arg("window_strides"), py::arg("padding"), py::arg("source"),
|
||||
py::arg("init_value"), py::arg("scatter"));
|
||||
ops.def("Slice", &Slice, py::arg("operand"), py::arg("start_indices"),
|
||||
py::arg("limit_indices"), py::arg("strides"));
|
||||
ops.def("SliceInDim", &SliceInDim, py::arg("operand"), py::arg("start_index"),
|
||||
py::arg("limit_index"), py::arg("stride"), py::arg("dimno"));
|
||||
ops.def(
|
||||
"Sort",
|
||||
[](XlaBuilder* builder, absl::Span<const XlaOp> operands,
|
||||
absl::optional<const XlaComputation*> comparator, int64 dimension,
|
||||
bool is_stable) -> XlaOp {
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
std::vector<PrimitiveType> operand_types;
|
||||
for (const auto& operand : operands) {
|
||||
TF_ASSIGN_OR_RETURN(auto operand_shape, builder->GetShape(operand));
|
||||
operand_types.push_back(operand_shape.element_type());
|
||||
}
|
||||
|
||||
if (comparator) {
|
||||
return Sort(operands, **comparator, dimension, is_stable);
|
||||
} else {
|
||||
return Sort(operands,
|
||||
CreateScalarLtComputation(operand_types, builder),
|
||||
dimension, is_stable);
|
||||
}
|
||||
});
|
||||
},
|
||||
py::arg("builder"), py::arg("operands"),
|
||||
py::arg("comparator") = absl::nullopt, py::arg("dimension") = -1,
|
||||
py::arg("is_stable") = false);
|
||||
ops.def("TopK", &TopK, py::arg("input"), py::arg("k"));
|
||||
ops.def("Transpose", &Transpose, py::arg("operand"), py::arg("permutation"));
|
||||
ops.def("TriangularSolve", &TriangularSolve, py::arg("a"), py::arg("b"),
|
||||
py::arg("left_side"), py::arg("lower"), py::arg("unit_diagonal"),
|
||||
py::arg("transpose_a"));
|
||||
ops.def("Tuple", &Tuple, py::arg("builder"), py::arg("elements"));
|
||||
ops.def("While", &While, py::arg("condition"), py::arg("body"),
|
||||
py::arg("init"));
|
||||
|
||||
ops.def("Igamma", &Igamma, py::arg("a"), py::arg("x"));
|
||||
ops.def("Igammac", &Igammac, py::arg("a"), py::arg("x"));
|
||||
ops.def("IgammaGradA", &IgammaGradA, py::arg("a"), py::arg("x"));
|
||||
ops.def("RandomGammaGrad", &RandomGammaGrad, py::arg("a"), py::arg("x"));
|
||||
ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta, py::arg("a"),
|
||||
py::arg("b"), py::arg("x"));
|
||||
|
||||
#define BINARY_OP(op) \
|
||||
ops.def( \
|
||||
#op, \
|
||||
[](XlaOp a, XlaOp b, absl::optional<std::vector<int64>> dims) { \
|
||||
return dims ? op(a, b, *dims) : op(a, b); \
|
||||
}, \
|
||||
py::arg("lhs"), py::arg("rhs"), \
|
||||
py::arg("broadcast_dimensions") = absl::nullopt)
|
||||
BINARY_OP(Eq);
|
||||
BINARY_OP(Ne);
|
||||
BINARY_OP(Ge);
|
||||
BINARY_OP(Gt);
|
||||
BINARY_OP(Lt);
|
||||
BINARY_OP(Le);
|
||||
BINARY_OP(Add);
|
||||
BINARY_OP(Sub);
|
||||
BINARY_OP(Mul);
|
||||
BINARY_OP(Div);
|
||||
BINARY_OP(Rem);
|
||||
BINARY_OP(Max);
|
||||
BINARY_OP(Min);
|
||||
BINARY_OP(And);
|
||||
BINARY_OP(Or);
|
||||
BINARY_OP(Xor);
|
||||
BINARY_OP(ShiftLeft);
|
||||
BINARY_OP(ShiftRightArithmetic);
|
||||
BINARY_OP(ShiftRightLogical);
|
||||
BINARY_OP(Atan2);
|
||||
BINARY_OP(Pow);
|
||||
BINARY_OP(Complex);
|
||||
#undef BINARY_OP
|
||||
|
||||
#define UNARY_OP(op) ops.def(#op, &op)
|
||||
UNARY_OP(Not);
|
||||
UNARY_OP(PopulationCount);
|
||||
UNARY_OP(Clz);
|
||||
UNARY_OP(Abs);
|
||||
UNARY_OP(Exp);
|
||||
UNARY_OP(Expm1);
|
||||
UNARY_OP(Floor);
|
||||
UNARY_OP(Ceil);
|
||||
UNARY_OP(Round);
|
||||
UNARY_OP(Log);
|
||||
UNARY_OP(Log1p);
|
||||
UNARY_OP(Sign);
|
||||
UNARY_OP(Cos);
|
||||
UNARY_OP(Sin);
|
||||
UNARY_OP(Tanh);
|
||||
UNARY_OP(IsFinite);
|
||||
UNARY_OP(Neg);
|
||||
UNARY_OP(Sqrt);
|
||||
UNARY_OP(Rsqrt);
|
||||
UNARY_OP(Square);
|
||||
UNARY_OP(Reciprocal);
|
||||
UNARY_OP(Erfc);
|
||||
UNARY_OP(Erf);
|
||||
UNARY_OP(ErfInv);
|
||||
UNARY_OP(Lgamma);
|
||||
UNARY_OP(Digamma);
|
||||
UNARY_OP(BesselI0e);
|
||||
UNARY_OP(BesselI1e);
|
||||
UNARY_OP(Acos);
|
||||
UNARY_OP(Asin);
|
||||
UNARY_OP(Atan);
|
||||
UNARY_OP(Tan);
|
||||
UNARY_OP(Acosh);
|
||||
UNARY_OP(Asinh);
|
||||
UNARY_OP(Atanh);
|
||||
UNARY_OP(Cosh);
|
||||
UNARY_OP(Sinh);
|
||||
UNARY_OP(Real);
|
||||
UNARY_OP(Imag);
|
||||
UNARY_OP(Conj);
|
||||
#undef UNARY_OP
|
||||
}
|
||||
|
||||
} // namespace xla
|
27
tensorflow/compiler/xla/python/ops.h
Normal file
27
tensorflow/compiler/xla/python/ops.h
Normal file
@ -0,0 +1,27 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_OPS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_OPS_H_
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
void BuildOpsSubmodule(pybind11::module* m);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_OPS_H_
|
@ -30,12 +30,6 @@ limitations under the License.
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/pytypes.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/comparators.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/qr.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/sorting.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/svd.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
@ -48,6 +42,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/python/bfloat16.h"
|
||||
#include "tensorflow/compiler/xla/python/dlpack.h"
|
||||
#include "tensorflow/compiler/xla/python/ops.h"
|
||||
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
|
||||
#include "tensorflow/compiler/xla/python/types.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
||||
@ -306,321 +301,6 @@ StatusOr<py::dict> PjRtBufferCudaArrayInterface(const PjRtBuffer& buffer) {
|
||||
return result;
|
||||
}
|
||||
|
||||
void BuildOpsSubmodule(py::module* m) {
|
||||
// ops submodule, containing free functions that add operators to an
|
||||
// XlaBuilder.
|
||||
py::module ops = m->def_submodule("ops", "XLA operations");
|
||||
|
||||
py::enum_<TriangularSolveOptions::Transpose>(
|
||||
ops, "TriangularSolveOptions_Transpose")
|
||||
.value("TRANSPOSE_INVALID", TriangularSolveOptions::TRANSPOSE_INVALID)
|
||||
.value("NO_TRANSPOSE", TriangularSolveOptions::NO_TRANSPOSE)
|
||||
.value("TRANSPOSE", TriangularSolveOptions::TRANSPOSE)
|
||||
.value("ADJOINT", TriangularSolveOptions::ADJOINT);
|
||||
|
||||
ops.def("AfterAll", &AfterAll, py::arg("builder"), py::arg("tokens"));
|
||||
ops.def(
|
||||
"AllReduce",
|
||||
static_cast<XlaOp (*)(
|
||||
XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
|
||||
const absl::optional<ChannelHandle>&, const absl::optional<Shape>&)>(
|
||||
&AllReduce),
|
||||
py::arg("operand"), py::arg("computation"),
|
||||
py::arg("replica_groups") = py::list(),
|
||||
py::arg("channel_id") = absl::nullopt,
|
||||
py::arg("shape_with_layout") = absl::nullopt);
|
||||
ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"),
|
||||
py::arg("concat_dimension"), py::arg("split_count"),
|
||||
py::arg("replica_groups") = py::list(),
|
||||
py::arg("layout") = absl::nullopt);
|
||||
ops.def("CollectivePermute", &CollectivePermute, py::arg("operand"),
|
||||
py::arg("source_target_pairs"));
|
||||
ops.def("CreateToken", &CreateToken, py::arg("builder"));
|
||||
ops.def("CrossReplicaSum",
|
||||
static_cast<XlaOp (*)(XlaOp, absl::Span<const ReplicaGroup>)>(
|
||||
&CrossReplicaSum),
|
||||
py::arg("operand"), py::arg("replica_groups") = py::list());
|
||||
ops.def("BitcastConvertType", &BitcastConvertType, py::arg("operand"),
|
||||
py::arg("new_element_type"));
|
||||
ops.def("Broadcast", &Broadcast, py::arg("operand"), py::arg("sizes"));
|
||||
ops.def("BroadcastInDim", &BroadcastInDim, py::arg("operand"),
|
||||
py::arg("shape"), py::arg("broadcast_dimensions"));
|
||||
ops.def("Call", &Call, py::arg("builder"), py::arg("computation"),
|
||||
py::arg("operands"));
|
||||
ops.def("Cholesky", &Cholesky, py::arg("a"), py::arg("lower") = true);
|
||||
ops.def("Clamp", &Clamp, py::arg("min"), py::arg("operand"), py::arg("max"));
|
||||
ops.def("Collapse", &Collapse, py::arg("operand"), py::arg("dimensions"));
|
||||
ops.def("ConcatInDim", &ConcatInDim, py::arg("builder"), py::arg("operands"),
|
||||
py::arg("dimension"));
|
||||
ops.def("Conditional",
|
||||
static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaComputation* const>,
|
||||
absl::Span<const XlaOp>)>(&Conditional),
|
||||
py::arg("branch_index"), py::arg("branch_computations"),
|
||||
py::arg("branch_operands"));
|
||||
ops.def("Conditional",
|
||||
static_cast<XlaOp (*)(XlaOp, XlaOp, const XlaComputation&, XlaOp,
|
||||
const XlaComputation&)>(&Conditional),
|
||||
py::arg("predicate"), py::arg("true_operand"),
|
||||
py::arg("true_computation"), py::arg("false_operand"),
|
||||
py::arg("false_computation"));
|
||||
ops.def("Constant", &ConstantLiteral, py::arg("builder"), py::arg("literal"));
|
||||
ops.def("ConstantLiteral", &ConstantLiteral, py::arg("builder"),
|
||||
py::arg("literal"));
|
||||
ops.def("ConvGeneralDilated", &ConvGeneralDilated, py::arg("lhs"),
|
||||
py::arg("rhs"), py::arg("window_strides"), py::arg("padding"),
|
||||
py::arg("lhs_dilation"), py::arg("rhs_dilation"),
|
||||
py::arg("dimension_numbers"), py::arg("feature_group_count") = 1,
|
||||
py::arg("batch_group_count") = 1,
|
||||
py::arg("precision_config") = nullptr);
|
||||
ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"),
|
||||
py::arg("new_element_type"));
|
||||
ops.def(
|
||||
"CustomCall",
|
||||
[](XlaBuilder* builder, const py::bytes& call_target_name,
|
||||
absl::Span<const XlaOp> operands, const Shape& shape,
|
||||
const py::bytes& opaque) -> XlaOp {
|
||||
return CustomCall(builder, call_target_name, operands, shape, opaque);
|
||||
},
|
||||
py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
|
||||
py::arg("shape"), py::arg("opaque") = py::bytes(""));
|
||||
ops.def(
|
||||
"CustomCallWithLayout",
|
||||
[](XlaBuilder* builder, const py::bytes& call_target_name,
|
||||
absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
|
||||
absl::Span<const Shape> operand_shapes_with_layout,
|
||||
const py::bytes& opaque) -> XlaOp {
|
||||
return CustomCallWithLayout(builder, call_target_name, operands,
|
||||
shape_with_layout,
|
||||
operand_shapes_with_layout, opaque);
|
||||
},
|
||||
py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
|
||||
py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"),
|
||||
py::arg("opaque") = py::bytes(""));
|
||||
ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("precision_config") = nullptr);
|
||||
ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("dimension_numbers"), py::arg("precision_config") = nullptr);
|
||||
ops.def("DynamicSlice",
|
||||
static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaOp>,
|
||||
absl::Span<const int64>)>(&DynamicSlice),
|
||||
py::arg("operand"), py::arg("start_indices"), py::arg("slice_sizes"));
|
||||
ops.def("DynamicUpdateSlice",
|
||||
static_cast<XlaOp (*)(XlaOp, XlaOp, absl::Span<const XlaOp>)>(
|
||||
&DynamicUpdateSlice),
|
||||
py::arg("operand"), py::arg("update"), py::arg("start_indices"));
|
||||
|
||||
ops.def("Fft", &Fft, py::arg("operand"), py::arg("fft_type"),
|
||||
py::arg("fft_length"));
|
||||
|
||||
ops.def("Gather", &Gather, py::arg("a"), py::arg("start_indices"),
|
||||
py::arg("dimension_numbers"), py::arg("slice_sizes"),
|
||||
py::arg("indices_are_sorted") = false);
|
||||
ops.def("GetTupleElement", &GetTupleElement, py::arg("tuple_data"),
|
||||
py::arg("index"));
|
||||
ops.def("InfeedWithToken", &InfeedWithToken, py::arg("token"),
|
||||
py::arg("shape"), py::arg("config") = "");
|
||||
ops.def("Iota",
|
||||
static_cast<XlaOp (*)(XlaBuilder*, const Shape&, int64)>(&Iota),
|
||||
py::arg("builder"), py::arg("shape"), py::arg("iota_dimension"));
|
||||
ops.def("Iota",
|
||||
static_cast<XlaOp (*)(XlaBuilder*, PrimitiveType, int64)>(&Iota),
|
||||
py::arg("builder"), py::arg("type"), py::arg("size"));
|
||||
ops.def("Map", &Map, py::arg("builder"), py::arg("operands"),
|
||||
py::arg("computation"), py::arg("dimensions"),
|
||||
py::arg("static_operands") = py::list());
|
||||
ops.def("NextAfter", &NextAfter, py::arg("from"), py::arg("to"));
|
||||
ops.def("OutfeedWithToken", &OutfeedWithToken, py::arg("operand"),
|
||||
py::arg("token"), py::arg("shape_with_layout"),
|
||||
py::arg("outfeed_config") = "");
|
||||
ops.def("Pad", &Pad, py::arg("operand"), py::arg("padding_value"),
|
||||
py::arg("padding_config"));
|
||||
ops.def("Parameter",
|
||||
static_cast<XlaOp (*)(XlaBuilder*, int64, const Shape&,
|
||||
const std::string&, const std::vector<bool>&)>(
|
||||
&Parameter),
|
||||
py::arg("builder"), py::arg("parameter_number"), py::arg("shape"),
|
||||
py::arg("name") = "",
|
||||
py::arg("replicated_at_leaf_buffers") = std::vector<bool>());
|
||||
ops.def(
|
||||
"QR",
|
||||
[](XlaOp a, bool full_matrices) -> StatusOr<std::pair<XlaOp, XlaOp>> {
|
||||
TF_ASSIGN_OR_RETURN(auto qr, QRDecomposition(a, full_matrices));
|
||||
return std::make_pair(qr.q, qr.r);
|
||||
},
|
||||
py::arg("operand"), py::arg("full_matrices"));
|
||||
ops.def(
|
||||
"Eigh",
|
||||
[](XlaOp a, bool lower, int64 max_iter,
|
||||
float epsilon) -> std::pair<XlaOp, XlaOp> {
|
||||
auto eigh = SelfAdjointEig(a, lower, max_iter, epsilon);
|
||||
return std::make_pair(eigh.v, eigh.w);
|
||||
},
|
||||
py::arg("a"), py::arg("lower") = true, py::arg("max_iter") = 100,
|
||||
py::arg("epsilon") = 1e-6);
|
||||
ops.def(
|
||||
"SVD",
|
||||
[](XlaOp a, int64 max_iter,
|
||||
float epsilon) -> std::tuple<XlaOp, XlaOp, XlaOp> {
|
||||
auto svd = SVD(a, max_iter, epsilon);
|
||||
return std::make_tuple(svd.u, svd.d, svd.v);
|
||||
},
|
||||
py::arg("a"), py::arg("max_iter") = 100, py::arg("epsilon") = 1e-6);
|
||||
ops.def("Reduce",
|
||||
static_cast<XlaOp (*)(XlaBuilder*, absl::Span<const XlaOp>,
|
||||
absl::Span<const XlaOp>, const XlaComputation&,
|
||||
absl::Span<const int64>)>(&Reduce),
|
||||
py::arg("builder"), py::arg("operands"), py::arg("init_values"),
|
||||
py::arg("computation"), py::arg("dimensions_to_reduce"));
|
||||
ops.def("ReducePrecision", &ReducePrecision, py::arg("operand"),
|
||||
py::arg("exponent_bits"), py::arg("mantissa_bits"));
|
||||
ops.def("ReduceWindowWithGeneralPadding", &ReduceWindowWithGeneralPadding,
|
||||
py::arg("operand"), py::arg("init_value"), py::arg("computation"),
|
||||
py::arg("window_dimensions"), py::arg("window_strides"),
|
||||
py::arg("base_dilations"), py::arg("window_dilations"),
|
||||
py::arg("padding"));
|
||||
ops.def("ReplicaId", &ReplicaId, py::arg("builder"));
|
||||
ops.def("Reshape",
|
||||
static_cast<XlaOp (*)(XlaOp, absl::Span<const int64>,
|
||||
absl::Span<const int64>)>(&Reshape),
|
||||
py::arg("operand"), py::arg("dimensions"), py::arg("new_sizes"));
|
||||
ops.def("Reshape",
|
||||
static_cast<XlaOp (*)(XlaOp, absl::Span<const int64>)>(&Reshape),
|
||||
py::arg("operand"), py::arg("new_sizes"));
|
||||
ops.def("Rev", &Rev, py::arg("operand"), py::arg("dimensions"));
|
||||
ops.def("RngNormal", &RngNormal, py::arg("mu"), py::arg("sigma"),
|
||||
py::arg("shape"));
|
||||
ops.def("RngUniform", &RngUniform, py::arg("a"), py::arg("b"),
|
||||
py::arg("shape"));
|
||||
ops.def("Scatter", &Scatter, py::arg("input"), py::arg("scatter_indices"),
|
||||
py::arg("updates"), py::arg("update_computation"),
|
||||
py::arg("dimension_numbers"), py::arg("indices_are_sorted") = false,
|
||||
py::arg("unique_indices") = false);
|
||||
ops.def("Select", &Select, py::arg("pred"), py::arg("on_true"),
|
||||
py::arg("on_false"));
|
||||
ops.def("SelectAndScatterWithGeneralPadding",
|
||||
&SelectAndScatterWithGeneralPadding, py::arg("operand"),
|
||||
py::arg("select"), py::arg("window_dimensions"),
|
||||
py::arg("window_strides"), py::arg("padding"), py::arg("source"),
|
||||
py::arg("init_value"), py::arg("scatter"));
|
||||
ops.def("Slice", &Slice, py::arg("operand"), py::arg("start_indices"),
|
||||
py::arg("limit_indices"), py::arg("strides"));
|
||||
ops.def("SliceInDim", &SliceInDim, py::arg("operand"), py::arg("start_index"),
|
||||
py::arg("limit_index"), py::arg("stride"), py::arg("dimno"));
|
||||
ops.def(
|
||||
"Sort",
|
||||
[](XlaBuilder* builder, absl::Span<const XlaOp> operands,
|
||||
absl::optional<const XlaComputation*> comparator, int64 dimension,
|
||||
bool is_stable) -> XlaOp {
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
std::vector<PrimitiveType> operand_types;
|
||||
for (const auto& operand : operands) {
|
||||
TF_ASSIGN_OR_RETURN(auto operand_shape, builder->GetShape(operand));
|
||||
operand_types.push_back(operand_shape.element_type());
|
||||
}
|
||||
|
||||
if (comparator) {
|
||||
return Sort(operands, **comparator, dimension, is_stable);
|
||||
} else {
|
||||
return Sort(operands,
|
||||
CreateScalarLtComputation(operand_types, builder),
|
||||
dimension, is_stable);
|
||||
}
|
||||
});
|
||||
},
|
||||
py::arg("builder"), py::arg("operands"),
|
||||
py::arg("comparator") = absl::nullopt, py::arg("dimension") = -1,
|
||||
py::arg("is_stable") = false);
|
||||
ops.def("TopK", &TopK, py::arg("input"), py::arg("k"));
|
||||
ops.def("Transpose", &Transpose, py::arg("operand"), py::arg("permutation"));
|
||||
ops.def("TriangularSolve", &TriangularSolve, py::arg("a"), py::arg("b"),
|
||||
py::arg("left_side"), py::arg("lower"), py::arg("unit_diagonal"),
|
||||
py::arg("transpose_a"));
|
||||
ops.def("Tuple", &Tuple, py::arg("builder"), py::arg("elements"));
|
||||
ops.def("While", &While, py::arg("condition"), py::arg("body"),
|
||||
py::arg("init"));
|
||||
|
||||
ops.def("Igamma", &Igamma, py::arg("a"), py::arg("x"));
|
||||
ops.def("Igammac", &Igammac, py::arg("a"), py::arg("x"));
|
||||
ops.def("IgammaGradA", &IgammaGradA, py::arg("a"), py::arg("x"));
|
||||
ops.def("RandomGammaGrad", &RandomGammaGrad, py::arg("a"), py::arg("x"));
|
||||
ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta, py::arg("a"),
|
||||
py::arg("b"), py::arg("x"));
|
||||
|
||||
#define BINARY_OP(op) \
|
||||
ops.def( \
|
||||
#op, \
|
||||
[](XlaOp a, XlaOp b, absl::optional<std::vector<int64>> dims) { \
|
||||
return dims ? op(a, b, *dims) : op(a, b); \
|
||||
}, \
|
||||
py::arg("lhs"), py::arg("rhs"), \
|
||||
py::arg("broadcast_dimensions") = absl::nullopt)
|
||||
BINARY_OP(Eq);
|
||||
BINARY_OP(Ne);
|
||||
BINARY_OP(Ge);
|
||||
BINARY_OP(Gt);
|
||||
BINARY_OP(Lt);
|
||||
BINARY_OP(Le);
|
||||
BINARY_OP(Add);
|
||||
BINARY_OP(Sub);
|
||||
BINARY_OP(Mul);
|
||||
BINARY_OP(Div);
|
||||
BINARY_OP(Rem);
|
||||
BINARY_OP(Max);
|
||||
BINARY_OP(Min);
|
||||
BINARY_OP(And);
|
||||
BINARY_OP(Or);
|
||||
BINARY_OP(Xor);
|
||||
BINARY_OP(ShiftLeft);
|
||||
BINARY_OP(ShiftRightArithmetic);
|
||||
BINARY_OP(ShiftRightLogical);
|
||||
BINARY_OP(Atan2);
|
||||
BINARY_OP(Pow);
|
||||
BINARY_OP(Complex);
|
||||
#undef BINARY_OP
|
||||
|
||||
#define UNARY_OP(op) ops.def(#op, &op)
|
||||
UNARY_OP(Not);
|
||||
UNARY_OP(PopulationCount);
|
||||
UNARY_OP(Clz);
|
||||
UNARY_OP(Abs);
|
||||
UNARY_OP(Exp);
|
||||
UNARY_OP(Expm1);
|
||||
UNARY_OP(Floor);
|
||||
UNARY_OP(Ceil);
|
||||
UNARY_OP(Round);
|
||||
UNARY_OP(Log);
|
||||
UNARY_OP(Log1p);
|
||||
UNARY_OP(Sign);
|
||||
UNARY_OP(Cos);
|
||||
UNARY_OP(Sin);
|
||||
UNARY_OP(Tanh);
|
||||
UNARY_OP(IsFinite);
|
||||
UNARY_OP(Neg);
|
||||
UNARY_OP(Sqrt);
|
||||
UNARY_OP(Rsqrt);
|
||||
UNARY_OP(Square);
|
||||
UNARY_OP(Reciprocal);
|
||||
UNARY_OP(Erfc);
|
||||
UNARY_OP(Erf);
|
||||
UNARY_OP(ErfInv);
|
||||
UNARY_OP(Lgamma);
|
||||
UNARY_OP(Digamma);
|
||||
UNARY_OP(BesselI0e);
|
||||
UNARY_OP(BesselI1e);
|
||||
UNARY_OP(Acos);
|
||||
UNARY_OP(Asin);
|
||||
UNARY_OP(Atan);
|
||||
UNARY_OP(Tan);
|
||||
UNARY_OP(Acosh);
|
||||
UNARY_OP(Asinh);
|
||||
UNARY_OP(Atanh);
|
||||
UNARY_OP(Cosh);
|
||||
UNARY_OP(Sinh);
|
||||
UNARY_OP(Real);
|
||||
UNARY_OP(Imag);
|
||||
UNARY_OP(Conj);
|
||||
#undef UNARY_OP
|
||||
}
|
||||
|
||||
void BuildProfilerSubmodule(py::module* m) {
|
||||
py::module profiler =
|
||||
|
Loading…
Reference in New Issue
Block a user