diff --git a/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc index 52509352919..06614d7b7c5 100644 --- a/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc @@ -55,7 +55,6 @@ class DequantizeOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis)); OP_REQUIRES(ctx, axis == -1, errors::InvalidArgument("axis must be -1' is ", axis)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); } ~DequantizeOp() override = default; @@ -87,6 +86,7 @@ class DequantizeOp : public XlaOpKernel { xla::XlaOp input = ctx->Input(0); xla::XlaOp output; + // TODO(ylc): Support bfloat16. output = xla::ConvertElementType(input, xla::F32); auto scale = ScalarLike(output, scale_factor); @@ -94,14 +94,8 @@ class DequantizeOp : public XlaOpKernel { output = xla::Add(xla::Mul(xla::Add(output, halfrange), scale), ScalarLike(output, min_range)); - if (dtype_ == DT_BFLOAT16) { - output = xla::ConvertElementType(input, xla::BF16); - } ctx->SetOutput(0, output); } - - private: - DataType dtype_; }; REGISTER_XLA_OP(Name("Dequantize").TypeConstraint("T", kQuantizedType), diff --git a/tensorflow/core/api_def/base_api/api_def_Dequantize.pbtxt b/tensorflow/core/api_def/base_api/api_def_Dequantize.pbtxt index 030b98c369d..82804e46e0e 100644 --- a/tensorflow/core/api_def/base_api/api_def_Dequantize.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_Dequantize.pbtxt @@ -12,14 +12,7 @@ END The maximum scalar value possibly produced for the input. END } - attr { - name: "dtype" - description: < -T Cast(float v) { - return v; -} - -template <> -bfloat16 Cast(float v) { - return bfloat16(v); -} - -template +template class DequantizeOp : public OpKernel { public: explicit DequantizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) { string mode_string; OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_string)); - OP_REQUIRES( - ctx, - (ctx->output_type(0) == DT_FLOAT || ctx->output_type(0) == DT_BFLOAT16), - errors::InvalidArgument("Output type must be bfloat16 or float," - " is '" + - DataTypeString(ctx->output_type(0)) + "'")); - - if (ctx->output_type(0) == DT_FLOAT) { - OP_REQUIRES(ctx, - (mode_string == "MIN_COMBINED" || - mode_string == "MIN_FIRST" || mode_string == "SCALED"), - errors::InvalidArgument("Mode string must be 'MIN_COMBINED'," - " 'MIN_FIRST', or 'SCALED', is '" + - mode_string + "'")); - } else { - OP_REQUIRES( - ctx, (mode_string == "MIN_COMBINED"), - errors::InvalidArgument("When output type is bfloat16, Mode" - " string must be 'MIN_COMBINED', is '" + - mode_string + "'")); - } - + OP_REQUIRES(ctx, + (mode_string == "MIN_COMBINED" || mode_string == "MIN_FIRST" || + mode_string == "SCALED"), + errors::InvalidArgument("Mode string must be 'MIN_COMBINED'," + " 'MIN_FIRST', or 'SCALED', is '" + + mode_string + "'")); if (mode_string == "MIN_COMBINED") { mode_ = QUANTIZE_MODE_MIN_COMBINED; } else if (mode_string == "MIN_FIRST") { @@ -98,40 +71,34 @@ class DequantizeOp : public OpKernel { } Tensor* output = nullptr; - Tensor float_output = tensorflow::Tensor(DT_FLOAT, input.shape()); OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output)); if (num_slices == 1) { const float min_range = input_min_tensor.flat()(0); const float max_range = input_max_tensor.flat()(0); - DequantizeTensor(ctx, input, min_range, max_range, &float_output); - } else { - OP_REQUIRES(ctx, mode_ != QUANTIZE_MODE_MIN_FIRST, - errors::Unimplemented("MIN_FIRST mode is not implemented for " - "Dequantize with axis != -1.")); - - int64 pre_dim = 1, post_dim = 1; - for (int i = 0; i < axis_; ++i) { - pre_dim *= float_output.dim_size(i); - } - for (int i = axis_ + 1; i < float_output.dims(); ++i) { - post_dim *= float_output.dim_size(i); - } - auto input_tensor = input.template bit_casted_shaped( - {pre_dim, num_slices, post_dim}); - auto output_tensor = - float_output.flat_inner_outer_dims(axis_ - 1); - auto min_ranges = input_min_tensor.vec(); - auto max_ranges = input_max_tensor.vec(); - for (int i = 0; i < num_slices; ++i) { - DequantizeSlice(ctx->eigen_device(), ctx, - input_tensor.template chip<1>(i), min_ranges(i), - max_ranges(i), output_tensor.template chip<1>(i)); - } + DequantizeTensor(ctx, input, min_range, max_range, output); + return; } - S* out_ptr = output->flat().data(); - float* in_ptr = float_output.flat().data(); - for (int64 i = 0; i < float_output.NumElements(); ++i) { - out_ptr[i] = static_cast(in_ptr[i]); + + OP_REQUIRES(ctx, mode_ != QUANTIZE_MODE_MIN_FIRST, + errors::Unimplemented("MIN_FIRST mode is not implemented for " + "Dequantize with axis != -1.")); + + int64 pre_dim = 1, post_dim = 1; + for (int i = 0; i < axis_; ++i) { + pre_dim *= output->dim_size(i); + } + for (int i = axis_ + 1; i < output->dims(); ++i) { + post_dim *= output->dim_size(i); + } + auto input_tensor = + input.template bit_casted_shaped({pre_dim, num_slices, post_dim}); + auto output_tensor = output->flat_inner_outer_dims(axis_ - 1); + auto min_ranges = input_min_tensor.vec(); + auto max_ranges = input_max_tensor.vec(); + for (int i = 0; i < num_slices; ++i) { + DequantizeSlice(ctx->eigen_device(), ctx, + input_tensor.template chip<1>(i), min_ranges(i), + max_ranges(i), output_tensor.template chip<1>(i)); } } @@ -221,55 +188,21 @@ class DequantizeOp : public OpKernel { bool narrow_range_; }; -REGISTER_KERNEL_BUILDER(Name("Dequantize") - .Device(DEVICE_CPU) - .TypeConstraint("T") - .TypeConstraint("dtype"), - DequantizeOp); -REGISTER_KERNEL_BUILDER(Name("Dequantize") - .Device(DEVICE_CPU) - .TypeConstraint("T") - .TypeConstraint("dtype"), - DequantizeOp); -REGISTER_KERNEL_BUILDER(Name("Dequantize") - .Device(DEVICE_CPU) - .TypeConstraint("T") - .TypeConstraint("dtype"), - DequantizeOp); -REGISTER_KERNEL_BUILDER(Name("Dequantize") - .Device(DEVICE_CPU) - .TypeConstraint("T") - .TypeConstraint("dtype"), - DequantizeOp); -REGISTER_KERNEL_BUILDER(Name("Dequantize") - .Device(DEVICE_CPU) - .TypeConstraint("T") - .TypeConstraint("dtype"), - DequantizeOp); +REGISTER_KERNEL_BUILDER( + Name("Dequantize").Device(DEVICE_CPU).TypeConstraint("T"), + DequantizeOp); +REGISTER_KERNEL_BUILDER( + Name("Dequantize").Device(DEVICE_CPU).TypeConstraint("T"), + DequantizeOp); +REGISTER_KERNEL_BUILDER( + Name("Dequantize").Device(DEVICE_CPU).TypeConstraint("T"), + DequantizeOp); +REGISTER_KERNEL_BUILDER( + Name("Dequantize").Device(DEVICE_CPU).TypeConstraint("T"), + DequantizeOp); + +REGISTER_KERNEL_BUILDER( + Name("Dequantize").Device(DEVICE_CPU).TypeConstraint("T"), + DequantizeOp); -REGISTER_KERNEL_BUILDER(Name("Dequantize") - .Device(DEVICE_CPU) - .TypeConstraint("T") - .TypeConstraint("dtype"), - DequantizeOp); -REGISTER_KERNEL_BUILDER(Name("Dequantize") - .Device(DEVICE_CPU) - .TypeConstraint("T") - .TypeConstraint("dtype"), - DequantizeOp); -REGISTER_KERNEL_BUILDER(Name("Dequantize") - .Device(DEVICE_CPU) - .TypeConstraint("T") - .TypeConstraint("dtype"), - DequantizeOp); -REGISTER_KERNEL_BUILDER(Name("Dequantize") - .Device(DEVICE_CPU) - .TypeConstraint("T") - .TypeConstraint("dtype"), - DequantizeOp); -REGISTER_KERNEL_BUILDER(Name("Dequantize") - .Device(DEVICE_CPU) - .TypeConstraint("T") - .TypeConstraint("dtype"), - DequantizeOp); } // namespace tensorflow diff --git a/tensorflow/core/kernels/dequantize_op_test.cc b/tensorflow/core/kernels/dequantize_op_test.cc index 3c9d1790787..30e73caf143 100644 --- a/tensorflow/core/kernels/dequantize_op_test.cc +++ b/tensorflow/core/kernels/dequantize_op_test.cc @@ -28,7 +28,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/ops_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -62,9 +61,8 @@ class DequantizeOpTest : public OpsTestBase { // Compares dequantize min vs the same using eigen. This tests that a change // to not use eigen gives equivalent results to using eigen. template - void RunDequantizeMinCombinedTest(float min_range, float max_range, - const string& op_name) { - TF_ASSERT_OK(NodeDefBuilder("dequantize_op", op_name) + void RunDequantizeMinCombinedTest(float min_range, float max_range) { + TF_ASSERT_OK(NodeDefBuilder("dequantize_op", "Dequantize") .Input(FakeInput(DataTypeToEnum::v())) .Input(FakeInput(DT_FLOAT)) .Input(FakeInput(DT_FLOAT)) @@ -89,40 +87,6 @@ class DequantizeOpTest : public OpsTestBase { test::ExpectTensorEqual(expected, *GetOutput(0)); } - // Compares dequantize min vs the same using eigen. This tests that a change - // to not use eigen gives equivalent results to using eigen. - template - void RunDequantizeBfloat16MinCombinedTest(float min_range, float max_range) { - TF_ASSERT_OK(NodeDefBuilder("dequantize_op_bfloat16", "Dequantize") - .Input(FakeInput(DataTypeToEnum::v())) - .Input(FakeInput(DT_FLOAT)) - .Input(FakeInput(DT_FLOAT)) - .Attr("T", DataTypeToEnum::v()) - .Attr("mode", "MIN_COMBINED") - .Attr("dtype", DT_BFLOAT16) - .Finalize(node_def())); - TF_ASSERT_OK(InitOp()); - - std::vector input; - for (int64 i = std::numeric_limits::min(); - i < std::numeric_limits::max(); ++i) { - input.push_back(static_cast(i)); - } - TensorShape shape({static_cast(input.size())}); - AddInputFromArray(shape, input); - AddInputFromArray(TensorShape({}), {min_range}); - AddInputFromArray(TensorShape({}), {max_range}); - TF_ASSERT_OK(RunOpKernel()); - - Tensor expected_float32(allocator(), DT_FLOAT, shape); - ComputeDequantizeMinCombinedUsingEigen(GetInput(0), min_range, max_range, - &expected_float32); - Tensor expected(allocator(), DT_BFLOAT16, shape); - expected.flat() = expected_float32.flat().cast(); - - test::ExpectTensorEqual(expected, *GetOutput(0)); - } - // Creates a tensor with the specified dims, using values chosen from data, // multiplied by (1 + index) along the axis dimension. template @@ -187,29 +151,16 @@ struct ParameterizedDequantizeOpTest public ::testing::WithParamInterface {}; TEST_F(DequantizeOpTest, DequantizeMinCombinedQuint8) { - RunDequantizeMinCombinedTest(0, 255.0f, "Dequantize"); + RunDequantizeMinCombinedTest(0, 255.0f); } TEST_F(DequantizeOpTest, DequantizeMinCombinedQint8) { - RunDequantizeMinCombinedTest(0, 255.0f, "Dequantize"); + RunDequantizeMinCombinedTest(0, 255.0f); } TEST_F(DequantizeOpTest, DequantizeMinCombinedQint16) { - RunDequantizeMinCombinedTest(0, 255.0f, "Dequantize"); + RunDequantizeMinCombinedTest(0, 255.0f); } TEST_F(DequantizeOpTest, DequantizeMinCombinedQuint16) { - RunDequantizeMinCombinedTest(0, 255.0f, "Dequantize"); -} - -TEST_F(DequantizeOpTest, DequantizeBfloat16MinCombinedQuint8) { - RunDequantizeBfloat16MinCombinedTest(0, 255.0f); -} -TEST_F(DequantizeOpTest, DequantizeBfloat16MinCombinedQint8) { - RunDequantizeBfloat16MinCombinedTest(0, 255.0f); -} -TEST_F(DequantizeOpTest, DequantizeBfloat16MinCombinedQint16) { - RunDequantizeBfloat16MinCombinedTest(0, 255.0f); -} -TEST_F(DequantizeOpTest, DequantizeBfloat16MinCombinedQuint16) { - RunDequantizeBfloat16MinCombinedTest(0, 255.0f); + RunDequantizeMinCombinedTest(0, 255.0f); } TEST_F(DequantizeOpTest, DequantizeScaledQuint8Zero) { @@ -251,10 +202,8 @@ static void BM_DequantizeMinCombinedCpu(int iters) { auto root = Scope::NewRootScope().ExitOnError(); const int64 num_values = 1500 * 250; std::vector inputs; - inputs.reserve(num_values); for (int i = 0; i < num_values; ++i) inputs.push_back(i); - ops::Dequantize(root, test::AsTensor(inputs), test::AsScalar(-1.5f), test::AsScalar(20.5f), ops::Dequantize::Attrs().Mode("MIN_COMBINED")); @@ -288,47 +237,5 @@ BENCHMARK(BM_DequantizeMinCombinedCpuQint16); BENCHMARK(BM_DequantizeMinCombinedCpuQuint8); BENCHMARK(BM_DequantizeMinCombinedCpuQint8); -template -static void BM_DequantizeBfloat16MinCombinedCpu(int iters) { - auto root = Scope::NewRootScope().ExitOnError(); - const int64 num_values = 1500 * 250; - std::vector inputs; - - inputs.reserve(num_values); - for (int i = 0; i < num_values; ++i) inputs.push_back(i); - - ops::Dequantize(root, test::AsTensor(inputs), test::AsScalar(-1.5f), - test::AsScalar(20.5f), - ops::Dequantize::Attrs().Dtype(DT_BFLOAT16)); - TF_CHECK_OK(root.status()); - Graph* g = new Graph(OpRegistry::Global()); - TF_CHECK_OK(root.ToGraph(g)); - - test::Benchmark("cpu", g).Run(iters); - testing::BytesProcessed(iters * num_values * (sizeof(bfloat16) + sizeof(T))); - testing::ItemsProcessed(iters); -} - -static void BM_DequantizeBfloat16MinCombinedCpuQuint16(int iters) { - BM_DequantizeBfloat16MinCombinedCpu(iters); -} - -static void BM_DequantizeBfloat16MinCombinedCpuQint16(int iters) { - BM_DequantizeBfloat16MinCombinedCpu(iters); -} - -static void BM_DequantizeBfloat16MinCombinedCpuQuint8(int iters) { - BM_DequantizeBfloat16MinCombinedCpu(iters); -} - -static void BM_DequantizeBfloat16MinCombinedCpuQint8(int iters) { - BM_DequantizeBfloat16MinCombinedCpu(iters); -} - -BENCHMARK(BM_DequantizeBfloat16MinCombinedCpuQuint16); -BENCHMARK(BM_DequantizeBfloat16MinCombinedCpuQint16); -BENCHMARK(BM_DequantizeBfloat16MinCombinedCpuQuint8); -BENCHMARK(BM_DequantizeBfloat16MinCombinedCpuQint8); - } // namespace } // namespace tensorflow diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 60efdcb7a73..a427b8b3967 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -2871,12 +2871,11 @@ REGISTER_OP("Dequantize") .Input("input: T") .Input("min_range: float") .Input("max_range: float") - .Output("output: dtype") + .Output("output: float") .Attr("T: quantizedtype") .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'MIN_COMBINED'") .Attr("narrow_range: bool = false") .Attr("axis: int = -1") - .Attr("dtype: {bfloat16, float} = DT_FLOAT") .SetShapeFn([](InferenceContext* c) { int axis = -1; Status s = c->GetAttr("axis", &axis); diff --git a/tensorflow/core/ops/compat/ops_history_v1/Dequantize.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/Dequantize.pbtxt index f8a161433af..e0a88ff58a2 100644 --- a/tensorflow/core/ops/compat/ops_history_v1/Dequantize.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v1/Dequantize.pbtxt @@ -248,76 +248,3 @@ op { } } } -op { - name: "Dequantize" - input_arg { - name: "input" - type_attr: "T" - } - input_arg { - name: "min_range" - type: DT_FLOAT - } - input_arg { - name: "max_range" - type: DT_FLOAT - } - output_arg { - name: "output" - type_attr: "dtype" - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_QINT8 - type: DT_QUINT8 - type: DT_QINT32 - type: DT_QINT16 - type: DT_QUINT16 - } - } - } - attr { - name: "mode" - type: "string" - default_value { - s: "MIN_COMBINED" - } - allowed_values { - list { - s: "MIN_COMBINED" - s: "MIN_FIRST" - s: "SCALED" - } - } - } - attr { - name: "narrow_range" - type: "bool" - default_value { - b: false - } - } - attr { - name: "axis" - type: "int" - default_value { - i: -1 - } - } - attr { - name: "dtype" - type: "type" - default_value { - type: DT_FLOAT - } - allowed_values { - list { - type: DT_BFLOAT16 - type: DT_FLOAT - } - } - } -} diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 403ea2aee70..53620a897c4 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -4982,8 +4982,7 @@ def dequantize( # pylint: disable=missing-docstring mode="MIN_COMBINED", name=None, axis=None, - narrow_range=False, - dtype=dtypes.float32): + narrow_range=False): if axis is None: axis = -1 elif axis < 0: @@ -4993,17 +4992,10 @@ def dequantize( # pylint: disable=missing-docstring if axis >= 0 or narrow_range: return gen_array_ops.dequantize( - input, - min_range, - max_range, - mode=mode, - name=name, - narrow_range=narrow_range, - axis=axis, - dtype=dtype) + input, min_range, max_range, mode=mode, name=name, + narrow_range=narrow_range, axis=axis) return gen_array_ops.dequantize( - input, min_range, max_range, mode=mode, name=name, dtype=dtype) - + input, min_range, max_range, mode=mode, name=name) dequantize.__doc__ = gen_array_ops.dequantize.__doc__ diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index bcefb835e00..9abecf88b18 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1110,7 +1110,7 @@ tf_module { } member_method { name: "dequantize" - argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\', \'narrow_range\', \'dtype\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\', \'False\', \"\"], " + argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\', \'narrow_range\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\', \'False\'], " } member_method { name: "deserialize_many_sparse" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.quantization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.quantization.pbtxt index 047fb4deda7..7c3ef6a194a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.quantization.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.quantization.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.quantization" tf_module { member_method { name: "dequantize" - argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\', \'narrow_range\', \'dtype\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\', \'False\', \"\"], " + argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\', \'narrow_range\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\', \'False\'], " } member_method { name: "fake_quant_with_min_max_args" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index dc4552d62aa..9791da7c35f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -1082,7 +1082,7 @@ tf_module { } member_method { name: "Dequantize" - argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'narrow_range\', \'axis\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'False\', \'-1\', \"\", \'None\'], " + argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'False\', \'-1\', \'None\'], " } member_method { name: "DeserializeIterator" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.quantization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.quantization.pbtxt index 047fb4deda7..7c3ef6a194a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.quantization.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.quantization.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.quantization" tf_module { member_method { name: "dequantize" - argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\', \'narrow_range\', \'dtype\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\', \'False\', \"\"], " + argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\', \'axis\', \'narrow_range\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'None\', \'False\'], " } member_method { name: "fake_quant_with_min_max_args" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index dc4552d62aa..9791da7c35f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -1082,7 +1082,7 @@ tf_module { } member_method { name: "Dequantize" - argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'narrow_range\', \'axis\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'False\', \'-1\', \"\", \'None\'], " + argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'False\', \'-1\', \'None\'], " } member_method { name: "DeserializeIterator"