[XLA] [TF:XLA] Move Cholesky decomposition into xla/client/lib/cholesky.*
Move loop helpers used by Cholesky decomposition into xla/client/lib/loops.*. PiperOrigin-RevId: 225037112
This commit is contained in:
parent
c99ecfa992
commit
9b964193d9
@ -1,16 +1,11 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_copts", "tf_kernel_library")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
|
||||
)
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_copts")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
|
||||
load(
|
||||
"//third_party/mkl:build_defs.bzl",
|
||||
"if_mkl",
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "xla_ops",
|
||||
srcs = [
|
||||
@ -122,12 +117,10 @@ tf_kernel_library(
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/lib:broadcast",
|
||||
"//tensorflow/compiler/tf2xla/lib:cholesky",
|
||||
"//tensorflow/compiler/tf2xla/lib:qr",
|
||||
"//tensorflow/compiler/tf2xla/lib:random",
|
||||
"//tensorflow/compiler/tf2xla/lib:scatter",
|
||||
"//tensorflow/compiler/tf2xla/lib:util",
|
||||
"//tensorflow/compiler/tf2xla/lib:while_loop",
|
||||
"//tensorflow/compiler/tf2xla/ops:xla_ops",
|
||||
"//tensorflow/compiler/xla:array4d",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
@ -140,7 +133,9 @@ tf_kernel_library(
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
||||
"//tensorflow/compiler/xla/client/lib:cholesky",
|
||||
"//tensorflow/compiler/xla/client/lib:constants",
|
||||
"//tensorflow/compiler/xla/client/lib:loops",
|
||||
"//tensorflow/compiler/xla/client/lib:math",
|
||||
"//tensorflow/compiler/xla/client/lib:matrix",
|
||||
"//tensorflow/compiler/xla/client/lib:pooling",
|
||||
|
@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/lib/cholesky.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/cholesky.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -24,7 +24,7 @@ class CholeskyOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
ctx->SetOutput(0, Cholesky(ctx->Input(0)));
|
||||
ctx->SetOutput(0, xla::Cholesky(ctx->Input(0)));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -14,7 +14,6 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
|
@ -15,12 +15,12 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
||||
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/loops.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/sorting.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
@ -505,9 +505,9 @@ class NonMaxSuppressionOp : public XlaOpKernel {
|
||||
init_values.push_back(included_iou);
|
||||
|
||||
auto suppress_loop_result =
|
||||
XlaWhileLoop(WhileCondFn(num_boxes, output_size),
|
||||
SuppressBodyFn(num_boxes), init_values, "suppress_loop",
|
||||
builder)
|
||||
xla::WhileLoopHelper(WhileCondFn(num_boxes, output_size),
|
||||
SuppressBodyFn(num_boxes), init_values,
|
||||
"suppress_loop", builder)
|
||||
.ValueOrDie();
|
||||
|
||||
xla::XlaOp included_score =
|
||||
|
@ -20,12 +20,12 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/lib/random.h"
|
||||
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
||||
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/loops.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -175,8 +175,8 @@ class RandomShuffleOp : public XlaOpKernel {
|
||||
};
|
||||
// for i in range(n):
|
||||
auto swap_loop_result =
|
||||
XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices},
|
||||
"indices_swap_loop", builder)
|
||||
xla::ForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices},
|
||||
"indices_swap_loop", builder)
|
||||
.ValueOrDie();
|
||||
auto swapped_indices = swap_loop_result[1];
|
||||
|
||||
|
@ -15,8 +15,6 @@ filegroup(
|
||||
]),
|
||||
)
|
||||
|
||||
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
|
||||
|
||||
cc_library(
|
||||
name = "broadcast",
|
||||
srcs = ["broadcast.cc"],
|
||||
@ -33,27 +31,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cholesky",
|
||||
srcs = ["cholesky.cc"],
|
||||
hdrs = ["cholesky.h"],
|
||||
deps = [
|
||||
":util",
|
||||
":while_loop",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client/lib:constants",
|
||||
"//tensorflow/compiler/xla/client/lib:matrix",
|
||||
"//tensorflow/compiler/xla/client/lib:slicing",
|
||||
"//tensorflow/compiler/xla/client/lib:triangular_solve",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "random",
|
||||
srcs = ["random.cc"],
|
||||
@ -75,7 +52,6 @@ cc_library(
|
||||
hdrs = ["qr.h"],
|
||||
deps = [
|
||||
":util",
|
||||
":while_loop",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
@ -84,6 +60,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
||||
"//tensorflow/compiler/xla/client/lib:constants",
|
||||
"//tensorflow/compiler/xla/client/lib:loops",
|
||||
"//tensorflow/compiler/xla/client/lib:math",
|
||||
"//tensorflow/compiler/xla/client/lib:matrix",
|
||||
"//tensorflow/compiler/xla/client/lib:slicing",
|
||||
@ -97,7 +74,6 @@ cc_library(
|
||||
hdrs = ["scatter.h"],
|
||||
deps = [
|
||||
":util",
|
||||
":while_loop",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
@ -128,19 +104,3 @@ cc_library(
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "while_loop",
|
||||
srcs = ["while_loop.cc"],
|
||||
hdrs = ["while_loop.h"],
|
||||
deps = [
|
||||
":util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
@ -19,9 +19,9 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
||||
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/loops.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
@ -225,8 +225,8 @@ xla::StatusOr<QRBlockResult> QRBlock(
|
||||
builder, xla::ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n})));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto values,
|
||||
XlaForEachIndex(std::min(m, n), xla::S32, qr_body_fn,
|
||||
{a, vs, taus}, "qr", builder));
|
||||
xla::ForEachIndex(std::min(m, n), xla::S32, qr_body_fn,
|
||||
{a, vs, taus}, "qr", builder));
|
||||
|
||||
QRBlockResult result;
|
||||
result.r = values[0];
|
||||
@ -301,8 +301,8 @@ xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
|
||||
w = UpdateSliceInMinorDims(w, bv, {0});
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto values, XlaForEachIndex(n - 1, xla::S32, body_fn, {w, y, vs, taus},
|
||||
"wy", builder));
|
||||
auto values, xla::ForEachIndex(n - 1, xla::S32, body_fn, {w, y, vs, taus},
|
||||
"wy", builder));
|
||||
return values[0];
|
||||
}
|
||||
|
||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
||||
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
|
@ -1,5 +1,7 @@
|
||||
# Common computation builders for XLA.
|
||||
|
||||
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites", "xla_test")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//tensorflow/compiler/xla/client:friends"])
|
||||
@ -13,9 +15,6 @@ filegroup(
|
||||
]),
|
||||
)
|
||||
|
||||
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
|
||||
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites")
|
||||
|
||||
# Generate test_suites for all backends, named "${backend}_tests".
|
||||
generate_backend_suites()
|
||||
|
||||
@ -35,6 +34,48 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cholesky",
|
||||
srcs = ["cholesky.cc"],
|
||||
hdrs = ["cholesky.h"],
|
||||
deps = [
|
||||
":math",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client/lib:constants",
|
||||
"//tensorflow/compiler/xla/client/lib:loops",
|
||||
"//tensorflow/compiler/xla/client/lib:matrix",
|
||||
"//tensorflow/compiler/xla/client/lib:slicing",
|
||||
"//tensorflow/compiler/xla/client/lib:triangular_solve",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "cholesky_test",
|
||||
srcs = ["cholesky_test.cc"],
|
||||
tags = ["optonly"],
|
||||
deps = [
|
||||
":arithmetic",
|
||||
":cholesky",
|
||||
":matrix",
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "constants",
|
||||
srcs = ["constants.cc"],
|
||||
@ -75,6 +116,22 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "loops",
|
||||
srcs = ["loops.cc"],
|
||||
hdrs = ["loops.h"],
|
||||
deps = [
|
||||
":constants",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "math",
|
||||
srcs = ["math.cc"],
|
||||
|
@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/lib/cholesky.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/cholesky.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
||||
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/loops.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/triangular_solve.h"
|
||||
@ -31,7 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
@ -50,26 +50,25 @@ namespace {
|
||||
// l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) /
|
||||
// l[..., j, j]
|
||||
// return l
|
||||
xla::XlaOp CholeskyUnblocked(xla::XlaOp a,
|
||||
xla::PrecisionConfig::Precision precision) {
|
||||
xla::XlaBuilder* builder = a.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
|
||||
const int n_dims = xla::ShapeUtil::Rank(a_shape);
|
||||
const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
|
||||
auto major_dims = xla::AsInt64Slice(a_shape.dimensions())
|
||||
XlaOp CholeskyUnblocked(XlaOp a, PrecisionConfig::Precision precision) {
|
||||
XlaBuilder* builder = a.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||
const int n_dims = ShapeUtil::Rank(a_shape);
|
||||
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
|
||||
auto major_dims = AsInt64Slice(a_shape.dimensions())
|
||||
.subspan(
|
||||
/*pos=*/0,
|
||||
/*len=*/n_dims - 2);
|
||||
|
||||
xla::XlaOp l = xla::ZerosLike(a);
|
||||
XlaOp l = ZerosLike(a);
|
||||
|
||||
// Construct the for loop body to iterate over rows.
|
||||
auto body_fn = [&](xla::XlaOp i, absl::Span<const xla::XlaOp> loop_vars,
|
||||
xla::XlaBuilder* body_builder)
|
||||
-> xla::StatusOr<std::vector<xla::XlaOp>> {
|
||||
xla::Shape col_shape;
|
||||
xla::Shape row_shape;
|
||||
auto body_fn =
|
||||
[&](XlaOp i, absl::Span<const XlaOp> loop_vars,
|
||||
XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
|
||||
Shape col_shape;
|
||||
Shape row_shape;
|
||||
for (int64 d : major_dims) {
|
||||
row_shape.add_dimensions(d);
|
||||
col_shape.add_dimensions(d);
|
||||
@ -77,43 +76,40 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a,
|
||||
row_shape.add_dimensions(1);
|
||||
row_shape.add_dimensions(n);
|
||||
row_shape.set_element_type(a_shape.element_type());
|
||||
auto mask_zeros_row = xla::Zeros(body_builder, row_shape);
|
||||
auto mask_zeros_row = Zeros(body_builder, row_shape);
|
||||
|
||||
col_shape.add_dimensions(n);
|
||||
col_shape.add_dimensions(1);
|
||||
col_shape.set_element_type(a_shape.element_type());
|
||||
auto mask_zeros_col = xla::Zeros(body_builder, col_shape);
|
||||
auto mask_zeros_col = Zeros(body_builder, col_shape);
|
||||
|
||||
std::vector<int32> mask_vector(n);
|
||||
std::iota(mask_vector.begin(), mask_vector.end(), 0);
|
||||
auto mask_range = xla::ConstantR1<int32>(body_builder, mask_vector);
|
||||
auto mask_range = ConstantR1<int32>(body_builder, mask_vector);
|
||||
auto mask_range_row =
|
||||
xla::Broadcast(xla::Reshape(mask_range, {0}, {1, n}), major_dims);
|
||||
Broadcast(Reshape(mask_range, {0}, {1, n}), major_dims);
|
||||
auto mask_range_col =
|
||||
xla::Broadcast(xla::Reshape(mask_range, {0}, {n, 1}), major_dims);
|
||||
Broadcast(Reshape(mask_range, {0}, {n, 1}), major_dims);
|
||||
auto body_a = loop_vars[0];
|
||||
auto body_l = loop_vars[1];
|
||||
|
||||
// row = l[..., i, :i]
|
||||
// select the whole i-th row, then mask out all columns past i-1
|
||||
auto zero = xla::ConstantR0<int32>(body_builder, 0);
|
||||
auto zero = ConstantR0<int32>(body_builder, 0);
|
||||
auto l_i = DynamicSliceInMinorDims(body_l, {i, zero}, {1, n});
|
||||
auto row = xla::Select(xla::Ge(mask_range_row, i), mask_zeros_row, l_i);
|
||||
auto row = Select(Ge(mask_range_row, i), mask_zeros_row, l_i);
|
||||
// a[..., i, i]
|
||||
auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1});
|
||||
// np.dot(row, np.swapaxes(row, -1, -2))
|
||||
auto diag_dot = BatchDot(row, TransposeInMinorDims(row), precision);
|
||||
// l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row,
|
||||
// np.swapaxes(row, -1, -2)))
|
||||
auto l_ii =
|
||||
xla::Pow(a_ii - diag_dot,
|
||||
FloatLiteral(body_builder, a_shape.element_type(), 0.5));
|
||||
auto l_ii = Sqrt(a_ii - diag_dot);
|
||||
|
||||
// a[..., i+1:, i]
|
||||
// select the whole i-th column, then mask out all rows above i+1
|
||||
auto a_0i = DynamicSliceInMinorDims(body_a, {i}, {1});
|
||||
auto a_ip1i =
|
||||
xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, a_0i);
|
||||
auto a_ip1i = Select(Le(mask_range_col, i), mask_zeros_col, a_0i);
|
||||
|
||||
// l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) /
|
||||
// l[..., i, i]
|
||||
@ -122,8 +118,7 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a,
|
||||
// r.T)
|
||||
auto dot = BatchDot(body_l, TransposeInMinorDims(row), precision);
|
||||
// np.dot(l[..., i+1:, :i], r.T)
|
||||
auto dot_ip1 =
|
||||
xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot);
|
||||
auto dot_ip1 = Select(Le(mask_range_col, i), mask_zeros_col, dot);
|
||||
|
||||
body_l =
|
||||
DynamicUpdateSliceInMinorDims(body_l, (a_ip1i - dot_ip1) / l_ii, {i});
|
||||
@ -131,12 +126,12 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a,
|
||||
// column assign will wrap around and overwrite the diagonal assign.
|
||||
body_l = DynamicUpdateSliceInMinorDims(body_l, l_ii, {i, i});
|
||||
|
||||
return std::vector<xla::XlaOp>{body_a, body_l};
|
||||
return std::vector<XlaOp>{body_a, body_l};
|
||||
};
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto cholesky_while,
|
||||
XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder));
|
||||
ForEachIndex(n, S32, body_fn, {a, l}, "unblocked", builder));
|
||||
|
||||
return cholesky_while[1];
|
||||
});
|
||||
@ -144,34 +139,35 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a,
|
||||
|
||||
} // namespace
|
||||
|
||||
xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size,
|
||||
xla::PrecisionConfig::Precision precision) {
|
||||
xla::XlaBuilder* builder = a.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
|
||||
const int ndims = xla::ShapeUtil::Rank(a_shape);
|
||||
XlaOp Cholesky(XlaOp a, int64 block_size,
|
||||
PrecisionConfig::Precision precision) {
|
||||
XlaBuilder* builder = a.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||
const int ndims = ShapeUtil::Rank(a_shape);
|
||||
if (ndims < 2) {
|
||||
return errors::InvalidArgument(
|
||||
"Arguments to Cholesky must have rank >= 2: ", ndims);
|
||||
return InvalidArgument(
|
||||
"Argument to Cholesky must have rank >= 2; shape was %s",
|
||||
a_shape.ToString());
|
||||
}
|
||||
|
||||
const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
|
||||
if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) {
|
||||
return errors::InvalidArgument(
|
||||
"Arguments to Cholesky must be square matrices: ",
|
||||
xla::ShapeUtil::HumanString(a_shape));
|
||||
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
|
||||
if (n != ShapeUtil::GetDimension(a_shape, -2)) {
|
||||
return InvalidArgument(
|
||||
"Argument to Cholesky must be batched square matrices; got shape %s",
|
||||
ShapeUtil::HumanString(a_shape));
|
||||
}
|
||||
|
||||
if (block_size < 1) {
|
||||
return errors::InvalidArgument(
|
||||
"block_size argument to Cholesky must be >= 1; got ", block_size);
|
||||
return InvalidArgument(
|
||||
"block_size argument to Cholesky must be >= 1; got %d", block_size);
|
||||
}
|
||||
|
||||
// Blocked left-looking Cholesky factorization.
|
||||
// Algorithm 1 from
|
||||
// Haidar, Azzam, et al. "High-performance Cholesky factorization for
|
||||
// GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017.
|
||||
xla::XlaOp l = xla::ZerosLike(a);
|
||||
XlaOp l = ZerosLike(a);
|
||||
for (int64 i = 0; i < n; i += block_size) {
|
||||
int64 k = std::min(block_size, n - i);
|
||||
if (i > 0) {
|
||||
@ -207,4 +203,4 @@ xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size,
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
} // namespace xla
|
@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace xla {
|
||||
|
||||
// Computes the Cholesky decompositions of a batch of symmetric positive
|
||||
// definite matrices.
|
||||
@ -34,6 +34,6 @@ xla::XlaOp Cholesky(
|
||||
xla::XlaOp a, int64 block_size = 256,
|
||||
xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST);
|
||||
|
||||
} // namespace tensorflow
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_
|
166
tensorflow/compiler/xla/client/lib/cholesky_test.cc
Normal file
166
tensorflow/compiler/xla/client/lib/cholesky_test.cc
Normal file
@ -0,0 +1,166 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/cholesky.h"
|
||||
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/array2d.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using xla::int64;
|
||||
|
||||
using CholeskyTest = xla::ClientLibraryTestBase;
|
||||
|
||||
XLA_TEST_F(CholeskyTest, Simple) {
|
||||
xla::XlaBuilder builder(TestName());
|
||||
|
||||
xla::Array2D<float> a_vals({
|
||||
{4, 6, 8, 10},
|
||||
{6, 45, 54, 63},
|
||||
{8, 54, 146, 166},
|
||||
{10, 63, 166, 310},
|
||||
});
|
||||
|
||||
xla::XlaOp a;
|
||||
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
|
||||
xla::Cholesky(a, /*block_size=*/2);
|
||||
|
||||
xla::Array2D<float> expected({
|
||||
{2, 0, 0, 0},
|
||||
{3, 6, 0, 0},
|
||||
{4, 7, 9, 0},
|
||||
{5, 8, 10, 11},
|
||||
});
|
||||
|
||||
ComputeAndCompareR2<float>(&builder, expected, {a_data.get()},
|
||||
xla::ErrorSpec(1e-4, 1e-4));
|
||||
}
|
||||
|
||||
XLA_TEST_F(CholeskyTest, Simple2) {
|
||||
xla::XlaBuilder builder(TestName());
|
||||
|
||||
xla::Array2D<float> a_vals({
|
||||
{16, 24, 8, 12},
|
||||
{24, 61, 82, 48},
|
||||
{8, 82, 456, 106},
|
||||
{12, 48, 106, 62},
|
||||
});
|
||||
|
||||
xla::XlaOp a;
|
||||
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
|
||||
xla::Cholesky(a);
|
||||
|
||||
xla::Array2D<float> expected(
|
||||
{{4, 0, 0, 0}, {6, 5, 0, 0}, {2, 14, 16, 0}, {3, 6, 1, 4}});
|
||||
|
||||
ComputeAndCompareR2<float>(&builder, expected, {a_data.get()},
|
||||
xla::ErrorSpec(1e-4, 1e-4));
|
||||
}
|
||||
|
||||
XLA_TEST_F(CholeskyTest, SimpleBatched) {
|
||||
xla::XlaBuilder builder(TestName());
|
||||
|
||||
xla::Array3D<float> a_vals({
|
||||
{
|
||||
{4, 6, 8, 10},
|
||||
{6, 45, 54, 63},
|
||||
{8, 54, 146, 166},
|
||||
{10, 63, 166, 310},
|
||||
},
|
||||
{
|
||||
{16, 24, 8, 12},
|
||||
{24, 61, 82, 48},
|
||||
{8, 82, 456, 106},
|
||||
{12, 48, 106, 62},
|
||||
},
|
||||
});
|
||||
|
||||
xla::XlaOp a;
|
||||
auto a_data = CreateR3Parameter<float>(a_vals, 0, "a", &builder, &a);
|
||||
xla::Cholesky(a);
|
||||
|
||||
xla::Array3D<float> expected({
|
||||
{
|
||||
{2, 0, 0, 0},
|
||||
{3, 6, 0, 0},
|
||||
{4, 7, 9, 0},
|
||||
{5, 8, 10, 11},
|
||||
},
|
||||
{{4, 0, 0, 0}, {6, 5, 0, 0}, {2, 14, 16, 0}, {3, 6, 1, 4}},
|
||||
});
|
||||
|
||||
ComputeAndCompareR3<float>(&builder, expected, {a_data.get()},
|
||||
xla::ErrorSpec(1e-4, 1e-4));
|
||||
}
|
||||
|
||||
using CholeskyTestCase = std::tuple<int64, int64>;
|
||||
|
||||
class RandomCholeskyTest
|
||||
: public xla::ClientLibraryTestBase,
|
||||
public ::testing::WithParamInterface<CholeskyTestCase> {};
|
||||
|
||||
XLA_TEST_P(RandomCholeskyTest, Random) {
|
||||
xla::XlaBuilder builder(TestName());
|
||||
|
||||
auto test_params = GetParam();
|
||||
std::vector<int64> dimensions = {std::get<0>(test_params),
|
||||
std::get<1>(test_params),
|
||||
std::get<1>(test_params)};
|
||||
xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, dimensions);
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto literal,
|
||||
xla::LiteralUtil::CreateRandomLiteral<xla::F32>(shape, 0.0, 1.0));
|
||||
|
||||
auto input = xla::Parameter(&builder, 0, shape, "input");
|
||||
// Form a random positive definite matrix.
|
||||
auto matrix = xla::BatchDot(input, TransposeInMinorDims(input),
|
||||
xla::PrecisionConfig::HIGHEST);
|
||||
|
||||
auto cholesky = xla::Cholesky(matrix, /*block_size=*/4);
|
||||
|
||||
// Verify that ||matrix - cholesky * cholesky_t||_2 ~= 0
|
||||
auto verification = xla::BatchDot(cholesky, TransposeInMinorDims(cholesky),
|
||||
xla::PrecisionConfig::HIGHEST);
|
||||
auto delta = matrix - verification;
|
||||
xla::Reduce(delta * delta, xla::ConstantR0<float>(&builder, 0.0),
|
||||
CreateScalarAddComputation(xla::F32, &builder), {0, 1, 2});
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal));
|
||||
ComputeAndCompareR0<float>(&builder, 0.0, {input_data.get()},
|
||||
xla::ErrorSpec(1e-4, 1e-4));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(RandomCholeskyTestInstance, RandomCholeskyTest,
|
||||
::testing::Values(CholeskyTestCase{1, 1},
|
||||
CholeskyTestCase{1, 2},
|
||||
CholeskyTestCase{10, 5},
|
||||
CholeskyTestCase{2, 20}));
|
||||
|
||||
} // namespace
|
@ -13,44 +13,43 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
|
||||
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/loops.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace xla {
|
||||
|
||||
xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
|
||||
const LoopConditionFunction& condition_function,
|
||||
const LoopBodyFunction& body_function,
|
||||
absl::Span<const xla::XlaOp> initial_values, absl::string_view name,
|
||||
xla::XlaBuilder* builder) {
|
||||
StatusOr<std::vector<XlaOp>> WhileLoopHelper(
|
||||
const WhileLoopHelperConditionFunction& condition_function,
|
||||
const WhileLoopHelperBodyFunction& body_function,
|
||||
absl::Span<const XlaOp> initial_values, absl::string_view name,
|
||||
XlaBuilder* builder) {
|
||||
int arity = initial_values.size();
|
||||
std::vector<xla::Shape> var_shapes;
|
||||
std::vector<Shape> var_shapes;
|
||||
var_shapes.reserve(arity);
|
||||
for (const xla::XlaOp& input : initial_values) {
|
||||
for (const XlaOp& input : initial_values) {
|
||||
TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(input));
|
||||
var_shapes.push_back(std::move(shape));
|
||||
}
|
||||
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(var_shapes);
|
||||
Shape tuple_shape = ShapeUtil::MakeTupleShape(var_shapes);
|
||||
|
||||
// Unpacks a tuple into its component parts.
|
||||
auto unpack_tuple = [](xla::XlaOp tuple, int arity,
|
||||
xla::XlaBuilder* builder) {
|
||||
std::vector<xla::XlaOp> elements(arity);
|
||||
auto unpack_tuple = [](XlaOp tuple, int arity, XlaBuilder* builder) {
|
||||
std::vector<XlaOp> elements(arity);
|
||||
for (int i = 0; i < arity; ++i) {
|
||||
elements[i] = xla::GetTupleElement(tuple, i);
|
||||
elements[i] = GetTupleElement(tuple, i);
|
||||
}
|
||||
return elements;
|
||||
};
|
||||
|
||||
// Build the condition.
|
||||
std::unique_ptr<xla::XlaBuilder> cond_builder =
|
||||
std::unique_ptr<XlaBuilder> cond_builder =
|
||||
builder->CreateSubBuilder(absl::StrCat(name, "_condition"));
|
||||
{
|
||||
auto parameter =
|
||||
xla::Parameter(cond_builder.get(), 0, tuple_shape, "parameter");
|
||||
auto parameter = Parameter(cond_builder.get(), 0, tuple_shape, "parameter");
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
condition_function(unpack_tuple(parameter, arity, cond_builder.get()),
|
||||
@ -60,11 +59,10 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
|
||||
TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build());
|
||||
|
||||
// Build the body.
|
||||
std::unique_ptr<xla::XlaBuilder> body_builder =
|
||||
std::unique_ptr<XlaBuilder> body_builder =
|
||||
builder->CreateSubBuilder(absl::StrCat(name, "_body"));
|
||||
{
|
||||
auto parameter =
|
||||
xla::Parameter(body_builder.get(), 0, tuple_shape, "parameter");
|
||||
auto parameter = Parameter(body_builder.get(), 0, tuple_shape, "parameter");
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto result,
|
||||
@ -72,56 +70,54 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
|
||||
body_builder.get()));
|
||||
|
||||
TF_RET_CHECK(result.size() == initial_values.size());
|
||||
xla::Tuple(body_builder.get(), result);
|
||||
Tuple(body_builder.get(), result);
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto body, body_builder->Build());
|
||||
|
||||
auto outputs = xla::While(cond, body, xla::Tuple(builder, initial_values));
|
||||
auto outputs = While(cond, body, Tuple(builder, initial_values));
|
||||
|
||||
return unpack_tuple(outputs, arity, builder);
|
||||
}
|
||||
|
||||
xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
|
||||
int64 num_iterations, xla::PrimitiveType num_iterations_type,
|
||||
StatusOr<std::vector<XlaOp>> ForEachIndex(
|
||||
int64 num_iterations, PrimitiveType num_iterations_type,
|
||||
const ForEachIndexBodyFunction& body_function,
|
||||
absl::Span<const xla::XlaOp> initial_values, absl::string_view name,
|
||||
xla::XlaBuilder* builder) {
|
||||
auto while_cond_fn =
|
||||
[&](absl::Span<const xla::XlaOp> values,
|
||||
xla::XlaBuilder* cond_builder) -> xla::StatusOr<xla::XlaOp> {
|
||||
return xla::Lt(values[0], IntegerLiteral(cond_builder, num_iterations_type,
|
||||
num_iterations));
|
||||
absl::Span<const XlaOp> initial_values, absl::string_view name,
|
||||
XlaBuilder* builder) {
|
||||
auto while_cond_fn = [&](absl::Span<const XlaOp> values,
|
||||
XlaBuilder* cond_builder) -> StatusOr<XlaOp> {
|
||||
return Lt(values[0], ConstantR0WithType(cond_builder, num_iterations_type,
|
||||
num_iterations));
|
||||
};
|
||||
auto while_body_fn = [&](absl::Span<const xla::XlaOp> values,
|
||||
xla::XlaBuilder* body_builder)
|
||||
-> xla::StatusOr<std::vector<xla::XlaOp>> {
|
||||
xla::XlaOp iteration = values[0];
|
||||
auto while_body_fn =
|
||||
[&](absl::Span<const XlaOp> values,
|
||||
XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
|
||||
XlaOp iteration = values[0];
|
||||
|
||||
std::vector<xla::XlaOp> updated_values;
|
||||
std::vector<XlaOp> updated_values;
|
||||
updated_values.reserve(values.size());
|
||||
updated_values.push_back(xla::Add(
|
||||
updated_values.push_back(Add(
|
||||
iteration,
|
||||
xla::ConstantLiteral(body_builder,
|
||||
xla::LiteralUtil::One(num_iterations_type))));
|
||||
ConstantLiteral(body_builder, LiteralUtil::One(num_iterations_type))));
|
||||
|
||||
values.remove_prefix(1);
|
||||
TF_ASSIGN_OR_RETURN(std::vector<xla::XlaOp> body_outputs,
|
||||
TF_ASSIGN_OR_RETURN(std::vector<XlaOp> body_outputs,
|
||||
body_function(iteration, values, body_builder));
|
||||
updated_values.insert(updated_values.end(), body_outputs.begin(),
|
||||
body_outputs.end());
|
||||
return updated_values;
|
||||
};
|
||||
|
||||
std::vector<xla::XlaOp> values;
|
||||
std::vector<XlaOp> values;
|
||||
values.reserve(initial_values.size() + 1);
|
||||
values.push_back(xla::ConstantLiteral(
|
||||
builder, xla::LiteralUtil::Zero(num_iterations_type)));
|
||||
values.push_back(
|
||||
ConstantLiteral(builder, LiteralUtil::Zero(num_iterations_type)));
|
||||
values.insert(values.end(), initial_values.begin(), initial_values.end());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values,
|
||||
name, builder));
|
||||
TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn,
|
||||
values, name, builder));
|
||||
values.erase(values.begin(), values.begin() + 1);
|
||||
return values;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
} // namespace xla
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_
|
||||
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
@ -25,19 +25,18 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace xla {
|
||||
|
||||
// Function that builds a loop condition. Takes as input a sequence of input
|
||||
// values, and returns a boolean value representing if the condition succeeds.
|
||||
typedef std::function<xla::StatusOr<xla::XlaOp>(absl::Span<const xla::XlaOp>,
|
||||
xla::XlaBuilder*)>
|
||||
LoopConditionFunction;
|
||||
typedef std::function<StatusOr<XlaOp>(absl::Span<const XlaOp>, XlaBuilder*)>
|
||||
WhileLoopHelperConditionFunction;
|
||||
|
||||
// Function that builds a loop body. Takes as input a sequence of input values
|
||||
// and returns a sequence of output values.
|
||||
typedef std::function<xla::StatusOr<std::vector<xla::XlaOp>>(
|
||||
absl::Span<const xla::XlaOp>, xla::XlaBuilder*)>
|
||||
LoopBodyFunction;
|
||||
typedef std::function<StatusOr<std::vector<XlaOp>>(absl::Span<const XlaOp>,
|
||||
XlaBuilder*)>
|
||||
WhileLoopHelperBodyFunction;
|
||||
|
||||
// Helper function for building an XLA while loop, where the values carried by
|
||||
// the loop are a tuple of values, e.g., (a, b, c):
|
||||
@ -47,27 +46,27 @@ typedef std::function<xla::StatusOr<std::vector<xla::XlaOp>>(
|
||||
// init: (a, b, c)
|
||||
// )
|
||||
// 'name' is a descriptive name for the loop.
|
||||
xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
|
||||
const LoopConditionFunction& condition_function,
|
||||
const LoopBodyFunction& body_function,
|
||||
absl::Span<const xla::XlaOp> initial_values, absl::string_view name,
|
||||
xla::XlaBuilder* builder);
|
||||
StatusOr<std::vector<XlaOp>> WhileLoopHelper(
|
||||
const WhileLoopHelperConditionFunction& condition_function,
|
||||
const WhileLoopHelperBodyFunction& body_function,
|
||||
absl::Span<const XlaOp> initial_values, absl::string_view name,
|
||||
XlaBuilder* builder);
|
||||
|
||||
// Builds an XLA loop that repeats a computation `num_iterations` times.
|
||||
//
|
||||
// The body function (ForEachIndexBodyFunction) takes as input a pair of
|
||||
// (current iteration number, loop-carried values), and returns an updated
|
||||
// vector of the loop-carried values.
|
||||
typedef std::function<xla::StatusOr<std::vector<xla::XlaOp>>(
|
||||
xla::XlaOp, absl::Span<const xla::XlaOp>, xla::XlaBuilder*)>
|
||||
typedef std::function<StatusOr<std::vector<XlaOp>>(
|
||||
XlaOp, absl::Span<const XlaOp>, XlaBuilder*)>
|
||||
ForEachIndexBodyFunction;
|
||||
|
||||
xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
|
||||
int64 num_iterations, xla::PrimitiveType num_iterations_type,
|
||||
StatusOr<std::vector<XlaOp>> ForEachIndex(
|
||||
int64 num_iterations, PrimitiveType num_iterations_type,
|
||||
const ForEachIndexBodyFunction& body_function,
|
||||
absl::Span<const xla::XlaOp> initial_values, absl::string_view name,
|
||||
xla::XlaBuilder* builder);
|
||||
absl::Span<const XlaOp> initial_values, absl::string_view name,
|
||||
XlaBuilder* builder);
|
||||
|
||||
} // namespace tensorflow
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_
|
Loading…
Reference in New Issue
Block a user