Add Broadcasted Matrix Triangular Solve.

Add Numpy-style broadcasting in the batch dimensions for tf.linalg.triangular_solve op. The last two dimensions of both operands constitute the matrix dimensions. The dimensions beyond these are broadcasted to form a common output shape with the standard NumPy broadcasting rules. (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
Note: This implementation differs from Numpy's behavior in that vectors (rank-1 Tensors) are not promoted to matrices (rank-2 Tensors) by appending/prepending dimensions.
PiperOrigin-RevId: 291857632
Change-Id: Ifce8f1ae3e0e5b990b71cf468978e1cdc7663d1f
This commit is contained in:
Srinivas Vasudevan 2020-01-27 20:40:59 -08:00 committed by TensorFlower Gardener
parent 0a3c298880
commit b105944eb6
21 changed files with 1532 additions and 322 deletions

View File

@ -50,7 +50,9 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase):
atol): atol):
feed_dict = {placeholder_a: a, placeholder_ca: clean_a, placeholder_b: b} feed_dict = {placeholder_a: a, placeholder_ca: clean_a, placeholder_b: b}
verification_np = sess.run(verification, feed_dict) verification_np = sess.run(verification, feed_dict)
self.assertAllClose(b, verification_np, atol=atol) broadcasted_shape = a.shape[:-2] + (b.shape[-2], b.shape[-1])
broadcasted_b = b + np.zeros(shape=broadcasted_shape, dtype=b.dtype)
self.assertAllClose(broadcasted_b, verification_np, atol=atol)
def _VerifyTriangularSolve(self, a, b, lower, adjoint, atol): def _VerifyTriangularSolve(self, a, b, lower, adjoint, atol):
clean_a = np.tril(a) if lower else np.triu(a) clean_a = np.tril(a) if lower else np.triu(a)
@ -111,6 +113,18 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase):
self._VerifyTriangularSolveCombo( self._VerifyTriangularSolveCombo(
a.astype(dtype), b.astype(dtype), atol=1e-3) a.astype(dtype), b.astype(dtype), atol=1e-3)
def testBatchBroadcast(self):
rng = np.random.RandomState(0)
shapes = [((3, 3), (4, 3, 5)), ((1, 2, 2), (3, 2, 1)), ((1, 1), (1, 1, 2)),
((1, 3, 4, 4), (2, 1, 4, 1))]
tuples = itertools.product(self.float_types, shapes)
for dtype, (a_shape, b_shape) in tuples:
n = a_shape[-1]
a = np.tril(rng.rand(*a_shape) - 0.5) / (2.0 * n) + np.eye(n)
b = rng.randn(*b_shape)
self._VerifyTriangularSolveCombo(
a.astype(dtype), b.astype(dtype), atol=1e-3)
def testLarge(self): def testLarge(self):
n = 1024 n = 1024
rng = np.random.RandomState(0) rng = np.random.RandomState(0)

View File

@ -13,10 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/util/bcast.h"
#include "tensorflow/core/util/matmul_bcast.h"
namespace tensorflow { namespace tensorflow {
namespace { namespace {
@ -30,8 +33,28 @@ class MatrixTriangularSolveOp : public XlaOpKernel {
} }
void Compile(XlaOpKernelContext* ctx) override { void Compile(XlaOpKernelContext* ctx) override {
const TensorShape lhs_shape = ctx->InputShape(0);
const TensorShape rhs_shape = ctx->InputShape(1);
// By TensorFlow conventions the inputs may not have the same
// shapes, in which case they will be automatically broadcast if
// possible before mapping. Use the standard TensorFlow helper to
// compute valid broadcast shapes, but rely below on XLA to
// automatically perform the broadcast assuming its valid shapes are
// a superset of TensorFlow's valid shapes.
MatMulBCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape));
if (!bcast.IsValid()) {
ctx->SetStatus(errors::InvalidArgument(
"Incompatible shapes: ", lhs_shape.DebugString(), " vs. ",
rhs_shape.DebugString()));
return;
}
xla::XlaOp a = ctx->Input(0);
xla::XlaOp b = ctx->Input(1);
std::tie(a, b) = Broadcast(a, lhs_shape, b, rhs_shape, bcast);
auto result = xla::TriangularSolve( auto result = xla::TriangularSolve(
ctx->Input(0), ctx->Input(1), /*left_side=*/true, a, b, /*left_side=*/true,
/*lower=*/lower_, /*unit_diagonal=*/false, /*lower=*/lower_, /*unit_diagonal=*/false,
/*transpose_a=*/ /*transpose_a=*/
adjoint_ ? xla::TriangularSolveOptions::ADJOINT adjoint_ ? xla::TriangularSolveOptions::ADJOINT
@ -40,10 +63,41 @@ class MatrixTriangularSolveOp : public XlaOpKernel {
} }
private: private:
static std::pair<xla::XlaOp, xla::XlaOp> Broadcast(
xla::XlaOp lhs, const TensorShape& lhs_shape, xla::XlaOp rhs,
const TensorShape& rhs_shape, const MatMulBCast& broadcast_helper);
bool lower_; bool lower_;
bool adjoint_; bool adjoint_;
}; };
/* static */ std::pair<xla::XlaOp, xla::XlaOp>
MatrixTriangularSolveOp::Broadcast(xla::XlaOp lhs, const TensorShape& lhs_shape,
xla::XlaOp rhs, const TensorShape& rhs_shape,
const MatMulBCast& broadcast_helper) {
// Get the batch shape.
int64 m = lhs_shape.dim_size(lhs_shape.dims() - 1);
int64 n = rhs_shape.dim_size(rhs_shape.dims() - 1);
TensorShape lhs_broadcast_shape(broadcast_helper.output_batch_shape());
lhs_broadcast_shape.AddDim(m);
lhs_broadcast_shape.AddDim(m);
auto lhs_output = BroadcastTo(lhs, lhs_broadcast_shape.dim_sizes());
if (!lhs_output.ok()) {
xla::XlaOp error = lhs.builder()->ReportError(lhs_output.status());
return {error, error};
}
TensorShape rhs_broadcast_shape(broadcast_helper.output_batch_shape());
rhs_broadcast_shape.AddDim(m);
rhs_broadcast_shape.AddDim(n);
auto rhs_output = BroadcastTo(rhs, rhs_broadcast_shape.dim_sizes());
if (!rhs_output.ok()) {
xla::XlaOp error = rhs.builder()->ReportError(rhs_output.status());
return {error, error};
}
return {lhs_output.ValueOrDie(), rhs_output.ValueOrDie()};
}
REGISTER_XLA_OP(Name("MatrixTriangularSolve"), MatrixTriangularSolveOp); REGISTER_XLA_OP(Name("MatrixTriangularSolve"), MatrixTriangularSolveOp);
} // namespace } // namespace

View File

@ -44,15 +44,17 @@ square matrices. If `lower` is `True` then the strictly upper triangular part
of each inner-most matrix is assumed to be zero and not accessed. of each inner-most matrix is assumed to be zero and not accessed.
If `lower` is False then the strictly lower triangular part of each inner-most If `lower` is False then the strictly lower triangular part of each inner-most
matrix is assumed to be zero and not accessed. matrix is assumed to be zero and not accessed.
`rhs` is a tensor of shape `[..., M, K]`. `rhs` is a tensor of shape `[..., M, N]`.
The output is a tensor of shape `[..., M, K]`. If `adjoint` is The output is a tensor of shape `[..., M, N]`. If `adjoint` is
`True` then the innermost matrices in `output` satisfy matrix equations `True` then the innermost matrices in `output` satisfy matrix equations
`matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`. `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`.
If `adjoint` is `False` then the strictly then the innermost matrices in If `adjoint` is `False` then the strictly then the innermost matrices in
`output` satisfy matrix equations `output` satisfy matrix equations
`adjoint(matrix[..., i, k]) * output[..., k, j] = rhs[..., i, j]`. `adjoint(matrix[..., i, k]) * output[..., k, j] = rhs[..., i, j]`.
Note, the batch shapes for the inputs only need to broadcast.
Example: Example:
```python ```python

View File

@ -1,10 +1,4 @@
op { op {
graph_op_name: "MatrixTriangularSolve" graph_op_name: "MatrixTriangularSolve"
endpoint { visibility: HIDDEN
name: "linalg.triangular_solve"
}
endpoint {
name: "matrix_triangular_solve"
deprecation_version: 2
}
} }

View File

@ -3503,6 +3503,22 @@ tf_kernel_library(
], ],
) )
tf_kernel_library(
name = "rocm_solvers",
srcs = ["rocm_solvers.cc"],
hdrs = ["rocm_solvers.h"],
visibility = [":friends"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/stream_executor/lib",
"//tensorflow/stream_executor/platform:dso_loader",
"//tensorflow/stream_executor/rocm:rocblas_plugin",
"//tensorflow/stream_executor/rocm:rocm_gpu_executor",
"@local_config_rocm//rocm:rocprim",
],
)
tf_kernel_library( tf_kernel_library(
name = "cuda_sparse", name = "cuda_sparse",
srcs = if_cuda(["cuda_sparse.cc"]) + if_rocm(["rocm_sparse.cc"]), srcs = if_cuda(["cuda_sparse.cc"]) + if_rocm(["rocm_sparse.cc"]),
@ -3527,6 +3543,8 @@ LINALG_DEPS = [
] + if_cuda([ ] + if_cuda([
":cuda_solvers", ":cuda_solvers",
":transpose_functor", ":transpose_functor",
]) + if_rocm([
":rocm_solvers",
]) ])
tf_kernel_library( tf_kernel_library(
@ -3613,9 +3631,23 @@ tf_kernel_library(
tf_kernel_library( tf_kernel_library(
name = "matrix_triangular_solve_op", name = "matrix_triangular_solve_op",
hdrs = ["matrix_triangular_solve_op_impl.h"],
prefix = "matrix_triangular_solve_op", prefix = "matrix_triangular_solve_op",
deps = LINALG_DEPS + if_cuda([ deps = [
":linalg_ops_common",
"//third_party/eigen3",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
":fill_functor",
"//tensorflow/core:stream_executor",
] + if_cuda([
"//tensorflow/core/platform/default/build_config:cublas_plugin", "//tensorflow/core/platform/default/build_config:cublas_plugin",
":cuda_solvers",
]) + if_rocm([
"@local_config_rocm//rocm:rocprim",
":rocm_solvers",
]) + if_cuda_or_rocm([
":transpose_functor",
]), ]),
) )
@ -4204,6 +4236,25 @@ tf_cuda_cc_test(
], ],
) )
tf_cuda_cc_test(
name = "matrix_triangular_solve_op_test",
size = "small",
srcs = ["matrix_triangular_solve_op_test.cc"],
deps = [
":broadcast_to_op",
":matrix_triangular_solve_op",
":ops_testutil",
":ops_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_cuda_cc_test( tf_cuda_cc_test(
name = "scan_ops_test", name = "scan_ops_test",
size = "small", size = "small",

View File

@ -900,6 +900,106 @@ static inline Status MatInvBatchedImpl(
TF_CALL_LAPACK_TYPES(MATINV_BATCHED_INSTANCE); TF_CALL_LAPACK_TYPES(MATINV_BATCHED_INSTANCE);
template <typename Scalar, typename SolverFnT>
static inline Status TrsmImpl(SolverFnT solver, cublasHandle_t cublas_handle,
cublasSideMode_t side, cublasFillMode_t uplo,
cublasOperation_t trans, cublasDiagType_t diag,
int m, int n,
const Scalar* alpha, /* host or device pointer */
const Scalar* A, int lda, Scalar* B, int ldb) {
mutex_lock lock(handle_map_mutex);
using CudaScalar = typename CUDAComplexT<Scalar>::type;
TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, side, uplo, trans, diag, m, n,
reinterpret_cast<const CudaScalar*>(alpha),
reinterpret_cast<const CudaScalar*>(A), lda,
reinterpret_cast<CudaScalar*>(B), ldb));
return Status::OK();
}
#define TRSM_INSTANCE(Scalar, type_prefix) \
template <> \
Status CudaSolver::Trsm<Scalar>( \
cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, \
cublasDiagType_t diag, int m, int n, \
const Scalar* alpha, /* host or device pointer */ \
const Scalar* A, int lda, Scalar* B, int ldb) { \
return TrsmImpl(BLAS_SOLVER_FN(trsm, type_prefix), cublas_handle_, side, \
uplo, trans, diag, m, n, alpha, A, lda, B, ldb); \
}
TF_CALL_LAPACK_TYPES(TRSM_INSTANCE);
template <typename Scalar, typename SolverFnT>
static inline Status TrsvImpl(SolverFnT solver, cublasHandle_t cublas_handle,
cublasFillMode_t uplo, cublasOperation_t trans,
cublasDiagType_t diag, int n, const Scalar* A,
int lda, Scalar* x, int incx) {
mutex_lock lock(handle_map_mutex);
using CudaScalar = typename CUDAComplexT<Scalar>::type;
TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, uplo, trans, diag, n,
reinterpret_cast<const CudaScalar*>(A), lda,
reinterpret_cast<CudaScalar*>(x), incx));
return Status::OK();
}
#define TRSV_INSTANCE(Scalar, type_prefix) \
template <> \
Status CudaSolver::Trsv<Scalar>( \
cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, \
int n, const Scalar* A, int lda, Scalar* x, int incx) { \
return TrsvImpl(BLAS_SOLVER_FN(trsv, type_prefix), cublas_handle_, uplo, \
trans, diag, n, A, lda, x, incx); \
}
TF_CALL_LAPACK_TYPES(TRSV_INSTANCE);
template <typename Scalar, typename SolverFnT>
static inline Status TrsmBatchedImpl(
SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
cublasHandle_t cublas_handle, cublasSideMode_t side, cublasFillMode_t uplo,
cublasOperation_t trans, cublasDiagType_t diag, int m, int n,
const Scalar* alpha, const Scalar* const host_a_dev_ptrs[], int lda,
Scalar* host_b_dev_ptrs[], int ldb, int batch_size) {
mutex_lock lock(handle_map_mutex);
using CudaScalar = typename CUDAComplexT<Scalar>::type;
ScratchSpace<uint8> dev_a_dev_ptrs =
cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
/* on_host */ false);
ScratchSpace<uint8> dev_b_dev_ptrs =
cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
/* on_host */ false);
if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
return errors::Internal("TrsmBatched: failed to copy pointers to device");
}
if (!CopyHostToDevice(context, dev_b_dev_ptrs.mutable_data() /* dest */,
host_b_dev_ptrs /* source */, dev_b_dev_ptrs.bytes())) {
return errors::Internal("TrsmBatched: failed to copy pointers to device");
}
TF_RETURN_IF_CUBLAS_ERROR(
solver(cublas_handle, side, uplo, trans, diag, m, n,
reinterpret_cast<const CudaScalar*>(alpha),
reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()),
lda, reinterpret_cast<CudaScalar**>(dev_b_dev_ptrs.mutable_data()),
ldb, batch_size));
return Status::OK();
}
#define TRSM_BATCHED_INSTANCE(Scalar, type_prefix) \
template <> \
Status CudaSolver::TrsmBatched( \
cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, \
cublasDiagType_t diag, int m, int n, const Scalar* alpha, \
const Scalar* const dev_Aarray[], int lda, Scalar* dev_Barray[], \
int ldb, int batch_size) { \
return TrsmBatchedImpl(BLAS_SOLVER_FN(trsmBatched, type_prefix), this, \
context_, cublas_handle_, side, uplo, trans, diag, \
m, n, alpha, dev_Aarray, lda, dev_Barray, ldb, \
batch_size); \
}
TF_CALL_LAPACK_TYPES(TRSM_BATCHED_INSTANCE);
} // namespace tensorflow } // namespace tensorflow
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA

View File

@ -334,6 +334,29 @@ class CudaSolver {
Scalar* dev_V, int ldv, int* dev_lapack_info, Scalar* dev_V, int ldv, int* dev_lapack_info,
int batch_size); int batch_size);
// Triangular solve
// Returns Status::OK() if the kernel was launched successfully.
// See https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-trsm
template <typename Scalar>
Status Trsm(cublasSideMode_t side, cublasFillMode_t uplo,
cublasOperation_t trans, cublasDiagType_t diag, int m, int n,
const Scalar* alpha, const Scalar* A, int lda, Scalar* B,
int ldb);
template <typename Scalar>
Status Trsv(cublasFillMode_t uplo, cublasOperation_t trans,
cublasDiagType_t diag, int n, const Scalar* A, int lda, Scalar* x,
int intcx);
// See
// https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-trsmbatched
template <typename Scalar>
Status TrsmBatched(cublasSideMode_t side, cublasFillMode_t uplo,
cublasOperation_t trans, cublasDiagType_t diag, int m,
int n, const Scalar* alpha,
const Scalar* const dev_Aarray[], int lda,
Scalar* dev_Barray[], int ldb, int batch_size);
private: private:
OpKernelContext* context_; // not owned. OpKernelContext* context_; // not owned.
cudaStream_t cuda_stream_; cudaStream_t cuda_stream_;

View File

@ -1,258 +0,0 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/linalg_ops.cc.
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace tensorflow {
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace {
template <typename Scalar>
se::DeviceMemory<Scalar> AsDeviceMemory(const Scalar* gpu_memory) {
se::DeviceMemoryBase wrapped(const_cast<Scalar*>(gpu_memory));
se::DeviceMemory<Scalar> typed(wrapped);
return typed;
}
} // namespace
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <class Scalar>
class MatrixTriangularSolveOp : public LinearAlgebraOp<Scalar> {
public:
INHERIT_LINALG_TYPEDEFS(Scalar);
explicit MatrixTriangularSolveOp(OpKernelConstruction* context)
: Base(context), lower_(true), adjoint_(false) {
OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_));
OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
}
void ValidateInputMatrixShapes(
OpKernelContext* context,
const TensorShapes& input_matrix_shapes) const final {
Base::ValidateSquareSolver(context, input_matrix_shapes);
}
TensorShapes GetOutputMatrixShapes(
const TensorShapes& input_matrix_shapes) const final {
return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1),
input_matrix_shapes[1].dim_size(1)})});
}
int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final {
double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0));
double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1));
double cost = rows * rows * num_rhss *
(Eigen::TensorOpCost::AddCost<Scalar>() +
Eigen::TensorOpCost::MulCost<Scalar>());
return cost >= static_cast<double>(kint64max) ? kint64max
: static_cast<int64>(cost);
}
bool EnableInputForwarding() const final { return false; }
void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
MatrixMaps* outputs) final {
const ConstMatrixMap& matrix = inputs[0];
const ConstMatrixMap& rhs = inputs[1];
MatrixMap& output = outputs->at(0);
if (matrix.rows() == 0 || rhs.rows() == 0 || rhs.cols() == 0) {
// To be consistent with the MatrixInverse op, we define the solution for
// an empty set of equation as the empty matrix.
return;
}
const RealScalar min_abs_pivot = matrix.diagonal().cwiseAbs().minCoeff();
OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
errors::InvalidArgument("Input matrix is not invertible."));
if (lower_) {
auto triangle = matrix.template triangularView<Eigen::Lower>();
if (adjoint_) {
output.noalias() = triangle.adjoint().solve(rhs);
} else {
output.noalias() = triangle.solve(rhs);
}
} else {
auto triangle = matrix.template triangularView<Eigen::Upper>();
if (adjoint_) {
output.noalias() = triangle.adjoint().solve(rhs);
} else {
output.noalias() = triangle.solve(rhs);
}
}
}
private:
bool lower_;
bool adjoint_;
TF_DISALLOW_COPY_AND_ASSIGN(MatrixTriangularSolveOp);
};
REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
(MatrixTriangularSolveOp<float>), float);
REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
(MatrixTriangularSolveOp<double>), double);
REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
(MatrixTriangularSolveOp<complex64>), complex64);
REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
(MatrixTriangularSolveOp<complex128>), complex128);
REGISTER_LINALG_OP_CPU("BatchMatrixTriangularSolve",
(MatrixTriangularSolveOp<float>), float);
REGISTER_LINALG_OP_CPU("BatchMatrixTriangularSolve",
(MatrixTriangularSolveOp<double>), double);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// TODO(rmlarsen): Re-factor to
// 1. Enable buffer forwarding from rhs->out.
// 2. Save Memcpy when buffer forwarding is used.
// 3. Copy entire rhs in a single Memcpy when forwarding is not used.
template <class Scalar>
class MatrixTriangularSolveOpGPU : public LinearAlgebraOp<Scalar> {
public:
INHERIT_LINALG_TYPEDEFS(Scalar);
explicit MatrixTriangularSolveOpGPU(OpKernelConstruction* context)
: Base(context), lower_(true), adjoint_(false) {
OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_));
OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
}
void ValidateInputMatrixShapes(
OpKernelContext* context,
const TensorShapes& input_matrix_shapes) const final {
Base::ValidateSquareSolver(context, input_matrix_shapes);
}
TensorShapes GetOutputMatrixShapes(
const TensorShapes& input_matrix_shapes) const final {
return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1),
input_matrix_shapes[1].dim_size(1)})});
}
int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final {
double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0));
double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1));
double cost = rows * rows * num_rhss *
(Eigen::TensorOpCost::AddCost<Scalar>() +
Eigen::TensorOpCost::MulCost<Scalar>());
return cost >= static_cast<double>(kint64max) ? kint64max
: static_cast<int64>(cost);
}
bool EnableInputForwarding() const final { return false; }
void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
MatrixMaps* outputs) final {
const ConstMatrixMap& matrix = inputs[0];
const ConstMatrixMap& rhs = inputs[1];
MatrixMap& output = outputs->at(0);
if (matrix.rows() == 0 || rhs.rows() == 0 || rhs.cols() == 0) {
// To be consistent with the MatrixInverse op, we define the solution for
// an empty set of equation as the empty matrix.
return;
}
auto matrix_ptr = AsDeviceMemory(matrix.data());
auto rhs_ptr = AsDeviceMemory(rhs.data());
auto out_ptr = AsDeviceMemory(output.data());
auto* stream = context->op_device_context()->stream();
uint64 rhs_elems = rhs.rows() * rhs.cols();
bool copy_status =
stream->ThenMemcpyD2D(&out_ptr, rhs_ptr, sizeof(Scalar) * rhs_elems)
.ok();
if (!copy_status) {
context->SetStatus(
errors::Internal("Failed to copy rhs into output before solve"));
}
// Cublas does
// output = matrix \ rhs
// where matrix, rhs and output are assumed to be in column major.
// We want the output to be in row-major, so we can compute
// output' = rhs' / matrix' (' stands for transpose)
// Upper/lower needs to be swapped for this.
se::blas::UpperLower upper_lower_matrix;
se::blas::Transpose transpose_matrix;
if (lower_) {
upper_lower_matrix = se::blas::UpperLower::kUpper;
} else {
upper_lower_matrix = se::blas::UpperLower::kLower;
}
if (adjoint_) {
transpose_matrix = se::blas::Transpose::kConjugateTranspose;
} else {
transpose_matrix = se::blas::Transpose::kNoTranspose;
}
uint64 leading_dim_matrix = matrix.cols();
uint64 leading_dim_output = output.cols();
uint64 colmajor_rows = output.cols();
uint64 colmajor_cols = output.rows();
bool blas_launch_status =
stream
->ThenBlasTrsm(
se::blas::Side::kRight /*side*/, upper_lower_matrix /*uplo*/,
transpose_matrix /*trans*/,
se::blas::Diagonal::kNonUnit /*diag*/, colmajor_rows /*m*/,
colmajor_cols /*n*/, Scalar(1.0) /*alpha*/, matrix_ptr,
leading_dim_matrix /*lda*/, &out_ptr,
leading_dim_output /*ldb*/)
.ok();
if (!blas_launch_status) {
context->SetStatus(errors::Internal("Blas TRSM launch failed"));
}
}
private:
bool lower_;
bool adjoint_;
TF_DISALLOW_COPY_AND_ASSIGN(MatrixTriangularSolveOpGPU);
};
REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
(MatrixTriangularSolveOpGPU<float>), float);
REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
(MatrixTriangularSolveOpGPU<double>), double);
REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
(MatrixTriangularSolveOpGPU<complex64>), complex64);
REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
(MatrixTriangularSolveOpGPU<complex128>), complex128);
REGISTER_LINALG_OP_GPU("BatchMatrixTriangularSolve",
(MatrixTriangularSolveOpGPU<float>), float);
REGISTER_LINALG_OP_GPU("BatchMatrixTriangularSolve",
(MatrixTriangularSolveOpGPU<double>), double);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -0,0 +1,28 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/matrix_triangular_solve_op_impl.h"
namespace tensorflow {
TF_CALL_complex64(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU);
TF_CALL_complex128(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU);
#if GOOGLE_CUDA
TF_CALL_complex64(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU);
TF_CALL_complex128(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -0,0 +1,437 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/linalg_ops.cc.
//
#ifndef TENSORFLOW_CORE_KERNELS_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
#define TENSORFLOW_CORE_KERNELS_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/matmul_bcast.h"
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/transpose_functor.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if GOOGLE_CUDA
#include "tensorflow/core/kernels/cuda_solvers.h"
#elif TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/rocm_solvers.h"
#endif
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename Scalar>
se::DeviceMemory<Scalar> AsDeviceMemory(const Scalar* gpu_memory) {
se::DeviceMemoryBase wrapped(const_cast<Scalar*>(gpu_memory));
se::DeviceMemory<Scalar> typed(wrapped);
return typed;
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Sequential batch matrix triangular solve kernel that calls Eigen's
// matrix triangular solve.
template <typename Scalar>
struct SequentialMatrixTriangularSolveKernel {
using Matrix =
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using ConstMatrixMap = Eigen::Map<const Matrix>;
using MatrixMap = Eigen::Map<Matrix>;
using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t,
int slice) {
return ConstMatrixMap(
t.flat<Scalar>().data() + slice * t.dim_size(1) * t.dim_size(2),
t.dim_size(1), t.dim_size(2));
}
static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) {
return MatrixMap(
t->flat<Scalar>().data() + slice * t->dim_size(1) * t->dim_size(2),
t->dim_size(1), t->dim_size(2));
}
static void Run(const Tensor& in_x, const Tensor& in_y, bool lower,
bool adjoint, const MatMulBCast& bcast, Tensor* out,
int start, int limit) {
const bool should_bcast = bcast.IsBroadcastingRequired();
const auto& x_batch_indices = bcast.x_batch_indices();
const auto& y_batch_indices = bcast.y_batch_indices();
for (int64 i = start; i < limit; ++i) {
const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
auto matrix = ConstTensorSliceToEigenMatrix(in_x, x_batch_index);
auto rhs = ConstTensorSliceToEigenMatrix(in_y, y_batch_index);
auto output = TensorSliceToEigenMatrix(out, i);
if (lower) {
auto triangle = matrix.template triangularView<Eigen::Lower>();
if (adjoint) {
output.noalias() = triangle.adjoint().solve(rhs);
} else {
output.noalias() = triangle.solve(rhs);
}
} else {
auto triangle = matrix.template triangularView<Eigen::Upper>();
if (adjoint) {
output.noalias() = triangle.adjoint().solve(rhs);
} else {
output.noalias() = triangle.solve(rhs);
}
}
}
}
};
template <typename Device, typename Scalar>
struct LaunchBatchMatrixTriangularSolve;
template <typename Scalar>
struct LaunchBatchMatrixTriangularSolve<CPUDevice, Scalar> {
static void Launch(OpKernelContext* context, const Tensor& in_x,
const Tensor& in_y, bool adjoint, bool lower,
const MatMulBCast& bcast, Tensor* out) {
// Number of matrix triangular solves i.e. size of the batch.
const int64 batch_size = bcast.output_batch_size();
const int64 cost_per_unit =
in_x.dim_size(1) * in_x.dim_size(1) * in_y.dim_size(2) / 2;
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
using Matrix =
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using ConstMatrixMap = Eigen::Map<const Matrix>;
using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
// Check diagonal before doing any solves.
auto matrix = ConstMatrixMap(in_x.flat<Scalar>().data(), in_x.dim_size(1),
in_x.dim_size(2));
const RealScalar min_abs_pivot = matrix.diagonal().cwiseAbs().minCoeff();
OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
errors::InvalidArgument("Input matrix is not invertible."));
Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
cost_per_unit,
[&in_x, &in_y, adjoint, lower, &bcast, out](int start, int limit) {
SequentialMatrixTriangularSolveKernel<Scalar>::Run(
in_x, in_y, lower, adjoint, bcast, out, start, limit);
});
}
};
template <typename Device, typename Scalar>
class BaseMatrixTriangularSolveOp : public OpKernel {
public:
explicit BaseMatrixTriangularSolveOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_));
OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
}
~BaseMatrixTriangularSolveOp() override {}
void Compute(OpKernelContext* ctx) override {
const Tensor& in0 = ctx->input(0);
const Tensor& in1 = ctx->input(1);
ValidateInputTensors(ctx, in0, in1);
MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
OP_REQUIRES(
ctx, bcast.IsValid(),
errors::InvalidArgument(
"In[0] and In[1] must have compatible batch dimensions: ",
in0.shape().DebugString(), " vs. ", in1.shape().DebugString()));
TensorShape out_shape = bcast.output_batch_shape();
auto batch_size = bcast.output_batch_size();
auto d0 = in0.dim_size(in0.dims() - 2);
auto d1 = in0.dim_size(in0.dims() - 1);
Tensor in0_reshaped;
OP_REQUIRES(
ctx,
in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})),
errors::Internal("Failed to reshape In[0] from ",
in0.shape().DebugString()));
auto d2 = in1.dim_size(in1.dims() - 2);
auto d3 = in1.dim_size(in1.dims() - 1);
Tensor in1_reshaped;
OP_REQUIRES(
ctx,
in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
errors::Internal("Failed to reshape In[1] from ",
in1.shape().DebugString()));
if (adjoint_) std::swap(d0, d1);
OP_REQUIRES(ctx, d1 == d2,
errors::InvalidArgument(
"In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
in0.shape().DebugString(), " ", in1.shape().DebugString(),
" ", lower_, " ", adjoint_));
out_shape.AddDim(d0);
out_shape.AddDim(d3);
Tensor* out = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
if (out->NumElements() == 0) {
return;
}
Tensor out_reshaped;
OP_REQUIRES(ctx,
out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})),
errors::Internal("Failed to reshape output from ",
out->shape().DebugString()));
LaunchBatchMatrixTriangularSolve<Device, Scalar>::Launch(
ctx, in0_reshaped, in1_reshaped, adjoint_, lower_, bcast,
&out_reshaped);
}
private:
virtual void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
const Tensor& in1) = 0;
bool lower_;
bool adjoint_;
};
template <class Device, class Scalar>
class MatrixTriangularSolveOp
: public BaseMatrixTriangularSolveOp<Device, Scalar> {
public:
explicit MatrixTriangularSolveOp(OpKernelConstruction* context)
: BaseMatrixTriangularSolveOp<Device, Scalar>(context) {}
~MatrixTriangularSolveOp() override {}
private:
void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
const Tensor& in1) override {
OP_REQUIRES(
ctx, in0.dims() >= 2,
errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims()));
OP_REQUIRES(
ctx, in1.dims() >= 2,
errors::InvalidArgument("In[0] ndims must be >= 2: ", in1.dims()));
}
};
#define REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU(TYPE) \
REGISTER_KERNEL_BUILDER(Name("MatrixTriangularSolve") \
.Device(DEVICE_CPU) \
.TypeConstraint<TYPE>("T"), \
MatrixTriangularSolveOp<CPUDevice, TYPE>); \
REGISTER_KERNEL_BUILDER(Name("BatchMatrixTriangularSolve") \
.Device(DEVICE_CPU) \
.TypeConstraint<TYPE>("T"), \
MatrixTriangularSolveOp<CPUDevice, TYPE>);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename Scalar>
struct LaunchBatchMatrixTriangularSolve<GPUDevice, Scalar> {
static void Launch(OpKernelContext* context, const Tensor& in_x,
const Tensor& in_y, bool adjoint, bool lower,
const MatMulBCast& bcast, Tensor* out) {
auto* stream = context->op_device_context()->stream();
const uint64 m = in_x.dim_size(1);
const uint64 n = out->dim_size(2);
// Do a memcpy when we don't need to broadcast.
if (!bcast.IsBroadcastingRequired() || out->shape() == in_y.shape()) {
auto src_device_mem = AsDeviceMemory(in_y.template flat<Scalar>().data());
auto dst_device_mem = AsDeviceMemory(out->template flat<Scalar>().data());
OP_REQUIRES(
context,
stream
->ThenMemcpyD2D(&dst_device_mem, src_device_mem,
bcast.y_batch_size() * m * n * sizeof(Scalar))
.ok(),
errors::Internal("MatrixTriangularSolveOp: failed to copy rhs "
"from device"));
} else {
std::vector<Scalar*> out_ptrs;
std::vector<const Scalar*> b_tmp_ptrs;
auto* b_base_ptr = in_y.template flat<Scalar>().data();
const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
for (int64 i = 0; i < bcast.y_batch_size(); ++i) {
b_tmp_ptrs.push_back(b_base_ptr + i * m * n);
}
for (int64 i = 0; i < bcast.output_batch_size(); ++i) {
auto src_device_mem = AsDeviceMemory(b_tmp_ptrs[b_batch_indices[i]]);
auto dst_device_mem =
AsDeviceMemory(out->template flat<Scalar>().data() + i * m * n);
OP_REQUIRES(
context,
stream
->ThenMemcpyD2D(&dst_device_mem, src_device_mem,
m * n * sizeof(Scalar))
.ok(),
errors::Internal("MatrixTriangularSolveOp: failed to copy rhs "
"from device"));
}
}
if (out->NumElements() == 0) {
return;
}
#if GOOGLE_CUDA
cublasSideMode_t side = CUBLAS_SIDE_RIGHT;
cublasFillMode_t uplo;
cublasOperation_t trans;
cublasDiagType_t diag = CUBLAS_DIAG_NON_UNIT;
// Cublas does
// output = matrix \ rhs
// where matrix, rhs and output are assumed to be in column major.
// We want the output to be in row-major, so we can compute
// output' = rhs' / matrix' (' stands for transpose)
// Upper/lower needs to be swapped for this.
uplo = lower ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
trans = adjoint ? CUBLAS_OP_C : CUBLAS_OP_N;
auto solver = absl::make_unique<CudaSolver>(context);
#elif TENSORFLOW_USE_ROCM
rocblas_side side = rocblas_side_right;
rocblas_fill uplo;
rocblas_operation trans;
rocblas_diagonal diag = rocblas_diagonal_non_unit;
// rocblas does
// output = matrix \ rhs
// where matrix, rhs and output are assumed to be in column major.
// We want the output to be in row-major, so we can compute
// output' = rhs' / matrix' (' stands for transpose)
// Upper/lower needs to be swapped for this.
uplo = lower ? rocblas_fill_upper : rocblas_fill_upper;
trans = adjoint ? rocblas_operation_conjugate_transpose
: rocblas_operation_none;
auto solver = absl::make_unique<ROCmSolver>(context);
#endif
const uint64 leading_dim_matrix = m;
const uint64 leading_dim_output = n;
const uint64 colmajor_rows = n;
const uint64 colmajor_cols = m;
const int64 batch_size = bcast.output_batch_size();
std::vector<const Scalar*> a_ptrs;
std::vector<Scalar*> out_ptrs;
std::vector<const Scalar*> a_tmp_ptrs;
a_ptrs.reserve(batch_size);
out_ptrs.reserve(batch_size);
a_tmp_ptrs.reserve(bcast.x_batch_size());
auto* a_base_ptr = in_x.template flat<Scalar>().data();
auto* out_base_ptr = out->template flat<Scalar>().data();
if (!bcast.IsBroadcastingRequired()) {
for (int64 i = 0; i < batch_size; ++i) {
a_ptrs.push_back(a_base_ptr + i * m * m);
out_ptrs.push_back(out_base_ptr + i * m * n);
}
} else {
const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
for (int64 i = 0; i < bcast.x_batch_size(); ++i) {
a_tmp_ptrs.push_back(a_base_ptr + i * m * m);
}
for (int64 i = 0; i < batch_size; ++i) {
a_ptrs.push_back(a_tmp_ptrs[a_batch_indices[i]]);
out_ptrs.push_back(out_base_ptr + i * m * n);
}
}
typedef Scalar Coefficient;
const Scalar alpha = Scalar(1.0);
#if GOOGLE_CUDA
// TODO(b/146763573): Consider using Trsv here when the right hand side is
// a vector. This will require an explicit transpose since Trsv assumes
// CUBLAS_SIDE_LEFT.
if (batch_size == 1) {
OP_REQUIRES_OK(
context,
solver->Trsm(side, uplo, trans, diag, colmajor_rows, colmajor_cols,
&alpha, a_ptrs[0], leading_dim_matrix /*lda*/,
out_ptrs[0], leading_dim_output /*ldb*/));
} else {
// Heuristic for choosing between batched interface vs. non-batched
// interface. This is inspired by matrix_solve_op and can probably be
// tuned.
// TODO(b/146763573): Tune this heuristic.
const int kMaxMatrixSizeToBatchSizeRatio = 128;
const bool use_batched_solver =
m <= kMaxMatrixSizeToBatchSizeRatio * batch_size;
if (use_batched_solver) {
OP_REQUIRES_OK(
context, solver->TrsmBatched(
side, uplo, trans, diag, colmajor_rows, colmajor_cols,
&alpha, &a_ptrs[0], leading_dim_matrix /*lda*/,
&out_ptrs[0], leading_dim_output /*ldb*/, batch_size));
} else {
for (int batch = 0; batch < batch_size; ++batch) {
OP_REQUIRES_OK(
context, solver->Trsm(side, uplo, trans, diag, colmajor_rows,
colmajor_cols, &alpha, a_ptrs[batch],
leading_dim_matrix /*lda*/, out_ptrs[batch],
leading_dim_output /*ldb*/));
}
}
}
#elif TENSORFLOW_USE_ROCM
for (int batch = 0; batch < batch_size; ++batch) {
OP_REQUIRES_OK(
context,
solver->Trsm(side, uplo, trans, diag, colmajor_rows, colmajor_cols,
&alpha, a_ptrs[batch], leading_dim_matrix /*lda*/,
out_ptrs[batch], leading_dim_output /*ldb*/));
}
#endif
}
};
#define REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU(TYPE) \
REGISTER_KERNEL_BUILDER(Name("MatrixTriangularSolve") \
.Device(DEVICE_GPU) \
.TypeConstraint<TYPE>("T"), \
MatrixTriangularSolveOp<GPUDevice, TYPE>); \
REGISTER_KERNEL_BUILDER(Name("BatchMatrixTriangularSolve") \
.Device(DEVICE_GPU) \
.TypeConstraint<TYPE>("T"), \
MatrixTriangularSolveOp<GPUDevice, TYPE>);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_

View File

@ -0,0 +1,32 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/matrix_triangular_solve_op_impl.h"
#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cuda.h"
#endif // GOOGLE_CUDA
namespace tensorflow {
TF_CALL_float(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU);
TF_CALL_double(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TF_CALL_float(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU);
TF_CALL_double(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -0,0 +1,165 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/kernels/broadcast_to_op.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
namespace {
Node* BroadcastTo(Graph* g, Node* input, Node* shape) {
Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BroadcastTo")
.Input(input)
.Input(shape)
.Attr("Tidx", DT_INT64)
.Finalize(g, &ret));
return ret;
}
Node* MatrixTriangularSolve(Graph* g, Node* in0, Node* in1, bool adjoint) {
Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatrixTriangularSolve")
.Input(in0)
.Input(in1)
.Attr("lower", true)
.Attr("adjoint", adjoint)
.Finalize(g, &ret));
return ret;
}
template <typename T>
static Graph* MatrixTriangularSolveWithBroadcast(int64 b0, int64 b1, int64 m,
int64 n, bool manual_broadcast,
DataType type) {
Graph* g = new Graph(OpRegistry::Global());
Tensor in0(type, TensorShape({b0, m, m}));
// Set diagonal to non-zero to guarantee invertibility.
in0.flat<T>().setRandom();
auto matrix = Eigen::Map<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
in0.flat<T>().data(), in0.dim_size(1), in0.dim_size(2));
matrix.diagonal() =
(matrix.diagonal().cwiseAbs().array() + static_cast<T>(0.5));
Tensor in1(type, TensorShape({b1, m, n}));
in1.flat<T>().setRandom();
Tensor broadcasted_in0_shape(DT_INT64, TensorShape({3}));
Tensor broadcasted_in1_shape(DT_INT64, TensorShape({3}));
Node* in0_node = nullptr;
Node* in1_node = nullptr;
if (manual_broadcast) {
auto vec0 = broadcasted_in0_shape.vec<int64>();
auto vec1 = broadcasted_in1_shape.vec<int64>();
for (int i = 0; i < 3; ++i) {
vec0(i) = (i == 0 ? std::max(b0, b1) : in0.shape().dim_size(i));
vec1(i) = (i == 0 ? std::max(b0, b1) : in1.shape().dim_size(i));
}
in0_node = BroadcastTo(g, test::graph::Constant(g, in0),
test::graph::Constant(g, broadcasted_in0_shape));
in1_node = BroadcastTo(g, test::graph::Constant(g, in1),
test::graph::Constant(g, broadcasted_in1_shape));
} else {
in0_node = test::graph::Constant(g, in0);
in1_node = test::graph::Constant(g, in1);
}
MatrixTriangularSolve(g, in0_node, in1_node, false);
return g;
}
// Macro arguments names: --------------------------------------------------- //
// B1: batch size of LHS
// B2: batch size of RHS
// M: inner dimensions of LHS and RHS, outer dimension of LHS
// N: outer dimension of RHS
// MB: boolean indicating whether to use manual broadcasting
// T: C++ type of scalars (e.g. float, std::complex)
// TT: TensorFlow type of scalars (e.g. DT_FLOAT, DT_COMPLEX128
// D: Device (e.g. cpu, gpu)
#define BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, T, TT, D) \
static void \
BM_MatrixTriangularSolve##_##B1##_##B2##_##M##_##N##_##MB##_##TT##_##D( \
int iters) { \
testing::UseRealTime(); \
testing::ItemsProcessed(static_cast<int64>(iters) * std::max(B1, B2) * M * \
M * N * 2); \
test::Benchmark( \
#D, MatrixTriangularSolveWithBroadcast<T>(B1, B2, M, N, MB, TT)) \
.Run(iters); \
} \
BENCHMARK( \
BM_MatrixTriangularSolve##_##B1##_##B2##_##M##_##N##_##MB##_##TT##_##D);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define BM_MatrixTriangularSolve(B1, B2, M, N, MB) \
BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, float, DT_FLOAT, cpu); \
BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, double, DT_DOUBLE, cpu); \
BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, float, DT_FLOAT, gpu); \
BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, double, DT_DOUBLE, gpu);
#else
#define BM_MatrixTriangularSolve(B1, B2, M, N, MB) \
BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, float, DT_FLOAT, cpu); \
BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, double, DT_DOUBLE, cpu);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Square matrix triangular solve.
BM_MatrixTriangularSolve(32, 32, 512, 512, true);
BM_MatrixTriangularSolve(32, 32, 512, 512, false);
BM_MatrixTriangularSolve(1, 32, 512, 512, true);
BM_MatrixTriangularSolve(1, 32, 512, 512, false);
BM_MatrixTriangularSolve(32, 1, 512, 512, true);
BM_MatrixTriangularSolve(32, 1, 512, 512, false);
BM_MatrixTriangularSolve(128, 128, 512, 512, true);
BM_MatrixTriangularSolve(128, 128, 512, 512, false);
BM_MatrixTriangularSolve(1, 128, 512, 512, true);
BM_MatrixTriangularSolve(1, 128, 512, 512, false);
BM_MatrixTriangularSolve(128, 1, 512, 512, true);
BM_MatrixTriangularSolve(128, 1, 512, 512, false);
BM_MatrixTriangularSolve(1, 128, 1024, 1024, true);
BM_MatrixTriangularSolve(1, 128, 1024, 1024, false);
BM_MatrixTriangularSolve(128, 1, 1024, 1024, true);
BM_MatrixTriangularSolve(128, 1, 1024, 1024, false);
// Matrix-vector triangular solve.
BM_MatrixTriangularSolve(1, 128, 200, 1, true);
BM_MatrixTriangularSolve(1, 128, 200, 1, false);
BM_MatrixTriangularSolve(128, 1, 200, 1, true);
BM_MatrixTriangularSolve(128, 1, 200, 1, false);
// Matrix-vector triangular solve, large dimension.
BM_MatrixTriangularSolve(1, 128, 200, 10000, true);
BM_MatrixTriangularSolve(1, 128, 200, 10000, false);
BM_MatrixTriangularSolve(128, 1, 200, 10000, true);
BM_MatrixTriangularSolve(128, 1, 200, 10000, false);
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,245 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
#if TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/rocm_solvers.h"
#include <complex>
#include <unordered_map>
#include <vector>
#include "rocm/include/rocblas.h"
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/gpu/gpu_activation.h"
#include "tensorflow/stream_executor/gpu/gpu_executor.h"
#include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/platform/default/dso_loader.h"
#include "tensorflow/stream_executor/platform/port.h"
namespace tensorflow {
namespace {
using stream_executor::gpu::GpuExecutor;
using stream_executor::gpu::ScopedActivateExecutorContext;
using stream_executor::internal::CachedDsoLoader::GetRocblasDsoHandle;
namespace wrap {
#ifdef PLATFORM_GOOGLE
#define ROCBLAS_WRAP(__name) \
struct WrapperShim__##__name { \
static const char* kName; \
template <typename... Args> \
rocblas_status operator()(GpuExecutor* parent, Args... args) { \
ScopedActivateExecutorContext sac{parent}; \
return ::__name(args...); \
} \
} __name; \
const char* WrapperShim__##__name::kName = #__name;
#else
#define ROCBLAS_WRAP(__name) \
struct DynLoadShim__##__name { \
static const char* kName; \
using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
static void* GetDsoHandle() { \
auto s = GetRocblasDsoHandle(); \
return s.ValueOrDie(); \
} \
static FuncPtrT LoadOrDie() { \
void* f; \
auto s = stream_executor::port::Env::Default()->GetSymbolFromLibrary( \
GetDsoHandle(), kName, &f); \
CHECK(s.ok()) << "could not find " << kName \
<< " in rocblas DSO; dlerror: " << s.error_message(); \
return reinterpret_cast<FuncPtrT>(f); \
} \
static FuncPtrT DynLoad() { \
static FuncPtrT f = LoadOrDie(); \
return f; \
} \
template <typename... Args> \
rocblas_status operator()(GpuExecutor* parent, Args... args) { \
ScopedActivateExecutorContext sac{parent}; \
return DynLoad()(args...); \
} \
} __name; \
const char* DynLoadShim__##__name::kName = #__name;
#endif
ROCBLAS_WRAP(rocblas_create_handle)
ROCBLAS_WRAP(rocblas_destroy_handle)
ROCBLAS_WRAP(rocblas_set_stream)
ROCBLAS_WRAP(rocblas_dtrsm)
ROCBLAS_WRAP(rocblas_strsm)
} // namespace wrap
struct ROCmSolverHandles {
explicit ROCmSolverHandles(GpuExecutor* parent, hipStream_t stream) {
parent_ = parent;
CHECK(wrap::rocblas_create_handle(parent_, &rocm_blas_handle) ==
rocblas_status_success)
<< "Failed to create rocBlas instance.";
CHECK(wrap::rocblas_set_stream(parent_, rocm_blas_handle, stream) ==
rocblas_status_success)
<< "Failed to set rocBlas stream.";
}
~ROCmSolverHandles() {
CHECK(wrap::rocblas_destroy_handle(parent_, rocm_blas_handle) ==
rocblas_status_success)
<< "Failed to destroy cuBlas instance.";
}
GpuExecutor* parent_;
rocblas_handle rocm_blas_handle;
};
using HandleMap =
std::unordered_map<hipStream_t, std::unique_ptr<ROCmSolverHandles>>;
// Returns a singleton map used for storing initialized handles for each unique
// gpu stream.
HandleMap* GetHandleMapSingleton() {
static HandleMap* cm = new HandleMap;
return cm;
}
static mutex handle_map_mutex(LINKER_INITIALIZED);
} // namespace
ROCmSolver::ROCmSolver(OpKernelContext* context) : context_(context) {
mutex_lock lock(handle_map_mutex);
GpuExecutor* gpu_executor = static_cast<GpuExecutor*>(
context->op_device_context()->stream()->parent()->implementation());
const hipStream_t* hip_stream_ptr = CHECK_NOTNULL(
reinterpret_cast<const hipStream_t*>(context->op_device_context()
->stream()
->implementation()
->GpuStreamMemberHack()));
hip_stream_ = *hip_stream_ptr;
HandleMap* handle_map = CHECK_NOTNULL(GetHandleMapSingleton());
auto it = handle_map->find(hip_stream_);
if (it == handle_map->end()) {
LOG(INFO) << "Creating ROCmSolver handles for stream " << hip_stream_;
// Previously unseen Gpu stream. Initialize a set of Gpu solver library
// handles for it.
std::unique_ptr<ROCmSolverHandles> new_handles(
new ROCmSolverHandles(gpu_executor, hip_stream_));
it = handle_map->insert(std::make_pair(hip_stream_, std::move(new_handles)))
.first;
}
rocm_blas_handle_ = it->second->rocm_blas_handle;
}
ROCmSolver::~ROCmSolver() {
for (auto tensor_ref : scratch_tensor_refs_) {
tensor_ref.Unref();
}
}
#define TF_RETURN_IF_ROCBLAS_ERROR(expr) \
do { \
auto status = (expr); \
if (TF_PREDICT_FALSE(status != rocblas_status_success)) { \
return errors::Internal(__FILE__, ":", __LINE__, \
": rocBlas call failed status = ", status); \
} \
} while (0)
// Macro that specializes a solver method for all 4 standard
// numeric types.
#define TF_CALL_LAPACK_TYPES(m) \
m(float, s) m(double, d) m(std::complex<float>, c) m(std::complex<double>, z)
#define TF_CALL_LAPACK_TYPES_NO_COMPLEX(m) m(float, s) m(double, d)
#define BLAS_SOLVER_FN(method, type_prefix) \
wrap::rocblas##_##type_prefix##method
// Allocates a temporary tensor. The ROCmSolver object maintains a
// TensorReference to the underlying Tensor to prevent it from being deallocated
// prematurely.
Status ROCmSolver::allocate_scoped_tensor(DataType type,
const TensorShape& shape,
Tensor* out_temp) {
const Status status = context_->allocate_temp(type, shape, out_temp);
if (status.ok()) {
scratch_tensor_refs_.emplace_back(*out_temp);
}
return status;
}
Status ROCmSolver::forward_input_or_allocate_scoped_tensor(
gtl::ArraySlice<int> candidate_input_indices, DataType type,
const TensorShape& shape, Tensor* out_temp) {
const Status status = context_->forward_input_or_allocate_temp(
candidate_input_indices, type, shape, out_temp);
if (status.ok()) {
scratch_tensor_refs_.emplace_back(*out_temp);
}
return status;
}
template <typename Scalar, typename SolverFnT>
static inline Status TrsmImpl(GpuExecutor* gpu_executor, SolverFnT solver,
rocblas_handle rocm_blas_handle,
rocblas_side side, rocblas_fill uplo,
rocblas_operation trans, rocblas_diagonal diag,
int m, int n,
const Scalar* alpha, /* host or device pointer */
const Scalar* A, int lda, Scalar* B, int ldb) {
mutex_lock lock(handle_map_mutex);
using ROCmScalar = typename ROCmComplexT<Scalar>::type;
TF_RETURN_IF_ROCBLAS_ERROR(solver(gpu_executor, rocm_blas_handle, side, uplo,
trans, diag, m, n,
reinterpret_cast<const ROCmScalar*>(alpha),
reinterpret_cast<const ROCmScalar*>(A), lda,
reinterpret_cast<ROCmScalar*>(B), ldb));
return Status::OK();
}
#define TRSM_INSTANCE(Scalar, type_prefix) \
template <> \
Status ROCmSolver::Trsm<Scalar>( \
rocblas_side side, rocblas_fill uplo, rocblas_operation trans, \
rocblas_diagonal diag, int m, int n, \
const Scalar* alpha, /* host or device pointer */ \
const Scalar* A, int lda, Scalar* B, int ldb) { \
GpuExecutor* gpu_executor = static_cast<GpuExecutor*>( \
context_->op_device_context()->stream()->parent()->implementation()); \
return TrsmImpl(gpu_executor, BLAS_SOLVER_FN(trsm, type_prefix), \
rocm_blas_handle_, side, uplo, trans, diag, m, n, alpha, \
A, lda, B, ldb); \
}
TF_CALL_LAPACK_TYPES_NO_COMPLEX(TRSM_INSTANCE);
} // namespace tensorflow
#endif // TENSORFLOW_USE_ROCM

View File

@ -0,0 +1,160 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
#ifndef TENSORFLOW_CORE_KERNELS_ROCM_SOLVERS_H_
#define TENSORFLOW_CORE_KERNELS_ROCM_SOLVERS_H_
// This header declares the class ROCmSolver, which contains wrappers of linear
// algebra solvers in the cuBlas and cuSolverDN libraries for use in TensorFlow
// kernels.
#if TENSORFLOW_USE_ROCM
#include <functional>
#include <vector>
#include "rocm/include/hip/hip_complex.h"
#include "rocm/include/rocblas.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/stream_executor/blas.h"
namespace tensorflow {
// Type traits to get ROCm complex types from std::complex<T>.
template <typename T>
struct ROCmComplexT {
typedef T type;
};
template <>
struct ROCmComplexT<std::complex<float>> {
typedef hipComplex type;
};
template <>
struct ROCmComplexT<std::complex<double>> {
typedef hipDoubleComplex type;
};
// Converts pointers of std::complex<> to pointers of
// cuComplex/cuDoubleComplex. No type conversion for non-complex types.
template <typename T>
inline const typename ROCmComplexT<T>::type* ROCmComplex(const T* p) {
return reinterpret_cast<const typename ROCmComplexT<T>::type*>(p);
}
template <typename T>
inline typename ROCmComplexT<T>::type* ROCmComplex(T* p) {
return reinterpret_cast<typename ROCmComplexT<T>::type*>(p);
}
template <typename Scalar>
class ScratchSpace;
class ROCmSolver {
public:
// This object stores a pointer to context, which must outlive it.
explicit ROCmSolver(OpKernelContext* context);
virtual ~ROCmSolver();
// Allocates a temporary tensor that will live for the duration of the
// ROCmSolver object.
Status allocate_scoped_tensor(DataType type, const TensorShape& shape,
Tensor* scoped_tensor);
Status forward_input_or_allocate_scoped_tensor(
gtl::ArraySlice<int> candidate_input_indices, DataType type,
const TensorShape& shape, Tensor* input_alias_or_new_scoped_tensor);
OpKernelContext* context() { return context_; }
template <typename Scalar>
Status Trsm(rocblas_side side, rocblas_fill uplo, rocblas_operation trans,
rocblas_diagonal diag, int m, int n, const Scalar* alpha,
const Scalar* A, int lda, Scalar* B, int ldb);
private:
OpKernelContext* context_; // not owned.
hipStream_t hip_stream_;
rocblas_handle rocm_blas_handle_;
std::vector<TensorReference> scratch_tensor_refs_;
TF_DISALLOW_COPY_AND_ASSIGN(ROCmSolver);
};
// Helper class to allocate scratch memory and keep track of debug info.
// Mostly a thin wrapper around Tensor & allocate_temp.
template <typename Scalar>
class ScratchSpace {
public:
ScratchSpace(OpKernelContext* context, int64 size, bool on_host)
: ScratchSpace(context, TensorShape({size}), "", on_host) {}
ScratchSpace(OpKernelContext* context, int64 size, const string& debug_info,
bool on_host)
: ScratchSpace(context, TensorShape({size}), debug_info, on_host) {}
ScratchSpace(OpKernelContext* context, const TensorShape& shape,
const string& debug_info, bool on_host)
: context_(context), debug_info_(debug_info), on_host_(on_host) {
AllocatorAttributes alloc_attr;
if (on_host) {
// Allocate pinned memory on the host to avoid unnecessary
// synchronization.
alloc_attr.set_on_host(true);
alloc_attr.set_gpu_compatible(true);
}
TF_CHECK_OK(context->allocate_temp(DataTypeToEnum<Scalar>::value, shape,
&scratch_tensor_, alloc_attr));
}
virtual ~ScratchSpace() {}
Scalar* mutable_data() {
return scratch_tensor_.template flat<Scalar>().data();
}
const Scalar* data() const {
return scratch_tensor_.template flat<Scalar>().data();
}
Scalar& operator()(int64 i) {
return scratch_tensor_.template flat<Scalar>()(i);
}
const Scalar& operator()(int64 i) const {
return scratch_tensor_.template flat<Scalar>()(i);
}
int64 bytes() const { return scratch_tensor_.TotalBytes(); }
int64 size() const { return scratch_tensor_.NumElements(); }
const string& debug_info() const { return debug_info_; }
Tensor& tensor() { return scratch_tensor_; }
const Tensor& tensor() const { return scratch_tensor_; }
// Returns true if this ScratchSpace is in host memory.
bool on_host() const { return on_host_; }
protected:
OpKernelContext* context() const { return context_; }
private:
OpKernelContext* context_; // not owned
const string debug_info_;
const bool on_host_;
Tensor scratch_tensor_;
};
} // namespace tensorflow
#endif // TENSORFLOW_USE_ROCM
#endif // TENSORFLOW_CORE_KERNELS_ROCM_SOLVERS_H_

View File

@ -84,6 +84,34 @@ Status MatrixSolveShapeFn(InferenceContext* c, bool square) {
return Status::OK(); return Status::OK();
} }
// The first input is [...,M,M] and second input is [...,M,N].
// Output is [...,M,N].
Status MatrixTriangularSolveShapeFn(InferenceContext* c) {
ShapeHandle lhs;
ShapeHandle rhs;
TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &lhs));
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
ShapeHandle lhs_batch_shape;
ShapeHandle rhs_batch_shape;
ShapeHandle output_batch_shape;
// Make the common batch subshape.
TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
c, lhs_batch_shape, rhs_batch_shape, true, &output_batch_shape));
DimensionHandle m;
// lhs and rhs have the same value for m to be compatible.
TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, -1), c->Dim(rhs, -2), &m));
ShapeHandle out;
// Build final shape (batch_shape + m + n) in <out>.
TF_RETURN_IF_ERROR(
c->Concatenate(output_batch_shape, c->Matrix(m, c->Dim(rhs, -1)), &out));
c->set_output(0, out);
return Status::OK();
}
// Input is [...,N,N]. Outputs are: // Input is [...,N,N]. Outputs are:
// [...,N];[0], if compute_v is false, // [...,N];[0], if compute_v is false,
// [...,N];[...,N,N], if compute_v is true. // [...,N];[...,N,N], if compute_v is true.
@ -426,7 +454,7 @@ REGISTER_OP("MatrixTriangularSolve")
.Attr("adjoint: bool = False") .Attr("adjoint: bool = False")
.Attr("T: {double, float, half, complex64, complex128}") .Attr("T: {double, float, half, complex64, complex128}")
.SetShapeFn([](InferenceContext* c) { .SetShapeFn([](InferenceContext* c) {
return MatrixSolveShapeFn(c, true /* square (*/); return MatrixTriangularSolveShapeFn(c);
}); });
REGISTER_OP("MatrixSolveLs") REGISTER_OP("MatrixSolveLs")

View File

@ -122,14 +122,12 @@ TEST(LinalgOpsTest, SelfAdjointEigV2_ShapeFn) {
"[d0_0,d0_1,d0_2,d0_3|d0_4];[d0_0,d0_1,d0_2,d0_3|d0_4,d0_3|d0_4]"); "[d0_0,d0_1,d0_2,d0_3|d0_4];[d0_0,d0_1,d0_2,d0_3|d0_4,d0_3|d0_4]");
} }
TEST(LinalgOpsTest, SquareMatrixSolve_ShapeFn) { TEST(LinalgOpsTest, MatrixSolve_ShapeFn) {
for (const char* op_name : {"MatrixSolve", "MatrixTriangularSolve"}) { ShapeInferenceTestOp op("MatrixSolve");
ShapeInferenceTestOp op(op_name);
INFER_OK(op, "?;?", "?"); INFER_OK(op, "?;?", "?");
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1];?"); INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1];?");
INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,2];?"); INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,2];?");
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[5,?,?];[6]");
"[5,?,?];[6]");
INFER_ERROR("Shapes must be equal rank, but are 0 and 1", op, INFER_ERROR("Shapes must be equal rank, but are 0 and 1", op,
"[5,?];[6,?,?]"); "[5,?];[6,?,?]");
@ -149,7 +147,29 @@ TEST(LinalgOpsTest, SquareMatrixSolve_ShapeFn) {
INFER_OK(op, "[10,?,?,1];[?,20,1,?]", "[d0_0,d1_1,d0_3|d1_2,d1_3]"); INFER_OK(op, "[10,?,?,1];[?,20,1,?]", "[d0_0,d1_1,d0_3|d1_2,d1_3]");
INFER_OK(op, "[10,?,1,1];[?,20,?,?]", "[d0_0,d1_1,d0_2,d1_3]"); INFER_OK(op, "[10,?,1,1];[?,20,?,?]", "[d0_0,d1_1,d0_2,d1_3]");
INFER_OK(op, "[10,?,1,1];[?,20,1,?]", "[d0_0,d1_1,d0_2|d0_3|d1_2,d1_3]"); INFER_OK(op, "[10,?,1,1];[?,20,1,?]", "[d0_0,d1_1,d0_2|d0_3|d1_2,d1_3]");
} }
TEST(LinalgOpsTest, MatrixTriangularSolve_ShapeFn) {
ShapeInferenceTestOp op("MatrixTriangularSolve");
INFER_OK(op, "?;?", "?");
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1];?");
INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,2];?");
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[5,?,?];[6]");
// Inputs are [...,M,M] and [...,M,K]. Output is [...,M,K].
// First test where ... is empty.
INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]");
INFER_OK(op, "[?,?];[1,?]", "[d1_0,d1_1]");
INFER_OK(op, "[1,?];[1,?]", "[d0_0|d1_0,d1_1]");
INFER_OK(op, "[?,1];[1,?]", "[d0_1|d1_0,d1_1]");
INFER_OK(op, "[1,1];[?,?]", "[d0_0,d1_1]");
INFER_OK(op, "[1,1];[1,?]", "[d0_0|d0_1|d1_0,d1_1]");
// Test with ... being 2-d.
INFER_OK(op, "[10,?,?,?];[?,20,1,?]", "[d0_0,d1_1,d1_2,d1_3]");
INFER_OK(op, "[10,?,1,?];[?,20,1,?]", "[d0_0,d1_1,d0_2|d1_2,d1_3]");
INFER_OK(op, "[10,?,?,1];[?,20,1,?]", "[d0_0,d1_1,d0_3|d1_2,d1_3]");
INFER_OK(op, "[10,?,1,1];[?,20,?,?]", "[d0_0,d1_1,d0_2,d1_3]");
INFER_OK(op, "[10,?,1,1];[?,20,1,?]", "[d0_0,d1_1,d0_2|d0_3|d1_2,d1_3]");
} }
TEST(LinalgOpsTest, MatrixSolveLs_ShapeFn) { TEST(LinalgOpsTest, MatrixSolveLs_ShapeFn) {

View File

@ -755,8 +755,9 @@ cuda_py_test(
cuda_py_test( cuda_py_test(
name = "matrix_triangular_solve_op_test", name = "matrix_triangular_solve_op_test",
size = "small", size = "medium",
srcs = ["matrix_triangular_solve_op_test.py"], srcs = ["matrix_triangular_solve_op_test.py"],
shard_count = 3,
deps = [ deps = [
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:linalg_ops", "//tensorflow/python:linalg_ops",

View File

@ -20,7 +20,6 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import linalg_ops
@ -68,29 +67,30 @@ class MatrixTriangularSolveOpTest(test.TestCase):
else: else:
a_np = a a_np = a
if adjoint: if adjoint:
a_np = np.conj(np.transpose(a_np)) axes = list(range(len(a_np.shape)))
axes[-2] = -1
axes[-1] = -2
a_np = np.conj(np.transpose(a_np, axes=axes))
if batch_dims is not None: if batch_dims is not None:
a = np.tile(a, batch_dims + [1, 1]) a = np.tile(a, batch_dims + [1, 1])
a_np = np.tile(a_np, batch_dims + [1, 1]) a_np = np.tile(a_np, batch_dims + [1, 1])
b = np.tile(b, batch_dims + [1, 1]) b = np.tile(b, batch_dims + [1, 1])
with self.cached_session(use_gpu=True) as sess: def broadcast(a, b):
b1 = b + np.zeros(a.shape[:-2] + (1, 1), dtype=b.dtype)
return a, b1
a_tf = a
b_tf = b
if use_placeholder: if use_placeholder:
a_tf = array_ops.placeholder(a.dtype) a_tf = array_ops.placeholder_with_default(a_tf, shape=None)
b_tf = array_ops.placeholder(b.dtype) b_tf = array_ops.placeholder_with_default(b_tf, shape=None)
tf_ans = linalg_ops.matrix_triangular_solve(
a_tf, b_tf, lower=lower, adjoint=adjoint)
tf_val = sess.run(tf_ans, feed_dict={a_tf: a, b_tf: b})
np_ans = np.linalg.solve(a_np, b)
else:
a_tf = constant_op.constant(a)
b_tf = constant_op.constant(b)
tf_ans = linalg_ops.matrix_triangular_solve( tf_ans = linalg_ops.matrix_triangular_solve(
a_tf, b_tf, lower=lower, adjoint=adjoint) a_tf, b_tf, lower=lower, adjoint=adjoint)
tf_val = self.evaluate(tf_ans) tf_val = self.evaluate(tf_ans)
a_np, b = broadcast(a_np, b)
np_ans = np.linalg.solve(a_np, b) np_ans = np.linalg.solve(a_np, b)
self.assertEqual(np_ans.shape, tf_ans.get_shape())
self.assertEqual(np_ans.shape, tf_val.shape) self.assertEqual(np_ans.shape, tf_val.shape)
self.assertAllClose(np_ans, tf_val) self.assertAllClose(np_ans, tf_val)
@ -136,6 +136,48 @@ class MatrixTriangularSolveOpTest(test.TestCase):
# Batch of 3x2x2x2 matrices, 3x2x2x3 right-hand sides. # Batch of 3x2x2x2 matrices, 3x2x2x3 right-hand sides.
self._verifySolveAllWaysReal(matrix, rhs, batch_dims=[3, 2]) self._verifySolveAllWaysReal(matrix, rhs, batch_dims=[3, 2])
@test_util.run_deprecated_v1
def testSolveBatchBroadcast(self):
# 2 x 2 x 2
matrix = np.array([[[1., 0.], [3., 4.]], [[1., 0.], [2., 1.]]])
# 2 x 3
rhs = np.array([[1., 0., 1.], [0., 1., 1.]])
# 2 x 2 x 3
self._verifySolveAllWaysReal(matrix, rhs)
# 2 x 2 x 2
matrix2 = np.array([[[1., 0.], [3., 4.]], [[2., 0.], [1., 6.3]]])
# 1 x 2 x 3
rhs = np.array([[[1., 0., 1.], [0., 1., 1.]]])
# 2 x 2 x 3
self._verifySolveAllWaysReal(matrix2, rhs)
@test_util.run_deprecated_v1
def testSolveBatchBroadcastLargerBatches(self):
# 1 x 10 x 10
matrix = np.random.uniform(low=1, high=2., size=[1, 10, 10])
# 10 x 1
rhs = np.random.uniform(size=[10, 1])
# 1 x 10 x 1
self._verifySolveAllWaysReal(matrix, rhs)
# 2 x 10 x 10
matrix = np.random.uniform(low=1, high=2., size=[2, 10, 10])
# 10 x 1
rhs = np.random.uniform(size=[10, 1])
# 2 x 10 x 1
self._verifySolveAllWaysReal(matrix, rhs)
# 2 x 257 x 257
matrix = np.random.uniform(low=1, high=2., size=[2, 257, 257])
# Also ensure the matrix is well conditioned by making it diagonally
# dominant.
np.fill_diagonal(matrix[0, ...], 257 * 2)
np.fill_diagonal(matrix[1, ...], 257 * 2)
# 257 x 1
rhs = np.random.uniform(size=[257, 1])
# 2 x 257 x 1
self._verifySolveAllWaysReal(matrix, rhs)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testSolveBatchComplex(self): def testSolveBatchComplex(self):
if test.is_built_with_rocm(): if test.is_built_with_rocm():

View File

@ -607,6 +607,7 @@ def _MatrixSolveLsGrad(op, grad):
def _MatrixTriangularSolveGrad(op, grad): def _MatrixTriangularSolveGrad(op, grad):
"""Gradient for MatrixTriangularSolve.""" """Gradient for MatrixTriangularSolve."""
a = op.inputs[0] a = op.inputs[0]
b = op.inputs[1]
adjoint_a = op.get_attr("adjoint") adjoint_a = op.get_attr("adjoint")
lower_a = op.get_attr("lower") lower_a = op.get_attr("lower")
c = op.outputs[0] c = op.outputs[0]
@ -620,7 +621,16 @@ def _MatrixTriangularSolveGrad(op, grad):
grad_a = array_ops.matrix_band_part(grad_a, -1, 0) grad_a = array_ops.matrix_band_part(grad_a, -1, 0)
else: else:
grad_a = array_ops.matrix_band_part(grad_a, 0, -1) grad_a = array_ops.matrix_band_part(grad_a, 0, -1)
return (grad_a, grad_b) # If the static batch shapes are equal, we don't need to unbroadcast.
if (a.shape.is_fully_defined() and b.shape.is_fully_defined() and
a.shape[:-2] == b.shape[:-2]):
return grad_a, grad_b
a_shape = array_ops.shape(a)
b_shape = array_ops.shape(b)
ra, rb = array_ops.broadcast_gradient_args(a_shape[:-2], b_shape[:-2])
grad_a = array_ops.reshape(math_ops.reduce_sum(grad_a, axis=ra), a_shape)
grad_b = array_ops.reshape(math_ops.reduce_sum(grad_b, axis=rb), b_shape)
return grad_a, grad_b
@ops.RegisterGradient("SelfAdjointEigV2") @ops.RegisterGradient("SelfAdjointEigV2")

View File

@ -79,6 +79,68 @@ def _RegularizedGramianCholesky(matrix, l2_regularizer, first_kind):
return gen_linalg_ops.cholesky(gramian) return gen_linalg_ops.cholesky(gramian)
@tf_export(
'linalg.triangular_solve',
v1=['linalg.triangular_solve', 'matrix_triangular_solve'])
def matrix_triangular_solve(matrix, rhs, lower=True, adjoint=False, name=None):
"""Solve systems of linear equations with upper or lower triangular matrices.
`matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form
square matrices. If `lower` is `True` then the strictly upper triangular part
of each inner-most matrix is assumed to be zero and not accessed. If `lower`
is `False` then the strictly lower triangular part of each inner-most matrix
is assumed to be zero and not accessed. `rhs` is a tensor of shape
`[..., M, N]`.
The output is a tensor of shape `[..., M, N]`. If `adjoint` is `True` then the
innermost matrices in output satisfy matrix equations `
sum_k matrix[..., i, k] * output[..., k, j] = rhs[..., i, j]`.
If `adjoint` is `False` then the
innermost matrices in output satisfy matrix equations
`sum_k adjoint(matrix[..., i, k]) * output[..., k, j] = rhs[..., i, j]`.
Example:
>>> a = tf.constant([[3, 0, 0, 0],
... [2, 1, 0, 0],
... [1, 0, 1, 0],
... [1, 1, 1, 1]], dtype=tf.float32)
>>> b = tf.constant([[4], [2], [4], [2]], dtype=tf.float32)
>>> x = tf.linalg.triangular_solve(a, b, lower=True)
>>> x
<tf.Tensor: shape=(4, 1), dtype=float32, numpy=
array([[ 1.3333334 ],
[-0.66666675],
[ 2.6666665 ],
[-1.3333331 ]], dtype=float32)>
>>> tf.matmul(a, x)
<tf.Tensor: shape=(4, 1), dtype=float32, numpy=
array([[4.],
[2.],
[4.],
[2.]], dtype=float32)>
Args:
matrix: A `Tensor`. Must be one of the following types: `float64`,
`float32`, `half`, `complex64`, `complex128`. Shape is `[..., M, M]`.
rhs: A `Tensor`. Must have the same type as `matrix`. Shape is `[..., M,
N]`.
lower: An optional `bool`. Defaults to `True`. Boolean indicating whether
the innermost matrices in matrix are lower or upper triangular.
adjoint: An optional `bool`. Defaults to `False`. Boolean indicating whether
to solve with matrix or its (block-wise) adjoint.
name: A name for the operation (optional).
Returns:
A `Tensor`. Has the same type as matrix, and shape is `[..., M, N]`.
"""
with ops.name_scope(name, 'triangular_solve', [matrix, rhs]):
return gen_linalg_ops.matrix_triangular_solve(
matrix, rhs, lower=lower, adjoint=adjoint)
@tf_export( @tf_export(
'linalg.cholesky_solve', v1=['linalg.cholesky_solve', 'cholesky_solve']) 'linalg.cholesky_solve', v1=['linalg.cholesky_solve', 'cholesky_solve'])
@deprecation.deprecated_endpoints('cholesky_solve') @deprecation.deprecated_endpoints('cholesky_solve')

View File

@ -2872,9 +2872,9 @@ def _convert_log_matrix_determinant(pfor_input):
@RegisterPFor("MatrixTriangularSolve") @RegisterPFor("MatrixTriangularSolve")
def _convert_matrix_triangular_solve(pfor_input): def _convert_matrix_triangular_solve(pfor_input):
pfor_input.stack_inputs() pfor_input.expanddim_inputs_for_broadcast()
matrix = pfor_input.stacked_input(0) matrix = pfor_input.input(0)[0]
rhs = pfor_input.stacked_input(1) rhs = pfor_input.input(1)[0]
lower = pfor_input.get_attr("lower") lower = pfor_input.get_attr("lower")
adjoint = pfor_input.get_attr("adjoint") adjoint = pfor_input.get_attr("adjoint")
output = linalg_ops.matrix_triangular_solve( output = linalg_ops.matrix_triangular_solve(