|
|
|
@ -22,7 +22,6 @@ limitations under the License.
|
|
|
|
|
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
|
|
|
|
#include "tensorflow/core/framework/op.h"
|
|
|
|
|
#include "tensorflow/core/framework/op_kernel.h"
|
|
|
|
|
#include "tensorflow/core/framework/register_types.h"
|
|
|
|
@ -34,17 +33,24 @@ limitations under the License.
|
|
|
|
|
#include "tensorflow/core/lib/core/errors.h"
|
|
|
|
|
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
|
|
|
|
#include "tensorflow/core/platform/logging.h"
|
|
|
|
|
#include "tensorflow/core/platform/tensor_float_32_utils.h"
|
|
|
|
|
#include "tensorflow/core/platform/types.h"
|
|
|
|
|
#include "tensorflow/core/util/matmul_autotune.h"
|
|
|
|
|
#include "tensorflow/core/util/matmul_bcast.h"
|
|
|
|
|
#include "tensorflow/core/util/work_sharder.h"
|
|
|
|
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
|
|
|
|
|
|
|
|
|
#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
|
|
|
|
|
#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
|
|
|
|
#include "tensorflow/core/kernels/gpu_utils.h"
|
|
|
|
|
#include "tensorflow/core/platform/stream_executor.h"
|
|
|
|
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
|
|
|
|
#if GOOGLE_CUDA
|
|
|
|
|
#include "third_party/gpus/cuda/include/cuda.h" // For CUDA_VERSION
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
namespace tensorflow {
|
|
|
|
|
|
|
|
|
@ -219,7 +225,8 @@ template <typename Scalar>
|
|
|
|
|
struct LaunchBatchMatMul<CPUDevice, Scalar> {
|
|
|
|
|
static void Launch(OpKernelContext* context, const Tensor& in_x,
|
|
|
|
|
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
|
|
|
|
|
bool trans_y, const MatMulBCast& bcast, Tensor* out) {
|
|
|
|
|
bool trans_y, const MatMulBCast& bcast, bool use_autotune,
|
|
|
|
|
Tensor* out) {
|
|
|
|
|
typedef ParallelMatMulKernel<Scalar, Eigen::NumTraits<Scalar>::IsComplex>
|
|
|
|
|
ParallelMatMulKernel;
|
|
|
|
|
bool conjugate_result = false;
|
|
|
|
@ -275,45 +282,201 @@ se::DeviceMemory<T> AsDeviceMemory(const T* gpu_memory) {
|
|
|
|
|
return typed;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class BlasScratchAllocator : public se::ScratchAllocator {
|
|
|
|
|
using BlasScratchAllocator = GpuScratchAllocator;
|
|
|
|
|
|
|
|
|
|
int64 GetBlasWorkspaceLimit(const string& envvar_in_mb,
|
|
|
|
|
int64 default_value_in_bytes) {
|
|
|
|
|
return GetWorkspaceLimit(envvar_in_mb, default_value_in_bytes);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Encapsulate all of the shape, dtype etc. information that defines a unique
|
|
|
|
|
// batched matmul operation.
|
|
|
|
|
class BatchMatmulParameters {
|
|
|
|
|
public:
|
|
|
|
|
using Stream = se::Stream;
|
|
|
|
|
using DeviceMemoryBytes = se::DeviceMemory<uint8>;
|
|
|
|
|
BatchMatmulParameters(bool trans_a, bool trans_b, bool adj_a, bool adj_b,
|
|
|
|
|
uint64 m, uint64 n, uint64 k, uint64 batch_count,
|
|
|
|
|
bool broadcast_a, bool broadcast_b, DataType dtype_ab,
|
|
|
|
|
DataType dtype_cd, bool allow_tf32, int device_id)
|
|
|
|
|
: trans_a_(trans_a),
|
|
|
|
|
trans_b_(trans_b),
|
|
|
|
|
adj_a_(adj_a),
|
|
|
|
|
adj_b_(adj_b),
|
|
|
|
|
m_(m),
|
|
|
|
|
n_(n),
|
|
|
|
|
k_(k),
|
|
|
|
|
batch_count_(batch_count),
|
|
|
|
|
broadcast_a_(broadcast_a),
|
|
|
|
|
broadcast_b_(broadcast_b),
|
|
|
|
|
dtype_ab_(dtype_ab),
|
|
|
|
|
dtype_cd_(dtype_cd),
|
|
|
|
|
allow_tf32_(allow_tf32),
|
|
|
|
|
device_id_(device_id) {
|
|
|
|
|
hash_code_ = trans_a;
|
|
|
|
|
hash_code_ = Hash64Combine(hash_code_, trans_b);
|
|
|
|
|
hash_code_ = Hash64Combine(hash_code_, adj_a);
|
|
|
|
|
hash_code_ = Hash64Combine(hash_code_, adj_b);
|
|
|
|
|
hash_code_ = Hash64Combine(hash_code_, m);
|
|
|
|
|
hash_code_ = Hash64Combine(hash_code_, n);
|
|
|
|
|
hash_code_ = Hash64Combine(hash_code_, k);
|
|
|
|
|
hash_code_ = Hash64Combine(hash_code_, batch_count);
|
|
|
|
|
hash_code_ = Hash64Combine(hash_code_, broadcast_a);
|
|
|
|
|
hash_code_ = Hash64Combine(hash_code_, broadcast_b);
|
|
|
|
|
hash_code_ = Hash64Combine(hash_code_, dtype_ab);
|
|
|
|
|
hash_code_ = Hash64Combine(hash_code_, dtype_cd);
|
|
|
|
|
hash_code_ = Hash64Combine(hash_code_, allow_tf32);
|
|
|
|
|
hash_code_ = Hash64Combine(hash_code_, device_id);
|
|
|
|
|
}
|
|
|
|
|
bool operator==(const BatchMatmulParameters& other) const {
|
|
|
|
|
return this->get_data_as_tuple() == other.get_data_as_tuple();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BlasScratchAllocator(OpKernelContext* context) : context_(context) {}
|
|
|
|
|
bool operator!=(const BatchMatmulParameters& other) const {
|
|
|
|
|
return !(*this == other);
|
|
|
|
|
}
|
|
|
|
|
uint64 hash() const { return hash_code_; }
|
|
|
|
|
|
|
|
|
|
int64 GetMemoryLimitInBytes() override { return -1; }
|
|
|
|
|
|
|
|
|
|
se::port::StatusOr<DeviceMemoryBytes> AllocateBytes(
|
|
|
|
|
int64 byte_size) override {
|
|
|
|
|
Tensor temporary_memory;
|
|
|
|
|
|
|
|
|
|
Status allocation_status(context_->allocate_temp(
|
|
|
|
|
DT_UINT8, TensorShape({byte_size}), &temporary_memory));
|
|
|
|
|
if (!allocation_status.ok()) {
|
|
|
|
|
return se::port::StatusOr<DeviceMemoryBytes>(
|
|
|
|
|
DeviceMemoryBytes::MakeFromByteSize(nullptr, 0));
|
|
|
|
|
}
|
|
|
|
|
// Hold the reference of the allocated tensors until the end of the
|
|
|
|
|
// allocator.
|
|
|
|
|
allocated_tensors_.push_back(temporary_memory);
|
|
|
|
|
return se::port::StatusOr<DeviceMemoryBytes>(
|
|
|
|
|
DeviceMemoryBytes::MakeFromByteSize(
|
|
|
|
|
temporary_memory.flat<uint8>().data(),
|
|
|
|
|
temporary_memory.flat<uint8>().size()));
|
|
|
|
|
string ToString() const {
|
|
|
|
|
// clang-format off
|
|
|
|
|
return strings::StrCat(
|
|
|
|
|
trans_a_, ", ", trans_b_, ", ", adj_a_, ", ", adj_b_, ", ",
|
|
|
|
|
m_, ", ", n_, ", ", k_, ", ", batch_count_, ", ",
|
|
|
|
|
broadcast_a_, ", ", broadcast_b_, ", ",
|
|
|
|
|
dtype_ab_, ", ", dtype_cd_, ", ", allow_tf32_, ", ", device_id_);
|
|
|
|
|
// clang-format on
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
OpKernelContext* context_;
|
|
|
|
|
std::vector<Tensor> allocated_tensors_;
|
|
|
|
|
typedef std::tuple<bool, bool, bool, bool, int64, int64, int64, int64, bool,
|
|
|
|
|
bool, DataType, DataType, bool, int>
|
|
|
|
|
ParameterDataType;
|
|
|
|
|
|
|
|
|
|
ParameterDataType get_data_as_tuple() const {
|
|
|
|
|
return std::make_tuple(trans_a_, trans_b_, adj_a_, adj_b_, m_, n_, k_,
|
|
|
|
|
batch_count_, broadcast_a_, broadcast_b_, dtype_ab_,
|
|
|
|
|
dtype_cd_, allow_tf32_, device_id_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool trans_a_;
|
|
|
|
|
bool trans_b_;
|
|
|
|
|
bool adj_a_;
|
|
|
|
|
bool adj_b_;
|
|
|
|
|
uint64 m_;
|
|
|
|
|
uint64 n_;
|
|
|
|
|
uint64 k_;
|
|
|
|
|
uint64 batch_count_;
|
|
|
|
|
bool broadcast_a_;
|
|
|
|
|
bool broadcast_b_;
|
|
|
|
|
DataType dtype_ab_;
|
|
|
|
|
DataType dtype_cd_;
|
|
|
|
|
bool allow_tf32_;
|
|
|
|
|
int device_id_;
|
|
|
|
|
uint64 hash_code_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
bool GetBlasComputationType(const DataType& dtype, bool allow_tf32,
|
|
|
|
|
se::blas::ComputationType* compute_type) {
|
|
|
|
|
using se::blas::ComputationType;
|
|
|
|
|
static bool use_f32_for_f16_computation = MatmulDoFP32ComputationFP16Input();
|
|
|
|
|
ComputationType f32_type =
|
|
|
|
|
allow_tf32 ? ComputationType::kF32FastTF32 : ComputationType::kF32;
|
|
|
|
|
switch (dtype) {
|
|
|
|
|
case DT_HALF:
|
|
|
|
|
case DT_BFLOAT16:
|
|
|
|
|
*compute_type =
|
|
|
|
|
use_f32_for_f16_computation ? f32_type : ComputationType::kF16;
|
|
|
|
|
return true;
|
|
|
|
|
case DT_FLOAT:
|
|
|
|
|
*compute_type = f32_type;
|
|
|
|
|
return true;
|
|
|
|
|
case DT_DOUBLE:
|
|
|
|
|
*compute_type = ComputationType::kF64;
|
|
|
|
|
return true;
|
|
|
|
|
case DT_COMPLEX64:
|
|
|
|
|
*compute_type = f32_type;
|
|
|
|
|
return true;
|
|
|
|
|
case DT_COMPLEX128:
|
|
|
|
|
*compute_type = ComputationType::kComplexF64;
|
|
|
|
|
return true;
|
|
|
|
|
default:
|
|
|
|
|
// Unsupported compute_type, return false.
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Thread-safe map from matmul parameters to their corresponding plan and
|
|
|
|
|
// algorithms.
|
|
|
|
|
template <typename Parameters>
|
|
|
|
|
class BlasLtMatmulPlanMap {
|
|
|
|
|
public:
|
|
|
|
|
struct PlanAndAlgorithms {
|
|
|
|
|
std::unique_ptr<se::blas::IBlasLtMatmulPlan> plan;
|
|
|
|
|
std::vector<std::unique_ptr<se::blas::IBlasLtMatmulAlgorithm>> algorithms;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
const PlanAndAlgorithms* Find(const Parameters& params) {
|
|
|
|
|
mutex_lock lock(mu_);
|
|
|
|
|
auto iter = params_plan_map_.find(params);
|
|
|
|
|
if (iter == params_plan_map_.end()) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
return &iter->second;
|
|
|
|
|
}
|
|
|
|
|
const PlanAndAlgorithms* Insert(const Parameters& params,
|
|
|
|
|
PlanAndAlgorithms value) {
|
|
|
|
|
mutex_lock lock(mu_);
|
|
|
|
|
return ¶ms_plan_map_.emplace(params, std::move(value)).first->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
struct Hasher {
|
|
|
|
|
std::size_t operator()(const Parameters& parameter) const {
|
|
|
|
|
return parameter.hash();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
mutable mutex mu_;
|
|
|
|
|
std::unordered_map<Parameters, PlanAndAlgorithms, Hasher> params_plan_map_
|
|
|
|
|
GUARDED_BY(mu_);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename Parameters>
|
|
|
|
|
struct BlasLtPlanMapSingleton {
|
|
|
|
|
typedef BlasLtMatmulPlanMap<Parameters> PlanMapType;
|
|
|
|
|
static PlanMapType* GetInstance() {
|
|
|
|
|
static PlanMapType* instance = new PlanMapType();
|
|
|
|
|
return instance;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
typedef BlasLtPlanMapSingleton<BatchMatmulParameters>
|
|
|
|
|
BatchMatmulPlanMapSingleton;
|
|
|
|
|
|
|
|
|
|
// A dummy type to group matmul autotune results together.
|
|
|
|
|
struct BatchMatmulAutoTuneGroup {
|
|
|
|
|
static string name() { return "MatmulLt"; }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
typedef AutoTuneSingleton<BatchMatmulAutoTuneGroup, BatchMatmulParameters,
|
|
|
|
|
se::blas::AlgorithmConfig>
|
|
|
|
|
AutoTuneBatchMatmul;
|
|
|
|
|
|
|
|
|
|
template <typename Scalar>
|
|
|
|
|
struct CoefficientType {
|
|
|
|
|
typedef Scalar type;
|
|
|
|
|
};
|
|
|
|
|
template <>
|
|
|
|
|
struct CoefficientType<Eigen::half> {
|
|
|
|
|
typedef float type;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
template <typename Scalar>
|
|
|
|
|
struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
|
|
|
|
static void Launch(OpKernelContext* context, const Tensor& in_x,
|
|
|
|
|
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
|
|
|
|
|
bool trans_y, const MatMulBCast& bcast, Tensor* out) {
|
|
|
|
|
bool trans_y, const MatMulBCast& bcast, bool use_autotune,
|
|
|
|
|
Tensor* out) {
|
|
|
|
|
se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
|
|
|
|
|
se::blas::Transpose::kTranspose,
|
|
|
|
|
se::blas::Transpose::kConjugateTranspose};
|
|
|
|
@ -347,6 +510,198 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
|
|
|
|
uint64 b_stride;
|
|
|
|
|
uint64 c_stride;
|
|
|
|
|
|
|
|
|
|
typedef typename CoefficientType<Scalar>::type Coefficient;
|
|
|
|
|
|
|
|
|
|
static const int64 max_scratch_size = GetBlasWorkspaceLimit(
|
|
|
|
|
"TF_CUBLAS_WORKSPACE_LIMIT_IN_MB", 1LL << 32); // 4GB by default
|
|
|
|
|
|
|
|
|
|
// The BlasLtMatmul routines are only supported from CUDA 11.0 onward.
|
|
|
|
|
#if GOOGLE_CUDA && CUDA_VERSION >= 11000
|
|
|
|
|
bool is_full_broadcast =
|
|
|
|
|
std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
|
|
|
|
|
bool requires_mixed_broadcasting =
|
|
|
|
|
bcast.IsBroadcastingRequired() && !is_full_broadcast;
|
|
|
|
|
if (!requires_mixed_broadcasting) {
|
|
|
|
|
bool broadcast_a = bcast.x_batch_size() == 1;
|
|
|
|
|
bool broadcast_b = bcast.y_batch_size() == 1;
|
|
|
|
|
a_stride = broadcast_a ? 0 : m * k;
|
|
|
|
|
b_stride = broadcast_b ? 0 : k * n;
|
|
|
|
|
c_stride = m * n;
|
|
|
|
|
a_device_memory.push_back(AsDeviceMemory(a_base_ptr));
|
|
|
|
|
b_device_memory.push_back(AsDeviceMemory(b_base_ptr));
|
|
|
|
|
c_device_memory.push_back(AsDeviceMemory(c_base_ptr));
|
|
|
|
|
a_ptrs.push_back(&a_device_memory.back());
|
|
|
|
|
b_ptrs.push_back(&b_device_memory.back());
|
|
|
|
|
c_ptrs.push_back(&c_device_memory.back());
|
|
|
|
|
|
|
|
|
|
DataType dtype = DataTypeToEnum<Scalar>::value;
|
|
|
|
|
bool allow_tf32 = tensor_float_32_execution_enabled();
|
|
|
|
|
int device_id = stream->parent()->device_ordinal();
|
|
|
|
|
BatchMatmulParameters matmul_parameters(
|
|
|
|
|
trans_x, trans_y, adj_x, adj_y, m, n, k, batch_size, broadcast_a,
|
|
|
|
|
broadcast_b, dtype, dtype, allow_tf32, device_id);
|
|
|
|
|
|
|
|
|
|
static const bool max_autotune_algorithm_count =
|
|
|
|
|
MatmulMaxAutotuneAlgorithmCount();
|
|
|
|
|
int max_algorithm_count = use_autotune ? max_autotune_algorithm_count : 1;
|
|
|
|
|
|
|
|
|
|
const auto* plan_and_algorithms =
|
|
|
|
|
BatchMatmulPlanMapSingleton::GetInstance()->Find(matmul_parameters);
|
|
|
|
|
if (!plan_and_algorithms) {
|
|
|
|
|
se::blas::DataType blas_dtype = se::blas::ToDataType<Scalar>::value;
|
|
|
|
|
se::blas::ComputationType computation_type;
|
|
|
|
|
OP_REQUIRES(
|
|
|
|
|
context,
|
|
|
|
|
GetBlasComputationType(dtype, allow_tf32, &computation_type),
|
|
|
|
|
errors::Internal("Unsupported dtype for batched matmul"));
|
|
|
|
|
std::unique_ptr<se::blas::IBlasLtMatmulPlan> plan =
|
|
|
|
|
stream->parent()->CreateBlasLtMatmulPlanStridedBatched(
|
|
|
|
|
/*ab_type=*/blas_dtype,
|
|
|
|
|
/*cd_type=*/blas_dtype, computation_type,
|
|
|
|
|
se::blas::PointerMode::kHost, blas_transpose_b,
|
|
|
|
|
blas_transpose_a, n, m, k, batch_size,
|
|
|
|
|
/*lda=*/in_y.dim_size(2), b_stride,
|
|
|
|
|
/*ldb=*/in_x.dim_size(2), a_stride, /*ldc=*/n, c_stride);
|
|
|
|
|
OP_REQUIRES(
|
|
|
|
|
context, plan,
|
|
|
|
|
errors::Internal(
|
|
|
|
|
"CreateBlasLtMatmulPlanStridedBatched failed : a.shape=(",
|
|
|
|
|
in_x.dim_size(0), ", ", in_x.dim_size(1), ", ",
|
|
|
|
|
in_x.dim_size(2), "), b.shape=(", in_y.dim_size(0), ", ",
|
|
|
|
|
in_y.dim_size(1), ", ", in_y.dim_size(2), "), m=", m, ", n=", n,
|
|
|
|
|
", k=", k, ", batch_size=", batch_size, ", adjoint_a=", adj_x,
|
|
|
|
|
", adjoint_b=", adj_x, ", dtype=", dtype,
|
|
|
|
|
", computation_type=", computation_type));
|
|
|
|
|
std::vector<std::unique_ptr<se::blas::IBlasLtMatmulAlgorithm>>
|
|
|
|
|
algorithms;
|
|
|
|
|
OP_REQUIRES(
|
|
|
|
|
context,
|
|
|
|
|
stream->parent()->GetBlasLtMatmulAlgorithms(
|
|
|
|
|
plan.get(), max_scratch_size, max_algorithm_count, &algorithms),
|
|
|
|
|
errors::Internal("GetBlasLtMatmulAlgorithms failed: a.shape=(",
|
|
|
|
|
in_x.dim_size(0), ", ", in_x.dim_size(1), ", ",
|
|
|
|
|
in_x.dim_size(2), "), b.shape=(", in_y.dim_size(0),
|
|
|
|
|
", ", in_y.dim_size(1), ", ", in_y.dim_size(2),
|
|
|
|
|
"), m=", m, ", n=", n, ", k=", k,
|
|
|
|
|
", batch_size=", batch_size, ", adjoint_a=", adj_x,
|
|
|
|
|
", adjoint_b=", adj_x, ", dtype=", dtype,
|
|
|
|
|
", computation_type=", computation_type));
|
|
|
|
|
plan_and_algorithms =
|
|
|
|
|
BatchMatmulPlanMapSingleton::GetInstance()->Insert(
|
|
|
|
|
matmul_parameters, {std::move(plan), std::move(algorithms)});
|
|
|
|
|
}
|
|
|
|
|
const auto& plan = plan_and_algorithms->plan;
|
|
|
|
|
const auto& algorithms = plan_and_algorithms->algorithms;
|
|
|
|
|
|
|
|
|
|
// The BlasLtMatmul routines (unlike BlasGemm, BlasGemmBatched etc.) take
|
|
|
|
|
// alpha and beta with the same type as the matrices.
|
|
|
|
|
Scalar alpha(1.0);
|
|
|
|
|
Scalar beta(0.0);
|
|
|
|
|
|
|
|
|
|
// Note that algorithm_config.algorithm() here is used to refer
|
|
|
|
|
// to the index within the algorithms vector, not the algorithm
|
|
|
|
|
// itself.
|
|
|
|
|
se::blas::AlgorithmConfig algorithm_config(se::blas::kNoAlgorithm);
|
|
|
|
|
if (max_algorithm_count == 1) {
|
|
|
|
|
algorithm_config.set_algorithm(0);
|
|
|
|
|
} else if (!AutoTuneBatchMatmul::GetInstance()->Find(matmul_parameters,
|
|
|
|
|
&algorithm_config)) {
|
|
|
|
|
VLOG(4) << "Autotuning BlasLtMatmul over " << algorithms.size()
|
|
|
|
|
<< " algorithms.";
|
|
|
|
|
se::blas::ProfileResult best_result;
|
|
|
|
|
se::blas::ProfileResult profile_result;
|
|
|
|
|
//for (const auto& profile_algorithm : plan_and_algorithms->algorithms) {
|
|
|
|
|
for (size_t i = 0; i != algorithms.size(); ++i) {
|
|
|
|
|
const auto& profile_algorithm = algorithms[i];
|
|
|
|
|
// Create a new scratch allocator with every autotuning run so that
|
|
|
|
|
// scratch space is deallocated between runs.
|
|
|
|
|
BlasScratchAllocator scratch_allocator(max_scratch_size, context);
|
|
|
|
|
|
|
|
|
|
bool cublas_launch_status =
|
|
|
|
|
stream
|
|
|
|
|
->ThenBlasLtMatmul(plan.get(), alpha, *b_ptrs[0], *a_ptrs[0],
|
|
|
|
|
beta, c_ptrs[0], &scratch_allocator,
|
|
|
|
|
profile_algorithm.get(), &profile_result)
|
|
|
|
|
.ok();
|
|
|
|
|
|
|
|
|
|
VLOG(4) << " Autotune algorithm " << i
|
|
|
|
|
<< " result: " << profile_result.elapsed_time_in_ms()
|
|
|
|
|
<< " ms, valid=" << profile_result.is_valid()
|
|
|
|
|
<< ", workspace_size="
|
|
|
|
|
<< profile_algorithm->workspace_size();
|
|
|
|
|
|
|
|
|
|
if (cublas_launch_status && profile_result.is_valid() &&
|
|
|
|
|
profile_result.elapsed_time_in_ms() <
|
|
|
|
|
best_result.elapsed_time_in_ms()) {
|
|
|
|
|
best_result = profile_result;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (best_result.is_valid()) {
|
|
|
|
|
algorithm_config.set_algorithm(best_result.algorithm());
|
|
|
|
|
}
|
|
|
|
|
// We make sure that each matmul parameter set only gets one pass of
|
|
|
|
|
// autotune. If no algorithms works, we add kNoAlgorithm to the autotune
|
|
|
|
|
// map.
|
|
|
|
|
AutoTuneBatchMatmul::GetInstance()->Insert(matmul_parameters,
|
|
|
|
|
algorithm_config);
|
|
|
|
|
}
|
|
|
|
|
se::blas::AlgorithmType algorithm_idx = algorithm_config.algorithm();
|
|
|
|
|
OP_REQUIRES(context,
|
|
|
|
|
0 <= algorithm_idx && algorithm_idx < algorithms.size(),
|
|
|
|
|
errors::Internal("Missing/invalid BatchMatmul algorithm"));
|
|
|
|
|
const auto& algorithm = algorithms[algorithm_idx];
|
|
|
|
|
BlasScratchAllocator scratch_allocator(max_scratch_size, context);
|
|
|
|
|
bool cublas_launch_status =
|
|
|
|
|
stream
|
|
|
|
|
->ThenBlasLtMatmul(plan.get(), alpha, *b_ptrs[0], *a_ptrs[0],
|
|
|
|
|
beta, c_ptrs[0], &scratch_allocator,
|
|
|
|
|
algorithm.get())
|
|
|
|
|
.ok();
|
|
|
|
|
if (!cublas_launch_status) {
|
|
|
|
|
context->SetStatus(errors::Internal(
|
|
|
|
|
"Blas batched matmul launch failed : a.shape=(",
|
|
|
|
|
bcast.x_batch_size(), ", ", in_x.dim_size(0), ", ",
|
|
|
|
|
in_x.dim_size(1), "), b.shape=(", bcast.y_batch_size(), ", ",
|
|
|
|
|
in_y.dim_size(0), ", ", in_y.dim_size(1), "), m=", m, ", n=", n,
|
|
|
|
|
", k=", k, ", batch_size=", batch_size));
|
|
|
|
|
}
|
|
|
|
|
} else { // requires mixed broadcasting
|
|
|
|
|
const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
|
|
|
|
|
const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
|
|
|
|
|
for (int64 i = 0; i < bcast.x_batch_size(); ++i) {
|
|
|
|
|
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
|
|
|
|
|
}
|
|
|
|
|
for (int64 i = 0; i < bcast.y_batch_size(); ++i) {
|
|
|
|
|
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
|
|
|
|
|
}
|
|
|
|
|
for (int64 i = 0; i < batch_size; ++i) {
|
|
|
|
|
c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
|
|
|
|
|
a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
|
|
|
|
|
b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
|
|
|
|
|
c_ptrs.push_back(&c_device_memory.back());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BlasScratchAllocator scratch_allocator(max_scratch_size, context);
|
|
|
|
|
bool blas_launch_status =
|
|
|
|
|
stream
|
|
|
|
|
->ThenBlasGemmBatchedWithScratch(
|
|
|
|
|
blas_transpose_b, blas_transpose_a, n, m, k,
|
|
|
|
|
static_cast<Coefficient>(1.0), b_ptrs,
|
|
|
|
|
adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
|
|
|
|
|
static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
|
|
|
|
|
&scratch_allocator)
|
|
|
|
|
.ok();
|
|
|
|
|
if (!blas_launch_status) {
|
|
|
|
|
context->SetStatus(errors::Internal(
|
|
|
|
|
"Blas xGEMMBatched launch failed : a.shape=",
|
|
|
|
|
in_x.shape().DebugString(),
|
|
|
|
|
", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
|
|
|
|
|
", k=", k, ", batch_size=", batch_size));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
#else // if not GOOGLE_CUDA or CUDA_VERSION < 11000
|
|
|
|
|
bool is_full_broadcast =
|
|
|
|
|
std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
|
|
|
|
|
bool use_strided_batched =
|
|
|
|
@ -388,8 +743,6 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
typedef Scalar Coefficient;
|
|
|
|
|
|
|
|
|
|
// Blas does
|
|
|
|
|
// C = A x B
|
|
|
|
|
// where A, B and C are assumed to be in column major.
|
|
|
|
@ -399,7 +752,10 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
|
|
|
|
if (batch_size == 1) {
|
|
|
|
|
// This is a regular matrix*matrix or matrix*vector multiply. Avoid the
|
|
|
|
|
// overhead of the scratch allocator and the batch interface.
|
|
|
|
|
if (n == 1 &&
|
|
|
|
|
// Note that the GEMV call here does not support Eigen::half, so we do not
|
|
|
|
|
// use this path in that case. A workaround is applied to the pointers
|
|
|
|
|
// passed to the call itself to avoid compilation errors.
|
|
|
|
|
if (!std::is_same<Scalar, Eigen::half>::value && n == 1 &&
|
|
|
|
|
blas_transpose_b != se::blas::Transpose::kConjugateTranspose &&
|
|
|
|
|
blas_transpose_a != se::blas::Transpose::kConjugateTranspose) {
|
|
|
|
|
// This is a matrix*vector multiply so use GEMV to compute A * b.
|
|
|
|
@ -410,13 +766,19 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
|
|
|
|
auto gemv_trans_a = blas_transpose_a == se::blas::Transpose::kTranspose
|
|
|
|
|
? se::blas::Transpose::kNoTranspose
|
|
|
|
|
: se::blas::Transpose::kTranspose;
|
|
|
|
|
// Cast pointers as a workaround for GEMV not supporting Eigen::half
|
|
|
|
|
// (this will never actually be executed for Eigen::half).
|
|
|
|
|
typedef se::DeviceMemory<Coefficient> NonHalfDeviceMemoryType;
|
|
|
|
|
NonHalfDeviceMemoryType a_ptr(*(a_ptrs[0]));
|
|
|
|
|
NonHalfDeviceMemoryType b_ptr(*(b_ptrs[0]));
|
|
|
|
|
NonHalfDeviceMemoryType c_ptr(*(c_ptrs[0]));
|
|
|
|
|
bool blas_launch_status =
|
|
|
|
|
stream
|
|
|
|
|
->ThenBlasGemv(gemv_trans_a, adj_x || trans_x ? m : k,
|
|
|
|
|
adj_x || trans_x ? k : m,
|
|
|
|
|
static_cast<Coefficient>(1.0), *(a_ptrs[0]),
|
|
|
|
|
adj_x || trans_x ? m : k, *(b_ptrs[0]), 1,
|
|
|
|
|
static_cast<Coefficient>(0.0), c_ptrs[0], 1)
|
|
|
|
|
static_cast<Coefficient>(1.0), a_ptr,
|
|
|
|
|
adj_x || trans_x ? m : k, b_ptr, 1,
|
|
|
|
|
static_cast<Coefficient>(0.0), &c_ptr, 1)
|
|
|
|
|
.ok();
|
|
|
|
|
if (!blas_launch_status) {
|
|
|
|
|
context->SetStatus(errors::Internal(
|
|
|
|
@ -459,154 +821,7 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
|
|
|
|
", k=", k, ", batch_size=", batch_size));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
BlasScratchAllocator scratch_allocator(context);
|
|
|
|
|
bool blas_launch_status =
|
|
|
|
|
stream
|
|
|
|
|
->ThenBlasGemmBatchedWithScratch(
|
|
|
|
|
blas_transpose_b, blas_transpose_a, n, m, k,
|
|
|
|
|
static_cast<Coefficient>(1.0), b_ptrs,
|
|
|
|
|
adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
|
|
|
|
|
static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
|
|
|
|
|
&scratch_allocator)
|
|
|
|
|
.ok();
|
|
|
|
|
if (!blas_launch_status) {
|
|
|
|
|
context->SetStatus(errors::Internal(
|
|
|
|
|
"Blas xGEMMBatched launch failed : a.shape=",
|
|
|
|
|
in_x.shape().DebugString(),
|
|
|
|
|
", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
|
|
|
|
|
", k=", k, ", batch_size=", batch_size));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
|
|
|
|
|
static void Launch(OpKernelContext* context, const Tensor& in_x,
|
|
|
|
|
const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
|
|
|
|
|
bool trans_y, const MatMulBCast& bcast, Tensor* out) {
|
|
|
|
|
typedef Eigen::half Scalar;
|
|
|
|
|
se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
|
|
|
|
|
se::blas::Transpose::kTranspose,
|
|
|
|
|
se::blas::Transpose::kConjugateTranspose};
|
|
|
|
|
const uint64 m = in_x.dim_size(adj_x || trans_x ? 2 : 1);
|
|
|
|
|
const uint64 k = in_x.dim_size(adj_x || trans_x ? 1 : 2);
|
|
|
|
|
const uint64 n = in_y.dim_size(adj_y || trans_y ? 1 : 2);
|
|
|
|
|
const uint64 batch_size = bcast.output_batch_size();
|
|
|
|
|
auto blas_transpose_a = trans[adj_x ? 2 : (trans_x ? 1 : 0)];
|
|
|
|
|
auto blas_transpose_b = trans[adj_y ? 2 : (trans_y ? 1 : 0)];
|
|
|
|
|
|
|
|
|
|
auto* stream = context->op_device_context()->stream();
|
|
|
|
|
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
|
|
|
|
|
|
|
|
|
|
typedef perftools::gputools::DeviceMemory<Scalar> DeviceMemoryType;
|
|
|
|
|
std::vector<DeviceMemoryType> a_device_memory;
|
|
|
|
|
std::vector<DeviceMemoryType> b_device_memory;
|
|
|
|
|
std::vector<DeviceMemoryType> c_device_memory;
|
|
|
|
|
std::vector<DeviceMemoryType*> a_ptrs;
|
|
|
|
|
std::vector<DeviceMemoryType*> b_ptrs;
|
|
|
|
|
std::vector<DeviceMemoryType*> c_ptrs;
|
|
|
|
|
a_device_memory.reserve(bcast.x_batch_size());
|
|
|
|
|
b_device_memory.reserve(bcast.y_batch_size());
|
|
|
|
|
c_device_memory.reserve(batch_size);
|
|
|
|
|
a_ptrs.reserve(batch_size);
|
|
|
|
|
b_ptrs.reserve(batch_size);
|
|
|
|
|
c_ptrs.reserve(batch_size);
|
|
|
|
|
auto* a_base_ptr = in_x.template flat<Scalar>().data();
|
|
|
|
|
auto* b_base_ptr = in_y.template flat<Scalar>().data();
|
|
|
|
|
auto* c_base_ptr = out->template flat<Scalar>().data();
|
|
|
|
|
|
|
|
|
|
uint64 a_stride;
|
|
|
|
|
uint64 b_stride;
|
|
|
|
|
uint64 c_stride;
|
|
|
|
|
|
|
|
|
|
bool is_full_broadcast =
|
|
|
|
|
std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
|
|
|
|
|
bool use_strided_batched =
|
|
|
|
|
(!bcast.IsBroadcastingRequired() || is_full_broadcast) &&
|
|
|
|
|
batch_size > 1;
|
|
|
|
|
if (use_strided_batched) {
|
|
|
|
|
a_stride = bcast.x_batch_size() != 1 ? m * k : 0;
|
|
|
|
|
b_stride = bcast.y_batch_size() != 1 ? k * n : 0;
|
|
|
|
|
c_stride = m * n;
|
|
|
|
|
a_device_memory.push_back(AsDeviceMemory(a_base_ptr));
|
|
|
|
|
b_device_memory.push_back(AsDeviceMemory(b_base_ptr));
|
|
|
|
|
c_device_memory.push_back(AsDeviceMemory(c_base_ptr));
|
|
|
|
|
a_ptrs.push_back(&a_device_memory.back());
|
|
|
|
|
b_ptrs.push_back(&b_device_memory.back());
|
|
|
|
|
c_ptrs.push_back(&c_device_memory.back());
|
|
|
|
|
} else if (!bcast.IsBroadcastingRequired()) {
|
|
|
|
|
for (int64 i = 0; i < batch_size; ++i) {
|
|
|
|
|
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
|
|
|
|
|
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
|
|
|
|
|
c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
|
|
|
|
|
a_ptrs.push_back(&a_device_memory.back());
|
|
|
|
|
b_ptrs.push_back(&b_device_memory.back());
|
|
|
|
|
c_ptrs.push_back(&c_device_memory.back());
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
|
|
|
|
|
const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();
|
|
|
|
|
for (int64 i = 0; i < bcast.x_batch_size(); ++i) {
|
|
|
|
|
a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
|
|
|
|
|
}
|
|
|
|
|
for (int64 i = 0; i < bcast.y_batch_size(); ++i) {
|
|
|
|
|
b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
|
|
|
|
|
}
|
|
|
|
|
for (int64 i = 0; i < batch_size; ++i) {
|
|
|
|
|
c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
|
|
|
|
|
a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
|
|
|
|
|
b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
|
|
|
|
|
c_ptrs.push_back(&c_device_memory.back());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
typedef float Coefficient;
|
|
|
|
|
|
|
|
|
|
// Blas does
|
|
|
|
|
// C = A x B
|
|
|
|
|
// where A, B and C are assumed to be in column major.
|
|
|
|
|
// We want the output to be in row-major, so we can compute
|
|
|
|
|
// C' = B' x A', where ' stands for transpose (not adjoint).
|
|
|
|
|
// TODO(yangzihao): Choose the best of the three strategies using autotune.
|
|
|
|
|
if (batch_size == 1) {
|
|
|
|
|
// This is a regular matrix*matrix or matrix*vector multiply. Avoid the
|
|
|
|
|
// overhead of the scratch allocator and the batch interface.
|
|
|
|
|
// TODO(benbarsdell): Use fp16 Gemv if it becomes supported by CUBLAS
|
|
|
|
|
bool blas_launch_status =
|
|
|
|
|
stream
|
|
|
|
|
->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
|
|
|
|
|
static_cast<Coefficient>(1.0), *(b_ptrs[0]),
|
|
|
|
|
adj_y || trans_y ? k : n, *(a_ptrs[0]),
|
|
|
|
|
adj_x || trans_x ? m : k,
|
|
|
|
|
static_cast<Coefficient>(0.0), c_ptrs[0], n)
|
|
|
|
|
.ok();
|
|
|
|
|
if (!blas_launch_status) {
|
|
|
|
|
context->SetStatus(errors::Internal(
|
|
|
|
|
"Blas xGEMM launch failed : a.shape=", in_x.shape().DebugString(),
|
|
|
|
|
", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
|
|
|
|
|
", k=", k));
|
|
|
|
|
}
|
|
|
|
|
} else if (use_strided_batched) {
|
|
|
|
|
bool blas_launch_status =
|
|
|
|
|
stream
|
|
|
|
|
->ThenBlasGemmStridedBatched(
|
|
|
|
|
blas_transpose_b, blas_transpose_a, n, m, k,
|
|
|
|
|
static_cast<Coefficient>(1.0), *b_ptrs[0],
|
|
|
|
|
adj_y || trans_y ? k : n, b_stride, *a_ptrs[0],
|
|
|
|
|
adj_x || trans_x ? m : k, a_stride,
|
|
|
|
|
static_cast<Coefficient>(0.0), c_ptrs[0], n, c_stride,
|
|
|
|
|
batch_size)
|
|
|
|
|
.ok();
|
|
|
|
|
if (!blas_launch_status) {
|
|
|
|
|
context->SetStatus(errors::Internal(
|
|
|
|
|
"Blas xGEMMStridedBatched launch failed : a.shape=",
|
|
|
|
|
in_x.shape().DebugString(),
|
|
|
|
|
", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
|
|
|
|
|
", k=", k, ", batch_size=", batch_size));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
BlasScratchAllocator scratch_allocator(context);
|
|
|
|
|
BlasScratchAllocator scratch_allocator(max_scratch_size, context);
|
|
|
|
|
bool blas_launch_status =
|
|
|
|
|
stream
|
|
|
|
|
->ThenBlasGemmBatchedWithScratch(
|
|
|
|
@ -624,6 +839,7 @@ struct LaunchBatchMatMul<GPUDevice, Eigen::half> {
|
|
|
|
|
", k=", k, ", batch_size=", batch_size));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif // not GOOGLE_CUDA or CUDA_VERSION < 11000
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -637,6 +853,7 @@ class BaseBatchMatMulOp : public OpKernel {
|
|
|
|
|
: OpKernel(context) {
|
|
|
|
|
OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
|
|
|
|
|
OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
|
|
|
|
|
use_autotune_ = MatmulAutotuneEnable();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
~BaseBatchMatMulOp() override {}
|
|
|
|
@ -698,7 +915,7 @@ class BaseBatchMatMulOp : public OpKernel {
|
|
|
|
|
out->shape().DebugString()));
|
|
|
|
|
LaunchBatchMatMul<Device, Scalar>::Launch(
|
|
|
|
|
ctx, in0_reshaped, in1_reshaped, adj_x_, adj_y_, /*trans_x=*/false,
|
|
|
|
|
/*trans_y=*/false, bcast, &out_reshaped);
|
|
|
|
|
/*trans_y=*/false, bcast, use_autotune_, &out_reshaped);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
@ -708,6 +925,7 @@ class BaseBatchMatMulOp : public OpKernel {
|
|
|
|
|
private:
|
|
|
|
|
bool adj_x_;
|
|
|
|
|
bool adj_y_;
|
|
|
|
|
bool use_autotune_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// BatchMatMul Op implementation which disallows broadcasting.
|
|
|
|
|