Add support for blasLt epilogue fn and bias vector
- Changes the backend APIs to allow an epilogue function (default, ReLU, bias, or bias then ReLU) to be specified and a bias vector to be provided. - This is expected to be useful for XLA to perform fusions. - This functionality is not currently tested, because the BatchMatMulOp does not expose relu/bias fusion.
This commit is contained in:
parent
8c0eb4b35b
commit
39bf03f083
@ -558,8 +558,8 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
stream->parent()->CreateBlasLtMatmulPlanStridedBatched(
|
||||
/*ab_type=*/blas_dtype,
|
||||
/*cd_type=*/blas_dtype, computation_type,
|
||||
se::blas::PointerMode::kHost, blas_transpose_b,
|
||||
blas_transpose_a, n, m, k, batch_size,
|
||||
se::blas::PointerMode::kHost, se::blas::Epilogue::kDefault,
|
||||
blas_transpose_b, blas_transpose_a, n, m, k, batch_size,
|
||||
/*lda=*/in_y.dim_size(2), b_stride,
|
||||
/*ldb=*/in_x.dim_size(2), a_stride, /*ldc=*/n, c_stride);
|
||||
OP_REQUIRES(
|
||||
@ -621,7 +621,8 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
||||
stream
|
||||
->ThenBlasLtMatmul(plan.get(), alpha, *b_ptrs[0], *a_ptrs[0],
|
||||
beta, c_ptrs[0], &scratch_allocator,
|
||||
profile_algorithm.get(), &profile_result)
|
||||
profile_algorithm.get(), {},
|
||||
&profile_result)
|
||||
.ok();
|
||||
|
||||
VLOG(4) << " Autotune algorithm " << i
|
||||
|
@ -107,6 +107,13 @@ enum class ComputationType {
|
||||
kF32FastBF16, // 32-bit floating-point with reduced (7-bit) mantissa
|
||||
};
|
||||
|
||||
enum class Epilogue {
|
||||
kDefault = 1, // No special postprocessing
|
||||
kReLU = 2, // Apply ReLU func point-wise to the results
|
||||
kBias = 4, // Add broadcasted bias vector to the results
|
||||
kBiasThenReLU = kBias | kReLU, // Apply bias and then ReLU transform
|
||||
};
|
||||
|
||||
// Converts a ComputationType to a string.
|
||||
std::string ComputationTypeString(ComputationType ty);
|
||||
|
||||
@ -1462,11 +1469,11 @@ class BlasSupport {
|
||||
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
|
||||
blas::DataType ab_type, blas::DataType c_type,
|
||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
||||
uint64 k, int64 lda, int64 ldb, int64 ldc) {
|
||||
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
|
||||
uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc) {
|
||||
return CreateBlasLtMatmulPlanStridedBatched(
|
||||
ab_type, c_type, computation_type, pointer_mode, transa, transb, m, n,
|
||||
k, 1, lda, 0, ldb, 0, ldc, 0);
|
||||
ab_type, c_type, computation_type, pointer_mode, epilogue, transa,
|
||||
transb, m, n, k, 1, lda, 0, ldb, 0, ldc, 0);
|
||||
}
|
||||
|
||||
// A more general version of CreateBlasLtMatmulPlan supporting
|
||||
@ -1475,9 +1482,9 @@ class BlasSupport {
|
||||
CreateBlasLtMatmulPlanStridedBatched(
|
||||
blas::DataType ab_type, blas::DataType c_type,
|
||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
||||
uint64 k, int batch_count, int64 lda, int64 stride_a, int64 ldb,
|
||||
int64 stride_b, int64 ldc, int64 stride_c) = 0;
|
||||
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
|
||||
uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a,
|
||||
int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) = 0;
|
||||
|
||||
// Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
|
||||
// returned in the order of increasing estimated compute time according to an
|
||||
@ -1492,13 +1499,18 @@ class BlasSupport {
|
||||
// Executes a blaslt matmul operation on the stream. If output_profile_result
|
||||
// is not nullptr, the operation is profiled, error messages are
|
||||
// suppressed, and output_profile_result->algorithm() is set to
|
||||
// algorithm->index().
|
||||
// algorithm->index(). If epilogue was set to kBias or kBiasThenReLU when
|
||||
// creating the plan, the bias argument here must refer to a valid device
|
||||
// vector of length equal to the number of rows in matrix c. If epilogue was
|
||||
// set to any other value then the bias argument here must be null. The bias
|
||||
// vector is broadcast across the batch dimension.
|
||||
virtual bool DoBlasLtMatmul(
|
||||
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
|
||||
const HostOrDeviceScalar<int32>& alpha, const DeviceMemory<int8>& a,
|
||||
const DeviceMemory<int8>& b, const HostOrDeviceScalar<int32>& beta,
|
||||
DeviceMemory<int32>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
const DeviceMemory<int32>& bias = {},
|
||||
blas::ProfileResult* output_profile_result = nullptr) = 0;
|
||||
virtual bool DoBlasLtMatmul(
|
||||
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
|
||||
@ -1507,6 +1519,7 @@ class BlasSupport {
|
||||
const HostOrDeviceScalar<Eigen::half>& beta, DeviceMemory<Eigen::half>* c,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
const DeviceMemory<Eigen::half>& bias = {},
|
||||
blas::ProfileResult* output_profile_result = nullptr) = 0;
|
||||
virtual bool DoBlasLtMatmul(
|
||||
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
|
||||
@ -1514,6 +1527,7 @@ class BlasSupport {
|
||||
const DeviceMemory<float>& b, const HostOrDeviceScalar<float>& beta,
|
||||
DeviceMemory<float>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
const DeviceMemory<float>& bias = {},
|
||||
blas::ProfileResult* output_profile_result = nullptr) = 0;
|
||||
virtual bool DoBlasLtMatmul(
|
||||
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
|
||||
@ -1521,6 +1535,7 @@ class BlasSupport {
|
||||
const DeviceMemory<double>& b, const HostOrDeviceScalar<double>& beta,
|
||||
DeviceMemory<double>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
const DeviceMemory<double>& bias = {},
|
||||
blas::ProfileResult* output_profile_result = nullptr) = 0;
|
||||
virtual bool DoBlasLtMatmul(
|
||||
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
|
||||
@ -1530,6 +1545,7 @@ class BlasSupport {
|
||||
const HostOrDeviceScalar<std::complex<float>>& beta,
|
||||
DeviceMemory<std::complex<float>>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
const DeviceMemory<std::complex<float>>& bias = {},
|
||||
blas::ProfileResult* output_profile_result = nullptr) = 0;
|
||||
virtual bool DoBlasLtMatmul(
|
||||
Stream* stream, const blas::IBlasLtMatmulPlan* plan,
|
||||
@ -1540,6 +1556,7 @@ class BlasSupport {
|
||||
DeviceMemory<std::complex<double>>* c,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
const DeviceMemory<std::complex<double>>& bias = {},
|
||||
blas::ProfileResult* output_profile_result = nullptr) = 0;
|
||||
|
||||
virtual port::Status GetVersion(std::string *version) = 0;
|
||||
@ -2359,9 +2376,10 @@ class BlasSupport {
|
||||
CreateBlasLtMatmulPlanStridedBatched( \
|
||||
blas::DataType ab_type, blas::DataType cd_type, \
|
||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode, \
|
||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, \
|
||||
uint64 k, int batch_count, int64 lda, int64 stride_a, int64 ldb, \
|
||||
int64 stride_b, int64 ldc, int64 stride_c) override; \
|
||||
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb, \
|
||||
uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, \
|
||||
int64 stride_a, int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) \
|
||||
override; \
|
||||
bool GetBlasLtMatmulAlgorithms( \
|
||||
const blas::IBlasLtMatmulPlan* plan, size_t max_workspace_size, \
|
||||
int max_algorithm_count, \
|
||||
@ -2373,6 +2391,7 @@ class BlasSupport {
|
||||
const DeviceMemory<int8>& b, const HostOrDeviceScalar<int32>& beta, \
|
||||
DeviceMemory<int32>* c, ScratchAllocator* scratch_allocator, \
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm, \
|
||||
const DeviceMemory<int32>& bias = {}, \
|
||||
blas::ProfileResult* output_profile_result = nullptr) override; \
|
||||
bool DoBlasLtMatmul( \
|
||||
Stream* stream, const blas::IBlasLtMatmulPlan* plan, \
|
||||
@ -2381,21 +2400,24 @@ class BlasSupport {
|
||||
const HostOrDeviceScalar<Eigen::half>& beta, \
|
||||
DeviceMemory<Eigen::half>* c, ScratchAllocator* scratch_allocator, \
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm, \
|
||||
blas::ProfileResult* output_profile_result) override; \
|
||||
const DeviceMemory<Eigen::half>& bias = {}, \
|
||||
blas::ProfileResult* output_profile_result = nullptr) override; \
|
||||
bool DoBlasLtMatmul( \
|
||||
Stream* stream, const blas::IBlasLtMatmulPlan* plan, \
|
||||
const HostOrDeviceScalar<float>& alpha, const DeviceMemory<float>& a, \
|
||||
const DeviceMemory<float>& b, const HostOrDeviceScalar<float>& beta, \
|
||||
DeviceMemory<float>* c, ScratchAllocator* scratch_allocator, \
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm, \
|
||||
blas::ProfileResult* output_profile_result) override; \
|
||||
const DeviceMemory<float>& bias = {}, \
|
||||
blas::ProfileResult* output_profile_result = nullptr) override; \
|
||||
bool DoBlasLtMatmul( \
|
||||
Stream* stream, const blas::IBlasLtMatmulPlan* plan, \
|
||||
const HostOrDeviceScalar<double>& alpha, const DeviceMemory<double>& a, \
|
||||
const DeviceMemory<double>& b, const HostOrDeviceScalar<double>& beta, \
|
||||
DeviceMemory<double>* c, ScratchAllocator* scratch_allocator, \
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm, \
|
||||
blas::ProfileResult* output_profile_result) override; \
|
||||
const DeviceMemory<double>& bias = {}, \
|
||||
blas::ProfileResult* output_profile_result = nullptr) override; \
|
||||
bool DoBlasLtMatmul(Stream* stream, const blas::IBlasLtMatmulPlan* plan, \
|
||||
const HostOrDeviceScalar<std::complex<float>>& alpha, \
|
||||
const DeviceMemory<std::complex<float>>& a, \
|
||||
@ -2404,7 +2426,9 @@ class BlasSupport {
|
||||
DeviceMemory<std::complex<float>>* c, \
|
||||
ScratchAllocator* scratch_allocator, \
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm, \
|
||||
blas::ProfileResult* output_profile_result) override; \
|
||||
const DeviceMemory<std::complex<float>>& bias = {}, \
|
||||
blas::ProfileResult* output_profile_result = nullptr) \
|
||||
override; \
|
||||
bool DoBlasLtMatmul(Stream* stream, const blas::IBlasLtMatmulPlan* plan, \
|
||||
const HostOrDeviceScalar<std::complex<double>>& alpha, \
|
||||
const DeviceMemory<std::complex<double>>& a, \
|
||||
@ -2413,7 +2437,9 @@ class BlasSupport {
|
||||
DeviceMemory<std::complex<double>>* c, \
|
||||
ScratchAllocator* scratch_allocator, \
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm, \
|
||||
blas::ProfileResult* output_profile_result) override; \
|
||||
const DeviceMemory<std::complex<double>>& bias = {}, \
|
||||
blas::ProfileResult* output_profile_result = nullptr) \
|
||||
override; \
|
||||
port::Status GetVersion(std::string *version) override;
|
||||
|
||||
} // namespace blas
|
||||
|
@ -468,6 +468,18 @@ cublasLtPointerMode_t CUBLASPointerMode(blas::PointerMode pointer_mode) {
|
||||
return CUBLASLT_POINTER_MODE_DEVICE;
|
||||
}
|
||||
}
|
||||
cublasLtEpilogue_t CUBLASEpilogue(blas::Epilogue epilogue) {
|
||||
switch (epilogue) {
|
||||
case blas::Epilogue::kDefault:
|
||||
return CUBLASLT_EPILOGUE_DEFAULT;
|
||||
case blas::Epilogue::kReLU:
|
||||
return CUBLASLT_EPILOGUE_RELU;
|
||||
case blas::Epilogue::kBias:
|
||||
return CUBLASLT_EPILOGUE_BIAS;
|
||||
case blas::Epilogue::kBiasThenReLU:
|
||||
return CUBLASLT_EPILOGUE_RELU_BIAS;
|
||||
}
|
||||
}
|
||||
#endif // CUDA_VERSION >= 11000
|
||||
|
||||
cudaDataType_t GetCUDADataType(blas::DataType ty) {
|
||||
@ -3135,12 +3147,12 @@ using UniqueMatmulPreference =
|
||||
std::unique_ptr<std::remove_pointer<cublasLtMatmulPreference_t>::type,
|
||||
MatmulPreferenceDestroyer>;
|
||||
|
||||
UniqueOpDesc CreateCublasLtOperationDesc(
|
||||
blas::ComputationType computation_type, blas::DataType scale_type,
|
||||
blas::PointerMode pointer_mode, blas::Transpose transa,
|
||||
blas::Transpose transb) {
|
||||
cublasOperation_t cuda_transa = CUDABlasTranspose(transa);
|
||||
cublasOperation_t cuda_transb = CUDABlasTranspose(transb);
|
||||
UniqueOpDesc CreateCublasLtOperationDesc(blas::ComputationType computation_type,
|
||||
blas::DataType scale_type,
|
||||
blas::PointerMode pointer_mode,
|
||||
blas::Epilogue epilogue,
|
||||
blas::Transpose transa,
|
||||
blas::Transpose transb) {
|
||||
cublasLtMatmulDesc_t desc;
|
||||
cublasComputeType_t cublas_compute_type =
|
||||
CUBLASComputationType(computation_type);
|
||||
@ -3154,9 +3166,13 @@ UniqueOpDesc CreateCublasLtOperationDesc(
|
||||
}
|
||||
UniqueOpDesc unique_desc(desc);
|
||||
if (!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_POINTER_MODE,
|
||||
CUBLASPointerMode(pointer_mode)) ||
|
||||
!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSA, cuda_transa) ||
|
||||
!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSB, cuda_transb)) {
|
||||
CUBLASPointerMode(pointer_mode)) ||
|
||||
!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_EPILOGUE,
|
||||
CUBLASEpilogue(epilogue)) ||
|
||||
!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSA,
|
||||
CUDABlasTranspose(transa)) ||
|
||||
!SetCublasLtAttr(desc, CUBLASLT_MATMUL_DESC_TRANSB,
|
||||
CUDABlasTranspose(transb))) {
|
||||
return nullptr;
|
||||
}
|
||||
return unique_desc;
|
||||
@ -3217,11 +3233,11 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
|
||||
public:
|
||||
CUDABlasLtMatmulPlan(blas::DataType ab_type, blas::DataType cd_type,
|
||||
blas::ComputationType compute_type,
|
||||
blas::PointerMode pointer_mode, blas::Transpose transa,
|
||||
blas::Transpose transb, uint64 m, uint64 n, uint64 k,
|
||||
int batch_count, int64 lda, int64 stride_a, int64 ldb,
|
||||
int64 stride_b, int64 ldc, int64 stride_c, int64 ldd,
|
||||
int64 stride_d);
|
||||
blas::PointerMode pointer_mode, blas::Epilogue epilogue,
|
||||
blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
uint64 n, uint64 k, int batch_count, int64 lda,
|
||||
int64 stride_a, int64 ldb, int64 stride_b, int64 ldc,
|
||||
int64 stride_c, int64 ldd, int64 stride_d);
|
||||
|
||||
cublasLtMatmulDesc_t op_desc() const { return op_desc_.get(); }
|
||||
cublasLtMatrixLayout_t a_desc() const { return a_desc_.get(); }
|
||||
@ -3234,12 +3250,17 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
|
||||
blas::DataType cd_type() const { return cd_type_; }
|
||||
blas::DataType scale_type() const { return scale_type_; }
|
||||
blas::PointerMode pointer_mode() const { return pointer_mode_; }
|
||||
blas::Epilogue epilogue() const { return epilogue_; }
|
||||
int batch_count() const { return batch_count_; }
|
||||
int64 stride_a() const { return stride_a_; }
|
||||
int64 stride_b() const { return stride_b_; }
|
||||
int64 stride_c() const { return stride_c_; }
|
||||
int64 stride_d() const { return stride_d_; }
|
||||
|
||||
// Note: Must be const to satisfy API. This is always called before the plan
|
||||
// is executed, so the state change is not observed in subsequent executions.
|
||||
bool SetBiasPointer(const void* bias) const;
|
||||
|
||||
private:
|
||||
UniqueOpDesc op_desc_;
|
||||
UniqueLayoutDesc a_desc_;
|
||||
@ -3250,6 +3271,7 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
|
||||
blas::DataType cd_type_;
|
||||
blas::DataType scale_type_;
|
||||
blas::PointerMode pointer_mode_;
|
||||
blas::Epilogue epilogue_;
|
||||
int batch_count_;
|
||||
int64 stride_a_;
|
||||
int64 stride_b_;
|
||||
@ -3260,12 +3282,13 @@ class CUDABlasLtMatmulPlan final : public blas::IBlasLtMatmulPlan {
|
||||
CUDABlasLtMatmulPlan::CUDABlasLtMatmulPlan(
|
||||
blas::DataType ab_type, blas::DataType cd_type,
|
||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
||||
uint64 k, int batch_count, int64 lda, int64 stride_a, int64 ldb,
|
||||
int64 stride_b, int64 ldc, int64 stride_c, int64 ldd, int64 stride_d)
|
||||
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
|
||||
uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a,
|
||||
int64 ldb, int64 stride_b, int64 ldc, int64 stride_c, int64 ldd,
|
||||
int64 stride_d)
|
||||
: op_desc_(CreateCublasLtOperationDesc(
|
||||
computation_type, GetScaleType(cd_type, computation_type),
|
||||
pointer_mode, transa, transb)),
|
||||
pointer_mode, epilogue, transa, transb)),
|
||||
a_desc_(nullptr),
|
||||
b_desc_(nullptr),
|
||||
c_desc_(
|
||||
@ -3276,6 +3299,7 @@ CUDABlasLtMatmulPlan::CUDABlasLtMatmulPlan(
|
||||
cd_type_(cd_type),
|
||||
scale_type_(GetScaleType(cd_type, computation_type)),
|
||||
pointer_mode_(pointer_mode),
|
||||
epilogue_(epilogue),
|
||||
batch_count_(batch_count),
|
||||
stride_a_(stride_a),
|
||||
stride_b_(stride_b),
|
||||
@ -3291,6 +3315,11 @@ CUDABlasLtMatmulPlan::CUDABlasLtMatmulPlan(
|
||||
batch_count);
|
||||
}
|
||||
|
||||
bool CUDABlasLtMatmulPlan::SetBiasPointer(const void* bias) const {
|
||||
return SetCublasLtAttr(op_desc_.get(), CUBLASLT_MATMUL_DESC_BIAS_POINTER,
|
||||
bias);
|
||||
}
|
||||
|
||||
class CUDABlasLtMatmulAlgorithm final : public blas::IBlasLtMatmulAlgorithm {
|
||||
public:
|
||||
CUDABlasLtMatmulAlgorithm(blas::AlgorithmType index,
|
||||
@ -3370,13 +3399,14 @@ std::unique_ptr<blas::IBlasLtMatmulPlan>
|
||||
CUDABlas::CreateBlasLtMatmulPlanStridedBatched(
|
||||
blas::DataType ab_type, blas::DataType cd_type,
|
||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
||||
uint64 k, int batch_count, int64 lda, int64 stride_a, int64 ldb,
|
||||
int64 stride_b, int64 ldc, int64 stride_c) {
|
||||
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
|
||||
uint64 m, uint64 n, uint64 k, int batch_count, int64 lda, int64 stride_a,
|
||||
int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) {
|
||||
#if CUDA_VERSION >= 11000
|
||||
auto result = std::make_unique<CUDABlasLtMatmulPlan>(
|
||||
ab_type, cd_type, computation_type, pointer_mode, transa, transb, m, n, k,
|
||||
batch_count, lda, stride_a, ldb, stride_b, ldc, stride_c, ldc, stride_c);
|
||||
ab_type, cd_type, computation_type, pointer_mode, epilogue, transa,
|
||||
transb, m, n, k, batch_count, lda, stride_a, ldb, stride_b, ldc, stride_c,
|
||||
ldc, stride_c);
|
||||
if (!result->ok()) {
|
||||
result.reset();
|
||||
}
|
||||
@ -3436,7 +3466,8 @@ bool CUDABlas::DoBlasLtMatmulInternalImpl(
|
||||
const HostOrDeviceScalar<ScaleType>& alpha, const ABType* a,
|
||||
const ABType* b, const HostOrDeviceScalar<ScaleType>& beta, const CDType* c,
|
||||
CDType* d, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm) {
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
const CDType* bias) {
|
||||
const auto& cuda_plan = *static_cast<const CUDABlasLtMatmulPlan*>(plan);
|
||||
const auto& cuda_algo =
|
||||
*static_cast<const CUDABlasLtMatmulAlgorithm*>(algorithm);
|
||||
@ -3474,6 +3505,20 @@ bool CUDABlas::DoBlasLtMatmulInternalImpl(
|
||||
"pointer_mode for the given alpha/beta.";
|
||||
return false;
|
||||
}
|
||||
if ((cuda_plan.epilogue() == blas::Epilogue::kBias ||
|
||||
cuda_plan.epilogue() == blas::Epilogue::kBiasThenReLU) !=
|
||||
(bias != nullptr)) {
|
||||
VLOG(2) << "DoBlasLtMatmul returning false because plan has wrong "
|
||||
"epilogue for the given bias pointer.";
|
||||
return false;
|
||||
}
|
||||
if (bias != nullptr) {
|
||||
if (!cuda_plan.SetBiasPointer(bias)) {
|
||||
VLOG(2) << "DoBlasLtMatmul returning false because setting the bias "
|
||||
"pointer failed.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
const ScaleType* alpha_ptr =
|
||||
alpha.is_pointer() ? GpuMemory(alpha.pointer()) : &alpha.value();
|
||||
const ScaleType* beta_ptr =
|
||||
@ -3525,6 +3570,7 @@ bool CUDABlas::DoBlasLtMatmulInternal(
|
||||
const DeviceMemory<CDType>& c, DeviceMemory<CDType>* d,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
const DeviceMemory<CDType>& bias,
|
||||
blas::ProfileResult* output_profile_result) {
|
||||
#if CUDA_VERSION >= 11000
|
||||
std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
|
||||
@ -3538,7 +3584,8 @@ bool CUDABlas::DoBlasLtMatmulInternal(
|
||||
bool err_on_failure = timer != nullptr;
|
||||
bool result = DoBlasLtMatmulInternalImpl(
|
||||
stream, err_on_failure, plan, alpha, GpuMemory(a), GpuMemory(b), beta,
|
||||
GpuMemory(c), GpuMemoryMutable(d), scratch_allocator, algorithm);
|
||||
GpuMemory(c), GpuMemoryMutable(d), scratch_allocator, algorithm,
|
||||
GpuMemory(bias));
|
||||
|
||||
if (timer && result) {
|
||||
// GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
|
||||
@ -3563,9 +3610,10 @@ bool CUDABlas::DoBlasLtMatmul(
|
||||
const DeviceMemory<int8>& b, const HostOrDeviceScalar<int32>& beta,
|
||||
DeviceMemory<int32>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
const DeviceMemory<int32>& bias,
|
||||
blas::ProfileResult* output_profile_result) {
|
||||
return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
|
||||
scratch_allocator, algorithm,
|
||||
scratch_allocator, algorithm, bias,
|
||||
output_profile_result);
|
||||
}
|
||||
|
||||
@ -3578,6 +3626,7 @@ bool CUDABlas::DoBlasLtMatmul(Stream* stream,
|
||||
DeviceMemory<Eigen::half>* c,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
const DeviceMemory<Eigen::half>& bias,
|
||||
blas::ProfileResult* output_profile_result) {
|
||||
#if CUDA_VERSION >= 11000
|
||||
const auto& cuda_plan = *static_cast<const CUDABlasLtMatmulPlan*>(plan);
|
||||
@ -3591,11 +3640,11 @@ bool CUDABlas::DoBlasLtMatmul(Stream* stream,
|
||||
HostOrDeviceScalar<float> float_alpha(static_cast<float>(alpha.value()));
|
||||
HostOrDeviceScalar<float> float_beta(static_cast<float>(beta.value()));
|
||||
return DoBlasLtMatmulInternal(stream, plan, float_alpha, a, b, float_beta,
|
||||
*c, c, scratch_allocator, algorithm,
|
||||
*c, c, scratch_allocator, algorithm, bias,
|
||||
output_profile_result);
|
||||
}
|
||||
return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
|
||||
scratch_allocator, algorithm,
|
||||
scratch_allocator, algorithm, bias,
|
||||
output_profile_result);
|
||||
#else // if CUDA_VERSION < 11000
|
||||
return false;
|
||||
@ -3608,9 +3657,10 @@ bool CUDABlas::DoBlasLtMatmul(
|
||||
const DeviceMemory<float>& b, const HostOrDeviceScalar<float>& beta,
|
||||
DeviceMemory<float>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
const DeviceMemory<float>& bias,
|
||||
blas::ProfileResult* output_profile_result) {
|
||||
return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
|
||||
scratch_allocator, algorithm,
|
||||
scratch_allocator, algorithm, bias,
|
||||
output_profile_result);
|
||||
}
|
||||
|
||||
@ -3620,9 +3670,10 @@ bool CUDABlas::DoBlasLtMatmul(
|
||||
const DeviceMemory<double>& b, const HostOrDeviceScalar<double>& beta,
|
||||
DeviceMemory<double>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
const DeviceMemory<double>& bias,
|
||||
blas::ProfileResult* output_profile_result) {
|
||||
return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
|
||||
scratch_allocator, algorithm,
|
||||
scratch_allocator, algorithm, bias,
|
||||
output_profile_result);
|
||||
}
|
||||
|
||||
@ -3634,9 +3685,10 @@ bool CUDABlas::DoBlasLtMatmul(
|
||||
const HostOrDeviceScalar<std::complex<float>>& beta,
|
||||
DeviceMemory<std::complex<float>>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
const DeviceMemory<std::complex<float>>& bias,
|
||||
blas::ProfileResult* output_profile_result) {
|
||||
return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
|
||||
scratch_allocator, algorithm,
|
||||
scratch_allocator, algorithm, bias,
|
||||
output_profile_result);
|
||||
}
|
||||
|
||||
@ -3648,9 +3700,10 @@ bool CUDABlas::DoBlasLtMatmul(
|
||||
const HostOrDeviceScalar<std::complex<double>>& beta,
|
||||
DeviceMemory<std::complex<double>>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
const DeviceMemory<std::complex<double>>& bias,
|
||||
blas::ProfileResult* output_profile_result) {
|
||||
return DoBlasLtMatmulInternal(stream, plan, alpha, a, b, beta, *c, c,
|
||||
scratch_allocator, algorithm,
|
||||
scratch_allocator, algorithm, bias,
|
||||
output_profile_result);
|
||||
}
|
||||
|
||||
|
@ -148,6 +148,7 @@ class CUDABlas : public blas::BlasSupport {
|
||||
const DeviceMemory<CDType>& c, DeviceMemory<CDType>* d,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
const DeviceMemory<CDType>& bias,
|
||||
blas::ProfileResult* output_profile_result);
|
||||
|
||||
// Helper function for implementing DoBlasLtMatmulInternal.
|
||||
@ -157,7 +158,7 @@ class CUDABlas : public blas::BlasSupport {
|
||||
const HostOrDeviceScalar<ScaleType>& alpha, const ABType* a,
|
||||
const ABType* b, const HostOrDeviceScalar<ScaleType>& beta,
|
||||
const CDType* c, CDType* d, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm);
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm, const CDType* bias);
|
||||
|
||||
// Guards the cuBLAS handle for this device.
|
||||
absl::Mutex mu_;
|
||||
|
@ -4809,18 +4809,19 @@ Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
|
||||
DeviceMemory<int32>* c,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
const DeviceMemory<int32>& bias,
|
||||
blas::ProfileResult* output_profile_result) {
|
||||
VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
|
||||
PARAM(c), PARAM(algorithm));
|
||||
PARAM(c), PARAM(algorithm), PARAM(bias));
|
||||
|
||||
ThenBlasWithProfileImpl<
|
||||
const blas::IBlasLtMatmulPlan*, const HostOrDeviceScalar<int32>&,
|
||||
const DeviceMemory<int8>&, const DeviceMemory<int8>&,
|
||||
const HostOrDeviceScalar<int32>&, DeviceMemory<int32>*, ScratchAllocator*,
|
||||
const blas::IBlasLtMatmulAlgorithm*>
|
||||
const blas::IBlasLtMatmulAlgorithm*, const DeviceMemory<int32>&>
|
||||
impl;
|
||||
return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
|
||||
c, scratch_allocator, algorithm, output_profile_result);
|
||||
c, scratch_allocator, algorithm, bias, output_profile_result);
|
||||
}
|
||||
|
||||
Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
|
||||
@ -4831,18 +4832,20 @@ Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
|
||||
DeviceMemory<Eigen::half>* c,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
const DeviceMemory<Eigen::half>& bias,
|
||||
blas::ProfileResult* output_profile_result) {
|
||||
VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
|
||||
PARAM(c), PARAM(algorithm));
|
||||
PARAM(c), PARAM(algorithm), PARAM(bias));
|
||||
|
||||
ThenBlasWithProfileImpl<
|
||||
const blas::IBlasLtMatmulPlan*, const HostOrDeviceScalar<Eigen::half>&,
|
||||
const DeviceMemory<Eigen::half>&, const DeviceMemory<Eigen::half>&,
|
||||
const HostOrDeviceScalar<Eigen::half>&, DeviceMemory<Eigen::half>*,
|
||||
ScratchAllocator*, const blas::IBlasLtMatmulAlgorithm*>
|
||||
ScratchAllocator*, const blas::IBlasLtMatmulAlgorithm*,
|
||||
const DeviceMemory<Eigen::half>&>
|
||||
impl;
|
||||
return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
|
||||
c, scratch_allocator, algorithm, output_profile_result);
|
||||
c, scratch_allocator, algorithm, bias, output_profile_result);
|
||||
}
|
||||
|
||||
Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
|
||||
@ -4853,18 +4856,19 @@ Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
|
||||
DeviceMemory<float>* c,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
const DeviceMemory<float>& bias,
|
||||
blas::ProfileResult* output_profile_result) {
|
||||
VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
|
||||
PARAM(c), PARAM(algorithm));
|
||||
PARAM(c), PARAM(algorithm), PARAM(bias));
|
||||
|
||||
ThenBlasWithProfileImpl<
|
||||
const blas::IBlasLtMatmulPlan*, const HostOrDeviceScalar<float>&,
|
||||
const DeviceMemory<float>&, const DeviceMemory<float>&,
|
||||
const HostOrDeviceScalar<float>&, DeviceMemory<float>*, ScratchAllocator*,
|
||||
const blas::IBlasLtMatmulAlgorithm*>
|
||||
const blas::IBlasLtMatmulAlgorithm*, const DeviceMemory<float>&>
|
||||
impl;
|
||||
return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
|
||||
c, scratch_allocator, algorithm, output_profile_result);
|
||||
c, scratch_allocator, algorithm, bias, output_profile_result);
|
||||
}
|
||||
|
||||
Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
|
||||
@ -4875,18 +4879,20 @@ Stream& Stream::ThenBlasLtMatmul(const blas::IBlasLtMatmulPlan* plan,
|
||||
DeviceMemory<double>* c,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
const DeviceMemory<double>& bias,
|
||||
blas::ProfileResult* output_profile_result) {
|
||||
VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
|
||||
PARAM(c), PARAM(algorithm));
|
||||
PARAM(c), PARAM(algorithm), PARAM(bias));
|
||||
|
||||
ThenBlasWithProfileImpl<
|
||||
const blas::IBlasLtMatmulPlan*, const HostOrDeviceScalar<double>&,
|
||||
const DeviceMemory<double>&, const DeviceMemory<double>&,
|
||||
const HostOrDeviceScalar<double>&, DeviceMemory<double>*,
|
||||
ScratchAllocator*, const blas::IBlasLtMatmulAlgorithm*>
|
||||
ScratchAllocator*, const blas::IBlasLtMatmulAlgorithm*,
|
||||
const DeviceMemory<double>&>
|
||||
impl;
|
||||
return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
|
||||
c, scratch_allocator, algorithm, output_profile_result);
|
||||
c, scratch_allocator, algorithm, bias, output_profile_result);
|
||||
}
|
||||
|
||||
Stream& Stream::ThenBlasLtMatmul(
|
||||
@ -4897,9 +4903,10 @@ Stream& Stream::ThenBlasLtMatmul(
|
||||
const HostOrDeviceScalar<std::complex<float>>& beta,
|
||||
DeviceMemory<std::complex<float>>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
const DeviceMemory<std::complex<float>>& bias,
|
||||
blas::ProfileResult* output_profile_result) {
|
||||
VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
|
||||
PARAM(c), PARAM(algorithm));
|
||||
PARAM(c), PARAM(algorithm), PARAM(bias));
|
||||
|
||||
ThenBlasWithProfileImpl<const blas::IBlasLtMatmulPlan*,
|
||||
const HostOrDeviceScalar<std::complex<float>>&,
|
||||
@ -4907,10 +4914,11 @@ Stream& Stream::ThenBlasLtMatmul(
|
||||
const DeviceMemory<std::complex<float>>&,
|
||||
const HostOrDeviceScalar<std::complex<float>>&,
|
||||
DeviceMemory<std::complex<float>>*, ScratchAllocator*,
|
||||
const blas::IBlasLtMatmulAlgorithm*>
|
||||
const blas::IBlasLtMatmulAlgorithm*,
|
||||
const DeviceMemory<std::complex<float>>&>
|
||||
impl;
|
||||
return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
|
||||
c, scratch_allocator, algorithm, output_profile_result);
|
||||
c, scratch_allocator, algorithm, bias, output_profile_result);
|
||||
}
|
||||
|
||||
Stream& Stream::ThenBlasLtMatmul(
|
||||
@ -4921,9 +4929,10 @@ Stream& Stream::ThenBlasLtMatmul(
|
||||
const HostOrDeviceScalar<std::complex<double>>& beta,
|
||||
DeviceMemory<std::complex<double>>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
const DeviceMemory<std::complex<double>>& bias,
|
||||
blas::ProfileResult* output_profile_result) {
|
||||
VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
|
||||
PARAM(c), PARAM(algorithm));
|
||||
PARAM(c), PARAM(algorithm), PARAM(bias));
|
||||
|
||||
ThenBlasWithProfileImpl<const blas::IBlasLtMatmulPlan*,
|
||||
const HostOrDeviceScalar<std::complex<double>>&,
|
||||
@ -4932,10 +4941,11 @@ Stream& Stream::ThenBlasLtMatmul(
|
||||
const HostOrDeviceScalar<std::complex<double>>&,
|
||||
DeviceMemory<std::complex<double>>*,
|
||||
ScratchAllocator*,
|
||||
const blas::IBlasLtMatmulAlgorithm*>
|
||||
const blas::IBlasLtMatmulAlgorithm*,
|
||||
const DeviceMemory<std::complex<double>>&>
|
||||
impl;
|
||||
return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
|
||||
c, scratch_allocator, algorithm, output_profile_result);
|
||||
c, scratch_allocator, algorithm, bias, output_profile_result);
|
||||
}
|
||||
|
||||
Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
|
||||
|
@ -1672,6 +1672,7 @@ class Stream {
|
||||
const DeviceMemory<int8>& b, const HostOrDeviceScalar<int32>& beta,
|
||||
DeviceMemory<int32>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
const DeviceMemory<int32>& bias = {},
|
||||
blas::ProfileResult* output_profile_result = nullptr);
|
||||
Stream& ThenBlasLtMatmul(
|
||||
const blas::IBlasLtMatmulPlan* plan,
|
||||
@ -1680,6 +1681,7 @@ class Stream {
|
||||
const HostOrDeviceScalar<Eigen::half>& beta, DeviceMemory<Eigen::half>* c,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
const DeviceMemory<Eigen::half>& bias = {},
|
||||
blas::ProfileResult* output_profile_result = nullptr);
|
||||
Stream& ThenBlasLtMatmul(
|
||||
const blas::IBlasLtMatmulPlan* plan,
|
||||
@ -1687,6 +1689,7 @@ class Stream {
|
||||
const DeviceMemory<float>& b, const HostOrDeviceScalar<float>& beta,
|
||||
DeviceMemory<float>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
const DeviceMemory<float>& bias = {},
|
||||
blas::ProfileResult* output_profile_result = nullptr);
|
||||
Stream& ThenBlasLtMatmul(
|
||||
const blas::IBlasLtMatmulPlan* plan,
|
||||
@ -1694,6 +1697,7 @@ class Stream {
|
||||
const DeviceMemory<double>& b, const HostOrDeviceScalar<double>& beta,
|
||||
DeviceMemory<double>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
const DeviceMemory<double>& bias = {},
|
||||
blas::ProfileResult* output_profile_result = nullptr);
|
||||
Stream& ThenBlasLtMatmul(
|
||||
const blas::IBlasLtMatmulPlan* plan,
|
||||
@ -1703,6 +1707,7 @@ class Stream {
|
||||
const HostOrDeviceScalar<std::complex<float>>& beta,
|
||||
DeviceMemory<std::complex<float>>* c, ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
const DeviceMemory<std::complex<float>>& bias = {},
|
||||
blas::ProfileResult* output_profile_result = nullptr);
|
||||
Stream& ThenBlasLtMatmul(
|
||||
const blas::IBlasLtMatmulPlan* plan,
|
||||
@ -1713,6 +1718,7 @@ class Stream {
|
||||
DeviceMemory<std::complex<double>>* c,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
const blas::IBlasLtMatmulAlgorithm* algorithm,
|
||||
const DeviceMemory<std::complex<double>>& bias = {},
|
||||
blas::ProfileResult* output_profile_result = nullptr);
|
||||
|
||||
// See FftSupport::DoFft.
|
||||
|
@ -339,31 +339,32 @@ bool StreamExecutor::GetBlasGemmAlgorithms(
|
||||
std::unique_ptr<blas::IBlasLtMatmulPlan> StreamExecutor::CreateBlasLtMatmulPlan(
|
||||
blas::DataType ab_type, blas::DataType cd_type,
|
||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
||||
uint64 k, int64 lda, int64 ldb, int64 ldc) {
|
||||
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
|
||||
uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc) {
|
||||
blas::BlasSupport *blas_support = AsBlas();
|
||||
if (!blas_support) {
|
||||
return nullptr;
|
||||
}
|
||||
return blas_support->CreateBlasLtMatmulPlan(
|
||||
ab_type, cd_type, computation_type, pointer_mode, transa, transb, m, n, k,
|
||||
lda, ldb, ldc);
|
||||
ab_type, cd_type, computation_type, pointer_mode, epilogue, transa,
|
||||
transb, m, n, k, lda, ldb, ldc);
|
||||
}
|
||||
|
||||
std::unique_ptr<blas::IBlasLtMatmulPlan>
|
||||
StreamExecutor::CreateBlasLtMatmulPlanStridedBatched(
|
||||
blas::DataType ab_type, blas::DataType cd_type,
|
||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
||||
uint64 k, uint64 batch_count, int64 lda, int64 stride_a, int64 ldb,
|
||||
int64 stride_b, int64 ldc, int64 stride_c) {
|
||||
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
|
||||
uint64 m, uint64 n, uint64 k, uint64 batch_count, int64 lda, int64 stride_a,
|
||||
int64 ldb, int64 stride_b, int64 ldc, int64 stride_c) {
|
||||
blas::BlasSupport *blas_support = AsBlas();
|
||||
if (!blas_support) {
|
||||
return nullptr;
|
||||
}
|
||||
return blas_support->CreateBlasLtMatmulPlanStridedBatched(
|
||||
ab_type, cd_type, computation_type, pointer_mode, transa, transb, m, n, k,
|
||||
batch_count, lda, stride_a, ldb, stride_b, ldc, stride_c);
|
||||
ab_type, cd_type, computation_type, pointer_mode, epilogue, transa,
|
||||
transb, m, n, k, batch_count, lda, stride_a, ldb, stride_b, ldc,
|
||||
stride_c);
|
||||
}
|
||||
|
||||
bool StreamExecutor::GetBlasLtMatmulAlgorithms(
|
||||
|
@ -401,17 +401,17 @@ class StreamExecutor {
|
||||
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlan(
|
||||
blas::DataType ab_type, blas::DataType cd_type,
|
||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
||||
uint64 k, int64 lda, int64 ldb, int64 ldc);
|
||||
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
|
||||
uint64 m, uint64 n, uint64 k, int64 lda, int64 ldb, int64 ldc);
|
||||
|
||||
// A more general version of CreateBlasLtMatmulPlan supporting
|
||||
// batched operations.
|
||||
std::unique_ptr<blas::IBlasLtMatmulPlan> CreateBlasLtMatmulPlanStridedBatched(
|
||||
blas::DataType ab_type, blas::DataType cd_type,
|
||||
blas::ComputationType computation_type, blas::PointerMode pointer_mode,
|
||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
||||
uint64 k, uint64 batch_count, int64 lda, int64 stride_a, int64 ldb,
|
||||
int64 stride_b, int64 ldc, int64 stride_c);
|
||||
blas::Epilogue epilogue, blas::Transpose transa, blas::Transpose transb,
|
||||
uint64 m, uint64 n, uint64 k, uint64 batch_count, int64 lda,
|
||||
int64 stride_a, int64 ldb, int64 stride_b, int64 ldc, int64 stride_c);
|
||||
|
||||
// Gets a list of supported algorithms for DoBlasLtMatmul. The algorithms are
|
||||
// returned in the order of increasing estimated compute time according to an
|
||||
|
Loading…
Reference in New Issue
Block a user