Fix problem in quantized version of Comparison op handler

PiperOrigin-RevId: 215801773
This commit is contained in:
A. Unique TensorFlower 2018-10-04 14:18:22 -07:00 committed by TensorFlower Gardener
parent 4c1da53840
commit a2e48d849f
2 changed files with 16 additions and 11 deletions

View File

@ -66,31 +66,25 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
if (input1->type == kTfLiteUInt8) { \
auto input1_offset = -input1->params.zero_point; \
auto input2_offset = -input2->params.zero_point; \
const int left_shift = 20; \
const double twice_max_input_scale = \
2 * std::max(input1->params.scale, input2->params.scale); \
const double real_input1_multiplier = \
input1->params.scale / twice_max_input_scale; \
const double real_input2_multiplier = \
input2->params.scale / twice_max_input_scale; \
const int left_shift = 8; \
\
int32 input1_multiplier; \
int input1_shift; \
QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, \
QuantizeMultiplierSmallerThanOneExp(input1->params.scale, \
&input1_multiplier, &input1_shift); \
int32 input2_multiplier; \
int input2_shift; \
QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, \
QuantizeMultiplierSmallerThanOneExp(input2->params.scale, \
&input2_multiplier, &input2_shift); \
\
ComparisonParams op_params; \
op_params.left_shift = left_shift; \
op_params.input1_offset = input1_offset; \
op_params.input1_multiplier = input1_multiplier; \
op_params.input1_shift = -input1_shift; \
op_params.input1_shift = input1_shift; \
op_params.input2_offset = input2_offset; \
op_params.input2_multiplier = input2_multiplier; \
op_params.input2_shift = -input2_shift; \
op_params.input2_shift = input2_shift; \
if (requires_broadcast) { \
reference_ops::Broadcast4DSlow##opname##WithScaling( \
op_params, GetTensorShape(input1), GetTensorData<uint8_t>(input1), \

View File

@ -402,6 +402,17 @@ TEST(ComparisonsTest, GreaterQuantized) {
EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
}
TEST(ComparisonsTest, GreaterQuantizedSmallRange) {
ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, 0.0, 1.0},
{TensorType_UINT8, {1, 2, 2, 1}, 0.0, 2.0},
TensorType_UINT8, BuiltinOperator_GREATER);
model.QuantizeAndPopulate<uint8_t>(model.input1(), {1.0, 0.5, 0.35, 0.1});
model.QuantizeAndPopulate<uint8_t>(model.input2(), {1.01, 0.25, 0.3, 0.4});
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
}
TEST(ComparisonsTest, GreaterEqualQuantized) {
const float kMin = -1.f;
const float kMax = 128.f;