Use (+generalize) existing dnn::DataType in blas::

This commit is contained in:
Ben Barsdell 2020-09-29 16:07:57 +10:00
parent b03ae6de78
commit 6186c09367
5 changed files with 41 additions and 70 deletions

View File

@ -97,19 +97,19 @@ std::ostream& operator<<(std::ostream& os, ComputationType ty) {
string DataTypeString(DataType ty) {
switch (ty) {
case DataType::kF16:
case DataType::kHalf:
return "f16";
case DataType::kF32:
case DataType::kFloat:
return "f32";
case DataType::kF64:
case DataType::kDouble:
return "f64";
case DataType::kI8:
case DataType::kInt8:
return "i8";
case DataType::kI32:
case DataType::kInt32:
return "i32";
case DataType::kComplexF32:
case DataType::kComplexFloat:
return "complex f32";
case DataType::kComplexF64:
case DataType::kComplexDouble:
return "complex f64";
default:
LOG(FATAL) << "Unknown DataType " << static_cast<int32>(ty);

View File

@ -44,6 +44,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/stream_executor/host_or_device_scalar.h"
#include "tensorflow/stream_executor/dnn.h" // For DataType, ToDataType
#include "tensorflow/stream_executor/lib/array_slice.h"
#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/platform/port.h"
@ -119,16 +120,8 @@ std::string ComputationTypeString(ComputationType ty);
std::ostream &operator<<(std::ostream &os, ComputationType ty);
// Type with which inputs and outputs of a blaslt routine are performed.
enum class DataType {
kF16, // 16-bit floating-point
kF32, // 32-bit floating-point
kF64, // 64-bit floating-point
kI8, // 8-bit integer
kI32, // 32-bit integer
kComplexF32, // Complex number comprised of two f32s
kComplexF64, // Complex number comprised of two f64s
};
using dnn::DataType;
using dnn::ToDataType;
// Describes the type of pointers for the scaling factors alpha and beta in
// blaslt routines.
@ -142,38 +135,6 @@ string DataTypeString(DataType ty);
std::ostream &operator<<(std::ostream &os, DataType ty);
// Converts a compile-time type to a DataType value.
template <typename T>
struct ToDataType {};
template <>
struct ToDataType<Eigen::half> {
static constexpr const DataType value = DataType::kF16;
};
template <>
struct ToDataType<float> {
static constexpr const DataType value = DataType::kF32;
};
template <>
struct ToDataType<double> {
static constexpr const DataType value = DataType::kF64;
};
template <>
struct ToDataType<int8> {
static constexpr const DataType value = DataType::kI8;
};
template <>
struct ToDataType<int32> {
static constexpr const DataType value = DataType::kI32;
};
template <>
struct ToDataType<std::complex<float>> {
static constexpr const DataType value = DataType::kComplexF32;
};
template <>
struct ToDataType<std::complex<double>> {
static constexpr const DataType value = DataType::kComplexF64;
};
// Opaque identifier for an "algorithm" used by a blas routine. This functions
// as a hint to the blas library.
typedef int64 AlgorithmType;

View File

@ -441,21 +441,21 @@ cublasComputeType_t CUBLASComputationType(blas::ComputationType ty) {
blas::DataType GetScaleType(blas::DataType data_type,
blas::ComputationType compute_type) {
bool is_complex = data_type == blas::DataType::kComplexF32 ||
data_type == blas::DataType::kComplexF64;
bool is_complex = data_type == blas::DataType::kComplexFloat ||
data_type == blas::DataType::kComplexDouble;
switch (compute_type) {
case blas::ComputationType::kF16:
return blas::DataType::kF16;
return blas::DataType::kHalf;
case blas::ComputationType::kF32: // fall-through
case blas::ComputationType::kComplexF32: // fall-through
case blas::ComputationType::kF32FastTF32: // fall-through
case blas::ComputationType::kF32FastBF16:
return is_complex ? blas::DataType::kComplexF32 : blas::DataType::kF32;
return is_complex ? blas::DataType::kComplexFloat : blas::DataType::kFloat;
case blas::ComputationType::kF64: // fall-through
case blas::ComputationType::kComplexF64:
return is_complex ? blas::DataType::kComplexF64 : blas::DataType::kF64;
return is_complex ? blas::DataType::kComplexDouble : blas::DataType::kDouble;
case blas::ComputationType::kI32:
return blas::DataType::kI32;
return blas::DataType::kInt32;
}
}
@ -484,38 +484,38 @@ cublasLtEpilogue_t CUBLASEpilogue(blas::Epilogue epilogue) {
cudaDataType_t GetCUDADataType(blas::DataType ty) {
switch (ty) {
case blas::DataType::kF16:
case blas::DataType::kHalf:
return CUDA_R_16F;
case blas::DataType::kF32:
case blas::DataType::kFloat:
return CUDA_R_32F;
case blas::DataType::kF64:
case blas::DataType::kDouble:
return CUDA_R_64F;
case blas::DataType::kI8:
case blas::DataType::kInt8:
return CUDA_R_8I;
case blas::DataType::kI32:
case blas::DataType::kInt32:
return CUDA_R_32I;
case blas::DataType::kComplexF32:
case blas::DataType::kComplexFloat:
return CUDA_C_32F;
case blas::DataType::kComplexF64:
case blas::DataType::kComplexDouble:
return CUDA_C_64F;
}
}
int GetDataTypeSizeBytes(blas::DataType ty) {
switch (ty) {
case blas::DataType::kF16:
case blas::DataType::kHalf:
return 2;
case blas::DataType::kF32:
case blas::DataType::kFloat:
return 4;
case blas::DataType::kF64:
case blas::DataType::kDouble:
return 8;
case blas::DataType::kI8:
case blas::DataType::kInt8:
return 1;
case blas::DataType::kI32:
case blas::DataType::kInt32:
return 4;
case blas::DataType::kComplexF32:
case blas::DataType::kComplexFloat:
return 8;
case blas::DataType::kComplexF64:
case blas::DataType::kComplexDouble:
return 16;
}
}
@ -3611,7 +3611,7 @@ bool CUDABlas::DoBlasLtMatmul(Stream* stream,
blas::ProfileResult* output_profile_result) {
#if CUDA_VERSION >= 11000
const auto& cuda_plan = *static_cast<const CUDABlasLtMatmulPlan*>(plan);
if (cuda_plan.scale_type() == blas::DataType::kF32) {
if (cuda_plan.scale_type() == blas::DataType::kFloat) {
// F32* computation types require F32 alpha/beta type, so we must cast them.
if (alpha.is_pointer() || beta.is_pointer()) {
// We cannot easily convert a pointer to f16 memory to a pointer to f32

View File

@ -133,6 +133,14 @@ template <>
struct ToDataType<int32> {
static constexpr DataType value = DataType::kInt32;
};
template <>
struct ToDataType<std::complex<float>> {
static constexpr DataType value = DataType::kComplexFloat;
};
template <>
struct ToDataType<std::complex<double>> {
static constexpr DataType value = DataType::kComplexDouble;
};
// Specifies the types of a RNN model.
enum class RnnMode {

View File

@ -12,6 +12,8 @@ enum DataType {
kHalf = 2;
kInt8 = 3;
kInt32 = 4;
kComplexFloat = 5;
kComplexDouble = 6;
}
// Describes how a convolution input or output layer's data is formatted.