From 0d71ed6157e2922fa88d40195a6104b6d5b0ebe4 Mon Sep 17 00:00:00 2001 From: Ben Barsdell Date: Tue, 20 Oct 2020 11:53:30 +1100 Subject: [PATCH 1/3] Small refactor of cublasLt wrapper code - Makes the CUDABlasLtMatmulPlan class less verbose and more flexible. - Makes no functional change. --- tensorflow/stream_executor/cuda/cuda_blas.cc | 140 +++++++------------ 1 file changed, 53 insertions(+), 87 deletions(-) diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index 7d606d44ec3..8b9d9174f22 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -3257,46 +3257,43 @@ blas::ComputationType ToComputationType>() { class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan { public: - CUDABlasLtMatmulPlan(UniqueOpDesc op_desc, UniqueLayoutDesc a_desc, - UniqueLayoutDesc b_desc, UniqueLayoutDesc c_desc, - UniqueLayoutDesc d_desc, blas::DataType ab_type, - blas::DataType c_type, blas::DataType scale_type, - blas::PointerMode pointer_mode, blas::Epilogue epilogue, - int batch_count, int64 stride_a, int64 stride_b, - int64 stride_c, int64 stride_d) - : op_desc_(std::move(op_desc)), - a_desc_(std::move(a_desc)), - b_desc_(std::move(b_desc)), - c_desc_(std::move(c_desc)), - d_desc_(std::move(d_desc)), - ab_type_(ab_type), - c_type_(c_type), - scale_type_(scale_type), - pointer_mode_(pointer_mode), - epilogue_(epilogue), - batch_count_(batch_count), - stride_a_(stride_a), - stride_b_(stride_b), - stride_c_(stride_c), - stride_d_(stride_d) {} + port::Status init(const blas::BlasLtMatmulPlanParams& p) { + SE_ASSIGN_OR_RETURN( + op_desc_, + CreateCublasLtOperationDesc( + p.computation_type, GetScaleType(p.c_type, p.computation_type), + p.pointer_mode, p.epilogue, p.transa, p.transb)); + uint64 rows_a = p.transa == blas::Transpose::kNoTranspose ? p.m : p.k; + uint64 cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m; + uint64 rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n; + uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k; + SE_ASSIGN_OR_RETURN( + a_desc_, CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda, + p.stride_a, p.batch_count)); + SE_ASSIGN_OR_RETURN( + b_desc_, CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb, + p.stride_b, p.batch_count)); + SE_ASSIGN_OR_RETURN( + c_desc_, CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c, + p.batch_count)); + SE_ASSIGN_OR_RETURN( + d_desc_, CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c, + p.batch_count)); + params_ = p; + scale_type_ = GetScaleType(p.c_type, p.computation_type); + return port::Status::OK(); + } cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); } cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); } cublasLtMatrixLayout_t b_desc() const { return b_desc_.get(); } cublasLtMatrixLayout_t c_desc() const { return c_desc_.get(); } cublasLtMatrixLayout_t d_desc() const { return d_desc_.get(); } - bool ok() { return op_desc_ && a_desc_ && b_desc_ && c_desc_ && d_desc_; } - blas::DataType ab_type() const override { return ab_type_; } - blas::DataType c_type() const override { return c_type_; } + const blas::BlasLtMatmulPlanParams& params() const { return params_; } blas::DataType scale_type() const { return scale_type_; } - blas::PointerMode pointer_mode() const { return pointer_mode_; } - blas::Epilogue epilogue() const { return epilogue_; } - int batch_count() const { return batch_count_; } - int64 stride_a() const { return stride_a_; } - int64 stride_b() const { return stride_b_; } - int64 stride_c() const { return stride_c_; } - int64 stride_d() const { return stride_d_; } + blas::DataType ab_type() const override { return params_.ab_type; } + blas::DataType c_type() const override { return params_.c_type; } // Note: Must be const to satisfy API. This is always called before the plan // is executed, so the state change is not observed in subsequent executions. @@ -3308,16 +3305,8 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan { UniqueLayoutDesc b_desc_; UniqueLayoutDesc c_desc_; UniqueLayoutDesc d_desc_; - blas::DataType ab_type_; - blas::DataType c_type_; + blas::BlasLtMatmulPlanParams params_; blas::DataType scale_type_; - blas::PointerMode pointer_mode_; - blas::Epilogue epilogue_; - int batch_count_; - int64 stride_a_; - int64 stride_b_; - int64 stride_c_; - int64 stride_d_; }; bool CUDABlasLtMatmulPlan::SetBiasPointer(const void *bias) const { @@ -3365,7 +3354,7 @@ port::StatusOr CreateCublasLtMatmulPreference( max_workspace_bytes)); const auto &cuda_plan = *static_cast(plan); - if (cuda_plan.batch_count() == 0) { + if (cuda_plan.params().batch_count == 0) { return unique_preference; } // This is a workaround for a known issue in cuBlasLt where the heuristic may @@ -3374,27 +3363,29 @@ port::StatusOr CreateCublasLtMatmulPreference( auto get_alignment_bytes = [](int64 stride, blas::DataType dtype) { return (stride & -stride) * GetDataTypeSizeBytes(dtype); }; - if (cuda_plan.stride_a()) { - SE_RETURN_IF_ERROR( - SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, - (uint32)get_alignment_bytes(cuda_plan.stride_a(), - cuda_plan.ab_type()))); + if (cuda_plan.params().stride_a) { + SE_RETURN_IF_ERROR(SetCublasLtAttr( + preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, + (uint32)get_alignment_bytes(cuda_plan.params().stride_a, + cuda_plan.params().ab_type))); } - if (cuda_plan.stride_b()) { - SE_RETURN_IF_ERROR( - SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, - (uint32)get_alignment_bytes(cuda_plan.stride_b(), - cuda_plan.ab_type()))); + if (cuda_plan.params().stride_b) { + SE_RETURN_IF_ERROR(SetCublasLtAttr( + preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, + (uint32)get_alignment_bytes(cuda_plan.params().stride_b, + cuda_plan.params().ab_type))); } - if (cuda_plan.stride_c()) { + if (cuda_plan.params().stride_c) { SE_RETURN_IF_ERROR(SetCublasLtAttr( preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, - (uint32)get_alignment_bytes(cuda_plan.stride_c(), cuda_plan.c_type()))); + (uint32)get_alignment_bytes(cuda_plan.params().stride_c, + cuda_plan.params().c_type))); } - if (cuda_plan.stride_d()) { + if (cuda_plan.params().stride_c) { SE_RETURN_IF_ERROR(SetCublasLtAttr( preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, - (uint32)get_alignment_bytes(cuda_plan.stride_d(), cuda_plan.c_type()))); + (uint32)get_alignment_bytes(cuda_plan.params().stride_c, + cuda_plan.params().c_type))); } return unique_preference; } @@ -3406,35 +3397,10 @@ port::StatusOr CreateCublasLtMatmulPreference( port::StatusOr> CUDABlas::CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &p) { #if CUDA_VERSION >= 11000 - SE_ASSIGN_OR_RETURN( - auto op_desc, - CreateCublasLtOperationDesc( - p.computation_type, GetScaleType(p.c_type, p.computation_type), - p.pointer_mode, p.epilogue, p.transa, p.transb)); - uint64 rows_a = p.transa == blas::Transpose::kNoTranspose ? p.m : p.k; - uint64 cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m; - uint64 rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n; - uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k; - SE_ASSIGN_OR_RETURN(auto a_desc, - CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda, - p.stride_a, p.batch_count)); - SE_ASSIGN_OR_RETURN(auto b_desc, - CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb, - p.stride_b, p.batch_count)); - SE_ASSIGN_OR_RETURN(auto c_desc, - CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, - p.stride_c, p.batch_count)); - SE_ASSIGN_OR_RETURN(auto d_desc, - CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, - p.stride_c, p.batch_count)); - blas::DataType scale_type = GetScaleType(p.c_type, p.computation_type); - + auto cuda_plan = std::make_unique(); + SE_RETURN_IF_ERROR(cuda_plan->init(p)); return static_cast>( - std::make_unique( - std::move(op_desc), std::move(a_desc), std::move(b_desc), - std::move(c_desc), std::move(d_desc), p.ab_type, p.c_type, scale_type, - p.pointer_mode, p.epilogue, p.batch_count, p.stride_a, p.stride_b, - p.stride_c, p.stride_c)); + std::move(cuda_plan)); #else return port::Status( port::error::UNIMPLEMENTED, @@ -3514,14 +3480,14 @@ bool CUDABlas::DoBlasLtMatmulInternal( return false; } bool is_pointer_mode_host = !alpha.is_pointer(); - if ((cuda_plan.pointer_mode() == blas::PointerMode::kHost) != + if ((cuda_plan.params().pointer_mode == blas::PointerMode::kHost) != is_pointer_mode_host) { VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong " "pointer_mode for the given alpha/beta."; return false; } - if ((cuda_plan.epilogue() == blas::Epilogue::kBias || - cuda_plan.epilogue() == blas::Epilogue::kBiasThenReLU) != + if ((cuda_plan.params().epilogue == blas::Epilogue::kBias || + cuda_plan.params().epilogue == blas::Epilogue::kBiasThenReLU) != (bias != nullptr)) { VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong " "epilogue for the given bias pointer."; From 7b477b8752d51f08ab599652cd02c8dd938d2e3e Mon Sep 17 00:00:00 2001 From: Ben Barsdell Date: Tue, 20 Oct 2020 12:39:47 +1100 Subject: [PATCH 2/3] Move blasLt SetBiasPointer call inside mutex lock - This was not thread-safe before. --- tensorflow/stream_executor/cuda/cuda_blas.cc | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index 8b9d9174f22..afea6790187 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -3493,13 +3493,6 @@ bool CUDABlas::DoBlasLtMatmulInternal( "epilogue for the given bias pointer."; return false; } - if (bias != nullptr) { - if (!cuda_plan.SetBiasPointer(bias.opaque())) { - VLOG(2) << "DoBlasLtMatmul returning false because setting the bias " - "pointer failed."; - return false; - } - } const void *alpha_ptr = alpha.is_pointer() ? alpha.opaque_pointer().opaque() : alpha.opaque_value(); const void *beta_ptr = @@ -3524,6 +3517,14 @@ bool CUDABlas::DoBlasLtMatmulInternal( absl::MutexLock lock(&mu_); + if (bias != nullptr) { + if (!cuda_plan.SetBiasPointer(bias.opaque())) { + VLOG(2) << "DoBlasLtMatmul returning false because setting the bias " + "pointer failed."; + return false; + } + } + CHECK(blasLt_ != nullptr); gpu::ScopedActivateExecutorContext sac{parent_}; From aae59c53f4fb698e34d640c0ec6e90ddd826ff88 Mon Sep 17 00:00:00 2001 From: Ben Barsdell Date: Tue, 20 Oct 2020 19:51:20 +1100 Subject: [PATCH 3/3] Workaround blasLt batch size limitation - cublasLtMatmul does not always support batch sizes > 65535. - This commit breaks plan execution into repeated calls with up to the max batch size, followed by a remainder call. --- tensorflow/stream_executor/cuda/cuda_blas.cc | 183 ++++++++++++++++--- tensorflow/stream_executor/cuda/cuda_blas.h | 7 + 2 files changed, 167 insertions(+), 23 deletions(-) diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index afea6790187..6e6244b5da3 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -3258,6 +3258,8 @@ blas::ComputationType ToComputationType>() { class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan { public: port::Status init(const blas::BlasLtMatmulPlanParams& p) { + params_ = p; + scale_type_ = GetScaleType(p.c_type, p.computation_type); SE_ASSIGN_OR_RETURN( op_desc_, CreateCublasLtOperationDesc( @@ -3269,18 +3271,36 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan { uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k; SE_ASSIGN_OR_RETURN( a_desc_, CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda, - p.stride_a, p.batch_count)); + p.stride_a, capped_batch_count())); SE_ASSIGN_OR_RETURN( b_desc_, CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb, - p.stride_b, p.batch_count)); + p.stride_b, capped_batch_count())); SE_ASSIGN_OR_RETURN( c_desc_, CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c, - p.batch_count)); + capped_batch_count())); SE_ASSIGN_OR_RETURN( d_desc_, CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c, - p.batch_count)); - params_ = p; - scale_type_ = GetScaleType(p.c_type, p.computation_type); + capped_batch_count())); + remainder_batch_count_ = + p.batch_count > kMaxBatchCount ? p.batch_count % kMaxBatchCount : 0; + if (remainder_batch_count_) { + SE_ASSIGN_OR_RETURN( + a_remainder_desc_, + CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda, p.stride_a, + remainder_batch_count_)); + SE_ASSIGN_OR_RETURN( + b_remainder_desc_, + CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb, p.stride_b, + remainder_batch_count_)); + SE_ASSIGN_OR_RETURN( + c_remainder_desc_, + CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c, + remainder_batch_count_)); + SE_ASSIGN_OR_RETURN( + d_remainder_desc_, + CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c, + remainder_batch_count_)); + } return port::Status::OK(); } @@ -3289,24 +3309,51 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan { cublasLtMatrixLayout_t b_desc() const { return b_desc_.get(); } cublasLtMatrixLayout_t c_desc() const { return c_desc_.get(); } cublasLtMatrixLayout_t d_desc() const { return d_desc_.get(); } + cublasLtMatrixLayout_t a_remainder_desc() const { + return a_remainder_desc_.get(); + } + cublasLtMatrixLayout_t b_remainder_desc() const { + return b_remainder_desc_.get(); + } + cublasLtMatrixLayout_t c_remainder_desc() const { + return c_remainder_desc_.get(); + } + cublasLtMatrixLayout_t d_remainder_desc() const { + return d_remainder_desc_.get(); + } const blas::BlasLtMatmulPlanParams& params() const { return params_; } blas::DataType scale_type() const { return scale_type_; } blas::DataType ab_type() const override { return params_.ab_type; } blas::DataType c_type() const override { return params_.c_type; } + int capped_batch_count() const { + return std::min(params_.batch_count, kMaxBatchCount); + } + int remainder_batch_count() const { return remainder_batch_count_; } // Note: Must be const to satisfy API. This is always called before the plan // is executed, so the state change is not observed in subsequent executions. bool SetBiasPointer(const void *bias) const; private: + // In some cases cublasLt does not support large batch sizes, so we need to + // split up such cases into multiple calls. + static constexpr const int kMaxBatchCount = 65535; + blas::BlasLtMatmulPlanParams params_; + blas::DataType scale_type_; UniqueOpDesc op_desc_; + // These have batch count set to capped_batch_count(). UniqueLayoutDesc a_desc_; UniqueLayoutDesc b_desc_; UniqueLayoutDesc c_desc_; UniqueLayoutDesc d_desc_; - blas::BlasLtMatmulPlanParams params_; - blas::DataType scale_type_; + int remainder_batch_count_; + // These have batch count set to remainder_batch_count_, and are only created + // if params_.batch_count > kMaxBatchSize. + UniqueLayoutDesc a_remainder_desc_; + UniqueLayoutDesc b_remainder_desc_; + UniqueLayoutDesc c_remainder_desc_; + UniqueLayoutDesc d_remainder_desc_; }; bool CUDABlasLtMatmulPlan::SetBiasPointer(const void *bias) const { @@ -3409,9 +3456,10 @@ CUDABlas::CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &p) { } port::StatusOr>> -CUDABlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan, - size_t max_workspace_size, - int max_algorithm_count) { +CUDABlas::GetBlasLtMatmulAlgorithmsInternal(const blas::IBlasLtMatmulPlan* plan, + size_t max_workspace_size, + int max_algorithm_count, + bool for_remainder_batch) { #if CUDA_VERSION >= 11000 SE_ASSIGN_OR_RETURN(UniqueMatmulPreference preference, CreateCublasLtMatmulPreference(plan, max_workspace_size)); @@ -3426,10 +3474,18 @@ CUDABlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan, int found_algorithm_count = 0; const auto &cuda_plan = *static_cast(plan); + const auto& a_desc = + for_remainder_batch ? cuda_plan.a_remainder_desc() : cuda_plan.a_desc(); + const auto& b_desc = + for_remainder_batch ? cuda_plan.b_remainder_desc() : cuda_plan.b_desc(); + const auto& c_desc = + for_remainder_batch ? cuda_plan.c_remainder_desc() : cuda_plan.c_desc(); + const auto& d_desc = + for_remainder_batch ? cuda_plan.d_remainder_desc() : cuda_plan.d_desc(); cublasStatus_t status = cublasLtMatmulAlgoGetHeuristic( - blasLt_, cuda_plan.op_desc(), cuda_plan.a_desc(), cuda_plan.b_desc(), - cuda_plan.c_desc(), cuda_plan.d_desc(), preference.get(), - max_algorithm_count, results.data(), &found_algorithm_count); + blasLt_, cuda_plan.op_desc(), a_desc, b_desc, c_desc, d_desc, + preference.get(), max_algorithm_count, results.data(), + &found_algorithm_count); if (status != CUBLAS_STATUS_SUCCESS) { return port::Status( port::error::INTERNAL, @@ -3455,6 +3511,14 @@ CUDABlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan, #endif } +port::StatusOr>> +CUDABlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan* plan, + size_t max_workspace_size, + int max_algorithm_count) { + return GetBlasLtMatmulAlgorithmsInternal(plan, max_workspace_size, + max_algorithm_count); +} + #if CUDA_VERSION >= 11000 bool CUDABlas::DoBlasLtMatmulInternal( Stream *stream, bool err_on_failure, const blas::IBlasLtMatmulPlan *plan, @@ -3513,6 +3577,28 @@ bool CUDABlas::DoBlasLtMatmulInternal( } } + // This is only used when batch_count > kMaxBatchCount. + std::unique_ptr unique_remainder_algo; + if (cuda_plan.remainder_batch_count()) { + // There is no easy way to get the user-specified max workspace size here, + // so we just allow a very small amount and don't worry too much about + // performance because this is only used in rare cases. The same reasoning + // applies to selection of the algorithm. + size_t max_workspace_size = 4 * 1024 * 1024; // 4 MiB + auto status_or_algorithms = + GetBlasLtMatmulAlgorithmsInternal(plan, max_workspace_size, + /* max_algorithm_count = */ 1, + /* for_remainder_batch = */ true); + if (!status_or_algorithms.ok()) { + if (err_on_failure || VLOG_IS_ON(3)) { + LOG(ERROR) << "Failed to get algorithms for blasLt remainder batch."; + } + return false; + } + auto algorithms = status_or_algorithms.ConsumeValueOrDie(); + unique_remainder_algo = std::move(algorithms.front()); + } + cudaStream_t cuda_stream = CUDAStream(stream); absl::MutexLock lock(&mu_); @@ -3529,16 +3615,67 @@ bool CUDABlas::DoBlasLtMatmulInternal( gpu::ScopedActivateExecutorContext sac{parent_}; - cublasStatus_t ret = cublasLtMatmul( - blasLt_, cuda_plan.op_desc(), alpha_ptr, a.opaque(), cuda_plan.a_desc(), - b.opaque(), cuda_plan.b_desc(), beta_ptr, c.opaque(), cuda_plan.c_desc(), - d.opaque(), cuda_plan.d_desc(), cuda_algo.algo(), workspace, - cuda_algo.workspace_size(), cuda_stream); - if (ret != CUBLAS_STATUS_SUCCESS) { - if (err_on_failure || VLOG_IS_ON(3)) { - LOG(ERROR) << "failed to run cublasLtMatmul routine: " << ToString(ret); + // Plan execution is broken down into repeat calls with capped_batch_count, + // followed by a final call with remainder_batch_count. + // Cases where batch_count <= kMaxBatchCount require only a single call (a + // single loop iteration and no remainder). + int ab_type_size = GetDataTypeSizeBytes(cuda_plan.params().ab_type); + int c_type_size = GetDataTypeSizeBytes(cuda_plan.params().c_type); + const char* a_ptr = static_cast(a.opaque()); + const char* b_ptr = static_cast(b.opaque()); + const char* c_ptr = static_cast(c.opaque()); + char* d_ptr = static_cast(d.opaque()); + int capped_batch_count = cuda_plan.capped_batch_count(); + for (int batch = 0; + batch + capped_batch_count <= cuda_plan.params().batch_count; + batch += capped_batch_count) { + cublasStatus_t ret = cublasLtMatmul( + blasLt_, cuda_plan.op_desc(), alpha_ptr, a_ptr, cuda_plan.a_desc(), + b_ptr, cuda_plan.b_desc(), beta_ptr, c_ptr, cuda_plan.c_desc(), d_ptr, + cuda_plan.d_desc(), cuda_algo.algo(), workspace, + cuda_algo.workspace_size(), cuda_stream); + if (ret != CUBLAS_STATUS_SUCCESS) { + if (err_on_failure || VLOG_IS_ON(3)) { + LOG(ERROR) << "failed to run cublasLtMatmul routine: " << ToString(ret); + } + return false; + } + a_ptr += capped_batch_count * cuda_plan.params().stride_a * ab_type_size; + b_ptr += capped_batch_count * cuda_plan.params().stride_b * ab_type_size; + c_ptr += capped_batch_count * cuda_plan.params().stride_c * c_type_size; + d_ptr += capped_batch_count * cuda_plan.params().stride_c * c_type_size; + } + // This is only used when batch_count > kMaxBatchCount. + if (cuda_plan.remainder_batch_count()) { + const auto& remainder_algo = *static_cast( + unique_remainder_algo.get()); + if (remainder_algo.workspace_size()) { + port::Status allocation_status = AllocateWorkspace( + &workspace, scratch_allocator, remainder_algo.workspace_size()); + if (!allocation_status.ok()) { + if (err_on_failure || VLOG_IS_ON(3)) { + LOG(ERROR) << "Failed to allocate workspace for cublasLtMatmul algo " + "with id: " + << remainder_algo.algo_id() << " requiring " + << remainder_algo.workspace_size() + << " bytes of workspace"; + } + return false; + } + } + cublasStatus_t ret = cublasLtMatmul( + blasLt_, cuda_plan.op_desc(), alpha_ptr, a_ptr, + cuda_plan.a_remainder_desc(), b_ptr, cuda_plan.b_remainder_desc(), + beta_ptr, c_ptr, cuda_plan.c_remainder_desc(), d_ptr, + cuda_plan.d_remainder_desc(), remainder_algo.algo(), workspace, + remainder_algo.workspace_size(), cuda_stream); + if (ret != CUBLAS_STATUS_SUCCESS) { + if (err_on_failure || VLOG_IS_ON(3)) { + LOG(ERROR) << "failed to run remainder cublasLtMatmul routine: " + << ToString(ret); + } + return false; } - return false; } return true; } diff --git a/tensorflow/stream_executor/cuda/cuda_blas.h b/tensorflow/stream_executor/cuda/cuda_blas.h index 33d841b2c52..fa47dc98432 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.h +++ b/tensorflow/stream_executor/cuda/cuda_blas.h @@ -150,6 +150,13 @@ class CUDABlas : public blas::BlasSupport { const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias); + // Helper function for implementing GetBlasLtMatmulAlgorithms. + port::StatusOr>> + GetBlasLtMatmulAlgorithmsInternal(const blas::IBlasLtMatmulPlan* plan, + size_t max_workspace_size, + int max_algorithm_count, + bool for_remainder_batch = false); + // Guards the cuBLAS handle for this device. absl::Mutex mu_;