Merge pull request #20611 from jonathanwyatt16:matrix_square_root
PiperOrigin-RevId: 218197028
This commit is contained in:
		
						commit
						3d715da989
					
				@ -0,0 +1,37 @@
 | 
				
			|||||||
 | 
					op {
 | 
				
			||||||
 | 
					  graph_op_name: "MatrixSquareRoot"
 | 
				
			||||||
 | 
					  in_arg {
 | 
				
			||||||
 | 
					    name: "input"
 | 
				
			||||||
 | 
					    description: <<END
 | 
				
			||||||
 | 
					Shape is `[..., M, M]`.
 | 
				
			||||||
 | 
					END
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  out_arg {
 | 
				
			||||||
 | 
					    name: "output"
 | 
				
			||||||
 | 
					    description: <<END
 | 
				
			||||||
 | 
					Shape is `[..., M, M]`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@compatibility(scipy)
 | 
				
			||||||
 | 
					Equivalent to scipy.linalg.sqrtm
 | 
				
			||||||
 | 
					@end_compatibility
 | 
				
			||||||
 | 
					END
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  summary: "Computes the matrix square root of one or more square matrices:"
 | 
				
			||||||
 | 
					  description: <<END
 | 
				
			||||||
 | 
					matmul(sqrtm(A), sqrtm(A)) = A
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The input matrix should be invertible. If the input matrix is real, it should
 | 
				
			||||||
 | 
					have no eigenvalues which are real and negative (pairs of complex conjugate
 | 
				
			||||||
 | 
					eigenvalues are allowed).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The matrix square root is computed by first reducing the matrix to 
 | 
				
			||||||
 | 
					quasi-triangular form with the real Schur decomposition. The square root 
 | 
				
			||||||
 | 
					of the quasi-triangular matrix is then computed directly. Details of 
 | 
				
			||||||
 | 
					the algorithm can be found in: Nicholas J. Higham, "Computing real 
 | 
				
			||||||
 | 
					square roots of a real matrix", Linear Algebra Appl., 1987.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
 | 
				
			||||||
 | 
					form square matrices. The output is a tensor of the same shape as the input
 | 
				
			||||||
 | 
					containing the matrix square root for all input submatrices `[..., :, :]`.
 | 
				
			||||||
 | 
					END
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -0,0 +1,9 @@
 | 
				
			|||||||
 | 
					op {
 | 
				
			||||||
 | 
					  graph_op_name: "MatrixSquareRoot"
 | 
				
			||||||
 | 
					  endpoint {
 | 
				
			||||||
 | 
					    name: "linalg.sqrtm"
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  endpoint {
 | 
				
			||||||
 | 
					    name: "matrix_square_root"
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -2629,6 +2629,7 @@ cc_library(
 | 
				
			|||||||
        ":matrix_logarithm_op",
 | 
					        ":matrix_logarithm_op",
 | 
				
			||||||
        ":matrix_solve_ls_op",
 | 
					        ":matrix_solve_ls_op",
 | 
				
			||||||
        ":matrix_solve_op",
 | 
					        ":matrix_solve_op",
 | 
				
			||||||
 | 
					        ":matrix_square_root_op",
 | 
				
			||||||
        ":matrix_triangular_solve_op",
 | 
					        ":matrix_triangular_solve_op",
 | 
				
			||||||
        ":qr_op",
 | 
					        ":qr_op",
 | 
				
			||||||
        ":self_adjoint_eig_op",
 | 
					        ":self_adjoint_eig_op",
 | 
				
			||||||
@ -2738,6 +2739,12 @@ tf_kernel_library(
 | 
				
			|||||||
    deps = LINALG_DEPS,
 | 
					    deps = LINALG_DEPS,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					tf_kernel_library(
 | 
				
			||||||
 | 
					    name = "matrix_square_root_op",
 | 
				
			||||||
 | 
					    prefix = "matrix_square_root_op",
 | 
				
			||||||
 | 
					    deps = LINALG_DEPS,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
tf_kernel_library(
 | 
					tf_kernel_library(
 | 
				
			||||||
    name = "matrix_triangular_solve_op",
 | 
					    name = "matrix_triangular_solve_op",
 | 
				
			||||||
    prefix = "matrix_triangular_solve_op",
 | 
					    prefix = "matrix_triangular_solve_op",
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										58
									
								
								tensorflow/core/kernels/matrix_square_root_op.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								tensorflow/core/kernels/matrix_square_root_op.cc
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,58 @@
 | 
				
			|||||||
 | 
					/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					==============================================================================*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// See docs in ../ops/linalg_ops.cc.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "third_party/eigen3/Eigen/Core"
 | 
				
			||||||
 | 
					#include "third_party/eigen3/unsupported/Eigen/MatrixFunctions"
 | 
				
			||||||
 | 
					#include "tensorflow/core/framework/kernel_def_builder.h"
 | 
				
			||||||
 | 
					#include "tensorflow/core/framework/op_kernel.h"
 | 
				
			||||||
 | 
					#include "tensorflow/core/framework/tensor_shape.h"
 | 
				
			||||||
 | 
					#include "tensorflow/core/kernels/linalg_ops_common.h"
 | 
				
			||||||
 | 
					#include "tensorflow/core/lib/core/errors.h"
 | 
				
			||||||
 | 
					#include "tensorflow/core/platform/logging.h"
 | 
				
			||||||
 | 
					#include "tensorflow/core/platform/macros.h"
 | 
				
			||||||
 | 
					#include "tensorflow/core/platform/types.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace tensorflow {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <class Scalar>
 | 
				
			||||||
 | 
					class MatrixSquareRootOp : public LinearAlgebraOp<Scalar> {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  INHERIT_LINALG_TYPEDEFS(Scalar);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  explicit MatrixSquareRootOp(OpKernelConstruction* context) : Base(context) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
 | 
				
			||||||
 | 
					                     MatrixMaps* outputs) final {
 | 
				
			||||||
 | 
					    const ConstMatrixMap& input = inputs[0];
 | 
				
			||||||
 | 
					    if (input.rows() == 0) return;
 | 
				
			||||||
 | 
					    using Matrix =
 | 
				
			||||||
 | 
					        Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
 | 
				
			||||||
 | 
					    Matrix tmp = input;
 | 
				
			||||||
 | 
					    outputs->at(0) = tmp.sqrt();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 private:
 | 
				
			||||||
 | 
					  TF_DISALLOW_COPY_AND_ASSIGN(MatrixSquareRootOp);
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					REGISTER_LINALG_OP("MatrixSquareRoot", (MatrixSquareRootOp<float>), float);
 | 
				
			||||||
 | 
					REGISTER_LINALG_OP("MatrixSquareRoot", (MatrixSquareRootOp<double>), double);
 | 
				
			||||||
 | 
					REGISTER_LINALG_OP("MatrixSquareRoot", (MatrixSquareRootOp<complex64>),
 | 
				
			||||||
 | 
					                   complex64);
 | 
				
			||||||
 | 
					REGISTER_LINALG_OP("MatrixSquareRoot", (MatrixSquareRootOp<complex128>),
 | 
				
			||||||
 | 
					                   complex128);
 | 
				
			||||||
 | 
					}  // namespace tensorflow
 | 
				
			||||||
@ -323,6 +323,12 @@ REGISTER_OP("MatrixSolveLs")
 | 
				
			|||||||
      return MatrixSolveShapeFn(c, false /* square */);
 | 
					      return MatrixSolveShapeFn(c, false /* square */);
 | 
				
			||||||
    });
 | 
					    });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					REGISTER_OP("MatrixSquareRoot")
 | 
				
			||||||
 | 
					    .Input("input: T")
 | 
				
			||||||
 | 
					    .Output("output: T")
 | 
				
			||||||
 | 
					    .Attr("T: {double, float, complex64, complex128}")
 | 
				
			||||||
 | 
					    .SetShapeFn(BatchUnchangedSquareShapeFn);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
REGISTER_OP("Qr")
 | 
					REGISTER_OP("Qr")
 | 
				
			||||||
    .Input("input: T")
 | 
					    .Input("input: T")
 | 
				
			||||||
    .Output("q: T")
 | 
					    .Output("q: T")
 | 
				
			||||||
 | 
				
			|||||||
@ -16084,6 +16084,29 @@ op {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					op {
 | 
				
			||||||
 | 
					  name: "MatrixSquareRoot"
 | 
				
			||||||
 | 
					  input_arg {
 | 
				
			||||||
 | 
					    name: "matrix"
 | 
				
			||||||
 | 
					    type_attr: "T"
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  output_arg {
 | 
				
			||||||
 | 
					    name: "output"
 | 
				
			||||||
 | 
					    type_attr: "T"
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  attr {
 | 
				
			||||||
 | 
					    name: "T"
 | 
				
			||||||
 | 
					    type: "type"
 | 
				
			||||||
 | 
					    allowed_values {
 | 
				
			||||||
 | 
					      list {
 | 
				
			||||||
 | 
					        type: DT_DOUBLE
 | 
				
			||||||
 | 
					        type: DT_FLOAT
 | 
				
			||||||
 | 
					        type: DT_COMPLEX64
 | 
				
			||||||
 | 
					        type: DT_COMPLEX128
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
op {
 | 
					op {
 | 
				
			||||||
  name: "MatrixTriangularSolve"
 | 
					  name: "MatrixTriangularSolve"
 | 
				
			||||||
  input_arg {
 | 
					  input_arg {
 | 
				
			||||||
 | 
				
			|||||||
@ -16660,6 +16660,46 @@ func MatrixSolveLs(scope *Scope, matrix tf.Output, rhs tf.Output, l2_regularizer
 | 
				
			|||||||
	return op.Output(0)
 | 
						return op.Output(0)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Computes the matrix square root of one or more square matrices:
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// matmul(sqrtm(A), sqrtm(A)) = A
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// The input matrix should be invertible. If the input matrix is real,
 | 
				
			||||||
 | 
					// it should have no eigenvalues which are real and negative
 | 
				
			||||||
 | 
					// (pairs of complex conjugate eigenvalues are allowed).
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// The matrix square root is computed by first reducing the matrix to
 | 
				
			||||||
 | 
					// quasi-triangular form with the real Schur decomposition. The square root
 | 
				
			||||||
 | 
					// of the quasi-triangular matrix is then computed directly. Details of
 | 
				
			||||||
 | 
					// the algorithm can be found in: Nicholas J. Higham, "Computing real
 | 
				
			||||||
 | 
					// square roots of a real matrix", Linear Algebra Appl., 1987.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
 | 
				
			||||||
 | 
					// form square matrices. The output is a tensor of the same shape as the input
 | 
				
			||||||
 | 
					// containing the matrix square root for all input submatrices `[..., :, :]`.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Arguments:
 | 
				
			||||||
 | 
					//	input: Shape is `[..., M, M]`.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Returns Shape is `[..., M, M]`.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// @compatibility(scipy)
 | 
				
			||||||
 | 
					// Equivalent to scipy.linalg.sqrtm
 | 
				
			||||||
 | 
					// @end_compatibility
 | 
				
			||||||
 | 
					func MatrixSquareRoot(scope *Scope, input tf.Output) (output tf.Output) {
 | 
				
			||||||
 | 
						if scope.Err() != nil {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						opspec := tf.OpSpec{
 | 
				
			||||||
 | 
							Type: "MatrixSquareRoot",
 | 
				
			||||||
 | 
							Input: []tf.Input{
 | 
				
			||||||
 | 
								input,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						op := scope.AddOperation(opspec)
 | 
				
			||||||
 | 
						return op.Output(0)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// MaxPool3DAttr is an optional argument to MaxPool3D.
 | 
					// MaxPool3DAttr is an optional argument to MaxPool3D.
 | 
				
			||||||
type MaxPool3DAttr func(optionalAttr)
 | 
					type MaxPool3DAttr func(optionalAttr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -66,6 +66,10 @@ def _GetMatrixUnaryFunctorGradientTest(functor_, dtype_, shape_, **kwargs_):
 | 
				
			|||||||
          low=-1.0, high=1.0,
 | 
					          low=-1.0, high=1.0,
 | 
				
			||||||
          size=np.prod(shape_)).reshape(shape_).astype(dtype_)
 | 
					          size=np.prod(shape_)).reshape(shape_).astype(dtype_)
 | 
				
			||||||
      a = constant_op.constant(a_np)
 | 
					      a = constant_op.constant(a_np)
 | 
				
			||||||
 | 
					      if functor_.__name__ == 'matrix_square_root':
 | 
				
			||||||
 | 
					        # Square the input matrix to ensure that its matrix square root exists
 | 
				
			||||||
 | 
					        a = math_ops.matmul(a, a)
 | 
				
			||||||
 | 
					        a_np = a.eval()
 | 
				
			||||||
      b = functor_(a, **kwargs_)
 | 
					      b = functor_(a, **kwargs_)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      # Optimal stepsize for central difference is O(epsilon^{1/3}).
 | 
					      # Optimal stepsize for central difference is O(epsilon^{1/3}).
 | 
				
			||||||
@ -189,6 +193,17 @@ if __name__ == '__main__':
 | 
				
			|||||||
                lambda x: linalg_ops.log_matrix_determinant(x)[1],
 | 
					                lambda x: linalg_ops.log_matrix_determinant(x)[1],
 | 
				
			||||||
                dtype, shape))
 | 
					                dtype, shape))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # The numerical Jacobian is consistently invalid for these four shapes
 | 
				
			||||||
 | 
					        # because the matrix square root of the perturbed input doesn't exist
 | 
				
			||||||
 | 
					        if shape in {(2, 5, 5), (3, 5, 5), (3, 10, 10), (3, 2, 5, 5)}:
 | 
				
			||||||
 | 
					          # Alternative shape that consistently produces a valid numerical Jacobian
 | 
				
			||||||
 | 
					          shape = extra + (size + 1, size + 1)
 | 
				
			||||||
 | 
					          name = '%s_%s' % (dtype.__name__, '_'.join(map(str, shape)))
 | 
				
			||||||
 | 
					        _AddTest(
 | 
				
			||||||
 | 
					            MatrixUnaryFunctorGradientTest, 'MatrixSquareRootGradient', name,
 | 
				
			||||||
 | 
					            _GetMatrixUnaryFunctorGradientTest(linalg_ops.matrix_square_root,
 | 
				
			||||||
 | 
					                                               dtype, shape))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  # Tests for gradients of matrix_solve_ls
 | 
					  # Tests for gradients of matrix_solve_ls
 | 
				
			||||||
  for dtype in np.float32, np.float64:
 | 
					  for dtype in np.float32, np.float64:
 | 
				
			||||||
    for rows in 2, 5, 10:
 | 
					    for rows in 2, 5, 10:
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										116
									
								
								tensorflow/python/kernel_tests/matrix_square_root_op_test.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										116
									
								
								tensorflow/python/kernel_tests/matrix_square_root_op_test.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,116 @@
 | 
				
			|||||||
 | 
					# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					# ==============================================================================
 | 
				
			||||||
 | 
					"""Tests for tensorflow.ops.math_ops.matrix_square_root."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from __future__ import absolute_import
 | 
				
			||||||
 | 
					from __future__ import division
 | 
				
			||||||
 | 
					from __future__ import print_function
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from tensorflow.python.framework import constant_op
 | 
				
			||||||
 | 
					from tensorflow.python.ops import gen_linalg_ops
 | 
				
			||||||
 | 
					from tensorflow.python.ops import math_ops
 | 
				
			||||||
 | 
					from tensorflow.python.ops import random_ops
 | 
				
			||||||
 | 
					from tensorflow.python.platform import test
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SquareRootOpTest(test.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _verifySquareRoot(self, matrix, np_type):
 | 
				
			||||||
 | 
					    matrix = matrix.astype(np_type)
 | 
				
			||||||
 | 
					    with self.test_session(use_gpu=True):
 | 
				
			||||||
 | 
					      # Verify that matmul(sqrtm(A), sqrtm(A)) = A
 | 
				
			||||||
 | 
					      sqrt = gen_linalg_ops.matrix_square_root(matrix)
 | 
				
			||||||
 | 
					      square = math_ops.matmul(sqrt, sqrt)
 | 
				
			||||||
 | 
					      self.assertShapeEqual(matrix, square)
 | 
				
			||||||
 | 
					      self.assertAllClose(matrix, square, rtol=1e-4, atol=1e-3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _verifySquareRootReal(self, x):
 | 
				
			||||||
 | 
					    for np_type in [np.float32, np.float64]:
 | 
				
			||||||
 | 
					      self._verifySquareRoot(x, np_type)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _verifySquareRootComplex(self, x):
 | 
				
			||||||
 | 
					    for np_type in [np.complex64, np.complex128]:
 | 
				
			||||||
 | 
					      self._verifySquareRoot(x, np_type)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _makeBatch(self, matrix1, matrix2):
 | 
				
			||||||
 | 
					    matrix_batch = np.concatenate(
 | 
				
			||||||
 | 
					        [np.expand_dims(matrix1, 0),
 | 
				
			||||||
 | 
					         np.expand_dims(matrix2, 0)])
 | 
				
			||||||
 | 
					    matrix_batch = np.tile(matrix_batch, [2, 3, 1, 1])
 | 
				
			||||||
 | 
					    return matrix_batch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _testMatrices(self, matrix1, matrix2):
 | 
				
			||||||
 | 
					    # Real
 | 
				
			||||||
 | 
					    self._verifySquareRootReal(matrix1)
 | 
				
			||||||
 | 
					    self._verifySquareRootReal(matrix2)
 | 
				
			||||||
 | 
					    self._verifySquareRootReal(self._makeBatch(matrix1, matrix2))
 | 
				
			||||||
 | 
					    # Complex
 | 
				
			||||||
 | 
					    matrix1 = matrix1.astype(np.complex64)
 | 
				
			||||||
 | 
					    matrix2 = matrix2.astype(np.complex64)
 | 
				
			||||||
 | 
					    matrix1 += 1j * matrix1
 | 
				
			||||||
 | 
					    matrix2 += 1j * matrix2
 | 
				
			||||||
 | 
					    self._verifySquareRootComplex(matrix1)
 | 
				
			||||||
 | 
					    self._verifySquareRootComplex(matrix2)
 | 
				
			||||||
 | 
					    self._verifySquareRootComplex(self._makeBatch(matrix1, matrix2))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testSymmetricPositiveDefinite(self):
 | 
				
			||||||
 | 
					    matrix1 = np.array([[2., 1.], [1., 2.]])
 | 
				
			||||||
 | 
					    matrix2 = np.array([[3., -1.], [-1., 3.]])
 | 
				
			||||||
 | 
					    self._testMatrices(matrix1, matrix2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testAsymmetric(self):
 | 
				
			||||||
 | 
					    matrix1 = np.array([[0., 4.], [-1., 5.]])
 | 
				
			||||||
 | 
					    matrix2 = np.array([[33., 24.], [48., 57.]])
 | 
				
			||||||
 | 
					    self._testMatrices(matrix1, matrix2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testIdentityMatrix(self):
 | 
				
			||||||
 | 
					    # 2x2
 | 
				
			||||||
 | 
					    identity = np.array([[1., 0], [0, 1.]])
 | 
				
			||||||
 | 
					    self._verifySquareRootReal(identity)
 | 
				
			||||||
 | 
					    # 3x3
 | 
				
			||||||
 | 
					    identity = np.array([[1., 0, 0], [0, 1., 0], [0, 0, 1.]])
 | 
				
			||||||
 | 
					    self._verifySquareRootReal(identity)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testEmpty(self):
 | 
				
			||||||
 | 
					    self._verifySquareRootReal(np.empty([0, 2, 2]))
 | 
				
			||||||
 | 
					    self._verifySquareRootReal(np.empty([2, 0, 0]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testWrongDimensions(self):
 | 
				
			||||||
 | 
					    # The input to the square root should be at least a 2-dimensional tensor.
 | 
				
			||||||
 | 
					    tensor = constant_op.constant([1., 2.])
 | 
				
			||||||
 | 
					    with self.assertRaises(ValueError):
 | 
				
			||||||
 | 
					      gen_linalg_ops.matrix_square_root(tensor)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testNotSquare(self):
 | 
				
			||||||
 | 
					    with self.test_session():
 | 
				
			||||||
 | 
					      with self.assertRaises(ValueError):
 | 
				
			||||||
 | 
					        tensor = constant_op.constant([[1., 0., -1.], [-1., 1., 0.]])
 | 
				
			||||||
 | 
					        gen_linalg_ops.matrix_square_root(tensor).eval()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def testConcurrentExecutesWithoutError(self):
 | 
				
			||||||
 | 
					    with self.test_session(use_gpu=True) as sess:
 | 
				
			||||||
 | 
					      matrix1 = random_ops.random_normal([5, 5], seed=42)
 | 
				
			||||||
 | 
					      matrix2 = random_ops.random_normal([5, 5], seed=42)
 | 
				
			||||||
 | 
					      sqrt1 = gen_linalg_ops.matrix_square_root(matrix1)
 | 
				
			||||||
 | 
					      sqrt2 = gen_linalg_ops.matrix_square_root(matrix2)
 | 
				
			||||||
 | 
					      all_ops = [sqrt1, sqrt2]
 | 
				
			||||||
 | 
					      sqrt = sess.run(all_ops)
 | 
				
			||||||
 | 
					      self.assertAllEqual(sqrt[0], sqrt[1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					  test.main()
 | 
				
			||||||
@ -50,6 +50,7 @@ norm = linalg_ops.norm
 | 
				
			|||||||
qr = linalg_ops.qr
 | 
					qr = linalg_ops.qr
 | 
				
			||||||
set_diag = array_ops.matrix_set_diag
 | 
					set_diag = array_ops.matrix_set_diag
 | 
				
			||||||
solve = linalg_ops.matrix_solve
 | 
					solve = linalg_ops.matrix_solve
 | 
				
			||||||
 | 
					sqrtm = linalg_ops.matrix_square_root
 | 
				
			||||||
svd = linalg_ops.svd
 | 
					svd = linalg_ops.svd
 | 
				
			||||||
tensordot = math_ops.tensordot
 | 
					tensordot = math_ops.tensordot
 | 
				
			||||||
trace = math_ops.trace
 | 
					trace = math_ops.trace
 | 
				
			||||||
 | 
				
			|||||||
@ -55,6 +55,71 @@ def _MatrixDeterminantGrad(op, grad):
 | 
				
			|||||||
  return multipliers * a_adj_inv
 | 
					  return multipliers * a_adj_inv
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@ops.RegisterGradient("MatrixSquareRoot")
 | 
				
			||||||
 | 
					def _MatrixSquareRootGrad(op, grad):
 | 
				
			||||||
 | 
					  """Gradient for MatrixSquareRoot."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  # Let A be an m x m square matrix (or batch of matrices)
 | 
				
			||||||
 | 
					  # Let R = sqrtm(A)
 | 
				
			||||||
 | 
					  # By definition, A = RR
 | 
				
			||||||
 | 
					  # Take the differential: dA = d(RR) = RdR + dRR
 | 
				
			||||||
 | 
					  # Solve the resulting Sylvester equation for dR
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  # Used to find Kronecker products within the Sylvester equation
 | 
				
			||||||
 | 
					  def _KroneckerProduct(b1, b2):
 | 
				
			||||||
 | 
					    """Computes the Kronecker product of two batches of square matrices"""
 | 
				
			||||||
 | 
					    b1_shape = array_ops.shape(b1)
 | 
				
			||||||
 | 
					    b2_shape = array_ops.shape(b2)
 | 
				
			||||||
 | 
					    b1_order = b1_shape[-1]
 | 
				
			||||||
 | 
					    b2_order = b2_shape[-1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    shape_slice_size = [math_ops.subtract(array_ops.size(b1_shape), 2)]
 | 
				
			||||||
 | 
					    shape_slice = array_ops.slice(b1_shape, [0],
 | 
				
			||||||
 | 
					                                  shape_slice_size)  # Same for both batches
 | 
				
			||||||
 | 
					    b1_reshape_shape = array_ops.concat(
 | 
				
			||||||
 | 
					        [shape_slice, [b1_order], [1], [b1_order], [1]], 0)
 | 
				
			||||||
 | 
					    b2_reshape_shape = array_ops.concat(
 | 
				
			||||||
 | 
					        [shape_slice, [1], [b2_order], [1], [b2_order]], 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    b1_reshape = array_ops.reshape(b1, b1_reshape_shape)
 | 
				
			||||||
 | 
					    b2_reshape = array_ops.reshape(b2, b2_reshape_shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    order_prod = b1_order * b2_order
 | 
				
			||||||
 | 
					    kprod_shape = array_ops.concat([shape_slice, [order_prod], [order_prod]], 0)
 | 
				
			||||||
 | 
					    return array_ops.reshape(b1_reshape * b2_reshape, kprod_shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  sqrtm = op.outputs[0]  # R
 | 
				
			||||||
 | 
					  shape = array_ops.shape(sqrtm)
 | 
				
			||||||
 | 
					  order = shape[-1]  # m
 | 
				
			||||||
 | 
					  matrix_count = math_ops.reduce_prod(shape[0:-2])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  # Get batch of m x m identity matrices
 | 
				
			||||||
 | 
					  eye = linalg_ops.eye(order, dtype=sqrtm.dtype)  # m x m identity matrix
 | 
				
			||||||
 | 
					  eye_flat = array_ops.reshape(eye, [-1])
 | 
				
			||||||
 | 
					  eye_tiled = array_ops.tile(eye_flat, [matrix_count])
 | 
				
			||||||
 | 
					  eye_batch = array_ops.reshape(eye_tiled, shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  # The transpose of R is taken in the k1 term instead of k2 in
 | 
				
			||||||
 | 
					  # order to prevent redundant transposition of R (i.e. (R')' = R)
 | 
				
			||||||
 | 
					  sqrtm_transpose = array_ops.matrix_transpose(sqrtm)
 | 
				
			||||||
 | 
					  k1 = _KroneckerProduct(eye_batch, sqrtm_transpose)
 | 
				
			||||||
 | 
					  k2 = _KroneckerProduct(sqrtm, eye_batch)
 | 
				
			||||||
 | 
					  ksum = math_ops.add(k1, k2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  # Vectorize dA
 | 
				
			||||||
 | 
					  shape_slice_size = [math_ops.subtract(array_ops.size(shape), 2)]
 | 
				
			||||||
 | 
					  shape_slice = array_ops.slice(shape, [0], shape_slice_size)
 | 
				
			||||||
 | 
					  shape_vec_da = array_ops.concat([shape_slice, [order * order], [1]], 0)
 | 
				
			||||||
 | 
					  vec_da = array_ops.reshape(array_ops.matrix_transpose(grad), shape_vec_da)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  # Solve for vec(dR)
 | 
				
			||||||
 | 
					  vec_dsqrtm = linalg_ops.matrix_solve(ksum, vec_da)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  # Solve for dR by inverse vectorizing vec(dR)
 | 
				
			||||||
 | 
					  dsqrtm_transpose = array_ops.reshape(vec_dsqrtm, shape)
 | 
				
			||||||
 | 
					  return array_ops.matrix_transpose(dsqrtm_transpose)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ops.RegisterGradient("LogMatrixDeterminant")
 | 
					@ops.RegisterGradient("LogMatrixDeterminant")
 | 
				
			||||||
def _LogMatrixDeterminantGrad(op, _, grad_b):
 | 
					def _LogMatrixDeterminantGrad(op, _, grad_b):
 | 
				
			||||||
  """Gradient for LogMatrixDeterminant."""
 | 
					  """Gradient for LogMatrixDeterminant."""
 | 
				
			||||||
 | 
				
			|||||||
@ -156,6 +156,10 @@ tf_module {
 | 
				
			|||||||
    name: "solve"
 | 
					    name: "solve"
 | 
				
			||||||
    argspec: "args=[\'matrix\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
 | 
					    argspec: "args=[\'matrix\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					  member_method {
 | 
				
			||||||
 | 
					    name: "sqrtm"
 | 
				
			||||||
 | 
					    argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
  member_method {
 | 
					  member_method {
 | 
				
			||||||
    name: "svd"
 | 
					    name: "svd"
 | 
				
			||||||
    argspec: "args=[\'tensor\', \'full_matrices\', \'compute_uv\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
 | 
					    argspec: "args=[\'tensor\', \'full_matrices\', \'compute_uv\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
 | 
				
			||||||
 | 
				
			|||||||
@ -1504,6 +1504,10 @@ tf_module {
 | 
				
			|||||||
    name: "matrix_solve_ls"
 | 
					    name: "matrix_solve_ls"
 | 
				
			||||||
    argspec: "args=[\'matrix\', \'rhs\', \'l2_regularizer\', \'fast\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'True\', \'None\'], "
 | 
					    argspec: "args=[\'matrix\', \'rhs\', \'l2_regularizer\', \'fast\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'True\', \'None\'], "
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					  member_method {
 | 
				
			||||||
 | 
					    name: "matrix_square_root"
 | 
				
			||||||
 | 
					    argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
  member_method {
 | 
					  member_method {
 | 
				
			||||||
    name: "matrix_transpose"
 | 
					    name: "matrix_transpose"
 | 
				
			||||||
    argspec: "args=[\'a\', \'name\', \'conjugate\'], varargs=None, keywords=None, defaults=[\'matrix_transpose\', \'False\'], "
 | 
					    argspec: "args=[\'a\', \'name\', \'conjugate\'], varargs=None, keywords=None, defaults=[\'matrix_transpose\', \'False\'], "
 | 
				
			||||||
 | 
				
			|||||||
@ -156,6 +156,10 @@ tf_module {
 | 
				
			|||||||
    name: "solve"
 | 
					    name: "solve"
 | 
				
			||||||
    argspec: "args=[\'matrix\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
 | 
					    argspec: "args=[\'matrix\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					  member_method {
 | 
				
			||||||
 | 
					    name: "sqrtm"
 | 
				
			||||||
 | 
					    argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
  member_method {
 | 
					  member_method {
 | 
				
			||||||
    name: "svd"
 | 
					    name: "svd"
 | 
				
			||||||
    argspec: "args=[\'tensor\', \'full_matrices\', \'compute_uv\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
 | 
					    argspec: "args=[\'tensor\', \'full_matrices\', \'compute_uv\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
 | 
				
			||||||
 | 
				
			|||||||
@ -1120,6 +1120,10 @@ tf_module {
 | 
				
			|||||||
    name: "matrix_solve"
 | 
					    name: "matrix_solve"
 | 
				
			||||||
    argspec: "args=[\'matrix\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
 | 
					    argspec: "args=[\'matrix\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					  member_method {
 | 
				
			||||||
 | 
					    name: "matrix_square_root"
 | 
				
			||||||
 | 
					    argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
  member_method {
 | 
					  member_method {
 | 
				
			||||||
    name: "matrix_triangular_solve"
 | 
					    name: "matrix_triangular_solve"
 | 
				
			||||||
    argspec: "args=[\'matrix\', \'rhs\', \'lower\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'None\'], "
 | 
					    argspec: "args=[\'matrix\', \'rhs\', \'lower\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'None\'], "
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user