Rollback PR #43237: Integrate cuBLASLt API into backend
Reason: Performance regression on bert_pretraining and bert_squad. PiperOrigin-RevId: 339565741 Change-Id: I1c6a4bf807b3cb3aa132e6272a2f01f90bdeca6d
This commit is contained in:
parent
5e38165d02
commit
af5c3ccabc
@ -3396,7 +3396,7 @@ tf_kernel_library(
|
||||
name = "fft_ops",
|
||||
prefix = "fft_ops",
|
||||
deps = MATH_DEPS + [
|
||||
] + if_cuda_or_rocm([":gpu_utils"]) + if_cuda([
|
||||
] + if_cuda([
|
||||
"//tensorflow/core/platform/default/build_config:cufft_plugin",
|
||||
]),
|
||||
)
|
||||
|
@ -35,7 +35,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/matmul_autotune.h"
|
||||
#include "tensorflow/core/util/matmul_bcast.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
@ -44,13 +43,8 @@ limitations under the License.
|
||||
#endif
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/gpu_utils.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/platform/tensor_float_32_utils.h"
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/gpus/cuda/include/cuda.h" // For CUDA_VERSION
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -225,8 +219,7 @@ template <typename Scalar>
|
||||
struct LaunchBatchMatMul<CPUDevice, Scalar> {
|
||||
static void Launch(OpKernelContext* context, const Tensor& in_x,
|
||||
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
|
||||
bool trans_y, const MatMulBCast& bcast, bool use_autotune,
|
||||
Tensor* out) {
|
||||
bool trans_y, const MatMulBCast& bcast, Tensor* out) {
|
||||
typedef ParallelMatMulKernel<Scalar, Eigen::NumTraits<Scalar>::IsComplex>
|
||||
ParallelMatMulKernel;
|
||||
bool conjugate_result = false;
|
||||
@ -282,212 +275,45 @@ se::DeviceMemory<T> AsDeviceMemory(const T* gpu_memory) {
|
||||
return typed;
|
||||
}
|
||||
|
||||
using BlasScratchAllocator = GpuScratchAllocator;
|
||||
|
||||
int64 GetBlasWorkspaceLimit(const string& envvar_in_mb,
|
||||
int64 default_value_in_bytes) {
|
||||
return gpu_utils::GetWorkspaceLimit(envvar_in_mb, default_value_in_bytes);
|
||||
}
|
||||
|
||||
// Encapsulate all of the shape, dtype etc. information that defines a unique
|
||||
// batched matmul operation.
|
||||
class BatchMatmulParameters {
|
||||
class BlasScratchAllocator : public se::ScratchAllocator {
|
||||
public:
|
||||
BatchMatmulParameters(bool trans_a, bool trans_b, bool adj_a, bool adj_b,
|
||||
uint64 m, uint64 n, uint64 k, uint64 batch_count,
|
||||
bool broadcast_a, bool broadcast_b, DataType dtype_ab,
|
||||
DataType dtype_cd, bool allow_tf32, int device_id)
|
||||
: trans_a_(trans_a),
|
||||
trans_b_(trans_b),
|
||||
adj_a_(adj_a),
|
||||
adj_b_(adj_b),
|
||||
m_(m),
|
||||
n_(n),
|
||||
k_(k),
|
||||
batch_count_(batch_count),
|
||||
broadcast_a_(broadcast_a),
|
||||
broadcast_b_(broadcast_b),
|
||||
dtype_ab_(dtype_ab),
|
||||
dtype_cd_(dtype_cd),
|
||||
allow_tf32_(allow_tf32),
|
||||
device_id_(device_id) {
|
||||
hash_code_ = trans_a;
|
||||
hash_code_ = Hash64Combine(hash_code_, trans_b);
|
||||
hash_code_ = Hash64Combine(hash_code_, adj_a);
|
||||
hash_code_ = Hash64Combine(hash_code_, adj_b);
|
||||
hash_code_ = Hash64Combine(hash_code_, m);
|
||||
hash_code_ = Hash64Combine(hash_code_, n);
|
||||
hash_code_ = Hash64Combine(hash_code_, k);
|
||||
hash_code_ = Hash64Combine(hash_code_, batch_count);
|
||||
hash_code_ = Hash64Combine(hash_code_, broadcast_a);
|
||||
hash_code_ = Hash64Combine(hash_code_, broadcast_b);
|
||||
hash_code_ = Hash64Combine(hash_code_, dtype_ab);
|
||||
hash_code_ = Hash64Combine(hash_code_, dtype_cd);
|
||||
hash_code_ = Hash64Combine(hash_code_, allow_tf32);
|
||||
hash_code_ = Hash64Combine(hash_code_, device_id);
|
||||
}
|
||||
bool operator==(const BatchMatmulParameters& other) const {
|
||||
return this->get_data_as_tuple() == other.get_data_as_tuple();
|
||||
}
|
||||
using Stream = se::Stream;
|
||||
using DeviceMemoryBytes = se::DeviceMemory<uint8>;
|
||||
|
||||
bool operator!=(const BatchMatmulParameters& other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
uint64 hash() const { return hash_code_; }
|
||||
BlasScratchAllocator(OpKernelContext* context) : context_(context) {}
|
||||
|
||||
string ToString() const {
|
||||
// clang-format off
|
||||
return strings::StrCat(
|
||||
trans_a_, ", ", trans_b_, ", ", adj_a_, ", ", adj_b_, ", ",
|
||||
m_, ", ", n_, ", ", k_, ", ", batch_count_, ", ",
|
||||
broadcast_a_, ", ", broadcast_b_, ", ",
|
||||
dtype_ab_, ", ", dtype_cd_, ", ", allow_tf32_, ", ", device_id_);
|
||||
// clang-format on
|
||||
int64 GetMemoryLimitInBytes() override { return -1; }
|
||||
|
||||
se::port::StatusOr<DeviceMemoryBytes> AllocateBytes(
|
||||
int64 byte_size) override {
|
||||
Tensor temporary_memory;
|
||||
|
||||
Status allocation_status(context_->allocate_temp(
|
||||
DT_UINT8, TensorShape({byte_size}), &temporary_memory));
|
||||
if (!allocation_status.ok()) {
|
||||
return se::port::StatusOr<DeviceMemoryBytes>(
|
||||
DeviceMemoryBytes::MakeFromByteSize(nullptr, 0));
|
||||
}
|
||||
// Hold the reference of the allocated tensors until the end of the
|
||||
// allocator.
|
||||
allocated_tensors_.push_back(temporary_memory);
|
||||
return se::port::StatusOr<DeviceMemoryBytes>(
|
||||
DeviceMemoryBytes::MakeFromByteSize(
|
||||
temporary_memory.flat<uint8>().data(),
|
||||
temporary_memory.flat<uint8>().size()));
|
||||
}
|
||||
|
||||
private:
|
||||
typedef std::tuple<bool, bool, bool, bool, int64, int64, int64, int64, bool,
|
||||
bool, DataType, DataType, bool, int>
|
||||
ParameterDataType;
|
||||
|
||||
ParameterDataType get_data_as_tuple() const {
|
||||
return std::make_tuple(trans_a_, trans_b_, adj_a_, adj_b_, m_, n_, k_,
|
||||
batch_count_, broadcast_a_, broadcast_b_, dtype_ab_,
|
||||
dtype_cd_, allow_tf32_, device_id_);
|
||||
}
|
||||
|
||||
bool trans_a_;
|
||||
bool trans_b_;
|
||||
bool adj_a_;
|
||||
bool adj_b_;
|
||||
uint64 m_;
|
||||
uint64 n_;
|
||||
uint64 k_;
|
||||
uint64 batch_count_;
|
||||
bool broadcast_a_;
|
||||
bool broadcast_b_;
|
||||
DataType dtype_ab_;
|
||||
DataType dtype_cd_;
|
||||
bool allow_tf32_;
|
||||
int device_id_;
|
||||
uint64 hash_code_;
|
||||
OpKernelContext* context_;
|
||||
std::vector<Tensor> allocated_tensors_;
|
||||
};
|
||||
|
||||
bool GetBlasComputationType(const DataType& dtype, bool allow_tf32,
|
||||
se::blas::ComputationType* compute_type) {
|
||||
using se::blas::ComputationType;
|
||||
static bool use_f32_for_f16_computation = MatmulDoFP32ComputationFP16Input();
|
||||
ComputationType f32_type =
|
||||
allow_tf32 ? ComputationType::kTF32AsF32 : ComputationType::kF32;
|
||||
switch (dtype) {
|
||||
case DT_HALF:
|
||||
case DT_BFLOAT16:
|
||||
*compute_type =
|
||||
use_f32_for_f16_computation ? f32_type : ComputationType::kF16;
|
||||
return true;
|
||||
case DT_FLOAT:
|
||||
*compute_type = f32_type;
|
||||
return true;
|
||||
case DT_DOUBLE:
|
||||
*compute_type = ComputationType::kF64;
|
||||
return true;
|
||||
case DT_COMPLEX64:
|
||||
*compute_type = f32_type;
|
||||
return true;
|
||||
case DT_COMPLEX128:
|
||||
*compute_type = ComputationType::kComplexF64;
|
||||
return true;
|
||||
default:
|
||||
// Unsupported compute_type, return false.
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Thread-safe map from matmul parameters to their corresponding plan and
|
||||
// algorithms.
|
||||
template <typename Parameters>
|
||||
class BlasLtMatmulPlanMap {
|
||||
public:
|
||||
struct PlanAndAlgorithms {
|
||||
std::unique_ptr<se::blas::IBlasLtMatmulPlan> plan;
|
||||
std::vector<std::unique_ptr<se::blas::IBlasLtMatmulAlgorithm>> algorithms;
|
||||
};
|
||||
|
||||
const PlanAndAlgorithms* Find(const Parameters& params) {
|
||||
mutex_lock lock(mu_);
|
||||
auto iter = params_plan_map_.find(params);
|
||||
if (iter == params_plan_map_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return &iter->second;
|
||||
}
|
||||
const PlanAndAlgorithms* Insert(const Parameters& params,
|
||||
PlanAndAlgorithms value) {
|
||||
mutex_lock lock(mu_);
|
||||
return ¶ms_plan_map_.emplace(params, std::move(value)).first->second;
|
||||
}
|
||||
|
||||
private:
|
||||
struct Hasher {
|
||||
std::size_t operator()(const Parameters& parameter) const {
|
||||
return parameter.hash();
|
||||
}
|
||||
};
|
||||
|
||||
mutable mutex mu_;
|
||||
std::unordered_map<Parameters, PlanAndAlgorithms, Hasher> params_plan_map_
|
||||
GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
template <typename Parameters>
|
||||
struct BlasLtPlanMapSingleton {
|
||||
typedef BlasLtMatmulPlanMap<Parameters> PlanMapType;
|
||||
static PlanMapType* GetInstance() {
|
||||
static PlanMapType* instance = new PlanMapType();
|
||||
return instance;
|
||||
}
|
||||
};
|
||||
|
||||
typedef BlasLtPlanMapSingleton<BatchMatmulParameters>
|
||||
BatchMatmulPlanMapSingleton;
|
||||
|
||||
// A dummy type to group matmul autotune results together.
|
||||
struct BatchMatmulAutoTuneGroup {
|
||||
static string name() { return "MatmulLt"; }
|
||||
};
|
||||
|
||||
typedef AutoTuneSingleton<BatchMatmulAutoTuneGroup, BatchMatmulParameters,
|
||||
se::blas::AlgorithmConfig>
|
||||
AutoTuneBatchMatmul;
|
||||
|
||||
template <typename Scalar>
|
||||
struct CoefficientType {
|
||||
typedef Scalar type;
|
||||
};
|
||||
template <>
|
||||
struct CoefficientType<Eigen::half> {
|
||||
typedef float type;
|
||||
};
|
||||
|
||||
inline Status FromExecutorStatus(const se::port::Status& s) {
|
||||
return s.ok() ? Status::OK()
|
||||
: Status(static_cast<error::Code>(static_cast<int>(s.code())),
|
||||
s.error_message());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline Status FromExecutorStatus(const se::port::StatusOr<T>& s) {
|
||||
return FromExecutorStatus(s.status());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Scalar>
|
||||
struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
static void Launch(OpKernelContext* context, const Tensor& in_x,
|
||||
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
|
||||
bool trans_y, const MatMulBCast& bcast, bool use_autotune,
|
||||
Tensor* out) {
|
||||
bool trans_y, const MatMulBCast& bcast, Tensor* out) {
|
||||
se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
|
||||
se::blas::Transpose::kTranspose,
|
||||
se::blas::Transpose::kConjugateTranspose};
|
||||
@ -517,191 +343,10 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
auto* a_base_ptr = in_x.template flat<Scalar>().data();
|
||||
auto* b_base_ptr = in_y.template flat<Scalar>().data();
|
||||
auto* c_base_ptr = out->template flat<Scalar>().data();
|
||||
int64 a_stride;
|
||||
int64 b_stride;
|
||||
int64 c_stride;
|
||||
uint64 a_stride;
|
||||
uint64 b_stride;
|
||||
uint64 c_stride;
|
||||
|
||||
typedef typename CoefficientType<Scalar>::type Coefficient;
|
||||
|
||||
static const int64 max_scratch_size = GetBlasWorkspaceLimit(
|
||||
"TF_CUBLAS_WORKSPACE_LIMIT_IN_MB", 1LL << 32); // 4GB by default
|
||||
|
||||
// The BlasLtMatmul routines are only supported from CUDA 11.0 onward.
|
||||
#if GOOGLE_CUDA && CUDA_VERSION >= 11000
|
||||
bool is_full_broadcast =
|
||||
std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
|
||||
bool requires_mixed_broadcasting =
|
||||
bcast.IsBroadcastingRequired() && !is_full_broadcast;
|
||||
if (!requires_mixed_broadcasting) {
|
||||
bool broadcast_a = bcast.x_batch_size() == 1;
|
||||
bool broadcast_b = bcast.y_batch_size() == 1;
|
||||
a_stride = broadcast_a ? 0 : m * k;
|
||||
b_stride = broadcast_b ? 0 : k * n;
|
||||
c_stride = m * n;
|
||||
a_device_memory.push_back(AsDeviceMemory(a_base_ptr));
|
||||
b_device_memory.push_back(AsDeviceMemory(b_base_ptr));
|
||||
c_device_memory.push_back(AsDeviceMemory(c_base_ptr));
|
||||
a_ptrs.push_back(&a_device_memory.back());
|
||||
b_ptrs.push_back(&b_device_memory.back());
|
||||
c_ptrs.push_back(&c_device_memory.back());
|
||||
|
||||
DataType dtype = DataTypeToEnum<Scalar>::value;
|
||||
bool allow_tf32 = tensor_float_32_execution_enabled();
|
||||
int device_id = stream->parent()->device_ordinal();
|
||||
BatchMatmulParameters matmul_parameters(
|
||||
trans_x, trans_y, adj_x, adj_y, m, n, k, batch_size, broadcast_a,
|
||||
broadcast_b, dtype, dtype, allow_tf32, device_id);
|
||||
|
||||
static const bool max_autotune_algorithm_count =
|
||||
MatmulMaxAutotuneAlgorithmCount();
|
||||
int max_algorithm_count = use_autotune ? max_autotune_algorithm_count : 1;
|
||||
|
||||
const auto* plan_and_algorithms =
|
||||
BatchMatmulPlanMapSingleton::GetInstance()->Find(matmul_parameters);
|
||||
if (!plan_and_algorithms) {
|
||||
se::blas::DataType blas_dtype = se::blas::ToDataType<Scalar>::value;
|
||||
se::blas::ComputationType computation_type;
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
GetBlasComputationType(dtype, allow_tf32, &computation_type),
|
||||
errors::Internal("Unsupported dtype for batched matmul"));
|
||||
|
||||
auto status_or_plan = stream->parent()->CreateBlasLtMatmulPlan(
|
||||
{/*ab_type=*/blas_dtype,
|
||||
/*c_type=*/blas_dtype, computation_type,
|
||||
se::blas::PointerMode::kHost, se::blas::Epilogue::kDefault,
|
||||
blas_transpose_b, blas_transpose_a, n, m, k,
|
||||
/*lda=*/in_y.dim_size(2), /*ldb=*/in_x.dim_size(2),
|
||||
/*ldc=*/static_cast<int64>(n), static_cast<int>(batch_size),
|
||||
b_stride, a_stride, c_stride});
|
||||
OP_REQUIRES(context, status_or_plan.ok(),
|
||||
FromExecutorStatus(status_or_plan));
|
||||
std::unique_ptr<se::blas::IBlasLtMatmulPlan> plan =
|
||||
status_or_plan.ConsumeValueOrDie();
|
||||
|
||||
auto status_or_algorithms = stream->parent()->GetBlasLtMatmulAlgorithms(
|
||||
plan.get(), max_scratch_size, max_algorithm_count);
|
||||
OP_REQUIRES(context, status_or_algorithms.ok(),
|
||||
FromExecutorStatus(status_or_algorithms));
|
||||
auto algorithms = status_or_algorithms.ConsumeValueOrDie();
|
||||
|
||||
plan_and_algorithms =
|
||||
BatchMatmulPlanMapSingleton::GetInstance()->Insert(
|
||||
matmul_parameters, {std::move(plan), std::move(algorithms)});
|
||||
}
|
||||
const auto& plan = plan_and_algorithms->plan;
|
||||
const auto& algorithms = plan_and_algorithms->algorithms;
|
||||
|
||||
// The BlasLtMatmul routines (unlike BlasGemm, BlasGemmBatched etc.) take
|
||||
// alpha and beta with the same type as the matrices.
|
||||
Scalar alpha(1.0);
|
||||
Scalar beta(0.0);
|
||||
|
||||
// Note that algorithm_config.algorithm() here is used to refer
|
||||
// to the index within the algorithms vector, not the algorithm
|
||||
// itself.
|
||||
se::blas::AlgorithmConfig algorithm_config(se::blas::kNoAlgorithm);
|
||||
if (max_algorithm_count == 1) {
|
||||
algorithm_config.set_algorithm(0);
|
||||
} else if (!AutoTuneBatchMatmul::GetInstance()->Find(matmul_parameters,
|
||||
&algorithm_config)) {
|
||||
VLOG(4) << "Autotuning BlasLtMatmul over " << algorithms.size()
|
||||
<< " algorithms.";
|
||||
se::blas::ProfileResult best_result;
|
||||
se::blas::ProfileResult profile_result;
|
||||
// for (const auto& profile_algorithm : plan_and_algorithms->algorithms)
|
||||
// {
|
||||
for (size_t i = 0; i != algorithms.size(); ++i) {
|
||||
const auto& profile_algorithm = algorithms[i];
|
||||
// Create a new scratch allocator with every autotuning run so that
|
||||
// scratch space is deallocated between runs.
|
||||
BlasScratchAllocator scratch_allocator(max_scratch_size, context);
|
||||
|
||||
bool cublas_launch_status =
|
||||
stream
|
||||
->ThenBlasLtMatmul(plan.get(), alpha, *b_ptrs[0], *a_ptrs[0],
|
||||
beta, c_ptrs[0], &scratch_allocator,
|
||||
profile_algorithm.get(), {},
|
||||
&profile_result)
|
||||
.ok();
|
||||
|
||||
VLOG(4) << " Autotune algorithm " << i
|
||||
<< " result: " << profile_result.elapsed_time_in_ms()
|
||||
<< " ms, valid=" << profile_result.is_valid()
|
||||
<< ", workspace_size=" << profile_algorithm->workspace_size();
|
||||
|
||||
if (cublas_launch_status && profile_result.is_valid() &&
|
||||
profile_result.elapsed_time_in_ms() <
|
||||
best_result.elapsed_time_in_ms()) {
|
||||
best_result = profile_result;
|
||||
}
|
||||
}
|
||||
|
||||
if (best_result.is_valid()) {
|
||||
algorithm_config.set_algorithm(best_result.algorithm());
|
||||
}
|
||||
// We make sure that each matmul parameter set only gets one pass of
|
||||
// autotune. If no algorithms works, we add kNoAlgorithm to the autotune
|
||||
// map.
|
||||
AutoTuneBatchMatmul::GetInstance()->Insert(matmul_parameters,
|
||||
algorithm_config);
|
||||
}
|
||||
se::blas::AlgorithmType algorithm_idx = algorithm_config.algorithm();
|
||||
OP_REQUIRES(context,
|
||||
0 <= algorithm_idx && algorithm_idx < algorithms.size(),
|
||||
errors::Internal("Missing/invalid BatchMatmul algorithm"));
|
||||
const auto& algorithm = algorithms[algorithm_idx];
|
||||
BlasScratchAllocator scratch_allocator(max_scratch_size, context);
|
||||
bool cublas_launch_status =
|
||||
stream
|
||||
->ThenBlasLtMatmul(plan.get(), alpha, *b_ptrs[0], *a_ptrs[0],
|
||||
beta, c_ptrs[0], &scratch_allocator,
|
||||
algorithm.get())
|
||||
.ok();
|
||||
if (!cublas_launch_status) {
|
||||
context->SetStatus(errors::Internal(
|
||||
"Blas batched matmul launch failed : a.shape=(",
|
||||
bcast.x_batch_size(), ", ", in_x.dim_size(0), ", ",
|
||||
in_x.dim_size(1), "), b.shape=(", bcast.y_batch_size(), ", ",
|
||||
in_y.dim_size(0), ", ", in_y.dim_size(1), "), m=", m, ", n=", n,
|
||||
", k=", k, ", batch_size=", batch_size));
|
||||
}
|
||||
} else { // requires mixed broadcasting
|
||||
const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
|
||||
const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
|
||||
for (int64 i = 0; i < bcast.x_batch_size(); ++i) {
|
||||
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
|
||||
}
|
||||
for (int64 i = 0; i < bcast.y_batch_size(); ++i) {
|
||||
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
|
||||
}
|
||||
for (int64 i = 0; i < batch_size; ++i) {
|
||||
c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
|
||||
a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
|
||||
b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
|
||||
c_ptrs.push_back(&c_device_memory.back());
|
||||
}
|
||||
|
||||
BlasScratchAllocator scratch_allocator(max_scratch_size, context);
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemmBatchedWithScratch(
|
||||
blas_transpose_b, blas_transpose_a, n, m, k,
|
||||
static_cast<Coefficient>(1.0), b_ptrs,
|
||||
adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
|
||||
static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
|
||||
&scratch_allocator)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
context->SetStatus(errors::Internal(
|
||||
"Blas xGEMMBatched launch failed : a.shape=",
|
||||
in_x.shape().DebugString(),
|
||||
", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
|
||||
", k=", k, ", batch_size=", batch_size));
|
||||
}
|
||||
}
|
||||
return;
|
||||
#else // if not GOOGLE_CUDA or CUDA_VERSION < 11000
|
||||
bool is_full_broadcast =
|
||||
std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
|
||||
bool use_strided_batched =
|
||||
@ -743,6 +388,8 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
}
|
||||
}
|
||||
|
||||
typedef Scalar Coefficient;
|
||||
|
||||
// Blas does
|
||||
// C = A x B
|
||||
// where A, B and C are assumed to be in column major.
|
||||
@ -752,10 +399,7 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
if (batch_size == 1) {
|
||||
// This is a regular matrix*matrix or matrix*vector multiply. Avoid the
|
||||
// overhead of the scratch allocator and the batch interface.
|
||||
// Note that the GEMV call here does not support Eigen::half, so we do not
|
||||
// use this path in that case. A workaround is applied to the pointers
|
||||
// passed to the call itself to avoid compilation errors.
|
||||
if (!std::is_same<Scalar, Eigen::half>::value && n == 1 &&
|
||||
if (n == 1 &&
|
||||
blas_transpose_b != se::blas::Transpose::kConjugateTranspose &&
|
||||
blas_transpose_a != se::blas::Transpose::kConjugateTranspose) {
|
||||
// This is a matrix*vector multiply so use GEMV to compute A * b.
|
||||
@ -766,19 +410,13 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
auto gemv_trans_a = blas_transpose_a == se::blas::Transpose::kTranspose
|
||||
? se::blas::Transpose::kNoTranspose
|
||||
: se::blas::Transpose::kTranspose;
|
||||
// Cast pointers as a workaround for GEMV not supporting Eigen::half
|
||||
// (this will never actually be executed for Eigen::half).
|
||||
typedef se::DeviceMemory<Coefficient> NonHalfDeviceMemoryType;
|
||||
NonHalfDeviceMemoryType a_ptr(*(a_ptrs[0]));
|
||||
NonHalfDeviceMemoryType b_ptr(*(b_ptrs[0]));
|
||||
NonHalfDeviceMemoryType c_ptr(*(c_ptrs[0]));
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemv(gemv_trans_a, adj_x || trans_x ? m : k,
|
||||
adj_x || trans_x ? k : m,
|
||||
static_cast<Coefficient>(1.0), a_ptr,
|
||||
adj_x || trans_x ? m : k, b_ptr, 1,
|
||||
static_cast<Coefficient>(0.0), &c_ptr, 1)
|
||||
static_cast<Coefficient>(1.0), *(a_ptrs[0]),
|
||||
adj_x || trans_x ? m : k, *(b_ptrs[0]), 1,
|
||||
static_cast<Coefficient>(0.0), c_ptrs[0], 1)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
context->SetStatus(errors::Internal(
|
||||
@ -821,7 +459,154 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
", k=", k, ", batch_size=", batch_size));
|
||||
}
|
||||
} else {
|
||||
BlasScratchAllocator scratch_allocator(max_scratch_size, context);
|
||||
BlasScratchAllocator scratch_allocator(context);
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemmBatchedWithScratch(
|
||||
blas_transpose_b, blas_transpose_a, n, m, k,
|
||||
static_cast<Coefficient>(1.0), b_ptrs,
|
||||
adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
|
||||
static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
|
||||
&scratch_allocator)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
context->SetStatus(errors::Internal(
|
||||
"Blas xGEMMBatched launch failed : a.shape=",
|
||||
in_x.shape().DebugString(),
|
||||
", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
|
||||
", k=", k, ", batch_size=", batch_size));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
|
||||
static void Launch(OpKernelContext* context, const Tensor& in_x,
|
||||
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
|
||||
bool trans_y, const MatMulBCast& bcast, Tensor* out) {
|
||||
typedef Eigen::half Scalar;
|
||||
se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
|
||||
se::blas::Transpose::kTranspose,
|
||||
se::blas::Transpose::kConjugateTranspose};
|
||||
const uint64 m = in_x.dim_size(adj_x || trans_x ? 2 : 1);
|
||||
const uint64 k = in_x.dim_size(adj_x || trans_x ? 1 : 2);
|
||||
const uint64 n = in_y.dim_size(adj_y || trans_y ? 1 : 2);
|
||||
const uint64 batch_size = bcast.output_batch_size();
|
||||
auto blas_transpose_a = trans[adj_x ? 2 : (trans_x ? 1 : 0)];
|
||||
auto blas_transpose_b = trans[adj_y ? 2 : (trans_y ? 1 : 0)];
|
||||
|
||||
auto* stream = context->op_device_context()->stream();
|
||||
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
|
||||
|
||||
typedef perftools::gputools::DeviceMemory<Scalar> DeviceMemoryType;
|
||||
std::vector<DeviceMemoryType> a_device_memory;
|
||||
std::vector<DeviceMemoryType> b_device_memory;
|
||||
std::vector<DeviceMemoryType> c_device_memory;
|
||||
std::vector<DeviceMemoryType*> a_ptrs;
|
||||
std::vector<DeviceMemoryType*> b_ptrs;
|
||||
std::vector<DeviceMemoryType*> c_ptrs;
|
||||
a_device_memory.reserve(bcast.x_batch_size());
|
||||
b_device_memory.reserve(bcast.y_batch_size());
|
||||
c_device_memory.reserve(batch_size);
|
||||
a_ptrs.reserve(batch_size);
|
||||
b_ptrs.reserve(batch_size);
|
||||
c_ptrs.reserve(batch_size);
|
||||
auto* a_base_ptr = in_x.template flat<Scalar>().data();
|
||||
auto* b_base_ptr = in_y.template flat<Scalar>().data();
|
||||
auto* c_base_ptr = out->template flat<Scalar>().data();
|
||||
|
||||
uint64 a_stride;
|
||||
uint64 b_stride;
|
||||
uint64 c_stride;
|
||||
|
||||
bool is_full_broadcast =
|
||||
std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
|
||||
bool use_strided_batched =
|
||||
(!bcast.IsBroadcastingRequired() || is_full_broadcast) &&
|
||||
batch_size > 1;
|
||||
if (use_strided_batched) {
|
||||
a_stride = bcast.x_batch_size() != 1 ? m * k : 0;
|
||||
b_stride = bcast.y_batch_size() != 1 ? k * n : 0;
|
||||
c_stride = m * n;
|
||||
a_device_memory.push_back(AsDeviceMemory(a_base_ptr));
|
||||
b_device_memory.push_back(AsDeviceMemory(b_base_ptr));
|
||||
c_device_memory.push_back(AsDeviceMemory(c_base_ptr));
|
||||
a_ptrs.push_back(&a_device_memory.back());
|
||||
b_ptrs.push_back(&b_device_memory.back());
|
||||
c_ptrs.push_back(&c_device_memory.back());
|
||||
} else if (!bcast.IsBroadcastingRequired()) {
|
||||
for (int64 i = 0; i < batch_size; ++i) {
|
||||
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
|
||||
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
|
||||
c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
|
||||
a_ptrs.push_back(&a_device_memory.back());
|
||||
b_ptrs.push_back(&b_device_memory.back());
|
||||
c_ptrs.push_back(&c_device_memory.back());
|
||||
}
|
||||
} else {
|
||||
const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
|
||||
const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
|
||||
for (int64 i = 0; i < bcast.x_batch_size(); ++i) {
|
||||
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
|
||||
}
|
||||
for (int64 i = 0; i < bcast.y_batch_size(); ++i) {
|
||||
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
|
||||
}
|
||||
for (int64 i = 0; i < batch_size; ++i) {
|
||||
c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
|
||||
a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
|
||||
b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
|
||||
c_ptrs.push_back(&c_device_memory.back());
|
||||
}
|
||||
}
|
||||
|
||||
typedef float Coefficient;
|
||||
|
||||
// Blas does
|
||||
// C = A x B
|
||||
// where A, B and C are assumed to be in column major.
|
||||
// We want the output to be in row-major, so we can compute
|
||||
// C' = B' x A', where ' stands for transpose (not adjoint).
|
||||
// TODO(yangzihao): Choose the best of the three strategies using autotune.
|
||||
if (batch_size == 1) {
|
||||
// This is a regular matrix*matrix or matrix*vector multiply. Avoid the
|
||||
// overhead of the scratch allocator and the batch interface.
|
||||
// TODO(benbarsdell): Use fp16 Gemv if it becomes supported by CUBLAS
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
|
||||
static_cast<Coefficient>(1.0), *(b_ptrs[0]),
|
||||
adj_y || trans_y ? k : n, *(a_ptrs[0]),
|
||||
adj_x || trans_x ? m : k,
|
||||
static_cast<Coefficient>(0.0), c_ptrs[0], n)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
context->SetStatus(errors::Internal(
|
||||
"Blas xGEMM launch failed : a.shape=", in_x.shape().DebugString(),
|
||||
", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
|
||||
", k=", k));
|
||||
}
|
||||
} else if (use_strided_batched) {
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemmStridedBatched(
|
||||
blas_transpose_b, blas_transpose_a, n, m, k,
|
||||
static_cast<Coefficient>(1.0), *b_ptrs[0],
|
||||
adj_y || trans_y ? k : n, b_stride, *a_ptrs[0],
|
||||
adj_x || trans_x ? m : k, a_stride,
|
||||
static_cast<Coefficient>(0.0), c_ptrs[0], n, c_stride,
|
||||
batch_size)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
context->SetStatus(errors::Internal(
|
||||
"Blas xGEMMStridedBatched launch failed : a.shape=",
|
||||
in_x.shape().DebugString(),
|
||||
", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
|
||||
", k=", k, ", batch_size=", batch_size));
|
||||
}
|
||||
} else {
|
||||
BlasScratchAllocator scratch_allocator(context);
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemmBatchedWithScratch(
|
||||
@ -839,7 +624,6 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
", k=", k, ", batch_size=", batch_size));
|
||||
}
|
||||
}
|
||||
#endif // not GOOGLE_CUDA or CUDA_VERSION < 11000
|
||||
}
|
||||
};
|
||||
|
||||
@ -853,7 +637,6 @@ class BaseBatchMatMulOp : public OpKernel {
|
||||
: OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
|
||||
use_autotune_ = MatmulAutotuneEnable();
|
||||
}
|
||||
|
||||
~BaseBatchMatMulOp() override {}
|
||||
@ -915,7 +698,7 @@ class BaseBatchMatMulOp : public OpKernel {
|
||||
out->shape().DebugString()));
|
||||
LaunchBatchMatMul<Device, Scalar>::Launch(
|
||||
ctx, in0_reshaped, in1_reshaped, adj_x_, adj_y_, /*trans_x=*/false,
|
||||
/*trans_y=*/false, bcast, use_autotune_, &out_reshaped);
|
||||
/*trans_y=*/false, bcast, &out_reshaped);
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -925,7 +708,6 @@ class BaseBatchMatMulOp : public OpKernel {
|
||||
private:
|
||||
bool adj_x_;
|
||||
bool adj_y_;
|
||||
bool use_autotune_;
|
||||
};
|
||||
|
||||
// BatchMatMul Op implementation which disallows broadcasting.
|
||||
|
@ -619,7 +619,19 @@ template struct LaunchConv2DOp<CPUDevice, double>;
|
||||
|
||||
int64 GetDnnWorkspaceLimit(const string& envvar_in_mb,
|
||||
int64 default_value_in_bytes) {
|
||||
return gpu_utils::GetWorkspaceLimit(envvar_in_mb, default_value_in_bytes);
|
||||
const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
|
||||
if (workspace_limit_in_mb_str != nullptr &&
|
||||
strcmp(workspace_limit_in_mb_str, "") != 0) {
|
||||
int64 scratch_limit_in_mb = -1;
|
||||
if (strings::safe_strto64(workspace_limit_in_mb_str,
|
||||
&scratch_limit_in_mb)) {
|
||||
return scratch_limit_in_mb * (1 << 20);
|
||||
} else {
|
||||
LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
|
||||
<< workspace_limit_in_mb_str;
|
||||
}
|
||||
}
|
||||
return default_value_in_bytes;
|
||||
}
|
||||
|
||||
// A dummy type to group forward convolution autotune results together.
|
||||
|
@ -48,7 +48,52 @@ int64 GetDnnWorkspaceLimit(const string& envvar_in_mb,
|
||||
// A class to provide scratch-space allocator for Stream-Executor Cudnn
|
||||
// callback. TensorFlow is responsible for releasing the temporary buffers after
|
||||
// the kernel finishes.
|
||||
using DnnScratchAllocator = GpuScratchAllocator;
|
||||
class DnnScratchAllocator : public se::ScratchAllocator {
|
||||
public:
|
||||
virtual ~DnnScratchAllocator() {}
|
||||
DnnScratchAllocator(int64 memory_limit, OpKernelContext* context)
|
||||
: memory_limit_(memory_limit), total_byte_size_(0), context_(context) {}
|
||||
int64 GetMemoryLimitInBytes() override { return memory_limit_; }
|
||||
se::port::StatusOr<se::DeviceMemory<uint8>> AllocateBytes(
|
||||
int64 byte_size) override {
|
||||
Tensor temporary_memory;
|
||||
if (byte_size < 0) {
|
||||
return se::port::Status{se::port::error::INVALID_ARGUMENT,
|
||||
"Requested negative byte size!"};
|
||||
}
|
||||
if (byte_size > memory_limit_) {
|
||||
return se::port::Status{se::port::error::UNAVAILABLE,
|
||||
absl::StrCat("Requested memory size (", byte_size,
|
||||
") exceeds the max memory limit (",
|
||||
memory_limit_, ").")};
|
||||
}
|
||||
AllocationAttributes allocation_attr;
|
||||
allocation_attr.retry_on_failure = false;
|
||||
Status allocation_status(context_->allocate_temp(
|
||||
DT_UINT8, TensorShape({byte_size}), &temporary_memory,
|
||||
AllocatorAttributes(), allocation_attr));
|
||||
if (!allocation_status.ok()) {
|
||||
return se::port::Status{
|
||||
se::port::error::UNAVAILABLE,
|
||||
absl::StrCat("Failed to allocate the requested memory size (",
|
||||
byte_size, ").")};
|
||||
}
|
||||
// Hold the reference of the allocated tensors until the end of the
|
||||
// allocator.
|
||||
allocated_tensors_.push_back(temporary_memory);
|
||||
total_byte_size_ += byte_size;
|
||||
return se::port::StatusOr<se::DeviceMemory<uint8>>(
|
||||
AsDeviceMemory(temporary_memory.flat<uint8>().data(),
|
||||
temporary_memory.flat<uint8>().size()));
|
||||
}
|
||||
int64 TotalByteSize() { return total_byte_size_; }
|
||||
|
||||
private:
|
||||
int64 memory_limit_;
|
||||
int64 total_byte_size_;
|
||||
OpKernelContext* context_;
|
||||
std::vector<Tensor> allocated_tensors_;
|
||||
};
|
||||
|
||||
// Encapsulate all the shape information that is used in both forward and
|
||||
// backward conv operations.
|
||||
|
@ -31,7 +31,6 @@ limitations under the License.
|
||||
|
||||
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
|
||||
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
|
||||
#include "tensorflow/core/kernels/gpu_utils.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
@ -401,7 +400,20 @@ class CufftScratchAllocator : public se::ScratchAllocator {
|
||||
|
||||
int64 GetCufftWorkspaceLimit(const string& envvar_in_mb,
|
||||
int64 default_value_in_bytes) {
|
||||
return gpu_utils::GetWorkspaceLimit(envvar_in_mb, default_value_in_bytes);
|
||||
const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
|
||||
if (workspace_limit_in_mb_str != nullptr &&
|
||||
strcmp(workspace_limit_in_mb_str, "") != 0) {
|
||||
int64 scratch_limit_in_mb = -1;
|
||||
Status status = ReadInt64FromEnvVar(envvar_in_mb, default_value_in_bytes,
|
||||
&scratch_limit_in_mb);
|
||||
if (!status.ok()) {
|
||||
LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
|
||||
<< workspace_limit_in_mb_str;
|
||||
} else {
|
||||
return scratch_limit_in_mb * (1 << 20);
|
||||
}
|
||||
}
|
||||
return default_value_in_bytes;
|
||||
}
|
||||
|
||||
class FFTGPUBase : public FFTBase {
|
||||
|
@ -22,7 +22,6 @@ limitations under the License.
|
||||
#include "google/protobuf/any.pb.h"
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/base/call_once.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/platform/logger.h"
|
||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||
#include "tensorflow/core/protobuf/conv_autotuning.pb.h"
|
||||
@ -283,64 +282,6 @@ Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace gpu_utils {
|
||||
int64 GetWorkspaceLimit(const string& envvar_in_mb,
|
||||
int64 default_value_in_bytes) {
|
||||
const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
|
||||
if (workspace_limit_in_mb_str != nullptr &&
|
||||
strcmp(workspace_limit_in_mb_str, "") != 0) {
|
||||
int64 scratch_limit_in_mb = -1;
|
||||
if (strings::safe_strto64(workspace_limit_in_mb_str,
|
||||
&scratch_limit_in_mb)) {
|
||||
return scratch_limit_in_mb * (1 << 20);
|
||||
} else {
|
||||
LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
|
||||
<< workspace_limit_in_mb_str;
|
||||
}
|
||||
}
|
||||
return default_value_in_bytes;
|
||||
}
|
||||
} // namespace gpu_utils
|
||||
|
||||
GpuScratchAllocator::GpuScratchAllocator(int64 memory_limit,
|
||||
OpKernelContext* context)
|
||||
: memory_limit_(memory_limit), total_byte_size_(0), context_(context) {}
|
||||
|
||||
se::port::StatusOr<se::DeviceMemory<uint8>> GpuScratchAllocator::AllocateBytes(
|
||||
int64 byte_size) {
|
||||
Tensor temporary_memory;
|
||||
if (byte_size < 0) {
|
||||
return se::port::Status{se::port::error::INVALID_ARGUMENT,
|
||||
"Requested negative byte size!"};
|
||||
}
|
||||
if (byte_size > memory_limit_) {
|
||||
return se::port::Status{
|
||||
se::port::error::UNAVAILABLE,
|
||||
absl::StrCat("Requested memory size (", byte_size,
|
||||
") exceeds the max memory limit (", memory_limit_, ").")};
|
||||
}
|
||||
AllocationAttributes allocation_attr;
|
||||
allocation_attr.retry_on_failure = false;
|
||||
Status allocation_status(context_->allocate_temp(
|
||||
DT_UINT8, TensorShape({byte_size}), &temporary_memory,
|
||||
AllocatorAttributes(), allocation_attr));
|
||||
if (!allocation_status.ok()) {
|
||||
return se::port::Status{
|
||||
se::port::error::UNAVAILABLE,
|
||||
absl::StrCat("Failed to allocate the requested memory size (",
|
||||
byte_size, ").")};
|
||||
}
|
||||
// Hold the reference of the allocated tensors until the end of the
|
||||
// allocator.
|
||||
// NOTE: We expect tensors to be deallocated when this allocator goes out of
|
||||
// scope when allocated_tensors is destructed.
|
||||
allocated_tensors_.push_back(temporary_memory);
|
||||
total_byte_size_ += byte_size;
|
||||
return se::port::StatusOr<se::DeviceMemory<uint8>>(
|
||||
AsDeviceMemory(temporary_memory.flat<uint8>().data(),
|
||||
temporary_memory.flat<uint8>().size()));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -243,37 +243,6 @@ void LogFusedConvForwardAutotuneResults(
|
||||
Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
|
||||
se::dnn::AlgorithmConfig* algo);
|
||||
|
||||
namespace gpu_utils {
|
||||
// Get a workspace limit from the environment variable, which is in MB.
|
||||
// Return the workspace memory limit in bytes. If no value is set, return the
|
||||
// default value.
|
||||
int64 GetWorkspaceLimit(const string& envvar_in_mb,
|
||||
int64 default_value_in_bytes);
|
||||
} // namespace gpu_utils
|
||||
|
||||
// A class to provide scratch-space allocator for Stream-Executor callbacks in
|
||||
// CUDA libraries (CUDNN etc.).
|
||||
// TensorFlow is responsible for releasing the temporary buffers after
|
||||
// the kernel finishes.
|
||||
class GpuScratchAllocator : public se::ScratchAllocator {
|
||||
public:
|
||||
virtual ~GpuScratchAllocator() {}
|
||||
|
||||
GpuScratchAllocator(int64 memory_limit, OpKernelContext* context);
|
||||
|
||||
int64 GetMemoryLimitInBytes() override { return memory_limit_; }
|
||||
|
||||
se::port::StatusOr<se::DeviceMemory<uint8>> AllocateBytes(
|
||||
int64 byte_size) override;
|
||||
|
||||
int64 TotalByteSize() { return total_byte_size_; }
|
||||
|
||||
private:
|
||||
int64 memory_limit_;
|
||||
int64 total_byte_size_;
|
||||
OpKernelContext* context_;
|
||||
std::vector<Tensor> allocated_tensors_;
|
||||
};
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -549,7 +549,7 @@ struct EinsumHelper {
|
||||
static Status ContractOperands(OpKernelContext* ctx,
|
||||
absl::Span<const Tensor> inputs,
|
||||
absl::Span<const bool> swap_free_and_contract,
|
||||
bool use_autotune, Tensor* output) {
|
||||
Tensor* output) {
|
||||
if (inputs.size() == 1)
|
||||
return CopyFrom(inputs[0], inputs[0].shape(), output);
|
||||
MatMulBCast bcast(inputs[0].shape().dim_sizes(),
|
||||
@ -583,7 +583,7 @@ struct EinsumHelper {
|
||||
ReshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped));
|
||||
LaunchBatchMatMul<Device, T>::Launch(ctx, lhs, rhs, /*adj_x=*/false,
|
||||
/*adj_y=*/false, trans_x, trans_y,
|
||||
bcast, use_autotune, &output_reshaped);
|
||||
bcast, &output_reshaped);
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
@ -598,7 +598,6 @@ class EinsumOp : public OpKernel {
|
||||
equation_, &input_labels_, &output_labels_, &label_types_,
|
||||
&input_label_counts_, &output_label_counts_,
|
||||
&input_has_ellipsis_, &output_has_ellipsis_));
|
||||
use_autotune_ = MatmulAutotuneEnable();
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
@ -641,7 +640,7 @@ class EinsumOp : public OpKernel {
|
||||
Tensor contraction_output_reshaped;
|
||||
OP_REQUIRES_OK(ctx, EinsumHelper::ContractOperands<Device, T>(
|
||||
ctx, inputs_reduced, swap_free_and_contract,
|
||||
use_autotune_, &contraction_output_reshaped));
|
||||
&contraction_output_reshaped));
|
||||
|
||||
// Copy the batch labels from the contraction output. Recover the batch
|
||||
// shape, which may have been broadcasted.
|
||||
@ -739,7 +738,6 @@ class EinsumOp : public OpKernel {
|
||||
LabelCounts output_label_counts_;
|
||||
gtl::InlinedVector<bool, 2> input_has_ellipsis_;
|
||||
bool output_has_ellipsis_ = false;
|
||||
bool use_autotune_;
|
||||
};
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -48,22 +48,4 @@ bool MatmulDoFP32ComputationFP16Input() {
|
||||
return value;
|
||||
}
|
||||
|
||||
int MatmulMaxAutotuneAlgorithmCount() {
|
||||
int64 value;
|
||||
// In CUDA 11, cublasLtMatmulAlgoGetHeuristic typically returns <= 4
|
||||
// algorithms for a given configuration, so 10 seems like a reasonable default
|
||||
// here.
|
||||
Status status =
|
||||
ReadInt64FromEnvVar("TF_MATMUL_AUTOTUNE_MAX_ALGORITHMS", 10, &value);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << status.error_message();
|
||||
}
|
||||
static constexpr const int kMaxValue = std::numeric_limits<int>::max();
|
||||
if (value < 1 || value > kMaxValue) {
|
||||
LOG(ERROR) << "Invalid value for TF_MATMUL_AUTOTUNE_MAX_ALGORITHMS: "
|
||||
<< value << " is not in range [1, " << kMaxValue << "]";
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -22,7 +22,6 @@ namespace tensorflow {
|
||||
|
||||
bool MatmulAutotuneEnable();
|
||||
bool MatmulDoFP32ComputationFP16Input();
|
||||
int MatmulMaxAutotuneAlgorithmCount();
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
8
third_party/gpus/cuda/BUILD.tpl
vendored
8
third_party/gpus/cuda/BUILD.tpl
vendored
@ -127,13 +127,6 @@ cc_library(
|
||||
linkstatic = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cublasLt",
|
||||
srcs = ["cuda/lib/%{cublasLt_lib}"],
|
||||
data = ["cuda/lib/%{cublasLt_lib}"],
|
||||
linkstatic = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cusolver",
|
||||
srcs = ["cuda/lib/%{cusolver_lib}"],
|
||||
@ -175,7 +168,6 @@ cc_library(
|
||||
name = "cuda",
|
||||
deps = [
|
||||
":cublas",
|
||||
":cublasLt",
|
||||
":cuda_headers",
|
||||
":cudart",
|
||||
":cudnn",
|
||||
|
12
third_party/gpus/cuda_configure.bzl
vendored
12
third_party/gpus/cuda_configure.bzl
vendored
@ -551,13 +551,6 @@ def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config):
|
||||
cuda_config.cublas_version,
|
||||
static = False,
|
||||
),
|
||||
"cublasLt": _check_cuda_lib_params(
|
||||
"cublasLt",
|
||||
cpu_value,
|
||||
cuda_config.config["cublas_library_dir"],
|
||||
cuda_config.cublas_version,
|
||||
static = False,
|
||||
),
|
||||
"cusolver": _check_cuda_lib_params(
|
||||
"cusolver",
|
||||
cpu_value,
|
||||
@ -787,7 +780,6 @@ def _create_dummy_repository(repository_ctx):
|
||||
"%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value),
|
||||
"%{cudart_lib}": lib_name("cudart", cpu_value),
|
||||
"%{cublas_lib}": lib_name("cublas", cpu_value),
|
||||
"%{cublasLt_lib}": lib_name("cublasLt", cpu_value),
|
||||
"%{cusolver_lib}": lib_name("cusolver", cpu_value),
|
||||
"%{cudnn_lib}": lib_name("cudnn", cpu_value),
|
||||
"%{cufft_lib}": lib_name("cufft", cpu_value),
|
||||
@ -819,7 +811,6 @@ filegroup(name="cudnn-include")
|
||||
"cuda/cuda/lib/%s" % lib_name("cudart_static", cpu_value),
|
||||
)
|
||||
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublas", cpu_value))
|
||||
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublasLt", cpu_value))
|
||||
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusolver", cpu_value))
|
||||
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudnn", cpu_value))
|
||||
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("curand", cpu_value))
|
||||
@ -1011,13 +1002,11 @@ def _create_local_cuda_repository(repository_ctx):
|
||||
cublas_include_path + "/cublas.h",
|
||||
cublas_include_path + "/cublas_v2.h",
|
||||
cublas_include_path + "/cublas_api.h",
|
||||
cublas_include_path + "/cublasLt.h",
|
||||
],
|
||||
outs = [
|
||||
"cublas/include/cublas.h",
|
||||
"cublas/include/cublas_v2.h",
|
||||
"cublas/include/cublas_api.h",
|
||||
"cublas/include/cublasLt.h",
|
||||
],
|
||||
))
|
||||
|
||||
@ -1158,7 +1147,6 @@ def _create_local_cuda_repository(repository_ctx):
|
||||
"%{cudart_static_linkopt}": _cudart_static_linkopt(cuda_config.cpu_value),
|
||||
"%{cudart_lib}": _basename(repository_ctx, cuda_libs["cudart"]),
|
||||
"%{cublas_lib}": _basename(repository_ctx, cuda_libs["cublas"]),
|
||||
"%{cublasLt_lib}": _basename(repository_ctx, cuda_libs["cublasLt"]),
|
||||
"%{cusolver_lib}": _basename(repository_ctx, cuda_libs["cusolver"]),
|
||||
"%{cudnn_lib}": _basename(repository_ctx, cuda_libs["cudnn"]),
|
||||
"%{cufft_lib}": _basename(repository_ctx, cuda_libs["cufft"]),
|
||||
|
Loading…
Reference in New Issue
Block a user