diff --git a/tensorflow/lite/kernels/floor_div.cc b/tensorflow/lite/kernels/floor_div.cc index 19b806c91d5..2146a8723d7 100644 --- a/tensorflow/lite/kernels/floor_div.cc +++ b/tensorflow/lite/kernels/floor_div.cc @@ -41,12 +41,6 @@ struct OpData { bool requires_broadcast; }; -template -T FloorDiv(T input1, T input2) { - return std::floor(std::divides()(static_cast(input1), - static_cast(input2))); -} - void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* data = new OpData; data->requires_broadcast = false; @@ -118,12 +112,13 @@ TfLiteStatus EvalImpl(TfLiteContext* context, bool requires_broadcast, reference_ops::BroadcastBinaryFunction4DSlow( GetTensorShape(input1), GetTensorData(input1), GetTensorShape(input2), denominator_data, GetTensorShape(output), - GetTensorData(output), FloorDiv); + GetTensorData(output), reference_ops::FloorDiv); } else { reference_ops::BinaryFunction( GetTensorShape(input1), GetTensorData(input1), GetTensorShape(input2), GetTensorData(input2), - GetTensorShape(output), GetTensorData(output), FloorDiv); + GetTensorShape(output), GetTensorData(output), + reference_ops::FloorDiv); } return kTfLiteOk; diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index b4410812df4..075150879de 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -469,6 +469,7 @@ cc_library( "reference/exp.h", "reference/fill.h", "reference/floor.h", + "reference/floor_div.h", "reference/floor_mod.h", "reference/fully_connected.h", "reference/hard_swish.h", @@ -583,6 +584,7 @@ cc_library( "reference/exp.h", "reference/fill.h", "reference/floor.h", + "reference/floor_div.h", "reference/floor_mod.h", "reference/fully_connected.h", "reference/hard_swish.h", diff --git a/tensorflow/lite/kernels/internal/reference/floor_div.h b/tensorflow/lite/kernels/internal/reference/floor_div.h new file mode 100644 index 00000000000..e75d473cf0b --- /dev/null +++ b/tensorflow/lite/kernels/internal/reference/floor_div.h @@ -0,0 +1,35 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_FLOOR_DIV_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_FLOOR_DIV_H_ + +#include +#include + +#include "tensorflow/lite/kernels/internal/types.h" + +namespace tflite { +namespace reference_ops { + +template +T FloorDiv(T input1, T input2) { + return std::floor(std::divides()(static_cast(input1), + static_cast(input2))); +} + +} // namespace reference_ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_FLOOR_DIV_H_ diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index e69982625d0..08a259e9d63 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -48,6 +48,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/exp.h" #include "tensorflow/lite/kernels/internal/reference/fill.h" #include "tensorflow/lite/kernels/internal/reference/floor.h" +#include "tensorflow/lite/kernels/internal/reference/floor_div.h" #include "tensorflow/lite/kernels/internal/reference/floor_mod.h" #include "tensorflow/lite/kernels/internal/reference/fully_connected.h" #include "tensorflow/lite/kernels/internal/reference/hard_swish.h"