[XLA] Add support for Hermitian eigendecompositions in eigh expander pass.

PiperOrigin-RevId: 360417733
Change-Id: Ic5f2e8eb60f54742ee0808015e099b78ac00258b
This commit is contained in:
Peter Hawkins 2021-03-02 06:51:52 -08:00 committed by TensorFlower Gardener
parent acb619833a
commit 29ba662acd
8 changed files with 220 additions and 70 deletions

View File

@ -57,10 +57,8 @@ class SelfAdjointEigV2Op : public XlaOpKernel {
}
};
REGISTER_XLA_OP(Name("XlaSelfAdjointEig").TypeConstraint("T", kFloatTypes),
XlaSelfAdjointEigOp);
REGISTER_XLA_OP(Name("SelfAdjointEigV2").TypeConstraint("T", kFloatTypes),
SelfAdjointEigV2Op);
REGISTER_XLA_OP(Name("XlaSelfAdjointEig"), XlaSelfAdjointEigOp);
REGISTER_XLA_OP(Name("SelfAdjointEigV2"), SelfAdjointEigV2Op);
} // namespace
} // namespace tensorflow

View File

@ -467,17 +467,15 @@ cc_library(
xla_test(
name = "self_adjoint_eig_test",
srcs = ["self_adjoint_eig_test.cc"],
disabled_backends = [
"cpu",
"gpu",
],
real_hardware_only = True,
tags = ["optonly"],
deps = [
":arithmetic",
":constants",
":math",
":matrix",
":self_adjoint_eig",
"//tensorflow/compiler/xla:array",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:literal",

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/status_macros.h"
@ -235,6 +236,36 @@ XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); }
XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); }
XlaOp Symmetrize(XlaOp x, bool lower) {
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
if (shape.rank() < 2) {
return InvalidArgument(
"Argument to symmetrize must have >= 2 dimensions, got %s",
shape.ToString());
}
const int64 m = ShapeUtil::GetDimension(shape, -2);
const int64 n = ShapeUtil::GetDimension(shape, -1);
if (m != n) {
return InvalidArgument(
"The two most minor dimensions of the argument to symmetrize must be "
"equal size, got %s",
shape.ToString());
}
auto mask = lower ? TriangleMask(x, 0) : Not(TriangleMask(x, -1));
if (primitive_util::IsComplexType(shape.element_type())) {
auto re = Select(mask, Real(x), TransposeInMinorDims(Real(x)));
auto im_mask = lower ? TriangleMask(x, -1) : Not(TriangleMask(x, 0));
auto im = Select(im_mask, Imag(x), ZerosLike(Imag(x)));
im = Select(mask, im, -TransposeInMinorDims(im));
return Complex(re, im);
} else {
return Select(mask, x, TransposeInMinorDims(x));
}
});
}
namespace {
absl::optional<std::array<std::vector<int64>, 3>> EinsumDiagonalLabels(
absl::Span<const int64> config) {

View File

@ -50,8 +50,8 @@ XlaOp GetMatrixDiagonalViaGather(XlaOp x, int k = 0);
// Places diag along the kth diagonal of target.
XlaOp SetMatrixDiagonal(XlaOp matrix, XlaOp diag, int k = 0);
// Returns a lower-triangular mask, i.e., true below the `diagonal`-th diagonal
// and false above that diagonal.
// Returns a lower-triangular mask, i.e., true below and including the
// `diagonal`-th diagonal and false above that diagonal.
XlaOp TriangleMask(XlaOp x, int diagonal);
// Get the upper or lower triangle part of the last two dimensions
@ -63,6 +63,13 @@ XlaOp UpperTriangle(XlaOp x);
// Get the lower triangle part of the last two dimensions
XlaOp LowerTriangle(XlaOp x);
// If x is an array of shape [..., n, n], symmetrizes the matrix by replacing
// the upper triangle with the transpose of the lower triangle (if lower is
// True, vice-versa otherwise). If the type of `x` is complex, makes the matrix
// Hermitian by taking the conjugate of the complex part and setting the
// complex diagonal to zero.
XlaOp Symmetrize(XlaOp x, bool lower);
// Multiplies slices of two tensors in batches.
// Multiplies all slices of `Tensor` `x` and `y` (each slice can be

View File

@ -73,6 +73,54 @@ XLA_TEST_F(MatrixTest, Triangle) {
ComputeAndCompareR3<int32>(&builder, expected, {a_data.get()});
}
XLA_TEST_F(MatrixTest, Symmetrize) {
for (bool lower : {false, true}) {
XlaBuilder builder(TestName());
float nan = std::numeric_limits<float>::quiet_NaN();
Array<float> input = {
{1, nan, nan},
{2, 3, nan},
{4, 5, 6},
};
XlaOp a;
auto a_data = CreateParameter<float>(input, 0, "a", &builder, &a);
Symmetrize(lower ? a : TransposeInMinorDims(a), /*lower=*/lower);
Array<float> expected = {
{1, 2, 4},
{2, 3, 5},
{4, 5, 6},
};
ComputeAndCompare<float>(&builder, expected, {a_data.get()});
}
}
XLA_TEST_F(MatrixTest, SymmetrizeComplex) {
for (bool lower : {false, true}) {
XlaBuilder builder(TestName());
float nan = std::numeric_limits<float>::quiet_NaN();
Array<complex64> input = {
{complex64{1, nan}, nan, nan},
{complex64{2, 7}, complex64{3, nan}, nan},
{complex64{4, 8}, complex64{5, 9}, complex64{6, nan}},
};
XlaOp a;
auto a_data = CreateParameter<complex64>(input, 0, "a", &builder, &a);
Symmetrize(lower ? a : Conj(TransposeInMinorDims(a)), /*lower=*/lower);
Array<complex64> expected = {
{1, complex64{2, -7}, complex64{4, -8}},
{complex64{2, 7}, 3, complex64{5, -9}},
{complex64{4, 8}, complex64{5, 9}, 6},
};
ComputeAndCompare<complex64>(&builder, expected, {a_data.get()});
}
}
template <typename T>
void MatrixTest::TestMatrixDiagonal() {
XlaBuilder builder("SetMatrixDiagonal");

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -47,9 +48,12 @@ SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower, int64 max_iter,
a_shape.ToString());
}
PrimitiveType type = a_shape.element_type();
if (!primitive_util::IsFloatingPointType(type)) {
return InvalidArgument("Type of the input matrix must be float: got %s.",
a_shape.ToString());
if (!primitive_util::IsFloatingPointType(type) &&
!primitive_util::IsComplexType(type)) {
return InvalidArgument(
"Type of the input matrix must be floating point "
"or complex: got %s.",
a_shape.ToString());
}
const int64 m = ShapeUtil::GetDimension(a_shape, -2);
@ -67,10 +71,14 @@ SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower, int64 max_iter,
a_shape.dimensions().begin(),
a_shape.dimensions().begin() + num_batch_dims);
PrimitiveType eigvals_type =
primitive_util::IsComplexType(type)
? primitive_util::ComplexComponentType(type)
: type;
std::vector<int64> eigvals_dims = batch_dims;
eigvals_dims.push_back(m);
Shape eigh_shape = ShapeUtil::MakeTupleShape(
{a_shape, ShapeUtil::MakeShape(type, eigvals_dims)});
{a_shape, ShapeUtil::MakeShape(eigvals_type, eigvals_dims)});
// TODO(phawkins): upgrade Eigh decomposition to a first-class HLO operator.
std::string opaque =
absl::StrFormat("%d,%d,%f", lower ? 1 : 0, max_iter, tol);

View File

@ -15,10 +15,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
#include "tensorflow/compiler/xla/array.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
@ -121,8 +123,12 @@ XlaOp ComputeMatmulVWVt(SelfAdjointEigResult result, XlaBuilder* builder) {
std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
broadcast_dims[shape.rank() - 2] = shape.rank() - 1;
auto vw = Mul(result.v, BroadcastInDim(result.w, out_dims, broadcast_dims));
return BatchDot(vw, TransposeInMinorDims(result.v), PrecisionConfig::HIGHEST);
auto vw =
Mul(result.v,
BroadcastInDim(ConvertElementType(result.w, shape.element_type()),
out_dims, broadcast_dims));
return BatchDot(vw, MaybeConjugate(TransposeInMinorDims(result.v), true),
PrecisionConfig::HIGHEST);
}
XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_2x4x4) {
@ -137,6 +143,22 @@ XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_2x4x4) {
ErrorSpec(1e-3, 1e-3));
}
XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_3x3_Complex) {
XlaBuilder builder(TestName());
Array<complex64> input = {
{1, complex64{2, -7}, complex64{4, -8}},
{complex64{2, 7}, 3, complex64{5, -9}},
{complex64{4, 8}, complex64{5, 9}, 6},
};
XlaOp a;
auto a_data = CreateParameter<complex64>(input, 0, "a", &builder, &a);
auto result = SelfAdjointEig(a);
ComputeMatmulVWVt(result, &builder);
ComputeAndCompare<complex64>(&builder, input, {a_data.get()},
ErrorSpec(1e-3, 1e-3));
}
XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Lower_2x4x4) {
XlaBuilder builder(TestName());

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -83,7 +84,7 @@ XlaOp Hypot(XlaOp x, XlaOp y) {
// elements correspond to different rows and columns of the original matrix and
// their rotations do not interfere and hence can be computed in parallel.
//
// The algorithm is based on slaev2 from LAPACK, modified to allow for
// The algorithm is based on slaev2/claev2 from LAPACK, modified to allow for
// vectorization.
// In addition, slaev2 always returns the largest eigenvalue as rt1, which has
// the effect of swapping eigenvalues around in the Jacob algorithm. This does
@ -130,15 +131,28 @@ XlaOp Hypot(XlaOp x, XlaOp y) {
// cosine, sine = (np.where(same_sign, -sine, cosine),
// np.where(same_sign, cosine, sine))
// return rt1, rt2, cosine, sine
Eigh2x2 SymmetricEigenDecomposition2x2(XlaOp w_tl, XlaOp w_tr, XlaOp w_br) {
auto a = GetMatrixDiagonal(w_tl);
StatusOr<Eigh2x2> HermitianEigenDecomposition2x2(XlaOp w_tl, XlaOp w_tr,
XlaOp w_br) {
TF_ASSIGN_OR_RETURN(Shape w_tl_shape, w_tl.builder()->GetShape(w_tl));
bool is_complex = primitive_util::IsComplexType(w_tl_shape.element_type());
auto a = GetMatrixDiagonal(Real(w_tl));
auto b = GetMatrixDiagonal(w_tr);
auto c = GetMatrixDiagonal(w_br);
auto zero = ScalarLike(w_tl, 0.0);
auto half = ScalarLike(w_tl, 0.5);
auto neg_half = ScalarLike(w_tl, -0.5);
auto one = ScalarLike(w_tl, 1.0);
auto two = ScalarLike(w_tl, 2.0);
auto abs_b = Abs(b);
XlaOp w;
if (is_complex) {
w = Select(Eq(abs_b, ZerosLike(abs_b)), FullLike(b, 1),
Conj(b) / Complex(abs_b, ZerosLike(abs_b)));
b = abs_b;
}
auto c = GetMatrixDiagonal(Real(w_br));
auto zero = ScalarLike(a, 0.0);
auto half = ScalarLike(a, 0.5);
auto neg_half = ScalarLike(a, -0.5);
auto one = ScalarLike(a, 1.0);
auto two = ScalarLike(a, 2.0);
auto ac_sum = a + c;
auto ac_diff = a - c;
@ -174,7 +188,13 @@ Eigh2x2 SymmetricEigenDecomposition2x2(XlaOp w_tl, XlaOp w_tr, XlaOp w_br) {
// Negate 'sine' because we are returning the first row of the rotation matrix
// not the first eigenvector.
return {rt1, rt2, cosine, -sine};
if (is_complex) {
rt1 = Complex(rt1, ZerosLike(rt1));
rt2 = Complex(rt2, ZerosLike(rt2));
cosine = Complex(cosine, ZerosLike(cosine));
sine = Complex(sine, ZerosLike(sine)) * w;
}
return Eigh2x2{rt1, rt2, cosine, -sine};
}
// tl, tr, bl, br = (
@ -191,8 +211,10 @@ void ApplyJacobiRotationOverRows(Eigh2x2 rotation, XlaOp& tl, XlaOp& tr,
auto c = BroadcastInDim(rotation.c, shape.dimensions(), broadcast_dims);
auto s = BroadcastInDim(rotation.s, shape.dimensions(), broadcast_dims);
std::tie(tl, tr, bl, br) = std::make_tuple(tl * c - bl * s, tr * c - br * s,
tl * s + bl * c, tr * s + br * c);
auto s_conj = MaybeConjugate(s, true);
std::tie(tl, tr, bl, br) =
std::make_tuple(tl * c - bl * s_conj, tr * c - br * s_conj,
tl * s + bl * c, tr * s + br * c);
}
// tl, tr, bl, br = (
@ -210,8 +232,10 @@ void ApplyJacobiRotationOverCols(Eigh2x2 rotation, XlaOp& tl, XlaOp& tr,
auto c = BroadcastInDim(rotation.c, shape.dimensions(), broadcast_dims);
auto s = BroadcastInDim(rotation.s, shape.dimensions(), broadcast_dims);
std::tie(tl, tr, bl, br) = std::make_tuple(tl * c - tr * s, tl * s + tr * c,
bl * c - br * s, bl * s + br * c);
auto s_conj = MaybeConjugate(s, true);
std::tie(tl, tr, bl, br) =
std::make_tuple(tl * c - tr * s, tl * s_conj + tr * c, bl * c - br * s,
bl * s_conj + br * c);
}
// def permute_rows_in_col(top, bottom):
@ -270,9 +294,11 @@ void PermuteColumnsInRow(XlaOp& left, XlaOp& right) {
// implicit way of computing a tournament for n players such that each player
// plays every other player exactly once in n - 1 rounds. See the Brent/Luk
// paper for more details.
void ApplyRotations(int64 n, XlaOp& w_tl, XlaOp& w_tr, XlaOp& w_bl, XlaOp& w_br,
XlaOp& v_tl, XlaOp& v_tr, XlaOp& v_bl, XlaOp& v_br) {
Eigh2x2 rotation = SymmetricEigenDecomposition2x2(w_tl, w_tr, w_br);
Status ApplyRotations(int64 n, XlaOp& w_tl, XlaOp& w_tr, XlaOp& w_bl,
XlaOp& w_br, XlaOp& v_tl, XlaOp& v_tr, XlaOp& v_bl,
XlaOp& v_br) {
TF_ASSIGN_OR_RETURN(Eigh2x2 rotation,
HermitianEigenDecomposition2x2(w_tl, w_tr, w_br));
ApplyJacobiRotationOverRows(rotation, w_tl, w_tr, w_bl, w_br);
ApplyJacobiRotationOverCols(rotation, w_tl, w_tr, w_bl, w_br);
@ -292,6 +318,7 @@ void ApplyRotations(int64 n, XlaOp& w_tl, XlaOp& w_tr, XlaOp& w_bl, XlaOp& w_br,
ApplyJacobiRotationOverRows(rotation, v_tl, v_tr, v_bl, v_br);
PermuteRowsInColumn(v_tl, v_bl);
PermuteRowsInColumn(v_tr, v_br);
return Status::OK();
}
struct FrobeniusNorms {
@ -304,21 +331,28 @@ StatusOr<FrobeniusNorms> ComputeFrobeniusNorms(XlaOp w_tl, XlaOp w_tr,
XlaBuilder* builder = w_tl.builder();
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(w_tl));
const int64 num_dims = shape.rank();
auto square_norm = [](XlaOp x) -> XlaOp {
return Real(x * MaybeConjugate(x, true));
};
PrimitiveType norm_type =
primitive_util::IsComplexType(shape.element_type())
? primitive_util::ComplexComponentType(shape.element_type())
: shape.element_type();
auto zero = ScalarLike(Real(w_tl), 0.0);
auto frobenius_norm =
Sqrt(Reduce(Square(w_tl) + Square(w_tr) + Square(w_bl) + Square(w_br),
ScalarLike(w_tl, 0.0),
CreateScalarAddComputation(shape.element_type(), builder),
Sqrt(Reduce(square_norm(w_tl) + square_norm(w_tr) + square_norm(w_bl) +
square_norm(w_br),
zero, CreateScalarAddComputation(norm_type, builder),
{num_dims - 2, num_dims - 1}));
auto diag_square =
Reduce(Square(GetMatrixDiagonal(w_tl)) + Square(GetMatrixDiagonal(w_br)),
ScalarLike(w_tl, 0.0),
CreateScalarAddComputation(shape.element_type(), builder),
{num_dims - 2});
auto diag_square = Reduce(
Square(GetMatrixDiagonal(Real(w_tl))) +
Square(GetMatrixDiagonal(Real(w_br))),
zero, CreateScalarAddComputation(norm_type, builder), {num_dims - 2});
FrobeniusNorms frobenius_norms;
frobenius_norms.off_diagonal_norm =
Sqrt(Max(Square(frobenius_norm) - diag_square, ScalarLike(w_tl, 0.0)));
Sqrt(Max(Square(frobenius_norm) - diag_square, zero));
frobenius_norms.total_norm = frobenius_norm;
return frobenius_norms;
@ -360,7 +394,8 @@ StatusOr<std::vector<XlaOp>> Sweeps(absl::Span<const XlaOp> initial_values,
std::make_tuple(values[0], values[1], values[2], values[3],
values[4], values[5], values[6], values[7],
values[8]);
ApplyRotations(n, w_tl, w_tr, w_bl, w_br, v_tl, v_tr, v_bl, v_br);
TF_RETURN_IF_ERROR(ApplyRotations(n, w_tl, w_tr, w_bl, w_br, v_tl,
v_tr, v_bl, v_br));
return std::vector<XlaOp>{tol, w_tl, w_tr, w_bl, w_br,
v_tl, v_tr, v_bl, v_br};
},
@ -376,9 +411,10 @@ StatusOr<std::vector<XlaOp>> Sweeps(absl::Span<const XlaOp> initial_values,
StatusOr<std::pair<XlaOp, XlaOp>> SortByEigenvalues(XlaOp v, XlaOp w) {
XlaBuilder* builder = v.builder();
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(v));
const int64 num_dims = shape.rank();
auto dimensions = shape.dimensions();
TF_ASSIGN_OR_RETURN(Shape v_shape, builder->GetShape(v));
TF_ASSIGN_OR_RETURN(Shape w_shape, builder->GetShape(w));
const int64 num_dims = v_shape.rank();
auto dimensions = v_shape.dimensions();
std::vector<int64> broadcast_dims(num_dims - 1);
std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
@ -388,7 +424,7 @@ StatusOr<std::pair<XlaOp, XlaOp>> SortByEigenvalues(XlaOp v, XlaOp w) {
XlaOp sort_result =
Sort({w, v},
CreateScalarLtComputation(
{shape.element_type(), shape.element_type()}, builder),
{w_shape.element_type(), v_shape.element_type()}, builder),
num_dims - 1);
w = GetMatrixDiagonal(GetTupleElement(sort_result, 0));
v = GetTupleElement(sort_result, 1);
@ -461,9 +497,12 @@ XlaOp EighExpander::BuildEigh(XlaOp a, bool lower, int64 max_iter, float tol) {
a_shape.ToString());
}
PrimitiveType type = a_shape.element_type();
if (!primitive_util::IsFloatingPointType(type)) {
return InvalidArgument("Type of the input matrix must be float: got %s.",
a_shape.ToString());
if (!primitive_util::IsFloatingPointType(type) &&
!primitive_util::IsComplexType(type)) {
return InvalidArgument(
"Type of the input matrix must be floating point "
"or complex: got %s.",
a_shape.ToString());
}
const int64 m = ShapeUtil::GetDimension(a_shape, -2);
@ -483,12 +522,10 @@ XlaOp EighExpander::BuildEigh(XlaOp a, bool lower, int64 max_iter, float tol) {
}
if (m <= 1) {
return Tuple(builder, {FullLike(a, 1), GetMatrixDiagonal(a)});
return Tuple(builder, {FullLike(a, 1), GetMatrixDiagonal(Real(a))});
}
auto eye = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims);
a = Triangle(a, lower);
a = a + TransposeInMinorDims(a) - a * eye;
a = Symmetrize(a, lower);
const int64 k = CeilOfRatio(n, int64{2});
// tl = A[:n // 2, :n // 2]
@ -520,7 +557,7 @@ XlaOp EighExpander::BuildEigh(XlaOp a, bool lower, int64 max_iter, float tol) {
TF_ASSIGN_OR_RETURN(auto output, Sweeps(
{
Zero(builder, S32),
ScalarLike(a, tol),
ScalarLike(Real(a), tol),
tl,
tr,
bl,
@ -532,22 +569,23 @@ XlaOp EighExpander::BuildEigh(XlaOp a, bool lower, int64 max_iter, float tol) {
},
k * 2, max_iter, S32, builder));
std::tie(tl, tr, bl, br) =
std::make_tuple(output[2], output[3], output[4], output[5]);
std::tie(v_tl, v_tr, v_bl, v_br) =
std::make_tuple(output[6], output[7], output[8], output[9]);
std::tie(tl, tr, bl, br) =
std::make_tuple(output[2], output[3], output[4], output[5]);
std::tie(v_tl, v_tr, v_bl, v_br) =
std::make_tuple(output[6], output[7], output[8], output[9]);
auto w = ConcatInDim(builder, {GetMatrixDiagonal(tl), GetMatrixDiagonal(br)},
num_dims - 2);
auto v = ConcatInDim(builder,
{ConcatInDim(builder, {v_tl, v_tr}, num_dims - 1),
ConcatInDim(builder, {v_bl, v_br}, num_dims - 1)},
num_dims - 2);
if (n % 2) {
w = SliceInMinorDims(w, {0}, {n});
v = SliceInMinorDims(v, {0, 0}, {n, n});
}
v = TransposeInMinorDims(v);
auto w = ConcatInDim(
builder, {GetMatrixDiagonal(Real(tl)), GetMatrixDiagonal(Real(br))},
num_dims - 2);
auto v = ConcatInDim(builder,
{ConcatInDim(builder, {v_tl, v_tr}, num_dims - 1),
ConcatInDim(builder, {v_bl, v_br}, num_dims - 1)},
num_dims - 2);
if (n % 2) {
w = SliceInMinorDims(w, {0}, {n});
v = SliceInMinorDims(v, {0, 0}, {n, n});
}
v = MaybeConjugate(TransposeInMinorDims(v), true);
TF_ASSIGN_OR_RETURN(std::tie(v, w), SortByEigenvalues(v, w));
return Tuple(builder, {v, w});