From 0d172940c102b37300c3ddb7d8dbd3835382a474 Mon Sep 17 00:00:00 2001 From: Ben Barsdell Date: Mon, 6 Jul 2020 21:25:04 +1000 Subject: [PATCH] Use BlasLtMatmul APIs in batch_matmul_op_impl - Integrates BlasLtMatmul with autotuning into the implementation of the BatchMatMul and Einsum ops. - This integration is only used when the CUDA version is >= 11.0. --- tensorflow/core/kernels/BUILD | 4 + .../core/kernels/batch_matmul_op_impl.h | 584 ++++++++++++------ tensorflow/core/kernels/gpu_utils.cc | 57 ++ tensorflow/core/kernels/gpu_utils.h | 36 ++ .../core/kernels/linalg/einsum_op_impl.h | 7 +- tensorflow/core/util/matmul_autotune.cc | 18 + tensorflow/core/util/matmul_autotune.h | 1 + 7 files changed, 522 insertions(+), 185 deletions(-) diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 9917b8e5c95..1eec9056040 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3334,6 +3334,9 @@ tf_kernel_library( prefix = "batch_matmul_op", deps = MATH_DEPS + [":eigen_contraction_kernel"] + if_mkl_ml([ "//third_party/mkl:intel_binary_blob", + ]) + if_cuda([ + "//tensorflow/core/kernels:gpu_utils", + "//tensorflow/core/platform:tensor_float_32_utils", ]), ) @@ -3392,6 +3395,7 @@ tf_kernel_library( prefix = "fft_ops", deps = MATH_DEPS + [ ] + if_cuda([ + "//tensorflow/core/kernels:gpu_utils", "//tensorflow/core/platform/default/build_config:cufft_plugin", ]), ) diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/batch_matmul_op_impl.h index d6cc980633f..5ca85c00835 100644 --- a/tensorflow/core/kernels/batch_matmul_op_impl.h +++ b/tensorflow/core/kernels/batch_matmul_op_impl.h @@ -22,7 +22,6 @@ limitations under the License. #include -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -34,17 +33,24 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/tensor_float_32_utils.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/matmul_autotune.h" #include "tensorflow/core/util/matmul_bcast.h" #include "tensorflow/core/util/work_sharder.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) #include "tensorflow/core/kernels/eigen_contraction_kernel.h" #endif #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "tensorflow/core/kernels/gpu_utils.h" #include "tensorflow/core/platform/stream_executor.h" #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#if GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda.h" // For CUDA_VERSION +#endif namespace tensorflow { @@ -219,7 +225,8 @@ template struct LaunchBatchMatMul { static void Launch(OpKernelContext* context, const Tensor& in_x, const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x, - bool trans_y, const MatMulBCast& bcast, Tensor* out) { + bool trans_y, const MatMulBCast& bcast, bool use_autotune, + Tensor* out) { typedef ParallelMatMulKernel::IsComplex> ParallelMatMulKernel; bool conjugate_result = false; @@ -275,45 +282,201 @@ se::DeviceMemory AsDeviceMemory(const T* gpu_memory) { return typed; } -class BlasScratchAllocator : public se::ScratchAllocator { +using BlasScratchAllocator = GpuScratchAllocator; + +int64 GetBlasWorkspaceLimit(const string& envvar_in_mb, + int64 default_value_in_bytes) { + return GetWorkspaceLimit(envvar_in_mb, default_value_in_bytes); +} + +// Encapsulate all of the shape, dtype etc. information that defines a unique +// batched matmul operation. +class BatchMatmulParameters { public: - using Stream = se::Stream; - using DeviceMemoryBytes = se::DeviceMemory; + BatchMatmulParameters(bool trans_a, bool trans_b, bool adj_a, bool adj_b, + uint64 m, uint64 n, uint64 k, uint64 batch_count, + bool broadcast_a, bool broadcast_b, DataType dtype_ab, + DataType dtype_cd, bool allow_tf32, int device_id) + : trans_a_(trans_a), + trans_b_(trans_b), + adj_a_(adj_a), + adj_b_(adj_b), + m_(m), + n_(n), + k_(k), + batch_count_(batch_count), + broadcast_a_(broadcast_a), + broadcast_b_(broadcast_b), + dtype_ab_(dtype_ab), + dtype_cd_(dtype_cd), + allow_tf32_(allow_tf32), + device_id_(device_id) { + hash_code_ = trans_a; + hash_code_ = Hash64Combine(hash_code_, trans_b); + hash_code_ = Hash64Combine(hash_code_, adj_a); + hash_code_ = Hash64Combine(hash_code_, adj_b); + hash_code_ = Hash64Combine(hash_code_, m); + hash_code_ = Hash64Combine(hash_code_, n); + hash_code_ = Hash64Combine(hash_code_, k); + hash_code_ = Hash64Combine(hash_code_, batch_count); + hash_code_ = Hash64Combine(hash_code_, broadcast_a); + hash_code_ = Hash64Combine(hash_code_, broadcast_b); + hash_code_ = Hash64Combine(hash_code_, dtype_ab); + hash_code_ = Hash64Combine(hash_code_, dtype_cd); + hash_code_ = Hash64Combine(hash_code_, allow_tf32); + hash_code_ = Hash64Combine(hash_code_, device_id); + } + bool operator==(const BatchMatmulParameters& other) const { + return this->get_data_as_tuple() == other.get_data_as_tuple(); + } - BlasScratchAllocator(OpKernelContext* context) : context_(context) {} + bool operator!=(const BatchMatmulParameters& other) const { + return !(*this == other); + } + uint64 hash() const { return hash_code_; } - int64 GetMemoryLimitInBytes() override { return -1; } - - se::port::StatusOr AllocateBytes( - int64 byte_size) override { - Tensor temporary_memory; - - Status allocation_status(context_->allocate_temp( - DT_UINT8, TensorShape({byte_size}), &temporary_memory)); - if (!allocation_status.ok()) { - return se::port::StatusOr( - DeviceMemoryBytes::MakeFromByteSize(nullptr, 0)); - } - // Hold the reference of the allocated tensors until the end of the - // allocator. - allocated_tensors_.push_back(temporary_memory); - return se::port::StatusOr( - DeviceMemoryBytes::MakeFromByteSize( - temporary_memory.flat().data(), - temporary_memory.flat().size())); + string ToString() const { + // clang-format off + return strings::StrCat( + trans_a_, ", ", trans_b_, ", ", adj_a_, ", ", adj_b_, ", ", + m_, ", ", n_, ", ", k_, ", ", batch_count_, ", ", + broadcast_a_, ", ", broadcast_b_, ", ", + dtype_ab_, ", ", dtype_cd_, ", ", allow_tf32_, ", ", device_id_); + // clang-format on } private: - OpKernelContext* context_; - std::vector allocated_tensors_; + typedef std::tuple + ParameterDataType; + + ParameterDataType get_data_as_tuple() const { + return std::make_tuple(trans_a_, trans_b_, adj_a_, adj_b_, m_, n_, k_, + batch_count_, broadcast_a_, broadcast_b_, dtype_ab_, + dtype_cd_, allow_tf32_, device_id_); + } + + bool trans_a_; + bool trans_b_; + bool adj_a_; + bool adj_b_; + uint64 m_; + uint64 n_; + uint64 k_; + uint64 batch_count_; + bool broadcast_a_; + bool broadcast_b_; + DataType dtype_ab_; + DataType dtype_cd_; + bool allow_tf32_; + int device_id_; + uint64 hash_code_; }; + +bool GetBlasComputationType(const DataType& dtype, bool allow_tf32, + se::blas::ComputationType* compute_type) { + using se::blas::ComputationType; + static bool use_f32_for_f16_computation = MatmulDoFP32ComputationFP16Input(); + ComputationType f32_type = + allow_tf32 ? ComputationType::kF32FastTF32 : ComputationType::kF32; + switch (dtype) { + case DT_HALF: + case DT_BFLOAT16: + *compute_type = + use_f32_for_f16_computation ? f32_type : ComputationType::kF16; + return true; + case DT_FLOAT: + *compute_type = f32_type; + return true; + case DT_DOUBLE: + *compute_type = ComputationType::kF64; + return true; + case DT_COMPLEX64: + *compute_type = f32_type; + return true; + case DT_COMPLEX128: + *compute_type = ComputationType::kComplexF64; + return true; + default: + // Unsupported compute_type, return false. + return false; + } +} + +// Thread-safe map from matmul parameters to their corresponding plan and +// algorithms. +template +class BlasLtMatmulPlanMap { + public: + struct PlanAndAlgorithms { + std::unique_ptr plan; + std::vector> algorithms; + }; + + const PlanAndAlgorithms* Find(const Parameters& params) { + mutex_lock lock(mu_); + auto iter = params_plan_map_.find(params); + if (iter == params_plan_map_.end()) { + return nullptr; + } + return &iter->second; + } + const PlanAndAlgorithms* Insert(const Parameters& params, + PlanAndAlgorithms value) { + mutex_lock lock(mu_); + return ¶ms_plan_map_.emplace(params, std::move(value)).first->second; + } + + private: + struct Hasher { + std::size_t operator()(const Parameters& parameter) const { + return parameter.hash(); + } + }; + + mutable mutex mu_; + std::unordered_map params_plan_map_ + GUARDED_BY(mu_); +}; + +template +struct BlasLtPlanMapSingleton { + typedef BlasLtMatmulPlanMap PlanMapType; + static PlanMapType* GetInstance() { + static PlanMapType* instance = new PlanMapType(); + return instance; + } +}; + +typedef BlasLtPlanMapSingleton + BatchMatmulPlanMapSingleton; + +// A dummy type to group matmul autotune results together. +struct BatchMatmulAutoTuneGroup { + static string name() { return "MatmulLt"; } +}; + +typedef AutoTuneSingleton + AutoTuneBatchMatmul; + +template +struct CoefficientType { + typedef Scalar type; +}; +template <> +struct CoefficientType { + typedef float type; +}; + } // namespace template struct LaunchBatchMatMul { static void Launch(OpKernelContext* context, const Tensor& in_x, const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x, - bool trans_y, const MatMulBCast& bcast, Tensor* out) { + bool trans_y, const MatMulBCast& bcast, bool use_autotune, + Tensor* out) { se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose, se::blas::Transpose::kTranspose, se::blas::Transpose::kConjugateTranspose}; @@ -347,6 +510,198 @@ struct LaunchBatchMatMul { uint64 b_stride; uint64 c_stride; + typedef typename CoefficientType::type Coefficient; + + static const int64 max_scratch_size = GetBlasWorkspaceLimit( + "TF_CUBLAS_WORKSPACE_LIMIT_IN_MB", 1LL << 32); // 4GB by default + + // The BlasLtMatmul routines are only supported from CUDA 11.0 onward. +#if GOOGLE_CUDA && CUDA_VERSION >= 11000 + bool is_full_broadcast = + std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1; + bool requires_mixed_broadcasting = + bcast.IsBroadcastingRequired() && !is_full_broadcast; + if (!requires_mixed_broadcasting) { + bool broadcast_a = bcast.x_batch_size() == 1; + bool broadcast_b = bcast.y_batch_size() == 1; + a_stride = broadcast_a ? 0 : m * k; + b_stride = broadcast_b ? 0 : k * n; + c_stride = m * n; + a_device_memory.push_back(AsDeviceMemory(a_base_ptr)); + b_device_memory.push_back(AsDeviceMemory(b_base_ptr)); + c_device_memory.push_back(AsDeviceMemory(c_base_ptr)); + a_ptrs.push_back(&a_device_memory.back()); + b_ptrs.push_back(&b_device_memory.back()); + c_ptrs.push_back(&c_device_memory.back()); + + DataType dtype = DataTypeToEnum::value; + bool allow_tf32 = tensor_float_32_execution_enabled(); + int device_id = stream->parent()->device_ordinal(); + BatchMatmulParameters matmul_parameters( + trans_x, trans_y, adj_x, adj_y, m, n, k, batch_size, broadcast_a, + broadcast_b, dtype, dtype, allow_tf32, device_id); + + static const bool max_autotune_algorithm_count = + MatmulMaxAutotuneAlgorithmCount(); + int max_algorithm_count = use_autotune ? max_autotune_algorithm_count : 1; + + const auto* plan_and_algorithms = + BatchMatmulPlanMapSingleton::GetInstance()->Find(matmul_parameters); + if (!plan_and_algorithms) { + se::blas::DataType blas_dtype = se::blas::ToDataType::value; + se::blas::ComputationType computation_type; + OP_REQUIRES( + context, + GetBlasComputationType(dtype, allow_tf32, &computation_type), + errors::Internal("Unsupported dtype for batched matmul")); + std::unique_ptr plan = + stream->parent()->CreateBlasLtMatmulPlanStridedBatched( + /*ab_type=*/blas_dtype, + /*cd_type=*/blas_dtype, computation_type, + se::blas::PointerMode::kHost, blas_transpose_b, + blas_transpose_a, n, m, k, batch_size, + /*lda=*/in_y.dim_size(2), b_stride, + /*ldb=*/in_x.dim_size(2), a_stride, /*ldc=*/n, c_stride); + OP_REQUIRES( + context, plan, + errors::Internal( + "CreateBlasLtMatmulPlanStridedBatched failed : a.shape=(", + in_x.dim_size(0), ", ", in_x.dim_size(1), ", ", + in_x.dim_size(2), "), b.shape=(", in_y.dim_size(0), ", ", + in_y.dim_size(1), ", ", in_y.dim_size(2), "), m=", m, ", n=", n, + ", k=", k, ", batch_size=", batch_size, ", adjoint_a=", adj_x, + ", adjoint_b=", adj_x, ", dtype=", dtype, + ", computation_type=", computation_type)); + std::vector> + algorithms; + OP_REQUIRES( + context, + stream->parent()->GetBlasLtMatmulAlgorithms( + plan.get(), max_scratch_size, max_algorithm_count, &algorithms), + errors::Internal("GetBlasLtMatmulAlgorithms failed: a.shape=(", + in_x.dim_size(0), ", ", in_x.dim_size(1), ", ", + in_x.dim_size(2), "), b.shape=(", in_y.dim_size(0), + ", ", in_y.dim_size(1), ", ", in_y.dim_size(2), + "), m=", m, ", n=", n, ", k=", k, + ", batch_size=", batch_size, ", adjoint_a=", adj_x, + ", adjoint_b=", adj_x, ", dtype=", dtype, + ", computation_type=", computation_type)); + plan_and_algorithms = + BatchMatmulPlanMapSingleton::GetInstance()->Insert( + matmul_parameters, {std::move(plan), std::move(algorithms)}); + } + const auto& plan = plan_and_algorithms->plan; + const auto& algorithms = plan_and_algorithms->algorithms; + + // The BlasLtMatmul routines (unlike BlasGemm, BlasGemmBatched etc.) take + // alpha and beta with the same type as the matrices. + Scalar alpha(1.0); + Scalar beta(0.0); + + // Note that algorithm_config.algorithm() here is used to refer + // to the index within the algorithms vector, not the algorithm + // itself. + se::blas::AlgorithmConfig algorithm_config(se::blas::kNoAlgorithm); + if (max_algorithm_count == 1) { + algorithm_config.set_algorithm(0); + } else if (!AutoTuneBatchMatmul::GetInstance()->Find(matmul_parameters, + &algorithm_config)) { + VLOG(4) << "Autotuning BlasLtMatmul over " << algorithms.size() + << " algorithms."; + se::blas::ProfileResult best_result; + se::blas::ProfileResult profile_result; + //for (const auto& profile_algorithm : plan_and_algorithms->algorithms) { + for (size_t i = 0; i != algorithms.size(); ++i) { + const auto& profile_algorithm = algorithms[i]; + // Create a new scratch allocator with every autotuning run so that + // scratch space is deallocated between runs. + BlasScratchAllocator scratch_allocator(max_scratch_size, context); + + bool cublas_launch_status = + stream + ->ThenBlasLtMatmul(plan.get(), alpha, *b_ptrs[0], *a_ptrs[0], + beta, c_ptrs[0], &scratch_allocator, + profile_algorithm.get(), &profile_result) + .ok(); + + VLOG(4) << " Autotune algorithm " << i + << " result: " << profile_result.elapsed_time_in_ms() + << " ms, valid=" << profile_result.is_valid() + << ", workspace_size=" + << profile_algorithm->workspace_size(); + + if (cublas_launch_status && profile_result.is_valid() && + profile_result.elapsed_time_in_ms() < + best_result.elapsed_time_in_ms()) { + best_result = profile_result; + } + } + + if (best_result.is_valid()) { + algorithm_config.set_algorithm(best_result.algorithm()); + } + // We make sure that each matmul parameter set only gets one pass of + // autotune. If no algorithms works, we add kNoAlgorithm to the autotune + // map. + AutoTuneBatchMatmul::GetInstance()->Insert(matmul_parameters, + algorithm_config); + } + se::blas::AlgorithmType algorithm_idx = algorithm_config.algorithm(); + OP_REQUIRES(context, + 0 <= algorithm_idx && algorithm_idx < algorithms.size(), + errors::Internal("Missing/invalid BatchMatmul algorithm")); + const auto& algorithm = algorithms[algorithm_idx]; + BlasScratchAllocator scratch_allocator(max_scratch_size, context); + bool cublas_launch_status = + stream + ->ThenBlasLtMatmul(plan.get(), alpha, *b_ptrs[0], *a_ptrs[0], + beta, c_ptrs[0], &scratch_allocator, + algorithm.get()) + .ok(); + if (!cublas_launch_status) { + context->SetStatus(errors::Internal( + "Blas batched matmul launch failed : a.shape=(", + bcast.x_batch_size(), ", ", in_x.dim_size(0), ", ", + in_x.dim_size(1), "), b.shape=(", bcast.y_batch_size(), ", ", + in_y.dim_size(0), ", ", in_y.dim_size(1), "), m=", m, ", n=", n, + ", k=", k, ", batch_size=", batch_size)); + } + } else { // requires mixed broadcasting + const std::vector& a_batch_indices = bcast.x_batch_indices(); + const std::vector& b_batch_indices = bcast.y_batch_indices(); + for (int64 i = 0; i < bcast.x_batch_size(); ++i) { + a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k)); + } + for (int64 i = 0; i < bcast.y_batch_size(); ++i) { + b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n)); + } + for (int64 i = 0; i < batch_size; ++i) { + c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n)); + a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]); + b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]); + c_ptrs.push_back(&c_device_memory.back()); + } + + BlasScratchAllocator scratch_allocator(max_scratch_size, context); + bool blas_launch_status = + stream + ->ThenBlasGemmBatchedWithScratch( + blas_transpose_b, blas_transpose_a, n, m, k, + static_cast(1.0), b_ptrs, + adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k, + static_cast(0.0), c_ptrs, n, batch_size, + &scratch_allocator) + .ok(); + if (!blas_launch_status) { + context->SetStatus(errors::Internal( + "Blas xGEMMBatched launch failed : a.shape=", + in_x.shape().DebugString(), + ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n, + ", k=", k, ", batch_size=", batch_size)); + } + } + return; +#else // if not GOOGLE_CUDA or CUDA_VERSION < 11000 bool is_full_broadcast = std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1; bool use_strided_batched = @@ -388,8 +743,6 @@ struct LaunchBatchMatMul { } } - typedef Scalar Coefficient; - // Blas does // C = A x B // where A, B and C are assumed to be in column major. @@ -399,7 +752,10 @@ struct LaunchBatchMatMul { if (batch_size == 1) { // This is a regular matrix*matrix or matrix*vector multiply. Avoid the // overhead of the scratch allocator and the batch interface. - if (n == 1 && + // Note that the GEMV call here does not support Eigen::half, so we do not + // use this path in that case. A workaround is applied to the pointers + // passed to the call itself to avoid compilation errors. + if (!std::is_same::value && n == 1 && blas_transpose_b != se::blas::Transpose::kConjugateTranspose && blas_transpose_a != se::blas::Transpose::kConjugateTranspose) { // This is a matrix*vector multiply so use GEMV to compute A * b. @@ -410,13 +766,19 @@ struct LaunchBatchMatMul { auto gemv_trans_a = blas_transpose_a == se::blas::Transpose::kTranspose ? se::blas::Transpose::kNoTranspose : se::blas::Transpose::kTranspose; + // Cast pointers as a workaround for GEMV not supporting Eigen::half + // (this will never actually be executed for Eigen::half). + typedef se::DeviceMemory NonHalfDeviceMemoryType; + NonHalfDeviceMemoryType a_ptr(*(a_ptrs[0])); + NonHalfDeviceMemoryType b_ptr(*(b_ptrs[0])); + NonHalfDeviceMemoryType c_ptr(*(c_ptrs[0])); bool blas_launch_status = stream ->ThenBlasGemv(gemv_trans_a, adj_x || trans_x ? m : k, adj_x || trans_x ? k : m, - static_cast(1.0), *(a_ptrs[0]), - adj_x || trans_x ? m : k, *(b_ptrs[0]), 1, - static_cast(0.0), c_ptrs[0], 1) + static_cast(1.0), a_ptr, + adj_x || trans_x ? m : k, b_ptr, 1, + static_cast(0.0), &c_ptr, 1) .ok(); if (!blas_launch_status) { context->SetStatus(errors::Internal( @@ -459,154 +821,7 @@ struct LaunchBatchMatMul { ", k=", k, ", batch_size=", batch_size)); } } else { - BlasScratchAllocator scratch_allocator(context); - bool blas_launch_status = - stream - ->ThenBlasGemmBatchedWithScratch( - blas_transpose_b, blas_transpose_a, n, m, k, - static_cast(1.0), b_ptrs, - adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k, - static_cast(0.0), c_ptrs, n, batch_size, - &scratch_allocator) - .ok(); - if (!blas_launch_status) { - context->SetStatus(errors::Internal( - "Blas xGEMMBatched launch failed : a.shape=", - in_x.shape().DebugString(), - ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n, - ", k=", k, ", batch_size=", batch_size)); - } - } - } -}; - -template <> -struct LaunchBatchMatMul { - static void Launch(OpKernelContext* context, const Tensor& in_x, - const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x, - bool trans_y, const MatMulBCast& bcast, Tensor* out) { - typedef Eigen::half Scalar; - se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose, - se::blas::Transpose::kTranspose, - se::blas::Transpose::kConjugateTranspose}; - const uint64 m = in_x.dim_size(adj_x || trans_x ? 2 : 1); - const uint64 k = in_x.dim_size(adj_x || trans_x ? 1 : 2); - const uint64 n = in_y.dim_size(adj_y || trans_y ? 1 : 2); - const uint64 batch_size = bcast.output_batch_size(); - auto blas_transpose_a = trans[adj_x ? 2 : (trans_x ? 1 : 0)]; - auto blas_transpose_b = trans[adj_y ? 2 : (trans_y ? 1 : 0)]; - - auto* stream = context->op_device_context()->stream(); - OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); - - typedef perftools::gputools::DeviceMemory DeviceMemoryType; - std::vector a_device_memory; - std::vector b_device_memory; - std::vector c_device_memory; - std::vector a_ptrs; - std::vector b_ptrs; - std::vector c_ptrs; - a_device_memory.reserve(bcast.x_batch_size()); - b_device_memory.reserve(bcast.y_batch_size()); - c_device_memory.reserve(batch_size); - a_ptrs.reserve(batch_size); - b_ptrs.reserve(batch_size); - c_ptrs.reserve(batch_size); - auto* a_base_ptr = in_x.template flat().data(); - auto* b_base_ptr = in_y.template flat().data(); - auto* c_base_ptr = out->template flat().data(); - - uint64 a_stride; - uint64 b_stride; - uint64 c_stride; - - bool is_full_broadcast = - std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1; - bool use_strided_batched = - (!bcast.IsBroadcastingRequired() || is_full_broadcast) && - batch_size > 1; - if (use_strided_batched) { - a_stride = bcast.x_batch_size() != 1 ? m * k : 0; - b_stride = bcast.y_batch_size() != 1 ? k * n : 0; - c_stride = m * n; - a_device_memory.push_back(AsDeviceMemory(a_base_ptr)); - b_device_memory.push_back(AsDeviceMemory(b_base_ptr)); - c_device_memory.push_back(AsDeviceMemory(c_base_ptr)); - a_ptrs.push_back(&a_device_memory.back()); - b_ptrs.push_back(&b_device_memory.back()); - c_ptrs.push_back(&c_device_memory.back()); - } else if (!bcast.IsBroadcastingRequired()) { - for (int64 i = 0; i < batch_size; ++i) { - a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k)); - b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n)); - c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n)); - a_ptrs.push_back(&a_device_memory.back()); - b_ptrs.push_back(&b_device_memory.back()); - c_ptrs.push_back(&c_device_memory.back()); - } - } else { - const std::vector& a_batch_indices = bcast.x_batch_indices(); - const std::vector& b_batch_indices = bcast.y_batch_indices(); - for (int64 i = 0; i < bcast.x_batch_size(); ++i) { - a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k)); - } - for (int64 i = 0; i < bcast.y_batch_size(); ++i) { - b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n)); - } - for (int64 i = 0; i < batch_size; ++i) { - c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n)); - a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]); - b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]); - c_ptrs.push_back(&c_device_memory.back()); - } - } - - typedef float Coefficient; - - // Blas does - // C = A x B - // where A, B and C are assumed to be in column major. - // We want the output to be in row-major, so we can compute - // C' = B' x A', where ' stands for transpose (not adjoint). - // TODO(yangzihao): Choose the best of the three strategies using autotune. - if (batch_size == 1) { - // This is a regular matrix*matrix or matrix*vector multiply. Avoid the - // overhead of the scratch allocator and the batch interface. - // TODO(benbarsdell): Use fp16 Gemv if it becomes supported by CUBLAS - bool blas_launch_status = - stream - ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k, - static_cast(1.0), *(b_ptrs[0]), - adj_y || trans_y ? k : n, *(a_ptrs[0]), - adj_x || trans_x ? m : k, - static_cast(0.0), c_ptrs[0], n) - .ok(); - if (!blas_launch_status) { - context->SetStatus(errors::Internal( - "Blas xGEMM launch failed : a.shape=", in_x.shape().DebugString(), - ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n, - ", k=", k)); - } - } else if (use_strided_batched) { - bool blas_launch_status = - stream - ->ThenBlasGemmStridedBatched( - blas_transpose_b, blas_transpose_a, n, m, k, - static_cast(1.0), *b_ptrs[0], - adj_y || trans_y ? k : n, b_stride, *a_ptrs[0], - adj_x || trans_x ? m : k, a_stride, - static_cast(0.0), c_ptrs[0], n, c_stride, - batch_size) - .ok(); - if (!blas_launch_status) { - context->SetStatus(errors::Internal( - "Blas xGEMMStridedBatched launch failed : a.shape=", - in_x.shape().DebugString(), - ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n, - ", k=", k, ", batch_size=", batch_size)); - } - } else { - BlasScratchAllocator scratch_allocator(context); + BlasScratchAllocator scratch_allocator(max_scratch_size, context); bool blas_launch_status = stream ->ThenBlasGemmBatchedWithScratch( @@ -624,6 +839,7 @@ struct LaunchBatchMatMul { ", k=", k, ", batch_size=", batch_size)); } } +#endif // not GOOGLE_CUDA or CUDA_VERSION < 11000 } }; @@ -637,6 +853,7 @@ class BaseBatchMatMulOp : public OpKernel { : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_)); OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_)); + use_autotune_ = MatmulAutotuneEnable(); } ~BaseBatchMatMulOp() override {} @@ -698,7 +915,7 @@ class BaseBatchMatMulOp : public OpKernel { out->shape().DebugString())); LaunchBatchMatMul::Launch( ctx, in0_reshaped, in1_reshaped, adj_x_, adj_y_, /*trans_x=*/false, - /*trans_y=*/false, bcast, &out_reshaped); + /*trans_y=*/false, bcast, use_autotune_, &out_reshaped); } protected: @@ -708,6 +925,7 @@ class BaseBatchMatMulOp : public OpKernel { private: bool adj_x_; bool adj_y_; + bool use_autotune_; }; // BatchMatMul Op implementation which disallows broadcasting. diff --git a/tensorflow/core/kernels/gpu_utils.cc b/tensorflow/core/kernels/gpu_utils.cc index 7da1963c676..171a26e5b78 100644 --- a/tensorflow/core/kernels/gpu_utils.cc +++ b/tensorflow/core/kernels/gpu_utils.cc @@ -22,6 +22,7 @@ limitations under the License. #include "google/protobuf/any.pb.h" #include "absl/algorithm/container.h" #include "absl/base/call_once.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/logger.h" #include "tensorflow/core/protobuf/autotuning.pb.h" #include "tensorflow/core/protobuf/conv_autotuning.pb.h" @@ -282,6 +283,62 @@ Status BestCudnnConvAlgorithm(absl::Span results, return Status::OK(); } +int64 GetWorkspaceLimit(const string& envvar_in_mb, + int64 default_value_in_bytes) { + const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str()); + if (workspace_limit_in_mb_str != nullptr && + strcmp(workspace_limit_in_mb_str, "") != 0) { + int64 scratch_limit_in_mb = -1; + if (strings::safe_strto64(workspace_limit_in_mb_str, + &scratch_limit_in_mb)) { + return scratch_limit_in_mb * (1 << 20); + } else { + LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": " + << workspace_limit_in_mb_str; + } + } + return default_value_in_bytes; +} + +GpuScratchAllocator::GpuScratchAllocator(int64 memory_limit, + OpKernelContext* context) + : memory_limit_(memory_limit), total_byte_size_(0), context_(context) {} + +se::port::StatusOr> GpuScratchAllocator::AllocateBytes( + int64 byte_size) { + Tensor temporary_memory; + if (byte_size < 0) { + return se::port::Status{se::port::error::INVALID_ARGUMENT, + "Requested negative byte size!"}; + } + if (byte_size > memory_limit_) { + return se::port::Status{ + se::port::error::UNAVAILABLE, + absl::StrCat("Requested memory size (", byte_size, + ") exceeds the max memory limit (", memory_limit_, ").")}; + } + AllocationAttributes allocation_attr; + allocation_attr.retry_on_failure = false; + Status allocation_status(context_->allocate_temp( + DT_UINT8, TensorShape({byte_size}), &temporary_memory, + AllocatorAttributes(), allocation_attr)); + if (!allocation_status.ok()) { + return se::port::Status{ + se::port::error::UNAVAILABLE, + absl::StrCat("Failed to allocate the requested memory size (", + byte_size, ").")}; + } + // Hold the reference of the allocated tensors until the end of the + // allocator. + // NOTE: We expect tensors to be deallocated when this allocator goes out of + // scope when allocated_tensors is destructed. + allocated_tensors_.push_back(temporary_memory); + total_byte_size_ += byte_size; + return se::port::StatusOr>( + AsDeviceMemory(temporary_memory.flat().data(), + temporary_memory.flat().size())); +} + } // namespace tensorflow #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/gpu_utils.h b/tensorflow/core/kernels/gpu_utils.h index a1589db3b5b..f97aa182fbd 100644 --- a/tensorflow/core/kernels/gpu_utils.h +++ b/tensorflow/core/kernels/gpu_utils.h @@ -243,6 +243,42 @@ void LogFusedConvForwardAutotuneResults( Status BestCudnnConvAlgorithm(absl::Span results, se::dnn::AlgorithmConfig* algo); +// Get a workspace limit from the environment variable, which is in MB. +// Return the workspace memory limit in bytes. If no value is set, return the +// default value. +int64 GetWorkspaceLimit(const string& envvar_in_mb, + int64 default_value_in_bytes); + +// Get the Dnn workspace limit from the environment variable, which is in MB. +// Return the workspace memory limit in bytes. If no value is set, return the +// default value. +int64 GetDnnWorkspaceLimit(const string& envvar_in_mb, + int64 default_value_in_bytes); + +// A class to provide scratch-space allocator for Stream-Executor callbacks in +// CUDA libraries (CUDNN etc.). +// TensorFlow is responsible for releasing the temporary buffers after +// the kernel finishes. +class GpuScratchAllocator : public se::ScratchAllocator { + public: + virtual ~GpuScratchAllocator() {} + + GpuScratchAllocator(int64 memory_limit, OpKernelContext* context); + + int64 GetMemoryLimitInBytes() override { return memory_limit_; } + + se::port::StatusOr> AllocateBytes( + int64 byte_size) override; + + int64 TotalByteSize() { return total_byte_size_; } + + private: + int64 memory_limit_; + int64 total_byte_size_; + OpKernelContext* context_; + std::vector allocated_tensors_; +}; + } // namespace tensorflow #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/linalg/einsum_op_impl.h b/tensorflow/core/kernels/linalg/einsum_op_impl.h index b9b2d1f0eae..e10322a88e1 100644 --- a/tensorflow/core/kernels/linalg/einsum_op_impl.h +++ b/tensorflow/core/kernels/linalg/einsum_op_impl.h @@ -549,6 +549,7 @@ struct EinsumHelper { static Status ContractOperands(OpKernelContext* ctx, absl::Span inputs, absl::Span swap_free_and_contract, + bool use_autotune, Tensor* output) { if (inputs.size() == 1) return CopyFrom(inputs[0], inputs[0].shape(), output); @@ -583,7 +584,7 @@ struct EinsumHelper { ReshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped)); LaunchBatchMatMul::Launch(ctx, lhs, rhs, /*adj_x=*/false, /*adj_y=*/false, trans_x, trans_y, - bcast, &output_reshaped); + bcast, use_autotune, &output_reshaped); return Status::OK(); } }; @@ -598,6 +599,7 @@ class EinsumOp : public OpKernel { equation_, &input_labels_, &output_labels_, &label_types_, &input_label_counts_, &output_label_counts_, &input_has_ellipsis_, &output_has_ellipsis_)); + use_autotune_ = MatmulAutotuneEnable(); } void Compute(OpKernelContext* ctx) override { @@ -640,7 +642,7 @@ class EinsumOp : public OpKernel { Tensor contraction_output_reshaped; OP_REQUIRES_OK(ctx, EinsumHelper::ContractOperands( ctx, inputs_reduced, swap_free_and_contract, - &contraction_output_reshaped)); + use_autotune_, &contraction_output_reshaped)); // Copy the batch labels from the contraction output. Recover the batch // shape, which may have been broadcasted. @@ -738,6 +740,7 @@ class EinsumOp : public OpKernel { LabelCounts output_label_counts_; gtl::InlinedVector input_has_ellipsis_; bool output_has_ellipsis_ = false; + bool use_autotune_; }; #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/util/matmul_autotune.cc b/tensorflow/core/util/matmul_autotune.cc index 741a78a193f..c30a5d930e7 100644 --- a/tensorflow/core/util/matmul_autotune.cc +++ b/tensorflow/core/util/matmul_autotune.cc @@ -48,4 +48,22 @@ bool MatmulDoFP32ComputationFP16Input() { return value; } +int MatmulMaxAutotuneAlgorithmCount() { + int64 value; + // In CUDA 11, cublasLtMatmulAlgoGetHeuristic typically returns <= 4 + // algorithms for a given configuration, so 10 seems like a reasonable default + // here. + Status status = + ReadInt64FromEnvVar("TF_MATMUL_AUTOTUNE_MAX_ALGORITHMS", 10, &value); + if (!status.ok()) { + LOG(ERROR) << status.error_message(); + } + static constexpr const int kMaxValue = std::numeric_limits::max(); + if (value < 1 || value > kMaxValue) { + LOG(ERROR) << "Invalid value for TF_MATMUL_AUTOTUNE_MAX_ALGORITHMS: " + << value << " is not in range [1, " << kMaxValue << "]"; + } + return value; +} + } // namespace tensorflow diff --git a/tensorflow/core/util/matmul_autotune.h b/tensorflow/core/util/matmul_autotune.h index 5846cae2fc7..c77d274e781 100644 --- a/tensorflow/core/util/matmul_autotune.h +++ b/tensorflow/core/util/matmul_autotune.h @@ -22,6 +22,7 @@ namespace tensorflow { bool MatmulAutotuneEnable(); bool MatmulDoFP32ComputationFP16Input(); +int MatmulMaxAutotuneAlgorithmCount(); } // namespace tensorflow