[XLA] Fix shape error in triangular solve expander for left_side=False where more than one block is present.

PiperOrigin-RevId: 303374553
Change-Id: I1660c65771338e86a0ce61a777bfc39274ce12b8
This commit is contained in:
Peter Hawkins 2020-03-27 11:46:43 -07:00 committed by TensorFlower Gardener
parent f8ab46a8c8
commit d992938a2b
2 changed files with 5 additions and 12 deletions

View File

@ -320,10 +320,7 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks,
end = {k, std::min(i * block_size, n)};
}
if (!left_side) {
std::swap(end[0], end[1]);
}
if (transpose_a) {
if (!left_side ^ transpose_a) {
std::swap(start[0], start[1]);
std::swap(end[0], end[1]);
}
@ -337,16 +334,12 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks,
}
XlaOp x_update;
auto zero = Zero(builder, S32);
auto start_index = ConstantR0WithType(builder, S32, j * block_size);
std::vector<XlaOp> update_starts = {start_index, zero};
if (left_side) {
x_update =
BatchDot(inv_block, transpose_a, remainder, false, precision);
} else {
x_update =
BatchDot(remainder, false, inv_block, transpose_a, precision);
std::swap(update_starts[0], update_starts[1]);
}
if (i == 0) {

View File

@ -458,7 +458,7 @@ XLA_TEST_P(TriangularSolveParametricTest, Random) {
Array2D<float> avals(spec.m, spec.m);
avals.FillRandom(1.0);
for (int i = 0; i < spec.m; ++i) {
avals(i, i) += 10;
avals(i, i) += 30;
}
std::pair<int, int> bdims = spec.left_side ? std::make_pair(spec.m, spec.n)
@ -481,13 +481,13 @@ XLA_TEST_P(TriangularSolveParametricTest, Random) {
}
ComputeAndCompareR2<float>(&builder, bvals, {a_data.get(), b_data.get()},
ErrorSpec(1e-2, 1e-2));
ErrorSpec(3e-2, 3e-2));
}
std::vector<TriangularSolveTestSpec> TriangularSolveTests() {
std::vector<TriangularSolveTestSpec> specs;
for (int m : {5, 10}) {
for (int n : {5, 10}) {
for (int m : {5, 10, 150}) {
for (int n : {5, 10, 150}) {
for (bool left_side : {false, true}) {
for (bool lower : {false, true}) {
for (TriangularSolveOptions::Transpose transpose_a :