[XLA] Fix bug in triangular solve expander.
For batched matrix whose size was not evenly divided by the number of blocks, the HLO being produced was not shape-correct. Improves test case coverage to catch the problem. Will fix https://github.com/google/jax/issues/4773 when incorporated into a jaxlib. PiperOrigin-RevId: 340678796 Change-Id: Ifc230aa0bae9aec4902556c5e7829410c60c587f
This commit is contained in:
parent
de5e35d7fe
commit
7a050b85d8
@ -1898,6 +1898,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/client/lib:slicing",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
@ -43,6 +44,8 @@ XlaOp DiagonalBlocks(XlaOp a, int64 block_size) {
|
||||
int ndims = shape.rank();
|
||||
int64 n = ShapeUtil::GetDimension(shape, -1);
|
||||
int64 num_blocks = n / block_size;
|
||||
absl::Span<int64 const> batch_dims = absl::MakeConstSpan(
|
||||
shape.dimensions().begin(), shape.dimensions().begin() + (ndims - 2));
|
||||
|
||||
XlaOp diag_blocks;
|
||||
|
||||
@ -100,10 +103,10 @@ XlaOp DiagonalBlocks(XlaOp a, int64 block_size) {
|
||||
|
||||
auto eye =
|
||||
IdentityMatrix(builder, shape.element_type(), padding, padding);
|
||||
config = MakeNoPaddingConfig(ndims);
|
||||
config.mutable_dimensions(ndims - 2)->set_edge_padding_low(n %
|
||||
block_size);
|
||||
config = MakeNoPaddingConfig(2);
|
||||
config.mutable_dimensions(0)->set_edge_padding_low(n % block_size);
|
||||
eye = Pad(eye, Zero(builder, shape.element_type()), config);
|
||||
eye = Broadcast(eye, batch_dims);
|
||||
last_blocks = ConcatInDim(builder, {last_blocks, eye}, ndims - 1);
|
||||
|
||||
// Add a singleton dimension
|
||||
|
@ -2669,6 +2669,7 @@ xla_test(
|
||||
],
|
||||
deps = [
|
||||
":test_macros_header",
|
||||
"//tensorflow/compiler/xla:array",
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -2681,6 +2682,7 @@ xla_test(
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -226,7 +226,13 @@ class ClientLibraryTestBase : public ManifestCheckingTest {
|
||||
absl::Span<const Literal> arguments);
|
||||
void ComputeAndCompare(XlaBuilder* builder,
|
||||
absl::Span<const Literal> arguments, ErrorSpec error);
|
||||
|
||||
template <typename NativeT>
|
||||
void ComputeAndCompare(XlaBuilder* builder, const Array<NativeT>& expected,
|
||||
absl::Span<GlobalData* const> arguments);
|
||||
template <typename NativeT>
|
||||
void ComputeAndCompare(XlaBuilder* builder, const Array<NativeT>& expected,
|
||||
absl::Span<GlobalData* const> arguments,
|
||||
ErrorSpec error);
|
||||
// Create scalar operations for use in reductions.
|
||||
XlaComputation CreateScalarRelu();
|
||||
XlaComputation CreateScalarMax();
|
||||
@ -387,6 +393,13 @@ class ClientLibraryTestBase : public ManifestCheckingTest {
|
||||
const Array4D<NativeT>& array_4d, int64 parameter_number,
|
||||
const string& name, XlaBuilder* builder, XlaOp* data_handle);
|
||||
|
||||
template <typename NativeT>
|
||||
std::unique_ptr<GlobalData> CreateParameter(const Array<NativeT>& array_4d,
|
||||
int64 parameter_number,
|
||||
const string& name,
|
||||
XlaBuilder* builder,
|
||||
XlaOp* data_handle);
|
||||
|
||||
// Getter and setter for the use_bfloat16 flag, which indicates whether to run
|
||||
// tests with all float-type input/output converted to bfloat16.
|
||||
bool use_bfloat16() const { return use_bfloat16_; }
|
||||
@ -563,6 +576,31 @@ void ClientLibraryTestBase::ComputeAndCompareR4(
|
||||
arguments, error);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void ClientLibraryTestBase::ComputeAndCompare(
|
||||
XlaBuilder* builder, const Array<NativeT>& expected,
|
||||
absl::Span<GlobalData* const> arguments) {
|
||||
Literal expected_literal = LiteralUtil::CreateFromArray<NativeT>(expected);
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
|
||||
arguments);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void ClientLibraryTestBase::ComputeAndCompare(
|
||||
XlaBuilder* builder, const Array<NativeT>& expected,
|
||||
absl::Span<GlobalData* const> arguments, ErrorSpec error) {
|
||||
static_assert(std::is_same<NativeT, float>::value ||
|
||||
std::is_same<NativeT, double>::value ||
|
||||
std::is_same<NativeT, bfloat16>::value ||
|
||||
std::is_same<NativeT, half>::value ||
|
||||
std::is_same<NativeT, complex64>::value ||
|
||||
std::is_same<NativeT, complex128>::value,
|
||||
"Float or complex type required when specifying an ErrorSpec");
|
||||
Literal expected_literal = LiteralUtil::CreateFromArray<NativeT>(expected);
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
|
||||
arguments, error);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
|
||||
NativeT value, int64 parameter_number, const string& name,
|
||||
@ -633,6 +671,20 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR4Parameter(
|
||||
return data;
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateParameter(
|
||||
const Array<NativeT>& array, int64 parameter_number, const string& name,
|
||||
XlaBuilder* builder, XlaOp* data_handle) {
|
||||
Literal literal = LiteralUtil::CreateFromArray(array);
|
||||
if (use_bfloat16_ && literal.shape().element_type() == F32) {
|
||||
literal = LiteralUtil::ConvertF32ToBF16(literal);
|
||||
}
|
||||
std::unique_ptr<GlobalData> data =
|
||||
client_->TransferToServer(literal).ConsumeValueOrDie();
|
||||
*data_handle = Parameter(builder, parameter_number, literal.shape(), name);
|
||||
return data;
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
std::vector<NativeT> ClientLibraryTestBase::CreatePseudorandomR1(
|
||||
const int width, NativeT min_value, NativeT max_value, uint32 seed) {
|
||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/ascii.h"
|
||||
#include "tensorflow/compiler/xla/array.h"
|
||||
#include "tensorflow/compiler/xla/array2d.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
@ -440,7 +442,7 @@ XLA_TEST_F(TriangularSolveTest, BatchedLeftUpper) {
|
||||
}
|
||||
|
||||
struct TriangularSolveTestSpec {
|
||||
int m, n; // A is mxm, B is mxn
|
||||
std::vector<int64> dims; // [..., m, n] A is mxm, B is mxn
|
||||
bool left_side;
|
||||
bool lower;
|
||||
TriangularSolveOptions::Transpose transpose_a;
|
||||
@ -455,20 +457,27 @@ XLA_TEST_P(TriangularSolveParametricTest, Random) {
|
||||
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
Array2D<float> avals(spec.m, spec.m);
|
||||
CHECK_GE(spec.dims.size(), 2);
|
||||
std::vector<int64> a_dims = spec.dims;
|
||||
a_dims.back() = a_dims.at(a_dims.size() - 2);
|
||||
Array<float> avals(a_dims);
|
||||
avals.FillRandom(1.0);
|
||||
for (int i = 0; i < spec.m; ++i) {
|
||||
avals(i, i) += 30;
|
||||
}
|
||||
avals.Each([](absl::Span<const int64> dims, float* v) {
|
||||
if (dims.back() == dims.at(dims.size() - 2)) {
|
||||
*v += 30;
|
||||
}
|
||||
});
|
||||
|
||||
std::pair<int, int> bdims = spec.left_side ? std::make_pair(spec.m, spec.n)
|
||||
: std::make_pair(spec.n, spec.m);
|
||||
Array2D<float> bvals(bdims.first, bdims.second);
|
||||
std::vector<int64> b_dims = spec.dims;
|
||||
if (!spec.left_side) {
|
||||
std::swap(b_dims.back(), b_dims.at(b_dims.size() - 2));
|
||||
}
|
||||
Array<float> bvals(b_dims);
|
||||
bvals.FillRandom(1.0);
|
||||
|
||||
XlaOp a, b;
|
||||
auto a_data = CreateR2Parameter<float>(avals, 0, "a", &builder, &a);
|
||||
auto b_data = CreateR2Parameter<float>(bvals, 1, "b", &builder, &b);
|
||||
auto a_data = CreateParameter<float>(avals, 0, "a", &builder, &a);
|
||||
auto b_data = CreateParameter<float>(bvals, 1, "b", &builder, &b);
|
||||
auto x = TriangularSolve(a, b, spec.left_side, spec.lower,
|
||||
/*unit_diagonal=*/false, spec.transpose_a);
|
||||
auto a_tri = Triangle(a, spec.lower);
|
||||
@ -480,20 +489,26 @@ XLA_TEST_P(TriangularSolveParametricTest, Random) {
|
||||
BatchDot(x, a_tri);
|
||||
}
|
||||
|
||||
ComputeAndCompareR2<float>(&builder, bvals, {a_data.get(), b_data.get()},
|
||||
ErrorSpec(3e-2, 3e-2));
|
||||
ComputeAndCompare<float>(&builder, bvals, {a_data.get(), b_data.get()},
|
||||
ErrorSpec(3e-2, 3e-2));
|
||||
}
|
||||
|
||||
std::vector<TriangularSolveTestSpec> TriangularSolveTests() {
|
||||
std::vector<TriangularSolveTestSpec> specs;
|
||||
for (int m : {5, 10, 150}) {
|
||||
for (int n : {5, 10, 150}) {
|
||||
for (bool left_side : {false, true}) {
|
||||
for (bool lower : {false, true}) {
|
||||
for (TriangularSolveOptions::Transpose transpose_a :
|
||||
{TriangularSolveOptions::NO_TRANSPOSE,
|
||||
TriangularSolveOptions::TRANSPOSE}) {
|
||||
specs.push_back({m, n, left_side, lower, transpose_a});
|
||||
for (auto batch :
|
||||
{std::initializer_list<int64>{}, std::initializer_list<int64>{5}}) {
|
||||
for (int m : {5, 10, 150}) {
|
||||
for (int n : {5, 150}) {
|
||||
for (bool left_side : {false, true}) {
|
||||
for (bool lower : {false, true}) {
|
||||
for (TriangularSolveOptions::Transpose transpose_a :
|
||||
{TriangularSolveOptions::NO_TRANSPOSE,
|
||||
TriangularSolveOptions::TRANSPOSE}) {
|
||||
std::vector<int64> dims(batch.begin(), batch.end());
|
||||
dims.push_back(m);
|
||||
dims.push_back(n);
|
||||
specs.push_back({dims, left_side, lower, transpose_a});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -502,9 +517,18 @@ std::vector<TriangularSolveTestSpec> TriangularSolveTests() {
|
||||
return specs;
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TriangularSolveParametricTestInstantiation,
|
||||
TriangularSolveParametricTest,
|
||||
::testing::ValuesIn(TriangularSolveTests()));
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TriangularSolveParametricTestInstantiation, TriangularSolveParametricTest,
|
||||
::testing::ValuesIn(TriangularSolveTests()),
|
||||
[](const ::testing::TestParamInfo<TriangularSolveTestSpec>& info) {
|
||||
const TriangularSolveTestSpec& spec = info.param;
|
||||
std::string name = absl::StrCat(
|
||||
absl::StrJoin(spec.dims, "_"), "_", spec.left_side ? "left" : "right",
|
||||
"_", spec.lower ? "lower" : "upper", "_",
|
||||
absl::AsciiStrToLower(
|
||||
TriangularSolveOptions_Transpose_Name(spec.transpose_a)));
|
||||
return name;
|
||||
});
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user