Add cublasLt wrappers to stream_executor

- Adds ThenBlasLtMatmul routines that behave similarly to
  ThenBlasGemmWithAlgorithm but call into the cublasLt library and allow
  separation of plan creation and execution.
- A list of heuristically-prioritized opaque algorithm objects can be
  obtained via GetBlasLtMatmulAlgorithms.
- These routines are only supported when the CUDA version is >= 11.0.
This commit is contained in:
Ben Barsdell 2020-07-06 21:09:32 +10:00
parent 8ee3640e16
commit aaea82e6bc
16 changed files with 1741 additions and 3 deletions

View File

@ -95,5 +95,30 @@ std::ostream& operator<<(std::ostream& os, ComputationType ty) {
return os << ComputationTypeString(ty);
}
string DataTypeString(DataType ty) {
switch (ty) {
case DataType::kF16:
return "f16";
case DataType::kF32:
return "f32";
case DataType::kF64:
return "f64";
case DataType::kI8:
return "i8";
case DataType::kI32:
return "i32";
case DataType::kComplexF32:
return "complex f32";
case DataType::kComplexF64:
return "complex f64";
default:
LOG(FATAL) << "Unknown DataType " << static_cast<int32>(ty);
}
}
std::ostream& operator<<(std::ostream& os, DataType ty) {
return os << DataTypeString(ty);
}
} // namespace blas
} // namespace stream_executor

View File

@ -101,6 +101,10 @@ enum class ComputationType {
kI32, // 32-bit integer
kComplexF32, // Complex number comprised of two f32s.
kComplexF64, // Complex number comprised of two f64s.
// The below values are only supported for BlasLt routines (both real and
// complex).
kF32FastTF32, // 32-bit floating-point with reduced (>=10-bit) mantissa
kF32FastBF16, // 32-bit floating-point with reduced (7-bit) mantissa
};
// Converts a ComputationType to a string.
@ -108,6 +112,61 @@ std::string ComputationTypeString(ComputationType ty);
std::ostream &operator<<(std::ostream &os, ComputationType ty);
// Type with which inputs and outputs of a blaslt routine are performed.
enum class DataType {
kF16, // 16-bit floating-point
kF32, // 32-bit floating-point
kF64, // 64-bit floating-point
kI8, // 8-bit integer
kI32, // 32-bit integer
kComplexF32, // Complex number comprised of two f32s
kComplexF64, // Complex number comprised of two f64s
};
// Describes the type of pointers for the scaling factors alpha and beta in
// blaslt routines.
enum class PointerMode {
kHost,
kDevice,
};
// Converts a ComputationType to a string.
string DataTypeString(DataType ty);
std::ostream &operator<<(std::ostream &os, DataType ty);
// Converts a compile-time type to a DataType value.
template <typename T>
struct ToDataType {};
template <>
struct ToDataType<Eigen::half> {
static constexpr const DataType value = DataType::kF16;
};
template <>
struct ToDataType<float> {
static constexpr const DataType value = DataType::kF32;
};
template <>
struct ToDataType<double> {
static constexpr const DataType value = DataType::kF64;
};
template <>
struct ToDataType<int8> {
static constexpr const DataType value = DataType::kI8;
};
template <>
struct ToDataType<int32> {
static constexpr const DataType value = DataType::kI32;
};
template <>
struct ToDataType<std::complex<float>> {
static constexpr const DataType value = DataType::kComplexF32;
};
template <>
struct ToDataType<std::complex<double>> {
static constexpr const DataType value = DataType::kComplexF64;
};
// Opaque identifier for an "algorithm" used by a blas routine. This functions
// as a hint to the blas library.
typedef int64 AlgorithmType;
@ -163,6 +222,19 @@ class AlgorithmConfig {
AlgorithmType algorithm_;
};
struct IBlasLtMatmulPlan {
virtual ~IBlasLtMatmulPlan() {}
};
struct IBlasLtMatmulAlgorithm {
virtual ~IBlasLtMatmulAlgorithm() {}
// Returns the index of the algorithm within the list returned by
// GetBlasLtMatmulAlgorithms.
virtual AlgorithmType index() const = 0;
// Returns the workspace size required by the algorithm in bytes.
virtual size_t workspace_size() const = 0;
};
// BLAS support interface -- this can be derived from a GPU executor when the
// underlying platform has an BLAS library implementation available. See
// StreamExecutor::AsBlas().
@ -1383,6 +1455,93 @@ class BlasSupport {
const DeviceMemory<std::complex<double>> &a, int lda,
DeviceMemory<std::complex<double>> *b, int ldb) = 0;
// Creates a backend-specific plan object for a blaslt matmul operation, which
// can then be passed to DoBlasLtMatmul(). When possible, plans should be
// created once and reused for multiple calls to DoBlasLtMatmul().
// Returns a null pointer on failure.
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
blas::DataType ab_type, blas::DataType c_type,
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, int64 lda, int64 ldb, int64 ldc) {
return CreateBlasLtMatmulPlanStridedBatched(
ab_type, c_type, computation_type, pointer_mode, transa, transb, m, n,
k, 1, lda, 0, ldb, 0, ldc, 0);
}
// A more general version of CreateBlasLtMatmulPlan supporting
// batched operations.
virtual std::unique_ptr<blas::IBlasLtMatmulPlan>
CreateBlasLtMatmulPlanStridedBatched(
blas::DataType ab_type, blas::DataType c_type,
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, int batch_count, int64 lda, int64 stride_a, int64 ldb,
int64 stride_b, int64 ldc, int64 stride_c) = 0;
// Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
// returned in the order of increasing estimated compute time according to an
// internal heuristic. The first returned algorithm can be used as the default
// algorithm if no autotuning is to be performed.
virtual bool GetBlasLtMatmulAlgorithms(
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size,
int max_algorithm_count,
std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>*
out_algorithms) = 0;
// Executes a blaslt matmul operation on the stream. If output_profile_result
// is not nullptr, the operation is profiled, error messages are
// suppressed, and output_profile_result->algorithm() is set to
// algorithm->index().
virtual bool DoBlasLtMatmul(
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<int32>& alpha, const DeviceMemory<int8>& a,
const DeviceMemory<int8>& b, const HostOrDeviceScalar<int32>& beta,
DeviceMemory<int32>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
blas::ProfileResult* output_profile_result = nullptr) = 0;
virtual bool DoBlasLtMatmul(
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<Eigen::half>& alpha,
const DeviceMemory<Eigen::half>& a, const DeviceMemory<Eigen::half>& b,
const HostOrDeviceScalar<Eigen::half>& beta, DeviceMemory<Eigen::half>* c,
ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
blas::ProfileResult* output_profile_result = nullptr) = 0;
virtual bool DoBlasLtMatmul(
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<float>& alpha, const DeviceMemory<float>& a,
const DeviceMemory<float>& b, const HostOrDeviceScalar<float>& beta,
DeviceMemory<float>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
blas::ProfileResult* output_profile_result = nullptr) = 0;
virtual bool DoBlasLtMatmul(
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<double>& alpha, const DeviceMemory<double>& a,
const DeviceMemory<double>& b, const HostOrDeviceScalar<double>& beta,
DeviceMemory<double>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
blas::ProfileResult* output_profile_result = nullptr) = 0;
virtual bool DoBlasLtMatmul(
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<std::complex<float>>& alpha,
const DeviceMemory<std::complex<float>>& a,
const DeviceMemory<std::complex<float>>& b,
const HostOrDeviceScalar<std::complex<float>>& beta,
DeviceMemory<std::complex<float>>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
blas::ProfileResult* output_profile_result = nullptr) = 0;
virtual bool DoBlasLtMatmul(
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<std::complex<double>>& alpha,
const DeviceMemory<std::complex<double>>& a,
const DeviceMemory<std::complex<double>>& b,
const HostOrDeviceScalar<std::complex<double>>& beta,
DeviceMemory<std::complex<double>>* c,
ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
blas::ProfileResult* output_profile_result = nullptr) = 0;
virtual port::Status GetVersion(std::string *version) = 0;
protected:
@ -2196,6 +2355,65 @@ class BlasSupport {
uint64 n, std::complex<double> alpha, \
const DeviceMemory<std::complex<double>> &a, int lda, \
DeviceMemory<std::complex<double>> *b, int ldb) override; \
std::unique_ptr<blas::IBlasLtMatmulPlan> \
CreateBlasLtMatmulPlanStridedBatched( \
blas::DataType ab_type, blas::DataType cd_type, \
blas::ComputationType computation_type, blas::PointerMode pointer_mode, \
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, \
uint64 k, int batch_count, int64 lda, int64 stride_a, int64 ldb, \
int64 stride_b, int64 ldc, int64 stride_c) override; \
bool GetBlasLtMatmulAlgorithms( \
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size, \
int max_algorithm_count, \
std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>* \
out_algorithms) override; \
bool DoBlasLtMatmul( \
Stream* stream, const blas::IBlasLtMatmulPlan* plan, \
const HostOrDeviceScalar<int32>& alpha, const DeviceMemory<int8>& a, \
const DeviceMemory<int8>& b, const HostOrDeviceScalar<int32>& beta, \
DeviceMemory<int32>* c, ScratchAllocator* scratch_allocator, \
const blas::IBlasLtMatmulAlgorithm* algorithm, \
blas::ProfileResult* output_profile_result = nullptr) override; \
bool DoBlasLtMatmul( \
Stream* stream, const blas::IBlasLtMatmulPlan* plan, \
const HostOrDeviceScalar<Eigen::half>& alpha, \
const DeviceMemory<Eigen::half>& a, const DeviceMemory<Eigen::half>& b, \
const HostOrDeviceScalar<Eigen::half>& beta, \
DeviceMemory<Eigen::half>* c, ScratchAllocator* scratch_allocator, \
const blas::IBlasLtMatmulAlgorithm* algorithm, \
blas::ProfileResult* output_profile_result) override; \
bool DoBlasLtMatmul( \
Stream* stream, const blas::IBlasLtMatmulPlan* plan, \
const HostOrDeviceScalar<float>& alpha, const DeviceMemory<float>& a, \
const DeviceMemory<float>& b, const HostOrDeviceScalar<float>& beta, \
DeviceMemory<float>* c, ScratchAllocator* scratch_allocator, \
const blas::IBlasLtMatmulAlgorithm* algorithm, \
blas::ProfileResult* output_profile_result) override; \
bool DoBlasLtMatmul( \
Stream* stream, const blas::IBlasLtMatmulPlan* plan, \
const HostOrDeviceScalar<double>& alpha, const DeviceMemory<double>& a, \
const DeviceMemory<double>& b, const HostOrDeviceScalar<double>& beta, \
DeviceMemory<double>* c, ScratchAllocator* scratch_allocator, \
const blas::IBlasLtMatmulAlgorithm* algorithm, \
blas::ProfileResult* output_profile_result) override; \
bool DoBlasLtMatmul(Stream* stream, const blas::IBlasLtMatmulPlan* plan, \
const HostOrDeviceScalar<std::complex<float>>& alpha, \
const DeviceMemory<std::complex<float>>& a, \
const DeviceMemory<std::complex<float>>& b, \
const HostOrDeviceScalar<std::complex<float>>& beta, \
DeviceMemory<std::complex<float>>* c, \
ScratchAllocator* scratch_allocator, \
const blas::IBlasLtMatmulAlgorithm* algorithm, \
blas::ProfileResult* output_profile_result) override; \
bool DoBlasLtMatmul(Stream* stream, const blas::IBlasLtMatmulPlan* plan, \
const HostOrDeviceScalar<std::complex<double>>& alpha, \
const DeviceMemory<std::complex<double>>& a, \
const DeviceMemory<std::complex<double>>& b, \
const HostOrDeviceScalar<std::complex<double>>& beta, \
DeviceMemory<std::complex<double>>* c, \
ScratchAllocator* scratch_allocator, \
const blas::IBlasLtMatmulAlgorithm* algorithm, \
blas::ProfileResult* output_profile_result) override; \
port::Status GetVersion(std::string *version) override;
} // namespace blas

View File

@ -242,6 +242,29 @@ alias(
visibility = ["//visibility:public"],
)
cc_library(
name = "cublasLt_stub",
srcs = if_cuda_is_configured(["cublasLt_stub.cc"]),
textual_hdrs = glob(["cublasLt_*.inc"]),
deps = if_cuda_is_configured([
# LINT.IfChange
"@local_config_cuda//cuda:cublas_headers",
# LINT.ThenChange(//tensorflow/copy.bara.sky:cublasLt_headers)
"@local_config_cuda//cuda:cuda_headers",
"//tensorflow/stream_executor/lib",
"//tensorflow/stream_executor/platform:dso_loader",
]),
)
alias(
name = "cublasLt_lib",
actual = select({
"//tensorflow:oss": ":cublasLt_stub",
"//conditions:default": "@local_config_cuda//cuda:cublasLt",
}),
visibility = ["//visibility:public"],
)
cc_library(
name = "cublas_plugin",
srcs = if_cuda_is_configured(["cuda_blas.cc"]),
@ -249,6 +272,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = if_cuda_is_configured([
":cublas_lib",
":cublasLt_lib",
":cuda_activation",
":cuda_gpu_executor",
":cuda_platform_id",

View File

@ -0,0 +1,415 @@
// Auto-generated, do not edit.
extern "C" {
cublasStatus_t CUBLASWINAPI
cublasLtCreate(cublasLtHandle_t *lightHandle) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtHandle_t *);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtCreate");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(lightHandle);
}
cublasStatus_t CUBLASWINAPI
cublasLtDestroy(cublasLtHandle_t lightHandle) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtHandle_t);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtDestroy");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(lightHandle);
}
size_t CUBLASWINAPI
cublasLtGetVersion(void) {
using FuncPtr = size_t (CUBLASWINAPI *)();
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtGetVersion");
if (!func_ptr) return 0;
return func_ptr();
}
size_t CUBLASWINAPI
cublasLtGetCudartVersion(void) {
using FuncPtr = size_t (CUBLASWINAPI *)();
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtGetCudartVersion");
if (!func_ptr) return 0;
return func_ptr();
}
cublasStatus_t CUBLASWINAPI
cublasLtGetProperty(libraryPropertyType type, int *value) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(libraryPropertyType, int *);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtGetProperty");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(type, value);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatmul(cublasLtHandle_t lightHandle,
cublasLtMatmulDesc_t computeDesc,
const void *alpha, /* host or device pointer */
const void *A,
cublasLtMatrixLayout_t Adesc,
const void *B,
cublasLtMatrixLayout_t Bdesc,
const void *beta, /* host or device pointer */
const void *C,
cublasLtMatrixLayout_t Cdesc,
void *D,
cublasLtMatrixLayout_t Ddesc,
const cublasLtMatmulAlgo_t *algo,
void *workspace,
size_t workspaceSizeInBytes,
cudaStream_t stream) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtHandle_t, cublasLtMatmulDesc_t, const void *, const void *, cublasLtMatrixLayout_t, const void *, cublasLtMatrixLayout_t, const void *, const void *, cublasLtMatrixLayout_t, void *, cublasLtMatrixLayout_t, const cublasLtMatmulAlgo_t *, void *, size_t, cudaStream_t);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmul");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(lightHandle, computeDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, D, Ddesc, algo, workspace, workspaceSizeInBytes, stream);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatrixTransform(cublasLtHandle_t lightHandle,
cublasLtMatrixTransformDesc_t transformDesc,
const void *alpha, /* host or device pointer */
const void *A,
cublasLtMatrixLayout_t Adesc,
const void *beta, /* host or device pointer */
const void *B,
cublasLtMatrixLayout_t Bdesc,
void *C,
cublasLtMatrixLayout_t Cdesc,
cudaStream_t stream) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtHandle_t, cublasLtMatrixTransformDesc_t, const void *, const void *, cublasLtMatrixLayout_t, const void *, const void *, cublasLtMatrixLayout_t, void *, cublasLtMatrixLayout_t, cudaStream_t);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatrixTransform");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(lightHandle, transformDesc, alpha, A, Adesc, beta, B, Bdesc, C, Cdesc, stream);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatrixLayoutInit_internal( //
cublasLtMatrixLayout_t matLayout,
size_t size,
cudaDataType type,
uint64_t rows,
uint64_t cols,
int64_t ld) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(//
cublasLtMatrixLayout_t, size_t, cudaDataType, uint64_t, uint64_t, int64_t);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatrixLayoutInit_internal");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(matLayout, size, type, rows, cols, ld);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatrixLayoutCreate( //
cublasLtMatrixLayout_t *matLayout,
cudaDataType type,
uint64_t rows,
uint64_t cols,
int64_t ld) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(//
cublasLtMatrixLayout_t *, cudaDataType, uint64_t, uint64_t, int64_t);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatrixLayoutCreate");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(matLayout, type, rows, cols, ld);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatrixLayoutDestroy(cublasLtMatrixLayout_t matLayout) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatrixLayout_t);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatrixLayoutDestroy");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(matLayout);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatrixLayoutSetAttribute( //
cublasLtMatrixLayout_t matLayout,
cublasLtMatrixLayoutAttribute_t attr,
const void *buf,
size_t sizeInBytes) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(//
cublasLtMatrixLayout_t, cublasLtMatrixLayoutAttribute_t, const void *, size_t);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatrixLayoutSetAttribute");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(matLayout, attr, buf, sizeInBytes);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatrixLayoutGetAttribute( //
cublasLtMatrixLayout_t matLayout,
cublasLtMatrixLayoutAttribute_t attr,
void *buf,
size_t sizeInBytes,
size_t *sizeWritten) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(//
cublasLtMatrixLayout_t, cublasLtMatrixLayoutAttribute_t, void *, size_t, size_t *);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatrixLayoutGetAttribute");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(matLayout, attr, buf, sizeInBytes, sizeWritten);
}
cublasStatus_t CUBLASWINAPI cublasLtMatmulDescInit_internal( //
cublasLtMatmulDesc_t matmulDesc,
size_t size,
cublasComputeType_t computeType,
cudaDataType_t scaleType) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(//
cublasLtMatmulDesc_t, size_t, cublasComputeType_t, cudaDataType_t);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulDescInit_internal");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(matmulDesc, size, computeType, scaleType);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatmulDescCreate(cublasLtMatmulDesc_t *matmulDesc, cublasComputeType_t computeType, cudaDataType_t scaleType) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatmulDesc_t *, cublasComputeType_t, cudaDataType_t);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulDescCreate");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(matmulDesc, computeType, scaleType);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatmulDescDestroy(cublasLtMatmulDesc_t matmulDesc) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatmulDesc_t);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulDescDestroy");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(matmulDesc);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatmulDescSetAttribute( //
cublasLtMatmulDesc_t matmulDesc,
cublasLtMatmulDescAttributes_t attr,
const void *buf,
size_t sizeInBytes) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(//
cublasLtMatmulDesc_t, cublasLtMatmulDescAttributes_t, const void *, size_t);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulDescSetAttribute");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(matmulDesc, attr, buf, sizeInBytes);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatmulDescGetAttribute( //
cublasLtMatmulDesc_t matmulDesc,
cublasLtMatmulDescAttributes_t attr,
void *buf,
size_t sizeInBytes,
size_t *sizeWritten) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(//
cublasLtMatmulDesc_t, cublasLtMatmulDescAttributes_t, void *, size_t, size_t *);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulDescGetAttribute");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(matmulDesc, attr, buf, sizeInBytes, sizeWritten);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatrixTransformDescInit_internal(cublasLtMatrixTransformDesc_t transformDesc, size_t size, cudaDataType scaleType) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatrixTransformDesc_t, size_t, cudaDataType);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatrixTransformDescInit_internal");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(transformDesc, size, scaleType);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatrixTransformDescCreate(cublasLtMatrixTransformDesc_t *transformDesc, cudaDataType scaleType) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatrixTransformDesc_t *, cudaDataType);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatrixTransformDescCreate");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(transformDesc, scaleType);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatrixTransformDescDestroy(cublasLtMatrixTransformDesc_t transformDesc) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatrixTransformDesc_t);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatrixTransformDescDestroy");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(transformDesc);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatrixTransformDescSetAttribute( //
cublasLtMatrixTransformDesc_t transformDesc,
cublasLtMatrixTransformDescAttributes_t attr,
const void *buf,
size_t sizeInBytes) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(//
cublasLtMatrixTransformDesc_t, cublasLtMatrixTransformDescAttributes_t, const void *, size_t);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatrixTransformDescSetAttribute");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(transformDesc, attr, buf, sizeInBytes);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatrixTransformDescGetAttribute( //
cublasLtMatrixTransformDesc_t transformDesc,
cublasLtMatrixTransformDescAttributes_t attr,
void *buf,
size_t sizeInBytes,
size_t *sizeWritten) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(//
cublasLtMatrixTransformDesc_t, cublasLtMatrixTransformDescAttributes_t, void *, size_t, size_t *);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatrixTransformDescGetAttribute");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(transformDesc, attr, buf, sizeInBytes, sizeWritten);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatmulPreferenceInit_internal(cublasLtMatmulPreference_t pref, size_t size) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatmulPreference_t, size_t);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulPreferenceInit_internal");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(pref, size);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatmulPreferenceCreate(cublasLtMatmulPreference_t *pref) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatmulPreference_t *);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulPreferenceCreate");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(pref);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatmulPreferenceDestroy(cublasLtMatmulPreference_t pref) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatmulPreference_t);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulPreferenceDestroy");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(pref);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatmulPreferenceSetAttribute( //
cublasLtMatmulPreference_t pref,
cublasLtMatmulPreferenceAttributes_t attr,
const void *buf,
size_t sizeInBytes) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(//
cublasLtMatmulPreference_t, cublasLtMatmulPreferenceAttributes_t, const void *, size_t);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulPreferenceSetAttribute");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(pref, attr, buf, sizeInBytes);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatmulPreferenceGetAttribute( //
cublasLtMatmulPreference_t pref,
cublasLtMatmulPreferenceAttributes_t attr,
void *buf,
size_t sizeInBytes,
size_t *sizeWritten) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(//
cublasLtMatmulPreference_t, cublasLtMatmulPreferenceAttributes_t, void *, size_t, size_t *);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulPreferenceGetAttribute");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(pref, attr, buf, sizeInBytes, sizeWritten);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatmulAlgoGetHeuristic(
cublasLtHandle_t lightHandle,
cublasLtMatmulDesc_t operationDesc,
cublasLtMatrixLayout_t Adesc,
cublasLtMatrixLayout_t Bdesc,
cublasLtMatrixLayout_t Cdesc,
cublasLtMatrixLayout_t Ddesc,
cublasLtMatmulPreference_t preference,
int requestedAlgoCount,
cublasLtMatmulHeuristicResult_t heuristicResultsArray[],
int *returnAlgoCount) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtHandle_t, cublasLtMatmulDesc_t, cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, cublasLtMatmulPreference_t, int, cublasLtMatmulHeuristicResult_t [], int *);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulAlgoGetHeuristic");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(lightHandle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, requestedAlgoCount, heuristicResultsArray, returnAlgoCount);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatmulAlgoGetIds(
cublasLtHandle_t lightHandle,
cublasComputeType_t computeType,
cudaDataType_t scaleType,
cudaDataType_t Atype,
cudaDataType_t Btype,
cudaDataType_t Ctype,
cudaDataType_t Dtype,
int requestedAlgoCount,
int algoIdsArray[],
int *returnAlgoCount) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtHandle_t, cublasComputeType_t, cudaDataType_t, cudaDataType_t, cudaDataType_t, cudaDataType_t, cudaDataType_t, int, int [], int *);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulAlgoGetIds");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(lightHandle, computeType, scaleType, Atype, Btype, Ctype, Dtype, requestedAlgoCount, algoIdsArray, returnAlgoCount);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatmulAlgoInit ( cublasLtHandle_t lightHandle,
cublasComputeType_t computeType,
cudaDataType_t scaleType,
cudaDataType_t Atype,
cudaDataType_t Btype,
cudaDataType_t Ctype,
cudaDataType_t Dtype,
int algoId,
cublasLtMatmulAlgo_t *algo) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtHandle_t, cublasComputeType_t, cudaDataType_t, cudaDataType_t, cudaDataType_t, cudaDataType_t, cudaDataType_t, int, cublasLtMatmulAlgo_t *);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulAlgoInit");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(lightHandle, computeType, scaleType, Atype, Btype, Ctype, Dtype, algoId, algo);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatmulAlgoCheck( //
cublasLtHandle_t lightHandle,
cublasLtMatmulDesc_t operationDesc,
cublasLtMatrixLayout_t Adesc,
cublasLtMatrixLayout_t Bdesc,
cublasLtMatrixLayout_t Cdesc,
cublasLtMatrixLayout_t Ddesc,
const cublasLtMatmulAlgo_t *algo, ///< may point to result->algo
cublasLtMatmulHeuristicResult_t *result) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(//
cublasLtHandle_t, cublasLtMatmulDesc_t, cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, const cublasLtMatmulAlgo_t *, ///< may point to result->algo
cublasLtMatmulHeuristicResult_t *);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulAlgoCheck");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(lightHandle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, algo, result);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatmulAlgoCapGetAttribute(
const cublasLtMatmulAlgo_t *algo,
cublasLtMatmulAlgoCapAttributes_t attr,
void *buf,
size_t sizeInBytes,
size_t *sizeWritten) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(const cublasLtMatmulAlgo_t *, cublasLtMatmulAlgoCapAttributes_t, void *, size_t, size_t *);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulAlgoCapGetAttribute");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(algo, attr, buf, sizeInBytes, sizeWritten);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatmulAlgoConfigSetAttribute(
cublasLtMatmulAlgo_t *algo,
cublasLtMatmulAlgoConfigAttributes_t attr,
const void *buf,
size_t sizeInBytes) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatmulAlgo_t *, cublasLtMatmulAlgoConfigAttributes_t, const void *, size_t);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulAlgoConfigSetAttribute");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(algo, attr, buf, sizeInBytes);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatmulAlgoConfigGetAttribute(
const cublasLtMatmulAlgo_t *algo,
cublasLtMatmulAlgoConfigAttributes_t attr,
void *buf,
size_t sizeInBytes,
size_t *sizeWritten) {
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(const cublasLtMatmulAlgo_t *, cublasLtMatmulAlgoConfigAttributes_t, void *, size_t, size_t *);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulAlgoConfigGetAttribute");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(algo, attr, buf, sizeInBytes, sizeWritten);
}
} // extern "C"

View File

@ -0,0 +1,59 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/gpus/cuda/include/cublasLt.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/platform/dso_loader.h"
// Implements the cuBLASLt API by forwarding to cuBLASLt loaded from the DSO.
namespace {
// Returns DSO handle or null if loading the DSO fails.
void* GetDsoHandle() {
#ifdef PLATFORM_GOOGLE
return nullptr;
#else
static auto handle = []() -> void* {
auto handle_or =
stream_executor::internal::DsoLoader::GetCublasLtDsoHandle();
if (!handle_or.ok()) return nullptr;
return handle_or.ValueOrDie();
}();
return handle;
#endif
}
template <typename T>
T LoadSymbol(const char* symbol_name) {
void* symbol = nullptr;
if (auto handle = GetDsoHandle()) {
stream_executor::port::Env::Default()
->GetSymbolFromLibrary(handle, symbol_name, &symbol)
.IgnoreError();
}
return reinterpret_cast<T>(symbol);
}
void LogFatalSymbolNotFound(const char* symbol_name) {
LOG(FATAL) << symbol_name << " symbol not found.";
}
cublasStatus_t GetSymbolNotFoundError() { return CUBLAS_STATUS_INTERNAL_ERROR; }
} // namespace
// We only use cublasLt from CUDA 11.0 onward.
#if CUDA_VERSION >= 11000
#include "tensorflow/stream_executor/cuda/cublasLt_11_0.inc"
#endif

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "third_party/gpus/cuda/include/cublas_v2.h"
#include "third_party/gpus/cuda/include/cublasLt.h"
#include "third_party/gpus/cuda/include/cuda.h"
#define SE_CUDA_DATA_HALF CUDA_R_16F
@ -226,17 +227,38 @@ bool CUDABlas::Init() {
return false;
}
#if CUDA_VERSION >= 11000
ret = cublasLtCreate(&blasLt_);
if (ret != CUBLAS_STATUS_SUCCESS) {
LOG(ERROR) << "failed to create cublasLt handle: " << ToString(ret);
return false;
}
#endif // CUDA_VERSION >= 11000
return true;
}
CUDABlas::CUDABlas(gpu::GpuExecutor *parent)
: parent_(CHECK_NOTNULL(parent)), blas_(nullptr) {}
CUDABlas::CUDABlas(gpu::GpuExecutor* parent)
: parent_(CHECK_NOTNULL(parent)),
blas_(nullptr)
#if CUDA_VERSION >= 11000
,
blasLt_(nullptr)
#endif
{
}
CUDABlas::~CUDABlas() {
if (blas_ != nullptr) {
gpu::ScopedActivateExecutorContext sac{parent_};
cublasDestroy(blas_);
}
#if CUDA_VERSION >= 11000
if (blasLt_ != nullptr) {
gpu::ScopedActivateExecutorContext sac{parent_};
cublasLtDestroy(blasLt_);
}
#endif
}
bool CUDABlas::SetStream(Stream *stream) {
@ -253,6 +275,13 @@ bool CUDABlas::SetStream(Stream *stream) {
return true;
}
cudaStream_t CUDABlas::CUDAStream(Stream* stream) {
CHECK(stream != nullptr);
CHECK(AsGpuStreamValue(stream) != nullptr);
gpu::ScopedActivateExecutorContext sac{parent_};
return AsGpuStreamValue(stream);
}
namespace {
// Helper functions transforming blas arguments into cuBLAS arguments.
@ -381,6 +410,82 @@ cudaDataType_t CUDAComputationType(blas::ComputationType ty) {
return CUDA_C_32F;
case blas::ComputationType::kComplexF64:
return CUDA_C_64F;
case blas::ComputationType::kF32FastTF32: // fall-through
case blas::ComputationType::kF32FastBF16:
// These cases are currently only supported in the blasLt routines, which
// use CUBLASComputationType() instead.
LOG(FATAL) << "Invalid value of blas::ComputationType.";
}
}
#if CUDA_VERSION >= 11000
cublasComputeType_t CUBLASComputationType(blas::ComputationType ty) {
switch (ty) {
case blas::ComputationType::kF16:
return CUBLAS_COMPUTE_16F;
case blas::ComputationType::kF32: // fall-through
case blas::ComputationType::kComplexF32:
return CUBLAS_COMPUTE_32F;
case blas::ComputationType::kF64: // fall-through
case blas::ComputationType::kComplexF64:
return CUBLAS_COMPUTE_64F;
case blas::ComputationType::kI32:
return CUBLAS_COMPUTE_32I;
case blas::ComputationType::kF32FastTF32:
return CUBLAS_COMPUTE_32F_FAST_TF32;
case blas::ComputationType::kF32FastBF16:
return CUBLAS_COMPUTE_32F_FAST_16BF;
}
}
#endif // CUDA_VERSION >= 11000
blas::DataType GetScaleType(blas::DataType data_type,
blas::ComputationType compute_type) {
bool is_complex = data_type == blas::DataType::kComplexF32 ||
data_type == blas::DataType::kComplexF64;
switch (compute_type) {
case blas::ComputationType::kF16:
return blas::DataType::kF16;
case blas::ComputationType::kF32: // fall-through
case blas::ComputationType::kComplexF32: // fall-through
case blas::ComputationType::kF32FastTF32: // fall-through
case blas::ComputationType::kF32FastBF16:
return is_complex ? blas::DataType::kComplexF32 : blas::DataType::kF32;
case blas::ComputationType::kF64: // fall-through
case blas::ComputationType::kComplexF64:
return is_complex ? blas::DataType::kComplexF64 : blas::DataType::kF64;
case blas::ComputationType::kI32:
return blas::DataType::kI32;
}
}
#if CUDA_VERSION >= 11000
cublasLtPointerMode_t CUBLASPointerMode(blas::PointerMode pointer_mode) {
switch (pointer_mode) {
case blas::PointerMode::kHost:
return CUBLASLT_POINTER_MODE_HOST;
case blas::PointerMode::kDevice:
return CUBLASLT_POINTER_MODE_DEVICE;
}
}
#endif // CUDA_VERSION >= 11000
cudaDataType_t GetCUDADataType(blas::DataType ty) {
switch (ty) {
case blas::DataType::kF16:
return CUDA_R_16F;
case blas::DataType::kF32:
return CUDA_R_32F;
case blas::DataType::kF64:
return CUDA_R_64F;
case blas::DataType::kI8:
return CUDA_R_8I;
case blas::DataType::kI32:
return CUDA_R_32I;
case blas::DataType::kComplexF32:
return CUDA_C_32F;
case blas::DataType::kComplexF64:
return CUDA_C_64F;
}
}
} // namespace
@ -2912,6 +3017,577 @@ bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
GpuComplex(GpuMemoryMutable(b)), ldb);
}
// We only use cublasLt from CUDA 11.0 onward.
#if CUDA_VERSION >= 11000
namespace {
template <typename T>
inline bool SetCublasLtAttr(cublasLtMatrixLayout_t handle,
cublasLtMatrixLayoutAttribute_t attr,
const T& value) {
cublasStatus_t status =
cublasLtMatrixLayoutSetAttribute(handle, attr, &value, sizeof(T));
if (status != CUBLAS_STATUS_SUCCESS) {
VLOG(2) << "cublasLtMatrixLayoutSetAttribute(attr=" << attr
<< ", value=" << value << ") failed: " << ToString(status);
return false;
}
return true;
}
template <typename T>
inline bool SetCublasLtAttr(cublasLtMatmulAlgo_t* handle,
cublasLtMatmulAlgoConfigAttributes_t attr,
const T& value) {
cublasStatus_t status =
cublasLtMatmulAlgoConfigSetAttribute(handle, attr, &value, sizeof(T));
if (status != CUBLAS_STATUS_SUCCESS) {
VLOG(2) << "cublasLtMatmulAlgoConfigSetAttribute(attr=" << attr
<< ", value=" << value << ") failed: " << ToString(status);
return false;
}
return true;
}
template <typename T>
inline bool SetCublasLtAttr(cublasLtMatmulPreference_t handle,
cublasLtMatmulPreferenceAttributes_t attr,
const T& value) {
cublasStatus_t status =
cublasLtMatmulPreferenceSetAttribute(handle, attr, &value, sizeof(value));
if (status != CUBLAS_STATUS_SUCCESS) {
VLOG(2) << "cublasLtMatmulPreferenceSetAttribute(attr=" << attr
<< ", value=" << value << ") failed: " << ToString(status);
return false;
}
return true;
}
template <typename T>
inline bool GetCublasLtAttr(const cublasLtMatmulAlgo_t* handle,
cublasLtMatmulAlgoConfigAttributes_t attr,
T* value) {
auto mutable_handle = const_cast<cublasLtMatmulAlgo_t*>(handle);
size_t bytes_written = 0;
return cublasLtMatmulAlgoConfigGetAttribute(mutable_handle, attr, value,
sizeof(T), &bytes_written) ==
CUBLAS_STATUS_SUCCESS &&
bytes_written == sizeof(T);
}
template <typename T>
inline bool SetCublasLtAttr(cublasLtMatmulDesc_t handle,
cublasLtMatmulDescAttributes_t attr,
const T& value) {
cublasStatus_t status =
cublasLtMatmulDescSetAttribute(handle, attr, &value, sizeof(value));
if (status != CUBLAS_STATUS_SUCCESS) {
VLOG(2) << "cublasLtMatmulDescSetAttribute(attr=" << attr
<< ", value=" << value << ") failed: " << ToString(status);
return false;
}
return true;
}
struct MatmulDescDestroyer {
void operator()(cublasLtMatmulDesc_t matmul_desc) const {
cublasLtMatmulDescDestroy(matmul_desc);
}
};
struct LayoutDestroyer {
void operator()(cublasLtMatrixLayout_t layout) const {
cublasLtMatrixLayoutDestroy(layout);
}
};
struct MatmulPreferenceDestroyer {
void operator()(cublasLtMatmulPreference_t matmul_pref) const {
cublasLtMatmulPreferenceDestroy(matmul_pref);
}
};
using UniqueOpDesc =
std::unique_ptr<std::remove_pointer<cublasLtMatmulDesc_t>::type,
MatmulDescDestroyer>;
using UniqueLayoutDesc =
std::unique_ptr<std::remove_pointer<cublasLtMatrixLayout_t>::type,
LayoutDestroyer>;
using UniqueMatmulPreference =
std::unique_ptr<std::remove_pointer<cublasLtMatmulPreference_t>::type,
MatmulPreferenceDestroyer>;
UniqueOpDesc CreateCublasLtOperationDesc(
blas::ComputationType computation_type, blas::DataType scale_type,
blas::PointerMode pointer_mode, blas::Transpose transa,
blas::Transpose transb) {
cublasOperation_t cuda_transa = CUDABlasTranspose(transa);
cublasOperation_t cuda_transb = CUDABlasTranspose(transb);
cublasLtMatmulDesc_t desc;
cublasComputeType_t cublas_compute_type =
CUBLASComputationType(computation_type);
cudaDataType_t cuda_scale_type = GetCUDADataType(scale_type);
cublasStatus_t status =
cublasLtMatmulDescCreate(&desc, cublas_compute_type, cuda_scale_type);
if (status != CUBLAS_STATUS_SUCCESS) {
VLOG(2) << "cublasLtMatmulDescCreate(computation_type=" << computation_type
<< ") failed: " << ToString(status);
return nullptr;
}
UniqueOpDesc unique_desc(desc);
if (!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_POINTER_MODE,
CUBLASPointerMode(pointer_mode)) ||
!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSA, cuda_transa) ||
!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSB, cuda_transb)) {
return nullptr;
}
return unique_desc;
}
UniqueLayoutDesc CreateCublasLtLayoutDesc(blas::DataType data_type, uint64 rows,
uint64 cols, int64 ld, int64 stride,
int batch_count) {
cublasLtMatrixLayout_t desc;
cublasStatus_t status = cublasLtMatrixLayoutCreate(
&desc, GetCUDADataType(data_type), rows, cols, ld);
if (status != CUBLAS_STATUS_SUCCESS) {
VLOG(2) << "cublasLtMatrixLayoutCreate failed: " << ToString(status);
return nullptr;
}
UniqueLayoutDesc unique_desc(desc);
if (!SetCublasLtAttr(desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_count) ||
!SetCublasLtAttr(desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
stride)) {
return nullptr;
}
return unique_desc;
}
UniqueMatmulPreference CreateCublasLtMatmulPreference(
size_t max_workspace_bytes) {
cublasLtMatmulPreference_t preference;
cublasStatus_t status = cublasLtMatmulPreferenceCreate(&preference);
if (status != CUBLAS_STATUS_SUCCESS) {
VLOG(2) << "cublasLtMatmulPreferenceCreate failed: " << ToString(status);
return nullptr;
}
UniqueMatmulPreference unique_preference(preference);
if (!SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
max_workspace_bytes)) {
return nullptr;
}
return unique_preference;
}
// Helper function to allocate workspace.
port::Status AllocateWorkspace(void** workspace,
ScratchAllocator* scratch_allocator,
size_t num_bytes) {
SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace_bytes,
scratch_allocator->AllocateBytes(num_bytes));
*workspace = (void*)GpuMemoryMutable(&workspace_bytes);
return port::Status::OK();
}
template<typename T>
blas::ComputationType ToComputationType();
template <>
blas::ComputationType ToComputationType<Eigen::half>() {
return blas::ComputationType::kF16;
}
template <>
blas::ComputationType ToComputationType<float>() {
return blas::ComputationType::kF32;
}
template <>
blas::ComputationType ToComputationType<double>() {
return blas::ComputationType::kF64;
}
template <>
blas::ComputationType ToComputationType<std::complex<float>>() {
return blas::ComputationType::kComplexF32;
}template <>
blas::ComputationType ToComputationType<std::complex<double>>() {
return blas::ComputationType::kComplexF64;
}
class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
public:
CUDABlasLtMatmulPlan(blas::DataType ab_type, blas::DataType cd_type,
blas::ComputationType compute_type,
blas::PointerMode pointer_mode, blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n, uint64 k,
int batch_count, int64 lda, int64 stride_a, int64 ldb,
int64 stride_b, int64 ldc, int64 stride_c, int64 ldd,
int64 stride_d);
cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); }
cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); }
cublasLtMatrixLayout_t b_desc() const { return b_desc_.get(); }
cublasLtMatrixLayout_t c_desc() const { return c_desc_.get(); }
cublasLtMatrixLayout_t d_desc() const { return d_desc_.get(); }
bool ok() { return op_desc_ && a_desc_ && b_desc_ && c_desc_ && d_desc_; }
blas::DataType ab_type() const { return ab_type_; }
blas::DataType cd_type() const { return cd_type_; }
blas::DataType scale_type() const { return scale_type_; }
blas::PointerMode pointer_mode() const { return pointer_mode_; }
private:
UniqueOpDesc op_desc_;
UniqueLayoutDesc a_desc_;
UniqueLayoutDesc b_desc_;
UniqueLayoutDesc c_desc_;
UniqueLayoutDesc d_desc_;
blas::DataType ab_type_;
blas::DataType cd_type_;
blas::DataType scale_type_;
blas::PointerMode pointer_mode_;
};
CUDABlasLtMatmulPlan::CUDABlasLtMatmulPlan(
blas::DataType ab_type, blas::DataType cd_type,
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, int batch_count, int64 lda, int64 stride_a, int64 ldb,
int64 stride_b, int64 ldc, int64 stride_c, int64 ldd, int64 stride_d)
: op_desc_(CreateCublasLtOperationDesc(
computation_type, GetScaleType(cd_type, computation_type),
pointer_mode, transa, transb)),
a_desc_(nullptr),
b_desc_(nullptr),
c_desc_(
CreateCublasLtLayoutDesc(cd_type, m, n, ldc, stride_c, batch_count)),
d_desc_(
CreateCublasLtLayoutDesc(cd_type, m, n, ldd, stride_d, batch_count)),
ab_type_(ab_type),
cd_type_(cd_type),
scale_type_(GetScaleType(cd_type, computation_type)),
pointer_mode_(pointer_mode) {
uint64 rows_a = transa == blas::Transpose::kNoTranspose ? m : k;
uint64 cols_a = transa == blas::Transpose::kNoTranspose ? k : m;
uint64 rows_b = transb == blas::Transpose::kNoTranspose ? k : n;
uint64 cols_b = transb == blas::Transpose::kNoTranspose ? n : k;
a_desc_ = CreateCublasLtLayoutDesc(ab_type, rows_a, cols_a, lda, stride_a,
batch_count);
b_desc_ = CreateCublasLtLayoutDesc(ab_type, rows_b, cols_b, ldb, stride_b,
batch_count);
}
class CUDABlasLtMatmulAlgorithm final : public blas::IBlasLtMatmulAlgorithm {
public:
CUDABlasLtMatmulAlgorithm(blas::AlgorithmType index,
cublasLtMatmulAlgo_t algo, size_t workspace_size)
: index_(index), algo_(algo), workspace_size_(workspace_size) {}
blas::AlgorithmType index() const override { return index_; }
size_t workspace_size() const override { return workspace_size_; }
const cublasLtMatmulAlgo_t* algo() const { return &algo_; }
int algo_id() const {
int id;
GetCublasLtAttr(&algo_, CUBLASLT_ALGO_CONFIG_ID, &id);
return id;
}
private:
blas::AlgorithmType index_;
cublasLtMatmulAlgo_t algo_;
size_t workspace_size_;
};
} // namespace
#endif // CUDA_VERSION >= 11000
std::unique_ptr<blas::IBlasLtMatmulPlan>
CUDABlas::CreateBlasLtMatmulPlanStridedBatched(
blas::DataType ab_type, blas::DataType cd_type,
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, int batch_count, int64 lda, int64 stride_a, int64 ldb,
int64 stride_b, int64 ldc, int64 stride_c) {
#if CUDA_VERSION >= 11000
auto result = std::make_unique<CUDABlasLtMatmulPlan>(
ab_type, cd_type, computation_type, pointer_mode, transa, transb, m, n, k,
batch_count, lda, stride_a, ldb, stride_b, ldc, stride_c, ldc, stride_c);
if (!result->ok()) {
result.reset();
}
return result;
#else
return nullptr;
#endif
}
bool CUDABlas::GetBlasLtMatmulAlgorithms(
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size,
int max_algorithm_count,
std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>*
out_algorithms) {
#if CUDA_VERSION >= 11000
UniqueMatmulPreference preference =
CreateCublasLtMatmulPreference(max_workspace_size);
if (!preference) return false;
std::vector<cublasLtMatmulHeuristicResult_t> results(max_algorithm_count);
{
absl::MutexLock lock(&mu_);
CHECK(blasLt_ != nullptr);
gpu::ScopedActivateExecutorContext sac{parent_};
int found_algorithm_count = 0;
const auto& cuda_plan = *static_cast<const CUDABlasLtMatmulPlan*>(plan);
cublasStatus_t status = cublasLtMatmulAlgoGetHeuristic(
blasLt_, cuda_plan.op_desc(), cuda_plan.a_desc(), cuda_plan.b_desc(),
cuda_plan.c_desc(), cuda_plan.d_desc(), preference.get(),
max_algorithm_count, results.data(), &found_algorithm_count);
if (status != CUBLAS_STATUS_SUCCESS) {
VLOG(2) << "cublasLtMatmulAlgoGetHeuristic failed: " << ToString(status);
return false;
}
results.resize(found_algorithm_count);
}
for (size_t i = 0; i < results.size(); ++i) {
const auto& result = results[i];
if (result.state != CUBLAS_STATUS_SUCCESS) continue; // Skip failed algos
out_algorithms->emplace_back(std::make_unique<CUDABlasLtMatmulAlgorithm>(
i, result.algo, result.workspaceSize));
}
return true;
#else // if CUDA_VERSION < 11000
return false;
#endif
}
#if CUDA_VERSION >= 11000
template <typename ABType, typename CDType, typename ScaleType>
bool CUDABlas::DoBlasLtMatmulInternalImpl(
Stream* stream, bool err_on_failure, const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<ScaleType>& alpha, const ABType* a,
const ABType* b, const HostOrDeviceScalar<ScaleType>& beta, const CDType* c,
CDType* d, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm) {
const auto& cuda_plan = *static_cast<const CUDABlasLtMatmulPlan*>(plan);
const auto& cuda_algo =
*static_cast<const CUDABlasLtMatmulAlgorithm*>(algorithm);
if (cuda_plan.ab_type() != blas::ToDataType<ABType>::value) {
VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong ab_type: "
"expected "
<< blas::ToDataType<ABType>::value << ", got "
<< cuda_plan.ab_type();
return false;
}
if (cuda_plan.cd_type() != blas::ToDataType<CDType>::value) {
VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong cd_type: "
"expected "
<< blas::ToDataType<CDType>::value << ", got "
<< cuda_plan.cd_type();
return false;
}
if (cuda_plan.scale_type() != blas::ToDataType<ScaleType>::value) {
VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong "
"scale_type: expected "
<< blas::ToDataType<ScaleType>::value << ", got "
<< cuda_plan.cd_type();
return false;
}
if (alpha.is_pointer() != beta.is_pointer()) {
VLOG(2) << "DoBlasLtMatmul returning false because one of `alpha` "
"and `beta` is a pointer, but the other is not.";
return false;
}
bool is_pointer_mode_host = !alpha.is_pointer();
if ((cuda_plan.pointer_mode() == blas::PointerMode::kHost) !=
is_pointer_mode_host) {
VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong "
"pointer_mode for the given alpha/beta.";
return false;
}
const ScaleType* alpha_ptr =
alpha.is_pointer() ? GpuMemory(alpha.pointer()) : &alpha.value();
const ScaleType* beta_ptr =
beta.is_pointer() ? GpuMemory(beta.pointer()) : &beta.value();
void* workspace = nullptr;
if (cuda_algo.workspace_size()) {
port::Status allocation_status = AllocateWorkspace(
&workspace, scratch_allocator, cuda_algo.workspace_size());
if (!allocation_status.ok()) {
if (err_on_failure || VLOG_IS_ON(3)) {
LOG(ERROR)
<< "Failed to allocate workspace for cublasLtMatmul algo with id: "
<< cuda_algo.algo_id() << " requiring "
<< cuda_algo.workspace_size() << " bytes of workspace";
}
return false;
}
}
cudaStream_t cuda_stream = CUDAStream(stream);
absl::MutexLock lock(&mu_);
CHECK(blasLt_ != nullptr);
gpu::ScopedActivateExecutorContext sac{parent_};
cublasStatus_t ret = cublasLtMatmul(
blasLt_, cuda_plan.op_desc(), alpha_ptr, a, cuda_plan.a_desc(), b,
cuda_plan.b_desc(), beta_ptr, c, cuda_plan.c_desc(), d,
cuda_plan.d_desc(), cuda_algo.algo(), workspace,
cuda_algo.workspace_size(), cuda_stream);
if (ret != CUBLAS_STATUS_SUCCESS) {
if (err_on_failure || VLOG_IS_ON(3)) {
LOG(ERROR) << "failed to run cublasLtMatmul routine: " << ToString(ret);
}
return false;
}
return true;
}
#endif // CUDA_VERSION >= 11000
template <typename ABType, typename CDType, typename ScaleType>
bool CUDABlas::DoBlasLtMatmulInternal(
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<ScaleType>& alpha, const DeviceMemory<ABType>& a,
const DeviceMemory<ABType>& b, const HostOrDeviceScalar<ScaleType>& beta,
const DeviceMemory<CDType>& c, DeviceMemory<CDType>* d,
ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
blas::ProfileResult* output_profile_result) {
#if CUDA_VERSION >= 11000
std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
if (output_profile_result) {
timer.reset(new GpuTimer(parent_));
if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
return false;
}
}
bool err_on_failure = timer != nullptr;
bool result = DoBlasLtMatmulInternalImpl(
stream, err_on_failure, plan, alpha, GpuMemory(a), GpuMemory(b), beta,
GpuMemory(c), GpuMemoryMutable(d), scratch_allocator, algorithm);
if (timer && result) {
// GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
// state.
if (!timer->Stop(AsGpuStream(stream))) {
return false;
}
output_profile_result->set_is_valid(true);
output_profile_result->set_algorithm(algorithm->index());
output_profile_result->set_elapsed_time_in_ms(
timer->GetElapsedMilliseconds());
}
return result;
#else // if CUDA_VERSION < 11000
return false;
#endif
}
bool CUDABlas::DoBlasLtMatmul(
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<int32>& alpha, const DeviceMemory<int8>& a,
const DeviceMemory<int8>& b, const HostOrDeviceScalar<int32>& beta,
DeviceMemory<int32>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
blas::ProfileResult* output_profile_result) {
return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
scratch_allocator, algorithm,
output_profile_result);
}
bool CUDABlas::DoBlasLtMatmul(Stream* stream,
const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<Eigen::half>& alpha,
const DeviceMemory<Eigen::half>& a,
const DeviceMemory<Eigen::half>& b,
const HostOrDeviceScalar<Eigen::half>& beta,
DeviceMemory<Eigen::half>* c,
ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
blas::ProfileResult* output_profile_result) {
#if CUDA_VERSION >= 11000
const auto& cuda_plan = *static_cast<const CUDABlasLtMatmulPlan*>(plan);
if (cuda_plan.scale_type() == blas::DataType::kF32) {
// F32* computation types require F32 alpha/beta type, so we must cast them.
if (alpha.is_pointer() || beta.is_pointer()) {
// We cannot easily convert a pointer to f16 memory to a pointer to f32
// memory from here, so we don't support this for now.
return false;
}
HostOrDeviceScalar<float> float_alpha(static_cast<float>(alpha.value()));
HostOrDeviceScalar<float> float_beta(static_cast<float>(beta.value()));
return DoBlasLtMatmulInternal(stream, plan, float_alpha, a, b, float_beta,
*c, c, scratch_allocator, algorithm,
output_profile_result);
}
return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
scratch_allocator, algorithm,
output_profile_result);
#else // if CUDA_VERSION < 11000
return false;
#endif
}
bool CUDABlas::DoBlasLtMatmul(
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<float>& alpha, const DeviceMemory<float>& a,
const DeviceMemory<float>& b, const HostOrDeviceScalar<float>& beta,
DeviceMemory<float>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
blas::ProfileResult* output_profile_result) {
return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
scratch_allocator, algorithm,
output_profile_result);
}
bool CUDABlas::DoBlasLtMatmul(
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<double>& alpha, const DeviceMemory<double>& a,
const DeviceMemory<double>& b, const HostOrDeviceScalar<double>& beta,
DeviceMemory<double>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
blas::ProfileResult* output_profile_result) {
return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
scratch_allocator, algorithm,
output_profile_result);
}
bool CUDABlas::DoBlasLtMatmul(
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<std::complex<float>>& alpha,
const DeviceMemory<std::complex<float>>& a,
const DeviceMemory<std::complex<float>>& b,
const HostOrDeviceScalar<std::complex<float>>& beta,
DeviceMemory<std::complex<float>>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
blas::ProfileResult* output_profile_result) {
return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
scratch_allocator, algorithm,
output_profile_result);
}
bool CUDABlas::DoBlasLtMatmul(
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<std::complex<double>>& alpha,
const DeviceMemory<std::complex<double>>& a,
const DeviceMemory<std::complex<double>>& b,
const HostOrDeviceScalar<std::complex<double>>& beta,
DeviceMemory<std::complex<double>>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
blas::ProfileResult* output_profile_result) {
return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
scratch_allocator, algorithm,
output_profile_result);
}
port::Status CUDABlas::GetVersion(std::string *version) {
absl::MutexLock lock(&mu_);

View File

@ -22,6 +22,8 @@ limitations under the License.
#include "absl/synchronization/mutex.h"
#include "third_party/gpus/cuda/include/cublas_v2.h"
#include "third_party/gpus/cuda/include/cublasLt.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/host_or_device_scalar.h"
@ -71,6 +73,9 @@ class CUDABlas : public blas::BlasSupport {
// invoked before calling into cuBLAS.
bool SetStream(Stream *stream) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Returns the underlying CUDA stream.
cudaStream_t CUDAStream(Stream* stream);
// A helper function that calls the real cuBLAS function together with error
// handling.
//
@ -134,6 +139,26 @@ class CUDABlas : public blas::BlasSupport {
const T &beta, DeviceMemory<T> *y, int incy,
blas::ProfileResult *output_profile_result);
// Helper function for implementing DoBlasLtMatmul.
template <typename ABType, typename CDType, typename ScaleType>
bool DoBlasLtMatmulInternal(
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<ScaleType>& alpha, const DeviceMemory<ABType>& a,
const DeviceMemory<ABType>& b, const HostOrDeviceScalar<ScaleType>& beta,
const DeviceMemory<CDType>& c, DeviceMemory<CDType>* d,
ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
blas::ProfileResult* output_profile_result);
// Helper function for implementing DoBlasLtMatmulInternal.
template <typename ABType, typename CDType, typename ScaleType>
bool DoBlasLtMatmulInternalImpl(
Stream* stream, bool err_on_failure, const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<ScaleType>& alpha, const ABType* a,
const ABType* b, const HostOrDeviceScalar<ScaleType>& beta,
const CDType* c, CDType* d, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm);
// Guards the cuBLAS handle for this device.
absl::Mutex mu_;
@ -144,6 +169,11 @@ class CUDABlas : public blas::BlasSupport {
// cuBLAS library handle on the device.
cublasHandle_t blas_ TF_GUARDED_BY(mu_);
#if CUDA_VERSION >= 11000
// cuBLASLt library handle on the device.
cublasLtHandle_t blasLt_ GUARDED_BY(mu_);
#endif
SE_DISALLOW_COPY_AND_ASSIGN(CUDABlas);
};

View File

@ -23,6 +23,7 @@ namespace DsoLoader {
port::Status TryDlopenCUDALibraries() {
auto cudart_status = GetCudaRuntimeDsoHandle();
auto cublas_status = GetCublasDsoHandle();
auto cublaslt_status = GetCublasLtDsoHandle();
auto cufft_status = GetCufftDsoHandle();
auto curand_status = GetCurandDsoHandle();
auto cusolver_status = GetCusolverDsoHandle();
@ -31,7 +32,7 @@ port::Status TryDlopenCUDALibraries() {
if (!cudart_status.status().ok() || !cublas_status.status().ok() ||
!cufft_status.status().ok() || !curand_status.status().ok() ||
!cusolver_status.status().ok() || !cusparse_status.status().ok() ||
!cudnn_status.status().ok()) {
!cudnn_status.status().ok() || !cublaslt_status.status().ok()) {
return port::Status(port::error::INTERNAL,
absl::StrCat("Cannot dlopen all CUDA libraries."));
} else {

View File

@ -84,6 +84,10 @@ port::StatusOr<void*> GetCublasDsoHandle() {
return GetDsoHandle("cublas", GetCublasVersion());
}
port::StatusOr<void*> GetCublasLtDsoHandle() {
return GetDsoHandle("cublasLt", GetCublasVersion());
}
port::StatusOr<void*> GetCufftDsoHandle() {
return GetDsoHandle("cufft", GetCufftVersion());
}
@ -160,6 +164,11 @@ port::StatusOr<void*> GetCublasDsoHandle() {
return *result;
}
port::StatusOr<void*> GetCublasLtDsoHandle() {
static auto result = new auto(DsoLoader::GetCublasLtDsoHandle());
return *result;
}
port::StatusOr<void*> GetCurandDsoHandle() {
static auto result = new auto(DsoLoader::GetCurandDsoHandle());
return *result;

View File

@ -37,6 +37,7 @@ namespace DsoLoader {
port::StatusOr<void*> GetCudaDriverDsoHandle();
port::StatusOr<void*> GetCudaRuntimeDsoHandle();
port::StatusOr<void*> GetCublasDsoHandle();
port::StatusOr<void*> GetCublasLtDsoHandle();
port::StatusOr<void*> GetCufftDsoHandle();
port::StatusOr<void*> GetCurandDsoHandle();
port::StatusOr<void*> GetCusolverDsoHandle();
@ -72,6 +73,7 @@ namespace CachedDsoLoader {
port::StatusOr<void*> GetCudaDriverDsoHandle();
port::StatusOr<void*> GetCudaRuntimeDsoHandle();
port::StatusOr<void*> GetCublasDsoHandle();
port::StatusOr<void*> GetCublasLtDsoHandle();
port::StatusOr<void*> GetCufftDsoHandle();
port::StatusOr<void*> GetCurandDsoHandle();
port::StatusOr<void*> GetCusolverDsoHandle();

View File

@ -4801,6 +4801,143 @@ Stream &Stream::ThenBlasGemmStridedBatched(
c, ldc, stride_c, batch_count);
}
Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<int32>& alpha,
const DeviceMemory<int8>& a,
const DeviceMemory<int8>& b,
const HostOrDeviceScalar<int32>& beta,
DeviceMemory<int32>* c,
ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
blas::ProfileResult* output_profile_result) {
VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
PARAM(c), PARAM(algorithm));
ThenBlasWithProfileImpl<
const blas::IBlasLtMatmulPlan*, const HostOrDeviceScalar<int32>&,
const DeviceMemory<int8>&, const DeviceMemory<int8>&,
const HostOrDeviceScalar<int32>&, DeviceMemory<int32>*, ScratchAllocator*,
const blas::IBlasLtMatmulAlgorithm*>
impl;
return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
c, scratch_allocator, algorithm, output_profile_result);
}
Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<Eigen::half>& alpha,
const DeviceMemory<Eigen::half>& a,
const DeviceMemory<Eigen::half>& b,
const HostOrDeviceScalar<Eigen::half>& beta,
DeviceMemory<Eigen::half>* c,
ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
blas::ProfileResult* output_profile_result) {
VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
PARAM(c), PARAM(algorithm));
ThenBlasWithProfileImpl<
const blas::IBlasLtMatmulPlan*, const HostOrDeviceScalar<Eigen::half>&,
const DeviceMemory<Eigen::half>&, const DeviceMemory<Eigen::half>&,
const HostOrDeviceScalar<Eigen::half>&, DeviceMemory<Eigen::half>*,
ScratchAllocator*, const blas::IBlasLtMatmulAlgorithm*>
impl;
return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
c, scratch_allocator, algorithm, output_profile_result);
}
Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<float>& alpha,
const DeviceMemory<float>& a,
const DeviceMemory<float>& b,
const HostOrDeviceScalar<float>& beta,
DeviceMemory<float>* c,
ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
blas::ProfileResult* output_profile_result) {
VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
PARAM(c), PARAM(algorithm));
ThenBlasWithProfileImpl<
const blas::IBlasLtMatmulPlan*, const HostOrDeviceScalar<float>&,
const DeviceMemory<float>&, const DeviceMemory<float>&,
const HostOrDeviceScalar<float>&, DeviceMemory<float>*, ScratchAllocator*,
const blas::IBlasLtMatmulAlgorithm*>
impl;
return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
c, scratch_allocator, algorithm, output_profile_result);
}
Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<double>& alpha,
const DeviceMemory<double>& a,
const DeviceMemory<double>& b,
const HostOrDeviceScalar<double>& beta,
DeviceMemory<double>* c,
ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
blas::ProfileResult* output_profile_result) {
VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
PARAM(c), PARAM(algorithm));
ThenBlasWithProfileImpl<
const blas::IBlasLtMatmulPlan*, const HostOrDeviceScalar<double>&,
const DeviceMemory<double>&, const DeviceMemory<double>&,
const HostOrDeviceScalar<double>&, DeviceMemory<double>*,
ScratchAllocator*, const blas::IBlasLtMatmulAlgorithm*>
impl;
return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
c, scratch_allocator, algorithm, output_profile_result);
}
Stream& Stream::ThenBlasLtMatmul(
const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<std::complex<float>>& alpha,
const DeviceMemory<std::complex<float>>& a,
const DeviceMemory<std::complex<float>>& b,
const HostOrDeviceScalar<std::complex<float>>& beta,
DeviceMemory<std::complex<float>>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
blas::ProfileResult* output_profile_result) {
VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
PARAM(c), PARAM(algorithm));
ThenBlasWithProfileImpl<const blas::IBlasLtMatmulPlan*,
const HostOrDeviceScalar<std::complex<float>>&,
const DeviceMemory<std::complex<float>>&,
const DeviceMemory<std::complex<float>>&,
const HostOrDeviceScalar<std::complex<float>>&,
DeviceMemory<std::complex<float>>*, ScratchAllocator*,
const blas::IBlasLtMatmulAlgorithm*>
impl;
return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
c, scratch_allocator, algorithm, output_profile_result);
}
Stream& Stream::ThenBlasLtMatmul(
const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<std::complex<double>>& alpha,
const DeviceMemory<std::complex<double>>& a,
const DeviceMemory<std::complex<double>>& b,
const HostOrDeviceScalar<std::complex<double>>& beta,
DeviceMemory<std::complex<double>>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
blas::ProfileResult* output_profile_result) {
VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
PARAM(c), PARAM(algorithm));
ThenBlasWithProfileImpl<const blas::IBlasLtMatmulPlan*,
const HostOrDeviceScalar<std::complex<double>>&,
const DeviceMemory<std::complex<double>>&,
const DeviceMemory<std::complex<double>>&,
const HostOrDeviceScalar<std::complex<double>>&,
DeviceMemory<std::complex<double>>*,
ScratchAllocator*,
const blas::IBlasLtMatmulAlgorithm*>
impl;
return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
c, scratch_allocator, algorithm, output_profile_result);
}
Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
VLOG_CALL(PARAM(seed), PARAM(seed_bytes));

View File

@ -1665,6 +1665,56 @@ class Stream {
const DeviceMemory<std::complex<double>> &a, int lda,
DeviceMemory<std::complex<double>> *b, int ldb);
// See BlasSupport::DoBlatLtMatmul.
Stream& ThenBlasLtMatmul(
const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<int32>& alpha, const DeviceMemory<int8>& a,
const DeviceMemory<int8>& b, const HostOrDeviceScalar<int32>& beta,
DeviceMemory<int32>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
blas::ProfileResult* output_profile_result = nullptr);
Stream& ThenBlasLtMatmul(
const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<Eigen::half>& alpha,
const DeviceMemory<Eigen::half>& a, const DeviceMemory<Eigen::half>& b,
const HostOrDeviceScalar<Eigen::half>& beta, DeviceMemory<Eigen::half>* c,
ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
blas::ProfileResult* output_profile_result = nullptr);
Stream& ThenBlasLtMatmul(
const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<float>& alpha, const DeviceMemory<float>& a,
const DeviceMemory<float>& b, const HostOrDeviceScalar<float>& beta,
DeviceMemory<float>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
blas::ProfileResult* output_profile_result = nullptr);
Stream& ThenBlasLtMatmul(
const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<double>& alpha, const DeviceMemory<double>& a,
const DeviceMemory<double>& b, const HostOrDeviceScalar<double>& beta,
DeviceMemory<double>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
blas::ProfileResult* output_profile_result = nullptr);
Stream& ThenBlasLtMatmul(
const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<std::complex<float>>& alpha,
const DeviceMemory<std::complex<float>>& a,
const DeviceMemory<std::complex<float>>& b,
const HostOrDeviceScalar<std::complex<float>>& beta,
DeviceMemory<std::complex<float>>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
blas::ProfileResult* output_profile_result = nullptr);
Stream& ThenBlasLtMatmul(
const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<std::complex<double>>& alpha,
const DeviceMemory<std::complex<double>>& a,
const DeviceMemory<std::complex<double>>& b,
const HostOrDeviceScalar<std::complex<double>>& beta,
DeviceMemory<std::complex<double>>* c,
ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
blas::ProfileResult* output_profile_result = nullptr);
// See FftSupport::DoFft.
Stream &ThenFft(fft::Plan *plan,
const DeviceMemory<std::complex<float>> &input,

View File

@ -336,6 +336,49 @@ bool StreamExecutor::GetBlasGemmAlgorithms(
return blas_support->GetBlasGemmAlgorithms(out_algorithms);
}
std::unique_ptr<blas::IBlasLtMatmulPlan> StreamExecutor::CreateBlasLtMatmulPlan(
blas::DataType ab_type, blas::DataType cd_type,
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, int64 lda, int64 ldb, int64 ldc) {
blas::BlasSupport *blas_support = AsBlas();
if (!blas_support) {
return nullptr;
}
return blas_support->CreateBlasLtMatmulPlan(
ab_type, cd_type, computation_type, pointer_mode, transa, transb, m, n, k,
lda, ldb, ldc);
}
std::unique_ptr<blas::IBlasLtMatmulPlan>
StreamExecutor::CreateBlasLtMatmulPlanStridedBatched(
blas::DataType ab_type, blas::DataType cd_type,
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, uint64 batch_count, int64 lda, int64 stride_a, int64 ldb,
int64 stride_b, int64 ldc, int64 stride_c) {
blas::BlasSupport *blas_support = AsBlas();
if (!blas_support) {
return nullptr;
}
return blas_support->CreateBlasLtMatmulPlanStridedBatched(
ab_type, cd_type, computation_type, pointer_mode, transa, transb, m, n, k,
batch_count, lda, stride_a, ldb, stride_b, ldc, stride_c);
}
bool StreamExecutor::GetBlasLtMatmulAlgorithms(
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size,
int max_algorithm_count,
std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>*
out_algorithms) {
blas::BlasSupport *blas_support = AsBlas();
if (!blas_support) {
return false;
}
return blas_support->GetBlasLtMatmulAlgorithms(
plan, max_workspace_size, max_algorithm_count, out_algorithms);
}
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
StreamExecutor::createRnnDescriptor(
int num_layers, int hidden_size, int input_size, int cell_size,

View File

@ -394,6 +394,35 @@ class StreamExecutor {
// Get the list of supported algorithms for BLAS gemm.
bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms);
// Creates a backend-specific plan object for a blaslt matmul operation, which
// can then be passed to DoBlasLtMatmul(). When possible, plans should be
// created once and reused for multiple calls to DoBlasLtMatmul().
// Returns a null pointer on failure.
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
blas::DataType ab_type, blas::DataType cd_type,
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, int64 lda, int64 ldb, int64 ldc);
// A more general version of CreateBlasLtMatmulPlan supporting
// batched operations.
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlanStridedBatched(
blas::DataType ab_type, blas::DataType cd_type,
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, uint64 batch_count, int64 lda, int64 stride_a, int64 ldb,
int64 stride_b, int64 ldc, int64 stride_c);
// Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
// returned in the order of increasing estimated compute time according to an
// internal heuristic. The first returned algorithm can be used as the default
// algorithm if no autotuning is to be performed.
bool GetBlasLtMatmulAlgorithms(
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size,
int max_algorithm_count,
std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>*
out_algorithms);
// Create an RNN descriptor based on model shapes and configurations.
// The caller retains the ownership of the descriptor.
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(

View File

@ -127,6 +127,13 @@ cc_library(
linkstatic = 1,
)
cc_library(
name = "cublasLt",
srcs = ["cuda/lib/%{cublasLt_lib}"],
data = ["cuda/lib/%{cublasLt_lib}"],
linkstatic = 1,
)
cc_library(
name = "cusolver",
srcs = ["cuda/lib/%{cusolver_lib}"],
@ -168,6 +175,7 @@ cc_library(
name = "cuda",
deps = [
":cublas",
":cublasLt",
":cuda_headers",
":cudart",
":cudnn",

View File

@ -551,6 +551,13 @@ def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config):
cuda_config.cublas_version,
static = False,
),
"cublasLt": _check_cuda_lib_params(
"cublasLt",
cpu_value,
cuda_config.config["cublas_library_dir"],
cuda_config.cublas_version,
static = False,
),
"cusolver": _check_cuda_lib_params(
"cusolver",
cpu_value,
@ -771,6 +778,7 @@ def _create_dummy_repository(repository_ctx):
"%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value),
"%{cudart_lib}": lib_name("cudart", cpu_value),
"%{cublas_lib}": lib_name("cublas", cpu_value),
"%{cublasLt_lib}": lib_name("cublasLt", cpu_value),
"%{cusolver_lib}": lib_name("cusolver", cpu_value),
"%{cudnn_lib}": lib_name("cudnn", cpu_value),
"%{cufft_lib}": lib_name("cufft", cpu_value),
@ -802,6 +810,7 @@ filegroup(name="cudnn-include")
"cuda/cuda/lib/%s" % lib_name("cudart_static", cpu_value),
)
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublas", cpu_value))
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublasLt", cpu_value))
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusolver", cpu_value))
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudnn", cpu_value))
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("curand", cpu_value))
@ -992,11 +1001,13 @@ def _create_local_cuda_repository(repository_ctx):
cublas_include_path + "/cublas.h",
cublas_include_path + "/cublas_v2.h",
cublas_include_path + "/cublas_api.h",
cublas_include_path + "/cublasLt.h",
],
outs = [
"cublas/include/cublas.h",
"cublas/include/cublas_v2.h",
"cublas/include/cublas_api.h",
"cublas/include/cublasLt.h",
],
))
@ -1137,6 +1148,7 @@ def _create_local_cuda_repository(repository_ctx):
"%{cudart_static_linkopt}": _cudart_static_linkopt(cuda_config.cpu_value),
"%{cudart_lib}": _basename(repository_ctx, cuda_libs["cudart"]),
"%{cublas_lib}": _basename(repository_ctx, cuda_libs["cublas"]),
"%{cublasLt_lib}": _basename(repository_ctx, cuda_libs["cublasLt"]),
"%{cusolver_lib}": _basename(repository_ctx, cuda_libs["cusolver"]),
"%{cudnn_lib}": _basename(repository_ctx, cuda_libs["cudnn"]),
"%{cufft_lib}": _basename(repository_ctx, cuda_libs["cufft"]),