Merge pull request #43237 from benbarsdell:cublaslt
PiperOrigin-RevId: 337382541 Change-Id: I949698ec93cb3c15654857768fcfce53984a97be
This commit is contained in:
commit
6859f52a3f
tensorflow
core
kernels
util
python/kernel_tests
stream_executor
third_party/gpus
@ -3347,6 +3347,9 @@ tf_kernel_library(
|
||||
prefix = "batch_matmul_op",
|
||||
deps = MATH_DEPS + [":eigen_contraction_kernel"] + if_mkl_ml([
|
||||
"//third_party/mkl:intel_binary_blob",
|
||||
]) + if_cuda_or_rocm([
|
||||
"//tensorflow/core/kernels:gpu_utils",
|
||||
"//tensorflow/core/platform:tensor_float_32_utils",
|
||||
]),
|
||||
)
|
||||
|
||||
@ -3404,7 +3407,7 @@ tf_kernel_library(
|
||||
name = "fft_ops",
|
||||
prefix = "fft_ops",
|
||||
deps = MATH_DEPS + [
|
||||
] + if_cuda([
|
||||
] + if_cuda_or_rocm([":gpu_utils"]) + if_cuda([
|
||||
"//tensorflow/core/platform/default/build_config:cufft_plugin",
|
||||
]),
|
||||
)
|
||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/matmul_autotune.h"
|
||||
#include "tensorflow/core/util/matmul_bcast.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
@ -43,8 +44,13 @@ limitations under the License.
|
||||
#endif
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/gpu_utils.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/platform/tensor_float_32_utils.h"
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/gpus/cuda/include/cuda.h" // For CUDA_VERSION
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -219,7 +225,8 @@ template <typename Scalar>
|
||||
struct LaunchBatchMatMul<CPUDevice, Scalar> {
|
||||
static void Launch(OpKernelContext* context, const Tensor& in_x,
|
||||
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
|
||||
bool trans_y, const MatMulBCast& bcast, Tensor* out) {
|
||||
bool trans_y, const MatMulBCast& bcast, bool use_autotune,
|
||||
Tensor* out) {
|
||||
typedef ParallelMatMulKernel<Scalar, Eigen::NumTraits<Scalar>::IsComplex>
|
||||
ParallelMatMulKernel;
|
||||
bool conjugate_result = false;
|
||||
@ -275,45 +282,212 @@ se::DeviceMemory<T> AsDeviceMemory(const T* gpu_memory) {
|
||||
return typed;
|
||||
}
|
||||
|
||||
class BlasScratchAllocator : public se::ScratchAllocator {
|
||||
using BlasScratchAllocator = GpuScratchAllocator;
|
||||
|
||||
int64 GetBlasWorkspaceLimit(const string& envvar_in_mb,
|
||||
int64 default_value_in_bytes) {
|
||||
return gpu_utils::GetWorkspaceLimit(envvar_in_mb, default_value_in_bytes);
|
||||
}
|
||||
|
||||
// Encapsulate all of the shape, dtype etc. information that defines a unique
|
||||
// batched matmul operation.
|
||||
class BatchMatmulParameters {
|
||||
public:
|
||||
using Stream = se::Stream;
|
||||
using DeviceMemoryBytes = se::DeviceMemory<uint8>;
|
||||
BatchMatmulParameters(bool trans_a, bool trans_b, bool adj_a, bool adj_b,
|
||||
uint64 m, uint64 n, uint64 k, uint64 batch_count,
|
||||
bool broadcast_a, bool broadcast_b, DataType dtype_ab,
|
||||
DataType dtype_cd, bool allow_tf32, int device_id)
|
||||
: trans_a_(trans_a),
|
||||
trans_b_(trans_b),
|
||||
adj_a_(adj_a),
|
||||
adj_b_(adj_b),
|
||||
m_(m),
|
||||
n_(n),
|
||||
k_(k),
|
||||
batch_count_(batch_count),
|
||||
broadcast_a_(broadcast_a),
|
||||
broadcast_b_(broadcast_b),
|
||||
dtype_ab_(dtype_ab),
|
||||
dtype_cd_(dtype_cd),
|
||||
allow_tf32_(allow_tf32),
|
||||
device_id_(device_id) {
|
||||
hash_code_ = trans_a;
|
||||
hash_code_ = Hash64Combine(hash_code_, trans_b);
|
||||
hash_code_ = Hash64Combine(hash_code_, adj_a);
|
||||
hash_code_ = Hash64Combine(hash_code_, adj_b);
|
||||
hash_code_ = Hash64Combine(hash_code_, m);
|
||||
hash_code_ = Hash64Combine(hash_code_, n);
|
||||
hash_code_ = Hash64Combine(hash_code_, k);
|
||||
hash_code_ = Hash64Combine(hash_code_, batch_count);
|
||||
hash_code_ = Hash64Combine(hash_code_, broadcast_a);
|
||||
hash_code_ = Hash64Combine(hash_code_, broadcast_b);
|
||||
hash_code_ = Hash64Combine(hash_code_, dtype_ab);
|
||||
hash_code_ = Hash64Combine(hash_code_, dtype_cd);
|
||||
hash_code_ = Hash64Combine(hash_code_, allow_tf32);
|
||||
hash_code_ = Hash64Combine(hash_code_, device_id);
|
||||
}
|
||||
bool operator==(const BatchMatmulParameters& other) const {
|
||||
return this->get_data_as_tuple() == other.get_data_as_tuple();
|
||||
}
|
||||
|
||||
BlasScratchAllocator(OpKernelContext* context) : context_(context) {}
|
||||
bool operator!=(const BatchMatmulParameters& other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
uint64 hash() const { return hash_code_; }
|
||||
|
||||
int64 GetMemoryLimitInBytes() override { return -1; }
|
||||
|
||||
se::port::StatusOr<DeviceMemoryBytes> AllocateBytes(
|
||||
int64 byte_size) override {
|
||||
Tensor temporary_memory;
|
||||
|
||||
Status allocation_status(context_->allocate_temp(
|
||||
DT_UINT8, TensorShape({byte_size}), &temporary_memory));
|
||||
if (!allocation_status.ok()) {
|
||||
return se::port::StatusOr<DeviceMemoryBytes>(
|
||||
DeviceMemoryBytes::MakeFromByteSize(nullptr, 0));
|
||||
}
|
||||
// Hold the reference of the allocated tensors until the end of the
|
||||
// allocator.
|
||||
allocated_tensors_.push_back(temporary_memory);
|
||||
return se::port::StatusOr<DeviceMemoryBytes>(
|
||||
DeviceMemoryBytes::MakeFromByteSize(
|
||||
temporary_memory.flat<uint8>().data(),
|
||||
temporary_memory.flat<uint8>().size()));
|
||||
string ToString() const {
|
||||
// clang-format off
|
||||
return strings::StrCat(
|
||||
trans_a_, ", ", trans_b_, ", ", adj_a_, ", ", adj_b_, ", ",
|
||||
m_, ", ", n_, ", ", k_, ", ", batch_count_, ", ",
|
||||
broadcast_a_, ", ", broadcast_b_, ", ",
|
||||
dtype_ab_, ", ", dtype_cd_, ", ", allow_tf32_, ", ", device_id_);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
private:
|
||||
OpKernelContext* context_;
|
||||
std::vector<Tensor> allocated_tensors_;
|
||||
typedef std::tuple<bool, bool, bool, bool, int64, int64, int64, int64, bool,
|
||||
bool, DataType, DataType, bool, int>
|
||||
ParameterDataType;
|
||||
|
||||
ParameterDataType get_data_as_tuple() const {
|
||||
return std::make_tuple(trans_a_, trans_b_, adj_a_, adj_b_, m_, n_, k_,
|
||||
batch_count_, broadcast_a_, broadcast_b_, dtype_ab_,
|
||||
dtype_cd_, allow_tf32_, device_id_);
|
||||
}
|
||||
|
||||
bool trans_a_;
|
||||
bool trans_b_;
|
||||
bool adj_a_;
|
||||
bool adj_b_;
|
||||
uint64 m_;
|
||||
uint64 n_;
|
||||
uint64 k_;
|
||||
uint64 batch_count_;
|
||||
bool broadcast_a_;
|
||||
bool broadcast_b_;
|
||||
DataType dtype_ab_;
|
||||
DataType dtype_cd_;
|
||||
bool allow_tf32_;
|
||||
int device_id_;
|
||||
uint64 hash_code_;
|
||||
};
|
||||
|
||||
bool GetBlasComputationType(const DataType& dtype, bool allow_tf32,
|
||||
se::blas::ComputationType* compute_type) {
|
||||
using se::blas::ComputationType;
|
||||
static bool use_f32_for_f16_computation = MatmulDoFP32ComputationFP16Input();
|
||||
ComputationType f32_type =
|
||||
allow_tf32 ? ComputationType::kTF32AsF32 : ComputationType::kF32;
|
||||
switch (dtype) {
|
||||
case DT_HALF:
|
||||
case DT_BFLOAT16:
|
||||
*compute_type =
|
||||
use_f32_for_f16_computation ? f32_type : ComputationType::kF16;
|
||||
return true;
|
||||
case DT_FLOAT:
|
||||
*compute_type = f32_type;
|
||||
return true;
|
||||
case DT_DOUBLE:
|
||||
*compute_type = ComputationType::kF64;
|
||||
return true;
|
||||
case DT_COMPLEX64:
|
||||
*compute_type = f32_type;
|
||||
return true;
|
||||
case DT_COMPLEX128:
|
||||
*compute_type = ComputationType::kComplexF64;
|
||||
return true;
|
||||
default:
|
||||
// Unsupported compute_type, return false.
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Thread-safe map from matmul parameters to their corresponding plan and
|
||||
// algorithms.
|
||||
template <typename Parameters>
|
||||
class BlasLtMatmulPlanMap {
|
||||
public:
|
||||
struct PlanAndAlgorithms {
|
||||
std::unique_ptr<se::blas::IBlasLtMatmulPlan> plan;
|
||||
std::vector<std::unique_ptr<se::blas::IBlasLtMatmulAlgorithm>> algorithms;
|
||||
};
|
||||
|
||||
const PlanAndAlgorithms* Find(const Parameters& params) {
|
||||
mutex_lock lock(mu_);
|
||||
auto iter = params_plan_map_.find(params);
|
||||
if (iter == params_plan_map_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return &iter->second;
|
||||
}
|
||||
const PlanAndAlgorithms* Insert(const Parameters& params,
|
||||
PlanAndAlgorithms value) {
|
||||
mutex_lock lock(mu_);
|
||||
return ¶ms_plan_map_.emplace(params, std::move(value)).first->second;
|
||||
}
|
||||
|
||||
private:
|
||||
struct Hasher {
|
||||
std::size_t operator()(const Parameters& parameter) const {
|
||||
return parameter.hash();
|
||||
}
|
||||
};
|
||||
|
||||
mutable mutex mu_;
|
||||
std::unordered_map<Parameters, PlanAndAlgorithms, Hasher> params_plan_map_
|
||||
GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
template <typename Parameters>
|
||||
struct BlasLtPlanMapSingleton {
|
||||
typedef BlasLtMatmulPlanMap<Parameters> PlanMapType;
|
||||
static PlanMapType* GetInstance() {
|
||||
static PlanMapType* instance = new PlanMapType();
|
||||
return instance;
|
||||
}
|
||||
};
|
||||
|
||||
typedef BlasLtPlanMapSingleton<BatchMatmulParameters>
|
||||
BatchMatmulPlanMapSingleton;
|
||||
|
||||
// A dummy type to group matmul autotune results together.
|
||||
struct BatchMatmulAutoTuneGroup {
|
||||
static string name() { return "MatmulLt"; }
|
||||
};
|
||||
|
||||
typedef AutoTuneSingleton<BatchMatmulAutoTuneGroup, BatchMatmulParameters,
|
||||
se::blas::AlgorithmConfig>
|
||||
AutoTuneBatchMatmul;
|
||||
|
||||
template <typename Scalar>
|
||||
struct CoefficientType {
|
||||
typedef Scalar type;
|
||||
};
|
||||
template <>
|
||||
struct CoefficientType<Eigen::half> {
|
||||
typedef float type;
|
||||
};
|
||||
|
||||
inline Status FromExecutorStatus(const se::port::Status& s) {
|
||||
return s.ok() ? Status::OK()
|
||||
: Status(static_cast<error::Code>(static_cast<int>(s.code())),
|
||||
s.error_message());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline Status FromExecutorStatus(const se::port::StatusOr<T>& s) {
|
||||
return FromExecutorStatus(s.status());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Scalar>
|
||||
struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
static void Launch(OpKernelContext* context, const Tensor& in_x,
|
||||
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
|
||||
bool trans_y, const MatMulBCast& bcast, Tensor* out) {
|
||||
bool trans_y, const MatMulBCast& bcast, bool use_autotune,
|
||||
Tensor* out) {
|
||||
se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
|
||||
se::blas::Transpose::kTranspose,
|
||||
se::blas::Transpose::kConjugateTranspose};
|
||||
@ -343,10 +517,191 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
auto* a_base_ptr = in_x.template flat<Scalar>().data();
|
||||
auto* b_base_ptr = in_y.template flat<Scalar>().data();
|
||||
auto* c_base_ptr = out->template flat<Scalar>().data();
|
||||
uint64 a_stride;
|
||||
uint64 b_stride;
|
||||
uint64 c_stride;
|
||||
int64 a_stride;
|
||||
int64 b_stride;
|
||||
int64 c_stride;
|
||||
|
||||
typedef typename CoefficientType<Scalar>::type Coefficient;
|
||||
|
||||
static const int64 max_scratch_size = GetBlasWorkspaceLimit(
|
||||
"TF_CUBLAS_WORKSPACE_LIMIT_IN_MB", 1LL << 32); // 4GB by default
|
||||
|
||||
// The BlasLtMatmul routines are only supported from CUDA 11.0 onward.
|
||||
#if GOOGLE_CUDA && CUDA_VERSION >= 11000
|
||||
bool is_full_broadcast =
|
||||
std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
|
||||
bool requires_mixed_broadcasting =
|
||||
bcast.IsBroadcastingRequired() && !is_full_broadcast;
|
||||
if (!requires_mixed_broadcasting) {
|
||||
bool broadcast_a = bcast.x_batch_size() == 1;
|
||||
bool broadcast_b = bcast.y_batch_size() == 1;
|
||||
a_stride = broadcast_a ? 0 : m * k;
|
||||
b_stride = broadcast_b ? 0 : k * n;
|
||||
c_stride = m * n;
|
||||
a_device_memory.push_back(AsDeviceMemory(a_base_ptr));
|
||||
b_device_memory.push_back(AsDeviceMemory(b_base_ptr));
|
||||
c_device_memory.push_back(AsDeviceMemory(c_base_ptr));
|
||||
a_ptrs.push_back(&a_device_memory.back());
|
||||
b_ptrs.push_back(&b_device_memory.back());
|
||||
c_ptrs.push_back(&c_device_memory.back());
|
||||
|
||||
DataType dtype = DataTypeToEnum<Scalar>::value;
|
||||
bool allow_tf32 = tensor_float_32_execution_enabled();
|
||||
int device_id = stream->parent()->device_ordinal();
|
||||
BatchMatmulParameters matmul_parameters(
|
||||
trans_x, trans_y, adj_x, adj_y, m, n, k, batch_size, broadcast_a,
|
||||
broadcast_b, dtype, dtype, allow_tf32, device_id);
|
||||
|
||||
static const bool max_autotune_algorithm_count =
|
||||
MatmulMaxAutotuneAlgorithmCount();
|
||||
int max_algorithm_count = use_autotune ? max_autotune_algorithm_count : 1;
|
||||
|
||||
const auto* plan_and_algorithms =
|
||||
BatchMatmulPlanMapSingleton::GetInstance()->Find(matmul_parameters);
|
||||
if (!plan_and_algorithms) {
|
||||
se::blas::DataType blas_dtype = se::blas::ToDataType<Scalar>::value;
|
||||
se::blas::ComputationType computation_type;
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
GetBlasComputationType(dtype, allow_tf32, &computation_type),
|
||||
errors::Internal("Unsupported dtype for batched matmul"));
|
||||
|
||||
auto status_or_plan = stream->parent()->CreateBlasLtMatmulPlan(
|
||||
{/*ab_type=*/blas_dtype,
|
||||
/*c_type=*/blas_dtype, computation_type,
|
||||
se::blas::PointerMode::kHost, se::blas::Epilogue::kDefault,
|
||||
blas_transpose_b, blas_transpose_a, n, m, k,
|
||||
/*lda=*/in_y.dim_size(2), /*ldb=*/in_x.dim_size(2),
|
||||
/*ldc=*/static_cast<int64>(n), static_cast<int>(batch_size),
|
||||
b_stride, a_stride, c_stride});
|
||||
OP_REQUIRES(context, status_or_plan.ok(),
|
||||
FromExecutorStatus(status_or_plan));
|
||||
std::unique_ptr<se::blas::IBlasLtMatmulPlan> plan =
|
||||
status_or_plan.ConsumeValueOrDie();
|
||||
|
||||
auto status_or_algorithms = stream->parent()->GetBlasLtMatmulAlgorithms(
|
||||
plan.get(), max_scratch_size, max_algorithm_count);
|
||||
OP_REQUIRES(context, status_or_algorithms.ok(),
|
||||
FromExecutorStatus(status_or_algorithms));
|
||||
auto algorithms = status_or_algorithms.ConsumeValueOrDie();
|
||||
|
||||
plan_and_algorithms =
|
||||
BatchMatmulPlanMapSingleton::GetInstance()->Insert(
|
||||
matmul_parameters, {std::move(plan), std::move(algorithms)});
|
||||
}
|
||||
const auto& plan = plan_and_algorithms->plan;
|
||||
const auto& algorithms = plan_and_algorithms->algorithms;
|
||||
|
||||
// The BlasLtMatmul routines (unlike BlasGemm, BlasGemmBatched etc.) take
|
||||
// alpha and beta with the same type as the matrices.
|
||||
Scalar alpha(1.0);
|
||||
Scalar beta(0.0);
|
||||
|
||||
// Note that algorithm_config.algorithm() here is used to refer
|
||||
// to the index within the algorithms vector, not the algorithm
|
||||
// itself.
|
||||
se::blas::AlgorithmConfig algorithm_config(se::blas::kNoAlgorithm);
|
||||
if (max_algorithm_count == 1) {
|
||||
algorithm_config.set_algorithm(0);
|
||||
} else if (!AutoTuneBatchMatmul::GetInstance()->Find(matmul_parameters,
|
||||
&algorithm_config)) {
|
||||
VLOG(4) << "Autotuning BlasLtMatmul over " << algorithms.size()
|
||||
<< " algorithms.";
|
||||
se::blas::ProfileResult best_result;
|
||||
se::blas::ProfileResult profile_result;
|
||||
// for (const auto& profile_algorithm : plan_and_algorithms->algorithms)
|
||||
// {
|
||||
for (size_t i = 0; i != algorithms.size(); ++i) {
|
||||
const auto& profile_algorithm = algorithms[i];
|
||||
// Create a new scratch allocator with every autotuning run so that
|
||||
// scratch space is deallocated between runs.
|
||||
BlasScratchAllocator scratch_allocator(max_scratch_size, context);
|
||||
|
||||
bool cublas_launch_status =
|
||||
stream
|
||||
->ThenBlasLtMatmul(plan.get(), alpha, *b_ptrs[0], *a_ptrs[0],
|
||||
beta, c_ptrs[0], &scratch_allocator,
|
||||
profile_algorithm.get(), {},
|
||||
&profile_result)
|
||||
.ok();
|
||||
|
||||
VLOG(4) << " Autotune algorithm " << i
|
||||
<< " result: " << profile_result.elapsed_time_in_ms()
|
||||
<< " ms, valid=" << profile_result.is_valid()
|
||||
<< ", workspace_size=" << profile_algorithm->workspace_size();
|
||||
|
||||
if (cublas_launch_status && profile_result.is_valid() &&
|
||||
profile_result.elapsed_time_in_ms() <
|
||||
best_result.elapsed_time_in_ms()) {
|
||||
best_result = profile_result;
|
||||
}
|
||||
}
|
||||
|
||||
if (best_result.is_valid()) {
|
||||
algorithm_config.set_algorithm(best_result.algorithm());
|
||||
}
|
||||
// We make sure that each matmul parameter set only gets one pass of
|
||||
// autotune. If no algorithms works, we add kNoAlgorithm to the autotune
|
||||
// map.
|
||||
AutoTuneBatchMatmul::GetInstance()->Insert(matmul_parameters,
|
||||
algorithm_config);
|
||||
}
|
||||
se::blas::AlgorithmType algorithm_idx = algorithm_config.algorithm();
|
||||
OP_REQUIRES(context,
|
||||
0 <= algorithm_idx && algorithm_idx < algorithms.size(),
|
||||
errors::Internal("Missing/invalid BatchMatmul algorithm"));
|
||||
const auto& algorithm = algorithms[algorithm_idx];
|
||||
BlasScratchAllocator scratch_allocator(max_scratch_size, context);
|
||||
bool cublas_launch_status =
|
||||
stream
|
||||
->ThenBlasLtMatmul(plan.get(), alpha, *b_ptrs[0], *a_ptrs[0],
|
||||
beta, c_ptrs[0], &scratch_allocator,
|
||||
algorithm.get())
|
||||
.ok();
|
||||
if (!cublas_launch_status) {
|
||||
context->SetStatus(errors::Internal(
|
||||
"Blas batched matmul launch failed : a.shape=(",
|
||||
bcast.x_batch_size(), ", ", in_x.dim_size(0), ", ",
|
||||
in_x.dim_size(1), "), b.shape=(", bcast.y_batch_size(), ", ",
|
||||
in_y.dim_size(0), ", ", in_y.dim_size(1), "), m=", m, ", n=", n,
|
||||
", k=", k, ", batch_size=", batch_size));
|
||||
}
|
||||
} else { // requires mixed broadcasting
|
||||
const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
|
||||
const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
|
||||
for (int64 i = 0; i < bcast.x_batch_size(); ++i) {
|
||||
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
|
||||
}
|
||||
for (int64 i = 0; i < bcast.y_batch_size(); ++i) {
|
||||
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
|
||||
}
|
||||
for (int64 i = 0; i < batch_size; ++i) {
|
||||
c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
|
||||
a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
|
||||
b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
|
||||
c_ptrs.push_back(&c_device_memory.back());
|
||||
}
|
||||
|
||||
BlasScratchAllocator scratch_allocator(max_scratch_size, context);
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemmBatchedWithScratch(
|
||||
blas_transpose_b, blas_transpose_a, n, m, k,
|
||||
static_cast<Coefficient>(1.0), b_ptrs,
|
||||
adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
|
||||
static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
|
||||
&scratch_allocator)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
context->SetStatus(errors::Internal(
|
||||
"Blas xGEMMBatched launch failed : a.shape=",
|
||||
in_x.shape().DebugString(),
|
||||
", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
|
||||
", k=", k, ", batch_size=", batch_size));
|
||||
}
|
||||
}
|
||||
return;
|
||||
#else // if not GOOGLE_CUDA or CUDA_VERSION < 11000
|
||||
bool is_full_broadcast =
|
||||
std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
|
||||
bool use_strided_batched =
|
||||
@ -388,8 +743,6 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
}
|
||||
}
|
||||
|
||||
typedef Scalar Coefficient;
|
||||
|
||||
// Blas does
|
||||
// C = A x B
|
||||
// where A, B and C are assumed to be in column major.
|
||||
@ -399,7 +752,10 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
if (batch_size == 1) {
|
||||
// This is a regular matrix*matrix or matrix*vector multiply. Avoid the
|
||||
// overhead of the scratch allocator and the batch interface.
|
||||
if (n == 1 &&
|
||||
// Note that the GEMV call here does not support Eigen::half, so we do not
|
||||
// use this path in that case. A workaround is applied to the pointers
|
||||
// passed to the call itself to avoid compilation errors.
|
||||
if (!std::is_same<Scalar, Eigen::half>::value && n == 1 &&
|
||||
blas_transpose_b != se::blas::Transpose::kConjugateTranspose &&
|
||||
blas_transpose_a != se::blas::Transpose::kConjugateTranspose) {
|
||||
// This is a matrix*vector multiply so use GEMV to compute A * b.
|
||||
@ -410,13 +766,19 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
auto gemv_trans_a = blas_transpose_a == se::blas::Transpose::kTranspose
|
||||
? se::blas::Transpose::kNoTranspose
|
||||
: se::blas::Transpose::kTranspose;
|
||||
// Cast pointers as a workaround for GEMV not supporting Eigen::half
|
||||
// (this will never actually be executed for Eigen::half).
|
||||
typedef se::DeviceMemory<Coefficient> NonHalfDeviceMemoryType;
|
||||
NonHalfDeviceMemoryType a_ptr(*(a_ptrs[0]));
|
||||
NonHalfDeviceMemoryType b_ptr(*(b_ptrs[0]));
|
||||
NonHalfDeviceMemoryType c_ptr(*(c_ptrs[0]));
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemv(gemv_trans_a, adj_x || trans_x ? m : k,
|
||||
adj_x || trans_x ? k : m,
|
||||
static_cast<Coefficient>(1.0), *(a_ptrs[0]),
|
||||
adj_x || trans_x ? m : k, *(b_ptrs[0]), 1,
|
||||
static_cast<Coefficient>(0.0), c_ptrs[0], 1)
|
||||
static_cast<Coefficient>(1.0), a_ptr,
|
||||
adj_x || trans_x ? m : k, b_ptr, 1,
|
||||
static_cast<Coefficient>(0.0), &c_ptr, 1)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
context->SetStatus(errors::Internal(
|
||||
@ -459,154 +821,7 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
", k=", k, ", batch_size=", batch_size));
|
||||
}
|
||||
} else {
|
||||
BlasScratchAllocator scratch_allocator(context);
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemmBatchedWithScratch(
|
||||
blas_transpose_b, blas_transpose_a, n, m, k,
|
||||
static_cast<Coefficient>(1.0), b_ptrs,
|
||||
adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
|
||||
static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
|
||||
&scratch_allocator)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
context->SetStatus(errors::Internal(
|
||||
"Blas xGEMMBatched launch failed : a.shape=",
|
||||
in_x.shape().DebugString(),
|
||||
", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
|
||||
", k=", k, ", batch_size=", batch_size));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
|
||||
static void Launch(OpKernelContext* context, const Tensor& in_x,
|
||||
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
|
||||
bool trans_y, const MatMulBCast& bcast, Tensor* out) {
|
||||
typedef Eigen::half Scalar;
|
||||
se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
|
||||
se::blas::Transpose::kTranspose,
|
||||
se::blas::Transpose::kConjugateTranspose};
|
||||
const uint64 m = in_x.dim_size(adj_x || trans_x ? 2 : 1);
|
||||
const uint64 k = in_x.dim_size(adj_x || trans_x ? 1 : 2);
|
||||
const uint64 n = in_y.dim_size(adj_y || trans_y ? 1 : 2);
|
||||
const uint64 batch_size = bcast.output_batch_size();
|
||||
auto blas_transpose_a = trans[adj_x ? 2 : (trans_x ? 1 : 0)];
|
||||
auto blas_transpose_b = trans[adj_y ? 2 : (trans_y ? 1 : 0)];
|
||||
|
||||
auto* stream = context->op_device_context()->stream();
|
||||
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
|
||||
|
||||
typedef perftools::gputools::DeviceMemory<Scalar> DeviceMemoryType;
|
||||
std::vector<DeviceMemoryType> a_device_memory;
|
||||
std::vector<DeviceMemoryType> b_device_memory;
|
||||
std::vector<DeviceMemoryType> c_device_memory;
|
||||
std::vector<DeviceMemoryType*> a_ptrs;
|
||||
std::vector<DeviceMemoryType*> b_ptrs;
|
||||
std::vector<DeviceMemoryType*> c_ptrs;
|
||||
a_device_memory.reserve(bcast.x_batch_size());
|
||||
b_device_memory.reserve(bcast.y_batch_size());
|
||||
c_device_memory.reserve(batch_size);
|
||||
a_ptrs.reserve(batch_size);
|
||||
b_ptrs.reserve(batch_size);
|
||||
c_ptrs.reserve(batch_size);
|
||||
auto* a_base_ptr = in_x.template flat<Scalar>().data();
|
||||
auto* b_base_ptr = in_y.template flat<Scalar>().data();
|
||||
auto* c_base_ptr = out->template flat<Scalar>().data();
|
||||
|
||||
uint64 a_stride;
|
||||
uint64 b_stride;
|
||||
uint64 c_stride;
|
||||
|
||||
bool is_full_broadcast =
|
||||
std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
|
||||
bool use_strided_batched =
|
||||
(!bcast.IsBroadcastingRequired() || is_full_broadcast) &&
|
||||
batch_size > 1;
|
||||
if (use_strided_batched) {
|
||||
a_stride = bcast.x_batch_size() != 1 ? m * k : 0;
|
||||
b_stride = bcast.y_batch_size() != 1 ? k * n : 0;
|
||||
c_stride = m * n;
|
||||
a_device_memory.push_back(AsDeviceMemory(a_base_ptr));
|
||||
b_device_memory.push_back(AsDeviceMemory(b_base_ptr));
|
||||
c_device_memory.push_back(AsDeviceMemory(c_base_ptr));
|
||||
a_ptrs.push_back(&a_device_memory.back());
|
||||
b_ptrs.push_back(&b_device_memory.back());
|
||||
c_ptrs.push_back(&c_device_memory.back());
|
||||
} else if (!bcast.IsBroadcastingRequired()) {
|
||||
for (int64 i = 0; i < batch_size; ++i) {
|
||||
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
|
||||
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
|
||||
c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
|
||||
a_ptrs.push_back(&a_device_memory.back());
|
||||
b_ptrs.push_back(&b_device_memory.back());
|
||||
c_ptrs.push_back(&c_device_memory.back());
|
||||
}
|
||||
} else {
|
||||
const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
|
||||
const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
|
||||
for (int64 i = 0; i < bcast.x_batch_size(); ++i) {
|
||||
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
|
||||
}
|
||||
for (int64 i = 0; i < bcast.y_batch_size(); ++i) {
|
||||
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
|
||||
}
|
||||
for (int64 i = 0; i < batch_size; ++i) {
|
||||
c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
|
||||
a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
|
||||
b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
|
||||
c_ptrs.push_back(&c_device_memory.back());
|
||||
}
|
||||
}
|
||||
|
||||
typedef float Coefficient;
|
||||
|
||||
// Blas does
|
||||
// C = A x B
|
||||
// where A, B and C are assumed to be in column major.
|
||||
// We want the output to be in row-major, so we can compute
|
||||
// C' = B' x A', where ' stands for transpose (not adjoint).
|
||||
// TODO(yangzihao): Choose the best of the three strategies using autotune.
|
||||
if (batch_size == 1) {
|
||||
// This is a regular matrix*matrix or matrix*vector multiply. Avoid the
|
||||
// overhead of the scratch allocator and the batch interface.
|
||||
// TODO(benbarsdell): Use fp16 Gemv if it becomes supported by CUBLAS
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
|
||||
static_cast<Coefficient>(1.0), *(b_ptrs[0]),
|
||||
adj_y || trans_y ? k : n, *(a_ptrs[0]),
|
||||
adj_x || trans_x ? m : k,
|
||||
static_cast<Coefficient>(0.0), c_ptrs[0], n)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
context->SetStatus(errors::Internal(
|
||||
"Blas xGEMM launch failed : a.shape=", in_x.shape().DebugString(),
|
||||
", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
|
||||
", k=", k));
|
||||
}
|
||||
} else if (use_strided_batched) {
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemmStridedBatched(
|
||||
blas_transpose_b, blas_transpose_a, n, m, k,
|
||||
static_cast<Coefficient>(1.0), *b_ptrs[0],
|
||||
adj_y || trans_y ? k : n, b_stride, *a_ptrs[0],
|
||||
adj_x || trans_x ? m : k, a_stride,
|
||||
static_cast<Coefficient>(0.0), c_ptrs[0], n, c_stride,
|
||||
batch_size)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
context->SetStatus(errors::Internal(
|
||||
"Blas xGEMMStridedBatched launch failed : a.shape=",
|
||||
in_x.shape().DebugString(),
|
||||
", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
|
||||
", k=", k, ", batch_size=", batch_size));
|
||||
}
|
||||
} else {
|
||||
BlasScratchAllocator scratch_allocator(context);
|
||||
BlasScratchAllocator scratch_allocator(max_scratch_size, context);
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemmBatchedWithScratch(
|
||||
@ -624,6 +839,7 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
|
||||
", k=", k, ", batch_size=", batch_size));
|
||||
}
|
||||
}
|
||||
#endif // not GOOGLE_CUDA or CUDA_VERSION < 11000
|
||||
}
|
||||
};
|
||||
|
||||
@ -637,6 +853,7 @@ class BaseBatchMatMulOp : public OpKernel {
|
||||
: OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
|
||||
use_autotune_ = MatmulAutotuneEnable();
|
||||
}
|
||||
|
||||
~BaseBatchMatMulOp() override {}
|
||||
@ -698,7 +915,7 @@ class BaseBatchMatMulOp : public OpKernel {
|
||||
out->shape().DebugString()));
|
||||
LaunchBatchMatMul<Device, Scalar>::Launch(
|
||||
ctx, in0_reshaped, in1_reshaped, adj_x_, adj_y_, /*trans_x=*/false,
|
||||
/*trans_y=*/false, bcast, &out_reshaped);
|
||||
/*trans_y=*/false, bcast, use_autotune_, &out_reshaped);
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -708,6 +925,7 @@ class BaseBatchMatMulOp : public OpKernel {
|
||||
private:
|
||||
bool adj_x_;
|
||||
bool adj_y_;
|
||||
bool use_autotune_;
|
||||
};
|
||||
|
||||
// BatchMatMul Op implementation which disallows broadcasting.
|
||||
|
@ -619,19 +619,7 @@ template struct LaunchConv2DOp<CPUDevice, double>;
|
||||
|
||||
int64 GetDnnWorkspaceLimit(const string& envvar_in_mb,
|
||||
int64 default_value_in_bytes) {
|
||||
const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
|
||||
if (workspace_limit_in_mb_str != nullptr &&
|
||||
strcmp(workspace_limit_in_mb_str, "") != 0) {
|
||||
int64 scratch_limit_in_mb = -1;
|
||||
if (strings::safe_strto64(workspace_limit_in_mb_str,
|
||||
&scratch_limit_in_mb)) {
|
||||
return scratch_limit_in_mb * (1 << 20);
|
||||
} else {
|
||||
LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
|
||||
<< workspace_limit_in_mb_str;
|
||||
}
|
||||
}
|
||||
return default_value_in_bytes;
|
||||
return gpu_utils::GetWorkspaceLimit(envvar_in_mb, default_value_in_bytes);
|
||||
}
|
||||
|
||||
// A dummy type to group forward convolution autotune results together.
|
||||
|
@ -48,52 +48,7 @@ int64 GetDnnWorkspaceLimit(const string& envvar_in_mb,
|
||||
// A class to provide scratch-space allocator for Stream-Executor Cudnn
|
||||
// callback. TensorFlow is responsible for releasing the temporary buffers after
|
||||
// the kernel finishes.
|
||||
class DnnScratchAllocator : public se::ScratchAllocator {
|
||||
public:
|
||||
virtual ~DnnScratchAllocator() {}
|
||||
DnnScratchAllocator(int64 memory_limit, OpKernelContext* context)
|
||||
: memory_limit_(memory_limit), total_byte_size_(0), context_(context) {}
|
||||
int64 GetMemoryLimitInBytes() override { return memory_limit_; }
|
||||
se::port::StatusOr<se::DeviceMemory<uint8>> AllocateBytes(
|
||||
int64 byte_size) override {
|
||||
Tensor temporary_memory;
|
||||
if (byte_size < 0) {
|
||||
return se::port::Status{se::port::error::INVALID_ARGUMENT,
|
||||
"Requested negative byte size!"};
|
||||
}
|
||||
if (byte_size > memory_limit_) {
|
||||
return se::port::Status{se::port::error::UNAVAILABLE,
|
||||
absl::StrCat("Requested memory size (", byte_size,
|
||||
") exceeds the max memory limit (",
|
||||
memory_limit_, ").")};
|
||||
}
|
||||
AllocationAttributes allocation_attr;
|
||||
allocation_attr.retry_on_failure = false;
|
||||
Status allocation_status(context_->allocate_temp(
|
||||
DT_UINT8, TensorShape({byte_size}), &temporary_memory,
|
||||
AllocatorAttributes(), allocation_attr));
|
||||
if (!allocation_status.ok()) {
|
||||
return se::port::Status{
|
||||
se::port::error::UNAVAILABLE,
|
||||
absl::StrCat("Failed to allocate the requested memory size (",
|
||||
byte_size, ").")};
|
||||
}
|
||||
// Hold the reference of the allocated tensors until the end of the
|
||||
// allocator.
|
||||
allocated_tensors_.push_back(temporary_memory);
|
||||
total_byte_size_ += byte_size;
|
||||
return se::port::StatusOr<se::DeviceMemory<uint8>>(
|
||||
AsDeviceMemory(temporary_memory.flat<uint8>().data(),
|
||||
temporary_memory.flat<uint8>().size()));
|
||||
}
|
||||
int64 TotalByteSize() { return total_byte_size_; }
|
||||
|
||||
private:
|
||||
int64 memory_limit_;
|
||||
int64 total_byte_size_;
|
||||
OpKernelContext* context_;
|
||||
std::vector<Tensor> allocated_tensors_;
|
||||
};
|
||||
using DnnScratchAllocator = GpuScratchAllocator;
|
||||
|
||||
// Encapsulate all the shape information that is used in both forward and
|
||||
// backward conv operations.
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
|
||||
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
|
||||
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
|
||||
#include "tensorflow/core/kernels/gpu_utils.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
@ -400,20 +401,7 @@ class CufftScratchAllocator : public se::ScratchAllocator {
|
||||
|
||||
int64 GetCufftWorkspaceLimit(const string& envvar_in_mb,
|
||||
int64 default_value_in_bytes) {
|
||||
const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
|
||||
if (workspace_limit_in_mb_str != nullptr &&
|
||||
strcmp(workspace_limit_in_mb_str, "") != 0) {
|
||||
int64 scratch_limit_in_mb = -1;
|
||||
Status status = ReadInt64FromEnvVar(envvar_in_mb, default_value_in_bytes,
|
||||
&scratch_limit_in_mb);
|
||||
if (!status.ok()) {
|
||||
LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
|
||||
<< workspace_limit_in_mb_str;
|
||||
} else {
|
||||
return scratch_limit_in_mb * (1 << 20);
|
||||
}
|
||||
}
|
||||
return default_value_in_bytes;
|
||||
return gpu_utils::GetWorkspaceLimit(envvar_in_mb, default_value_in_bytes);
|
||||
}
|
||||
|
||||
class FFTGPUBase : public FFTBase {
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "google/protobuf/any.pb.h"
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/base/call_once.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/platform/logger.h"
|
||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||
#include "tensorflow/core/protobuf/conv_autotuning.pb.h"
|
||||
@ -282,6 +283,64 @@ Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace gpu_utils {
|
||||
int64 GetWorkspaceLimit(const string& envvar_in_mb,
|
||||
int64 default_value_in_bytes) {
|
||||
const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
|
||||
if (workspace_limit_in_mb_str != nullptr &&
|
||||
strcmp(workspace_limit_in_mb_str, "") != 0) {
|
||||
int64 scratch_limit_in_mb = -1;
|
||||
if (strings::safe_strto64(workspace_limit_in_mb_str,
|
||||
&scratch_limit_in_mb)) {
|
||||
return scratch_limit_in_mb * (1 << 20);
|
||||
} else {
|
||||
LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
|
||||
<< workspace_limit_in_mb_str;
|
||||
}
|
||||
}
|
||||
return default_value_in_bytes;
|
||||
}
|
||||
} // namespace gpu_utils
|
||||
|
||||
GpuScratchAllocator::GpuScratchAllocator(int64 memory_limit,
|
||||
OpKernelContext* context)
|
||||
: memory_limit_(memory_limit), total_byte_size_(0), context_(context) {}
|
||||
|
||||
se::port::StatusOr<se::DeviceMemory<uint8>> GpuScratchAllocator::AllocateBytes(
|
||||
int64 byte_size) {
|
||||
Tensor temporary_memory;
|
||||
if (byte_size < 0) {
|
||||
return se::port::Status{se::port::error::INVALID_ARGUMENT,
|
||||
"Requested negative byte size!"};
|
||||
}
|
||||
if (byte_size > memory_limit_) {
|
||||
return se::port::Status{
|
||||
se::port::error::UNAVAILABLE,
|
||||
absl::StrCat("Requested memory size (", byte_size,
|
||||
") exceeds the max memory limit (", memory_limit_, ").")};
|
||||
}
|
||||
AllocationAttributes allocation_attr;
|
||||
allocation_attr.retry_on_failure = false;
|
||||
Status allocation_status(context_->allocate_temp(
|
||||
DT_UINT8, TensorShape({byte_size}), &temporary_memory,
|
||||
AllocatorAttributes(), allocation_attr));
|
||||
if (!allocation_status.ok()) {
|
||||
return se::port::Status{
|
||||
se::port::error::UNAVAILABLE,
|
||||
absl::StrCat("Failed to allocate the requested memory size (",
|
||||
byte_size, ").")};
|
||||
}
|
||||
// Hold the reference of the allocated tensors until the end of the
|
||||
// allocator.
|
||||
// NOTE: We expect tensors to be deallocated when this allocator goes out of
|
||||
// scope when allocated_tensors is destructed.
|
||||
allocated_tensors_.push_back(temporary_memory);
|
||||
total_byte_size_ += byte_size;
|
||||
return se::port::StatusOr<se::DeviceMemory<uint8>>(
|
||||
AsDeviceMemory(temporary_memory.flat<uint8>().data(),
|
||||
temporary_memory.flat<uint8>().size()));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -243,6 +243,37 @@ void LogFusedConvForwardAutotuneResults(
|
||||
Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
|
||||
se::dnn::AlgorithmConfig* algo);
|
||||
|
||||
namespace gpu_utils {
|
||||
// Get a workspace limit from the environment variable, which is in MB.
|
||||
// Return the workspace memory limit in bytes. If no value is set, return the
|
||||
// default value.
|
||||
int64 GetWorkspaceLimit(const string& envvar_in_mb,
|
||||
int64 default_value_in_bytes);
|
||||
} // namespace gpu_utils
|
||||
|
||||
// A class to provide scratch-space allocator for Stream-Executor callbacks in
|
||||
// CUDA libraries (CUDNN etc.).
|
||||
// TensorFlow is responsible for releasing the temporary buffers after
|
||||
// the kernel finishes.
|
||||
class GpuScratchAllocator : public se::ScratchAllocator {
|
||||
public:
|
||||
virtual ~GpuScratchAllocator() {}
|
||||
|
||||
GpuScratchAllocator(int64 memory_limit, OpKernelContext* context);
|
||||
|
||||
int64 GetMemoryLimitInBytes() override { return memory_limit_; }
|
||||
|
||||
se::port::StatusOr<se::DeviceMemory<uint8>> AllocateBytes(
|
||||
int64 byte_size) override;
|
||||
|
||||
int64 TotalByteSize() { return total_byte_size_; }
|
||||
|
||||
private:
|
||||
int64 memory_limit_;
|
||||
int64 total_byte_size_;
|
||||
OpKernelContext* context_;
|
||||
std::vector<Tensor> allocated_tensors_;
|
||||
};
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -549,7 +549,7 @@ struct EinsumHelper {
|
||||
static Status ContractOperands(OpKernelContext* ctx,
|
||||
absl::Span<const Tensor> inputs,
|
||||
absl::Span<const bool> swap_free_and_contract,
|
||||
Tensor* output) {
|
||||
bool use_autotune, Tensor* output) {
|
||||
if (inputs.size() == 1)
|
||||
return CopyFrom(inputs[0], inputs[0].shape(), output);
|
||||
MatMulBCast bcast(inputs[0].shape().dim_sizes(),
|
||||
@ -583,7 +583,7 @@ struct EinsumHelper {
|
||||
ReshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped));
|
||||
LaunchBatchMatMul<Device, T>::Launch(ctx, lhs, rhs, /*adj_x=*/false,
|
||||
/*adj_y=*/false, trans_x, trans_y,
|
||||
bcast, &output_reshaped);
|
||||
bcast, use_autotune, &output_reshaped);
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
@ -598,6 +598,7 @@ class EinsumOp : public OpKernel {
|
||||
equation_, &input_labels_, &output_labels_, &label_types_,
|
||||
&input_label_counts_, &output_label_counts_,
|
||||
&input_has_ellipsis_, &output_has_ellipsis_));
|
||||
use_autotune_ = MatmulAutotuneEnable();
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
@ -640,7 +641,7 @@ class EinsumOp : public OpKernel {
|
||||
Tensor contraction_output_reshaped;
|
||||
OP_REQUIRES_OK(ctx, EinsumHelper::ContractOperands<Device, T>(
|
||||
ctx, inputs_reduced, swap_free_and_contract,
|
||||
&contraction_output_reshaped));
|
||||
use_autotune_, &contraction_output_reshaped));
|
||||
|
||||
// Copy the batch labels from the contraction output. Recover the batch
|
||||
// shape, which may have been broadcasted.
|
||||
@ -738,6 +739,7 @@ class EinsumOp : public OpKernel {
|
||||
LabelCounts output_label_counts_;
|
||||
gtl::InlinedVector<bool, 2> input_has_ellipsis_;
|
||||
bool output_has_ellipsis_ = false;
|
||||
bool use_autotune_;
|
||||
};
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -48,4 +48,22 @@ bool MatmulDoFP32ComputationFP16Input() {
|
||||
return value;
|
||||
}
|
||||
|
||||
int MatmulMaxAutotuneAlgorithmCount() {
|
||||
int64 value;
|
||||
// In CUDA 11, cublasLtMatmulAlgoGetHeuristic typically returns <= 4
|
||||
// algorithms for a given configuration, so 10 seems like a reasonable default
|
||||
// here.
|
||||
Status status =
|
||||
ReadInt64FromEnvVar("TF_MATMUL_AUTOTUNE_MAX_ALGORITHMS", 10, &value);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << status.error_message();
|
||||
}
|
||||
static constexpr const int kMaxValue = std::numeric_limits<int>::max();
|
||||
if (value < 1 || value > kMaxValue) {
|
||||
LOG(ERROR) << "Invalid value for TF_MATMUL_AUTOTUNE_MAX_ALGORITHMS: "
|
||||
<< value << " is not in range [1, " << kMaxValue << "]";
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -22,6 +22,7 @@ namespace tensorflow {
|
||||
|
||||
bool MatmulAutotuneEnable();
|
||||
bool MatmulDoFP32ComputationFP16Input();
|
||||
int MatmulMaxAutotuneAlgorithmCount();
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -114,6 +114,10 @@ cuda_py_test(
|
||||
name = "dirichlet_test",
|
||||
size = "small",
|
||||
srcs = ["dirichlet_test.py"],
|
||||
tags = [
|
||||
# b/170982175
|
||||
"no_oss",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
|
@ -197,7 +197,7 @@ class DirichletTest(test.TestCase):
|
||||
|
||||
self.assertAllClose(sample_mean_, analytic_mean, atol=0.04, rtol=0.)
|
||||
self.assertAllClose(sample_cov_, analytic_cov, atol=0.06, rtol=0.)
|
||||
self.assertAllClose(sample_var_, analytic_var, atol=0.03, rtol=0.)
|
||||
self.assertAllClose(sample_var_, analytic_var, atol=0.04, rtol=0.)
|
||||
self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.)
|
||||
|
||||
@test_util.run_without_tensor_float_32(
|
||||
|
@ -587,7 +587,7 @@ class CSRSparseMatrixOpsTest(test.TestCase):
|
||||
c_t_value, c_dense_t_value = self.evaluate((c_t, c_dense_t))
|
||||
|
||||
self.assertAllClose(
|
||||
c_t_value, c_dense_t_value, rtol=1e-6, atol=1e-5)
|
||||
c_t_value, c_dense_t_value, rtol=1e-6, atol=2e-5)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testLargeBatchSparseMatrixMatMulTransposed(self):
|
||||
@ -650,7 +650,7 @@ class CSRSparseMatrixOpsTest(test.TestCase):
|
||||
self.assertAllEqual(c_t.shape, c_dense_t.shape)
|
||||
c_t_value, c_dense_t_value = self.evaluate((c_t, c_dense_t))
|
||||
self.assertAllClose(
|
||||
c_t_value, c_dense_t_value, rtol=1e-6, atol=1e-5)
|
||||
c_t_value, c_dense_t_value, rtol=1e-6, atol=2e-5)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testLargeBatchSparseMatrixMatMulConjugate(self):
|
||||
|
@ -61,7 +61,6 @@ cc_library(
|
||||
"blas.h",
|
||||
"device_description.h",
|
||||
"device_options.h",
|
||||
"dnn.h",
|
||||
"event.cc",
|
||||
"fft.h",
|
||||
"kernel_cache_config.h",
|
||||
@ -103,7 +102,6 @@ cc_library(
|
||||
cc_library(
|
||||
name = "kernel",
|
||||
srcs = [
|
||||
"dnn.h",
|
||||
"fft.h",
|
||||
"kernel.cc",
|
||||
"plugin.h",
|
||||
@ -300,6 +298,7 @@ cc_library(
|
||||
name = "host_or_device_scalar",
|
||||
hdrs = ["host_or_device_scalar.h"],
|
||||
deps = [
|
||||
":data_type",
|
||||
":device_memory",
|
||||
"//tensorflow/stream_executor/platform",
|
||||
],
|
||||
@ -331,7 +330,6 @@ cc_library(
|
||||
],
|
||||
hdrs = [
|
||||
"blas.h",
|
||||
"dnn.h",
|
||||
"executor_cache.h",
|
||||
"fft.h",
|
||||
"kernel.h",
|
||||
@ -423,11 +421,21 @@ tf_proto_library(
|
||||
make_default_target_header_only = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "data_type",
|
||||
hdrs = ["data_type.h"],
|
||||
deps = [
|
||||
":dnn_proto_cc",
|
||||
"//tensorflow/stream_executor/platform",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "dnn",
|
||||
srcs = ["dnn.cc"],
|
||||
hdrs = ["dnn.h"],
|
||||
deps = [
|
||||
":data_type",
|
||||
":device_memory",
|
||||
":dnn_proto_cc",
|
||||
":stream_executor_headers",
|
||||
@ -445,7 +453,6 @@ cc_library(
|
||||
cc_library(
|
||||
name = "stream_executor_internal",
|
||||
srcs = [
|
||||
"dnn.h",
|
||||
"stream_executor_internal.cc",
|
||||
],
|
||||
hdrs = [
|
||||
@ -474,7 +481,6 @@ cc_library(
|
||||
name = "stream_executor_pimpl_header",
|
||||
hdrs = [
|
||||
"device_description.h",
|
||||
"dnn.h",
|
||||
"kernel.h",
|
||||
"kernel_cache_config.h",
|
||||
"stream_executor_pimpl.h",
|
||||
|
@ -95,5 +95,30 @@ std::ostream& operator<<(std::ostream& os, ComputationType ty) {
|
||||
return os << ComputationTypeString(ty);
|
||||
}
|
||||
|
||||
string DataTypeString(DataType ty) {
|
||||
switch (ty) {
|
||||
case DataType::kHalf:
|
||||
return "f16";
|
||||
case DataType::kFloat:
|
||||
return "f32";
|
||||
case DataType::kDouble:
|
||||
return "f64";
|
||||
case DataType::kInt8:
|
||||
return "i8";
|
||||
case DataType::kInt32:
|
||||
return "i32";
|
||||
case DataType::kComplexFloat:
|
||||
return "complex f32";
|
||||
case DataType::kComplexDouble:
|
||||
return "complex f64";
|
||||
default:
|
||||
LOG(FATAL) << "Unknown DataType " << static_cast<int32>(ty);
|
||||
}
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, DataType ty) {
|
||||
return os << DataTypeString(ty);
|
||||
}
|
||||
|
||||
} // namespace blas
|
||||
} // namespace stream_executor
|
||||
|
@ -43,7 +43,7 @@ limitations under the License.
|
||||
#include <complex>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/stream_executor/host_or_device_scalar.h"
|
||||
#include "tensorflow/stream_executor/dnn.h" // For DataType, ToDataType
|
||||
#include "tensorflow/stream_executor/lib/array_slice.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
@ -60,6 +60,9 @@ class ScratchAllocator;
|
||||
template <typename ElemT>
|
||||
class DeviceMemory;
|
||||
|
||||
template <typename ElemT>
|
||||
class HostOrDeviceScalar;
|
||||
|
||||
namespace blas {
|
||||
|
||||
// Specifies whether the input matrix will be transposed or
|
||||
@ -101,6 +104,18 @@ enum class ComputationType {
|
||||
kI32, // 32-bit integer
|
||||
kComplexF32, // Complex number comprised of two f32s.
|
||||
kComplexF64, // Complex number comprised of two f64s.
|
||||
// The below values are only supported for BlasLt routines (both real and
|
||||
// complex). They use float32 for accumulation but round the input mantissas
|
||||
// to a smaller number of bits.
|
||||
kTF32AsF32, // 32-bit floating-point with reduced (>=10-bit) mantissa
|
||||
kBF16AsF32, // 32-bit floating-point with reduced (7-bit) mantissa
|
||||
};
|
||||
|
||||
enum class Epilogue {
|
||||
kDefault = 1, // No special postprocessing
|
||||
kReLU = 2, // Apply ReLU func point-wise to the results
|
||||
kBias = 4, // Add broadcasted bias vector to the results
|
||||
kBiasThenReLU = kBias | kReLU, // Apply bias and then ReLU transform
|
||||
};
|
||||
|
||||
// Converts a ComputationType to a string.
|
||||
@ -108,6 +123,21 @@ std::string ComputationTypeString(ComputationType ty);
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, ComputationType ty);
|
||||
|
||||
using dnn::DataType;
|
||||
using dnn::ToDataType;
|
||||
|
||||
// Describes the type of pointers for the scaling factors alpha and beta in
|
||||
// blaslt routines.
|
||||
enum class PointerMode {
|
||||
kHost,
|
||||
kDevice,
|
||||
};
|
||||
|
||||
// Converts a ComputationType to a string.
|
||||
string DataTypeString(DataType ty);
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, DataType ty);
|
||||
|
||||
// Opaque identifier for an "algorithm" used by a blas routine. This functions
|
||||
// as a hint to the blas library.
|
||||
typedef int64 AlgorithmType;
|
||||
@ -163,6 +193,44 @@ class AlgorithmConfig {
|
||||
AlgorithmType algorithm_;
|
||||
};
|
||||
|
||||
struct IBlasLtMatmulPlan {
|
||||
// Returns the data type of the A and B (input) matrices.
|
||||
virtual DataType ab_type() const = 0;
|
||||
// Returns the data type of the C (input/output) matrix.
|
||||
virtual DataType c_type() const = 0;
|
||||
virtual ~IBlasLtMatmulPlan() {}
|
||||
};
|
||||
|
||||
struct IBlasLtMatmulAlgorithm {
|
||||
virtual ~IBlasLtMatmulAlgorithm() {}
|
||||
// Returns the index of the algorithm within the list returned by
|
||||
// GetBlasLtMatmulAlgorithms.
|
||||
virtual AlgorithmType index() const = 0;
|
||||
// Returns the workspace size required by the algorithm in bytes.
|
||||
virtual size_t workspace_size() const = 0;
|
||||
};
|
||||
|
||||
// Parameters for the CreateBlasLtMatmulPlan method.
|
||||
struct BlasLtMatmulPlanParams {
|
||||
DataType ab_type;
|
||||
DataType c_type;
|
||||
ComputationType computation_type;
|
||||
PointerMode pointer_mode;
|
||||
Epilogue epilogue;
|
||||
Transpose transa;
|
||||
Transpose transb;
|
||||
uint64 m;
|
||||
uint64 n;
|
||||
uint64 k;
|
||||
int64 lda;
|
||||
int64 ldb;
|
||||
int64 ldc;
|
||||
int batch_count = 1;
|
||||
int64 stride_a = 0;
|
||||
int64 stride_b = 0;
|
||||
int64 stride_c = 0;
|
||||
};
|
||||
|
||||
// BLAS support interface -- this can be derived from a GPU executor when the
|
||||
// underlying platform has an BLAS library implementation available. See
|
||||
// StreamExecutor::AsBlas().
|
||||
@ -1383,6 +1451,71 @@ class BlasSupport {
|
||||
const DeviceMemory<std::complex<double>> &a, int lda,
|
||||
DeviceMemory<std::complex<double>> *b, int ldb) = 0;
|
||||
|
||||
// Creates a backend-specific plan object for a blaslt matmul operation, which
|
||||
// can then be passed to DoBlasLtMatmul(). When possible, plans should be
|
||||
// created once and reused for multiple calls to DoBlasLtMatmul().
|
||||
virtual port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
|
||||
CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams ¶ms) = 0;
|
||||
|
||||
// Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
|
||||
// returned in the order of increasing estimated compute time according to an
|
||||
// internal heuristic. The first returned algorithm can be used as the default
|
||||
// algorithm if no autotuning is to be performed.
|
||||
virtual port::StatusOr<
|
||||
std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
|
||||
GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
|
||||
size_t max_workspace_size,
|
||||
int max_algorithm_count) = 0;
|
||||
|
||||
// Executes a blaslt matmul operation on the stream. If output_profile_result
|
||||
// is not nullptr, the operation is profiled, error messages are
|
||||
// suppressed, and output_profile_result->algorithm() is set to
|
||||
// algorithm->index(). If epilogue was set to kBias or kBiasThenReLU when
|
||||
// creating the plan, the bias argument here must refer to a valid device
|
||||
// vector of length equal to the number of rows in matrix c. If epilogue was
|
||||
// set to any other value then the bias argument here must be null. The bias
|
||||
// vector is broadcast across the batch dimension.
|
||||
// Note that the data types of a and b (c and bias) must match the ab_type
|
||||
// (c_type) with which the plan was created, and the data types of alpha and
|
||||
// beta must match the data type of c.
|
||||
virtual bool DoBlasLtMatmul(
|
||||
Stream *stream, const blas::IBlasLtMatmulPlan *plan,
|
||||
const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a,
|
||||
DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta,
|
||||
DeviceMemoryBase c, ScratchAllocator *scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias,
|
||||
blas::ProfileResult *output_profile_result) = 0;
|
||||
|
||||
template <typename ABType, typename CType>
|
||||
bool DoBlasLtMatmul(Stream *stream, const blas::IBlasLtMatmulPlan *plan,
|
||||
const HostOrDeviceScalar<CType> &alpha,
|
||||
const DeviceMemory<ABType> &a,
|
||||
const DeviceMemory<ABType> &b,
|
||||
const HostOrDeviceScalar<CType> &beta,
|
||||
DeviceMemory<CType> *c,
|
||||
ScratchAllocator *scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm *algorithm,
|
||||
const DeviceMemory<CType> &bias = {},
|
||||
blas::ProfileResult *output_profile_result = nullptr) {
|
||||
constexpr blas::DataType ab_type = blas::ToDataType<ABType>::value;
|
||||
if (ab_type != plan->ab_type()) {
|
||||
VLOG(2) << "DoBlasLtMatmul returning false because a and b type does "
|
||||
"not match plan: expected "
|
||||
<< plan->ab_type() << ", got " << ab_type;
|
||||
return false;
|
||||
}
|
||||
constexpr blas::DataType c_type = blas::ToDataType<CType>::value;
|
||||
if (c_type != plan->c_type()) {
|
||||
VLOG(2) << "DoBlasLtMatmul returning false because c type does "
|
||||
"not match plan: expected "
|
||||
<< plan->c_type() << ", got " << c_type;
|
||||
return false;
|
||||
}
|
||||
return DoBlasLtMatmul(stream, plan, alpha, a, b, beta, *c,
|
||||
scratch_allocator, algorithm, bias,
|
||||
output_profile_result);
|
||||
}
|
||||
|
||||
virtual port::Status GetVersion(std::string *version) = 0;
|
||||
|
||||
protected:
|
||||
@ -2196,6 +2329,19 @@ class BlasSupport {
|
||||
uint64 n, std::complex<double> alpha, \
|
||||
const DeviceMemory<std::complex<double>> &a, int lda, \
|
||||
DeviceMemory<std::complex<double>> *b, int ldb) override; \
|
||||
port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>> \
|
||||
CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams ¶ms) override; \
|
||||
port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>> \
|
||||
GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan, \
|
||||
size_t max_workspace_size, \
|
||||
int max_algorithm_count) override; \
|
||||
bool DoBlasLtMatmul( \
|
||||
Stream *stream, const blas::IBlasLtMatmulPlan *plan, \
|
||||
const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a, \
|
||||
DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta, \
|
||||
DeviceMemoryBase c, ScratchAllocator *scratch_allocator, \
|
||||
const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias, \
|
||||
blas::ProfileResult *output_profile_result) override; \
|
||||
port::Status GetVersion(std::string *version) override;
|
||||
|
||||
} // namespace blas
|
||||
|
@ -251,6 +251,31 @@ alias(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cublas_lt_stub",
|
||||
srcs = if_cuda_is_configured(["cublasLt_stub.cc"]),
|
||||
textual_hdrs = glob(["cublasLt_*.inc"]),
|
||||
deps = if_cuda_is_configured([
|
||||
# LINT.IfChange
|
||||
"@local_config_cuda//cuda:cublas_headers",
|
||||
# LINT.ThenChange(//tensorflow/copy.bara.sky:cublasLt_headers)
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"//tensorflow/stream_executor/platform:dso_loader",
|
||||
]),
|
||||
)
|
||||
|
||||
cc_library(name = "empty_lib")
|
||||
|
||||
alias(
|
||||
name = "cublas_lt_lib",
|
||||
actual = select({
|
||||
"//tensorflow:oss": ":cublas_lt_stub",
|
||||
"//conditions:default": ":empty_lib",
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cublas_plugin",
|
||||
srcs = if_cuda_is_configured(["cuda_blas.cc"]),
|
||||
@ -258,6 +283,7 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = if_cuda_is_configured([
|
||||
":cublas_lib",
|
||||
":cublas_lt_lib",
|
||||
":cuda_activation",
|
||||
":cuda_gpu_executor",
|
||||
":cuda_platform_id",
|
||||
|
390
tensorflow/stream_executor/cuda/cublasLt_11_0.inc
Normal file
390
tensorflow/stream_executor/cuda/cublasLt_11_0.inc
Normal file
@ -0,0 +1,390 @@
|
||||
// Auto-generated, do not edit.
|
||||
|
||||
extern "C" {
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtCreate(cublasLtHandle_t *lightHandle) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasLtHandle_t *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtCreate");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(lightHandle);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtDestroy(cublasLtHandle_t lightHandle) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasLtHandle_t);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtDestroy");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(lightHandle);
|
||||
}
|
||||
|
||||
size_t CUBLASWINAPI cublasLtGetVersion(void) {
|
||||
using FuncPtr = size_t(CUBLASWINAPI *)();
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtGetVersion");
|
||||
if (!func_ptr) return 0;
|
||||
return func_ptr();
|
||||
}
|
||||
|
||||
size_t CUBLASWINAPI cublasLtGetCudartVersion(void) {
|
||||
using FuncPtr = size_t(CUBLASWINAPI *)();
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtGetCudartVersion");
|
||||
if (!func_ptr) return 0;
|
||||
return func_ptr();
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtGetProperty(libraryPropertyType type,
|
||||
int *value) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(libraryPropertyType, int *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtGetProperty");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(type, value);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtMatmul(
|
||||
cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t computeDesc,
|
||||
const void *alpha, /* host or device pointer */
|
||||
const void *A, cublasLtMatrixLayout_t Adesc, const void *B,
|
||||
cublasLtMatrixLayout_t Bdesc, const void *beta, /* host or device pointer */
|
||||
const void *C, cublasLtMatrixLayout_t Cdesc, void *D,
|
||||
cublasLtMatrixLayout_t Ddesc, const cublasLtMatmulAlgo_t *algo,
|
||||
void *workspace, size_t workspaceSizeInBytes, cudaStream_t stream) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(
|
||||
cublasLtHandle_t, cublasLtMatmulDesc_t, const void *, const void *,
|
||||
cublasLtMatrixLayout_t, const void *, cublasLtMatrixLayout_t,
|
||||
const void *, const void *, cublasLtMatrixLayout_t, void *,
|
||||
cublasLtMatrixLayout_t, const cublasLtMatmulAlgo_t *, void *, size_t,
|
||||
cudaStream_t);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmul");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(lightHandle, computeDesc, alpha, A, Adesc, B, Bdesc, beta, C,
|
||||
Cdesc, D, Ddesc, algo, workspace, workspaceSizeInBytes,
|
||||
stream);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtMatrixTransform(
|
||||
cublasLtHandle_t lightHandle, cublasLtMatrixTransformDesc_t transformDesc,
|
||||
const void *alpha, /* host or device pointer */
|
||||
const void *A, cublasLtMatrixLayout_t Adesc,
|
||||
const void *beta, /* host or device pointer */
|
||||
const void *B, cublasLtMatrixLayout_t Bdesc, void *C,
|
||||
cublasLtMatrixLayout_t Cdesc, cudaStream_t stream) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(
|
||||
cublasLtHandle_t, cublasLtMatrixTransformDesc_t, const void *,
|
||||
const void *, cublasLtMatrixLayout_t, const void *, const void *,
|
||||
cublasLtMatrixLayout_t, void *, cublasLtMatrixLayout_t, cudaStream_t);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatrixTransform");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(lightHandle, transformDesc, alpha, A, Adesc, beta, B, Bdesc,
|
||||
C, Cdesc, stream);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutInit_internal( //
|
||||
cublasLtMatrixLayout_t matLayout, size_t size, cudaDataType type,
|
||||
uint64_t rows, uint64_t cols, int64_t ld) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( //
|
||||
cublasLtMatrixLayout_t, size_t, cudaDataType, uint64_t, uint64_t,
|
||||
int64_t);
|
||||
static auto func_ptr =
|
||||
LoadSymbol<FuncPtr>("cublasLtMatrixLayoutInit_internal");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(matLayout, size, type, rows, cols, ld);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutCreate( //
|
||||
cublasLtMatrixLayout_t *matLayout, cudaDataType type, uint64_t rows,
|
||||
uint64_t cols, int64_t ld) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( //
|
||||
cublasLtMatrixLayout_t *, cudaDataType, uint64_t, uint64_t, int64_t);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatrixLayoutCreate");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(matLayout, type, rows, cols, ld);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatrixLayoutDestroy(cublasLtMatrixLayout_t matLayout) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasLtMatrixLayout_t);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatrixLayoutDestroy");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(matLayout);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutSetAttribute( //
|
||||
cublasLtMatrixLayout_t matLayout, cublasLtMatrixLayoutAttribute_t attr,
|
||||
const void *buf, size_t sizeInBytes) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( //
|
||||
cublasLtMatrixLayout_t, cublasLtMatrixLayoutAttribute_t, const void *,
|
||||
size_t);
|
||||
static auto func_ptr =
|
||||
LoadSymbol<FuncPtr>("cublasLtMatrixLayoutSetAttribute");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(matLayout, attr, buf, sizeInBytes);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutGetAttribute( //
|
||||
cublasLtMatrixLayout_t matLayout, cublasLtMatrixLayoutAttribute_t attr,
|
||||
void *buf, size_t sizeInBytes, size_t *sizeWritten) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( //
|
||||
cublasLtMatrixLayout_t, cublasLtMatrixLayoutAttribute_t, void *, size_t,
|
||||
size_t *);
|
||||
static auto func_ptr =
|
||||
LoadSymbol<FuncPtr>("cublasLtMatrixLayoutGetAttribute");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(matLayout, attr, buf, sizeInBytes, sizeWritten);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtMatmulDescInit_internal( //
|
||||
cublasLtMatmulDesc_t matmulDesc, size_t size,
|
||||
cublasComputeType_t computeType, cudaDataType_t scaleType) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( //
|
||||
cublasLtMatmulDesc_t, size_t, cublasComputeType_t, cudaDataType_t);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulDescInit_internal");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(matmulDesc, size, computeType, scaleType);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtMatmulDescCreate(
|
||||
cublasLtMatmulDesc_t *matmulDesc, cublasComputeType_t computeType,
|
||||
cudaDataType_t scaleType) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(
|
||||
cublasLtMatmulDesc_t *, cublasComputeType_t, cudaDataType_t);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulDescCreate");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(matmulDesc, computeType, scaleType);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatmulDescDestroy(cublasLtMatmulDesc_t matmulDesc) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasLtMatmulDesc_t);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulDescDestroy");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(matmulDesc);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtMatmulDescSetAttribute( //
|
||||
cublasLtMatmulDesc_t matmulDesc, cublasLtMatmulDescAttributes_t attr,
|
||||
const void *buf, size_t sizeInBytes) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( //
|
||||
cublasLtMatmulDesc_t, cublasLtMatmulDescAttributes_t, const void *,
|
||||
size_t);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulDescSetAttribute");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(matmulDesc, attr, buf, sizeInBytes);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtMatmulDescGetAttribute( //
|
||||
cublasLtMatmulDesc_t matmulDesc, cublasLtMatmulDescAttributes_t attr,
|
||||
void *buf, size_t sizeInBytes, size_t *sizeWritten) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( //
|
||||
cublasLtMatmulDesc_t, cublasLtMatmulDescAttributes_t, void *, size_t,
|
||||
size_t *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulDescGetAttribute");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(matmulDesc, attr, buf, sizeInBytes, sizeWritten);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescInit_internal(
|
||||
cublasLtMatrixTransformDesc_t transformDesc, size_t size,
|
||||
cudaDataType scaleType) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasLtMatrixTransformDesc_t,
|
||||
size_t, cudaDataType);
|
||||
static auto func_ptr =
|
||||
LoadSymbol<FuncPtr>("cublasLtMatrixTransformDescInit_internal");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(transformDesc, size, scaleType);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescCreate(
|
||||
cublasLtMatrixTransformDesc_t *transformDesc, cudaDataType scaleType) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(
|
||||
cublasLtMatrixTransformDesc_t *, cudaDataType);
|
||||
static auto func_ptr =
|
||||
LoadSymbol<FuncPtr>("cublasLtMatrixTransformDescCreate");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(transformDesc, scaleType);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescDestroy(
|
||||
cublasLtMatrixTransformDesc_t transformDesc) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasLtMatrixTransformDesc_t);
|
||||
static auto func_ptr =
|
||||
LoadSymbol<FuncPtr>("cublasLtMatrixTransformDescDestroy");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(transformDesc);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescSetAttribute( //
|
||||
cublasLtMatrixTransformDesc_t transformDesc,
|
||||
cublasLtMatrixTransformDescAttributes_t attr, const void *buf,
|
||||
size_t sizeInBytes) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( //
|
||||
cublasLtMatrixTransformDesc_t, cublasLtMatrixTransformDescAttributes_t,
|
||||
const void *, size_t);
|
||||
static auto func_ptr =
|
||||
LoadSymbol<FuncPtr>("cublasLtMatrixTransformDescSetAttribute");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(transformDesc, attr, buf, sizeInBytes);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescGetAttribute( //
|
||||
cublasLtMatrixTransformDesc_t transformDesc,
|
||||
cublasLtMatrixTransformDescAttributes_t attr, void *buf, size_t sizeInBytes,
|
||||
size_t *sizeWritten) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( //
|
||||
cublasLtMatrixTransformDesc_t, cublasLtMatrixTransformDescAttributes_t,
|
||||
void *, size_t, size_t *);
|
||||
static auto func_ptr =
|
||||
LoadSymbol<FuncPtr>("cublasLtMatrixTransformDescGetAttribute");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(transformDesc, attr, buf, sizeInBytes, sizeWritten);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceInit_internal(
|
||||
cublasLtMatmulPreference_t pref, size_t size) {
|
||||
using FuncPtr =
|
||||
cublasStatus_t(CUBLASWINAPI *)(cublasLtMatmulPreference_t, size_t);
|
||||
static auto func_ptr =
|
||||
LoadSymbol<FuncPtr>("cublasLtMatmulPreferenceInit_internal");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(pref, size);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatmulPreferenceCreate(cublasLtMatmulPreference_t *pref) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasLtMatmulPreference_t *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulPreferenceCreate");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(pref);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI
|
||||
cublasLtMatmulPreferenceDestroy(cublasLtMatmulPreference_t pref) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasLtMatmulPreference_t);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulPreferenceDestroy");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(pref);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceSetAttribute( //
|
||||
cublasLtMatmulPreference_t pref, cublasLtMatmulPreferenceAttributes_t attr,
|
||||
const void *buf, size_t sizeInBytes) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( //
|
||||
cublasLtMatmulPreference_t, cublasLtMatmulPreferenceAttributes_t,
|
||||
const void *, size_t);
|
||||
static auto func_ptr =
|
||||
LoadSymbol<FuncPtr>("cublasLtMatmulPreferenceSetAttribute");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(pref, attr, buf, sizeInBytes);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceGetAttribute( //
|
||||
cublasLtMatmulPreference_t pref, cublasLtMatmulPreferenceAttributes_t attr,
|
||||
void *buf, size_t sizeInBytes, size_t *sizeWritten) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( //
|
||||
cublasLtMatmulPreference_t, cublasLtMatmulPreferenceAttributes_t, void *,
|
||||
size_t, size_t *);
|
||||
static auto func_ptr =
|
||||
LoadSymbol<FuncPtr>("cublasLtMatmulPreferenceGetAttribute");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(pref, attr, buf, sizeInBytes, sizeWritten);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoGetHeuristic(
|
||||
cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t operationDesc,
|
||||
cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc,
|
||||
cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc,
|
||||
cublasLtMatmulPreference_t preference, int requestedAlgoCount,
|
||||
cublasLtMatmulHeuristicResult_t heuristicResultsArray[],
|
||||
int *returnAlgoCount) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(
|
||||
cublasLtHandle_t, cublasLtMatmulDesc_t, cublasLtMatrixLayout_t,
|
||||
cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, cublasLtMatrixLayout_t,
|
||||
cublasLtMatmulPreference_t, int, cublasLtMatmulHeuristicResult_t[],
|
||||
int *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulAlgoGetHeuristic");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(lightHandle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc,
|
||||
preference, requestedAlgoCount, heuristicResultsArray,
|
||||
returnAlgoCount);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoGetIds(
|
||||
cublasLtHandle_t lightHandle, cublasComputeType_t computeType,
|
||||
cudaDataType_t scaleType, cudaDataType_t Atype, cudaDataType_t Btype,
|
||||
cudaDataType_t Ctype, cudaDataType_t Dtype, int requestedAlgoCount,
|
||||
int algoIdsArray[], int *returnAlgoCount) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(
|
||||
cublasLtHandle_t, cublasComputeType_t, cudaDataType_t, cudaDataType_t,
|
||||
cudaDataType_t, cudaDataType_t, cudaDataType_t, int, int[], int *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulAlgoGetIds");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(lightHandle, computeType, scaleType, Atype, Btype, Ctype,
|
||||
Dtype, requestedAlgoCount, algoIdsArray, returnAlgoCount);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoInit(
|
||||
cublasLtHandle_t lightHandle, cublasComputeType_t computeType,
|
||||
cudaDataType_t scaleType, cudaDataType_t Atype, cudaDataType_t Btype,
|
||||
cudaDataType_t Ctype, cudaDataType_t Dtype, int algoId,
|
||||
cublasLtMatmulAlgo_t *algo) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(
|
||||
cublasLtHandle_t, cublasComputeType_t, cudaDataType_t, cudaDataType_t,
|
||||
cudaDataType_t, cudaDataType_t, cudaDataType_t, int,
|
||||
cublasLtMatmulAlgo_t *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulAlgoInit");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(lightHandle, computeType, scaleType, Atype, Btype, Ctype,
|
||||
Dtype, algoId, algo);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoCheck( //
|
||||
cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t operationDesc,
|
||||
cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc,
|
||||
cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc,
|
||||
const cublasLtMatmulAlgo_t *algo, ///< may point to result->algo
|
||||
cublasLtMatmulHeuristicResult_t *result) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( //
|
||||
cublasLtHandle_t, cublasLtMatmulDesc_t, cublasLtMatrixLayout_t,
|
||||
cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, cublasLtMatrixLayout_t,
|
||||
const cublasLtMatmulAlgo_t *, ///< may point to result->algo
|
||||
cublasLtMatmulHeuristicResult_t *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulAlgoCheck");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(lightHandle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, algo,
|
||||
result);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoCapGetAttribute(
|
||||
const cublasLtMatmulAlgo_t *algo, cublasLtMatmulAlgoCapAttributes_t attr,
|
||||
void *buf, size_t sizeInBytes, size_t *sizeWritten) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(
|
||||
const cublasLtMatmulAlgo_t *, cublasLtMatmulAlgoCapAttributes_t, void *,
|
||||
size_t, size_t *);
|
||||
static auto func_ptr =
|
||||
LoadSymbol<FuncPtr>("cublasLtMatmulAlgoCapGetAttribute");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(algo, attr, buf, sizeInBytes, sizeWritten);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoConfigSetAttribute(
|
||||
cublasLtMatmulAlgo_t *algo, cublasLtMatmulAlgoConfigAttributes_t attr,
|
||||
const void *buf, size_t sizeInBytes) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(
|
||||
cublasLtMatmulAlgo_t *, cublasLtMatmulAlgoConfigAttributes_t,
|
||||
const void *, size_t);
|
||||
static auto func_ptr =
|
||||
LoadSymbol<FuncPtr>("cublasLtMatmulAlgoConfigSetAttribute");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(algo, attr, buf, sizeInBytes);
|
||||
}
|
||||
|
||||
cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoConfigGetAttribute(
|
||||
const cublasLtMatmulAlgo_t *algo, cublasLtMatmulAlgoConfigAttributes_t attr,
|
||||
void *buf, size_t sizeInBytes, size_t *sizeWritten) {
|
||||
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(
|
||||
const cublasLtMatmulAlgo_t *, cublasLtMatmulAlgoConfigAttributes_t,
|
||||
void *, size_t, size_t *);
|
||||
static auto func_ptr =
|
||||
LoadSymbol<FuncPtr>("cublasLtMatmulAlgoConfigGetAttribute");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(algo, attr, buf, sizeInBytes, sizeWritten);
|
||||
}
|
||||
|
||||
} // extern "C"
|
59
tensorflow/stream_executor/cuda/cublasLt_stub.cc
Normal file
59
tensorflow/stream_executor/cuda/cublasLt_stub.cc
Normal file
@ -0,0 +1,59 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "third_party/gpus/cuda/include/cublasLt.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "tensorflow/stream_executor/lib/env.h"
|
||||
#include "tensorflow/stream_executor/platform/dso_loader.h"
|
||||
|
||||
// Implements the cuBLASLt API by forwarding to cuBLASLt loaded from the DSO.
|
||||
|
||||
namespace {
|
||||
// Returns DSO handle or null if loading the DSO fails.
|
||||
void* GetDsoHandle() {
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
return nullptr;
|
||||
#else
|
||||
static auto handle = []() -> void* {
|
||||
auto handle_or =
|
||||
stream_executor::internal::DsoLoader::GetCublasLtDsoHandle();
|
||||
if (!handle_or.ok()) return nullptr;
|
||||
return handle_or.ValueOrDie();
|
||||
}();
|
||||
return handle;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T LoadSymbol(const char* symbol_name) {
|
||||
void* symbol = nullptr;
|
||||
if (auto handle = GetDsoHandle()) {
|
||||
stream_executor::port::Env::Default()
|
||||
->GetSymbolFromLibrary(handle, symbol_name, &symbol)
|
||||
.IgnoreError();
|
||||
}
|
||||
return reinterpret_cast<T>(symbol);
|
||||
}
|
||||
|
||||
void LogFatalSymbolNotFound(const char* symbol_name) {
|
||||
LOG(FATAL) << symbol_name << " symbol not found.";
|
||||
}
|
||||
|
||||
cublasStatus_t GetSymbolNotFoundError() { return CUBLAS_STATUS_INTERNAL_ERROR; }
|
||||
} // namespace
|
||||
|
||||
// We only use cublasLt from CUDA 11.0 onward.
|
||||
#if CUDA_VERSION >= 11000
|
||||
#include "tensorflow/stream_executor/cuda/cublasLt_11_0.inc"
|
||||
#endif
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/gpus/cuda/include/cublasLt.h"
|
||||
#include "third_party/gpus/cuda/include/cublas_v2.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
|
||||
@ -226,17 +227,38 @@ bool CUDABlas::Init() {
|
||||
return false;
|
||||
}
|
||||
|
||||
#if CUDA_VERSION >= 11000
|
||||
ret = cublasLtCreate(&blasLt_);
|
||||
if (ret != CUBLAS_STATUS_SUCCESS) {
|
||||
LOG(ERROR) << "failed to create cublasLt handle: " << ToString(ret);
|
||||
return false;
|
||||
}
|
||||
#endif // CUDA_VERSION >= 11000
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
CUDABlas::CUDABlas(gpu::GpuExecutor *parent)
|
||||
: parent_(CHECK_NOTNULL(parent)), blas_(nullptr) {}
|
||||
: parent_(CHECK_NOTNULL(parent)),
|
||||
blas_(nullptr)
|
||||
#if CUDA_VERSION >= 11000
|
||||
,
|
||||
blasLt_(nullptr)
|
||||
#endif
|
||||
{
|
||||
}
|
||||
|
||||
CUDABlas::~CUDABlas() {
|
||||
if (blas_ != nullptr) {
|
||||
gpu::ScopedActivateExecutorContext sac{parent_};
|
||||
cublasDestroy(blas_);
|
||||
}
|
||||
#if CUDA_VERSION >= 11000
|
||||
if (blasLt_ != nullptr) {
|
||||
gpu::ScopedActivateExecutorContext sac{parent_};
|
||||
cublasLtDestroy(blasLt_);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
bool CUDABlas::SetStream(Stream *stream) {
|
||||
@ -253,6 +275,13 @@ bool CUDABlas::SetStream(Stream *stream) {
|
||||
return true;
|
||||
}
|
||||
|
||||
cudaStream_t CUDABlas::CUDAStream(Stream *stream) {
|
||||
CHECK(stream != nullptr);
|
||||
CHECK(AsGpuStreamValue(stream) != nullptr);
|
||||
gpu::ScopedActivateExecutorContext sac{parent_};
|
||||
return AsGpuStreamValue(stream);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Helper functions transforming blas arguments into cuBLAS arguments.
|
||||
@ -381,8 +410,122 @@ cudaDataType_t CUDAComputationType(blas::ComputationType ty) {
|
||||
return CUDA_C_32F;
|
||||
case blas::ComputationType::kComplexF64:
|
||||
return CUDA_C_64F;
|
||||
case blas::ComputationType::kTF32AsF32: // fall-through
|
||||
case blas::ComputationType::kBF16AsF32:
|
||||
// These cases are currently only supported in the blasLt routines, which
|
||||
// use CUBLASComputationType() instead.
|
||||
LOG(FATAL) << "Invalid value of blas::ComputationType.";
|
||||
}
|
||||
}
|
||||
|
||||
#if CUDA_VERSION >= 11000
|
||||
cublasComputeType_t CUBLASComputationType(blas::ComputationType ty) {
|
||||
switch (ty) {
|
||||
case blas::ComputationType::kF16:
|
||||
return CUBLAS_COMPUTE_16F;
|
||||
case blas::ComputationType::kF32: // fall-through
|
||||
case blas::ComputationType::kComplexF32:
|
||||
return CUBLAS_COMPUTE_32F;
|
||||
case blas::ComputationType::kF64: // fall-through
|
||||
case blas::ComputationType::kComplexF64:
|
||||
return CUBLAS_COMPUTE_64F;
|
||||
case blas::ComputationType::kI32:
|
||||
return CUBLAS_COMPUTE_32I;
|
||||
case blas::ComputationType::kTF32AsF32:
|
||||
return CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||
case blas::ComputationType::kBF16AsF32:
|
||||
return CUBLAS_COMPUTE_32F_FAST_16BF;
|
||||
}
|
||||
}
|
||||
#endif // CUDA_VERSION >= 11000
|
||||
|
||||
blas::DataType GetScaleType(blas::DataType data_type,
|
||||
blas::ComputationType compute_type) {
|
||||
bool is_complex = data_type == blas::DataType::kComplexFloat ||
|
||||
data_type == blas::DataType::kComplexDouble;
|
||||
switch (compute_type) {
|
||||
case blas::ComputationType::kF16:
|
||||
return blas::DataType::kHalf;
|
||||
case blas::ComputationType::kF32: // fall-through
|
||||
case blas::ComputationType::kComplexF32: // fall-through
|
||||
case blas::ComputationType::kTF32AsF32: // fall-through
|
||||
case blas::ComputationType::kBF16AsF32:
|
||||
return is_complex ? blas::DataType::kComplexFloat
|
||||
: blas::DataType::kFloat;
|
||||
case blas::ComputationType::kF64: // fall-through
|
||||
case blas::ComputationType::kComplexF64:
|
||||
return is_complex ? blas::DataType::kComplexDouble
|
||||
: blas::DataType::kDouble;
|
||||
case blas::ComputationType::kI32:
|
||||
return blas::DataType::kInt32;
|
||||
}
|
||||
}
|
||||
|
||||
#if CUDA_VERSION >= 11000
|
||||
cublasLtPointerMode_t CUBLASPointerMode(blas::PointerMode pointer_mode) {
|
||||
switch (pointer_mode) {
|
||||
case blas::PointerMode::kHost:
|
||||
return CUBLASLT_POINTER_MODE_HOST;
|
||||
case blas::PointerMode::kDevice:
|
||||
return CUBLASLT_POINTER_MODE_DEVICE;
|
||||
}
|
||||
}
|
||||
cublasLtEpilogue_t CUBLASEpilogue(blas::Epilogue epilogue) {
|
||||
switch (epilogue) {
|
||||
case blas::Epilogue::kDefault:
|
||||
return CUBLASLT_EPILOGUE_DEFAULT;
|
||||
case blas::Epilogue::kReLU:
|
||||
return CUBLASLT_EPILOGUE_RELU;
|
||||
case blas::Epilogue::kBias:
|
||||
return CUBLASLT_EPILOGUE_BIAS;
|
||||
case blas::Epilogue::kBiasThenReLU:
|
||||
return CUBLASLT_EPILOGUE_RELU_BIAS;
|
||||
}
|
||||
}
|
||||
#endif // CUDA_VERSION >= 11000
|
||||
|
||||
cudaDataType_t GetCUDADataType(blas::DataType ty) {
|
||||
switch (ty) {
|
||||
case blas::DataType::kHalf:
|
||||
return CUDA_R_16F;
|
||||
case blas::DataType::kFloat:
|
||||
return CUDA_R_32F;
|
||||
case blas::DataType::kDouble:
|
||||
return CUDA_R_64F;
|
||||
case blas::DataType::kInt8:
|
||||
return CUDA_R_8I;
|
||||
case blas::DataType::kInt32:
|
||||
return CUDA_R_32I;
|
||||
case blas::DataType::kComplexFloat:
|
||||
return CUDA_C_32F;
|
||||
case blas::DataType::kComplexDouble:
|
||||
return CUDA_C_64F;
|
||||
default:
|
||||
LOG(FATAL) << "Invalid value of blas::DataType in GetCUDADataType";
|
||||
}
|
||||
}
|
||||
|
||||
int GetDataTypeSizeBytes(blas::DataType ty) {
|
||||
switch (ty) {
|
||||
case blas::DataType::kHalf:
|
||||
return 2;
|
||||
case blas::DataType::kFloat:
|
||||
return 4;
|
||||
case blas::DataType::kDouble:
|
||||
return 8;
|
||||
case blas::DataType::kInt8:
|
||||
return 1;
|
||||
case blas::DataType::kInt32:
|
||||
return 4;
|
||||
case blas::DataType::kComplexFloat:
|
||||
return 8;
|
||||
case blas::DataType::kComplexDouble:
|
||||
return 16;
|
||||
default:
|
||||
LOG(FATAL) << "Invalid value of blas::DataType in GetDataTypeSizeBytes";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename FuncT, typename... Args>
|
||||
@ -2921,6 +3064,575 @@ bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
|
||||
GpuComplex(GpuMemoryMutable(b)), ldb);
|
||||
}
|
||||
|
||||
// We only use cublasLt from CUDA 11.0 onward.
|
||||
#if CUDA_VERSION >= 11000
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
inline port::Status SetCublasLtAttr(cublasLtMatrixLayout_t handle,
|
||||
cublasLtMatrixLayoutAttribute_t attr,
|
||||
const T &value) {
|
||||
cublasStatus_t status =
|
||||
cublasLtMatrixLayoutSetAttribute(handle, attr, &value, sizeof(T));
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
return port::Status(
|
||||
port::error::INTERNAL,
|
||||
absl::StrCat("cublasLtMatrixLayoutSetAttribute(attr=", attr,
|
||||
", value=", value, ") failed: ", ToString(status)));
|
||||
}
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline port::Status SetCublasLtAttr(cublasLtMatmulAlgo_t *handle,
|
||||
cublasLtMatmulAlgoConfigAttributes_t attr,
|
||||
const T &value) {
|
||||
cublasStatus_t status =
|
||||
cublasLtMatmulAlgoConfigSetAttribute(handle, attr, &value, sizeof(T));
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
return port::Status(
|
||||
port::error::INTERNAL,
|
||||
absl::StrCat("cublasLtMatmulAlgoConfigSetAttribute(attr=", attr,
|
||||
", value=", value, ") failed: ", ToString(status)));
|
||||
}
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline port::Status SetCublasLtAttr(cublasLtMatmulPreference_t handle,
|
||||
cublasLtMatmulPreferenceAttributes_t attr,
|
||||
const T &value) {
|
||||
cublasStatus_t status =
|
||||
cublasLtMatmulPreferenceSetAttribute(handle, attr, &value, sizeof(value));
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
return port::Status(
|
||||
port::error::INTERNAL,
|
||||
absl::StrCat("cublasLtMatmulPreferenceSetAttribute(attr=", attr,
|
||||
", value=", value, ") failed: ", ToString(status)));
|
||||
}
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool GetCublasLtAttr(const cublasLtMatmulAlgo_t *handle,
|
||||
cublasLtMatmulAlgoConfigAttributes_t attr,
|
||||
T *value) {
|
||||
auto mutable_handle = const_cast<cublasLtMatmulAlgo_t *>(handle);
|
||||
size_t bytes_written = 0;
|
||||
return cublasLtMatmulAlgoConfigGetAttribute(mutable_handle, attr, value,
|
||||
sizeof(T), &bytes_written) ==
|
||||
CUBLAS_STATUS_SUCCESS &&
|
||||
bytes_written == sizeof(T);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline const T &ValueForStrCat(const T &value) {
|
||||
return value;
|
||||
}
|
||||
template <typename T>
|
||||
inline absl::Hex ValueForStrCat(T *ptr) {
|
||||
return absl::Hex(reinterpret_cast<uintptr_t>(ptr));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline port::Status SetCublasLtAttr(cublasLtMatmulDesc_t handle,
|
||||
cublasLtMatmulDescAttributes_t attr,
|
||||
const T &value) {
|
||||
cublasStatus_t status =
|
||||
cublasLtMatmulDescSetAttribute(handle, attr, &value, sizeof(value));
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
return port::Status(
|
||||
port::error::INTERNAL,
|
||||
absl::StrCat("cublasLtMatmulDescSetAttribute(attr=", attr, ", value=",
|
||||
ValueForStrCat(value), ") failed: ", ToString(status)));
|
||||
}
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
struct MatmulDescDestroyer {
|
||||
void operator()(cublasLtMatmulDesc_t matmul_desc) const {
|
||||
cublasLtMatmulDescDestroy(matmul_desc);
|
||||
}
|
||||
};
|
||||
struct LayoutDestroyer {
|
||||
void operator()(cublasLtMatrixLayout_t layout) const {
|
||||
cublasLtMatrixLayoutDestroy(layout);
|
||||
}
|
||||
};
|
||||
struct MatmulPreferenceDestroyer {
|
||||
void operator()(cublasLtMatmulPreference_t matmul_pref) const {
|
||||
cublasLtMatmulPreferenceDestroy(matmul_pref);
|
||||
}
|
||||
};
|
||||
using UniqueOpDesc =
|
||||
std::unique_ptr<std::remove_pointer<cublasLtMatmulDesc_t>::type,
|
||||
MatmulDescDestroyer>;
|
||||
using UniqueLayoutDesc =
|
||||
std::unique_ptr<std::remove_pointer<cublasLtMatrixLayout_t>::type,
|
||||
LayoutDestroyer>;
|
||||
using UniqueMatmulPreference =
|
||||
std::unique_ptr<std::remove_pointer<cublasLtMatmulPreference_t>::type,
|
||||
MatmulPreferenceDestroyer>;
|
||||
|
||||
port::StatusOr<UniqueOpDesc> CreateCublasLtOperationDesc(
|
||||
blas::ComputationType computation_type, blas::DataType scale_type,
|
||||
blas::PointerMode pointer_mode, blas::Epilogue epilogue,
|
||||
blas::Transpose transa, blas::Transpose transb) {
|
||||
cublasLtMatmulDesc_t desc;
|
||||
cublasComputeType_t cublas_compute_type =
|
||||
CUBLASComputationType(computation_type);
|
||||
cudaDataType_t cuda_scale_type = GetCUDADataType(scale_type);
|
||||
cublasStatus_t status =
|
||||
cublasLtMatmulDescCreate(&desc, cublas_compute_type, cuda_scale_type);
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
return port::Status(
|
||||
port::error::INTERNAL,
|
||||
absl::StrCat("cublasLtMatmulDescCreate(computation_type=",
|
||||
computation_type, ") failed: ", ToString(status)));
|
||||
}
|
||||
UniqueOpDesc unique_desc(desc);
|
||||
SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_POINTER_MODE,
|
||||
CUBLASPointerMode(pointer_mode)));
|
||||
SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_EPILOGUE,
|
||||
CUBLASEpilogue(epilogue)));
|
||||
SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSA,
|
||||
CUDABlasTranspose(transa)));
|
||||
SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSB,
|
||||
CUDABlasTranspose(transb)));
|
||||
return unique_desc;
|
||||
}
|
||||
|
||||
port::StatusOr<UniqueLayoutDesc> CreateCublasLtLayoutDesc(
|
||||
blas::DataType data_type, uint64 rows, uint64 cols, int64 ld, int64 stride,
|
||||
int batch_count) {
|
||||
cublasLtMatrixLayout_t desc;
|
||||
cublasStatus_t status = cublasLtMatrixLayoutCreate(
|
||||
&desc, GetCUDADataType(data_type), rows, cols, ld);
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
return port::Status(
|
||||
port::error::INTERNAL,
|
||||
absl::StrCat("cublasLtMatrixLayoutCreate failed: ", ToString(status)));
|
||||
}
|
||||
UniqueLayoutDesc unique_desc(desc);
|
||||
SE_RETURN_IF_ERROR(
|
||||
SetCublasLtAttr(desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_count));
|
||||
SE_RETURN_IF_ERROR(SetCublasLtAttr(
|
||||
desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride));
|
||||
return unique_desc;
|
||||
}
|
||||
|
||||
// Helper function to allocate workspace.
|
||||
port::Status AllocateWorkspace(void **workspace,
|
||||
ScratchAllocator *scratch_allocator,
|
||||
size_t num_bytes) {
|
||||
SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace_bytes,
|
||||
scratch_allocator->AllocateBytes(num_bytes));
|
||||
*workspace = (void *)GpuMemoryMutable(&workspace_bytes);
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
blas::ComputationType ToComputationType();
|
||||
template <>
|
||||
blas::ComputationType ToComputationType<Eigen::half>() {
|
||||
return blas::ComputationType::kF16;
|
||||
}
|
||||
template <>
|
||||
blas::ComputationType ToComputationType<float>() {
|
||||
return blas::ComputationType::kF32;
|
||||
}
|
||||
template <>
|
||||
blas::ComputationType ToComputationType<double>() {
|
||||
return blas::ComputationType::kF64;
|
||||
}
|
||||
template <>
|
||||
blas::ComputationType ToComputationType<std::complex<float>>() {
|
||||
return blas::ComputationType::kComplexF32;
|
||||
}
|
||||
template <>
|
||||
blas::ComputationType ToComputationType<std::complex<double>>() {
|
||||
return blas::ComputationType::kComplexF64;
|
||||
}
|
||||
|
||||
class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
|
||||
public:
|
||||
CUDABlasLtMatmulPlan(UniqueOpDesc op_desc, UniqueLayoutDesc a_desc,
|
||||
UniqueLayoutDesc b_desc, UniqueLayoutDesc c_desc,
|
||||
UniqueLayoutDesc d_desc, blas::DataType ab_type,
|
||||
blas::DataType c_type, blas::DataType scale_type,
|
||||
blas::PointerMode pointer_mode, blas::Epilogue epilogue,
|
||||
int batch_count, int64 stride_a, int64 stride_b,
|
||||
int64 stride_c, int64 stride_d)
|
||||
: op_desc_(std::move(op_desc)),
|
||||
a_desc_(std::move(a_desc)),
|
||||
b_desc_(std::move(b_desc)),
|
||||
c_desc_(std::move(c_desc)),
|
||||
d_desc_(std::move(d_desc)),
|
||||
ab_type_(ab_type),
|
||||
c_type_(c_type),
|
||||
scale_type_(scale_type),
|
||||
pointer_mode_(pointer_mode),
|
||||
epilogue_(epilogue),
|
||||
batch_count_(batch_count),
|
||||
stride_a_(stride_a),
|
||||
stride_b_(stride_b),
|
||||
stride_c_(stride_c),
|
||||
stride_d_(stride_d) {}
|
||||
|
||||
cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); }
|
||||
cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); }
|
||||
cublasLtMatrixLayout_t b_desc() const { return b_desc_.get(); }
|
||||
cublasLtMatrixLayout_t c_desc() const { return c_desc_.get(); }
|
||||
cublasLtMatrixLayout_t d_desc() const { return d_desc_.get(); }
|
||||
bool ok() { return op_desc_ && a_desc_ && b_desc_ && c_desc_ && d_desc_; }
|
||||
|
||||
blas::DataType ab_type() const override { return ab_type_; }
|
||||
blas::DataType c_type() const override { return c_type_; }
|
||||
blas::DataType scale_type() const { return scale_type_; }
|
||||
blas::PointerMode pointer_mode() const { return pointer_mode_; }
|
||||
blas::Epilogue epilogue() const { return epilogue_; }
|
||||
int batch_count() const { return batch_count_; }
|
||||
int64 stride_a() const { return stride_a_; }
|
||||
int64 stride_b() const { return stride_b_; }
|
||||
int64 stride_c() const { return stride_c_; }
|
||||
int64 stride_d() const { return stride_d_; }
|
||||
|
||||
// Note: Must be const to satisfy API. This is always called before the plan
|
||||
// is executed, so the state change is not observed in subsequent executions.
|
||||
bool SetBiasPointer(const void *bias) const;
|
||||
|
||||
private:
|
||||
UniqueOpDesc op_desc_;
|
||||
UniqueLayoutDesc a_desc_;
|
||||
UniqueLayoutDesc b_desc_;
|
||||
UniqueLayoutDesc c_desc_;
|
||||
UniqueLayoutDesc d_desc_;
|
||||
blas::DataType ab_type_;
|
||||
blas::DataType c_type_;
|
||||
blas::DataType scale_type_;
|
||||
blas::PointerMode pointer_mode_;
|
||||
blas::Epilogue epilogue_;
|
||||
int batch_count_;
|
||||
int64 stride_a_;
|
||||
int64 stride_b_;
|
||||
int64 stride_c_;
|
||||
int64 stride_d_;
|
||||
};
|
||||
|
||||
bool CUDABlasLtMatmulPlan::SetBiasPointer(const void *bias) const {
|
||||
return SetCublasLtAttr(op_desc_.get(), CUBLASLT_MATMUL_DESC_BIAS_POINTER,
|
||||
bias)
|
||||
.ok();
|
||||
}
|
||||
|
||||
class CUDABlasLtMatmulAlgorithm final : public blas::IBlasLtMatmulAlgorithm {
|
||||
public:
|
||||
CUDABlasLtMatmulAlgorithm(blas::AlgorithmType index,
|
||||
cublasLtMatmulAlgo_t algo, size_t workspace_size)
|
||||
: index_(index), algo_(algo), workspace_size_(workspace_size) {}
|
||||
|
||||
blas::AlgorithmType index() const override { return index_; }
|
||||
|
||||
size_t workspace_size() const override { return workspace_size_; }
|
||||
|
||||
const cublasLtMatmulAlgo_t *algo() const { return &algo_; }
|
||||
|
||||
int algo_id() const {
|
||||
int id;
|
||||
GetCublasLtAttr(&algo_, CUBLASLT_ALGO_CONFIG_ID, &id);
|
||||
return id;
|
||||
}
|
||||
|
||||
private:
|
||||
blas::AlgorithmType index_;
|
||||
cublasLtMatmulAlgo_t algo_;
|
||||
size_t workspace_size_;
|
||||
};
|
||||
|
||||
port::StatusOr<UniqueMatmulPreference> CreateCublasLtMatmulPreference(
|
||||
const blas::IBlasLtMatmulPlan *plan, size_t max_workspace_bytes) {
|
||||
cublasLtMatmulPreference_t preference;
|
||||
cublasStatus_t status = cublasLtMatmulPreferenceCreate(&preference);
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
return port::Status(port::error::INTERNAL,
|
||||
absl::StrCat("cublasLtMatmulPreferenceCreate failed: ",
|
||||
ToString(status)));
|
||||
}
|
||||
UniqueMatmulPreference unique_preference(preference);
|
||||
SE_RETURN_IF_ERROR(SetCublasLtAttr(preference,
|
||||
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
max_workspace_bytes));
|
||||
|
||||
const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
|
||||
if (cuda_plan.batch_count() == 0) {
|
||||
return unique_preference;
|
||||
}
|
||||
// This is a workaround for a known issue in cuBlasLt where the heuristic may
|
||||
// in rare cases select an algo that does not support the specified stride.
|
||||
// Specifying the alignment requirements manually like this avoids the issue.
|
||||
auto get_alignment_bytes = [](int64 stride, blas::DataType dtype) {
|
||||
return (stride & -stride) * GetDataTypeSizeBytes(dtype);
|
||||
};
|
||||
if (cuda_plan.stride_a()) {
|
||||
SE_RETURN_IF_ERROR(
|
||||
SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES,
|
||||
(uint32)get_alignment_bytes(cuda_plan.stride_a(),
|
||||
cuda_plan.ab_type())));
|
||||
}
|
||||
if (cuda_plan.stride_b()) {
|
||||
SE_RETURN_IF_ERROR(
|
||||
SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES,
|
||||
(uint32)get_alignment_bytes(cuda_plan.stride_b(),
|
||||
cuda_plan.ab_type())));
|
||||
}
|
||||
if (cuda_plan.stride_c()) {
|
||||
SE_RETURN_IF_ERROR(SetCublasLtAttr(
|
||||
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES,
|
||||
(uint32)get_alignment_bytes(cuda_plan.stride_c(), cuda_plan.c_type())));
|
||||
}
|
||||
if (cuda_plan.stride_d()) {
|
||||
SE_RETURN_IF_ERROR(SetCublasLtAttr(
|
||||
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES,
|
||||
(uint32)get_alignment_bytes(cuda_plan.stride_d(), cuda_plan.c_type())));
|
||||
}
|
||||
return unique_preference;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
#endif // CUDA_VERSION >= 11000
|
||||
|
||||
port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
|
||||
CUDABlas::CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &p) {
|
||||
#if CUDA_VERSION >= 11000
|
||||
SE_ASSIGN_OR_RETURN(
|
||||
auto op_desc,
|
||||
CreateCublasLtOperationDesc(
|
||||
p.computation_type, GetScaleType(p.c_type, p.computation_type),
|
||||
p.pointer_mode, p.epilogue, p.transa, p.transb));
|
||||
uint64 rows_a = p.transa == blas::Transpose::kNoTranspose ? p.m : p.k;
|
||||
uint64 cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m;
|
||||
uint64 rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n;
|
||||
uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k;
|
||||
SE_ASSIGN_OR_RETURN(auto a_desc,
|
||||
CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda,
|
||||
p.stride_a, p.batch_count));
|
||||
SE_ASSIGN_OR_RETURN(auto b_desc,
|
||||
CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb,
|
||||
p.stride_b, p.batch_count));
|
||||
SE_ASSIGN_OR_RETURN(auto c_desc,
|
||||
CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc,
|
||||
p.stride_c, p.batch_count));
|
||||
SE_ASSIGN_OR_RETURN(auto d_desc,
|
||||
CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc,
|
||||
p.stride_c, p.batch_count));
|
||||
blas::DataType scale_type = GetScaleType(p.c_type, p.computation_type);
|
||||
|
||||
return static_cast<std::unique_ptr<blas::IBlasLtMatmulPlan>>(
|
||||
std::make_unique<CUDABlasLtMatmulPlan>(
|
||||
std::move(op_desc), std::move(a_desc), std::move(b_desc),
|
||||
std::move(c_desc), std::move(d_desc), p.ab_type, p.c_type, scale_type,
|
||||
p.pointer_mode, p.epilogue, p.batch_count, p.stride_a, p.stride_b,
|
||||
p.stride_c, p.stride_c));
|
||||
#else
|
||||
return port::Status(
|
||||
port::error::UNIMPLEMENTED,
|
||||
"CreateBlasLtMatmulPlan is not supported with this version of CUDA");
|
||||
#endif
|
||||
}
|
||||
|
||||
port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
|
||||
CUDABlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
|
||||
size_t max_workspace_size,
|
||||
int max_algorithm_count) {
|
||||
#if CUDA_VERSION >= 11000
|
||||
SE_ASSIGN_OR_RETURN(UniqueMatmulPreference preference,
|
||||
CreateCublasLtMatmulPreference(plan, max_workspace_size));
|
||||
|
||||
std::vector<cublasLtMatmulHeuristicResult_t> results(max_algorithm_count);
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
|
||||
CHECK(blasLt_ != nullptr);
|
||||
|
||||
gpu::ScopedActivateExecutorContext sac{parent_};
|
||||
|
||||
int found_algorithm_count = 0;
|
||||
const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
|
||||
cublasStatus_t status = cublasLtMatmulAlgoGetHeuristic(
|
||||
blasLt_, cuda_plan.op_desc(), cuda_plan.a_desc(), cuda_plan.b_desc(),
|
||||
cuda_plan.c_desc(), cuda_plan.d_desc(), preference.get(),
|
||||
max_algorithm_count, results.data(), &found_algorithm_count);
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
return port::Status(
|
||||
port::error::INTERNAL,
|
||||
absl::StrCat("cublasLtMatmulAlgoGetHeuristic failed: ",
|
||||
ToString(status)));
|
||||
}
|
||||
results.resize(found_algorithm_count);
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>> out_algorithms;
|
||||
out_algorithms.reserve(results.size());
|
||||
for (size_t i = 0; i < results.size(); ++i) {
|
||||
const auto &result = results[i];
|
||||
if (result.state != CUBLAS_STATUS_SUCCESS) continue; // Skip failed algos
|
||||
out_algorithms.emplace_back(std::make_unique<CUDABlasLtMatmulAlgorithm>(
|
||||
i, result.algo, result.workspaceSize));
|
||||
}
|
||||
return out_algorithms;
|
||||
#else // if CUDA_VERSION < 11000
|
||||
return port::Status(
|
||||
port::error::UNIMPLEMENTED,
|
||||
"GetBlasLtMatmulAlgorithms is not supported with this version of CUDA");
|
||||
#endif
|
||||
}
|
||||
|
||||
#if CUDA_VERSION >= 11000
|
||||
bool CUDABlas::DoBlasLtMatmulInternal(
|
||||
Stream *stream, bool err_on_failure, const blas::IBlasLtMatmulPlan *plan,
|
||||
const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a,
|
||||
DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta,
|
||||
DeviceMemoryBase c, DeviceMemoryBase d, ScratchAllocator *scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias) {
|
||||
const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
|
||||
const auto &cuda_algo =
|
||||
*static_cast<const CUDABlasLtMatmulAlgorithm *>(algorithm);
|
||||
|
||||
if (alpha.data_type() != cuda_plan.scale_type() ||
|
||||
beta.data_type() != cuda_plan.scale_type()) {
|
||||
VLOG(2) << "DoBlasLtMatmul returning false because alpha and beta types do "
|
||||
"not match plan: expected "
|
||||
<< cuda_plan.c_type() << ", got alpha=" << alpha.data_type()
|
||||
<< " beta=" << beta.data_type();
|
||||
return false;
|
||||
}
|
||||
if (alpha.is_pointer() != beta.is_pointer()) {
|
||||
VLOG(2) << "DoBlasLtMatmul returning false because one of `alpha` "
|
||||
"and `beta` is a pointer, but the other is not.";
|
||||
return false;
|
||||
}
|
||||
bool is_pointer_mode_host = !alpha.is_pointer();
|
||||
if ((cuda_plan.pointer_mode() == blas::PointerMode::kHost) !=
|
||||
is_pointer_mode_host) {
|
||||
VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong "
|
||||
"pointer_mode for the given alpha/beta.";
|
||||
return false;
|
||||
}
|
||||
if ((cuda_plan.epilogue() == blas::Epilogue::kBias ||
|
||||
cuda_plan.epilogue() == blas::Epilogue::kBiasThenReLU) !=
|
||||
(bias != nullptr)) {
|
||||
VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong "
|
||||
"epilogue for the given bias pointer.";
|
||||
return false;
|
||||
}
|
||||
if (bias != nullptr) {
|
||||
if (!cuda_plan.SetBiasPointer(bias.opaque())) {
|
||||
VLOG(2) << "DoBlasLtMatmul returning false because setting the bias "
|
||||
"pointer failed.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
const void *alpha_ptr = alpha.is_pointer() ? alpha.opaque_pointer().opaque()
|
||||
: alpha.opaque_value();
|
||||
const void *beta_ptr =
|
||||
beta.is_pointer() ? beta.opaque_pointer().opaque() : beta.opaque_value();
|
||||
|
||||
void *workspace = nullptr;
|
||||
if (cuda_algo.workspace_size()) {
|
||||
port::Status allocation_status = AllocateWorkspace(
|
||||
&workspace, scratch_allocator, cuda_algo.workspace_size());
|
||||
if (!allocation_status.ok()) {
|
||||
if (err_on_failure || VLOG_IS_ON(3)) {
|
||||
LOG(ERROR)
|
||||
<< "Failed to allocate workspace for cublasLtMatmul algo with id: "
|
||||
<< cuda_algo.algo_id() << " requiring "
|
||||
<< cuda_algo.workspace_size() << " bytes of workspace";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
cudaStream_t cuda_stream = CUDAStream(stream);
|
||||
|
||||
absl::MutexLock lock(&mu_);
|
||||
|
||||
CHECK(blasLt_ != nullptr);
|
||||
|
||||
gpu::ScopedActivateExecutorContext sac{parent_};
|
||||
|
||||
cublasStatus_t ret = cublasLtMatmul(
|
||||
blasLt_, cuda_plan.op_desc(), alpha_ptr, a.opaque(), cuda_plan.a_desc(),
|
||||
b.opaque(), cuda_plan.b_desc(), beta_ptr, c.opaque(), cuda_plan.c_desc(),
|
||||
d.opaque(), cuda_plan.d_desc(), cuda_algo.algo(), workspace,
|
||||
cuda_algo.workspace_size(), cuda_stream);
|
||||
if (ret != CUBLAS_STATUS_SUCCESS) {
|
||||
if (err_on_failure || VLOG_IS_ON(3)) {
|
||||
LOG(ERROR) << "failed to run cublasLtMatmul routine: " << ToString(ret);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
#endif // CUDA_VERSION >= 11000
|
||||
|
||||
bool CUDABlas::DoBlasLtMatmul(
|
||||
Stream *stream, const blas::IBlasLtMatmulPlan *plan,
|
||||
const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a,
|
||||
DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta,
|
||||
DeviceMemoryBase c, ScratchAllocator *scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias,
|
||||
blas::ProfileResult *output_profile_result) {
|
||||
#if CUDA_VERSION >= 11000
|
||||
const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
|
||||
HostOrDeviceScalar<void> alpha_cast = alpha;
|
||||
HostOrDeviceScalar<void> beta_cast = beta;
|
||||
if (cuda_plan.c_type() == blas::DataType::kHalf &&
|
||||
cuda_plan.scale_type() == blas::DataType::kFloat) {
|
||||
// The given alpha and beta types are F16 (they always match c), but F32*
|
||||
// computation type requires that they be F32, so we must cast them.
|
||||
if (alpha.is_pointer() || beta.is_pointer()) {
|
||||
// We cannot easily convert a pointer to f16 memory to a pointer to f32
|
||||
// memory from here, so we don't support this for now.
|
||||
return false;
|
||||
}
|
||||
alpha_cast = HostOrDeviceScalar<void>(
|
||||
static_cast<float>(alpha.value<Eigen::half>()));
|
||||
beta_cast =
|
||||
HostOrDeviceScalar<void>(static_cast<float>(beta.value<Eigen::half>()));
|
||||
}
|
||||
|
||||
std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
|
||||
if (output_profile_result) {
|
||||
timer.reset(new GpuTimer(parent_));
|
||||
if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool err_on_failure = timer != nullptr;
|
||||
bool result = DoBlasLtMatmulInternal(stream, err_on_failure, plan, alpha_cast,
|
||||
a, b, beta_cast, c, c, scratch_allocator,
|
||||
algorithm, bias);
|
||||
|
||||
if (timer && result) {
|
||||
// GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
|
||||
// state.
|
||||
if (!timer->Stop(AsGpuStream(stream))) {
|
||||
return false;
|
||||
}
|
||||
output_profile_result->set_is_valid(true);
|
||||
output_profile_result->set_algorithm(algorithm->index());
|
||||
output_profile_result->set_elapsed_time_in_ms(
|
||||
timer->GetElapsedMilliseconds());
|
||||
}
|
||||
return result;
|
||||
#else // if CUDA_VERSION < 11000
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
port::Status CUDABlas::GetVersion(std::string *version) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
|
||||
|
@ -21,7 +21,9 @@ limitations under the License.
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "third_party/gpus/cuda/include/cublasLt.h"
|
||||
#include "third_party/gpus/cuda/include/cublas_v2.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/stream_executor/blas.h"
|
||||
#include "tensorflow/stream_executor/host_or_device_scalar.h"
|
||||
@ -71,6 +73,9 @@ class CUDABlas : public blas::BlasSupport {
|
||||
// invoked before calling into cuBLAS.
|
||||
bool SetStream(Stream *stream) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Returns the underlying CUDA stream.
|
||||
cudaStream_t CUDAStream(Stream *stream);
|
||||
|
||||
// A helper function that calls the real cuBLAS function together with error
|
||||
// handling.
|
||||
//
|
||||
@ -134,6 +139,17 @@ class CUDABlas : public blas::BlasSupport {
|
||||
const T &beta, DeviceMemory<T> *y, int incy,
|
||||
blas::ProfileResult *output_profile_result);
|
||||
|
||||
// Helper function for implementing DoBlasLtMatmul.
|
||||
bool DoBlasLtMatmulInternal(Stream *stream, bool err_on_failure,
|
||||
const blas::IBlasLtMatmulPlan *plan,
|
||||
const HostOrDeviceScalar<void> &alpha,
|
||||
DeviceMemoryBase a, DeviceMemoryBase b,
|
||||
const HostOrDeviceScalar<void> &beta,
|
||||
DeviceMemoryBase c, DeviceMemoryBase d,
|
||||
ScratchAllocator *scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm *algorithm,
|
||||
DeviceMemoryBase bias);
|
||||
|
||||
// Guards the cuBLAS handle for this device.
|
||||
absl::Mutex mu_;
|
||||
|
||||
@ -144,6 +160,11 @@ class CUDABlas : public blas::BlasSupport {
|
||||
// cuBLAS library handle on the device.
|
||||
cublasHandle_t blas_ TF_GUARDED_BY(mu_);
|
||||
|
||||
#if CUDA_VERSION >= 11000
|
||||
// cuBLASLt library handle on the device.
|
||||
cublasLtHandle_t blasLt_ GUARDED_BY(mu_);
|
||||
#endif
|
||||
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(CUDABlas);
|
||||
};
|
||||
|
||||
|
66
tensorflow/stream_executor/data_type.h
Normal file
66
tensorflow/stream_executor/data_type.h
Normal file
@ -0,0 +1,66 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_DATA_TYPE_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_DATA_TYPE_H_
|
||||
|
||||
#include <complex>
|
||||
|
||||
#include "tensorflow/stream_executor/dnn.pb.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
|
||||
namespace Eigen {
|
||||
struct half;
|
||||
} // namespace Eigen
|
||||
|
||||
namespace stream_executor {
|
||||
namespace dnn {
|
||||
|
||||
// A helper class to convert C/C++ types to the proper enums.
|
||||
template <typename T>
|
||||
struct ToDataType;
|
||||
template <>
|
||||
struct ToDataType<float> {
|
||||
static constexpr DataType value = DataType::kFloat;
|
||||
};
|
||||
template <>
|
||||
struct ToDataType<double> {
|
||||
static constexpr DataType value = DataType::kDouble;
|
||||
};
|
||||
template <>
|
||||
struct ToDataType<Eigen::half> {
|
||||
static constexpr DataType value = DataType::kHalf;
|
||||
};
|
||||
template <>
|
||||
struct ToDataType<tensorflow::int8> {
|
||||
static constexpr DataType value = DataType::kInt8;
|
||||
};
|
||||
template <>
|
||||
struct ToDataType<tensorflow::int32> {
|
||||
static constexpr DataType value = DataType::kInt32;
|
||||
};
|
||||
template <>
|
||||
struct ToDataType<std::complex<float>> {
|
||||
static constexpr DataType value = DataType::kComplexFloat;
|
||||
};
|
||||
template <>
|
||||
struct ToDataType<std::complex<double>> {
|
||||
static constexpr DataType value = DataType::kComplexDouble;
|
||||
};
|
||||
|
||||
} // namespace dnn
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_DATA_TYPE_H_
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/stream_executor/data_type.h"
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
#include "tensorflow/stream_executor/dnn.pb.h"
|
||||
#include "tensorflow/stream_executor/lib/array_slice.h"
|
||||
@ -110,30 +111,6 @@ enum class QuantizedActivationMode {
|
||||
k32Bit = 4,
|
||||
};
|
||||
|
||||
// A helper class to convert C/C++ types to the proper enums.
|
||||
template <typename T>
|
||||
struct ToDataType;
|
||||
template <>
|
||||
struct ToDataType<float> {
|
||||
static constexpr DataType value = DataType::kFloat;
|
||||
};
|
||||
template <>
|
||||
struct ToDataType<double> {
|
||||
static constexpr DataType value = DataType::kDouble;
|
||||
};
|
||||
template <>
|
||||
struct ToDataType<Eigen::half> {
|
||||
static constexpr DataType value = DataType::kHalf;
|
||||
};
|
||||
template <>
|
||||
struct ToDataType<int8> {
|
||||
static constexpr DataType value = DataType::kInt8;
|
||||
};
|
||||
template <>
|
||||
struct ToDataType<int32> {
|
||||
static constexpr DataType value = DataType::kInt32;
|
||||
};
|
||||
|
||||
// Specifies the types of a RNN model.
|
||||
enum class RnnMode {
|
||||
kRnnRelu = 0,
|
||||
|
@ -12,6 +12,8 @@ enum DataType {
|
||||
kHalf = 2;
|
||||
kInt8 = 3;
|
||||
kInt32 = 4;
|
||||
kComplexFloat = 5;
|
||||
kComplexDouble = 6;
|
||||
}
|
||||
|
||||
// Describes how a convolution input or output layer's data is formatted.
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_HOST_OR_DEVICE_SCALAR_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_HOST_OR_DEVICE_SCALAR_H_
|
||||
|
||||
#include "tensorflow/stream_executor/data_type.h"
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
#include "tensorflow/stream_executor/platform/logging.h"
|
||||
|
||||
@ -23,6 +24,7 @@ namespace stream_executor {
|
||||
|
||||
// Allows to represent a value that is either a host scalar or a scalar stored
|
||||
// on the GPU device.
|
||||
// See also the specialization for ElemT=void below.
|
||||
template <typename ElemT>
|
||||
class HostOrDeviceScalar {
|
||||
public:
|
||||
@ -52,5 +54,154 @@ class HostOrDeviceScalar {
|
||||
bool is_pointer_;
|
||||
};
|
||||
|
||||
// Specialization for wrapping a dynamically-typed value (via type erasure).
|
||||
template <>
|
||||
class HostOrDeviceScalar<void> {
|
||||
public:
|
||||
using DataType = dnn::DataType;
|
||||
|
||||
// Constructors not marked as explicit because when using this constructor, we
|
||||
// usually want to set this to a compile-time constant.
|
||||
|
||||
// NOLINTNEXTLINE google-explicit-constructor
|
||||
HostOrDeviceScalar(float value)
|
||||
: float_(value), is_pointer_(false), dtype_(DataType::kFloat) {}
|
||||
// NOLINTNEXTLINE google-explicit-constructor
|
||||
HostOrDeviceScalar(double value)
|
||||
: double_(value), is_pointer_(false), dtype_(DataType::kDouble) {}
|
||||
// NOLINTNEXTLINE google-explicit-constructor
|
||||
HostOrDeviceScalar(Eigen::half value)
|
||||
: half_(value), is_pointer_(false), dtype_(DataType::kHalf) {}
|
||||
// NOLINTNEXTLINE google-explicit-constructor
|
||||
HostOrDeviceScalar(int8 value)
|
||||
: int8_(value), is_pointer_(false), dtype_(DataType::kInt8) {}
|
||||
// NOLINTNEXTLINE google-explicit-constructor
|
||||
HostOrDeviceScalar(int32 value)
|
||||
: int32_(value), is_pointer_(false), dtype_(DataType::kInt32) {}
|
||||
// NOLINTNEXTLINE google-explicit-constructor
|
||||
HostOrDeviceScalar(std::complex<float> value)
|
||||
: complex_float_(value),
|
||||
is_pointer_(false),
|
||||
dtype_(DataType::kComplexFloat) {}
|
||||
// NOLINTNEXTLINE google-explicit-constructor
|
||||
HostOrDeviceScalar(std::complex<double> value)
|
||||
: complex_double_(value),
|
||||
is_pointer_(false),
|
||||
dtype_(DataType::kComplexDouble) {}
|
||||
template <typename T>
|
||||
explicit HostOrDeviceScalar(const DeviceMemory<T>& pointer)
|
||||
: pointer_(pointer),
|
||||
is_pointer_(true),
|
||||
dtype_(dnn::ToDataType<T>::value) {
|
||||
CHECK_EQ(1, pointer.ElementCount());
|
||||
}
|
||||
// Construct from statically-typed version.
|
||||
template <typename T, typename std::enable_if<!std::is_same<T, void>::value,
|
||||
int>::type = 0>
|
||||
// NOLINTNEXTLINE google-explicit-constructor
|
||||
HostOrDeviceScalar(const HostOrDeviceScalar<T>& other) {
|
||||
if (other.is_pointer()) {
|
||||
*this = HostOrDeviceScalar(other.pointer());
|
||||
} else {
|
||||
*this = HostOrDeviceScalar(other.value());
|
||||
}
|
||||
}
|
||||
|
||||
bool is_pointer() const { return is_pointer_; }
|
||||
template <typename T>
|
||||
const DeviceMemory<T>& pointer() const {
|
||||
CHECK(is_pointer());
|
||||
CHECK(dtype_ == dnn::ToDataType<T>::value);
|
||||
return pointer_;
|
||||
}
|
||||
template <typename T>
|
||||
const T& value() const {
|
||||
CHECK(!is_pointer());
|
||||
CHECK(dtype_ == dnn::ToDataType<T>::value);
|
||||
return value_impl<T>();
|
||||
}
|
||||
const DeviceMemoryBase& opaque_pointer() const {
|
||||
CHECK(is_pointer());
|
||||
return pointer_;
|
||||
}
|
||||
const void* opaque_value() const {
|
||||
CHECK(!is_pointer());
|
||||
switch (dtype_) {
|
||||
case DataType::kFloat:
|
||||
return &float_;
|
||||
case DataType::kDouble:
|
||||
return &double_;
|
||||
case DataType::kHalf:
|
||||
return &half_;
|
||||
case DataType::kInt8:
|
||||
return &int8_;
|
||||
case DataType::kInt32:
|
||||
return &int32_;
|
||||
case DataType::kComplexFloat:
|
||||
return &complex_float_;
|
||||
case DataType::kComplexDouble:
|
||||
return &complex_double_;
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
DataType data_type() const { return dtype_; }
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
const T& value_impl() const;
|
||||
|
||||
union {
|
||||
float float_;
|
||||
double double_;
|
||||
Eigen::half half_;
|
||||
int8 int8_;
|
||||
int32 int32_;
|
||||
std::complex<float> complex_float_;
|
||||
std::complex<double> complex_double_;
|
||||
DeviceMemoryBase pointer_;
|
||||
};
|
||||
bool is_pointer_;
|
||||
DataType dtype_;
|
||||
};
|
||||
|
||||
template <>
|
||||
inline const float& HostOrDeviceScalar<void>::value_impl<float>() const {
|
||||
return float_;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const double& HostOrDeviceScalar<void>::value_impl<double>() const {
|
||||
return double_;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const Eigen::half& HostOrDeviceScalar<void>::value_impl<Eigen::half>()
|
||||
const {
|
||||
return half_;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const int8& HostOrDeviceScalar<void>::value_impl<int8>() const {
|
||||
return int8_;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const int32& HostOrDeviceScalar<void>::value_impl<int32>() const {
|
||||
return int32_;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const std::complex<float>&
|
||||
HostOrDeviceScalar<void>::value_impl<std::complex<float>>() const {
|
||||
return complex_float_;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const std::complex<double>&
|
||||
HostOrDeviceScalar<void>::value_impl<std::complex<double>>() const {
|
||||
return complex_double_;
|
||||
}
|
||||
|
||||
} // namespace stream_executor
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_HOST_OR_DEVICE_SCALAR_H_
|
||||
|
@ -23,6 +23,7 @@ namespace DsoLoader {
|
||||
port::Status TryDlopenCUDALibraries() {
|
||||
auto cudart_status = GetCudaRuntimeDsoHandle();
|
||||
auto cublas_status = GetCublasDsoHandle();
|
||||
auto cublaslt_status = GetCublasLtDsoHandle();
|
||||
auto cufft_status = GetCufftDsoHandle();
|
||||
auto curand_status = GetCurandDsoHandle();
|
||||
auto cusolver_status = GetCusolverDsoHandle();
|
||||
@ -31,7 +32,7 @@ port::Status TryDlopenCUDALibraries() {
|
||||
if (!cudart_status.status().ok() || !cublas_status.status().ok() ||
|
||||
!cufft_status.status().ok() || !curand_status.status().ok() ||
|
||||
!cusolver_status.status().ok() || !cusparse_status.status().ok() ||
|
||||
!cudnn_status.status().ok()) {
|
||||
!cudnn_status.status().ok() || !cublaslt_status.status().ok()) {
|
||||
return port::Status(port::error::INTERNAL,
|
||||
absl::StrCat("Cannot dlopen all CUDA libraries."));
|
||||
} else {
|
||||
|
@ -85,6 +85,10 @@ port::StatusOr<void*> GetCublasDsoHandle() {
|
||||
return GetDsoHandle("cublas", GetCublasVersion());
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetCublasLtDsoHandle() {
|
||||
return GetDsoHandle("cublasLt", GetCublasVersion());
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetCufftDsoHandle() {
|
||||
return GetDsoHandle("cufft", GetCufftVersion());
|
||||
}
|
||||
@ -161,6 +165,11 @@ port::StatusOr<void*> GetCublasDsoHandle() {
|
||||
return *result;
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetCublasLtDsoHandle() {
|
||||
static auto result = new auto(DsoLoader::GetCublasLtDsoHandle());
|
||||
return *result;
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetCurandDsoHandle() {
|
||||
static auto result = new auto(DsoLoader::GetCurandDsoHandle());
|
||||
return *result;
|
||||
|
@ -37,6 +37,7 @@ namespace DsoLoader {
|
||||
port::StatusOr<void*> GetCudaDriverDsoHandle();
|
||||
port::StatusOr<void*> GetCudaRuntimeDsoHandle();
|
||||
port::StatusOr<void*> GetCublasDsoHandle();
|
||||
port::StatusOr<void*> GetCublasLtDsoHandle();
|
||||
port::StatusOr<void*> GetCufftDsoHandle();
|
||||
port::StatusOr<void*> GetCurandDsoHandle();
|
||||
port::StatusOr<void*> GetCusolverDsoHandle();
|
||||
@ -72,6 +73,7 @@ namespace CachedDsoLoader {
|
||||
port::StatusOr<void*> GetCudaDriverDsoHandle();
|
||||
port::StatusOr<void*> GetCudaRuntimeDsoHandle();
|
||||
port::StatusOr<void*> GetCublasDsoHandle();
|
||||
port::StatusOr<void*> GetCublasLtDsoHandle();
|
||||
port::StatusOr<void*> GetCufftDsoHandle();
|
||||
port::StatusOr<void*> GetCurandDsoHandle();
|
||||
port::StatusOr<void*> GetCusolverDsoHandle();
|
||||
|
@ -2540,6 +2540,32 @@ port::Status ROCMBlas::GetVersion(string *version) {
|
||||
return port::UnimplementedError("");
|
||||
}
|
||||
|
||||
port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
|
||||
ROCMBlas::CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &p) {
|
||||
return port::Status(
|
||||
port::error::UNIMPLEMENTED,
|
||||
"CreateBlasLtMatmulPlan is not supported with this version of ROCM");
|
||||
}
|
||||
|
||||
port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
|
||||
ROCMBlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
|
||||
size_t max_workspace_size,
|
||||
int max_algorithm_count) {
|
||||
return port::Status(
|
||||
port::error::UNIMPLEMENTED,
|
||||
"GetBlasLtMatmulAlgorithms is not supported with this version of ROCM");
|
||||
}
|
||||
|
||||
bool ROCMBlas::DoBlasLtMatmul(
|
||||
Stream *stream, const blas::IBlasLtMatmulPlan *plan,
|
||||
const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a,
|
||||
DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta,
|
||||
DeviceMemoryBase c, ScratchAllocator *scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias,
|
||||
blas::ProfileResult *output_profile_result) {
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
|
||||
void initialize_rocblas() {
|
||||
|
@ -4322,6 +4322,80 @@ Stream &Stream::ThenBlasGemmStridedBatched(
|
||||
c, ldc, stride_c, batch_count);
|
||||
}
|
||||
|
||||
template <typename ABType, typename CType>
|
||||
Stream &Stream::ThenBlasLtMatmulImpl(
|
||||
const blas::IBlasLtMatmulPlan *plan, const HostOrDeviceScalar<CType> &alpha,
|
||||
const DeviceMemory<ABType> &a, const DeviceMemory<ABType> &b,
|
||||
const HostOrDeviceScalar<CType> &beta, DeviceMemory<CType> *c,
|
||||
ScratchAllocator *scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm *algorithm,
|
||||
const DeviceMemory<CType> &bias,
|
||||
blas::ProfileResult *output_profile_result) {
|
||||
VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
|
||||
PARAM(c), PARAM(algorithm), PARAM(bias));
|
||||
|
||||
ThenBlasWithProfileImpl<
|
||||
const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<CType> &,
|
||||
const DeviceMemory<ABType> &, const DeviceMemory<ABType> &,
|
||||
const HostOrDeviceScalar<CType> &, DeviceMemory<CType> *,
|
||||
ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
|
||||
const DeviceMemory<CType> &>
|
||||
impl;
|
||||
return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
|
||||
c, scratch_allocator, algorithm, bias, output_profile_result);
|
||||
}
|
||||
|
||||
// Explicit template instantiations for each supported type combination.
|
||||
template Stream &Stream::ThenBlasLtMatmulImpl<int8, int32>(
|
||||
const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<int32> &,
|
||||
const DeviceMemory<int8> &, const DeviceMemory<int8> &,
|
||||
const HostOrDeviceScalar<int32> &, DeviceMemory<int32> *,
|
||||
ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
|
||||
const DeviceMemory<int32> &, blas::ProfileResult *);
|
||||
|
||||
template Stream &Stream::ThenBlasLtMatmulImpl<Eigen::half, Eigen::half>(
|
||||
const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<Eigen::half> &,
|
||||
const DeviceMemory<Eigen::half> &, const DeviceMemory<Eigen::half> &,
|
||||
const HostOrDeviceScalar<Eigen::half> &, DeviceMemory<Eigen::half> *,
|
||||
ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
|
||||
const DeviceMemory<Eigen::half> &, blas::ProfileResult *);
|
||||
|
||||
template Stream &Stream::ThenBlasLtMatmulImpl<float, float>(
|
||||
const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<float> &,
|
||||
const DeviceMemory<float> &, const DeviceMemory<float> &,
|
||||
const HostOrDeviceScalar<float> &, DeviceMemory<float> *,
|
||||
ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
|
||||
const DeviceMemory<float> &, blas::ProfileResult *);
|
||||
|
||||
template Stream &Stream::ThenBlasLtMatmulImpl<double, double>(
|
||||
const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<double> &,
|
||||
const DeviceMemory<double> &, const DeviceMemory<double> &,
|
||||
const HostOrDeviceScalar<double> &, DeviceMemory<double> *,
|
||||
ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
|
||||
const DeviceMemory<double> &, blas::ProfileResult *);
|
||||
|
||||
template Stream &
|
||||
Stream::ThenBlasLtMatmulImpl<std::complex<float>, std::complex<float>>(
|
||||
const blas::IBlasLtMatmulPlan *,
|
||||
const HostOrDeviceScalar<std::complex<float>> &,
|
||||
const DeviceMemory<std::complex<float>> &,
|
||||
const DeviceMemory<std::complex<float>> &,
|
||||
const HostOrDeviceScalar<std::complex<float>> &,
|
||||
DeviceMemory<std::complex<float>> *, ScratchAllocator *,
|
||||
const blas::IBlasLtMatmulAlgorithm *,
|
||||
const DeviceMemory<std::complex<float>> &, blas::ProfileResult *);
|
||||
|
||||
template Stream &
|
||||
Stream::ThenBlasLtMatmulImpl<std::complex<double>, std::complex<double>>(
|
||||
const blas::IBlasLtMatmulPlan *,
|
||||
const HostOrDeviceScalar<std::complex<double>> &,
|
||||
const DeviceMemory<std::complex<double>> &,
|
||||
const DeviceMemory<std::complex<double>> &,
|
||||
const HostOrDeviceScalar<std::complex<double>> &,
|
||||
DeviceMemory<std::complex<double>> *, ScratchAllocator *,
|
||||
const blas::IBlasLtMatmulAlgorithm *,
|
||||
const DeviceMemory<std::complex<double>> &, blas::ProfileResult *);
|
||||
|
||||
Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
|
||||
VLOG_CALL(PARAM(seed), PARAM(seed_bytes));
|
||||
|
||||
|
@ -75,6 +75,19 @@ class AlgorithmDesc;
|
||||
class StreamExecutor;
|
||||
class ScratchAllocator;
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Helper class to prevent a template function argument from being deduced. This
|
||||
// is identical to std::type_identity in C++20.
|
||||
template <typename T>
|
||||
struct NonDeduced {
|
||||
using type = T;
|
||||
};
|
||||
template <typename T>
|
||||
using NonDeducedType = typename NonDeduced<T>::type;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Convert a type to the corresponding QuantizedActivationMode.
|
||||
template <typename ElementType>
|
||||
struct Quantization;
|
||||
@ -1632,6 +1645,25 @@ class Stream {
|
||||
const DeviceMemory<std::complex<double>> &a, int lda,
|
||||
DeviceMemory<std::complex<double>> *b, int ldb);
|
||||
|
||||
// See BlasSupport::DoBlatLtMatmul.
|
||||
// Note that we prevent alpha and beta from being used to deduce CType so that
|
||||
// they can be constructed implicitly from values of type CType. Without this,
|
||||
// type deduction would fail when this function is called with a value of type
|
||||
// CType for alpha or beta.
|
||||
template <typename ABType, typename CType>
|
||||
Stream &ThenBlasLtMatmul(
|
||||
const blas::IBlasLtMatmulPlan *plan,
|
||||
const detail::NonDeducedType<HostOrDeviceScalar<CType>> &alpha,
|
||||
const DeviceMemory<ABType> &a, const DeviceMemory<ABType> &b,
|
||||
const detail::NonDeducedType<HostOrDeviceScalar<CType>> &beta,
|
||||
DeviceMemory<CType> *c, ScratchAllocator *scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm *algorithm,
|
||||
const DeviceMemory<CType> &bias = {},
|
||||
blas::ProfileResult *output_profile_result = nullptr) {
|
||||
return ThenBlasLtMatmulImpl(plan, alpha, a, b, beta, c, scratch_allocator,
|
||||
algorithm, bias, output_profile_result);
|
||||
}
|
||||
|
||||
// See FftSupport::DoFft.
|
||||
Stream &ThenFft(fft::Plan *plan,
|
||||
const DeviceMemory<std::complex<float>> &input,
|
||||
@ -2064,6 +2096,19 @@ class Stream {
|
||||
const dnn::BatchDescriptor &bias_descriptor,
|
||||
DeviceMemory<T> *backward_bias_data);
|
||||
|
||||
// Implementation of ThenBlasLtMatmul that is shared by all types.
|
||||
template <typename ABType, typename CType>
|
||||
Stream &ThenBlasLtMatmulImpl(const blas::IBlasLtMatmulPlan *plan,
|
||||
const HostOrDeviceScalar<CType> &alpha,
|
||||
const DeviceMemory<ABType> &a,
|
||||
const DeviceMemory<ABType> &b,
|
||||
const HostOrDeviceScalar<CType> &beta,
|
||||
DeviceMemory<CType> *c,
|
||||
ScratchAllocator *scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm *algorithm,
|
||||
const DeviceMemory<CType> &bias,
|
||||
blas::ProfileResult *output_profile_result);
|
||||
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(Stream);
|
||||
};
|
||||
|
||||
|
@ -337,6 +337,30 @@ bool StreamExecutor::GetBlasGemmAlgorithms(
|
||||
return blas_support->GetBlasGemmAlgorithms(out_algorithms);
|
||||
}
|
||||
|
||||
port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
|
||||
StreamExecutor::CreateBlasLtMatmulPlan(
|
||||
const blas::BlasLtMatmulPlanParams ¶ms) {
|
||||
blas::BlasSupport *blas_support = AsBlas();
|
||||
if (!blas_support) {
|
||||
return port::Status(port::error::UNKNOWN,
|
||||
"Fail to find the blas implementation.");
|
||||
}
|
||||
return blas_support->CreateBlasLtMatmulPlan(params);
|
||||
}
|
||||
|
||||
port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
|
||||
StreamExecutor::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
|
||||
size_t max_workspace_size,
|
||||
int max_algorithm_count) {
|
||||
blas::BlasSupport *blas_support = AsBlas();
|
||||
if (!blas_support) {
|
||||
return port::Status(port::error::UNKNOWN,
|
||||
"Fail to find the blas implementation.");
|
||||
}
|
||||
return blas_support->GetBlasLtMatmulAlgorithms(plan, max_workspace_size,
|
||||
max_algorithm_count);
|
||||
}
|
||||
|
||||
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
|
||||
StreamExecutor::createRnnDescriptor(
|
||||
int num_layers, int hidden_size, int input_size, int cell_size,
|
||||
|
@ -395,6 +395,21 @@ class StreamExecutor {
|
||||
// Get the list of supported algorithms for BLAS gemm.
|
||||
bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms);
|
||||
|
||||
// Creates a backend-specific plan object for a blaslt matmul operation, which
|
||||
// can then be passed to DoBlasLtMatmul(). When possible, plans should be
|
||||
// created once and reused for multiple calls to DoBlasLtMatmul().
|
||||
// Returns a null pointer on failure.
|
||||
port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
|
||||
CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams ¶ms);
|
||||
|
||||
// Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
|
||||
// returned in the order of increasing estimated compute time according to an
|
||||
// internal heuristic. The first returned algorithm can be used as the default
|
||||
// algorithm if no autotuning is to be performed.
|
||||
port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
|
||||
GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
|
||||
size_t max_workspace_size, int max_algorithm_count);
|
||||
|
||||
// Create an RNN descriptor based on model shapes and configurations.
|
||||
// The caller retains the ownership of the descriptor.
|
||||
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
|
||||
|
8
third_party/gpus/cuda/BUILD.tpl
vendored
8
third_party/gpus/cuda/BUILD.tpl
vendored
@ -127,6 +127,13 @@ cc_library(
|
||||
linkstatic = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cublasLt",
|
||||
srcs = ["cuda/lib/%{cublasLt_lib}"],
|
||||
data = ["cuda/lib/%{cublasLt_lib}"],
|
||||
linkstatic = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cusolver",
|
||||
srcs = ["cuda/lib/%{cusolver_lib}"],
|
||||
@ -168,6 +175,7 @@ cc_library(
|
||||
name = "cuda",
|
||||
deps = [
|
||||
":cublas",
|
||||
":cublasLt",
|
||||
":cuda_headers",
|
||||
":cudart",
|
||||
":cudnn",
|
||||
|
12
third_party/gpus/cuda_configure.bzl
vendored
12
third_party/gpus/cuda_configure.bzl
vendored
@ -551,6 +551,13 @@ def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config):
|
||||
cuda_config.cublas_version,
|
||||
static = False,
|
||||
),
|
||||
"cublasLt": _check_cuda_lib_params(
|
||||
"cublasLt",
|
||||
cpu_value,
|
||||
cuda_config.config["cublas_library_dir"],
|
||||
cuda_config.cublas_version,
|
||||
static = False,
|
||||
),
|
||||
"cusolver": _check_cuda_lib_params(
|
||||
"cusolver",
|
||||
cpu_value,
|
||||
@ -780,6 +787,7 @@ def _create_dummy_repository(repository_ctx):
|
||||
"%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value),
|
||||
"%{cudart_lib}": lib_name("cudart", cpu_value),
|
||||
"%{cublas_lib}": lib_name("cublas", cpu_value),
|
||||
"%{cublasLt_lib}": lib_name("cublasLt", cpu_value),
|
||||
"%{cusolver_lib}": lib_name("cusolver", cpu_value),
|
||||
"%{cudnn_lib}": lib_name("cudnn", cpu_value),
|
||||
"%{cufft_lib}": lib_name("cufft", cpu_value),
|
||||
@ -811,6 +819,7 @@ filegroup(name="cudnn-include")
|
||||
"cuda/cuda/lib/%s" % lib_name("cudart_static", cpu_value),
|
||||
)
|
||||
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublas", cpu_value))
|
||||
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublasLt", cpu_value))
|
||||
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusolver", cpu_value))
|
||||
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudnn", cpu_value))
|
||||
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("curand", cpu_value))
|
||||
@ -1002,11 +1011,13 @@ def _create_local_cuda_repository(repository_ctx):
|
||||
cublas_include_path + "/cublas.h",
|
||||
cublas_include_path + "/cublas_v2.h",
|
||||
cublas_include_path + "/cublas_api.h",
|
||||
cublas_include_path + "/cublasLt.h",
|
||||
],
|
||||
outs = [
|
||||
"cublas/include/cublas.h",
|
||||
"cublas/include/cublas_v2.h",
|
||||
"cublas/include/cublas_api.h",
|
||||
"cublas/include/cublasLt.h",
|
||||
],
|
||||
))
|
||||
|
||||
@ -1147,6 +1158,7 @@ def _create_local_cuda_repository(repository_ctx):
|
||||
"%{cudart_static_linkopt}": _cudart_static_linkopt(cuda_config.cpu_value),
|
||||
"%{cudart_lib}": _basename(repository_ctx, cuda_libs["cudart"]),
|
||||
"%{cublas_lib}": _basename(repository_ctx, cuda_libs["cublas"]),
|
||||
"%{cublasLt_lib}": _basename(repository_ctx, cuda_libs["cublasLt"]),
|
||||
"%{cusolver_lib}": _basename(repository_ctx, cuda_libs["cusolver"]),
|
||||
"%{cudnn_lib}": _basename(repository_ctx, cuda_libs["cudnn"]),
|
||||
"%{cufft_lib}": _basename(repository_ctx, cuda_libs["cufft"]),
|
||||
|
Loading…
Reference in New Issue
Block a user