Fix problem in quantized version of Comparison op handler
PiperOrigin-RevId: 215801773
This commit is contained in:
parent
4c1da53840
commit
a2e48d849f
@ -66,31 +66,25 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
if (input1->type == kTfLiteUInt8) { \
|
||||
auto input1_offset = -input1->params.zero_point; \
|
||||
auto input2_offset = -input2->params.zero_point; \
|
||||
const int left_shift = 20; \
|
||||
const double twice_max_input_scale = \
|
||||
2 * std::max(input1->params.scale, input2->params.scale); \
|
||||
const double real_input1_multiplier = \
|
||||
input1->params.scale / twice_max_input_scale; \
|
||||
const double real_input2_multiplier = \
|
||||
input2->params.scale / twice_max_input_scale; \
|
||||
const int left_shift = 8; \
|
||||
\
|
||||
int32 input1_multiplier; \
|
||||
int input1_shift; \
|
||||
QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, \
|
||||
QuantizeMultiplierSmallerThanOneExp(input1->params.scale, \
|
||||
&input1_multiplier, &input1_shift); \
|
||||
int32 input2_multiplier; \
|
||||
int input2_shift; \
|
||||
QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, \
|
||||
QuantizeMultiplierSmallerThanOneExp(input2->params.scale, \
|
||||
&input2_multiplier, &input2_shift); \
|
||||
\
|
||||
ComparisonParams op_params; \
|
||||
op_params.left_shift = left_shift; \
|
||||
op_params.input1_offset = input1_offset; \
|
||||
op_params.input1_multiplier = input1_multiplier; \
|
||||
op_params.input1_shift = -input1_shift; \
|
||||
op_params.input1_shift = input1_shift; \
|
||||
op_params.input2_offset = input2_offset; \
|
||||
op_params.input2_multiplier = input2_multiplier; \
|
||||
op_params.input2_shift = -input2_shift; \
|
||||
op_params.input2_shift = input2_shift; \
|
||||
if (requires_broadcast) { \
|
||||
reference_ops::Broadcast4DSlow##opname##WithScaling( \
|
||||
op_params, GetTensorShape(input1), GetTensorData<uint8_t>(input1), \
|
||||
|
@ -402,6 +402,17 @@ TEST(ComparisonsTest, GreaterQuantized) {
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
|
||||
}
|
||||
|
||||
TEST(ComparisonsTest, GreaterQuantizedSmallRange) {
|
||||
ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, 0.0, 1.0},
|
||||
{TensorType_UINT8, {1, 2, 2, 1}, 0.0, 2.0},
|
||||
TensorType_UINT8, BuiltinOperator_GREATER);
|
||||
model.QuantizeAndPopulate<uint8_t>(model.input1(), {1.0, 0.5, 0.35, 0.1});
|
||||
model.QuantizeAndPopulate<uint8_t>(model.input2(), {1.01, 0.25, 0.3, 0.4});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
|
||||
}
|
||||
|
||||
TEST(ComparisonsTest, GreaterEqualQuantized) {
|
||||
const float kMin = -1.f;
|
||||
const float kMax = 128.f;
|
||||
|
Loading…
x
Reference in New Issue
Block a user