Fix and Refactor NonAVX512 CPU platform
This commit is contained in:
parent
5a6c606736
commit
eabb1453f2
|
@ -126,22 +126,19 @@ inline string GetMklEagerOpName(const string& name) {
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
||||||
static inline bool CheckBfloat16Support(DataType T) {
|
static inline bool IsBF16SupportedByOneDNNOnThisCPU() {
|
||||||
static absl::once_flag cpu_bfloat16_warn_once_flag;
|
return port::TestCPUFeature(port::CPUFeature::AVX512F);
|
||||||
// Restrict bfloat16 ops to platforms with at least AVX512 support, fall back
|
|
||||||
// to Eigen implementation otherwise.
|
|
||||||
if (!(port::TestCPUFeature(port::CPUFeature::AVX512F)) && T == DT_BFLOAT16) {
|
|
||||||
absl::call_once(cpu_bfloat16_warn_once_flag, [] {
|
|
||||||
LOG(ERROR)
|
|
||||||
<< "oneDNN BFloat16 support are only on platforms with AVX512. "
|
|
||||||
"Falling back to default implementation if present.";
|
|
||||||
});
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
static inline void BF16UnsupportedWarning() {
|
||||||
|
static absl::once_flag cpu_bfloat16_warn_once_flag;
|
||||||
|
absl::call_once(cpu_bfloat16_warn_once_flag, [] {
|
||||||
|
LOG(ERROR) << "oneDNN BFloat16 support are only on platforms with AVX512. "
|
||||||
|
"Falling back to default implementation if present.";
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// Check whether opname with type T is registered as MKL operator
|
// Check whether opname with type T is registered as MKL operator
|
||||||
// that can accept input tensors in MKL layout.
|
// that can accept input tensors in MKL layout.
|
||||||
//
|
//
|
||||||
|
@ -159,7 +156,18 @@ static inline bool IsMklLayoutDependentOp(const string& op_name, DataType T) {
|
||||||
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
||||||
// Restrict regular ops to FLOAT and BFLOAT16
|
// Restrict regular ops to FLOAT and BFLOAT16
|
||||||
if (kernel.find(kMklLayoutDependentOpLabelPattern) != string::npos) {
|
if (kernel.find(kMklLayoutDependentOpLabelPattern) != string::npos) {
|
||||||
return (T == DT_FLOAT || CheckBfloat16Support(T));
|
if (T == DT_FLOAT) return true;
|
||||||
|
if (T == DT_BFLOAT16) {
|
||||||
|
if (IsBF16SupportedByOneDNNOnThisCPU()) {
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
// Restrict bfloat16 ops to platforms with at least AVX512 support, fall
|
||||||
|
// back to Eigen implementation otherwise.
|
||||||
|
BF16UnsupportedWarning();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
// Restrict regular ops to FLOAT
|
// Restrict regular ops to FLOAT
|
||||||
|
@ -216,7 +224,18 @@ static inline bool IsMklNameChangeOp(const string& op_name, DataType T) {
|
||||||
isTypeAllowed = (T == DT_COMPLEX128 || T == DT_COMPLEX64 ||
|
isTypeAllowed = (T == DT_COMPLEX128 || T == DT_COMPLEX64 ||
|
||||||
T == DT_DOUBLE || T == DT_FLOAT);
|
T == DT_DOUBLE || T == DT_FLOAT);
|
||||||
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
||||||
isTypeAllowed = (isTypeAllowed || CheckBfloat16Support(T));
|
if (!isTypeAllowed) {
|
||||||
|
if (T == DT_BFLOAT16) {
|
||||||
|
if (IsBF16SupportedByOneDNNOnThisCPU()) {
|
||||||
|
isTypeAllowed = true;
|
||||||
|
} else {
|
||||||
|
// Restrict bfloat16 ops to platforms with at least AVX512 support,
|
||||||
|
// fall back to Eigen implementation otherwise.
|
||||||
|
BF16UnsupportedWarning();
|
||||||
|
isTypeAllowed = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
return isTypeAllowed;
|
return isTypeAllowed;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue