Reorganised the code for FloorMod.
This was one of the TODO items.
This commit is contained in:
parent
b4ecc7875f
commit
75d6fc0259
@ -37,24 +37,7 @@ struct OpData {
|
||||
bool requires_broadcast;
|
||||
};
|
||||
|
||||
struct FloatMod {
|
||||
float operator()(const float lhs, const float rhs) const {
|
||||
return std::fmod(lhs, rhs);
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(b/117912007): Move the implementation to reference_ops.h
|
||||
// TODO(b/117912880): Support quantization.
|
||||
template <typename T>
|
||||
T FloorMod(T input1, T input2) {
|
||||
using ModFunc = typename std::conditional<std::is_integral<T>::value,
|
||||
std::modulus<T>, FloatMod>::type;
|
||||
|
||||
ModFunc mod_func;
|
||||
T trunc_mod = mod_func(input1, input2);
|
||||
return trunc_mod != 0 && (input2 < 0 != trunc_mod < 0) ? trunc_mod + input2
|
||||
: trunc_mod;
|
||||
}
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
auto* data = new OpData;
|
||||
@ -120,12 +103,13 @@ TfLiteStatus EvalImpl(TfLiteContext* context, bool requires_broadcast,
|
||||
reference_ops::BroadcastBinaryFunction4DSlow<T, T, T>(
|
||||
GetTensorShape(input1), GetTensorData<T>(input1),
|
||||
GetTensorShape(input2), denominator_data, GetTensorShape(output),
|
||||
GetTensorData<T>(output), FloorMod<T>);
|
||||
GetTensorData<T>(output), reference_ops::FloorMod<T>);
|
||||
} else {
|
||||
reference_ops::BinaryFunction<T, T, T>(
|
||||
GetTensorShape(input1), GetTensorData<T>(input1),
|
||||
GetTensorShape(input2), GetTensorData<T>(input2),
|
||||
GetTensorShape(output), GetTensorData<T>(output), FloorMod<T>);
|
||||
GetTensorShape(output), GetTensorData<T>(output),
|
||||
reference_ops::FloorMod<T>);
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
|
@ -2933,6 +2933,21 @@ inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T FloorMod(T input1, T input2) {
|
||||
struct FloatMod {
|
||||
float operator()(const float lhs, const float rhs) const {
|
||||
return std::fmod(lhs, rhs);
|
||||
}
|
||||
};
|
||||
using ModFunc = typename std::conditional<std::is_integral<T>::value,
|
||||
std::modulus<T>, FloatMod>::type;
|
||||
ModFunc mod_func;
|
||||
T trunc_mod = mod_func(input1, input2);
|
||||
return trunc_mod != 0 && (input2 < 0 != trunc_mod < 0) ? trunc_mod + input2
|
||||
: trunc_mod;
|
||||
}
|
||||
|
||||
inline void Floor(const RuntimeShape& input_shape, const float* input_data,
|
||||
const RuntimeShape& output_shape, float* output_data) {
|
||||
const int flat_size = MatchingFlatSize(input_shape, output_shape);
|
||||
|
Loading…
Reference in New Issue
Block a user