Fix and Refactor NonAVX512 CPU platform

This commit is contained in:
nammbash 2020-06-14 11:47:44 -07:00
parent 5a6c606736
commit eabb1453f2
1 changed files with 34 additions and 15 deletions

View File

@ -126,22 +126,19 @@ inline string GetMklEagerOpName(const string& name) {
}
#ifdef ENABLE_INTEL_MKL_BFLOAT16
static inline bool CheckBfloat16Support(DataType T) {
static absl::once_flag cpu_bfloat16_warn_once_flag;
// Restrict bfloat16 ops to platforms with at least AVX512 support, fall back
// to Eigen implementation otherwise.
if (!(port::TestCPUFeature(port::CPUFeature::AVX512F)) && T == DT_BFLOAT16) {
absl::call_once(cpu_bfloat16_warn_once_flag, [] {
LOG(ERROR)
<< "oneDNN BFloat16 support are only on platforms with AVX512. "
"Falling back to default implementation if present.";
});
return false;
}
return true;
static inline bool IsBF16SupportedByOneDNNOnThisCPU() {
return port::TestCPUFeature(port::CPUFeature::AVX512F);
}
#endif
static inline void BF16UnsupportedWarning() {
static absl::once_flag cpu_bfloat16_warn_once_flag;
absl::call_once(cpu_bfloat16_warn_once_flag, [] {
LOG(ERROR) << "oneDNN BFloat16 support are only on platforms with AVX512. "
"Falling back to default implementation if present.";
});
}
// Check whether opname with type T is registered as MKL operator
// that can accept input tensors in MKL layout.
//
@ -159,7 +156,18 @@ static inline bool IsMklLayoutDependentOp(const string& op_name, DataType T) {
#ifdef ENABLE_INTEL_MKL_BFLOAT16
// Restrict regular ops to FLOAT and BFLOAT16
if (kernel.find(kMklLayoutDependentOpLabelPattern) != string::npos) {
return (T == DT_FLOAT || CheckBfloat16Support(T));
if (T == DT_FLOAT) return true;
if (T == DT_BFLOAT16) {
if (IsBF16SupportedByOneDNNOnThisCPU()) {
return true;
} else {
// Restrict bfloat16 ops to platforms with at least AVX512 support, fall
// back to Eigen implementation otherwise.
BF16UnsupportedWarning();
return false;
}
}
return false;
}
#else
// Restrict regular ops to FLOAT
@ -216,7 +224,18 @@ static inline bool IsMklNameChangeOp(const string& op_name, DataType T) {
isTypeAllowed = (T == DT_COMPLEX128 || T == DT_COMPLEX64 ||
T == DT_DOUBLE || T == DT_FLOAT);
#ifdef ENABLE_INTEL_MKL_BFLOAT16
isTypeAllowed = (isTypeAllowed || CheckBfloat16Support(T));
if (!isTypeAllowed) {
if (T == DT_BFLOAT16) {
if (IsBF16SupportedByOneDNNOnThisCPU()) {
isTypeAllowed = true;
} else {
// Restrict bfloat16 ops to platforms with at least AVX512 support,
// fall back to Eigen implementation otherwise.
BF16UnsupportedWarning();
isTypeAllowed = false;
}
}
}
#endif
return isTypeAllowed;
}