[ROCm] Fix for ROCm CSB breakage on 200507

The following PR/commit introduces a build error on the ROCm platform

https://github.com/tensorflow/tensorflow/pull/38802

The error is caused by a call to the `CsrgemmBufferSize` routine which only exists on the CUDA side. The call to it was not guarded by the same #if block that guards the function declaration + definition. Adding the missing #if block fixes the issue.

This PR also adds some explicit `GOOGLE_CUDA &&` and `|| TENSORFLOW_USE_ROCM` conditions to some `#if` to make things clear.
This commit is contained in:
Deven Desai 2020-05-07 20:55:49 +00:00
parent a65ece1e46
commit 5bf0bab331
3 changed files with 19 additions and 23 deletions

View File

@ -259,7 +259,7 @@ class GpuSparse {
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-coo2csr. // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-coo2csr.
Status Coo2csr(const int* cooRowInd, int nnz, int m, int* csrRowPtr) const; Status Coo2csr(const int* cooRowInd, int nnz, int m, int* csrRowPtr) const;
#if CUDA_VERSION < 10020 #if (GOOGLE_CUDA && (CUDA_VERSION < 10020)) || TENSORFLOW_USE_ROCM
// Sparse-dense matrix multiplication C = alpha * op(A) * op(B) + beta * C, // Sparse-dense matrix multiplication C = alpha * op(A) * op(B) + beta * C,
// where A is a sparse matrix in CSR format, B and C are dense tall // where A is a sparse matrix in CSR format, B and C are dense tall
// matrices. This routine allows transposition of matrix B, which // matrices. This routine allows transposition of matrix B, which
@ -311,7 +311,7 @@ class GpuSparse {
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrmv_mergepath // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrmv_mergepath
// //
// **NOTE** This is an in-place operation for data in y. // **NOTE** This is an in-place operation for data in y.
#if CUDA_VERSION < 10020 #if (GOOGLE_CUDA && (CUDA_VERSION < 10020)) || TENSORFLOW_USE_ROCM
template <typename Scalar> template <typename Scalar>
Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz, Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz,
const Scalar* alpha_host, const gpusparseMatDescr_t descrA, const Scalar* alpha_host, const gpusparseMatDescr_t descrA,
@ -366,7 +366,7 @@ class GpuSparse {
Scalar* csrSortedValC, int* csrSortedRowPtrC, Scalar* csrSortedValC, int* csrSortedRowPtrC,
int* csrSortedColIndC, void* workspace); int* csrSortedColIndC, void* workspace);
#if CUDA_VERSION >= 10000 #if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
// Computes sparse-sparse matrix multiplication of matrices // Computes sparse-sparse matrix multiplication of matrices
// stored in CSR format. This is part zero: calculate required workspace // stored in CSR format. This is part zero: calculate required workspace
// size. // size.
@ -383,7 +383,7 @@ class GpuSparse {
// output. csrSortedRowPtrC must be preallocated on device with // output. csrSortedRowPtrC must be preallocated on device with
// m + 1 entries. See: // m + 1 entries. See:
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm. // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
#if CUDA_VERSION < 10000 #if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
Status CsrgemmNnz(gpusparseOperation_t transA, gpusparseOperation_t transB, Status CsrgemmNnz(gpusparseOperation_t transA, gpusparseOperation_t transB,
int m, int k, int n, const gpusparseMatDescr_t descrA, int m, int k, int n, const gpusparseMatDescr_t descrA,
int nnzA, const int* csrSortedRowPtrA, int nnzA, const int* csrSortedRowPtrA,
@ -408,7 +408,7 @@ class GpuSparse {
// addition. csrValC and csrColIndC must be allocated on the device // addition. csrValC and csrColIndC must be allocated on the device
// with nnzTotalDevHostPtr entries (as calculated by CsrgemmNnz). See: // with nnzTotalDevHostPtr entries (as calculated by CsrgemmNnz). See:
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm. // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
#if CUDA_VERSION < 10000 #if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
template <typename Scalar> template <typename Scalar>
Status Csrgemm(gpusparseOperation_t transA, gpusparseOperation_t transB, Status Csrgemm(gpusparseOperation_t transA, gpusparseOperation_t transB,
int m, int k, int n, const gpusparseMatDescr_t descrA, int m, int k, int n, const gpusparseMatDescr_t descrA,

View File

@ -728,12 +728,14 @@ namespace {
template <typename T> template <typename T>
struct GPUDataType; struct GPUDataType;
// GPUDataType templates are currently not instantiated in the ROCm flow
// So leaving out the #elif TENSORFLOW_USE_ROCM blocks for now
// hipblas library is not (yet) being pulled in via rocm_configure.bzl
// so cannot reference tyeps from hipblas headers here
template <> template <>
struct GPUDataType<Eigen::half> { struct GPUDataType<Eigen::half> {
#if GOOGLE_CUDA #if GOOGLE_CUDA
static constexpr cudaDataType_t type = CUDA_R_16F; static constexpr cudaDataType_t type = CUDA_R_16F;
#elif TENSORFLOW_USE_ROCM
static constexpr hipblasDataType_t type = HIPBLAS_R_16F;
#endif #endif
}; };
@ -741,8 +743,6 @@ template <>
struct GPUDataType<float> { struct GPUDataType<float> {
#if GOOGLE_CUDA #if GOOGLE_CUDA
static constexpr cudaDataType_t type = CUDA_R_32F; static constexpr cudaDataType_t type = CUDA_R_32F;
#elif TENSORFLOW_USE_ROCM
static constexpr hipblasDataType_t type = HIPBLAS_R_32F;
#endif #endif
}; };
@ -750,8 +750,6 @@ template <>
struct GPUDataType<std::complex<float>> { struct GPUDataType<std::complex<float>> {
#if GOOGLE_CUDA #if GOOGLE_CUDA
static constexpr cudaDataType_t type = CUDA_C_32F; static constexpr cudaDataType_t type = CUDA_C_32F;
#elif TENSORFLOW_USE_ROCM
static constexpr hipblasDataType_t type = HIPBLAS_C_32F;
#endif #endif
}; };
@ -759,8 +757,6 @@ template <>
struct GPUDataType<double> { struct GPUDataType<double> {
#if GOOGLE_CUDA #if GOOGLE_CUDA
static constexpr cudaDataType_t type = CUDA_R_64F; static constexpr cudaDataType_t type = CUDA_R_64F;
#elif TENSORFLOW_USE_ROCM
static constexpr hipblasDataType_t type = HIPBLAS_R_64F;
#endif #endif
}; };
@ -768,8 +764,6 @@ template <>
struct GPUDataType<std::complex<double>> { struct GPUDataType<std::complex<double>> {
#if GOOGLE_CUDA #if GOOGLE_CUDA
static constexpr cudaDataType_t type = CUDA_C_64F; static constexpr cudaDataType_t type = CUDA_C_64F;
#elif TENSORFLOW_USE_ROCM
static constexpr hipblasDataType_t type = HIPBLAS_C_64F;
#endif #endif
}; };
@ -957,7 +951,7 @@ class CSRSparseMatrixMatVec<GPUDevice, T> {
const int n = a.dense_shape_host(1); const int n = a.dense_shape_host(1);
const int nnz = a.values.size(); const int nnz = a.values.size();
DCHECK_EQ(nnz, a.col_ind.size()); DCHECK_EQ(nnz, a.col_ind.size());
#if CUDA_VERSION >= 10020 #if GOOGLE_CUDA && (CUDA_VERSION >= 10020)
TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha, TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha,
a.values.data(), a.row_ptr.data(), a.values.data(), a.row_ptr.data(),
a.col_ind.data(), x, &beta, y)); a.col_ind.data(), x, &beta, y));

View File

@ -417,7 +417,7 @@ class CSRSparseMatMulGPUOp : public OpKernel {
} }
auto b_input_dense_shape = b_input_matrix->dense_shape().vec<int64>(); auto b_input_dense_shape = b_input_matrix->dense_shape().vec<int64>();
#if CUDA_VERSION >= 10000 #if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
size_t maxWorkspaceSize = 0; size_t maxWorkspaceSize = 0;
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
// Calculate maximum workspace size over batch. // Calculate maximum workspace size over batch.
@ -558,7 +558,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
initialized_(false), initialized_(false),
transpose_a_(transpose_a), transpose_a_(transpose_a),
adjoint_a_(adjoint_a), adjoint_a_(adjoint_a),
#if CUDA_VERSION < 10000 #if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
transpose_b_(transpose_b) { transpose_b_(transpose_b) {
#else #else
transpose_b_(transpose_b), transpose_b_(transpose_b),
@ -573,7 +573,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
: GPUSPARSE(OPERATION_NON_TRANSPOSE); : GPUSPARSE(OPERATION_NON_TRANSPOSE);
} }
#if CUDA_VERSION >= 10000 #if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
~CSRSparseSparseMatrixMatMul() { ~CSRSparseSparseMatrixMatMul() {
if (initialized_) { if (initialized_) {
cusparseDestroyCsrgemm2Info(info_); cusparseDestroyCsrgemm2Info(info_);
@ -591,7 +591,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
TF_RETURN_IF_ERROR(descrA_.Initialize()); TF_RETURN_IF_ERROR(descrA_.Initialize());
TF_RETURN_IF_ERROR(descrB_.Initialize()); TF_RETURN_IF_ERROR(descrB_.Initialize());
TF_RETURN_IF_ERROR(descrC_.Initialize()); TF_RETURN_IF_ERROR(descrC_.Initialize());
#if CUDA_VERSION >= 10000 #if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsrgemm2Info(&info_)); TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsrgemm2Info(&info_));
#endif #endif
initialized_ = true; initialized_ = true;
@ -600,6 +600,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
Status GetWorkspaceSize(const ConstCSRComponent<T>& a, Status GetWorkspaceSize(const ConstCSRComponent<T>& a,
const ConstCSRComponent<T>& b, size_t* bufferSize) { const ConstCSRComponent<T>& b, size_t* bufferSize) {
#if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
DCHECK(initialized_); DCHECK(initialized_);
const int m = const int m =
a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 1 : 2)); a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 1 : 2));
@ -621,6 +622,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(), m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(),
descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(), info_, descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(), info_,
bufferSize)); bufferSize));
#endif
return Status::OK(); return Status::OK();
} }
@ -650,7 +652,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
*output_nnz = -1; *output_nnz = -1;
#if CUDA_VERSION < 10000 #if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmNnz( TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmNnz(
transA_, transB_, m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(), transA_, transB_, m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(),
a.col_ind.data(), descrB_.descr(), nnzB, b.row_ptr.data(), a.col_ind.data(), descrB_.descr(), nnzB, b.row_ptr.data(),
@ -693,7 +695,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
b.dense_shape_host(b.dense_shape_host.size() - (transpose_b_ ? 2 : 1)); b.dense_shape_host(b.dense_shape_host.size() - (transpose_b_ ? 2 : 1));
DCHECK_EQ(n, c->dense_shape_host(c->dense_shape_host.size() - 1)); DCHECK_EQ(n, c->dense_shape_host(c->dense_shape_host.size() - 1));
#if CUDA_VERSION < 10000 #if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
TF_RETURN_IF_ERROR(cuda_sparse_.Csrgemm( TF_RETURN_IF_ERROR(cuda_sparse_.Csrgemm(
transA_, transB_, m, k, n, descrA_.descr(), nnzA, a.values.data(), transA_, transB_, m, k, n, descrA_.descr(), nnzA, a.values.data(),
a.row_ptr.data(), a.col_ind.data(), descrB_.descr(), nnzB, a.row_ptr.data(), a.col_ind.data(), descrB_.descr(), nnzB,
@ -732,7 +734,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
GpuSparseMatrixDescriptor descrC_; GpuSparseMatrixDescriptor descrC_;
gpusparseOperation_t transA_; gpusparseOperation_t transA_;
gpusparseOperation_t transB_; gpusparseOperation_t transB_;
#if CUDA_VERSION >= 10000 #if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
csrgemm2Info_t info_; csrgemm2Info_t info_;
#endif #endif
}; };