Reorganised the code for FloorMod.

This was one of the TODO items.
This commit is contained in:
Amit Srivastava 2019-03-21 03:09:57 +05:30
parent b4ecc7875f
commit 75d6fc0259
2 changed files with 18 additions and 19 deletions

View File

@ -37,24 +37,7 @@ struct OpData {
bool requires_broadcast;
};
struct FloatMod {
float operator()(const float lhs, const float rhs) const {
return std::fmod(lhs, rhs);
}
};
// TODO(b/117912007): Move the implementation to reference_ops.h
// TODO(b/117912880): Support quantization.
template <typename T>
T FloorMod(T input1, T input2) {
using ModFunc = typename std::conditional<std::is_integral<T>::value,
std::modulus<T>, FloatMod>::type;
ModFunc mod_func;
T trunc_mod = mod_func(input1, input2);
return trunc_mod != 0 && (input2 < 0 != trunc_mod < 0) ? trunc_mod + input2
: trunc_mod;
}
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* data = new OpData;
@ -120,12 +103,13 @@ TfLiteStatus EvalImpl(TfLiteContext* context, bool requires_broadcast,
reference_ops::BroadcastBinaryFunction4DSlow<T, T, T>(
GetTensorShape(input1), GetTensorData<T>(input1),
GetTensorShape(input2), denominator_data, GetTensorShape(output),
GetTensorData<T>(output), FloorMod<T>);
GetTensorData<T>(output), reference_ops::FloorMod<T>);
} else {
reference_ops::BinaryFunction<T, T, T>(
GetTensorShape(input1), GetTensorData<T>(input1),
GetTensorShape(input2), GetTensorData<T>(input2),
GetTensorShape(output), GetTensorData<T>(output), FloorMod<T>);
GetTensorShape(output), GetTensorData<T>(output),
reference_ops::FloorMod<T>);
}
return kTfLiteOk;

View File

@ -2933,6 +2933,21 @@ inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
}
}
template <typename T>
T FloorMod(T input1, T input2) {
struct FloatMod {
float operator()(const float lhs, const float rhs) const {
return std::fmod(lhs, rhs);
}
};
using ModFunc = typename std::conditional<std::is_integral<T>::value,
std::modulus<T>, FloatMod>::type;
ModFunc mod_func;
T trunc_mod = mod_func(input1, input2);
return trunc_mod != 0 && (input2 < 0 != trunc_mod < 0) ? trunc_mod + input2
: trunc_mod;
}
inline void Floor(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);