Create int8 strided slice.
PiperOrigin-RevId: 233487928
This commit is contained in:
parent
c1043a02f9
commit
3a475eb69d
@ -253,7 +253,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
/* max_version */ 3);
|
||||
AddBuiltin(BuiltinOperator_SPLIT_V, Register_SPLIT_V());
|
||||
AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE());
|
||||
AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE());
|
||||
AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE(),
|
||||
/* min_version */ 1,
|
||||
/* max_version */ 2);
|
||||
AddBuiltin(BuiltinOperator_EXP, Register_EXP());
|
||||
AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2(),
|
||||
/* min_version */ 1,
|
||||
|
@ -234,6 +234,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_STRIDED_SLICE(reference_ops, uint8_t);
|
||||
}
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_STRIDED_SLICE(reference_ops, int8_t);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
context->ReportError(context,
|
||||
"Type %d is currently not supported "
|
||||
|
@ -577,6 +577,18 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1Uint8) {
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||
}
|
||||
|
||||
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1int8) {
|
||||
StridedSliceOpModel<int8_t, TensorType_INT8> m({2, 3, 2}, {3}, {3}, {3}, 0, 0,
|
||||
0, 0, 1);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
m.SetBegin({0, 0, 0});
|
||||
m.SetEnd({1, 3, 2});
|
||||
m.SetStrides({1, 1, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -1400,6 +1400,12 @@ class StridedSlice
|
||||
}
|
||||
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
const string& input_name = op_signature.op->inputs[0];
|
||||
const Array& input_array = op_signature.model->GetArray(input_name);
|
||||
// If the op take int8 input, it is version 2.
|
||||
if (input_array.data_type == ArrayDataType::kInt8) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
@ -808,6 +808,10 @@ TEST_F(OperatorTest, VersioningPackTest) {
|
||||
SimpleVersioningTest<PackOperator>();
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, VersioningStridedSliceTest) {
|
||||
SimpleVersioningTest<StridedSliceOperator>();
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, VersioningSelectTest) {
|
||||
SelectOperator select_op;
|
||||
select_op.inputs = {"input1"};
|
||||
|
Loading…
x
Reference in New Issue
Block a user