From a2e48d849f5c7a97b788ba8d2499e95aaef95945 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 4 Oct 2018 14:18:22 -0700 Subject: [PATCH] Fix problem in quantized version of Comparison op handler PiperOrigin-RevId: 215801773 --- tensorflow/contrib/lite/kernels/comparisons.cc | 16 +++++----------- .../contrib/lite/kernels/comparisons_test.cc | 11 +++++++++++ 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc index f765235e040..3926af5b973 100644 --- a/tensorflow/contrib/lite/kernels/comparisons.cc +++ b/tensorflow/contrib/lite/kernels/comparisons.cc @@ -66,31 +66,25 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) { if (input1->type == kTfLiteUInt8) { \ auto input1_offset = -input1->params.zero_point; \ auto input2_offset = -input2->params.zero_point; \ - const int left_shift = 20; \ - const double twice_max_input_scale = \ - 2 * std::max(input1->params.scale, input2->params.scale); \ - const double real_input1_multiplier = \ - input1->params.scale / twice_max_input_scale; \ - const double real_input2_multiplier = \ - input2->params.scale / twice_max_input_scale; \ + const int left_shift = 8; \ \ int32 input1_multiplier; \ int input1_shift; \ - QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, \ + QuantizeMultiplierSmallerThanOneExp(input1->params.scale, \ &input1_multiplier, &input1_shift); \ int32 input2_multiplier; \ int input2_shift; \ - QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, \ + QuantizeMultiplierSmallerThanOneExp(input2->params.scale, \ &input2_multiplier, &input2_shift); \ \ ComparisonParams op_params; \ op_params.left_shift = left_shift; \ op_params.input1_offset = input1_offset; \ op_params.input1_multiplier = input1_multiplier; \ - op_params.input1_shift = -input1_shift; \ + op_params.input1_shift = input1_shift; \ op_params.input2_offset = input2_offset; \ op_params.input2_multiplier = input2_multiplier; \ - op_params.input2_shift = -input2_shift; \ + op_params.input2_shift = input2_shift; \ if (requires_broadcast) { \ reference_ops::Broadcast4DSlow##opname##WithScaling( \ op_params, GetTensorShape(input1), GetTensorData(input1), \ diff --git a/tensorflow/contrib/lite/kernels/comparisons_test.cc b/tensorflow/contrib/lite/kernels/comparisons_test.cc index 67a91c17fd4..04c8bf2e301 100644 --- a/tensorflow/contrib/lite/kernels/comparisons_test.cc +++ b/tensorflow/contrib/lite/kernels/comparisons_test.cc @@ -402,6 +402,17 @@ TEST(ComparisonsTest, GreaterQuantized) { EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false)); } +TEST(ComparisonsTest, GreaterQuantizedSmallRange) { + ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, 0.0, 1.0}, + {TensorType_UINT8, {1, 2, 2, 1}, 0.0, 2.0}, + TensorType_UINT8, BuiltinOperator_GREATER); + model.QuantizeAndPopulate(model.input1(), {1.0, 0.5, 0.35, 0.1}); + model.QuantizeAndPopulate(model.input2(), {1.01, 0.25, 0.3, 0.4}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false)); +} + TEST(ComparisonsTest, GreaterEqualQuantized) { const float kMin = -1.f; const float kMax = 128.f;