[XLA] [TF:XLA] Move Cholesky decomposition into xla/client/lib/cholesky.*
Move loop helpers used by Cholesky decomposition into xla/client/lib/loops.*. PiperOrigin-RevId: 225037112
This commit is contained in:
parent
c99ecfa992
commit
9b964193d9
@ -1,16 +1,11 @@
|
|||||||
|
load("//tensorflow:tensorflow.bzl", "tf_copts", "tf_kernel_library")
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
|
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
|
||||||
)
|
)
|
||||||
|
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_copts")
|
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
|
|
||||||
load(
|
|
||||||
"//third_party/mkl:build_defs.bzl",
|
|
||||||
"if_mkl",
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "xla_ops",
|
name = "xla_ops",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -122,12 +117,10 @@ tf_kernel_library(
|
|||||||
"//tensorflow/compiler/tf2xla:common",
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
"//tensorflow/compiler/tf2xla/lib:broadcast",
|
"//tensorflow/compiler/tf2xla/lib:broadcast",
|
||||||
"//tensorflow/compiler/tf2xla/lib:cholesky",
|
|
||||||
"//tensorflow/compiler/tf2xla/lib:qr",
|
"//tensorflow/compiler/tf2xla/lib:qr",
|
||||||
"//tensorflow/compiler/tf2xla/lib:random",
|
"//tensorflow/compiler/tf2xla/lib:random",
|
||||||
"//tensorflow/compiler/tf2xla/lib:scatter",
|
"//tensorflow/compiler/tf2xla/lib:scatter",
|
||||||
"//tensorflow/compiler/tf2xla/lib:util",
|
"//tensorflow/compiler/tf2xla/lib:util",
|
||||||
"//tensorflow/compiler/tf2xla/lib:while_loop",
|
|
||||||
"//tensorflow/compiler/tf2xla/ops:xla_ops",
|
"//tensorflow/compiler/tf2xla/ops:xla_ops",
|
||||||
"//tensorflow/compiler/xla:array4d",
|
"//tensorflow/compiler/xla:array4d",
|
||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
@ -140,7 +133,9 @@ tf_kernel_library(
|
|||||||
"//tensorflow/compiler/xla/client:xla_builder",
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
"//tensorflow/compiler/xla/client:xla_computation",
|
"//tensorflow/compiler/xla/client:xla_computation",
|
||||||
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
||||||
|
"//tensorflow/compiler/xla/client/lib:cholesky",
|
||||||
"//tensorflow/compiler/xla/client/lib:constants",
|
"//tensorflow/compiler/xla/client/lib:constants",
|
||||||
|
"//tensorflow/compiler/xla/client/lib:loops",
|
||||||
"//tensorflow/compiler/xla/client/lib:math",
|
"//tensorflow/compiler/xla/client/lib:math",
|
||||||
"//tensorflow/compiler/xla/client/lib:matrix",
|
"//tensorflow/compiler/xla/client/lib:matrix",
|
||||||
"//tensorflow/compiler/xla/client/lib:pooling",
|
"//tensorflow/compiler/xla/client/lib:pooling",
|
||||||
|
@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/compiler/tf2xla/lib/cholesky.h"
|
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/cholesky.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
@ -24,7 +24,7 @@ class CholeskyOp : public XlaOpKernel {
|
|||||||
public:
|
public:
|
||||||
explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
ctx->SetOutput(0, Cholesky(ctx->Input(0)));
|
ctx->SetOutput(0, xla::Cholesky(ctx->Input(0)));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -14,7 +14,6 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
|
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
|
||||||
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
|
|
||||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||||
|
@ -15,12 +15,12 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
|
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
|
||||||
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
|
|
||||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/loops.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/sorting.h"
|
#include "tensorflow/compiler/xla/client/lib/sorting.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
@ -505,9 +505,9 @@ class NonMaxSuppressionOp : public XlaOpKernel {
|
|||||||
init_values.push_back(included_iou);
|
init_values.push_back(included_iou);
|
||||||
|
|
||||||
auto suppress_loop_result =
|
auto suppress_loop_result =
|
||||||
XlaWhileLoop(WhileCondFn(num_boxes, output_size),
|
xla::WhileLoopHelper(WhileCondFn(num_boxes, output_size),
|
||||||
SuppressBodyFn(num_boxes), init_values, "suppress_loop",
|
SuppressBodyFn(num_boxes), init_values,
|
||||||
builder)
|
"suppress_loop", builder)
|
||||||
.ValueOrDie();
|
.ValueOrDie();
|
||||||
|
|
||||||
xla::XlaOp included_score =
|
xla::XlaOp included_score =
|
||||||
|
@ -20,12 +20,12 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
|
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
|
||||||
#include "tensorflow/compiler/tf2xla/lib/random.h"
|
#include "tensorflow/compiler/tf2xla/lib/random.h"
|
||||||
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
|
|
||||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/loops.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
@ -175,7 +175,7 @@ class RandomShuffleOp : public XlaOpKernel {
|
|||||||
};
|
};
|
||||||
// for i in range(n):
|
// for i in range(n):
|
||||||
auto swap_loop_result =
|
auto swap_loop_result =
|
||||||
XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices},
|
xla::ForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices},
|
||||||
"indices_swap_loop", builder)
|
"indices_swap_loop", builder)
|
||||||
.ValueOrDie();
|
.ValueOrDie();
|
||||||
auto swapped_indices = swap_loop_result[1];
|
auto swapped_indices = swap_loop_result[1];
|
||||||
|
@ -15,8 +15,6 @@ filegroup(
|
|||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "broadcast",
|
name = "broadcast",
|
||||||
srcs = ["broadcast.cc"],
|
srcs = ["broadcast.cc"],
|
||||||
@ -33,27 +31,6 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "cholesky",
|
|
||||||
srcs = ["cholesky.cc"],
|
|
||||||
hdrs = ["cholesky.h"],
|
|
||||||
deps = [
|
|
||||||
":util",
|
|
||||||
":while_loop",
|
|
||||||
"//tensorflow/compiler/xla:literal",
|
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
|
||||||
"//tensorflow/compiler/xla:statusor",
|
|
||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
|
||||||
"//tensorflow/compiler/xla/client:xla_builder",
|
|
||||||
"//tensorflow/compiler/xla/client/lib:constants",
|
|
||||||
"//tensorflow/compiler/xla/client/lib:matrix",
|
|
||||||
"//tensorflow/compiler/xla/client/lib:slicing",
|
|
||||||
"//tensorflow/compiler/xla/client/lib:triangular_solve",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "random",
|
name = "random",
|
||||||
srcs = ["random.cc"],
|
srcs = ["random.cc"],
|
||||||
@ -75,7 +52,6 @@ cc_library(
|
|||||||
hdrs = ["qr.h"],
|
hdrs = ["qr.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":util",
|
":util",
|
||||||
":while_loop",
|
|
||||||
"//tensorflow/compiler/xla:literal_util",
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
@ -84,6 +60,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/client:xla_builder",
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
||||||
"//tensorflow/compiler/xla/client/lib:constants",
|
"//tensorflow/compiler/xla/client/lib:constants",
|
||||||
|
"//tensorflow/compiler/xla/client/lib:loops",
|
||||||
"//tensorflow/compiler/xla/client/lib:math",
|
"//tensorflow/compiler/xla/client/lib:math",
|
||||||
"//tensorflow/compiler/xla/client/lib:matrix",
|
"//tensorflow/compiler/xla/client/lib:matrix",
|
||||||
"//tensorflow/compiler/xla/client/lib:slicing",
|
"//tensorflow/compiler/xla/client/lib:slicing",
|
||||||
@ -97,7 +74,6 @@ cc_library(
|
|||||||
hdrs = ["scatter.h"],
|
hdrs = ["scatter.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":util",
|
":util",
|
||||||
":while_loop",
|
|
||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
@ -128,19 +104,3 @@ cc_library(
|
|||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "while_loop",
|
|
||||||
srcs = ["while_loop.cc"],
|
|
||||||
hdrs = ["while_loop.h"],
|
|
||||||
deps = [
|
|
||||||
":util",
|
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
|
||||||
"//tensorflow/compiler/xla:statusor",
|
|
||||||
"//tensorflow/compiler/xla/client:xla_builder",
|
|
||||||
"//tensorflow/compiler/xla/client:xla_computation",
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
"@com_google_absl//absl/types:span",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
@ -19,9 +19,9 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
|
|
||||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/loops.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||||
@ -225,7 +225,7 @@ xla::StatusOr<QRBlockResult> QRBlock(
|
|||||||
builder, xla::ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n})));
|
builder, xla::ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n})));
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(auto values,
|
TF_ASSIGN_OR_RETURN(auto values,
|
||||||
XlaForEachIndex(std::min(m, n), xla::S32, qr_body_fn,
|
xla::ForEachIndex(std::min(m, n), xla::S32, qr_body_fn,
|
||||||
{a, vs, taus}, "qr", builder));
|
{a, vs, taus}, "qr", builder));
|
||||||
|
|
||||||
QRBlockResult result;
|
QRBlockResult result;
|
||||||
@ -301,7 +301,7 @@ xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
|
|||||||
w = UpdateSliceInMinorDims(w, bv, {0});
|
w = UpdateSliceInMinorDims(w, bv, {0});
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
auto values, XlaForEachIndex(n - 1, xla::S32, body_fn, {w, y, vs, taus},
|
auto values, xla::ForEachIndex(n - 1, xla::S32, body_fn, {w, y, vs, taus},
|
||||||
"wy", builder));
|
"wy", builder));
|
||||||
return values[0];
|
return values[0];
|
||||||
}
|
}
|
||||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
|
|
||||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/literal.h"
|
#include "tensorflow/compiler/xla/literal.h"
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
# Common computation builders for XLA.
|
# Common computation builders for XLA.
|
||||||
|
|
||||||
|
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites", "xla_test")
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
package(default_visibility = ["//tensorflow/compiler/xla/client:friends"])
|
package(default_visibility = ["//tensorflow/compiler/xla/client:friends"])
|
||||||
@ -13,9 +15,6 @@ filegroup(
|
|||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
|
|
||||||
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites")
|
|
||||||
|
|
||||||
# Generate test_suites for all backends, named "${backend}_tests".
|
# Generate test_suites for all backends, named "${backend}_tests".
|
||||||
generate_backend_suites()
|
generate_backend_suites()
|
||||||
|
|
||||||
@ -35,6 +34,48 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "cholesky",
|
||||||
|
srcs = ["cholesky.cc"],
|
||||||
|
hdrs = ["cholesky.h"],
|
||||||
|
deps = [
|
||||||
|
":math",
|
||||||
|
"//tensorflow/compiler/xla:literal",
|
||||||
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
|
"//tensorflow/compiler/xla/client/lib:constants",
|
||||||
|
"//tensorflow/compiler/xla/client/lib:loops",
|
||||||
|
"//tensorflow/compiler/xla/client/lib:matrix",
|
||||||
|
"//tensorflow/compiler/xla/client/lib:slicing",
|
||||||
|
"//tensorflow/compiler/xla/client/lib:triangular_solve",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
xla_test(
|
||||||
|
name = "cholesky_test",
|
||||||
|
srcs = ["cholesky_test.cc"],
|
||||||
|
tags = ["optonly"],
|
||||||
|
deps = [
|
||||||
|
":arithmetic",
|
||||||
|
":cholesky",
|
||||||
|
":matrix",
|
||||||
|
"//tensorflow/compiler/xla:array2d",
|
||||||
|
"//tensorflow/compiler/xla:literal",
|
||||||
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
"//tensorflow/compiler/xla:test",
|
||||||
|
"//tensorflow/compiler/xla:types",
|
||||||
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
|
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||||
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "constants",
|
name = "constants",
|
||||||
srcs = ["constants.cc"],
|
srcs = ["constants.cc"],
|
||||||
@ -75,6 +116,22 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "loops",
|
||||||
|
srcs = ["loops.cc"],
|
||||||
|
hdrs = ["loops.h"],
|
||||||
|
deps = [
|
||||||
|
":constants",
|
||||||
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
|
"//tensorflow/compiler/xla/client:xla_computation",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "math",
|
name = "math",
|
||||||
srcs = ["math.cc"],
|
srcs = ["math.cc"],
|
||||||
|
@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/compiler/tf2xla/lib/cholesky.h"
|
#include "tensorflow/compiler/xla/client/lib/cholesky.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
|
||||||
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
|
|
||||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/loops.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/triangular_solve.h"
|
#include "tensorflow/compiler/xla/client/lib/triangular_solve.h"
|
||||||
@ -31,7 +31,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace xla {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@ -50,26 +50,25 @@ namespace {
|
|||||||
// l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) /
|
// l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) /
|
||||||
// l[..., j, j]
|
// l[..., j, j]
|
||||||
// return l
|
// return l
|
||||||
xla::XlaOp CholeskyUnblocked(xla::XlaOp a,
|
XlaOp CholeskyUnblocked(XlaOp a, PrecisionConfig::Precision precision) {
|
||||||
xla::PrecisionConfig::Precision precision) {
|
XlaBuilder* builder = a.builder();
|
||||||
xla::XlaBuilder* builder = a.builder();
|
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
|
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||||
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
|
const int n_dims = ShapeUtil::Rank(a_shape);
|
||||||
const int n_dims = xla::ShapeUtil::Rank(a_shape);
|
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
|
||||||
const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
|
auto major_dims = AsInt64Slice(a_shape.dimensions())
|
||||||
auto major_dims = xla::AsInt64Slice(a_shape.dimensions())
|
|
||||||
.subspan(
|
.subspan(
|
||||||
/*pos=*/0,
|
/*pos=*/0,
|
||||||
/*len=*/n_dims - 2);
|
/*len=*/n_dims - 2);
|
||||||
|
|
||||||
xla::XlaOp l = xla::ZerosLike(a);
|
XlaOp l = ZerosLike(a);
|
||||||
|
|
||||||
// Construct the for loop body to iterate over rows.
|
// Construct the for loop body to iterate over rows.
|
||||||
auto body_fn = [&](xla::XlaOp i, absl::Span<const xla::XlaOp> loop_vars,
|
auto body_fn =
|
||||||
xla::XlaBuilder* body_builder)
|
[&](XlaOp i, absl::Span<const XlaOp> loop_vars,
|
||||||
-> xla::StatusOr<std::vector<xla::XlaOp>> {
|
XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
|
||||||
xla::Shape col_shape;
|
Shape col_shape;
|
||||||
xla::Shape row_shape;
|
Shape row_shape;
|
||||||
for (int64 d : major_dims) {
|
for (int64 d : major_dims) {
|
||||||
row_shape.add_dimensions(d);
|
row_shape.add_dimensions(d);
|
||||||
col_shape.add_dimensions(d);
|
col_shape.add_dimensions(d);
|
||||||
@ -77,43 +76,40 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a,
|
|||||||
row_shape.add_dimensions(1);
|
row_shape.add_dimensions(1);
|
||||||
row_shape.add_dimensions(n);
|
row_shape.add_dimensions(n);
|
||||||
row_shape.set_element_type(a_shape.element_type());
|
row_shape.set_element_type(a_shape.element_type());
|
||||||
auto mask_zeros_row = xla::Zeros(body_builder, row_shape);
|
auto mask_zeros_row = Zeros(body_builder, row_shape);
|
||||||
|
|
||||||
col_shape.add_dimensions(n);
|
col_shape.add_dimensions(n);
|
||||||
col_shape.add_dimensions(1);
|
col_shape.add_dimensions(1);
|
||||||
col_shape.set_element_type(a_shape.element_type());
|
col_shape.set_element_type(a_shape.element_type());
|
||||||
auto mask_zeros_col = xla::Zeros(body_builder, col_shape);
|
auto mask_zeros_col = Zeros(body_builder, col_shape);
|
||||||
|
|
||||||
std::vector<int32> mask_vector(n);
|
std::vector<int32> mask_vector(n);
|
||||||
std::iota(mask_vector.begin(), mask_vector.end(), 0);
|
std::iota(mask_vector.begin(), mask_vector.end(), 0);
|
||||||
auto mask_range = xla::ConstantR1<int32>(body_builder, mask_vector);
|
auto mask_range = ConstantR1<int32>(body_builder, mask_vector);
|
||||||
auto mask_range_row =
|
auto mask_range_row =
|
||||||
xla::Broadcast(xla::Reshape(mask_range, {0}, {1, n}), major_dims);
|
Broadcast(Reshape(mask_range, {0}, {1, n}), major_dims);
|
||||||
auto mask_range_col =
|
auto mask_range_col =
|
||||||
xla::Broadcast(xla::Reshape(mask_range, {0}, {n, 1}), major_dims);
|
Broadcast(Reshape(mask_range, {0}, {n, 1}), major_dims);
|
||||||
auto body_a = loop_vars[0];
|
auto body_a = loop_vars[0];
|
||||||
auto body_l = loop_vars[1];
|
auto body_l = loop_vars[1];
|
||||||
|
|
||||||
// row = l[..., i, :i]
|
// row = l[..., i, :i]
|
||||||
// select the whole i-th row, then mask out all columns past i-1
|
// select the whole i-th row, then mask out all columns past i-1
|
||||||
auto zero = xla::ConstantR0<int32>(body_builder, 0);
|
auto zero = ConstantR0<int32>(body_builder, 0);
|
||||||
auto l_i = DynamicSliceInMinorDims(body_l, {i, zero}, {1, n});
|
auto l_i = DynamicSliceInMinorDims(body_l, {i, zero}, {1, n});
|
||||||
auto row = xla::Select(xla::Ge(mask_range_row, i), mask_zeros_row, l_i);
|
auto row = Select(Ge(mask_range_row, i), mask_zeros_row, l_i);
|
||||||
// a[..., i, i]
|
// a[..., i, i]
|
||||||
auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1});
|
auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1});
|
||||||
// np.dot(row, np.swapaxes(row, -1, -2))
|
// np.dot(row, np.swapaxes(row, -1, -2))
|
||||||
auto diag_dot = BatchDot(row, TransposeInMinorDims(row), precision);
|
auto diag_dot = BatchDot(row, TransposeInMinorDims(row), precision);
|
||||||
// l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row,
|
// l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row,
|
||||||
// np.swapaxes(row, -1, -2)))
|
// np.swapaxes(row, -1, -2)))
|
||||||
auto l_ii =
|
auto l_ii = Sqrt(a_ii - diag_dot);
|
||||||
xla::Pow(a_ii - diag_dot,
|
|
||||||
FloatLiteral(body_builder, a_shape.element_type(), 0.5));
|
|
||||||
|
|
||||||
// a[..., i+1:, i]
|
// a[..., i+1:, i]
|
||||||
// select the whole i-th column, then mask out all rows above i+1
|
// select the whole i-th column, then mask out all rows above i+1
|
||||||
auto a_0i = DynamicSliceInMinorDims(body_a, {i}, {1});
|
auto a_0i = DynamicSliceInMinorDims(body_a, {i}, {1});
|
||||||
auto a_ip1i =
|
auto a_ip1i = Select(Le(mask_range_col, i), mask_zeros_col, a_0i);
|
||||||
xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, a_0i);
|
|
||||||
|
|
||||||
// l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) /
|
// l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) /
|
||||||
// l[..., i, i]
|
// l[..., i, i]
|
||||||
@ -122,8 +118,7 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a,
|
|||||||
// r.T)
|
// r.T)
|
||||||
auto dot = BatchDot(body_l, TransposeInMinorDims(row), precision);
|
auto dot = BatchDot(body_l, TransposeInMinorDims(row), precision);
|
||||||
// np.dot(l[..., i+1:, :i], r.T)
|
// np.dot(l[..., i+1:, :i], r.T)
|
||||||
auto dot_ip1 =
|
auto dot_ip1 = Select(Le(mask_range_col, i), mask_zeros_col, dot);
|
||||||
xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot);
|
|
||||||
|
|
||||||
body_l =
|
body_l =
|
||||||
DynamicUpdateSliceInMinorDims(body_l, (a_ip1i - dot_ip1) / l_ii, {i});
|
DynamicUpdateSliceInMinorDims(body_l, (a_ip1i - dot_ip1) / l_ii, {i});
|
||||||
@ -131,12 +126,12 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a,
|
|||||||
// column assign will wrap around and overwrite the diagonal assign.
|
// column assign will wrap around and overwrite the diagonal assign.
|
||||||
body_l = DynamicUpdateSliceInMinorDims(body_l, l_ii, {i, i});
|
body_l = DynamicUpdateSliceInMinorDims(body_l, l_ii, {i, i});
|
||||||
|
|
||||||
return std::vector<xla::XlaOp>{body_a, body_l};
|
return std::vector<XlaOp>{body_a, body_l};
|
||||||
};
|
};
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
auto cholesky_while,
|
auto cholesky_while,
|
||||||
XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder));
|
ForEachIndex(n, S32, body_fn, {a, l}, "unblocked", builder));
|
||||||
|
|
||||||
return cholesky_while[1];
|
return cholesky_while[1];
|
||||||
});
|
});
|
||||||
@ -144,34 +139,35 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a,
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size,
|
XlaOp Cholesky(XlaOp a, int64 block_size,
|
||||||
xla::PrecisionConfig::Precision precision) {
|
PrecisionConfig::Precision precision) {
|
||||||
xla::XlaBuilder* builder = a.builder();
|
XlaBuilder* builder = a.builder();
|
||||||
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
|
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
|
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||||
const int ndims = xla::ShapeUtil::Rank(a_shape);
|
const int ndims = ShapeUtil::Rank(a_shape);
|
||||||
if (ndims < 2) {
|
if (ndims < 2) {
|
||||||
return errors::InvalidArgument(
|
return InvalidArgument(
|
||||||
"Arguments to Cholesky must have rank >= 2: ", ndims);
|
"Argument to Cholesky must have rank >= 2; shape was %s",
|
||||||
|
a_shape.ToString());
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
|
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
|
||||||
if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) {
|
if (n != ShapeUtil::GetDimension(a_shape, -2)) {
|
||||||
return errors::InvalidArgument(
|
return InvalidArgument(
|
||||||
"Arguments to Cholesky must be square matrices: ",
|
"Argument to Cholesky must be batched square matrices; got shape %s",
|
||||||
xla::ShapeUtil::HumanString(a_shape));
|
ShapeUtil::HumanString(a_shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (block_size < 1) {
|
if (block_size < 1) {
|
||||||
return errors::InvalidArgument(
|
return InvalidArgument(
|
||||||
"block_size argument to Cholesky must be >= 1; got ", block_size);
|
"block_size argument to Cholesky must be >= 1; got %d", block_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Blocked left-looking Cholesky factorization.
|
// Blocked left-looking Cholesky factorization.
|
||||||
// Algorithm 1 from
|
// Algorithm 1 from
|
||||||
// Haidar, Azzam, et al. "High-performance Cholesky factorization for
|
// Haidar, Azzam, et al. "High-performance Cholesky factorization for
|
||||||
// GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017.
|
// GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017.
|
||||||
xla::XlaOp l = xla::ZerosLike(a);
|
XlaOp l = ZerosLike(a);
|
||||||
for (int64 i = 0; i < n; i += block_size) {
|
for (int64 i = 0; i < n; i += block_size) {
|
||||||
int64 k = std::min(block_size, n - i);
|
int64 k = std::min(block_size, n - i);
|
||||||
if (i > 0) {
|
if (i > 0) {
|
||||||
@ -207,4 +203,4 @@ xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size,
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace xla
|
@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_
|
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_
|
||||||
#define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_
|
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace xla {
|
||||||
|
|
||||||
// Computes the Cholesky decompositions of a batch of symmetric positive
|
// Computes the Cholesky decompositions of a batch of symmetric positive
|
||||||
// definite matrices.
|
// definite matrices.
|
||||||
@ -34,6 +34,6 @@ xla::XlaOp Cholesky(
|
|||||||
xla::XlaOp a, int64 block_size = 256,
|
xla::XlaOp a, int64 block_size = 256,
|
||||||
xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST);
|
xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace xla
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_
|
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_
|
166
tensorflow/compiler/xla/client/lib/cholesky_test.cc
Normal file
166
tensorflow/compiler/xla/client/lib/cholesky_test.cc
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/cholesky.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <numeric>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/array2d.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
|
#include "tensorflow/compiler/xla/literal.h"
|
||||||
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||||
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using xla::int64;
|
||||||
|
|
||||||
|
using CholeskyTest = xla::ClientLibraryTestBase;
|
||||||
|
|
||||||
|
XLA_TEST_F(CholeskyTest, Simple) {
|
||||||
|
xla::XlaBuilder builder(TestName());
|
||||||
|
|
||||||
|
xla::Array2D<float> a_vals({
|
||||||
|
{4, 6, 8, 10},
|
||||||
|
{6, 45, 54, 63},
|
||||||
|
{8, 54, 146, 166},
|
||||||
|
{10, 63, 166, 310},
|
||||||
|
});
|
||||||
|
|
||||||
|
xla::XlaOp a;
|
||||||
|
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
|
||||||
|
xla::Cholesky(a, /*block_size=*/2);
|
||||||
|
|
||||||
|
xla::Array2D<float> expected({
|
||||||
|
{2, 0, 0, 0},
|
||||||
|
{3, 6, 0, 0},
|
||||||
|
{4, 7, 9, 0},
|
||||||
|
{5, 8, 10, 11},
|
||||||
|
});
|
||||||
|
|
||||||
|
ComputeAndCompareR2<float>(&builder, expected, {a_data.get()},
|
||||||
|
xla::ErrorSpec(1e-4, 1e-4));
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(CholeskyTest, Simple2) {
|
||||||
|
xla::XlaBuilder builder(TestName());
|
||||||
|
|
||||||
|
xla::Array2D<float> a_vals({
|
||||||
|
{16, 24, 8, 12},
|
||||||
|
{24, 61, 82, 48},
|
||||||
|
{8, 82, 456, 106},
|
||||||
|
{12, 48, 106, 62},
|
||||||
|
});
|
||||||
|
|
||||||
|
xla::XlaOp a;
|
||||||
|
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
|
||||||
|
xla::Cholesky(a);
|
||||||
|
|
||||||
|
xla::Array2D<float> expected(
|
||||||
|
{{4, 0, 0, 0}, {6, 5, 0, 0}, {2, 14, 16, 0}, {3, 6, 1, 4}});
|
||||||
|
|
||||||
|
ComputeAndCompareR2<float>(&builder, expected, {a_data.get()},
|
||||||
|
xla::ErrorSpec(1e-4, 1e-4));
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(CholeskyTest, SimpleBatched) {
|
||||||
|
xla::XlaBuilder builder(TestName());
|
||||||
|
|
||||||
|
xla::Array3D<float> a_vals({
|
||||||
|
{
|
||||||
|
{4, 6, 8, 10},
|
||||||
|
{6, 45, 54, 63},
|
||||||
|
{8, 54, 146, 166},
|
||||||
|
{10, 63, 166, 310},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{16, 24, 8, 12},
|
||||||
|
{24, 61, 82, 48},
|
||||||
|
{8, 82, 456, 106},
|
||||||
|
{12, 48, 106, 62},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
xla::XlaOp a;
|
||||||
|
auto a_data = CreateR3Parameter<float>(a_vals, 0, "a", &builder, &a);
|
||||||
|
xla::Cholesky(a);
|
||||||
|
|
||||||
|
xla::Array3D<float> expected({
|
||||||
|
{
|
||||||
|
{2, 0, 0, 0},
|
||||||
|
{3, 6, 0, 0},
|
||||||
|
{4, 7, 9, 0},
|
||||||
|
{5, 8, 10, 11},
|
||||||
|
},
|
||||||
|
{{4, 0, 0, 0}, {6, 5, 0, 0}, {2, 14, 16, 0}, {3, 6, 1, 4}},
|
||||||
|
});
|
||||||
|
|
||||||
|
ComputeAndCompareR3<float>(&builder, expected, {a_data.get()},
|
||||||
|
xla::ErrorSpec(1e-4, 1e-4));
|
||||||
|
}
|
||||||
|
|
||||||
|
using CholeskyTestCase = std::tuple<int64, int64>;
|
||||||
|
|
||||||
|
class RandomCholeskyTest
|
||||||
|
: public xla::ClientLibraryTestBase,
|
||||||
|
public ::testing::WithParamInterface<CholeskyTestCase> {};
|
||||||
|
|
||||||
|
XLA_TEST_P(RandomCholeskyTest, Random) {
|
||||||
|
xla::XlaBuilder builder(TestName());
|
||||||
|
|
||||||
|
auto test_params = GetParam();
|
||||||
|
std::vector<int64> dimensions = {std::get<0>(test_params),
|
||||||
|
std::get<1>(test_params),
|
||||||
|
std::get<1>(test_params)};
|
||||||
|
xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, dimensions);
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto literal,
|
||||||
|
xla::LiteralUtil::CreateRandomLiteral<xla::F32>(shape, 0.0, 1.0));
|
||||||
|
|
||||||
|
auto input = xla::Parameter(&builder, 0, shape, "input");
|
||||||
|
// Form a random positive definite matrix.
|
||||||
|
auto matrix = xla::BatchDot(input, TransposeInMinorDims(input),
|
||||||
|
xla::PrecisionConfig::HIGHEST);
|
||||||
|
|
||||||
|
auto cholesky = xla::Cholesky(matrix, /*block_size=*/4);
|
||||||
|
|
||||||
|
// Verify that ||matrix - cholesky * cholesky_t||_2 ~= 0
|
||||||
|
auto verification = xla::BatchDot(cholesky, TransposeInMinorDims(cholesky),
|
||||||
|
xla::PrecisionConfig::HIGHEST);
|
||||||
|
auto delta = matrix - verification;
|
||||||
|
xla::Reduce(delta * delta, xla::ConstantR0<float>(&builder, 0.0),
|
||||||
|
CreateScalarAddComputation(xla::F32, &builder), {0, 1, 2});
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal));
|
||||||
|
ComputeAndCompareR0<float>(&builder, 0.0, {input_data.get()},
|
||||||
|
xla::ErrorSpec(1e-4, 1e-4));
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(RandomCholeskyTestInstance, RandomCholeskyTest,
|
||||||
|
::testing::Values(CholeskyTestCase{1, 1},
|
||||||
|
CholeskyTestCase{1, 2},
|
||||||
|
CholeskyTestCase{10, 5},
|
||||||
|
CholeskyTestCase{2, 20}));
|
||||||
|
|
||||||
|
} // namespace
|
@ -13,44 +13,43 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
|
#include "tensorflow/compiler/xla/client/lib/loops.h"
|
||||||
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace xla {
|
||||||
|
|
||||||
xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
|
StatusOr<std::vector<XlaOp>> WhileLoopHelper(
|
||||||
const LoopConditionFunction& condition_function,
|
const WhileLoopHelperConditionFunction& condition_function,
|
||||||
const LoopBodyFunction& body_function,
|
const WhileLoopHelperBodyFunction& body_function,
|
||||||
absl::Span<const xla::XlaOp> initial_values, absl::string_view name,
|
absl::Span<const XlaOp> initial_values, absl::string_view name,
|
||||||
xla::XlaBuilder* builder) {
|
XlaBuilder* builder) {
|
||||||
int arity = initial_values.size();
|
int arity = initial_values.size();
|
||||||
std::vector<xla::Shape> var_shapes;
|
std::vector<Shape> var_shapes;
|
||||||
var_shapes.reserve(arity);
|
var_shapes.reserve(arity);
|
||||||
for (const xla::XlaOp& input : initial_values) {
|
for (const XlaOp& input : initial_values) {
|
||||||
TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(input));
|
TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(input));
|
||||||
var_shapes.push_back(std::move(shape));
|
var_shapes.push_back(std::move(shape));
|
||||||
}
|
}
|
||||||
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(var_shapes);
|
Shape tuple_shape = ShapeUtil::MakeTupleShape(var_shapes);
|
||||||
|
|
||||||
// Unpacks a tuple into its component parts.
|
// Unpacks a tuple into its component parts.
|
||||||
auto unpack_tuple = [](xla::XlaOp tuple, int arity,
|
auto unpack_tuple = [](XlaOp tuple, int arity, XlaBuilder* builder) {
|
||||||
xla::XlaBuilder* builder) {
|
std::vector<XlaOp> elements(arity);
|
||||||
std::vector<xla::XlaOp> elements(arity);
|
|
||||||
for (int i = 0; i < arity; ++i) {
|
for (int i = 0; i < arity; ++i) {
|
||||||
elements[i] = xla::GetTupleElement(tuple, i);
|
elements[i] = GetTupleElement(tuple, i);
|
||||||
}
|
}
|
||||||
return elements;
|
return elements;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Build the condition.
|
// Build the condition.
|
||||||
std::unique_ptr<xla::XlaBuilder> cond_builder =
|
std::unique_ptr<XlaBuilder> cond_builder =
|
||||||
builder->CreateSubBuilder(absl::StrCat(name, "_condition"));
|
builder->CreateSubBuilder(absl::StrCat(name, "_condition"));
|
||||||
{
|
{
|
||||||
auto parameter =
|
auto parameter = Parameter(cond_builder.get(), 0, tuple_shape, "parameter");
|
||||||
xla::Parameter(cond_builder.get(), 0, tuple_shape, "parameter");
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
condition_function(unpack_tuple(parameter, arity, cond_builder.get()),
|
condition_function(unpack_tuple(parameter, arity, cond_builder.get()),
|
||||||
@ -60,11 +59,10 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
|
|||||||
TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build());
|
TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build());
|
||||||
|
|
||||||
// Build the body.
|
// Build the body.
|
||||||
std::unique_ptr<xla::XlaBuilder> body_builder =
|
std::unique_ptr<XlaBuilder> body_builder =
|
||||||
builder->CreateSubBuilder(absl::StrCat(name, "_body"));
|
builder->CreateSubBuilder(absl::StrCat(name, "_body"));
|
||||||
{
|
{
|
||||||
auto parameter =
|
auto parameter = Parameter(body_builder.get(), 0, tuple_shape, "parameter");
|
||||||
xla::Parameter(body_builder.get(), 0, tuple_shape, "parameter");
|
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
auto result,
|
auto result,
|
||||||
@ -72,56 +70,54 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
|
|||||||
body_builder.get()));
|
body_builder.get()));
|
||||||
|
|
||||||
TF_RET_CHECK(result.size() == initial_values.size());
|
TF_RET_CHECK(result.size() == initial_values.size());
|
||||||
xla::Tuple(body_builder.get(), result);
|
Tuple(body_builder.get(), result);
|
||||||
}
|
}
|
||||||
TF_ASSIGN_OR_RETURN(auto body, body_builder->Build());
|
TF_ASSIGN_OR_RETURN(auto body, body_builder->Build());
|
||||||
|
|
||||||
auto outputs = xla::While(cond, body, xla::Tuple(builder, initial_values));
|
auto outputs = While(cond, body, Tuple(builder, initial_values));
|
||||||
|
|
||||||
return unpack_tuple(outputs, arity, builder);
|
return unpack_tuple(outputs, arity, builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
|
StatusOr<std::vector<XlaOp>> ForEachIndex(
|
||||||
int64 num_iterations, xla::PrimitiveType num_iterations_type,
|
int64 num_iterations, PrimitiveType num_iterations_type,
|
||||||
const ForEachIndexBodyFunction& body_function,
|
const ForEachIndexBodyFunction& body_function,
|
||||||
absl::Span<const xla::XlaOp> initial_values, absl::string_view name,
|
absl::Span<const XlaOp> initial_values, absl::string_view name,
|
||||||
xla::XlaBuilder* builder) {
|
XlaBuilder* builder) {
|
||||||
auto while_cond_fn =
|
auto while_cond_fn = [&](absl::Span<const XlaOp> values,
|
||||||
[&](absl::Span<const xla::XlaOp> values,
|
XlaBuilder* cond_builder) -> StatusOr<XlaOp> {
|
||||||
xla::XlaBuilder* cond_builder) -> xla::StatusOr<xla::XlaOp> {
|
return Lt(values[0], ConstantR0WithType(cond_builder, num_iterations_type,
|
||||||
return xla::Lt(values[0], IntegerLiteral(cond_builder, num_iterations_type,
|
|
||||||
num_iterations));
|
num_iterations));
|
||||||
};
|
};
|
||||||
auto while_body_fn = [&](absl::Span<const xla::XlaOp> values,
|
auto while_body_fn =
|
||||||
xla::XlaBuilder* body_builder)
|
[&](absl::Span<const XlaOp> values,
|
||||||
-> xla::StatusOr<std::vector<xla::XlaOp>> {
|
XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
|
||||||
xla::XlaOp iteration = values[0];
|
XlaOp iteration = values[0];
|
||||||
|
|
||||||
std::vector<xla::XlaOp> updated_values;
|
std::vector<XlaOp> updated_values;
|
||||||
updated_values.reserve(values.size());
|
updated_values.reserve(values.size());
|
||||||
updated_values.push_back(xla::Add(
|
updated_values.push_back(Add(
|
||||||
iteration,
|
iteration,
|
||||||
xla::ConstantLiteral(body_builder,
|
ConstantLiteral(body_builder, LiteralUtil::One(num_iterations_type))));
|
||||||
xla::LiteralUtil::One(num_iterations_type))));
|
|
||||||
|
|
||||||
values.remove_prefix(1);
|
values.remove_prefix(1);
|
||||||
TF_ASSIGN_OR_RETURN(std::vector<xla::XlaOp> body_outputs,
|
TF_ASSIGN_OR_RETURN(std::vector<XlaOp> body_outputs,
|
||||||
body_function(iteration, values, body_builder));
|
body_function(iteration, values, body_builder));
|
||||||
updated_values.insert(updated_values.end(), body_outputs.begin(),
|
updated_values.insert(updated_values.end(), body_outputs.begin(),
|
||||||
body_outputs.end());
|
body_outputs.end());
|
||||||
return updated_values;
|
return updated_values;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<xla::XlaOp> values;
|
std::vector<XlaOp> values;
|
||||||
values.reserve(initial_values.size() + 1);
|
values.reserve(initial_values.size() + 1);
|
||||||
values.push_back(xla::ConstantLiteral(
|
values.push_back(
|
||||||
builder, xla::LiteralUtil::Zero(num_iterations_type)));
|
ConstantLiteral(builder, LiteralUtil::Zero(num_iterations_type)));
|
||||||
values.insert(values.end(), initial_values.begin(), initial_values.end());
|
values.insert(values.end(), initial_values.begin(), initial_values.end());
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values,
|
TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn,
|
||||||
name, builder));
|
values, name, builder));
|
||||||
values.erase(values.begin(), values.begin() + 1);
|
values.erase(values.begin(), values.begin() + 1);
|
||||||
return values;
|
return values;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace xla
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_
|
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_
|
||||||
#define TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_
|
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -25,19 +25,18 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace xla {
|
||||||
|
|
||||||
// Function that builds a loop condition. Takes as input a sequence of input
|
// Function that builds a loop condition. Takes as input a sequence of input
|
||||||
// values, and returns a boolean value representing if the condition succeeds.
|
// values, and returns a boolean value representing if the condition succeeds.
|
||||||
typedef std::function<xla::StatusOr<xla::XlaOp>(absl::Span<const xla::XlaOp>,
|
typedef std::function<StatusOr<XlaOp>(absl::Span<const XlaOp>, XlaBuilder*)>
|
||||||
xla::XlaBuilder*)>
|
WhileLoopHelperConditionFunction;
|
||||||
LoopConditionFunction;
|
|
||||||
|
|
||||||
// Function that builds a loop body. Takes as input a sequence of input values
|
// Function that builds a loop body. Takes as input a sequence of input values
|
||||||
// and returns a sequence of output values.
|
// and returns a sequence of output values.
|
||||||
typedef std::function<xla::StatusOr<std::vector<xla::XlaOp>>(
|
typedef std::function<StatusOr<std::vector<XlaOp>>(absl::Span<const XlaOp>,
|
||||||
absl::Span<const xla::XlaOp>, xla::XlaBuilder*)>
|
XlaBuilder*)>
|
||||||
LoopBodyFunction;
|
WhileLoopHelperBodyFunction;
|
||||||
|
|
||||||
// Helper function for building an XLA while loop, where the values carried by
|
// Helper function for building an XLA while loop, where the values carried by
|
||||||
// the loop are a tuple of values, e.g., (a, b, c):
|
// the loop are a tuple of values, e.g., (a, b, c):
|
||||||
@ -47,27 +46,27 @@ typedef std::function<xla::StatusOr<std::vector<xla::XlaOp>>(
|
|||||||
// init: (a, b, c)
|
// init: (a, b, c)
|
||||||
// )
|
// )
|
||||||
// 'name' is a descriptive name for the loop.
|
// 'name' is a descriptive name for the loop.
|
||||||
xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
|
StatusOr<std::vector<XlaOp>> WhileLoopHelper(
|
||||||
const LoopConditionFunction& condition_function,
|
const WhileLoopHelperConditionFunction& condition_function,
|
||||||
const LoopBodyFunction& body_function,
|
const WhileLoopHelperBodyFunction& body_function,
|
||||||
absl::Span<const xla::XlaOp> initial_values, absl::string_view name,
|
absl::Span<const XlaOp> initial_values, absl::string_view name,
|
||||||
xla::XlaBuilder* builder);
|
XlaBuilder* builder);
|
||||||
|
|
||||||
// Builds an XLA loop that repeats a computation `num_iterations` times.
|
// Builds an XLA loop that repeats a computation `num_iterations` times.
|
||||||
//
|
//
|
||||||
// The body function (ForEachIndexBodyFunction) takes as input a pair of
|
// The body function (ForEachIndexBodyFunction) takes as input a pair of
|
||||||
// (current iteration number, loop-carried values), and returns an updated
|
// (current iteration number, loop-carried values), and returns an updated
|
||||||
// vector of the loop-carried values.
|
// vector of the loop-carried values.
|
||||||
typedef std::function<xla::StatusOr<std::vector<xla::XlaOp>>(
|
typedef std::function<StatusOr<std::vector<XlaOp>>(
|
||||||
xla::XlaOp, absl::Span<const xla::XlaOp>, xla::XlaBuilder*)>
|
XlaOp, absl::Span<const XlaOp>, XlaBuilder*)>
|
||||||
ForEachIndexBodyFunction;
|
ForEachIndexBodyFunction;
|
||||||
|
|
||||||
xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
|
StatusOr<std::vector<XlaOp>> ForEachIndex(
|
||||||
int64 num_iterations, xla::PrimitiveType num_iterations_type,
|
int64 num_iterations, PrimitiveType num_iterations_type,
|
||||||
const ForEachIndexBodyFunction& body_function,
|
const ForEachIndexBodyFunction& body_function,
|
||||||
absl::Span<const xla::XlaOp> initial_values, absl::string_view name,
|
absl::Span<const XlaOp> initial_values, absl::string_view name,
|
||||||
xla::XlaBuilder* builder);
|
XlaBuilder* builder);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace xla
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_
|
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_
|
Loading…
Reference in New Issue
Block a user