Transpose op optimization
PiperOrigin-RevId: 268194603
This commit is contained in:
parent
b266a849d8
commit
2a7d2c4295
@ -72,7 +72,6 @@ using reference_ops::SpaceToBatchND;
|
|||||||
using reference_ops::Split;
|
using reference_ops::Split;
|
||||||
using reference_ops::StridedSlice;
|
using reference_ops::StridedSlice;
|
||||||
using reference_ops::TensorFlowSplit;
|
using reference_ops::TensorFlowSplit;
|
||||||
using reference_ops::Transpose;
|
|
||||||
|
|
||||||
static constexpr int kDepthwiseReverseShift = -1;
|
static constexpr int kDepthwiseReverseShift = -1;
|
||||||
|
|
||||||
@ -4918,6 +4917,18 @@ inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
|
|||||||
DimsToShape(output_dims), output_data);
|
DimsToShape(output_dims), output_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void Transpose(const T* input, const Dims<4>& input_dims, T* output,
|
||||||
|
const Dims<4>& output_dims, const int* permuted_axes) {
|
||||||
|
TransposeParams params;
|
||||||
|
params.perm_count = 4;
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
params.perm[i] = 3 - permuted_axes[3 - i];
|
||||||
|
}
|
||||||
|
Transpose(params, DimsToShape(input_dims), input, DimsToShape(output_dims),
|
||||||
|
output);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace optimized_ops
|
} // namespace optimized_ops
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
|
||||||
|
@ -98,7 +98,6 @@ using reference_ops::SpaceToBatchND;
|
|||||||
using reference_ops::Split;
|
using reference_ops::Split;
|
||||||
using reference_ops::StridedSlice;
|
using reference_ops::StridedSlice;
|
||||||
using reference_ops::Sub16;
|
using reference_ops::Sub16;
|
||||||
using reference_ops::Transpose;
|
|
||||||
|
|
||||||
// TODO(b/80247582) Remove this constant.
|
// TODO(b/80247582) Remove this constant.
|
||||||
// This will be phased out as the shifts are revised with more thought. Use of a
|
// This will be phased out as the shifts are revised with more thought. Use of a
|
||||||
@ -180,6 +179,12 @@ struct TTypes {
|
|||||||
typedef Eigen::TensorMap<
|
typedef Eigen::TensorMap<
|
||||||
Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>>
|
Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>>
|
||||||
UnalignedConstMatrix;
|
UnalignedConstMatrix;
|
||||||
|
typedef Eigen::TensorMap<
|
||||||
|
Eigen::Tensor<const T, NDIMS, Eigen::RowMajor, IndexType>, Eigen::Aligned>
|
||||||
|
ConstTensor;
|
||||||
|
typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, IndexType>,
|
||||||
|
Eigen::Aligned>
|
||||||
|
Tensor;
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO(b/62193649): this function is only needed as long
|
// TODO(b/62193649): this function is only needed as long
|
||||||
@ -6694,6 +6699,171 @@ inline void Logistic16bitPercision(const LogisticParams& params,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Transpose2DOn32bitMatrix only deals with typical 2D matrix transpose ops.
|
||||||
|
inline void Transpose2DOn32bitMatrix(const TransposeParams& params,
|
||||||
|
const RuntimeShape& input_shape,
|
||||||
|
const int32_t* input_data,
|
||||||
|
const RuntimeShape& output_shape,
|
||||||
|
int32_t* output_data) {
|
||||||
|
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 2);
|
||||||
|
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2);
|
||||||
|
TFLITE_DCHECK_EQ(params.perm_count, 2);
|
||||||
|
TFLITE_DCHECK_EQ(params.perm[0], 1);
|
||||||
|
TFLITE_DCHECK_EQ(params.perm[1], 0);
|
||||||
|
|
||||||
|
const int d0 = input_shape.DimsData()[0];
|
||||||
|
const int d1 = input_shape.DimsData()[1];
|
||||||
|
#ifdef USE_NEON
|
||||||
|
const int kLines = 4;
|
||||||
|
const int kSkipSize = (kLines - 1) * d1;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
const int32_t* input = input_data;
|
||||||
|
|
||||||
|
int i = 0;
|
||||||
|
#ifdef USE_NEON
|
||||||
|
for (; i <= d0 - kLines; i += kLines) {
|
||||||
|
int32_t* output = output_data + i;
|
||||||
|
|
||||||
|
const int32_t* input_ptr = input;
|
||||||
|
__builtin_prefetch(input_ptr, 0, 3);
|
||||||
|
input_ptr += d1;
|
||||||
|
__builtin_prefetch(input_ptr, 0, 3);
|
||||||
|
input_ptr += d1;
|
||||||
|
__builtin_prefetch(input_ptr, 0, 3);
|
||||||
|
input_ptr += d1;
|
||||||
|
__builtin_prefetch(input_ptr, 0, 3);
|
||||||
|
|
||||||
|
int j = 0;
|
||||||
|
for (; j <= d1 - kLines; j += kLines) {
|
||||||
|
input_ptr = input;
|
||||||
|
int32x4_t a0 = vld1q_s32(input);
|
||||||
|
input_ptr += d1;
|
||||||
|
int32x4_t a1 = vld1q_s32(input_ptr);
|
||||||
|
input_ptr += d1;
|
||||||
|
int32x4_t a2 = vld1q_s32(input_ptr);
|
||||||
|
input_ptr += d1;
|
||||||
|
int32x4_t a3 = vld1q_s32(input_ptr);
|
||||||
|
|
||||||
|
int32x4x2_t tmp1 = vuzpq_s32(a0, a2);
|
||||||
|
int32x4x2_t tmp2 = vuzpq_s32(a1, a3);
|
||||||
|
int32x4x2_t tmp3 = vtrnq_s32(tmp1.val[0], tmp2.val[0]);
|
||||||
|
int32x4x2_t tmp4 = vtrnq_s32(tmp1.val[1], tmp2.val[1]);
|
||||||
|
|
||||||
|
vst1q_s32(output, tmp3.val[0]);
|
||||||
|
output += d0;
|
||||||
|
vst1q_s32(output, tmp4.val[0]);
|
||||||
|
output += d0;
|
||||||
|
vst1q_s32(output, tmp3.val[1]);
|
||||||
|
output += d0;
|
||||||
|
vst1q_s32(output, tmp4.val[1]);
|
||||||
|
output += d0;
|
||||||
|
input += kLines;
|
||||||
|
}
|
||||||
|
if (j == d1) {
|
||||||
|
input += kSkipSize;
|
||||||
|
} else {
|
||||||
|
for (int p = 0; p < kLines; ++p) {
|
||||||
|
for (int q = 0; q < d1 - j; ++q) {
|
||||||
|
*(output + q * d0 + p) = *(input + p * d1 + q);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input += (d1 - j) + kSkipSize;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
for (; i < d0; ++i) {
|
||||||
|
int32_t* output = output_data + i;
|
||||||
|
for (int j = 0; j < d1; ++j) {
|
||||||
|
*output = *input;
|
||||||
|
output += d0;
|
||||||
|
++input;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void TransposeImpl(const TransposeParams& params,
|
||||||
|
const RuntimeShape& unextended_input_shape,
|
||||||
|
const T* input_data,
|
||||||
|
const RuntimeShape& unextended_output_shape,
|
||||||
|
T* output_data) {
|
||||||
|
const int unextended_output_size = unextended_input_shape.DimensionsCount();
|
||||||
|
const RuntimeShape input_shape =
|
||||||
|
RuntimeShape::ExtendedShape(4, unextended_input_shape);
|
||||||
|
const RuntimeShape output_shape =
|
||||||
|
RuntimeShape::ExtendedShape(4, unextended_output_shape);
|
||||||
|
const int input_ext_size = 4 - unextended_input_shape.DimensionsCount();
|
||||||
|
const int output_ext_size = 4 - unextended_output_size;
|
||||||
|
|
||||||
|
// The perm data is extended to match the output, each index incremented by
|
||||||
|
// the amount of front padding of the input shape.
|
||||||
|
int extended_perm[4];
|
||||||
|
for (int i = 0; i < output_ext_size; ++i) {
|
||||||
|
extended_perm[i] = i;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < unextended_output_size; ++i) {
|
||||||
|
extended_perm[i + output_ext_size] = params.perm[i] + input_ext_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
Eigen::array<int, 4> p;
|
||||||
|
for (int i = 0; i < 4; ++i) p[i] = extended_perm[i];
|
||||||
|
Eigen::DSizes<Eigen::DenseIndex, 4> input_dsizes;
|
||||||
|
for (int d = 0; d < 4; d++) {
|
||||||
|
input_dsizes[d] = static_cast<Eigen::DenseIndex>(input_shape.Dims(d));
|
||||||
|
}
|
||||||
|
Eigen::DSizes<Eigen::DenseIndex, 4> output_dsizes;
|
||||||
|
for (int d = 0; d < 4; d++) {
|
||||||
|
output_dsizes[d] = static_cast<Eigen::DenseIndex>(output_shape.Dims(d));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto x = typename TTypes<T, 4>::ConstTensor(input_data, input_dsizes);
|
||||||
|
auto y = typename TTypes<T, 4>::Tensor(output_data, output_dsizes);
|
||||||
|
y = x.shuffle(p);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void Transpose(const TransposeParams& params,
|
||||||
|
const RuntimeShape& unextended_input_shape, const T* input_data,
|
||||||
|
const RuntimeShape& unextended_output_shape, T* output_data) {
|
||||||
|
const int unextended_output_size = unextended_output_shape.DimensionsCount();
|
||||||
|
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
|
||||||
|
TFLITE_DCHECK_LE(unextended_output_size, 4);
|
||||||
|
TFLITE_DCHECK_EQ(unextended_output_size, params.perm_count);
|
||||||
|
|
||||||
|
// Transpose kernel only does rearranging values not numeric evaluations on
|
||||||
|
// each cell. It's safe to implement per size of scalar type and this trick
|
||||||
|
// keeps the total code size in a reasonable range.
|
||||||
|
switch (sizeof(T)) {
|
||||||
|
case 1:
|
||||||
|
// TODO(jaesung): Find a good 2d transpose implementation for 8-bit
|
||||||
|
// matrices.
|
||||||
|
TransposeImpl(params, unextended_input_shape,
|
||||||
|
reinterpret_cast<const int8_t*>(input_data),
|
||||||
|
unextended_output_shape,
|
||||||
|
reinterpret_cast<int8_t*>(output_data));
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
if (unextended_input_shape.DimensionsCount() == 2 &&
|
||||||
|
params.perm[0] == 1 && params.perm[1] == 0) {
|
||||||
|
Transpose2DOn32bitMatrix(params, unextended_input_shape,
|
||||||
|
reinterpret_cast<const int32_t*>(input_data),
|
||||||
|
unextended_output_shape,
|
||||||
|
reinterpret_cast<int32_t*>(output_data));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
TransposeImpl(params, unextended_input_shape,
|
||||||
|
reinterpret_cast<const int32_t*>(input_data),
|
||||||
|
unextended_output_shape,
|
||||||
|
reinterpret_cast<int32_t*>(output_data));
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
// Reroute to the reference version if the given size is not common.
|
||||||
|
reference_ops::Transpose(params, unextended_input_shape, input_data,
|
||||||
|
unextended_output_shape, output_data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace optimized_ops
|
} // namespace optimized_ops
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
@ -3046,9 +3046,11 @@ inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void Transpose(const TransposeParams& params,
|
inline void TransposeImpl(const TransposeParams& params,
|
||||||
const RuntimeShape& unextended_input_shape, const T* input_data,
|
const RuntimeShape& unextended_input_shape,
|
||||||
const RuntimeShape& unextended_output_shape, T* output_data) {
|
const T* input_data,
|
||||||
|
const RuntimeShape& unextended_output_shape,
|
||||||
|
T* output_data) {
|
||||||
const int unextended_output_size = unextended_output_shape.DimensionsCount();
|
const int unextended_output_size = unextended_output_shape.DimensionsCount();
|
||||||
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
|
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
|
||||||
TFLITE_DCHECK_LE(unextended_output_size, 4);
|
TFLITE_DCHECK_LE(unextended_output_size, 4);
|
||||||
@ -3096,6 +3098,42 @@ void Transpose(const TransposeParams& params,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void Transpose(const TransposeParams& params,
|
||||||
|
const RuntimeShape& unextended_input_shape, const T* input_data,
|
||||||
|
const RuntimeShape& unextended_output_shape, T* output_data) {
|
||||||
|
// Transpose kernel only does rearranging values not numeric evaluations on
|
||||||
|
// each cell. It's safe to implement per size of scalar type and this trick
|
||||||
|
// keeps the total code size in a reasonable range.
|
||||||
|
switch (sizeof(T)) {
|
||||||
|
case 1:
|
||||||
|
TransposeImpl<int8_t>(params, unextended_input_shape,
|
||||||
|
reinterpret_cast<const int8_t*>(input_data),
|
||||||
|
unextended_output_shape,
|
||||||
|
reinterpret_cast<int8_t*>(output_data));
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
TransposeImpl<int16_t>(params, unextended_input_shape,
|
||||||
|
reinterpret_cast<const int16_t*>(input_data),
|
||||||
|
unextended_output_shape,
|
||||||
|
reinterpret_cast<int16_t*>(output_data));
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 4:
|
||||||
|
TransposeImpl<int32_t>(params, unextended_input_shape,
|
||||||
|
reinterpret_cast<const int32_t*>(input_data),
|
||||||
|
unextended_output_shape,
|
||||||
|
reinterpret_cast<int32_t*>(output_data));
|
||||||
|
break;
|
||||||
|
case 8:
|
||||||
|
TransposeImpl<int64_t>(params, unextended_input_shape,
|
||||||
|
reinterpret_cast<const int64_t*>(input_data),
|
||||||
|
unextended_output_shape,
|
||||||
|
reinterpret_cast<int64_t*>(output_data));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
inline void TransposeConv(
|
inline void TransposeConv(
|
||||||
const ConvParams& params, const RuntimeShape& input_shape,
|
const ConvParams& params, const RuntimeShape& input_shape,
|
||||||
const float* input_data, const RuntimeShape& filter_shape,
|
const float* input_data, const RuntimeShape& filter_shape,
|
||||||
|
@ -13,9 +13,12 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
@ -29,6 +32,7 @@ namespace transpose {
|
|||||||
// This file has two implementations of Transpose.
|
// This file has two implementations of Transpose.
|
||||||
enum KernelType {
|
enum KernelType {
|
||||||
kReference,
|
kReference,
|
||||||
|
kGenericOptimized,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TransposeContext {
|
struct TransposeContext {
|
||||||
@ -96,8 +100,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
const int size = op_context.perm->dims->data[0];
|
const int size = op_context.perm->dims->data[0];
|
||||||
TransposeParams params;
|
TransposeParams params;
|
||||||
params.perm_count = size;
|
params.perm_count = size;
|
||||||
|
bool identical = true;
|
||||||
for (int i = 0; i < size; ++i) {
|
for (int i = 0; i < size; ++i) {
|
||||||
params.perm[i] = perm_data[i];
|
params.perm[i] = perm_data[i];
|
||||||
|
if (perm_data[i] != i) identical = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(b/140779653): Add an optimization pass in the conversion process to
|
||||||
|
// remove transpose op nodes where they do nothing like the below one.
|
||||||
|
if (identical) {
|
||||||
|
memcpy(op_context.output->data.raw, op_context.input->data.raw,
|
||||||
|
op_context.output->bytes);
|
||||||
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
#define TF_LITE_TRANSPOSE(type, scalar) \
|
#define TF_LITE_TRANSPOSE(type, scalar) \
|
||||||
@ -108,32 +122,44 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
switch (op_context.input->type) {
|
switch (op_context.input->type) {
|
||||||
case kTfLiteFloat32:
|
case kTfLiteFloat32:
|
||||||
if (kernel_type == kReference) {
|
if (kernel_type == kGenericOptimized) {
|
||||||
|
TF_LITE_TRANSPOSE(optimized_ops, float);
|
||||||
|
} else {
|
||||||
TF_LITE_TRANSPOSE(reference_ops, float);
|
TF_LITE_TRANSPOSE(reference_ops, float);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case kTfLiteUInt8:
|
case kTfLiteUInt8:
|
||||||
if (kernel_type == kReference) {
|
if (kernel_type == kGenericOptimized) {
|
||||||
|
TF_LITE_TRANSPOSE(optimized_ops, uint8_t);
|
||||||
|
} else {
|
||||||
TF_LITE_TRANSPOSE(reference_ops, uint8_t);
|
TF_LITE_TRANSPOSE(reference_ops, uint8_t);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case kTfLiteInt8:
|
case kTfLiteInt8:
|
||||||
if (kernel_type == kReference) {
|
if (kernel_type == kGenericOptimized) {
|
||||||
|
TF_LITE_TRANSPOSE(optimized_ops, int8_t);
|
||||||
|
} else {
|
||||||
TF_LITE_TRANSPOSE(reference_ops, int8_t);
|
TF_LITE_TRANSPOSE(reference_ops, int8_t);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case kTfLiteInt32:
|
case kTfLiteInt32:
|
||||||
if (kernel_type == kReference) {
|
if (kernel_type == kGenericOptimized) {
|
||||||
|
TF_LITE_TRANSPOSE(optimized_ops, int32_t);
|
||||||
|
} else {
|
||||||
TF_LITE_TRANSPOSE(reference_ops, int32_t);
|
TF_LITE_TRANSPOSE(reference_ops, int32_t);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case kTfLiteInt64:
|
case kTfLiteInt64:
|
||||||
if (kernel_type == kReference) {
|
if (kernel_type == kGenericOptimized) {
|
||||||
|
TF_LITE_TRANSPOSE(optimized_ops, int64_t);
|
||||||
|
} else {
|
||||||
TF_LITE_TRANSPOSE(reference_ops, int64_t);
|
TF_LITE_TRANSPOSE(reference_ops, int64_t);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case kTfLiteBool:
|
case kTfLiteBool:
|
||||||
if (kernel_type == kReference) {
|
if (kernel_type == kGenericOptimized) {
|
||||||
|
TF_LITE_TRANSPOSE(optimized_ops, bool);
|
||||||
|
} else {
|
||||||
TF_LITE_TRANSPOSE(reference_ops, bool);
|
TF_LITE_TRANSPOSE(reference_ops, bool);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
@ -156,7 +182,15 @@ TfLiteRegistration* Register_TRANSPOSE_REF() {
|
|||||||
return &r;
|
return &r;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration* Register_TRANSPOSE() { return Register_TRANSPOSE_REF(); }
|
TfLiteRegistration* Register_TRANSPOSE_GENERIC_OPTIMIZED() {
|
||||||
|
static TfLiteRegistration r = {nullptr, nullptr, transpose::Prepare,
|
||||||
|
transpose::Eval<transpose::kGenericOptimized>};
|
||||||
|
return &r;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteRegistration* Register_TRANSPOSE() {
|
||||||
|
return Register_TRANSPOSE_GENERIC_OPTIMIZED();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace builtin
|
} // namespace builtin
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
|
@ -233,6 +233,28 @@ TEST(TransposeTest, Test2DInputConstTensor) {
|
|||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 4, 1, 3, 5}));
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 4, 1, 3, 5}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(TransposeTest, Test2D4x4KernelTestLeftOverRightSide) {
|
||||||
|
TransposeOpConstModel m({4, 6}, {2}, {1, 0});
|
||||||
|
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||||
|
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 4}));
|
||||||
|
EXPECT_THAT(m.GetOutput(),
|
||||||
|
ElementsAreArray({0, 6, 12, 18, 1, 7, 13, 19, 2, 8, 14, 20,
|
||||||
|
3, 9, 15, 21, 4, 10, 16, 22, 5, 11, 17, 23}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransposeTest, Test2D4x4KernelTest2LeftOverBottomSide) {
|
||||||
|
TransposeOpConstModel m({6, 4}, {2}, {1, 0});
|
||||||
|
m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||||
|
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 6}));
|
||||||
|
EXPECT_THAT(m.GetOutput(),
|
||||||
|
ElementsAreArray({0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21,
|
||||||
|
2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(TransposeTest, Test2DInputDynamicTensor) {
|
TEST(TransposeTest, Test2DInputDynamicTensor) {
|
||||||
TransposeOpDynamicModel m({3, 2}, {2});
|
TransposeOpDynamicModel m({3, 2}, {2});
|
||||||
m.SetInput({0, 1, 2, 3, 4, 5});
|
m.SetInput({0, 1, 2, 3, 4, 5});
|
||||||
|
Loading…
Reference in New Issue
Block a user