Replace CreateBlasLtMatmulPlan args with struct
This commit is contained in:
parent
39bf03f083
commit
b03ae6de78
@ -555,23 +555,23 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
GetBlasComputationType(dtype, allow_tf32, &computation_type),
|
||||
errors::Internal("Unsupported dtype for batched matmul"));
|
||||
std::unique_ptr<se::blas::IBlasLtMatmulPlan> plan =
|
||||
stream->parent()->CreateBlasLtMatmulPlanStridedBatched(
|
||||
/*ab_type=*/blas_dtype,
|
||||
/*cd_type=*/blas_dtype, computation_type,
|
||||
se::blas::PointerMode::kHost, se::blas::Epilogue::kDefault,
|
||||
blas_transpose_b, blas_transpose_a, n, m, k, batch_size,
|
||||
/*lda=*/in_y.dim_size(2), b_stride,
|
||||
/*ldb=*/in_x.dim_size(2), a_stride, /*ldc=*/n, c_stride);
|
||||
stream->parent()->CreateBlasLtMatmulPlan(
|
||||
{/*ab_type=*/blas_dtype,
|
||||
/*c_type=*/blas_dtype, computation_type,
|
||||
se::blas::PointerMode::kHost, se::blas::Epilogue::kDefault,
|
||||
blas_transpose_b, blas_transpose_a, n, m, k,
|
||||
/*lda=*/in_y.dim_size(2), /*ldb=*/in_x.dim_size(2), /*ldc=*/n,
|
||||
batch_size, b_stride, a_stride, c_stride});
|
||||
OP_REQUIRES(
|
||||
context, plan,
|
||||
errors::Internal(
|
||||
"CreateBlasLtMatmulPlanStridedBatched failed : a.shape=(",
|
||||
in_x.dim_size(0), ", ", in_x.dim_size(1), ", ",
|
||||
in_x.dim_size(2), "), b.shape=(", in_y.dim_size(0), ", ",
|
||||
in_y.dim_size(1), ", ", in_y.dim_size(2), "), m=", m, ", n=", n,
|
||||
", k=", k, ", batch_size=", batch_size, ", adjoint_a=", adj_x,
|
||||
", adjoint_b=", adj_x, ", dtype=", dtype,
|
||||
", computation_type=", computation_type));
|
||||
errors::Internal("CreateBlasLtMatmulPlan failed : a.shape=(",
|
||||
in_x.dim_size(0), ", ", in_x.dim_size(1), ", ",
|
||||
in_x.dim_size(2), "), b.shape=(", in_y.dim_size(0),
|
||||
", ", in_y.dim_size(1), ", ", in_y.dim_size(2),
|
||||
"), m=", m, ", n=", n, ", k=", k,
|
||||
", batch_size=", batch_size, ", adjoint_a=", adj_x,
|
||||
", adjoint_b=", adj_x, ", dtype=", dtype,
|
||||
", computation_type=", computation_type));
|
||||
std::vector<std::unique_ptr<se::blas::IBlasLtMatmulAlgorithm>>
|
||||
algorithms;
|
||||
OP_REQUIRES(
|
||||
|
@ -242,6 +242,27 @@ struct IBlasLtMatmulAlgorithm {
|
||||
virtual size_t workspace_size() const = 0;
|
||||
};
|
||||
|
||||
// Parameters for the CreateBlasLtMatmulPlan method.
|
||||
struct BlasLtMatmulPlanParams {
|
||||
DataType ab_type;
|
||||
DataType c_type;
|
||||
ComputationType computation_type;
|
||||
PointerMode pointer_mode;
|
||||
Epilogue epilogue;
|
||||
Transpose transa;
|
||||
Transpose transb;
|
||||
uint64 m;
|
||||
uint64 n;
|
||||
uint64 k;
|
||||
int64 lda;
|
||||
int64 ldb;
|
||||
int64 ldc;
|
||||
int batch_count = 1;
|
||||
int64 stride_a = 0;
|
||||
int64 stride_b = 0;
|
||||
int64 stride_c = 0;
|
||||
};
|
||||
|
||||
// BLAS support interface -- this can be derived from a GPU executor when the
|
||||
// underlying platform has an BLAS library implementation available. See
|
||||
// StreamExecutor::AsBlas().
|
||||
@ -1466,25 +1487,8 @@ class BlasSupport {
|
||||
// can then be passed to DoBlasLtMatmul(). When possible, plans should be
|
||||
// created once and reused for multiple calls to DoBlasLtMatmul().
|
||||
// Returns a null pointer on failure.
|
||||
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
|
||||
blas::DataType ab_type, blas::DataType c_type,
|
||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
||||
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
|
||||
uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc) {
|
||||
return CreateBlasLtMatmulPlanStridedBatched(
|
||||
ab_type, c_type, computation_type, pointer_mode, epilogue, transa,
|
||||
transb, m, n, k, 1, lda, 0, ldb, 0, ldc, 0);
|
||||
}
|
||||
|
||||
// A more general version of CreateBlasLtMatmulPlan supporting
|
||||
// batched operations.
|
||||
virtual std::unique_ptr<blas::IBlasLtMatmulPlan>
|
||||
CreateBlasLtMatmulPlanStridedBatched(
|
||||
blas::DataType ab_type, blas::DataType c_type,
|
||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
||||
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
|
||||
uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a,
|
||||
int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) = 0;
|
||||
virtual std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
|
||||
const blas::BlasLtMatmulPlanParams& params) = 0;
|
||||
|
||||
// Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
|
||||
// returned in the order of increasing estimated compute time according to an
|
||||
@ -2372,14 +2376,8 @@ class BlasSupport {
|
||||
uint64 n, std::complex<double> alpha, \
|
||||
const DeviceMemory<std::complex<double>> &a, int lda, \
|
||||
DeviceMemory<std::complex<double>> *b, int ldb) override; \
|
||||
std::unique_ptr<blas::IBlasLtMatmulPlan> \
|
||||
CreateBlasLtMatmulPlanStridedBatched( \
|
||||
blas::DataType ab_type, blas::DataType cd_type, \
|
||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode, \
|
||||
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb, \
|
||||
uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, \
|
||||
int64 stride_a, int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) \
|
||||
override; \
|
||||
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan( \
|
||||
const blas::BlasLtMatmulPlanParams& params) override; \
|
||||
bool GetBlasLtMatmulAlgorithms( \
|
||||
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size, \
|
||||
int max_algorithm_count, \
|
||||
|
@ -3231,13 +3231,7 @@ blas::ComputationType ToComputationType<std::complex<double>>() {
|
||||
|
||||
class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
|
||||
public:
|
||||
CUDABlasLtMatmulPlan(blas::DataType ab_type, blas::DataType cd_type,
|
||||
blas::ComputationType compute_type,
|
||||
blas::PointerMode pointer_mode, blas::Epilogue epilogue,
|
||||
blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
uint64 n, uint64 k, int batch_count, int64 lda,
|
||||
int64 stride_a, int64 ldb, int64 stride_b, int64 ldc,
|
||||
int64 stride_c, int64 ldd, int64 stride_d);
|
||||
CUDABlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams& params);
|
||||
|
||||
cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); }
|
||||
cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); }
|
||||
@ -3280,39 +3274,34 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
|
||||
};
|
||||
|
||||
CUDABlasLtMatmulPlan::CUDABlasLtMatmulPlan(
|
||||
blas::DataType ab_type, blas::DataType cd_type,
|
||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
||||
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
|
||||
uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a,
|
||||
int64 ldb, int64 stride_b, int64 ldc, int64 stride_c, int64 ldd,
|
||||
int64 stride_d)
|
||||
const blas::BlasLtMatmulPlanParams& p)
|
||||
: op_desc_(CreateCublasLtOperationDesc(
|
||||
computation_type, GetScaleType(cd_type, computation_type),
|
||||
pointer_mode, epilogue, transa, transb)),
|
||||
p.computation_type, GetScaleType(p.c_type, p.computation_type),
|
||||
p.pointer_mode, p.epilogue, p.transa, p.transb)),
|
||||
a_desc_(nullptr),
|
||||
b_desc_(nullptr),
|
||||
c_desc_(
|
||||
CreateCublasLtLayoutDesc(cd_type, m, n, ldc, stride_c, batch_count)),
|
||||
d_desc_(
|
||||
CreateCublasLtLayoutDesc(cd_type, m, n, ldd, stride_d, batch_count)),
|
||||
ab_type_(ab_type),
|
||||
cd_type_(cd_type),
|
||||
scale_type_(GetScaleType(cd_type, computation_type)),
|
||||
pointer_mode_(pointer_mode),
|
||||
epilogue_(epilogue),
|
||||
batch_count_(batch_count),
|
||||
stride_a_(stride_a),
|
||||
stride_b_(stride_b),
|
||||
stride_c_(stride_c),
|
||||
stride_d_(stride_d) {
|
||||
uint64 rows_a = transa == blas::Transpose::kNoTranspose ? m : k;
|
||||
uint64 cols_a = transa == blas::Transpose::kNoTranspose ? k : m;
|
||||
uint64 rows_b = transb == blas::Transpose::kNoTranspose ? k : n;
|
||||
uint64 cols_b = transb == blas::Transpose::kNoTranspose ? n : k;
|
||||
a_desc_ = CreateCublasLtLayoutDesc(ab_type, rows_a, cols_a, lda, stride_a,
|
||||
batch_count);
|
||||
b_desc_ = CreateCublasLtLayoutDesc(ab_type, rows_b, cols_b, ldb, stride_b,
|
||||
batch_count);
|
||||
c_desc_(CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
|
||||
p.batch_count)),
|
||||
d_desc_(CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
|
||||
p.batch_count)),
|
||||
ab_type_(p.ab_type),
|
||||
cd_type_(p.c_type),
|
||||
scale_type_(GetScaleType(p.c_type, p.computation_type)),
|
||||
pointer_mode_(p.pointer_mode),
|
||||
epilogue_(p.epilogue),
|
||||
batch_count_(p.batch_count),
|
||||
stride_a_(p.stride_a),
|
||||
stride_b_(p.stride_b),
|
||||
stride_c_(p.stride_c),
|
||||
stride_d_(p.stride_c) {
|
||||
uint64 rows_a = p.transa == blas::Transpose::kNoTranspose ? p.m : p.k;
|
||||
uint64 cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m;
|
||||
uint64 rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n;
|
||||
uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k;
|
||||
a_desc_ = CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda,
|
||||
p.stride_a, p.batch_count);
|
||||
b_desc_ = CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb,
|
||||
p.stride_b, p.batch_count);
|
||||
}
|
||||
|
||||
bool CUDABlasLtMatmulPlan::SetBiasPointer(const void* bias) const {
|
||||
@ -3395,18 +3384,10 @@ UniqueMatmulPreference CreateCublasLtMatmulPreference(
|
||||
|
||||
#endif // CUDA_VERSION >= 11000
|
||||
|
||||
std::unique_ptr<blas::IBlasLtMatmulPlan>
|
||||
CUDABlas::CreateBlasLtMatmulPlanStridedBatched(
|
||||
blas::DataType ab_type, blas::DataType cd_type,
|
||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
||||
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
|
||||
uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a,
|
||||
int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) {
|
||||
std::unique_ptr<blas::IBlasLtMatmulPlan> CUDABlas::CreateBlasLtMatmulPlan(
|
||||
const blas::BlasLtMatmulPlanParams& params) {
|
||||
#if CUDA_VERSION >= 11000
|
||||
auto result = std::make_unique<CUDABlasLtMatmulPlan>(
|
||||
ab_type, cd_type, computation_type, pointer_mode, epilogue, transa,
|
||||
transb, m, n, k, batch_count, lda, stride_a, ldb, stride_b, ldc, stride_c,
|
||||
ldc, stride_c);
|
||||
auto result = std::make_unique<CUDABlasLtMatmulPlan>(params);
|
||||
if (!result->ok()) {
|
||||
result.reset();
|
||||
}
|
||||
|
@ -337,34 +337,12 @@ bool StreamExecutor::GetBlasGemmAlgorithms(
|
||||
}
|
||||
|
||||
std::unique_ptr<blas::IBlasLtMatmulPlan> StreamExecutor::CreateBlasLtMatmulPlan(
|
||||
blas::DataType ab_type, blas::DataType cd_type,
|
||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
||||
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
|
||||
uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc) {
|
||||
const blas::BlasLtMatmulPlanParams& params) {
|
||||
blas::BlasSupport *blas_support = AsBlas();
|
||||
if (!blas_support) {
|
||||
return nullptr;
|
||||
}
|
||||
return blas_support->CreateBlasLtMatmulPlan(
|
||||
ab_type, cd_type, computation_type, pointer_mode, epilogue, transa,
|
||||
transb, m, n, k, lda, ldb, ldc);
|
||||
}
|
||||
|
||||
std::unique_ptr<blas::IBlasLtMatmulPlan>
|
||||
StreamExecutor::CreateBlasLtMatmulPlanStridedBatched(
|
||||
blas::DataType ab_type, blas::DataType cd_type,
|
||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
||||
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
|
||||
uint64 m, uint64 n, uint64 k, uint64 batch_count, int64 lda, int64 stride_a,
|
||||
int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) {
|
||||
blas::BlasSupport *blas_support = AsBlas();
|
||||
if (!blas_support) {
|
||||
return nullptr;
|
||||
}
|
||||
return blas_support->CreateBlasLtMatmulPlanStridedBatched(
|
||||
ab_type, cd_type, computation_type, pointer_mode, epilogue, transa,
|
||||
transb, m, n, k, batch_count, lda, stride_a, ldb, stride_b, ldc,
|
||||
stride_c);
|
||||
return blas_support->CreateBlasLtMatmulPlan(params);
|
||||
}
|
||||
|
||||
bool StreamExecutor::GetBlasLtMatmulAlgorithms(
|
||||
|
@ -399,19 +399,7 @@ class StreamExecutor {
|
||||
// created once and reused for multiple calls to DoBlasLtMatmul().
|
||||
// Returns a null pointer on failure.
|
||||
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
|
||||
blas::DataType ab_type, blas::DataType cd_type,
|
||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
||||
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
|
||||
uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc);
|
||||
|
||||
// A more general version of CreateBlasLtMatmulPlan supporting
|
||||
// batched operations.
|
||||
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlanStridedBatched(
|
||||
blas::DataType ab_type, blas::DataType cd_type,
|
||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
||||
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
|
||||
uint64 m, uint64 n, uint64 k, uint64 batch_count, int64 lda,
|
||||
int64 stride_a, int64 ldb, int64 stride_b, int64 ldc, int64 stride_c);
|
||||
const blas::BlasLtMatmulPlanParams& params);
|
||||
|
||||
// Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
|
||||
// returned in the order of increasing estimated compute time according to an
|
||||
|
Loading…
Reference in New Issue
Block a user