From cafe948be40c7883ce116e2516c4c67a2045558b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 26 May 2016 08:54:00 -0800 Subject: [PATCH] Add gradients for batch_cholesky. Change: 123328520 --- tensorflow/core/kernels/cholesky_grad.cc | 119 +++++++++--------- tensorflow/core/ops/linalg_ops.cc | 31 ++++- .../python/kernel_tests/cholesky_op_test.py | 57 ++++++--- tensorflow/python/ops/linalg_grad.py | 11 +- tensorflow/python/ops/linalg_ops.py | 6 +- 5 files changed, 136 insertions(+), 88 deletions(-) diff --git a/tensorflow/core/kernels/cholesky_grad.cc b/tensorflow/core/kernels/cholesky_grad.cc index 4fefcee55e4..7a1c44da426 100644 --- a/tensorflow/core/kernels/cholesky_grad.cc +++ b/tensorflow/core/kernels/cholesky_grad.cc @@ -13,75 +13,68 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/op.h" #include "third_party/eigen3/Eigen/Core" - +#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" - -#include "tensorflow/core/kernels/linalg_ops_common.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/binary_linalg_ops_common.h" namespace tensorflow { -template -class CholeskyGrad : public OpKernel { +template +class CholeskyGrad + : public BinaryLinearAlgebraOp { public: - explicit CholeskyGrad(OpKernelConstruction* context) : OpKernel(context) {} + explicit CholeskyGrad(OpKernelConstruction* context) + : BinaryLinearAlgebraOp(context) {} + ~CholeskyGrad() override {} + using Matrix = - Eigen::Matrix; + Eigen::Matrix; using ConstMatrixMap = Eigen::Map; using MatrixMap = Eigen::Map; using ConstRef = Eigen::Ref; using Ref = Eigen::Ref; - void Compute(OpKernelContext* context) override { - const Tensor& input_tensor_l = context->input(0); - const Tensor& input_tensor_grad = context->input(1); - // Check that input tensors represent a matrix. - OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_tensor_l.shape()), - errors::InvalidArgument("In[0] is not a matrix")); - OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_tensor_grad.shape()), - errors::InvalidArgument("In[1] is not a matrix")); - // Check that input tensors are square. - OP_REQUIRES(context, - input_tensor_l.dim_size(0) == input_tensor_l.dim_size(1), - errors::InvalidArgument("Input matrix must be square.")); - OP_REQUIRES(context, - input_tensor_grad.dim_size(0) == input_tensor_grad.dim_size(1), - errors::InvalidArgument("Input matrix must be square.")); + TensorShape GetOutputMatrixShape( + const TensorShape& input_matrix_l_full_shape, + const TensorShape& input_matrix_grad_shape) override { + return input_matrix_l_full_shape; + } - // Check that input tensors are of same size. - OP_REQUIRES(context, - input_tensor_l.dim_size(0) == input_tensor_grad.dim_size(0), - errors::InvalidArgument("Input matrices must be same size.")); - - // Create an output tensor - Tensor* output_tensor = NULL; - OP_REQUIRES_OK(context, context->allocate_output( - 0, input_tensor_grad.shape(), &output_tensor)); - - if (output_tensor->NumElements() == 0) { - // the output shape is a 0-element matrix, so there is nothing to do. - return; + int64 GetCostPerUnit(const TensorShape& input_matrix_shape, + const TensorShape& rhs_matrix_shape) override { + const int64 rows = input_matrix_shape.dim_size(0); + if (rows > (1LL << 20)) { + // A big number to cap the cost in case overflow. + return kint64max; + } else { + return rows * rows * rows; } - // The next lines are necessary to get Eigen matrix behaviour. - const ConstMatrixMap input_matrix_l_full(input_tensor_l.flat().data(), - input_tensor_l.dim_size(0), - input_tensor_l.dim_size(1)); - const ConstMatrixMap input_matrix_grad(input_tensor_grad.flat().data(), - input_tensor_grad.dim_size(0), - input_tensor_grad.dim_size(1)); - MatrixMap output_matrix(output_tensor->template flat().data(), - input_tensor_l.dim_size(0), - input_tensor_l.dim_size(1)); + } - // Algorithm only depends on lower triangular half on input_tensor_l. + void ComputeMatrix(OpKernelContext* context, + const ConstMatrixMap& input_matrix_l_full, + const ConstMatrixMap& input_matrix_grad, + MatrixMap* output_matrix) override { + OP_REQUIRES(context, + input_matrix_l_full.rows() == input_matrix_l_full.cols(), + errors::InvalidArgument("Input matrix must be square.")); + OP_REQUIRES( + context, input_matrix_l_full.cols() == input_matrix_grad.cols(), + errors::InvalidArgument( + "Input matrix and gradient must have same number of cols.")); + OP_REQUIRES( + context, input_matrix_l_full.rows() == input_matrix_grad.rows(), + errors::InvalidArgument( + "Input matrix and gradient must have same number of rows.")); + + // Algorithm only depends on lower triangular half on input_matrix_l. const Matrix input_matrix_l = input_matrix_l_full.template triangularView(); // Algorithm only depends on lower triangular half on input_matrix_grad. - output_matrix = input_matrix_grad.template triangularView(); + *output_matrix = input_matrix_grad.template triangularView(); const int64 kMatrixSize = input_matrix_l.rows(); const int64 kMaxBlockSize = 32; @@ -104,20 +97,21 @@ class CholeskyGrad : public OpKernel { auto B = input_matrix_l.block(block_end, 0, trailing_size, block_begin); auto B_bar = - output_matrix.block(block_end, 0, trailing_size, block_begin); + output_matrix->block(block_end, 0, trailing_size, block_begin); auto C = input_matrix_l.block(block_end, block_begin, trailing_size, block_size); - auto C_bar = output_matrix.block(block_end, block_begin, trailing_size, - block_size); + auto C_bar = output_matrix->block(block_end, block_begin, trailing_size, + block_size); auto D = input_matrix_l.block(block_begin, block_begin, block_size, block_size); - auto D_bar = - output_matrix.block(block_begin, block_begin, block_size, block_size); + auto D_bar = output_matrix->block(block_begin, block_begin, block_size, + block_size); auto R = input_matrix_l.block(block_begin, 0, block_size, block_begin); - auto R_bar = output_matrix.block(block_begin, 0, block_size, block_begin); + auto R_bar = + output_matrix->block(block_begin, 0, block_size, block_begin); C_bar = D.adjoint().template triangularView() .solve(C_bar.adjoint()).adjoint(); @@ -127,9 +121,11 @@ class CholeskyGrad : public OpKernel { CholeskyGradUnblocked(D, D_bar); R_bar -= (D_bar + D_bar.adjoint()) * R; } - output_matrix = (0.5 * (output_matrix + output_matrix.transpose())).eval(); + *output_matrix = + (0.5 * (*output_matrix + output_matrix->transpose())).eval(); } - void CholeskyGradUnblocked(const ConstRef l_block, Ref grad_block) { + + void CholeskyGradUnblocked(const ConstRef& l_block, Ref grad_block) { const int64 kMatrixSize = l_block.rows(); for (int64 k = kMatrixSize - 1; k >= 0; k--) { /* This shows the block structure. @@ -166,6 +162,11 @@ class CholeskyGrad : public OpKernel { } }; -REGISTER_LINALG_OP("CholeskyGrad", (CholeskyGrad), float); -REGISTER_LINALG_OP("CholeskyGrad", (CholeskyGrad), double); +REGISTER_BINARY_LINALG_OP("CholeskyGrad", (CholeskyGrad), float); +REGISTER_BINARY_LINALG_OP("CholeskyGrad", (CholeskyGrad), + double); +REGISTER_BINARY_LINALG_OP("BatchCholeskyGrad", (CholeskyGrad), + float); +REGISTER_BINARY_LINALG_OP("BatchCholeskyGrad", (CholeskyGrad), + double); } // namespace tensorflow diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc index edd60df8ef9..be87022c0a8 100644 --- a/tensorflow/core/ops/linalg_ops.cc +++ b/tensorflow/core/ops/linalg_ops.cc @@ -129,11 +129,34 @@ REGISTER_OP("CholeskyGrad") .Doc(R"doc( Calculates the reverse mode backpropagated gradient of the Cholesky algorithm. -For an explanation see "Differentiation of the Cholesky algorithm" by Iain Murray http://arxiv.org/abs/1602.07527. +For an explanation see "Differentiation of the Cholesky algorithm" by +Iain Murray http://arxiv.org/abs/1602.07527. -l: Output of Cholesky algorithm l = chol(A). Shape is `[M, M]`. Algorithm depends only on lower triangular part of this matrix. -grad: df/dl where f is some scalar function. Shape is `[M, M]'. Algorithm depends only on lower triangular part of this matrix. -output: Symmetrized version of df/dA . Shape is `[M, M]' +l: Output of Cholesky algorithm l = chol(A). Shape is `[M, M]`. + Algorithm depends only on lower triangular part of this matrix. +grad: df/dl where f is some scalar function. Shape is `[M, M]'. + Algorithm depends only on lower triangular part of this matrix. +output: Symmetrized version of df/dA . Shape is `[M, M]'. +)doc"); + +REGISTER_OP("BatchCholeskyGrad") + .Input("l: T") + .Input("grad: T") + .Output("output: T") + .Attr("T: {float, double}") + .Doc(R"doc( +Calculates the reverse mode backpropagated gradient of the Cholesky algorithm. + +For an explanation see "Differentiation of the Cholesky algorithm" by +Iain Murray http://arxiv.org/abs/1602.07527. + +l: Output of batch Cholesky algorithm l = batch_cholesky(A). Shape is `[..., M, M]`. + Algorithm depends only on lower triangular part of the innermost matrices of + this tensor. +grad: df/dl where f is some scalar function. Shape is `[..., M, M]'. + Algorithm depends only on lower triangular part of the innermost matrices of + this tensor. +output: Symmetrized version of df/dA . Shape is `[..., M, M]' )doc"); REGISTER_OP("SelfAdjointEig") diff --git a/tensorflow/python/kernel_tests/cholesky_op_test.py b/tensorflow/python/kernel_tests/cholesky_op_test.py index 199b54512e0..0189280e248 100644 --- a/tensorflow/python/kernel_tests/cholesky_op_test.py +++ b/tensorflow/python/kernel_tests/cholesky_op_test.py @@ -71,18 +71,23 @@ class CholeskyOpTest(tf.test.TestCase): def testNonSquareMatrix(self): with self.assertRaises(ValueError): tf.cholesky(np.array([[1., 2., 3.], [3., 4., 5.]])) + with self.assertRaises(ValueError): + tf.batch_cholesky(np.array([[[1., 2., 3.], [3., 4., 5.]], + [[1., 2., 3.], [3., 4., 5.]]])) def testWrongDimensions(self): tensor3 = tf.constant([1., 2.]) with self.assertRaises(ValueError): tf.cholesky(tensor3) + with self.assertRaises(ValueError): + tf.batch_cholesky(tensor3) def testNotInvertible(self): - # The input should be invertible. + # The input should be invertible. with self.test_session(): with self.assertRaisesOpError("LLT decomposition was not successful. The" " input might not be valid."): - # All rows of the matrix below add to zero + # All rows of the matrix below add to zero self._verifyCholesky(np.array([[1., -1., 0.], [-1., 1., -1.], [0., -1., 1.]])) @@ -122,24 +127,36 @@ class CholeskyGradTest(tf.test.TestCase): scalarTest=False): with self.test_session(use_gpu=False): for shape in shapes: - for dtype in dtypes: - if not(scalarTest): - x = tf.constant(np.random.randn(shape[0], shape[1]), dtype) - K = tf.matmul(x, tf.transpose(x)) / shape[0] # K is posdef - y = tf.cholesky(K) - else: # This is designed to be a faster test for larger matrices. - x = tf.constant(np.random.randn(), dtype) - R = tf.constant(np.random.randn(shape[0], shape[1]), dtype) - e = tf.mul(R, x) - K = tf.matmul(e, tf.transpose(e)) / shape[0] # K is posdef - y = tf.reduce_mean(tf.cholesky(K)) - error = tf.test.compute_gradient_error(x, x._shape_as_list(), - y, y._shape_as_list()) - tf.logging.info("error = %f", error) - if dtype == tf.float64: - self.assertLess(error, 1e-5) - else: - self.assertLess(error, 2e-3) + for batch in False, True: + for dtype in dtypes: + if not scalarTest: + x = tf.constant(np.random.randn(shape[0], shape[1]), dtype) + tensor = tf.matmul(x, tf.transpose(x)) / shape[0] + else: + # This is designed to be a faster test for larger matrices. + x = tf.constant(np.random.randn(), dtype) + R = tf.constant(np.random.randn(shape[0], shape[1]), dtype) + e = tf.mul(R, x) + tensor = tf.matmul(e, tf.transpose(e)) / shape[0] + + # Inner-most matrices in tensor are positive definite. + if batch: + tensor = tf.tile(tf.expand_dims(tensor, 0), [4, 1, 1]) + op = tf.batch_cholesky + else: + op = tf.cholesky + + if not (scalarTest): + y = op(tensor) + else: + y = tf.reduce_mean(op(tensor)) + error = tf.test.compute_gradient_error(x, x._shape_as_list(), y, + y._shape_as_list()) + tf.logging.info("error = %f", error) + if dtype == tf.float64: + self.assertLess(error, 1e-5) + else: + self.assertLess(error, 3e-3) if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py index 92911009eb7..36c82278584 100644 --- a/tensorflow/python/ops/linalg_grad.py +++ b/tensorflow/python/ops/linalg_grad.py @@ -32,6 +32,9 @@ from tensorflow.python.ops import constant_op from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +ops.NoGradient("CholeskyGrad") +ops.NoGradient("BatchCholeskyGrad") + @ops.RegisterGradient("MatrixInverse") def _MatrixInverseGrad(op, grad): @@ -76,11 +79,17 @@ def _BatchMatrixDeterminantGrad(op, grad): @ops.RegisterGradient("Cholesky") -def _cholesky_grad(op, grad): +def _CholeskyGrad(op, grad): """Gradient for Cholesky.""" return linalg_ops.cholesky_grad(op.outputs[0], grad) +@ops.RegisterGradient("BatchCholesky") +def _BatchCholeskyGrad(op, grad): + """Gradient for BatchCholesky.""" + return linalg_ops.batch_cholesky_grad(op.outputs[0], grad) + + @ops.RegisterGradient("MatrixSolve") def _MatrixSolveGrad(op, grad): """Gradients for MatrixSolve.""" diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py index 983851e09e4..66a64a998ba 100644 --- a/tensorflow/python/ops/linalg_ops.py +++ b/tensorflow/python/ops/linalg_ops.py @@ -28,6 +28,7 @@ from tensorflow.python.ops.gen_linalg_ops import * @ops.RegisterShape("Cholesky") +@ops.RegisterShape("CholeskyGrad") @ops.RegisterShape("MatrixInverse") def _UnchangedSquare(op): input_shape = op.inputs[0].get_shape().with_rank(2) @@ -37,6 +38,7 @@ def _UnchangedSquare(op): @ops.RegisterShape("BatchCholesky") +@ops.RegisterShape("BatchCholeskyGrad") @ops.RegisterShape("BatchMatrixInverse") def _BatchUnchangedSquare(op): input_shape = op.inputs[0].get_shape().with_rank_at_least(2) @@ -44,10 +46,6 @@ def _BatchUnchangedSquare(op): input_shape[-1].assert_is_compatible_with(input_shape[-2]) return [input_shape] -@ops.RegisterShape("CholeskyGrad") -def _cholesky_grad_shape(op): - return [op.inputs[0].get_shape()] - @ops.RegisterShape("MatrixDeterminant") def _MatrixDeterminantShape(op): input_shape = op.inputs[0].get_shape().with_rank(2)