[XLA] Add support for complex numbers to Qr decomposition expander.
PiperOrigin-RevId: 333208193 Change-Id: Ic9adc699a11ffcc23a0ae518b54ee29cce8569ce
This commit is contained in:
parent
50540465b5
commit
3daf30f97d
@ -74,8 +74,14 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
def _test(self, dtype, shape, full_matrices):
|
||||
np.random.seed(1)
|
||||
x_np = np.random.uniform(
|
||||
low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype)
|
||||
|
||||
def rng():
|
||||
return np.random.uniform(
|
||||
low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype)
|
||||
|
||||
x_np = rng()
|
||||
if np.issubdtype(dtype, np.complexfloating):
|
||||
x_np += rng() * dtype(1j)
|
||||
|
||||
with self.session() as sess:
|
||||
x_tf = array_ops.placeholder(dtype)
|
||||
@ -102,7 +108,7 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
self.CheckUnitary(q_tf_val)
|
||||
|
||||
SIZES = [1, 2, 5, 10, 32, 100, 300]
|
||||
DTYPES = [np.float32]
|
||||
DTYPES = [np.float32, np.complex64]
|
||||
PARAMS = itertools.product(SIZES, SIZES, DTYPES)
|
||||
|
||||
@parameterized.parameters(*PARAMS)
|
||||
|
||||
@ -41,7 +41,7 @@ class QROp : public XlaOpKernel {
|
||||
bool full_matrices_;
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("Qr").TypeConstraint("T", kFloatTypes), QROp);
|
||||
REGISTER_XLA_OP(Name("Qr"), QROp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
@ -63,13 +63,16 @@ std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
|
||||
// x_copy = np.copy(x)
|
||||
// x_copy[:k+1] = 0
|
||||
// xnorm = norm2(x_copy)
|
||||
// if xnorm == 0:
|
||||
// if xnorm == 0 and np.imag(alpha) == 0:
|
||||
// beta = alpha
|
||||
// tau = 0
|
||||
// v = np.zeros_like(x)
|
||||
// else:
|
||||
// beta = - np.sign(alpha) * dlapy2(alpha, xnorm)
|
||||
// tau = (beta - alpha) / beta
|
||||
// beta = -np.sign(np.real(alpha)) * np.sqrt(alpha * np.conj(alpha) + xnorm)
|
||||
// if np.issubdtype(x.dtype, np.complexfloating):
|
||||
// tau = (beta - alpha) / beta
|
||||
// else:
|
||||
// tau = (beta - np.real(alpha) / beta) + (-np.imag(alpha) / beta) * 1j
|
||||
// v = x / (alpha - beta)
|
||||
// v[k] = 1
|
||||
// return (v, tau, beta)
|
||||
@ -86,7 +89,6 @@ Status House(XlaOp x, XlaOp k, absl::Span<const int64> batch_dims,
|
||||
const int64 minor_dim = batch_dims.size();
|
||||
|
||||
XlaOp zero = ScalarLike(x, 0.0);
|
||||
XlaOp one = ScalarLike(x, 1.0);
|
||||
|
||||
// alpha = x[k]
|
||||
XlaOp alpha = Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims);
|
||||
@ -96,20 +98,46 @@ Status House(XlaOp x, XlaOp k, absl::Span<const int64> batch_dims,
|
||||
XlaOp x_after_k = Mul(x, ConvertElementType(Gt(iota, k), type),
|
||||
/*broadcast_dimensions=*/{minor_dim});
|
||||
|
||||
// sigma = np.dot(x[k+1:], x[k+1:])
|
||||
// TODO(phawkins): this calculation may be numerically unstable.
|
||||
auto sigma = Reduce(x_after_k * x_after_k, zero,
|
||||
CreateScalarAddComputation(type, builder), {minor_dim});
|
||||
// mu = np.sqrt(x[k]*x[k] + sigma)
|
||||
auto mu = Sqrt(Square(alpha) + sigma);
|
||||
XlaOp sigma_is_zero;
|
||||
if (primitive_util::IsComplexType(type)) {
|
||||
// sigma = np.dot(x[k+1:], np.conj(x[k+1:]))
|
||||
// TODO(phawkins): this calculation may be numerically unstable.
|
||||
auto x_squared = Real(x_after_k * Conj(x_after_k));
|
||||
auto sigma =
|
||||
Reduce(x_squared, ScalarLike(x_squared, 0.0),
|
||||
CreateScalarAddComputation(
|
||||
primitive_util::ComplexComponentType(type), builder),
|
||||
{minor_dim});
|
||||
// mu = np.sqrt(x[k]*np.con(x[k]) + sigma)
|
||||
auto mu = Sqrt(Real(alpha * Conj(alpha)) + sigma);
|
||||
|
||||
auto sigma_is_zero = Eq(sigma, zero);
|
||||
sigma_is_zero = Eq(sigma, ScalarLike(sigma, 0));
|
||||
sigma_is_zero = And(sigma_is_zero, Eq(Imag(alpha), ScalarLike(sigma, 0)));
|
||||
|
||||
*beta = Select(Lt(Real(alpha), ScalarLike(sigma, 0)), ScalarLike(mu, 1),
|
||||
ScalarLike(mu, -1)) *
|
||||
mu;
|
||||
*beta = Select(sigma_is_zero, Real(alpha), *beta);
|
||||
*tau = Complex((*beta - Real(alpha)) / *beta, -Imag(alpha) / *beta);
|
||||
} else {
|
||||
// sigma = np.dot(x[k+1:], x[k+1:])
|
||||
// TODO(phawkins): this calculation may be numerically unstable.
|
||||
auto sigma = Reduce(x_after_k * x_after_k, zero,
|
||||
CreateScalarAddComputation(type, builder), {minor_dim});
|
||||
// mu = np.sqrt(x[k]*x[k] + sigma)
|
||||
auto mu = Sqrt(Square(alpha) + sigma);
|
||||
sigma_is_zero = Eq(sigma, zero);
|
||||
|
||||
XlaOp one = ScalarLike(x, 1.0);
|
||||
*beta = Select(Lt(alpha, zero), one, -one) * mu;
|
||||
*beta = Select(sigma_is_zero, alpha, *beta);
|
||||
*tau = (*beta - alpha) / *beta;
|
||||
}
|
||||
*tau = Select(sigma_is_zero, ZerosLike(*tau), *tau);
|
||||
|
||||
*beta = Select(sigma_is_zero, alpha, Select(Lt(alpha, zero), one, -one) * mu);
|
||||
*tau = Select(sigma_is_zero, Broadcast(zero, batch_dims),
|
||||
(*beta - alpha) / *beta);
|
||||
auto divisor =
|
||||
Select(sigma_is_zero, Broadcast(one, batch_dims), alpha - *beta);
|
||||
Select(sigma_is_zero, Broadcast(ScalarLike(alpha, 1), batch_dims),
|
||||
alpha - ConvertElementType(*beta, type));
|
||||
|
||||
auto e_k = Broadcast(ConvertElementType(Eq(iota, k), type),
|
||||
std::vector<int64>(batch_dims.size(), 1));
|
||||
@ -136,8 +164,8 @@ Status House(XlaOp x, XlaOp k, absl::Span<const int64> batch_dims,
|
||||
// taus = np.zeros([n])
|
||||
// for j in xrange(min(m, n)):
|
||||
// v, tau, beta = house(a[:, j], j)
|
||||
// a[:, j+1:] -= tau * np.dot(v[:, np.newaxis],
|
||||
// np.dot(v[np.newaxis, :], a[:, j+1:]))
|
||||
// a[:, j+1:] -= np.conj(tau) * np.dot(v[:, np.newaxis],
|
||||
// np.dot(np.conj(v[np.newaxis, :]), a[:, j+1:]))
|
||||
// # Form column j explicitly rather than relying on the precision of the
|
||||
// # Householder update.
|
||||
// a[j, j] = beta
|
||||
@ -187,13 +215,14 @@ StatusOr<QrExpander::QrResult> QrExpander::QrBlock(
|
||||
shape.push_back(1);
|
||||
shape.push_back(m);
|
||||
auto v_broadcast = Reshape(v, shape);
|
||||
// a[:, j+1:] -= tau * (v[:, np.newaxis] @ (v[np.newaxis, :] @ a[:, j+1:]))
|
||||
// a[:, j+1:] -= np.conj(tau) * (v[:, np.newaxis] @
|
||||
// (np.conj(v[np.newaxis, :]) @ a[:, j+1:]))
|
||||
// We use masking rather than a loop-variant shape to handle the j+1:
|
||||
// indexing.
|
||||
auto vva = BatchDot(v_broadcast, Select(Lt(j, iota_mn), a, ZerosLike(a)),
|
||||
precision);
|
||||
auto vva = BatchDot(MaybeConjugate(v_broadcast, true),
|
||||
Select(Lt(j, iota_mn), a, ZerosLike(a)), precision);
|
||||
vva = BatchDot(v_broadcast, true, vva, false, precision);
|
||||
a = a - Mul(tau, vva,
|
||||
a = a - Mul(MaybeConjugate(tau, true), vva,
|
||||
/*broadcast_dimensions=*/batch_dim_indices);
|
||||
|
||||
// a[j, j] = beta
|
||||
@ -205,7 +234,8 @@ StatusOr<QrExpander::QrResult> QrExpander::QrBlock(
|
||||
auto successor_mask = Gt(Iota(a.builder(), S32, m), j);
|
||||
auto new_x = Mul(x, predecessor_mask,
|
||||
/*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) +
|
||||
Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices);
|
||||
Mul(ConvertElementType(beta, type), mask,
|
||||
/*broadcast_dimensions=*/batch_dim_indices);
|
||||
new_x = Add(
|
||||
new_x, Select(Broadcast(successor_mask, batch_dims), v, ZerosLike(v)),
|
||||
/*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {minor_dim}));
|
||||
@ -257,7 +287,7 @@ StatusOr<QrExpander::QrResult> QrExpander::QrBlock(
|
||||
// t = np.eye(n) * -taus
|
||||
// # We premultiply Y.T @ vs, since we would prefer to compute a single matrix
|
||||
// # multiplication to many matrix-vector products.
|
||||
// vtv = -taus[None, :] * np.triu(vs.T @ vs, 1) + np.eye(n)
|
||||
// vtv = -taus[None, :] * np.triu(np.conj(vs.T) @ vs, 1) + np.eye(n)
|
||||
// for i in range(1, n):
|
||||
// t[:, i] = scipy.linalg.blas.strmm(t, vtv[:, i])
|
||||
// return t
|
||||
@ -293,8 +323,8 @@ StatusOr<XlaOp> QrExpander::CompactWYRepresentation(
|
||||
auto eye = Broadcast(IdentityMatrix(builder, type, n, n), batch_dims);
|
||||
auto t = eye;
|
||||
|
||||
auto vtv =
|
||||
BatchDot(vs, /*transpose_x=*/true, vs, /*transpose_y=*/false, precision);
|
||||
auto vtv = BatchDot(MaybeConjugate(vs, true), /*transpose_x=*/true, vs,
|
||||
/*transpose_y=*/false, precision);
|
||||
vtv = Select(TriangleMask(vtv, 0), ZerosLike(vtv), vtv);
|
||||
vtv = (vtv + eye) * tau_scale;
|
||||
|
||||
@ -313,8 +343,8 @@ StatusOr<XlaOp> QrExpander::CompactWYRepresentation(
|
||||
// (a, taus) = qr(a[i:, i:i+k])
|
||||
// y = np.eye(m, n) + np.tril(a, -1)
|
||||
// t = CompactWYRepresentation(vs, taus, m-i, k)
|
||||
// a[i:, i+k:] += (y @ t.T) @ (y.T @ a[i:, i+k:])
|
||||
// q[:, i:] += (q[:, i:] @ y) @ (y @ t.T).T
|
||||
// a[i:, i+k:] += (y @ np.conj(t.T)) @ (np.conj(y.T) @ a[i:, i+k:])
|
||||
// q[:, i:] += (q[:, i:] @ y) @ np.conj((y @ np.conj(t.T)).T)
|
||||
// return (q, a)
|
||||
StatusOr<XlaOp> QrExpander::BuildQrDecomposition(
|
||||
XlaOp a, int64 block_size, PrecisionConfig::Precision precision) {
|
||||
@ -361,21 +391,23 @@ StatusOr<XlaOp> QrExpander::BuildQrDecomposition(
|
||||
auto t, CompactWYRepresentation(type, batch_dims, y, qr_block.taus,
|
||||
m - i, k, precision));
|
||||
|
||||
// a[i:, i+k:] += (y @ t.T) @ (y.T @ a[i:, i+k:])
|
||||
auto yt =
|
||||
BatchDot(y, /*transpose_x=*/false, t, /*transpose_y=*/true, precision);
|
||||
// a[i:, i+k:] += (y @ np.conj(t.T)) @ (np.conj(y.T) @ a[i:, i+k:])
|
||||
auto yt = BatchDot(y, /*transpose_x=*/false, MaybeConjugate(t, true),
|
||||
/*transpose_y=*/true, precision);
|
||||
auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n});
|
||||
auto a_update = BatchDot(y, /*transpose_x=*/true, a_panel,
|
||||
/*transpose_y=*/false, precision);
|
||||
auto a_update =
|
||||
BatchDot(MaybeConjugate(y, true), /*transpose_x=*/true, a_panel,
|
||||
/*transpose_y=*/false, precision);
|
||||
a_update = BatchDot(yt, a_update, precision);
|
||||
a_panel = a_panel + a_update;
|
||||
a = UpdateSliceInMinorDims(a, a_panel, {i, i + k});
|
||||
|
||||
// q[:, i:] += (q[:, i:] @ y) @ (y @ t.T).T
|
||||
// q[:, i:] += (q[:, i:] @ y) @ np.conj((y @ np.conj(t.T)).T)
|
||||
auto q_panel = SliceInMinorDims(q, {0, i}, {m, m});
|
||||
auto q_update = BatchDot(q_panel, y, precision);
|
||||
q_update = BatchDot(q_update, /*transpose_x=*/false, yt,
|
||||
/*transpose_y=*/true, precision);
|
||||
q_update =
|
||||
BatchDot(q_update, /*transpose_x=*/false, MaybeConjugate(yt, true),
|
||||
/*transpose_y=*/true, precision);
|
||||
q_panel = q_panel + q_update;
|
||||
q = UpdateSliceInMinorDims(q, q_panel, {0, i});
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user