[XLA] Fix bug in triangular solve expander.

For batched matrix whose size was not evenly divided by the number of blocks, the HLO being produced was not shape-correct. Improves test case coverage to catch the problem.

Will fix https://github.com/google/jax/issues/4773 when incorporated into a jaxlib.

PiperOrigin-RevId: 340678796
Change-Id: Ifc230aa0bae9aec4902556c5e7829410c60c587f
This commit is contained in:
Peter Hawkins 2020-11-04 10:20:57 -08:00 committed by TensorFlower Gardener
parent de5e35d7fe
commit 7a050b85d8
5 changed files with 109 additions and 27 deletions

View File

@ -1898,6 +1898,7 @@ cc_library(
"//tensorflow/compiler/xla/client/lib:slicing",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/types:span",
],
)

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
@ -43,6 +44,8 @@ XlaOp DiagonalBlocks(XlaOp a, int64 block_size) {
int ndims = shape.rank();
int64 n = ShapeUtil::GetDimension(shape, -1);
int64 num_blocks = n / block_size;
absl::Span<int64 const> batch_dims = absl::MakeConstSpan(
shape.dimensions().begin(), shape.dimensions().begin() + (ndims - 2));
XlaOp diag_blocks;
@ -100,10 +103,10 @@ XlaOp DiagonalBlocks(XlaOp a, int64 block_size) {
auto eye =
IdentityMatrix(builder, shape.element_type(), padding, padding);
config = MakeNoPaddingConfig(ndims);
config.mutable_dimensions(ndims - 2)->set_edge_padding_low(n %
block_size);
config = MakeNoPaddingConfig(2);
config.mutable_dimensions(0)->set_edge_padding_low(n % block_size);
eye = Pad(eye, Zero(builder, shape.element_type()), config);
eye = Broadcast(eye, batch_dims);
last_blocks = ConcatInDim(builder, {last_blocks, eye}, ndims - 1);
// Add a singleton dimension

View File

@ -2669,6 +2669,7 @@ xla_test(
],
deps = [
":test_macros_header",
"//tensorflow/compiler/xla:array",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:statusor",
@ -2681,6 +2682,7 @@ xla_test(
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
"@com_google_absl//absl/strings",
],
)

View File

@ -226,7 +226,13 @@ class ClientLibraryTestBase : public ManifestCheckingTest {
absl::Span<const Literal> arguments);
void ComputeAndCompare(XlaBuilder* builder,
absl::Span<const Literal> arguments, ErrorSpec error);
template <typename NativeT>
void ComputeAndCompare(XlaBuilder* builder, const Array<NativeT>& expected,
absl::Span<GlobalData* const> arguments);
template <typename NativeT>
void ComputeAndCompare(XlaBuilder* builder, const Array<NativeT>& expected,
absl::Span<GlobalData* const> arguments,
ErrorSpec error);
// Create scalar operations for use in reductions.
XlaComputation CreateScalarRelu();
XlaComputation CreateScalarMax();
@ -387,6 +393,13 @@ class ClientLibraryTestBase : public ManifestCheckingTest {
const Array4D<NativeT>& array_4d, int64 parameter_number,
const string& name, XlaBuilder* builder, XlaOp* data_handle);
template <typename NativeT>
std::unique_ptr<GlobalData> CreateParameter(const Array<NativeT>& array_4d,
int64 parameter_number,
const string& name,
XlaBuilder* builder,
XlaOp* data_handle);
// Getter and setter for the use_bfloat16 flag, which indicates whether to run
// tests with all float-type input/output converted to bfloat16.
bool use_bfloat16() const { return use_bfloat16_; }
@ -563,6 +576,31 @@ void ClientLibraryTestBase::ComputeAndCompareR4(
arguments, error);
}
template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompare(
XlaBuilder* builder, const Array<NativeT>& expected,
absl::Span<GlobalData* const> arguments) {
Literal expected_literal = LiteralUtil::CreateFromArray<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments);
}
template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompare(
XlaBuilder* builder, const Array<NativeT>& expected,
absl::Span<GlobalData* const> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
std::is_same<NativeT, bfloat16>::value ||
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value ||
std::is_same<NativeT, complex128>::value,
"Float or complex type required when specifying an ErrorSpec");
Literal expected_literal = LiteralUtil::CreateFromArray<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
arguments, error);
}
template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
NativeT value, int64 parameter_number, const string& name,
@ -633,6 +671,20 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR4Parameter(
return data;
}
template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateParameter(
const Array<NativeT>& array, int64 parameter_number, const string& name,
XlaBuilder* builder, XlaOp* data_handle) {
Literal literal = LiteralUtil::CreateFromArray(array);
if (use_bfloat16_ && literal.shape().element_type() == F32) {
literal = LiteralUtil::ConvertF32ToBF16(literal);
}
std::unique_ptr<GlobalData> data =
client_->TransferToServer(literal).ConsumeValueOrDie();
*data_handle = Parameter(builder, parameter_number, literal.shape(), name);
return data;
}
template <typename NativeT>
std::vector<NativeT> ClientLibraryTestBase::CreatePseudorandomR1(
const int width, NativeT min_value, NativeT max_value, uint32 seed) {

View File

@ -17,6 +17,8 @@ limitations under the License.
#include <numeric>
#include <vector>
#include "absl/strings/ascii.h"
#include "tensorflow/compiler/xla/array.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
@ -440,7 +442,7 @@ XLA_TEST_F(TriangularSolveTest, BatchedLeftUpper) {
}
struct TriangularSolveTestSpec {
int m, n; // A is mxm, B is mxn
std::vector<int64> dims; // [..., m, n] A is mxm, B is mxn
bool left_side;
bool lower;
TriangularSolveOptions::Transpose transpose_a;
@ -455,20 +457,27 @@ XLA_TEST_P(TriangularSolveParametricTest, Random) {
XlaBuilder builder(TestName());
Array2D<float> avals(spec.m, spec.m);
CHECK_GE(spec.dims.size(), 2);
std::vector<int64> a_dims = spec.dims;
a_dims.back() = a_dims.at(a_dims.size() - 2);
Array<float> avals(a_dims);
avals.FillRandom(1.0);
for (int i = 0; i < spec.m; ++i) {
avals(i, i) += 30;
}
avals.Each([](absl::Span<const int64> dims, float* v) {
if (dims.back() == dims.at(dims.size() - 2)) {
*v += 30;
}
});
std::pair<int, int> bdims = spec.left_side ? std::make_pair(spec.m, spec.n)
: std::make_pair(spec.n, spec.m);
Array2D<float> bvals(bdims.first, bdims.second);
std::vector<int64> b_dims = spec.dims;
if (!spec.left_side) {
std::swap(b_dims.back(), b_dims.at(b_dims.size() - 2));
}
Array<float> bvals(b_dims);
bvals.FillRandom(1.0);
XlaOp a, b;
auto a_data = CreateR2Parameter<float>(avals, 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(bvals, 1, "b", &builder, &b);
auto a_data = CreateParameter<float>(avals, 0, "a", &builder, &a);
auto b_data = CreateParameter<float>(bvals, 1, "b", &builder, &b);
auto x = TriangularSolve(a, b, spec.left_side, spec.lower,
/*unit_diagonal=*/false, spec.transpose_a);
auto a_tri = Triangle(a, spec.lower);
@ -480,20 +489,26 @@ XLA_TEST_P(TriangularSolveParametricTest, Random) {
BatchDot(x, a_tri);
}
ComputeAndCompareR2<float>(&builder, bvals, {a_data.get(), b_data.get()},
ErrorSpec(3e-2, 3e-2));
ComputeAndCompare<float>(&builder, bvals, {a_data.get(), b_data.get()},
ErrorSpec(3e-2, 3e-2));
}
std::vector<TriangularSolveTestSpec> TriangularSolveTests() {
std::vector<TriangularSolveTestSpec> specs;
for (int m : {5, 10, 150}) {
for (int n : {5, 10, 150}) {
for (bool left_side : {false, true}) {
for (bool lower : {false, true}) {
for (TriangularSolveOptions::Transpose transpose_a :
{TriangularSolveOptions::NO_TRANSPOSE,
TriangularSolveOptions::TRANSPOSE}) {
specs.push_back({m, n, left_side, lower, transpose_a});
for (auto batch :
{std::initializer_list<int64>{}, std::initializer_list<int64>{5}}) {
for (int m : {5, 10, 150}) {
for (int n : {5, 150}) {
for (bool left_side : {false, true}) {
for (bool lower : {false, true}) {
for (TriangularSolveOptions::Transpose transpose_a :
{TriangularSolveOptions::NO_TRANSPOSE,
TriangularSolveOptions::TRANSPOSE}) {
std::vector<int64> dims(batch.begin(), batch.end());
dims.push_back(m);
dims.push_back(n);
specs.push_back({dims, left_side, lower, transpose_a});
}
}
}
}
@ -502,9 +517,18 @@ std::vector<TriangularSolveTestSpec> TriangularSolveTests() {
return specs;
}
INSTANTIATE_TEST_SUITE_P(TriangularSolveParametricTestInstantiation,
TriangularSolveParametricTest,
::testing::ValuesIn(TriangularSolveTests()));
INSTANTIATE_TEST_SUITE_P(
TriangularSolveParametricTestInstantiation, TriangularSolveParametricTest,
::testing::ValuesIn(TriangularSolveTests()),
[](const ::testing::TestParamInfo<TriangularSolveTestSpec>& info) {
const TriangularSolveTestSpec& spec = info.param;
std::string name = absl::StrCat(
absl::StrJoin(spec.dims, "_"), "_", spec.left_side ? "left" : "right",
"_", spec.lower ? "lower" : "upper", "_",
absl::AsciiStrToLower(
TriangularSolveOptions_Transpose_Name(spec.transpose_a)));
return name;
});
} // namespace
} // namespace xla