Transpose op optimization
PiperOrigin-RevId: 269470475
This commit is contained in:
parent
e95490d0e1
commit
376e283836
@ -6767,11 +6767,11 @@ inline void Transpose2DOn32bitMatrix(const TransposeParams& params,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline void TransposeImpl(const TransposeParams& params,
|
void TransposeImpl(const TransposeParams& params,
|
||||||
const RuntimeShape& unextended_input_shape,
|
const RuntimeShape& unextended_input_shape,
|
||||||
const T* input_data,
|
const T* input_data,
|
||||||
const RuntimeShape& unextended_output_shape,
|
const RuntimeShape& unextended_output_shape,
|
||||||
T* output_data) {
|
T* output_data) {
|
||||||
const int unextended_output_size = unextended_input_shape.DimensionsCount();
|
const int unextended_output_size = unextended_input_shape.DimensionsCount();
|
||||||
const RuntimeShape input_shape =
|
const RuntimeShape input_shape =
|
||||||
RuntimeShape::ExtendedShape(4, unextended_input_shape);
|
RuntimeShape::ExtendedShape(4, unextended_input_shape);
|
||||||
@ -6819,14 +6819,6 @@ void Transpose(const TransposeParams& params,
|
|||||||
// each cell. It's safe to implement per size of scalar type and this trick
|
// each cell. It's safe to implement per size of scalar type and this trick
|
||||||
// keeps the total code size in a reasonable range.
|
// keeps the total code size in a reasonable range.
|
||||||
switch (sizeof(T)) {
|
switch (sizeof(T)) {
|
||||||
case 1:
|
|
||||||
// TODO(jaesung): Find a good 2d transpose implementation for 8-bit
|
|
||||||
// matrices.
|
|
||||||
TransposeImpl(params, unextended_input_shape,
|
|
||||||
reinterpret_cast<const int8_t*>(input_data),
|
|
||||||
unextended_output_shape,
|
|
||||||
reinterpret_cast<int8_t*>(output_data));
|
|
||||||
break;
|
|
||||||
case 4:
|
case 4:
|
||||||
if (unextended_input_shape.DimensionsCount() == 2 &&
|
if (unextended_input_shape.DimensionsCount() == 2 &&
|
||||||
params.perm[0] == 1 && params.perm[1] == 0) {
|
params.perm[0] == 1 && params.perm[1] == 0) {
|
||||||
@ -6842,7 +6834,7 @@ void Transpose(const TransposeParams& params,
|
|||||||
reinterpret_cast<int32_t*>(output_data));
|
reinterpret_cast<int32_t*>(output_data));
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
// Reroute to the reference version if the given size is not common.
|
// Reroute to the reference version if the given size is not available.
|
||||||
reference_ops::Transpose(params, unextended_input_shape, input_data,
|
reference_ops::Transpose(params, unextended_input_shape, input_data,
|
||||||
unextended_output_shape, output_data);
|
unextended_output_shape, output_data);
|
||||||
}
|
}
|
||||||
|
@ -3008,11 +3008,11 @@ inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline void TransposeImpl(const TransposeParams& params,
|
void TransposeImpl(const TransposeParams& params,
|
||||||
const RuntimeShape& unextended_input_shape,
|
const RuntimeShape& unextended_input_shape,
|
||||||
const T* input_data,
|
const T* input_data,
|
||||||
const RuntimeShape& unextended_output_shape,
|
const RuntimeShape& unextended_output_shape,
|
||||||
T* output_data) {
|
T* output_data) {
|
||||||
const int unextended_output_size = unextended_output_shape.DimensionsCount();
|
const int unextended_output_size = unextended_output_shape.DimensionsCount();
|
||||||
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
|
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
|
||||||
TFLITE_DCHECK_LE(unextended_output_size, 4);
|
TFLITE_DCHECK_LE(unextended_output_size, 4);
|
||||||
|
Loading…
Reference in New Issue
Block a user