Replace CreateBlasLtMatmulPlan args with struct

This commit is contained in:
Ben Barsdell 2020-09-29 16:02:09 +10:00
parent 39bf03f083
commit b03ae6de78
5 changed files with 72 additions and 127 deletions

View File

@ -555,23 +555,23 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
GetBlasComputationType(dtype, allow_tf32, &computation_type), GetBlasComputationType(dtype, allow_tf32, &computation_type),
errors::Internal("Unsupported dtype for batched matmul")); errors::Internal("Unsupported dtype for batched matmul"));
std::unique_ptr<se::blas::IBlasLtMatmulPlan> plan = std::unique_ptr<se::blas::IBlasLtMatmulPlan> plan =
stream->parent()->CreateBlasLtMatmulPlanStridedBatched( stream->parent()->CreateBlasLtMatmulPlan(
/*ab_type=*/blas_dtype, {/*ab_type=*/blas_dtype,
/*cd_type=*/blas_dtype, computation_type, /*c_type=*/blas_dtype, computation_type,
se::blas::PointerMode::kHost, se::blas::Epilogue::kDefault, se::blas::PointerMode::kHost, se::blas::Epilogue::kDefault,
blas_transpose_b, blas_transpose_a, n, m, k, batch_size, blas_transpose_b, blas_transpose_a, n, m, k,
/*lda=*/in_y.dim_size(2), b_stride, /*lda=*/in_y.dim_size(2), /*ldb=*/in_x.dim_size(2), /*ldc=*/n,
/*ldb=*/in_x.dim_size(2), a_stride, /*ldc=*/n, c_stride); batch_size, b_stride, a_stride, c_stride});
OP_REQUIRES( OP_REQUIRES(
context, plan, context, plan,
errors::Internal( errors::Internal("CreateBlasLtMatmulPlan failed : a.shape=(",
"CreateBlasLtMatmulPlanStridedBatched failed : a.shape=(", in_x.dim_size(0), ", ", in_x.dim_size(1), ", ",
in_x.dim_size(0), ", ", in_x.dim_size(1), ", ", in_x.dim_size(2), "), b.shape=(", in_y.dim_size(0),
in_x.dim_size(2), "), b.shape=(", in_y.dim_size(0), ", ", ", ", in_y.dim_size(1), ", ", in_y.dim_size(2),
in_y.dim_size(1), ", ", in_y.dim_size(2), "), m=", m, ", n=", n, "), m=", m, ", n=", n, ", k=", k,
", k=", k, ", batch_size=", batch_size, ", adjoint_a=", adj_x, ", batch_size=", batch_size, ", adjoint_a=", adj_x,
", adjoint_b=", adj_x, ", dtype=", dtype, ", adjoint_b=", adj_x, ", dtype=", dtype,
", computation_type=", computation_type)); ", computation_type=", computation_type));
std::vector<std::unique_ptr<se::blas::IBlasLtMatmulAlgorithm>> std::vector<std::unique_ptr<se::blas::IBlasLtMatmulAlgorithm>>
algorithms; algorithms;
OP_REQUIRES( OP_REQUIRES(

View File

@ -242,6 +242,27 @@ struct IBlasLtMatmulAlgorithm {
virtual size_t workspace_size() const = 0; virtual size_t workspace_size() const = 0;
}; };
// Parameters for the CreateBlasLtMatmulPlan method.
struct BlasLtMatmulPlanParams {
DataType ab_type;
DataType c_type;
ComputationType computation_type;
PointerMode pointer_mode;
Epilogue epilogue;
Transpose transa;
Transpose transb;
uint64 m;
uint64 n;
uint64 k;
int64 lda;
int64 ldb;
int64 ldc;
int batch_count = 1;
int64 stride_a = 0;
int64 stride_b = 0;
int64 stride_c = 0;
};
// BLAS support interface -- this can be derived from a GPU executor when the // BLAS support interface -- this can be derived from a GPU executor when the
// underlying platform has an BLAS library implementation available. See // underlying platform has an BLAS library implementation available. See
// StreamExecutor::AsBlas(). // StreamExecutor::AsBlas().
@ -1466,25 +1487,8 @@ class BlasSupport {
// can then be passed to DoBlasLtMatmul(). When possible, plans should be // can then be passed to DoBlasLtMatmul(). When possible, plans should be
// created once and reused for multiple calls to DoBlasLtMatmul(). // created once and reused for multiple calls to DoBlasLtMatmul().
// Returns a null pointer on failure. // Returns a null pointer on failure.
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan( virtual std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
blas::DataType ab_type, blas::DataType c_type, const blas::BlasLtMatmulPlanParams& params) = 0;
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc) {
return CreateBlasLtMatmulPlanStridedBatched(
ab_type, c_type, computation_type, pointer_mode, epilogue, transa,
transb, m, n, k, 1, lda, 0, ldb, 0, ldc, 0);
}
// A more general version of CreateBlasLtMatmulPlan supporting
// batched operations.
virtual std::unique_ptr<blas::IBlasLtMatmulPlan>
CreateBlasLtMatmulPlanStridedBatched(
blas::DataType ab_type, blas::DataType c_type,
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a,
int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) = 0;
// Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are // Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
// returned in the order of increasing estimated compute time according to an // returned in the order of increasing estimated compute time according to an
@ -2372,14 +2376,8 @@ class BlasSupport {
uint64 n, std::complex<double> alpha, \ uint64 n, std::complex<double> alpha, \
const DeviceMemory<std::complex<double>> &a, int lda, \ const DeviceMemory<std::complex<double>> &a, int lda, \
DeviceMemory<std::complex<double>> *b, int ldb) override; \ DeviceMemory<std::complex<double>> *b, int ldb) override; \
std::unique_ptr<blas::IBlasLtMatmulPlan> \ std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan( \
CreateBlasLtMatmulPlanStridedBatched( \ const blas::BlasLtMatmulPlanParams& params) override; \
blas::DataType ab_type, blas::DataType cd_type, \
blas::ComputationType computation_type, blas::PointerMode pointer_mode, \
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb, \
uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, \
int64 stride_a, int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) \
override; \
bool GetBlasLtMatmulAlgorithms( \ bool GetBlasLtMatmulAlgorithms( \
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size, \ const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size, \
int max_algorithm_count, \ int max_algorithm_count, \

View File

@ -3231,13 +3231,7 @@ blas::ComputationType ToComputationType<std::complex<double>>() {
class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan { class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
public: public:
CUDABlasLtMatmulPlan(blas::DataType ab_type, blas::DataType cd_type, CUDABlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams& params);
blas::ComputationType compute_type,
blas::PointerMode pointer_mode, blas::Epilogue epilogue,
blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, int batch_count, int64 lda,
int64 stride_a, int64 ldb, int64 stride_b, int64 ldc,
int64 stride_c, int64 ldd, int64 stride_d);
cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); } cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); }
cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); } cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); }
@ -3280,39 +3274,34 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
}; };
CUDABlasLtMatmulPlan::CUDABlasLtMatmulPlan( CUDABlasLtMatmulPlan::CUDABlasLtMatmulPlan(
blas::DataType ab_type, blas::DataType cd_type, const blas::BlasLtMatmulPlanParams& p)
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a,
int64 ldb, int64 stride_b, int64 ldc, int64 stride_c, int64 ldd,
int64 stride_d)
: op_desc_(CreateCublasLtOperationDesc( : op_desc_(CreateCublasLtOperationDesc(
computation_type, GetScaleType(cd_type, computation_type), p.computation_type, GetScaleType(p.c_type, p.computation_type),
pointer_mode, epilogue, transa, transb)), p.pointer_mode, p.epilogue, p.transa, p.transb)),
a_desc_(nullptr), a_desc_(nullptr),
b_desc_(nullptr), b_desc_(nullptr),
c_desc_( c_desc_(CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
CreateCublasLtLayoutDesc(cd_type, m, n, ldc, stride_c, batch_count)), p.batch_count)),
d_desc_( d_desc_(CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
CreateCublasLtLayoutDesc(cd_type, m, n, ldd, stride_d, batch_count)), p.batch_count)),
ab_type_(ab_type), ab_type_(p.ab_type),
cd_type_(cd_type), cd_type_(p.c_type),
scale_type_(GetScaleType(cd_type, computation_type)), scale_type_(GetScaleType(p.c_type, p.computation_type)),
pointer_mode_(pointer_mode), pointer_mode_(p.pointer_mode),
epilogue_(epilogue), epilogue_(p.epilogue),
batch_count_(batch_count), batch_count_(p.batch_count),
stride_a_(stride_a), stride_a_(p.stride_a),
stride_b_(stride_b), stride_b_(p.stride_b),
stride_c_(stride_c), stride_c_(p.stride_c),
stride_d_(stride_d) { stride_d_(p.stride_c) {
uint64 rows_a = transa == blas::Transpose::kNoTranspose ? m : k; uint64 rows_a = p.transa == blas::Transpose::kNoTranspose ? p.m : p.k;
uint64 cols_a = transa == blas::Transpose::kNoTranspose ? k : m; uint64 cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m;
uint64 rows_b = transb == blas::Transpose::kNoTranspose ? k : n; uint64 rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n;
uint64 cols_b = transb == blas::Transpose::kNoTranspose ? n : k; uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k;
a_desc_ = CreateCublasLtLayoutDesc(ab_type, rows_a, cols_a, lda, stride_a, a_desc_ = CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda,
batch_count); p.stride_a, p.batch_count);
b_desc_ = CreateCublasLtLayoutDesc(ab_type, rows_b, cols_b, ldb, stride_b, b_desc_ = CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb,
batch_count); p.stride_b, p.batch_count);
} }
bool CUDABlasLtMatmulPlan::SetBiasPointer(const void* bias) const { bool CUDABlasLtMatmulPlan::SetBiasPointer(const void* bias) const {
@ -3395,18 +3384,10 @@ UniqueMatmulPreference CreateCublasLtMatmulPreference(
#endif // CUDA_VERSION >= 11000 #endif // CUDA_VERSION >= 11000
std::unique_ptr<blas::IBlasLtMatmulPlan> std::unique_ptr<blas::IBlasLtMatmulPlan> CUDABlas::CreateBlasLtMatmulPlan(
CUDABlas::CreateBlasLtMatmulPlanStridedBatched( const blas::BlasLtMatmulPlanParams& params) {
blas::DataType ab_type, blas::DataType cd_type,
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a,
int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) {
#if CUDA_VERSION >= 11000 #if CUDA_VERSION >= 11000
auto result = std::make_unique<CUDABlasLtMatmulPlan>( auto result = std::make_unique<CUDABlasLtMatmulPlan>(params);
ab_type, cd_type, computation_type, pointer_mode, epilogue, transa,
transb, m, n, k, batch_count, lda, stride_a, ldb, stride_b, ldc, stride_c,
ldc, stride_c);
if (!result->ok()) { if (!result->ok()) {
result.reset(); result.reset();
} }

View File

@ -337,34 +337,12 @@ bool StreamExecutor::GetBlasGemmAlgorithms(
} }
std::unique_ptr<blas::IBlasLtMatmulPlan> StreamExecutor::CreateBlasLtMatmulPlan( std::unique_ptr<blas::IBlasLtMatmulPlan> StreamExecutor::CreateBlasLtMatmulPlan(
blas::DataType ab_type, blas::DataType cd_type, const blas::BlasLtMatmulPlanParams& params) {
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc) {
blas::BlasSupport *blas_support = AsBlas(); blas::BlasSupport *blas_support = AsBlas();
if (!blas_support) { if (!blas_support) {
return nullptr; return nullptr;
} }
return blas_support->CreateBlasLtMatmulPlan( return blas_support->CreateBlasLtMatmulPlan(params);
ab_type, cd_type, computation_type, pointer_mode, epilogue, transa,
transb, m, n, k, lda, ldb, ldc);
}
std::unique_ptr<blas::IBlasLtMatmulPlan>
StreamExecutor::CreateBlasLtMatmulPlanStridedBatched(
blas::DataType ab_type, blas::DataType cd_type,
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, uint64 batch_count, int64 lda, int64 stride_a,
int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) {
blas::BlasSupport *blas_support = AsBlas();
if (!blas_support) {
return nullptr;
}
return blas_support->CreateBlasLtMatmulPlanStridedBatched(
ab_type, cd_type, computation_type, pointer_mode, epilogue, transa,
transb, m, n, k, batch_count, lda, stride_a, ldb, stride_b, ldc,
stride_c);
} }
bool StreamExecutor::GetBlasLtMatmulAlgorithms( bool StreamExecutor::GetBlasLtMatmulAlgorithms(

View File

@ -399,19 +399,7 @@ class StreamExecutor {
// created once and reused for multiple calls to DoBlasLtMatmul(). // created once and reused for multiple calls to DoBlasLtMatmul().
// Returns a null pointer on failure. // Returns a null pointer on failure.
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan( std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
blas::DataType ab_type, blas::DataType cd_type, const blas::BlasLtMatmulPlanParams& params);
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc);
// A more general version of CreateBlasLtMatmulPlan supporting
// batched operations.
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlanStridedBatched(
blas::DataType ab_type, blas::DataType cd_type,
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, uint64 batch_count, int64 lda,
int64 stride_a, int64 ldb, int64 stride_b, int64 ldc, int64 stride_c);
// Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are // Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
// returned in the order of increasing estimated compute time according to an // returned in the order of increasing estimated compute time according to an