Transpose op optimization

PiperOrigin-RevId: 269470475
This commit is contained in:
Jaesung Chung 2019-09-16 18:52:00 -07:00 committed by TensorFlower Gardener
parent e95490d0e1
commit 376e283836
2 changed files with 11 additions and 19 deletions

View File

@ -6767,11 +6767,11 @@ inline void Transpose2DOn32bitMatrix(const TransposeParams& params,
}
template <typename T>
inline void TransposeImpl(const TransposeParams& params,
const RuntimeShape& unextended_input_shape,
const T* input_data,
const RuntimeShape& unextended_output_shape,
T* output_data) {
void TransposeImpl(const TransposeParams& params,
const RuntimeShape& unextended_input_shape,
const T* input_data,
const RuntimeShape& unextended_output_shape,
T* output_data) {
const int unextended_output_size = unextended_input_shape.DimensionsCount();
const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
@ -6819,14 +6819,6 @@ void Transpose(const TransposeParams& params,
// each cell. It's safe to implement per size of scalar type and this trick
// keeps the total code size in a reasonable range.
switch (sizeof(T)) {
case 1:
// TODO(jaesung): Find a good 2d transpose implementation for 8-bit
// matrices.
TransposeImpl(params, unextended_input_shape,
reinterpret_cast<const int8_t*>(input_data),
unextended_output_shape,
reinterpret_cast<int8_t*>(output_data));
break;
case 4:
if (unextended_input_shape.DimensionsCount() == 2 &&
params.perm[0] == 1 && params.perm[1] == 0) {
@ -6842,7 +6834,7 @@ void Transpose(const TransposeParams& params,
reinterpret_cast<int32_t*>(output_data));
break;
default:
// Reroute to the reference version if the given size is not common.
// Reroute to the reference version if the given size is not available.
reference_ops::Transpose(params, unextended_input_shape, input_data,
unextended_output_shape, output_data);
}

View File

@ -3008,11 +3008,11 @@ inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
}
template <typename T>
inline void TransposeImpl(const TransposeParams& params,
const RuntimeShape& unextended_input_shape,
const T* input_data,
const RuntimeShape& unextended_output_shape,
T* output_data) {
void TransposeImpl(const TransposeParams& params,
const RuntimeShape& unextended_input_shape,
const T* input_data,
const RuntimeShape& unextended_output_shape,
T* output_data) {
const int unextended_output_size = unextended_output_shape.DimensionsCount();
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_size, 4);