Refactor blasLt APIs to return Status, not bool

- This does not include DoBlasLtMatmul because the helpers in
  stream.cc require it to return bool.
This commit is contained in:
Ben Barsdell 2020-10-05 21:17:18 +11:00
parent a3dfb6f366
commit c491ca455c
5 changed files with 227 additions and 201 deletions

View File

@ -469,6 +469,17 @@ struct CoefficientType<Eigen::half> {
typedef float type;
};
inline Status FromExecutorStatus(const se::port::Status& s) {
return s.ok() ? Status::OK()
: Status(static_cast<error::Code>(static_cast<int>(s.code())),
s.error_message());
}
template <typename T>
inline Status FromExecutorStatus(const se::port::StatusOr<T>& s) {
return FromExecutorStatus(s.status());
}
} // namespace
template <typename Scalar>
@ -554,38 +565,25 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
context,
GetBlasComputationType(dtype, allow_tf32, &computation_type),
errors::Internal("Unsupported dtype for batched matmul"));
auto status_or_plan = stream->parent()->CreateBlasLtMatmulPlan(
{/*ab_type=*/blas_dtype,
/*c_type=*/blas_dtype, computation_type,
se::blas::PointerMode::kHost, se::blas::Epilogue::kDefault,
blas_transpose_b, blas_transpose_a, n, m, k,
/*lda=*/in_y.dim_size(2), /*ldb=*/in_x.dim_size(2), /*ldc=*/n,
batch_size, b_stride, a_stride, c_stride});
OP_REQUIRES(context, status_or_plan.ok(),
FromExecutorStatus(status_or_plan));
std::unique_ptr<se::blas::IBlasLtMatmulPlan> plan =
stream->parent()->CreateBlasLtMatmulPlan(
{/*ab_type=*/blas_dtype,
/*c_type=*/blas_dtype, computation_type,
se::blas::PointerMode::kHost, se::blas::Epilogue::kDefault,
blas_transpose_b, blas_transpose_a, n, m, k,
/*lda=*/in_y.dim_size(2), /*ldb=*/in_x.dim_size(2), /*ldc=*/n,
batch_size, b_stride, a_stride, c_stride});
OP_REQUIRES(
context, plan,
errors::Internal("CreateBlasLtMatmulPlan failed : a.shape=(",
in_x.dim_size(0), ", ", in_x.dim_size(1), ", ",
in_x.dim_size(2), "), b.shape=(", in_y.dim_size(0),
", ", in_y.dim_size(1), ", ", in_y.dim_size(2),
"), m=", m, ", n=", n, ", k=", k,
", batch_size=", batch_size, ", adjoint_a=", adj_x,
", adjoint_b=", adj_x, ", dtype=", dtype,
", computation_type=", computation_type));
std::vector<std::unique_ptr<se::blas::IBlasLtMatmulAlgorithm>>
algorithms;
OP_REQUIRES(
context,
stream->parent()->GetBlasLtMatmulAlgorithms(
plan.get(), max_scratch_size, max_algorithm_count, &algorithms),
errors::Internal("GetBlasLtMatmulAlgorithms failed: a.shape=(",
in_x.dim_size(0), ", ", in_x.dim_size(1), ", ",
in_x.dim_size(2), "), b.shape=(", in_y.dim_size(0),
", ", in_y.dim_size(1), ", ", in_y.dim_size(2),
"), m=", m, ", n=", n, ", k=", k,
", batch_size=", batch_size, ", adjoint_a=", adj_x,
", adjoint_b=", adj_x, ", dtype=", dtype,
", computation_type=", computation_type));
status_or_plan.ConsumeValueOrDie();
auto status_or_algorithms = stream->parent()->GetBlasLtMatmulAlgorithms(
plan.get(), max_scratch_size, max_algorithm_count);
OP_REQUIRES(context, status_or_algorithms.ok(),
FromExecutorStatus(status_or_algorithms));
auto algorithms = status_or_algorithms.ConsumeValueOrDie();
plan_and_algorithms =
BatchMatmulPlanMapSingleton::GetInstance()->Insert(
matmul_parameters, {std::move(plan), std::move(algorithms)});

View File

@ -1454,19 +1454,18 @@ class BlasSupport {
// Creates a backend-specific plan object for a blaslt matmul operation, which
// can then be passed to DoBlasLtMatmul(). When possible, plans should be
// created once and reused for multiple calls to DoBlasLtMatmul().
// Returns a null pointer on failure.
virtual std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
const blas::BlasLtMatmulPlanParams& params) = 0;
virtual port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams& params) = 0;
// Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
// returned in the order of increasing estimated compute time according to an
// internal heuristic. The first returned algorithm can be used as the default
// algorithm if no autotuning is to be performed.
virtual bool GetBlasLtMatmulAlgorithms(
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size,
int max_algorithm_count,
std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>*
out_algorithms) = 0;
virtual port::StatusOr<
std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan* plan,
size_t max_workspace_size,
int max_algorithm_count) = 0;
// Executes a blaslt matmul operation on the stream. If output_profile_result
// is not nullptr, the operation is profiled, error messages are
@ -2330,13 +2329,12 @@ class BlasSupport {
uint64 n, std::complex<double> alpha, \
const DeviceMemory<std::complex<double>> &a, int lda, \
DeviceMemory<std::complex<double>> *b, int ldb) override; \
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan( \
const blas::BlasLtMatmulPlanParams& params) override; \
bool GetBlasLtMatmulAlgorithms( \
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size, \
int max_algorithm_count, \
std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>* \
out_algorithms) override; \
port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>> \
CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams& params) override; \
port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>> \
GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan* plan, \
size_t max_workspace_size, \
int max_algorithm_count) override; \
bool DoBlasLtMatmul( \
Stream* stream, const blas::IBlasLtMatmulPlan* plan, \
const HostOrDeviceScalar<void>& alpha, DeviceMemoryBase a, \

View File

@ -3057,45 +3057,48 @@ bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
namespace {
template <typename T>
inline bool SetCublasLtAttr(cublasLtMatrixLayout_t handle,
cublasLtMatrixLayoutAttribute_t attr,
const T& value) {
inline port::Status SetCublasLtAttr(cublasLtMatrixLayout_t handle,
cublasLtMatrixLayoutAttribute_t attr,
const T& value) {
cublasStatus_t status =
cublasLtMatrixLayoutSetAttribute(handle, attr, &value, sizeof(T));
if (status != CUBLAS_STATUS_SUCCESS) {
VLOG(2) << "cublasLtMatrixLayoutSetAttribute(attr=" << attr
<< ", value=" << value << ") failed: " << ToString(status);
return false;
return port::Status(
port::error::INTERNAL,
absl::StrCat("cublasLtMatrixLayoutSetAttribute(attr=", attr,
", value=", value, ") failed: ", ToString(status)));
}
return true;
return port::Status::OK();
}
template <typename T>
inline bool SetCublasLtAttr(cublasLtMatmulAlgo_t* handle,
cublasLtMatmulAlgoConfigAttributes_t attr,
const T& value) {
inline port::Status SetCublasLtAttr(cublasLtMatmulAlgo_t* handle,
cublasLtMatmulAlgoConfigAttributes_t attr,
const T& value) {
cublasStatus_t status =
cublasLtMatmulAlgoConfigSetAttribute(handle, attr, &value, sizeof(T));
if (status != CUBLAS_STATUS_SUCCESS) {
VLOG(2) << "cublasLtMatmulAlgoConfigSetAttribute(attr=" << attr
<< ", value=" << value << ") failed: " << ToString(status);
return false;
return port::Status(
port::error::INTERNAL,
absl::StrCat("cublasLtMatmulAlgoConfigSetAttribute(attr=", attr,
", value=", value, ") failed: ", ToString(status)));
}
return true;
return port::Status::OK();
}
template <typename T>
inline bool SetCublasLtAttr(cublasLtMatmulPreference_t handle,
cublasLtMatmulPreferenceAttributes_t attr,
const T& value) {
inline port::Status SetCublasLtAttr(cublasLtMatmulPreference_t handle,
cublasLtMatmulPreferenceAttributes_t attr,
const T& value) {
cublasStatus_t status =
cublasLtMatmulPreferenceSetAttribute(handle, attr, &value, sizeof(value));
if (status != CUBLAS_STATUS_SUCCESS) {
VLOG(2) << "cublasLtMatmulPreferenceSetAttribute(attr=" << attr
<< ", value=" << value << ") failed: " << ToString(status);
return false;
return port::Status(
port::error::INTERNAL,
absl::StrCat("cublasLtMatmulPreferenceSetAttribute(attr=", attr,
", value=", value, ") failed: ", ToString(status)));
}
return true;
return port::Status::OK();
}
template <typename T>
@ -3111,17 +3114,27 @@ inline bool GetCublasLtAttr(const cublasLtMatmulAlgo_t* handle,
}
template <typename T>
inline bool SetCublasLtAttr(cublasLtMatmulDesc_t handle,
cublasLtMatmulDescAttributes_t attr,
const T& value) {
inline const T& ValueForStrCat(const T& value) {
return value;
}
template <typename T>
inline absl::Hex ValueForStrCat(T* ptr) {
return absl::Hex(reinterpret_cast<uintptr_t>(ptr));
}
template <typename T>
inline port::Status SetCublasLtAttr(cublasLtMatmulDesc_t handle,
cublasLtMatmulDescAttributes_t attr,
const T& value) {
cublasStatus_t status =
cublasLtMatmulDescSetAttribute(handle, attr, &value, sizeof(value));
if (status != CUBLAS_STATUS_SUCCESS) {
VLOG(2) << "cublasLtMatmulDescSetAttribute(attr=" << attr
<< ", value=" << value << ") failed: " << ToString(status);
return false;
return port::Status(
port::error::INTERNAL,
absl::StrCat("cublasLtMatmulDescSetAttribute(attr=", attr, ", value=",
ValueForStrCat(value), ") failed: ", ToString(status)));
}
return true;
return port::Status::OK();
}
struct MatmulDescDestroyer {
@ -3149,12 +3162,10 @@ using UniqueMatmulPreference =
std::unique_ptr<std::remove_pointer<cublasLtMatmulPreference_t>::type,
MatmulPreferenceDestroyer>;
UniqueOpDesc CreateCublasLtOperationDesc(blas::ComputationType computation_type,
blas::DataType scale_type,
blas::PointerMode pointer_mode,
blas::Epilogue epilogue,
blas::Transpose transa,
blas::Transpose transb) {
port::StatusOr<UniqueOpDesc> CreateCublasLtOperationDesc(
blas::ComputationType computation_type, blas::DataType scale_type,
blas::PointerMode pointer_mode, blas::Epilogue epilogue,
blas::Transpose transa, blas::Transpose transb) {
cublasLtMatmulDesc_t desc;
cublasComputeType_t cublas_compute_type =
CUBLASComputationType(computation_type);
@ -3162,40 +3173,39 @@ UniqueOpDesc CreateCublasLtOperationDesc(blas::ComputationType computation_type,
cublasStatus_t status =
cublasLtMatmulDescCreate(&desc, cublas_compute_type, cuda_scale_type);
if (status != CUBLAS_STATUS_SUCCESS) {
VLOG(2) << "cublasLtMatmulDescCreate(computation_type=" << computation_type
<< ") failed: " << ToString(status);
return nullptr;
return port::Status(
port::error::INTERNAL,
absl::StrCat("cublasLtMatmulDescCreate(computation_type=",
computation_type, ") failed: ", ToString(status)));
}
UniqueOpDesc unique_desc(desc);
if (!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_POINTER_MODE,
CUBLASPointerMode(pointer_mode)) ||
!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_EPILOGUE,
CUBLASEpilogue(epilogue)) ||
!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSA,
CUDABlasTranspose(transa)) ||
!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSB,
CUDABlasTranspose(transb))) {
return nullptr;
}
SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_POINTER_MODE,
CUBLASPointerMode(pointer_mode)));
SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_EPILOGUE,
CUBLASEpilogue(epilogue)));
SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSA,
CUDABlasTranspose(transa)));
SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSB,
CUDABlasTranspose(transb)));
return unique_desc;
}
UniqueLayoutDesc CreateCublasLtLayoutDesc(blas::DataType data_type, uint64 rows,
uint64 cols, int64 ld, int64 stride,
int batch_count) {
port::StatusOr<UniqueLayoutDesc> CreateCublasLtLayoutDesc(
blas::DataType data_type, uint64 rows, uint64 cols, int64 ld, int64 stride,
int batch_count) {
cublasLtMatrixLayout_t desc;
cublasStatus_t status = cublasLtMatrixLayoutCreate(
&desc, GetCUDADataType(data_type), rows, cols, ld);
if (status != CUBLAS_STATUS_SUCCESS) {
VLOG(2) << "cublasLtMatrixLayoutCreate failed: " << ToString(status);
return nullptr;
return port::Status(
port::error::INTERNAL,
absl::StrCat("cublasLtMatrixLayoutCreate failed: ", ToString(status)));
}
UniqueLayoutDesc unique_desc(desc);
if (!SetCublasLtAttr(desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_count) ||
!SetCublasLtAttr(desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
stride)) {
return nullptr;
}
SE_RETURN_IF_ERROR(
SetCublasLtAttr(desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_count));
SE_RETURN_IF_ERROR(SetCublasLtAttr(
desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride));
return unique_desc;
}
@ -3234,7 +3244,28 @@ blas::ComputationType ToComputationType<std::complex<double>>() {
class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
public:
CUDABlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams& params);
CUDABlasLtMatmulPlan(UniqueOpDesc op_desc, UniqueLayoutDesc a_desc,
UniqueLayoutDesc b_desc, UniqueLayoutDesc c_desc,
UniqueLayoutDesc d_desc, blas::DataType ab_type,
blas::DataType c_type, blas::DataType scale_type,
blas::PointerMode pointer_mode, blas::Epilogue epilogue,
int batch_count, int64 stride_a, int64 stride_b,
int64 stride_c, int64 stride_d)
: op_desc_(std::move(op_desc)),
a_desc_(std::move(a_desc)),
b_desc_(std::move(b_desc)),
c_desc_(std::move(c_desc)),
d_desc_(std::move(d_desc)),
ab_type_(ab_type),
c_type_(c_type),
scale_type_(scale_type),
pointer_mode_(pointer_mode),
epilogue_(epilogue),
batch_count_(batch_count),
stride_a_(stride_a),
stride_b_(stride_b),
stride_c_(stride_c),
stride_d_(stride_d) {}
cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); }
cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); }
@ -3276,40 +3307,9 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
int64 stride_d_;
};
CUDABlasLtMatmulPlan::CUDABlasLtMatmulPlan(
const blas::BlasLtMatmulPlanParams& p)
: op_desc_(CreateCublasLtOperationDesc(
p.computation_type, GetScaleType(p.c_type, p.computation_type),
p.pointer_mode, p.epilogue, p.transa, p.transb)),
a_desc_(nullptr),
b_desc_(nullptr),
c_desc_(CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
p.batch_count)),
d_desc_(CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
p.batch_count)),
ab_type_(p.ab_type),
cd_type_(p.c_type),
scale_type_(GetScaleType(p.c_type, p.computation_type)),
pointer_mode_(p.pointer_mode),
epilogue_(p.epilogue),
batch_count_(p.batch_count),
stride_a_(p.stride_a),
stride_b_(p.stride_b),
stride_c_(p.stride_c),
stride_d_(p.stride_c) {
uint64 rows_a = p.transa == blas::Transpose::kNoTranspose ? p.m : p.k;
uint64 cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m;
uint64 rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n;
uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k;
a_desc_ = CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda,
p.stride_a, p.batch_count);
b_desc_ = CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb,
p.stride_b, p.batch_count);
}
bool CUDABlasLtMatmulPlan::SetBiasPointer(const void* bias) const {
return SetCublasLtAttr(op_desc_.get(), CUBLASLT_MATMUL_DESC_BIAS_POINTER,
bias);
bias).ok();
}
class CUDABlasLtMatmulAlgorithm final : public blas::IBlasLtMatmulAlgorithm {
@ -3336,20 +3336,19 @@ class CUDABlasLtMatmulAlgorithm final : public blas::IBlasLtMatmulAlgorithm {
size_t workspace_size_;
};
UniqueMatmulPreference CreateCublasLtMatmulPreference(
const blas::IBlasLtMatmulPlan* plan,
size_t max_workspace_bytes) {
port::StatusOr<UniqueMatmulPreference> CreateCublasLtMatmulPreference(
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_bytes) {
cublasLtMatmulPreference_t preference;
cublasStatus_t status = cublasLtMatmulPreferenceCreate(&preference);
if (status != CUBLAS_STATUS_SUCCESS) {
VLOG(2) << "cublasLtMatmulPreferenceCreate failed: " << ToString(status);
return nullptr;
return port::Status(port::error::INTERNAL,
absl::StrCat("cublasLtMatmulPreferenceCreate failed: ",
ToString(status)));
}
UniqueMatmulPreference unique_preference(preference);
if (!SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
max_workspace_bytes)) {
return nullptr;
}
SE_RETURN_IF_ERROR(SetCublasLtAttr(preference,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
max_workspace_bytes));
const auto& cuda_plan = *static_cast<const CUDABlasLtMatmulPlan*>(plan);
if (cuda_plan.batch_count() == 0) {
@ -3361,25 +3360,28 @@ UniqueMatmulPreference CreateCublasLtMatmulPreference(
auto get_alignment_bytes = [](int64 stride, blas::DataType dtype) {
return (stride & -stride) * GetDataTypeSizeBytes(dtype);
};
if ((cuda_plan.stride_a() &&
!SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES,
if (cuda_plan.stride_a()) {
SE_RETURN_IF_ERROR(
SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES,
(uint32)get_alignment_bytes(cuda_plan.stride_a(),
cuda_plan.ab_type()))) ||
(cuda_plan.stride_b() &&
!SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES,
(uint32)get_alignment_bytes(cuda_plan.stride_b(),
cuda_plan.ab_type()))) ||
(cuda_plan.stride_c() &&
!SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES,
(uint32)get_alignment_bytes(cuda_plan.stride_c(),
cuda_plan.cd_type()))) ||
(cuda_plan.stride_d() &&
!SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES,
(uint32)get_alignment_bytes(cuda_plan.stride_d(),
cuda_plan.cd_type())))) {
return nullptr;
cuda_plan.ab_type())));
}
if (cuda_plan.stride_b()) {
SE_RETURN_IF_ERROR(
SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES,
(uint32)get_alignment_bytes(cuda_plan.stride_b(),
cuda_plan.ab_type())));
}
if (cuda_plan.stride_c()) {
SE_RETURN_IF_ERROR(SetCublasLtAttr(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES,
(uint32)get_alignment_bytes(cuda_plan.stride_c(), cuda_plan.c_type())));
}
if (cuda_plan.stride_d()) {
SE_RETURN_IF_ERROR(SetCublasLtAttr(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES,
(uint32)get_alignment_bytes(cuda_plan.stride_d(), cuda_plan.c_type())));
}
return unique_preference;
}
@ -3387,28 +3389,50 @@ UniqueMatmulPreference CreateCublasLtMatmulPreference(
#endif // CUDA_VERSION >= 11000
std::unique_ptr<blas::IBlasLtMatmulPlan> CUDABlas::CreateBlasLtMatmulPlan(
const blas::BlasLtMatmulPlanParams& params) {
port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
CUDABlas::CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams& p) {
#if CUDA_VERSION >= 11000
auto result = std::make_unique<CUDABlasLtMatmulPlan>(params);
if (!result->ok()) {
result.reset();
}
return result;
SE_ASSIGN_OR_RETURN(
auto op_desc,
CreateCublasLtOperationDesc(
p.computation_type, GetScaleType(p.c_type, p.computation_type),
p.pointer_mode, p.epilogue, p.transa, p.transb));
uint64 rows_a = p.transa == blas::Transpose::kNoTranspose ? p.m : p.k;
uint64 cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m;
uint64 rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n;
uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k;
SE_ASSIGN_OR_RETURN(auto a_desc,
CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda,
p.stride_a, p.batch_count));
SE_ASSIGN_OR_RETURN(auto b_desc,
CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb,
p.stride_b, p.batch_count));
SE_ASSIGN_OR_RETURN(auto c_desc,
CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc,
p.stride_c, p.batch_count));
SE_ASSIGN_OR_RETURN(auto d_desc,
CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc,
p.stride_c, p.batch_count));
blas::DataType scale_type = GetScaleType(p.c_type, p.computation_type);
return static_cast<std::unique_ptr<blas::IBlasLtMatmulPlan>>(
std::make_unique<CUDABlasLtMatmulPlan>(
std::move(op_desc), std::move(a_desc), std::move(b_desc),
std::move(c_desc), std::move(d_desc), p.ab_type, p.c_type, scale_type,
p.pointer_mode, p.epilogue, p.batch_count, p.stride_a, p.stride_b,
p.stride_c, p.stride_c));
#else
return nullptr;
#endif
}
bool CUDABlas::GetBlasLtMatmulAlgorithms(
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size,
int max_algorithm_count,
std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>*
out_algorithms) {
port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
CUDABlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan* plan,
size_t max_workspace_size,
int max_algorithm_count) {
#if CUDA_VERSION >= 11000
UniqueMatmulPreference preference =
CreateCublasLtMatmulPreference(plan, max_workspace_size);
if (!preference) return false;
SE_ASSIGN_OR_RETURN(UniqueMatmulPreference preference,
CreateCublasLtMatmulPreference(plan, max_workspace_size));
std::vector<cublasLtMatmulHeuristicResult_t> results(max_algorithm_count);
{
@ -3425,21 +3449,27 @@ bool CUDABlas::GetBlasLtMatmulAlgorithms(
cuda_plan.c_desc(), cuda_plan.d_desc(), preference.get(),
max_algorithm_count, results.data(), &found_algorithm_count);
if (status != CUBLAS_STATUS_SUCCESS) {
VLOG(2) << "cublasLtMatmulAlgoGetHeuristic failed: " << ToString(status);
return false;
return port::Status(
port::error::INTERNAL,
absl::StrCat("cublasLtMatmulAlgoGetHeuristic failed: ",
ToString(status)));
}
results.resize(found_algorithm_count);
}
std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>> out_algorithms;
out_algorithms.reserve(results.size());
for (size_t i = 0; i < results.size(); ++i) {
const auto& result = results[i];
if (result.state != CUBLAS_STATUS_SUCCESS) continue; // Skip failed algos
out_algorithms->emplace_back(std::make_unique<CUDABlasLtMatmulAlgorithm>(
out_algorithms.emplace_back(std::make_unique<CUDABlasLtMatmulAlgorithm>(
i, result.algo, result.workspaceSize));
}
return true;
return out_algorithms;
#else // if CUDA_VERSION < 11000
return false;
return port::Status(
port::error::UNIMPLEMENTED,
"GetBlasLtMatmulAlgorithms is not supported with this version of CUDA");
#endif
}

View File

@ -336,26 +336,28 @@ bool StreamExecutor::GetBlasGemmAlgorithms(
return blas_support->GetBlasGemmAlgorithms(out_algorithms);
}
std::unique_ptr<blas::IBlasLtMatmulPlan> StreamExecutor::CreateBlasLtMatmulPlan(
port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
StreamExecutor::CreateBlasLtMatmulPlan(
const blas::BlasLtMatmulPlanParams& params) {
blas::BlasSupport *blas_support = AsBlas();
blas::BlasSupport* blas_support = AsBlas();
if (!blas_support) {
return nullptr;
return port::Status(port::error::UNKNOWN,
"Fail to find the blas implementation.");
}
return blas_support->CreateBlasLtMatmulPlan(params);
}
bool StreamExecutor::GetBlasLtMatmulAlgorithms(
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size,
int max_algorithm_count,
std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>*
out_algorithms) {
blas::BlasSupport *blas_support = AsBlas();
port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
StreamExecutor::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan* plan,
size_t max_workspace_size,
int max_algorithm_count) {
blas::BlasSupport* blas_support = AsBlas();
if (!blas_support) {
return false;
return port::Status(port::error::UNKNOWN,
"Fail to find the blas implementation.");
}
return blas_support->GetBlasLtMatmulAlgorithms(
plan, max_workspace_size, max_algorithm_count, out_algorithms);
return blas_support->GetBlasLtMatmulAlgorithms(plan, max_workspace_size,
max_algorithm_count);
}
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>

View File

@ -398,18 +398,16 @@ class StreamExecutor {
// can then be passed to DoBlasLtMatmul(). When possible, plans should be
// created once and reused for multiple calls to DoBlasLtMatmul().
// Returns a null pointer on failure.
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
const blas::BlasLtMatmulPlanParams& params);
port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams& params);
// Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
// returned in the order of increasing estimated compute time according to an
// internal heuristic. The first returned algorithm can be used as the default
// algorithm if no autotuning is to be performed.
bool GetBlasLtMatmulAlgorithms(
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size,
int max_algorithm_count,
std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>*
out_algorithms);
port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan* plan,
size_t max_workspace_size, int max_algorithm_count);
// Create an RNN descriptor based on model shapes and configurations.
// The caller retains the ownership of the descriptor.