Transpose op optimization

PiperOrigin-RevId: 269470475
This commit is contained in:
Jaesung Chung 2019-09-16 18:52:00 -07:00 committed by TensorFlower Gardener
parent e95490d0e1
commit 376e283836
2 changed files with 11 additions and 19 deletions

View File

@ -6767,7 +6767,7 @@ inline void Transpose2DOn32bitMatrix(const TransposeParams& params,
} }
template <typename T> template <typename T>
inline void TransposeImpl(const TransposeParams& params, void TransposeImpl(const TransposeParams& params,
const RuntimeShape& unextended_input_shape, const RuntimeShape& unextended_input_shape,
const T* input_data, const T* input_data,
const RuntimeShape& unextended_output_shape, const RuntimeShape& unextended_output_shape,
@ -6819,14 +6819,6 @@ void Transpose(const TransposeParams& params,
// each cell. It's safe to implement per size of scalar type and this trick // each cell. It's safe to implement per size of scalar type and this trick
// keeps the total code size in a reasonable range. // keeps the total code size in a reasonable range.
switch (sizeof(T)) { switch (sizeof(T)) {
case 1:
// TODO(jaesung): Find a good 2d transpose implementation for 8-bit
// matrices.
TransposeImpl(params, unextended_input_shape,
reinterpret_cast<const int8_t*>(input_data),
unextended_output_shape,
reinterpret_cast<int8_t*>(output_data));
break;
case 4: case 4:
if (unextended_input_shape.DimensionsCount() == 2 && if (unextended_input_shape.DimensionsCount() == 2 &&
params.perm[0] == 1 && params.perm[1] == 0) { params.perm[0] == 1 && params.perm[1] == 0) {
@ -6842,7 +6834,7 @@ void Transpose(const TransposeParams& params,
reinterpret_cast<int32_t*>(output_data)); reinterpret_cast<int32_t*>(output_data));
break; break;
default: default:
// Reroute to the reference version if the given size is not common. // Reroute to the reference version if the given size is not available.
reference_ops::Transpose(params, unextended_input_shape, input_data, reference_ops::Transpose(params, unextended_input_shape, input_data,
unextended_output_shape, output_data); unextended_output_shape, output_data);
} }

View File

@ -3008,7 +3008,7 @@ inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
} }
template <typename T> template <typename T>
inline void TransposeImpl(const TransposeParams& params, void TransposeImpl(const TransposeParams& params,
const RuntimeShape& unextended_input_shape, const RuntimeShape& unextended_input_shape,
const T* input_data, const T* input_data,
const RuntimeShape& unextended_output_shape, const RuntimeShape& unextended_output_shape,