diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py index 5dc2619ac94..58157168182 100644 --- a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py +++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py @@ -50,7 +50,9 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase): atol): feed_dict = {placeholder_a: a, placeholder_ca: clean_a, placeholder_b: b} verification_np = sess.run(verification, feed_dict) - self.assertAllClose(b, verification_np, atol=atol) + broadcasted_shape = a.shape[:-2] + (b.shape[-2], b.shape[-1]) + broadcasted_b = b + np.zeros(shape=broadcasted_shape, dtype=b.dtype) + self.assertAllClose(broadcasted_b, verification_np, atol=atol) def _VerifyTriangularSolve(self, a, b, lower, adjoint, atol): clean_a = np.tril(a) if lower else np.triu(a) @@ -111,6 +113,18 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase): self._VerifyTriangularSolveCombo( a.astype(dtype), b.astype(dtype), atol=1e-3) + def testBatchBroadcast(self): + rng = np.random.RandomState(0) + shapes = [((3, 3), (4, 3, 5)), ((1, 2, 2), (3, 2, 1)), ((1, 1), (1, 1, 2)), + ((1, 3, 4, 4), (2, 1, 4, 1))] + tuples = itertools.product(self.float_types, shapes) + for dtype, (a_shape, b_shape) in tuples: + n = a_shape[-1] + a = np.tril(rng.rand(*a_shape) - 0.5) / (2.0 * n) + np.eye(n) + b = rng.randn(*b_shape) + self._VerifyTriangularSolveCombo( + a.astype(dtype), b.astype(dtype), atol=1e-3) + def testLarge(self): n = 1024 rng = np.random.RandomState(0) diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc index 5a6569c8954..5a719484e05 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc @@ -13,10 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/util/bcast.h" +#include "tensorflow/core/util/matmul_bcast.h" namespace tensorflow { namespace { @@ -30,8 +33,28 @@ class MatrixTriangularSolveOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { + const TensorShape lhs_shape = ctx->InputShape(0); + const TensorShape rhs_shape = ctx->InputShape(1); + + // By TensorFlow conventions the inputs may not have the same + // shapes, in which case they will be automatically broadcast if + // possible before mapping. Use the standard TensorFlow helper to + // compute valid broadcast shapes, but rely below on XLA to + // automatically perform the broadcast assuming its valid shapes are + // a superset of TensorFlow's valid shapes. + MatMulBCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape)); + if (!bcast.IsValid()) { + ctx->SetStatus(errors::InvalidArgument( + "Incompatible shapes: ", lhs_shape.DebugString(), " vs. ", + rhs_shape.DebugString())); + return; + } + + xla::XlaOp a = ctx->Input(0); + xla::XlaOp b = ctx->Input(1); + std::tie(a, b) = Broadcast(a, lhs_shape, b, rhs_shape, bcast); auto result = xla::TriangularSolve( - ctx->Input(0), ctx->Input(1), /*left_side=*/true, + a, b, /*left_side=*/true, /*lower=*/lower_, /*unit_diagonal=*/false, /*transpose_a=*/ adjoint_ ? xla::TriangularSolveOptions::ADJOINT @@ -40,10 +63,41 @@ class MatrixTriangularSolveOp : public XlaOpKernel { } private: + static std::pair Broadcast( + xla::XlaOp lhs, const TensorShape& lhs_shape, xla::XlaOp rhs, + const TensorShape& rhs_shape, const MatMulBCast& broadcast_helper); bool lower_; bool adjoint_; }; +/* static */ std::pair +MatrixTriangularSolveOp::Broadcast(xla::XlaOp lhs, const TensorShape& lhs_shape, + xla::XlaOp rhs, const TensorShape& rhs_shape, + const MatMulBCast& broadcast_helper) { + // Get the batch shape. + int64 m = lhs_shape.dim_size(lhs_shape.dims() - 1); + int64 n = rhs_shape.dim_size(rhs_shape.dims() - 1); + + TensorShape lhs_broadcast_shape(broadcast_helper.output_batch_shape()); + lhs_broadcast_shape.AddDim(m); + lhs_broadcast_shape.AddDim(m); + auto lhs_output = BroadcastTo(lhs, lhs_broadcast_shape.dim_sizes()); + if (!lhs_output.ok()) { + xla::XlaOp error = lhs.builder()->ReportError(lhs_output.status()); + return {error, error}; + } + + TensorShape rhs_broadcast_shape(broadcast_helper.output_batch_shape()); + rhs_broadcast_shape.AddDim(m); + rhs_broadcast_shape.AddDim(n); + auto rhs_output = BroadcastTo(rhs, rhs_broadcast_shape.dim_sizes()); + if (!rhs_output.ok()) { + xla::XlaOp error = rhs.builder()->ReportError(rhs_output.status()); + return {error, error}; + } + return {lhs_output.ValueOrDie(), rhs_output.ValueOrDie()}; +} + REGISTER_XLA_OP(Name("MatrixTriangularSolve"), MatrixTriangularSolveOp); } // namespace diff --git a/tensorflow/core/api_def/base_api/api_def_MatrixTriangularSolve.pbtxt b/tensorflow/core/api_def/base_api/api_def_MatrixTriangularSolve.pbtxt index 0ecd7937995..bf31b2d9e4d 100644 --- a/tensorflow/core/api_def/base_api/api_def_MatrixTriangularSolve.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_MatrixTriangularSolve.pbtxt @@ -44,15 +44,17 @@ square matrices. If `lower` is `True` then the strictly upper triangular part of each inner-most matrix is assumed to be zero and not accessed. If `lower` is False then the strictly lower triangular part of each inner-most matrix is assumed to be zero and not accessed. -`rhs` is a tensor of shape `[..., M, K]`. +`rhs` is a tensor of shape `[..., M, N]`. -The output is a tensor of shape `[..., M, K]`. If `adjoint` is +The output is a tensor of shape `[..., M, N]`. If `adjoint` is `True` then the innermost matrices in `output` satisfy matrix equations `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`. If `adjoint` is `False` then the strictly then the innermost matrices in `output` satisfy matrix equations `adjoint(matrix[..., i, k]) * output[..., k, j] = rhs[..., i, j]`. +Note, the batch shapes for the inputs only need to broadcast. + Example: ```python diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixTriangularSolve.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixTriangularSolve.pbtxt index 17dc57335ae..8022c6d0556 100644 --- a/tensorflow/core/api_def/python_api/api_def_MatrixTriangularSolve.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_MatrixTriangularSolve.pbtxt @@ -1,10 +1,4 @@ op { graph_op_name: "MatrixTriangularSolve" - endpoint { - name: "linalg.triangular_solve" - } - endpoint { - name: "matrix_triangular_solve" - deprecation_version: 2 - } + visibility: HIDDEN } diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 0109597e9cc..d579878b33d 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3503,6 +3503,22 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "rocm_solvers", + srcs = ["rocm_solvers.cc"], + hdrs = ["rocm_solvers.h"], + visibility = [":friends"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/stream_executor/lib", + "//tensorflow/stream_executor/platform:dso_loader", + "//tensorflow/stream_executor/rocm:rocblas_plugin", + "//tensorflow/stream_executor/rocm:rocm_gpu_executor", + "@local_config_rocm//rocm:rocprim", + ], +) + tf_kernel_library( name = "cuda_sparse", srcs = if_cuda(["cuda_sparse.cc"]) + if_rocm(["rocm_sparse.cc"]), @@ -3527,6 +3543,8 @@ LINALG_DEPS = [ ] + if_cuda([ ":cuda_solvers", ":transpose_functor", +]) + if_rocm([ + ":rocm_solvers", ]) tf_kernel_library( @@ -3613,9 +3631,23 @@ tf_kernel_library( tf_kernel_library( name = "matrix_triangular_solve_op", + hdrs = ["matrix_triangular_solve_op_impl.h"], prefix = "matrix_triangular_solve_op", - deps = LINALG_DEPS + if_cuda([ + deps = [ + ":linalg_ops_common", + "//third_party/eigen3", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ":fill_functor", + "//tensorflow/core:stream_executor", + ] + if_cuda([ "//tensorflow/core/platform/default/build_config:cublas_plugin", + ":cuda_solvers", + ]) + if_rocm([ + "@local_config_rocm//rocm:rocprim", + ":rocm_solvers", + ]) + if_cuda_or_rocm([ + ":transpose_functor", ]), ) @@ -4204,6 +4236,25 @@ tf_cuda_cc_test( ], ) +tf_cuda_cc_test( + name = "matrix_triangular_solve_op_test", + size = "small", + srcs = ["matrix_triangular_solve_op_test.cc"], + deps = [ + ":broadcast_to_op", + ":matrix_triangular_solve_op", + ":ops_testutil", + ":ops_util", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_cuda_cc_test( name = "scan_ops_test", size = "small", diff --git a/tensorflow/core/kernels/cuda_solvers.cc b/tensorflow/core/kernels/cuda_solvers.cc index 1c569204265..dcf40ef6798 100644 --- a/tensorflow/core/kernels/cuda_solvers.cc +++ b/tensorflow/core/kernels/cuda_solvers.cc @@ -900,6 +900,106 @@ static inline Status MatInvBatchedImpl( TF_CALL_LAPACK_TYPES(MATINV_BATCHED_INSTANCE); +template +static inline Status TrsmImpl(SolverFnT solver, cublasHandle_t cublas_handle, + cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, + int m, int n, + const Scalar* alpha, /* host or device pointer */ + const Scalar* A, int lda, Scalar* B, int ldb) { + mutex_lock lock(handle_map_mutex); + using CudaScalar = typename CUDAComplexT::type; + TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, side, uplo, trans, diag, m, n, + reinterpret_cast(alpha), + reinterpret_cast(A), lda, + reinterpret_cast(B), ldb)); + return Status::OK(); +} + +#define TRSM_INSTANCE(Scalar, type_prefix) \ + template <> \ + Status CudaSolver::Trsm( \ + cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, \ + cublasDiagType_t diag, int m, int n, \ + const Scalar* alpha, /* host or device pointer */ \ + const Scalar* A, int lda, Scalar* B, int ldb) { \ + return TrsmImpl(BLAS_SOLVER_FN(trsm, type_prefix), cublas_handle_, side, \ + uplo, trans, diag, m, n, alpha, A, lda, B, ldb); \ + } + +TF_CALL_LAPACK_TYPES(TRSM_INSTANCE); + +template +static inline Status TrsvImpl(SolverFnT solver, cublasHandle_t cublas_handle, + cublasFillMode_t uplo, cublasOperation_t trans, + cublasDiagType_t diag, int n, const Scalar* A, + int lda, Scalar* x, int incx) { + mutex_lock lock(handle_map_mutex); + using CudaScalar = typename CUDAComplexT::type; + TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, uplo, trans, diag, n, + reinterpret_cast(A), lda, + reinterpret_cast(x), incx)); + return Status::OK(); +} + +#define TRSV_INSTANCE(Scalar, type_prefix) \ + template <> \ + Status CudaSolver::Trsv( \ + cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, \ + int n, const Scalar* A, int lda, Scalar* x, int incx) { \ + return TrsvImpl(BLAS_SOLVER_FN(trsv, type_prefix), cublas_handle_, uplo, \ + trans, diag, n, A, lda, x, incx); \ + } + +TF_CALL_LAPACK_TYPES(TRSV_INSTANCE); + +template +static inline Status TrsmBatchedImpl( + SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context, + cublasHandle_t cublas_handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const Scalar* alpha, const Scalar* const host_a_dev_ptrs[], int lda, + Scalar* host_b_dev_ptrs[], int ldb, int batch_size) { + mutex_lock lock(handle_map_mutex); + using CudaScalar = typename CUDAComplexT::type; + ScratchSpace dev_a_dev_ptrs = + cuda_solver->GetScratchSpace(sizeof(CudaScalar*) * batch_size, "", + /* on_host */ false); + ScratchSpace dev_b_dev_ptrs = + cuda_solver->GetScratchSpace(sizeof(CudaScalar*) * batch_size, "", + /* on_host */ false); + if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */, + host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) { + return errors::Internal("TrsmBatched: failed to copy pointers to device"); + } + if (!CopyHostToDevice(context, dev_b_dev_ptrs.mutable_data() /* dest */, + host_b_dev_ptrs /* source */, dev_b_dev_ptrs.bytes())) { + return errors::Internal("TrsmBatched: failed to copy pointers to device"); + } + TF_RETURN_IF_CUBLAS_ERROR( + solver(cublas_handle, side, uplo, trans, diag, m, n, + reinterpret_cast(alpha), + reinterpret_cast(dev_a_dev_ptrs.data()), + lda, reinterpret_cast(dev_b_dev_ptrs.mutable_data()), + ldb, batch_size)); + return Status::OK(); +} + +#define TRSM_BATCHED_INSTANCE(Scalar, type_prefix) \ + template <> \ + Status CudaSolver::TrsmBatched( \ + cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, \ + cublasDiagType_t diag, int m, int n, const Scalar* alpha, \ + const Scalar* const dev_Aarray[], int lda, Scalar* dev_Barray[], \ + int ldb, int batch_size) { \ + return TrsmBatchedImpl(BLAS_SOLVER_FN(trsmBatched, type_prefix), this, \ + context_, cublas_handle_, side, uplo, trans, diag, \ + m, n, alpha, dev_Aarray, lda, dev_Barray, ldb, \ + batch_size); \ + } + +TF_CALL_LAPACK_TYPES(TRSM_BATCHED_INSTANCE); + } // namespace tensorflow #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cuda_solvers.h b/tensorflow/core/kernels/cuda_solvers.h index 104ee09a2bc..fa0984d05c7 100644 --- a/tensorflow/core/kernels/cuda_solvers.h +++ b/tensorflow/core/kernels/cuda_solvers.h @@ -334,6 +334,29 @@ class CudaSolver { Scalar* dev_V, int ldv, int* dev_lapack_info, int batch_size); + // Triangular solve + // Returns Status::OK() if the kernel was launched successfully. + // See https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-trsm + template + Status Trsm(cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const Scalar* alpha, const Scalar* A, int lda, Scalar* B, + int ldb); + + template + Status Trsv(cublasFillMode_t uplo, cublasOperation_t trans, + cublasDiagType_t diag, int n, const Scalar* A, int lda, Scalar* x, + int intcx); + + // See + // https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-trsmbatched + template + Status TrsmBatched(cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, + int n, const Scalar* alpha, + const Scalar* const dev_Aarray[], int lda, + Scalar* dev_Barray[], int ldb, int batch_size); + private: OpKernelContext* context_; // not owned. cudaStream_t cuda_stream_; diff --git a/tensorflow/core/kernels/matrix_triangular_solve_op.cc b/tensorflow/core/kernels/matrix_triangular_solve_op.cc deleted file mode 100644 index 61bc4aad214..00000000000 --- a/tensorflow/core/kernels/matrix_triangular_solve_op.cc +++ /dev/null @@ -1,258 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// See docs in ../ops/linalg_ops.cc. - -#include "third_party/eigen3/Eigen/Core" -#include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/kernels/linalg_ops_common.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "tensorflow/core/platform/stream_executor.h" -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -namespace tensorflow { - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -namespace { -template -se::DeviceMemory AsDeviceMemory(const Scalar* gpu_memory) { - se::DeviceMemoryBase wrapped(const_cast(gpu_memory)); - se::DeviceMemory typed(wrapped); - return typed; -} -} // namespace -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -template -class MatrixTriangularSolveOp : public LinearAlgebraOp { - public: - INHERIT_LINALG_TYPEDEFS(Scalar); - - explicit MatrixTriangularSolveOp(OpKernelConstruction* context) - : Base(context), lower_(true), adjoint_(false) { - OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_)); - OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_)); - } - - void ValidateInputMatrixShapes( - OpKernelContext* context, - const TensorShapes& input_matrix_shapes) const final { - Base::ValidateSquareSolver(context, input_matrix_shapes); - } - - TensorShapes GetOutputMatrixShapes( - const TensorShapes& input_matrix_shapes) const final { - return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1), - input_matrix_shapes[1].dim_size(1)})}); - } - - int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { - double rows = static_cast(input_matrix_shapes[0].dim_size(0)); - double num_rhss = static_cast(input_matrix_shapes[1].dim_size(1)); - double cost = rows * rows * num_rhss * - (Eigen::TensorOpCost::AddCost() + - Eigen::TensorOpCost::MulCost()); - return cost >= static_cast(kint64max) ? kint64max - : static_cast(cost); - } - - bool EnableInputForwarding() const final { return false; } - - void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, - MatrixMaps* outputs) final { - const ConstMatrixMap& matrix = inputs[0]; - const ConstMatrixMap& rhs = inputs[1]; - MatrixMap& output = outputs->at(0); - - if (matrix.rows() == 0 || rhs.rows() == 0 || rhs.cols() == 0) { - // To be consistent with the MatrixInverse op, we define the solution for - // an empty set of equation as the empty matrix. - return; - } - const RealScalar min_abs_pivot = matrix.diagonal().cwiseAbs().minCoeff(); - OP_REQUIRES(context, min_abs_pivot > RealScalar(0), - errors::InvalidArgument("Input matrix is not invertible.")); - if (lower_) { - auto triangle = matrix.template triangularView(); - if (adjoint_) { - output.noalias() = triangle.adjoint().solve(rhs); - } else { - output.noalias() = triangle.solve(rhs); - } - } else { - auto triangle = matrix.template triangularView(); - if (adjoint_) { - output.noalias() = triangle.adjoint().solve(rhs); - } else { - output.noalias() = triangle.solve(rhs); - } - } - } - - private: - bool lower_; - bool adjoint_; - - TF_DISALLOW_COPY_AND_ASSIGN(MatrixTriangularSolveOp); -}; - -REGISTER_LINALG_OP_CPU("MatrixTriangularSolve", - (MatrixTriangularSolveOp), float); -REGISTER_LINALG_OP_CPU("MatrixTriangularSolve", - (MatrixTriangularSolveOp), double); -REGISTER_LINALG_OP_CPU("MatrixTriangularSolve", - (MatrixTriangularSolveOp), complex64); -REGISTER_LINALG_OP_CPU("MatrixTriangularSolve", - (MatrixTriangularSolveOp), complex128); -REGISTER_LINALG_OP_CPU("BatchMatrixTriangularSolve", - (MatrixTriangularSolveOp), float); -REGISTER_LINALG_OP_CPU("BatchMatrixTriangularSolve", - (MatrixTriangularSolveOp), double); - -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -// TODO(rmlarsen): Re-factor to -// 1. Enable buffer forwarding from rhs->out. -// 2. Save Memcpy when buffer forwarding is used. -// 3. Copy entire rhs in a single Memcpy when forwarding is not used. -template -class MatrixTriangularSolveOpGPU : public LinearAlgebraOp { - public: - INHERIT_LINALG_TYPEDEFS(Scalar); - - explicit MatrixTriangularSolveOpGPU(OpKernelConstruction* context) - : Base(context), lower_(true), adjoint_(false) { - OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_)); - OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_)); - } - - void ValidateInputMatrixShapes( - OpKernelContext* context, - const TensorShapes& input_matrix_shapes) const final { - Base::ValidateSquareSolver(context, input_matrix_shapes); - } - - TensorShapes GetOutputMatrixShapes( - const TensorShapes& input_matrix_shapes) const final { - return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1), - input_matrix_shapes[1].dim_size(1)})}); - } - - int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { - double rows = static_cast(input_matrix_shapes[0].dim_size(0)); - double num_rhss = static_cast(input_matrix_shapes[1].dim_size(1)); - double cost = rows * rows * num_rhss * - (Eigen::TensorOpCost::AddCost() + - Eigen::TensorOpCost::MulCost()); - return cost >= static_cast(kint64max) ? kint64max - : static_cast(cost); - } - - bool EnableInputForwarding() const final { return false; } - - void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, - MatrixMaps* outputs) final { - const ConstMatrixMap& matrix = inputs[0]; - const ConstMatrixMap& rhs = inputs[1]; - MatrixMap& output = outputs->at(0); - - if (matrix.rows() == 0 || rhs.rows() == 0 || rhs.cols() == 0) { - // To be consistent with the MatrixInverse op, we define the solution for - // an empty set of equation as the empty matrix. - return; - } - - auto matrix_ptr = AsDeviceMemory(matrix.data()); - auto rhs_ptr = AsDeviceMemory(rhs.data()); - auto out_ptr = AsDeviceMemory(output.data()); - - auto* stream = context->op_device_context()->stream(); - uint64 rhs_elems = rhs.rows() * rhs.cols(); - bool copy_status = - stream->ThenMemcpyD2D(&out_ptr, rhs_ptr, sizeof(Scalar) * rhs_elems) - .ok(); - if (!copy_status) { - context->SetStatus( - errors::Internal("Failed to copy rhs into output before solve")); - } - - // Cublas does - // output = matrix \ rhs - // where matrix, rhs and output are assumed to be in column major. - // We want the output to be in row-major, so we can compute - // output' = rhs' / matrix' (' stands for transpose) - // Upper/lower needs to be swapped for this. - - se::blas::UpperLower upper_lower_matrix; - se::blas::Transpose transpose_matrix; - if (lower_) { - upper_lower_matrix = se::blas::UpperLower::kUpper; - } else { - upper_lower_matrix = se::blas::UpperLower::kLower; - } - if (adjoint_) { - transpose_matrix = se::blas::Transpose::kConjugateTranspose; - } else { - transpose_matrix = se::blas::Transpose::kNoTranspose; - } - uint64 leading_dim_matrix = matrix.cols(); - uint64 leading_dim_output = output.cols(); - uint64 colmajor_rows = output.cols(); - uint64 colmajor_cols = output.rows(); - bool blas_launch_status = - stream - ->ThenBlasTrsm( - se::blas::Side::kRight /*side*/, upper_lower_matrix /*uplo*/, - transpose_matrix /*trans*/, - se::blas::Diagonal::kNonUnit /*diag*/, colmajor_rows /*m*/, - colmajor_cols /*n*/, Scalar(1.0) /*alpha*/, matrix_ptr, - leading_dim_matrix /*lda*/, &out_ptr, - leading_dim_output /*ldb*/) - .ok(); - if (!blas_launch_status) { - context->SetStatus(errors::Internal("Blas TRSM launch failed")); - } - } - - private: - bool lower_; - bool adjoint_; - - TF_DISALLOW_COPY_AND_ASSIGN(MatrixTriangularSolveOpGPU); -}; - -REGISTER_LINALG_OP_GPU("MatrixTriangularSolve", - (MatrixTriangularSolveOpGPU), float); -REGISTER_LINALG_OP_GPU("MatrixTriangularSolve", - (MatrixTriangularSolveOpGPU), double); -REGISTER_LINALG_OP_GPU("MatrixTriangularSolve", - (MatrixTriangularSolveOpGPU), complex64); -REGISTER_LINALG_OP_GPU("MatrixTriangularSolve", - (MatrixTriangularSolveOpGPU), complex128); -REGISTER_LINALG_OP_GPU("BatchMatrixTriangularSolve", - (MatrixTriangularSolveOpGPU), float); -REGISTER_LINALG_OP_GPU("BatchMatrixTriangularSolve", - (MatrixTriangularSolveOpGPU), double); - -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - -} // namespace tensorflow diff --git a/tensorflow/core/kernels/matrix_triangular_solve_op_complex.cc b/tensorflow/core/kernels/matrix_triangular_solve_op_complex.cc new file mode 100644 index 00000000000..47f958ff6a9 --- /dev/null +++ b/tensorflow/core/kernels/matrix_triangular_solve_op_complex.cc @@ -0,0 +1,28 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/matrix_triangular_solve_op_impl.h" + +namespace tensorflow { + +TF_CALL_complex64(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU); +TF_CALL_complex128(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU); + +#if GOOGLE_CUDA +TF_CALL_complex64(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU); +TF_CALL_complex128(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU); +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/matrix_triangular_solve_op_impl.h b/tensorflow/core/kernels/matrix_triangular_solve_op_impl.h new file mode 100644 index 00000000000..48f2eec11a6 --- /dev/null +++ b/tensorflow/core/kernels/matrix_triangular_solve_op_impl.h @@ -0,0 +1,437 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/linalg_ops.cc. +// +#ifndef TENSORFLOW_CORE_KERNELS_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_ + +#include "third_party/eigen3/Eigen/Core" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/kernels/linalg_ops_common.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/matmul_bcast.h" + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/kernels/transpose_functor.h" +#include "tensorflow/core/platform/stream_executor.h" +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#if GOOGLE_CUDA +#include "tensorflow/core/kernels/cuda_solvers.h" +#elif TENSORFLOW_USE_ROCM +#include "tensorflow/core/kernels/rocm_solvers.h" +#endif + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +template +se::DeviceMemory AsDeviceMemory(const Scalar* gpu_memory) { + se::DeviceMemoryBase wrapped(const_cast(gpu_memory)); + se::DeviceMemory typed(wrapped); + return typed; +} + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +// Sequential batch matrix triangular solve kernel that calls Eigen's +// matrix triangular solve. +template +struct SequentialMatrixTriangularSolveKernel { + using Matrix = + Eigen::Matrix; + using ConstMatrixMap = Eigen::Map; + using MatrixMap = Eigen::Map; + using RealScalar = typename Eigen::NumTraits::Real; + + static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t, + int slice) { + return ConstMatrixMap( + t.flat().data() + slice * t.dim_size(1) * t.dim_size(2), + t.dim_size(1), t.dim_size(2)); + } + + static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) { + return MatrixMap( + t->flat().data() + slice * t->dim_size(1) * t->dim_size(2), + t->dim_size(1), t->dim_size(2)); + } + + static void Run(const Tensor& in_x, const Tensor& in_y, bool lower, + bool adjoint, const MatMulBCast& bcast, Tensor* out, + int start, int limit) { + const bool should_bcast = bcast.IsBroadcastingRequired(); + const auto& x_batch_indices = bcast.x_batch_indices(); + const auto& y_batch_indices = bcast.y_batch_indices(); + for (int64 i = start; i < limit; ++i) { + const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i; + const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i; + auto matrix = ConstTensorSliceToEigenMatrix(in_x, x_batch_index); + auto rhs = ConstTensorSliceToEigenMatrix(in_y, y_batch_index); + auto output = TensorSliceToEigenMatrix(out, i); + if (lower) { + auto triangle = matrix.template triangularView(); + if (adjoint) { + output.noalias() = triangle.adjoint().solve(rhs); + } else { + output.noalias() = triangle.solve(rhs); + } + } else { + auto triangle = matrix.template triangularView(); + if (adjoint) { + output.noalias() = triangle.adjoint().solve(rhs); + } else { + output.noalias() = triangle.solve(rhs); + } + } + } + } +}; + +template +struct LaunchBatchMatrixTriangularSolve; + +template +struct LaunchBatchMatrixTriangularSolve { + static void Launch(OpKernelContext* context, const Tensor& in_x, + const Tensor& in_y, bool adjoint, bool lower, + const MatMulBCast& bcast, Tensor* out) { + // Number of matrix triangular solves i.e. size of the batch. + const int64 batch_size = bcast.output_batch_size(); + const int64 cost_per_unit = + in_x.dim_size(1) * in_x.dim_size(1) * in_y.dim_size(2) / 2; + auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + + using Matrix = + Eigen::Matrix; + using ConstMatrixMap = Eigen::Map; + using RealScalar = typename Eigen::NumTraits::Real; + // Check diagonal before doing any solves. + auto matrix = ConstMatrixMap(in_x.flat().data(), in_x.dim_size(1), + in_x.dim_size(2)); + const RealScalar min_abs_pivot = matrix.diagonal().cwiseAbs().minCoeff(); + OP_REQUIRES(context, min_abs_pivot > RealScalar(0), + errors::InvalidArgument("Input matrix is not invertible.")); + + Shard(worker_threads.num_threads, worker_threads.workers, batch_size, + cost_per_unit, + [&in_x, &in_y, adjoint, lower, &bcast, out](int start, int limit) { + SequentialMatrixTriangularSolveKernel::Run( + in_x, in_y, lower, adjoint, bcast, out, start, limit); + }); + } +}; + +template +class BaseMatrixTriangularSolveOp : public OpKernel { + public: + explicit BaseMatrixTriangularSolveOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_)); + OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_)); + } + + ~BaseMatrixTriangularSolveOp() override {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& in0 = ctx->input(0); + const Tensor& in1 = ctx->input(1); + + ValidateInputTensors(ctx, in0, in1); + + MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes()); + OP_REQUIRES( + ctx, bcast.IsValid(), + errors::InvalidArgument( + "In[0] and In[1] must have compatible batch dimensions: ", + in0.shape().DebugString(), " vs. ", in1.shape().DebugString())); + + TensorShape out_shape = bcast.output_batch_shape(); + auto batch_size = bcast.output_batch_size(); + auto d0 = in0.dim_size(in0.dims() - 2); + auto d1 = in0.dim_size(in0.dims() - 1); + Tensor in0_reshaped; + OP_REQUIRES( + ctx, + in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})), + errors::Internal("Failed to reshape In[0] from ", + in0.shape().DebugString())); + auto d2 = in1.dim_size(in1.dims() - 2); + auto d3 = in1.dim_size(in1.dims() - 1); + Tensor in1_reshaped; + OP_REQUIRES( + ctx, + in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})), + errors::Internal("Failed to reshape In[1] from ", + in1.shape().DebugString())); + if (adjoint_) std::swap(d0, d1); + OP_REQUIRES(ctx, d1 == d2, + errors::InvalidArgument( + "In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ", + in0.shape().DebugString(), " ", in1.shape().DebugString(), + " ", lower_, " ", adjoint_)); + out_shape.AddDim(d0); + out_shape.AddDim(d3); + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out)); + if (out->NumElements() == 0) { + return; + } + Tensor out_reshaped; + OP_REQUIRES(ctx, + out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})), + errors::Internal("Failed to reshape output from ", + out->shape().DebugString())); + LaunchBatchMatrixTriangularSolve::Launch( + ctx, in0_reshaped, in1_reshaped, adjoint_, lower_, bcast, + &out_reshaped); + } + + private: + virtual void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0, + const Tensor& in1) = 0; + bool lower_; + bool adjoint_; +}; + +template +class MatrixTriangularSolveOp + : public BaseMatrixTriangularSolveOp { + public: + explicit MatrixTriangularSolveOp(OpKernelConstruction* context) + : BaseMatrixTriangularSolveOp(context) {} + + ~MatrixTriangularSolveOp() override {} + + private: + void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0, + const Tensor& in1) override { + OP_REQUIRES( + ctx, in0.dims() >= 2, + errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims())); + + OP_REQUIRES( + ctx, in1.dims() >= 2, + errors::InvalidArgument("In[0] ndims must be >= 2: ", in1.dims())); + } +}; + +#define REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU(TYPE) \ + REGISTER_KERNEL_BUILDER(Name("MatrixTriangularSolve") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + MatrixTriangularSolveOp); \ + REGISTER_KERNEL_BUILDER(Name("BatchMatrixTriangularSolve") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + MatrixTriangularSolveOp); + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +template +struct LaunchBatchMatrixTriangularSolve { + static void Launch(OpKernelContext* context, const Tensor& in_x, + const Tensor& in_y, bool adjoint, bool lower, + const MatMulBCast& bcast, Tensor* out) { + auto* stream = context->op_device_context()->stream(); + + const uint64 m = in_x.dim_size(1); + const uint64 n = out->dim_size(2); + + // Do a memcpy when we don't need to broadcast. + if (!bcast.IsBroadcastingRequired() || out->shape() == in_y.shape()) { + auto src_device_mem = AsDeviceMemory(in_y.template flat().data()); + auto dst_device_mem = AsDeviceMemory(out->template flat().data()); + OP_REQUIRES( + context, + stream + ->ThenMemcpyD2D(&dst_device_mem, src_device_mem, + bcast.y_batch_size() * m * n * sizeof(Scalar)) + .ok(), + errors::Internal("MatrixTriangularSolveOp: failed to copy rhs " + "from device")); + } else { + std::vector out_ptrs; + std::vector b_tmp_ptrs; + auto* b_base_ptr = in_y.template flat().data(); + const std::vector& b_batch_indices = bcast.y_batch_indices(); + for (int64 i = 0; i < bcast.y_batch_size(); ++i) { + b_tmp_ptrs.push_back(b_base_ptr + i * m * n); + } + for (int64 i = 0; i < bcast.output_batch_size(); ++i) { + auto src_device_mem = AsDeviceMemory(b_tmp_ptrs[b_batch_indices[i]]); + auto dst_device_mem = + AsDeviceMemory(out->template flat().data() + i * m * n); + OP_REQUIRES( + context, + stream + ->ThenMemcpyD2D(&dst_device_mem, src_device_mem, + m * n * sizeof(Scalar)) + .ok(), + errors::Internal("MatrixTriangularSolveOp: failed to copy rhs " + "from device")); + } + } + + if (out->NumElements() == 0) { + return; + } + +#if GOOGLE_CUDA + + cublasSideMode_t side = CUBLAS_SIDE_RIGHT; + cublasFillMode_t uplo; + cublasOperation_t trans; + cublasDiagType_t diag = CUBLAS_DIAG_NON_UNIT; + + // Cublas does + // output = matrix \ rhs + // where matrix, rhs and output are assumed to be in column major. + // We want the output to be in row-major, so we can compute + // output' = rhs' / matrix' (' stands for transpose) + // Upper/lower needs to be swapped for this. + + uplo = lower ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; + trans = adjoint ? CUBLAS_OP_C : CUBLAS_OP_N; + auto solver = absl::make_unique(context); + +#elif TENSORFLOW_USE_ROCM + rocblas_side side = rocblas_side_right; + rocblas_fill uplo; + rocblas_operation trans; + rocblas_diagonal diag = rocblas_diagonal_non_unit; + + // rocblas does + // output = matrix \ rhs + // where matrix, rhs and output are assumed to be in column major. + // We want the output to be in row-major, so we can compute + // output' = rhs' / matrix' (' stands for transpose) + // Upper/lower needs to be swapped for this. + + uplo = lower ? rocblas_fill_upper : rocblas_fill_upper; + trans = adjoint ? rocblas_operation_conjugate_transpose + : rocblas_operation_none; + auto solver = absl::make_unique(context); + +#endif + + const uint64 leading_dim_matrix = m; + const uint64 leading_dim_output = n; + const uint64 colmajor_rows = n; + const uint64 colmajor_cols = m; + + const int64 batch_size = bcast.output_batch_size(); + std::vector a_ptrs; + std::vector out_ptrs; + std::vector a_tmp_ptrs; + a_ptrs.reserve(batch_size); + out_ptrs.reserve(batch_size); + a_tmp_ptrs.reserve(bcast.x_batch_size()); + auto* a_base_ptr = in_x.template flat().data(); + auto* out_base_ptr = out->template flat().data(); + + if (!bcast.IsBroadcastingRequired()) { + for (int64 i = 0; i < batch_size; ++i) { + a_ptrs.push_back(a_base_ptr + i * m * m); + out_ptrs.push_back(out_base_ptr + i * m * n); + } + } else { + const std::vector& a_batch_indices = bcast.x_batch_indices(); + for (int64 i = 0; i < bcast.x_batch_size(); ++i) { + a_tmp_ptrs.push_back(a_base_ptr + i * m * m); + } + for (int64 i = 0; i < batch_size; ++i) { + a_ptrs.push_back(a_tmp_ptrs[a_batch_indices[i]]); + out_ptrs.push_back(out_base_ptr + i * m * n); + } + } + + typedef Scalar Coefficient; + const Scalar alpha = Scalar(1.0); + +#if GOOGLE_CUDA + + // TODO(b/146763573): Consider using Trsv here when the right hand side is + // a vector. This will require an explicit transpose since Trsv assumes + // CUBLAS_SIDE_LEFT. + if (batch_size == 1) { + OP_REQUIRES_OK( + context, + solver->Trsm(side, uplo, trans, diag, colmajor_rows, colmajor_cols, + &alpha, a_ptrs[0], leading_dim_matrix /*lda*/, + out_ptrs[0], leading_dim_output /*ldb*/)); + } else { + // Heuristic for choosing between batched interface vs. non-batched + // interface. This is inspired by matrix_solve_op and can probably be + // tuned. + // TODO(b/146763573): Tune this heuristic. + const int kMaxMatrixSizeToBatchSizeRatio = 128; + const bool use_batched_solver = + m <= kMaxMatrixSizeToBatchSizeRatio * batch_size; + if (use_batched_solver) { + OP_REQUIRES_OK( + context, solver->TrsmBatched( + side, uplo, trans, diag, colmajor_rows, colmajor_cols, + &alpha, &a_ptrs[0], leading_dim_matrix /*lda*/, + &out_ptrs[0], leading_dim_output /*ldb*/, batch_size)); + } else { + for (int batch = 0; batch < batch_size; ++batch) { + OP_REQUIRES_OK( + context, solver->Trsm(side, uplo, trans, diag, colmajor_rows, + colmajor_cols, &alpha, a_ptrs[batch], + leading_dim_matrix /*lda*/, out_ptrs[batch], + leading_dim_output /*ldb*/)); + } + } + } +#elif TENSORFLOW_USE_ROCM + for (int batch = 0; batch < batch_size; ++batch) { + OP_REQUIRES_OK( + context, + solver->Trsm(side, uplo, trans, diag, colmajor_rows, colmajor_cols, + &alpha, a_ptrs[batch], leading_dim_matrix /*lda*/, + out_ptrs[batch], leading_dim_output /*ldb*/)); + } +#endif + } +}; + +#define REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU(TYPE) \ + REGISTER_KERNEL_BUILDER(Name("MatrixTriangularSolve") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T"), \ + MatrixTriangularSolveOp); \ + REGISTER_KERNEL_BUILDER(Name("BatchMatrixTriangularSolve") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T"), \ + MatrixTriangularSolveOp); + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_ diff --git a/tensorflow/core/kernels/matrix_triangular_solve_op_real.cc b/tensorflow/core/kernels/matrix_triangular_solve_op_real.cc new file mode 100644 index 00000000000..0f92964dd72 --- /dev/null +++ b/tensorflow/core/kernels/matrix_triangular_solve_op_real.cc @@ -0,0 +1,32 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/matrix_triangular_solve_op_impl.h" + +#if GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda.h" +#endif // GOOGLE_CUDA + +namespace tensorflow { + +TF_CALL_float(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU); +TF_CALL_double(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU); + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +TF_CALL_float(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU); +TF_CALL_double(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU); +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/matrix_triangular_solve_op_test.cc b/tensorflow/core/kernels/matrix_triangular_solve_op_test.cc new file mode 100644 index 00000000000..7bb71ae8b68 --- /dev/null +++ b/tensorflow/core/kernels/matrix_triangular_solve_op_test.cc @@ -0,0 +1,165 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/kernels/broadcast_to_op.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +namespace { + +Node* BroadcastTo(Graph* g, Node* input, Node* shape) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BroadcastTo") + .Input(input) + .Input(shape) + .Attr("Tidx", DT_INT64) + .Finalize(g, &ret)); + return ret; +} + +Node* MatrixTriangularSolve(Graph* g, Node* in0, Node* in1, bool adjoint) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatrixTriangularSolve") + .Input(in0) + .Input(in1) + .Attr("lower", true) + .Attr("adjoint", adjoint) + .Finalize(g, &ret)); + return ret; +} + +template +static Graph* MatrixTriangularSolveWithBroadcast(int64 b0, int64 b1, int64 m, + int64 n, bool manual_broadcast, + DataType type) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor in0(type, TensorShape({b0, m, m})); + // Set diagonal to non-zero to guarantee invertibility. + in0.flat().setRandom(); + auto matrix = Eigen::Map< + Eigen::Matrix>( + in0.flat().data(), in0.dim_size(1), in0.dim_size(2)); + + matrix.diagonal() = + (matrix.diagonal().cwiseAbs().array() + static_cast(0.5)); + Tensor in1(type, TensorShape({b1, m, n})); + in1.flat().setRandom(); + + Tensor broadcasted_in0_shape(DT_INT64, TensorShape({3})); + Tensor broadcasted_in1_shape(DT_INT64, TensorShape({3})); + + Node* in0_node = nullptr; + Node* in1_node = nullptr; + if (manual_broadcast) { + auto vec0 = broadcasted_in0_shape.vec(); + auto vec1 = broadcasted_in1_shape.vec(); + for (int i = 0; i < 3; ++i) { + vec0(i) = (i == 0 ? std::max(b0, b1) : in0.shape().dim_size(i)); + vec1(i) = (i == 0 ? std::max(b0, b1) : in1.shape().dim_size(i)); + } + in0_node = BroadcastTo(g, test::graph::Constant(g, in0), + test::graph::Constant(g, broadcasted_in0_shape)); + in1_node = BroadcastTo(g, test::graph::Constant(g, in1), + test::graph::Constant(g, broadcasted_in1_shape)); + } else { + in0_node = test::graph::Constant(g, in0); + in1_node = test::graph::Constant(g, in1); + } + + MatrixTriangularSolve(g, in0_node, in1_node, false); + return g; +} + +// Macro arguments names: --------------------------------------------------- // +// B1: batch size of LHS +// B2: batch size of RHS +// M: inner dimensions of LHS and RHS, outer dimension of LHS +// N: outer dimension of RHS +// MB: boolean indicating whether to use manual broadcasting +// T: C++ type of scalars (e.g. float, std::complex) +// TT: TensorFlow type of scalars (e.g. DT_FLOAT, DT_COMPLEX128 +// D: Device (e.g. cpu, gpu) +#define BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, T, TT, D) \ + static void \ + BM_MatrixTriangularSolve##_##B1##_##B2##_##M##_##N##_##MB##_##TT##_##D( \ + int iters) { \ + testing::UseRealTime(); \ + testing::ItemsProcessed(static_cast(iters) * std::max(B1, B2) * M * \ + M * N * 2); \ + test::Benchmark( \ + #D, MatrixTriangularSolveWithBroadcast(B1, B2, M, N, MB, TT)) \ + .Run(iters); \ + } \ + BENCHMARK( \ + BM_MatrixTriangularSolve##_##B1##_##B2##_##M##_##N##_##MB##_##TT##_##D); + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#define BM_MatrixTriangularSolve(B1, B2, M, N, MB) \ + BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, float, DT_FLOAT, cpu); \ + BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, double, DT_DOUBLE, cpu); \ + BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, float, DT_FLOAT, gpu); \ + BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, double, DT_DOUBLE, gpu); + +#else + +#define BM_MatrixTriangularSolve(B1, B2, M, N, MB) \ + BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, float, DT_FLOAT, cpu); \ + BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, double, DT_DOUBLE, cpu); + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +// Square matrix triangular solve. +BM_MatrixTriangularSolve(32, 32, 512, 512, true); +BM_MatrixTriangularSolve(32, 32, 512, 512, false); +BM_MatrixTriangularSolve(1, 32, 512, 512, true); +BM_MatrixTriangularSolve(1, 32, 512, 512, false); +BM_MatrixTriangularSolve(32, 1, 512, 512, true); +BM_MatrixTriangularSolve(32, 1, 512, 512, false); +BM_MatrixTriangularSolve(128, 128, 512, 512, true); +BM_MatrixTriangularSolve(128, 128, 512, 512, false); +BM_MatrixTriangularSolve(1, 128, 512, 512, true); +BM_MatrixTriangularSolve(1, 128, 512, 512, false); +BM_MatrixTriangularSolve(128, 1, 512, 512, true); +BM_MatrixTriangularSolve(128, 1, 512, 512, false); +BM_MatrixTriangularSolve(1, 128, 1024, 1024, true); +BM_MatrixTriangularSolve(1, 128, 1024, 1024, false); +BM_MatrixTriangularSolve(128, 1, 1024, 1024, true); +BM_MatrixTriangularSolve(128, 1, 1024, 1024, false); + +// Matrix-vector triangular solve. +BM_MatrixTriangularSolve(1, 128, 200, 1, true); +BM_MatrixTriangularSolve(1, 128, 200, 1, false); +BM_MatrixTriangularSolve(128, 1, 200, 1, true); +BM_MatrixTriangularSolve(128, 1, 200, 1, false); + +// Matrix-vector triangular solve, large dimension. +BM_MatrixTriangularSolve(1, 128, 200, 10000, true); +BM_MatrixTriangularSolve(1, 128, 200, 10000, false); +BM_MatrixTriangularSolve(128, 1, 200, 10000, true); +BM_MatrixTriangularSolve(128, 1, 200, 10000, false); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/rocm_solvers.cc b/tensorflow/core/kernels/rocm_solvers.cc new file mode 100644 index 00000000000..5faf718332e --- /dev/null +++ b/tensorflow/core/kernels/rocm_solvers.cc @@ -0,0 +1,245 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== +*/ +#if TENSORFLOW_USE_ROCM +#include "tensorflow/core/kernels/rocm_solvers.h" + +#include +#include +#include + +#include "rocm/include/rocblas.h" +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/gpu/gpu_activation.h" +#include "tensorflow/stream_executor/gpu/gpu_executor.h" +#include "tensorflow/stream_executor/lib/env.h" +#include "tensorflow/stream_executor/platform/default/dso_loader.h" +#include "tensorflow/stream_executor/platform/port.h" + +namespace tensorflow { +namespace { + +using stream_executor::gpu::GpuExecutor; +using stream_executor::gpu::ScopedActivateExecutorContext; +using stream_executor::internal::CachedDsoLoader::GetRocblasDsoHandle; + +namespace wrap { +#ifdef PLATFORM_GOOGLE +#define ROCBLAS_WRAP(__name) \ + struct WrapperShim__##__name { \ + static const char* kName; \ + template \ + rocblas_status operator()(GpuExecutor* parent, Args... args) { \ + ScopedActivateExecutorContext sac{parent}; \ + return ::__name(args...); \ + } \ + } __name; \ + const char* WrapperShim__##__name::kName = #__name; + +#else + +#define ROCBLAS_WRAP(__name) \ + struct DynLoadShim__##__name { \ + static const char* kName; \ + using FuncPtrT = std::add_pointer::type; \ + static void* GetDsoHandle() { \ + auto s = GetRocblasDsoHandle(); \ + return s.ValueOrDie(); \ + } \ + static FuncPtrT LoadOrDie() { \ + void* f; \ + auto s = stream_executor::port::Env::Default()->GetSymbolFromLibrary( \ + GetDsoHandle(), kName, &f); \ + CHECK(s.ok()) << "could not find " << kName \ + << " in rocblas DSO; dlerror: " << s.error_message(); \ + return reinterpret_cast(f); \ + } \ + static FuncPtrT DynLoad() { \ + static FuncPtrT f = LoadOrDie(); \ + return f; \ + } \ + template \ + rocblas_status operator()(GpuExecutor* parent, Args... args) { \ + ScopedActivateExecutorContext sac{parent}; \ + return DynLoad()(args...); \ + } \ + } __name; \ + const char* DynLoadShim__##__name::kName = #__name; + +#endif + +ROCBLAS_WRAP(rocblas_create_handle) +ROCBLAS_WRAP(rocblas_destroy_handle) +ROCBLAS_WRAP(rocblas_set_stream) +ROCBLAS_WRAP(rocblas_dtrsm) +ROCBLAS_WRAP(rocblas_strsm) + +} // namespace wrap + +struct ROCmSolverHandles { + explicit ROCmSolverHandles(GpuExecutor* parent, hipStream_t stream) { + parent_ = parent; + CHECK(wrap::rocblas_create_handle(parent_, &rocm_blas_handle) == + rocblas_status_success) + << "Failed to create rocBlas instance."; + CHECK(wrap::rocblas_set_stream(parent_, rocm_blas_handle, stream) == + rocblas_status_success) + << "Failed to set rocBlas stream."; + } + + ~ROCmSolverHandles() { + CHECK(wrap::rocblas_destroy_handle(parent_, rocm_blas_handle) == + rocblas_status_success) + << "Failed to destroy cuBlas instance."; + } + GpuExecutor* parent_; + rocblas_handle rocm_blas_handle; +}; + +using HandleMap = + std::unordered_map>; + +// Returns a singleton map used for storing initialized handles for each unique +// gpu stream. +HandleMap* GetHandleMapSingleton() { + static HandleMap* cm = new HandleMap; + return cm; +} + +static mutex handle_map_mutex(LINKER_INITIALIZED); + +} // namespace + +ROCmSolver::ROCmSolver(OpKernelContext* context) : context_(context) { + mutex_lock lock(handle_map_mutex); + GpuExecutor* gpu_executor = static_cast( + context->op_device_context()->stream()->parent()->implementation()); + const hipStream_t* hip_stream_ptr = CHECK_NOTNULL( + reinterpret_cast(context->op_device_context() + ->stream() + ->implementation() + ->GpuStreamMemberHack())); + + hip_stream_ = *hip_stream_ptr; + HandleMap* handle_map = CHECK_NOTNULL(GetHandleMapSingleton()); + auto it = handle_map->find(hip_stream_); + if (it == handle_map->end()) { + LOG(INFO) << "Creating ROCmSolver handles for stream " << hip_stream_; + // Previously unseen Gpu stream. Initialize a set of Gpu solver library + // handles for it. + std::unique_ptr new_handles( + new ROCmSolverHandles(gpu_executor, hip_stream_)); + it = handle_map->insert(std::make_pair(hip_stream_, std::move(new_handles))) + .first; + } + rocm_blas_handle_ = it->second->rocm_blas_handle; +} + +ROCmSolver::~ROCmSolver() { + for (auto tensor_ref : scratch_tensor_refs_) { + tensor_ref.Unref(); + } +} + +#define TF_RETURN_IF_ROCBLAS_ERROR(expr) \ + do { \ + auto status = (expr); \ + if (TF_PREDICT_FALSE(status != rocblas_status_success)) { \ + return errors::Internal(__FILE__, ":", __LINE__, \ + ": rocBlas call failed status = ", status); \ + } \ + } while (0) + +// Macro that specializes a solver method for all 4 standard +// numeric types. +#define TF_CALL_LAPACK_TYPES(m) \ + m(float, s) m(double, d) m(std::complex, c) m(std::complex, z) +#define TF_CALL_LAPACK_TYPES_NO_COMPLEX(m) m(float, s) m(double, d) + +#define BLAS_SOLVER_FN(method, type_prefix) \ + wrap::rocblas##_##type_prefix##method + +// Allocates a temporary tensor. The ROCmSolver object maintains a +// TensorReference to the underlying Tensor to prevent it from being deallocated +// prematurely. +Status ROCmSolver::allocate_scoped_tensor(DataType type, + const TensorShape& shape, + Tensor* out_temp) { + const Status status = context_->allocate_temp(type, shape, out_temp); + if (status.ok()) { + scratch_tensor_refs_.emplace_back(*out_temp); + } + return status; +} + +Status ROCmSolver::forward_input_or_allocate_scoped_tensor( + gtl::ArraySlice candidate_input_indices, DataType type, + const TensorShape& shape, Tensor* out_temp) { + const Status status = context_->forward_input_or_allocate_temp( + candidate_input_indices, type, shape, out_temp); + if (status.ok()) { + scratch_tensor_refs_.emplace_back(*out_temp); + } + return status; +} + +template +static inline Status TrsmImpl(GpuExecutor* gpu_executor, SolverFnT solver, + rocblas_handle rocm_blas_handle, + rocblas_side side, rocblas_fill uplo, + rocblas_operation trans, rocblas_diagonal diag, + int m, int n, + const Scalar* alpha, /* host or device pointer */ + const Scalar* A, int lda, Scalar* B, int ldb) { + mutex_lock lock(handle_map_mutex); + using ROCmScalar = typename ROCmComplexT::type; + + TF_RETURN_IF_ROCBLAS_ERROR(solver(gpu_executor, rocm_blas_handle, side, uplo, + trans, diag, m, n, + reinterpret_cast(alpha), + reinterpret_cast(A), lda, + reinterpret_cast(B), ldb)); + + return Status::OK(); +} + +#define TRSM_INSTANCE(Scalar, type_prefix) \ + template <> \ + Status ROCmSolver::Trsm( \ + rocblas_side side, rocblas_fill uplo, rocblas_operation trans, \ + rocblas_diagonal diag, int m, int n, \ + const Scalar* alpha, /* host or device pointer */ \ + const Scalar* A, int lda, Scalar* B, int ldb) { \ + GpuExecutor* gpu_executor = static_cast( \ + context_->op_device_context()->stream()->parent()->implementation()); \ + return TrsmImpl(gpu_executor, BLAS_SOLVER_FN(trsm, type_prefix), \ + rocm_blas_handle_, side, uplo, trans, diag, m, n, alpha, \ + A, lda, B, ldb); \ + } + +TF_CALL_LAPACK_TYPES_NO_COMPLEX(TRSM_INSTANCE); + +} // namespace tensorflow + +#endif // TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/rocm_solvers.h b/tensorflow/core/kernels/rocm_solvers.h new file mode 100644 index 00000000000..9826bcbf923 --- /dev/null +++ b/tensorflow/core/kernels/rocm_solvers.h @@ -0,0 +1,160 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +#ifndef TENSORFLOW_CORE_KERNELS_ROCM_SOLVERS_H_ +#define TENSORFLOW_CORE_KERNELS_ROCM_SOLVERS_H_ + +// This header declares the class ROCmSolver, which contains wrappers of linear +// algebra solvers in the cuBlas and cuSolverDN libraries for use in TensorFlow +// kernels. + +#if TENSORFLOW_USE_ROCM + +#include +#include + +#include "rocm/include/hip/hip_complex.h" +#include "rocm/include/rocblas.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/stream_executor/blas.h" + +namespace tensorflow { + +// Type traits to get ROCm complex types from std::complex. +template +struct ROCmComplexT { + typedef T type; +}; +template <> +struct ROCmComplexT> { + typedef hipComplex type; +}; +template <> +struct ROCmComplexT> { + typedef hipDoubleComplex type; +}; +// Converts pointers of std::complex<> to pointers of +// cuComplex/cuDoubleComplex. No type conversion for non-complex types. +template +inline const typename ROCmComplexT::type* ROCmComplex(const T* p) { + return reinterpret_cast::type*>(p); +} +template +inline typename ROCmComplexT::type* ROCmComplex(T* p) { + return reinterpret_cast::type*>(p); +} + +template +class ScratchSpace; + +class ROCmSolver { + public: + // This object stores a pointer to context, which must outlive it. + explicit ROCmSolver(OpKernelContext* context); + virtual ~ROCmSolver(); + + // Allocates a temporary tensor that will live for the duration of the + // ROCmSolver object. + Status allocate_scoped_tensor(DataType type, const TensorShape& shape, + Tensor* scoped_tensor); + Status forward_input_or_allocate_scoped_tensor( + gtl::ArraySlice candidate_input_indices, DataType type, + const TensorShape& shape, Tensor* input_alias_or_new_scoped_tensor); + + OpKernelContext* context() { return context_; } + + template + Status Trsm(rocblas_side side, rocblas_fill uplo, rocblas_operation trans, + rocblas_diagonal diag, int m, int n, const Scalar* alpha, + const Scalar* A, int lda, Scalar* B, int ldb); + + private: + OpKernelContext* context_; // not owned. + hipStream_t hip_stream_; + rocblas_handle rocm_blas_handle_; + std::vector scratch_tensor_refs_; + + TF_DISALLOW_COPY_AND_ASSIGN(ROCmSolver); +}; + +// Helper class to allocate scratch memory and keep track of debug info. +// Mostly a thin wrapper around Tensor & allocate_temp. +template +class ScratchSpace { + public: + ScratchSpace(OpKernelContext* context, int64 size, bool on_host) + : ScratchSpace(context, TensorShape({size}), "", on_host) {} + + ScratchSpace(OpKernelContext* context, int64 size, const string& debug_info, + bool on_host) + : ScratchSpace(context, TensorShape({size}), debug_info, on_host) {} + + ScratchSpace(OpKernelContext* context, const TensorShape& shape, + const string& debug_info, bool on_host) + : context_(context), debug_info_(debug_info), on_host_(on_host) { + AllocatorAttributes alloc_attr; + if (on_host) { + // Allocate pinned memory on the host to avoid unnecessary + // synchronization. + alloc_attr.set_on_host(true); + alloc_attr.set_gpu_compatible(true); + } + TF_CHECK_OK(context->allocate_temp(DataTypeToEnum::value, shape, + &scratch_tensor_, alloc_attr)); + } + + virtual ~ScratchSpace() {} + + Scalar* mutable_data() { + return scratch_tensor_.template flat().data(); + } + const Scalar* data() const { + return scratch_tensor_.template flat().data(); + } + Scalar& operator()(int64 i) { + return scratch_tensor_.template flat()(i); + } + const Scalar& operator()(int64 i) const { + return scratch_tensor_.template flat()(i); + } + int64 bytes() const { return scratch_tensor_.TotalBytes(); } + int64 size() const { return scratch_tensor_.NumElements(); } + const string& debug_info() const { return debug_info_; } + + Tensor& tensor() { return scratch_tensor_; } + const Tensor& tensor() const { return scratch_tensor_; } + + // Returns true if this ScratchSpace is in host memory. + bool on_host() const { return on_host_; } + + protected: + OpKernelContext* context() const { return context_; } + + private: + OpKernelContext* context_; // not owned + const string debug_info_; + const bool on_host_; + Tensor scratch_tensor_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_USE_ROCM + +#endif // TENSORFLOW_CORE_KERNELS_ROCM_SOLVERS_H_ diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc index 4572df279b7..75340b28eb0 100644 --- a/tensorflow/core/ops/linalg_ops.cc +++ b/tensorflow/core/ops/linalg_ops.cc @@ -84,6 +84,34 @@ Status MatrixSolveShapeFn(InferenceContext* c, bool square) { return Status::OK(); } +// The first input is [...,M,M] and second input is [...,M,N]. +// Output is [...,M,N]. +Status MatrixTriangularSolveShapeFn(InferenceContext* c) { + ShapeHandle lhs; + ShapeHandle rhs; + TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &lhs)); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs)); + + ShapeHandle lhs_batch_shape; + ShapeHandle rhs_batch_shape; + ShapeHandle output_batch_shape; + // Make the common batch subshape. + TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape)); + TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape)); + TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( + c, lhs_batch_shape, rhs_batch_shape, true, &output_batch_shape)); + DimensionHandle m; + // lhs and rhs have the same value for m to be compatible. + TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, -1), c->Dim(rhs, -2), &m)); + + ShapeHandle out; + // Build final shape (batch_shape + m + n) in . + TF_RETURN_IF_ERROR( + c->Concatenate(output_batch_shape, c->Matrix(m, c->Dim(rhs, -1)), &out)); + c->set_output(0, out); + return Status::OK(); +} + // Input is [...,N,N]. Outputs are: // [...,N];[0], if compute_v is false, // [...,N];[...,N,N], if compute_v is true. @@ -426,7 +454,7 @@ REGISTER_OP("MatrixTriangularSolve") .Attr("adjoint: bool = False") .Attr("T: {double, float, half, complex64, complex128}") .SetShapeFn([](InferenceContext* c) { - return MatrixSolveShapeFn(c, true /* square (*/); + return MatrixTriangularSolveShapeFn(c); }); REGISTER_OP("MatrixSolveLs") diff --git a/tensorflow/core/ops/linalg_ops_test.cc b/tensorflow/core/ops/linalg_ops_test.cc index 682a994e890..7e5ddc02339 100644 --- a/tensorflow/core/ops/linalg_ops_test.cc +++ b/tensorflow/core/ops/linalg_ops_test.cc @@ -122,34 +122,54 @@ TEST(LinalgOpsTest, SelfAdjointEigV2_ShapeFn) { "[d0_0,d0_1,d0_2,d0_3|d0_4];[d0_0,d0_1,d0_2,d0_3|d0_4,d0_3|d0_4]"); } -TEST(LinalgOpsTest, SquareMatrixSolve_ShapeFn) { - for (const char* op_name : {"MatrixSolve", "MatrixTriangularSolve"}) { - ShapeInferenceTestOp op(op_name); - INFER_OK(op, "?;?", "?"); - INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1];?"); - INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,2];?"); - INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, - "[5,?,?];[6]"); - INFER_ERROR("Shapes must be equal rank, but are 0 and 1", op, - "[5,?];[6,?,?]"); +TEST(LinalgOpsTest, MatrixSolve_ShapeFn) { + ShapeInferenceTestOp op("MatrixSolve"); + INFER_OK(op, "?;?", "?"); + INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1];?"); + INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,2];?"); + INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[5,?,?];[6]"); + INFER_ERROR("Shapes must be equal rank, but are 0 and 1", op, + "[5,?];[6,?,?]"); - INFER_OK(op, "[?,?];?", "[d0_0|d0_1,?]"); + INFER_OK(op, "[?,?];?", "[d0_0|d0_1,?]"); - // Inputs are [...,M,M] and [...,M,K]. Output is [...,M,K]. - // First test where ... is empty. - INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]"); - INFER_OK(op, "[?,?];[1,?]", "[d1_0,d1_1]"); - INFER_OK(op, "[1,?];[1,?]", "[d0_0|d1_0,d1_1]"); - INFER_OK(op, "[?,1];[1,?]", "[d0_1|d1_0,d1_1]"); - INFER_OK(op, "[1,1];[?,?]", "[d0_0,d1_1]"); - INFER_OK(op, "[1,1];[1,?]", "[d0_0|d0_1|d1_0,d1_1]"); - // Test with ... being 2-d. - INFER_OK(op, "[10,?,?,?];[?,20,1,?]", "[d0_0,d1_1,d1_2,d1_3]"); - INFER_OK(op, "[10,?,1,?];[?,20,1,?]", "[d0_0,d1_1,d0_2|d1_2,d1_3]"); - INFER_OK(op, "[10,?,?,1];[?,20,1,?]", "[d0_0,d1_1,d0_3|d1_2,d1_3]"); - INFER_OK(op, "[10,?,1,1];[?,20,?,?]", "[d0_0,d1_1,d0_2,d1_3]"); - INFER_OK(op, "[10,?,1,1];[?,20,1,?]", "[d0_0,d1_1,d0_2|d0_3|d1_2,d1_3]"); - } + // Inputs are [...,M,M] and [...,M,K]. Output is [...,M,K]. + // First test where ... is empty. + INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]"); + INFER_OK(op, "[?,?];[1,?]", "[d1_0,d1_1]"); + INFER_OK(op, "[1,?];[1,?]", "[d0_0|d1_0,d1_1]"); + INFER_OK(op, "[?,1];[1,?]", "[d0_1|d1_0,d1_1]"); + INFER_OK(op, "[1,1];[?,?]", "[d0_0,d1_1]"); + INFER_OK(op, "[1,1];[1,?]", "[d0_0|d0_1|d1_0,d1_1]"); + // Test with ... being 2-d. + INFER_OK(op, "[10,?,?,?];[?,20,1,?]", "[d0_0,d1_1,d1_2,d1_3]"); + INFER_OK(op, "[10,?,1,?];[?,20,1,?]", "[d0_0,d1_1,d0_2|d1_2,d1_3]"); + INFER_OK(op, "[10,?,?,1];[?,20,1,?]", "[d0_0,d1_1,d0_3|d1_2,d1_3]"); + INFER_OK(op, "[10,?,1,1];[?,20,?,?]", "[d0_0,d1_1,d0_2,d1_3]"); + INFER_OK(op, "[10,?,1,1];[?,20,1,?]", "[d0_0,d1_1,d0_2|d0_3|d1_2,d1_3]"); +} + +TEST(LinalgOpsTest, MatrixTriangularSolve_ShapeFn) { + ShapeInferenceTestOp op("MatrixTriangularSolve"); + INFER_OK(op, "?;?", "?"); + INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1];?"); + INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,2];?"); + INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[5,?,?];[6]"); + + // Inputs are [...,M,M] and [...,M,K]. Output is [...,M,K]. + // First test where ... is empty. + INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]"); + INFER_OK(op, "[?,?];[1,?]", "[d1_0,d1_1]"); + INFER_OK(op, "[1,?];[1,?]", "[d0_0|d1_0,d1_1]"); + INFER_OK(op, "[?,1];[1,?]", "[d0_1|d1_0,d1_1]"); + INFER_OK(op, "[1,1];[?,?]", "[d0_0,d1_1]"); + INFER_OK(op, "[1,1];[1,?]", "[d0_0|d0_1|d1_0,d1_1]"); + // Test with ... being 2-d. + INFER_OK(op, "[10,?,?,?];[?,20,1,?]", "[d0_0,d1_1,d1_2,d1_3]"); + INFER_OK(op, "[10,?,1,?];[?,20,1,?]", "[d0_0,d1_1,d0_2|d1_2,d1_3]"); + INFER_OK(op, "[10,?,?,1];[?,20,1,?]", "[d0_0,d1_1,d0_3|d1_2,d1_3]"); + INFER_OK(op, "[10,?,1,1];[?,20,?,?]", "[d0_0,d1_1,d0_2,d1_3]"); + INFER_OK(op, "[10,?,1,1];[?,20,1,?]", "[d0_0,d1_1,d0_2|d0_3|d1_2,d1_3]"); } TEST(LinalgOpsTest, MatrixSolveLs_ShapeFn) { diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 19763f92d50..bd80e3341cd 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -755,8 +755,9 @@ cuda_py_test( cuda_py_test( name = "matrix_triangular_solve_op_test", - size = "small", + size = "medium", srcs = ["matrix_triangular_solve_op_test.py"], + shard_count = 3, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:linalg_ops", diff --git a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py index 32ab6125717..683b1188ffb 100644 --- a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py @@ -20,7 +20,6 @@ from __future__ import print_function import numpy as np -from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops @@ -68,31 +67,32 @@ class MatrixTriangularSolveOpTest(test.TestCase): else: a_np = a if adjoint: - a_np = np.conj(np.transpose(a_np)) + axes = list(range(len(a_np.shape))) + axes[-2] = -1 + axes[-1] = -2 + a_np = np.conj(np.transpose(a_np, axes=axes)) if batch_dims is not None: a = np.tile(a, batch_dims + [1, 1]) a_np = np.tile(a_np, batch_dims + [1, 1]) b = np.tile(b, batch_dims + [1, 1]) - with self.cached_session(use_gpu=True) as sess: - if use_placeholder: - a_tf = array_ops.placeholder(a.dtype) - b_tf = array_ops.placeholder(b.dtype) - tf_ans = linalg_ops.matrix_triangular_solve( - a_tf, b_tf, lower=lower, adjoint=adjoint) - tf_val = sess.run(tf_ans, feed_dict={a_tf: a, b_tf: b}) - np_ans = np.linalg.solve(a_np, b) - else: - a_tf = constant_op.constant(a) - b_tf = constant_op.constant(b) - tf_ans = linalg_ops.matrix_triangular_solve( - a_tf, b_tf, lower=lower, adjoint=adjoint) - tf_val = self.evaluate(tf_ans) - np_ans = np.linalg.solve(a_np, b) - self.assertEqual(np_ans.shape, tf_ans.get_shape()) - self.assertEqual(np_ans.shape, tf_val.shape) - self.assertAllClose(np_ans, tf_val) + def broadcast(a, b): + b1 = b + np.zeros(a.shape[:-2] + (1, 1), dtype=b.dtype) + return a, b1 + + a_tf = a + b_tf = b + if use_placeholder: + a_tf = array_ops.placeholder_with_default(a_tf, shape=None) + b_tf = array_ops.placeholder_with_default(b_tf, shape=None) + tf_ans = linalg_ops.matrix_triangular_solve( + a_tf, b_tf, lower=lower, adjoint=adjoint) + tf_val = self.evaluate(tf_ans) + a_np, b = broadcast(a_np, b) + np_ans = np.linalg.solve(a_np, b) + self.assertEqual(np_ans.shape, tf_val.shape) + self.assertAllClose(np_ans, tf_val) @test_util.run_deprecated_v1 def testSolve(self): @@ -136,6 +136,48 @@ class MatrixTriangularSolveOpTest(test.TestCase): # Batch of 3x2x2x2 matrices, 3x2x2x3 right-hand sides. self._verifySolveAllWaysReal(matrix, rhs, batch_dims=[3, 2]) + @test_util.run_deprecated_v1 + def testSolveBatchBroadcast(self): + # 2 x 2 x 2 + matrix = np.array([[[1., 0.], [3., 4.]], [[1., 0.], [2., 1.]]]) + # 2 x 3 + rhs = np.array([[1., 0., 1.], [0., 1., 1.]]) + # 2 x 2 x 3 + self._verifySolveAllWaysReal(matrix, rhs) + # 2 x 2 x 2 + matrix2 = np.array([[[1., 0.], [3., 4.]], [[2., 0.], [1., 6.3]]]) + # 1 x 2 x 3 + rhs = np.array([[[1., 0., 1.], [0., 1., 1.]]]) + # 2 x 2 x 3 + self._verifySolveAllWaysReal(matrix2, rhs) + + @test_util.run_deprecated_v1 + def testSolveBatchBroadcastLargerBatches(self): + # 1 x 10 x 10 + matrix = np.random.uniform(low=1, high=2., size=[1, 10, 10]) + # 10 x 1 + rhs = np.random.uniform(size=[10, 1]) + # 1 x 10 x 1 + self._verifySolveAllWaysReal(matrix, rhs) + + # 2 x 10 x 10 + matrix = np.random.uniform(low=1, high=2., size=[2, 10, 10]) + # 10 x 1 + rhs = np.random.uniform(size=[10, 1]) + # 2 x 10 x 1 + self._verifySolveAllWaysReal(matrix, rhs) + + # 2 x 257 x 257 + matrix = np.random.uniform(low=1, high=2., size=[2, 257, 257]) + # Also ensure the matrix is well conditioned by making it diagonally + # dominant. + np.fill_diagonal(matrix[0, ...], 257 * 2) + np.fill_diagonal(matrix[1, ...], 257 * 2) + # 257 x 1 + rhs = np.random.uniform(size=[257, 1]) + # 2 x 257 x 1 + self._verifySolveAllWaysReal(matrix, rhs) + @test_util.run_deprecated_v1 def testSolveBatchComplex(self): if test.is_built_with_rocm(): diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py index 3e6d22accec..94ef2a9bff4 100644 --- a/tensorflow/python/ops/linalg_grad.py +++ b/tensorflow/python/ops/linalg_grad.py @@ -607,6 +607,7 @@ def _MatrixSolveLsGrad(op, grad): def _MatrixTriangularSolveGrad(op, grad): """Gradient for MatrixTriangularSolve.""" a = op.inputs[0] + b = op.inputs[1] adjoint_a = op.get_attr("adjoint") lower_a = op.get_attr("lower") c = op.outputs[0] @@ -620,7 +621,16 @@ def _MatrixTriangularSolveGrad(op, grad): grad_a = array_ops.matrix_band_part(grad_a, -1, 0) else: grad_a = array_ops.matrix_band_part(grad_a, 0, -1) - return (grad_a, grad_b) + # If the static batch shapes are equal, we don't need to unbroadcast. + if (a.shape.is_fully_defined() and b.shape.is_fully_defined() and + a.shape[:-2] == b.shape[:-2]): + return grad_a, grad_b + a_shape = array_ops.shape(a) + b_shape = array_ops.shape(b) + ra, rb = array_ops.broadcast_gradient_args(a_shape[:-2], b_shape[:-2]) + grad_a = array_ops.reshape(math_ops.reduce_sum(grad_a, axis=ra), a_shape) + grad_b = array_ops.reshape(math_ops.reduce_sum(grad_b, axis=rb), b_shape) + return grad_a, grad_b @ops.RegisterGradient("SelfAdjointEigV2") diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py index bb84c3f7dd9..fcbfd51e394 100644 --- a/tensorflow/python/ops/linalg_ops.py +++ b/tensorflow/python/ops/linalg_ops.py @@ -79,6 +79,68 @@ def _RegularizedGramianCholesky(matrix, l2_regularizer, first_kind): return gen_linalg_ops.cholesky(gramian) +@tf_export( + 'linalg.triangular_solve', + v1=['linalg.triangular_solve', 'matrix_triangular_solve']) +def matrix_triangular_solve(matrix, rhs, lower=True, adjoint=False, name=None): + """Solve systems of linear equations with upper or lower triangular matrices. + + `matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form + square matrices. If `lower` is `True` then the strictly upper triangular part + of each inner-most matrix is assumed to be zero and not accessed. If `lower` + is `False` then the strictly lower triangular part of each inner-most matrix + is assumed to be zero and not accessed. `rhs` is a tensor of shape + `[..., M, N]`. + + The output is a tensor of shape `[..., M, N]`. If `adjoint` is `True` then the + innermost matrices in output satisfy matrix equations ` + sum_k matrix[..., i, k] * output[..., k, j] = rhs[..., i, j]`. + If `adjoint` is `False` then the + innermost matrices in output satisfy matrix equations + `sum_k adjoint(matrix[..., i, k]) * output[..., k, j] = rhs[..., i, j]`. + + Example: + + >>> a = tf.constant([[3, 0, 0, 0], + ... [2, 1, 0, 0], + ... [1, 0, 1, 0], + ... [1, 1, 1, 1]], dtype=tf.float32) + + >>> b = tf.constant([[4], [2], [4], [2]], dtype=tf.float32) + >>> x = tf.linalg.triangular_solve(a, b, lower=True) + >>> x + + >>> tf.matmul(a, x) + + + Args: + matrix: A `Tensor`. Must be one of the following types: `float64`, + `float32`, `half`, `complex64`, `complex128`. Shape is `[..., M, M]`. + rhs: A `Tensor`. Must have the same type as `matrix`. Shape is `[..., M, + N]`. + lower: An optional `bool`. Defaults to `True`. Boolean indicating whether + the innermost matrices in matrix are lower or upper triangular. + adjoint: An optional `bool`. Defaults to `False`. Boolean indicating whether + to solve with matrix or its (block-wise) adjoint. + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as matrix, and shape is `[..., M, N]`. + + """ + with ops.name_scope(name, 'triangular_solve', [matrix, rhs]): + return gen_linalg_ops.matrix_triangular_solve( + matrix, rhs, lower=lower, adjoint=adjoint) + + @tf_export( 'linalg.cholesky_solve', v1=['linalg.cholesky_solve', 'cholesky_solve']) @deprecation.deprecated_endpoints('cholesky_solve') diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py index c6caf2b7f17..218e2db2095 100644 --- a/tensorflow/python/ops/parallel_for/pfor.py +++ b/tensorflow/python/ops/parallel_for/pfor.py @@ -2872,9 +2872,9 @@ def _convert_log_matrix_determinant(pfor_input): @RegisterPFor("MatrixTriangularSolve") def _convert_matrix_triangular_solve(pfor_input): - pfor_input.stack_inputs() - matrix = pfor_input.stacked_input(0) - rhs = pfor_input.stacked_input(1) + pfor_input.expanddim_inputs_for_broadcast() + matrix = pfor_input.input(0)[0] + rhs = pfor_input.input(1)[0] lower = pfor_input.get_attr("lower") adjoint = pfor_input.get_attr("adjoint") output = linalg_ops.matrix_triangular_solve(