From cff8012de1e657fd9286121492adcc146e345986 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Tue, 14 Jan 2020 11:55:44 -0800
Subject: [PATCH] Support Dequantize to bfloat16. Introduce DequantizeV2 which
 allows user to specify the output dtype{float|bfloat16}.

PiperOrigin-RevId: 289699810
Change-Id: Idb12a52b6b9c18d015278b5c9aa4fd347a109b60
---
 .../compiler/tf2xla/kernels/dequantize_op.cc  |   8 +-
 .../api_def/base_api/api_def_Dequantize.pbtxt |   9 +-
 tensorflow/core/kernels/dequantize_op.cc      | 157 +++++++++++++-----
 tensorflow/core/kernels/dequantize_op_test.cc | 105 +++++++++++-
 tensorflow/core/ops/array_ops.cc              |   3 +-
 .../compat/ops_history_v1/Dequantize.pbtxt    |  73 ++++++++
 tensorflow/python/ops/array_ops.py            |  16 +-
 .../tools/api/golden/v1/tensorflow.pbtxt      |   2 +-
 .../golden/v1/tensorflow.quantization.pbtxt   |   2 +-
 .../api/golden/v1/tensorflow.raw_ops.pbtxt    |   2 +-
 .../golden/v2/tensorflow.quantization.pbtxt   |   2 +-
 .../api/golden/v2/tensorflow.raw_ops.pbtxt    |   2 +-
 12 files changed, 318 insertions(+), 63 deletions(-)

diff --git a/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc
index 06614d7b7c5..52509352919 100644
--- a/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc
@@ -55,6 +55,7 @@ class DequantizeOp : public XlaOpKernel {
     OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis));
     OP_REQUIRES(ctx, axis == -1,
                 errors::InvalidArgument("axis must be -1' is ", axis));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
   }
 
   ~DequantizeOp() override = default;
@@ -86,7 +87,6 @@ class DequantizeOp : public XlaOpKernel {
     xla::XlaOp input = ctx->Input(0);
     xla::XlaOp output;
 
-    // TODO(ylc): Support bfloat16.
     output = xla::ConvertElementType(input, xla::F32);
 
     auto scale = ScalarLike(output, scale_factor);
@@ -94,8 +94,14 @@ class DequantizeOp : public XlaOpKernel {
     output = xla::Add(xla::Mul(xla::Add(output, halfrange), scale),
                       ScalarLike(output, min_range));
 
+    if (dtype_ == DT_BFLOAT16) {
+      output = xla::ConvertElementType(input, xla::BF16);
+    }
     ctx->SetOutput(0, output);
   }
+
+ private:
+  DataType dtype_;
 };
 
 REGISTER_XLA_OP(Name("Dequantize").TypeConstraint("T", kQuantizedType),
diff --git a/tensorflow/core/api_def/base_api/api_def_Dequantize.pbtxt b/tensorflow/core/api_def/base_api/api_def_Dequantize.pbtxt
index 82804e46e0e..030b98c369d 100644
--- a/tensorflow/core/api_def/base_api/api_def_Dequantize.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Dequantize.pbtxt
@@ -12,7 +12,14 @@ END
 The maximum scalar value possibly produced for the input.
 END
   }
-  summary: "Dequantize the \'input\' tensor into a float Tensor."
+  attr {
+    name: "dtype"
+    description: <<END
+Type of the output tensor. Currently Dequantize supports float and bfloat16.
+If 'dtype' is 'bfloat16', it only supports 'MIN_COMBINED' mode.
+END
+  }
+  summary: "Dequantize the \'input\' tensor into a float or bfloat16 Tensor."
   description: <<END
 [min_range, max_range] are scalar floats that specify the range for
 the output. The 'mode' attribute controls exactly which calculations are
diff --git a/tensorflow/core/kernels/dequantize_op.cc b/tensorflow/core/kernels/dequantize_op.cc
index 481909e8420..0f5a7019b1f 100644
--- a/tensorflow/core/kernels/dequantize_op.cc
+++ b/tensorflow/core/kernels/dequantize_op.cc
@@ -23,6 +23,7 @@ limitations under the License.
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/kernels/meta_support.h"
 #include "tensorflow/core/kernels/quantization_utils.h"
+#include "tensorflow/core/lib/bfloat16/bfloat16.h"
 #include "tensorflow/core/lib/core/errors.h"
 
 namespace {
@@ -37,18 +38,44 @@ namespace tensorflow {
 
 typedef Eigen::ThreadPoolDevice CPUDevice;
 
-template <typename Device, typename T>
+template <typename T>
+T Cast(float v) {
+  return v;
+}
+
+template <>
+bfloat16 Cast<bfloat16>(float v) {
+  return bfloat16(v);
+}
+
+template <typename Device, typename T, typename S>
 class DequantizeOp : public OpKernel {
  public:
   explicit DequantizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
     string mode_string;
     OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_string));
-    OP_REQUIRES(ctx,
-                (mode_string == "MIN_COMBINED" || mode_string == "MIN_FIRST" ||
-                 mode_string == "SCALED"),
-                errors::InvalidArgument("Mode string must be 'MIN_COMBINED',"
-                                        " 'MIN_FIRST', or 'SCALED', is '" +
-                                        mode_string + "'"));
+    OP_REQUIRES(
+        ctx,
+        (ctx->output_type(0) == DT_FLOAT || ctx->output_type(0) == DT_BFLOAT16),
+        errors::InvalidArgument("Output type must be bfloat16 or float,"
+                                " is '" +
+                                DataTypeString(ctx->output_type(0)) + "'"));
+
+    if (ctx->output_type(0) == DT_FLOAT) {
+      OP_REQUIRES(ctx,
+                  (mode_string == "MIN_COMBINED" ||
+                   mode_string == "MIN_FIRST" || mode_string == "SCALED"),
+                  errors::InvalidArgument("Mode string must be 'MIN_COMBINED',"
+                                          " 'MIN_FIRST', or 'SCALED', is '" +
+                                          mode_string + "'"));
+    } else {
+      OP_REQUIRES(
+          ctx, (mode_string == "MIN_COMBINED"),
+          errors::InvalidArgument("When output type is bfloat16, Mode"
+                                  " string must be 'MIN_COMBINED', is '" +
+                                  mode_string + "'"));
+    }
+
     if (mode_string == "MIN_COMBINED") {
       mode_ = QUANTIZE_MODE_MIN_COMBINED;
     } else if (mode_string == "MIN_FIRST") {
@@ -71,34 +98,40 @@ class DequantizeOp : public OpKernel {
     }
 
     Tensor* output = nullptr;
+    Tensor float_output = tensorflow::Tensor(DT_FLOAT, input.shape());
     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
     if (num_slices == 1) {
       const float min_range = input_min_tensor.flat<float>()(0);
       const float max_range = input_max_tensor.flat<float>()(0);
-      DequantizeTensor(ctx, input, min_range, max_range, output);
-      return;
-    }
+      DequantizeTensor(ctx, input, min_range, max_range, &float_output);
+    } else {
+      OP_REQUIRES(ctx, mode_ != QUANTIZE_MODE_MIN_FIRST,
+                  errors::Unimplemented("MIN_FIRST mode is not implemented for "
+                                        "Dequantize with axis != -1."));
 
-    OP_REQUIRES(ctx, mode_ != QUANTIZE_MODE_MIN_FIRST,
-                errors::Unimplemented("MIN_FIRST mode is not implemented for "
-                                      "Dequantize with axis != -1."));
-
-    int64 pre_dim = 1, post_dim = 1;
-    for (int i = 0; i < axis_; ++i) {
-      pre_dim *= output->dim_size(i);
+      int64 pre_dim = 1, post_dim = 1;
+      for (int i = 0; i < axis_; ++i) {
+        pre_dim *= float_output.dim_size(i);
+      }
+      for (int i = axis_ + 1; i < float_output.dims(); ++i) {
+        post_dim *= float_output.dim_size(i);
+      }
+      auto input_tensor = input.template bit_casted_shaped<T, 3>(
+          {pre_dim, num_slices, post_dim});
+      auto output_tensor =
+          float_output.flat_inner_outer_dims<float, 3>(axis_ - 1);
+      auto min_ranges = input_min_tensor.vec<float>();
+      auto max_ranges = input_max_tensor.vec<float>();
+      for (int i = 0; i < num_slices; ++i) {
+        DequantizeSlice(ctx->eigen_device<Device>(), ctx,
+                        input_tensor.template chip<1>(i), min_ranges(i),
+                        max_ranges(i), output_tensor.template chip<1>(i));
+      }
     }
-    for (int i = axis_ + 1; i < output->dims(); ++i) {
-      post_dim *= output->dim_size(i);
-    }
-    auto input_tensor =
-        input.template bit_casted_shaped<T, 3>({pre_dim, num_slices, post_dim});
-    auto output_tensor = output->flat_inner_outer_dims<float, 3>(axis_ - 1);
-    auto min_ranges = input_min_tensor.vec<float>();
-    auto max_ranges = input_max_tensor.vec<float>();
-    for (int i = 0; i < num_slices; ++i) {
-      DequantizeSlice(ctx->eigen_device<Device>(), ctx,
-                      input_tensor.template chip<1>(i), min_ranges(i),
-                      max_ranges(i), output_tensor.template chip<1>(i));
+    S* out_ptr = output->flat<S>().data();
+    float* in_ptr = float_output.flat<float>().data();
+    for (int64 i = 0; i < float_output.NumElements(); ++i) {
+      out_ptr[i] = static_cast<S>(in_ptr[i]);
     }
   }
 
@@ -188,21 +221,55 @@ class DequantizeOp : public OpKernel {
   bool narrow_range_;
 };
 
-REGISTER_KERNEL_BUILDER(
-    Name("Dequantize").Device(DEVICE_CPU).TypeConstraint<quint8>("T"),
-    DequantizeOp<CPUDevice, quint8>);
-REGISTER_KERNEL_BUILDER(
-    Name("Dequantize").Device(DEVICE_CPU).TypeConstraint<qint8>("T"),
-    DequantizeOp<CPUDevice, qint8>);
-REGISTER_KERNEL_BUILDER(
-    Name("Dequantize").Device(DEVICE_CPU).TypeConstraint<quint16>("T"),
-    DequantizeOp<CPUDevice, quint16>);
-REGISTER_KERNEL_BUILDER(
-    Name("Dequantize").Device(DEVICE_CPU).TypeConstraint<qint16>("T"),
-    DequantizeOp<CPUDevice, qint16>);
-
-REGISTER_KERNEL_BUILDER(
-    Name("Dequantize").Device(DEVICE_CPU).TypeConstraint<qint32>("T"),
-    DequantizeOp<CPUDevice, qint32>);
+REGISTER_KERNEL_BUILDER(Name("Dequantize")
+                            .Device(DEVICE_CPU)
+                            .TypeConstraint<quint8>("T")
+                            .TypeConstraint<float>("dtype"),
+                        DequantizeOp<CPUDevice, quint8, float>);
+REGISTER_KERNEL_BUILDER(Name("Dequantize")
+                            .Device(DEVICE_CPU)
+                            .TypeConstraint<qint8>("T")
+                            .TypeConstraint<float>("dtype"),
+                        DequantizeOp<CPUDevice, qint8, float>);
+REGISTER_KERNEL_BUILDER(Name("Dequantize")
+                            .Device(DEVICE_CPU)
+                            .TypeConstraint<quint16>("T")
+                            .TypeConstraint<float>("dtype"),
+                        DequantizeOp<CPUDevice, quint16, float>);
+REGISTER_KERNEL_BUILDER(Name("Dequantize")
+                            .Device(DEVICE_CPU)
+                            .TypeConstraint<qint16>("T")
+                            .TypeConstraint<float>("dtype"),
+                        DequantizeOp<CPUDevice, qint16, float>);
+REGISTER_KERNEL_BUILDER(Name("Dequantize")
+                            .Device(DEVICE_CPU)
+                            .TypeConstraint<qint32>("T")
+                            .TypeConstraint<float>("dtype"),
+                        DequantizeOp<CPUDevice, qint32, float>);
 
+REGISTER_KERNEL_BUILDER(Name("Dequantize")
+                            .Device(DEVICE_CPU)
+                            .TypeConstraint<quint8>("T")
+                            .TypeConstraint<bfloat16>("dtype"),
+                        DequantizeOp<CPUDevice, quint8, bfloat16>);
+REGISTER_KERNEL_BUILDER(Name("Dequantize")
+                            .Device(DEVICE_CPU)
+                            .TypeConstraint<qint8>("T")
+                            .TypeConstraint<bfloat16>("dtype"),
+                        DequantizeOp<CPUDevice, qint8, bfloat16>);
+REGISTER_KERNEL_BUILDER(Name("Dequantize")
+                            .Device(DEVICE_CPU)
+                            .TypeConstraint<quint16>("T")
+                            .TypeConstraint<bfloat16>("dtype"),
+                        DequantizeOp<CPUDevice, quint16, bfloat16>);
+REGISTER_KERNEL_BUILDER(Name("Dequantize")
+                            .Device(DEVICE_CPU)
+                            .TypeConstraint<qint16>("T")
+                            .TypeConstraint<bfloat16>("dtype"),
+                        DequantizeOp<CPUDevice, qint16, bfloat16>);
+REGISTER_KERNEL_BUILDER(Name("Dequantize")
+                            .Device(DEVICE_CPU)
+                            .TypeConstraint<qint32>("T")
+                            .TypeConstraint<bfloat16>("dtype"),
+                        DequantizeOp<CPUDevice, qint32, bfloat16>);
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/dequantize_op_test.cc b/tensorflow/core/kernels/dequantize_op_test.cc
index 30e73caf143..3c9d1790787 100644
--- a/tensorflow/core/kernels/dequantize_op_test.cc
+++ b/tensorflow/core/kernels/dequantize_op_test.cc
@@ -28,6 +28,7 @@ limitations under the License.
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
 #include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/kernels/ops_testutil.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test_benchmark.h"
@@ -61,8 +62,9 @@ class DequantizeOpTest : public OpsTestBase {
   // Compares dequantize min vs the same using eigen. This tests that a change
   // to not use eigen gives equivalent results to using eigen.
   template <typename T>
-  void RunDequantizeMinCombinedTest(float min_range, float max_range) {
-    TF_ASSERT_OK(NodeDefBuilder("dequantize_op", "Dequantize")
+  void RunDequantizeMinCombinedTest(float min_range, float max_range,
+                                    const string& op_name) {
+    TF_ASSERT_OK(NodeDefBuilder("dequantize_op", op_name)
                      .Input(FakeInput(DataTypeToEnum<T>::v()))
                      .Input(FakeInput(DT_FLOAT))
                      .Input(FakeInput(DT_FLOAT))
@@ -87,6 +89,40 @@ class DequantizeOpTest : public OpsTestBase {
     test::ExpectTensorEqual<float>(expected, *GetOutput(0));
   }
 
+  // Compares dequantize min vs the same using eigen. This tests that a change
+  // to not use eigen gives equivalent results to using eigen.
+  template <typename T>
+  void RunDequantizeBfloat16MinCombinedTest(float min_range, float max_range) {
+    TF_ASSERT_OK(NodeDefBuilder("dequantize_op_bfloat16", "Dequantize")
+                     .Input(FakeInput(DataTypeToEnum<T>::v()))
+                     .Input(FakeInput(DT_FLOAT))
+                     .Input(FakeInput(DT_FLOAT))
+                     .Attr("T", DataTypeToEnum<T>::v())
+                     .Attr("mode", "MIN_COMBINED")
+                     .Attr("dtype", DT_BFLOAT16)
+                     .Finalize(node_def()));
+    TF_ASSERT_OK(InitOp());
+
+    std::vector<T> input;
+    for (int64 i = std::numeric_limits<T>::min();
+         i < std::numeric_limits<T>::max(); ++i) {
+      input.push_back(static_cast<T>(i));
+    }
+    TensorShape shape({static_cast<int64>(input.size())});
+    AddInputFromArray<T>(shape, input);
+    AddInputFromArray<float>(TensorShape({}), {min_range});
+    AddInputFromArray<float>(TensorShape({}), {max_range});
+    TF_ASSERT_OK(RunOpKernel());
+
+    Tensor expected_float32(allocator(), DT_FLOAT, shape);
+    ComputeDequantizeMinCombinedUsingEigen<T>(GetInput(0), min_range, max_range,
+                                              &expected_float32);
+    Tensor expected(allocator(), DT_BFLOAT16, shape);
+    expected.flat<bfloat16>() = expected_float32.flat<float>().cast<bfloat16>();
+
+    test::ExpectTensorEqual<bfloat16>(expected, *GetOutput(0));
+  }
+
   // Creates a tensor with the specified dims, using values chosen from data,
   // multiplied by (1 + index) along the axis dimension.
   template <typename T>
@@ -151,16 +187,29 @@ struct ParameterizedDequantizeOpTest
       public ::testing::WithParamInterface<int> {};
 
 TEST_F(DequantizeOpTest, DequantizeMinCombinedQuint8) {
-  RunDequantizeMinCombinedTest<quint8>(0, 255.0f);
+  RunDequantizeMinCombinedTest<quint8>(0, 255.0f, "Dequantize");
 }
 TEST_F(DequantizeOpTest, DequantizeMinCombinedQint8) {
-  RunDequantizeMinCombinedTest<qint8>(0, 255.0f);
+  RunDequantizeMinCombinedTest<qint8>(0, 255.0f, "Dequantize");
 }
 TEST_F(DequantizeOpTest, DequantizeMinCombinedQint16) {
-  RunDequantizeMinCombinedTest<qint16>(0, 255.0f);
+  RunDequantizeMinCombinedTest<qint16>(0, 255.0f, "Dequantize");
 }
 TEST_F(DequantizeOpTest, DequantizeMinCombinedQuint16) {
-  RunDequantizeMinCombinedTest<quint16>(0, 255.0f);
+  RunDequantizeMinCombinedTest<quint16>(0, 255.0f, "Dequantize");
+}
+
+TEST_F(DequantizeOpTest, DequantizeBfloat16MinCombinedQuint8) {
+  RunDequantizeBfloat16MinCombinedTest<quint8>(0, 255.0f);
+}
+TEST_F(DequantizeOpTest, DequantizeBfloat16MinCombinedQint8) {
+  RunDequantizeBfloat16MinCombinedTest<qint8>(0, 255.0f);
+}
+TEST_F(DequantizeOpTest, DequantizeBfloat16MinCombinedQint16) {
+  RunDequantizeBfloat16MinCombinedTest<qint16>(0, 255.0f);
+}
+TEST_F(DequantizeOpTest, DequantizeBfloat16MinCombinedQuint16) {
+  RunDequantizeBfloat16MinCombinedTest<quint16>(0, 255.0f);
 }
 
 TEST_F(DequantizeOpTest, DequantizeScaledQuint8Zero) {
@@ -202,8 +251,10 @@ static void BM_DequantizeMinCombinedCpu(int iters) {
   auto root = Scope::NewRootScope().ExitOnError();
   const int64 num_values = 1500 * 250;
   std::vector<T> inputs;
+
   inputs.reserve(num_values);
   for (int i = 0; i < num_values; ++i) inputs.push_back(i);
+
   ops::Dequantize(root, test::AsTensor<T>(inputs), test::AsScalar<float>(-1.5f),
                   test::AsScalar<float>(20.5f),
                   ops::Dequantize::Attrs().Mode("MIN_COMBINED"));
@@ -237,5 +288,47 @@ BENCHMARK(BM_DequantizeMinCombinedCpuQint16);
 BENCHMARK(BM_DequantizeMinCombinedCpuQuint8);
 BENCHMARK(BM_DequantizeMinCombinedCpuQint8);
 
+template <typename T>
+static void BM_DequantizeBfloat16MinCombinedCpu(int iters) {
+  auto root = Scope::NewRootScope().ExitOnError();
+  const int64 num_values = 1500 * 250;
+  std::vector<T> inputs;
+
+  inputs.reserve(num_values);
+  for (int i = 0; i < num_values; ++i) inputs.push_back(i);
+
+  ops::Dequantize(root, test::AsTensor<T>(inputs), test::AsScalar<float>(-1.5f),
+                  test::AsScalar<float>(20.5f),
+                  ops::Dequantize::Attrs().Dtype(DT_BFLOAT16));
+  TF_CHECK_OK(root.status());
+  Graph* g = new Graph(OpRegistry::Global());
+  TF_CHECK_OK(root.ToGraph(g));
+
+  test::Benchmark("cpu", g).Run(iters);
+  testing::BytesProcessed(iters * num_values * (sizeof(bfloat16) + sizeof(T)));
+  testing::ItemsProcessed(iters);
+}
+
+static void BM_DequantizeBfloat16MinCombinedCpuQuint16(int iters) {
+  BM_DequantizeBfloat16MinCombinedCpu<quint16>(iters);
+}
+
+static void BM_DequantizeBfloat16MinCombinedCpuQint16(int iters) {
+  BM_DequantizeBfloat16MinCombinedCpu<qint16>(iters);
+}
+
+static void BM_DequantizeBfloat16MinCombinedCpuQuint8(int iters) {
+  BM_DequantizeBfloat16MinCombinedCpu<quint8>(iters);
+}
+
+static void BM_DequantizeBfloat16MinCombinedCpuQint8(int iters) {
+  BM_DequantizeBfloat16MinCombinedCpu<qint8>(iters);
+}
+
+BENCHMARK(BM_DequantizeBfloat16MinCombinedCpuQuint16);
+BENCHMARK(BM_DequantizeBfloat16MinCombinedCpuQint16);
+BENCHMARK(BM_DequantizeBfloat16MinCombinedCpuQuint8);
+BENCHMARK(BM_DequantizeBfloat16MinCombinedCpuQint8);
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index a427b8b3967..60efdcb7a73 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -2871,11 +2871,12 @@ REGISTER_OP("Dequantize")
     .Input("input: T")
     .Input("min_range: float")
     .Input("max_range: float")
-    .Output("output: float")
+    .Output("output: dtype")
     .Attr("T: quantizedtype")
     .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'MIN_COMBINED'")
     .Attr("narrow_range: bool = false")
     .Attr("axis: int = -1")
+    .Attr("dtype: {bfloat16, float} = DT_FLOAT")
     .SetShapeFn([](InferenceContext* c) {
       int axis = -1;
       Status s = c->GetAttr("axis", &axis);
diff --git a/tensorflow/core/ops/compat/ops_history_v1/Dequantize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Dequantize.pbtxt
index e0a88ff58a2..f8a161433af 100644
--- a/tensorflow/core/ops/compat/ops_history_v1/Dequantize.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history_v1/Dequantize.pbtxt
@@ -248,3 +248,76 @@ op {
     }
   }
 }
+op {
+  name: "Dequantize"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "min_range"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "max_range"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "output"
+    type_attr: "dtype"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_QINT16
+        type: DT_QUINT16
+      }
+    }
+  }
+  attr {
+    name: "mode"
+    type: "string"
+    default_value {
+      s: "MIN_COMBINED"
+    }
+    allowed_values {
+      list {
+        s: "MIN_COMBINED"
+        s: "MIN_FIRST"
+        s: "SCALED"
+      }
+    }
+  }
+  attr {
+    name: "narrow_range"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  attr {
+    name: "axis"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_BFLOAT16
+        type: DT_FLOAT
+      }
+    }
+  }
+}
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 53620a897c4..403ea2aee70 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -4982,7 +4982,8 @@ def dequantize(  # pylint: disable=missing-docstring
     mode="MIN_COMBINED",
     name=None,
     axis=None,
-    narrow_range=False):
+    narrow_range=False,
+    dtype=dtypes.float32):
   if axis is None:
     axis = -1
   elif axis < 0:
@@ -4992,10 +4993,17 @@ def dequantize(  # pylint: disable=missing-docstring
 
   if axis >= 0 or narrow_range:
     return gen_array_ops.dequantize(
-        input, min_range, max_range, mode=mode, name=name,
-        narrow_range=narrow_range, axis=axis)
+        input,
+        min_range,
+        max_range,
+        mode=mode,
+        name=name,
+        narrow_range=narrow_range,
+        axis=axis,
+        dtype=dtype)
   return gen_array_ops.dequantize(
-      input, min_range, max_range, mode=mode, name=name)
+      input, min_range, max_range, mode=mode, name=name, dtype=dtype)
+
 
 dequantize.__doc__ = gen_array_ops.dequantize.__doc__
 
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
index 9abecf88b18..bcefb835e00 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
@@ -1110,7 +1110,7 @@ tf_module {
   }
   member_method {
     name: "dequantize"
-    argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\', \'narrow_range\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\', \'False\'], "
+    argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\', \'narrow_range\', \'dtype\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\', \'False\', \"<dtype: \'float32\'>\"], "
   }
   member_method {
     name: "deserialize_many_sparse"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.quantization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.quantization.pbtxt
index 7c3ef6a194a..047fb4deda7 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.quantization.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.quantization.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.quantization"
 tf_module {
   member_method {
     name: "dequantize"
-    argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\', \'narrow_range\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\', \'False\'], "
+    argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\', \'narrow_range\', \'dtype\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\', \'False\', \"<dtype: \'float32\'>\"], "
   }
   member_method {
     name: "fake_quant_with_min_max_args"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index 9791da7c35f..dc4552d62aa 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -1082,7 +1082,7 @@ tf_module {
   }
   member_method {
     name: "Dequantize"
-    argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'False\', \'-1\', \'None\'], "
+    argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'narrow_range\', \'axis\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'False\', \'-1\', \"<dtype: \'float32\'>\", \'None\'], "
   }
   member_method {
     name: "DeserializeIterator"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.quantization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.quantization.pbtxt
index 7c3ef6a194a..047fb4deda7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.quantization.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.quantization.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.quantization"
 tf_module {
   member_method {
     name: "dequantize"
-    argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\', \'narrow_range\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\', \'False\'], "
+    argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\', \'narrow_range\', \'dtype\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\', \'False\', \"<dtype: \'float32\'>\"], "
   }
   member_method {
     name: "fake_quant_with_min_max_args"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index 9791da7c35f..dc4552d62aa 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -1082,7 +1082,7 @@ tf_module {
   }
   member_method {
     name: "Dequantize"
-    argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'False\', \'-1\', \'None\'], "
+    argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'narrow_range\', \'axis\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'False\', \'-1\', \"<dtype: \'float32\'>\", \'None\'], "
   }
   member_method {
     name: "DeserializeIterator"