[XLA] Fix bug in triangular solve expander.
For batched matrix whose size was not evenly divided by the number of blocks, the HLO being produced was not shape-correct. Improves test case coverage to catch the problem. Will fix https://github.com/google/jax/issues/4773 when incorporated into a jaxlib. PiperOrigin-RevId: 340678796 Change-Id: Ifc230aa0bae9aec4902556c5e7829410c60c587f
This commit is contained in:
parent
de5e35d7fe
commit
7a050b85d8
@ -1898,6 +1898,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/client/lib:slicing",
|
"//tensorflow/compiler/xla/client/lib:slicing",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||||
@ -43,6 +44,8 @@ XlaOp DiagonalBlocks(XlaOp a, int64 block_size) {
|
|||||||
int ndims = shape.rank();
|
int ndims = shape.rank();
|
||||||
int64 n = ShapeUtil::GetDimension(shape, -1);
|
int64 n = ShapeUtil::GetDimension(shape, -1);
|
||||||
int64 num_blocks = n / block_size;
|
int64 num_blocks = n / block_size;
|
||||||
|
absl::Span<int64 const> batch_dims = absl::MakeConstSpan(
|
||||||
|
shape.dimensions().begin(), shape.dimensions().begin() + (ndims - 2));
|
||||||
|
|
||||||
XlaOp diag_blocks;
|
XlaOp diag_blocks;
|
||||||
|
|
||||||
@ -100,10 +103,10 @@ XlaOp DiagonalBlocks(XlaOp a, int64 block_size) {
|
|||||||
|
|
||||||
auto eye =
|
auto eye =
|
||||||
IdentityMatrix(builder, shape.element_type(), padding, padding);
|
IdentityMatrix(builder, shape.element_type(), padding, padding);
|
||||||
config = MakeNoPaddingConfig(ndims);
|
config = MakeNoPaddingConfig(2);
|
||||||
config.mutable_dimensions(ndims - 2)->set_edge_padding_low(n %
|
config.mutable_dimensions(0)->set_edge_padding_low(n % block_size);
|
||||||
block_size);
|
|
||||||
eye = Pad(eye, Zero(builder, shape.element_type()), config);
|
eye = Pad(eye, Zero(builder, shape.element_type()), config);
|
||||||
|
eye = Broadcast(eye, batch_dims);
|
||||||
last_blocks = ConcatInDim(builder, {last_blocks, eye}, ndims - 1);
|
last_blocks = ConcatInDim(builder, {last_blocks, eye}, ndims - 1);
|
||||||
|
|
||||||
// Add a singleton dimension
|
// Add a singleton dimension
|
||||||
|
@ -2669,6 +2669,7 @@ xla_test(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":test_macros_header",
|
":test_macros_header",
|
||||||
|
"//tensorflow/compiler/xla:array",
|
||||||
"//tensorflow/compiler/xla:array2d",
|
"//tensorflow/compiler/xla:array2d",
|
||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
@ -2681,6 +2682,7 @@ xla_test(
|
|||||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -226,7 +226,13 @@ class ClientLibraryTestBase : public ManifestCheckingTest {
|
|||||||
absl::Span<const Literal> arguments);
|
absl::Span<const Literal> arguments);
|
||||||
void ComputeAndCompare(XlaBuilder* builder,
|
void ComputeAndCompare(XlaBuilder* builder,
|
||||||
absl::Span<const Literal> arguments, ErrorSpec error);
|
absl::Span<const Literal> arguments, ErrorSpec error);
|
||||||
|
template <typename NativeT>
|
||||||
|
void ComputeAndCompare(XlaBuilder* builder, const Array<NativeT>& expected,
|
||||||
|
absl::Span<GlobalData* const> arguments);
|
||||||
|
template <typename NativeT>
|
||||||
|
void ComputeAndCompare(XlaBuilder* builder, const Array<NativeT>& expected,
|
||||||
|
absl::Span<GlobalData* const> arguments,
|
||||||
|
ErrorSpec error);
|
||||||
// Create scalar operations for use in reductions.
|
// Create scalar operations for use in reductions.
|
||||||
XlaComputation CreateScalarRelu();
|
XlaComputation CreateScalarRelu();
|
||||||
XlaComputation CreateScalarMax();
|
XlaComputation CreateScalarMax();
|
||||||
@ -387,6 +393,13 @@ class ClientLibraryTestBase : public ManifestCheckingTest {
|
|||||||
const Array4D<NativeT>& array_4d, int64 parameter_number,
|
const Array4D<NativeT>& array_4d, int64 parameter_number,
|
||||||
const string& name, XlaBuilder* builder, XlaOp* data_handle);
|
const string& name, XlaBuilder* builder, XlaOp* data_handle);
|
||||||
|
|
||||||
|
template <typename NativeT>
|
||||||
|
std::unique_ptr<GlobalData> CreateParameter(const Array<NativeT>& array_4d,
|
||||||
|
int64 parameter_number,
|
||||||
|
const string& name,
|
||||||
|
XlaBuilder* builder,
|
||||||
|
XlaOp* data_handle);
|
||||||
|
|
||||||
// Getter and setter for the use_bfloat16 flag, which indicates whether to run
|
// Getter and setter for the use_bfloat16 flag, which indicates whether to run
|
||||||
// tests with all float-type input/output converted to bfloat16.
|
// tests with all float-type input/output converted to bfloat16.
|
||||||
bool use_bfloat16() const { return use_bfloat16_; }
|
bool use_bfloat16() const { return use_bfloat16_; }
|
||||||
@ -563,6 +576,31 @@ void ClientLibraryTestBase::ComputeAndCompareR4(
|
|||||||
arguments, error);
|
arguments, error);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename NativeT>
|
||||||
|
void ClientLibraryTestBase::ComputeAndCompare(
|
||||||
|
XlaBuilder* builder, const Array<NativeT>& expected,
|
||||||
|
absl::Span<GlobalData* const> arguments) {
|
||||||
|
Literal expected_literal = LiteralUtil::CreateFromArray<NativeT>(expected);
|
||||||
|
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
|
||||||
|
arguments);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename NativeT>
|
||||||
|
void ClientLibraryTestBase::ComputeAndCompare(
|
||||||
|
XlaBuilder* builder, const Array<NativeT>& expected,
|
||||||
|
absl::Span<GlobalData* const> arguments, ErrorSpec error) {
|
||||||
|
static_assert(std::is_same<NativeT, float>::value ||
|
||||||
|
std::is_same<NativeT, double>::value ||
|
||||||
|
std::is_same<NativeT, bfloat16>::value ||
|
||||||
|
std::is_same<NativeT, half>::value ||
|
||||||
|
std::is_same<NativeT, complex64>::value ||
|
||||||
|
std::is_same<NativeT, complex128>::value,
|
||||||
|
"Float or complex type required when specifying an ErrorSpec");
|
||||||
|
Literal expected_literal = LiteralUtil::CreateFromArray<NativeT>(expected);
|
||||||
|
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
|
||||||
|
arguments, error);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
|
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
|
||||||
NativeT value, int64 parameter_number, const string& name,
|
NativeT value, int64 parameter_number, const string& name,
|
||||||
@ -633,6 +671,20 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR4Parameter(
|
|||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename NativeT>
|
||||||
|
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateParameter(
|
||||||
|
const Array<NativeT>& array, int64 parameter_number, const string& name,
|
||||||
|
XlaBuilder* builder, XlaOp* data_handle) {
|
||||||
|
Literal literal = LiteralUtil::CreateFromArray(array);
|
||||||
|
if (use_bfloat16_ && literal.shape().element_type() == F32) {
|
||||||
|
literal = LiteralUtil::ConvertF32ToBF16(literal);
|
||||||
|
}
|
||||||
|
std::unique_ptr<GlobalData> data =
|
||||||
|
client_->TransferToServer(literal).ConsumeValueOrDie();
|
||||||
|
*data_handle = Parameter(builder, parameter_number, literal.shape(), name);
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
std::vector<NativeT> ClientLibraryTestBase::CreatePseudorandomR1(
|
std::vector<NativeT> ClientLibraryTestBase::CreatePseudorandomR1(
|
||||||
const int width, NativeT min_value, NativeT max_value, uint32 seed) {
|
const int width, NativeT min_value, NativeT max_value, uint32 seed) {
|
||||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
|||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/strings/ascii.h"
|
||||||
|
#include "tensorflow/compiler/xla/array.h"
|
||||||
#include "tensorflow/compiler/xla/array2d.h"
|
#include "tensorflow/compiler/xla/array2d.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||||
@ -440,7 +442,7 @@ XLA_TEST_F(TriangularSolveTest, BatchedLeftUpper) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct TriangularSolveTestSpec {
|
struct TriangularSolveTestSpec {
|
||||||
int m, n; // A is mxm, B is mxn
|
std::vector<int64> dims; // [..., m, n] A is mxm, B is mxn
|
||||||
bool left_side;
|
bool left_side;
|
||||||
bool lower;
|
bool lower;
|
||||||
TriangularSolveOptions::Transpose transpose_a;
|
TriangularSolveOptions::Transpose transpose_a;
|
||||||
@ -455,20 +457,27 @@ XLA_TEST_P(TriangularSolveParametricTest, Random) {
|
|||||||
|
|
||||||
XlaBuilder builder(TestName());
|
XlaBuilder builder(TestName());
|
||||||
|
|
||||||
Array2D<float> avals(spec.m, spec.m);
|
CHECK_GE(spec.dims.size(), 2);
|
||||||
|
std::vector<int64> a_dims = spec.dims;
|
||||||
|
a_dims.back() = a_dims.at(a_dims.size() - 2);
|
||||||
|
Array<float> avals(a_dims);
|
||||||
avals.FillRandom(1.0);
|
avals.FillRandom(1.0);
|
||||||
for (int i = 0; i < spec.m; ++i) {
|
avals.Each([](absl::Span<const int64> dims, float* v) {
|
||||||
avals(i, i) += 30;
|
if (dims.back() == dims.at(dims.size() - 2)) {
|
||||||
}
|
*v += 30;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
std::pair<int, int> bdims = spec.left_side ? std::make_pair(spec.m, spec.n)
|
std::vector<int64> b_dims = spec.dims;
|
||||||
: std::make_pair(spec.n, spec.m);
|
if (!spec.left_side) {
|
||||||
Array2D<float> bvals(bdims.first, bdims.second);
|
std::swap(b_dims.back(), b_dims.at(b_dims.size() - 2));
|
||||||
|
}
|
||||||
|
Array<float> bvals(b_dims);
|
||||||
bvals.FillRandom(1.0);
|
bvals.FillRandom(1.0);
|
||||||
|
|
||||||
XlaOp a, b;
|
XlaOp a, b;
|
||||||
auto a_data = CreateR2Parameter<float>(avals, 0, "a", &builder, &a);
|
auto a_data = CreateParameter<float>(avals, 0, "a", &builder, &a);
|
||||||
auto b_data = CreateR2Parameter<float>(bvals, 1, "b", &builder, &b);
|
auto b_data = CreateParameter<float>(bvals, 1, "b", &builder, &b);
|
||||||
auto x = TriangularSolve(a, b, spec.left_side, spec.lower,
|
auto x = TriangularSolve(a, b, spec.left_side, spec.lower,
|
||||||
/*unit_diagonal=*/false, spec.transpose_a);
|
/*unit_diagonal=*/false, spec.transpose_a);
|
||||||
auto a_tri = Triangle(a, spec.lower);
|
auto a_tri = Triangle(a, spec.lower);
|
||||||
@ -480,20 +489,26 @@ XLA_TEST_P(TriangularSolveParametricTest, Random) {
|
|||||||
BatchDot(x, a_tri);
|
BatchDot(x, a_tri);
|
||||||
}
|
}
|
||||||
|
|
||||||
ComputeAndCompareR2<float>(&builder, bvals, {a_data.get(), b_data.get()},
|
ComputeAndCompare<float>(&builder, bvals, {a_data.get(), b_data.get()},
|
||||||
ErrorSpec(3e-2, 3e-2));
|
ErrorSpec(3e-2, 3e-2));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<TriangularSolveTestSpec> TriangularSolveTests() {
|
std::vector<TriangularSolveTestSpec> TriangularSolveTests() {
|
||||||
std::vector<TriangularSolveTestSpec> specs;
|
std::vector<TriangularSolveTestSpec> specs;
|
||||||
for (int m : {5, 10, 150}) {
|
for (auto batch :
|
||||||
for (int n : {5, 10, 150}) {
|
{std::initializer_list<int64>{}, std::initializer_list<int64>{5}}) {
|
||||||
for (bool left_side : {false, true}) {
|
for (int m : {5, 10, 150}) {
|
||||||
for (bool lower : {false, true}) {
|
for (int n : {5, 150}) {
|
||||||
for (TriangularSolveOptions::Transpose transpose_a :
|
for (bool left_side : {false, true}) {
|
||||||
{TriangularSolveOptions::NO_TRANSPOSE,
|
for (bool lower : {false, true}) {
|
||||||
TriangularSolveOptions::TRANSPOSE}) {
|
for (TriangularSolveOptions::Transpose transpose_a :
|
||||||
specs.push_back({m, n, left_side, lower, transpose_a});
|
{TriangularSolveOptions::NO_TRANSPOSE,
|
||||||
|
TriangularSolveOptions::TRANSPOSE}) {
|
||||||
|
std::vector<int64> dims(batch.begin(), batch.end());
|
||||||
|
dims.push_back(m);
|
||||||
|
dims.push_back(n);
|
||||||
|
specs.push_back({dims, left_side, lower, transpose_a});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -502,9 +517,18 @@ std::vector<TriangularSolveTestSpec> TriangularSolveTests() {
|
|||||||
return specs;
|
return specs;
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(TriangularSolveParametricTestInstantiation,
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
TriangularSolveParametricTest,
|
TriangularSolveParametricTestInstantiation, TriangularSolveParametricTest,
|
||||||
::testing::ValuesIn(TriangularSolveTests()));
|
::testing::ValuesIn(TriangularSolveTests()),
|
||||||
|
[](const ::testing::TestParamInfo<TriangularSolveTestSpec>& info) {
|
||||||
|
const TriangularSolveTestSpec& spec = info.param;
|
||||||
|
std::string name = absl::StrCat(
|
||||||
|
absl::StrJoin(spec.dims, "_"), "_", spec.left_side ? "left" : "right",
|
||||||
|
"_", spec.lower ? "lower" : "upper", "_",
|
||||||
|
absl::AsciiStrToLower(
|
||||||
|
TriangularSolveOptions_Transpose_Name(spec.transpose_a)));
|
||||||
|
return name;
|
||||||
|
});
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
x
Reference in New Issue
Block a user