Add support for int8 x int8 -> int32 matrix multiplication via cublasGemmEx to stream_executor.
PiperOrigin-RevId: 161137741
This commit is contained in:
parent
755fa7b501
commit
a2ee8bca3f
@ -97,8 +97,9 @@ enum class ComputationType {
|
||||
kF16, // 16-bit floating-point
|
||||
kF32, // 32-bit floating-point
|
||||
kF64, // 64-bit floating-point
|
||||
kI32, // 32-bit integer
|
||||
kComplexF32, // Complex number comprised of two f32s.
|
||||
kComplexF64 // Complex number comprised of two f64s.
|
||||
kComplexF64, // Complex number comprised of two f64s.
|
||||
};
|
||||
|
||||
// Converts a ComputationType to a string.
|
||||
@ -108,6 +109,15 @@ string ComputationTypeString(ComputationType ty);
|
||||
// as a hint to the blas library.
|
||||
typedef int64 AlgorithmType;
|
||||
|
||||
// blas uses -1 to represent the default algorithm. This happens to match up
|
||||
// with the CUBLAS_GEMM_DFALT constant, so cuda_blas.cc is using static_cast
|
||||
// to convert from AlgorithmType to cublasGemmAlgo_t, and uses a static_assert
|
||||
// to ensure that this assumption does not break.
|
||||
// If another blas implementation uses a different value for the default
|
||||
// algorithm, then it needs to convert kDefaultGemmAlgo to that value
|
||||
// (e.g. via a function called ToWhateverGemmAlgo).
|
||||
constexpr AlgorithmType kDefaultGemmAlgo = -1;
|
||||
|
||||
// Describes the result of a performance experiment, usually timing the speed of
|
||||
// a particular AlgorithmType.
|
||||
//
|
||||
@ -944,6 +954,12 @@ class BlasSupport {
|
||||
// output_profile_result->is_valid(). This lets you use this function for
|
||||
// choosing the best algorithm among many (some of which may fail) without
|
||||
// creating a new Stream for each attempt.
|
||||
virtual bool DoBlasGemmWithAlgorithm(
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
uint64 n, uint64 k, int alpha, const DeviceMemory<int8> &a, int lda,
|
||||
const DeviceMemory<int8> &b, int ldb, int beta, DeviceMemory<int32> *c,
|
||||
int ldc, ComputationType computation_type, AlgorithmType algorithm,
|
||||
ProfileResult *output_profile_result) = 0;
|
||||
virtual bool DoBlasGemmWithAlgorithm(
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
uint64 n, uint64 k, const Eigen::half &alpha,
|
||||
@ -1737,6 +1753,13 @@ class BlasSupport {
|
||||
DeviceMemory<std::complex<double>> *c, int ldc) override; \
|
||||
bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms) \
|
||||
override; \
|
||||
bool DoBlasGemmWithAlgorithm( \
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
|
||||
uint64 m, uint64 n, uint64 k, int alpha, const DeviceMemory<int8> &a, \
|
||||
int lda, const DeviceMemory<int8> &b, int ldb, int beta, \
|
||||
DeviceMemory<int> *c, int ldc, blas::ComputationType computation_type, \
|
||||
blas::AlgorithmType algorithm, \
|
||||
blas::ProfileResult *output_profile_result) override; \
|
||||
bool DoBlasGemmWithAlgorithm( \
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
|
||||
uint64 m, uint64 n, uint64 k, const Eigen::half &alpha, \
|
||||
|
@ -33,6 +33,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/stream_executor/cuda/cuda_blas.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <complex>
|
||||
|
||||
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
|
||||
@ -483,6 +484,11 @@ struct CUDADataType<std::complex<double>> {
|
||||
static constexpr cudaDataType_t type = CUDA_C_64F;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CUDADataType<int> {
|
||||
static constexpr cudaDataType_t type = CUDA_R_32I;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CUDADataType<int8> {
|
||||
static constexpr cudaDataType_t type = CUDA_R_8I;
|
||||
@ -511,6 +517,8 @@ cudaDataType_t CUDAComputationType(blas::ComputationType ty) {
|
||||
return CUDA_R_32F;
|
||||
case blas::ComputationType::kF64:
|
||||
return CUDA_R_64F;
|
||||
case blas::ComputationType::kI32:
|
||||
return CUDA_R_32I;
|
||||
case blas::ComputationType::kComplexF32:
|
||||
return CUDA_C_32F;
|
||||
case blas::ComputationType::kComplexF64:
|
||||
@ -1849,12 +1857,12 @@ bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
||||
CUDAComplex(CUDAMemoryMutable(c)), ldc);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename InT, typename OutT, typename CompT>
|
||||
bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
uint64 n, uint64 k, const T &alpha, const DeviceMemory<T> &a, int lda,
|
||||
const DeviceMemory<T> &b, int ldb, const T &beta, DeviceMemory<T> *c,
|
||||
int ldc, blas::ComputationType computation_type,
|
||||
uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a, int lda,
|
||||
const DeviceMemory<InT> &b, int ldb, const CompT &beta,
|
||||
DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type,
|
||||
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
|
||||
// CUDA < version 8 and GPUs < sm_50 don't support cublasGemmEx.
|
||||
#if CUDA_VERSION < 8000
|
||||
@ -1881,12 +1889,15 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
|
||||
}
|
||||
}
|
||||
|
||||
cudaDataType_t data_type = CUDADataType<T>::type;
|
||||
cudaDataType_t cuda_in_type = CUDADataType<InT>::type;
|
||||
// Since we are converting 'algorithm' to cublasGemmAlgo_t by static_cast,
|
||||
// we do the following compile-time check on the default value:
|
||||
static_assert(blas::kDefaultGemmAlgo == CUBLAS_GEMM_DFALT, "");
|
||||
bool result = DoBlasInternalFailureOK(
|
||||
wrap::cublasGemmEx, stream, /* pointer_mode_host = */ true,
|
||||
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
|
||||
CUDAMemory(a), data_type, lda, CUDAMemory(b), data_type, ldb, &beta,
|
||||
CUDAMemoryMutable(c), data_type, ldc,
|
||||
CUDAMemory(a), cuda_in_type, lda, CUDAMemory(b), cuda_in_type, ldb, &beta,
|
||||
CUDAMemoryMutable(c), CUDADataType<OutT>::type, ldc,
|
||||
CUDAComputationType(computation_type),
|
||||
static_cast<cublasGemmAlgo_t>(algorithm));
|
||||
|
||||
@ -1920,6 +1931,17 @@ bool CUDABlas::GetBlasGemmAlgorithms(
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CUDABlas::DoBlasGemmWithAlgorithm(
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
uint64 n, uint64 k, int alpha, const DeviceMemory<int8> &a, int lda,
|
||||
const DeviceMemory<int8> &b, int ldb, int beta, DeviceMemory<int> *c,
|
||||
int ldc, blas::ComputationType computation_type,
|
||||
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
|
||||
return DoBlasGemmWithAlgorithmImpl(
|
||||
stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
||||
computation_type, algorithm, output_profile_result);
|
||||
}
|
||||
|
||||
bool CUDABlas::DoBlasGemmWithAlgorithm(
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
uint64 n, uint64 k, const Eigen::half &alpha,
|
||||
|
@ -118,14 +118,12 @@ class CUDABlas : public blas::BlasSupport {
|
||||
// and we want to avoid pulling in a dependency on Eigen. When we pass the
|
||||
// references to cublas, we essentially reinterpret_cast to __half, which is
|
||||
// safe because Eigen::half inherits from __half.
|
||||
template <typename T>
|
||||
bool DoBlasGemmWithAlgorithmImpl(Stream *stream, blas::Transpose transa,
|
||||
blas::Transpose transb, uint64 m, uint64 n,
|
||||
uint64 k, const T &alpha,
|
||||
const DeviceMemory<T> &a, int lda,
|
||||
const DeviceMemory<T> &b, int ldb,
|
||||
const T &beta, DeviceMemory<T> *c, int ldc,
|
||||
blas::ComputationType computation_type,
|
||||
template <typename InT, typename OutT, typename CompT>
|
||||
bool DoBlasGemmWithAlgorithmImpl(
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a,
|
||||
int lda, const DeviceMemory<InT> &b, int ldb, const CompT &beta,
|
||||
DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type,
|
||||
blas::AlgorithmType algorithm,
|
||||
blas::ProfileResult *output_profile_result);
|
||||
|
||||
|
@ -3482,6 +3482,27 @@ Stream &Stream::ThenBlasGemmWithAlgorithm(
|
||||
algorithm, output_profile_result);
|
||||
}
|
||||
|
||||
Stream &Stream::ThenBlasGemmWithAlgorithm(
|
||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
||||
uint64 k, int alpha, const DeviceMemory<int8> &a, int lda,
|
||||
const DeviceMemory<int8> &b, int ldb, int beta, DeviceMemory<int> *c,
|
||||
int ldc, blas::ComputationType computation_type,
|
||||
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
|
||||
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
|
||||
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
|
||||
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
|
||||
PARAM(algorithm));
|
||||
|
||||
ThenBlasWithProfileImpl<
|
||||
blas::Transpose, blas::Transpose, uint64, uint64, uint64, int,
|
||||
const DeviceMemory<int8> &, int, const DeviceMemory<int8> &, int, int,
|
||||
DeviceMemory<int> *, int, blas::ComputationType, blas::AlgorithmType>
|
||||
impl;
|
||||
return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
|
||||
m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
|
||||
algorithm, output_profile_result);
|
||||
}
|
||||
|
||||
Stream &Stream::ThenBlasGemmWithAlgorithm(
|
||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
||||
uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
|
||||
|
@ -1257,6 +1257,15 @@ class Stream {
|
||||
const Eigen::half &beta, DeviceMemory<Eigen::half> *c, int ldc,
|
||||
blas::ComputationType computation_type, blas::AlgorithmType algorithm,
|
||||
blas::ProfileResult *output_profile_result);
|
||||
Stream &ThenBlasGemmWithAlgorithm(blas::Transpose transa,
|
||||
blas::Transpose transb, uint64 m, uint64 n,
|
||||
uint64 k, int alpha,
|
||||
const DeviceMemory<int8> &a, int lda,
|
||||
const DeviceMemory<int8> &b, int ldb,
|
||||
int beta, DeviceMemory<int> *c, int ldc,
|
||||
blas::ComputationType computation_type,
|
||||
blas::AlgorithmType algorithm,
|
||||
blas::ProfileResult *output_profile_result);
|
||||
Stream &ThenBlasGemmWithAlgorithm(blas::Transpose transa,
|
||||
blas::Transpose transb, uint64 m, uint64 n,
|
||||
uint64 k, float alpha,
|
||||
|
Loading…
Reference in New Issue
Block a user