[XLA] Add support for Hermitian eigendecompositions in eigh expander pass.
PiperOrigin-RevId: 360417733 Change-Id: Ic5f2e8eb60f54742ee0808015e099b78ac00258b
This commit is contained in:
parent
acb619833a
commit
29ba662acd
@ -57,10 +57,8 @@ class SelfAdjointEigV2Op : public XlaOpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("XlaSelfAdjointEig").TypeConstraint("T", kFloatTypes),
|
||||
XlaSelfAdjointEigOp);
|
||||
REGISTER_XLA_OP(Name("SelfAdjointEigV2").TypeConstraint("T", kFloatTypes),
|
||||
SelfAdjointEigV2Op);
|
||||
REGISTER_XLA_OP(Name("XlaSelfAdjointEig"), XlaSelfAdjointEigOp);
|
||||
REGISTER_XLA_OP(Name("SelfAdjointEigV2"), SelfAdjointEigV2Op);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -467,17 +467,15 @@ cc_library(
|
||||
xla_test(
|
||||
name = "self_adjoint_eig_test",
|
||||
srcs = ["self_adjoint_eig_test.cc"],
|
||||
disabled_backends = [
|
||||
"cpu",
|
||||
"gpu",
|
||||
],
|
||||
real_hardware_only = True,
|
||||
tags = ["optonly"],
|
||||
deps = [
|
||||
":arithmetic",
|
||||
":constants",
|
||||
":math",
|
||||
":matrix",
|
||||
":self_adjoint_eig",
|
||||
"//tensorflow/compiler/xla:array",
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:array3d",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
|
@ -37,6 +37,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -235,6 +236,36 @@ XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); }
|
||||
|
||||
XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); }
|
||||
|
||||
XlaOp Symmetrize(XlaOp x, bool lower) {
|
||||
XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
|
||||
if (shape.rank() < 2) {
|
||||
return InvalidArgument(
|
||||
"Argument to symmetrize must have >= 2 dimensions, got %s",
|
||||
shape.ToString());
|
||||
}
|
||||
const int64 m = ShapeUtil::GetDimension(shape, -2);
|
||||
const int64 n = ShapeUtil::GetDimension(shape, -1);
|
||||
if (m != n) {
|
||||
return InvalidArgument(
|
||||
"The two most minor dimensions of the argument to symmetrize must be "
|
||||
"equal size, got %s",
|
||||
shape.ToString());
|
||||
}
|
||||
auto mask = lower ? TriangleMask(x, 0) : Not(TriangleMask(x, -1));
|
||||
if (primitive_util::IsComplexType(shape.element_type())) {
|
||||
auto re = Select(mask, Real(x), TransposeInMinorDims(Real(x)));
|
||||
auto im_mask = lower ? TriangleMask(x, -1) : Not(TriangleMask(x, 0));
|
||||
auto im = Select(im_mask, Imag(x), ZerosLike(Imag(x)));
|
||||
im = Select(mask, im, -TransposeInMinorDims(im));
|
||||
return Complex(re, im);
|
||||
} else {
|
||||
return Select(mask, x, TransposeInMinorDims(x));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
namespace {
|
||||
absl::optional<std::array<std::vector<int64>, 3>> EinsumDiagonalLabels(
|
||||
absl::Span<const int64> config) {
|
||||
|
@ -50,8 +50,8 @@ XlaOp GetMatrixDiagonalViaGather(XlaOp x, int k = 0);
|
||||
// Places diag along the kth diagonal of target.
|
||||
XlaOp SetMatrixDiagonal(XlaOp matrix, XlaOp diag, int k = 0);
|
||||
|
||||
// Returns a lower-triangular mask, i.e., true below the `diagonal`-th diagonal
|
||||
// and false above that diagonal.
|
||||
// Returns a lower-triangular mask, i.e., true below and including the
|
||||
// `diagonal`-th diagonal and false above that diagonal.
|
||||
XlaOp TriangleMask(XlaOp x, int diagonal);
|
||||
|
||||
// Get the upper or lower triangle part of the last two dimensions
|
||||
@ -63,6 +63,13 @@ XlaOp UpperTriangle(XlaOp x);
|
||||
// Get the lower triangle part of the last two dimensions
|
||||
XlaOp LowerTriangle(XlaOp x);
|
||||
|
||||
// If x is an array of shape [..., n, n], symmetrizes the matrix by replacing
|
||||
// the upper triangle with the transpose of the lower triangle (if lower is
|
||||
// True, vice-versa otherwise). If the type of `x` is complex, makes the matrix
|
||||
// Hermitian by taking the conjugate of the complex part and setting the
|
||||
// complex diagonal to zero.
|
||||
XlaOp Symmetrize(XlaOp x, bool lower);
|
||||
|
||||
// Multiplies slices of two tensors in batches.
|
||||
|
||||
// Multiplies all slices of `Tensor` `x` and `y` (each slice can be
|
||||
|
@ -73,6 +73,54 @@ XLA_TEST_F(MatrixTest, Triangle) {
|
||||
ComputeAndCompareR3<int32>(&builder, expected, {a_data.get()});
|
||||
}
|
||||
|
||||
XLA_TEST_F(MatrixTest, Symmetrize) {
|
||||
for (bool lower : {false, true}) {
|
||||
XlaBuilder builder(TestName());
|
||||
float nan = std::numeric_limits<float>::quiet_NaN();
|
||||
Array<float> input = {
|
||||
{1, nan, nan},
|
||||
{2, 3, nan},
|
||||
{4, 5, 6},
|
||||
};
|
||||
|
||||
XlaOp a;
|
||||
auto a_data = CreateParameter<float>(input, 0, "a", &builder, &a);
|
||||
Symmetrize(lower ? a : TransposeInMinorDims(a), /*lower=*/lower);
|
||||
|
||||
Array<float> expected = {
|
||||
{1, 2, 4},
|
||||
{2, 3, 5},
|
||||
{4, 5, 6},
|
||||
};
|
||||
|
||||
ComputeAndCompare<float>(&builder, expected, {a_data.get()});
|
||||
}
|
||||
}
|
||||
|
||||
XLA_TEST_F(MatrixTest, SymmetrizeComplex) {
|
||||
for (bool lower : {false, true}) {
|
||||
XlaBuilder builder(TestName());
|
||||
float nan = std::numeric_limits<float>::quiet_NaN();
|
||||
Array<complex64> input = {
|
||||
{complex64{1, nan}, nan, nan},
|
||||
{complex64{2, 7}, complex64{3, nan}, nan},
|
||||
{complex64{4, 8}, complex64{5, 9}, complex64{6, nan}},
|
||||
};
|
||||
|
||||
XlaOp a;
|
||||
auto a_data = CreateParameter<complex64>(input, 0, "a", &builder, &a);
|
||||
Symmetrize(lower ? a : Conj(TransposeInMinorDims(a)), /*lower=*/lower);
|
||||
|
||||
Array<complex64> expected = {
|
||||
{1, complex64{2, -7}, complex64{4, -8}},
|
||||
{complex64{2, 7}, 3, complex64{5, -9}},
|
||||
{complex64{4, 8}, complex64{5, 9}, 6},
|
||||
};
|
||||
|
||||
ComputeAndCompare<complex64>(&builder, expected, {a_data.get()});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MatrixTest::TestMatrixDiagonal() {
|
||||
XlaBuilder builder("SetMatrixDiagonal");
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
@ -47,9 +48,12 @@ SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower, int64 max_iter,
|
||||
a_shape.ToString());
|
||||
}
|
||||
PrimitiveType type = a_shape.element_type();
|
||||
if (!primitive_util::IsFloatingPointType(type)) {
|
||||
return InvalidArgument("Type of the input matrix must be float: got %s.",
|
||||
a_shape.ToString());
|
||||
if (!primitive_util::IsFloatingPointType(type) &&
|
||||
!primitive_util::IsComplexType(type)) {
|
||||
return InvalidArgument(
|
||||
"Type of the input matrix must be floating point "
|
||||
"or complex: got %s.",
|
||||
a_shape.ToString());
|
||||
}
|
||||
|
||||
const int64 m = ShapeUtil::GetDimension(a_shape, -2);
|
||||
@ -67,10 +71,14 @@ SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower, int64 max_iter,
|
||||
a_shape.dimensions().begin(),
|
||||
a_shape.dimensions().begin() + num_batch_dims);
|
||||
|
||||
PrimitiveType eigvals_type =
|
||||
primitive_util::IsComplexType(type)
|
||||
? primitive_util::ComplexComponentType(type)
|
||||
: type;
|
||||
std::vector<int64> eigvals_dims = batch_dims;
|
||||
eigvals_dims.push_back(m);
|
||||
Shape eigh_shape = ShapeUtil::MakeTupleShape(
|
||||
{a_shape, ShapeUtil::MakeShape(type, eigvals_dims)});
|
||||
{a_shape, ShapeUtil::MakeShape(eigvals_type, eigvals_dims)});
|
||||
// TODO(phawkins): upgrade Eigh decomposition to a first-class HLO operator.
|
||||
std::string opaque =
|
||||
absl::StrFormat("%d,%d,%f", lower ? 1 : 0, max_iter, tol);
|
||||
|
@ -15,10 +15,12 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/array.h"
|
||||
#include "tensorflow/compiler/xla/array2d.h"
|
||||
#include "tensorflow/compiler/xla/array3d.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
@ -121,8 +123,12 @@ XlaOp ComputeMatmulVWVt(SelfAdjointEigResult result, XlaBuilder* builder) {
|
||||
std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
|
||||
|
||||
broadcast_dims[shape.rank() - 2] = shape.rank() - 1;
|
||||
auto vw = Mul(result.v, BroadcastInDim(result.w, out_dims, broadcast_dims));
|
||||
return BatchDot(vw, TransposeInMinorDims(result.v), PrecisionConfig::HIGHEST);
|
||||
auto vw =
|
||||
Mul(result.v,
|
||||
BroadcastInDim(ConvertElementType(result.w, shape.element_type()),
|
||||
out_dims, broadcast_dims));
|
||||
return BatchDot(vw, MaybeConjugate(TransposeInMinorDims(result.v), true),
|
||||
PrecisionConfig::HIGHEST);
|
||||
}
|
||||
|
||||
XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_2x4x4) {
|
||||
@ -137,6 +143,22 @@ XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_2x4x4) {
|
||||
ErrorSpec(1e-3, 1e-3));
|
||||
}
|
||||
|
||||
XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_3x3_Complex) {
|
||||
XlaBuilder builder(TestName());
|
||||
Array<complex64> input = {
|
||||
{1, complex64{2, -7}, complex64{4, -8}},
|
||||
{complex64{2, 7}, 3, complex64{5, -9}},
|
||||
{complex64{4, 8}, complex64{5, 9}, 6},
|
||||
};
|
||||
XlaOp a;
|
||||
auto a_data = CreateParameter<complex64>(input, 0, "a", &builder, &a);
|
||||
auto result = SelfAdjointEig(a);
|
||||
ComputeMatmulVWVt(result, &builder);
|
||||
|
||||
ComputeAndCompare<complex64>(&builder, input, {a_data.get()},
|
||||
ErrorSpec(1e-3, 1e-3));
|
||||
}
|
||||
|
||||
XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Lower_2x4x4) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
@ -83,7 +84,7 @@ XlaOp Hypot(XlaOp x, XlaOp y) {
|
||||
// elements correspond to different rows and columns of the original matrix and
|
||||
// their rotations do not interfere and hence can be computed in parallel.
|
||||
//
|
||||
// The algorithm is based on slaev2 from LAPACK, modified to allow for
|
||||
// The algorithm is based on slaev2/claev2 from LAPACK, modified to allow for
|
||||
// vectorization.
|
||||
// In addition, slaev2 always returns the largest eigenvalue as rt1, which has
|
||||
// the effect of swapping eigenvalues around in the Jacob algorithm. This does
|
||||
@ -130,15 +131,28 @@ XlaOp Hypot(XlaOp x, XlaOp y) {
|
||||
// cosine, sine = (np.where(same_sign, -sine, cosine),
|
||||
// np.where(same_sign, cosine, sine))
|
||||
// return rt1, rt2, cosine, sine
|
||||
Eigh2x2 SymmetricEigenDecomposition2x2(XlaOp w_tl, XlaOp w_tr, XlaOp w_br) {
|
||||
auto a = GetMatrixDiagonal(w_tl);
|
||||
StatusOr<Eigh2x2> HermitianEigenDecomposition2x2(XlaOp w_tl, XlaOp w_tr,
|
||||
XlaOp w_br) {
|
||||
TF_ASSIGN_OR_RETURN(Shape w_tl_shape, w_tl.builder()->GetShape(w_tl));
|
||||
bool is_complex = primitive_util::IsComplexType(w_tl_shape.element_type());
|
||||
|
||||
auto a = GetMatrixDiagonal(Real(w_tl));
|
||||
auto b = GetMatrixDiagonal(w_tr);
|
||||
auto c = GetMatrixDiagonal(w_br);
|
||||
auto zero = ScalarLike(w_tl, 0.0);
|
||||
auto half = ScalarLike(w_tl, 0.5);
|
||||
auto neg_half = ScalarLike(w_tl, -0.5);
|
||||
auto one = ScalarLike(w_tl, 1.0);
|
||||
auto two = ScalarLike(w_tl, 2.0);
|
||||
auto abs_b = Abs(b);
|
||||
|
||||
XlaOp w;
|
||||
if (is_complex) {
|
||||
w = Select(Eq(abs_b, ZerosLike(abs_b)), FullLike(b, 1),
|
||||
Conj(b) / Complex(abs_b, ZerosLike(abs_b)));
|
||||
b = abs_b;
|
||||
}
|
||||
|
||||
auto c = GetMatrixDiagonal(Real(w_br));
|
||||
auto zero = ScalarLike(a, 0.0);
|
||||
auto half = ScalarLike(a, 0.5);
|
||||
auto neg_half = ScalarLike(a, -0.5);
|
||||
auto one = ScalarLike(a, 1.0);
|
||||
auto two = ScalarLike(a, 2.0);
|
||||
|
||||
auto ac_sum = a + c;
|
||||
auto ac_diff = a - c;
|
||||
@ -174,7 +188,13 @@ Eigh2x2 SymmetricEigenDecomposition2x2(XlaOp w_tl, XlaOp w_tr, XlaOp w_br) {
|
||||
|
||||
// Negate 'sine' because we are returning the first row of the rotation matrix
|
||||
// not the first eigenvector.
|
||||
return {rt1, rt2, cosine, -sine};
|
||||
if (is_complex) {
|
||||
rt1 = Complex(rt1, ZerosLike(rt1));
|
||||
rt2 = Complex(rt2, ZerosLike(rt2));
|
||||
cosine = Complex(cosine, ZerosLike(cosine));
|
||||
sine = Complex(sine, ZerosLike(sine)) * w;
|
||||
}
|
||||
return Eigh2x2{rt1, rt2, cosine, -sine};
|
||||
}
|
||||
|
||||
// tl, tr, bl, br = (
|
||||
@ -191,8 +211,10 @@ void ApplyJacobiRotationOverRows(Eigh2x2 rotation, XlaOp& tl, XlaOp& tr,
|
||||
auto c = BroadcastInDim(rotation.c, shape.dimensions(), broadcast_dims);
|
||||
auto s = BroadcastInDim(rotation.s, shape.dimensions(), broadcast_dims);
|
||||
|
||||
std::tie(tl, tr, bl, br) = std::make_tuple(tl * c - bl * s, tr * c - br * s,
|
||||
tl * s + bl * c, tr * s + br * c);
|
||||
auto s_conj = MaybeConjugate(s, true);
|
||||
std::tie(tl, tr, bl, br) =
|
||||
std::make_tuple(tl * c - bl * s_conj, tr * c - br * s_conj,
|
||||
tl * s + bl * c, tr * s + br * c);
|
||||
}
|
||||
|
||||
// tl, tr, bl, br = (
|
||||
@ -210,8 +232,10 @@ void ApplyJacobiRotationOverCols(Eigh2x2 rotation, XlaOp& tl, XlaOp& tr,
|
||||
auto c = BroadcastInDim(rotation.c, shape.dimensions(), broadcast_dims);
|
||||
auto s = BroadcastInDim(rotation.s, shape.dimensions(), broadcast_dims);
|
||||
|
||||
std::tie(tl, tr, bl, br) = std::make_tuple(tl * c - tr * s, tl * s + tr * c,
|
||||
bl * c - br * s, bl * s + br * c);
|
||||
auto s_conj = MaybeConjugate(s, true);
|
||||
std::tie(tl, tr, bl, br) =
|
||||
std::make_tuple(tl * c - tr * s, tl * s_conj + tr * c, bl * c - br * s,
|
||||
bl * s_conj + br * c);
|
||||
}
|
||||
|
||||
// def permute_rows_in_col(top, bottom):
|
||||
@ -270,9 +294,11 @@ void PermuteColumnsInRow(XlaOp& left, XlaOp& right) {
|
||||
// implicit way of computing a tournament for n players such that each player
|
||||
// plays every other player exactly once in n - 1 rounds. See the Brent/Luk
|
||||
// paper for more details.
|
||||
void ApplyRotations(int64 n, XlaOp& w_tl, XlaOp& w_tr, XlaOp& w_bl, XlaOp& w_br,
|
||||
XlaOp& v_tl, XlaOp& v_tr, XlaOp& v_bl, XlaOp& v_br) {
|
||||
Eigh2x2 rotation = SymmetricEigenDecomposition2x2(w_tl, w_tr, w_br);
|
||||
Status ApplyRotations(int64 n, XlaOp& w_tl, XlaOp& w_tr, XlaOp& w_bl,
|
||||
XlaOp& w_br, XlaOp& v_tl, XlaOp& v_tr, XlaOp& v_bl,
|
||||
XlaOp& v_br) {
|
||||
TF_ASSIGN_OR_RETURN(Eigh2x2 rotation,
|
||||
HermitianEigenDecomposition2x2(w_tl, w_tr, w_br));
|
||||
|
||||
ApplyJacobiRotationOverRows(rotation, w_tl, w_tr, w_bl, w_br);
|
||||
ApplyJacobiRotationOverCols(rotation, w_tl, w_tr, w_bl, w_br);
|
||||
@ -292,6 +318,7 @@ void ApplyRotations(int64 n, XlaOp& w_tl, XlaOp& w_tr, XlaOp& w_bl, XlaOp& w_br,
|
||||
ApplyJacobiRotationOverRows(rotation, v_tl, v_tr, v_bl, v_br);
|
||||
PermuteRowsInColumn(v_tl, v_bl);
|
||||
PermuteRowsInColumn(v_tr, v_br);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
struct FrobeniusNorms {
|
||||
@ -304,21 +331,28 @@ StatusOr<FrobeniusNorms> ComputeFrobeniusNorms(XlaOp w_tl, XlaOp w_tr,
|
||||
XlaBuilder* builder = w_tl.builder();
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(w_tl));
|
||||
const int64 num_dims = shape.rank();
|
||||
auto square_norm = [](XlaOp x) -> XlaOp {
|
||||
return Real(x * MaybeConjugate(x, true));
|
||||
};
|
||||
PrimitiveType norm_type =
|
||||
primitive_util::IsComplexType(shape.element_type())
|
||||
? primitive_util::ComplexComponentType(shape.element_type())
|
||||
: shape.element_type();
|
||||
auto zero = ScalarLike(Real(w_tl), 0.0);
|
||||
auto frobenius_norm =
|
||||
Sqrt(Reduce(Square(w_tl) + Square(w_tr) + Square(w_bl) + Square(w_br),
|
||||
ScalarLike(w_tl, 0.0),
|
||||
CreateScalarAddComputation(shape.element_type(), builder),
|
||||
Sqrt(Reduce(square_norm(w_tl) + square_norm(w_tr) + square_norm(w_bl) +
|
||||
square_norm(w_br),
|
||||
zero, CreateScalarAddComputation(norm_type, builder),
|
||||
{num_dims - 2, num_dims - 1}));
|
||||
auto diag_square =
|
||||
Reduce(Square(GetMatrixDiagonal(w_tl)) + Square(GetMatrixDiagonal(w_br)),
|
||||
ScalarLike(w_tl, 0.0),
|
||||
CreateScalarAddComputation(shape.element_type(), builder),
|
||||
{num_dims - 2});
|
||||
auto diag_square = Reduce(
|
||||
Square(GetMatrixDiagonal(Real(w_tl))) +
|
||||
Square(GetMatrixDiagonal(Real(w_br))),
|
||||
zero, CreateScalarAddComputation(norm_type, builder), {num_dims - 2});
|
||||
|
||||
FrobeniusNorms frobenius_norms;
|
||||
|
||||
frobenius_norms.off_diagonal_norm =
|
||||
Sqrt(Max(Square(frobenius_norm) - diag_square, ScalarLike(w_tl, 0.0)));
|
||||
Sqrt(Max(Square(frobenius_norm) - diag_square, zero));
|
||||
frobenius_norms.total_norm = frobenius_norm;
|
||||
|
||||
return frobenius_norms;
|
||||
@ -360,7 +394,8 @@ StatusOr<std::vector<XlaOp>> Sweeps(absl::Span<const XlaOp> initial_values,
|
||||
std::make_tuple(values[0], values[1], values[2], values[3],
|
||||
values[4], values[5], values[6], values[7],
|
||||
values[8]);
|
||||
ApplyRotations(n, w_tl, w_tr, w_bl, w_br, v_tl, v_tr, v_bl, v_br);
|
||||
TF_RETURN_IF_ERROR(ApplyRotations(n, w_tl, w_tr, w_bl, w_br, v_tl,
|
||||
v_tr, v_bl, v_br));
|
||||
return std::vector<XlaOp>{tol, w_tl, w_tr, w_bl, w_br,
|
||||
v_tl, v_tr, v_bl, v_br};
|
||||
},
|
||||
@ -376,9 +411,10 @@ StatusOr<std::vector<XlaOp>> Sweeps(absl::Span<const XlaOp> initial_values,
|
||||
|
||||
StatusOr<std::pair<XlaOp, XlaOp>> SortByEigenvalues(XlaOp v, XlaOp w) {
|
||||
XlaBuilder* builder = v.builder();
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(v));
|
||||
const int64 num_dims = shape.rank();
|
||||
auto dimensions = shape.dimensions();
|
||||
TF_ASSIGN_OR_RETURN(Shape v_shape, builder->GetShape(v));
|
||||
TF_ASSIGN_OR_RETURN(Shape w_shape, builder->GetShape(w));
|
||||
const int64 num_dims = v_shape.rank();
|
||||
auto dimensions = v_shape.dimensions();
|
||||
|
||||
std::vector<int64> broadcast_dims(num_dims - 1);
|
||||
std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
|
||||
@ -388,7 +424,7 @@ StatusOr<std::pair<XlaOp, XlaOp>> SortByEigenvalues(XlaOp v, XlaOp w) {
|
||||
XlaOp sort_result =
|
||||
Sort({w, v},
|
||||
CreateScalarLtComputation(
|
||||
{shape.element_type(), shape.element_type()}, builder),
|
||||
{w_shape.element_type(), v_shape.element_type()}, builder),
|
||||
num_dims - 1);
|
||||
w = GetMatrixDiagonal(GetTupleElement(sort_result, 0));
|
||||
v = GetTupleElement(sort_result, 1);
|
||||
@ -461,9 +497,12 @@ XlaOp EighExpander::BuildEigh(XlaOp a, bool lower, int64 max_iter, float tol) {
|
||||
a_shape.ToString());
|
||||
}
|
||||
PrimitiveType type = a_shape.element_type();
|
||||
if (!primitive_util::IsFloatingPointType(type)) {
|
||||
return InvalidArgument("Type of the input matrix must be float: got %s.",
|
||||
a_shape.ToString());
|
||||
if (!primitive_util::IsFloatingPointType(type) &&
|
||||
!primitive_util::IsComplexType(type)) {
|
||||
return InvalidArgument(
|
||||
"Type of the input matrix must be floating point "
|
||||
"or complex: got %s.",
|
||||
a_shape.ToString());
|
||||
}
|
||||
|
||||
const int64 m = ShapeUtil::GetDimension(a_shape, -2);
|
||||
@ -483,12 +522,10 @@ XlaOp EighExpander::BuildEigh(XlaOp a, bool lower, int64 max_iter, float tol) {
|
||||
}
|
||||
|
||||
if (m <= 1) {
|
||||
return Tuple(builder, {FullLike(a, 1), GetMatrixDiagonal(a)});
|
||||
return Tuple(builder, {FullLike(a, 1), GetMatrixDiagonal(Real(a))});
|
||||
}
|
||||
|
||||
auto eye = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims);
|
||||
a = Triangle(a, lower);
|
||||
a = a + TransposeInMinorDims(a) - a * eye;
|
||||
a = Symmetrize(a, lower);
|
||||
|
||||
const int64 k = CeilOfRatio(n, int64{2});
|
||||
// tl = A[:n // 2, :n // 2]
|
||||
@ -520,7 +557,7 @@ XlaOp EighExpander::BuildEigh(XlaOp a, bool lower, int64 max_iter, float tol) {
|
||||
TF_ASSIGN_OR_RETURN(auto output, Sweeps(
|
||||
{
|
||||
Zero(builder, S32),
|
||||
ScalarLike(a, tol),
|
||||
ScalarLike(Real(a), tol),
|
||||
tl,
|
||||
tr,
|
||||
bl,
|
||||
@ -532,22 +569,23 @@ XlaOp EighExpander::BuildEigh(XlaOp a, bool lower, int64 max_iter, float tol) {
|
||||
},
|
||||
k * 2, max_iter, S32, builder));
|
||||
|
||||
std::tie(tl, tr, bl, br) =
|
||||
std::make_tuple(output[2], output[3], output[4], output[5]);
|
||||
std::tie(v_tl, v_tr, v_bl, v_br) =
|
||||
std::make_tuple(output[6], output[7], output[8], output[9]);
|
||||
std::tie(tl, tr, bl, br) =
|
||||
std::make_tuple(output[2], output[3], output[4], output[5]);
|
||||
std::tie(v_tl, v_tr, v_bl, v_br) =
|
||||
std::make_tuple(output[6], output[7], output[8], output[9]);
|
||||
|
||||
auto w = ConcatInDim(builder, {GetMatrixDiagonal(tl), GetMatrixDiagonal(br)},
|
||||
num_dims - 2);
|
||||
auto v = ConcatInDim(builder,
|
||||
{ConcatInDim(builder, {v_tl, v_tr}, num_dims - 1),
|
||||
ConcatInDim(builder, {v_bl, v_br}, num_dims - 1)},
|
||||
num_dims - 2);
|
||||
if (n % 2) {
|
||||
w = SliceInMinorDims(w, {0}, {n});
|
||||
v = SliceInMinorDims(v, {0, 0}, {n, n});
|
||||
}
|
||||
v = TransposeInMinorDims(v);
|
||||
auto w = ConcatInDim(
|
||||
builder, {GetMatrixDiagonal(Real(tl)), GetMatrixDiagonal(Real(br))},
|
||||
num_dims - 2);
|
||||
auto v = ConcatInDim(builder,
|
||||
{ConcatInDim(builder, {v_tl, v_tr}, num_dims - 1),
|
||||
ConcatInDim(builder, {v_bl, v_br}, num_dims - 1)},
|
||||
num_dims - 2);
|
||||
if (n % 2) {
|
||||
w = SliceInMinorDims(w, {0}, {n});
|
||||
v = SliceInMinorDims(v, {0, 0}, {n, n});
|
||||
}
|
||||
v = MaybeConjugate(TransposeInMinorDims(v), true);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::tie(v, w), SortByEigenvalues(v, w));
|
||||
return Tuple(builder, {v, w});
|
||||
|
Loading…
x
Reference in New Issue
Block a user