Add gradients for batch_cholesky.

Change: 123328520
This commit is contained in:
A. Unique TensorFlower 2016-05-26 08:54:00 -08:00 committed by TensorFlower Gardener
parent c1b5075d1f
commit cafe948be4
5 changed files with 136 additions and 88 deletions

View File

@ -13,75 +13,68 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "third_party/eigen3/Eigen/Core" #include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/binary_linalg_ops_common.h"
namespace tensorflow { namespace tensorflow {
template <typename T> template <typename Scalar, bool SupportsBatchOperationT>
class CholeskyGrad : public OpKernel { class CholeskyGrad
: public BinaryLinearAlgebraOp<Scalar, SupportsBatchOperationT> {
public: public:
explicit CholeskyGrad(OpKernelConstruction* context) : OpKernel(context) {} explicit CholeskyGrad(OpKernelConstruction* context)
: BinaryLinearAlgebraOp<Scalar, SupportsBatchOperationT>(context) {}
~CholeskyGrad() override {}
using Matrix = using Matrix =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using ConstMatrixMap = Eigen::Map<const Matrix>; using ConstMatrixMap = Eigen::Map<const Matrix>;
using MatrixMap = Eigen::Map<Matrix>; using MatrixMap = Eigen::Map<Matrix>;
using ConstRef = Eigen::Ref<const Matrix>; using ConstRef = Eigen::Ref<const Matrix>;
using Ref = Eigen::Ref<Matrix>; using Ref = Eigen::Ref<Matrix>;
void Compute(OpKernelContext* context) override { TensorShape GetOutputMatrixShape(
const Tensor& input_tensor_l = context->input(0); const TensorShape& input_matrix_l_full_shape,
const Tensor& input_tensor_grad = context->input(1); const TensorShape& input_matrix_grad_shape) override {
// Check that input tensors represent a matrix. return input_matrix_l_full_shape;
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_tensor_l.shape()), }
errors::InvalidArgument("In[0] is not a matrix"));
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_tensor_grad.shape()),
errors::InvalidArgument("In[1] is not a matrix"));
// Check that input tensors are square.
OP_REQUIRES(context,
input_tensor_l.dim_size(0) == input_tensor_l.dim_size(1),
errors::InvalidArgument("Input matrix must be square."));
OP_REQUIRES(context,
input_tensor_grad.dim_size(0) == input_tensor_grad.dim_size(1),
errors::InvalidArgument("Input matrix must be square."));
// Check that input tensors are of same size. int64 GetCostPerUnit(const TensorShape& input_matrix_shape,
OP_REQUIRES(context, const TensorShape& rhs_matrix_shape) override {
input_tensor_l.dim_size(0) == input_tensor_grad.dim_size(0), const int64 rows = input_matrix_shape.dim_size(0);
errors::InvalidArgument("Input matrices must be same size.")); if (rows > (1LL << 20)) {
// A big number to cap the cost in case overflow.
// Create an output tensor return kint64max;
Tensor* output_tensor = NULL; } else {
OP_REQUIRES_OK(context, context->allocate_output( return rows * rows * rows;
0, input_tensor_grad.shape(), &output_tensor));
if (output_tensor->NumElements() == 0) {
// the output shape is a 0-element matrix, so there is nothing to do.
return;
} }
// The next lines are necessary to get Eigen matrix behaviour. }
const ConstMatrixMap input_matrix_l_full(input_tensor_l.flat<T>().data(),
input_tensor_l.dim_size(0),
input_tensor_l.dim_size(1));
const ConstMatrixMap input_matrix_grad(input_tensor_grad.flat<T>().data(),
input_tensor_grad.dim_size(0),
input_tensor_grad.dim_size(1));
MatrixMap output_matrix(output_tensor->template flat<T>().data(),
input_tensor_l.dim_size(0),
input_tensor_l.dim_size(1));
// Algorithm only depends on lower triangular half on input_tensor_l. void ComputeMatrix(OpKernelContext* context,
const ConstMatrixMap& input_matrix_l_full,
const ConstMatrixMap& input_matrix_grad,
MatrixMap* output_matrix) override {
OP_REQUIRES(context,
input_matrix_l_full.rows() == input_matrix_l_full.cols(),
errors::InvalidArgument("Input matrix must be square."));
OP_REQUIRES(
context, input_matrix_l_full.cols() == input_matrix_grad.cols(),
errors::InvalidArgument(
"Input matrix and gradient must have same number of cols."));
OP_REQUIRES(
context, input_matrix_l_full.rows() == input_matrix_grad.rows(),
errors::InvalidArgument(
"Input matrix and gradient must have same number of rows."));
// Algorithm only depends on lower triangular half on input_matrix_l.
const Matrix input_matrix_l = const Matrix input_matrix_l =
input_matrix_l_full.template triangularView<Eigen::Lower>(); input_matrix_l_full.template triangularView<Eigen::Lower>();
// Algorithm only depends on lower triangular half on input_matrix_grad. // Algorithm only depends on lower triangular half on input_matrix_grad.
output_matrix = input_matrix_grad.template triangularView<Eigen::Lower>(); *output_matrix = input_matrix_grad.template triangularView<Eigen::Lower>();
const int64 kMatrixSize = input_matrix_l.rows(); const int64 kMatrixSize = input_matrix_l.rows();
const int64 kMaxBlockSize = 32; const int64 kMaxBlockSize = 32;
@ -104,20 +97,21 @@ class CholeskyGrad : public OpKernel {
auto B = input_matrix_l.block(block_end, 0, trailing_size, block_begin); auto B = input_matrix_l.block(block_end, 0, trailing_size, block_begin);
auto B_bar = auto B_bar =
output_matrix.block(block_end, 0, trailing_size, block_begin); output_matrix->block(block_end, 0, trailing_size, block_begin);
auto C = input_matrix_l.block(block_end, block_begin, trailing_size, auto C = input_matrix_l.block(block_end, block_begin, trailing_size,
block_size); block_size);
auto C_bar = output_matrix.block(block_end, block_begin, trailing_size, auto C_bar = output_matrix->block(block_end, block_begin, trailing_size,
block_size); block_size);
auto D = input_matrix_l.block(block_begin, block_begin, block_size, auto D = input_matrix_l.block(block_begin, block_begin, block_size,
block_size); block_size);
auto D_bar = auto D_bar = output_matrix->block(block_begin, block_begin, block_size,
output_matrix.block(block_begin, block_begin, block_size, block_size); block_size);
auto R = input_matrix_l.block(block_begin, 0, block_size, block_begin); auto R = input_matrix_l.block(block_begin, 0, block_size, block_begin);
auto R_bar = output_matrix.block(block_begin, 0, block_size, block_begin); auto R_bar =
output_matrix->block(block_begin, 0, block_size, block_begin);
C_bar = D.adjoint().template triangularView<Eigen::Upper>() C_bar = D.adjoint().template triangularView<Eigen::Upper>()
.solve(C_bar.adjoint()).adjoint(); .solve(C_bar.adjoint()).adjoint();
@ -127,9 +121,11 @@ class CholeskyGrad : public OpKernel {
CholeskyGradUnblocked(D, D_bar); CholeskyGradUnblocked(D, D_bar);
R_bar -= (D_bar + D_bar.adjoint()) * R; R_bar -= (D_bar + D_bar.adjoint()) * R;
} }
output_matrix = (0.5 * (output_matrix + output_matrix.transpose())).eval(); *output_matrix =
(0.5 * (*output_matrix + output_matrix->transpose())).eval();
} }
void CholeskyGradUnblocked(const ConstRef l_block, Ref grad_block) {
void CholeskyGradUnblocked(const ConstRef& l_block, Ref grad_block) {
const int64 kMatrixSize = l_block.rows(); const int64 kMatrixSize = l_block.rows();
for (int64 k = kMatrixSize - 1; k >= 0; k--) { for (int64 k = kMatrixSize - 1; k >= 0; k--) {
/* This shows the block structure. /* This shows the block structure.
@ -166,6 +162,11 @@ class CholeskyGrad : public OpKernel {
} }
}; };
REGISTER_LINALG_OP("CholeskyGrad", (CholeskyGrad<float>), float); REGISTER_BINARY_LINALG_OP("CholeskyGrad", (CholeskyGrad<float, false>), float);
REGISTER_LINALG_OP("CholeskyGrad", (CholeskyGrad<double>), double); REGISTER_BINARY_LINALG_OP("CholeskyGrad", (CholeskyGrad<double, false>),
double);
REGISTER_BINARY_LINALG_OP("BatchCholeskyGrad", (CholeskyGrad<float, true>),
float);
REGISTER_BINARY_LINALG_OP("BatchCholeskyGrad", (CholeskyGrad<double, true>),
double);
} // namespace tensorflow } // namespace tensorflow

View File

@ -129,11 +129,34 @@ REGISTER_OP("CholeskyGrad")
.Doc(R"doc( .Doc(R"doc(
Calculates the reverse mode backpropagated gradient of the Cholesky algorithm. Calculates the reverse mode backpropagated gradient of the Cholesky algorithm.
For an explanation see "Differentiation of the Cholesky algorithm" by Iain Murray http://arxiv.org/abs/1602.07527. For an explanation see "Differentiation of the Cholesky algorithm" by
Iain Murray http://arxiv.org/abs/1602.07527.
l: Output of Cholesky algorithm l = chol(A). Shape is `[M, M]`. Algorithm depends only on lower triangular part of this matrix. l: Output of Cholesky algorithm l = chol(A). Shape is `[M, M]`.
grad: df/dl where f is some scalar function. Shape is `[M, M]'. Algorithm depends only on lower triangular part of this matrix. Algorithm depends only on lower triangular part of this matrix.
output: Symmetrized version of df/dA . Shape is `[M, M]' grad: df/dl where f is some scalar function. Shape is `[M, M]'.
Algorithm depends only on lower triangular part of this matrix.
output: Symmetrized version of df/dA . Shape is `[M, M]'.
)doc");
REGISTER_OP("BatchCholeskyGrad")
.Input("l: T")
.Input("grad: T")
.Output("output: T")
.Attr("T: {float, double}")
.Doc(R"doc(
Calculates the reverse mode backpropagated gradient of the Cholesky algorithm.
For an explanation see "Differentiation of the Cholesky algorithm" by
Iain Murray http://arxiv.org/abs/1602.07527.
l: Output of batch Cholesky algorithm l = batch_cholesky(A). Shape is `[..., M, M]`.
Algorithm depends only on lower triangular part of the innermost matrices of
this tensor.
grad: df/dl where f is some scalar function. Shape is `[..., M, M]'.
Algorithm depends only on lower triangular part of the innermost matrices of
this tensor.
output: Symmetrized version of df/dA . Shape is `[..., M, M]'
)doc"); )doc");
REGISTER_OP("SelfAdjointEig") REGISTER_OP("SelfAdjointEig")

View File

@ -71,18 +71,23 @@ class CholeskyOpTest(tf.test.TestCase):
def testNonSquareMatrix(self): def testNonSquareMatrix(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
tf.cholesky(np.array([[1., 2., 3.], [3., 4., 5.]])) tf.cholesky(np.array([[1., 2., 3.], [3., 4., 5.]]))
with self.assertRaises(ValueError):
tf.batch_cholesky(np.array([[[1., 2., 3.], [3., 4., 5.]],
[[1., 2., 3.], [3., 4., 5.]]]))
def testWrongDimensions(self): def testWrongDimensions(self):
tensor3 = tf.constant([1., 2.]) tensor3 = tf.constant([1., 2.])
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
tf.cholesky(tensor3) tf.cholesky(tensor3)
with self.assertRaises(ValueError):
tf.batch_cholesky(tensor3)
def testNotInvertible(self): def testNotInvertible(self):
# The input should be invertible. # The input should be invertible.
with self.test_session(): with self.test_session():
with self.assertRaisesOpError("LLT decomposition was not successful. The" with self.assertRaisesOpError("LLT decomposition was not successful. The"
" input might not be valid."): " input might not be valid."):
# All rows of the matrix below add to zero # All rows of the matrix below add to zero
self._verifyCholesky(np.array([[1., -1., 0.], [-1., 1., -1.], [0., -1., self._verifyCholesky(np.array([[1., -1., 0.], [-1., 1., -1.], [0., -1.,
1.]])) 1.]]))
@ -122,24 +127,36 @@ class CholeskyGradTest(tf.test.TestCase):
scalarTest=False): scalarTest=False):
with self.test_session(use_gpu=False): with self.test_session(use_gpu=False):
for shape in shapes: for shape in shapes:
for dtype in dtypes: for batch in False, True:
if not(scalarTest): for dtype in dtypes:
x = tf.constant(np.random.randn(shape[0], shape[1]), dtype) if not scalarTest:
K = tf.matmul(x, tf.transpose(x)) / shape[0] # K is posdef x = tf.constant(np.random.randn(shape[0], shape[1]), dtype)
y = tf.cholesky(K) tensor = tf.matmul(x, tf.transpose(x)) / shape[0]
else: # This is designed to be a faster test for larger matrices. else:
x = tf.constant(np.random.randn(), dtype) # This is designed to be a faster test for larger matrices.
R = tf.constant(np.random.randn(shape[0], shape[1]), dtype) x = tf.constant(np.random.randn(), dtype)
e = tf.mul(R, x) R = tf.constant(np.random.randn(shape[0], shape[1]), dtype)
K = tf.matmul(e, tf.transpose(e)) / shape[0] # K is posdef e = tf.mul(R, x)
y = tf.reduce_mean(tf.cholesky(K)) tensor = tf.matmul(e, tf.transpose(e)) / shape[0]
error = tf.test.compute_gradient_error(x, x._shape_as_list(),
y, y._shape_as_list()) # Inner-most matrices in tensor are positive definite.
tf.logging.info("error = %f", error) if batch:
if dtype == tf.float64: tensor = tf.tile(tf.expand_dims(tensor, 0), [4, 1, 1])
self.assertLess(error, 1e-5) op = tf.batch_cholesky
else: else:
self.assertLess(error, 2e-3) op = tf.cholesky
if not (scalarTest):
y = op(tensor)
else:
y = tf.reduce_mean(op(tensor))
error = tf.test.compute_gradient_error(x, x._shape_as_list(), y,
y._shape_as_list())
tf.logging.info("error = %f", error)
if dtype == tf.float64:
self.assertLess(error, 1e-5)
else:
self.assertLess(error, 3e-3)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()

View File

@ -32,6 +32,9 @@ from tensorflow.python.ops import constant_op
from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
ops.NoGradient("CholeskyGrad")
ops.NoGradient("BatchCholeskyGrad")
@ops.RegisterGradient("MatrixInverse") @ops.RegisterGradient("MatrixInverse")
def _MatrixInverseGrad(op, grad): def _MatrixInverseGrad(op, grad):
@ -76,11 +79,17 @@ def _BatchMatrixDeterminantGrad(op, grad):
@ops.RegisterGradient("Cholesky") @ops.RegisterGradient("Cholesky")
def _cholesky_grad(op, grad): def _CholeskyGrad(op, grad):
"""Gradient for Cholesky.""" """Gradient for Cholesky."""
return linalg_ops.cholesky_grad(op.outputs[0], grad) return linalg_ops.cholesky_grad(op.outputs[0], grad)
@ops.RegisterGradient("BatchCholesky")
def _BatchCholeskyGrad(op, grad):
"""Gradient for BatchCholesky."""
return linalg_ops.batch_cholesky_grad(op.outputs[0], grad)
@ops.RegisterGradient("MatrixSolve") @ops.RegisterGradient("MatrixSolve")
def _MatrixSolveGrad(op, grad): def _MatrixSolveGrad(op, grad):
"""Gradients for MatrixSolve.""" """Gradients for MatrixSolve."""

View File

@ -28,6 +28,7 @@ from tensorflow.python.ops.gen_linalg_ops import *
@ops.RegisterShape("Cholesky") @ops.RegisterShape("Cholesky")
@ops.RegisterShape("CholeskyGrad")
@ops.RegisterShape("MatrixInverse") @ops.RegisterShape("MatrixInverse")
def _UnchangedSquare(op): def _UnchangedSquare(op):
input_shape = op.inputs[0].get_shape().with_rank(2) input_shape = op.inputs[0].get_shape().with_rank(2)
@ -37,6 +38,7 @@ def _UnchangedSquare(op):
@ops.RegisterShape("BatchCholesky") @ops.RegisterShape("BatchCholesky")
@ops.RegisterShape("BatchCholeskyGrad")
@ops.RegisterShape("BatchMatrixInverse") @ops.RegisterShape("BatchMatrixInverse")
def _BatchUnchangedSquare(op): def _BatchUnchangedSquare(op):
input_shape = op.inputs[0].get_shape().with_rank_at_least(2) input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
@ -44,10 +46,6 @@ def _BatchUnchangedSquare(op):
input_shape[-1].assert_is_compatible_with(input_shape[-2]) input_shape[-1].assert_is_compatible_with(input_shape[-2])
return [input_shape] return [input_shape]
@ops.RegisterShape("CholeskyGrad")
def _cholesky_grad_shape(op):
return [op.inputs[0].get_shape()]
@ops.RegisterShape("MatrixDeterminant") @ops.RegisterShape("MatrixDeterminant")
def _MatrixDeterminantShape(op): def _MatrixDeterminantShape(op):
input_shape = op.inputs[0].get_shape().with_rank(2) input_shape = op.inputs[0].get_shape().with_rank(2)