[XLA] [TF:XLA] Move Cholesky decomposition into xla/client/lib/cholesky.*

Move loop helpers used by Cholesky decomposition into xla/client/lib/loops.*.

PiperOrigin-RevId: 225037112
This commit is contained in:
Peter Hawkins 2018-12-11 11:15:58 -08:00 committed by TensorFlower Gardener
parent c99ecfa992
commit 9b964193d9
14 changed files with 361 additions and 194 deletions

View File

@ -1,16 +1,11 @@
load("//tensorflow:tensorflow.bzl", "tf_copts", "tf_kernel_library")
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
package( package(
default_visibility = ["//tensorflow/compiler/tf2xla:internal"], default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
) )
load("//tensorflow:tensorflow.bzl", "tf_copts")
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
)
tf_kernel_library( tf_kernel_library(
name = "xla_ops", name = "xla_ops",
srcs = [ srcs = [
@ -122,12 +117,10 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/lib:broadcast", "//tensorflow/compiler/tf2xla/lib:broadcast",
"//tensorflow/compiler/tf2xla/lib:cholesky",
"//tensorflow/compiler/tf2xla/lib:qr", "//tensorflow/compiler/tf2xla/lib:qr",
"//tensorflow/compiler/tf2xla/lib:random", "//tensorflow/compiler/tf2xla/lib:random",
"//tensorflow/compiler/tf2xla/lib:scatter", "//tensorflow/compiler/tf2xla/lib:scatter",
"//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/tf2xla/lib:util",
"//tensorflow/compiler/tf2xla/lib:while_loop",
"//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:array4d",
"//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal",
@ -140,7 +133,9 @@ tf_kernel_library(
"//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:cholesky",
"//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:loops",
"//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/client/lib:pooling", "//tensorflow/compiler/xla/client/lib:pooling",

View File

@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/tf2xla/lib/cholesky.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/cholesky.h"
namespace tensorflow { namespace tensorflow {
namespace { namespace {
@ -24,7 +24,7 @@ class CholeskyOp : public XlaOpKernel {
public: public:
explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override { void Compile(XlaOpKernelContext* ctx) override {
ctx->SetOutput(0, Cholesky(ctx->Input(0))); ctx->SetOutput(0, xla::Cholesky(ctx->Input(0)));
} }
}; };

View File

@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_context.h"

View File

@ -15,12 +15,12 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
#include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/loops.h"
#include "tensorflow/compiler/xla/client/lib/sorting.h" #include "tensorflow/compiler/xla/client/lib/sorting.h"
#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
@ -505,9 +505,9 @@ class NonMaxSuppressionOp : public XlaOpKernel {
init_values.push_back(included_iou); init_values.push_back(included_iou);
auto suppress_loop_result = auto suppress_loop_result =
XlaWhileLoop(WhileCondFn(num_boxes, output_size), xla::WhileLoopHelper(WhileCondFn(num_boxes, output_size),
SuppressBodyFn(num_boxes), init_values, "suppress_loop", SuppressBodyFn(num_boxes), init_values,
builder) "suppress_loop", builder)
.ValueOrDie(); .ValueOrDie();
xla::XlaOp included_score = xla::XlaOp included_score =

View File

@ -20,12 +20,12 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
#include "tensorflow/compiler/tf2xla/lib/random.h" #include "tensorflow/compiler/tf2xla/lib/random.h"
#include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/loops.h"
#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
@ -175,7 +175,7 @@ class RandomShuffleOp : public XlaOpKernel {
}; };
// for i in range(n): // for i in range(n):
auto swap_loop_result = auto swap_loop_result =
XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices}, xla::ForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices},
"indices_swap_loop", builder) "indices_swap_loop", builder)
.ValueOrDie(); .ValueOrDie();
auto swapped_indices = swap_loop_result[1]; auto swapped_indices = swap_loop_result[1];

View File

@ -15,8 +15,6 @@ filegroup(
]), ]),
) )
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
cc_library( cc_library(
name = "broadcast", name = "broadcast",
srcs = ["broadcast.cc"], srcs = ["broadcast.cc"],
@ -33,27 +31,6 @@ cc_library(
], ],
) )
cc_library(
name = "cholesky",
srcs = ["cholesky.cc"],
hdrs = ["cholesky.h"],
deps = [
":util",
":while_loop",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/client/lib:slicing",
"//tensorflow/compiler/xla/client/lib:triangular_solve",
"//tensorflow/core:lib",
],
)
cc_library( cc_library(
name = "random", name = "random",
srcs = ["random.cc"], srcs = ["random.cc"],
@ -75,7 +52,6 @@ cc_library(
hdrs = ["qr.h"], hdrs = ["qr.h"],
deps = [ deps = [
":util", ":util",
":while_loop",
"//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:status_macros",
@ -84,6 +60,7 @@ cc_library(
"//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:loops",
"//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/client/lib:slicing", "//tensorflow/compiler/xla/client/lib:slicing",
@ -97,7 +74,6 @@ cc_library(
hdrs = ["scatter.h"], hdrs = ["scatter.h"],
deps = [ deps = [
":util", ":util",
":while_loop",
"//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:status_macros",
@ -128,19 +104,3 @@ cc_library(
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
], ],
) )
cc_library(
name = "while_loop",
srcs = ["while_loop.cc"],
hdrs = ["while_loop.h"],
deps = [
":util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)

View File

@ -19,9 +19,9 @@ limitations under the License.
#include <vector> #include <vector>
#include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/loops.h"
#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/lib/slicing.h"
@ -225,7 +225,7 @@ xla::StatusOr<QRBlockResult> QRBlock(
builder, xla::ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n}))); builder, xla::ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n})));
TF_ASSIGN_OR_RETURN(auto values, TF_ASSIGN_OR_RETURN(auto values,
XlaForEachIndex(std::min(m, n), xla::S32, qr_body_fn, xla::ForEachIndex(std::min(m, n), xla::S32, qr_body_fn,
{a, vs, taus}, "qr", builder)); {a, vs, taus}, "qr", builder));
QRBlockResult result; QRBlockResult result;
@ -301,7 +301,7 @@ xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
w = UpdateSliceInMinorDims(w, bv, {0}); w = UpdateSliceInMinorDims(w, bv, {0});
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
auto values, XlaForEachIndex(n - 1, xla::S32, body_fn, {w, y, vs, taus}, auto values, xla::ForEachIndex(n - 1, xla::S32, body_fn, {w, y, vs, taus},
"wy", builder)); "wy", builder));
return values[0]; return values[0];
} }

View File

@ -20,7 +20,6 @@ limitations under the License.
#include "absl/types/span.h" #include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal.h"

View File

@ -1,5 +1,7 @@
# Common computation builders for XLA. # Common computation builders for XLA.
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites", "xla_test")
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//tensorflow/compiler/xla/client:friends"]) package(default_visibility = ["//tensorflow/compiler/xla/client:friends"])
@ -13,9 +15,6 @@ filegroup(
]), ]),
) )
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites")
# Generate test_suites for all backends, named "${backend}_tests". # Generate test_suites for all backends, named "${backend}_tests".
generate_backend_suites() generate_backend_suites()
@ -35,6 +34,48 @@ cc_library(
], ],
) )
cc_library(
name = "cholesky",
srcs = ["cholesky.cc"],
hdrs = ["cholesky.h"],
deps = [
":math",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:loops",
"//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/client/lib:slicing",
"//tensorflow/compiler/xla/client/lib:triangular_solve",
"//tensorflow/core:lib",
],
)
xla_test(
name = "cholesky_test",
srcs = ["cholesky_test.cc"],
tags = ["optonly"],
deps = [
":arithmetic",
":cholesky",
":matrix",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)
cc_library( cc_library(
name = "constants", name = "constants",
srcs = ["constants.cc"], srcs = ["constants.cc"],
@ -75,6 +116,22 @@ cc_library(
], ],
) )
cc_library(
name = "loops",
srcs = ["loops.cc"],
hdrs = ["loops.h"],
deps = [
":constants",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
cc_library( cc_library(
name = "math", name = "math",
srcs = ["math.cc"], srcs = ["math.cc"],

View File

@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/tf2xla/lib/cholesky.h" #include "tensorflow/compiler/xla/client/lib/cholesky.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/loops.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" #include "tensorflow/compiler/xla/client/lib/triangular_solve.h"
@ -31,7 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
namespace tensorflow { namespace xla {
namespace { namespace {
@ -50,26 +50,25 @@ namespace {
// l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) / // l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) /
// l[..., j, j] // l[..., j, j]
// return l // return l
xla::XlaOp CholeskyUnblocked(xla::XlaOp a, XlaOp CholeskyUnblocked(XlaOp a, PrecisionConfig::Precision precision) {
xla::PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder();
xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); const int n_dims = ShapeUtil::Rank(a_shape);
const int n_dims = xla::ShapeUtil::Rank(a_shape); const int64 n = ShapeUtil::GetDimension(a_shape, -1);
const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); auto major_dims = AsInt64Slice(a_shape.dimensions())
auto major_dims = xla::AsInt64Slice(a_shape.dimensions())
.subspan( .subspan(
/*pos=*/0, /*pos=*/0,
/*len=*/n_dims - 2); /*len=*/n_dims - 2);
xla::XlaOp l = xla::ZerosLike(a); XlaOp l = ZerosLike(a);
// Construct the for loop body to iterate over rows. // Construct the for loop body to iterate over rows.
auto body_fn = [&](xla::XlaOp i, absl::Span<const xla::XlaOp> loop_vars, auto body_fn =
xla::XlaBuilder* body_builder) [&](XlaOp i, absl::Span<const XlaOp> loop_vars,
-> xla::StatusOr<std::vector<xla::XlaOp>> { XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
xla::Shape col_shape; Shape col_shape;
xla::Shape row_shape; Shape row_shape;
for (int64 d : major_dims) { for (int64 d : major_dims) {
row_shape.add_dimensions(d); row_shape.add_dimensions(d);
col_shape.add_dimensions(d); col_shape.add_dimensions(d);
@ -77,43 +76,40 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a,
row_shape.add_dimensions(1); row_shape.add_dimensions(1);
row_shape.add_dimensions(n); row_shape.add_dimensions(n);
row_shape.set_element_type(a_shape.element_type()); row_shape.set_element_type(a_shape.element_type());
auto mask_zeros_row = xla::Zeros(body_builder, row_shape); auto mask_zeros_row = Zeros(body_builder, row_shape);
col_shape.add_dimensions(n); col_shape.add_dimensions(n);
col_shape.add_dimensions(1); col_shape.add_dimensions(1);
col_shape.set_element_type(a_shape.element_type()); col_shape.set_element_type(a_shape.element_type());
auto mask_zeros_col = xla::Zeros(body_builder, col_shape); auto mask_zeros_col = Zeros(body_builder, col_shape);
std::vector<int32> mask_vector(n); std::vector<int32> mask_vector(n);
std::iota(mask_vector.begin(), mask_vector.end(), 0); std::iota(mask_vector.begin(), mask_vector.end(), 0);
auto mask_range = xla::ConstantR1<int32>(body_builder, mask_vector); auto mask_range = ConstantR1<int32>(body_builder, mask_vector);
auto mask_range_row = auto mask_range_row =
xla::Broadcast(xla::Reshape(mask_range, {0}, {1, n}), major_dims); Broadcast(Reshape(mask_range, {0}, {1, n}), major_dims);
auto mask_range_col = auto mask_range_col =
xla::Broadcast(xla::Reshape(mask_range, {0}, {n, 1}), major_dims); Broadcast(Reshape(mask_range, {0}, {n, 1}), major_dims);
auto body_a = loop_vars[0]; auto body_a = loop_vars[0];
auto body_l = loop_vars[1]; auto body_l = loop_vars[1];
// row = l[..., i, :i] // row = l[..., i, :i]
// select the whole i-th row, then mask out all columns past i-1 // select the whole i-th row, then mask out all columns past i-1
auto zero = xla::ConstantR0<int32>(body_builder, 0); auto zero = ConstantR0<int32>(body_builder, 0);
auto l_i = DynamicSliceInMinorDims(body_l, {i, zero}, {1, n}); auto l_i = DynamicSliceInMinorDims(body_l, {i, zero}, {1, n});
auto row = xla::Select(xla::Ge(mask_range_row, i), mask_zeros_row, l_i); auto row = Select(Ge(mask_range_row, i), mask_zeros_row, l_i);
// a[..., i, i] // a[..., i, i]
auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1}); auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1});
// np.dot(row, np.swapaxes(row, -1, -2)) // np.dot(row, np.swapaxes(row, -1, -2))
auto diag_dot = BatchDot(row, TransposeInMinorDims(row), precision); auto diag_dot = BatchDot(row, TransposeInMinorDims(row), precision);
// l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row,
// np.swapaxes(row, -1, -2))) // np.swapaxes(row, -1, -2)))
auto l_ii = auto l_ii = Sqrt(a_ii - diag_dot);
xla::Pow(a_ii - diag_dot,
FloatLiteral(body_builder, a_shape.element_type(), 0.5));
// a[..., i+1:, i] // a[..., i+1:, i]
// select the whole i-th column, then mask out all rows above i+1 // select the whole i-th column, then mask out all rows above i+1
auto a_0i = DynamicSliceInMinorDims(body_a, {i}, {1}); auto a_0i = DynamicSliceInMinorDims(body_a, {i}, {1});
auto a_ip1i = auto a_ip1i = Select(Le(mask_range_col, i), mask_zeros_col, a_0i);
xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, a_0i);
// l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) / // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) /
// l[..., i, i] // l[..., i, i]
@ -122,8 +118,7 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a,
// r.T) // r.T)
auto dot = BatchDot(body_l, TransposeInMinorDims(row), precision); auto dot = BatchDot(body_l, TransposeInMinorDims(row), precision);
// np.dot(l[..., i+1:, :i], r.T) // np.dot(l[..., i+1:, :i], r.T)
auto dot_ip1 = auto dot_ip1 = Select(Le(mask_range_col, i), mask_zeros_col, dot);
xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot);
body_l = body_l =
DynamicUpdateSliceInMinorDims(body_l, (a_ip1i - dot_ip1) / l_ii, {i}); DynamicUpdateSliceInMinorDims(body_l, (a_ip1i - dot_ip1) / l_ii, {i});
@ -131,12 +126,12 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a,
// column assign will wrap around and overwrite the diagonal assign. // column assign will wrap around and overwrite the diagonal assign.
body_l = DynamicUpdateSliceInMinorDims(body_l, l_ii, {i, i}); body_l = DynamicUpdateSliceInMinorDims(body_l, l_ii, {i, i});
return std::vector<xla::XlaOp>{body_a, body_l}; return std::vector<XlaOp>{body_a, body_l};
}; };
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
auto cholesky_while, auto cholesky_while,
XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder)); ForEachIndex(n, S32, body_fn, {a, l}, "unblocked", builder));
return cholesky_while[1]; return cholesky_while[1];
}); });
@ -144,34 +139,35 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a,
} // namespace } // namespace
xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size, XlaOp Cholesky(XlaOp a, int64 block_size,
xla::PrecisionConfig::Precision precision) { PrecisionConfig::Precision precision) {
xla::XlaBuilder* builder = a.builder(); XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> { return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
const int ndims = xla::ShapeUtil::Rank(a_shape); const int ndims = ShapeUtil::Rank(a_shape);
if (ndims < 2) { if (ndims < 2) {
return errors::InvalidArgument( return InvalidArgument(
"Arguments to Cholesky must have rank >= 2: ", ndims); "Argument to Cholesky must have rank >= 2; shape was %s",
a_shape.ToString());
} }
const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); const int64 n = ShapeUtil::GetDimension(a_shape, -1);
if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) { if (n != ShapeUtil::GetDimension(a_shape, -2)) {
return errors::InvalidArgument( return InvalidArgument(
"Arguments to Cholesky must be square matrices: ", "Argument to Cholesky must be batched square matrices; got shape %s",
xla::ShapeUtil::HumanString(a_shape)); ShapeUtil::HumanString(a_shape));
} }
if (block_size < 1) { if (block_size < 1) {
return errors::InvalidArgument( return InvalidArgument(
"block_size argument to Cholesky must be >= 1; got ", block_size); "block_size argument to Cholesky must be >= 1; got %d", block_size);
} }
// Blocked left-looking Cholesky factorization. // Blocked left-looking Cholesky factorization.
// Algorithm 1 from // Algorithm 1 from
// Haidar, Azzam, et al. "High-performance Cholesky factorization for // Haidar, Azzam, et al. "High-performance Cholesky factorization for
// GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017. // GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017.
xla::XlaOp l = xla::ZerosLike(a); XlaOp l = ZerosLike(a);
for (int64 i = 0; i < n; i += block_size) { for (int64 i = 0; i < n; i += block_size) {
int64 k = std::min(block_size, n - i); int64 k = std::min(block_size, n - i);
if (i > 0) { if (i > 0) {
@ -207,4 +203,4 @@ xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size,
}); });
} }
} // namespace tensorflow } // namespace xla

View File

@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_
#define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_
#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
namespace tensorflow { namespace xla {
// Computes the Cholesky decompositions of a batch of symmetric positive // Computes the Cholesky decompositions of a batch of symmetric positive
// definite matrices. // definite matrices.
@ -34,6 +34,6 @@ xla::XlaOp Cholesky(
xla::XlaOp a, int64 block_size = 256, xla::XlaOp a, int64 block_size = 256,
xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST);
} // namespace tensorflow } // namespace xla
#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_

View File

@ -0,0 +1,166 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/cholesky.h"
#include <memory>
#include <numeric>
#include <vector>
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace {
using xla::int64;
using CholeskyTest = xla::ClientLibraryTestBase;
XLA_TEST_F(CholeskyTest, Simple) {
xla::XlaBuilder builder(TestName());
xla::Array2D<float> a_vals({
{4, 6, 8, 10},
{6, 45, 54, 63},
{8, 54, 146, 166},
{10, 63, 166, 310},
});
xla::XlaOp a;
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
xla::Cholesky(a, /*block_size=*/2);
xla::Array2D<float> expected({
{2, 0, 0, 0},
{3, 6, 0, 0},
{4, 7, 9, 0},
{5, 8, 10, 11},
});
ComputeAndCompareR2<float>(&builder, expected, {a_data.get()},
xla::ErrorSpec(1e-4, 1e-4));
}
XLA_TEST_F(CholeskyTest, Simple2) {
xla::XlaBuilder builder(TestName());
xla::Array2D<float> a_vals({
{16, 24, 8, 12},
{24, 61, 82, 48},
{8, 82, 456, 106},
{12, 48, 106, 62},
});
xla::XlaOp a;
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
xla::Cholesky(a);
xla::Array2D<float> expected(
{{4, 0, 0, 0}, {6, 5, 0, 0}, {2, 14, 16, 0}, {3, 6, 1, 4}});
ComputeAndCompareR2<float>(&builder, expected, {a_data.get()},
xla::ErrorSpec(1e-4, 1e-4));
}
XLA_TEST_F(CholeskyTest, SimpleBatched) {
xla::XlaBuilder builder(TestName());
xla::Array3D<float> a_vals({
{
{4, 6, 8, 10},
{6, 45, 54, 63},
{8, 54, 146, 166},
{10, 63, 166, 310},
},
{
{16, 24, 8, 12},
{24, 61, 82, 48},
{8, 82, 456, 106},
{12, 48, 106, 62},
},
});
xla::XlaOp a;
auto a_data = CreateR3Parameter<float>(a_vals, 0, "a", &builder, &a);
xla::Cholesky(a);
xla::Array3D<float> expected({
{
{2, 0, 0, 0},
{3, 6, 0, 0},
{4, 7, 9, 0},
{5, 8, 10, 11},
},
{{4, 0, 0, 0}, {6, 5, 0, 0}, {2, 14, 16, 0}, {3, 6, 1, 4}},
});
ComputeAndCompareR3<float>(&builder, expected, {a_data.get()},
xla::ErrorSpec(1e-4, 1e-4));
}
using CholeskyTestCase = std::tuple<int64, int64>;
class RandomCholeskyTest
: public xla::ClientLibraryTestBase,
public ::testing::WithParamInterface<CholeskyTestCase> {};
XLA_TEST_P(RandomCholeskyTest, Random) {
xla::XlaBuilder builder(TestName());
auto test_params = GetParam();
std::vector<int64> dimensions = {std::get<0>(test_params),
std::get<1>(test_params),
std::get<1>(test_params)};
xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, dimensions);
TF_ASSERT_OK_AND_ASSIGN(
auto literal,
xla::LiteralUtil::CreateRandomLiteral<xla::F32>(shape, 0.0, 1.0));
auto input = xla::Parameter(&builder, 0, shape, "input");
// Form a random positive definite matrix.
auto matrix = xla::BatchDot(input, TransposeInMinorDims(input),
xla::PrecisionConfig::HIGHEST);
auto cholesky = xla::Cholesky(matrix, /*block_size=*/4);
// Verify that ||matrix - cholesky * cholesky_t||_2 ~= 0
auto verification = xla::BatchDot(cholesky, TransposeInMinorDims(cholesky),
xla::PrecisionConfig::HIGHEST);
auto delta = matrix - verification;
xla::Reduce(delta * delta, xla::ConstantR0<float>(&builder, 0.0),
CreateScalarAddComputation(xla::F32, &builder), {0, 1, 2});
TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal));
ComputeAndCompareR0<float>(&builder, 0.0, {input_data.get()},
xla::ErrorSpec(1e-4, 1e-4));
}
INSTANTIATE_TEST_CASE_P(RandomCholeskyTestInstance, RandomCholeskyTest,
::testing::Values(CholeskyTestCase{1, 1},
CholeskyTestCase{1, 2},
CholeskyTestCase{10, 5},
CholeskyTestCase{2, 20}));
} // namespace

View File

@ -13,44 +13,43 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/loops.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/status_macros.h"
namespace tensorflow { namespace xla {
xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop( StatusOr<std::vector<XlaOp>> WhileLoopHelper(
const LoopConditionFunction& condition_function, const WhileLoopHelperConditionFunction& condition_function,
const LoopBodyFunction& body_function, const WhileLoopHelperBodyFunction& body_function,
absl::Span<const xla::XlaOp> initial_values, absl::string_view name, absl::Span<const XlaOp> initial_values, absl::string_view name,
xla::XlaBuilder* builder) { XlaBuilder* builder) {
int arity = initial_values.size(); int arity = initial_values.size();
std::vector<xla::Shape> var_shapes; std::vector<Shape> var_shapes;
var_shapes.reserve(arity); var_shapes.reserve(arity);
for (const xla::XlaOp& input : initial_values) { for (const XlaOp& input : initial_values) {
TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(input)); TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(input));
var_shapes.push_back(std::move(shape)); var_shapes.push_back(std::move(shape));
} }
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(var_shapes); Shape tuple_shape = ShapeUtil::MakeTupleShape(var_shapes);
// Unpacks a tuple into its component parts. // Unpacks a tuple into its component parts.
auto unpack_tuple = [](xla::XlaOp tuple, int arity, auto unpack_tuple = [](XlaOp tuple, int arity, XlaBuilder* builder) {
xla::XlaBuilder* builder) { std::vector<XlaOp> elements(arity);
std::vector<xla::XlaOp> elements(arity);
for (int i = 0; i < arity; ++i) { for (int i = 0; i < arity; ++i) {
elements[i] = xla::GetTupleElement(tuple, i); elements[i] = GetTupleElement(tuple, i);
} }
return elements; return elements;
}; };
// Build the condition. // Build the condition.
std::unique_ptr<xla::XlaBuilder> cond_builder = std::unique_ptr<XlaBuilder> cond_builder =
builder->CreateSubBuilder(absl::StrCat(name, "_condition")); builder->CreateSubBuilder(absl::StrCat(name, "_condition"));
{ {
auto parameter = auto parameter = Parameter(cond_builder.get(), 0, tuple_shape, "parameter");
xla::Parameter(cond_builder.get(), 0, tuple_shape, "parameter");
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
condition_function(unpack_tuple(parameter, arity, cond_builder.get()), condition_function(unpack_tuple(parameter, arity, cond_builder.get()),
@ -60,11 +59,10 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build()); TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build());
// Build the body. // Build the body.
std::unique_ptr<xla::XlaBuilder> body_builder = std::unique_ptr<XlaBuilder> body_builder =
builder->CreateSubBuilder(absl::StrCat(name, "_body")); builder->CreateSubBuilder(absl::StrCat(name, "_body"));
{ {
auto parameter = auto parameter = Parameter(body_builder.get(), 0, tuple_shape, "parameter");
xla::Parameter(body_builder.get(), 0, tuple_shape, "parameter");
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
auto result, auto result,
@ -72,56 +70,54 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
body_builder.get())); body_builder.get()));
TF_RET_CHECK(result.size() == initial_values.size()); TF_RET_CHECK(result.size() == initial_values.size());
xla::Tuple(body_builder.get(), result); Tuple(body_builder.get(), result);
} }
TF_ASSIGN_OR_RETURN(auto body, body_builder->Build()); TF_ASSIGN_OR_RETURN(auto body, body_builder->Build());
auto outputs = xla::While(cond, body, xla::Tuple(builder, initial_values)); auto outputs = While(cond, body, Tuple(builder, initial_values));
return unpack_tuple(outputs, arity, builder); return unpack_tuple(outputs, arity, builder);
} }
xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex( StatusOr<std::vector<XlaOp>> ForEachIndex(
int64 num_iterations, xla::PrimitiveType num_iterations_type, int64 num_iterations, PrimitiveType num_iterations_type,
const ForEachIndexBodyFunction& body_function, const ForEachIndexBodyFunction& body_function,
absl::Span<const xla::XlaOp> initial_values, absl::string_view name, absl::Span<const XlaOp> initial_values, absl::string_view name,
xla::XlaBuilder* builder) { XlaBuilder* builder) {
auto while_cond_fn = auto while_cond_fn = [&](absl::Span<const XlaOp> values,
[&](absl::Span<const xla::XlaOp> values, XlaBuilder* cond_builder) -> StatusOr<XlaOp> {
xla::XlaBuilder* cond_builder) -> xla::StatusOr<xla::XlaOp> { return Lt(values[0], ConstantR0WithType(cond_builder, num_iterations_type,
return xla::Lt(values[0], IntegerLiteral(cond_builder, num_iterations_type,
num_iterations)); num_iterations));
}; };
auto while_body_fn = [&](absl::Span<const xla::XlaOp> values, auto while_body_fn =
xla::XlaBuilder* body_builder) [&](absl::Span<const XlaOp> values,
-> xla::StatusOr<std::vector<xla::XlaOp>> { XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
xla::XlaOp iteration = values[0]; XlaOp iteration = values[0];
std::vector<xla::XlaOp> updated_values; std::vector<XlaOp> updated_values;
updated_values.reserve(values.size()); updated_values.reserve(values.size());
updated_values.push_back(xla::Add( updated_values.push_back(Add(
iteration, iteration,
xla::ConstantLiteral(body_builder, ConstantLiteral(body_builder, LiteralUtil::One(num_iterations_type))));
xla::LiteralUtil::One(num_iterations_type))));
values.remove_prefix(1); values.remove_prefix(1);
TF_ASSIGN_OR_RETURN(std::vector<xla::XlaOp> body_outputs, TF_ASSIGN_OR_RETURN(std::vector<XlaOp> body_outputs,
body_function(iteration, values, body_builder)); body_function(iteration, values, body_builder));
updated_values.insert(updated_values.end(), body_outputs.begin(), updated_values.insert(updated_values.end(), body_outputs.begin(),
body_outputs.end()); body_outputs.end());
return updated_values; return updated_values;
}; };
std::vector<xla::XlaOp> values; std::vector<XlaOp> values;
values.reserve(initial_values.size() + 1); values.reserve(initial_values.size() + 1);
values.push_back(xla::ConstantLiteral( values.push_back(
builder, xla::LiteralUtil::Zero(num_iterations_type))); ConstantLiteral(builder, LiteralUtil::Zero(num_iterations_type)));
values.insert(values.end(), initial_values.begin(), initial_values.end()); values.insert(values.end(), initial_values.begin(), initial_values.end());
TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values, TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn,
name, builder)); values, name, builder));
values.erase(values.begin(), values.begin() + 1); values.erase(values.begin(), values.begin() + 1);
return values; return values;
} }
} // namespace tensorflow } // namespace xla

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_
#define TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_
#include <functional> #include <functional>
#include <vector> #include <vector>
@ -25,19 +25,18 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
namespace tensorflow { namespace xla {
// Function that builds a loop condition. Takes as input a sequence of input // Function that builds a loop condition. Takes as input a sequence of input
// values, and returns a boolean value representing if the condition succeeds. // values, and returns a boolean value representing if the condition succeeds.
typedef std::function<xla::StatusOr<xla::XlaOp>(absl::Span<const xla::XlaOp>, typedef std::function<StatusOr<XlaOp>(absl::Span<const XlaOp>, XlaBuilder*)>
xla::XlaBuilder*)> WhileLoopHelperConditionFunction;
LoopConditionFunction;
// Function that builds a loop body. Takes as input a sequence of input values // Function that builds a loop body. Takes as input a sequence of input values
// and returns a sequence of output values. // and returns a sequence of output values.
typedef std::function<xla::StatusOr<std::vector<xla::XlaOp>>( typedef std::function<StatusOr<std::vector<XlaOp>>(absl::Span<const XlaOp>,
absl::Span<const xla::XlaOp>, xla::XlaBuilder*)> XlaBuilder*)>
LoopBodyFunction; WhileLoopHelperBodyFunction;
// Helper function for building an XLA while loop, where the values carried by // Helper function for building an XLA while loop, where the values carried by
// the loop are a tuple of values, e.g., (a, b, c): // the loop are a tuple of values, e.g., (a, b, c):
@ -47,27 +46,27 @@ typedef std::function<xla::StatusOr<std::vector<xla::XlaOp>>(
// init: (a, b, c) // init: (a, b, c)
// ) // )
// 'name' is a descriptive name for the loop. // 'name' is a descriptive name for the loop.
xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop( StatusOr<std::vector<XlaOp>> WhileLoopHelper(
const LoopConditionFunction& condition_function, const WhileLoopHelperConditionFunction& condition_function,
const LoopBodyFunction& body_function, const WhileLoopHelperBodyFunction& body_function,
absl::Span<const xla::XlaOp> initial_values, absl::string_view name, absl::Span<const XlaOp> initial_values, absl::string_view name,
xla::XlaBuilder* builder); XlaBuilder* builder);
// Builds an XLA loop that repeats a computation `num_iterations` times. // Builds an XLA loop that repeats a computation `num_iterations` times.
// //
// The body function (ForEachIndexBodyFunction) takes as input a pair of // The body function (ForEachIndexBodyFunction) takes as input a pair of
// (current iteration number, loop-carried values), and returns an updated // (current iteration number, loop-carried values), and returns an updated
// vector of the loop-carried values. // vector of the loop-carried values.
typedef std::function<xla::StatusOr<std::vector<xla::XlaOp>>( typedef std::function<StatusOr<std::vector<XlaOp>>(
xla::XlaOp, absl::Span<const xla::XlaOp>, xla::XlaBuilder*)> XlaOp, absl::Span<const XlaOp>, XlaBuilder*)>
ForEachIndexBodyFunction; ForEachIndexBodyFunction;
xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex( StatusOr<std::vector<XlaOp>> ForEachIndex(
int64 num_iterations, xla::PrimitiveType num_iterations_type, int64 num_iterations, PrimitiveType num_iterations_type,
const ForEachIndexBodyFunction& body_function, const ForEachIndexBodyFunction& body_function,
absl::Span<const xla::XlaOp> initial_values, absl::string_view name, absl::Span<const XlaOp> initial_values, absl::string_view name,
xla::XlaBuilder* builder); XlaBuilder* builder);
} // namespace tensorflow } // namespace xla
#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_