Add 5D support for TFLite Transpose

PiperOrigin-RevId: 304139938
Change-Id: Id938689744d1f6b3623fd8d1ed63594c4afec950
This commit is contained in:
Thai Nguyen 2020-04-01 01:59:31 -07:00 committed by TensorFlower Gardener
parent 031804922d
commit 97b63ed876
13 changed files with 202 additions and 95 deletions

View File

@ -357,7 +357,9 @@ TopKV2OpTest/TopKV2OpTest/.+/0,29
# transpose_test
# death test
-TransposeTest/Test5DInputTensor
-TransposeTest/Test6DInputTensor
-TransposeTest/5DDividedIntoTwo2Ds.*
-TransposeTest/Complex5DTest.*
-TransposeTest/.+DynamicTensor
TransposeTest/.+

View File

@ -7694,7 +7694,7 @@ inline void Transpose3D(const TransposeParams& params,
}
}
template <typename T>
template <typename T, int N>
void TransposeImpl(const TransposeParams& params,
const RuntimeShape& input_shape, const T* input_data,
const RuntimeShape& output_shape, T* output_data) {
@ -7725,19 +7725,19 @@ void TransposeImpl(const TransposeParams& params,
// Reroute to the reference version if an optimized method for the given data
// is not available.
reference_ops::Transpose(params, input_shape, input_data, output_shape,
output_data);
reference_ops::Transpose<T, N>(params, input_shape, input_data, output_shape,
output_data);
}
template <typename T>
template <typename T, int N = 5>
void Transpose(const TransposeParams& unshrinked_params,
const RuntimeShape& unshrinked_input_shape, const T* input_data,
const RuntimeShape& unshrinked_output_shape, T* output_data) {
ruy::profiler::ScopeLabel label("Transpose");
const int output_size = unshrinked_output_shape.DimensionsCount();
TFLITE_DCHECK_LE(unshrinked_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(output_size, 4);
TFLITE_DCHECK_LE(unshrinked_input_shape.DimensionsCount(), N);
TFLITE_DCHECK_LE(output_size, N);
TFLITE_DCHECK_EQ(output_size, unshrinked_params.perm_count);
RuntimeShape shrinked_input_shape = RuntimeShape(unshrinked_input_shape);
@ -7778,15 +7778,16 @@ void Transpose(const TransposeParams& unshrinked_params,
TFLITE_DCHECK_NE(non_flatten_params.perm[0], 0);
for (int i = 0; i < total_size; i += non_flatten_size) {
TransposeImpl(non_flatten_params, non_flatten_input_shape, input_data + i,
non_flatten_output_shape, output_data + i);
TransposeImpl<T, N>(non_flatten_params, non_flatten_input_shape,
input_data + i, non_flatten_output_shape,
output_data + i);
}
return;
}
// Call non-flattened case.
TransposeImpl(shrinked_params, shrinked_input_shape, input_data,
shrinked_output_shape, output_data);
TransposeImpl<T, N>(shrinked_params, shrinked_input_shape, input_data,
shrinked_output_shape, output_data);
}
} // namespace optimized_ops

View File

@ -60,7 +60,6 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace reference_ops {
@ -1991,60 +1990,54 @@ inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
ArgMax(input1_shape, input1_data, input2_data, output_shape, output_data);
}
template <typename T>
template <typename T, int N>
void TransposeImpl(const TransposeParams& params,
const RuntimeShape& unextended_input_shape,
const T* input_data,
const RuntimeShape& unextended_output_shape,
T* output_data) {
const int unextended_input_size = unextended_input_shape.DimensionsCount();
const int unextended_output_size = unextended_output_shape.DimensionsCount();
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_size, 4);
TFLITE_DCHECK_LE(unextended_input_size, N);
TFLITE_DCHECK_LE(unextended_output_size, N);
TFLITE_DCHECK_EQ(unextended_output_size, params.perm_count);
const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
const int input_ext_size = 4 - unextended_input_shape.DimensionsCount();
const int output_ext_size = 4 - unextended_output_size;
const int input_ext_size = N - unextended_input_size;
const int output_ext_size = N - unextended_output_size;
NdArrayDesc<N> input_desc;
NdArrayDesc<N> output_desc;
CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_input_shape),
&input_desc);
CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
&output_desc);
// The perm data is extended to match the output, each index incremented by
// the amount of front padding of the input shape.
int extended_perm[4];
for (int i = 0; i < output_ext_size; ++i) {
extended_perm[i] = i;
}
for (int i = 0; i < unextended_output_size; ++i) {
extended_perm[i + output_ext_size] = params.perm[i] + input_ext_size;
int extended_perm[N];
for (int i = 0; i < N; ++i) {
extended_perm[i] = i < output_ext_size
? i
: params.perm[i - output_ext_size] + input_ext_size;
}
int out_sizes[4];
// Compute the inverse permutation array so we can do an output centered
// transpose. Also, check to make sure output_dims is matching input_dims.
for (int k = 0; k < 4; k++) {
out_sizes[k] = MatchingDim(input_shape, extended_perm[k], output_shape, k);
// Permutes the input shape so we don't need to permute the indexes inside
// the loop. Check to make sure output_dims is matching input_dims.
NdArrayDesc<N> perm_input_desc;
for (int k = 0; k < N; ++k) {
TFLITE_DCHECK_EQ(input_desc.extents[extended_perm[k]],
output_desc.extents[k]);
perm_input_desc.extents[k] = input_desc.extents[extended_perm[k]];
perm_input_desc.strides[k] = input_desc.strides[extended_perm[k]];
}
// Naive transpose loop (iterate on output index and compute input index).
int o[4]; // loop index (on output).
int i[4];
for (o[3] = 0; o[3] < out_sizes[3]; o[3]++) {
i[extended_perm[3]] = o[3];
for (o[2] = 0; o[2] < out_sizes[2]; o[2]++) {
i[extended_perm[2]] = o[2];
for (o[1] = 0; o[1] < out_sizes[1]; o[1]++) {
i[extended_perm[1]] = o[1];
for (o[0] = 0; o[0] < out_sizes[0]; o[0]++) {
i[extended_perm[0]] = o[0];
output_data[Offset(output_shape, o)] =
input_data[Offset(input_shape, i)];
}
}
}
}
auto tranpose_func = [&](int indexes[N]) {
output_data[SubscriptToIndex(output_desc, indexes)] =
input_data[SubscriptToIndex(perm_input_desc, indexes)];
};
NDOpsHelper<N>(output_desc, tranpose_func);
}
template <typename T>
template <typename T, int N = 5>
void Transpose(const TransposeParams& params,
const RuntimeShape& unextended_input_shape, const T* input_data,
const RuntimeShape& unextended_output_shape, T* output_data) {
@ -2053,29 +2046,29 @@ void Transpose(const TransposeParams& params,
// keeps the total code size in a reasonable range.
switch (sizeof(T)) {
case 1:
TransposeImpl<int8_t>(params, unextended_input_shape,
reinterpret_cast<const int8_t*>(input_data),
unextended_output_shape,
reinterpret_cast<int8_t*>(output_data));
TransposeImpl<int8_t, N>(params, unextended_input_shape,
reinterpret_cast<const int8_t*>(input_data),
unextended_output_shape,
reinterpret_cast<int8_t*>(output_data));
break;
case 2:
TransposeImpl<int16_t>(params, unextended_input_shape,
reinterpret_cast<const int16_t*>(input_data),
unextended_output_shape,
reinterpret_cast<int16_t*>(output_data));
TransposeImpl<int16_t, N>(params, unextended_input_shape,
reinterpret_cast<const int16_t*>(input_data),
unextended_output_shape,
reinterpret_cast<int16_t*>(output_data));
break;
case 4:
TransposeImpl<int32_t>(params, unextended_input_shape,
reinterpret_cast<const int32_t*>(input_data),
unextended_output_shape,
reinterpret_cast<int32_t*>(output_data));
TransposeImpl<int32_t, N>(params, unextended_input_shape,
reinterpret_cast<const int32_t*>(input_data),
unextended_output_shape,
reinterpret_cast<int32_t*>(output_data));
break;
case 8:
TransposeImpl<int64_t>(params, unextended_input_shape,
reinterpret_cast<const int64_t*>(input_data),
unextended_output_shape,
reinterpret_cast<int64_t*>(output_data));
TransposeImpl<int64_t, N>(params, unextended_input_shape,
reinterpret_cast<const int64_t*>(input_data),
unextended_output_shape,
reinterpret_cast<int64_t*>(output_data));
break;
}
}

View File

@ -1077,7 +1077,7 @@ struct TanhParams {
struct TransposeParams {
int8 perm_count;
int32 perm[4];
int32 perm[5];
};
struct UnpackParams {

View File

@ -132,7 +132,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* max_version */ 3);
AddBuiltin(BuiltinOperator_TRANSPOSE, Register_TRANSPOSE(),
/* min_version */ 1,
/* max_version */ 3);
/* max_version */ 4);
AddBuiltin(BuiltinOperator_MEAN, Register_MEAN(),
/* min_version */ 1,
/* max_version */ 2);

View File

@ -76,8 +76,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TransposeContext op_context(context, node);
// Ensure validity of input tensor.
TF_LITE_ENSURE_MSG(context, NumDimensions(op_context.input) <= 4,
"Transpose op only supports 1D-4D input arrays.");
TF_LITE_ENSURE_MSG(context, NumDimensions(op_context.input) <= 5,
"Transpose op only supports 1D-5D input arrays.");
TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
if (!IsConstantTensor(op_context.perm)) {

View File

@ -225,7 +225,7 @@ class TransposeOpConstModel : public TransposeOpModel {
TransposeOpConstModel(std::initializer_list<int> input_shape,
std::initializer_list<int> perm_shape,
std::initializer_list<int> perm) {
input_ = AddInput(TensorType_FLOAT32);
input_ = AddInput({TensorType_FLOAT32, input_shape});
perm_ = AddConstInput(TensorType_INT32, perm, perm_shape);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions,
@ -507,10 +507,43 @@ TEST(TransposeTest, 4DDividedIntoTwo2DsThird) {
EXPECT_EQ(m.GetOutput(), out);
}
TEST(TransposeTest, 5DDividedIntoTwo2DsOne) {
std::vector<float> out;
RunTestPermutation({2, 3, 2, 2, 2}, {1, 4, 2, 3, 0}, &out);
TransposeOpConstModel m({2, 3, 2, 2, 2}, {5}, {1, 4, 2, 3, 0});
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47});
m.Invoke();
EXPECT_EQ(m.GetOutput(), out);
}
TEST(TransposeTest, 5DDividedIntoTwo2DsTwo) {
std::vector<float> out;
RunTestPermutation({2, 3, 2, 2, 2}, {2, 3, 0, 4, 1}, &out);
TransposeOpConstModel m({2, 3, 2, 2, 2}, {5}, {2, 3, 0, 4, 1});
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47});
m.Invoke();
EXPECT_EQ(m.GetOutput(), out);
}
TEST(TransposeTest, 5DDividedIntoTwo2DsThird) {
std::vector<float> out;
RunTestPermutation({2, 3, 2, 2, 2}, {3, 0, 4, 1, 2}, &out);
TransposeOpConstModel m({2, 3, 2, 2, 2}, {5}, {3, 0, 4, 1, 2});
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47});
m.Invoke();
EXPECT_EQ(m.GetOutput(), out);
}
#ifdef GTEST_HAS_DEATH_TEST
TEST(TransposeTest, Test5DInputTensor) {
EXPECT_DEATH(TransposeOpConstModel({1, 2, 3, 4, 5}, {5}, {0, 1, 2, 3, 4}),
"Transpose op only supports 1D-4D input arrays.");
TEST(TransposeTest, Test6DInputTensor) {
EXPECT_DEATH(TransposeOpConstModel({1, 2, 3, 4, 5, 6}, {5}, {0, 1, 2, 3, 4}),
"Transpose op only supports 1D-5D input arrays.");
}
#endif
@ -594,5 +627,60 @@ TEST(TransposeTest, ComplexTestWithReorderDynamicTensor) {
EXPECT_THAT(m.GetOutput(), result);
}
TEST(TransposeTest, Complex5DTestWithReorderConstTensor) {
TransposeOpConstModel m({2, 3, 2, 2, 5}, {5}, {2, 0, 1, 4, 3});
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95,
96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107,
108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 3, 5, 2}));
auto result = ElementsAreArray(
{0, 5, 1, 6, 2, 7, 3, 8, 4, 9, 20, 25, 21, 26, 22,
27, 23, 28, 24, 29, 40, 45, 41, 46, 42, 47, 43, 48, 44, 49,
60, 65, 61, 66, 62, 67, 63, 68, 64, 69, 80, 85, 81, 86, 82,
87, 83, 88, 84, 89, 100, 105, 101, 106, 102, 107, 103, 108, 104, 109,
10, 15, 11, 16, 12, 17, 13, 18, 14, 19, 30, 35, 31, 36, 32,
37, 33, 38, 34, 39, 50, 55, 51, 56, 52, 57, 53, 58, 54, 59,
70, 75, 71, 76, 72, 77, 73, 78, 74, 79, 90, 95, 91, 96, 92,
97, 93, 98, 94, 99, 110, 115, 111, 116, 112, 117, 113, 118, 114, 119});
EXPECT_THAT(m.GetOutput(), result);
}
TEST(TransposeTest, Complex5DTestWithReorderDynamicTensor) {
TransposeOpDynamicModel m({2, 3, 2, 2, 5}, {5});
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95,
96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107,
108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119});
m.SetPerm({2, 0, 1, 4, 3});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 3, 5, 2}));
auto result = ElementsAreArray(
{0, 5, 1, 6, 2, 7, 3, 8, 4, 9, 20, 25, 21, 26, 22,
27, 23, 28, 24, 29, 40, 45, 41, 46, 42, 47, 43, 48, 44, 49,
60, 65, 61, 66, 62, 67, 63, 68, 64, 69, 80, 85, 81, 86, 82,
87, 83, 88, 84, 89, 100, 105, 101, 106, 102, 107, 103, 108, 104, 109,
10, 15, 11, 16, 12, 17, 13, 18, 14, 19, 30, 35, 31, 36, 32,
37, 33, 38, 34, 39, 50, 55, 51, 56, 52, 57, 53, 58, 54, 59,
70, 75, 71, 76, 72, 77, 73, 78, 74, 79, 90, 95, 91, 96, 92,
97, 93, 98, 94, 99, 110, 115, 111, 116, 112, 117, 113, 118, 114, 119});
EXPECT_THAT(m.GetOutput(), result);
}
} // namespace
} // namespace tflite

View File

@ -90,9 +90,6 @@ const std::map<string, string>& GetKnownBrokenTests() {
// ResizeBilinear looks completely incompatible with Tensorflow
{R"(^\/resize_bilinear.*dtype=tf.int32)", "72401107"},
// Transpose only supports 1D-4D input tensors.
{R"(^\/transpose.*input_shape=\[.,.,.,.,.\])", "71545879"},
// Relu does not support int32.
// These test cases appends a Relu after the tested ops when
// activation=True. The tests are failing since Relu doesn't support

View File

@ -59,6 +59,12 @@ def make_transpose_tests(options):
"perm": [[0, 1, 2, 3], [3, 0, 1, 2]],
"constant_perm": [True],
"fully_quantize": [True],
}, {
"dtype": [tf.float32],
"input_shape": [[1, 2, 3, 4, 5]],
"perm": [[0, 1, 2, 3, 4], [3, 4, 0, 1, 2]],
"constant_perm": [True],
"fully_quantize": [True, False],
}]
def build_graph(parameters):

View File

@ -241,7 +241,7 @@ class SpaceToBatchND
const Array& input_array = op_signature.model->GetArray(input_name);
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
op_sig.options.space_batch.num_dims =
op_sig.options.single_input_op.num_dims =
input_array.shape().dimensions_count();
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
@ -325,7 +325,7 @@ class BatchToSpaceND
const Array& input_array = op_signature.model->GetArray(input_name);
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
op_sig.options.space_batch.num_dims =
op_sig.options.single_input_op.num_dims =
input_array.shape().dimensions_count();
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
@ -1208,7 +1208,7 @@ class StridedSlice
static_cast<const StridedSliceOperator&>(*op_signature.op);
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
op_sig.options.strided_slice.num_dims = ss_op.start_indices.size();
op_sig.options.single_input_op.num_dims = ss_op.start_indices.size();
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};

View File

@ -173,6 +173,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
return 1;
case BuiltinOperator_TRANSPOSE:
if (op_sig.options.single_input_op.num_dims > 4) {
return 4;
}
// If the op takes bool input, it is version 3.
if (op_sig.input_types.at(0) == TensorType_BOOL) {
return 3;
@ -293,7 +296,7 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
}
return 1;
case BuiltinOperator_STRIDED_SLICE:
if (op_sig.options.strided_slice.num_dims > 4) {
if (op_sig.options.single_input_op.num_dims > 4) {
return 4;
}
// If the op takes bool input, it is version 3.
@ -352,7 +355,7 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
case BuiltinOperator_SPACE_TO_BATCH_ND:
case BuiltinOperator_BATCH_TO_SPACE_ND:
if (op_sig.options.space_batch.num_dims != 4) {
if (op_sig.options.single_input_op.num_dims != 4) {
return 3;
}
if (op_sig.input_types.at(0) == TensorType_INT8) {
@ -504,13 +507,11 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op,
}
} break;
// TODO(b/150176627): Add tests for GetOpSignature.
case BuiltinOperator_STRIDED_SLICE: {
op_sig.options.strided_slice.num_dims = GetNumDims(subgraph, op, 0);
} break;
case BuiltinOperator_STRIDED_SLICE:
case BuiltinOperator_SPACE_TO_BATCH_ND:
case BuiltinOperator_BATCH_TO_SPACE_ND: {
op_sig.options.space_batch.num_dims = GetNumDims(subgraph, op, 0);
case BuiltinOperator_BATCH_TO_SPACE_ND:
case BuiltinOperator_TRANSPOSE: {
op_sig.options.single_input_op.num_dims = GetNumDims(subgraph, op, 0);
} break;
case BuiltinOperator_SUB:

View File

@ -51,10 +51,7 @@ typedef struct {
} resize_bilinear;
struct {
int32_t num_dims;
} strided_slice;
struct {
int32_t num_dims;
} space_batch;
} single_input_op;
struct {
int32_t num_dims;
bool need_broadcast;

View File

@ -165,18 +165,18 @@ TEST(OpVersionTest, VersioningBatchToSpaceNDTest) {
.op = BuiltinOperator_BATCH_TO_SPACE_ND,
.input_types = std::vector<TensorType>{TensorType_INT8},
};
fake_op_sig.options.space_batch.num_dims = 3;
fake_op_sig.options.single_input_op.num_dims = 3;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
fake_op_sig.options.space_batch.num_dims = 4;
fake_op_sig.options.single_input_op.num_dims = 4;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
fake_op_sig = {
.op = BuiltinOperator_BATCH_TO_SPACE_ND,
.input_types = std::vector<TensorType>{TensorType_UINT8},
};
fake_op_sig.options.space_batch.num_dims = 3;
fake_op_sig.options.single_input_op.num_dims = 3;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
fake_op_sig.options.space_batch.num_dims = 4;
fake_op_sig.options.single_input_op.num_dims = 4;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
}
@ -503,4 +503,26 @@ TEST(OpVersionTest, VersioningTileOperatorTest) {
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
}
TEST(OpVersionTest, VersioningTransposeTest) {
OpSignature fake_op_sig = {
.op = BuiltinOperator_TRANSPOSE,
.input_types = std::vector<TensorType>{TensorType_BOOL},
};
fake_op_sig.options.single_input_op.num_dims = 5;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4);
fake_op_sig.options.single_input_op.num_dims = 4;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
fake_op_sig = {
.op = BuiltinOperator_TRANSPOSE,
.input_types = std::vector<TensorType>{TensorType_INT8},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
fake_op_sig = {
.op = BuiltinOperator_TRANSPOSE,
.input_types = std::vector<TensorType>{TensorType_UINT8},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
}
} // namespace tflite