From aaea82e6bcd0948a9b2bf10396684c2f01fb60ea Mon Sep 17 00:00:00 2001 From: Ben Barsdell Date: Mon, 6 Jul 2020 21:09:32 +1000 Subject: [PATCH] Add cublasLt wrappers to stream_executor - Adds ThenBlasLtMatmul routines that behave similarly to ThenBlasGemmWithAlgorithm but call into the cublasLt library and allow separation of plan creation and execution. - A list of heuristically-prioritized opaque algorithm objects can be obtained via GetBlasLtMatmulAlgorithms. - These routines are only supported when the CUDA version is >= 11.0. --- tensorflow/stream_executor/blas.cc | 25 + tensorflow/stream_executor/blas.h | 218 ++++++ tensorflow/stream_executor/cuda/BUILD | 24 + .../stream_executor/cuda/cublasLt_11_0.inc | 415 +++++++++++ .../stream_executor/cuda/cublasLt_stub.cc | 59 ++ tensorflow/stream_executor/cuda/cuda_blas.cc | 680 +++++++++++++++++- tensorflow/stream_executor/cuda/cuda_blas.h | 30 + .../platform/default/dlopen_checker.cc | 3 +- .../platform/default/dso_loader.cc | 9 + .../platform/default/dso_loader.h | 2 + tensorflow/stream_executor/stream.cc | 137 ++++ tensorflow/stream_executor/stream.h | 50 ++ .../stream_executor/stream_executor_pimpl.cc | 43 ++ .../stream_executor/stream_executor_pimpl.h | 29 + third_party/gpus/cuda/BUILD.tpl | 8 + third_party/gpus/cuda_configure.bzl | 12 + 16 files changed, 1741 insertions(+), 3 deletions(-) create mode 100644 tensorflow/stream_executor/cuda/cublasLt_11_0.inc create mode 100644 tensorflow/stream_executor/cuda/cublasLt_stub.cc diff --git a/tensorflow/stream_executor/blas.cc b/tensorflow/stream_executor/blas.cc index f499b3003d0..f55e318e88b 100644 --- a/tensorflow/stream_executor/blas.cc +++ b/tensorflow/stream_executor/blas.cc @@ -95,5 +95,30 @@ std::ostream& operator<<(std::ostream& os, ComputationType ty) { return os << ComputationTypeString(ty); } +string DataTypeString(DataType ty) { + switch (ty) { + case DataType::kF16: + return "f16"; + case DataType::kF32: + return "f32"; + case DataType::kF64: + return "f64"; + case DataType::kI8: + return "i8"; + case DataType::kI32: + return "i32"; + case DataType::kComplexF32: + return "complex f32"; + case DataType::kComplexF64: + return "complex f64"; + default: + LOG(FATAL) << "Unknown DataType " << static_cast(ty); + } +} + +std::ostream& operator<<(std::ostream& os, DataType ty) { + return os << DataTypeString(ty); +} + } // namespace blas } // namespace stream_executor diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h index 5018d487ed1..583fba2a505 100644 --- a/tensorflow/stream_executor/blas.h +++ b/tensorflow/stream_executor/blas.h @@ -101,6 +101,10 @@ enum class ComputationType { kI32, // 32-bit integer kComplexF32, // Complex number comprised of two f32s. kComplexF64, // Complex number comprised of two f64s. + // The below values are only supported for BlasLt routines (both real and + // complex). + kF32FastTF32, // 32-bit floating-point with reduced (>=10-bit) mantissa + kF32FastBF16, // 32-bit floating-point with reduced (7-bit) mantissa }; // Converts a ComputationType to a string. @@ -108,6 +112,61 @@ std::string ComputationTypeString(ComputationType ty); std::ostream &operator<<(std::ostream &os, ComputationType ty); +// Type with which inputs and outputs of a blaslt routine are performed. +enum class DataType { + kF16, // 16-bit floating-point + kF32, // 32-bit floating-point + kF64, // 64-bit floating-point + kI8, // 8-bit integer + kI32, // 32-bit integer + kComplexF32, // Complex number comprised of two f32s + kComplexF64, // Complex number comprised of two f64s +}; + +// Describes the type of pointers for the scaling factors alpha and beta in +// blaslt routines. +enum class PointerMode { + kHost, + kDevice, +}; + +// Converts a ComputationType to a string. +string DataTypeString(DataType ty); + +std::ostream &operator<<(std::ostream &os, DataType ty); + +// Converts a compile-time type to a DataType value. +template +struct ToDataType {}; +template <> +struct ToDataType { + static constexpr const DataType value = DataType::kF16; +}; +template <> +struct ToDataType { + static constexpr const DataType value = DataType::kF32; +}; +template <> +struct ToDataType { + static constexpr const DataType value = DataType::kF64; +}; +template <> +struct ToDataType { + static constexpr const DataType value = DataType::kI8; +}; +template <> +struct ToDataType { + static constexpr const DataType value = DataType::kI32; +}; +template <> +struct ToDataType> { + static constexpr const DataType value = DataType::kComplexF32; +}; +template <> +struct ToDataType> { + static constexpr const DataType value = DataType::kComplexF64; +}; + // Opaque identifier for an "algorithm" used by a blas routine. This functions // as a hint to the blas library. typedef int64 AlgorithmType; @@ -163,6 +222,19 @@ class AlgorithmConfig { AlgorithmType algorithm_; }; +struct IBlasLtMatmulPlan { + virtual ~IBlasLtMatmulPlan() {} +}; + +struct IBlasLtMatmulAlgorithm { + virtual ~IBlasLtMatmulAlgorithm() {} + // Returns the index of the algorithm within the list returned by + // GetBlasLtMatmulAlgorithms. + virtual AlgorithmType index() const = 0; + // Returns the workspace size required by the algorithm in bytes. + virtual size_t workspace_size() const = 0; +}; + // BLAS support interface -- this can be derived from a GPU executor when the // underlying platform has an BLAS library implementation available. See // StreamExecutor::AsBlas(). @@ -1383,6 +1455,93 @@ class BlasSupport { const DeviceMemory> &a, int lda, DeviceMemory> *b, int ldb) = 0; + // Creates a backend-specific plan object for a blaslt matmul operation, which + // can then be passed to DoBlasLtMatmul(). When possible, plans should be + // created once and reused for multiple calls to DoBlasLtMatmul(). + // Returns a null pointer on failure. + std::unique_ptr CreateBlasLtMatmulPlan( + blas::DataType ab_type, blas::DataType c_type, + blas::ComputationType computation_type, blas::PointerMode pointer_mode, + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, int64 lda, int64 ldb, int64 ldc) { + return CreateBlasLtMatmulPlanStridedBatched( + ab_type, c_type, computation_type, pointer_mode, transa, transb, m, n, + k, 1, lda, 0, ldb, 0, ldc, 0); + } + + // A more general version of CreateBlasLtMatmulPlan supporting + // batched operations. + virtual std::unique_ptr + CreateBlasLtMatmulPlanStridedBatched( + blas::DataType ab_type, blas::DataType c_type, + blas::ComputationType computation_type, blas::PointerMode pointer_mode, + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, int batch_count, int64 lda, int64 stride_a, int64 ldb, + int64 stride_b, int64 ldc, int64 stride_c) = 0; + + // Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are + // returned in the order of increasing estimated compute time according to an + // internal heuristic. The first returned algorithm can be used as the default + // algorithm if no autotuning is to be performed. + virtual bool GetBlasLtMatmulAlgorithms( + const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size, + int max_algorithm_count, + std::vector>* + out_algorithms) = 0; + + // Executes a blaslt matmul operation on the stream. If output_profile_result + // is not nullptr, the operation is profiled, error messages are + // suppressed, and output_profile_result->algorithm() is set to + // algorithm->index(). + virtual bool DoBlasLtMatmul( + Stream* stream, const blas::IBlasLtMatmulPlan* plan, + const HostOrDeviceScalar& alpha, const DeviceMemory& a, + const DeviceMemory& b, const HostOrDeviceScalar& beta, + DeviceMemory* c, ScratchAllocator* scratch_allocator, + const blas::IBlasLtMatmulAlgorithm* algorithm, + blas::ProfileResult* output_profile_result = nullptr) = 0; + virtual bool DoBlasLtMatmul( + Stream* stream, const blas::IBlasLtMatmulPlan* plan, + const HostOrDeviceScalar& alpha, + const DeviceMemory& a, const DeviceMemory& b, + const HostOrDeviceScalar& beta, DeviceMemory* c, + ScratchAllocator* scratch_allocator, + const blas::IBlasLtMatmulAlgorithm* algorithm, + blas::ProfileResult* output_profile_result = nullptr) = 0; + virtual bool DoBlasLtMatmul( + Stream* stream, const blas::IBlasLtMatmulPlan* plan, + const HostOrDeviceScalar& alpha, const DeviceMemory& a, + const DeviceMemory& b, const HostOrDeviceScalar& beta, + DeviceMemory* c, ScratchAllocator* scratch_allocator, + const blas::IBlasLtMatmulAlgorithm* algorithm, + blas::ProfileResult* output_profile_result = nullptr) = 0; + virtual bool DoBlasLtMatmul( + Stream* stream, const blas::IBlasLtMatmulPlan* plan, + const HostOrDeviceScalar& alpha, const DeviceMemory& a, + const DeviceMemory& b, const HostOrDeviceScalar& beta, + DeviceMemory* c, ScratchAllocator* scratch_allocator, + const blas::IBlasLtMatmulAlgorithm* algorithm, + blas::ProfileResult* output_profile_result = nullptr) = 0; + virtual bool DoBlasLtMatmul( + Stream* stream, const blas::IBlasLtMatmulPlan* plan, + const HostOrDeviceScalar>& alpha, + const DeviceMemory>& a, + const DeviceMemory>& b, + const HostOrDeviceScalar>& beta, + DeviceMemory>* c, ScratchAllocator* scratch_allocator, + const blas::IBlasLtMatmulAlgorithm* algorithm, + blas::ProfileResult* output_profile_result = nullptr) = 0; + virtual bool DoBlasLtMatmul( + Stream* stream, const blas::IBlasLtMatmulPlan* plan, + const HostOrDeviceScalar>& alpha, + const DeviceMemory>& a, + const DeviceMemory>& b, + const HostOrDeviceScalar>& beta, + DeviceMemory>* c, + ScratchAllocator* scratch_allocator, + const blas::IBlasLtMatmulAlgorithm* algorithm, + blas::ProfileResult* output_profile_result = nullptr) = 0; + virtual port::Status GetVersion(std::string *version) = 0; protected: @@ -2196,6 +2355,65 @@ class BlasSupport { uint64 n, std::complex alpha, \ const DeviceMemory> &a, int lda, \ DeviceMemory> *b, int ldb) override; \ + std::unique_ptr \ + CreateBlasLtMatmulPlanStridedBatched( \ + blas::DataType ab_type, blas::DataType cd_type, \ + blas::ComputationType computation_type, blas::PointerMode pointer_mode, \ + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, \ + uint64 k, int batch_count, int64 lda, int64 stride_a, int64 ldb, \ + int64 stride_b, int64 ldc, int64 stride_c) override; \ + bool GetBlasLtMatmulAlgorithms( \ + const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size, \ + int max_algorithm_count, \ + std::vector>* \ + out_algorithms) override; \ + bool DoBlasLtMatmul( \ + Stream* stream, const blas::IBlasLtMatmulPlan* plan, \ + const HostOrDeviceScalar& alpha, const DeviceMemory& a, \ + const DeviceMemory& b, const HostOrDeviceScalar& beta, \ + DeviceMemory* c, ScratchAllocator* scratch_allocator, \ + const blas::IBlasLtMatmulAlgorithm* algorithm, \ + blas::ProfileResult* output_profile_result = nullptr) override; \ + bool DoBlasLtMatmul( \ + Stream* stream, const blas::IBlasLtMatmulPlan* plan, \ + const HostOrDeviceScalar& alpha, \ + const DeviceMemory& a, const DeviceMemory& b, \ + const HostOrDeviceScalar& beta, \ + DeviceMemory* c, ScratchAllocator* scratch_allocator, \ + const blas::IBlasLtMatmulAlgorithm* algorithm, \ + blas::ProfileResult* output_profile_result) override; \ + bool DoBlasLtMatmul( \ + Stream* stream, const blas::IBlasLtMatmulPlan* plan, \ + const HostOrDeviceScalar& alpha, const DeviceMemory& a, \ + const DeviceMemory& b, const HostOrDeviceScalar& beta, \ + DeviceMemory* c, ScratchAllocator* scratch_allocator, \ + const blas::IBlasLtMatmulAlgorithm* algorithm, \ + blas::ProfileResult* output_profile_result) override; \ + bool DoBlasLtMatmul( \ + Stream* stream, const blas::IBlasLtMatmulPlan* plan, \ + const HostOrDeviceScalar& alpha, const DeviceMemory& a, \ + const DeviceMemory& b, const HostOrDeviceScalar& beta, \ + DeviceMemory* c, ScratchAllocator* scratch_allocator, \ + const blas::IBlasLtMatmulAlgorithm* algorithm, \ + blas::ProfileResult* output_profile_result) override; \ + bool DoBlasLtMatmul(Stream* stream, const blas::IBlasLtMatmulPlan* plan, \ + const HostOrDeviceScalar>& alpha, \ + const DeviceMemory>& a, \ + const DeviceMemory>& b, \ + const HostOrDeviceScalar>& beta, \ + DeviceMemory>* c, \ + ScratchAllocator* scratch_allocator, \ + const blas::IBlasLtMatmulAlgorithm* algorithm, \ + blas::ProfileResult* output_profile_result) override; \ + bool DoBlasLtMatmul(Stream* stream, const blas::IBlasLtMatmulPlan* plan, \ + const HostOrDeviceScalar>& alpha, \ + const DeviceMemory>& a, \ + const DeviceMemory>& b, \ + const HostOrDeviceScalar>& beta, \ + DeviceMemory>* c, \ + ScratchAllocator* scratch_allocator, \ + const blas::IBlasLtMatmulAlgorithm* algorithm, \ + blas::ProfileResult* output_profile_result) override; \ port::Status GetVersion(std::string *version) override; } // namespace blas diff --git a/tensorflow/stream_executor/cuda/BUILD b/tensorflow/stream_executor/cuda/BUILD index dccdab8877e..87cb64490a6 100644 --- a/tensorflow/stream_executor/cuda/BUILD +++ b/tensorflow/stream_executor/cuda/BUILD @@ -242,6 +242,29 @@ alias( visibility = ["//visibility:public"], ) +cc_library( + name = "cublasLt_stub", + srcs = if_cuda_is_configured(["cublasLt_stub.cc"]), + textual_hdrs = glob(["cublasLt_*.inc"]), + deps = if_cuda_is_configured([ + # LINT.IfChange + "@local_config_cuda//cuda:cublas_headers", + # LINT.ThenChange(//tensorflow/copy.bara.sky:cublasLt_headers) + "@local_config_cuda//cuda:cuda_headers", + "//tensorflow/stream_executor/lib", + "//tensorflow/stream_executor/platform:dso_loader", + ]), +) + +alias( + name = "cublasLt_lib", + actual = select({ + "//tensorflow:oss": ":cublasLt_stub", + "//conditions:default": "@local_config_cuda//cuda:cublasLt", + }), + visibility = ["//visibility:public"], +) + cc_library( name = "cublas_plugin", srcs = if_cuda_is_configured(["cuda_blas.cc"]), @@ -249,6 +272,7 @@ cc_library( visibility = ["//visibility:public"], deps = if_cuda_is_configured([ ":cublas_lib", + ":cublasLt_lib", ":cuda_activation", ":cuda_gpu_executor", ":cuda_platform_id", diff --git a/tensorflow/stream_executor/cuda/cublasLt_11_0.inc b/tensorflow/stream_executor/cuda/cublasLt_11_0.inc new file mode 100644 index 00000000000..819dfced4ff --- /dev/null +++ b/tensorflow/stream_executor/cuda/cublasLt_11_0.inc @@ -0,0 +1,415 @@ +// Auto-generated, do not edit. + +extern "C" { + +cublasStatus_t CUBLASWINAPI +cublasLtCreate(cublasLtHandle_t *lightHandle) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtHandle_t *); + static auto func_ptr = LoadSymbol("cublasLtCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(lightHandle); +} + +cublasStatus_t CUBLASWINAPI +cublasLtDestroy(cublasLtHandle_t lightHandle) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtHandle_t); + static auto func_ptr = LoadSymbol("cublasLtDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(lightHandle); +} + +size_t CUBLASWINAPI +cublasLtGetVersion(void) { + using FuncPtr = size_t (CUBLASWINAPI *)(); + static auto func_ptr = LoadSymbol("cublasLtGetVersion"); + if (!func_ptr) return 0; + return func_ptr(); +} + +size_t CUBLASWINAPI +cublasLtGetCudartVersion(void) { + using FuncPtr = size_t (CUBLASWINAPI *)(); + static auto func_ptr = LoadSymbol("cublasLtGetCudartVersion"); + if (!func_ptr) return 0; + return func_ptr(); +} + +cublasStatus_t CUBLASWINAPI +cublasLtGetProperty(libraryPropertyType type, int *value) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(libraryPropertyType, int *); + static auto func_ptr = LoadSymbol("cublasLtGetProperty"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(type, value); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatmul(cublasLtHandle_t lightHandle, + cublasLtMatmulDesc_t computeDesc, + const void *alpha, /* host or device pointer */ + const void *A, + cublasLtMatrixLayout_t Adesc, + const void *B, + cublasLtMatrixLayout_t Bdesc, + const void *beta, /* host or device pointer */ + const void *C, + cublasLtMatrixLayout_t Cdesc, + void *D, + cublasLtMatrixLayout_t Ddesc, + const cublasLtMatmulAlgo_t *algo, + void *workspace, + size_t workspaceSizeInBytes, + cudaStream_t stream) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtHandle_t, cublasLtMatmulDesc_t, const void *, const void *, cublasLtMatrixLayout_t, const void *, cublasLtMatrixLayout_t, const void *, const void *, cublasLtMatrixLayout_t, void *, cublasLtMatrixLayout_t, const cublasLtMatmulAlgo_t *, void *, size_t, cudaStream_t); + static auto func_ptr = LoadSymbol("cublasLtMatmul"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(lightHandle, computeDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, D, Ddesc, algo, workspace, workspaceSizeInBytes, stream); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatrixTransform(cublasLtHandle_t lightHandle, + cublasLtMatrixTransformDesc_t transformDesc, + const void *alpha, /* host or device pointer */ + const void *A, + cublasLtMatrixLayout_t Adesc, + const void *beta, /* host or device pointer */ + const void *B, + cublasLtMatrixLayout_t Bdesc, + void *C, + cublasLtMatrixLayout_t Cdesc, + cudaStream_t stream) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtHandle_t, cublasLtMatrixTransformDesc_t, const void *, const void *, cublasLtMatrixLayout_t, const void *, const void *, cublasLtMatrixLayout_t, void *, cublasLtMatrixLayout_t, cudaStream_t); + static auto func_ptr = LoadSymbol("cublasLtMatrixTransform"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(lightHandle, transformDesc, alpha, A, Adesc, beta, B, Bdesc, C, Cdesc, stream); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatrixLayoutInit_internal( // + cublasLtMatrixLayout_t matLayout, + size_t size, + cudaDataType type, + uint64_t rows, + uint64_t cols, + int64_t ld) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(// + cublasLtMatrixLayout_t, size_t, cudaDataType, uint64_t, uint64_t, int64_t); + static auto func_ptr = LoadSymbol("cublasLtMatrixLayoutInit_internal"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(matLayout, size, type, rows, cols, ld); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatrixLayoutCreate( // + cublasLtMatrixLayout_t *matLayout, + cudaDataType type, + uint64_t rows, + uint64_t cols, + int64_t ld) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(// + cublasLtMatrixLayout_t *, cudaDataType, uint64_t, uint64_t, int64_t); + static auto func_ptr = LoadSymbol("cublasLtMatrixLayoutCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(matLayout, type, rows, cols, ld); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatrixLayoutDestroy(cublasLtMatrixLayout_t matLayout) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatrixLayout_t); + static auto func_ptr = LoadSymbol("cublasLtMatrixLayoutDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(matLayout); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatrixLayoutSetAttribute( // + cublasLtMatrixLayout_t matLayout, + cublasLtMatrixLayoutAttribute_t attr, + const void *buf, + size_t sizeInBytes) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(// + cublasLtMatrixLayout_t, cublasLtMatrixLayoutAttribute_t, const void *, size_t); + static auto func_ptr = LoadSymbol("cublasLtMatrixLayoutSetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(matLayout, attr, buf, sizeInBytes); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatrixLayoutGetAttribute( // + cublasLtMatrixLayout_t matLayout, + cublasLtMatrixLayoutAttribute_t attr, + void *buf, + size_t sizeInBytes, + size_t *sizeWritten) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(// + cublasLtMatrixLayout_t, cublasLtMatrixLayoutAttribute_t, void *, size_t, size_t *); + static auto func_ptr = LoadSymbol("cublasLtMatrixLayoutGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(matLayout, attr, buf, sizeInBytes, sizeWritten); +} + +cublasStatus_t CUBLASWINAPI cublasLtMatmulDescInit_internal( // + cublasLtMatmulDesc_t matmulDesc, + size_t size, + cublasComputeType_t computeType, + cudaDataType_t scaleType) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(// + cublasLtMatmulDesc_t, size_t, cublasComputeType_t, cudaDataType_t); + static auto func_ptr = LoadSymbol("cublasLtMatmulDescInit_internal"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(matmulDesc, size, computeType, scaleType); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatmulDescCreate(cublasLtMatmulDesc_t *matmulDesc, cublasComputeType_t computeType, cudaDataType_t scaleType) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatmulDesc_t *, cublasComputeType_t, cudaDataType_t); + static auto func_ptr = LoadSymbol("cublasLtMatmulDescCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(matmulDesc, computeType, scaleType); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatmulDescDestroy(cublasLtMatmulDesc_t matmulDesc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatmulDesc_t); + static auto func_ptr = LoadSymbol("cublasLtMatmulDescDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(matmulDesc); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatmulDescSetAttribute( // + cublasLtMatmulDesc_t matmulDesc, + cublasLtMatmulDescAttributes_t attr, + const void *buf, + size_t sizeInBytes) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(// + cublasLtMatmulDesc_t, cublasLtMatmulDescAttributes_t, const void *, size_t); + static auto func_ptr = LoadSymbol("cublasLtMatmulDescSetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(matmulDesc, attr, buf, sizeInBytes); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatmulDescGetAttribute( // + cublasLtMatmulDesc_t matmulDesc, + cublasLtMatmulDescAttributes_t attr, + void *buf, + size_t sizeInBytes, + size_t *sizeWritten) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(// + cublasLtMatmulDesc_t, cublasLtMatmulDescAttributes_t, void *, size_t, size_t *); + static auto func_ptr = LoadSymbol("cublasLtMatmulDescGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(matmulDesc, attr, buf, sizeInBytes, sizeWritten); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatrixTransformDescInit_internal(cublasLtMatrixTransformDesc_t transformDesc, size_t size, cudaDataType scaleType) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatrixTransformDesc_t, size_t, cudaDataType); + static auto func_ptr = LoadSymbol("cublasLtMatrixTransformDescInit_internal"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(transformDesc, size, scaleType); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatrixTransformDescCreate(cublasLtMatrixTransformDesc_t *transformDesc, cudaDataType scaleType) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatrixTransformDesc_t *, cudaDataType); + static auto func_ptr = LoadSymbol("cublasLtMatrixTransformDescCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(transformDesc, scaleType); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatrixTransformDescDestroy(cublasLtMatrixTransformDesc_t transformDesc) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatrixTransformDesc_t); + static auto func_ptr = LoadSymbol("cublasLtMatrixTransformDescDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(transformDesc); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatrixTransformDescSetAttribute( // + cublasLtMatrixTransformDesc_t transformDesc, + cublasLtMatrixTransformDescAttributes_t attr, + const void *buf, + size_t sizeInBytes) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(// + cublasLtMatrixTransformDesc_t, cublasLtMatrixTransformDescAttributes_t, const void *, size_t); + static auto func_ptr = LoadSymbol("cublasLtMatrixTransformDescSetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(transformDesc, attr, buf, sizeInBytes); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatrixTransformDescGetAttribute( // + cublasLtMatrixTransformDesc_t transformDesc, + cublasLtMatrixTransformDescAttributes_t attr, + void *buf, + size_t sizeInBytes, + size_t *sizeWritten) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(// + cublasLtMatrixTransformDesc_t, cublasLtMatrixTransformDescAttributes_t, void *, size_t, size_t *); + static auto func_ptr = LoadSymbol("cublasLtMatrixTransformDescGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(transformDesc, attr, buf, sizeInBytes, sizeWritten); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatmulPreferenceInit_internal(cublasLtMatmulPreference_t pref, size_t size) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatmulPreference_t, size_t); + static auto func_ptr = LoadSymbol("cublasLtMatmulPreferenceInit_internal"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pref, size); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatmulPreferenceCreate(cublasLtMatmulPreference_t *pref) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatmulPreference_t *); + static auto func_ptr = LoadSymbol("cublasLtMatmulPreferenceCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pref); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatmulPreferenceDestroy(cublasLtMatmulPreference_t pref) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatmulPreference_t); + static auto func_ptr = LoadSymbol("cublasLtMatmulPreferenceDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pref); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatmulPreferenceSetAttribute( // + cublasLtMatmulPreference_t pref, + cublasLtMatmulPreferenceAttributes_t attr, + const void *buf, + size_t sizeInBytes) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(// + cublasLtMatmulPreference_t, cublasLtMatmulPreferenceAttributes_t, const void *, size_t); + static auto func_ptr = LoadSymbol("cublasLtMatmulPreferenceSetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pref, attr, buf, sizeInBytes); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatmulPreferenceGetAttribute( // + cublasLtMatmulPreference_t pref, + cublasLtMatmulPreferenceAttributes_t attr, + void *buf, + size_t sizeInBytes, + size_t *sizeWritten) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(// + cublasLtMatmulPreference_t, cublasLtMatmulPreferenceAttributes_t, void *, size_t, size_t *); + static auto func_ptr = LoadSymbol("cublasLtMatmulPreferenceGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pref, attr, buf, sizeInBytes, sizeWritten); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatmulAlgoGetHeuristic( + cublasLtHandle_t lightHandle, + cublasLtMatmulDesc_t operationDesc, + cublasLtMatrixLayout_t Adesc, + cublasLtMatrixLayout_t Bdesc, + cublasLtMatrixLayout_t Cdesc, + cublasLtMatrixLayout_t Ddesc, + cublasLtMatmulPreference_t preference, + int requestedAlgoCount, + cublasLtMatmulHeuristicResult_t heuristicResultsArray[], + int *returnAlgoCount) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtHandle_t, cublasLtMatmulDesc_t, cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, cublasLtMatmulPreference_t, int, cublasLtMatmulHeuristicResult_t [], int *); + static auto func_ptr = LoadSymbol("cublasLtMatmulAlgoGetHeuristic"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(lightHandle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, requestedAlgoCount, heuristicResultsArray, returnAlgoCount); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatmulAlgoGetIds( + cublasLtHandle_t lightHandle, + cublasComputeType_t computeType, + cudaDataType_t scaleType, + cudaDataType_t Atype, + cudaDataType_t Btype, + cudaDataType_t Ctype, + cudaDataType_t Dtype, + int requestedAlgoCount, + int algoIdsArray[], + int *returnAlgoCount) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtHandle_t, cublasComputeType_t, cudaDataType_t, cudaDataType_t, cudaDataType_t, cudaDataType_t, cudaDataType_t, int, int [], int *); + static auto func_ptr = LoadSymbol("cublasLtMatmulAlgoGetIds"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(lightHandle, computeType, scaleType, Atype, Btype, Ctype, Dtype, requestedAlgoCount, algoIdsArray, returnAlgoCount); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatmulAlgoInit ( cublasLtHandle_t lightHandle, + cublasComputeType_t computeType, + cudaDataType_t scaleType, + cudaDataType_t Atype, + cudaDataType_t Btype, + cudaDataType_t Ctype, + cudaDataType_t Dtype, + int algoId, + cublasLtMatmulAlgo_t *algo) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtHandle_t, cublasComputeType_t, cudaDataType_t, cudaDataType_t, cudaDataType_t, cudaDataType_t, cudaDataType_t, int, cublasLtMatmulAlgo_t *); + static auto func_ptr = LoadSymbol("cublasLtMatmulAlgoInit"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(lightHandle, computeType, scaleType, Atype, Btype, Ctype, Dtype, algoId, algo); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatmulAlgoCheck( // + cublasLtHandle_t lightHandle, + cublasLtMatmulDesc_t operationDesc, + cublasLtMatrixLayout_t Adesc, + cublasLtMatrixLayout_t Bdesc, + cublasLtMatrixLayout_t Cdesc, + cublasLtMatrixLayout_t Ddesc, + const cublasLtMatmulAlgo_t *algo, ///< may point to result->algo + cublasLtMatmulHeuristicResult_t *result) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(// + cublasLtHandle_t, cublasLtMatmulDesc_t, cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, const cublasLtMatmulAlgo_t *, ///< may point to result->algo + cublasLtMatmulHeuristicResult_t *); + static auto func_ptr = LoadSymbol("cublasLtMatmulAlgoCheck"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(lightHandle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, algo, result); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatmulAlgoCapGetAttribute( + const cublasLtMatmulAlgo_t *algo, + cublasLtMatmulAlgoCapAttributes_t attr, + void *buf, + size_t sizeInBytes, + size_t *sizeWritten) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(const cublasLtMatmulAlgo_t *, cublasLtMatmulAlgoCapAttributes_t, void *, size_t, size_t *); + static auto func_ptr = LoadSymbol("cublasLtMatmulAlgoCapGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(algo, attr, buf, sizeInBytes, sizeWritten); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatmulAlgoConfigSetAttribute( + cublasLtMatmulAlgo_t *algo, + cublasLtMatmulAlgoConfigAttributes_t attr, + const void *buf, + size_t sizeInBytes) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatmulAlgo_t *, cublasLtMatmulAlgoConfigAttributes_t, const void *, size_t); + static auto func_ptr = LoadSymbol("cublasLtMatmulAlgoConfigSetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(algo, attr, buf, sizeInBytes); +} + +cublasStatus_t CUBLASWINAPI +cublasLtMatmulAlgoConfigGetAttribute( + const cublasLtMatmulAlgo_t *algo, + cublasLtMatmulAlgoConfigAttributes_t attr, + void *buf, + size_t sizeInBytes, + size_t *sizeWritten) { + using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(const cublasLtMatmulAlgo_t *, cublasLtMatmulAlgoConfigAttributes_t, void *, size_t, size_t *); + static auto func_ptr = LoadSymbol("cublasLtMatmulAlgoConfigGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(algo, attr, buf, sizeInBytes, sizeWritten); +} + +} // extern "C" diff --git a/tensorflow/stream_executor/cuda/cublasLt_stub.cc b/tensorflow/stream_executor/cuda/cublasLt_stub.cc new file mode 100644 index 00000000000..aae8a94285b --- /dev/null +++ b/tensorflow/stream_executor/cuda/cublasLt_stub.cc @@ -0,0 +1,59 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "third_party/gpus/cuda/include/cublasLt.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "tensorflow/stream_executor/lib/env.h" +#include "tensorflow/stream_executor/platform/dso_loader.h" + +// Implements the cuBLASLt API by forwarding to cuBLASLt loaded from the DSO. + +namespace { +// Returns DSO handle or null if loading the DSO fails. +void* GetDsoHandle() { +#ifdef PLATFORM_GOOGLE + return nullptr; +#else + static auto handle = []() -> void* { + auto handle_or = + stream_executor::internal::DsoLoader::GetCublasLtDsoHandle(); + if (!handle_or.ok()) return nullptr; + return handle_or.ValueOrDie(); + }(); + return handle; +#endif +} + +template +T LoadSymbol(const char* symbol_name) { + void* symbol = nullptr; + if (auto handle = GetDsoHandle()) { + stream_executor::port::Env::Default() + ->GetSymbolFromLibrary(handle, symbol_name, &symbol) + .IgnoreError(); + } + return reinterpret_cast(symbol); +} + +void LogFatalSymbolNotFound(const char* symbol_name) { + LOG(FATAL) << symbol_name << " symbol not found."; +} + +cublasStatus_t GetSymbolNotFoundError() { return CUBLAS_STATUS_INTERNAL_ERROR; } +} // namespace + +// We only use cublasLt from CUDA 11.0 onward. +#if CUDA_VERSION >= 11000 +#include "tensorflow/stream_executor/cuda/cublasLt_11_0.inc" +#endif diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index 4b659bb81e1..565a1c02fb4 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "third_party/gpus/cuda/include/cublas_v2.h" +#include "third_party/gpus/cuda/include/cublasLt.h" #include "third_party/gpus/cuda/include/cuda.h" #define SE_CUDA_DATA_HALF CUDA_R_16F @@ -226,17 +227,38 @@ bool CUDABlas::Init() { return false; } +#if CUDA_VERSION >= 11000 + ret = cublasLtCreate(&blasLt_); + if (ret != CUBLAS_STATUS_SUCCESS) { + LOG(ERROR) << "failed to create cublasLt handle: " << ToString(ret); + return false; + } +#endif // CUDA_VERSION >= 11000 + return true; } -CUDABlas::CUDABlas(gpu::GpuExecutor *parent) - : parent_(CHECK_NOTNULL(parent)), blas_(nullptr) {} +CUDABlas::CUDABlas(gpu::GpuExecutor* parent) + : parent_(CHECK_NOTNULL(parent)), + blas_(nullptr) +#if CUDA_VERSION >= 11000 + , + blasLt_(nullptr) +#endif +{ +} CUDABlas::~CUDABlas() { if (blas_ != nullptr) { gpu::ScopedActivateExecutorContext sac{parent_}; cublasDestroy(blas_); } +#if CUDA_VERSION >= 11000 + if (blasLt_ != nullptr) { + gpu::ScopedActivateExecutorContext sac{parent_}; + cublasLtDestroy(blasLt_); + } +#endif } bool CUDABlas::SetStream(Stream *stream) { @@ -253,6 +275,13 @@ bool CUDABlas::SetStream(Stream *stream) { return true; } +cudaStream_t CUDABlas::CUDAStream(Stream* stream) { + CHECK(stream != nullptr); + CHECK(AsGpuStreamValue(stream) != nullptr); + gpu::ScopedActivateExecutorContext sac{parent_}; + return AsGpuStreamValue(stream); +} + namespace { // Helper functions transforming blas arguments into cuBLAS arguments. @@ -381,6 +410,82 @@ cudaDataType_t CUDAComputationType(blas::ComputationType ty) { return CUDA_C_32F; case blas::ComputationType::kComplexF64: return CUDA_C_64F; + case blas::ComputationType::kF32FastTF32: // fall-through + case blas::ComputationType::kF32FastBF16: + // These cases are currently only supported in the blasLt routines, which + // use CUBLASComputationType() instead. + LOG(FATAL) << "Invalid value of blas::ComputationType."; + } +} + +#if CUDA_VERSION >= 11000 +cublasComputeType_t CUBLASComputationType(blas::ComputationType ty) { + switch (ty) { + case blas::ComputationType::kF16: + return CUBLAS_COMPUTE_16F; + case blas::ComputationType::kF32: // fall-through + case blas::ComputationType::kComplexF32: + return CUBLAS_COMPUTE_32F; + case blas::ComputationType::kF64: // fall-through + case blas::ComputationType::kComplexF64: + return CUBLAS_COMPUTE_64F; + case blas::ComputationType::kI32: + return CUBLAS_COMPUTE_32I; + case blas::ComputationType::kF32FastTF32: + return CUBLAS_COMPUTE_32F_FAST_TF32; + case blas::ComputationType::kF32FastBF16: + return CUBLAS_COMPUTE_32F_FAST_16BF; + } +} +#endif // CUDA_VERSION >= 11000 + +blas::DataType GetScaleType(blas::DataType data_type, + blas::ComputationType compute_type) { + bool is_complex = data_type == blas::DataType::kComplexF32 || + data_type == blas::DataType::kComplexF64; + switch (compute_type) { + case blas::ComputationType::kF16: + return blas::DataType::kF16; + case blas::ComputationType::kF32: // fall-through + case blas::ComputationType::kComplexF32: // fall-through + case blas::ComputationType::kF32FastTF32: // fall-through + case blas::ComputationType::kF32FastBF16: + return is_complex ? blas::DataType::kComplexF32 : blas::DataType::kF32; + case blas::ComputationType::kF64: // fall-through + case blas::ComputationType::kComplexF64: + return is_complex ? blas::DataType::kComplexF64 : blas::DataType::kF64; + case blas::ComputationType::kI32: + return blas::DataType::kI32; + } +} + +#if CUDA_VERSION >= 11000 +cublasLtPointerMode_t CUBLASPointerMode(blas::PointerMode pointer_mode) { + switch (pointer_mode) { + case blas::PointerMode::kHost: + return CUBLASLT_POINTER_MODE_HOST; + case blas::PointerMode::kDevice: + return CUBLASLT_POINTER_MODE_DEVICE; + } +} +#endif // CUDA_VERSION >= 11000 + +cudaDataType_t GetCUDADataType(blas::DataType ty) { + switch (ty) { + case blas::DataType::kF16: + return CUDA_R_16F; + case blas::DataType::kF32: + return CUDA_R_32F; + case blas::DataType::kF64: + return CUDA_R_64F; + case blas::DataType::kI8: + return CUDA_R_8I; + case blas::DataType::kI32: + return CUDA_R_32I; + case blas::DataType::kComplexF32: + return CUDA_C_32F; + case blas::DataType::kComplexF64: + return CUDA_C_64F; } } } // namespace @@ -2912,6 +3017,577 @@ bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side, GpuComplex(GpuMemoryMutable(b)), ldb); } +// We only use cublasLt from CUDA 11.0 onward. +#if CUDA_VERSION >= 11000 + +namespace { + +template +inline bool SetCublasLtAttr(cublasLtMatrixLayout_t handle, + cublasLtMatrixLayoutAttribute_t attr, + const T& value) { + cublasStatus_t status = + cublasLtMatrixLayoutSetAttribute(handle, attr, &value, sizeof(T)); + if (status != CUBLAS_STATUS_SUCCESS) { + VLOG(2) << "cublasLtMatrixLayoutSetAttribute(attr=" << attr + << ", value=" << value << ") failed: " << ToString(status); + return false; + } + return true; +} + +template +inline bool SetCublasLtAttr(cublasLtMatmulAlgo_t* handle, + cublasLtMatmulAlgoConfigAttributes_t attr, + const T& value) { + cublasStatus_t status = + cublasLtMatmulAlgoConfigSetAttribute(handle, attr, &value, sizeof(T)); + if (status != CUBLAS_STATUS_SUCCESS) { + VLOG(2) << "cublasLtMatmulAlgoConfigSetAttribute(attr=" << attr + << ", value=" << value << ") failed: " << ToString(status); + return false; + } + return true; +} + +template +inline bool SetCublasLtAttr(cublasLtMatmulPreference_t handle, + cublasLtMatmulPreferenceAttributes_t attr, + const T& value) { + cublasStatus_t status = + cublasLtMatmulPreferenceSetAttribute(handle, attr, &value, sizeof(value)); + if (status != CUBLAS_STATUS_SUCCESS) { + VLOG(2) << "cublasLtMatmulPreferenceSetAttribute(attr=" << attr + << ", value=" << value << ") failed: " << ToString(status); + return false; + } + return true; +} + +template +inline bool GetCublasLtAttr(const cublasLtMatmulAlgo_t* handle, + cublasLtMatmulAlgoConfigAttributes_t attr, + T* value) { + auto mutable_handle = const_cast(handle); + size_t bytes_written = 0; + return cublasLtMatmulAlgoConfigGetAttribute(mutable_handle, attr, value, + sizeof(T), &bytes_written) == + CUBLAS_STATUS_SUCCESS && + bytes_written == sizeof(T); +} + +template +inline bool SetCublasLtAttr(cublasLtMatmulDesc_t handle, + cublasLtMatmulDescAttributes_t attr, + const T& value) { + cublasStatus_t status = + cublasLtMatmulDescSetAttribute(handle, attr, &value, sizeof(value)); + if (status != CUBLAS_STATUS_SUCCESS) { + VLOG(2) << "cublasLtMatmulDescSetAttribute(attr=" << attr + << ", value=" << value << ") failed: " << ToString(status); + return false; + } + return true; +} + +struct MatmulDescDestroyer { + void operator()(cublasLtMatmulDesc_t matmul_desc) const { + cublasLtMatmulDescDestroy(matmul_desc); + } +}; +struct LayoutDestroyer { + void operator()(cublasLtMatrixLayout_t layout) const { + cublasLtMatrixLayoutDestroy(layout); + } +}; +struct MatmulPreferenceDestroyer { + void operator()(cublasLtMatmulPreference_t matmul_pref) const { + cublasLtMatmulPreferenceDestroy(matmul_pref); + } +}; +using UniqueOpDesc = + std::unique_ptr::type, + MatmulDescDestroyer>; +using UniqueLayoutDesc = + std::unique_ptr::type, + LayoutDestroyer>; +using UniqueMatmulPreference = + std::unique_ptr::type, + MatmulPreferenceDestroyer>; + +UniqueOpDesc CreateCublasLtOperationDesc( + blas::ComputationType computation_type, blas::DataType scale_type, + blas::PointerMode pointer_mode, blas::Transpose transa, + blas::Transpose transb) { + cublasOperation_t cuda_transa = CUDABlasTranspose(transa); + cublasOperation_t cuda_transb = CUDABlasTranspose(transb); + cublasLtMatmulDesc_t desc; + cublasComputeType_t cublas_compute_type = + CUBLASComputationType(computation_type); + cudaDataType_t cuda_scale_type = GetCUDADataType(scale_type); + cublasStatus_t status = + cublasLtMatmulDescCreate(&desc, cublas_compute_type, cuda_scale_type); + if (status != CUBLAS_STATUS_SUCCESS) { + VLOG(2) << "cublasLtMatmulDescCreate(computation_type=" << computation_type + << ") failed: " << ToString(status); + return nullptr; + } + UniqueOpDesc unique_desc(desc); + if (!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_POINTER_MODE, + CUBLASPointerMode(pointer_mode)) || + !SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSA, cuda_transa) || + !SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSB, cuda_transb)) { + return nullptr; + } + return unique_desc; +} + +UniqueLayoutDesc CreateCublasLtLayoutDesc(blas::DataType data_type, uint64 rows, + uint64 cols, int64 ld, int64 stride, + int batch_count) { + cublasLtMatrixLayout_t desc; + cublasStatus_t status = cublasLtMatrixLayoutCreate( + &desc, GetCUDADataType(data_type), rows, cols, ld); + if (status != CUBLAS_STATUS_SUCCESS) { + VLOG(2) << "cublasLtMatrixLayoutCreate failed: " << ToString(status); + return nullptr; + } + UniqueLayoutDesc unique_desc(desc); + if (!SetCublasLtAttr(desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_count) || + !SetCublasLtAttr(desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + stride)) { + return nullptr; + } + return unique_desc; +} + +UniqueMatmulPreference CreateCublasLtMatmulPreference( + size_t max_workspace_bytes) { + cublasLtMatmulPreference_t preference; + cublasStatus_t status = cublasLtMatmulPreferenceCreate(&preference); + if (status != CUBLAS_STATUS_SUCCESS) { + VLOG(2) << "cublasLtMatmulPreferenceCreate failed: " << ToString(status); + return nullptr; + } + UniqueMatmulPreference unique_preference(preference); + if (!SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + max_workspace_bytes)) { + return nullptr; + } + return unique_preference; +} + +// Helper function to allocate workspace. +port::Status AllocateWorkspace(void** workspace, + ScratchAllocator* scratch_allocator, + size_t num_bytes) { + SE_ASSIGN_OR_RETURN(DeviceMemory workspace_bytes, + scratch_allocator->AllocateBytes(num_bytes)); + *workspace = (void*)GpuMemoryMutable(&workspace_bytes); + return port::Status::OK(); +} + +template +blas::ComputationType ToComputationType(); +template <> +blas::ComputationType ToComputationType() { + return blas::ComputationType::kF16; +} +template <> +blas::ComputationType ToComputationType() { + return blas::ComputationType::kF32; +} +template <> +blas::ComputationType ToComputationType() { + return blas::ComputationType::kF64; +} +template <> +blas::ComputationType ToComputationType>() { + return blas::ComputationType::kComplexF32; +}template <> +blas::ComputationType ToComputationType>() { + return blas::ComputationType::kComplexF64; +} + +class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan { + public: + CUDABlasLtMatmulPlan(blas::DataType ab_type, blas::DataType cd_type, + blas::ComputationType compute_type, + blas::PointerMode pointer_mode, blas::Transpose transa, + blas::Transpose transb, uint64 m, uint64 n, uint64 k, + int batch_count, int64 lda, int64 stride_a, int64 ldb, + int64 stride_b, int64 ldc, int64 stride_c, int64 ldd, + int64 stride_d); + + cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); } + cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); } + cublasLtMatrixLayout_t b_desc() const { return b_desc_.get(); } + cublasLtMatrixLayout_t c_desc() const { return c_desc_.get(); } + cublasLtMatrixLayout_t d_desc() const { return d_desc_.get(); } + bool ok() { return op_desc_ && a_desc_ && b_desc_ && c_desc_ && d_desc_; } + + blas::DataType ab_type() const { return ab_type_; } + blas::DataType cd_type() const { return cd_type_; } + blas::DataType scale_type() const { return scale_type_; } + blas::PointerMode pointer_mode() const { return pointer_mode_; } + + private: + UniqueOpDesc op_desc_; + UniqueLayoutDesc a_desc_; + UniqueLayoutDesc b_desc_; + UniqueLayoutDesc c_desc_; + UniqueLayoutDesc d_desc_; + blas::DataType ab_type_; + blas::DataType cd_type_; + blas::DataType scale_type_; + blas::PointerMode pointer_mode_; +}; + +CUDABlasLtMatmulPlan::CUDABlasLtMatmulPlan( + blas::DataType ab_type, blas::DataType cd_type, + blas::ComputationType computation_type, blas::PointerMode pointer_mode, + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, int batch_count, int64 lda, int64 stride_a, int64 ldb, + int64 stride_b, int64 ldc, int64 stride_c, int64 ldd, int64 stride_d) + : op_desc_(CreateCublasLtOperationDesc( + computation_type, GetScaleType(cd_type, computation_type), + pointer_mode, transa, transb)), + a_desc_(nullptr), + b_desc_(nullptr), + c_desc_( + CreateCublasLtLayoutDesc(cd_type, m, n, ldc, stride_c, batch_count)), + d_desc_( + CreateCublasLtLayoutDesc(cd_type, m, n, ldd, stride_d, batch_count)), + ab_type_(ab_type), + cd_type_(cd_type), + scale_type_(GetScaleType(cd_type, computation_type)), + pointer_mode_(pointer_mode) { + uint64 rows_a = transa == blas::Transpose::kNoTranspose ? m : k; + uint64 cols_a = transa == blas::Transpose::kNoTranspose ? k : m; + uint64 rows_b = transb == blas::Transpose::kNoTranspose ? k : n; + uint64 cols_b = transb == blas::Transpose::kNoTranspose ? n : k; + a_desc_ = CreateCublasLtLayoutDesc(ab_type, rows_a, cols_a, lda, stride_a, + batch_count); + b_desc_ = CreateCublasLtLayoutDesc(ab_type, rows_b, cols_b, ldb, stride_b, + batch_count); +} + +class CUDABlasLtMatmulAlgorithm final : public blas::IBlasLtMatmulAlgorithm { + public: + CUDABlasLtMatmulAlgorithm(blas::AlgorithmType index, + cublasLtMatmulAlgo_t algo, size_t workspace_size) + : index_(index), algo_(algo), workspace_size_(workspace_size) {} + + blas::AlgorithmType index() const override { return index_; } + + size_t workspace_size() const override { return workspace_size_; } + + const cublasLtMatmulAlgo_t* algo() const { return &algo_; } + + int algo_id() const { + int id; + GetCublasLtAttr(&algo_, CUBLASLT_ALGO_CONFIG_ID, &id); + return id; + } + + private: + blas::AlgorithmType index_; + cublasLtMatmulAlgo_t algo_; + size_t workspace_size_; +}; + +} // namespace + +#endif // CUDA_VERSION >= 11000 + +std::unique_ptr +CUDABlas::CreateBlasLtMatmulPlanStridedBatched( + blas::DataType ab_type, blas::DataType cd_type, + blas::ComputationType computation_type, blas::PointerMode pointer_mode, + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, int batch_count, int64 lda, int64 stride_a, int64 ldb, + int64 stride_b, int64 ldc, int64 stride_c) { +#if CUDA_VERSION >= 11000 + auto result = std::make_unique( + ab_type, cd_type, computation_type, pointer_mode, transa, transb, m, n, k, + batch_count, lda, stride_a, ldb, stride_b, ldc, stride_c, ldc, stride_c); + if (!result->ok()) { + result.reset(); + } + return result; +#else + return nullptr; +#endif +} + +bool CUDABlas::GetBlasLtMatmulAlgorithms( + const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size, + int max_algorithm_count, + std::vector>* + out_algorithms) { +#if CUDA_VERSION >= 11000 + UniqueMatmulPreference preference = + CreateCublasLtMatmulPreference(max_workspace_size); + if (!preference) return false; + + std::vector results(max_algorithm_count); + { + absl::MutexLock lock(&mu_); + + CHECK(blasLt_ != nullptr); + + gpu::ScopedActivateExecutorContext sac{parent_}; + + int found_algorithm_count = 0; + const auto& cuda_plan = *static_cast(plan); + cublasStatus_t status = cublasLtMatmulAlgoGetHeuristic( + blasLt_, cuda_plan.op_desc(), cuda_plan.a_desc(), cuda_plan.b_desc(), + cuda_plan.c_desc(), cuda_plan.d_desc(), preference.get(), + max_algorithm_count, results.data(), &found_algorithm_count); + if (status != CUBLAS_STATUS_SUCCESS) { + VLOG(2) << "cublasLtMatmulAlgoGetHeuristic failed: " << ToString(status); + return false; + } + results.resize(found_algorithm_count); + } + + for (size_t i = 0; i < results.size(); ++i) { + const auto& result = results[i]; + if (result.state != CUBLAS_STATUS_SUCCESS) continue; // Skip failed algos + out_algorithms->emplace_back(std::make_unique( + i, result.algo, result.workspaceSize)); + } + return true; +#else // if CUDA_VERSION < 11000 + return false; +#endif +} + +#if CUDA_VERSION >= 11000 +template +bool CUDABlas::DoBlasLtMatmulInternalImpl( + Stream* stream, bool err_on_failure, const blas::IBlasLtMatmulPlan* plan, + const HostOrDeviceScalar& alpha, const ABType* a, + const ABType* b, const HostOrDeviceScalar& beta, const CDType* c, + CDType* d, ScratchAllocator* scratch_allocator, + const blas::IBlasLtMatmulAlgorithm* algorithm) { + const auto& cuda_plan = *static_cast(plan); + const auto& cuda_algo = + *static_cast(algorithm); + + if (cuda_plan.ab_type() != blas::ToDataType::value) { + VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong ab_type: " + "expected " + << blas::ToDataType::value << ", got " + << cuda_plan.ab_type(); + return false; + } + if (cuda_plan.cd_type() != blas::ToDataType::value) { + VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong cd_type: " + "expected " + << blas::ToDataType::value << ", got " + << cuda_plan.cd_type(); + return false; + } + if (cuda_plan.scale_type() != blas::ToDataType::value) { + VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong " + "scale_type: expected " + << blas::ToDataType::value << ", got " + << cuda_plan.cd_type(); + return false; + } + if (alpha.is_pointer() != beta.is_pointer()) { + VLOG(2) << "DoBlasLtMatmul returning false because one of `alpha` " + "and `beta` is a pointer, but the other is not."; + return false; + } + bool is_pointer_mode_host = !alpha.is_pointer(); + if ((cuda_plan.pointer_mode() == blas::PointerMode::kHost) != + is_pointer_mode_host) { + VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong " + "pointer_mode for the given alpha/beta."; + return false; + } + const ScaleType* alpha_ptr = + alpha.is_pointer() ? GpuMemory(alpha.pointer()) : &alpha.value(); + const ScaleType* beta_ptr = + beta.is_pointer() ? GpuMemory(beta.pointer()) : &beta.value(); + + void* workspace = nullptr; + if (cuda_algo.workspace_size()) { + port::Status allocation_status = AllocateWorkspace( + &workspace, scratch_allocator, cuda_algo.workspace_size()); + if (!allocation_status.ok()) { + if (err_on_failure || VLOG_IS_ON(3)) { + LOG(ERROR) + << "Failed to allocate workspace for cublasLtMatmul algo with id: " + << cuda_algo.algo_id() << " requiring " + << cuda_algo.workspace_size() << " bytes of workspace"; + } + return false; + } + } + + cudaStream_t cuda_stream = CUDAStream(stream); + + absl::MutexLock lock(&mu_); + + CHECK(blasLt_ != nullptr); + + gpu::ScopedActivateExecutorContext sac{parent_}; + + cublasStatus_t ret = cublasLtMatmul( + blasLt_, cuda_plan.op_desc(), alpha_ptr, a, cuda_plan.a_desc(), b, + cuda_plan.b_desc(), beta_ptr, c, cuda_plan.c_desc(), d, + cuda_plan.d_desc(), cuda_algo.algo(), workspace, + cuda_algo.workspace_size(), cuda_stream); + if (ret != CUBLAS_STATUS_SUCCESS) { + if (err_on_failure || VLOG_IS_ON(3)) { + LOG(ERROR) << "failed to run cublasLtMatmul routine: " << ToString(ret); + } + return false; + } + return true; +} +#endif // CUDA_VERSION >= 11000 + +template +bool CUDABlas::DoBlasLtMatmulInternal( + Stream* stream, const blas::IBlasLtMatmulPlan* plan, + const HostOrDeviceScalar& alpha, const DeviceMemory& a, + const DeviceMemory& b, const HostOrDeviceScalar& beta, + const DeviceMemory& c, DeviceMemory* d, + ScratchAllocator* scratch_allocator, + const blas::IBlasLtMatmulAlgorithm* algorithm, + blas::ProfileResult* output_profile_result) { +#if CUDA_VERSION >= 11000 + std::unique_ptr timer; + if (output_profile_result) { + timer.reset(new GpuTimer(parent_)); + if (!timer->Init() || !timer->Start(AsGpuStream(stream))) { + return false; + } + } + + bool err_on_failure = timer != nullptr; + bool result = DoBlasLtMatmulInternalImpl( + stream, err_on_failure, plan, alpha, GpuMemory(a), GpuMemory(b), beta, + GpuMemory(c), GpuMemoryMutable(d), scratch_allocator, algorithm); + + if (timer && result) { + // GpuTimer will CHECK-fail if we Stop() it while the stream is in an error + // state. + if (!timer->Stop(AsGpuStream(stream))) { + return false; + } + output_profile_result->set_is_valid(true); + output_profile_result->set_algorithm(algorithm->index()); + output_profile_result->set_elapsed_time_in_ms( + timer->GetElapsedMilliseconds()); + } + return result; +#else // if CUDA_VERSION < 11000 + return false; +#endif +} + +bool CUDABlas::DoBlasLtMatmul( + Stream* stream, const blas::IBlasLtMatmulPlan* plan, + const HostOrDeviceScalar& alpha, const DeviceMemory& a, + const DeviceMemory& b, const HostOrDeviceScalar& beta, + DeviceMemory* c, ScratchAllocator* scratch_allocator, + const blas::IBlasLtMatmulAlgorithm* algorithm, + blas::ProfileResult* output_profile_result) { + return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c, + scratch_allocator, algorithm, + output_profile_result); +} + +bool CUDABlas::DoBlasLtMatmul(Stream* stream, + const blas::IBlasLtMatmulPlan* plan, + const HostOrDeviceScalar& alpha, + const DeviceMemory& a, + const DeviceMemory& b, + const HostOrDeviceScalar& beta, + DeviceMemory* c, + ScratchAllocator* scratch_allocator, + const blas::IBlasLtMatmulAlgorithm* algorithm, + blas::ProfileResult* output_profile_result) { +#if CUDA_VERSION >= 11000 + const auto& cuda_plan = *static_cast(plan); + if (cuda_plan.scale_type() == blas::DataType::kF32) { + // F32* computation types require F32 alpha/beta type, so we must cast them. + if (alpha.is_pointer() || beta.is_pointer()) { + // We cannot easily convert a pointer to f16 memory to a pointer to f32 + // memory from here, so we don't support this for now. + return false; + } + HostOrDeviceScalar float_alpha(static_cast(alpha.value())); + HostOrDeviceScalar float_beta(static_cast(beta.value())); + return DoBlasLtMatmulInternal(stream, plan, float_alpha, a, b, float_beta, + *c, c, scratch_allocator, algorithm, + output_profile_result); + } + return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c, + scratch_allocator, algorithm, + output_profile_result); +#else // if CUDA_VERSION < 11000 + return false; +#endif +} + +bool CUDABlas::DoBlasLtMatmul( + Stream* stream, const blas::IBlasLtMatmulPlan* plan, + const HostOrDeviceScalar& alpha, const DeviceMemory& a, + const DeviceMemory& b, const HostOrDeviceScalar& beta, + DeviceMemory* c, ScratchAllocator* scratch_allocator, + const blas::IBlasLtMatmulAlgorithm* algorithm, + blas::ProfileResult* output_profile_result) { + return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c, + scratch_allocator, algorithm, + output_profile_result); +} + +bool CUDABlas::DoBlasLtMatmul( + Stream* stream, const blas::IBlasLtMatmulPlan* plan, + const HostOrDeviceScalar& alpha, const DeviceMemory& a, + const DeviceMemory& b, const HostOrDeviceScalar& beta, + DeviceMemory* c, ScratchAllocator* scratch_allocator, + const blas::IBlasLtMatmulAlgorithm* algorithm, + blas::ProfileResult* output_profile_result) { + return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c, + scratch_allocator, algorithm, + output_profile_result); +} + +bool CUDABlas::DoBlasLtMatmul( + Stream* stream, const blas::IBlasLtMatmulPlan* plan, + const HostOrDeviceScalar>& alpha, + const DeviceMemory>& a, + const DeviceMemory>& b, + const HostOrDeviceScalar>& beta, + DeviceMemory>* c, ScratchAllocator* scratch_allocator, + const blas::IBlasLtMatmulAlgorithm* algorithm, + blas::ProfileResult* output_profile_result) { + return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c, + scratch_allocator, algorithm, + output_profile_result); +} + +bool CUDABlas::DoBlasLtMatmul( + Stream* stream, const blas::IBlasLtMatmulPlan* plan, + const HostOrDeviceScalar>& alpha, + const DeviceMemory>& a, + const DeviceMemory>& b, + const HostOrDeviceScalar>& beta, + DeviceMemory>* c, ScratchAllocator* scratch_allocator, + const blas::IBlasLtMatmulAlgorithm* algorithm, + blas::ProfileResult* output_profile_result) { + return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c, + scratch_allocator, algorithm, + output_profile_result); +} + port::Status CUDABlas::GetVersion(std::string *version) { absl::MutexLock lock(&mu_); diff --git a/tensorflow/stream_executor/cuda/cuda_blas.h b/tensorflow/stream_executor/cuda/cuda_blas.h index 9ff63102aaa..351a7778c01 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.h +++ b/tensorflow/stream_executor/cuda/cuda_blas.h @@ -22,6 +22,8 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "third_party/gpus/cuda/include/cublas_v2.h" +#include "third_party/gpus/cuda/include/cublasLt.h" +#include "third_party/gpus/cuda/include/cuda.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/stream_executor/blas.h" #include "tensorflow/stream_executor/host_or_device_scalar.h" @@ -71,6 +73,9 @@ class CUDABlas : public blas::BlasSupport { // invoked before calling into cuBLAS. bool SetStream(Stream *stream) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Returns the underlying CUDA stream. + cudaStream_t CUDAStream(Stream* stream); + // A helper function that calls the real cuBLAS function together with error // handling. // @@ -134,6 +139,26 @@ class CUDABlas : public blas::BlasSupport { const T &beta, DeviceMemory *y, int incy, blas::ProfileResult *output_profile_result); + // Helper function for implementing DoBlasLtMatmul. + template + bool DoBlasLtMatmulInternal( + Stream* stream, const blas::IBlasLtMatmulPlan* plan, + const HostOrDeviceScalar& alpha, const DeviceMemory& a, + const DeviceMemory& b, const HostOrDeviceScalar& beta, + const DeviceMemory& c, DeviceMemory* d, + ScratchAllocator* scratch_allocator, + const blas::IBlasLtMatmulAlgorithm* algorithm, + blas::ProfileResult* output_profile_result); + + // Helper function for implementing DoBlasLtMatmulInternal. + template + bool DoBlasLtMatmulInternalImpl( + Stream* stream, bool err_on_failure, const blas::IBlasLtMatmulPlan* plan, + const HostOrDeviceScalar& alpha, const ABType* a, + const ABType* b, const HostOrDeviceScalar& beta, + const CDType* c, CDType* d, ScratchAllocator* scratch_allocator, + const blas::IBlasLtMatmulAlgorithm* algorithm); + // Guards the cuBLAS handle for this device. absl::Mutex mu_; @@ -144,6 +169,11 @@ class CUDABlas : public blas::BlasSupport { // cuBLAS library handle on the device. cublasHandle_t blas_ TF_GUARDED_BY(mu_); +#if CUDA_VERSION >= 11000 + // cuBLASLt library handle on the device. + cublasLtHandle_t blasLt_ GUARDED_BY(mu_); +#endif + SE_DISALLOW_COPY_AND_ASSIGN(CUDABlas); }; diff --git a/tensorflow/stream_executor/platform/default/dlopen_checker.cc b/tensorflow/stream_executor/platform/default/dlopen_checker.cc index b55c9f53793..7b38dfcfec0 100644 --- a/tensorflow/stream_executor/platform/default/dlopen_checker.cc +++ b/tensorflow/stream_executor/platform/default/dlopen_checker.cc @@ -23,6 +23,7 @@ namespace DsoLoader { port::Status TryDlopenCUDALibraries() { auto cudart_status = GetCudaRuntimeDsoHandle(); auto cublas_status = GetCublasDsoHandle(); + auto cublaslt_status = GetCublasLtDsoHandle(); auto cufft_status = GetCufftDsoHandle(); auto curand_status = GetCurandDsoHandle(); auto cusolver_status = GetCusolverDsoHandle(); @@ -31,7 +32,7 @@ port::Status TryDlopenCUDALibraries() { if (!cudart_status.status().ok() || !cublas_status.status().ok() || !cufft_status.status().ok() || !curand_status.status().ok() || !cusolver_status.status().ok() || !cusparse_status.status().ok() || - !cudnn_status.status().ok()) { + !cudnn_status.status().ok() || !cublaslt_status.status().ok()) { return port::Status(port::error::INTERNAL, absl::StrCat("Cannot dlopen all CUDA libraries.")); } else { diff --git a/tensorflow/stream_executor/platform/default/dso_loader.cc b/tensorflow/stream_executor/platform/default/dso_loader.cc index 70b1ebe070a..66cf3f2b43b 100644 --- a/tensorflow/stream_executor/platform/default/dso_loader.cc +++ b/tensorflow/stream_executor/platform/default/dso_loader.cc @@ -84,6 +84,10 @@ port::StatusOr GetCublasDsoHandle() { return GetDsoHandle("cublas", GetCublasVersion()); } +port::StatusOr GetCublasLtDsoHandle() { + return GetDsoHandle("cublasLt", GetCublasVersion()); +} + port::StatusOr GetCufftDsoHandle() { return GetDsoHandle("cufft", GetCufftVersion()); } @@ -160,6 +164,11 @@ port::StatusOr GetCublasDsoHandle() { return *result; } +port::StatusOr GetCublasLtDsoHandle() { + static auto result = new auto(DsoLoader::GetCublasLtDsoHandle()); + return *result; +} + port::StatusOr GetCurandDsoHandle() { static auto result = new auto(DsoLoader::GetCurandDsoHandle()); return *result; diff --git a/tensorflow/stream_executor/platform/default/dso_loader.h b/tensorflow/stream_executor/platform/default/dso_loader.h index 91138f713bd..7f087349fcf 100644 --- a/tensorflow/stream_executor/platform/default/dso_loader.h +++ b/tensorflow/stream_executor/platform/default/dso_loader.h @@ -37,6 +37,7 @@ namespace DsoLoader { port::StatusOr GetCudaDriverDsoHandle(); port::StatusOr GetCudaRuntimeDsoHandle(); port::StatusOr GetCublasDsoHandle(); +port::StatusOr GetCublasLtDsoHandle(); port::StatusOr GetCufftDsoHandle(); port::StatusOr GetCurandDsoHandle(); port::StatusOr GetCusolverDsoHandle(); @@ -72,6 +73,7 @@ namespace CachedDsoLoader { port::StatusOr GetCudaDriverDsoHandle(); port::StatusOr GetCudaRuntimeDsoHandle(); port::StatusOr GetCublasDsoHandle(); +port::StatusOr GetCublasLtDsoHandle(); port::StatusOr GetCufftDsoHandle(); port::StatusOr GetCurandDsoHandle(); port::StatusOr GetCusolverDsoHandle(); diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 62689e61be1..144af92185c 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -4801,6 +4801,143 @@ Stream &Stream::ThenBlasGemmStridedBatched( c, ldc, stride_c, batch_count); } +Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan, + const HostOrDeviceScalar& alpha, + const DeviceMemory& a, + const DeviceMemory& b, + const HostOrDeviceScalar& beta, + DeviceMemory* c, + ScratchAllocator* scratch_allocator, + const blas::IBlasLtMatmulAlgorithm* algorithm, + blas::ProfileResult* output_profile_result) { + VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta), + PARAM(c), PARAM(algorithm)); + + ThenBlasWithProfileImpl< + const blas::IBlasLtMatmulPlan*, const HostOrDeviceScalar&, + const DeviceMemory&, const DeviceMemory&, + const HostOrDeviceScalar&, DeviceMemory*, ScratchAllocator*, + const blas::IBlasLtMatmulAlgorithm*> + impl; + return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta, + c, scratch_allocator, algorithm, output_profile_result); +} + +Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan, + const HostOrDeviceScalar& alpha, + const DeviceMemory& a, + const DeviceMemory& b, + const HostOrDeviceScalar& beta, + DeviceMemory* c, + ScratchAllocator* scratch_allocator, + const blas::IBlasLtMatmulAlgorithm* algorithm, + blas::ProfileResult* output_profile_result) { + VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta), + PARAM(c), PARAM(algorithm)); + + ThenBlasWithProfileImpl< + const blas::IBlasLtMatmulPlan*, const HostOrDeviceScalar&, + const DeviceMemory&, const DeviceMemory&, + const HostOrDeviceScalar&, DeviceMemory*, + ScratchAllocator*, const blas::IBlasLtMatmulAlgorithm*> + impl; + return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta, + c, scratch_allocator, algorithm, output_profile_result); +} + +Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan, + const HostOrDeviceScalar& alpha, + const DeviceMemory& a, + const DeviceMemory& b, + const HostOrDeviceScalar& beta, + DeviceMemory* c, + ScratchAllocator* scratch_allocator, + const blas::IBlasLtMatmulAlgorithm* algorithm, + blas::ProfileResult* output_profile_result) { + VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta), + PARAM(c), PARAM(algorithm)); + + ThenBlasWithProfileImpl< + const blas::IBlasLtMatmulPlan*, const HostOrDeviceScalar&, + const DeviceMemory&, const DeviceMemory&, + const HostOrDeviceScalar&, DeviceMemory*, ScratchAllocator*, + const blas::IBlasLtMatmulAlgorithm*> + impl; + return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta, + c, scratch_allocator, algorithm, output_profile_result); +} + +Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan, + const HostOrDeviceScalar& alpha, + const DeviceMemory& a, + const DeviceMemory& b, + const HostOrDeviceScalar& beta, + DeviceMemory* c, + ScratchAllocator* scratch_allocator, + const blas::IBlasLtMatmulAlgorithm* algorithm, + blas::ProfileResult* output_profile_result) { + VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta), + PARAM(c), PARAM(algorithm)); + + ThenBlasWithProfileImpl< + const blas::IBlasLtMatmulPlan*, const HostOrDeviceScalar&, + const DeviceMemory&, const DeviceMemory&, + const HostOrDeviceScalar&, DeviceMemory*, + ScratchAllocator*, const blas::IBlasLtMatmulAlgorithm*> + impl; + return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta, + c, scratch_allocator, algorithm, output_profile_result); +} + +Stream& Stream::ThenBlasLtMatmul( + const blas::IBlasLtMatmulPlan* plan, + const HostOrDeviceScalar>& alpha, + const DeviceMemory>& a, + const DeviceMemory>& b, + const HostOrDeviceScalar>& beta, + DeviceMemory>* c, ScratchAllocator* scratch_allocator, + const blas::IBlasLtMatmulAlgorithm* algorithm, + blas::ProfileResult* output_profile_result) { + VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta), + PARAM(c), PARAM(algorithm)); + + ThenBlasWithProfileImpl>&, + const DeviceMemory>&, + const DeviceMemory>&, + const HostOrDeviceScalar>&, + DeviceMemory>*, ScratchAllocator*, + const blas::IBlasLtMatmulAlgorithm*> + impl; + return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta, + c, scratch_allocator, algorithm, output_profile_result); +} + +Stream& Stream::ThenBlasLtMatmul( + const blas::IBlasLtMatmulPlan* plan, + const HostOrDeviceScalar>& alpha, + const DeviceMemory>& a, + const DeviceMemory>& b, + const HostOrDeviceScalar>& beta, + DeviceMemory>* c, ScratchAllocator* scratch_allocator, + const blas::IBlasLtMatmulAlgorithm* algorithm, + blas::ProfileResult* output_profile_result) { + VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta), + PARAM(c), PARAM(algorithm)); + + ThenBlasWithProfileImpl>&, + const DeviceMemory>&, + const DeviceMemory>&, + const HostOrDeviceScalar>&, + DeviceMemory>*, + ScratchAllocator*, + const blas::IBlasLtMatmulAlgorithm*> + impl; + return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta, + c, scratch_allocator, algorithm, output_profile_result); +} + Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) { VLOG_CALL(PARAM(seed), PARAM(seed_bytes)); diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index bfe442641ad..15f5dfc936f 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -1665,6 +1665,56 @@ class Stream { const DeviceMemory> &a, int lda, DeviceMemory> *b, int ldb); + // See BlasSupport::DoBlatLtMatmul. + Stream& ThenBlasLtMatmul( + const blas::IBlasLtMatmulPlan* plan, + const HostOrDeviceScalar& alpha, const DeviceMemory& a, + const DeviceMemory& b, const HostOrDeviceScalar& beta, + DeviceMemory* c, ScratchAllocator* scratch_allocator, + const blas::IBlasLtMatmulAlgorithm* algorithm, + blas::ProfileResult* output_profile_result = nullptr); + Stream& ThenBlasLtMatmul( + const blas::IBlasLtMatmulPlan* plan, + const HostOrDeviceScalar& alpha, + const DeviceMemory& a, const DeviceMemory& b, + const HostOrDeviceScalar& beta, DeviceMemory* c, + ScratchAllocator* scratch_allocator, + const blas::IBlasLtMatmulAlgorithm* algorithm, + blas::ProfileResult* output_profile_result = nullptr); + Stream& ThenBlasLtMatmul( + const blas::IBlasLtMatmulPlan* plan, + const HostOrDeviceScalar& alpha, const DeviceMemory& a, + const DeviceMemory& b, const HostOrDeviceScalar& beta, + DeviceMemory* c, ScratchAllocator* scratch_allocator, + const blas::IBlasLtMatmulAlgorithm* algorithm, + blas::ProfileResult* output_profile_result = nullptr); + Stream& ThenBlasLtMatmul( + const blas::IBlasLtMatmulPlan* plan, + const HostOrDeviceScalar& alpha, const DeviceMemory& a, + const DeviceMemory& b, const HostOrDeviceScalar& beta, + DeviceMemory* c, ScratchAllocator* scratch_allocator, + const blas::IBlasLtMatmulAlgorithm* algorithm, + blas::ProfileResult* output_profile_result = nullptr); + Stream& ThenBlasLtMatmul( + const blas::IBlasLtMatmulPlan* plan, + const HostOrDeviceScalar>& alpha, + const DeviceMemory>& a, + const DeviceMemory>& b, + const HostOrDeviceScalar>& beta, + DeviceMemory>* c, ScratchAllocator* scratch_allocator, + const blas::IBlasLtMatmulAlgorithm* algorithm, + blas::ProfileResult* output_profile_result = nullptr); + Stream& ThenBlasLtMatmul( + const blas::IBlasLtMatmulPlan* plan, + const HostOrDeviceScalar>& alpha, + const DeviceMemory>& a, + const DeviceMemory>& b, + const HostOrDeviceScalar>& beta, + DeviceMemory>* c, + ScratchAllocator* scratch_allocator, + const blas::IBlasLtMatmulAlgorithm* algorithm, + blas::ProfileResult* output_profile_result = nullptr); + // See FftSupport::DoFft. Stream &ThenFft(fft::Plan *plan, const DeviceMemory> &input, diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index db4e8f9b694..3fbbc3f2aac 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -336,6 +336,49 @@ bool StreamExecutor::GetBlasGemmAlgorithms( return blas_support->GetBlasGemmAlgorithms(out_algorithms); } +std::unique_ptr StreamExecutor::CreateBlasLtMatmulPlan( + blas::DataType ab_type, blas::DataType cd_type, + blas::ComputationType computation_type, blas::PointerMode pointer_mode, + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, int64 lda, int64 ldb, int64 ldc) { + blas::BlasSupport *blas_support = AsBlas(); + if (!blas_support) { + return nullptr; + } + return blas_support->CreateBlasLtMatmulPlan( + ab_type, cd_type, computation_type, pointer_mode, transa, transb, m, n, k, + lda, ldb, ldc); +} + +std::unique_ptr +StreamExecutor::CreateBlasLtMatmulPlanStridedBatched( + blas::DataType ab_type, blas::DataType cd_type, + blas::ComputationType computation_type, blas::PointerMode pointer_mode, + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, uint64 batch_count, int64 lda, int64 stride_a, int64 ldb, + int64 stride_b, int64 ldc, int64 stride_c) { + blas::BlasSupport *blas_support = AsBlas(); + if (!blas_support) { + return nullptr; + } + return blas_support->CreateBlasLtMatmulPlanStridedBatched( + ab_type, cd_type, computation_type, pointer_mode, transa, transb, m, n, k, + batch_count, lda, stride_a, ldb, stride_b, ldc, stride_c); +} + +bool StreamExecutor::GetBlasLtMatmulAlgorithms( + const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size, + int max_algorithm_count, + std::vector>* + out_algorithms) { + blas::BlasSupport *blas_support = AsBlas(); + if (!blas_support) { + return false; + } + return blas_support->GetBlasLtMatmulAlgorithms( + plan, max_workspace_size, max_algorithm_count, out_algorithms); +} + port::StatusOr> StreamExecutor::createRnnDescriptor( int num_layers, int hidden_size, int input_size, int cell_size, diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h index b9b118ca42c..90137417250 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/stream_executor/stream_executor_pimpl.h @@ -394,6 +394,35 @@ class StreamExecutor { // Get the list of supported algorithms for BLAS gemm. bool GetBlasGemmAlgorithms(std::vector *out_algorithms); + // Creates a backend-specific plan object for a blaslt matmul operation, which + // can then be passed to DoBlasLtMatmul(). When possible, plans should be + // created once and reused for multiple calls to DoBlasLtMatmul(). + // Returns a null pointer on failure. + std::unique_ptr CreateBlasLtMatmulPlan( + blas::DataType ab_type, blas::DataType cd_type, + blas::ComputationType computation_type, blas::PointerMode pointer_mode, + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, int64 lda, int64 ldb, int64 ldc); + + // A more general version of CreateBlasLtMatmulPlan supporting + // batched operations. + std::unique_ptr CreateBlasLtMatmulPlanStridedBatched( + blas::DataType ab_type, blas::DataType cd_type, + blas::ComputationType computation_type, blas::PointerMode pointer_mode, + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, uint64 batch_count, int64 lda, int64 stride_a, int64 ldb, + int64 stride_b, int64 ldc, int64 stride_c); + + // Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are + // returned in the order of increasing estimated compute time according to an + // internal heuristic. The first returned algorithm can be used as the default + // algorithm if no autotuning is to be performed. + bool GetBlasLtMatmulAlgorithms( + const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size, + int max_algorithm_count, + std::vector>* + out_algorithms); + // Create an RNN descriptor based on model shapes and configurations. // The caller retains the ownership of the descriptor. port::StatusOr> createRnnDescriptor( diff --git a/third_party/gpus/cuda/BUILD.tpl b/third_party/gpus/cuda/BUILD.tpl index a4a21abc367..70eacf82883 100644 --- a/third_party/gpus/cuda/BUILD.tpl +++ b/third_party/gpus/cuda/BUILD.tpl @@ -127,6 +127,13 @@ cc_library( linkstatic = 1, ) +cc_library( + name = "cublasLt", + srcs = ["cuda/lib/%{cublasLt_lib}"], + data = ["cuda/lib/%{cublasLt_lib}"], + linkstatic = 1, +) + cc_library( name = "cusolver", srcs = ["cuda/lib/%{cusolver_lib}"], @@ -168,6 +175,7 @@ cc_library( name = "cuda", deps = [ ":cublas", + ":cublasLt", ":cuda_headers", ":cudart", ":cudnn", diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index ea33963fe19..55bcd6e5ccc 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -551,6 +551,13 @@ def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config): cuda_config.cublas_version, static = False, ), + "cublasLt": _check_cuda_lib_params( + "cublasLt", + cpu_value, + cuda_config.config["cublas_library_dir"], + cuda_config.cublas_version, + static = False, + ), "cusolver": _check_cuda_lib_params( "cusolver", cpu_value, @@ -771,6 +778,7 @@ def _create_dummy_repository(repository_ctx): "%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value), "%{cudart_lib}": lib_name("cudart", cpu_value), "%{cublas_lib}": lib_name("cublas", cpu_value), + "%{cublasLt_lib}": lib_name("cublasLt", cpu_value), "%{cusolver_lib}": lib_name("cusolver", cpu_value), "%{cudnn_lib}": lib_name("cudnn", cpu_value), "%{cufft_lib}": lib_name("cufft", cpu_value), @@ -802,6 +810,7 @@ filegroup(name="cudnn-include") "cuda/cuda/lib/%s" % lib_name("cudart_static", cpu_value), ) repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublas", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublasLt", cpu_value)) repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusolver", cpu_value)) repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudnn", cpu_value)) repository_ctx.file("cuda/cuda/lib/%s" % lib_name("curand", cpu_value)) @@ -992,11 +1001,13 @@ def _create_local_cuda_repository(repository_ctx): cublas_include_path + "/cublas.h", cublas_include_path + "/cublas_v2.h", cublas_include_path + "/cublas_api.h", + cublas_include_path + "/cublasLt.h", ], outs = [ "cublas/include/cublas.h", "cublas/include/cublas_v2.h", "cublas/include/cublas_api.h", + "cublas/include/cublasLt.h", ], )) @@ -1137,6 +1148,7 @@ def _create_local_cuda_repository(repository_ctx): "%{cudart_static_linkopt}": _cudart_static_linkopt(cuda_config.cpu_value), "%{cudart_lib}": _basename(repository_ctx, cuda_libs["cudart"]), "%{cublas_lib}": _basename(repository_ctx, cuda_libs["cublas"]), + "%{cublasLt_lib}": _basename(repository_ctx, cuda_libs["cublasLt"]), "%{cusolver_lib}": _basename(repository_ctx, cuda_libs["cusolver"]), "%{cudnn_lib}": _basename(repository_ctx, cuda_libs["cudnn"]), "%{cufft_lib}": _basename(repository_ctx, cuda_libs["cufft"]),