Use BlasLtMatmul APIs in batch_matmul_op_impl
- Integrates BlasLtMatmul with autotuning into the implementation of the BatchMatMul and Einsum ops. - This integration is only used when the CUDA version is >= 11.0.
This commit is contained in:
parent
aaea82e6bc
commit
0d172940c1
@ -3334,6 +3334,9 @@ tf_kernel_library(
|
|||||||
prefix = "batch_matmul_op",
|
prefix = "batch_matmul_op",
|
||||||
deps = MATH_DEPS + [":eigen_contraction_kernel"] + if_mkl_ml([
|
deps = MATH_DEPS + [":eigen_contraction_kernel"] + if_mkl_ml([
|
||||||
"//third_party/mkl:intel_binary_blob",
|
"//third_party/mkl:intel_binary_blob",
|
||||||
|
]) + if_cuda([
|
||||||
|
"//tensorflow/core/kernels:gpu_utils",
|
||||||
|
"//tensorflow/core/platform:tensor_float_32_utils",
|
||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -3392,6 +3395,7 @@ tf_kernel_library(
|
|||||||
prefix = "fft_ops",
|
prefix = "fft_ops",
|
||||||
deps = MATH_DEPS + [
|
deps = MATH_DEPS + [
|
||||||
] + if_cuda([
|
] + if_cuda([
|
||||||
|
"//tensorflow/core/kernels:gpu_utils",
|
||||||
"//tensorflow/core/platform/default/build_config:cufft_plugin",
|
"//tensorflow/core/platform/default/build_config:cufft_plugin",
|
||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
@ -22,7 +22,6 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
@ -34,17 +33,24 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/tensor_float_32_utils.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/util/matmul_autotune.h"
|
||||||
#include "tensorflow/core/util/matmul_bcast.h"
|
#include "tensorflow/core/util/matmul_bcast.h"
|
||||||
#include "tensorflow/core/util/work_sharder.h"
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
|
||||||
#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
|
#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
|
||||||
#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
|
#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
#include "tensorflow/core/kernels/gpu_utils.h"
|
||||||
#include "tensorflow/core/platform/stream_executor.h"
|
#include "tensorflow/core/platform/stream_executor.h"
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#include "third_party/gpus/cuda/include/cuda.h" // For CUDA_VERSION
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -219,7 +225,8 @@ template <typename Scalar>
|
|||||||
struct LaunchBatchMatMul<CPUDevice, Scalar> {
|
struct LaunchBatchMatMul<CPUDevice, Scalar> {
|
||||||
static void Launch(OpKernelContext* context, const Tensor& in_x,
|
static void Launch(OpKernelContext* context, const Tensor& in_x,
|
||||||
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
|
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
|
||||||
bool trans_y, const MatMulBCast& bcast, Tensor* out) {
|
bool trans_y, const MatMulBCast& bcast, bool use_autotune,
|
||||||
|
Tensor* out) {
|
||||||
typedef ParallelMatMulKernel<Scalar, Eigen::NumTraits<Scalar>::IsComplex>
|
typedef ParallelMatMulKernel<Scalar, Eigen::NumTraits<Scalar>::IsComplex>
|
||||||
ParallelMatMulKernel;
|
ParallelMatMulKernel;
|
||||||
bool conjugate_result = false;
|
bool conjugate_result = false;
|
||||||
@ -275,45 +282,201 @@ se::DeviceMemory<T> AsDeviceMemory(const T* gpu_memory) {
|
|||||||
return typed;
|
return typed;
|
||||||
}
|
}
|
||||||
|
|
||||||
class BlasScratchAllocator : public se::ScratchAllocator {
|
using BlasScratchAllocator = GpuScratchAllocator;
|
||||||
|
|
||||||
|
int64 GetBlasWorkspaceLimit(const string& envvar_in_mb,
|
||||||
|
int64 default_value_in_bytes) {
|
||||||
|
return GetWorkspaceLimit(envvar_in_mb, default_value_in_bytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encapsulate all of the shape, dtype etc. information that defines a unique
|
||||||
|
// batched matmul operation.
|
||||||
|
class BatchMatmulParameters {
|
||||||
public:
|
public:
|
||||||
using Stream = se::Stream;
|
BatchMatmulParameters(bool trans_a, bool trans_b, bool adj_a, bool adj_b,
|
||||||
using DeviceMemoryBytes = se::DeviceMemory<uint8>;
|
uint64 m, uint64 n, uint64 k, uint64 batch_count,
|
||||||
|
bool broadcast_a, bool broadcast_b, DataType dtype_ab,
|
||||||
BlasScratchAllocator(OpKernelContext* context) : context_(context) {}
|
DataType dtype_cd, bool allow_tf32, int device_id)
|
||||||
|
: trans_a_(trans_a),
|
||||||
int64 GetMemoryLimitInBytes() override { return -1; }
|
trans_b_(trans_b),
|
||||||
|
adj_a_(adj_a),
|
||||||
se::port::StatusOr<DeviceMemoryBytes> AllocateBytes(
|
adj_b_(adj_b),
|
||||||
int64 byte_size) override {
|
m_(m),
|
||||||
Tensor temporary_memory;
|
n_(n),
|
||||||
|
k_(k),
|
||||||
Status allocation_status(context_->allocate_temp(
|
batch_count_(batch_count),
|
||||||
DT_UINT8, TensorShape({byte_size}), &temporary_memory));
|
broadcast_a_(broadcast_a),
|
||||||
if (!allocation_status.ok()) {
|
broadcast_b_(broadcast_b),
|
||||||
return se::port::StatusOr<DeviceMemoryBytes>(
|
dtype_ab_(dtype_ab),
|
||||||
DeviceMemoryBytes::MakeFromByteSize(nullptr, 0));
|
dtype_cd_(dtype_cd),
|
||||||
|
allow_tf32_(allow_tf32),
|
||||||
|
device_id_(device_id) {
|
||||||
|
hash_code_ = trans_a;
|
||||||
|
hash_code_ = Hash64Combine(hash_code_, trans_b);
|
||||||
|
hash_code_ = Hash64Combine(hash_code_, adj_a);
|
||||||
|
hash_code_ = Hash64Combine(hash_code_, adj_b);
|
||||||
|
hash_code_ = Hash64Combine(hash_code_, m);
|
||||||
|
hash_code_ = Hash64Combine(hash_code_, n);
|
||||||
|
hash_code_ = Hash64Combine(hash_code_, k);
|
||||||
|
hash_code_ = Hash64Combine(hash_code_, batch_count);
|
||||||
|
hash_code_ = Hash64Combine(hash_code_, broadcast_a);
|
||||||
|
hash_code_ = Hash64Combine(hash_code_, broadcast_b);
|
||||||
|
hash_code_ = Hash64Combine(hash_code_, dtype_ab);
|
||||||
|
hash_code_ = Hash64Combine(hash_code_, dtype_cd);
|
||||||
|
hash_code_ = Hash64Combine(hash_code_, allow_tf32);
|
||||||
|
hash_code_ = Hash64Combine(hash_code_, device_id);
|
||||||
}
|
}
|
||||||
// Hold the reference of the allocated tensors until the end of the
|
bool operator==(const BatchMatmulParameters& other) const {
|
||||||
// allocator.
|
return this->get_data_as_tuple() == other.get_data_as_tuple();
|
||||||
allocated_tensors_.push_back(temporary_memory);
|
}
|
||||||
return se::port::StatusOr<DeviceMemoryBytes>(
|
|
||||||
DeviceMemoryBytes::MakeFromByteSize(
|
bool operator!=(const BatchMatmulParameters& other) const {
|
||||||
temporary_memory.flat<uint8>().data(),
|
return !(*this == other);
|
||||||
temporary_memory.flat<uint8>().size()));
|
}
|
||||||
|
uint64 hash() const { return hash_code_; }
|
||||||
|
|
||||||
|
string ToString() const {
|
||||||
|
// clang-format off
|
||||||
|
return strings::StrCat(
|
||||||
|
trans_a_, ", ", trans_b_, ", ", adj_a_, ", ", adj_b_, ", ",
|
||||||
|
m_, ", ", n_, ", ", k_, ", ", batch_count_, ", ",
|
||||||
|
broadcast_a_, ", ", broadcast_b_, ", ",
|
||||||
|
dtype_ab_, ", ", dtype_cd_, ", ", allow_tf32_, ", ", device_id_);
|
||||||
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
OpKernelContext* context_;
|
typedef std::tuple<bool, bool, bool, bool, int64, int64, int64, int64, bool,
|
||||||
std::vector<Tensor> allocated_tensors_;
|
bool, DataType, DataType, bool, int>
|
||||||
|
ParameterDataType;
|
||||||
|
|
||||||
|
ParameterDataType get_data_as_tuple() const {
|
||||||
|
return std::make_tuple(trans_a_, trans_b_, adj_a_, adj_b_, m_, n_, k_,
|
||||||
|
batch_count_, broadcast_a_, broadcast_b_, dtype_ab_,
|
||||||
|
dtype_cd_, allow_tf32_, device_id_);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool trans_a_;
|
||||||
|
bool trans_b_;
|
||||||
|
bool adj_a_;
|
||||||
|
bool adj_b_;
|
||||||
|
uint64 m_;
|
||||||
|
uint64 n_;
|
||||||
|
uint64 k_;
|
||||||
|
uint64 batch_count_;
|
||||||
|
bool broadcast_a_;
|
||||||
|
bool broadcast_b_;
|
||||||
|
DataType dtype_ab_;
|
||||||
|
DataType dtype_cd_;
|
||||||
|
bool allow_tf32_;
|
||||||
|
int device_id_;
|
||||||
|
uint64 hash_code_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
bool GetBlasComputationType(const DataType& dtype, bool allow_tf32,
|
||||||
|
se::blas::ComputationType* compute_type) {
|
||||||
|
using se::blas::ComputationType;
|
||||||
|
static bool use_f32_for_f16_computation = MatmulDoFP32ComputationFP16Input();
|
||||||
|
ComputationType f32_type =
|
||||||
|
allow_tf32 ? ComputationType::kF32FastTF32 : ComputationType::kF32;
|
||||||
|
switch (dtype) {
|
||||||
|
case DT_HALF:
|
||||||
|
case DT_BFLOAT16:
|
||||||
|
*compute_type =
|
||||||
|
use_f32_for_f16_computation ? f32_type : ComputationType::kF16;
|
||||||
|
return true;
|
||||||
|
case DT_FLOAT:
|
||||||
|
*compute_type = f32_type;
|
||||||
|
return true;
|
||||||
|
case DT_DOUBLE:
|
||||||
|
*compute_type = ComputationType::kF64;
|
||||||
|
return true;
|
||||||
|
case DT_COMPLEX64:
|
||||||
|
*compute_type = f32_type;
|
||||||
|
return true;
|
||||||
|
case DT_COMPLEX128:
|
||||||
|
*compute_type = ComputationType::kComplexF64;
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
|
// Unsupported compute_type, return false.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Thread-safe map from matmul parameters to their corresponding plan and
|
||||||
|
// algorithms.
|
||||||
|
template <typename Parameters>
|
||||||
|
class BlasLtMatmulPlanMap {
|
||||||
|
public:
|
||||||
|
struct PlanAndAlgorithms {
|
||||||
|
std::unique_ptr<se::blas::IBlasLtMatmulPlan> plan;
|
||||||
|
std::vector<std::unique_ptr<se::blas::IBlasLtMatmulAlgorithm>> algorithms;
|
||||||
|
};
|
||||||
|
|
||||||
|
const PlanAndAlgorithms* Find(const Parameters& params) {
|
||||||
|
mutex_lock lock(mu_);
|
||||||
|
auto iter = params_plan_map_.find(params);
|
||||||
|
if (iter == params_plan_map_.end()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return &iter->second;
|
||||||
|
}
|
||||||
|
const PlanAndAlgorithms* Insert(const Parameters& params,
|
||||||
|
PlanAndAlgorithms value) {
|
||||||
|
mutex_lock lock(mu_);
|
||||||
|
return ¶ms_plan_map_.emplace(params, std::move(value)).first->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct Hasher {
|
||||||
|
std::size_t operator()(const Parameters& parameter) const {
|
||||||
|
return parameter.hash();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
mutable mutex mu_;
|
||||||
|
std::unordered_map<Parameters, PlanAndAlgorithms, Hasher> params_plan_map_
|
||||||
|
GUARDED_BY(mu_);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Parameters>
|
||||||
|
struct BlasLtPlanMapSingleton {
|
||||||
|
typedef BlasLtMatmulPlanMap<Parameters> PlanMapType;
|
||||||
|
static PlanMapType* GetInstance() {
|
||||||
|
static PlanMapType* instance = new PlanMapType();
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef BlasLtPlanMapSingleton<BatchMatmulParameters>
|
||||||
|
BatchMatmulPlanMapSingleton;
|
||||||
|
|
||||||
|
// A dummy type to group matmul autotune results together.
|
||||||
|
struct BatchMatmulAutoTuneGroup {
|
||||||
|
static string name() { return "MatmulLt"; }
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef AutoTuneSingleton<BatchMatmulAutoTuneGroup, BatchMatmulParameters,
|
||||||
|
se::blas::AlgorithmConfig>
|
||||||
|
AutoTuneBatchMatmul;
|
||||||
|
|
||||||
|
template <typename Scalar>
|
||||||
|
struct CoefficientType {
|
||||||
|
typedef Scalar type;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct CoefficientType<Eigen::half> {
|
||||||
|
typedef float type;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
template <typename Scalar>
|
template <typename Scalar>
|
||||||
struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||||
static void Launch(OpKernelContext* context, const Tensor& in_x,
|
static void Launch(OpKernelContext* context, const Tensor& in_x,
|
||||||
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
|
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
|
||||||
bool trans_y, const MatMulBCast& bcast, Tensor* out) {
|
bool trans_y, const MatMulBCast& bcast, bool use_autotune,
|
||||||
|
Tensor* out) {
|
||||||
se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
|
se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
|
||||||
se::blas::Transpose::kTranspose,
|
se::blas::Transpose::kTranspose,
|
||||||
se::blas::Transpose::kConjugateTranspose};
|
se::blas::Transpose::kConjugateTranspose};
|
||||||
@ -347,6 +510,198 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
|||||||
uint64 b_stride;
|
uint64 b_stride;
|
||||||
uint64 c_stride;
|
uint64 c_stride;
|
||||||
|
|
||||||
|
typedef typename CoefficientType<Scalar>::type Coefficient;
|
||||||
|
|
||||||
|
static const int64 max_scratch_size = GetBlasWorkspaceLimit(
|
||||||
|
"TF_CUBLAS_WORKSPACE_LIMIT_IN_MB", 1LL << 32); // 4GB by default
|
||||||
|
|
||||||
|
// The BlasLtMatmul routines are only supported from CUDA 11.0 onward.
|
||||||
|
#if GOOGLE_CUDA && CUDA_VERSION >= 11000
|
||||||
|
bool is_full_broadcast =
|
||||||
|
std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
|
||||||
|
bool requires_mixed_broadcasting =
|
||||||
|
bcast.IsBroadcastingRequired() && !is_full_broadcast;
|
||||||
|
if (!requires_mixed_broadcasting) {
|
||||||
|
bool broadcast_a = bcast.x_batch_size() == 1;
|
||||||
|
bool broadcast_b = bcast.y_batch_size() == 1;
|
||||||
|
a_stride = broadcast_a ? 0 : m * k;
|
||||||
|
b_stride = broadcast_b ? 0 : k * n;
|
||||||
|
c_stride = m * n;
|
||||||
|
a_device_memory.push_back(AsDeviceMemory(a_base_ptr));
|
||||||
|
b_device_memory.push_back(AsDeviceMemory(b_base_ptr));
|
||||||
|
c_device_memory.push_back(AsDeviceMemory(c_base_ptr));
|
||||||
|
a_ptrs.push_back(&a_device_memory.back());
|
||||||
|
b_ptrs.push_back(&b_device_memory.back());
|
||||||
|
c_ptrs.push_back(&c_device_memory.back());
|
||||||
|
|
||||||
|
DataType dtype = DataTypeToEnum<Scalar>::value;
|
||||||
|
bool allow_tf32 = tensor_float_32_execution_enabled();
|
||||||
|
int device_id = stream->parent()->device_ordinal();
|
||||||
|
BatchMatmulParameters matmul_parameters(
|
||||||
|
trans_x, trans_y, adj_x, adj_y, m, n, k, batch_size, broadcast_a,
|
||||||
|
broadcast_b, dtype, dtype, allow_tf32, device_id);
|
||||||
|
|
||||||
|
static const bool max_autotune_algorithm_count =
|
||||||
|
MatmulMaxAutotuneAlgorithmCount();
|
||||||
|
int max_algorithm_count = use_autotune ? max_autotune_algorithm_count : 1;
|
||||||
|
|
||||||
|
const auto* plan_and_algorithms =
|
||||||
|
BatchMatmulPlanMapSingleton::GetInstance()->Find(matmul_parameters);
|
||||||
|
if (!plan_and_algorithms) {
|
||||||
|
se::blas::DataType blas_dtype = se::blas::ToDataType<Scalar>::value;
|
||||||
|
se::blas::ComputationType computation_type;
|
||||||
|
OP_REQUIRES(
|
||||||
|
context,
|
||||||
|
GetBlasComputationType(dtype, allow_tf32, &computation_type),
|
||||||
|
errors::Internal("Unsupported dtype for batched matmul"));
|
||||||
|
std::unique_ptr<se::blas::IBlasLtMatmulPlan> plan =
|
||||||
|
stream->parent()->CreateBlasLtMatmulPlanStridedBatched(
|
||||||
|
/*ab_type=*/blas_dtype,
|
||||||
|
/*cd_type=*/blas_dtype, computation_type,
|
||||||
|
se::blas::PointerMode::kHost, blas_transpose_b,
|
||||||
|
blas_transpose_a, n, m, k, batch_size,
|
||||||
|
/*lda=*/in_y.dim_size(2), b_stride,
|
||||||
|
/*ldb=*/in_x.dim_size(2), a_stride, /*ldc=*/n, c_stride);
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, plan,
|
||||||
|
errors::Internal(
|
||||||
|
"CreateBlasLtMatmulPlanStridedBatched failed : a.shape=(",
|
||||||
|
in_x.dim_size(0), ", ", in_x.dim_size(1), ", ",
|
||||||
|
in_x.dim_size(2), "), b.shape=(", in_y.dim_size(0), ", ",
|
||||||
|
in_y.dim_size(1), ", ", in_y.dim_size(2), "), m=", m, ", n=", n,
|
||||||
|
", k=", k, ", batch_size=", batch_size, ", adjoint_a=", adj_x,
|
||||||
|
", adjoint_b=", adj_x, ", dtype=", dtype,
|
||||||
|
", computation_type=", computation_type));
|
||||||
|
std::vector<std::unique_ptr<se::blas::IBlasLtMatmulAlgorithm>>
|
||||||
|
algorithms;
|
||||||
|
OP_REQUIRES(
|
||||||
|
context,
|
||||||
|
stream->parent()->GetBlasLtMatmulAlgorithms(
|
||||||
|
plan.get(), max_scratch_size, max_algorithm_count, &algorithms),
|
||||||
|
errors::Internal("GetBlasLtMatmulAlgorithms failed: a.shape=(",
|
||||||
|
in_x.dim_size(0), ", ", in_x.dim_size(1), ", ",
|
||||||
|
in_x.dim_size(2), "), b.shape=(", in_y.dim_size(0),
|
||||||
|
", ", in_y.dim_size(1), ", ", in_y.dim_size(2),
|
||||||
|
"), m=", m, ", n=", n, ", k=", k,
|
||||||
|
", batch_size=", batch_size, ", adjoint_a=", adj_x,
|
||||||
|
", adjoint_b=", adj_x, ", dtype=", dtype,
|
||||||
|
", computation_type=", computation_type));
|
||||||
|
plan_and_algorithms =
|
||||||
|
BatchMatmulPlanMapSingleton::GetInstance()->Insert(
|
||||||
|
matmul_parameters, {std::move(plan), std::move(algorithms)});
|
||||||
|
}
|
||||||
|
const auto& plan = plan_and_algorithms->plan;
|
||||||
|
const auto& algorithms = plan_and_algorithms->algorithms;
|
||||||
|
|
||||||
|
// The BlasLtMatmul routines (unlike BlasGemm, BlasGemmBatched etc.) take
|
||||||
|
// alpha and beta with the same type as the matrices.
|
||||||
|
Scalar alpha(1.0);
|
||||||
|
Scalar beta(0.0);
|
||||||
|
|
||||||
|
// Note that algorithm_config.algorithm() here is used to refer
|
||||||
|
// to the index within the algorithms vector, not the algorithm
|
||||||
|
// itself.
|
||||||
|
se::blas::AlgorithmConfig algorithm_config(se::blas::kNoAlgorithm);
|
||||||
|
if (max_algorithm_count == 1) {
|
||||||
|
algorithm_config.set_algorithm(0);
|
||||||
|
} else if (!AutoTuneBatchMatmul::GetInstance()->Find(matmul_parameters,
|
||||||
|
&algorithm_config)) {
|
||||||
|
VLOG(4) << "Autotuning BlasLtMatmul over " << algorithms.size()
|
||||||
|
<< " algorithms.";
|
||||||
|
se::blas::ProfileResult best_result;
|
||||||
|
se::blas::ProfileResult profile_result;
|
||||||
|
//for (const auto& profile_algorithm : plan_and_algorithms->algorithms) {
|
||||||
|
for (size_t i = 0; i != algorithms.size(); ++i) {
|
||||||
|
const auto& profile_algorithm = algorithms[i];
|
||||||
|
// Create a new scratch allocator with every autotuning run so that
|
||||||
|
// scratch space is deallocated between runs.
|
||||||
|
BlasScratchAllocator scratch_allocator(max_scratch_size, context);
|
||||||
|
|
||||||
|
bool cublas_launch_status =
|
||||||
|
stream
|
||||||
|
->ThenBlasLtMatmul(plan.get(), alpha, *b_ptrs[0], *a_ptrs[0],
|
||||||
|
beta, c_ptrs[0], &scratch_allocator,
|
||||||
|
profile_algorithm.get(), &profile_result)
|
||||||
|
.ok();
|
||||||
|
|
||||||
|
VLOG(4) << " Autotune algorithm " << i
|
||||||
|
<< " result: " << profile_result.elapsed_time_in_ms()
|
||||||
|
<< " ms, valid=" << profile_result.is_valid()
|
||||||
|
<< ", workspace_size="
|
||||||
|
<< profile_algorithm->workspace_size();
|
||||||
|
|
||||||
|
if (cublas_launch_status && profile_result.is_valid() &&
|
||||||
|
profile_result.elapsed_time_in_ms() <
|
||||||
|
best_result.elapsed_time_in_ms()) {
|
||||||
|
best_result = profile_result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (best_result.is_valid()) {
|
||||||
|
algorithm_config.set_algorithm(best_result.algorithm());
|
||||||
|
}
|
||||||
|
// We make sure that each matmul parameter set only gets one pass of
|
||||||
|
// autotune. If no algorithms works, we add kNoAlgorithm to the autotune
|
||||||
|
// map.
|
||||||
|
AutoTuneBatchMatmul::GetInstance()->Insert(matmul_parameters,
|
||||||
|
algorithm_config);
|
||||||
|
}
|
||||||
|
se::blas::AlgorithmType algorithm_idx = algorithm_config.algorithm();
|
||||||
|
OP_REQUIRES(context,
|
||||||
|
0 <= algorithm_idx && algorithm_idx < algorithms.size(),
|
||||||
|
errors::Internal("Missing/invalid BatchMatmul algorithm"));
|
||||||
|
const auto& algorithm = algorithms[algorithm_idx];
|
||||||
|
BlasScratchAllocator scratch_allocator(max_scratch_size, context);
|
||||||
|
bool cublas_launch_status =
|
||||||
|
stream
|
||||||
|
->ThenBlasLtMatmul(plan.get(), alpha, *b_ptrs[0], *a_ptrs[0],
|
||||||
|
beta, c_ptrs[0], &scratch_allocator,
|
||||||
|
algorithm.get())
|
||||||
|
.ok();
|
||||||
|
if (!cublas_launch_status) {
|
||||||
|
context->SetStatus(errors::Internal(
|
||||||
|
"Blas batched matmul launch failed : a.shape=(",
|
||||||
|
bcast.x_batch_size(), ", ", in_x.dim_size(0), ", ",
|
||||||
|
in_x.dim_size(1), "), b.shape=(", bcast.y_batch_size(), ", ",
|
||||||
|
in_y.dim_size(0), ", ", in_y.dim_size(1), "), m=", m, ", n=", n,
|
||||||
|
", k=", k, ", batch_size=", batch_size));
|
||||||
|
}
|
||||||
|
} else { // requires mixed broadcasting
|
||||||
|
const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
|
||||||
|
const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
|
||||||
|
for (int64 i = 0; i < bcast.x_batch_size(); ++i) {
|
||||||
|
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
|
||||||
|
}
|
||||||
|
for (int64 i = 0; i < bcast.y_batch_size(); ++i) {
|
||||||
|
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
|
||||||
|
}
|
||||||
|
for (int64 i = 0; i < batch_size; ++i) {
|
||||||
|
c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
|
||||||
|
a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
|
||||||
|
b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
|
||||||
|
c_ptrs.push_back(&c_device_memory.back());
|
||||||
|
}
|
||||||
|
|
||||||
|
BlasScratchAllocator scratch_allocator(max_scratch_size, context);
|
||||||
|
bool blas_launch_status =
|
||||||
|
stream
|
||||||
|
->ThenBlasGemmBatchedWithScratch(
|
||||||
|
blas_transpose_b, blas_transpose_a, n, m, k,
|
||||||
|
static_cast<Coefficient>(1.0), b_ptrs,
|
||||||
|
adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
|
||||||
|
static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
|
||||||
|
&scratch_allocator)
|
||||||
|
.ok();
|
||||||
|
if (!blas_launch_status) {
|
||||||
|
context->SetStatus(errors::Internal(
|
||||||
|
"Blas xGEMMBatched launch failed : a.shape=",
|
||||||
|
in_x.shape().DebugString(),
|
||||||
|
", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
|
||||||
|
", k=", k, ", batch_size=", batch_size));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
#else // if not GOOGLE_CUDA or CUDA_VERSION < 11000
|
||||||
bool is_full_broadcast =
|
bool is_full_broadcast =
|
||||||
std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
|
std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
|
||||||
bool use_strided_batched =
|
bool use_strided_batched =
|
||||||
@ -388,8 +743,6 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef Scalar Coefficient;
|
|
||||||
|
|
||||||
// Blas does
|
// Blas does
|
||||||
// C = A x B
|
// C = A x B
|
||||||
// where A, B and C are assumed to be in column major.
|
// where A, B and C are assumed to be in column major.
|
||||||
@ -399,7 +752,10 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
|||||||
if (batch_size == 1) {
|
if (batch_size == 1) {
|
||||||
// This is a regular matrix*matrix or matrix*vector multiply. Avoid the
|
// This is a regular matrix*matrix or matrix*vector multiply. Avoid the
|
||||||
// overhead of the scratch allocator and the batch interface.
|
// overhead of the scratch allocator and the batch interface.
|
||||||
if (n == 1 &&
|
// Note that the GEMV call here does not support Eigen::half, so we do not
|
||||||
|
// use this path in that case. A workaround is applied to the pointers
|
||||||
|
// passed to the call itself to avoid compilation errors.
|
||||||
|
if (!std::is_same<Scalar, Eigen::half>::value && n == 1 &&
|
||||||
blas_transpose_b != se::blas::Transpose::kConjugateTranspose &&
|
blas_transpose_b != se::blas::Transpose::kConjugateTranspose &&
|
||||||
blas_transpose_a != se::blas::Transpose::kConjugateTranspose) {
|
blas_transpose_a != se::blas::Transpose::kConjugateTranspose) {
|
||||||
// This is a matrix*vector multiply so use GEMV to compute A * b.
|
// This is a matrix*vector multiply so use GEMV to compute A * b.
|
||||||
@ -410,13 +766,19 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
|||||||
auto gemv_trans_a = blas_transpose_a == se::blas::Transpose::kTranspose
|
auto gemv_trans_a = blas_transpose_a == se::blas::Transpose::kTranspose
|
||||||
? se::blas::Transpose::kNoTranspose
|
? se::blas::Transpose::kNoTranspose
|
||||||
: se::blas::Transpose::kTranspose;
|
: se::blas::Transpose::kTranspose;
|
||||||
|
// Cast pointers as a workaround for GEMV not supporting Eigen::half
|
||||||
|
// (this will never actually be executed for Eigen::half).
|
||||||
|
typedef se::DeviceMemory<Coefficient> NonHalfDeviceMemoryType;
|
||||||
|
NonHalfDeviceMemoryType a_ptr(*(a_ptrs[0]));
|
||||||
|
NonHalfDeviceMemoryType b_ptr(*(b_ptrs[0]));
|
||||||
|
NonHalfDeviceMemoryType c_ptr(*(c_ptrs[0]));
|
||||||
bool blas_launch_status =
|
bool blas_launch_status =
|
||||||
stream
|
stream
|
||||||
->ThenBlasGemv(gemv_trans_a, adj_x || trans_x ? m : k,
|
->ThenBlasGemv(gemv_trans_a, adj_x || trans_x ? m : k,
|
||||||
adj_x || trans_x ? k : m,
|
adj_x || trans_x ? k : m,
|
||||||
static_cast<Coefficient>(1.0), *(a_ptrs[0]),
|
static_cast<Coefficient>(1.0), a_ptr,
|
||||||
adj_x || trans_x ? m : k, *(b_ptrs[0]), 1,
|
adj_x || trans_x ? m : k, b_ptr, 1,
|
||||||
static_cast<Coefficient>(0.0), c_ptrs[0], 1)
|
static_cast<Coefficient>(0.0), &c_ptr, 1)
|
||||||
.ok();
|
.ok();
|
||||||
if (!blas_launch_status) {
|
if (!blas_launch_status) {
|
||||||
context->SetStatus(errors::Internal(
|
context->SetStatus(errors::Internal(
|
||||||
@ -459,154 +821,7 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
|||||||
", k=", k, ", batch_size=", batch_size));
|
", k=", k, ", batch_size=", batch_size));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
BlasScratchAllocator scratch_allocator(context);
|
BlasScratchAllocator scratch_allocator(max_scratch_size, context);
|
||||||
bool blas_launch_status =
|
|
||||||
stream
|
|
||||||
->ThenBlasGemmBatchedWithScratch(
|
|
||||||
blas_transpose_b, blas_transpose_a, n, m, k,
|
|
||||||
static_cast<Coefficient>(1.0), b_ptrs,
|
|
||||||
adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
|
|
||||||
static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
|
|
||||||
&scratch_allocator)
|
|
||||||
.ok();
|
|
||||||
if (!blas_launch_status) {
|
|
||||||
context->SetStatus(errors::Internal(
|
|
||||||
"Blas xGEMMBatched launch failed : a.shape=",
|
|
||||||
in_x.shape().DebugString(),
|
|
||||||
", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
|
|
||||||
", k=", k, ", batch_size=", batch_size));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
|
|
||||||
static void Launch(OpKernelContext* context, const Tensor& in_x,
|
|
||||||
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
|
|
||||||
bool trans_y, const MatMulBCast& bcast, Tensor* out) {
|
|
||||||
typedef Eigen::half Scalar;
|
|
||||||
se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
|
|
||||||
se::blas::Transpose::kTranspose,
|
|
||||||
se::blas::Transpose::kConjugateTranspose};
|
|
||||||
const uint64 m = in_x.dim_size(adj_x || trans_x ? 2 : 1);
|
|
||||||
const uint64 k = in_x.dim_size(adj_x || trans_x ? 1 : 2);
|
|
||||||
const uint64 n = in_y.dim_size(adj_y || trans_y ? 1 : 2);
|
|
||||||
const uint64 batch_size = bcast.output_batch_size();
|
|
||||||
auto blas_transpose_a = trans[adj_x ? 2 : (trans_x ? 1 : 0)];
|
|
||||||
auto blas_transpose_b = trans[adj_y ? 2 : (trans_y ? 1 : 0)];
|
|
||||||
|
|
||||||
auto* stream = context->op_device_context()->stream();
|
|
||||||
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
|
|
||||||
|
|
||||||
typedef perftools::gputools::DeviceMemory<Scalar> DeviceMemoryType;
|
|
||||||
std::vector<DeviceMemoryType> a_device_memory;
|
|
||||||
std::vector<DeviceMemoryType> b_device_memory;
|
|
||||||
std::vector<DeviceMemoryType> c_device_memory;
|
|
||||||
std::vector<DeviceMemoryType*> a_ptrs;
|
|
||||||
std::vector<DeviceMemoryType*> b_ptrs;
|
|
||||||
std::vector<DeviceMemoryType*> c_ptrs;
|
|
||||||
a_device_memory.reserve(bcast.x_batch_size());
|
|
||||||
b_device_memory.reserve(bcast.y_batch_size());
|
|
||||||
c_device_memory.reserve(batch_size);
|
|
||||||
a_ptrs.reserve(batch_size);
|
|
||||||
b_ptrs.reserve(batch_size);
|
|
||||||
c_ptrs.reserve(batch_size);
|
|
||||||
auto* a_base_ptr = in_x.template flat<Scalar>().data();
|
|
||||||
auto* b_base_ptr = in_y.template flat<Scalar>().data();
|
|
||||||
auto* c_base_ptr = out->template flat<Scalar>().data();
|
|
||||||
|
|
||||||
uint64 a_stride;
|
|
||||||
uint64 b_stride;
|
|
||||||
uint64 c_stride;
|
|
||||||
|
|
||||||
bool is_full_broadcast =
|
|
||||||
std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
|
|
||||||
bool use_strided_batched =
|
|
||||||
(!bcast.IsBroadcastingRequired() || is_full_broadcast) &&
|
|
||||||
batch_size > 1;
|
|
||||||
if (use_strided_batched) {
|
|
||||||
a_stride = bcast.x_batch_size() != 1 ? m * k : 0;
|
|
||||||
b_stride = bcast.y_batch_size() != 1 ? k * n : 0;
|
|
||||||
c_stride = m * n;
|
|
||||||
a_device_memory.push_back(AsDeviceMemory(a_base_ptr));
|
|
||||||
b_device_memory.push_back(AsDeviceMemory(b_base_ptr));
|
|
||||||
c_device_memory.push_back(AsDeviceMemory(c_base_ptr));
|
|
||||||
a_ptrs.push_back(&a_device_memory.back());
|
|
||||||
b_ptrs.push_back(&b_device_memory.back());
|
|
||||||
c_ptrs.push_back(&c_device_memory.back());
|
|
||||||
} else if (!bcast.IsBroadcastingRequired()) {
|
|
||||||
for (int64 i = 0; i < batch_size; ++i) {
|
|
||||||
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
|
|
||||||
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
|
|
||||||
c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
|
|
||||||
a_ptrs.push_back(&a_device_memory.back());
|
|
||||||
b_ptrs.push_back(&b_device_memory.back());
|
|
||||||
c_ptrs.push_back(&c_device_memory.back());
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
|
|
||||||
const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
|
|
||||||
for (int64 i = 0; i < bcast.x_batch_size(); ++i) {
|
|
||||||
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
|
|
||||||
}
|
|
||||||
for (int64 i = 0; i < bcast.y_batch_size(); ++i) {
|
|
||||||
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
|
|
||||||
}
|
|
||||||
for (int64 i = 0; i < batch_size; ++i) {
|
|
||||||
c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
|
|
||||||
a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
|
|
||||||
b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
|
|
||||||
c_ptrs.push_back(&c_device_memory.back());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
typedef float Coefficient;
|
|
||||||
|
|
||||||
// Blas does
|
|
||||||
// C = A x B
|
|
||||||
// where A, B and C are assumed to be in column major.
|
|
||||||
// We want the output to be in row-major, so we can compute
|
|
||||||
// C' = B' x A', where ' stands for transpose (not adjoint).
|
|
||||||
// TODO(yangzihao): Choose the best of the three strategies using autotune.
|
|
||||||
if (batch_size == 1) {
|
|
||||||
// This is a regular matrix*matrix or matrix*vector multiply. Avoid the
|
|
||||||
// overhead of the scratch allocator and the batch interface.
|
|
||||||
// TODO(benbarsdell): Use fp16 Gemv if it becomes supported by CUBLAS
|
|
||||||
bool blas_launch_status =
|
|
||||||
stream
|
|
||||||
->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
|
|
||||||
static_cast<Coefficient>(1.0), *(b_ptrs[0]),
|
|
||||||
adj_y || trans_y ? k : n, *(a_ptrs[0]),
|
|
||||||
adj_x || trans_x ? m : k,
|
|
||||||
static_cast<Coefficient>(0.0), c_ptrs[0], n)
|
|
||||||
.ok();
|
|
||||||
if (!blas_launch_status) {
|
|
||||||
context->SetStatus(errors::Internal(
|
|
||||||
"Blas xGEMM launch failed : a.shape=", in_x.shape().DebugString(),
|
|
||||||
", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
|
|
||||||
", k=", k));
|
|
||||||
}
|
|
||||||
} else if (use_strided_batched) {
|
|
||||||
bool blas_launch_status =
|
|
||||||
stream
|
|
||||||
->ThenBlasGemmStridedBatched(
|
|
||||||
blas_transpose_b, blas_transpose_a, n, m, k,
|
|
||||||
static_cast<Coefficient>(1.0), *b_ptrs[0],
|
|
||||||
adj_y || trans_y ? k : n, b_stride, *a_ptrs[0],
|
|
||||||
adj_x || trans_x ? m : k, a_stride,
|
|
||||||
static_cast<Coefficient>(0.0), c_ptrs[0], n, c_stride,
|
|
||||||
batch_size)
|
|
||||||
.ok();
|
|
||||||
if (!blas_launch_status) {
|
|
||||||
context->SetStatus(errors::Internal(
|
|
||||||
"Blas xGEMMStridedBatched launch failed : a.shape=",
|
|
||||||
in_x.shape().DebugString(),
|
|
||||||
", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
|
|
||||||
", k=", k, ", batch_size=", batch_size));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
BlasScratchAllocator scratch_allocator(context);
|
|
||||||
bool blas_launch_status =
|
bool blas_launch_status =
|
||||||
stream
|
stream
|
||||||
->ThenBlasGemmBatchedWithScratch(
|
->ThenBlasGemmBatchedWithScratch(
|
||||||
@ -624,6 +839,7 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
|
|||||||
", k=", k, ", batch_size=", batch_size));
|
", k=", k, ", batch_size=", batch_size));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif // not GOOGLE_CUDA or CUDA_VERSION < 11000
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -637,6 +853,7 @@ class BaseBatchMatMulOp : public OpKernel {
|
|||||||
: OpKernel(context) {
|
: OpKernel(context) {
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
|
OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
|
OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
|
||||||
|
use_autotune_ = MatmulAutotuneEnable();
|
||||||
}
|
}
|
||||||
|
|
||||||
~BaseBatchMatMulOp() override {}
|
~BaseBatchMatMulOp() override {}
|
||||||
@ -698,7 +915,7 @@ class BaseBatchMatMulOp : public OpKernel {
|
|||||||
out->shape().DebugString()));
|
out->shape().DebugString()));
|
||||||
LaunchBatchMatMul<Device, Scalar>::Launch(
|
LaunchBatchMatMul<Device, Scalar>::Launch(
|
||||||
ctx, in0_reshaped, in1_reshaped, adj_x_, adj_y_, /*trans_x=*/false,
|
ctx, in0_reshaped, in1_reshaped, adj_x_, adj_y_, /*trans_x=*/false,
|
||||||
/*trans_y=*/false, bcast, &out_reshaped);
|
/*trans_y=*/false, bcast, use_autotune_, &out_reshaped);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
@ -708,6 +925,7 @@ class BaseBatchMatMulOp : public OpKernel {
|
|||||||
private:
|
private:
|
||||||
bool adj_x_;
|
bool adj_x_;
|
||||||
bool adj_y_;
|
bool adj_y_;
|
||||||
|
bool use_autotune_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// BatchMatMul Op implementation which disallows broadcasting.
|
// BatchMatMul Op implementation which disallows broadcasting.
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include "google/protobuf/any.pb.h"
|
#include "google/protobuf/any.pb.h"
|
||||||
#include "absl/algorithm/container.h"
|
#include "absl/algorithm/container.h"
|
||||||
#include "absl/base/call_once.h"
|
#include "absl/base/call_once.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/platform/logger.h"
|
#include "tensorflow/core/platform/logger.h"
|
||||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||||
#include "tensorflow/core/protobuf/conv_autotuning.pb.h"
|
#include "tensorflow/core/protobuf/conv_autotuning.pb.h"
|
||||||
@ -282,6 +283,62 @@ Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64 GetWorkspaceLimit(const string& envvar_in_mb,
|
||||||
|
int64 default_value_in_bytes) {
|
||||||
|
const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
|
||||||
|
if (workspace_limit_in_mb_str != nullptr &&
|
||||||
|
strcmp(workspace_limit_in_mb_str, "") != 0) {
|
||||||
|
int64 scratch_limit_in_mb = -1;
|
||||||
|
if (strings::safe_strto64(workspace_limit_in_mb_str,
|
||||||
|
&scratch_limit_in_mb)) {
|
||||||
|
return scratch_limit_in_mb * (1 << 20);
|
||||||
|
} else {
|
||||||
|
LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
|
||||||
|
<< workspace_limit_in_mb_str;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return default_value_in_bytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
GpuScratchAllocator::GpuScratchAllocator(int64 memory_limit,
|
||||||
|
OpKernelContext* context)
|
||||||
|
: memory_limit_(memory_limit), total_byte_size_(0), context_(context) {}
|
||||||
|
|
||||||
|
se::port::StatusOr<se::DeviceMemory<uint8>> GpuScratchAllocator::AllocateBytes(
|
||||||
|
int64 byte_size) {
|
||||||
|
Tensor temporary_memory;
|
||||||
|
if (byte_size < 0) {
|
||||||
|
return se::port::Status{se::port::error::INVALID_ARGUMENT,
|
||||||
|
"Requested negative byte size!"};
|
||||||
|
}
|
||||||
|
if (byte_size > memory_limit_) {
|
||||||
|
return se::port::Status{
|
||||||
|
se::port::error::UNAVAILABLE,
|
||||||
|
absl::StrCat("Requested memory size (", byte_size,
|
||||||
|
") exceeds the max memory limit (", memory_limit_, ").")};
|
||||||
|
}
|
||||||
|
AllocationAttributes allocation_attr;
|
||||||
|
allocation_attr.retry_on_failure = false;
|
||||||
|
Status allocation_status(context_->allocate_temp(
|
||||||
|
DT_UINT8, TensorShape({byte_size}), &temporary_memory,
|
||||||
|
AllocatorAttributes(), allocation_attr));
|
||||||
|
if (!allocation_status.ok()) {
|
||||||
|
return se::port::Status{
|
||||||
|
se::port::error::UNAVAILABLE,
|
||||||
|
absl::StrCat("Failed to allocate the requested memory size (",
|
||||||
|
byte_size, ").")};
|
||||||
|
}
|
||||||
|
// Hold the reference of the allocated tensors until the end of the
|
||||||
|
// allocator.
|
||||||
|
// NOTE: We expect tensors to be deallocated when this allocator goes out of
|
||||||
|
// scope when allocated_tensors is destructed.
|
||||||
|
allocated_tensors_.push_back(temporary_memory);
|
||||||
|
total_byte_size_ += byte_size;
|
||||||
|
return se::port::StatusOr<se::DeviceMemory<uint8>>(
|
||||||
|
AsDeviceMemory(temporary_memory.flat<uint8>().data(),
|
||||||
|
temporary_memory.flat<uint8>().size()));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
@ -243,6 +243,42 @@ void LogFusedConvForwardAutotuneResults(
|
|||||||
Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
|
Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
|
||||||
se::dnn::AlgorithmConfig* algo);
|
se::dnn::AlgorithmConfig* algo);
|
||||||
|
|
||||||
|
// Get a workspace limit from the environment variable, which is in MB.
|
||||||
|
// Return the workspace memory limit in bytes. If no value is set, return the
|
||||||
|
// default value.
|
||||||
|
int64 GetWorkspaceLimit(const string& envvar_in_mb,
|
||||||
|
int64 default_value_in_bytes);
|
||||||
|
|
||||||
|
// Get the Dnn workspace limit from the environment variable, which is in MB.
|
||||||
|
// Return the workspace memory limit in bytes. If no value is set, return the
|
||||||
|
// default value.
|
||||||
|
int64 GetDnnWorkspaceLimit(const string& envvar_in_mb,
|
||||||
|
int64 default_value_in_bytes);
|
||||||
|
|
||||||
|
// A class to provide scratch-space allocator for Stream-Executor callbacks in
|
||||||
|
// CUDA libraries (CUDNN etc.).
|
||||||
|
// TensorFlow is responsible for releasing the temporary buffers after
|
||||||
|
// the kernel finishes.
|
||||||
|
class GpuScratchAllocator : public se::ScratchAllocator {
|
||||||
|
public:
|
||||||
|
virtual ~GpuScratchAllocator() {}
|
||||||
|
|
||||||
|
GpuScratchAllocator(int64 memory_limit, OpKernelContext* context);
|
||||||
|
|
||||||
|
int64 GetMemoryLimitInBytes() override { return memory_limit_; }
|
||||||
|
|
||||||
|
se::port::StatusOr<se::DeviceMemory<uint8>> AllocateBytes(
|
||||||
|
int64 byte_size) override;
|
||||||
|
|
||||||
|
int64 TotalByteSize() { return total_byte_size_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
int64 memory_limit_;
|
||||||
|
int64 total_byte_size_;
|
||||||
|
OpKernelContext* context_;
|
||||||
|
std::vector<Tensor> allocated_tensors_;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
@ -549,6 +549,7 @@ struct EinsumHelper {
|
|||||||
static Status ContractOperands(OpKernelContext* ctx,
|
static Status ContractOperands(OpKernelContext* ctx,
|
||||||
absl::Span<const Tensor> inputs,
|
absl::Span<const Tensor> inputs,
|
||||||
absl::Span<const bool> swap_free_and_contract,
|
absl::Span<const bool> swap_free_and_contract,
|
||||||
|
bool use_autotune,
|
||||||
Tensor* output) {
|
Tensor* output) {
|
||||||
if (inputs.size() == 1)
|
if (inputs.size() == 1)
|
||||||
return CopyFrom(inputs[0], inputs[0].shape(), output);
|
return CopyFrom(inputs[0], inputs[0].shape(), output);
|
||||||
@ -583,7 +584,7 @@ struct EinsumHelper {
|
|||||||
ReshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped));
|
ReshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped));
|
||||||
LaunchBatchMatMul<Device, T>::Launch(ctx, lhs, rhs, /*adj_x=*/false,
|
LaunchBatchMatMul<Device, T>::Launch(ctx, lhs, rhs, /*adj_x=*/false,
|
||||||
/*adj_y=*/false, trans_x, trans_y,
|
/*adj_y=*/false, trans_x, trans_y,
|
||||||
bcast, &output_reshaped);
|
bcast, use_autotune, &output_reshaped);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -598,6 +599,7 @@ class EinsumOp : public OpKernel {
|
|||||||
equation_, &input_labels_, &output_labels_, &label_types_,
|
equation_, &input_labels_, &output_labels_, &label_types_,
|
||||||
&input_label_counts_, &output_label_counts_,
|
&input_label_counts_, &output_label_counts_,
|
||||||
&input_has_ellipsis_, &output_has_ellipsis_));
|
&input_has_ellipsis_, &output_has_ellipsis_));
|
||||||
|
use_autotune_ = MatmulAutotuneEnable();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
@ -640,7 +642,7 @@ class EinsumOp : public OpKernel {
|
|||||||
Tensor contraction_output_reshaped;
|
Tensor contraction_output_reshaped;
|
||||||
OP_REQUIRES_OK(ctx, EinsumHelper::ContractOperands<Device, T>(
|
OP_REQUIRES_OK(ctx, EinsumHelper::ContractOperands<Device, T>(
|
||||||
ctx, inputs_reduced, swap_free_and_contract,
|
ctx, inputs_reduced, swap_free_and_contract,
|
||||||
&contraction_output_reshaped));
|
use_autotune_, &contraction_output_reshaped));
|
||||||
|
|
||||||
// Copy the batch labels from the contraction output. Recover the batch
|
// Copy the batch labels from the contraction output. Recover the batch
|
||||||
// shape, which may have been broadcasted.
|
// shape, which may have been broadcasted.
|
||||||
@ -738,6 +740,7 @@ class EinsumOp : public OpKernel {
|
|||||||
LabelCounts output_label_counts_;
|
LabelCounts output_label_counts_;
|
||||||
gtl::InlinedVector<bool, 2> input_has_ellipsis_;
|
gtl::InlinedVector<bool, 2> input_has_ellipsis_;
|
||||||
bool output_has_ellipsis_ = false;
|
bool output_has_ellipsis_ = false;
|
||||||
|
bool use_autotune_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
@ -48,4 +48,22 @@ bool MatmulDoFP32ComputationFP16Input() {
|
|||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int MatmulMaxAutotuneAlgorithmCount() {
|
||||||
|
int64 value;
|
||||||
|
// In CUDA 11, cublasLtMatmulAlgoGetHeuristic typically returns <= 4
|
||||||
|
// algorithms for a given configuration, so 10 seems like a reasonable default
|
||||||
|
// here.
|
||||||
|
Status status =
|
||||||
|
ReadInt64FromEnvVar("TF_MATMUL_AUTOTUNE_MAX_ALGORITHMS", 10, &value);
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(ERROR) << status.error_message();
|
||||||
|
}
|
||||||
|
static constexpr const int kMaxValue = std::numeric_limits<int>::max();
|
||||||
|
if (value < 1 || value > kMaxValue) {
|
||||||
|
LOG(ERROR) << "Invalid value for TF_MATMUL_AUTOTUNE_MAX_ALGORITHMS: "
|
||||||
|
<< value << " is not in range [1, " << kMaxValue << "]";
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -22,6 +22,7 @@ namespace tensorflow {
|
|||||||
|
|
||||||
bool MatmulAutotuneEnable();
|
bool MatmulAutotuneEnable();
|
||||||
bool MatmulDoFP32ComputationFP16Input();
|
bool MatmulDoFP32ComputationFP16Input();
|
||||||
|
int MatmulMaxAutotuneAlgorithmCount();
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user