diff --git a/tensorflow/core/kernels/sparse/mat_mul_op.cc b/tensorflow/core/kernels/sparse/mat_mul_op.cc index bf9de570fbf..799e33000ad 100644 --- a/tensorflow/core/kernels/sparse/mat_mul_op.cc +++ b/tensorflow/core/kernels/sparse/mat_mul_op.cc @@ -886,11 +886,11 @@ class CSRSparseMatrixMatMul { const gpusparseOperation_t transB = HIPSPARSE_OPERATION_TRANSPOSE; gpusparseMatDescr_t descrA; - TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA)); + TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateMatDescr(&descrA)); TF_RETURN_IF_GPUSPARSE_ERROR( - hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL)); + wrap::hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL)); TF_RETURN_IF_GPUSPARSE_ERROR( - hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO)); + wrap::hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO)); #endif // GOOGLE_CUDA TF_RETURN_IF_ERROR( @@ -940,11 +940,11 @@ class CSRSparseMatrixMatVec { cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO)); #elif TENSORFLOW_USE_ROCM gpusparseMatDescr_t descrA; - TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA)); + TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateMatDescr(&descrA)); TF_RETURN_IF_GPUSPARSE_ERROR( - hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL)); + wrap::hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL)); TF_RETURN_IF_GPUSPARSE_ERROR( - hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO)); + wrap::hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO)); #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM const int m = a.dense_shape_host(0); diff --git a/tensorflow/core/util/BUILD b/tensorflow/core/util/BUILD index dcb2787e309..0dc8f84aadf 100644 --- a/tensorflow/core/util/BUILD +++ b/tensorflow/core/util/BUILD @@ -626,7 +626,7 @@ tf_kernel_library( "//tensorflow/stream_executor/cuda:cusparse_lib", "@cub_archive//:cub", ]) + if_rocm([ - "@local_config_rocm//rocm:hipsparse", + "//tensorflow/stream_executor/rocm:hipsparse_wrapper", ]), ) diff --git a/tensorflow/core/util/cuda_sparse.h b/tensorflow/core/util/cuda_sparse.h index 76580766d69..cd10ba8d8cb 100644 --- a/tensorflow/core/util/cuda_sparse.h +++ b/tensorflow/core/util/cuda_sparse.h @@ -46,7 +46,7 @@ using gpusparseSpMMAlg_t = cusparseSpMMAlg_t; #elif TENSORFLOW_USE_ROCM -#include "rocm/include/hipsparse/hipsparse.h" +#include "tensorflow/stream_executor/rocm/hipsparse_wrapper.h" using gpusparseStatus_t = hipsparseStatus_t; using gpusparseOperation_t = hipsparseOperation_t; @@ -485,7 +485,7 @@ class GpuSparseMatrixDescriptor { #if GOOGLE_CUDA TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descr_)); #elif TENSORFLOW_USE_ROCM - TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descr_)); + TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateMatDescr(&descr_)); #endif initialized_ = true; return Status::OK(); @@ -507,7 +507,7 @@ class GpuSparseMatrixDescriptor { #if GOOGLE_CUDA cusparseDestroyMatDescr(descr_); #elif TENSORFLOW_USE_ROCM - hipsparseDestroyMatDescr(descr_); + wrap::hipsparseDestroyMatDescr(descr_); #endif initialized_ = false; } diff --git a/tensorflow/core/util/rocm_sparse.cc b/tensorflow/core/util/rocm_sparse.cc index cc7b56fdc01..22c2af780c7 100644 --- a/tensorflow/core/util/rocm_sparse.cc +++ b/tensorflow/core/util/rocm_sparse.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/core/util/cuda_sparse.h" namespace tensorflow { + namespace { // A set of initialized handles to the underlying ROCm libraries used by @@ -67,9 +68,9 @@ class HipSparseHandles { Status Initialize() { if (initialized_) return Status::OK(); - TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreate(&hipsparse_handle_)); + TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreate(&hipsparse_handle_)); TF_RETURN_IF_GPUSPARSE_ERROR( - hipsparseSetStream(hipsparse_handle_, stream_)); + wrap::hipsparseSetStream(hipsparse_handle_, stream_)); initialized_ = true; return Status::OK(); } @@ -88,7 +89,7 @@ class HipSparseHandles { void Release() { if (initialized_) { // This should never return anything other than success - auto err = hipsparseDestroy(hipsparse_handle_); + auto err = wrap::hipsparseDestroy(hipsparse_handle_); DCHECK(err == HIPSPARSE_STATUS_SUCCESS) << "Failed to destroy hipSPARSE instance."; initialized_ = false; @@ -156,23 +157,23 @@ Status GpuSparse::Initialize() { #define TF_CALL_HIP_LAPACK_TYPES(m) m(float, S) m(double, D) // Macros to construct hipsparse method names. -#define SPARSE_FN(method, sparse_prefix) hipsparse##sparse_prefix##method +#define SPARSE_FN(method, sparse_prefix) wrap::hipsparse##sparse_prefix##method Status GpuSparse::Coo2csr(const int* cooRowInd, int nnz, int m, int* csrRowPtr) const { DCHECK(initialized_); - TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcoo2csr(*gpusparse_handle_, cooRowInd, - nnz, m, csrRowPtr, - HIPSPARSE_INDEX_BASE_ZERO)); + TF_RETURN_IF_GPUSPARSE_ERROR( + wrap::hipsparseXcoo2csr(*gpusparse_handle_, cooRowInd, nnz, m, csrRowPtr, + HIPSPARSE_INDEX_BASE_ZERO)); return Status::OK(); } Status GpuSparse::Csr2coo(const int* csrRowPtr, int nnz, int m, int* cooRowInd) const { DCHECK(initialized_); - TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcsr2coo(*gpusparse_handle_, csrRowPtr, - nnz, m, cooRowInd, - HIPSPARSE_INDEX_BASE_ZERO)); + TF_RETURN_IF_GPUSPARSE_ERROR( + wrap::hipsparseXcsr2coo(*gpusparse_handle_, csrRowPtr, nnz, m, cooRowInd, + HIPSPARSE_INDEX_BASE_ZERO)); return Status::OK(); } @@ -252,7 +253,7 @@ Status GpuSparse::CsrgemmNnz( int* csrSortedRowPtrC, int* nnzTotalDevHostPtr) { DCHECK(initialized_); DCHECK(nnzTotalDevHostPtr != nullptr); - TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcsrgemmNnz( + TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseXcsrgemmNnz( *gpusparse_handle_, transA, transB, m, n, k, descrA, nnzA, csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB, descrC, csrSortedRowPtrC, nnzTotalDevHostPtr)); diff --git a/tensorflow/stream_executor/platform/default/dso_loader.cc b/tensorflow/stream_executor/platform/default/dso_loader.cc index 6e0113ab05a..70b1ebe070a 100644 --- a/tensorflow/stream_executor/platform/default/dso_loader.cc +++ b/tensorflow/stream_executor/platform/default/dso_loader.cc @@ -136,6 +136,10 @@ port::StatusOr GetRocrandDsoHandle() { return GetDsoHandle("rocrand", ""); } +port::StatusOr GetHipsparseDsoHandle() { + return GetDsoHandle("hipsparse", ""); +} + port::StatusOr GetHipDsoHandle() { return GetDsoHandle("hip_hcc", ""); } } // namespace DsoLoader @@ -206,6 +210,11 @@ port::StatusOr GetRocrandDsoHandle() { return *result; } +port::StatusOr GetHipsparseDsoHandle() { + static auto result = new auto(DsoLoader::GetHipsparseDsoHandle()); + return *result; +} + port::StatusOr GetHipDsoHandle() { static auto result = new auto(DsoLoader::GetHipDsoHandle()); return *result; diff --git a/tensorflow/stream_executor/platform/default/dso_loader.h b/tensorflow/stream_executor/platform/default/dso_loader.h index 7eee2e60785..91138f713bd 100644 --- a/tensorflow/stream_executor/platform/default/dso_loader.h +++ b/tensorflow/stream_executor/platform/default/dso_loader.h @@ -50,6 +50,7 @@ port::StatusOr GetRocblasDsoHandle(); port::StatusOr GetMiopenDsoHandle(); port::StatusOr GetRocfftDsoHandle(); port::StatusOr GetRocrandDsoHandle(); +port::StatusOr GetHipsparseDsoHandle(); port::StatusOr GetHipDsoHandle(); // The following method tries to dlopen all necessary GPU libraries for the GPU @@ -82,6 +83,7 @@ port::StatusOr GetRocblasDsoHandle(); port::StatusOr GetMiopenDsoHandle(); port::StatusOr GetRocfftDsoHandle(); port::StatusOr GetRocrandDsoHandle(); +port::StatusOr GetHipsparseDsoHandle(); port::StatusOr GetHipDsoHandle(); } // namespace CachedDsoLoader diff --git a/tensorflow/stream_executor/rocm/BUILD b/tensorflow/stream_executor/rocm/BUILD index bd924125d77..bd4c45382f8 100644 --- a/tensorflow/stream_executor/rocm/BUILD +++ b/tensorflow/stream_executor/rocm/BUILD @@ -277,6 +277,23 @@ cc_library( alwayslink = True, ) +cc_library( + name = "hipsparse_wrapper", + srcs = if_rocm_is_configured(["hipsparse_wrapper.h"]), + hdrs = if_rocm_is_configured(["hipsparse_wrapper.h"]), + deps = if_rocm_is_configured([ + ":rocm_gpu_executor", + ":rocm_platform_id", + "@local_config_rocm//rocm:rocm_headers", + "//tensorflow/stream_executor/lib", + "//tensorflow/stream_executor/platform", + "//tensorflow/stream_executor/platform:dso_loader", + ] + if_static([ + "@local_config_rocm//rocm:hiprand", + ])), + alwayslink = True, +) + cc_library( name = "all_runtime", copts = tf_copts(), diff --git a/tensorflow/stream_executor/rocm/hipsparse_wrapper.h b/tensorflow/stream_executor/rocm/hipsparse_wrapper.h new file mode 100644 index 00000000000..6444f015cf8 --- /dev/null +++ b/tensorflow/stream_executor/rocm/hipsparse_wrapper.h @@ -0,0 +1,105 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file wraps hipsparse API calls with dso loader so that we don't need to +// have explicit linking to libhipsparse. All TF hipsarse API usage should route +// through this wrapper. + +#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_ +#define TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_ + +#include "rocm/include/hipsparse/hipsparse.h" +#include "tensorflow/stream_executor/lib/env.h" +#include "tensorflow/stream_executor/platform/dso_loader.h" +#include "tensorflow/stream_executor/platform/port.h" + +namespace tensorflow { +namespace wrap { + +#ifdef PLATFORM_GOOGLE + +#define HIPSPARSE_API_WRAPPER(__name) \ + struct WrapperShim__##__name { \ + template \ + hipsparseStatus_t operator()(Args... args) { \ + hipSparseStatus_t retval = ::__name(args...); \ + return retval; \ + } \ + } __name; + +#else + +#define HIPSPARSE_API_WRAPPER(__name) \ + struct DynLoadShim__##__name { \ + static const char* kName; \ + using FuncPtrT = std::add_pointer::type; \ + static void* GetDsoHandle() { \ + auto s = \ + stream_executor::internal::CachedDsoLoader::GetHipsparseDsoHandle(); \ + return s.ValueOrDie(); \ + } \ + static FuncPtrT LoadOrDie() { \ + void* f; \ + auto s = \ + Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), kName, &f); \ + CHECK(s.ok()) << "could not find " << kName \ + << " in miopen DSO; dlerror: " << s.error_message(); \ + return reinterpret_cast(f); \ + } \ + static FuncPtrT DynLoad() { \ + static FuncPtrT f = LoadOrDie(); \ + return f; \ + } \ + template \ + hipsparseStatus_t operator()(Args... args) { \ + return DynLoad()(args...); \ + } \ + } __name; \ + const char* DynLoadShim__##__name::kName = #__name; + +#endif + +// clang-format off +#define FOREACH_HIPSPARSE_API(__macro) \ + __macro(hipsparseCreate) \ + __macro(hipsparseCreateMatDescr) \ + __macro(hipsparseDcsr2csc) \ + __macro(hipsparseDcsrgemm) \ + __macro(hipsparseDcsrmm2) \ + __macro(hipsparseDcsrmv) \ + __macro(hipsparseDestroy) \ + __macro(hipsparseDestroyMatDescr) \ + __macro(hipsparseScsr2csc) \ + __macro(hipsparseScsrgemm) \ + __macro(hipsparseScsrmm2) \ + __macro(hipsparseScsrmv) \ + __macro(hipsparseSetStream) \ + __macro(hipsparseSetMatIndexBase) \ + __macro(hipsparseSetMatType) \ + __macro(hipsparseXcoo2csr) \ + __macro(hipsparseXcsr2coo) \ + __macro(hipsparseXcsrgemmNnz) + +// clang-format on + +FOREACH_HIPSPARSE_API(HIPSPARSE_API_WRAPPER) + +#undef FOREACH_HIPSPARSE_API +#undef HIPSPARSE_API_WRAPPER + +} // namespace wrap +} // namespace tensorflow + +#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_ diff --git a/third_party/gpus/rocm/BUILD.tpl b/third_party/gpus/rocm/BUILD.tpl index cf8950b5bc7..3c233b4f5b0 100644 --- a/third_party/gpus/rocm/BUILD.tpl +++ b/third_party/gpus/rocm/BUILD.tpl @@ -108,6 +108,7 @@ cc_library( ":rocfft", ":hiprand", ":miopen", + ":hipsparse", ], ) @@ -137,11 +138,9 @@ cc_library( ], ) -cc_import( +cc_library( name = "hipsparse", - hdrs = glob(["rocm/include/hipsparse/**",]), - shared_library = "rocm/lib/%{hipsparse_lib}", - visibility = ["//visibility:public"], + data = ["rocm/lib/%{hipsparse_lib}"], ) %{copy_rules}