diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/batch_matmul_op_impl.h
index 456b4beff1e..ac5a45b99ba 100644
--- a/tensorflow/core/kernels/batch_matmul_op_impl.h
+++ b/tensorflow/core/kernels/batch_matmul_op_impl.h
@@ -555,23 +555,23 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
             GetBlasComputationType(dtype, allow_tf32, &computation_type),
             errors::Internal("Unsupported dtype for batched matmul"));
         std::unique_ptr<se::blas::IBlasLtMatmulPlan> plan =
-            stream->parent()->CreateBlasLtMatmulPlanStridedBatched(
-                /*ab_type=*/blas_dtype,
-                /*cd_type=*/blas_dtype, computation_type,
-                se::blas::PointerMode::kHost, se::blas::Epilogue::kDefault,
-                blas_transpose_b, blas_transpose_a, n, m, k, batch_size,
-                /*lda=*/in_y.dim_size(2), b_stride,
-                /*ldb=*/in_x.dim_size(2), a_stride, /*ldc=*/n, c_stride);
+            stream->parent()->CreateBlasLtMatmulPlan(
+                {/*ab_type=*/blas_dtype,
+                 /*c_type=*/blas_dtype, computation_type,
+                 se::blas::PointerMode::kHost, se::blas::Epilogue::kDefault,
+                 blas_transpose_b, blas_transpose_a, n, m, k,
+                 /*lda=*/in_y.dim_size(2), /*ldb=*/in_x.dim_size(2), /*ldc=*/n,
+                 batch_size, b_stride, a_stride, c_stride});
         OP_REQUIRES(
             context, plan,
-            errors::Internal(
-                "CreateBlasLtMatmulPlanStridedBatched failed : a.shape=(",
-                in_x.dim_size(0), ", ", in_x.dim_size(1), ", ",
-                in_x.dim_size(2), "), b.shape=(", in_y.dim_size(0), ", ",
-                in_y.dim_size(1), ", ", in_y.dim_size(2), "), m=", m, ", n=", n,
-                ", k=", k, ", batch_size=", batch_size, ", adjoint_a=", adj_x,
-                ", adjoint_b=", adj_x, ", dtype=", dtype,
-                ", computation_type=", computation_type));
+            errors::Internal("CreateBlasLtMatmulPlan failed : a.shape=(",
+                             in_x.dim_size(0), ", ", in_x.dim_size(1), ", ",
+                             in_x.dim_size(2), "), b.shape=(", in_y.dim_size(0),
+                             ", ", in_y.dim_size(1), ", ", in_y.dim_size(2),
+                             "), m=", m, ", n=", n, ", k=", k,
+                             ", batch_size=", batch_size, ", adjoint_a=", adj_x,
+                             ", adjoint_b=", adj_x, ", dtype=", dtype,
+                             ", computation_type=", computation_type));
         std::vector<std::unique_ptr<se::blas::IBlasLtMatmulAlgorithm>>
             algorithms;
         OP_REQUIRES(
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index ae5b4853d05..411f6f11275 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -242,6 +242,27 @@ struct IBlasLtMatmulAlgorithm {
   virtual size_t workspace_size() const = 0;
 };
 
+// Parameters for the CreateBlasLtMatmulPlan method.
+struct BlasLtMatmulPlanParams {
+  DataType ab_type;
+  DataType c_type;
+  ComputationType computation_type;
+  PointerMode pointer_mode;
+  Epilogue epilogue;
+  Transpose transa;
+  Transpose transb;
+  uint64 m;
+  uint64 n;
+  uint64 k;
+  int64 lda;
+  int64 ldb;
+  int64 ldc;
+  int batch_count = 1;
+  int64 stride_a = 0;
+  int64 stride_b = 0;
+  int64 stride_c = 0;
+};
+
 // BLAS support interface -- this can be derived from a GPU executor when the
 // underlying platform has an BLAS library implementation available. See
 // StreamExecutor::AsBlas().
@@ -1466,25 +1487,8 @@ class BlasSupport {
   // can then be passed to DoBlasLtMatmul(). When possible, plans should be
   // created once and reused for multiple calls to DoBlasLtMatmul().
   // Returns a null pointer on failure.
-  std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
-      blas::DataType ab_type, blas::DataType c_type,
-      blas::ComputationType computation_type, blas::PointerMode pointer_mode,
-      blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
-      uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc) {
-    return CreateBlasLtMatmulPlanStridedBatched(
-        ab_type, c_type, computation_type, pointer_mode, epilogue, transa,
-        transb, m, n, k, 1, lda, 0, ldb, 0, ldc, 0);
-  }
-
-  // A more general version of CreateBlasLtMatmulPlan supporting
-  // batched operations.
-  virtual std::unique_ptr<blas::IBlasLtMatmulPlan>
-  CreateBlasLtMatmulPlanStridedBatched(
-      blas::DataType ab_type, blas::DataType c_type,
-      blas::ComputationType computation_type, blas::PointerMode pointer_mode,
-      blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
-      uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a,
-      int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) = 0;
+  virtual std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
+      const blas::BlasLtMatmulPlanParams& params) = 0;
 
   // Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
   // returned in the order of increasing estimated compute time according to an
@@ -2372,14 +2376,8 @@ class BlasSupport {
                   uint64 n, std::complex<double> alpha,                        \
                   const DeviceMemory<std::complex<double>> &a, int lda,        \
                   DeviceMemory<std::complex<double>> *b, int ldb) override;    \
-  std::unique_ptr<blas::IBlasLtMatmulPlan>                                     \
-  CreateBlasLtMatmulPlanStridedBatched(                                        \
-      blas::DataType ab_type, blas::DataType cd_type,                          \
-      blas::ComputationType computation_type, blas::PointerMode pointer_mode,  \
-      blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb, \
-      uint64 m, uint64 n, uint64 k, int batch_count, int64 lda,                \
-      int64 stride_a, int64 ldb, int64 stride_b, int64 ldc, int64 stride_c)    \
-      override;                                                                \
+  std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(             \
+      const blas::BlasLtMatmulPlanParams& params) override;                    \
   bool GetBlasLtMatmulAlgorithms(                                              \
       const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size,          \
       int max_algorithm_count,                                                 \
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index 1d95b00ce7e..f2bc79e1c29 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -3231,13 +3231,7 @@ blas::ComputationType ToComputationType<std::complex<double>>() {
 
 class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
  public:
-  CUDABlasLtMatmulPlan(blas::DataType ab_type, blas::DataType cd_type,
-                       blas::ComputationType compute_type,
-                       blas::PointerMode pointer_mode, blas::Epilogue epilogue,
-                       blas::Transpose transa, blas::Transpose transb, uint64 m,
-                       uint64 n, uint64 k, int batch_count, int64 lda,
-                       int64 stride_a, int64 ldb, int64 stride_b, int64 ldc,
-                       int64 stride_c, int64 ldd, int64 stride_d);
+  CUDABlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams& params);
 
   cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); }
   cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); }
@@ -3280,39 +3274,34 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
 };
 
 CUDABlasLtMatmulPlan::CUDABlasLtMatmulPlan(
-    blas::DataType ab_type, blas::DataType cd_type,
-    blas::ComputationType computation_type, blas::PointerMode pointer_mode,
-    blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
-    uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a,
-    int64 ldb, int64 stride_b, int64 ldc, int64 stride_c, int64 ldd,
-    int64 stride_d)
+    const blas::BlasLtMatmulPlanParams& p)
     : op_desc_(CreateCublasLtOperationDesc(
-          computation_type, GetScaleType(cd_type, computation_type),
-          pointer_mode, epilogue, transa, transb)),
+          p.computation_type, GetScaleType(p.c_type, p.computation_type),
+          p.pointer_mode, p.epilogue, p.transa, p.transb)),
       a_desc_(nullptr),
       b_desc_(nullptr),
-      c_desc_(
-          CreateCublasLtLayoutDesc(cd_type, m, n, ldc, stride_c, batch_count)),
-      d_desc_(
-          CreateCublasLtLayoutDesc(cd_type, m, n, ldd, stride_d, batch_count)),
-      ab_type_(ab_type),
-      cd_type_(cd_type),
-      scale_type_(GetScaleType(cd_type, computation_type)),
-      pointer_mode_(pointer_mode),
-      epilogue_(epilogue),
-      batch_count_(batch_count),
-      stride_a_(stride_a),
-      stride_b_(stride_b),
-      stride_c_(stride_c),
-      stride_d_(stride_d) {
-  uint64 rows_a = transa == blas::Transpose::kNoTranspose ? m : k;
-  uint64 cols_a = transa == blas::Transpose::kNoTranspose ? k : m;
-  uint64 rows_b = transb == blas::Transpose::kNoTranspose ? k : n;
-  uint64 cols_b = transb == blas::Transpose::kNoTranspose ? n : k;
-  a_desc_ = CreateCublasLtLayoutDesc(ab_type, rows_a, cols_a, lda, stride_a,
-                                     batch_count);
-  b_desc_ = CreateCublasLtLayoutDesc(ab_type, rows_b, cols_b, ldb, stride_b,
-                                     batch_count);
+      c_desc_(CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
+                                       p.batch_count)),
+      d_desc_(CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc, p.stride_c,
+                                       p.batch_count)),
+      ab_type_(p.ab_type),
+      cd_type_(p.c_type),
+      scale_type_(GetScaleType(p.c_type, p.computation_type)),
+      pointer_mode_(p.pointer_mode),
+      epilogue_(p.epilogue),
+      batch_count_(p.batch_count),
+      stride_a_(p.stride_a),
+      stride_b_(p.stride_b),
+      stride_c_(p.stride_c),
+      stride_d_(p.stride_c) {
+  uint64 rows_a = p.transa == blas::Transpose::kNoTranspose ? p.m : p.k;
+  uint64 cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m;
+  uint64 rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n;
+  uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k;
+  a_desc_ = CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda,
+                                     p.stride_a, p.batch_count);
+  b_desc_ = CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb,
+                                     p.stride_b, p.batch_count);
 }
 
 bool CUDABlasLtMatmulPlan::SetBiasPointer(const void* bias) const {
@@ -3395,18 +3384,10 @@ UniqueMatmulPreference CreateCublasLtMatmulPreference(
 
 #endif  // CUDA_VERSION >= 11000
 
-std::unique_ptr<blas::IBlasLtMatmulPlan>
-CUDABlas::CreateBlasLtMatmulPlanStridedBatched(
-    blas::DataType ab_type, blas::DataType cd_type,
-    blas::ComputationType computation_type, blas::PointerMode pointer_mode,
-    blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
-    uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a,
-    int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) {
+std::unique_ptr<blas::IBlasLtMatmulPlan> CUDABlas::CreateBlasLtMatmulPlan(
+    const blas::BlasLtMatmulPlanParams& params) {
 #if CUDA_VERSION >= 11000
-  auto result = std::make_unique<CUDABlasLtMatmulPlan>(
-      ab_type, cd_type, computation_type, pointer_mode, epilogue, transa,
-      transb, m, n, k, batch_count, lda, stride_a, ldb, stride_b, ldc, stride_c,
-      ldc, stride_c);
+  auto result = std::make_unique<CUDABlasLtMatmulPlan>(params);
   if (!result->ok()) {
     result.reset();
   }
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index d75c1bc65c5..d40b6adc285 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -337,34 +337,12 @@ bool StreamExecutor::GetBlasGemmAlgorithms(
 }
 
 std::unique_ptr<blas::IBlasLtMatmulPlan> StreamExecutor::CreateBlasLtMatmulPlan(
-    blas::DataType ab_type, blas::DataType cd_type,
-    blas::ComputationType computation_type, blas::PointerMode pointer_mode,
-    blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
-    uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc) {
+    const blas::BlasLtMatmulPlanParams& params) {
   blas::BlasSupport *blas_support = AsBlas();
   if (!blas_support) {
     return nullptr;
   }
-  return blas_support->CreateBlasLtMatmulPlan(
-      ab_type, cd_type, computation_type, pointer_mode, epilogue, transa,
-      transb, m, n, k, lda, ldb, ldc);
-}
-
-std::unique_ptr<blas::IBlasLtMatmulPlan>
-StreamExecutor::CreateBlasLtMatmulPlanStridedBatched(
-    blas::DataType ab_type, blas::DataType cd_type,
-    blas::ComputationType computation_type, blas::PointerMode pointer_mode,
-    blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
-    uint64 m, uint64 n, uint64 k, uint64 batch_count, int64 lda, int64 stride_a,
-    int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) {
-  blas::BlasSupport *blas_support = AsBlas();
-  if (!blas_support) {
-    return nullptr;
-  }
-  return blas_support->CreateBlasLtMatmulPlanStridedBatched(
-      ab_type, cd_type, computation_type, pointer_mode, epilogue, transa,
-      transb, m, n, k, batch_count, lda, stride_a, ldb, stride_b, ldc,
-      stride_c);
+  return blas_support->CreateBlasLtMatmulPlan(params);
 }
 
 bool StreamExecutor::GetBlasLtMatmulAlgorithms(
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index b40c0c23c05..ce801bf0f28 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -399,19 +399,7 @@ class StreamExecutor {
   // created once and reused for multiple calls to DoBlasLtMatmul().
   // Returns a null pointer on failure.
   std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
-      blas::DataType ab_type, blas::DataType cd_type,
-      blas::ComputationType computation_type, blas::PointerMode pointer_mode,
-      blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
-      uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc);
-
-  // A more general version of CreateBlasLtMatmulPlan supporting
-  // batched operations.
-  std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlanStridedBatched(
-      blas::DataType ab_type, blas::DataType cd_type,
-      blas::ComputationType computation_type, blas::PointerMode pointer_mode,
-      blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
-      uint64 m, uint64 n, uint64 k, uint64 batch_count, int64 lda,
-      int64 stride_a, int64 ldb, int64 stride_b, int64 ldc, int64 stride_c);
+      const blas::BlasLtMatmulPlanParams& params);
 
   // Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
   // returned in the order of increasing estimated compute time according to an