Fix and Refactor NonAVX512 CPU platform

This commit is contained in:
nammbash 2020-06-14 11:47:44 -07:00
parent 5a6c606736
commit eabb1453f2
1 changed files with 34 additions and 15 deletions

View File

@ -126,22 +126,19 @@ inline string GetMklEagerOpName(const string& name) {
} }
#ifdef ENABLE_INTEL_MKL_BFLOAT16 #ifdef ENABLE_INTEL_MKL_BFLOAT16
static inline bool CheckBfloat16Support(DataType T) { static inline bool IsBF16SupportedByOneDNNOnThisCPU() {
static absl::once_flag cpu_bfloat16_warn_once_flag; return port::TestCPUFeature(port::CPUFeature::AVX512F);
// Restrict bfloat16 ops to platforms with at least AVX512 support, fall back
// to Eigen implementation otherwise.
if (!(port::TestCPUFeature(port::CPUFeature::AVX512F)) && T == DT_BFLOAT16) {
absl::call_once(cpu_bfloat16_warn_once_flag, [] {
LOG(ERROR)
<< "oneDNN BFloat16 support are only on platforms with AVX512. "
"Falling back to default implementation if present.";
});
return false;
}
return true;
} }
#endif #endif
static inline void BF16UnsupportedWarning() {
static absl::once_flag cpu_bfloat16_warn_once_flag;
absl::call_once(cpu_bfloat16_warn_once_flag, [] {
LOG(ERROR) << "oneDNN BFloat16 support are only on platforms with AVX512. "
"Falling back to default implementation if present.";
});
}
// Check whether opname with type T is registered as MKL operator // Check whether opname with type T is registered as MKL operator
// that can accept input tensors in MKL layout. // that can accept input tensors in MKL layout.
// //
@ -159,7 +156,18 @@ static inline bool IsMklLayoutDependentOp(const string& op_name, DataType T) {
#ifdef ENABLE_INTEL_MKL_BFLOAT16 #ifdef ENABLE_INTEL_MKL_BFLOAT16
// Restrict regular ops to FLOAT and BFLOAT16 // Restrict regular ops to FLOAT and BFLOAT16
if (kernel.find(kMklLayoutDependentOpLabelPattern) != string::npos) { if (kernel.find(kMklLayoutDependentOpLabelPattern) != string::npos) {
return (T == DT_FLOAT || CheckBfloat16Support(T)); if (T == DT_FLOAT) return true;
if (T == DT_BFLOAT16) {
if (IsBF16SupportedByOneDNNOnThisCPU()) {
return true;
} else {
// Restrict bfloat16 ops to platforms with at least AVX512 support, fall
// back to Eigen implementation otherwise.
BF16UnsupportedWarning();
return false;
}
}
return false;
} }
#else #else
// Restrict regular ops to FLOAT // Restrict regular ops to FLOAT
@ -216,7 +224,18 @@ static inline bool IsMklNameChangeOp(const string& op_name, DataType T) {
isTypeAllowed = (T == DT_COMPLEX128 || T == DT_COMPLEX64 || isTypeAllowed = (T == DT_COMPLEX128 || T == DT_COMPLEX64 ||
T == DT_DOUBLE || T == DT_FLOAT); T == DT_DOUBLE || T == DT_FLOAT);
#ifdef ENABLE_INTEL_MKL_BFLOAT16 #ifdef ENABLE_INTEL_MKL_BFLOAT16
isTypeAllowed = (isTypeAllowed || CheckBfloat16Support(T)); if (!isTypeAllowed) {
if (T == DT_BFLOAT16) {
if (IsBF16SupportedByOneDNNOnThisCPU()) {
isTypeAllowed = true;
} else {
// Restrict bfloat16 ops to platforms with at least AVX512 support,
// fall back to Eigen implementation otherwise.
BF16UnsupportedWarning();
isTypeAllowed = false;
}
}
}
#endif #endif
return isTypeAllowed; return isTypeAllowed;
} }