From c8e8ba577e9a2e94885f4f423a84d42e45015652 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Wed, 15 Jan 2020 18:21:45 -0800 Subject: [PATCH] Add Broadcasted Matrix Triangular Solve. Add Numpy-style broadcasting in the batch dimensions for tf.linalg.triangular_solve op. The last two dimensions of both operands constitute the matrix dimensions. The dimensions beyond these are broadcasted to form a common output shape with the standard NumPy broadcasting rules. (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) Note: This implementation differs from Numpy's behavior in that vectors (rank-1 Tensors) are not pr... PiperOrigin-RevId: 289978628 Change-Id: I66e41e292e57e6df8111745cbe47ccffacb53edc --- .../api_def_MatrixTriangularSolve.pbtxt | 6 +- .../api_def_MatrixTriangularSolve.pbtxt | 8 +- tensorflow/core/kernels/BUILD | 25 +- tensorflow/core/kernels/cuda_solvers.cc | 100 ---- tensorflow/core/kernels/cuda_solvers.h | 22 - .../kernels/matrix_triangular_solve_op.cc | 258 +++++++++++ .../matrix_triangular_solve_op_complex.cc | 28 -- .../kernels/matrix_triangular_solve_op_impl.h | 431 ------------------ .../matrix_triangular_solve_op_real.cc | 32 -- .../matrix_triangular_solve_op_test.cc | 165 ------- tensorflow/core/ops/linalg_ops.cc | 30 +- tensorflow/core/ops/linalg_ops_test.cc | 72 ++- tensorflow/python/kernel_tests/BUILD | 1 - .../matrix_triangular_solve_op_test.py | 84 +--- tensorflow/python/ops/linalg_grad.py | 12 +- tensorflow/python/ops/linalg_ops.py | 61 --- 16 files changed, 316 insertions(+), 1019 deletions(-) create mode 100644 tensorflow/core/kernels/matrix_triangular_solve_op.cc delete mode 100644 tensorflow/core/kernels/matrix_triangular_solve_op_complex.cc delete mode 100644 tensorflow/core/kernels/matrix_triangular_solve_op_impl.h delete mode 100644 tensorflow/core/kernels/matrix_triangular_solve_op_real.cc delete mode 100644 tensorflow/core/kernels/matrix_triangular_solve_op_test.cc diff --git a/tensorflow/core/api_def/base_api/api_def_MatrixTriangularSolve.pbtxt b/tensorflow/core/api_def/base_api/api_def_MatrixTriangularSolve.pbtxt index bf31b2d9e4d..0ecd7937995 100644 --- a/tensorflow/core/api_def/base_api/api_def_MatrixTriangularSolve.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_MatrixTriangularSolve.pbtxt @@ -44,17 +44,15 @@ square matrices. If `lower` is `True` then the strictly upper triangular part of each inner-most matrix is assumed to be zero and not accessed. If `lower` is False then the strictly lower triangular part of each inner-most matrix is assumed to be zero and not accessed. -`rhs` is a tensor of shape `[..., M, N]`. +`rhs` is a tensor of shape `[..., M, K]`. -The output is a tensor of shape `[..., M, N]`. If `adjoint` is +The output is a tensor of shape `[..., M, K]`. If `adjoint` is `True` then the innermost matrices in `output` satisfy matrix equations `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`. If `adjoint` is `False` then the strictly then the innermost matrices in `output` satisfy matrix equations `adjoint(matrix[..., i, k]) * output[..., k, j] = rhs[..., i, j]`. -Note, the batch shapes for the inputs only need to broadcast. - Example: ```python diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixTriangularSolve.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixTriangularSolve.pbtxt index 8022c6d0556..17dc57335ae 100644 --- a/tensorflow/core/api_def/python_api/api_def_MatrixTriangularSolve.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_MatrixTriangularSolve.pbtxt @@ -1,4 +1,10 @@ op { graph_op_name: "MatrixTriangularSolve" - visibility: HIDDEN + endpoint { + name: "linalg.triangular_solve" + } + endpoint { + name: "matrix_triangular_solve" + deprecation_version: 2 + } } diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index c42dc636e8d..26a2d2892e0 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3588,14 +3588,10 @@ tf_kernel_library( tf_kernel_library( name = "matrix_triangular_solve_op", - hdrs = ["matrix_triangular_solve_op_impl.h"], prefix = "matrix_triangular_solve_op", deps = LINALG_DEPS + if_cuda([ "//tensorflow/core/platform/default/build_config:cublas_plugin", - ]) + [ - ":fill_functor", - "//tensorflow/core:stream_executor", - ], + ]), ) tf_kernel_library( @@ -4183,25 +4179,6 @@ tf_cuda_cc_test( ], ) -tf_cuda_cc_test( - name = "matrix_triangular_solve_op_test", - size = "small", - srcs = ["matrix_triangular_solve_op_test.cc"], - deps = [ - ":broadcast_to_op", - ":matrix_triangular_solve_op", - ":ops_testutil", - ":ops_util", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ], -) - tf_cuda_cc_test( name = "scan_ops_test", size = "small", diff --git a/tensorflow/core/kernels/cuda_solvers.cc b/tensorflow/core/kernels/cuda_solvers.cc index dcf40ef6798..1c569204265 100644 --- a/tensorflow/core/kernels/cuda_solvers.cc +++ b/tensorflow/core/kernels/cuda_solvers.cc @@ -900,106 +900,6 @@ static inline Status MatInvBatchedImpl( TF_CALL_LAPACK_TYPES(MATINV_BATCHED_INSTANCE); -template -static inline Status TrsmImpl(SolverFnT solver, cublasHandle_t cublas_handle, - cublasSideMode_t side, cublasFillMode_t uplo, - cublasOperation_t trans, cublasDiagType_t diag, - int m, int n, - const Scalar* alpha, /* host or device pointer */ - const Scalar* A, int lda, Scalar* B, int ldb) { - mutex_lock lock(handle_map_mutex); - using CudaScalar = typename CUDAComplexT::type; - TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, side, uplo, trans, diag, m, n, - reinterpret_cast(alpha), - reinterpret_cast(A), lda, - reinterpret_cast(B), ldb)); - return Status::OK(); -} - -#define TRSM_INSTANCE(Scalar, type_prefix) \ - template <> \ - Status CudaSolver::Trsm( \ - cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, \ - cublasDiagType_t diag, int m, int n, \ - const Scalar* alpha, /* host or device pointer */ \ - const Scalar* A, int lda, Scalar* B, int ldb) { \ - return TrsmImpl(BLAS_SOLVER_FN(trsm, type_prefix), cublas_handle_, side, \ - uplo, trans, diag, m, n, alpha, A, lda, B, ldb); \ - } - -TF_CALL_LAPACK_TYPES(TRSM_INSTANCE); - -template -static inline Status TrsvImpl(SolverFnT solver, cublasHandle_t cublas_handle, - cublasFillMode_t uplo, cublasOperation_t trans, - cublasDiagType_t diag, int n, const Scalar* A, - int lda, Scalar* x, int incx) { - mutex_lock lock(handle_map_mutex); - using CudaScalar = typename CUDAComplexT::type; - TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, uplo, trans, diag, n, - reinterpret_cast(A), lda, - reinterpret_cast(x), incx)); - return Status::OK(); -} - -#define TRSV_INSTANCE(Scalar, type_prefix) \ - template <> \ - Status CudaSolver::Trsv( \ - cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, \ - int n, const Scalar* A, int lda, Scalar* x, int incx) { \ - return TrsvImpl(BLAS_SOLVER_FN(trsv, type_prefix), cublas_handle_, uplo, \ - trans, diag, n, A, lda, x, incx); \ - } - -TF_CALL_LAPACK_TYPES(TRSV_INSTANCE); - -template -static inline Status TrsmBatchedImpl( - SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context, - cublasHandle_t cublas_handle, cublasSideMode_t side, cublasFillMode_t uplo, - cublasOperation_t trans, cublasDiagType_t diag, int m, int n, - const Scalar* alpha, const Scalar* const host_a_dev_ptrs[], int lda, - Scalar* host_b_dev_ptrs[], int ldb, int batch_size) { - mutex_lock lock(handle_map_mutex); - using CudaScalar = typename CUDAComplexT::type; - ScratchSpace dev_a_dev_ptrs = - cuda_solver->GetScratchSpace(sizeof(CudaScalar*) * batch_size, "", - /* on_host */ false); - ScratchSpace dev_b_dev_ptrs = - cuda_solver->GetScratchSpace(sizeof(CudaScalar*) * batch_size, "", - /* on_host */ false); - if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */, - host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) { - return errors::Internal("TrsmBatched: failed to copy pointers to device"); - } - if (!CopyHostToDevice(context, dev_b_dev_ptrs.mutable_data() /* dest */, - host_b_dev_ptrs /* source */, dev_b_dev_ptrs.bytes())) { - return errors::Internal("TrsmBatched: failed to copy pointers to device"); - } - TF_RETURN_IF_CUBLAS_ERROR( - solver(cublas_handle, side, uplo, trans, diag, m, n, - reinterpret_cast(alpha), - reinterpret_cast(dev_a_dev_ptrs.data()), - lda, reinterpret_cast(dev_b_dev_ptrs.mutable_data()), - ldb, batch_size)); - return Status::OK(); -} - -#define TRSM_BATCHED_INSTANCE(Scalar, type_prefix) \ - template <> \ - Status CudaSolver::TrsmBatched( \ - cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, \ - cublasDiagType_t diag, int m, int n, const Scalar* alpha, \ - const Scalar* const dev_Aarray[], int lda, Scalar* dev_Barray[], \ - int ldb, int batch_size) { \ - return TrsmBatchedImpl(BLAS_SOLVER_FN(trsmBatched, type_prefix), this, \ - context_, cublas_handle_, side, uplo, trans, diag, \ - m, n, alpha, dev_Aarray, lda, dev_Barray, ldb, \ - batch_size); \ - } - -TF_CALL_LAPACK_TYPES(TRSM_BATCHED_INSTANCE); - } // namespace tensorflow #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cuda_solvers.h b/tensorflow/core/kernels/cuda_solvers.h index f1e5e71b16a..104ee09a2bc 100644 --- a/tensorflow/core/kernels/cuda_solvers.h +++ b/tensorflow/core/kernels/cuda_solvers.h @@ -333,28 +333,6 @@ class CudaSolver { int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_V, int ldv, int* dev_lapack_info, int batch_size); - // Triangular solve - // Returns Status::OK() if the kernel was launched successfully. - // See https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-trsm - template - Status Trsm(cublasSideMode_t side, cublasFillMode_t uplo, - cublasOperation_t trans, cublasDiagType_t diag, int m, int n, - const Scalar* alpha, const Scalar* A, int lda, Scalar* B, - int ldb); - - template - Status Trsv(cublasFillMode_t uplo, cublasOperation_t trans, - cublasDiagType_t diag, int n, const Scalar* A, int lda, Scalar* x, - int incx); - - // See - // https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-trsmbatched - template - Status TrsmBatched(cublasSideMode_t side, cublasFillMode_t uplo, - cublasOperation_t trans, cublasDiagType_t diag, int m, - int n, const Scalar* alpha, - const Scalar* const dev_Aarray[], int lda, - Scalar* dev_Barray[], int ldb, int batch_size); private: OpKernelContext* context_; // not owned. diff --git a/tensorflow/core/kernels/matrix_triangular_solve_op.cc b/tensorflow/core/kernels/matrix_triangular_solve_op.cc new file mode 100644 index 00000000000..61bc4aad214 --- /dev/null +++ b/tensorflow/core/kernels/matrix_triangular_solve_op.cc @@ -0,0 +1,258 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/linalg_ops.cc. + +#include "third_party/eigen3/Eigen/Core" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/linalg_ops_common.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "tensorflow/core/platform/stream_executor.h" +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +namespace tensorflow { + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +namespace { +template +se::DeviceMemory AsDeviceMemory(const Scalar* gpu_memory) { + se::DeviceMemoryBase wrapped(const_cast(gpu_memory)); + se::DeviceMemory typed(wrapped); + return typed; +} +} // namespace +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +template +class MatrixTriangularSolveOp : public LinearAlgebraOp { + public: + INHERIT_LINALG_TYPEDEFS(Scalar); + + explicit MatrixTriangularSolveOp(OpKernelConstruction* context) + : Base(context), lower_(true), adjoint_(false) { + OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_)); + OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_)); + } + + void ValidateInputMatrixShapes( + OpKernelContext* context, + const TensorShapes& input_matrix_shapes) const final { + Base::ValidateSquareSolver(context, input_matrix_shapes); + } + + TensorShapes GetOutputMatrixShapes( + const TensorShapes& input_matrix_shapes) const final { + return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1), + input_matrix_shapes[1].dim_size(1)})}); + } + + int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { + double rows = static_cast(input_matrix_shapes[0].dim_size(0)); + double num_rhss = static_cast(input_matrix_shapes[1].dim_size(1)); + double cost = rows * rows * num_rhss * + (Eigen::TensorOpCost::AddCost() + + Eigen::TensorOpCost::MulCost()); + return cost >= static_cast(kint64max) ? kint64max + : static_cast(cost); + } + + bool EnableInputForwarding() const final { return false; } + + void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, + MatrixMaps* outputs) final { + const ConstMatrixMap& matrix = inputs[0]; + const ConstMatrixMap& rhs = inputs[1]; + MatrixMap& output = outputs->at(0); + + if (matrix.rows() == 0 || rhs.rows() == 0 || rhs.cols() == 0) { + // To be consistent with the MatrixInverse op, we define the solution for + // an empty set of equation as the empty matrix. + return; + } + const RealScalar min_abs_pivot = matrix.diagonal().cwiseAbs().minCoeff(); + OP_REQUIRES(context, min_abs_pivot > RealScalar(0), + errors::InvalidArgument("Input matrix is not invertible.")); + if (lower_) { + auto triangle = matrix.template triangularView(); + if (adjoint_) { + output.noalias() = triangle.adjoint().solve(rhs); + } else { + output.noalias() = triangle.solve(rhs); + } + } else { + auto triangle = matrix.template triangularView(); + if (adjoint_) { + output.noalias() = triangle.adjoint().solve(rhs); + } else { + output.noalias() = triangle.solve(rhs); + } + } + } + + private: + bool lower_; + bool adjoint_; + + TF_DISALLOW_COPY_AND_ASSIGN(MatrixTriangularSolveOp); +}; + +REGISTER_LINALG_OP_CPU("MatrixTriangularSolve", + (MatrixTriangularSolveOp), float); +REGISTER_LINALG_OP_CPU("MatrixTriangularSolve", + (MatrixTriangularSolveOp), double); +REGISTER_LINALG_OP_CPU("MatrixTriangularSolve", + (MatrixTriangularSolveOp), complex64); +REGISTER_LINALG_OP_CPU("MatrixTriangularSolve", + (MatrixTriangularSolveOp), complex128); +REGISTER_LINALG_OP_CPU("BatchMatrixTriangularSolve", + (MatrixTriangularSolveOp), float); +REGISTER_LINALG_OP_CPU("BatchMatrixTriangularSolve", + (MatrixTriangularSolveOp), double); + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +// TODO(rmlarsen): Re-factor to +// 1. Enable buffer forwarding from rhs->out. +// 2. Save Memcpy when buffer forwarding is used. +// 3. Copy entire rhs in a single Memcpy when forwarding is not used. +template +class MatrixTriangularSolveOpGPU : public LinearAlgebraOp { + public: + INHERIT_LINALG_TYPEDEFS(Scalar); + + explicit MatrixTriangularSolveOpGPU(OpKernelConstruction* context) + : Base(context), lower_(true), adjoint_(false) { + OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_)); + OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_)); + } + + void ValidateInputMatrixShapes( + OpKernelContext* context, + const TensorShapes& input_matrix_shapes) const final { + Base::ValidateSquareSolver(context, input_matrix_shapes); + } + + TensorShapes GetOutputMatrixShapes( + const TensorShapes& input_matrix_shapes) const final { + return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1), + input_matrix_shapes[1].dim_size(1)})}); + } + + int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { + double rows = static_cast(input_matrix_shapes[0].dim_size(0)); + double num_rhss = static_cast(input_matrix_shapes[1].dim_size(1)); + double cost = rows * rows * num_rhss * + (Eigen::TensorOpCost::AddCost() + + Eigen::TensorOpCost::MulCost()); + return cost >= static_cast(kint64max) ? kint64max + : static_cast(cost); + } + + bool EnableInputForwarding() const final { return false; } + + void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, + MatrixMaps* outputs) final { + const ConstMatrixMap& matrix = inputs[0]; + const ConstMatrixMap& rhs = inputs[1]; + MatrixMap& output = outputs->at(0); + + if (matrix.rows() == 0 || rhs.rows() == 0 || rhs.cols() == 0) { + // To be consistent with the MatrixInverse op, we define the solution for + // an empty set of equation as the empty matrix. + return; + } + + auto matrix_ptr = AsDeviceMemory(matrix.data()); + auto rhs_ptr = AsDeviceMemory(rhs.data()); + auto out_ptr = AsDeviceMemory(output.data()); + + auto* stream = context->op_device_context()->stream(); + uint64 rhs_elems = rhs.rows() * rhs.cols(); + bool copy_status = + stream->ThenMemcpyD2D(&out_ptr, rhs_ptr, sizeof(Scalar) * rhs_elems) + .ok(); + if (!copy_status) { + context->SetStatus( + errors::Internal("Failed to copy rhs into output before solve")); + } + + // Cublas does + // output = matrix \ rhs + // where matrix, rhs and output are assumed to be in column major. + // We want the output to be in row-major, so we can compute + // output' = rhs' / matrix' (' stands for transpose) + // Upper/lower needs to be swapped for this. + + se::blas::UpperLower upper_lower_matrix; + se::blas::Transpose transpose_matrix; + if (lower_) { + upper_lower_matrix = se::blas::UpperLower::kUpper; + } else { + upper_lower_matrix = se::blas::UpperLower::kLower; + } + if (adjoint_) { + transpose_matrix = se::blas::Transpose::kConjugateTranspose; + } else { + transpose_matrix = se::blas::Transpose::kNoTranspose; + } + uint64 leading_dim_matrix = matrix.cols(); + uint64 leading_dim_output = output.cols(); + uint64 colmajor_rows = output.cols(); + uint64 colmajor_cols = output.rows(); + bool blas_launch_status = + stream + ->ThenBlasTrsm( + se::blas::Side::kRight /*side*/, upper_lower_matrix /*uplo*/, + transpose_matrix /*trans*/, + se::blas::Diagonal::kNonUnit /*diag*/, colmajor_rows /*m*/, + colmajor_cols /*n*/, Scalar(1.0) /*alpha*/, matrix_ptr, + leading_dim_matrix /*lda*/, &out_ptr, + leading_dim_output /*ldb*/) + .ok(); + if (!blas_launch_status) { + context->SetStatus(errors::Internal("Blas TRSM launch failed")); + } + } + + private: + bool lower_; + bool adjoint_; + + TF_DISALLOW_COPY_AND_ASSIGN(MatrixTriangularSolveOpGPU); +}; + +REGISTER_LINALG_OP_GPU("MatrixTriangularSolve", + (MatrixTriangularSolveOpGPU), float); +REGISTER_LINALG_OP_GPU("MatrixTriangularSolve", + (MatrixTriangularSolveOpGPU), double); +REGISTER_LINALG_OP_GPU("MatrixTriangularSolve", + (MatrixTriangularSolveOpGPU), complex64); +REGISTER_LINALG_OP_GPU("MatrixTriangularSolve", + (MatrixTriangularSolveOpGPU), complex128); +REGISTER_LINALG_OP_GPU("BatchMatrixTriangularSolve", + (MatrixTriangularSolveOpGPU), float); +REGISTER_LINALG_OP_GPU("BatchMatrixTriangularSolve", + (MatrixTriangularSolveOpGPU), double); + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/matrix_triangular_solve_op_complex.cc b/tensorflow/core/kernels/matrix_triangular_solve_op_complex.cc deleted file mode 100644 index 1efd89367ca..00000000000 --- a/tensorflow/core/kernels/matrix_triangular_solve_op_complex.cc +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/kernels/matrix_triangular_solve_op_impl.h" - -namespace tensorflow { - -TF_CALL_complex64(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU); -TF_CALL_complex128(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU); - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -TF_CALL_complex64(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU); -TF_CALL_complex128(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU); -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -} // namespace tensorflow diff --git a/tensorflow/core/kernels/matrix_triangular_solve_op_impl.h b/tensorflow/core/kernels/matrix_triangular_solve_op_impl.h deleted file mode 100644 index 926296b3760..00000000000 --- a/tensorflow/core/kernels/matrix_triangular_solve_op_impl.h +++ /dev/null @@ -1,431 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// See docs in ../ops/linalg_ops.cc. -// -#ifndef TENSORFLOW_CORE_KERNELS_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_ -#define TENSORFLOW_CORE_KERNELS_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_ - -#include "third_party/eigen3/Eigen/Core" -#include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/kernels/fill_functor.h" -#include "tensorflow/core/kernels/linalg_ops_common.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/matmul_bcast.h" - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/kernels/cuda_solvers.h" -#include "tensorflow/core/kernels/transpose_functor.h" -#include "tensorflow/core/platform/stream_executor.h" -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -namespace tensorflow { - -typedef Eigen::ThreadPoolDevice CPUDevice; -typedef Eigen::GpuDevice GPUDevice; - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -template -se::DeviceMemory AsDeviceMemory(const Scalar* gpu_memory) { - se::DeviceMemoryBase wrapped(const_cast(gpu_memory)); - se::DeviceMemory typed(wrapped); - return typed; -} - -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -// Sequential batch matrix triangular solve kernel that calls Eigen's -// matrix triangular solve. -template -struct SequentialMatrixTriangularSolveKernel { - using Matrix = - Eigen::Matrix; - using ConstMatrixMap = Eigen::Map; - using MatrixMap = Eigen::Map; - using RealScalar = typename Eigen::NumTraits::Real; - - static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t, - int slice) { - return ConstMatrixMap( - t.flat().data() + slice * t.dim_size(1) * t.dim_size(2), - t.dim_size(1), t.dim_size(2)); - } - - static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) { - return MatrixMap( - t->flat().data() + slice * t->dim_size(1) * t->dim_size(2), - t->dim_size(1), t->dim_size(2)); - } - - static void Run(const Tensor& in_x, const Tensor& in_y, bool lower, - bool adjoint, const MatMulBCast& bcast, Tensor* out, - int start, int limit) { - const bool should_bcast = bcast.IsBroadcastingRequired(); - const auto& x_batch_indices = bcast.x_batch_indices(); - const auto& y_batch_indices = bcast.y_batch_indices(); - for (int64 i = start; i < limit; ++i) { - const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i; - const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i; - auto matrix = ConstTensorSliceToEigenMatrix(in_x, x_batch_index); - auto rhs = ConstTensorSliceToEigenMatrix(in_y, y_batch_index); - auto output = TensorSliceToEigenMatrix(out, i); - if (lower) { - auto triangle = matrix.template triangularView(); - if (adjoint) { - output.noalias() = triangle.adjoint().solve(rhs); - } else { - output.noalias() = triangle.solve(rhs); - } - } else { - auto triangle = matrix.template triangularView(); - if (adjoint) { - output.noalias() = triangle.adjoint().solve(rhs); - } else { - output.noalias() = triangle.solve(rhs); - } - } - } - } -}; - -template -struct LaunchBatchMatrixTriangularSolve; - -template -struct LaunchBatchMatrixTriangularSolve { - static void Launch(OpKernelContext* context, const Tensor& in_x, - const Tensor& in_y, bool adjoint, bool lower, - const MatMulBCast& bcast, Tensor* out) { - // Number of matrix triangular solves i.e. size of the batch. - const int64 batch_size = bcast.output_batch_size(); - const int64 cost_per_unit = - in_x.dim_size(1) * in_x.dim_size(1) * in_y.dim_size(2) / 2; - auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); - - using Matrix = - Eigen::Matrix; - using ConstMatrixMap = Eigen::Map; - using RealScalar = typename Eigen::NumTraits::Real; - // Check diagonal before doing any solves. - auto matrix = ConstMatrixMap(in_x.flat().data(), in_x.dim_size(1), - in_x.dim_size(2)); - const RealScalar min_abs_pivot = matrix.diagonal().cwiseAbs().minCoeff(); - OP_REQUIRES(context, min_abs_pivot > RealScalar(0), - errors::InvalidArgument("Input matrix is not invertible.")); - - Shard(worker_threads.num_threads, worker_threads.workers, batch_size, - cost_per_unit, - [&in_x, &in_y, adjoint, lower, &bcast, out](int start, int limit) { - SequentialMatrixTriangularSolveKernel::Run( - in_x, in_y, lower, adjoint, bcast, out, start, limit); - }); - } -}; - -template -class BaseMatrixTriangularSolveOp : public OpKernel { - public: - explicit BaseMatrixTriangularSolveOp(OpKernelConstruction* context) - : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_)); - OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_)); - } - - ~BaseMatrixTriangularSolveOp() override {} - - void Compute(OpKernelContext* ctx) override { - const Tensor& in0 = ctx->input(0); - const Tensor& in1 = ctx->input(1); - - ValidateInputTensors(ctx, in0, in1); - - MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes()); - OP_REQUIRES( - ctx, bcast.IsValid(), - errors::InvalidArgument( - "In[0] and In[1] must have compatible batch dimensions: ", - in0.shape().DebugString(), " vs. ", in1.shape().DebugString())); - - TensorShape out_shape = bcast.output_batch_shape(); - auto batch_size = bcast.output_batch_size(); - auto d0 = in0.dim_size(in0.dims() - 2); - auto d1 = in0.dim_size(in0.dims() - 1); - Tensor in0_reshaped; - OP_REQUIRES( - ctx, - in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})), - errors::Internal("Failed to reshape In[0] from ", - in0.shape().DebugString())); - auto d2 = in1.dim_size(in1.dims() - 2); - auto d3 = in1.dim_size(in1.dims() - 1); - Tensor in1_reshaped; - OP_REQUIRES( - ctx, - in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})), - errors::Internal("Failed to reshape In[1] from ", - in1.shape().DebugString())); - if (adjoint_) std::swap(d0, d1); - OP_REQUIRES(ctx, d1 == d2, - errors::InvalidArgument( - "In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ", - in0.shape().DebugString(), " ", in1.shape().DebugString(), - " ", lower_, " ", adjoint_)); - out_shape.AddDim(d0); - out_shape.AddDim(d3); - Tensor* out = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out)); - if (out->NumElements() == 0) { - return; - } - Tensor out_reshaped; - OP_REQUIRES(ctx, - out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})), - errors::Internal("Failed to reshape output from ", - out->shape().DebugString())); - LaunchBatchMatrixTriangularSolve::Launch( - ctx, in0_reshaped, in1_reshaped, adjoint_, lower_, bcast, - &out_reshaped); - } - - private: - virtual void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0, - const Tensor& in1) = 0; - bool lower_; - bool adjoint_; -}; - -template -class MatrixTriangularSolveOp - : public BaseMatrixTriangularSolveOp { - public: - explicit MatrixTriangularSolveOp(OpKernelConstruction* context) - : BaseMatrixTriangularSolveOp(context) {} - - ~MatrixTriangularSolveOp() override {} - - private: - void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0, - const Tensor& in1) override { - // Disallow broadcasting support. Ensure that all batch dimensions of the - // input tensors match. - OP_REQUIRES(ctx, in0.dims() == in1.dims(), - errors::InvalidArgument("In[0] and In[1] has different ndims: ", - in0.shape().DebugString(), " vs. ", - in1.shape().DebugString())); - const int ndims = in0.dims(); - OP_REQUIRES( - ctx, ndims >= 2, - errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims)); - for (int i = 0; i < ndims - 2; ++i) { - OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i), - errors::InvalidArgument( - "In[0].dim(", i, ") and In[1].dim(", i, - ") must be the same: ", in0.shape().DebugString(), " vs ", - in1.shape().DebugString())); - } - } -}; - -template -class MatrixTriangularSolveOpV2 - : public BaseMatrixTriangularSolveOp { - public: - explicit MatrixTriangularSolveOpV2(OpKernelConstruction* context) - : BaseMatrixTriangularSolveOp(context) {} - - ~MatrixTriangularSolveOpV2() override {} - - private: - void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0, - const Tensor& in1) override { - OP_REQUIRES( - ctx, in0.dims() >= 2, - errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims())); - - OP_REQUIRES( - ctx, in1.dims() >= 2, - errors::InvalidArgument("In[0] ndims must be >= 2: ", in1.dims())); - } -}; - -#define REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU(TYPE) \ - REGISTER_KERNEL_BUILDER(Name("MatrixTriangularSolve") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T"), \ - MatrixTriangularSolveOpV2); \ - REGISTER_KERNEL_BUILDER(Name("BatchMatrixTriangularSolve") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T"), \ - MatrixTriangularSolveOpV2); - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -template -struct LaunchBatchMatrixTriangularSolve { - static void Launch(OpKernelContext* context, const Tensor& in_x, - const Tensor& in_y, bool adjoint, bool lower, - const MatMulBCast& bcast, Tensor* out) { - auto* stream = context->op_device_context()->stream(); - - const uint64 m = in_x.dim_size(1); - const uint64 n = out->dim_size(2); - - // Do a memcpy when we don't need to broadcast. - if (!bcast.IsBroadcastingRequired() || out->shape() == in_y.shape()) { - auto src_device_mem = AsDeviceMemory(in_y.template flat().data()); - auto dst_device_mem = AsDeviceMemory(out->template flat().data()); - OP_REQUIRES( - context, - stream - ->ThenMemcpyD2D(&dst_device_mem, src_device_mem, - bcast.y_batch_size() * m * n * sizeof(Scalar)) - .ok(), - errors::Internal("MatrixTriangularSolveOpV2: failed to copy rhs " - "from device")); - } else { - std::vector out_ptrs; - std::vector b_tmp_ptrs; - auto* b_base_ptr = in_y.template flat().data(); - const std::vector& b_batch_indices = bcast.y_batch_indices(); - for (int64 i = 0; i < bcast.y_batch_size(); ++i) { - b_tmp_ptrs.push_back(b_base_ptr + i * m * n); - } - for (int64 i = 0; i < bcast.output_batch_size(); ++i) { - auto src_device_mem = AsDeviceMemory(b_tmp_ptrs[b_batch_indices[i]]); - auto dst_device_mem = - AsDeviceMemory(out->template flat().data() + i * m * n); - OP_REQUIRES( - context, - stream - ->ThenMemcpyD2D(&dst_device_mem, src_device_mem, - m * n * sizeof(Scalar)) - .ok(), - errors::Internal("MatrixTriangularSolveOpV2: failed to copy rhs " - "from device")); - } - } - - if (out->NumElements() == 0) { - return; - } - - cublasSideMode_t side = CUBLAS_SIDE_RIGHT; - cublasFillMode_t uplo; - cublasOperation_t trans; - cublasDiagType_t diag = CUBLAS_DIAG_NON_UNIT; - - // Cublas does - // output = matrix \ rhs - // where matrix, rhs and output are assumed to be in column major. - // We want the output to be in row-major, so we can compute - // output' = rhs' / matrix' (' stands for transpose) - // Upper/lower needs to be swapped for this. - - uplo = lower ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; - trans = adjoint ? CUBLAS_OP_C : CUBLAS_OP_N; - auto solver = absl::make_unique(context); - - const uint64 leading_dim_matrix = m; - const uint64 leading_dim_output = n; - const uint64 colmajor_rows = n; - const uint64 colmajor_cols = m; - - const int64 batch_size = bcast.output_batch_size(); - std::vector a_ptrs; - std::vector out_ptrs; - std::vector a_tmp_ptrs; - a_ptrs.reserve(batch_size); - out_ptrs.reserve(batch_size); - a_tmp_ptrs.reserve(bcast.x_batch_size()); - auto* a_base_ptr = in_x.template flat().data(); - auto* out_base_ptr = out->template flat().data(); - - if (!bcast.IsBroadcastingRequired()) { - for (int64 i = 0; i < batch_size; ++i) { - a_ptrs.push_back(a_base_ptr + i * m * m); - out_ptrs.push_back(out_base_ptr + i * m * n); - } - } else { - const std::vector& a_batch_indices = bcast.x_batch_indices(); - for (int64 i = 0; i < bcast.x_batch_size(); ++i) { - a_tmp_ptrs.push_back(a_base_ptr + i * m * m); - } - for (int64 i = 0; i < batch_size; ++i) { - a_ptrs.push_back(a_tmp_ptrs[a_batch_indices[i]]); - out_ptrs.push_back(out_base_ptr + i * m * n); - } - } - - typedef Scalar Coefficient; - const Scalar alpha = Scalar(1.0); - - // TODO(b/146763573): Consider using Trsv here when the right hand side is - // a vector. This will require an explicit transpose since Trsv assumes - // CUBLAS_SIDE_LEFT. - if (batch_size == 1) { - OP_REQUIRES_OK( - context, - solver->Trsm(side, uplo, trans, diag, colmajor_rows, colmajor_cols, - &alpha, a_ptrs[0], leading_dim_matrix /*lda*/, - out_ptrs[0], leading_dim_output /*ldb*/)); - } else { - // Heuristic for choosing between batched interface vs. non-batched - // interface. This is inspired by matrix_solve_op and can probably be - // tuned. - // TODO(b/146763573): Tune this heuristic. - const int kMaxMatrixSizeToBatchSizeRatio = 128; - const bool use_batched_solver = - m <= kMaxMatrixSizeToBatchSizeRatio * batch_size; - if (use_batched_solver) { - OP_REQUIRES_OK( - context, solver->TrsmBatched( - side, uplo, trans, diag, colmajor_rows, colmajor_cols, - &alpha, &a_ptrs[0], leading_dim_matrix /*lda*/, - &out_ptrs[0], leading_dim_output /*ldb*/, batch_size)); - } else { - for (int batch = 0; batch < batch_size; ++batch) { - OP_REQUIRES_OK( - context, solver->Trsm(side, uplo, trans, diag, colmajor_rows, - colmajor_cols, &alpha, a_ptrs[batch], - leading_dim_matrix /*lda*/, out_ptrs[batch], - leading_dim_output /*ldb*/)); - } - } - } - } -}; - -#define REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU(TYPE) \ - REGISTER_KERNEL_BUILDER(Name("MatrixTriangularSolve") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T"), \ - MatrixTriangularSolveOpV2); \ - REGISTER_KERNEL_BUILDER(Name("BatchMatrixTriangularSolve") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T"), \ - MatrixTriangularSolveOpV2); - -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_KERNELS_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_ diff --git a/tensorflow/core/kernels/matrix_triangular_solve_op_real.cc b/tensorflow/core/kernels/matrix_triangular_solve_op_real.cc deleted file mode 100644 index 0f92964dd72..00000000000 --- a/tensorflow/core/kernels/matrix_triangular_solve_op_real.cc +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/kernels/matrix_triangular_solve_op_impl.h" - -#if GOOGLE_CUDA -#include "third_party/gpus/cuda/include/cuda.h" -#endif // GOOGLE_CUDA - -namespace tensorflow { - -TF_CALL_float(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU); -TF_CALL_double(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU); - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -TF_CALL_float(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU); -TF_CALL_double(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU); -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -} // namespace tensorflow diff --git a/tensorflow/core/kernels/matrix_triangular_solve_op_test.cc b/tensorflow/core/kernels/matrix_triangular_solve_op_test.cc deleted file mode 100644 index 7bb71ae8b68..00000000000 --- a/tensorflow/core/kernels/matrix_triangular_solve_op_test.cc +++ /dev/null @@ -1,165 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/node_builder.h" -#include "tensorflow/core/graph/testlib.h" -#include "tensorflow/core/kernels/broadcast_to_op.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/test_benchmark.h" - -namespace tensorflow { -namespace { - -Node* BroadcastTo(Graph* g, Node* input, Node* shape) { - Node* ret; - TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BroadcastTo") - .Input(input) - .Input(shape) - .Attr("Tidx", DT_INT64) - .Finalize(g, &ret)); - return ret; -} - -Node* MatrixTriangularSolve(Graph* g, Node* in0, Node* in1, bool adjoint) { - Node* ret; - TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatrixTriangularSolve") - .Input(in0) - .Input(in1) - .Attr("lower", true) - .Attr("adjoint", adjoint) - .Finalize(g, &ret)); - return ret; -} - -template -static Graph* MatrixTriangularSolveWithBroadcast(int64 b0, int64 b1, int64 m, - int64 n, bool manual_broadcast, - DataType type) { - Graph* g = new Graph(OpRegistry::Global()); - Tensor in0(type, TensorShape({b0, m, m})); - // Set diagonal to non-zero to guarantee invertibility. - in0.flat().setRandom(); - auto matrix = Eigen::Map< - Eigen::Matrix>( - in0.flat().data(), in0.dim_size(1), in0.dim_size(2)); - - matrix.diagonal() = - (matrix.diagonal().cwiseAbs().array() + static_cast(0.5)); - Tensor in1(type, TensorShape({b1, m, n})); - in1.flat().setRandom(); - - Tensor broadcasted_in0_shape(DT_INT64, TensorShape({3})); - Tensor broadcasted_in1_shape(DT_INT64, TensorShape({3})); - - Node* in0_node = nullptr; - Node* in1_node = nullptr; - if (manual_broadcast) { - auto vec0 = broadcasted_in0_shape.vec(); - auto vec1 = broadcasted_in1_shape.vec(); - for (int i = 0; i < 3; ++i) { - vec0(i) = (i == 0 ? std::max(b0, b1) : in0.shape().dim_size(i)); - vec1(i) = (i == 0 ? std::max(b0, b1) : in1.shape().dim_size(i)); - } - in0_node = BroadcastTo(g, test::graph::Constant(g, in0), - test::graph::Constant(g, broadcasted_in0_shape)); - in1_node = BroadcastTo(g, test::graph::Constant(g, in1), - test::graph::Constant(g, broadcasted_in1_shape)); - } else { - in0_node = test::graph::Constant(g, in0); - in1_node = test::graph::Constant(g, in1); - } - - MatrixTriangularSolve(g, in0_node, in1_node, false); - return g; -} - -// Macro arguments names: --------------------------------------------------- // -// B1: batch size of LHS -// B2: batch size of RHS -// M: inner dimensions of LHS and RHS, outer dimension of LHS -// N: outer dimension of RHS -// MB: boolean indicating whether to use manual broadcasting -// T: C++ type of scalars (e.g. float, std::complex) -// TT: TensorFlow type of scalars (e.g. DT_FLOAT, DT_COMPLEX128 -// D: Device (e.g. cpu, gpu) -#define BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, T, TT, D) \ - static void \ - BM_MatrixTriangularSolve##_##B1##_##B2##_##M##_##N##_##MB##_##TT##_##D( \ - int iters) { \ - testing::UseRealTime(); \ - testing::ItemsProcessed(static_cast(iters) * std::max(B1, B2) * M * \ - M * N * 2); \ - test::Benchmark( \ - #D, MatrixTriangularSolveWithBroadcast(B1, B2, M, N, MB, TT)) \ - .Run(iters); \ - } \ - BENCHMARK( \ - BM_MatrixTriangularSolve##_##B1##_##B2##_##M##_##N##_##MB##_##TT##_##D); - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -#define BM_MatrixTriangularSolve(B1, B2, M, N, MB) \ - BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, float, DT_FLOAT, cpu); \ - BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, double, DT_DOUBLE, cpu); \ - BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, float, DT_FLOAT, gpu); \ - BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, double, DT_DOUBLE, gpu); - -#else - -#define BM_MatrixTriangularSolve(B1, B2, M, N, MB) \ - BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, float, DT_FLOAT, cpu); \ - BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, double, DT_DOUBLE, cpu); - -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -// Square matrix triangular solve. -BM_MatrixTriangularSolve(32, 32, 512, 512, true); -BM_MatrixTriangularSolve(32, 32, 512, 512, false); -BM_MatrixTriangularSolve(1, 32, 512, 512, true); -BM_MatrixTriangularSolve(1, 32, 512, 512, false); -BM_MatrixTriangularSolve(32, 1, 512, 512, true); -BM_MatrixTriangularSolve(32, 1, 512, 512, false); -BM_MatrixTriangularSolve(128, 128, 512, 512, true); -BM_MatrixTriangularSolve(128, 128, 512, 512, false); -BM_MatrixTriangularSolve(1, 128, 512, 512, true); -BM_MatrixTriangularSolve(1, 128, 512, 512, false); -BM_MatrixTriangularSolve(128, 1, 512, 512, true); -BM_MatrixTriangularSolve(128, 1, 512, 512, false); -BM_MatrixTriangularSolve(1, 128, 1024, 1024, true); -BM_MatrixTriangularSolve(1, 128, 1024, 1024, false); -BM_MatrixTriangularSolve(128, 1, 1024, 1024, true); -BM_MatrixTriangularSolve(128, 1, 1024, 1024, false); - -// Matrix-vector triangular solve. -BM_MatrixTriangularSolve(1, 128, 200, 1, true); -BM_MatrixTriangularSolve(1, 128, 200, 1, false); -BM_MatrixTriangularSolve(128, 1, 200, 1, true); -BM_MatrixTriangularSolve(128, 1, 200, 1, false); - -// Matrix-vector triangular solve, large dimension. -BM_MatrixTriangularSolve(1, 128, 200, 10000, true); -BM_MatrixTriangularSolve(1, 128, 200, 10000, false); -BM_MatrixTriangularSolve(128, 1, 200, 10000, true); -BM_MatrixTriangularSolve(128, 1, 200, 10000, false); - -} // namespace -} // namespace tensorflow diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc index 75340b28eb0..4572df279b7 100644 --- a/tensorflow/core/ops/linalg_ops.cc +++ b/tensorflow/core/ops/linalg_ops.cc @@ -84,34 +84,6 @@ Status MatrixSolveShapeFn(InferenceContext* c, bool square) { return Status::OK(); } -// The first input is [...,M,M] and second input is [...,M,N]. -// Output is [...,M,N]. -Status MatrixTriangularSolveShapeFn(InferenceContext* c) { - ShapeHandle lhs; - ShapeHandle rhs; - TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &lhs)); - TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs)); - - ShapeHandle lhs_batch_shape; - ShapeHandle rhs_batch_shape; - ShapeHandle output_batch_shape; - // Make the common batch subshape. - TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape)); - TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape)); - TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( - c, lhs_batch_shape, rhs_batch_shape, true, &output_batch_shape)); - DimensionHandle m; - // lhs and rhs have the same value for m to be compatible. - TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, -1), c->Dim(rhs, -2), &m)); - - ShapeHandle out; - // Build final shape (batch_shape + m + n) in . - TF_RETURN_IF_ERROR( - c->Concatenate(output_batch_shape, c->Matrix(m, c->Dim(rhs, -1)), &out)); - c->set_output(0, out); - return Status::OK(); -} - // Input is [...,N,N]. Outputs are: // [...,N];[0], if compute_v is false, // [...,N];[...,N,N], if compute_v is true. @@ -454,7 +426,7 @@ REGISTER_OP("MatrixTriangularSolve") .Attr("adjoint: bool = False") .Attr("T: {double, float, half, complex64, complex128}") .SetShapeFn([](InferenceContext* c) { - return MatrixTriangularSolveShapeFn(c); + return MatrixSolveShapeFn(c, true /* square (*/); }); REGISTER_OP("MatrixSolveLs") diff --git a/tensorflow/core/ops/linalg_ops_test.cc b/tensorflow/core/ops/linalg_ops_test.cc index 7e5ddc02339..682a994e890 100644 --- a/tensorflow/core/ops/linalg_ops_test.cc +++ b/tensorflow/core/ops/linalg_ops_test.cc @@ -122,54 +122,34 @@ TEST(LinalgOpsTest, SelfAdjointEigV2_ShapeFn) { "[d0_0,d0_1,d0_2,d0_3|d0_4];[d0_0,d0_1,d0_2,d0_3|d0_4,d0_3|d0_4]"); } -TEST(LinalgOpsTest, MatrixSolve_ShapeFn) { - ShapeInferenceTestOp op("MatrixSolve"); - INFER_OK(op, "?;?", "?"); - INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1];?"); - INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,2];?"); - INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[5,?,?];[6]"); - INFER_ERROR("Shapes must be equal rank, but are 0 and 1", op, - "[5,?];[6,?,?]"); +TEST(LinalgOpsTest, SquareMatrixSolve_ShapeFn) { + for (const char* op_name : {"MatrixSolve", "MatrixTriangularSolve"}) { + ShapeInferenceTestOp op(op_name); + INFER_OK(op, "?;?", "?"); + INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1];?"); + INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,2];?"); + INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, + "[5,?,?];[6]"); + INFER_ERROR("Shapes must be equal rank, but are 0 and 1", op, + "[5,?];[6,?,?]"); - INFER_OK(op, "[?,?];?", "[d0_0|d0_1,?]"); + INFER_OK(op, "[?,?];?", "[d0_0|d0_1,?]"); - // Inputs are [...,M,M] and [...,M,K]. Output is [...,M,K]. - // First test where ... is empty. - INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]"); - INFER_OK(op, "[?,?];[1,?]", "[d1_0,d1_1]"); - INFER_OK(op, "[1,?];[1,?]", "[d0_0|d1_0,d1_1]"); - INFER_OK(op, "[?,1];[1,?]", "[d0_1|d1_0,d1_1]"); - INFER_OK(op, "[1,1];[?,?]", "[d0_0,d1_1]"); - INFER_OK(op, "[1,1];[1,?]", "[d0_0|d0_1|d1_0,d1_1]"); - // Test with ... being 2-d. - INFER_OK(op, "[10,?,?,?];[?,20,1,?]", "[d0_0,d1_1,d1_2,d1_3]"); - INFER_OK(op, "[10,?,1,?];[?,20,1,?]", "[d0_0,d1_1,d0_2|d1_2,d1_3]"); - INFER_OK(op, "[10,?,?,1];[?,20,1,?]", "[d0_0,d1_1,d0_3|d1_2,d1_3]"); - INFER_OK(op, "[10,?,1,1];[?,20,?,?]", "[d0_0,d1_1,d0_2,d1_3]"); - INFER_OK(op, "[10,?,1,1];[?,20,1,?]", "[d0_0,d1_1,d0_2|d0_3|d1_2,d1_3]"); -} - -TEST(LinalgOpsTest, MatrixTriangularSolve_ShapeFn) { - ShapeInferenceTestOp op("MatrixTriangularSolve"); - INFER_OK(op, "?;?", "?"); - INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1];?"); - INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,2];?"); - INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[5,?,?];[6]"); - - // Inputs are [...,M,M] and [...,M,K]. Output is [...,M,K]. - // First test where ... is empty. - INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]"); - INFER_OK(op, "[?,?];[1,?]", "[d1_0,d1_1]"); - INFER_OK(op, "[1,?];[1,?]", "[d0_0|d1_0,d1_1]"); - INFER_OK(op, "[?,1];[1,?]", "[d0_1|d1_0,d1_1]"); - INFER_OK(op, "[1,1];[?,?]", "[d0_0,d1_1]"); - INFER_OK(op, "[1,1];[1,?]", "[d0_0|d0_1|d1_0,d1_1]"); - // Test with ... being 2-d. - INFER_OK(op, "[10,?,?,?];[?,20,1,?]", "[d0_0,d1_1,d1_2,d1_3]"); - INFER_OK(op, "[10,?,1,?];[?,20,1,?]", "[d0_0,d1_1,d0_2|d1_2,d1_3]"); - INFER_OK(op, "[10,?,?,1];[?,20,1,?]", "[d0_0,d1_1,d0_3|d1_2,d1_3]"); - INFER_OK(op, "[10,?,1,1];[?,20,?,?]", "[d0_0,d1_1,d0_2,d1_3]"); - INFER_OK(op, "[10,?,1,1];[?,20,1,?]", "[d0_0,d1_1,d0_2|d0_3|d1_2,d1_3]"); + // Inputs are [...,M,M] and [...,M,K]. Output is [...,M,K]. + // First test where ... is empty. + INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]"); + INFER_OK(op, "[?,?];[1,?]", "[d1_0,d1_1]"); + INFER_OK(op, "[1,?];[1,?]", "[d0_0|d1_0,d1_1]"); + INFER_OK(op, "[?,1];[1,?]", "[d0_1|d1_0,d1_1]"); + INFER_OK(op, "[1,1];[?,?]", "[d0_0,d1_1]"); + INFER_OK(op, "[1,1];[1,?]", "[d0_0|d0_1|d1_0,d1_1]"); + // Test with ... being 2-d. + INFER_OK(op, "[10,?,?,?];[?,20,1,?]", "[d0_0,d1_1,d1_2,d1_3]"); + INFER_OK(op, "[10,?,1,?];[?,20,1,?]", "[d0_0,d1_1,d0_2|d1_2,d1_3]"); + INFER_OK(op, "[10,?,?,1];[?,20,1,?]", "[d0_0,d1_1,d0_3|d1_2,d1_3]"); + INFER_OK(op, "[10,?,1,1];[?,20,?,?]", "[d0_0,d1_1,d0_2,d1_3]"); + INFER_OK(op, "[10,?,1,1];[?,20,1,?]", "[d0_0,d1_1,d0_2|d0_3|d1_2,d1_3]"); + } } TEST(LinalgOpsTest, MatrixSolveLs_ShapeFn) { diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 5b7b1b9ecbe..6ea17b4fa5a 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -756,7 +756,6 @@ cuda_py_test( name = "matrix_triangular_solve_op_test", size = "small", srcs = ["matrix_triangular_solve_op_test.py"], - shard_count = 2, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:linalg_ops", diff --git a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py index 1c2407a7c72..32ab6125717 100644 --- a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops @@ -67,32 +68,31 @@ class MatrixTriangularSolveOpTest(test.TestCase): else: a_np = a if adjoint: - axes = list(range(len(a_np.shape))) - axes[-2] = -1 - axes[-1] = -2 - a_np = np.conj(np.transpose(a_np, axes=axes)) + a_np = np.conj(np.transpose(a_np)) if batch_dims is not None: a = np.tile(a, batch_dims + [1, 1]) a_np = np.tile(a_np, batch_dims + [1, 1]) b = np.tile(b, batch_dims + [1, 1]) - def broadcast(a, b): - b1 = b + np.zeros(a.shape[:-2] + (1, 1), dtype=b.dtype) - return a, b1 - - a_tf = a - b_tf = b - if use_placeholder: - a_tf = array_ops.placeholder_with_default(a_tf, shape=None) - b_tf = array_ops.placeholder_with_default(b_tf, shape=None) - tf_ans = linalg_ops.matrix_triangular_solve( - a_tf, b_tf, lower=lower, adjoint=adjoint) - tf_val = self.evaluate(tf_ans) - a_np, b = broadcast(a_np, b) - np_ans = np.linalg.solve(a_np, b) - self.assertEqual(np_ans.shape, tf_val.shape) - self.assertAllClose(np_ans, tf_val) + with self.cached_session(use_gpu=True) as sess: + if use_placeholder: + a_tf = array_ops.placeholder(a.dtype) + b_tf = array_ops.placeholder(b.dtype) + tf_ans = linalg_ops.matrix_triangular_solve( + a_tf, b_tf, lower=lower, adjoint=adjoint) + tf_val = sess.run(tf_ans, feed_dict={a_tf: a, b_tf: b}) + np_ans = np.linalg.solve(a_np, b) + else: + a_tf = constant_op.constant(a) + b_tf = constant_op.constant(b) + tf_ans = linalg_ops.matrix_triangular_solve( + a_tf, b_tf, lower=lower, adjoint=adjoint) + tf_val = self.evaluate(tf_ans) + np_ans = np.linalg.solve(a_np, b) + self.assertEqual(np_ans.shape, tf_ans.get_shape()) + self.assertEqual(np_ans.shape, tf_val.shape) + self.assertAllClose(np_ans, tf_val) @test_util.run_deprecated_v1 def testSolve(self): @@ -136,50 +136,6 @@ class MatrixTriangularSolveOpTest(test.TestCase): # Batch of 3x2x2x2 matrices, 3x2x2x3 right-hand sides. self._verifySolveAllWaysReal(matrix, rhs, batch_dims=[3, 2]) - @test_util.run_deprecated_v1 - @test_util.disable_xla("XLA cannot broadcast triangular solve.") - def testSolveBatchBroadcast(self): - # 2 x 2 x 2 - matrix = np.array([[[1., 0.], [3., 4.]], [[1., 0.], [2., 1.]]]) - # 2 x 3 - rhs = np.array([[1., 0., 1.], [0., 1., 1.]]) - # 2 x 2 x 3 - self._verifySolveAllWaysReal(matrix, rhs) - # 2 x 2 x 2 - matrix2 = np.array([[[1., 0.], [3., 4.]], [[2., 0.], [1., 6.3]]]) - # 1 x 2 x 3 - rhs = np.array([[[1., 0., 1.], [0., 1., 1.]]]) - # 2 x 2 x 3 - self._verifySolveAllWaysReal(matrix2, rhs) - - @test_util.run_deprecated_v1 - @test_util.disable_xla("XLA cannot broadcast triangular solve.") - def testSolveBatchBroadcastLargerBatches(self): - # 1 x 10 x 10 - matrix = np.random.uniform(low=1, high=2., size=[1, 10, 10]) - # 10 x 1 - rhs = np.random.uniform(size=[10, 1]) - # 1 x 10 x 1 - self._verifySolveAllWaysReal(matrix, rhs) - - # 2 x 10 x 10 - matrix = np.random.uniform(low=1, high=2., size=[2, 10, 10]) - # 10 x 1 - rhs = np.random.uniform(size=[10, 1]) - # 2 x 10 x 1 - self._verifySolveAllWaysReal(matrix, rhs) - - # 2 x 257 x 257 - matrix = np.random.uniform(low=1, high=2., size=[2, 257, 257]) - # Also ensure the matrix is well conditioned by making it diagonally - # dominant. - np.fill_diagonal(matrix[0, ...], 257 * 2) - np.fill_diagonal(matrix[1, ...], 257 * 2) - # 257 x 1 - rhs = np.random.uniform(size=[257, 1]) - # 2 x 257 x 1 - self._verifySolveAllWaysReal(matrix, rhs) - @test_util.run_deprecated_v1 def testSolveBatchComplex(self): if test.is_built_with_rocm(): diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py index 94ef2a9bff4..3e6d22accec 100644 --- a/tensorflow/python/ops/linalg_grad.py +++ b/tensorflow/python/ops/linalg_grad.py @@ -607,7 +607,6 @@ def _MatrixSolveLsGrad(op, grad): def _MatrixTriangularSolveGrad(op, grad): """Gradient for MatrixTriangularSolve.""" a = op.inputs[0] - b = op.inputs[1] adjoint_a = op.get_attr("adjoint") lower_a = op.get_attr("lower") c = op.outputs[0] @@ -621,16 +620,7 @@ def _MatrixTriangularSolveGrad(op, grad): grad_a = array_ops.matrix_band_part(grad_a, -1, 0) else: grad_a = array_ops.matrix_band_part(grad_a, 0, -1) - # If the static batch shapes are equal, we don't need to unbroadcast. - if (a.shape.is_fully_defined() and b.shape.is_fully_defined() and - a.shape[:-2] == b.shape[:-2]): - return grad_a, grad_b - a_shape = array_ops.shape(a) - b_shape = array_ops.shape(b) - ra, rb = array_ops.broadcast_gradient_args(a_shape[:-2], b_shape[:-2]) - grad_a = array_ops.reshape(math_ops.reduce_sum(grad_a, axis=ra), a_shape) - grad_b = array_ops.reshape(math_ops.reduce_sum(grad_b, axis=rb), b_shape) - return grad_a, grad_b + return (grad_a, grad_b) @ops.RegisterGradient("SelfAdjointEigV2") diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py index 04678cca8e5..bb84c3f7dd9 100644 --- a/tensorflow/python/ops/linalg_ops.py +++ b/tensorflow/python/ops/linalg_ops.py @@ -79,67 +79,6 @@ def _RegularizedGramianCholesky(matrix, l2_regularizer, first_kind): return gen_linalg_ops.cholesky(gramian) -@tf_export( - 'linalg.triangular_solve', - v1=['linalg.triangular_solve', 'matrix_triangular_solve']) -def matrix_triangular_solve(matrix, rhs, lower=True, adjoint=False, name=None): - """Solve systems of linear equations with upper or lower triangular matrices. - - `matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form - square matrices. If `lower` is `True` then the strictly upper triangular part - of each inner-most matrix is assumed to be zero and not accessed. If `lower` - is `False` then the strictly lower triangular part of each inner-most matrix - is assumed to be zero and not accessed. `rhs` is a tensor of shape - `[..., M, N]`. - - The output is a tensor of shape `[..., M, N]`. If `adjoint` is `True` then the - innermost matrices in output satisfy matrix equations `matrix[..., i, k] * - output[..., k, j] = rhs[..., i, j]`. If `adjoint` is `False` then the - innermost matrices in output satisfy matrix equations - `adjoint(matrix[..., i, k]) * output[..., k, j] = rhs[..., i, j]`. - - Example: - - >>> a = tf.constant([[3, 0, 0, 0], - ... [2, 1, 0, 0], - ... [1, 0, 1, 0], - ... [1, 1, 1, 1]], dtype=tf.float32) - - >>> b = tf.constant([[4], [2], [4], [2]], dtype=tf.float32) - >>> x = tf.linalg.triangular_solve(a, b, lower=True) - >>> x - - >>> tf.matmul(a, x) - - - Args: - matrix: A `Tensor`. Must be one of the following types: `float64`, - `float32`, `half`, `complex64`, `complex128`. Shape is `[..., M, M]`. - rhs: A `Tensor`. Must have the same type as `matrix`. Shape is `[..., M, - N]`. - lower: An optional `bool`. Defaults to `True`. Boolean indicating whether - the innermost matrices in matrix are lower or upper triangular. - adjoint: An optional `bool`. Defaults to `False`. Boolean indicating whether - to solve with matrix or its (block-wise) adjoint. - name: A name for the operation (optional). - - Returns: - A `Tensor`. Has the same type as matrix, and shape is `[..., M, N]`. - - """ - with ops.name_scope(name, 'triangular_solve', [matrix, rhs]): - return gen_linalg_ops.matrix_triangular_solve( - matrix, rhs, lower=lower, adjoint=adjoint) - - @tf_export( 'linalg.cholesky_solve', v1=['linalg.cholesky_solve', 'cholesky_solve']) @deprecation.deprecated_endpoints('cholesky_solve')