Small refactor of cublasLt wrapper code

- Makes the CUDABlasLtMatmulPlan class less verbose and more flexible.
- Makes no functional change.
This commit is contained in:
Ben Barsdell 2020-10-20 11:53:30 +11:00
parent dc82710fec
commit 0d71ed6157

View File

@ -3257,46 +3257,43 @@ blas::ComputationType ToComputationType<std::complex<double>>() {
class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
public:
CUDABlasLtMatmulPlan(UniqueOpDesc op_desc, UniqueLayoutDesc a_desc,
UniqueLayoutDesc b_desc, UniqueLayoutDesc c_desc,
UniqueLayoutDesc d_desc, blas::DataType ab_type,
blas::DataType c_type, blas::DataType scale_type,
blas::PointerMode pointer_mode, blas::Epilogue epilogue,
int batch_count, int64 stride_a, int64 stride_b,
int64 stride_c, int64 stride_d)
: op_desc_(std::move(op_desc)),
a_desc_(std::move(a_desc)),
b_desc_(std::move(b_desc)),
c_desc_(std::move(c_desc)),
d_desc_(std::move(d_desc)),
ab_type_(ab_type),
c_type_(c_type),
scale_type_(scale_type),
pointer_mode_(pointer_mode),
epilogue_(epilogue),
batch_count_(batch_count),
stride_a_(stride_a),
stride_b_(stride_b),
stride_c_(stride_c),
stride_d_(stride_d) {}
port::Status init(const blas::BlasLtMatmulPlanParams& p) {
SE_ASSIGN_OR_RETURN(
op_desc_,
CreateCublasLtOperationDesc(
p.computation_type, GetScaleType(p.c_type, p.computation_type),
p.pointer_mode, p.epilogue, p.transa, p.transb));
uint64 rows_a = p.transa == blas::Transpose::kNoTranspose ? p.m : p.k;
uint64 cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m;
uint64 rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n;
uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k;
SE_ASSIGN_OR_RETURN(
a_desc_, CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda,
p.stride_a, p.batch_count));
SE_ASSIGN_OR_RETURN(
b_desc_, CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb,
p.stride_b, p.batch_count));
SE_ASSIGN_OR_RETURN(
c_desc_, CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
p.batch_count));
SE_ASSIGN_OR_RETURN(
d_desc_, CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
p.batch_count));
params_ = p;
scale_type_ = GetScaleType(p.c_type, p.computation_type);
return port::Status::OK();
}
cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); }
cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); }
cublasLtMatrixLayout_t b_desc() const { return b_desc_.get(); }
cublasLtMatrixLayout_t c_desc() const { return c_desc_.get(); }
cublasLtMatrixLayout_t d_desc() const { return d_desc_.get(); }
bool ok() { return op_desc_ && a_desc_ && b_desc_ && c_desc_ && d_desc_; }
blas::DataType ab_type() const override { return ab_type_; }
blas::DataType c_type() const override { return c_type_; }
const blas::BlasLtMatmulPlanParams& params() const { return params_; }
blas::DataType scale_type() const { return scale_type_; }
blas::PointerMode pointer_mode() const { return pointer_mode_; }
blas::Epilogue epilogue() const { return epilogue_; }
int batch_count() const { return batch_count_; }
int64 stride_a() const { return stride_a_; }
int64 stride_b() const { return stride_b_; }
int64 stride_c() const { return stride_c_; }
int64 stride_d() const { return stride_d_; }
blas::DataType ab_type() const override { return params_.ab_type; }
blas::DataType c_type() const override { return params_.c_type; }
// Note: Must be const to satisfy API. This is always called before the plan
// is executed, so the state change is not observed in subsequent executions.
@ -3308,16 +3305,8 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
UniqueLayoutDesc b_desc_;
UniqueLayoutDesc c_desc_;
UniqueLayoutDesc d_desc_;
blas::DataType ab_type_;
blas::DataType c_type_;
blas::BlasLtMatmulPlanParams params_;
blas::DataType scale_type_;
blas::PointerMode pointer_mode_;
blas::Epilogue epilogue_;
int batch_count_;
int64 stride_a_;
int64 stride_b_;
int64 stride_c_;
int64 stride_d_;
};
bool CUDABlasLtMatmulPlan::SetBiasPointer(const void *bias) const {
@ -3365,7 +3354,7 @@ port::StatusOr<UniqueMatmulPreference> CreateCublasLtMatmulPreference(
max_workspace_bytes));
const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
if (cuda_plan.batch_count() == 0) {
if (cuda_plan.params().batch_count == 0) {
return unique_preference;
}
// This is a workaround for a known issue in cuBlasLt where the heuristic may
@ -3374,27 +3363,29 @@ port::StatusOr<UniqueMatmulPreference> CreateCublasLtMatmulPreference(
auto get_alignment_bytes = [](int64 stride, blas::DataType dtype) {
return (stride & -stride) * GetDataTypeSizeBytes(dtype);
};
if (cuda_plan.stride_a()) {
SE_RETURN_IF_ERROR(
SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES,
(uint32)get_alignment_bytes(cuda_plan.stride_a(),
cuda_plan.ab_type())));
if (cuda_plan.params().stride_a) {
SE_RETURN_IF_ERROR(SetCublasLtAttr(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES,
(uint32)get_alignment_bytes(cuda_plan.params().stride_a,
cuda_plan.params().ab_type)));
}
if (cuda_plan.stride_b()) {
SE_RETURN_IF_ERROR(
SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES,
(uint32)get_alignment_bytes(cuda_plan.stride_b(),
cuda_plan.ab_type())));
if (cuda_plan.params().stride_b) {
SE_RETURN_IF_ERROR(SetCublasLtAttr(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES,
(uint32)get_alignment_bytes(cuda_plan.params().stride_b,
cuda_plan.params().ab_type)));
}
if (cuda_plan.stride_c()) {
if (cuda_plan.params().stride_c) {
SE_RETURN_IF_ERROR(SetCublasLtAttr(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES,
(uint32)get_alignment_bytes(cuda_plan.stride_c(), cuda_plan.c_type())));
(uint32)get_alignment_bytes(cuda_plan.params().stride_c,
cuda_plan.params().c_type)));
}
if (cuda_plan.stride_d()) {
if (cuda_plan.params().stride_c) {
SE_RETURN_IF_ERROR(SetCublasLtAttr(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES,
(uint32)get_alignment_bytes(cuda_plan.stride_d(), cuda_plan.c_type())));
(uint32)get_alignment_bytes(cuda_plan.params().stride_c,
cuda_plan.params().c_type)));
}
return unique_preference;
}
@ -3406,35 +3397,10 @@ port::StatusOr<UniqueMatmulPreference> CreateCublasLtMatmulPreference(
port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
CUDABlas::CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &p) {
#if CUDA_VERSION >= 11000
SE_ASSIGN_OR_RETURN(
auto op_desc,
CreateCublasLtOperationDesc(
p.computation_type, GetScaleType(p.c_type, p.computation_type),
p.pointer_mode, p.epilogue, p.transa, p.transb));
uint64 rows_a = p.transa == blas::Transpose::kNoTranspose ? p.m : p.k;
uint64 cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m;
uint64 rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n;
uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k;
SE_ASSIGN_OR_RETURN(auto a_desc,
CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda,
p.stride_a, p.batch_count));
SE_ASSIGN_OR_RETURN(auto b_desc,
CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb,
p.stride_b, p.batch_count));
SE_ASSIGN_OR_RETURN(auto c_desc,
CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc,
p.stride_c, p.batch_count));
SE_ASSIGN_OR_RETURN(auto d_desc,
CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc,
p.stride_c, p.batch_count));
blas::DataType scale_type = GetScaleType(p.c_type, p.computation_type);
auto cuda_plan = std::make_unique<CUDABlasLtMatmulPlan>();
SE_RETURN_IF_ERROR(cuda_plan->init(p));
return static_cast<std::unique_ptr<blas::IBlasLtMatmulPlan>>(
std::make_unique<CUDABlasLtMatmulPlan>(
std::move(op_desc), std::move(a_desc), std::move(b_desc),
std::move(c_desc), std::move(d_desc), p.ab_type, p.c_type, scale_type,
p.pointer_mode, p.epilogue, p.batch_count, p.stride_a, p.stride_b,
p.stride_c, p.stride_c));
std::move(cuda_plan));
#else
return port::Status(
port::error::UNIMPLEMENTED,
@ -3514,14 +3480,14 @@ bool CUDABlas::DoBlasLtMatmulInternal(
return false;
}
bool is_pointer_mode_host = !alpha.is_pointer();
if ((cuda_plan.pointer_mode() == blas::PointerMode::kHost) !=
if ((cuda_plan.params().pointer_mode == blas::PointerMode::kHost) !=
is_pointer_mode_host) {
VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong "
"pointer_mode for the given alpha/beta.";
return false;
}
if ((cuda_plan.epilogue() == blas::Epilogue::kBias ||
cuda_plan.epilogue() == blas::Epilogue::kBiasThenReLU) !=
if ((cuda_plan.params().epilogue == blas::Epilogue::kBias ||
cuda_plan.params().epilogue == blas::Epilogue::kBiasThenReLU) !=
(bias != nullptr)) {
VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong "
"epilogue for the given bias pointer.";