Port ArgMax for backward compatibility
This commit is contained in:
parent
52ccf9d95c
commit
d41c879e5c
@ -7941,6 +7941,25 @@ inline void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data,
|
||||
output_data, is_arg_max);
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
|
||||
const T3* input2_data, const RuntimeShape& output_shape,
|
||||
T2* output_data) {
|
||||
ArgMinMax(input1_shape, input1_data, input2_data, output_shape, output_data,
|
||||
/*is_arg_max=*/true);
|
||||
}
|
||||
|
||||
// Convenience version that allows, for example, generated-code calls to be
|
||||
// the same as other binary ops.
|
||||
// For backward compatibility, reference_ops has ArgMax function.
|
||||
template <typename T1, typename T2, typename T3>
|
||||
inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
|
||||
const RuntimeShape& input2_shape, const T3* input2_data,
|
||||
const RuntimeShape& output_shape, T2* output_data) {
|
||||
// Drop shape of second input: not needed.
|
||||
ArgMax(input1_shape, input1_data, input2_data, output_shape, output_data);
|
||||
}
|
||||
|
||||
} // namespace optimized_ops
|
||||
} // namespace tflite
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user