Add Int8 support in TFLite Fill op
PiperOrigin-RevId: 351352604 Change-Id: Ie215ab1ade893ab3963934c4de7c26017fb1b00e
This commit is contained in:
parent
915c0b2bbc
commit
49752e0f82
@ -1604,9 +1604,9 @@ def TFL_FillOp: TFL_Op<"fill", [
|
||||
}];
|
||||
|
||||
let arguments = (ins TFL_I32OrI64Tensor:$dims,
|
||||
TFL_TensorOf<[F32, I32, I64, I1, TFL_Str]>:$input);
|
||||
TFL_TensorOf<[F32, I32, I64, I1, QI8, TFL_Str]>:$input);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, I32, I64, I1, TFL_Str]>:$result);
|
||||
let results = (outs TFL_TensorOf<[F32, I32, I64, I1, QI8, TFL_Str]>:$result);
|
||||
|
||||
let hasOptions = 0;
|
||||
}
|
||||
|
@ -2505,3 +2505,11 @@ func @testBroadcastToWithI64ShapeTensor(tensor<?x?x?x?x?x?xf32>, tensor<8xi64>)
|
||||
%0 = "tfl.broadcast_to"(%arg0, %arg1): (tensor<?x?x?x?x?x?xf32>, tensor<8xi64>) -> tensor<?x?x?x?x?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?x?x?x?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testFillWithQI8
|
||||
func @testFillWithQI8(%arg0: tensor<1x4xi32>, %arg1: tensor<? x !quant.uniform<i8:f32, 0.1>>) -> tensor<? x !quant.uniform<i8:f32, 0.1>> {
|
||||
%0 = "tfl.fill"(%arg0, %arg1): (tensor<1x4xi32>, tensor<? x !quant.uniform<i8:f32, 0.1>>) -> tensor<? x !quant.uniform<i8:f32, 0.1>>
|
||||
return %0 : tensor<? x !quant.uniform<i8:f32, 0.1>>
|
||||
}
|
||||
|
@ -147,6 +147,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteString:
|
||||
FillString(value, output);
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
TF_LITE_FILL(int8_t);
|
||||
break;
|
||||
default:
|
||||
context->ReportError(
|
||||
context,
|
||||
|
@ -136,6 +136,14 @@ TEST(FillOpTest, FillString) {
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2}));
|
||||
}
|
||||
|
||||
TEST_P(FillOpTest, FillInt8) {
|
||||
FillOpModel<int64_t, int8_t> m(TensorType_INT64, {3}, {2, 2, 2}, 5,
|
||||
GetParam());
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5}));
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2}));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(FillOpTest, FillOpTest,
|
||||
::testing::Values(TestType::kConst,
|
||||
TestType::kDynamic));
|
||||
|
Loading…
x
Reference in New Issue
Block a user