Merge pull request #44175 from benbarsdell:cublaslt-large-batch-workaround-proper

PiperOrigin-RevId: 338087096
Change-Id: Iad9dd0c7f6a6d38043958eea83b0c1b1ad93b287
This commit is contained in:
TensorFlower Gardener 2020-10-20 10:42:41 -07:00
commit 7e5dc369b8
2 changed files with 222 additions and 110 deletions

View File

@ -3257,67 +3257,103 @@ blas::ComputationType ToComputationType<std::complex<double>>() {
class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
public:
CUDABlasLtMatmulPlan(UniqueOpDesc op_desc, UniqueLayoutDesc a_desc,
UniqueLayoutDesc b_desc, UniqueLayoutDesc c_desc,
UniqueLayoutDesc d_desc, blas::DataType ab_type,
blas::DataType c_type, blas::DataType scale_type,
blas::PointerMode pointer_mode, blas::Epilogue epilogue,
int batch_count, int64 stride_a, int64 stride_b,
int64 stride_c, int64 stride_d)
: op_desc_(std::move(op_desc)),
a_desc_(std::move(a_desc)),
b_desc_(std::move(b_desc)),
c_desc_(std::move(c_desc)),
d_desc_(std::move(d_desc)),
ab_type_(ab_type),
c_type_(c_type),
scale_type_(scale_type),
pointer_mode_(pointer_mode),
epilogue_(epilogue),
batch_count_(batch_count),
stride_a_(stride_a),
stride_b_(stride_b),
stride_c_(stride_c),
stride_d_(stride_d) {}
port::Status init(const blas::BlasLtMatmulPlanParams &p) {
params_ = p;
scale_type_ = GetScaleType(p.c_type, p.computation_type);
SE_ASSIGN_OR_RETURN(
op_desc_,
CreateCublasLtOperationDesc(
p.computation_type, GetScaleType(p.c_type, p.computation_type),
p.pointer_mode, p.epilogue, p.transa, p.transb));
uint64 rows_a = p.transa == blas::Transpose::kNoTranspose ? p.m : p.k;
uint64 cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m;
uint64 rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n;
uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k;
SE_ASSIGN_OR_RETURN(
a_desc_, CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda,
p.stride_a, capped_batch_count()));
SE_ASSIGN_OR_RETURN(
b_desc_, CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb,
p.stride_b, capped_batch_count()));
SE_ASSIGN_OR_RETURN(
c_desc_, CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
capped_batch_count()));
SE_ASSIGN_OR_RETURN(
d_desc_, CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
capped_batch_count()));
remainder_batch_count_ =
p.batch_count > kMaxBatchCount ? p.batch_count % kMaxBatchCount : 0;
if (remainder_batch_count_) {
SE_ASSIGN_OR_RETURN(
a_remainder_desc_,
CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda, p.stride_a,
remainder_batch_count_));
SE_ASSIGN_OR_RETURN(
b_remainder_desc_,
CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb, p.stride_b,
remainder_batch_count_));
SE_ASSIGN_OR_RETURN(
c_remainder_desc_,
CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
remainder_batch_count_));
SE_ASSIGN_OR_RETURN(
d_remainder_desc_,
CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
remainder_batch_count_));
}
return port::Status::OK();
}
cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); }
cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); }
cublasLtMatrixLayout_t b_desc() const { return b_desc_.get(); }
cublasLtMatrixLayout_t c_desc() const { return c_desc_.get(); }
cublasLtMatrixLayout_t d_desc() const { return d_desc_.get(); }
bool ok() { return op_desc_ && a_desc_ && b_desc_ && c_desc_ && d_desc_; }
cublasLtMatrixLayout_t a_remainder_desc() const {
return a_remainder_desc_.get();
}
cublasLtMatrixLayout_t b_remainder_desc() const {
return b_remainder_desc_.get();
}
cublasLtMatrixLayout_t c_remainder_desc() const {
return c_remainder_desc_.get();
}
cublasLtMatrixLayout_t d_remainder_desc() const {
return d_remainder_desc_.get();
}
blas::DataType ab_type() const override { return ab_type_; }
blas::DataType c_type() const override { return c_type_; }
const blas::BlasLtMatmulPlanParams &params() const { return params_; }
blas::DataType scale_type() const { return scale_type_; }
blas::PointerMode pointer_mode() const { return pointer_mode_; }
blas::Epilogue epilogue() const { return epilogue_; }
int batch_count() const { return batch_count_; }
int64 stride_a() const { return stride_a_; }
int64 stride_b() const { return stride_b_; }
int64 stride_c() const { return stride_c_; }
int64 stride_d() const { return stride_d_; }
blas::DataType ab_type() const override { return params_.ab_type; }
blas::DataType c_type() const override { return params_.c_type; }
int capped_batch_count() const {
return std::min(params_.batch_count, kMaxBatchCount);
}
int remainder_batch_count() const { return remainder_batch_count_; }
// Note: Must be const to satisfy API. This is always called before the plan
// is executed, so the state change is not observed in subsequent executions.
bool SetBiasPointer(const void *bias) const;
private:
// In some cases cublasLt does not support large batch sizes, so we need to
// split up such cases into multiple calls.
static constexpr const int kMaxBatchCount = 65535;
blas::BlasLtMatmulPlanParams params_;
blas::DataType scale_type_;
UniqueOpDesc op_desc_;
// These have batch count set to capped_batch_count().
UniqueLayoutDesc a_desc_;
UniqueLayoutDesc b_desc_;
UniqueLayoutDesc c_desc_;
UniqueLayoutDesc d_desc_;
blas::DataType ab_type_;
blas::DataType c_type_;
blas::DataType scale_type_;
blas::PointerMode pointer_mode_;
blas::Epilogue epilogue_;
int batch_count_;
int64 stride_a_;
int64 stride_b_;
int64 stride_c_;
int64 stride_d_;
int remainder_batch_count_;
// These have batch count set to remainder_batch_count_, and are only created
// if params_.batch_count > kMaxBatchSize.
UniqueLayoutDesc a_remainder_desc_;
UniqueLayoutDesc b_remainder_desc_;
UniqueLayoutDesc c_remainder_desc_;
UniqueLayoutDesc d_remainder_desc_;
};
bool CUDABlasLtMatmulPlan::SetBiasPointer(const void *bias) const {
@ -3365,7 +3401,7 @@ port::StatusOr<UniqueMatmulPreference> CreateCublasLtMatmulPreference(
max_workspace_bytes));
const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
if (cuda_plan.batch_count() == 0) {
if (cuda_plan.params().batch_count == 0) {
return unique_preference;
}
// This is a workaround for a known issue in cuBlasLt where the heuristic may
@ -3374,27 +3410,29 @@ port::StatusOr<UniqueMatmulPreference> CreateCublasLtMatmulPreference(
auto get_alignment_bytes = [](int64 stride, blas::DataType dtype) {
return (stride & -stride) * GetDataTypeSizeBytes(dtype);
};
if (cuda_plan.stride_a()) {
SE_RETURN_IF_ERROR(
SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES,
(uint32)get_alignment_bytes(cuda_plan.stride_a(),
cuda_plan.ab_type())));
if (cuda_plan.params().stride_a) {
SE_RETURN_IF_ERROR(SetCublasLtAttr(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES,
(uint32)get_alignment_bytes(cuda_plan.params().stride_a,
cuda_plan.params().ab_type)));
}
if (cuda_plan.stride_b()) {
SE_RETURN_IF_ERROR(
SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES,
(uint32)get_alignment_bytes(cuda_plan.stride_b(),
cuda_plan.ab_type())));
if (cuda_plan.params().stride_b) {
SE_RETURN_IF_ERROR(SetCublasLtAttr(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES,
(uint32)get_alignment_bytes(cuda_plan.params().stride_b,
cuda_plan.params().ab_type)));
}
if (cuda_plan.stride_c()) {
if (cuda_plan.params().stride_c) {
SE_RETURN_IF_ERROR(SetCublasLtAttr(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES,
(uint32)get_alignment_bytes(cuda_plan.stride_c(), cuda_plan.c_type())));
(uint32)get_alignment_bytes(cuda_plan.params().stride_c,
cuda_plan.params().c_type)));
}
if (cuda_plan.stride_d()) {
if (cuda_plan.params().stride_c) {
SE_RETURN_IF_ERROR(SetCublasLtAttr(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES,
(uint32)get_alignment_bytes(cuda_plan.stride_d(), cuda_plan.c_type())));
(uint32)get_alignment_bytes(cuda_plan.params().stride_c,
cuda_plan.params().c_type)));
}
return unique_preference;
}
@ -3406,35 +3444,10 @@ port::StatusOr<UniqueMatmulPreference> CreateCublasLtMatmulPreference(
port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
CUDABlas::CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &p) {
#if CUDA_VERSION >= 11000
SE_ASSIGN_OR_RETURN(
auto op_desc,
CreateCublasLtOperationDesc(
p.computation_type, GetScaleType(p.c_type, p.computation_type),
p.pointer_mode, p.epilogue, p.transa, p.transb));
uint64 rows_a = p.transa == blas::Transpose::kNoTranspose ? p.m : p.k;
uint64 cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m;
uint64 rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n;
uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k;
SE_ASSIGN_OR_RETURN(auto a_desc,
CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda,
p.stride_a, p.batch_count));
SE_ASSIGN_OR_RETURN(auto b_desc,
CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb,
p.stride_b, p.batch_count));
SE_ASSIGN_OR_RETURN(auto c_desc,
CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc,
p.stride_c, p.batch_count));
SE_ASSIGN_OR_RETURN(auto d_desc,
CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc,
p.stride_c, p.batch_count));
blas::DataType scale_type = GetScaleType(p.c_type, p.computation_type);
auto cuda_plan = std::make_unique<CUDABlasLtMatmulPlan>();
SE_RETURN_IF_ERROR(cuda_plan->init(p));
return static_cast<std::unique_ptr<blas::IBlasLtMatmulPlan>>(
std::make_unique<CUDABlasLtMatmulPlan>(
std::move(op_desc), std::move(a_desc), std::move(b_desc),
std::move(c_desc), std::move(d_desc), p.ab_type, p.c_type, scale_type,
p.pointer_mode, p.epilogue, p.batch_count, p.stride_a, p.stride_b,
p.stride_c, p.stride_c));
std::move(cuda_plan));
#else
return port::Status(
port::error::UNIMPLEMENTED,
@ -3443,9 +3456,10 @@ CUDABlas::CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &p) {
}
port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
CUDABlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
size_t max_workspace_size,
int max_algorithm_count) {
CUDABlas::GetBlasLtMatmulAlgorithmsInternal(const blas::IBlasLtMatmulPlan *plan,
size_t max_workspace_size,
int max_algorithm_count,
bool for_remainder_batch) {
#if CUDA_VERSION >= 11000
SE_ASSIGN_OR_RETURN(UniqueMatmulPreference preference,
CreateCublasLtMatmulPreference(plan, max_workspace_size));
@ -3460,10 +3474,18 @@ CUDABlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
int found_algorithm_count = 0;
const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
const auto &a_desc =
for_remainder_batch ? cuda_plan.a_remainder_desc() : cuda_plan.a_desc();
const auto &b_desc =
for_remainder_batch ? cuda_plan.b_remainder_desc() : cuda_plan.b_desc();
const auto &c_desc =
for_remainder_batch ? cuda_plan.c_remainder_desc() : cuda_plan.c_desc();
const auto &d_desc =
for_remainder_batch ? cuda_plan.d_remainder_desc() : cuda_plan.d_desc();
cublasStatus_t status = cublasLtMatmulAlgoGetHeuristic(
blasLt_, cuda_plan.op_desc(), cuda_plan.a_desc(), cuda_plan.b_desc(),
cuda_plan.c_desc(), cuda_plan.d_desc(), preference.get(),
max_algorithm_count, results.data(), &found_algorithm_count);
blasLt_, cuda_plan.op_desc(), a_desc, b_desc, c_desc, d_desc,
preference.get(), max_algorithm_count, results.data(),
&found_algorithm_count);
if (status != CUBLAS_STATUS_SUCCESS) {
return port::Status(
port::error::INTERNAL,
@ -3489,6 +3511,14 @@ CUDABlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
#endif
}
port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
CUDABlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
size_t max_workspace_size,
int max_algorithm_count) {
return GetBlasLtMatmulAlgorithmsInternal(plan, max_workspace_size,
max_algorithm_count);
}
#if CUDA_VERSION >= 11000
bool CUDABlas::DoBlasLtMatmulInternal(
Stream *stream, bool err_on_failure, const blas::IBlasLtMatmulPlan *plan,
@ -3514,26 +3544,19 @@ bool CUDABlas::DoBlasLtMatmulInternal(
return false;
}
bool is_pointer_mode_host = !alpha.is_pointer();
if ((cuda_plan.pointer_mode() == blas::PointerMode::kHost) !=
if ((cuda_plan.params().pointer_mode == blas::PointerMode::kHost) !=
is_pointer_mode_host) {
VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong "
"pointer_mode for the given alpha/beta.";
return false;
}
if ((cuda_plan.epilogue() == blas::Epilogue::kBias ||
cuda_plan.epilogue() == blas::Epilogue::kBiasThenReLU) !=
if ((cuda_plan.params().epilogue == blas::Epilogue::kBias ||
cuda_plan.params().epilogue == blas::Epilogue::kBiasThenReLU) !=
(bias != nullptr)) {
VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong "
"epilogue for the given bias pointer.";
return false;
}
if (bias != nullptr) {
if (!cuda_plan.SetBiasPointer(bias.opaque())) {
VLOG(2) << "DoBlasLtMatmul returning false because setting the bias "
"pointer failed.";
return false;
}
}
const void *alpha_ptr = alpha.is_pointer() ? alpha.opaque_pointer().opaque()
: alpha.opaque_value();
const void *beta_ptr =
@ -3554,24 +3577,106 @@ bool CUDABlas::DoBlasLtMatmulInternal(
}
}
// This is only used when batch_count > kMaxBatchCount.
std::unique_ptr<blas::IBlasLtMatmulAlgorithm> unique_remainder_algo;
if (cuda_plan.remainder_batch_count()) {
// There is no easy way to get the user-specified max workspace size here,
// so we just allow a very small amount and don't worry too much about
// performance because this is only used in rare cases. The same reasoning
// applies to selection of the algorithm.
size_t max_workspace_size = 4 * 1024 * 1024; // 4 MiB
auto status_or_algorithms =
GetBlasLtMatmulAlgorithmsInternal(plan, max_workspace_size,
/* max_algorithm_count = */ 1,
/* for_remainder_batch = */ true);
if (!status_or_algorithms.ok()) {
if (err_on_failure || VLOG_IS_ON(3)) {
LOG(ERROR) << "Failed to get algorithms for blasLt remainder batch.";
}
return false;
}
auto algorithms = status_or_algorithms.ConsumeValueOrDie();
unique_remainder_algo = std::move(algorithms.front());
}
cudaStream_t cuda_stream = CUDAStream(stream);
absl::MutexLock lock(&mu_);
if (bias != nullptr) {
if (!cuda_plan.SetBiasPointer(bias.opaque())) {
VLOG(2) << "DoBlasLtMatmul returning false because setting the bias "
"pointer failed.";
return false;
}
}
CHECK(blasLt_ != nullptr);
gpu::ScopedActivateExecutorContext sac{parent_};
cublasStatus_t ret = cublasLtMatmul(
blasLt_, cuda_plan.op_desc(), alpha_ptr, a.opaque(), cuda_plan.a_desc(),
b.opaque(), cuda_plan.b_desc(), beta_ptr, c.opaque(), cuda_plan.c_desc(),
d.opaque(), cuda_plan.d_desc(), cuda_algo.algo(), workspace,
cuda_algo.workspace_size(), cuda_stream);
if (ret != CUBLAS_STATUS_SUCCESS) {
if (err_on_failure || VLOG_IS_ON(3)) {
LOG(ERROR) << "failed to run cublasLtMatmul routine: " << ToString(ret);
// Plan execution is broken down into repeat calls with capped_batch_count,
// followed by a final call with remainder_batch_count.
// Cases where batch_count <= kMaxBatchCount require only a single call (a
// single loop iteration and no remainder).
int ab_type_size = GetDataTypeSizeBytes(cuda_plan.params().ab_type);
int c_type_size = GetDataTypeSizeBytes(cuda_plan.params().c_type);
const char *a_ptr = static_cast<const char *>(a.opaque());
const char *b_ptr = static_cast<const char *>(b.opaque());
const char *c_ptr = static_cast<const char *>(c.opaque());
char *d_ptr = static_cast<char *>(d.opaque());
int capped_batch_count = cuda_plan.capped_batch_count();
for (int batch = 0;
batch + capped_batch_count <= cuda_plan.params().batch_count;
batch += capped_batch_count) {
cublasStatus_t ret = cublasLtMatmul(
blasLt_, cuda_plan.op_desc(), alpha_ptr, a_ptr, cuda_plan.a_desc(),
b_ptr, cuda_plan.b_desc(), beta_ptr, c_ptr, cuda_plan.c_desc(), d_ptr,
cuda_plan.d_desc(), cuda_algo.algo(), workspace,
cuda_algo.workspace_size(), cuda_stream);
if (ret != CUBLAS_STATUS_SUCCESS) {
if (err_on_failure || VLOG_IS_ON(3)) {
LOG(ERROR) << "failed to run cublasLtMatmul routine: " << ToString(ret);
}
return false;
}
a_ptr += capped_batch_count * cuda_plan.params().stride_a * ab_type_size;
b_ptr += capped_batch_count * cuda_plan.params().stride_b * ab_type_size;
c_ptr += capped_batch_count * cuda_plan.params().stride_c * c_type_size;
d_ptr += capped_batch_count * cuda_plan.params().stride_c * c_type_size;
}
// This is only used when batch_count > kMaxBatchCount.
if (cuda_plan.remainder_batch_count()) {
const auto &remainder_algo =
*static_cast<const CUDABlasLtMatmulAlgorithm *>(
unique_remainder_algo.get());
if (remainder_algo.workspace_size()) {
port::Status allocation_status = AllocateWorkspace(
&workspace, scratch_allocator, remainder_algo.workspace_size());
if (!allocation_status.ok()) {
if (err_on_failure || VLOG_IS_ON(3)) {
LOG(ERROR) << "Failed to allocate workspace for cublasLtMatmul algo "
"with id: "
<< remainder_algo.algo_id() << " requiring "
<< remainder_algo.workspace_size()
<< " bytes of workspace";
}
return false;
}
}
cublasStatus_t ret = cublasLtMatmul(
blasLt_, cuda_plan.op_desc(), alpha_ptr, a_ptr,
cuda_plan.a_remainder_desc(), b_ptr, cuda_plan.b_remainder_desc(),
beta_ptr, c_ptr, cuda_plan.c_remainder_desc(), d_ptr,
cuda_plan.d_remainder_desc(), remainder_algo.algo(), workspace,
remainder_algo.workspace_size(), cuda_stream);
if (ret != CUBLAS_STATUS_SUCCESS) {
if (err_on_failure || VLOG_IS_ON(3)) {
LOG(ERROR) << "failed to run remainder cublasLtMatmul routine: "
<< ToString(ret);
}
return false;
}
return false;
}
return true;
}

View File

@ -150,6 +150,13 @@ class CUDABlas : public blas::BlasSupport {
const blas::IBlasLtMatmulAlgorithm *algorithm,
DeviceMemoryBase bias);
// Helper function for implementing GetBlasLtMatmulAlgorithms.
port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
GetBlasLtMatmulAlgorithmsInternal(const blas::IBlasLtMatmulPlan *plan,
size_t max_workspace_size,
int max_algorithm_count,
bool for_remainder_batch = false);
// Guards the cuBLAS handle for this device.
absl::Mutex mu_;