Add cublasLt wrappers to stream_executor
- Adds ThenBlasLtMatmul routines that behave similarly to ThenBlasGemmWithAlgorithm but call into the cublasLt library and allow separation of plan creation and execution. - A list of heuristically-prioritized opaque algorithm objects can be obtained via GetBlasLtMatmulAlgorithms. - These routines are only supported when the CUDA version is >= 11.0.
This commit is contained in:
parent
8ee3640e16
commit
aaea82e6bc
tensorflow/stream_executor
third_party/gpus
@ -95,5 +95,30 @@ std::ostream& operator<<(std::ostream& os, ComputationType ty) {
|
||||
return os << ComputationTypeString(ty);
|
||||
}
|
||||
|
||||
string DataTypeString(DataType ty) {
|
||||
switch (ty) {
|
||||
case DataType::kF16:
|
||||
return "f16";
|
||||
case DataType::kF32:
|
||||
return "f32";
|
||||
case DataType::kF64:
|
||||
return "f64";
|
||||
case DataType::kI8:
|
||||
return "i8";
|
||||
case DataType::kI32:
|
||||
return "i32";
|
||||
case DataType::kComplexF32:
|
||||
return "complex f32";
|
||||
case DataType::kComplexF64:
|
||||
return "complex f64";
|
||||
default:
|
||||
LOG(FATAL) << "Unknown DataType " << static_cast<int32>(ty);
|
||||
}
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, DataType ty) {
|
||||
return os << DataTypeString(ty);
|
||||
}
|
||||
|
||||
} // namespace blas
|
||||
} // namespace stream_executor
|
||||
|
@ -101,6 +101,10 @@ enum class ComputationType {
|
||||
kI32, // 32-bit integer
|
||||
kComplexF32, // Complex number comprised of two f32s.
|
||||
kComplexF64, // Complex number comprised of two f64s.
|
||||
// The below values are only supported for BlasLt routines (both real and
|
||||
// complex).
|
||||
kF32FastTF32, // 32-bit floating-point with reduced (>=10-bit) mantissa
|
||||
kF32FastBF16, // 32-bit floating-point with reduced (7-bit) mantissa
|
||||
};
|
||||
|
||||
// Converts a ComputationType to a string.
|
||||
@ -108,6 +112,61 @@ std::string ComputationTypeString(ComputationType ty);
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, ComputationType ty);
|
||||
|
||||
// Type with which inputs and outputs of a blaslt routine are performed.
|
||||
enum class DataType {
|
||||
kF16, // 16-bit floating-point
|
||||
kF32, // 32-bit floating-point
|
||||
kF64, // 64-bit floating-point
|
||||
kI8, // 8-bit integer
|
||||
kI32, // 32-bit integer
|
||||
kComplexF32, // Complex number comprised of two f32s
|
||||
kComplexF64, // Complex number comprised of two f64s
|
||||
};
|
||||
|
||||
// Describes the type of pointers for the scaling factors alpha and beta in
|
||||
// blaslt routines.
|
||||
enum class PointerMode {
|
||||
kHost,
|
||||
kDevice,
|
||||
};
|
||||
|
||||
// Converts a ComputationType to a string.
|
||||
string DataTypeString(DataType ty);
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, DataType ty);
|
||||
|
||||
// Converts a compile-time type to a DataType value.
|
||||
template <typename T>
|
||||
struct ToDataType {};
|
||||
template <>
|
||||
struct ToDataType<Eigen::half> {
|
||||
static constexpr const DataType value = DataType::kF16;
|
||||
};
|
||||
template <>
|
||||
struct ToDataType<float> {
|
||||
static constexpr const DataType value = DataType::kF32;
|
||||
};
|
||||
template <>
|
||||
struct ToDataType<double> {
|
||||
static constexpr const DataType value = DataType::kF64;
|
||||
};
|
||||
template <>
|
||||
struct ToDataType<int8> {
|
||||
static constexpr const DataType value = DataType::kI8;
|
||||
};
|
||||
template <>
|
||||
struct ToDataType<int32> {
|
||||
static constexpr const DataType value = DataType::kI32;
|
||||
};
|
||||
template <>
|
||||
struct ToDataType<std::complex<float>> {
|
||||
static constexpr const DataType value = DataType::kComplexF32;
|
||||
};
|
||||
template <>
|
||||
struct ToDataType<std::complex<double>> {
|
||||
static constexpr const DataType value = DataType::kComplexF64;
|
||||
};
|
||||
|
||||
// Opaque identifier for an "algorithm" used by a blas routine. This functions
|
||||
// as a hint to the blas library.
|
||||
typedef int64 AlgorithmType;
|
||||
@ -163,6 +222,19 @@ class AlgorithmConfig {
|
||||
AlgorithmType algorithm_;
|
||||
};
|
||||
|
||||
struct IBlasLtMatmulPlan {
|
||||
virtual ~IBlasLtMatmulPlan() {}
|
||||
};
|
||||
|
||||
struct IBlasLtMatmulAlgorithm {
|
||||
virtual ~IBlasLtMatmulAlgorithm() {}
|
||||
// Returns the index of the algorithm within the list returned by
|
||||
// GetBlasLtMatmulAlgorithms.
|
||||
virtual AlgorithmType index() const = 0;
|
||||
// Returns the workspace size required by the algorithm in bytes.
|
||||
virtual size_t workspace_size() const = 0;
|
||||
};
|
||||
|
||||
// BLAS support interface -- this can be derived from a GPU executor when the
|
||||
// underlying platform has an BLAS library implementation available. See
|
||||
// StreamExecutor::AsBlas().
|
||||
@ -1383,6 +1455,93 @@ class BlasSupport {
|
||||
const DeviceMemory<std::complex<double>> &a, int lda,
|
||||
DeviceMemory<std::complex<double>> *b, int ldb) = 0;
|
||||
|
||||
// Creates a backend-specific plan object for a blaslt matmul operation, which
|
||||
// can then be passed to DoBlasLtMatmul(). When possible, plans should be
|
||||
// created once and reused for multiple calls to DoBlasLtMatmul().
|
||||
// Returns a null pointer on failure.
|
||||
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
|
||||
blas::DataType ab_type, blas::DataType c_type,
|
||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
||||
uint64 k, int64 lda, int64 ldb, int64 ldc) {
|
||||
return CreateBlasLtMatmulPlanStridedBatched(
|
||||
ab_type, c_type, computation_type, pointer_mode, transa, transb, m, n,
|
||||
k, 1, lda, 0, ldb, 0, ldc, 0);
|
||||
}
|
||||
|
||||
// A more general version of CreateBlasLtMatmulPlan supporting
|
||||
// batched operations.
|
||||
virtual std::unique_ptr<blas::IBlasLtMatmulPlan>
|
||||
CreateBlasLtMatmulPlanStridedBatched(
|
||||
blas::DataType ab_type, blas::DataType c_type,
|
||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
||||
uint64 k, int batch_count, int64 lda, int64 stride_a, int64 ldb,
|
||||
int64 stride_b, int64 ldc, int64 stride_c) = 0;
|
||||
|
||||
// Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
|
||||
// returned in the order of increasing estimated compute time according to an
|
||||
// internal heuristic. The first returned algorithm can be used as the default
|
||||
// algorithm if no autotuning is to be performed.
|
||||
virtual bool GetBlasLtMatmulAlgorithms(
|
||||
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size,
|
||||
int max_algorithm_count,
|
||||
std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>*
|
||||
out_algorithms) = 0;
|
||||
|
||||
// Executes a blaslt matmul operation on the stream. If output_profile_result
|
||||
// is not nullptr, the operation is profiled, error messages are
|
||||
// suppressed, and output_profile_result->algorithm() is set to
|
||||
// algorithm->index().
|
||||
virtual bool DoBlasLtMatmul(
|
||||
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<int32>& alpha, const DeviceMemory<int8>& a,
|
||||
const DeviceMemory<int8>& b, const HostOrDeviceScalar<int32>& beta,
|
||||
DeviceMemory<int32>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
blas::ProfileResult* output_profile_result = nullptr) = 0;
|
||||
virtual bool DoBlasLtMatmul(
|
||||
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<Eigen::half>& alpha,
|
||||
const DeviceMemory<Eigen::half>& a, const DeviceMemory<Eigen::half>& b,
|
||||
const HostOrDeviceScalar<Eigen::half>& beta, DeviceMemory<Eigen::half>* c,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
blas::ProfileResult* output_profile_result = nullptr) = 0;
|
||||
virtual bool DoBlasLtMatmul(
|
||||
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<float>& alpha, const DeviceMemory<float>& a,
|
||||
const DeviceMemory<float>& b, const HostOrDeviceScalar<float>& beta,
|
||||
DeviceMemory<float>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
blas::ProfileResult* output_profile_result = nullptr) = 0;
|
||||
virtual bool DoBlasLtMatmul(
|
||||
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<double>& alpha, const DeviceMemory<double>& a,
|
||||
const DeviceMemory<double>& b, const HostOrDeviceScalar<double>& beta,
|
||||
DeviceMemory<double>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
blas::ProfileResult* output_profile_result = nullptr) = 0;
|
||||
virtual bool DoBlasLtMatmul(
|
||||
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<std::complex<float>>& alpha,
|
||||
const DeviceMemory<std::complex<float>>& a,
|
||||
const DeviceMemory<std::complex<float>>& b,
|
||||
const HostOrDeviceScalar<std::complex<float>>& beta,
|
||||
DeviceMemory<std::complex<float>>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
blas::ProfileResult* output_profile_result = nullptr) = 0;
|
||||
virtual bool DoBlasLtMatmul(
|
||||
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<std::complex<double>>& alpha,
|
||||
const DeviceMemory<std::complex<double>>& a,
|
||||
const DeviceMemory<std::complex<double>>& b,
|
||||
const HostOrDeviceScalar<std::complex<double>>& beta,
|
||||
DeviceMemory<std::complex<double>>* c,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
blas::ProfileResult* output_profile_result = nullptr) = 0;
|
||||
|
||||
virtual port::Status GetVersion(std::string *version) = 0;
|
||||
|
||||
protected:
|
||||
@ -2196,6 +2355,65 @@ class BlasSupport {
|
||||
uint64 n, std::complex<double> alpha, \
|
||||
const DeviceMemory<std::complex<double>> &a, int lda, \
|
||||
DeviceMemory<std::complex<double>> *b, int ldb) override; \
|
||||
std::unique_ptr<blas::IBlasLtMatmulPlan> \
|
||||
CreateBlasLtMatmulPlanStridedBatched( \
|
||||
blas::DataType ab_type, blas::DataType cd_type, \
|
||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode, \
|
||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, \
|
||||
uint64 k, int batch_count, int64 lda, int64 stride_a, int64 ldb, \
|
||||
int64 stride_b, int64 ldc, int64 stride_c) override; \
|
||||
bool GetBlasLtMatmulAlgorithms( \
|
||||
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size, \
|
||||
int max_algorithm_count, \
|
||||
std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>* \
|
||||
out_algorithms) override; \
|
||||
bool DoBlasLtMatmul( \
|
||||
Stream* stream, const blas::IBlasLtMatmulPlan* plan, \
|
||||
const HostOrDeviceScalar<int32>& alpha, const DeviceMemory<int8>& a, \
|
||||
const DeviceMemory<int8>& b, const HostOrDeviceScalar<int32>& beta, \
|
||||
DeviceMemory<int32>* c, ScratchAllocator* scratch_allocator, \
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm, \
|
||||
blas::ProfileResult* output_profile_result = nullptr) override; \
|
||||
bool DoBlasLtMatmul( \
|
||||
Stream* stream, const blas::IBlasLtMatmulPlan* plan, \
|
||||
const HostOrDeviceScalar<Eigen::half>& alpha, \
|
||||
const DeviceMemory<Eigen::half>& a, const DeviceMemory<Eigen::half>& b, \
|
||||
const HostOrDeviceScalar<Eigen::half>& beta, \
|
||||
DeviceMemory<Eigen::half>* c, ScratchAllocator* scratch_allocator, \
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm, \
|
||||
blas::ProfileResult* output_profile_result) override; \
|
||||
bool DoBlasLtMatmul( \
|
||||
Stream* stream, const blas::IBlasLtMatmulPlan* plan, \
|
||||
const HostOrDeviceScalar<float>& alpha, const DeviceMemory<float>& a, \
|
||||
const DeviceMemory<float>& b, const HostOrDeviceScalar<float>& beta, \
|
||||
DeviceMemory<float>* c, ScratchAllocator* scratch_allocator, \
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm, \
|
||||
blas::ProfileResult* output_profile_result) override; \
|
||||
bool DoBlasLtMatmul( \
|
||||
Stream* stream, const blas::IBlasLtMatmulPlan* plan, \
|
||||
const HostOrDeviceScalar<double>& alpha, const DeviceMemory<double>& a, \
|
||||
const DeviceMemory<double>& b, const HostOrDeviceScalar<double>& beta, \
|
||||
DeviceMemory<double>* c, ScratchAllocator* scratch_allocator, \
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm, \
|
||||
blas::ProfileResult* output_profile_result) override; \
|
||||
bool DoBlasLtMatmul(Stream* stream, const blas::IBlasLtMatmulPlan* plan, \
|
||||
const HostOrDeviceScalar<std::complex<float>>& alpha, \
|
||||
const DeviceMemory<std::complex<float>>& a, \
|
||||
const DeviceMemory<std::complex<float>>& b, \
|
||||
const HostOrDeviceScalar<std::complex<float>>& beta, \
|
||||
DeviceMemory<std::complex<float>>* c, \
|
||||
ScratchAllocator* scratch_allocator, \
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm, \
|
||||
blas::ProfileResult* output_profile_result) override; \
|
||||
bool DoBlasLtMatmul(Stream* stream, const blas::IBlasLtMatmulPlan* plan, \
|
||||
const HostOrDeviceScalar<std::complex<double>>& alpha, \
|
||||
const DeviceMemory<std::complex<double>>& a, \
|
||||
const DeviceMemory<std::complex<double>>& b, \
|
||||
const HostOrDeviceScalar<std::complex<double>>& beta, \
|
||||
DeviceMemory<std::complex<double>>* c, \
|
||||
ScratchAllocator* scratch_allocator, \
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm, \
|
||||
blas::ProfileResult* output_profile_result) override; \
|
||||
port::Status GetVersion(std::string *version) override;
|
||||
|
||||
} // namespace blas
|
||||
|
@ -242,6 +242,29 @@ alias(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cublasLt_stub",
|
||||
srcs = if_cuda_is_configured(["cublasLt_stub.cc"]),
|
||||
textual_hdrs = glob(["cublasLt_*.inc"]),
|
||||
deps = if_cuda_is_configured([
|
||||
# LINT.IfChange
|
||||
"@local_config_cuda//cuda:cublas_headers",
|
||||
# LINT.ThenChange(//tensorflow/copy.bara.sky:cublasLt_headers)
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"//tensorflow/stream_executor/platform:dso_loader",
|
||||
]),
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "cublasLt_lib",
|
||||
actual = select({
|
||||
"//tensorflow:oss": ":cublasLt_stub",
|
||||
"//conditions:default": "@local_config_cuda//cuda:cublasLt",
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cublas_plugin",
|
||||
srcs = if_cuda_is_configured(["cuda_blas.cc"]),
|
||||
@ -249,6 +272,7 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = if_cuda_is_configured([
|
||||
":cublas_lib",
|
||||
":cublasLt_lib",
|
||||
":cuda_activation",
|
||||
":cuda_gpu_executor",
|
||||
":cuda_platform_id",
|
||||
|
415
tensorflow/stream_executor/cuda/cublasLt_11_0.inc
Normal file
415
tensorflow/stream_executor/cuda/cublasLt_11_0.inc
Normal file
@ -0,0 +1,415 @@
|
||||
// Auto-generated, do not edit.
|
||||
|
||||
extern "C" {
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtCreate(cublasLtHandle_t *lightHandle) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtHandle_t *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtCreate");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(lightHandle);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtDestroy(cublasLtHandle_t lightHandle) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtHandle_t);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtDestroy");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(lightHandle);
|
||||
}
|
||||
|
||||
size_t CUBLASWINAPI
|
||||
cublasLtGetVersion(void) {
|
||||
using FuncPtr = size_t (CUBLASWINAPI *)();
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtGetVersion");
|
||||
if (!func_ptr) return 0;
|
||||
return func_ptr();
|
||||
}
|
||||
|
||||
size_t CUBLASWINAPI
|
||||
cublasLtGetCudartVersion(void) {
|
||||
using FuncPtr = size_t (CUBLASWINAPI *)();
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtGetCudartVersion");
|
||||
if (!func_ptr) return 0;
|
||||
return func_ptr();
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtGetProperty(libraryPropertyType type, int *value) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(libraryPropertyType, int *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtGetProperty");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(type, value);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatmul(cublasLtHandle_t lightHandle,
|
||||
cublasLtMatmulDesc_t computeDesc,
|
||||
const void *alpha, /* host or device pointer */
|
||||
const void *A,
|
||||
cublasLtMatrixLayout_t Adesc,
|
||||
const void *B,
|
||||
cublasLtMatrixLayout_t Bdesc,
|
||||
const void *beta, /* host or device pointer */
|
||||
const void *C,
|
||||
cublasLtMatrixLayout_t Cdesc,
|
||||
void *D,
|
||||
cublasLtMatrixLayout_t Ddesc,
|
||||
const cublasLtMatmulAlgo_t *algo,
|
||||
void *workspace,
|
||||
size_t workspaceSizeInBytes,
|
||||
cudaStream_t stream) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtHandle_t, cublasLtMatmulDesc_t, const void *, const void *, cublasLtMatrixLayout_t, const void *, cublasLtMatrixLayout_t, const void *, const void *, cublasLtMatrixLayout_t, void *, cublasLtMatrixLayout_t, const cublasLtMatmulAlgo_t *, void *, size_t, cudaStream_t);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmul");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(lightHandle, computeDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, D, Ddesc, algo, workspace, workspaceSizeInBytes, stream);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatrixTransform(cublasLtHandle_t lightHandle,
|
||||
cublasLtMatrixTransformDesc_t transformDesc,
|
||||
const void *alpha, /* host or device pointer */
|
||||
const void *A,
|
||||
cublasLtMatrixLayout_t Adesc,
|
||||
const void *beta, /* host or device pointer */
|
||||
const void *B,
|
||||
cublasLtMatrixLayout_t Bdesc,
|
||||
void *C,
|
||||
cublasLtMatrixLayout_t Cdesc,
|
||||
cudaStream_t stream) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtHandle_t, cublasLtMatrixTransformDesc_t, const void *, const void *, cublasLtMatrixLayout_t, const void *, const void *, cublasLtMatrixLayout_t, void *, cublasLtMatrixLayout_t, cudaStream_t);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatrixTransform");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(lightHandle, transformDesc, alpha, A, Adesc, beta, B, Bdesc, C, Cdesc, stream);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatrixLayoutInit_internal( //
|
||||
cublasLtMatrixLayout_t matLayout,
|
||||
size_t size,
|
||||
cudaDataType type,
|
||||
uint64_t rows,
|
||||
uint64_t cols,
|
||||
int64_t ld) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(//
|
||||
cublasLtMatrixLayout_t, size_t, cudaDataType, uint64_t, uint64_t, int64_t);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatrixLayoutInit_internal");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(matLayout, size, type, rows, cols, ld);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatrixLayoutCreate( //
|
||||
cublasLtMatrixLayout_t *matLayout,
|
||||
cudaDataType type,
|
||||
uint64_t rows,
|
||||
uint64_t cols,
|
||||
int64_t ld) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(//
|
||||
cublasLtMatrixLayout_t *, cudaDataType, uint64_t, uint64_t, int64_t);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatrixLayoutCreate");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(matLayout, type, rows, cols, ld);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatrixLayoutDestroy(cublasLtMatrixLayout_t matLayout) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatrixLayout_t);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatrixLayoutDestroy");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(matLayout);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatrixLayoutSetAttribute( //
|
||||
cublasLtMatrixLayout_t matLayout,
|
||||
cublasLtMatrixLayoutAttribute_t attr,
|
||||
const void *buf,
|
||||
size_t sizeInBytes) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(//
|
||||
cublasLtMatrixLayout_t, cublasLtMatrixLayoutAttribute_t, const void *, size_t);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatrixLayoutSetAttribute");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(matLayout, attr, buf, sizeInBytes);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatrixLayoutGetAttribute( //
|
||||
cublasLtMatrixLayout_t matLayout,
|
||||
cublasLtMatrixLayoutAttribute_t attr,
|
||||
void *buf,
|
||||
size_t sizeInBytes,
|
||||
size_t *sizeWritten) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(//
|
||||
cublasLtMatrixLayout_t, cublasLtMatrixLayoutAttribute_t, void *, size_t, size_t *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatrixLayoutGetAttribute");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(matLayout, attr, buf, sizeInBytes, sizeWritten);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtMatmulDescInit_internal( //
|
||||
cublasLtMatmulDesc_t matmulDesc,
|
||||
size_t size,
|
||||
cublasComputeType_t computeType,
|
||||
cudaDataType_t scaleType) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(//
|
||||
cublasLtMatmulDesc_t, size_t, cublasComputeType_t, cudaDataType_t);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulDescInit_internal");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(matmulDesc, size, computeType, scaleType);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatmulDescCreate(cublasLtMatmulDesc_t *matmulDesc, cublasComputeType_t computeType, cudaDataType_t scaleType) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatmulDesc_t *, cublasComputeType_t, cudaDataType_t);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulDescCreate");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(matmulDesc, computeType, scaleType);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatmulDescDestroy(cublasLtMatmulDesc_t matmulDesc) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatmulDesc_t);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulDescDestroy");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(matmulDesc);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatmulDescSetAttribute( //
|
||||
cublasLtMatmulDesc_t matmulDesc,
|
||||
cublasLtMatmulDescAttributes_t attr,
|
||||
const void *buf,
|
||||
size_t sizeInBytes) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(//
|
||||
cublasLtMatmulDesc_t, cublasLtMatmulDescAttributes_t, const void *, size_t);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulDescSetAttribute");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(matmulDesc, attr, buf, sizeInBytes);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatmulDescGetAttribute( //
|
||||
cublasLtMatmulDesc_t matmulDesc,
|
||||
cublasLtMatmulDescAttributes_t attr,
|
||||
void *buf,
|
||||
size_t sizeInBytes,
|
||||
size_t *sizeWritten) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(//
|
||||
cublasLtMatmulDesc_t, cublasLtMatmulDescAttributes_t, void *, size_t, size_t *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulDescGetAttribute");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(matmulDesc, attr, buf, sizeInBytes, sizeWritten);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatrixTransformDescInit_internal(cublasLtMatrixTransformDesc_t transformDesc, size_t size, cudaDataType scaleType) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatrixTransformDesc_t, size_t, cudaDataType);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatrixTransformDescInit_internal");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(transformDesc, size, scaleType);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatrixTransformDescCreate(cublasLtMatrixTransformDesc_t *transformDesc, cudaDataType scaleType) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatrixTransformDesc_t *, cudaDataType);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatrixTransformDescCreate");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(transformDesc, scaleType);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatrixTransformDescDestroy(cublasLtMatrixTransformDesc_t transformDesc) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatrixTransformDesc_t);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatrixTransformDescDestroy");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(transformDesc);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatrixTransformDescSetAttribute( //
|
||||
cublasLtMatrixTransformDesc_t transformDesc,
|
||||
cublasLtMatrixTransformDescAttributes_t attr,
|
||||
const void *buf,
|
||||
size_t sizeInBytes) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(//
|
||||
cublasLtMatrixTransformDesc_t, cublasLtMatrixTransformDescAttributes_t, const void *, size_t);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatrixTransformDescSetAttribute");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(transformDesc, attr, buf, sizeInBytes);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatrixTransformDescGetAttribute( //
|
||||
cublasLtMatrixTransformDesc_t transformDesc,
|
||||
cublasLtMatrixTransformDescAttributes_t attr,
|
||||
void *buf,
|
||||
size_t sizeInBytes,
|
||||
size_t *sizeWritten) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(//
|
||||
cublasLtMatrixTransformDesc_t, cublasLtMatrixTransformDescAttributes_t, void *, size_t, size_t *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatrixTransformDescGetAttribute");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(transformDesc, attr, buf, sizeInBytes, sizeWritten);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatmulPreferenceInit_internal(cublasLtMatmulPreference_t pref, size_t size) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatmulPreference_t, size_t);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulPreferenceInit_internal");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(pref, size);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatmulPreferenceCreate(cublasLtMatmulPreference_t *pref) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatmulPreference_t *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulPreferenceCreate");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(pref);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatmulPreferenceDestroy(cublasLtMatmulPreference_t pref) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatmulPreference_t);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulPreferenceDestroy");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(pref);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatmulPreferenceSetAttribute( //
|
||||
cublasLtMatmulPreference_t pref,
|
||||
cublasLtMatmulPreferenceAttributes_t attr,
|
||||
const void *buf,
|
||||
size_t sizeInBytes) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(//
|
||||
cublasLtMatmulPreference_t, cublasLtMatmulPreferenceAttributes_t, const void *, size_t);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulPreferenceSetAttribute");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(pref, attr, buf, sizeInBytes);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatmulPreferenceGetAttribute( //
|
||||
cublasLtMatmulPreference_t pref,
|
||||
cublasLtMatmulPreferenceAttributes_t attr,
|
||||
void *buf,
|
||||
size_t sizeInBytes,
|
||||
size_t *sizeWritten) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(//
|
||||
cublasLtMatmulPreference_t, cublasLtMatmulPreferenceAttributes_t, void *, size_t, size_t *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulPreferenceGetAttribute");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(pref, attr, buf, sizeInBytes, sizeWritten);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatmulAlgoGetHeuristic(
|
||||
cublasLtHandle_t lightHandle,
|
||||
cublasLtMatmulDesc_t operationDesc,
|
||||
cublasLtMatrixLayout_t Adesc,
|
||||
cublasLtMatrixLayout_t Bdesc,
|
||||
cublasLtMatrixLayout_t Cdesc,
|
||||
cublasLtMatrixLayout_t Ddesc,
|
||||
cublasLtMatmulPreference_t preference,
|
||||
int requestedAlgoCount,
|
||||
cublasLtMatmulHeuristicResult_t heuristicResultsArray[],
|
||||
int *returnAlgoCount) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtHandle_t, cublasLtMatmulDesc_t, cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, cublasLtMatmulPreference_t, int, cublasLtMatmulHeuristicResult_t [], int *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulAlgoGetHeuristic");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(lightHandle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, requestedAlgoCount, heuristicResultsArray, returnAlgoCount);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatmulAlgoGetIds(
|
||||
cublasLtHandle_t lightHandle,
|
||||
cublasComputeType_t computeType,
|
||||
cudaDataType_t scaleType,
|
||||
cudaDataType_t Atype,
|
||||
cudaDataType_t Btype,
|
||||
cudaDataType_t Ctype,
|
||||
cudaDataType_t Dtype,
|
||||
int requestedAlgoCount,
|
||||
int algoIdsArray[],
|
||||
int *returnAlgoCount) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtHandle_t, cublasComputeType_t, cudaDataType_t, cudaDataType_t, cudaDataType_t, cudaDataType_t, cudaDataType_t, int, int [], int *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulAlgoGetIds");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(lightHandle, computeType, scaleType, Atype, Btype, Ctype, Dtype, requestedAlgoCount, algoIdsArray, returnAlgoCount);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatmulAlgoInit ( cublasLtHandle_t lightHandle,
|
||||
cublasComputeType_t computeType,
|
||||
cudaDataType_t scaleType,
|
||||
cudaDataType_t Atype,
|
||||
cudaDataType_t Btype,
|
||||
cudaDataType_t Ctype,
|
||||
cudaDataType_t Dtype,
|
||||
int algoId,
|
||||
cublasLtMatmulAlgo_t *algo) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtHandle_t, cublasComputeType_t, cudaDataType_t, cudaDataType_t, cudaDataType_t, cudaDataType_t, cudaDataType_t, int, cublasLtMatmulAlgo_t *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulAlgoInit");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(lightHandle, computeType, scaleType, Atype, Btype, Ctype, Dtype, algoId, algo);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatmulAlgoCheck( //
|
||||
cublasLtHandle_t lightHandle,
|
||||
cublasLtMatmulDesc_t operationDesc,
|
||||
cublasLtMatrixLayout_t Adesc,
|
||||
cublasLtMatrixLayout_t Bdesc,
|
||||
cublasLtMatrixLayout_t Cdesc,
|
||||
cublasLtMatrixLayout_t Ddesc,
|
||||
const cublasLtMatmulAlgo_t *algo, ///< may point to result->algo
|
||||
cublasLtMatmulHeuristicResult_t *result) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(//
|
||||
cublasLtHandle_t, cublasLtMatmulDesc_t, cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, const cublasLtMatmulAlgo_t *, ///< may point to result->algo
|
||||
cublasLtMatmulHeuristicResult_t *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulAlgoCheck");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(lightHandle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, algo, result);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatmulAlgoCapGetAttribute(
|
||||
const cublasLtMatmulAlgo_t *algo,
|
||||
cublasLtMatmulAlgoCapAttributes_t attr,
|
||||
void *buf,
|
||||
size_t sizeInBytes,
|
||||
size_t *sizeWritten) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(const cublasLtMatmulAlgo_t *, cublasLtMatmulAlgoCapAttributes_t, void *, size_t, size_t *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulAlgoCapGetAttribute");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(algo, attr, buf, sizeInBytes, sizeWritten);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatmulAlgoConfigSetAttribute(
|
||||
cublasLtMatmulAlgo_t *algo,
|
||||
cublasLtMatmulAlgoConfigAttributes_t attr,
|
||||
const void *buf,
|
||||
size_t sizeInBytes) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasLtMatmulAlgo_t *, cublasLtMatmulAlgoConfigAttributes_t, const void *, size_t);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulAlgoConfigSetAttribute");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(algo, attr, buf, sizeInBytes);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatmulAlgoConfigGetAttribute(
|
||||
const cublasLtMatmulAlgo_t *algo,
|
||||
cublasLtMatmulAlgoConfigAttributes_t attr,
|
||||
void *buf,
|
||||
size_t sizeInBytes,
|
||||
size_t *sizeWritten) {
|
||||
using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(const cublasLtMatmulAlgo_t *, cublasLtMatmulAlgoConfigAttributes_t, void *, size_t, size_t *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulAlgoConfigGetAttribute");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(algo, attr, buf, sizeInBytes, sizeWritten);
|
||||
}
|
||||
|
||||
} // extern "C"
|
59
tensorflow/stream_executor/cuda/cublasLt_stub.cc
Normal file
59
tensorflow/stream_executor/cuda/cublasLt_stub.cc
Normal file
@ -0,0 +1,59 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "third_party/gpus/cuda/include/cublasLt.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "tensorflow/stream_executor/lib/env.h"
|
||||
#include "tensorflow/stream_executor/platform/dso_loader.h"
|
||||
|
||||
// Implements the cuBLASLt API by forwarding to cuBLASLt loaded from the DSO.
|
||||
|
||||
namespace {
|
||||
// Returns DSO handle or null if loading the DSO fails.
|
||||
void* GetDsoHandle() {
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
return nullptr;
|
||||
#else
|
||||
static auto handle = []() -> void* {
|
||||
auto handle_or =
|
||||
stream_executor::internal::DsoLoader::GetCublasLtDsoHandle();
|
||||
if (!handle_or.ok()) return nullptr;
|
||||
return handle_or.ValueOrDie();
|
||||
}();
|
||||
return handle;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T LoadSymbol(const char* symbol_name) {
|
||||
void* symbol = nullptr;
|
||||
if (auto handle = GetDsoHandle()) {
|
||||
stream_executor::port::Env::Default()
|
||||
->GetSymbolFromLibrary(handle, symbol_name, &symbol)
|
||||
.IgnoreError();
|
||||
}
|
||||
return reinterpret_cast<T>(symbol);
|
||||
}
|
||||
|
||||
void LogFatalSymbolNotFound(const char* symbol_name) {
|
||||
LOG(FATAL) << symbol_name << " symbol not found.";
|
||||
}
|
||||
|
||||
cublasStatus_t GetSymbolNotFoundError() { return CUBLAS_STATUS_INTERNAL_ERROR; }
|
||||
} // namespace
|
||||
|
||||
// We only use cublasLt from CUDA 11.0 onward.
|
||||
#if CUDA_VERSION >= 11000
|
||||
#include "tensorflow/stream_executor/cuda/cublasLt_11_0.inc"
|
||||
#endif
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/gpus/cuda/include/cublas_v2.h"
|
||||
#include "third_party/gpus/cuda/include/cublasLt.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
|
||||
#define SE_CUDA_DATA_HALF CUDA_R_16F
|
||||
@ -226,17 +227,38 @@ bool CUDABlas::Init() {
|
||||
return false;
|
||||
}
|
||||
|
||||
#if CUDA_VERSION >= 11000
|
||||
ret = cublasLtCreate(&blasLt_);
|
||||
if (ret != CUBLAS_STATUS_SUCCESS) {
|
||||
LOG(ERROR) << "failed to create cublasLt handle: " << ToString(ret);
|
||||
return false;
|
||||
}
|
||||
#endif // CUDA_VERSION >= 11000
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
CUDABlas::CUDABlas(gpu::GpuExecutor *parent)
|
||||
: parent_(CHECK_NOTNULL(parent)), blas_(nullptr) {}
|
||||
CUDABlas::CUDABlas(gpu::GpuExecutor* parent)
|
||||
: parent_(CHECK_NOTNULL(parent)),
|
||||
blas_(nullptr)
|
||||
#if CUDA_VERSION >= 11000
|
||||
,
|
||||
blasLt_(nullptr)
|
||||
#endif
|
||||
{
|
||||
}
|
||||
|
||||
CUDABlas::~CUDABlas() {
|
||||
if (blas_ != nullptr) {
|
||||
gpu::ScopedActivateExecutorContext sac{parent_};
|
||||
cublasDestroy(blas_);
|
||||
}
|
||||
#if CUDA_VERSION >= 11000
|
||||
if (blasLt_ != nullptr) {
|
||||
gpu::ScopedActivateExecutorContext sac{parent_};
|
||||
cublasLtDestroy(blasLt_);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
bool CUDABlas::SetStream(Stream *stream) {
|
||||
@ -253,6 +275,13 @@ bool CUDABlas::SetStream(Stream *stream) {
|
||||
return true;
|
||||
}
|
||||
|
||||
cudaStream_t CUDABlas::CUDAStream(Stream* stream) {
|
||||
CHECK(stream != nullptr);
|
||||
CHECK(AsGpuStreamValue(stream) != nullptr);
|
||||
gpu::ScopedActivateExecutorContext sac{parent_};
|
||||
return AsGpuStreamValue(stream);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Helper functions transforming blas arguments into cuBLAS arguments.
|
||||
@ -381,6 +410,82 @@ cudaDataType_t CUDAComputationType(blas::ComputationType ty) {
|
||||
return CUDA_C_32F;
|
||||
case blas::ComputationType::kComplexF64:
|
||||
return CUDA_C_64F;
|
||||
case blas::ComputationType::kF32FastTF32: // fall-through
|
||||
case blas::ComputationType::kF32FastBF16:
|
||||
// These cases are currently only supported in the blasLt routines, which
|
||||
// use CUBLASComputationType() instead.
|
||||
LOG(FATAL) << "Invalid value of blas::ComputationType.";
|
||||
}
|
||||
}
|
||||
|
||||
#if CUDA_VERSION >= 11000
|
||||
cublasComputeType_t CUBLASComputationType(blas::ComputationType ty) {
|
||||
switch (ty) {
|
||||
case blas::ComputationType::kF16:
|
||||
return CUBLAS_COMPUTE_16F;
|
||||
case blas::ComputationType::kF32: // fall-through
|
||||
case blas::ComputationType::kComplexF32:
|
||||
return CUBLAS_COMPUTE_32F;
|
||||
case blas::ComputationType::kF64: // fall-through
|
||||
case blas::ComputationType::kComplexF64:
|
||||
return CUBLAS_COMPUTE_64F;
|
||||
case blas::ComputationType::kI32:
|
||||
return CUBLAS_COMPUTE_32I;
|
||||
case blas::ComputationType::kF32FastTF32:
|
||||
return CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||
case blas::ComputationType::kF32FastBF16:
|
||||
return CUBLAS_COMPUTE_32F_FAST_16BF;
|
||||
}
|
||||
}
|
||||
#endif // CUDA_VERSION >= 11000
|
||||
|
||||
blas::DataType GetScaleType(blas::DataType data_type,
|
||||
blas::ComputationType compute_type) {
|
||||
bool is_complex = data_type == blas::DataType::kComplexF32 ||
|
||||
data_type == blas::DataType::kComplexF64;
|
||||
switch (compute_type) {
|
||||
case blas::ComputationType::kF16:
|
||||
return blas::DataType::kF16;
|
||||
case blas::ComputationType::kF32: // fall-through
|
||||
case blas::ComputationType::kComplexF32: // fall-through
|
||||
case blas::ComputationType::kF32FastTF32: // fall-through
|
||||
case blas::ComputationType::kF32FastBF16:
|
||||
return is_complex ? blas::DataType::kComplexF32 : blas::DataType::kF32;
|
||||
case blas::ComputationType::kF64: // fall-through
|
||||
case blas::ComputationType::kComplexF64:
|
||||
return is_complex ? blas::DataType::kComplexF64 : blas::DataType::kF64;
|
||||
case blas::ComputationType::kI32:
|
||||
return blas::DataType::kI32;
|
||||
}
|
||||
}
|
||||
|
||||
#if CUDA_VERSION >= 11000
|
||||
cublasLtPointerMode_t CUBLASPointerMode(blas::PointerMode pointer_mode) {
|
||||
switch (pointer_mode) {
|
||||
case blas::PointerMode::kHost:
|
||||
return CUBLASLT_POINTER_MODE_HOST;
|
||||
case blas::PointerMode::kDevice:
|
||||
return CUBLASLT_POINTER_MODE_DEVICE;
|
||||
}
|
||||
}
|
||||
#endif // CUDA_VERSION >= 11000
|
||||
|
||||
cudaDataType_t GetCUDADataType(blas::DataType ty) {
|
||||
switch (ty) {
|
||||
case blas::DataType::kF16:
|
||||
return CUDA_R_16F;
|
||||
case blas::DataType::kF32:
|
||||
return CUDA_R_32F;
|
||||
case blas::DataType::kF64:
|
||||
return CUDA_R_64F;
|
||||
case blas::DataType::kI8:
|
||||
return CUDA_R_8I;
|
||||
case blas::DataType::kI32:
|
||||
return CUDA_R_32I;
|
||||
case blas::DataType::kComplexF32:
|
||||
return CUDA_C_32F;
|
||||
case blas::DataType::kComplexF64:
|
||||
return CUDA_C_64F;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
@ -2912,6 +3017,577 @@ bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
|
||||
GpuComplex(GpuMemoryMutable(b)), ldb);
|
||||
}
|
||||
|
||||
// We only use cublasLt from CUDA 11.0 onward.
|
||||
#if CUDA_VERSION >= 11000
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
inline bool SetCublasLtAttr(cublasLtMatrixLayout_t handle,
|
||||
cublasLtMatrixLayoutAttribute_t attr,
|
||||
const T& value) {
|
||||
cublasStatus_t status =
|
||||
cublasLtMatrixLayoutSetAttribute(handle, attr, &value, sizeof(T));
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
VLOG(2) << "cublasLtMatrixLayoutSetAttribute(attr=" << attr
|
||||
<< ", value=" << value << ") failed: " << ToString(status);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool SetCublasLtAttr(cublasLtMatmulAlgo_t* handle,
|
||||
cublasLtMatmulAlgoConfigAttributes_t attr,
|
||||
const T& value) {
|
||||
cublasStatus_t status =
|
||||
cublasLtMatmulAlgoConfigSetAttribute(handle, attr, &value, sizeof(T));
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
VLOG(2) << "cublasLtMatmulAlgoConfigSetAttribute(attr=" << attr
|
||||
<< ", value=" << value << ") failed: " << ToString(status);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool SetCublasLtAttr(cublasLtMatmulPreference_t handle,
|
||||
cublasLtMatmulPreferenceAttributes_t attr,
|
||||
const T& value) {
|
||||
cublasStatus_t status =
|
||||
cublasLtMatmulPreferenceSetAttribute(handle, attr, &value, sizeof(value));
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
VLOG(2) << "cublasLtMatmulPreferenceSetAttribute(attr=" << attr
|
||||
<< ", value=" << value << ") failed: " << ToString(status);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool GetCublasLtAttr(const cublasLtMatmulAlgo_t* handle,
|
||||
cublasLtMatmulAlgoConfigAttributes_t attr,
|
||||
T* value) {
|
||||
auto mutable_handle = const_cast<cublasLtMatmulAlgo_t*>(handle);
|
||||
size_t bytes_written = 0;
|
||||
return cublasLtMatmulAlgoConfigGetAttribute(mutable_handle, attr, value,
|
||||
sizeof(T), &bytes_written) ==
|
||||
CUBLAS_STATUS_SUCCESS &&
|
||||
bytes_written == sizeof(T);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool SetCublasLtAttr(cublasLtMatmulDesc_t handle,
|
||||
cublasLtMatmulDescAttributes_t attr,
|
||||
const T& value) {
|
||||
cublasStatus_t status =
|
||||
cublasLtMatmulDescSetAttribute(handle, attr, &value, sizeof(value));
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
VLOG(2) << "cublasLtMatmulDescSetAttribute(attr=" << attr
|
||||
<< ", value=" << value << ") failed: " << ToString(status);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
struct MatmulDescDestroyer {
|
||||
void operator()(cublasLtMatmulDesc_t matmul_desc) const {
|
||||
cublasLtMatmulDescDestroy(matmul_desc);
|
||||
}
|
||||
};
|
||||
struct LayoutDestroyer {
|
||||
void operator()(cublasLtMatrixLayout_t layout) const {
|
||||
cublasLtMatrixLayoutDestroy(layout);
|
||||
}
|
||||
};
|
||||
struct MatmulPreferenceDestroyer {
|
||||
void operator()(cublasLtMatmulPreference_t matmul_pref) const {
|
||||
cublasLtMatmulPreferenceDestroy(matmul_pref);
|
||||
}
|
||||
};
|
||||
using UniqueOpDesc =
|
||||
std::unique_ptr<std::remove_pointer<cublasLtMatmulDesc_t>::type,
|
||||
MatmulDescDestroyer>;
|
||||
using UniqueLayoutDesc =
|
||||
std::unique_ptr<std::remove_pointer<cublasLtMatrixLayout_t>::type,
|
||||
LayoutDestroyer>;
|
||||
using UniqueMatmulPreference =
|
||||
std::unique_ptr<std::remove_pointer<cublasLtMatmulPreference_t>::type,
|
||||
MatmulPreferenceDestroyer>;
|
||||
|
||||
UniqueOpDesc CreateCublasLtOperationDesc(
|
||||
blas::ComputationType computation_type, blas::DataType scale_type,
|
||||
blas::PointerMode pointer_mode, blas::Transpose transa,
|
||||
blas::Transpose transb) {
|
||||
cublasOperation_t cuda_transa = CUDABlasTranspose(transa);
|
||||
cublasOperation_t cuda_transb = CUDABlasTranspose(transb);
|
||||
cublasLtMatmulDesc_t desc;
|
||||
cublasComputeType_t cublas_compute_type =
|
||||
CUBLASComputationType(computation_type);
|
||||
cudaDataType_t cuda_scale_type = GetCUDADataType(scale_type);
|
||||
cublasStatus_t status =
|
||||
cublasLtMatmulDescCreate(&desc, cublas_compute_type, cuda_scale_type);
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
VLOG(2) << "cublasLtMatmulDescCreate(computation_type=" << computation_type
|
||||
<< ") failed: " << ToString(status);
|
||||
return nullptr;
|
||||
}
|
||||
UniqueOpDesc unique_desc(desc);
|
||||
if (!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_POINTER_MODE,
|
||||
CUBLASPointerMode(pointer_mode)) ||
|
||||
!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSA, cuda_transa) ||
|
||||
!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSB, cuda_transb)) {
|
||||
return nullptr;
|
||||
}
|
||||
return unique_desc;
|
||||
}
|
||||
|
||||
UniqueLayoutDesc CreateCublasLtLayoutDesc(blas::DataType data_type, uint64 rows,
|
||||
uint64 cols, int64 ld, int64 stride,
|
||||
int batch_count) {
|
||||
cublasLtMatrixLayout_t desc;
|
||||
cublasStatus_t status = cublasLtMatrixLayoutCreate(
|
||||
&desc, GetCUDADataType(data_type), rows, cols, ld);
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
VLOG(2) << "cublasLtMatrixLayoutCreate failed: " << ToString(status);
|
||||
return nullptr;
|
||||
}
|
||||
UniqueLayoutDesc unique_desc(desc);
|
||||
if (!SetCublasLtAttr(desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_count) ||
|
||||
!SetCublasLtAttr(desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
|
||||
stride)) {
|
||||
return nullptr;
|
||||
}
|
||||
return unique_desc;
|
||||
}
|
||||
|
||||
UniqueMatmulPreference CreateCublasLtMatmulPreference(
|
||||
size_t max_workspace_bytes) {
|
||||
cublasLtMatmulPreference_t preference;
|
||||
cublasStatus_t status = cublasLtMatmulPreferenceCreate(&preference);
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
VLOG(2) << "cublasLtMatmulPreferenceCreate failed: " << ToString(status);
|
||||
return nullptr;
|
||||
}
|
||||
UniqueMatmulPreference unique_preference(preference);
|
||||
if (!SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
max_workspace_bytes)) {
|
||||
return nullptr;
|
||||
}
|
||||
return unique_preference;
|
||||
}
|
||||
|
||||
// Helper function to allocate workspace.
|
||||
port::Status AllocateWorkspace(void** workspace,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
size_t num_bytes) {
|
||||
SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace_bytes,
|
||||
scratch_allocator->AllocateBytes(num_bytes));
|
||||
*workspace = (void*)GpuMemoryMutable(&workspace_bytes);
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
blas::ComputationType ToComputationType();
|
||||
template <>
|
||||
blas::ComputationType ToComputationType<Eigen::half>() {
|
||||
return blas::ComputationType::kF16;
|
||||
}
|
||||
template <>
|
||||
blas::ComputationType ToComputationType<float>() {
|
||||
return blas::ComputationType::kF32;
|
||||
}
|
||||
template <>
|
||||
blas::ComputationType ToComputationType<double>() {
|
||||
return blas::ComputationType::kF64;
|
||||
}
|
||||
template <>
|
||||
blas::ComputationType ToComputationType<std::complex<float>>() {
|
||||
return blas::ComputationType::kComplexF32;
|
||||
}template <>
|
||||
blas::ComputationType ToComputationType<std::complex<double>>() {
|
||||
return blas::ComputationType::kComplexF64;
|
||||
}
|
||||
|
||||
class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
|
||||
public:
|
||||
CUDABlasLtMatmulPlan(blas::DataType ab_type, blas::DataType cd_type,
|
||||
blas::ComputationType compute_type,
|
||||
blas::PointerMode pointer_mode, blas::Transpose transa,
|
||||
blas::Transpose transb, uint64 m, uint64 n, uint64 k,
|
||||
int batch_count, int64 lda, int64 stride_a, int64 ldb,
|
||||
int64 stride_b, int64 ldc, int64 stride_c, int64 ldd,
|
||||
int64 stride_d);
|
||||
|
||||
cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); }
|
||||
cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); }
|
||||
cublasLtMatrixLayout_t b_desc() const { return b_desc_.get(); }
|
||||
cublasLtMatrixLayout_t c_desc() const { return c_desc_.get(); }
|
||||
cublasLtMatrixLayout_t d_desc() const { return d_desc_.get(); }
|
||||
bool ok() { return op_desc_ && a_desc_ && b_desc_ && c_desc_ && d_desc_; }
|
||||
|
||||
blas::DataType ab_type() const { return ab_type_; }
|
||||
blas::DataType cd_type() const { return cd_type_; }
|
||||
blas::DataType scale_type() const { return scale_type_; }
|
||||
blas::PointerMode pointer_mode() const { return pointer_mode_; }
|
||||
|
||||
private:
|
||||
UniqueOpDesc op_desc_;
|
||||
UniqueLayoutDesc a_desc_;
|
||||
UniqueLayoutDesc b_desc_;
|
||||
UniqueLayoutDesc c_desc_;
|
||||
UniqueLayoutDesc d_desc_;
|
||||
blas::DataType ab_type_;
|
||||
blas::DataType cd_type_;
|
||||
blas::DataType scale_type_;
|
||||
blas::PointerMode pointer_mode_;
|
||||
};
|
||||
|
||||
CUDABlasLtMatmulPlan::CUDABlasLtMatmulPlan(
|
||||
blas::DataType ab_type, blas::DataType cd_type,
|
||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
||||
uint64 k, int batch_count, int64 lda, int64 stride_a, int64 ldb,
|
||||
int64 stride_b, int64 ldc, int64 stride_c, int64 ldd, int64 stride_d)
|
||||
: op_desc_(CreateCublasLtOperationDesc(
|
||||
computation_type, GetScaleType(cd_type, computation_type),
|
||||
pointer_mode, transa, transb)),
|
||||
a_desc_(nullptr),
|
||||
b_desc_(nullptr),
|
||||
c_desc_(
|
||||
CreateCublasLtLayoutDesc(cd_type, m, n, ldc, stride_c, batch_count)),
|
||||
d_desc_(
|
||||
CreateCublasLtLayoutDesc(cd_type, m, n, ldd, stride_d, batch_count)),
|
||||
ab_type_(ab_type),
|
||||
cd_type_(cd_type),
|
||||
scale_type_(GetScaleType(cd_type, computation_type)),
|
||||
pointer_mode_(pointer_mode) {
|
||||
uint64 rows_a = transa == blas::Transpose::kNoTranspose ? m : k;
|
||||
uint64 cols_a = transa == blas::Transpose::kNoTranspose ? k : m;
|
||||
uint64 rows_b = transb == blas::Transpose::kNoTranspose ? k : n;
|
||||
uint64 cols_b = transb == blas::Transpose::kNoTranspose ? n : k;
|
||||
a_desc_ = CreateCublasLtLayoutDesc(ab_type, rows_a, cols_a, lda, stride_a,
|
||||
batch_count);
|
||||
b_desc_ = CreateCublasLtLayoutDesc(ab_type, rows_b, cols_b, ldb, stride_b,
|
||||
batch_count);
|
||||
}
|
||||
|
||||
class CUDABlasLtMatmulAlgorithm final : public blas::IBlasLtMatmulAlgorithm {
|
||||
public:
|
||||
CUDABlasLtMatmulAlgorithm(blas::AlgorithmType index,
|
||||
cublasLtMatmulAlgo_t algo, size_t workspace_size)
|
||||
: index_(index), algo_(algo), workspace_size_(workspace_size) {}
|
||||
|
||||
blas::AlgorithmType index() const override { return index_; }
|
||||
|
||||
size_t workspace_size() const override { return workspace_size_; }
|
||||
|
||||
const cublasLtMatmulAlgo_t* algo() const { return &algo_; }
|
||||
|
||||
int algo_id() const {
|
||||
int id;
|
||||
GetCublasLtAttr(&algo_, CUBLASLT_ALGO_CONFIG_ID, &id);
|
||||
return id;
|
||||
}
|
||||
|
||||
private:
|
||||
blas::AlgorithmType index_;
|
||||
cublasLtMatmulAlgo_t algo_;
|
||||
size_t workspace_size_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
#endif // CUDA_VERSION >= 11000
|
||||
|
||||
std::unique_ptr<blas::IBlasLtMatmulPlan>
|
||||
CUDABlas::CreateBlasLtMatmulPlanStridedBatched(
|
||||
blas::DataType ab_type, blas::DataType cd_type,
|
||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
||||
uint64 k, int batch_count, int64 lda, int64 stride_a, int64 ldb,
|
||||
int64 stride_b, int64 ldc, int64 stride_c) {
|
||||
#if CUDA_VERSION >= 11000
|
||||
auto result = std::make_unique<CUDABlasLtMatmulPlan>(
|
||||
ab_type, cd_type, computation_type, pointer_mode, transa, transb, m, n, k,
|
||||
batch_count, lda, stride_a, ldb, stride_b, ldc, stride_c, ldc, stride_c);
|
||||
if (!result->ok()) {
|
||||
result.reset();
|
||||
}
|
||||
return result;
|
||||
#else
|
||||
return nullptr;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool CUDABlas::GetBlasLtMatmulAlgorithms(
|
||||
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size,
|
||||
int max_algorithm_count,
|
||||
std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>*
|
||||
out_algorithms) {
|
||||
#if CUDA_VERSION >= 11000
|
||||
UniqueMatmulPreference preference =
|
||||
CreateCublasLtMatmulPreference(max_workspace_size);
|
||||
if (!preference) return false;
|
||||
|
||||
std::vector<cublasLtMatmulHeuristicResult_t> results(max_algorithm_count);
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
|
||||
CHECK(blasLt_ != nullptr);
|
||||
|
||||
gpu::ScopedActivateExecutorContext sac{parent_};
|
||||
|
||||
int found_algorithm_count = 0;
|
||||
const auto& cuda_plan = *static_cast<const CUDABlasLtMatmulPlan*>(plan);
|
||||
cublasStatus_t status = cublasLtMatmulAlgoGetHeuristic(
|
||||
blasLt_, cuda_plan.op_desc(), cuda_plan.a_desc(), cuda_plan.b_desc(),
|
||||
cuda_plan.c_desc(), cuda_plan.d_desc(), preference.get(),
|
||||
max_algorithm_count, results.data(), &found_algorithm_count);
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
VLOG(2) << "cublasLtMatmulAlgoGetHeuristic failed: " << ToString(status);
|
||||
return false;
|
||||
}
|
||||
results.resize(found_algorithm_count);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < results.size(); ++i) {
|
||||
const auto& result = results[i];
|
||||
if (result.state != CUBLAS_STATUS_SUCCESS) continue; // Skip failed algos
|
||||
out_algorithms->emplace_back(std::make_unique<CUDABlasLtMatmulAlgorithm>(
|
||||
i, result.algo, result.workspaceSize));
|
||||
}
|
||||
return true;
|
||||
#else // if CUDA_VERSION < 11000
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
#if CUDA_VERSION >= 11000
|
||||
template <typename ABType, typename CDType, typename ScaleType>
|
||||
bool CUDABlas::DoBlasLtMatmulInternalImpl(
|
||||
Stream* stream, bool err_on_failure, const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<ScaleType>& alpha, const ABType* a,
|
||||
const ABType* b, const HostOrDeviceScalar<ScaleType>& beta, const CDType* c,
|
||||
CDType* d, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm) {
|
||||
const auto& cuda_plan = *static_cast<const CUDABlasLtMatmulPlan*>(plan);
|
||||
const auto& cuda_algo =
|
||||
*static_cast<const CUDABlasLtMatmulAlgorithm*>(algorithm);
|
||||
|
||||
if (cuda_plan.ab_type() != blas::ToDataType<ABType>::value) {
|
||||
VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong ab_type: "
|
||||
"expected "
|
||||
<< blas::ToDataType<ABType>::value << ", got "
|
||||
<< cuda_plan.ab_type();
|
||||
return false;
|
||||
}
|
||||
if (cuda_plan.cd_type() != blas::ToDataType<CDType>::value) {
|
||||
VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong cd_type: "
|
||||
"expected "
|
||||
<< blas::ToDataType<CDType>::value << ", got "
|
||||
<< cuda_plan.cd_type();
|
||||
return false;
|
||||
}
|
||||
if (cuda_plan.scale_type() != blas::ToDataType<ScaleType>::value) {
|
||||
VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong "
|
||||
"scale_type: expected "
|
||||
<< blas::ToDataType<ScaleType>::value << ", got "
|
||||
<< cuda_plan.cd_type();
|
||||
return false;
|
||||
}
|
||||
if (alpha.is_pointer() != beta.is_pointer()) {
|
||||
VLOG(2) << "DoBlasLtMatmul returning false because one of `alpha` "
|
||||
"and `beta` is a pointer, but the other is not.";
|
||||
return false;
|
||||
}
|
||||
bool is_pointer_mode_host = !alpha.is_pointer();
|
||||
if ((cuda_plan.pointer_mode() == blas::PointerMode::kHost) !=
|
||||
is_pointer_mode_host) {
|
||||
VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong "
|
||||
"pointer_mode for the given alpha/beta.";
|
||||
return false;
|
||||
}
|
||||
const ScaleType* alpha_ptr =
|
||||
alpha.is_pointer() ? GpuMemory(alpha.pointer()) : &alpha.value();
|
||||
const ScaleType* beta_ptr =
|
||||
beta.is_pointer() ? GpuMemory(beta.pointer()) : &beta.value();
|
||||
|
||||
void* workspace = nullptr;
|
||||
if (cuda_algo.workspace_size()) {
|
||||
port::Status allocation_status = AllocateWorkspace(
|
||||
&workspace, scratch_allocator, cuda_algo.workspace_size());
|
||||
if (!allocation_status.ok()) {
|
||||
if (err_on_failure || VLOG_IS_ON(3)) {
|
||||
LOG(ERROR)
|
||||
<< "Failed to allocate workspace for cublasLtMatmul algo with id: "
|
||||
<< cuda_algo.algo_id() << " requiring "
|
||||
<< cuda_algo.workspace_size() << " bytes of workspace";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
cudaStream_t cuda_stream = CUDAStream(stream);
|
||||
|
||||
absl::MutexLock lock(&mu_);
|
||||
|
||||
CHECK(blasLt_ != nullptr);
|
||||
|
||||
gpu::ScopedActivateExecutorContext sac{parent_};
|
||||
|
||||
cublasStatus_t ret = cublasLtMatmul(
|
||||
blasLt_, cuda_plan.op_desc(), alpha_ptr, a, cuda_plan.a_desc(), b,
|
||||
cuda_plan.b_desc(), beta_ptr, c, cuda_plan.c_desc(), d,
|
||||
cuda_plan.d_desc(), cuda_algo.algo(), workspace,
|
||||
cuda_algo.workspace_size(), cuda_stream);
|
||||
if (ret != CUBLAS_STATUS_SUCCESS) {
|
||||
if (err_on_failure || VLOG_IS_ON(3)) {
|
||||
LOG(ERROR) << "failed to run cublasLtMatmul routine: " << ToString(ret);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
#endif // CUDA_VERSION >= 11000
|
||||
|
||||
template <typename ABType, typename CDType, typename ScaleType>
|
||||
bool CUDABlas::DoBlasLtMatmulInternal(
|
||||
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<ScaleType>& alpha, const DeviceMemory<ABType>& a,
|
||||
const DeviceMemory<ABType>& b, const HostOrDeviceScalar<ScaleType>& beta,
|
||||
const DeviceMemory<CDType>& c, DeviceMemory<CDType>* d,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
blas::ProfileResult* output_profile_result) {
|
||||
#if CUDA_VERSION >= 11000
|
||||
std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
|
||||
if (output_profile_result) {
|
||||
timer.reset(new GpuTimer(parent_));
|
||||
if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool err_on_failure = timer != nullptr;
|
||||
bool result = DoBlasLtMatmulInternalImpl(
|
||||
stream, err_on_failure, plan, alpha, GpuMemory(a), GpuMemory(b), beta,
|
||||
GpuMemory(c), GpuMemoryMutable(d), scratch_allocator, algorithm);
|
||||
|
||||
if (timer && result) {
|
||||
// GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
|
||||
// state.
|
||||
if (!timer->Stop(AsGpuStream(stream))) {
|
||||
return false;
|
||||
}
|
||||
output_profile_result->set_is_valid(true);
|
||||
output_profile_result->set_algorithm(algorithm->index());
|
||||
output_profile_result->set_elapsed_time_in_ms(
|
||||
timer->GetElapsedMilliseconds());
|
||||
}
|
||||
return result;
|
||||
#else // if CUDA_VERSION < 11000
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool CUDABlas::DoBlasLtMatmul(
|
||||
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<int32>& alpha, const DeviceMemory<int8>& a,
|
||||
const DeviceMemory<int8>& b, const HostOrDeviceScalar<int32>& beta,
|
||||
DeviceMemory<int32>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
blas::ProfileResult* output_profile_result) {
|
||||
return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
|
||||
scratch_allocator, algorithm,
|
||||
output_profile_result);
|
||||
}
|
||||
|
||||
bool CUDABlas::DoBlasLtMatmul(Stream* stream,
|
||||
const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<Eigen::half>& alpha,
|
||||
const DeviceMemory<Eigen::half>& a,
|
||||
const DeviceMemory<Eigen::half>& b,
|
||||
const HostOrDeviceScalar<Eigen::half>& beta,
|
||||
DeviceMemory<Eigen::half>* c,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
blas::ProfileResult* output_profile_result) {
|
||||
#if CUDA_VERSION >= 11000
|
||||
const auto& cuda_plan = *static_cast<const CUDABlasLtMatmulPlan*>(plan);
|
||||
if (cuda_plan.scale_type() == blas::DataType::kF32) {
|
||||
// F32* computation types require F32 alpha/beta type, so we must cast them.
|
||||
if (alpha.is_pointer() || beta.is_pointer()) {
|
||||
// We cannot easily convert a pointer to f16 memory to a pointer to f32
|
||||
// memory from here, so we don't support this for now.
|
||||
return false;
|
||||
}
|
||||
HostOrDeviceScalar<float> float_alpha(static_cast<float>(alpha.value()));
|
||||
HostOrDeviceScalar<float> float_beta(static_cast<float>(beta.value()));
|
||||
return DoBlasLtMatmulInternal(stream, plan, float_alpha, a, b, float_beta,
|
||||
*c, c, scratch_allocator, algorithm,
|
||||
output_profile_result);
|
||||
}
|
||||
return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
|
||||
scratch_allocator, algorithm,
|
||||
output_profile_result);
|
||||
#else // if CUDA_VERSION < 11000
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool CUDABlas::DoBlasLtMatmul(
|
||||
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<float>& alpha, const DeviceMemory<float>& a,
|
||||
const DeviceMemory<float>& b, const HostOrDeviceScalar<float>& beta,
|
||||
DeviceMemory<float>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
blas::ProfileResult* output_profile_result) {
|
||||
return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
|
||||
scratch_allocator, algorithm,
|
||||
output_profile_result);
|
||||
}
|
||||
|
||||
bool CUDABlas::DoBlasLtMatmul(
|
||||
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<double>& alpha, const DeviceMemory<double>& a,
|
||||
const DeviceMemory<double>& b, const HostOrDeviceScalar<double>& beta,
|
||||
DeviceMemory<double>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
blas::ProfileResult* output_profile_result) {
|
||||
return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
|
||||
scratch_allocator, algorithm,
|
||||
output_profile_result);
|
||||
}
|
||||
|
||||
bool CUDABlas::DoBlasLtMatmul(
|
||||
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<std::complex<float>>& alpha,
|
||||
const DeviceMemory<std::complex<float>>& a,
|
||||
const DeviceMemory<std::complex<float>>& b,
|
||||
const HostOrDeviceScalar<std::complex<float>>& beta,
|
||||
DeviceMemory<std::complex<float>>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
blas::ProfileResult* output_profile_result) {
|
||||
return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
|
||||
scratch_allocator, algorithm,
|
||||
output_profile_result);
|
||||
}
|
||||
|
||||
bool CUDABlas::DoBlasLtMatmul(
|
||||
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<std::complex<double>>& alpha,
|
||||
const DeviceMemory<std::complex<double>>& a,
|
||||
const DeviceMemory<std::complex<double>>& b,
|
||||
const HostOrDeviceScalar<std::complex<double>>& beta,
|
||||
DeviceMemory<std::complex<double>>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
blas::ProfileResult* output_profile_result) {
|
||||
return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
|
||||
scratch_allocator, algorithm,
|
||||
output_profile_result);
|
||||
}
|
||||
|
||||
port::Status CUDABlas::GetVersion(std::string *version) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
|
||||
|
@ -22,6 +22,8 @@ limitations under the License.
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "third_party/gpus/cuda/include/cublas_v2.h"
|
||||
#include "third_party/gpus/cuda/include/cublasLt.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/stream_executor/blas.h"
|
||||
#include "tensorflow/stream_executor/host_or_device_scalar.h"
|
||||
@ -71,6 +73,9 @@ class CUDABlas : public blas::BlasSupport {
|
||||
// invoked before calling into cuBLAS.
|
||||
bool SetStream(Stream *stream) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Returns the underlying CUDA stream.
|
||||
cudaStream_t CUDAStream(Stream* stream);
|
||||
|
||||
// A helper function that calls the real cuBLAS function together with error
|
||||
// handling.
|
||||
//
|
||||
@ -134,6 +139,26 @@ class CUDABlas : public blas::BlasSupport {
|
||||
const T &beta, DeviceMemory<T> *y, int incy,
|
||||
blas::ProfileResult *output_profile_result);
|
||||
|
||||
// Helper function for implementing DoBlasLtMatmul.
|
||||
template <typename ABType, typename CDType, typename ScaleType>
|
||||
bool DoBlasLtMatmulInternal(
|
||||
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<ScaleType>& alpha, const DeviceMemory<ABType>& a,
|
||||
const DeviceMemory<ABType>& b, const HostOrDeviceScalar<ScaleType>& beta,
|
||||
const DeviceMemory<CDType>& c, DeviceMemory<CDType>* d,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
blas::ProfileResult* output_profile_result);
|
||||
|
||||
// Helper function for implementing DoBlasLtMatmulInternal.
|
||||
template <typename ABType, typename CDType, typename ScaleType>
|
||||
bool DoBlasLtMatmulInternalImpl(
|
||||
Stream* stream, bool err_on_failure, const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<ScaleType>& alpha, const ABType* a,
|
||||
const ABType* b, const HostOrDeviceScalar<ScaleType>& beta,
|
||||
const CDType* c, CDType* d, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm);
|
||||
|
||||
// Guards the cuBLAS handle for this device.
|
||||
absl::Mutex mu_;
|
||||
|
||||
@ -144,6 +169,11 @@ class CUDABlas : public blas::BlasSupport {
|
||||
// cuBLAS library handle on the device.
|
||||
cublasHandle_t blas_ TF_GUARDED_BY(mu_);
|
||||
|
||||
#if CUDA_VERSION >= 11000
|
||||
// cuBLASLt library handle on the device.
|
||||
cublasLtHandle_t blasLt_ GUARDED_BY(mu_);
|
||||
#endif
|
||||
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(CUDABlas);
|
||||
};
|
||||
|
||||
|
@ -23,6 +23,7 @@ namespace DsoLoader {
|
||||
port::Status TryDlopenCUDALibraries() {
|
||||
auto cudart_status = GetCudaRuntimeDsoHandle();
|
||||
auto cublas_status = GetCublasDsoHandle();
|
||||
auto cublaslt_status = GetCublasLtDsoHandle();
|
||||
auto cufft_status = GetCufftDsoHandle();
|
||||
auto curand_status = GetCurandDsoHandle();
|
||||
auto cusolver_status = GetCusolverDsoHandle();
|
||||
@ -31,7 +32,7 @@ port::Status TryDlopenCUDALibraries() {
|
||||
if (!cudart_status.status().ok() || !cublas_status.status().ok() ||
|
||||
!cufft_status.status().ok() || !curand_status.status().ok() ||
|
||||
!cusolver_status.status().ok() || !cusparse_status.status().ok() ||
|
||||
!cudnn_status.status().ok()) {
|
||||
!cudnn_status.status().ok() || !cublaslt_status.status().ok()) {
|
||||
return port::Status(port::error::INTERNAL,
|
||||
absl::StrCat("Cannot dlopen all CUDA libraries."));
|
||||
} else {
|
||||
|
@ -84,6 +84,10 @@ port::StatusOr<void*> GetCublasDsoHandle() {
|
||||
return GetDsoHandle("cublas", GetCublasVersion());
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetCublasLtDsoHandle() {
|
||||
return GetDsoHandle("cublasLt", GetCublasVersion());
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetCufftDsoHandle() {
|
||||
return GetDsoHandle("cufft", GetCufftVersion());
|
||||
}
|
||||
@ -160,6 +164,11 @@ port::StatusOr<void*> GetCublasDsoHandle() {
|
||||
return *result;
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetCublasLtDsoHandle() {
|
||||
static auto result = new auto(DsoLoader::GetCublasLtDsoHandle());
|
||||
return *result;
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetCurandDsoHandle() {
|
||||
static auto result = new auto(DsoLoader::GetCurandDsoHandle());
|
||||
return *result;
|
||||
|
@ -37,6 +37,7 @@ namespace DsoLoader {
|
||||
port::StatusOr<void*> GetCudaDriverDsoHandle();
|
||||
port::StatusOr<void*> GetCudaRuntimeDsoHandle();
|
||||
port::StatusOr<void*> GetCublasDsoHandle();
|
||||
port::StatusOr<void*> GetCublasLtDsoHandle();
|
||||
port::StatusOr<void*> GetCufftDsoHandle();
|
||||
port::StatusOr<void*> GetCurandDsoHandle();
|
||||
port::StatusOr<void*> GetCusolverDsoHandle();
|
||||
@ -72,6 +73,7 @@ namespace CachedDsoLoader {
|
||||
port::StatusOr<void*> GetCudaDriverDsoHandle();
|
||||
port::StatusOr<void*> GetCudaRuntimeDsoHandle();
|
||||
port::StatusOr<void*> GetCublasDsoHandle();
|
||||
port::StatusOr<void*> GetCublasLtDsoHandle();
|
||||
port::StatusOr<void*> GetCufftDsoHandle();
|
||||
port::StatusOr<void*> GetCurandDsoHandle();
|
||||
port::StatusOr<void*> GetCusolverDsoHandle();
|
||||
|
@ -4801,6 +4801,143 @@ Stream &Stream::ThenBlasGemmStridedBatched(
|
||||
c, ldc, stride_c, batch_count);
|
||||
}
|
||||
|
||||
Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<int32>& alpha,
|
||||
const DeviceMemory<int8>& a,
|
||||
const DeviceMemory<int8>& b,
|
||||
const HostOrDeviceScalar<int32>& beta,
|
||||
DeviceMemory<int32>* c,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
blas::ProfileResult* output_profile_result) {
|
||||
VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
|
||||
PARAM(c), PARAM(algorithm));
|
||||
|
||||
ThenBlasWithProfileImpl<
|
||||
const blas::IBlasLtMatmulPlan*, const HostOrDeviceScalar<int32>&,
|
||||
const DeviceMemory<int8>&, const DeviceMemory<int8>&,
|
||||
const HostOrDeviceScalar<int32>&, DeviceMemory<int32>*, ScratchAllocator*,
|
||||
const blas::IBlasLtMatmulAlgorithm*>
|
||||
impl;
|
||||
return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
|
||||
c, scratch_allocator, algorithm, output_profile_result);
|
||||
}
|
||||
|
||||
Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<Eigen::half>& alpha,
|
||||
const DeviceMemory<Eigen::half>& a,
|
||||
const DeviceMemory<Eigen::half>& b,
|
||||
const HostOrDeviceScalar<Eigen::half>& beta,
|
||||
DeviceMemory<Eigen::half>* c,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
blas::ProfileResult* output_profile_result) {
|
||||
VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
|
||||
PARAM(c), PARAM(algorithm));
|
||||
|
||||
ThenBlasWithProfileImpl<
|
||||
const blas::IBlasLtMatmulPlan*, const HostOrDeviceScalar<Eigen::half>&,
|
||||
const DeviceMemory<Eigen::half>&, const DeviceMemory<Eigen::half>&,
|
||||
const HostOrDeviceScalar<Eigen::half>&, DeviceMemory<Eigen::half>*,
|
||||
ScratchAllocator*, const blas::IBlasLtMatmulAlgorithm*>
|
||||
impl;
|
||||
return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
|
||||
c, scratch_allocator, algorithm, output_profile_result);
|
||||
}
|
||||
|
||||
Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<float>& alpha,
|
||||
const DeviceMemory<float>& a,
|
||||
const DeviceMemory<float>& b,
|
||||
const HostOrDeviceScalar<float>& beta,
|
||||
DeviceMemory<float>* c,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
blas::ProfileResult* output_profile_result) {
|
||||
VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
|
||||
PARAM(c), PARAM(algorithm));
|
||||
|
||||
ThenBlasWithProfileImpl<
|
||||
const blas::IBlasLtMatmulPlan*, const HostOrDeviceScalar<float>&,
|
||||
const DeviceMemory<float>&, const DeviceMemory<float>&,
|
||||
const HostOrDeviceScalar<float>&, DeviceMemory<float>*, ScratchAllocator*,
|
||||
const blas::IBlasLtMatmulAlgorithm*>
|
||||
impl;
|
||||
return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
|
||||
c, scratch_allocator, algorithm, output_profile_result);
|
||||
}
|
||||
|
||||
Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<double>& alpha,
|
||||
const DeviceMemory<double>& a,
|
||||
const DeviceMemory<double>& b,
|
||||
const HostOrDeviceScalar<double>& beta,
|
||||
DeviceMemory<double>* c,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
blas::ProfileResult* output_profile_result) {
|
||||
VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
|
||||
PARAM(c), PARAM(algorithm));
|
||||
|
||||
ThenBlasWithProfileImpl<
|
||||
const blas::IBlasLtMatmulPlan*, const HostOrDeviceScalar<double>&,
|
||||
const DeviceMemory<double>&, const DeviceMemory<double>&,
|
||||
const HostOrDeviceScalar<double>&, DeviceMemory<double>*,
|
||||
ScratchAllocator*, const blas::IBlasLtMatmulAlgorithm*>
|
||||
impl;
|
||||
return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
|
||||
c, scratch_allocator, algorithm, output_profile_result);
|
||||
}
|
||||
|
||||
Stream& Stream::ThenBlasLtMatmul(
|
||||
const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<std::complex<float>>& alpha,
|
||||
const DeviceMemory<std::complex<float>>& a,
|
||||
const DeviceMemory<std::complex<float>>& b,
|
||||
const HostOrDeviceScalar<std::complex<float>>& beta,
|
||||
DeviceMemory<std::complex<float>>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
blas::ProfileResult* output_profile_result) {
|
||||
VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
|
||||
PARAM(c), PARAM(algorithm));
|
||||
|
||||
ThenBlasWithProfileImpl<const blas::IBlasLtMatmulPlan*,
|
||||
const HostOrDeviceScalar<std::complex<float>>&,
|
||||
const DeviceMemory<std::complex<float>>&,
|
||||
const DeviceMemory<std::complex<float>>&,
|
||||
const HostOrDeviceScalar<std::complex<float>>&,
|
||||
DeviceMemory<std::complex<float>>*, ScratchAllocator*,
|
||||
const blas::IBlasLtMatmulAlgorithm*>
|
||||
impl;
|
||||
return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
|
||||
c, scratch_allocator, algorithm, output_profile_result);
|
||||
}
|
||||
|
||||
Stream& Stream::ThenBlasLtMatmul(
|
||||
const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<std::complex<double>>& alpha,
|
||||
const DeviceMemory<std::complex<double>>& a,
|
||||
const DeviceMemory<std::complex<double>>& b,
|
||||
const HostOrDeviceScalar<std::complex<double>>& beta,
|
||||
DeviceMemory<std::complex<double>>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
blas::ProfileResult* output_profile_result) {
|
||||
VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
|
||||
PARAM(c), PARAM(algorithm));
|
||||
|
||||
ThenBlasWithProfileImpl<const blas::IBlasLtMatmulPlan*,
|
||||
const HostOrDeviceScalar<std::complex<double>>&,
|
||||
const DeviceMemory<std::complex<double>>&,
|
||||
const DeviceMemory<std::complex<double>>&,
|
||||
const HostOrDeviceScalar<std::complex<double>>&,
|
||||
DeviceMemory<std::complex<double>>*,
|
||||
ScratchAllocator*,
|
||||
const blas::IBlasLtMatmulAlgorithm*>
|
||||
impl;
|
||||
return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
|
||||
c, scratch_allocator, algorithm, output_profile_result);
|
||||
}
|
||||
|
||||
Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
|
||||
VLOG_CALL(PARAM(seed), PARAM(seed_bytes));
|
||||
|
||||
|
@ -1665,6 +1665,56 @@ class Stream {
|
||||
const DeviceMemory<std::complex<double>> &a, int lda,
|
||||
DeviceMemory<std::complex<double>> *b, int ldb);
|
||||
|
||||
// See BlasSupport::DoBlatLtMatmul.
|
||||
Stream& ThenBlasLtMatmul(
|
||||
const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<int32>& alpha, const DeviceMemory<int8>& a,
|
||||
const DeviceMemory<int8>& b, const HostOrDeviceScalar<int32>& beta,
|
||||
DeviceMemory<int32>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
blas::ProfileResult* output_profile_result = nullptr);
|
||||
Stream& ThenBlasLtMatmul(
|
||||
const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<Eigen::half>& alpha,
|
||||
const DeviceMemory<Eigen::half>& a, const DeviceMemory<Eigen::half>& b,
|
||||
const HostOrDeviceScalar<Eigen::half>& beta, DeviceMemory<Eigen::half>* c,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
blas::ProfileResult* output_profile_result = nullptr);
|
||||
Stream& ThenBlasLtMatmul(
|
||||
const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<float>& alpha, const DeviceMemory<float>& a,
|
||||
const DeviceMemory<float>& b, const HostOrDeviceScalar<float>& beta,
|
||||
DeviceMemory<float>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
blas::ProfileResult* output_profile_result = nullptr);
|
||||
Stream& ThenBlasLtMatmul(
|
||||
const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<double>& alpha, const DeviceMemory<double>& a,
|
||||
const DeviceMemory<double>& b, const HostOrDeviceScalar<double>& beta,
|
||||
DeviceMemory<double>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
blas::ProfileResult* output_profile_result = nullptr);
|
||||
Stream& ThenBlasLtMatmul(
|
||||
const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<std::complex<float>>& alpha,
|
||||
const DeviceMemory<std::complex<float>>& a,
|
||||
const DeviceMemory<std::complex<float>>& b,
|
||||
const HostOrDeviceScalar<std::complex<float>>& beta,
|
||||
DeviceMemory<std::complex<float>>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
blas::ProfileResult* output_profile_result = nullptr);
|
||||
Stream& ThenBlasLtMatmul(
|
||||
const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<std::complex<double>>& alpha,
|
||||
const DeviceMemory<std::complex<double>>& a,
|
||||
const DeviceMemory<std::complex<double>>& b,
|
||||
const HostOrDeviceScalar<std::complex<double>>& beta,
|
||||
DeviceMemory<std::complex<double>>* c,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
blas::ProfileResult* output_profile_result = nullptr);
|
||||
|
||||
// See FftSupport::DoFft.
|
||||
Stream &ThenFft(fft::Plan *plan,
|
||||
const DeviceMemory<std::complex<float>> &input,
|
||||
|
@ -336,6 +336,49 @@ bool StreamExecutor::GetBlasGemmAlgorithms(
|
||||
return blas_support->GetBlasGemmAlgorithms(out_algorithms);
|
||||
}
|
||||
|
||||
std::unique_ptr<blas::IBlasLtMatmulPlan> StreamExecutor::CreateBlasLtMatmulPlan(
|
||||
blas::DataType ab_type, blas::DataType cd_type,
|
||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
||||
uint64 k, int64 lda, int64 ldb, int64 ldc) {
|
||||
blas::BlasSupport *blas_support = AsBlas();
|
||||
if (!blas_support) {
|
||||
return nullptr;
|
||||
}
|
||||
return blas_support->CreateBlasLtMatmulPlan(
|
||||
ab_type, cd_type, computation_type, pointer_mode, transa, transb, m, n, k,
|
||||
lda, ldb, ldc);
|
||||
}
|
||||
|
||||
std::unique_ptr<blas::IBlasLtMatmulPlan>
|
||||
StreamExecutor::CreateBlasLtMatmulPlanStridedBatched(
|
||||
blas::DataType ab_type, blas::DataType cd_type,
|
||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
||||
uint64 k, uint64 batch_count, int64 lda, int64 stride_a, int64 ldb,
|
||||
int64 stride_b, int64 ldc, int64 stride_c) {
|
||||
blas::BlasSupport *blas_support = AsBlas();
|
||||
if (!blas_support) {
|
||||
return nullptr;
|
||||
}
|
||||
return blas_support->CreateBlasLtMatmulPlanStridedBatched(
|
||||
ab_type, cd_type, computation_type, pointer_mode, transa, transb, m, n, k,
|
||||
batch_count, lda, stride_a, ldb, stride_b, ldc, stride_c);
|
||||
}
|
||||
|
||||
bool StreamExecutor::GetBlasLtMatmulAlgorithms(
|
||||
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size,
|
||||
int max_algorithm_count,
|
||||
std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>*
|
||||
out_algorithms) {
|
||||
blas::BlasSupport *blas_support = AsBlas();
|
||||
if (!blas_support) {
|
||||
return false;
|
||||
}
|
||||
return blas_support->GetBlasLtMatmulAlgorithms(
|
||||
plan, max_workspace_size, max_algorithm_count, out_algorithms);
|
||||
}
|
||||
|
||||
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
|
||||
StreamExecutor::createRnnDescriptor(
|
||||
int num_layers, int hidden_size, int input_size, int cell_size,
|
||||
|
@ -394,6 +394,35 @@ class StreamExecutor {
|
||||
// Get the list of supported algorithms for BLAS gemm.
|
||||
bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms);
|
||||
|
||||
// Creates a backend-specific plan object for a blaslt matmul operation, which
|
||||
// can then be passed to DoBlasLtMatmul(). When possible, plans should be
|
||||
// created once and reused for multiple calls to DoBlasLtMatmul().
|
||||
// Returns a null pointer on failure.
|
||||
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
|
||||
blas::DataType ab_type, blas::DataType cd_type,
|
||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
||||
uint64 k, int64 lda, int64 ldb, int64 ldc);
|
||||
|
||||
// A more general version of CreateBlasLtMatmulPlan supporting
|
||||
// batched operations.
|
||||
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlanStridedBatched(
|
||||
blas::DataType ab_type, blas::DataType cd_type,
|
||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
||||
uint64 k, uint64 batch_count, int64 lda, int64 stride_a, int64 ldb,
|
||||
int64 stride_b, int64 ldc, int64 stride_c);
|
||||
|
||||
// Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
|
||||
// returned in the order of increasing estimated compute time according to an
|
||||
// internal heuristic. The first returned algorithm can be used as the default
|
||||
// algorithm if no autotuning is to be performed.
|
||||
bool GetBlasLtMatmulAlgorithms(
|
||||
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size,
|
||||
int max_algorithm_count,
|
||||
std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>*
|
||||
out_algorithms);
|
||||
|
||||
// Create an RNN descriptor based on model shapes and configurations.
|
||||
// The caller retains the ownership of the descriptor.
|
||||
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
|
||||
|
8
third_party/gpus/cuda/BUILD.tpl
vendored
8
third_party/gpus/cuda/BUILD.tpl
vendored
@ -127,6 +127,13 @@ cc_library(
|
||||
linkstatic = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cublasLt",
|
||||
srcs = ["cuda/lib/%{cublasLt_lib}"],
|
||||
data = ["cuda/lib/%{cublasLt_lib}"],
|
||||
linkstatic = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cusolver",
|
||||
srcs = ["cuda/lib/%{cusolver_lib}"],
|
||||
@ -168,6 +175,7 @@ cc_library(
|
||||
name = "cuda",
|
||||
deps = [
|
||||
":cublas",
|
||||
":cublasLt",
|
||||
":cuda_headers",
|
||||
":cudart",
|
||||
":cudnn",
|
||||
|
12
third_party/gpus/cuda_configure.bzl
vendored
12
third_party/gpus/cuda_configure.bzl
vendored
@ -551,6 +551,13 @@ def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config):
|
||||
cuda_config.cublas_version,
|
||||
static = False,
|
||||
),
|
||||
"cublasLt": _check_cuda_lib_params(
|
||||
"cublasLt",
|
||||
cpu_value,
|
||||
cuda_config.config["cublas_library_dir"],
|
||||
cuda_config.cublas_version,
|
||||
static = False,
|
||||
),
|
||||
"cusolver": _check_cuda_lib_params(
|
||||
"cusolver",
|
||||
cpu_value,
|
||||
@ -771,6 +778,7 @@ def _create_dummy_repository(repository_ctx):
|
||||
"%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value),
|
||||
"%{cudart_lib}": lib_name("cudart", cpu_value),
|
||||
"%{cublas_lib}": lib_name("cublas", cpu_value),
|
||||
"%{cublasLt_lib}": lib_name("cublasLt", cpu_value),
|
||||
"%{cusolver_lib}": lib_name("cusolver", cpu_value),
|
||||
"%{cudnn_lib}": lib_name("cudnn", cpu_value),
|
||||
"%{cufft_lib}": lib_name("cufft", cpu_value),
|
||||
@ -802,6 +810,7 @@ filegroup(name="cudnn-include")
|
||||
"cuda/cuda/lib/%s" % lib_name("cudart_static", cpu_value),
|
||||
)
|
||||
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublas", cpu_value))
|
||||
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublasLt", cpu_value))
|
||||
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusolver", cpu_value))
|
||||
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudnn", cpu_value))
|
||||
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("curand", cpu_value))
|
||||
@ -992,11 +1001,13 @@ def _create_local_cuda_repository(repository_ctx):
|
||||
cublas_include_path + "/cublas.h",
|
||||
cublas_include_path + "/cublas_v2.h",
|
||||
cublas_include_path + "/cublas_api.h",
|
||||
cublas_include_path + "/cublasLt.h",
|
||||
],
|
||||
outs = [
|
||||
"cublas/include/cublas.h",
|
||||
"cublas/include/cublas_v2.h",
|
||||
"cublas/include/cublas_api.h",
|
||||
"cublas/include/cublasLt.h",
|
||||
],
|
||||
))
|
||||
|
||||
@ -1137,6 +1148,7 @@ def _create_local_cuda_repository(repository_ctx):
|
||||
"%{cudart_static_linkopt}": _cudart_static_linkopt(cuda_config.cpu_value),
|
||||
"%{cudart_lib}": _basename(repository_ctx, cuda_libs["cudart"]),
|
||||
"%{cublas_lib}": _basename(repository_ctx, cuda_libs["cublas"]),
|
||||
"%{cublasLt_lib}": _basename(repository_ctx, cuda_libs["cublasLt"]),
|
||||
"%{cusolver_lib}": _basename(repository_ctx, cuda_libs["cusolver"]),
|
||||
"%{cudnn_lib}": _basename(repository_ctx, cuda_libs["cudnn"]),
|
||||
"%{cufft_lib}": _basename(repository_ctx, cuda_libs["cufft"]),
|
||||
|
Loading…
Reference in New Issue
Block a user