Use (+generalize) existing dnn::DataType in blas::
This commit is contained in:
parent
b03ae6de78
commit
6186c09367
@ -97,19 +97,19 @@ std::ostream& operator<<(std::ostream& os, ComputationType ty) {
|
||||
|
||||
string DataTypeString(DataType ty) {
|
||||
switch (ty) {
|
||||
case DataType::kF16:
|
||||
case DataType::kHalf:
|
||||
return "f16";
|
||||
case DataType::kF32:
|
||||
case DataType::kFloat:
|
||||
return "f32";
|
||||
case DataType::kF64:
|
||||
case DataType::kDouble:
|
||||
return "f64";
|
||||
case DataType::kI8:
|
||||
case DataType::kInt8:
|
||||
return "i8";
|
||||
case DataType::kI32:
|
||||
case DataType::kInt32:
|
||||
return "i32";
|
||||
case DataType::kComplexF32:
|
||||
case DataType::kComplexFloat:
|
||||
return "complex f32";
|
||||
case DataType::kComplexF64:
|
||||
case DataType::kComplexDouble:
|
||||
return "complex f64";
|
||||
default:
|
||||
LOG(FATAL) << "Unknown DataType " << static_cast<int32>(ty);
|
||||
|
@ -44,6 +44,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/stream_executor/host_or_device_scalar.h"
|
||||
#include "tensorflow/stream_executor/dnn.h" // For DataType, ToDataType
|
||||
#include "tensorflow/stream_executor/lib/array_slice.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
@ -119,16 +120,8 @@ std::string ComputationTypeString(ComputationType ty);
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, ComputationType ty);
|
||||
|
||||
// Type with which inputs and outputs of a blaslt routine are performed.
|
||||
enum class DataType {
|
||||
kF16, // 16-bit floating-point
|
||||
kF32, // 32-bit floating-point
|
||||
kF64, // 64-bit floating-point
|
||||
kI8, // 8-bit integer
|
||||
kI32, // 32-bit integer
|
||||
kComplexF32, // Complex number comprised of two f32s
|
||||
kComplexF64, // Complex number comprised of two f64s
|
||||
};
|
||||
using dnn::DataType;
|
||||
using dnn::ToDataType;
|
||||
|
||||
// Describes the type of pointers for the scaling factors alpha and beta in
|
||||
// blaslt routines.
|
||||
@ -142,38 +135,6 @@ string DataTypeString(DataType ty);
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, DataType ty);
|
||||
|
||||
// Converts a compile-time type to a DataType value.
|
||||
template <typename T>
|
||||
struct ToDataType {};
|
||||
template <>
|
||||
struct ToDataType<Eigen::half> {
|
||||
static constexpr const DataType value = DataType::kF16;
|
||||
};
|
||||
template <>
|
||||
struct ToDataType<float> {
|
||||
static constexpr const DataType value = DataType::kF32;
|
||||
};
|
||||
template <>
|
||||
struct ToDataType<double> {
|
||||
static constexpr const DataType value = DataType::kF64;
|
||||
};
|
||||
template <>
|
||||
struct ToDataType<int8> {
|
||||
static constexpr const DataType value = DataType::kI8;
|
||||
};
|
||||
template <>
|
||||
struct ToDataType<int32> {
|
||||
static constexpr const DataType value = DataType::kI32;
|
||||
};
|
||||
template <>
|
||||
struct ToDataType<std::complex<float>> {
|
||||
static constexpr const DataType value = DataType::kComplexF32;
|
||||
};
|
||||
template <>
|
||||
struct ToDataType<std::complex<double>> {
|
||||
static constexpr const DataType value = DataType::kComplexF64;
|
||||
};
|
||||
|
||||
// Opaque identifier for an "algorithm" used by a blas routine. This functions
|
||||
// as a hint to the blas library.
|
||||
typedef int64 AlgorithmType;
|
||||
|
@ -441,21 +441,21 @@ cublasComputeType_t CUBLASComputationType(blas::ComputationType ty) {
|
||||
|
||||
blas::DataType GetScaleType(blas::DataType data_type,
|
||||
blas::ComputationType compute_type) {
|
||||
bool is_complex = data_type == blas::DataType::kComplexF32 ||
|
||||
data_type == blas::DataType::kComplexF64;
|
||||
bool is_complex = data_type == blas::DataType::kComplexFloat ||
|
||||
data_type == blas::DataType::kComplexDouble;
|
||||
switch (compute_type) {
|
||||
case blas::ComputationType::kF16:
|
||||
return blas::DataType::kF16;
|
||||
return blas::DataType::kHalf;
|
||||
case blas::ComputationType::kF32: // fall-through
|
||||
case blas::ComputationType::kComplexF32: // fall-through
|
||||
case blas::ComputationType::kF32FastTF32: // fall-through
|
||||
case blas::ComputationType::kF32FastBF16:
|
||||
return is_complex ? blas::DataType::kComplexF32 : blas::DataType::kF32;
|
||||
return is_complex ? blas::DataType::kComplexFloat : blas::DataType::kFloat;
|
||||
case blas::ComputationType::kF64: // fall-through
|
||||
case blas::ComputationType::kComplexF64:
|
||||
return is_complex ? blas::DataType::kComplexF64 : blas::DataType::kF64;
|
||||
return is_complex ? blas::DataType::kComplexDouble : blas::DataType::kDouble;
|
||||
case blas::ComputationType::kI32:
|
||||
return blas::DataType::kI32;
|
||||
return blas::DataType::kInt32;
|
||||
}
|
||||
}
|
||||
|
||||
@ -484,38 +484,38 @@ cublasLtEpilogue_t CUBLASEpilogue(blas::Epilogue epilogue) {
|
||||
|
||||
cudaDataType_t GetCUDADataType(blas::DataType ty) {
|
||||
switch (ty) {
|
||||
case blas::DataType::kF16:
|
||||
case blas::DataType::kHalf:
|
||||
return CUDA_R_16F;
|
||||
case blas::DataType::kF32:
|
||||
case blas::DataType::kFloat:
|
||||
return CUDA_R_32F;
|
||||
case blas::DataType::kF64:
|
||||
case blas::DataType::kDouble:
|
||||
return CUDA_R_64F;
|
||||
case blas::DataType::kI8:
|
||||
case blas::DataType::kInt8:
|
||||
return CUDA_R_8I;
|
||||
case blas::DataType::kI32:
|
||||
case blas::DataType::kInt32:
|
||||
return CUDA_R_32I;
|
||||
case blas::DataType::kComplexF32:
|
||||
case blas::DataType::kComplexFloat:
|
||||
return CUDA_C_32F;
|
||||
case blas::DataType::kComplexF64:
|
||||
case blas::DataType::kComplexDouble:
|
||||
return CUDA_C_64F;
|
||||
}
|
||||
}
|
||||
|
||||
int GetDataTypeSizeBytes(blas::DataType ty) {
|
||||
switch (ty) {
|
||||
case blas::DataType::kF16:
|
||||
case blas::DataType::kHalf:
|
||||
return 2;
|
||||
case blas::DataType::kF32:
|
||||
case blas::DataType::kFloat:
|
||||
return 4;
|
||||
case blas::DataType::kF64:
|
||||
case blas::DataType::kDouble:
|
||||
return 8;
|
||||
case blas::DataType::kI8:
|
||||
case blas::DataType::kInt8:
|
||||
return 1;
|
||||
case blas::DataType::kI32:
|
||||
case blas::DataType::kInt32:
|
||||
return 4;
|
||||
case blas::DataType::kComplexF32:
|
||||
case blas::DataType::kComplexFloat:
|
||||
return 8;
|
||||
case blas::DataType::kComplexF64:
|
||||
case blas::DataType::kComplexDouble:
|
||||
return 16;
|
||||
}
|
||||
}
|
||||
@ -3611,7 +3611,7 @@ bool CUDABlas::DoBlasLtMatmul(Stream* stream,
|
||||
blas::ProfileResult* output_profile_result) {
|
||||
#if CUDA_VERSION >= 11000
|
||||
const auto& cuda_plan = *static_cast<const CUDABlasLtMatmulPlan*>(plan);
|
||||
if (cuda_plan.scale_type() == blas::DataType::kF32) {
|
||||
if (cuda_plan.scale_type() == blas::DataType::kFloat) {
|
||||
// F32* computation types require F32 alpha/beta type, so we must cast them.
|
||||
if (alpha.is_pointer() || beta.is_pointer()) {
|
||||
// We cannot easily convert a pointer to f16 memory to a pointer to f32
|
||||
|
@ -133,6 +133,14 @@ template <>
|
||||
struct ToDataType<int32> {
|
||||
static constexpr DataType value = DataType::kInt32;
|
||||
};
|
||||
template <>
|
||||
struct ToDataType<std::complex<float>> {
|
||||
static constexpr DataType value = DataType::kComplexFloat;
|
||||
};
|
||||
template <>
|
||||
struct ToDataType<std::complex<double>> {
|
||||
static constexpr DataType value = DataType::kComplexDouble;
|
||||
};
|
||||
|
||||
// Specifies the types of a RNN model.
|
||||
enum class RnnMode {
|
||||
|
@ -12,6 +12,8 @@ enum DataType {
|
||||
kHalf = 2;
|
||||
kInt8 = 3;
|
||||
kInt32 = 4;
|
||||
kComplexFloat = 5;
|
||||
kComplexDouble = 6;
|
||||
}
|
||||
|
||||
// Describes how a convolution input or output layer's data is formatted.
|
||||
|
Loading…
Reference in New Issue
Block a user