Workaround blasLt batch size limitation

- cublasLtMatmul does not always support batch sizes > 65535.
- This commit breaks plan execution into repeated calls with up to the
  max batch size, followed by a remainder call.
This commit is contained in:
Ben Barsdell 2020-10-20 19:51:20 +11:00
parent 7b477b8752
commit aae59c53f4
2 changed files with 167 additions and 23 deletions

View File

@ -3258,6 +3258,8 @@ blas::ComputationType ToComputationType<std::complex<double>>() {
class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
public:
port::Status init(const blas::BlasLtMatmulPlanParams& p) {
params_ = p;
scale_type_ = GetScaleType(p.c_type, p.computation_type);
SE_ASSIGN_OR_RETURN(
op_desc_,
CreateCublasLtOperationDesc(
@ -3269,18 +3271,36 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k;
SE_ASSIGN_OR_RETURN(
a_desc_, CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda,
p.stride_a, p.batch_count));
p.stride_a, capped_batch_count()));
SE_ASSIGN_OR_RETURN(
b_desc_, CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb,
p.stride_b, p.batch_count));
p.stride_b, capped_batch_count()));
SE_ASSIGN_OR_RETURN(
c_desc_, CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
p.batch_count));
capped_batch_count()));
SE_ASSIGN_OR_RETURN(
d_desc_, CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
p.batch_count));
params_ = p;
scale_type_ = GetScaleType(p.c_type, p.computation_type);
capped_batch_count()));
remainder_batch_count_ =
p.batch_count > kMaxBatchCount ? p.batch_count % kMaxBatchCount : 0;
if (remainder_batch_count_) {
SE_ASSIGN_OR_RETURN(
a_remainder_desc_,
CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda, p.stride_a,
remainder_batch_count_));
SE_ASSIGN_OR_RETURN(
b_remainder_desc_,
CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb, p.stride_b,
remainder_batch_count_));
SE_ASSIGN_OR_RETURN(
c_remainder_desc_,
CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
remainder_batch_count_));
SE_ASSIGN_OR_RETURN(
d_remainder_desc_,
CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
remainder_batch_count_));
}
return port::Status::OK();
}
@ -3289,24 +3309,51 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
cublasLtMatrixLayout_t b_desc() const { return b_desc_.get(); }
cublasLtMatrixLayout_t c_desc() const { return c_desc_.get(); }
cublasLtMatrixLayout_t d_desc() const { return d_desc_.get(); }
cublasLtMatrixLayout_t a_remainder_desc() const {
return a_remainder_desc_.get();
}
cublasLtMatrixLayout_t b_remainder_desc() const {
return b_remainder_desc_.get();
}
cublasLtMatrixLayout_t c_remainder_desc() const {
return c_remainder_desc_.get();
}
cublasLtMatrixLayout_t d_remainder_desc() const {
return d_remainder_desc_.get();
}
const blas::BlasLtMatmulPlanParams& params() const { return params_; }
blas::DataType scale_type() const { return scale_type_; }
blas::DataType ab_type() const override { return params_.ab_type; }
blas::DataType c_type() const override { return params_.c_type; }
int capped_batch_count() const {
return std::min(params_.batch_count, kMaxBatchCount);
}
int remainder_batch_count() const { return remainder_batch_count_; }
// Note: Must be const to satisfy API. This is always called before the plan
// is executed, so the state change is not observed in subsequent executions.
bool SetBiasPointer(const void *bias) const;
private:
// In some cases cublasLt does not support large batch sizes, so we need to
// split up such cases into multiple calls.
static constexpr const int kMaxBatchCount = 65535;
blas::BlasLtMatmulPlanParams params_;
blas::DataType scale_type_;
UniqueOpDesc op_desc_;
// These have batch count set to capped_batch_count().
UniqueLayoutDesc a_desc_;
UniqueLayoutDesc b_desc_;
UniqueLayoutDesc c_desc_;
UniqueLayoutDesc d_desc_;
blas::BlasLtMatmulPlanParams params_;
blas::DataType scale_type_;
int remainder_batch_count_;
// These have batch count set to remainder_batch_count_, and are only created
// if params_.batch_count > kMaxBatchSize.
UniqueLayoutDesc a_remainder_desc_;
UniqueLayoutDesc b_remainder_desc_;
UniqueLayoutDesc c_remainder_desc_;
UniqueLayoutDesc d_remainder_desc_;
};
bool CUDABlasLtMatmulPlan::SetBiasPointer(const void *bias) const {
@ -3409,9 +3456,10 @@ CUDABlas::CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &p) {
}
port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
CUDABlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
size_t max_workspace_size,
int max_algorithm_count) {
CUDABlas::GetBlasLtMatmulAlgorithmsInternal(const blas::IBlasLtMatmulPlan* plan,
size_t max_workspace_size,
int max_algorithm_count,
bool for_remainder_batch) {
#if CUDA_VERSION >= 11000
SE_ASSIGN_OR_RETURN(UniqueMatmulPreference preference,
CreateCublasLtMatmulPreference(plan, max_workspace_size));
@ -3426,10 +3474,18 @@ CUDABlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
int found_algorithm_count = 0;
const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
const auto& a_desc =
for_remainder_batch ? cuda_plan.a_remainder_desc() : cuda_plan.a_desc();
const auto& b_desc =
for_remainder_batch ? cuda_plan.b_remainder_desc() : cuda_plan.b_desc();
const auto& c_desc =
for_remainder_batch ? cuda_plan.c_remainder_desc() : cuda_plan.c_desc();
const auto& d_desc =
for_remainder_batch ? cuda_plan.d_remainder_desc() : cuda_plan.d_desc();
cublasStatus_t status = cublasLtMatmulAlgoGetHeuristic(
blasLt_, cuda_plan.op_desc(), cuda_plan.a_desc(), cuda_plan.b_desc(),
cuda_plan.c_desc(), cuda_plan.d_desc(), preference.get(),
max_algorithm_count, results.data(), &found_algorithm_count);
blasLt_, cuda_plan.op_desc(), a_desc, b_desc, c_desc, d_desc,
preference.get(), max_algorithm_count, results.data(),
&found_algorithm_count);
if (status != CUBLAS_STATUS_SUCCESS) {
return port::Status(
port::error::INTERNAL,
@ -3455,6 +3511,14 @@ CUDABlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
#endif
}
port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
CUDABlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan* plan,
size_t max_workspace_size,
int max_algorithm_count) {
return GetBlasLtMatmulAlgorithmsInternal(plan, max_workspace_size,
max_algorithm_count);
}
#if CUDA_VERSION >= 11000
bool CUDABlas::DoBlasLtMatmulInternal(
Stream *stream, bool err_on_failure, const blas::IBlasLtMatmulPlan *plan,
@ -3513,6 +3577,28 @@ bool CUDABlas::DoBlasLtMatmulInternal(
}
}
// This is only used when batch_count > kMaxBatchCount.
std::unique_ptr<blas::IBlasLtMatmulAlgorithm> unique_remainder_algo;
if (cuda_plan.remainder_batch_count()) {
// There is no easy way to get the user-specified max workspace size here,
// so we just allow a very small amount and don't worry too much about
// performance because this is only used in rare cases. The same reasoning
// applies to selection of the algorithm.
size_t max_workspace_size = 4 * 1024 * 1024; // 4 MiB
auto status_or_algorithms =
GetBlasLtMatmulAlgorithmsInternal(plan, max_workspace_size,
/* max_algorithm_count = */ 1,
/* for_remainder_batch = */ true);
if (!status_or_algorithms.ok()) {
if (err_on_failure || VLOG_IS_ON(3)) {
LOG(ERROR) << "Failed to get algorithms for blasLt remainder batch.";
}
return false;
}
auto algorithms = status_or_algorithms.ConsumeValueOrDie();
unique_remainder_algo = std::move(algorithms.front());
}
cudaStream_t cuda_stream = CUDAStream(stream);
absl::MutexLock lock(&mu_);
@ -3529,16 +3615,67 @@ bool CUDABlas::DoBlasLtMatmulInternal(
gpu::ScopedActivateExecutorContext sac{parent_};
cublasStatus_t ret = cublasLtMatmul(
blasLt_, cuda_plan.op_desc(), alpha_ptr, a.opaque(), cuda_plan.a_desc(),
b.opaque(), cuda_plan.b_desc(), beta_ptr, c.opaque(), cuda_plan.c_desc(),
d.opaque(), cuda_plan.d_desc(), cuda_algo.algo(), workspace,
cuda_algo.workspace_size(), cuda_stream);
if (ret != CUBLAS_STATUS_SUCCESS) {
if (err_on_failure || VLOG_IS_ON(3)) {
LOG(ERROR) << "failed to run cublasLtMatmul routine: " << ToString(ret);
// Plan execution is broken down into repeat calls with capped_batch_count,
// followed by a final call with remainder_batch_count.
// Cases where batch_count <= kMaxBatchCount require only a single call (a
// single loop iteration and no remainder).
int ab_type_size = GetDataTypeSizeBytes(cuda_plan.params().ab_type);
int c_type_size = GetDataTypeSizeBytes(cuda_plan.params().c_type);
const char* a_ptr = static_cast<const char*>(a.opaque());
const char* b_ptr = static_cast<const char*>(b.opaque());
const char* c_ptr = static_cast<const char*>(c.opaque());
char* d_ptr = static_cast<char*>(d.opaque());
int capped_batch_count = cuda_plan.capped_batch_count();
for (int batch = 0;
batch + capped_batch_count <= cuda_plan.params().batch_count;
batch += capped_batch_count) {
cublasStatus_t ret = cublasLtMatmul(
blasLt_, cuda_plan.op_desc(), alpha_ptr, a_ptr, cuda_plan.a_desc(),
b_ptr, cuda_plan.b_desc(), beta_ptr, c_ptr, cuda_plan.c_desc(), d_ptr,
cuda_plan.d_desc(), cuda_algo.algo(), workspace,
cuda_algo.workspace_size(), cuda_stream);
if (ret != CUBLAS_STATUS_SUCCESS) {
if (err_on_failure || VLOG_IS_ON(3)) {
LOG(ERROR) << "failed to run cublasLtMatmul routine: " << ToString(ret);
}
return false;
}
a_ptr += capped_batch_count * cuda_plan.params().stride_a * ab_type_size;
b_ptr += capped_batch_count * cuda_plan.params().stride_b * ab_type_size;
c_ptr += capped_batch_count * cuda_plan.params().stride_c * c_type_size;
d_ptr += capped_batch_count * cuda_plan.params().stride_c * c_type_size;
}
// This is only used when batch_count > kMaxBatchCount.
if (cuda_plan.remainder_batch_count()) {
const auto& remainder_algo = *static_cast<const CUDABlasLtMatmulAlgorithm*>(
unique_remainder_algo.get());
if (remainder_algo.workspace_size()) {
port::Status allocation_status = AllocateWorkspace(
&workspace, scratch_allocator, remainder_algo.workspace_size());
if (!allocation_status.ok()) {
if (err_on_failure || VLOG_IS_ON(3)) {
LOG(ERROR) << "Failed to allocate workspace for cublasLtMatmul algo "
"with id: "
<< remainder_algo.algo_id() << " requiring "
<< remainder_algo.workspace_size()
<< " bytes of workspace";
}
return false;
}
}
cublasStatus_t ret = cublasLtMatmul(
blasLt_, cuda_plan.op_desc(), alpha_ptr, a_ptr,
cuda_plan.a_remainder_desc(), b_ptr, cuda_plan.b_remainder_desc(),
beta_ptr, c_ptr, cuda_plan.c_remainder_desc(), d_ptr,
cuda_plan.d_remainder_desc(), remainder_algo.algo(), workspace,
remainder_algo.workspace_size(), cuda_stream);
if (ret != CUBLAS_STATUS_SUCCESS) {
if (err_on_failure || VLOG_IS_ON(3)) {
LOG(ERROR) << "failed to run remainder cublasLtMatmul routine: "
<< ToString(ret);
}
return false;
}
return false;
}
return true;
}

View File

@ -150,6 +150,13 @@ class CUDABlas : public blas::BlasSupport {
const blas::IBlasLtMatmulAlgorithm *algorithm,
DeviceMemoryBase bias);
// Helper function for implementing GetBlasLtMatmulAlgorithms.
port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
GetBlasLtMatmulAlgorithmsInternal(const blas::IBlasLtMatmulPlan* plan,
size_t max_workspace_size,
int max_algorithm_count,
bool for_remainder_batch = false);
// Guards the cuBLAS handle for this device.
absl::Mutex mu_;