TFL MCU: Move reference ResizeNearestNeighbor implementation into its own file.

so that we won't need to import all the dependencies.

This CL simply copies the existing code into the new file.

PiperOrigin-RevId: 306480426
Change-Id: I672a14c3edeeab975b5e75c18d632c8d71ecead4
This commit is contained in:
A. Unique TensorFlower 2020-04-14 11:32:46 -07:00 committed by TensorFlower Gardener
parent 3cfc29aa36
commit d58a4dfc54
3 changed files with 86 additions and 54 deletions
tensorflow/lite/kernels/internal

View File

@ -472,6 +472,7 @@ cc_library(
"reference/reduce.h",
"reference/reference_ops.h",
"reference/requantize.h",
"reference/resize_nearest_neighbor.h",
"reference/round.h",
"reference/softmax.h",
"reference/sparse_ops/fully_connected.h",
@ -541,6 +542,7 @@ cc_library(
"reference/reduce.h",
"reference/reference_ops.h",
"reference/requantize.h",
"reference/resize_nearest_neighbor.h",
"reference/round.h",
"reference/softmax.h",
"reference/strided_slice.h",

View File

@ -53,6 +53,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/reference/quantize.h"
#include "tensorflow/lite/kernels/internal/reference/reduce.h"
#include "tensorflow/lite/kernels/internal/reference/requantize.h"
#include "tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h"
#include "tensorflow/lite/kernels/internal/reference/round.h"
#include "tensorflow/lite/kernels/internal/reference/softmax.h"
#include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
@ -2459,60 +2460,6 @@ inline void BroadcastPow4DSlow(const RuntimeShape& unextended_input1_shape,
}
}
template <typename T>
inline void ResizeNearestNeighbor(
const tflite::ResizeNearestNeighborParams& op_params,
const RuntimeShape& unextended_input_shape, const T* input_data,
const RuntimeShape& output_size_shape, const int32* output_size_data,
const RuntimeShape& unextended_output_shape, T* output_data) {
// Align corners = true is not supported.
TFLITE_DCHECK(!op_params.align_corners);
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
int32 input_height = input_shape.Dims(1);
int32 input_width = input_shape.Dims(2);
int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
// The Tensorflow version of this op allows resize on the width and height
// axis only.
TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
int32 output_height = output_size_data[0];
int32 output_width = output_size_data[1];
// We use float to ensure agreement with the Tensorflow implementation.
const float height_scale = static_cast<float>(input_height) / output_height;
const float width_scale = static_cast<float>(input_width) / output_width;
const int col_offset = input_shape.Dims(3);
const int row_offset = input_shape.Dims(2) * col_offset;
const int batch_offset = input_shape.Dims(1) * row_offset;
const T* input_ptr = input_data;
T* output_ptr = output_data;
for (int b = 0; b < batches; ++b) {
for (int y = 0; y < output_height; ++y) {
int32 in_y = std::min(static_cast<int32>(std::floor(y * height_scale)),
input_height - 1);
const T* y_input_ptr = input_ptr + in_y * row_offset;
for (int x = 0; x < output_width; ++x) {
int32 in_x = std::min(static_cast<int32>(std::floor(x * width_scale)),
input_width - 1);
const T* x_input_ptr = y_input_ptr + in_x * col_offset;
memcpy(output_ptr, x_input_ptr, depth * sizeof(T));
output_ptr += depth;
}
}
input_ptr += batch_offset;
}
}
template <typename T>
void Fill(const RuntimeShape& value_shape, const T* value_data,
const RuntimeShape& output_shape, T* output_data) {

View File

@ -0,0 +1,83 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_RESIZE_NEAREST_NEIGHBOR_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_RESIZE_NEAREST_NEIGHBOR_H_
#include <cmath>
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace reference_ops {
template <typename T>
inline void ResizeNearestNeighbor(
const tflite::ResizeNearestNeighborParams& op_params,
const RuntimeShape& unextended_input_shape, const T* input_data,
const RuntimeShape& output_size_shape, const int32* output_size_data,
const RuntimeShape& unextended_output_shape, T* output_data) {
// Align corners = true is not supported.
TFLITE_DCHECK(!op_params.align_corners);
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
int32 input_height = input_shape.Dims(1);
int32 input_width = input_shape.Dims(2);
int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
// The Tensorflow version of this op allows resize on the width and height
// axis only.
TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
int32 output_height = output_size_data[0];
int32 output_width = output_size_data[1];
// We use float to ensure agreement with the Tensorflow implementation.
const float height_scale = static_cast<float>(input_height) / output_height;
const float width_scale = static_cast<float>(input_width) / output_width;
const int col_offset = input_shape.Dims(3);
const int row_offset = input_shape.Dims(2) * col_offset;
const int batch_offset = input_shape.Dims(1) * row_offset;
const T* input_ptr = input_data;
T* output_ptr = output_data;
for (int b = 0; b < batches; ++b) {
for (int y = 0; y < output_height; ++y) {
int32 in_y = std::min(static_cast<int32>(std::floor(y * height_scale)),
input_height - 1);
const T* y_input_ptr = input_ptr + in_y * row_offset;
for (int x = 0; x < output_width; ++x) {
int32 in_x = std::min(static_cast<int32>(std::floor(x * width_scale)),
input_width - 1);
const T* x_input_ptr = y_input_ptr + in_x * col_offset;
memcpy(output_ptr, x_input_ptr, depth * sizeof(T));
output_ptr += depth;
}
}
input_ptr += batch_offset;
}
}
} // namespace reference_ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_RESIZE_NEAREST_NEIGHBOR_H_