Merge pull request #47511 from Intel-tensorflow:sbin_fix_bf16

PiperOrigin-RevId: 361208129
Change-Id: I76f3f531102e9b361ba98b95bf54e8335689aafc
This commit is contained in:
TensorFlower Gardener 2021-03-05 13:10:08 -08:00
commit e0a81a4ee7
7 changed files with 4 additions and 44 deletions

View File

@ -1123,13 +1123,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
DataType T_m; DataType T_m;
TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m)); TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
#ifndef ENABLE_INTEL_MKL_BFLOAT16
// Don't try to merge if datatype is not DT_FLOAT
if (T_m != DT_FLOAT) return n;
#else
// Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16 // Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16
if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n; if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n;
#endif
if (m->type_string() == csinfo_.bias_add) { if (m->type_string() == csinfo_.bias_add) {
// If a is BiasAdd, then Conv2D is 0th input of BiasAdd. // If a is BiasAdd, then Conv2D is 0th input of BiasAdd.
@ -1168,13 +1163,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
DataType T_m; DataType T_m;
TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m)); TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
#ifndef ENABLE_INTEL_MKL_BFLOAT16
// Don't try to merge if datatype is not DT_FLOAT
if (T_m != DT_FLOAT) return n;
#else
// Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16 // Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16
if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n; if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n;
#endif
const Node* conv_node; const Node* conv_node;
if (m->type_string() == csinfo_.pad) { if (m->type_string() == csinfo_.pad) {
@ -1291,13 +1281,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
DataType T_m; DataType T_m;
TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m)); TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
#ifndef ENABLE_INTEL_MKL_BFLOAT16
// Don't try to merge if datatype is not DT_FLOAT
if (T_m != DT_FLOAT) return n;
#else
// Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16 // Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16
if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n; if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n;
#endif
if (m->type_string() == csinfo_.bias_add_grad) { if (m->type_string() == csinfo_.bias_add_grad) {
// Get 1st input 'g' of BiasAddGrad. // Get 1st input 'g' of BiasAddGrad.

View File

@ -175,7 +175,6 @@ REGISTER_OP("QInt8Input").Output("o: qint8").SetIsStateful();
REGISTER_OP("QUInt8Input").Output("o: quint8").SetIsStateful(); REGISTER_OP("QUInt8Input").Output("o: quint8").SetIsStateful();
REGISTER_OP("QInt32Input").Output("o: qint32").SetIsStateful(); REGISTER_OP("QInt32Input").Output("o: qint32").SetIsStateful();
#ifdef ENABLE_INTEL_MKL_BFLOAT16
REGISTER_OP("BFloat16Input").Output("o: bfloat16").SetIsStateful(); REGISTER_OP("BFloat16Input").Output("o: bfloat16").SetIsStateful();
REGISTER_OP("BFloat16InputList") REGISTER_OP("BFloat16InputList")
.Output("o: N * bfloat16") .Output("o: N * bfloat16")
@ -185,7 +184,6 @@ REGISTER_OP("BFloat16Output2")
.Input("i: bfloat16") .Input("i: bfloat16")
.Input("i1: bfloat16") .Input("i1: bfloat16")
.SetIsStateful(); .SetIsStateful();
#endif // ENABLE_INTEL_MKL_BFLOAT16
///////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////
// Unit tests related to node merge optimization // Unit tests related to node merge optimization
@ -797,11 +795,9 @@ REGISTER_TEST_ALL_TYPES(NodeMerge_PadWithConv2D_Common_Input);
} }
REGISTER_TEST(NodeMerge_PadWithConv2D_Common_InOutput, DT_FLOAT, Float32Input, REGISTER_TEST(NodeMerge_PadWithConv2D_Common_InOutput, DT_FLOAT, Float32Input,
Float32Output2); Float32Output2);
#ifdef ENABLE_INTEL_MKL_BFLOAT16
// TODO(nhasabni): Enable bfloat16 test when we enable the operator. // TODO(nhasabni): Enable bfloat16 test when we enable the operator.
REGISTER_TEST(NodeMerge_PadWithConv2D_Common_InOutput, DT_BFLOAT16, REGISTER_TEST(NodeMerge_PadWithConv2D_Common_InOutput, DT_BFLOAT16,
BFloat16Input, BFloat16Output2); BFloat16Input, BFloat16Output2);
#endif
#undef REGISTER_TEST #undef REGISTER_TEST
// Pad + Conv2D; padding is SAME // Pad + Conv2D; padding is SAME
@ -2451,11 +2447,9 @@ REGISTER_TEST_ALL_TYPES(Output_ControlEdge_PadWithFusedConv2D_Positive);
} }
REGISTER_TEST(NodeMerge_PadWithFusedConv2D_Common_InOutput, DT_FLOAT, REGISTER_TEST(NodeMerge_PadWithFusedConv2D_Common_InOutput, DT_FLOAT,
Float32Input, Float32Output2); Float32Input, Float32Output2);
#ifdef ENABLE_INTEL_MKL_BFLOAT16
// TODO(nhasabni): Enable bfloat16 test when we enable the operator. // TODO(nhasabni): Enable bfloat16 test when we enable the operator.
REGISTER_TEST(NodeMerge_PadWithFusedConv2D_Common_InOutput, DT_BFLOAT16, REGISTER_TEST(NodeMerge_PadWithFusedConv2D_Common_InOutput, DT_BFLOAT16,
BFloat16Input, BFloat16Output2); BFloat16Input, BFloat16Output2);
#endif
#undef REGISTER_TEST #undef REGISTER_TEST
#define REGISTER_TEST(NAME, T, INPUT) \ #define REGISTER_TEST(NAME, T, INPUT) \

View File

@ -176,11 +176,9 @@ inline string GetMklEagerOpName(const string& name) {
return string(kMklEagerOpPrefix) + name; return string(kMklEagerOpPrefix) + name;
} }
#ifdef ENABLE_INTEL_MKL_BFLOAT16
static inline bool IsBF16SupportedByOneDNNOnThisCPU() { static inline bool IsBF16SupportedByOneDNNOnThisCPU() {
return port::TestCPUFeature(port::CPUFeature::AVX512F); return port::TestCPUFeature(port::CPUFeature::AVX512F);
} }
#endif
static inline void BF16UnsupportedWarning() { static inline void BF16UnsupportedWarning() {
static absl::once_flag cpu_bfloat16_warn_once_flag; static absl::once_flag cpu_bfloat16_warn_once_flag;
@ -204,7 +202,6 @@ static inline bool IsMklLayoutDependentOp(const string& op_name, DataType T) {
if (kernel.find(kMklQuantizedOpLabelPattern) != string::npos) { if (kernel.find(kMklQuantizedOpLabelPattern) != string::npos) {
return (T == DT_QUINT8 || T == DT_QINT8 || T == DT_QINT32); return (T == DT_QUINT8 || T == DT_QINT8 || T == DT_QINT32);
} }
#ifdef ENABLE_INTEL_MKL_BFLOAT16
// Restrict regular ops to FLOAT and BFLOAT16 // Restrict regular ops to FLOAT and BFLOAT16
if (kernel.find(kMklLayoutDependentOpLabelPattern) != string::npos) { if (kernel.find(kMklLayoutDependentOpLabelPattern) != string::npos) {
if (T == DT_FLOAT) return true; if (T == DT_FLOAT) return true;
@ -220,12 +217,6 @@ static inline bool IsMklLayoutDependentOp(const string& op_name, DataType T) {
} }
return false; return false;
} }
#else
// Restrict regular ops to FLOAT
if (kernel.find(kMklLayoutDependentOpLabelPattern) != string::npos) {
return (T == DT_FLOAT);
}
#endif // ENABLE_INTEL_MKL_BFLOAT16
return false; return false;
} }
@ -274,7 +265,6 @@ static inline bool IsMklNameChangeOp(const string& op_name, DataType T) {
if (kernel.find(search_string) != string::npos) { if (kernel.find(search_string) != string::npos) {
isTypeAllowed = (T == DT_COMPLEX128 || T == DT_COMPLEX64 || isTypeAllowed = (T == DT_COMPLEX128 || T == DT_COMPLEX64 ||
T == DT_DOUBLE || T == DT_FLOAT); T == DT_DOUBLE || T == DT_FLOAT);
#ifdef ENABLE_INTEL_MKL_BFLOAT16
if (!isTypeAllowed) { if (!isTypeAllowed) {
if (T == DT_BFLOAT16) { if (T == DT_BFLOAT16) {
if (IsBF16SupportedByOneDNNOnThisCPU()) { if (IsBF16SupportedByOneDNNOnThisCPU()) {
@ -287,7 +277,6 @@ static inline bool IsMklNameChangeOp(const string& op_name, DataType T) {
} }
} }
} }
#endif
return isTypeAllowed; return isTypeAllowed;
} }

View File

@ -309,7 +309,6 @@ CREATE_CONV2DFUSION_ADD_BCAST_TEST(AddV2);
REGISTER_TEST_ALL_TYPES(FuseDepthwiseConv2DWithBiasAndActivation); REGISTER_TEST_ALL_TYPES(FuseDepthwiseConv2DWithBiasAndActivation);
#undef REGISTER_TEST #undef REGISTER_TEST
#ifdef ENABLE_MKLDNN_V1
TEST_F(MklRemapperTest, FuseBatchNormWithRelu) { TEST_F(MklRemapperTest, FuseBatchNormWithRelu) {
using ::tensorflow::ops::Placeholder; using ::tensorflow::ops::Placeholder;
@ -524,7 +523,6 @@ TEST_F(MklRemapperTest, FuseMatMulWithBiasAddAndAdd) {
EXPECT_EQ(1, tensors.size()); EXPECT_EQ(1, tensors.size());
test::ExpectClose(tensors_expected[0], tensors[0], 0, 1e-6); test::ExpectClose(tensors_expected[0], tensors[0], 0, 1e-6);
} }
#endif // ENABLE_MKLDNN_V1
} // namespace grappler } // namespace grappler
} // namespace tensorflow } // namespace tensorflow

View File

@ -18,6 +18,7 @@ limitations under the License.
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/l2loss_op.h" #include "tensorflow/core/kernels/l2loss_op.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
@ -55,11 +56,11 @@ class L2LossOp<CPUDevice, T> : public OpKernel {
REGISTER_KERNEL(float); REGISTER_KERNEL(float);
REGISTER_KERNEL(double); REGISTER_KERNEL(double);
REGISTER_KERNEL(Eigen::half); REGISTER_KERNEL(Eigen::half);
#ifdef ENABLE_INTEL_MKL_BFLOAT16 #ifdef INTEL_MKL
// Since Eigen backend does not support bfloat16 ops, we are selectively // Since Eigen backend does not support bfloat16 ops, we are selectively
// enabling them for MKL backend. // enabling them for MKL backend.
REGISTER_KERNEL(bfloat16); REGISTER_KERNEL(bfloat16);
#endif #endif // INTEL_MKL
#undef REGISTER_KERNEL #undef REGISTER_KERNEL
} // namespace tensorflow } // namespace tensorflow

View File

@ -996,11 +996,7 @@ memory::data_type MklDnnType<qint32>() {
} }
template <> template <>
memory::data_type MklDnnType<bfloat16>() { memory::data_type MklDnnType<bfloat16>() {
#ifdef ENABLE_INTEL_MKL_BFLOAT16
return memory::data_type::bf16; return memory::data_type::bf16;
#else
return memory::data_type::f32;
#endif
} }
// Map MklTensorFormat to MKL-DNN format tag // Map MklTensorFormat to MKL-DNN format tag
@ -2034,7 +2030,6 @@ inline bool IsConv1x1StrideNot1(memory::dims filter_dims,
#define REGISTER_TEST_FLOAT32(TEST) REGISTER_TEST(TEST, DT_FLOAT, Float32Input); #define REGISTER_TEST_FLOAT32(TEST) REGISTER_TEST(TEST, DT_FLOAT, Float32Input);
#ifdef ENABLE_INTEL_MKL_BFLOAT16
#define REGISTER_TEST_BFLOAT16(TEST) \ #define REGISTER_TEST_BFLOAT16(TEST) \
REGISTER_TEST(TEST, DT_BFLOAT16, BFloat16Input); REGISTER_TEST(TEST, DT_BFLOAT16, BFloat16Input);
@ -2043,7 +2038,6 @@ inline bool IsConv1x1StrideNot1(memory::dims filter_dims,
REGISTER_TEST_BFLOAT16(TEST); REGISTER_TEST_BFLOAT16(TEST);
#else #else
#define REGISTER_TEST_ALL_TYPES(TEST) REGISTER_TEST_FLOAT32(TEST); #define REGISTER_TEST_ALL_TYPES(TEST) REGISTER_TEST_FLOAT32(TEST);
#endif // ENABLE_INTEL_MKL_BFLOAT16
#endif // INTEL_MKL #endif // INTEL_MKL
#endif // TENSORFLOW_CORE_UTIL_MKL_UTIL_H_ #endif // TENSORFLOW_CORE_UTIL_MKL_UTIL_H_

View File

@ -38,7 +38,6 @@ load(
) )
load( load(
"//third_party/mkl_dnn:build_defs.bzl", "//third_party/mkl_dnn:build_defs.bzl",
"if_mkl_open_source_only",
"if_mkldnn_openmp", "if_mkldnn_openmp",
) )
load("@bazel_skylib//lib:new_sets.bzl", "sets") load("@bazel_skylib//lib:new_sets.bzl", "sets")
@ -1449,7 +1448,7 @@ def tf_gpu_library(deps = None, cuda_deps = None, copts = tf_copts(), **kwargs):
]) + if_rocm_is_configured([ ]) + if_rocm_is_configured([
"@local_config_rocm//rocm:rocm_headers", "@local_config_rocm//rocm:rocm_headers",
]), ]),
copts = (copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"]) + if_xla_available(["-DTENSORFLOW_USE_XLA=1"]) + if_mkl(["-DINTEL_MKL=1"]) + if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) + if_enable_mkl(["-DENABLE_MKL"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"])), copts = (copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"]) + if_xla_available(["-DTENSORFLOW_USE_XLA=1"]) + if_mkl(["-DINTEL_MKL=1"]) + if_enable_mkl(["-DENABLE_MKL"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"])),
**kwargs **kwargs
) )