From eabb1453f2efc62a648096c21b2ed993079d8c51 Mon Sep 17 00:00:00 2001 From: nammbash Date: Sun, 14 Jun 2020 11:47:44 -0700 Subject: [PATCH] Fix and Refactor NonAVX512 CPU platform --- tensorflow/core/graph/mkl_graph_util.h | 49 ++++++++++++++++++-------- 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/tensorflow/core/graph/mkl_graph_util.h b/tensorflow/core/graph/mkl_graph_util.h index cd09ac522d7..3c4c186b791 100644 --- a/tensorflow/core/graph/mkl_graph_util.h +++ b/tensorflow/core/graph/mkl_graph_util.h @@ -126,22 +126,19 @@ inline string GetMklEagerOpName(const string& name) { } #ifdef ENABLE_INTEL_MKL_BFLOAT16 -static inline bool CheckBfloat16Support(DataType T) { - static absl::once_flag cpu_bfloat16_warn_once_flag; - // Restrict bfloat16 ops to platforms with at least AVX512 support, fall back - // to Eigen implementation otherwise. - if (!(port::TestCPUFeature(port::CPUFeature::AVX512F)) && T == DT_BFLOAT16) { - absl::call_once(cpu_bfloat16_warn_once_flag, [] { - LOG(ERROR) - << "oneDNN BFloat16 support are only on platforms with AVX512. " - "Falling back to default implementation if present."; - }); - return false; - } - return true; +static inline bool IsBF16SupportedByOneDNNOnThisCPU() { + return port::TestCPUFeature(port::CPUFeature::AVX512F); } #endif +static inline void BF16UnsupportedWarning() { + static absl::once_flag cpu_bfloat16_warn_once_flag; + absl::call_once(cpu_bfloat16_warn_once_flag, [] { + LOG(ERROR) << "oneDNN BFloat16 support are only on platforms with AVX512. " + "Falling back to default implementation if present."; + }); +} + // Check whether opname with type T is registered as MKL operator // that can accept input tensors in MKL layout. // @@ -159,7 +156,18 @@ static inline bool IsMklLayoutDependentOp(const string& op_name, DataType T) { #ifdef ENABLE_INTEL_MKL_BFLOAT16 // Restrict regular ops to FLOAT and BFLOAT16 if (kernel.find(kMklLayoutDependentOpLabelPattern) != string::npos) { - return (T == DT_FLOAT || CheckBfloat16Support(T)); + if (T == DT_FLOAT) return true; + if (T == DT_BFLOAT16) { + if (IsBF16SupportedByOneDNNOnThisCPU()) { + return true; + } else { + // Restrict bfloat16 ops to platforms with at least AVX512 support, fall + // back to Eigen implementation otherwise. + BF16UnsupportedWarning(); + return false; + } + } + return false; } #else // Restrict regular ops to FLOAT @@ -216,7 +224,18 @@ static inline bool IsMklNameChangeOp(const string& op_name, DataType T) { isTypeAllowed = (T == DT_COMPLEX128 || T == DT_COMPLEX64 || T == DT_DOUBLE || T == DT_FLOAT); #ifdef ENABLE_INTEL_MKL_BFLOAT16 - isTypeAllowed = (isTypeAllowed || CheckBfloat16Support(T)); + if (!isTypeAllowed) { + if (T == DT_BFLOAT16) { + if (IsBF16SupportedByOneDNNOnThisCPU()) { + isTypeAllowed = true; + } else { + // Restrict bfloat16 ops to platforms with at least AVX512 support, + // fall back to Eigen implementation otherwise. + BF16UnsupportedWarning(); + isTypeAllowed = false; + } + } + } #endif return isTypeAllowed; }