Implement general dimension reduction optimizations
PiperOrigin-RevId: 276480378 Change-Id: Iaf8f9228c51cf4f39c95c7f66d02abea2ad2b9fa
This commit is contained in:
parent
8598d2c233
commit
37248520e4
tensorflow/lite/kernels
@ -229,6 +229,7 @@ cc_library(
|
||||
":round",
|
||||
":tensor",
|
||||
":tensor_utils",
|
||||
":transpose_utils",
|
||||
"//third_party/eigen3",
|
||||
"@gemmlowp//:fixedpoint",
|
||||
"@gemmlowp//:profiler",
|
||||
@ -271,6 +272,7 @@ cc_library(
|
||||
":strided_slice_logic",
|
||||
":tensor",
|
||||
":tensor_utils",
|
||||
":transpose_utils",
|
||||
":types",
|
||||
":legacy_types",
|
||||
":legacy_reference_base",
|
||||
@ -349,6 +351,28 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "transpose_utils",
|
||||
srcs = [
|
||||
"transpose_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"transpose_utils.h",
|
||||
],
|
||||
deps = [
|
||||
":types",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "transpose_utils_test",
|
||||
srcs = ["transpose_utils_test.cc"],
|
||||
deps = [
|
||||
":transpose_utils",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "strided_slice_logic",
|
||||
srcs = [],
|
||||
|
@ -51,6 +51,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
|
||||
#include "tensorflow/lite/kernels/internal/transpose_utils.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -7095,15 +7096,10 @@ inline void Logistic16bitPercision(const LogisticParams& params,
|
||||
// Perform transpose by transposing 4x4 blocks of the input, proceeding from
|
||||
// left to right (down the rows) of the input, and then from top to bottom.
|
||||
template <typename T>
|
||||
inline void Transpose2DImpl(const TransposeParams& params,
|
||||
const RuntimeShape& input_shape,
|
||||
const T* input_data,
|
||||
const RuntimeShape& output_shape, T* output_data) {
|
||||
inline void Transpose2D(const RuntimeShape& input_shape, const T* input_data,
|
||||
const RuntimeShape& output_shape, T* output_data) {
|
||||
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 2);
|
||||
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2);
|
||||
TFLITE_DCHECK_EQ(params.perm_count, 2);
|
||||
TFLITE_DCHECK_EQ(params.perm[0], 1);
|
||||
TFLITE_DCHECK_EQ(params.perm[1], 0);
|
||||
|
||||
const int d0 = input_shape.DimsData()[0];
|
||||
const int d1 = input_shape.DimsData()[1];
|
||||
@ -7196,16 +7192,12 @@ inline void Transpose2DImpl(const TransposeParams& params,
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void Transpose2DImpl(const TransposeParams& params,
|
||||
const RuntimeShape& input_shape,
|
||||
const int32_t* input_data,
|
||||
const RuntimeShape& output_shape,
|
||||
int32_t* output_data) {
|
||||
inline void Transpose2D(const RuntimeShape& input_shape,
|
||||
const int32_t* input_data,
|
||||
const RuntimeShape& output_shape,
|
||||
int32_t* output_data) {
|
||||
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 2);
|
||||
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2);
|
||||
TFLITE_DCHECK_EQ(params.perm_count, 2);
|
||||
TFLITE_DCHECK_EQ(params.perm[0], 1);
|
||||
TFLITE_DCHECK_EQ(params.perm[1], 0);
|
||||
|
||||
const int d0 = input_shape.DimsData()[0];
|
||||
const int d1 = input_shape.DimsData()[1];
|
||||
@ -7278,93 +7270,17 @@ inline void Transpose2DImpl(const TransposeParams& params,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Transpose2D(const TransposeParams& params, const RuntimeShape& input_shape,
|
||||
const T* input_data, const RuntimeShape& output_shape,
|
||||
T* output_data) {
|
||||
// Transpose kernel only does rearranging values not numeric evaluations on
|
||||
// each cell. It's safe to implement per size of scalar type and this trick
|
||||
// keeps the total code size in a reasonable range.
|
||||
switch (sizeof(T)) {
|
||||
case 1:
|
||||
Transpose2DImpl(params, input_shape,
|
||||
reinterpret_cast<const int8_t*>(input_data), output_shape,
|
||||
reinterpret_cast<int8_t*>(output_data));
|
||||
break;
|
||||
case 4:
|
||||
Transpose2DImpl(params, input_shape,
|
||||
reinterpret_cast<const int32_t*>(input_data),
|
||||
output_shape, reinterpret_cast<int32_t*>(output_data));
|
||||
break;
|
||||
default:
|
||||
// Reroute to the reference version if an optimized method for the given
|
||||
// data is not available.
|
||||
reference_ops::Transpose(params, input_shape, input_data, output_shape,
|
||||
output_data);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(alanchiao): see if we can reduce the number
|
||||
// of lines of code in branching without affecting latency.
|
||||
template <typename T>
|
||||
inline void Transpose3DImpl(const TransposeParams& params,
|
||||
const RuntimeShape& input_shape,
|
||||
const T* input_data,
|
||||
const RuntimeShape& output_shape, T* output_data) {
|
||||
inline void Transpose3D(const TransposeParams& params,
|
||||
const RuntimeShape& input_shape, const T* input_data,
|
||||
const RuntimeShape& output_shape, T* output_data) {
|
||||
int s1, s2, s3;
|
||||
s1 = input_shape.Dims(0);
|
||||
s2 = input_shape.Dims(1);
|
||||
s3 = input_shape.Dims(2);
|
||||
|
||||
// TODO(b/141169757): generalize the following logics and move to the
|
||||
// Transpose method.
|
||||
const bool hasOneInDimension = (s1 == 1 || s2 == 1 || s3 == 1);
|
||||
// Can fast path as 2D transpose in this case.
|
||||
if (hasOneInDimension) {
|
||||
int d1, d2;
|
||||
bool is_identity = false;
|
||||
// Check for identity to just return.
|
||||
if (s1 == 1) {
|
||||
// (0, 1, 2), (1, 0, 2), (1, 2, 0)
|
||||
if ((params.perm[0] == 0 && params.perm[1] == 1) || params.perm[0] == 1) {
|
||||
is_identity = true;
|
||||
}
|
||||
d1 = s2;
|
||||
d2 = s3;
|
||||
} else if (s2 == 1) {
|
||||
// (0, 1, 2), (0, 2, 1), (1, 0, 2)
|
||||
if ((params.perm[0] == 1 && params.perm[1] == 0) || params.perm[0] == 0) {
|
||||
is_identity = true;
|
||||
}
|
||||
d1 = s1;
|
||||
d2 = s3;
|
||||
} else {
|
||||
// (0, 1, 2), (0, 2, 1), (2, 0, 1)
|
||||
if ((params.perm[0] == 2 && params.perm[1] == 0) || params.perm[0] == 0) {
|
||||
is_identity = true;
|
||||
}
|
||||
d1 = s1;
|
||||
d2 = s2;
|
||||
}
|
||||
|
||||
if (is_identity) {
|
||||
memcpy(output_data, input_data, sizeof(T) * input_shape.FlatSize());
|
||||
return;
|
||||
}
|
||||
|
||||
TransposeParams new_params;
|
||||
new_params.perm_count = 2;
|
||||
new_params.perm[0] = 1;
|
||||
new_params.perm[1] = 0;
|
||||
|
||||
const RuntimeShape new_input_shape({d1, d2});
|
||||
const RuntimeShape new_output_shape({d2, d1});
|
||||
|
||||
Transpose2D(new_params, new_input_shape, input_data, new_output_shape,
|
||||
output_data);
|
||||
return;
|
||||
}
|
||||
|
||||
int p1, p2, p3;
|
||||
if (params.perm[0] == 2) {
|
||||
p1 = 1;
|
||||
@ -7407,44 +7323,16 @@ inline void Transpose3DImpl(const TransposeParams& params,
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Transpose3D(const TransposeParams& params, const RuntimeShape& input_shape,
|
||||
const T* input_data, const RuntimeShape& output_shape,
|
||||
T* output_data) {
|
||||
// Transpose kernel only does rearranging values not numeric evaluations on
|
||||
// each cell. It's safe to implement per size of scalar type and this trick
|
||||
// keeps the total code size in a reasonable range.
|
||||
switch (sizeof(T)) {
|
||||
case 1:
|
||||
Transpose3DImpl(params, input_shape,
|
||||
reinterpret_cast<const int8_t*>(input_data), output_shape,
|
||||
reinterpret_cast<int8_t*>(output_data));
|
||||
break;
|
||||
case 4:
|
||||
Transpose3DImpl(params, input_shape,
|
||||
reinterpret_cast<const int32_t*>(input_data),
|
||||
output_shape, reinterpret_cast<int32_t*>(output_data));
|
||||
break;
|
||||
default:
|
||||
// Reroute to the reference version if an optimized method for the given
|
||||
// data is not available.
|
||||
reference_ops::Transpose(params, input_shape, input_data, output_shape,
|
||||
output_data);
|
||||
}
|
||||
}
|
||||
void TransposeImpl(const TransposeParams& params,
|
||||
const RuntimeShape& input_shape, const T* input_data,
|
||||
const RuntimeShape& output_shape, T* output_data) {
|
||||
const int dims_cnt = input_shape.DimensionsCount();
|
||||
|
||||
template <typename T>
|
||||
void Transpose(const TransposeParams& params, const RuntimeShape& input_shape,
|
||||
const T* input_data, const RuntimeShape& output_shape,
|
||||
T* output_data) {
|
||||
const int output_size = output_shape.DimensionsCount();
|
||||
TFLITE_DCHECK_LE(input_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(output_size, 4);
|
||||
TFLITE_DCHECK_EQ(output_size, params.perm_count);
|
||||
|
||||
// Apply 2-D transpose.
|
||||
if (input_shape.DimensionsCount() == 2 && params.perm[0] == 1 &&
|
||||
params.perm[1] == 0) {
|
||||
Transpose2D(params, input_shape, input_data, output_shape, output_data);
|
||||
int dim0, dim1;
|
||||
if (transpose_utils::IsTranspose2DApplicable(params, input_shape, &dim0,
|
||||
&dim1)) {
|
||||
Transpose2D(RuntimeShape({dim0, dim1}), input_data,
|
||||
RuntimeShape({dim1, dim0}), output_data);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -7458,7 +7346,7 @@ void Transpose(const TransposeParams& params, const RuntimeShape& input_shape,
|
||||
// 96^3 is not mobile-friendly for certain usecases
|
||||
// (e.g. model used in beam search for seq2seq) but is in others.
|
||||
// Consider tradeoffs.
|
||||
if (input_shape.DimensionsCount() == 3) {
|
||||
if (dims_cnt == 3) {
|
||||
Transpose3D(params, input_shape, input_data, output_shape, output_data);
|
||||
return;
|
||||
}
|
||||
@ -7469,6 +7357,66 @@ void Transpose(const TransposeParams& params, const RuntimeShape& input_shape,
|
||||
output_data);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Transpose(const TransposeParams& unshrinked_params,
|
||||
const RuntimeShape& unshrinked_input_shape, const T* input_data,
|
||||
const RuntimeShape& unshrinked_output_shape, T* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("Transpose");
|
||||
|
||||
const int output_size = unshrinked_output_shape.DimensionsCount();
|
||||
TFLITE_DCHECK_LE(unshrinked_input_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(output_size, 4);
|
||||
TFLITE_DCHECK_EQ(output_size, unshrinked_params.perm_count);
|
||||
|
||||
RuntimeShape shrinked_input_shape = RuntimeShape(unshrinked_input_shape);
|
||||
RuntimeShape shrinked_output_shape = RuntimeShape(unshrinked_output_shape);
|
||||
TransposeParams shrinked_params = unshrinked_params;
|
||||
|
||||
// Reduce any dimensions that have one size. Lower transpose op usually
|
||||
// performs better since memory access patterns will be improved.
|
||||
transpose_utils::RemoveOneSizeDimensions(
|
||||
&shrinked_input_shape, &shrinked_output_shape, &shrinked_params);
|
||||
|
||||
// Handle identity cases.
|
||||
// TODO(b/140779653): Add an optimization pass in the conversion process to
|
||||
// remove transpose op nodes where they do nothing like the below one.
|
||||
bool identical = true;
|
||||
for (int i = 0; i < shrinked_params.perm_count; ++i) {
|
||||
if (shrinked_params.perm[i] != i) {
|
||||
identical = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (identical) {
|
||||
memcpy(output_data, input_data,
|
||||
unshrinked_input_shape.FlatSize() * sizeof(T));
|
||||
return;
|
||||
}
|
||||
|
||||
// Reduce dimensions by flattening.
|
||||
if (shrinked_params.perm[0] == 0 && output_size >= 3) {
|
||||
RuntimeShape non_flatten_input_shape;
|
||||
RuntimeShape non_flatten_output_shape;
|
||||
TransposeParams non_flatten_params;
|
||||
const int total_size = shrinked_input_shape.FlatSize();
|
||||
const int non_flatten_size = transpose_utils::Flatten(
|
||||
shrinked_input_shape, shrinked_output_shape, shrinked_params,
|
||||
&non_flatten_input_shape, &non_flatten_output_shape,
|
||||
&non_flatten_params);
|
||||
TFLITE_DCHECK_NE(non_flatten_params.perm[0], 0);
|
||||
|
||||
for (int i = 0; i < total_size; i += non_flatten_size) {
|
||||
TransposeImpl(non_flatten_params, non_flatten_input_shape, input_data + i,
|
||||
non_flatten_output_shape, output_data + i);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Call non-flattened case.
|
||||
TransposeImpl(shrinked_params, shrinked_input_shape, input_data,
|
||||
shrinked_output_shape, output_data);
|
||||
}
|
||||
|
||||
} // namespace optimized_ops
|
||||
} // namespace tflite
|
||||
|
||||
|
165
tensorflow/lite/kernels/internal/transpose_utils.cc
Normal file
165
tensorflow/lite/kernels/internal/transpose_utils.cc
Normal file
@ -0,0 +1,165 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/kernels/internal/transpose_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace transpose_utils {
|
||||
|
||||
bool IsTranspose2DApplicable(const TransposeParams& params,
|
||||
const RuntimeShape& input_shape, int* dim0,
|
||||
int* dim1) {
|
||||
const int dims_cnt = input_shape.DimensionsCount();
|
||||
|
||||
if (dims_cnt == 2) {
|
||||
*dim0 = input_shape.Dims(0);
|
||||
*dim1 = input_shape.Dims(1);
|
||||
return true;
|
||||
}
|
||||
|
||||
const int first_perm = params.perm[0];
|
||||
for (int i = 1; i < dims_cnt; ++i) {
|
||||
int rebased = params.perm[i] - first_perm;
|
||||
if (rebased < 0) {
|
||||
rebased += dims_cnt;
|
||||
}
|
||||
if (rebased != i) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
*dim0 = 1;
|
||||
*dim1 = 1;
|
||||
for (int i = 0; i < dims_cnt; ++i) {
|
||||
if (i < first_perm) {
|
||||
*dim0 *= input_shape.Dims(i);
|
||||
} else {
|
||||
*dim1 *= input_shape.Dims(i);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void RemoveOneSizeDimensions(RuntimeShape* input_shape,
|
||||
RuntimeShape* output_shape,
|
||||
TransposeParams* params) {
|
||||
const int dims_cnt = input_shape->DimensionsCount();
|
||||
TFLITE_DCHECK_EQ(params->perm_count, dims_cnt);
|
||||
|
||||
bool foundOneSizeDim = false;
|
||||
for (int i = 0; i < dims_cnt; ++i) {
|
||||
if (input_shape->Dims(i) == 1) {
|
||||
foundOneSizeDim = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Return here if there is no one size dimension.
|
||||
if (!foundOneSizeDim) return;
|
||||
|
||||
// Handle the case where all the dimension size is one.
|
||||
if (input_shape->FlatSize() == 1) {
|
||||
input_shape->Resize(1);
|
||||
input_shape->SetDim(0, 1);
|
||||
output_shape->Resize(1);
|
||||
output_shape->SetDim(0, 1);
|
||||
params->perm_count = 1;
|
||||
params->perm[0] = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
// Resize input shape.
|
||||
int new_dims_cnt = 0;
|
||||
for (int i = 0; i < dims_cnt; ++i) {
|
||||
if (input_shape->Dims(i) == 1) {
|
||||
continue;
|
||||
}
|
||||
input_shape->SetDim(new_dims_cnt, input_shape->Dims(i));
|
||||
++new_dims_cnt;
|
||||
}
|
||||
input_shape->Resize(new_dims_cnt);
|
||||
|
||||
// Resize output shape and re-calculate the perm parameter.
|
||||
TransposeParams new_params;
|
||||
new_dims_cnt = 0;
|
||||
for (int i = 0; i < dims_cnt; ++i) {
|
||||
if (output_shape->Dims(i) == 1) {
|
||||
continue;
|
||||
}
|
||||
new_params.perm[new_dims_cnt] = params->perm[i];
|
||||
output_shape->SetDim(new_dims_cnt, output_shape->Dims(i));
|
||||
++new_dims_cnt;
|
||||
}
|
||||
output_shape->Resize(new_dims_cnt);
|
||||
new_params.perm_count = new_dims_cnt;
|
||||
|
||||
for (int i = 0; i < new_dims_cnt; ++i) {
|
||||
int min_val_idx = -1;
|
||||
for (int j = 0; j < new_dims_cnt; ++j) {
|
||||
if (new_params.perm[j] >= i &&
|
||||
(min_val_idx == -1 ||
|
||||
new_params.perm[min_val_idx] > new_params.perm[j])) {
|
||||
min_val_idx = j;
|
||||
}
|
||||
}
|
||||
new_params.perm[min_val_idx] = i;
|
||||
}
|
||||
*params = new_params;
|
||||
}
|
||||
|
||||
size_t Flatten(const RuntimeShape& input_shape,
|
||||
const RuntimeShape& output_shape, const TransposeParams& params,
|
||||
RuntimeShape* non_flatten_input_shape,
|
||||
RuntimeShape* non_flatten_output_shape,
|
||||
TransposeParams* non_flatten_params) {
|
||||
// Calculate the total size of non-flatten dimensions.
|
||||
int skip_dims_cnt = 0;
|
||||
size_t flat_size = input_shape.FlatSize();
|
||||
for (int i = 0; i < params.perm_count; ++i) {
|
||||
if (params.perm[i] == i) {
|
||||
flat_size /= input_shape.Dims(i);
|
||||
++skip_dims_cnt;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Shrink the shapes and re-calculate the perm parameter.
|
||||
const int new_dims_cnt = params.perm_count - skip_dims_cnt;
|
||||
non_flatten_input_shape->Resize(new_dims_cnt);
|
||||
non_flatten_output_shape->Resize(new_dims_cnt);
|
||||
non_flatten_params->perm_count = new_dims_cnt;
|
||||
|
||||
for (int i = skip_dims_cnt; i < params.perm_count; ++i) {
|
||||
non_flatten_input_shape->SetDim(i - skip_dims_cnt, input_shape.Dims(i));
|
||||
non_flatten_output_shape->SetDim(i - skip_dims_cnt, output_shape.Dims(i));
|
||||
non_flatten_params->perm[i - skip_dims_cnt] = params.perm[i];
|
||||
}
|
||||
for (int i = 0; i < new_dims_cnt; ++i) {
|
||||
int min_val_idx = -1;
|
||||
for (int j = 0; j < new_dims_cnt; ++j) {
|
||||
if (non_flatten_params->perm[j] >= i &&
|
||||
(min_val_idx == -1 || non_flatten_params->perm[min_val_idx] >
|
||||
non_flatten_params->perm[j])) {
|
||||
min_val_idx = j;
|
||||
}
|
||||
}
|
||||
non_flatten_params->perm[min_val_idx] = i;
|
||||
}
|
||||
|
||||
return flat_size;
|
||||
}
|
||||
|
||||
} // namespace transpose_utils
|
||||
|
||||
} // namespace tflite
|
52
tensorflow/lite/kernels/internal/transpose_utils.h
Normal file
52
tensorflow/lite/kernels/internal/transpose_utils.h
Normal file
@ -0,0 +1,52 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_TRANSPOSE_UTILS_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_TRANSPOSE_UTILS_H_
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace transpose_utils {
|
||||
|
||||
// IsTranspose2DApplicable returns true if the given perm can be lowered to a
|
||||
// 2D transpose op. If possible, it copies the lowered dimension counts to the
|
||||
// given dim0 and dim1 pointers.
|
||||
bool IsTranspose2DApplicable(const TransposeParams& params,
|
||||
const RuntimeShape& input_shape, int* dim0,
|
||||
int* dim1);
|
||||
|
||||
// RemoveOneSizeDimensions removes one size dimensions in the given input/output
|
||||
// shapes and adjusts the parameter values for transpose op.
|
||||
void RemoveOneSizeDimensions(RuntimeShape* input_shape,
|
||||
RuntimeShape* output_shape,
|
||||
TransposeParams* params);
|
||||
|
||||
// Flatten finds the dimensions that can be flatten, shrinks the given shapes
|
||||
// and the given perm parameter to reflect the non-flatten dimensions, and
|
||||
// returns the total size of the non-flatten dimensions.
|
||||
//
|
||||
// E.g, in perm [0, 1, 3, 2] case, the first two dimensions can be flatten and
|
||||
// it returns |Dim Size(2)| x |Dim Size(3)|.
|
||||
size_t Flatten(const RuntimeShape& input_shape,
|
||||
const RuntimeShape& output_shape, const TransposeParams& params,
|
||||
RuntimeShape* non_flatten_input_shape,
|
||||
RuntimeShape* non_flatten_output_shape,
|
||||
TransposeParams* non_flatten_params);
|
||||
|
||||
} // namespace transpose_utils
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_TRANSPOSE_UTILS_H_
|
506
tensorflow/lite/kernels/internal/transpose_utils_test.cc
Normal file
506
tensorflow/lite/kernels/internal/transpose_utils_test.cc
Normal file
@ -0,0 +1,506 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/kernels/internal/transpose_utils.h"
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
TEST(TransposeUtilsTest, RemoveOneSizeDimensions_1DNoChanges) {
|
||||
RuntimeShape input_shape({9});
|
||||
RuntimeShape output_shape({9});
|
||||
|
||||
TransposeParams params;
|
||||
params.perm_count = 1;
|
||||
params.perm[0] = 0;
|
||||
|
||||
transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
|
||||
¶ms);
|
||||
|
||||
EXPECT_EQ(input_shape, RuntimeShape({9}));
|
||||
EXPECT_EQ(output_shape, RuntimeShape({9}));
|
||||
|
||||
EXPECT_EQ(params.perm_count, 1);
|
||||
EXPECT_EQ(params.perm[0], 0);
|
||||
}
|
||||
|
||||
TEST(TransposeUtilsTest, RemoveOneSizeDimensions_2DNoChanges) {
|
||||
RuntimeShape input_shape({9, 3});
|
||||
RuntimeShape output_shape({3, 9});
|
||||
|
||||
TransposeParams params;
|
||||
params.perm_count = 2;
|
||||
params.perm[0] = 1;
|
||||
params.perm[1] = 0;
|
||||
|
||||
transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
|
||||
¶ms);
|
||||
|
||||
EXPECT_EQ(input_shape, RuntimeShape({9, 3}));
|
||||
EXPECT_EQ(output_shape, RuntimeShape({3, 9}));
|
||||
|
||||
EXPECT_EQ(params.perm_count, 2);
|
||||
EXPECT_EQ(params.perm[0], 1);
|
||||
EXPECT_EQ(params.perm[1], 0);
|
||||
}
|
||||
|
||||
TEST(TransposeUtilsTest, RemoveOneSizeDimensions_2DShrinking) {
|
||||
RuntimeShape input_shape({9, 1});
|
||||
RuntimeShape output_shape({1, 9});
|
||||
|
||||
TransposeParams params;
|
||||
params.perm_count = 2;
|
||||
params.perm[0] = 1;
|
||||
params.perm[1] = 0;
|
||||
|
||||
transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
|
||||
¶ms);
|
||||
|
||||
EXPECT_EQ(input_shape, RuntimeShape({9}));
|
||||
EXPECT_EQ(output_shape, RuntimeShape({9}));
|
||||
|
||||
EXPECT_EQ(params.perm_count, 1);
|
||||
EXPECT_EQ(params.perm[0], 0);
|
||||
}
|
||||
|
||||
TEST(TransposeUtilsTest, RemoveOneSizeDimensions_3DNoChanges) {
|
||||
RuntimeShape input_shape({4, 3, 8});
|
||||
RuntimeShape output_shape({8, 4, 3});
|
||||
|
||||
TransposeParams params;
|
||||
params.perm_count = 3;
|
||||
params.perm[0] = 2;
|
||||
params.perm[1] = 0;
|
||||
params.perm[2] = 1;
|
||||
|
||||
transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
|
||||
¶ms);
|
||||
|
||||
EXPECT_EQ(input_shape, RuntimeShape({4, 3, 8}));
|
||||
EXPECT_EQ(output_shape, RuntimeShape({8, 4, 3}));
|
||||
|
||||
EXPECT_EQ(params.perm_count, 3);
|
||||
EXPECT_EQ(params.perm[0], 2);
|
||||
EXPECT_EQ(params.perm[1], 0);
|
||||
EXPECT_EQ(params.perm[2], 1);
|
||||
}
|
||||
|
||||
TEST(TransposeUtilsTest, RemoveOneSizeDimensions_3DShrinkingOnce) {
|
||||
RuntimeShape input_shape({4, 1, 8});
|
||||
RuntimeShape output_shape({8, 4, 1});
|
||||
|
||||
TransposeParams params;
|
||||
params.perm_count = 3;
|
||||
params.perm[0] = 2;
|
||||
params.perm[1] = 0;
|
||||
params.perm[2] = 1;
|
||||
|
||||
transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
|
||||
¶ms);
|
||||
|
||||
EXPECT_EQ(input_shape, RuntimeShape({4, 8}));
|
||||
EXPECT_EQ(output_shape, RuntimeShape({8, 4}));
|
||||
EXPECT_EQ(output_shape.Dims(1), 4);
|
||||
|
||||
EXPECT_EQ(params.perm_count, 2);
|
||||
EXPECT_EQ(params.perm[0], 1);
|
||||
EXPECT_EQ(params.perm[1], 0);
|
||||
}
|
||||
|
||||
TEST(TransposeUtilsTest, RemoveOneSizeDimensions_3DShrinkingTwice) {
|
||||
RuntimeShape input_shape({4, 1, 1});
|
||||
RuntimeShape output_shape({1, 4, 1});
|
||||
|
||||
TransposeParams params;
|
||||
params.perm_count = 3;
|
||||
params.perm[0] = 2;
|
||||
params.perm[1] = 0;
|
||||
params.perm[2] = 1;
|
||||
|
||||
transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
|
||||
¶ms);
|
||||
|
||||
EXPECT_EQ(input_shape, RuntimeShape({4}));
|
||||
EXPECT_EQ(output_shape, RuntimeShape({4}));
|
||||
|
||||
EXPECT_EQ(params.perm_count, 1);
|
||||
EXPECT_EQ(params.perm[0], 0);
|
||||
}
|
||||
|
||||
TEST(TransposeUtilsTest, RemoveOneSizeDimensions_3DAllOnes) {
|
||||
RuntimeShape input_shape({1, 1, 1});
|
||||
RuntimeShape output_shape({1, 1, 1});
|
||||
|
||||
TransposeParams params;
|
||||
params.perm_count = 3;
|
||||
params.perm[0] = 2;
|
||||
params.perm[1] = 0;
|
||||
params.perm[2] = 1;
|
||||
|
||||
transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
|
||||
¶ms);
|
||||
|
||||
EXPECT_EQ(input_shape, RuntimeShape({1}));
|
||||
EXPECT_EQ(output_shape, RuntimeShape({1}));
|
||||
|
||||
EXPECT_EQ(params.perm_count, 1);
|
||||
EXPECT_EQ(params.perm[0], 0);
|
||||
}
|
||||
|
||||
TEST(TransposeUtilsTest, RemoveOneSizeDimensions_4DNoChanges) {
|
||||
RuntimeShape input_shape({9, 3, 2, 4});
|
||||
RuntimeShape output_shape({3, 9, 4, 2});
|
||||
|
||||
TransposeParams params;
|
||||
params.perm_count = 4;
|
||||
params.perm[0] = 1;
|
||||
params.perm[1] = 0;
|
||||
params.perm[2] = 3;
|
||||
params.perm[3] = 2;
|
||||
|
||||
transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
|
||||
¶ms);
|
||||
|
||||
EXPECT_EQ(input_shape, RuntimeShape({9, 3, 2, 4}));
|
||||
EXPECT_EQ(output_shape, RuntimeShape({3, 9, 4, 2}));
|
||||
|
||||
EXPECT_EQ(params.perm_count, 4);
|
||||
EXPECT_EQ(params.perm[0], 1);
|
||||
EXPECT_EQ(params.perm[1], 0);
|
||||
EXPECT_EQ(params.perm[2], 3);
|
||||
EXPECT_EQ(params.perm[3], 2);
|
||||
}
|
||||
|
||||
TEST(TransposeUtilsTest, RemoveOneSizeDimensions_4DShrinkingOnce) {
|
||||
RuntimeShape input_shape({9, 3, 1, 4});
|
||||
RuntimeShape output_shape({3, 9, 4, 1});
|
||||
|
||||
TransposeParams params;
|
||||
params.perm_count = 4;
|
||||
params.perm[0] = 1;
|
||||
params.perm[1] = 0;
|
||||
params.perm[2] = 3;
|
||||
params.perm[3] = 2;
|
||||
|
||||
transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
|
||||
¶ms);
|
||||
|
||||
EXPECT_EQ(input_shape, RuntimeShape({9, 3, 4}));
|
||||
EXPECT_EQ(output_shape, RuntimeShape({3, 9, 4}));
|
||||
|
||||
EXPECT_EQ(params.perm_count, 3);
|
||||
EXPECT_EQ(params.perm[0], 1);
|
||||
EXPECT_EQ(params.perm[1], 0);
|
||||
EXPECT_EQ(params.perm[2], 2);
|
||||
}
|
||||
|
||||
TEST(TransposeUtilsTest, RemoveOneSizeDimensions_4DShrinkingTwice) {
|
||||
RuntimeShape input_shape({1, 3, 1, 4});
|
||||
RuntimeShape output_shape({3, 1, 4, 1});
|
||||
|
||||
TransposeParams params;
|
||||
params.perm_count = 4;
|
||||
params.perm[0] = 1;
|
||||
params.perm[1] = 2;
|
||||
params.perm[2] = 3;
|
||||
params.perm[3] = 0;
|
||||
|
||||
transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
|
||||
¶ms);
|
||||
|
||||
EXPECT_EQ(input_shape, RuntimeShape({3, 4}));
|
||||
EXPECT_EQ(output_shape, RuntimeShape({3, 4}));
|
||||
|
||||
EXPECT_EQ(params.perm_count, 2);
|
||||
EXPECT_EQ(params.perm[0], 0);
|
||||
EXPECT_EQ(params.perm[1], 1);
|
||||
}
|
||||
|
||||
TEST(TransposeUtilsTest, RemoveOneSizeDimensions_4DShrinkingThirdTimes) {
|
||||
RuntimeShape input_shape({1, 1, 7, 1});
|
||||
RuntimeShape output_shape({1, 7, 1, 1});
|
||||
|
||||
TransposeParams params;
|
||||
params.perm_count = 4;
|
||||
params.perm[0] = 0;
|
||||
params.perm[1] = 2;
|
||||
params.perm[2] = 1;
|
||||
params.perm[3] = 3;
|
||||
|
||||
transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
|
||||
¶ms);
|
||||
|
||||
EXPECT_EQ(input_shape, RuntimeShape({7}));
|
||||
EXPECT_EQ(output_shape, RuntimeShape({7}));
|
||||
|
||||
EXPECT_EQ(params.perm_count, 1);
|
||||
EXPECT_EQ(params.perm[0], 0);
|
||||
}
|
||||
|
||||
TEST(TransposeUtilsTest, RemoveOneSizeDimensions_4DAllOnes) {
|
||||
RuntimeShape input_shape({1, 1, 1, 1});
|
||||
RuntimeShape output_shape({1, 1, 1, 1});
|
||||
|
||||
TransposeParams params;
|
||||
params.perm_count = 4;
|
||||
params.perm[0] = 0;
|
||||
params.perm[1] = 2;
|
||||
params.perm[2] = 1;
|
||||
params.perm[3] = 3;
|
||||
|
||||
transpose_utils::RemoveOneSizeDimensions(&input_shape, &output_shape,
|
||||
¶ms);
|
||||
|
||||
EXPECT_EQ(input_shape, RuntimeShape({1}));
|
||||
EXPECT_EQ(output_shape, RuntimeShape({1}));
|
||||
|
||||
EXPECT_EQ(params.perm_count, 1);
|
||||
EXPECT_EQ(params.perm[0], 0);
|
||||
}
|
||||
|
||||
TEST(TransposeUtilsTest, Flatten3D) {
|
||||
RuntimeShape input_shape({3, 5, 7});
|
||||
RuntimeShape output_shape({3, 7, 5});
|
||||
|
||||
TransposeParams params;
|
||||
params.perm_count = 3;
|
||||
params.perm[0] = 0;
|
||||
params.perm[1] = 2;
|
||||
params.perm[2] = 1;
|
||||
|
||||
RuntimeShape non_flatten_input_shape;
|
||||
RuntimeShape non_flatten_output_shape;
|
||||
TransposeParams non_flatten_params;
|
||||
size_t non_flatten_size = transpose_utils::Flatten(
|
||||
input_shape, output_shape, params, &non_flatten_input_shape,
|
||||
&non_flatten_output_shape, &non_flatten_params);
|
||||
|
||||
EXPECT_EQ(non_flatten_input_shape, RuntimeShape({5, 7}));
|
||||
EXPECT_EQ(non_flatten_output_shape, RuntimeShape({7, 5}));
|
||||
EXPECT_EQ(non_flatten_size, 5 * 7);
|
||||
|
||||
EXPECT_EQ(non_flatten_params.perm_count, 2);
|
||||
EXPECT_EQ(non_flatten_params.perm[0], 1);
|
||||
EXPECT_EQ(non_flatten_params.perm[1], 0);
|
||||
}
|
||||
|
||||
TEST(TransposeUtilsTest, Flatten4DFlattenOnce) {
|
||||
RuntimeShape input_shape({3, 5, 7, 9});
|
||||
RuntimeShape output_shape({3, 7, 5, 9});
|
||||
|
||||
TransposeParams params;
|
||||
params.perm_count = 4;
|
||||
params.perm[0] = 0;
|
||||
params.perm[1] = 2;
|
||||
params.perm[2] = 1;
|
||||
params.perm[3] = 3;
|
||||
|
||||
RuntimeShape non_flatten_input_shape;
|
||||
RuntimeShape non_flatten_output_shape;
|
||||
TransposeParams non_flatten_params;
|
||||
size_t non_flatten_size = transpose_utils::Flatten(
|
||||
input_shape, output_shape, params, &non_flatten_input_shape,
|
||||
&non_flatten_output_shape, &non_flatten_params);
|
||||
|
||||
EXPECT_EQ(non_flatten_input_shape, RuntimeShape({5, 7, 9}));
|
||||
EXPECT_EQ(non_flatten_output_shape, RuntimeShape({7, 5, 9}));
|
||||
EXPECT_EQ(non_flatten_size, 5 * 7 * 9);
|
||||
|
||||
EXPECT_EQ(non_flatten_params.perm_count, 3);
|
||||
EXPECT_EQ(non_flatten_params.perm[0], 1);
|
||||
EXPECT_EQ(non_flatten_params.perm[1], 0);
|
||||
EXPECT_EQ(non_flatten_params.perm[2], 2);
|
||||
}
|
||||
|
||||
TEST(TransposeUtilsTest, Flatten4DFlattenTwice) {
|
||||
RuntimeShape input_shape({3, 5, 7, 9});
|
||||
RuntimeShape output_shape({3, 5, 9, 7});
|
||||
|
||||
TransposeParams params;
|
||||
params.perm_count = 4;
|
||||
params.perm[0] = 0;
|
||||
params.perm[1] = 1;
|
||||
params.perm[2] = 3;
|
||||
params.perm[3] = 2;
|
||||
|
||||
RuntimeShape non_flatten_input_shape;
|
||||
RuntimeShape non_flatten_output_shape;
|
||||
TransposeParams non_flatten_params;
|
||||
size_t non_flatten_size = transpose_utils::Flatten(
|
||||
input_shape, output_shape, params, &non_flatten_input_shape,
|
||||
&non_flatten_output_shape, &non_flatten_params);
|
||||
|
||||
EXPECT_EQ(non_flatten_input_shape, RuntimeShape({7, 9}));
|
||||
EXPECT_EQ(non_flatten_output_shape, RuntimeShape({9, 7}));
|
||||
EXPECT_EQ(non_flatten_size, 7 * 9);
|
||||
|
||||
EXPECT_EQ(non_flatten_params.perm_count, 2);
|
||||
EXPECT_EQ(non_flatten_params.perm[0], 1);
|
||||
EXPECT_EQ(non_flatten_params.perm[1], 0);
|
||||
}
|
||||
|
||||
TEST(TransposeUtilsTest, IsTranspose2DApplicable2D) {
|
||||
RuntimeShape input_shape({4, 5});
|
||||
|
||||
TransposeParams params;
|
||||
params.perm_count = 2;
|
||||
params.perm[0] = 1;
|
||||
params.perm[1] = 0;
|
||||
|
||||
int dim0, dim1;
|
||||
bool applicable = transpose_utils::IsTranspose2DApplicable(
|
||||
params, input_shape, &dim0, &dim1);
|
||||
|
||||
EXPECT_TRUE(applicable);
|
||||
EXPECT_EQ(dim0, 4);
|
||||
EXPECT_EQ(dim1, 5);
|
||||
}
|
||||
|
||||
TEST(TransposeUtilsTest, IsTranspose2DApplicable3DOne) {
|
||||
RuntimeShape input_shape({4, 5, 6});
|
||||
|
||||
TransposeParams params;
|
||||
params.perm_count = 3;
|
||||
params.perm[0] = 1;
|
||||
params.perm[1] = 2;
|
||||
params.perm[2] = 0;
|
||||
|
||||
int dim0, dim1;
|
||||
bool applicable = transpose_utils::IsTranspose2DApplicable(
|
||||
params, input_shape, &dim0, &dim1);
|
||||
|
||||
EXPECT_TRUE(applicable);
|
||||
EXPECT_EQ(dim0, 4);
|
||||
EXPECT_EQ(dim1, 30);
|
||||
}
|
||||
|
||||
TEST(TransposeUtilsTest, IsTranspose2DApplicable3DTwo) {
|
||||
RuntimeShape input_shape({4, 5, 6});
|
||||
|
||||
TransposeParams params;
|
||||
params.perm_count = 3;
|
||||
params.perm[0] = 2;
|
||||
params.perm[1] = 0;
|
||||
params.perm[2] = 1;
|
||||
|
||||
int dim0, dim1;
|
||||
bool applicable = transpose_utils::IsTranspose2DApplicable(
|
||||
params, input_shape, &dim0, &dim1);
|
||||
|
||||
EXPECT_TRUE(applicable);
|
||||
EXPECT_EQ(dim0, 20);
|
||||
EXPECT_EQ(dim1, 6);
|
||||
}
|
||||
|
||||
TEST(TransposeUtilsTest, IsTranspose2DApplicable3DNotApplicable) {
|
||||
RuntimeShape input_shape({4, 5, 6});
|
||||
|
||||
TransposeParams params;
|
||||
params.perm_count = 3;
|
||||
params.perm[0] = 2;
|
||||
params.perm[1] = 1;
|
||||
params.perm[2] = 0;
|
||||
|
||||
int dim0, dim1;
|
||||
bool applicable = transpose_utils::IsTranspose2DApplicable(
|
||||
params, input_shape, &dim0, &dim1);
|
||||
|
||||
EXPECT_FALSE(applicable);
|
||||
}
|
||||
|
||||
TEST(TransposeUtilsTest, IsTranspose2DApplicable4DOne) {
|
||||
RuntimeShape input_shape({4, 5, 6, 7});
|
||||
|
||||
TransposeParams params;
|
||||
params.perm_count = 4;
|
||||
params.perm[0] = 1;
|
||||
params.perm[1] = 2;
|
||||
params.perm[2] = 3;
|
||||
params.perm[3] = 0;
|
||||
|
||||
int dim0, dim1;
|
||||
bool applicable = transpose_utils::IsTranspose2DApplicable(
|
||||
params, input_shape, &dim0, &dim1);
|
||||
|
||||
EXPECT_TRUE(applicable);
|
||||
EXPECT_EQ(dim0, 4);
|
||||
EXPECT_EQ(dim1, 210);
|
||||
}
|
||||
|
||||
TEST(TransposeUtilsTest, IsTranspose2DApplicable4DTwo) {
|
||||
RuntimeShape input_shape({4, 5, 6, 7});
|
||||
|
||||
TransposeParams params;
|
||||
params.perm_count = 4;
|
||||
params.perm[0] = 2;
|
||||
params.perm[1] = 3;
|
||||
params.perm[2] = 0;
|
||||
params.perm[3] = 1;
|
||||
|
||||
int dim0, dim1;
|
||||
bool applicable = transpose_utils::IsTranspose2DApplicable(
|
||||
params, input_shape, &dim0, &dim1);
|
||||
|
||||
EXPECT_TRUE(applicable);
|
||||
EXPECT_EQ(dim0, 20);
|
||||
EXPECT_EQ(dim1, 42);
|
||||
}
|
||||
|
||||
TEST(TransposeUtilsTest, IsTranspose2DApplicable4DThird) {
|
||||
RuntimeShape input_shape({4, 5, 6, 7});
|
||||
|
||||
TransposeParams params;
|
||||
params.perm_count = 4;
|
||||
params.perm[0] = 3;
|
||||
params.perm[1] = 0;
|
||||
params.perm[2] = 1;
|
||||
params.perm[3] = 2;
|
||||
|
||||
int dim0, dim1;
|
||||
bool applicable = transpose_utils::IsTranspose2DApplicable(
|
||||
params, input_shape, &dim0, &dim1);
|
||||
|
||||
EXPECT_TRUE(applicable);
|
||||
EXPECT_EQ(dim0, 120);
|
||||
EXPECT_EQ(dim1, 7);
|
||||
}
|
||||
|
||||
TEST(TransposeUtilsTest, IsTranspose2DApplicable4DNotApplicable) {
|
||||
RuntimeShape input_shape({4, 5, 6, 7});
|
||||
|
||||
TransposeParams params;
|
||||
params.perm_count = 4;
|
||||
params.perm[0] = 3;
|
||||
params.perm[1] = 2;
|
||||
params.perm[2] = 1;
|
||||
params.perm[3] = 0;
|
||||
|
||||
int dim0, dim1;
|
||||
bool applicable = transpose_utils::IsTranspose2DApplicable(
|
||||
params, input_shape, &dim0, &dim1);
|
||||
|
||||
EXPECT_FALSE(applicable);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
// On Linux, add: tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -100,18 +100,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const int size = op_context.perm->dims->data[0];
|
||||
TransposeParams params;
|
||||
params.perm_count = size;
|
||||
bool identical = true;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
params.perm[i] = perm_data[i];
|
||||
if (perm_data[i] != i) identical = false;
|
||||
}
|
||||
|
||||
// TODO(b/140779653): Add an optimization pass in the conversion process to
|
||||
// remove transpose op nodes where they do nothing like the below one.
|
||||
if (identical) {
|
||||
memcpy(op_context.output->data.raw, op_context.input->data.raw,
|
||||
op_context.output->bytes);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
#define TF_LITE_TRANSPOSE(type, scalar) \
|
||||
@ -120,28 +110,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetTensorShape(op_context.output), \
|
||||
GetTensorData<scalar>(op_context.output))
|
||||
|
||||
// Transpose kernel only does rearranging values not numeric evaluations on
|
||||
// each cell. It's safe to implement per size of scalar type and this trick
|
||||
// keeps the total code size in a reasonable range.
|
||||
switch (op_context.input->type) {
|
||||
case kTfLiteFloat32:
|
||||
if (kernel_type == kGenericOptimized) {
|
||||
TF_LITE_TRANSPOSE(optimized_ops, float);
|
||||
} else {
|
||||
TF_LITE_TRANSPOSE(reference_ops, float);
|
||||
}
|
||||
break;
|
||||
case kTfLiteUInt8:
|
||||
if (kernel_type == kGenericOptimized) {
|
||||
TF_LITE_TRANSPOSE(optimized_ops, uint8_t);
|
||||
} else {
|
||||
TF_LITE_TRANSPOSE(reference_ops, uint8_t);
|
||||
}
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
if (kernel_type == kGenericOptimized) {
|
||||
TF_LITE_TRANSPOSE(optimized_ops, int8_t);
|
||||
} else {
|
||||
TF_LITE_TRANSPOSE(reference_ops, int8_t);
|
||||
}
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
if (kernel_type == kGenericOptimized) {
|
||||
TF_LITE_TRANSPOSE(optimized_ops, int32_t);
|
||||
@ -149,16 +122,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_TRANSPOSE(reference_ops, int32_t);
|
||||
}
|
||||
break;
|
||||
case kTfLiteInt64:
|
||||
case kTfLiteUInt8:
|
||||
case kTfLiteInt8:
|
||||
if (kernel_type == kGenericOptimized) {
|
||||
TF_LITE_TRANSPOSE(optimized_ops, int64_t);
|
||||
TF_LITE_TRANSPOSE(optimized_ops, int8_t);
|
||||
} else {
|
||||
TF_LITE_TRANSPOSE(reference_ops, int64_t);
|
||||
TF_LITE_TRANSPOSE(reference_ops, int8_t);
|
||||
}
|
||||
break;
|
||||
case kTfLiteInt64:
|
||||
TF_LITE_TRANSPOSE(reference_ops, int64_t);
|
||||
break;
|
||||
case kTfLiteBool:
|
||||
if (kernel_type == kGenericOptimized) {
|
||||
TF_LITE_TRANSPOSE(optimized_ops, bool);
|
||||
if (sizeof(bool) == 1) {
|
||||
if (kernel_type == kGenericOptimized) {
|
||||
TF_LITE_TRANSPOSE(optimized_ops, int8_t);
|
||||
} else {
|
||||
TF_LITE_TRANSPOSE(reference_ops, int8_t);
|
||||
}
|
||||
} else {
|
||||
TF_LITE_TRANSPOSE(reference_ops, bool);
|
||||
}
|
||||
|
@ -346,6 +346,167 @@ TEST(TransposeTest, Test3DInputDynamicTensor) {
|
||||
2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23}));
|
||||
}
|
||||
|
||||
TEST(TransposeTest, Test1DNotShrinked) {
|
||||
TransposeOpConstModel m({1}, {1}, {0});
|
||||
m.SetInput({0});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0}));
|
||||
}
|
||||
|
||||
TEST(TransposeTest, Test2DShrinkedOneTime) {
|
||||
TransposeOpConstModel m({2, 1}, {2}, {1, 0});
|
||||
m.SetInput({0, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 1}));
|
||||
}
|
||||
|
||||
TEST(TransposeTest, Test2DShrinkedTwoTimes) {
|
||||
TransposeOpConstModel m({1, 1}, {2}, {1, 0});
|
||||
m.SetInput({0});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0}));
|
||||
}
|
||||
|
||||
TEST(TransposeTest, Test3DShrinkedOneTime) {
|
||||
TransposeOpConstModel m({2, 1, 3}, {3}, {0, 2, 1});
|
||||
m.SetInput({0, 1, 2, 3, 4, 5});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 1}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 1, 2, 3, 4, 5}));
|
||||
}
|
||||
|
||||
TEST(TransposeTest, Test3DShrinkedTwoTimes) {
|
||||
TransposeOpConstModel m({1, 1, 3}, {3}, {1, 2, 0});
|
||||
m.SetInput({0, 1, 2});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 1, 2}));
|
||||
}
|
||||
|
||||
TEST(TransposeTest, Test3DShrinkedAll) {
|
||||
TransposeOpConstModel m({1, 1, 1}, {3}, {1, 2, 0});
|
||||
m.SetInput({0});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0}));
|
||||
}
|
||||
|
||||
TEST(TransposeTest, Test4DShrinkedOneTimes) {
|
||||
TransposeOpConstModel m({2, 2, 3, 1}, {4}, {3, 0, 1, 2});
|
||||
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 2, 3}));
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
ElementsAreArray({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}));
|
||||
}
|
||||
|
||||
TEST(TransposeTest, Test4DShrinkedTwoTimes) {
|
||||
TransposeOpConstModel m({2, 1, 3, 1}, {4}, {0, 3, 1, 2});
|
||||
m.SetInput({0, 1, 2, 3, 4, 5});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 1, 3}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 1, 2, 3, 4, 5}));
|
||||
}
|
||||
|
||||
TEST(TransposeTest, Test4DShrinkedThirdTimes) {
|
||||
TransposeOpConstModel m({2, 1, 1, 1}, {4}, {3, 2, 1, 0});
|
||||
m.SetInput({0, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 2}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 1}));
|
||||
}
|
||||
|
||||
TEST(TransposeTest, Test4DShrinkedFourTimes) {
|
||||
TransposeOpConstModel m({1, 1, 1, 1}, {4}, {2, 3, 1, 0});
|
||||
m.SetInput({0});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 1}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0}));
|
||||
}
|
||||
|
||||
TEST(TransposeTest, Test3DFlatten) {
|
||||
TransposeOpConstModel m({2, 2, 3}, {3}, {0, 2, 1});
|
||||
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 2}));
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
ElementsAreArray({0, 3, 1, 4, 2, 5, 6, 9, 7, 10, 8, 11}));
|
||||
}
|
||||
|
||||
TEST(TransposeTest, Test4DFlattenOne) {
|
||||
TransposeOpConstModel m({2, 2, 2, 2}, {4}, {0, 1, 3, 2});
|
||||
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2, 2}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 1, 3, 4, 6, 5, 7, 8, 10, 9,
|
||||
11, 12, 14, 13, 15}));
|
||||
}
|
||||
|
||||
TEST(TransposeTest, Test4DFlattenTwo) {
|
||||
TransposeOpConstModel m({2, 2, 2, 2}, {4}, {0, 2, 3, 1});
|
||||
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2, 2}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 4, 1, 5, 2, 6, 3, 7, 8, 12, 9,
|
||||
13, 10, 14, 11, 15}));
|
||||
}
|
||||
|
||||
TEST(TransposeTest, 3DDividedIntoTwo2DsOne) {
|
||||
std::vector<float> out;
|
||||
RunTestPermutation({2, 3, 4}, {1, 2, 0}, &out);
|
||||
TransposeOpConstModel m({2, 3, 4}, {3}, {1, 2, 0});
|
||||
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23});
|
||||
m.Invoke();
|
||||
EXPECT_EQ(m.GetOutput(), out);
|
||||
}
|
||||
|
||||
TEST(TransposeTest, 3DDividedIntoTwo2DsTwo) {
|
||||
std::vector<float> out;
|
||||
RunTestPermutation({2, 3, 4}, {2, 0, 1}, &out);
|
||||
TransposeOpConstModel m({2, 3, 4}, {3}, {2, 0, 1});
|
||||
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23});
|
||||
m.Invoke();
|
||||
EXPECT_EQ(m.GetOutput(), out);
|
||||
}
|
||||
|
||||
TEST(TransposeTest, 4DDividedIntoTwo2DsOne) {
|
||||
std::vector<float> out;
|
||||
RunTestPermutation({2, 3, 4, 2}, {1, 2, 3, 0}, &out);
|
||||
TransposeOpConstModel m({2, 3, 4, 2}, {4}, {1, 2, 3, 0});
|
||||
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
|
||||
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47});
|
||||
m.Invoke();
|
||||
EXPECT_EQ(m.GetOutput(), out);
|
||||
}
|
||||
|
||||
TEST(TransposeTest, 4DDividedIntoTwo2DsTwo) {
|
||||
std::vector<float> out;
|
||||
RunTestPermutation({2, 3, 4, 2}, {2, 3, 0, 1}, &out);
|
||||
TransposeOpConstModel m({2, 3, 4, 2}, {4}, {2, 3, 0, 1});
|
||||
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
|
||||
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47});
|
||||
m.Invoke();
|
||||
EXPECT_EQ(m.GetOutput(), out);
|
||||
}
|
||||
|
||||
TEST(TransposeTest, 4DDividedIntoTwo2DsThird) {
|
||||
std::vector<float> out;
|
||||
RunTestPermutation({2, 3, 4, 2}, {3, 0, 1, 2}, &out);
|
||||
TransposeOpConstModel m({2, 3, 4, 2}, {4}, {3, 0, 1, 2});
|
||||
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
|
||||
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47});
|
||||
m.Invoke();
|
||||
EXPECT_EQ(m.GetOutput(), out);
|
||||
}
|
||||
|
||||
#ifdef GTEST_HAS_DEATH_TEST
|
||||
TEST(TransposeTest, Test5DInputTensor) {
|
||||
EXPECT_DEATH(TransposeOpConstModel({1, 2, 3, 4, 5}, {5}, {0, 1, 2, 3, 4}),
|
||||
|
Loading…
Reference in New Issue
Block a user