Add workaround for cublasLt known issue
- Avoids a heuristic alignment issue noted in the CUDA Release Notes.
This commit is contained in:
parent
f7d29c94df
commit
8c0eb4b35b
@ -488,6 +488,26 @@ cudaDataType_t GetCUDADataType(blas::DataType ty) {
|
||||
return CUDA_C_64F;
|
||||
}
|
||||
}
|
||||
|
||||
int GetDataTypeSizeBytes(blas::DataType ty) {
|
||||
switch (ty) {
|
||||
case blas::DataType::kF16:
|
||||
return 2;
|
||||
case blas::DataType::kF32:
|
||||
return 4;
|
||||
case blas::DataType::kF64:
|
||||
return 8;
|
||||
case blas::DataType::kI8:
|
||||
return 1;
|
||||
case blas::DataType::kI32:
|
||||
return 4;
|
||||
case blas::DataType::kComplexF32:
|
||||
return 8;
|
||||
case blas::DataType::kComplexF64:
|
||||
return 16;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename FuncT, typename... Args>
|
||||
@ -3161,22 +3181,6 @@ UniqueLayoutDesc CreateCublasLtLayoutDesc(blas::DataType data_type, uint64 rows,
|
||||
return unique_desc;
|
||||
}
|
||||
|
||||
UniqueMatmulPreference CreateCublasLtMatmulPreference(
|
||||
size_t max_workspace_bytes) {
|
||||
cublasLtMatmulPreference_t preference;
|
||||
cublasStatus_t status = cublasLtMatmulPreferenceCreate(&preference);
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
VLOG(2) << "cublasLtMatmulPreferenceCreate failed: " << ToString(status);
|
||||
return nullptr;
|
||||
}
|
||||
UniqueMatmulPreference unique_preference(preference);
|
||||
if (!SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
max_workspace_bytes)) {
|
||||
return nullptr;
|
||||
}
|
||||
return unique_preference;
|
||||
}
|
||||
|
||||
// Helper function to allocate workspace.
|
||||
port::Status AllocateWorkspace(void** workspace,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
@ -3230,6 +3234,11 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
|
||||
blas::DataType cd_type() const { return cd_type_; }
|
||||
blas::DataType scale_type() const { return scale_type_; }
|
||||
blas::PointerMode pointer_mode() const { return pointer_mode_; }
|
||||
int batch_count() const { return batch_count_; }
|
||||
int64 stride_a() const { return stride_a_; }
|
||||
int64 stride_b() const { return stride_b_; }
|
||||
int64 stride_c() const { return stride_c_; }
|
||||
int64 stride_d() const { return stride_d_; }
|
||||
|
||||
private:
|
||||
UniqueOpDesc op_desc_;
|
||||
@ -3241,6 +3250,11 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
|
||||
blas::DataType cd_type_;
|
||||
blas::DataType scale_type_;
|
||||
blas::PointerMode pointer_mode_;
|
||||
int batch_count_;
|
||||
int64 stride_a_;
|
||||
int64 stride_b_;
|
||||
int64 stride_c_;
|
||||
int64 stride_d_;
|
||||
};
|
||||
|
||||
CUDABlasLtMatmulPlan::CUDABlasLtMatmulPlan(
|
||||
@ -3261,7 +3275,12 @@ CUDABlasLtMatmulPlan::CUDABlasLtMatmulPlan(
|
||||
ab_type_(ab_type),
|
||||
cd_type_(cd_type),
|
||||
scale_type_(GetScaleType(cd_type, computation_type)),
|
||||
pointer_mode_(pointer_mode) {
|
||||
pointer_mode_(pointer_mode),
|
||||
batch_count_(batch_count),
|
||||
stride_a_(stride_a),
|
||||
stride_b_(stride_b),
|
||||
stride_c_(stride_c),
|
||||
stride_d_(stride_d) {
|
||||
uint64 rows_a = transa == blas::Transpose::kNoTranspose ? m : k;
|
||||
uint64 cols_a = transa == blas::Transpose::kNoTranspose ? k : m;
|
||||
uint64 rows_b = transb == blas::Transpose::kNoTranspose ? k : n;
|
||||
@ -3296,6 +3315,53 @@ class CUDABlasLtMatmulAlgorithm final : public blas::IBlasLtMatmulAlgorithm {
|
||||
size_t workspace_size_;
|
||||
};
|
||||
|
||||
UniqueMatmulPreference CreateCublasLtMatmulPreference(
|
||||
const blas::IBlasLtMatmulPlan* plan,
|
||||
size_t max_workspace_bytes) {
|
||||
cublasLtMatmulPreference_t preference;
|
||||
cublasStatus_t status = cublasLtMatmulPreferenceCreate(&preference);
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
VLOG(2) << "cublasLtMatmulPreferenceCreate failed: " << ToString(status);
|
||||
return nullptr;
|
||||
}
|
||||
UniqueMatmulPreference unique_preference(preference);
|
||||
if (!SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
max_workspace_bytes)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto& cuda_plan = *static_cast<const CUDABlasLtMatmulPlan*>(plan);
|
||||
if (cuda_plan.batch_count() == 0) {
|
||||
return unique_preference;
|
||||
}
|
||||
// This is a workaround for a known issue in cuBlasLt where the heuristic may
|
||||
// in rare cases select an algo that does not support the specified stride.
|
||||
// Specifying the alignment requirements manually like this avoids the issue.
|
||||
auto get_alignment_bytes = [](int64 stride, blas::DataType dtype) {
|
||||
return (stride & -stride) * GetDataTypeSizeBytes(dtype);
|
||||
};
|
||||
if ((cuda_plan.stride_a() &&
|
||||
!SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES,
|
||||
(uint32)get_alignment_bytes(cuda_plan.stride_a(),
|
||||
cuda_plan.ab_type()))) ||
|
||||
(cuda_plan.stride_b() &&
|
||||
!SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES,
|
||||
(uint32)get_alignment_bytes(cuda_plan.stride_b(),
|
||||
cuda_plan.ab_type()))) ||
|
||||
(cuda_plan.stride_c() &&
|
||||
!SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES,
|
||||
(uint32)get_alignment_bytes(cuda_plan.stride_c(),
|
||||
cuda_plan.cd_type()))) ||
|
||||
(cuda_plan.stride_d() &&
|
||||
!SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES,
|
||||
(uint32)get_alignment_bytes(cuda_plan.stride_d(),
|
||||
cuda_plan.cd_type())))) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return unique_preference;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
#endif // CUDA_VERSION >= 11000
|
||||
@ -3327,7 +3393,7 @@ bool CUDABlas::GetBlasLtMatmulAlgorithms(
|
||||
out_algorithms) {
|
||||
#if CUDA_VERSION >= 11000
|
||||
UniqueMatmulPreference preference =
|
||||
CreateCublasLtMatmulPreference(max_workspace_size);
|
||||
CreateCublasLtMatmulPreference(plan, max_workspace_size);
|
||||
if (!preference) return false;
|
||||
|
||||
std::vector<cublasLtMatmulHeuristicResult_t> results(max_algorithm_count);
|
||||
|
Loading…
Reference in New Issue
Block a user