Adding ROCm support for the GpuSparse API (TF wrapper for cuSPARSE/hipSPARSE)
This commit is contained in:
parent
f725b46454
commit
7e8ccbd22b
@ -3459,14 +3459,18 @@ tf_kernel_library(
|
|||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "cuda_sparse",
|
name = "cuda_sparse",
|
||||||
srcs = ["cuda_sparse.cc"],
|
srcs = if_cuda(["cuda_sparse.cc"]) + if_rocm(["rocm_sparse.cc"]),
|
||||||
hdrs = ["cuda_sparse.h"],
|
hdrs = ["cuda_sparse.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/kernels:cuda_solvers",
|
"//tensorflow/core/kernels:cuda_solvers",
|
||||||
|
] + if_cuda([
|
||||||
"//tensorflow/stream_executor/cuda:cusparse_lib",
|
"//tensorflow/stream_executor/cuda:cusparse_lib",
|
||||||
] + if_cuda(["@cub_archive//:cub"]),
|
"@cub_archive//:cub",
|
||||||
|
]) + if_rocm([
|
||||||
|
"@local_config_rocm//rocm:hipsparse",
|
||||||
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
LINALG_DEPS = [
|
LINALG_DEPS = [
|
||||||
|
@ -19,11 +19,13 @@ limitations under the License.
|
|||||||
// This header declares the class GpuSparse, which contains wrappers of
|
// This header declares the class GpuSparse, which contains wrappers of
|
||||||
// cuSparse libraries for use in TensorFlow kernels.
|
// cuSparse libraries for use in TensorFlow kernels.
|
||||||
|
|
||||||
#ifdef GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||||
|
|
||||||
using gpusparseStatus_t = cusparseStatus_t;
|
using gpusparseStatus_t = cusparseStatus_t;
|
||||||
@ -33,6 +35,19 @@ using gpusparseAction_t = cusparseAction_t;
|
|||||||
using gpusparseHandle_t = cusparseHandle_t;
|
using gpusparseHandle_t = cusparseHandle_t;
|
||||||
using gpuStream_t = cudaStream_t;
|
using gpuStream_t = cudaStream_t;
|
||||||
|
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
|
#include "rocm/include/hipsparse/hipsparse.h"
|
||||||
|
|
||||||
|
using gpusparseStatus_t = hipsparseStatus_t;
|
||||||
|
using gpusparseOperation_t = hipsparseOperation_t;
|
||||||
|
using gpusparseMatDescr_t = hipsparseMatDescr_t;
|
||||||
|
using gpusparseAction_t = hipsparseAction_t;
|
||||||
|
using gpusparseHandle_t = hipsparseHandle_t;
|
||||||
|
using gpuStream_t = hipStream_t;
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
@ -55,6 +70,8 @@ inline string ConvertGPUSparseErrorToString(const gpusparseStatus_t status) {
|
|||||||
case err: \
|
case err: \
|
||||||
return STRINGIZE(err);
|
return STRINGIZE(err);
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
RETURN_IF_STATUS(CUSPARSE_STATUS_SUCCESS)
|
RETURN_IF_STATUS(CUSPARSE_STATUS_SUCCESS)
|
||||||
RETURN_IF_STATUS(CUSPARSE_STATUS_NOT_INITIALIZED)
|
RETURN_IF_STATUS(CUSPARSE_STATUS_NOT_INITIALIZED)
|
||||||
RETURN_IF_STATUS(CUSPARSE_STATUS_ALLOC_FAILED)
|
RETURN_IF_STATUS(CUSPARSE_STATUS_ALLOC_FAILED)
|
||||||
@ -65,14 +82,34 @@ inline string ConvertGPUSparseErrorToString(const gpusparseStatus_t status) {
|
|||||||
RETURN_IF_STATUS(CUSPARSE_STATUS_INTERNAL_ERROR)
|
RETURN_IF_STATUS(CUSPARSE_STATUS_INTERNAL_ERROR)
|
||||||
RETURN_IF_STATUS(CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED)
|
RETURN_IF_STATUS(CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED)
|
||||||
|
|
||||||
#undef RETURN_IF_STATUS
|
|
||||||
#undef STRINGIZE
|
|
||||||
default:
|
default:
|
||||||
return strings::StrCat("Unknown CUSPARSE error: ",
|
return strings::StrCat("Unknown CUSPARSE error: ",
|
||||||
static_cast<int>(status));
|
static_cast<int>(status));
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
|
RETURN_IF_STATUS(HIPSPARSE_STATUS_SUCCESS)
|
||||||
|
RETURN_IF_STATUS(HIPSPARSE_STATUS_NOT_INITIALIZED)
|
||||||
|
RETURN_IF_STATUS(HIPSPARSE_STATUS_ALLOC_FAILED)
|
||||||
|
RETURN_IF_STATUS(HIPSPARSE_STATUS_INVALID_VALUE)
|
||||||
|
RETURN_IF_STATUS(HIPSPARSE_STATUS_ARCH_MISMATCH)
|
||||||
|
RETURN_IF_STATUS(HIPSPARSE_STATUS_MAPPING_ERROR)
|
||||||
|
RETURN_IF_STATUS(HIPSPARSE_STATUS_EXECUTION_FAILED)
|
||||||
|
RETURN_IF_STATUS(HIPSPARSE_STATUS_INTERNAL_ERROR)
|
||||||
|
RETURN_IF_STATUS(HIPSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED)
|
||||||
|
RETURN_IF_STATUS(HIPSPARSE_STATUS_ZERO_PIVOT)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return strings::StrCat("Unknown hipSPARSE error: ",
|
||||||
|
static_cast<int>(status));
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#undef RETURN_IF_STATUS
|
||||||
|
#undef STRINGIZE
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
#define TF_RETURN_IF_GPUSPARSE_ERROR(expr) \
|
#define TF_RETURN_IF_GPUSPARSE_ERROR(expr) \
|
||||||
do { \
|
do { \
|
||||||
auto status = (expr); \
|
auto status = (expr); \
|
||||||
@ -83,9 +120,24 @@ inline string ConvertGPUSparseErrorToString(const gpusparseStatus_t status) {
|
|||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
|
#define TF_RETURN_IF_GPUSPARSE_ERROR(expr) \
|
||||||
|
do { \
|
||||||
|
auto status = (expr); \
|
||||||
|
if (TF_PREDICT_FALSE(status != HIPSPARSE_STATUS_SUCCESS)) { \
|
||||||
|
return errors::Internal(__FILE__, ":", __LINE__, " (", TF_STR(expr), \
|
||||||
|
"): hipSPARSE call failed with status ", \
|
||||||
|
ConvertGPUSparseErrorToString(status)); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
inline gpusparseOperation_t TransposeAndConjugateToGpuSparseOp(bool transpose,
|
inline gpusparseOperation_t TransposeAndConjugateToGpuSparseOp(bool transpose,
|
||||||
bool conjugate,
|
bool conjugate,
|
||||||
Status* status) {
|
Status* status) {
|
||||||
|
#if GOOGLE_CUDA
|
||||||
if (transpose) {
|
if (transpose) {
|
||||||
return conjugate ? CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE
|
return conjugate ? CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE
|
||||||
: CUSPARSE_OPERATION_TRANSPOSE;
|
: CUSPARSE_OPERATION_TRANSPOSE;
|
||||||
@ -97,6 +149,19 @@ inline gpusparseOperation_t TransposeAndConjugateToGpuSparseOp(bool transpose,
|
|||||||
}
|
}
|
||||||
return CUSPARSE_OPERATION_NON_TRANSPOSE;
|
return CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||||
}
|
}
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
if (transpose) {
|
||||||
|
return conjugate ? HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE
|
||||||
|
: HIPSPARSE_OPERATION_TRANSPOSE;
|
||||||
|
} else {
|
||||||
|
if (conjugate) {
|
||||||
|
DCHECK(status != nullptr);
|
||||||
|
*status = errors::InvalidArgument(
|
||||||
|
"Conjugate == True and transpose == False is not supported.");
|
||||||
|
}
|
||||||
|
return HIPSPARSE_OPERATION_NON_TRANSPOSE;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// The GpuSparse class provides a simplified templated API for cuSparse
|
// The GpuSparse class provides a simplified templated API for cuSparse
|
||||||
@ -353,7 +418,11 @@ class GpuSparseMatrixDescriptor {
|
|||||||
// called more than once.
|
// called more than once.
|
||||||
Status Initialize() {
|
Status Initialize() {
|
||||||
DCHECK(!initialized_);
|
DCHECK(!initialized_);
|
||||||
|
#if GOOGLE_CUDA
|
||||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descr_));
|
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descr_));
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descr_));
|
||||||
|
#endif
|
||||||
initialized_ = true;
|
initialized_ = true;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -371,7 +440,11 @@ class GpuSparseMatrixDescriptor {
|
|||||||
private:
|
private:
|
||||||
void Release() {
|
void Release() {
|
||||||
if (initialized_) {
|
if (initialized_) {
|
||||||
|
#if GOOGLE_CUDA
|
||||||
cusparseDestroyMatDescr(descr_);
|
cusparseDestroyMatDescr(descr_);
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
hipsparseDestroyMatDescr(descr_);
|
||||||
|
#endif
|
||||||
initialized_ = false;
|
initialized_ = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -382,6 +455,8 @@ class GpuSparseMatrixDescriptor {
|
|||||||
TF_DISALLOW_COPY_AND_ASSIGN(GpuSparseMatrixDescriptor);
|
TF_DISALLOW_COPY_AND_ASSIGN(GpuSparseMatrixDescriptor);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
// A wrapper class to ensure that an unsorted/sorted CSR conversion information
|
// A wrapper class to ensure that an unsorted/sorted CSR conversion information
|
||||||
// struct (csru2csrInfo_t) is initialized only once. See:
|
// struct (csru2csrInfo_t) is initialized only once. See:
|
||||||
// https://docs.nvidia.com/cuda/cusparse/index.html#csru2csr
|
// https://docs.nvidia.com/cuda/cusparse/index.html#csru2csr
|
||||||
@ -439,8 +514,10 @@ class GpuSparseCsrSortingConversionInfo {
|
|||||||
TF_DISALLOW_COPY_AND_ASSIGN(GpuSparseCsrSortingConversionInfo);
|
TF_DISALLOW_COPY_AND_ASSIGN(GpuSparseCsrSortingConversionInfo);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_
|
#endif // TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_
|
||||||
|
330
tensorflow/core/kernels/rocm_sparse.cc
Normal file
330
tensorflow/core/kernels/rocm_sparse.cc
Normal file
@ -0,0 +1,330 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#if TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
|
#include <complex>
|
||||||
|
#include <memory>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||||
|
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||||
|
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||||
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
#include "tensorflow/core/platform/macros.h"
|
||||||
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
|
#include "tensorflow/core/platform/stream_executor.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// A set of initialized handles to the underlying ROCm libraries used by
|
||||||
|
// GpuSparse. We maintain one such set of handles per unique stream.
|
||||||
|
class HipSparseHandles {
|
||||||
|
public:
|
||||||
|
explicit HipSparseHandles(hipStream_t stream)
|
||||||
|
: initialized_(false), stream_(stream) {}
|
||||||
|
|
||||||
|
HipSparseHandles(HipSparseHandles&& rhs)
|
||||||
|
: initialized_(rhs.initialized_),
|
||||||
|
stream_(std::move(rhs.stream_)),
|
||||||
|
hipsparse_handle_(rhs.hipsparse_handle_) {
|
||||||
|
rhs.initialized_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
HipSparseHandles& operator=(HipSparseHandles&& rhs) {
|
||||||
|
if (this == &rhs) return *this;
|
||||||
|
Release();
|
||||||
|
stream_ = std::move(rhs.stream_);
|
||||||
|
hipsparse_handle_ = std::move(rhs.hipsparse_handle_);
|
||||||
|
initialized_ = rhs.initialized_;
|
||||||
|
rhs.initialized_ = false;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
~HipSparseHandles() { Release(); }
|
||||||
|
|
||||||
|
Status Initialize() {
|
||||||
|
if (initialized_) return Status::OK();
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreate(&hipsparse_handle_));
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||||
|
hipsparseSetStream(hipsparse_handle_, stream_));
|
||||||
|
initialized_ = true;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
hipsparseHandle_t& handle() {
|
||||||
|
DCHECK(initialized_);
|
||||||
|
return hipsparse_handle_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const hipsparseHandle_t& handle() const {
|
||||||
|
DCHECK(initialized_);
|
||||||
|
return hipsparse_handle_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void Release() {
|
||||||
|
if (initialized_) {
|
||||||
|
// This should never return anything other than success
|
||||||
|
auto err = hipsparseDestroy(hipsparse_handle_);
|
||||||
|
DCHECK(err == HIPSPARSE_STATUS_SUCCESS)
|
||||||
|
<< "Failed to destroy hipSPARSE instance.";
|
||||||
|
initialized_ = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bool initialized_;
|
||||||
|
hipStream_t stream_;
|
||||||
|
hipsparseHandle_t hipsparse_handle_;
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(HipSparseHandles);
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO(ebrevdo): Replace global mutex guarding CudaSparseHandles
|
||||||
|
// lookup with one of:
|
||||||
|
// 1. Adding the handle to the CudaStream structure; do the lookup there.
|
||||||
|
// 2. Add a thread-local cusparse, set it to the current stream
|
||||||
|
// upon each call.
|
||||||
|
// #1 seems like the cleanest option but will need to wait until this
|
||||||
|
// is moved into TF core.
|
||||||
|
static mutex handle_map_mutex(LINKER_INITIALIZED);
|
||||||
|
|
||||||
|
using HandleMap = std::unordered_map<hipStream_t, HipSparseHandles>;
|
||||||
|
|
||||||
|
// Returns a singleton map used for storing initialized handles for each unique
|
||||||
|
// cuda stream.
|
||||||
|
HandleMap* GetHandleMapSingleton() {
|
||||||
|
static HandleMap* cm = new HandleMap;
|
||||||
|
return cm;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
GpuSparse::GpuSparse(OpKernelContext* context)
|
||||||
|
: initialized_(false), context_(context) {
|
||||||
|
auto hip_stream_ptr =
|
||||||
|
reinterpret_cast<const hipStream_t*>(context->op_device_context()
|
||||||
|
->stream()
|
||||||
|
->implementation()
|
||||||
|
->GpuStreamMemberHack());
|
||||||
|
DCHECK(hip_stream_ptr);
|
||||||
|
gpu_stream_ = *hip_stream_ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GpuSparse::Initialize() {
|
||||||
|
HandleMap* handle_map = GetHandleMapSingleton();
|
||||||
|
DCHECK(handle_map);
|
||||||
|
mutex_lock lock(handle_map_mutex);
|
||||||
|
auto it = handle_map->find(gpu_stream_);
|
||||||
|
if (it == handle_map->end()) {
|
||||||
|
LOG(INFO) << "Creating GpuSparse handles for stream " << gpu_stream_;
|
||||||
|
// Previously unseen ROCm stream. Initialize a set of ROCm sparse library
|
||||||
|
// handles for it.
|
||||||
|
HipSparseHandles new_handles(gpu_stream_);
|
||||||
|
TF_RETURN_IF_ERROR(new_handles.Initialize());
|
||||||
|
it = handle_map->insert(std::make_pair(gpu_stream_, std::move(new_handles)))
|
||||||
|
.first;
|
||||||
|
}
|
||||||
|
gpusparse_handle_ = &it->second.handle();
|
||||||
|
initialized_ = true;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Macro that specializes a sparse method for all 4 standard
|
||||||
|
// numeric types.
|
||||||
|
#define TF_CALL_HIP_LAPACK_TYPES(m) m(float, S) m(double, D)
|
||||||
|
|
||||||
|
// Macros to construct hipsparse method names.
|
||||||
|
#define SPARSE_FN(method, sparse_prefix) hipsparse##sparse_prefix##method
|
||||||
|
|
||||||
|
Status GpuSparse::Coo2csr(const int* cooRowInd, int nnz, int m,
|
||||||
|
int* csrRowPtr) const {
|
||||||
|
DCHECK(initialized_);
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcoo2csr(*gpusparse_handle_, cooRowInd,
|
||||||
|
nnz, m, csrRowPtr,
|
||||||
|
HIPSPARSE_INDEX_BASE_ZERO));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GpuSparse::Csr2coo(const int* csrRowPtr, int nnz, int m,
|
||||||
|
int* cooRowInd) const {
|
||||||
|
DCHECK(initialized_);
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcsr2coo(*gpusparse_handle_, csrRowPtr,
|
||||||
|
nnz, m, cooRowInd,
|
||||||
|
HIPSPARSE_INDEX_BASE_ZERO));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Scalar, typename SparseFnT>
|
||||||
|
static inline Status CsrmmImpl(
|
||||||
|
SparseFnT op, OpKernelContext* context, hipsparseHandle_t hipsparse_handle,
|
||||||
|
hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n,
|
||||||
|
int k, int nnz, const Scalar* alpha_host, const hipsparseMatDescr_t descrA,
|
||||||
|
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||||
|
const int* csrSortedColIndA, const Scalar* B, int ldb,
|
||||||
|
const Scalar* beta_host, Scalar* C, int ldc) {
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(op(hipsparse_handle, transA, transB, m, n, k,
|
||||||
|
nnz, alpha_host, descrA, csrSortedValA,
|
||||||
|
csrSortedRowPtrA, csrSortedColIndA, B, ldb,
|
||||||
|
beta_host, C, ldc));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CSRMM_INSTANCE(Scalar, sparse_prefix) \
|
||||||
|
template <> \
|
||||||
|
Status GpuSparse::Csrmm<Scalar>( \
|
||||||
|
hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n, \
|
||||||
|
int k, int nnz, const Scalar* alpha_host, \
|
||||||
|
const hipsparseMatDescr_t descrA, const Scalar* csrSortedValA, \
|
||||||
|
const int* csrSortedRowPtrA, const int* csrSortedColIndA, \
|
||||||
|
const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C, int ldc) \
|
||||||
|
const { \
|
||||||
|
DCHECK(initialized_); \
|
||||||
|
return CsrmmImpl(SPARSE_FN(csrmm2, sparse_prefix), context_, \
|
||||||
|
*gpusparse_handle_, transA, transB, m, n, k, nnz, \
|
||||||
|
alpha_host, descrA, csrSortedValA, csrSortedRowPtrA, \
|
||||||
|
csrSortedColIndA, B, ldb, beta_host, C, ldc); \
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_CALL_HIP_LAPACK_TYPES(CSRMM_INSTANCE);
|
||||||
|
|
||||||
|
template <typename Scalar, typename SparseFnT>
|
||||||
|
static inline Status CsrmvImpl(SparseFnT op, OpKernelContext* context,
|
||||||
|
hipsparseHandle_t hipsparse_handle,
|
||||||
|
hipsparseOperation_t transA, int m, int n,
|
||||||
|
int nnz, const Scalar* alpha_host,
|
||||||
|
const hipsparseMatDescr_t descrA,
|
||||||
|
const Scalar* csrSortedValA,
|
||||||
|
const int* csrSortedRowPtrA,
|
||||||
|
const int* csrSortedColIndA, const Scalar* x,
|
||||||
|
const Scalar* beta_host, Scalar* y) {
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||||
|
op(hipsparse_handle, transA, m, n, nnz, alpha_host, descrA, csrSortedValA,
|
||||||
|
csrSortedRowPtrA, csrSortedColIndA, x, beta_host, y));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(ebrevdo,rmlarsen): Use csrmv_mp for all cases when available in CUDA 9.
|
||||||
|
#define CSRMV_INSTANCE(Scalar, sparse_prefix) \
|
||||||
|
template <> \
|
||||||
|
Status GpuSparse::Csrmv<Scalar>( \
|
||||||
|
hipsparseOperation_t transA, int m, int n, int nnz, \
|
||||||
|
const Scalar* alpha_host, const hipsparseMatDescr_t descrA, \
|
||||||
|
const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
|
||||||
|
const int* csrSortedColIndA, const Scalar* x, const Scalar* beta_host, \
|
||||||
|
Scalar* y) const { \
|
||||||
|
DCHECK(initialized_); \
|
||||||
|
return CsrmvImpl(SPARSE_FN(csrmv, sparse_prefix), context_, \
|
||||||
|
*gpusparse_handle_, transA, m, n, nnz, alpha_host, \
|
||||||
|
descrA, csrSortedValA, csrSortedRowPtrA, \
|
||||||
|
csrSortedColIndA, x, beta_host, y); \
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_CALL_HIP_LAPACK_TYPES(CSRMV_INSTANCE);
|
||||||
|
|
||||||
|
Status GpuSparse::CsrgemmNnz(
|
||||||
|
hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n,
|
||||||
|
int k, const hipsparseMatDescr_t descrA, int nnzA,
|
||||||
|
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
|
||||||
|
const hipsparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,
|
||||||
|
const int* csrSortedColIndB, const hipsparseMatDescr_t descrC,
|
||||||
|
int* csrSortedRowPtrC, int* nnzTotalDevHostPtr) {
|
||||||
|
DCHECK(initialized_);
|
||||||
|
DCHECK(nnzTotalDevHostPtr != nullptr);
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcsrgemmNnz(
|
||||||
|
*gpusparse_handle_, transA, transB, m, n, k, descrA, nnzA,
|
||||||
|
csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB,
|
||||||
|
csrSortedColIndB, descrC, csrSortedRowPtrC, nnzTotalDevHostPtr));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Scalar, typename SparseFnT>
|
||||||
|
static inline Status CsrgemmImpl(
|
||||||
|
SparseFnT op, OpKernelContext* context, hipsparseHandle_t hipsparse_handle,
|
||||||
|
hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n,
|
||||||
|
int k, const hipsparseMatDescr_t descrA, int nnzA,
|
||||||
|
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||||
|
const int* csrSortedColIndA, const hipsparseMatDescr_t descrB, int nnzB,
|
||||||
|
const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
|
||||||
|
const int* csrSortedColIndB, const hipsparseMatDescr_t descrC,
|
||||||
|
Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC) {
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||||
|
op(hipsparse_handle, transA, transB, m, n, k, descrA, nnzA, csrSortedValA,
|
||||||
|
csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, csrSortedValB,
|
||||||
|
csrSortedRowPtrB, csrSortedColIndB, descrC, csrSortedValC,
|
||||||
|
csrSortedRowPtrC, csrSortedColIndC));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CSRGEMM_INSTANCE(Scalar, sparse_prefix) \
|
||||||
|
template <> \
|
||||||
|
Status GpuSparse::Csrgemm<Scalar>( \
|
||||||
|
hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n, \
|
||||||
|
int k, const hipsparseMatDescr_t descrA, int nnzA, \
|
||||||
|
const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
|
||||||
|
const int* csrSortedColIndA, const hipsparseMatDescr_t descrB, int nnzB, \
|
||||||
|
const Scalar* csrSortedValB, const int* csrSortedRowPtrB, \
|
||||||
|
const int* csrSortedColIndB, const hipsparseMatDescr_t descrC, \
|
||||||
|
Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC) { \
|
||||||
|
DCHECK(initialized_); \
|
||||||
|
return CsrgemmImpl(SPARSE_FN(csrgemm, sparse_prefix), context_, \
|
||||||
|
*gpusparse_handle_, transA, transB, m, n, k, descrA, \
|
||||||
|
nnzA, csrSortedValA, csrSortedRowPtrA, \
|
||||||
|
csrSortedColIndA, descrB, nnzB, csrSortedValB, \
|
||||||
|
csrSortedRowPtrB, csrSortedColIndB, descrC, \
|
||||||
|
csrSortedValC, csrSortedRowPtrC, csrSortedColIndC); \
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_CALL_HIP_LAPACK_TYPES(CSRGEMM_INSTANCE);
|
||||||
|
|
||||||
|
template <typename Scalar, typename SparseFnT>
|
||||||
|
static inline Status Csr2cscImpl(SparseFnT op, OpKernelContext* context,
|
||||||
|
hipsparseHandle_t hipsparse_handle, int m,
|
||||||
|
int n, int nnz, const Scalar* csrVal,
|
||||||
|
const int* csrRowPtr, const int* csrColInd,
|
||||||
|
Scalar* cscVal, int* cscRowInd, int* cscColPtr,
|
||||||
|
const hipsparseAction_t copyValues) {
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||||
|
op(hipsparse_handle, m, n, nnz, csrVal, csrRowPtr, csrColInd, cscVal,
|
||||||
|
cscRowInd, cscColPtr, copyValues, HIPSPARSE_INDEX_BASE_ZERO));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CSR2CSC_INSTANCE(Scalar, sparse_prefix) \
|
||||||
|
template <> \
|
||||||
|
Status GpuSparse::Csr2csc<Scalar>( \
|
||||||
|
int m, int n, int nnz, const Scalar* csrVal, const int* csrRowPtr, \
|
||||||
|
const int* csrColInd, Scalar* cscVal, int* cscRowInd, int* cscColPtr, \
|
||||||
|
const hipsparseAction_t copyValues) { \
|
||||||
|
DCHECK(initialized_); \
|
||||||
|
return Csr2cscImpl(SPARSE_FN(csr2csc, sparse_prefix), context_, \
|
||||||
|
*gpusparse_handle_, m, n, nnz, csrVal, csrRowPtr, \
|
||||||
|
csrColInd, cscVal, cscRowInd, cscColPtr, copyValues); \
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_CALL_HIP_LAPACK_TYPES(CSR2CSC_INSTANCE);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_USE_ROCM
|
Loading…
x
Reference in New Issue
Block a user