Merge pull request #47511 from Intel-tensorflow:sbin_fix_bf16
PiperOrigin-RevId: 361208129 Change-Id: I76f3f531102e9b361ba98b95bf54e8335689aafc
This commit is contained in:
commit
e0a81a4ee7
tensorflow
@ -1123,13 +1123,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
DataType T_m;
|
||||
TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
|
||||
|
||||
#ifndef ENABLE_INTEL_MKL_BFLOAT16
|
||||
// Don't try to merge if datatype is not DT_FLOAT
|
||||
if (T_m != DT_FLOAT) return n;
|
||||
#else
|
||||
// Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16
|
||||
if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n;
|
||||
#endif
|
||||
|
||||
if (m->type_string() == csinfo_.bias_add) {
|
||||
// If a is BiasAdd, then Conv2D is 0th input of BiasAdd.
|
||||
@ -1168,13 +1163,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
DataType T_m;
|
||||
TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
|
||||
|
||||
#ifndef ENABLE_INTEL_MKL_BFLOAT16
|
||||
// Don't try to merge if datatype is not DT_FLOAT
|
||||
if (T_m != DT_FLOAT) return n;
|
||||
#else
|
||||
// Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16
|
||||
if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n;
|
||||
#endif
|
||||
|
||||
const Node* conv_node;
|
||||
if (m->type_string() == csinfo_.pad) {
|
||||
@ -1291,13 +1281,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
DataType T_m;
|
||||
TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
|
||||
|
||||
#ifndef ENABLE_INTEL_MKL_BFLOAT16
|
||||
// Don't try to merge if datatype is not DT_FLOAT
|
||||
if (T_m != DT_FLOAT) return n;
|
||||
#else
|
||||
// Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16
|
||||
if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n;
|
||||
#endif
|
||||
|
||||
if (m->type_string() == csinfo_.bias_add_grad) {
|
||||
// Get 1st input 'g' of BiasAddGrad.
|
||||
|
@ -175,7 +175,6 @@ REGISTER_OP("QInt8Input").Output("o: qint8").SetIsStateful();
|
||||
REGISTER_OP("QUInt8Input").Output("o: quint8").SetIsStateful();
|
||||
REGISTER_OP("QInt32Input").Output("o: qint32").SetIsStateful();
|
||||
|
||||
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
||||
REGISTER_OP("BFloat16Input").Output("o: bfloat16").SetIsStateful();
|
||||
REGISTER_OP("BFloat16InputList")
|
||||
.Output("o: N * bfloat16")
|
||||
@ -185,7 +184,6 @@ REGISTER_OP("BFloat16Output2")
|
||||
.Input("i: bfloat16")
|
||||
.Input("i1: bfloat16")
|
||||
.SetIsStateful();
|
||||
#endif // ENABLE_INTEL_MKL_BFLOAT16
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Unit tests related to node merge optimization
|
||||
@ -797,11 +795,9 @@ REGISTER_TEST_ALL_TYPES(NodeMerge_PadWithConv2D_Common_Input);
|
||||
}
|
||||
REGISTER_TEST(NodeMerge_PadWithConv2D_Common_InOutput, DT_FLOAT, Float32Input,
|
||||
Float32Output2);
|
||||
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
||||
// TODO(nhasabni): Enable bfloat16 test when we enable the operator.
|
||||
REGISTER_TEST(NodeMerge_PadWithConv2D_Common_InOutput, DT_BFLOAT16,
|
||||
BFloat16Input, BFloat16Output2);
|
||||
#endif
|
||||
#undef REGISTER_TEST
|
||||
|
||||
// Pad + Conv2D; padding is SAME
|
||||
@ -2451,11 +2447,9 @@ REGISTER_TEST_ALL_TYPES(Output_ControlEdge_PadWithFusedConv2D_Positive);
|
||||
}
|
||||
REGISTER_TEST(NodeMerge_PadWithFusedConv2D_Common_InOutput, DT_FLOAT,
|
||||
Float32Input, Float32Output2);
|
||||
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
||||
// TODO(nhasabni): Enable bfloat16 test when we enable the operator.
|
||||
REGISTER_TEST(NodeMerge_PadWithFusedConv2D_Common_InOutput, DT_BFLOAT16,
|
||||
BFloat16Input, BFloat16Output2);
|
||||
#endif
|
||||
#undef REGISTER_TEST
|
||||
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
|
@ -176,11 +176,9 @@ inline string GetMklEagerOpName(const string& name) {
|
||||
return string(kMklEagerOpPrefix) + name;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
||||
static inline bool IsBF16SupportedByOneDNNOnThisCPU() {
|
||||
return port::TestCPUFeature(port::CPUFeature::AVX512F);
|
||||
}
|
||||
#endif
|
||||
|
||||
static inline void BF16UnsupportedWarning() {
|
||||
static absl::once_flag cpu_bfloat16_warn_once_flag;
|
||||
@ -204,7 +202,6 @@ static inline bool IsMklLayoutDependentOp(const string& op_name, DataType T) {
|
||||
if (kernel.find(kMklQuantizedOpLabelPattern) != string::npos) {
|
||||
return (T == DT_QUINT8 || T == DT_QINT8 || T == DT_QINT32);
|
||||
}
|
||||
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
||||
// Restrict regular ops to FLOAT and BFLOAT16
|
||||
if (kernel.find(kMklLayoutDependentOpLabelPattern) != string::npos) {
|
||||
if (T == DT_FLOAT) return true;
|
||||
@ -220,12 +217,6 @@ static inline bool IsMklLayoutDependentOp(const string& op_name, DataType T) {
|
||||
}
|
||||
return false;
|
||||
}
|
||||
#else
|
||||
// Restrict regular ops to FLOAT
|
||||
if (kernel.find(kMklLayoutDependentOpLabelPattern) != string::npos) {
|
||||
return (T == DT_FLOAT);
|
||||
}
|
||||
#endif // ENABLE_INTEL_MKL_BFLOAT16
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -274,7 +265,6 @@ static inline bool IsMklNameChangeOp(const string& op_name, DataType T) {
|
||||
if (kernel.find(search_string) != string::npos) {
|
||||
isTypeAllowed = (T == DT_COMPLEX128 || T == DT_COMPLEX64 ||
|
||||
T == DT_DOUBLE || T == DT_FLOAT);
|
||||
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
||||
if (!isTypeAllowed) {
|
||||
if (T == DT_BFLOAT16) {
|
||||
if (IsBF16SupportedByOneDNNOnThisCPU()) {
|
||||
@ -287,7 +277,6 @@ static inline bool IsMklNameChangeOp(const string& op_name, DataType T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return isTypeAllowed;
|
||||
}
|
||||
|
||||
|
@ -309,7 +309,6 @@ CREATE_CONV2DFUSION_ADD_BCAST_TEST(AddV2);
|
||||
REGISTER_TEST_ALL_TYPES(FuseDepthwiseConv2DWithBiasAndActivation);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
TEST_F(MklRemapperTest, FuseBatchNormWithRelu) {
|
||||
using ::tensorflow::ops::Placeholder;
|
||||
|
||||
@ -524,7 +523,6 @@ TEST_F(MklRemapperTest, FuseMatMulWithBiasAddAndAdd) {
|
||||
EXPECT_EQ(1, tensors.size());
|
||||
test::ExpectClose(tensors_expected[0], tensors[0], 0, 1e-6);
|
||||
}
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/core/kernels/l2loss_op.h"
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@ -55,11 +56,11 @@ class L2LossOp<CPUDevice, T> : public OpKernel {
|
||||
REGISTER_KERNEL(float);
|
||||
REGISTER_KERNEL(double);
|
||||
REGISTER_KERNEL(Eigen::half);
|
||||
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
||||
#ifdef INTEL_MKL
|
||||
// Since Eigen backend does not support bfloat16 ops, we are selectively
|
||||
// enabling them for MKL backend.
|
||||
REGISTER_KERNEL(bfloat16);
|
||||
#endif
|
||||
#endif // INTEL_MKL
|
||||
#undef REGISTER_KERNEL
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -996,11 +996,7 @@ memory::data_type MklDnnType<qint32>() {
|
||||
}
|
||||
template <>
|
||||
memory::data_type MklDnnType<bfloat16>() {
|
||||
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
||||
return memory::data_type::bf16;
|
||||
#else
|
||||
return memory::data_type::f32;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Map MklTensorFormat to MKL-DNN format tag
|
||||
@ -2034,7 +2030,6 @@ inline bool IsConv1x1StrideNot1(memory::dims filter_dims,
|
||||
|
||||
#define REGISTER_TEST_FLOAT32(TEST) REGISTER_TEST(TEST, DT_FLOAT, Float32Input);
|
||||
|
||||
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
||||
#define REGISTER_TEST_BFLOAT16(TEST) \
|
||||
REGISTER_TEST(TEST, DT_BFLOAT16, BFloat16Input);
|
||||
|
||||
@ -2043,7 +2038,6 @@ inline bool IsConv1x1StrideNot1(memory::dims filter_dims,
|
||||
REGISTER_TEST_BFLOAT16(TEST);
|
||||
#else
|
||||
#define REGISTER_TEST_ALL_TYPES(TEST) REGISTER_TEST_FLOAT32(TEST);
|
||||
#endif // ENABLE_INTEL_MKL_BFLOAT16
|
||||
|
||||
#endif // INTEL_MKL
|
||||
#endif // TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
|
||||
|
@ -38,7 +38,6 @@ load(
|
||||
)
|
||||
load(
|
||||
"//third_party/mkl_dnn:build_defs.bzl",
|
||||
"if_mkl_open_source_only",
|
||||
"if_mkldnn_openmp",
|
||||
)
|
||||
load("@bazel_skylib//lib:new_sets.bzl", "sets")
|
||||
@ -1449,7 +1448,7 @@ def tf_gpu_library(deps = None, cuda_deps = None, copts = tf_copts(), **kwargs):
|
||||
]) + if_rocm_is_configured([
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
]),
|
||||
copts = (copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"]) + if_xla_available(["-DTENSORFLOW_USE_XLA=1"]) + if_mkl(["-DINTEL_MKL=1"]) + if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) + if_enable_mkl(["-DENABLE_MKL"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"])),
|
||||
copts = (copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"]) + if_xla_available(["-DTENSORFLOW_USE_XLA=1"]) + if_mkl(["-DINTEL_MKL=1"]) + if_enable_mkl(["-DENABLE_MKL"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"])),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user