Mul int8 support.
PiperOrigin-RevId: 234382662
This commit is contained in:
parent
355cc566ef
commit
130aebaa1e
@ -312,6 +312,7 @@ cc_library(
|
|||||||
"reference/integer_ops/dequantize.h",
|
"reference/integer_ops/dequantize.h",
|
||||||
"reference/integer_ops/fully_connected.h",
|
"reference/integer_ops/fully_connected.h",
|
||||||
"reference/integer_ops/logistic.h",
|
"reference/integer_ops/logistic.h",
|
||||||
|
"reference/integer_ops/mul.h",
|
||||||
"reference/integer_ops/pooling.h",
|
"reference/integer_ops/pooling.h",
|
||||||
"reference/integer_ops/softmax.h",
|
"reference/integer_ops/softmax.h",
|
||||||
"reference/integer_ops/tanh.h",
|
"reference/integer_ops/tanh.h",
|
||||||
|
130
tensorflow/lite/kernels/internal/reference/integer_ops/mul.h
Normal file
130
tensorflow/lite/kernels/internal/reference/integer_ops/mul.h
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MUL_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MUL_H_
|
||||||
|
|
||||||
|
#include "public/gemmlowp.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace reference_integer_ops {
|
||||||
|
|
||||||
|
inline void MulElementwise(int size, const ArithmeticParams& params,
|
||||||
|
const int8_t* input1_data, const int8_t* input2_data,
|
||||||
|
int8_t* output_data) {
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
const int32 input1_val = params.input1_offset + input1_data[i];
|
||||||
|
const int32 input2_val = params.input2_offset + input2_data[i];
|
||||||
|
const int32 unclamped_result =
|
||||||
|
params.output_offset +
|
||||||
|
MultiplyByQuantizedMultiplierSmallerThanOneExp(input1_val * input2_val,
|
||||||
|
params.output_multiplier,
|
||||||
|
params.output_shift);
|
||||||
|
const int32 clamped_output =
|
||||||
|
std::min(params.quantized_activation_max,
|
||||||
|
std::max(params.quantized_activation_min, unclamped_result));
|
||||||
|
output_data[i] = static_cast<int8_t>(clamped_output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void Mul(const ArithmeticParams& params,
|
||||||
|
const RuntimeShape& input1_shape, const int8_t* input1_data,
|
||||||
|
const RuntimeShape& input2_shape, const int8_t* input2_data,
|
||||||
|
const RuntimeShape& output_shape, int8_t* output_data) {
|
||||||
|
TFLITE_DCHECK_LE(params.quantized_activation_min,
|
||||||
|
params.quantized_activation_max);
|
||||||
|
gemmlowp::ScopedProfilingLabel label("Mul/8bit");
|
||||||
|
const int flat_size =
|
||||||
|
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||||
|
|
||||||
|
MulElementwise(flat_size, params, input1_data, input2_data, output_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mul with 16 bit inputs and int8_t outputs.
|
||||||
|
inline void Mul(const ArithmeticParams& params,
|
||||||
|
const RuntimeShape& input1_shape, const int16* input1_data,
|
||||||
|
const RuntimeShape& input2_shape, const int16* input2_data,
|
||||||
|
const RuntimeShape& output_shape, int8_t* output_data) {
|
||||||
|
gemmlowp::ScopedProfilingLabel label("Mul/Int16Int8");
|
||||||
|
int32 output_offset = params.output_offset;
|
||||||
|
int32 output_activation_min = params.quantized_activation_min;
|
||||||
|
int32 output_activation_max = params.quantized_activation_max;
|
||||||
|
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
|
||||||
|
|
||||||
|
const int flat_size =
|
||||||
|
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||||
|
|
||||||
|
for (int i = 0; i < flat_size; i++) {
|
||||||
|
// F0 uses 0 integer bits, range [-1, 1].
|
||||||
|
using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
|
||||||
|
|
||||||
|
F0 unclamped_result =
|
||||||
|
F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
|
||||||
|
int16 rescaled_result =
|
||||||
|
gemmlowp::RoundingDivideByPOT(unclamped_result.raw(), 8);
|
||||||
|
int16 clamped_result =
|
||||||
|
std::min<int16>(output_activation_max - output_offset, rescaled_result);
|
||||||
|
clamped_result =
|
||||||
|
std::max<int16>(output_activation_min - output_offset, clamped_result);
|
||||||
|
output_data[i] = output_offset + clamped_result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void BroadcastMul4DSlow(const ArithmeticParams& params,
|
||||||
|
const RuntimeShape& input1_shape,
|
||||||
|
const int8_t* input1_data,
|
||||||
|
const RuntimeShape& input2_shape,
|
||||||
|
const int8_t* input2_data,
|
||||||
|
const RuntimeShape& output_shape,
|
||||||
|
int8_t* output_data) {
|
||||||
|
gemmlowp::ScopedProfilingLabel label("BroadcastMul4DSlow/8bit");
|
||||||
|
|
||||||
|
NdArrayDesc<4> desc1;
|
||||||
|
NdArrayDesc<4> desc2;
|
||||||
|
// The input shapes are extended as part of NdArrayDesc initialization.
|
||||||
|
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
|
||||||
|
&desc2);
|
||||||
|
const RuntimeShape extended_output_shape =
|
||||||
|
RuntimeShape::ExtendedShape(4, output_shape);
|
||||||
|
|
||||||
|
for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
|
||||||
|
for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
|
||||||
|
for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
|
||||||
|
for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
|
||||||
|
const int32 input1_val =
|
||||||
|
params.input1_offset +
|
||||||
|
input1_data[SubscriptToIndex(desc1, b, y, x, c)];
|
||||||
|
const int32 input2_val =
|
||||||
|
params.input2_offset +
|
||||||
|
input2_data[SubscriptToIndex(desc2, b, y, x, c)];
|
||||||
|
const int32 unclamped_result =
|
||||||
|
params.output_offset +
|
||||||
|
MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
||||||
|
input1_val * input2_val, params.output_multiplier,
|
||||||
|
params.output_shift);
|
||||||
|
const int32 clamped_output = std::min(
|
||||||
|
params.quantized_activation_max,
|
||||||
|
std::max(params.quantized_activation_min, unclamped_result));
|
||||||
|
output_data[Offset(extended_output_shape, b, y, x, c)] =
|
||||||
|
static_cast<int8_t>(clamped_output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace reference_integer_ops
|
||||||
|
} // namespace tflite
|
||||||
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MUL_H_
|
@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/integer_ops/mul.h"
|
||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||||
@ -87,8 +88,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
&data->output_activation_min,
|
&data->output_activation_min,
|
||||||
&data->output_activation_max);
|
&data->output_activation_max);
|
||||||
}
|
}
|
||||||
|
if (output->type == kTfLiteInt8) {
|
||||||
|
CalculateActivationRangeInt8(params->activation, output,
|
||||||
|
&data->output_activation_min,
|
||||||
|
&data->output_activation_max);
|
||||||
|
}
|
||||||
|
|
||||||
if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) {
|
if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 ||
|
||||||
|
output->type == kTfLiteInt16) {
|
||||||
double real_multiplier =
|
double real_multiplier =
|
||||||
input1->params.scale * input2->params.scale / output->params.scale;
|
input1->params.scale * input2->params.scale / output->params.scale;
|
||||||
QuantizeMultiplierSmallerThanOneExp(
|
QuantizeMultiplierSmallerThanOneExp(
|
||||||
@ -151,8 +158,8 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
|||||||
TfLiteMulParams* params, const OpData* data,
|
TfLiteMulParams* params, const OpData* data,
|
||||||
const TfLiteTensor* input1,
|
const TfLiteTensor* input1,
|
||||||
const TfLiteTensor* input2, TfLiteTensor* output) {
|
const TfLiteTensor* input2, TfLiteTensor* output) {
|
||||||
if (input1->type == kTfLiteUInt8 && input2->type == kTfLiteUInt8 &&
|
if (input1->type == input2->type && input1->type == output->type &&
|
||||||
output->type == kTfLiteUInt8) {
|
(input1->type == kTfLiteUInt8 || input1->type == kTfLiteInt8)) {
|
||||||
tflite::ArithmeticParams op_params;
|
tflite::ArithmeticParams op_params;
|
||||||
SetActivationParams(data->output_activation_min,
|
SetActivationParams(data->output_activation_min,
|
||||||
data->output_activation_max, &op_params);
|
data->output_activation_max, &op_params);
|
||||||
@ -163,23 +170,31 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
|||||||
op_params.output_shift = data->output_shift;
|
op_params.output_shift = data->output_shift;
|
||||||
bool need_broadcast = optimized_ops::ProcessBroadcastShapes(
|
bool need_broadcast = optimized_ops::ProcessBroadcastShapes(
|
||||||
GetTensorShape(input1), GetTensorShape(input2), &op_params);
|
GetTensorShape(input1), GetTensorShape(input2), &op_params);
|
||||||
#define TF_LITE_MUL(type, opname) \
|
#define TF_LITE_MUL(type, opname, dtype) \
|
||||||
type::opname(op_params, GetTensorShape(input1), \
|
type::opname(op_params, GetTensorShape(input1), \
|
||||||
GetTensorData<uint8_t>(input1), GetTensorShape(input2), \
|
GetTensorData<dtype>(input1), GetTensorShape(input2), \
|
||||||
GetTensorData<uint8_t>(input2), GetTensorShape(output), \
|
GetTensorData<dtype>(input2), GetTensorShape(output), \
|
||||||
GetTensorData<uint8_t>(output))
|
GetTensorData<dtype>(output))
|
||||||
|
if (input1->type == kTfLiteInt8) {
|
||||||
|
if (need_broadcast) {
|
||||||
|
TF_LITE_MUL(reference_integer_ops, BroadcastMul4DSlow, int8_t);
|
||||||
|
} else {
|
||||||
|
TF_LITE_MUL(reference_integer_ops, Mul, int8_t);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// type == kTfLiteUInt8
|
||||||
if (kernel_type == kReference) {
|
if (kernel_type == kReference) {
|
||||||
if (need_broadcast) {
|
if (need_broadcast) {
|
||||||
TF_LITE_MUL(reference_ops, BroadcastMul4DSlow);
|
TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, uint8_t);
|
||||||
} else {
|
} else {
|
||||||
TF_LITE_MUL(reference_ops, Mul);
|
TF_LITE_MUL(reference_ops, Mul, uint8_t);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (need_broadcast) {
|
if (need_broadcast) {
|
||||||
TF_LITE_MUL(optimized_ops, BroadcastMulFivefold);
|
TF_LITE_MUL(optimized_ops, BroadcastMulFivefold, uint8_t);
|
||||||
} else {
|
} else {
|
||||||
TF_LITE_MUL(optimized_ops, Mul);
|
TF_LITE_MUL(optimized_ops, Mul, uint8_t);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#undef TF_LITE_MUL
|
#undef TF_LITE_MUL
|
||||||
@ -198,8 +213,8 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
|||||||
}
|
}
|
||||||
#undef TF_LITE_MUL
|
#undef TF_LITE_MUL
|
||||||
} else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 &&
|
} else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 &&
|
||||||
output->type == kTfLiteUInt8) {
|
(output->type == kTfLiteUInt8 || output->type == kTfLiteInt8)) {
|
||||||
#define TF_LITE_MUL(type, opname) \
|
#define TF_LITE_MUL(type, opname, output_dtype) \
|
||||||
tflite::ArithmeticParams op_params; \
|
tflite::ArithmeticParams op_params; \
|
||||||
SetActivationParams(data->output_activation_min, \
|
SetActivationParams(data->output_activation_min, \
|
||||||
data->output_activation_max, &op_params); \
|
data->output_activation_max, &op_params); \
|
||||||
@ -207,11 +222,15 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
|||||||
type::opname(op_params, GetTensorShape(input1), \
|
type::opname(op_params, GetTensorShape(input1), \
|
||||||
GetTensorData<int16_t>(input1), GetTensorShape(input2), \
|
GetTensorData<int16_t>(input1), GetTensorShape(input2), \
|
||||||
GetTensorData<int16_t>(input2), GetTensorShape(output), \
|
GetTensorData<int16_t>(input2), GetTensorShape(output), \
|
||||||
GetTensorData<uint8_t>(output))
|
GetTensorData<output_dtype>(output))
|
||||||
if (kernel_type == kReference) {
|
if (output->type == kTfLiteInt8) {
|
||||||
TF_LITE_MUL(reference_ops, Mul);
|
TF_LITE_MUL(reference_integer_ops, Mul, int8_t);
|
||||||
} else {
|
} else {
|
||||||
TF_LITE_MUL(optimized_ops, Mul);
|
if (kernel_type == kReference) {
|
||||||
|
TF_LITE_MUL(reference_ops, Mul, uint8_t);
|
||||||
|
} else {
|
||||||
|
TF_LITE_MUL(optimized_ops, Mul, uint8_t);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#undef TF_LITE_MUL
|
#undef TF_LITE_MUL
|
||||||
} else {
|
} else {
|
||||||
@ -233,14 +252,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
|
if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
|
||||||
EvalMul<kernel_type>(context, node, params, data, input1, input2, output);
|
EvalMul<kernel_type>(context, node, params, data, input1, input2, output);
|
||||||
} else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) {
|
} else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 ||
|
||||||
|
output->type == kTfLiteInt16) {
|
||||||
TF_LITE_ENSURE_OK(
|
TF_LITE_ENSURE_OK(
|
||||||
context, EvalQuantized<kernel_type>(context, node, params, data, input1,
|
context, EvalQuantized<kernel_type>(context, node, params, data, input1,
|
||||||
input2, output));
|
input2, output));
|
||||||
} else {
|
} else {
|
||||||
context->ReportError(context,
|
context->ReportError(context,
|
||||||
"Mul only supports FLOAT32, INT32 and quantized UINT8 "
|
"Mul only supports FLOAT32, INT32 and quantized UINT8,"
|
||||||
"and INT16 now, got %d.",
|
" INT8 and INT16 now, got %d.",
|
||||||
output->type);
|
output->type);
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
|
@ -73,8 +73,9 @@ class QuantizedMulOpModel : public BaseMulOpModel {
|
|||||||
public:
|
public:
|
||||||
using BaseMulOpModel::BaseMulOpModel;
|
using BaseMulOpModel::BaseMulOpModel;
|
||||||
|
|
||||||
|
template <typename integer_dtype>
|
||||||
std::vector<float> GetDequantizedOutput() {
|
std::vector<float> GetDequantizedOutput() {
|
||||||
return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
|
return Dequantize<integer_dtype>(ExtractVector<integer_dtype>(output_),
|
||||||
GetScale(output_), GetZeroPoint(output_));
|
GetScale(output_), GetZeroPoint(output_));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -191,19 +192,28 @@ TEST(IntegerMulOpTest, WithBroadcast) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(QuantizedMulOpTest, NoActivation) {
|
template <TensorType tensor_type, typename integer_dtype>
|
||||||
QuantizedMulOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0},
|
void NoActivation() {
|
||||||
{TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0},
|
QuantizedMulOpModel m({tensor_type, {1, 2, 2, 1}, -1.0, 1.0},
|
||||||
{TensorType_UINT8, {}, -1.0, 1.0},
|
{tensor_type, {1, 2, 2, 1}, -1.0, 1.0},
|
||||||
|
{tensor_type, {}, -1.0, 1.0},
|
||||||
ActivationFunctionType_NONE);
|
ActivationFunctionType_NONE);
|
||||||
m.QuantizeAndPopulate<uint8_t>(m.input1(), {-0.8, 0.2, 0.9, 0.7});
|
m.QuantizeAndPopulate<integer_dtype>(m.input1(), {-0.8, 0.2, 0.9, 0.7});
|
||||||
m.QuantizeAndPopulate<uint8_t>(m.input2(), {0.6, 0.4, 0.9, 0.8});
|
m.QuantizeAndPopulate<integer_dtype>(m.input2(), {0.6, 0.4, 0.9, 0.8});
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetDequantizedOutput(),
|
EXPECT_THAT(m.GetDequantizedOutput<integer_dtype>(),
|
||||||
ElementsAreArray(ArrayFloatNear({-0.48, 0.08, 0.81, 0.56},
|
ElementsAreArray(ArrayFloatNear({-0.48, 0.08, 0.81, 0.56},
|
||||||
kQuantizedTolerance)));
|
kQuantizedTolerance)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(QuantizedMulOpTest, NoActivationUInt8) {
|
||||||
|
NoActivation<TensorType_UINT8, uint8_t>();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(QuantizedMulOpTest, NoActivationInt8) {
|
||||||
|
NoActivation<TensorType_INT8, int8_t>();
|
||||||
|
}
|
||||||
|
|
||||||
TEST(QuantizedMulOpTest, NoActivationInt16) {
|
TEST(QuantizedMulOpTest, NoActivationInt16) {
|
||||||
const float kMin = -1.f;
|
const float kMin = -1.f;
|
||||||
const float kMax = 32767.f / 32768.f;
|
const float kMax = 32767.f / 32768.f;
|
||||||
@ -219,23 +229,32 @@ TEST(QuantizedMulOpTest, NoActivationInt16) {
|
|||||||
kQuantizedToleranceInt16)));
|
kQuantizedToleranceInt16)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(QuantizedMulOpTest, NoActivationInt16WithUint8Output) {
|
template <TensorType tensor_type, typename integer_dtype>
|
||||||
|
void NoActivationInt16With8BitOutput() {
|
||||||
const float kMinInt16 = -1.f;
|
const float kMinInt16 = -1.f;
|
||||||
const float kMaxInt16 = 32767.f / 32768.f;
|
const float kMaxInt16 = 32767.f / 32768.f;
|
||||||
const float kMinUint8 = -1.f;
|
const float kMinUint8 = -1.f;
|
||||||
const float kMaxUint8 = 127.f / 128.f;
|
const float kMaxUint8 = 127.f / 128.f;
|
||||||
QuantizedMulOpModel m({TensorType_INT16, {1, 2, 2, 1}, kMinInt16, kMaxInt16},
|
QuantizedMulOpModel m({TensorType_INT16, {1, 2, 2, 1}, kMinInt16, kMaxInt16},
|
||||||
{TensorType_INT16, {1, 2, 2, 1}, kMinInt16, kMaxInt16},
|
{TensorType_INT16, {1, 2, 2, 1}, kMinInt16, kMaxInt16},
|
||||||
{TensorType_UINT8, {}, kMinUint8, kMaxUint8},
|
{tensor_type, {}, kMinUint8, kMaxUint8},
|
||||||
ActivationFunctionType_NONE);
|
ActivationFunctionType_NONE);
|
||||||
m.QuantizeAndPopulate<int16_t>(m.input1(), {-0.8, 0.2, 0.9, 0.7});
|
m.QuantizeAndPopulate<int16_t>(m.input1(), {-0.8, 0.2, 0.9, 0.7});
|
||||||
m.QuantizeAndPopulate<int16_t>(m.input2(), {0.6, 0.4, 0.9, 0.8});
|
m.QuantizeAndPopulate<int16_t>(m.input2(), {0.6, 0.4, 0.9, 0.8});
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetDequantizedOutput(),
|
EXPECT_THAT(m.GetDequantizedOutput<integer_dtype>(),
|
||||||
ElementsAreArray(ArrayFloatNear({-0.48, 0.08, 0.81, 0.56},
|
ElementsAreArray(ArrayFloatNear({-0.48, 0.08, 0.81, 0.56},
|
||||||
kQuantizedTolerance)));
|
kQuantizedTolerance)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(QuantizedMulOpTest, NoActivationInt16WithUint8Output) {
|
||||||
|
NoActivationInt16With8BitOutput<TensorType_UINT8, uint8_t>();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(QuantizedMulOpTest, NoActivationInt16Withint8Output) {
|
||||||
|
NoActivationInt16With8BitOutput<TensorType_INT8, int8_t>();
|
||||||
|
}
|
||||||
|
|
||||||
// for quantized Mul, the error shouldn't exceed 2*step
|
// for quantized Mul, the error shouldn't exceed 2*step
|
||||||
float GetTolerance(int min, int max) {
|
float GetTolerance(int min, int max) {
|
||||||
float kQuantizedStep = (max - min) / 255.0;
|
float kQuantizedStep = (max - min) / 255.0;
|
||||||
@ -243,25 +262,35 @@ float GetTolerance(int min, int max) {
|
|||||||
return kQuantizedTolerance;
|
return kQuantizedTolerance;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(QuantizedMulOpTest, WithBroadcast) {
|
template <TensorType tensor_type, typename integer_dtype>
|
||||||
|
void WithBroadcast() {
|
||||||
float kQuantizedTolerance = GetTolerance(-3.0, 3.0);
|
float kQuantizedTolerance = GetTolerance(-3.0, 3.0);
|
||||||
std::vector<std::vector<int>> test_shapes = {
|
std::vector<std::vector<int>> test_shapes = {
|
||||||
{6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
|
{6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
|
||||||
for (int i = 0; i < test_shapes.size(); ++i) {
|
for (int i = 0; i < test_shapes.size(); ++i) {
|
||||||
QuantizedMulOpModel m({TensorType_UINT8, test_shapes[i], -3.0, 3.0},
|
QuantizedMulOpModel m({tensor_type, test_shapes[i], -3.0, 3.0},
|
||||||
{TensorType_UINT8, {}, -3.0, 3.0}, // always a scalar
|
{tensor_type, {}, -3.0, 3.0}, // always a scalar
|
||||||
{TensorType_UINT8, {}, -3.0, 3.0},
|
{tensor_type, {}, -3.0, 3.0},
|
||||||
ActivationFunctionType_NONE);
|
ActivationFunctionType_NONE);
|
||||||
m.QuantizeAndPopulate<uint8_t>(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0});
|
m.QuantizeAndPopulate<integer_dtype>(m.input1(),
|
||||||
m.QuantizeAndPopulate<uint8_t>(m.input2(), {0.1});
|
{-2.0, 0.2, 0.7, 0.8, 1.1, 2.0});
|
||||||
|
m.QuantizeAndPopulate<integer_dtype>(m.input2(), {0.1});
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetDequantizedOutput(),
|
EXPECT_THAT(m.GetDequantizedOutput<integer_dtype>(),
|
||||||
ElementsAreArray(ArrayFloatNear(
|
ElementsAreArray(ArrayFloatNear(
|
||||||
{-0.2, 0.02, 0.07, 0.08, 0.11, 0.2}, kQuantizedTolerance)))
|
{-0.2, 0.02, 0.07, 0.08, 0.11, 0.2}, kQuantizedTolerance)))
|
||||||
<< "With shape number " << i;
|
<< "With shape number " << i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(QuantizedMulOpTest, WithBroadcastUInt8) {
|
||||||
|
WithBroadcast<TensorType_UINT8, uint8_t>();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(QuantizedMulOpTest, WithBroadcastInt8) {
|
||||||
|
WithBroadcast<TensorType_INT8, int8_t>();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
@ -764,6 +764,12 @@ class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
|
|||||||
}
|
}
|
||||||
|
|
||||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||||
|
const string& input_name = op_signature.op->inputs[0];
|
||||||
|
const Array& input_array = op_signature.model->GetArray(input_name);
|
||||||
|
// Version 2 supports signed int8 input types.
|
||||||
|
if (input_array.data_type == ArrayDataType::kInt8) {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -841,6 +841,8 @@ TEST_F(OperatorTest, VersioningAddTest) { SimpleVersioningTest<AddOperator>(); }
|
|||||||
|
|
||||||
TEST_F(OperatorTest, VersioningSubTest) { SimpleVersioningTest<SubOperator>(); }
|
TEST_F(OperatorTest, VersioningSubTest) { SimpleVersioningTest<SubOperator>(); }
|
||||||
|
|
||||||
|
TEST_F(OperatorTest, VersioningMulTest) { SimpleVersioningTest<MulOperator>(); }
|
||||||
|
|
||||||
TEST_F(OperatorTest, VersioningPadTest) { SimpleVersioningTest<PadOperator>(); }
|
TEST_F(OperatorTest, VersioningPadTest) { SimpleVersioningTest<PadOperator>(); }
|
||||||
|
|
||||||
TEST_F(OperatorTest, VersioningPadV2Test) {
|
TEST_F(OperatorTest, VersioningPadV2Test) {
|
||||||
|
Loading…
Reference in New Issue
Block a user