Add Broadcasted Matrix Triangular Solve.

Add Numpy-style broadcasting in the batch dimensions for tf.linalg.triangular_solve op. The last two dimensions of both operands constitute the matrix dimensions. The dimensions beyond these are broadcasted to form a common output shape with the standard NumPy broadcasting rules. (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
Note: This implementation differs from Numpy's behavior in that vectors (rank-1 Tensors) are not pr...
PiperOrigin-RevId: 289978628
Change-Id: I66e41e292e57e6df8111745cbe47ccffacb53edc
This commit is contained in:
Smit Hinsu 2020-01-15 18:21:45 -08:00 committed by TensorFlower Gardener
parent a5218435ec
commit c8e8ba577e
16 changed files with 316 additions and 1019 deletions

View File

@ -44,17 +44,15 @@ square matrices. If `lower` is `True` then the strictly upper triangular part
of each inner-most matrix is assumed to be zero and not accessed.
If `lower` is False then the strictly lower triangular part of each inner-most
matrix is assumed to be zero and not accessed.
`rhs` is a tensor of shape `[..., M, N]`.
`rhs` is a tensor of shape `[..., M, K]`.
The output is a tensor of shape `[..., M, N]`. If `adjoint` is
The output is a tensor of shape `[..., M, K]`. If `adjoint` is
`True` then the innermost matrices in `output` satisfy matrix equations
`matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`.
If `adjoint` is `False` then the strictly then the innermost matrices in
`output` satisfy matrix equations
`adjoint(matrix[..., i, k]) * output[..., k, j] = rhs[..., i, j]`.
Note, the batch shapes for the inputs only need to broadcast.
Example:
```python

View File

@ -1,4 +1,10 @@
op {
graph_op_name: "MatrixTriangularSolve"
visibility: HIDDEN
endpoint {
name: "linalg.triangular_solve"
}
endpoint {
name: "matrix_triangular_solve"
deprecation_version: 2
}
}

View File

@ -3588,14 +3588,10 @@ tf_kernel_library(
tf_kernel_library(
name = "matrix_triangular_solve_op",
hdrs = ["matrix_triangular_solve_op_impl.h"],
prefix = "matrix_triangular_solve_op",
deps = LINALG_DEPS + if_cuda([
"//tensorflow/core/platform/default/build_config:cublas_plugin",
]) + [
":fill_functor",
"//tensorflow/core:stream_executor",
],
]),
)
tf_kernel_library(
@ -4183,25 +4179,6 @@ tf_cuda_cc_test(
],
)
tf_cuda_cc_test(
name = "matrix_triangular_solve_op_test",
size = "small",
srcs = ["matrix_triangular_solve_op_test.cc"],
deps = [
":broadcast_to_op",
":matrix_triangular_solve_op",
":ops_testutil",
":ops_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_cuda_cc_test(
name = "scan_ops_test",
size = "small",

View File

@ -900,106 +900,6 @@ static inline Status MatInvBatchedImpl(
TF_CALL_LAPACK_TYPES(MATINV_BATCHED_INSTANCE);
template <typename Scalar, typename SolverFnT>
static inline Status TrsmImpl(SolverFnT solver, cublasHandle_t cublas_handle,
cublasSideMode_t side, cublasFillMode_t uplo,
cublasOperation_t trans, cublasDiagType_t diag,
int m, int n,
const Scalar* alpha, /* host or device pointer */
const Scalar* A, int lda, Scalar* B, int ldb) {
mutex_lock lock(handle_map_mutex);
using CudaScalar = typename CUDAComplexT<Scalar>::type;
TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, side, uplo, trans, diag, m, n,
reinterpret_cast<const CudaScalar*>(alpha),
reinterpret_cast<const CudaScalar*>(A), lda,
reinterpret_cast<CudaScalar*>(B), ldb));
return Status::OK();
}
#define TRSM_INSTANCE(Scalar, type_prefix) \
template <> \
Status CudaSolver::Trsm<Scalar>( \
cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, \
cublasDiagType_t diag, int m, int n, \
const Scalar* alpha, /* host or device pointer */ \
const Scalar* A, int lda, Scalar* B, int ldb) { \
return TrsmImpl(BLAS_SOLVER_FN(trsm, type_prefix), cublas_handle_, side, \
uplo, trans, diag, m, n, alpha, A, lda, B, ldb); \
}
TF_CALL_LAPACK_TYPES(TRSM_INSTANCE);
template <typename Scalar, typename SolverFnT>
static inline Status TrsvImpl(SolverFnT solver, cublasHandle_t cublas_handle,
cublasFillMode_t uplo, cublasOperation_t trans,
cublasDiagType_t diag, int n, const Scalar* A,
int lda, Scalar* x, int incx) {
mutex_lock lock(handle_map_mutex);
using CudaScalar = typename CUDAComplexT<Scalar>::type;
TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, uplo, trans, diag, n,
reinterpret_cast<const CudaScalar*>(A), lda,
reinterpret_cast<CudaScalar*>(x), incx));
return Status::OK();
}
#define TRSV_INSTANCE(Scalar, type_prefix) \
template <> \
Status CudaSolver::Trsv<Scalar>( \
cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, \
int n, const Scalar* A, int lda, Scalar* x, int incx) { \
return TrsvImpl(BLAS_SOLVER_FN(trsv, type_prefix), cublas_handle_, uplo, \
trans, diag, n, A, lda, x, incx); \
}
TF_CALL_LAPACK_TYPES(TRSV_INSTANCE);
template <typename Scalar, typename SolverFnT>
static inline Status TrsmBatchedImpl(
SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
cublasHandle_t cublas_handle, cublasSideMode_t side, cublasFillMode_t uplo,
cublasOperation_t trans, cublasDiagType_t diag, int m, int n,
const Scalar* alpha, const Scalar* const host_a_dev_ptrs[], int lda,
Scalar* host_b_dev_ptrs[], int ldb, int batch_size) {
mutex_lock lock(handle_map_mutex);
using CudaScalar = typename CUDAComplexT<Scalar>::type;
ScratchSpace<uint8> dev_a_dev_ptrs =
cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
/* on_host */ false);
ScratchSpace<uint8> dev_b_dev_ptrs =
cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
/* on_host */ false);
if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
return errors::Internal("TrsmBatched: failed to copy pointers to device");
}
if (!CopyHostToDevice(context, dev_b_dev_ptrs.mutable_data() /* dest */,
host_b_dev_ptrs /* source */, dev_b_dev_ptrs.bytes())) {
return errors::Internal("TrsmBatched: failed to copy pointers to device");
}
TF_RETURN_IF_CUBLAS_ERROR(
solver(cublas_handle, side, uplo, trans, diag, m, n,
reinterpret_cast<const CudaScalar*>(alpha),
reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()),
lda, reinterpret_cast<CudaScalar**>(dev_b_dev_ptrs.mutable_data()),
ldb, batch_size));
return Status::OK();
}
#define TRSM_BATCHED_INSTANCE(Scalar, type_prefix) \
template <> \
Status CudaSolver::TrsmBatched( \
cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, \
cublasDiagType_t diag, int m, int n, const Scalar* alpha, \
const Scalar* const dev_Aarray[], int lda, Scalar* dev_Barray[], \
int ldb, int batch_size) { \
return TrsmBatchedImpl(BLAS_SOLVER_FN(trsmBatched, type_prefix), this, \
context_, cublas_handle_, side, uplo, trans, diag, \
m, n, alpha, dev_Aarray, lda, dev_Barray, ldb, \
batch_size); \
}
TF_CALL_LAPACK_TYPES(TRSM_BATCHED_INSTANCE);
} // namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -333,28 +333,6 @@ class CudaSolver {
int lda, Scalar* dev_S, Scalar* dev_U, int ldu,
Scalar* dev_V, int ldv, int* dev_lapack_info,
int batch_size);
// Triangular solve
// Returns Status::OK() if the kernel was launched successfully.
// See https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-trsm
template <typename Scalar>
Status Trsm(cublasSideMode_t side, cublasFillMode_t uplo,
cublasOperation_t trans, cublasDiagType_t diag, int m, int n,
const Scalar* alpha, const Scalar* A, int lda, Scalar* B,
int ldb);
template <typename Scalar>
Status Trsv(cublasFillMode_t uplo, cublasOperation_t trans,
cublasDiagType_t diag, int n, const Scalar* A, int lda, Scalar* x,
int incx);
// See
// https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-trsmbatched
template <typename Scalar>
Status TrsmBatched(cublasSideMode_t side, cublasFillMode_t uplo,
cublasOperation_t trans, cublasDiagType_t diag, int m,
int n, const Scalar* alpha,
const Scalar* const dev_Aarray[], int lda,
Scalar* dev_Barray[], int ldb, int batch_size);
private:
OpKernelContext* context_; // not owned.

View File

@ -0,0 +1,258 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/linalg_ops.cc.
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace tensorflow {
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace {
template <typename Scalar>
se::DeviceMemory<Scalar> AsDeviceMemory(const Scalar* gpu_memory) {
se::DeviceMemoryBase wrapped(const_cast<Scalar*>(gpu_memory));
se::DeviceMemory<Scalar> typed(wrapped);
return typed;
}
} // namespace
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <class Scalar>
class MatrixTriangularSolveOp : public LinearAlgebraOp<Scalar> {
public:
INHERIT_LINALG_TYPEDEFS(Scalar);
explicit MatrixTriangularSolveOp(OpKernelConstruction* context)
: Base(context), lower_(true), adjoint_(false) {
OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_));
OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
}
void ValidateInputMatrixShapes(
OpKernelContext* context,
const TensorShapes& input_matrix_shapes) const final {
Base::ValidateSquareSolver(context, input_matrix_shapes);
}
TensorShapes GetOutputMatrixShapes(
const TensorShapes& input_matrix_shapes) const final {
return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1),
input_matrix_shapes[1].dim_size(1)})});
}
int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final {
double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0));
double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1));
double cost = rows * rows * num_rhss *
(Eigen::TensorOpCost::AddCost<Scalar>() +
Eigen::TensorOpCost::MulCost<Scalar>());
return cost >= static_cast<double>(kint64max) ? kint64max
: static_cast<int64>(cost);
}
bool EnableInputForwarding() const final { return false; }
void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
MatrixMaps* outputs) final {
const ConstMatrixMap& matrix = inputs[0];
const ConstMatrixMap& rhs = inputs[1];
MatrixMap& output = outputs->at(0);
if (matrix.rows() == 0 || rhs.rows() == 0 || rhs.cols() == 0) {
// To be consistent with the MatrixInverse op, we define the solution for
// an empty set of equation as the empty matrix.
return;
}
const RealScalar min_abs_pivot = matrix.diagonal().cwiseAbs().minCoeff();
OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
errors::InvalidArgument("Input matrix is not invertible."));
if (lower_) {
auto triangle = matrix.template triangularView<Eigen::Lower>();
if (adjoint_) {
output.noalias() = triangle.adjoint().solve(rhs);
} else {
output.noalias() = triangle.solve(rhs);
}
} else {
auto triangle = matrix.template triangularView<Eigen::Upper>();
if (adjoint_) {
output.noalias() = triangle.adjoint().solve(rhs);
} else {
output.noalias() = triangle.solve(rhs);
}
}
}
private:
bool lower_;
bool adjoint_;
TF_DISALLOW_COPY_AND_ASSIGN(MatrixTriangularSolveOp);
};
REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
(MatrixTriangularSolveOp<float>), float);
REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
(MatrixTriangularSolveOp<double>), double);
REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
(MatrixTriangularSolveOp<complex64>), complex64);
REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
(MatrixTriangularSolveOp<complex128>), complex128);
REGISTER_LINALG_OP_CPU("BatchMatrixTriangularSolve",
(MatrixTriangularSolveOp<float>), float);
REGISTER_LINALG_OP_CPU("BatchMatrixTriangularSolve",
(MatrixTriangularSolveOp<double>), double);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// TODO(rmlarsen): Re-factor to
// 1. Enable buffer forwarding from rhs->out.
// 2. Save Memcpy when buffer forwarding is used.
// 3. Copy entire rhs in a single Memcpy when forwarding is not used.
template <class Scalar>
class MatrixTriangularSolveOpGPU : public LinearAlgebraOp<Scalar> {
public:
INHERIT_LINALG_TYPEDEFS(Scalar);
explicit MatrixTriangularSolveOpGPU(OpKernelConstruction* context)
: Base(context), lower_(true), adjoint_(false) {
OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_));
OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
}
void ValidateInputMatrixShapes(
OpKernelContext* context,
const TensorShapes& input_matrix_shapes) const final {
Base::ValidateSquareSolver(context, input_matrix_shapes);
}
TensorShapes GetOutputMatrixShapes(
const TensorShapes& input_matrix_shapes) const final {
return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1),
input_matrix_shapes[1].dim_size(1)})});
}
int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final {
double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0));
double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1));
double cost = rows * rows * num_rhss *
(Eigen::TensorOpCost::AddCost<Scalar>() +
Eigen::TensorOpCost::MulCost<Scalar>());
return cost >= static_cast<double>(kint64max) ? kint64max
: static_cast<int64>(cost);
}
bool EnableInputForwarding() const final { return false; }
void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
MatrixMaps* outputs) final {
const ConstMatrixMap& matrix = inputs[0];
const ConstMatrixMap& rhs = inputs[1];
MatrixMap& output = outputs->at(0);
if (matrix.rows() == 0 || rhs.rows() == 0 || rhs.cols() == 0) {
// To be consistent with the MatrixInverse op, we define the solution for
// an empty set of equation as the empty matrix.
return;
}
auto matrix_ptr = AsDeviceMemory(matrix.data());
auto rhs_ptr = AsDeviceMemory(rhs.data());
auto out_ptr = AsDeviceMemory(output.data());
auto* stream = context->op_device_context()->stream();
uint64 rhs_elems = rhs.rows() * rhs.cols();
bool copy_status =
stream->ThenMemcpyD2D(&out_ptr, rhs_ptr, sizeof(Scalar) * rhs_elems)
.ok();
if (!copy_status) {
context->SetStatus(
errors::Internal("Failed to copy rhs into output before solve"));
}
// Cublas does
// output = matrix \ rhs
// where matrix, rhs and output are assumed to be in column major.
// We want the output to be in row-major, so we can compute
// output' = rhs' / matrix' (' stands for transpose)
// Upper/lower needs to be swapped for this.
se::blas::UpperLower upper_lower_matrix;
se::blas::Transpose transpose_matrix;
if (lower_) {
upper_lower_matrix = se::blas::UpperLower::kUpper;
} else {
upper_lower_matrix = se::blas::UpperLower::kLower;
}
if (adjoint_) {
transpose_matrix = se::blas::Transpose::kConjugateTranspose;
} else {
transpose_matrix = se::blas::Transpose::kNoTranspose;
}
uint64 leading_dim_matrix = matrix.cols();
uint64 leading_dim_output = output.cols();
uint64 colmajor_rows = output.cols();
uint64 colmajor_cols = output.rows();
bool blas_launch_status =
stream
->ThenBlasTrsm(
se::blas::Side::kRight /*side*/, upper_lower_matrix /*uplo*/,
transpose_matrix /*trans*/,
se::blas::Diagonal::kNonUnit /*diag*/, colmajor_rows /*m*/,
colmajor_cols /*n*/, Scalar(1.0) /*alpha*/, matrix_ptr,
leading_dim_matrix /*lda*/, &out_ptr,
leading_dim_output /*ldb*/)
.ok();
if (!blas_launch_status) {
context->SetStatus(errors::Internal("Blas TRSM launch failed"));
}
}
private:
bool lower_;
bool adjoint_;
TF_DISALLOW_COPY_AND_ASSIGN(MatrixTriangularSolveOpGPU);
};
REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
(MatrixTriangularSolveOpGPU<float>), float);
REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
(MatrixTriangularSolveOpGPU<double>), double);
REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
(MatrixTriangularSolveOpGPU<complex64>), complex64);
REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
(MatrixTriangularSolveOpGPU<complex128>), complex128);
REGISTER_LINALG_OP_GPU("BatchMatrixTriangularSolve",
(MatrixTriangularSolveOpGPU<float>), float);
REGISTER_LINALG_OP_GPU("BatchMatrixTriangularSolve",
(MatrixTriangularSolveOpGPU<double>), double);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -1,28 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/matrix_triangular_solve_op_impl.h"
namespace tensorflow {
TF_CALL_complex64(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU);
TF_CALL_complex128(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TF_CALL_complex64(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU);
TF_CALL_complex128(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -1,431 +0,0 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/linalg_ops.cc.
//
#ifndef TENSORFLOW_CORE_KERNELS_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
#define TENSORFLOW_CORE_KERNELS_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/matmul_bcast.h"
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/transpose_functor.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename Scalar>
se::DeviceMemory<Scalar> AsDeviceMemory(const Scalar* gpu_memory) {
se::DeviceMemoryBase wrapped(const_cast<Scalar*>(gpu_memory));
se::DeviceMemory<Scalar> typed(wrapped);
return typed;
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Sequential batch matrix triangular solve kernel that calls Eigen's
// matrix triangular solve.
template <typename Scalar>
struct SequentialMatrixTriangularSolveKernel {
using Matrix =
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using ConstMatrixMap = Eigen::Map<const Matrix>;
using MatrixMap = Eigen::Map<Matrix>;
using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t,
int slice) {
return ConstMatrixMap(
t.flat<Scalar>().data() + slice * t.dim_size(1) * t.dim_size(2),
t.dim_size(1), t.dim_size(2));
}
static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) {
return MatrixMap(
t->flat<Scalar>().data() + slice * t->dim_size(1) * t->dim_size(2),
t->dim_size(1), t->dim_size(2));
}
static void Run(const Tensor& in_x, const Tensor& in_y, bool lower,
bool adjoint, const MatMulBCast& bcast, Tensor* out,
int start, int limit) {
const bool should_bcast = bcast.IsBroadcastingRequired();
const auto& x_batch_indices = bcast.x_batch_indices();
const auto& y_batch_indices = bcast.y_batch_indices();
for (int64 i = start; i < limit; ++i) {
const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
auto matrix = ConstTensorSliceToEigenMatrix(in_x, x_batch_index);
auto rhs = ConstTensorSliceToEigenMatrix(in_y, y_batch_index);
auto output = TensorSliceToEigenMatrix(out, i);
if (lower) {
auto triangle = matrix.template triangularView<Eigen::Lower>();
if (adjoint) {
output.noalias() = triangle.adjoint().solve(rhs);
} else {
output.noalias() = triangle.solve(rhs);
}
} else {
auto triangle = matrix.template triangularView<Eigen::Upper>();
if (adjoint) {
output.noalias() = triangle.adjoint().solve(rhs);
} else {
output.noalias() = triangle.solve(rhs);
}
}
}
}
};
template <typename Device, typename Scalar>
struct LaunchBatchMatrixTriangularSolve;
template <typename Scalar>
struct LaunchBatchMatrixTriangularSolve<CPUDevice, Scalar> {
static void Launch(OpKernelContext* context, const Tensor& in_x,
const Tensor& in_y, bool adjoint, bool lower,
const MatMulBCast& bcast, Tensor* out) {
// Number of matrix triangular solves i.e. size of the batch.
const int64 batch_size = bcast.output_batch_size();
const int64 cost_per_unit =
in_x.dim_size(1) * in_x.dim_size(1) * in_y.dim_size(2) / 2;
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
using Matrix =
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using ConstMatrixMap = Eigen::Map<const Matrix>;
using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
// Check diagonal before doing any solves.
auto matrix = ConstMatrixMap(in_x.flat<Scalar>().data(), in_x.dim_size(1),
in_x.dim_size(2));
const RealScalar min_abs_pivot = matrix.diagonal().cwiseAbs().minCoeff();
OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
errors::InvalidArgument("Input matrix is not invertible."));
Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
cost_per_unit,
[&in_x, &in_y, adjoint, lower, &bcast, out](int start, int limit) {
SequentialMatrixTriangularSolveKernel<Scalar>::Run(
in_x, in_y, lower, adjoint, bcast, out, start, limit);
});
}
};
template <typename Device, typename Scalar>
class BaseMatrixTriangularSolveOp : public OpKernel {
public:
explicit BaseMatrixTriangularSolveOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_));
OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
}
~BaseMatrixTriangularSolveOp() override {}
void Compute(OpKernelContext* ctx) override {
const Tensor& in0 = ctx->input(0);
const Tensor& in1 = ctx->input(1);
ValidateInputTensors(ctx, in0, in1);
MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
OP_REQUIRES(
ctx, bcast.IsValid(),
errors::InvalidArgument(
"In[0] and In[1] must have compatible batch dimensions: ",
in0.shape().DebugString(), " vs. ", in1.shape().DebugString()));
TensorShape out_shape = bcast.output_batch_shape();
auto batch_size = bcast.output_batch_size();
auto d0 = in0.dim_size(in0.dims() - 2);
auto d1 = in0.dim_size(in0.dims() - 1);
Tensor in0_reshaped;
OP_REQUIRES(
ctx,
in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})),
errors::Internal("Failed to reshape In[0] from ",
in0.shape().DebugString()));
auto d2 = in1.dim_size(in1.dims() - 2);
auto d3 = in1.dim_size(in1.dims() - 1);
Tensor in1_reshaped;
OP_REQUIRES(
ctx,
in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
errors::Internal("Failed to reshape In[1] from ",
in1.shape().DebugString()));
if (adjoint_) std::swap(d0, d1);
OP_REQUIRES(ctx, d1 == d2,
errors::InvalidArgument(
"In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
in0.shape().DebugString(), " ", in1.shape().DebugString(),
" ", lower_, " ", adjoint_));
out_shape.AddDim(d0);
out_shape.AddDim(d3);
Tensor* out = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
if (out->NumElements() == 0) {
return;
}
Tensor out_reshaped;
OP_REQUIRES(ctx,
out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})),
errors::Internal("Failed to reshape output from ",
out->shape().DebugString()));
LaunchBatchMatrixTriangularSolve<Device, Scalar>::Launch(
ctx, in0_reshaped, in1_reshaped, adjoint_, lower_, bcast,
&out_reshaped);
}
private:
virtual void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
const Tensor& in1) = 0;
bool lower_;
bool adjoint_;
};
template <class Device, class Scalar>
class MatrixTriangularSolveOp
: public BaseMatrixTriangularSolveOp<Device, Scalar> {
public:
explicit MatrixTriangularSolveOp(OpKernelConstruction* context)
: BaseMatrixTriangularSolveOp<Device, Scalar>(context) {}
~MatrixTriangularSolveOp() override {}
private:
void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
const Tensor& in1) override {
// Disallow broadcasting support. Ensure that all batch dimensions of the
// input tensors match.
OP_REQUIRES(ctx, in0.dims() == in1.dims(),
errors::InvalidArgument("In[0] and In[1] has different ndims: ",
in0.shape().DebugString(), " vs. ",
in1.shape().DebugString()));
const int ndims = in0.dims();
OP_REQUIRES(
ctx, ndims >= 2,
errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims));
for (int i = 0; i < ndims - 2; ++i) {
OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i),
errors::InvalidArgument(
"In[0].dim(", i, ") and In[1].dim(", i,
") must be the same: ", in0.shape().DebugString(), " vs ",
in1.shape().DebugString()));
}
}
};
template <class Device, class Scalar>
class MatrixTriangularSolveOpV2
: public BaseMatrixTriangularSolveOp<Device, Scalar> {
public:
explicit MatrixTriangularSolveOpV2(OpKernelConstruction* context)
: BaseMatrixTriangularSolveOp<Device, Scalar>(context) {}
~MatrixTriangularSolveOpV2() override {}
private:
void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
const Tensor& in1) override {
OP_REQUIRES(
ctx, in0.dims() >= 2,
errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims()));
OP_REQUIRES(
ctx, in1.dims() >= 2,
errors::InvalidArgument("In[0] ndims must be >= 2: ", in1.dims()));
}
};
#define REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU(TYPE) \
REGISTER_KERNEL_BUILDER(Name("MatrixTriangularSolve") \
.Device(DEVICE_CPU) \
.TypeConstraint<TYPE>("T"), \
MatrixTriangularSolveOpV2<CPUDevice, TYPE>); \
REGISTER_KERNEL_BUILDER(Name("BatchMatrixTriangularSolve") \
.Device(DEVICE_CPU) \
.TypeConstraint<TYPE>("T"), \
MatrixTriangularSolveOpV2<CPUDevice, TYPE>);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename Scalar>
struct LaunchBatchMatrixTriangularSolve<GPUDevice, Scalar> {
static void Launch(OpKernelContext* context, const Tensor& in_x,
const Tensor& in_y, bool adjoint, bool lower,
const MatMulBCast& bcast, Tensor* out) {
auto* stream = context->op_device_context()->stream();
const uint64 m = in_x.dim_size(1);
const uint64 n = out->dim_size(2);
// Do a memcpy when we don't need to broadcast.
if (!bcast.IsBroadcastingRequired() || out->shape() == in_y.shape()) {
auto src_device_mem = AsDeviceMemory(in_y.template flat<Scalar>().data());
auto dst_device_mem = AsDeviceMemory(out->template flat<Scalar>().data());
OP_REQUIRES(
context,
stream
->ThenMemcpyD2D(&dst_device_mem, src_device_mem,
bcast.y_batch_size() * m * n * sizeof(Scalar))
.ok(),
errors::Internal("MatrixTriangularSolveOpV2: failed to copy rhs "
"from device"));
} else {
std::vector<Scalar*> out_ptrs;
std::vector<const Scalar*> b_tmp_ptrs;
auto* b_base_ptr = in_y.template flat<Scalar>().data();
const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
for (int64 i = 0; i < bcast.y_batch_size(); ++i) {
b_tmp_ptrs.push_back(b_base_ptr + i * m * n);
}
for (int64 i = 0; i < bcast.output_batch_size(); ++i) {
auto src_device_mem = AsDeviceMemory(b_tmp_ptrs[b_batch_indices[i]]);
auto dst_device_mem =
AsDeviceMemory(out->template flat<Scalar>().data() + i * m * n);
OP_REQUIRES(
context,
stream
->ThenMemcpyD2D(&dst_device_mem, src_device_mem,
m * n * sizeof(Scalar))
.ok(),
errors::Internal("MatrixTriangularSolveOpV2: failed to copy rhs "
"from device"));
}
}
if (out->NumElements() == 0) {
return;
}
cublasSideMode_t side = CUBLAS_SIDE_RIGHT;
cublasFillMode_t uplo;
cublasOperation_t trans;
cublasDiagType_t diag = CUBLAS_DIAG_NON_UNIT;
// Cublas does
// output = matrix \ rhs
// where matrix, rhs and output are assumed to be in column major.
// We want the output to be in row-major, so we can compute
// output' = rhs' / matrix' (' stands for transpose)
// Upper/lower needs to be swapped for this.
uplo = lower ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
trans = adjoint ? CUBLAS_OP_C : CUBLAS_OP_N;
auto solver = absl::make_unique<CudaSolver>(context);
const uint64 leading_dim_matrix = m;
const uint64 leading_dim_output = n;
const uint64 colmajor_rows = n;
const uint64 colmajor_cols = m;
const int64 batch_size = bcast.output_batch_size();
std::vector<const Scalar*> a_ptrs;
std::vector<Scalar*> out_ptrs;
std::vector<const Scalar*> a_tmp_ptrs;
a_ptrs.reserve(batch_size);
out_ptrs.reserve(batch_size);
a_tmp_ptrs.reserve(bcast.x_batch_size());
auto* a_base_ptr = in_x.template flat<Scalar>().data();
auto* out_base_ptr = out->template flat<Scalar>().data();
if (!bcast.IsBroadcastingRequired()) {
for (int64 i = 0; i < batch_size; ++i) {
a_ptrs.push_back(a_base_ptr + i * m * m);
out_ptrs.push_back(out_base_ptr + i * m * n);
}
} else {
const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
for (int64 i = 0; i < bcast.x_batch_size(); ++i) {
a_tmp_ptrs.push_back(a_base_ptr + i * m * m);
}
for (int64 i = 0; i < batch_size; ++i) {
a_ptrs.push_back(a_tmp_ptrs[a_batch_indices[i]]);
out_ptrs.push_back(out_base_ptr + i * m * n);
}
}
typedef Scalar Coefficient;
const Scalar alpha = Scalar(1.0);
// TODO(b/146763573): Consider using Trsv here when the right hand side is
// a vector. This will require an explicit transpose since Trsv assumes
// CUBLAS_SIDE_LEFT.
if (batch_size == 1) {
OP_REQUIRES_OK(
context,
solver->Trsm(side, uplo, trans, diag, colmajor_rows, colmajor_cols,
&alpha, a_ptrs[0], leading_dim_matrix /*lda*/,
out_ptrs[0], leading_dim_output /*ldb*/));
} else {
// Heuristic for choosing between batched interface vs. non-batched
// interface. This is inspired by matrix_solve_op and can probably be
// tuned.
// TODO(b/146763573): Tune this heuristic.
const int kMaxMatrixSizeToBatchSizeRatio = 128;
const bool use_batched_solver =
m <= kMaxMatrixSizeToBatchSizeRatio * batch_size;
if (use_batched_solver) {
OP_REQUIRES_OK(
context, solver->TrsmBatched(
side, uplo, trans, diag, colmajor_rows, colmajor_cols,
&alpha, &a_ptrs[0], leading_dim_matrix /*lda*/,
&out_ptrs[0], leading_dim_output /*ldb*/, batch_size));
} else {
for (int batch = 0; batch < batch_size; ++batch) {
OP_REQUIRES_OK(
context, solver->Trsm(side, uplo, trans, diag, colmajor_rows,
colmajor_cols, &alpha, a_ptrs[batch],
leading_dim_matrix /*lda*/, out_ptrs[batch],
leading_dim_output /*ldb*/));
}
}
}
}
};
#define REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU(TYPE) \
REGISTER_KERNEL_BUILDER(Name("MatrixTriangularSolve") \
.Device(DEVICE_GPU) \
.TypeConstraint<TYPE>("T"), \
MatrixTriangularSolveOpV2<GPUDevice, TYPE>); \
REGISTER_KERNEL_BUILDER(Name("BatchMatrixTriangularSolve") \
.Device(DEVICE_GPU) \
.TypeConstraint<TYPE>("T"), \
MatrixTriangularSolveOpV2<GPUDevice, TYPE>);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_

View File

@ -1,32 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/matrix_triangular_solve_op_impl.h"
#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cuda.h"
#endif // GOOGLE_CUDA
namespace tensorflow {
TF_CALL_float(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU);
TF_CALL_double(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TF_CALL_float(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU);
TF_CALL_double(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -1,165 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/kernels/broadcast_to_op.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
namespace {
Node* BroadcastTo(Graph* g, Node* input, Node* shape) {
Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BroadcastTo")
.Input(input)
.Input(shape)
.Attr("Tidx", DT_INT64)
.Finalize(g, &ret));
return ret;
}
Node* MatrixTriangularSolve(Graph* g, Node* in0, Node* in1, bool adjoint) {
Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatrixTriangularSolve")
.Input(in0)
.Input(in1)
.Attr("lower", true)
.Attr("adjoint", adjoint)
.Finalize(g, &ret));
return ret;
}
template <typename T>
static Graph* MatrixTriangularSolveWithBroadcast(int64 b0, int64 b1, int64 m,
int64 n, bool manual_broadcast,
DataType type) {
Graph* g = new Graph(OpRegistry::Global());
Tensor in0(type, TensorShape({b0, m, m}));
// Set diagonal to non-zero to guarantee invertibility.
in0.flat<T>().setRandom();
auto matrix = Eigen::Map<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
in0.flat<T>().data(), in0.dim_size(1), in0.dim_size(2));
matrix.diagonal() =
(matrix.diagonal().cwiseAbs().array() + static_cast<T>(0.5));
Tensor in1(type, TensorShape({b1, m, n}));
in1.flat<T>().setRandom();
Tensor broadcasted_in0_shape(DT_INT64, TensorShape({3}));
Tensor broadcasted_in1_shape(DT_INT64, TensorShape({3}));
Node* in0_node = nullptr;
Node* in1_node = nullptr;
if (manual_broadcast) {
auto vec0 = broadcasted_in0_shape.vec<int64>();
auto vec1 = broadcasted_in1_shape.vec<int64>();
for (int i = 0; i < 3; ++i) {
vec0(i) = (i == 0 ? std::max(b0, b1) : in0.shape().dim_size(i));
vec1(i) = (i == 0 ? std::max(b0, b1) : in1.shape().dim_size(i));
}
in0_node = BroadcastTo(g, test::graph::Constant(g, in0),
test::graph::Constant(g, broadcasted_in0_shape));
in1_node = BroadcastTo(g, test::graph::Constant(g, in1),
test::graph::Constant(g, broadcasted_in1_shape));
} else {
in0_node = test::graph::Constant(g, in0);
in1_node = test::graph::Constant(g, in1);
}
MatrixTriangularSolve(g, in0_node, in1_node, false);
return g;
}
// Macro arguments names: --------------------------------------------------- //
// B1: batch size of LHS
// B2: batch size of RHS
// M: inner dimensions of LHS and RHS, outer dimension of LHS
// N: outer dimension of RHS
// MB: boolean indicating whether to use manual broadcasting
// T: C++ type of scalars (e.g. float, std::complex)
// TT: TensorFlow type of scalars (e.g. DT_FLOAT, DT_COMPLEX128
// D: Device (e.g. cpu, gpu)
#define BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, T, TT, D) \
static void \
BM_MatrixTriangularSolve##_##B1##_##B2##_##M##_##N##_##MB##_##TT##_##D( \
int iters) { \
testing::UseRealTime(); \
testing::ItemsProcessed(static_cast<int64>(iters) * std::max(B1, B2) * M * \
M * N * 2); \
test::Benchmark( \
#D, MatrixTriangularSolveWithBroadcast<T>(B1, B2, M, N, MB, TT)) \
.Run(iters); \
} \
BENCHMARK( \
BM_MatrixTriangularSolve##_##B1##_##B2##_##M##_##N##_##MB##_##TT##_##D);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define BM_MatrixTriangularSolve(B1, B2, M, N, MB) \
BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, float, DT_FLOAT, cpu); \
BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, double, DT_DOUBLE, cpu); \
BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, float, DT_FLOAT, gpu); \
BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, double, DT_DOUBLE, gpu);
#else
#define BM_MatrixTriangularSolve(B1, B2, M, N, MB) \
BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, float, DT_FLOAT, cpu); \
BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, double, DT_DOUBLE, cpu);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Square matrix triangular solve.
BM_MatrixTriangularSolve(32, 32, 512, 512, true);
BM_MatrixTriangularSolve(32, 32, 512, 512, false);
BM_MatrixTriangularSolve(1, 32, 512, 512, true);
BM_MatrixTriangularSolve(1, 32, 512, 512, false);
BM_MatrixTriangularSolve(32, 1, 512, 512, true);
BM_MatrixTriangularSolve(32, 1, 512, 512, false);
BM_MatrixTriangularSolve(128, 128, 512, 512, true);
BM_MatrixTriangularSolve(128, 128, 512, 512, false);
BM_MatrixTriangularSolve(1, 128, 512, 512, true);
BM_MatrixTriangularSolve(1, 128, 512, 512, false);
BM_MatrixTriangularSolve(128, 1, 512, 512, true);
BM_MatrixTriangularSolve(128, 1, 512, 512, false);
BM_MatrixTriangularSolve(1, 128, 1024, 1024, true);
BM_MatrixTriangularSolve(1, 128, 1024, 1024, false);
BM_MatrixTriangularSolve(128, 1, 1024, 1024, true);
BM_MatrixTriangularSolve(128, 1, 1024, 1024, false);
// Matrix-vector triangular solve.
BM_MatrixTriangularSolve(1, 128, 200, 1, true);
BM_MatrixTriangularSolve(1, 128, 200, 1, false);
BM_MatrixTriangularSolve(128, 1, 200, 1, true);
BM_MatrixTriangularSolve(128, 1, 200, 1, false);
// Matrix-vector triangular solve, large dimension.
BM_MatrixTriangularSolve(1, 128, 200, 10000, true);
BM_MatrixTriangularSolve(1, 128, 200, 10000, false);
BM_MatrixTriangularSolve(128, 1, 200, 10000, true);
BM_MatrixTriangularSolve(128, 1, 200, 10000, false);
} // namespace
} // namespace tensorflow

View File

@ -84,34 +84,6 @@ Status MatrixSolveShapeFn(InferenceContext* c, bool square) {
return Status::OK();
}
// The first input is [...,M,M] and second input is [...,M,N].
// Output is [...,M,N].
Status MatrixTriangularSolveShapeFn(InferenceContext* c) {
ShapeHandle lhs;
ShapeHandle rhs;
TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &lhs));
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
ShapeHandle lhs_batch_shape;
ShapeHandle rhs_batch_shape;
ShapeHandle output_batch_shape;
// Make the common batch subshape.
TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
c, lhs_batch_shape, rhs_batch_shape, true, &output_batch_shape));
DimensionHandle m;
// lhs and rhs have the same value for m to be compatible.
TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, -1), c->Dim(rhs, -2), &m));
ShapeHandle out;
// Build final shape (batch_shape + m + n) in <out>.
TF_RETURN_IF_ERROR(
c->Concatenate(output_batch_shape, c->Matrix(m, c->Dim(rhs, -1)), &out));
c->set_output(0, out);
return Status::OK();
}
// Input is [...,N,N]. Outputs are:
// [...,N];[0], if compute_v is false,
// [...,N];[...,N,N], if compute_v is true.
@ -454,7 +426,7 @@ REGISTER_OP("MatrixTriangularSolve")
.Attr("adjoint: bool = False")
.Attr("T: {double, float, half, complex64, complex128}")
.SetShapeFn([](InferenceContext* c) {
return MatrixTriangularSolveShapeFn(c);
return MatrixSolveShapeFn(c, true /* square (*/);
});
REGISTER_OP("MatrixSolveLs")

View File

@ -122,54 +122,34 @@ TEST(LinalgOpsTest, SelfAdjointEigV2_ShapeFn) {
"[d0_0,d0_1,d0_2,d0_3|d0_4];[d0_0,d0_1,d0_2,d0_3|d0_4,d0_3|d0_4]");
}
TEST(LinalgOpsTest, MatrixSolve_ShapeFn) {
ShapeInferenceTestOp op("MatrixSolve");
INFER_OK(op, "?;?", "?");
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1];?");
INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,2];?");
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[5,?,?];[6]");
INFER_ERROR("Shapes must be equal rank, but are 0 and 1", op,
"[5,?];[6,?,?]");
TEST(LinalgOpsTest, SquareMatrixSolve_ShapeFn) {
for (const char* op_name : {"MatrixSolve", "MatrixTriangularSolve"}) {
ShapeInferenceTestOp op(op_name);
INFER_OK(op, "?;?", "?");
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1];?");
INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,2];?");
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
"[5,?,?];[6]");
INFER_ERROR("Shapes must be equal rank, but are 0 and 1", op,
"[5,?];[6,?,?]");
INFER_OK(op, "[?,?];?", "[d0_0|d0_1,?]");
INFER_OK(op, "[?,?];?", "[d0_0|d0_1,?]");
// Inputs are [...,M,M] and [...,M,K]. Output is [...,M,K].
// First test where ... is empty.
INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]");
INFER_OK(op, "[?,?];[1,?]", "[d1_0,d1_1]");
INFER_OK(op, "[1,?];[1,?]", "[d0_0|d1_0,d1_1]");
INFER_OK(op, "[?,1];[1,?]", "[d0_1|d1_0,d1_1]");
INFER_OK(op, "[1,1];[?,?]", "[d0_0,d1_1]");
INFER_OK(op, "[1,1];[1,?]", "[d0_0|d0_1|d1_0,d1_1]");
// Test with ... being 2-d.
INFER_OK(op, "[10,?,?,?];[?,20,1,?]", "[d0_0,d1_1,d1_2,d1_3]");
INFER_OK(op, "[10,?,1,?];[?,20,1,?]", "[d0_0,d1_1,d0_2|d1_2,d1_3]");
INFER_OK(op, "[10,?,?,1];[?,20,1,?]", "[d0_0,d1_1,d0_3|d1_2,d1_3]");
INFER_OK(op, "[10,?,1,1];[?,20,?,?]", "[d0_0,d1_1,d0_2,d1_3]");
INFER_OK(op, "[10,?,1,1];[?,20,1,?]", "[d0_0,d1_1,d0_2|d0_3|d1_2,d1_3]");
}
TEST(LinalgOpsTest, MatrixTriangularSolve_ShapeFn) {
ShapeInferenceTestOp op("MatrixTriangularSolve");
INFER_OK(op, "?;?", "?");
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1];?");
INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,2];?");
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[5,?,?];[6]");
// Inputs are [...,M,M] and [...,M,K]. Output is [...,M,K].
// First test where ... is empty.
INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]");
INFER_OK(op, "[?,?];[1,?]", "[d1_0,d1_1]");
INFER_OK(op, "[1,?];[1,?]", "[d0_0|d1_0,d1_1]");
INFER_OK(op, "[?,1];[1,?]", "[d0_1|d1_0,d1_1]");
INFER_OK(op, "[1,1];[?,?]", "[d0_0,d1_1]");
INFER_OK(op, "[1,1];[1,?]", "[d0_0|d0_1|d1_0,d1_1]");
// Test with ... being 2-d.
INFER_OK(op, "[10,?,?,?];[?,20,1,?]", "[d0_0,d1_1,d1_2,d1_3]");
INFER_OK(op, "[10,?,1,?];[?,20,1,?]", "[d0_0,d1_1,d0_2|d1_2,d1_3]");
INFER_OK(op, "[10,?,?,1];[?,20,1,?]", "[d0_0,d1_1,d0_3|d1_2,d1_3]");
INFER_OK(op, "[10,?,1,1];[?,20,?,?]", "[d0_0,d1_1,d0_2,d1_3]");
INFER_OK(op, "[10,?,1,1];[?,20,1,?]", "[d0_0,d1_1,d0_2|d0_3|d1_2,d1_3]");
// Inputs are [...,M,M] and [...,M,K]. Output is [...,M,K].
// First test where ... is empty.
INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]");
INFER_OK(op, "[?,?];[1,?]", "[d1_0,d1_1]");
INFER_OK(op, "[1,?];[1,?]", "[d0_0|d1_0,d1_1]");
INFER_OK(op, "[?,1];[1,?]", "[d0_1|d1_0,d1_1]");
INFER_OK(op, "[1,1];[?,?]", "[d0_0,d1_1]");
INFER_OK(op, "[1,1];[1,?]", "[d0_0|d0_1|d1_0,d1_1]");
// Test with ... being 2-d.
INFER_OK(op, "[10,?,?,?];[?,20,1,?]", "[d0_0,d1_1,d1_2,d1_3]");
INFER_OK(op, "[10,?,1,?];[?,20,1,?]", "[d0_0,d1_1,d0_2|d1_2,d1_3]");
INFER_OK(op, "[10,?,?,1];[?,20,1,?]", "[d0_0,d1_1,d0_3|d1_2,d1_3]");
INFER_OK(op, "[10,?,1,1];[?,20,?,?]", "[d0_0,d1_1,d0_2,d1_3]");
INFER_OK(op, "[10,?,1,1];[?,20,1,?]", "[d0_0,d1_1,d0_2|d0_3|d1_2,d1_3]");
}
}
TEST(LinalgOpsTest, MatrixSolveLs_ShapeFn) {

View File

@ -756,7 +756,6 @@ cuda_py_test(
name = "matrix_triangular_solve_op_test",
size = "small",
srcs = ["matrix_triangular_solve_op_test.py"],
shard_count = 2,
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:linalg_ops",

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
@ -67,32 +68,31 @@ class MatrixTriangularSolveOpTest(test.TestCase):
else:
a_np = a
if adjoint:
axes = list(range(len(a_np.shape)))
axes[-2] = -1
axes[-1] = -2
a_np = np.conj(np.transpose(a_np, axes=axes))
a_np = np.conj(np.transpose(a_np))
if batch_dims is not None:
a = np.tile(a, batch_dims + [1, 1])
a_np = np.tile(a_np, batch_dims + [1, 1])
b = np.tile(b, batch_dims + [1, 1])
def broadcast(a, b):
b1 = b + np.zeros(a.shape[:-2] + (1, 1), dtype=b.dtype)
return a, b1
a_tf = a
b_tf = b
if use_placeholder:
a_tf = array_ops.placeholder_with_default(a_tf, shape=None)
b_tf = array_ops.placeholder_with_default(b_tf, shape=None)
tf_ans = linalg_ops.matrix_triangular_solve(
a_tf, b_tf, lower=lower, adjoint=adjoint)
tf_val = self.evaluate(tf_ans)
a_np, b = broadcast(a_np, b)
np_ans = np.linalg.solve(a_np, b)
self.assertEqual(np_ans.shape, tf_val.shape)
self.assertAllClose(np_ans, tf_val)
with self.cached_session(use_gpu=True) as sess:
if use_placeholder:
a_tf = array_ops.placeholder(a.dtype)
b_tf = array_ops.placeholder(b.dtype)
tf_ans = linalg_ops.matrix_triangular_solve(
a_tf, b_tf, lower=lower, adjoint=adjoint)
tf_val = sess.run(tf_ans, feed_dict={a_tf: a, b_tf: b})
np_ans = np.linalg.solve(a_np, b)
else:
a_tf = constant_op.constant(a)
b_tf = constant_op.constant(b)
tf_ans = linalg_ops.matrix_triangular_solve(
a_tf, b_tf, lower=lower, adjoint=adjoint)
tf_val = self.evaluate(tf_ans)
np_ans = np.linalg.solve(a_np, b)
self.assertEqual(np_ans.shape, tf_ans.get_shape())
self.assertEqual(np_ans.shape, tf_val.shape)
self.assertAllClose(np_ans, tf_val)
@test_util.run_deprecated_v1
def testSolve(self):
@ -136,50 +136,6 @@ class MatrixTriangularSolveOpTest(test.TestCase):
# Batch of 3x2x2x2 matrices, 3x2x2x3 right-hand sides.
self._verifySolveAllWaysReal(matrix, rhs, batch_dims=[3, 2])
@test_util.run_deprecated_v1
@test_util.disable_xla("XLA cannot broadcast triangular solve.")
def testSolveBatchBroadcast(self):
# 2 x 2 x 2
matrix = np.array([[[1., 0.], [3., 4.]], [[1., 0.], [2., 1.]]])
# 2 x 3
rhs = np.array([[1., 0., 1.], [0., 1., 1.]])
# 2 x 2 x 3
self._verifySolveAllWaysReal(matrix, rhs)
# 2 x 2 x 2
matrix2 = np.array([[[1., 0.], [3., 4.]], [[2., 0.], [1., 6.3]]])
# 1 x 2 x 3
rhs = np.array([[[1., 0., 1.], [0., 1., 1.]]])
# 2 x 2 x 3
self._verifySolveAllWaysReal(matrix2, rhs)
@test_util.run_deprecated_v1
@test_util.disable_xla("XLA cannot broadcast triangular solve.")
def testSolveBatchBroadcastLargerBatches(self):
# 1 x 10 x 10
matrix = np.random.uniform(low=1, high=2., size=[1, 10, 10])
# 10 x 1
rhs = np.random.uniform(size=[10, 1])
# 1 x 10 x 1
self._verifySolveAllWaysReal(matrix, rhs)
# 2 x 10 x 10
matrix = np.random.uniform(low=1, high=2., size=[2, 10, 10])
# 10 x 1
rhs = np.random.uniform(size=[10, 1])
# 2 x 10 x 1
self._verifySolveAllWaysReal(matrix, rhs)
# 2 x 257 x 257
matrix = np.random.uniform(low=1, high=2., size=[2, 257, 257])
# Also ensure the matrix is well conditioned by making it diagonally
# dominant.
np.fill_diagonal(matrix[0, ...], 257 * 2)
np.fill_diagonal(matrix[1, ...], 257 * 2)
# 257 x 1
rhs = np.random.uniform(size=[257, 1])
# 2 x 257 x 1
self._verifySolveAllWaysReal(matrix, rhs)
@test_util.run_deprecated_v1
def testSolveBatchComplex(self):
if test.is_built_with_rocm():

View File

@ -607,7 +607,6 @@ def _MatrixSolveLsGrad(op, grad):
def _MatrixTriangularSolveGrad(op, grad):
"""Gradient for MatrixTriangularSolve."""
a = op.inputs[0]
b = op.inputs[1]
adjoint_a = op.get_attr("adjoint")
lower_a = op.get_attr("lower")
c = op.outputs[0]
@ -621,16 +620,7 @@ def _MatrixTriangularSolveGrad(op, grad):
grad_a = array_ops.matrix_band_part(grad_a, -1, 0)
else:
grad_a = array_ops.matrix_band_part(grad_a, 0, -1)
# If the static batch shapes are equal, we don't need to unbroadcast.
if (a.shape.is_fully_defined() and b.shape.is_fully_defined() and
a.shape[:-2] == b.shape[:-2]):
return grad_a, grad_b
a_shape = array_ops.shape(a)
b_shape = array_ops.shape(b)
ra, rb = array_ops.broadcast_gradient_args(a_shape[:-2], b_shape[:-2])
grad_a = array_ops.reshape(math_ops.reduce_sum(grad_a, axis=ra), a_shape)
grad_b = array_ops.reshape(math_ops.reduce_sum(grad_b, axis=rb), b_shape)
return grad_a, grad_b
return (grad_a, grad_b)
@ops.RegisterGradient("SelfAdjointEigV2")

View File

@ -79,67 +79,6 @@ def _RegularizedGramianCholesky(matrix, l2_regularizer, first_kind):
return gen_linalg_ops.cholesky(gramian)
@tf_export(
'linalg.triangular_solve',
v1=['linalg.triangular_solve', 'matrix_triangular_solve'])
def matrix_triangular_solve(matrix, rhs, lower=True, adjoint=False, name=None):
"""Solve systems of linear equations with upper or lower triangular matrices.
`matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form
square matrices. If `lower` is `True` then the strictly upper triangular part
of each inner-most matrix is assumed to be zero and not accessed. If `lower`
is `False` then the strictly lower triangular part of each inner-most matrix
is assumed to be zero and not accessed. `rhs` is a tensor of shape
`[..., M, N]`.
The output is a tensor of shape `[..., M, N]`. If `adjoint` is `True` then the
innermost matrices in output satisfy matrix equations `matrix[..., i, k] *
output[..., k, j] = rhs[..., i, j]`. If `adjoint` is `False` then the
innermost matrices in output satisfy matrix equations
`adjoint(matrix[..., i, k]) * output[..., k, j] = rhs[..., i, j]`.
Example:
>>> a = tf.constant([[3, 0, 0, 0],
... [2, 1, 0, 0],
... [1, 0, 1, 0],
... [1, 1, 1, 1]], dtype=tf.float32)
>>> b = tf.constant([[4], [2], [4], [2]], dtype=tf.float32)
>>> x = tf.linalg.triangular_solve(a, b, lower=True)
>>> x
<tf.Tensor: shape=(4, 1), dtype=float32, numpy=
array([[ 1.3333334 ],
[-0.66666675],
[ 2.6666665 ],
[-1.3333331 ]], dtype=float32)>
>>> tf.matmul(a, x)
<tf.Tensor: shape=(4, 1), dtype=float32, numpy=
array([[4.],
[2.],
[4.],
[2.]], dtype=float32)>
Args:
matrix: A `Tensor`. Must be one of the following types: `float64`,
`float32`, `half`, `complex64`, `complex128`. Shape is `[..., M, M]`.
rhs: A `Tensor`. Must have the same type as `matrix`. Shape is `[..., M,
N]`.
lower: An optional `bool`. Defaults to `True`. Boolean indicating whether
the innermost matrices in matrix are lower or upper triangular.
adjoint: An optional `bool`. Defaults to `False`. Boolean indicating whether
to solve with matrix or its (block-wise) adjoint.
name: A name for the operation (optional).
Returns:
A `Tensor`. Has the same type as matrix, and shape is `[..., M, N]`.
"""
with ops.name_scope(name, 'triangular_solve', [matrix, rhs]):
return gen_linalg_ops.matrix_triangular_solve(
matrix, rhs, lower=lower, adjoint=adjoint)
@tf_export(
'linalg.cholesky_solve', v1=['linalg.cholesky_solve', 'cholesky_solve'])
@deprecation.deprecated_endpoints('cholesky_solve')