Resolve conflicts in legacy ops

This commit is contained in:
Tzu-Wei Sung 2021-03-05 09:04:14 -08:00
parent 9f9e52b94f
commit 52ccf9d95c
2 changed files with 33 additions and 0 deletions

View File

@ -4964,6 +4964,31 @@ inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims), output_data);
}
template <typename T1, typename T2, typename T3>
void ArgMax(const T3* axis, const T1* input_data,
const tflite::Dims<4>& input_dims, T2* output_data,
const tflite::Dims<4>& output_dims) {
// Assumes the input always has 4 dimensions, and therefore,
// output always has three dimensions.
auto output_shape = RuntimeShape(
{output_dims.sizes[2], output_dims.sizes[1], output_dims.sizes[0]});
// Another way to interpret this is that output_dims.sizes[4] is always 1.
TFLITE_DCHECK_EQ(output_shape.FlatSize(),
DimsToShape(output_dims).FlatSize());
// Legacy path only supported this.
TFLITE_DCHECK_EQ(axis[0], 3);
ArgMinMax(DimsToShape(input_dims), input_data, axis, output_shape,
output_data, /*is_arg_max=*/true);
}
template <typename T1, typename T2, typename T3>
void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
T2* output_data, const Dims<4>& output_dims,
const bool is_arg_max) {
ArgMinMax(axis, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
output_data, is_arg_max);
}
} // namespace optimized_ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_

View File

@ -2146,6 +2146,14 @@ void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
output_data, cmp);
}
template <typename T1, typename T2, typename T3>
void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
T2* output_data, const Dims<4>& output_dims,
const bool is_arg_max) {
ArgMinMax(axis, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
output_data, is_arg_max);
}
template <typename T>
inline void Pow(const T* input1_data, const Dims<4>& input1_dims,
const T* input2_data, const Dims<4>& input2_dims,