From 6186c0936704539e2d4ac4d2f216799be21ea997 Mon Sep 17 00:00:00 2001 From: Ben Barsdell Date: Tue, 29 Sep 2020 16:07:57 +1000 Subject: [PATCH] Use (+generalize) existing dnn::DataType in blas:: --- tensorflow/stream_executor/blas.cc | 14 +++--- tensorflow/stream_executor/blas.h | 45 ++------------------ tensorflow/stream_executor/cuda/cuda_blas.cc | 42 +++++++++--------- tensorflow/stream_executor/dnn.h | 8 ++++ tensorflow/stream_executor/dnn.proto | 2 + 5 files changed, 41 insertions(+), 70 deletions(-) diff --git a/tensorflow/stream_executor/blas.cc b/tensorflow/stream_executor/blas.cc index f55e318e88b..ca597595969 100644 --- a/tensorflow/stream_executor/blas.cc +++ b/tensorflow/stream_executor/blas.cc @@ -97,19 +97,19 @@ std::ostream& operator<<(std::ostream& os, ComputationType ty) { string DataTypeString(DataType ty) { switch (ty) { - case DataType::kF16: + case DataType::kHalf: return "f16"; - case DataType::kF32: + case DataType::kFloat: return "f32"; - case DataType::kF64: + case DataType::kDouble: return "f64"; - case DataType::kI8: + case DataType::kInt8: return "i8"; - case DataType::kI32: + case DataType::kInt32: return "i32"; - case DataType::kComplexF32: + case DataType::kComplexFloat: return "complex f32"; - case DataType::kComplexF64: + case DataType::kComplexDouble: return "complex f64"; default: LOG(FATAL) << "Unknown DataType " << static_cast(ty); diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h index 411f6f11275..29fa7dbc68e 100644 --- a/tensorflow/stream_executor/blas.h +++ b/tensorflow/stream_executor/blas.h @@ -44,6 +44,7 @@ limitations under the License. #include #include "tensorflow/stream_executor/host_or_device_scalar.h" +#include "tensorflow/stream_executor/dnn.h" // For DataType, ToDataType #include "tensorflow/stream_executor/lib/array_slice.h" #include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/platform/port.h" @@ -119,16 +120,8 @@ std::string ComputationTypeString(ComputationType ty); std::ostream &operator<<(std::ostream &os, ComputationType ty); -// Type with which inputs and outputs of a blaslt routine are performed. -enum class DataType { - kF16, // 16-bit floating-point - kF32, // 32-bit floating-point - kF64, // 64-bit floating-point - kI8, // 8-bit integer - kI32, // 32-bit integer - kComplexF32, // Complex number comprised of two f32s - kComplexF64, // Complex number comprised of two f64s -}; +using dnn::DataType; +using dnn::ToDataType; // Describes the type of pointers for the scaling factors alpha and beta in // blaslt routines. @@ -142,38 +135,6 @@ string DataTypeString(DataType ty); std::ostream &operator<<(std::ostream &os, DataType ty); -// Converts a compile-time type to a DataType value. -template -struct ToDataType {}; -template <> -struct ToDataType { - static constexpr const DataType value = DataType::kF16; -}; -template <> -struct ToDataType { - static constexpr const DataType value = DataType::kF32; -}; -template <> -struct ToDataType { - static constexpr const DataType value = DataType::kF64; -}; -template <> -struct ToDataType { - static constexpr const DataType value = DataType::kI8; -}; -template <> -struct ToDataType { - static constexpr const DataType value = DataType::kI32; -}; -template <> -struct ToDataType> { - static constexpr const DataType value = DataType::kComplexF32; -}; -template <> -struct ToDataType> { - static constexpr const DataType value = DataType::kComplexF64; -}; - // Opaque identifier for an "algorithm" used by a blas routine. This functions // as a hint to the blas library. typedef int64 AlgorithmType; diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index f2bc79e1c29..1c97b6db6a3 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -441,21 +441,21 @@ cublasComputeType_t CUBLASComputationType(blas::ComputationType ty) { blas::DataType GetScaleType(blas::DataType data_type, blas::ComputationType compute_type) { - bool is_complex = data_type == blas::DataType::kComplexF32 || - data_type == blas::DataType::kComplexF64; + bool is_complex = data_type == blas::DataType::kComplexFloat || + data_type == blas::DataType::kComplexDouble; switch (compute_type) { case blas::ComputationType::kF16: - return blas::DataType::kF16; + return blas::DataType::kHalf; case blas::ComputationType::kF32: // fall-through case blas::ComputationType::kComplexF32: // fall-through case blas::ComputationType::kF32FastTF32: // fall-through case blas::ComputationType::kF32FastBF16: - return is_complex ? blas::DataType::kComplexF32 : blas::DataType::kF32; + return is_complex ? blas::DataType::kComplexFloat : blas::DataType::kFloat; case blas::ComputationType::kF64: // fall-through case blas::ComputationType::kComplexF64: - return is_complex ? blas::DataType::kComplexF64 : blas::DataType::kF64; + return is_complex ? blas::DataType::kComplexDouble : blas::DataType::kDouble; case blas::ComputationType::kI32: - return blas::DataType::kI32; + return blas::DataType::kInt32; } } @@ -484,38 +484,38 @@ cublasLtEpilogue_t CUBLASEpilogue(blas::Epilogue epilogue) { cudaDataType_t GetCUDADataType(blas::DataType ty) { switch (ty) { - case blas::DataType::kF16: + case blas::DataType::kHalf: return CUDA_R_16F; - case blas::DataType::kF32: + case blas::DataType::kFloat: return CUDA_R_32F; - case blas::DataType::kF64: + case blas::DataType::kDouble: return CUDA_R_64F; - case blas::DataType::kI8: + case blas::DataType::kInt8: return CUDA_R_8I; - case blas::DataType::kI32: + case blas::DataType::kInt32: return CUDA_R_32I; - case blas::DataType::kComplexF32: + case blas::DataType::kComplexFloat: return CUDA_C_32F; - case blas::DataType::kComplexF64: + case blas::DataType::kComplexDouble: return CUDA_C_64F; } } int GetDataTypeSizeBytes(blas::DataType ty) { switch (ty) { - case blas::DataType::kF16: + case blas::DataType::kHalf: return 2; - case blas::DataType::kF32: + case blas::DataType::kFloat: return 4; - case blas::DataType::kF64: + case blas::DataType::kDouble: return 8; - case blas::DataType::kI8: + case blas::DataType::kInt8: return 1; - case blas::DataType::kI32: + case blas::DataType::kInt32: return 4; - case blas::DataType::kComplexF32: + case blas::DataType::kComplexFloat: return 8; - case blas::DataType::kComplexF64: + case blas::DataType::kComplexDouble: return 16; } } @@ -3611,7 +3611,7 @@ bool CUDABlas::DoBlasLtMatmul(Stream* stream, blas::ProfileResult* output_profile_result) { #if CUDA_VERSION >= 11000 const auto& cuda_plan = *static_cast(plan); - if (cuda_plan.scale_type() == blas::DataType::kF32) { + if (cuda_plan.scale_type() == blas::DataType::kFloat) { // F32* computation types require F32 alpha/beta type, so we must cast them. if (alpha.is_pointer() || beta.is_pointer()) { // We cannot easily convert a pointer to f16 memory to a pointer to f32 diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index 53cdff8cb7a..fd38efc2537 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -133,6 +133,14 @@ template <> struct ToDataType { static constexpr DataType value = DataType::kInt32; }; +template <> +struct ToDataType> { + static constexpr DataType value = DataType::kComplexFloat; +}; +template <> +struct ToDataType> { + static constexpr DataType value = DataType::kComplexDouble; +}; // Specifies the types of a RNN model. enum class RnnMode { diff --git a/tensorflow/stream_executor/dnn.proto b/tensorflow/stream_executor/dnn.proto index 4d09e615e7d..f849b011eb3 100644 --- a/tensorflow/stream_executor/dnn.proto +++ b/tensorflow/stream_executor/dnn.proto @@ -12,6 +12,8 @@ enum DataType { kHalf = 2; kInt8 = 3; kInt32 = 4; + kComplexFloat = 5; + kComplexDouble = 6; } // Describes how a convolution input or output layer's data is formatted.