diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/batch_matmul_op_impl.h
index 5ca85c00835..456b4beff1e 100644
--- a/tensorflow/core/kernels/batch_matmul_op_impl.h
+++ b/tensorflow/core/kernels/batch_matmul_op_impl.h
@@ -558,8 +558,8 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
             stream->parent()->CreateBlasLtMatmulPlanStridedBatched(
                 /*ab_type=*/blas_dtype,
                 /*cd_type=*/blas_dtype, computation_type,
-                se::blas::PointerMode::kHost, blas_transpose_b,
-                blas_transpose_a, n, m, k, batch_size,
+                se::blas::PointerMode::kHost, se::blas::Epilogue::kDefault,
+                blas_transpose_b, blas_transpose_a, n, m, k, batch_size,
                 /*lda=*/in_y.dim_size(2), b_stride,
                 /*ldb=*/in_x.dim_size(2), a_stride, /*ldc=*/n, c_stride);
         OP_REQUIRES(
@@ -621,7 +621,8 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
               stream
                   ->ThenBlasLtMatmul(plan.get(), alpha, *b_ptrs[0], *a_ptrs[0],
                                      beta, c_ptrs[0], &scratch_allocator,
-                                     profile_algorithm.get(), &profile_result)
+                                     profile_algorithm.get(), {},
+                                     &profile_result)
                   .ok();
 
           VLOG(4) << "  Autotune algorithm " << i
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index 583fba2a505..ae5b4853d05 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -107,6 +107,13 @@ enum class ComputationType {
   kF32FastBF16, // 32-bit floating-point with reduced (7-bit) mantissa
 };
 
+enum class Epilogue {
+  kDefault = 1,                   // No special postprocessing
+  kReLU = 2,                      // Apply ReLU func point-wise to the results
+  kBias = 4,                      // Add broadcasted bias vector to the results
+  kBiasThenReLU = kBias | kReLU,  // Apply bias and then ReLU transform
+};
+
 // Converts a ComputationType to a string.
 std::string ComputationTypeString(ComputationType ty);
 
@@ -1462,11 +1469,11 @@ class BlasSupport {
   std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
       blas::DataType ab_type, blas::DataType c_type,
       blas::ComputationType computation_type, blas::PointerMode pointer_mode,
-      blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
-      uint64 k, int64 lda, int64 ldb, int64 ldc) {
+      blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
+      uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc) {
     return CreateBlasLtMatmulPlanStridedBatched(
-        ab_type, c_type, computation_type, pointer_mode, transa, transb, m, n,
-        k, 1, lda, 0, ldb, 0, ldc, 0);
+        ab_type, c_type, computation_type, pointer_mode, epilogue, transa,
+        transb, m, n, k, 1, lda, 0, ldb, 0, ldc, 0);
   }
 
   // A more general version of CreateBlasLtMatmulPlan supporting
@@ -1475,9 +1482,9 @@ class BlasSupport {
   CreateBlasLtMatmulPlanStridedBatched(
       blas::DataType ab_type, blas::DataType c_type,
       blas::ComputationType computation_type, blas::PointerMode pointer_mode,
-      blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
-      uint64 k, int batch_count, int64 lda, int64 stride_a, int64 ldb,
-      int64 stride_b, int64 ldc, int64 stride_c) = 0;
+      blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
+      uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a,
+      int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) = 0;
 
   // Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
   // returned in the order of increasing estimated compute time according to an
@@ -1492,13 +1499,18 @@ class BlasSupport {
   // Executes a blaslt matmul operation on the stream. If output_profile_result
   // is not nullptr, the operation is profiled, error messages are
   // suppressed, and output_profile_result->algorithm() is set to
-  // algorithm->index().
+  // algorithm->index(). If epilogue was set to kBias or kBiasThenReLU when
+  // creating the plan, the bias argument here must refer to a valid device
+  // vector of length equal to the number of rows in matrix c. If epilogue was
+  // set to any other value then the bias argument here must be null. The bias
+  // vector is broadcast across the batch dimension.
   virtual bool DoBlasLtMatmul(
       Stream* stream, const blas::IBlasLtMatmulPlan* plan,
       const HostOrDeviceScalar<int32>& alpha, const DeviceMemory<int8>& a,
       const DeviceMemory<int8>& b, const HostOrDeviceScalar<int32>& beta,
       DeviceMemory<int32>* c, ScratchAllocator* scratch_allocator,
       const blas::IBlasLtMatmulAlgorithm* algorithm,
+      const DeviceMemory<int32>& bias = {},
       blas::ProfileResult* output_profile_result = nullptr) = 0;
   virtual bool DoBlasLtMatmul(
       Stream* stream, const blas::IBlasLtMatmulPlan* plan,
@@ -1507,6 +1519,7 @@ class BlasSupport {
       const HostOrDeviceScalar<Eigen::half>& beta, DeviceMemory<Eigen::half>* c,
       ScratchAllocator* scratch_allocator,
       const blas::IBlasLtMatmulAlgorithm* algorithm,
+      const DeviceMemory<Eigen::half>& bias = {},
       blas::ProfileResult* output_profile_result = nullptr) = 0;
   virtual bool DoBlasLtMatmul(
       Stream* stream, const blas::IBlasLtMatmulPlan* plan,
@@ -1514,6 +1527,7 @@ class BlasSupport {
       const DeviceMemory<float>& b, const HostOrDeviceScalar<float>& beta,
       DeviceMemory<float>* c, ScratchAllocator* scratch_allocator,
       const blas::IBlasLtMatmulAlgorithm* algorithm,
+      const DeviceMemory<float>& bias = {},
       blas::ProfileResult* output_profile_result = nullptr) = 0;
   virtual bool DoBlasLtMatmul(
       Stream* stream, const blas::IBlasLtMatmulPlan* plan,
@@ -1521,6 +1535,7 @@ class BlasSupport {
       const DeviceMemory<double>& b, const HostOrDeviceScalar<double>& beta,
       DeviceMemory<double>* c, ScratchAllocator* scratch_allocator,
       const blas::IBlasLtMatmulAlgorithm* algorithm,
+      const DeviceMemory<double>& bias = {},
       blas::ProfileResult* output_profile_result = nullptr) = 0;
   virtual bool DoBlasLtMatmul(
       Stream* stream, const blas::IBlasLtMatmulPlan* plan,
@@ -1530,6 +1545,7 @@ class BlasSupport {
       const HostOrDeviceScalar<std::complex<float>>& beta,
       DeviceMemory<std::complex<float>>* c, ScratchAllocator* scratch_allocator,
       const blas::IBlasLtMatmulAlgorithm* algorithm,
+      const DeviceMemory<std::complex<float>>& bias = {},
       blas::ProfileResult* output_profile_result = nullptr) = 0;
   virtual bool DoBlasLtMatmul(
       Stream* stream, const blas::IBlasLtMatmulPlan* plan,
@@ -1540,6 +1556,7 @@ class BlasSupport {
       DeviceMemory<std::complex<double>>* c,
       ScratchAllocator* scratch_allocator,
       const blas::IBlasLtMatmulAlgorithm* algorithm,
+      const DeviceMemory<std::complex<double>>& bias = {},
       blas::ProfileResult* output_profile_result = nullptr) = 0;
 
   virtual port::Status GetVersion(std::string *version) = 0;
@@ -2359,9 +2376,10 @@ class BlasSupport {
   CreateBlasLtMatmulPlanStridedBatched(                                        \
       blas::DataType ab_type, blas::DataType cd_type,                          \
       blas::ComputationType computation_type, blas::PointerMode pointer_mode,  \
-      blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,      \
-      uint64 k, int batch_count, int64 lda, int64 stride_a, int64 ldb,         \
-      int64 stride_b, int64 ldc, int64 stride_c) override;                     \
+      blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb, \
+      uint64 m, uint64 n, uint64 k, int batch_count, int64 lda,                \
+      int64 stride_a, int64 ldb, int64 stride_b, int64 ldc, int64 stride_c)    \
+      override;                                                                \
   bool GetBlasLtMatmulAlgorithms(                                              \
       const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size,          \
       int max_algorithm_count,                                                 \
@@ -2373,6 +2391,7 @@ class BlasSupport {
       const DeviceMemory<int8>& b, const HostOrDeviceScalar<int32>& beta,      \
       DeviceMemory<int32>* c, ScratchAllocator* scratch_allocator,             \
       const blas::IBlasLtMatmulAlgorithm* algorithm,                           \
+      const DeviceMemory<int32>& bias = {},                                    \
       blas::ProfileResult* output_profile_result = nullptr) override;          \
   bool DoBlasLtMatmul(                                                         \
       Stream* stream, const blas::IBlasLtMatmulPlan* plan,                     \
@@ -2381,21 +2400,24 @@ class BlasSupport {
       const HostOrDeviceScalar<Eigen::half>& beta,                             \
       DeviceMemory<Eigen::half>* c, ScratchAllocator* scratch_allocator,       \
       const blas::IBlasLtMatmulAlgorithm* algorithm,                           \
-      blas::ProfileResult* output_profile_result) override;                    \
+      const DeviceMemory<Eigen::half>& bias = {},                              \
+      blas::ProfileResult* output_profile_result = nullptr) override;          \
   bool DoBlasLtMatmul(                                                         \
       Stream* stream, const blas::IBlasLtMatmulPlan* plan,                     \
       const HostOrDeviceScalar<float>& alpha, const DeviceMemory<float>& a,    \
       const DeviceMemory<float>& b, const HostOrDeviceScalar<float>& beta,     \
       DeviceMemory<float>* c, ScratchAllocator* scratch_allocator,             \
       const blas::IBlasLtMatmulAlgorithm* algorithm,                           \
-      blas::ProfileResult* output_profile_result) override;                    \
+      const DeviceMemory<float>& bias = {},                                    \
+      blas::ProfileResult* output_profile_result = nullptr) override;          \
   bool DoBlasLtMatmul(                                                         \
       Stream* stream, const blas::IBlasLtMatmulPlan* plan,                     \
       const HostOrDeviceScalar<double>& alpha, const DeviceMemory<double>& a,  \
       const DeviceMemory<double>& b, const HostOrDeviceScalar<double>& beta,   \
       DeviceMemory<double>* c, ScratchAllocator* scratch_allocator,            \
       const blas::IBlasLtMatmulAlgorithm* algorithm,                           \
-      blas::ProfileResult* output_profile_result) override;                    \
+      const DeviceMemory<double>& bias = {},                                   \
+      blas::ProfileResult* output_profile_result = nullptr) override;          \
   bool DoBlasLtMatmul(Stream* stream, const blas::IBlasLtMatmulPlan* plan,     \
                       const HostOrDeviceScalar<std::complex<float>>& alpha,    \
                       const DeviceMemory<std::complex<float>>& a,              \
@@ -2404,7 +2426,9 @@ class BlasSupport {
                       DeviceMemory<std::complex<float>>* c,                    \
                       ScratchAllocator* scratch_allocator,                     \
                       const blas::IBlasLtMatmulAlgorithm* algorithm,           \
-                      blas::ProfileResult* output_profile_result) override;    \
+                      const DeviceMemory<std::complex<float>>& bias = {},      \
+                      blas::ProfileResult* output_profile_result = nullptr)    \
+      override;                                                                \
   bool DoBlasLtMatmul(Stream* stream, const blas::IBlasLtMatmulPlan* plan,     \
                       const HostOrDeviceScalar<std::complex<double>>& alpha,   \
                       const DeviceMemory<std::complex<double>>& a,             \
@@ -2413,7 +2437,9 @@ class BlasSupport {
                       DeviceMemory<std::complex<double>>* c,                   \
                       ScratchAllocator* scratch_allocator,                     \
                       const blas::IBlasLtMatmulAlgorithm* algorithm,           \
-                      blas::ProfileResult* output_profile_result) override;    \
+                      const DeviceMemory<std::complex<double>>& bias = {},     \
+                      blas::ProfileResult* output_profile_result = nullptr)    \
+      override;                                                                \
   port::Status GetVersion(std::string *version) override;
 
 }  // namespace blas
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index ba833e562e2..1d95b00ce7e 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -468,6 +468,18 @@ cublasLtPointerMode_t CUBLASPointerMode(blas::PointerMode pointer_mode) {
     return CUBLASLT_POINTER_MODE_DEVICE;
   }
 }
+cublasLtEpilogue_t CUBLASEpilogue(blas::Epilogue epilogue) {
+  switch (epilogue) {
+  case blas::Epilogue::kDefault:
+    return CUBLASLT_EPILOGUE_DEFAULT;
+  case blas::Epilogue::kReLU:
+    return CUBLASLT_EPILOGUE_RELU;
+  case blas::Epilogue::kBias:
+    return CUBLASLT_EPILOGUE_BIAS;
+  case blas::Epilogue::kBiasThenReLU:
+    return CUBLASLT_EPILOGUE_RELU_BIAS;
+  }
+}
 #endif  // CUDA_VERSION >= 11000
 
 cudaDataType_t GetCUDADataType(blas::DataType ty) {
@@ -3135,12 +3147,12 @@ using UniqueMatmulPreference =
     std::unique_ptr<std::remove_pointer<cublasLtMatmulPreference_t>::type,
                     MatmulPreferenceDestroyer>;
 
-UniqueOpDesc CreateCublasLtOperationDesc(
-    blas::ComputationType computation_type, blas::DataType scale_type,
-    blas::PointerMode pointer_mode, blas::Transpose transa,
-    blas::Transpose transb) {
-  cublasOperation_t cuda_transa = CUDABlasTranspose(transa);
-  cublasOperation_t cuda_transb = CUDABlasTranspose(transb);
+UniqueOpDesc CreateCublasLtOperationDesc(blas::ComputationType computation_type,
+                                         blas::DataType scale_type,
+                                         blas::PointerMode pointer_mode,
+                                         blas::Epilogue epilogue,
+                                         blas::Transpose transa,
+                                         blas::Transpose transb) {
   cublasLtMatmulDesc_t desc;
   cublasComputeType_t cublas_compute_type =
       CUBLASComputationType(computation_type);
@@ -3154,9 +3166,13 @@ UniqueOpDesc CreateCublasLtOperationDesc(
   }
   UniqueOpDesc unique_desc(desc);
   if (!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_POINTER_MODE,
-                         CUBLASPointerMode(pointer_mode)) ||
-      !SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSA, cuda_transa) ||
-      !SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSB, cuda_transb)) {
+                       CUBLASPointerMode(pointer_mode)) ||
+      !SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_EPILOGUE,
+                       CUBLASEpilogue(epilogue)) ||
+      !SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSA,
+                       CUDABlasTranspose(transa)) ||
+      !SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSB,
+                       CUDABlasTranspose(transb))) {
     return nullptr;
   }
   return unique_desc;
@@ -3217,11 +3233,11 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
  public:
   CUDABlasLtMatmulPlan(blas::DataType ab_type, blas::DataType cd_type,
                        blas::ComputationType compute_type,
-                       blas::PointerMode pointer_mode, blas::Transpose transa,
-                       blas::Transpose transb, uint64 m, uint64 n, uint64 k,
-                       int batch_count, int64 lda, int64 stride_a, int64 ldb,
-                       int64 stride_b, int64 ldc, int64 stride_c, int64 ldd,
-                       int64 stride_d);
+                       blas::PointerMode pointer_mode, blas::Epilogue epilogue,
+                       blas::Transpose transa, blas::Transpose transb, uint64 m,
+                       uint64 n, uint64 k, int batch_count, int64 lda,
+                       int64 stride_a, int64 ldb, int64 stride_b, int64 ldc,
+                       int64 stride_c, int64 ldd, int64 stride_d);
 
   cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); }
   cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); }
@@ -3234,12 +3250,17 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
   blas::DataType cd_type() const { return cd_type_; }
   blas::DataType scale_type() const { return scale_type_; }
   blas::PointerMode pointer_mode() const { return pointer_mode_; }
+  blas::Epilogue epilogue() const { return epilogue_; }
   int batch_count() const { return batch_count_; }
   int64 stride_a() const { return stride_a_; }
   int64 stride_b() const { return stride_b_; }
   int64 stride_c() const { return stride_c_; }
   int64 stride_d() const { return stride_d_; }
 
+  // Note: Must be const to satisfy API. This is always called before the plan
+  // is executed, so the state change is not observed in subsequent executions.
+  bool SetBiasPointer(const void* bias) const;
+
  private:
   UniqueOpDesc op_desc_;
   UniqueLayoutDesc a_desc_;
@@ -3250,6 +3271,7 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
   blas::DataType cd_type_;
   blas::DataType scale_type_;
   blas::PointerMode pointer_mode_;
+  blas::Epilogue epilogue_;
   int batch_count_;
   int64 stride_a_;
   int64 stride_b_;
@@ -3260,12 +3282,13 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
 CUDABlasLtMatmulPlan::CUDABlasLtMatmulPlan(
     blas::DataType ab_type, blas::DataType cd_type,
     blas::ComputationType computation_type, blas::PointerMode pointer_mode,
-    blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
-    uint64 k, int batch_count, int64 lda, int64 stride_a, int64 ldb,
-    int64 stride_b, int64 ldc, int64 stride_c, int64 ldd, int64 stride_d)
+    blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
+    uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a,
+    int64 ldb, int64 stride_b, int64 ldc, int64 stride_c, int64 ldd,
+    int64 stride_d)
     : op_desc_(CreateCublasLtOperationDesc(
           computation_type, GetScaleType(cd_type, computation_type),
-          pointer_mode, transa, transb)),
+          pointer_mode, epilogue, transa, transb)),
       a_desc_(nullptr),
       b_desc_(nullptr),
       c_desc_(
@@ -3276,6 +3299,7 @@ CUDABlasLtMatmulPlan::CUDABlasLtMatmulPlan(
       cd_type_(cd_type),
       scale_type_(GetScaleType(cd_type, computation_type)),
       pointer_mode_(pointer_mode),
+      epilogue_(epilogue),
       batch_count_(batch_count),
       stride_a_(stride_a),
       stride_b_(stride_b),
@@ -3291,6 +3315,11 @@ CUDABlasLtMatmulPlan::CUDABlasLtMatmulPlan(
                                      batch_count);
 }
 
+bool CUDABlasLtMatmulPlan::SetBiasPointer(const void* bias) const {
+  return SetCublasLtAttr(op_desc_.get(), CUBLASLT_MATMUL_DESC_BIAS_POINTER,
+                         bias);
+}
+
 class CUDABlasLtMatmulAlgorithm final : public blas::IBlasLtMatmulAlgorithm {
  public:
   CUDABlasLtMatmulAlgorithm(blas::AlgorithmType index,
@@ -3370,13 +3399,14 @@ std::unique_ptr<blas::IBlasLtMatmulPlan>
 CUDABlas::CreateBlasLtMatmulPlanStridedBatched(
     blas::DataType ab_type, blas::DataType cd_type,
     blas::ComputationType computation_type, blas::PointerMode pointer_mode,
-    blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
-    uint64 k, int batch_count, int64 lda, int64 stride_a, int64 ldb,
-    int64 stride_b, int64 ldc, int64 stride_c) {
+    blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
+    uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a,
+    int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) {
 #if CUDA_VERSION >= 11000
   auto result = std::make_unique<CUDABlasLtMatmulPlan>(
-      ab_type, cd_type, computation_type, pointer_mode, transa, transb, m, n, k,
-      batch_count, lda, stride_a, ldb, stride_b, ldc, stride_c, ldc, stride_c);
+      ab_type, cd_type, computation_type, pointer_mode, epilogue, transa,
+      transb, m, n, k, batch_count, lda, stride_a, ldb, stride_b, ldc, stride_c,
+      ldc, stride_c);
   if (!result->ok()) {
     result.reset();
   }
@@ -3436,7 +3466,8 @@ bool CUDABlas::DoBlasLtMatmulInternalImpl(
     const HostOrDeviceScalar<ScaleType>& alpha, const ABType* a,
     const ABType* b, const HostOrDeviceScalar<ScaleType>& beta, const CDType* c,
     CDType* d, ScratchAllocator* scratch_allocator,
-    const blas::IBlasLtMatmulAlgorithm* algorithm) {
+    const blas::IBlasLtMatmulAlgorithm* algorithm,
+    const CDType* bias) {
   const auto& cuda_plan = *static_cast<const CUDABlasLtMatmulPlan*>(plan);
   const auto& cuda_algo =
       *static_cast<const CUDABlasLtMatmulAlgorithm*>(algorithm);
@@ -3474,6 +3505,20 @@ bool CUDABlas::DoBlasLtMatmulInternalImpl(
                "pointer_mode for the given alpha/beta.";
     return false;
   }
+  if ((cuda_plan.epilogue() == blas::Epilogue::kBias ||
+       cuda_plan.epilogue() == blas::Epilogue::kBiasThenReLU) !=
+      (bias != nullptr)) {
+    VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong "
+               "epilogue for the given bias pointer.";
+    return false;
+  }
+  if (bias != nullptr) {
+    if (!cuda_plan.SetBiasPointer(bias)) {
+      VLOG(2) << "DoBlasLtMatmul returning false because setting the bias "
+                 "pointer failed.";
+      return false;
+    }
+  }
   const ScaleType* alpha_ptr =
       alpha.is_pointer() ? GpuMemory(alpha.pointer()) : &alpha.value();
   const ScaleType* beta_ptr =
@@ -3525,6 +3570,7 @@ bool CUDABlas::DoBlasLtMatmulInternal(
     const DeviceMemory<CDType>& c, DeviceMemory<CDType>* d,
     ScratchAllocator* scratch_allocator,
     const blas::IBlasLtMatmulAlgorithm* algorithm,
+    const DeviceMemory<CDType>& bias,
     blas::ProfileResult* output_profile_result) {
 #if CUDA_VERSION >= 11000
   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
@@ -3538,7 +3584,8 @@ bool CUDABlas::DoBlasLtMatmulInternal(
   bool err_on_failure = timer != nullptr;
   bool result = DoBlasLtMatmulInternalImpl(
       stream, err_on_failure, plan, alpha, GpuMemory(a), GpuMemory(b), beta,
-      GpuMemory(c), GpuMemoryMutable(d), scratch_allocator, algorithm);
+      GpuMemory(c), GpuMemoryMutable(d), scratch_allocator, algorithm,
+      GpuMemory(bias));
 
   if (timer && result) {
     // GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
@@ -3563,9 +3610,10 @@ bool CUDABlas::DoBlasLtMatmul(
     const DeviceMemory<int8>& b, const HostOrDeviceScalar<int32>& beta,
     DeviceMemory<int32>* c, ScratchAllocator* scratch_allocator,
     const blas::IBlasLtMatmulAlgorithm* algorithm,
+    const DeviceMemory<int32>& bias,
     blas::ProfileResult* output_profile_result) {
   return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
-                                scratch_allocator, algorithm,
+                                scratch_allocator, algorithm, bias,
                                 output_profile_result);
 }
 
@@ -3578,6 +3626,7 @@ bool CUDABlas::DoBlasLtMatmul(Stream* stream,
                               DeviceMemory<Eigen::half>* c,
                               ScratchAllocator* scratch_allocator,
                               const blas::IBlasLtMatmulAlgorithm* algorithm,
+                              const DeviceMemory<Eigen::half>& bias,
                               blas::ProfileResult* output_profile_result) {
 #if CUDA_VERSION >= 11000
   const auto& cuda_plan = *static_cast<const CUDABlasLtMatmulPlan*>(plan);
@@ -3591,11 +3640,11 @@ bool CUDABlas::DoBlasLtMatmul(Stream* stream,
     HostOrDeviceScalar<float> float_alpha(static_cast<float>(alpha.value()));
     HostOrDeviceScalar<float> float_beta(static_cast<float>(beta.value()));
     return DoBlasLtMatmulInternal(stream, plan, float_alpha, a, b, float_beta,
-                                  *c, c, scratch_allocator, algorithm,
+                                  *c, c, scratch_allocator, algorithm, bias,
                                   output_profile_result);
   }
   return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
-                                scratch_allocator, algorithm,
+                                scratch_allocator, algorithm, bias,
                                 output_profile_result);
 #else  // if CUDA_VERSION < 11000
   return false;
@@ -3608,9 +3657,10 @@ bool CUDABlas::DoBlasLtMatmul(
     const DeviceMemory<float>& b, const HostOrDeviceScalar<float>& beta,
     DeviceMemory<float>* c, ScratchAllocator* scratch_allocator,
     const blas::IBlasLtMatmulAlgorithm* algorithm,
+    const DeviceMemory<float>& bias,
     blas::ProfileResult* output_profile_result) {
   return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
-                                scratch_allocator, algorithm,
+                                scratch_allocator, algorithm, bias,
                                 output_profile_result);
 }
 
@@ -3620,9 +3670,10 @@ bool CUDABlas::DoBlasLtMatmul(
     const DeviceMemory<double>& b, const HostOrDeviceScalar<double>& beta,
     DeviceMemory<double>* c, ScratchAllocator* scratch_allocator,
     const blas::IBlasLtMatmulAlgorithm* algorithm,
+    const DeviceMemory<double>& bias,
     blas::ProfileResult* output_profile_result) {
   return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
-                                scratch_allocator, algorithm,
+                                scratch_allocator, algorithm, bias,
                                 output_profile_result);
 }
 
@@ -3634,9 +3685,10 @@ bool CUDABlas::DoBlasLtMatmul(
     const HostOrDeviceScalar<std::complex<float>>& beta,
     DeviceMemory<std::complex<float>>* c, ScratchAllocator* scratch_allocator,
     const blas::IBlasLtMatmulAlgorithm* algorithm,
+    const DeviceMemory<std::complex<float>>& bias,
     blas::ProfileResult* output_profile_result) {
   return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
-                                scratch_allocator, algorithm,
+                                scratch_allocator, algorithm, bias,
                                 output_profile_result);
 }
 
@@ -3648,9 +3700,10 @@ bool CUDABlas::DoBlasLtMatmul(
     const HostOrDeviceScalar<std::complex<double>>& beta,
     DeviceMemory<std::complex<double>>* c, ScratchAllocator* scratch_allocator,
     const blas::IBlasLtMatmulAlgorithm* algorithm,
+    const DeviceMemory<std::complex<double>>& bias,
     blas::ProfileResult* output_profile_result) {
   return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
-                                scratch_allocator, algorithm,
+                                scratch_allocator, algorithm, bias,
                                 output_profile_result);
 }
 
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.h b/tensorflow/stream_executor/cuda/cuda_blas.h
index 351a7778c01..3fdfcb0a50c 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.h
+++ b/tensorflow/stream_executor/cuda/cuda_blas.h
@@ -148,6 +148,7 @@ class CUDABlas : public blas::BlasSupport {
       const DeviceMemory<CDType>& c, DeviceMemory<CDType>* d,
       ScratchAllocator* scratch_allocator,
       const blas::IBlasLtMatmulAlgorithm* algorithm,
+      const DeviceMemory<CDType>& bias,
       blas::ProfileResult* output_profile_result);
 
   // Helper function for implementing DoBlasLtMatmulInternal.
@@ -157,7 +158,7 @@ class CUDABlas : public blas::BlasSupport {
       const HostOrDeviceScalar<ScaleType>& alpha, const ABType* a,
       const ABType* b, const HostOrDeviceScalar<ScaleType>& beta,
       const CDType* c, CDType* d, ScratchAllocator* scratch_allocator,
-      const blas::IBlasLtMatmulAlgorithm* algorithm);
+      const blas::IBlasLtMatmulAlgorithm* algorithm, const CDType* bias);
 
   // Guards the cuBLAS handle for this device.
   absl::Mutex mu_;
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 144af92185c..66728c94821 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -4809,18 +4809,19 @@ Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
                                  DeviceMemory<int32>* c,
                                  ScratchAllocator* scratch_allocator,
                                  const blas::IBlasLtMatmulAlgorithm* algorithm,
+                                 const DeviceMemory<int32>& bias,
                                  blas::ProfileResult* output_profile_result) {
   VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
-            PARAM(c), PARAM(algorithm));
+            PARAM(c), PARAM(algorithm), PARAM(bias));
 
   ThenBlasWithProfileImpl<
       const blas::IBlasLtMatmulPlan*, const HostOrDeviceScalar<int32>&,
       const DeviceMemory<int8>&, const DeviceMemory<int8>&,
       const HostOrDeviceScalar<int32>&, DeviceMemory<int32>*, ScratchAllocator*,
-      const blas::IBlasLtMatmulAlgorithm*>
+      const blas::IBlasLtMatmulAlgorithm*, const DeviceMemory<int32>&>
       impl;
   return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
-              c, scratch_allocator, algorithm, output_profile_result);
+              c, scratch_allocator, algorithm, bias, output_profile_result);
 }
 
 Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
@@ -4831,18 +4832,20 @@ Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
                                  DeviceMemory<Eigen::half>* c,
                                  ScratchAllocator* scratch_allocator,
                                  const blas::IBlasLtMatmulAlgorithm* algorithm,
+                                 const DeviceMemory<Eigen::half>& bias,
                                  blas::ProfileResult* output_profile_result) {
   VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
-            PARAM(c), PARAM(algorithm));
+            PARAM(c), PARAM(algorithm), PARAM(bias));
 
   ThenBlasWithProfileImpl<
       const blas::IBlasLtMatmulPlan*, const HostOrDeviceScalar<Eigen::half>&,
       const DeviceMemory<Eigen::half>&, const DeviceMemory<Eigen::half>&,
       const HostOrDeviceScalar<Eigen::half>&, DeviceMemory<Eigen::half>*,
-      ScratchAllocator*, const blas::IBlasLtMatmulAlgorithm*>
+      ScratchAllocator*, const blas::IBlasLtMatmulAlgorithm*,
+      const DeviceMemory<Eigen::half>&>
       impl;
   return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
-              c, scratch_allocator, algorithm, output_profile_result);
+              c, scratch_allocator, algorithm, bias, output_profile_result);
 }
 
 Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
@@ -4853,18 +4856,19 @@ Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
                                  DeviceMemory<float>* c,
                                  ScratchAllocator* scratch_allocator,
                                  const blas::IBlasLtMatmulAlgorithm* algorithm,
+                                 const DeviceMemory<float>& bias,
                                  blas::ProfileResult* output_profile_result) {
   VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
-            PARAM(c), PARAM(algorithm));
+            PARAM(c), PARAM(algorithm), PARAM(bias));
 
   ThenBlasWithProfileImpl<
       const blas::IBlasLtMatmulPlan*, const HostOrDeviceScalar<float>&,
       const DeviceMemory<float>&, const DeviceMemory<float>&,
       const HostOrDeviceScalar<float>&, DeviceMemory<float>*, ScratchAllocator*,
-      const blas::IBlasLtMatmulAlgorithm*>
+      const blas::IBlasLtMatmulAlgorithm*, const DeviceMemory<float>&>
       impl;
   return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
-              c, scratch_allocator, algorithm, output_profile_result);
+              c, scratch_allocator, algorithm, bias, output_profile_result);
 }
 
 Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
@@ -4875,18 +4879,20 @@ Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
                                  DeviceMemory<double>* c,
                                  ScratchAllocator* scratch_allocator,
                                  const blas::IBlasLtMatmulAlgorithm* algorithm,
+                                 const DeviceMemory<double>& bias,
                                  blas::ProfileResult* output_profile_result) {
   VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
-            PARAM(c), PARAM(algorithm));
+            PARAM(c), PARAM(algorithm), PARAM(bias));
 
   ThenBlasWithProfileImpl<
       const blas::IBlasLtMatmulPlan*, const HostOrDeviceScalar<double>&,
       const DeviceMemory<double>&, const DeviceMemory<double>&,
       const HostOrDeviceScalar<double>&, DeviceMemory<double>*,
-      ScratchAllocator*, const blas::IBlasLtMatmulAlgorithm*>
+      ScratchAllocator*, const blas::IBlasLtMatmulAlgorithm*,
+      const DeviceMemory<double>&>
       impl;
   return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
-              c, scratch_allocator, algorithm, output_profile_result);
+              c, scratch_allocator, algorithm, bias, output_profile_result);
 }
 
 Stream& Stream::ThenBlasLtMatmul(
@@ -4897,9 +4903,10 @@ Stream& Stream::ThenBlasLtMatmul(
     const HostOrDeviceScalar<std::complex<float>>& beta,
     DeviceMemory<std::complex<float>>* c, ScratchAllocator* scratch_allocator,
     const blas::IBlasLtMatmulAlgorithm* algorithm,
+    const DeviceMemory<std::complex<float>>& bias,
     blas::ProfileResult* output_profile_result) {
   VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
-            PARAM(c), PARAM(algorithm));
+            PARAM(c), PARAM(algorithm), PARAM(bias));
 
   ThenBlasWithProfileImpl<const blas::IBlasLtMatmulPlan*,
                           const HostOrDeviceScalar<std::complex<float>>&,
@@ -4907,10 +4914,11 @@ Stream& Stream::ThenBlasLtMatmul(
                           const DeviceMemory<std::complex<float>>&,
                           const HostOrDeviceScalar<std::complex<float>>&,
                           DeviceMemory<std::complex<float>>*, ScratchAllocator*,
-                          const blas::IBlasLtMatmulAlgorithm*>
+                          const blas::IBlasLtMatmulAlgorithm*,
+                          const DeviceMemory<std::complex<float>>&>
       impl;
   return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
-              c, scratch_allocator, algorithm, output_profile_result);
+              c, scratch_allocator, algorithm, bias, output_profile_result);
 }
 
 Stream& Stream::ThenBlasLtMatmul(
@@ -4921,9 +4929,10 @@ Stream& Stream::ThenBlasLtMatmul(
     const HostOrDeviceScalar<std::complex<double>>& beta,
     DeviceMemory<std::complex<double>>* c, ScratchAllocator* scratch_allocator,
     const blas::IBlasLtMatmulAlgorithm* algorithm,
+    const DeviceMemory<std::complex<double>>& bias,
     blas::ProfileResult* output_profile_result) {
   VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
-            PARAM(c), PARAM(algorithm));
+            PARAM(c), PARAM(algorithm), PARAM(bias));
 
   ThenBlasWithProfileImpl<const blas::IBlasLtMatmulPlan*,
                           const HostOrDeviceScalar<std::complex<double>>&,
@@ -4932,10 +4941,11 @@ Stream& Stream::ThenBlasLtMatmul(
                           const HostOrDeviceScalar<std::complex<double>>&,
                           DeviceMemory<std::complex<double>>*,
                           ScratchAllocator*,
-                          const blas::IBlasLtMatmulAlgorithm*>
+                          const blas::IBlasLtMatmulAlgorithm*,
+                          const DeviceMemory<std::complex<double>>&>
       impl;
   return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
-              c, scratch_allocator, algorithm, output_profile_result);
+              c, scratch_allocator, algorithm, bias, output_profile_result);
 }
 
 Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index 15f5dfc936f..91a80331f8e 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -1672,6 +1672,7 @@ class Stream {
       const DeviceMemory<int8>& b, const HostOrDeviceScalar<int32>& beta,
       DeviceMemory<int32>* c, ScratchAllocator* scratch_allocator,
       const blas::IBlasLtMatmulAlgorithm* algorithm,
+      const DeviceMemory<int32>& bias = {},
       blas::ProfileResult* output_profile_result = nullptr);
   Stream& ThenBlasLtMatmul(
       const blas::IBlasLtMatmulPlan* plan,
@@ -1680,6 +1681,7 @@ class Stream {
       const HostOrDeviceScalar<Eigen::half>& beta, DeviceMemory<Eigen::half>* c,
       ScratchAllocator* scratch_allocator,
       const blas::IBlasLtMatmulAlgorithm* algorithm,
+      const DeviceMemory<Eigen::half>& bias = {},
       blas::ProfileResult* output_profile_result = nullptr);
   Stream& ThenBlasLtMatmul(
       const blas::IBlasLtMatmulPlan* plan,
@@ -1687,6 +1689,7 @@ class Stream {
       const DeviceMemory<float>& b, const HostOrDeviceScalar<float>& beta,
       DeviceMemory<float>* c, ScratchAllocator* scratch_allocator,
       const blas::IBlasLtMatmulAlgorithm* algorithm,
+      const DeviceMemory<float>& bias = {},
       blas::ProfileResult* output_profile_result = nullptr);
   Stream& ThenBlasLtMatmul(
       const blas::IBlasLtMatmulPlan* plan,
@@ -1694,6 +1697,7 @@ class Stream {
       const DeviceMemory<double>& b, const HostOrDeviceScalar<double>& beta,
       DeviceMemory<double>* c, ScratchAllocator* scratch_allocator,
       const blas::IBlasLtMatmulAlgorithm* algorithm,
+      const DeviceMemory<double>& bias = {},
       blas::ProfileResult* output_profile_result = nullptr);
   Stream& ThenBlasLtMatmul(
       const blas::IBlasLtMatmulPlan* plan,
@@ -1703,6 +1707,7 @@ class Stream {
       const HostOrDeviceScalar<std::complex<float>>& beta,
       DeviceMemory<std::complex<float>>* c, ScratchAllocator* scratch_allocator,
       const blas::IBlasLtMatmulAlgorithm* algorithm,
+      const DeviceMemory<std::complex<float>>& bias = {},
       blas::ProfileResult* output_profile_result = nullptr);
   Stream& ThenBlasLtMatmul(
       const blas::IBlasLtMatmulPlan* plan,
@@ -1713,6 +1718,7 @@ class Stream {
       DeviceMemory<std::complex<double>>* c,
       ScratchAllocator* scratch_allocator,
       const blas::IBlasLtMatmulAlgorithm* algorithm,
+      const DeviceMemory<std::complex<double>>& bias = {},
       blas::ProfileResult* output_profile_result = nullptr);
 
   // See FftSupport::DoFft.
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index 3fbbc3f2aac..d75c1bc65c5 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -339,31 +339,32 @@ bool StreamExecutor::GetBlasGemmAlgorithms(
 std::unique_ptr<blas::IBlasLtMatmulPlan> StreamExecutor::CreateBlasLtMatmulPlan(
     blas::DataType ab_type, blas::DataType cd_type,
     blas::ComputationType computation_type, blas::PointerMode pointer_mode,
-    blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
-    uint64 k, int64 lda, int64 ldb, int64 ldc) {
+    blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
+    uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc) {
   blas::BlasSupport *blas_support = AsBlas();
   if (!blas_support) {
     return nullptr;
   }
   return blas_support->CreateBlasLtMatmulPlan(
-      ab_type, cd_type, computation_type, pointer_mode, transa, transb, m, n, k,
-      lda, ldb, ldc);
+      ab_type, cd_type, computation_type, pointer_mode, epilogue, transa,
+      transb, m, n, k, lda, ldb, ldc);
 }
 
 std::unique_ptr<blas::IBlasLtMatmulPlan>
 StreamExecutor::CreateBlasLtMatmulPlanStridedBatched(
     blas::DataType ab_type, blas::DataType cd_type,
     blas::ComputationType computation_type, blas::PointerMode pointer_mode,
-    blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
-    uint64 k, uint64 batch_count, int64 lda, int64 stride_a, int64 ldb,
-    int64 stride_b, int64 ldc, int64 stride_c) {
+    blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
+    uint64 m, uint64 n, uint64 k, uint64 batch_count, int64 lda, int64 stride_a,
+    int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) {
   blas::BlasSupport *blas_support = AsBlas();
   if (!blas_support) {
     return nullptr;
   }
   return blas_support->CreateBlasLtMatmulPlanStridedBatched(
-      ab_type, cd_type, computation_type, pointer_mode, transa, transb, m, n, k,
-      batch_count, lda, stride_a, ldb, stride_b, ldc, stride_c);
+      ab_type, cd_type, computation_type, pointer_mode, epilogue, transa,
+      transb, m, n, k, batch_count, lda, stride_a, ldb, stride_b, ldc,
+      stride_c);
 }
 
 bool StreamExecutor::GetBlasLtMatmulAlgorithms(
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index 90137417250..b40c0c23c05 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -401,17 +401,17 @@ class StreamExecutor {
   std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
       blas::DataType ab_type, blas::DataType cd_type,
       blas::ComputationType computation_type, blas::PointerMode pointer_mode,
-      blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
-      uint64 k, int64 lda, int64 ldb, int64 ldc);
+      blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
+      uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc);
 
   // A more general version of CreateBlasLtMatmulPlan supporting
   // batched operations.
   std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlanStridedBatched(
       blas::DataType ab_type, blas::DataType cd_type,
       blas::ComputationType computation_type, blas::PointerMode pointer_mode,
-      blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
-      uint64 k, uint64 batch_count, int64 lda, int64 stride_a, int64 ldb,
-      int64 stride_b, int64 ldc, int64 stride_c);
+      blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
+      uint64 m, uint64 n, uint64 k, uint64 batch_count, int64 lda,
+      int64 stride_a, int64 ldb, int64 stride_b, int64 ldc, int64 stride_c);
 
   // Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
   // returned in the order of increasing estimated compute time according to an