[XLA] [TF:XLA] Move Cholesky decomposition into xla/client/lib/cholesky.*

Move loop helpers used by Cholesky decomposition into xla/client/lib/loops.*.

PiperOrigin-RevId: 225037112
This commit is contained in:
Peter Hawkins 2018-12-11 11:15:58 -08:00 committed by TensorFlower Gardener
parent c99ecfa992
commit 9b964193d9
14 changed files with 361 additions and 194 deletions

View File

@ -1,16 +1,11 @@
load("//tensorflow:tensorflow.bzl", "tf_copts", "tf_kernel_library")
licenses(["notice"]) # Apache 2.0
package(
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
)
load("//tensorflow:tensorflow.bzl", "tf_copts")
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
)
tf_kernel_library(
name = "xla_ops",
srcs = [
@ -122,12 +117,10 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/lib:broadcast",
"//tensorflow/compiler/tf2xla/lib:cholesky",
"//tensorflow/compiler/tf2xla/lib:qr",
"//tensorflow/compiler/tf2xla/lib:random",
"//tensorflow/compiler/tf2xla/lib:scatter",
"//tensorflow/compiler/tf2xla/lib:util",
"//tensorflow/compiler/tf2xla/lib:while_loop",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:array4d",
"//tensorflow/compiler/xla:literal",
@ -140,7 +133,9 @@ tf_kernel_library(
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:cholesky",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:loops",
"//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/client/lib:pooling",

View File

@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/lib/cholesky.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/cholesky.h"
namespace tensorflow {
namespace {
@ -24,7 +24,7 @@ class CholeskyOp : public XlaOpKernel {
public:
explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
ctx->SetOutput(0, Cholesky(ctx->Input(0)));
ctx->SetOutput(0, xla::Cholesky(ctx->Input(0)));
}
};

View File

@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"

View File

@ -15,12 +15,12 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/loops.h"
#include "tensorflow/compiler/xla/client/lib/sorting.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/shape_util.h"
@ -505,9 +505,9 @@ class NonMaxSuppressionOp : public XlaOpKernel {
init_values.push_back(included_iou);
auto suppress_loop_result =
XlaWhileLoop(WhileCondFn(num_boxes, output_size),
SuppressBodyFn(num_boxes), init_values, "suppress_loop",
builder)
xla::WhileLoopHelper(WhileCondFn(num_boxes, output_size),
SuppressBodyFn(num_boxes), init_values,
"suppress_loop", builder)
.ValueOrDie();
xla::XlaOp included_score =

View File

@ -20,12 +20,12 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
#include "tensorflow/compiler/tf2xla/lib/random.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/loops.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
@ -175,7 +175,7 @@ class RandomShuffleOp : public XlaOpKernel {
};
// for i in range(n):
auto swap_loop_result =
XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices},
xla::ForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices},
"indices_swap_loop", builder)
.ValueOrDie();
auto swapped_indices = swap_loop_result[1];

View File

@ -15,8 +15,6 @@ filegroup(
]),
)
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
cc_library(
name = "broadcast",
srcs = ["broadcast.cc"],
@ -33,27 +31,6 @@ cc_library(
],
)
cc_library(
name = "cholesky",
srcs = ["cholesky.cc"],
hdrs = ["cholesky.h"],
deps = [
":util",
":while_loop",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/client/lib:slicing",
"//tensorflow/compiler/xla/client/lib:triangular_solve",
"//tensorflow/core:lib",
],
)
cc_library(
name = "random",
srcs = ["random.cc"],
@ -75,7 +52,6 @@ cc_library(
hdrs = ["qr.h"],
deps = [
":util",
":while_loop",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@ -84,6 +60,7 @@ cc_library(
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:loops",
"//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/client/lib:slicing",
@ -97,7 +74,6 @@ cc_library(
hdrs = ["scatter.h"],
deps = [
":util",
":while_loop",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@ -128,19 +104,3 @@ cc_library(
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "while_loop",
srcs = ["while_loop.cc"],
hdrs = ["while_loop.h"],
deps = [
":util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)

View File

@ -19,9 +19,9 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/loops.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
@ -225,7 +225,7 @@ xla::StatusOr<QRBlockResult> QRBlock(
builder, xla::ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n})));
TF_ASSIGN_OR_RETURN(auto values,
XlaForEachIndex(std::min(m, n), xla::S32, qr_body_fn,
xla::ForEachIndex(std::min(m, n), xla::S32, qr_body_fn,
{a, vs, taus}, "qr", builder));
QRBlockResult result;
@ -301,7 +301,7 @@ xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
w = UpdateSliceInMinorDims(w, bv, {0});
TF_ASSIGN_OR_RETURN(
auto values, XlaForEachIndex(n - 1, xla::S32, body_fn, {w, y, vs, taus},
auto values, xla::ForEachIndex(n - 1, xla::S32, body_fn, {w, y, vs, taus},
"wy", builder));
return values[0];
}

View File

@ -20,7 +20,6 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"

View File

@ -1,5 +1,7 @@
# Common computation builders for XLA.
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites", "xla_test")
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//tensorflow/compiler/xla/client:friends"])
@ -13,9 +15,6 @@ filegroup(
]),
)
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites")
# Generate test_suites for all backends, named "${backend}_tests".
generate_backend_suites()
@ -35,6 +34,48 @@ cc_library(
],
)
cc_library(
name = "cholesky",
srcs = ["cholesky.cc"],
hdrs = ["cholesky.h"],
deps = [
":math",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:loops",
"//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/client/lib:slicing",
"//tensorflow/compiler/xla/client/lib:triangular_solve",
"//tensorflow/core:lib",
],
)
xla_test(
name = "cholesky_test",
srcs = ["cholesky_test.cc"],
tags = ["optonly"],
deps = [
":arithmetic",
":cholesky",
":matrix",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)
cc_library(
name = "constants",
srcs = ["constants.cc"],
@ -75,6 +116,22 @@ cc_library(
],
)
cc_library(
name = "loops",
srcs = ["loops.cc"],
hdrs = ["loops.h"],
deps = [
":constants",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "math",
srcs = ["math.cc"],

View File

@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/lib/cholesky.h"
#include "tensorflow/compiler/xla/client/lib/cholesky.h"
#include <memory>
#include <vector>
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/loops.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/lib/triangular_solve.h"
@ -31,7 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
namespace xla {
namespace {
@ -50,26 +50,25 @@ namespace {
// l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) /
// l[..., j, j]
// return l
xla::XlaOp CholeskyUnblocked(xla::XlaOp a,
xla::PrecisionConfig::Precision precision) {
xla::XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
const int n_dims = xla::ShapeUtil::Rank(a_shape);
const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
auto major_dims = xla::AsInt64Slice(a_shape.dimensions())
XlaOp CholeskyUnblocked(XlaOp a, PrecisionConfig::Precision precision) {
XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
const int n_dims = ShapeUtil::Rank(a_shape);
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
auto major_dims = AsInt64Slice(a_shape.dimensions())
.subspan(
/*pos=*/0,
/*len=*/n_dims - 2);
xla::XlaOp l = xla::ZerosLike(a);
XlaOp l = ZerosLike(a);
// Construct the for loop body to iterate over rows.
auto body_fn = [&](xla::XlaOp i, absl::Span<const xla::XlaOp> loop_vars,
xla::XlaBuilder* body_builder)
-> xla::StatusOr<std::vector<xla::XlaOp>> {
xla::Shape col_shape;
xla::Shape row_shape;
auto body_fn =
[&](XlaOp i, absl::Span<const XlaOp> loop_vars,
XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
Shape col_shape;
Shape row_shape;
for (int64 d : major_dims) {
row_shape.add_dimensions(d);
col_shape.add_dimensions(d);
@ -77,43 +76,40 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a,
row_shape.add_dimensions(1);
row_shape.add_dimensions(n);
row_shape.set_element_type(a_shape.element_type());
auto mask_zeros_row = xla::Zeros(body_builder, row_shape);
auto mask_zeros_row = Zeros(body_builder, row_shape);
col_shape.add_dimensions(n);
col_shape.add_dimensions(1);
col_shape.set_element_type(a_shape.element_type());
auto mask_zeros_col = xla::Zeros(body_builder, col_shape);
auto mask_zeros_col = Zeros(body_builder, col_shape);
std::vector<int32> mask_vector(n);
std::iota(mask_vector.begin(), mask_vector.end(), 0);
auto mask_range = xla::ConstantR1<int32>(body_builder, mask_vector);
auto mask_range = ConstantR1<int32>(body_builder, mask_vector);
auto mask_range_row =
xla::Broadcast(xla::Reshape(mask_range, {0}, {1, n}), major_dims);
Broadcast(Reshape(mask_range, {0}, {1, n}), major_dims);
auto mask_range_col =
xla::Broadcast(xla::Reshape(mask_range, {0}, {n, 1}), major_dims);
Broadcast(Reshape(mask_range, {0}, {n, 1}), major_dims);
auto body_a = loop_vars[0];
auto body_l = loop_vars[1];
// row = l[..., i, :i]
// select the whole i-th row, then mask out all columns past i-1
auto zero = xla::ConstantR0<int32>(body_builder, 0);
auto zero = ConstantR0<int32>(body_builder, 0);
auto l_i = DynamicSliceInMinorDims(body_l, {i, zero}, {1, n});
auto row = xla::Select(xla::Ge(mask_range_row, i), mask_zeros_row, l_i);
auto row = Select(Ge(mask_range_row, i), mask_zeros_row, l_i);
// a[..., i, i]
auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1});
// np.dot(row, np.swapaxes(row, -1, -2))
auto diag_dot = BatchDot(row, TransposeInMinorDims(row), precision);
// l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row,
// np.swapaxes(row, -1, -2)))
auto l_ii =
xla::Pow(a_ii - diag_dot,
FloatLiteral(body_builder, a_shape.element_type(), 0.5));
auto l_ii = Sqrt(a_ii - diag_dot);
// a[..., i+1:, i]
// select the whole i-th column, then mask out all rows above i+1
auto a_0i = DynamicSliceInMinorDims(body_a, {i}, {1});
auto a_ip1i =
xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, a_0i);
auto a_ip1i = Select(Le(mask_range_col, i), mask_zeros_col, a_0i);
// l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) /
// l[..., i, i]
@ -122,8 +118,7 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a,
// r.T)
auto dot = BatchDot(body_l, TransposeInMinorDims(row), precision);
// np.dot(l[..., i+1:, :i], r.T)
auto dot_ip1 =
xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot);
auto dot_ip1 = Select(Le(mask_range_col, i), mask_zeros_col, dot);
body_l =
DynamicUpdateSliceInMinorDims(body_l, (a_ip1i - dot_ip1) / l_ii, {i});
@ -131,12 +126,12 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a,
// column assign will wrap around and overwrite the diagonal assign.
body_l = DynamicUpdateSliceInMinorDims(body_l, l_ii, {i, i});
return std::vector<xla::XlaOp>{body_a, body_l};
return std::vector<XlaOp>{body_a, body_l};
};
TF_ASSIGN_OR_RETURN(
auto cholesky_while,
XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder));
ForEachIndex(n, S32, body_fn, {a, l}, "unblocked", builder));
return cholesky_while[1];
});
@ -144,34 +139,35 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a,
} // namespace
xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size,
xla::PrecisionConfig::Precision precision) {
xla::XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
const int ndims = xla::ShapeUtil::Rank(a_shape);
XlaOp Cholesky(XlaOp a, int64 block_size,
PrecisionConfig::Precision precision) {
XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
const int ndims = ShapeUtil::Rank(a_shape);
if (ndims < 2) {
return errors::InvalidArgument(
"Arguments to Cholesky must have rank >= 2: ", ndims);
return InvalidArgument(
"Argument to Cholesky must have rank >= 2; shape was %s",
a_shape.ToString());
}
const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) {
return errors::InvalidArgument(
"Arguments to Cholesky must be square matrices: ",
xla::ShapeUtil::HumanString(a_shape));
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
if (n != ShapeUtil::GetDimension(a_shape, -2)) {
return InvalidArgument(
"Argument to Cholesky must be batched square matrices; got shape %s",
ShapeUtil::HumanString(a_shape));
}
if (block_size < 1) {
return errors::InvalidArgument(
"block_size argument to Cholesky must be >= 1; got ", block_size);
return InvalidArgument(
"block_size argument to Cholesky must be >= 1; got %d", block_size);
}
// Blocked left-looking Cholesky factorization.
// Algorithm 1 from
// Haidar, Azzam, et al. "High-performance Cholesky factorization for
// GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017.
xla::XlaOp l = xla::ZerosLike(a);
XlaOp l = ZerosLike(a);
for (int64 i = 0; i < n; i += block_size) {
int64 k = std::min(block_size, n - i);
if (i > 0) {
@ -207,4 +203,4 @@ xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size,
});
}
} // namespace tensorflow
} // namespace xla

View File

@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_
#define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace tensorflow {
namespace xla {
// Computes the Cholesky decompositions of a batch of symmetric positive
// definite matrices.
@ -34,6 +34,6 @@ xla::XlaOp Cholesky(
xla::XlaOp a, int64 block_size = 256,
xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST);
} // namespace tensorflow
} // namespace xla
#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_

View File

@ -0,0 +1,166 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/cholesky.h"
#include <memory>
#include <numeric>
#include <vector>
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace {
using xla::int64;
using CholeskyTest = xla::ClientLibraryTestBase;
XLA_TEST_F(CholeskyTest, Simple) {
xla::XlaBuilder builder(TestName());
xla::Array2D<float> a_vals({
{4, 6, 8, 10},
{6, 45, 54, 63},
{8, 54, 146, 166},
{10, 63, 166, 310},
});
xla::XlaOp a;
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
xla::Cholesky(a, /*block_size=*/2);
xla::Array2D<float> expected({
{2, 0, 0, 0},
{3, 6, 0, 0},
{4, 7, 9, 0},
{5, 8, 10, 11},
});
ComputeAndCompareR2<float>(&builder, expected, {a_data.get()},
xla::ErrorSpec(1e-4, 1e-4));
}
XLA_TEST_F(CholeskyTest, Simple2) {
xla::XlaBuilder builder(TestName());
xla::Array2D<float> a_vals({
{16, 24, 8, 12},
{24, 61, 82, 48},
{8, 82, 456, 106},
{12, 48, 106, 62},
});
xla::XlaOp a;
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
xla::Cholesky(a);
xla::Array2D<float> expected(
{{4, 0, 0, 0}, {6, 5, 0, 0}, {2, 14, 16, 0}, {3, 6, 1, 4}});
ComputeAndCompareR2<float>(&builder, expected, {a_data.get()},
xla::ErrorSpec(1e-4, 1e-4));
}
XLA_TEST_F(CholeskyTest, SimpleBatched) {
xla::XlaBuilder builder(TestName());
xla::Array3D<float> a_vals({
{
{4, 6, 8, 10},
{6, 45, 54, 63},
{8, 54, 146, 166},
{10, 63, 166, 310},
},
{
{16, 24, 8, 12},
{24, 61, 82, 48},
{8, 82, 456, 106},
{12, 48, 106, 62},
},
});
xla::XlaOp a;
auto a_data = CreateR3Parameter<float>(a_vals, 0, "a", &builder, &a);
xla::Cholesky(a);
xla::Array3D<float> expected({
{
{2, 0, 0, 0},
{3, 6, 0, 0},
{4, 7, 9, 0},
{5, 8, 10, 11},
},
{{4, 0, 0, 0}, {6, 5, 0, 0}, {2, 14, 16, 0}, {3, 6, 1, 4}},
});
ComputeAndCompareR3<float>(&builder, expected, {a_data.get()},
xla::ErrorSpec(1e-4, 1e-4));
}
using CholeskyTestCase = std::tuple<int64, int64>;
class RandomCholeskyTest
: public xla::ClientLibraryTestBase,
public ::testing::WithParamInterface<CholeskyTestCase> {};
XLA_TEST_P(RandomCholeskyTest, Random) {
xla::XlaBuilder builder(TestName());
auto test_params = GetParam();
std::vector<int64> dimensions = {std::get<0>(test_params),
std::get<1>(test_params),
std::get<1>(test_params)};
xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, dimensions);
TF_ASSERT_OK_AND_ASSIGN(
auto literal,
xla::LiteralUtil::CreateRandomLiteral<xla::F32>(shape, 0.0, 1.0));
auto input = xla::Parameter(&builder, 0, shape, "input");
// Form a random positive definite matrix.
auto matrix = xla::BatchDot(input, TransposeInMinorDims(input),
xla::PrecisionConfig::HIGHEST);
auto cholesky = xla::Cholesky(matrix, /*block_size=*/4);
// Verify that ||matrix - cholesky * cholesky_t||_2 ~= 0
auto verification = xla::BatchDot(cholesky, TransposeInMinorDims(cholesky),
xla::PrecisionConfig::HIGHEST);
auto delta = matrix - verification;
xla::Reduce(delta * delta, xla::ConstantR0<float>(&builder, 0.0),
CreateScalarAddComputation(xla::F32, &builder), {0, 1, 2});
TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal));
ComputeAndCompareR0<float>(&builder, 0.0, {input_data.get()},
xla::ErrorSpec(1e-4, 1e-4));
}
INSTANTIATE_TEST_CASE_P(RandomCholeskyTestInstance, RandomCholeskyTest,
::testing::Values(CholeskyTestCase{1, 1},
CholeskyTestCase{1, 2},
CholeskyTestCase{10, 5},
CholeskyTestCase{2, 20}));
} // namespace

View File

@ -13,44 +13,43 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/xla/client/lib/loops.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
namespace tensorflow {
namespace xla {
xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
const LoopConditionFunction& condition_function,
const LoopBodyFunction& body_function,
absl::Span<const xla::XlaOp> initial_values, absl::string_view name,
xla::XlaBuilder* builder) {
StatusOr<std::vector<XlaOp>> WhileLoopHelper(
const WhileLoopHelperConditionFunction& condition_function,
const WhileLoopHelperBodyFunction& body_function,
absl::Span<const XlaOp> initial_values, absl::string_view name,
XlaBuilder* builder) {
int arity = initial_values.size();
std::vector<xla::Shape> var_shapes;
std::vector<Shape> var_shapes;
var_shapes.reserve(arity);
for (const xla::XlaOp& input : initial_values) {
for (const XlaOp& input : initial_values) {
TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(input));
var_shapes.push_back(std::move(shape));
}
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(var_shapes);
Shape tuple_shape = ShapeUtil::MakeTupleShape(var_shapes);
// Unpacks a tuple into its component parts.
auto unpack_tuple = [](xla::XlaOp tuple, int arity,
xla::XlaBuilder* builder) {
std::vector<xla::XlaOp> elements(arity);
auto unpack_tuple = [](XlaOp tuple, int arity, XlaBuilder* builder) {
std::vector<XlaOp> elements(arity);
for (int i = 0; i < arity; ++i) {
elements[i] = xla::GetTupleElement(tuple, i);
elements[i] = GetTupleElement(tuple, i);
}
return elements;
};
// Build the condition.
std::unique_ptr<xla::XlaBuilder> cond_builder =
std::unique_ptr<XlaBuilder> cond_builder =
builder->CreateSubBuilder(absl::StrCat(name, "_condition"));
{
auto parameter =
xla::Parameter(cond_builder.get(), 0, tuple_shape, "parameter");
auto parameter = Parameter(cond_builder.get(), 0, tuple_shape, "parameter");
TF_RETURN_IF_ERROR(
condition_function(unpack_tuple(parameter, arity, cond_builder.get()),
@ -60,11 +59,10 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build());
// Build the body.
std::unique_ptr<xla::XlaBuilder> body_builder =
std::unique_ptr<XlaBuilder> body_builder =
builder->CreateSubBuilder(absl::StrCat(name, "_body"));
{
auto parameter =
xla::Parameter(body_builder.get(), 0, tuple_shape, "parameter");
auto parameter = Parameter(body_builder.get(), 0, tuple_shape, "parameter");
TF_ASSIGN_OR_RETURN(
auto result,
@ -72,56 +70,54 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
body_builder.get()));
TF_RET_CHECK(result.size() == initial_values.size());
xla::Tuple(body_builder.get(), result);
Tuple(body_builder.get(), result);
}
TF_ASSIGN_OR_RETURN(auto body, body_builder->Build());
auto outputs = xla::While(cond, body, xla::Tuple(builder, initial_values));
auto outputs = While(cond, body, Tuple(builder, initial_values));
return unpack_tuple(outputs, arity, builder);
}
xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
int64 num_iterations, xla::PrimitiveType num_iterations_type,
StatusOr<std::vector<XlaOp>> ForEachIndex(
int64 num_iterations, PrimitiveType num_iterations_type,
const ForEachIndexBodyFunction& body_function,
absl::Span<const xla::XlaOp> initial_values, absl::string_view name,
xla::XlaBuilder* builder) {
auto while_cond_fn =
[&](absl::Span<const xla::XlaOp> values,
xla::XlaBuilder* cond_builder) -> xla::StatusOr<xla::XlaOp> {
return xla::Lt(values[0], IntegerLiteral(cond_builder, num_iterations_type,
absl::Span<const XlaOp> initial_values, absl::string_view name,
XlaBuilder* builder) {
auto while_cond_fn = [&](absl::Span<const XlaOp> values,
XlaBuilder* cond_builder) -> StatusOr<XlaOp> {
return Lt(values[0], ConstantR0WithType(cond_builder, num_iterations_type,
num_iterations));
};
auto while_body_fn = [&](absl::Span<const xla::XlaOp> values,
xla::XlaBuilder* body_builder)
-> xla::StatusOr<std::vector<xla::XlaOp>> {
xla::XlaOp iteration = values[0];
auto while_body_fn =
[&](absl::Span<const XlaOp> values,
XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
XlaOp iteration = values[0];
std::vector<xla::XlaOp> updated_values;
std::vector<XlaOp> updated_values;
updated_values.reserve(values.size());
updated_values.push_back(xla::Add(
updated_values.push_back(Add(
iteration,
xla::ConstantLiteral(body_builder,
xla::LiteralUtil::One(num_iterations_type))));
ConstantLiteral(body_builder, LiteralUtil::One(num_iterations_type))));
values.remove_prefix(1);
TF_ASSIGN_OR_RETURN(std::vector<xla::XlaOp> body_outputs,
TF_ASSIGN_OR_RETURN(std::vector<XlaOp> body_outputs,
body_function(iteration, values, body_builder));
updated_values.insert(updated_values.end(), body_outputs.begin(),
body_outputs.end());
return updated_values;
};
std::vector<xla::XlaOp> values;
std::vector<XlaOp> values;
values.reserve(initial_values.size() + 1);
values.push_back(xla::ConstantLiteral(
builder, xla::LiteralUtil::Zero(num_iterations_type)));
values.push_back(
ConstantLiteral(builder, LiteralUtil::Zero(num_iterations_type)));
values.insert(values.end(), initial_values.begin(), initial_values.end());
TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values,
name, builder));
TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn,
values, name, builder));
values.erase(values.begin(), values.begin() + 1);
return values;
}
} // namespace tensorflow
} // namespace xla

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_
#define TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_
#include <functional>
#include <vector>
@ -25,19 +25,18 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace tensorflow {
namespace xla {
// Function that builds a loop condition. Takes as input a sequence of input
// values, and returns a boolean value representing if the condition succeeds.
typedef std::function<xla::StatusOr<xla::XlaOp>(absl::Span<const xla::XlaOp>,
xla::XlaBuilder*)>
LoopConditionFunction;
typedef std::function<StatusOr<XlaOp>(absl::Span<const XlaOp>, XlaBuilder*)>
WhileLoopHelperConditionFunction;
// Function that builds a loop body. Takes as input a sequence of input values
// and returns a sequence of output values.
typedef std::function<xla::StatusOr<std::vector<xla::XlaOp>>(
absl::Span<const xla::XlaOp>, xla::XlaBuilder*)>
LoopBodyFunction;
typedef std::function<StatusOr<std::vector<XlaOp>>(absl::Span<const XlaOp>,
XlaBuilder*)>
WhileLoopHelperBodyFunction;
// Helper function for building an XLA while loop, where the values carried by
// the loop are a tuple of values, e.g., (a, b, c):
@ -47,27 +46,27 @@ typedef std::function<xla::StatusOr<std::vector<xla::XlaOp>>(
// init: (a, b, c)
// )
// 'name' is a descriptive name for the loop.
xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
const LoopConditionFunction& condition_function,
const LoopBodyFunction& body_function,
absl::Span<const xla::XlaOp> initial_values, absl::string_view name,
xla::XlaBuilder* builder);
StatusOr<std::vector<XlaOp>> WhileLoopHelper(
const WhileLoopHelperConditionFunction& condition_function,
const WhileLoopHelperBodyFunction& body_function,
absl::Span<const XlaOp> initial_values, absl::string_view name,
XlaBuilder* builder);
// Builds an XLA loop that repeats a computation `num_iterations` times.
//
// The body function (ForEachIndexBodyFunction) takes as input a pair of
// (current iteration number, loop-carried values), and returns an updated
// vector of the loop-carried values.
typedef std::function<xla::StatusOr<std::vector<xla::XlaOp>>(
xla::XlaOp, absl::Span<const xla::XlaOp>, xla::XlaBuilder*)>
typedef std::function<StatusOr<std::vector<XlaOp>>(
XlaOp, absl::Span<const XlaOp>, XlaBuilder*)>
ForEachIndexBodyFunction;
xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
int64 num_iterations, xla::PrimitiveType num_iterations_type,
StatusOr<std::vector<XlaOp>> ForEachIndex(
int64 num_iterations, PrimitiveType num_iterations_type,
const ForEachIndexBodyFunction& body_function,
absl::Span<const xla::XlaOp> initial_values, absl::string_view name,
xla::XlaBuilder* builder);
absl::Span<const XlaOp> initial_values, absl::string_view name,
XlaBuilder* builder);
} // namespace tensorflow
} // namespace xla
#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_