Port ArgMax for backward compatibility

This commit is contained in:
Tzu-Wei Sung 2021-03-05 09:05:14 -08:00
parent 52ccf9d95c
commit d41c879e5c

View File

@ -7941,6 +7941,25 @@ inline void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data,
output_data, is_arg_max);
}
template <typename T1, typename T2, typename T3>
void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
const T3* input2_data, const RuntimeShape& output_shape,
T2* output_data) {
ArgMinMax(input1_shape, input1_data, input2_data, output_shape, output_data,
/*is_arg_max=*/true);
}
// Convenience version that allows, for example, generated-code calls to be
// the same as other binary ops.
// For backward compatibility, reference_ops has ArgMax function.
template <typename T1, typename T2, typename T3>
inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
const RuntimeShape& input2_shape, const T3* input2_data,
const RuntimeShape& output_shape, T2* output_data) {
// Drop shape of second input: not needed.
ArgMax(input1_shape, input1_data, input2_data, output_shape, output_data);
}
} // namespace optimized_ops
} // namespace tflite