Add Broadcasted Matrix Triangular Solve.
Add Numpy-style broadcasting in the batch dimensions for tf.linalg.triangular_solve op. The last two dimensions of both operands constitute the matrix dimensions. The dimensions beyond these are broadcasted to form a common output shape with the standard NumPy broadcasting rules. (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) Note: This implementation differs from Numpy's behavior in that vectors (rank-1 Tensors) are not pr... PiperOrigin-RevId: 289978628 Change-Id: I66e41e292e57e6df8111745cbe47ccffacb53edc
This commit is contained in:
parent
a5218435ec
commit
c8e8ba577e
@ -44,17 +44,15 @@ square matrices. If `lower` is `True` then the strictly upper triangular part
|
||||
of each inner-most matrix is assumed to be zero and not accessed.
|
||||
If `lower` is False then the strictly lower triangular part of each inner-most
|
||||
matrix is assumed to be zero and not accessed.
|
||||
`rhs` is a tensor of shape `[..., M, N]`.
|
||||
`rhs` is a tensor of shape `[..., M, K]`.
|
||||
|
||||
The output is a tensor of shape `[..., M, N]`. If `adjoint` is
|
||||
The output is a tensor of shape `[..., M, K]`. If `adjoint` is
|
||||
`True` then the innermost matrices in `output` satisfy matrix equations
|
||||
`matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`.
|
||||
If `adjoint` is `False` then the strictly then the innermost matrices in
|
||||
`output` satisfy matrix equations
|
||||
`adjoint(matrix[..., i, k]) * output[..., k, j] = rhs[..., i, j]`.
|
||||
|
||||
Note, the batch shapes for the inputs only need to broadcast.
|
||||
|
||||
Example:
|
||||
```python
|
||||
|
||||
|
@ -1,4 +1,10 @@
|
||||
op {
|
||||
graph_op_name: "MatrixTriangularSolve"
|
||||
visibility: HIDDEN
|
||||
endpoint {
|
||||
name: "linalg.triangular_solve"
|
||||
}
|
||||
endpoint {
|
||||
name: "matrix_triangular_solve"
|
||||
deprecation_version: 2
|
||||
}
|
||||
}
|
||||
|
@ -3588,14 +3588,10 @@ tf_kernel_library(
|
||||
|
||||
tf_kernel_library(
|
||||
name = "matrix_triangular_solve_op",
|
||||
hdrs = ["matrix_triangular_solve_op_impl.h"],
|
||||
prefix = "matrix_triangular_solve_op",
|
||||
deps = LINALG_DEPS + if_cuda([
|
||||
"//tensorflow/core/platform/default/build_config:cublas_plugin",
|
||||
]) + [
|
||||
":fill_functor",
|
||||
"//tensorflow/core:stream_executor",
|
||||
],
|
||||
]),
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
@ -4183,25 +4179,6 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "matrix_triangular_solve_op_test",
|
||||
size = "small",
|
||||
srcs = ["matrix_triangular_solve_op_test.cc"],
|
||||
deps = [
|
||||
":broadcast_to_op",
|
||||
":matrix_triangular_solve_op",
|
||||
":ops_testutil",
|
||||
":ops_util",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "scan_ops_test",
|
||||
size = "small",
|
||||
|
@ -900,106 +900,6 @@ static inline Status MatInvBatchedImpl(
|
||||
|
||||
TF_CALL_LAPACK_TYPES(MATINV_BATCHED_INSTANCE);
|
||||
|
||||
template <typename Scalar, typename SolverFnT>
|
||||
static inline Status TrsmImpl(SolverFnT solver, cublasHandle_t cublas_handle,
|
||||
cublasSideMode_t side, cublasFillMode_t uplo,
|
||||
cublasOperation_t trans, cublasDiagType_t diag,
|
||||
int m, int n,
|
||||
const Scalar* alpha, /* host or device pointer */
|
||||
const Scalar* A, int lda, Scalar* B, int ldb) {
|
||||
mutex_lock lock(handle_map_mutex);
|
||||
using CudaScalar = typename CUDAComplexT<Scalar>::type;
|
||||
TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, side, uplo, trans, diag, m, n,
|
||||
reinterpret_cast<const CudaScalar*>(alpha),
|
||||
reinterpret_cast<const CudaScalar*>(A), lda,
|
||||
reinterpret_cast<CudaScalar*>(B), ldb));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define TRSM_INSTANCE(Scalar, type_prefix) \
|
||||
template <> \
|
||||
Status CudaSolver::Trsm<Scalar>( \
|
||||
cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, \
|
||||
cublasDiagType_t diag, int m, int n, \
|
||||
const Scalar* alpha, /* host or device pointer */ \
|
||||
const Scalar* A, int lda, Scalar* B, int ldb) { \
|
||||
return TrsmImpl(BLAS_SOLVER_FN(trsm, type_prefix), cublas_handle_, side, \
|
||||
uplo, trans, diag, m, n, alpha, A, lda, B, ldb); \
|
||||
}
|
||||
|
||||
TF_CALL_LAPACK_TYPES(TRSM_INSTANCE);
|
||||
|
||||
template <typename Scalar, typename SolverFnT>
|
||||
static inline Status TrsvImpl(SolverFnT solver, cublasHandle_t cublas_handle,
|
||||
cublasFillMode_t uplo, cublasOperation_t trans,
|
||||
cublasDiagType_t diag, int n, const Scalar* A,
|
||||
int lda, Scalar* x, int incx) {
|
||||
mutex_lock lock(handle_map_mutex);
|
||||
using CudaScalar = typename CUDAComplexT<Scalar>::type;
|
||||
TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, uplo, trans, diag, n,
|
||||
reinterpret_cast<const CudaScalar*>(A), lda,
|
||||
reinterpret_cast<CudaScalar*>(x), incx));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define TRSV_INSTANCE(Scalar, type_prefix) \
|
||||
template <> \
|
||||
Status CudaSolver::Trsv<Scalar>( \
|
||||
cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, \
|
||||
int n, const Scalar* A, int lda, Scalar* x, int incx) { \
|
||||
return TrsvImpl(BLAS_SOLVER_FN(trsv, type_prefix), cublas_handle_, uplo, \
|
||||
trans, diag, n, A, lda, x, incx); \
|
||||
}
|
||||
|
||||
TF_CALL_LAPACK_TYPES(TRSV_INSTANCE);
|
||||
|
||||
template <typename Scalar, typename SolverFnT>
|
||||
static inline Status TrsmBatchedImpl(
|
||||
SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
|
||||
cublasHandle_t cublas_handle, cublasSideMode_t side, cublasFillMode_t uplo,
|
||||
cublasOperation_t trans, cublasDiagType_t diag, int m, int n,
|
||||
const Scalar* alpha, const Scalar* const host_a_dev_ptrs[], int lda,
|
||||
Scalar* host_b_dev_ptrs[], int ldb, int batch_size) {
|
||||
mutex_lock lock(handle_map_mutex);
|
||||
using CudaScalar = typename CUDAComplexT<Scalar>::type;
|
||||
ScratchSpace<uint8> dev_a_dev_ptrs =
|
||||
cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
|
||||
/* on_host */ false);
|
||||
ScratchSpace<uint8> dev_b_dev_ptrs =
|
||||
cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
|
||||
/* on_host */ false);
|
||||
if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
|
||||
host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
|
||||
return errors::Internal("TrsmBatched: failed to copy pointers to device");
|
||||
}
|
||||
if (!CopyHostToDevice(context, dev_b_dev_ptrs.mutable_data() /* dest */,
|
||||
host_b_dev_ptrs /* source */, dev_b_dev_ptrs.bytes())) {
|
||||
return errors::Internal("TrsmBatched: failed to copy pointers to device");
|
||||
}
|
||||
TF_RETURN_IF_CUBLAS_ERROR(
|
||||
solver(cublas_handle, side, uplo, trans, diag, m, n,
|
||||
reinterpret_cast<const CudaScalar*>(alpha),
|
||||
reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()),
|
||||
lda, reinterpret_cast<CudaScalar**>(dev_b_dev_ptrs.mutable_data()),
|
||||
ldb, batch_size));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define TRSM_BATCHED_INSTANCE(Scalar, type_prefix) \
|
||||
template <> \
|
||||
Status CudaSolver::TrsmBatched( \
|
||||
cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, \
|
||||
cublasDiagType_t diag, int m, int n, const Scalar* alpha, \
|
||||
const Scalar* const dev_Aarray[], int lda, Scalar* dev_Barray[], \
|
||||
int ldb, int batch_size) { \
|
||||
return TrsmBatchedImpl(BLAS_SOLVER_FN(trsmBatched, type_prefix), this, \
|
||||
context_, cublas_handle_, side, uplo, trans, diag, \
|
||||
m, n, alpha, dev_Aarray, lda, dev_Barray, ldb, \
|
||||
batch_size); \
|
||||
}
|
||||
|
||||
TF_CALL_LAPACK_TYPES(TRSM_BATCHED_INSTANCE);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
@ -333,28 +333,6 @@ class CudaSolver {
|
||||
int lda, Scalar* dev_S, Scalar* dev_U, int ldu,
|
||||
Scalar* dev_V, int ldv, int* dev_lapack_info,
|
||||
int batch_size);
|
||||
// Triangular solve
|
||||
// Returns Status::OK() if the kernel was launched successfully.
|
||||
// See https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-trsm
|
||||
template <typename Scalar>
|
||||
Status Trsm(cublasSideMode_t side, cublasFillMode_t uplo,
|
||||
cublasOperation_t trans, cublasDiagType_t diag, int m, int n,
|
||||
const Scalar* alpha, const Scalar* A, int lda, Scalar* B,
|
||||
int ldb);
|
||||
|
||||
template <typename Scalar>
|
||||
Status Trsv(cublasFillMode_t uplo, cublasOperation_t trans,
|
||||
cublasDiagType_t diag, int n, const Scalar* A, int lda, Scalar* x,
|
||||
int incx);
|
||||
|
||||
// See
|
||||
// https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-trsmbatched
|
||||
template <typename Scalar>
|
||||
Status TrsmBatched(cublasSideMode_t side, cublasFillMode_t uplo,
|
||||
cublasOperation_t trans, cublasDiagType_t diag, int m,
|
||||
int n, const Scalar* alpha,
|
||||
const Scalar* const dev_Aarray[], int lda,
|
||||
Scalar* dev_Barray[], int ldb, int batch_size);
|
||||
|
||||
private:
|
||||
OpKernelContext* context_; // not owned.
|
||||
|
258
tensorflow/core/kernels/matrix_triangular_solve_op.cc
Normal file
258
tensorflow/core/kernels/matrix_triangular_solve_op.cc
Normal file
@ -0,0 +1,258 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// See docs in ../ops/linalg_ops.cc.
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
namespace {
|
||||
template <typename Scalar>
|
||||
se::DeviceMemory<Scalar> AsDeviceMemory(const Scalar* gpu_memory) {
|
||||
se::DeviceMemoryBase wrapped(const_cast<Scalar*>(gpu_memory));
|
||||
se::DeviceMemory<Scalar> typed(wrapped);
|
||||
return typed;
|
||||
}
|
||||
} // namespace
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
template <class Scalar>
|
||||
class MatrixTriangularSolveOp : public LinearAlgebraOp<Scalar> {
|
||||
public:
|
||||
INHERIT_LINALG_TYPEDEFS(Scalar);
|
||||
|
||||
explicit MatrixTriangularSolveOp(OpKernelConstruction* context)
|
||||
: Base(context), lower_(true), adjoint_(false) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
|
||||
}
|
||||
|
||||
void ValidateInputMatrixShapes(
|
||||
OpKernelContext* context,
|
||||
const TensorShapes& input_matrix_shapes) const final {
|
||||
Base::ValidateSquareSolver(context, input_matrix_shapes);
|
||||
}
|
||||
|
||||
TensorShapes GetOutputMatrixShapes(
|
||||
const TensorShapes& input_matrix_shapes) const final {
|
||||
return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1),
|
||||
input_matrix_shapes[1].dim_size(1)})});
|
||||
}
|
||||
|
||||
int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final {
|
||||
double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0));
|
||||
double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1));
|
||||
double cost = rows * rows * num_rhss *
|
||||
(Eigen::TensorOpCost::AddCost<Scalar>() +
|
||||
Eigen::TensorOpCost::MulCost<Scalar>());
|
||||
return cost >= static_cast<double>(kint64max) ? kint64max
|
||||
: static_cast<int64>(cost);
|
||||
}
|
||||
|
||||
bool EnableInputForwarding() const final { return false; }
|
||||
|
||||
void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
|
||||
MatrixMaps* outputs) final {
|
||||
const ConstMatrixMap& matrix = inputs[0];
|
||||
const ConstMatrixMap& rhs = inputs[1];
|
||||
MatrixMap& output = outputs->at(0);
|
||||
|
||||
if (matrix.rows() == 0 || rhs.rows() == 0 || rhs.cols() == 0) {
|
||||
// To be consistent with the MatrixInverse op, we define the solution for
|
||||
// an empty set of equation as the empty matrix.
|
||||
return;
|
||||
}
|
||||
const RealScalar min_abs_pivot = matrix.diagonal().cwiseAbs().minCoeff();
|
||||
OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
|
||||
errors::InvalidArgument("Input matrix is not invertible."));
|
||||
if (lower_) {
|
||||
auto triangle = matrix.template triangularView<Eigen::Lower>();
|
||||
if (adjoint_) {
|
||||
output.noalias() = triangle.adjoint().solve(rhs);
|
||||
} else {
|
||||
output.noalias() = triangle.solve(rhs);
|
||||
}
|
||||
} else {
|
||||
auto triangle = matrix.template triangularView<Eigen::Upper>();
|
||||
if (adjoint_) {
|
||||
output.noalias() = triangle.adjoint().solve(rhs);
|
||||
} else {
|
||||
output.noalias() = triangle.solve(rhs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool lower_;
|
||||
bool adjoint_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(MatrixTriangularSolveOp);
|
||||
};
|
||||
|
||||
REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
|
||||
(MatrixTriangularSolveOp<float>), float);
|
||||
REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
|
||||
(MatrixTriangularSolveOp<double>), double);
|
||||
REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
|
||||
(MatrixTriangularSolveOp<complex64>), complex64);
|
||||
REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
|
||||
(MatrixTriangularSolveOp<complex128>), complex128);
|
||||
REGISTER_LINALG_OP_CPU("BatchMatrixTriangularSolve",
|
||||
(MatrixTriangularSolveOp<float>), float);
|
||||
REGISTER_LINALG_OP_CPU("BatchMatrixTriangularSolve",
|
||||
(MatrixTriangularSolveOp<double>), double);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
// TODO(rmlarsen): Re-factor to
|
||||
// 1. Enable buffer forwarding from rhs->out.
|
||||
// 2. Save Memcpy when buffer forwarding is used.
|
||||
// 3. Copy entire rhs in a single Memcpy when forwarding is not used.
|
||||
template <class Scalar>
|
||||
class MatrixTriangularSolveOpGPU : public LinearAlgebraOp<Scalar> {
|
||||
public:
|
||||
INHERIT_LINALG_TYPEDEFS(Scalar);
|
||||
|
||||
explicit MatrixTriangularSolveOpGPU(OpKernelConstruction* context)
|
||||
: Base(context), lower_(true), adjoint_(false) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
|
||||
}
|
||||
|
||||
void ValidateInputMatrixShapes(
|
||||
OpKernelContext* context,
|
||||
const TensorShapes& input_matrix_shapes) const final {
|
||||
Base::ValidateSquareSolver(context, input_matrix_shapes);
|
||||
}
|
||||
|
||||
TensorShapes GetOutputMatrixShapes(
|
||||
const TensorShapes& input_matrix_shapes) const final {
|
||||
return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1),
|
||||
input_matrix_shapes[1].dim_size(1)})});
|
||||
}
|
||||
|
||||
int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final {
|
||||
double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0));
|
||||
double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1));
|
||||
double cost = rows * rows * num_rhss *
|
||||
(Eigen::TensorOpCost::AddCost<Scalar>() +
|
||||
Eigen::TensorOpCost::MulCost<Scalar>());
|
||||
return cost >= static_cast<double>(kint64max) ? kint64max
|
||||
: static_cast<int64>(cost);
|
||||
}
|
||||
|
||||
bool EnableInputForwarding() const final { return false; }
|
||||
|
||||
void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
|
||||
MatrixMaps* outputs) final {
|
||||
const ConstMatrixMap& matrix = inputs[0];
|
||||
const ConstMatrixMap& rhs = inputs[1];
|
||||
MatrixMap& output = outputs->at(0);
|
||||
|
||||
if (matrix.rows() == 0 || rhs.rows() == 0 || rhs.cols() == 0) {
|
||||
// To be consistent with the MatrixInverse op, we define the solution for
|
||||
// an empty set of equation as the empty matrix.
|
||||
return;
|
||||
}
|
||||
|
||||
auto matrix_ptr = AsDeviceMemory(matrix.data());
|
||||
auto rhs_ptr = AsDeviceMemory(rhs.data());
|
||||
auto out_ptr = AsDeviceMemory(output.data());
|
||||
|
||||
auto* stream = context->op_device_context()->stream();
|
||||
uint64 rhs_elems = rhs.rows() * rhs.cols();
|
||||
bool copy_status =
|
||||
stream->ThenMemcpyD2D(&out_ptr, rhs_ptr, sizeof(Scalar) * rhs_elems)
|
||||
.ok();
|
||||
if (!copy_status) {
|
||||
context->SetStatus(
|
||||
errors::Internal("Failed to copy rhs into output before solve"));
|
||||
}
|
||||
|
||||
// Cublas does
|
||||
// output = matrix \ rhs
|
||||
// where matrix, rhs and output are assumed to be in column major.
|
||||
// We want the output to be in row-major, so we can compute
|
||||
// output' = rhs' / matrix' (' stands for transpose)
|
||||
// Upper/lower needs to be swapped for this.
|
||||
|
||||
se::blas::UpperLower upper_lower_matrix;
|
||||
se::blas::Transpose transpose_matrix;
|
||||
if (lower_) {
|
||||
upper_lower_matrix = se::blas::UpperLower::kUpper;
|
||||
} else {
|
||||
upper_lower_matrix = se::blas::UpperLower::kLower;
|
||||
}
|
||||
if (adjoint_) {
|
||||
transpose_matrix = se::blas::Transpose::kConjugateTranspose;
|
||||
} else {
|
||||
transpose_matrix = se::blas::Transpose::kNoTranspose;
|
||||
}
|
||||
uint64 leading_dim_matrix = matrix.cols();
|
||||
uint64 leading_dim_output = output.cols();
|
||||
uint64 colmajor_rows = output.cols();
|
||||
uint64 colmajor_cols = output.rows();
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasTrsm(
|
||||
se::blas::Side::kRight /*side*/, upper_lower_matrix /*uplo*/,
|
||||
transpose_matrix /*trans*/,
|
||||
se::blas::Diagonal::kNonUnit /*diag*/, colmajor_rows /*m*/,
|
||||
colmajor_cols /*n*/, Scalar(1.0) /*alpha*/, matrix_ptr,
|
||||
leading_dim_matrix /*lda*/, &out_ptr,
|
||||
leading_dim_output /*ldb*/)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
context->SetStatus(errors::Internal("Blas TRSM launch failed"));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool lower_;
|
||||
bool adjoint_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(MatrixTriangularSolveOpGPU);
|
||||
};
|
||||
|
||||
REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
|
||||
(MatrixTriangularSolveOpGPU<float>), float);
|
||||
REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
|
||||
(MatrixTriangularSolveOpGPU<double>), double);
|
||||
REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
|
||||
(MatrixTriangularSolveOpGPU<complex64>), complex64);
|
||||
REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
|
||||
(MatrixTriangularSolveOpGPU<complex128>), complex128);
|
||||
REGISTER_LINALG_OP_GPU("BatchMatrixTriangularSolve",
|
||||
(MatrixTriangularSolveOpGPU<float>), float);
|
||||
REGISTER_LINALG_OP_GPU("BatchMatrixTriangularSolve",
|
||||
(MatrixTriangularSolveOpGPU<double>), double);
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
@ -1,28 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/matrix_triangular_solve_op_impl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TF_CALL_complex64(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU);
|
||||
TF_CALL_complex128(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
TF_CALL_complex64(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU);
|
||||
TF_CALL_complex128(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU);
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
@ -1,431 +0,0 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// See docs in ../ops/linalg_ops.cc.
|
||||
//
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/fill_functor.h"
|
||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/matmul_bcast.h"
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||
#include "tensorflow/core/kernels/transpose_functor.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
template <typename Scalar>
|
||||
se::DeviceMemory<Scalar> AsDeviceMemory(const Scalar* gpu_memory) {
|
||||
se::DeviceMemoryBase wrapped(const_cast<Scalar*>(gpu_memory));
|
||||
se::DeviceMemory<Scalar> typed(wrapped);
|
||||
return typed;
|
||||
}
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
// Sequential batch matrix triangular solve kernel that calls Eigen's
|
||||
// matrix triangular solve.
|
||||
template <typename Scalar>
|
||||
struct SequentialMatrixTriangularSolveKernel {
|
||||
using Matrix =
|
||||
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
||||
using ConstMatrixMap = Eigen::Map<const Matrix>;
|
||||
using MatrixMap = Eigen::Map<Matrix>;
|
||||
using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
|
||||
|
||||
static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t,
|
||||
int slice) {
|
||||
return ConstMatrixMap(
|
||||
t.flat<Scalar>().data() + slice * t.dim_size(1) * t.dim_size(2),
|
||||
t.dim_size(1), t.dim_size(2));
|
||||
}
|
||||
|
||||
static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) {
|
||||
return MatrixMap(
|
||||
t->flat<Scalar>().data() + slice * t->dim_size(1) * t->dim_size(2),
|
||||
t->dim_size(1), t->dim_size(2));
|
||||
}
|
||||
|
||||
static void Run(const Tensor& in_x, const Tensor& in_y, bool lower,
|
||||
bool adjoint, const MatMulBCast& bcast, Tensor* out,
|
||||
int start, int limit) {
|
||||
const bool should_bcast = bcast.IsBroadcastingRequired();
|
||||
const auto& x_batch_indices = bcast.x_batch_indices();
|
||||
const auto& y_batch_indices = bcast.y_batch_indices();
|
||||
for (int64 i = start; i < limit; ++i) {
|
||||
const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
|
||||
const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
|
||||
auto matrix = ConstTensorSliceToEigenMatrix(in_x, x_batch_index);
|
||||
auto rhs = ConstTensorSliceToEigenMatrix(in_y, y_batch_index);
|
||||
auto output = TensorSliceToEigenMatrix(out, i);
|
||||
if (lower) {
|
||||
auto triangle = matrix.template triangularView<Eigen::Lower>();
|
||||
if (adjoint) {
|
||||
output.noalias() = triangle.adjoint().solve(rhs);
|
||||
} else {
|
||||
output.noalias() = triangle.solve(rhs);
|
||||
}
|
||||
} else {
|
||||
auto triangle = matrix.template triangularView<Eigen::Upper>();
|
||||
if (adjoint) {
|
||||
output.noalias() = triangle.adjoint().solve(rhs);
|
||||
} else {
|
||||
output.noalias() = triangle.solve(rhs);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename Scalar>
|
||||
struct LaunchBatchMatrixTriangularSolve;
|
||||
|
||||
template <typename Scalar>
|
||||
struct LaunchBatchMatrixTriangularSolve<CPUDevice, Scalar> {
|
||||
static void Launch(OpKernelContext* context, const Tensor& in_x,
|
||||
const Tensor& in_y, bool adjoint, bool lower,
|
||||
const MatMulBCast& bcast, Tensor* out) {
|
||||
// Number of matrix triangular solves i.e. size of the batch.
|
||||
const int64 batch_size = bcast.output_batch_size();
|
||||
const int64 cost_per_unit =
|
||||
in_x.dim_size(1) * in_x.dim_size(1) * in_y.dim_size(2) / 2;
|
||||
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
|
||||
|
||||
using Matrix =
|
||||
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
||||
using ConstMatrixMap = Eigen::Map<const Matrix>;
|
||||
using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
|
||||
// Check diagonal before doing any solves.
|
||||
auto matrix = ConstMatrixMap(in_x.flat<Scalar>().data(), in_x.dim_size(1),
|
||||
in_x.dim_size(2));
|
||||
const RealScalar min_abs_pivot = matrix.diagonal().cwiseAbs().minCoeff();
|
||||
OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
|
||||
errors::InvalidArgument("Input matrix is not invertible."));
|
||||
|
||||
Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
|
||||
cost_per_unit,
|
||||
[&in_x, &in_y, adjoint, lower, &bcast, out](int start, int limit) {
|
||||
SequentialMatrixTriangularSolveKernel<Scalar>::Run(
|
||||
in_x, in_y, lower, adjoint, bcast, out, start, limit);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename Scalar>
|
||||
class BaseMatrixTriangularSolveOp : public OpKernel {
|
||||
public:
|
||||
explicit BaseMatrixTriangularSolveOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
|
||||
}
|
||||
|
||||
~BaseMatrixTriangularSolveOp() override {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& in0 = ctx->input(0);
|
||||
const Tensor& in1 = ctx->input(1);
|
||||
|
||||
ValidateInputTensors(ctx, in0, in1);
|
||||
|
||||
MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
|
||||
OP_REQUIRES(
|
||||
ctx, bcast.IsValid(),
|
||||
errors::InvalidArgument(
|
||||
"In[0] and In[1] must have compatible batch dimensions: ",
|
||||
in0.shape().DebugString(), " vs. ", in1.shape().DebugString()));
|
||||
|
||||
TensorShape out_shape = bcast.output_batch_shape();
|
||||
auto batch_size = bcast.output_batch_size();
|
||||
auto d0 = in0.dim_size(in0.dims() - 2);
|
||||
auto d1 = in0.dim_size(in0.dims() - 1);
|
||||
Tensor in0_reshaped;
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})),
|
||||
errors::Internal("Failed to reshape In[0] from ",
|
||||
in0.shape().DebugString()));
|
||||
auto d2 = in1.dim_size(in1.dims() - 2);
|
||||
auto d3 = in1.dim_size(in1.dims() - 1);
|
||||
Tensor in1_reshaped;
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
|
||||
errors::Internal("Failed to reshape In[1] from ",
|
||||
in1.shape().DebugString()));
|
||||
if (adjoint_) std::swap(d0, d1);
|
||||
OP_REQUIRES(ctx, d1 == d2,
|
||||
errors::InvalidArgument(
|
||||
"In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
|
||||
in0.shape().DebugString(), " ", in1.shape().DebugString(),
|
||||
" ", lower_, " ", adjoint_));
|
||||
out_shape.AddDim(d0);
|
||||
out_shape.AddDim(d3);
|
||||
Tensor* out = nullptr;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
|
||||
if (out->NumElements() == 0) {
|
||||
return;
|
||||
}
|
||||
Tensor out_reshaped;
|
||||
OP_REQUIRES(ctx,
|
||||
out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})),
|
||||
errors::Internal("Failed to reshape output from ",
|
||||
out->shape().DebugString()));
|
||||
LaunchBatchMatrixTriangularSolve<Device, Scalar>::Launch(
|
||||
ctx, in0_reshaped, in1_reshaped, adjoint_, lower_, bcast,
|
||||
&out_reshaped);
|
||||
}
|
||||
|
||||
private:
|
||||
virtual void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
|
||||
const Tensor& in1) = 0;
|
||||
bool lower_;
|
||||
bool adjoint_;
|
||||
};
|
||||
|
||||
template <class Device, class Scalar>
|
||||
class MatrixTriangularSolveOp
|
||||
: public BaseMatrixTriangularSolveOp<Device, Scalar> {
|
||||
public:
|
||||
explicit MatrixTriangularSolveOp(OpKernelConstruction* context)
|
||||
: BaseMatrixTriangularSolveOp<Device, Scalar>(context) {}
|
||||
|
||||
~MatrixTriangularSolveOp() override {}
|
||||
|
||||
private:
|
||||
void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
|
||||
const Tensor& in1) override {
|
||||
// Disallow broadcasting support. Ensure that all batch dimensions of the
|
||||
// input tensors match.
|
||||
OP_REQUIRES(ctx, in0.dims() == in1.dims(),
|
||||
errors::InvalidArgument("In[0] and In[1] has different ndims: ",
|
||||
in0.shape().DebugString(), " vs. ",
|
||||
in1.shape().DebugString()));
|
||||
const int ndims = in0.dims();
|
||||
OP_REQUIRES(
|
||||
ctx, ndims >= 2,
|
||||
errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims));
|
||||
for (int i = 0; i < ndims - 2; ++i) {
|
||||
OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i),
|
||||
errors::InvalidArgument(
|
||||
"In[0].dim(", i, ") and In[1].dim(", i,
|
||||
") must be the same: ", in0.shape().DebugString(), " vs ",
|
||||
in1.shape().DebugString()));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class Device, class Scalar>
|
||||
class MatrixTriangularSolveOpV2
|
||||
: public BaseMatrixTriangularSolveOp<Device, Scalar> {
|
||||
public:
|
||||
explicit MatrixTriangularSolveOpV2(OpKernelConstruction* context)
|
||||
: BaseMatrixTriangularSolveOp<Device, Scalar>(context) {}
|
||||
|
||||
~MatrixTriangularSolveOpV2() override {}
|
||||
|
||||
private:
|
||||
void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
|
||||
const Tensor& in1) override {
|
||||
OP_REQUIRES(
|
||||
ctx, in0.dims() >= 2,
|
||||
errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims()));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, in1.dims() >= 2,
|
||||
errors::InvalidArgument("In[0] ndims must be >= 2: ", in1.dims()));
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER(Name("MatrixTriangularSolve") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<TYPE>("T"), \
|
||||
MatrixTriangularSolveOpV2<CPUDevice, TYPE>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("BatchMatrixTriangularSolve") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<TYPE>("T"), \
|
||||
MatrixTriangularSolveOpV2<CPUDevice, TYPE>);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
template <typename Scalar>
|
||||
struct LaunchBatchMatrixTriangularSolve<GPUDevice, Scalar> {
|
||||
static void Launch(OpKernelContext* context, const Tensor& in_x,
|
||||
const Tensor& in_y, bool adjoint, bool lower,
|
||||
const MatMulBCast& bcast, Tensor* out) {
|
||||
auto* stream = context->op_device_context()->stream();
|
||||
|
||||
const uint64 m = in_x.dim_size(1);
|
||||
const uint64 n = out->dim_size(2);
|
||||
|
||||
// Do a memcpy when we don't need to broadcast.
|
||||
if (!bcast.IsBroadcastingRequired() || out->shape() == in_y.shape()) {
|
||||
auto src_device_mem = AsDeviceMemory(in_y.template flat<Scalar>().data());
|
||||
auto dst_device_mem = AsDeviceMemory(out->template flat<Scalar>().data());
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
stream
|
||||
->ThenMemcpyD2D(&dst_device_mem, src_device_mem,
|
||||
bcast.y_batch_size() * m * n * sizeof(Scalar))
|
||||
.ok(),
|
||||
errors::Internal("MatrixTriangularSolveOpV2: failed to copy rhs "
|
||||
"from device"));
|
||||
} else {
|
||||
std::vector<Scalar*> out_ptrs;
|
||||
std::vector<const Scalar*> b_tmp_ptrs;
|
||||
auto* b_base_ptr = in_y.template flat<Scalar>().data();
|
||||
const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
|
||||
for (int64 i = 0; i < bcast.y_batch_size(); ++i) {
|
||||
b_tmp_ptrs.push_back(b_base_ptr + i * m * n);
|
||||
}
|
||||
for (int64 i = 0; i < bcast.output_batch_size(); ++i) {
|
||||
auto src_device_mem = AsDeviceMemory(b_tmp_ptrs[b_batch_indices[i]]);
|
||||
auto dst_device_mem =
|
||||
AsDeviceMemory(out->template flat<Scalar>().data() + i * m * n);
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
stream
|
||||
->ThenMemcpyD2D(&dst_device_mem, src_device_mem,
|
||||
m * n * sizeof(Scalar))
|
||||
.ok(),
|
||||
errors::Internal("MatrixTriangularSolveOpV2: failed to copy rhs "
|
||||
"from device"));
|
||||
}
|
||||
}
|
||||
|
||||
if (out->NumElements() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
cublasSideMode_t side = CUBLAS_SIDE_RIGHT;
|
||||
cublasFillMode_t uplo;
|
||||
cublasOperation_t trans;
|
||||
cublasDiagType_t diag = CUBLAS_DIAG_NON_UNIT;
|
||||
|
||||
// Cublas does
|
||||
// output = matrix \ rhs
|
||||
// where matrix, rhs and output are assumed to be in column major.
|
||||
// We want the output to be in row-major, so we can compute
|
||||
// output' = rhs' / matrix' (' stands for transpose)
|
||||
// Upper/lower needs to be swapped for this.
|
||||
|
||||
uplo = lower ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
|
||||
trans = adjoint ? CUBLAS_OP_C : CUBLAS_OP_N;
|
||||
auto solver = absl::make_unique<CudaSolver>(context);
|
||||
|
||||
const uint64 leading_dim_matrix = m;
|
||||
const uint64 leading_dim_output = n;
|
||||
const uint64 colmajor_rows = n;
|
||||
const uint64 colmajor_cols = m;
|
||||
|
||||
const int64 batch_size = bcast.output_batch_size();
|
||||
std::vector<const Scalar*> a_ptrs;
|
||||
std::vector<Scalar*> out_ptrs;
|
||||
std::vector<const Scalar*> a_tmp_ptrs;
|
||||
a_ptrs.reserve(batch_size);
|
||||
out_ptrs.reserve(batch_size);
|
||||
a_tmp_ptrs.reserve(bcast.x_batch_size());
|
||||
auto* a_base_ptr = in_x.template flat<Scalar>().data();
|
||||
auto* out_base_ptr = out->template flat<Scalar>().data();
|
||||
|
||||
if (!bcast.IsBroadcastingRequired()) {
|
||||
for (int64 i = 0; i < batch_size; ++i) {
|
||||
a_ptrs.push_back(a_base_ptr + i * m * m);
|
||||
out_ptrs.push_back(out_base_ptr + i * m * n);
|
||||
}
|
||||
} else {
|
||||
const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
|
||||
for (int64 i = 0; i < bcast.x_batch_size(); ++i) {
|
||||
a_tmp_ptrs.push_back(a_base_ptr + i * m * m);
|
||||
}
|
||||
for (int64 i = 0; i < batch_size; ++i) {
|
||||
a_ptrs.push_back(a_tmp_ptrs[a_batch_indices[i]]);
|
||||
out_ptrs.push_back(out_base_ptr + i * m * n);
|
||||
}
|
||||
}
|
||||
|
||||
typedef Scalar Coefficient;
|
||||
const Scalar alpha = Scalar(1.0);
|
||||
|
||||
// TODO(b/146763573): Consider using Trsv here when the right hand side is
|
||||
// a vector. This will require an explicit transpose since Trsv assumes
|
||||
// CUBLAS_SIDE_LEFT.
|
||||
if (batch_size == 1) {
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
solver->Trsm(side, uplo, trans, diag, colmajor_rows, colmajor_cols,
|
||||
&alpha, a_ptrs[0], leading_dim_matrix /*lda*/,
|
||||
out_ptrs[0], leading_dim_output /*ldb*/));
|
||||
} else {
|
||||
// Heuristic for choosing between batched interface vs. non-batched
|
||||
// interface. This is inspired by matrix_solve_op and can probably be
|
||||
// tuned.
|
||||
// TODO(b/146763573): Tune this heuristic.
|
||||
const int kMaxMatrixSizeToBatchSizeRatio = 128;
|
||||
const bool use_batched_solver =
|
||||
m <= kMaxMatrixSizeToBatchSizeRatio * batch_size;
|
||||
if (use_batched_solver) {
|
||||
OP_REQUIRES_OK(
|
||||
context, solver->TrsmBatched(
|
||||
side, uplo, trans, diag, colmajor_rows, colmajor_cols,
|
||||
&alpha, &a_ptrs[0], leading_dim_matrix /*lda*/,
|
||||
&out_ptrs[0], leading_dim_output /*ldb*/, batch_size));
|
||||
} else {
|
||||
for (int batch = 0; batch < batch_size; ++batch) {
|
||||
OP_REQUIRES_OK(
|
||||
context, solver->Trsm(side, uplo, trans, diag, colmajor_rows,
|
||||
colmajor_cols, &alpha, a_ptrs[batch],
|
||||
leading_dim_matrix /*lda*/, out_ptrs[batch],
|
||||
leading_dim_output /*ldb*/));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER(Name("MatrixTriangularSolve") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<TYPE>("T"), \
|
||||
MatrixTriangularSolveOpV2<GPUDevice, TYPE>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("BatchMatrixTriangularSolve") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<TYPE>("T"), \
|
||||
MatrixTriangularSolveOpV2<GPUDevice, TYPE>);
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
|
@ -1,32 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/matrix_triangular_solve_op_impl.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TF_CALL_float(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU);
|
||||
TF_CALL_double(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
TF_CALL_float(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU);
|
||||
TF_CALL_double(REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU);
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
@ -1,165 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/graph/testlib.h"
|
||||
#include "tensorflow/core/kernels/broadcast_to_op.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Node* BroadcastTo(Graph* g, Node* input, Node* shape) {
|
||||
Node* ret;
|
||||
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BroadcastTo")
|
||||
.Input(input)
|
||||
.Input(shape)
|
||||
.Attr("Tidx", DT_INT64)
|
||||
.Finalize(g, &ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
Node* MatrixTriangularSolve(Graph* g, Node* in0, Node* in1, bool adjoint) {
|
||||
Node* ret;
|
||||
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatrixTriangularSolve")
|
||||
.Input(in0)
|
||||
.Input(in1)
|
||||
.Attr("lower", true)
|
||||
.Attr("adjoint", adjoint)
|
||||
.Finalize(g, &ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static Graph* MatrixTriangularSolveWithBroadcast(int64 b0, int64 b1, int64 m,
|
||||
int64 n, bool manual_broadcast,
|
||||
DataType type) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
Tensor in0(type, TensorShape({b0, m, m}));
|
||||
// Set diagonal to non-zero to guarantee invertibility.
|
||||
in0.flat<T>().setRandom();
|
||||
auto matrix = Eigen::Map<
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
|
||||
in0.flat<T>().data(), in0.dim_size(1), in0.dim_size(2));
|
||||
|
||||
matrix.diagonal() =
|
||||
(matrix.diagonal().cwiseAbs().array() + static_cast<T>(0.5));
|
||||
Tensor in1(type, TensorShape({b1, m, n}));
|
||||
in1.flat<T>().setRandom();
|
||||
|
||||
Tensor broadcasted_in0_shape(DT_INT64, TensorShape({3}));
|
||||
Tensor broadcasted_in1_shape(DT_INT64, TensorShape({3}));
|
||||
|
||||
Node* in0_node = nullptr;
|
||||
Node* in1_node = nullptr;
|
||||
if (manual_broadcast) {
|
||||
auto vec0 = broadcasted_in0_shape.vec<int64>();
|
||||
auto vec1 = broadcasted_in1_shape.vec<int64>();
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
vec0(i) = (i == 0 ? std::max(b0, b1) : in0.shape().dim_size(i));
|
||||
vec1(i) = (i == 0 ? std::max(b0, b1) : in1.shape().dim_size(i));
|
||||
}
|
||||
in0_node = BroadcastTo(g, test::graph::Constant(g, in0),
|
||||
test::graph::Constant(g, broadcasted_in0_shape));
|
||||
in1_node = BroadcastTo(g, test::graph::Constant(g, in1),
|
||||
test::graph::Constant(g, broadcasted_in1_shape));
|
||||
} else {
|
||||
in0_node = test::graph::Constant(g, in0);
|
||||
in1_node = test::graph::Constant(g, in1);
|
||||
}
|
||||
|
||||
MatrixTriangularSolve(g, in0_node, in1_node, false);
|
||||
return g;
|
||||
}
|
||||
|
||||
// Macro arguments names: --------------------------------------------------- //
|
||||
// B1: batch size of LHS
|
||||
// B2: batch size of RHS
|
||||
// M: inner dimensions of LHS and RHS, outer dimension of LHS
|
||||
// N: outer dimension of RHS
|
||||
// MB: boolean indicating whether to use manual broadcasting
|
||||
// T: C++ type of scalars (e.g. float, std::complex)
|
||||
// TT: TensorFlow type of scalars (e.g. DT_FLOAT, DT_COMPLEX128
|
||||
// D: Device (e.g. cpu, gpu)
|
||||
#define BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, T, TT, D) \
|
||||
static void \
|
||||
BM_MatrixTriangularSolve##_##B1##_##B2##_##M##_##N##_##MB##_##TT##_##D( \
|
||||
int iters) { \
|
||||
testing::UseRealTime(); \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * std::max(B1, B2) * M * \
|
||||
M * N * 2); \
|
||||
test::Benchmark( \
|
||||
#D, MatrixTriangularSolveWithBroadcast<T>(B1, B2, M, N, MB, TT)) \
|
||||
.Run(iters); \
|
||||
} \
|
||||
BENCHMARK( \
|
||||
BM_MatrixTriangularSolve##_##B1##_##B2##_##M##_##N##_##MB##_##TT##_##D);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define BM_MatrixTriangularSolve(B1, B2, M, N, MB) \
|
||||
BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, float, DT_FLOAT, cpu); \
|
||||
BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, double, DT_DOUBLE, cpu); \
|
||||
BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, float, DT_FLOAT, gpu); \
|
||||
BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, double, DT_DOUBLE, gpu);
|
||||
|
||||
#else
|
||||
|
||||
#define BM_MatrixTriangularSolve(B1, B2, M, N, MB) \
|
||||
BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, float, DT_FLOAT, cpu); \
|
||||
BM_MatrixTriangularSolveDev(B1, B2, M, N, MB, double, DT_DOUBLE, cpu);
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
// Square matrix triangular solve.
|
||||
BM_MatrixTriangularSolve(32, 32, 512, 512, true);
|
||||
BM_MatrixTriangularSolve(32, 32, 512, 512, false);
|
||||
BM_MatrixTriangularSolve(1, 32, 512, 512, true);
|
||||
BM_MatrixTriangularSolve(1, 32, 512, 512, false);
|
||||
BM_MatrixTriangularSolve(32, 1, 512, 512, true);
|
||||
BM_MatrixTriangularSolve(32, 1, 512, 512, false);
|
||||
BM_MatrixTriangularSolve(128, 128, 512, 512, true);
|
||||
BM_MatrixTriangularSolve(128, 128, 512, 512, false);
|
||||
BM_MatrixTriangularSolve(1, 128, 512, 512, true);
|
||||
BM_MatrixTriangularSolve(1, 128, 512, 512, false);
|
||||
BM_MatrixTriangularSolve(128, 1, 512, 512, true);
|
||||
BM_MatrixTriangularSolve(128, 1, 512, 512, false);
|
||||
BM_MatrixTriangularSolve(1, 128, 1024, 1024, true);
|
||||
BM_MatrixTriangularSolve(1, 128, 1024, 1024, false);
|
||||
BM_MatrixTriangularSolve(128, 1, 1024, 1024, true);
|
||||
BM_MatrixTriangularSolve(128, 1, 1024, 1024, false);
|
||||
|
||||
// Matrix-vector triangular solve.
|
||||
BM_MatrixTriangularSolve(1, 128, 200, 1, true);
|
||||
BM_MatrixTriangularSolve(1, 128, 200, 1, false);
|
||||
BM_MatrixTriangularSolve(128, 1, 200, 1, true);
|
||||
BM_MatrixTriangularSolve(128, 1, 200, 1, false);
|
||||
|
||||
// Matrix-vector triangular solve, large dimension.
|
||||
BM_MatrixTriangularSolve(1, 128, 200, 10000, true);
|
||||
BM_MatrixTriangularSolve(1, 128, 200, 10000, false);
|
||||
BM_MatrixTriangularSolve(128, 1, 200, 10000, true);
|
||||
BM_MatrixTriangularSolve(128, 1, 200, 10000, false);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -84,34 +84,6 @@ Status MatrixSolveShapeFn(InferenceContext* c, bool square) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// The first input is [...,M,M] and second input is [...,M,N].
|
||||
// Output is [...,M,N].
|
||||
Status MatrixTriangularSolveShapeFn(InferenceContext* c) {
|
||||
ShapeHandle lhs;
|
||||
ShapeHandle rhs;
|
||||
TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &lhs));
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
|
||||
|
||||
ShapeHandle lhs_batch_shape;
|
||||
ShapeHandle rhs_batch_shape;
|
||||
ShapeHandle output_batch_shape;
|
||||
// Make the common batch subshape.
|
||||
TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
|
||||
TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
|
||||
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
|
||||
c, lhs_batch_shape, rhs_batch_shape, true, &output_batch_shape));
|
||||
DimensionHandle m;
|
||||
// lhs and rhs have the same value for m to be compatible.
|
||||
TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, -1), c->Dim(rhs, -2), &m));
|
||||
|
||||
ShapeHandle out;
|
||||
// Build final shape (batch_shape + m + n) in <out>.
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Concatenate(output_batch_shape, c->Matrix(m, c->Dim(rhs, -1)), &out));
|
||||
c->set_output(0, out);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Input is [...,N,N]. Outputs are:
|
||||
// [...,N];[0], if compute_v is false,
|
||||
// [...,N];[...,N,N], if compute_v is true.
|
||||
@ -454,7 +426,7 @@ REGISTER_OP("MatrixTriangularSolve")
|
||||
.Attr("adjoint: bool = False")
|
||||
.Attr("T: {double, float, half, complex64, complex128}")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
return MatrixTriangularSolveShapeFn(c);
|
||||
return MatrixSolveShapeFn(c, true /* square (*/);
|
||||
});
|
||||
|
||||
REGISTER_OP("MatrixSolveLs")
|
||||
|
@ -122,12 +122,14 @@ TEST(LinalgOpsTest, SelfAdjointEigV2_ShapeFn) {
|
||||
"[d0_0,d0_1,d0_2,d0_3|d0_4];[d0_0,d0_1,d0_2,d0_3|d0_4,d0_3|d0_4]");
|
||||
}
|
||||
|
||||
TEST(LinalgOpsTest, MatrixSolve_ShapeFn) {
|
||||
ShapeInferenceTestOp op("MatrixSolve");
|
||||
TEST(LinalgOpsTest, SquareMatrixSolve_ShapeFn) {
|
||||
for (const char* op_name : {"MatrixSolve", "MatrixTriangularSolve"}) {
|
||||
ShapeInferenceTestOp op(op_name);
|
||||
INFER_OK(op, "?;?", "?");
|
||||
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1];?");
|
||||
INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,2];?");
|
||||
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[5,?,?];[6]");
|
||||
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
|
||||
"[5,?,?];[6]");
|
||||
INFER_ERROR("Shapes must be equal rank, but are 0 and 1", op,
|
||||
"[5,?];[6,?,?]");
|
||||
|
||||
@ -147,29 +149,7 @@ TEST(LinalgOpsTest, MatrixSolve_ShapeFn) {
|
||||
INFER_OK(op, "[10,?,?,1];[?,20,1,?]", "[d0_0,d1_1,d0_3|d1_2,d1_3]");
|
||||
INFER_OK(op, "[10,?,1,1];[?,20,?,?]", "[d0_0,d1_1,d0_2,d1_3]");
|
||||
INFER_OK(op, "[10,?,1,1];[?,20,1,?]", "[d0_0,d1_1,d0_2|d0_3|d1_2,d1_3]");
|
||||
}
|
||||
|
||||
TEST(LinalgOpsTest, MatrixTriangularSolve_ShapeFn) {
|
||||
ShapeInferenceTestOp op("MatrixTriangularSolve");
|
||||
INFER_OK(op, "?;?", "?");
|
||||
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1];?");
|
||||
INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,2];?");
|
||||
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[5,?,?];[6]");
|
||||
|
||||
// Inputs are [...,M,M] and [...,M,K]. Output is [...,M,K].
|
||||
// First test where ... is empty.
|
||||
INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]");
|
||||
INFER_OK(op, "[?,?];[1,?]", "[d1_0,d1_1]");
|
||||
INFER_OK(op, "[1,?];[1,?]", "[d0_0|d1_0,d1_1]");
|
||||
INFER_OK(op, "[?,1];[1,?]", "[d0_1|d1_0,d1_1]");
|
||||
INFER_OK(op, "[1,1];[?,?]", "[d0_0,d1_1]");
|
||||
INFER_OK(op, "[1,1];[1,?]", "[d0_0|d0_1|d1_0,d1_1]");
|
||||
// Test with ... being 2-d.
|
||||
INFER_OK(op, "[10,?,?,?];[?,20,1,?]", "[d0_0,d1_1,d1_2,d1_3]");
|
||||
INFER_OK(op, "[10,?,1,?];[?,20,1,?]", "[d0_0,d1_1,d0_2|d1_2,d1_3]");
|
||||
INFER_OK(op, "[10,?,?,1];[?,20,1,?]", "[d0_0,d1_1,d0_3|d1_2,d1_3]");
|
||||
INFER_OK(op, "[10,?,1,1];[?,20,?,?]", "[d0_0,d1_1,d0_2,d1_3]");
|
||||
INFER_OK(op, "[10,?,1,1];[?,20,1,?]", "[d0_0,d1_1,d0_2|d0_3|d1_2,d1_3]");
|
||||
}
|
||||
}
|
||||
|
||||
TEST(LinalgOpsTest, MatrixSolveLs_ShapeFn) {
|
||||
|
@ -756,7 +756,6 @@ cuda_py_test(
|
||||
name = "matrix_triangular_solve_op_test",
|
||||
size = "small",
|
||||
srcs = ["matrix_triangular_solve_op_test.py"],
|
||||
shard_count = 2,
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:linalg_ops",
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
@ -67,30 +68,29 @@ class MatrixTriangularSolveOpTest(test.TestCase):
|
||||
else:
|
||||
a_np = a
|
||||
if adjoint:
|
||||
axes = list(range(len(a_np.shape)))
|
||||
axes[-2] = -1
|
||||
axes[-1] = -2
|
||||
a_np = np.conj(np.transpose(a_np, axes=axes))
|
||||
a_np = np.conj(np.transpose(a_np))
|
||||
|
||||
if batch_dims is not None:
|
||||
a = np.tile(a, batch_dims + [1, 1])
|
||||
a_np = np.tile(a_np, batch_dims + [1, 1])
|
||||
b = np.tile(b, batch_dims + [1, 1])
|
||||
|
||||
def broadcast(a, b):
|
||||
b1 = b + np.zeros(a.shape[:-2] + (1, 1), dtype=b.dtype)
|
||||
return a, b1
|
||||
|
||||
a_tf = a
|
||||
b_tf = b
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
if use_placeholder:
|
||||
a_tf = array_ops.placeholder_with_default(a_tf, shape=None)
|
||||
b_tf = array_ops.placeholder_with_default(b_tf, shape=None)
|
||||
a_tf = array_ops.placeholder(a.dtype)
|
||||
b_tf = array_ops.placeholder(b.dtype)
|
||||
tf_ans = linalg_ops.matrix_triangular_solve(
|
||||
a_tf, b_tf, lower=lower, adjoint=adjoint)
|
||||
tf_val = sess.run(tf_ans, feed_dict={a_tf: a, b_tf: b})
|
||||
np_ans = np.linalg.solve(a_np, b)
|
||||
else:
|
||||
a_tf = constant_op.constant(a)
|
||||
b_tf = constant_op.constant(b)
|
||||
tf_ans = linalg_ops.matrix_triangular_solve(
|
||||
a_tf, b_tf, lower=lower, adjoint=adjoint)
|
||||
tf_val = self.evaluate(tf_ans)
|
||||
a_np, b = broadcast(a_np, b)
|
||||
np_ans = np.linalg.solve(a_np, b)
|
||||
self.assertEqual(np_ans.shape, tf_ans.get_shape())
|
||||
self.assertEqual(np_ans.shape, tf_val.shape)
|
||||
self.assertAllClose(np_ans, tf_val)
|
||||
|
||||
@ -136,50 +136,6 @@ class MatrixTriangularSolveOpTest(test.TestCase):
|
||||
# Batch of 3x2x2x2 matrices, 3x2x2x3 right-hand sides.
|
||||
self._verifySolveAllWaysReal(matrix, rhs, batch_dims=[3, 2])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla("XLA cannot broadcast triangular solve.")
|
||||
def testSolveBatchBroadcast(self):
|
||||
# 2 x 2 x 2
|
||||
matrix = np.array([[[1., 0.], [3., 4.]], [[1., 0.], [2., 1.]]])
|
||||
# 2 x 3
|
||||
rhs = np.array([[1., 0., 1.], [0., 1., 1.]])
|
||||
# 2 x 2 x 3
|
||||
self._verifySolveAllWaysReal(matrix, rhs)
|
||||
# 2 x 2 x 2
|
||||
matrix2 = np.array([[[1., 0.], [3., 4.]], [[2., 0.], [1., 6.3]]])
|
||||
# 1 x 2 x 3
|
||||
rhs = np.array([[[1., 0., 1.], [0., 1., 1.]]])
|
||||
# 2 x 2 x 3
|
||||
self._verifySolveAllWaysReal(matrix2, rhs)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla("XLA cannot broadcast triangular solve.")
|
||||
def testSolveBatchBroadcastLargerBatches(self):
|
||||
# 1 x 10 x 10
|
||||
matrix = np.random.uniform(low=1, high=2., size=[1, 10, 10])
|
||||
# 10 x 1
|
||||
rhs = np.random.uniform(size=[10, 1])
|
||||
# 1 x 10 x 1
|
||||
self._verifySolveAllWaysReal(matrix, rhs)
|
||||
|
||||
# 2 x 10 x 10
|
||||
matrix = np.random.uniform(low=1, high=2., size=[2, 10, 10])
|
||||
# 10 x 1
|
||||
rhs = np.random.uniform(size=[10, 1])
|
||||
# 2 x 10 x 1
|
||||
self._verifySolveAllWaysReal(matrix, rhs)
|
||||
|
||||
# 2 x 257 x 257
|
||||
matrix = np.random.uniform(low=1, high=2., size=[2, 257, 257])
|
||||
# Also ensure the matrix is well conditioned by making it diagonally
|
||||
# dominant.
|
||||
np.fill_diagonal(matrix[0, ...], 257 * 2)
|
||||
np.fill_diagonal(matrix[1, ...], 257 * 2)
|
||||
# 257 x 1
|
||||
rhs = np.random.uniform(size=[257, 1])
|
||||
# 2 x 257 x 1
|
||||
self._verifySolveAllWaysReal(matrix, rhs)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSolveBatchComplex(self):
|
||||
if test.is_built_with_rocm():
|
||||
|
@ -607,7 +607,6 @@ def _MatrixSolveLsGrad(op, grad):
|
||||
def _MatrixTriangularSolveGrad(op, grad):
|
||||
"""Gradient for MatrixTriangularSolve."""
|
||||
a = op.inputs[0]
|
||||
b = op.inputs[1]
|
||||
adjoint_a = op.get_attr("adjoint")
|
||||
lower_a = op.get_attr("lower")
|
||||
c = op.outputs[0]
|
||||
@ -621,16 +620,7 @@ def _MatrixTriangularSolveGrad(op, grad):
|
||||
grad_a = array_ops.matrix_band_part(grad_a, -1, 0)
|
||||
else:
|
||||
grad_a = array_ops.matrix_band_part(grad_a, 0, -1)
|
||||
# If the static batch shapes are equal, we don't need to unbroadcast.
|
||||
if (a.shape.is_fully_defined() and b.shape.is_fully_defined() and
|
||||
a.shape[:-2] == b.shape[:-2]):
|
||||
return grad_a, grad_b
|
||||
a_shape = array_ops.shape(a)
|
||||
b_shape = array_ops.shape(b)
|
||||
ra, rb = array_ops.broadcast_gradient_args(a_shape[:-2], b_shape[:-2])
|
||||
grad_a = array_ops.reshape(math_ops.reduce_sum(grad_a, axis=ra), a_shape)
|
||||
grad_b = array_ops.reshape(math_ops.reduce_sum(grad_b, axis=rb), b_shape)
|
||||
return grad_a, grad_b
|
||||
return (grad_a, grad_b)
|
||||
|
||||
|
||||
@ops.RegisterGradient("SelfAdjointEigV2")
|
||||
|
@ -79,67 +79,6 @@ def _RegularizedGramianCholesky(matrix, l2_regularizer, first_kind):
|
||||
return gen_linalg_ops.cholesky(gramian)
|
||||
|
||||
|
||||
@tf_export(
|
||||
'linalg.triangular_solve',
|
||||
v1=['linalg.triangular_solve', 'matrix_triangular_solve'])
|
||||
def matrix_triangular_solve(matrix, rhs, lower=True, adjoint=False, name=None):
|
||||
"""Solve systems of linear equations with upper or lower triangular matrices.
|
||||
|
||||
`matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form
|
||||
square matrices. If `lower` is `True` then the strictly upper triangular part
|
||||
of each inner-most matrix is assumed to be zero and not accessed. If `lower`
|
||||
is `False` then the strictly lower triangular part of each inner-most matrix
|
||||
is assumed to be zero and not accessed. `rhs` is a tensor of shape
|
||||
`[..., M, N]`.
|
||||
|
||||
The output is a tensor of shape `[..., M, N]`. If `adjoint` is `True` then the
|
||||
innermost matrices in output satisfy matrix equations `matrix[..., i, k] *
|
||||
output[..., k, j] = rhs[..., i, j]`. If `adjoint` is `False` then the
|
||||
innermost matrices in output satisfy matrix equations
|
||||
`adjoint(matrix[..., i, k]) * output[..., k, j] = rhs[..., i, j]`.
|
||||
|
||||
Example:
|
||||
|
||||
>>> a = tf.constant([[3, 0, 0, 0],
|
||||
... [2, 1, 0, 0],
|
||||
... [1, 0, 1, 0],
|
||||
... [1, 1, 1, 1]], dtype=tf.float32)
|
||||
|
||||
>>> b = tf.constant([[4], [2], [4], [2]], dtype=tf.float32)
|
||||
>>> x = tf.linalg.triangular_solve(a, b, lower=True)
|
||||
>>> x
|
||||
<tf.Tensor: shape=(4, 1), dtype=float32, numpy=
|
||||
array([[ 1.3333334 ],
|
||||
[-0.66666675],
|
||||
[ 2.6666665 ],
|
||||
[-1.3333331 ]], dtype=float32)>
|
||||
>>> tf.matmul(a, x)
|
||||
<tf.Tensor: shape=(4, 1), dtype=float32, numpy=
|
||||
array([[4.],
|
||||
[2.],
|
||||
[4.],
|
||||
[2.]], dtype=float32)>
|
||||
|
||||
Args:
|
||||
matrix: A `Tensor`. Must be one of the following types: `float64`,
|
||||
`float32`, `half`, `complex64`, `complex128`. Shape is `[..., M, M]`.
|
||||
rhs: A `Tensor`. Must have the same type as `matrix`. Shape is `[..., M,
|
||||
N]`.
|
||||
lower: An optional `bool`. Defaults to `True`. Boolean indicating whether
|
||||
the innermost matrices in matrix are lower or upper triangular.
|
||||
adjoint: An optional `bool`. Defaults to `False`. Boolean indicating whether
|
||||
to solve with matrix or its (block-wise) adjoint.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `Tensor`. Has the same type as matrix, and shape is `[..., M, N]`.
|
||||
|
||||
"""
|
||||
with ops.name_scope(name, 'triangular_solve', [matrix, rhs]):
|
||||
return gen_linalg_ops.matrix_triangular_solve(
|
||||
matrix, rhs, lower=lower, adjoint=adjoint)
|
||||
|
||||
|
||||
@tf_export(
|
||||
'linalg.cholesky_solve', v1=['linalg.cholesky_solve', 'cholesky_solve'])
|
||||
@deprecation.deprecated_endpoints('cholesky_solve')
|
||||
|
Loading…
Reference in New Issue
Block a user