[XLA] Fix shape error in triangular solve expander for left_side=False where more than one block is present.
PiperOrigin-RevId: 303374553 Change-Id: I1660c65771338e86a0ce61a777bfc39274ce12b8
This commit is contained in:
parent
f8ab46a8c8
commit
d992938a2b
@ -320,10 +320,7 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks,
|
||||
end = {k, std::min(i * block_size, n)};
|
||||
}
|
||||
|
||||
if (!left_side) {
|
||||
std::swap(end[0], end[1]);
|
||||
}
|
||||
if (transpose_a) {
|
||||
if (!left_side ^ transpose_a) {
|
||||
std::swap(start[0], start[1]);
|
||||
std::swap(end[0], end[1]);
|
||||
}
|
||||
@ -337,16 +334,12 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks,
|
||||
}
|
||||
|
||||
XlaOp x_update;
|
||||
auto zero = Zero(builder, S32);
|
||||
auto start_index = ConstantR0WithType(builder, S32, j * block_size);
|
||||
std::vector<XlaOp> update_starts = {start_index, zero};
|
||||
if (left_side) {
|
||||
x_update =
|
||||
BatchDot(inv_block, transpose_a, remainder, false, precision);
|
||||
} else {
|
||||
x_update =
|
||||
BatchDot(remainder, false, inv_block, transpose_a, precision);
|
||||
std::swap(update_starts[0], update_starts[1]);
|
||||
}
|
||||
|
||||
if (i == 0) {
|
||||
|
||||
@ -458,7 +458,7 @@ XLA_TEST_P(TriangularSolveParametricTest, Random) {
|
||||
Array2D<float> avals(spec.m, spec.m);
|
||||
avals.FillRandom(1.0);
|
||||
for (int i = 0; i < spec.m; ++i) {
|
||||
avals(i, i) += 10;
|
||||
avals(i, i) += 30;
|
||||
}
|
||||
|
||||
std::pair<int, int> bdims = spec.left_side ? std::make_pair(spec.m, spec.n)
|
||||
@ -481,13 +481,13 @@ XLA_TEST_P(TriangularSolveParametricTest, Random) {
|
||||
}
|
||||
|
||||
ComputeAndCompareR2<float>(&builder, bvals, {a_data.get(), b_data.get()},
|
||||
ErrorSpec(1e-2, 1e-2));
|
||||
ErrorSpec(3e-2, 3e-2));
|
||||
}
|
||||
|
||||
std::vector<TriangularSolveTestSpec> TriangularSolveTests() {
|
||||
std::vector<TriangularSolveTestSpec> specs;
|
||||
for (int m : {5, 10}) {
|
||||
for (int n : {5, 10}) {
|
||||
for (int m : {5, 10, 150}) {
|
||||
for (int n : {5, 10, 150}) {
|
||||
for (bool left_side : {false, true}) {
|
||||
for (bool lower : {false, true}) {
|
||||
for (TriangularSolveOptions::Transpose transpose_a :
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user