diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 8bc32922964..901b97736b1 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -1,16 +1,11 @@ +load("//tensorflow:tensorflow.bzl", "tf_copts", "tf_kernel_library") + licenses(["notice"]) # Apache 2.0 package( default_visibility = ["//tensorflow/compiler/tf2xla:internal"], ) -load("//tensorflow:tensorflow.bzl", "tf_copts") -load("//tensorflow:tensorflow.bzl", "tf_kernel_library") -load( - "//third_party/mkl:build_defs.bzl", - "if_mkl", -) - tf_kernel_library( name = "xla_ops", srcs = [ @@ -122,12 +117,10 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:broadcast", - "//tensorflow/compiler/tf2xla/lib:cholesky", "//tensorflow/compiler/tf2xla/lib:qr", "//tensorflow/compiler/tf2xla/lib:random", "//tensorflow/compiler/tf2xla/lib:scatter", "//tensorflow/compiler/tf2xla/lib:util", - "//tensorflow/compiler/tf2xla/lib:while_loop", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal", @@ -140,7 +133,9 @@ tf_kernel_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:cholesky", "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:loops", "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/compiler/xla/client/lib:pooling", diff --git a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc index 9fcbc86adc0..0ed3044efa5 100644 --- a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/cholesky.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/cholesky.h" namespace tensorflow { namespace { @@ -24,7 +24,7 @@ class CholeskyOp : public XlaOpKernel { public: explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - ctx->SetOutput(0, Cholesky(ctx->Input(0))); + ctx->SetOutput(0, xla::Cholesky(ctx->Input(0))); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 20b0de193dc..41c31d0ed58 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index e9bb0a77e99..96ddd42e2ae 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -15,12 +15,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/lib/sorting.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -505,9 +505,9 @@ class NonMaxSuppressionOp : public XlaOpKernel { init_values.push_back(included_iou); auto suppress_loop_result = - XlaWhileLoop(WhileCondFn(num_boxes, output_size), - SuppressBodyFn(num_boxes), init_values, "suppress_loop", - builder) + xla::WhileLoopHelper(WhileCondFn(num_boxes, output_size), + SuppressBodyFn(num_boxes), init_values, + "suppress_loop", builder) .ValueOrDie(); xla::XlaOp included_score = diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 8822e29f7e7..2d92056e4f5 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -20,12 +20,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/lib/random.h" #include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -175,8 +175,8 @@ class RandomShuffleOp : public XlaOpKernel { }; // for i in range(n): auto swap_loop_result = - XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices}, - "indices_swap_loop", builder) + xla::ForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices}, + "indices_swap_loop", builder) .ValueOrDie(); auto swapped_indices = swap_loop_result[1]; diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 3e7a7611203..9ec9e9bdc08 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -15,8 +15,6 @@ filegroup( ]), ) -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") - cc_library( name = "broadcast", srcs = ["broadcast.cc"], @@ -33,27 +31,6 @@ cc_library( ], ) -cc_library( - name = "cholesky", - srcs = ["cholesky.cc"], - hdrs = ["cholesky.h"], - deps = [ - ":util", - ":while_loop", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:matrix", - "//tensorflow/compiler/xla/client/lib:slicing", - "//tensorflow/compiler/xla/client/lib:triangular_solve", - "//tensorflow/core:lib", - ], -) - cc_library( name = "random", srcs = ["random.cc"], @@ -75,7 +52,6 @@ cc_library( hdrs = ["qr.h"], deps = [ ":util", - ":while_loop", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -84,6 +60,7 @@ cc_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:loops", "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/compiler/xla/client/lib:slicing", @@ -97,7 +74,6 @@ cc_library( hdrs = ["scatter.h"], deps = [ ":util", - ":while_loop", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -128,19 +104,3 @@ cc_library( "@com_google_absl//absl/types:span", ], ) - -cc_library( - name = "while_loop", - srcs = ["while_loop.cc"], - hdrs = ["while_loop.h"], - deps = [ - ":util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc index d6007748609..057045fc0c0 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.cc +++ b/tensorflow/compiler/tf2xla/lib/qr.cc @@ -19,9 +19,9 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" @@ -225,8 +225,8 @@ xla::StatusOr QRBlock( builder, xla::ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n}))); TF_ASSIGN_OR_RETURN(auto values, - XlaForEachIndex(std::min(m, n), xla::S32, qr_body_fn, - {a, vs, taus}, "qr", builder)); + xla::ForEachIndex(std::min(m, n), xla::S32, qr_body_fn, + {a, vs, taus}, "qr", builder)); QRBlockResult result; result.r = values[0]; @@ -301,8 +301,8 @@ xla::StatusOr ComputeWYRepresentation( w = UpdateSliceInMinorDims(w, bv, {0}); TF_ASSIGN_OR_RETURN( - auto values, XlaForEachIndex(n - 1, xla::S32, body_fn, {w, y, vs, taus}, - "wy", builder)); + auto values, xla::ForEachIndex(n - 1, xla::S32, body_fn, {w, y, vs, taus}, + "wy", builder)); return values[0]; } diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 2b1c2ced925..688056791f9 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -20,7 +20,6 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 41db8de29ff..bf21b267c53 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -1,5 +1,7 @@ # Common computation builders for XLA. +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites", "xla_test") + licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow/compiler/xla/client:friends"]) @@ -13,9 +15,6 @@ filegroup( ]), ) -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites") - # Generate test_suites for all backends, named "${backend}_tests". generate_backend_suites() @@ -35,6 +34,48 @@ cc_library( ], ) +cc_library( + name = "cholesky", + srcs = ["cholesky.cc"], + hdrs = ["cholesky.h"], + deps = [ + ":math", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:loops", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/client/lib:slicing", + "//tensorflow/compiler/xla/client/lib:triangular_solve", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "cholesky_test", + srcs = ["cholesky_test.cc"], + tags = ["optonly"], + deps = [ + ":arithmetic", + ":cholesky", + ":matrix", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "constants", srcs = ["constants.cc"], @@ -75,6 +116,22 @@ cc_library( ], ) +cc_library( + name = "loops", + srcs = ["loops.cc"], + hdrs = ["loops.h"], + deps = [ + ":constants", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "math", srcs = ["math.cc"], diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/xla/client/lib/cholesky.cc similarity index 68% rename from tensorflow/compiler/tf2xla/lib/cholesky.cc rename to tensorflow/compiler/xla/client/lib/cholesky.cc index 550ab5b0569..fd980499684 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/xla/client/lib/cholesky.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/cholesky.h" +#include "tensorflow/compiler/xla/client/lib/cholesky.h" #include #include -#include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/lib/triangular_solve.h" @@ -31,7 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/errors.h" -namespace tensorflow { +namespace xla { namespace { @@ -50,26 +50,25 @@ namespace { // l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) / // l[..., j, j] // return l -xla::XlaOp CholeskyUnblocked(xla::XlaOp a, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int n_dims = xla::ShapeUtil::Rank(a_shape); - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); - auto major_dims = xla::AsInt64Slice(a_shape.dimensions()) +XlaOp CholeskyUnblocked(XlaOp a, PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int n_dims = ShapeUtil::Rank(a_shape); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); + auto major_dims = AsInt64Slice(a_shape.dimensions()) .subspan( /*pos=*/0, /*len=*/n_dims - 2); - xla::XlaOp l = xla::ZerosLike(a); + XlaOp l = ZerosLike(a); // Construct the for loop body to iterate over rows. - auto body_fn = [&](xla::XlaOp i, absl::Span loop_vars, - xla::XlaBuilder* body_builder) - -> xla::StatusOr> { - xla::Shape col_shape; - xla::Shape row_shape; + auto body_fn = + [&](XlaOp i, absl::Span loop_vars, + XlaBuilder* body_builder) -> StatusOr> { + Shape col_shape; + Shape row_shape; for (int64 d : major_dims) { row_shape.add_dimensions(d); col_shape.add_dimensions(d); @@ -77,43 +76,40 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, row_shape.add_dimensions(1); row_shape.add_dimensions(n); row_shape.set_element_type(a_shape.element_type()); - auto mask_zeros_row = xla::Zeros(body_builder, row_shape); + auto mask_zeros_row = Zeros(body_builder, row_shape); col_shape.add_dimensions(n); col_shape.add_dimensions(1); col_shape.set_element_type(a_shape.element_type()); - auto mask_zeros_col = xla::Zeros(body_builder, col_shape); + auto mask_zeros_col = Zeros(body_builder, col_shape); std::vector mask_vector(n); std::iota(mask_vector.begin(), mask_vector.end(), 0); - auto mask_range = xla::ConstantR1(body_builder, mask_vector); + auto mask_range = ConstantR1(body_builder, mask_vector); auto mask_range_row = - xla::Broadcast(xla::Reshape(mask_range, {0}, {1, n}), major_dims); + Broadcast(Reshape(mask_range, {0}, {1, n}), major_dims); auto mask_range_col = - xla::Broadcast(xla::Reshape(mask_range, {0}, {n, 1}), major_dims); + Broadcast(Reshape(mask_range, {0}, {n, 1}), major_dims); auto body_a = loop_vars[0]; auto body_l = loop_vars[1]; // row = l[..., i, :i] // select the whole i-th row, then mask out all columns past i-1 - auto zero = xla::ConstantR0(body_builder, 0); + auto zero = ConstantR0(body_builder, 0); auto l_i = DynamicSliceInMinorDims(body_l, {i, zero}, {1, n}); - auto row = xla::Select(xla::Ge(mask_range_row, i), mask_zeros_row, l_i); + auto row = Select(Ge(mask_range_row, i), mask_zeros_row, l_i); // a[..., i, i] auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1}); // np.dot(row, np.swapaxes(row, -1, -2)) auto diag_dot = BatchDot(row, TransposeInMinorDims(row), precision); // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, // np.swapaxes(row, -1, -2))) - auto l_ii = - xla::Pow(a_ii - diag_dot, - FloatLiteral(body_builder, a_shape.element_type(), 0.5)); + auto l_ii = Sqrt(a_ii - diag_dot); // a[..., i+1:, i] // select the whole i-th column, then mask out all rows above i+1 auto a_0i = DynamicSliceInMinorDims(body_a, {i}, {1}); - auto a_ip1i = - xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, a_0i); + auto a_ip1i = Select(Le(mask_range_col, i), mask_zeros_col, a_0i); // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) / // l[..., i, i] @@ -122,8 +118,7 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, // r.T) auto dot = BatchDot(body_l, TransposeInMinorDims(row), precision); // np.dot(l[..., i+1:, :i], r.T) - auto dot_ip1 = - xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot); + auto dot_ip1 = Select(Le(mask_range_col, i), mask_zeros_col, dot); body_l = DynamicUpdateSliceInMinorDims(body_l, (a_ip1i - dot_ip1) / l_ii, {i}); @@ -131,12 +126,12 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, // column assign will wrap around and overwrite the diagonal assign. body_l = DynamicUpdateSliceInMinorDims(body_l, l_ii, {i, i}); - return std::vector{body_a, body_l}; + return std::vector{body_a, body_l}; }; TF_ASSIGN_OR_RETURN( auto cholesky_while, - XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder)); + ForEachIndex(n, S32, body_fn, {a, l}, "unblocked", builder)); return cholesky_while[1]; }); @@ -144,34 +139,35 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, } // namespace -xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int ndims = xla::ShapeUtil::Rank(a_shape); +XlaOp Cholesky(XlaOp a, int64 block_size, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int ndims = ShapeUtil::Rank(a_shape); if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to Cholesky must have rank >= 2: ", ndims); + return InvalidArgument( + "Argument to Cholesky must have rank >= 2; shape was %s", + a_shape.ToString()); } - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); - if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) { - return errors::InvalidArgument( - "Arguments to Cholesky must be square matrices: ", - xla::ShapeUtil::HumanString(a_shape)); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); + if (n != ShapeUtil::GetDimension(a_shape, -2)) { + return InvalidArgument( + "Argument to Cholesky must be batched square matrices; got shape %s", + ShapeUtil::HumanString(a_shape)); } if (block_size < 1) { - return errors::InvalidArgument( - "block_size argument to Cholesky must be >= 1; got ", block_size); + return InvalidArgument( + "block_size argument to Cholesky must be >= 1; got %d", block_size); } // Blocked left-looking Cholesky factorization. // Algorithm 1 from // Haidar, Azzam, et al. "High-performance Cholesky factorization for // GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017. - xla::XlaOp l = xla::ZerosLike(a); + XlaOp l = ZerosLike(a); for (int64 i = 0; i < n; i += block_size) { int64 k = std::min(block_size, n - i); if (i > 0) { @@ -207,4 +203,4 @@ xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size, }); } -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/xla/client/lib/cholesky.h similarity index 87% rename from tensorflow/compiler/tf2xla/lib/cholesky.h rename to tensorflow/compiler/xla/client/lib/cholesky.h index 9a561c34b92..0bae26837c0 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/xla/client/lib/cholesky.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -namespace tensorflow { +namespace xla { // Computes the Cholesky decompositions of a batch of symmetric positive // definite matrices. @@ -34,6 +34,6 @@ xla::XlaOp Cholesky( xla::XlaOp a, int64 block_size = 256, xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); -} // namespace tensorflow +} // namespace xla -#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_ diff --git a/tensorflow/compiler/xla/client/lib/cholesky_test.cc b/tensorflow/compiler/xla/client/lib/cholesky_test.cc new file mode 100644 index 00000000000..ba9580a3d32 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/cholesky_test.cc @@ -0,0 +1,166 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/cholesky.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace { + +using xla::int64; + +using CholeskyTest = xla::ClientLibraryTestBase; + +XLA_TEST_F(CholeskyTest, Simple) { + xla::XlaBuilder builder(TestName()); + + xla::Array2D a_vals({ + {4, 6, 8, 10}, + {6, 45, 54, 63}, + {8, 54, 146, 166}, + {10, 63, 166, 310}, + }); + + xla::XlaOp a; + auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); + xla::Cholesky(a, /*block_size=*/2); + + xla::Array2D expected({ + {2, 0, 0, 0}, + {3, 6, 0, 0}, + {4, 7, 9, 0}, + {5, 8, 10, 11}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +XLA_TEST_F(CholeskyTest, Simple2) { + xla::XlaBuilder builder(TestName()); + + xla::Array2D a_vals({ + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }); + + xla::XlaOp a; + auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); + xla::Cholesky(a); + + xla::Array2D expected( + {{4, 0, 0, 0}, {6, 5, 0, 0}, {2, 14, 16, 0}, {3, 6, 1, 4}}); + + ComputeAndCompareR2(&builder, expected, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +XLA_TEST_F(CholeskyTest, SimpleBatched) { + xla::XlaBuilder builder(TestName()); + + xla::Array3D a_vals({ + { + {4, 6, 8, 10}, + {6, 45, 54, 63}, + {8, 54, 146, 166}, + {10, 63, 166, 310}, + }, + { + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }, + }); + + xla::XlaOp a; + auto a_data = CreateR3Parameter(a_vals, 0, "a", &builder, &a); + xla::Cholesky(a); + + xla::Array3D expected({ + { + {2, 0, 0, 0}, + {3, 6, 0, 0}, + {4, 7, 9, 0}, + {5, 8, 10, 11}, + }, + {{4, 0, 0, 0}, {6, 5, 0, 0}, {2, 14, 16, 0}, {3, 6, 1, 4}}, + }); + + ComputeAndCompareR3(&builder, expected, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +using CholeskyTestCase = std::tuple; + +class RandomCholeskyTest + : public xla::ClientLibraryTestBase, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(RandomCholeskyTest, Random) { + xla::XlaBuilder builder(TestName()); + + auto test_params = GetParam(); + std::vector dimensions = {std::get<0>(test_params), + std::get<1>(test_params), + std::get<1>(test_params)}; + xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, dimensions); + TF_ASSERT_OK_AND_ASSIGN( + auto literal, + xla::LiteralUtil::CreateRandomLiteral(shape, 0.0, 1.0)); + + auto input = xla::Parameter(&builder, 0, shape, "input"); + // Form a random positive definite matrix. + auto matrix = xla::BatchDot(input, TransposeInMinorDims(input), + xla::PrecisionConfig::HIGHEST); + + auto cholesky = xla::Cholesky(matrix, /*block_size=*/4); + + // Verify that ||matrix - cholesky * cholesky_t||_2 ~= 0 + auto verification = xla::BatchDot(cholesky, TransposeInMinorDims(cholesky), + xla::PrecisionConfig::HIGHEST); + auto delta = matrix - verification; + xla::Reduce(delta * delta, xla::ConstantR0(&builder, 0.0), + CreateScalarAddComputation(xla::F32, &builder), {0, 1, 2}); + + TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal)); + ComputeAndCompareR0(&builder, 0.0, {input_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +INSTANTIATE_TEST_CASE_P(RandomCholeskyTestInstance, RandomCholeskyTest, + ::testing::Values(CholeskyTestCase{1, 1}, + CholeskyTestCase{1, 2}, + CholeskyTestCase{10, 5}, + CholeskyTestCase{2, 20})); + +} // namespace diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/xla/client/lib/loops.cc similarity index 50% rename from tensorflow/compiler/tf2xla/lib/while_loop.cc rename to tensorflow/compiler/xla/client/lib/loops.cc index 594ab1dfd07..721f987628a 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.cc +++ b/tensorflow/compiler/xla/client/lib/loops.cc @@ -13,44 +13,43 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" -#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" + +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" -namespace tensorflow { +namespace xla { -xla::StatusOr> XlaWhileLoop( - const LoopConditionFunction& condition_function, - const LoopBodyFunction& body_function, - absl::Span initial_values, absl::string_view name, - xla::XlaBuilder* builder) { +StatusOr> WhileLoopHelper( + const WhileLoopHelperConditionFunction& condition_function, + const WhileLoopHelperBodyFunction& body_function, + absl::Span initial_values, absl::string_view name, + XlaBuilder* builder) { int arity = initial_values.size(); - std::vector var_shapes; + std::vector var_shapes; var_shapes.reserve(arity); - for (const xla::XlaOp& input : initial_values) { + for (const XlaOp& input : initial_values) { TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(input)); var_shapes.push_back(std::move(shape)); } - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(var_shapes); + Shape tuple_shape = ShapeUtil::MakeTupleShape(var_shapes); // Unpacks a tuple into its component parts. - auto unpack_tuple = [](xla::XlaOp tuple, int arity, - xla::XlaBuilder* builder) { - std::vector elements(arity); + auto unpack_tuple = [](XlaOp tuple, int arity, XlaBuilder* builder) { + std::vector elements(arity); for (int i = 0; i < arity; ++i) { - elements[i] = xla::GetTupleElement(tuple, i); + elements[i] = GetTupleElement(tuple, i); } return elements; }; // Build the condition. - std::unique_ptr cond_builder = + std::unique_ptr cond_builder = builder->CreateSubBuilder(absl::StrCat(name, "_condition")); { - auto parameter = - xla::Parameter(cond_builder.get(), 0, tuple_shape, "parameter"); + auto parameter = Parameter(cond_builder.get(), 0, tuple_shape, "parameter"); TF_RETURN_IF_ERROR( condition_function(unpack_tuple(parameter, arity, cond_builder.get()), @@ -60,11 +59,10 @@ xla::StatusOr> XlaWhileLoop( TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build()); // Build the body. - std::unique_ptr body_builder = + std::unique_ptr body_builder = builder->CreateSubBuilder(absl::StrCat(name, "_body")); { - auto parameter = - xla::Parameter(body_builder.get(), 0, tuple_shape, "parameter"); + auto parameter = Parameter(body_builder.get(), 0, tuple_shape, "parameter"); TF_ASSIGN_OR_RETURN( auto result, @@ -72,56 +70,54 @@ xla::StatusOr> XlaWhileLoop( body_builder.get())); TF_RET_CHECK(result.size() == initial_values.size()); - xla::Tuple(body_builder.get(), result); + Tuple(body_builder.get(), result); } TF_ASSIGN_OR_RETURN(auto body, body_builder->Build()); - auto outputs = xla::While(cond, body, xla::Tuple(builder, initial_values)); + auto outputs = While(cond, body, Tuple(builder, initial_values)); return unpack_tuple(outputs, arity, builder); } -xla::StatusOr> XlaForEachIndex( - int64 num_iterations, xla::PrimitiveType num_iterations_type, +StatusOr> ForEachIndex( + int64 num_iterations, PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - absl::Span initial_values, absl::string_view name, - xla::XlaBuilder* builder) { - auto while_cond_fn = - [&](absl::Span values, - xla::XlaBuilder* cond_builder) -> xla::StatusOr { - return xla::Lt(values[0], IntegerLiteral(cond_builder, num_iterations_type, - num_iterations)); + absl::Span initial_values, absl::string_view name, + XlaBuilder* builder) { + auto while_cond_fn = [&](absl::Span values, + XlaBuilder* cond_builder) -> StatusOr { + return Lt(values[0], ConstantR0WithType(cond_builder, num_iterations_type, + num_iterations)); }; - auto while_body_fn = [&](absl::Span values, - xla::XlaBuilder* body_builder) - -> xla::StatusOr> { - xla::XlaOp iteration = values[0]; + auto while_body_fn = + [&](absl::Span values, + XlaBuilder* body_builder) -> StatusOr> { + XlaOp iteration = values[0]; - std::vector updated_values; + std::vector updated_values; updated_values.reserve(values.size()); - updated_values.push_back(xla::Add( + updated_values.push_back(Add( iteration, - xla::ConstantLiteral(body_builder, - xla::LiteralUtil::One(num_iterations_type)))); + ConstantLiteral(body_builder, LiteralUtil::One(num_iterations_type)))); values.remove_prefix(1); - TF_ASSIGN_OR_RETURN(std::vector body_outputs, + TF_ASSIGN_OR_RETURN(std::vector body_outputs, body_function(iteration, values, body_builder)); updated_values.insert(updated_values.end(), body_outputs.begin(), body_outputs.end()); return updated_values; }; - std::vector values; + std::vector values; values.reserve(initial_values.size() + 1); - values.push_back(xla::ConstantLiteral( - builder, xla::LiteralUtil::Zero(num_iterations_type))); + values.push_back( + ConstantLiteral(builder, LiteralUtil::Zero(num_iterations_type))); values.insert(values.end(), initial_values.begin(), initial_values.end()); - TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values, - name, builder)); + TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn, + values, name, builder)); values.erase(values.begin(), values.begin() + 1); return values; } -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/xla/client/lib/loops.h similarity index 62% rename from tensorflow/compiler/tf2xla/lib/while_loop.h rename to tensorflow/compiler/xla/client/lib/loops.h index f2134bb4495..e11de59493e 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.h +++ b/tensorflow/compiler/xla/client/lib/loops.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_ #include #include @@ -25,19 +25,18 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" -namespace tensorflow { +namespace xla { // Function that builds a loop condition. Takes as input a sequence of input // values, and returns a boolean value representing if the condition succeeds. -typedef std::function(absl::Span, - xla::XlaBuilder*)> - LoopConditionFunction; +typedef std::function(absl::Span, XlaBuilder*)> + WhileLoopHelperConditionFunction; // Function that builds a loop body. Takes as input a sequence of input values // and returns a sequence of output values. -typedef std::function>( - absl::Span, xla::XlaBuilder*)> - LoopBodyFunction; +typedef std::function>(absl::Span, + XlaBuilder*)> + WhileLoopHelperBodyFunction; // Helper function for building an XLA while loop, where the values carried by // the loop are a tuple of values, e.g., (a, b, c): @@ -47,27 +46,27 @@ typedef std::function>( // init: (a, b, c) // ) // 'name' is a descriptive name for the loop. -xla::StatusOr> XlaWhileLoop( - const LoopConditionFunction& condition_function, - const LoopBodyFunction& body_function, - absl::Span initial_values, absl::string_view name, - xla::XlaBuilder* builder); +StatusOr> WhileLoopHelper( + const WhileLoopHelperConditionFunction& condition_function, + const WhileLoopHelperBodyFunction& body_function, + absl::Span initial_values, absl::string_view name, + XlaBuilder* builder); // Builds an XLA loop that repeats a computation `num_iterations` times. // // The body function (ForEachIndexBodyFunction) takes as input a pair of // (current iteration number, loop-carried values), and returns an updated // vector of the loop-carried values. -typedef std::function>( - xla::XlaOp, absl::Span, xla::XlaBuilder*)> +typedef std::function>( + XlaOp, absl::Span, XlaBuilder*)> ForEachIndexBodyFunction; -xla::StatusOr> XlaForEachIndex( - int64 num_iterations, xla::PrimitiveType num_iterations_type, +StatusOr> ForEachIndex( + int64 num_iterations, PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - absl::Span initial_values, absl::string_view name, - xla::XlaBuilder* builder); + absl::Span initial_values, absl::string_view name, + XlaBuilder* builder); -} // namespace tensorflow +} // namespace xla -#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_