Merge pull request #34800 from ROCmSoftwarePlatform:google_upstream_rocm_csr_sparse_matrix_support

PiperOrigin-RevId: 289617600
Change-Id: Ic1aa3714126d7b867295ae386b6be643c1dc83e4
This commit is contained in:
TensorFlower Gardener 2020-01-14 03:19:11 -08:00
commit 0e2b5a9d2a
35 changed files with 1031 additions and 424 deletions

View File

@ -3480,14 +3480,18 @@ tf_kernel_library(
tf_kernel_library( tf_kernel_library(
name = "cuda_sparse", name = "cuda_sparse",
srcs = ["cuda_sparse.cc"], srcs = if_cuda(["cuda_sparse.cc"]) + if_rocm(["rocm_sparse.cc"]),
hdrs = ["cuda_sparse.h"], hdrs = ["cuda_sparse.h"],
deps = [ deps = [
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/kernels:cuda_solvers", "//tensorflow/core/kernels:cuda_solvers",
] + if_cuda([
"//tensorflow/stream_executor/cuda:cusparse_lib", "//tensorflow/stream_executor/cuda:cusparse_lib",
] + if_cuda(["@cub_archive//:cub"]), "@cub_archive//:cub",
]) + if_rocm([
"@local_config_rocm//rocm:hipsparse",
]),
) )
LINALG_DEPS = [ LINALG_DEPS = [

View File

@ -69,7 +69,7 @@ inline typename CudaComplexT<T>::type* AsCudaComplex(T* p) {
} }
// A set of initialized handles to the underlying Cuda libraries used by // A set of initialized handles to the underlying Cuda libraries used by
// CudaSparse. We maintain one such set of handles per unique stream. // GpuSparse. We maintain one such set of handles per unique stream.
class CudaSparseHandles { class CudaSparseHandles {
public: public:
explicit CudaSparseHandles(cudaStream_t stream) explicit CudaSparseHandles(cudaStream_t stream)
@ -96,8 +96,8 @@ class CudaSparseHandles {
Status Initialize() { Status Initialize() {
if (initialized_) return Status::OK(); if (initialized_) return Status::OK();
TF_RETURN_IF_CUSPARSE_ERROR(cusparseCreate(&cusparse_handle_)); TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreate(&cusparse_handle_));
TF_RETURN_IF_CUSPARSE_ERROR(cusparseSetStream(cusparse_handle_, stream_)); TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSetStream(cusparse_handle_, stream_));
initialized_ = true; initialized_ = true;
return Status::OK(); return Status::OK();
} }
@ -149,7 +149,7 @@ HandleMap* GetHandleMapSingleton() {
} // namespace } // namespace
CudaSparse::CudaSparse(OpKernelContext* context) GpuSparse::GpuSparse(OpKernelContext* context)
: initialized_(false), context_(context) { : initialized_(false), context_(context) {
auto cuda_stream_ptr = auto cuda_stream_ptr =
reinterpret_cast<const cudaStream_t*>(context->op_device_context() reinterpret_cast<const cudaStream_t*>(context->op_device_context()
@ -157,25 +157,24 @@ CudaSparse::CudaSparse(OpKernelContext* context)
->implementation() ->implementation()
->GpuStreamMemberHack()); ->GpuStreamMemberHack());
DCHECK(cuda_stream_ptr); DCHECK(cuda_stream_ptr);
cuda_stream_ = *cuda_stream_ptr; gpu_stream_ = *cuda_stream_ptr;
} }
Status CudaSparse::Initialize() { Status GpuSparse::Initialize() {
HandleMap* handle_map = GetHandleMapSingleton(); HandleMap* handle_map = GetHandleMapSingleton();
DCHECK(handle_map); DCHECK(handle_map);
mutex_lock lock(handle_map_mutex); mutex_lock lock(handle_map_mutex);
auto it = handle_map->find(cuda_stream_); auto it = handle_map->find(gpu_stream_);
if (it == handle_map->end()) { if (it == handle_map->end()) {
LOG(INFO) << "Creating CudaSparse handles for stream " << cuda_stream_; LOG(INFO) << "Creating CudaSparse handles for stream " << gpu_stream_;
// Previously unseen Cuda stream. Initialize a set of Cuda sparse library // Previously unseen Cuda stream. Initialize a set of Cuda sparse library
// handles for it. // handles for it.
CudaSparseHandles new_handles(cuda_stream_); CudaSparseHandles new_handles(gpu_stream_);
TF_RETURN_IF_ERROR(new_handles.Initialize()); TF_RETURN_IF_ERROR(new_handles.Initialize());
it = it = handle_map->insert(std::make_pair(gpu_stream_, std::move(new_handles)))
handle_map->insert(std::make_pair(cuda_stream_, std::move(new_handles)))
.first; .first;
} }
cusparse_handle_ = &it->second.handle(); gpusparse_handle_ = &it->second.handle();
initialized_ = true; initialized_ = true;
return Status::OK(); return Status::OK();
} }
@ -205,7 +204,7 @@ template <typename Scalar, typename SparseFn>
static inline Status GtsvImpl(SparseFn op, cusparseHandle_t cusparse_handle, static inline Status GtsvImpl(SparseFn op, cusparseHandle_t cusparse_handle,
int m, int n, const Scalar* dl, const Scalar* d, int m, int n, const Scalar* dl, const Scalar* d,
const Scalar* du, Scalar* B, int ldb) { const Scalar* du, Scalar* B, int ldb) {
TF_RETURN_IF_CUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl), TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl),
AsCudaComplex(d), AsCudaComplex(du), AsCudaComplex(d), AsCudaComplex(du),
AsCudaComplex(B), ldb)); AsCudaComplex(B), ldb));
return Status::OK(); return Status::OK();
@ -213,11 +212,11 @@ static inline Status GtsvImpl(SparseFn op, cusparseHandle_t cusparse_handle,
#define GTSV_INSTANCE(Scalar, sparse_prefix) \ #define GTSV_INSTANCE(Scalar, sparse_prefix) \
template <> \ template <> \
Status CudaSparse::Gtsv<Scalar>(int m, int n, const Scalar* dl, \ Status GpuSparse::Gtsv<Scalar>(int m, int n, const Scalar* dl, \
const Scalar* d, const Scalar* du, \ const Scalar* d, const Scalar* du, Scalar* B, \
Scalar* B, int ldb) const { \ int ldb) const { \
DCHECK(initialized_); \ DCHECK(initialized_); \
return GtsvImpl(SPARSE_FN(gtsv, sparse_prefix), *cusparse_handle_, m, n, \ return GtsvImpl(SPARSE_FN(gtsv, sparse_prefix), *gpusparse_handle_, m, n, \
dl, d, du, B, ldb); \ dl, d, du, B, ldb); \
} }
@ -225,12 +224,12 @@ TF_CALL_LAPACK_TYPES(GTSV_INSTANCE);
#define GTSV_NO_PIVOT_INSTANCE(Scalar, sparse_prefix) \ #define GTSV_NO_PIVOT_INSTANCE(Scalar, sparse_prefix) \
template <> \ template <> \
Status CudaSparse::GtsvNoPivot<Scalar>(int m, int n, const Scalar* dl, \ Status GpuSparse::GtsvNoPivot<Scalar>(int m, int n, const Scalar* dl, \
const Scalar* d, const Scalar* du, \ const Scalar* d, const Scalar* du, \
Scalar* B, int ldb) const { \ Scalar* B, int ldb) const { \
DCHECK(initialized_); \ DCHECK(initialized_); \
return GtsvImpl(SPARSE_FN(gtsv_nopivot, sparse_prefix), *cusparse_handle_, \ return GtsvImpl(SPARSE_FN(gtsv_nopivot, sparse_prefix), \
m, n, dl, d, du, B, ldb); \ *gpusparse_handle_, m, n, dl, d, du, B, ldb); \
} }
TF_CALL_LAPACK_TYPES(GTSV_NO_PIVOT_INSTANCE); TF_CALL_LAPACK_TYPES(GTSV_NO_PIVOT_INSTANCE);
@ -242,7 +241,7 @@ static inline Status GtsvStridedBatchImpl(SparseFn op,
const Scalar* d, const Scalar* du, const Scalar* d, const Scalar* du,
Scalar* x, int batchCount, Scalar* x, int batchCount,
int batchStride) { int batchStride) {
TF_RETURN_IF_CUSPARSE_ERROR(op(cusparse_handle, m, AsCudaComplex(dl), TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, AsCudaComplex(dl),
AsCudaComplex(d), AsCudaComplex(du), AsCudaComplex(d), AsCudaComplex(du),
AsCudaComplex(x), batchCount, batchStride)); AsCudaComplex(x), batchCount, batchStride));
return Status::OK(); return Status::OK();
@ -250,12 +249,12 @@ static inline Status GtsvStridedBatchImpl(SparseFn op,
#define GTSV_STRIDED_BATCH_INSTANCE(Scalar, sparse_prefix) \ #define GTSV_STRIDED_BATCH_INSTANCE(Scalar, sparse_prefix) \
template <> \ template <> \
Status CudaSparse::GtsvStridedBatch<Scalar>( \ Status GpuSparse::GtsvStridedBatch<Scalar>( \
int m, const Scalar* dl, const Scalar* d, const Scalar* du, Scalar* x, \ int m, const Scalar* dl, const Scalar* d, const Scalar* du, Scalar* x, \
int batchCount, int batchStride) const { \ int batchCount, int batchStride) const { \
DCHECK(initialized_); \ DCHECK(initialized_); \
return GtsvStridedBatchImpl(SPARSE_FN(gtsvStridedBatch, sparse_prefix), \ return GtsvStridedBatchImpl(SPARSE_FN(gtsvStridedBatch, sparse_prefix), \
*cusparse_handle_, m, dl, d, du, x, \ *gpusparse_handle_, m, dl, d, du, x, \
batchCount, batchStride); \ batchCount, batchStride); \
} }
@ -266,7 +265,7 @@ static inline Status Gtsv2Impl(SparseFn op, cusparseHandle_t cusparse_handle,
int m, int n, const Scalar* dl, const Scalar* d, int m, int n, const Scalar* dl, const Scalar* d,
const Scalar* du, Scalar* B, int ldb, const Scalar* du, Scalar* B, int ldb,
void* pBuffer) { void* pBuffer) {
TF_RETURN_IF_CUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl), TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl),
AsCudaComplex(d), AsCudaComplex(du), AsCudaComplex(d), AsCudaComplex(du),
AsCudaComplex(B), ldb, pBuffer)); AsCudaComplex(B), ldb, pBuffer));
return Status::OK(); return Status::OK();
@ -274,24 +273,24 @@ static inline Status Gtsv2Impl(SparseFn op, cusparseHandle_t cusparse_handle,
#define GTSV2_INSTANCE(Scalar, sparse_prefix) \ #define GTSV2_INSTANCE(Scalar, sparse_prefix) \
template <> \ template <> \
Status CudaSparse::Gtsv2<Scalar>(int m, int n, const Scalar* dl, \ Status GpuSparse::Gtsv2<Scalar>(int m, int n, const Scalar* dl, \
const Scalar* d, const Scalar* du, \ const Scalar* d, const Scalar* du, \
Scalar* B, int ldb, void* pBuffer) const { \ Scalar* B, int ldb, void* pBuffer) const { \
DCHECK(initialized_); \ DCHECK(initialized_); \
return Gtsv2Impl(SPARSE_FN(gtsv2, sparse_prefix), *cusparse_handle_, m, n, \ return Gtsv2Impl(SPARSE_FN(gtsv2, sparse_prefix), *gpusparse_handle_, m, \
dl, d, du, B, ldb, pBuffer); \ n, dl, d, du, B, ldb, pBuffer); \
} }
TF_CALL_LAPACK_TYPES(GTSV2_INSTANCE); TF_CALL_LAPACK_TYPES(GTSV2_INSTANCE);
#define GTSV2_NO_PIVOT_INSTANCE(Scalar, sparse_prefix) \ #define GTSV2_NO_PIVOT_INSTANCE(Scalar, sparse_prefix) \
template <> \ template <> \
Status CudaSparse::Gtsv2NoPivot<Scalar>( \ Status GpuSparse::Gtsv2NoPivot<Scalar>( \
int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du, \ int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du, \
Scalar* B, int ldb, void* pBuffer) const { \ Scalar* B, int ldb, void* pBuffer) const { \
DCHECK(initialized_); \ DCHECK(initialized_); \
return Gtsv2Impl(SPARSE_FN(gtsv2_nopivot, sparse_prefix), \ return Gtsv2Impl(SPARSE_FN(gtsv2_nopivot, sparse_prefix), \
*cusparse_handle_, m, n, dl, d, du, B, ldb, pBuffer); \ *gpusparse_handle_, m, n, dl, d, du, B, ldb, pBuffer); \
} }
TF_CALL_LAPACK_TYPES(GTSV2_NO_PIVOT_INSTANCE); TF_CALL_LAPACK_TYPES(GTSV2_NO_PIVOT_INSTANCE);
@ -303,7 +302,7 @@ static inline Status Gtsv2BufferSizeExtImpl(SparseFn op,
const Scalar* d, const Scalar* du, const Scalar* d, const Scalar* du,
const Scalar* B, int ldb, const Scalar* B, int ldb,
size_t* bufferSizeInBytes) { size_t* bufferSizeInBytes) {
TF_RETURN_IF_CUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl), TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl),
AsCudaComplex(d), AsCudaComplex(du), AsCudaComplex(d), AsCudaComplex(du),
AsCudaComplex(B), ldb, bufferSizeInBytes)); AsCudaComplex(B), ldb, bufferSizeInBytes));
return Status::OK(); return Status::OK();
@ -311,12 +310,12 @@ static inline Status Gtsv2BufferSizeExtImpl(SparseFn op,
#define GTSV2_BUFFER_SIZE_INSTANCE(Scalar, sparse_prefix) \ #define GTSV2_BUFFER_SIZE_INSTANCE(Scalar, sparse_prefix) \
template <> \ template <> \
Status CudaSparse::Gtsv2BufferSizeExt<Scalar>( \ Status GpuSparse::Gtsv2BufferSizeExt<Scalar>( \
int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du, \ int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du, \
const Scalar* B, int ldb, size_t* bufferSizeInBytes) const { \ const Scalar* B, int ldb, size_t* bufferSizeInBytes) const { \
DCHECK(initialized_); \ DCHECK(initialized_); \
return Gtsv2BufferSizeExtImpl( \ return Gtsv2BufferSizeExtImpl( \
SPARSE_FN(gtsv2_bufferSizeExt, sparse_prefix), *cusparse_handle_, m, \ SPARSE_FN(gtsv2_bufferSizeExt, sparse_prefix), *gpusparse_handle_, m, \
n, dl, d, du, B, ldb, bufferSizeInBytes); \ n, dl, d, du, B, ldb, bufferSizeInBytes); \
} }
@ -324,13 +323,13 @@ TF_CALL_LAPACK_TYPES(GTSV2_BUFFER_SIZE_INSTANCE);
#define GTSV2_NO_PIVOT_BUFFER_SIZE_INSTANCE(Scalar, sparse_prefix) \ #define GTSV2_NO_PIVOT_BUFFER_SIZE_INSTANCE(Scalar, sparse_prefix) \
template <> \ template <> \
Status CudaSparse::Gtsv2NoPivotBufferSizeExt<Scalar>( \ Status GpuSparse::Gtsv2NoPivotBufferSizeExt<Scalar>( \
int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du, \ int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du, \
const Scalar* B, int ldb, size_t* bufferSizeInBytes) const { \ const Scalar* B, int ldb, size_t* bufferSizeInBytes) const { \
DCHECK(initialized_); \ DCHECK(initialized_); \
return Gtsv2BufferSizeExtImpl( \ return Gtsv2BufferSizeExtImpl( \
SPARSE_FN(gtsv2_nopivot_bufferSizeExt, sparse_prefix), \ SPARSE_FN(gtsv2_nopivot_bufferSizeExt, sparse_prefix), \
*cusparse_handle_, m, n, dl, d, du, B, ldb, bufferSizeInBytes); \ *gpusparse_handle_, m, n, dl, d, du, B, ldb, bufferSizeInBytes); \
} }
TF_CALL_LAPACK_TYPES(GTSV2_NO_PIVOT_BUFFER_SIZE_INSTANCE); TF_CALL_LAPACK_TYPES(GTSV2_NO_PIVOT_BUFFER_SIZE_INSTANCE);
@ -342,7 +341,7 @@ static inline Status Gtsv2StridedBatchImpl(SparseFn op,
const Scalar* d, const Scalar* du, const Scalar* d, const Scalar* du,
Scalar* x, int batchCount, Scalar* x, int batchCount,
int batchStride, void* pBuffer) { int batchStride, void* pBuffer) {
TF_RETURN_IF_CUSPARSE_ERROR(op( TF_RETURN_IF_GPUSPARSE_ERROR(op(
cusparse_handle, m, AsCudaComplex(dl), AsCudaComplex(d), cusparse_handle, m, AsCudaComplex(dl), AsCudaComplex(d),
AsCudaComplex(du), AsCudaComplex(x), batchCount, batchStride, pBuffer)); AsCudaComplex(du), AsCudaComplex(x), batchCount, batchStride, pBuffer));
return Status::OK(); return Status::OK();
@ -350,12 +349,12 @@ static inline Status Gtsv2StridedBatchImpl(SparseFn op,
#define GTSV2_STRIDED_BATCH_INSTANCE(Scalar, sparse_prefix) \ #define GTSV2_STRIDED_BATCH_INSTANCE(Scalar, sparse_prefix) \
template <> \ template <> \
Status CudaSparse::Gtsv2StridedBatch<Scalar>( \ Status GpuSparse::Gtsv2StridedBatch<Scalar>( \
int m, const Scalar* dl, const Scalar* d, const Scalar* du, Scalar* x, \ int m, const Scalar* dl, const Scalar* d, const Scalar* du, Scalar* x, \
int batchCount, int batchStride, void* pBuffer) const { \ int batchCount, int batchStride, void* pBuffer) const { \
DCHECK(initialized_); \ DCHECK(initialized_); \
return Gtsv2StridedBatchImpl(SPARSE_FN(gtsv2StridedBatch, sparse_prefix), \ return Gtsv2StridedBatchImpl(SPARSE_FN(gtsv2StridedBatch, sparse_prefix), \
*cusparse_handle_, m, dl, d, du, x, \ *gpusparse_handle_, m, dl, d, du, x, \
batchCount, batchStride, pBuffer); \ batchCount, batchStride, pBuffer); \
} }
@ -366,7 +365,7 @@ static inline Status Gtsv2StridedBatchBufferSizeImpl(
SparseFn op, cusparseHandle_t cusparse_handle, int m, const Scalar* dl, SparseFn op, cusparseHandle_t cusparse_handle, int m, const Scalar* dl,
const Scalar* d, const Scalar* du, const Scalar* x, int batchCount, const Scalar* d, const Scalar* du, const Scalar* x, int batchCount,
int batchStride, size_t* bufferSizeInBytes) { int batchStride, size_t* bufferSizeInBytes) {
TF_RETURN_IF_CUSPARSE_ERROR(op(cusparse_handle, m, AsCudaComplex(dl), TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, AsCudaComplex(dl),
AsCudaComplex(d), AsCudaComplex(du), AsCudaComplex(d), AsCudaComplex(du),
AsCudaComplex(x), batchCount, batchStride, AsCudaComplex(x), batchCount, batchStride,
bufferSizeInBytes)); bufferSizeInBytes));
@ -375,20 +374,20 @@ static inline Status Gtsv2StridedBatchBufferSizeImpl(
#define GTSV2_STRIDED_BATCH_BUFFER_SIZE_INSTANCE(Scalar, sparse_prefix) \ #define GTSV2_STRIDED_BATCH_BUFFER_SIZE_INSTANCE(Scalar, sparse_prefix) \
template <> \ template <> \
Status CudaSparse::Gtsv2StridedBatchBufferSizeExt<Scalar>( \ Status GpuSparse::Gtsv2StridedBatchBufferSizeExt<Scalar>( \
int m, const Scalar* dl, const Scalar* d, const Scalar* du, \ int m, const Scalar* dl, const Scalar* d, const Scalar* du, \
const Scalar* x, int batchCount, int batchStride, \ const Scalar* x, int batchCount, int batchStride, \
size_t* bufferSizeInBytes) const { \ size_t* bufferSizeInBytes) const { \
DCHECK(initialized_); \ DCHECK(initialized_); \
return Gtsv2StridedBatchBufferSizeImpl( \ return Gtsv2StridedBatchBufferSizeImpl( \
SPARSE_FN(gtsv2StridedBatch_bufferSizeExt, sparse_prefix), \ SPARSE_FN(gtsv2StridedBatch_bufferSizeExt, sparse_prefix), \
*cusparse_handle_, m, dl, d, du, x, batchCount, batchStride, \ *gpusparse_handle_, m, dl, d, du, x, batchCount, batchStride, \
bufferSizeInBytes); \ bufferSizeInBytes); \
} }
TF_CALL_LAPACK_TYPES(GTSV2_STRIDED_BATCH_BUFFER_SIZE_INSTANCE); TF_CALL_LAPACK_TYPES(GTSV2_STRIDED_BATCH_BUFFER_SIZE_INSTANCE);
Status CudaSparse::Coo2csr(const int* cooRowInd, int nnz, int m, Status GpuSparse::Coo2csr(const int* cooRowInd, int nnz, int m,
int* csrRowPtr) const { int* csrRowPtr) const {
// cusparseStatus_t CUSPARSEAPI cusparseXcoo2csr(cusparseHandle_t handle, // cusparseStatus_t CUSPARSEAPI cusparseXcoo2csr(cusparseHandle_t handle,
// const int *cooRowInd, // const int *cooRowInd,
@ -398,13 +397,13 @@ Status CudaSparse::Coo2csr(const int* cooRowInd, int nnz, int m,
// cusparseIndexBase_t // cusparseIndexBase_t
// idxBase); // idxBase);
DCHECK(initialized_); DCHECK(initialized_);
TF_RETURN_IF_CUSPARSE_ERROR(cusparseXcoo2csr(*cusparse_handle_, cooRowInd, TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcoo2csr(*gpusparse_handle_, cooRowInd,
nnz, m, csrRowPtr, nnz, m, csrRowPtr,
CUSPARSE_INDEX_BASE_ZERO)); CUSPARSE_INDEX_BASE_ZERO));
return Status::OK(); return Status::OK();
} }
Status CudaSparse::Csr2coo(const int* csrRowPtr, int nnz, int m, Status GpuSparse::Csr2coo(const int* csrRowPtr, int nnz, int m,
int* cooRowInd) const { int* cooRowInd) const {
// cusparseStatus_t CUSPARSEAPI cusparseXcsr2coo(cusparseHandle_t handle, // cusparseStatus_t CUSPARSEAPI cusparseXcsr2coo(cusparseHandle_t handle,
// const int *csrRowPtr, // const int *csrRowPtr,
@ -414,13 +413,13 @@ Status CudaSparse::Csr2coo(const int* csrRowPtr, int nnz, int m,
// cusparseIndexBase_t // cusparseIndexBase_t
// idxBase); // idxBase);
DCHECK(initialized_); DCHECK(initialized_);
TF_RETURN_IF_CUSPARSE_ERROR(cusparseXcsr2coo(*cusparse_handle_, csrRowPtr, TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsr2coo(*gpusparse_handle_, csrRowPtr,
nnz, m, cooRowInd, nnz, m, cooRowInd,
CUSPARSE_INDEX_BASE_ZERO)); CUSPARSE_INDEX_BASE_ZERO));
return Status::OK(); return Status::OK();
} }
Status CudaSparse::CsrgeamNnz(int m, int n, const cusparseMatDescr_t descrA, Status GpuSparse::CsrgeamNnz(int m, int n, const cusparseMatDescr_t descrA,
int nnzA, const int* csrSortedRowPtrA, int nnzA, const int* csrSortedRowPtrA,
const int* csrSortedColIndA, const int* csrSortedColIndA,
const cusparseMatDescr_t descrB, int nnzB, const cusparseMatDescr_t descrB, int nnzB,
@ -430,10 +429,10 @@ Status CudaSparse::CsrgeamNnz(int m, int n, const cusparseMatDescr_t descrA,
int* csrSortedRowPtrC, int* nnzTotalDevHostPtr) { int* csrSortedRowPtrC, int* nnzTotalDevHostPtr) {
DCHECK(initialized_); DCHECK(initialized_);
DCHECK(nnzTotalDevHostPtr != nullptr); DCHECK(nnzTotalDevHostPtr != nullptr);
TF_RETURN_IF_CUSPARSE_ERROR(cusparseXcsrgeamNnz( TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgeamNnz(
*cusparse_handle_, m, n, descrA, nnzA, csrSortedRowPtrA, csrSortedColIndA, *gpusparse_handle_, m, n, descrA, nnzA, csrSortedRowPtrA,
descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB, descrC, csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB,
csrSortedRowPtrC, nnzTotalDevHostPtr)); descrC, csrSortedRowPtrC, nnzTotalDevHostPtr));
return Status::OK(); return Status::OK();
} }
@ -452,7 +451,7 @@ static inline Status CsrmmImpl(
// const float* csrSortedValA, const int* csrSortedRowPtrA, // const float* csrSortedValA, const int* csrSortedRowPtrA,
// const int* csrSortedColIndA, const float* B, int ldb, const float* // const int* csrSortedColIndA, const float* B, int ldb, const float*
// beta, float* C, int ldc); // beta, float* C, int ldc);
TF_RETURN_IF_CUSPARSE_ERROR(op( TF_RETURN_IF_GPUSPARSE_ERROR(op(
cusparse_handle, transA, transB, m, n, k, nnz, AsCudaComplex(alpha_host), cusparse_handle, transA, transB, m, n, k, nnz, AsCudaComplex(alpha_host),
descrA, AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA, descrA, AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
AsCudaComplex(B), ldb, AsCudaComplex(beta_host), AsCudaComplex(C), ldc)); AsCudaComplex(B), ldb, AsCudaComplex(beta_host), AsCudaComplex(C), ldc));
@ -461,7 +460,7 @@ static inline Status CsrmmImpl(
#define CSRMM_INSTANCE(Scalar, sparse_prefix) \ #define CSRMM_INSTANCE(Scalar, sparse_prefix) \
template <> \ template <> \
Status CudaSparse::Csrmm<Scalar>( \ Status GpuSparse::Csrmm<Scalar>( \
cusparseOperation_t transA, cusparseOperation_t transB, int m, int n, \ cusparseOperation_t transA, cusparseOperation_t transB, int m, int n, \
int k, int nnz, const Scalar* alpha_host, \ int k, int nnz, const Scalar* alpha_host, \
const cusparseMatDescr_t descrA, const Scalar* csrSortedValA, \ const cusparseMatDescr_t descrA, const Scalar* csrSortedValA, \
@ -470,7 +469,7 @@ static inline Status CsrmmImpl(
const { \ const { \
DCHECK(initialized_); \ DCHECK(initialized_); \
return CsrmmImpl(SPARSE_FN(csrmm2, sparse_prefix), context_, \ return CsrmmImpl(SPARSE_FN(csrmm2, sparse_prefix), context_, \
*cusparse_handle_, transA, transB, m, n, k, nnz, \ *gpusparse_handle_, transA, transB, m, n, k, nnz, \
alpha_host, descrA, csrSortedValA, csrSortedRowPtrA, \ alpha_host, descrA, csrSortedValA, csrSortedRowPtrA, \
csrSortedColIndA, B, ldb, beta_host, C, ldc); \ csrSortedColIndA, B, ldb, beta_host, C, ldc); \
} }
@ -484,7 +483,7 @@ static inline Status CsrmvImpl(
const cusparseMatDescr_t descrA, const Scalar* csrSortedValA, const cusparseMatDescr_t descrA, const Scalar* csrSortedValA,
const int* csrSortedRowPtrA, const int* csrSortedColIndA, const Scalar* x, const int* csrSortedRowPtrA, const int* csrSortedColIndA, const Scalar* x,
const Scalar* beta_host, Scalar* y) { const Scalar* beta_host, Scalar* y) {
TF_RETURN_IF_CUSPARSE_ERROR( TF_RETURN_IF_GPUSPARSE_ERROR(
op(cusparse_handle, transA, m, n, nnz, AsCudaComplex(alpha_host), descrA, op(cusparse_handle, transA, m, n, nnz, AsCudaComplex(alpha_host), descrA,
AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA, AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
AsCudaComplex(x), AsCudaComplex(beta_host), AsCudaComplex(y))); AsCudaComplex(x), AsCudaComplex(beta_host), AsCudaComplex(y)));
@ -494,7 +493,7 @@ static inline Status CsrmvImpl(
// TODO(ebrevdo,rmlarsen): Use csrmv_mp for all cases when available in CUDA 9. // TODO(ebrevdo,rmlarsen): Use csrmv_mp for all cases when available in CUDA 9.
#define CSRMV_INSTANCE(Scalar, sparse_prefix) \ #define CSRMV_INSTANCE(Scalar, sparse_prefix) \
template <> \ template <> \
Status CudaSparse::Csrmv<Scalar>( \ Status GpuSparse::Csrmv<Scalar>( \
cusparseOperation_t transA, int m, int n, int nnz, \ cusparseOperation_t transA, int m, int n, int nnz, \
const Scalar* alpha_host, const cusparseMatDescr_t descrA, \ const Scalar* alpha_host, const cusparseMatDescr_t descrA, \
const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \ const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
@ -503,12 +502,12 @@ static inline Status CsrmvImpl(
DCHECK(initialized_); \ DCHECK(initialized_); \
if (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) { \ if (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) { \
return CsrmvImpl(SPARSE_FN(csrmv_mp, sparse_prefix), context_, \ return CsrmvImpl(SPARSE_FN(csrmv_mp, sparse_prefix), context_, \
*cusparse_handle_, transA, m, n, nnz, alpha_host, \ *gpusparse_handle_, transA, m, n, nnz, alpha_host, \
descrA, csrSortedValA, csrSortedRowPtrA, \ descrA, csrSortedValA, csrSortedRowPtrA, \
csrSortedColIndA, x, beta_host, y); \ csrSortedColIndA, x, beta_host, y); \
} else { \ } else { \
return CsrmvImpl(SPARSE_FN(csrmv, sparse_prefix), context_, \ return CsrmvImpl(SPARSE_FN(csrmv, sparse_prefix), context_, \
*cusparse_handle_, transA, m, n, nnz, alpha_host, \ *gpusparse_handle_, transA, m, n, nnz, alpha_host, \
descrA, csrSortedValA, csrSortedRowPtrA, \ descrA, csrSortedValA, csrSortedRowPtrA, \
csrSortedColIndA, x, beta_host, y); \ csrSortedColIndA, x, beta_host, y); \
} \ } \
@ -526,7 +525,7 @@ static inline Status CsrgeamImpl(
const int* csrSortedRowPtrB, const int* csrSortedColIndB, const int* csrSortedRowPtrB, const int* csrSortedColIndB,
const cusparseMatDescr_t descrC, Scalar* csrSortedValC, const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
int* csrSortedRowPtrC, int* csrSortedColIndC) { int* csrSortedRowPtrC, int* csrSortedColIndC) {
TF_RETURN_IF_CUSPARSE_ERROR( TF_RETURN_IF_GPUSPARSE_ERROR(
op(cusparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA, op(cusparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA,
AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA, AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB), AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB),
@ -537,7 +536,7 @@ static inline Status CsrgeamImpl(
#define CSRGEAM_INSTANCE(Scalar, sparse_prefix) \ #define CSRGEAM_INSTANCE(Scalar, sparse_prefix) \
template <> \ template <> \
Status CudaSparse::Csrgeam<Scalar>( \ Status GpuSparse::Csrgeam<Scalar>( \
int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, \ int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, \
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \ int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
const int* csrSortedColIndA, const Scalar* beta, \ const int* csrSortedColIndA, const Scalar* beta, \
@ -547,7 +546,7 @@ static inline Status CsrgeamImpl(
int* csrSortedRowPtrC, int* csrSortedColIndC) { \ int* csrSortedRowPtrC, int* csrSortedColIndC) { \
DCHECK(initialized_); \ DCHECK(initialized_); \
return CsrgeamImpl(SPARSE_FN(csrgeam, sparse_prefix), context_, \ return CsrgeamImpl(SPARSE_FN(csrgeam, sparse_prefix), context_, \
*cusparse_handle_, m, n, alpha, descrA, nnzA, \ *gpusparse_handle_, m, n, alpha, descrA, nnzA, \
csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, \ csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, \
beta, descrB, nnzB, csrSortedValB, csrSortedRowPtrB, \ beta, descrB, nnzB, csrSortedValB, csrSortedRowPtrB, \
csrSortedColIndB, descrC, csrSortedValC, \ csrSortedColIndB, descrC, csrSortedValC, \
@ -556,7 +555,7 @@ static inline Status CsrgeamImpl(
TF_CALL_LAPACK_TYPES(CSRGEAM_INSTANCE); TF_CALL_LAPACK_TYPES(CSRGEAM_INSTANCE);
Status CudaSparse::CsrgemmNnz( Status GpuSparse::CsrgemmNnz(
cusparseOperation_t transA, cusparseOperation_t transB, int m, int k, int n, cusparseOperation_t transA, cusparseOperation_t transB, int m, int k, int n,
const cusparseMatDescr_t descrA, int nnzA, const int* csrSortedRowPtrA, const cusparseMatDescr_t descrA, int nnzA, const int* csrSortedRowPtrA,
const int* csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB, const int* csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB,
@ -565,8 +564,8 @@ Status CudaSparse::CsrgemmNnz(
int* nnzTotalDevHostPtr) { int* nnzTotalDevHostPtr) {
DCHECK(initialized_); DCHECK(initialized_);
DCHECK(nnzTotalDevHostPtr != nullptr); DCHECK(nnzTotalDevHostPtr != nullptr);
TF_RETURN_IF_CUSPARSE_ERROR(cusparseXcsrgemmNnz( TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgemmNnz(
*cusparse_handle_, transA, transB, m, k, n, descrA, nnzA, *gpusparse_handle_, transA, transB, m, k, n, descrA, nnzA,
csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB,
csrSortedColIndB, descrC, csrSortedRowPtrC, nnzTotalDevHostPtr)); csrSortedColIndB, descrC, csrSortedRowPtrC, nnzTotalDevHostPtr));
return Status::OK(); return Status::OK();
@ -582,7 +581,7 @@ static inline Status CsrgemmImpl(
const int* csrSortedRowPtrB, const int* csrSortedColIndB, const int* csrSortedRowPtrB, const int* csrSortedColIndB,
const cusparseMatDescr_t descrC, Scalar* csrSortedValC, const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
int* csrSortedRowPtrC, int* csrSortedColIndC) { int* csrSortedRowPtrC, int* csrSortedColIndC) {
TF_RETURN_IF_CUSPARSE_ERROR( TF_RETURN_IF_GPUSPARSE_ERROR(
op(cusparse_handle, transA, transB, m, k, n, descrA, nnzA, op(cusparse_handle, transA, transB, m, k, n, descrA, nnzA,
AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA, AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
descrB, nnzB, AsCudaComplex(csrSortedValB), csrSortedRowPtrB, descrB, nnzB, AsCudaComplex(csrSortedValB), csrSortedRowPtrB,
@ -593,7 +592,7 @@ static inline Status CsrgemmImpl(
#define CSRGEMM_INSTANCE(Scalar, sparse_prefix) \ #define CSRGEMM_INSTANCE(Scalar, sparse_prefix) \
template <> \ template <> \
Status CudaSparse::Csrgemm<Scalar>( \ Status GpuSparse::Csrgemm<Scalar>( \
cusparseOperation_t transA, cusparseOperation_t transB, int m, int k, \ cusparseOperation_t transA, cusparseOperation_t transB, int m, int k, \
int n, const cusparseMatDescr_t descrA, int nnzA, \ int n, const cusparseMatDescr_t descrA, int nnzA, \
const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \ const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
@ -603,7 +602,7 @@ static inline Status CsrgemmImpl(
Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC) { \ Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC) { \
DCHECK(initialized_); \ DCHECK(initialized_); \
return CsrgemmImpl(SPARSE_FN(csrgemm, sparse_prefix), context_, \ return CsrgemmImpl(SPARSE_FN(csrgemm, sparse_prefix), context_, \
*cusparse_handle_, transA, transB, m, k, n, descrA, \ *gpusparse_handle_, transA, transB, m, k, n, descrA, \
nnzA, csrSortedValA, csrSortedRowPtrA, \ nnzA, csrSortedValA, csrSortedRowPtrA, \
csrSortedColIndA, descrB, nnzB, csrSortedValB, \ csrSortedColIndA, descrB, nnzB, csrSortedValB, \
csrSortedRowPtrB, csrSortedColIndB, descrC, \ csrSortedRowPtrB, csrSortedColIndB, descrC, \
@ -620,12 +619,12 @@ static inline Status Csru2csrImpl(SparseFnT op, BufferSizeFnT buffer_size_op,
const cusparseMatDescr_t descrA, const cusparseMatDescr_t descrA,
Scalar* csrVal, const int* csrRowPtr, Scalar* csrVal, const int* csrRowPtr,
int* csrColInd) { int* csrColInd) {
CudaSparseCsrSortingConversionInfo info; GpuSparseCsrSortingConversionInfo info;
TF_RETURN_IF_ERROR(info.Initialize()); TF_RETURN_IF_ERROR(info.Initialize());
size_t pBufferSizeInBytes = 0; size_t pBufferSizeInBytes = 0;
TF_RETURN_IF_CUSPARSE_ERROR( TF_RETURN_IF_GPUSPARSE_ERROR(
buffer_size_op(cusparse_handle, m, n, nnz, AsCudaComplex(csrVal), buffer_size_op(cusparse_handle, m, n, nnz, AsCudaComplex(csrVal),
csrRowPtr, csrColInd, info.info(), &pBufferSizeInBytes)); csrRowPtr, csrColInd, info.info(), &pBufferSizeInBytes));
@ -636,7 +635,7 @@ static inline Status Csru2csrImpl(SparseFnT op, BufferSizeFnT buffer_size_op,
auto pBuffer = pBuffer_t.flat<int8>(); auto pBuffer = pBuffer_t.flat<int8>();
DCHECK(pBuffer.data() != nullptr); DCHECK(pBuffer.data() != nullptr);
TF_RETURN_IF_CUSPARSE_ERROR(op(cusparse_handle, m, n, nnz, descrA, TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, nnz, descrA,
AsCudaComplex(csrVal), csrRowPtr, csrColInd, AsCudaComplex(csrVal), csrRowPtr, csrColInd,
info.info(), pBuffer.data())); info.info(), pBuffer.data()));
@ -645,13 +644,13 @@ static inline Status Csru2csrImpl(SparseFnT op, BufferSizeFnT buffer_size_op,
#define CSRU2CSR_INSTANCE(Scalar, sparse_prefix) \ #define CSRU2CSR_INSTANCE(Scalar, sparse_prefix) \
template <> \ template <> \
Status CudaSparse::Csru2csr<Scalar>( \ Status GpuSparse::Csru2csr<Scalar>( \
int m, int n, int nnz, const cusparseMatDescr_t descrA, Scalar* csrVal, \ int m, int n, int nnz, const cusparseMatDescr_t descrA, Scalar* csrVal, \
const int* csrRowPtr, int* csrColInd) { \ const int* csrRowPtr, int* csrColInd) { \
DCHECK(initialized_); \ DCHECK(initialized_); \
return Csru2csrImpl(SPARSE_FN(csru2csr, sparse_prefix), \ return Csru2csrImpl(SPARSE_FN(csru2csr, sparse_prefix), \
BUFSIZE_FN(csru2csr, sparse_prefix), context_, \ BUFSIZE_FN(csru2csr, sparse_prefix), context_, \
*cusparse_handle_, m, n, nnz, descrA, csrVal, \ *gpusparse_handle_, m, n, nnz, descrA, csrVal, \
csrRowPtr, csrColInd); \ csrRowPtr, csrColInd); \
} }
@ -664,7 +663,7 @@ static inline Status Csr2cscImpl(SparseFnT op, OpKernelContext* context,
const int* csrRowPtr, const int* csrColInd, const int* csrRowPtr, const int* csrColInd,
Scalar* cscVal, int* cscRowInd, int* cscColPtr, Scalar* cscVal, int* cscRowInd, int* cscColPtr,
const cusparseAction_t copyValues) { const cusparseAction_t copyValues) {
TF_RETURN_IF_CUSPARSE_ERROR(op(cusparse_handle, m, n, nnz, TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, nnz,
AsCudaComplex(csrVal), csrRowPtr, csrColInd, AsCudaComplex(csrVal), csrRowPtr, csrColInd,
AsCudaComplex(cscVal), cscRowInd, cscColPtr, AsCudaComplex(cscVal), cscRowInd, cscColPtr,
copyValues, CUSPARSE_INDEX_BASE_ZERO)); copyValues, CUSPARSE_INDEX_BASE_ZERO));
@ -673,13 +672,13 @@ static inline Status Csr2cscImpl(SparseFnT op, OpKernelContext* context,
#define CSR2CSC_INSTANCE(Scalar, sparse_prefix) \ #define CSR2CSC_INSTANCE(Scalar, sparse_prefix) \
template <> \ template <> \
Status CudaSparse::Csr2csc<Scalar>( \ Status GpuSparse::Csr2csc<Scalar>( \
int m, int n, int nnz, const Scalar* csrVal, const int* csrRowPtr, \ int m, int n, int nnz, const Scalar* csrVal, const int* csrRowPtr, \
const int* csrColInd, Scalar* cscVal, int* cscRowInd, int* cscColPtr, \ const int* csrColInd, Scalar* cscVal, int* cscRowInd, int* cscColPtr, \
const cusparseAction_t copyValues) { \ const cusparseAction_t copyValues) { \
DCHECK(initialized_); \ DCHECK(initialized_); \
return Csr2cscImpl(SPARSE_FN(csr2csc, sparse_prefix), context_, \ return Csr2cscImpl(SPARSE_FN(csr2csc, sparse_prefix), context_, \
*cusparse_handle_, m, n, nnz, csrVal, csrRowPtr, \ *gpusparse_handle_, m, n, nnz, csrVal, csrRowPtr, \
csrColInd, cscVal, cscRowInd, cscColPtr, copyValues); \ csrColInd, cscVal, cscRowInd, cscColPtr, copyValues); \
} }

View File

@ -16,15 +16,38 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_ #ifndef TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_
#define TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_ #define TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_
// This header declares the class CudaSparse, which contains wrappers of // This header declares the class GpuSparse, which contains wrappers of
// cuSparse libraries for use in TensorFlow kernels. // cuSparse libraries for use in TensorFlow kernels.
#ifdef GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include <functional> #include <functional>
#include <vector> #include <vector>
#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cusparse.h" #include "third_party/gpus/cuda/include/cusparse.h"
using gpusparseStatus_t = cusparseStatus_t;
using gpusparseOperation_t = cusparseOperation_t;
using gpusparseMatDescr_t = cusparseMatDescr_t;
using gpusparseAction_t = cusparseAction_t;
using gpusparseHandle_t = cusparseHandle_t;
using gpuStream_t = cudaStream_t;
#elif TENSORFLOW_USE_ROCM
#include "rocm/include/hipsparse/hipsparse.h"
using gpusparseStatus_t = hipsparseStatus_t;
using gpusparseOperation_t = hipsparseOperation_t;
using gpusparseMatDescr_t = hipsparseMatDescr_t;
using gpusparseAction_t = hipsparseAction_t;
using gpusparseHandle_t = hipsparseHandle_t;
using gpuStream_t = hipStream_t;
#endif
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/tensor_types.h"
@ -40,13 +63,15 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
inline string ConvertCUSparseErrorToString(const cusparseStatus_t status) { inline string ConvertGPUSparseErrorToString(const gpusparseStatus_t status) {
switch (status) { switch (status) {
#define STRINGIZE(q) #q #define STRINGIZE(q) #q
#define RETURN_IF_STATUS(err) \ #define RETURN_IF_STATUS(err) \
case err: \ case err: \
return STRINGIZE(err); return STRINGIZE(err);
#if GOOGLE_CUDA
RETURN_IF_STATUS(CUSPARSE_STATUS_SUCCESS) RETURN_IF_STATUS(CUSPARSE_STATUS_SUCCESS)
RETURN_IF_STATUS(CUSPARSE_STATUS_NOT_INITIALIZED) RETURN_IF_STATUS(CUSPARSE_STATUS_NOT_INITIALIZED)
RETURN_IF_STATUS(CUSPARSE_STATUS_ALLOC_FAILED) RETURN_IF_STATUS(CUSPARSE_STATUS_ALLOC_FAILED)
@ -57,27 +82,62 @@ inline string ConvertCUSparseErrorToString(const cusparseStatus_t status) {
RETURN_IF_STATUS(CUSPARSE_STATUS_INTERNAL_ERROR) RETURN_IF_STATUS(CUSPARSE_STATUS_INTERNAL_ERROR)
RETURN_IF_STATUS(CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED) RETURN_IF_STATUS(CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED)
#undef RETURN_IF_STATUS
#undef STRINGIZE
default: default:
return strings::StrCat("Unknown CUSPARSE error: ", return strings::StrCat("Unknown CUSPARSE error: ",
static_cast<int>(status)); static_cast<int>(status));
#elif TENSORFLOW_USE_ROCM
RETURN_IF_STATUS(HIPSPARSE_STATUS_SUCCESS)
RETURN_IF_STATUS(HIPSPARSE_STATUS_NOT_INITIALIZED)
RETURN_IF_STATUS(HIPSPARSE_STATUS_ALLOC_FAILED)
RETURN_IF_STATUS(HIPSPARSE_STATUS_INVALID_VALUE)
RETURN_IF_STATUS(HIPSPARSE_STATUS_ARCH_MISMATCH)
RETURN_IF_STATUS(HIPSPARSE_STATUS_MAPPING_ERROR)
RETURN_IF_STATUS(HIPSPARSE_STATUS_EXECUTION_FAILED)
RETURN_IF_STATUS(HIPSPARSE_STATUS_INTERNAL_ERROR)
RETURN_IF_STATUS(HIPSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED)
RETURN_IF_STATUS(HIPSPARSE_STATUS_ZERO_PIVOT)
default:
return strings::StrCat("Unknown hipSPARSE error: ",
static_cast<int>(status));
#endif
#undef RETURN_IF_STATUS
#undef STRINGIZE
} }
} }
#define TF_RETURN_IF_CUSPARSE_ERROR(expr) \ #if GOOGLE_CUDA
#define TF_RETURN_IF_GPUSPARSE_ERROR(expr) \
do { \ do { \
auto status = (expr); \ auto status = (expr); \
if (TF_PREDICT_FALSE(status != CUSPARSE_STATUS_SUCCESS)) { \ if (TF_PREDICT_FALSE(status != CUSPARSE_STATUS_SUCCESS)) { \
return errors::Internal(__FILE__, ":", __LINE__, " (", TF_STR(expr), \ return errors::Internal(__FILE__, ":", __LINE__, " (", TF_STR(expr), \
"): cuSparse call failed with status ", \ "): cuSparse call failed with status ", \
ConvertCUSparseErrorToString(status)); \ ConvertGPUSparseErrorToString(status)); \
} \ } \
} while (0) } while (0)
inline cusparseOperation_t TransposeAndConjugateToCuSparseOp(bool transpose, #elif TENSORFLOW_USE_ROCM
#define TF_RETURN_IF_GPUSPARSE_ERROR(expr) \
do { \
auto status = (expr); \
if (TF_PREDICT_FALSE(status != HIPSPARSE_STATUS_SUCCESS)) { \
return errors::Internal(__FILE__, ":", __LINE__, " (", TF_STR(expr), \
"): hipSPARSE call failed with status ", \
ConvertGPUSparseErrorToString(status)); \
} \
} while (0)
#endif
inline gpusparseOperation_t TransposeAndConjugateToGpuSparseOp(bool transpose,
bool conjugate, bool conjugate,
Status* status) { Status* status) {
#if GOOGLE_CUDA
if (transpose) { if (transpose) {
return conjugate ? CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE return conjugate ? CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE
: CUSPARSE_OPERATION_TRANSPOSE; : CUSPARSE_OPERATION_TRANSPOSE;
@ -89,25 +149,38 @@ inline cusparseOperation_t TransposeAndConjugateToCuSparseOp(bool transpose,
} }
return CUSPARSE_OPERATION_NON_TRANSPOSE; return CUSPARSE_OPERATION_NON_TRANSPOSE;
} }
#elif TENSORFLOW_USE_ROCM
if (transpose) {
return conjugate ? HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE
: HIPSPARSE_OPERATION_TRANSPOSE;
} else {
if (conjugate) {
DCHECK(status != nullptr);
*status = errors::InvalidArgument(
"Conjugate == True and transpose == False is not supported.");
}
return HIPSPARSE_OPERATION_NON_TRANSPOSE;
}
#endif
} }
// The CudaSparse class provides a simplified templated API for cuSparse // The GpuSparse class provides a simplified templated API for cuSparse
// (http://docs.nvidia.com/cuda/cusparse/index.html). // (http://docs.nvidia.com/cuda/cusparse/index.html).
// An object of this class wraps static cuSparse instances, // An object of this class wraps static cuSparse instances,
// and will launch Cuda kernels on the stream wrapped by the GPU device // and will launch Cuda kernels on the stream wrapped by the GPU device
// in the OpKernelContext provided to the constructor. // in the OpKernelContext provided to the constructor.
// //
// Notice: All the computational member functions are asynchronous and simply // Notice: All the computational member functions are asynchronous and simply
// launch one or more Cuda kernels on the Cuda stream wrapped by the CudaSparse // launch one or more Cuda kernels on the Cuda stream wrapped by the GpuSparse
// object. // object.
class CudaSparse { class GpuSparse {
public: public:
// This object stores a pointer to context, which must outlive it. // This object stores a pointer to context, which must outlive it.
explicit CudaSparse(OpKernelContext* context); explicit GpuSparse(OpKernelContext* context);
virtual ~CudaSparse() {} virtual ~GpuSparse() {}
// This initializes the CudaSparse class if it hasn't // This initializes the GpuSparse class if it hasn't
// been initialized yet. All following public methods require the // been initialized yet. All following public methods require the
// class has been initialized. Can be run multiple times; all // class has been initialized. Can be run multiple times; all
// subsequent calls after the first have no effect. // subsequent calls after the first have no effect.
@ -218,9 +291,9 @@ class CudaSparse {
// //
// **NOTE** This is an in-place operation for data in C. // **NOTE** This is an in-place operation for data in C.
template <typename Scalar> template <typename Scalar>
Status Csrmm(cusparseOperation_t transA, cusparseOperation_t transB, int m, Status Csrmm(gpusparseOperation_t transA, gpusparseOperation_t transB, int m,
int n, int k, int nnz, const Scalar* alpha_host, int n, int k, int nnz, const Scalar* alpha_host,
const cusparseMatDescr_t descrA, const Scalar* csrSortedValA, const gpusparseMatDescr_t descrA, const Scalar* csrSortedValA,
const int* csrSortedRowPtrA, const int* csrSortedColIndA, const int* csrSortedRowPtrA, const int* csrSortedColIndA,
const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C, const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C,
int ldc) const; int ldc) const;
@ -231,8 +304,8 @@ class CudaSparse {
// //
// **NOTE** This is an in-place operation for data in y. // **NOTE** This is an in-place operation for data in y.
template <typename Scalar> template <typename Scalar>
Status Csrmv(cusparseOperation_t transA, int m, int n, int nnz, Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz,
const Scalar* alpha_host, const cusparseMatDescr_t descrA, const Scalar* alpha_host, const gpusparseMatDescr_t descrA,
const Scalar* csrSortedValA, const int* csrSortedRowPtrA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
const int* csrSortedColIndA, const Scalar* x, const int* csrSortedColIndA, const Scalar* x,
const Scalar* beta_host, Scalar* y) const; const Scalar* beta_host, Scalar* y) const;
@ -242,11 +315,11 @@ class CudaSparse {
// output. csrSortedRowPtrC must be preallocated on device with // output. csrSortedRowPtrC must be preallocated on device with
// m + 1 entries. See: // m + 1 entries. See:
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgeam. // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgeam.
Status CsrgeamNnz(int m, int n, const cusparseMatDescr_t descrA, int nnzA, Status CsrgeamNnz(int m, int n, const gpusparseMatDescr_t descrA, int nnzA,
const int* csrSortedRowPtrA, const int* csrSortedColIndA, const int* csrSortedRowPtrA, const int* csrSortedColIndA,
const cusparseMatDescr_t descrB, int nnzB, const gpusparseMatDescr_t descrB, int nnzB,
const int* csrSortedRowPtrB, const int* csrSortedColIndB, const int* csrSortedRowPtrB, const int* csrSortedColIndB,
const cusparseMatDescr_t descrC, int* csrSortedRowPtrC, const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
int* nnzTotalDevHostPtr); int* nnzTotalDevHostPtr);
// Computes sparse - sparse matrix addition of matrices // Computes sparse - sparse matrix addition of matrices
@ -256,12 +329,12 @@ class CudaSparse {
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgeam. // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgeam.
template <typename Scalar> template <typename Scalar>
Status Csrgeam(int m, int n, const Scalar* alpha, Status Csrgeam(int m, int n, const Scalar* alpha,
const cusparseMatDescr_t descrA, int nnzA, const gpusparseMatDescr_t descrA, int nnzA,
const Scalar* csrSortedValA, const int* csrSortedRowPtrA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
const int* csrSortedColIndA, const Scalar* beta, const int* csrSortedColIndA, const Scalar* beta,
const cusparseMatDescr_t descrB, int nnzB, const gpusparseMatDescr_t descrB, int nnzB,
const Scalar* csrSortedValB, const int* csrSortedRowPtrB, const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
const int* csrSortedColIndB, const cusparseMatDescr_t descrC, const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
Scalar* csrSortedValC, int* csrSortedRowPtrC, Scalar* csrSortedValC, int* csrSortedRowPtrC,
int* csrSortedColIndC); int* csrSortedColIndC);
@ -270,13 +343,13 @@ class CudaSparse {
// output. csrSortedRowPtrC must be preallocated on device with // output. csrSortedRowPtrC must be preallocated on device with
// m + 1 entries. See: // m + 1 entries. See:
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm. // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
Status CsrgemmNnz(cusparseOperation_t transA, cusparseOperation_t transB, Status CsrgemmNnz(gpusparseOperation_t transA, gpusparseOperation_t transB,
int m, int k, int n, const cusparseMatDescr_t descrA, int m, int k, int n, const gpusparseMatDescr_t descrA,
int nnzA, const int* csrSortedRowPtrA, int nnzA, const int* csrSortedRowPtrA,
const int* csrSortedColIndA, const int* csrSortedColIndA,
const cusparseMatDescr_t descrB, int nnzB, const gpusparseMatDescr_t descrB, int nnzB,
const int* csrSortedRowPtrB, const int* csrSortedColIndB, const int* csrSortedRowPtrB, const int* csrSortedColIndB,
const cusparseMatDescr_t descrC, int* csrSortedRowPtrC, const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
int* nnzTotalDevHostPtr); int* nnzTotalDevHostPtr);
// Computes sparse - sparse matrix matmul of matrices // Computes sparse - sparse matrix matmul of matrices
@ -285,19 +358,20 @@ class CudaSparse {
// with nnzTotalDevHostPtr entries (as calculated by CsrgemmNnz). See: // with nnzTotalDevHostPtr entries (as calculated by CsrgemmNnz). See:
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm. // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
template <typename Scalar> template <typename Scalar>
Status Csrgemm(cusparseOperation_t transA, cusparseOperation_t transB, int m, Status Csrgemm(gpusparseOperation_t transA, gpusparseOperation_t transB,
int k, int n, const cusparseMatDescr_t descrA, int nnzA, int m, int k, int n, const gpusparseMatDescr_t descrA,
const Scalar* csrSortedValA, const int* csrSortedRowPtrA, int nnzA, const Scalar* csrSortedValA,
const int* csrSortedColIndA, const cusparseMatDescr_t descrB, const int* csrSortedRowPtrA, const int* csrSortedColIndA,
int nnzB, const Scalar* csrSortedValB, const gpusparseMatDescr_t descrB, int nnzB,
const int* csrSortedRowPtrB, const int* csrSortedColIndB, const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
const cusparseMatDescr_t descrC, Scalar* csrSortedValC, const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
int* csrSortedRowPtrC, int* csrSortedColIndC); Scalar* csrSortedValC, int* csrSortedRowPtrC,
int* csrSortedColIndC);
// In-place reordering of unsorted CSR to sorted CSR. // In-place reordering of unsorted CSR to sorted CSR.
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csru2csr // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csru2csr
template <typename Scalar> template <typename Scalar>
Status Csru2csr(int m, int n, int nnz, const cusparseMatDescr_t descrA, Status Csru2csr(int m, int n, int nnz, const gpusparseMatDescr_t descrA,
Scalar* csrVal, const int* csrRowPtr, int* csrColInd); Scalar* csrVal, const int* csrRowPtr, int* csrColInd);
// Converts from CSR to CSC format (equivalently, transpose). // Converts from CSR to CSC format (equivalently, transpose).
@ -306,30 +380,30 @@ class CudaSparse {
Status Csr2csc(int m, int n, int nnz, const Scalar* csrVal, Status Csr2csc(int m, int n, int nnz, const Scalar* csrVal,
const int* csrRowPtr, const int* csrColInd, Scalar* cscVal, const int* csrRowPtr, const int* csrColInd, Scalar* cscVal,
int* cscRowInd, int* cscColPtr, int* cscRowInd, int* cscColPtr,
const cusparseAction_t copyValues); const gpusparseAction_t copyValues);
private: private:
bool initialized_; bool initialized_;
OpKernelContext *context_; // not owned. OpKernelContext *context_; // not owned.
cudaStream_t cuda_stream_; gpuStream_t gpu_stream_;
cusparseHandle_t *cusparse_handle_; // not owned. gpusparseHandle_t* gpusparse_handle_; // not owned.
TF_DISALLOW_COPY_AND_ASSIGN(CudaSparse); TF_DISALLOW_COPY_AND_ASSIGN(GpuSparse);
}; };
// A wrapper class to ensure that a CUDA sparse matrix descriptor is initialized // A wrapper class to ensure that a CUDA sparse matrix descriptor is initialized
// only once. For more details on the descriptor (cusparseMatDescr_t), see: // only once. For more details on the descriptor (gpusparseMatDescr_t), see:
// https://docs.nvidia.com/cuda/cusparse/index.html#cusparsematdescrt // https://docs.nvidia.com/cuda/cusparse/index.html#cusparsematdescrt
class CudaSparseMatrixDescriptor { class GpuSparseMatrixDescriptor {
public: public:
explicit CudaSparseMatrixDescriptor() : initialized_(false) {} explicit GpuSparseMatrixDescriptor() : initialized_(false) {}
CudaSparseMatrixDescriptor(CudaSparseMatrixDescriptor&& rhs) GpuSparseMatrixDescriptor(GpuSparseMatrixDescriptor&& rhs)
: initialized_(rhs.initialized_), descr_(std::move(rhs.descr_)) { : initialized_(rhs.initialized_), descr_(std::move(rhs.descr_)) {
rhs.initialized_ = false; rhs.initialized_ = false;
} }
CudaSparseMatrixDescriptor& operator=(CudaSparseMatrixDescriptor&& rhs) { GpuSparseMatrixDescriptor& operator=(GpuSparseMatrixDescriptor&& rhs) {
if (this == &rhs) return *this; if (this == &rhs) return *this;
Release(); Release();
initialized_ = rhs.initialized_; initialized_ = rhs.initialized_;
@ -338,23 +412,27 @@ class CudaSparseMatrixDescriptor {
return *this; return *this;
} }
~CudaSparseMatrixDescriptor() { Release(); } ~GpuSparseMatrixDescriptor() { Release(); }
// Initializes the underlying descriptor. Will fail on the second call if // Initializes the underlying descriptor. Will fail on the second call if
// called more than once. // called more than once.
Status Initialize() { Status Initialize() {
DCHECK(!initialized_); DCHECK(!initialized_);
TF_RETURN_IF_CUSPARSE_ERROR(cusparseCreateMatDescr(&descr_)); #if GOOGLE_CUDA
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descr_));
#elif TENSORFLOW_USE_ROCM
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descr_));
#endif
initialized_ = true; initialized_ = true;
return Status::OK(); return Status::OK();
} }
cusparseMatDescr_t& descr() { gpusparseMatDescr_t& descr() {
DCHECK(initialized_); DCHECK(initialized_);
return descr_; return descr_;
} }
const cusparseMatDescr_t& descr() const { const gpusparseMatDescr_t& descr() const {
DCHECK(initialized_); DCHECK(initialized_);
return descr_; return descr_;
} }
@ -362,31 +440,37 @@ class CudaSparseMatrixDescriptor {
private: private:
void Release() { void Release() {
if (initialized_) { if (initialized_) {
#if GOOGLE_CUDA
cusparseDestroyMatDescr(descr_); cusparseDestroyMatDescr(descr_);
#elif TENSORFLOW_USE_ROCM
hipsparseDestroyMatDescr(descr_);
#endif
initialized_ = false; initialized_ = false;
} }
} }
bool initialized_; bool initialized_;
cusparseMatDescr_t descr_; gpusparseMatDescr_t descr_;
TF_DISALLOW_COPY_AND_ASSIGN(CudaSparseMatrixDescriptor); TF_DISALLOW_COPY_AND_ASSIGN(GpuSparseMatrixDescriptor);
}; };
#if GOOGLE_CUDA
// A wrapper class to ensure that an unsorted/sorted CSR conversion information // A wrapper class to ensure that an unsorted/sorted CSR conversion information
// struct (csru2csrInfo_t) is initialized only once. See: // struct (csru2csrInfo_t) is initialized only once. See:
// https://docs.nvidia.com/cuda/cusparse/index.html#csru2csr // https://docs.nvidia.com/cuda/cusparse/index.html#csru2csr
class CudaSparseCsrSortingConversionInfo { class GpuSparseCsrSortingConversionInfo {
public: public:
explicit CudaSparseCsrSortingConversionInfo() : initialized_(false) {} explicit GpuSparseCsrSortingConversionInfo() : initialized_(false) {}
CudaSparseCsrSortingConversionInfo(CudaSparseCsrSortingConversionInfo&& rhs) GpuSparseCsrSortingConversionInfo(GpuSparseCsrSortingConversionInfo&& rhs)
: initialized_(rhs.initialized_), info_(std::move(rhs.info_)) { : initialized_(rhs.initialized_), info_(std::move(rhs.info_)) {
rhs.initialized_ = false; rhs.initialized_ = false;
} }
CudaSparseCsrSortingConversionInfo& operator=( GpuSparseCsrSortingConversionInfo& operator=(
CudaSparseCsrSortingConversionInfo&& rhs) { GpuSparseCsrSortingConversionInfo&& rhs) {
if (this == &rhs) return *this; if (this == &rhs) return *this;
Release(); Release();
initialized_ = rhs.initialized_; initialized_ = rhs.initialized_;
@ -395,13 +479,13 @@ class CudaSparseCsrSortingConversionInfo {
return *this; return *this;
} }
~CudaSparseCsrSortingConversionInfo() { Release(); } ~GpuSparseCsrSortingConversionInfo() { Release(); }
// Initializes the underlying info. Will fail on the second call if called // Initializes the underlying info. Will fail on the second call if called
// more than once. // more than once.
Status Initialize() { Status Initialize() {
DCHECK(!initialized_); DCHECK(!initialized_);
TF_RETURN_IF_CUSPARSE_ERROR(cusparseCreateCsru2csrInfo(&info_)); TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsru2csrInfo(&info_));
initialized_ = true; initialized_ = true;
return Status::OK(); return Status::OK();
} }
@ -427,11 +511,13 @@ class CudaSparseCsrSortingConversionInfo {
bool initialized_; bool initialized_;
csru2csrInfo_t info_; csru2csrInfo_t info_;
TF_DISALLOW_COPY_AND_ASSIGN(CudaSparseCsrSortingConversionInfo); TF_DISALLOW_COPY_AND_ASSIGN(GpuSparseCsrSortingConversionInfo);
}; };
} // namespace tensorflow
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA
} // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_ #endif // TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_

View File

@ -0,0 +1,330 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if TENSORFLOW_USE_ROCM
#include <complex>
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace {
// A set of initialized handles to the underlying ROCm libraries used by
// GpuSparse. We maintain one such set of handles per unique stream.
class HipSparseHandles {
public:
explicit HipSparseHandles(hipStream_t stream)
: initialized_(false), stream_(stream) {}
HipSparseHandles(HipSparseHandles&& rhs)
: initialized_(rhs.initialized_),
stream_(std::move(rhs.stream_)),
hipsparse_handle_(rhs.hipsparse_handle_) {
rhs.initialized_ = false;
}
HipSparseHandles& operator=(HipSparseHandles&& rhs) {
if (this == &rhs) return *this;
Release();
stream_ = std::move(rhs.stream_);
hipsparse_handle_ = std::move(rhs.hipsparse_handle_);
initialized_ = rhs.initialized_;
rhs.initialized_ = false;
return *this;
}
~HipSparseHandles() { Release(); }
Status Initialize() {
if (initialized_) return Status::OK();
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreate(&hipsparse_handle_));
TF_RETURN_IF_GPUSPARSE_ERROR(
hipsparseSetStream(hipsparse_handle_, stream_));
initialized_ = true;
return Status::OK();
}
hipsparseHandle_t& handle() {
DCHECK(initialized_);
return hipsparse_handle_;
}
const hipsparseHandle_t& handle() const {
DCHECK(initialized_);
return hipsparse_handle_;
}
private:
void Release() {
if (initialized_) {
// This should never return anything other than success
auto err = hipsparseDestroy(hipsparse_handle_);
DCHECK(err == HIPSPARSE_STATUS_SUCCESS)
<< "Failed to destroy hipSPARSE instance.";
initialized_ = false;
}
}
bool initialized_;
hipStream_t stream_;
hipsparseHandle_t hipsparse_handle_;
TF_DISALLOW_COPY_AND_ASSIGN(HipSparseHandles);
};
// TODO(ebrevdo): Replace global mutex guarding CudaSparseHandles
// lookup with one of:
// 1. Adding the handle to the CudaStream structure; do the lookup there.
// 2. Add a thread-local cusparse, set it to the current stream
// upon each call.
// #1 seems like the cleanest option but will need to wait until this
// is moved into TF core.
static mutex handle_map_mutex(LINKER_INITIALIZED);
using HandleMap = std::unordered_map<hipStream_t, HipSparseHandles>;
// Returns a singleton map used for storing initialized handles for each unique
// cuda stream.
HandleMap* GetHandleMapSingleton() {
static HandleMap* cm = new HandleMap;
return cm;
}
} // namespace
GpuSparse::GpuSparse(OpKernelContext* context)
: initialized_(false), context_(context) {
auto hip_stream_ptr =
reinterpret_cast<const hipStream_t*>(context->op_device_context()
->stream()
->implementation()
->GpuStreamMemberHack());
DCHECK(hip_stream_ptr);
gpu_stream_ = *hip_stream_ptr;
}
Status GpuSparse::Initialize() {
HandleMap* handle_map = GetHandleMapSingleton();
DCHECK(handle_map);
mutex_lock lock(handle_map_mutex);
auto it = handle_map->find(gpu_stream_);
if (it == handle_map->end()) {
LOG(INFO) << "Creating GpuSparse handles for stream " << gpu_stream_;
// Previously unseen ROCm stream. Initialize a set of ROCm sparse library
// handles for it.
HipSparseHandles new_handles(gpu_stream_);
TF_RETURN_IF_ERROR(new_handles.Initialize());
it = handle_map->insert(std::make_pair(gpu_stream_, std::move(new_handles)))
.first;
}
gpusparse_handle_ = &it->second.handle();
initialized_ = true;
return Status::OK();
}
// Macro that specializes a sparse method for all 4 standard
// numeric types.
#define TF_CALL_HIP_LAPACK_TYPES(m) m(float, S) m(double, D)
// Macros to construct hipsparse method names.
#define SPARSE_FN(method, sparse_prefix) hipsparse##sparse_prefix##method
Status GpuSparse::Coo2csr(const int* cooRowInd, int nnz, int m,
int* csrRowPtr) const {
DCHECK(initialized_);
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcoo2csr(*gpusparse_handle_, cooRowInd,
nnz, m, csrRowPtr,
HIPSPARSE_INDEX_BASE_ZERO));
return Status::OK();
}
Status GpuSparse::Csr2coo(const int* csrRowPtr, int nnz, int m,
int* cooRowInd) const {
DCHECK(initialized_);
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcsr2coo(*gpusparse_handle_, csrRowPtr,
nnz, m, cooRowInd,
HIPSPARSE_INDEX_BASE_ZERO));
return Status::OK();
}
template <typename Scalar, typename SparseFnT>
static inline Status CsrmmImpl(
SparseFnT op, OpKernelContext* context, hipsparseHandle_t hipsparse_handle,
hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n,
int k, int nnz, const Scalar* alpha_host, const hipsparseMatDescr_t descrA,
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
const int* csrSortedColIndA, const Scalar* B, int ldb,
const Scalar* beta_host, Scalar* C, int ldc) {
TF_RETURN_IF_GPUSPARSE_ERROR(op(hipsparse_handle, transA, transB, m, n, k,
nnz, alpha_host, descrA, csrSortedValA,
csrSortedRowPtrA, csrSortedColIndA, B, ldb,
beta_host, C, ldc));
return Status::OK();
}
#define CSRMM_INSTANCE(Scalar, sparse_prefix) \
template <> \
Status GpuSparse::Csrmm<Scalar>( \
hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n, \
int k, int nnz, const Scalar* alpha_host, \
const hipsparseMatDescr_t descrA, const Scalar* csrSortedValA, \
const int* csrSortedRowPtrA, const int* csrSortedColIndA, \
const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C, int ldc) \
const { \
DCHECK(initialized_); \
return CsrmmImpl(SPARSE_FN(csrmm2, sparse_prefix), context_, \
*gpusparse_handle_, transA, transB, m, n, k, nnz, \
alpha_host, descrA, csrSortedValA, csrSortedRowPtrA, \
csrSortedColIndA, B, ldb, beta_host, C, ldc); \
}
TF_CALL_HIP_LAPACK_TYPES(CSRMM_INSTANCE);
template <typename Scalar, typename SparseFnT>
static inline Status CsrmvImpl(SparseFnT op, OpKernelContext* context,
hipsparseHandle_t hipsparse_handle,
hipsparseOperation_t transA, int m, int n,
int nnz, const Scalar* alpha_host,
const hipsparseMatDescr_t descrA,
const Scalar* csrSortedValA,
const int* csrSortedRowPtrA,
const int* csrSortedColIndA, const Scalar* x,
const Scalar* beta_host, Scalar* y) {
TF_RETURN_IF_GPUSPARSE_ERROR(
op(hipsparse_handle, transA, m, n, nnz, alpha_host, descrA, csrSortedValA,
csrSortedRowPtrA, csrSortedColIndA, x, beta_host, y));
return Status::OK();
}
// TODO(ebrevdo,rmlarsen): Use csrmv_mp for all cases when available in CUDA 9.
#define CSRMV_INSTANCE(Scalar, sparse_prefix) \
template <> \
Status GpuSparse::Csrmv<Scalar>( \
hipsparseOperation_t transA, int m, int n, int nnz, \
const Scalar* alpha_host, const hipsparseMatDescr_t descrA, \
const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
const int* csrSortedColIndA, const Scalar* x, const Scalar* beta_host, \
Scalar* y) const { \
DCHECK(initialized_); \
return CsrmvImpl(SPARSE_FN(csrmv, sparse_prefix), context_, \
*gpusparse_handle_, transA, m, n, nnz, alpha_host, \
descrA, csrSortedValA, csrSortedRowPtrA, \
csrSortedColIndA, x, beta_host, y); \
}
TF_CALL_HIP_LAPACK_TYPES(CSRMV_INSTANCE);
Status GpuSparse::CsrgemmNnz(
hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n,
int k, const hipsparseMatDescr_t descrA, int nnzA,
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
const hipsparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,
const int* csrSortedColIndB, const hipsparseMatDescr_t descrC,
int* csrSortedRowPtrC, int* nnzTotalDevHostPtr) {
DCHECK(initialized_);
DCHECK(nnzTotalDevHostPtr != nullptr);
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcsrgemmNnz(
*gpusparse_handle_, transA, transB, m, n, k, descrA, nnzA,
csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB,
csrSortedColIndB, descrC, csrSortedRowPtrC, nnzTotalDevHostPtr));
return Status::OK();
}
template <typename Scalar, typename SparseFnT>
static inline Status CsrgemmImpl(
SparseFnT op, OpKernelContext* context, hipsparseHandle_t hipsparse_handle,
hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n,
int k, const hipsparseMatDescr_t descrA, int nnzA,
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
const int* csrSortedColIndA, const hipsparseMatDescr_t descrB, int nnzB,
const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
const int* csrSortedColIndB, const hipsparseMatDescr_t descrC,
Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC) {
TF_RETURN_IF_GPUSPARSE_ERROR(
op(hipsparse_handle, transA, transB, m, n, k, descrA, nnzA, csrSortedValA,
csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, csrSortedValB,
csrSortedRowPtrB, csrSortedColIndB, descrC, csrSortedValC,
csrSortedRowPtrC, csrSortedColIndC));
return Status::OK();
}
#define CSRGEMM_INSTANCE(Scalar, sparse_prefix) \
template <> \
Status GpuSparse::Csrgemm<Scalar>( \
hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n, \
int k, const hipsparseMatDescr_t descrA, int nnzA, \
const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
const int* csrSortedColIndA, const hipsparseMatDescr_t descrB, int nnzB, \
const Scalar* csrSortedValB, const int* csrSortedRowPtrB, \
const int* csrSortedColIndB, const hipsparseMatDescr_t descrC, \
Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC) { \
DCHECK(initialized_); \
return CsrgemmImpl(SPARSE_FN(csrgemm, sparse_prefix), context_, \
*gpusparse_handle_, transA, transB, m, n, k, descrA, \
nnzA, csrSortedValA, csrSortedRowPtrA, \
csrSortedColIndA, descrB, nnzB, csrSortedValB, \
csrSortedRowPtrB, csrSortedColIndB, descrC, \
csrSortedValC, csrSortedRowPtrC, csrSortedColIndC); \
}
TF_CALL_HIP_LAPACK_TYPES(CSRGEMM_INSTANCE);
template <typename Scalar, typename SparseFnT>
static inline Status Csr2cscImpl(SparseFnT op, OpKernelContext* context,
hipsparseHandle_t hipsparse_handle, int m,
int n, int nnz, const Scalar* csrVal,
const int* csrRowPtr, const int* csrColInd,
Scalar* cscVal, int* cscRowInd, int* cscColPtr,
const hipsparseAction_t copyValues) {
TF_RETURN_IF_GPUSPARSE_ERROR(
op(hipsparse_handle, m, n, nnz, csrVal, csrRowPtr, csrColInd, cscVal,
cscRowInd, cscColPtr, copyValues, HIPSPARSE_INDEX_BASE_ZERO));
return Status::OK();
}
#define CSR2CSC_INSTANCE(Scalar, sparse_prefix) \
template <> \
Status GpuSparse::Csr2csc<Scalar>( \
int m, int n, int nnz, const Scalar* csrVal, const int* csrRowPtr, \
const int* csrColInd, Scalar* cscVal, int* cscRowInd, int* cscColPtr, \
const hipsparseAction_t copyValues) { \
DCHECK(initialized_); \
return Csr2cscImpl(SPARSE_FN(csr2csc, sparse_prefix), context_, \
*gpusparse_handle_, m, n, nnz, csrVal, csrRowPtr, \
csrColInd, cscVal, cscRowInd, cscColPtr, copyValues); \
}
TF_CALL_HIP_LAPACK_TYPES(CSR2CSC_INSTANCE);
} // namespace tensorflow
#endif // TENSORFLOW_USE_ROCM

View File

@ -2,10 +2,10 @@
load( load(
"//tensorflow:tensorflow.bzl", "//tensorflow:tensorflow.bzl",
"if_cuda_or_rocm",
"tf_cc_test", "tf_cc_test",
"tf_kernel_library", "tf_kernel_library",
) )
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
package( package(
default_visibility = ["//visibility:public"], default_visibility = ["//visibility:public"],
@ -77,7 +77,7 @@ tf_kernel_library(
"//tensorflow/core/kernels:scatter_nd_op", "//tensorflow/core/kernels:scatter_nd_op",
"//tensorflow/core/kernels:slice_op", "//tensorflow/core/kernels:slice_op",
"//tensorflow/core/kernels:transpose_functor", "//tensorflow/core/kernels:transpose_functor",
] + if_cuda([ ] + if_cuda_or_rocm([
"//tensorflow/core/kernels:cuda_solvers", "//tensorflow/core/kernels:cuda_solvers",
"//tensorflow/core/kernels:cuda_sparse", "//tensorflow/core/kernels:cuda_sparse",
]), ]),

View File

@ -15,7 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif
@ -31,7 +31,7 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse/sparse_matrix.h" #include "tensorflow/core/kernels/sparse/sparse_matrix.h"
#include "tensorflow/core/kernels/fill_functor.h" #include "tensorflow/core/kernels/fill_functor.h"
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_solvers.h" #include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h" #include "tensorflow/core/kernels/cuda_sparse.h"
#endif #endif
@ -233,8 +233,10 @@ class CSRAddOp : public OpKernel {
REGISTER_GPU(float) REGISTER_GPU(float)
REGISTER_GPU(double) REGISTER_GPU(double)
#if GOOGLE_CUDA
REGISTER_GPU(complex64) REGISTER_GPU(complex64)
REGISTER_GPU(complex128) REGISTER_GPU(complex128)
#endif
#undef REGISTER_GPU #undef REGISTER_GPU
@ -246,7 +248,7 @@ REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(
#undef REGISTER #undef REGISTER
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace functor { namespace functor {
template <typename T> template <typename T>
struct CSRSparseMatrixAdd<GPUDevice, T> struct CSRSparseMatrixAdd<GPUDevice, T>
@ -324,10 +326,10 @@ struct CSRSparseMatrixAdd<GPUDevice, T>
private: private:
OpKernelContext* ctx_; OpKernelContext* ctx_;
CudaSparse cuda_sparse_; GpuSparse cuda_sparse_;
CudaSparseMatrixDescriptor descrA_; GpuSparseMatrixDescriptor descrA_;
CudaSparseMatrixDescriptor descrB_; GpuSparseMatrixDescriptor descrB_;
CudaSparseMatrixDescriptor descrC_; GpuSparseMatrixDescriptor descrC_;
const T alpha_; const T alpha_;
const T beta_; const T beta_;
bool initialized_; bool initialized_;
@ -337,6 +339,6 @@ struct CSRSparseMatrixAdd<GPUDevice, T>
} // namespace functor } // namespace functor
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow } // namespace tensorflow

View File

@ -15,7 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif
@ -31,7 +31,7 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse/kernels.h" #include "tensorflow/core/kernels/sparse/kernels.h"
#include "tensorflow/core/kernels/sparse/sparse_matrix.h" #include "tensorflow/core/kernels/sparse/sparse_matrix.h"
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_solvers.h" #include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h" #include "tensorflow/core/kernels/cuda_sparse.h"
#endif #endif
@ -92,12 +92,12 @@ REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(
CONJ_VARIANT_UNARY_OP, DEVICE_CPU, CSRSparseMatrix, CONJ_VARIANT_UNARY_OP, DEVICE_CPU, CSRSparseMatrix,
(CSRSparseMatrixUnaryHelper<CPUDevice, CSRSparseMatrixConjFunctor>)); (CSRSparseMatrixUnaryHelper<CPUDevice, CSRSparseMatrixConjFunctor>));
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION( REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(
CONJ_VARIANT_UNARY_OP, DEVICE_GPU, CSRSparseMatrix, CONJ_VARIANT_UNARY_OP, DEVICE_GPU, CSRSparseMatrix,
(CSRSparseMatrixUnaryHelper<GPUDevice, CSRSparseMatrixConjFunctor>)); (CSRSparseMatrixUnaryHelper<GPUDevice, CSRSparseMatrixConjFunctor>));
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow } // namespace tensorflow

View File

@ -15,7 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif
@ -33,7 +33,7 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse/sparse_matrix.h" #include "tensorflow/core/kernels/sparse/sparse_matrix.h"
#include "tensorflow/core/util/work_sharder.h" #include "tensorflow/core/util/work_sharder.h"
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_solvers.h" #include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h" #include "tensorflow/core/kernels/cuda_sparse.h"
#endif #endif
@ -220,19 +220,21 @@ REGISTER_CPU(double)
REGISTER_CPU(complex64) REGISTER_CPU(complex64)
REGISTER_CPU(complex128) REGISTER_CPU(complex128)
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER_GPU(float) REGISTER_GPU(float)
REGISTER_GPU(double) REGISTER_GPU(double)
#if GOOGLE_CUDA
REGISTER_GPU(complex64) REGISTER_GPU(complex64)
REGISTER_GPU(complex128) REGISTER_GPU(complex128)
#endif
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#undef REGISTER_CPU #undef REGISTER_CPU
#undef REGISTER_GPU #undef REGISTER_GPU
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace functor { namespace functor {
template <> template <>
@ -256,6 +258,6 @@ extern template struct CSRSparseMatrixToCOOSparseMatrix<GPUDevice>;
} // namespace functor } // namespace functor
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow } // namespace tensorflow

View File

@ -15,7 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif
@ -31,7 +31,7 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse/sparse_matrix.h" #include "tensorflow/core/kernels/sparse/sparse_matrix.h"
#include "tensorflow/core/util/work_sharder.h" #include "tensorflow/core/util/work_sharder.h"
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_solvers.h" #include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h" #include "tensorflow/core/kernels/cuda_sparse.h"
#endif #endif
@ -205,18 +205,20 @@ class CSRSparseMatrixToSparseTensorGPUOp : public OpKernel {
.HostMemory("dense_shape"), \ .HostMemory("dense_shape"), \
CSRSparseMatrixToSparseTensorGPUOp<GPUDevice, T>); CSRSparseMatrixToSparseTensorGPUOp<GPUDevice, T>);
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER_GPU(float) REGISTER_GPU(float)
REGISTER_GPU(double) REGISTER_GPU(double)
#if GOOGLE_CUDA
REGISTER_GPU(complex64) REGISTER_GPU(complex64)
REGISTER_GPU(complex128) REGISTER_GPU(complex128)
#endif
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#undef REGISTER_GPU #undef REGISTER_GPU
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace functor { namespace functor {
template <> template <>
@ -240,7 +242,7 @@ extern template struct CSRSparseMatrixToCOOSparseMatrix<GPUDevice>;
} // namespace functor } // namespace functor
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_CPU(T) \ #define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER(Name("CSRSparseMatrixToSparseTensor") \ REGISTER_KERNEL_BUILDER(Name("CSRSparseMatrixToSparseTensor") \

View File

@ -15,7 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif
@ -32,13 +32,18 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse/kernels.h" #include "tensorflow/core/kernels/sparse/kernels.h"
#include "tensorflow/core/kernels/sparse/sparse_matrix.h" #include "tensorflow/core/kernels/sparse/sparse_matrix.h"
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/kernels/cuda_solvers.h" #include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h" #include "tensorflow/core/kernels/cuda_sparse.h"
#include "tensorflow/stream_executor/cuda/cuda_activation.h" #endif
#if GOOGLE_CUDA
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
using ::perftools::gputools::cuda::ScopedActivateExecutorContext; using ::perftools::gputools::cuda::ScopedActivateExecutorContext;
#elif TENSORFLOW_USE_ROCM
#include "tensorflow/stream_executor/rocm/rocm_activation.h"
using ::perftools::gputools::rocm::ScopedActivateExecutorContext;
#endif #endif
namespace tensorflow { namespace tensorflow {
@ -138,7 +143,7 @@ REGISTER_CPU(complex128)
#undef REGISTER_CPU #undef REGISTER_CPU
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename Device, typename T> template <typename Device, typename T>
class DenseToCSRSparseMatrixGPUOp : public AsyncOpKernel { class DenseToCSRSparseMatrixGPUOp : public AsyncOpKernel {
@ -356,8 +361,10 @@ class DenseToCSRSparseMatrixGPUOp : public AsyncOpKernel {
REGISTER_GPU(GPU, float) REGISTER_GPU(GPU, float)
REGISTER_GPU(GPU, double) REGISTER_GPU(GPU, double)
#if GOOGLE_CUDA
REGISTER_GPU(GPU, complex64) REGISTER_GPU(GPU, complex64)
REGISTER_GPU(GPU, complex128) REGISTER_GPU(GPU, complex128)
#endif
namespace functor { namespace functor {
@ -380,7 +387,7 @@ struct COOSparseMatrixToCSRSparseMatrix<GPUDevice> {
Status operator()(OpKernelContext* c, const int rows, const int cols, Status operator()(OpKernelContext* c, const int rows, const int cols,
TTypes<int>::UnalignedVec coo_row_ind, TTypes<int>::UnalignedVec coo_row_ind,
TTypes<int>::UnalignedVec csr_row_ptr) { TTypes<int>::UnalignedVec csr_row_ptr) {
CudaSparse cuda_sparse(c); GpuSparse cuda_sparse(c);
TF_RETURN_IF_ERROR(cuda_sparse.Initialize()); TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
return cuda_sparse.Coo2csr(coo_row_ind.data(), return cuda_sparse.Coo2csr(coo_row_ind.data(),
/*nnz*/ coo_row_ind.size(), /*nnz*/ coo_row_ind.size(),
@ -391,7 +398,7 @@ extern template struct COOSparseMatrixToCSRSparseMatrix<GPUDevice>;
} // namespace functor } // namespace functor
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#undef REGISTER_GPU #undef REGISTER_GPU

View File

@ -13,15 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#if GOOGLE_CUDA
#include "third_party/cub/device/device_histogram.cuh" #include "third_party/cub/device/device_histogram.cuh"
#include "third_party/cub/iterator/counting_input_iterator.cuh" #include "third_party/cub/iterator/counting_input_iterator.cuh"
#include "third_party/cub/iterator/transform_input_iterator.cuh" #include "third_party/cub/iterator/transform_input_iterator.cuh"
#include "third_party/gpus/cuda/include/cusparse.h" #include "third_party/gpus/cuda/include/cusparse.h"
#elif TENSORFLOW_USE_ROCM
#include "rocm/include/hipcub/hipcub.hpp"
#endif
#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/cuda_sparse.h" #include "tensorflow/core/kernels/cuda_sparse.h"
@ -32,6 +36,12 @@ limitations under the License.
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/gpu_kernel_helper.h" #include "tensorflow/core/util/gpu_kernel_helper.h"
#if GOOGLE_CUDA
namespace gpuprim = ::cub;
#elif TENSORFLOW_USE_ROCM
namespace gpuprim = ::hipcub;
#endif
namespace tensorflow { namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice; typedef Eigen::GpuDevice GPUDevice;
@ -65,9 +75,9 @@ Status CalculateNNZPerBatchMatrixFromIndices<GPUDevice>::operator()(
DCHECK_EQ(indices.dimension(1), 3); // batch, row, col DCHECK_EQ(indices.dimension(1), 3); // batch, row, col
const int rank = indices.dimension(1); const int rank = indices.dimension(1);
cub::CountingInputIterator<int> row_counter(0); gpuprim::CountingInputIterator<int> row_counter(0);
cub::TransformInputIterator<int, StridedDataReader, gpuprim::TransformInputIterator<int, StridedDataReader,
cub::CountingInputIterator<int>> gpuprim::CountingInputIterator<int>>
indices_first_column(row_counter, indices_first_column(row_counter,
StridedDataReader(indices.data(), rank)); StridedDataReader(indices.data(), rank));
@ -76,7 +86,7 @@ Status CalculateNNZPerBatchMatrixFromIndices<GPUDevice>::operator()(
DCHECK_NE(indices.data(), nullptr); DCHECK_NE(indices.data(), nullptr);
DCHECK_NE(nnz_per_batch.data(), nullptr); DCHECK_NE(nnz_per_batch.data(), nullptr);
auto first_success = cub::DeviceHistogram::HistogramEven( auto first_success = gpuprim::DeviceHistogram::HistogramEven(
/*d_temp_storage*/ nullptr, /*d_temp_storage*/ nullptr,
/*temp_storage_bytes&*/ temp_storage_bytes, /*temp_storage_bytes&*/ temp_storage_bytes,
/*d_samples*/ indices_first_column, /*d_samples*/ indices_first_column,
@ -87,12 +97,12 @@ Status CalculateNNZPerBatchMatrixFromIndices<GPUDevice>::operator()(
/*num_samples*/ total_nnz, /*num_samples*/ total_nnz,
/*stream*/ cu_stream); /*stream*/ cu_stream);
if (first_success != cudaSuccess) { if (first_success != gpuSuccess) {
return errors::Internal( return errors::Internal(
"SparseTensorToCSRSparseMatrix: Could not launch " "SparseTensorToCSRSparseMatrix: Could not launch "
"cub::DeviceHistogram::HistogramEven " "gpuprim::DeviceHistogram::HistogramEven "
"to calculate temp_storage_bytes, status: ", "to calculate temp_storage_bytes, status: ",
cudaGetErrorString(first_success)); GpuGetErrorString(first_success));
} }
Tensor temp_storage; Tensor temp_storage;
@ -100,7 +110,7 @@ Status CalculateNNZPerBatchMatrixFromIndices<GPUDevice>::operator()(
DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}), DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
&temp_storage)); &temp_storage));
DCHECK_NE(temp_storage.flat<int8>().data(), nullptr); DCHECK_NE(temp_storage.flat<int8>().data(), nullptr);
auto second_success = cub::DeviceHistogram::HistogramEven( auto second_success = gpuprim::DeviceHistogram::HistogramEven(
/*d_temp_storage*/ temp_storage.flat<int8>().data(), /*d_temp_storage*/ temp_storage.flat<int8>().data(),
/*temp_storage_bytes&*/ temp_storage_bytes, /*temp_storage_bytes&*/ temp_storage_bytes,
/*d_samples*/ indices_first_column, /*d_samples*/ indices_first_column,
@ -111,12 +121,12 @@ Status CalculateNNZPerBatchMatrixFromIndices<GPUDevice>::operator()(
/*num_samples*/ total_nnz, /*num_samples*/ total_nnz,
/*stream*/ cu_stream); /*stream*/ cu_stream);
if (second_success != cudaSuccess) { if (second_success != gpuSuccess) {
return errors::Internal( return errors::Internal(
"SparseTensorToCSRSparseMatrix: Could not launch " "SparseTensorToCSRSparseMatrix: Could not launch "
"cub::DeviceHistogram::HistogramEven " "gpuprim::DeviceHistogram::HistogramEven "
"to count nnz entries per batch. temp_storage_bytes: ", "to count nnz entries per batch. temp_storage_bytes: ",
temp_storage_bytes, ", status: ", cudaGetErrorString(second_success)); temp_storage_bytes, ", status: ", GpuGetErrorString(second_success));
} }
return Status::OK(); return Status::OK();
@ -128,11 +138,11 @@ template <>
Status CSRSparseMatrixToCOOSparseMatrix<GPUDevice>::operator()( Status CSRSparseMatrixToCOOSparseMatrix<GPUDevice>::operator()(
OpKernelContext* c, TTypes<const int>::UnalignedVec csr_row_ptr, OpKernelContext* c, TTypes<const int>::UnalignedVec csr_row_ptr,
TTypes<int>::UnalignedVec coo_row_ind) { TTypes<int>::UnalignedVec coo_row_ind) {
CudaSparse cuda_sparse(c); GpuSparse gpu_sparse(c);
const int nnz = coo_row_ind.size(); const int nnz = coo_row_ind.size();
TF_RETURN_IF_ERROR(cuda_sparse.Initialize()); TF_RETURN_IF_ERROR(gpu_sparse.Initialize());
const int m = csr_row_ptr.size() - 1; // rows const int m = csr_row_ptr.size() - 1; // rows
return cuda_sparse.Csr2coo(csr_row_ptr.data(), nnz, m, coo_row_ind.data()); return gpu_sparse.Csr2coo(csr_row_ptr.data(), nnz, m, coo_row_ind.data());
} }
template <int stride> template <int stride>
@ -140,7 +150,7 @@ __global__ void SparseTensorToCOOMatrixKernel(const int64* indices,
int* coo_rows_out, int* coo_rows_out,
int* coo_cols_out, int size) { int* coo_cols_out, int size) {
const int offset = (stride == 3) ? 1 : 0; const int offset = (stride == 3) ? 1 : 0;
CUDA_1D_KERNEL_LOOP(i, size) { GPU_1D_KERNEL_LOOP(i, size) {
coo_rows_out[i] = static_cast<int>(ldg(indices + i * stride + offset)); coo_rows_out[i] = static_cast<int>(ldg(indices + i * stride + offset));
coo_cols_out[i] = static_cast<int>(ldg(indices + i * stride + offset + 1)); coo_cols_out[i] = static_cast<int>(ldg(indices + i * stride + offset + 1));
} }
@ -157,20 +167,22 @@ void SparseTensorToCOOSparseMatrix<GPUDevice>::operator()(
const int size = coo_row_ind.dimension(0); const int size = coo_row_ind.dimension(0);
GpuLaunchConfig config = GetGpuLaunchConfig(size, d); GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
if (stride == 2) { if (stride == 2) {
SparseTensorToCOOMatrixKernel<2> TF_CHECK_OK(GpuLaunchKernel(SparseTensorToCOOMatrixKernel<2>,
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>( config.block_count, config.thread_per_block, 0,
indices.data(), coo_row_ind.data(), coo_col_ind.data(), size); d.stream(), indices.data(), coo_row_ind.data(),
coo_col_ind.data(), size));
} else { } else {
SparseTensorToCOOMatrixKernel<3> TF_CHECK_OK(GpuLaunchKernel(SparseTensorToCOOMatrixKernel<3>,
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>( config.block_count, config.thread_per_block, 0,
indices.data(), coo_row_ind.data(), coo_col_ind.data(), size); d.stream(), indices.data(), coo_row_ind.data(),
coo_col_ind.data(), size));
} }
} }
__global__ void COOMatrixToSparseTensorKernel2D(const int* coo_rows, __global__ void COOMatrixToSparseTensorKernel2D(const int* coo_rows,
const int* coo_cols, const int* coo_cols,
int64* indices_out, int size) { int64* indices_out, int size) {
CUDA_1D_KERNEL_LOOP(i, size) { GPU_1D_KERNEL_LOOP(i, size) {
indices_out[i * 2] = static_cast<int64>(ldg(coo_rows + i)); indices_out[i * 2] = static_cast<int64>(ldg(coo_rows + i));
indices_out[i * 2 + 1] = static_cast<int64>(ldg(coo_cols + i)); indices_out[i * 2 + 1] = static_cast<int64>(ldg(coo_cols + i));
} }
@ -203,7 +215,7 @@ __global__ void COOMatrixToSparseTensorKernel3D(
} }
__syncthreads(); __syncthreads();
CUDA_1D_KERNEL_LOOP(i, size) { GPU_1D_KERNEL_LOOP(i, size) {
// TODO(ebrevdo): Consider special casing batch_size <= 3, // TODO(ebrevdo): Consider special casing batch_size <= 3,
// alternatively doing linear instead of binary search. Requires // alternatively doing linear instead of binary search. Requires
// some benchmarks. // some benchmarks.
@ -231,9 +243,10 @@ Status COOSparseMatrixToSparseTensor<GPUDevice>::operator()(
DCHECK_EQ(size, indices.dimension(0)); DCHECK_EQ(size, indices.dimension(0));
if (ndims == 2) { if (ndims == 2) {
GpuLaunchConfig config = GetGpuLaunchConfig(size, d); GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
COOMatrixToSparseTensorKernel2D<<<config.block_count, TF_CHECK_OK(GpuLaunchKernel(COOMatrixToSparseTensorKernel2D,
config.thread_per_block, 0, d.stream()>>>( config.block_count, config.thread_per_block, 0,
coo_row_ind.data(), coo_col_ind.data(), indices.data(), size); d.stream(), coo_row_ind.data(),
coo_col_ind.data(), indices.data(), size));
return Status::OK(); return Status::OK();
} else { } else {
const int batch_size = host_dense_shape(0); const int batch_size = host_dense_shape(0);
@ -246,11 +259,11 @@ Status COOSparseMatrixToSparseTensor<GPUDevice>::operator()(
GpuLaunchConfig config = GetGpuLaunchConfig(size, d); GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
// shared memory stores the batch pointers. // shared memory stores the batch pointers.
const size_t shared_memory_size = sizeof(int) * (batch_size + 1); const size_t shared_memory_size = sizeof(int) * (batch_size + 1);
COOMatrixToSparseTensorKernel3D<<<config.block_count, TF_CHECK_OK(
config.thread_per_block, GpuLaunchKernel(COOMatrixToSparseTensorKernel3D, config.block_count,
shared_memory_size, d.stream()>>>( config.thread_per_block, shared_memory_size, d.stream(),
coo_row_ind.data(), coo_col_ind.data(), indices.data(), coo_row_ind.data(), coo_col_ind.data(), indices.data(),
batch_ptr_copy.data(), batch_size, size); batch_ptr_copy.data(), batch_size, size));
return Status::OK(); return Status::OK();
} }
} }
@ -274,7 +287,7 @@ __global__ void CSRSparseMatrixBatchMulVecKernel3D(
} }
__syncthreads(); __syncthreads();
CUDA_1D_KERNEL_LOOP(i, total_nnz) { GPU_1D_KERNEL_LOOP(i, total_nnz) {
const int b = BinarySearchRange(local_batch_ptr, batch_size, i); const int b = BinarySearchRange(local_batch_ptr, batch_size, i);
c_values[i] = ldg(a_values + i) * local_batch_values[b]; c_values[i] = ldg(a_values + i) * local_batch_values[b];
} }
@ -316,10 +329,10 @@ Status CSRSparseMatrixBatchMulVecImpl(OpKernelContext* ctx,
const size_t shared_memory_size = const size_t shared_memory_size =
(sizeof(int) * (batch_size + 1) // local batch_pointers. (sizeof(int) * (batch_size + 1) // local batch_pointers.
+ sizeof(T) * batch_size); // local copy of b. + sizeof(T) * batch_size); // local copy of b.
CSRSparseMatrixBatchMulVecKernel3D<T> TF_CHECK_OK(GpuLaunchKernel(
<<<config.block_count, config.thread_per_block, shared_memory_size, CSRSparseMatrixBatchMulVecKernel3D<T>, config.block_count,
d.stream()>>>(a_values.data(), b.data(), c_values.data(), config.thread_per_block, shared_memory_size, d.stream(), a_values.data(),
batch_ptr_copy.data(), batch_size, total_nnz); b.data(), c_values.data(), batch_ptr_copy.data(), batch_size, total_nnz));
return Status::OK(); return Status::OK();
} }
@ -374,7 +387,7 @@ __global__ void CSRSparseMatrixSoftmaxKernel2D(const int rows,
// algorithm to distribute the work in case the row sizes are // algorithm to distribute the work in case the row sizes are
// uneven: // uneven:
// http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf // http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf
CUDA_1D_KERNEL_LOOP(row, rows) { GPU_1D_KERNEL_LOOP(row, rows) {
CalculateRowSoftmax(ldg(row_ptr + row), ldg(row_ptr + row + 1), logits, CalculateRowSoftmax(ldg(row_ptr + row), ldg(row_ptr + row + 1), logits,
softmax); softmax);
} }
@ -382,7 +395,7 @@ __global__ void CSRSparseMatrixSoftmaxKernel2D(const int rows,
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void CopyFromGpuDeviceArrayToLocal( EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void CopyFromGpuDeviceArrayToLocal(
GpuDeviceArrayStruct<int> cuda_ptr_s, int* local_ptr, int length) { GpuDeviceArrayStruct<int> cuda_ptr_s, int* local_ptr, int length) {
#ifdef __CUDA_ARCH__ #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
const int* cuda_ptr = GetGpuDeviceArrayOnDevice(&cuda_ptr_s); const int* cuda_ptr = GetGpuDeviceArrayOnDevice(&cuda_ptr_s);
for (int i = threadIdx.x; i < length; i += blockDim.x) { for (int i = threadIdx.x; i < length; i += blockDim.x) {
local_ptr[i] = cuda_ptr[i]; local_ptr[i] = cuda_ptr[i];
@ -404,7 +417,7 @@ __global__ void CSRSparseMatrixSoftmaxKernel3D(
CopyFromGpuDeviceArrayToLocal(std::move(batch_ptr_s), local_batch_ptr, CopyFromGpuDeviceArrayToLocal(std::move(batch_ptr_s), local_batch_ptr,
batch_size + 1); batch_size + 1);
CUDA_1D_KERNEL_LOOP(i, size) { GPU_1D_KERNEL_LOOP(i, size) {
const int batch = i / rows; const int batch = i / rows;
const int row = i % rows; const int row = i % rows;
const int batch_offset = local_batch_ptr[batch]; const int batch_offset = local_batch_ptr[batch];
@ -431,10 +444,10 @@ Status CSRSparseMatrixSoftmaxGPUImpl(OpKernelContext* ctx,
const int rows = host_dense_shape(0); const int rows = host_dense_shape(0);
DCHECK_EQ(rows, row_ptr.size() - 1); DCHECK_EQ(rows, row_ptr.size() - 1);
GpuLaunchConfig config = GetGpuLaunchConfig(rows /*size*/, d); GpuLaunchConfig config = GetGpuLaunchConfig(rows /*size*/, d);
CSRSparseMatrixSoftmaxKernel2D<T> TF_CHECK_OK(GpuLaunchKernel(CSRSparseMatrixSoftmaxKernel2D<T>,
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>( config.block_count, config.thread_per_block, 0,
rows /*size*/, row_ptr.data(), logits_values.data(), d.stream(), rows /*size*/, row_ptr.data(),
softmax_values.data()); logits_values.data(), softmax_values.data()));
} else { } else {
const int batch_size = host_dense_shape(0); const int batch_size = host_dense_shape(0);
const int rows = host_dense_shape(1); const int rows = host_dense_shape(1);
@ -452,10 +465,11 @@ Status CSRSparseMatrixSoftmaxGPUImpl(OpKernelContext* ctx,
GpuLaunchConfig config = GetGpuLaunchConfig(size, d); GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
// shared memory stores the batch pointers. // shared memory stores the batch pointers.
const size_t shared_memory_size = sizeof(int) * (batch_size + 1); const size_t shared_memory_size = sizeof(int) * (batch_size + 1);
CSRSparseMatrixSoftmaxKernel3D<T> TF_CHECK_OK(GpuLaunchKernel(CSRSparseMatrixSoftmaxKernel3D<T>,
<<<config.block_count, config.thread_per_block, shared_memory_size, config.block_count, config.thread_per_block,
d.stream()>>>(size, rows, batch_ptr_copy.data(), row_ptr.data(), shared_memory_size, d.stream(), size, rows,
logits_values.data(), softmax_values.data()); batch_ptr_copy.data(), row_ptr.data(),
logits_values.data(), softmax_values.data()));
} }
return Status::OK(); return Status::OK();
@ -549,7 +563,7 @@ __global__ void CSRSparseMatrixSoftmaxGradKernel2D(
// algorithm to distribute the work in case the row sizes are // algorithm to distribute the work in case the row sizes are
// uneven: // uneven:
// http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf // http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf
CUDA_1D_KERNEL_LOOP(row, rows) { GPU_1D_KERNEL_LOOP(row, rows) {
CalculateRowSoftmaxGrad( CalculateRowSoftmaxGrad(
ldg(softmax_row_ptr + row) /*softmax_begin*/, ldg(softmax_row_ptr + row) /*softmax_begin*/,
ldg(softmax_row_ptr + row + 1) /*softmax_end*/, softmax_col_ind, ldg(softmax_row_ptr + row + 1) /*softmax_end*/, softmax_col_ind,
@ -579,7 +593,7 @@ __global__ void CSRSparseMatrixSoftmaxGradKernel3D(
#define SOFTMAX_BATCH_PTR(i) local_batch_ptr[i]; #define SOFTMAX_BATCH_PTR(i) local_batch_ptr[i];
#define GRAD_SOFTMAX_BATCH_PTR(i) local_batch_ptr[batch_size + 1 + i]; #define GRAD_SOFTMAX_BATCH_PTR(i) local_batch_ptr[batch_size + 1 + i];
CUDA_1D_KERNEL_LOOP(i, size) { GPU_1D_KERNEL_LOOP(i, size) {
const int batch = i / rows; const int batch = i / rows;
const int row = i % rows; const int row = i % rows;
const int softmax_batch_offset = SOFTMAX_BATCH_PTR(batch); const int softmax_batch_offset = SOFTMAX_BATCH_PTR(batch);
@ -625,12 +639,12 @@ Status CSRSparseMatrixSoftmaxGradGPUImpl(
DCHECK_EQ(rows + 1, softmax_row_ptr.size()); DCHECK_EQ(rows + 1, softmax_row_ptr.size());
DCHECK_EQ(rows + 1, grad_softmax_row_ptr.size()); DCHECK_EQ(rows + 1, grad_softmax_row_ptr.size());
GpuLaunchConfig config = GetGpuLaunchConfig(rows /*size*/, d); GpuLaunchConfig config = GetGpuLaunchConfig(rows /*size*/, d);
CSRSparseMatrixSoftmaxGradKernel2D<T> TF_CHECK_OK(GpuLaunchKernel(
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>( CSRSparseMatrixSoftmaxGradKernel2D<T>, config.block_count,
rows /*size*/, softmax_row_ptr.data(), softmax_col_ind.data(), config.thread_per_block, 0, d.stream(), rows /*size*/,
softmax_values.data(), grad_softmax_row_ptr.data(), softmax_row_ptr.data(), softmax_col_ind.data(), softmax_values.data(),
grad_softmax_col_ind.data(), grad_softmax_values.data(), grad_softmax_row_ptr.data(), grad_softmax_col_ind.data(),
gradient_values.data()); grad_softmax_values.data(), gradient_values.data()));
} else { } else {
const int batch_size = host_dense_shape(0); const int batch_size = host_dense_shape(0);
const int rows = host_dense_shape(1); const int rows = host_dense_shape(1);
@ -656,13 +670,13 @@ Status CSRSparseMatrixSoftmaxGradGPUImpl(
// shared memory stores two copies of batch pointers: one for the // shared memory stores two copies of batch pointers: one for the
// softmax CSR matrix, one for the grad_softmax CSR matrix. // softmax CSR matrix, one for the grad_softmax CSR matrix.
const size_t shared_memory_size = 2 * sizeof(int) * (batch_size + 1); const size_t shared_memory_size = 2 * sizeof(int) * (batch_size + 1);
CSRSparseMatrixSoftmaxGradKernel3D<T> TF_CHECK_OK(GpuLaunchKernel(
<<<config.block_count, config.thread_per_block, shared_memory_size, CSRSparseMatrixSoftmaxGradKernel3D<T>, config.block_count,
d.stream()>>>(size, rows, softmax_and_grad_batch_ptr_copy.data(), config.thread_per_block, shared_memory_size, d.stream(), size, rows,
softmax_row_ptr.data(), softmax_col_ind.data(), softmax_and_grad_batch_ptr_copy.data(), softmax_row_ptr.data(),
softmax_values.data(), grad_softmax_row_ptr.data(), softmax_col_ind.data(), softmax_values.data(),
grad_softmax_col_ind.data(), grad_softmax_row_ptr.data(), grad_softmax_col_ind.data(),
grad_softmax_values.data(), gradient_values.data()); grad_softmax_values.data(), gradient_values.data()));
} }
return Status::OK(); return Status::OK();
@ -687,4 +701,4 @@ DEFINE_SOFTMAX_GRAD_GPU(double);
} // namespace tensorflow } // namespace tensorflow
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -15,7 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif
@ -36,7 +36,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/threadpool.h" #include "tensorflow/core/platform/threadpool.h"
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_solvers.h" #include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h" #include "tensorflow/core/kernels/cuda_sparse.h"
#endif #endif
@ -694,7 +694,7 @@ REGISTER_CPU(complex128)
#undef REGISTER_CPU #undef REGISTER_CPU
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU(T) \ #define REGISTER_GPU(T) \
REGISTER_KERNEL_BUILDER( \ REGISTER_KERNEL_BUILDER( \
@ -703,14 +703,16 @@ REGISTER_CPU(complex128)
REGISTER_GPU(float) REGISTER_GPU(float)
REGISTER_GPU(double) REGISTER_GPU(double)
#if GOOGLE_CUDA
REGISTER_GPU(complex64) REGISTER_GPU(complex64)
REGISTER_GPU(complex128) REGISTER_GPU(complex128)
#endif
#undef REGISTER_GPU #undef REGISTER_GPU
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace functor { namespace functor {
@ -723,7 +725,7 @@ class CSRSparseMatrixMatMul<GPUDevice, T> {
Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a, Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a,
typename TTypes<T>::UnalignedConstMatrix b, typename TTypes<T>::UnalignedConstMatrix b,
typename TTypes<T>::UnalignedMatrix c) { typename TTypes<T>::UnalignedMatrix c) {
CudaSparse cuda_sparse(ctx); GpuSparse cuda_sparse(ctx);
TF_RETURN_IF_ERROR(cuda_sparse.Initialize()); TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
{ {
// Use Csrmm to calculate: // Use Csrmm to calculate:
@ -741,19 +743,34 @@ class CSRSparseMatrixMatMul<GPUDevice, T> {
// transA must be non-transpose if transB is transpose (cusparse // transA must be non-transpose if transB is transpose (cusparse
// limitation). // limitation).
const cusparseOperation_t transA = CUSPARSE_OPERATION_NON_TRANSPOSE; #if GOOGLE_CUDA
const gpusparseOperation_t transA = CUSPARSE_OPERATION_NON_TRANSPOSE;
#elif TENSORFLOW_USE_ROCM
const gpusparseOperation_t transA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
#endif
// transB: b is row-major, and cusparse requires col-major b (or // transB: b is row-major, and cusparse requires col-major b (or
// equivalently transB == transpose). this version is actually more // equivalently transB == transpose). this version is actually more
// efficient. // efficient.
const cusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE; #if GOOGLE_CUDA
const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE;
cusparseMatDescr_t descrA; gpusparseMatDescr_t descrA;
TF_RETURN_IF_CUSPARSE_ERROR(cusparseCreateMatDescr(&descrA)); TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
TF_RETURN_IF_CUSPARSE_ERROR( TF_RETURN_IF_GPUSPARSE_ERROR(
cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL)); cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
TF_RETURN_IF_CUSPARSE_ERROR( TF_RETURN_IF_GPUSPARSE_ERROR(
cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO)); cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
#elif TENSORFLOW_USE_ROCM
const gpusparseOperation_t transB = HIPSPARSE_OPERATION_TRANSPOSE;
gpusparseMatDescr_t descrA;
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA));
TF_RETURN_IF_GPUSPARSE_ERROR(
hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
TF_RETURN_IF_GPUSPARSE_ERROR(
hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
#endif
// A is (m, k), Bt is (ldb, k) and Ct is (ldc, n) // A is (m, k), Bt is (ldb, k) and Ct is (ldc, n)
const int k = b.dimension(0); const int k = b.dimension(0);
@ -796,13 +813,13 @@ template <typename T>
class CSRSparseMatrixMatVec<GPUDevice, T> { class CSRSparseMatrixMatVec<GPUDevice, T> {
public: public:
CSRSparseMatrixMatVec(bool transpose_a, bool conjugate_a) CSRSparseMatrixMatVec(bool transpose_a, bool conjugate_a)
: transA_(TransposeAndConjugateToCuSparseOp(transpose_a, conjugate_a, : transA_(TransposeAndConjugateToGpuSparseOp(transpose_a, conjugate_a,
&status_)) {} &status_)) {}
Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a, Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a,
const T* x, T* y) { const T* x, T* y) {
TF_RETURN_IF_ERROR(status_); TF_RETURN_IF_ERROR(status_);
CudaSparse cuda_sparse(ctx); GpuSparse cuda_sparse(ctx);
TF_RETURN_IF_ERROR(cuda_sparse.Initialize()); TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
{ {
// Use Csrmv to calculate: // Use Csrmv to calculate:
@ -815,12 +832,20 @@ class CSRSparseMatrixMatVec<GPUDevice, T> {
const T alpha = 1; const T alpha = 1;
const T beta = 0; const T beta = 0;
cusparseMatDescr_t descrA; gpusparseMatDescr_t descrA;
TF_RETURN_IF_CUSPARSE_ERROR(cusparseCreateMatDescr(&descrA)); #if GOOGLE_CUDA
TF_RETURN_IF_CUSPARSE_ERROR( TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
TF_RETURN_IF_GPUSPARSE_ERROR(
cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL)); cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
TF_RETURN_IF_CUSPARSE_ERROR( TF_RETURN_IF_GPUSPARSE_ERROR(
cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO)); cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
#elif TENSORFLOW_USE_ROCM
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA));
TF_RETURN_IF_GPUSPARSE_ERROR(
hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
TF_RETURN_IF_GPUSPARSE_ERROR(
hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
const int m = a.dense_shape_host(0); const int m = a.dense_shape_host(0);
const int n = a.dense_shape_host(1); const int n = a.dense_shape_host(1);
@ -836,11 +861,11 @@ class CSRSparseMatrixMatVec<GPUDevice, T> {
private: private:
Status status_; Status status_;
const cusparseOperation_t transA_; const gpusparseOperation_t transA_;
}; };
} // namespace functor } // namespace functor
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow } // namespace tensorflow

View File

@ -15,7 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif
@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse/kernels.h" #include "tensorflow/core/kernels/sparse/kernels.h"
#include "tensorflow/core/kernels/sparse/sparse_matrix.h" #include "tensorflow/core/kernels/sparse/sparse_matrix.h"
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_sparse.h" #include "tensorflow/core/kernels/cuda_sparse.h"
#endif #endif
@ -101,22 +101,24 @@ class CSRMulOp : public OpKernel {
Name("SparseMatrixMul").Device(DEVICE_##DEV).TypeConstraint<T>("T"), \ Name("SparseMatrixMul").Device(DEVICE_##DEV).TypeConstraint<T>("T"), \
CSRMulOp<DEV##Device, T>); CSRMulOp<DEV##Device, T>);
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU(T) REGISTER(GPU, T) #define REGISTER_GPU(T) REGISTER(GPU, T)
REGISTER_GPU(float) REGISTER_GPU(float)
REGISTER_GPU(double) REGISTER_GPU(double)
#if GOOGLE_CUDA
REGISTER_GPU(complex64) REGISTER_GPU(complex64)
REGISTER_GPU(complex128) REGISTER_GPU(complex128)
#endif
#undef REGISTER_GPU #undef REGISTER_GPU
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#undef REGISTER #undef REGISTER
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace functor { namespace functor {
@ -159,13 +161,15 @@ class CSRSparseMatrixMulScalar<GPUDevice, T> {
DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(float);
DECLARE_GPU_SPEC(double); DECLARE_GPU_SPEC(double);
#if GOOGLE_CUDA
DECLARE_GPU_SPEC(std::complex<float>); DECLARE_GPU_SPEC(std::complex<float>);
DECLARE_GPU_SPEC(std::complex<double>); DECLARE_GPU_SPEC(std::complex<double>);
#endif
#undef DECLARE_GPU_SPEC #undef DECLARE_GPU_SPEC
} // namespace functor } // namespace functor
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow } // namespace tensorflow

View File

@ -15,7 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif
@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse/kernels.h" #include "tensorflow/core/kernels/sparse/kernels.h"
#include "tensorflow/core/kernels/sparse/sparse_matrix.h" #include "tensorflow/core/kernels/sparse/sparse_matrix.h"
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_solvers.h" #include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h" #include "tensorflow/core/kernels/cuda_sparse.h"
#endif #endif
@ -67,11 +67,11 @@ class CSRNNZOp : public OpKernel {
REGISTER(CPU) REGISTER(CPU)
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER(GPU) REGISTER(GPU)
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#undef REGISTER #undef REGISTER

View File

@ -19,7 +19,7 @@ limitations under the License.
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_sparse.h" #include "tensorflow/core/kernels/cuda_sparse.h"
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif
@ -84,7 +84,7 @@ class CSRSoftmaxOp : public OpKernel {
} }
}; };
#ifdef GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER(DEV, T) \ #define REGISTER(DEV, T) \
REGISTER_KERNEL_BUILDER(Name("SparseMatrixSoftmax") \ REGISTER_KERNEL_BUILDER(Name("SparseMatrixSoftmax") \
.Device(DEVICE_##DEV) \ .Device(DEVICE_##DEV) \
@ -110,7 +110,7 @@ DECLARE_GPU_SPEC(double);
#undef DECLARE_GPU_SPEC #undef DECLARE_GPU_SPEC
} // namespace functor } // namespace functor
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename Device, typename T> template <typename Device, typename T>
class CSRSoftmaxGradOp : public OpKernel { class CSRSoftmaxGradOp : public OpKernel {
@ -193,7 +193,7 @@ class CSRSoftmaxGradOp : public OpKernel {
} }
}; };
#ifdef GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER(DEV, T) \ #define REGISTER(DEV, T) \
REGISTER_KERNEL_BUILDER(Name("SparseMatrixSoftmaxGrad") \ REGISTER_KERNEL_BUILDER(Name("SparseMatrixSoftmaxGrad") \
.Device(DEVICE_##DEV) \ .Device(DEVICE_##DEV) \
@ -220,6 +220,6 @@ DECLARE_GPU_SPEC(double);
#undef DECLARE_GPU_SPEC #undef DECLARE_GPU_SPEC
} // namespace functor } // namespace functor
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow } // namespace tensorflow

View File

@ -15,7 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif
@ -35,7 +35,7 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse/sparse_matrix.h" #include "tensorflow/core/kernels/sparse/sparse_matrix.h"
#include "tensorflow/core/util/work_sharder.h" #include "tensorflow/core/util/work_sharder.h"
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_solvers.h" #include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h" #include "tensorflow/core/kernels/cuda_sparse.h"
#endif #endif
@ -498,22 +498,24 @@ REGISTER_CPU(complex128)
.TypeConstraint<T>("type"), \ .TypeConstraint<T>("type"), \
CSRSparseMatMulGPUOp<DEV##Device, T>); CSRSparseMatMulGPUOp<DEV##Device, T>);
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU(T) REGISTER(GPU, T) #define REGISTER_GPU(T) REGISTER(GPU, T)
REGISTER_GPU(float) REGISTER_GPU(float)
REGISTER_GPU(double) REGISTER_GPU(double)
#if GOOGLE_CUDA
REGISTER_GPU(complex64) REGISTER_GPU(complex64)
REGISTER_GPU(complex128) REGISTER_GPU(complex128)
#endif // GOOGLE_CUDA
#undef REGISTER_GPU #undef REGISTER_GPU
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#undef REGISTER #undef REGISTER
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace functor { namespace functor {
template <typename T> template <typename T>
struct CSRSparseSparseMatrixMatMul<GPUDevice, T> struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
@ -527,11 +529,20 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
adjoint_a_(adjoint_a), adjoint_a_(adjoint_a),
transpose_b_(transpose_b) { transpose_b_(transpose_b) {
// TODO(ebrevdo): Figure out why transposed implementations crash cuSparse. // TODO(ebrevdo): Figure out why transposed implementations crash cuSparse.
#if GOOGLE_CUDA
transA_ = transpose_a ? (adjoint_a ? CUSPARSE_OPERATION_TRANSPOSE transA_ = transpose_a ? (adjoint_a ? CUSPARSE_OPERATION_TRANSPOSE
: CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE) : CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE)
: CUSPARSE_OPERATION_NON_TRANSPOSE; : CUSPARSE_OPERATION_NON_TRANSPOSE;
transB_ = transpose_b ? CUSPARSE_OPERATION_TRANSPOSE transB_ = transpose_b ? CUSPARSE_OPERATION_TRANSPOSE
: CUSPARSE_OPERATION_NON_TRANSPOSE; : CUSPARSE_OPERATION_NON_TRANSPOSE;
#elif TENSORFLOW_USE_ROCM
transA_ = transpose_a
? (adjoint_a ? HIPSPARSE_OPERATION_TRANSPOSE
: HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE)
: HIPSPARSE_OPERATION_NON_TRANSPOSE;
transB_ = transpose_b ? HIPSPARSE_OPERATION_TRANSPOSE
: HIPSPARSE_OPERATION_NON_TRANSPOSE;
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} }
Status Initialize() { Status Initialize() {
@ -630,20 +641,20 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
private: private:
OpKernelContext* ctx_; OpKernelContext* ctx_;
CudaSparse cuda_sparse_; GpuSparse cuda_sparse_;
bool initialized_; bool initialized_;
bool transpose_a_; bool transpose_a_;
bool adjoint_a_; bool adjoint_a_;
bool transpose_b_; bool transpose_b_;
CudaSparseMatrixDescriptor descrA_; GpuSparseMatrixDescriptor descrA_;
CudaSparseMatrixDescriptor descrB_; GpuSparseMatrixDescriptor descrB_;
CudaSparseMatrixDescriptor descrC_; GpuSparseMatrixDescriptor descrC_;
cusparseOperation_t transA_; gpusparseOperation_t transA_;
cusparseOperation_t transB_; gpusparseOperation_t transB_;
}; };
} // namespace functor } // namespace functor
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow } // namespace tensorflow

View File

@ -15,7 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif

View File

@ -18,7 +18,7 @@ limitations under the License.
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif

View File

@ -15,7 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif
@ -29,7 +29,7 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse/kernels.h" #include "tensorflow/core/kernels/sparse/kernels.h"
#include "tensorflow/core/kernels/sparse/sparse_matrix.h" #include "tensorflow/core/kernels/sparse/sparse_matrix.h"
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_solvers.h" #include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h" #include "tensorflow/core/kernels/cuda_sparse.h"
#endif #endif
@ -116,12 +116,14 @@ REGISTER(CPU, double)
REGISTER(CPU, complex64) REGISTER(CPU, complex64)
REGISTER(CPU, complex128) REGISTER(CPU, complex128)
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER(GPU, float) REGISTER(GPU, float)
REGISTER(GPU, double) REGISTER(GPU, double)
#if GOOGLE_CUDA
REGISTER(GPU, complex64) REGISTER(GPU, complex64)
REGISTER(GPU, complex128) REGISTER(GPU, complex128)
#endif
#undef REGISTER #undef REGISTER
@ -139,12 +141,14 @@ namespace functor {
DECLARE_GPU_SPEC(int32); DECLARE_GPU_SPEC(int32);
DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(float);
DECLARE_GPU_SPEC(double); DECLARE_GPU_SPEC(double);
#if GOOGLE_CUDA
DECLARE_GPU_SPEC(complex64); DECLARE_GPU_SPEC(complex64);
DECLARE_GPU_SPEC(complex128); DECLARE_GPU_SPEC(complex128);
#endif
#undef DECLARE_GPU_SPEC #undef DECLARE_GPU_SPEC
} // namespace functor } // namespace functor
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow } // namespace tensorflow

View File

@ -15,7 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif
@ -30,13 +30,18 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse/kernels.h" #include "tensorflow/core/kernels/sparse/kernels.h"
#include "tensorflow/core/kernels/sparse/sparse_matrix.h" #include "tensorflow/core/kernels/sparse/sparse_matrix.h"
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/kernels/cuda_solvers.h" #include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h" #include "tensorflow/core/kernels/cuda_sparse.h"
#include "tensorflow/stream_executor/cuda/cuda_activation.h" #endif
#if GOOGLE_CUDA
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
using ::perftools::gputools::cuda::ScopedActivateExecutorContext; using ::perftools::gputools::cuda::ScopedActivateExecutorContext;
#elif TENSORFLOW_USE_ROCM
#include "tensorflow/stream_executor/rocm/rocm_activation.h"
using ::perftools::gputools::rocm::ScopedActivateExecutorContext;
#endif #endif
namespace tensorflow { namespace tensorflow {
@ -104,7 +109,7 @@ class SparseTensorToCSRSparseMatrixCPUOp : public OpKernel {
} }
}; };
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename Device, typename T> template <typename Device, typename T>
class SparseTensorToCSRSparseMatrixGPUOp : public AsyncOpKernel { class SparseTensorToCSRSparseMatrixGPUOp : public AsyncOpKernel {
@ -302,7 +307,7 @@ struct COOSparseMatrixToCSRSparseMatrix<GPUDevice> {
Status operator()(OpKernelContext* c, const int rows, const int cols, Status operator()(OpKernelContext* c, const int rows, const int cols,
TTypes<int>::UnalignedVec coo_row_ind, TTypes<int>::UnalignedVec coo_row_ind,
TTypes<int>::UnalignedVec csr_row_ptr) { TTypes<int>::UnalignedVec csr_row_ptr) {
CudaSparse cuda_sparse(c); GpuSparse cuda_sparse(c);
TF_RETURN_IF_ERROR(cuda_sparse.Initialize()); TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
return cuda_sparse.Coo2csr(coo_row_ind.data(), return cuda_sparse.Coo2csr(coo_row_ind.data(),
/*nnz*/ coo_row_ind.size(), /*nnz*/ coo_row_ind.size(),
@ -322,12 +327,14 @@ extern template struct COOSparseMatrixToCSRSparseMatrix<GPUDevice>;
REGISTER_GPU(float) REGISTER_GPU(float)
REGISTER_GPU(double) REGISTER_GPU(double)
#if GOOGLE_CUDA
REGISTER_GPU(complex64) REGISTER_GPU(complex64)
REGISTER_GPU(complex128) REGISTER_GPU(complex128)
#endif
#undef REGISTER_GPU #undef REGISTER_GPU
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_CPU(T) \ #define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER(Name("SparseTensorToCSRSparseMatrix") \ REGISTER_KERNEL_BUILDER(Name("SparseTensorToCSRSparseMatrix") \

View File

@ -19,7 +19,7 @@ limitations under the License.
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_sparse.h" #include "tensorflow/core/kernels/cuda_sparse.h"
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif
@ -132,9 +132,12 @@ REGISTER_TRANSPOSE(CPU, double)
REGISTER_TRANSPOSE(CPU, complex64) REGISTER_TRANSPOSE(CPU, complex64)
REGISTER_TRANSPOSE(CPU, complex128) REGISTER_TRANSPOSE(CPU, complex128)
#ifdef GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER_TRANSPOSE(GPU, float) REGISTER_TRANSPOSE(GPU, float)
REGISTER_TRANSPOSE(GPU, double) REGISTER_TRANSPOSE(GPU, double)
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if GOOGLE_CUDA
REGISTER_TRANSPOSE(GPU, complex64) REGISTER_TRANSPOSE(GPU, complex64)
REGISTER_TRANSPOSE(GPU, complex128) REGISTER_TRANSPOSE(GPU, complex128)
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA
@ -250,16 +253,20 @@ struct CSRSparseMatrixTransposeComponent<CPUDevice, T> {
} }
}; };
#ifdef GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename T> template <typename T>
struct CSRSparseMatrixTransposeComponent<GPUDevice, T> { struct CSRSparseMatrixTransposeComponent<GPUDevice, T> {
Status operator()(OpKernelContext* ctx, const ConstCSRComponent<T>& x, Status operator()(OpKernelContext* ctx, const ConstCSRComponent<T>& x,
CSRComponent<T>* y) { CSRComponent<T>* y) {
TF_RETURN_IF_ERROR(ValidateTransposeInputs(x, *y)); TF_RETURN_IF_ERROR(ValidateTransposeInputs(x, *y));
CudaSparse cuda_sparse(ctx); GpuSparse cuda_sparse(ctx);
TF_RETURN_IF_ERROR(cuda_sparse.Initialize()); TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
const cusparseAction_t copyValues = CUSPARSE_ACTION_NUMERIC; #if GOOGLE_CUDA
const gpusparseAction_t copyValues = CUSPARSE_ACTION_NUMERIC;
#elif TENSORFLOW_USE_ROCM
const gpusparseAction_t copyValues = HIPSPARSE_ACTION_NUMERIC;
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
const int rank = x.dense_shape_host.size(); const int rank = x.dense_shape_host.size();
const int m = x.row_ptr.size() - 1; const int m = x.row_ptr.size() - 1;
const int n = x.dense_shape_host(rank - 1); const int n = x.dense_shape_host(rank - 1);
@ -279,7 +286,7 @@ struct CSRSparseMatrixTransposeComponent<GPUDevice, T> {
return Status::OK(); return Status::OK();
} }
}; };
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace functor } // namespace functor
} // namespace tensorflow } // namespace tensorflow

View File

@ -15,7 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif
@ -74,7 +74,7 @@ Status CSRSparseMatrixZerosLikeHelper(OpKernelContext* ctx,
} // namespace } // namespace
#ifdef GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER(DEV) \ #define REGISTER(DEV) \
REGISTER_KERNEL_BUILDER(Name("SparseMatrixZeros") \ REGISTER_KERNEL_BUILDER(Name("SparseMatrixZeros") \
.Device(DEVICE_##DEV) \ .Device(DEVICE_##DEV) \
@ -88,6 +88,6 @@ REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(
CSRSparseMatrixZerosLikeHelper<GPUDevice>); CSRSparseMatrixZerosLikeHelper<GPUDevice>);
#undef REGISTER #undef REGISTER
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow } // namespace tensorflow

View File

@ -18,7 +18,7 @@ limitations under the License.
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif

View File

@ -156,7 +156,7 @@ class TridiagonalSolveOpGpuLinalg : public LinearAlgebraOp<Scalar> {
k); k);
return; return;
} }
std::unique_ptr<CudaSparse> cusparse_solver(new CudaSparse(context)); std::unique_ptr<GpuSparse> cusparse_solver(new GpuSparse(context));
OP_REQUIRES_OK(context, cusparse_solver->Initialize()); OP_REQUIRES_OK(context, cusparse_solver->Initialize());
if (k == 1) { if (k == 1) {
// rhs is copied into x, then gtsv replaces x with solution. // rhs is copied into x, then gtsv replaces x with solution.
@ -196,20 +196,20 @@ class TridiagonalSolveOpGpuLinalg : public LinearAlgebraOp<Scalar> {
} }
void SolveWithGtsv(OpKernelContext* context, void SolveWithGtsv(OpKernelContext* context,
std::unique_ptr<CudaSparse>& cusparse_solver, std::unique_ptr<GpuSparse>& cusparse_solver,
const Scalar* superdiag, const Scalar* diag, const Scalar* superdiag, const Scalar* diag,
const Scalar* subdiag, Scalar* rhs, const int num_eqs, const Scalar* subdiag, Scalar* rhs, const int num_eqs,
const int num_rhs) const { const int num_rhs) const {
#if CUDA_VERSION < 9000 #if CUDA_VERSION < 9000
auto function = pivoting_ ? &CudaSparse::Gtsv<Scalar> auto function =
: &CudaSparse::GtsvNoPivot<Scalar>; pivoting_ ? &GpuSparse::Gtsv<Scalar> : &GpuSparse::GtsvNoPivot<Scalar>;
OP_REQUIRES_OK( OP_REQUIRES_OK(
context, (cusparse_solver.get()->*function)( context, (cusparse_solver.get()->*function)(
num_eqs, num_rhs, subdiag, diag, superdiag, rhs, num_eqs)); num_eqs, num_rhs, subdiag, diag, superdiag, rhs, num_eqs));
#else #else
auto buffer_function = pivoting_ auto buffer_function = pivoting_
? &CudaSparse::Gtsv2BufferSizeExt<Scalar> ? &GpuSparse::Gtsv2BufferSizeExt<Scalar>
: &CudaSparse::Gtsv2NoPivotBufferSizeExt<Scalar>; : &GpuSparse::Gtsv2NoPivotBufferSizeExt<Scalar>;
size_t buffer_size; size_t buffer_size;
OP_REQUIRES_OK(context, (cusparse_solver.get()->*buffer_function)( OP_REQUIRES_OK(context, (cusparse_solver.get()->*buffer_function)(
num_eqs, num_rhs, subdiag, diag, superdiag, rhs, num_eqs, num_rhs, subdiag, diag, superdiag, rhs,
@ -220,8 +220,8 @@ class TridiagonalSolveOpGpuLinalg : public LinearAlgebraOp<Scalar> {
context->allocate_temp(DT_UINT8, temp_shape, &temp_tensor)); context->allocate_temp(DT_UINT8, temp_shape, &temp_tensor));
void* buffer = temp_tensor.flat<std::uint8_t>().data(); void* buffer = temp_tensor.flat<std::uint8_t>().data();
auto solver_function = pivoting_ ? &CudaSparse::Gtsv2<Scalar> auto solver_function = pivoting_ ? &GpuSparse::Gtsv2<Scalar>
: &CudaSparse::Gtsv2NoPivot<Scalar>; : &GpuSparse::Gtsv2NoPivot<Scalar>;
OP_REQUIRES_OK(context, (cusparse_solver.get()->*solver_function)( OP_REQUIRES_OK(context, (cusparse_solver.get()->*solver_function)(
num_eqs, num_rhs, subdiag, diag, superdiag, rhs, num_eqs, num_rhs, subdiag, diag, superdiag, rhs,
num_eqs, buffer)); num_eqs, buffer));
@ -315,7 +315,7 @@ class TridiagonalSolveOpGpu : public OpKernel {
rhs.flat<Scalar>().size()); rhs.flat<Scalar>().size());
Scalar* x = output->flat<Scalar>().data(); Scalar* x = output->flat<Scalar>().data();
std::unique_ptr<CudaSparse> cusparse_solver(new CudaSparse(context)); std::unique_ptr<GpuSparse> cusparse_solver(new GpuSparse(context));
OP_REQUIRES_OK(context, cusparse_solver->Initialize()); OP_REQUIRES_OK(context, cusparse_solver->Initialize());
#if CUDA_VERSION < 9000 #if CUDA_VERSION < 9000

View File

@ -28,7 +28,6 @@ cuda_py_test(
size = "medium", size = "medium",
srcs = ["csr_sparse_matrix_test.py"], srcs = ["csr_sparse_matrix_test.py"],
main = "csr_sparse_matrix_test.py", main = "csr_sparse_matrix_test.py",
tags = ["no_rocm"],
deps = [ deps = [
"//tensorflow/python/ops/linalg/sparse", "//tensorflow/python/ops/linalg/sparse",
], ],
@ -40,7 +39,6 @@ cuda_py_test(
srcs = ["csr_sparse_matrix_ops_test.py"], srcs = ["csr_sparse_matrix_ops_test.py"],
main = "csr_sparse_matrix_ops_test.py", main = "csr_sparse_matrix_ops_test.py",
shard_count = 10, shard_count = 10,
tags = ["no_rocm"],
deps = [ deps = [
"//tensorflow/python/ops/linalg/sparse", "//tensorflow/python/ops/linalg/sparse",
"//tensorflow/python/ops/linalg/sparse:gen_sparse_csr_matrix_ops", "//tensorflow/python/ops/linalg/sparse:gen_sparse_csr_matrix_ops",
@ -53,7 +51,6 @@ cuda_py_test(
srcs = ["csr_sparse_matrix_grad_test.py"], srcs = ["csr_sparse_matrix_grad_test.py"],
main = "csr_sparse_matrix_grad_test.py", main = "csr_sparse_matrix_grad_test.py",
shard_count = 50, shard_count = 50,
tags = ["no_rocm"],
deps = [ deps = [
"//tensorflow/python/ops/linalg/sparse", "//tensorflow/python/ops/linalg/sparse",
], ],
@ -65,7 +62,6 @@ cuda_py_test(
srcs = ["csr_sparse_matrix_dense_mat_mul_grad_test.py"], srcs = ["csr_sparse_matrix_dense_mat_mul_grad_test.py"],
main = "csr_sparse_matrix_dense_mat_mul_grad_test.py", main = "csr_sparse_matrix_dense_mat_mul_grad_test.py",
shard_count = 50, shard_count = 50,
tags = ["no_rocm"],
deps = [ deps = [
"//tensorflow/python/ops/linalg/sparse", "//tensorflow/python/ops/linalg/sparse",
], ],
@ -77,7 +73,6 @@ cuda_py_test(
srcs = ["csr_sparse_matrix_sparse_mat_mul_grad_test.py"], srcs = ["csr_sparse_matrix_sparse_mat_mul_grad_test.py"],
main = "csr_sparse_matrix_sparse_mat_mul_grad_test.py", main = "csr_sparse_matrix_sparse_mat_mul_grad_test.py",
shard_count = 50, shard_count = 50,
tags = ["no_rocm"],
deps = [ deps = [
"//tensorflow/python/ops/linalg/sparse", "//tensorflow/python/ops/linalg/sparse",
], ],

View File

@ -106,7 +106,11 @@ class CSRSparseMatrixDenseMatMulGradTest(test.TestCase):
# These tests are refactored from sparse_csr_matrix_grad_test to keep its size # These tests are refactored from sparse_csr_matrix_grad_test to keep its size
# "medium". # "medium".
for dtype in (np.float32, np.complex64): dtypes_to_test = [np.float32]
if not test.is_built_with_rocm:
# complex type is not supported on the ROCm platform
dtypes_to_test += [np.complex64]
for dtype in dtypes_to_test:
for (t_a, t_b, adj_a, adj_b, t_out, for (t_a, t_b, adj_a, adj_b, t_out,
conj_out) in itertools.product(*(([False, True],) * 6)): conj_out) in itertools.product(*(([False, True],) * 6)):

View File

@ -84,6 +84,9 @@ class CSRSparseMatrixGradTest(test.TestCase):
if not self._gpu_available: if not self._gpu_available:
return return
if test.is_built_with_rocm():
self.skipTest("sparse-matrix-add op not supported on ROCm")
sparsify = lambda m: m * (m > 0) sparsify = lambda m: m * (m > 0)
for dense_shape in ([53, 65, 127], [127, 65]): for dense_shape in ([53, 65, 127], [127, 65]):
a_mats_val = sparsify(np.random.randn(*dense_shape)) a_mats_val = sparsify(np.random.randn(*dense_shape))

View File

@ -432,6 +432,9 @@ class CSRSparseMatrixOpsTest(test.TestCase):
if not self._gpu_available: if not self._gpu_available:
return return
if test.is_built_with_rocm():
self.skipTest("sparse-matrix-add op not supported on ROCm")
a_indices = np.array([[0, 0], [2, 3]]) a_indices = np.array([[0, 0], [2, 3]])
a_values = np.array([1.0, 5.0]).astype(np.float32) a_values = np.array([1.0, 5.0]).astype(np.float32)
a_dense_shape = [5, 6] a_dense_shape = [5, 6]
@ -469,6 +472,9 @@ class CSRSparseMatrixOpsTest(test.TestCase):
if not self._gpu_available: if not self._gpu_available:
return return
if test.is_built_with_rocm():
self.skipTest("sparse-matrix-add op not supported on ROCm")
sparsify = lambda m: m * (m > 0) sparsify = lambda m: m * (m > 0)
dense_shape = [53, 65, 127] dense_shape = [53, 65, 127]
a_mats = sparsify(np.random.randn(*dense_shape)).astype(np.float32) a_mats = sparsify(np.random.randn(*dense_shape)).astype(np.float32)
@ -511,6 +517,9 @@ class CSRSparseMatrixOpsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testSparseMatrixMatMulConjugateOutput(self): def testSparseMatrixMatMulConjugateOutput(self):
if test.is_built_with_rocm():
self.skipTest("complex type not supported on ROCm")
for shapes in [[(5, 6), (6, 1)], [(5, 6), (6, 2)]]: for shapes in [[(5, 6), (6, 1)], [(5, 6), (6, 2)]]:
a_indices = np.array([[0, 0], [2, 3]]) a_indices = np.array([[0, 0], [2, 3]])
a_values = np.array([1.0 + 1.j, 5.0 - 2.j]).astype(np.complex64) a_values = np.array([1.0 + 1.j, 5.0 - 2.j]).astype(np.complex64)
@ -533,8 +542,19 @@ class CSRSparseMatrixOpsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testLargeBatchSparseMatrixMatMul(self): def testLargeBatchSparseMatrixMatMul(self):
dtypes_to_test = [np.float32]
if not test.is_built_with_rocm():
# complex types is not supported on the ROCm platform
dtypes_to_test += [np.complex64]
if test.is_built_with_rocm():
# TODO(rocm): fix this
# This test is currently failing on the ROCm platform
# Ren-enable it once the fix is available
self.skipTest("hipSPARSE all failure on the ROCm platform")
sparsify = lambda m: m * (m > 0) sparsify = lambda m: m * (m > 0)
for dtype in np.float32, np.complex64: for dtype in dtypes_to_test:
for (transpose_a, transpose_b) in ((False, False), (False, True), for (transpose_a, transpose_b) in ((False, False), (False, True),
(True, False), (True, True)): (True, False), (True, True)):
for (adjoint_a, adjoint_b) in ((False, False), (False, True), for (adjoint_a, adjoint_b) in ((False, False), (False, True),
@ -584,8 +604,19 @@ class CSRSparseMatrixOpsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testLargeBatchSparseMatrixMatMulTransposed(self): def testLargeBatchSparseMatrixMatMulTransposed(self):
dtypes_to_test = [np.float32]
if not test.is_built_with_rocm():
# complex types is not supported on the ROCm platform
dtypes_to_test += [np.complex64]
if test.is_built_with_rocm():
# TODO(rocm): fix this
# This test is currently failing on the ROCm platform
# Ren-enable it once the fix is available
self.skipTest("hipSPARSE all failure on the ROCm platform")
sparsify = lambda m: m * (m > 0) sparsify = lambda m: m * (m > 0)
for dtype in np.float32, np.complex64: for dtype in dtypes_to_test:
for (transpose_a, transpose_b) in ((False, False), (False, True), for (transpose_a, transpose_b) in ((False, False), (False, True),
(True, False), (True, True)): (True, False), (True, True)):
for (adjoint_a, adjoint_b) in ((False, False), (False, True), for (adjoint_a, adjoint_b) in ((False, False), (False, True),
@ -636,6 +667,10 @@ class CSRSparseMatrixOpsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testLargeBatchSparseMatrixMatMulConjugate(self): def testLargeBatchSparseMatrixMatMulConjugate(self):
if test.is_built_with_rocm():
# complex types are not yet supported on the ROCm platform
self.skipTest("complex type not supported on ROCm")
sparsify = lambda m: m * (m > 0) sparsify = lambda m: m * (m > 0)
a_dense_shape = [53, 65, 127] a_dense_shape = [53, 65, 127]
b_dense_shape = [53, 127, 67] b_dense_shape = [53, 127, 67]
@ -767,6 +802,10 @@ class CSRSparseMatrixOpsTest(test.TestCase):
if not self._gpu_available: if not self._gpu_available:
return return
if test.is_built_with_rocm():
# sparse-matrix-add op is not yet supported on the ROCm platform
self.skipTest("sparse-matrix-add op not supported on ROCm")
sparsify = lambda m: m * (m > 0) sparsify = lambda m: m * (m > 0)
dense_shape = [53, 65, 127] dense_shape = [53, 65, 127]
matrices = [ matrices = [
@ -1154,9 +1193,10 @@ class CSRSparseMatrixOpsTest(test.TestCase):
] # ] #
]).astype(np.complex128) ]).astype(np.complex128)
data_types = [ data_types = [dtypes.float32, dtypes.float64]
dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128 if not test.is_built_with_rocm():
] # complex type is not supported on the ROCm platform
data_types += [dtypes.complex64, dtypes.complex128]
for dtype in data_types: for dtype in data_types:
sparse_matrix = dense_to_csr_sparse_matrix( sparse_matrix = dense_to_csr_sparse_matrix(
math_ops.cast(dense_mat, dtype)) math_ops.cast(dense_mat, dtype))

View File

@ -154,7 +154,11 @@ class SparseMatrixMatmulTest(test.TestCase):
sparsify = lambda m: m * (m > 0) sparsify = lambda m: m * (m > 0)
dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13] dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13]
dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15] dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15]
for dtype in np.float32, np.complex64: dtypes_to_test = [np.float32]
if not test.is_built_with_rocm():
# complex type is not supported on the ROCm platform
dtypes_to_test += [np.complex64]
for dtype in dtypes_to_test:
a_mats = sparsify((np.random.randn(*dense_shape_a) + a_mats = sparsify((np.random.randn(*dense_shape_a) +
1.j * np.random.randn(*dense_shape_a))).astype(dtype) 1.j * np.random.randn(*dense_shape_a))).astype(dtype)
b_mats = sparsify((np.random.randn(*dense_shape_b) + b_mats = sparsify((np.random.randn(*dense_shape_b) +
@ -194,7 +198,11 @@ class SparseMatrixMatmulTest(test.TestCase):
sparsify = lambda m: m * (m > 0) sparsify = lambda m: m * (m > 0)
dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13] dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13]
dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15] dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15]
for dtype in np.float32, np.complex64: dtypes_to_test = [np.float32]
if not test.is_built_with_rocm():
# complex type is not supported on the ROCm platform
dtypes_to_test += [np.complex64]
for dtype in dtypes_to_test:
a_mats = sparsify((np.random.randn(*dense_shape_a) + a_mats = sparsify((np.random.randn(*dense_shape_a) +
1.j * np.random.randn(*dense_shape_a))).astype(dtype) 1.j * np.random.randn(*dense_shape_a))).astype(dtype)
b_mats = (np.random.randn(*dense_shape_b) + b_mats = (np.random.randn(*dense_shape_b) +
@ -231,7 +239,11 @@ class SparseMatrixMatmulTest(test.TestCase):
sparsify = lambda m: m * (m > 0) sparsify = lambda m: m * (m > 0)
dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13] dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13]
dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15] dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15]
for dtype in np.float32, np.complex64: dtypes_to_test = [np.float32]
if not test.is_built_with_rocm():
# complex type is not supported on the ROCm platform
dtypes_to_test += [np.complex64]
for dtype in dtypes_to_test:
a_mats = (np.random.randn(*dense_shape_a) + a_mats = (np.random.randn(*dense_shape_a) +
1.j * np.random.randn(*dense_shape_a)).astype(dtype) 1.j * np.random.randn(*dense_shape_a)).astype(dtype)
b_mats = sparsify((np.random.randn(*dense_shape_b) + b_mats = sparsify((np.random.randn(*dense_shape_b) +

View File

@ -137,4 +137,11 @@ cc_library(
], ],
) )
cc_import(
name = "hipsparse",
hdrs = glob(["rocm/include/hipsparse/**",]),
shared_library = "rocm/lib/%{hipsparse_lib}",
visibility = ["//visibility:public"],
)
%{copy_rules} %{copy_rules}

View File

@ -16,6 +16,6 @@ limitations under the License.
#ifndef ROCM_ROCM_CONFIG_H_ #ifndef ROCM_ROCM_CONFIG_H_
#define ROCM_ROCM_CONFIG_H_ #define ROCM_ROCM_CONFIG_H_
#define TF_ROCM_TOOLKIT_PATH "/opt/rocm" #define TF_ROCM_TOOLKIT_PATH "%{rocm_toolkit_path}"
#endif // ROCM_ROCM_CONFIG_H_ #endif // ROCM_ROCM_CONFIG_H_

View File

@ -191,50 +191,50 @@ def _rocm_include_path(repository_ctx, rocm_config):
inc_dirs.append(rocm_config.rocm_toolkit_path + "/include") inc_dirs.append(rocm_config.rocm_toolkit_path + "/include")
# Add HSA headers # Add HSA headers
inc_dirs.append("/opt/rocm/hsa/include") inc_dirs.append(rocm_config.rocm_toolkit_path + "/hsa/include")
# Add HIP headers # Add HIP headers
inc_dirs.append("/opt/rocm/include/hip") inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/hip")
inc_dirs.append("/opt/rocm/include/hip/hcc_detail") inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/hip/hcc_detail")
inc_dirs.append("/opt/rocm/hip/include") inc_dirs.append(rocm_config.rocm_toolkit_path + "/hip/include")
# Add HIP-Clang headers # Add HIP-Clang headers
inc_dirs.append("/opt/rocm/llvm/lib/clang/8.0/include") inc_dirs.append(rocm_config.rocm_toolkit_path + "/llvm/lib/clang/8.0/include")
inc_dirs.append("/opt/rocm/llvm/lib/clang/9.0.0/include") inc_dirs.append(rocm_config.rocm_toolkit_path + "/llvm/lib/clang/9.0.0/include")
inc_dirs.append("/opt/rocm/llvm/lib/clang/10.0.0/include") inc_dirs.append(rocm_config.rocm_toolkit_path + "/llvm/lib/clang/10.0.0/include")
# Add rocrand and hiprand headers # Add rocrand and hiprand headers
inc_dirs.append("/opt/rocm/rocrand/include") inc_dirs.append(rocm_config.rocm_toolkit_path + "/rocrand/include")
inc_dirs.append("/opt/rocm/hiprand/include") inc_dirs.append(rocm_config.rocm_toolkit_path + "/hiprand/include")
# Add rocfft headers # Add rocfft headers
inc_dirs.append("/opt/rocm/rocfft/include") inc_dirs.append(rocm_config.rocm_toolkit_path + "/rocfft/include")
# Add rocBLAS headers # Add rocBLAS headers
inc_dirs.append("/opt/rocm/rocblas/include") inc_dirs.append(rocm_config.rocm_toolkit_path + "/rocblas/include")
# Add MIOpen headers # Add MIOpen headers
inc_dirs.append("/opt/rocm/miopen/include") inc_dirs.append(rocm_config.rocm_toolkit_path + "/miopen/include")
# Add RCCL headers # Add RCCL headers
inc_dirs.append("/opt/rocm/rccl/include") inc_dirs.append(rocm_config.rocm_toolkit_path + "/rccl/include")
# Add hcc headers # Add hcc headers
inc_dirs.append("/opt/rocm/hcc/include") inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/include")
inc_dirs.append("/opt/rocm/hcc/compiler/lib/clang/7.0.0/include/") inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/compiler/lib/clang/7.0.0/include/")
inc_dirs.append("/opt/rocm/hcc/lib/clang/7.0.0/include") inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/lib/clang/7.0.0/include")
# Newer hcc builds use/are based off of clang 8.0.0. # Newer hcc builds use/are based off of clang 8.0.0.
inc_dirs.append("/opt/rocm/hcc/compiler/lib/clang/8.0.0/include/") inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/compiler/lib/clang/8.0.0/include/")
inc_dirs.append("/opt/rocm/hcc/lib/clang/8.0.0/include") inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/lib/clang/8.0.0/include")
# Support hcc based off clang 9.0.0, included in ROCm2.2 # Support hcc based off clang 9.0.0, included in ROCm2.2
inc_dirs.append("/opt/rocm/hcc/compiler/lib/clang/9.0.0/include/") inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/compiler/lib/clang/9.0.0/include/")
inc_dirs.append("/opt/rocm/hcc/lib/clang/9.0.0/include") inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/lib/clang/9.0.0/include")
# Support hcc based off clang 10.0.0, included in ROCm2.8 # Support hcc based off clang 10.0.0, included in ROCm2.8
inc_dirs.append("/opt/rocm/hcc/compiler/lib/clang/10.0.0/include/") inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/compiler/lib/clang/10.0.0/include/")
inc_dirs.append("/opt/rocm/hcc/lib/clang/10.0.0/include") inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/lib/clang/10.0.0/include")
return inc_dirs return inc_dirs
@ -300,11 +300,12 @@ def _hipcc_env(repository_ctx):
repository_ctx.os.environ[name].strip() + "\";") repository_ctx.os.environ[name].strip() + "\";")
return hipcc_env.strip() return hipcc_env.strip()
def _hipcc_is_hipclang(repository_ctx): def _hipcc_is_hipclang(repository_ctx, rocm_config):
"""Returns if hipcc is based on hip-clang toolchain. """Returns if hipcc is based on hip-clang toolchain.
Args: Args:
repository_ctx: The repository context. repository_ctx: The repository context.
rocm_config: The path to the hip compiler.
Returns: Returns:
A string "True" if hipcc is based on hip-clang toolchain. A string "True" if hipcc is based on hip-clang toolchain.
@ -319,7 +320,7 @@ def _hipcc_is_hipclang(repository_ctx):
# grep for "HIP_COMPILER=clang" in /opt/rocm/hip/lib/.hipInfo # grep for "HIP_COMPILER=clang" in /opt/rocm/hip/lib/.hipInfo
grep_result = _execute( grep_result = _execute(
repository_ctx, repository_ctx,
["grep", "HIP_COMPILER=clang", "/opt/rocm/hip/lib/.hipInfo"], ["grep", "HIP_COMPILER=clang", rocm_config.rocm_toolkit_path + "/hip/lib/.hipInfo"],
empty_stdout_fine = True, empty_stdout_fine = True,
) )
result = grep_result.stdout.strip() result = grep_result.stdout.strip()
@ -327,13 +328,14 @@ def _hipcc_is_hipclang(repository_ctx):
return "True" return "True"
return "False" return "False"
def _if_hipcc_is_hipclang(repository_ctx, if_true, if_false = []): def _if_hipcc_is_hipclang(repository_ctx, rocm_config, if_true, if_false = []):
""" """
Returns either the if_true or if_false arg based on whether hipcc Returns either the if_true or if_false arg based on whether hipcc
is based on the hip-clang toolchain is based on the hip-clang toolchain
Args : Args :
repository_ctx: The repository context. repository_ctx: The repository context.
rocm_config: The path to the hip compiler.
if_true : value to return if hipcc is hip-clang based if_true : value to return if hipcc is hip-clang based
if_false : value to return if hipcc is not hip-clang based if_false : value to return if hipcc is not hip-clang based
(optional, defaults to empty list) (optional, defaults to empty list)
@ -341,7 +343,7 @@ def _if_hipcc_is_hipclang(repository_ctx, if_true, if_false = []):
Returns : Returns :
either the if_true arg or the of_False arg either the if_true arg or the of_False arg
""" """
if _hipcc_is_hipclang(repository_ctx) == "True": if _hipcc_is_hipclang(repository_ctx, rocm_config) == "True":
return if_true return if_true
return if_false return if_false
@ -478,6 +480,11 @@ def _find_libs(repository_ctx, rocm_config):
repository_ctx, repository_ctx,
rocm_config.rocm_toolkit_path + "/rccl", rocm_config.rocm_toolkit_path + "/rccl",
), ),
"hipsparse": _find_rocm_lib(
"hipsparse",
repository_ctx,
rocm_config.rocm_toolkit_path + "/hipsparse",
),
} }
def _get_rocm_config(repository_ctx): def _get_rocm_config(repository_ctx):
@ -558,6 +565,7 @@ def _create_dummy_repository(repository_ctx):
"%{rccl_lib}": _lib_name("rccl"), "%{rccl_lib}": _lib_name("rccl"),
"%{rocfft_lib}": _lib_name("rocfft"), "%{rocfft_lib}": _lib_name("rocfft"),
"%{hiprand_lib}": _lib_name("hiprand"), "%{hiprand_lib}": _lib_name("hiprand"),
"%{hipsparse_lib}": _lib_name("hipsparse"),
"%{copy_rules}": "", "%{copy_rules}": "",
"%{rocm_headers}": "", "%{rocm_headers}": "",
}, },
@ -703,6 +711,12 @@ def _create_local_rocm_repository(repository_ctx):
src_dir = rocm_toolkit_path + "/rccl/include", src_dir = rocm_toolkit_path + "/rccl/include",
out_dir = "rocm/include/rccl", out_dir = "rocm/include/rccl",
), ),
make_copy_dir_rule(
repository_ctx,
name = "hipsparse-include",
src_dir = rocm_toolkit_path + "/hipsparse/include",
out_dir = "rocm/include/hipsparse",
),
] ]
rocm_libs = _find_libs(repository_ctx, rocm_config) rocm_libs = _find_libs(repository_ctx, rocm_config)
@ -740,16 +754,19 @@ def _create_local_rocm_repository(repository_ctx):
"%{hiprand_lib}": rocm_libs["hiprand"].file_name, "%{hiprand_lib}": rocm_libs["hiprand"].file_name,
"%{miopen_lib}": rocm_libs["miopen"].file_name, "%{miopen_lib}": rocm_libs["miopen"].file_name,
"%{rccl_lib}": rocm_libs["rccl"].file_name, "%{rccl_lib}": rocm_libs["rccl"].file_name,
"%{hipsparse_lib}": rocm_libs["hipsparse"].file_name,
"%{copy_rules}": "\n".join(copy_rules), "%{copy_rules}": "\n".join(copy_rules),
"%{rocm_headers}": ('":rocm-include",\n' + "%{rocm_headers}": ('":rocm-include",\n' +
'":rocfft-include",\n' + '":rocfft-include",\n' +
'":rocblas-include",\n' + '":rocblas-include",\n' +
'":miopen-include",\n' + '":miopen-include",\n' +
'":rccl-include",'), '":rccl-include",\n' +
'":hipsparse-include",'),
}, },
) )
# Set up crosstool/ # Set up crosstool/
cc = find_cc(repository_ctx) cc = find_cc(repository_ctx)
host_compiler_includes = get_cxx_inc_directories(repository_ctx, cc) host_compiler_includes = get_cxx_inc_directories(repository_ctx, cc)
@ -762,7 +779,7 @@ def _create_local_rocm_repository(repository_ctx):
rocm_defines["%{host_compiler_prefix}"] = host_compiler_prefix rocm_defines["%{host_compiler_prefix}"] = host_compiler_prefix
rocm_defines["%{linker_bin_path}"] = "/opt/rocm/hcc/compiler/bin" rocm_defines["%{linker_bin_path}"] = rocm_config.rocm_toolkit_path + "/hcc/compiler/bin"
# For gcc, do not canonicalize system header paths; some versions of gcc # For gcc, do not canonicalize system header paths; some versions of gcc
# pick the shortest possible path for system includes when creating the # pick the shortest possible path for system includes when creating the
@ -775,7 +792,7 @@ def _create_local_rocm_repository(repository_ctx):
"-DTENSORFLOW_USE_ROCM=1", "-DTENSORFLOW_USE_ROCM=1",
"-D__HIP_PLATFORM_HCC__", "-D__HIP_PLATFORM_HCC__",
"-DEIGEN_USE_HIP", "-DEIGEN_USE_HIP",
] + _if_hipcc_is_hipclang(repository_ctx, [ ] + _if_hipcc_is_hipclang(repository_ctx, rocm_config, [
# #
# define "TENSORFLOW_COMPILER_IS_HIP_CLANG" when we are using clang # define "TENSORFLOW_COMPILER_IS_HIP_CLANG" when we are using clang
# based hipcc to compile/build tensorflow # based hipcc to compile/build tensorflow
@ -815,14 +832,14 @@ def _create_local_rocm_repository(repository_ctx):
"crosstool:clang/bin/crosstool_wrapper_driver_rocm", "crosstool:clang/bin/crosstool_wrapper_driver_rocm",
{ {
"%{cpu_compiler}": str(cc), "%{cpu_compiler}": str(cc),
"%{hipcc_path}": "/opt/rocm/bin/hipcc", "%{hipcc_path}": rocm_config.rocm_toolkit_path + "/bin/hipcc",
"%{hipcc_env}": _hipcc_env(repository_ctx), "%{hipcc_env}": _hipcc_env(repository_ctx),
"%{hipcc_is_hipclang}": _hipcc_is_hipclang(repository_ctx), "%{hipcc_is_hipclang}": _hipcc_is_hipclang(repository_ctx, rocm_config),
"%{rocr_runtime_path}": "/opt/rocm/lib", "%{rocr_runtime_path}": rocm_config.rocm_toolkit_path + "/lib",
"%{rocr_runtime_library}": "hsa-runtime64", "%{rocr_runtime_library}": "hsa-runtime64",
"%{hip_runtime_path}": "/opt/rocm/hip/lib", "%{hip_runtime_path}": rocm_config.rocm_toolkit_path + "/hip/lib",
"%{hip_runtime_library}": "hip_hcc", "%{hip_runtime_library}": "hip_hcc",
"%{hcc_runtime_path}": "/opt/rocm/hcc/lib", "%{hcc_runtime_path}": rocm_config.rocm_toolkit_path + "/hcc/lib",
"%{hcc_runtime_library}": "mcwamp", "%{hcc_runtime_library}": "mcwamp",
"%{crosstool_verbose}": _crosstool_verbose(repository_ctx), "%{crosstool_verbose}": _crosstool_verbose(repository_ctx),
"%{gcc_host_compiler_path}": str(cc), "%{gcc_host_compiler_path}": str(cc),

View File

@ -9,5 +9,5 @@ container_digests = {
"cuda10.1-cudnn7-centos6": "sha256:454b899657e87893ee5e68dc0f87df59b6a0a7418ae09cafcc3dd65ac71feca9", "cuda10.1-cudnn7-centos6": "sha256:454b899657e87893ee5e68dc0f87df59b6a0a7418ae09cafcc3dd65ac71feca9",
"cuda10.0-cudnn7-ubuntu16.04-manylinux2010": "sha256:5812d9d0ef0a3276fc5faaf4cd01f3d6e03d635893a6e2d2e04f6f01d626c432", "cuda10.0-cudnn7-ubuntu16.04-manylinux2010": "sha256:5812d9d0ef0a3276fc5faaf4cd01f3d6e03d635893a6e2d2e04f6f01d626c432",
"cuda10.1-cudnn7-ubuntu16.04-manylinux2010": "sha256:f8e15f08cb501e5f2de3dc450f614609fd3ed19bde74b153fa66d14b2307610c", "cuda10.1-cudnn7-ubuntu16.04-manylinux2010": "sha256:f8e15f08cb501e5f2de3dc450f614609fd3ed19bde74b153fa66d14b2307610c",
"rocm-ubuntu16.04": "sha256:d5cd4120cff3d2a452378aad03746ff5f24699d86cf695c20ee96f366e42975f", "rocm-ubuntu16.04": "sha256:e645447dd6127325f3e97b8bf23424f637a8579d963b34fcc6772cf7cfaa0ebe",
} }

View File

@ -72,7 +72,7 @@ def _tensorflow_rbe_config(name, compiler, python_version, os, rocm_version = No
docker_toolchain_autoconfig( docker_toolchain_autoconfig(
name = name, name = name,
base = base, base = base,
bazel_version = "0.29.1", bazel_version = "1.2.1",
build_bazel_src = build_bazel_src, build_bazel_src = build_bazel_src,
config_repos = config_repos, config_repos = config_repos,
env = env, env = env,

View File

@ -15,6 +15,7 @@ cc_library(
name = "rocm_headers", name = "rocm_headers",
hdrs = [ hdrs = [
"rocm/rocm_config.h", "rocm/rocm_config.h",
":hipsparse-include",
":miopen-include", ":miopen-include",
":rccl-include", ":rccl-include",
":rocblas-include", ":rocblas-include",
@ -141,6 +142,13 @@ cc_library(
], ],
) )
cc_import(
name = "hipsparse",
hdrs = glob(["rocm/include/hipsparse/**"]),
shared_library = "rocm/lib/libhipsparse.so",
visibility = ["//visibility:public"],
)
genrule( genrule(
name = "rocm-include", name = "rocm-include",
outs = [ outs = [
@ -175,6 +183,7 @@ genrule(
"rocm/include/hcc/clang-c/CXErrorCode.h", "rocm/include/hcc/clang-c/CXErrorCode.h",
"rocm/include/hcc/clang-c/CXString.h", "rocm/include/hcc/clang-c/CXString.h",
"rocm/include/hcc/clang-c/Documentation.h", "rocm/include/hcc/clang-c/Documentation.h",
"rocm/include/hcc/clang-c/FatalErrorHandler.h",
"rocm/include/hcc/clang-c/Index.h", "rocm/include/hcc/clang-c/Index.h",
"rocm/include/hcc/clang-c/Platform.h", "rocm/include/hcc/clang-c/Platform.h",
"rocm/include/hcc/coordinate", "rocm/include/hcc/coordinate",
@ -275,12 +284,14 @@ genrule(
"rocm/include/hip/hcc_detail/hip_prof_str.h", "rocm/include/hip/hcc_detail/hip_prof_str.h",
"rocm/include/hip/hcc_detail/hip_runtime.h", "rocm/include/hip/hcc_detail/hip_runtime.h",
"rocm/include/hip/hcc_detail/hip_runtime_api.h", "rocm/include/hip/hcc_detail/hip_runtime_api.h",
"rocm/include/hip/hcc_detail/hip_runtime_prof.h",
"rocm/include/hip/hcc_detail/hip_surface_types.h", "rocm/include/hip/hcc_detail/hip_surface_types.h",
"rocm/include/hip/hcc_detail/hip_texture_types.h", "rocm/include/hip/hcc_detail/hip_texture_types.h",
"rocm/include/hip/hcc_detail/hip_vector_types.h", "rocm/include/hip/hcc_detail/hip_vector_types.h",
"rocm/include/hip/hcc_detail/hiprtc.h", "rocm/include/hip/hcc_detail/hiprtc.h",
"rocm/include/hip/hcc_detail/host_defines.h", "rocm/include/hip/hcc_detail/host_defines.h",
"rocm/include/hip/hcc_detail/hsa_helpers.hpp", "rocm/include/hip/hcc_detail/hsa_helpers.hpp",
"rocm/include/hip/hcc_detail/library_types.h",
"rocm/include/hip/hcc_detail/llvm_intrinsics.h", "rocm/include/hip/hcc_detail/llvm_intrinsics.h",
"rocm/include/hip/hcc_detail/macro_based_grid_launch.hpp", "rocm/include/hip/hcc_detail/macro_based_grid_launch.hpp",
"rocm/include/hip/hcc_detail/math_functions.h", "rocm/include/hip/hcc_detail/math_functions.h",
@ -292,6 +303,7 @@ genrule(
"rocm/include/hip/hip_common.h", "rocm/include/hip/hip_common.h",
"rocm/include/hip/hip_complex.h", "rocm/include/hip/hip_complex.h",
"rocm/include/hip/hip_cooperative_groups.h", "rocm/include/hip/hip_cooperative_groups.h",
"rocm/include/hip/hip_ext.h",
"rocm/include/hip/hip_fp16.h", "rocm/include/hip/hip_fp16.h",
"rocm/include/hip/hip_hcc.h", "rocm/include/hip/hip_hcc.h",
"rocm/include/hip/hip_profile.h", "rocm/include/hip/hip_profile.h",
@ -300,6 +312,7 @@ genrule(
"rocm/include/hip/hip_texture_types.h", "rocm/include/hip/hip_texture_types.h",
"rocm/include/hip/hip_vector_types.h", "rocm/include/hip/hip_vector_types.h",
"rocm/include/hip/hiprtc.h", "rocm/include/hip/hiprtc.h",
"rocm/include/hip/library_types.h",
"rocm/include/hip/math_functions.h", "rocm/include/hip/math_functions.h",
"rocm/include/hip/nvcc_detail/channel_descriptor.h", "rocm/include/hip/nvcc_detail/channel_descriptor.h",
"rocm/include/hip/nvcc_detail/hip_complex.h", "rocm/include/hip/nvcc_detail/hip_complex.h",
@ -441,7 +454,6 @@ genrule(
"rocm/include/ocml.h", "rocm/include/ocml.h",
"rocm/include/opencl1.2-c.pch", "rocm/include/opencl1.2-c.pch",
"rocm/include/opencl2.0-c.pch", "rocm/include/opencl2.0-c.pch",
"rocm/include/profiler/CXLActivityLogger/CXLActivityLogger.h",
"rocm/include/rccl.h", "rocm/include/rccl.h",
"rocm/include/rocalution.hpp", "rocm/include/rocalution.hpp",
"rocm/include/rocblas-auxiliary.h", "rocm/include/rocblas-auxiliary.h",
@ -583,6 +595,7 @@ genrule(
"rocm/include/rocrand/rocrand_xorwow.h", "rocm/include/rocrand/rocrand_xorwow.h",
"rocm/include/rocrand/rocrand_xorwow_precomputed.h", "rocm/include/rocrand/rocrand_xorwow_precomputed.h",
"rocm/include/rocsparse-auxiliary.h", "rocm/include/rocsparse-auxiliary.h",
"rocm/include/rocsparse-complex-types.h",
"rocm/include/rocsparse-export.h", "rocm/include/rocsparse-export.h",
"rocm/include/rocsparse-functions.h", "rocm/include/rocsparse-functions.h",
"rocm/include/rocsparse-types.h", "rocm/include/rocsparse-types.h",
@ -1468,6 +1481,16 @@ genrule(
cmd = """cp -rLf "/opt/rocm/rccl/include/." "$(@D)/" """, cmd = """cp -rLf "/opt/rocm/rccl/include/." "$(@D)/" """,
) )
genrule(
name = "hipsparse-include",
outs = [
"rocm/include/hipsparse/hipsparse-export.h",
"rocm/include/hipsparse/hipsparse-version.h",
"rocm/include/hipsparse/hipsparse.h",
],
cmd = """cp -rLf "/opt/rocm/hipsparse/include/." "$(@D)/rocm/include/hipsparse/" """,
)
genrule( genrule(
name = "rocm-lib", name = "rocm-lib",
outs = [ outs = [
@ -1477,11 +1500,13 @@ genrule(
"rocm/lib/libhiprand.so", "rocm/lib/libhiprand.so",
"rocm/lib/libMIOpen.so", "rocm/lib/libMIOpen.so",
"rocm/lib/librccl.so", "rocm/lib/librccl.so",
"rocm/lib/libhipsparse.so",
], ],
cmd = """cp -f "/opt/rocm/hip/lib/libhip_hcc.so" "$(location rocm/lib/libhip_hcc.so)" && \ cmd = """cp -f "/opt/rocm/hip/lib/libhip_hcc.so" "$(location rocm/lib/libhip_hcc.so)" && \
cp -f "/opt/rocm/rocblas/lib/librocblas.so.0.1" "$(location rocm/lib/librocblas.so)" && \ cp -f "/opt/rocm/rocblas/lib/librocblas.so.0.1" "$(location rocm/lib/librocblas.so)" && \
cp -f "/opt/rocm/rocfft/lib/librocfft.so.0.1" "$(location rocm/lib/librocfft.so)" && \ cp -f "/opt/rocm/rocfft/lib/librocfft.so.0.1" "$(location rocm/lib/librocfft.so)" && \
cp -f "/opt/rocm/hiprand/lib/libhiprand.so.1.1" "$(location rocm/lib/libhiprand.so)" && \ cp -f "/opt/rocm/hiprand/lib/libhiprand.so.1.1" "$(location rocm/lib/libhiprand.so)" && \
cp -f "/opt/rocm/miopen/lib/libMIOpen.so.1" "$(location rocm/lib/libMIOpen.so)" && \ cp -f "/opt/rocm/miopen/lib/libMIOpen.so.1" "$(location rocm/lib/libMIOpen.so)" && \
cp -f "/opt/rocm/rccl/lib/librccl.so" "$(location rocm/lib/librccl.so)" """, cp -f "/opt/rocm/rccl/lib/librccl.so" "$(location rocm/lib/librccl.so)" && \
cp -f "/opt/rocm/hipsparse/lib/libhipsparse.so.0.1" "$(location rocm/lib/libhipsparse.so)" """,
) )