Implement general dimension reduction optimizations

PiperOrigin-RevId: 276480378
Change-Id: Iaf8f9228c51cf4f39c95c7f66d02abea2ad2b9fa
This commit is contained in:
Jaesung Chung 2019-10-24 07:04:56 -07:00 committed by TensorFlower Gardener
parent 8598d2c233
commit 37248520e4
7 changed files with 1004 additions and 167 deletions

View File

@ -229,6 +229,7 @@ cc_library(
":round",
":tensor",
":tensor_utils",
":transpose_utils",
"//third_party/eigen3",
"@gemmlowp//:fixedpoint",
"@gemmlowp//:profiler",
@ -271,6 +272,7 @@ cc_library(
":strided_slice_logic",
":tensor",
":tensor_utils",
":transpose_utils",
":types",
":legacy_types",
":legacy_reference_base",
@ -349,6 +351,28 @@ cc_test(
],
)
cc_library(
name = "transpose_utils",
srcs = [
"transpose_utils.cc",
],
hdrs = [
"transpose_utils.h",
],
deps = [
":types",
],
)
cc_test(
name = "transpose_utils_test",
srcs = ["transpose_utils_test.cc"],
deps = [
":transpose_utils",
"@com_google_googletest//:gtest",
],
)
cc_library(
name = "strided_slice_logic",
srcs = [],

View File

@ -51,6 +51,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/lite/kernels/internal/transpose_utils.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
@ -7095,15 +7096,10 @@ inline void Logistic16bitPercision(const LogisticParams& params,
// Perform transpose by transposing 4x4 blocks of the input, proceeding from
// left to right (down the rows) of the input, and then from top to bottom.
template <typename T>
inline void Transpose2DImpl(const TransposeParams& params,
const RuntimeShape& input_shape,
const T* input_data,
const RuntimeShape& output_shape, T* output_data) {
inline void Transpose2D(const RuntimeShape& input_shape, const T* input_data,
const RuntimeShape& output_shape, T* output_data) {
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 2);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2);
TFLITE_DCHECK_EQ(params.perm_count, 2);
TFLITE_DCHECK_EQ(params.perm[0], 1);
TFLITE_DCHECK_EQ(params.perm[1], 0);
const int d0 = input_shape.DimsData()[0];
const int d1 = input_shape.DimsData()[1];
@ -7196,16 +7192,12 @@ inline void Transpose2DImpl(const TransposeParams& params,
}
template <>
inline void Transpose2DImpl(const TransposeParams& params,
const RuntimeShape& input_shape,
const int32_t* input_data,
const RuntimeShape& output_shape,
int32_t* output_data) {
inline void Transpose2D(const RuntimeShape& input_shape,
const int32_t* input_data,
const RuntimeShape& output_shape,
int32_t* output_data) {
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 2);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2);
TFLITE_DCHECK_EQ(params.perm_count, 2);
TFLITE_DCHECK_EQ(params.perm[0], 1);
TFLITE_DCHECK_EQ(params.perm[1], 0);
const int d0 = input_shape.DimsData()[0];
const int d1 = input_shape.DimsData()[1];
@ -7278,93 +7270,17 @@ inline void Transpose2DImpl(const TransposeParams& params,
}
}
template <typename T>
void Transpose2D(const TransposeParams& params, const RuntimeShape& input_shape,
const T* input_data, const RuntimeShape& output_shape,
T* output_data) {
// Transpose kernel only does rearranging values not numeric evaluations on
// each cell. It's safe to implement per size of scalar type and this trick
// keeps the total code size in a reasonable range.
switch (sizeof(T)) {
case 1:
Transpose2DImpl(params, input_shape,
reinterpret_cast<const int8_t*>(input_data), output_shape,
reinterpret_cast<int8_t*>(output_data));
break;
case 4:
Transpose2DImpl(params, input_shape,
reinterpret_cast<const int32_t*>(input_data),
output_shape, reinterpret_cast<int32_t*>(output_data));
break;
default:
// Reroute to the reference version if an optimized method for the given
// data is not available.
reference_ops::Transpose(params, input_shape, input_data, output_shape,
output_data);
}
}
// TODO(alanchiao): see if we can reduce the number
// of lines of code in branching without affecting latency.
template <typename T>
inline void Transpose3DImpl(const TransposeParams& params,
const RuntimeShape& input_shape,
const T* input_data,
const RuntimeShape& output_shape, T* output_data) {
inline void Transpose3D(const TransposeParams& params,
const RuntimeShape& input_shape, const T* input_data,
const RuntimeShape& output_shape, T* output_data) {
int s1, s2, s3;
s1 = input_shape.Dims(0);
s2 = input_shape.Dims(1);
s3 = input_shape.Dims(2);
// TODO(b/141169757): generalize the following logics and move to the
// Transpose method.
const bool hasOneInDimension = (s1 == 1 || s2 == 1 || s3 == 1);
// Can fast path as 2D transpose in this case.
if (hasOneInDimension) {
int d1, d2;
bool is_identity = false;
// Check for identity to just return.
if (s1 == 1) {
// (0, 1, 2), (1, 0, 2), (1, 2, 0)
if ((params.perm[0] == 0 && params.perm[1] == 1) || params.perm[0] == 1) {
is_identity = true;
}
d1 = s2;
d2 = s3;
} else if (s2 == 1) {
// (0, 1, 2), (0, 2, 1), (1, 0, 2)
if ((params.perm[0] == 1 && params.perm[1] == 0) || params.perm[0] == 0) {
is_identity = true;
}
d1 = s1;
d2 = s3;
} else {
// (0, 1, 2), (0, 2, 1), (2, 0, 1)
if ((params.perm[0] == 2 && params.perm[1] == 0) || params.perm[0] == 0) {
is_identity = true;
}
d1 = s1;
d2 = s2;
}
if (is_identity) {
memcpy(output_data, input_data, sizeof(T) * input_shape.FlatSize());
return;
}
TransposeParams new_params;
new_params.perm_count = 2;
new_params.perm[0] = 1;
new_params.perm[1] = 0;
const RuntimeShape new_input_shape({d1, d2});
const RuntimeShape new_output_shape({d2, d1});
Transpose2D(new_params, new_input_shape, input_data, new_output_shape,
output_data);
return;
}
int p1, p2, p3;
if (params.perm[0] == 2) {
p1 = 1;
@ -7407,44 +7323,16 @@ inline void Transpose3DImpl(const TransposeParams& params,
}
template <typename T>
void Transpose3D(const TransposeParams& params, const RuntimeShape& input_shape,
const T* input_data, const RuntimeShape& output_shape,
T* output_data) {
// Transpose kernel only does rearranging values not numeric evaluations on
// each cell. It's safe to implement per size of scalar type and this trick
// keeps the total code size in a reasonable range.
switch (sizeof(T)) {
case 1:
Transpose3DImpl(params, input_shape,
reinterpret_cast<const int8_t*>(input_data), output_shape,
reinterpret_cast<int8_t*>(output_data));
break;
case 4:
Transpose3DImpl(params, input_shape,
reinterpret_cast<const int32_t*>(input_data),
output_shape, reinterpret_cast<int32_t*>(output_data));
break;
default:
// Reroute to the reference version if an optimized method for the given
// data is not available.
reference_ops::Transpose(params, input_shape, input_data, output_shape,
output_data);
}
}
void TransposeImpl(const TransposeParams& params,
const RuntimeShape& input_shape, const T* input_data,
const RuntimeShape& output_shape, T* output_data) {
const int dims_cnt = input_shape.DimensionsCount();
template <typename T>
void Transpose(const TransposeParams& params, const RuntimeShape& input_shape,
const T* input_data, const RuntimeShape& output_shape,
T* output_data) {
const int output_size = output_shape.DimensionsCount();
TFLITE_DCHECK_LE(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(output_size, 4);
TFLITE_DCHECK_EQ(output_size, params.perm_count);
// Apply 2-D transpose.
if (input_shape.DimensionsCount() == 2 && params.perm[0] == 1 &&
params.perm[1] == 0) {
Transpose2D(params, input_shape, input_data, output_shape, output_data);
int dim0, dim1;
if (transpose_utils::IsTranspose2DApplicable(params, input_shape, &dim0,
&dim1)) {
Transpose2D(RuntimeShape({dim0, dim1}), input_data,
RuntimeShape({dim1, dim0}), output_data);
return;
}
@ -7458,7 +7346,7 @@ void Transpose(const TransposeParams& params, const RuntimeShape& input_shape,
// 96^3 is not mobile-friendly for certain usecases
// (e.g. model used in beam search for seq2seq) but is in others.
// Consider tradeoffs.
if (input_shape.DimensionsCount() == 3) {
if (dims_cnt == 3) {
Transpose3D(params, input_shape, input_data, output_shape, output_data);
return;
}
@ -7469,6 +7357,66 @@ void Transpose(const TransposeParams& params, const RuntimeShape& input_shape,
output_data);
}
template <typename T>
void Transpose(const TransposeParams& unshrinked_params,
const RuntimeShape& unshrinked_input_shape, const T* input_data,
const RuntimeShape& unshrinked_output_shape, T* output_data) {
gemmlowp::ScopedProfilingLabel label("Transpose");
const int output_size = unshrinked_output_shape.DimensionsCount();
TFLITE_DCHECK_LE(unshrinked_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(output_size, 4);
TFLITE_DCHECK_EQ(output_size, unshrinked_params.perm_count);
RuntimeShape shrinked_input_shape = RuntimeShape(unshrinked_input_shape);
RuntimeShape shrinked_output_shape = RuntimeShape(unshrinked_output_shape);
TransposeParams shrinked_params = unshrinked_params;
// Reduce any dimensions that have one size. Lower transpose op usually
// performs better since memory access patterns will be improved.
transpose_utils::RemoveOneSizeDimensions(
&shrinked_input_shape, &shrinked_output_shape, &shrinked_params);
// Handle identity cases.
// TODO(b/140779653): Add an optimization pass in the conversion process to
// remove transpose op nodes where they do nothing like the below one.
bool identical = true;
for (int i = 0; i < shrinked_params.perm_count; ++i) {
if (shrinked_params.perm[i] != i) {
identical = false;
break;
}
}
if (identical) {
memcpy(output_data, input_data,
unshrinked_input_shape.FlatSize() * sizeof(T));
return;
}
// Reduce dimensions by flattening.
if (shrinked_params.perm[0] == 0 && output_size >= 3) {
RuntimeShape non_flatten_input_shape;
RuntimeShape non_flatten_output_shape;
TransposeParams non_flatten_params;
const int total_size = shrinked_input_shape.FlatSize();
const int non_flatten_size = transpose_utils::Flatten(
shrinked_input_shape, shrinked_output_shape, shrinked_params,
&non_flatten_input_shape, &non_flatten_output_shape,
&non_flatten_params);
TFLITE_DCHECK_NE(non_flatten_params.perm[0], 0);
for (int i = 0; i < total_size; i += non_flatten_size) {
TransposeImpl(non_flatten_params, non_flatten_input_shape, input_data + i,
non_flatten_output_shape, output_data + i);
}
return;
}
// Call non-flattened case.
TransposeImpl(shrinked_params, shrinked_input_shape, input_data,
shrinked_output_shape, output_data);
}
} // namespace optimized_ops
} // namespace tflite

View File

@ -0,0 +1,165 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/transpose_utils.h"
namespace tflite {
namespace transpose_utils {
bool IsTranspose2DApplicable(const TransposeParams& params,
const RuntimeShape& input_shape, int* dim0,
int* dim1) {
const int dims_cnt = input_shape.DimensionsCount();
if (dims_cnt == 2) {
*dim0 = input_shape.Dims(0);
*dim1 = input_shape.Dims(1);
return true;
}
const int first_perm = params.perm[0];
for (int i = 1; i < dims_cnt; ++i) {
int rebased = params.perm[i] - first_perm;
if (rebased < 0) {
rebased += dims_cnt;
}
if (rebased != i) {
return false;
}
}
*dim0 = 1;
*dim1 = 1;
for (int i = 0; i < dims_cnt; ++i) {
if (i < first_perm) {
*dim0 *= input_shape.Dims(i);
} else {
*dim1 *= input_shape.Dims(i);
}
}
return true;
}
void RemoveOneSizeDimensions(RuntimeShape* input_shape,
RuntimeShape* output_shape,
TransposeParams* params) {
const int dims_cnt = input_shape->DimensionsCount();
TFLITE_DCHECK_EQ(params->perm_count, dims_cnt);
bool foundOneSizeDim = false;
for (int i = 0; i < dims_cnt; ++i) {
if (input_shape->Dims(i) == 1) {
foundOneSizeDim = true;
break;
}
}
// Return here if there is no one size dimension.
if (!foundOneSizeDim) return;
// Handle the case where all the dimension size is one.
if (input_shape->FlatSize() == 1) {
input_shape->Resize(1);
input_shape->SetDim(0, 1);
output_shape->Resize(1);
output_shape->SetDim(0, 1);
params->perm_count = 1;
params->perm[0] = 0;
return;
}
// Resize input shape.
int new_dims_cnt = 0;
for (int i = 0; i < dims_cnt; ++i) {
if (input_shape->Dims(i) == 1) {
continue;
}
input_shape->SetDim(new_dims_cnt, input_shape->Dims(i));
++new_dims_cnt;
}
input_shape->Resize(new_dims_cnt);
// Resize output shape and re-calculate the perm parameter.
TransposeParams new_params;
new_dims_cnt = 0;
for (int i = 0; i < dims_cnt; ++i) {
if (output_shape->Dims(i) == 1) {
continue;
}
new_params.perm[new_dims_cnt] = params->perm[i];
output_shape->SetDim(new_dims_cnt, output_shape->Dims(i));
++new_dims_cnt;
}
output_shape->Resize(new_dims_cnt);
new_params.perm_count = new_dims_cnt;
for (int i = 0; i < new_dims_cnt; ++i) {
int min_val_idx = -1;
for (int j = 0; j < new_dims_cnt; ++j) {
if (new_params.perm[j] >= i &&
(min_val_idx == -1 ||
new_params.perm[min_val_idx] > new_params.perm[j])) {
min_val_idx = j;
}
}
new_params.perm[min_val_idx] = i;
}
*params = new_params;
}
size_t Flatten(const RuntimeShape& input_shape,
const RuntimeShape& output_shape, const TransposeParams& params,
RuntimeShape* non_flatten_input_shape,
RuntimeShape* non_flatten_output_shape,
TransposeParams* non_flatten_params) {
// Calculate the total size of non-flatten dimensions.
int skip_dims_cnt = 0;
size_t flat_size = input_shape.FlatSize();
for (int i = 0; i < params.perm_count; ++i) {
if (params.perm[i] == i) {
flat_size /= input_shape.Dims(i);
++skip_dims_cnt;
} else {
break;
}
}
// Shrink the shapes and re-calculate the perm parameter.
const int new_dims_cnt = params.perm_count - skip_dims_cnt;
non_flatten_input_shape->Resize(new_dims_cnt);
non_flatten_output_shape->Resize(new_dims_cnt);
non_flatten_params->perm_count = new_dims_cnt;
for (int i = skip_dims_cnt; i < params.perm_count; ++i) {
non_flatten_input_shape->SetDim(i - skip_dims_cnt, input_shape.Dims(i));
non_flatten_output_shape->SetDim(i - skip_dims_cnt, output_shape.Dims(i));
non_flatten_params->perm[i - skip_dims_cnt] = params.perm[i];
}
for (int i = 0; i < new_dims_cnt; ++i) {
int min_val_idx = -1;
for (int j = 0; j < new_dims_cnt; ++j) {
if (non_flatten_params->perm[j] >= i &&
(min_val_idx == -1 || non_flatten_params->perm[min_val_idx] >
non_flatten_params->perm[j])) {
min_val_idx = j;
}
}
non_flatten_params->perm[min_val_idx] = i;
}
return flat_size;
}
} // namespace transpose_utils
} // namespace tflite

View File

@ -0,0 +1,52 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_TRANSPOSE_UTILS_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_TRANSPOSE_UTILS_H_
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace transpose_utils {
// IsTranspose2DApplicable returns true if the given perm can be lowered to a
// 2D transpose op. If possible, it copies the lowered dimension counts to the
// given dim0 and dim1 pointers.
bool IsTranspose2DApplicable(const TransposeParams& params,
const RuntimeShape& input_shape, int* dim0,
int* dim1);
// RemoveOneSizeDimensions removes one size dimensions in the given input/output
// shapes and adjusts the parameter values for transpose op.
void RemoveOneSizeDimensions(RuntimeShape* input_shape,
RuntimeShape* output_shape,
TransposeParams* params);
// Flatten finds the dimensions that can be flatten, shrinks the given shapes
// and the given perm parameter to reflect the non-flatten dimensions, and
// returns the total size of the non-flatten dimensions.
//
// E.g, in perm [0, 1, 3, 2] case, the first two dimensions can be flatten and
// it returns |Dim Size(2)| x |Dim Size(3)|.
size_t Flatten(const RuntimeShape& input_shape,
const RuntimeShape& output_shape, const TransposeParams& params,
RuntimeShape* non_flatten_input_shape,
RuntimeShape* non_flatten_output_shape,
TransposeParams* non_flatten_params);
} // namespace transpose_utils
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_TRANSPOSE_UTILS_H_

View File

@ -0,0 +1,506 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/transpose_utils.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
namespace tflite {
namespace {
TEST(TransposeUtilsTest, RemoveOneSizeDimensions_1DNoChanges) {
RuntimeShape input_shape({9});
RuntimeShape output_shape({9});
TransposeParams params;
params.perm_count = 1;
params.perm[0] = 0;
transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
&params);
EXPECT_EQ(input_shape, RuntimeShape({9}));
EXPECT_EQ(output_shape, RuntimeShape({9}));
EXPECT_EQ(params.perm_count, 1);
EXPECT_EQ(params.perm[0], 0);
}
TEST(TransposeUtilsTest, RemoveOneSizeDimensions_2DNoChanges) {
RuntimeShape input_shape({9, 3});
RuntimeShape output_shape({3, 9});
TransposeParams params;
params.perm_count = 2;
params.perm[0] = 1;
params.perm[1] = 0;
transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
&params);
EXPECT_EQ(input_shape, RuntimeShape({9, 3}));
EXPECT_EQ(output_shape, RuntimeShape({3, 9}));
EXPECT_EQ(params.perm_count, 2);
EXPECT_EQ(params.perm[0], 1);
EXPECT_EQ(params.perm[1], 0);
}
TEST(TransposeUtilsTest, RemoveOneSizeDimensions_2DShrinking) {
RuntimeShape input_shape({9, 1});
RuntimeShape output_shape({1, 9});
TransposeParams params;
params.perm_count = 2;
params.perm[0] = 1;
params.perm[1] = 0;
transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
&params);
EXPECT_EQ(input_shape, RuntimeShape({9}));
EXPECT_EQ(output_shape, RuntimeShape({9}));
EXPECT_EQ(params.perm_count, 1);
EXPECT_EQ(params.perm[0], 0);
}
TEST(TransposeUtilsTest, RemoveOneSizeDimensions_3DNoChanges) {
RuntimeShape input_shape({4, 3, 8});
RuntimeShape output_shape({8, 4, 3});
TransposeParams params;
params.perm_count = 3;
params.perm[0] = 2;
params.perm[1] = 0;
params.perm[2] = 1;
transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
&params);
EXPECT_EQ(input_shape, RuntimeShape({4, 3, 8}));
EXPECT_EQ(output_shape, RuntimeShape({8, 4, 3}));
EXPECT_EQ(params.perm_count, 3);
EXPECT_EQ(params.perm[0], 2);
EXPECT_EQ(params.perm[1], 0);
EXPECT_EQ(params.perm[2], 1);
}
TEST(TransposeUtilsTest, RemoveOneSizeDimensions_3DShrinkingOnce) {
RuntimeShape input_shape({4, 1, 8});
RuntimeShape output_shape({8, 4, 1});
TransposeParams params;
params.perm_count = 3;
params.perm[0] = 2;
params.perm[1] = 0;
params.perm[2] = 1;
transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
&params);
EXPECT_EQ(input_shape, RuntimeShape({4, 8}));
EXPECT_EQ(output_shape, RuntimeShape({8, 4}));
EXPECT_EQ(output_shape.Dims(1), 4);
EXPECT_EQ(params.perm_count, 2);
EXPECT_EQ(params.perm[0], 1);
EXPECT_EQ(params.perm[1], 0);
}
TEST(TransposeUtilsTest, RemoveOneSizeDimensions_3DShrinkingTwice) {
RuntimeShape input_shape({4, 1, 1});
RuntimeShape output_shape({1, 4, 1});
TransposeParams params;
params.perm_count = 3;
params.perm[0] = 2;
params.perm[1] = 0;
params.perm[2] = 1;
transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
&params);
EXPECT_EQ(input_shape, RuntimeShape({4}));
EXPECT_EQ(output_shape, RuntimeShape({4}));
EXPECT_EQ(params.perm_count, 1);
EXPECT_EQ(params.perm[0], 0);
}
TEST(TransposeUtilsTest, RemoveOneSizeDimensions_3DAllOnes) {
RuntimeShape input_shape({1, 1, 1});
RuntimeShape output_shape({1, 1, 1});
TransposeParams params;
params.perm_count = 3;
params.perm[0] = 2;
params.perm[1] = 0;
params.perm[2] = 1;
transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
&params);
EXPECT_EQ(input_shape, RuntimeShape({1}));
EXPECT_EQ(output_shape, RuntimeShape({1}));
EXPECT_EQ(params.perm_count, 1);
EXPECT_EQ(params.perm[0], 0);
}
TEST(TransposeUtilsTest, RemoveOneSizeDimensions_4DNoChanges) {
RuntimeShape input_shape({9, 3, 2, 4});
RuntimeShape output_shape({3, 9, 4, 2});
TransposeParams params;
params.perm_count = 4;
params.perm[0] = 1;
params.perm[1] = 0;
params.perm[2] = 3;
params.perm[3] = 2;
transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
&params);
EXPECT_EQ(input_shape, RuntimeShape({9, 3, 2, 4}));
EXPECT_EQ(output_shape, RuntimeShape({3, 9, 4, 2}));
EXPECT_EQ(params.perm_count, 4);
EXPECT_EQ(params.perm[0], 1);
EXPECT_EQ(params.perm[1], 0);
EXPECT_EQ(params.perm[2], 3);
EXPECT_EQ(params.perm[3], 2);
}
TEST(TransposeUtilsTest, RemoveOneSizeDimensions_4DShrinkingOnce) {
RuntimeShape input_shape({9, 3, 1, 4});
RuntimeShape output_shape({3, 9, 4, 1});
TransposeParams params;
params.perm_count = 4;
params.perm[0] = 1;
params.perm[1] = 0;
params.perm[2] = 3;
params.perm[3] = 2;
transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
&params);
EXPECT_EQ(input_shape, RuntimeShape({9, 3, 4}));
EXPECT_EQ(output_shape, RuntimeShape({3, 9, 4}));
EXPECT_EQ(params.perm_count, 3);
EXPECT_EQ(params.perm[0], 1);
EXPECT_EQ(params.perm[1], 0);
EXPECT_EQ(params.perm[2], 2);
}
TEST(TransposeUtilsTest, RemoveOneSizeDimensions_4DShrinkingTwice) {
RuntimeShape input_shape({1, 3, 1, 4});
RuntimeShape output_shape({3, 1, 4, 1});
TransposeParams params;
params.perm_count = 4;
params.perm[0] = 1;
params.perm[1] = 2;
params.perm[2] = 3;
params.perm[3] = 0;
transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
&params);
EXPECT_EQ(input_shape, RuntimeShape({3, 4}));
EXPECT_EQ(output_shape, RuntimeShape({3, 4}));
EXPECT_EQ(params.perm_count, 2);
EXPECT_EQ(params.perm[0], 0);
EXPECT_EQ(params.perm[1], 1);
}
TEST(TransposeUtilsTest, RemoveOneSizeDimensions_4DShrinkingThirdTimes) {
RuntimeShape input_shape({1, 1, 7, 1});
RuntimeShape output_shape({1, 7, 1, 1});
TransposeParams params;
params.perm_count = 4;
params.perm[0] = 0;
params.perm[1] = 2;
params.perm[2] = 1;
params.perm[3] = 3;
transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
&params);
EXPECT_EQ(input_shape, RuntimeShape({7}));
EXPECT_EQ(output_shape, RuntimeShape({7}));
EXPECT_EQ(params.perm_count, 1);
EXPECT_EQ(params.perm[0], 0);
}
TEST(TransposeUtilsTest, RemoveOneSizeDimensions_4DAllOnes) {
RuntimeShape input_shape({1, 1, 1, 1});
RuntimeShape output_shape({1, 1, 1, 1});
TransposeParams params;
params.perm_count = 4;
params.perm[0] = 0;
params.perm[1] = 2;
params.perm[2] = 1;
params.perm[3] = 3;
transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
&params);
EXPECT_EQ(input_shape, RuntimeShape({1}));
EXPECT_EQ(output_shape, RuntimeShape({1}));
EXPECT_EQ(params.perm_count, 1);
EXPECT_EQ(params.perm[0], 0);
}
TEST(TransposeUtilsTest, Flatten3D) {
RuntimeShape input_shape({3, 5, 7});
RuntimeShape output_shape({3, 7, 5});
TransposeParams params;
params.perm_count = 3;
params.perm[0] = 0;
params.perm[1] = 2;
params.perm[2] = 1;
RuntimeShape non_flatten_input_shape;
RuntimeShape non_flatten_output_shape;
TransposeParams non_flatten_params;
size_t non_flatten_size = transpose_utils::Flatten(
input_shape, output_shape, params, &non_flatten_input_shape,
&non_flatten_output_shape, &non_flatten_params);
EXPECT_EQ(non_flatten_input_shape, RuntimeShape({5, 7}));
EXPECT_EQ(non_flatten_output_shape, RuntimeShape({7, 5}));
EXPECT_EQ(non_flatten_size, 5 * 7);
EXPECT_EQ(non_flatten_params.perm_count, 2);
EXPECT_EQ(non_flatten_params.perm[0], 1);
EXPECT_EQ(non_flatten_params.perm[1], 0);
}
TEST(TransposeUtilsTest, Flatten4DFlattenOnce) {
RuntimeShape input_shape({3, 5, 7, 9});
RuntimeShape output_shape({3, 7, 5, 9});
TransposeParams params;
params.perm_count = 4;
params.perm[0] = 0;
params.perm[1] = 2;
params.perm[2] = 1;
params.perm[3] = 3;
RuntimeShape non_flatten_input_shape;
RuntimeShape non_flatten_output_shape;
TransposeParams non_flatten_params;
size_t non_flatten_size = transpose_utils::Flatten(
input_shape, output_shape, params, &non_flatten_input_shape,
&non_flatten_output_shape, &non_flatten_params);
EXPECT_EQ(non_flatten_input_shape, RuntimeShape({5, 7, 9}));
EXPECT_EQ(non_flatten_output_shape, RuntimeShape({7, 5, 9}));
EXPECT_EQ(non_flatten_size, 5 * 7 * 9);
EXPECT_EQ(non_flatten_params.perm_count, 3);
EXPECT_EQ(non_flatten_params.perm[0], 1);
EXPECT_EQ(non_flatten_params.perm[1], 0);
EXPECT_EQ(non_flatten_params.perm[2], 2);
}
TEST(TransposeUtilsTest, Flatten4DFlattenTwice) {
RuntimeShape input_shape({3, 5, 7, 9});
RuntimeShape output_shape({3, 5, 9, 7});
TransposeParams params;
params.perm_count = 4;
params.perm[0] = 0;
params.perm[1] = 1;
params.perm[2] = 3;
params.perm[3] = 2;
RuntimeShape non_flatten_input_shape;
RuntimeShape non_flatten_output_shape;
TransposeParams non_flatten_params;
size_t non_flatten_size = transpose_utils::Flatten(
input_shape, output_shape, params, &non_flatten_input_shape,
&non_flatten_output_shape, &non_flatten_params);
EXPECT_EQ(non_flatten_input_shape, RuntimeShape({7, 9}));
EXPECT_EQ(non_flatten_output_shape, RuntimeShape({9, 7}));
EXPECT_EQ(non_flatten_size, 7 * 9);
EXPECT_EQ(non_flatten_params.perm_count, 2);
EXPECT_EQ(non_flatten_params.perm[0], 1);
EXPECT_EQ(non_flatten_params.perm[1], 0);
}
TEST(TransposeUtilsTest, IsTranspose2DApplicable2D) {
RuntimeShape input_shape({4, 5});
TransposeParams params;
params.perm_count = 2;
params.perm[0] = 1;
params.perm[1] = 0;
int dim0, dim1;
bool applicable = transpose_utils::IsTranspose2DApplicable(
params, input_shape, &dim0, &dim1);
EXPECT_TRUE(applicable);
EXPECT_EQ(dim0, 4);
EXPECT_EQ(dim1, 5);
}
TEST(TransposeUtilsTest, IsTranspose2DApplicable3DOne) {
RuntimeShape input_shape({4, 5, 6});
TransposeParams params;
params.perm_count = 3;
params.perm[0] = 1;
params.perm[1] = 2;
params.perm[2] = 0;
int dim0, dim1;
bool applicable = transpose_utils::IsTranspose2DApplicable(
params, input_shape, &dim0, &dim1);
EXPECT_TRUE(applicable);
EXPECT_EQ(dim0, 4);
EXPECT_EQ(dim1, 30);
}
TEST(TransposeUtilsTest, IsTranspose2DApplicable3DTwo) {
RuntimeShape input_shape({4, 5, 6});
TransposeParams params;
params.perm_count = 3;
params.perm[0] = 2;
params.perm[1] = 0;
params.perm[2] = 1;
int dim0, dim1;
bool applicable = transpose_utils::IsTranspose2DApplicable(
params, input_shape, &dim0, &dim1);
EXPECT_TRUE(applicable);
EXPECT_EQ(dim0, 20);
EXPECT_EQ(dim1, 6);
}
TEST(TransposeUtilsTest, IsTranspose2DApplicable3DNotApplicable) {
RuntimeShape input_shape({4, 5, 6});
TransposeParams params;
params.perm_count = 3;
params.perm[0] = 2;
params.perm[1] = 1;
params.perm[2] = 0;
int dim0, dim1;
bool applicable = transpose_utils::IsTranspose2DApplicable(
params, input_shape, &dim0, &dim1);
EXPECT_FALSE(applicable);
}
TEST(TransposeUtilsTest, IsTranspose2DApplicable4DOne) {
RuntimeShape input_shape({4, 5, 6, 7});
TransposeParams params;
params.perm_count = 4;
params.perm[0] = 1;
params.perm[1] = 2;
params.perm[2] = 3;
params.perm[3] = 0;
int dim0, dim1;
bool applicable = transpose_utils::IsTranspose2DApplicable(
params, input_shape, &dim0, &dim1);
EXPECT_TRUE(applicable);
EXPECT_EQ(dim0, 4);
EXPECT_EQ(dim1, 210);
}
TEST(TransposeUtilsTest, IsTranspose2DApplicable4DTwo) {
RuntimeShape input_shape({4, 5, 6, 7});
TransposeParams params;
params.perm_count = 4;
params.perm[0] = 2;
params.perm[1] = 3;
params.perm[2] = 0;
params.perm[3] = 1;
int dim0, dim1;
bool applicable = transpose_utils::IsTranspose2DApplicable(
params, input_shape, &dim0, &dim1);
EXPECT_TRUE(applicable);
EXPECT_EQ(dim0, 20);
EXPECT_EQ(dim1, 42);
}
TEST(TransposeUtilsTest, IsTranspose2DApplicable4DThird) {
RuntimeShape input_shape({4, 5, 6, 7});
TransposeParams params;
params.perm_count = 4;
params.perm[0] = 3;
params.perm[1] = 0;
params.perm[2] = 1;
params.perm[3] = 2;
int dim0, dim1;
bool applicable = transpose_utils::IsTranspose2DApplicable(
params, input_shape, &dim0, &dim1);
EXPECT_TRUE(applicable);
EXPECT_EQ(dim0, 120);
EXPECT_EQ(dim1, 7);
}
TEST(TransposeUtilsTest, IsTranspose2DApplicable4DNotApplicable) {
RuntimeShape input_shape({4, 5, 6, 7});
TransposeParams params;
params.perm_count = 4;
params.perm[0] = 3;
params.perm[1] = 2;
params.perm[2] = 1;
params.perm[3] = 0;
int dim0, dim1;
bool applicable = transpose_utils::IsTranspose2DApplicable(
params, input_shape, &dim0, &dim1);
EXPECT_FALSE(applicable);
}
} // namespace
} // namespace tflite
int main(int argc, char** argv) {
// On Linux, add: tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -100,18 +100,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const int size = op_context.perm->dims->data[0];
TransposeParams params;
params.perm_count = size;
bool identical = true;
for (int i = 0; i < size; ++i) {
params.perm[i] = perm_data[i];
if (perm_data[i] != i) identical = false;
}
// TODO(b/140779653): Add an optimization pass in the conversion process to
// remove transpose op nodes where they do nothing like the below one.
if (identical) {
memcpy(op_context.output->data.raw, op_context.input->data.raw,
op_context.output->bytes);
return kTfLiteOk;
}
#define TF_LITE_TRANSPOSE(type, scalar) \
@ -120,28 +110,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTensorShape(op_context.output), \
GetTensorData<scalar>(op_context.output))
// Transpose kernel only does rearranging values not numeric evaluations on
// each cell. It's safe to implement per size of scalar type and this trick
// keeps the total code size in a reasonable range.
switch (op_context.input->type) {
case kTfLiteFloat32:
if (kernel_type == kGenericOptimized) {
TF_LITE_TRANSPOSE(optimized_ops, float);
} else {
TF_LITE_TRANSPOSE(reference_ops, float);
}
break;
case kTfLiteUInt8:
if (kernel_type == kGenericOptimized) {
TF_LITE_TRANSPOSE(optimized_ops, uint8_t);
} else {
TF_LITE_TRANSPOSE(reference_ops, uint8_t);
}
break;
case kTfLiteInt8:
if (kernel_type == kGenericOptimized) {
TF_LITE_TRANSPOSE(optimized_ops, int8_t);
} else {
TF_LITE_TRANSPOSE(reference_ops, int8_t);
}
break;
case kTfLiteInt32:
if (kernel_type == kGenericOptimized) {
TF_LITE_TRANSPOSE(optimized_ops, int32_t);
@ -149,16 +122,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_TRANSPOSE(reference_ops, int32_t);
}
break;
case kTfLiteInt64:
case kTfLiteUInt8:
case kTfLiteInt8:
if (kernel_type == kGenericOptimized) {
TF_LITE_TRANSPOSE(optimized_ops, int64_t);
TF_LITE_TRANSPOSE(optimized_ops, int8_t);
} else {
TF_LITE_TRANSPOSE(reference_ops, int64_t);
TF_LITE_TRANSPOSE(reference_ops, int8_t);
}
break;
case kTfLiteInt64:
TF_LITE_TRANSPOSE(reference_ops, int64_t);
break;
case kTfLiteBool:
if (kernel_type == kGenericOptimized) {
TF_LITE_TRANSPOSE(optimized_ops, bool);
if (sizeof(bool) == 1) {
if (kernel_type == kGenericOptimized) {
TF_LITE_TRANSPOSE(optimized_ops, int8_t);
} else {
TF_LITE_TRANSPOSE(reference_ops, int8_t);
}
} else {
TF_LITE_TRANSPOSE(reference_ops, bool);
}

View File

@ -346,6 +346,167 @@ TEST(TransposeTest, Test3DInputDynamicTensor) {
2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23}));
}
TEST(TransposeTest, Test1DNotShrinked) {
TransposeOpConstModel m({1}, {1}, {0});
m.SetInput({0});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0}));
}
TEST(TransposeTest, Test2DShrinkedOneTime) {
TransposeOpConstModel m({2, 1}, {2}, {1, 0});
m.SetInput({0, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 1}));
}
TEST(TransposeTest, Test2DShrinkedTwoTimes) {
TransposeOpConstModel m({1, 1}, {2}, {1, 0});
m.SetInput({0});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0}));
}
TEST(TransposeTest, Test3DShrinkedOneTime) {
TransposeOpConstModel m({2, 1, 3}, {3}, {0, 2, 1});
m.SetInput({0, 1, 2, 3, 4, 5});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 1}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 1, 2, 3, 4, 5}));
}
TEST(TransposeTest, Test3DShrinkedTwoTimes) {
TransposeOpConstModel m({1, 1, 3}, {3}, {1, 2, 0});
m.SetInput({0, 1, 2});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 1, 2}));
}
TEST(TransposeTest, Test3DShrinkedAll) {
TransposeOpConstModel m({1, 1, 1}, {3}, {1, 2, 0});
m.SetInput({0});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0}));
}
TEST(TransposeTest, Test4DShrinkedOneTimes) {
TransposeOpConstModel m({2, 2, 3, 1}, {4}, {3, 0, 1, 2});
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 2, 3}));
EXPECT_THAT(m.GetOutput(),
ElementsAreArray({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}));
}
TEST(TransposeTest, Test4DShrinkedTwoTimes) {
TransposeOpConstModel m({2, 1, 3, 1}, {4}, {0, 3, 1, 2});
m.SetInput({0, 1, 2, 3, 4, 5});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 1, 3}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 1, 2, 3, 4, 5}));
}
TEST(TransposeTest, Test4DShrinkedThirdTimes) {
TransposeOpConstModel m({2, 1, 1, 1}, {4}, {3, 2, 1, 0});
m.SetInput({0, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 1}));
}
TEST(TransposeTest, Test4DShrinkedFourTimes) {
TransposeOpConstModel m({1, 1, 1, 1}, {4}, {2, 3, 1, 0});
m.SetInput({0});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 1}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0}));
}
TEST(TransposeTest, Test3DFlatten) {
TransposeOpConstModel m({2, 2, 3}, {3}, {0, 2, 1});
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 2}));
EXPECT_THAT(m.GetOutput(),
ElementsAreArray({0, 3, 1, 4, 2, 5, 6, 9, 7, 10, 8, 11}));
}
TEST(TransposeTest, Test4DFlattenOne) {
TransposeOpConstModel m({2, 2, 2, 2}, {4}, {0, 1, 3, 2});
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 1, 3, 4, 6, 5, 7, 8, 10, 9,
11, 12, 14, 13, 15}));
}
TEST(TransposeTest, Test4DFlattenTwo) {
TransposeOpConstModel m({2, 2, 2, 2}, {4}, {0, 2, 3, 1});
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2, 2}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 4, 1, 5, 2, 6, 3, 7, 8, 12, 9,
13, 10, 14, 11, 15}));
}
TEST(TransposeTest, 3DDividedIntoTwo2DsOne) {
std::vector<float> out;
RunTestPermutation({2, 3, 4}, {1, 2, 0}, &out);
TransposeOpConstModel m({2, 3, 4}, {3}, {1, 2, 0});
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23});
m.Invoke();
EXPECT_EQ(m.GetOutput(), out);
}
TEST(TransposeTest, 3DDividedIntoTwo2DsTwo) {
std::vector<float> out;
RunTestPermutation({2, 3, 4}, {2, 0, 1}, &out);
TransposeOpConstModel m({2, 3, 4}, {3}, {2, 0, 1});
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23});
m.Invoke();
EXPECT_EQ(m.GetOutput(), out);
}
TEST(TransposeTest, 4DDividedIntoTwo2DsOne) {
std::vector<float> out;
RunTestPermutation({2, 3, 4, 2}, {1, 2, 3, 0}, &out);
TransposeOpConstModel m({2, 3, 4, 2}, {4}, {1, 2, 3, 0});
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47});
m.Invoke();
EXPECT_EQ(m.GetOutput(), out);
}
TEST(TransposeTest, 4DDividedIntoTwo2DsTwo) {
std::vector<float> out;
RunTestPermutation({2, 3, 4, 2}, {2, 3, 0, 1}, &out);
TransposeOpConstModel m({2, 3, 4, 2}, {4}, {2, 3, 0, 1});
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47});
m.Invoke();
EXPECT_EQ(m.GetOutput(), out);
}
TEST(TransposeTest, 4DDividedIntoTwo2DsThird) {
std::vector<float> out;
RunTestPermutation({2, 3, 4, 2}, {3, 0, 1, 2}, &out);
TransposeOpConstModel m({2, 3, 4, 2}, {4}, {3, 0, 1, 2});
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47});
m.Invoke();
EXPECT_EQ(m.GetOutput(), out);
}
#ifdef GTEST_HAS_DEATH_TEST
TEST(TransposeTest, Test5DInputTensor) {
EXPECT_DEATH(TransposeOpConstModel({1, 2, 3, 4, 5}, {5}, {0, 1, 2, 3, 4}),