From edd248c42a2c2cbde01deb0ca38efdc131c24793 Mon Sep 17 00:00:00 2001 From: Niranjan Hasabnis Date: Thu, 16 Jul 2020 17:28:22 -0700 Subject: [PATCH 1/2] [Intel MKL] Enabling BatchNorm test in AutoMPMkl This PR makes 2 fixes: one in BatchNormGrad and another one in ReLU. BatchNormGrad fix ensures that both the inputs of BatchNormGrad are in same layout. As per DNNL doc, lack of it would lead to sub-optimal performance in BatchNormGrad. ReLU fix ensures that if input of ReLU gets reordered into a blocked layout from native layout, then output of ReLU carries correct meta tensor. Last change enables 1 disabled unit test in AutoMixedPrecisionMkl. --- .../core/kernels/mkl_fused_batch_norm_op.cc | 45 +++- .../kernels/mkl_fused_batch_norm_op_test.cc | 235 +++++++++++++++++- tensorflow/core/kernels/mkl_relu_op.cc | 52 ++-- .../grappler/auto_mixed_precision_test.py | 5 +- 4 files changed, 306 insertions(+), 31 deletions(-) diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc index 3b2c4f84039..ca15f25a594 100644 --- a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #ifdef INTEL_MKL #include "mkldnn.hpp" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -24,6 +23,7 @@ limitations under the License. #include "tensorflow/core/util/mkl_types.h" #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/tensor_format.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #define GET_FLAG(bn_flag) static_cast(BN_FLAGS::bn_flag) #define IS_SET(cflag) (context_.flags & GET_FLAG(cflag)) @@ -1327,6 +1327,33 @@ class MklFusedBatchNormGradOp : public OpKernel { ? dnn_shape_diff_dst.GetMklLayout() : memory::desc(diff_dst_dims, MklDnnType(), dnn_fmt); + MklDnnData reorder_src(&cpu_engine_); + MklDnnData reorder_diff_dst(&cpu_engine_); + T* diff_dst_data = + static_cast(const_cast(diff_dst_tensor.flat().data())); + T* src_data = + static_cast(const_cast(src_tensor.flat().data())); + +#ifdef ENABLE_MKLDNN_V1 + // Ensure that src and diff_dst are in same blocked memory layout. + // As per MKL-DNN doc, this will lead to faster perf. + if (dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) { + reorder_diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); + reorder_diff_dst.CheckReorderToOpMem( + MEMORY_PD_WITHOUT_DATA(src_md, cpu_engine_), context); + diff_dst_md = src_md; + diff_dst_data = + static_cast(reorder_diff_dst.GetOpMem().get_data_handle()); + } else if (!dnn_shape_src.IsMklTensor() && + dnn_shape_diff_dst.IsMklTensor()) { + reorder_src.SetUsrMem(src_md, &src_tensor); + reorder_src.CheckReorderToOpMem( + MEMORY_PD_WITHOUT_DATA(diff_dst_md, cpu_engine_), context); + src_md = diff_dst_md; + src_data = static_cast(reorder_src.GetOpMem().get_data_handle()); + } +#endif // ENABLE_MKLDNN_V1 + // weights -- MKL DNN packs scales/ shifts as weights in order // of scale, ..., scale, shift, ...., shift weights.AllocateBuffer(2 * depth_ * sizeof(U)); @@ -1350,20 +1377,24 @@ class MklFusedBatchNormGradOp : public OpKernel { MklFusedBatchNormBwdPrimitive* bn_bwd = MklFusedBatchNormBwdPrimitiveFactory::Get(bwdParams); - const T* src_data = src_tensor.flat().data(); - const T* diff_dst_data = diff_dst_tensor.flat().data(); // Check if diff_dst input needs to be reordered std::shared_ptr bn_bwd_pd = bn_bwd->GetBatchNormBwdPd(); if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, bn_bwd_pd, bn_bwd)) { - diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); + diff_dst.SetUsrMem(diff_dst_md, diff_dst_data); diff_dst.CheckReorderToOpMem( MEMORY_PD_WITHOUT_DATA(GET_DIFF_DST_DESC_FROM_OP_PD(bn_bwd_pd), cpu_engine_), context); diff_dst_data = static_cast(diff_dst.GetOpMem().get_data_handle()); - } else { - diff_dst_data = - static_cast(const_cast(diff_dst_tensor.flat().data())); + } + + if (IS_SRC_REORDER_NEEDED(src_md, bn_bwd_pd, bn_bwd)) { + src.SetUsrMem(src_md, src_data); + src.CheckReorderToOpMem( + MEMORY_PD_WITHOUT_DATA(GET_SRC_DESC_FROM_OP_PD(bn_bwd_pd), + cpu_engine_), + context); + src_data = static_cast(src.GetOpMem().get_data_handle()); } // Indices of output tensors diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op_test.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op_test.cc index d97d70fdd81..f151cee53a2 100644 --- a/tensorflow/core/kernels/mkl_fused_batch_norm_op_test.cc +++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op_test.cc @@ -46,6 +46,12 @@ using GraphRunner = std::function; +using GraphRunnerGrad = std::function; + template class CommonTestUtilities : public OpsTestBase { public: @@ -118,10 +124,94 @@ class CommonTestUtilities : public OpsTestBase { test::ExpectClose(batch_var, mkl_batch_var, 1e-5); } + static void VerifyTensorsCloseForGrad(const float epsilon, + const GraphRunnerGrad& run, + const GraphRunnerGrad& run_mkl) { + int batch = 2; + int height = 8; + int width = 8; + int depth = 1; + DataType dtype = DataTypeToEnum::v(); + + Tensor input(dtype, {batch, height, width, depth}); + input.flat() = input.flat().template setRandom(); + Tensor filter(dtype, {3, 3, 1, 6}); + filter.flat() = filter.flat().template setRandom(); + + Tensor y_backprop(dtype, {batch, 8, 8, 6}); + y_backprop.flat() = + y_backprop.flat().template setRandom(); + Tensor scale(dtype, {6}); + scale.flat() = scale.flat().template setRandom(); + Tensor mean(dtype, {6}); + mean.flat() = mean.flat().template setRandom(); + Tensor variance(dtype, {6}); + variance.flat() = + variance.flat().template setRandom().abs(); + Tensor res_sp3(dtype, {6}); + res_sp3.flat() = + res_sp3.flat().template setRandom().abs(); + + Tensor output; + Tensor scale_backprop; + Tensor offset_backprop; + Tensor mkl_output; + Tensor mkl_scale_backprop; + Tensor mkl_offset_backprop; + + run(input, filter, y_backprop, scale, mean, variance, res_sp3, &output, + &scale_backprop, &offset_backprop, epsilon); + + run_mkl(input, filter, y_backprop, scale, mean, variance, res_sp3, + &mkl_output, &mkl_scale_backprop, &mkl_offset_backprop, epsilon); + + ASSERT_EQ(output.dtype(), mkl_output.dtype()); + ASSERT_EQ(output.shape(), mkl_output.shape()); + ASSERT_EQ(scale_backprop.dtype(), mkl_scale_backprop.dtype()); + ASSERT_EQ(scale_backprop.shape(), mkl_scale_backprop.shape()); + ASSERT_EQ(offset_backprop.dtype(), mkl_offset_backprop.dtype()); + ASSERT_EQ(offset_backprop.shape(), mkl_offset_backprop.shape()); + + test::ExpectClose(output, mkl_output, 1e-5); + test::ExpectClose(scale_backprop, mkl_scale_backprop, 1e-5); + test::ExpectClose(offset_backprop, mkl_offset_backprop, 1e-5); + } + private: using random_gen_ = Eigen::internal::NormalRandomGenerator; }; +template +class Conv2DOpTest : public OpsTestBase { + void TestBody() {} + + public: + void RunConv2D(const Tensor& input, const Tensor& filter, Tensor* output, + Tensor* meta_output) { + DataType dtype = DataTypeToEnum::v(); + + TF_EXPECT_OK(NodeDefBuilder("MklConv2D", "_MklConv2D") + .Input(FakeInput(dtype)) + .Input(FakeInput(dtype)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Attr("strides", {1, 1, 1, 1}) + .Attr("padding", "SAME") + .Attr("data_format", "NHWC") + .Attr("_kernel", "MklLayoutDependentOp") + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + AddInputFromArray(input.shape(), input.flat()); + AddInputFromArray(filter.shape(), filter.flat()); + for (int i = 0; i < 2; ++i) + AddInputFromArray(dummy_shape, dummy_tensor); + TF_ASSERT_OK(RunOpKernel()); + + *output = *GetOutput(0); + *meta_output = *GetOutput(2); + } +}; + template class FusedBatchNormOpTest : public OpsTestBase { protected: @@ -198,11 +288,11 @@ class FusedBatchNormOpTest : public OpsTestBase { .Finalize(node_def())); TF_EXPECT_OK(InitOp()); - AddInputFromArray(input.shape(), input.flat()); - AddInputFromArray(scale.shape(), scale.flat()); - AddInputFromArray(offset.shape(), offset.flat()); - AddInputFromArray(mean.shape(), mean.flat()); - AddInputFromArray(variance.shape(), variance.flat()); + AddInputFromArray(input.shape(), input.flat()); + AddInputFromArray(scale.shape(), scale.flat()); + AddInputFromArray(offset.shape(), offset.flat()); + AddInputFromArray(mean.shape(), mean.flat()); + AddInputFromArray(variance.shape(), variance.flat()); for (int i = 0; i < 5; ++i) AddInputFromArray(dummy_shape, dummy_tensor); TF_ASSERT_OK(RunOpKernel()); @@ -222,6 +312,133 @@ class FusedBatchNormOpTest : public OpsTestBase { CommonTestUtilities::VerifyTensorsClose(exponential_avg_factor, is_training, run, run_mkl); } + + void VerifyFusedBatchNormGradWithConv2D(const float epsilon) { + const GraphRunnerGrad run = + [this](const Tensor& input, const Tensor& filter, + const Tensor& y_backprop, const Tensor& scale, + const Tensor& mean, const Tensor& variance, + const Tensor& res_sp3, Tensor* x_backprop_tensor, + Tensor* scale_backprop_tensor, Tensor* offset_backprop_tensor, + const float epsilon) { + auto root = tensorflow::Scope::NewRootScope(); + + auto input_op = + ops::Const(root.WithOpName("input"), Input::Initializer(input)); + auto filter_op = + ops::Const(root.WithOpName("filter"), Input::Initializer(filter)); + ops::Conv2D::Attrs conv_attr; + conv_attr = conv_attr.DataFormat("NHWC"); + auto conv = ops::Conv2D(root.WithOpName("Conv"), input_op, filter_op, + {1, 1, 1, 1}, "SAME", conv_attr); + // ------------------------------------------------------------- + auto y_backprop_op = ops::Const(root.WithOpName("y_backprop"), + Input::Initializer(y_backprop)); + auto scale_op = + ops::Const(root.WithOpName("scale"), Input::Initializer(scale)); + auto mean_op = + ops::Const(root.WithOpName("mean"), Input::Initializer(mean)); + auto var_op = ops::Const(root.WithOpName("variance"), + Input::Initializer(variance)); + auto res_sp3_op = ops::Const(root.WithOpName("reserve_space_3"), + Input::Initializer(res_sp3)); + ops::FusedBatchNormGradV3::Attrs bn_attr; + bn_attr = bn_attr.IsTraining(true); + bn_attr = bn_attr.Epsilon(epsilon); + bn_attr = bn_attr.DataFormat("NHWC"); + auto bn = ops::FusedBatchNormGradV3( + root.WithOpName("FusedBatchNormGrad"), y_backprop_op, conv, + scale_op, mean_op, var_op, res_sp3_op, bn_attr); + + auto x_backprop = + ops::Identity(root.WithOpName("x_backprop"), bn.x_backprop); + auto scale_backprop = ops::Identity(root.WithOpName("scale_backprop"), + bn.scale_backprop); + auto offset_backprop = ops::Identity( + root.WithOpName("offset_backprop"), bn.offset_backprop); + + tensorflow::GraphDef graph; + TF_ASSERT_OK(root.ToGraphDef(&graph)); + + tensorflow::SessionOptions session_options; + std::unique_ptr session( + tensorflow::NewSession(session_options)); + TF_ASSERT_OK(session->Create(graph)); + + std::vector output_tensors; + TF_ASSERT_OK(session->Run( + {}, {"x_backprop", "scale_backprop", "offset_backprop"}, {}, + &output_tensors)); + + *x_backprop_tensor = output_tensors[0]; + *scale_backprop_tensor = output_tensors[1]; + *offset_backprop_tensor = output_tensors[2]; + }; + + const GraphRunnerGrad run_mkl = + [this](const Tensor& input, const Tensor& filter, + const Tensor& y_backprop, const Tensor& scale, + const Tensor& mean, const Tensor& variance, + const Tensor& res_sp3, Tensor* x_backprop_tensor, + Tensor* scale_backprop_tensor, Tensor* offset_backprop_tensor, + const float epsilon) { + Tensor conv2d_output, conv2d_meta_output; + Conv2DOpTest conv2d_test; + conv2d_test.RunConv2D(input, filter, &conv2d_output, + &conv2d_meta_output); + + DataType dtype = DataTypeToEnum::v(); + TF_EXPECT_OK( + NodeDefBuilder("MklFusedBatchNorm", "_MklFusedBatchNormGradV3") + .Input(FakeInput(dtype)) + .Input(FakeInput(dtype)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Attr("epsilon", epsilon) + .Attr("is_training", true) + .Attr("data_format", "NHWC") + .Attr("_kernel", "MklLayoutDependentOp") + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + + AddInputFromArray(y_backprop.shape(), y_backprop.flat()); + AddInputFromArray(conv2d_output.shape(), conv2d_output.flat()); + AddInputFromArray(scale.shape(), scale.flat()); + AddInputFromArray(mean.shape(), mean.flat()); + AddInputFromArray(variance.shape(), variance.flat()); + AddInputFromArray(res_sp3.shape(), res_sp3.flat()); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(conv2d_meta_output.shape(), + conv2d_meta_output.flat()); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + TF_ASSERT_OK(RunOpKernel()); + + CommonTestUtilities test_util; + test_util.PerformConversion(dtype, *GetOutput(0), *GetOutput(5), + x_backprop_tensor); + + CommonTestUtilities test_util_mean; + test_util_mean.PerformConversion(dtype, *GetOutput(1), *GetOutput(6), + scale_backprop_tensor); + + CommonTestUtilities test_util_var; + test_util_var.PerformConversion(dtype, *GetOutput(2), *GetOutput(7), + offset_backprop_tensor); + }; + + CommonTestUtilities::VerifyTensorsCloseForGrad(epsilon, run, run_mkl); + } }; TYPED_TEST_SUITE_P(FusedBatchNormOpTest); @@ -250,8 +467,14 @@ TYPED_TEST_P(FusedBatchNormOpTest, InferenceIgnoreAvgFactor) { this->VerifyFusedBatchNorm(exponential_avg_factor, is_training); } +TYPED_TEST_P(FusedBatchNormOpTest, FusedBatchNormGradV3) { + const float epsilon = 0.001; + this->VerifyFusedBatchNormGradWithConv2D(epsilon); +} + REGISTER_TYPED_TEST_SUITE_P(FusedBatchNormOpTest, Training, TrainingRunningMean, - Inference, InferenceIgnoreAvgFactor); + Inference, InferenceIgnoreAvgFactor, + FusedBatchNormGradV3); using FusedBatchNormDataTypes = ::testing::Types; INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedBatchNormOpTest, diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc index 5d52742d558..e090d6e62f7 100644 --- a/tensorflow/core/kernels/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl_relu_op.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include "mkldnn.hpp" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -27,6 +26,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/mkl_types.h" #include "tensorflow/core/util/mkl_util.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" using mkldnn::algorithm; using mkldnn::eltwise_forward; @@ -566,6 +566,7 @@ class MklReluOpBase : public OpKernel { std::shared_ptr fwd_cpu_stream; fwd_cpu_stream.reset(CreateStream(context, eltwise_fwd->GetEngine())); // Check if src needs to be reordered + bool is_src_reordered = false; const T* src_data = src_tensor.flat().data(); if (IS_SRC_REORDER_NEEDED(src_md, eltwise_fwd_pd, eltwise_fwd)) { src.SetUsrMem(src_md, &src_tensor); @@ -575,27 +576,48 @@ class MklReluOpBase : public OpKernel { context); src_data = const_cast( reinterpret_cast(src.GetOpMem().get_data_handle())); + is_src_reordered = true; } - // Allocate dst tensor, always set it as MKL-DNN layout - if (dnn_shape_src.IsMklTensor()) { + + // If src is reordered, then dst tensor would be in blocked layout. + // So we propagate this blocked layout on the output. We follow same + // logic when src is in blocked (MKL) layout to start of with also. + if (is_src_reordered || dnn_shape_src.IsMklTensor()) { dnn_shape_dst.SetMklTensor(true); auto dst_pd = eltwise_fwd_pd->PRIMITIVE_DESC_DST; dnn_shape_dst.SetMklLayout(&dst_pd); dnn_shape_dst.SetElemType(MklDnnType()); - dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(), - dnn_shape_src.GetSizesAsMklDnnDims(), - dnn_shape_src.GetTfDataFormat()); + if (dnn_shape_src.IsMklTensor()) { + dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(), + dnn_shape_src.GetSizesAsMklDnnDims(), + dnn_shape_src.GetTfDataFormat()); + } else { + dnn_shape_dst.SetTfLayout(src_tensor.dims(), + TFShapeToMklDnnDims(src_tensor.shape()), + MKL_TENSOR_FORMAT_BLOCKED); + } tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T)); } else { + // If src is not in blocked layout or it is not reordered, then dst is + // in native layout. dnn_shape_dst.SetMklTensor(false); tf_shape_dst = src_tensor.shape(); } - OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( - {static_cast(src_index)}, - static_cast(dst_index), - tf_shape_dst, &dst_tensor)); - AllocateOutputSetMklShape(context, dst_index, dnn_shape_dst); + if (is_src_reordered) { + // If src is reordered, then src and dst would be in different layouts. + AllocateOutputSetMklShape(context, dst_index, &dst_tensor, tf_shape_dst, + dnn_shape_dst); + } else { + // forwarding input to output works only when layouts of src and + // dst tensor remains same -- either both of them are in native layout + // or in blocked (MKL) layout. + OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( + {static_cast(src_index)}, + static_cast(dst_index), + tf_shape_dst, &dst_tensor)); + AllocateOutputSetMklShape(context, dst_index, dnn_shape_dst); + } T* dst_data = dst_tensor->flat().data(); // execute eltwise @@ -924,7 +946,7 @@ class MklEluOp : public MklReluOpBase { // return exp(feature) - 1 if feature > 0; feature otherwise T feature = (static_cast(user_i))[0]; if (feature < static_cast(0)) - (static_cast(out_o))[0] = std::exp(feature); + (static_cast(out_o))[0] = Eigen::numext::exp(feature); else (static_cast(out_o))[0] = feature; return; @@ -966,7 +988,7 @@ class MklEluGradOp if (feature > static_cast(0)) { (static_cast(out_o))[0] = (static_cast(user_g))[0]; } else { - T elu = std::exp(feature) - static_cast(1); + T elu = Eigen::numext::exp(feature) - static_cast(1); (static_cast(out_o))[0] = (static_cast(user_g))[0] * (elu + static_cast(1)); } @@ -1004,8 +1026,8 @@ class MklTanhOp : public MklReluOpBase { void* out_o = static_cast(dst_tensor->flat().data()); // tanh(x) = (e^x - e^(-x))/ (e^x + e^(-x)) T feature = (static_cast(user_i))[0]; - T e1 = std::exp(feature); - T e2 = std::exp(-feature); + T e1 = Eigen::numext::exp(feature); + T e2 = Eigen::numext::exp(-feature); (static_cast(out_o))[0] = (e1 - e2) / (e1 + e2); return; } diff --git a/tensorflow/python/grappler/auto_mixed_precision_test.py b/tensorflow/python/grappler/auto_mixed_precision_test.py index 539c2bca9f3..3762435c597 100644 --- a/tensorflow/python/grappler/auto_mixed_precision_test.py +++ b/tensorflow/python/grappler/auto_mixed_precision_test.py @@ -506,9 +506,7 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase): tol = 5e-2 if mode == 'mkl' else 1e-3 self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol) - # TODO(reedwm): Fix and enable this test with MKL. Currently this crashes with - # MKL - @parameterized.parameters(['cuda']) + @parameterized.parameters(['cuda', 'mkl']) @test_util.run_deprecated_v1 @test_util.disable_xla('This test does not pass with XLA') def test_conv_bn_dropout(self, mode): @@ -539,6 +537,7 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase): # The default tolerance (1e-3) results in a tiny fraction (<1%) of # miscompares on ROCm platform, and hence the tolerance bump tol = 2e-3 if test.is_built_with_rocm else 1e-3 + tol = 5e-2 if mode == 'mkl' else tol self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol) # TODO(reedwm): Fix and enable this test with MKL. Currently this crashes with From 78ce403921649b9f98f3172d47196c638c8aef0b Mon Sep 17 00:00:00 2001 From: Yimei Sun Date: Tue, 22 Sep 2020 16:08:28 -0700 Subject: [PATCH 2/2] Address the review comments --- tensorflow/core/kernels/mkl/BUILD | 1 + .../core/kernels/mkl/mkl_fused_batch_norm_op.cc | 6 ++++-- .../kernels/mkl/mkl_fused_batch_norm_op_test.cc | 17 +++++++++++------ 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/tensorflow/core/kernels/mkl/BUILD b/tensorflow/core/kernels/mkl/BUILD index 16180a5b7bd..515db2595ef 100644 --- a/tensorflow/core/kernels/mkl/BUILD +++ b/tensorflow/core/kernels/mkl/BUILD @@ -318,6 +318,7 @@ tf_cc_test_mkl( srcs = ["mkl_fused_batch_norm_op_test.cc"], linkstatic = 1, deps = [ + ":mkl_conv_op", ":mkl_fused_batch_norm_op", "//tensorflow/core:direct_session", "//tensorflow/core/kernels:conv_ops_gpu_hdrs", diff --git a/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc index ca15f25a594..7a4931c1e83 100644 --- a/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc @@ -1335,8 +1335,10 @@ class MklFusedBatchNormGradOp : public OpKernel { static_cast(const_cast(src_tensor.flat().data())); #ifdef ENABLE_MKLDNN_V1 - // Ensure that src and diff_dst are in same blocked memory layout. - // As per MKL-DNN doc, this will lead to faster perf. + // MKL-DNN requires src and diff_dst to be in same memory layout, either + // blocked or native format. If these inputs are in different formats, + // convert the one in native format to blocked format as MKL-DNN gives + // better performance for blocked format. if (dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) { reorder_diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); reorder_diff_dst.CheckReorderToOpMem( diff --git a/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op_test.cc b/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op_test.cc index f151cee53a2..c9b52f3da94 100644 --- a/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op_test.cc +++ b/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op_test.cc @@ -131,24 +131,29 @@ class CommonTestUtilities : public OpsTestBase { int height = 8; int width = 8; int depth = 1; + int filter_height = 3; + int filter_width = 3; + int in_channels = 1; + int out_channels = 6; DataType dtype = DataTypeToEnum::v(); Tensor input(dtype, {batch, height, width, depth}); input.flat() = input.flat().template setRandom(); - Tensor filter(dtype, {3, 3, 1, 6}); + Tensor filter(dtype, + {filter_height, filter_width, in_channels, out_channels}); filter.flat() = filter.flat().template setRandom(); - Tensor y_backprop(dtype, {batch, 8, 8, 6}); + Tensor y_backprop(dtype, {batch, height, width, out_channels}); y_backprop.flat() = y_backprop.flat().template setRandom(); - Tensor scale(dtype, {6}); + Tensor scale(dtype, {out_channels}); scale.flat() = scale.flat().template setRandom(); - Tensor mean(dtype, {6}); + Tensor mean(dtype, {out_channels}); mean.flat() = mean.flat().template setRandom(); - Tensor variance(dtype, {6}); + Tensor variance(dtype, {out_channels}); variance.flat() = variance.flat().template setRandom().abs(); - Tensor res_sp3(dtype, {6}); + Tensor res_sp3(dtype, {out_channels}); res_sp3.flat() = res_sp3.flat().template setRandom().abs();