From 39bf03f083bc78812eaef8dc7e9b274110b923ee Mon Sep 17 00:00:00 2001 From: Ben Barsdell Date: Wed, 5 Aug 2020 09:36:23 +1000 Subject: [PATCH] Add support for blasLt epilogue fn and bias vector - Changes the backend APIs to allow an epilogue function (default, ReLU, bias, or bias then ReLU) to be specified and a bias vector to be provided. - This is expected to be useful for XLA to perform fusions. - This functionality is not currently tested, because the BatchMatMulOp does not expose relu/bias fusion. --- .../core/kernels/batch_matmul_op_impl.h | 7 +- tensorflow/stream_executor/blas.h | 58 ++++++--- tensorflow/stream_executor/cuda/cuda_blas.cc | 117 +++++++++++++----- tensorflow/stream_executor/cuda/cuda_blas.h | 3 +- tensorflow/stream_executor/stream.cc | 46 ++++--- tensorflow/stream_executor/stream.h | 6 + .../stream_executor/stream_executor_pimpl.cc | 19 +-- .../stream_executor/stream_executor_pimpl.h | 10 +- 8 files changed, 182 insertions(+), 84 deletions(-) diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/batch_matmul_op_impl.h index 5ca85c00835..456b4beff1e 100644 --- a/tensorflow/core/kernels/batch_matmul_op_impl.h +++ b/tensorflow/core/kernels/batch_matmul_op_impl.h @@ -558,8 +558,8 @@ struct LaunchBatchMatMul { stream->parent()->CreateBlasLtMatmulPlanStridedBatched( /*ab_type=*/blas_dtype, /*cd_type=*/blas_dtype, computation_type, - se::blas::PointerMode::kHost, blas_transpose_b, - blas_transpose_a, n, m, k, batch_size, + se::blas::PointerMode::kHost, se::blas::Epilogue::kDefault, + blas_transpose_b, blas_transpose_a, n, m, k, batch_size, /*lda=*/in_y.dim_size(2), b_stride, /*ldb=*/in_x.dim_size(2), a_stride, /*ldc=*/n, c_stride); OP_REQUIRES( @@ -621,7 +621,8 @@ struct LaunchBatchMatMul { stream ->ThenBlasLtMatmul(plan.get(), alpha, *b_ptrs[0], *a_ptrs[0], beta, c_ptrs[0], &scratch_allocator, - profile_algorithm.get(), &profile_result) + profile_algorithm.get(), {}, + &profile_result) .ok(); VLOG(4) << " Autotune algorithm " << i diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h index 583fba2a505..ae5b4853d05 100644 --- a/tensorflow/stream_executor/blas.h +++ b/tensorflow/stream_executor/blas.h @@ -107,6 +107,13 @@ enum class ComputationType { kF32FastBF16, // 32-bit floating-point with reduced (7-bit) mantissa }; +enum class Epilogue { + kDefault = 1, // No special postprocessing + kReLU = 2, // Apply ReLU func point-wise to the results + kBias = 4, // Add broadcasted bias vector to the results + kBiasThenReLU = kBias | kReLU, // Apply bias and then ReLU transform +}; + // Converts a ComputationType to a string. std::string ComputationTypeString(ComputationType ty); @@ -1462,11 +1469,11 @@ class BlasSupport { std::unique_ptr CreateBlasLtMatmulPlan( blas::DataType ab_type, blas::DataType c_type, blas::ComputationType computation_type, blas::PointerMode pointer_mode, - blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, - uint64 k, int64 lda, int64 ldb, int64 ldc) { + blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb, + uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc) { return CreateBlasLtMatmulPlanStridedBatched( - ab_type, c_type, computation_type, pointer_mode, transa, transb, m, n, - k, 1, lda, 0, ldb, 0, ldc, 0); + ab_type, c_type, computation_type, pointer_mode, epilogue, transa, + transb, m, n, k, 1, lda, 0, ldb, 0, ldc, 0); } // A more general version of CreateBlasLtMatmulPlan supporting @@ -1475,9 +1482,9 @@ class BlasSupport { CreateBlasLtMatmulPlanStridedBatched( blas::DataType ab_type, blas::DataType c_type, blas::ComputationType computation_type, blas::PointerMode pointer_mode, - blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, - uint64 k, int batch_count, int64 lda, int64 stride_a, int64 ldb, - int64 stride_b, int64 ldc, int64 stride_c) = 0; + blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb, + uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a, + int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) = 0; // Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are // returned in the order of increasing estimated compute time according to an @@ -1492,13 +1499,18 @@ class BlasSupport { // Executes a blaslt matmul operation on the stream. If output_profile_result // is not nullptr, the operation is profiled, error messages are // suppressed, and output_profile_result->algorithm() is set to - // algorithm->index(). + // algorithm->index(). If epilogue was set to kBias or kBiasThenReLU when + // creating the plan, the bias argument here must refer to a valid device + // vector of length equal to the number of rows in matrix c. If epilogue was + // set to any other value then the bias argument here must be null. The bias + // vector is broadcast across the batch dimension. virtual bool DoBlasLtMatmul( Stream* stream, const blas::IBlasLtMatmulPlan* plan, const HostOrDeviceScalar& alpha, const DeviceMemory& a, const DeviceMemory& b, const HostOrDeviceScalar& beta, DeviceMemory* c, ScratchAllocator* scratch_allocator, const blas::IBlasLtMatmulAlgorithm* algorithm, + const DeviceMemory& bias = {}, blas::ProfileResult* output_profile_result = nullptr) = 0; virtual bool DoBlasLtMatmul( Stream* stream, const blas::IBlasLtMatmulPlan* plan, @@ -1507,6 +1519,7 @@ class BlasSupport { const HostOrDeviceScalar& beta, DeviceMemory* c, ScratchAllocator* scratch_allocator, const blas::IBlasLtMatmulAlgorithm* algorithm, + const DeviceMemory& bias = {}, blas::ProfileResult* output_profile_result = nullptr) = 0; virtual bool DoBlasLtMatmul( Stream* stream, const blas::IBlasLtMatmulPlan* plan, @@ -1514,6 +1527,7 @@ class BlasSupport { const DeviceMemory& b, const HostOrDeviceScalar& beta, DeviceMemory* c, ScratchAllocator* scratch_allocator, const blas::IBlasLtMatmulAlgorithm* algorithm, + const DeviceMemory& bias = {}, blas::ProfileResult* output_profile_result = nullptr) = 0; virtual bool DoBlasLtMatmul( Stream* stream, const blas::IBlasLtMatmulPlan* plan, @@ -1521,6 +1535,7 @@ class BlasSupport { const DeviceMemory& b, const HostOrDeviceScalar& beta, DeviceMemory* c, ScratchAllocator* scratch_allocator, const blas::IBlasLtMatmulAlgorithm* algorithm, + const DeviceMemory& bias = {}, blas::ProfileResult* output_profile_result = nullptr) = 0; virtual bool DoBlasLtMatmul( Stream* stream, const blas::IBlasLtMatmulPlan* plan, @@ -1530,6 +1545,7 @@ class BlasSupport { const HostOrDeviceScalar>& beta, DeviceMemory>* c, ScratchAllocator* scratch_allocator, const blas::IBlasLtMatmulAlgorithm* algorithm, + const DeviceMemory>& bias = {}, blas::ProfileResult* output_profile_result = nullptr) = 0; virtual bool DoBlasLtMatmul( Stream* stream, const blas::IBlasLtMatmulPlan* plan, @@ -1540,6 +1556,7 @@ class BlasSupport { DeviceMemory>* c, ScratchAllocator* scratch_allocator, const blas::IBlasLtMatmulAlgorithm* algorithm, + const DeviceMemory>& bias = {}, blas::ProfileResult* output_profile_result = nullptr) = 0; virtual port::Status GetVersion(std::string *version) = 0; @@ -2359,9 +2376,10 @@ class BlasSupport { CreateBlasLtMatmulPlanStridedBatched( \ blas::DataType ab_type, blas::DataType cd_type, \ blas::ComputationType computation_type, blas::PointerMode pointer_mode, \ - blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, \ - uint64 k, int batch_count, int64 lda, int64 stride_a, int64 ldb, \ - int64 stride_b, int64 ldc, int64 stride_c) override; \ + blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, \ + int64 stride_a, int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) \ + override; \ bool GetBlasLtMatmulAlgorithms( \ const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size, \ int max_algorithm_count, \ @@ -2373,6 +2391,7 @@ class BlasSupport { const DeviceMemory& b, const HostOrDeviceScalar& beta, \ DeviceMemory* c, ScratchAllocator* scratch_allocator, \ const blas::IBlasLtMatmulAlgorithm* algorithm, \ + const DeviceMemory& bias = {}, \ blas::ProfileResult* output_profile_result = nullptr) override; \ bool DoBlasLtMatmul( \ Stream* stream, const blas::IBlasLtMatmulPlan* plan, \ @@ -2381,21 +2400,24 @@ class BlasSupport { const HostOrDeviceScalar& beta, \ DeviceMemory* c, ScratchAllocator* scratch_allocator, \ const blas::IBlasLtMatmulAlgorithm* algorithm, \ - blas::ProfileResult* output_profile_result) override; \ + const DeviceMemory& bias = {}, \ + blas::ProfileResult* output_profile_result = nullptr) override; \ bool DoBlasLtMatmul( \ Stream* stream, const blas::IBlasLtMatmulPlan* plan, \ const HostOrDeviceScalar& alpha, const DeviceMemory& a, \ const DeviceMemory& b, const HostOrDeviceScalar& beta, \ DeviceMemory* c, ScratchAllocator* scratch_allocator, \ const blas::IBlasLtMatmulAlgorithm* algorithm, \ - blas::ProfileResult* output_profile_result) override; \ + const DeviceMemory& bias = {}, \ + blas::ProfileResult* output_profile_result = nullptr) override; \ bool DoBlasLtMatmul( \ Stream* stream, const blas::IBlasLtMatmulPlan* plan, \ const HostOrDeviceScalar& alpha, const DeviceMemory& a, \ const DeviceMemory& b, const HostOrDeviceScalar& beta, \ DeviceMemory* c, ScratchAllocator* scratch_allocator, \ const blas::IBlasLtMatmulAlgorithm* algorithm, \ - blas::ProfileResult* output_profile_result) override; \ + const DeviceMemory& bias = {}, \ + blas::ProfileResult* output_profile_result = nullptr) override; \ bool DoBlasLtMatmul(Stream* stream, const blas::IBlasLtMatmulPlan* plan, \ const HostOrDeviceScalar>& alpha, \ const DeviceMemory>& a, \ @@ -2404,7 +2426,9 @@ class BlasSupport { DeviceMemory>* c, \ ScratchAllocator* scratch_allocator, \ const blas::IBlasLtMatmulAlgorithm* algorithm, \ - blas::ProfileResult* output_profile_result) override; \ + const DeviceMemory>& bias = {}, \ + blas::ProfileResult* output_profile_result = nullptr) \ + override; \ bool DoBlasLtMatmul(Stream* stream, const blas::IBlasLtMatmulPlan* plan, \ const HostOrDeviceScalar>& alpha, \ const DeviceMemory>& a, \ @@ -2413,7 +2437,9 @@ class BlasSupport { DeviceMemory>* c, \ ScratchAllocator* scratch_allocator, \ const blas::IBlasLtMatmulAlgorithm* algorithm, \ - blas::ProfileResult* output_profile_result) override; \ + const DeviceMemory>& bias = {}, \ + blas::ProfileResult* output_profile_result = nullptr) \ + override; \ port::Status GetVersion(std::string *version) override; } // namespace blas diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index ba833e562e2..1d95b00ce7e 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -468,6 +468,18 @@ cublasLtPointerMode_t CUBLASPointerMode(blas::PointerMode pointer_mode) { return CUBLASLT_POINTER_MODE_DEVICE; } } +cublasLtEpilogue_t CUBLASEpilogue(blas::Epilogue epilogue) { + switch (epilogue) { + case blas::Epilogue::kDefault: + return CUBLASLT_EPILOGUE_DEFAULT; + case blas::Epilogue::kReLU: + return CUBLASLT_EPILOGUE_RELU; + case blas::Epilogue::kBias: + return CUBLASLT_EPILOGUE_BIAS; + case blas::Epilogue::kBiasThenReLU: + return CUBLASLT_EPILOGUE_RELU_BIAS; + } +} #endif // CUDA_VERSION >= 11000 cudaDataType_t GetCUDADataType(blas::DataType ty) { @@ -3135,12 +3147,12 @@ using UniqueMatmulPreference = std::unique_ptr::type, MatmulPreferenceDestroyer>; -UniqueOpDesc CreateCublasLtOperationDesc( - blas::ComputationType computation_type, blas::DataType scale_type, - blas::PointerMode pointer_mode, blas::Transpose transa, - blas::Transpose transb) { - cublasOperation_t cuda_transa = CUDABlasTranspose(transa); - cublasOperation_t cuda_transb = CUDABlasTranspose(transb); +UniqueOpDesc CreateCublasLtOperationDesc(blas::ComputationType computation_type, + blas::DataType scale_type, + blas::PointerMode pointer_mode, + blas::Epilogue epilogue, + blas::Transpose transa, + blas::Transpose transb) { cublasLtMatmulDesc_t desc; cublasComputeType_t cublas_compute_type = CUBLASComputationType(computation_type); @@ -3154,9 +3166,13 @@ UniqueOpDesc CreateCublasLtOperationDesc( } UniqueOpDesc unique_desc(desc); if (!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_POINTER_MODE, - CUBLASPointerMode(pointer_mode)) || - !SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSA, cuda_transa) || - !SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSB, cuda_transb)) { + CUBLASPointerMode(pointer_mode)) || + !SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_EPILOGUE, + CUBLASEpilogue(epilogue)) || + !SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSA, + CUDABlasTranspose(transa)) || + !SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSB, + CUDABlasTranspose(transb))) { return nullptr; } return unique_desc; @@ -3217,11 +3233,11 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan { public: CUDABlasLtMatmulPlan(blas::DataType ab_type, blas::DataType cd_type, blas::ComputationType compute_type, - blas::PointerMode pointer_mode, blas::Transpose transa, - blas::Transpose transb, uint64 m, uint64 n, uint64 k, - int batch_count, int64 lda, int64 stride_a, int64 ldb, - int64 stride_b, int64 ldc, int64 stride_c, int64 ldd, - int64 stride_d); + blas::PointerMode pointer_mode, blas::Epilogue epilogue, + blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, int batch_count, int64 lda, + int64 stride_a, int64 ldb, int64 stride_b, int64 ldc, + int64 stride_c, int64 ldd, int64 stride_d); cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); } cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); } @@ -3234,12 +3250,17 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan { blas::DataType cd_type() const { return cd_type_; } blas::DataType scale_type() const { return scale_type_; } blas::PointerMode pointer_mode() const { return pointer_mode_; } + blas::Epilogue epilogue() const { return epilogue_; } int batch_count() const { return batch_count_; } int64 stride_a() const { return stride_a_; } int64 stride_b() const { return stride_b_; } int64 stride_c() const { return stride_c_; } int64 stride_d() const { return stride_d_; } + // Note: Must be const to satisfy API. This is always called before the plan + // is executed, so the state change is not observed in subsequent executions. + bool SetBiasPointer(const void* bias) const; + private: UniqueOpDesc op_desc_; UniqueLayoutDesc a_desc_; @@ -3250,6 +3271,7 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan { blas::DataType cd_type_; blas::DataType scale_type_; blas::PointerMode pointer_mode_; + blas::Epilogue epilogue_; int batch_count_; int64 stride_a_; int64 stride_b_; @@ -3260,12 +3282,13 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan { CUDABlasLtMatmulPlan::CUDABlasLtMatmulPlan( blas::DataType ab_type, blas::DataType cd_type, blas::ComputationType computation_type, blas::PointerMode pointer_mode, - blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, - uint64 k, int batch_count, int64 lda, int64 stride_a, int64 ldb, - int64 stride_b, int64 ldc, int64 stride_c, int64 ldd, int64 stride_d) + blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb, + uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a, + int64 ldb, int64 stride_b, int64 ldc, int64 stride_c, int64 ldd, + int64 stride_d) : op_desc_(CreateCublasLtOperationDesc( computation_type, GetScaleType(cd_type, computation_type), - pointer_mode, transa, transb)), + pointer_mode, epilogue, transa, transb)), a_desc_(nullptr), b_desc_(nullptr), c_desc_( @@ -3276,6 +3299,7 @@ CUDABlasLtMatmulPlan::CUDABlasLtMatmulPlan( cd_type_(cd_type), scale_type_(GetScaleType(cd_type, computation_type)), pointer_mode_(pointer_mode), + epilogue_(epilogue), batch_count_(batch_count), stride_a_(stride_a), stride_b_(stride_b), @@ -3291,6 +3315,11 @@ CUDABlasLtMatmulPlan::CUDABlasLtMatmulPlan( batch_count); } +bool CUDABlasLtMatmulPlan::SetBiasPointer(const void* bias) const { + return SetCublasLtAttr(op_desc_.get(), CUBLASLT_MATMUL_DESC_BIAS_POINTER, + bias); +} + class CUDABlasLtMatmulAlgorithm final : public blas::IBlasLtMatmulAlgorithm { public: CUDABlasLtMatmulAlgorithm(blas::AlgorithmType index, @@ -3370,13 +3399,14 @@ std::unique_ptr CUDABlas::CreateBlasLtMatmulPlanStridedBatched( blas::DataType ab_type, blas::DataType cd_type, blas::ComputationType computation_type, blas::PointerMode pointer_mode, - blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, - uint64 k, int batch_count, int64 lda, int64 stride_a, int64 ldb, - int64 stride_b, int64 ldc, int64 stride_c) { + blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb, + uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a, + int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) { #if CUDA_VERSION >= 11000 auto result = std::make_unique( - ab_type, cd_type, computation_type, pointer_mode, transa, transb, m, n, k, - batch_count, lda, stride_a, ldb, stride_b, ldc, stride_c, ldc, stride_c); + ab_type, cd_type, computation_type, pointer_mode, epilogue, transa, + transb, m, n, k, batch_count, lda, stride_a, ldb, stride_b, ldc, stride_c, + ldc, stride_c); if (!result->ok()) { result.reset(); } @@ -3436,7 +3466,8 @@ bool CUDABlas::DoBlasLtMatmulInternalImpl( const HostOrDeviceScalar& alpha, const ABType* a, const ABType* b, const HostOrDeviceScalar& beta, const CDType* c, CDType* d, ScratchAllocator* scratch_allocator, - const blas::IBlasLtMatmulAlgorithm* algorithm) { + const blas::IBlasLtMatmulAlgorithm* algorithm, + const CDType* bias) { const auto& cuda_plan = *static_cast(plan); const auto& cuda_algo = *static_cast(algorithm); @@ -3474,6 +3505,20 @@ bool CUDABlas::DoBlasLtMatmulInternalImpl( "pointer_mode for the given alpha/beta."; return false; } + if ((cuda_plan.epilogue() == blas::Epilogue::kBias || + cuda_plan.epilogue() == blas::Epilogue::kBiasThenReLU) != + (bias != nullptr)) { + VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong " + "epilogue for the given bias pointer."; + return false; + } + if (bias != nullptr) { + if (!cuda_plan.SetBiasPointer(bias)) { + VLOG(2) << "DoBlasLtMatmul returning false because setting the bias " + "pointer failed."; + return false; + } + } const ScaleType* alpha_ptr = alpha.is_pointer() ? GpuMemory(alpha.pointer()) : &alpha.value(); const ScaleType* beta_ptr = @@ -3525,6 +3570,7 @@ bool CUDABlas::DoBlasLtMatmulInternal( const DeviceMemory& c, DeviceMemory* d, ScratchAllocator* scratch_allocator, const blas::IBlasLtMatmulAlgorithm* algorithm, + const DeviceMemory& bias, blas::ProfileResult* output_profile_result) { #if CUDA_VERSION >= 11000 std::unique_ptr timer; @@ -3538,7 +3584,8 @@ bool CUDABlas::DoBlasLtMatmulInternal( bool err_on_failure = timer != nullptr; bool result = DoBlasLtMatmulInternalImpl( stream, err_on_failure, plan, alpha, GpuMemory(a), GpuMemory(b), beta, - GpuMemory(c), GpuMemoryMutable(d), scratch_allocator, algorithm); + GpuMemory(c), GpuMemoryMutable(d), scratch_allocator, algorithm, + GpuMemory(bias)); if (timer && result) { // GpuTimer will CHECK-fail if we Stop() it while the stream is in an error @@ -3563,9 +3610,10 @@ bool CUDABlas::DoBlasLtMatmul( const DeviceMemory& b, const HostOrDeviceScalar& beta, DeviceMemory* c, ScratchAllocator* scratch_allocator, const blas::IBlasLtMatmulAlgorithm* algorithm, + const DeviceMemory& bias, blas::ProfileResult* output_profile_result) { return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c, - scratch_allocator, algorithm, + scratch_allocator, algorithm, bias, output_profile_result); } @@ -3578,6 +3626,7 @@ bool CUDABlas::DoBlasLtMatmul(Stream* stream, DeviceMemory* c, ScratchAllocator* scratch_allocator, const blas::IBlasLtMatmulAlgorithm* algorithm, + const DeviceMemory& bias, blas::ProfileResult* output_profile_result) { #if CUDA_VERSION >= 11000 const auto& cuda_plan = *static_cast(plan); @@ -3591,11 +3640,11 @@ bool CUDABlas::DoBlasLtMatmul(Stream* stream, HostOrDeviceScalar float_alpha(static_cast(alpha.value())); HostOrDeviceScalar float_beta(static_cast(beta.value())); return DoBlasLtMatmulInternal(stream, plan, float_alpha, a, b, float_beta, - *c, c, scratch_allocator, algorithm, + *c, c, scratch_allocator, algorithm, bias, output_profile_result); } return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c, - scratch_allocator, algorithm, + scratch_allocator, algorithm, bias, output_profile_result); #else // if CUDA_VERSION < 11000 return false; @@ -3608,9 +3657,10 @@ bool CUDABlas::DoBlasLtMatmul( const DeviceMemory& b, const HostOrDeviceScalar& beta, DeviceMemory* c, ScratchAllocator* scratch_allocator, const blas::IBlasLtMatmulAlgorithm* algorithm, + const DeviceMemory& bias, blas::ProfileResult* output_profile_result) { return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c, - scratch_allocator, algorithm, + scratch_allocator, algorithm, bias, output_profile_result); } @@ -3620,9 +3670,10 @@ bool CUDABlas::DoBlasLtMatmul( const DeviceMemory& b, const HostOrDeviceScalar& beta, DeviceMemory* c, ScratchAllocator* scratch_allocator, const blas::IBlasLtMatmulAlgorithm* algorithm, + const DeviceMemory& bias, blas::ProfileResult* output_profile_result) { return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c, - scratch_allocator, algorithm, + scratch_allocator, algorithm, bias, output_profile_result); } @@ -3634,9 +3685,10 @@ bool CUDABlas::DoBlasLtMatmul( const HostOrDeviceScalar>& beta, DeviceMemory>* c, ScratchAllocator* scratch_allocator, const blas::IBlasLtMatmulAlgorithm* algorithm, + const DeviceMemory>& bias, blas::ProfileResult* output_profile_result) { return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c, - scratch_allocator, algorithm, + scratch_allocator, algorithm, bias, output_profile_result); } @@ -3648,9 +3700,10 @@ bool CUDABlas::DoBlasLtMatmul( const HostOrDeviceScalar>& beta, DeviceMemory>* c, ScratchAllocator* scratch_allocator, const blas::IBlasLtMatmulAlgorithm* algorithm, + const DeviceMemory>& bias, blas::ProfileResult* output_profile_result) { return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c, - scratch_allocator, algorithm, + scratch_allocator, algorithm, bias, output_profile_result); } diff --git a/tensorflow/stream_executor/cuda/cuda_blas.h b/tensorflow/stream_executor/cuda/cuda_blas.h index 351a7778c01..3fdfcb0a50c 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.h +++ b/tensorflow/stream_executor/cuda/cuda_blas.h @@ -148,6 +148,7 @@ class CUDABlas : public blas::BlasSupport { const DeviceMemory& c, DeviceMemory* d, ScratchAllocator* scratch_allocator, const blas::IBlasLtMatmulAlgorithm* algorithm, + const DeviceMemory& bias, blas::ProfileResult* output_profile_result); // Helper function for implementing DoBlasLtMatmulInternal. @@ -157,7 +158,7 @@ class CUDABlas : public blas::BlasSupport { const HostOrDeviceScalar& alpha, const ABType* a, const ABType* b, const HostOrDeviceScalar& beta, const CDType* c, CDType* d, ScratchAllocator* scratch_allocator, - const blas::IBlasLtMatmulAlgorithm* algorithm); + const blas::IBlasLtMatmulAlgorithm* algorithm, const CDType* bias); // Guards the cuBLAS handle for this device. absl::Mutex mu_; diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 144af92185c..66728c94821 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -4809,18 +4809,19 @@ Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan, DeviceMemory* c, ScratchAllocator* scratch_allocator, const blas::IBlasLtMatmulAlgorithm* algorithm, + const DeviceMemory& bias, blas::ProfileResult* output_profile_result) { VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta), - PARAM(c), PARAM(algorithm)); + PARAM(c), PARAM(algorithm), PARAM(bias)); ThenBlasWithProfileImpl< const blas::IBlasLtMatmulPlan*, const HostOrDeviceScalar&, const DeviceMemory&, const DeviceMemory&, const HostOrDeviceScalar&, DeviceMemory*, ScratchAllocator*, - const blas::IBlasLtMatmulAlgorithm*> + const blas::IBlasLtMatmulAlgorithm*, const DeviceMemory&> impl; return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta, - c, scratch_allocator, algorithm, output_profile_result); + c, scratch_allocator, algorithm, bias, output_profile_result); } Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan, @@ -4831,18 +4832,20 @@ Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan, DeviceMemory* c, ScratchAllocator* scratch_allocator, const blas::IBlasLtMatmulAlgorithm* algorithm, + const DeviceMemory& bias, blas::ProfileResult* output_profile_result) { VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta), - PARAM(c), PARAM(algorithm)); + PARAM(c), PARAM(algorithm), PARAM(bias)); ThenBlasWithProfileImpl< const blas::IBlasLtMatmulPlan*, const HostOrDeviceScalar&, const DeviceMemory&, const DeviceMemory&, const HostOrDeviceScalar&, DeviceMemory*, - ScratchAllocator*, const blas::IBlasLtMatmulAlgorithm*> + ScratchAllocator*, const blas::IBlasLtMatmulAlgorithm*, + const DeviceMemory&> impl; return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta, - c, scratch_allocator, algorithm, output_profile_result); + c, scratch_allocator, algorithm, bias, output_profile_result); } Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan, @@ -4853,18 +4856,19 @@ Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan, DeviceMemory* c, ScratchAllocator* scratch_allocator, const blas::IBlasLtMatmulAlgorithm* algorithm, + const DeviceMemory& bias, blas::ProfileResult* output_profile_result) { VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta), - PARAM(c), PARAM(algorithm)); + PARAM(c), PARAM(algorithm), PARAM(bias)); ThenBlasWithProfileImpl< const blas::IBlasLtMatmulPlan*, const HostOrDeviceScalar&, const DeviceMemory&, const DeviceMemory&, const HostOrDeviceScalar&, DeviceMemory*, ScratchAllocator*, - const blas::IBlasLtMatmulAlgorithm*> + const blas::IBlasLtMatmulAlgorithm*, const DeviceMemory&> impl; return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta, - c, scratch_allocator, algorithm, output_profile_result); + c, scratch_allocator, algorithm, bias, output_profile_result); } Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan, @@ -4875,18 +4879,20 @@ Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan, DeviceMemory* c, ScratchAllocator* scratch_allocator, const blas::IBlasLtMatmulAlgorithm* algorithm, + const DeviceMemory& bias, blas::ProfileResult* output_profile_result) { VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta), - PARAM(c), PARAM(algorithm)); + PARAM(c), PARAM(algorithm), PARAM(bias)); ThenBlasWithProfileImpl< const blas::IBlasLtMatmulPlan*, const HostOrDeviceScalar&, const DeviceMemory&, const DeviceMemory&, const HostOrDeviceScalar&, DeviceMemory*, - ScratchAllocator*, const blas::IBlasLtMatmulAlgorithm*> + ScratchAllocator*, const blas::IBlasLtMatmulAlgorithm*, + const DeviceMemory&> impl; return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta, - c, scratch_allocator, algorithm, output_profile_result); + c, scratch_allocator, algorithm, bias, output_profile_result); } Stream& Stream::ThenBlasLtMatmul( @@ -4897,9 +4903,10 @@ Stream& Stream::ThenBlasLtMatmul( const HostOrDeviceScalar>& beta, DeviceMemory>* c, ScratchAllocator* scratch_allocator, const blas::IBlasLtMatmulAlgorithm* algorithm, + const DeviceMemory>& bias, blas::ProfileResult* output_profile_result) { VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta), - PARAM(c), PARAM(algorithm)); + PARAM(c), PARAM(algorithm), PARAM(bias)); ThenBlasWithProfileImpl>&, @@ -4907,10 +4914,11 @@ Stream& Stream::ThenBlasLtMatmul( const DeviceMemory>&, const HostOrDeviceScalar>&, DeviceMemory>*, ScratchAllocator*, - const blas::IBlasLtMatmulAlgorithm*> + const blas::IBlasLtMatmulAlgorithm*, + const DeviceMemory>&> impl; return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta, - c, scratch_allocator, algorithm, output_profile_result); + c, scratch_allocator, algorithm, bias, output_profile_result); } Stream& Stream::ThenBlasLtMatmul( @@ -4921,9 +4929,10 @@ Stream& Stream::ThenBlasLtMatmul( const HostOrDeviceScalar>& beta, DeviceMemory>* c, ScratchAllocator* scratch_allocator, const blas::IBlasLtMatmulAlgorithm* algorithm, + const DeviceMemory>& bias, blas::ProfileResult* output_profile_result) { VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta), - PARAM(c), PARAM(algorithm)); + PARAM(c), PARAM(algorithm), PARAM(bias)); ThenBlasWithProfileImpl>&, @@ -4932,10 +4941,11 @@ Stream& Stream::ThenBlasLtMatmul( const HostOrDeviceScalar>&, DeviceMemory>*, ScratchAllocator*, - const blas::IBlasLtMatmulAlgorithm*> + const blas::IBlasLtMatmulAlgorithm*, + const DeviceMemory>&> impl; return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta, - c, scratch_allocator, algorithm, output_profile_result); + c, scratch_allocator, algorithm, bias, output_profile_result); } Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) { diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index 15f5dfc936f..91a80331f8e 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -1672,6 +1672,7 @@ class Stream { const DeviceMemory& b, const HostOrDeviceScalar& beta, DeviceMemory* c, ScratchAllocator* scratch_allocator, const blas::IBlasLtMatmulAlgorithm* algorithm, + const DeviceMemory& bias = {}, blas::ProfileResult* output_profile_result = nullptr); Stream& ThenBlasLtMatmul( const blas::IBlasLtMatmulPlan* plan, @@ -1680,6 +1681,7 @@ class Stream { const HostOrDeviceScalar& beta, DeviceMemory* c, ScratchAllocator* scratch_allocator, const blas::IBlasLtMatmulAlgorithm* algorithm, + const DeviceMemory& bias = {}, blas::ProfileResult* output_profile_result = nullptr); Stream& ThenBlasLtMatmul( const blas::IBlasLtMatmulPlan* plan, @@ -1687,6 +1689,7 @@ class Stream { const DeviceMemory& b, const HostOrDeviceScalar& beta, DeviceMemory* c, ScratchAllocator* scratch_allocator, const blas::IBlasLtMatmulAlgorithm* algorithm, + const DeviceMemory& bias = {}, blas::ProfileResult* output_profile_result = nullptr); Stream& ThenBlasLtMatmul( const blas::IBlasLtMatmulPlan* plan, @@ -1694,6 +1697,7 @@ class Stream { const DeviceMemory& b, const HostOrDeviceScalar& beta, DeviceMemory* c, ScratchAllocator* scratch_allocator, const blas::IBlasLtMatmulAlgorithm* algorithm, + const DeviceMemory& bias = {}, blas::ProfileResult* output_profile_result = nullptr); Stream& ThenBlasLtMatmul( const blas::IBlasLtMatmulPlan* plan, @@ -1703,6 +1707,7 @@ class Stream { const HostOrDeviceScalar>& beta, DeviceMemory>* c, ScratchAllocator* scratch_allocator, const blas::IBlasLtMatmulAlgorithm* algorithm, + const DeviceMemory>& bias = {}, blas::ProfileResult* output_profile_result = nullptr); Stream& ThenBlasLtMatmul( const blas::IBlasLtMatmulPlan* plan, @@ -1713,6 +1718,7 @@ class Stream { DeviceMemory>* c, ScratchAllocator* scratch_allocator, const blas::IBlasLtMatmulAlgorithm* algorithm, + const DeviceMemory>& bias = {}, blas::ProfileResult* output_profile_result = nullptr); // See FftSupport::DoFft. diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index 3fbbc3f2aac..d75c1bc65c5 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -339,31 +339,32 @@ bool StreamExecutor::GetBlasGemmAlgorithms( std::unique_ptr StreamExecutor::CreateBlasLtMatmulPlan( blas::DataType ab_type, blas::DataType cd_type, blas::ComputationType computation_type, blas::PointerMode pointer_mode, - blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, - uint64 k, int64 lda, int64 ldb, int64 ldc) { + blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb, + uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc) { blas::BlasSupport *blas_support = AsBlas(); if (!blas_support) { return nullptr; } return blas_support->CreateBlasLtMatmulPlan( - ab_type, cd_type, computation_type, pointer_mode, transa, transb, m, n, k, - lda, ldb, ldc); + ab_type, cd_type, computation_type, pointer_mode, epilogue, transa, + transb, m, n, k, lda, ldb, ldc); } std::unique_ptr StreamExecutor::CreateBlasLtMatmulPlanStridedBatched( blas::DataType ab_type, blas::DataType cd_type, blas::ComputationType computation_type, blas::PointerMode pointer_mode, - blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, - uint64 k, uint64 batch_count, int64 lda, int64 stride_a, int64 ldb, - int64 stride_b, int64 ldc, int64 stride_c) { + blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb, + uint64 m, uint64 n, uint64 k, uint64 batch_count, int64 lda, int64 stride_a, + int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) { blas::BlasSupport *blas_support = AsBlas(); if (!blas_support) { return nullptr; } return blas_support->CreateBlasLtMatmulPlanStridedBatched( - ab_type, cd_type, computation_type, pointer_mode, transa, transb, m, n, k, - batch_count, lda, stride_a, ldb, stride_b, ldc, stride_c); + ab_type, cd_type, computation_type, pointer_mode, epilogue, transa, + transb, m, n, k, batch_count, lda, stride_a, ldb, stride_b, ldc, + stride_c); } bool StreamExecutor::GetBlasLtMatmulAlgorithms( diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h index 90137417250..b40c0c23c05 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/stream_executor/stream_executor_pimpl.h @@ -401,17 +401,17 @@ class StreamExecutor { std::unique_ptr CreateBlasLtMatmulPlan( blas::DataType ab_type, blas::DataType cd_type, blas::ComputationType computation_type, blas::PointerMode pointer_mode, - blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, - uint64 k, int64 lda, int64 ldb, int64 ldc); + blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb, + uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc); // A more general version of CreateBlasLtMatmulPlan supporting // batched operations. std::unique_ptr CreateBlasLtMatmulPlanStridedBatched( blas::DataType ab_type, blas::DataType cd_type, blas::ComputationType computation_type, blas::PointerMode pointer_mode, - blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, - uint64 k, uint64 batch_count, int64 lda, int64 stride_a, int64 ldb, - int64 stride_b, int64 ldc, int64 stride_c); + blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb, + uint64 m, uint64 n, uint64 k, uint64 batch_count, int64 lda, + int64 stride_a, int64 ldb, int64 stride_b, int64 ldc, int64 stride_c); // Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are // returned in the order of increasing estimated compute time according to an