From 60da3fbda7e6a0c0a84b6bac168c3b06ced04d01 Mon Sep 17 00:00:00 2001 From: "Li, Guizi" Date: Fri, 14 Feb 2020 13:47:29 +0800 Subject: [PATCH] [Intel MKL] Fix dequantize accuracy issue and re-enable this OP --- tensorflow/core/graph/mkl_layout_pass.cc | 32 ++++---- tensorflow/core/kernels/BUILD | 2 + tensorflow/core/kernels/mkl_dequantize_op.cc | 16 ++-- .../core/kernels/mkl_dequantize_op_test.cc | 81 +++++++++++++++++++ tensorflow/core/kernels/mkl_reshape_op.cc | 81 ++++++------------- tensorflow/core/ops/mkl_array_ops.cc | 3 + tensorflow/core/util/mkl_util.h | 15 ++-- 7 files changed, 146 insertions(+), 84 deletions(-) diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 33b66848081..0b765e22d38 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -359,9 +359,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.mul = "Mul"; csinfo_.squared_difference = "SquaredDifference"; csinfo_.sub = "Sub"; -// End - element-wise ops. See note above. + // End - element-wise ops. See note above. -// NOTE: names are alphabetically sorted. + // NOTE: names are alphabetically sorted. rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn), CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation}); @@ -671,18 +671,18 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back( {csinfo_.requantize, mkl_op_registry::GetMklOpName(csinfo_.requantize), CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation}); -// Disable these two MKL operators for now due to some test failures caused -// by these two ops -/* -rinfo_.push_back({csinfo_.tanh, - mkl_op_registry::GetMklOpName(csinfo_.tanh), - CopyAttrsAll, AlwaysRewrite, - kRewriteForLayoutPropagation}); -rinfo_.push_back({csinfo_.tanh_grad, - mkl_op_registry::GetMklOpName(csinfo_.tanh_grad), - CopyAttrsAll, AlwaysRewrite, - kRewriteForLayoutPropagation}); -*/ + // Disable these two MKL operators for now due to some test failures caused + // by these two ops + /* + rinfo_.push_back({csinfo_.tanh, + mkl_op_registry::GetMklOpName(csinfo_.tanh), + CopyAttrsAll, AlwaysRewrite, + kRewriteForLayoutPropagation}); + rinfo_.push_back({csinfo_.tanh_grad, + mkl_op_registry::GetMklOpName(csinfo_.tanh_grad), + CopyAttrsAll, AlwaysRewrite, + kRewriteForLayoutPropagation}); + */ rinfo_.push_back( {csinfo_.reshape, mkl_op_registry::GetMklOpName(csinfo_.reshape), CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation}); @@ -1478,9 +1478,7 @@ rinfo_.push_back({csinfo_.tanh_grad, "Eigen op for Dequantize op."; return false; } - // TODO(sriniva2/mabuzain) Enable the op after verifying support for - // object detection models - return false; + return true; } // Rewrite rule for _FusedMatMul. diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 409f52db948..f72236e07a1 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -7976,6 +7976,7 @@ tf_cc_test_mkl( srcs = ["mkl_dequantize_op_test.cc"], deps = [ ":mkl_dequantize_op", + ":mkl_tfconv_op", ":ops_testutil", ":ops_util", "//tensorflow/core:array_ops_op_lib", @@ -7984,6 +7985,7 @@ tf_cc_test_mkl( "//tensorflow/core:mkl_array_ops_op_lib", "//tensorflow/core:nn_ops_op_lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", diff --git a/tensorflow/core/kernels/mkl_dequantize_op.cc b/tensorflow/core/kernels/mkl_dequantize_op.cc index 4c9dbf4274a..02aaf9ee798 100644 --- a/tensorflow/core/kernels/mkl_dequantize_op.cc +++ b/tensorflow/core/kernels/mkl_dequantize_op.cc @@ -92,10 +92,12 @@ class MklDequantizeOp : public OpKernel { memory::primitive_desc src_pd = memory::primitive_desc(src_md, cpu_engine); - memory::desc dst_md = src_mkl_shape.IsMklTensor() - ? src_md - : memory::desc(src_dims, MklDnnType(), - memory::format::nhwc); + memory::desc dst_md = + src_mkl_shape.IsMklTensor() + ? memory::desc(src_dims, MklDnnType(), + static_cast(src_md.data.format)) + : memory::desc(src_dims, MklDnnType(), + memory::format::nhwc); memory::primitive_desc dst_pd = memory::primitive_desc(dst_md, cpu_engine); @@ -150,9 +152,9 @@ class MklDequantizeOp : public OpKernel { mkldnn::reorder(reorder_pd, *src.GetUsrMem(), *dst.GetUsrMem())); stream(stream::kind::eager).submit(net).wait(); } catch (mkldnn::error& e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + ", in file " + - string(__FILE__) + ":" + std::to_string(__LINE__); + string error_msg = "Status: " + std::to_string(e.status) + ", message: " + + string(e.message) + ", in file " + string(__FILE__) + + ":" + std::to_string(__LINE__); OP_REQUIRES_OK( ctx, errors::Aborted("Operation received an exception:", error_msg)); } diff --git a/tensorflow/core/kernels/mkl_dequantize_op_test.cc b/tensorflow/core/kernels/mkl_dequantize_op_test.cc index 23d59ef7ab6..3093b87fb95 100644 --- a/tensorflow/core/kernels/mkl_dequantize_op_test.cc +++ b/tensorflow/core/kernels/mkl_dequantize_op_test.cc @@ -22,6 +22,8 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/util/mkl_util.h" + namespace tensorflow { class MklDequantizeOpTest : public OpsTestBase {}; @@ -59,4 +61,83 @@ TEST_F(MklDequantizeOpTest, small) { test::ExpectTensorNear(expected, output, 0.1); } +Tensor CreateMklInput() { + MklDnnShape mkl_shape; + memory::desc md = + memory::desc({1, 2, 2, 2}, MklDnnType(), memory::format::nhwc); + mkl_shape.SetMklTensor(true); + mkl_shape.SetMklLayout(&md); + mkl_shape.SetElemType(MklDnnType()); + mkl_shape.SetTfLayout(4, {1, 2, 2, 2}, memory::format::nhwc); + + DataType dtype = DataTypeToEnum::v(); + Tensor mkl_tensor(dtype, {mkl_shape.GetSerializeBufferSize()}); + mkl_shape.SerializeMklDnnShape( + mkl_tensor.flat().data(), + mkl_tensor.flat().size() * sizeof(uint8)); + return mkl_tensor; +} + +template +class CommonTestUtilities : public OpsTestBase { + public: + void MklToTF(const Tensor& tensor, const Tensor& mkl_meta_tensor, + Tensor* output) { + // Create an MKL to TF conversion node and execute it + TF_ASSERT_OK(NodeDefBuilder("mkl_to_tf_op", "_MklToTf") + .Input(FakeInput(DataTypeToEnum::v())) + .Input(FakeInput(DT_UINT8)) // MKL second tensor + .Attr("T", DataTypeToEnum::v()) + .Attr("_kernel", "MklLayoutDependentOp") + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + AddInputFromArray(tensor.shape(), tensor.flat()); + AddInputFromArray(mkl_meta_tensor.shape(), + mkl_meta_tensor.flat()); + TF_ASSERT_OK(RunOpKernel()); + + *output = *GetOutput(0); + } + + void ConvertAndCompare(const Tensor& tensor, const Tensor& mkl_meta_tensor, + const Tensor& expected) { + Tensor output; + MklToTF(tensor, mkl_meta_tensor, &output); + test::ExpectTensorNear(expected, output, 0.1); + } + + void TestBody() {} +}; + +TEST_F(MklDequantizeOpTest, MKLInput) { + TF_ASSERT_OK(NodeDefBuilder("dequantize_op", "_MklDequantize") + .Input(FakeInput(DT_QUINT8)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_UINT8)) // MKL second tensor + .Input(FakeInput(DT_UINT8)) // MKL second tensor + .Input(FakeInput(DT_UINT8)) // MKL second tensor + .Attr("T", DataTypeToEnum::v()) + .Attr("mode", "SCALED") + .Attr("_kernel", "QuantizedMklOp") + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + AddInputFromArray(TensorShape({1, 2, 2, 2}), + {0, 10, 50, 40, 25, 115, 190, 255}); + // min_range = 0 + AddInputFromArray(TensorShape({1}), {0}); + // max_range = 200 + AddInputFromArray(TensorShape({1}), {200.0f}); + auto mkl_tensor = CreateMklInput(); + AddInputFromArray(mkl_tensor.shape(), mkl_tensor.flat()); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + TF_ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 2, 2, 2})); + test::FillValues(&expected, + {0.0, 7.84, 39.21, 31.37, 19.6, 90.2, 149.0, 200}); + CommonTestUtilities test_util; + test_util.ConvertAndCompare(*GetOutput(0), *GetOutput(1), expected); +} + } // namespace tensorflow diff --git a/tensorflow/core/kernels/mkl_reshape_op.cc b/tensorflow/core/kernels/mkl_reshape_op.cc index 3c95a37ecfd..ddb2548b99b 100644 --- a/tensorflow/core/kernels/mkl_reshape_op.cc +++ b/tensorflow/core/kernels/mkl_reshape_op.cc @@ -132,7 +132,7 @@ class MklReshapeOp : public OpKernel { " values, but the requested shape has ", shape.num_elements())); - if (input_in_mkl_format) { + if (input_in_mkl_format && !SkipReorder(mkl_shape_input, shape)) { TensorShape& shape_to = shape; TensorShape shape_from = mkl_shape_input.GetTfShape(); if (shape_from == shape_to) { @@ -152,65 +152,36 @@ class MklReshapeOp : public OpKernel { // Tensorflow, we don't need to reorder tensor contents, we just // need to update MklDnnShape object associated with the input // tensor to reflect the shape change expected by reshape. - if (!SkipReorder(mkl_shape_input, shape_to)) { - // If dimensions that are being expanded or collapsed are not - // maintained contiguously by MKLDNN, then we use reorder. + // If dimensions that are being expanded or collapsed are not + // maintained contiguously by MKLDNN, then we use reorder. - // Get Mkl layout of input tensor. - auto input_mkl_md = mkl_shape_input.GetMklLayout(); - // Set input Mkl layout as the user layout. - dnn_data_input.SetUsrMem(input_mkl_md, &input_tensor); - // Get expected Tensorflow layout of input tensor. - auto output_tf_md = mkl_shape_input.GetTfLayout(); - auto output_tf_pd = - memory::primitive_desc(output_tf_md, cpu_engine); + // Get Mkl layout of input tensor. + auto input_mkl_md = mkl_shape_input.GetMklLayout(); + // Set input Mkl layout as the user layout. + dnn_data_input.SetUsrMem(input_mkl_md, &input_tensor); + // Get expected Tensorflow layout of input tensor. + auto output_tf_md = mkl_shape_input.GetTfLayout(); + auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine); - Tensor* output_tensor = nullptr; - MklDnnShape mkl_shape_output; - mkl_shape_output.SetMklTensor(false); - // We allocate output tensor in the shape expected by Reshape. - AllocateOutputSetMklShape(context, kOutputSlotIdx, &output_tensor, - shape_to, mkl_shape_output); + Tensor* output_tensor = nullptr; + MklDnnShape mkl_shape_output; + mkl_shape_output.SetMklTensor(false); + // We allocate output tensor in the shape expected by Reshape. + AllocateOutputSetMklShape(context, kOutputSlotIdx, &output_tensor, + shape_to, mkl_shape_output); - // Insert reorder between Mkl layout and TensorFlow layout if - // needed. If reorder is not needed but reshape is needed (since - // shape_from != shape_to), then we just copy input tensor to - // output tensor with target shape (we cannot forward Mkl layout - // in such case because shape has changed.) - if (dnn_data_input.CheckReorderToOpMem(output_tf_pd, - output_tensor)) { - } else { - OP_REQUIRES( - context, output_tensor->CopyFrom(input_tensor, shape_to), - errors::InvalidArgument("invalid input tensor shape")); - } - return; + // Insert reorder between Mkl layout and TensorFlow layout if + // needed. If reorder is not needed but reshape is needed (since + // shape_from != shape_to), then we just copy input tensor to + // output tensor with target shape (we cannot forward Mkl layout + // in such case because shape has changed.) + if (dnn_data_input.CheckReorderToOpMem(output_tf_pd, output_tensor)) { } else { - // If dimensions that are being expanded or collapsed are - // maintained contiguously by MKLDNN, then we skip reorder, just - // update MklDnnShape object for the tensorflow tensor, and forward - // Tensorflow tensor as it is to the output. - auto output_dims = TFShapeToMklDnnDims(shape_to); - auto output_strides = CalculateTFStrides(output_dims); - auto output_tf_md = MklDnnData::CreateBlockedMemDesc( - output_dims, output_strides); - auto output_tf_pd = - memory::primitive_desc(output_tf_md, cpu_engine); - - // Set MklDnnShape - MklDnnShape mkl_shape_output; - mkl_shape_output.SetMklTensor(true); - mkl_shape_output.SetMklLayout(&output_tf_pd); - mkl_shape_output.SetElemType(MklDnnType()); - mkl_shape_output.SetTfLayout(output_dims.size(), output_dims, - memory::format::blocked); - - // We now simply forward input Mkl tensor to output and change its - // output MklDnnShape object. - ForwardMklTensorInToOutWithMklShape( - context, kInputSlotIdx, kOutputSlotIdx, mkl_shape_output); - return; + OP_REQUIRES(context, + output_tensor->CopyFrom(input_tensor, shape_to), + errors::InvalidArgument("invalid input tensor shape")); } + return; } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + diff --git a/tensorflow/core/ops/mkl_array_ops.cc b/tensorflow/core/ops/mkl_array_ops.cc index d4908f881e9..4e58711ccad 100644 --- a/tensorflow/core/ops/mkl_array_ops.cc +++ b/tensorflow/core/ops/mkl_array_ops.cc @@ -142,7 +142,10 @@ REGISTER_OP("_MklDequantize") .Output("output: float") .Output("mkl_output: uint8") .Attr("T: quantizedtype") + .Attr("narrow_range: bool = false") + .Attr("axis: int = -1") .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'SCALED'") + .Attr("dtype: {bfloat16, float} = DT_FLOAT") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); ShapeHandle unused; diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index e4450ee8a56..34183e48a6d 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -728,9 +728,9 @@ inline Status ConvertMklToTF(OpKernelContext* context, } return Status::OK(); } catch (mkldnn::error& e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + ", in file " + - string(__FILE__) + ":" + std::to_string(__LINE__); + string error_msg = "Status: " + std::to_string(e.status) + ", message: " + + string(e.message) + ", in file " + string(__FILE__) + + ":" + std::to_string(__LINE__); LOG(FATAL) << "Operation received an exception: " << error_msg; } } @@ -1011,6 +1011,11 @@ memory::data_type MklDnnType() { return memory::data_type::u8; } +template <> +memory::data_type MklDnnType() { + return memory::data_type::u8; +} + template <> memory::data_type MklDnnType() { return memory::data_type::s8; @@ -1250,8 +1255,8 @@ inline Status CreateBlockedMemDescHelper(const memory::dims& dim, } catch (mkldnn::error& e) { return Status(error::Code::INTERNAL, tensorflow::strings::StrCat( - "Failed to create blocked memory descriptor.", - "Status: ", e.status, ", message: ", e.message)); + "Failed to create blocked memory descriptor.", "Status: ", + e.status, ", message: ", e.message)); } #else // We have to construct memory descriptor in a C style. This is not at all