diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h index 07a0f7ccd61..cfff3649c8d 100644 --- a/tensorflow/stream_executor/blas.h +++ b/tensorflow/stream_executor/blas.h @@ -97,8 +97,9 @@ enum class ComputationType { kF16, // 16-bit floating-point kF32, // 32-bit floating-point kF64, // 64-bit floating-point + kI32, // 32-bit integer kComplexF32, // Complex number comprised of two f32s. - kComplexF64 // Complex number comprised of two f64s. + kComplexF64, // Complex number comprised of two f64s. }; // Converts a ComputationType to a string. @@ -108,6 +109,15 @@ string ComputationTypeString(ComputationType ty); // as a hint to the blas library. typedef int64 AlgorithmType; +// blas uses -1 to represent the default algorithm. This happens to match up +// with the CUBLAS_GEMM_DFALT constant, so cuda_blas.cc is using static_cast +// to convert from AlgorithmType to cublasGemmAlgo_t, and uses a static_assert +// to ensure that this assumption does not break. +// If another blas implementation uses a different value for the default +// algorithm, then it needs to convert kDefaultGemmAlgo to that value +// (e.g. via a function called ToWhateverGemmAlgo). +constexpr AlgorithmType kDefaultGemmAlgo = -1; + // Describes the result of a performance experiment, usually timing the speed of // a particular AlgorithmType. // @@ -944,6 +954,12 @@ class BlasSupport { // output_profile_result->is_valid(). This lets you use this function for // choosing the best algorithm among many (some of which may fail) without // creating a new Stream for each attempt. + virtual bool DoBlasGemmWithAlgorithm( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, int alpha, const DeviceMemory &a, int lda, + const DeviceMemory &b, int ldb, int beta, DeviceMemory *c, + int ldc, ComputationType computation_type, AlgorithmType algorithm, + ProfileResult *output_profile_result) = 0; virtual bool DoBlasGemmWithAlgorithm( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, const Eigen::half &alpha, @@ -1737,6 +1753,13 @@ class BlasSupport { DeviceMemory> *c, int ldc) override; \ bool GetBlasGemmAlgorithms(std::vector *out_algorithms) \ override; \ + bool DoBlasGemmWithAlgorithm( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, int alpha, const DeviceMemory &a, \ + int lda, const DeviceMemory &b, int ldb, int beta, \ + DeviceMemory *c, int ldc, blas::ComputationType computation_type, \ + blas::AlgorithmType algorithm, \ + blas::ProfileResult *output_profile_result) override; \ bool DoBlasGemmWithAlgorithm( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, const Eigen::half &alpha, \ diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index 2c650afc702..2817364e97d 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/stream_executor/cuda/cuda_blas.h" +#include #include #include "tensorflow/stream_executor/cuda/cuda_activation.h" @@ -483,6 +484,11 @@ struct CUDADataType> { static constexpr cudaDataType_t type = CUDA_C_64F; }; +template <> +struct CUDADataType { + static constexpr cudaDataType_t type = CUDA_R_32I; +}; + template <> struct CUDADataType { static constexpr cudaDataType_t type = CUDA_R_8I; @@ -511,6 +517,8 @@ cudaDataType_t CUDAComputationType(blas::ComputationType ty) { return CUDA_R_32F; case blas::ComputationType::kF64: return CUDA_R_64F; + case blas::ComputationType::kI32: + return CUDA_R_32I; case blas::ComputationType::kComplexF32: return CUDA_C_32F; case blas::ComputationType::kComplexF64: @@ -1849,12 +1857,12 @@ bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa, CUDAComplex(CUDAMemoryMutable(c)), ldc); } -template +template bool CUDABlas::DoBlasGemmWithAlgorithmImpl( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, - uint64 n, uint64 k, const T &alpha, const DeviceMemory &a, int lda, - const DeviceMemory &b, int ldb, const T &beta, DeviceMemory *c, - int ldc, blas::ComputationType computation_type, + uint64 n, uint64 k, const CompT &alpha, const DeviceMemory &a, int lda, + const DeviceMemory &b, int ldb, const CompT &beta, + DeviceMemory *c, int ldc, blas::ComputationType computation_type, blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { // CUDA < version 8 and GPUs < sm_50 don't support cublasGemmEx. #if CUDA_VERSION < 8000 @@ -1881,12 +1889,15 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( } } - cudaDataType_t data_type = CUDADataType::type; + cudaDataType_t cuda_in_type = CUDADataType::type; + // Since we are converting 'algorithm' to cublasGemmAlgo_t by static_cast, + // we do the following compile-time check on the default value: + static_assert(blas::kDefaultGemmAlgo == CUBLAS_GEMM_DFALT, ""); bool result = DoBlasInternalFailureOK( wrap::cublasGemmEx, stream, /* pointer_mode_host = */ true, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, - CUDAMemory(a), data_type, lda, CUDAMemory(b), data_type, ldb, &beta, - CUDAMemoryMutable(c), data_type, ldc, + CUDAMemory(a), cuda_in_type, lda, CUDAMemory(b), cuda_in_type, ldb, &beta, + CUDAMemoryMutable(c), CUDADataType::type, ldc, CUDAComputationType(computation_type), static_cast(algorithm)); @@ -1920,6 +1931,17 @@ bool CUDABlas::GetBlasGemmAlgorithms( return true; } +bool CUDABlas::DoBlasGemmWithAlgorithm( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, int alpha, const DeviceMemory &a, int lda, + const DeviceMemory &b, int ldb, int beta, DeviceMemory *c, + int ldc, blas::ComputationType computation_type, + blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { + return DoBlasGemmWithAlgorithmImpl( + stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + computation_type, algorithm, output_profile_result); +} + bool CUDABlas::DoBlasGemmWithAlgorithm( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, const Eigen::half &alpha, diff --git a/tensorflow/stream_executor/cuda/cuda_blas.h b/tensorflow/stream_executor/cuda/cuda_blas.h index 6a33cd746b3..4a8641b300d 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.h +++ b/tensorflow/stream_executor/cuda/cuda_blas.h @@ -118,16 +118,14 @@ class CUDABlas : public blas::BlasSupport { // and we want to avoid pulling in a dependency on Eigen. When we pass the // references to cublas, we essentially reinterpret_cast to __half, which is // safe because Eigen::half inherits from __half. - template - bool DoBlasGemmWithAlgorithmImpl(Stream *stream, blas::Transpose transa, - blas::Transpose transb, uint64 m, uint64 n, - uint64 k, const T &alpha, - const DeviceMemory &a, int lda, - const DeviceMemory &b, int ldb, - const T &beta, DeviceMemory *c, int ldc, - blas::ComputationType computation_type, - blas::AlgorithmType algorithm, - blas::ProfileResult *output_profile_result); + template + bool DoBlasGemmWithAlgorithmImpl( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, const CompT &alpha, const DeviceMemory &a, + int lda, const DeviceMemory &b, int ldb, const CompT &beta, + DeviceMemory *c, int ldc, blas::ComputationType computation_type, + blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result); // mutex that guards the cuBLAS handle for this device. mutex mu_; diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 97193af7771..9b4a4c4fb18 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -3482,6 +3482,27 @@ Stream &Stream::ThenBlasGemmWithAlgorithm( algorithm, output_profile_result); } +Stream &Stream::ThenBlasGemmWithAlgorithm( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, int alpha, const DeviceMemory &a, int lda, + const DeviceMemory &b, int ldb, int beta, DeviceMemory *c, + int ldc, blas::ComputationType computation_type, + blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), + PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type), + PARAM(algorithm)); + + ThenBlasWithProfileImpl< + blas::Transpose, blas::Transpose, uint64, uint64, uint64, int, + const DeviceMemory &, int, const DeviceMemory &, int, int, + DeviceMemory *, int, blas::ComputationType, blas::AlgorithmType> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb, + m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type, + algorithm, output_profile_result); +} + Stream &Stream::ThenBlasGemmWithAlgorithm( blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory &a, int lda, diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index 2ab3f44af51..ab6b866744d 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -1257,6 +1257,15 @@ class Stream { const Eigen::half &beta, DeviceMemory *c, int ldc, blas::ComputationType computation_type, blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result); + Stream &ThenBlasGemmWithAlgorithm(blas::Transpose transa, + blas::Transpose transb, uint64 m, uint64 n, + uint64 k, int alpha, + const DeviceMemory &a, int lda, + const DeviceMemory &b, int ldb, + int beta, DeviceMemory *c, int ldc, + blas::ComputationType computation_type, + blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result); Stream &ThenBlasGemmWithAlgorithm(blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, float alpha,