Merge pull request #34800 from ROCmSoftwarePlatform:google_upstream_rocm_csr_sparse_matrix_support
PiperOrigin-RevId: 289617600 Change-Id: Ic1aa3714126d7b867295ae386b6be643c1dc83e4
This commit is contained in:
commit
0e2b5a9d2a
@ -3480,14 +3480,18 @@ tf_kernel_library(
|
||||
|
||||
tf_kernel_library(
|
||||
name = "cuda_sparse",
|
||||
srcs = ["cuda_sparse.cc"],
|
||||
srcs = if_cuda(["cuda_sparse.cc"]) + if_rocm(["rocm_sparse.cc"]),
|
||||
hdrs = ["cuda_sparse.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/kernels:cuda_solvers",
|
||||
] + if_cuda([
|
||||
"//tensorflow/stream_executor/cuda:cusparse_lib",
|
||||
] + if_cuda(["@cub_archive//:cub"]),
|
||||
"@cub_archive//:cub",
|
||||
]) + if_rocm([
|
||||
"@local_config_rocm//rocm:hipsparse",
|
||||
]),
|
||||
)
|
||||
|
||||
LINALG_DEPS = [
|
||||
|
@ -69,7 +69,7 @@ inline typename CudaComplexT<T>::type* AsCudaComplex(T* p) {
|
||||
}
|
||||
|
||||
// A set of initialized handles to the underlying Cuda libraries used by
|
||||
// CudaSparse. We maintain one such set of handles per unique stream.
|
||||
// GpuSparse. We maintain one such set of handles per unique stream.
|
||||
class CudaSparseHandles {
|
||||
public:
|
||||
explicit CudaSparseHandles(cudaStream_t stream)
|
||||
@ -96,8 +96,8 @@ class CudaSparseHandles {
|
||||
|
||||
Status Initialize() {
|
||||
if (initialized_) return Status::OK();
|
||||
TF_RETURN_IF_CUSPARSE_ERROR(cusparseCreate(&cusparse_handle_));
|
||||
TF_RETURN_IF_CUSPARSE_ERROR(cusparseSetStream(cusparse_handle_, stream_));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreate(&cusparse_handle_));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSetStream(cusparse_handle_, stream_));
|
||||
initialized_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -149,7 +149,7 @@ HandleMap* GetHandleMapSingleton() {
|
||||
|
||||
} // namespace
|
||||
|
||||
CudaSparse::CudaSparse(OpKernelContext* context)
|
||||
GpuSparse::GpuSparse(OpKernelContext* context)
|
||||
: initialized_(false), context_(context) {
|
||||
auto cuda_stream_ptr =
|
||||
reinterpret_cast<const cudaStream_t*>(context->op_device_context()
|
||||
@ -157,25 +157,24 @@ CudaSparse::CudaSparse(OpKernelContext* context)
|
||||
->implementation()
|
||||
->GpuStreamMemberHack());
|
||||
DCHECK(cuda_stream_ptr);
|
||||
cuda_stream_ = *cuda_stream_ptr;
|
||||
gpu_stream_ = *cuda_stream_ptr;
|
||||
}
|
||||
|
||||
Status CudaSparse::Initialize() {
|
||||
Status GpuSparse::Initialize() {
|
||||
HandleMap* handle_map = GetHandleMapSingleton();
|
||||
DCHECK(handle_map);
|
||||
mutex_lock lock(handle_map_mutex);
|
||||
auto it = handle_map->find(cuda_stream_);
|
||||
auto it = handle_map->find(gpu_stream_);
|
||||
if (it == handle_map->end()) {
|
||||
LOG(INFO) << "Creating CudaSparse handles for stream " << cuda_stream_;
|
||||
LOG(INFO) << "Creating CudaSparse handles for stream " << gpu_stream_;
|
||||
// Previously unseen Cuda stream. Initialize a set of Cuda sparse library
|
||||
// handles for it.
|
||||
CudaSparseHandles new_handles(cuda_stream_);
|
||||
CudaSparseHandles new_handles(gpu_stream_);
|
||||
TF_RETURN_IF_ERROR(new_handles.Initialize());
|
||||
it =
|
||||
handle_map->insert(std::make_pair(cuda_stream_, std::move(new_handles)))
|
||||
.first;
|
||||
it = handle_map->insert(std::make_pair(gpu_stream_, std::move(new_handles)))
|
||||
.first;
|
||||
}
|
||||
cusparse_handle_ = &it->second.handle();
|
||||
gpusparse_handle_ = &it->second.handle();
|
||||
initialized_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -205,32 +204,32 @@ template <typename Scalar, typename SparseFn>
|
||||
static inline Status GtsvImpl(SparseFn op, cusparseHandle_t cusparse_handle,
|
||||
int m, int n, const Scalar* dl, const Scalar* d,
|
||||
const Scalar* du, Scalar* B, int ldb) {
|
||||
TF_RETURN_IF_CUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl),
|
||||
AsCudaComplex(d), AsCudaComplex(du),
|
||||
AsCudaComplex(B), ldb));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl),
|
||||
AsCudaComplex(d), AsCudaComplex(du),
|
||||
AsCudaComplex(B), ldb));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define GTSV_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status CudaSparse::Gtsv<Scalar>(int m, int n, const Scalar* dl, \
|
||||
const Scalar* d, const Scalar* du, \
|
||||
Scalar* B, int ldb) const { \
|
||||
DCHECK(initialized_); \
|
||||
return GtsvImpl(SPARSE_FN(gtsv, sparse_prefix), *cusparse_handle_, m, n, \
|
||||
dl, d, du, B, ldb); \
|
||||
#define GTSV_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status GpuSparse::Gtsv<Scalar>(int m, int n, const Scalar* dl, \
|
||||
const Scalar* d, const Scalar* du, Scalar* B, \
|
||||
int ldb) const { \
|
||||
DCHECK(initialized_); \
|
||||
return GtsvImpl(SPARSE_FN(gtsv, sparse_prefix), *gpusparse_handle_, m, n, \
|
||||
dl, d, du, B, ldb); \
|
||||
}
|
||||
|
||||
TF_CALL_LAPACK_TYPES(GTSV_INSTANCE);
|
||||
|
||||
#define GTSV_NO_PIVOT_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status CudaSparse::GtsvNoPivot<Scalar>(int m, int n, const Scalar* dl, \
|
||||
const Scalar* d, const Scalar* du, \
|
||||
Scalar* B, int ldb) const { \
|
||||
DCHECK(initialized_); \
|
||||
return GtsvImpl(SPARSE_FN(gtsv_nopivot, sparse_prefix), *cusparse_handle_, \
|
||||
m, n, dl, d, du, B, ldb); \
|
||||
#define GTSV_NO_PIVOT_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status GpuSparse::GtsvNoPivot<Scalar>(int m, int n, const Scalar* dl, \
|
||||
const Scalar* d, const Scalar* du, \
|
||||
Scalar* B, int ldb) const { \
|
||||
DCHECK(initialized_); \
|
||||
return GtsvImpl(SPARSE_FN(gtsv_nopivot, sparse_prefix), \
|
||||
*gpusparse_handle_, m, n, dl, d, du, B, ldb); \
|
||||
}
|
||||
|
||||
TF_CALL_LAPACK_TYPES(GTSV_NO_PIVOT_INSTANCE);
|
||||
@ -242,20 +241,20 @@ static inline Status GtsvStridedBatchImpl(SparseFn op,
|
||||
const Scalar* d, const Scalar* du,
|
||||
Scalar* x, int batchCount,
|
||||
int batchStride) {
|
||||
TF_RETURN_IF_CUSPARSE_ERROR(op(cusparse_handle, m, AsCudaComplex(dl),
|
||||
AsCudaComplex(d), AsCudaComplex(du),
|
||||
AsCudaComplex(x), batchCount, batchStride));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, AsCudaComplex(dl),
|
||||
AsCudaComplex(d), AsCudaComplex(du),
|
||||
AsCudaComplex(x), batchCount, batchStride));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define GTSV_STRIDED_BATCH_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status CudaSparse::GtsvStridedBatch<Scalar>( \
|
||||
Status GpuSparse::GtsvStridedBatch<Scalar>( \
|
||||
int m, const Scalar* dl, const Scalar* d, const Scalar* du, Scalar* x, \
|
||||
int batchCount, int batchStride) const { \
|
||||
DCHECK(initialized_); \
|
||||
return GtsvStridedBatchImpl(SPARSE_FN(gtsvStridedBatch, sparse_prefix), \
|
||||
*cusparse_handle_, m, dl, d, du, x, \
|
||||
*gpusparse_handle_, m, dl, d, du, x, \
|
||||
batchCount, batchStride); \
|
||||
}
|
||||
|
||||
@ -266,32 +265,32 @@ static inline Status Gtsv2Impl(SparseFn op, cusparseHandle_t cusparse_handle,
|
||||
int m, int n, const Scalar* dl, const Scalar* d,
|
||||
const Scalar* du, Scalar* B, int ldb,
|
||||
void* pBuffer) {
|
||||
TF_RETURN_IF_CUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl),
|
||||
AsCudaComplex(d), AsCudaComplex(du),
|
||||
AsCudaComplex(B), ldb, pBuffer));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl),
|
||||
AsCudaComplex(d), AsCudaComplex(du),
|
||||
AsCudaComplex(B), ldb, pBuffer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define GTSV2_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status CudaSparse::Gtsv2<Scalar>(int m, int n, const Scalar* dl, \
|
||||
const Scalar* d, const Scalar* du, \
|
||||
Scalar* B, int ldb, void* pBuffer) const { \
|
||||
DCHECK(initialized_); \
|
||||
return Gtsv2Impl(SPARSE_FN(gtsv2, sparse_prefix), *cusparse_handle_, m, n, \
|
||||
dl, d, du, B, ldb, pBuffer); \
|
||||
#define GTSV2_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status GpuSparse::Gtsv2<Scalar>(int m, int n, const Scalar* dl, \
|
||||
const Scalar* d, const Scalar* du, \
|
||||
Scalar* B, int ldb, void* pBuffer) const { \
|
||||
DCHECK(initialized_); \
|
||||
return Gtsv2Impl(SPARSE_FN(gtsv2, sparse_prefix), *gpusparse_handle_, m, \
|
||||
n, dl, d, du, B, ldb, pBuffer); \
|
||||
}
|
||||
|
||||
TF_CALL_LAPACK_TYPES(GTSV2_INSTANCE);
|
||||
|
||||
#define GTSV2_NO_PIVOT_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status CudaSparse::Gtsv2NoPivot<Scalar>( \
|
||||
int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du, \
|
||||
Scalar* B, int ldb, void* pBuffer) const { \
|
||||
DCHECK(initialized_); \
|
||||
return Gtsv2Impl(SPARSE_FN(gtsv2_nopivot, sparse_prefix), \
|
||||
*cusparse_handle_, m, n, dl, d, du, B, ldb, pBuffer); \
|
||||
#define GTSV2_NO_PIVOT_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status GpuSparse::Gtsv2NoPivot<Scalar>( \
|
||||
int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du, \
|
||||
Scalar* B, int ldb, void* pBuffer) const { \
|
||||
DCHECK(initialized_); \
|
||||
return Gtsv2Impl(SPARSE_FN(gtsv2_nopivot, sparse_prefix), \
|
||||
*gpusparse_handle_, m, n, dl, d, du, B, ldb, pBuffer); \
|
||||
}
|
||||
|
||||
TF_CALL_LAPACK_TYPES(GTSV2_NO_PIVOT_INSTANCE);
|
||||
@ -303,34 +302,34 @@ static inline Status Gtsv2BufferSizeExtImpl(SparseFn op,
|
||||
const Scalar* d, const Scalar* du,
|
||||
const Scalar* B, int ldb,
|
||||
size_t* bufferSizeInBytes) {
|
||||
TF_RETURN_IF_CUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl),
|
||||
AsCudaComplex(d), AsCudaComplex(du),
|
||||
AsCudaComplex(B), ldb, bufferSizeInBytes));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl),
|
||||
AsCudaComplex(d), AsCudaComplex(du),
|
||||
AsCudaComplex(B), ldb, bufferSizeInBytes));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define GTSV2_BUFFER_SIZE_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status CudaSparse::Gtsv2BufferSizeExt<Scalar>( \
|
||||
int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du, \
|
||||
const Scalar* B, int ldb, size_t* bufferSizeInBytes) const { \
|
||||
DCHECK(initialized_); \
|
||||
return Gtsv2BufferSizeExtImpl( \
|
||||
SPARSE_FN(gtsv2_bufferSizeExt, sparse_prefix), *cusparse_handle_, m, \
|
||||
n, dl, d, du, B, ldb, bufferSizeInBytes); \
|
||||
#define GTSV2_BUFFER_SIZE_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status GpuSparse::Gtsv2BufferSizeExt<Scalar>( \
|
||||
int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du, \
|
||||
const Scalar* B, int ldb, size_t* bufferSizeInBytes) const { \
|
||||
DCHECK(initialized_); \
|
||||
return Gtsv2BufferSizeExtImpl( \
|
||||
SPARSE_FN(gtsv2_bufferSizeExt, sparse_prefix), *gpusparse_handle_, m, \
|
||||
n, dl, d, du, B, ldb, bufferSizeInBytes); \
|
||||
}
|
||||
|
||||
TF_CALL_LAPACK_TYPES(GTSV2_BUFFER_SIZE_INSTANCE);
|
||||
|
||||
#define GTSV2_NO_PIVOT_BUFFER_SIZE_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status CudaSparse::Gtsv2NoPivotBufferSizeExt<Scalar>( \
|
||||
Status GpuSparse::Gtsv2NoPivotBufferSizeExt<Scalar>( \
|
||||
int m, int n, const Scalar* dl, const Scalar* d, const Scalar* du, \
|
||||
const Scalar* B, int ldb, size_t* bufferSizeInBytes) const { \
|
||||
DCHECK(initialized_); \
|
||||
return Gtsv2BufferSizeExtImpl( \
|
||||
SPARSE_FN(gtsv2_nopivot_bufferSizeExt, sparse_prefix), \
|
||||
*cusparse_handle_, m, n, dl, d, du, B, ldb, bufferSizeInBytes); \
|
||||
*gpusparse_handle_, m, n, dl, d, du, B, ldb, bufferSizeInBytes); \
|
||||
}
|
||||
|
||||
TF_CALL_LAPACK_TYPES(GTSV2_NO_PIVOT_BUFFER_SIZE_INSTANCE);
|
||||
@ -342,7 +341,7 @@ static inline Status Gtsv2StridedBatchImpl(SparseFn op,
|
||||
const Scalar* d, const Scalar* du,
|
||||
Scalar* x, int batchCount,
|
||||
int batchStride, void* pBuffer) {
|
||||
TF_RETURN_IF_CUSPARSE_ERROR(op(
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(op(
|
||||
cusparse_handle, m, AsCudaComplex(dl), AsCudaComplex(d),
|
||||
AsCudaComplex(du), AsCudaComplex(x), batchCount, batchStride, pBuffer));
|
||||
return Status::OK();
|
||||
@ -350,12 +349,12 @@ static inline Status Gtsv2StridedBatchImpl(SparseFn op,
|
||||
|
||||
#define GTSV2_STRIDED_BATCH_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status CudaSparse::Gtsv2StridedBatch<Scalar>( \
|
||||
Status GpuSparse::Gtsv2StridedBatch<Scalar>( \
|
||||
int m, const Scalar* dl, const Scalar* d, const Scalar* du, Scalar* x, \
|
||||
int batchCount, int batchStride, void* pBuffer) const { \
|
||||
DCHECK(initialized_); \
|
||||
return Gtsv2StridedBatchImpl(SPARSE_FN(gtsv2StridedBatch, sparse_prefix), \
|
||||
*cusparse_handle_, m, dl, d, du, x, \
|
||||
*gpusparse_handle_, m, dl, d, du, x, \
|
||||
batchCount, batchStride, pBuffer); \
|
||||
}
|
||||
|
||||
@ -366,30 +365,30 @@ static inline Status Gtsv2StridedBatchBufferSizeImpl(
|
||||
SparseFn op, cusparseHandle_t cusparse_handle, int m, const Scalar* dl,
|
||||
const Scalar* d, const Scalar* du, const Scalar* x, int batchCount,
|
||||
int batchStride, size_t* bufferSizeInBytes) {
|
||||
TF_RETURN_IF_CUSPARSE_ERROR(op(cusparse_handle, m, AsCudaComplex(dl),
|
||||
AsCudaComplex(d), AsCudaComplex(du),
|
||||
AsCudaComplex(x), batchCount, batchStride,
|
||||
bufferSizeInBytes));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, AsCudaComplex(dl),
|
||||
AsCudaComplex(d), AsCudaComplex(du),
|
||||
AsCudaComplex(x), batchCount, batchStride,
|
||||
bufferSizeInBytes));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define GTSV2_STRIDED_BATCH_BUFFER_SIZE_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status CudaSparse::Gtsv2StridedBatchBufferSizeExt<Scalar>( \
|
||||
Status GpuSparse::Gtsv2StridedBatchBufferSizeExt<Scalar>( \
|
||||
int m, const Scalar* dl, const Scalar* d, const Scalar* du, \
|
||||
const Scalar* x, int batchCount, int batchStride, \
|
||||
size_t* bufferSizeInBytes) const { \
|
||||
DCHECK(initialized_); \
|
||||
return Gtsv2StridedBatchBufferSizeImpl( \
|
||||
SPARSE_FN(gtsv2StridedBatch_bufferSizeExt, sparse_prefix), \
|
||||
*cusparse_handle_, m, dl, d, du, x, batchCount, batchStride, \
|
||||
*gpusparse_handle_, m, dl, d, du, x, batchCount, batchStride, \
|
||||
bufferSizeInBytes); \
|
||||
}
|
||||
|
||||
TF_CALL_LAPACK_TYPES(GTSV2_STRIDED_BATCH_BUFFER_SIZE_INSTANCE);
|
||||
|
||||
Status CudaSparse::Coo2csr(const int* cooRowInd, int nnz, int m,
|
||||
int* csrRowPtr) const {
|
||||
Status GpuSparse::Coo2csr(const int* cooRowInd, int nnz, int m,
|
||||
int* csrRowPtr) const {
|
||||
// cusparseStatus_t CUSPARSEAPI cusparseXcoo2csr(cusparseHandle_t handle,
|
||||
// const int *cooRowInd,
|
||||
// int nnz,
|
||||
@ -398,14 +397,14 @@ Status CudaSparse::Coo2csr(const int* cooRowInd, int nnz, int m,
|
||||
// cusparseIndexBase_t
|
||||
// idxBase);
|
||||
DCHECK(initialized_);
|
||||
TF_RETURN_IF_CUSPARSE_ERROR(cusparseXcoo2csr(*cusparse_handle_, cooRowInd,
|
||||
nnz, m, csrRowPtr,
|
||||
CUSPARSE_INDEX_BASE_ZERO));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcoo2csr(*gpusparse_handle_, cooRowInd,
|
||||
nnz, m, csrRowPtr,
|
||||
CUSPARSE_INDEX_BASE_ZERO));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CudaSparse::Csr2coo(const int* csrRowPtr, int nnz, int m,
|
||||
int* cooRowInd) const {
|
||||
Status GpuSparse::Csr2coo(const int* csrRowPtr, int nnz, int m,
|
||||
int* cooRowInd) const {
|
||||
// cusparseStatus_t CUSPARSEAPI cusparseXcsr2coo(cusparseHandle_t handle,
|
||||
// const int *csrRowPtr,
|
||||
// int nnz,
|
||||
@ -414,26 +413,26 @@ Status CudaSparse::Csr2coo(const int* csrRowPtr, int nnz, int m,
|
||||
// cusparseIndexBase_t
|
||||
// idxBase);
|
||||
DCHECK(initialized_);
|
||||
TF_RETURN_IF_CUSPARSE_ERROR(cusparseXcsr2coo(*cusparse_handle_, csrRowPtr,
|
||||
nnz, m, cooRowInd,
|
||||
CUSPARSE_INDEX_BASE_ZERO));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsr2coo(*gpusparse_handle_, csrRowPtr,
|
||||
nnz, m, cooRowInd,
|
||||
CUSPARSE_INDEX_BASE_ZERO));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CudaSparse::CsrgeamNnz(int m, int n, const cusparseMatDescr_t descrA,
|
||||
int nnzA, const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA,
|
||||
const cusparseMatDescr_t descrB, int nnzB,
|
||||
const int* csrSortedRowPtrB,
|
||||
const int* csrSortedColIndB,
|
||||
const cusparseMatDescr_t descrC,
|
||||
int* csrSortedRowPtrC, int* nnzTotalDevHostPtr) {
|
||||
Status GpuSparse::CsrgeamNnz(int m, int n, const cusparseMatDescr_t descrA,
|
||||
int nnzA, const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA,
|
||||
const cusparseMatDescr_t descrB, int nnzB,
|
||||
const int* csrSortedRowPtrB,
|
||||
const int* csrSortedColIndB,
|
||||
const cusparseMatDescr_t descrC,
|
||||
int* csrSortedRowPtrC, int* nnzTotalDevHostPtr) {
|
||||
DCHECK(initialized_);
|
||||
DCHECK(nnzTotalDevHostPtr != nullptr);
|
||||
TF_RETURN_IF_CUSPARSE_ERROR(cusparseXcsrgeamNnz(
|
||||
*cusparse_handle_, m, n, descrA, nnzA, csrSortedRowPtrA, csrSortedColIndA,
|
||||
descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB, descrC,
|
||||
csrSortedRowPtrC, nnzTotalDevHostPtr));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgeamNnz(
|
||||
*gpusparse_handle_, m, n, descrA, nnzA, csrSortedRowPtrA,
|
||||
csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB,
|
||||
descrC, csrSortedRowPtrC, nnzTotalDevHostPtr));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -452,7 +451,7 @@ static inline Status CsrmmImpl(
|
||||
// const float* csrSortedValA, const int* csrSortedRowPtrA,
|
||||
// const int* csrSortedColIndA, const float* B, int ldb, const float*
|
||||
// beta, float* C, int ldc);
|
||||
TF_RETURN_IF_CUSPARSE_ERROR(op(
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(op(
|
||||
cusparse_handle, transA, transB, m, n, k, nnz, AsCudaComplex(alpha_host),
|
||||
descrA, AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
|
||||
AsCudaComplex(B), ldb, AsCudaComplex(beta_host), AsCudaComplex(C), ldc));
|
||||
@ -461,7 +460,7 @@ static inline Status CsrmmImpl(
|
||||
|
||||
#define CSRMM_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status CudaSparse::Csrmm<Scalar>( \
|
||||
Status GpuSparse::Csrmm<Scalar>( \
|
||||
cusparseOperation_t transA, cusparseOperation_t transB, int m, int n, \
|
||||
int k, int nnz, const Scalar* alpha_host, \
|
||||
const cusparseMatDescr_t descrA, const Scalar* csrSortedValA, \
|
||||
@ -470,7 +469,7 @@ static inline Status CsrmmImpl(
|
||||
const { \
|
||||
DCHECK(initialized_); \
|
||||
return CsrmmImpl(SPARSE_FN(csrmm2, sparse_prefix), context_, \
|
||||
*cusparse_handle_, transA, transB, m, n, k, nnz, \
|
||||
*gpusparse_handle_, transA, transB, m, n, k, nnz, \
|
||||
alpha_host, descrA, csrSortedValA, csrSortedRowPtrA, \
|
||||
csrSortedColIndA, B, ldb, beta_host, C, ldc); \
|
||||
}
|
||||
@ -484,7 +483,7 @@ static inline Status CsrmvImpl(
|
||||
const cusparseMatDescr_t descrA, const Scalar* csrSortedValA,
|
||||
const int* csrSortedRowPtrA, const int* csrSortedColIndA, const Scalar* x,
|
||||
const Scalar* beta_host, Scalar* y) {
|
||||
TF_RETURN_IF_CUSPARSE_ERROR(
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
op(cusparse_handle, transA, m, n, nnz, AsCudaComplex(alpha_host), descrA,
|
||||
AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
|
||||
AsCudaComplex(x), AsCudaComplex(beta_host), AsCudaComplex(y)));
|
||||
@ -494,7 +493,7 @@ static inline Status CsrmvImpl(
|
||||
// TODO(ebrevdo,rmlarsen): Use csrmv_mp for all cases when available in CUDA 9.
|
||||
#define CSRMV_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status CudaSparse::Csrmv<Scalar>( \
|
||||
Status GpuSparse::Csrmv<Scalar>( \
|
||||
cusparseOperation_t transA, int m, int n, int nnz, \
|
||||
const Scalar* alpha_host, const cusparseMatDescr_t descrA, \
|
||||
const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
|
||||
@ -503,12 +502,12 @@ static inline Status CsrmvImpl(
|
||||
DCHECK(initialized_); \
|
||||
if (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) { \
|
||||
return CsrmvImpl(SPARSE_FN(csrmv_mp, sparse_prefix), context_, \
|
||||
*cusparse_handle_, transA, m, n, nnz, alpha_host, \
|
||||
*gpusparse_handle_, transA, m, n, nnz, alpha_host, \
|
||||
descrA, csrSortedValA, csrSortedRowPtrA, \
|
||||
csrSortedColIndA, x, beta_host, y); \
|
||||
} else { \
|
||||
return CsrmvImpl(SPARSE_FN(csrmv, sparse_prefix), context_, \
|
||||
*cusparse_handle_, transA, m, n, nnz, alpha_host, \
|
||||
*gpusparse_handle_, transA, m, n, nnz, alpha_host, \
|
||||
descrA, csrSortedValA, csrSortedRowPtrA, \
|
||||
csrSortedColIndA, x, beta_host, y); \
|
||||
} \
|
||||
@ -526,7 +525,7 @@ static inline Status CsrgeamImpl(
|
||||
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
||||
const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
|
||||
int* csrSortedRowPtrC, int* csrSortedColIndC) {
|
||||
TF_RETURN_IF_CUSPARSE_ERROR(
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
op(cusparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA,
|
||||
AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
|
||||
AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB),
|
||||
@ -537,7 +536,7 @@ static inline Status CsrgeamImpl(
|
||||
|
||||
#define CSRGEAM_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status CudaSparse::Csrgeam<Scalar>( \
|
||||
Status GpuSparse::Csrgeam<Scalar>( \
|
||||
int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, \
|
||||
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
|
||||
const int* csrSortedColIndA, const Scalar* beta, \
|
||||
@ -547,7 +546,7 @@ static inline Status CsrgeamImpl(
|
||||
int* csrSortedRowPtrC, int* csrSortedColIndC) { \
|
||||
DCHECK(initialized_); \
|
||||
return CsrgeamImpl(SPARSE_FN(csrgeam, sparse_prefix), context_, \
|
||||
*cusparse_handle_, m, n, alpha, descrA, nnzA, \
|
||||
*gpusparse_handle_, m, n, alpha, descrA, nnzA, \
|
||||
csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, \
|
||||
beta, descrB, nnzB, csrSortedValB, csrSortedRowPtrB, \
|
||||
csrSortedColIndB, descrC, csrSortedValC, \
|
||||
@ -556,7 +555,7 @@ static inline Status CsrgeamImpl(
|
||||
|
||||
TF_CALL_LAPACK_TYPES(CSRGEAM_INSTANCE);
|
||||
|
||||
Status CudaSparse::CsrgemmNnz(
|
||||
Status GpuSparse::CsrgemmNnz(
|
||||
cusparseOperation_t transA, cusparseOperation_t transB, int m, int k, int n,
|
||||
const cusparseMatDescr_t descrA, int nnzA, const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB,
|
||||
@ -565,8 +564,8 @@ Status CudaSparse::CsrgemmNnz(
|
||||
int* nnzTotalDevHostPtr) {
|
||||
DCHECK(initialized_);
|
||||
DCHECK(nnzTotalDevHostPtr != nullptr);
|
||||
TF_RETURN_IF_CUSPARSE_ERROR(cusparseXcsrgemmNnz(
|
||||
*cusparse_handle_, transA, transB, m, k, n, descrA, nnzA,
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgemmNnz(
|
||||
*gpusparse_handle_, transA, transB, m, k, n, descrA, nnzA,
|
||||
csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB,
|
||||
csrSortedColIndB, descrC, csrSortedRowPtrC, nnzTotalDevHostPtr));
|
||||
return Status::OK();
|
||||
@ -582,7 +581,7 @@ static inline Status CsrgemmImpl(
|
||||
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
||||
const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
|
||||
int* csrSortedRowPtrC, int* csrSortedColIndC) {
|
||||
TF_RETURN_IF_CUSPARSE_ERROR(
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
op(cusparse_handle, transA, transB, m, k, n, descrA, nnzA,
|
||||
AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
|
||||
descrB, nnzB, AsCudaComplex(csrSortedValB), csrSortedRowPtrB,
|
||||
@ -593,7 +592,7 @@ static inline Status CsrgemmImpl(
|
||||
|
||||
#define CSRGEMM_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status CudaSparse::Csrgemm<Scalar>( \
|
||||
Status GpuSparse::Csrgemm<Scalar>( \
|
||||
cusparseOperation_t transA, cusparseOperation_t transB, int m, int k, \
|
||||
int n, const cusparseMatDescr_t descrA, int nnzA, \
|
||||
const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
|
||||
@ -603,7 +602,7 @@ static inline Status CsrgemmImpl(
|
||||
Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC) { \
|
||||
DCHECK(initialized_); \
|
||||
return CsrgemmImpl(SPARSE_FN(csrgemm, sparse_prefix), context_, \
|
||||
*cusparse_handle_, transA, transB, m, k, n, descrA, \
|
||||
*gpusparse_handle_, transA, transB, m, k, n, descrA, \
|
||||
nnzA, csrSortedValA, csrSortedRowPtrA, \
|
||||
csrSortedColIndA, descrB, nnzB, csrSortedValB, \
|
||||
csrSortedRowPtrB, csrSortedColIndB, descrC, \
|
||||
@ -620,12 +619,12 @@ static inline Status Csru2csrImpl(SparseFnT op, BufferSizeFnT buffer_size_op,
|
||||
const cusparseMatDescr_t descrA,
|
||||
Scalar* csrVal, const int* csrRowPtr,
|
||||
int* csrColInd) {
|
||||
CudaSparseCsrSortingConversionInfo info;
|
||||
GpuSparseCsrSortingConversionInfo info;
|
||||
TF_RETURN_IF_ERROR(info.Initialize());
|
||||
|
||||
size_t pBufferSizeInBytes = 0;
|
||||
|
||||
TF_RETURN_IF_CUSPARSE_ERROR(
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
buffer_size_op(cusparse_handle, m, n, nnz, AsCudaComplex(csrVal),
|
||||
csrRowPtr, csrColInd, info.info(), &pBufferSizeInBytes));
|
||||
|
||||
@ -636,22 +635,22 @@ static inline Status Csru2csrImpl(SparseFnT op, BufferSizeFnT buffer_size_op,
|
||||
auto pBuffer = pBuffer_t.flat<int8>();
|
||||
DCHECK(pBuffer.data() != nullptr);
|
||||
|
||||
TF_RETURN_IF_CUSPARSE_ERROR(op(cusparse_handle, m, n, nnz, descrA,
|
||||
AsCudaComplex(csrVal), csrRowPtr, csrColInd,
|
||||
info.info(), pBuffer.data()));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, nnz, descrA,
|
||||
AsCudaComplex(csrVal), csrRowPtr, csrColInd,
|
||||
info.info(), pBuffer.data()));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define CSRU2CSR_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status CudaSparse::Csru2csr<Scalar>( \
|
||||
Status GpuSparse::Csru2csr<Scalar>( \
|
||||
int m, int n, int nnz, const cusparseMatDescr_t descrA, Scalar* csrVal, \
|
||||
const int* csrRowPtr, int* csrColInd) { \
|
||||
DCHECK(initialized_); \
|
||||
return Csru2csrImpl(SPARSE_FN(csru2csr, sparse_prefix), \
|
||||
BUFSIZE_FN(csru2csr, sparse_prefix), context_, \
|
||||
*cusparse_handle_, m, n, nnz, descrA, csrVal, \
|
||||
*gpusparse_handle_, m, n, nnz, descrA, csrVal, \
|
||||
csrRowPtr, csrColInd); \
|
||||
}
|
||||
|
||||
@ -664,22 +663,22 @@ static inline Status Csr2cscImpl(SparseFnT op, OpKernelContext* context,
|
||||
const int* csrRowPtr, const int* csrColInd,
|
||||
Scalar* cscVal, int* cscRowInd, int* cscColPtr,
|
||||
const cusparseAction_t copyValues) {
|
||||
TF_RETURN_IF_CUSPARSE_ERROR(op(cusparse_handle, m, n, nnz,
|
||||
AsCudaComplex(csrVal), csrRowPtr, csrColInd,
|
||||
AsCudaComplex(cscVal), cscRowInd, cscColPtr,
|
||||
copyValues, CUSPARSE_INDEX_BASE_ZERO));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, nnz,
|
||||
AsCudaComplex(csrVal), csrRowPtr, csrColInd,
|
||||
AsCudaComplex(cscVal), cscRowInd, cscColPtr,
|
||||
copyValues, CUSPARSE_INDEX_BASE_ZERO));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define CSR2CSC_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status CudaSparse::Csr2csc<Scalar>( \
|
||||
Status GpuSparse::Csr2csc<Scalar>( \
|
||||
int m, int n, int nnz, const Scalar* csrVal, const int* csrRowPtr, \
|
||||
const int* csrColInd, Scalar* cscVal, int* cscRowInd, int* cscColPtr, \
|
||||
const cusparseAction_t copyValues) { \
|
||||
DCHECK(initialized_); \
|
||||
return Csr2cscImpl(SPARSE_FN(csr2csc, sparse_prefix), context_, \
|
||||
*cusparse_handle_, m, n, nnz, csrVal, csrRowPtr, \
|
||||
*gpusparse_handle_, m, n, nnz, csrVal, csrRowPtr, \
|
||||
csrColInd, cscVal, cscRowInd, cscColPtr, copyValues); \
|
||||
}
|
||||
|
||||
|
@ -16,15 +16,38 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_
|
||||
|
||||
// This header declares the class CudaSparse, which contains wrappers of
|
||||
// This header declares the class GpuSparse, which contains wrappers of
|
||||
// cuSparse libraries for use in TensorFlow kernels.
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||
|
||||
using gpusparseStatus_t = cusparseStatus_t;
|
||||
using gpusparseOperation_t = cusparseOperation_t;
|
||||
using gpusparseMatDescr_t = cusparseMatDescr_t;
|
||||
using gpusparseAction_t = cusparseAction_t;
|
||||
using gpusparseHandle_t = cusparseHandle_t;
|
||||
using gpuStream_t = cudaStream_t;
|
||||
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
|
||||
#include "rocm/include/hipsparse/hipsparse.h"
|
||||
|
||||
using gpusparseStatus_t = hipsparseStatus_t;
|
||||
using gpusparseOperation_t = hipsparseOperation_t;
|
||||
using gpusparseMatDescr_t = hipsparseMatDescr_t;
|
||||
using gpusparseAction_t = hipsparseAction_t;
|
||||
using gpusparseHandle_t = hipsparseHandle_t;
|
||||
using gpuStream_t = hipStream_t;
|
||||
|
||||
#endif
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
@ -40,13 +63,15 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
inline string ConvertCUSparseErrorToString(const cusparseStatus_t status) {
|
||||
inline string ConvertGPUSparseErrorToString(const gpusparseStatus_t status) {
|
||||
switch (status) {
|
||||
#define STRINGIZE(q) #q
|
||||
#define RETURN_IF_STATUS(err) \
|
||||
case err: \
|
||||
return STRINGIZE(err);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
RETURN_IF_STATUS(CUSPARSE_STATUS_SUCCESS)
|
||||
RETURN_IF_STATUS(CUSPARSE_STATUS_NOT_INITIALIZED)
|
||||
RETURN_IF_STATUS(CUSPARSE_STATUS_ALLOC_FAILED)
|
||||
@ -57,27 +82,62 @@ inline string ConvertCUSparseErrorToString(const cusparseStatus_t status) {
|
||||
RETURN_IF_STATUS(CUSPARSE_STATUS_INTERNAL_ERROR)
|
||||
RETURN_IF_STATUS(CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED)
|
||||
|
||||
#undef RETURN_IF_STATUS
|
||||
#undef STRINGIZE
|
||||
default:
|
||||
return strings::StrCat("Unknown CUSPARSE error: ",
|
||||
static_cast<int>(status));
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
|
||||
RETURN_IF_STATUS(HIPSPARSE_STATUS_SUCCESS)
|
||||
RETURN_IF_STATUS(HIPSPARSE_STATUS_NOT_INITIALIZED)
|
||||
RETURN_IF_STATUS(HIPSPARSE_STATUS_ALLOC_FAILED)
|
||||
RETURN_IF_STATUS(HIPSPARSE_STATUS_INVALID_VALUE)
|
||||
RETURN_IF_STATUS(HIPSPARSE_STATUS_ARCH_MISMATCH)
|
||||
RETURN_IF_STATUS(HIPSPARSE_STATUS_MAPPING_ERROR)
|
||||
RETURN_IF_STATUS(HIPSPARSE_STATUS_EXECUTION_FAILED)
|
||||
RETURN_IF_STATUS(HIPSPARSE_STATUS_INTERNAL_ERROR)
|
||||
RETURN_IF_STATUS(HIPSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED)
|
||||
RETURN_IF_STATUS(HIPSPARSE_STATUS_ZERO_PIVOT)
|
||||
|
||||
default:
|
||||
return strings::StrCat("Unknown hipSPARSE error: ",
|
||||
static_cast<int>(status));
|
||||
#endif
|
||||
|
||||
#undef RETURN_IF_STATUS
|
||||
#undef STRINGIZE
|
||||
}
|
||||
}
|
||||
|
||||
#define TF_RETURN_IF_CUSPARSE_ERROR(expr) \
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#define TF_RETURN_IF_GPUSPARSE_ERROR(expr) \
|
||||
do { \
|
||||
auto status = (expr); \
|
||||
if (TF_PREDICT_FALSE(status != CUSPARSE_STATUS_SUCCESS)) { \
|
||||
return errors::Internal(__FILE__, ":", __LINE__, " (", TF_STR(expr), \
|
||||
"): cuSparse call failed with status ", \
|
||||
ConvertCUSparseErrorToString(status)); \
|
||||
ConvertGPUSparseErrorToString(status)); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
inline cusparseOperation_t TransposeAndConjugateToCuSparseOp(bool transpose,
|
||||
bool conjugate,
|
||||
Status* status) {
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
|
||||
#define TF_RETURN_IF_GPUSPARSE_ERROR(expr) \
|
||||
do { \
|
||||
auto status = (expr); \
|
||||
if (TF_PREDICT_FALSE(status != HIPSPARSE_STATUS_SUCCESS)) { \
|
||||
return errors::Internal(__FILE__, ":", __LINE__, " (", TF_STR(expr), \
|
||||
"): hipSPARSE call failed with status ", \
|
||||
ConvertGPUSparseErrorToString(status)); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#endif
|
||||
|
||||
inline gpusparseOperation_t TransposeAndConjugateToGpuSparseOp(bool transpose,
|
||||
bool conjugate,
|
||||
Status* status) {
|
||||
#if GOOGLE_CUDA
|
||||
if (transpose) {
|
||||
return conjugate ? CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE
|
||||
: CUSPARSE_OPERATION_TRANSPOSE;
|
||||
@ -89,25 +149,38 @@ inline cusparseOperation_t TransposeAndConjugateToCuSparseOp(bool transpose,
|
||||
}
|
||||
return CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
}
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
if (transpose) {
|
||||
return conjugate ? HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE
|
||||
: HIPSPARSE_OPERATION_TRANSPOSE;
|
||||
} else {
|
||||
if (conjugate) {
|
||||
DCHECK(status != nullptr);
|
||||
*status = errors::InvalidArgument(
|
||||
"Conjugate == True and transpose == False is not supported.");
|
||||
}
|
||||
return HIPSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// The CudaSparse class provides a simplified templated API for cuSparse
|
||||
// The GpuSparse class provides a simplified templated API for cuSparse
|
||||
// (http://docs.nvidia.com/cuda/cusparse/index.html).
|
||||
// An object of this class wraps static cuSparse instances,
|
||||
// and will launch Cuda kernels on the stream wrapped by the GPU device
|
||||
// in the OpKernelContext provided to the constructor.
|
||||
//
|
||||
// Notice: All the computational member functions are asynchronous and simply
|
||||
// launch one or more Cuda kernels on the Cuda stream wrapped by the CudaSparse
|
||||
// launch one or more Cuda kernels on the Cuda stream wrapped by the GpuSparse
|
||||
// object.
|
||||
|
||||
class CudaSparse {
|
||||
class GpuSparse {
|
||||
public:
|
||||
// This object stores a pointer to context, which must outlive it.
|
||||
explicit CudaSparse(OpKernelContext* context);
|
||||
virtual ~CudaSparse() {}
|
||||
explicit GpuSparse(OpKernelContext* context);
|
||||
virtual ~GpuSparse() {}
|
||||
|
||||
// This initializes the CudaSparse class if it hasn't
|
||||
// This initializes the GpuSparse class if it hasn't
|
||||
// been initialized yet. All following public methods require the
|
||||
// class has been initialized. Can be run multiple times; all
|
||||
// subsequent calls after the first have no effect.
|
||||
@ -218,9 +291,9 @@ class CudaSparse {
|
||||
//
|
||||
// **NOTE** This is an in-place operation for data in C.
|
||||
template <typename Scalar>
|
||||
Status Csrmm(cusparseOperation_t transA, cusparseOperation_t transB, int m,
|
||||
Status Csrmm(gpusparseOperation_t transA, gpusparseOperation_t transB, int m,
|
||||
int n, int k, int nnz, const Scalar* alpha_host,
|
||||
const cusparseMatDescr_t descrA, const Scalar* csrSortedValA,
|
||||
const gpusparseMatDescr_t descrA, const Scalar* csrSortedValA,
|
||||
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
|
||||
const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C,
|
||||
int ldc) const;
|
||||
@ -231,8 +304,8 @@ class CudaSparse {
|
||||
//
|
||||
// **NOTE** This is an in-place operation for data in y.
|
||||
template <typename Scalar>
|
||||
Status Csrmv(cusparseOperation_t transA, int m, int n, int nnz,
|
||||
const Scalar* alpha_host, const cusparseMatDescr_t descrA,
|
||||
Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz,
|
||||
const Scalar* alpha_host, const gpusparseMatDescr_t descrA,
|
||||
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA, const Scalar* x,
|
||||
const Scalar* beta_host, Scalar* y) const;
|
||||
@ -242,11 +315,11 @@ class CudaSparse {
|
||||
// output. csrSortedRowPtrC must be preallocated on device with
|
||||
// m + 1 entries. See:
|
||||
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgeam.
|
||||
Status CsrgeamNnz(int m, int n, const cusparseMatDescr_t descrA, int nnzA,
|
||||
Status CsrgeamNnz(int m, int n, const gpusparseMatDescr_t descrA, int nnzA,
|
||||
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
|
||||
const cusparseMatDescr_t descrB, int nnzB,
|
||||
const gpusparseMatDescr_t descrB, int nnzB,
|
||||
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
||||
const cusparseMatDescr_t descrC, int* csrSortedRowPtrC,
|
||||
const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
|
||||
int* nnzTotalDevHostPtr);
|
||||
|
||||
// Computes sparse - sparse matrix addition of matrices
|
||||
@ -256,12 +329,12 @@ class CudaSparse {
|
||||
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgeam.
|
||||
template <typename Scalar>
|
||||
Status Csrgeam(int m, int n, const Scalar* alpha,
|
||||
const cusparseMatDescr_t descrA, int nnzA,
|
||||
const gpusparseMatDescr_t descrA, int nnzA,
|
||||
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA, const Scalar* beta,
|
||||
const cusparseMatDescr_t descrB, int nnzB,
|
||||
const gpusparseMatDescr_t descrB, int nnzB,
|
||||
const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
|
||||
const int* csrSortedColIndB, const cusparseMatDescr_t descrC,
|
||||
const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
|
||||
Scalar* csrSortedValC, int* csrSortedRowPtrC,
|
||||
int* csrSortedColIndC);
|
||||
|
||||
@ -270,13 +343,13 @@ class CudaSparse {
|
||||
// output. csrSortedRowPtrC must be preallocated on device with
|
||||
// m + 1 entries. See:
|
||||
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
|
||||
Status CsrgemmNnz(cusparseOperation_t transA, cusparseOperation_t transB,
|
||||
int m, int k, int n, const cusparseMatDescr_t descrA,
|
||||
Status CsrgemmNnz(gpusparseOperation_t transA, gpusparseOperation_t transB,
|
||||
int m, int k, int n, const gpusparseMatDescr_t descrA,
|
||||
int nnzA, const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA,
|
||||
const cusparseMatDescr_t descrB, int nnzB,
|
||||
const gpusparseMatDescr_t descrB, int nnzB,
|
||||
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
||||
const cusparseMatDescr_t descrC, int* csrSortedRowPtrC,
|
||||
const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
|
||||
int* nnzTotalDevHostPtr);
|
||||
|
||||
// Computes sparse - sparse matrix matmul of matrices
|
||||
@ -285,19 +358,20 @@ class CudaSparse {
|
||||
// with nnzTotalDevHostPtr entries (as calculated by CsrgemmNnz). See:
|
||||
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
|
||||
template <typename Scalar>
|
||||
Status Csrgemm(cusparseOperation_t transA, cusparseOperation_t transB, int m,
|
||||
int k, int n, const cusparseMatDescr_t descrA, int nnzA,
|
||||
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA, const cusparseMatDescr_t descrB,
|
||||
int nnzB, const Scalar* csrSortedValB,
|
||||
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
||||
const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
|
||||
int* csrSortedRowPtrC, int* csrSortedColIndC);
|
||||
Status Csrgemm(gpusparseOperation_t transA, gpusparseOperation_t transB,
|
||||
int m, int k, int n, const gpusparseMatDescr_t descrA,
|
||||
int nnzA, const Scalar* csrSortedValA,
|
||||
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
|
||||
const gpusparseMatDescr_t descrB, int nnzB,
|
||||
const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
|
||||
const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
|
||||
Scalar* csrSortedValC, int* csrSortedRowPtrC,
|
||||
int* csrSortedColIndC);
|
||||
|
||||
// In-place reordering of unsorted CSR to sorted CSR.
|
||||
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csru2csr
|
||||
template <typename Scalar>
|
||||
Status Csru2csr(int m, int n, int nnz, const cusparseMatDescr_t descrA,
|
||||
Status Csru2csr(int m, int n, int nnz, const gpusparseMatDescr_t descrA,
|
||||
Scalar* csrVal, const int* csrRowPtr, int* csrColInd);
|
||||
|
||||
// Converts from CSR to CSC format (equivalently, transpose).
|
||||
@ -306,30 +380,30 @@ class CudaSparse {
|
||||
Status Csr2csc(int m, int n, int nnz, const Scalar* csrVal,
|
||||
const int* csrRowPtr, const int* csrColInd, Scalar* cscVal,
|
||||
int* cscRowInd, int* cscColPtr,
|
||||
const cusparseAction_t copyValues);
|
||||
const gpusparseAction_t copyValues);
|
||||
|
||||
private:
|
||||
bool initialized_;
|
||||
OpKernelContext *context_; // not owned.
|
||||
cudaStream_t cuda_stream_;
|
||||
cusparseHandle_t *cusparse_handle_; // not owned.
|
||||
gpuStream_t gpu_stream_;
|
||||
gpusparseHandle_t* gpusparse_handle_; // not owned.
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CudaSparse);
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GpuSparse);
|
||||
};
|
||||
|
||||
// A wrapper class to ensure that a CUDA sparse matrix descriptor is initialized
|
||||
// only once. For more details on the descriptor (cusparseMatDescr_t), see:
|
||||
// only once. For more details on the descriptor (gpusparseMatDescr_t), see:
|
||||
// https://docs.nvidia.com/cuda/cusparse/index.html#cusparsematdescrt
|
||||
class CudaSparseMatrixDescriptor {
|
||||
class GpuSparseMatrixDescriptor {
|
||||
public:
|
||||
explicit CudaSparseMatrixDescriptor() : initialized_(false) {}
|
||||
explicit GpuSparseMatrixDescriptor() : initialized_(false) {}
|
||||
|
||||
CudaSparseMatrixDescriptor(CudaSparseMatrixDescriptor&& rhs)
|
||||
GpuSparseMatrixDescriptor(GpuSparseMatrixDescriptor&& rhs)
|
||||
: initialized_(rhs.initialized_), descr_(std::move(rhs.descr_)) {
|
||||
rhs.initialized_ = false;
|
||||
}
|
||||
|
||||
CudaSparseMatrixDescriptor& operator=(CudaSparseMatrixDescriptor&& rhs) {
|
||||
GpuSparseMatrixDescriptor& operator=(GpuSparseMatrixDescriptor&& rhs) {
|
||||
if (this == &rhs) return *this;
|
||||
Release();
|
||||
initialized_ = rhs.initialized_;
|
||||
@ -338,23 +412,27 @@ class CudaSparseMatrixDescriptor {
|
||||
return *this;
|
||||
}
|
||||
|
||||
~CudaSparseMatrixDescriptor() { Release(); }
|
||||
~GpuSparseMatrixDescriptor() { Release(); }
|
||||
|
||||
// Initializes the underlying descriptor. Will fail on the second call if
|
||||
// called more than once.
|
||||
Status Initialize() {
|
||||
DCHECK(!initialized_);
|
||||
TF_RETURN_IF_CUSPARSE_ERROR(cusparseCreateMatDescr(&descr_));
|
||||
#if GOOGLE_CUDA
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descr_));
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descr_));
|
||||
#endif
|
||||
initialized_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
cusparseMatDescr_t& descr() {
|
||||
gpusparseMatDescr_t& descr() {
|
||||
DCHECK(initialized_);
|
||||
return descr_;
|
||||
}
|
||||
|
||||
const cusparseMatDescr_t& descr() const {
|
||||
const gpusparseMatDescr_t& descr() const {
|
||||
DCHECK(initialized_);
|
||||
return descr_;
|
||||
}
|
||||
@ -362,31 +440,37 @@ class CudaSparseMatrixDescriptor {
|
||||
private:
|
||||
void Release() {
|
||||
if (initialized_) {
|
||||
#if GOOGLE_CUDA
|
||||
cusparseDestroyMatDescr(descr_);
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
hipsparseDestroyMatDescr(descr_);
|
||||
#endif
|
||||
initialized_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
bool initialized_;
|
||||
cusparseMatDescr_t descr_;
|
||||
gpusparseMatDescr_t descr_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CudaSparseMatrixDescriptor);
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GpuSparseMatrixDescriptor);
|
||||
};
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
// A wrapper class to ensure that an unsorted/sorted CSR conversion information
|
||||
// struct (csru2csrInfo_t) is initialized only once. See:
|
||||
// https://docs.nvidia.com/cuda/cusparse/index.html#csru2csr
|
||||
class CudaSparseCsrSortingConversionInfo {
|
||||
class GpuSparseCsrSortingConversionInfo {
|
||||
public:
|
||||
explicit CudaSparseCsrSortingConversionInfo() : initialized_(false) {}
|
||||
explicit GpuSparseCsrSortingConversionInfo() : initialized_(false) {}
|
||||
|
||||
CudaSparseCsrSortingConversionInfo(CudaSparseCsrSortingConversionInfo&& rhs)
|
||||
GpuSparseCsrSortingConversionInfo(GpuSparseCsrSortingConversionInfo&& rhs)
|
||||
: initialized_(rhs.initialized_), info_(std::move(rhs.info_)) {
|
||||
rhs.initialized_ = false;
|
||||
}
|
||||
|
||||
CudaSparseCsrSortingConversionInfo& operator=(
|
||||
CudaSparseCsrSortingConversionInfo&& rhs) {
|
||||
GpuSparseCsrSortingConversionInfo& operator=(
|
||||
GpuSparseCsrSortingConversionInfo&& rhs) {
|
||||
if (this == &rhs) return *this;
|
||||
Release();
|
||||
initialized_ = rhs.initialized_;
|
||||
@ -395,13 +479,13 @@ class CudaSparseCsrSortingConversionInfo {
|
||||
return *this;
|
||||
}
|
||||
|
||||
~CudaSparseCsrSortingConversionInfo() { Release(); }
|
||||
~GpuSparseCsrSortingConversionInfo() { Release(); }
|
||||
|
||||
// Initializes the underlying info. Will fail on the second call if called
|
||||
// more than once.
|
||||
Status Initialize() {
|
||||
DCHECK(!initialized_);
|
||||
TF_RETURN_IF_CUSPARSE_ERROR(cusparseCreateCsru2csrInfo(&info_));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsru2csrInfo(&info_));
|
||||
initialized_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -427,11 +511,13 @@ class CudaSparseCsrSortingConversionInfo {
|
||||
bool initialized_;
|
||||
csru2csrInfo_t info_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CudaSparseCsrSortingConversionInfo);
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GpuSparseCsrSortingConversionInfo);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_
|
||||
|
330
tensorflow/core/kernels/rocm_sparse.cc
Normal file
330
tensorflow/core/kernels/rocm_sparse.cc
Normal file
@ -0,0 +1,330 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
|
||||
#include <complex>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// A set of initialized handles to the underlying ROCm libraries used by
|
||||
// GpuSparse. We maintain one such set of handles per unique stream.
|
||||
class HipSparseHandles {
|
||||
public:
|
||||
explicit HipSparseHandles(hipStream_t stream)
|
||||
: initialized_(false), stream_(stream) {}
|
||||
|
||||
HipSparseHandles(HipSparseHandles&& rhs)
|
||||
: initialized_(rhs.initialized_),
|
||||
stream_(std::move(rhs.stream_)),
|
||||
hipsparse_handle_(rhs.hipsparse_handle_) {
|
||||
rhs.initialized_ = false;
|
||||
}
|
||||
|
||||
HipSparseHandles& operator=(HipSparseHandles&& rhs) {
|
||||
if (this == &rhs) return *this;
|
||||
Release();
|
||||
stream_ = std::move(rhs.stream_);
|
||||
hipsparse_handle_ = std::move(rhs.hipsparse_handle_);
|
||||
initialized_ = rhs.initialized_;
|
||||
rhs.initialized_ = false;
|
||||
return *this;
|
||||
}
|
||||
|
||||
~HipSparseHandles() { Release(); }
|
||||
|
||||
Status Initialize() {
|
||||
if (initialized_) return Status::OK();
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreate(&hipsparse_handle_));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
hipsparseSetStream(hipsparse_handle_, stream_));
|
||||
initialized_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
hipsparseHandle_t& handle() {
|
||||
DCHECK(initialized_);
|
||||
return hipsparse_handle_;
|
||||
}
|
||||
|
||||
const hipsparseHandle_t& handle() const {
|
||||
DCHECK(initialized_);
|
||||
return hipsparse_handle_;
|
||||
}
|
||||
|
||||
private:
|
||||
void Release() {
|
||||
if (initialized_) {
|
||||
// This should never return anything other than success
|
||||
auto err = hipsparseDestroy(hipsparse_handle_);
|
||||
DCHECK(err == HIPSPARSE_STATUS_SUCCESS)
|
||||
<< "Failed to destroy hipSPARSE instance.";
|
||||
initialized_ = false;
|
||||
}
|
||||
}
|
||||
bool initialized_;
|
||||
hipStream_t stream_;
|
||||
hipsparseHandle_t hipsparse_handle_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(HipSparseHandles);
|
||||
};
|
||||
|
||||
// TODO(ebrevdo): Replace global mutex guarding CudaSparseHandles
|
||||
// lookup with one of:
|
||||
// 1. Adding the handle to the CudaStream structure; do the lookup there.
|
||||
// 2. Add a thread-local cusparse, set it to the current stream
|
||||
// upon each call.
|
||||
// #1 seems like the cleanest option but will need to wait until this
|
||||
// is moved into TF core.
|
||||
static mutex handle_map_mutex(LINKER_INITIALIZED);
|
||||
|
||||
using HandleMap = std::unordered_map<hipStream_t, HipSparseHandles>;
|
||||
|
||||
// Returns a singleton map used for storing initialized handles for each unique
|
||||
// cuda stream.
|
||||
HandleMap* GetHandleMapSingleton() {
|
||||
static HandleMap* cm = new HandleMap;
|
||||
return cm;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
GpuSparse::GpuSparse(OpKernelContext* context)
|
||||
: initialized_(false), context_(context) {
|
||||
auto hip_stream_ptr =
|
||||
reinterpret_cast<const hipStream_t*>(context->op_device_context()
|
||||
->stream()
|
||||
->implementation()
|
||||
->GpuStreamMemberHack());
|
||||
DCHECK(hip_stream_ptr);
|
||||
gpu_stream_ = *hip_stream_ptr;
|
||||
}
|
||||
|
||||
Status GpuSparse::Initialize() {
|
||||
HandleMap* handle_map = GetHandleMapSingleton();
|
||||
DCHECK(handle_map);
|
||||
mutex_lock lock(handle_map_mutex);
|
||||
auto it = handle_map->find(gpu_stream_);
|
||||
if (it == handle_map->end()) {
|
||||
LOG(INFO) << "Creating GpuSparse handles for stream " << gpu_stream_;
|
||||
// Previously unseen ROCm stream. Initialize a set of ROCm sparse library
|
||||
// handles for it.
|
||||
HipSparseHandles new_handles(gpu_stream_);
|
||||
TF_RETURN_IF_ERROR(new_handles.Initialize());
|
||||
it = handle_map->insert(std::make_pair(gpu_stream_, std::move(new_handles)))
|
||||
.first;
|
||||
}
|
||||
gpusparse_handle_ = &it->second.handle();
|
||||
initialized_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Macro that specializes a sparse method for all 4 standard
|
||||
// numeric types.
|
||||
#define TF_CALL_HIP_LAPACK_TYPES(m) m(float, S) m(double, D)
|
||||
|
||||
// Macros to construct hipsparse method names.
|
||||
#define SPARSE_FN(method, sparse_prefix) hipsparse##sparse_prefix##method
|
||||
|
||||
Status GpuSparse::Coo2csr(const int* cooRowInd, int nnz, int m,
|
||||
int* csrRowPtr) const {
|
||||
DCHECK(initialized_);
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcoo2csr(*gpusparse_handle_, cooRowInd,
|
||||
nnz, m, csrRowPtr,
|
||||
HIPSPARSE_INDEX_BASE_ZERO));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GpuSparse::Csr2coo(const int* csrRowPtr, int nnz, int m,
|
||||
int* cooRowInd) const {
|
||||
DCHECK(initialized_);
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcsr2coo(*gpusparse_handle_, csrRowPtr,
|
||||
nnz, m, cooRowInd,
|
||||
HIPSPARSE_INDEX_BASE_ZERO));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename Scalar, typename SparseFnT>
|
||||
static inline Status CsrmmImpl(
|
||||
SparseFnT op, OpKernelContext* context, hipsparseHandle_t hipsparse_handle,
|
||||
hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n,
|
||||
int k, int nnz, const Scalar* alpha_host, const hipsparseMatDescr_t descrA,
|
||||
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA, const Scalar* B, int ldb,
|
||||
const Scalar* beta_host, Scalar* C, int ldc) {
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(op(hipsparse_handle, transA, transB, m, n, k,
|
||||
nnz, alpha_host, descrA, csrSortedValA,
|
||||
csrSortedRowPtrA, csrSortedColIndA, B, ldb,
|
||||
beta_host, C, ldc));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define CSRMM_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status GpuSparse::Csrmm<Scalar>( \
|
||||
hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n, \
|
||||
int k, int nnz, const Scalar* alpha_host, \
|
||||
const hipsparseMatDescr_t descrA, const Scalar* csrSortedValA, \
|
||||
const int* csrSortedRowPtrA, const int* csrSortedColIndA, \
|
||||
const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C, int ldc) \
|
||||
const { \
|
||||
DCHECK(initialized_); \
|
||||
return CsrmmImpl(SPARSE_FN(csrmm2, sparse_prefix), context_, \
|
||||
*gpusparse_handle_, transA, transB, m, n, k, nnz, \
|
||||
alpha_host, descrA, csrSortedValA, csrSortedRowPtrA, \
|
||||
csrSortedColIndA, B, ldb, beta_host, C, ldc); \
|
||||
}
|
||||
|
||||
TF_CALL_HIP_LAPACK_TYPES(CSRMM_INSTANCE);
|
||||
|
||||
template <typename Scalar, typename SparseFnT>
|
||||
static inline Status CsrmvImpl(SparseFnT op, OpKernelContext* context,
|
||||
hipsparseHandle_t hipsparse_handle,
|
||||
hipsparseOperation_t transA, int m, int n,
|
||||
int nnz, const Scalar* alpha_host,
|
||||
const hipsparseMatDescr_t descrA,
|
||||
const Scalar* csrSortedValA,
|
||||
const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA, const Scalar* x,
|
||||
const Scalar* beta_host, Scalar* y) {
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
op(hipsparse_handle, transA, m, n, nnz, alpha_host, descrA, csrSortedValA,
|
||||
csrSortedRowPtrA, csrSortedColIndA, x, beta_host, y));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// TODO(ebrevdo,rmlarsen): Use csrmv_mp for all cases when available in CUDA 9.
|
||||
#define CSRMV_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status GpuSparse::Csrmv<Scalar>( \
|
||||
hipsparseOperation_t transA, int m, int n, int nnz, \
|
||||
const Scalar* alpha_host, const hipsparseMatDescr_t descrA, \
|
||||
const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
|
||||
const int* csrSortedColIndA, const Scalar* x, const Scalar* beta_host, \
|
||||
Scalar* y) const { \
|
||||
DCHECK(initialized_); \
|
||||
return CsrmvImpl(SPARSE_FN(csrmv, sparse_prefix), context_, \
|
||||
*gpusparse_handle_, transA, m, n, nnz, alpha_host, \
|
||||
descrA, csrSortedValA, csrSortedRowPtrA, \
|
||||
csrSortedColIndA, x, beta_host, y); \
|
||||
}
|
||||
|
||||
TF_CALL_HIP_LAPACK_TYPES(CSRMV_INSTANCE);
|
||||
|
||||
Status GpuSparse::CsrgemmNnz(
|
||||
hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n,
|
||||
int k, const hipsparseMatDescr_t descrA, int nnzA,
|
||||
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
|
||||
const hipsparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,
|
||||
const int* csrSortedColIndB, const hipsparseMatDescr_t descrC,
|
||||
int* csrSortedRowPtrC, int* nnzTotalDevHostPtr) {
|
||||
DCHECK(initialized_);
|
||||
DCHECK(nnzTotalDevHostPtr != nullptr);
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcsrgemmNnz(
|
||||
*gpusparse_handle_, transA, transB, m, n, k, descrA, nnzA,
|
||||
csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB,
|
||||
csrSortedColIndB, descrC, csrSortedRowPtrC, nnzTotalDevHostPtr));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename Scalar, typename SparseFnT>
|
||||
static inline Status CsrgemmImpl(
|
||||
SparseFnT op, OpKernelContext* context, hipsparseHandle_t hipsparse_handle,
|
||||
hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n,
|
||||
int k, const hipsparseMatDescr_t descrA, int nnzA,
|
||||
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA, const hipsparseMatDescr_t descrB, int nnzB,
|
||||
const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
|
||||
const int* csrSortedColIndB, const hipsparseMatDescr_t descrC,
|
||||
Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC) {
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
op(hipsparse_handle, transA, transB, m, n, k, descrA, nnzA, csrSortedValA,
|
||||
csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, csrSortedValB,
|
||||
csrSortedRowPtrB, csrSortedColIndB, descrC, csrSortedValC,
|
||||
csrSortedRowPtrC, csrSortedColIndC));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define CSRGEMM_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status GpuSparse::Csrgemm<Scalar>( \
|
||||
hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n, \
|
||||
int k, const hipsparseMatDescr_t descrA, int nnzA, \
|
||||
const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
|
||||
const int* csrSortedColIndA, const hipsparseMatDescr_t descrB, int nnzB, \
|
||||
const Scalar* csrSortedValB, const int* csrSortedRowPtrB, \
|
||||
const int* csrSortedColIndB, const hipsparseMatDescr_t descrC, \
|
||||
Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC) { \
|
||||
DCHECK(initialized_); \
|
||||
return CsrgemmImpl(SPARSE_FN(csrgemm, sparse_prefix), context_, \
|
||||
*gpusparse_handle_, transA, transB, m, n, k, descrA, \
|
||||
nnzA, csrSortedValA, csrSortedRowPtrA, \
|
||||
csrSortedColIndA, descrB, nnzB, csrSortedValB, \
|
||||
csrSortedRowPtrB, csrSortedColIndB, descrC, \
|
||||
csrSortedValC, csrSortedRowPtrC, csrSortedColIndC); \
|
||||
}
|
||||
|
||||
TF_CALL_HIP_LAPACK_TYPES(CSRGEMM_INSTANCE);
|
||||
|
||||
template <typename Scalar, typename SparseFnT>
|
||||
static inline Status Csr2cscImpl(SparseFnT op, OpKernelContext* context,
|
||||
hipsparseHandle_t hipsparse_handle, int m,
|
||||
int n, int nnz, const Scalar* csrVal,
|
||||
const int* csrRowPtr, const int* csrColInd,
|
||||
Scalar* cscVal, int* cscRowInd, int* cscColPtr,
|
||||
const hipsparseAction_t copyValues) {
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
op(hipsparse_handle, m, n, nnz, csrVal, csrRowPtr, csrColInd, cscVal,
|
||||
cscRowInd, cscColPtr, copyValues, HIPSPARSE_INDEX_BASE_ZERO));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define CSR2CSC_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status GpuSparse::Csr2csc<Scalar>( \
|
||||
int m, int n, int nnz, const Scalar* csrVal, const int* csrRowPtr, \
|
||||
const int* csrColInd, Scalar* cscVal, int* cscRowInd, int* cscColPtr, \
|
||||
const hipsparseAction_t copyValues) { \
|
||||
DCHECK(initialized_); \
|
||||
return Csr2cscImpl(SPARSE_FN(csr2csc, sparse_prefix), context_, \
|
||||
*gpusparse_handle_, m, n, nnz, csrVal, csrRowPtr, \
|
||||
csrColInd, cscVal, cscRowInd, cscColPtr, copyValues); \
|
||||
}
|
||||
|
||||
TF_CALL_HIP_LAPACK_TYPES(CSR2CSC_INSTANCE);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_USE_ROCM
|
@ -2,10 +2,10 @@
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"if_cuda_or_rocm",
|
||||
"tf_cc_test",
|
||||
"tf_kernel_library",
|
||||
)
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
@ -77,7 +77,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core/kernels:scatter_nd_op",
|
||||
"//tensorflow/core/kernels:slice_op",
|
||||
"//tensorflow/core/kernels:transpose_functor",
|
||||
] + if_cuda([
|
||||
] + if_cuda_or_rocm([
|
||||
"//tensorflow/core/kernels:cuda_solvers",
|
||||
"//tensorflow/core/kernels:cuda_sparse",
|
||||
]),
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
@ -31,7 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||
#include "tensorflow/core/kernels/fill_functor.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#endif
|
||||
@ -233,8 +233,10 @@ class CSRAddOp : public OpKernel {
|
||||
|
||||
REGISTER_GPU(float)
|
||||
REGISTER_GPU(double)
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_GPU(complex64)
|
||||
REGISTER_GPU(complex128)
|
||||
#endif
|
||||
|
||||
#undef REGISTER_GPU
|
||||
|
||||
@ -246,7 +248,7 @@ REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(
|
||||
|
||||
#undef REGISTER
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
namespace functor {
|
||||
template <typename T>
|
||||
struct CSRSparseMatrixAdd<GPUDevice, T>
|
||||
@ -324,10 +326,10 @@ struct CSRSparseMatrixAdd<GPUDevice, T>
|
||||
|
||||
private:
|
||||
OpKernelContext* ctx_;
|
||||
CudaSparse cuda_sparse_;
|
||||
CudaSparseMatrixDescriptor descrA_;
|
||||
CudaSparseMatrixDescriptor descrB_;
|
||||
CudaSparseMatrixDescriptor descrC_;
|
||||
GpuSparse cuda_sparse_;
|
||||
GpuSparseMatrixDescriptor descrA_;
|
||||
GpuSparseMatrixDescriptor descrB_;
|
||||
GpuSparseMatrixDescriptor descrC_;
|
||||
const T alpha_;
|
||||
const T beta_;
|
||||
bool initialized_;
|
||||
@ -337,6 +339,6 @@ struct CSRSparseMatrixAdd<GPUDevice, T>
|
||||
|
||||
} // namespace functor
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
@ -31,7 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#endif
|
||||
@ -92,12 +92,12 @@ REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(
|
||||
CONJ_VARIANT_UNARY_OP, DEVICE_CPU, CSRSparseMatrix,
|
||||
(CSRSparseMatrixUnaryHelper<CPUDevice, CSRSparseMatrixConjFunctor>));
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(
|
||||
CONJ_VARIANT_UNARY_OP, DEVICE_GPU, CSRSparseMatrix,
|
||||
(CSRSparseMatrixUnaryHelper<GPUDevice, CSRSparseMatrixConjFunctor>));
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
@ -33,7 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#endif
|
||||
@ -220,19 +220,21 @@ REGISTER_CPU(double)
|
||||
REGISTER_CPU(complex64)
|
||||
REGISTER_CPU(complex128)
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
REGISTER_GPU(float)
|
||||
REGISTER_GPU(double)
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_GPU(complex64)
|
||||
REGISTER_GPU(complex128)
|
||||
#endif
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#undef REGISTER_CPU
|
||||
#undef REGISTER_GPU
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
namespace functor {
|
||||
template <>
|
||||
@ -256,6 +258,6 @@ extern template struct CSRSparseMatrixToCOOSparseMatrix<GPUDevice>;
|
||||
|
||||
} // namespace functor
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
@ -31,7 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#endif
|
||||
@ -205,18 +205,20 @@ class CSRSparseMatrixToSparseTensorGPUOp : public OpKernel {
|
||||
.HostMemory("dense_shape"), \
|
||||
CSRSparseMatrixToSparseTensorGPUOp<GPUDevice, T>);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
REGISTER_GPU(float)
|
||||
REGISTER_GPU(double)
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_GPU(complex64)
|
||||
REGISTER_GPU(complex128)
|
||||
#endif
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#undef REGISTER_GPU
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
namespace functor {
|
||||
template <>
|
||||
@ -240,7 +242,7 @@ extern template struct CSRSparseMatrixToCOOSparseMatrix<GPUDevice>;
|
||||
|
||||
} // namespace functor
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define REGISTER_CPU(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("CSRSparseMatrixToSparseTensor") \
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
@ -32,13 +32,18 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
|
||||
#endif
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
|
||||
using ::perftools::gputools::cuda::ScopedActivateExecutorContext;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/stream_executor/rocm/rocm_activation.h"
|
||||
using ::perftools::gputools::rocm::ScopedActivateExecutorContext;
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
@ -138,7 +143,7 @@ REGISTER_CPU(complex128)
|
||||
|
||||
#undef REGISTER_CPU
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
template <typename Device, typename T>
|
||||
class DenseToCSRSparseMatrixGPUOp : public AsyncOpKernel {
|
||||
@ -356,8 +361,10 @@ class DenseToCSRSparseMatrixGPUOp : public AsyncOpKernel {
|
||||
|
||||
REGISTER_GPU(GPU, float)
|
||||
REGISTER_GPU(GPU, double)
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_GPU(GPU, complex64)
|
||||
REGISTER_GPU(GPU, complex128)
|
||||
#endif
|
||||
|
||||
namespace functor {
|
||||
|
||||
@ -380,7 +387,7 @@ struct COOSparseMatrixToCSRSparseMatrix<GPUDevice> {
|
||||
Status operator()(OpKernelContext* c, const int rows, const int cols,
|
||||
TTypes<int>::UnalignedVec coo_row_ind,
|
||||
TTypes<int>::UnalignedVec csr_row_ptr) {
|
||||
CudaSparse cuda_sparse(c);
|
||||
GpuSparse cuda_sparse(c);
|
||||
TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
|
||||
return cuda_sparse.Coo2csr(coo_row_ind.data(),
|
||||
/*nnz*/ coo_row_ind.size(),
|
||||
@ -391,7 +398,7 @@ extern template struct COOSparseMatrixToCSRSparseMatrix<GPUDevice>;
|
||||
|
||||
} // namespace functor
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#undef REGISTER_GPU
|
||||
|
||||
|
@ -13,15 +13,19 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/cub/device/device_histogram.cuh"
|
||||
#include "third_party/cub/iterator/counting_input_iterator.cuh"
|
||||
#include "third_party/cub/iterator/transform_input_iterator.cuh"
|
||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#include "rocm/include/hipcub/hipcub.hpp"
|
||||
#endif
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
@ -32,6 +36,12 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
namespace gpuprim = ::cub;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
namespace gpuprim = ::hipcub;
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
@ -65,9 +75,9 @@ Status CalculateNNZPerBatchMatrixFromIndices<GPUDevice>::operator()(
|
||||
DCHECK_EQ(indices.dimension(1), 3); // batch, row, col
|
||||
|
||||
const int rank = indices.dimension(1);
|
||||
cub::CountingInputIterator<int> row_counter(0);
|
||||
cub::TransformInputIterator<int, StridedDataReader,
|
||||
cub::CountingInputIterator<int>>
|
||||
gpuprim::CountingInputIterator<int> row_counter(0);
|
||||
gpuprim::TransformInputIterator<int, StridedDataReader,
|
||||
gpuprim::CountingInputIterator<int>>
|
||||
indices_first_column(row_counter,
|
||||
StridedDataReader(indices.data(), rank));
|
||||
|
||||
@ -76,7 +86,7 @@ Status CalculateNNZPerBatchMatrixFromIndices<GPUDevice>::operator()(
|
||||
DCHECK_NE(indices.data(), nullptr);
|
||||
DCHECK_NE(nnz_per_batch.data(), nullptr);
|
||||
|
||||
auto first_success = cub::DeviceHistogram::HistogramEven(
|
||||
auto first_success = gpuprim::DeviceHistogram::HistogramEven(
|
||||
/*d_temp_storage*/ nullptr,
|
||||
/*temp_storage_bytes&*/ temp_storage_bytes,
|
||||
/*d_samples*/ indices_first_column,
|
||||
@ -87,12 +97,12 @@ Status CalculateNNZPerBatchMatrixFromIndices<GPUDevice>::operator()(
|
||||
/*num_samples*/ total_nnz,
|
||||
/*stream*/ cu_stream);
|
||||
|
||||
if (first_success != cudaSuccess) {
|
||||
if (first_success != gpuSuccess) {
|
||||
return errors::Internal(
|
||||
"SparseTensorToCSRSparseMatrix: Could not launch "
|
||||
"cub::DeviceHistogram::HistogramEven "
|
||||
"gpuprim::DeviceHistogram::HistogramEven "
|
||||
"to calculate temp_storage_bytes, status: ",
|
||||
cudaGetErrorString(first_success));
|
||||
GpuGetErrorString(first_success));
|
||||
}
|
||||
|
||||
Tensor temp_storage;
|
||||
@ -100,7 +110,7 @@ Status CalculateNNZPerBatchMatrixFromIndices<GPUDevice>::operator()(
|
||||
DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
|
||||
&temp_storage));
|
||||
DCHECK_NE(temp_storage.flat<int8>().data(), nullptr);
|
||||
auto second_success = cub::DeviceHistogram::HistogramEven(
|
||||
auto second_success = gpuprim::DeviceHistogram::HistogramEven(
|
||||
/*d_temp_storage*/ temp_storage.flat<int8>().data(),
|
||||
/*temp_storage_bytes&*/ temp_storage_bytes,
|
||||
/*d_samples*/ indices_first_column,
|
||||
@ -111,12 +121,12 @@ Status CalculateNNZPerBatchMatrixFromIndices<GPUDevice>::operator()(
|
||||
/*num_samples*/ total_nnz,
|
||||
/*stream*/ cu_stream);
|
||||
|
||||
if (second_success != cudaSuccess) {
|
||||
if (second_success != gpuSuccess) {
|
||||
return errors::Internal(
|
||||
"SparseTensorToCSRSparseMatrix: Could not launch "
|
||||
"cub::DeviceHistogram::HistogramEven "
|
||||
"gpuprim::DeviceHistogram::HistogramEven "
|
||||
"to count nnz entries per batch. temp_storage_bytes: ",
|
||||
temp_storage_bytes, ", status: ", cudaGetErrorString(second_success));
|
||||
temp_storage_bytes, ", status: ", GpuGetErrorString(second_success));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -128,11 +138,11 @@ template <>
|
||||
Status CSRSparseMatrixToCOOSparseMatrix<GPUDevice>::operator()(
|
||||
OpKernelContext* c, TTypes<const int>::UnalignedVec csr_row_ptr,
|
||||
TTypes<int>::UnalignedVec coo_row_ind) {
|
||||
CudaSparse cuda_sparse(c);
|
||||
GpuSparse gpu_sparse(c);
|
||||
const int nnz = coo_row_ind.size();
|
||||
TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
|
||||
TF_RETURN_IF_ERROR(gpu_sparse.Initialize());
|
||||
const int m = csr_row_ptr.size() - 1; // rows
|
||||
return cuda_sparse.Csr2coo(csr_row_ptr.data(), nnz, m, coo_row_ind.data());
|
||||
return gpu_sparse.Csr2coo(csr_row_ptr.data(), nnz, m, coo_row_ind.data());
|
||||
}
|
||||
|
||||
template <int stride>
|
||||
@ -140,7 +150,7 @@ __global__ void SparseTensorToCOOMatrixKernel(const int64* indices,
|
||||
int* coo_rows_out,
|
||||
int* coo_cols_out, int size) {
|
||||
const int offset = (stride == 3) ? 1 : 0;
|
||||
CUDA_1D_KERNEL_LOOP(i, size) {
|
||||
GPU_1D_KERNEL_LOOP(i, size) {
|
||||
coo_rows_out[i] = static_cast<int>(ldg(indices + i * stride + offset));
|
||||
coo_cols_out[i] = static_cast<int>(ldg(indices + i * stride + offset + 1));
|
||||
}
|
||||
@ -157,20 +167,22 @@ void SparseTensorToCOOSparseMatrix<GPUDevice>::operator()(
|
||||
const int size = coo_row_ind.dimension(0);
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
|
||||
if (stride == 2) {
|
||||
SparseTensorToCOOMatrixKernel<2>
|
||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
indices.data(), coo_row_ind.data(), coo_col_ind.data(), size);
|
||||
TF_CHECK_OK(GpuLaunchKernel(SparseTensorToCOOMatrixKernel<2>,
|
||||
config.block_count, config.thread_per_block, 0,
|
||||
d.stream(), indices.data(), coo_row_ind.data(),
|
||||
coo_col_ind.data(), size));
|
||||
} else {
|
||||
SparseTensorToCOOMatrixKernel<3>
|
||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
indices.data(), coo_row_ind.data(), coo_col_ind.data(), size);
|
||||
TF_CHECK_OK(GpuLaunchKernel(SparseTensorToCOOMatrixKernel<3>,
|
||||
config.block_count, config.thread_per_block, 0,
|
||||
d.stream(), indices.data(), coo_row_ind.data(),
|
||||
coo_col_ind.data(), size));
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void COOMatrixToSparseTensorKernel2D(const int* coo_rows,
|
||||
const int* coo_cols,
|
||||
int64* indices_out, int size) {
|
||||
CUDA_1D_KERNEL_LOOP(i, size) {
|
||||
GPU_1D_KERNEL_LOOP(i, size) {
|
||||
indices_out[i * 2] = static_cast<int64>(ldg(coo_rows + i));
|
||||
indices_out[i * 2 + 1] = static_cast<int64>(ldg(coo_cols + i));
|
||||
}
|
||||
@ -203,7 +215,7 @@ __global__ void COOMatrixToSparseTensorKernel3D(
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
CUDA_1D_KERNEL_LOOP(i, size) {
|
||||
GPU_1D_KERNEL_LOOP(i, size) {
|
||||
// TODO(ebrevdo): Consider special casing batch_size <= 3,
|
||||
// alternatively doing linear instead of binary search. Requires
|
||||
// some benchmarks.
|
||||
@ -231,9 +243,10 @@ Status COOSparseMatrixToSparseTensor<GPUDevice>::operator()(
|
||||
DCHECK_EQ(size, indices.dimension(0));
|
||||
if (ndims == 2) {
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
|
||||
COOMatrixToSparseTensorKernel2D<<<config.block_count,
|
||||
config.thread_per_block, 0, d.stream()>>>(
|
||||
coo_row_ind.data(), coo_col_ind.data(), indices.data(), size);
|
||||
TF_CHECK_OK(GpuLaunchKernel(COOMatrixToSparseTensorKernel2D,
|
||||
config.block_count, config.thread_per_block, 0,
|
||||
d.stream(), coo_row_ind.data(),
|
||||
coo_col_ind.data(), indices.data(), size));
|
||||
return Status::OK();
|
||||
} else {
|
||||
const int batch_size = host_dense_shape(0);
|
||||
@ -246,11 +259,11 @@ Status COOSparseMatrixToSparseTensor<GPUDevice>::operator()(
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
|
||||
// shared memory stores the batch pointers.
|
||||
const size_t shared_memory_size = sizeof(int) * (batch_size + 1);
|
||||
COOMatrixToSparseTensorKernel3D<<<config.block_count,
|
||||
config.thread_per_block,
|
||||
shared_memory_size, d.stream()>>>(
|
||||
coo_row_ind.data(), coo_col_ind.data(), indices.data(),
|
||||
batch_ptr_copy.data(), batch_size, size);
|
||||
TF_CHECK_OK(
|
||||
GpuLaunchKernel(COOMatrixToSparseTensorKernel3D, config.block_count,
|
||||
config.thread_per_block, shared_memory_size, d.stream(),
|
||||
coo_row_ind.data(), coo_col_ind.data(), indices.data(),
|
||||
batch_ptr_copy.data(), batch_size, size));
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
@ -274,7 +287,7 @@ __global__ void CSRSparseMatrixBatchMulVecKernel3D(
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
CUDA_1D_KERNEL_LOOP(i, total_nnz) {
|
||||
GPU_1D_KERNEL_LOOP(i, total_nnz) {
|
||||
const int b = BinarySearchRange(local_batch_ptr, batch_size, i);
|
||||
c_values[i] = ldg(a_values + i) * local_batch_values[b];
|
||||
}
|
||||
@ -316,10 +329,10 @@ Status CSRSparseMatrixBatchMulVecImpl(OpKernelContext* ctx,
|
||||
const size_t shared_memory_size =
|
||||
(sizeof(int) * (batch_size + 1) // local batch_pointers.
|
||||
+ sizeof(T) * batch_size); // local copy of b.
|
||||
CSRSparseMatrixBatchMulVecKernel3D<T>
|
||||
<<<config.block_count, config.thread_per_block, shared_memory_size,
|
||||
d.stream()>>>(a_values.data(), b.data(), c_values.data(),
|
||||
batch_ptr_copy.data(), batch_size, total_nnz);
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
CSRSparseMatrixBatchMulVecKernel3D<T>, config.block_count,
|
||||
config.thread_per_block, shared_memory_size, d.stream(), a_values.data(),
|
||||
b.data(), c_values.data(), batch_ptr_copy.data(), batch_size, total_nnz));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
@ -374,7 +387,7 @@ __global__ void CSRSparseMatrixSoftmaxKernel2D(const int rows,
|
||||
// algorithm to distribute the work in case the row sizes are
|
||||
// uneven:
|
||||
// http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf
|
||||
CUDA_1D_KERNEL_LOOP(row, rows) {
|
||||
GPU_1D_KERNEL_LOOP(row, rows) {
|
||||
CalculateRowSoftmax(ldg(row_ptr + row), ldg(row_ptr + row + 1), logits,
|
||||
softmax);
|
||||
}
|
||||
@ -382,7 +395,7 @@ __global__ void CSRSparseMatrixSoftmaxKernel2D(const int rows,
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void CopyFromGpuDeviceArrayToLocal(
|
||||
GpuDeviceArrayStruct<int> cuda_ptr_s, int* local_ptr, int length) {
|
||||
#ifdef __CUDA_ARCH__
|
||||
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
|
||||
const int* cuda_ptr = GetGpuDeviceArrayOnDevice(&cuda_ptr_s);
|
||||
for (int i = threadIdx.x; i < length; i += blockDim.x) {
|
||||
local_ptr[i] = cuda_ptr[i];
|
||||
@ -404,7 +417,7 @@ __global__ void CSRSparseMatrixSoftmaxKernel3D(
|
||||
CopyFromGpuDeviceArrayToLocal(std::move(batch_ptr_s), local_batch_ptr,
|
||||
batch_size + 1);
|
||||
|
||||
CUDA_1D_KERNEL_LOOP(i, size) {
|
||||
GPU_1D_KERNEL_LOOP(i, size) {
|
||||
const int batch = i / rows;
|
||||
const int row = i % rows;
|
||||
const int batch_offset = local_batch_ptr[batch];
|
||||
@ -431,10 +444,10 @@ Status CSRSparseMatrixSoftmaxGPUImpl(OpKernelContext* ctx,
|
||||
const int rows = host_dense_shape(0);
|
||||
DCHECK_EQ(rows, row_ptr.size() - 1);
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(rows /*size*/, d);
|
||||
CSRSparseMatrixSoftmaxKernel2D<T>
|
||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
rows /*size*/, row_ptr.data(), logits_values.data(),
|
||||
softmax_values.data());
|
||||
TF_CHECK_OK(GpuLaunchKernel(CSRSparseMatrixSoftmaxKernel2D<T>,
|
||||
config.block_count, config.thread_per_block, 0,
|
||||
d.stream(), rows /*size*/, row_ptr.data(),
|
||||
logits_values.data(), softmax_values.data()));
|
||||
} else {
|
||||
const int batch_size = host_dense_shape(0);
|
||||
const int rows = host_dense_shape(1);
|
||||
@ -452,10 +465,11 @@ Status CSRSparseMatrixSoftmaxGPUImpl(OpKernelContext* ctx,
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
|
||||
// shared memory stores the batch pointers.
|
||||
const size_t shared_memory_size = sizeof(int) * (batch_size + 1);
|
||||
CSRSparseMatrixSoftmaxKernel3D<T>
|
||||
<<<config.block_count, config.thread_per_block, shared_memory_size,
|
||||
d.stream()>>>(size, rows, batch_ptr_copy.data(), row_ptr.data(),
|
||||
logits_values.data(), softmax_values.data());
|
||||
TF_CHECK_OK(GpuLaunchKernel(CSRSparseMatrixSoftmaxKernel3D<T>,
|
||||
config.block_count, config.thread_per_block,
|
||||
shared_memory_size, d.stream(), size, rows,
|
||||
batch_ptr_copy.data(), row_ptr.data(),
|
||||
logits_values.data(), softmax_values.data()));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -549,7 +563,7 @@ __global__ void CSRSparseMatrixSoftmaxGradKernel2D(
|
||||
// algorithm to distribute the work in case the row sizes are
|
||||
// uneven:
|
||||
// http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf
|
||||
CUDA_1D_KERNEL_LOOP(row, rows) {
|
||||
GPU_1D_KERNEL_LOOP(row, rows) {
|
||||
CalculateRowSoftmaxGrad(
|
||||
ldg(softmax_row_ptr + row) /*softmax_begin*/,
|
||||
ldg(softmax_row_ptr + row + 1) /*softmax_end*/, softmax_col_ind,
|
||||
@ -579,7 +593,7 @@ __global__ void CSRSparseMatrixSoftmaxGradKernel3D(
|
||||
#define SOFTMAX_BATCH_PTR(i) local_batch_ptr[i];
|
||||
#define GRAD_SOFTMAX_BATCH_PTR(i) local_batch_ptr[batch_size + 1 + i];
|
||||
|
||||
CUDA_1D_KERNEL_LOOP(i, size) {
|
||||
GPU_1D_KERNEL_LOOP(i, size) {
|
||||
const int batch = i / rows;
|
||||
const int row = i % rows;
|
||||
const int softmax_batch_offset = SOFTMAX_BATCH_PTR(batch);
|
||||
@ -625,12 +639,12 @@ Status CSRSparseMatrixSoftmaxGradGPUImpl(
|
||||
DCHECK_EQ(rows + 1, softmax_row_ptr.size());
|
||||
DCHECK_EQ(rows + 1, grad_softmax_row_ptr.size());
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(rows /*size*/, d);
|
||||
CSRSparseMatrixSoftmaxGradKernel2D<T>
|
||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
rows /*size*/, softmax_row_ptr.data(), softmax_col_ind.data(),
|
||||
softmax_values.data(), grad_softmax_row_ptr.data(),
|
||||
grad_softmax_col_ind.data(), grad_softmax_values.data(),
|
||||
gradient_values.data());
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
CSRSparseMatrixSoftmaxGradKernel2D<T>, config.block_count,
|
||||
config.thread_per_block, 0, d.stream(), rows /*size*/,
|
||||
softmax_row_ptr.data(), softmax_col_ind.data(), softmax_values.data(),
|
||||
grad_softmax_row_ptr.data(), grad_softmax_col_ind.data(),
|
||||
grad_softmax_values.data(), gradient_values.data()));
|
||||
} else {
|
||||
const int batch_size = host_dense_shape(0);
|
||||
const int rows = host_dense_shape(1);
|
||||
@ -656,13 +670,13 @@ Status CSRSparseMatrixSoftmaxGradGPUImpl(
|
||||
// shared memory stores two copies of batch pointers: one for the
|
||||
// softmax CSR matrix, one for the grad_softmax CSR matrix.
|
||||
const size_t shared_memory_size = 2 * sizeof(int) * (batch_size + 1);
|
||||
CSRSparseMatrixSoftmaxGradKernel3D<T>
|
||||
<<<config.block_count, config.thread_per_block, shared_memory_size,
|
||||
d.stream()>>>(size, rows, softmax_and_grad_batch_ptr_copy.data(),
|
||||
softmax_row_ptr.data(), softmax_col_ind.data(),
|
||||
softmax_values.data(), grad_softmax_row_ptr.data(),
|
||||
grad_softmax_col_ind.data(),
|
||||
grad_softmax_values.data(), gradient_values.data());
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
CSRSparseMatrixSoftmaxGradKernel3D<T>, config.block_count,
|
||||
config.thread_per_block, shared_memory_size, d.stream(), size, rows,
|
||||
softmax_and_grad_batch_ptr_copy.data(), softmax_row_ptr.data(),
|
||||
softmax_col_ind.data(), softmax_values.data(),
|
||||
grad_softmax_row_ptr.data(), grad_softmax_col_ind.data(),
|
||||
grad_softmax_values.data(), gradient_values.data()));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -687,4 +701,4 @@ DEFINE_SOFTMAX_GRAD_GPU(double);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
@ -36,7 +36,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/platform/threadpool.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#endif
|
||||
@ -694,7 +694,7 @@ REGISTER_CPU(complex128)
|
||||
|
||||
#undef REGISTER_CPU
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define REGISTER_GPU(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
@ -703,14 +703,16 @@ REGISTER_CPU(complex128)
|
||||
|
||||
REGISTER_GPU(float)
|
||||
REGISTER_GPU(double)
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_GPU(complex64)
|
||||
REGISTER_GPU(complex128)
|
||||
#endif
|
||||
|
||||
#undef REGISTER_GPU
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
namespace functor {
|
||||
|
||||
@ -723,7 +725,7 @@ class CSRSparseMatrixMatMul<GPUDevice, T> {
|
||||
Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a,
|
||||
typename TTypes<T>::UnalignedConstMatrix b,
|
||||
typename TTypes<T>::UnalignedMatrix c) {
|
||||
CudaSparse cuda_sparse(ctx);
|
||||
GpuSparse cuda_sparse(ctx);
|
||||
TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
|
||||
{
|
||||
// Use Csrmm to calculate:
|
||||
@ -741,19 +743,34 @@ class CSRSparseMatrixMatMul<GPUDevice, T> {
|
||||
|
||||
// transA must be non-transpose if transB is transpose (cusparse
|
||||
// limitation).
|
||||
const cusparseOperation_t transA = CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
#if GOOGLE_CUDA
|
||||
const gpusparseOperation_t transA = CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
const gpusparseOperation_t transA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
#endif
|
||||
|
||||
// transB: b is row-major, and cusparse requires col-major b (or
|
||||
// equivalently transB == transpose). this version is actually more
|
||||
// efficient.
|
||||
const cusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE;
|
||||
#if GOOGLE_CUDA
|
||||
const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE;
|
||||
|
||||
cusparseMatDescr_t descrA;
|
||||
TF_RETURN_IF_CUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
|
||||
TF_RETURN_IF_CUSPARSE_ERROR(
|
||||
gpusparseMatDescr_t descrA;
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
|
||||
TF_RETURN_IF_CUSPARSE_ERROR(
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
const gpusparseOperation_t transB = HIPSPARSE_OPERATION_TRANSPOSE;
|
||||
|
||||
gpusparseMatDescr_t descrA;
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
|
||||
#endif
|
||||
|
||||
// A is (m, k), Bt is (ldb, k) and Ct is (ldc, n)
|
||||
const int k = b.dimension(0);
|
||||
@ -796,13 +813,13 @@ template <typename T>
|
||||
class CSRSparseMatrixMatVec<GPUDevice, T> {
|
||||
public:
|
||||
CSRSparseMatrixMatVec(bool transpose_a, bool conjugate_a)
|
||||
: transA_(TransposeAndConjugateToCuSparseOp(transpose_a, conjugate_a,
|
||||
&status_)) {}
|
||||
: transA_(TransposeAndConjugateToGpuSparseOp(transpose_a, conjugate_a,
|
||||
&status_)) {}
|
||||
|
||||
Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a,
|
||||
const T* x, T* y) {
|
||||
TF_RETURN_IF_ERROR(status_);
|
||||
CudaSparse cuda_sparse(ctx);
|
||||
GpuSparse cuda_sparse(ctx);
|
||||
TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
|
||||
{
|
||||
// Use Csrmv to calculate:
|
||||
@ -815,12 +832,20 @@ class CSRSparseMatrixMatVec<GPUDevice, T> {
|
||||
const T alpha = 1;
|
||||
const T beta = 0;
|
||||
|
||||
cusparseMatDescr_t descrA;
|
||||
TF_RETURN_IF_CUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
|
||||
TF_RETURN_IF_CUSPARSE_ERROR(
|
||||
gpusparseMatDescr_t descrA;
|
||||
#if GOOGLE_CUDA
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
|
||||
TF_RETURN_IF_CUSPARSE_ERROR(
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
const int m = a.dense_shape_host(0);
|
||||
const int n = a.dense_shape_host(1);
|
||||
@ -836,11 +861,11 @@ class CSRSparseMatrixMatVec<GPUDevice, T> {
|
||||
|
||||
private:
|
||||
Status status_;
|
||||
const cusparseOperation_t transA_;
|
||||
const gpusparseOperation_t transA_;
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
@ -28,7 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#endif
|
||||
|
||||
@ -101,22 +101,24 @@ class CSRMulOp : public OpKernel {
|
||||
Name("SparseMatrixMul").Device(DEVICE_##DEV).TypeConstraint<T>("T"), \
|
||||
CSRMulOp<DEV##Device, T>);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define REGISTER_GPU(T) REGISTER(GPU, T)
|
||||
|
||||
REGISTER_GPU(float)
|
||||
REGISTER_GPU(double)
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_GPU(complex64)
|
||||
REGISTER_GPU(complex128)
|
||||
#endif
|
||||
|
||||
#undef REGISTER_GPU
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#undef REGISTER
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
namespace functor {
|
||||
|
||||
@ -159,13 +161,15 @@ class CSRSparseMatrixMulScalar<GPUDevice, T> {
|
||||
|
||||
DECLARE_GPU_SPEC(float);
|
||||
DECLARE_GPU_SPEC(double);
|
||||
#if GOOGLE_CUDA
|
||||
DECLARE_GPU_SPEC(std::complex<float>);
|
||||
DECLARE_GPU_SPEC(std::complex<double>);
|
||||
#endif
|
||||
|
||||
#undef DECLARE_GPU_SPEC
|
||||
|
||||
} // namespace functor
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
@ -28,7 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#endif
|
||||
@ -67,11 +67,11 @@ class CSRNNZOp : public OpKernel {
|
||||
|
||||
REGISTER(CPU)
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
REGISTER(GPU)
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#undef REGISTER
|
||||
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
@ -84,7 +84,7 @@ class CSRSoftmaxOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define REGISTER(DEV, T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseMatrixSoftmax") \
|
||||
.Device(DEVICE_##DEV) \
|
||||
@ -110,7 +110,7 @@ DECLARE_GPU_SPEC(double);
|
||||
#undef DECLARE_GPU_SPEC
|
||||
} // namespace functor
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
template <typename Device, typename T>
|
||||
class CSRSoftmaxGradOp : public OpKernel {
|
||||
@ -193,7 +193,7 @@ class CSRSoftmaxGradOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define REGISTER(DEV, T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseMatrixSoftmaxGrad") \
|
||||
.Device(DEVICE_##DEV) \
|
||||
@ -220,6 +220,6 @@ DECLARE_GPU_SPEC(double);
|
||||
#undef DECLARE_GPU_SPEC
|
||||
} // namespace functor
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
@ -35,7 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#endif
|
||||
@ -498,22 +498,24 @@ REGISTER_CPU(complex128)
|
||||
.TypeConstraint<T>("type"), \
|
||||
CSRSparseMatMulGPUOp<DEV##Device, T>);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define REGISTER_GPU(T) REGISTER(GPU, T)
|
||||
|
||||
REGISTER_GPU(float)
|
||||
REGISTER_GPU(double)
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_GPU(complex64)
|
||||
REGISTER_GPU(complex128)
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#undef REGISTER_GPU
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#undef REGISTER
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
namespace functor {
|
||||
template <typename T>
|
||||
struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
||||
@ -527,11 +529,20 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
||||
adjoint_a_(adjoint_a),
|
||||
transpose_b_(transpose_b) {
|
||||
// TODO(ebrevdo): Figure out why transposed implementations crash cuSparse.
|
||||
#if GOOGLE_CUDA
|
||||
transA_ = transpose_a ? (adjoint_a ? CUSPARSE_OPERATION_TRANSPOSE
|
||||
: CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE)
|
||||
: CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
transB_ = transpose_b ? CUSPARSE_OPERATION_TRANSPOSE
|
||||
: CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
transA_ = transpose_a
|
||||
? (adjoint_a ? HIPSPARSE_OPERATION_TRANSPOSE
|
||||
: HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE)
|
||||
: HIPSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
transB_ = transpose_b ? HIPSPARSE_OPERATION_TRANSPOSE
|
||||
: HIPSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
}
|
||||
|
||||
Status Initialize() {
|
||||
@ -630,20 +641,20 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
||||
|
||||
private:
|
||||
OpKernelContext* ctx_;
|
||||
CudaSparse cuda_sparse_;
|
||||
GpuSparse cuda_sparse_;
|
||||
bool initialized_;
|
||||
bool transpose_a_;
|
||||
bool adjoint_a_;
|
||||
bool transpose_b_;
|
||||
CudaSparseMatrixDescriptor descrA_;
|
||||
CudaSparseMatrixDescriptor descrB_;
|
||||
CudaSparseMatrixDescriptor descrC_;
|
||||
cusparseOperation_t transA_;
|
||||
cusparseOperation_t transB_;
|
||||
GpuSparseMatrixDescriptor descrA_;
|
||||
GpuSparseMatrixDescriptor descrB_;
|
||||
GpuSparseMatrixDescriptor descrC_;
|
||||
gpusparseOperation_t transA_;
|
||||
gpusparseOperation_t transB_;
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
@ -29,7 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#endif
|
||||
@ -116,12 +116,14 @@ REGISTER(CPU, double)
|
||||
REGISTER(CPU, complex64)
|
||||
REGISTER(CPU, complex128)
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
REGISTER(GPU, float)
|
||||
REGISTER(GPU, double)
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER(GPU, complex64)
|
||||
REGISTER(GPU, complex128)
|
||||
#endif
|
||||
|
||||
#undef REGISTER
|
||||
|
||||
@ -139,12 +141,14 @@ namespace functor {
|
||||
DECLARE_GPU_SPEC(int32);
|
||||
DECLARE_GPU_SPEC(float);
|
||||
DECLARE_GPU_SPEC(double);
|
||||
#if GOOGLE_CUDA
|
||||
DECLARE_GPU_SPEC(complex64);
|
||||
DECLARE_GPU_SPEC(complex128);
|
||||
#endif
|
||||
|
||||
#undef DECLARE_GPU_SPEC
|
||||
} // namespace functor
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
@ -30,13 +30,18 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
|
||||
#endif
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
|
||||
using ::perftools::gputools::cuda::ScopedActivateExecutorContext;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/stream_executor/rocm/rocm_activation.h"
|
||||
using ::perftools::gputools::rocm::ScopedActivateExecutorContext;
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
@ -104,7 +109,7 @@ class SparseTensorToCSRSparseMatrixCPUOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
template <typename Device, typename T>
|
||||
class SparseTensorToCSRSparseMatrixGPUOp : public AsyncOpKernel {
|
||||
@ -302,7 +307,7 @@ struct COOSparseMatrixToCSRSparseMatrix<GPUDevice> {
|
||||
Status operator()(OpKernelContext* c, const int rows, const int cols,
|
||||
TTypes<int>::UnalignedVec coo_row_ind,
|
||||
TTypes<int>::UnalignedVec csr_row_ptr) {
|
||||
CudaSparse cuda_sparse(c);
|
||||
GpuSparse cuda_sparse(c);
|
||||
TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
|
||||
return cuda_sparse.Coo2csr(coo_row_ind.data(),
|
||||
/*nnz*/ coo_row_ind.size(),
|
||||
@ -322,12 +327,14 @@ extern template struct COOSparseMatrixToCSRSparseMatrix<GPUDevice>;
|
||||
|
||||
REGISTER_GPU(float)
|
||||
REGISTER_GPU(double)
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_GPU(complex64)
|
||||
REGISTER_GPU(complex128)
|
||||
#endif
|
||||
|
||||
#undef REGISTER_GPU
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define REGISTER_CPU(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseTensorToCSRSparseMatrix") \
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
@ -132,9 +132,12 @@ REGISTER_TRANSPOSE(CPU, double)
|
||||
REGISTER_TRANSPOSE(CPU, complex64)
|
||||
REGISTER_TRANSPOSE(CPU, complex128)
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
REGISTER_TRANSPOSE(GPU, float)
|
||||
REGISTER_TRANSPOSE(GPU, double)
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_TRANSPOSE(GPU, complex64)
|
||||
REGISTER_TRANSPOSE(GPU, complex128)
|
||||
#endif // GOOGLE_CUDA
|
||||
@ -250,16 +253,20 @@ struct CSRSparseMatrixTransposeComponent<CPUDevice, T> {
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
template <typename T>
|
||||
struct CSRSparseMatrixTransposeComponent<GPUDevice, T> {
|
||||
Status operator()(OpKernelContext* ctx, const ConstCSRComponent<T>& x,
|
||||
CSRComponent<T>* y) {
|
||||
TF_RETURN_IF_ERROR(ValidateTransposeInputs(x, *y));
|
||||
CudaSparse cuda_sparse(ctx);
|
||||
GpuSparse cuda_sparse(ctx);
|
||||
TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
|
||||
const cusparseAction_t copyValues = CUSPARSE_ACTION_NUMERIC;
|
||||
#if GOOGLE_CUDA
|
||||
const gpusparseAction_t copyValues = CUSPARSE_ACTION_NUMERIC;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
const gpusparseAction_t copyValues = HIPSPARSE_ACTION_NUMERIC;
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
const int rank = x.dense_shape_host.size();
|
||||
const int m = x.row_ptr.size() - 1;
|
||||
const int n = x.dense_shape_host(rank - 1);
|
||||
@ -279,7 +286,7 @@ struct CSRSparseMatrixTransposeComponent<GPUDevice, T> {
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
} // namespace functor
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
@ -74,7 +74,7 @@ Status CSRSparseMatrixZerosLikeHelper(OpKernelContext* ctx,
|
||||
|
||||
} // namespace
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define REGISTER(DEV) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseMatrixZeros") \
|
||||
.Device(DEVICE_##DEV) \
|
||||
@ -88,6 +88,6 @@ REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(
|
||||
CSRSparseMatrixZerosLikeHelper<GPUDevice>);
|
||||
|
||||
#undef REGISTER
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
|
@ -156,7 +156,7 @@ class TridiagonalSolveOpGpuLinalg : public LinearAlgebraOp<Scalar> {
|
||||
k);
|
||||
return;
|
||||
}
|
||||
std::unique_ptr<CudaSparse> cusparse_solver(new CudaSparse(context));
|
||||
std::unique_ptr<GpuSparse> cusparse_solver(new GpuSparse(context));
|
||||
OP_REQUIRES_OK(context, cusparse_solver->Initialize());
|
||||
if (k == 1) {
|
||||
// rhs is copied into x, then gtsv replaces x with solution.
|
||||
@ -196,20 +196,20 @@ class TridiagonalSolveOpGpuLinalg : public LinearAlgebraOp<Scalar> {
|
||||
}
|
||||
|
||||
void SolveWithGtsv(OpKernelContext* context,
|
||||
std::unique_ptr<CudaSparse>& cusparse_solver,
|
||||
std::unique_ptr<GpuSparse>& cusparse_solver,
|
||||
const Scalar* superdiag, const Scalar* diag,
|
||||
const Scalar* subdiag, Scalar* rhs, const int num_eqs,
|
||||
const int num_rhs) const {
|
||||
#if CUDA_VERSION < 9000
|
||||
auto function = pivoting_ ? &CudaSparse::Gtsv<Scalar>
|
||||
: &CudaSparse::GtsvNoPivot<Scalar>;
|
||||
auto function =
|
||||
pivoting_ ? &GpuSparse::Gtsv<Scalar> : &GpuSparse::GtsvNoPivot<Scalar>;
|
||||
OP_REQUIRES_OK(
|
||||
context, (cusparse_solver.get()->*function)(
|
||||
num_eqs, num_rhs, subdiag, diag, superdiag, rhs, num_eqs));
|
||||
#else
|
||||
auto buffer_function = pivoting_
|
||||
? &CudaSparse::Gtsv2BufferSizeExt<Scalar>
|
||||
: &CudaSparse::Gtsv2NoPivotBufferSizeExt<Scalar>;
|
||||
? &GpuSparse::Gtsv2BufferSizeExt<Scalar>
|
||||
: &GpuSparse::Gtsv2NoPivotBufferSizeExt<Scalar>;
|
||||
size_t buffer_size;
|
||||
OP_REQUIRES_OK(context, (cusparse_solver.get()->*buffer_function)(
|
||||
num_eqs, num_rhs, subdiag, diag, superdiag, rhs,
|
||||
@ -220,8 +220,8 @@ class TridiagonalSolveOpGpuLinalg : public LinearAlgebraOp<Scalar> {
|
||||
context->allocate_temp(DT_UINT8, temp_shape, &temp_tensor));
|
||||
void* buffer = temp_tensor.flat<std::uint8_t>().data();
|
||||
|
||||
auto solver_function = pivoting_ ? &CudaSparse::Gtsv2<Scalar>
|
||||
: &CudaSparse::Gtsv2NoPivot<Scalar>;
|
||||
auto solver_function = pivoting_ ? &GpuSparse::Gtsv2<Scalar>
|
||||
: &GpuSparse::Gtsv2NoPivot<Scalar>;
|
||||
OP_REQUIRES_OK(context, (cusparse_solver.get()->*solver_function)(
|
||||
num_eqs, num_rhs, subdiag, diag, superdiag, rhs,
|
||||
num_eqs, buffer));
|
||||
@ -315,7 +315,7 @@ class TridiagonalSolveOpGpu : public OpKernel {
|
||||
rhs.flat<Scalar>().size());
|
||||
Scalar* x = output->flat<Scalar>().data();
|
||||
|
||||
std::unique_ptr<CudaSparse> cusparse_solver(new CudaSparse(context));
|
||||
std::unique_ptr<GpuSparse> cusparse_solver(new GpuSparse(context));
|
||||
|
||||
OP_REQUIRES_OK(context, cusparse_solver->Initialize());
|
||||
#if CUDA_VERSION < 9000
|
||||
|
@ -28,7 +28,6 @@ cuda_py_test(
|
||||
size = "medium",
|
||||
srcs = ["csr_sparse_matrix_test.py"],
|
||||
main = "csr_sparse_matrix_test.py",
|
||||
tags = ["no_rocm"],
|
||||
deps = [
|
||||
"//tensorflow/python/ops/linalg/sparse",
|
||||
],
|
||||
@ -40,7 +39,6 @@ cuda_py_test(
|
||||
srcs = ["csr_sparse_matrix_ops_test.py"],
|
||||
main = "csr_sparse_matrix_ops_test.py",
|
||||
shard_count = 10,
|
||||
tags = ["no_rocm"],
|
||||
deps = [
|
||||
"//tensorflow/python/ops/linalg/sparse",
|
||||
"//tensorflow/python/ops/linalg/sparse:gen_sparse_csr_matrix_ops",
|
||||
@ -53,7 +51,6 @@ cuda_py_test(
|
||||
srcs = ["csr_sparse_matrix_grad_test.py"],
|
||||
main = "csr_sparse_matrix_grad_test.py",
|
||||
shard_count = 50,
|
||||
tags = ["no_rocm"],
|
||||
deps = [
|
||||
"//tensorflow/python/ops/linalg/sparse",
|
||||
],
|
||||
@ -65,7 +62,6 @@ cuda_py_test(
|
||||
srcs = ["csr_sparse_matrix_dense_mat_mul_grad_test.py"],
|
||||
main = "csr_sparse_matrix_dense_mat_mul_grad_test.py",
|
||||
shard_count = 50,
|
||||
tags = ["no_rocm"],
|
||||
deps = [
|
||||
"//tensorflow/python/ops/linalg/sparse",
|
||||
],
|
||||
@ -77,7 +73,6 @@ cuda_py_test(
|
||||
srcs = ["csr_sparse_matrix_sparse_mat_mul_grad_test.py"],
|
||||
main = "csr_sparse_matrix_sparse_mat_mul_grad_test.py",
|
||||
shard_count = 50,
|
||||
tags = ["no_rocm"],
|
||||
deps = [
|
||||
"//tensorflow/python/ops/linalg/sparse",
|
||||
],
|
||||
|
@ -106,7 +106,11 @@ class CSRSparseMatrixDenseMatMulGradTest(test.TestCase):
|
||||
|
||||
# These tests are refactored from sparse_csr_matrix_grad_test to keep its size
|
||||
# "medium".
|
||||
for dtype in (np.float32, np.complex64):
|
||||
dtypes_to_test = [np.float32]
|
||||
if not test.is_built_with_rocm:
|
||||
# complex type is not supported on the ROCm platform
|
||||
dtypes_to_test += [np.complex64]
|
||||
for dtype in dtypes_to_test:
|
||||
for (t_a, t_b, adj_a, adj_b, t_out,
|
||||
conj_out) in itertools.product(*(([False, True],) * 6)):
|
||||
|
||||
|
@ -84,6 +84,9 @@ class CSRSparseMatrixGradTest(test.TestCase):
|
||||
if not self._gpu_available:
|
||||
return
|
||||
|
||||
if test.is_built_with_rocm():
|
||||
self.skipTest("sparse-matrix-add op not supported on ROCm")
|
||||
|
||||
sparsify = lambda m: m * (m > 0)
|
||||
for dense_shape in ([53, 65, 127], [127, 65]):
|
||||
a_mats_val = sparsify(np.random.randn(*dense_shape))
|
||||
|
@ -432,6 +432,9 @@ class CSRSparseMatrixOpsTest(test.TestCase):
|
||||
if not self._gpu_available:
|
||||
return
|
||||
|
||||
if test.is_built_with_rocm():
|
||||
self.skipTest("sparse-matrix-add op not supported on ROCm")
|
||||
|
||||
a_indices = np.array([[0, 0], [2, 3]])
|
||||
a_values = np.array([1.0, 5.0]).astype(np.float32)
|
||||
a_dense_shape = [5, 6]
|
||||
@ -469,6 +472,9 @@ class CSRSparseMatrixOpsTest(test.TestCase):
|
||||
if not self._gpu_available:
|
||||
return
|
||||
|
||||
if test.is_built_with_rocm():
|
||||
self.skipTest("sparse-matrix-add op not supported on ROCm")
|
||||
|
||||
sparsify = lambda m: m * (m > 0)
|
||||
dense_shape = [53, 65, 127]
|
||||
a_mats = sparsify(np.random.randn(*dense_shape)).astype(np.float32)
|
||||
@ -511,6 +517,9 @@ class CSRSparseMatrixOpsTest(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testSparseMatrixMatMulConjugateOutput(self):
|
||||
if test.is_built_with_rocm():
|
||||
self.skipTest("complex type not supported on ROCm")
|
||||
|
||||
for shapes in [[(5, 6), (6, 1)], [(5, 6), (6, 2)]]:
|
||||
a_indices = np.array([[0, 0], [2, 3]])
|
||||
a_values = np.array([1.0 + 1.j, 5.0 - 2.j]).astype(np.complex64)
|
||||
@ -533,8 +542,19 @@ class CSRSparseMatrixOpsTest(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testLargeBatchSparseMatrixMatMul(self):
|
||||
dtypes_to_test = [np.float32]
|
||||
if not test.is_built_with_rocm():
|
||||
# complex types is not supported on the ROCm platform
|
||||
dtypes_to_test += [np.complex64]
|
||||
|
||||
if test.is_built_with_rocm():
|
||||
# TODO(rocm): fix this
|
||||
# This test is currently failing on the ROCm platform
|
||||
# Ren-enable it once the fix is available
|
||||
self.skipTest("hipSPARSE all failure on the ROCm platform")
|
||||
|
||||
sparsify = lambda m: m * (m > 0)
|
||||
for dtype in np.float32, np.complex64:
|
||||
for dtype in dtypes_to_test:
|
||||
for (transpose_a, transpose_b) in ((False, False), (False, True),
|
||||
(True, False), (True, True)):
|
||||
for (adjoint_a, adjoint_b) in ((False, False), (False, True),
|
||||
@ -584,8 +604,19 @@ class CSRSparseMatrixOpsTest(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testLargeBatchSparseMatrixMatMulTransposed(self):
|
||||
dtypes_to_test = [np.float32]
|
||||
if not test.is_built_with_rocm():
|
||||
# complex types is not supported on the ROCm platform
|
||||
dtypes_to_test += [np.complex64]
|
||||
|
||||
if test.is_built_with_rocm():
|
||||
# TODO(rocm): fix this
|
||||
# This test is currently failing on the ROCm platform
|
||||
# Ren-enable it once the fix is available
|
||||
self.skipTest("hipSPARSE all failure on the ROCm platform")
|
||||
|
||||
sparsify = lambda m: m * (m > 0)
|
||||
for dtype in np.float32, np.complex64:
|
||||
for dtype in dtypes_to_test:
|
||||
for (transpose_a, transpose_b) in ((False, False), (False, True),
|
||||
(True, False), (True, True)):
|
||||
for (adjoint_a, adjoint_b) in ((False, False), (False, True),
|
||||
@ -636,6 +667,10 @@ class CSRSparseMatrixOpsTest(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testLargeBatchSparseMatrixMatMulConjugate(self):
|
||||
if test.is_built_with_rocm():
|
||||
# complex types are not yet supported on the ROCm platform
|
||||
self.skipTest("complex type not supported on ROCm")
|
||||
|
||||
sparsify = lambda m: m * (m > 0)
|
||||
a_dense_shape = [53, 65, 127]
|
||||
b_dense_shape = [53, 127, 67]
|
||||
@ -767,6 +802,10 @@ class CSRSparseMatrixOpsTest(test.TestCase):
|
||||
if not self._gpu_available:
|
||||
return
|
||||
|
||||
if test.is_built_with_rocm():
|
||||
# sparse-matrix-add op is not yet supported on the ROCm platform
|
||||
self.skipTest("sparse-matrix-add op not supported on ROCm")
|
||||
|
||||
sparsify = lambda m: m * (m > 0)
|
||||
dense_shape = [53, 65, 127]
|
||||
matrices = [
|
||||
@ -1154,9 +1193,10 @@ class CSRSparseMatrixOpsTest(test.TestCase):
|
||||
] #
|
||||
]).astype(np.complex128)
|
||||
|
||||
data_types = [
|
||||
dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128
|
||||
]
|
||||
data_types = [dtypes.float32, dtypes.float64]
|
||||
if not test.is_built_with_rocm():
|
||||
# complex type is not supported on the ROCm platform
|
||||
data_types += [dtypes.complex64, dtypes.complex128]
|
||||
for dtype in data_types:
|
||||
sparse_matrix = dense_to_csr_sparse_matrix(
|
||||
math_ops.cast(dense_mat, dtype))
|
||||
|
@ -154,7 +154,11 @@ class SparseMatrixMatmulTest(test.TestCase):
|
||||
sparsify = lambda m: m * (m > 0)
|
||||
dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13]
|
||||
dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15]
|
||||
for dtype in np.float32, np.complex64:
|
||||
dtypes_to_test = [np.float32]
|
||||
if not test.is_built_with_rocm():
|
||||
# complex type is not supported on the ROCm platform
|
||||
dtypes_to_test += [np.complex64]
|
||||
for dtype in dtypes_to_test:
|
||||
a_mats = sparsify((np.random.randn(*dense_shape_a) +
|
||||
1.j * np.random.randn(*dense_shape_a))).astype(dtype)
|
||||
b_mats = sparsify((np.random.randn(*dense_shape_b) +
|
||||
@ -194,7 +198,11 @@ class SparseMatrixMatmulTest(test.TestCase):
|
||||
sparsify = lambda m: m * (m > 0)
|
||||
dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13]
|
||||
dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15]
|
||||
for dtype in np.float32, np.complex64:
|
||||
dtypes_to_test = [np.float32]
|
||||
if not test.is_built_with_rocm():
|
||||
# complex type is not supported on the ROCm platform
|
||||
dtypes_to_test += [np.complex64]
|
||||
for dtype in dtypes_to_test:
|
||||
a_mats = sparsify((np.random.randn(*dense_shape_a) +
|
||||
1.j * np.random.randn(*dense_shape_a))).astype(dtype)
|
||||
b_mats = (np.random.randn(*dense_shape_b) +
|
||||
@ -231,7 +239,11 @@ class SparseMatrixMatmulTest(test.TestCase):
|
||||
sparsify = lambda m: m * (m > 0)
|
||||
dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13]
|
||||
dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15]
|
||||
for dtype in np.float32, np.complex64:
|
||||
dtypes_to_test = [np.float32]
|
||||
if not test.is_built_with_rocm():
|
||||
# complex type is not supported on the ROCm platform
|
||||
dtypes_to_test += [np.complex64]
|
||||
for dtype in dtypes_to_test:
|
||||
a_mats = (np.random.randn(*dense_shape_a) +
|
||||
1.j * np.random.randn(*dense_shape_a)).astype(dtype)
|
||||
b_mats = sparsify((np.random.randn(*dense_shape_b) +
|
||||
|
7
third_party/gpus/rocm/BUILD.tpl
vendored
7
third_party/gpus/rocm/BUILD.tpl
vendored
@ -137,4 +137,11 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_import(
|
||||
name = "hipsparse",
|
||||
hdrs = glob(["rocm/include/hipsparse/**",]),
|
||||
shared_library = "rocm/lib/%{hipsparse_lib}",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
%{copy_rules}
|
||||
|
2
third_party/gpus/rocm/rocm_config.h.tpl
vendored
2
third_party/gpus/rocm/rocm_config.h.tpl
vendored
@ -16,6 +16,6 @@ limitations under the License.
|
||||
#ifndef ROCM_ROCM_CONFIG_H_
|
||||
#define ROCM_ROCM_CONFIG_H_
|
||||
|
||||
#define TF_ROCM_TOOLKIT_PATH "/opt/rocm"
|
||||
#define TF_ROCM_TOOLKIT_PATH "%{rocm_toolkit_path}"
|
||||
|
||||
#endif // ROCM_ROCM_CONFIG_H_
|
||||
|
85
third_party/gpus/rocm_configure.bzl
vendored
85
third_party/gpus/rocm_configure.bzl
vendored
@ -191,50 +191,50 @@ def _rocm_include_path(repository_ctx, rocm_config):
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/include")
|
||||
|
||||
# Add HSA headers
|
||||
inc_dirs.append("/opt/rocm/hsa/include")
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/hsa/include")
|
||||
|
||||
# Add HIP headers
|
||||
inc_dirs.append("/opt/rocm/include/hip")
|
||||
inc_dirs.append("/opt/rocm/include/hip/hcc_detail")
|
||||
inc_dirs.append("/opt/rocm/hip/include")
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/hip")
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/hip/hcc_detail")
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/hip/include")
|
||||
|
||||
# Add HIP-Clang headers
|
||||
inc_dirs.append("/opt/rocm/llvm/lib/clang/8.0/include")
|
||||
inc_dirs.append("/opt/rocm/llvm/lib/clang/9.0.0/include")
|
||||
inc_dirs.append("/opt/rocm/llvm/lib/clang/10.0.0/include")
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/llvm/lib/clang/8.0/include")
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/llvm/lib/clang/9.0.0/include")
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/llvm/lib/clang/10.0.0/include")
|
||||
|
||||
# Add rocrand and hiprand headers
|
||||
inc_dirs.append("/opt/rocm/rocrand/include")
|
||||
inc_dirs.append("/opt/rocm/hiprand/include")
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/rocrand/include")
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/hiprand/include")
|
||||
|
||||
# Add rocfft headers
|
||||
inc_dirs.append("/opt/rocm/rocfft/include")
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/rocfft/include")
|
||||
|
||||
# Add rocBLAS headers
|
||||
inc_dirs.append("/opt/rocm/rocblas/include")
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/rocblas/include")
|
||||
|
||||
# Add MIOpen headers
|
||||
inc_dirs.append("/opt/rocm/miopen/include")
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/miopen/include")
|
||||
|
||||
# Add RCCL headers
|
||||
inc_dirs.append("/opt/rocm/rccl/include")
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/rccl/include")
|
||||
|
||||
# Add hcc headers
|
||||
inc_dirs.append("/opt/rocm/hcc/include")
|
||||
inc_dirs.append("/opt/rocm/hcc/compiler/lib/clang/7.0.0/include/")
|
||||
inc_dirs.append("/opt/rocm/hcc/lib/clang/7.0.0/include")
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/include")
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/compiler/lib/clang/7.0.0/include/")
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/lib/clang/7.0.0/include")
|
||||
|
||||
# Newer hcc builds use/are based off of clang 8.0.0.
|
||||
inc_dirs.append("/opt/rocm/hcc/compiler/lib/clang/8.0.0/include/")
|
||||
inc_dirs.append("/opt/rocm/hcc/lib/clang/8.0.0/include")
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/compiler/lib/clang/8.0.0/include/")
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/lib/clang/8.0.0/include")
|
||||
|
||||
# Support hcc based off clang 9.0.0, included in ROCm2.2
|
||||
inc_dirs.append("/opt/rocm/hcc/compiler/lib/clang/9.0.0/include/")
|
||||
inc_dirs.append("/opt/rocm/hcc/lib/clang/9.0.0/include")
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/compiler/lib/clang/9.0.0/include/")
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/lib/clang/9.0.0/include")
|
||||
|
||||
# Support hcc based off clang 10.0.0, included in ROCm2.8
|
||||
inc_dirs.append("/opt/rocm/hcc/compiler/lib/clang/10.0.0/include/")
|
||||
inc_dirs.append("/opt/rocm/hcc/lib/clang/10.0.0/include")
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/compiler/lib/clang/10.0.0/include/")
|
||||
inc_dirs.append(rocm_config.rocm_toolkit_path + "/hcc/lib/clang/10.0.0/include")
|
||||
|
||||
return inc_dirs
|
||||
|
||||
@ -300,11 +300,12 @@ def _hipcc_env(repository_ctx):
|
||||
repository_ctx.os.environ[name].strip() + "\";")
|
||||
return hipcc_env.strip()
|
||||
|
||||
def _hipcc_is_hipclang(repository_ctx):
|
||||
def _hipcc_is_hipclang(repository_ctx, rocm_config):
|
||||
"""Returns if hipcc is based on hip-clang toolchain.
|
||||
|
||||
Args:
|
||||
repository_ctx: The repository context.
|
||||
rocm_config: The path to the hip compiler.
|
||||
|
||||
Returns:
|
||||
A string "True" if hipcc is based on hip-clang toolchain.
|
||||
@ -319,7 +320,7 @@ def _hipcc_is_hipclang(repository_ctx):
|
||||
# grep for "HIP_COMPILER=clang" in /opt/rocm/hip/lib/.hipInfo
|
||||
grep_result = _execute(
|
||||
repository_ctx,
|
||||
["grep", "HIP_COMPILER=clang", "/opt/rocm/hip/lib/.hipInfo"],
|
||||
["grep", "HIP_COMPILER=clang", rocm_config.rocm_toolkit_path + "/hip/lib/.hipInfo"],
|
||||
empty_stdout_fine = True,
|
||||
)
|
||||
result = grep_result.stdout.strip()
|
||||
@ -327,13 +328,14 @@ def _hipcc_is_hipclang(repository_ctx):
|
||||
return "True"
|
||||
return "False"
|
||||
|
||||
def _if_hipcc_is_hipclang(repository_ctx, if_true, if_false = []):
|
||||
def _if_hipcc_is_hipclang(repository_ctx, rocm_config, if_true, if_false = []):
|
||||
"""
|
||||
Returns either the if_true or if_false arg based on whether hipcc
|
||||
is based on the hip-clang toolchain
|
||||
|
||||
Args :
|
||||
repository_ctx: The repository context.
|
||||
rocm_config: The path to the hip compiler.
|
||||
if_true : value to return if hipcc is hip-clang based
|
||||
if_false : value to return if hipcc is not hip-clang based
|
||||
(optional, defaults to empty list)
|
||||
@ -341,7 +343,7 @@ def _if_hipcc_is_hipclang(repository_ctx, if_true, if_false = []):
|
||||
Returns :
|
||||
either the if_true arg or the of_False arg
|
||||
"""
|
||||
if _hipcc_is_hipclang(repository_ctx) == "True":
|
||||
if _hipcc_is_hipclang(repository_ctx, rocm_config) == "True":
|
||||
return if_true
|
||||
return if_false
|
||||
|
||||
@ -478,6 +480,11 @@ def _find_libs(repository_ctx, rocm_config):
|
||||
repository_ctx,
|
||||
rocm_config.rocm_toolkit_path + "/rccl",
|
||||
),
|
||||
"hipsparse": _find_rocm_lib(
|
||||
"hipsparse",
|
||||
repository_ctx,
|
||||
rocm_config.rocm_toolkit_path + "/hipsparse",
|
||||
),
|
||||
}
|
||||
|
||||
def _get_rocm_config(repository_ctx):
|
||||
@ -558,6 +565,7 @@ def _create_dummy_repository(repository_ctx):
|
||||
"%{rccl_lib}": _lib_name("rccl"),
|
||||
"%{rocfft_lib}": _lib_name("rocfft"),
|
||||
"%{hiprand_lib}": _lib_name("hiprand"),
|
||||
"%{hipsparse_lib}": _lib_name("hipsparse"),
|
||||
"%{copy_rules}": "",
|
||||
"%{rocm_headers}": "",
|
||||
},
|
||||
@ -703,6 +711,12 @@ def _create_local_rocm_repository(repository_ctx):
|
||||
src_dir = rocm_toolkit_path + "/rccl/include",
|
||||
out_dir = "rocm/include/rccl",
|
||||
),
|
||||
make_copy_dir_rule(
|
||||
repository_ctx,
|
||||
name = "hipsparse-include",
|
||||
src_dir = rocm_toolkit_path + "/hipsparse/include",
|
||||
out_dir = "rocm/include/hipsparse",
|
||||
),
|
||||
]
|
||||
|
||||
rocm_libs = _find_libs(repository_ctx, rocm_config)
|
||||
@ -740,16 +754,19 @@ def _create_local_rocm_repository(repository_ctx):
|
||||
"%{hiprand_lib}": rocm_libs["hiprand"].file_name,
|
||||
"%{miopen_lib}": rocm_libs["miopen"].file_name,
|
||||
"%{rccl_lib}": rocm_libs["rccl"].file_name,
|
||||
"%{hipsparse_lib}": rocm_libs["hipsparse"].file_name,
|
||||
"%{copy_rules}": "\n".join(copy_rules),
|
||||
"%{rocm_headers}": ('":rocm-include",\n' +
|
||||
'":rocfft-include",\n' +
|
||||
'":rocblas-include",\n' +
|
||||
'":miopen-include",\n' +
|
||||
'":rccl-include",'),
|
||||
'":rccl-include",\n' +
|
||||
'":hipsparse-include",'),
|
||||
},
|
||||
)
|
||||
|
||||
# Set up crosstool/
|
||||
|
||||
cc = find_cc(repository_ctx)
|
||||
|
||||
host_compiler_includes = get_cxx_inc_directories(repository_ctx, cc)
|
||||
@ -762,7 +779,7 @@ def _create_local_rocm_repository(repository_ctx):
|
||||
|
||||
rocm_defines["%{host_compiler_prefix}"] = host_compiler_prefix
|
||||
|
||||
rocm_defines["%{linker_bin_path}"] = "/opt/rocm/hcc/compiler/bin"
|
||||
rocm_defines["%{linker_bin_path}"] = rocm_config.rocm_toolkit_path + "/hcc/compiler/bin"
|
||||
|
||||
# For gcc, do not canonicalize system header paths; some versions of gcc
|
||||
# pick the shortest possible path for system includes when creating the
|
||||
@ -775,7 +792,7 @@ def _create_local_rocm_repository(repository_ctx):
|
||||
"-DTENSORFLOW_USE_ROCM=1",
|
||||
"-D__HIP_PLATFORM_HCC__",
|
||||
"-DEIGEN_USE_HIP",
|
||||
] + _if_hipcc_is_hipclang(repository_ctx, [
|
||||
] + _if_hipcc_is_hipclang(repository_ctx, rocm_config, [
|
||||
#
|
||||
# define "TENSORFLOW_COMPILER_IS_HIP_CLANG" when we are using clang
|
||||
# based hipcc to compile/build tensorflow
|
||||
@ -815,14 +832,14 @@ def _create_local_rocm_repository(repository_ctx):
|
||||
"crosstool:clang/bin/crosstool_wrapper_driver_rocm",
|
||||
{
|
||||
"%{cpu_compiler}": str(cc),
|
||||
"%{hipcc_path}": "/opt/rocm/bin/hipcc",
|
||||
"%{hipcc_path}": rocm_config.rocm_toolkit_path + "/bin/hipcc",
|
||||
"%{hipcc_env}": _hipcc_env(repository_ctx),
|
||||
"%{hipcc_is_hipclang}": _hipcc_is_hipclang(repository_ctx),
|
||||
"%{rocr_runtime_path}": "/opt/rocm/lib",
|
||||
"%{hipcc_is_hipclang}": _hipcc_is_hipclang(repository_ctx, rocm_config),
|
||||
"%{rocr_runtime_path}": rocm_config.rocm_toolkit_path + "/lib",
|
||||
"%{rocr_runtime_library}": "hsa-runtime64",
|
||||
"%{hip_runtime_path}": "/opt/rocm/hip/lib",
|
||||
"%{hip_runtime_path}": rocm_config.rocm_toolkit_path + "/hip/lib",
|
||||
"%{hip_runtime_library}": "hip_hcc",
|
||||
"%{hcc_runtime_path}": "/opt/rocm/hcc/lib",
|
||||
"%{hcc_runtime_path}": rocm_config.rocm_toolkit_path + "/hcc/lib",
|
||||
"%{hcc_runtime_library}": "mcwamp",
|
||||
"%{crosstool_verbose}": _crosstool_verbose(repository_ctx),
|
||||
"%{gcc_host_compiler_path}": str(cc),
|
||||
|
@ -9,5 +9,5 @@ container_digests = {
|
||||
"cuda10.1-cudnn7-centos6": "sha256:454b899657e87893ee5e68dc0f87df59b6a0a7418ae09cafcc3dd65ac71feca9",
|
||||
"cuda10.0-cudnn7-ubuntu16.04-manylinux2010": "sha256:5812d9d0ef0a3276fc5faaf4cd01f3d6e03d635893a6e2d2e04f6f01d626c432",
|
||||
"cuda10.1-cudnn7-ubuntu16.04-manylinux2010": "sha256:f8e15f08cb501e5f2de3dc450f614609fd3ed19bde74b153fa66d14b2307610c",
|
||||
"rocm-ubuntu16.04": "sha256:d5cd4120cff3d2a452378aad03746ff5f24699d86cf695c20ee96f366e42975f",
|
||||
"rocm-ubuntu16.04": "sha256:e645447dd6127325f3e97b8bf23424f637a8579d963b34fcc6772cf7cfaa0ebe",
|
||||
}
|
||||
|
@ -72,7 +72,7 @@ def _tensorflow_rbe_config(name, compiler, python_version, os, rocm_version = No
|
||||
docker_toolchain_autoconfig(
|
||||
name = name,
|
||||
base = base,
|
||||
bazel_version = "0.29.1",
|
||||
bazel_version = "1.2.1",
|
||||
build_bazel_src = build_bazel_src,
|
||||
config_repos = config_repos,
|
||||
env = env,
|
||||
|
@ -15,6 +15,7 @@ cc_library(
|
||||
name = "rocm_headers",
|
||||
hdrs = [
|
||||
"rocm/rocm_config.h",
|
||||
":hipsparse-include",
|
||||
":miopen-include",
|
||||
":rccl-include",
|
||||
":rocblas-include",
|
||||
@ -141,6 +142,13 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_import(
|
||||
name = "hipsparse",
|
||||
hdrs = glob(["rocm/include/hipsparse/**"]),
|
||||
shared_library = "rocm/lib/libhipsparse.so",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "rocm-include",
|
||||
outs = [
|
||||
@ -175,6 +183,7 @@ genrule(
|
||||
"rocm/include/hcc/clang-c/CXErrorCode.h",
|
||||
"rocm/include/hcc/clang-c/CXString.h",
|
||||
"rocm/include/hcc/clang-c/Documentation.h",
|
||||
"rocm/include/hcc/clang-c/FatalErrorHandler.h",
|
||||
"rocm/include/hcc/clang-c/Index.h",
|
||||
"rocm/include/hcc/clang-c/Platform.h",
|
||||
"rocm/include/hcc/coordinate",
|
||||
@ -275,12 +284,14 @@ genrule(
|
||||
"rocm/include/hip/hcc_detail/hip_prof_str.h",
|
||||
"rocm/include/hip/hcc_detail/hip_runtime.h",
|
||||
"rocm/include/hip/hcc_detail/hip_runtime_api.h",
|
||||
"rocm/include/hip/hcc_detail/hip_runtime_prof.h",
|
||||
"rocm/include/hip/hcc_detail/hip_surface_types.h",
|
||||
"rocm/include/hip/hcc_detail/hip_texture_types.h",
|
||||
"rocm/include/hip/hcc_detail/hip_vector_types.h",
|
||||
"rocm/include/hip/hcc_detail/hiprtc.h",
|
||||
"rocm/include/hip/hcc_detail/host_defines.h",
|
||||
"rocm/include/hip/hcc_detail/hsa_helpers.hpp",
|
||||
"rocm/include/hip/hcc_detail/library_types.h",
|
||||
"rocm/include/hip/hcc_detail/llvm_intrinsics.h",
|
||||
"rocm/include/hip/hcc_detail/macro_based_grid_launch.hpp",
|
||||
"rocm/include/hip/hcc_detail/math_functions.h",
|
||||
@ -292,6 +303,7 @@ genrule(
|
||||
"rocm/include/hip/hip_common.h",
|
||||
"rocm/include/hip/hip_complex.h",
|
||||
"rocm/include/hip/hip_cooperative_groups.h",
|
||||
"rocm/include/hip/hip_ext.h",
|
||||
"rocm/include/hip/hip_fp16.h",
|
||||
"rocm/include/hip/hip_hcc.h",
|
||||
"rocm/include/hip/hip_profile.h",
|
||||
@ -300,6 +312,7 @@ genrule(
|
||||
"rocm/include/hip/hip_texture_types.h",
|
||||
"rocm/include/hip/hip_vector_types.h",
|
||||
"rocm/include/hip/hiprtc.h",
|
||||
"rocm/include/hip/library_types.h",
|
||||
"rocm/include/hip/math_functions.h",
|
||||
"rocm/include/hip/nvcc_detail/channel_descriptor.h",
|
||||
"rocm/include/hip/nvcc_detail/hip_complex.h",
|
||||
@ -441,7 +454,6 @@ genrule(
|
||||
"rocm/include/ocml.h",
|
||||
"rocm/include/opencl1.2-c.pch",
|
||||
"rocm/include/opencl2.0-c.pch",
|
||||
"rocm/include/profiler/CXLActivityLogger/CXLActivityLogger.h",
|
||||
"rocm/include/rccl.h",
|
||||
"rocm/include/rocalution.hpp",
|
||||
"rocm/include/rocblas-auxiliary.h",
|
||||
@ -583,6 +595,7 @@ genrule(
|
||||
"rocm/include/rocrand/rocrand_xorwow.h",
|
||||
"rocm/include/rocrand/rocrand_xorwow_precomputed.h",
|
||||
"rocm/include/rocsparse-auxiliary.h",
|
||||
"rocm/include/rocsparse-complex-types.h",
|
||||
"rocm/include/rocsparse-export.h",
|
||||
"rocm/include/rocsparse-functions.h",
|
||||
"rocm/include/rocsparse-types.h",
|
||||
@ -1468,6 +1481,16 @@ genrule(
|
||||
cmd = """cp -rLf "/opt/rocm/rccl/include/." "$(@D)/" """,
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "hipsparse-include",
|
||||
outs = [
|
||||
"rocm/include/hipsparse/hipsparse-export.h",
|
||||
"rocm/include/hipsparse/hipsparse-version.h",
|
||||
"rocm/include/hipsparse/hipsparse.h",
|
||||
],
|
||||
cmd = """cp -rLf "/opt/rocm/hipsparse/include/." "$(@D)/rocm/include/hipsparse/" """,
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "rocm-lib",
|
||||
outs = [
|
||||
@ -1477,11 +1500,13 @@ genrule(
|
||||
"rocm/lib/libhiprand.so",
|
||||
"rocm/lib/libMIOpen.so",
|
||||
"rocm/lib/librccl.so",
|
||||
"rocm/lib/libhipsparse.so",
|
||||
],
|
||||
cmd = """cp -f "/opt/rocm/hip/lib/libhip_hcc.so" "$(location rocm/lib/libhip_hcc.so)" && \
|
||||
cp -f "/opt/rocm/rocblas/lib/librocblas.so.0.1" "$(location rocm/lib/librocblas.so)" && \
|
||||
cp -f "/opt/rocm/rocfft/lib/librocfft.so.0.1" "$(location rocm/lib/librocfft.so)" && \
|
||||
cp -f "/opt/rocm/hiprand/lib/libhiprand.so.1.1" "$(location rocm/lib/libhiprand.so)" && \
|
||||
cp -f "/opt/rocm/miopen/lib/libMIOpen.so.1" "$(location rocm/lib/libMIOpen.so)" && \
|
||||
cp -f "/opt/rocm/rccl/lib/librccl.so" "$(location rocm/lib/librccl.so)" """,
|
||||
cp -f "/opt/rocm/rccl/lib/librccl.so" "$(location rocm/lib/librccl.so)" && \
|
||||
cp -f "/opt/rocm/hipsparse/lib/libhipsparse.so.0.1" "$(location rocm/lib/libhipsparse.so)" """,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user