diff --git a/tensorflow/core/common_runtime/mkl_layout_pass.cc b/tensorflow/core/common_runtime/mkl_layout_pass.cc index f6e42fc7e8c..778d5445cb2 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass.cc @@ -682,12 +682,15 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back( {csinfo_.requantize, mkl_op_registry::GetMklOpName(csinfo_.requantize), CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation}); +#ifdef ENABLE_MKLDNN_V1 + // Optimized TanhGrad support exists only in DNNL 1.x. rinfo_.push_back({csinfo_.tanh, mkl_op_registry::GetMklOpName(csinfo_.tanh), CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation}); rinfo_.push_back( {csinfo_.tanh_grad, mkl_op_registry::GetMklOpName(csinfo_.tanh_grad), CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation}); +#endif // ENABLE_MKLDNN_V1 rinfo_.push_back( {csinfo_.reshape, mkl_op_registry::GetMklOpName(csinfo_.reshape), CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation}); diff --git a/tensorflow/core/common_runtime/mkl_layout_pass_test.cc b/tensorflow/core/common_runtime/mkl_layout_pass_test.cc index 9971f6c5d7e..d480c0a49ce 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass_test.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass_test.cc @@ -3024,6 +3024,8 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_LeakyReluLeakyReluGrad_Positive); // clang-format on // clang-format off +#ifdef ENABLE_MKLDNN_V1 + #define REGISTER_TEST(NAME, T, INPUT) \ TEST_F(MklLayoutPassTest, NAME##_##T) { \ DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); \ @@ -3081,6 +3083,7 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_TanhGrad_Positive); } REGISTER_TEST_ALL_TYPES(NodeRewrite_TanhTanhGrad_Positive); #undef REGISTER_TEST +#endif // ENABLE_MKLDNN_V1 // clang-format on TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_Positive) { diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc index 6d79b8f3282..70aa1e937d3 100644 --- a/tensorflow/core/kernels/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl_relu_op.cc @@ -19,7 +19,6 @@ limitations under the License. #include <unordered_map> #include "mkldnn.hpp" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -27,6 +26,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/mkl_types.h" #include "tensorflow/core/util/mkl_util.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" using mkldnn::algorithm; using mkldnn::eltwise_forward; @@ -269,7 +269,7 @@ class MklEltwiseBwdParams { MklEltwiseBwdParams(const memory::dims& src_dims, const memory::desc& common_md, algorithm alg_kind, - float alpha, float beta, int forward_input_type) + float alpha, float beta, int forward_input_type = -1) : src_dims(src_dims), common_md(common_md), alg_kind(alg_kind), @@ -644,7 +644,10 @@ class MklReluGradOpBase : public OpKernel { virtual int GetDiffSrcIndex() const { return 0; } // What is the type of input tensor that grad op receives from forward op -- // is it 'x' (SRC) or 'y' (DST). For Relu-family, it is 'x', so fwd op SRC. + +#ifdef ENABLE_MKLDNN_V1 virtual int GetTypeOfInputTensorFromFwdOp() const { return MKLDNN_ARG_SRC; } +#endif void Compute(OpKernelContext* context) { try { @@ -736,8 +739,16 @@ class MklReluGradOpBase : public OpKernel { common_md = src_md; } +#ifdef ENABLE_MKLDNN_V1 MklEltwiseBwdParams<T> bwdParams(src_dims, common_md, alg_kind, alpha_, beta_, GetTypeOfInputTensorFromFwdOp()); +#else + // MKLDNN V0 does not support reusing output of forward op in backward. + // So this optimization works only in MKLDNN v1. + MklEltwiseBwdParams<T> bwdParams(src_dims, common_md, alg_kind, alpha_, + beta_); +#endif // ENABLE_MKLDNN_V1 + MklEltwiseBwdPrimitive<T>* eltwise_bwd = MklEltwiseBwdPrimitiveFactory<T>::Get(bwdParams); @@ -962,6 +973,11 @@ class MklEluGradOp } }; +#ifdef ENABLE_MKLDNN_V1 +// Optimized TanhGrad support exists in DNNL1.x only +// (eltwise_tanh_use_dst_for_bwd). We can still support it with DNNL0.x, but +// it will not be optimized. So we disable it for DNNL0.x. + template <typename Device, typename T> class MklTanhOp : public MklReluOpBase<Device, T, ALGORITHM::eltwise_tanh> { public: @@ -1043,6 +1059,7 @@ class MklTanhGradOp (static_cast<T*>(user_g))[0] * (static_cast<T>(1) - tanh * tanh); } }; +#endif // ENABLE_MKLDNN_V1 #define RELU6_UPPER_BOUND 6.0f template <typename Device, typename T> @@ -1227,6 +1244,7 @@ TF_CALL_bfloat16(REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES); TF_CALL_float(REGISTER_ELU_MKL_SUPPORTED_KERNELS_TYPES); TF_CALL_bfloat16(REGISTER_ELU_MKL_SUPPORTED_KERNELS_TYPES); +#ifdef ENABLE_MKLDNN_V1 #define REGISTER_TANH_MKL_SUPPORTED_KERNELS_TYPES(type) \ REGISTER_KERNEL_BUILDER( \ Name("_MklTanh") \ @@ -1242,6 +1260,7 @@ TF_CALL_bfloat16(REGISTER_ELU_MKL_SUPPORTED_KERNELS_TYPES); MklTanhGradOp<CPUDevice, type>); TF_CALL_float(REGISTER_TANH_MKL_SUPPORTED_KERNELS_TYPES); TF_CALL_bfloat16(REGISTER_TANH_MKL_SUPPORTED_KERNELS_TYPES); +#endif #define REGISTER_RELU6_MKL_SUPPORTED_KERNELS_TYPES(type) \ REGISTER_KERNEL_BUILDER( \ diff --git a/tensorflow/core/kernels/mkl_relu_op_test.cc b/tensorflow/core/kernels/mkl_relu_op_test.cc index d1fdf7ab4ae..86d7f979c1f 100644 --- a/tensorflow/core/kernels/mkl_relu_op_test.cc +++ b/tensorflow/core/kernels/mkl_relu_op_test.cc @@ -15,8 +15,8 @@ limitations under the License. #ifdef INTEL_MKL -#include "mkldnn.hpp" #include "absl/strings/match.h" +#include "mkldnn.hpp" #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/cc/ops/nn_ops.h" #include "tensorflow/cc/ops/standard_ops.h" @@ -121,8 +121,11 @@ static Graph* Activation(const string& op_name, const string& kind, BM(OP, 32, 64, 128, 256, cpu); \ BM(OP, 33, 65, 129, 257, cpu); +#ifdef ENABLE_MKLDNN_V1 +// Optimized MKLDNN TanhGrad support exists in DNNL1.x only. TEST_ALL_SIZES(Tanh) TEST_ALL_SIZES(TanhGrad) +#endif // ENABLE_MKLDNN_V1 TEST_ALL_SIZES(Relu) TEST_ALL_SIZES(ReluGrad) TEST_ALL_SIZES(Elu)