Transpose op optimization
PiperOrigin-RevId: 269470475
This commit is contained in:
parent
e95490d0e1
commit
376e283836
@ -6767,11 +6767,11 @@ inline void Transpose2DOn32bitMatrix(const TransposeParams& params,
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void TransposeImpl(const TransposeParams& params,
|
||||
const RuntimeShape& unextended_input_shape,
|
||||
const T* input_data,
|
||||
const RuntimeShape& unextended_output_shape,
|
||||
T* output_data) {
|
||||
void TransposeImpl(const TransposeParams& params,
|
||||
const RuntimeShape& unextended_input_shape,
|
||||
const T* input_data,
|
||||
const RuntimeShape& unextended_output_shape,
|
||||
T* output_data) {
|
||||
const int unextended_output_size = unextended_input_shape.DimensionsCount();
|
||||
const RuntimeShape input_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_input_shape);
|
||||
@ -6819,14 +6819,6 @@ void Transpose(const TransposeParams& params,
|
||||
// each cell. It's safe to implement per size of scalar type and this trick
|
||||
// keeps the total code size in a reasonable range.
|
||||
switch (sizeof(T)) {
|
||||
case 1:
|
||||
// TODO(jaesung): Find a good 2d transpose implementation for 8-bit
|
||||
// matrices.
|
||||
TransposeImpl(params, unextended_input_shape,
|
||||
reinterpret_cast<const int8_t*>(input_data),
|
||||
unextended_output_shape,
|
||||
reinterpret_cast<int8_t*>(output_data));
|
||||
break;
|
||||
case 4:
|
||||
if (unextended_input_shape.DimensionsCount() == 2 &&
|
||||
params.perm[0] == 1 && params.perm[1] == 0) {
|
||||
@ -6842,7 +6834,7 @@ void Transpose(const TransposeParams& params,
|
||||
reinterpret_cast<int32_t*>(output_data));
|
||||
break;
|
||||
default:
|
||||
// Reroute to the reference version if the given size is not common.
|
||||
// Reroute to the reference version if the given size is not available.
|
||||
reference_ops::Transpose(params, unextended_input_shape, input_data,
|
||||
unextended_output_shape, output_data);
|
||||
}
|
||||
|
@ -3008,11 +3008,11 @@ inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void TransposeImpl(const TransposeParams& params,
|
||||
const RuntimeShape& unextended_input_shape,
|
||||
const T* input_data,
|
||||
const RuntimeShape& unextended_output_shape,
|
||||
T* output_data) {
|
||||
void TransposeImpl(const TransposeParams& params,
|
||||
const RuntimeShape& unextended_input_shape,
|
||||
const T* input_data,
|
||||
const RuntimeShape& unextended_output_shape,
|
||||
T* output_data) {
|
||||
const int unextended_output_size = unextended_output_shape.DimensionsCount();
|
||||
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_output_size, 4);
|
||||
|
Loading…
Reference in New Issue
Block a user