Adding ROCm support for the GpuSparse API (TF wrapper for cuSPARSE/hipSPARSE)
This commit is contained in:
parent
f725b46454
commit
7e8ccbd22b
tensorflow/core/kernels
@ -3459,14 +3459,18 @@ tf_kernel_library(
|
||||
|
||||
tf_kernel_library(
|
||||
name = "cuda_sparse",
|
||||
srcs = ["cuda_sparse.cc"],
|
||||
srcs = if_cuda(["cuda_sparse.cc"]) + if_rocm(["rocm_sparse.cc"]),
|
||||
hdrs = ["cuda_sparse.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/kernels:cuda_solvers",
|
||||
] + if_cuda([
|
||||
"//tensorflow/stream_executor/cuda:cusparse_lib",
|
||||
] + if_cuda(["@cub_archive//:cub"]),
|
||||
"@cub_archive//:cub",
|
||||
]) + if_rocm([
|
||||
"@local_config_rocm//rocm:hipsparse",
|
||||
]),
|
||||
)
|
||||
|
||||
LINALG_DEPS = [
|
||||
|
@ -19,11 +19,13 @@ limitations under the License.
|
||||
// This header declares the class GpuSparse, which contains wrappers of
|
||||
// cuSparse libraries for use in TensorFlow kernels.
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||
|
||||
using gpusparseStatus_t = cusparseStatus_t;
|
||||
@ -33,6 +35,19 @@ using gpusparseAction_t = cusparseAction_t;
|
||||
using gpusparseHandle_t = cusparseHandle_t;
|
||||
using gpuStream_t = cudaStream_t;
|
||||
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
|
||||
#include "rocm/include/hipsparse/hipsparse.h"
|
||||
|
||||
using gpusparseStatus_t = hipsparseStatus_t;
|
||||
using gpusparseOperation_t = hipsparseOperation_t;
|
||||
using gpusparseMatDescr_t = hipsparseMatDescr_t;
|
||||
using gpusparseAction_t = hipsparseAction_t;
|
||||
using gpusparseHandle_t = hipsparseHandle_t;
|
||||
using gpuStream_t = hipStream_t;
|
||||
|
||||
#endif
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
@ -55,6 +70,8 @@ inline string ConvertGPUSparseErrorToString(const gpusparseStatus_t status) {
|
||||
case err: \
|
||||
return STRINGIZE(err);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
RETURN_IF_STATUS(CUSPARSE_STATUS_SUCCESS)
|
||||
RETURN_IF_STATUS(CUSPARSE_STATUS_NOT_INITIALIZED)
|
||||
RETURN_IF_STATUS(CUSPARSE_STATUS_ALLOC_FAILED)
|
||||
@ -65,14 +82,34 @@ inline string ConvertGPUSparseErrorToString(const gpusparseStatus_t status) {
|
||||
RETURN_IF_STATUS(CUSPARSE_STATUS_INTERNAL_ERROR)
|
||||
RETURN_IF_STATUS(CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED)
|
||||
|
||||
#undef RETURN_IF_STATUS
|
||||
#undef STRINGIZE
|
||||
default:
|
||||
return strings::StrCat("Unknown CUSPARSE error: ",
|
||||
static_cast<int>(status));
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
|
||||
RETURN_IF_STATUS(HIPSPARSE_STATUS_SUCCESS)
|
||||
RETURN_IF_STATUS(HIPSPARSE_STATUS_NOT_INITIALIZED)
|
||||
RETURN_IF_STATUS(HIPSPARSE_STATUS_ALLOC_FAILED)
|
||||
RETURN_IF_STATUS(HIPSPARSE_STATUS_INVALID_VALUE)
|
||||
RETURN_IF_STATUS(HIPSPARSE_STATUS_ARCH_MISMATCH)
|
||||
RETURN_IF_STATUS(HIPSPARSE_STATUS_MAPPING_ERROR)
|
||||
RETURN_IF_STATUS(HIPSPARSE_STATUS_EXECUTION_FAILED)
|
||||
RETURN_IF_STATUS(HIPSPARSE_STATUS_INTERNAL_ERROR)
|
||||
RETURN_IF_STATUS(HIPSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED)
|
||||
RETURN_IF_STATUS(HIPSPARSE_STATUS_ZERO_PIVOT)
|
||||
|
||||
default:
|
||||
return strings::StrCat("Unknown hipSPARSE error: ",
|
||||
static_cast<int>(status));
|
||||
#endif
|
||||
|
||||
#undef RETURN_IF_STATUS
|
||||
#undef STRINGIZE
|
||||
}
|
||||
}
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#define TF_RETURN_IF_GPUSPARSE_ERROR(expr) \
|
||||
do { \
|
||||
auto status = (expr); \
|
||||
@ -83,9 +120,24 @@ inline string ConvertGPUSparseErrorToString(const gpusparseStatus_t status) {
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
|
||||
#define TF_RETURN_IF_GPUSPARSE_ERROR(expr) \
|
||||
do { \
|
||||
auto status = (expr); \
|
||||
if (TF_PREDICT_FALSE(status != HIPSPARSE_STATUS_SUCCESS)) { \
|
||||
return errors::Internal(__FILE__, ":", __LINE__, " (", TF_STR(expr), \
|
||||
"): hipSPARSE call failed with status ", \
|
||||
ConvertGPUSparseErrorToString(status)); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#endif
|
||||
|
||||
inline gpusparseOperation_t TransposeAndConjugateToGpuSparseOp(bool transpose,
|
||||
bool conjugate,
|
||||
Status* status) {
|
||||
#if GOOGLE_CUDA
|
||||
if (transpose) {
|
||||
return conjugate ? CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE
|
||||
: CUSPARSE_OPERATION_TRANSPOSE;
|
||||
@ -97,6 +149,19 @@ inline gpusparseOperation_t TransposeAndConjugateToGpuSparseOp(bool transpose,
|
||||
}
|
||||
return CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
}
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
if (transpose) {
|
||||
return conjugate ? HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE
|
||||
: HIPSPARSE_OPERATION_TRANSPOSE;
|
||||
} else {
|
||||
if (conjugate) {
|
||||
DCHECK(status != nullptr);
|
||||
*status = errors::InvalidArgument(
|
||||
"Conjugate == True and transpose == False is not supported.");
|
||||
}
|
||||
return HIPSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// The GpuSparse class provides a simplified templated API for cuSparse
|
||||
@ -353,7 +418,11 @@ class GpuSparseMatrixDescriptor {
|
||||
// called more than once.
|
||||
Status Initialize() {
|
||||
DCHECK(!initialized_);
|
||||
#if GOOGLE_CUDA
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descr_));
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descr_));
|
||||
#endif
|
||||
initialized_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -371,7 +440,11 @@ class GpuSparseMatrixDescriptor {
|
||||
private:
|
||||
void Release() {
|
||||
if (initialized_) {
|
||||
#if GOOGLE_CUDA
|
||||
cusparseDestroyMatDescr(descr_);
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
hipsparseDestroyMatDescr(descr_);
|
||||
#endif
|
||||
initialized_ = false;
|
||||
}
|
||||
}
|
||||
@ -382,6 +455,8 @@ class GpuSparseMatrixDescriptor {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GpuSparseMatrixDescriptor);
|
||||
};
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
// A wrapper class to ensure that an unsorted/sorted CSR conversion information
|
||||
// struct (csru2csrInfo_t) is initialized only once. See:
|
||||
// https://docs.nvidia.com/cuda/cusparse/index.html#csru2csr
|
||||
@ -439,8 +514,10 @@ class GpuSparseCsrSortingConversionInfo {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GpuSparseCsrSortingConversionInfo);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_
|
||||
|
330
tensorflow/core/kernels/rocm_sparse.cc
Normal file
330
tensorflow/core/kernels/rocm_sparse.cc
Normal file
@ -0,0 +1,330 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
|
||||
#include <complex>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// A set of initialized handles to the underlying ROCm libraries used by
|
||||
// GpuSparse. We maintain one such set of handles per unique stream.
|
||||
class HipSparseHandles {
|
||||
public:
|
||||
explicit HipSparseHandles(hipStream_t stream)
|
||||
: initialized_(false), stream_(stream) {}
|
||||
|
||||
HipSparseHandles(HipSparseHandles&& rhs)
|
||||
: initialized_(rhs.initialized_),
|
||||
stream_(std::move(rhs.stream_)),
|
||||
hipsparse_handle_(rhs.hipsparse_handle_) {
|
||||
rhs.initialized_ = false;
|
||||
}
|
||||
|
||||
HipSparseHandles& operator=(HipSparseHandles&& rhs) {
|
||||
if (this == &rhs) return *this;
|
||||
Release();
|
||||
stream_ = std::move(rhs.stream_);
|
||||
hipsparse_handle_ = std::move(rhs.hipsparse_handle_);
|
||||
initialized_ = rhs.initialized_;
|
||||
rhs.initialized_ = false;
|
||||
return *this;
|
||||
}
|
||||
|
||||
~HipSparseHandles() { Release(); }
|
||||
|
||||
Status Initialize() {
|
||||
if (initialized_) return Status::OK();
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreate(&hipsparse_handle_));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
hipsparseSetStream(hipsparse_handle_, stream_));
|
||||
initialized_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
hipsparseHandle_t& handle() {
|
||||
DCHECK(initialized_);
|
||||
return hipsparse_handle_;
|
||||
}
|
||||
|
||||
const hipsparseHandle_t& handle() const {
|
||||
DCHECK(initialized_);
|
||||
return hipsparse_handle_;
|
||||
}
|
||||
|
||||
private:
|
||||
void Release() {
|
||||
if (initialized_) {
|
||||
// This should never return anything other than success
|
||||
auto err = hipsparseDestroy(hipsparse_handle_);
|
||||
DCHECK(err == HIPSPARSE_STATUS_SUCCESS)
|
||||
<< "Failed to destroy hipSPARSE instance.";
|
||||
initialized_ = false;
|
||||
}
|
||||
}
|
||||
bool initialized_;
|
||||
hipStream_t stream_;
|
||||
hipsparseHandle_t hipsparse_handle_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(HipSparseHandles);
|
||||
};
|
||||
|
||||
// TODO(ebrevdo): Replace global mutex guarding CudaSparseHandles
|
||||
// lookup with one of:
|
||||
// 1. Adding the handle to the CudaStream structure; do the lookup there.
|
||||
// 2. Add a thread-local cusparse, set it to the current stream
|
||||
// upon each call.
|
||||
// #1 seems like the cleanest option but will need to wait until this
|
||||
// is moved into TF core.
|
||||
static mutex handle_map_mutex(LINKER_INITIALIZED);
|
||||
|
||||
using HandleMap = std::unordered_map<hipStream_t, HipSparseHandles>;
|
||||
|
||||
// Returns a singleton map used for storing initialized handles for each unique
|
||||
// cuda stream.
|
||||
HandleMap* GetHandleMapSingleton() {
|
||||
static HandleMap* cm = new HandleMap;
|
||||
return cm;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
GpuSparse::GpuSparse(OpKernelContext* context)
|
||||
: initialized_(false), context_(context) {
|
||||
auto hip_stream_ptr =
|
||||
reinterpret_cast<const hipStream_t*>(context->op_device_context()
|
||||
->stream()
|
||||
->implementation()
|
||||
->GpuStreamMemberHack());
|
||||
DCHECK(hip_stream_ptr);
|
||||
gpu_stream_ = *hip_stream_ptr;
|
||||
}
|
||||
|
||||
Status GpuSparse::Initialize() {
|
||||
HandleMap* handle_map = GetHandleMapSingleton();
|
||||
DCHECK(handle_map);
|
||||
mutex_lock lock(handle_map_mutex);
|
||||
auto it = handle_map->find(gpu_stream_);
|
||||
if (it == handle_map->end()) {
|
||||
LOG(INFO) << "Creating GpuSparse handles for stream " << gpu_stream_;
|
||||
// Previously unseen ROCm stream. Initialize a set of ROCm sparse library
|
||||
// handles for it.
|
||||
HipSparseHandles new_handles(gpu_stream_);
|
||||
TF_RETURN_IF_ERROR(new_handles.Initialize());
|
||||
it = handle_map->insert(std::make_pair(gpu_stream_, std::move(new_handles)))
|
||||
.first;
|
||||
}
|
||||
gpusparse_handle_ = &it->second.handle();
|
||||
initialized_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Macro that specializes a sparse method for all 4 standard
|
||||
// numeric types.
|
||||
#define TF_CALL_HIP_LAPACK_TYPES(m) m(float, S) m(double, D)
|
||||
|
||||
// Macros to construct hipsparse method names.
|
||||
#define SPARSE_FN(method, sparse_prefix) hipsparse##sparse_prefix##method
|
||||
|
||||
Status GpuSparse::Coo2csr(const int* cooRowInd, int nnz, int m,
|
||||
int* csrRowPtr) const {
|
||||
DCHECK(initialized_);
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcoo2csr(*gpusparse_handle_, cooRowInd,
|
||||
nnz, m, csrRowPtr,
|
||||
HIPSPARSE_INDEX_BASE_ZERO));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GpuSparse::Csr2coo(const int* csrRowPtr, int nnz, int m,
|
||||
int* cooRowInd) const {
|
||||
DCHECK(initialized_);
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcsr2coo(*gpusparse_handle_, csrRowPtr,
|
||||
nnz, m, cooRowInd,
|
||||
HIPSPARSE_INDEX_BASE_ZERO));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename Scalar, typename SparseFnT>
|
||||
static inline Status CsrmmImpl(
|
||||
SparseFnT op, OpKernelContext* context, hipsparseHandle_t hipsparse_handle,
|
||||
hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n,
|
||||
int k, int nnz, const Scalar* alpha_host, const hipsparseMatDescr_t descrA,
|
||||
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA, const Scalar* B, int ldb,
|
||||
const Scalar* beta_host, Scalar* C, int ldc) {
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(op(hipsparse_handle, transA, transB, m, n, k,
|
||||
nnz, alpha_host, descrA, csrSortedValA,
|
||||
csrSortedRowPtrA, csrSortedColIndA, B, ldb,
|
||||
beta_host, C, ldc));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define CSRMM_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status GpuSparse::Csrmm<Scalar>( \
|
||||
hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n, \
|
||||
int k, int nnz, const Scalar* alpha_host, \
|
||||
const hipsparseMatDescr_t descrA, const Scalar* csrSortedValA, \
|
||||
const int* csrSortedRowPtrA, const int* csrSortedColIndA, \
|
||||
const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C, int ldc) \
|
||||
const { \
|
||||
DCHECK(initialized_); \
|
||||
return CsrmmImpl(SPARSE_FN(csrmm2, sparse_prefix), context_, \
|
||||
*gpusparse_handle_, transA, transB, m, n, k, nnz, \
|
||||
alpha_host, descrA, csrSortedValA, csrSortedRowPtrA, \
|
||||
csrSortedColIndA, B, ldb, beta_host, C, ldc); \
|
||||
}
|
||||
|
||||
TF_CALL_HIP_LAPACK_TYPES(CSRMM_INSTANCE);
|
||||
|
||||
template <typename Scalar, typename SparseFnT>
|
||||
static inline Status CsrmvImpl(SparseFnT op, OpKernelContext* context,
|
||||
hipsparseHandle_t hipsparse_handle,
|
||||
hipsparseOperation_t transA, int m, int n,
|
||||
int nnz, const Scalar* alpha_host,
|
||||
const hipsparseMatDescr_t descrA,
|
||||
const Scalar* csrSortedValA,
|
||||
const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA, const Scalar* x,
|
||||
const Scalar* beta_host, Scalar* y) {
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
op(hipsparse_handle, transA, m, n, nnz, alpha_host, descrA, csrSortedValA,
|
||||
csrSortedRowPtrA, csrSortedColIndA, x, beta_host, y));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// TODO(ebrevdo,rmlarsen): Use csrmv_mp for all cases when available in CUDA 9.
|
||||
#define CSRMV_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status GpuSparse::Csrmv<Scalar>( \
|
||||
hipsparseOperation_t transA, int m, int n, int nnz, \
|
||||
const Scalar* alpha_host, const hipsparseMatDescr_t descrA, \
|
||||
const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
|
||||
const int* csrSortedColIndA, const Scalar* x, const Scalar* beta_host, \
|
||||
Scalar* y) const { \
|
||||
DCHECK(initialized_); \
|
||||
return CsrmvImpl(SPARSE_FN(csrmv, sparse_prefix), context_, \
|
||||
*gpusparse_handle_, transA, m, n, nnz, alpha_host, \
|
||||
descrA, csrSortedValA, csrSortedRowPtrA, \
|
||||
csrSortedColIndA, x, beta_host, y); \
|
||||
}
|
||||
|
||||
TF_CALL_HIP_LAPACK_TYPES(CSRMV_INSTANCE);
|
||||
|
||||
Status GpuSparse::CsrgemmNnz(
|
||||
hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n,
|
||||
int k, const hipsparseMatDescr_t descrA, int nnzA,
|
||||
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
|
||||
const hipsparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,
|
||||
const int* csrSortedColIndB, const hipsparseMatDescr_t descrC,
|
||||
int* csrSortedRowPtrC, int* nnzTotalDevHostPtr) {
|
||||
DCHECK(initialized_);
|
||||
DCHECK(nnzTotalDevHostPtr != nullptr);
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcsrgemmNnz(
|
||||
*gpusparse_handle_, transA, transB, m, n, k, descrA, nnzA,
|
||||
csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB,
|
||||
csrSortedColIndB, descrC, csrSortedRowPtrC, nnzTotalDevHostPtr));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename Scalar, typename SparseFnT>
|
||||
static inline Status CsrgemmImpl(
|
||||
SparseFnT op, OpKernelContext* context, hipsparseHandle_t hipsparse_handle,
|
||||
hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n,
|
||||
int k, const hipsparseMatDescr_t descrA, int nnzA,
|
||||
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA, const hipsparseMatDescr_t descrB, int nnzB,
|
||||
const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
|
||||
const int* csrSortedColIndB, const hipsparseMatDescr_t descrC,
|
||||
Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC) {
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
op(hipsparse_handle, transA, transB, m, n, k, descrA, nnzA, csrSortedValA,
|
||||
csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, csrSortedValB,
|
||||
csrSortedRowPtrB, csrSortedColIndB, descrC, csrSortedValC,
|
||||
csrSortedRowPtrC, csrSortedColIndC));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define CSRGEMM_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status GpuSparse::Csrgemm<Scalar>( \
|
||||
hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n, \
|
||||
int k, const hipsparseMatDescr_t descrA, int nnzA, \
|
||||
const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
|
||||
const int* csrSortedColIndA, const hipsparseMatDescr_t descrB, int nnzB, \
|
||||
const Scalar* csrSortedValB, const int* csrSortedRowPtrB, \
|
||||
const int* csrSortedColIndB, const hipsparseMatDescr_t descrC, \
|
||||
Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC) { \
|
||||
DCHECK(initialized_); \
|
||||
return CsrgemmImpl(SPARSE_FN(csrgemm, sparse_prefix), context_, \
|
||||
*gpusparse_handle_, transA, transB, m, n, k, descrA, \
|
||||
nnzA, csrSortedValA, csrSortedRowPtrA, \
|
||||
csrSortedColIndA, descrB, nnzB, csrSortedValB, \
|
||||
csrSortedRowPtrB, csrSortedColIndB, descrC, \
|
||||
csrSortedValC, csrSortedRowPtrC, csrSortedColIndC); \
|
||||
}
|
||||
|
||||
TF_CALL_HIP_LAPACK_TYPES(CSRGEMM_INSTANCE);
|
||||
|
||||
template <typename Scalar, typename SparseFnT>
|
||||
static inline Status Csr2cscImpl(SparseFnT op, OpKernelContext* context,
|
||||
hipsparseHandle_t hipsparse_handle, int m,
|
||||
int n, int nnz, const Scalar* csrVal,
|
||||
const int* csrRowPtr, const int* csrColInd,
|
||||
Scalar* cscVal, int* cscRowInd, int* cscColPtr,
|
||||
const hipsparseAction_t copyValues) {
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
op(hipsparse_handle, m, n, nnz, csrVal, csrRowPtr, csrColInd, cscVal,
|
||||
cscRowInd, cscColPtr, copyValues, HIPSPARSE_INDEX_BASE_ZERO));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define CSR2CSC_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status GpuSparse::Csr2csc<Scalar>( \
|
||||
int m, int n, int nnz, const Scalar* csrVal, const int* csrRowPtr, \
|
||||
const int* csrColInd, Scalar* cscVal, int* cscRowInd, int* cscColPtr, \
|
||||
const hipsparseAction_t copyValues) { \
|
||||
DCHECK(initialized_); \
|
||||
return Csr2cscImpl(SPARSE_FN(csr2csc, sparse_prefix), context_, \
|
||||
*gpusparse_handle_, m, n, nnz, csrVal, csrRowPtr, \
|
||||
csrColInd, cscVal, cscRowInd, cscColPtr, copyValues); \
|
||||
}
|
||||
|
||||
TF_CALL_HIP_LAPACK_TYPES(CSR2CSC_INSTANCE);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_USE_ROCM
|
Loading…
Reference in New Issue
Block a user