Refactor blasLt APIs to return Status, not bool
- This does not include DoBlasLtMatmul because the helpers in stream.cc require it to return bool.
This commit is contained in:
parent
a3dfb6f366
commit
c491ca455c
@ -469,6 +469,17 @@ struct CoefficientType<Eigen::half> {
|
||||
typedef float type;
|
||||
};
|
||||
|
||||
inline Status FromExecutorStatus(const se::port::Status& s) {
|
||||
return s.ok() ? Status::OK()
|
||||
: Status(static_cast<error::Code>(static_cast<int>(s.code())),
|
||||
s.error_message());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline Status FromExecutorStatus(const se::port::StatusOr<T>& s) {
|
||||
return FromExecutorStatus(s.status());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Scalar>
|
||||
@ -554,38 +565,25 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
context,
|
||||
GetBlasComputationType(dtype, allow_tf32, &computation_type),
|
||||
errors::Internal("Unsupported dtype for batched matmul"));
|
||||
|
||||
auto status_or_plan = stream->parent()->CreateBlasLtMatmulPlan(
|
||||
{/*ab_type=*/blas_dtype,
|
||||
/*c_type=*/blas_dtype, computation_type,
|
||||
se::blas::PointerMode::kHost, se::blas::Epilogue::kDefault,
|
||||
blas_transpose_b, blas_transpose_a, n, m, k,
|
||||
/*lda=*/in_y.dim_size(2), /*ldb=*/in_x.dim_size(2), /*ldc=*/n,
|
||||
batch_size, b_stride, a_stride, c_stride});
|
||||
OP_REQUIRES(context, status_or_plan.ok(),
|
||||
FromExecutorStatus(status_or_plan));
|
||||
std::unique_ptr<se::blas::IBlasLtMatmulPlan> plan =
|
||||
stream->parent()->CreateBlasLtMatmulPlan(
|
||||
{/*ab_type=*/blas_dtype,
|
||||
/*c_type=*/blas_dtype, computation_type,
|
||||
se::blas::PointerMode::kHost, se::blas::Epilogue::kDefault,
|
||||
blas_transpose_b, blas_transpose_a, n, m, k,
|
||||
/*lda=*/in_y.dim_size(2), /*ldb=*/in_x.dim_size(2), /*ldc=*/n,
|
||||
batch_size, b_stride, a_stride, c_stride});
|
||||
OP_REQUIRES(
|
||||
context, plan,
|
||||
errors::Internal("CreateBlasLtMatmulPlan failed : a.shape=(",
|
||||
in_x.dim_size(0), ", ", in_x.dim_size(1), ", ",
|
||||
in_x.dim_size(2), "), b.shape=(", in_y.dim_size(0),
|
||||
", ", in_y.dim_size(1), ", ", in_y.dim_size(2),
|
||||
"), m=", m, ", n=", n, ", k=", k,
|
||||
", batch_size=", batch_size, ", adjoint_a=", adj_x,
|
||||
", adjoint_b=", adj_x, ", dtype=", dtype,
|
||||
", computation_type=", computation_type));
|
||||
std::vector<std::unique_ptr<se::blas::IBlasLtMatmulAlgorithm>>
|
||||
algorithms;
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
stream->parent()->GetBlasLtMatmulAlgorithms(
|
||||
plan.get(), max_scratch_size, max_algorithm_count, &algorithms),
|
||||
errors::Internal("GetBlasLtMatmulAlgorithms failed: a.shape=(",
|
||||
in_x.dim_size(0), ", ", in_x.dim_size(1), ", ",
|
||||
in_x.dim_size(2), "), b.shape=(", in_y.dim_size(0),
|
||||
", ", in_y.dim_size(1), ", ", in_y.dim_size(2),
|
||||
"), m=", m, ", n=", n, ", k=", k,
|
||||
", batch_size=", batch_size, ", adjoint_a=", adj_x,
|
||||
", adjoint_b=", adj_x, ", dtype=", dtype,
|
||||
", computation_type=", computation_type));
|
||||
status_or_plan.ConsumeValueOrDie();
|
||||
|
||||
auto status_or_algorithms = stream->parent()->GetBlasLtMatmulAlgorithms(
|
||||
plan.get(), max_scratch_size, max_algorithm_count);
|
||||
OP_REQUIRES(context, status_or_algorithms.ok(),
|
||||
FromExecutorStatus(status_or_algorithms));
|
||||
auto algorithms = status_or_algorithms.ConsumeValueOrDie();
|
||||
|
||||
plan_and_algorithms =
|
||||
BatchMatmulPlanMapSingleton::GetInstance()->Insert(
|
||||
matmul_parameters, {std::move(plan), std::move(algorithms)});
|
||||
|
@ -1454,19 +1454,18 @@ class BlasSupport {
|
||||
// Creates a backend-specific plan object for a blaslt matmul operation, which
|
||||
// can then be passed to DoBlasLtMatmul(). When possible, plans should be
|
||||
// created once and reused for multiple calls to DoBlasLtMatmul().
|
||||
// Returns a null pointer on failure.
|
||||
virtual std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
|
||||
const blas::BlasLtMatmulPlanParams& params) = 0;
|
||||
virtual port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
|
||||
CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams& params) = 0;
|
||||
|
||||
// Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
|
||||
// returned in the order of increasing estimated compute time according to an
|
||||
// internal heuristic. The first returned algorithm can be used as the default
|
||||
// algorithm if no autotuning is to be performed.
|
||||
virtual bool GetBlasLtMatmulAlgorithms(
|
||||
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size,
|
||||
int max_algorithm_count,
|
||||
std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>*
|
||||
out_algorithms) = 0;
|
||||
virtual port::StatusOr<
|
||||
std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
|
||||
GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan* plan,
|
||||
size_t max_workspace_size,
|
||||
int max_algorithm_count) = 0;
|
||||
|
||||
// Executes a blaslt matmul operation on the stream. If output_profile_result
|
||||
// is not nullptr, the operation is profiled, error messages are
|
||||
@ -2330,13 +2329,12 @@ class BlasSupport {
|
||||
uint64 n, std::complex<double> alpha, \
|
||||
const DeviceMemory<std::complex<double>> &a, int lda, \
|
||||
DeviceMemory<std::complex<double>> *b, int ldb) override; \
|
||||
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan( \
|
||||
const blas::BlasLtMatmulPlanParams& params) override; \
|
||||
bool GetBlasLtMatmulAlgorithms( \
|
||||
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size, \
|
||||
int max_algorithm_count, \
|
||||
std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>* \
|
||||
out_algorithms) override; \
|
||||
port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>> \
|
||||
CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams& params) override; \
|
||||
port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>> \
|
||||
GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan* plan, \
|
||||
size_t max_workspace_size, \
|
||||
int max_algorithm_count) override; \
|
||||
bool DoBlasLtMatmul( \
|
||||
Stream* stream, const blas::IBlasLtMatmulPlan* plan, \
|
||||
const HostOrDeviceScalar<void>& alpha, DeviceMemoryBase a, \
|
||||
|
@ -3057,45 +3057,48 @@ bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
inline bool SetCublasLtAttr(cublasLtMatrixLayout_t handle,
|
||||
cublasLtMatrixLayoutAttribute_t attr,
|
||||
const T& value) {
|
||||
inline port::Status SetCublasLtAttr(cublasLtMatrixLayout_t handle,
|
||||
cublasLtMatrixLayoutAttribute_t attr,
|
||||
const T& value) {
|
||||
cublasStatus_t status =
|
||||
cublasLtMatrixLayoutSetAttribute(handle, attr, &value, sizeof(T));
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
VLOG(2) << "cublasLtMatrixLayoutSetAttribute(attr=" << attr
|
||||
<< ", value=" << value << ") failed: " << ToString(status);
|
||||
return false;
|
||||
return port::Status(
|
||||
port::error::INTERNAL,
|
||||
absl::StrCat("cublasLtMatrixLayoutSetAttribute(attr=", attr,
|
||||
", value=", value, ") failed: ", ToString(status)));
|
||||
}
|
||||
return true;
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool SetCublasLtAttr(cublasLtMatmulAlgo_t* handle,
|
||||
cublasLtMatmulAlgoConfigAttributes_t attr,
|
||||
const T& value) {
|
||||
inline port::Status SetCublasLtAttr(cublasLtMatmulAlgo_t* handle,
|
||||
cublasLtMatmulAlgoConfigAttributes_t attr,
|
||||
const T& value) {
|
||||
cublasStatus_t status =
|
||||
cublasLtMatmulAlgoConfigSetAttribute(handle, attr, &value, sizeof(T));
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
VLOG(2) << "cublasLtMatmulAlgoConfigSetAttribute(attr=" << attr
|
||||
<< ", value=" << value << ") failed: " << ToString(status);
|
||||
return false;
|
||||
return port::Status(
|
||||
port::error::INTERNAL,
|
||||
absl::StrCat("cublasLtMatmulAlgoConfigSetAttribute(attr=", attr,
|
||||
", value=", value, ") failed: ", ToString(status)));
|
||||
}
|
||||
return true;
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool SetCublasLtAttr(cublasLtMatmulPreference_t handle,
|
||||
cublasLtMatmulPreferenceAttributes_t attr,
|
||||
const T& value) {
|
||||
inline port::Status SetCublasLtAttr(cublasLtMatmulPreference_t handle,
|
||||
cublasLtMatmulPreferenceAttributes_t attr,
|
||||
const T& value) {
|
||||
cublasStatus_t status =
|
||||
cublasLtMatmulPreferenceSetAttribute(handle, attr, &value, sizeof(value));
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
VLOG(2) << "cublasLtMatmulPreferenceSetAttribute(attr=" << attr
|
||||
<< ", value=" << value << ") failed: " << ToString(status);
|
||||
return false;
|
||||
return port::Status(
|
||||
port::error::INTERNAL,
|
||||
absl::StrCat("cublasLtMatmulPreferenceSetAttribute(attr=", attr,
|
||||
", value=", value, ") failed: ", ToString(status)));
|
||||
}
|
||||
return true;
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -3111,17 +3114,27 @@ inline bool GetCublasLtAttr(const cublasLtMatmulAlgo_t* handle,
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool SetCublasLtAttr(cublasLtMatmulDesc_t handle,
|
||||
cublasLtMatmulDescAttributes_t attr,
|
||||
const T& value) {
|
||||
inline const T& ValueForStrCat(const T& value) {
|
||||
return value;
|
||||
}
|
||||
template <typename T>
|
||||
inline absl::Hex ValueForStrCat(T* ptr) {
|
||||
return absl::Hex(reinterpret_cast<uintptr_t>(ptr));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline port::Status SetCublasLtAttr(cublasLtMatmulDesc_t handle,
|
||||
cublasLtMatmulDescAttributes_t attr,
|
||||
const T& value) {
|
||||
cublasStatus_t status =
|
||||
cublasLtMatmulDescSetAttribute(handle, attr, &value, sizeof(value));
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
VLOG(2) << "cublasLtMatmulDescSetAttribute(attr=" << attr
|
||||
<< ", value=" << value << ") failed: " << ToString(status);
|
||||
return false;
|
||||
return port::Status(
|
||||
port::error::INTERNAL,
|
||||
absl::StrCat("cublasLtMatmulDescSetAttribute(attr=", attr, ", value=",
|
||||
ValueForStrCat(value), ") failed: ", ToString(status)));
|
||||
}
|
||||
return true;
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
struct MatmulDescDestroyer {
|
||||
@ -3149,12 +3162,10 @@ using UniqueMatmulPreference =
|
||||
std::unique_ptr<std::remove_pointer<cublasLtMatmulPreference_t>::type,
|
||||
MatmulPreferenceDestroyer>;
|
||||
|
||||
UniqueOpDesc CreateCublasLtOperationDesc(blas::ComputationType computation_type,
|
||||
blas::DataType scale_type,
|
||||
blas::PointerMode pointer_mode,
|
||||
blas::Epilogue epilogue,
|
||||
blas::Transpose transa,
|
||||
blas::Transpose transb) {
|
||||
port::StatusOr<UniqueOpDesc> CreateCublasLtOperationDesc(
|
||||
blas::ComputationType computation_type, blas::DataType scale_type,
|
||||
blas::PointerMode pointer_mode, blas::Epilogue epilogue,
|
||||
blas::Transpose transa, blas::Transpose transb) {
|
||||
cublasLtMatmulDesc_t desc;
|
||||
cublasComputeType_t cublas_compute_type =
|
||||
CUBLASComputationType(computation_type);
|
||||
@ -3162,40 +3173,39 @@ UniqueOpDesc CreateCublasLtOperationDesc(blas::ComputationType computation_type,
|
||||
cublasStatus_t status =
|
||||
cublasLtMatmulDescCreate(&desc, cublas_compute_type, cuda_scale_type);
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
VLOG(2) << "cublasLtMatmulDescCreate(computation_type=" << computation_type
|
||||
<< ") failed: " << ToString(status);
|
||||
return nullptr;
|
||||
return port::Status(
|
||||
port::error::INTERNAL,
|
||||
absl::StrCat("cublasLtMatmulDescCreate(computation_type=",
|
||||
computation_type, ") failed: ", ToString(status)));
|
||||
}
|
||||
UniqueOpDesc unique_desc(desc);
|
||||
if (!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_POINTER_MODE,
|
||||
CUBLASPointerMode(pointer_mode)) ||
|
||||
!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_EPILOGUE,
|
||||
CUBLASEpilogue(epilogue)) ||
|
||||
!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSA,
|
||||
CUDABlasTranspose(transa)) ||
|
||||
!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSB,
|
||||
CUDABlasTranspose(transb))) {
|
||||
return nullptr;
|
||||
}
|
||||
SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_POINTER_MODE,
|
||||
CUBLASPointerMode(pointer_mode)));
|
||||
SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_EPILOGUE,
|
||||
CUBLASEpilogue(epilogue)));
|
||||
SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSA,
|
||||
CUDABlasTranspose(transa)));
|
||||
SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSB,
|
||||
CUDABlasTranspose(transb)));
|
||||
return unique_desc;
|
||||
}
|
||||
|
||||
UniqueLayoutDesc CreateCublasLtLayoutDesc(blas::DataType data_type, uint64 rows,
|
||||
uint64 cols, int64 ld, int64 stride,
|
||||
int batch_count) {
|
||||
port::StatusOr<UniqueLayoutDesc> CreateCublasLtLayoutDesc(
|
||||
blas::DataType data_type, uint64 rows, uint64 cols, int64 ld, int64 stride,
|
||||
int batch_count) {
|
||||
cublasLtMatrixLayout_t desc;
|
||||
cublasStatus_t status = cublasLtMatrixLayoutCreate(
|
||||
&desc, GetCUDADataType(data_type), rows, cols, ld);
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
VLOG(2) << "cublasLtMatrixLayoutCreate failed: " << ToString(status);
|
||||
return nullptr;
|
||||
return port::Status(
|
||||
port::error::INTERNAL,
|
||||
absl::StrCat("cublasLtMatrixLayoutCreate failed: ", ToString(status)));
|
||||
}
|
||||
UniqueLayoutDesc unique_desc(desc);
|
||||
if (!SetCublasLtAttr(desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_count) ||
|
||||
!SetCublasLtAttr(desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
|
||||
stride)) {
|
||||
return nullptr;
|
||||
}
|
||||
SE_RETURN_IF_ERROR(
|
||||
SetCublasLtAttr(desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_count));
|
||||
SE_RETURN_IF_ERROR(SetCublasLtAttr(
|
||||
desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride));
|
||||
return unique_desc;
|
||||
}
|
||||
|
||||
@ -3234,7 +3244,28 @@ blas::ComputationType ToComputationType<std::complex<double>>() {
|
||||
|
||||
class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
|
||||
public:
|
||||
CUDABlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams& params);
|
||||
CUDABlasLtMatmulPlan(UniqueOpDesc op_desc, UniqueLayoutDesc a_desc,
|
||||
UniqueLayoutDesc b_desc, UniqueLayoutDesc c_desc,
|
||||
UniqueLayoutDesc d_desc, blas::DataType ab_type,
|
||||
blas::DataType c_type, blas::DataType scale_type,
|
||||
blas::PointerMode pointer_mode, blas::Epilogue epilogue,
|
||||
int batch_count, int64 stride_a, int64 stride_b,
|
||||
int64 stride_c, int64 stride_d)
|
||||
: op_desc_(std::move(op_desc)),
|
||||
a_desc_(std::move(a_desc)),
|
||||
b_desc_(std::move(b_desc)),
|
||||
c_desc_(std::move(c_desc)),
|
||||
d_desc_(std::move(d_desc)),
|
||||
ab_type_(ab_type),
|
||||
c_type_(c_type),
|
||||
scale_type_(scale_type),
|
||||
pointer_mode_(pointer_mode),
|
||||
epilogue_(epilogue),
|
||||
batch_count_(batch_count),
|
||||
stride_a_(stride_a),
|
||||
stride_b_(stride_b),
|
||||
stride_c_(stride_c),
|
||||
stride_d_(stride_d) {}
|
||||
|
||||
cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); }
|
||||
cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); }
|
||||
@ -3276,40 +3307,9 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
|
||||
int64 stride_d_;
|
||||
};
|
||||
|
||||
CUDABlasLtMatmulPlan::CUDABlasLtMatmulPlan(
|
||||
const blas::BlasLtMatmulPlanParams& p)
|
||||
: op_desc_(CreateCublasLtOperationDesc(
|
||||
p.computation_type, GetScaleType(p.c_type, p.computation_type),
|
||||
p.pointer_mode, p.epilogue, p.transa, p.transb)),
|
||||
a_desc_(nullptr),
|
||||
b_desc_(nullptr),
|
||||
c_desc_(CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
|
||||
p.batch_count)),
|
||||
d_desc_(CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
|
||||
p.batch_count)),
|
||||
ab_type_(p.ab_type),
|
||||
cd_type_(p.c_type),
|
||||
scale_type_(GetScaleType(p.c_type, p.computation_type)),
|
||||
pointer_mode_(p.pointer_mode),
|
||||
epilogue_(p.epilogue),
|
||||
batch_count_(p.batch_count),
|
||||
stride_a_(p.stride_a),
|
||||
stride_b_(p.stride_b),
|
||||
stride_c_(p.stride_c),
|
||||
stride_d_(p.stride_c) {
|
||||
uint64 rows_a = p.transa == blas::Transpose::kNoTranspose ? p.m : p.k;
|
||||
uint64 cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m;
|
||||
uint64 rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n;
|
||||
uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k;
|
||||
a_desc_ = CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda,
|
||||
p.stride_a, p.batch_count);
|
||||
b_desc_ = CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb,
|
||||
p.stride_b, p.batch_count);
|
||||
}
|
||||
|
||||
bool CUDABlasLtMatmulPlan::SetBiasPointer(const void* bias) const {
|
||||
return SetCublasLtAttr(op_desc_.get(), CUBLASLT_MATMUL_DESC_BIAS_POINTER,
|
||||
bias);
|
||||
bias).ok();
|
||||
}
|
||||
|
||||
class CUDABlasLtMatmulAlgorithm final : public blas::IBlasLtMatmulAlgorithm {
|
||||
@ -3336,20 +3336,19 @@ class CUDABlasLtMatmulAlgorithm final : public blas::IBlasLtMatmulAlgorithm {
|
||||
size_t workspace_size_;
|
||||
};
|
||||
|
||||
UniqueMatmulPreference CreateCublasLtMatmulPreference(
|
||||
const blas::IBlasLtMatmulPlan* plan,
|
||||
size_t max_workspace_bytes) {
|
||||
port::StatusOr<UniqueMatmulPreference> CreateCublasLtMatmulPreference(
|
||||
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_bytes) {
|
||||
cublasLtMatmulPreference_t preference;
|
||||
cublasStatus_t status = cublasLtMatmulPreferenceCreate(&preference);
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
VLOG(2) << "cublasLtMatmulPreferenceCreate failed: " << ToString(status);
|
||||
return nullptr;
|
||||
return port::Status(port::error::INTERNAL,
|
||||
absl::StrCat("cublasLtMatmulPreferenceCreate failed: ",
|
||||
ToString(status)));
|
||||
}
|
||||
UniqueMatmulPreference unique_preference(preference);
|
||||
if (!SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
max_workspace_bytes)) {
|
||||
return nullptr;
|
||||
}
|
||||
SE_RETURN_IF_ERROR(SetCublasLtAttr(preference,
|
||||
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
max_workspace_bytes));
|
||||
|
||||
const auto& cuda_plan = *static_cast<const CUDABlasLtMatmulPlan*>(plan);
|
||||
if (cuda_plan.batch_count() == 0) {
|
||||
@ -3361,25 +3360,28 @@ UniqueMatmulPreference CreateCublasLtMatmulPreference(
|
||||
auto get_alignment_bytes = [](int64 stride, blas::DataType dtype) {
|
||||
return (stride & -stride) * GetDataTypeSizeBytes(dtype);
|
||||
};
|
||||
if ((cuda_plan.stride_a() &&
|
||||
!SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES,
|
||||
if (cuda_plan.stride_a()) {
|
||||
SE_RETURN_IF_ERROR(
|
||||
SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES,
|
||||
(uint32)get_alignment_bytes(cuda_plan.stride_a(),
|
||||
cuda_plan.ab_type()))) ||
|
||||
(cuda_plan.stride_b() &&
|
||||
!SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES,
|
||||
(uint32)get_alignment_bytes(cuda_plan.stride_b(),
|
||||
cuda_plan.ab_type()))) ||
|
||||
(cuda_plan.stride_c() &&
|
||||
!SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES,
|
||||
(uint32)get_alignment_bytes(cuda_plan.stride_c(),
|
||||
cuda_plan.cd_type()))) ||
|
||||
(cuda_plan.stride_d() &&
|
||||
!SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES,
|
||||
(uint32)get_alignment_bytes(cuda_plan.stride_d(),
|
||||
cuda_plan.cd_type())))) {
|
||||
return nullptr;
|
||||
cuda_plan.ab_type())));
|
||||
}
|
||||
if (cuda_plan.stride_b()) {
|
||||
SE_RETURN_IF_ERROR(
|
||||
SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES,
|
||||
(uint32)get_alignment_bytes(cuda_plan.stride_b(),
|
||||
cuda_plan.ab_type())));
|
||||
}
|
||||
if (cuda_plan.stride_c()) {
|
||||
SE_RETURN_IF_ERROR(SetCublasLtAttr(
|
||||
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES,
|
||||
(uint32)get_alignment_bytes(cuda_plan.stride_c(), cuda_plan.c_type())));
|
||||
}
|
||||
if (cuda_plan.stride_d()) {
|
||||
SE_RETURN_IF_ERROR(SetCublasLtAttr(
|
||||
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES,
|
||||
(uint32)get_alignment_bytes(cuda_plan.stride_d(), cuda_plan.c_type())));
|
||||
}
|
||||
|
||||
return unique_preference;
|
||||
}
|
||||
|
||||
@ -3387,28 +3389,50 @@ UniqueMatmulPreference CreateCublasLtMatmulPreference(
|
||||
|
||||
#endif // CUDA_VERSION >= 11000
|
||||
|
||||
std::unique_ptr<blas::IBlasLtMatmulPlan> CUDABlas::CreateBlasLtMatmulPlan(
|
||||
const blas::BlasLtMatmulPlanParams& params) {
|
||||
port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
|
||||
CUDABlas::CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams& p) {
|
||||
#if CUDA_VERSION >= 11000
|
||||
auto result = std::make_unique<CUDABlasLtMatmulPlan>(params);
|
||||
if (!result->ok()) {
|
||||
result.reset();
|
||||
}
|
||||
return result;
|
||||
SE_ASSIGN_OR_RETURN(
|
||||
auto op_desc,
|
||||
CreateCublasLtOperationDesc(
|
||||
p.computation_type, GetScaleType(p.c_type, p.computation_type),
|
||||
p.pointer_mode, p.epilogue, p.transa, p.transb));
|
||||
uint64 rows_a = p.transa == blas::Transpose::kNoTranspose ? p.m : p.k;
|
||||
uint64 cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m;
|
||||
uint64 rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n;
|
||||
uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k;
|
||||
SE_ASSIGN_OR_RETURN(auto a_desc,
|
||||
CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda,
|
||||
p.stride_a, p.batch_count));
|
||||
SE_ASSIGN_OR_RETURN(auto b_desc,
|
||||
CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb,
|
||||
p.stride_b, p.batch_count));
|
||||
SE_ASSIGN_OR_RETURN(auto c_desc,
|
||||
CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc,
|
||||
p.stride_c, p.batch_count));
|
||||
SE_ASSIGN_OR_RETURN(auto d_desc,
|
||||
CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc,
|
||||
p.stride_c, p.batch_count));
|
||||
blas::DataType scale_type = GetScaleType(p.c_type, p.computation_type);
|
||||
|
||||
return static_cast<std::unique_ptr<blas::IBlasLtMatmulPlan>>(
|
||||
std::make_unique<CUDABlasLtMatmulPlan>(
|
||||
std::move(op_desc), std::move(a_desc), std::move(b_desc),
|
||||
std::move(c_desc), std::move(d_desc), p.ab_type, p.c_type, scale_type,
|
||||
p.pointer_mode, p.epilogue, p.batch_count, p.stride_a, p.stride_b,
|
||||
p.stride_c, p.stride_c));
|
||||
#else
|
||||
return nullptr;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool CUDABlas::GetBlasLtMatmulAlgorithms(
|
||||
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size,
|
||||
int max_algorithm_count,
|
||||
std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>*
|
||||
out_algorithms) {
|
||||
port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
|
||||
CUDABlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan* plan,
|
||||
size_t max_workspace_size,
|
||||
int max_algorithm_count) {
|
||||
#if CUDA_VERSION >= 11000
|
||||
UniqueMatmulPreference preference =
|
||||
CreateCublasLtMatmulPreference(plan, max_workspace_size);
|
||||
if (!preference) return false;
|
||||
SE_ASSIGN_OR_RETURN(UniqueMatmulPreference preference,
|
||||
CreateCublasLtMatmulPreference(plan, max_workspace_size));
|
||||
|
||||
std::vector<cublasLtMatmulHeuristicResult_t> results(max_algorithm_count);
|
||||
{
|
||||
@ -3425,21 +3449,27 @@ bool CUDABlas::GetBlasLtMatmulAlgorithms(
|
||||
cuda_plan.c_desc(), cuda_plan.d_desc(), preference.get(),
|
||||
max_algorithm_count, results.data(), &found_algorithm_count);
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
VLOG(2) << "cublasLtMatmulAlgoGetHeuristic failed: " << ToString(status);
|
||||
return false;
|
||||
return port::Status(
|
||||
port::error::INTERNAL,
|
||||
absl::StrCat("cublasLtMatmulAlgoGetHeuristic failed: ",
|
||||
ToString(status)));
|
||||
}
|
||||
results.resize(found_algorithm_count);
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>> out_algorithms;
|
||||
out_algorithms.reserve(results.size());
|
||||
for (size_t i = 0; i < results.size(); ++i) {
|
||||
const auto& result = results[i];
|
||||
if (result.state != CUBLAS_STATUS_SUCCESS) continue; // Skip failed algos
|
||||
out_algorithms->emplace_back(std::make_unique<CUDABlasLtMatmulAlgorithm>(
|
||||
out_algorithms.emplace_back(std::make_unique<CUDABlasLtMatmulAlgorithm>(
|
||||
i, result.algo, result.workspaceSize));
|
||||
}
|
||||
return true;
|
||||
return out_algorithms;
|
||||
#else // if CUDA_VERSION < 11000
|
||||
return false;
|
||||
return port::Status(
|
||||
port::error::UNIMPLEMENTED,
|
||||
"GetBlasLtMatmulAlgorithms is not supported with this version of CUDA");
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -336,26 +336,28 @@ bool StreamExecutor::GetBlasGemmAlgorithms(
|
||||
return blas_support->GetBlasGemmAlgorithms(out_algorithms);
|
||||
}
|
||||
|
||||
std::unique_ptr<blas::IBlasLtMatmulPlan> StreamExecutor::CreateBlasLtMatmulPlan(
|
||||
port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
|
||||
StreamExecutor::CreateBlasLtMatmulPlan(
|
||||
const blas::BlasLtMatmulPlanParams& params) {
|
||||
blas::BlasSupport *blas_support = AsBlas();
|
||||
blas::BlasSupport* blas_support = AsBlas();
|
||||
if (!blas_support) {
|
||||
return nullptr;
|
||||
return port::Status(port::error::UNKNOWN,
|
||||
"Fail to find the blas implementation.");
|
||||
}
|
||||
return blas_support->CreateBlasLtMatmulPlan(params);
|
||||
}
|
||||
|
||||
bool StreamExecutor::GetBlasLtMatmulAlgorithms(
|
||||
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size,
|
||||
int max_algorithm_count,
|
||||
std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>*
|
||||
out_algorithms) {
|
||||
blas::BlasSupport *blas_support = AsBlas();
|
||||
port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
|
||||
StreamExecutor::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan* plan,
|
||||
size_t max_workspace_size,
|
||||
int max_algorithm_count) {
|
||||
blas::BlasSupport* blas_support = AsBlas();
|
||||
if (!blas_support) {
|
||||
return false;
|
||||
return port::Status(port::error::UNKNOWN,
|
||||
"Fail to find the blas implementation.");
|
||||
}
|
||||
return blas_support->GetBlasLtMatmulAlgorithms(
|
||||
plan, max_workspace_size, max_algorithm_count, out_algorithms);
|
||||
return blas_support->GetBlasLtMatmulAlgorithms(plan, max_workspace_size,
|
||||
max_algorithm_count);
|
||||
}
|
||||
|
||||
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
|
||||
|
@ -398,18 +398,16 @@ class StreamExecutor {
|
||||
// can then be passed to DoBlasLtMatmul(). When possible, plans should be
|
||||
// created once and reused for multiple calls to DoBlasLtMatmul().
|
||||
// Returns a null pointer on failure.
|
||||
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
|
||||
const blas::BlasLtMatmulPlanParams& params);
|
||||
port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
|
||||
CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams& params);
|
||||
|
||||
// Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
|
||||
// returned in the order of increasing estimated compute time according to an
|
||||
// internal heuristic. The first returned algorithm can be used as the default
|
||||
// algorithm if no autotuning is to be performed.
|
||||
bool GetBlasLtMatmulAlgorithms(
|
||||
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size,
|
||||
int max_algorithm_count,
|
||||
std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>*
|
||||
out_algorithms);
|
||||
port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
|
||||
GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan* plan,
|
||||
size_t max_workspace_size, int max_algorithm_count);
|
||||
|
||||
// Create an RNN descriptor based on model shapes and configurations.
|
||||
// The caller retains the ownership of the descriptor.
|
||||
|
Loading…
Reference in New Issue
Block a user