Add int8 and int16 support for FILL operator
This commit is contained in:
parent
df87bbd3bb
commit
607c596d95
@ -1596,17 +1596,16 @@ shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1]
|
||||
def TFL_FillOp: TFL_Op<"fill", [
|
||||
NoSideEffect,
|
||||
PredOpTrait<"input and result must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 1>>,
|
||||
NoQuantizableResult]> {
|
||||
TFL_TCresVTEtIsSameAsOp<0, 1>>]> {
|
||||
let summary = "Fill the tensor with given value.";
|
||||
let description = [{
|
||||
Fill the tensor with given value.
|
||||
}];
|
||||
|
||||
let arguments = (ins TFL_I32OrI64Tensor:$dims,
|
||||
TFL_TensorOf<[F32, I32, I64, I1, QI8, TFL_Str]>:$input);
|
||||
TFL_TensorOf<[F32, I32, I64, I1, QI8, QI16, TFL_Str]>:$input);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, I32, I64, I1, QI8, TFL_Str]>:$result);
|
||||
let results = (outs TFL_TensorOf<[F32, I32, I64, I1, QI8, QI16, TFL_Str]>:$result);
|
||||
|
||||
let hasOptions = 0;
|
||||
}
|
||||
|
@ -92,6 +92,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
output->type = value->type;
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, output->params.scale, value->params.scale);
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point,
|
||||
value->params.zero_point);
|
||||
|
||||
if (value->type == kTfLiteInt16) {
|
||||
TF_LITE_ENSURE_EQ(context, value->params.zero_point, 0);
|
||||
}
|
||||
|
||||
if (IsConstantTensor(dims)) {
|
||||
TF_LITE_ENSURE_OK(context, ResizeOutput(context, dims, output));
|
||||
} else {
|
||||
@ -132,6 +140,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetTensorShape(output), \
|
||||
GetTensorData<data_type>(output))
|
||||
switch (output->type) {
|
||||
case kTfLiteInt8:
|
||||
TF_LITE_FILL(int8_t);
|
||||
break;
|
||||
case kTfLiteInt16:
|
||||
TF_LITE_FILL(int16_t);
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
TF_LITE_FILL(int32_t);
|
||||
break;
|
||||
@ -147,14 +161,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteString:
|
||||
FillString(value, output);
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
TF_LITE_FILL(int8_t);
|
||||
break;
|
||||
default:
|
||||
context->ReportError(
|
||||
context,
|
||||
"Fill only currently supports int32, int64, float32, bool, string "
|
||||
"for input 1, got %d.",
|
||||
"Fill only currently supports int8, int16, int32, int64, float32, "
|
||||
"bool, string for input 1, got %d.",
|
||||
value->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
@ -73,6 +73,42 @@ class FillOpModel : public SingleOpModel {
|
||||
int output_;
|
||||
};
|
||||
|
||||
template <typename dims_type, typename quant_type>
|
||||
class QuantizedFillOpModel : public SingleOpModel {
|
||||
public:
|
||||
explicit QuantizedFillOpModel(TensorType dims_tensor_type,
|
||||
std::initializer_list<int> dims_shape,
|
||||
std::initializer_list<dims_type> dims_data,
|
||||
const TensorData& tensor_data,
|
||||
float value) {
|
||||
dims_ = AddInput(dims_tensor_type);
|
||||
value_ = AddInput(tensor_data);
|
||||
output_ = AddOutput(tensor_data);
|
||||
SetBuiltinOp(BuiltinOperator_FILL, BuiltinOptions_FillOptions,
|
||||
CreateFillOptions(builder_).Union());
|
||||
BuildInterpreter({dims_shape, {}});
|
||||
|
||||
if (dims_data.size() > 0) {
|
||||
PopulateTensor<dims_type>(dims_, dims_data);
|
||||
}
|
||||
QuantizeAndPopulate<quant_type>(value_, {value});
|
||||
}
|
||||
|
||||
std::vector<quant_type> GetOutput() {
|
||||
return ExtractVector<quant_type>(output_);
|
||||
}
|
||||
std::vector<float> GetDequantizedOutput() {
|
||||
TfLiteTensor* t = interpreter_->tensor(output_);
|
||||
return Dequantize(GetOutput(), t->params.scale, t->params.zero_point);
|
||||
}
|
||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
|
||||
protected:
|
||||
int dims_;
|
||||
int value_;
|
||||
int output_;
|
||||
};
|
||||
|
||||
class FillOpTest : public ::testing::TestWithParam<TestType> {};
|
||||
|
||||
TEST_P(FillOpTest, FillInt32) {
|
||||
@ -144,6 +180,42 @@ TEST_P(FillOpTest, FillInt8) {
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2}));
|
||||
}
|
||||
|
||||
template <typename quant_type>
|
||||
void QuantizedFill(float value) {
|
||||
// Prepare TensorData for quantization of value
|
||||
const float kMin = -1;
|
||||
// Workaround to get a zero-point of 0
|
||||
const float kMax =
|
||||
std::numeric_limits<quant_type>::max() /
|
||||
static_cast<float>(std::numeric_limits<quant_type>::max() + 1);
|
||||
const TensorData tensor_data(GetTensorType<quant_type>(), {},
|
||||
std::abs(value) * kMin, std::abs(value) * kMax);
|
||||
|
||||
QuantizedFillOpModel<int32_t, quant_type> m(TensorType_INT32, {2}, {2, 3},
|
||||
tensor_data, value);
|
||||
m.Invoke();
|
||||
|
||||
constexpr float epsilon = 0.01f;
|
||||
const float min_value = tensor_data.min - epsilon;
|
||||
const float max_value = tensor_data.max + epsilon;
|
||||
const float kQuantizedTolerance =
|
||||
(max_value - min_value) / (std::numeric_limits<quant_type>::max() -
|
||||
std::numeric_limits<quant_type>::min());
|
||||
EXPECT_THAT(
|
||||
m.GetDequantizedOutput(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{value, value, value, value, value, value}, kQuantizedTolerance)));
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
|
||||
}
|
||||
|
||||
TEST(FillOpTest, QuantizedFillInt8) {
|
||||
QuantizedFill<int8_t>(3.14f);
|
||||
}
|
||||
|
||||
TEST(FillOpTest, QuantizedFillInt16) {
|
||||
QuantizedFill<int16_t>(3.14f);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(FillOpTest, FillOpTest,
|
||||
::testing::Values(TestType::kConst,
|
||||
TestType::kDynamic));
|
||||
|
@ -270,7 +270,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
/* max_version = */ 2);
|
||||
AddBuiltin(BuiltinOperator_FILL, Register_FILL(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
/* max_version = */ 3);
|
||||
AddBuiltin(BuiltinOperator_MIRROR_PAD, Register_MIRROR_PAD(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
|
@ -432,7 +432,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
|
||||
/* max_version = */ 2);
|
||||
AddBuiltin(BuiltinOperator_FILL, Register_FILL(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
/* max_version = */ 3);
|
||||
AddBuiltin(BuiltinOperator_MIRROR_PAD, Register_MIRROR_PAD(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
|
@ -192,6 +192,13 @@ OperatorProperty GetOperatorProperty(OpVariant op_variant) {
|
||||
property.outputs = {{0, {}}};
|
||||
property.version = 1;
|
||||
break;
|
||||
case BuiltinOperator_FILL: {
|
||||
property.inputs = {{1, {}}};
|
||||
property.outputs = {{0, {}}};
|
||||
property.restrict_same_input_output_scale = true;
|
||||
property.version = 3;
|
||||
break;
|
||||
}
|
||||
case BuiltinOperator_FULLY_CONNECTED: {
|
||||
TensorProperty tensor_property;
|
||||
tensor_property.symmetric = true;
|
||||
|
@ -538,10 +538,14 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_FILL:
|
||||
if (op_sig.input_types.size() >= 2 &&
|
||||
(op_sig.input_types.at(1) == TensorType_BOOL ||
|
||||
op_sig.input_types.at(1) == TensorType_STRING)) {
|
||||
return 2;
|
||||
if (op_sig.input_types.size() >= 2) {
|
||||
if (op_sig.input_types.at(1) == TensorType_INT8 ||
|
||||
op_sig.input_types.at(1) == TensorType_INT16) {
|
||||
return 3;
|
||||
} else if ((op_sig.input_types.at(1) == TensorType_BOOL ||
|
||||
op_sig.input_types.at(1) == TensorType_STRING)) {
|
||||
return 2;
|
||||
}
|
||||
}
|
||||
return 1;
|
||||
|
||||
|
@ -700,7 +700,15 @@ TEST(OpVersionTest, VersioningDivTest) {
|
||||
TEST(OpVersionTEst, VersioningFillTest) {
|
||||
OpSignature fake_op_sig = {.op = BuiltinOperator_FILL,
|
||||
.input_types = std::vector<TensorType>{
|
||||
TensorType_INT32, TensorType_BOOL}};
|
||||
TensorType_INT32, TensorType_INT8}};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||
fake_op_sig = {.op = BuiltinOperator_FILL,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT64,
|
||||
TensorType_INT16}};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||
fake_op_sig = {.op = BuiltinOperator_FILL,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT32,
|
||||
TensorType_BOOL}};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||
fake_op_sig = {.op = BuiltinOperator_FILL,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT32,
|
||||
|
@ -332,6 +332,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
||||
{{BuiltinOperator_HARD_SWISH, 1}, "1.15.0"},
|
||||
{{BuiltinOperator_FILL, 1}, "1.13.0"},
|
||||
{{BuiltinOperator_FILL, 2}, "2.3.0"},
|
||||
{{BuiltinOperator_FILL, 3}, kPendingReleaseVersion},
|
||||
{{BuiltinOperator_REVERSE_V2, 1}, "1.14.0"},
|
||||
{{BuiltinOperator_REVERSE_V2, 2}, "2.2.0"},
|
||||
{{BuiltinOperator_REVERSE_V2, 3}, kPendingReleaseVersion},
|
||||
|
Loading…
x
Reference in New Issue
Block a user