Add Int8 support in TFLite Fill op

PiperOrigin-RevId: 351352604
Change-Id: Ie215ab1ade893ab3963934c4de7c26017fb1b00e
This commit is contained in:
Thai Nguyen 2021-01-12 05:41:11 -08:00 committed by TensorFlower Gardener
parent 915c0b2bbc
commit 49752e0f82
4 changed files with 21 additions and 2 deletions

View File

@ -1604,9 +1604,9 @@ def TFL_FillOp: TFL_Op<"fill", [
}];
let arguments = (ins TFL_I32OrI64Tensor:$dims,
TFL_TensorOf<[F32, I32, I64, I1, TFL_Str]>:$input);
TFL_TensorOf<[F32, I32, I64, I1, QI8, TFL_Str]>:$input);
let results = (outs TFL_TensorOf<[F32, I32, I64, I1, TFL_Str]>:$result);
let results = (outs TFL_TensorOf<[F32, I32, I64, I1, QI8, TFL_Str]>:$result);
let hasOptions = 0;
}

View File

@ -2505,3 +2505,11 @@ func @testBroadcastToWithI64ShapeTensor(tensor<?x?x?x?x?x?xf32>, tensor<8xi64>)
%0 = "tfl.broadcast_to"(%arg0, %arg1): (tensor<?x?x?x?x?x?xf32>, tensor<8xi64>) -> tensor<?x?x?x?x?x?x?x?xf32>
return %0 : tensor<?x?x?x?x?x?x?x?xf32>
}
// -----
// CHECK-LABEL: testFillWithQI8
func @testFillWithQI8(%arg0: tensor<1x4xi32>, %arg1: tensor<? x !quant.uniform<i8:f32, 0.1>>) -> tensor<? x !quant.uniform<i8:f32, 0.1>> {
%0 = "tfl.fill"(%arg0, %arg1): (tensor<1x4xi32>, tensor<? x !quant.uniform<i8:f32, 0.1>>) -> tensor<? x !quant.uniform<i8:f32, 0.1>>
return %0 : tensor<? x !quant.uniform<i8:f32, 0.1>>
}

View File

@ -147,6 +147,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteString:
FillString(value, output);
break;
case kTfLiteInt8:
TF_LITE_FILL(int8_t);
break;
default:
context->ReportError(
context,

View File

@ -136,6 +136,14 @@ TEST(FillOpTest, FillString) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2}));
}
TEST_P(FillOpTest, FillInt8) {
FillOpModel<int64_t, int8_t> m(TensorType_INT64, {3}, {2, 2, 2}, 5,
GetParam());
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5}));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2}));
}
INSTANTIATE_TEST_SUITE_P(FillOpTest, FillOpTest,
::testing::Values(TestType::kConst,
TestType::kDynamic));