Add support for blasLt epilogue fn and bias vector

- Changes the backend APIs to allow an epilogue function (default, ReLU,
  bias, or bias then ReLU) to be specified and a bias vector to be
  provided.
- This is expected to be useful for XLA to perform fusions.
- This functionality is not currently tested, because the BatchMatMulOp
  does not expose relu/bias fusion.
This commit is contained in:
Ben Barsdell 2020-08-05 09:36:23 +10:00
parent 8c0eb4b35b
commit 39bf03f083
8 changed files with 182 additions and 84 deletions

View File

@ -558,8 +558,8 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
stream->parent()->CreateBlasLtMatmulPlanStridedBatched(
/*ab_type=*/blas_dtype,
/*cd_type=*/blas_dtype, computation_type,
se::blas::PointerMode::kHost, blas_transpose_b,
blas_transpose_a, n, m, k, batch_size,
se::blas::PointerMode::kHost, se::blas::Epilogue::kDefault,
blas_transpose_b, blas_transpose_a, n, m, k, batch_size,
/*lda=*/in_y.dim_size(2), b_stride,
/*ldb=*/in_x.dim_size(2), a_stride, /*ldc=*/n, c_stride);
OP_REQUIRES(
@ -621,7 +621,8 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
stream
->ThenBlasLtMatmul(plan.get(), alpha, *b_ptrs[0], *a_ptrs[0],
beta, c_ptrs[0], &scratch_allocator,
profile_algorithm.get(), &profile_result)
profile_algorithm.get(), {},
&profile_result)
.ok();
VLOG(4) << " Autotune algorithm " << i

View File

@ -107,6 +107,13 @@ enum class ComputationType {
kF32FastBF16, // 32-bit floating-point with reduced (7-bit) mantissa
};
enum class Epilogue {
kDefault = 1, // No special postprocessing
kReLU = 2, // Apply ReLU func point-wise to the results
kBias = 4, // Add broadcasted bias vector to the results
kBiasThenReLU = kBias | kReLU, // Apply bias and then ReLU transform
};
// Converts a ComputationType to a string.
std::string ComputationTypeString(ComputationType ty);
@ -1462,11 +1469,11 @@ class BlasSupport {
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
blas::DataType ab_type, blas::DataType c_type,
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, int64 lda, int64 ldb, int64 ldc) {
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc) {
return CreateBlasLtMatmulPlanStridedBatched(
ab_type, c_type, computation_type, pointer_mode, transa, transb, m, n,
k, 1, lda, 0, ldb, 0, ldc, 0);
ab_type, c_type, computation_type, pointer_mode, epilogue, transa,
transb, m, n, k, 1, lda, 0, ldb, 0, ldc, 0);
}
// A more general version of CreateBlasLtMatmulPlan supporting
@ -1475,9 +1482,9 @@ class BlasSupport {
CreateBlasLtMatmulPlanStridedBatched(
blas::DataType ab_type, blas::DataType c_type,
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, int batch_count, int64 lda, int64 stride_a, int64 ldb,
int64 stride_b, int64 ldc, int64 stride_c) = 0;
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a,
int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) = 0;
// Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
// returned in the order of increasing estimated compute time according to an
@ -1492,13 +1499,18 @@ class BlasSupport {
// Executes a blaslt matmul operation on the stream. If output_profile_result
// is not nullptr, the operation is profiled, error messages are
// suppressed, and output_profile_result->algorithm() is set to
// algorithm->index().
// algorithm->index(). If epilogue was set to kBias or kBiasThenReLU when
// creating the plan, the bias argument here must refer to a valid device
// vector of length equal to the number of rows in matrix c. If epilogue was
// set to any other value then the bias argument here must be null. The bias
// vector is broadcast across the batch dimension.
virtual bool DoBlasLtMatmul(
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
const HostOrDeviceScalar<int32>& alpha, const DeviceMemory<int8>& a,
const DeviceMemory<int8>& b, const HostOrDeviceScalar<int32>& beta,
DeviceMemory<int32>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
const DeviceMemory<int32>& bias = {},
blas::ProfileResult* output_profile_result = nullptr) = 0;
virtual bool DoBlasLtMatmul(
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
@ -1507,6 +1519,7 @@ class BlasSupport {
const HostOrDeviceScalar<Eigen::half>& beta, DeviceMemory<Eigen::half>* c,
ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
const DeviceMemory<Eigen::half>& bias = {},
blas::ProfileResult* output_profile_result = nullptr) = 0;
virtual bool DoBlasLtMatmul(
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
@ -1514,6 +1527,7 @@ class BlasSupport {
const DeviceMemory<float>& b, const HostOrDeviceScalar<float>& beta,
DeviceMemory<float>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
const DeviceMemory<float>& bias = {},
blas::ProfileResult* output_profile_result = nullptr) = 0;
virtual bool DoBlasLtMatmul(
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
@ -1521,6 +1535,7 @@ class BlasSupport {
const DeviceMemory<double>& b, const HostOrDeviceScalar<double>& beta,
DeviceMemory<double>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
const DeviceMemory<double>& bias = {},
blas::ProfileResult* output_profile_result = nullptr) = 0;
virtual bool DoBlasLtMatmul(
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
@ -1530,6 +1545,7 @@ class BlasSupport {
const HostOrDeviceScalar<std::complex<float>>& beta,
DeviceMemory<std::complex<float>>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
const DeviceMemory<std::complex<float>>& bias = {},
blas::ProfileResult* output_profile_result = nullptr) = 0;
virtual bool DoBlasLtMatmul(
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
@ -1540,6 +1556,7 @@ class BlasSupport {
DeviceMemory<std::complex<double>>* c,
ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
const DeviceMemory<std::complex<double>>& bias = {},
blas::ProfileResult* output_profile_result = nullptr) = 0;
virtual port::Status GetVersion(std::string *version) = 0;
@ -2359,9 +2376,10 @@ class BlasSupport {
CreateBlasLtMatmulPlanStridedBatched( \
blas::DataType ab_type, blas::DataType cd_type, \
blas::ComputationType computation_type, blas::PointerMode pointer_mode, \
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, \
uint64 k, int batch_count, int64 lda, int64 stride_a, int64 ldb, \
int64 stride_b, int64 ldc, int64 stride_c) override; \
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb, \
uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, \
int64 stride_a, int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) \
override; \
bool GetBlasLtMatmulAlgorithms( \
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size, \
int max_algorithm_count, \
@ -2373,6 +2391,7 @@ class BlasSupport {
const DeviceMemory<int8>& b, const HostOrDeviceScalar<int32>& beta, \
DeviceMemory<int32>* c, ScratchAllocator* scratch_allocator, \
const blas::IBlasLtMatmulAlgorithm* algorithm, \
const DeviceMemory<int32>& bias = {}, \
blas::ProfileResult* output_profile_result = nullptr) override; \
bool DoBlasLtMatmul( \
Stream* stream, const blas::IBlasLtMatmulPlan* plan, \
@ -2381,21 +2400,24 @@ class BlasSupport {
const HostOrDeviceScalar<Eigen::half>& beta, \
DeviceMemory<Eigen::half>* c, ScratchAllocator* scratch_allocator, \
const blas::IBlasLtMatmulAlgorithm* algorithm, \
blas::ProfileResult* output_profile_result) override; \
const DeviceMemory<Eigen::half>& bias = {}, \
blas::ProfileResult* output_profile_result = nullptr) override; \
bool DoBlasLtMatmul( \
Stream* stream, const blas::IBlasLtMatmulPlan* plan, \
const HostOrDeviceScalar<float>& alpha, const DeviceMemory<float>& a, \
const DeviceMemory<float>& b, const HostOrDeviceScalar<float>& beta, \
DeviceMemory<float>* c, ScratchAllocator* scratch_allocator, \
const blas::IBlasLtMatmulAlgorithm* algorithm, \
blas::ProfileResult* output_profile_result) override; \
const DeviceMemory<float>& bias = {}, \
blas::ProfileResult* output_profile_result = nullptr) override; \
bool DoBlasLtMatmul( \
Stream* stream, const blas::IBlasLtMatmulPlan* plan, \
const HostOrDeviceScalar<double>& alpha, const DeviceMemory<double>& a, \
const DeviceMemory<double>& b, const HostOrDeviceScalar<double>& beta, \
DeviceMemory<double>* c, ScratchAllocator* scratch_allocator, \
const blas::IBlasLtMatmulAlgorithm* algorithm, \
blas::ProfileResult* output_profile_result) override; \
const DeviceMemory<double>& bias = {}, \
blas::ProfileResult* output_profile_result = nullptr) override; \
bool DoBlasLtMatmul(Stream* stream, const blas::IBlasLtMatmulPlan* plan, \
const HostOrDeviceScalar<std::complex<float>>& alpha, \
const DeviceMemory<std::complex<float>>& a, \
@ -2404,7 +2426,9 @@ class BlasSupport {
DeviceMemory<std::complex<float>>* c, \
ScratchAllocator* scratch_allocator, \
const blas::IBlasLtMatmulAlgorithm* algorithm, \
blas::ProfileResult* output_profile_result) override; \
const DeviceMemory<std::complex<float>>& bias = {}, \
blas::ProfileResult* output_profile_result = nullptr) \
override; \
bool DoBlasLtMatmul(Stream* stream, const blas::IBlasLtMatmulPlan* plan, \
const HostOrDeviceScalar<std::complex<double>>& alpha, \
const DeviceMemory<std::complex<double>>& a, \
@ -2413,7 +2437,9 @@ class BlasSupport {
DeviceMemory<std::complex<double>>* c, \
ScratchAllocator* scratch_allocator, \
const blas::IBlasLtMatmulAlgorithm* algorithm, \
blas::ProfileResult* output_profile_result) override; \
const DeviceMemory<std::complex<double>>& bias = {}, \
blas::ProfileResult* output_profile_result = nullptr) \
override; \
port::Status GetVersion(std::string *version) override;
} // namespace blas

View File

@ -468,6 +468,18 @@ cublasLtPointerMode_t CUBLASPointerMode(blas::PointerMode pointer_mode) {
return CUBLASLT_POINTER_MODE_DEVICE;
}
}
cublasLtEpilogue_t CUBLASEpilogue(blas::Epilogue epilogue) {
switch (epilogue) {
case blas::Epilogue::kDefault:
return CUBLASLT_EPILOGUE_DEFAULT;
case blas::Epilogue::kReLU:
return CUBLASLT_EPILOGUE_RELU;
case blas::Epilogue::kBias:
return CUBLASLT_EPILOGUE_BIAS;
case blas::Epilogue::kBiasThenReLU:
return CUBLASLT_EPILOGUE_RELU_BIAS;
}
}
#endif // CUDA_VERSION >= 11000
cudaDataType_t GetCUDADataType(blas::DataType ty) {
@ -3135,12 +3147,12 @@ using UniqueMatmulPreference =
std::unique_ptr<std::remove_pointer<cublasLtMatmulPreference_t>::type,
MatmulPreferenceDestroyer>;
UniqueOpDesc CreateCublasLtOperationDesc(
blas::ComputationType computation_type, blas::DataType scale_type,
blas::PointerMode pointer_mode, blas::Transpose transa,
blas::Transpose transb) {
cublasOperation_t cuda_transa = CUDABlasTranspose(transa);
cublasOperation_t cuda_transb = CUDABlasTranspose(transb);
UniqueOpDesc CreateCublasLtOperationDesc(blas::ComputationType computation_type,
blas::DataType scale_type,
blas::PointerMode pointer_mode,
blas::Epilogue epilogue,
blas::Transpose transa,
blas::Transpose transb) {
cublasLtMatmulDesc_t desc;
cublasComputeType_t cublas_compute_type =
CUBLASComputationType(computation_type);
@ -3154,9 +3166,13 @@ UniqueOpDesc CreateCublasLtOperationDesc(
}
UniqueOpDesc unique_desc(desc);
if (!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_POINTER_MODE,
CUBLASPointerMode(pointer_mode)) ||
!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSA, cuda_transa) ||
!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSB, cuda_transb)) {
CUBLASPointerMode(pointer_mode)) ||
!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_EPILOGUE,
CUBLASEpilogue(epilogue)) ||
!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSA,
CUDABlasTranspose(transa)) ||
!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSB,
CUDABlasTranspose(transb))) {
return nullptr;
}
return unique_desc;
@ -3217,11 +3233,11 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
public:
CUDABlasLtMatmulPlan(blas::DataType ab_type, blas::DataType cd_type,
blas::ComputationType compute_type,
blas::PointerMode pointer_mode, blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n, uint64 k,
int batch_count, int64 lda, int64 stride_a, int64 ldb,
int64 stride_b, int64 ldc, int64 stride_c, int64 ldd,
int64 stride_d);
blas::PointerMode pointer_mode, blas::Epilogue epilogue,
blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, int batch_count, int64 lda,
int64 stride_a, int64 ldb, int64 stride_b, int64 ldc,
int64 stride_c, int64 ldd, int64 stride_d);
cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); }
cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); }
@ -3234,12 +3250,17 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
blas::DataType cd_type() const { return cd_type_; }
blas::DataType scale_type() const { return scale_type_; }
blas::PointerMode pointer_mode() const { return pointer_mode_; }
blas::Epilogue epilogue() const { return epilogue_; }
int batch_count() const { return batch_count_; }
int64 stride_a() const { return stride_a_; }
int64 stride_b() const { return stride_b_; }
int64 stride_c() const { return stride_c_; }
int64 stride_d() const { return stride_d_; }
// Note: Must be const to satisfy API. This is always called before the plan
// is executed, so the state change is not observed in subsequent executions.
bool SetBiasPointer(const void* bias) const;
private:
UniqueOpDesc op_desc_;
UniqueLayoutDesc a_desc_;
@ -3250,6 +3271,7 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
blas::DataType cd_type_;
blas::DataType scale_type_;
blas::PointerMode pointer_mode_;
blas::Epilogue epilogue_;
int batch_count_;
int64 stride_a_;
int64 stride_b_;
@ -3260,12 +3282,13 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
CUDABlasLtMatmulPlan::CUDABlasLtMatmulPlan(
blas::DataType ab_type, blas::DataType cd_type,
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, int batch_count, int64 lda, int64 stride_a, int64 ldb,
int64 stride_b, int64 ldc, int64 stride_c, int64 ldd, int64 stride_d)
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a,
int64 ldb, int64 stride_b, int64 ldc, int64 stride_c, int64 ldd,
int64 stride_d)
: op_desc_(CreateCublasLtOperationDesc(
computation_type, GetScaleType(cd_type, computation_type),
pointer_mode, transa, transb)),
pointer_mode, epilogue, transa, transb)),
a_desc_(nullptr),
b_desc_(nullptr),
c_desc_(
@ -3276,6 +3299,7 @@ CUDABlasLtMatmulPlan::CUDABlasLtMatmulPlan(
cd_type_(cd_type),
scale_type_(GetScaleType(cd_type, computation_type)),
pointer_mode_(pointer_mode),
epilogue_(epilogue),
batch_count_(batch_count),
stride_a_(stride_a),
stride_b_(stride_b),
@ -3291,6 +3315,11 @@ CUDABlasLtMatmulPlan::CUDABlasLtMatmulPlan(
batch_count);
}
bool CUDABlasLtMatmulPlan::SetBiasPointer(const void* bias) const {
return SetCublasLtAttr(op_desc_.get(), CUBLASLT_MATMUL_DESC_BIAS_POINTER,
bias);
}
class CUDABlasLtMatmulAlgorithm final : public blas::IBlasLtMatmulAlgorithm {
public:
CUDABlasLtMatmulAlgorithm(blas::AlgorithmType index,
@ -3370,13 +3399,14 @@ std::unique_ptr<blas::IBlasLtMatmulPlan>
CUDABlas::CreateBlasLtMatmulPlanStridedBatched(
blas::DataType ab_type, blas::DataType cd_type,
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, int batch_count, int64 lda, int64 stride_a, int64 ldb,
int64 stride_b, int64 ldc, int64 stride_c) {
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a,
int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) {
#if CUDA_VERSION >= 11000
auto result = std::make_unique<CUDABlasLtMatmulPlan>(
ab_type, cd_type, computation_type, pointer_mode, transa, transb, m, n, k,
batch_count, lda, stride_a, ldb, stride_b, ldc, stride_c, ldc, stride_c);
ab_type, cd_type, computation_type, pointer_mode, epilogue, transa,
transb, m, n, k, batch_count, lda, stride_a, ldb, stride_b, ldc, stride_c,
ldc, stride_c);
if (!result->ok()) {
result.reset();
}
@ -3436,7 +3466,8 @@ bool CUDABlas::DoBlasLtMatmulInternalImpl(
const HostOrDeviceScalar<ScaleType>& alpha, const ABType* a,
const ABType* b, const HostOrDeviceScalar<ScaleType>& beta, const CDType* c,
CDType* d, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm) {
const blas::IBlasLtMatmulAlgorithm* algorithm,
const CDType* bias) {
const auto& cuda_plan = *static_cast<const CUDABlasLtMatmulPlan*>(plan);
const auto& cuda_algo =
*static_cast<const CUDABlasLtMatmulAlgorithm*>(algorithm);
@ -3474,6 +3505,20 @@ bool CUDABlas::DoBlasLtMatmulInternalImpl(
"pointer_mode for the given alpha/beta.";
return false;
}
if ((cuda_plan.epilogue() == blas::Epilogue::kBias ||
cuda_plan.epilogue() == blas::Epilogue::kBiasThenReLU) !=
(bias != nullptr)) {
VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong "
"epilogue for the given bias pointer.";
return false;
}
if (bias != nullptr) {
if (!cuda_plan.SetBiasPointer(bias)) {
VLOG(2) << "DoBlasLtMatmul returning false because setting the bias "
"pointer failed.";
return false;
}
}
const ScaleType* alpha_ptr =
alpha.is_pointer() ? GpuMemory(alpha.pointer()) : &alpha.value();
const ScaleType* beta_ptr =
@ -3525,6 +3570,7 @@ bool CUDABlas::DoBlasLtMatmulInternal(
const DeviceMemory<CDType>& c, DeviceMemory<CDType>* d,
ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
const DeviceMemory<CDType>& bias,
blas::ProfileResult* output_profile_result) {
#if CUDA_VERSION >= 11000
std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
@ -3538,7 +3584,8 @@ bool CUDABlas::DoBlasLtMatmulInternal(
bool err_on_failure = timer != nullptr;
bool result = DoBlasLtMatmulInternalImpl(
stream, err_on_failure, plan, alpha, GpuMemory(a), GpuMemory(b), beta,
GpuMemory(c), GpuMemoryMutable(d), scratch_allocator, algorithm);
GpuMemory(c), GpuMemoryMutable(d), scratch_allocator, algorithm,
GpuMemory(bias));
if (timer && result) {
// GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
@ -3563,9 +3610,10 @@ bool CUDABlas::DoBlasLtMatmul(
const DeviceMemory<int8>& b, const HostOrDeviceScalar<int32>& beta,
DeviceMemory<int32>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
const DeviceMemory<int32>& bias,
blas::ProfileResult* output_profile_result) {
return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
scratch_allocator, algorithm,
scratch_allocator, algorithm, bias,
output_profile_result);
}
@ -3578,6 +3626,7 @@ bool CUDABlas::DoBlasLtMatmul(Stream* stream,
DeviceMemory<Eigen::half>* c,
ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
const DeviceMemory<Eigen::half>& bias,
blas::ProfileResult* output_profile_result) {
#if CUDA_VERSION >= 11000
const auto& cuda_plan = *static_cast<const CUDABlasLtMatmulPlan*>(plan);
@ -3591,11 +3640,11 @@ bool CUDABlas::DoBlasLtMatmul(Stream* stream,
HostOrDeviceScalar<float> float_alpha(static_cast<float>(alpha.value()));
HostOrDeviceScalar<float> float_beta(static_cast<float>(beta.value()));
return DoBlasLtMatmulInternal(stream, plan, float_alpha, a, b, float_beta,
*c, c, scratch_allocator, algorithm,
*c, c, scratch_allocator, algorithm, bias,
output_profile_result);
}
return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
scratch_allocator, algorithm,
scratch_allocator, algorithm, bias,
output_profile_result);
#else // if CUDA_VERSION < 11000
return false;
@ -3608,9 +3657,10 @@ bool CUDABlas::DoBlasLtMatmul(
const DeviceMemory<float>& b, const HostOrDeviceScalar<float>& beta,
DeviceMemory<float>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
const DeviceMemory<float>& bias,
blas::ProfileResult* output_profile_result) {
return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
scratch_allocator, algorithm,
scratch_allocator, algorithm, bias,
output_profile_result);
}
@ -3620,9 +3670,10 @@ bool CUDABlas::DoBlasLtMatmul(
const DeviceMemory<double>& b, const HostOrDeviceScalar<double>& beta,
DeviceMemory<double>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
const DeviceMemory<double>& bias,
blas::ProfileResult* output_profile_result) {
return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
scratch_allocator, algorithm,
scratch_allocator, algorithm, bias,
output_profile_result);
}
@ -3634,9 +3685,10 @@ bool CUDABlas::DoBlasLtMatmul(
const HostOrDeviceScalar<std::complex<float>>& beta,
DeviceMemory<std::complex<float>>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
const DeviceMemory<std::complex<float>>& bias,
blas::ProfileResult* output_profile_result) {
return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
scratch_allocator, algorithm,
scratch_allocator, algorithm, bias,
output_profile_result);
}
@ -3648,9 +3700,10 @@ bool CUDABlas::DoBlasLtMatmul(
const HostOrDeviceScalar<std::complex<double>>& beta,
DeviceMemory<std::complex<double>>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
const DeviceMemory<std::complex<double>>& bias,
blas::ProfileResult* output_profile_result) {
return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
scratch_allocator, algorithm,
scratch_allocator, algorithm, bias,
output_profile_result);
}

View File

@ -148,6 +148,7 @@ class CUDABlas : public blas::BlasSupport {
const DeviceMemory<CDType>& c, DeviceMemory<CDType>* d,
ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
const DeviceMemory<CDType>& bias,
blas::ProfileResult* output_profile_result);
// Helper function for implementing DoBlasLtMatmulInternal.
@ -157,7 +158,7 @@ class CUDABlas : public blas::BlasSupport {
const HostOrDeviceScalar<ScaleType>& alpha, const ABType* a,
const ABType* b, const HostOrDeviceScalar<ScaleType>& beta,
const CDType* c, CDType* d, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm);
const blas::IBlasLtMatmulAlgorithm* algorithm, const CDType* bias);
// Guards the cuBLAS handle for this device.
absl::Mutex mu_;

View File

@ -4809,18 +4809,19 @@ Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
DeviceMemory<int32>* c,
ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
const DeviceMemory<int32>& bias,
blas::ProfileResult* output_profile_result) {
VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
PARAM(c), PARAM(algorithm));
PARAM(c), PARAM(algorithm), PARAM(bias));
ThenBlasWithProfileImpl<
const blas::IBlasLtMatmulPlan*, const HostOrDeviceScalar<int32>&,
const DeviceMemory<int8>&, const DeviceMemory<int8>&,
const HostOrDeviceScalar<int32>&, DeviceMemory<int32>*, ScratchAllocator*,
const blas::IBlasLtMatmulAlgorithm*>
const blas::IBlasLtMatmulAlgorithm*, const DeviceMemory<int32>&>
impl;
return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
c, scratch_allocator, algorithm, output_profile_result);
c, scratch_allocator, algorithm, bias, output_profile_result);
}
Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
@ -4831,18 +4832,20 @@ Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
DeviceMemory<Eigen::half>* c,
ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
const DeviceMemory<Eigen::half>& bias,
blas::ProfileResult* output_profile_result) {
VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
PARAM(c), PARAM(algorithm));
PARAM(c), PARAM(algorithm), PARAM(bias));
ThenBlasWithProfileImpl<
const blas::IBlasLtMatmulPlan*, const HostOrDeviceScalar<Eigen::half>&,
const DeviceMemory<Eigen::half>&, const DeviceMemory<Eigen::half>&,
const HostOrDeviceScalar<Eigen::half>&, DeviceMemory<Eigen::half>*,
ScratchAllocator*, const blas::IBlasLtMatmulAlgorithm*>
ScratchAllocator*, const blas::IBlasLtMatmulAlgorithm*,
const DeviceMemory<Eigen::half>&>
impl;
return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
c, scratch_allocator, algorithm, output_profile_result);
c, scratch_allocator, algorithm, bias, output_profile_result);
}
Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
@ -4853,18 +4856,19 @@ Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
DeviceMemory<float>* c,
ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
const DeviceMemory<float>& bias,
blas::ProfileResult* output_profile_result) {
VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
PARAM(c), PARAM(algorithm));
PARAM(c), PARAM(algorithm), PARAM(bias));
ThenBlasWithProfileImpl<
const blas::IBlasLtMatmulPlan*, const HostOrDeviceScalar<float>&,
const DeviceMemory<float>&, const DeviceMemory<float>&,
const HostOrDeviceScalar<float>&, DeviceMemory<float>*, ScratchAllocator*,
const blas::IBlasLtMatmulAlgorithm*>
const blas::IBlasLtMatmulAlgorithm*, const DeviceMemory<float>&>
impl;
return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
c, scratch_allocator, algorithm, output_profile_result);
c, scratch_allocator, algorithm, bias, output_profile_result);
}
Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
@ -4875,18 +4879,20 @@ Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
DeviceMemory<double>* c,
ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
const DeviceMemory<double>& bias,
blas::ProfileResult* output_profile_result) {
VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
PARAM(c), PARAM(algorithm));
PARAM(c), PARAM(algorithm), PARAM(bias));
ThenBlasWithProfileImpl<
const blas::IBlasLtMatmulPlan*, const HostOrDeviceScalar<double>&,
const DeviceMemory<double>&, const DeviceMemory<double>&,
const HostOrDeviceScalar<double>&, DeviceMemory<double>*,
ScratchAllocator*, const blas::IBlasLtMatmulAlgorithm*>
ScratchAllocator*, const blas::IBlasLtMatmulAlgorithm*,
const DeviceMemory<double>&>
impl;
return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
c, scratch_allocator, algorithm, output_profile_result);
c, scratch_allocator, algorithm, bias, output_profile_result);
}
Stream& Stream::ThenBlasLtMatmul(
@ -4897,9 +4903,10 @@ Stream& Stream::ThenBlasLtMatmul(
const HostOrDeviceScalar<std::complex<float>>& beta,
DeviceMemory<std::complex<float>>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
const DeviceMemory<std::complex<float>>& bias,
blas::ProfileResult* output_profile_result) {
VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
PARAM(c), PARAM(algorithm));
PARAM(c), PARAM(algorithm), PARAM(bias));
ThenBlasWithProfileImpl<const blas::IBlasLtMatmulPlan*,
const HostOrDeviceScalar<std::complex<float>>&,
@ -4907,10 +4914,11 @@ Stream& Stream::ThenBlasLtMatmul(
const DeviceMemory<std::complex<float>>&,
const HostOrDeviceScalar<std::complex<float>>&,
DeviceMemory<std::complex<float>>*, ScratchAllocator*,
const blas::IBlasLtMatmulAlgorithm*>
const blas::IBlasLtMatmulAlgorithm*,
const DeviceMemory<std::complex<float>>&>
impl;
return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
c, scratch_allocator, algorithm, output_profile_result);
c, scratch_allocator, algorithm, bias, output_profile_result);
}
Stream& Stream::ThenBlasLtMatmul(
@ -4921,9 +4929,10 @@ Stream& Stream::ThenBlasLtMatmul(
const HostOrDeviceScalar<std::complex<double>>& beta,
DeviceMemory<std::complex<double>>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
const DeviceMemory<std::complex<double>>& bias,
blas::ProfileResult* output_profile_result) {
VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
PARAM(c), PARAM(algorithm));
PARAM(c), PARAM(algorithm), PARAM(bias));
ThenBlasWithProfileImpl<const blas::IBlasLtMatmulPlan*,
const HostOrDeviceScalar<std::complex<double>>&,
@ -4932,10 +4941,11 @@ Stream& Stream::ThenBlasLtMatmul(
const HostOrDeviceScalar<std::complex<double>>&,
DeviceMemory<std::complex<double>>*,
ScratchAllocator*,
const blas::IBlasLtMatmulAlgorithm*>
const blas::IBlasLtMatmulAlgorithm*,
const DeviceMemory<std::complex<double>>&>
impl;
return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
c, scratch_allocator, algorithm, output_profile_result);
c, scratch_allocator, algorithm, bias, output_profile_result);
}
Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {

View File

@ -1672,6 +1672,7 @@ class Stream {
const DeviceMemory<int8>& b, const HostOrDeviceScalar<int32>& beta,
DeviceMemory<int32>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
const DeviceMemory<int32>& bias = {},
blas::ProfileResult* output_profile_result = nullptr);
Stream& ThenBlasLtMatmul(
const blas::IBlasLtMatmulPlan* plan,
@ -1680,6 +1681,7 @@ class Stream {
const HostOrDeviceScalar<Eigen::half>& beta, DeviceMemory<Eigen::half>* c,
ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
const DeviceMemory<Eigen::half>& bias = {},
blas::ProfileResult* output_profile_result = nullptr);
Stream& ThenBlasLtMatmul(
const blas::IBlasLtMatmulPlan* plan,
@ -1687,6 +1689,7 @@ class Stream {
const DeviceMemory<float>& b, const HostOrDeviceScalar<float>& beta,
DeviceMemory<float>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
const DeviceMemory<float>& bias = {},
blas::ProfileResult* output_profile_result = nullptr);
Stream& ThenBlasLtMatmul(
const blas::IBlasLtMatmulPlan* plan,
@ -1694,6 +1697,7 @@ class Stream {
const DeviceMemory<double>& b, const HostOrDeviceScalar<double>& beta,
DeviceMemory<double>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
const DeviceMemory<double>& bias = {},
blas::ProfileResult* output_profile_result = nullptr);
Stream& ThenBlasLtMatmul(
const blas::IBlasLtMatmulPlan* plan,
@ -1703,6 +1707,7 @@ class Stream {
const HostOrDeviceScalar<std::complex<float>>& beta,
DeviceMemory<std::complex<float>>* c, ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
const DeviceMemory<std::complex<float>>& bias = {},
blas::ProfileResult* output_profile_result = nullptr);
Stream& ThenBlasLtMatmul(
const blas::IBlasLtMatmulPlan* plan,
@ -1713,6 +1718,7 @@ class Stream {
DeviceMemory<std::complex<double>>* c,
ScratchAllocator* scratch_allocator,
const blas::IBlasLtMatmulAlgorithm* algorithm,
const DeviceMemory<std::complex<double>>& bias = {},
blas::ProfileResult* output_profile_result = nullptr);
// See FftSupport::DoFft.

View File

@ -339,31 +339,32 @@ bool StreamExecutor::GetBlasGemmAlgorithms(
std::unique_ptr<blas::IBlasLtMatmulPlan> StreamExecutor::CreateBlasLtMatmulPlan(
blas::DataType ab_type, blas::DataType cd_type,
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, int64 lda, int64 ldb, int64 ldc) {
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc) {
blas::BlasSupport *blas_support = AsBlas();
if (!blas_support) {
return nullptr;
}
return blas_support->CreateBlasLtMatmulPlan(
ab_type, cd_type, computation_type, pointer_mode, transa, transb, m, n, k,
lda, ldb, ldc);
ab_type, cd_type, computation_type, pointer_mode, epilogue, transa,
transb, m, n, k, lda, ldb, ldc);
}
std::unique_ptr<blas::IBlasLtMatmulPlan>
StreamExecutor::CreateBlasLtMatmulPlanStridedBatched(
blas::DataType ab_type, blas::DataType cd_type,
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, uint64 batch_count, int64 lda, int64 stride_a, int64 ldb,
int64 stride_b, int64 ldc, int64 stride_c) {
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, uint64 batch_count, int64 lda, int64 stride_a,
int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) {
blas::BlasSupport *blas_support = AsBlas();
if (!blas_support) {
return nullptr;
}
return blas_support->CreateBlasLtMatmulPlanStridedBatched(
ab_type, cd_type, computation_type, pointer_mode, transa, transb, m, n, k,
batch_count, lda, stride_a, ldb, stride_b, ldc, stride_c);
ab_type, cd_type, computation_type, pointer_mode, epilogue, transa,
transb, m, n, k, batch_count, lda, stride_a, ldb, stride_b, ldc,
stride_c);
}
bool StreamExecutor::GetBlasLtMatmulAlgorithms(

View File

@ -401,17 +401,17 @@ class StreamExecutor {
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
blas::DataType ab_type, blas::DataType cd_type,
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, int64 lda, int64 ldb, int64 ldc);
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc);
// A more general version of CreateBlasLtMatmulPlan supporting
// batched operations.
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlanStridedBatched(
blas::DataType ab_type, blas::DataType cd_type,
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, uint64 batch_count, int64 lda, int64 stride_a, int64 ldb,
int64 stride_b, int64 ldc, int64 stride_c);
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
uint64 m, uint64 n, uint64 k, uint64 batch_count, int64 lda,
int64 stride_a, int64 ldb, int64 stride_b, int64 ldc, int64 stride_c);
// Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
// returned in the order of increasing estimated compute time according to an