Replace CreateBlasLtMatmulPlan args with struct
This commit is contained in:
parent
39bf03f083
commit
b03ae6de78
tensorflow
core/kernels
stream_executor
@ -555,23 +555,23 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
|||||||
GetBlasComputationType(dtype, allow_tf32, &computation_type),
|
GetBlasComputationType(dtype, allow_tf32, &computation_type),
|
||||||
errors::Internal("Unsupported dtype for batched matmul"));
|
errors::Internal("Unsupported dtype for batched matmul"));
|
||||||
std::unique_ptr<se::blas::IBlasLtMatmulPlan> plan =
|
std::unique_ptr<se::blas::IBlasLtMatmulPlan> plan =
|
||||||
stream->parent()->CreateBlasLtMatmulPlanStridedBatched(
|
stream->parent()->CreateBlasLtMatmulPlan(
|
||||||
/*ab_type=*/blas_dtype,
|
{/*ab_type=*/blas_dtype,
|
||||||
/*cd_type=*/blas_dtype, computation_type,
|
/*c_type=*/blas_dtype, computation_type,
|
||||||
se::blas::PointerMode::kHost, se::blas::Epilogue::kDefault,
|
se::blas::PointerMode::kHost, se::blas::Epilogue::kDefault,
|
||||||
blas_transpose_b, blas_transpose_a, n, m, k, batch_size,
|
blas_transpose_b, blas_transpose_a, n, m, k,
|
||||||
/*lda=*/in_y.dim_size(2), b_stride,
|
/*lda=*/in_y.dim_size(2), /*ldb=*/in_x.dim_size(2), /*ldc=*/n,
|
||||||
/*ldb=*/in_x.dim_size(2), a_stride, /*ldc=*/n, c_stride);
|
batch_size, b_stride, a_stride, c_stride});
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, plan,
|
context, plan,
|
||||||
errors::Internal(
|
errors::Internal("CreateBlasLtMatmulPlan failed : a.shape=(",
|
||||||
"CreateBlasLtMatmulPlanStridedBatched failed : a.shape=(",
|
in_x.dim_size(0), ", ", in_x.dim_size(1), ", ",
|
||||||
in_x.dim_size(0), ", ", in_x.dim_size(1), ", ",
|
in_x.dim_size(2), "), b.shape=(", in_y.dim_size(0),
|
||||||
in_x.dim_size(2), "), b.shape=(", in_y.dim_size(0), ", ",
|
", ", in_y.dim_size(1), ", ", in_y.dim_size(2),
|
||||||
in_y.dim_size(1), ", ", in_y.dim_size(2), "), m=", m, ", n=", n,
|
"), m=", m, ", n=", n, ", k=", k,
|
||||||
", k=", k, ", batch_size=", batch_size, ", adjoint_a=", adj_x,
|
", batch_size=", batch_size, ", adjoint_a=", adj_x,
|
||||||
", adjoint_b=", adj_x, ", dtype=", dtype,
|
", adjoint_b=", adj_x, ", dtype=", dtype,
|
||||||
", computation_type=", computation_type));
|
", computation_type=", computation_type));
|
||||||
std::vector<std::unique_ptr<se::blas::IBlasLtMatmulAlgorithm>>
|
std::vector<std::unique_ptr<se::blas::IBlasLtMatmulAlgorithm>>
|
||||||
algorithms;
|
algorithms;
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
|
@ -242,6 +242,27 @@ struct IBlasLtMatmulAlgorithm {
|
|||||||
virtual size_t workspace_size() const = 0;
|
virtual size_t workspace_size() const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Parameters for the CreateBlasLtMatmulPlan method.
|
||||||
|
struct BlasLtMatmulPlanParams {
|
||||||
|
DataType ab_type;
|
||||||
|
DataType c_type;
|
||||||
|
ComputationType computation_type;
|
||||||
|
PointerMode pointer_mode;
|
||||||
|
Epilogue epilogue;
|
||||||
|
Transpose transa;
|
||||||
|
Transpose transb;
|
||||||
|
uint64 m;
|
||||||
|
uint64 n;
|
||||||
|
uint64 k;
|
||||||
|
int64 lda;
|
||||||
|
int64 ldb;
|
||||||
|
int64 ldc;
|
||||||
|
int batch_count = 1;
|
||||||
|
int64 stride_a = 0;
|
||||||
|
int64 stride_b = 0;
|
||||||
|
int64 stride_c = 0;
|
||||||
|
};
|
||||||
|
|
||||||
// BLAS support interface -- this can be derived from a GPU executor when the
|
// BLAS support interface -- this can be derived from a GPU executor when the
|
||||||
// underlying platform has an BLAS library implementation available. See
|
// underlying platform has an BLAS library implementation available. See
|
||||||
// StreamExecutor::AsBlas().
|
// StreamExecutor::AsBlas().
|
||||||
@ -1466,25 +1487,8 @@ class BlasSupport {
|
|||||||
// can then be passed to DoBlasLtMatmul(). When possible, plans should be
|
// can then be passed to DoBlasLtMatmul(). When possible, plans should be
|
||||||
// created once and reused for multiple calls to DoBlasLtMatmul().
|
// created once and reused for multiple calls to DoBlasLtMatmul().
|
||||||
// Returns a null pointer on failure.
|
// Returns a null pointer on failure.
|
||||||
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
|
virtual std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
|
||||||
blas::DataType ab_type, blas::DataType c_type,
|
const blas::BlasLtMatmulPlanParams& params) = 0;
|
||||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
|
||||||
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
|
|
||||||
uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc) {
|
|
||||||
return CreateBlasLtMatmulPlanStridedBatched(
|
|
||||||
ab_type, c_type, computation_type, pointer_mode, epilogue, transa,
|
|
||||||
transb, m, n, k, 1, lda, 0, ldb, 0, ldc, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// A more general version of CreateBlasLtMatmulPlan supporting
|
|
||||||
// batched operations.
|
|
||||||
virtual std::unique_ptr<blas::IBlasLtMatmulPlan>
|
|
||||||
CreateBlasLtMatmulPlanStridedBatched(
|
|
||||||
blas::DataType ab_type, blas::DataType c_type,
|
|
||||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
|
||||||
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
|
|
||||||
uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a,
|
|
||||||
int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) = 0;
|
|
||||||
|
|
||||||
// Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
|
// Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
|
||||||
// returned in the order of increasing estimated compute time according to an
|
// returned in the order of increasing estimated compute time according to an
|
||||||
@ -2372,14 +2376,8 @@ class BlasSupport {
|
|||||||
uint64 n, std::complex<double> alpha, \
|
uint64 n, std::complex<double> alpha, \
|
||||||
const DeviceMemory<std::complex<double>> &a, int lda, \
|
const DeviceMemory<std::complex<double>> &a, int lda, \
|
||||||
DeviceMemory<std::complex<double>> *b, int ldb) override; \
|
DeviceMemory<std::complex<double>> *b, int ldb) override; \
|
||||||
std::unique_ptr<blas::IBlasLtMatmulPlan> \
|
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan( \
|
||||||
CreateBlasLtMatmulPlanStridedBatched( \
|
const blas::BlasLtMatmulPlanParams& params) override; \
|
||||||
blas::DataType ab_type, blas::DataType cd_type, \
|
|
||||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode, \
|
|
||||||
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb, \
|
|
||||||
uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, \
|
|
||||||
int64 stride_a, int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) \
|
|
||||||
override; \
|
|
||||||
bool GetBlasLtMatmulAlgorithms( \
|
bool GetBlasLtMatmulAlgorithms( \
|
||||||
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size, \
|
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size, \
|
||||||
int max_algorithm_count, \
|
int max_algorithm_count, \
|
||||||
|
@ -3231,13 +3231,7 @@ blas::ComputationType ToComputationType<std::complex<double>>() {
|
|||||||
|
|
||||||
class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
|
class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
|
||||||
public:
|
public:
|
||||||
CUDABlasLtMatmulPlan(blas::DataType ab_type, blas::DataType cd_type,
|
CUDABlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams& params);
|
||||||
blas::ComputationType compute_type,
|
|
||||||
blas::PointerMode pointer_mode, blas::Epilogue epilogue,
|
|
||||||
blas::Transpose transa, blas::Transpose transb, uint64 m,
|
|
||||||
uint64 n, uint64 k, int batch_count, int64 lda,
|
|
||||||
int64 stride_a, int64 ldb, int64 stride_b, int64 ldc,
|
|
||||||
int64 stride_c, int64 ldd, int64 stride_d);
|
|
||||||
|
|
||||||
cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); }
|
cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); }
|
||||||
cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); }
|
cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); }
|
||||||
@ -3280,39 +3274,34 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
|
|||||||
};
|
};
|
||||||
|
|
||||||
CUDABlasLtMatmulPlan::CUDABlasLtMatmulPlan(
|
CUDABlasLtMatmulPlan::CUDABlasLtMatmulPlan(
|
||||||
blas::DataType ab_type, blas::DataType cd_type,
|
const blas::BlasLtMatmulPlanParams& p)
|
||||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
|
||||||
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
|
|
||||||
uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a,
|
|
||||||
int64 ldb, int64 stride_b, int64 ldc, int64 stride_c, int64 ldd,
|
|
||||||
int64 stride_d)
|
|
||||||
: op_desc_(CreateCublasLtOperationDesc(
|
: op_desc_(CreateCublasLtOperationDesc(
|
||||||
computation_type, GetScaleType(cd_type, computation_type),
|
p.computation_type, GetScaleType(p.c_type, p.computation_type),
|
||||||
pointer_mode, epilogue, transa, transb)),
|
p.pointer_mode, p.epilogue, p.transa, p.transb)),
|
||||||
a_desc_(nullptr),
|
a_desc_(nullptr),
|
||||||
b_desc_(nullptr),
|
b_desc_(nullptr),
|
||||||
c_desc_(
|
c_desc_(CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
|
||||||
CreateCublasLtLayoutDesc(cd_type, m, n, ldc, stride_c, batch_count)),
|
p.batch_count)),
|
||||||
d_desc_(
|
d_desc_(CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
|
||||||
CreateCublasLtLayoutDesc(cd_type, m, n, ldd, stride_d, batch_count)),
|
p.batch_count)),
|
||||||
ab_type_(ab_type),
|
ab_type_(p.ab_type),
|
||||||
cd_type_(cd_type),
|
cd_type_(p.c_type),
|
||||||
scale_type_(GetScaleType(cd_type, computation_type)),
|
scale_type_(GetScaleType(p.c_type, p.computation_type)),
|
||||||
pointer_mode_(pointer_mode),
|
pointer_mode_(p.pointer_mode),
|
||||||
epilogue_(epilogue),
|
epilogue_(p.epilogue),
|
||||||
batch_count_(batch_count),
|
batch_count_(p.batch_count),
|
||||||
stride_a_(stride_a),
|
stride_a_(p.stride_a),
|
||||||
stride_b_(stride_b),
|
stride_b_(p.stride_b),
|
||||||
stride_c_(stride_c),
|
stride_c_(p.stride_c),
|
||||||
stride_d_(stride_d) {
|
stride_d_(p.stride_c) {
|
||||||
uint64 rows_a = transa == blas::Transpose::kNoTranspose ? m : k;
|
uint64 rows_a = p.transa == blas::Transpose::kNoTranspose ? p.m : p.k;
|
||||||
uint64 cols_a = transa == blas::Transpose::kNoTranspose ? k : m;
|
uint64 cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m;
|
||||||
uint64 rows_b = transb == blas::Transpose::kNoTranspose ? k : n;
|
uint64 rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n;
|
||||||
uint64 cols_b = transb == blas::Transpose::kNoTranspose ? n : k;
|
uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k;
|
||||||
a_desc_ = CreateCublasLtLayoutDesc(ab_type, rows_a, cols_a, lda, stride_a,
|
a_desc_ = CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda,
|
||||||
batch_count);
|
p.stride_a, p.batch_count);
|
||||||
b_desc_ = CreateCublasLtLayoutDesc(ab_type, rows_b, cols_b, ldb, stride_b,
|
b_desc_ = CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb,
|
||||||
batch_count);
|
p.stride_b, p.batch_count);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CUDABlasLtMatmulPlan::SetBiasPointer(const void* bias) const {
|
bool CUDABlasLtMatmulPlan::SetBiasPointer(const void* bias) const {
|
||||||
@ -3395,18 +3384,10 @@ UniqueMatmulPreference CreateCublasLtMatmulPreference(
|
|||||||
|
|
||||||
#endif // CUDA_VERSION >= 11000
|
#endif // CUDA_VERSION >= 11000
|
||||||
|
|
||||||
std::unique_ptr<blas::IBlasLtMatmulPlan>
|
std::unique_ptr<blas::IBlasLtMatmulPlan> CUDABlas::CreateBlasLtMatmulPlan(
|
||||||
CUDABlas::CreateBlasLtMatmulPlanStridedBatched(
|
const blas::BlasLtMatmulPlanParams& params) {
|
||||||
blas::DataType ab_type, blas::DataType cd_type,
|
|
||||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
|
||||||
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
|
|
||||||
uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a,
|
|
||||||
int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) {
|
|
||||||
#if CUDA_VERSION >= 11000
|
#if CUDA_VERSION >= 11000
|
||||||
auto result = std::make_unique<CUDABlasLtMatmulPlan>(
|
auto result = std::make_unique<CUDABlasLtMatmulPlan>(params);
|
||||||
ab_type, cd_type, computation_type, pointer_mode, epilogue, transa,
|
|
||||||
transb, m, n, k, batch_count, lda, stride_a, ldb, stride_b, ldc, stride_c,
|
|
||||||
ldc, stride_c);
|
|
||||||
if (!result->ok()) {
|
if (!result->ok()) {
|
||||||
result.reset();
|
result.reset();
|
||||||
}
|
}
|
||||||
|
@ -337,34 +337,12 @@ bool StreamExecutor::GetBlasGemmAlgorithms(
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<blas::IBlasLtMatmulPlan> StreamExecutor::CreateBlasLtMatmulPlan(
|
std::unique_ptr<blas::IBlasLtMatmulPlan> StreamExecutor::CreateBlasLtMatmulPlan(
|
||||||
blas::DataType ab_type, blas::DataType cd_type,
|
const blas::BlasLtMatmulPlanParams& params) {
|
||||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
|
||||||
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
|
|
||||||
uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc) {
|
|
||||||
blas::BlasSupport *blas_support = AsBlas();
|
blas::BlasSupport *blas_support = AsBlas();
|
||||||
if (!blas_support) {
|
if (!blas_support) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return blas_support->CreateBlasLtMatmulPlan(
|
return blas_support->CreateBlasLtMatmulPlan(params);
|
||||||
ab_type, cd_type, computation_type, pointer_mode, epilogue, transa,
|
|
||||||
transb, m, n, k, lda, ldb, ldc);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<blas::IBlasLtMatmulPlan>
|
|
||||||
StreamExecutor::CreateBlasLtMatmulPlanStridedBatched(
|
|
||||||
blas::DataType ab_type, blas::DataType cd_type,
|
|
||||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
|
||||||
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
|
|
||||||
uint64 m, uint64 n, uint64 k, uint64 batch_count, int64 lda, int64 stride_a,
|
|
||||||
int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) {
|
|
||||||
blas::BlasSupport *blas_support = AsBlas();
|
|
||||||
if (!blas_support) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return blas_support->CreateBlasLtMatmulPlanStridedBatched(
|
|
||||||
ab_type, cd_type, computation_type, pointer_mode, epilogue, transa,
|
|
||||||
transb, m, n, k, batch_count, lda, stride_a, ldb, stride_b, ldc,
|
|
||||||
stride_c);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool StreamExecutor::GetBlasLtMatmulAlgorithms(
|
bool StreamExecutor::GetBlasLtMatmulAlgorithms(
|
||||||
|
@ -399,19 +399,7 @@ class StreamExecutor {
|
|||||||
// created once and reused for multiple calls to DoBlasLtMatmul().
|
// created once and reused for multiple calls to DoBlasLtMatmul().
|
||||||
// Returns a null pointer on failure.
|
// Returns a null pointer on failure.
|
||||||
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
|
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
|
||||||
blas::DataType ab_type, blas::DataType cd_type,
|
const blas::BlasLtMatmulPlanParams& params);
|
||||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
|
||||||
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
|
|
||||||
uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc);
|
|
||||||
|
|
||||||
// A more general version of CreateBlasLtMatmulPlan supporting
|
|
||||||
// batched operations.
|
|
||||||
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlanStridedBatched(
|
|
||||||
blas::DataType ab_type, blas::DataType cd_type,
|
|
||||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
|
||||||
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
|
|
||||||
uint64 m, uint64 n, uint64 k, uint64 batch_count, int64 lda,
|
|
||||||
int64 stride_a, int64 ldb, int64 stride_b, int64 ldc, int64 stride_c);
|
|
||||||
|
|
||||||
// Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
|
// Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
|
||||||
// returned in the order of increasing estimated compute time according to an
|
// returned in the order of increasing estimated compute time according to an
|
||||||
|
Loading…
Reference in New Issue
Block a user