From 0d71ed6157e2922fa88d40195a6104b6d5b0ebe4 Mon Sep 17 00:00:00 2001 From: Ben Barsdell Date: Tue, 20 Oct 2020 11:53:30 +1100 Subject: [PATCH] Small refactor of cublasLt wrapper code - Makes the CUDABlasLtMatmulPlan class less verbose and more flexible. - Makes no functional change. --- tensorflow/stream_executor/cuda/cuda_blas.cc | 140 +++++++------------ 1 file changed, 53 insertions(+), 87 deletions(-) diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index 7d606d44ec3..8b9d9174f22 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -3257,46 +3257,43 @@ blas::ComputationType ToComputationType>() { class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan { public: - CUDABlasLtMatmulPlan(UniqueOpDesc op_desc, UniqueLayoutDesc a_desc, - UniqueLayoutDesc b_desc, UniqueLayoutDesc c_desc, - UniqueLayoutDesc d_desc, blas::DataType ab_type, - blas::DataType c_type, blas::DataType scale_type, - blas::PointerMode pointer_mode, blas::Epilogue epilogue, - int batch_count, int64 stride_a, int64 stride_b, - int64 stride_c, int64 stride_d) - : op_desc_(std::move(op_desc)), - a_desc_(std::move(a_desc)), - b_desc_(std::move(b_desc)), - c_desc_(std::move(c_desc)), - d_desc_(std::move(d_desc)), - ab_type_(ab_type), - c_type_(c_type), - scale_type_(scale_type), - pointer_mode_(pointer_mode), - epilogue_(epilogue), - batch_count_(batch_count), - stride_a_(stride_a), - stride_b_(stride_b), - stride_c_(stride_c), - stride_d_(stride_d) {} + port::Status init(const blas::BlasLtMatmulPlanParams& p) { + SE_ASSIGN_OR_RETURN( + op_desc_, + CreateCublasLtOperationDesc( + p.computation_type, GetScaleType(p.c_type, p.computation_type), + p.pointer_mode, p.epilogue, p.transa, p.transb)); + uint64 rows_a = p.transa == blas::Transpose::kNoTranspose ? p.m : p.k; + uint64 cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m; + uint64 rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n; + uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k; + SE_ASSIGN_OR_RETURN( + a_desc_, CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda, + p.stride_a, p.batch_count)); + SE_ASSIGN_OR_RETURN( + b_desc_, CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb, + p.stride_b, p.batch_count)); + SE_ASSIGN_OR_RETURN( + c_desc_, CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c, + p.batch_count)); + SE_ASSIGN_OR_RETURN( + d_desc_, CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c, + p.batch_count)); + params_ = p; + scale_type_ = GetScaleType(p.c_type, p.computation_type); + return port::Status::OK(); + } cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); } cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); } cublasLtMatrixLayout_t b_desc() const { return b_desc_.get(); } cublasLtMatrixLayout_t c_desc() const { return c_desc_.get(); } cublasLtMatrixLayout_t d_desc() const { return d_desc_.get(); } - bool ok() { return op_desc_ && a_desc_ && b_desc_ && c_desc_ && d_desc_; } - blas::DataType ab_type() const override { return ab_type_; } - blas::DataType c_type() const override { return c_type_; } + const blas::BlasLtMatmulPlanParams& params() const { return params_; } blas::DataType scale_type() const { return scale_type_; } - blas::PointerMode pointer_mode() const { return pointer_mode_; } - blas::Epilogue epilogue() const { return epilogue_; } - int batch_count() const { return batch_count_; } - int64 stride_a() const { return stride_a_; } - int64 stride_b() const { return stride_b_; } - int64 stride_c() const { return stride_c_; } - int64 stride_d() const { return stride_d_; } + blas::DataType ab_type() const override { return params_.ab_type; } + blas::DataType c_type() const override { return params_.c_type; } // Note: Must be const to satisfy API. This is always called before the plan // is executed, so the state change is not observed in subsequent executions. @@ -3308,16 +3305,8 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan { UniqueLayoutDesc b_desc_; UniqueLayoutDesc c_desc_; UniqueLayoutDesc d_desc_; - blas::DataType ab_type_; - blas::DataType c_type_; + blas::BlasLtMatmulPlanParams params_; blas::DataType scale_type_; - blas::PointerMode pointer_mode_; - blas::Epilogue epilogue_; - int batch_count_; - int64 stride_a_; - int64 stride_b_; - int64 stride_c_; - int64 stride_d_; }; bool CUDABlasLtMatmulPlan::SetBiasPointer(const void *bias) const { @@ -3365,7 +3354,7 @@ port::StatusOr CreateCublasLtMatmulPreference( max_workspace_bytes)); const auto &cuda_plan = *static_cast(plan); - if (cuda_plan.batch_count() == 0) { + if (cuda_plan.params().batch_count == 0) { return unique_preference; } // This is a workaround for a known issue in cuBlasLt where the heuristic may @@ -3374,27 +3363,29 @@ port::StatusOr CreateCublasLtMatmulPreference( auto get_alignment_bytes = [](int64 stride, blas::DataType dtype) { return (stride & -stride) * GetDataTypeSizeBytes(dtype); }; - if (cuda_plan.stride_a()) { - SE_RETURN_IF_ERROR( - SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, - (uint32)get_alignment_bytes(cuda_plan.stride_a(), - cuda_plan.ab_type()))); + if (cuda_plan.params().stride_a) { + SE_RETURN_IF_ERROR(SetCublasLtAttr( + preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, + (uint32)get_alignment_bytes(cuda_plan.params().stride_a, + cuda_plan.params().ab_type))); } - if (cuda_plan.stride_b()) { - SE_RETURN_IF_ERROR( - SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, - (uint32)get_alignment_bytes(cuda_plan.stride_b(), - cuda_plan.ab_type()))); + if (cuda_plan.params().stride_b) { + SE_RETURN_IF_ERROR(SetCublasLtAttr( + preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, + (uint32)get_alignment_bytes(cuda_plan.params().stride_b, + cuda_plan.params().ab_type))); } - if (cuda_plan.stride_c()) { + if (cuda_plan.params().stride_c) { SE_RETURN_IF_ERROR(SetCublasLtAttr( preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, - (uint32)get_alignment_bytes(cuda_plan.stride_c(), cuda_plan.c_type()))); + (uint32)get_alignment_bytes(cuda_plan.params().stride_c, + cuda_plan.params().c_type))); } - if (cuda_plan.stride_d()) { + if (cuda_plan.params().stride_c) { SE_RETURN_IF_ERROR(SetCublasLtAttr( preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, - (uint32)get_alignment_bytes(cuda_plan.stride_d(), cuda_plan.c_type()))); + (uint32)get_alignment_bytes(cuda_plan.params().stride_c, + cuda_plan.params().c_type))); } return unique_preference; } @@ -3406,35 +3397,10 @@ port::StatusOr CreateCublasLtMatmulPreference( port::StatusOr> CUDABlas::CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &p) { #if CUDA_VERSION >= 11000 - SE_ASSIGN_OR_RETURN( - auto op_desc, - CreateCublasLtOperationDesc( - p.computation_type, GetScaleType(p.c_type, p.computation_type), - p.pointer_mode, p.epilogue, p.transa, p.transb)); - uint64 rows_a = p.transa == blas::Transpose::kNoTranspose ? p.m : p.k; - uint64 cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m; - uint64 rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n; - uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k; - SE_ASSIGN_OR_RETURN(auto a_desc, - CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda, - p.stride_a, p.batch_count)); - SE_ASSIGN_OR_RETURN(auto b_desc, - CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb, - p.stride_b, p.batch_count)); - SE_ASSIGN_OR_RETURN(auto c_desc, - CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, - p.stride_c, p.batch_count)); - SE_ASSIGN_OR_RETURN(auto d_desc, - CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, - p.stride_c, p.batch_count)); - blas::DataType scale_type = GetScaleType(p.c_type, p.computation_type); - + auto cuda_plan = std::make_unique(); + SE_RETURN_IF_ERROR(cuda_plan->init(p)); return static_cast>( - std::make_unique( - std::move(op_desc), std::move(a_desc), std::move(b_desc), - std::move(c_desc), std::move(d_desc), p.ab_type, p.c_type, scale_type, - p.pointer_mode, p.epilogue, p.batch_count, p.stride_a, p.stride_b, - p.stride_c, p.stride_c)); + std::move(cuda_plan)); #else return port::Status( port::error::UNIMPLEMENTED, @@ -3514,14 +3480,14 @@ bool CUDABlas::DoBlasLtMatmulInternal( return false; } bool is_pointer_mode_host = !alpha.is_pointer(); - if ((cuda_plan.pointer_mode() == blas::PointerMode::kHost) != + if ((cuda_plan.params().pointer_mode == blas::PointerMode::kHost) != is_pointer_mode_host) { VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong " "pointer_mode for the given alpha/beta."; return false; } - if ((cuda_plan.epilogue() == blas::Epilogue::kBias || - cuda_plan.epilogue() == blas::Epilogue::kBiasThenReLU) != + if ((cuda_plan.params().epilogue == blas::Epilogue::kBias || + cuda_plan.params().epilogue == blas::Epilogue::kBiasThenReLU) != (bias != nullptr)) { VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong " "epilogue for the given bias pointer.";