diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 48e13d52986..b59dc0a9fff 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -457,6 +457,7 @@ cc_library( "reference/depthwiseconv_float.h", "reference/depthwiseconv_uint8.h", "reference/dequantize.h", + "reference/fill.h", "reference/floor.h", "reference/fully_connected.h", "reference/hard_swish.h", @@ -550,6 +551,7 @@ cc_library( "reference/depthwiseconv_float.h", "reference/depthwiseconv_uint8.h", "reference/dequantize.h", + "reference/fill.h", "reference/floor.h", "reference/fully_connected.h", "reference/hard_swish.h", diff --git a/tensorflow/lite/kernels/internal/reference/fill.h b/tensorflow/lite/kernels/internal/reference/fill.h new file mode 100644 index 00000000000..16630e617d1 --- /dev/null +++ b/tensorflow/lite/kernels/internal/reference/fill.h @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_FILL_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_FILL_H_ + +#include + +#include "tensorflow/lite/kernels/internal/types.h" + +namespace tflite { +namespace reference_ops { + +template +void Fill(const RuntimeShape& value_shape, const T* value_data, + const RuntimeShape& output_shape, T* output_data) { + TFLITE_DCHECK_EQ(value_shape.DimensionsCount(), 0); + const int flat_size = output_shape.FlatSize(); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = *value_data; + } +} + +} // namespace reference_ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_FILL_H_ diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index fb2b42457e7..afbb717bc9f 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/concatenation.h" #include "tensorflow/lite/kernels/internal/reference/conv.h" #include "tensorflow/lite/kernels/internal/reference/dequantize.h" +#include "tensorflow/lite/kernels/internal/reference/fill.h" #include "tensorflow/lite/kernels/internal/reference/floor.h" #include "tensorflow/lite/kernels/internal/reference/fully_connected.h" #include "tensorflow/lite/kernels/internal/reference/hard_swish.h" @@ -2496,16 +2497,6 @@ inline void BroadcastPow4DSlow(const RuntimeShape& unextended_input1_shape, } } -template -void Fill(const RuntimeShape& value_shape, const T* value_data, - const RuntimeShape& output_shape, T* output_data) { - TFLITE_DCHECK_EQ(value_shape.DimensionsCount(), 0); - const int flat_size = output_shape.FlatSize(); - for (int i = 0; i < flat_size; ++i) { - output_data[i] = *value_data; - } -} - template void Reverse(int axis, const RuntimeShape& input_shape, const Scalar* input_data, const RuntimeShape& output_shape,