From af5c3ccabcacc39c95096022ef20b8089cc56f5e Mon Sep 17 00:00:00 2001 From: Tim Shen Date: Wed, 28 Oct 2020 16:54:29 -0700 Subject: [PATCH] Rollback PR #43237: Integrate cuBLASLt API into backend Reason: Performance regression on bert_pretraining and bert_squad. PiperOrigin-RevId: 339565741 Change-Id: I1c6a4bf807b3cb3aa132e6272a2f01f90bdeca6d --- tensorflow/core/kernels/BUILD | 2 +- .../core/kernels/batch_matmul_op_impl.h | 588 ++++++------------ tensorflow/core/kernels/conv_ops.cc | 14 +- tensorflow/core/kernels/conv_ops_gpu.h | 47 +- tensorflow/core/kernels/fft_ops.cc | 16 +- tensorflow/core/kernels/gpu_utils.cc | 59 -- tensorflow/core/kernels/gpu_utils.h | 31 - .../core/kernels/linalg/einsum_op_impl.h | 8 +- tensorflow/core/util/matmul_autotune.cc | 18 - tensorflow/core/util/matmul_autotune.h | 1 - third_party/gpus/cuda/BUILD.tpl | 8 - third_party/gpus/cuda_configure.bzl | 12 - 12 files changed, 262 insertions(+), 542 deletions(-) diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index c094a8a6046..53c53ac3ff6 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3396,7 +3396,7 @@ tf_kernel_library( name = "fft_ops", prefix = "fft_ops", deps = MATH_DEPS + [ - ] + if_cuda_or_rocm([":gpu_utils"]) + if_cuda([ + ] + if_cuda([ "//tensorflow/core/platform/default/build_config:cufft_plugin", ]), ) diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/batch_matmul_op_impl.h index 5c1e0cbe6e4..d6cc980633f 100644 --- a/tensorflow/core/kernels/batch_matmul_op_impl.h +++ b/tensorflow/core/kernels/batch_matmul_op_impl.h @@ -35,7 +35,6 @@ limitations under the License. #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/matmul_autotune.h" #include "tensorflow/core/util/matmul_bcast.h" #include "tensorflow/core/util/work_sharder.h" @@ -44,13 +43,8 @@ limitations under the License. #endif #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "tensorflow/core/kernels/gpu_utils.h" #include "tensorflow/core/platform/stream_executor.h" -#include "tensorflow/core/platform/tensor_float_32_utils.h" #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#if GOOGLE_CUDA -#include "third_party/gpus/cuda/include/cuda.h" // For CUDA_VERSION -#endif namespace tensorflow { @@ -225,8 +219,7 @@ template struct LaunchBatchMatMul { static void Launch(OpKernelContext* context, const Tensor& in_x, const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x, - bool trans_y, const MatMulBCast& bcast, bool use_autotune, - Tensor* out) { + bool trans_y, const MatMulBCast& bcast, Tensor* out) { typedef ParallelMatMulKernel::IsComplex> ParallelMatMulKernel; bool conjugate_result = false; @@ -282,212 +275,45 @@ se::DeviceMemory AsDeviceMemory(const T* gpu_memory) { return typed; } -using BlasScratchAllocator = GpuScratchAllocator; - -int64 GetBlasWorkspaceLimit(const string& envvar_in_mb, - int64 default_value_in_bytes) { - return gpu_utils::GetWorkspaceLimit(envvar_in_mb, default_value_in_bytes); -} - -// Encapsulate all of the shape, dtype etc. information that defines a unique -// batched matmul operation. -class BatchMatmulParameters { +class BlasScratchAllocator : public se::ScratchAllocator { public: - BatchMatmulParameters(bool trans_a, bool trans_b, bool adj_a, bool adj_b, - uint64 m, uint64 n, uint64 k, uint64 batch_count, - bool broadcast_a, bool broadcast_b, DataType dtype_ab, - DataType dtype_cd, bool allow_tf32, int device_id) - : trans_a_(trans_a), - trans_b_(trans_b), - adj_a_(adj_a), - adj_b_(adj_b), - m_(m), - n_(n), - k_(k), - batch_count_(batch_count), - broadcast_a_(broadcast_a), - broadcast_b_(broadcast_b), - dtype_ab_(dtype_ab), - dtype_cd_(dtype_cd), - allow_tf32_(allow_tf32), - device_id_(device_id) { - hash_code_ = trans_a; - hash_code_ = Hash64Combine(hash_code_, trans_b); - hash_code_ = Hash64Combine(hash_code_, adj_a); - hash_code_ = Hash64Combine(hash_code_, adj_b); - hash_code_ = Hash64Combine(hash_code_, m); - hash_code_ = Hash64Combine(hash_code_, n); - hash_code_ = Hash64Combine(hash_code_, k); - hash_code_ = Hash64Combine(hash_code_, batch_count); - hash_code_ = Hash64Combine(hash_code_, broadcast_a); - hash_code_ = Hash64Combine(hash_code_, broadcast_b); - hash_code_ = Hash64Combine(hash_code_, dtype_ab); - hash_code_ = Hash64Combine(hash_code_, dtype_cd); - hash_code_ = Hash64Combine(hash_code_, allow_tf32); - hash_code_ = Hash64Combine(hash_code_, device_id); - } - bool operator==(const BatchMatmulParameters& other) const { - return this->get_data_as_tuple() == other.get_data_as_tuple(); - } + using Stream = se::Stream; + using DeviceMemoryBytes = se::DeviceMemory; - bool operator!=(const BatchMatmulParameters& other) const { - return !(*this == other); - } - uint64 hash() const { return hash_code_; } + BlasScratchAllocator(OpKernelContext* context) : context_(context) {} - string ToString() const { - // clang-format off - return strings::StrCat( - trans_a_, ", ", trans_b_, ", ", adj_a_, ", ", adj_b_, ", ", - m_, ", ", n_, ", ", k_, ", ", batch_count_, ", ", - broadcast_a_, ", ", broadcast_b_, ", ", - dtype_ab_, ", ", dtype_cd_, ", ", allow_tf32_, ", ", device_id_); - // clang-format on + int64 GetMemoryLimitInBytes() override { return -1; } + + se::port::StatusOr AllocateBytes( + int64 byte_size) override { + Tensor temporary_memory; + + Status allocation_status(context_->allocate_temp( + DT_UINT8, TensorShape({byte_size}), &temporary_memory)); + if (!allocation_status.ok()) { + return se::port::StatusOr( + DeviceMemoryBytes::MakeFromByteSize(nullptr, 0)); + } + // Hold the reference of the allocated tensors until the end of the + // allocator. + allocated_tensors_.push_back(temporary_memory); + return se::port::StatusOr( + DeviceMemoryBytes::MakeFromByteSize( + temporary_memory.flat().data(), + temporary_memory.flat().size())); } private: - typedef std::tuple - ParameterDataType; - - ParameterDataType get_data_as_tuple() const { - return std::make_tuple(trans_a_, trans_b_, adj_a_, adj_b_, m_, n_, k_, - batch_count_, broadcast_a_, broadcast_b_, dtype_ab_, - dtype_cd_, allow_tf32_, device_id_); - } - - bool trans_a_; - bool trans_b_; - bool adj_a_; - bool adj_b_; - uint64 m_; - uint64 n_; - uint64 k_; - uint64 batch_count_; - bool broadcast_a_; - bool broadcast_b_; - DataType dtype_ab_; - DataType dtype_cd_; - bool allow_tf32_; - int device_id_; - uint64 hash_code_; + OpKernelContext* context_; + std::vector allocated_tensors_; }; - -bool GetBlasComputationType(const DataType& dtype, bool allow_tf32, - se::blas::ComputationType* compute_type) { - using se::blas::ComputationType; - static bool use_f32_for_f16_computation = MatmulDoFP32ComputationFP16Input(); - ComputationType f32_type = - allow_tf32 ? ComputationType::kTF32AsF32 : ComputationType::kF32; - switch (dtype) { - case DT_HALF: - case DT_BFLOAT16: - *compute_type = - use_f32_for_f16_computation ? f32_type : ComputationType::kF16; - return true; - case DT_FLOAT: - *compute_type = f32_type; - return true; - case DT_DOUBLE: - *compute_type = ComputationType::kF64; - return true; - case DT_COMPLEX64: - *compute_type = f32_type; - return true; - case DT_COMPLEX128: - *compute_type = ComputationType::kComplexF64; - return true; - default: - // Unsupported compute_type, return false. - return false; - } -} - -// Thread-safe map from matmul parameters to their corresponding plan and -// algorithms. -template -class BlasLtMatmulPlanMap { - public: - struct PlanAndAlgorithms { - std::unique_ptr plan; - std::vector> algorithms; - }; - - const PlanAndAlgorithms* Find(const Parameters& params) { - mutex_lock lock(mu_); - auto iter = params_plan_map_.find(params); - if (iter == params_plan_map_.end()) { - return nullptr; - } - return &iter->second; - } - const PlanAndAlgorithms* Insert(const Parameters& params, - PlanAndAlgorithms value) { - mutex_lock lock(mu_); - return ¶ms_plan_map_.emplace(params, std::move(value)).first->second; - } - - private: - struct Hasher { - std::size_t operator()(const Parameters& parameter) const { - return parameter.hash(); - } - }; - - mutable mutex mu_; - std::unordered_map params_plan_map_ - GUARDED_BY(mu_); -}; - -template -struct BlasLtPlanMapSingleton { - typedef BlasLtMatmulPlanMap PlanMapType; - static PlanMapType* GetInstance() { - static PlanMapType* instance = new PlanMapType(); - return instance; - } -}; - -typedef BlasLtPlanMapSingleton - BatchMatmulPlanMapSingleton; - -// A dummy type to group matmul autotune results together. -struct BatchMatmulAutoTuneGroup { - static string name() { return "MatmulLt"; } -}; - -typedef AutoTuneSingleton - AutoTuneBatchMatmul; - -template -struct CoefficientType { - typedef Scalar type; -}; -template <> -struct CoefficientType { - typedef float type; -}; - -inline Status FromExecutorStatus(const se::port::Status& s) { - return s.ok() ? Status::OK() - : Status(static_cast(static_cast(s.code())), - s.error_message()); -} - -template -inline Status FromExecutorStatus(const se::port::StatusOr& s) { - return FromExecutorStatus(s.status()); -} - } // namespace template struct LaunchBatchMatMul { static void Launch(OpKernelContext* context, const Tensor& in_x, const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x, - bool trans_y, const MatMulBCast& bcast, bool use_autotune, - Tensor* out) { + bool trans_y, const MatMulBCast& bcast, Tensor* out) { se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose, se::blas::Transpose::kTranspose, se::blas::Transpose::kConjugateTranspose}; @@ -517,191 +343,10 @@ struct LaunchBatchMatMul { auto* a_base_ptr = in_x.template flat().data(); auto* b_base_ptr = in_y.template flat().data(); auto* c_base_ptr = out->template flat().data(); - int64 a_stride; - int64 b_stride; - int64 c_stride; + uint64 a_stride; + uint64 b_stride; + uint64 c_stride; - typedef typename CoefficientType::type Coefficient; - - static const int64 max_scratch_size = GetBlasWorkspaceLimit( - "TF_CUBLAS_WORKSPACE_LIMIT_IN_MB", 1LL << 32); // 4GB by default - - // The BlasLtMatmul routines are only supported from CUDA 11.0 onward. -#if GOOGLE_CUDA && CUDA_VERSION >= 11000 - bool is_full_broadcast = - std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1; - bool requires_mixed_broadcasting = - bcast.IsBroadcastingRequired() && !is_full_broadcast; - if (!requires_mixed_broadcasting) { - bool broadcast_a = bcast.x_batch_size() == 1; - bool broadcast_b = bcast.y_batch_size() == 1; - a_stride = broadcast_a ? 0 : m * k; - b_stride = broadcast_b ? 0 : k * n; - c_stride = m * n; - a_device_memory.push_back(AsDeviceMemory(a_base_ptr)); - b_device_memory.push_back(AsDeviceMemory(b_base_ptr)); - c_device_memory.push_back(AsDeviceMemory(c_base_ptr)); - a_ptrs.push_back(&a_device_memory.back()); - b_ptrs.push_back(&b_device_memory.back()); - c_ptrs.push_back(&c_device_memory.back()); - - DataType dtype = DataTypeToEnum::value; - bool allow_tf32 = tensor_float_32_execution_enabled(); - int device_id = stream->parent()->device_ordinal(); - BatchMatmulParameters matmul_parameters( - trans_x, trans_y, adj_x, adj_y, m, n, k, batch_size, broadcast_a, - broadcast_b, dtype, dtype, allow_tf32, device_id); - - static const bool max_autotune_algorithm_count = - MatmulMaxAutotuneAlgorithmCount(); - int max_algorithm_count = use_autotune ? max_autotune_algorithm_count : 1; - - const auto* plan_and_algorithms = - BatchMatmulPlanMapSingleton::GetInstance()->Find(matmul_parameters); - if (!plan_and_algorithms) { - se::blas::DataType blas_dtype = se::blas::ToDataType::value; - se::blas::ComputationType computation_type; - OP_REQUIRES( - context, - GetBlasComputationType(dtype, allow_tf32, &computation_type), - errors::Internal("Unsupported dtype for batched matmul")); - - auto status_or_plan = stream->parent()->CreateBlasLtMatmulPlan( - {/*ab_type=*/blas_dtype, - /*c_type=*/blas_dtype, computation_type, - se::blas::PointerMode::kHost, se::blas::Epilogue::kDefault, - blas_transpose_b, blas_transpose_a, n, m, k, - /*lda=*/in_y.dim_size(2), /*ldb=*/in_x.dim_size(2), - /*ldc=*/static_cast(n), static_cast(batch_size), - b_stride, a_stride, c_stride}); - OP_REQUIRES(context, status_or_plan.ok(), - FromExecutorStatus(status_or_plan)); - std::unique_ptr plan = - status_or_plan.ConsumeValueOrDie(); - - auto status_or_algorithms = stream->parent()->GetBlasLtMatmulAlgorithms( - plan.get(), max_scratch_size, max_algorithm_count); - OP_REQUIRES(context, status_or_algorithms.ok(), - FromExecutorStatus(status_or_algorithms)); - auto algorithms = status_or_algorithms.ConsumeValueOrDie(); - - plan_and_algorithms = - BatchMatmulPlanMapSingleton::GetInstance()->Insert( - matmul_parameters, {std::move(plan), std::move(algorithms)}); - } - const auto& plan = plan_and_algorithms->plan; - const auto& algorithms = plan_and_algorithms->algorithms; - - // The BlasLtMatmul routines (unlike BlasGemm, BlasGemmBatched etc.) take - // alpha and beta with the same type as the matrices. - Scalar alpha(1.0); - Scalar beta(0.0); - - // Note that algorithm_config.algorithm() here is used to refer - // to the index within the algorithms vector, not the algorithm - // itself. - se::blas::AlgorithmConfig algorithm_config(se::blas::kNoAlgorithm); - if (max_algorithm_count == 1) { - algorithm_config.set_algorithm(0); - } else if (!AutoTuneBatchMatmul::GetInstance()->Find(matmul_parameters, - &algorithm_config)) { - VLOG(4) << "Autotuning BlasLtMatmul over " << algorithms.size() - << " algorithms."; - se::blas::ProfileResult best_result; - se::blas::ProfileResult profile_result; - // for (const auto& profile_algorithm : plan_and_algorithms->algorithms) - // { - for (size_t i = 0; i != algorithms.size(); ++i) { - const auto& profile_algorithm = algorithms[i]; - // Create a new scratch allocator with every autotuning run so that - // scratch space is deallocated between runs. - BlasScratchAllocator scratch_allocator(max_scratch_size, context); - - bool cublas_launch_status = - stream - ->ThenBlasLtMatmul(plan.get(), alpha, *b_ptrs[0], *a_ptrs[0], - beta, c_ptrs[0], &scratch_allocator, - profile_algorithm.get(), {}, - &profile_result) - .ok(); - - VLOG(4) << " Autotune algorithm " << i - << " result: " << profile_result.elapsed_time_in_ms() - << " ms, valid=" << profile_result.is_valid() - << ", workspace_size=" << profile_algorithm->workspace_size(); - - if (cublas_launch_status && profile_result.is_valid() && - profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; - } - } - - if (best_result.is_valid()) { - algorithm_config.set_algorithm(best_result.algorithm()); - } - // We make sure that each matmul parameter set only gets one pass of - // autotune. If no algorithms works, we add kNoAlgorithm to the autotune - // map. - AutoTuneBatchMatmul::GetInstance()->Insert(matmul_parameters, - algorithm_config); - } - se::blas::AlgorithmType algorithm_idx = algorithm_config.algorithm(); - OP_REQUIRES(context, - 0 <= algorithm_idx && algorithm_idx < algorithms.size(), - errors::Internal("Missing/invalid BatchMatmul algorithm")); - const auto& algorithm = algorithms[algorithm_idx]; - BlasScratchAllocator scratch_allocator(max_scratch_size, context); - bool cublas_launch_status = - stream - ->ThenBlasLtMatmul(plan.get(), alpha, *b_ptrs[0], *a_ptrs[0], - beta, c_ptrs[0], &scratch_allocator, - algorithm.get()) - .ok(); - if (!cublas_launch_status) { - context->SetStatus(errors::Internal( - "Blas batched matmul launch failed : a.shape=(", - bcast.x_batch_size(), ", ", in_x.dim_size(0), ", ", - in_x.dim_size(1), "), b.shape=(", bcast.y_batch_size(), ", ", - in_y.dim_size(0), ", ", in_y.dim_size(1), "), m=", m, ", n=", n, - ", k=", k, ", batch_size=", batch_size)); - } - } else { // requires mixed broadcasting - const std::vector& a_batch_indices = bcast.x_batch_indices(); - const std::vector& b_batch_indices = bcast.y_batch_indices(); - for (int64 i = 0; i < bcast.x_batch_size(); ++i) { - a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k)); - } - for (int64 i = 0; i < bcast.y_batch_size(); ++i) { - b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n)); - } - for (int64 i = 0; i < batch_size; ++i) { - c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n)); - a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]); - b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]); - c_ptrs.push_back(&c_device_memory.back()); - } - - BlasScratchAllocator scratch_allocator(max_scratch_size, context); - bool blas_launch_status = - stream - ->ThenBlasGemmBatchedWithScratch( - blas_transpose_b, blas_transpose_a, n, m, k, - static_cast(1.0), b_ptrs, - adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k, - static_cast(0.0), c_ptrs, n, batch_size, - &scratch_allocator) - .ok(); - if (!blas_launch_status) { - context->SetStatus(errors::Internal( - "Blas xGEMMBatched launch failed : a.shape=", - in_x.shape().DebugString(), - ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n, - ", k=", k, ", batch_size=", batch_size)); - } - } - return; -#else // if not GOOGLE_CUDA or CUDA_VERSION < 11000 bool is_full_broadcast = std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1; bool use_strided_batched = @@ -743,6 +388,8 @@ struct LaunchBatchMatMul { } } + typedef Scalar Coefficient; + // Blas does // C = A x B // where A, B and C are assumed to be in column major. @@ -752,10 +399,7 @@ struct LaunchBatchMatMul { if (batch_size == 1) { // This is a regular matrix*matrix or matrix*vector multiply. Avoid the // overhead of the scratch allocator and the batch interface. - // Note that the GEMV call here does not support Eigen::half, so we do not - // use this path in that case. A workaround is applied to the pointers - // passed to the call itself to avoid compilation errors. - if (!std::is_same::value && n == 1 && + if (n == 1 && blas_transpose_b != se::blas::Transpose::kConjugateTranspose && blas_transpose_a != se::blas::Transpose::kConjugateTranspose) { // This is a matrix*vector multiply so use GEMV to compute A * b. @@ -766,19 +410,13 @@ struct LaunchBatchMatMul { auto gemv_trans_a = blas_transpose_a == se::blas::Transpose::kTranspose ? se::blas::Transpose::kNoTranspose : se::blas::Transpose::kTranspose; - // Cast pointers as a workaround for GEMV not supporting Eigen::half - // (this will never actually be executed for Eigen::half). - typedef se::DeviceMemory NonHalfDeviceMemoryType; - NonHalfDeviceMemoryType a_ptr(*(a_ptrs[0])); - NonHalfDeviceMemoryType b_ptr(*(b_ptrs[0])); - NonHalfDeviceMemoryType c_ptr(*(c_ptrs[0])); bool blas_launch_status = stream ->ThenBlasGemv(gemv_trans_a, adj_x || trans_x ? m : k, adj_x || trans_x ? k : m, - static_cast(1.0), a_ptr, - adj_x || trans_x ? m : k, b_ptr, 1, - static_cast(0.0), &c_ptr, 1) + static_cast(1.0), *(a_ptrs[0]), + adj_x || trans_x ? m : k, *(b_ptrs[0]), 1, + static_cast(0.0), c_ptrs[0], 1) .ok(); if (!blas_launch_status) { context->SetStatus(errors::Internal( @@ -821,7 +459,154 @@ struct LaunchBatchMatMul { ", k=", k, ", batch_size=", batch_size)); } } else { - BlasScratchAllocator scratch_allocator(max_scratch_size, context); + BlasScratchAllocator scratch_allocator(context); + bool blas_launch_status = + stream + ->ThenBlasGemmBatchedWithScratch( + blas_transpose_b, blas_transpose_a, n, m, k, + static_cast(1.0), b_ptrs, + adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k, + static_cast(0.0), c_ptrs, n, batch_size, + &scratch_allocator) + .ok(); + if (!blas_launch_status) { + context->SetStatus(errors::Internal( + "Blas xGEMMBatched launch failed : a.shape=", + in_x.shape().DebugString(), + ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n, + ", k=", k, ", batch_size=", batch_size)); + } + } + } +}; + +template <> +struct LaunchBatchMatMul { + static void Launch(OpKernelContext* context, const Tensor& in_x, + const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x, + bool trans_y, const MatMulBCast& bcast, Tensor* out) { + typedef Eigen::half Scalar; + se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose, + se::blas::Transpose::kTranspose, + se::blas::Transpose::kConjugateTranspose}; + const uint64 m = in_x.dim_size(adj_x || trans_x ? 2 : 1); + const uint64 k = in_x.dim_size(adj_x || trans_x ? 1 : 2); + const uint64 n = in_y.dim_size(adj_y || trans_y ? 1 : 2); + const uint64 batch_size = bcast.output_batch_size(); + auto blas_transpose_a = trans[adj_x ? 2 : (trans_x ? 1 : 0)]; + auto blas_transpose_b = trans[adj_y ? 2 : (trans_y ? 1 : 0)]; + + auto* stream = context->op_device_context()->stream(); + OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); + + typedef perftools::gputools::DeviceMemory DeviceMemoryType; + std::vector a_device_memory; + std::vector b_device_memory; + std::vector c_device_memory; + std::vector a_ptrs; + std::vector b_ptrs; + std::vector c_ptrs; + a_device_memory.reserve(bcast.x_batch_size()); + b_device_memory.reserve(bcast.y_batch_size()); + c_device_memory.reserve(batch_size); + a_ptrs.reserve(batch_size); + b_ptrs.reserve(batch_size); + c_ptrs.reserve(batch_size); + auto* a_base_ptr = in_x.template flat().data(); + auto* b_base_ptr = in_y.template flat().data(); + auto* c_base_ptr = out->template flat().data(); + + uint64 a_stride; + uint64 b_stride; + uint64 c_stride; + + bool is_full_broadcast = + std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1; + bool use_strided_batched = + (!bcast.IsBroadcastingRequired() || is_full_broadcast) && + batch_size > 1; + if (use_strided_batched) { + a_stride = bcast.x_batch_size() != 1 ? m * k : 0; + b_stride = bcast.y_batch_size() != 1 ? k * n : 0; + c_stride = m * n; + a_device_memory.push_back(AsDeviceMemory(a_base_ptr)); + b_device_memory.push_back(AsDeviceMemory(b_base_ptr)); + c_device_memory.push_back(AsDeviceMemory(c_base_ptr)); + a_ptrs.push_back(&a_device_memory.back()); + b_ptrs.push_back(&b_device_memory.back()); + c_ptrs.push_back(&c_device_memory.back()); + } else if (!bcast.IsBroadcastingRequired()) { + for (int64 i = 0; i < batch_size; ++i) { + a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k)); + b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n)); + c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n)); + a_ptrs.push_back(&a_device_memory.back()); + b_ptrs.push_back(&b_device_memory.back()); + c_ptrs.push_back(&c_device_memory.back()); + } + } else { + const std::vector& a_batch_indices = bcast.x_batch_indices(); + const std::vector& b_batch_indices = bcast.y_batch_indices(); + for (int64 i = 0; i < bcast.x_batch_size(); ++i) { + a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k)); + } + for (int64 i = 0; i < bcast.y_batch_size(); ++i) { + b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n)); + } + for (int64 i = 0; i < batch_size; ++i) { + c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n)); + a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]); + b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]); + c_ptrs.push_back(&c_device_memory.back()); + } + } + + typedef float Coefficient; + + // Blas does + // C = A x B + // where A, B and C are assumed to be in column major. + // We want the output to be in row-major, so we can compute + // C' = B' x A', where ' stands for transpose (not adjoint). + // TODO(yangzihao): Choose the best of the three strategies using autotune. + if (batch_size == 1) { + // This is a regular matrix*matrix or matrix*vector multiply. Avoid the + // overhead of the scratch allocator and the batch interface. + // TODO(benbarsdell): Use fp16 Gemv if it becomes supported by CUBLAS + bool blas_launch_status = + stream + ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k, + static_cast(1.0), *(b_ptrs[0]), + adj_y || trans_y ? k : n, *(a_ptrs[0]), + adj_x || trans_x ? m : k, + static_cast(0.0), c_ptrs[0], n) + .ok(); + if (!blas_launch_status) { + context->SetStatus(errors::Internal( + "Blas xGEMM launch failed : a.shape=", in_x.shape().DebugString(), + ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n, + ", k=", k)); + } + } else if (use_strided_batched) { + bool blas_launch_status = + stream + ->ThenBlasGemmStridedBatched( + blas_transpose_b, blas_transpose_a, n, m, k, + static_cast(1.0), *b_ptrs[0], + adj_y || trans_y ? k : n, b_stride, *a_ptrs[0], + adj_x || trans_x ? m : k, a_stride, + static_cast(0.0), c_ptrs[0], n, c_stride, + batch_size) + .ok(); + if (!blas_launch_status) { + context->SetStatus(errors::Internal( + "Blas xGEMMStridedBatched launch failed : a.shape=", + in_x.shape().DebugString(), + ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n, + ", k=", k, ", batch_size=", batch_size)); + } + } else { + BlasScratchAllocator scratch_allocator(context); bool blas_launch_status = stream ->ThenBlasGemmBatchedWithScratch( @@ -839,7 +624,6 @@ struct LaunchBatchMatMul { ", k=", k, ", batch_size=", batch_size)); } } -#endif // not GOOGLE_CUDA or CUDA_VERSION < 11000 } }; @@ -853,7 +637,6 @@ class BaseBatchMatMulOp : public OpKernel { : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_)); OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_)); - use_autotune_ = MatmulAutotuneEnable(); } ~BaseBatchMatMulOp() override {} @@ -915,7 +698,7 @@ class BaseBatchMatMulOp : public OpKernel { out->shape().DebugString())); LaunchBatchMatMul::Launch( ctx, in0_reshaped, in1_reshaped, adj_x_, adj_y_, /*trans_x=*/false, - /*trans_y=*/false, bcast, use_autotune_, &out_reshaped); + /*trans_y=*/false, bcast, &out_reshaped); } protected: @@ -925,7 +708,6 @@ class BaseBatchMatMulOp : public OpKernel { private: bool adj_x_; bool adj_y_; - bool use_autotune_; }; // BatchMatMul Op implementation which disallows broadcasting. diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index f5b9e79fb54..025a8e37a94 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -619,7 +619,19 @@ template struct LaunchConv2DOp; int64 GetDnnWorkspaceLimit(const string& envvar_in_mb, int64 default_value_in_bytes) { - return gpu_utils::GetWorkspaceLimit(envvar_in_mb, default_value_in_bytes); + const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str()); + if (workspace_limit_in_mb_str != nullptr && + strcmp(workspace_limit_in_mb_str, "") != 0) { + int64 scratch_limit_in_mb = -1; + if (strings::safe_strto64(workspace_limit_in_mb_str, + &scratch_limit_in_mb)) { + return scratch_limit_in_mb * (1 << 20); + } else { + LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": " + << workspace_limit_in_mb_str; + } + } + return default_value_in_bytes; } // A dummy type to group forward convolution autotune results together. diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h index 8beab722a64..2e97d486b54 100644 --- a/tensorflow/core/kernels/conv_ops_gpu.h +++ b/tensorflow/core/kernels/conv_ops_gpu.h @@ -48,7 +48,52 @@ int64 GetDnnWorkspaceLimit(const string& envvar_in_mb, // A class to provide scratch-space allocator for Stream-Executor Cudnn // callback. TensorFlow is responsible for releasing the temporary buffers after // the kernel finishes. -using DnnScratchAllocator = GpuScratchAllocator; +class DnnScratchAllocator : public se::ScratchAllocator { + public: + virtual ~DnnScratchAllocator() {} + DnnScratchAllocator(int64 memory_limit, OpKernelContext* context) + : memory_limit_(memory_limit), total_byte_size_(0), context_(context) {} + int64 GetMemoryLimitInBytes() override { return memory_limit_; } + se::port::StatusOr> AllocateBytes( + int64 byte_size) override { + Tensor temporary_memory; + if (byte_size < 0) { + return se::port::Status{se::port::error::INVALID_ARGUMENT, + "Requested negative byte size!"}; + } + if (byte_size > memory_limit_) { + return se::port::Status{se::port::error::UNAVAILABLE, + absl::StrCat("Requested memory size (", byte_size, + ") exceeds the max memory limit (", + memory_limit_, ").")}; + } + AllocationAttributes allocation_attr; + allocation_attr.retry_on_failure = false; + Status allocation_status(context_->allocate_temp( + DT_UINT8, TensorShape({byte_size}), &temporary_memory, + AllocatorAttributes(), allocation_attr)); + if (!allocation_status.ok()) { + return se::port::Status{ + se::port::error::UNAVAILABLE, + absl::StrCat("Failed to allocate the requested memory size (", + byte_size, ").")}; + } + // Hold the reference of the allocated tensors until the end of the + // allocator. + allocated_tensors_.push_back(temporary_memory); + total_byte_size_ += byte_size; + return se::port::StatusOr>( + AsDeviceMemory(temporary_memory.flat().data(), + temporary_memory.flat().size())); + } + int64 TotalByteSize() { return total_byte_size_; } + + private: + int64 memory_limit_; + int64 total_byte_size_; + OpKernelContext* context_; + std::vector allocated_tensors_; +}; // Encapsulate all the shape information that is used in both forward and // backward conv operations. diff --git a/tensorflow/core/kernels/fft_ops.cc b/tensorflow/core/kernels/fft_ops.cc index 9b625c256a5..050b83980c6 100644 --- a/tensorflow/core/kernels/fft_ops.cc +++ b/tensorflow/core/kernels/fft_ops.cc @@ -31,7 +31,6 @@ limitations under the License. #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) -#include "tensorflow/core/kernels/gpu_utils.h" #include "tensorflow/core/platform/stream_executor.h" #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM @@ -401,7 +400,20 @@ class CufftScratchAllocator : public se::ScratchAllocator { int64 GetCufftWorkspaceLimit(const string& envvar_in_mb, int64 default_value_in_bytes) { - return gpu_utils::GetWorkspaceLimit(envvar_in_mb, default_value_in_bytes); + const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str()); + if (workspace_limit_in_mb_str != nullptr && + strcmp(workspace_limit_in_mb_str, "") != 0) { + int64 scratch_limit_in_mb = -1; + Status status = ReadInt64FromEnvVar(envvar_in_mb, default_value_in_bytes, + &scratch_limit_in_mb); + if (!status.ok()) { + LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": " + << workspace_limit_in_mb_str; + } else { + return scratch_limit_in_mb * (1 << 20); + } + } + return default_value_in_bytes; } class FFTGPUBase : public FFTBase { diff --git a/tensorflow/core/kernels/gpu_utils.cc b/tensorflow/core/kernels/gpu_utils.cc index 1a14768f487..7da1963c676 100644 --- a/tensorflow/core/kernels/gpu_utils.cc +++ b/tensorflow/core/kernels/gpu_utils.cc @@ -22,7 +22,6 @@ limitations under the License. #include "google/protobuf/any.pb.h" #include "absl/algorithm/container.h" #include "absl/base/call_once.h" -#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/logger.h" #include "tensorflow/core/protobuf/autotuning.pb.h" #include "tensorflow/core/protobuf/conv_autotuning.pb.h" @@ -283,64 +282,6 @@ Status BestCudnnConvAlgorithm(absl::Span results, return Status::OK(); } -namespace gpu_utils { -int64 GetWorkspaceLimit(const string& envvar_in_mb, - int64 default_value_in_bytes) { - const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str()); - if (workspace_limit_in_mb_str != nullptr && - strcmp(workspace_limit_in_mb_str, "") != 0) { - int64 scratch_limit_in_mb = -1; - if (strings::safe_strto64(workspace_limit_in_mb_str, - &scratch_limit_in_mb)) { - return scratch_limit_in_mb * (1 << 20); - } else { - LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": " - << workspace_limit_in_mb_str; - } - } - return default_value_in_bytes; -} -} // namespace gpu_utils - -GpuScratchAllocator::GpuScratchAllocator(int64 memory_limit, - OpKernelContext* context) - : memory_limit_(memory_limit), total_byte_size_(0), context_(context) {} - -se::port::StatusOr> GpuScratchAllocator::AllocateBytes( - int64 byte_size) { - Tensor temporary_memory; - if (byte_size < 0) { - return se::port::Status{se::port::error::INVALID_ARGUMENT, - "Requested negative byte size!"}; - } - if (byte_size > memory_limit_) { - return se::port::Status{ - se::port::error::UNAVAILABLE, - absl::StrCat("Requested memory size (", byte_size, - ") exceeds the max memory limit (", memory_limit_, ").")}; - } - AllocationAttributes allocation_attr; - allocation_attr.retry_on_failure = false; - Status allocation_status(context_->allocate_temp( - DT_UINT8, TensorShape({byte_size}), &temporary_memory, - AllocatorAttributes(), allocation_attr)); - if (!allocation_status.ok()) { - return se::port::Status{ - se::port::error::UNAVAILABLE, - absl::StrCat("Failed to allocate the requested memory size (", - byte_size, ").")}; - } - // Hold the reference of the allocated tensors until the end of the - // allocator. - // NOTE: We expect tensors to be deallocated when this allocator goes out of - // scope when allocated_tensors is destructed. - allocated_tensors_.push_back(temporary_memory); - total_byte_size_ += byte_size; - return se::port::StatusOr>( - AsDeviceMemory(temporary_memory.flat().data(), - temporary_memory.flat().size())); -} - } // namespace tensorflow #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/gpu_utils.h b/tensorflow/core/kernels/gpu_utils.h index 62db406513b..a1589db3b5b 100644 --- a/tensorflow/core/kernels/gpu_utils.h +++ b/tensorflow/core/kernels/gpu_utils.h @@ -243,37 +243,6 @@ void LogFusedConvForwardAutotuneResults( Status BestCudnnConvAlgorithm(absl::Span results, se::dnn::AlgorithmConfig* algo); -namespace gpu_utils { -// Get a workspace limit from the environment variable, which is in MB. -// Return the workspace memory limit in bytes. If no value is set, return the -// default value. -int64 GetWorkspaceLimit(const string& envvar_in_mb, - int64 default_value_in_bytes); -} // namespace gpu_utils - -// A class to provide scratch-space allocator for Stream-Executor callbacks in -// CUDA libraries (CUDNN etc.). -// TensorFlow is responsible for releasing the temporary buffers after -// the kernel finishes. -class GpuScratchAllocator : public se::ScratchAllocator { - public: - virtual ~GpuScratchAllocator() {} - - GpuScratchAllocator(int64 memory_limit, OpKernelContext* context); - - int64 GetMemoryLimitInBytes() override { return memory_limit_; } - - se::port::StatusOr> AllocateBytes( - int64 byte_size) override; - - int64 TotalByteSize() { return total_byte_size_; } - - private: - int64 memory_limit_; - int64 total_byte_size_; - OpKernelContext* context_; - std::vector allocated_tensors_; -}; } // namespace tensorflow #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/linalg/einsum_op_impl.h b/tensorflow/core/kernels/linalg/einsum_op_impl.h index 1fe9a34e67d..b9b2d1f0eae 100644 --- a/tensorflow/core/kernels/linalg/einsum_op_impl.h +++ b/tensorflow/core/kernels/linalg/einsum_op_impl.h @@ -549,7 +549,7 @@ struct EinsumHelper { static Status ContractOperands(OpKernelContext* ctx, absl::Span inputs, absl::Span swap_free_and_contract, - bool use_autotune, Tensor* output) { + Tensor* output) { if (inputs.size() == 1) return CopyFrom(inputs[0], inputs[0].shape(), output); MatMulBCast bcast(inputs[0].shape().dim_sizes(), @@ -583,7 +583,7 @@ struct EinsumHelper { ReshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped)); LaunchBatchMatMul::Launch(ctx, lhs, rhs, /*adj_x=*/false, /*adj_y=*/false, trans_x, trans_y, - bcast, use_autotune, &output_reshaped); + bcast, &output_reshaped); return Status::OK(); } }; @@ -598,7 +598,6 @@ class EinsumOp : public OpKernel { equation_, &input_labels_, &output_labels_, &label_types_, &input_label_counts_, &output_label_counts_, &input_has_ellipsis_, &output_has_ellipsis_)); - use_autotune_ = MatmulAutotuneEnable(); } void Compute(OpKernelContext* ctx) override { @@ -641,7 +640,7 @@ class EinsumOp : public OpKernel { Tensor contraction_output_reshaped; OP_REQUIRES_OK(ctx, EinsumHelper::ContractOperands( ctx, inputs_reduced, swap_free_and_contract, - use_autotune_, &contraction_output_reshaped)); + &contraction_output_reshaped)); // Copy the batch labels from the contraction output. Recover the batch // shape, which may have been broadcasted. @@ -739,7 +738,6 @@ class EinsumOp : public OpKernel { LabelCounts output_label_counts_; gtl::InlinedVector input_has_ellipsis_; bool output_has_ellipsis_ = false; - bool use_autotune_; }; #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/util/matmul_autotune.cc b/tensorflow/core/util/matmul_autotune.cc index c30a5d930e7..741a78a193f 100644 --- a/tensorflow/core/util/matmul_autotune.cc +++ b/tensorflow/core/util/matmul_autotune.cc @@ -48,22 +48,4 @@ bool MatmulDoFP32ComputationFP16Input() { return value; } -int MatmulMaxAutotuneAlgorithmCount() { - int64 value; - // In CUDA 11, cublasLtMatmulAlgoGetHeuristic typically returns <= 4 - // algorithms for a given configuration, so 10 seems like a reasonable default - // here. - Status status = - ReadInt64FromEnvVar("TF_MATMUL_AUTOTUNE_MAX_ALGORITHMS", 10, &value); - if (!status.ok()) { - LOG(ERROR) << status.error_message(); - } - static constexpr const int kMaxValue = std::numeric_limits::max(); - if (value < 1 || value > kMaxValue) { - LOG(ERROR) << "Invalid value for TF_MATMUL_AUTOTUNE_MAX_ALGORITHMS: " - << value << " is not in range [1, " << kMaxValue << "]"; - } - return value; -} - } // namespace tensorflow diff --git a/tensorflow/core/util/matmul_autotune.h b/tensorflow/core/util/matmul_autotune.h index c77d274e781..5846cae2fc7 100644 --- a/tensorflow/core/util/matmul_autotune.h +++ b/tensorflow/core/util/matmul_autotune.h @@ -22,7 +22,6 @@ namespace tensorflow { bool MatmulAutotuneEnable(); bool MatmulDoFP32ComputationFP16Input(); -int MatmulMaxAutotuneAlgorithmCount(); } // namespace tensorflow diff --git a/third_party/gpus/cuda/BUILD.tpl b/third_party/gpus/cuda/BUILD.tpl index 70eacf82883..a4a21abc367 100644 --- a/third_party/gpus/cuda/BUILD.tpl +++ b/third_party/gpus/cuda/BUILD.tpl @@ -127,13 +127,6 @@ cc_library( linkstatic = 1, ) -cc_library( - name = "cublasLt", - srcs = ["cuda/lib/%{cublasLt_lib}"], - data = ["cuda/lib/%{cublasLt_lib}"], - linkstatic = 1, -) - cc_library( name = "cusolver", srcs = ["cuda/lib/%{cusolver_lib}"], @@ -175,7 +168,6 @@ cc_library( name = "cuda", deps = [ ":cublas", - ":cublasLt", ":cuda_headers", ":cudart", ":cudnn", diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index 3ba34470b93..704003b7f63 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -551,13 +551,6 @@ def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config): cuda_config.cublas_version, static = False, ), - "cublasLt": _check_cuda_lib_params( - "cublasLt", - cpu_value, - cuda_config.config["cublas_library_dir"], - cuda_config.cublas_version, - static = False, - ), "cusolver": _check_cuda_lib_params( "cusolver", cpu_value, @@ -787,7 +780,6 @@ def _create_dummy_repository(repository_ctx): "%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value), "%{cudart_lib}": lib_name("cudart", cpu_value), "%{cublas_lib}": lib_name("cublas", cpu_value), - "%{cublasLt_lib}": lib_name("cublasLt", cpu_value), "%{cusolver_lib}": lib_name("cusolver", cpu_value), "%{cudnn_lib}": lib_name("cudnn", cpu_value), "%{cufft_lib}": lib_name("cufft", cpu_value), @@ -819,7 +811,6 @@ filegroup(name="cudnn-include") "cuda/cuda/lib/%s" % lib_name("cudart_static", cpu_value), ) repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublas", cpu_value)) - repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublasLt", cpu_value)) repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusolver", cpu_value)) repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudnn", cpu_value)) repository_ctx.file("cuda/cuda/lib/%s" % lib_name("curand", cpu_value)) @@ -1011,13 +1002,11 @@ def _create_local_cuda_repository(repository_ctx): cublas_include_path + "/cublas.h", cublas_include_path + "/cublas_v2.h", cublas_include_path + "/cublas_api.h", - cublas_include_path + "/cublasLt.h", ], outs = [ "cublas/include/cublas.h", "cublas/include/cublas_v2.h", "cublas/include/cublas_api.h", - "cublas/include/cublasLt.h", ], )) @@ -1158,7 +1147,6 @@ def _create_local_cuda_repository(repository_ctx): "%{cudart_static_linkopt}": _cudart_static_linkopt(cuda_config.cpu_value), "%{cudart_lib}": _basename(repository_ctx, cuda_libs["cudart"]), "%{cublas_lib}": _basename(repository_ctx, cuda_libs["cublas"]), - "%{cublasLt_lib}": _basename(repository_ctx, cuda_libs["cublasLt"]), "%{cusolver_lib}": _basename(repository_ctx, cuda_libs["cusolver"]), "%{cudnn_lib}": _basename(repository_ctx, cuda_libs["cudnn"]), "%{cufft_lib}": _basename(repository_ctx, cuda_libs["cufft"]),