diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 5fe7fbdb1d8..b33147f3153 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1898,6 +1898,7 @@ cc_library( "//tensorflow/compiler/xla/client/lib:slicing", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/service/triangular_solve_expander.cc b/tensorflow/compiler/xla/service/triangular_solve_expander.cc index 8df078d1377..952ec137954 100644 --- a/tensorflow/compiler/xla/service/triangular_solve_expander.cc +++ b/tensorflow/compiler/xla/service/triangular_solve_expander.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" @@ -43,6 +44,8 @@ XlaOp DiagonalBlocks(XlaOp a, int64 block_size) { int ndims = shape.rank(); int64 n = ShapeUtil::GetDimension(shape, -1); int64 num_blocks = n / block_size; + absl::Span batch_dims = absl::MakeConstSpan( + shape.dimensions().begin(), shape.dimensions().begin() + (ndims - 2)); XlaOp diag_blocks; @@ -100,10 +103,10 @@ XlaOp DiagonalBlocks(XlaOp a, int64 block_size) { auto eye = IdentityMatrix(builder, shape.element_type(), padding, padding); - config = MakeNoPaddingConfig(ndims); - config.mutable_dimensions(ndims - 2)->set_edge_padding_low(n % - block_size); + config = MakeNoPaddingConfig(2); + config.mutable_dimensions(0)->set_edge_padding_low(n % block_size); eye = Pad(eye, Zero(builder, shape.element_type()), config); + eye = Broadcast(eye, batch_dims); last_blocks = ConcatInDim(builder, {last_blocks, eye}, ndims - 1); // Add a singleton dimension diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index a429bf7f2bc..55a31b7f650 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -2669,6 +2669,7 @@ xla_test( ], deps = [ ":test_macros_header", + "//tensorflow/compiler/xla:array", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:statusor", @@ -2681,6 +2682,7 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 3c9e37b8fa4..2e038ea27a9 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -226,7 +226,13 @@ class ClientLibraryTestBase : public ManifestCheckingTest { absl::Span arguments); void ComputeAndCompare(XlaBuilder* builder, absl::Span arguments, ErrorSpec error); - + template + void ComputeAndCompare(XlaBuilder* builder, const Array& expected, + absl::Span arguments); + template + void ComputeAndCompare(XlaBuilder* builder, const Array& expected, + absl::Span arguments, + ErrorSpec error); // Create scalar operations for use in reductions. XlaComputation CreateScalarRelu(); XlaComputation CreateScalarMax(); @@ -387,6 +393,13 @@ class ClientLibraryTestBase : public ManifestCheckingTest { const Array4D& array_4d, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle); + template + std::unique_ptr CreateParameter(const Array& array_4d, + int64 parameter_number, + const string& name, + XlaBuilder* builder, + XlaOp* data_handle); + // Getter and setter for the use_bfloat16 flag, which indicates whether to run // tests with all float-type input/output converted to bfloat16. bool use_bfloat16() const { return use_bfloat16_; } @@ -563,6 +576,31 @@ void ClientLibraryTestBase::ComputeAndCompareR4( arguments, error); } +template +void ClientLibraryTestBase::ComputeAndCompare( + XlaBuilder* builder, const Array& expected, + absl::Span arguments) { + Literal expected_literal = LiteralUtil::CreateFromArray(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, + arguments); +} + +template +void ClientLibraryTestBase::ComputeAndCompare( + XlaBuilder* builder, const Array& expected, + absl::Span arguments, ErrorSpec error) { + static_assert(std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value, + "Float or complex type required when specifying an ErrorSpec"); + Literal expected_literal = LiteralUtil::CreateFromArray(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, + arguments, error); +} + template std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( NativeT value, int64 parameter_number, const string& name, @@ -633,6 +671,20 @@ std::unique_ptr ClientLibraryTestBase::CreateR4Parameter( return data; } +template +std::unique_ptr ClientLibraryTestBase::CreateParameter( + const Array& array, int64 parameter_number, const string& name, + XlaBuilder* builder, XlaOp* data_handle) { + Literal literal = LiteralUtil::CreateFromArray(array); + if (use_bfloat16_ && literal.shape().element_type() == F32) { + literal = LiteralUtil::ConvertF32ToBF16(literal); + } + std::unique_ptr data = + client_->TransferToServer(literal).ConsumeValueOrDie(); + *data_handle = Parameter(builder, parameter_number, literal.shape(), name); + return data; +} + template std::vector ClientLibraryTestBase::CreatePseudorandomR1( const int width, NativeT min_value, NativeT max_value, uint32 seed) { diff --git a/tensorflow/compiler/xla/tests/triangular_solve_test.cc b/tensorflow/compiler/xla/tests/triangular_solve_test.cc index f3358f65ce3..65d18baae5b 100644 --- a/tensorflow/compiler/xla/tests/triangular_solve_test.cc +++ b/tensorflow/compiler/xla/tests/triangular_solve_test.cc @@ -17,6 +17,8 @@ limitations under the License. #include #include +#include "absl/strings/ascii.h" +#include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" @@ -440,7 +442,7 @@ XLA_TEST_F(TriangularSolveTest, BatchedLeftUpper) { } struct TriangularSolveTestSpec { - int m, n; // A is mxm, B is mxn + std::vector dims; // [..., m, n] A is mxm, B is mxn bool left_side; bool lower; TriangularSolveOptions::Transpose transpose_a; @@ -455,20 +457,27 @@ XLA_TEST_P(TriangularSolveParametricTest, Random) { XlaBuilder builder(TestName()); - Array2D avals(spec.m, spec.m); + CHECK_GE(spec.dims.size(), 2); + std::vector a_dims = spec.dims; + a_dims.back() = a_dims.at(a_dims.size() - 2); + Array avals(a_dims); avals.FillRandom(1.0); - for (int i = 0; i < spec.m; ++i) { - avals(i, i) += 30; - } + avals.Each([](absl::Span dims, float* v) { + if (dims.back() == dims.at(dims.size() - 2)) { + *v += 30; + } + }); - std::pair bdims = spec.left_side ? std::make_pair(spec.m, spec.n) - : std::make_pair(spec.n, spec.m); - Array2D bvals(bdims.first, bdims.second); + std::vector b_dims = spec.dims; + if (!spec.left_side) { + std::swap(b_dims.back(), b_dims.at(b_dims.size() - 2)); + } + Array bvals(b_dims); bvals.FillRandom(1.0); XlaOp a, b; - auto a_data = CreateR2Parameter(avals, 0, "a", &builder, &a); - auto b_data = CreateR2Parameter(bvals, 1, "b", &builder, &b); + auto a_data = CreateParameter(avals, 0, "a", &builder, &a); + auto b_data = CreateParameter(bvals, 1, "b", &builder, &b); auto x = TriangularSolve(a, b, spec.left_side, spec.lower, /*unit_diagonal=*/false, spec.transpose_a); auto a_tri = Triangle(a, spec.lower); @@ -480,20 +489,26 @@ XLA_TEST_P(TriangularSolveParametricTest, Random) { BatchDot(x, a_tri); } - ComputeAndCompareR2(&builder, bvals, {a_data.get(), b_data.get()}, - ErrorSpec(3e-2, 3e-2)); + ComputeAndCompare(&builder, bvals, {a_data.get(), b_data.get()}, + ErrorSpec(3e-2, 3e-2)); } std::vector TriangularSolveTests() { std::vector specs; - for (int m : {5, 10, 150}) { - for (int n : {5, 10, 150}) { - for (bool left_side : {false, true}) { - for (bool lower : {false, true}) { - for (TriangularSolveOptions::Transpose transpose_a : - {TriangularSolveOptions::NO_TRANSPOSE, - TriangularSolveOptions::TRANSPOSE}) { - specs.push_back({m, n, left_side, lower, transpose_a}); + for (auto batch : + {std::initializer_list{}, std::initializer_list{5}}) { + for (int m : {5, 10, 150}) { + for (int n : {5, 150}) { + for (bool left_side : {false, true}) { + for (bool lower : {false, true}) { + for (TriangularSolveOptions::Transpose transpose_a : + {TriangularSolveOptions::NO_TRANSPOSE, + TriangularSolveOptions::TRANSPOSE}) { + std::vector dims(batch.begin(), batch.end()); + dims.push_back(m); + dims.push_back(n); + specs.push_back({dims, left_side, lower, transpose_a}); + } } } } @@ -502,9 +517,18 @@ std::vector TriangularSolveTests() { return specs; } -INSTANTIATE_TEST_SUITE_P(TriangularSolveParametricTestInstantiation, - TriangularSolveParametricTest, - ::testing::ValuesIn(TriangularSolveTests())); +INSTANTIATE_TEST_SUITE_P( + TriangularSolveParametricTestInstantiation, TriangularSolveParametricTest, + ::testing::ValuesIn(TriangularSolveTests()), + [](const ::testing::TestParamInfo& info) { + const TriangularSolveTestSpec& spec = info.param; + std::string name = absl::StrCat( + absl::StrJoin(spec.dims, "_"), "_", spec.left_side ? "left" : "right", + "_", spec.lower ? "lower" : "upper", "_", + absl::AsciiStrToLower( + TriangularSolveOptions_Transpose_Name(spec.transpose_a))); + return name; + }); } // namespace } // namespace xla