From c5a0825115b0dfd1f3e2e979b42fc5aa2847aae8 Mon Sep 17 00:00:00 2001
From: R Gomathi <gomathi.ramamurthy@intel.com>
Date: Fri, 13 Sep 2019 10:55:53 +0530
Subject: [PATCH] [INTEL MKL] Enabled MIN_FIRST support and primitive caching
 for MKL-DNN Quantize OP

---
 tensorflow/core/graph/mkl_layout_pass.cc      |   7 +-
 tensorflow/core/kernels/mkl_quantize_op.cc    | 313 ++++++++++++++++--
 .../core/kernels/mkl_quantize_op_test.cc      |  91 +++++
 tensorflow/core/ops/mkl_array_ops.cc          |   2 +-
 4 files changed, 376 insertions(+), 37 deletions(-)

diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 41f6bb92ac8..f96acc10e10 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -1580,6 +1580,7 @@ rinfo_.push_back({csinfo_.tanh_grad,
     TryGetNodeAttr(n->def(), "axis", &axis);
     TF_CHECK_OK(GetNodeAttr(n->def(), "mode", &mode_string));
     TF_CHECK_OK(GetNodeAttr(n->def(), "round_mode", &round_mode_string));
+
     if (narrow_range) {
       VLOG(1) << "QuantizeOpRewrite: narrow range is enabled for quantization."
               << "This case is not optimized by Intel MKL, "
@@ -1593,8 +1594,9 @@ rinfo_.push_back({csinfo_.tanh_grad,
               << "thus using Eigen op for Quantize op ";
       return false;
     }
-    if (mode_string != "SCALED" || round_mode_string != "HALF_TO_EVEN") {
-      VLOG(1) << "QuantizeOpRewrite: Mode is not SCALED and/or"
+    if (!((mode_string == "SCALED" && round_mode_string == "HALF_TO_EVEN") ||
+          (mode_string == "MIN_FIRST"))) {
+      VLOG(1) << "QuantizeOpRewrite: Mode is not SCALED or MIN_FIRST and/or"
               << "rounding mode is not HALF_TO_EVEN. "
               << "This case is not optimized by Intel MKL, thus using Eigen op"
               << "for Quantize op ";
@@ -1849,6 +1851,7 @@ rinfo_.push_back({csinfo_.tanh_grad,
   // NOTE: names are alphabetically sorted.
   static void CopyAttrsAll(const Node* orig_node, NodeBuilder* nb,
                            bool change_format = false);
+
   static void CopyAttrsConv(const Node* orig_node, NodeBuilder* nb,
                             bool change_format = false);
   static void CopyAttrsConv2DDepthwiseCheckConstFilter(
diff --git a/tensorflow/core/kernels/mkl_quantize_op.cc b/tensorflow/core/kernels/mkl_quantize_op.cc
index 4f7d054f724..465d860db9c 100644
--- a/tensorflow/core/kernels/mkl_quantize_op.cc
+++ b/tensorflow/core/kernels/mkl_quantize_op.cc
@@ -59,6 +59,196 @@ namespace tensorflow {
 
 typedef Eigen::ThreadPoolDevice CPUDevice;
 
+struct MklReorderWithScaleFwdParams {
+  memory::dims src_dims;
+  memory::desc src_md;
+  memory::desc dst_md;
+  string dtypes = string("");
+  struct PostOpParam {
+    string name;
+    std::vector<float> param;
+  };
+  PostOpParam post_op_params;
+
+  MklReorderWithScaleFwdParams(memory::dims src_dims, memory::desc src_md,
+                               memory::desc dst_md)
+      : src_dims(src_dims), src_md(src_md), dst_md(dst_md) {}
+};
+
+class MklReorderWithScalePrimitive : public MklPrimitive {
+ public:
+  explicit MklReorderWithScalePrimitive(
+      const memory* from, const memory* to,
+      const MklReorderWithScaleFwdParams& fwdParams) {
+    // create reorder primitive
+    Setup(from, to, fwdParams);
+  }
+
+  ~MklReorderWithScalePrimitive() {}
+
+  std::shared_ptr<primitive> GetPrimitive() { return context_.reorder_prim; }
+
+  // set data handles
+  void SetMemory(const memory* from, const memory* to) {
+    context_.src_mem->set_data_handle(from->get_data_handle());
+    context_.dst_mem->set_data_handle(to->get_data_handle());
+  }
+
+ private:
+  // Primitive reuse context for reorder
+  struct ReorderContext {
+    // MKLDNN memory
+    std::shared_ptr<mkldnn::memory> src_mem;
+    std::shared_ptr<mkldnn::memory> dst_mem;
+
+    // Memory desc
+    std::shared_ptr<mkldnn::memory::desc> src_md;
+    std::shared_ptr<mkldnn::memory::desc> dst_md;
+
+    // Memory primitive desc
+    std::shared_ptr<mkldnn::memory::primitive_desc> src_mpd;
+    std::shared_ptr<mkldnn::memory::primitive_desc> dst_mpd;
+
+    // Reorder primitive descriptor and primitive
+    std::shared_ptr<reorder::primitive_desc> reorder_pd;
+    std::shared_ptr<primitive> reorder_prim;
+
+    ReorderContext()
+        : src_mem(nullptr),
+          dst_mem(nullptr),
+          src_md(nullptr),
+          dst_md(nullptr),
+          src_mpd(nullptr),
+          dst_mpd(nullptr),
+          reorder_pd(nullptr),
+          reorder_prim(nullptr) {}
+  } context_;
+
+  engine cpu_engine_ = engine(engine::cpu, 0);
+
+  // Reorder primitive setup
+  void Setup(const memory* from, const memory* to,
+             const MklReorderWithScaleFwdParams& fwdParams) {
+    // Create memory descriptors for reorder data with specified format
+    context_.src_md.reset(new memory::desc(fwdParams.src_md.data));
+    context_.dst_md.reset(new memory::desc(fwdParams.dst_md.data));
+    context_.src_mpd.reset(
+        new memory::primitive_desc(*context_.src_md, cpu_engine_));
+    context_.dst_mpd.reset(
+        new memory::primitive_desc(*context_.dst_md, cpu_engine_));
+
+    // Check if there is any fusion as post-ops
+    auto const& post_op_params = fwdParams.post_op_params;
+    mkldnn::primitive_attr post_ops_attr;
+
+    if (post_op_params.name == "scale") {
+      DCHECK_EQ(post_op_params.param.size(), 1);
+      std::vector<float> scales;
+      scales.push_back(post_op_params.param[0]);
+      post_ops_attr.set_output_scales(0, scales);
+    } else {
+      DCHECK(post_op_params.name == "scale");
+    }
+
+    // Create a reorder
+    context_.reorder_pd =
+        std::make_shared<reorder::primitive_desc>(reorder::primitive_desc(
+            *context_.src_mpd, *context_.dst_mpd, post_ops_attr));
+
+    // Create memory primitive based on dummy data
+    context_.src_mem.reset(new memory(*context_.src_mpd, DummyData));
+    context_.dst_mem.reset(new memory(*context_.dst_mpd, DummyData));
+
+    // Create reorder primitive
+    context_.reorder_prim = std::make_shared<reorder>(
+        reorder(*context_.reorder_pd, *context_.src_mem, *context_.dst_mem));
+  }
+};
+
+template <typename T>
+class MklReorderWithScalePrimitiveFactory : public MklPrimitiveFactory<T> {
+ public:
+  static MklReorderWithScalePrimitive* Get(
+      const memory* from, const memory* to,
+      const MklReorderWithScaleFwdParams& fwdParams) {
+    // Try to find a suitable primitive from the cached pool
+    auto reorderPrim = static_cast<MklReorderWithScalePrimitive*>(
+        MklReorderWithScalePrimitiveFactory<T>::GetInstance().GetReorder(
+            from, to, fwdParams));
+    if (reorderPrim == nullptr) {
+      reorderPrim = new MklReorderWithScalePrimitive(from, to, fwdParams);
+      MklReorderWithScalePrimitiveFactory<T>::GetInstance().SetReorder(
+          from, to, reorderPrim, fwdParams);
+    }
+    reorderPrim->SetMemory(from, to);
+    return reorderPrim;
+  }
+
+  static MklReorderWithScalePrimitiveFactory& GetInstance() {
+    static MklReorderWithScalePrimitiveFactory instance_;
+    return instance_;
+  }
+
+ private:
+  MklReorderWithScalePrimitiveFactory() {}
+  ~MklReorderWithScalePrimitiveFactory() {}
+
+  static string CreateKey(const memory* from, const memory* to,
+                          const MklReorderWithScaleFwdParams& fwdParams) {
+    string dtypes = string("");
+    string prefix = "reorder";
+    FactoryKeyCreator key_creator;
+    auto const& from_desc = from->get_primitive_desc().desc().data;
+    auto const& to_desc = to->get_primitive_desc().desc().data;
+
+    key_creator.AddAsKey(prefix);
+    key_creator.AddAsKey(static_cast<int>(from_desc.format));
+    key_creator.AddAsKey(static_cast<int>(from_desc.data_type));
+    key_creator.AddAsKey(fwdParams.src_dims);
+    key_creator.AddAsKey(static_cast<int>(to_desc.format));
+    key_creator.AddAsKey(static_cast<int>(to_desc.data_type));
+    key_creator.AddAsKey(fwdParams.dtypes);
+
+    // Generate key for post-op scale
+    if (fwdParams.post_op_params.name == "scale") {
+      DCHECK_EQ(fwdParams.post_op_params.param.size(), 1);
+      key_creator.AddAsKey(fwdParams.post_op_params.name);
+      key_creator.AddAsKey(fwdParams.post_op_params.param[0]);
+    } else {
+      return string("not_a_key");
+    }
+
+    return key_creator.GetKey();
+  }
+
+  MklPrimitive* GetReorder(const memory* from, const memory* to,
+                           const MklReorderWithScaleFwdParams& fwdParams) {
+    string key = CreateKey(from, to, fwdParams);
+    return this->GetOp(key);
+  }
+
+  void SetReorder(const memory* from, const memory* to, MklPrimitive* op,
+                  const MklReorderWithScaleFwdParams& fwdParams) {
+    string key = CreateKey(from, to, fwdParams);
+    this->SetOp(key, op);
+  }
+};
+
+/// Fuction to find(or create) a reorder from memory pointed by
+/// from to memory pointed by to, it will create primitive or
+/// get primitive from pool if it is cached.
+/// Returns the primitive.
+template <typename T>
+inline primitive FindOrCreateReorder(
+    const memory* from, const memory* to,
+    const MklReorderWithScaleFwdParams& fwdParams) {
+  CHECK_NOTNULL(from);
+  CHECK_NOTNULL(to);
+  MklReorderWithScalePrimitive* reorder_prim =
+      MklReorderWithScalePrimitiveFactory<T>::Get(from, to, fwdParams);
+  return *reorder_prim->GetPrimitive();
+}
+
 // Quantizes a tensor from float to T, with user-specified min_range and
 // max_range.
 template <typename Device, typename T>
@@ -67,6 +257,7 @@ class MklQuantizeV2Op : public OpKernel {
   explicit MklQuantizeV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) {
     string mode_string;
     OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_string));
+<<<<<<< HEAD
     OP_REQUIRES(ctx,
                 (mode_string == "MIN_COMBINED" || mode_string == "MIN_FIRST" ||
                  mode_string == "SCALED"),
@@ -93,6 +284,7 @@ class MklQuantizeV2Op : public OpKernel {
     if (round_mode_string == "HALF_AWAY_FROM_ZERO") {
       round_mode_ = ROUND_HALF_AWAY_FROM_ZERO;
     } else if (round_mode_string == "HALF_TO_EVEN") {
+    if (round_mode_string == "HALF_TO_EVEN") {
       OP_REQUIRES(ctx, mode_string == "SCALED",
                   errors::InvalidArgument("Round mode 'HALF_TO_EVEN' "
                                           "only supported for mode 'SCALED', "
@@ -106,9 +298,22 @@ class MklQuantizeV2Op : public OpKernel {
         ctx, ctx->GetAttr("ensure_minimum_range", &ensure_minimum_range_));
   }
 
-  ~MklQuantizeV2Op() {}
+  ~MklQuantizeV2Op() {
+    if (this->minfirst_input_ != nullptr) {
+      delete this->minfirst_input_;
+      minfirst_input_ = nullptr;
+    }
+  }
+
+  float* GetMinfirstInputBuf(int size) {
+    if (!minfirst_input_) {
+      minfirst_input_ = new float[size];
+    }
+    return minfirst_input_;
+  }
 
   void Compute(OpKernelContext* ctx) override {
+    const Tensor& input = ctx->input(0);
     const float input_min_range = ctx->input(1).flat<float>()(0);
     const float input_max_range = ctx->input(2).flat<float>()(0);
     float min_range = std::min(0.0f, input_min_range);
@@ -174,7 +379,20 @@ class MklQuantizeV2Op : public OpKernel {
         src_mkl_shape.IsMklTensor()
             ? src_mkl_shape.GetMklLayout()
             : memory::desc(src_dims, MklDnnType<float>(), dst_layout_type);
-    src.SetUsrMem(src_md, &src_tensor);
+
+    // If the mode is min_first, input data has to be subtracted from
+    // min_range, before being scaled
+    auto flat_input = input.flat<float>().data();
+    if (mode_ == QUANTIZE_MODE_MIN_FIRST) {
+      float* minfirst_input = GetMinfirstInputBuf(input.NumElements());
+#pragma omp parallel for schedule(static)
+      for (int i = 0; i < input.NumElements(); i++) {
+        minfirst_input[i] = flat_input[i] - min_range;
+      }
+      src.SetUsrMem(src_md, minfirst_input);
+    } else {
+      src.SetUsrMem(src_md, &src_tensor);
+    }
 
     memory::desc dst_md =
         memory::desc(src_dims, MklDnnType<T>(), dst_layout_type);
@@ -212,39 +430,65 @@ class MklQuantizeV2Op : public OpKernel {
                               max_mkl_shape);
 
     dst.SetUsrMem(dst_md, output_tensor);
-    // Estimating scales for quantization.
-    const int num_bits = sizeof(T) * 8;
-    const float max_abs = std::max(std::abs(min_range), std::abs(max_range));
-    const bool is_signed = std::is_signed<T>::value;
-    float target_range;
-    if (is_signed) {
-      max_range = max_abs;
-      min_range = -max_abs;
-      // If it is signed, we try to keep 0.0 being 0 and drop one bucket. For
-      // example, if it is 8 bits, we have the range [-127, 127]. So for input
-      // range of [-x, x], the scale should be 254/(2*x).
-      target_range = static_cast<float>((uint64_t{1} << (num_bits - 1)) - 1);
-    } else {
-      max_range = max_abs;
-      min_range = 0.0;
-      // If it is unsigned and num_bits == 8, the range with 8 bits is [0,
-      // 255].  If the input range is [0, x], then the scale is 255/x instead
-      // of 254 as in the case above.
-      target_range = static_cast<float>((uint64_t{1} << num_bits) - 1);
+
+    float scale_factor = 0;
+    if (mode_ == QUANTIZE_MODE_SCALED) {
+      // Estimating scales for quantization.
+      const int num_bits = sizeof(T) * 8;
+      const float max_abs = std::max(std::abs(min_range), std::abs(max_range));
+      const bool is_signed = std::is_signed<T>::value;
+      float target_range;
+      if (is_signed) {
+        max_range = max_abs;
+        min_range = -max_abs;
+        // If it is signed, we try to keep 0.0 being 0 and drop one bucket. For
+        // example, if it is 8 bits, we have the range [-127, 127]. So for input
+        // range of [-x, x], the scale should be 254/(2*x).
+        target_range = static_cast<float>((uint64_t{1} << (num_bits - 1)) - 1);
+      } else {
+        max_range = max_abs;
+        min_range = 0.0;
+        // If it is unsigned and num_bits == 8, the range with 8 bits is [0,
+        // 255].  If the input range is [0, x], then the scale is 255/x instead
+        // of 254 as in the case above.
+        target_range = static_cast<float>((uint64_t{1} << num_bits) - 1);
+      }
+      scale_factor = target_range / max_abs;
+
+      output_min_tensor->flat<float>()(0) = min_range;
+      output_max_tensor->flat<float>()(0) = max_range;
+
+      // Primitive creation and stream submit
+      std::vector<float> scales{scale_factor};
+      mkldnn::primitive_attr attr;
+      attr.set_output_scales(0, scales);
+      auto reorder_desc = reorder::primitive_desc(
+          src.GetUsrMemPrimDesc(), dst.GetUsrMemPrimDesc(), attr);
+      reorder my_reorder = reorder(
+          reorder_desc, primitive::at(*src.GetUsrMem()), *dst.GetUsrMem());
+      std::vector<primitive> net{my_reorder};
+      stream(stream::kind::eager).submit(net).wait();
+    } else if (mode_ == QUANTIZE_MODE_MIN_FIRST) {
+      // Estimate scale for qunatization
+      const int number_of_bits = sizeof(T) * 8;
+      const int64 number_of_steps = static_cast<int64>(1) << number_of_bits;
+      scale_factor = (number_of_steps - 1.0) / (max_range - min_range);
+
+      output_min_tensor->flat<float>()(0) = min_range;
+      output_max_tensor->flat<float>()(0) = max_range;
+
+      MklReorderWithScaleFwdParams fwdParams(src_dims, src_md, dst_md);
+      fwdParams.dtypes.append(typeid(T).name());
+
+      fwdParams.post_op_params.name = "scale";
+      fwdParams.post_op_params.param.push_back(scale_factor);
+
+      // Get primitive from pool or create one and submit
+      std::vector<primitive> net;
+      net.push_back(
+          FindOrCreateReorder<T>(src.GetUsrMem(), dst.GetUsrMem(), fwdParams));
+      stream(stream::kind::eager).submit(net).wait();
     }
-    output_min_tensor->flat<float>()(0) = min_range;
-    output_max_tensor->flat<float>()(0) = max_range;
-    const float scale_factor = target_range / max_abs;
-    // Primitive creation and stream submit
-    std::vector<float> scales{scale_factor};
-    mkldnn::primitive_attr attr;
-    attr.set_output_scales(0, scales);
-    auto reorder_desc = reorder::primitive_desc(src.GetUsrMemPrimDesc(),
-                                                dst.GetUsrMemPrimDesc(), attr);
-    reorder my_reorder = reorder(reorder_desc, primitive::at(*src.GetUsrMem()),
-                                 *dst.GetUsrMem());
-    std::vector<primitive> net{my_reorder};
-    stream(stream::kind::eager).submit(net).wait();
   }
 
  private:
@@ -253,6 +497,7 @@ class MklQuantizeV2Op : public OpKernel {
   int round_mode_;
   int axis_;
   bool narrow_range_;
+  float* minfirst_input_ = nullptr;
 };
 
 REGISTER_KERNEL_BUILDER(Name("_MklQuantizeV2")
diff --git a/tensorflow/core/kernels/mkl_quantize_op_test.cc b/tensorflow/core/kernels/mkl_quantize_op_test.cc
index cb53411ee6c..289bb00a26e 100644
--- a/tensorflow/core/kernels/mkl_quantize_op_test.cc
+++ b/tensorflow/core/kernels/mkl_quantize_op_test.cc
@@ -95,4 +95,95 @@ TEST_F(MklQuantizeV2OpTest, small_int8) {
   test::ExpectTensorEqual<float>(expected_min, *GetOutput(1));
   test::ExpectTensorEqual<float>(expected_max, *GetOutput(2));
 }
+
+TEST_F(MklQuantizeV2OpTest, small_minfirst) {
+  TF_ASSERT_OK(NodeDefBuilder("quantize_op", "_MklQuantizeV2")
+                   .Input(FakeInput(DT_FLOAT))
+                   .Input(FakeInput(DT_FLOAT))
+                   .Input(FakeInput(DT_FLOAT))
+                   .Input(FakeInput(DT_UINT8))  // MKL second tensor
+                   .Input(FakeInput(DT_UINT8))  // MKL second tensor
+                   .Input(FakeInput(DT_UINT8))  // MKL second tensor
+                   .Attr("T", DataTypeToEnum<quint8>::v())
+                   .Attr("mode", "MIN_FIRST")
+                   .Attr("_kernel", "QuantizedMklOp")
+                   .Finalize(node_def()));
+  TF_ASSERT_OK(InitOp());
+  AddInputFromArray<float>(TensorShape({8}),
+                           {1.0, 1.25, 1.75, 2, 3.15, 127.0, 255.0, 500.0});
+  AddInputFromArray<float>(TensorShape({1}), {0});
+  AddInputFromArray<float>(TensorShape({1}), {255.0f});
+  AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
+  AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
+  AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
+  TF_ASSERT_OK(RunOpKernel());
+  Tensor expected(allocator(), DT_QUINT8, TensorShape({8}));
+  test::FillValues<quint8>(&expected, {1, 1, 2, 2, 3, 127, 255, 255});
+  test::ExpectTensorEqual<quint8>(expected, *GetOutput(0));
+  const float output_min = GetOutput(1)->flat<float>()(0);
+  const float output_max = GetOutput(2)->flat<float>()(0);
+  EXPECT_NEAR(0.0f, output_min, 1e-5f);
+  EXPECT_NEAR(255.0f, output_max, 1e-5f);
+}
+
+TEST_F(MklQuantizeV2OpTest, small_minfirst_uint) {
+  TF_ASSERT_OK(NodeDefBuilder("quantize_op", "_MklQuantizeV2")
+                   .Input(FakeInput(DT_FLOAT))
+                   .Input(FakeInput(DT_FLOAT))
+                   .Input(FakeInput(DT_FLOAT))
+                   .Input(FakeInput(DT_UINT8))  // MKL second tensor
+                   .Input(FakeInput(DT_UINT8))  // MKL second tensor
+                   .Input(FakeInput(DT_UINT8))  // MKL second tensor
+                   .Attr("T", DataTypeToEnum<quint8>::v())
+                   .Attr("mode", "MIN_FIRST")
+                   .Attr("_kernel", "QuantizedMklOp")
+                   .Finalize(node_def()));
+  TF_ASSERT_OK(InitOp());
+  AddInputFromArray<float>(TensorShape({8}),
+                           {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
+  AddInputFromArray<float>(TensorShape({1}), {0.1});
+  AddInputFromArray<float>(TensorShape({1}), {0.8});
+  AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
+  AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
+  AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
+  TF_ASSERT_OK(RunOpKernel());
+  Tensor expected(allocator(), DT_QUINT8, TensorShape({8}));
+  test::FillValues<quint8>(&expected, {32, 64, 96, 128, 159, 191, 223, 255});
+  test::ExpectTensorEqual<quint8>(expected, *GetOutput(0));
+  const float output_min = GetOutput(1)->flat<float>()(0);
+  const float output_max = GetOutput(2)->flat<float>()(0);
+  EXPECT_NEAR(0.0f, output_min, 1e-5f);
+  EXPECT_NEAR(0.8f, output_max, 1e-5f);
+}
+
+TEST_F(MklQuantizeV2OpTest, small_minfirst_int) {
+  TF_ASSERT_OK(NodeDefBuilder("quantize_op", "_MklQuantizeV2")
+                   .Input(FakeInput(DT_FLOAT))
+                   .Input(FakeInput(DT_FLOAT))
+                   .Input(FakeInput(DT_FLOAT))
+                   .Input(FakeInput(DT_UINT8))  // MKL second tensor
+                   .Input(FakeInput(DT_UINT8))  // MKL second tensor
+                   .Input(FakeInput(DT_UINT8))  // MKL second tensor
+                   .Attr("T", DataTypeToEnum<quint8>::v())
+                   .Attr("mode", "MIN_FIRST")
+                   .Attr("_kernel", "QuantizedMklOp")
+                   .Finalize(node_def()));
+  TF_ASSERT_OK(InitOp());
+  AddInputFromArray<float>(TensorShape({8}),
+                           {-0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8});
+  AddInputFromArray<float>(TensorShape({1}), {-0.8});
+  AddInputFromArray<float>(TensorShape({1}), {-0.1});
+  AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
+  AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
+  AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
+  TF_ASSERT_OK(RunOpKernel());
+  Tensor expected(allocator(), DT_QUINT8, TensorShape({8}));
+  test::FillValues<quint8>(&expected, {223, 191, 159, 128, 96, 64, 32, 0});
+  test::ExpectTensorEqual<quint8>(expected, *GetOutput(0));
+  const float output_min = GetOutput(1)->flat<float>()(0);
+  const float output_max = GetOutput(2)->flat<float>()(0);
+  EXPECT_NEAR(-0.8f, output_min, 1e-5f);
+  EXPECT_NEAR(0.0f, output_max, 1e-5f);
+}
+
 }  // end namespace tensorflow
diff --git a/tensorflow/core/ops/mkl_array_ops.cc b/tensorflow/core/ops/mkl_array_ops.cc
index 5e847ecb22e..d4908f881e9 100644
--- a/tensorflow/core/ops/mkl_array_ops.cc
+++ b/tensorflow/core/ops/mkl_array_ops.cc
@@ -105,7 +105,7 @@ REGISTER_OP("_MklQuantizeV2")
     .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'SCALED'")
     .Attr(
         "round_mode: {'HALF_AWAY_FROM_ZERO', 'HALF_TO_EVEN'} = "
-        "'HALF_TO_EVEN'")
+        "'HALF_AWAY_FROM_ZERO'")
     .Attr("narrow_range: bool = false")
     .Attr("axis: int = -1")
     .Attr("ensure_minimum_range: float = 0.01")