Merge pull request from jonathanwyatt16:matrix_square_root

PiperOrigin-RevId: 218197028
This commit is contained in:
TensorFlower Gardener 2018-10-22 11:09:36 -07:00
commit 3d715da989
15 changed files with 393 additions and 0 deletions

View File

@ -0,0 +1,37 @@
op {
graph_op_name: "MatrixSquareRoot"
in_arg {
name: "input"
description: <<END
Shape is `[..., M, M]`.
END
}
out_arg {
name: "output"
description: <<END
Shape is `[..., M, M]`.
@compatibility(scipy)
Equivalent to scipy.linalg.sqrtm
@end_compatibility
END
}
summary: "Computes the matrix square root of one or more square matrices:"
description: <<END
matmul(sqrtm(A), sqrtm(A)) = A
The input matrix should be invertible. If the input matrix is real, it should
have no eigenvalues which are real and negative (pairs of complex conjugate
eigenvalues are allowed).
The matrix square root is computed by first reducing the matrix to
quasi-triangular form with the real Schur decomposition. The square root
of the quasi-triangular matrix is then computed directly. Details of
the algorithm can be found in: Nicholas J. Higham, "Computing real
square roots of a real matrix", Linear Algebra Appl., 1987.
The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
form square matrices. The output is a tensor of the same shape as the input
containing the matrix square root for all input submatrices `[..., :, :]`.
END
}

View File

@ -0,0 +1,9 @@
op {
graph_op_name: "MatrixSquareRoot"
endpoint {
name: "linalg.sqrtm"
}
endpoint {
name: "matrix_square_root"
}
}

View File

@ -2629,6 +2629,7 @@ cc_library(
":matrix_logarithm_op",
":matrix_solve_ls_op",
":matrix_solve_op",
":matrix_square_root_op",
":matrix_triangular_solve_op",
":qr_op",
":self_adjoint_eig_op",
@ -2738,6 +2739,12 @@ tf_kernel_library(
deps = LINALG_DEPS,
)
tf_kernel_library(
name = "matrix_square_root_op",
prefix = "matrix_square_root_op",
deps = LINALG_DEPS,
)
tf_kernel_library(
name = "matrix_triangular_solve_op",
prefix = "matrix_triangular_solve_op",

View File

@ -0,0 +1,58 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/linalg_ops.cc.
#include "third_party/eigen3/Eigen/Core"
#include "third_party/eigen3/unsupported/Eigen/MatrixFunctions"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
template <class Scalar>
class MatrixSquareRootOp : public LinearAlgebraOp<Scalar> {
public:
INHERIT_LINALG_TYPEDEFS(Scalar);
explicit MatrixSquareRootOp(OpKernelConstruction* context) : Base(context) {}
void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
MatrixMaps* outputs) final {
const ConstMatrixMap& input = inputs[0];
if (input.rows() == 0) return;
using Matrix =
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
Matrix tmp = input;
outputs->at(0) = tmp.sqrt();
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(MatrixSquareRootOp);
};
REGISTER_LINALG_OP("MatrixSquareRoot", (MatrixSquareRootOp<float>), float);
REGISTER_LINALG_OP("MatrixSquareRoot", (MatrixSquareRootOp<double>), double);
REGISTER_LINALG_OP("MatrixSquareRoot", (MatrixSquareRootOp<complex64>),
complex64);
REGISTER_LINALG_OP("MatrixSquareRoot", (MatrixSquareRootOp<complex128>),
complex128);
} // namespace tensorflow

View File

@ -323,6 +323,12 @@ REGISTER_OP("MatrixSolveLs")
return MatrixSolveShapeFn(c, false /* square */);
});
REGISTER_OP("MatrixSquareRoot")
.Input("input: T")
.Output("output: T")
.Attr("T: {double, float, complex64, complex128}")
.SetShapeFn(BatchUnchangedSquareShapeFn);
REGISTER_OP("Qr")
.Input("input: T")
.Output("q: T")

View File

@ -16084,6 +16084,29 @@ op {
}
}
}
op {
name: "MatrixSquareRoot"
input_arg {
name: "matrix"
type_attr: "T"
}
output_arg {
name: "output"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_DOUBLE
type: DT_FLOAT
type: DT_COMPLEX64
type: DT_COMPLEX128
}
}
}
}
op {
name: "MatrixTriangularSolve"
input_arg {

View File

@ -16660,6 +16660,46 @@ func MatrixSolveLs(scope *Scope, matrix tf.Output, rhs tf.Output, l2_regularizer
return op.Output(0)
}
// Computes the matrix square root of one or more square matrices:
//
// matmul(sqrtm(A), sqrtm(A)) = A
//
// The input matrix should be invertible. If the input matrix is real,
// it should have no eigenvalues which are real and negative
// (pairs of complex conjugate eigenvalues are allowed).
//
// The matrix square root is computed by first reducing the matrix to
// quasi-triangular form with the real Schur decomposition. The square root
// of the quasi-triangular matrix is then computed directly. Details of
// the algorithm can be found in: Nicholas J. Higham, "Computing real
// square roots of a real matrix", Linear Algebra Appl., 1987.
//
// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
// form square matrices. The output is a tensor of the same shape as the input
// containing the matrix square root for all input submatrices `[..., :, :]`.
//
// Arguments:
// input: Shape is `[..., M, M]`.
//
// Returns Shape is `[..., M, M]`.
//
// @compatibility(scipy)
// Equivalent to scipy.linalg.sqrtm
// @end_compatibility
func MatrixSquareRoot(scope *Scope, input tf.Output) (output tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
Type: "MatrixSquareRoot",
Input: []tf.Input{
input,
},
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
// MaxPool3DAttr is an optional argument to MaxPool3D.
type MaxPool3DAttr func(optionalAttr)

View File

@ -66,6 +66,10 @@ def _GetMatrixUnaryFunctorGradientTest(functor_, dtype_, shape_, **kwargs_):
low=-1.0, high=1.0,
size=np.prod(shape_)).reshape(shape_).astype(dtype_)
a = constant_op.constant(a_np)
if functor_.__name__ == 'matrix_square_root':
# Square the input matrix to ensure that its matrix square root exists
a = math_ops.matmul(a, a)
a_np = a.eval()
b = functor_(a, **kwargs_)
# Optimal stepsize for central difference is O(epsilon^{1/3}).
@ -189,6 +193,17 @@ if __name__ == '__main__':
lambda x: linalg_ops.log_matrix_determinant(x)[1],
dtype, shape))
# The numerical Jacobian is consistently invalid for these four shapes
# because the matrix square root of the perturbed input doesn't exist
if shape in {(2, 5, 5), (3, 5, 5), (3, 10, 10), (3, 2, 5, 5)}:
# Alternative shape that consistently produces a valid numerical Jacobian
shape = extra + (size + 1, size + 1)
name = '%s_%s' % (dtype.__name__, '_'.join(map(str, shape)))
_AddTest(
MatrixUnaryFunctorGradientTest, 'MatrixSquareRootGradient', name,
_GetMatrixUnaryFunctorGradientTest(linalg_ops.matrix_square_root,
dtype, shape))
# Tests for gradients of matrix_solve_ls
for dtype in np.float32, np.float64:
for rows in 2, 5, 10:

View File

@ -0,0 +1,116 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.math_ops.matrix_square_root."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
class SquareRootOpTest(test.TestCase):
def _verifySquareRoot(self, matrix, np_type):
matrix = matrix.astype(np_type)
with self.test_session(use_gpu=True):
# Verify that matmul(sqrtm(A), sqrtm(A)) = A
sqrt = gen_linalg_ops.matrix_square_root(matrix)
square = math_ops.matmul(sqrt, sqrt)
self.assertShapeEqual(matrix, square)
self.assertAllClose(matrix, square, rtol=1e-4, atol=1e-3)
def _verifySquareRootReal(self, x):
for np_type in [np.float32, np.float64]:
self._verifySquareRoot(x, np_type)
def _verifySquareRootComplex(self, x):
for np_type in [np.complex64, np.complex128]:
self._verifySquareRoot(x, np_type)
def _makeBatch(self, matrix1, matrix2):
matrix_batch = np.concatenate(
[np.expand_dims(matrix1, 0),
np.expand_dims(matrix2, 0)])
matrix_batch = np.tile(matrix_batch, [2, 3, 1, 1])
return matrix_batch
def _testMatrices(self, matrix1, matrix2):
# Real
self._verifySquareRootReal(matrix1)
self._verifySquareRootReal(matrix2)
self._verifySquareRootReal(self._makeBatch(matrix1, matrix2))
# Complex
matrix1 = matrix1.astype(np.complex64)
matrix2 = matrix2.astype(np.complex64)
matrix1 += 1j * matrix1
matrix2 += 1j * matrix2
self._verifySquareRootComplex(matrix1)
self._verifySquareRootComplex(matrix2)
self._verifySquareRootComplex(self._makeBatch(matrix1, matrix2))
def testSymmetricPositiveDefinite(self):
matrix1 = np.array([[2., 1.], [1., 2.]])
matrix2 = np.array([[3., -1.], [-1., 3.]])
self._testMatrices(matrix1, matrix2)
def testAsymmetric(self):
matrix1 = np.array([[0., 4.], [-1., 5.]])
matrix2 = np.array([[33., 24.], [48., 57.]])
self._testMatrices(matrix1, matrix2)
def testIdentityMatrix(self):
# 2x2
identity = np.array([[1., 0], [0, 1.]])
self._verifySquareRootReal(identity)
# 3x3
identity = np.array([[1., 0, 0], [0, 1., 0], [0, 0, 1.]])
self._verifySquareRootReal(identity)
def testEmpty(self):
self._verifySquareRootReal(np.empty([0, 2, 2]))
self._verifySquareRootReal(np.empty([2, 0, 0]))
def testWrongDimensions(self):
# The input to the square root should be at least a 2-dimensional tensor.
tensor = constant_op.constant([1., 2.])
with self.assertRaises(ValueError):
gen_linalg_ops.matrix_square_root(tensor)
def testNotSquare(self):
with self.test_session():
with self.assertRaises(ValueError):
tensor = constant_op.constant([[1., 0., -1.], [-1., 1., 0.]])
gen_linalg_ops.matrix_square_root(tensor).eval()
def testConcurrentExecutesWithoutError(self):
with self.test_session(use_gpu=True) as sess:
matrix1 = random_ops.random_normal([5, 5], seed=42)
matrix2 = random_ops.random_normal([5, 5], seed=42)
sqrt1 = gen_linalg_ops.matrix_square_root(matrix1)
sqrt2 = gen_linalg_ops.matrix_square_root(matrix2)
all_ops = [sqrt1, sqrt2]
sqrt = sess.run(all_ops)
self.assertAllEqual(sqrt[0], sqrt[1])
if __name__ == "__main__":
test.main()

View File

@ -50,6 +50,7 @@ norm = linalg_ops.norm
qr = linalg_ops.qr
set_diag = array_ops.matrix_set_diag
solve = linalg_ops.matrix_solve
sqrtm = linalg_ops.matrix_square_root
svd = linalg_ops.svd
tensordot = math_ops.tensordot
trace = math_ops.trace

View File

@ -55,6 +55,71 @@ def _MatrixDeterminantGrad(op, grad):
return multipliers * a_adj_inv
@ops.RegisterGradient("MatrixSquareRoot")
def _MatrixSquareRootGrad(op, grad):
"""Gradient for MatrixSquareRoot."""
# Let A be an m x m square matrix (or batch of matrices)
# Let R = sqrtm(A)
# By definition, A = RR
# Take the differential: dA = d(RR) = RdR + dRR
# Solve the resulting Sylvester equation for dR
# Used to find Kronecker products within the Sylvester equation
def _KroneckerProduct(b1, b2):
"""Computes the Kronecker product of two batches of square matrices"""
b1_shape = array_ops.shape(b1)
b2_shape = array_ops.shape(b2)
b1_order = b1_shape[-1]
b2_order = b2_shape[-1]
shape_slice_size = [math_ops.subtract(array_ops.size(b1_shape), 2)]
shape_slice = array_ops.slice(b1_shape, [0],
shape_slice_size) # Same for both batches
b1_reshape_shape = array_ops.concat(
[shape_slice, [b1_order], [1], [b1_order], [1]], 0)
b2_reshape_shape = array_ops.concat(
[shape_slice, [1], [b2_order], [1], [b2_order]], 0)
b1_reshape = array_ops.reshape(b1, b1_reshape_shape)
b2_reshape = array_ops.reshape(b2, b2_reshape_shape)
order_prod = b1_order * b2_order
kprod_shape = array_ops.concat([shape_slice, [order_prod], [order_prod]], 0)
return array_ops.reshape(b1_reshape * b2_reshape, kprod_shape)
sqrtm = op.outputs[0] # R
shape = array_ops.shape(sqrtm)
order = shape[-1] # m
matrix_count = math_ops.reduce_prod(shape[0:-2])
# Get batch of m x m identity matrices
eye = linalg_ops.eye(order, dtype=sqrtm.dtype) # m x m identity matrix
eye_flat = array_ops.reshape(eye, [-1])
eye_tiled = array_ops.tile(eye_flat, [matrix_count])
eye_batch = array_ops.reshape(eye_tiled, shape)
# The transpose of R is taken in the k1 term instead of k2 in
# order to prevent redundant transposition of R (i.e. (R')' = R)
sqrtm_transpose = array_ops.matrix_transpose(sqrtm)
k1 = _KroneckerProduct(eye_batch, sqrtm_transpose)
k2 = _KroneckerProduct(sqrtm, eye_batch)
ksum = math_ops.add(k1, k2)
# Vectorize dA
shape_slice_size = [math_ops.subtract(array_ops.size(shape), 2)]
shape_slice = array_ops.slice(shape, [0], shape_slice_size)
shape_vec_da = array_ops.concat([shape_slice, [order * order], [1]], 0)
vec_da = array_ops.reshape(array_ops.matrix_transpose(grad), shape_vec_da)
# Solve for vec(dR)
vec_dsqrtm = linalg_ops.matrix_solve(ksum, vec_da)
# Solve for dR by inverse vectorizing vec(dR)
dsqrtm_transpose = array_ops.reshape(vec_dsqrtm, shape)
return array_ops.matrix_transpose(dsqrtm_transpose)
@ops.RegisterGradient("LogMatrixDeterminant")
def _LogMatrixDeterminantGrad(op, _, grad_b):
"""Gradient for LogMatrixDeterminant."""

View File

@ -156,6 +156,10 @@ tf_module {
name: "solve"
argspec: "args=[\'matrix\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "sqrtm"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "svd"
argspec: "args=[\'tensor\', \'full_matrices\', \'compute_uv\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "

View File

@ -1504,6 +1504,10 @@ tf_module {
name: "matrix_solve_ls"
argspec: "args=[\'matrix\', \'rhs\', \'l2_regularizer\', \'fast\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'True\', \'None\'], "
}
member_method {
name: "matrix_square_root"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "matrix_transpose"
argspec: "args=[\'a\', \'name\', \'conjugate\'], varargs=None, keywords=None, defaults=[\'matrix_transpose\', \'False\'], "

View File

@ -156,6 +156,10 @@ tf_module {
name: "solve"
argspec: "args=[\'matrix\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "sqrtm"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "svd"
argspec: "args=[\'tensor\', \'full_matrices\', \'compute_uv\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "

View File

@ -1120,6 +1120,10 @@ tf_module {
name: "matrix_solve"
argspec: "args=[\'matrix\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "matrix_square_root"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "matrix_triangular_solve"
argspec: "args=[\'matrix\', \'rhs\', \'lower\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'None\'], "