Add int8 and int16 support for FILL operator

This commit is contained in:
Benjamin Klimczak 2020-11-10 12:07:46 +00:00
parent df87bbd3bb
commit 607c596d95
9 changed files with 118 additions and 16 deletions

View File

@ -1596,17 +1596,16 @@ shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1]
def TFL_FillOp: TFL_Op<"fill", [
NoSideEffect,
PredOpTrait<"input and result must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 1>>,
NoQuantizableResult]> {
TFL_TCresVTEtIsSameAsOp<0, 1>>]> {
let summary = "Fill the tensor with given value.";
let description = [{
Fill the tensor with given value.
}];
let arguments = (ins TFL_I32OrI64Tensor:$dims,
TFL_TensorOf<[F32, I32, I64, I1, QI8, TFL_Str]>:$input);
TFL_TensorOf<[F32, I32, I64, I1, QI8, QI16, TFL_Str]>:$input);
let results = (outs TFL_TensorOf<[F32, I32, I64, I1, QI8, TFL_Str]>:$result);
let results = (outs TFL_TensorOf<[F32, I32, I64, I1, QI8, QI16, TFL_Str]>:$result);
let hasOptions = 0;
}

View File

@ -92,6 +92,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
GetOutputSafe(context, node, kOutputTensor, &output));
output->type = value->type;
TF_LITE_ENSURE_EQ(context, output->params.scale, value->params.scale);
TF_LITE_ENSURE_EQ(context, output->params.zero_point,
value->params.zero_point);
if (value->type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, value->params.zero_point, 0);
}
if (IsConstantTensor(dims)) {
TF_LITE_ENSURE_OK(context, ResizeOutput(context, dims, output));
} else {
@ -132,6 +140,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTensorShape(output), \
GetTensorData<data_type>(output))
switch (output->type) {
case kTfLiteInt8:
TF_LITE_FILL(int8_t);
break;
case kTfLiteInt16:
TF_LITE_FILL(int16_t);
break;
case kTfLiteInt32:
TF_LITE_FILL(int32_t);
break;
@ -147,14 +161,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteString:
FillString(value, output);
break;
case kTfLiteInt8:
TF_LITE_FILL(int8_t);
break;
default:
context->ReportError(
context,
"Fill only currently supports int32, int64, float32, bool, string "
"for input 1, got %d.",
"Fill only currently supports int8, int16, int32, int64, float32, "
"bool, string for input 1, got %d.",
value->type);
return kTfLiteError;
}

View File

@ -73,6 +73,42 @@ class FillOpModel : public SingleOpModel {
int output_;
};
template <typename dims_type, typename quant_type>
class QuantizedFillOpModel : public SingleOpModel {
public:
explicit QuantizedFillOpModel(TensorType dims_tensor_type,
std::initializer_list<int> dims_shape,
std::initializer_list<dims_type> dims_data,
const TensorData& tensor_data,
float value) {
dims_ = AddInput(dims_tensor_type);
value_ = AddInput(tensor_data);
output_ = AddOutput(tensor_data);
SetBuiltinOp(BuiltinOperator_FILL, BuiltinOptions_FillOptions,
CreateFillOptions(builder_).Union());
BuildInterpreter({dims_shape, {}});
if (dims_data.size() > 0) {
PopulateTensor<dims_type>(dims_, dims_data);
}
QuantizeAndPopulate<quant_type>(value_, {value});
}
std::vector<quant_type> GetOutput() {
return ExtractVector<quant_type>(output_);
}
std::vector<float> GetDequantizedOutput() {
TfLiteTensor* t = interpreter_->tensor(output_);
return Dequantize(GetOutput(), t->params.scale, t->params.zero_point);
}
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
protected:
int dims_;
int value_;
int output_;
};
class FillOpTest : public ::testing::TestWithParam<TestType> {};
TEST_P(FillOpTest, FillInt32) {
@ -144,6 +180,42 @@ TEST_P(FillOpTest, FillInt8) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2}));
}
template <typename quant_type>
void QuantizedFill(float value) {
// Prepare TensorData for quantization of value
const float kMin = -1;
// Workaround to get a zero-point of 0
const float kMax =
std::numeric_limits<quant_type>::max() /
static_cast<float>(std::numeric_limits<quant_type>::max() + 1);
const TensorData tensor_data(GetTensorType<quant_type>(), {},
std::abs(value) * kMin, std::abs(value) * kMax);
QuantizedFillOpModel<int32_t, quant_type> m(TensorType_INT32, {2}, {2, 3},
tensor_data, value);
m.Invoke();
constexpr float epsilon = 0.01f;
const float min_value = tensor_data.min - epsilon;
const float max_value = tensor_data.max + epsilon;
const float kQuantizedTolerance =
(max_value - min_value) / (std::numeric_limits<quant_type>::max() -
std::numeric_limits<quant_type>::min());
EXPECT_THAT(
m.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear(
{value, value, value, value, value, value}, kQuantizedTolerance)));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
}
TEST(FillOpTest, QuantizedFillInt8) {
QuantizedFill<int8_t>(3.14f);
}
TEST(FillOpTest, QuantizedFillInt16) {
QuantizedFill<int16_t>(3.14f);
}
INSTANTIATE_TEST_SUITE_P(FillOpTest, FillOpTest,
::testing::Values(TestType::kConst,
TestType::kDynamic));

View File

@ -270,7 +270,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_FILL, Register_FILL(),
/* min_version = */ 1,
/* max_version = */ 2);
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_MIRROR_PAD, Register_MIRROR_PAD(),
/* min_version = */ 1,
/* max_version = */ 2);

View File

@ -432,7 +432,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_FILL, Register_FILL(),
/* min_version = */ 1,
/* max_version = */ 2);
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_MIRROR_PAD, Register_MIRROR_PAD(),
/* min_version = */ 1,
/* max_version = */ 2);

View File

@ -192,6 +192,13 @@ OperatorProperty GetOperatorProperty(OpVariant op_variant) {
property.outputs = {{0, {}}};
property.version = 1;
break;
case BuiltinOperator_FILL: {
property.inputs = {{1, {}}};
property.outputs = {{0, {}}};
property.restrict_same_input_output_scale = true;
property.version = 3;
break;
}
case BuiltinOperator_FULLY_CONNECTED: {
TensorProperty tensor_property;
tensor_property.symmetric = true;

View File

@ -538,10 +538,14 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
return 1;
case BuiltinOperator_FILL:
if (op_sig.input_types.size() >= 2 &&
(op_sig.input_types.at(1) == TensorType_BOOL ||
op_sig.input_types.at(1) == TensorType_STRING)) {
return 2;
if (op_sig.input_types.size() >= 2) {
if (op_sig.input_types.at(1) == TensorType_INT8 ||
op_sig.input_types.at(1) == TensorType_INT16) {
return 3;
} else if ((op_sig.input_types.at(1) == TensorType_BOOL ||
op_sig.input_types.at(1) == TensorType_STRING)) {
return 2;
}
}
return 1;

View File

@ -700,7 +700,15 @@ TEST(OpVersionTest, VersioningDivTest) {
TEST(OpVersionTEst, VersioningFillTest) {
OpSignature fake_op_sig = {.op = BuiltinOperator_FILL,
.input_types = std::vector<TensorType>{
TensorType_INT32, TensorType_BOOL}};
TensorType_INT32, TensorType_INT8}};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
fake_op_sig = {.op = BuiltinOperator_FILL,
.input_types = std::vector<TensorType>{TensorType_INT64,
TensorType_INT16}};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
fake_op_sig = {.op = BuiltinOperator_FILL,
.input_types = std::vector<TensorType>{TensorType_INT32,
TensorType_BOOL}};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
fake_op_sig = {.op = BuiltinOperator_FILL,
.input_types = std::vector<TensorType>{TensorType_INT32,

View File

@ -332,6 +332,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_HARD_SWISH, 1}, "1.15.0"},
{{BuiltinOperator_FILL, 1}, "1.13.0"},
{{BuiltinOperator_FILL, 2}, "2.3.0"},
{{BuiltinOperator_FILL, 3}, kPendingReleaseVersion},
{{BuiltinOperator_REVERSE_V2, 1}, "1.14.0"},
{{BuiltinOperator_REVERSE_V2, 2}, "2.2.0"},
{{BuiltinOperator_REVERSE_V2, 3}, kPendingReleaseVersion},