Add gradients for batch_cholesky.
Change: 123328520
This commit is contained in:
parent
c1b5075d1f
commit
cafe948be4
@ -13,75 +13,68 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/framework/op.h"
|
|
||||||
#include "third_party/eigen3/Eigen/Core"
|
#include "third_party/eigen3/Eigen/Core"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/kernels/binary_linalg_ops_common.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
template <typename T>
|
template <typename Scalar, bool SupportsBatchOperationT>
|
||||||
class CholeskyGrad : public OpKernel {
|
class CholeskyGrad
|
||||||
|
: public BinaryLinearAlgebraOp<Scalar, SupportsBatchOperationT> {
|
||||||
public:
|
public:
|
||||||
explicit CholeskyGrad(OpKernelConstruction* context) : OpKernel(context) {}
|
explicit CholeskyGrad(OpKernelConstruction* context)
|
||||||
|
: BinaryLinearAlgebraOp<Scalar, SupportsBatchOperationT>(context) {}
|
||||||
|
~CholeskyGrad() override {}
|
||||||
|
|
||||||
using Matrix =
|
using Matrix =
|
||||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
||||||
using ConstMatrixMap = Eigen::Map<const Matrix>;
|
using ConstMatrixMap = Eigen::Map<const Matrix>;
|
||||||
using MatrixMap = Eigen::Map<Matrix>;
|
using MatrixMap = Eigen::Map<Matrix>;
|
||||||
using ConstRef = Eigen::Ref<const Matrix>;
|
using ConstRef = Eigen::Ref<const Matrix>;
|
||||||
using Ref = Eigen::Ref<Matrix>;
|
using Ref = Eigen::Ref<Matrix>;
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
TensorShape GetOutputMatrixShape(
|
||||||
const Tensor& input_tensor_l = context->input(0);
|
const TensorShape& input_matrix_l_full_shape,
|
||||||
const Tensor& input_tensor_grad = context->input(1);
|
const TensorShape& input_matrix_grad_shape) override {
|
||||||
// Check that input tensors represent a matrix.
|
return input_matrix_l_full_shape;
|
||||||
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_tensor_l.shape()),
|
}
|
||||||
errors::InvalidArgument("In[0] is not a matrix"));
|
|
||||||
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_tensor_grad.shape()),
|
|
||||||
errors::InvalidArgument("In[1] is not a matrix"));
|
|
||||||
// Check that input tensors are square.
|
|
||||||
OP_REQUIRES(context,
|
|
||||||
input_tensor_l.dim_size(0) == input_tensor_l.dim_size(1),
|
|
||||||
errors::InvalidArgument("Input matrix must be square."));
|
|
||||||
OP_REQUIRES(context,
|
|
||||||
input_tensor_grad.dim_size(0) == input_tensor_grad.dim_size(1),
|
|
||||||
errors::InvalidArgument("Input matrix must be square."));
|
|
||||||
|
|
||||||
// Check that input tensors are of same size.
|
int64 GetCostPerUnit(const TensorShape& input_matrix_shape,
|
||||||
OP_REQUIRES(context,
|
const TensorShape& rhs_matrix_shape) override {
|
||||||
input_tensor_l.dim_size(0) == input_tensor_grad.dim_size(0),
|
const int64 rows = input_matrix_shape.dim_size(0);
|
||||||
errors::InvalidArgument("Input matrices must be same size."));
|
if (rows > (1LL << 20)) {
|
||||||
|
// A big number to cap the cost in case overflow.
|
||||||
// Create an output tensor
|
return kint64max;
|
||||||
Tensor* output_tensor = NULL;
|
} else {
|
||||||
OP_REQUIRES_OK(context, context->allocate_output(
|
return rows * rows * rows;
|
||||||
0, input_tensor_grad.shape(), &output_tensor));
|
|
||||||
|
|
||||||
if (output_tensor->NumElements() == 0) {
|
|
||||||
// the output shape is a 0-element matrix, so there is nothing to do.
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
// The next lines are necessary to get Eigen matrix behaviour.
|
}
|
||||||
const ConstMatrixMap input_matrix_l_full(input_tensor_l.flat<T>().data(),
|
|
||||||
input_tensor_l.dim_size(0),
|
|
||||||
input_tensor_l.dim_size(1));
|
|
||||||
const ConstMatrixMap input_matrix_grad(input_tensor_grad.flat<T>().data(),
|
|
||||||
input_tensor_grad.dim_size(0),
|
|
||||||
input_tensor_grad.dim_size(1));
|
|
||||||
MatrixMap output_matrix(output_tensor->template flat<T>().data(),
|
|
||||||
input_tensor_l.dim_size(0),
|
|
||||||
input_tensor_l.dim_size(1));
|
|
||||||
|
|
||||||
// Algorithm only depends on lower triangular half on input_tensor_l.
|
void ComputeMatrix(OpKernelContext* context,
|
||||||
|
const ConstMatrixMap& input_matrix_l_full,
|
||||||
|
const ConstMatrixMap& input_matrix_grad,
|
||||||
|
MatrixMap* output_matrix) override {
|
||||||
|
OP_REQUIRES(context,
|
||||||
|
input_matrix_l_full.rows() == input_matrix_l_full.cols(),
|
||||||
|
errors::InvalidArgument("Input matrix must be square."));
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, input_matrix_l_full.cols() == input_matrix_grad.cols(),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Input matrix and gradient must have same number of cols."));
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, input_matrix_l_full.rows() == input_matrix_grad.rows(),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Input matrix and gradient must have same number of rows."));
|
||||||
|
|
||||||
|
// Algorithm only depends on lower triangular half on input_matrix_l.
|
||||||
const Matrix input_matrix_l =
|
const Matrix input_matrix_l =
|
||||||
input_matrix_l_full.template triangularView<Eigen::Lower>();
|
input_matrix_l_full.template triangularView<Eigen::Lower>();
|
||||||
// Algorithm only depends on lower triangular half on input_matrix_grad.
|
// Algorithm only depends on lower triangular half on input_matrix_grad.
|
||||||
output_matrix = input_matrix_grad.template triangularView<Eigen::Lower>();
|
*output_matrix = input_matrix_grad.template triangularView<Eigen::Lower>();
|
||||||
|
|
||||||
const int64 kMatrixSize = input_matrix_l.rows();
|
const int64 kMatrixSize = input_matrix_l.rows();
|
||||||
const int64 kMaxBlockSize = 32;
|
const int64 kMaxBlockSize = 32;
|
||||||
@ -104,20 +97,21 @@ class CholeskyGrad : public OpKernel {
|
|||||||
|
|
||||||
auto B = input_matrix_l.block(block_end, 0, trailing_size, block_begin);
|
auto B = input_matrix_l.block(block_end, 0, trailing_size, block_begin);
|
||||||
auto B_bar =
|
auto B_bar =
|
||||||
output_matrix.block(block_end, 0, trailing_size, block_begin);
|
output_matrix->block(block_end, 0, trailing_size, block_begin);
|
||||||
|
|
||||||
auto C = input_matrix_l.block(block_end, block_begin, trailing_size,
|
auto C = input_matrix_l.block(block_end, block_begin, trailing_size,
|
||||||
block_size);
|
block_size);
|
||||||
auto C_bar = output_matrix.block(block_end, block_begin, trailing_size,
|
auto C_bar = output_matrix->block(block_end, block_begin, trailing_size,
|
||||||
block_size);
|
block_size);
|
||||||
|
|
||||||
auto D = input_matrix_l.block(block_begin, block_begin, block_size,
|
auto D = input_matrix_l.block(block_begin, block_begin, block_size,
|
||||||
block_size);
|
block_size);
|
||||||
auto D_bar =
|
auto D_bar = output_matrix->block(block_begin, block_begin, block_size,
|
||||||
output_matrix.block(block_begin, block_begin, block_size, block_size);
|
block_size);
|
||||||
|
|
||||||
auto R = input_matrix_l.block(block_begin, 0, block_size, block_begin);
|
auto R = input_matrix_l.block(block_begin, 0, block_size, block_begin);
|
||||||
auto R_bar = output_matrix.block(block_begin, 0, block_size, block_begin);
|
auto R_bar =
|
||||||
|
output_matrix->block(block_begin, 0, block_size, block_begin);
|
||||||
|
|
||||||
C_bar = D.adjoint().template triangularView<Eigen::Upper>()
|
C_bar = D.adjoint().template triangularView<Eigen::Upper>()
|
||||||
.solve(C_bar.adjoint()).adjoint();
|
.solve(C_bar.adjoint()).adjoint();
|
||||||
@ -127,9 +121,11 @@ class CholeskyGrad : public OpKernel {
|
|||||||
CholeskyGradUnblocked(D, D_bar);
|
CholeskyGradUnblocked(D, D_bar);
|
||||||
R_bar -= (D_bar + D_bar.adjoint()) * R;
|
R_bar -= (D_bar + D_bar.adjoint()) * R;
|
||||||
}
|
}
|
||||||
output_matrix = (0.5 * (output_matrix + output_matrix.transpose())).eval();
|
*output_matrix =
|
||||||
|
(0.5 * (*output_matrix + output_matrix->transpose())).eval();
|
||||||
}
|
}
|
||||||
void CholeskyGradUnblocked(const ConstRef l_block, Ref grad_block) {
|
|
||||||
|
void CholeskyGradUnblocked(const ConstRef& l_block, Ref grad_block) {
|
||||||
const int64 kMatrixSize = l_block.rows();
|
const int64 kMatrixSize = l_block.rows();
|
||||||
for (int64 k = kMatrixSize - 1; k >= 0; k--) {
|
for (int64 k = kMatrixSize - 1; k >= 0; k--) {
|
||||||
/* This shows the block structure.
|
/* This shows the block structure.
|
||||||
@ -166,6 +162,11 @@ class CholeskyGrad : public OpKernel {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_LINALG_OP("CholeskyGrad", (CholeskyGrad<float>), float);
|
REGISTER_BINARY_LINALG_OP("CholeskyGrad", (CholeskyGrad<float, false>), float);
|
||||||
REGISTER_LINALG_OP("CholeskyGrad", (CholeskyGrad<double>), double);
|
REGISTER_BINARY_LINALG_OP("CholeskyGrad", (CholeskyGrad<double, false>),
|
||||||
|
double);
|
||||||
|
REGISTER_BINARY_LINALG_OP("BatchCholeskyGrad", (CholeskyGrad<float, true>),
|
||||||
|
float);
|
||||||
|
REGISTER_BINARY_LINALG_OP("BatchCholeskyGrad", (CholeskyGrad<double, true>),
|
||||||
|
double);
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -129,11 +129,34 @@ REGISTER_OP("CholeskyGrad")
|
|||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Calculates the reverse mode backpropagated gradient of the Cholesky algorithm.
|
Calculates the reverse mode backpropagated gradient of the Cholesky algorithm.
|
||||||
|
|
||||||
For an explanation see "Differentiation of the Cholesky algorithm" by Iain Murray http://arxiv.org/abs/1602.07527.
|
For an explanation see "Differentiation of the Cholesky algorithm" by
|
||||||
|
Iain Murray http://arxiv.org/abs/1602.07527.
|
||||||
|
|
||||||
l: Output of Cholesky algorithm l = chol(A). Shape is `[M, M]`. Algorithm depends only on lower triangular part of this matrix.
|
l: Output of Cholesky algorithm l = chol(A). Shape is `[M, M]`.
|
||||||
grad: df/dl where f is some scalar function. Shape is `[M, M]'. Algorithm depends only on lower triangular part of this matrix.
|
Algorithm depends only on lower triangular part of this matrix.
|
||||||
output: Symmetrized version of df/dA . Shape is `[M, M]'
|
grad: df/dl where f is some scalar function. Shape is `[M, M]'.
|
||||||
|
Algorithm depends only on lower triangular part of this matrix.
|
||||||
|
output: Symmetrized version of df/dA . Shape is `[M, M]'.
|
||||||
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("BatchCholeskyGrad")
|
||||||
|
.Input("l: T")
|
||||||
|
.Input("grad: T")
|
||||||
|
.Output("output: T")
|
||||||
|
.Attr("T: {float, double}")
|
||||||
|
.Doc(R"doc(
|
||||||
|
Calculates the reverse mode backpropagated gradient of the Cholesky algorithm.
|
||||||
|
|
||||||
|
For an explanation see "Differentiation of the Cholesky algorithm" by
|
||||||
|
Iain Murray http://arxiv.org/abs/1602.07527.
|
||||||
|
|
||||||
|
l: Output of batch Cholesky algorithm l = batch_cholesky(A). Shape is `[..., M, M]`.
|
||||||
|
Algorithm depends only on lower triangular part of the innermost matrices of
|
||||||
|
this tensor.
|
||||||
|
grad: df/dl where f is some scalar function. Shape is `[..., M, M]'.
|
||||||
|
Algorithm depends only on lower triangular part of the innermost matrices of
|
||||||
|
this tensor.
|
||||||
|
output: Symmetrized version of df/dA . Shape is `[..., M, M]'
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
REGISTER_OP("SelfAdjointEig")
|
REGISTER_OP("SelfAdjointEig")
|
||||||
|
@ -71,18 +71,23 @@ class CholeskyOpTest(tf.test.TestCase):
|
|||||||
def testNonSquareMatrix(self):
|
def testNonSquareMatrix(self):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
tf.cholesky(np.array([[1., 2., 3.], [3., 4., 5.]]))
|
tf.cholesky(np.array([[1., 2., 3.], [3., 4., 5.]]))
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tf.batch_cholesky(np.array([[[1., 2., 3.], [3., 4., 5.]],
|
||||||
|
[[1., 2., 3.], [3., 4., 5.]]]))
|
||||||
|
|
||||||
def testWrongDimensions(self):
|
def testWrongDimensions(self):
|
||||||
tensor3 = tf.constant([1., 2.])
|
tensor3 = tf.constant([1., 2.])
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
tf.cholesky(tensor3)
|
tf.cholesky(tensor3)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tf.batch_cholesky(tensor3)
|
||||||
|
|
||||||
def testNotInvertible(self):
|
def testNotInvertible(self):
|
||||||
# The input should be invertible.
|
# The input should be invertible.
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
with self.assertRaisesOpError("LLT decomposition was not successful. The"
|
with self.assertRaisesOpError("LLT decomposition was not successful. The"
|
||||||
" input might not be valid."):
|
" input might not be valid."):
|
||||||
# All rows of the matrix below add to zero
|
# All rows of the matrix below add to zero
|
||||||
self._verifyCholesky(np.array([[1., -1., 0.], [-1., 1., -1.], [0., -1.,
|
self._verifyCholesky(np.array([[1., -1., 0.], [-1., 1., -1.], [0., -1.,
|
||||||
1.]]))
|
1.]]))
|
||||||
|
|
||||||
@ -122,24 +127,36 @@ class CholeskyGradTest(tf.test.TestCase):
|
|||||||
scalarTest=False):
|
scalarTest=False):
|
||||||
with self.test_session(use_gpu=False):
|
with self.test_session(use_gpu=False):
|
||||||
for shape in shapes:
|
for shape in shapes:
|
||||||
for dtype in dtypes:
|
for batch in False, True:
|
||||||
if not(scalarTest):
|
for dtype in dtypes:
|
||||||
x = tf.constant(np.random.randn(shape[0], shape[1]), dtype)
|
if not scalarTest:
|
||||||
K = tf.matmul(x, tf.transpose(x)) / shape[0] # K is posdef
|
x = tf.constant(np.random.randn(shape[0], shape[1]), dtype)
|
||||||
y = tf.cholesky(K)
|
tensor = tf.matmul(x, tf.transpose(x)) / shape[0]
|
||||||
else: # This is designed to be a faster test for larger matrices.
|
else:
|
||||||
x = tf.constant(np.random.randn(), dtype)
|
# This is designed to be a faster test for larger matrices.
|
||||||
R = tf.constant(np.random.randn(shape[0], shape[1]), dtype)
|
x = tf.constant(np.random.randn(), dtype)
|
||||||
e = tf.mul(R, x)
|
R = tf.constant(np.random.randn(shape[0], shape[1]), dtype)
|
||||||
K = tf.matmul(e, tf.transpose(e)) / shape[0] # K is posdef
|
e = tf.mul(R, x)
|
||||||
y = tf.reduce_mean(tf.cholesky(K))
|
tensor = tf.matmul(e, tf.transpose(e)) / shape[0]
|
||||||
error = tf.test.compute_gradient_error(x, x._shape_as_list(),
|
|
||||||
y, y._shape_as_list())
|
# Inner-most matrices in tensor are positive definite.
|
||||||
tf.logging.info("error = %f", error)
|
if batch:
|
||||||
if dtype == tf.float64:
|
tensor = tf.tile(tf.expand_dims(tensor, 0), [4, 1, 1])
|
||||||
self.assertLess(error, 1e-5)
|
op = tf.batch_cholesky
|
||||||
else:
|
else:
|
||||||
self.assertLess(error, 2e-3)
|
op = tf.cholesky
|
||||||
|
|
||||||
|
if not (scalarTest):
|
||||||
|
y = op(tensor)
|
||||||
|
else:
|
||||||
|
y = tf.reduce_mean(op(tensor))
|
||||||
|
error = tf.test.compute_gradient_error(x, x._shape_as_list(), y,
|
||||||
|
y._shape_as_list())
|
||||||
|
tf.logging.info("error = %f", error)
|
||||||
|
if dtype == tf.float64:
|
||||||
|
self.assertLess(error, 1e-5)
|
||||||
|
else:
|
||||||
|
self.assertLess(error, 3e-3)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
@ -32,6 +32,9 @@ from tensorflow.python.ops import constant_op
|
|||||||
from tensorflow.python.ops import linalg_ops
|
from tensorflow.python.ops import linalg_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
|
||||||
|
ops.NoGradient("CholeskyGrad")
|
||||||
|
ops.NoGradient("BatchCholeskyGrad")
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("MatrixInverse")
|
@ops.RegisterGradient("MatrixInverse")
|
||||||
def _MatrixInverseGrad(op, grad):
|
def _MatrixInverseGrad(op, grad):
|
||||||
@ -76,11 +79,17 @@ def _BatchMatrixDeterminantGrad(op, grad):
|
|||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("Cholesky")
|
@ops.RegisterGradient("Cholesky")
|
||||||
def _cholesky_grad(op, grad):
|
def _CholeskyGrad(op, grad):
|
||||||
"""Gradient for Cholesky."""
|
"""Gradient for Cholesky."""
|
||||||
return linalg_ops.cholesky_grad(op.outputs[0], grad)
|
return linalg_ops.cholesky_grad(op.outputs[0], grad)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterGradient("BatchCholesky")
|
||||||
|
def _BatchCholeskyGrad(op, grad):
|
||||||
|
"""Gradient for BatchCholesky."""
|
||||||
|
return linalg_ops.batch_cholesky_grad(op.outputs[0], grad)
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("MatrixSolve")
|
@ops.RegisterGradient("MatrixSolve")
|
||||||
def _MatrixSolveGrad(op, grad):
|
def _MatrixSolveGrad(op, grad):
|
||||||
"""Gradients for MatrixSolve."""
|
"""Gradients for MatrixSolve."""
|
||||||
|
@ -28,6 +28,7 @@ from tensorflow.python.ops.gen_linalg_ops import *
|
|||||||
|
|
||||||
|
|
||||||
@ops.RegisterShape("Cholesky")
|
@ops.RegisterShape("Cholesky")
|
||||||
|
@ops.RegisterShape("CholeskyGrad")
|
||||||
@ops.RegisterShape("MatrixInverse")
|
@ops.RegisterShape("MatrixInverse")
|
||||||
def _UnchangedSquare(op):
|
def _UnchangedSquare(op):
|
||||||
input_shape = op.inputs[0].get_shape().with_rank(2)
|
input_shape = op.inputs[0].get_shape().with_rank(2)
|
||||||
@ -37,6 +38,7 @@ def _UnchangedSquare(op):
|
|||||||
|
|
||||||
|
|
||||||
@ops.RegisterShape("BatchCholesky")
|
@ops.RegisterShape("BatchCholesky")
|
||||||
|
@ops.RegisterShape("BatchCholeskyGrad")
|
||||||
@ops.RegisterShape("BatchMatrixInverse")
|
@ops.RegisterShape("BatchMatrixInverse")
|
||||||
def _BatchUnchangedSquare(op):
|
def _BatchUnchangedSquare(op):
|
||||||
input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
|
input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
|
||||||
@ -44,10 +46,6 @@ def _BatchUnchangedSquare(op):
|
|||||||
input_shape[-1].assert_is_compatible_with(input_shape[-2])
|
input_shape[-1].assert_is_compatible_with(input_shape[-2])
|
||||||
return [input_shape]
|
return [input_shape]
|
||||||
|
|
||||||
@ops.RegisterShape("CholeskyGrad")
|
|
||||||
def _cholesky_grad_shape(op):
|
|
||||||
return [op.inputs[0].get_shape()]
|
|
||||||
|
|
||||||
@ops.RegisterShape("MatrixDeterminant")
|
@ops.RegisterShape("MatrixDeterminant")
|
||||||
def _MatrixDeterminantShape(op):
|
def _MatrixDeterminantShape(op):
|
||||||
input_shape = op.inputs[0].get_shape().with_rank(2)
|
input_shape = op.inputs[0].get_shape().with_rank(2)
|
||||||
|
Loading…
Reference in New Issue
Block a user