Add gradients for batch_cholesky.
Change: 123328520
This commit is contained in:
parent
c1b5075d1f
commit
cafe948be4
tensorflow
core
python
@ -13,75 +13,68 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/binary_linalg_ops_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
template <typename T>
|
||||
class CholeskyGrad : public OpKernel {
|
||||
template <typename Scalar, bool SupportsBatchOperationT>
|
||||
class CholeskyGrad
|
||||
: public BinaryLinearAlgebraOp<Scalar, SupportsBatchOperationT> {
|
||||
public:
|
||||
explicit CholeskyGrad(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
explicit CholeskyGrad(OpKernelConstruction* context)
|
||||
: BinaryLinearAlgebraOp<Scalar, SupportsBatchOperationT>(context) {}
|
||||
~CholeskyGrad() override {}
|
||||
|
||||
using Matrix =
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
||||
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
||||
using ConstMatrixMap = Eigen::Map<const Matrix>;
|
||||
using MatrixMap = Eigen::Map<Matrix>;
|
||||
using ConstRef = Eigen::Ref<const Matrix>;
|
||||
using Ref = Eigen::Ref<Matrix>;
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& input_tensor_l = context->input(0);
|
||||
const Tensor& input_tensor_grad = context->input(1);
|
||||
// Check that input tensors represent a matrix.
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_tensor_l.shape()),
|
||||
errors::InvalidArgument("In[0] is not a matrix"));
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_tensor_grad.shape()),
|
||||
errors::InvalidArgument("In[1] is not a matrix"));
|
||||
// Check that input tensors are square.
|
||||
OP_REQUIRES(context,
|
||||
input_tensor_l.dim_size(0) == input_tensor_l.dim_size(1),
|
||||
errors::InvalidArgument("Input matrix must be square."));
|
||||
OP_REQUIRES(context,
|
||||
input_tensor_grad.dim_size(0) == input_tensor_grad.dim_size(1),
|
||||
errors::InvalidArgument("Input matrix must be square."));
|
||||
TensorShape GetOutputMatrixShape(
|
||||
const TensorShape& input_matrix_l_full_shape,
|
||||
const TensorShape& input_matrix_grad_shape) override {
|
||||
return input_matrix_l_full_shape;
|
||||
}
|
||||
|
||||
// Check that input tensors are of same size.
|
||||
OP_REQUIRES(context,
|
||||
input_tensor_l.dim_size(0) == input_tensor_grad.dim_size(0),
|
||||
errors::InvalidArgument("Input matrices must be same size."));
|
||||
|
||||
// Create an output tensor
|
||||
Tensor* output_tensor = NULL;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(
|
||||
0, input_tensor_grad.shape(), &output_tensor));
|
||||
|
||||
if (output_tensor->NumElements() == 0) {
|
||||
// the output shape is a 0-element matrix, so there is nothing to do.
|
||||
return;
|
||||
int64 GetCostPerUnit(const TensorShape& input_matrix_shape,
|
||||
const TensorShape& rhs_matrix_shape) override {
|
||||
const int64 rows = input_matrix_shape.dim_size(0);
|
||||
if (rows > (1LL << 20)) {
|
||||
// A big number to cap the cost in case overflow.
|
||||
return kint64max;
|
||||
} else {
|
||||
return rows * rows * rows;
|
||||
}
|
||||
// The next lines are necessary to get Eigen matrix behaviour.
|
||||
const ConstMatrixMap input_matrix_l_full(input_tensor_l.flat<T>().data(),
|
||||
input_tensor_l.dim_size(0),
|
||||
input_tensor_l.dim_size(1));
|
||||
const ConstMatrixMap input_matrix_grad(input_tensor_grad.flat<T>().data(),
|
||||
input_tensor_grad.dim_size(0),
|
||||
input_tensor_grad.dim_size(1));
|
||||
MatrixMap output_matrix(output_tensor->template flat<T>().data(),
|
||||
input_tensor_l.dim_size(0),
|
||||
input_tensor_l.dim_size(1));
|
||||
}
|
||||
|
||||
// Algorithm only depends on lower triangular half on input_tensor_l.
|
||||
void ComputeMatrix(OpKernelContext* context,
|
||||
const ConstMatrixMap& input_matrix_l_full,
|
||||
const ConstMatrixMap& input_matrix_grad,
|
||||
MatrixMap* output_matrix) override {
|
||||
OP_REQUIRES(context,
|
||||
input_matrix_l_full.rows() == input_matrix_l_full.cols(),
|
||||
errors::InvalidArgument("Input matrix must be square."));
|
||||
OP_REQUIRES(
|
||||
context, input_matrix_l_full.cols() == input_matrix_grad.cols(),
|
||||
errors::InvalidArgument(
|
||||
"Input matrix and gradient must have same number of cols."));
|
||||
OP_REQUIRES(
|
||||
context, input_matrix_l_full.rows() == input_matrix_grad.rows(),
|
||||
errors::InvalidArgument(
|
||||
"Input matrix and gradient must have same number of rows."));
|
||||
|
||||
// Algorithm only depends on lower triangular half on input_matrix_l.
|
||||
const Matrix input_matrix_l =
|
||||
input_matrix_l_full.template triangularView<Eigen::Lower>();
|
||||
// Algorithm only depends on lower triangular half on input_matrix_grad.
|
||||
output_matrix = input_matrix_grad.template triangularView<Eigen::Lower>();
|
||||
*output_matrix = input_matrix_grad.template triangularView<Eigen::Lower>();
|
||||
|
||||
const int64 kMatrixSize = input_matrix_l.rows();
|
||||
const int64 kMaxBlockSize = 32;
|
||||
@ -104,20 +97,21 @@ class CholeskyGrad : public OpKernel {
|
||||
|
||||
auto B = input_matrix_l.block(block_end, 0, trailing_size, block_begin);
|
||||
auto B_bar =
|
||||
output_matrix.block(block_end, 0, trailing_size, block_begin);
|
||||
output_matrix->block(block_end, 0, trailing_size, block_begin);
|
||||
|
||||
auto C = input_matrix_l.block(block_end, block_begin, trailing_size,
|
||||
block_size);
|
||||
auto C_bar = output_matrix.block(block_end, block_begin, trailing_size,
|
||||
block_size);
|
||||
auto C_bar = output_matrix->block(block_end, block_begin, trailing_size,
|
||||
block_size);
|
||||
|
||||
auto D = input_matrix_l.block(block_begin, block_begin, block_size,
|
||||
block_size);
|
||||
auto D_bar =
|
||||
output_matrix.block(block_begin, block_begin, block_size, block_size);
|
||||
auto D_bar = output_matrix->block(block_begin, block_begin, block_size,
|
||||
block_size);
|
||||
|
||||
auto R = input_matrix_l.block(block_begin, 0, block_size, block_begin);
|
||||
auto R_bar = output_matrix.block(block_begin, 0, block_size, block_begin);
|
||||
auto R_bar =
|
||||
output_matrix->block(block_begin, 0, block_size, block_begin);
|
||||
|
||||
C_bar = D.adjoint().template triangularView<Eigen::Upper>()
|
||||
.solve(C_bar.adjoint()).adjoint();
|
||||
@ -127,9 +121,11 @@ class CholeskyGrad : public OpKernel {
|
||||
CholeskyGradUnblocked(D, D_bar);
|
||||
R_bar -= (D_bar + D_bar.adjoint()) * R;
|
||||
}
|
||||
output_matrix = (0.5 * (output_matrix + output_matrix.transpose())).eval();
|
||||
*output_matrix =
|
||||
(0.5 * (*output_matrix + output_matrix->transpose())).eval();
|
||||
}
|
||||
void CholeskyGradUnblocked(const ConstRef l_block, Ref grad_block) {
|
||||
|
||||
void CholeskyGradUnblocked(const ConstRef& l_block, Ref grad_block) {
|
||||
const int64 kMatrixSize = l_block.rows();
|
||||
for (int64 k = kMatrixSize - 1; k >= 0; k--) {
|
||||
/* This shows the block structure.
|
||||
@ -166,6 +162,11 @@ class CholeskyGrad : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_LINALG_OP("CholeskyGrad", (CholeskyGrad<float>), float);
|
||||
REGISTER_LINALG_OP("CholeskyGrad", (CholeskyGrad<double>), double);
|
||||
REGISTER_BINARY_LINALG_OP("CholeskyGrad", (CholeskyGrad<float, false>), float);
|
||||
REGISTER_BINARY_LINALG_OP("CholeskyGrad", (CholeskyGrad<double, false>),
|
||||
double);
|
||||
REGISTER_BINARY_LINALG_OP("BatchCholeskyGrad", (CholeskyGrad<float, true>),
|
||||
float);
|
||||
REGISTER_BINARY_LINALG_OP("BatchCholeskyGrad", (CholeskyGrad<double, true>),
|
||||
double);
|
||||
} // namespace tensorflow
|
||||
|
@ -129,11 +129,34 @@ REGISTER_OP("CholeskyGrad")
|
||||
.Doc(R"doc(
|
||||
Calculates the reverse mode backpropagated gradient of the Cholesky algorithm.
|
||||
|
||||
For an explanation see "Differentiation of the Cholesky algorithm" by Iain Murray http://arxiv.org/abs/1602.07527.
|
||||
For an explanation see "Differentiation of the Cholesky algorithm" by
|
||||
Iain Murray http://arxiv.org/abs/1602.07527.
|
||||
|
||||
l: Output of Cholesky algorithm l = chol(A). Shape is `[M, M]`. Algorithm depends only on lower triangular part of this matrix.
|
||||
grad: df/dl where f is some scalar function. Shape is `[M, M]'. Algorithm depends only on lower triangular part of this matrix.
|
||||
output: Symmetrized version of df/dA . Shape is `[M, M]'
|
||||
l: Output of Cholesky algorithm l = chol(A). Shape is `[M, M]`.
|
||||
Algorithm depends only on lower triangular part of this matrix.
|
||||
grad: df/dl where f is some scalar function. Shape is `[M, M]'.
|
||||
Algorithm depends only on lower triangular part of this matrix.
|
||||
output: Symmetrized version of df/dA . Shape is `[M, M]'.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("BatchCholeskyGrad")
|
||||
.Input("l: T")
|
||||
.Input("grad: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: {float, double}")
|
||||
.Doc(R"doc(
|
||||
Calculates the reverse mode backpropagated gradient of the Cholesky algorithm.
|
||||
|
||||
For an explanation see "Differentiation of the Cholesky algorithm" by
|
||||
Iain Murray http://arxiv.org/abs/1602.07527.
|
||||
|
||||
l: Output of batch Cholesky algorithm l = batch_cholesky(A). Shape is `[..., M, M]`.
|
||||
Algorithm depends only on lower triangular part of the innermost matrices of
|
||||
this tensor.
|
||||
grad: df/dl where f is some scalar function. Shape is `[..., M, M]'.
|
||||
Algorithm depends only on lower triangular part of the innermost matrices of
|
||||
this tensor.
|
||||
output: Symmetrized version of df/dA . Shape is `[..., M, M]'
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("SelfAdjointEig")
|
||||
|
@ -71,18 +71,23 @@ class CholeskyOpTest(tf.test.TestCase):
|
||||
def testNonSquareMatrix(self):
|
||||
with self.assertRaises(ValueError):
|
||||
tf.cholesky(np.array([[1., 2., 3.], [3., 4., 5.]]))
|
||||
with self.assertRaises(ValueError):
|
||||
tf.batch_cholesky(np.array([[[1., 2., 3.], [3., 4., 5.]],
|
||||
[[1., 2., 3.], [3., 4., 5.]]]))
|
||||
|
||||
def testWrongDimensions(self):
|
||||
tensor3 = tf.constant([1., 2.])
|
||||
with self.assertRaises(ValueError):
|
||||
tf.cholesky(tensor3)
|
||||
with self.assertRaises(ValueError):
|
||||
tf.batch_cholesky(tensor3)
|
||||
|
||||
def testNotInvertible(self):
|
||||
# The input should be invertible.
|
||||
# The input should be invertible.
|
||||
with self.test_session():
|
||||
with self.assertRaisesOpError("LLT decomposition was not successful. The"
|
||||
" input might not be valid."):
|
||||
# All rows of the matrix below add to zero
|
||||
# All rows of the matrix below add to zero
|
||||
self._verifyCholesky(np.array([[1., -1., 0.], [-1., 1., -1.], [0., -1.,
|
||||
1.]]))
|
||||
|
||||
@ -122,24 +127,36 @@ class CholeskyGradTest(tf.test.TestCase):
|
||||
scalarTest=False):
|
||||
with self.test_session(use_gpu=False):
|
||||
for shape in shapes:
|
||||
for dtype in dtypes:
|
||||
if not(scalarTest):
|
||||
x = tf.constant(np.random.randn(shape[0], shape[1]), dtype)
|
||||
K = tf.matmul(x, tf.transpose(x)) / shape[0] # K is posdef
|
||||
y = tf.cholesky(K)
|
||||
else: # This is designed to be a faster test for larger matrices.
|
||||
x = tf.constant(np.random.randn(), dtype)
|
||||
R = tf.constant(np.random.randn(shape[0], shape[1]), dtype)
|
||||
e = tf.mul(R, x)
|
||||
K = tf.matmul(e, tf.transpose(e)) / shape[0] # K is posdef
|
||||
y = tf.reduce_mean(tf.cholesky(K))
|
||||
error = tf.test.compute_gradient_error(x, x._shape_as_list(),
|
||||
y, y._shape_as_list())
|
||||
tf.logging.info("error = %f", error)
|
||||
if dtype == tf.float64:
|
||||
self.assertLess(error, 1e-5)
|
||||
else:
|
||||
self.assertLess(error, 2e-3)
|
||||
for batch in False, True:
|
||||
for dtype in dtypes:
|
||||
if not scalarTest:
|
||||
x = tf.constant(np.random.randn(shape[0], shape[1]), dtype)
|
||||
tensor = tf.matmul(x, tf.transpose(x)) / shape[0]
|
||||
else:
|
||||
# This is designed to be a faster test for larger matrices.
|
||||
x = tf.constant(np.random.randn(), dtype)
|
||||
R = tf.constant(np.random.randn(shape[0], shape[1]), dtype)
|
||||
e = tf.mul(R, x)
|
||||
tensor = tf.matmul(e, tf.transpose(e)) / shape[0]
|
||||
|
||||
# Inner-most matrices in tensor are positive definite.
|
||||
if batch:
|
||||
tensor = tf.tile(tf.expand_dims(tensor, 0), [4, 1, 1])
|
||||
op = tf.batch_cholesky
|
||||
else:
|
||||
op = tf.cholesky
|
||||
|
||||
if not (scalarTest):
|
||||
y = op(tensor)
|
||||
else:
|
||||
y = tf.reduce_mean(op(tensor))
|
||||
error = tf.test.compute_gradient_error(x, x._shape_as_list(), y,
|
||||
y._shape_as_list())
|
||||
tf.logging.info("error = %f", error)
|
||||
if dtype == tf.float64:
|
||||
self.assertLess(error, 1e-5)
|
||||
else:
|
||||
self.assertLess(error, 3e-3)
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
|
@ -32,6 +32,9 @@ from tensorflow.python.ops import constant_op
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
ops.NoGradient("CholeskyGrad")
|
||||
ops.NoGradient("BatchCholeskyGrad")
|
||||
|
||||
|
||||
@ops.RegisterGradient("MatrixInverse")
|
||||
def _MatrixInverseGrad(op, grad):
|
||||
@ -76,11 +79,17 @@ def _BatchMatrixDeterminantGrad(op, grad):
|
||||
|
||||
|
||||
@ops.RegisterGradient("Cholesky")
|
||||
def _cholesky_grad(op, grad):
|
||||
def _CholeskyGrad(op, grad):
|
||||
"""Gradient for Cholesky."""
|
||||
return linalg_ops.cholesky_grad(op.outputs[0], grad)
|
||||
|
||||
|
||||
@ops.RegisterGradient("BatchCholesky")
|
||||
def _BatchCholeskyGrad(op, grad):
|
||||
"""Gradient for BatchCholesky."""
|
||||
return linalg_ops.batch_cholesky_grad(op.outputs[0], grad)
|
||||
|
||||
|
||||
@ops.RegisterGradient("MatrixSolve")
|
||||
def _MatrixSolveGrad(op, grad):
|
||||
"""Gradients for MatrixSolve."""
|
||||
|
@ -28,6 +28,7 @@ from tensorflow.python.ops.gen_linalg_ops import *
|
||||
|
||||
|
||||
@ops.RegisterShape("Cholesky")
|
||||
@ops.RegisterShape("CholeskyGrad")
|
||||
@ops.RegisterShape("MatrixInverse")
|
||||
def _UnchangedSquare(op):
|
||||
input_shape = op.inputs[0].get_shape().with_rank(2)
|
||||
@ -37,6 +38,7 @@ def _UnchangedSquare(op):
|
||||
|
||||
|
||||
@ops.RegisterShape("BatchCholesky")
|
||||
@ops.RegisterShape("BatchCholeskyGrad")
|
||||
@ops.RegisterShape("BatchMatrixInverse")
|
||||
def _BatchUnchangedSquare(op):
|
||||
input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
|
||||
@ -44,10 +46,6 @@ def _BatchUnchangedSquare(op):
|
||||
input_shape[-1].assert_is_compatible_with(input_shape[-2])
|
||||
return [input_shape]
|
||||
|
||||
@ops.RegisterShape("CholeskyGrad")
|
||||
def _cholesky_grad_shape(op):
|
||||
return [op.inputs[0].get_shape()]
|
||||
|
||||
@ops.RegisterShape("MatrixDeterminant")
|
||||
def _MatrixDeterminantShape(op):
|
||||
input_shape = op.inputs[0].get_shape().with_rank(2)
|
||||
|
Loading…
Reference in New Issue
Block a user