SLICE, TRANSPOSE operators are fixed for the given network.

Change-Id: Ia33e6d0bb55273a52e8b7ac6b9226b7175cb8010
This commit is contained in:
Elena Zhelezina 2020-07-27 12:02:34 +01:00
parent f295633406
commit 9ca8d0cdc1
5 changed files with 31 additions and 4 deletions

View File

@ -360,6 +360,7 @@ SelectOpTest/.+,29
-SliceOpTest/SliceOpTest/SliceString/.+
-SliceOpTest/SliceOpTest/SliceInt64/.+
-SliceOpTest/SliceOpTest/SliceBool/.+
-SliceOpTest/SliceOpTest/SliceInt16/.+
# Only constant tensors
SliceOpTest/SliceOpTest/.+/0,29
@ -409,6 +410,7 @@ TopKV2OpTest/TopKV2OpTest/.+/0,29
-TransposeTest/5DDividedIntoTwo2Ds.*
-TransposeTest/Complex5DTest.*
-TransposeTest/.+DynamicTensor
-TransposeTest/TestRefOps4DInt16
TransposeTest/.+
# transpose_conv_test

View File

@ -214,6 +214,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt8:
TF_LITE_SLICE(int8_t, kernel_type);
break;
case kTfLiteInt16:
TF_LITE_SLICE(int16_t, kernel_type);
break;
case kTfLiteUInt8:
TF_LITE_SLICE(uint8_t, kernel_type);
break;

View File

@ -226,6 +226,16 @@ TEST_P(SliceOpTest, SliceInt8) {
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
}
TEST_P(SliceOpTest, SliceInt16) {
SliceOpModel<int16_t, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
{2, 1, -1, 1}, TensorType_INT32,
TensorType_INT16, GetParam());
m.SetInput({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 3, 1}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 3, 3, 5, 5, 5}));
}
TEST_P(SliceOpTest, SliceString) {
SliceOpModel<string, int32_t> m({3, 2, 3, 1}, {4}, {1, 0, 0, 0}, {4},
{2, 1, -1, 1}, TensorType_INT32,

View File

@ -130,6 +130,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_TRANSPOSE(reference_ops, int8_t);
}
break;
case kTfLiteInt16:
TF_LITE_TRANSPOSE(reference_ops, int16_t);
break;
case kTfLiteInt64:
TF_LITE_TRANSPOSE(reference_ops, int64_t);
break;

View File

@ -180,13 +180,14 @@ TEST(TransposeTest, TestRefOps4D) {
ASSERT_EQ(out, ref);
}
TEST(TransposeTest, TestRefOps4DInt8) {
std::vector<int8_t> out;
template<typename T>
void TransposeTestTestRefOps4D() {
std::vector<T> out;
// Basic 4d.
RunTestPermutation({2, 3, 4, 5}, {2, 0, 1, 3}, &out);
ASSERT_EQ(
out,
std::vector<int8_t>(
std::vector<T>(
{0, 1, 2, 3, 4, 20, 21, 22, 23, 24, 40, 41, 42, 43, 44,
60, 61, 62, 63, 64, 80, 81, 82, 83, 84, 100, 101, 102, 103, 104,
5, 6, 7, 8, 9, 25, 26, 27, 28, 29, 45, 46, 47, 48, 49,
@ -197,11 +198,19 @@ TEST(TransposeTest, TestRefOps4DInt8) {
75, 76, 77, 78, 79, 95, 96, 97, 98, 99, 115, 116, 117, 118, 119}));
RunTestPermutation({2, 3, 4, 5}, {0, 1, 2, 3}, &out);
// Basic identity.
std::vector<int8_t> ref(out.size());
std::vector<T> ref(out.size());
for (int k = 0; k < ref.size(); k++) ref[k] = k;
ASSERT_EQ(out, ref);
}
TEST(TransposeTest, TestRefOps4DInt8) {
TransposeTestTestRefOps4D<int8_t>();
}
TEST(TransposeTest, TestRefOps4DInt16) {
TransposeTestTestRefOps4D<int16_t>();
}
class TransposeOpModel : public SingleOpModel {
public:
void SetInput(std::initializer_list<float> data) {