[ROCm] Fix for ROCm CSB breakage on 200507
The following PR/commit introduces a build error on the ROCm platform https://github.com/tensorflow/tensorflow/pull/38802 The error is caused by a call to the `CsrgemmBufferSize` routine which only exists on the CUDA side. The call to it was not guarded by the same #if block that guards the function declaration + definition. Adding the missing #if block fixes the issue. This PR also adds some explicit `GOOGLE_CUDA &&` and `|| TENSORFLOW_USE_ROCM` conditions to some `#if` to make things clear.
This commit is contained in:
parent
a65ece1e46
commit
5bf0bab331
tensorflow/core/kernels
@ -259,7 +259,7 @@ class GpuSparse {
|
|||||||
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-coo2csr.
|
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-coo2csr.
|
||||||
Status Coo2csr(const int* cooRowInd, int nnz, int m, int* csrRowPtr) const;
|
Status Coo2csr(const int* cooRowInd, int nnz, int m, int* csrRowPtr) const;
|
||||||
|
|
||||||
#if CUDA_VERSION < 10020
|
#if (GOOGLE_CUDA && (CUDA_VERSION < 10020)) || TENSORFLOW_USE_ROCM
|
||||||
// Sparse-dense matrix multiplication C = alpha * op(A) * op(B) + beta * C,
|
// Sparse-dense matrix multiplication C = alpha * op(A) * op(B) + beta * C,
|
||||||
// where A is a sparse matrix in CSR format, B and C are dense tall
|
// where A is a sparse matrix in CSR format, B and C are dense tall
|
||||||
// matrices. This routine allows transposition of matrix B, which
|
// matrices. This routine allows transposition of matrix B, which
|
||||||
@ -311,7 +311,7 @@ class GpuSparse {
|
|||||||
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrmv_mergepath
|
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrmv_mergepath
|
||||||
//
|
//
|
||||||
// **NOTE** This is an in-place operation for data in y.
|
// **NOTE** This is an in-place operation for data in y.
|
||||||
#if CUDA_VERSION < 10020
|
#if (GOOGLE_CUDA && (CUDA_VERSION < 10020)) || TENSORFLOW_USE_ROCM
|
||||||
template <typename Scalar>
|
template <typename Scalar>
|
||||||
Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz,
|
Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz,
|
||||||
const Scalar* alpha_host, const gpusparseMatDescr_t descrA,
|
const Scalar* alpha_host, const gpusparseMatDescr_t descrA,
|
||||||
@ -366,7 +366,7 @@ class GpuSparse {
|
|||||||
Scalar* csrSortedValC, int* csrSortedRowPtrC,
|
Scalar* csrSortedValC, int* csrSortedRowPtrC,
|
||||||
int* csrSortedColIndC, void* workspace);
|
int* csrSortedColIndC, void* workspace);
|
||||||
|
|
||||||
#if CUDA_VERSION >= 10000
|
#if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
|
||||||
// Computes sparse-sparse matrix multiplication of matrices
|
// Computes sparse-sparse matrix multiplication of matrices
|
||||||
// stored in CSR format. This is part zero: calculate required workspace
|
// stored in CSR format. This is part zero: calculate required workspace
|
||||||
// size.
|
// size.
|
||||||
@ -383,7 +383,7 @@ class GpuSparse {
|
|||||||
// output. csrSortedRowPtrC must be preallocated on device with
|
// output. csrSortedRowPtrC must be preallocated on device with
|
||||||
// m + 1 entries. See:
|
// m + 1 entries. See:
|
||||||
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
|
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
|
||||||
#if CUDA_VERSION < 10000
|
#if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
|
||||||
Status CsrgemmNnz(gpusparseOperation_t transA, gpusparseOperation_t transB,
|
Status CsrgemmNnz(gpusparseOperation_t transA, gpusparseOperation_t transB,
|
||||||
int m, int k, int n, const gpusparseMatDescr_t descrA,
|
int m, int k, int n, const gpusparseMatDescr_t descrA,
|
||||||
int nnzA, const int* csrSortedRowPtrA,
|
int nnzA, const int* csrSortedRowPtrA,
|
||||||
@ -408,7 +408,7 @@ class GpuSparse {
|
|||||||
// addition. csrValC and csrColIndC must be allocated on the device
|
// addition. csrValC and csrColIndC must be allocated on the device
|
||||||
// with nnzTotalDevHostPtr entries (as calculated by CsrgemmNnz). See:
|
// with nnzTotalDevHostPtr entries (as calculated by CsrgemmNnz). See:
|
||||||
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
|
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
|
||||||
#if CUDA_VERSION < 10000
|
#if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
|
||||||
template <typename Scalar>
|
template <typename Scalar>
|
||||||
Status Csrgemm(gpusparseOperation_t transA, gpusparseOperation_t transB,
|
Status Csrgemm(gpusparseOperation_t transA, gpusparseOperation_t transB,
|
||||||
int m, int k, int n, const gpusparseMatDescr_t descrA,
|
int m, int k, int n, const gpusparseMatDescr_t descrA,
|
||||||
|
@ -728,12 +728,14 @@ namespace {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
struct GPUDataType;
|
struct GPUDataType;
|
||||||
|
|
||||||
|
// GPUDataType templates are currently not instantiated in the ROCm flow
|
||||||
|
// So leaving out the #elif TENSORFLOW_USE_ROCM blocks for now
|
||||||
|
// hipblas library is not (yet) being pulled in via rocm_configure.bzl
|
||||||
|
// so cannot reference tyeps from hipblas headers here
|
||||||
template <>
|
template <>
|
||||||
struct GPUDataType<Eigen::half> {
|
struct GPUDataType<Eigen::half> {
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
static constexpr cudaDataType_t type = CUDA_R_16F;
|
static constexpr cudaDataType_t type = CUDA_R_16F;
|
||||||
#elif TENSORFLOW_USE_ROCM
|
|
||||||
static constexpr hipblasDataType_t type = HIPBLAS_R_16F;
|
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -741,8 +743,6 @@ template <>
|
|||||||
struct GPUDataType<float> {
|
struct GPUDataType<float> {
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
static constexpr cudaDataType_t type = CUDA_R_32F;
|
static constexpr cudaDataType_t type = CUDA_R_32F;
|
||||||
#elif TENSORFLOW_USE_ROCM
|
|
||||||
static constexpr hipblasDataType_t type = HIPBLAS_R_32F;
|
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -750,8 +750,6 @@ template <>
|
|||||||
struct GPUDataType<std::complex<float>> {
|
struct GPUDataType<std::complex<float>> {
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
static constexpr cudaDataType_t type = CUDA_C_32F;
|
static constexpr cudaDataType_t type = CUDA_C_32F;
|
||||||
#elif TENSORFLOW_USE_ROCM
|
|
||||||
static constexpr hipblasDataType_t type = HIPBLAS_C_32F;
|
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -759,8 +757,6 @@ template <>
|
|||||||
struct GPUDataType<double> {
|
struct GPUDataType<double> {
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
static constexpr cudaDataType_t type = CUDA_R_64F;
|
static constexpr cudaDataType_t type = CUDA_R_64F;
|
||||||
#elif TENSORFLOW_USE_ROCM
|
|
||||||
static constexpr hipblasDataType_t type = HIPBLAS_R_64F;
|
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -768,8 +764,6 @@ template <>
|
|||||||
struct GPUDataType<std::complex<double>> {
|
struct GPUDataType<std::complex<double>> {
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
static constexpr cudaDataType_t type = CUDA_C_64F;
|
static constexpr cudaDataType_t type = CUDA_C_64F;
|
||||||
#elif TENSORFLOW_USE_ROCM
|
|
||||||
static constexpr hipblasDataType_t type = HIPBLAS_C_64F;
|
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -957,7 +951,7 @@ class CSRSparseMatrixMatVec<GPUDevice, T> {
|
|||||||
const int n = a.dense_shape_host(1);
|
const int n = a.dense_shape_host(1);
|
||||||
const int nnz = a.values.size();
|
const int nnz = a.values.size();
|
||||||
DCHECK_EQ(nnz, a.col_ind.size());
|
DCHECK_EQ(nnz, a.col_ind.size());
|
||||||
#if CUDA_VERSION >= 10020
|
#if GOOGLE_CUDA && (CUDA_VERSION >= 10020)
|
||||||
TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha,
|
TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha,
|
||||||
a.values.data(), a.row_ptr.data(),
|
a.values.data(), a.row_ptr.data(),
|
||||||
a.col_ind.data(), x, &beta, y));
|
a.col_ind.data(), x, &beta, y));
|
||||||
|
@ -417,7 +417,7 @@ class CSRSparseMatMulGPUOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
auto b_input_dense_shape = b_input_matrix->dense_shape().vec<int64>();
|
auto b_input_dense_shape = b_input_matrix->dense_shape().vec<int64>();
|
||||||
|
|
||||||
#if CUDA_VERSION >= 10000
|
#if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
|
||||||
size_t maxWorkspaceSize = 0;
|
size_t maxWorkspaceSize = 0;
|
||||||
for (int i = 0; i < batch_size; ++i) {
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
// Calculate maximum workspace size over batch.
|
// Calculate maximum workspace size over batch.
|
||||||
@ -558,7 +558,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
|||||||
initialized_(false),
|
initialized_(false),
|
||||||
transpose_a_(transpose_a),
|
transpose_a_(transpose_a),
|
||||||
adjoint_a_(adjoint_a),
|
adjoint_a_(adjoint_a),
|
||||||
#if CUDA_VERSION < 10000
|
#if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
|
||||||
transpose_b_(transpose_b) {
|
transpose_b_(transpose_b) {
|
||||||
#else
|
#else
|
||||||
transpose_b_(transpose_b),
|
transpose_b_(transpose_b),
|
||||||
@ -573,7 +573,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
|||||||
: GPUSPARSE(OPERATION_NON_TRANSPOSE);
|
: GPUSPARSE(OPERATION_NON_TRANSPOSE);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if CUDA_VERSION >= 10000
|
#if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
|
||||||
~CSRSparseSparseMatrixMatMul() {
|
~CSRSparseSparseMatrixMatMul() {
|
||||||
if (initialized_) {
|
if (initialized_) {
|
||||||
cusparseDestroyCsrgemm2Info(info_);
|
cusparseDestroyCsrgemm2Info(info_);
|
||||||
@ -591,7 +591,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
|||||||
TF_RETURN_IF_ERROR(descrA_.Initialize());
|
TF_RETURN_IF_ERROR(descrA_.Initialize());
|
||||||
TF_RETURN_IF_ERROR(descrB_.Initialize());
|
TF_RETURN_IF_ERROR(descrB_.Initialize());
|
||||||
TF_RETURN_IF_ERROR(descrC_.Initialize());
|
TF_RETURN_IF_ERROR(descrC_.Initialize());
|
||||||
#if CUDA_VERSION >= 10000
|
#if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
|
||||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsrgemm2Info(&info_));
|
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsrgemm2Info(&info_));
|
||||||
#endif
|
#endif
|
||||||
initialized_ = true;
|
initialized_ = true;
|
||||||
@ -600,6 +600,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
|||||||
|
|
||||||
Status GetWorkspaceSize(const ConstCSRComponent<T>& a,
|
Status GetWorkspaceSize(const ConstCSRComponent<T>& a,
|
||||||
const ConstCSRComponent<T>& b, size_t* bufferSize) {
|
const ConstCSRComponent<T>& b, size_t* bufferSize) {
|
||||||
|
#if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
|
||||||
DCHECK(initialized_);
|
DCHECK(initialized_);
|
||||||
const int m =
|
const int m =
|
||||||
a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 1 : 2));
|
a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 1 : 2));
|
||||||
@ -621,6 +622,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
|||||||
m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(),
|
m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(),
|
||||||
descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(), info_,
|
descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(), info_,
|
||||||
bufferSize));
|
bufferSize));
|
||||||
|
#endif
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -650,7 +652,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
|||||||
|
|
||||||
*output_nnz = -1;
|
*output_nnz = -1;
|
||||||
|
|
||||||
#if CUDA_VERSION < 10000
|
#if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
|
||||||
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmNnz(
|
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmNnz(
|
||||||
transA_, transB_, m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(),
|
transA_, transB_, m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(),
|
||||||
a.col_ind.data(), descrB_.descr(), nnzB, b.row_ptr.data(),
|
a.col_ind.data(), descrB_.descr(), nnzB, b.row_ptr.data(),
|
||||||
@ -693,7 +695,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
|||||||
b.dense_shape_host(b.dense_shape_host.size() - (transpose_b_ ? 2 : 1));
|
b.dense_shape_host(b.dense_shape_host.size() - (transpose_b_ ? 2 : 1));
|
||||||
DCHECK_EQ(n, c->dense_shape_host(c->dense_shape_host.size() - 1));
|
DCHECK_EQ(n, c->dense_shape_host(c->dense_shape_host.size() - 1));
|
||||||
|
|
||||||
#if CUDA_VERSION < 10000
|
#if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
|
||||||
TF_RETURN_IF_ERROR(cuda_sparse_.Csrgemm(
|
TF_RETURN_IF_ERROR(cuda_sparse_.Csrgemm(
|
||||||
transA_, transB_, m, k, n, descrA_.descr(), nnzA, a.values.data(),
|
transA_, transB_, m, k, n, descrA_.descr(), nnzA, a.values.data(),
|
||||||
a.row_ptr.data(), a.col_ind.data(), descrB_.descr(), nnzB,
|
a.row_ptr.data(), a.col_ind.data(), descrB_.descr(), nnzB,
|
||||||
@ -732,7 +734,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
|||||||
GpuSparseMatrixDescriptor descrC_;
|
GpuSparseMatrixDescriptor descrC_;
|
||||||
gpusparseOperation_t transA_;
|
gpusparseOperation_t transA_;
|
||||||
gpusparseOperation_t transB_;
|
gpusparseOperation_t transB_;
|
||||||
#if CUDA_VERSION >= 10000
|
#if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
|
||||||
csrgemm2Info_t info_;
|
csrgemm2Info_t info_;
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user