Replace CreateBlasLtMatmulPlan args with struct

This commit is contained in:
Ben Barsdell 2020-09-29 16:02:09 +10:00
parent 39bf03f083
commit b03ae6de78
5 changed files with 72 additions and 127 deletions

View File

@ -555,23 +555,23 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
GetBlasComputationType(dtype, allow_tf32, &computation_type),
errors::Internal("Unsupported dtype for batched matmul"));
std::unique_ptr<se::blas::IBlasLtMatmulPlan> plan =
stream->parent()->CreateBlasLtMatmulPlanStridedBatched(
/*ab_type=*/blas_dtype,
/*cd_type=*/blas_dtype, computation_type,
se::blas::PointerMode::kHost, se::blas::Epilogue::kDefault,
blas_transpose_b, blas_transpose_a, n, m, k, batch_size,
/*lda=*/in_y.dim_size(2), b_stride,
/*ldb=*/in_x.dim_size(2), a_stride, /*ldc=*/n, c_stride);
stream->parent()->CreateBlasLtMatmulPlan(
{/*ab_type=*/blas_dtype,
/*c_type=*/blas_dtype, computation_type,
se::blas::PointerMode::kHost, se::blas::Epilogue::kDefault,
blas_transpose_b, blas_transpose_a, n, m, k,
/*lda=*/in_y.dim_size(2), /*ldb=*/in_x.dim_size(2), /*ldc=*/n,
batch_size, b_stride, a_stride, c_stride});
OP_REQUIRES(
context, plan,
errors::Internal(
"CreateBlasLtMatmulPlanStridedBatched failed : a.shape=(",
in_x.dim_size(0), ", ", in_x.dim_size(1), ", ",
in_x.dim_size(2), "), b.shape=(", in_y.dim_size(0), ", ",
in_y.dim_size(1), ", ", in_y.dim_size(2), "), m=", m, ", n=", n,
", k=", k, ", batch_size=", batch_size, ", adjoint_a=", adj_x,
", adjoint_b=", adj_x, ", dtype=", dtype,
", computation_type=", computation_type));
errors::Internal("CreateBlasLtMatmulPlan failed : a.shape=(",
in_x.dim_size(0), ", ", in_x.dim_size(1), ", ",
in_x.dim_size(2), "), b.shape=(", in_y.dim_size(0),
", ", in_y.dim_size(1), ", ", in_y.dim_size(2),
"), m=", m, ", n=", n, ", k=", k,
", batch_size=", batch_size, ", adjoint_a=", adj_x,
", adjoint_b=", adj_x, ", dtype=", dtype,
", computation_type=", computation_type));
std::vector<std::unique_ptr<se::blas::IBlasLtMatmulAlgorithm>>
algorithms;
OP_REQUIRES(

View File

@ -242,6 +242,27 @@ struct IBlasLtMatmulAlgorithm {
virtual size_t workspace_size() const = 0;
};
// Parameters for the CreateBlasLtMatmulPlan method.
struct BlasLtMatmulPlanParams {
DataType ab_type;
DataType c_type;
ComputationType computation_type;
PointerMode pointer_mode;
Epilogue epilogue;
Transpose transa;
Transpose transb;
uint64 m;
uint64 n;
uint64 k;
int64 lda;
int64 ldb;
int64 ldc;
int batch_count = 1;
int64 stride_a = 0;
int64 stride_b = 0;
int64 stride_c = 0;
};
// BLAS support interface -- this can be derived from a GPU executor when the
// underlying platform has an BLAS library implementation available. See
// StreamExecutor::AsBlas().
@ -1466,25 +1487,8 @@ class BlasSupport {
// can then be passed to DoBlasLtMatmul(). When possible, plans should be
// created once and reused for multiple calls to DoBlasLtMatmul().
// Returns a null pointer on failure.
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
blas::DataType ab_type, blas::DataType c_type,
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc) {
return CreateBlasLtMatmulPlanStridedBatched(
ab_type, c_type, computation_type, pointer_mode, epilogue, transa,
transb, m, n, k, 1, lda, 0, ldb, 0, ldc, 0);
}
// A more general version of CreateBlasLtMatmulPlan supporting
// batched operations.
virtual std::unique_ptr<blas::IBlasLtMatmulPlan>
CreateBlasLtMatmulPlanStridedBatched(
blas::DataType ab_type, blas::DataType c_type,
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a,
int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) = 0;
virtual std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
const blas::BlasLtMatmulPlanParams& params) = 0;
// Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
// returned in the order of increasing estimated compute time according to an
@ -2372,14 +2376,8 @@ class BlasSupport {
uint64 n, std::complex<double> alpha, \
const DeviceMemory<std::complex<double>> &a, int lda, \
DeviceMemory<std::complex<double>> *b, int ldb) override; \
std::unique_ptr<blas::IBlasLtMatmulPlan> \
CreateBlasLtMatmulPlanStridedBatched( \
blas::DataType ab_type, blas::DataType cd_type, \
blas::ComputationType computation_type, blas::PointerMode pointer_mode, \
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb, \
uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, \
int64 stride_a, int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) \
override; \
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan( \
const blas::BlasLtMatmulPlanParams& params) override; \
bool GetBlasLtMatmulAlgorithms( \
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size, \
int max_algorithm_count, \

View File

@ -3231,13 +3231,7 @@ blas::ComputationType ToComputationType<std::complex<double>>() {
class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
public:
CUDABlasLtMatmulPlan(blas::DataType ab_type, blas::DataType cd_type,
blas::ComputationType compute_type,
blas::PointerMode pointer_mode, blas::Epilogue epilogue,
blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, int batch_count, int64 lda,
int64 stride_a, int64 ldb, int64 stride_b, int64 ldc,
int64 stride_c, int64 ldd, int64 stride_d);
CUDABlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams& params);
cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); }
cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); }
@ -3280,39 +3274,34 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
};
CUDABlasLtMatmulPlan::CUDABlasLtMatmulPlan(
blas::DataType ab_type, blas::DataType cd_type,
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a,
int64 ldb, int64 stride_b, int64 ldc, int64 stride_c, int64 ldd,
int64 stride_d)
const blas::BlasLtMatmulPlanParams& p)
: op_desc_(CreateCublasLtOperationDesc(
computation_type, GetScaleType(cd_type, computation_type),
pointer_mode, epilogue, transa, transb)),
p.computation_type, GetScaleType(p.c_type, p.computation_type),
p.pointer_mode, p.epilogue, p.transa, p.transb)),
a_desc_(nullptr),
b_desc_(nullptr),
c_desc_(
CreateCublasLtLayoutDesc(cd_type, m, n, ldc, stride_c, batch_count)),
d_desc_(
CreateCublasLtLayoutDesc(cd_type, m, n, ldd, stride_d, batch_count)),
ab_type_(ab_type),
cd_type_(cd_type),
scale_type_(GetScaleType(cd_type, computation_type)),
pointer_mode_(pointer_mode),
epilogue_(epilogue),
batch_count_(batch_count),
stride_a_(stride_a),
stride_b_(stride_b),
stride_c_(stride_c),
stride_d_(stride_d) {
uint64 rows_a = transa == blas::Transpose::kNoTranspose ? m : k;
uint64 cols_a = transa == blas::Transpose::kNoTranspose ? k : m;
uint64 rows_b = transb == blas::Transpose::kNoTranspose ? k : n;
uint64 cols_b = transb == blas::Transpose::kNoTranspose ? n : k;
a_desc_ = CreateCublasLtLayoutDesc(ab_type, rows_a, cols_a, lda, stride_a,
batch_count);
b_desc_ = CreateCublasLtLayoutDesc(ab_type, rows_b, cols_b, ldb, stride_b,
batch_count);
c_desc_(CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
p.batch_count)),
d_desc_(CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
p.batch_count)),
ab_type_(p.ab_type),
cd_type_(p.c_type),
scale_type_(GetScaleType(p.c_type, p.computation_type)),
pointer_mode_(p.pointer_mode),
epilogue_(p.epilogue),
batch_count_(p.batch_count),
stride_a_(p.stride_a),
stride_b_(p.stride_b),
stride_c_(p.stride_c),
stride_d_(p.stride_c) {
uint64 rows_a = p.transa == blas::Transpose::kNoTranspose ? p.m : p.k;
uint64 cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m;
uint64 rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n;
uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k;
a_desc_ = CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda,
p.stride_a, p.batch_count);
b_desc_ = CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb,
p.stride_b, p.batch_count);
}
bool CUDABlasLtMatmulPlan::SetBiasPointer(const void* bias) const {
@ -3395,18 +3384,10 @@ UniqueMatmulPreference CreateCublasLtMatmulPreference(
#endif // CUDA_VERSION >= 11000
std::unique_ptr<blas::IBlasLtMatmulPlan>
CUDABlas::CreateBlasLtMatmulPlanStridedBatched(
blas::DataType ab_type, blas::DataType cd_type,
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a,
int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) {
std::unique_ptr<blas::IBlasLtMatmulPlan> CUDABlas::CreateBlasLtMatmulPlan(
const blas::BlasLtMatmulPlanParams& params) {
#if CUDA_VERSION >= 11000
auto result = std::make_unique<CUDABlasLtMatmulPlan>(
ab_type, cd_type, computation_type, pointer_mode, epilogue, transa,
transb, m, n, k, batch_count, lda, stride_a, ldb, stride_b, ldc, stride_c,
ldc, stride_c);
auto result = std::make_unique<CUDABlasLtMatmulPlan>(params);
if (!result->ok()) {
result.reset();
}

View File

@ -337,34 +337,12 @@ bool StreamExecutor::GetBlasGemmAlgorithms(
}
std::unique_ptr<blas::IBlasLtMatmulPlan> StreamExecutor::CreateBlasLtMatmulPlan(
blas::DataType ab_type, blas::DataType cd_type,
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc) {
const blas::BlasLtMatmulPlanParams& params) {
blas::BlasSupport *blas_support = AsBlas();
if (!blas_support) {
return nullptr;
}
return blas_support->CreateBlasLtMatmulPlan(
ab_type, cd_type, computation_type, pointer_mode, epilogue, transa,
transb, m, n, k, lda, ldb, ldc);
}
std::unique_ptr<blas::IBlasLtMatmulPlan>
StreamExecutor::CreateBlasLtMatmulPlanStridedBatched(
blas::DataType ab_type, blas::DataType cd_type,
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, uint64 batch_count, int64 lda, int64 stride_a,
int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) {
blas::BlasSupport *blas_support = AsBlas();
if (!blas_support) {
return nullptr;
}
return blas_support->CreateBlasLtMatmulPlanStridedBatched(
ab_type, cd_type, computation_type, pointer_mode, epilogue, transa,
transb, m, n, k, batch_count, lda, stride_a, ldb, stride_b, ldc,
stride_c);
return blas_support->CreateBlasLtMatmulPlan(params);
}
bool StreamExecutor::GetBlasLtMatmulAlgorithms(

View File

@ -399,19 +399,7 @@ class StreamExecutor {
// created once and reused for multiple calls to DoBlasLtMatmul().
// Returns a null pointer on failure.
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
blas::DataType ab_type, blas::DataType cd_type,
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc);
// A more general version of CreateBlasLtMatmulPlan supporting
// batched operations.
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlanStridedBatched(
blas::DataType ab_type, blas::DataType cd_type,
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, uint64 batch_count, int64 lda,
int64 stride_a, int64 ldb, int64 stride_b, int64 ldc, int64 stride_c);
const blas::BlasLtMatmulPlanParams& params);
// Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
// returned in the order of increasing estimated compute time according to an