Merge pull request from benbarsdell:cublaslt

PiperOrigin-RevId: 337382541
Change-Id: I949698ec93cb3c15654857768fcfce53984a97be
This commit is contained in:
TensorFlower Gardener 2020-10-15 14:39:38 -07:00
commit 6859f52a3f
35 changed files with 2361 additions and 297 deletions

View File

@ -3347,6 +3347,9 @@ tf_kernel_library(
prefix = "batch_matmul_op",
deps = MATH_DEPS + [":eigen_contraction_kernel"] + if_mkl_ml([
"//third_party/mkl:intel_binary_blob",
]) + if_cuda_or_rocm([
"//tensorflow/core/kernels:gpu_utils",
"//tensorflow/core/platform:tensor_float_32_utils",
]),
)
@ -3404,7 +3407,7 @@ tf_kernel_library(
name = "fft_ops",
prefix = "fft_ops",
deps = MATH_DEPS + [
] + if_cuda([
] + if_cuda_or_rocm([":gpu_utils"]) + if_cuda([
"//tensorflow/core/platform/default/build_config:cufft_plugin",
]),
)

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/matmul_autotune.h"
#include "tensorflow/core/util/matmul_bcast.h"
#include "tensorflow/core/util/work_sharder.h"
@ -43,8 +44,13 @@ limitations under the License.
#endif
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/gpu_utils.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/tensor_float_32_utils.h"
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cuda.h" // For CUDA_VERSION
#endif
namespace tensorflow {
@ -219,7 +225,8 @@ template <typename Scalar>
struct LaunchBatchMatMul<CPUDevice, Scalar> {
static void Launch(OpKernelContext* context, const Tensor& in_x,
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
bool trans_y, const MatMulBCast& bcast, Tensor* out) {
bool trans_y, const MatMulBCast& bcast, bool use_autotune,
Tensor* out) {
typedef ParallelMatMulKernel<Scalar, Eigen::NumTraits<Scalar>::IsComplex>
ParallelMatMulKernel;
bool conjugate_result = false;
@ -275,45 +282,212 @@ se::DeviceMemory<T> AsDeviceMemory(const T* gpu_memory) {
return typed;
}
class BlasScratchAllocator : public se::ScratchAllocator {
using BlasScratchAllocator = GpuScratchAllocator;
int64 GetBlasWorkspaceLimit(const string& envvar_in_mb,
int64 default_value_in_bytes) {
return gpu_utils::GetWorkspaceLimit(envvar_in_mb, default_value_in_bytes);
}
// Encapsulate all of the shape, dtype etc. information that defines a unique
// batched matmul operation.
class BatchMatmulParameters {
public:
using Stream = se::Stream;
using DeviceMemoryBytes = se::DeviceMemory<uint8>;
BatchMatmulParameters(bool trans_a, bool trans_b, bool adj_a, bool adj_b,
uint64 m, uint64 n, uint64 k, uint64 batch_count,
bool broadcast_a, bool broadcast_b, DataType dtype_ab,
DataType dtype_cd, bool allow_tf32, int device_id)
: trans_a_(trans_a),
trans_b_(trans_b),
adj_a_(adj_a),
adj_b_(adj_b),
m_(m),
n_(n),
k_(k),
batch_count_(batch_count),
broadcast_a_(broadcast_a),
broadcast_b_(broadcast_b),
dtype_ab_(dtype_ab),
dtype_cd_(dtype_cd),
allow_tf32_(allow_tf32),
device_id_(device_id) {
hash_code_ = trans_a;
hash_code_ = Hash64Combine(hash_code_, trans_b);
hash_code_ = Hash64Combine(hash_code_, adj_a);
hash_code_ = Hash64Combine(hash_code_, adj_b);
hash_code_ = Hash64Combine(hash_code_, m);
hash_code_ = Hash64Combine(hash_code_, n);
hash_code_ = Hash64Combine(hash_code_, k);
hash_code_ = Hash64Combine(hash_code_, batch_count);
hash_code_ = Hash64Combine(hash_code_, broadcast_a);
hash_code_ = Hash64Combine(hash_code_, broadcast_b);
hash_code_ = Hash64Combine(hash_code_, dtype_ab);
hash_code_ = Hash64Combine(hash_code_, dtype_cd);
hash_code_ = Hash64Combine(hash_code_, allow_tf32);
hash_code_ = Hash64Combine(hash_code_, device_id);
}
bool operator==(const BatchMatmulParameters& other) const {
return this->get_data_as_tuple() == other.get_data_as_tuple();
}
BlasScratchAllocator(OpKernelContext* context) : context_(context) {}
bool operator!=(const BatchMatmulParameters& other) const {
return !(*this == other);
}
uint64 hash() const { return hash_code_; }
int64 GetMemoryLimitInBytes() override { return -1; }
se::port::StatusOr<DeviceMemoryBytes> AllocateBytes(
int64 byte_size) override {
Tensor temporary_memory;
Status allocation_status(context_->allocate_temp(
DT_UINT8, TensorShape({byte_size}), &temporary_memory));
if (!allocation_status.ok()) {
return se::port::StatusOr<DeviceMemoryBytes>(
DeviceMemoryBytes::MakeFromByteSize(nullptr, 0));
}
// Hold the reference of the allocated tensors until the end of the
// allocator.
allocated_tensors_.push_back(temporary_memory);
return se::port::StatusOr<DeviceMemoryBytes>(
DeviceMemoryBytes::MakeFromByteSize(
temporary_memory.flat<uint8>().data(),
temporary_memory.flat<uint8>().size()));
string ToString() const {
// clang-format off
return strings::StrCat(
trans_a_, ", ", trans_b_, ", ", adj_a_, ", ", adj_b_, ", ",
m_, ", ", n_, ", ", k_, ", ", batch_count_, ", ",
broadcast_a_, ", ", broadcast_b_, ", ",
dtype_ab_, ", ", dtype_cd_, ", ", allow_tf32_, ", ", device_id_);
// clang-format on
}
private:
OpKernelContext* context_;
std::vector<Tensor> allocated_tensors_;
typedef std::tuple<bool, bool, bool, bool, int64, int64, int64, int64, bool,
bool, DataType, DataType, bool, int>
ParameterDataType;
ParameterDataType get_data_as_tuple() const {
return std::make_tuple(trans_a_, trans_b_, adj_a_, adj_b_, m_, n_, k_,
batch_count_, broadcast_a_, broadcast_b_, dtype_ab_,
dtype_cd_, allow_tf32_, device_id_);
}
bool trans_a_;
bool trans_b_;
bool adj_a_;
bool adj_b_;
uint64 m_;
uint64 n_;
uint64 k_;
uint64 batch_count_;
bool broadcast_a_;
bool broadcast_b_;
DataType dtype_ab_;
DataType dtype_cd_;
bool allow_tf32_;
int device_id_;
uint64 hash_code_;
};
bool GetBlasComputationType(const DataType& dtype, bool allow_tf32,
se::blas::ComputationType* compute_type) {
using se::blas::ComputationType;
static bool use_f32_for_f16_computation = MatmulDoFP32ComputationFP16Input();
ComputationType f32_type =
allow_tf32 ? ComputationType::kTF32AsF32 : ComputationType::kF32;
switch (dtype) {
case DT_HALF:
case DT_BFLOAT16:
*compute_type =
use_f32_for_f16_computation ? f32_type : ComputationType::kF16;
return true;
case DT_FLOAT:
*compute_type = f32_type;
return true;
case DT_DOUBLE:
*compute_type = ComputationType::kF64;
return true;
case DT_COMPLEX64:
*compute_type = f32_type;
return true;
case DT_COMPLEX128:
*compute_type = ComputationType::kComplexF64;
return true;
default:
// Unsupported compute_type, return false.
return false;
}
}
// Thread-safe map from matmul parameters to their corresponding plan and
// algorithms.
template <typename Parameters>
class BlasLtMatmulPlanMap {
public:
struct PlanAndAlgorithms {
std::unique_ptr<se::blas::IBlasLtMatmulPlan> plan;
std::vector<std::unique_ptr<se::blas::IBlasLtMatmulAlgorithm>> algorithms;
};
const PlanAndAlgorithms* Find(const Parameters& params) {
mutex_lock lock(mu_);
auto iter = params_plan_map_.find(params);
if (iter == params_plan_map_.end()) {
return nullptr;
}
return &iter->second;
}
const PlanAndAlgorithms* Insert(const Parameters& params,
PlanAndAlgorithms value) {
mutex_lock lock(mu_);
return &params_plan_map_.emplace(params, std::move(value)).first->second;
}
private:
struct Hasher {
std::size_t operator()(const Parameters& parameter) const {
return parameter.hash();
}
};
mutable mutex mu_;
std::unordered_map<Parameters, PlanAndAlgorithms, Hasher> params_plan_map_
GUARDED_BY(mu_);
};
template <typename Parameters>
struct BlasLtPlanMapSingleton {
typedef BlasLtMatmulPlanMap<Parameters> PlanMapType;
static PlanMapType* GetInstance() {
static PlanMapType* instance = new PlanMapType();
return instance;
}
};
typedef BlasLtPlanMapSingleton<BatchMatmulParameters>
BatchMatmulPlanMapSingleton;
// A dummy type to group matmul autotune results together.
struct BatchMatmulAutoTuneGroup {
static string name() { return "MatmulLt"; }
};
typedef AutoTuneSingleton<BatchMatmulAutoTuneGroup, BatchMatmulParameters,
se::blas::AlgorithmConfig>
AutoTuneBatchMatmul;
template <typename Scalar>
struct CoefficientType {
typedef Scalar type;
};
template <>
struct CoefficientType<Eigen::half> {
typedef float type;
};
inline Status FromExecutorStatus(const se::port::Status& s) {
return s.ok() ? Status::OK()
: Status(static_cast<error::Code>(static_cast<int>(s.code())),
s.error_message());
}
template <typename T>
inline Status FromExecutorStatus(const se::port::StatusOr<T>& s) {
return FromExecutorStatus(s.status());
}
} // namespace
template <typename Scalar>
struct LaunchBatchMatMul<GPUDevice, Scalar> {
static void Launch(OpKernelContext* context, const Tensor& in_x,
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
bool trans_y, const MatMulBCast& bcast, Tensor* out) {
bool trans_y, const MatMulBCast& bcast, bool use_autotune,
Tensor* out) {
se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
se::blas::Transpose::kTranspose,
se::blas::Transpose::kConjugateTranspose};
@ -343,10 +517,191 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
auto* a_base_ptr = in_x.template flat<Scalar>().data();
auto* b_base_ptr = in_y.template flat<Scalar>().data();
auto* c_base_ptr = out->template flat<Scalar>().data();
uint64 a_stride;
uint64 b_stride;
uint64 c_stride;
int64 a_stride;
int64 b_stride;
int64 c_stride;
typedef typename CoefficientType<Scalar>::type Coefficient;
static const int64 max_scratch_size = GetBlasWorkspaceLimit(
"TF_CUBLAS_WORKSPACE_LIMIT_IN_MB", 1LL << 32); // 4GB by default
// The BlasLtMatmul routines are only supported from CUDA 11.0 onward.
#if GOOGLE_CUDA && CUDA_VERSION >= 11000
bool is_full_broadcast =
std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
bool requires_mixed_broadcasting =
bcast.IsBroadcastingRequired() && !is_full_broadcast;
if (!requires_mixed_broadcasting) {
bool broadcast_a = bcast.x_batch_size() == 1;
bool broadcast_b = bcast.y_batch_size() == 1;
a_stride = broadcast_a ? 0 : m * k;
b_stride = broadcast_b ? 0 : k * n;
c_stride = m * n;
a_device_memory.push_back(AsDeviceMemory(a_base_ptr));
b_device_memory.push_back(AsDeviceMemory(b_base_ptr));
c_device_memory.push_back(AsDeviceMemory(c_base_ptr));
a_ptrs.push_back(&a_device_memory.back());
b_ptrs.push_back(&b_device_memory.back());
c_ptrs.push_back(&c_device_memory.back());
DataType dtype = DataTypeToEnum<Scalar>::value;
bool allow_tf32 = tensor_float_32_execution_enabled();
int device_id = stream->parent()->device_ordinal();
BatchMatmulParameters matmul_parameters(
trans_x, trans_y, adj_x, adj_y, m, n, k, batch_size, broadcast_a,
broadcast_b, dtype, dtype, allow_tf32, device_id);
static const bool max_autotune_algorithm_count =
MatmulMaxAutotuneAlgorithmCount();
int max_algorithm_count = use_autotune ? max_autotune_algorithm_count : 1;
const auto* plan_and_algorithms =
BatchMatmulPlanMapSingleton::GetInstance()->Find(matmul_parameters);
if (!plan_and_algorithms) {
se::blas::DataType blas_dtype = se::blas::ToDataType<Scalar>::value;
se::blas::ComputationType computation_type;
OP_REQUIRES(
context,
GetBlasComputationType(dtype, allow_tf32, &computation_type),
errors::Internal("Unsupported dtype for batched matmul"));
auto status_or_plan = stream->parent()->CreateBlasLtMatmulPlan(
{/*ab_type=*/blas_dtype,
/*c_type=*/blas_dtype, computation_type,
se::blas::PointerMode::kHost, se::blas::Epilogue::kDefault,
blas_transpose_b, blas_transpose_a, n, m, k,
/*lda=*/in_y.dim_size(2), /*ldb=*/in_x.dim_size(2),
/*ldc=*/static_cast<int64>(n), static_cast<int>(batch_size),
b_stride, a_stride, c_stride});
OP_REQUIRES(context, status_or_plan.ok(),
FromExecutorStatus(status_or_plan));
std::unique_ptr<se::blas::IBlasLtMatmulPlan> plan =
status_or_plan.ConsumeValueOrDie();
auto status_or_algorithms = stream->parent()->GetBlasLtMatmulAlgorithms(
plan.get(), max_scratch_size, max_algorithm_count);
OP_REQUIRES(context, status_or_algorithms.ok(),
FromExecutorStatus(status_or_algorithms));
auto algorithms = status_or_algorithms.ConsumeValueOrDie();
plan_and_algorithms =
BatchMatmulPlanMapSingleton::GetInstance()->Insert(
matmul_parameters, {std::move(plan), std::move(algorithms)});
}
const auto& plan = plan_and_algorithms->plan;
const auto& algorithms = plan_and_algorithms->algorithms;
// The BlasLtMatmul routines (unlike BlasGemm, BlasGemmBatched etc.) take
// alpha and beta with the same type as the matrices.
Scalar alpha(1.0);
Scalar beta(0.0);
// Note that algorithm_config.algorithm() here is used to refer
// to the index within the algorithms vector, not the algorithm
// itself.
se::blas::AlgorithmConfig algorithm_config(se::blas::kNoAlgorithm);
if (max_algorithm_count == 1) {
algorithm_config.set_algorithm(0);
} else if (!AutoTuneBatchMatmul::GetInstance()->Find(matmul_parameters,
&algorithm_config)) {
VLOG(4) << "Autotuning BlasLtMatmul over " << algorithms.size()
<< " algorithms.";
se::blas::ProfileResult best_result;
se::blas::ProfileResult profile_result;
// for (const auto& profile_algorithm : plan_and_algorithms->algorithms)
// {
for (size_t i = 0; i != algorithms.size(); ++i) {
const auto& profile_algorithm = algorithms[i];
// Create a new scratch allocator with every autotuning run so that
// scratch space is deallocated between runs.
BlasScratchAllocator scratch_allocator(max_scratch_size, context);
bool cublas_launch_status =
stream
->ThenBlasLtMatmul(plan.get(), alpha, *b_ptrs[0], *a_ptrs[0],
beta, c_ptrs[0], &scratch_allocator,
profile_algorithm.get(), {},
&profile_result)
.ok();
VLOG(4) << " Autotune algorithm " << i
<< " result: " << profile_result.elapsed_time_in_ms()
<< " ms, valid=" << profile_result.is_valid()
<< ", workspace_size=" << profile_algorithm->workspace_size();
if (cublas_launch_status && profile_result.is_valid() &&
profile_result.elapsed_time_in_ms() <
best_result.elapsed_time_in_ms()) {
best_result = profile_result;
}
}
if (best_result.is_valid()) {
algorithm_config.set_algorithm(best_result.algorithm());
}
// We make sure that each matmul parameter set only gets one pass of
// autotune. If no algorithms works, we add kNoAlgorithm to the autotune
// map.
AutoTuneBatchMatmul::GetInstance()->Insert(matmul_parameters,
algorithm_config);
}
se::blas::AlgorithmType algorithm_idx = algorithm_config.algorithm();
OP_REQUIRES(context,
0 <= algorithm_idx && algorithm_idx < algorithms.size(),
errors::Internal("Missing/invalid BatchMatmul algorithm"));
const auto& algorithm = algorithms[algorithm_idx];
BlasScratchAllocator scratch_allocator(max_scratch_size, context);
bool cublas_launch_status =
stream
->ThenBlasLtMatmul(plan.get(), alpha, *b_ptrs[0], *a_ptrs[0],
beta, c_ptrs[0], &scratch_allocator,
algorithm.get())
.ok();
if (!cublas_launch_status) {
context->SetStatus(errors::Internal(
"Blas batched matmul launch failed : a.shape=(",
bcast.x_batch_size(), ", ", in_x.dim_size(0), ", ",
in_x.dim_size(1), "), b.shape=(", bcast.y_batch_size(), ", ",
in_y.dim_size(0), ", ", in_y.dim_size(1), "), m=", m, ", n=", n,
", k=", k, ", batch_size=", batch_size));
}
} else { // requires mixed broadcasting
const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
for (int64 i = 0; i < bcast.x_batch_size(); ++i) {
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
}
for (int64 i = 0; i < bcast.y_batch_size(); ++i) {
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
}
for (int64 i = 0; i < batch_size; ++i) {
c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
c_ptrs.push_back(&c_device_memory.back());
}
BlasScratchAllocator scratch_allocator(max_scratch_size, context);
bool blas_launch_status =
stream
->ThenBlasGemmBatchedWithScratch(
blas_transpose_b, blas_transpose_a, n, m, k,
static_cast<Coefficient>(1.0), b_ptrs,
adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
&scratch_allocator)
.ok();
if (!blas_launch_status) {
context->SetStatus(errors::Internal(
"Blas xGEMMBatched launch failed : a.shape=",
in_x.shape().DebugString(),
", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
", k=", k, ", batch_size=", batch_size));
}
}
return;
#else // if not GOOGLE_CUDA or CUDA_VERSION < 11000
bool is_full_broadcast =
std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
bool use_strided_batched =
@ -388,8 +743,6 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
}
}
typedef Scalar Coefficient;
// Blas does
// C = A x B
// where A, B and C are assumed to be in column major.
@ -399,7 +752,10 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
if (batch_size == 1) {
// This is a regular matrix*matrix or matrix*vector multiply. Avoid the
// overhead of the scratch allocator and the batch interface.
if (n == 1 &&
// Note that the GEMV call here does not support Eigen::half, so we do not
// use this path in that case. A workaround is applied to the pointers
// passed to the call itself to avoid compilation errors.
if (!std::is_same<Scalar, Eigen::half>::value && n == 1 &&
blas_transpose_b != se::blas::Transpose::kConjugateTranspose &&
blas_transpose_a != se::blas::Transpose::kConjugateTranspose) {
// This is a matrix*vector multiply so use GEMV to compute A * b.
@ -410,13 +766,19 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
auto gemv_trans_a = blas_transpose_a == se::blas::Transpose::kTranspose
? se::blas::Transpose::kNoTranspose
: se::blas::Transpose::kTranspose;
// Cast pointers as a workaround for GEMV not supporting Eigen::half
// (this will never actually be executed for Eigen::half).
typedef se::DeviceMemory<Coefficient> NonHalfDeviceMemoryType;
NonHalfDeviceMemoryType a_ptr(*(a_ptrs[0]));
NonHalfDeviceMemoryType b_ptr(*(b_ptrs[0]));
NonHalfDeviceMemoryType c_ptr(*(c_ptrs[0]));
bool blas_launch_status =
stream
->ThenBlasGemv(gemv_trans_a, adj_x || trans_x ? m : k,
adj_x || trans_x ? k : m,
static_cast<Coefficient>(1.0), *(a_ptrs[0]),
adj_x || trans_x ? m : k, *(b_ptrs[0]), 1,
static_cast<Coefficient>(0.0), c_ptrs[0], 1)
static_cast<Coefficient>(1.0), a_ptr,
adj_x || trans_x ? m : k, b_ptr, 1,
static_cast<Coefficient>(0.0), &c_ptr, 1)
.ok();
if (!blas_launch_status) {
context->SetStatus(errors::Internal(
@ -459,154 +821,7 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
", k=", k, ", batch_size=", batch_size));
}
} else {
BlasScratchAllocator scratch_allocator(context);
bool blas_launch_status =
stream
->ThenBlasGemmBatchedWithScratch(
blas_transpose_b, blas_transpose_a, n, m, k,
static_cast<Coefficient>(1.0), b_ptrs,
adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
&scratch_allocator)
.ok();
if (!blas_launch_status) {
context->SetStatus(errors::Internal(
"Blas xGEMMBatched launch failed : a.shape=",
in_x.shape().DebugString(),
", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
", k=", k, ", batch_size=", batch_size));
}
}
}
};
template <>
struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
static void Launch(OpKernelContext* context, const Tensor& in_x,
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
bool trans_y, const MatMulBCast& bcast, Tensor* out) {
typedef Eigen::half Scalar;
se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
se::blas::Transpose::kTranspose,
se::blas::Transpose::kConjugateTranspose};
const uint64 m = in_x.dim_size(adj_x || trans_x ? 2 : 1);
const uint64 k = in_x.dim_size(adj_x || trans_x ? 1 : 2);
const uint64 n = in_y.dim_size(adj_y || trans_y ? 1 : 2);
const uint64 batch_size = bcast.output_batch_size();
auto blas_transpose_a = trans[adj_x ? 2 : (trans_x ? 1 : 0)];
auto blas_transpose_b = trans[adj_y ? 2 : (trans_y ? 1 : 0)];
auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
typedef perftools::gputools::DeviceMemory<Scalar> DeviceMemoryType;
std::vector<DeviceMemoryType> a_device_memory;
std::vector<DeviceMemoryType> b_device_memory;
std::vector<DeviceMemoryType> c_device_memory;
std::vector<DeviceMemoryType*> a_ptrs;
std::vector<DeviceMemoryType*> b_ptrs;
std::vector<DeviceMemoryType*> c_ptrs;
a_device_memory.reserve(bcast.x_batch_size());
b_device_memory.reserve(bcast.y_batch_size());
c_device_memory.reserve(batch_size);
a_ptrs.reserve(batch_size);
b_ptrs.reserve(batch_size);
c_ptrs.reserve(batch_size);
auto* a_base_ptr = in_x.template flat<Scalar>().data();
auto* b_base_ptr = in_y.template flat<Scalar>().data();
auto* c_base_ptr = out->template flat<Scalar>().data();
uint64 a_stride;
uint64 b_stride;
uint64 c_stride;
bool is_full_broadcast =
std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
bool use_strided_batched =
(!bcast.IsBroadcastingRequired() || is_full_broadcast) &&
batch_size > 1;
if (use_strided_batched) {
a_stride = bcast.x_batch_size() != 1 ? m * k : 0;
b_stride = bcast.y_batch_size() != 1 ? k * n : 0;
c_stride = m * n;
a_device_memory.push_back(AsDeviceMemory(a_base_ptr));
b_device_memory.push_back(AsDeviceMemory(b_base_ptr));
c_device_memory.push_back(AsDeviceMemory(c_base_ptr));
a_ptrs.push_back(&a_device_memory.back());
b_ptrs.push_back(&b_device_memory.back());
c_ptrs.push_back(&c_device_memory.back());
} else if (!bcast.IsBroadcastingRequired()) {
for (int64 i = 0; i < batch_size; ++i) {
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
a_ptrs.push_back(&a_device_memory.back());
b_ptrs.push_back(&b_device_memory.back());
c_ptrs.push_back(&c_device_memory.back());
}
} else {
const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
for (int64 i = 0; i < bcast.x_batch_size(); ++i) {
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
}
for (int64 i = 0; i < bcast.y_batch_size(); ++i) {
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
}
for (int64 i = 0; i < batch_size; ++i) {
c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
c_ptrs.push_back(&c_device_memory.back());
}
}
typedef float Coefficient;
// Blas does
// C = A x B
// where A, B and C are assumed to be in column major.
// We want the output to be in row-major, so we can compute
// C' = B' x A', where ' stands for transpose (not adjoint).
// TODO(yangzihao): Choose the best of the three strategies using autotune.
if (batch_size == 1) {
// This is a regular matrix*matrix or matrix*vector multiply. Avoid the
// overhead of the scratch allocator and the batch interface.
// TODO(benbarsdell): Use fp16 Gemv if it becomes supported by CUBLAS
bool blas_launch_status =
stream
->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
static_cast<Coefficient>(1.0), *(b_ptrs[0]),
adj_y || trans_y ? k : n, *(a_ptrs[0]),
adj_x || trans_x ? m : k,
static_cast<Coefficient>(0.0), c_ptrs[0], n)
.ok();
if (!blas_launch_status) {
context->SetStatus(errors::Internal(
"Blas xGEMM launch failed : a.shape=", in_x.shape().DebugString(),
", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
", k=", k));
}
} else if (use_strided_batched) {
bool blas_launch_status =
stream
->ThenBlasGemmStridedBatched(
blas_transpose_b, blas_transpose_a, n, m, k,
static_cast<Coefficient>(1.0), *b_ptrs[0],
adj_y || trans_y ? k : n, b_stride, *a_ptrs[0],
adj_x || trans_x ? m : k, a_stride,
static_cast<Coefficient>(0.0), c_ptrs[0], n, c_stride,
batch_size)
.ok();
if (!blas_launch_status) {
context->SetStatus(errors::Internal(
"Blas xGEMMStridedBatched launch failed : a.shape=",
in_x.shape().DebugString(),
", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
", k=", k, ", batch_size=", batch_size));
}
} else {
BlasScratchAllocator scratch_allocator(context);
BlasScratchAllocator scratch_allocator(max_scratch_size, context);
bool blas_launch_status =
stream
->ThenBlasGemmBatchedWithScratch(
@ -624,6 +839,7 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
", k=", k, ", batch_size=", batch_size));
}
}
#endif // not GOOGLE_CUDA or CUDA_VERSION < 11000
}
};
@ -637,6 +853,7 @@ class BaseBatchMatMulOp : public OpKernel {
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
use_autotune_ = MatmulAutotuneEnable();
}
~BaseBatchMatMulOp() override {}
@ -698,7 +915,7 @@ class BaseBatchMatMulOp : public OpKernel {
out->shape().DebugString()));
LaunchBatchMatMul<Device, Scalar>::Launch(
ctx, in0_reshaped, in1_reshaped, adj_x_, adj_y_, /*trans_x=*/false,
/*trans_y=*/false, bcast, &out_reshaped);
/*trans_y=*/false, bcast, use_autotune_, &out_reshaped);
}
protected:
@ -708,6 +925,7 @@ class BaseBatchMatMulOp : public OpKernel {
private:
bool adj_x_;
bool adj_y_;
bool use_autotune_;
};
// BatchMatMul Op implementation which disallows broadcasting.

View File

@ -619,19 +619,7 @@ template struct LaunchConv2DOp<CPUDevice, double>;
int64 GetDnnWorkspaceLimit(const string& envvar_in_mb,
int64 default_value_in_bytes) {
const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
if (workspace_limit_in_mb_str != nullptr &&
strcmp(workspace_limit_in_mb_str, "") != 0) {
int64 scratch_limit_in_mb = -1;
if (strings::safe_strto64(workspace_limit_in_mb_str,
&scratch_limit_in_mb)) {
return scratch_limit_in_mb * (1 << 20);
} else {
LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
<< workspace_limit_in_mb_str;
}
}
return default_value_in_bytes;
return gpu_utils::GetWorkspaceLimit(envvar_in_mb, default_value_in_bytes);
}
// A dummy type to group forward convolution autotune results together.

View File

@ -48,52 +48,7 @@ int64 GetDnnWorkspaceLimit(const string& envvar_in_mb,
// A class to provide scratch-space allocator for Stream-Executor Cudnn
// callback. TensorFlow is responsible for releasing the temporary buffers after
// the kernel finishes.
class DnnScratchAllocator : public se::ScratchAllocator {
public:
virtual ~DnnScratchAllocator() {}
DnnScratchAllocator(int64 memory_limit, OpKernelContext* context)
: memory_limit_(memory_limit), total_byte_size_(0), context_(context) {}
int64 GetMemoryLimitInBytes() override { return memory_limit_; }
se::port::StatusOr<se::DeviceMemory<uint8>> AllocateBytes(
int64 byte_size) override {
Tensor temporary_memory;
if (byte_size < 0) {
return se::port::Status{se::port::error::INVALID_ARGUMENT,
"Requested negative byte size!"};
}
if (byte_size > memory_limit_) {
return se::port::Status{se::port::error::UNAVAILABLE,
absl::StrCat("Requested memory size (", byte_size,
") exceeds the max memory limit (",
memory_limit_, ").")};
}
AllocationAttributes allocation_attr;
allocation_attr.retry_on_failure = false;
Status allocation_status(context_->allocate_temp(
DT_UINT8, TensorShape({byte_size}), &temporary_memory,
AllocatorAttributes(), allocation_attr));
if (!allocation_status.ok()) {
return se::port::Status{
se::port::error::UNAVAILABLE,
absl::StrCat("Failed to allocate the requested memory size (",
byte_size, ").")};
}
// Hold the reference of the allocated tensors until the end of the
// allocator.
allocated_tensors_.push_back(temporary_memory);
total_byte_size_ += byte_size;
return se::port::StatusOr<se::DeviceMemory<uint8>>(
AsDeviceMemory(temporary_memory.flat<uint8>().data(),
temporary_memory.flat<uint8>().size()));
}
int64 TotalByteSize() { return total_byte_size_; }
private:
int64 memory_limit_;
int64 total_byte_size_;
OpKernelContext* context_;
std::vector<Tensor> allocated_tensors_;
};
using DnnScratchAllocator = GpuScratchAllocator;
// Encapsulate all the shape information that is used in both forward and
// backward conv operations.

View File

@ -31,6 +31,7 @@ limitations under the License.
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
#include "tensorflow/core/kernels/gpu_utils.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
@ -400,20 +401,7 @@ class CufftScratchAllocator : public se::ScratchAllocator {
int64 GetCufftWorkspaceLimit(const string& envvar_in_mb,
int64 default_value_in_bytes) {
const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
if (workspace_limit_in_mb_str != nullptr &&
strcmp(workspace_limit_in_mb_str, "") != 0) {
int64 scratch_limit_in_mb = -1;
Status status = ReadInt64FromEnvVar(envvar_in_mb, default_value_in_bytes,
&scratch_limit_in_mb);
if (!status.ok()) {
LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
<< workspace_limit_in_mb_str;
} else {
return scratch_limit_in_mb * (1 << 20);
}
}
return default_value_in_bytes;
return gpu_utils::GetWorkspaceLimit(envvar_in_mb, default_value_in_bytes);
}
class FFTGPUBase : public FFTBase {

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "google/protobuf/any.pb.h"
#include "absl/algorithm/container.h"
#include "absl/base/call_once.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/logger.h"
#include "tensorflow/core/protobuf/autotuning.pb.h"
#include "tensorflow/core/protobuf/conv_autotuning.pb.h"
@ -282,6 +283,64 @@ Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
return Status::OK();
}
namespace gpu_utils {
int64 GetWorkspaceLimit(const string& envvar_in_mb,
int64 default_value_in_bytes) {
const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
if (workspace_limit_in_mb_str != nullptr &&
strcmp(workspace_limit_in_mb_str, "") != 0) {
int64 scratch_limit_in_mb = -1;
if (strings::safe_strto64(workspace_limit_in_mb_str,
&scratch_limit_in_mb)) {
return scratch_limit_in_mb * (1 << 20);
} else {
LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
<< workspace_limit_in_mb_str;
}
}
return default_value_in_bytes;
}
} // namespace gpu_utils
GpuScratchAllocator::GpuScratchAllocator(int64 memory_limit,
OpKernelContext* context)
: memory_limit_(memory_limit), total_byte_size_(0), context_(context) {}
se::port::StatusOr<se::DeviceMemory<uint8>> GpuScratchAllocator::AllocateBytes(
int64 byte_size) {
Tensor temporary_memory;
if (byte_size < 0) {
return se::port::Status{se::port::error::INVALID_ARGUMENT,
"Requested negative byte size!"};
}
if (byte_size > memory_limit_) {
return se::port::Status{
se::port::error::UNAVAILABLE,
absl::StrCat("Requested memory size (", byte_size,
") exceeds the max memory limit (", memory_limit_, ").")};
}
AllocationAttributes allocation_attr;
allocation_attr.retry_on_failure = false;
Status allocation_status(context_->allocate_temp(
DT_UINT8, TensorShape({byte_size}), &temporary_memory,
AllocatorAttributes(), allocation_attr));
if (!allocation_status.ok()) {
return se::port::Status{
se::port::error::UNAVAILABLE,
absl::StrCat("Failed to allocate the requested memory size (",
byte_size, ").")};
}
// Hold the reference of the allocated tensors until the end of the
// allocator.
// NOTE: We expect tensors to be deallocated when this allocator goes out of
// scope when allocated_tensors is destructed.
allocated_tensors_.push_back(temporary_memory);
total_byte_size_ += byte_size;
return se::port::StatusOr<se::DeviceMemory<uint8>>(
AsDeviceMemory(temporary_memory.flat<uint8>().data(),
temporary_memory.flat<uint8>().size()));
}
} // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -243,6 +243,37 @@ void LogFusedConvForwardAutotuneResults(
Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
se::dnn::AlgorithmConfig* algo);
namespace gpu_utils {
// Get a workspace limit from the environment variable, which is in MB.
// Return the workspace memory limit in bytes. If no value is set, return the
// default value.
int64 GetWorkspaceLimit(const string& envvar_in_mb,
int64 default_value_in_bytes);
} // namespace gpu_utils
// A class to provide scratch-space allocator for Stream-Executor callbacks in
// CUDA libraries (CUDNN etc.).
// TensorFlow is responsible for releasing the temporary buffers after
// the kernel finishes.
class GpuScratchAllocator : public se::ScratchAllocator {
public:
virtual ~GpuScratchAllocator() {}
GpuScratchAllocator(int64 memory_limit, OpKernelContext* context);
int64 GetMemoryLimitInBytes() override { return memory_limit_; }
se::port::StatusOr<se::DeviceMemory<uint8>> AllocateBytes(
int64 byte_size) override;
int64 TotalByteSize() { return total_byte_size_; }
private:
int64 memory_limit_;
int64 total_byte_size_;
OpKernelContext* context_;
std::vector<Tensor> allocated_tensors_;
};
} // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -549,7 +549,7 @@ struct EinsumHelper {
static Status ContractOperands(OpKernelContext* ctx,
absl::Span<const Tensor> inputs,
absl::Span<const bool> swap_free_and_contract,
Tensor* output) {
bool use_autotune, Tensor* output) {
if (inputs.size() == 1)
return CopyFrom(inputs[0], inputs[0].shape(), output);
MatMulBCast bcast(inputs[0].shape().dim_sizes(),
@ -583,7 +583,7 @@ struct EinsumHelper {
ReshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped));
LaunchBatchMatMul<Device, T>::Launch(ctx, lhs, rhs, /*adj_x=*/false,
/*adj_y=*/false, trans_x, trans_y,
bcast, &output_reshaped);
bcast, use_autotune, &output_reshaped);
return Status::OK();
}
};
@ -598,6 +598,7 @@ class EinsumOp : public OpKernel {
equation_, &input_labels_, &output_labels_, &label_types_,
&input_label_counts_, &output_label_counts_,
&input_has_ellipsis_, &output_has_ellipsis_));
use_autotune_ = MatmulAutotuneEnable();
}
void Compute(OpKernelContext* ctx) override {
@ -640,7 +641,7 @@ class EinsumOp : public OpKernel {
Tensor contraction_output_reshaped;
OP_REQUIRES_OK(ctx, EinsumHelper::ContractOperands<Device, T>(
ctx, inputs_reduced, swap_free_and_contract,
&contraction_output_reshaped));
use_autotune_, &contraction_output_reshaped));
// Copy the batch labels from the contraction output. Recover the batch
// shape, which may have been broadcasted.
@ -738,6 +739,7 @@ class EinsumOp : public OpKernel {
LabelCounts output_label_counts_;
gtl::InlinedVector<bool, 2> input_has_ellipsis_;
bool output_has_ellipsis_ = false;
bool use_autotune_;
};
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -48,4 +48,22 @@ bool MatmulDoFP32ComputationFP16Input() {
return value;
}
int MatmulMaxAutotuneAlgorithmCount() {
int64 value;
// In CUDA 11, cublasLtMatmulAlgoGetHeuristic typically returns <= 4
// algorithms for a given configuration, so 10 seems like a reasonable default
// here.
Status status =
ReadInt64FromEnvVar("TF_MATMUL_AUTOTUNE_MAX_ALGORITHMS", 10, &value);
if (!status.ok()) {
LOG(ERROR) << status.error_message();
}
static constexpr const int kMaxValue = std::numeric_limits<int>::max();
if (value < 1 || value > kMaxValue) {
LOG(ERROR) << "Invalid value for TF_MATMUL_AUTOTUNE_MAX_ALGORITHMS: "
<< value << " is not in range [1, " << kMaxValue << "]";
}
return value;
}
} // namespace tensorflow

View File

@ -22,6 +22,7 @@ namespace tensorflow {
bool MatmulAutotuneEnable();
bool MatmulDoFP32ComputationFP16Input();
int MatmulMaxAutotuneAlgorithmCount();
} // namespace tensorflow

View File

@ -114,6 +114,10 @@ cuda_py_test(
name = "dirichlet_test",
size = "small",
srcs = ["dirichlet_test.py"],
tags = [
# b/170982175
"no_oss",
],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",

View File

@ -197,7 +197,7 @@ class DirichletTest(test.TestCase):
self.assertAllClose(sample_mean_, analytic_mean, atol=0.04, rtol=0.)
self.assertAllClose(sample_cov_, analytic_cov, atol=0.06, rtol=0.)
self.assertAllClose(sample_var_, analytic_var, atol=0.03, rtol=0.)
self.assertAllClose(sample_var_, analytic_var, atol=0.04, rtol=0.)
self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.)
@test_util.run_without_tensor_float_32(

View File

@ -587,7 +587,7 @@ class CSRSparseMatrixOpsTest(test.TestCase):
c_t_value, c_dense_t_value = self.evaluate((c_t, c_dense_t))
self.assertAllClose(
c_t_value, c_dense_t_value, rtol=1e-6, atol=1e-5)
c_t_value, c_dense_t_value, rtol=1e-6, atol=2e-5)
@test_util.run_in_graph_and_eager_modes
def testLargeBatchSparseMatrixMatMulTransposed(self):
@ -650,7 +650,7 @@ class CSRSparseMatrixOpsTest(test.TestCase):
self.assertAllEqual(c_t.shape, c_dense_t.shape)
c_t_value, c_dense_t_value = self.evaluate((c_t, c_dense_t))
self.assertAllClose(
c_t_value, c_dense_t_value, rtol=1e-6, atol=1e-5)
c_t_value, c_dense_t_value, rtol=1e-6, atol=2e-5)
@test_util.run_in_graph_and_eager_modes
def testLargeBatchSparseMatrixMatMulConjugate(self):

View File

@ -61,7 +61,6 @@ cc_library(
"blas.h",
"device_description.h",
"device_options.h",
"dnn.h",
"event.cc",
"fft.h",
"kernel_cache_config.h",
@ -103,7 +102,6 @@ cc_library(
cc_library(
name = "kernel",
srcs = [
"dnn.h",
"fft.h",
"kernel.cc",
"plugin.h",
@ -300,6 +298,7 @@ cc_library(
name = "host_or_device_scalar",
hdrs = ["host_or_device_scalar.h"],
deps = [
":data_type",
":device_memory",
"//tensorflow/stream_executor/platform",
],
@ -331,7 +330,6 @@ cc_library(
],
hdrs = [
"blas.h",
"dnn.h",
"executor_cache.h",
"fft.h",
"kernel.h",
@ -423,11 +421,21 @@ tf_proto_library(
make_default_target_header_only = True,
)
cc_library(
name = "data_type",
hdrs = ["data_type.h"],
deps = [
":dnn_proto_cc",
"//tensorflow/stream_executor/platform",
],
)
cc_library(
name = "dnn",
srcs = ["dnn.cc"],
hdrs = ["dnn.h"],
deps = [
":data_type",
":device_memory",
":dnn_proto_cc",
":stream_executor_headers",
@ -445,7 +453,6 @@ cc_library(
cc_library(
name = "stream_executor_internal",
srcs = [
"dnn.h",
"stream_executor_internal.cc",
],
hdrs = [
@ -474,7 +481,6 @@ cc_library(
name = "stream_executor_pimpl_header",
hdrs = [
"device_description.h",
"dnn.h",
"kernel.h",
"kernel_cache_config.h",
"stream_executor_pimpl.h",

View File

@ -95,5 +95,30 @@ std::ostream& operator<<(std::ostream& os, ComputationType ty) {
return os << ComputationTypeString(ty);
}
string DataTypeString(DataType ty) {
switch (ty) {
case DataType::kHalf:
return "f16";
case DataType::kFloat:
return "f32";
case DataType::kDouble:
return "f64";
case DataType::kInt8:
return "i8";
case DataType::kInt32:
return "i32";
case DataType::kComplexFloat:
return "complex f32";
case DataType::kComplexDouble:
return "complex f64";
default:
LOG(FATAL) << "Unknown DataType " << static_cast<int32>(ty);
}
}
std::ostream& operator<<(std::ostream& os, DataType ty) {
return os << DataTypeString(ty);
}
} // namespace blas
} // namespace stream_executor

View File

@ -43,7 +43,7 @@ limitations under the License.
#include <complex>
#include <vector>
#include "tensorflow/stream_executor/host_or_device_scalar.h"
#include "tensorflow/stream_executor/dnn.h" // For DataType, ToDataType
#include "tensorflow/stream_executor/lib/array_slice.h"
#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/platform/port.h"
@ -60,6 +60,9 @@ class ScratchAllocator;
template <typename ElemT>
class DeviceMemory;
template <typename ElemT>
class HostOrDeviceScalar;
namespace blas {
// Specifies whether the input matrix will be transposed or
@ -101,6 +104,18 @@ enum class ComputationType {
kI32, // 32-bit integer
kComplexF32, // Complex number comprised of two f32s.
kComplexF64, // Complex number comprised of two f64s.
// The below values are only supported for BlasLt routines (both real and
// complex). They use float32 for accumulation but round the input mantissas
// to a smaller number of bits.
kTF32AsF32, // 32-bit floating-point with reduced (>=10-bit) mantissa
kBF16AsF32, // 32-bit floating-point with reduced (7-bit) mantissa
};
enum class Epilogue {
kDefault = 1, // No special postprocessing
kReLU = 2, // Apply ReLU func point-wise to the results
kBias = 4, // Add broadcasted bias vector to the results
kBiasThenReLU = kBias | kReLU, // Apply bias and then ReLU transform
};
// Converts a ComputationType to a string.
@ -108,6 +123,21 @@ std::string ComputationTypeString(ComputationType ty);
std::ostream &operator<<(std::ostream &os, ComputationType ty);
using dnn::DataType;
using dnn::ToDataType;
// Describes the type of pointers for the scaling factors alpha and beta in
// blaslt routines.
enum class PointerMode {
kHost,
kDevice,
};
// Converts a ComputationType to a string.
string DataTypeString(DataType ty);
std::ostream &operator<<(std::ostream &os, DataType ty);
// Opaque identifier for an "algorithm" used by a blas routine. This functions
// as a hint to the blas library.
typedef int64 AlgorithmType;
@ -163,6 +193,44 @@ class AlgorithmConfig {
AlgorithmType algorithm_;
};
struct IBlasLtMatmulPlan {
// Returns the data type of the A and B (input) matrices.
virtual DataType ab_type() const = 0;
// Returns the data type of the C (input/output) matrix.
virtual DataType c_type() const = 0;
virtual ~IBlasLtMatmulPlan() {}
};
struct IBlasLtMatmulAlgorithm {
virtual ~IBlasLtMatmulAlgorithm() {}
// Returns the index of the algorithm within the list returned by
// GetBlasLtMatmulAlgorithms.
virtual AlgorithmType index() const = 0;
// Returns the workspace size required by the algorithm in bytes.
virtual size_t workspace_size() const = 0;
};
// Parameters for the CreateBlasLtMatmulPlan method.
struct BlasLtMatmulPlanParams {
DataType ab_type;
DataType c_type;
ComputationType computation_type;
PointerMode pointer_mode;
Epilogue epilogue;
Transpose transa;
Transpose transb;
uint64 m;
uint64 n;
uint64 k;
int64 lda;
int64 ldb;
int64 ldc;
int batch_count = 1;
int64 stride_a = 0;
int64 stride_b = 0;
int64 stride_c = 0;
};
// BLAS support interface -- this can be derived from a GPU executor when the
// underlying platform has an BLAS library implementation available. See
// StreamExecutor::AsBlas().
@ -1383,6 +1451,71 @@ class BlasSupport {
const DeviceMemory<std::complex<double>> &a, int lda,
DeviceMemory<std::complex<double>> *b, int ldb) = 0;
// Creates a backend-specific plan object for a blaslt matmul operation, which
// can then be passed to DoBlasLtMatmul(). When possible, plans should be
// created once and reused for multiple calls to DoBlasLtMatmul().
virtual port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &params) = 0;
// Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
// returned in the order of increasing estimated compute time according to an
// internal heuristic. The first returned algorithm can be used as the default
// algorithm if no autotuning is to be performed.
virtual port::StatusOr<
std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
size_t max_workspace_size,
int max_algorithm_count) = 0;
// Executes a blaslt matmul operation on the stream. If output_profile_result
// is not nullptr, the operation is profiled, error messages are
// suppressed, and output_profile_result->algorithm() is set to
// algorithm->index(). If epilogue was set to kBias or kBiasThenReLU when
// creating the plan, the bias argument here must refer to a valid device
// vector of length equal to the number of rows in matrix c. If epilogue was
// set to any other value then the bias argument here must be null. The bias
// vector is broadcast across the batch dimension.
// Note that the data types of a and b (c and bias) must match the ab_type
// (c_type) with which the plan was created, and the data types of alpha and
// beta must match the data type of c.
virtual bool DoBlasLtMatmul(
Stream *stream, const blas::IBlasLtMatmulPlan *plan,
const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a,
DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta,
DeviceMemoryBase c, ScratchAllocator *scratch_allocator,
const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias,
blas::ProfileResult *output_profile_result) = 0;
template <typename ABType, typename CType>
bool DoBlasLtMatmul(Stream *stream, const blas::IBlasLtMatmulPlan *plan,
const HostOrDeviceScalar<CType> &alpha,
const DeviceMemory<ABType> &a,
const DeviceMemory<ABType> &b,
const HostOrDeviceScalar<CType> &beta,
DeviceMemory<CType> *c,
ScratchAllocator *scratch_allocator,
const blas::IBlasLtMatmulAlgorithm *algorithm,
const DeviceMemory<CType> &bias = {},
blas::ProfileResult *output_profile_result = nullptr) {
constexpr blas::DataType ab_type = blas::ToDataType<ABType>::value;
if (ab_type != plan->ab_type()) {
VLOG(2) << "DoBlasLtMatmul returning false because a and b type does "
"not match plan: expected "
<< plan->ab_type() << ", got " << ab_type;
return false;
}
constexpr blas::DataType c_type = blas::ToDataType<CType>::value;
if (c_type != plan->c_type()) {
VLOG(2) << "DoBlasLtMatmul returning false because c type does "
"not match plan: expected "
<< plan->c_type() << ", got " << c_type;
return false;
}
return DoBlasLtMatmul(stream, plan, alpha, a, b, beta, *c,
scratch_allocator, algorithm, bias,
output_profile_result);
}
virtual port::Status GetVersion(std::string *version) = 0;
protected:
@ -2196,6 +2329,19 @@ class BlasSupport {
uint64 n, std::complex<double> alpha, \
const DeviceMemory<std::complex<double>> &a, int lda, \
DeviceMemory<std::complex<double>> *b, int ldb) override; \
port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>> \
CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &params) override; \
port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>> \
GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan, \
size_t max_workspace_size, \
int max_algorithm_count) override; \
bool DoBlasLtMatmul( \
Stream *stream, const blas::IBlasLtMatmulPlan *plan, \
const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a, \
DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta, \
DeviceMemoryBase c, ScratchAllocator *scratch_allocator, \
const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias, \
blas::ProfileResult *output_profile_result) override; \
port::Status GetVersion(std::string *version) override;
} // namespace blas

View File

@ -251,6 +251,31 @@ alias(
visibility = ["//visibility:public"],
)
cc_library(
name = "cublas_lt_stub",
srcs = if_cuda_is_configured(["cublasLt_stub.cc"]),
textual_hdrs = glob(["cublasLt_*.inc"]),
deps = if_cuda_is_configured([
# LINT.IfChange
"@local_config_cuda//cuda:cublas_headers",
# LINT.ThenChange(//tensorflow/copy.bara.sky:cublasLt_headers)
"@local_config_cuda//cuda:cuda_headers",
"//tensorflow/stream_executor/lib",
"//tensorflow/stream_executor/platform:dso_loader",
]),
)
cc_library(name = "empty_lib")
alias(
name = "cublas_lt_lib",
actual = select({
"//tensorflow:oss": ":cublas_lt_stub",
"//conditions:default": ":empty_lib",
}),
visibility = ["//visibility:public"],
)
cc_library(
name = "cublas_plugin",
srcs = if_cuda_is_configured(["cuda_blas.cc"]),
@ -258,6 +283,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = if_cuda_is_configured([
":cublas_lib",
":cublas_lt_lib",
":cuda_activation",
":cuda_gpu_executor",
":cuda_platform_id",

View File

@ -0,0 +1,390 @@
// Auto-generated, do not edit.
extern "C" {
cublasStatus_t CUBLASWINAPI cublasLtCreate(cublasLtHandle_t *lightHandle) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasLtHandle_t *);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtCreate");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(lightHandle);
}
cublasStatus_t CUBLASWINAPI cublasLtDestroy(cublasLtHandle_t lightHandle) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasLtHandle_t);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtDestroy");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(lightHandle);
}
size_t CUBLASWINAPI cublasLtGetVersion(void) {
using FuncPtr = size_t(CUBLASWINAPI *)();
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtGetVersion");
if (!func_ptr) return 0;
return func_ptr();
}
size_t CUBLASWINAPI cublasLtGetCudartVersion(void) {
using FuncPtr = size_t(CUBLASWINAPI *)();
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtGetCudartVersion");
if (!func_ptr) return 0;
return func_ptr();
}
cublasStatus_t CUBLASWINAPI cublasLtGetProperty(libraryPropertyType type,
int *value) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(libraryPropertyType, int *);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtGetProperty");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(type, value);
}
cublasStatus_t CUBLASWINAPI cublasLtMatmul(
cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t computeDesc,
const void *alpha, /* host or device pointer */
const void *A, cublasLtMatrixLayout_t Adesc, const void *B,
cublasLtMatrixLayout_t Bdesc, const void *beta, /* host or device pointer */
const void *C, cublasLtMatrixLayout_t Cdesc, void *D,
cublasLtMatrixLayout_t Ddesc, const cublasLtMatmulAlgo_t *algo,
void *workspace, size_t workspaceSizeInBytes, cudaStream_t stream) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(
cublasLtHandle_t, cublasLtMatmulDesc_t, const void *, const void *,
cublasLtMatrixLayout_t, const void *, cublasLtMatrixLayout_t,
const void *, const void *, cublasLtMatrixLayout_t, void *,
cublasLtMatrixLayout_t, const cublasLtMatmulAlgo_t *, void *, size_t,
cudaStream_t);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmul");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(lightHandle, computeDesc, alpha, A, Adesc, B, Bdesc, beta, C,
Cdesc, D, Ddesc, algo, workspace, workspaceSizeInBytes,
stream);
}
cublasStatus_t CUBLASWINAPI cublasLtMatrixTransform(
cublasLtHandle_t lightHandle, cublasLtMatrixTransformDesc_t transformDesc,
const void *alpha, /* host or device pointer */
const void *A, cublasLtMatrixLayout_t Adesc,
const void *beta, /* host or device pointer */
const void *B, cublasLtMatrixLayout_t Bdesc, void *C,
cublasLtMatrixLayout_t Cdesc, cudaStream_t stream) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(
cublasLtHandle_t, cublasLtMatrixTransformDesc_t, const void *,
const void *, cublasLtMatrixLayout_t, const void *, const void *,
cublasLtMatrixLayout_t, void *, cublasLtMatrixLayout_t, cudaStream_t);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatrixTransform");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(lightHandle, transformDesc, alpha, A, Adesc, beta, B, Bdesc,
C, Cdesc, stream);
}
cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutInit_internal( //
cublasLtMatrixLayout_t matLayout, size_t size, cudaDataType type,
uint64_t rows, uint64_t cols, int64_t ld) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( //
cublasLtMatrixLayout_t, size_t, cudaDataType, uint64_t, uint64_t,
int64_t);
static auto func_ptr =
LoadSymbol<FuncPtr>("cublasLtMatrixLayoutInit_internal");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(matLayout, size, type, rows, cols, ld);
}
cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutCreate( //
cublasLtMatrixLayout_t *matLayout, cudaDataType type, uint64_t rows,
uint64_t cols, int64_t ld) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( //
cublasLtMatrixLayout_t *, cudaDataType, uint64_t, uint64_t, int64_t);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatrixLayoutCreate");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(matLayout, type, rows, cols, ld);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatrixLayoutDestroy(cublasLtMatrixLayout_t matLayout) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasLtMatrixLayout_t);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatrixLayoutDestroy");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(matLayout);
}
cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutSetAttribute( //
cublasLtMatrixLayout_t matLayout, cublasLtMatrixLayoutAttribute_t attr,
const void *buf, size_t sizeInBytes) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( //
cublasLtMatrixLayout_t, cublasLtMatrixLayoutAttribute_t, const void *,
size_t);
static auto func_ptr =
LoadSymbol<FuncPtr>("cublasLtMatrixLayoutSetAttribute");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(matLayout, attr, buf, sizeInBytes);
}
cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutGetAttribute( //
cublasLtMatrixLayout_t matLayout, cublasLtMatrixLayoutAttribute_t attr,
void *buf, size_t sizeInBytes, size_t *sizeWritten) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( //
cublasLtMatrixLayout_t, cublasLtMatrixLayoutAttribute_t, void *, size_t,
size_t *);
static auto func_ptr =
LoadSymbol<FuncPtr>("cublasLtMatrixLayoutGetAttribute");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(matLayout, attr, buf, sizeInBytes, sizeWritten);
}
cublasStatus_t CUBLASWINAPI cublasLtMatmulDescInit_internal( //
cublasLtMatmulDesc_t matmulDesc, size_t size,
cublasComputeType_t computeType, cudaDataType_t scaleType) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( //
cublasLtMatmulDesc_t, size_t, cublasComputeType_t, cudaDataType_t);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulDescInit_internal");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(matmulDesc, size, computeType, scaleType);
}
cublasStatus_t CUBLASWINAPI cublasLtMatmulDescCreate(
cublasLtMatmulDesc_t *matmulDesc, cublasComputeType_t computeType,
cudaDataType_t scaleType) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(
cublasLtMatmulDesc_t *, cublasComputeType_t, cudaDataType_t);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulDescCreate");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(matmulDesc, computeType, scaleType);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatmulDescDestroy(cublasLtMatmulDesc_t matmulDesc) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasLtMatmulDesc_t);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulDescDestroy");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(matmulDesc);
}
cublasStatus_t CUBLASWINAPI cublasLtMatmulDescSetAttribute( //
cublasLtMatmulDesc_t matmulDesc, cublasLtMatmulDescAttributes_t attr,
const void *buf, size_t sizeInBytes) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( //
cublasLtMatmulDesc_t, cublasLtMatmulDescAttributes_t, const void *,
size_t);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulDescSetAttribute");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(matmulDesc, attr, buf, sizeInBytes);
}
cublasStatus_t CUBLASWINAPI cublasLtMatmulDescGetAttribute( //
cublasLtMatmulDesc_t matmulDesc, cublasLtMatmulDescAttributes_t attr,
void *buf, size_t sizeInBytes, size_t *sizeWritten) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( //
cublasLtMatmulDesc_t, cublasLtMatmulDescAttributes_t, void *, size_t,
size_t *);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulDescGetAttribute");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(matmulDesc, attr, buf, sizeInBytes, sizeWritten);
}
cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescInit_internal(
cublasLtMatrixTransformDesc_t transformDesc, size_t size,
cudaDataType scaleType) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasLtMatrixTransformDesc_t,
size_t, cudaDataType);
static auto func_ptr =
LoadSymbol<FuncPtr>("cublasLtMatrixTransformDescInit_internal");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(transformDesc, size, scaleType);
}
cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescCreate(
cublasLtMatrixTransformDesc_t *transformDesc, cudaDataType scaleType) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(
cublasLtMatrixTransformDesc_t *, cudaDataType);
static auto func_ptr =
LoadSymbol<FuncPtr>("cublasLtMatrixTransformDescCreate");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(transformDesc, scaleType);
}
cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescDestroy(
cublasLtMatrixTransformDesc_t transformDesc) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasLtMatrixTransformDesc_t);
static auto func_ptr =
LoadSymbol<FuncPtr>("cublasLtMatrixTransformDescDestroy");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(transformDesc);
}
cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescSetAttribute( //
cublasLtMatrixTransformDesc_t transformDesc,
cublasLtMatrixTransformDescAttributes_t attr, const void *buf,
size_t sizeInBytes) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( //
cublasLtMatrixTransformDesc_t, cublasLtMatrixTransformDescAttributes_t,
const void *, size_t);
static auto func_ptr =
LoadSymbol<FuncPtr>("cublasLtMatrixTransformDescSetAttribute");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(transformDesc, attr, buf, sizeInBytes);
}
cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescGetAttribute( //
cublasLtMatrixTransformDesc_t transformDesc,
cublasLtMatrixTransformDescAttributes_t attr, void *buf, size_t sizeInBytes,
size_t *sizeWritten) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( //
cublasLtMatrixTransformDesc_t, cublasLtMatrixTransformDescAttributes_t,
void *, size_t, size_t *);
static auto func_ptr =
LoadSymbol<FuncPtr>("cublasLtMatrixTransformDescGetAttribute");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(transformDesc, attr, buf, sizeInBytes, sizeWritten);
}
cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceInit_internal(
cublasLtMatmulPreference_t pref, size_t size) {
using FuncPtr =
cublasStatus_t(CUBLASWINAPI *)(cublasLtMatmulPreference_t, size_t);
static auto func_ptr =
LoadSymbol<FuncPtr>("cublasLtMatmulPreferenceInit_internal");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(pref, size);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatmulPreferenceCreate(cublasLtMatmulPreference_t *pref) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasLtMatmulPreference_t *);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulPreferenceCreate");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(pref);
}
cublasStatus_t CUBLASWINAPI
cublasLtMatmulPreferenceDestroy(cublasLtMatmulPreference_t pref) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasLtMatmulPreference_t);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulPreferenceDestroy");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(pref);
}
cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceSetAttribute( //
cublasLtMatmulPreference_t pref, cublasLtMatmulPreferenceAttributes_t attr,
const void *buf, size_t sizeInBytes) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( //
cublasLtMatmulPreference_t, cublasLtMatmulPreferenceAttributes_t,
const void *, size_t);
static auto func_ptr =
LoadSymbol<FuncPtr>("cublasLtMatmulPreferenceSetAttribute");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(pref, attr, buf, sizeInBytes);
}
cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceGetAttribute( //
cublasLtMatmulPreference_t pref, cublasLtMatmulPreferenceAttributes_t attr,
void *buf, size_t sizeInBytes, size_t *sizeWritten) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( //
cublasLtMatmulPreference_t, cublasLtMatmulPreferenceAttributes_t, void *,
size_t, size_t *);
static auto func_ptr =
LoadSymbol<FuncPtr>("cublasLtMatmulPreferenceGetAttribute");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(pref, attr, buf, sizeInBytes, sizeWritten);
}
cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoGetHeuristic(
cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t operationDesc,
cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc,
cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc,
cublasLtMatmulPreference_t preference, int requestedAlgoCount,
cublasLtMatmulHeuristicResult_t heuristicResultsArray[],
int *returnAlgoCount) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(
cublasLtHandle_t, cublasLtMatmulDesc_t, cublasLtMatrixLayout_t,
cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, cublasLtMatrixLayout_t,
cublasLtMatmulPreference_t, int, cublasLtMatmulHeuristicResult_t[],
int *);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulAlgoGetHeuristic");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(lightHandle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc,
preference, requestedAlgoCount, heuristicResultsArray,
returnAlgoCount);
}
cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoGetIds(
cublasLtHandle_t lightHandle, cublasComputeType_t computeType,
cudaDataType_t scaleType, cudaDataType_t Atype, cudaDataType_t Btype,
cudaDataType_t Ctype, cudaDataType_t Dtype, int requestedAlgoCount,
int algoIdsArray[], int *returnAlgoCount) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(
cublasLtHandle_t, cublasComputeType_t, cudaDataType_t, cudaDataType_t,
cudaDataType_t, cudaDataType_t, cudaDataType_t, int, int[], int *);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulAlgoGetIds");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(lightHandle, computeType, scaleType, Atype, Btype, Ctype,
Dtype, requestedAlgoCount, algoIdsArray, returnAlgoCount);
}
cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoInit(
cublasLtHandle_t lightHandle, cublasComputeType_t computeType,
cudaDataType_t scaleType, cudaDataType_t Atype, cudaDataType_t Btype,
cudaDataType_t Ctype, cudaDataType_t Dtype, int algoId,
cublasLtMatmulAlgo_t *algo) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(
cublasLtHandle_t, cublasComputeType_t, cudaDataType_t, cudaDataType_t,
cudaDataType_t, cudaDataType_t, cudaDataType_t, int,
cublasLtMatmulAlgo_t *);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulAlgoInit");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(lightHandle, computeType, scaleType, Atype, Btype, Ctype,
Dtype, algoId, algo);
}
cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoCheck( //
cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t operationDesc,
cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc,
cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc,
const cublasLtMatmulAlgo_t *algo, ///< may point to result->algo
cublasLtMatmulHeuristicResult_t *result) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( //
cublasLtHandle_t, cublasLtMatmulDesc_t, cublasLtMatrixLayout_t,
cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, cublasLtMatrixLayout_t,
const cublasLtMatmulAlgo_t *, ///< may point to result->algo
cublasLtMatmulHeuristicResult_t *);
static auto func_ptr = LoadSymbol<FuncPtr>("cublasLtMatmulAlgoCheck");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(lightHandle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, algo,
result);
}
cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoCapGetAttribute(
const cublasLtMatmulAlgo_t *algo, cublasLtMatmulAlgoCapAttributes_t attr,
void *buf, size_t sizeInBytes, size_t *sizeWritten) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(
const cublasLtMatmulAlgo_t *, cublasLtMatmulAlgoCapAttributes_t, void *,
size_t, size_t *);
static auto func_ptr =
LoadSymbol<FuncPtr>("cublasLtMatmulAlgoCapGetAttribute");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(algo, attr, buf, sizeInBytes, sizeWritten);
}
cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoConfigSetAttribute(
cublasLtMatmulAlgo_t *algo, cublasLtMatmulAlgoConfigAttributes_t attr,
const void *buf, size_t sizeInBytes) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(
cublasLtMatmulAlgo_t *, cublasLtMatmulAlgoConfigAttributes_t,
const void *, size_t);
static auto func_ptr =
LoadSymbol<FuncPtr>("cublasLtMatmulAlgoConfigSetAttribute");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(algo, attr, buf, sizeInBytes);
}
cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoConfigGetAttribute(
const cublasLtMatmulAlgo_t *algo, cublasLtMatmulAlgoConfigAttributes_t attr,
void *buf, size_t sizeInBytes, size_t *sizeWritten) {
using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(
const cublasLtMatmulAlgo_t *, cublasLtMatmulAlgoConfigAttributes_t,
void *, size_t, size_t *);
static auto func_ptr =
LoadSymbol<FuncPtr>("cublasLtMatmulAlgoConfigGetAttribute");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr(algo, attr, buf, sizeInBytes, sizeWritten);
}
} // extern "C"

View File

@ -0,0 +1,59 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/gpus/cuda/include/cublasLt.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/platform/dso_loader.h"
// Implements the cuBLASLt API by forwarding to cuBLASLt loaded from the DSO.
namespace {
// Returns DSO handle or null if loading the DSO fails.
void* GetDsoHandle() {
#ifdef PLATFORM_GOOGLE
return nullptr;
#else
static auto handle = []() -> void* {
auto handle_or =
stream_executor::internal::DsoLoader::GetCublasLtDsoHandle();
if (!handle_or.ok()) return nullptr;
return handle_or.ValueOrDie();
}();
return handle;
#endif
}
template <typename T>
T LoadSymbol(const char* symbol_name) {
void* symbol = nullptr;
if (auto handle = GetDsoHandle()) {
stream_executor::port::Env::Default()
->GetSymbolFromLibrary(handle, symbol_name, &symbol)
.IgnoreError();
}
return reinterpret_cast<T>(symbol);
}
void LogFatalSymbolNotFound(const char* symbol_name) {
LOG(FATAL) << symbol_name << " symbol not found.";
}
cublasStatus_t GetSymbolNotFoundError() { return CUBLAS_STATUS_INTERNAL_ERROR; }
} // namespace
// We only use cublasLt from CUDA 11.0 onward.
#if CUDA_VERSION >= 11000
#include "tensorflow/stream_executor/cuda/cublasLt_11_0.inc"
#endif

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/gpus/cuda/include/cublasLt.h"
#include "third_party/gpus/cuda/include/cublas_v2.h"
#include "third_party/gpus/cuda/include/cuda.h"
@ -226,17 +227,38 @@ bool CUDABlas::Init() {
return false;
}
#if CUDA_VERSION >= 11000
ret = cublasLtCreate(&blasLt_);
if (ret != CUBLAS_STATUS_SUCCESS) {
LOG(ERROR) << "failed to create cublasLt handle: " << ToString(ret);
return false;
}
#endif // CUDA_VERSION >= 11000
return true;
}
CUDABlas::CUDABlas(gpu::GpuExecutor *parent)
: parent_(CHECK_NOTNULL(parent)), blas_(nullptr) {}
: parent_(CHECK_NOTNULL(parent)),
blas_(nullptr)
#if CUDA_VERSION >= 11000
,
blasLt_(nullptr)
#endif
{
}
CUDABlas::~CUDABlas() {
if (blas_ != nullptr) {
gpu::ScopedActivateExecutorContext sac{parent_};
cublasDestroy(blas_);
}
#if CUDA_VERSION >= 11000
if (blasLt_ != nullptr) {
gpu::ScopedActivateExecutorContext sac{parent_};
cublasLtDestroy(blasLt_);
}
#endif
}
bool CUDABlas::SetStream(Stream *stream) {
@ -253,6 +275,13 @@ bool CUDABlas::SetStream(Stream *stream) {
return true;
}
cudaStream_t CUDABlas::CUDAStream(Stream *stream) {
CHECK(stream != nullptr);
CHECK(AsGpuStreamValue(stream) != nullptr);
gpu::ScopedActivateExecutorContext sac{parent_};
return AsGpuStreamValue(stream);
}
namespace {
// Helper functions transforming blas arguments into cuBLAS arguments.
@ -381,8 +410,122 @@ cudaDataType_t CUDAComputationType(blas::ComputationType ty) {
return CUDA_C_32F;
case blas::ComputationType::kComplexF64:
return CUDA_C_64F;
case blas::ComputationType::kTF32AsF32: // fall-through
case blas::ComputationType::kBF16AsF32:
// These cases are currently only supported in the blasLt routines, which
// use CUBLASComputationType() instead.
LOG(FATAL) << "Invalid value of blas::ComputationType.";
}
}
#if CUDA_VERSION >= 11000
cublasComputeType_t CUBLASComputationType(blas::ComputationType ty) {
switch (ty) {
case blas::ComputationType::kF16:
return CUBLAS_COMPUTE_16F;
case blas::ComputationType::kF32: // fall-through
case blas::ComputationType::kComplexF32:
return CUBLAS_COMPUTE_32F;
case blas::ComputationType::kF64: // fall-through
case blas::ComputationType::kComplexF64:
return CUBLAS_COMPUTE_64F;
case blas::ComputationType::kI32:
return CUBLAS_COMPUTE_32I;
case blas::ComputationType::kTF32AsF32:
return CUBLAS_COMPUTE_32F_FAST_TF32;
case blas::ComputationType::kBF16AsF32:
return CUBLAS_COMPUTE_32F_FAST_16BF;
}
}
#endif // CUDA_VERSION >= 11000
blas::DataType GetScaleType(blas::DataType data_type,
blas::ComputationType compute_type) {
bool is_complex = data_type == blas::DataType::kComplexFloat ||
data_type == blas::DataType::kComplexDouble;
switch (compute_type) {
case blas::ComputationType::kF16:
return blas::DataType::kHalf;
case blas::ComputationType::kF32: // fall-through
case blas::ComputationType::kComplexF32: // fall-through
case blas::ComputationType::kTF32AsF32: // fall-through
case blas::ComputationType::kBF16AsF32:
return is_complex ? blas::DataType::kComplexFloat
: blas::DataType::kFloat;
case blas::ComputationType::kF64: // fall-through
case blas::ComputationType::kComplexF64:
return is_complex ? blas::DataType::kComplexDouble
: blas::DataType::kDouble;
case blas::ComputationType::kI32:
return blas::DataType::kInt32;
}
}
#if CUDA_VERSION >= 11000
cublasLtPointerMode_t CUBLASPointerMode(blas::PointerMode pointer_mode) {
switch (pointer_mode) {
case blas::PointerMode::kHost:
return CUBLASLT_POINTER_MODE_HOST;
case blas::PointerMode::kDevice:
return CUBLASLT_POINTER_MODE_DEVICE;
}
}
cublasLtEpilogue_t CUBLASEpilogue(blas::Epilogue epilogue) {
switch (epilogue) {
case blas::Epilogue::kDefault:
return CUBLASLT_EPILOGUE_DEFAULT;
case blas::Epilogue::kReLU:
return CUBLASLT_EPILOGUE_RELU;
case blas::Epilogue::kBias:
return CUBLASLT_EPILOGUE_BIAS;
case blas::Epilogue::kBiasThenReLU:
return CUBLASLT_EPILOGUE_RELU_BIAS;
}
}
#endif // CUDA_VERSION >= 11000
cudaDataType_t GetCUDADataType(blas::DataType ty) {
switch (ty) {
case blas::DataType::kHalf:
return CUDA_R_16F;
case blas::DataType::kFloat:
return CUDA_R_32F;
case blas::DataType::kDouble:
return CUDA_R_64F;
case blas::DataType::kInt8:
return CUDA_R_8I;
case blas::DataType::kInt32:
return CUDA_R_32I;
case blas::DataType::kComplexFloat:
return CUDA_C_32F;
case blas::DataType::kComplexDouble:
return CUDA_C_64F;
default:
LOG(FATAL) << "Invalid value of blas::DataType in GetCUDADataType";
}
}
int GetDataTypeSizeBytes(blas::DataType ty) {
switch (ty) {
case blas::DataType::kHalf:
return 2;
case blas::DataType::kFloat:
return 4;
case blas::DataType::kDouble:
return 8;
case blas::DataType::kInt8:
return 1;
case blas::DataType::kInt32:
return 4;
case blas::DataType::kComplexFloat:
return 8;
case blas::DataType::kComplexDouble:
return 16;
default:
LOG(FATAL) << "Invalid value of blas::DataType in GetDataTypeSizeBytes";
}
}
} // namespace
template <typename FuncT, typename... Args>
@ -2921,6 +3064,575 @@ bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
GpuComplex(GpuMemoryMutable(b)), ldb);
}
// We only use cublasLt from CUDA 11.0 onward.
#if CUDA_VERSION >= 11000
namespace {
template <typename T>
inline port::Status SetCublasLtAttr(cublasLtMatrixLayout_t handle,
cublasLtMatrixLayoutAttribute_t attr,
const T &value) {
cublasStatus_t status =
cublasLtMatrixLayoutSetAttribute(handle, attr, &value, sizeof(T));
if (status != CUBLAS_STATUS_SUCCESS) {
return port::Status(
port::error::INTERNAL,
absl::StrCat("cublasLtMatrixLayoutSetAttribute(attr=", attr,
", value=", value, ") failed: ", ToString(status)));
}
return port::Status::OK();
}
template <typename T>
inline port::Status SetCublasLtAttr(cublasLtMatmulAlgo_t *handle,
cublasLtMatmulAlgoConfigAttributes_t attr,
const T &value) {
cublasStatus_t status =
cublasLtMatmulAlgoConfigSetAttribute(handle, attr, &value, sizeof(T));
if (status != CUBLAS_STATUS_SUCCESS) {
return port::Status(
port::error::INTERNAL,
absl::StrCat("cublasLtMatmulAlgoConfigSetAttribute(attr=", attr,
", value=", value, ") failed: ", ToString(status)));
}
return port::Status::OK();
}
template <typename T>
inline port::Status SetCublasLtAttr(cublasLtMatmulPreference_t handle,
cublasLtMatmulPreferenceAttributes_t attr,
const T &value) {
cublasStatus_t status =
cublasLtMatmulPreferenceSetAttribute(handle, attr, &value, sizeof(value));
if (status != CUBLAS_STATUS_SUCCESS) {
return port::Status(
port::error::INTERNAL,
absl::StrCat("cublasLtMatmulPreferenceSetAttribute(attr=", attr,
", value=", value, ") failed: ", ToString(status)));
}
return port::Status::OK();
}
template <typename T>
inline bool GetCublasLtAttr(const cublasLtMatmulAlgo_t *handle,
cublasLtMatmulAlgoConfigAttributes_t attr,
T *value) {
auto mutable_handle = const_cast<cublasLtMatmulAlgo_t *>(handle);
size_t bytes_written = 0;
return cublasLtMatmulAlgoConfigGetAttribute(mutable_handle, attr, value,
sizeof(T), &bytes_written) ==
CUBLAS_STATUS_SUCCESS &&
bytes_written == sizeof(T);
}
template <typename T>
inline const T &ValueForStrCat(const T &value) {
return value;
}
template <typename T>
inline absl::Hex ValueForStrCat(T *ptr) {
return absl::Hex(reinterpret_cast<uintptr_t>(ptr));
}
template <typename T>
inline port::Status SetCublasLtAttr(cublasLtMatmulDesc_t handle,
cublasLtMatmulDescAttributes_t attr,
const T &value) {
cublasStatus_t status =
cublasLtMatmulDescSetAttribute(handle, attr, &value, sizeof(value));
if (status != CUBLAS_STATUS_SUCCESS) {
return port::Status(
port::error::INTERNAL,
absl::StrCat("cublasLtMatmulDescSetAttribute(attr=", attr, ", value=",
ValueForStrCat(value), ") failed: ", ToString(status)));
}
return port::Status::OK();
}
struct MatmulDescDestroyer {
void operator()(cublasLtMatmulDesc_t matmul_desc) const {
cublasLtMatmulDescDestroy(matmul_desc);
}
};
struct LayoutDestroyer {
void operator()(cublasLtMatrixLayout_t layout) const {
cublasLtMatrixLayoutDestroy(layout);
}
};
struct MatmulPreferenceDestroyer {
void operator()(cublasLtMatmulPreference_t matmul_pref) const {
cublasLtMatmulPreferenceDestroy(matmul_pref);
}
};
using UniqueOpDesc =
std::unique_ptr<std::remove_pointer<cublasLtMatmulDesc_t>::type,
MatmulDescDestroyer>;
using UniqueLayoutDesc =
std::unique_ptr<std::remove_pointer<cublasLtMatrixLayout_t>::type,
LayoutDestroyer>;
using UniqueMatmulPreference =
std::unique_ptr<std::remove_pointer<cublasLtMatmulPreference_t>::type,
MatmulPreferenceDestroyer>;
port::StatusOr<UniqueOpDesc> CreateCublasLtOperationDesc(
blas::ComputationType computation_type, blas::DataType scale_type,
blas::PointerMode pointer_mode, blas::Epilogue epilogue,
blas::Transpose transa, blas::Transpose transb) {
cublasLtMatmulDesc_t desc;
cublasComputeType_t cublas_compute_type =
CUBLASComputationType(computation_type);
cudaDataType_t cuda_scale_type = GetCUDADataType(scale_type);
cublasStatus_t status =
cublasLtMatmulDescCreate(&desc, cublas_compute_type, cuda_scale_type);
if (status != CUBLAS_STATUS_SUCCESS) {
return port::Status(
port::error::INTERNAL,
absl::StrCat("cublasLtMatmulDescCreate(computation_type=",
computation_type, ") failed: ", ToString(status)));
}
UniqueOpDesc unique_desc(desc);
SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_POINTER_MODE,
CUBLASPointerMode(pointer_mode)));
SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_EPILOGUE,
CUBLASEpilogue(epilogue)));
SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSA,
CUDABlasTranspose(transa)));
SE_RETURN_IF_ERROR(SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSB,
CUDABlasTranspose(transb)));
return unique_desc;
}
port::StatusOr<UniqueLayoutDesc> CreateCublasLtLayoutDesc(
blas::DataType data_type, uint64 rows, uint64 cols, int64 ld, int64 stride,
int batch_count) {
cublasLtMatrixLayout_t desc;
cublasStatus_t status = cublasLtMatrixLayoutCreate(
&desc, GetCUDADataType(data_type), rows, cols, ld);
if (status != CUBLAS_STATUS_SUCCESS) {
return port::Status(
port::error::INTERNAL,
absl::StrCat("cublasLtMatrixLayoutCreate failed: ", ToString(status)));
}
UniqueLayoutDesc unique_desc(desc);
SE_RETURN_IF_ERROR(
SetCublasLtAttr(desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_count));
SE_RETURN_IF_ERROR(SetCublasLtAttr(
desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride));
return unique_desc;
}
// Helper function to allocate workspace.
port::Status AllocateWorkspace(void **workspace,
ScratchAllocator *scratch_allocator,
size_t num_bytes) {
SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace_bytes,
scratch_allocator->AllocateBytes(num_bytes));
*workspace = (void *)GpuMemoryMutable(&workspace_bytes);
return port::Status::OK();
}
template <typename T>
blas::ComputationType ToComputationType();
template <>
blas::ComputationType ToComputationType<Eigen::half>() {
return blas::ComputationType::kF16;
}
template <>
blas::ComputationType ToComputationType<float>() {
return blas::ComputationType::kF32;
}
template <>
blas::ComputationType ToComputationType<double>() {
return blas::ComputationType::kF64;
}
template <>
blas::ComputationType ToComputationType<std::complex<float>>() {
return blas::ComputationType::kComplexF32;
}
template <>
blas::ComputationType ToComputationType<std::complex<double>>() {
return blas::ComputationType::kComplexF64;
}
class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
public:
CUDABlasLtMatmulPlan(UniqueOpDesc op_desc, UniqueLayoutDesc a_desc,
UniqueLayoutDesc b_desc, UniqueLayoutDesc c_desc,
UniqueLayoutDesc d_desc, blas::DataType ab_type,
blas::DataType c_type, blas::DataType scale_type,
blas::PointerMode pointer_mode, blas::Epilogue epilogue,
int batch_count, int64 stride_a, int64 stride_b,
int64 stride_c, int64 stride_d)
: op_desc_(std::move(op_desc)),
a_desc_(std::move(a_desc)),
b_desc_(std::move(b_desc)),
c_desc_(std::move(c_desc)),
d_desc_(std::move(d_desc)),
ab_type_(ab_type),
c_type_(c_type),
scale_type_(scale_type),
pointer_mode_(pointer_mode),
epilogue_(epilogue),
batch_count_(batch_count),
stride_a_(stride_a),
stride_b_(stride_b),
stride_c_(stride_c),
stride_d_(stride_d) {}
cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); }
cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); }
cublasLtMatrixLayout_t b_desc() const { return b_desc_.get(); }
cublasLtMatrixLayout_t c_desc() const { return c_desc_.get(); }
cublasLtMatrixLayout_t d_desc() const { return d_desc_.get(); }
bool ok() { return op_desc_ && a_desc_ && b_desc_ && c_desc_ && d_desc_; }
blas::DataType ab_type() const override { return ab_type_; }
blas::DataType c_type() const override { return c_type_; }
blas::DataType scale_type() const { return scale_type_; }
blas::PointerMode pointer_mode() const { return pointer_mode_; }
blas::Epilogue epilogue() const { return epilogue_; }
int batch_count() const { return batch_count_; }
int64 stride_a() const { return stride_a_; }
int64 stride_b() const { return stride_b_; }
int64 stride_c() const { return stride_c_; }
int64 stride_d() const { return stride_d_; }
// Note: Must be const to satisfy API. This is always called before the plan
// is executed, so the state change is not observed in subsequent executions.
bool SetBiasPointer(const void *bias) const;
private:
UniqueOpDesc op_desc_;
UniqueLayoutDesc a_desc_;
UniqueLayoutDesc b_desc_;
UniqueLayoutDesc c_desc_;
UniqueLayoutDesc d_desc_;
blas::DataType ab_type_;
blas::DataType c_type_;
blas::DataType scale_type_;
blas::PointerMode pointer_mode_;
blas::Epilogue epilogue_;
int batch_count_;
int64 stride_a_;
int64 stride_b_;
int64 stride_c_;
int64 stride_d_;
};
bool CUDABlasLtMatmulPlan::SetBiasPointer(const void *bias) const {
return SetCublasLtAttr(op_desc_.get(), CUBLASLT_MATMUL_DESC_BIAS_POINTER,
bias)
.ok();
}
class CUDABlasLtMatmulAlgorithm final : public blas::IBlasLtMatmulAlgorithm {
public:
CUDABlasLtMatmulAlgorithm(blas::AlgorithmType index,
cublasLtMatmulAlgo_t algo, size_t workspace_size)
: index_(index), algo_(algo), workspace_size_(workspace_size) {}
blas::AlgorithmType index() const override { return index_; }
size_t workspace_size() const override { return workspace_size_; }
const cublasLtMatmulAlgo_t *algo() const { return &algo_; }
int algo_id() const {
int id;
GetCublasLtAttr(&algo_, CUBLASLT_ALGO_CONFIG_ID, &id);
return id;
}
private:
blas::AlgorithmType index_;
cublasLtMatmulAlgo_t algo_;
size_t workspace_size_;
};
port::StatusOr<UniqueMatmulPreference> CreateCublasLtMatmulPreference(
const blas::IBlasLtMatmulPlan *plan, size_t max_workspace_bytes) {
cublasLtMatmulPreference_t preference;
cublasStatus_t status = cublasLtMatmulPreferenceCreate(&preference);
if (status != CUBLAS_STATUS_SUCCESS) {
return port::Status(port::error::INTERNAL,
absl::StrCat("cublasLtMatmulPreferenceCreate failed: ",
ToString(status)));
}
UniqueMatmulPreference unique_preference(preference);
SE_RETURN_IF_ERROR(SetCublasLtAttr(preference,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
max_workspace_bytes));
const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
if (cuda_plan.batch_count() == 0) {
return unique_preference;
}
// This is a workaround for a known issue in cuBlasLt where the heuristic may
// in rare cases select an algo that does not support the specified stride.
// Specifying the alignment requirements manually like this avoids the issue.
auto get_alignment_bytes = [](int64 stride, blas::DataType dtype) {
return (stride & -stride) * GetDataTypeSizeBytes(dtype);
};
if (cuda_plan.stride_a()) {
SE_RETURN_IF_ERROR(
SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES,
(uint32)get_alignment_bytes(cuda_plan.stride_a(),
cuda_plan.ab_type())));
}
if (cuda_plan.stride_b()) {
SE_RETURN_IF_ERROR(
SetCublasLtAttr(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES,
(uint32)get_alignment_bytes(cuda_plan.stride_b(),
cuda_plan.ab_type())));
}
if (cuda_plan.stride_c()) {
SE_RETURN_IF_ERROR(SetCublasLtAttr(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES,
(uint32)get_alignment_bytes(cuda_plan.stride_c(), cuda_plan.c_type())));
}
if (cuda_plan.stride_d()) {
SE_RETURN_IF_ERROR(SetCublasLtAttr(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES,
(uint32)get_alignment_bytes(cuda_plan.stride_d(), cuda_plan.c_type())));
}
return unique_preference;
}
} // namespace
#endif // CUDA_VERSION >= 11000
port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
CUDABlas::CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &p) {
#if CUDA_VERSION >= 11000
SE_ASSIGN_OR_RETURN(
auto op_desc,
CreateCublasLtOperationDesc(
p.computation_type, GetScaleType(p.c_type, p.computation_type),
p.pointer_mode, p.epilogue, p.transa, p.transb));
uint64 rows_a = p.transa == blas::Transpose::kNoTranspose ? p.m : p.k;
uint64 cols_a = p.transa == blas::Transpose::kNoTranspose ? p.k : p.m;
uint64 rows_b = p.transb == blas::Transpose::kNoTranspose ? p.k : p.n;
uint64 cols_b = p.transb == blas::Transpose::kNoTranspose ? p.n : p.k;
SE_ASSIGN_OR_RETURN(auto a_desc,
CreateCublasLtLayoutDesc(p.ab_type, rows_a, cols_a, p.lda,
p.stride_a, p.batch_count));
SE_ASSIGN_OR_RETURN(auto b_desc,
CreateCublasLtLayoutDesc(p.ab_type, rows_b, cols_b, p.ldb,
p.stride_b, p.batch_count));
SE_ASSIGN_OR_RETURN(auto c_desc,
CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc,
p.stride_c, p.batch_count));
SE_ASSIGN_OR_RETURN(auto d_desc,
CreateCublasLtLayoutDesc(p.c_type, p.m, p.n, p.ldc,
p.stride_c, p.batch_count));
blas::DataType scale_type = GetScaleType(p.c_type, p.computation_type);
return static_cast<std::unique_ptr<blas::IBlasLtMatmulPlan>>(
std::make_unique<CUDABlasLtMatmulPlan>(
std::move(op_desc), std::move(a_desc), std::move(b_desc),
std::move(c_desc), std::move(d_desc), p.ab_type, p.c_type, scale_type,
p.pointer_mode, p.epilogue, p.batch_count, p.stride_a, p.stride_b,
p.stride_c, p.stride_c));
#else
return port::Status(
port::error::UNIMPLEMENTED,
"CreateBlasLtMatmulPlan is not supported with this version of CUDA");
#endif
}
port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
CUDABlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
size_t max_workspace_size,
int max_algorithm_count) {
#if CUDA_VERSION >= 11000
SE_ASSIGN_OR_RETURN(UniqueMatmulPreference preference,
CreateCublasLtMatmulPreference(plan, max_workspace_size));
std::vector<cublasLtMatmulHeuristicResult_t> results(max_algorithm_count);
{
absl::MutexLock lock(&mu_);
CHECK(blasLt_ != nullptr);
gpu::ScopedActivateExecutorContext sac{parent_};
int found_algorithm_count = 0;
const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
cublasStatus_t status = cublasLtMatmulAlgoGetHeuristic(
blasLt_, cuda_plan.op_desc(), cuda_plan.a_desc(), cuda_plan.b_desc(),
cuda_plan.c_desc(), cuda_plan.d_desc(), preference.get(),
max_algorithm_count, results.data(), &found_algorithm_count);
if (status != CUBLAS_STATUS_SUCCESS) {
return port::Status(
port::error::INTERNAL,
absl::StrCat("cublasLtMatmulAlgoGetHeuristic failed: ",
ToString(status)));
}
results.resize(found_algorithm_count);
}
std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>> out_algorithms;
out_algorithms.reserve(results.size());
for (size_t i = 0; i < results.size(); ++i) {
const auto &result = results[i];
if (result.state != CUBLAS_STATUS_SUCCESS) continue; // Skip failed algos
out_algorithms.emplace_back(std::make_unique<CUDABlasLtMatmulAlgorithm>(
i, result.algo, result.workspaceSize));
}
return out_algorithms;
#else // if CUDA_VERSION < 11000
return port::Status(
port::error::UNIMPLEMENTED,
"GetBlasLtMatmulAlgorithms is not supported with this version of CUDA");
#endif
}
#if CUDA_VERSION >= 11000
bool CUDABlas::DoBlasLtMatmulInternal(
Stream *stream, bool err_on_failure, const blas::IBlasLtMatmulPlan *plan,
const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a,
DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta,
DeviceMemoryBase c, DeviceMemoryBase d, ScratchAllocator *scratch_allocator,
const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias) {
const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
const auto &cuda_algo =
*static_cast<const CUDABlasLtMatmulAlgorithm *>(algorithm);
if (alpha.data_type() != cuda_plan.scale_type() ||
beta.data_type() != cuda_plan.scale_type()) {
VLOG(2) << "DoBlasLtMatmul returning false because alpha and beta types do "
"not match plan: expected "
<< cuda_plan.c_type() << ", got alpha=" << alpha.data_type()
<< " beta=" << beta.data_type();
return false;
}
if (alpha.is_pointer() != beta.is_pointer()) {
VLOG(2) << "DoBlasLtMatmul returning false because one of `alpha` "
"and `beta` is a pointer, but the other is not.";
return false;
}
bool is_pointer_mode_host = !alpha.is_pointer();
if ((cuda_plan.pointer_mode() == blas::PointerMode::kHost) !=
is_pointer_mode_host) {
VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong "
"pointer_mode for the given alpha/beta.";
return false;
}
if ((cuda_plan.epilogue() == blas::Epilogue::kBias ||
cuda_plan.epilogue() == blas::Epilogue::kBiasThenReLU) !=
(bias != nullptr)) {
VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong "
"epilogue for the given bias pointer.";
return false;
}
if (bias != nullptr) {
if (!cuda_plan.SetBiasPointer(bias.opaque())) {
VLOG(2) << "DoBlasLtMatmul returning false because setting the bias "
"pointer failed.";
return false;
}
}
const void *alpha_ptr = alpha.is_pointer() ? alpha.opaque_pointer().opaque()
: alpha.opaque_value();
const void *beta_ptr =
beta.is_pointer() ? beta.opaque_pointer().opaque() : beta.opaque_value();
void *workspace = nullptr;
if (cuda_algo.workspace_size()) {
port::Status allocation_status = AllocateWorkspace(
&workspace, scratch_allocator, cuda_algo.workspace_size());
if (!allocation_status.ok()) {
if (err_on_failure || VLOG_IS_ON(3)) {
LOG(ERROR)
<< "Failed to allocate workspace for cublasLtMatmul algo with id: "
<< cuda_algo.algo_id() << " requiring "
<< cuda_algo.workspace_size() << " bytes of workspace";
}
return false;
}
}
cudaStream_t cuda_stream = CUDAStream(stream);
absl::MutexLock lock(&mu_);
CHECK(blasLt_ != nullptr);
gpu::ScopedActivateExecutorContext sac{parent_};
cublasStatus_t ret = cublasLtMatmul(
blasLt_, cuda_plan.op_desc(), alpha_ptr, a.opaque(), cuda_plan.a_desc(),
b.opaque(), cuda_plan.b_desc(), beta_ptr, c.opaque(), cuda_plan.c_desc(),
d.opaque(), cuda_plan.d_desc(), cuda_algo.algo(), workspace,
cuda_algo.workspace_size(), cuda_stream);
if (ret != CUBLAS_STATUS_SUCCESS) {
if (err_on_failure || VLOG_IS_ON(3)) {
LOG(ERROR) << "failed to run cublasLtMatmul routine: " << ToString(ret);
}
return false;
}
return true;
}
#endif // CUDA_VERSION >= 11000
bool CUDABlas::DoBlasLtMatmul(
Stream *stream, const blas::IBlasLtMatmulPlan *plan,
const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a,
DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta,
DeviceMemoryBase c, ScratchAllocator *scratch_allocator,
const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias,
blas::ProfileResult *output_profile_result) {
#if CUDA_VERSION >= 11000
const auto &cuda_plan = *static_cast<const CUDABlasLtMatmulPlan *>(plan);
HostOrDeviceScalar<void> alpha_cast = alpha;
HostOrDeviceScalar<void> beta_cast = beta;
if (cuda_plan.c_type() == blas::DataType::kHalf &&
cuda_plan.scale_type() == blas::DataType::kFloat) {
// The given alpha and beta types are F16 (they always match c), but F32*
// computation type requires that they be F32, so we must cast them.
if (alpha.is_pointer() || beta.is_pointer()) {
// We cannot easily convert a pointer to f16 memory to a pointer to f32
// memory from here, so we don't support this for now.
return false;
}
alpha_cast = HostOrDeviceScalar<void>(
static_cast<float>(alpha.value<Eigen::half>()));
beta_cast =
HostOrDeviceScalar<void>(static_cast<float>(beta.value<Eigen::half>()));
}
std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
if (output_profile_result) {
timer.reset(new GpuTimer(parent_));
if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
return false;
}
}
bool err_on_failure = timer != nullptr;
bool result = DoBlasLtMatmulInternal(stream, err_on_failure, plan, alpha_cast,
a, b, beta_cast, c, c, scratch_allocator,
algorithm, bias);
if (timer && result) {
// GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
// state.
if (!timer->Stop(AsGpuStream(stream))) {
return false;
}
output_profile_result->set_is_valid(true);
output_profile_result->set_algorithm(algorithm->index());
output_profile_result->set_elapsed_time_in_ms(
timer->GetElapsedMilliseconds());
}
return result;
#else // if CUDA_VERSION < 11000
return false;
#endif
}
port::Status CUDABlas::GetVersion(std::string *version) {
absl::MutexLock lock(&mu_);

View File

@ -21,7 +21,9 @@ limitations under the License.
#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
#include "absl/synchronization/mutex.h"
#include "third_party/gpus/cuda/include/cublasLt.h"
#include "third_party/gpus/cuda/include/cublas_v2.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/host_or_device_scalar.h"
@ -71,6 +73,9 @@ class CUDABlas : public blas::BlasSupport {
// invoked before calling into cuBLAS.
bool SetStream(Stream *stream) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Returns the underlying CUDA stream.
cudaStream_t CUDAStream(Stream *stream);
// A helper function that calls the real cuBLAS function together with error
// handling.
//
@ -134,6 +139,17 @@ class CUDABlas : public blas::BlasSupport {
const T &beta, DeviceMemory<T> *y, int incy,
blas::ProfileResult *output_profile_result);
// Helper function for implementing DoBlasLtMatmul.
bool DoBlasLtMatmulInternal(Stream *stream, bool err_on_failure,
const blas::IBlasLtMatmulPlan *plan,
const HostOrDeviceScalar<void> &alpha,
DeviceMemoryBase a, DeviceMemoryBase b,
const HostOrDeviceScalar<void> &beta,
DeviceMemoryBase c, DeviceMemoryBase d,
ScratchAllocator *scratch_allocator,
const blas::IBlasLtMatmulAlgorithm *algorithm,
DeviceMemoryBase bias);
// Guards the cuBLAS handle for this device.
absl::Mutex mu_;
@ -144,6 +160,11 @@ class CUDABlas : public blas::BlasSupport {
// cuBLAS library handle on the device.
cublasHandle_t blas_ TF_GUARDED_BY(mu_);
#if CUDA_VERSION >= 11000
// cuBLASLt library handle on the device.
cublasLtHandle_t blasLt_ GUARDED_BY(mu_);
#endif
SE_DISALLOW_COPY_AND_ASSIGN(CUDABlas);
};

View File

@ -0,0 +1,66 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_STREAM_EXECUTOR_DATA_TYPE_H_
#define TENSORFLOW_STREAM_EXECUTOR_DATA_TYPE_H_
#include <complex>
#include "tensorflow/stream_executor/dnn.pb.h"
#include "tensorflow/stream_executor/platform/port.h"
namespace Eigen {
struct half;
} // namespace Eigen
namespace stream_executor {
namespace dnn {
// A helper class to convert C/C++ types to the proper enums.
template <typename T>
struct ToDataType;
template <>
struct ToDataType<float> {
static constexpr DataType value = DataType::kFloat;
};
template <>
struct ToDataType<double> {
static constexpr DataType value = DataType::kDouble;
};
template <>
struct ToDataType<Eigen::half> {
static constexpr DataType value = DataType::kHalf;
};
template <>
struct ToDataType<tensorflow::int8> {
static constexpr DataType value = DataType::kInt8;
};
template <>
struct ToDataType<tensorflow::int32> {
static constexpr DataType value = DataType::kInt32;
};
template <>
struct ToDataType<std::complex<float>> {
static constexpr DataType value = DataType::kComplexFloat;
};
template <>
struct ToDataType<std::complex<double>> {
static constexpr DataType value = DataType::kComplexDouble;
};
} // namespace dnn
} // namespace stream_executor
#endif // TENSORFLOW_STREAM_EXECUTOR_DATA_TYPE_H_

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/stream_executor/data_type.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/dnn.pb.h"
#include "tensorflow/stream_executor/lib/array_slice.h"
@ -110,30 +111,6 @@ enum class QuantizedActivationMode {
k32Bit = 4,
};
// A helper class to convert C/C++ types to the proper enums.
template <typename T>
struct ToDataType;
template <>
struct ToDataType<float> {
static constexpr DataType value = DataType::kFloat;
};
template <>
struct ToDataType<double> {
static constexpr DataType value = DataType::kDouble;
};
template <>
struct ToDataType<Eigen::half> {
static constexpr DataType value = DataType::kHalf;
};
template <>
struct ToDataType<int8> {
static constexpr DataType value = DataType::kInt8;
};
template <>
struct ToDataType<int32> {
static constexpr DataType value = DataType::kInt32;
};
// Specifies the types of a RNN model.
enum class RnnMode {
kRnnRelu = 0,

View File

@ -12,6 +12,8 @@ enum DataType {
kHalf = 2;
kInt8 = 3;
kInt32 = 4;
kComplexFloat = 5;
kComplexDouble = 6;
}
// Describes how a convolution input or output layer's data is formatted.

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_HOST_OR_DEVICE_SCALAR_H_
#define TENSORFLOW_STREAM_EXECUTOR_HOST_OR_DEVICE_SCALAR_H_
#include "tensorflow/stream_executor/data_type.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/platform/logging.h"
@ -23,6 +24,7 @@ namespace stream_executor {
// Allows to represent a value that is either a host scalar or a scalar stored
// on the GPU device.
// See also the specialization for ElemT=void below.
template <typename ElemT>
class HostOrDeviceScalar {
public:
@ -52,5 +54,154 @@ class HostOrDeviceScalar {
bool is_pointer_;
};
// Specialization for wrapping a dynamically-typed value (via type erasure).
template <>
class HostOrDeviceScalar<void> {
public:
using DataType = dnn::DataType;
// Constructors not marked as explicit because when using this constructor, we
// usually want to set this to a compile-time constant.
// NOLINTNEXTLINE google-explicit-constructor
HostOrDeviceScalar(float value)
: float_(value), is_pointer_(false), dtype_(DataType::kFloat) {}
// NOLINTNEXTLINE google-explicit-constructor
HostOrDeviceScalar(double value)
: double_(value), is_pointer_(false), dtype_(DataType::kDouble) {}
// NOLINTNEXTLINE google-explicit-constructor
HostOrDeviceScalar(Eigen::half value)
: half_(value), is_pointer_(false), dtype_(DataType::kHalf) {}
// NOLINTNEXTLINE google-explicit-constructor
HostOrDeviceScalar(int8 value)
: int8_(value), is_pointer_(false), dtype_(DataType::kInt8) {}
// NOLINTNEXTLINE google-explicit-constructor
HostOrDeviceScalar(int32 value)
: int32_(value), is_pointer_(false), dtype_(DataType::kInt32) {}
// NOLINTNEXTLINE google-explicit-constructor
HostOrDeviceScalar(std::complex<float> value)
: complex_float_(value),
is_pointer_(false),
dtype_(DataType::kComplexFloat) {}
// NOLINTNEXTLINE google-explicit-constructor
HostOrDeviceScalar(std::complex<double> value)
: complex_double_(value),
is_pointer_(false),
dtype_(DataType::kComplexDouble) {}
template <typename T>
explicit HostOrDeviceScalar(const DeviceMemory<T>& pointer)
: pointer_(pointer),
is_pointer_(true),
dtype_(dnn::ToDataType<T>::value) {
CHECK_EQ(1, pointer.ElementCount());
}
// Construct from statically-typed version.
template <typename T, typename std::enable_if<!std::is_same<T, void>::value,
int>::type = 0>
// NOLINTNEXTLINE google-explicit-constructor
HostOrDeviceScalar(const HostOrDeviceScalar<T>& other) {
if (other.is_pointer()) {
*this = HostOrDeviceScalar(other.pointer());
} else {
*this = HostOrDeviceScalar(other.value());
}
}
bool is_pointer() const { return is_pointer_; }
template <typename T>
const DeviceMemory<T>& pointer() const {
CHECK(is_pointer());
CHECK(dtype_ == dnn::ToDataType<T>::value);
return pointer_;
}
template <typename T>
const T& value() const {
CHECK(!is_pointer());
CHECK(dtype_ == dnn::ToDataType<T>::value);
return value_impl<T>();
}
const DeviceMemoryBase& opaque_pointer() const {
CHECK(is_pointer());
return pointer_;
}
const void* opaque_value() const {
CHECK(!is_pointer());
switch (dtype_) {
case DataType::kFloat:
return &float_;
case DataType::kDouble:
return &double_;
case DataType::kHalf:
return &half_;
case DataType::kInt8:
return &int8_;
case DataType::kInt32:
return &int32_;
case DataType::kComplexFloat:
return &complex_float_;
case DataType::kComplexDouble:
return &complex_double_;
default:
return nullptr;
}
}
DataType data_type() const { return dtype_; }
private:
template <typename T>
const T& value_impl() const;
union {
float float_;
double double_;
Eigen::half half_;
int8 int8_;
int32 int32_;
std::complex<float> complex_float_;
std::complex<double> complex_double_;
DeviceMemoryBase pointer_;
};
bool is_pointer_;
DataType dtype_;
};
template <>
inline const float& HostOrDeviceScalar<void>::value_impl<float>() const {
return float_;
}
template <>
inline const double& HostOrDeviceScalar<void>::value_impl<double>() const {
return double_;
}
template <>
inline const Eigen::half& HostOrDeviceScalar<void>::value_impl<Eigen::half>()
const {
return half_;
}
template <>
inline const int8& HostOrDeviceScalar<void>::value_impl<int8>() const {
return int8_;
}
template <>
inline const int32& HostOrDeviceScalar<void>::value_impl<int32>() const {
return int32_;
}
template <>
inline const std::complex<float>&
HostOrDeviceScalar<void>::value_impl<std::complex<float>>() const {
return complex_float_;
}
template <>
inline const std::complex<double>&
HostOrDeviceScalar<void>::value_impl<std::complex<double>>() const {
return complex_double_;
}
} // namespace stream_executor
#endif // TENSORFLOW_STREAM_EXECUTOR_HOST_OR_DEVICE_SCALAR_H_

View File

@ -23,6 +23,7 @@ namespace DsoLoader {
port::Status TryDlopenCUDALibraries() {
auto cudart_status = GetCudaRuntimeDsoHandle();
auto cublas_status = GetCublasDsoHandle();
auto cublaslt_status = GetCublasLtDsoHandle();
auto cufft_status = GetCufftDsoHandle();
auto curand_status = GetCurandDsoHandle();
auto cusolver_status = GetCusolverDsoHandle();
@ -31,7 +32,7 @@ port::Status TryDlopenCUDALibraries() {
if (!cudart_status.status().ok() || !cublas_status.status().ok() ||
!cufft_status.status().ok() || !curand_status.status().ok() ||
!cusolver_status.status().ok() || !cusparse_status.status().ok() ||
!cudnn_status.status().ok()) {
!cudnn_status.status().ok() || !cublaslt_status.status().ok()) {
return port::Status(port::error::INTERNAL,
absl::StrCat("Cannot dlopen all CUDA libraries."));
} else {

View File

@ -85,6 +85,10 @@ port::StatusOr<void*> GetCublasDsoHandle() {
return GetDsoHandle("cublas", GetCublasVersion());
}
port::StatusOr<void*> GetCublasLtDsoHandle() {
return GetDsoHandle("cublasLt", GetCublasVersion());
}
port::StatusOr<void*> GetCufftDsoHandle() {
return GetDsoHandle("cufft", GetCufftVersion());
}
@ -161,6 +165,11 @@ port::StatusOr<void*> GetCublasDsoHandle() {
return *result;
}
port::StatusOr<void*> GetCublasLtDsoHandle() {
static auto result = new auto(DsoLoader::GetCublasLtDsoHandle());
return *result;
}
port::StatusOr<void*> GetCurandDsoHandle() {
static auto result = new auto(DsoLoader::GetCurandDsoHandle());
return *result;

View File

@ -37,6 +37,7 @@ namespace DsoLoader {
port::StatusOr<void*> GetCudaDriverDsoHandle();
port::StatusOr<void*> GetCudaRuntimeDsoHandle();
port::StatusOr<void*> GetCublasDsoHandle();
port::StatusOr<void*> GetCublasLtDsoHandle();
port::StatusOr<void*> GetCufftDsoHandle();
port::StatusOr<void*> GetCurandDsoHandle();
port::StatusOr<void*> GetCusolverDsoHandle();
@ -72,6 +73,7 @@ namespace CachedDsoLoader {
port::StatusOr<void*> GetCudaDriverDsoHandle();
port::StatusOr<void*> GetCudaRuntimeDsoHandle();
port::StatusOr<void*> GetCublasDsoHandle();
port::StatusOr<void*> GetCublasLtDsoHandle();
port::StatusOr<void*> GetCufftDsoHandle();
port::StatusOr<void*> GetCurandDsoHandle();
port::StatusOr<void*> GetCusolverDsoHandle();

View File

@ -2540,6 +2540,32 @@ port::Status ROCMBlas::GetVersion(string *version) {
return port::UnimplementedError("");
}
port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
ROCMBlas::CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &p) {
return port::Status(
port::error::UNIMPLEMENTED,
"CreateBlasLtMatmulPlan is not supported with this version of ROCM");
}
port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
ROCMBlas::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
size_t max_workspace_size,
int max_algorithm_count) {
return port::Status(
port::error::UNIMPLEMENTED,
"GetBlasLtMatmulAlgorithms is not supported with this version of ROCM");
}
bool ROCMBlas::DoBlasLtMatmul(
Stream *stream, const blas::IBlasLtMatmulPlan *plan,
const HostOrDeviceScalar<void> &alpha, DeviceMemoryBase a,
DeviceMemoryBase b, const HostOrDeviceScalar<void> &beta,
DeviceMemoryBase c, ScratchAllocator *scratch_allocator,
const blas::IBlasLtMatmulAlgorithm *algorithm, DeviceMemoryBase bias,
blas::ProfileResult *output_profile_result) {
return false;
}
} // namespace gpu
void initialize_rocblas() {

View File

@ -4322,6 +4322,80 @@ Stream &Stream::ThenBlasGemmStridedBatched(
c, ldc, stride_c, batch_count);
}
template <typename ABType, typename CType>
Stream &Stream::ThenBlasLtMatmulImpl(
const blas::IBlasLtMatmulPlan *plan, const HostOrDeviceScalar<CType> &alpha,
const DeviceMemory<ABType> &a, const DeviceMemory<ABType> &b,
const HostOrDeviceScalar<CType> &beta, DeviceMemory<CType> *c,
ScratchAllocator *scratch_allocator,
const blas::IBlasLtMatmulAlgorithm *algorithm,
const DeviceMemory<CType> &bias,
blas::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
PARAM(c), PARAM(algorithm), PARAM(bias));
ThenBlasWithProfileImpl<
const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<CType> &,
const DeviceMemory<ABType> &, const DeviceMemory<ABType> &,
const HostOrDeviceScalar<CType> &, DeviceMemory<CType> *,
ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
const DeviceMemory<CType> &>
impl;
return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
c, scratch_allocator, algorithm, bias, output_profile_result);
}
// Explicit template instantiations for each supported type combination.
template Stream &Stream::ThenBlasLtMatmulImpl<int8, int32>(
const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<int32> &,
const DeviceMemory<int8> &, const DeviceMemory<int8> &,
const HostOrDeviceScalar<int32> &, DeviceMemory<int32> *,
ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
const DeviceMemory<int32> &, blas::ProfileResult *);
template Stream &Stream::ThenBlasLtMatmulImpl<Eigen::half, Eigen::half>(
const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<Eigen::half> &,
const DeviceMemory<Eigen::half> &, const DeviceMemory<Eigen::half> &,
const HostOrDeviceScalar<Eigen::half> &, DeviceMemory<Eigen::half> *,
ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
const DeviceMemory<Eigen::half> &, blas::ProfileResult *);
template Stream &Stream::ThenBlasLtMatmulImpl<float, float>(
const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<float> &,
const DeviceMemory<float> &, const DeviceMemory<float> &,
const HostOrDeviceScalar<float> &, DeviceMemory<float> *,
ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
const DeviceMemory<float> &, blas::ProfileResult *);
template Stream &Stream::ThenBlasLtMatmulImpl<double, double>(
const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<double> &,
const DeviceMemory<double> &, const DeviceMemory<double> &,
const HostOrDeviceScalar<double> &, DeviceMemory<double> *,
ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
const DeviceMemory<double> &, blas::ProfileResult *);
template Stream &
Stream::ThenBlasLtMatmulImpl<std::complex<float>, std::complex<float>>(
const blas::IBlasLtMatmulPlan *,
const HostOrDeviceScalar<std::complex<float>> &,
const DeviceMemory<std::complex<float>> &,
const DeviceMemory<std::complex<float>> &,
const HostOrDeviceScalar<std::complex<float>> &,
DeviceMemory<std::complex<float>> *, ScratchAllocator *,
const blas::IBlasLtMatmulAlgorithm *,
const DeviceMemory<std::complex<float>> &, blas::ProfileResult *);
template Stream &
Stream::ThenBlasLtMatmulImpl<std::complex<double>, std::complex<double>>(
const blas::IBlasLtMatmulPlan *,
const HostOrDeviceScalar<std::complex<double>> &,
const DeviceMemory<std::complex<double>> &,
const DeviceMemory<std::complex<double>> &,
const HostOrDeviceScalar<std::complex<double>> &,
DeviceMemory<std::complex<double>> *, ScratchAllocator *,
const blas::IBlasLtMatmulAlgorithm *,
const DeviceMemory<std::complex<double>> &, blas::ProfileResult *);
Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
VLOG_CALL(PARAM(seed), PARAM(seed_bytes));

View File

@ -75,6 +75,19 @@ class AlgorithmDesc;
class StreamExecutor;
class ScratchAllocator;
namespace detail {
// Helper class to prevent a template function argument from being deduced. This
// is identical to std::type_identity in C++20.
template <typename T>
struct NonDeduced {
using type = T;
};
template <typename T>
using NonDeducedType = typename NonDeduced<T>::type;
} // namespace detail
// Convert a type to the corresponding QuantizedActivationMode.
template <typename ElementType>
struct Quantization;
@ -1632,6 +1645,25 @@ class Stream {
const DeviceMemory<std::complex<double>> &a, int lda,
DeviceMemory<std::complex<double>> *b, int ldb);
// See BlasSupport::DoBlatLtMatmul.
// Note that we prevent alpha and beta from being used to deduce CType so that
// they can be constructed implicitly from values of type CType. Without this,
// type deduction would fail when this function is called with a value of type
// CType for alpha or beta.
template <typename ABType, typename CType>
Stream &ThenBlasLtMatmul(
const blas::IBlasLtMatmulPlan *plan,
const detail::NonDeducedType<HostOrDeviceScalar<CType>> &alpha,
const DeviceMemory<ABType> &a, const DeviceMemory<ABType> &b,
const detail::NonDeducedType<HostOrDeviceScalar<CType>> &beta,
DeviceMemory<CType> *c, ScratchAllocator *scratch_allocator,
const blas::IBlasLtMatmulAlgorithm *algorithm,
const DeviceMemory<CType> &bias = {},
blas::ProfileResult *output_profile_result = nullptr) {
return ThenBlasLtMatmulImpl(plan, alpha, a, b, beta, c, scratch_allocator,
algorithm, bias, output_profile_result);
}
// See FftSupport::DoFft.
Stream &ThenFft(fft::Plan *plan,
const DeviceMemory<std::complex<float>> &input,
@ -2064,6 +2096,19 @@ class Stream {
const dnn::BatchDescriptor &bias_descriptor,
DeviceMemory<T> *backward_bias_data);
// Implementation of ThenBlasLtMatmul that is shared by all types.
template <typename ABType, typename CType>
Stream &ThenBlasLtMatmulImpl(const blas::IBlasLtMatmulPlan *plan,
const HostOrDeviceScalar<CType> &alpha,
const DeviceMemory<ABType> &a,
const DeviceMemory<ABType> &b,
const HostOrDeviceScalar<CType> &beta,
DeviceMemory<CType> *c,
ScratchAllocator *scratch_allocator,
const blas::IBlasLtMatmulAlgorithm *algorithm,
const DeviceMemory<CType> &bias,
blas::ProfileResult *output_profile_result);
SE_DISALLOW_COPY_AND_ASSIGN(Stream);
};

View File

@ -337,6 +337,30 @@ bool StreamExecutor::GetBlasGemmAlgorithms(
return blas_support->GetBlasGemmAlgorithms(out_algorithms);
}
port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
StreamExecutor::CreateBlasLtMatmulPlan(
const blas::BlasLtMatmulPlanParams &params) {
blas::BlasSupport *blas_support = AsBlas();
if (!blas_support) {
return port::Status(port::error::UNKNOWN,
"Fail to find the blas implementation.");
}
return blas_support->CreateBlasLtMatmulPlan(params);
}
port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
StreamExecutor::GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
size_t max_workspace_size,
int max_algorithm_count) {
blas::BlasSupport *blas_support = AsBlas();
if (!blas_support) {
return port::Status(port::error::UNKNOWN,
"Fail to find the blas implementation.");
}
return blas_support->GetBlasLtMatmulAlgorithms(plan, max_workspace_size,
max_algorithm_count);
}
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
StreamExecutor::createRnnDescriptor(
int num_layers, int hidden_size, int input_size, int cell_size,

View File

@ -395,6 +395,21 @@ class StreamExecutor {
// Get the list of supported algorithms for BLAS gemm.
bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms);
// Creates a backend-specific plan object for a blaslt matmul operation, which
// can then be passed to DoBlasLtMatmul(). When possible, plans should be
// created once and reused for multiple calls to DoBlasLtMatmul().
// Returns a null pointer on failure.
port::StatusOr<std::unique_ptr<blas::IBlasLtMatmulPlan>>
CreateBlasLtMatmulPlan(const blas::BlasLtMatmulPlanParams &params);
// Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
// returned in the order of increasing estimated compute time according to an
// internal heuristic. The first returned algorithm can be used as the default
// algorithm if no autotuning is to be performed.
port::StatusOr<std::vector<std::unique_ptr<blas::IBlasLtMatmulAlgorithm>>>
GetBlasLtMatmulAlgorithms(const blas::IBlasLtMatmulPlan *plan,
size_t max_workspace_size, int max_algorithm_count);
// Create an RNN descriptor based on model shapes and configurations.
// The caller retains the ownership of the descriptor.
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(

View File

@ -127,6 +127,13 @@ cc_library(
linkstatic = 1,
)
cc_library(
name = "cublasLt",
srcs = ["cuda/lib/%{cublasLt_lib}"],
data = ["cuda/lib/%{cublasLt_lib}"],
linkstatic = 1,
)
cc_library(
name = "cusolver",
srcs = ["cuda/lib/%{cusolver_lib}"],
@ -168,6 +175,7 @@ cc_library(
name = "cuda",
deps = [
":cublas",
":cublasLt",
":cuda_headers",
":cudart",
":cudnn",

View File

@ -551,6 +551,13 @@ def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config):
cuda_config.cublas_version,
static = False,
),
"cublasLt": _check_cuda_lib_params(
"cublasLt",
cpu_value,
cuda_config.config["cublas_library_dir"],
cuda_config.cublas_version,
static = False,
),
"cusolver": _check_cuda_lib_params(
"cusolver",
cpu_value,
@ -780,6 +787,7 @@ def _create_dummy_repository(repository_ctx):
"%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value),
"%{cudart_lib}": lib_name("cudart", cpu_value),
"%{cublas_lib}": lib_name("cublas", cpu_value),
"%{cublasLt_lib}": lib_name("cublasLt", cpu_value),
"%{cusolver_lib}": lib_name("cusolver", cpu_value),
"%{cudnn_lib}": lib_name("cudnn", cpu_value),
"%{cufft_lib}": lib_name("cufft", cpu_value),
@ -811,6 +819,7 @@ filegroup(name="cudnn-include")
"cuda/cuda/lib/%s" % lib_name("cudart_static", cpu_value),
)
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublas", cpu_value))
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublasLt", cpu_value))
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusolver", cpu_value))
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudnn", cpu_value))
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("curand", cpu_value))
@ -1002,11 +1011,13 @@ def _create_local_cuda_repository(repository_ctx):
cublas_include_path + "/cublas.h",
cublas_include_path + "/cublas_v2.h",
cublas_include_path + "/cublas_api.h",
cublas_include_path + "/cublasLt.h",
],
outs = [
"cublas/include/cublas.h",
"cublas/include/cublas_v2.h",
"cublas/include/cublas_api.h",
"cublas/include/cublasLt.h",
],
))
@ -1147,6 +1158,7 @@ def _create_local_cuda_repository(repository_ctx):
"%{cudart_static_linkopt}": _cudart_static_linkopt(cuda_config.cpu_value),
"%{cudart_lib}": _basename(repository_ctx, cuda_libs["cudart"]),
"%{cublas_lib}": _basename(repository_ctx, cuda_libs["cublas"]),
"%{cublasLt_lib}": _basename(repository_ctx, cuda_libs["cublasLt"]),
"%{cusolver_lib}": _basename(repository_ctx, cuda_libs["cusolver"]),
"%{cudnn_lib}": _basename(repository_ctx, cuda_libs["cudnn"]),
"%{cufft_lib}": _basename(repository_ctx, cuda_libs["cufft"]),