Create int8 strided slice.

PiperOrigin-RevId: 233487928
This commit is contained in:
Jian Li 2019-02-11 15:11:04 -08:00 committed by TensorFlower Gardener
parent c1043a02f9
commit 3a475eb69d
5 changed files with 30 additions and 1 deletions

View File

@ -253,7 +253,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* max_version */ 3);
AddBuiltin(BuiltinOperator_SPLIT_V, Register_SPLIT_V());
AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE());
AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE());
AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE(),
/* min_version */ 1,
/* max_version */ 2);
AddBuiltin(BuiltinOperator_EXP, Register_EXP());
AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2(),
/* min_version */ 1,

View File

@ -234,6 +234,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_STRIDED_SLICE(reference_ops, uint8_t);
}
break;
case kTfLiteInt8:
if (kernel_type == kReference) {
TF_LITE_STRIDED_SLICE(reference_ops, int8_t);
}
break;
default:
context->ReportError(context,
"Type %d is currently not supported "

View File

@ -577,6 +577,18 @@ TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1Uint8) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
}
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1int8) {
StridedSliceOpModel<int8_t, TensorType_INT8> m({2, 3, 2}, {3}, {3}, {3}, 0, 0,
0, 0, 1);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetBegin({0, 0, 0});
m.SetEnd({1, 3, 2});
m.SetStrides({1, 1, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
}
} // namespace
} // namespace tflite

View File

@ -1400,6 +1400,12 @@ class StridedSlice
}
int GetVersion(const OperatorSignature& op_signature) const override {
const string& input_name = op_signature.op->inputs[0];
const Array& input_array = op_signature.model->GetArray(input_name);
// If the op take int8 input, it is version 2.
if (input_array.data_type == ArrayDataType::kInt8) {
return 2;
}
return 1;
}
};

View File

@ -808,6 +808,10 @@ TEST_F(OperatorTest, VersioningPackTest) {
SimpleVersioningTest<PackOperator>();
}
TEST_F(OperatorTest, VersioningStridedSliceTest) {
SimpleVersioningTest<StridedSliceOperator>();
}
TEST_F(OperatorTest, VersioningSelectTest) {
SelectOperator select_op;
select_op.inputs = {"input1"};