Merge pull request #38881 from Intel-tensorflow:sriniva2/remapper_test
PiperOrigin-RevId: 308965176 Change-Id: Ifd4d80a1e475b6858ba31db72a55271279ab5321
This commit is contained in:
commit
af52ea9135
|
@ -219,19 +219,22 @@ bool HasDataType(const NodeDef* node, const DataType& expected,
|
|||
bool IsCpuCompatibleDataType(const NodeDef* contraction,
|
||||
const string& type_attr = "T") {
|
||||
DataType dtype = GetDataTypeFromAttr(*contraction, type_attr);
|
||||
#if defined(INTEL_MKL) && defined(ENABLE_INTEL_MKL_BFLOAT16)
|
||||
if (IsConv2D(*contraction)) {
|
||||
return dtype == DT_FLOAT || dtype == DT_BFLOAT16;
|
||||
} else if (IsDepthwiseConv2dNative(*contraction)) {
|
||||
return dtype == DT_FLOAT || dtype == DT_BFLOAT16;
|
||||
} else if (IsMatMul(*contraction)) {
|
||||
#if defined(INTEL_MKL)
|
||||
#if defined(ENABLE_INTEL_MKL_BFLOAT16)
|
||||
if (IsConv2D(*contraction) || IsDepthwiseConv2dNative(*contraction) ||
|
||||
IsMatMul(*contraction)) {
|
||||
return dtype == DT_FLOAT || dtype == DT_BFLOAT16;
|
||||
#else
|
||||
if (IsConv2D(*contraction) || IsDepthwiseConv2dNative(*contraction) ||
|
||||
IsMatMul(*contraction)) {
|
||||
return dtype == DT_FLOAT;
|
||||
#endif // ENABLE_INTEL_MKL_BFLOAT16
|
||||
#else
|
||||
if (IsConv2D(*contraction)) {
|
||||
return dtype == DT_FLOAT || dtype == DT_DOUBLE;
|
||||
} else if (IsMatMul(*contraction)) {
|
||||
return dtype == DT_FLOAT;
|
||||
#endif // INTEL_MKL && ENABLE_INTEL_MKL_BFLOAT16
|
||||
#endif // INTEL_MKL
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue