Fix and Refactor NonAVX512 CPU platform
This commit is contained in:
parent
5a6c606736
commit
eabb1453f2
|
@ -126,22 +126,19 @@ inline string GetMklEagerOpName(const string& name) {
|
|||
}
|
||||
|
||||
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
||||
static inline bool CheckBfloat16Support(DataType T) {
|
||||
static absl::once_flag cpu_bfloat16_warn_once_flag;
|
||||
// Restrict bfloat16 ops to platforms with at least AVX512 support, fall back
|
||||
// to Eigen implementation otherwise.
|
||||
if (!(port::TestCPUFeature(port::CPUFeature::AVX512F)) && T == DT_BFLOAT16) {
|
||||
absl::call_once(cpu_bfloat16_warn_once_flag, [] {
|
||||
LOG(ERROR)
|
||||
<< "oneDNN BFloat16 support are only on platforms with AVX512. "
|
||||
"Falling back to default implementation if present.";
|
||||
});
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
static inline bool IsBF16SupportedByOneDNNOnThisCPU() {
|
||||
return port::TestCPUFeature(port::CPUFeature::AVX512F);
|
||||
}
|
||||
#endif
|
||||
|
||||
static inline void BF16UnsupportedWarning() {
|
||||
static absl::once_flag cpu_bfloat16_warn_once_flag;
|
||||
absl::call_once(cpu_bfloat16_warn_once_flag, [] {
|
||||
LOG(ERROR) << "oneDNN BFloat16 support are only on platforms with AVX512. "
|
||||
"Falling back to default implementation if present.";
|
||||
});
|
||||
}
|
||||
|
||||
// Check whether opname with type T is registered as MKL operator
|
||||
// that can accept input tensors in MKL layout.
|
||||
//
|
||||
|
@ -159,7 +156,18 @@ static inline bool IsMklLayoutDependentOp(const string& op_name, DataType T) {
|
|||
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
||||
// Restrict regular ops to FLOAT and BFLOAT16
|
||||
if (kernel.find(kMklLayoutDependentOpLabelPattern) != string::npos) {
|
||||
return (T == DT_FLOAT || CheckBfloat16Support(T));
|
||||
if (T == DT_FLOAT) return true;
|
||||
if (T == DT_BFLOAT16) {
|
||||
if (IsBF16SupportedByOneDNNOnThisCPU()) {
|
||||
return true;
|
||||
} else {
|
||||
// Restrict bfloat16 ops to platforms with at least AVX512 support, fall
|
||||
// back to Eigen implementation otherwise.
|
||||
BF16UnsupportedWarning();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
#else
|
||||
// Restrict regular ops to FLOAT
|
||||
|
@ -216,7 +224,18 @@ static inline bool IsMklNameChangeOp(const string& op_name, DataType T) {
|
|||
isTypeAllowed = (T == DT_COMPLEX128 || T == DT_COMPLEX64 ||
|
||||
T == DT_DOUBLE || T == DT_FLOAT);
|
||||
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
||||
isTypeAllowed = (isTypeAllowed || CheckBfloat16Support(T));
|
||||
if (!isTypeAllowed) {
|
||||
if (T == DT_BFLOAT16) {
|
||||
if (IsBF16SupportedByOneDNNOnThisCPU()) {
|
||||
isTypeAllowed = true;
|
||||
} else {
|
||||
// Restrict bfloat16 ops to platforms with at least AVX512 support,
|
||||
// fall back to Eigen implementation otherwise.
|
||||
BF16UnsupportedWarning();
|
||||
isTypeAllowed = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return isTypeAllowed;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue