Add support for int8 x int8 -> int32 matrix multiplication via cublasGemmEx to stream_executor.

PiperOrigin-RevId: 161137741
This commit is contained in:
A. Unique TensorFlower 2017-07-06 15:15:27 -07:00 committed by TensorFlower Gardener
parent 755fa7b501
commit a2ee8bca3f
5 changed files with 91 additions and 18 deletions

View File

@ -97,8 +97,9 @@ enum class ComputationType {
kF16, // 16-bit floating-point
kF32, // 32-bit floating-point
kF64, // 64-bit floating-point
kI32, // 32-bit integer
kComplexF32, // Complex number comprised of two f32s.
kComplexF64 // Complex number comprised of two f64s.
kComplexF64, // Complex number comprised of two f64s.
};
// Converts a ComputationType to a string.
@ -108,6 +109,15 @@ string ComputationTypeString(ComputationType ty);
// as a hint to the blas library.
typedef int64 AlgorithmType;
// blas uses -1 to represent the default algorithm. This happens to match up
// with the CUBLAS_GEMM_DFALT constant, so cuda_blas.cc is using static_cast
// to convert from AlgorithmType to cublasGemmAlgo_t, and uses a static_assert
// to ensure that this assumption does not break.
// If another blas implementation uses a different value for the default
// algorithm, then it needs to convert kDefaultGemmAlgo to that value
// (e.g. via a function called ToWhateverGemmAlgo).
constexpr AlgorithmType kDefaultGemmAlgo = -1;
// Describes the result of a performance experiment, usually timing the speed of
// a particular AlgorithmType.
//
@ -944,6 +954,12 @@ class BlasSupport {
// output_profile_result->is_valid(). This lets you use this function for
// choosing the best algorithm among many (some of which may fail) without
// creating a new Stream for each attempt.
virtual bool DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, int alpha, const DeviceMemory<int8> &a, int lda,
const DeviceMemory<int8> &b, int ldb, int beta, DeviceMemory<int32> *c,
int ldc, ComputationType computation_type, AlgorithmType algorithm,
ProfileResult *output_profile_result) = 0;
virtual bool DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const Eigen::half &alpha,
@ -1737,6 +1753,13 @@ class BlasSupport {
DeviceMemory<std::complex<double>> *c, int ldc) override; \
bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms) \
override; \
bool DoBlasGemmWithAlgorithm( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
uint64 m, uint64 n, uint64 k, int alpha, const DeviceMemory<int8> &a, \
int lda, const DeviceMemory<int8> &b, int ldb, int beta, \
DeviceMemory<int> *c, int ldc, blas::ComputationType computation_type, \
blas::AlgorithmType algorithm, \
blas::ProfileResult *output_profile_result) override; \
bool DoBlasGemmWithAlgorithm( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
uint64 m, uint64 n, uint64 k, const Eigen::half &alpha, \

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_blas.h"
#include <assert.h>
#include <complex>
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
@ -483,6 +484,11 @@ struct CUDADataType<std::complex<double>> {
static constexpr cudaDataType_t type = CUDA_C_64F;
};
template <>
struct CUDADataType<int> {
static constexpr cudaDataType_t type = CUDA_R_32I;
};
template <>
struct CUDADataType<int8> {
static constexpr cudaDataType_t type = CUDA_R_8I;
@ -511,6 +517,8 @@ cudaDataType_t CUDAComputationType(blas::ComputationType ty) {
return CUDA_R_32F;
case blas::ComputationType::kF64:
return CUDA_R_64F;
case blas::ComputationType::kI32:
return CUDA_R_32I;
case blas::ComputationType::kComplexF32:
return CUDA_C_32F;
case blas::ComputationType::kComplexF64:
@ -1849,12 +1857,12 @@ bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
CUDAComplex(CUDAMemoryMutable(c)), ldc);
}
template <typename T>
template <typename InT, typename OutT, typename CompT>
bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const T &alpha, const DeviceMemory<T> &a, int lda,
const DeviceMemory<T> &b, int ldb, const T &beta, DeviceMemory<T> *c,
int ldc, blas::ComputationType computation_type,
uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a, int lda,
const DeviceMemory<InT> &b, int ldb, const CompT &beta,
DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type,
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
// CUDA < version 8 and GPUs < sm_50 don't support cublasGemmEx.
#if CUDA_VERSION < 8000
@ -1881,12 +1889,15 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
}
}
cudaDataType_t data_type = CUDADataType<T>::type;
cudaDataType_t cuda_in_type = CUDADataType<InT>::type;
// Since we are converting 'algorithm' to cublasGemmAlgo_t by static_cast,
// we do the following compile-time check on the default value:
static_assert(blas::kDefaultGemmAlgo == CUBLAS_GEMM_DFALT, "");
bool result = DoBlasInternalFailureOK(
wrap::cublasGemmEx, stream, /* pointer_mode_host = */ true,
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
CUDAMemory(a), data_type, lda, CUDAMemory(b), data_type, ldb, &beta,
CUDAMemoryMutable(c), data_type, ldc,
CUDAMemory(a), cuda_in_type, lda, CUDAMemory(b), cuda_in_type, ldb, &beta,
CUDAMemoryMutable(c), CUDADataType<OutT>::type, ldc,
CUDAComputationType(computation_type),
static_cast<cublasGemmAlgo_t>(algorithm));
@ -1920,6 +1931,17 @@ bool CUDABlas::GetBlasGemmAlgorithms(
return true;
}
bool CUDABlas::DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, int alpha, const DeviceMemory<int8> &a, int lda,
const DeviceMemory<int8> &b, int ldb, int beta, DeviceMemory<int> *c,
int ldc, blas::ComputationType computation_type,
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
return DoBlasGemmWithAlgorithmImpl(
stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
computation_type, algorithm, output_profile_result);
}
bool CUDABlas::DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const Eigen::half &alpha,

View File

@ -118,14 +118,12 @@ class CUDABlas : public blas::BlasSupport {
// and we want to avoid pulling in a dependency on Eigen. When we pass the
// references to cublas, we essentially reinterpret_cast to __half, which is
// safe because Eigen::half inherits from __half.
template <typename T>
bool DoBlasGemmWithAlgorithmImpl(Stream *stream, blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n,
uint64 k, const T &alpha,
const DeviceMemory<T> &a, int lda,
const DeviceMemory<T> &b, int ldb,
const T &beta, DeviceMemory<T> *c, int ldc,
blas::ComputationType computation_type,
template <typename InT, typename OutT, typename CompT>
bool DoBlasGemmWithAlgorithmImpl(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a,
int lda, const DeviceMemory<InT> &b, int ldb, const CompT &beta,
DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type,
blas::AlgorithmType algorithm,
blas::ProfileResult *output_profile_result);

View File

@ -3482,6 +3482,27 @@ Stream &Stream::ThenBlasGemmWithAlgorithm(
algorithm, output_profile_result);
}
Stream &Stream::ThenBlasGemmWithAlgorithm(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, int alpha, const DeviceMemory<int8> &a, int lda,
const DeviceMemory<int8> &b, int ldb, int beta, DeviceMemory<int> *c,
int ldc, blas::ComputationType computation_type,
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
PARAM(algorithm));
ThenBlasWithProfileImpl<
blas::Transpose, blas::Transpose, uint64, uint64, uint64, int,
const DeviceMemory<int8> &, int, const DeviceMemory<int8> &, int, int,
DeviceMemory<int> *, int, blas::ComputationType, blas::AlgorithmType>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
algorithm, output_profile_result);
}
Stream &Stream::ThenBlasGemmWithAlgorithm(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, float alpha, const DeviceMemory<float> &a, int lda,

View File

@ -1257,6 +1257,15 @@ class Stream {
const Eigen::half &beta, DeviceMemory<Eigen::half> *c, int ldc,
blas::ComputationType computation_type, blas::AlgorithmType algorithm,
blas::ProfileResult *output_profile_result);
Stream &ThenBlasGemmWithAlgorithm(blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n,
uint64 k, int alpha,
const DeviceMemory<int8> &a, int lda,
const DeviceMemory<int8> &b, int ldb,
int beta, DeviceMemory<int> *c, int ldc,
blas::ComputationType computation_type,
blas::AlgorithmType algorithm,
blas::ProfileResult *output_profile_result);
Stream &ThenBlasGemmWithAlgorithm(blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n,
uint64 k, float alpha,