Add 5D support for TFLite Transpose
PiperOrigin-RevId: 304139938 Change-Id: Id938689744d1f6b3623fd8d1ed63594c4afec950
This commit is contained in:
parent
031804922d
commit
97b63ed876
@ -357,7 +357,9 @@ TopKV2OpTest/TopKV2OpTest/.+/0,29
|
||||
|
||||
# transpose_test
|
||||
# death test
|
||||
-TransposeTest/Test5DInputTensor
|
||||
-TransposeTest/Test6DInputTensor
|
||||
-TransposeTest/5DDividedIntoTwo2Ds.*
|
||||
-TransposeTest/Complex5DTest.*
|
||||
-TransposeTest/.+DynamicTensor
|
||||
TransposeTest/.+
|
||||
|
||||
|
@ -7694,7 +7694,7 @@ inline void Transpose3D(const TransposeParams& params,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, int N>
|
||||
void TransposeImpl(const TransposeParams& params,
|
||||
const RuntimeShape& input_shape, const T* input_data,
|
||||
const RuntimeShape& output_shape, T* output_data) {
|
||||
@ -7725,19 +7725,19 @@ void TransposeImpl(const TransposeParams& params,
|
||||
|
||||
// Reroute to the reference version if an optimized method for the given data
|
||||
// is not available.
|
||||
reference_ops::Transpose(params, input_shape, input_data, output_shape,
|
||||
output_data);
|
||||
reference_ops::Transpose<T, N>(params, input_shape, input_data, output_shape,
|
||||
output_data);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, int N = 5>
|
||||
void Transpose(const TransposeParams& unshrinked_params,
|
||||
const RuntimeShape& unshrinked_input_shape, const T* input_data,
|
||||
const RuntimeShape& unshrinked_output_shape, T* output_data) {
|
||||
ruy::profiler::ScopeLabel label("Transpose");
|
||||
|
||||
const int output_size = unshrinked_output_shape.DimensionsCount();
|
||||
TFLITE_DCHECK_LE(unshrinked_input_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(output_size, 4);
|
||||
TFLITE_DCHECK_LE(unshrinked_input_shape.DimensionsCount(), N);
|
||||
TFLITE_DCHECK_LE(output_size, N);
|
||||
TFLITE_DCHECK_EQ(output_size, unshrinked_params.perm_count);
|
||||
|
||||
RuntimeShape shrinked_input_shape = RuntimeShape(unshrinked_input_shape);
|
||||
@ -7778,15 +7778,16 @@ void Transpose(const TransposeParams& unshrinked_params,
|
||||
TFLITE_DCHECK_NE(non_flatten_params.perm[0], 0);
|
||||
|
||||
for (int i = 0; i < total_size; i += non_flatten_size) {
|
||||
TransposeImpl(non_flatten_params, non_flatten_input_shape, input_data + i,
|
||||
non_flatten_output_shape, output_data + i);
|
||||
TransposeImpl<T, N>(non_flatten_params, non_flatten_input_shape,
|
||||
input_data + i, non_flatten_output_shape,
|
||||
output_data + i);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Call non-flattened case.
|
||||
TransposeImpl(shrinked_params, shrinked_input_shape, input_data,
|
||||
shrinked_output_shape, output_data);
|
||||
TransposeImpl<T, N>(shrinked_params, shrinked_input_shape, input_data,
|
||||
shrinked_output_shape, output_data);
|
||||
}
|
||||
|
||||
} // namespace optimized_ops
|
||||
|
@ -60,7 +60,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
namespace reference_ops {
|
||||
@ -1991,60 +1990,54 @@ inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
|
||||
ArgMax(input1_shape, input1_data, input2_data, output_shape, output_data);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, int N>
|
||||
void TransposeImpl(const TransposeParams& params,
|
||||
const RuntimeShape& unextended_input_shape,
|
||||
const T* input_data,
|
||||
const RuntimeShape& unextended_output_shape,
|
||||
T* output_data) {
|
||||
const int unextended_input_size = unextended_input_shape.DimensionsCount();
|
||||
const int unextended_output_size = unextended_output_shape.DimensionsCount();
|
||||
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_output_size, 4);
|
||||
TFLITE_DCHECK_LE(unextended_input_size, N);
|
||||
TFLITE_DCHECK_LE(unextended_output_size, N);
|
||||
TFLITE_DCHECK_EQ(unextended_output_size, params.perm_count);
|
||||
const RuntimeShape input_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_input_shape);
|
||||
const RuntimeShape output_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_output_shape);
|
||||
const int input_ext_size = 4 - unextended_input_shape.DimensionsCount();
|
||||
const int output_ext_size = 4 - unextended_output_size;
|
||||
const int input_ext_size = N - unextended_input_size;
|
||||
const int output_ext_size = N - unextended_output_size;
|
||||
NdArrayDesc<N> input_desc;
|
||||
NdArrayDesc<N> output_desc;
|
||||
CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_input_shape),
|
||||
&input_desc);
|
||||
CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
|
||||
&output_desc);
|
||||
|
||||
// The perm data is extended to match the output, each index incremented by
|
||||
// the amount of front padding of the input shape.
|
||||
int extended_perm[4];
|
||||
for (int i = 0; i < output_ext_size; ++i) {
|
||||
extended_perm[i] = i;
|
||||
}
|
||||
for (int i = 0; i < unextended_output_size; ++i) {
|
||||
extended_perm[i + output_ext_size] = params.perm[i] + input_ext_size;
|
||||
int extended_perm[N];
|
||||
for (int i = 0; i < N; ++i) {
|
||||
extended_perm[i] = i < output_ext_size
|
||||
? i
|
||||
: params.perm[i - output_ext_size] + input_ext_size;
|
||||
}
|
||||
|
||||
int out_sizes[4];
|
||||
// Compute the inverse permutation array so we can do an output centered
|
||||
// transpose. Also, check to make sure output_dims is matching input_dims.
|
||||
for (int k = 0; k < 4; k++) {
|
||||
out_sizes[k] = MatchingDim(input_shape, extended_perm[k], output_shape, k);
|
||||
// Permutes the input shape so we don't need to permute the indexes inside
|
||||
// the loop. Check to make sure output_dims is matching input_dims.
|
||||
NdArrayDesc<N> perm_input_desc;
|
||||
for (int k = 0; k < N; ++k) {
|
||||
TFLITE_DCHECK_EQ(input_desc.extents[extended_perm[k]],
|
||||
output_desc.extents[k]);
|
||||
perm_input_desc.extents[k] = input_desc.extents[extended_perm[k]];
|
||||
perm_input_desc.strides[k] = input_desc.strides[extended_perm[k]];
|
||||
}
|
||||
|
||||
// Naive transpose loop (iterate on output index and compute input index).
|
||||
int o[4]; // loop index (on output).
|
||||
int i[4];
|
||||
for (o[3] = 0; o[3] < out_sizes[3]; o[3]++) {
|
||||
i[extended_perm[3]] = o[3];
|
||||
for (o[2] = 0; o[2] < out_sizes[2]; o[2]++) {
|
||||
i[extended_perm[2]] = o[2];
|
||||
for (o[1] = 0; o[1] < out_sizes[1]; o[1]++) {
|
||||
i[extended_perm[1]] = o[1];
|
||||
for (o[0] = 0; o[0] < out_sizes[0]; o[0]++) {
|
||||
i[extended_perm[0]] = o[0];
|
||||
output_data[Offset(output_shape, o)] =
|
||||
input_data[Offset(input_shape, i)];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
auto tranpose_func = [&](int indexes[N]) {
|
||||
output_data[SubscriptToIndex(output_desc, indexes)] =
|
||||
input_data[SubscriptToIndex(perm_input_desc, indexes)];
|
||||
};
|
||||
NDOpsHelper<N>(output_desc, tranpose_func);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, int N = 5>
|
||||
void Transpose(const TransposeParams& params,
|
||||
const RuntimeShape& unextended_input_shape, const T* input_data,
|
||||
const RuntimeShape& unextended_output_shape, T* output_data) {
|
||||
@ -2053,29 +2046,29 @@ void Transpose(const TransposeParams& params,
|
||||
// keeps the total code size in a reasonable range.
|
||||
switch (sizeof(T)) {
|
||||
case 1:
|
||||
TransposeImpl<int8_t>(params, unextended_input_shape,
|
||||
reinterpret_cast<const int8_t*>(input_data),
|
||||
unextended_output_shape,
|
||||
reinterpret_cast<int8_t*>(output_data));
|
||||
TransposeImpl<int8_t, N>(params, unextended_input_shape,
|
||||
reinterpret_cast<const int8_t*>(input_data),
|
||||
unextended_output_shape,
|
||||
reinterpret_cast<int8_t*>(output_data));
|
||||
break;
|
||||
case 2:
|
||||
TransposeImpl<int16_t>(params, unextended_input_shape,
|
||||
reinterpret_cast<const int16_t*>(input_data),
|
||||
unextended_output_shape,
|
||||
reinterpret_cast<int16_t*>(output_data));
|
||||
TransposeImpl<int16_t, N>(params, unextended_input_shape,
|
||||
reinterpret_cast<const int16_t*>(input_data),
|
||||
unextended_output_shape,
|
||||
reinterpret_cast<int16_t*>(output_data));
|
||||
break;
|
||||
|
||||
case 4:
|
||||
TransposeImpl<int32_t>(params, unextended_input_shape,
|
||||
reinterpret_cast<const int32_t*>(input_data),
|
||||
unextended_output_shape,
|
||||
reinterpret_cast<int32_t*>(output_data));
|
||||
TransposeImpl<int32_t, N>(params, unextended_input_shape,
|
||||
reinterpret_cast<const int32_t*>(input_data),
|
||||
unextended_output_shape,
|
||||
reinterpret_cast<int32_t*>(output_data));
|
||||
break;
|
||||
case 8:
|
||||
TransposeImpl<int64_t>(params, unextended_input_shape,
|
||||
reinterpret_cast<const int64_t*>(input_data),
|
||||
unextended_output_shape,
|
||||
reinterpret_cast<int64_t*>(output_data));
|
||||
TransposeImpl<int64_t, N>(params, unextended_input_shape,
|
||||
reinterpret_cast<const int64_t*>(input_data),
|
||||
unextended_output_shape,
|
||||
reinterpret_cast<int64_t*>(output_data));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -1077,7 +1077,7 @@ struct TanhParams {
|
||||
|
||||
struct TransposeParams {
|
||||
int8 perm_count;
|
||||
int32 perm[4];
|
||||
int32 perm[5];
|
||||
};
|
||||
|
||||
struct UnpackParams {
|
||||
|
@ -132,7 +132,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
/* max_version */ 3);
|
||||
AddBuiltin(BuiltinOperator_TRANSPOSE, Register_TRANSPOSE(),
|
||||
/* min_version */ 1,
|
||||
/* max_version */ 3);
|
||||
/* max_version */ 4);
|
||||
AddBuiltin(BuiltinOperator_MEAN, Register_MEAN(),
|
||||
/* min_version */ 1,
|
||||
/* max_version */ 2);
|
||||
|
@ -76,8 +76,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TransposeContext op_context(context, node);
|
||||
|
||||
// Ensure validity of input tensor.
|
||||
TF_LITE_ENSURE_MSG(context, NumDimensions(op_context.input) <= 4,
|
||||
"Transpose op only supports 1D-4D input arrays.");
|
||||
TF_LITE_ENSURE_MSG(context, NumDimensions(op_context.input) <= 5,
|
||||
"Transpose op only supports 1D-5D input arrays.");
|
||||
TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
|
||||
|
||||
if (!IsConstantTensor(op_context.perm)) {
|
||||
|
@ -225,7 +225,7 @@ class TransposeOpConstModel : public TransposeOpModel {
|
||||
TransposeOpConstModel(std::initializer_list<int> input_shape,
|
||||
std::initializer_list<int> perm_shape,
|
||||
std::initializer_list<int> perm) {
|
||||
input_ = AddInput(TensorType_FLOAT32);
|
||||
input_ = AddInput({TensorType_FLOAT32, input_shape});
|
||||
perm_ = AddConstInput(TensorType_INT32, perm, perm_shape);
|
||||
output_ = AddOutput(TensorType_FLOAT32);
|
||||
SetBuiltinOp(BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions,
|
||||
@ -507,10 +507,43 @@ TEST(TransposeTest, 4DDividedIntoTwo2DsThird) {
|
||||
EXPECT_EQ(m.GetOutput(), out);
|
||||
}
|
||||
|
||||
TEST(TransposeTest, 5DDividedIntoTwo2DsOne) {
|
||||
std::vector<float> out;
|
||||
RunTestPermutation({2, 3, 2, 2, 2}, {1, 4, 2, 3, 0}, &out);
|
||||
TransposeOpConstModel m({2, 3, 2, 2, 2}, {5}, {1, 4, 2, 3, 0});
|
||||
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
|
||||
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47});
|
||||
m.Invoke();
|
||||
EXPECT_EQ(m.GetOutput(), out);
|
||||
}
|
||||
|
||||
TEST(TransposeTest, 5DDividedIntoTwo2DsTwo) {
|
||||
std::vector<float> out;
|
||||
RunTestPermutation({2, 3, 2, 2, 2}, {2, 3, 0, 4, 1}, &out);
|
||||
TransposeOpConstModel m({2, 3, 2, 2, 2}, {5}, {2, 3, 0, 4, 1});
|
||||
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
|
||||
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47});
|
||||
m.Invoke();
|
||||
EXPECT_EQ(m.GetOutput(), out);
|
||||
}
|
||||
|
||||
TEST(TransposeTest, 5DDividedIntoTwo2DsThird) {
|
||||
std::vector<float> out;
|
||||
RunTestPermutation({2, 3, 2, 2, 2}, {3, 0, 4, 1, 2}, &out);
|
||||
TransposeOpConstModel m({2, 3, 2, 2, 2}, {5}, {3, 0, 4, 1, 2});
|
||||
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
|
||||
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47});
|
||||
m.Invoke();
|
||||
EXPECT_EQ(m.GetOutput(), out);
|
||||
}
|
||||
|
||||
#ifdef GTEST_HAS_DEATH_TEST
|
||||
TEST(TransposeTest, Test5DInputTensor) {
|
||||
EXPECT_DEATH(TransposeOpConstModel({1, 2, 3, 4, 5}, {5}, {0, 1, 2, 3, 4}),
|
||||
"Transpose op only supports 1D-4D input arrays.");
|
||||
TEST(TransposeTest, Test6DInputTensor) {
|
||||
EXPECT_DEATH(TransposeOpConstModel({1, 2, 3, 4, 5, 6}, {5}, {0, 1, 2, 3, 4}),
|
||||
"Transpose op only supports 1D-5D input arrays.");
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -594,5 +627,60 @@ TEST(TransposeTest, ComplexTestWithReorderDynamicTensor) {
|
||||
EXPECT_THAT(m.GetOutput(), result);
|
||||
}
|
||||
|
||||
TEST(TransposeTest, Complex5DTestWithReorderConstTensor) {
|
||||
TransposeOpConstModel m({2, 3, 2, 2, 5}, {5}, {2, 0, 1, 4, 3});
|
||||
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
|
||||
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
|
||||
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
|
||||
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
|
||||
60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
|
||||
72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
|
||||
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95,
|
||||
96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107,
|
||||
108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 3, 5, 2}));
|
||||
auto result = ElementsAreArray(
|
||||
{0, 5, 1, 6, 2, 7, 3, 8, 4, 9, 20, 25, 21, 26, 22,
|
||||
27, 23, 28, 24, 29, 40, 45, 41, 46, 42, 47, 43, 48, 44, 49,
|
||||
60, 65, 61, 66, 62, 67, 63, 68, 64, 69, 80, 85, 81, 86, 82,
|
||||
87, 83, 88, 84, 89, 100, 105, 101, 106, 102, 107, 103, 108, 104, 109,
|
||||
10, 15, 11, 16, 12, 17, 13, 18, 14, 19, 30, 35, 31, 36, 32,
|
||||
37, 33, 38, 34, 39, 50, 55, 51, 56, 52, 57, 53, 58, 54, 59,
|
||||
70, 75, 71, 76, 72, 77, 73, 78, 74, 79, 90, 95, 91, 96, 92,
|
||||
97, 93, 98, 94, 99, 110, 115, 111, 116, 112, 117, 113, 118, 114, 119});
|
||||
EXPECT_THAT(m.GetOutput(), result);
|
||||
}
|
||||
|
||||
TEST(TransposeTest, Complex5DTestWithReorderDynamicTensor) {
|
||||
TransposeOpDynamicModel m({2, 3, 2, 2, 5}, {5});
|
||||
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
|
||||
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
|
||||
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
|
||||
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
|
||||
60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
|
||||
72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
|
||||
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95,
|
||||
96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107,
|
||||
108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119});
|
||||
m.SetPerm({2, 0, 1, 4, 3});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 3, 5, 2}));
|
||||
auto result = ElementsAreArray(
|
||||
{0, 5, 1, 6, 2, 7, 3, 8, 4, 9, 20, 25, 21, 26, 22,
|
||||
27, 23, 28, 24, 29, 40, 45, 41, 46, 42, 47, 43, 48, 44, 49,
|
||||
60, 65, 61, 66, 62, 67, 63, 68, 64, 69, 80, 85, 81, 86, 82,
|
||||
87, 83, 88, 84, 89, 100, 105, 101, 106, 102, 107, 103, 108, 104, 109,
|
||||
10, 15, 11, 16, 12, 17, 13, 18, 14, 19, 30, 35, 31, 36, 32,
|
||||
37, 33, 38, 34, 39, 50, 55, 51, 56, 52, 57, 53, 58, 54, 59,
|
||||
70, 75, 71, 76, 72, 77, 73, 78, 74, 79, 90, 95, 91, 96, 92,
|
||||
97, 93, 98, 94, 99, 110, 115, 111, 116, 112, 117, 113, 118, 114, 119});
|
||||
EXPECT_THAT(m.GetOutput(), result);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
@ -90,9 +90,6 @@ const std::map<string, string>& GetKnownBrokenTests() {
|
||||
// ResizeBilinear looks completely incompatible with Tensorflow
|
||||
{R"(^\/resize_bilinear.*dtype=tf.int32)", "72401107"},
|
||||
|
||||
// Transpose only supports 1D-4D input tensors.
|
||||
{R"(^\/transpose.*input_shape=\[.,.,.,.,.\])", "71545879"},
|
||||
|
||||
// Relu does not support int32.
|
||||
// These test cases appends a Relu after the tested ops when
|
||||
// activation=True. The tests are failing since Relu doesn't support
|
||||
|
@ -59,6 +59,12 @@ def make_transpose_tests(options):
|
||||
"perm": [[0, 1, 2, 3], [3, 0, 1, 2]],
|
||||
"constant_perm": [True],
|
||||
"fully_quantize": [True],
|
||||
}, {
|
||||
"dtype": [tf.float32],
|
||||
"input_shape": [[1, 2, 3, 4, 5]],
|
||||
"perm": [[0, 1, 2, 3, 4], [3, 4, 0, 1, 2]],
|
||||
"constant_perm": [True],
|
||||
"fully_quantize": [True, False],
|
||||
}]
|
||||
|
||||
def build_graph(parameters):
|
||||
|
@ -241,7 +241,7 @@ class SpaceToBatchND
|
||||
const Array& input_array = op_signature.model->GetArray(input_name);
|
||||
::tflite::OpSignature op_sig =
|
||||
GetVersioningOpSig(builtin_op(), op_signature);
|
||||
op_sig.options.space_batch.num_dims =
|
||||
op_sig.options.single_input_op.num_dims =
|
||||
input_array.shape().dimensions_count();
|
||||
return ::tflite::GetBuiltinOperatorVersion(op_sig);
|
||||
}
|
||||
@ -325,7 +325,7 @@ class BatchToSpaceND
|
||||
const Array& input_array = op_signature.model->GetArray(input_name);
|
||||
::tflite::OpSignature op_sig =
|
||||
GetVersioningOpSig(builtin_op(), op_signature);
|
||||
op_sig.options.space_batch.num_dims =
|
||||
op_sig.options.single_input_op.num_dims =
|
||||
input_array.shape().dimensions_count();
|
||||
return ::tflite::GetBuiltinOperatorVersion(op_sig);
|
||||
}
|
||||
@ -1208,7 +1208,7 @@ class StridedSlice
|
||||
static_cast<const StridedSliceOperator&>(*op_signature.op);
|
||||
::tflite::OpSignature op_sig =
|
||||
GetVersioningOpSig(builtin_op(), op_signature);
|
||||
op_sig.options.strided_slice.num_dims = ss_op.start_indices.size();
|
||||
op_sig.options.single_input_op.num_dims = ss_op.start_indices.size();
|
||||
return ::tflite::GetBuiltinOperatorVersion(op_sig);
|
||||
}
|
||||
};
|
||||
|
@ -173,6 +173,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_TRANSPOSE:
|
||||
if (op_sig.options.single_input_op.num_dims > 4) {
|
||||
return 4;
|
||||
}
|
||||
// If the op takes bool input, it is version 3.
|
||||
if (op_sig.input_types.at(0) == TensorType_BOOL) {
|
||||
return 3;
|
||||
@ -293,7 +296,7 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||
}
|
||||
return 1;
|
||||
case BuiltinOperator_STRIDED_SLICE:
|
||||
if (op_sig.options.strided_slice.num_dims > 4) {
|
||||
if (op_sig.options.single_input_op.num_dims > 4) {
|
||||
return 4;
|
||||
}
|
||||
// If the op takes bool input, it is version 3.
|
||||
@ -352,7 +355,7 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||
|
||||
case BuiltinOperator_SPACE_TO_BATCH_ND:
|
||||
case BuiltinOperator_BATCH_TO_SPACE_ND:
|
||||
if (op_sig.options.space_batch.num_dims != 4) {
|
||||
if (op_sig.options.single_input_op.num_dims != 4) {
|
||||
return 3;
|
||||
}
|
||||
if (op_sig.input_types.at(0) == TensorType_INT8) {
|
||||
@ -504,13 +507,11 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op,
|
||||
}
|
||||
} break;
|
||||
// TODO(b/150176627): Add tests for GetOpSignature.
|
||||
case BuiltinOperator_STRIDED_SLICE: {
|
||||
op_sig.options.strided_slice.num_dims = GetNumDims(subgraph, op, 0);
|
||||
} break;
|
||||
|
||||
case BuiltinOperator_STRIDED_SLICE:
|
||||
case BuiltinOperator_SPACE_TO_BATCH_ND:
|
||||
case BuiltinOperator_BATCH_TO_SPACE_ND: {
|
||||
op_sig.options.space_batch.num_dims = GetNumDims(subgraph, op, 0);
|
||||
case BuiltinOperator_BATCH_TO_SPACE_ND:
|
||||
case BuiltinOperator_TRANSPOSE: {
|
||||
op_sig.options.single_input_op.num_dims = GetNumDims(subgraph, op, 0);
|
||||
} break;
|
||||
|
||||
case BuiltinOperator_SUB:
|
||||
|
@ -51,10 +51,7 @@ typedef struct {
|
||||
} resize_bilinear;
|
||||
struct {
|
||||
int32_t num_dims;
|
||||
} strided_slice;
|
||||
struct {
|
||||
int32_t num_dims;
|
||||
} space_batch;
|
||||
} single_input_op;
|
||||
struct {
|
||||
int32_t num_dims;
|
||||
bool need_broadcast;
|
||||
|
@ -165,18 +165,18 @@ TEST(OpVersionTest, VersioningBatchToSpaceNDTest) {
|
||||
.op = BuiltinOperator_BATCH_TO_SPACE_ND,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT8},
|
||||
};
|
||||
fake_op_sig.options.space_batch.num_dims = 3;
|
||||
fake_op_sig.options.single_input_op.num_dims = 3;
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||
fake_op_sig.options.space_batch.num_dims = 4;
|
||||
fake_op_sig.options.single_input_op.num_dims = 4;
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_BATCH_TO_SPACE_ND,
|
||||
.input_types = std::vector<TensorType>{TensorType_UINT8},
|
||||
};
|
||||
fake_op_sig.options.space_batch.num_dims = 3;
|
||||
fake_op_sig.options.single_input_op.num_dims = 3;
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||
fake_op_sig.options.space_batch.num_dims = 4;
|
||||
fake_op_sig.options.single_input_op.num_dims = 4;
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
|
||||
}
|
||||
|
||||
@ -503,4 +503,26 @@ TEST(OpVersionTest, VersioningTileOperatorTest) {
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||
}
|
||||
TEST(OpVersionTest, VersioningTransposeTest) {
|
||||
OpSignature fake_op_sig = {
|
||||
.op = BuiltinOperator_TRANSPOSE,
|
||||
.input_types = std::vector<TensorType>{TensorType_BOOL},
|
||||
};
|
||||
fake_op_sig.options.single_input_op.num_dims = 5;
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4);
|
||||
fake_op_sig.options.single_input_op.num_dims = 4;
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_TRANSPOSE,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT8},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_TRANSPOSE,
|
||||
.input_types = std::vector<TensorType>{TensorType_UINT8},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
|
||||
}
|
||||
} // namespace tflite
|
||||
|
Loading…
Reference in New Issue
Block a user