diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index b0a5cb9e8db..728d386058f 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -465,6 +465,7 @@ cc_library( "reference/exp.h", "reference/fill.h", "reference/floor.h", + "reference/floor_mod.h", "reference/fully_connected.h", "reference/hard_swish.h", "reference/integer_ops/add.h", @@ -571,6 +572,7 @@ cc_library( "reference/exp.h", "reference/fill.h", "reference/floor.h", + "reference/floor_mod.h", "reference/fully_connected.h", "reference/hard_swish.h", "reference/l2normalization.h", diff --git a/tensorflow/lite/kernels/internal/reference/floor_mod.h b/tensorflow/lite/kernels/internal/reference/floor_mod.h new file mode 100644 index 00000000000..b1fe1705011 --- /dev/null +++ b/tensorflow/lite/kernels/internal/reference/floor_mod.h @@ -0,0 +1,43 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_FLOOR_MOD_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_FLOOR_MOD_H_ + +#include + +namespace tflite { + +namespace reference_ops { + +template +T FloorMod(T input1, T input2) { + struct FloatMod { + float operator()(const float lhs, const float rhs) const { + return std::fmod(lhs, rhs); + } + }; + using ModFunc = typename std::conditional::value, + std::modulus, FloatMod>::type; + ModFunc mod_func; + T trunc_mod = mod_func(input1, input2); + return (trunc_mod != 0) && ((input2 < 0) != (trunc_mod < 0)) + ? (trunc_mod + input2) + : trunc_mod; +} + +} // namespace reference_ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_FLOOR_MOD_H_ diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index 9b05a4fef0e..a7270e876a0 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -44,6 +44,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/exp.h" #include "tensorflow/lite/kernels/internal/reference/fill.h" #include "tensorflow/lite/kernels/internal/reference/floor.h" +#include "tensorflow/lite/kernels/internal/reference/floor_mod.h" #include "tensorflow/lite/kernels/internal/reference/fully_connected.h" #include "tensorflow/lite/kernels/internal/reference/hard_swish.h" #include "tensorflow/lite/kernels/internal/reference/l2normalization.h" @@ -1174,22 +1175,6 @@ inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data, } } -template -T FloorMod(T input1, T input2) { - struct FloatMod { - float operator()(const float lhs, const float rhs) const { - return std::fmod(lhs, rhs); - } - }; - using ModFunc = typename std::conditional::value, - std::modulus, FloatMod>::type; - ModFunc mod_func; - T trunc_mod = mod_func(input1, input2); - return (trunc_mod != 0) && ((input2 < 0) != (trunc_mod < 0)) - ? (trunc_mod + input2) - : trunc_mod; -} - template inline void Gather(const tflite::GatherParams& op_params, const RuntimeShape& input_shape, const T* input_data,