[XLA] Add support for complex numbers to Qr decomposition expander.

PiperOrigin-RevId: 333208193
Change-Id: Ic9adc699a11ffcc23a0ae518b54ee29cce8569ce
This commit is contained in:
Peter Hawkins 2020-09-22 19:28:18 -07:00 committed by TensorFlower Gardener
parent 50540465b5
commit 3daf30f97d
3 changed files with 77 additions and 39 deletions

View File

@ -74,8 +74,14 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase):
def _test(self, dtype, shape, full_matrices):
np.random.seed(1)
x_np = np.random.uniform(
low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype)
def rng():
return np.random.uniform(
low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype)
x_np = rng()
if np.issubdtype(dtype, np.complexfloating):
x_np += rng() * dtype(1j)
with self.session() as sess:
x_tf = array_ops.placeholder(dtype)
@ -102,7 +108,7 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase):
self.CheckUnitary(q_tf_val)
SIZES = [1, 2, 5, 10, 32, 100, 300]
DTYPES = [np.float32]
DTYPES = [np.float32, np.complex64]
PARAMS = itertools.product(SIZES, SIZES, DTYPES)
@parameterized.parameters(*PARAMS)

View File

@ -41,7 +41,7 @@ class QROp : public XlaOpKernel {
bool full_matrices_;
};
REGISTER_XLA_OP(Name("Qr").TypeConstraint("T", kFloatTypes), QROp);
REGISTER_XLA_OP(Name("Qr"), QROp);
} // namespace
} // namespace tensorflow

View File

@ -63,13 +63,16 @@ std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
// x_copy = np.copy(x)
// x_copy[:k+1] = 0
// xnorm = norm2(x_copy)
// if xnorm == 0:
// if xnorm == 0 and np.imag(alpha) == 0:
// beta = alpha
// tau = 0
// v = np.zeros_like(x)
// else:
// beta = - np.sign(alpha) * dlapy2(alpha, xnorm)
// tau = (beta - alpha) / beta
// beta = -np.sign(np.real(alpha)) * np.sqrt(alpha * np.conj(alpha) + xnorm)
// if np.issubdtype(x.dtype, np.complexfloating):
// tau = (beta - alpha) / beta
// else:
// tau = (beta - np.real(alpha) / beta) + (-np.imag(alpha) / beta) * 1j
// v = x / (alpha - beta)
// v[k] = 1
// return (v, tau, beta)
@ -86,7 +89,6 @@ Status House(XlaOp x, XlaOp k, absl::Span<const int64> batch_dims,
const int64 minor_dim = batch_dims.size();
XlaOp zero = ScalarLike(x, 0.0);
XlaOp one = ScalarLike(x, 1.0);
// alpha = x[k]
XlaOp alpha = Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims);
@ -96,20 +98,46 @@ Status House(XlaOp x, XlaOp k, absl::Span<const int64> batch_dims,
XlaOp x_after_k = Mul(x, ConvertElementType(Gt(iota, k), type),
/*broadcast_dimensions=*/{minor_dim});
// sigma = np.dot(x[k+1:], x[k+1:])
// TODO(phawkins): this calculation may be numerically unstable.
auto sigma = Reduce(x_after_k * x_after_k, zero,
CreateScalarAddComputation(type, builder), {minor_dim});
// mu = np.sqrt(x[k]*x[k] + sigma)
auto mu = Sqrt(Square(alpha) + sigma);
XlaOp sigma_is_zero;
if (primitive_util::IsComplexType(type)) {
// sigma = np.dot(x[k+1:], np.conj(x[k+1:]))
// TODO(phawkins): this calculation may be numerically unstable.
auto x_squared = Real(x_after_k * Conj(x_after_k));
auto sigma =
Reduce(x_squared, ScalarLike(x_squared, 0.0),
CreateScalarAddComputation(
primitive_util::ComplexComponentType(type), builder),
{minor_dim});
// mu = np.sqrt(x[k]*np.con(x[k]) + sigma)
auto mu = Sqrt(Real(alpha * Conj(alpha)) + sigma);
auto sigma_is_zero = Eq(sigma, zero);
sigma_is_zero = Eq(sigma, ScalarLike(sigma, 0));
sigma_is_zero = And(sigma_is_zero, Eq(Imag(alpha), ScalarLike(sigma, 0)));
*beta = Select(Lt(Real(alpha), ScalarLike(sigma, 0)), ScalarLike(mu, 1),
ScalarLike(mu, -1)) *
mu;
*beta = Select(sigma_is_zero, Real(alpha), *beta);
*tau = Complex((*beta - Real(alpha)) / *beta, -Imag(alpha) / *beta);
} else {
// sigma = np.dot(x[k+1:], x[k+1:])
// TODO(phawkins): this calculation may be numerically unstable.
auto sigma = Reduce(x_after_k * x_after_k, zero,
CreateScalarAddComputation(type, builder), {minor_dim});
// mu = np.sqrt(x[k]*x[k] + sigma)
auto mu = Sqrt(Square(alpha) + sigma);
sigma_is_zero = Eq(sigma, zero);
XlaOp one = ScalarLike(x, 1.0);
*beta = Select(Lt(alpha, zero), one, -one) * mu;
*beta = Select(sigma_is_zero, alpha, *beta);
*tau = (*beta - alpha) / *beta;
}
*tau = Select(sigma_is_zero, ZerosLike(*tau), *tau);
*beta = Select(sigma_is_zero, alpha, Select(Lt(alpha, zero), one, -one) * mu);
*tau = Select(sigma_is_zero, Broadcast(zero, batch_dims),
(*beta - alpha) / *beta);
auto divisor =
Select(sigma_is_zero, Broadcast(one, batch_dims), alpha - *beta);
Select(sigma_is_zero, Broadcast(ScalarLike(alpha, 1), batch_dims),
alpha - ConvertElementType(*beta, type));
auto e_k = Broadcast(ConvertElementType(Eq(iota, k), type),
std::vector<int64>(batch_dims.size(), 1));
@ -136,8 +164,8 @@ Status House(XlaOp x, XlaOp k, absl::Span<const int64> batch_dims,
// taus = np.zeros([n])
// for j in xrange(min(m, n)):
// v, tau, beta = house(a[:, j], j)
// a[:, j+1:] -= tau * np.dot(v[:, np.newaxis],
// np.dot(v[np.newaxis, :], a[:, j+1:]))
// a[:, j+1:] -= np.conj(tau) * np.dot(v[:, np.newaxis],
// np.dot(np.conj(v[np.newaxis, :]), a[:, j+1:]))
// # Form column j explicitly rather than relying on the precision of the
// # Householder update.
// a[j, j] = beta
@ -187,13 +215,14 @@ StatusOr<QrExpander::QrResult> QrExpander::QrBlock(
shape.push_back(1);
shape.push_back(m);
auto v_broadcast = Reshape(v, shape);
// a[:, j+1:] -= tau * (v[:, np.newaxis] @ (v[np.newaxis, :] @ a[:, j+1:]))
// a[:, j+1:] -= np.conj(tau) * (v[:, np.newaxis] @
// (np.conj(v[np.newaxis, :]) @ a[:, j+1:]))
// We use masking rather than a loop-variant shape to handle the j+1:
// indexing.
auto vva = BatchDot(v_broadcast, Select(Lt(j, iota_mn), a, ZerosLike(a)),
precision);
auto vva = BatchDot(MaybeConjugate(v_broadcast, true),
Select(Lt(j, iota_mn), a, ZerosLike(a)), precision);
vva = BatchDot(v_broadcast, true, vva, false, precision);
a = a - Mul(tau, vva,
a = a - Mul(MaybeConjugate(tau, true), vva,
/*broadcast_dimensions=*/batch_dim_indices);
// a[j, j] = beta
@ -205,7 +234,8 @@ StatusOr<QrExpander::QrResult> QrExpander::QrBlock(
auto successor_mask = Gt(Iota(a.builder(), S32, m), j);
auto new_x = Mul(x, predecessor_mask,
/*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) +
Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices);
Mul(ConvertElementType(beta, type), mask,
/*broadcast_dimensions=*/batch_dim_indices);
new_x = Add(
new_x, Select(Broadcast(successor_mask, batch_dims), v, ZerosLike(v)),
/*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {minor_dim}));
@ -257,7 +287,7 @@ StatusOr<QrExpander::QrResult> QrExpander::QrBlock(
// t = np.eye(n) * -taus
// # We premultiply Y.T @ vs, since we would prefer to compute a single matrix
// # multiplication to many matrix-vector products.
// vtv = -taus[None, :] * np.triu(vs.T @ vs, 1) + np.eye(n)
// vtv = -taus[None, :] * np.triu(np.conj(vs.T) @ vs, 1) + np.eye(n)
// for i in range(1, n):
// t[:, i] = scipy.linalg.blas.strmm(t, vtv[:, i])
// return t
@ -293,8 +323,8 @@ StatusOr<XlaOp> QrExpander::CompactWYRepresentation(
auto eye = Broadcast(IdentityMatrix(builder, type, n, n), batch_dims);
auto t = eye;
auto vtv =
BatchDot(vs, /*transpose_x=*/true, vs, /*transpose_y=*/false, precision);
auto vtv = BatchDot(MaybeConjugate(vs, true), /*transpose_x=*/true, vs,
/*transpose_y=*/false, precision);
vtv = Select(TriangleMask(vtv, 0), ZerosLike(vtv), vtv);
vtv = (vtv + eye) * tau_scale;
@ -313,8 +343,8 @@ StatusOr<XlaOp> QrExpander::CompactWYRepresentation(
// (a, taus) = qr(a[i:, i:i+k])
// y = np.eye(m, n) + np.tril(a, -1)
// t = CompactWYRepresentation(vs, taus, m-i, k)
// a[i:, i+k:] += (y @ t.T) @ (y.T @ a[i:, i+k:])
// q[:, i:] += (q[:, i:] @ y) @ (y @ t.T).T
// a[i:, i+k:] += (y @ np.conj(t.T)) @ (np.conj(y.T) @ a[i:, i+k:])
// q[:, i:] += (q[:, i:] @ y) @ np.conj((y @ np.conj(t.T)).T)
// return (q, a)
StatusOr<XlaOp> QrExpander::BuildQrDecomposition(
XlaOp a, int64 block_size, PrecisionConfig::Precision precision) {
@ -361,21 +391,23 @@ StatusOr<XlaOp> QrExpander::BuildQrDecomposition(
auto t, CompactWYRepresentation(type, batch_dims, y, qr_block.taus,
m - i, k, precision));
// a[i:, i+k:] += (y @ t.T) @ (y.T @ a[i:, i+k:])
auto yt =
BatchDot(y, /*transpose_x=*/false, t, /*transpose_y=*/true, precision);
// a[i:, i+k:] += (y @ np.conj(t.T)) @ (np.conj(y.T) @ a[i:, i+k:])
auto yt = BatchDot(y, /*transpose_x=*/false, MaybeConjugate(t, true),
/*transpose_y=*/true, precision);
auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n});
auto a_update = BatchDot(y, /*transpose_x=*/true, a_panel,
/*transpose_y=*/false, precision);
auto a_update =
BatchDot(MaybeConjugate(y, true), /*transpose_x=*/true, a_panel,
/*transpose_y=*/false, precision);
a_update = BatchDot(yt, a_update, precision);
a_panel = a_panel + a_update;
a = UpdateSliceInMinorDims(a, a_panel, {i, i + k});
// q[:, i:] += (q[:, i:] @ y) @ (y @ t.T).T
// q[:, i:] += (q[:, i:] @ y) @ np.conj((y @ np.conj(t.T)).T)
auto q_panel = SliceInMinorDims(q, {0, i}, {m, m});
auto q_update = BatchDot(q_panel, y, precision);
q_update = BatchDot(q_update, /*transpose_x=*/false, yt,
/*transpose_y=*/true, precision);
q_update =
BatchDot(q_update, /*transpose_x=*/false, MaybeConjugate(yt, true),
/*transpose_y=*/true, precision);
q_panel = q_panel + q_update;
q = UpdateSliceInMinorDims(q, q_panel, {0, i});
}