diff --git a/tensorflow/compiler/xla/service/triangular_solve_expander.cc b/tensorflow/compiler/xla/service/triangular_solve_expander.cc index a19f17996be..cc483c310e8 100644 --- a/tensorflow/compiler/xla/service/triangular_solve_expander.cc +++ b/tensorflow/compiler/xla/service/triangular_solve_expander.cc @@ -320,10 +320,7 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, end = {k, std::min(i * block_size, n)}; } - if (!left_side) { - std::swap(end[0], end[1]); - } - if (transpose_a) { + if (!left_side ^ transpose_a) { std::swap(start[0], start[1]); std::swap(end[0], end[1]); } @@ -337,16 +334,12 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, } XlaOp x_update; - auto zero = Zero(builder, S32); - auto start_index = ConstantR0WithType(builder, S32, j * block_size); - std::vector update_starts = {start_index, zero}; if (left_side) { x_update = BatchDot(inv_block, transpose_a, remainder, false, precision); } else { x_update = BatchDot(remainder, false, inv_block, transpose_a, precision); - std::swap(update_starts[0], update_starts[1]); } if (i == 0) { diff --git a/tensorflow/compiler/xla/tests/triangular_solve_test.cc b/tensorflow/compiler/xla/tests/triangular_solve_test.cc index f2a95ab126a..f3358f65ce3 100644 --- a/tensorflow/compiler/xla/tests/triangular_solve_test.cc +++ b/tensorflow/compiler/xla/tests/triangular_solve_test.cc @@ -458,7 +458,7 @@ XLA_TEST_P(TriangularSolveParametricTest, Random) { Array2D avals(spec.m, spec.m); avals.FillRandom(1.0); for (int i = 0; i < spec.m; ++i) { - avals(i, i) += 10; + avals(i, i) += 30; } std::pair bdims = spec.left_side ? std::make_pair(spec.m, spec.n) @@ -481,13 +481,13 @@ XLA_TEST_P(TriangularSolveParametricTest, Random) { } ComputeAndCompareR2(&builder, bvals, {a_data.get(), b_data.get()}, - ErrorSpec(1e-2, 1e-2)); + ErrorSpec(3e-2, 3e-2)); } std::vector TriangularSolveTests() { std::vector specs; - for (int m : {5, 10}) { - for (int n : {5, 10}) { + for (int m : {5, 10, 150}) { + for (int n : {5, 10, 150}) { for (bool left_side : {false, true}) { for (bool lower : {false, true}) { for (TriangularSolveOptions::Transpose transpose_a :