diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/batch_matmul_op_impl.h index 456b4beff1e..ac5a45b99ba 100644 --- a/tensorflow/core/kernels/batch_matmul_op_impl.h +++ b/tensorflow/core/kernels/batch_matmul_op_impl.h @@ -555,23 +555,23 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> { GetBlasComputationType(dtype, allow_tf32, &computation_type), errors::Internal("Unsupported dtype for batched matmul")); std::unique_ptr<se::blas::IBlasLtMatmulPlan> plan = - stream->parent()->CreateBlasLtMatmulPlanStridedBatched( - /*ab_type=*/blas_dtype, - /*cd_type=*/blas_dtype, computation_type, - se::blas::PointerMode::kHost, se::blas::Epilogue::kDefault, - blas_transpose_b, blas_transpose_a, n, m, k, batch_size, - /*lda=*/in_y.dim_size(2), b_stride, - /*ldb=*/in_x.dim_size(2), a_stride, /*ldc=*/n, c_stride); + stream->parent()->CreateBlasLtMatmulPlan( + {/*ab_type=*/blas_dtype, + /*c_type=*/blas_dtype, computation_type, + se::blas::PointerMode::kHost, se::blas::Epilogue::kDefault, + blas_transpose_b, blas_transpose_a, n, m, k, + /*lda=*/in_y.dim_size(2), /*ldb=*/in_x.dim_size(2), /*ldc=*/n, + batch_size, b_stride, a_stride, c_stride}); OP_REQUIRES( context, plan, - errors::Internal( - "CreateBlasLtMatmulPlanStridedBatched failed : a.shape=(", - in_x.dim_size(0), ", ", in_x.dim_size(1), ", ", - in_x.dim_size(2), "), b.shape=(", in_y.dim_size(0), ", ", - in_y.dim_size(1), ", ", in_y.dim_size(2), "), m=", m, ", n=", n, - ", k=", k, ", batch_size=", batch_size, ", adjoint_a=", adj_x, - ", adjoint_b=", adj_x, ", dtype=", dtype, - ", computation_type=", computation_type)); + errors::Internal("CreateBlasLtMatmulPlan failed : a.shape=(", + in_x.dim_size(0), ", ", in_x.dim_size(1), ", ", + in_x.dim_size(2), "), b.shape=(", in_y.dim_size(0), + ", ", in_y.dim_size(1), ", ", in_y.dim_size(2), + "), m=", m, ", n=", n, ", k=", k, + ", batch_size=", batch_size, ", adjoint_a=", adj_x, + ", adjoint_b=", adj_x, ", dtype=", dtype, + ", computation_type=", computation_type)); std::vector<std::unique_ptr<se::blas::IBlasLtMatmulAlgorithm>> algorithms; OP_REQUIRES( diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h index ae5b4853d05..411f6f11275 100644 --- a/tensorflow/stream_executor/blas.h +++ b/tensorflow/stream_executor/blas.h @@ -242,6 +242,27 @@ struct IBlasLtMatmulAlgorithm { virtual size_t workspace_size() const = 0; }; +// Parameters for the CreateBlasLtMatmulPlan method. +struct BlasLtMatmulPlanParams { + DataType ab_type; + DataType c_type; + ComputationType computation_type; + PointerMode pointer_mode; + Epilogue epilogue; + Transpose transa; + Transpose transb; + uint64 m; + uint64 n; + uint64 k; + int64 lda; + int64 ldb; + int64 ldc; + int batch_count = 1; + int64 stride_a = 0; + int64 stride_b = 0; + int64 stride_c = 0; +}; + // BLAS support interface -- this can be derived from a GPU executor when the // underlying platform has an BLAS library implementation available. See // StreamExecutor::AsBlas(). @@ -1466,25 +1487,8 @@ class BlasSupport { // can then be passed to DoBlasLtMatmul(). When possible, plans should be // created once and reused for multiple calls to DoBlasLtMatmul(). // Returns a null pointer on failure. - std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan( - blas::DataType ab_type, blas::DataType c_type, - blas::ComputationType computation_type, blas::PointerMode pointer_mode, - blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb, - uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc) { - return CreateBlasLtMatmulPlanStridedBatched( - ab_type, c_type, computation_type, pointer_mode, epilogue, transa, - transb, m, n, k, 1, lda, 0, ldb, 0, ldc, 0); - } - - // A more general version of CreateBlasLtMatmulPlan supporting - // batched operations. - virtual std::unique_ptr<blas::IBlasLtMatmulPlan> - CreateBlasLtMatmulPlanStridedBatched( - blas::DataType ab_type, blas::DataType c_type, - blas::ComputationType computation_type, blas::PointerMode pointer_mode, - blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb, - uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a, - int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) = 0; + virtual std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan( + const blas::BlasLtMatmulPlanParams& params) = 0; // Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are // returned in the order of increasing estimated compute time according to an @@ -2372,14 +2376,8 @@ class BlasSupport { uint64 n, std::complex<double> alpha, \ const DeviceMemory<std::complex<double>> &a, int lda, \ DeviceMemory<std::complex<double>> *b, int ldb) override; \ - std::unique_ptr<blas::IBlasLtMatmulPlan> \ - CreateBlasLtMatmulPlanStridedBatched( \ - blas::DataType ab_type, blas::DataType cd_type, \ - blas::ComputationType computation_type, blas::PointerMode pointer_mode, \ - blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb, \ - uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, \ - int64 stride_a, int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) \ - override; \ + std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan( \ + const blas::BlasLtMatmulPlanParams& params) override; \ bool GetBlasLtMatmulAlgorithms( \ const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size, \ int max_algorithm_count, \ diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index 1d95b00ce7e..f2bc79e1c29 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -3231,13 +3231,7 @@ blas::ComputationType ToComputationType<std::complex<double>>() { class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan { public: - CUDABlasLtMatmulPlan(blas::DataType ab_type, blas::DataType cd_type, - blas::ComputationType compute_type, - blas::PointerMode pointer_mode, blas::Epilogue epilogue, - blas::Transpose transa, blas::Transpose transb, uint64 m, - uint64 n, uint64 k, int batch_count, int64 lda, - int64 stride_a, int64 ldb, int64 stride_b, int64 ldc, - int64 stride_c, int64 ldd, int64 stride_d); + CUDABlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams& params); cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); } cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); } @@ -3280,39 +3274,34 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan { }; CUDABlasLtMatmulPlan::CUDABlasLtMatmulPlan( - blas::DataType ab_type, blas::DataType cd_type, - blas::ComputationType computation_type, blas::PointerMode pointer_mode, - blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb, - uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a, - int64 ldb, int64 stride_b, int64 ldc, int64 stride_c, int64 ldd, - int64 stride_d) + const blas::BlasLtMatmulPlanParams& p) : op_desc_(CreateCublasLtOperationDesc( - computation_type, GetScaleType(cd_type, computation_type), - pointer_mode, epilogue, transa, transb)), + p.computation_type, GetScaleType(p.c_type, p.computation_type), + p.pointer_mode, p.epilogue, p.transa, p.transb)), a_desc_(nullptr), b_desc_(nullptr), - c_desc_( - CreateCublasLtLayoutDesc(cd_type, m, n, ldc, stride_c, batch_count)), - d_desc_( - CreateCublasLtLayoutDesc(cd_type, m, n, ldd, stride_d, batch_count)), - ab_type_(ab_type), - cd_type_(cd_type), - scale_type_(GetScaleType(cd_type, computation_type)), - pointer_mode_(pointer_mode), - epilogue_(epilogue), - batch_count_(batch_count), - stride_a_(stride_a), - stride_b_(stride_b), - stride_c_(stride_c), - stride_d_(stride_d) { - uint64 rows_a = transa == blas::Transpose::kNoTranspose ? m : k; - uint64 cols_a = transa == blas::Transpose::kNoTranspose ? k : m; - uint64 rows_b = transb == blas::Transpose::kNoTranspose ? k : n; - uint64 cols_b = transb == blas::Transpose::kNoTranspose ? n : k; - a_desc_ = CreateCublasLtLayoutDesc(ab_type, rows_a, cols_a, lda, stride_a, - batch_count); - b_desc_ = CreateCublasLtLayoutDesc(ab_type, rows_b, cols_b, ldb, stride_b, - batch_count); + c_desc_(CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c, + p.batch_count)), + d_desc_(CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c, + p.batch_count)), + ab_type_(p.ab_type), + cd_type_(p.c_type), + scale_type_(GetScaleType(p.c_type, p.computation_type)), + pointer_mode_(p.pointer_mode), + epilogue_(p.epilogue), + batch_count_(p.batch_count), + stride_a_(p.stride_a), + stride_b_(p.stride_b), + stride_c_(p.stride_c), + stride_d_(p.stride_c) { + uint64 rows_a = p.transa == blas::Transpose::kNoTranspose ? p.m : p.k; + uint64 cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m; + uint64 rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n; + uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k; + a_desc_ = CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda, + p.stride_a, p.batch_count); + b_desc_ = CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb, + p.stride_b, p.batch_count); } bool CUDABlasLtMatmulPlan::SetBiasPointer(const void* bias) const { @@ -3395,18 +3384,10 @@ UniqueMatmulPreference CreateCublasLtMatmulPreference( #endif // CUDA_VERSION >= 11000 -std::unique_ptr<blas::IBlasLtMatmulPlan> -CUDABlas::CreateBlasLtMatmulPlanStridedBatched( - blas::DataType ab_type, blas::DataType cd_type, - blas::ComputationType computation_type, blas::PointerMode pointer_mode, - blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb, - uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a, - int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) { +std::unique_ptr<blas::IBlasLtMatmulPlan> CUDABlas::CreateBlasLtMatmulPlan( + const blas::BlasLtMatmulPlanParams& params) { #if CUDA_VERSION >= 11000 - auto result = std::make_unique<CUDABlasLtMatmulPlan>( - ab_type, cd_type, computation_type, pointer_mode, epilogue, transa, - transb, m, n, k, batch_count, lda, stride_a, ldb, stride_b, ldc, stride_c, - ldc, stride_c); + auto result = std::make_unique<CUDABlasLtMatmulPlan>(params); if (!result->ok()) { result.reset(); } diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index d75c1bc65c5..d40b6adc285 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -337,34 +337,12 @@ bool StreamExecutor::GetBlasGemmAlgorithms( } std::unique_ptr<blas::IBlasLtMatmulPlan> StreamExecutor::CreateBlasLtMatmulPlan( - blas::DataType ab_type, blas::DataType cd_type, - blas::ComputationType computation_type, blas::PointerMode pointer_mode, - blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb, - uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc) { + const blas::BlasLtMatmulPlanParams& params) { blas::BlasSupport *blas_support = AsBlas(); if (!blas_support) { return nullptr; } - return blas_support->CreateBlasLtMatmulPlan( - ab_type, cd_type, computation_type, pointer_mode, epilogue, transa, - transb, m, n, k, lda, ldb, ldc); -} - -std::unique_ptr<blas::IBlasLtMatmulPlan> -StreamExecutor::CreateBlasLtMatmulPlanStridedBatched( - blas::DataType ab_type, blas::DataType cd_type, - blas::ComputationType computation_type, blas::PointerMode pointer_mode, - blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb, - uint64 m, uint64 n, uint64 k, uint64 batch_count, int64 lda, int64 stride_a, - int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) { - blas::BlasSupport *blas_support = AsBlas(); - if (!blas_support) { - return nullptr; - } - return blas_support->CreateBlasLtMatmulPlanStridedBatched( - ab_type, cd_type, computation_type, pointer_mode, epilogue, transa, - transb, m, n, k, batch_count, lda, stride_a, ldb, stride_b, ldc, - stride_c); + return blas_support->CreateBlasLtMatmulPlan(params); } bool StreamExecutor::GetBlasLtMatmulAlgorithms( diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h index b40c0c23c05..ce801bf0f28 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/stream_executor/stream_executor_pimpl.h @@ -399,19 +399,7 @@ class StreamExecutor { // created once and reused for multiple calls to DoBlasLtMatmul(). // Returns a null pointer on failure. std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan( - blas::DataType ab_type, blas::DataType cd_type, - blas::ComputationType computation_type, blas::PointerMode pointer_mode, - blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb, - uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc); - - // A more general version of CreateBlasLtMatmulPlan supporting - // batched operations. - std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlanStridedBatched( - blas::DataType ab_type, blas::DataType cd_type, - blas::ComputationType computation_type, blas::PointerMode pointer_mode, - blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb, - uint64 m, uint64 n, uint64 k, uint64 batch_count, int64 lda, - int64 stride_a, int64 ldb, int64 stride_b, int64 ldc, int64 stride_c); + const blas::BlasLtMatmulPlanParams& params); // Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are // returned in the order of increasing estimated compute time according to an