Merge pull request #44175 from benbarsdell:cublaslt-large-batch-workaround-proper
PiperOrigin-RevId: 338087096 Change-Id: Iad9dd0c7f6a6d38043958eea83b0c1b1ad93b287
This commit is contained in:
commit
7e5dc369b8
@ -3257,67 +3257,103 @@ blas::ComputationType ToComputationType<std::complex<double>>() {
|
||||
|
||||
class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
|
||||
public:
|
||||
CUDABlasLtMatmulPlan(UniqueOpDesc op_desc, UniqueLayoutDesc a_desc,
|
||||
UniqueLayoutDesc b_desc, UniqueLayoutDesc c_desc,
|
||||
UniqueLayoutDesc d_desc, blas::DataType ab_type,
|
||||
blas::DataType c_type, blas::DataType scale_type,
|
||||
blas::PointerMode pointer_mode, blas::Epilogue epilogue,
|
||||
int batch_count, int64 stride_a, int64 stride_b,
|
||||
int64 stride_c, int64 stride_d)
|
||||
: op_desc_(std::move(op_desc)),
|
||||
a_desc_(std::move(a_desc)),
|
||||
b_desc_(std::move(b_desc)),
|
||||
c_desc_(std::move(c_desc)),
|
||||
d_desc_(std::move(d_desc)),
|
||||
ab_type_(ab_type),
|
||||
c_type_(c_type),
|
||||
scale_type_(scale_type),
|
||||
pointer_mode_(pointer_mode),
|
||||
epilogue_(epilogue),
|
||||
batch_count_(batch_count),
|
||||
stride_a_(stride_a),
|
||||
stride_b_(stride_b),
|
||||
stride_c_(stride_c),
|
||||
stride_d_(stride_d) {}
|
||||
port::Status init(const blas::BlasLtMatmulPlanParams &p) {
|
||||
params_ = p;
|
||||
scale_type_ = GetScaleType(p.c_type, p.computation_type);
|
||||
SE_ASSIGN_OR_RETURN(
|
||||
op_desc_,
|
||||
CreateCublasLtOperationDesc(
|
||||
p.computation_type, GetScaleType(p.c_type, p.computation_type),
|
||||
p.pointer_mode, p.epilogue, p.transa, p.transb));
|
||||
uint64 rows_a = p.transa == blas::Transpose::kNoTranspose ? p.m : p.k;
|
||||
uint64 cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m;
|
||||
uint64 rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n;
|
||||
uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k;
|
||||
SE_ASSIGN_OR_RETURN(
|
||||
a_desc_, CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda,
|
||||
p.stride_a, capped_batch_count()));
|
||||
SE_ASSIGN_OR_RETURN(
|
||||
b_desc_, CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb,
|
||||
p.stride_b, capped_batch_count()));
|
||||
SE_ASSIGN_OR_RETURN(
|
||||
c_desc_, CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
|
||||
capped_batch_count()));
|
||||
SE_ASSIGN_OR_RETURN(
|
||||
d_desc_, CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
|
||||
capped_batch_count()));
|
||||
remainder_batch_count_ =
|
||||
p.batch_count > kMaxBatchCount ? p.batch_count % kMaxBatchCount : 0;
|
||||
if (remainder_batch_count_) {
|
||||
SE_ASSIGN_OR_RETURN(
|
||||
a_remainder_desc_,
|
||||
CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda, p.stride_a,
|
||||
remainder_batch_count_));
|
||||
SE_ASSIGN_OR_RETURN(
|
||||
b_remainder_desc_,
|
||||
CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb, p.stride_b,
|
||||
remainder_batch_count_));
|
||||
SE_ASSIGN_OR_RETURN(
|
||||
c_remainder_desc_,
|
||||
CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
|
||||
remainder_batch_count_));
|
||||
SE_ASSIGN_OR_RETURN(
|
||||
d_remainder_desc_,
|
||||
CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
|
||||
remainder_batch_count_));
|
||||
}
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); }
|
||||
cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); }
|
||||
cublasLtMatrixLayout_t b_desc() const { return b_desc_.get(); }
|
||||
cublasLtMatrixLayout_t c_desc() const { return c_desc_.get(); }
|
||||
cublasLtMatrixLayout_t d_desc() const { return d_desc_.get(); }
|
||||
bool ok() { return op_desc_ && a_desc_ && b_desc_ && c_desc_ && d_desc_; }
|
||||
cublasLtMatrixLayout_t a_remainder_desc() const {
|
||||
return a_remainder_desc_.get();
|
||||
}
|
||||
cublasLtMatrixLayout_t b_remainder_desc() const {
|
||||
return b_remainder_desc_.get();
|
||||
}
|
||||
cublasLtMatrixLayout_t c_remainder_desc() const {
|
||||
return c_remainder_desc_.get();
|
||||
}
|
||||
cublasLtMatrixLayout_t d_remainder_desc() const {
|
||||
return d_remainder_desc_.get();
|
||||
}
|
||||
|
||||
blas::DataType ab_type() const override { return ab_type_; }
|
||||
blas::DataType c_type() const override { return c_type_; }
|
||||
const blas::BlasLtMatmulPlanParams ¶ms() const { return params_; }
|
||||
blas::DataType scale_type() const { return scale_type_; }
|
||||
blas::PointerMode pointer_mode() const { return pointer_mode_; }
|
||||
blas::Epilogue epilogue() const { return epilogue_; }
|
||||
int batch_count() const { return batch_count_; }
|
||||
int64 stride_a() const { return stride_a_; }
|
||||
int64 stride_b() const { return stride_b_; }
|
||||
int64 stride_c() const { return stride_c_; }
|
||||
int64 stride_d() const { return stride_d_; }
|
||||
blas::DataType ab_type() const override { return params_.ab_type; }
|
||||
blas::DataType c_type() const override { return params_.c_type; }
|
||||
int capped_batch_count() const {
|
||||
return std::min(params_.batch_count, kMaxBatchCount);
|
||||
}
|
||||
int remainder_batch_count() const { return remainder_batch_count_; }
|
||||
|
||||
// Note: Must be const to satisfy API. This is always called before the plan
|
||||
// is executed, so the state change is not observed in subsequent executions.
|
||||
bool SetBiasPointer(const void *bias) const;
|
||||
|
||||
private:
|
||||
// In some cases cublasLt does not support large batch sizes, so we need to
|
||||
// split up such cases into multiple calls.
|
||||
static constexpr const int kMaxBatchCount = 65535;
|
||||
blas::BlasLtMatmulPlanParams params_;
|
||||
blas::DataType scale_type_;
|
||||
UniqueOpDesc op_desc_;
|
||||
// These have batch count set to capped_batch_count().
|
||||
UniqueLayoutDesc a_desc_;
|
||||
UniqueLayoutDesc b_desc_;
|
||||
UniqueLayoutDesc c_desc_;
|
||||
UniqueLayoutDesc d_desc_;
|
||||
blas::DataType ab_type_;
|
||||
blas::DataType c_type_;
|
||||
blas::DataType scale_type_;
|
||||
blas::PointerMode pointer_mode_;
|
||||
blas::Epilogue epilogue_;
|
||||
int batch_count_;
|
||||
int64 stride_a_;
|
||||
int64 stride_b_;
|
||||
int64 stride_c_;
|
||||
int64 stride_d_;
|
||||
int remainder_batch_count_;
|
||||
// These have batch count set to remainder_batch_count_, and are only created
|
||||
// if params_.batch_count > kMaxBatchSize.
|
||||
UniqueLayoutDesc a_remainder_desc_;
|
||||
UniqueLayoutDesc b_remainder_desc_;
|
||||
UniqueLayoutDesc c_remainder_desc_;
|
||||
UniqueLayoutDesc d_remainder_desc_;
|
||||
};
|
||||
|
||||
bool CUDABlasLtMatmulPlan::SetBiasPointer(const void *bias) const {
|
||||
@ -3365,7 +3401,7 @@ port::StatusOr<UniqueMatmulPreference> CreateCublasLtMatmulPreference(
|
||||
max_workspace_bytes));
|
||||
|
||||
const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
|
||||
if (cuda_plan.batch_count() == 0) {
|
||||
if (cuda_plan.params().batch_count == 0) {
|
||||
return unique_preference;
|
||||
}
|
||||
// This is a workaround for a known issue in cuBlasLt where the heuristic may
|
||||
@ -3374,27 +3410,29 @@ port::StatusOr<UniqueMatmulPreference> CreateCublasLtMatmulPreference(
|
||||
auto get_alignment_bytes = [](int64 stride, blas::DataType dtype) {
|
||||
return (stride & -stride) * GetDataTypeSizeBytes(dtype);
|
||||
};
|
||||
if (cuda_plan.stride_a()) {
|
||||
SE_RETURN_IF_ERROR(
|
||||
SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES,
|
||||
(uint32)get_alignment_bytes(cuda_plan.stride_a(),
|
||||
cuda_plan.ab_type())));
|
||||
if (cuda_plan.params().stride_a) {
|
||||
SE_RETURN_IF_ERROR(SetCublasLtAttr(
|
||||
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES,
|
||||
(uint32)get_alignment_bytes(cuda_plan.params().stride_a,
|
||||
cuda_plan.params().ab_type)));
|
||||
}
|
||||
if (cuda_plan.stride_b()) {
|
||||
SE_RETURN_IF_ERROR(
|
||||
SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES,
|
||||
(uint32)get_alignment_bytes(cuda_plan.stride_b(),
|
||||
cuda_plan.ab_type())));
|
||||
if (cuda_plan.params().stride_b) {
|
||||
SE_RETURN_IF_ERROR(SetCublasLtAttr(
|
||||
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES,
|
||||
(uint32)get_alignment_bytes(cuda_plan.params().stride_b,
|
||||
cuda_plan.params().ab_type)));
|
||||
}
|
||||
if (cuda_plan.stride_c()) {
|
||||
if (cuda_plan.params().stride_c) {
|
||||
SE_RETURN_IF_ERROR(SetCublasLtAttr(
|
||||
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES,
|
||||
(uint32)get_alignment_bytes(cuda_plan.stride_c(), cuda_plan.c_type())));
|
||||
(uint32)get_alignment_bytes(cuda_plan.params().stride_c,
|
||||
cuda_plan.params().c_type)));
|
||||
}
|
||||
if (cuda_plan.stride_d()) {
|
||||
if (cuda_plan.params().stride_c) {
|
||||
SE_RETURN_IF_ERROR(SetCublasLtAttr(
|
||||
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES,
|
||||
(uint32)get_alignment_bytes(cuda_plan.stride_d(), cuda_plan.c_type())));
|
||||
(uint32)get_alignment_bytes(cuda_plan.params().stride_c,
|
||||
cuda_plan.params().c_type)));
|
||||
}
|
||||
return unique_preference;
|
||||
}
|
||||
@ -3406,35 +3444,10 @@ port::StatusOr<UniqueMatmulPreference> CreateCublasLtMatmulPreference(
|
||||
port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
|
||||
CUDABlas::CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &p) {
|
||||
#if CUDA_VERSION >= 11000
|
||||
SE_ASSIGN_OR_RETURN(
|
||||
auto op_desc,
|
||||
CreateCublasLtOperationDesc(
|
||||
p.computation_type, GetScaleType(p.c_type, p.computation_type),
|
||||
p.pointer_mode, p.epilogue, p.transa, p.transb));
|
||||
uint64 rows_a = p.transa == blas::Transpose::kNoTranspose ? p.m : p.k;
|
||||
uint64 cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m;
|
||||
uint64 rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n;
|
||||
uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k;
|
||||
SE_ASSIGN_OR_RETURN(auto a_desc,
|
||||
CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda,
|
||||
p.stride_a, p.batch_count));
|
||||
SE_ASSIGN_OR_RETURN(auto b_desc,
|
||||
CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb,
|
||||
p.stride_b, p.batch_count));
|
||||
SE_ASSIGN_OR_RETURN(auto c_desc,
|
||||
CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc,
|
||||
p.stride_c, p.batch_count));
|
||||
SE_ASSIGN_OR_RETURN(auto d_desc,
|
||||
CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc,
|
||||
p.stride_c, p.batch_count));
|
||||
blas::DataType scale_type = GetScaleType(p.c_type, p.computation_type);
|
||||
|
||||
auto cuda_plan = std::make_unique<CUDABlasLtMatmulPlan>();
|
||||
SE_RETURN_IF_ERROR(cuda_plan->init(p));
|
||||
return static_cast<std::unique_ptr<blas::IBlasLtMatmulPlan>>(
|
||||
std::make_unique<CUDABlasLtMatmulPlan>(
|
||||
std::move(op_desc), std::move(a_desc), std::move(b_desc),
|
||||
std::move(c_desc), std::move(d_desc), p.ab_type, p.c_type, scale_type,
|
||||
p.pointer_mode, p.epilogue, p.batch_count, p.stride_a, p.stride_b,
|
||||
p.stride_c, p.stride_c));
|
||||
std::move(cuda_plan));
|
||||
#else
|
||||
return port::Status(
|
||||
port::error::UNIMPLEMENTED,
|
||||
@ -3443,9 +3456,10 @@ CUDABlas::CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &p) {
|
||||
}
|
||||
|
||||
port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
|
||||
CUDABlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
|
||||
size_t max_workspace_size,
|
||||
int max_algorithm_count) {
|
||||
CUDABlas::GetBlasLtMatmulAlgorithmsInternal(const blas::IBlasLtMatmulPlan *plan,
|
||||
size_t max_workspace_size,
|
||||
int max_algorithm_count,
|
||||
bool for_remainder_batch) {
|
||||
#if CUDA_VERSION >= 11000
|
||||
SE_ASSIGN_OR_RETURN(UniqueMatmulPreference preference,
|
||||
CreateCublasLtMatmulPreference(plan, max_workspace_size));
|
||||
@ -3460,10 +3474,18 @@ CUDABlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
|
||||
|
||||
int found_algorithm_count = 0;
|
||||
const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
|
||||
const auto &a_desc =
|
||||
for_remainder_batch ? cuda_plan.a_remainder_desc() : cuda_plan.a_desc();
|
||||
const auto &b_desc =
|
||||
for_remainder_batch ? cuda_plan.b_remainder_desc() : cuda_plan.b_desc();
|
||||
const auto &c_desc =
|
||||
for_remainder_batch ? cuda_plan.c_remainder_desc() : cuda_plan.c_desc();
|
||||
const auto &d_desc =
|
||||
for_remainder_batch ? cuda_plan.d_remainder_desc() : cuda_plan.d_desc();
|
||||
cublasStatus_t status = cublasLtMatmulAlgoGetHeuristic(
|
||||
blasLt_, cuda_plan.op_desc(), cuda_plan.a_desc(), cuda_plan.b_desc(),
|
||||
cuda_plan.c_desc(), cuda_plan.d_desc(), preference.get(),
|
||||
max_algorithm_count, results.data(), &found_algorithm_count);
|
||||
blasLt_, cuda_plan.op_desc(), a_desc, b_desc, c_desc, d_desc,
|
||||
preference.get(), max_algorithm_count, results.data(),
|
||||
&found_algorithm_count);
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
return port::Status(
|
||||
port::error::INTERNAL,
|
||||
@ -3489,6 +3511,14 @@ CUDABlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
|
||||
#endif
|
||||
}
|
||||
|
||||
port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
|
||||
CUDABlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
|
||||
size_t max_workspace_size,
|
||||
int max_algorithm_count) {
|
||||
return GetBlasLtMatmulAlgorithmsInternal(plan, max_workspace_size,
|
||||
max_algorithm_count);
|
||||
}
|
||||
|
||||
#if CUDA_VERSION >= 11000
|
||||
bool CUDABlas::DoBlasLtMatmulInternal(
|
||||
Stream *stream, bool err_on_failure, const blas::IBlasLtMatmulPlan *plan,
|
||||
@ -3514,26 +3544,19 @@ bool CUDABlas::DoBlasLtMatmulInternal(
|
||||
return false;
|
||||
}
|
||||
bool is_pointer_mode_host = !alpha.is_pointer();
|
||||
if ((cuda_plan.pointer_mode() == blas::PointerMode::kHost) !=
|
||||
if ((cuda_plan.params().pointer_mode == blas::PointerMode::kHost) !=
|
||||
is_pointer_mode_host) {
|
||||
VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong "
|
||||
"pointer_mode for the given alpha/beta.";
|
||||
return false;
|
||||
}
|
||||
if ((cuda_plan.epilogue() == blas::Epilogue::kBias ||
|
||||
cuda_plan.epilogue() == blas::Epilogue::kBiasThenReLU) !=
|
||||
if ((cuda_plan.params().epilogue == blas::Epilogue::kBias ||
|
||||
cuda_plan.params().epilogue == blas::Epilogue::kBiasThenReLU) !=
|
||||
(bias != nullptr)) {
|
||||
VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong "
|
||||
"epilogue for the given bias pointer.";
|
||||
return false;
|
||||
}
|
||||
if (bias != nullptr) {
|
||||
if (!cuda_plan.SetBiasPointer(bias.opaque())) {
|
||||
VLOG(2) << "DoBlasLtMatmul returning false because setting the bias "
|
||||
"pointer failed.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
const void *alpha_ptr = alpha.is_pointer() ? alpha.opaque_pointer().opaque()
|
||||
: alpha.opaque_value();
|
||||
const void *beta_ptr =
|
||||
@ -3554,24 +3577,106 @@ bool CUDABlas::DoBlasLtMatmulInternal(
|
||||
}
|
||||
}
|
||||
|
||||
// This is only used when batch_count > kMaxBatchCount.
|
||||
std::unique_ptr<blas::IBlasLtMatmulAlgorithm> unique_remainder_algo;
|
||||
if (cuda_plan.remainder_batch_count()) {
|
||||
// There is no easy way to get the user-specified max workspace size here,
|
||||
// so we just allow a very small amount and don't worry too much about
|
||||
// performance because this is only used in rare cases. The same reasoning
|
||||
// applies to selection of the algorithm.
|
||||
size_t max_workspace_size = 4 * 1024 * 1024; // 4 MiB
|
||||
auto status_or_algorithms =
|
||||
GetBlasLtMatmulAlgorithmsInternal(plan, max_workspace_size,
|
||||
/* max_algorithm_count = */ 1,
|
||||
/* for_remainder_batch = */ true);
|
||||
if (!status_or_algorithms.ok()) {
|
||||
if (err_on_failure || VLOG_IS_ON(3)) {
|
||||
LOG(ERROR) << "Failed to get algorithms for blasLt remainder batch.";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
auto algorithms = status_or_algorithms.ConsumeValueOrDie();
|
||||
unique_remainder_algo = std::move(algorithms.front());
|
||||
}
|
||||
|
||||
cudaStream_t cuda_stream = CUDAStream(stream);
|
||||
|
||||
absl::MutexLock lock(&mu_);
|
||||
|
||||
if (bias != nullptr) {
|
||||
if (!cuda_plan.SetBiasPointer(bias.opaque())) {
|
||||
VLOG(2) << "DoBlasLtMatmul returning false because setting the bias "
|
||||
"pointer failed.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
CHECK(blasLt_ != nullptr);
|
||||
|
||||
gpu::ScopedActivateExecutorContext sac{parent_};
|
||||
|
||||
cublasStatus_t ret = cublasLtMatmul(
|
||||
blasLt_, cuda_plan.op_desc(), alpha_ptr, a.opaque(), cuda_plan.a_desc(),
|
||||
b.opaque(), cuda_plan.b_desc(), beta_ptr, c.opaque(), cuda_plan.c_desc(),
|
||||
d.opaque(), cuda_plan.d_desc(), cuda_algo.algo(), workspace,
|
||||
cuda_algo.workspace_size(), cuda_stream);
|
||||
if (ret != CUBLAS_STATUS_SUCCESS) {
|
||||
if (err_on_failure || VLOG_IS_ON(3)) {
|
||||
LOG(ERROR) << "failed to run cublasLtMatmul routine: " << ToString(ret);
|
||||
// Plan execution is broken down into repeat calls with capped_batch_count,
|
||||
// followed by a final call with remainder_batch_count.
|
||||
// Cases where batch_count <= kMaxBatchCount require only a single call (a
|
||||
// single loop iteration and no remainder).
|
||||
int ab_type_size = GetDataTypeSizeBytes(cuda_plan.params().ab_type);
|
||||
int c_type_size = GetDataTypeSizeBytes(cuda_plan.params().c_type);
|
||||
const char *a_ptr = static_cast<const char *>(a.opaque());
|
||||
const char *b_ptr = static_cast<const char *>(b.opaque());
|
||||
const char *c_ptr = static_cast<const char *>(c.opaque());
|
||||
char *d_ptr = static_cast<char *>(d.opaque());
|
||||
int capped_batch_count = cuda_plan.capped_batch_count();
|
||||
for (int batch = 0;
|
||||
batch + capped_batch_count <= cuda_plan.params().batch_count;
|
||||
batch += capped_batch_count) {
|
||||
cublasStatus_t ret = cublasLtMatmul(
|
||||
blasLt_, cuda_plan.op_desc(), alpha_ptr, a_ptr, cuda_plan.a_desc(),
|
||||
b_ptr, cuda_plan.b_desc(), beta_ptr, c_ptr, cuda_plan.c_desc(), d_ptr,
|
||||
cuda_plan.d_desc(), cuda_algo.algo(), workspace,
|
||||
cuda_algo.workspace_size(), cuda_stream);
|
||||
if (ret != CUBLAS_STATUS_SUCCESS) {
|
||||
if (err_on_failure || VLOG_IS_ON(3)) {
|
||||
LOG(ERROR) << "failed to run cublasLtMatmul routine: " << ToString(ret);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
a_ptr += capped_batch_count * cuda_plan.params().stride_a * ab_type_size;
|
||||
b_ptr += capped_batch_count * cuda_plan.params().stride_b * ab_type_size;
|
||||
c_ptr += capped_batch_count * cuda_plan.params().stride_c * c_type_size;
|
||||
d_ptr += capped_batch_count * cuda_plan.params().stride_c * c_type_size;
|
||||
}
|
||||
// This is only used when batch_count > kMaxBatchCount.
|
||||
if (cuda_plan.remainder_batch_count()) {
|
||||
const auto &remainder_algo =
|
||||
*static_cast<const CUDABlasLtMatmulAlgorithm *>(
|
||||
unique_remainder_algo.get());
|
||||
if (remainder_algo.workspace_size()) {
|
||||
port::Status allocation_status = AllocateWorkspace(
|
||||
&workspace, scratch_allocator, remainder_algo.workspace_size());
|
||||
if (!allocation_status.ok()) {
|
||||
if (err_on_failure || VLOG_IS_ON(3)) {
|
||||
LOG(ERROR) << "Failed to allocate workspace for cublasLtMatmul algo "
|
||||
"with id: "
|
||||
<< remainder_algo.algo_id() << " requiring "
|
||||
<< remainder_algo.workspace_size()
|
||||
<< " bytes of workspace";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
cublasStatus_t ret = cublasLtMatmul(
|
||||
blasLt_, cuda_plan.op_desc(), alpha_ptr, a_ptr,
|
||||
cuda_plan.a_remainder_desc(), b_ptr, cuda_plan.b_remainder_desc(),
|
||||
beta_ptr, c_ptr, cuda_plan.c_remainder_desc(), d_ptr,
|
||||
cuda_plan.d_remainder_desc(), remainder_algo.algo(), workspace,
|
||||
remainder_algo.workspace_size(), cuda_stream);
|
||||
if (ret != CUBLAS_STATUS_SUCCESS) {
|
||||
if (err_on_failure || VLOG_IS_ON(3)) {
|
||||
LOG(ERROR) << "failed to run remainder cublasLtMatmul routine: "
|
||||
<< ToString(ret);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -150,6 +150,13 @@ class CUDABlas : public blas::BlasSupport {
|
||||
const blas::IBlasLtMatmulAlgorithm *algorithm,
|
||||
DeviceMemoryBase bias);
|
||||
|
||||
// Helper function for implementing GetBlasLtMatmulAlgorithms.
|
||||
port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
|
||||
GetBlasLtMatmulAlgorithmsInternal(const blas::IBlasLtMatmulPlan *plan,
|
||||
size_t max_workspace_size,
|
||||
int max_algorithm_count,
|
||||
bool for_remainder_batch = false);
|
||||
|
||||
// Guards the cuBLAS handle for this device.
|
||||
absl::Mutex mu_;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user