diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 20a58c337fc..3f02658e21e 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -229,6 +229,7 @@ cc_library( ":round", ":tensor", ":tensor_utils", + ":transpose_utils", "//third_party/eigen3", "@gemmlowp//:fixedpoint", "@gemmlowp//:profiler", @@ -271,6 +272,7 @@ cc_library( ":strided_slice_logic", ":tensor", ":tensor_utils", + ":transpose_utils", ":types", ":legacy_types", ":legacy_reference_base", @@ -349,6 +351,28 @@ cc_test( ], ) +cc_library( + name = "transpose_utils", + srcs = [ + "transpose_utils.cc", + ], + hdrs = [ + "transpose_utils.h", + ], + deps = [ + ":types", + ], +) + +cc_test( + name = "transpose_utils_test", + srcs = ["transpose_utils_test.cc"], + deps = [ + ":transpose_utils", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "strided_slice_logic", srcs = [], diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index 8ff502d2449..f6b623cc6b7 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -51,6 +51,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/strided_slice_logic.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/lite/kernels/internal/transpose_utils.h" #include "tensorflow/lite/kernels/internal/types.h" namespace tflite { @@ -7095,15 +7096,10 @@ inline void Logistic16bitPercision(const LogisticParams& params, // Perform transpose by transposing 4x4 blocks of the input, proceeding from // left to right (down the rows) of the input, and then from top to bottom. template -inline void Transpose2DImpl(const TransposeParams& params, - const RuntimeShape& input_shape, - const T* input_data, - const RuntimeShape& output_shape, T* output_data) { +inline void Transpose2D(const RuntimeShape& input_shape, const T* input_data, + const RuntimeShape& output_shape, T* output_data) { TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 2); TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2); - TFLITE_DCHECK_EQ(params.perm_count, 2); - TFLITE_DCHECK_EQ(params.perm[0], 1); - TFLITE_DCHECK_EQ(params.perm[1], 0); const int d0 = input_shape.DimsData()[0]; const int d1 = input_shape.DimsData()[1]; @@ -7196,16 +7192,12 @@ inline void Transpose2DImpl(const TransposeParams& params, } template <> -inline void Transpose2DImpl(const TransposeParams& params, - const RuntimeShape& input_shape, - const int32_t* input_data, - const RuntimeShape& output_shape, - int32_t* output_data) { +inline void Transpose2D(const RuntimeShape& input_shape, + const int32_t* input_data, + const RuntimeShape& output_shape, + int32_t* output_data) { TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 2); TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2); - TFLITE_DCHECK_EQ(params.perm_count, 2); - TFLITE_DCHECK_EQ(params.perm[0], 1); - TFLITE_DCHECK_EQ(params.perm[1], 0); const int d0 = input_shape.DimsData()[0]; const int d1 = input_shape.DimsData()[1]; @@ -7278,93 +7270,17 @@ inline void Transpose2DImpl(const TransposeParams& params, } } -template -void Transpose2D(const TransposeParams& params, const RuntimeShape& input_shape, - const T* input_data, const RuntimeShape& output_shape, - T* output_data) { - // Transpose kernel only does rearranging values not numeric evaluations on - // each cell. It's safe to implement per size of scalar type and this trick - // keeps the total code size in a reasonable range. - switch (sizeof(T)) { - case 1: - Transpose2DImpl(params, input_shape, - reinterpret_cast(input_data), output_shape, - reinterpret_cast(output_data)); - break; - case 4: - Transpose2DImpl(params, input_shape, - reinterpret_cast(input_data), - output_shape, reinterpret_cast(output_data)); - break; - default: - // Reroute to the reference version if an optimized method for the given - // data is not available. - reference_ops::Transpose(params, input_shape, input_data, output_shape, - output_data); - } -} - // TODO(alanchiao): see if we can reduce the number // of lines of code in branching without affecting latency. template -inline void Transpose3DImpl(const TransposeParams& params, - const RuntimeShape& input_shape, - const T* input_data, - const RuntimeShape& output_shape, T* output_data) { +inline void Transpose3D(const TransposeParams& params, + const RuntimeShape& input_shape, const T* input_data, + const RuntimeShape& output_shape, T* output_data) { int s1, s2, s3; s1 = input_shape.Dims(0); s2 = input_shape.Dims(1); s3 = input_shape.Dims(2); - // TODO(b/141169757): generalize the following logics and move to the - // Transpose method. - const bool hasOneInDimension = (s1 == 1 || s2 == 1 || s3 == 1); - // Can fast path as 2D transpose in this case. - if (hasOneInDimension) { - int d1, d2; - bool is_identity = false; - // Check for identity to just return. - if (s1 == 1) { - // (0, 1, 2), (1, 0, 2), (1, 2, 0) - if ((params.perm[0] == 0 && params.perm[1] == 1) || params.perm[0] == 1) { - is_identity = true; - } - d1 = s2; - d2 = s3; - } else if (s2 == 1) { - // (0, 1, 2), (0, 2, 1), (1, 0, 2) - if ((params.perm[0] == 1 && params.perm[1] == 0) || params.perm[0] == 0) { - is_identity = true; - } - d1 = s1; - d2 = s3; - } else { - // (0, 1, 2), (0, 2, 1), (2, 0, 1) - if ((params.perm[0] == 2 && params.perm[1] == 0) || params.perm[0] == 0) { - is_identity = true; - } - d1 = s1; - d2 = s2; - } - - if (is_identity) { - memcpy(output_data, input_data, sizeof(T) * input_shape.FlatSize()); - return; - } - - TransposeParams new_params; - new_params.perm_count = 2; - new_params.perm[0] = 1; - new_params.perm[1] = 0; - - const RuntimeShape new_input_shape({d1, d2}); - const RuntimeShape new_output_shape({d2, d1}); - - Transpose2D(new_params, new_input_shape, input_data, new_output_shape, - output_data); - return; - } - int p1, p2, p3; if (params.perm[0] == 2) { p1 = 1; @@ -7407,44 +7323,16 @@ inline void Transpose3DImpl(const TransposeParams& params, } template -void Transpose3D(const TransposeParams& params, const RuntimeShape& input_shape, - const T* input_data, const RuntimeShape& output_shape, - T* output_data) { - // Transpose kernel only does rearranging values not numeric evaluations on - // each cell. It's safe to implement per size of scalar type and this trick - // keeps the total code size in a reasonable range. - switch (sizeof(T)) { - case 1: - Transpose3DImpl(params, input_shape, - reinterpret_cast(input_data), output_shape, - reinterpret_cast(output_data)); - break; - case 4: - Transpose3DImpl(params, input_shape, - reinterpret_cast(input_data), - output_shape, reinterpret_cast(output_data)); - break; - default: - // Reroute to the reference version if an optimized method for the given - // data is not available. - reference_ops::Transpose(params, input_shape, input_data, output_shape, - output_data); - } -} +void TransposeImpl(const TransposeParams& params, + const RuntimeShape& input_shape, const T* input_data, + const RuntimeShape& output_shape, T* output_data) { + const int dims_cnt = input_shape.DimensionsCount(); -template -void Transpose(const TransposeParams& params, const RuntimeShape& input_shape, - const T* input_data, const RuntimeShape& output_shape, - T* output_data) { - const int output_size = output_shape.DimensionsCount(); - TFLITE_DCHECK_LE(input_shape.DimensionsCount(), 4); - TFLITE_DCHECK_LE(output_size, 4); - TFLITE_DCHECK_EQ(output_size, params.perm_count); - - // Apply 2-D transpose. - if (input_shape.DimensionsCount() == 2 && params.perm[0] == 1 && - params.perm[1] == 0) { - Transpose2D(params, input_shape, input_data, output_shape, output_data); + int dim0, dim1; + if (transpose_utils::IsTranspose2DApplicable(params, input_shape, &dim0, + &dim1)) { + Transpose2D(RuntimeShape({dim0, dim1}), input_data, + RuntimeShape({dim1, dim0}), output_data); return; } @@ -7458,7 +7346,7 @@ void Transpose(const TransposeParams& params, const RuntimeShape& input_shape, // 96^3 is not mobile-friendly for certain usecases // (e.g. model used in beam search for seq2seq) but is in others. // Consider tradeoffs. - if (input_shape.DimensionsCount() == 3) { + if (dims_cnt == 3) { Transpose3D(params, input_shape, input_data, output_shape, output_data); return; } @@ -7469,6 +7357,66 @@ void Transpose(const TransposeParams& params, const RuntimeShape& input_shape, output_data); } +template +void Transpose(const TransposeParams& unshrinked_params, + const RuntimeShape& unshrinked_input_shape, const T* input_data, + const RuntimeShape& unshrinked_output_shape, T* output_data) { + gemmlowp::ScopedProfilingLabel label("Transpose"); + + const int output_size = unshrinked_output_shape.DimensionsCount(); + TFLITE_DCHECK_LE(unshrinked_input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(output_size, 4); + TFLITE_DCHECK_EQ(output_size, unshrinked_params.perm_count); + + RuntimeShape shrinked_input_shape = RuntimeShape(unshrinked_input_shape); + RuntimeShape shrinked_output_shape = RuntimeShape(unshrinked_output_shape); + TransposeParams shrinked_params = unshrinked_params; + + // Reduce any dimensions that have one size. Lower transpose op usually + // performs better since memory access patterns will be improved. + transpose_utils::RemoveOneSizeDimensions( + &shrinked_input_shape, &shrinked_output_shape, &shrinked_params); + + // Handle identity cases. + // TODO(b/140779653): Add an optimization pass in the conversion process to + // remove transpose op nodes where they do nothing like the below one. + bool identical = true; + for (int i = 0; i < shrinked_params.perm_count; ++i) { + if (shrinked_params.perm[i] != i) { + identical = false; + break; + } + } + if (identical) { + memcpy(output_data, input_data, + unshrinked_input_shape.FlatSize() * sizeof(T)); + return; + } + + // Reduce dimensions by flattening. + if (shrinked_params.perm[0] == 0 && output_size >= 3) { + RuntimeShape non_flatten_input_shape; + RuntimeShape non_flatten_output_shape; + TransposeParams non_flatten_params; + const int total_size = shrinked_input_shape.FlatSize(); + const int non_flatten_size = transpose_utils::Flatten( + shrinked_input_shape, shrinked_output_shape, shrinked_params, + &non_flatten_input_shape, &non_flatten_output_shape, + &non_flatten_params); + TFLITE_DCHECK_NE(non_flatten_params.perm[0], 0); + + for (int i = 0; i < total_size; i += non_flatten_size) { + TransposeImpl(non_flatten_params, non_flatten_input_shape, input_data + i, + non_flatten_output_shape, output_data + i); + } + return; + } + + // Call non-flattened case. + TransposeImpl(shrinked_params, shrinked_input_shape, input_data, + shrinked_output_shape, output_data); +} + } // namespace optimized_ops } // namespace tflite diff --git a/tensorflow/lite/kernels/internal/transpose_utils.cc b/tensorflow/lite/kernels/internal/transpose_utils.cc new file mode 100644 index 00000000000..76808020853 --- /dev/null +++ b/tensorflow/lite/kernels/internal/transpose_utils.cc @@ -0,0 +1,165 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/kernels/internal/transpose_utils.h" + +namespace tflite { +namespace transpose_utils { + +bool IsTranspose2DApplicable(const TransposeParams& params, + const RuntimeShape& input_shape, int* dim0, + int* dim1) { + const int dims_cnt = input_shape.DimensionsCount(); + + if (dims_cnt == 2) { + *dim0 = input_shape.Dims(0); + *dim1 = input_shape.Dims(1); + return true; + } + + const int first_perm = params.perm[0]; + for (int i = 1; i < dims_cnt; ++i) { + int rebased = params.perm[i] - first_perm; + if (rebased < 0) { + rebased += dims_cnt; + } + if (rebased != i) { + return false; + } + } + *dim0 = 1; + *dim1 = 1; + for (int i = 0; i < dims_cnt; ++i) { + if (i < first_perm) { + *dim0 *= input_shape.Dims(i); + } else { + *dim1 *= input_shape.Dims(i); + } + } + return true; +} + +void RemoveOneSizeDimensions(RuntimeShape* input_shape, + RuntimeShape* output_shape, + TransposeParams* params) { + const int dims_cnt = input_shape->DimensionsCount(); + TFLITE_DCHECK_EQ(params->perm_count, dims_cnt); + + bool foundOneSizeDim = false; + for (int i = 0; i < dims_cnt; ++i) { + if (input_shape->Dims(i) == 1) { + foundOneSizeDim = true; + break; + } + } + + // Return here if there is no one size dimension. + if (!foundOneSizeDim) return; + + // Handle the case where all the dimension size is one. + if (input_shape->FlatSize() == 1) { + input_shape->Resize(1); + input_shape->SetDim(0, 1); + output_shape->Resize(1); + output_shape->SetDim(0, 1); + params->perm_count = 1; + params->perm[0] = 0; + return; + } + + // Resize input shape. + int new_dims_cnt = 0; + for (int i = 0; i < dims_cnt; ++i) { + if (input_shape->Dims(i) == 1) { + continue; + } + input_shape->SetDim(new_dims_cnt, input_shape->Dims(i)); + ++new_dims_cnt; + } + input_shape->Resize(new_dims_cnt); + + // Resize output shape and re-calculate the perm parameter. + TransposeParams new_params; + new_dims_cnt = 0; + for (int i = 0; i < dims_cnt; ++i) { + if (output_shape->Dims(i) == 1) { + continue; + } + new_params.perm[new_dims_cnt] = params->perm[i]; + output_shape->SetDim(new_dims_cnt, output_shape->Dims(i)); + ++new_dims_cnt; + } + output_shape->Resize(new_dims_cnt); + new_params.perm_count = new_dims_cnt; + + for (int i = 0; i < new_dims_cnt; ++i) { + int min_val_idx = -1; + for (int j = 0; j < new_dims_cnt; ++j) { + if (new_params.perm[j] >= i && + (min_val_idx == -1 || + new_params.perm[min_val_idx] > new_params.perm[j])) { + min_val_idx = j; + } + } + new_params.perm[min_val_idx] = i; + } + *params = new_params; +} + +size_t Flatten(const RuntimeShape& input_shape, + const RuntimeShape& output_shape, const TransposeParams& params, + RuntimeShape* non_flatten_input_shape, + RuntimeShape* non_flatten_output_shape, + TransposeParams* non_flatten_params) { + // Calculate the total size of non-flatten dimensions. + int skip_dims_cnt = 0; + size_t flat_size = input_shape.FlatSize(); + for (int i = 0; i < params.perm_count; ++i) { + if (params.perm[i] == i) { + flat_size /= input_shape.Dims(i); + ++skip_dims_cnt; + } else { + break; + } + } + + // Shrink the shapes and re-calculate the perm parameter. + const int new_dims_cnt = params.perm_count - skip_dims_cnt; + non_flatten_input_shape->Resize(new_dims_cnt); + non_flatten_output_shape->Resize(new_dims_cnt); + non_flatten_params->perm_count = new_dims_cnt; + + for (int i = skip_dims_cnt; i < params.perm_count; ++i) { + non_flatten_input_shape->SetDim(i - skip_dims_cnt, input_shape.Dims(i)); + non_flatten_output_shape->SetDim(i - skip_dims_cnt, output_shape.Dims(i)); + non_flatten_params->perm[i - skip_dims_cnt] = params.perm[i]; + } + for (int i = 0; i < new_dims_cnt; ++i) { + int min_val_idx = -1; + for (int j = 0; j < new_dims_cnt; ++j) { + if (non_flatten_params->perm[j] >= i && + (min_val_idx == -1 || non_flatten_params->perm[min_val_idx] > + non_flatten_params->perm[j])) { + min_val_idx = j; + } + } + non_flatten_params->perm[min_val_idx] = i; + } + + return flat_size; +} + +} // namespace transpose_utils + +} // namespace tflite diff --git a/tensorflow/lite/kernels/internal/transpose_utils.h b/tensorflow/lite/kernels/internal/transpose_utils.h new file mode 100644 index 00000000000..b7fee18852e --- /dev/null +++ b/tensorflow/lite/kernels/internal/transpose_utils.h @@ -0,0 +1,52 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_TRANSPOSE_UTILS_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_TRANSPOSE_UTILS_H_ + +#include "tensorflow/lite/kernels/internal/types.h" + +namespace tflite { +namespace transpose_utils { + +// IsTranspose2DApplicable returns true if the given perm can be lowered to a +// 2D transpose op. If possible, it copies the lowered dimension counts to the +// given dim0 and dim1 pointers. +bool IsTranspose2DApplicable(const TransposeParams& params, + const RuntimeShape& input_shape, int* dim0, + int* dim1); + +// RemoveOneSizeDimensions removes one size dimensions in the given input/output +// shapes and adjusts the parameter values for transpose op. +void RemoveOneSizeDimensions(RuntimeShape* input_shape, + RuntimeShape* output_shape, + TransposeParams* params); + +// Flatten finds the dimensions that can be flatten, shrinks the given shapes +// and the given perm parameter to reflect the non-flatten dimensions, and +// returns the total size of the non-flatten dimensions. +// +// E.g, in perm [0, 1, 3, 2] case, the first two dimensions can be flatten and +// it returns |Dim Size(2)| x |Dim Size(3)|. +size_t Flatten(const RuntimeShape& input_shape, + const RuntimeShape& output_shape, const TransposeParams& params, + RuntimeShape* non_flatten_input_shape, + RuntimeShape* non_flatten_output_shape, + TransposeParams* non_flatten_params); + +} // namespace transpose_utils + +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_TRANSPOSE_UTILS_H_ diff --git a/tensorflow/lite/kernels/internal/transpose_utils_test.cc b/tensorflow/lite/kernels/internal/transpose_utils_test.cc new file mode 100644 index 00000000000..b55519f6a03 --- /dev/null +++ b/tensorflow/lite/kernels/internal/transpose_utils_test.cc @@ -0,0 +1,506 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/kernels/internal/transpose_utils.h" + +#include +#include + +namespace tflite { +namespace { + +TEST(TransposeUtilsTest, RemoveOneSizeDimensions_1DNoChanges) { + RuntimeShape input_shape({9}); + RuntimeShape output_shape({9}); + + TransposeParams params; + params.perm_count = 1; + params.perm[0] = 0; + + transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape, + ¶ms); + + EXPECT_EQ(input_shape, RuntimeShape({9})); + EXPECT_EQ(output_shape, RuntimeShape({9})); + + EXPECT_EQ(params.perm_count, 1); + EXPECT_EQ(params.perm[0], 0); +} + +TEST(TransposeUtilsTest, RemoveOneSizeDimensions_2DNoChanges) { + RuntimeShape input_shape({9, 3}); + RuntimeShape output_shape({3, 9}); + + TransposeParams params; + params.perm_count = 2; + params.perm[0] = 1; + params.perm[1] = 0; + + transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape, + ¶ms); + + EXPECT_EQ(input_shape, RuntimeShape({9, 3})); + EXPECT_EQ(output_shape, RuntimeShape({3, 9})); + + EXPECT_EQ(params.perm_count, 2); + EXPECT_EQ(params.perm[0], 1); + EXPECT_EQ(params.perm[1], 0); +} + +TEST(TransposeUtilsTest, RemoveOneSizeDimensions_2DShrinking) { + RuntimeShape input_shape({9, 1}); + RuntimeShape output_shape({1, 9}); + + TransposeParams params; + params.perm_count = 2; + params.perm[0] = 1; + params.perm[1] = 0; + + transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape, + ¶ms); + + EXPECT_EQ(input_shape, RuntimeShape({9})); + EXPECT_EQ(output_shape, RuntimeShape({9})); + + EXPECT_EQ(params.perm_count, 1); + EXPECT_EQ(params.perm[0], 0); +} + +TEST(TransposeUtilsTest, RemoveOneSizeDimensions_3DNoChanges) { + RuntimeShape input_shape({4, 3, 8}); + RuntimeShape output_shape({8, 4, 3}); + + TransposeParams params; + params.perm_count = 3; + params.perm[0] = 2; + params.perm[1] = 0; + params.perm[2] = 1; + + transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape, + ¶ms); + + EXPECT_EQ(input_shape, RuntimeShape({4, 3, 8})); + EXPECT_EQ(output_shape, RuntimeShape({8, 4, 3})); + + EXPECT_EQ(params.perm_count, 3); + EXPECT_EQ(params.perm[0], 2); + EXPECT_EQ(params.perm[1], 0); + EXPECT_EQ(params.perm[2], 1); +} + +TEST(TransposeUtilsTest, RemoveOneSizeDimensions_3DShrinkingOnce) { + RuntimeShape input_shape({4, 1, 8}); + RuntimeShape output_shape({8, 4, 1}); + + TransposeParams params; + params.perm_count = 3; + params.perm[0] = 2; + params.perm[1] = 0; + params.perm[2] = 1; + + transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape, + ¶ms); + + EXPECT_EQ(input_shape, RuntimeShape({4, 8})); + EXPECT_EQ(output_shape, RuntimeShape({8, 4})); + EXPECT_EQ(output_shape.Dims(1), 4); + + EXPECT_EQ(params.perm_count, 2); + EXPECT_EQ(params.perm[0], 1); + EXPECT_EQ(params.perm[1], 0); +} + +TEST(TransposeUtilsTest, RemoveOneSizeDimensions_3DShrinkingTwice) { + RuntimeShape input_shape({4, 1, 1}); + RuntimeShape output_shape({1, 4, 1}); + + TransposeParams params; + params.perm_count = 3; + params.perm[0] = 2; + params.perm[1] = 0; + params.perm[2] = 1; + + transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape, + ¶ms); + + EXPECT_EQ(input_shape, RuntimeShape({4})); + EXPECT_EQ(output_shape, RuntimeShape({4})); + + EXPECT_EQ(params.perm_count, 1); + EXPECT_EQ(params.perm[0], 0); +} + +TEST(TransposeUtilsTest, RemoveOneSizeDimensions_3DAllOnes) { + RuntimeShape input_shape({1, 1, 1}); + RuntimeShape output_shape({1, 1, 1}); + + TransposeParams params; + params.perm_count = 3; + params.perm[0] = 2; + params.perm[1] = 0; + params.perm[2] = 1; + + transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape, + ¶ms); + + EXPECT_EQ(input_shape, RuntimeShape({1})); + EXPECT_EQ(output_shape, RuntimeShape({1})); + + EXPECT_EQ(params.perm_count, 1); + EXPECT_EQ(params.perm[0], 0); +} + +TEST(TransposeUtilsTest, RemoveOneSizeDimensions_4DNoChanges) { + RuntimeShape input_shape({9, 3, 2, 4}); + RuntimeShape output_shape({3, 9, 4, 2}); + + TransposeParams params; + params.perm_count = 4; + params.perm[0] = 1; + params.perm[1] = 0; + params.perm[2] = 3; + params.perm[3] = 2; + + transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape, + ¶ms); + + EXPECT_EQ(input_shape, RuntimeShape({9, 3, 2, 4})); + EXPECT_EQ(output_shape, RuntimeShape({3, 9, 4, 2})); + + EXPECT_EQ(params.perm_count, 4); + EXPECT_EQ(params.perm[0], 1); + EXPECT_EQ(params.perm[1], 0); + EXPECT_EQ(params.perm[2], 3); + EXPECT_EQ(params.perm[3], 2); +} + +TEST(TransposeUtilsTest, RemoveOneSizeDimensions_4DShrinkingOnce) { + RuntimeShape input_shape({9, 3, 1, 4}); + RuntimeShape output_shape({3, 9, 4, 1}); + + TransposeParams params; + params.perm_count = 4; + params.perm[0] = 1; + params.perm[1] = 0; + params.perm[2] = 3; + params.perm[3] = 2; + + transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape, + ¶ms); + + EXPECT_EQ(input_shape, RuntimeShape({9, 3, 4})); + EXPECT_EQ(output_shape, RuntimeShape({3, 9, 4})); + + EXPECT_EQ(params.perm_count, 3); + EXPECT_EQ(params.perm[0], 1); + EXPECT_EQ(params.perm[1], 0); + EXPECT_EQ(params.perm[2], 2); +} + +TEST(TransposeUtilsTest, RemoveOneSizeDimensions_4DShrinkingTwice) { + RuntimeShape input_shape({1, 3, 1, 4}); + RuntimeShape output_shape({3, 1, 4, 1}); + + TransposeParams params; + params.perm_count = 4; + params.perm[0] = 1; + params.perm[1] = 2; + params.perm[2] = 3; + params.perm[3] = 0; + + transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape, + ¶ms); + + EXPECT_EQ(input_shape, RuntimeShape({3, 4})); + EXPECT_EQ(output_shape, RuntimeShape({3, 4})); + + EXPECT_EQ(params.perm_count, 2); + EXPECT_EQ(params.perm[0], 0); + EXPECT_EQ(params.perm[1], 1); +} + +TEST(TransposeUtilsTest, RemoveOneSizeDimensions_4DShrinkingThirdTimes) { + RuntimeShape input_shape({1, 1, 7, 1}); + RuntimeShape output_shape({1, 7, 1, 1}); + + TransposeParams params; + params.perm_count = 4; + params.perm[0] = 0; + params.perm[1] = 2; + params.perm[2] = 1; + params.perm[3] = 3; + + transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape, + ¶ms); + + EXPECT_EQ(input_shape, RuntimeShape({7})); + EXPECT_EQ(output_shape, RuntimeShape({7})); + + EXPECT_EQ(params.perm_count, 1); + EXPECT_EQ(params.perm[0], 0); +} + +TEST(TransposeUtilsTest, RemoveOneSizeDimensions_4DAllOnes) { + RuntimeShape input_shape({1, 1, 1, 1}); + RuntimeShape output_shape({1, 1, 1, 1}); + + TransposeParams params; + params.perm_count = 4; + params.perm[0] = 0; + params.perm[1] = 2; + params.perm[2] = 1; + params.perm[3] = 3; + + transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape, + ¶ms); + + EXPECT_EQ(input_shape, RuntimeShape({1})); + EXPECT_EQ(output_shape, RuntimeShape({1})); + + EXPECT_EQ(params.perm_count, 1); + EXPECT_EQ(params.perm[0], 0); +} + +TEST(TransposeUtilsTest, Flatten3D) { + RuntimeShape input_shape({3, 5, 7}); + RuntimeShape output_shape({3, 7, 5}); + + TransposeParams params; + params.perm_count = 3; + params.perm[0] = 0; + params.perm[1] = 2; + params.perm[2] = 1; + + RuntimeShape non_flatten_input_shape; + RuntimeShape non_flatten_output_shape; + TransposeParams non_flatten_params; + size_t non_flatten_size = transpose_utils::Flatten( + input_shape, output_shape, params, &non_flatten_input_shape, + &non_flatten_output_shape, &non_flatten_params); + + EXPECT_EQ(non_flatten_input_shape, RuntimeShape({5, 7})); + EXPECT_EQ(non_flatten_output_shape, RuntimeShape({7, 5})); + EXPECT_EQ(non_flatten_size, 5 * 7); + + EXPECT_EQ(non_flatten_params.perm_count, 2); + EXPECT_EQ(non_flatten_params.perm[0], 1); + EXPECT_EQ(non_flatten_params.perm[1], 0); +} + +TEST(TransposeUtilsTest, Flatten4DFlattenOnce) { + RuntimeShape input_shape({3, 5, 7, 9}); + RuntimeShape output_shape({3, 7, 5, 9}); + + TransposeParams params; + params.perm_count = 4; + params.perm[0] = 0; + params.perm[1] = 2; + params.perm[2] = 1; + params.perm[3] = 3; + + RuntimeShape non_flatten_input_shape; + RuntimeShape non_flatten_output_shape; + TransposeParams non_flatten_params; + size_t non_flatten_size = transpose_utils::Flatten( + input_shape, output_shape, params, &non_flatten_input_shape, + &non_flatten_output_shape, &non_flatten_params); + + EXPECT_EQ(non_flatten_input_shape, RuntimeShape({5, 7, 9})); + EXPECT_EQ(non_flatten_output_shape, RuntimeShape({7, 5, 9})); + EXPECT_EQ(non_flatten_size, 5 * 7 * 9); + + EXPECT_EQ(non_flatten_params.perm_count, 3); + EXPECT_EQ(non_flatten_params.perm[0], 1); + EXPECT_EQ(non_flatten_params.perm[1], 0); + EXPECT_EQ(non_flatten_params.perm[2], 2); +} + +TEST(TransposeUtilsTest, Flatten4DFlattenTwice) { + RuntimeShape input_shape({3, 5, 7, 9}); + RuntimeShape output_shape({3, 5, 9, 7}); + + TransposeParams params; + params.perm_count = 4; + params.perm[0] = 0; + params.perm[1] = 1; + params.perm[2] = 3; + params.perm[3] = 2; + + RuntimeShape non_flatten_input_shape; + RuntimeShape non_flatten_output_shape; + TransposeParams non_flatten_params; + size_t non_flatten_size = transpose_utils::Flatten( + input_shape, output_shape, params, &non_flatten_input_shape, + &non_flatten_output_shape, &non_flatten_params); + + EXPECT_EQ(non_flatten_input_shape, RuntimeShape({7, 9})); + EXPECT_EQ(non_flatten_output_shape, RuntimeShape({9, 7})); + EXPECT_EQ(non_flatten_size, 7 * 9); + + EXPECT_EQ(non_flatten_params.perm_count, 2); + EXPECT_EQ(non_flatten_params.perm[0], 1); + EXPECT_EQ(non_flatten_params.perm[1], 0); +} + +TEST(TransposeUtilsTest, IsTranspose2DApplicable2D) { + RuntimeShape input_shape({4, 5}); + + TransposeParams params; + params.perm_count = 2; + params.perm[0] = 1; + params.perm[1] = 0; + + int dim0, dim1; + bool applicable = transpose_utils::IsTranspose2DApplicable( + params, input_shape, &dim0, &dim1); + + EXPECT_TRUE(applicable); + EXPECT_EQ(dim0, 4); + EXPECT_EQ(dim1, 5); +} + +TEST(TransposeUtilsTest, IsTranspose2DApplicable3DOne) { + RuntimeShape input_shape({4, 5, 6}); + + TransposeParams params; + params.perm_count = 3; + params.perm[0] = 1; + params.perm[1] = 2; + params.perm[2] = 0; + + int dim0, dim1; + bool applicable = transpose_utils::IsTranspose2DApplicable( + params, input_shape, &dim0, &dim1); + + EXPECT_TRUE(applicable); + EXPECT_EQ(dim0, 4); + EXPECT_EQ(dim1, 30); +} + +TEST(TransposeUtilsTest, IsTranspose2DApplicable3DTwo) { + RuntimeShape input_shape({4, 5, 6}); + + TransposeParams params; + params.perm_count = 3; + params.perm[0] = 2; + params.perm[1] = 0; + params.perm[2] = 1; + + int dim0, dim1; + bool applicable = transpose_utils::IsTranspose2DApplicable( + params, input_shape, &dim0, &dim1); + + EXPECT_TRUE(applicable); + EXPECT_EQ(dim0, 20); + EXPECT_EQ(dim1, 6); +} + +TEST(TransposeUtilsTest, IsTranspose2DApplicable3DNotApplicable) { + RuntimeShape input_shape({4, 5, 6}); + + TransposeParams params; + params.perm_count = 3; + params.perm[0] = 2; + params.perm[1] = 1; + params.perm[2] = 0; + + int dim0, dim1; + bool applicable = transpose_utils::IsTranspose2DApplicable( + params, input_shape, &dim0, &dim1); + + EXPECT_FALSE(applicable); +} + +TEST(TransposeUtilsTest, IsTranspose2DApplicable4DOne) { + RuntimeShape input_shape({4, 5, 6, 7}); + + TransposeParams params; + params.perm_count = 4; + params.perm[0] = 1; + params.perm[1] = 2; + params.perm[2] = 3; + params.perm[3] = 0; + + int dim0, dim1; + bool applicable = transpose_utils::IsTranspose2DApplicable( + params, input_shape, &dim0, &dim1); + + EXPECT_TRUE(applicable); + EXPECT_EQ(dim0, 4); + EXPECT_EQ(dim1, 210); +} + +TEST(TransposeUtilsTest, IsTranspose2DApplicable4DTwo) { + RuntimeShape input_shape({4, 5, 6, 7}); + + TransposeParams params; + params.perm_count = 4; + params.perm[0] = 2; + params.perm[1] = 3; + params.perm[2] = 0; + params.perm[3] = 1; + + int dim0, dim1; + bool applicable = transpose_utils::IsTranspose2DApplicable( + params, input_shape, &dim0, &dim1); + + EXPECT_TRUE(applicable); + EXPECT_EQ(dim0, 20); + EXPECT_EQ(dim1, 42); +} + +TEST(TransposeUtilsTest, IsTranspose2DApplicable4DThird) { + RuntimeShape input_shape({4, 5, 6, 7}); + + TransposeParams params; + params.perm_count = 4; + params.perm[0] = 3; + params.perm[1] = 0; + params.perm[2] = 1; + params.perm[3] = 2; + + int dim0, dim1; + bool applicable = transpose_utils::IsTranspose2DApplicable( + params, input_shape, &dim0, &dim1); + + EXPECT_TRUE(applicable); + EXPECT_EQ(dim0, 120); + EXPECT_EQ(dim1, 7); +} + +TEST(TransposeUtilsTest, IsTranspose2DApplicable4DNotApplicable) { + RuntimeShape input_shape({4, 5, 6, 7}); + + TransposeParams params; + params.perm_count = 4; + params.perm[0] = 3; + params.perm[1] = 2; + params.perm[2] = 1; + params.perm[3] = 0; + + int dim0, dim1; + bool applicable = transpose_utils::IsTranspose2DApplicable( + params, input_shape, &dim0, &dim1); + + EXPECT_FALSE(applicable); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/lite/kernels/transpose.cc b/tensorflow/lite/kernels/transpose.cc index c2d2cde2a68..f3e00c24d49 100644 --- a/tensorflow/lite/kernels/transpose.cc +++ b/tensorflow/lite/kernels/transpose.cc @@ -100,18 +100,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const int size = op_context.perm->dims->data[0]; TransposeParams params; params.perm_count = size; - bool identical = true; for (int i = 0; i < size; ++i) { params.perm[i] = perm_data[i]; - if (perm_data[i] != i) identical = false; - } - - // TODO(b/140779653): Add an optimization pass in the conversion process to - // remove transpose op nodes where they do nothing like the below one. - if (identical) { - memcpy(op_context.output->data.raw, op_context.input->data.raw, - op_context.output->bytes); - return kTfLiteOk; } #define TF_LITE_TRANSPOSE(type, scalar) \ @@ -120,28 +110,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTensorShape(op_context.output), \ GetTensorData(op_context.output)) + // Transpose kernel only does rearranging values not numeric evaluations on + // each cell. It's safe to implement per size of scalar type and this trick + // keeps the total code size in a reasonable range. switch (op_context.input->type) { case kTfLiteFloat32: - if (kernel_type == kGenericOptimized) { - TF_LITE_TRANSPOSE(optimized_ops, float); - } else { - TF_LITE_TRANSPOSE(reference_ops, float); - } - break; - case kTfLiteUInt8: - if (kernel_type == kGenericOptimized) { - TF_LITE_TRANSPOSE(optimized_ops, uint8_t); - } else { - TF_LITE_TRANSPOSE(reference_ops, uint8_t); - } - break; - case kTfLiteInt8: - if (kernel_type == kGenericOptimized) { - TF_LITE_TRANSPOSE(optimized_ops, int8_t); - } else { - TF_LITE_TRANSPOSE(reference_ops, int8_t); - } - break; case kTfLiteInt32: if (kernel_type == kGenericOptimized) { TF_LITE_TRANSPOSE(optimized_ops, int32_t); @@ -149,16 +122,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_TRANSPOSE(reference_ops, int32_t); } break; - case kTfLiteInt64: + case kTfLiteUInt8: + case kTfLiteInt8: if (kernel_type == kGenericOptimized) { - TF_LITE_TRANSPOSE(optimized_ops, int64_t); + TF_LITE_TRANSPOSE(optimized_ops, int8_t); } else { - TF_LITE_TRANSPOSE(reference_ops, int64_t); + TF_LITE_TRANSPOSE(reference_ops, int8_t); } break; + case kTfLiteInt64: + TF_LITE_TRANSPOSE(reference_ops, int64_t); + break; case kTfLiteBool: - if (kernel_type == kGenericOptimized) { - TF_LITE_TRANSPOSE(optimized_ops, bool); + if (sizeof(bool) == 1) { + if (kernel_type == kGenericOptimized) { + TF_LITE_TRANSPOSE(optimized_ops, int8_t); + } else { + TF_LITE_TRANSPOSE(reference_ops, int8_t); + } } else { TF_LITE_TRANSPOSE(reference_ops, bool); } diff --git a/tensorflow/lite/kernels/transpose_test.cc b/tensorflow/lite/kernels/transpose_test.cc index 32227c32496..327692bb8eb 100644 --- a/tensorflow/lite/kernels/transpose_test.cc +++ b/tensorflow/lite/kernels/transpose_test.cc @@ -346,6 +346,167 @@ TEST(TransposeTest, Test3DInputDynamicTensor) { 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23})); } +TEST(TransposeTest, Test1DNotShrinked) { + TransposeOpConstModel m({1}, {1}, {0}); + m.SetInput({0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0})); +} + +TEST(TransposeTest, Test2DShrinkedOneTime) { + TransposeOpConstModel m({2, 1}, {2}, {1, 0}); + m.SetInput({0, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 1})); +} + +TEST(TransposeTest, Test2DShrinkedTwoTimes) { + TransposeOpConstModel m({1, 1}, {2}, {1, 0}); + m.SetInput({0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0})); +} + +TEST(TransposeTest, Test3DShrinkedOneTime) { + TransposeOpConstModel m({2, 1, 3}, {3}, {0, 2, 1}); + m.SetInput({0, 1, 2, 3, 4, 5}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 1, 2, 3, 4, 5})); +} + +TEST(TransposeTest, Test3DShrinkedTwoTimes) { + TransposeOpConstModel m({1, 1, 3}, {3}, {1, 2, 0}); + m.SetInput({0, 1, 2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 1, 2})); +} + +TEST(TransposeTest, Test3DShrinkedAll) { + TransposeOpConstModel m({1, 1, 1}, {3}, {1, 2, 0}); + m.SetInput({0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0})); +} + +TEST(TransposeTest, Test4DShrinkedOneTimes) { + TransposeOpConstModel m({2, 2, 3, 1}, {4}, {3, 0, 1, 2}); + m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 2, 3})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11})); +} + +TEST(TransposeTest, Test4DShrinkedTwoTimes) { + TransposeOpConstModel m({2, 1, 3, 1}, {4}, {0, 3, 1, 2}); + m.SetInput({0, 1, 2, 3, 4, 5}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 1, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 1, 2, 3, 4, 5})); +} + +TEST(TransposeTest, Test4DShrinkedThirdTimes) { + TransposeOpConstModel m({2, 1, 1, 1}, {4}, {3, 2, 1, 0}); + m.SetInput({0, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 1})); +} + +TEST(TransposeTest, Test4DShrinkedFourTimes) { + TransposeOpConstModel m({1, 1, 1, 1}, {4}, {2, 3, 1, 0}); + m.SetInput({0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0})); +} + +TEST(TransposeTest, Test3DFlatten) { + TransposeOpConstModel m({2, 2, 3}, {3}, {0, 2, 1}); + m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 3, 1, 4, 2, 5, 6, 9, 7, 10, 8, 11})); +} + +TEST(TransposeTest, Test4DFlattenOne) { + TransposeOpConstModel m({2, 2, 2, 2}, {4}, {0, 1, 3, 2}); + m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 1, 3, 4, 6, 5, 7, 8, 10, 9, + 11, 12, 14, 13, 15})); +} + +TEST(TransposeTest, Test4DFlattenTwo) { + TransposeOpConstModel m({2, 2, 2, 2}, {4}, {0, 2, 3, 1}); + m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 4, 1, 5, 2, 6, 3, 7, 8, 12, 9, + 13, 10, 14, 11, 15})); +} + +TEST(TransposeTest, 3DDividedIntoTwo2DsOne) { + std::vector out; + RunTestPermutation({2, 3, 4}, {1, 2, 0}, &out); + TransposeOpConstModel m({2, 3, 4}, {3}, {1, 2, 0}); + m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); + m.Invoke(); + EXPECT_EQ(m.GetOutput(), out); +} + +TEST(TransposeTest, 3DDividedIntoTwo2DsTwo) { + std::vector out; + RunTestPermutation({2, 3, 4}, {2, 0, 1}, &out); + TransposeOpConstModel m({2, 3, 4}, {3}, {2, 0, 1}); + m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); + m.Invoke(); + EXPECT_EQ(m.GetOutput(), out); +} + +TEST(TransposeTest, 4DDividedIntoTwo2DsOne) { + std::vector out; + RunTestPermutation({2, 3, 4, 2}, {1, 2, 3, 0}, &out); + TransposeOpConstModel m({2, 3, 4, 2}, {4}, {1, 2, 3, 0}); + m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}); + m.Invoke(); + EXPECT_EQ(m.GetOutput(), out); +} + +TEST(TransposeTest, 4DDividedIntoTwo2DsTwo) { + std::vector out; + RunTestPermutation({2, 3, 4, 2}, {2, 3, 0, 1}, &out); + TransposeOpConstModel m({2, 3, 4, 2}, {4}, {2, 3, 0, 1}); + m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}); + m.Invoke(); + EXPECT_EQ(m.GetOutput(), out); +} + +TEST(TransposeTest, 4DDividedIntoTwo2DsThird) { + std::vector out; + RunTestPermutation({2, 3, 4, 2}, {3, 0, 1, 2}, &out); + TransposeOpConstModel m({2, 3, 4, 2}, {4}, {3, 0, 1, 2}); + m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}); + m.Invoke(); + EXPECT_EQ(m.GetOutput(), out); +} + #ifdef GTEST_HAS_DEATH_TEST TEST(TransposeTest, Test5DInputTensor) { EXPECT_DEATH(TransposeOpConstModel({1, 2, 3, 4, 5}, {5}, {0, 1, 2, 3, 4}),