diff --git a/tensorflow/stream_executor/blas.cc b/tensorflow/stream_executor/blas.cc
index f55e318e88b..ca597595969 100644
--- a/tensorflow/stream_executor/blas.cc
+++ b/tensorflow/stream_executor/blas.cc
@@ -97,19 +97,19 @@ std::ostream& operator<<(std::ostream& os, ComputationType ty) {
 
 string DataTypeString(DataType ty) {
   switch (ty) {
-    case DataType::kF16:
+    case DataType::kHalf:
       return "f16";
-    case DataType::kF32:
+    case DataType::kFloat:
       return "f32";
-    case DataType::kF64:
+    case DataType::kDouble:
       return "f64";
-    case DataType::kI8:
+    case DataType::kInt8:
       return "i8";
-    case DataType::kI32:
+    case DataType::kInt32:
       return "i32";
-    case DataType::kComplexF32:
+    case DataType::kComplexFloat:
       return "complex f32";
-    case DataType::kComplexF64:
+    case DataType::kComplexDouble:
       return "complex f64";
     default:
       LOG(FATAL) << "Unknown DataType " << static_cast<int32>(ty);
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index 411f6f11275..29fa7dbc68e 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -44,6 +44,7 @@ limitations under the License.
 #include <vector>
 
 #include "tensorflow/stream_executor/host_or_device_scalar.h"
+#include "tensorflow/stream_executor/dnn.h"  // For DataType, ToDataType
 #include "tensorflow/stream_executor/lib/array_slice.h"
 #include "tensorflow/stream_executor/lib/statusor.h"
 #include "tensorflow/stream_executor/platform/port.h"
@@ -119,16 +120,8 @@ std::string ComputationTypeString(ComputationType ty);
 
 std::ostream &operator<<(std::ostream &os, ComputationType ty);
 
-// Type with which inputs and outputs of a blaslt routine are performed.
-enum class DataType {
-  kF16,         // 16-bit floating-point
-  kF32,         // 32-bit floating-point
-  kF64,         // 64-bit floating-point
-  kI8,          // 8-bit integer
-  kI32,         // 32-bit integer
-  kComplexF32,  // Complex number comprised of two f32s
-  kComplexF64,  // Complex number comprised of two f64s
-};
+using dnn::DataType;
+using dnn::ToDataType;
 
 // Describes the type of pointers for the scaling factors alpha and beta in
 // blaslt routines.
@@ -142,38 +135,6 @@ string DataTypeString(DataType ty);
 
 std::ostream &operator<<(std::ostream &os, DataType ty);
 
-// Converts a compile-time type to a DataType value.
-template <typename T>
-struct ToDataType {};
-template <>
-struct ToDataType<Eigen::half> {
-  static constexpr const DataType value = DataType::kF16;
-};
-template <>
-struct ToDataType<float> {
-  static constexpr const DataType value = DataType::kF32;
-};
-template <>
-struct ToDataType<double> {
-  static constexpr const DataType value = DataType::kF64;
-};
-template <>
-struct ToDataType<int8> {
-  static constexpr const DataType value = DataType::kI8;
-};
-template <>
-struct ToDataType<int32> {
-  static constexpr const DataType value = DataType::kI32;
-};
-template <>
-struct ToDataType<std::complex<float>> {
-  static constexpr const DataType value = DataType::kComplexF32;
-};
-template <>
-struct ToDataType<std::complex<double>> {
-  static constexpr const DataType value = DataType::kComplexF64;
-};
-
 // Opaque identifier for an "algorithm" used by a blas routine.  This functions
 // as a hint to the blas library.
 typedef int64 AlgorithmType;
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index f2bc79e1c29..1c97b6db6a3 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -441,21 +441,21 @@ cublasComputeType_t CUBLASComputationType(blas::ComputationType ty) {
 
 blas::DataType GetScaleType(blas::DataType data_type,
                             blas::ComputationType compute_type) {
-  bool is_complex = data_type == blas::DataType::kComplexF32 ||
-                    data_type == blas::DataType::kComplexF64;
+  bool is_complex = data_type == blas::DataType::kComplexFloat ||
+                    data_type == blas::DataType::kComplexDouble;
   switch (compute_type) {
     case blas::ComputationType::kF16:
-      return blas::DataType::kF16;
+      return blas::DataType::kHalf;
     case blas::ComputationType::kF32:          // fall-through
     case blas::ComputationType::kComplexF32:   // fall-through
     case blas::ComputationType::kF32FastTF32:  // fall-through
     case blas::ComputationType::kF32FastBF16:
-      return is_complex ? blas::DataType::kComplexF32 : blas::DataType::kF32;
+      return is_complex ? blas::DataType::kComplexFloat : blas::DataType::kFloat;
     case blas::ComputationType::kF64:  // fall-through
     case blas::ComputationType::kComplexF64:
-      return is_complex ? blas::DataType::kComplexF64 : blas::DataType::kF64;
+      return is_complex ? blas::DataType::kComplexDouble : blas::DataType::kDouble;
     case blas::ComputationType::kI32:
-      return blas::DataType::kI32;
+      return blas::DataType::kInt32;
   }
 }
 
@@ -484,38 +484,38 @@ cublasLtEpilogue_t CUBLASEpilogue(blas::Epilogue epilogue) {
 
 cudaDataType_t GetCUDADataType(blas::DataType ty) {
   switch (ty) {
-    case blas::DataType::kF16:
+    case blas::DataType::kHalf:
       return CUDA_R_16F;
-    case blas::DataType::kF32:
+    case blas::DataType::kFloat:
       return CUDA_R_32F;
-    case blas::DataType::kF64:
+    case blas::DataType::kDouble:
       return CUDA_R_64F;
-    case blas::DataType::kI8:
+    case blas::DataType::kInt8:
       return CUDA_R_8I;
-    case blas::DataType::kI32:
+    case blas::DataType::kInt32:
       return CUDA_R_32I;
-    case blas::DataType::kComplexF32:
+    case blas::DataType::kComplexFloat:
       return CUDA_C_32F;
-    case blas::DataType::kComplexF64:
+    case blas::DataType::kComplexDouble:
       return CUDA_C_64F;
   }
 }
 
 int GetDataTypeSizeBytes(blas::DataType ty) {
   switch (ty) {
-    case blas::DataType::kF16:
+    case blas::DataType::kHalf:
       return 2;
-    case blas::DataType::kF32:
+    case blas::DataType::kFloat:
       return 4;
-    case blas::DataType::kF64:
+    case blas::DataType::kDouble:
       return 8;
-    case blas::DataType::kI8:
+    case blas::DataType::kInt8:
       return 1;
-    case blas::DataType::kI32:
+    case blas::DataType::kInt32:
       return 4;
-    case blas::DataType::kComplexF32:
+    case blas::DataType::kComplexFloat:
       return 8;
-    case blas::DataType::kComplexF64:
+    case blas::DataType::kComplexDouble:
       return 16;
   }
 }
@@ -3611,7 +3611,7 @@ bool CUDABlas::DoBlasLtMatmul(Stream* stream,
                               blas::ProfileResult* output_profile_result) {
 #if CUDA_VERSION >= 11000
   const auto& cuda_plan = *static_cast<const CUDABlasLtMatmulPlan*>(plan);
-  if (cuda_plan.scale_type() == blas::DataType::kF32) {
+  if (cuda_plan.scale_type() == blas::DataType::kFloat) {
     // F32* computation types require F32 alpha/beta type, so we must cast them.
     if (alpha.is_pointer() || beta.is_pointer()) {
       // We cannot easily convert a pointer to f16 memory to a pointer to f32
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 53cdff8cb7a..fd38efc2537 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -133,6 +133,14 @@ template <>
 struct ToDataType<int32> {
   static constexpr DataType value = DataType::kInt32;
 };
+template <>
+struct ToDataType<std::complex<float>> {
+  static constexpr DataType value = DataType::kComplexFloat;
+};
+template <>
+struct ToDataType<std::complex<double>> {
+  static constexpr DataType value = DataType::kComplexDouble;
+};
 
 // Specifies the types of a RNN model.
 enum class RnnMode {
diff --git a/tensorflow/stream_executor/dnn.proto b/tensorflow/stream_executor/dnn.proto
index 4d09e615e7d..f849b011eb3 100644
--- a/tensorflow/stream_executor/dnn.proto
+++ b/tensorflow/stream_executor/dnn.proto
@@ -12,6 +12,8 @@ enum DataType {
   kHalf = 2;
   kInt8 = 3;
   kInt32 = 4;
+  kComplexFloat = 5;
+  kComplexDouble = 6;
 }
 
 // Describes how a convolution input or output layer's data is formatted.