Port pad op to micro.

PiperOrigin-RevId: 284832820
Change-Id: Id7fcff7cf8f5c748aa835932405fa461ecec8b6f
This commit is contained in:
Nat Jeffries 2019-12-10 13:09:52 -08:00 committed by TensorFlower Gardener
parent 635a3a0f82
commit 840cabea48
11 changed files with 890 additions and 167 deletions

View File

@ -439,6 +439,7 @@ cc_library(
"reference/mul.h",
"reference/neg.h",
"reference/non_max_suppression.h",
"reference/pad.h",
"reference/pooling.h",
"reference/prelu.h",
"reference/process_broadcast_shapes.h",
@ -499,6 +500,7 @@ cc_library(
"reference/maximum_minimum.h",
"reference/mul.h",
"reference/neg.h",
"reference/pad.h",
"reference/pooling.h",
"reference/prelu.h",
"reference/process_broadcast_shapes.h",

View File

@ -0,0 +1,184 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PAD_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PAD_H_
#include <vector>
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace reference_ops {
// TFLite Pad supports activation tensors with up to 4 dimensions.
constexpr int PadKernelMaxDimensionCount() { return 4; }
// There are two versions of pad: Pad and PadV2. In PadV2 there is a second
// scalar input that provides the padding value. Therefore pad_value_ptr can be
// equivalent to a simple input1_data. For Pad, it should point to a zero
// value.
//
// Note that two typenames are required, so that T=P=int32 is considered a
// specialization distinct from P=int32.
template <typename T, typename P>
inline void PadImpl(const tflite::PadParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
const P* pad_value_ptr, const RuntimeShape& output_shape,
T* output_data) {
const RuntimeShape ext_input_shape =
RuntimeShape::ExtendedShape(PadKernelMaxDimensionCount(), input_shape);
const RuntimeShape ext_output_shape =
RuntimeShape::ExtendedShape(PadKernelMaxDimensionCount(), output_shape);
TFLITE_DCHECK_LE(op_params.left_padding_count, PadKernelMaxDimensionCount());
TFLITE_DCHECK_LE(op_params.right_padding_count, PadKernelMaxDimensionCount());
// Runtime calls are currently fixed at 4 dimensions. Copy inputs so we can
// pad them to 4 dims (yes, we are "padding the padding").
int left_padding_copy[PadKernelMaxDimensionCount()];
for (int i = 0; i < PadKernelMaxDimensionCount(); i++) {
left_padding_copy[i] = 0;
}
for (int i = 0; i < op_params.left_padding_count; ++i) {
left_padding_copy[i + PadKernelMaxDimensionCount() -
op_params.left_padding_count] = op_params.left_padding[i];
}
int right_padding_copy[PadKernelMaxDimensionCount()];
for (int i = 0; i < PadKernelMaxDimensionCount(); i++) {
right_padding_copy[i] = 0;
}
for (int i = 0; i < op_params.right_padding_count; ++i) {
right_padding_copy[i + PadKernelMaxDimensionCount() -
op_params.right_padding_count] =
op_params.right_padding[i];
}
const int output_batch = ext_output_shape.Dims(0);
const int output_height = ext_output_shape.Dims(1);
const int output_width = ext_output_shape.Dims(2);
const int output_depth = ext_output_shape.Dims(3);
const int left_b_padding = left_padding_copy[0];
const int left_h_padding = left_padding_copy[1];
const int left_w_padding = left_padding_copy[2];
const int left_d_padding = left_padding_copy[3];
const int right_b_padding = right_padding_copy[0];
const int right_h_padding = right_padding_copy[1];
const int right_w_padding = right_padding_copy[2];
const int right_d_padding = right_padding_copy[3];
const T pad_value = *pad_value_ptr;
const T* in_ptr = input_data;
T* out_ptr = output_data;
for (int out_b = 0; out_b < output_batch; ++out_b) {
for (int out_h = 0; out_h < output_height; ++out_h) {
for (int out_w = 0; out_w < output_width; ++out_w) {
for (int out_d = 0; out_d < output_depth; ++out_d) {
if (out_b < left_b_padding ||
out_b >= output_batch - right_b_padding ||
out_h < left_h_padding ||
out_h >= output_height - right_h_padding ||
out_w < left_w_padding ||
out_w >= output_width - right_w_padding ||
out_d < left_d_padding ||
out_d >= output_depth - right_d_padding) {
*out_ptr++ = pad_value;
} else {
*out_ptr++ = *in_ptr++;
}
}
}
}
}
}
template <typename T, typename P>
inline void Pad(const tflite::PadParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
const P* pad_value_ptr, const RuntimeShape& output_shape,
T* output_data) {
PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
output_data);
}
// The second (pad-value) input can be int32 when, say, the first is uint8.
template <typename T>
inline void Pad(const tflite::PadParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
const int32* pad_value_ptr, const RuntimeShape& output_shape,
T* output_data) {
const T converted_pad_value = static_cast<T>(*pad_value_ptr);
PadImpl(op_params, input_shape, input_data, &converted_pad_value,
output_shape, output_data);
}
// This version avoids conflicting template matching.
template <>
inline void Pad(const tflite::PadParams& op_params,
const RuntimeShape& input_shape, const int32* input_data,
const int32* pad_value_ptr, const RuntimeShape& output_shape,
int32* output_data) {
PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
output_data);
}
// One could make all PadImageStyle calls simply delegate the work to the
// ordinary Pad. However, it is better that the reference code asserts false in
// similar cases.
template <typename T, typename P>
inline void PadImageStyle(const tflite::PadParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
const P* pad_value_ptr,
const RuntimeShape& output_shape, T* output_data) {
TFLITE_ASSERT_FALSE;
}
template <typename P>
inline void PadImageStyle(const tflite::PadParams& op_params,
const RuntimeShape& input_shape,
const uint8* input_data, const P* pad_value_ptr,
const RuntimeShape& output_shape,
uint8* output_data) {
Pad(op_params, input_shape, input_data, pad_value_ptr, output_shape,
output_data);
}
template <typename P>
inline void PadImageStyle(const tflite::PadParams& op_params,
const RuntimeShape& input_shape,
const int8_t* input_data, const P* pad_value_ptr,
const RuntimeShape& output_shape,
int8_t* output_data) {
Pad(op_params, input_shape, input_data, pad_value_ptr, output_shape,
output_data);
}
template <typename P>
inline void PadImageStyle(const tflite::PadParams& op_params,
const RuntimeShape& input_shape,
const float* input_data, const P* pad_value_ptr,
const RuntimeShape& output_shape,
float* output_data) {
Pad(op_params, input_shape, input_data, pad_value_ptr, output_shape,
output_data);
}
} // namespace reference_ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PAD_H_

View File

@ -45,6 +45,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/reference/maximum_minimum.h"
#include "tensorflow/lite/kernels/internal/reference/mul.h"
#include "tensorflow/lite/kernels/internal/reference/neg.h"
#include "tensorflow/lite/kernels/internal/reference/pad.h"
#include "tensorflow/lite/kernels/internal/reference/pooling.h"
#include "tensorflow/lite/kernels/internal/reference/prelu.h"
#include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
@ -2148,150 +2149,6 @@ inline void BatchToSpaceND(
}
}
// There are two versions of pad: Pad and PadV2. In PadV2 there is a second
// scalar input that provides the padding value. Therefore pad_value_ptr can be
// equivalent to a simple input1_data. For Pad, it should point to a zero
// value.
//
// Note that two typenames are required, so that T=P=int32 is considered a
// specialization distinct from P=int32.
template <typename T, typename P>
inline void PadImpl(const tflite::PadParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
const P* pad_value_ptr, const RuntimeShape& output_shape,
T* output_data) {
const RuntimeShape ext_input_shape =
RuntimeShape::ExtendedShape(4, input_shape);
const RuntimeShape ext_output_shape =
RuntimeShape::ExtendedShape(4, output_shape);
TFLITE_DCHECK_LE(op_params.left_padding_count, 4);
TFLITE_DCHECK_LE(op_params.right_padding_count, 4);
// Runtime calls are currently fixed at 4 dimensions. Copy inputs so
// we can pad them to 4 dims (yes, we are "padding the padding").
std::vector<int> left_padding_copy(4, 0);
for (int i = 0; i < op_params.left_padding_count; ++i) {
left_padding_copy[i + 4 - op_params.left_padding_count] =
op_params.left_padding[i];
}
std::vector<int> right_padding_copy(4, 0);
for (int i = 0; i < op_params.right_padding_count; ++i) {
right_padding_copy[i + 4 - op_params.right_padding_count] =
op_params.right_padding[i];
}
const int output_batch = ext_output_shape.Dims(0);
const int output_height = ext_output_shape.Dims(1);
const int output_width = ext_output_shape.Dims(2);
const int output_depth = ext_output_shape.Dims(3);
const int left_b_padding = left_padding_copy[0];
const int left_h_padding = left_padding_copy[1];
const int left_w_padding = left_padding_copy[2];
const int left_d_padding = left_padding_copy[3];
const int right_b_padding = right_padding_copy[0];
const int right_h_padding = right_padding_copy[1];
const int right_w_padding = right_padding_copy[2];
const int right_d_padding = right_padding_copy[3];
const T pad_value = *pad_value_ptr;
const T* in_ptr = input_data;
T* out_ptr = output_data;
for (int out_b = 0; out_b < output_batch; ++out_b) {
for (int out_h = 0; out_h < output_height; ++out_h) {
for (int out_w = 0; out_w < output_width; ++out_w) {
for (int out_d = 0; out_d < output_depth; ++out_d) {
if (out_b < left_b_padding ||
out_b >= output_batch - right_b_padding ||
out_h < left_h_padding ||
out_h >= output_height - right_h_padding ||
out_w < left_w_padding ||
out_w >= output_width - right_w_padding ||
out_d < left_d_padding ||
out_d >= output_depth - right_d_padding) {
*out_ptr++ = pad_value;
} else {
*out_ptr++ = *in_ptr++;
}
}
}
}
}
}
template <typename T, typename P>
inline void Pad(const tflite::PadParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
const P* pad_value_ptr, const RuntimeShape& output_shape,
T* output_data) {
PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
output_data);
}
// The second (pad-value) input can be int32 when, say, the first is uint8.
template <typename T>
inline void Pad(const tflite::PadParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
const int32* pad_value_ptr, const RuntimeShape& output_shape,
T* output_data) {
const T converted_pad_value = static_cast<T>(*pad_value_ptr);
PadImpl(op_params, input_shape, input_data, &converted_pad_value,
output_shape, output_data);
}
// This version avoids conflicting template matching.
template <>
inline void Pad(const tflite::PadParams& op_params,
const RuntimeShape& input_shape, const int32* input_data,
const int32* pad_value_ptr, const RuntimeShape& output_shape,
int32* output_data) {
PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
output_data);
}
// One could make all PadImageStyle calls simply delegate the work to the
// ordinary Pad. However, it is better that the reference code asserts false in
// similar cases.
template <typename T, typename P>
inline void PadImageStyle(const tflite::PadParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
const P* pad_value_ptr,
const RuntimeShape& output_shape, T* output_data) {
TFLITE_ASSERT_FALSE;
}
template <typename P>
inline void PadImageStyle(const tflite::PadParams& op_params,
const RuntimeShape& input_shape,
const uint8* input_data, const P* pad_value_ptr,
const RuntimeShape& output_shape,
uint8* output_data) {
Pad(op_params, input_shape, input_data, pad_value_ptr, output_shape,
output_data);
}
template <typename P>
inline void PadImageStyle(const tflite::PadParams& op_params,
const RuntimeShape& input_shape,
const int8_t* input_data, const P* pad_value_ptr,
const RuntimeShape& output_shape,
int8_t* output_data) {
Pad(op_params, input_shape, input_data, pad_value_ptr, output_shape,
output_data);
}
template <typename P>
inline void PadImageStyle(const tflite::PadParams& op_params,
const RuntimeShape& input_shape,
const float* input_data, const P* pad_value_ptr,
const RuntimeShape& output_shape,
float* output_data) {
Pad(op_params, input_shape, input_data, pad_value_ptr, output_shape,
output_data);
}
template <typename T>
inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape,

View File

@ -107,7 +107,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
// TODO(nupurgarg): Current implementations rely on the inputs being <= 4D.
TF_LITE_ENSURE(context, op_context.dims <= 4);
TF_LITE_ENSURE(
context, op_context.dims <= reference_ops::PadKernelMaxDimensionCount());
// Exit early if paddings is a non-const tensor. Set output tensor to
// dynamic so output size can be determined in Eval.
@ -132,32 +133,22 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
}
// TODO(nupurgarg): Change kernel implementation to take in int* instead of
// vector<int> to remove malloc from Eval().
// Create before and after padding arrays that are accepted by the kernel.
std::vector<int> before_padding;
std::vector<int> after_padding;
const int32* paddings_data = GetTensorData<int32>(op_context.paddings);
// TODO(nupurgarg): Change kernel implementation to use padding arrays in
// forward order (depth, width, height, batch).
// Build paddings in order of int[] = {batch, height, width, depth} to match
// kernel implementation of Pad in reference_ops.h and optimized_ops.h.
TF_LITE_ENSURE(
context, op_context.dims <= reference_ops::PadKernelMaxDimensionCount());
tflite::PadParams op_params;
op_params.left_padding_count = op_context.dims;
op_params.right_padding_count = op_context.dims;
for (int idx = op_context.dims - 1; idx >= 0; --idx) {
before_padding.push_back(paddings_data[idx * 2]);
after_padding.push_back(paddings_data[idx * 2 + 1]);
op_params.left_padding[idx] = paddings_data[idx * 2];
op_params.right_padding[idx] = paddings_data[idx * 2 + 1];
}
#define TF_LITE_PAD(type, op_name, scalar, pad_value) \
TF_LITE_ENSURE(context, before_padding.size() <= 4); \
TF_LITE_ENSURE(context, after_padding.size() <= 4); \
tflite::PadParams op_params; \
op_params.left_padding_count = before_padding.size(); \
op_params.right_padding_count = after_padding.size(); \
for (int i = 0; i < op_context.dims; ++i) { \
op_params.left_padding[i] = before_padding[op_context.dims - 1 - i]; \
op_params.right_padding[i] = after_padding[op_context.dims - 1 - i]; \
} \
const scalar pad_value_copy = pad_value; \
\
type::op_name(op_params, GetTensorShape(op_context.input), \

View File

@ -189,7 +189,7 @@ TEST(PadOpTest, TooManyDimensions) {
PadOpConstModel({TensorType_FLOAT32, {1, 2, 3, 4, 5, 6, 7, 8, 9}}, {9, 2},
{1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9},
{TensorType_FLOAT32}),
"dims <= 4");
"dims <= reference_ops::PadKernelMaxDimensionCount()");
}
TEST(PadOpTest, UnequalDimensions) {
@ -426,7 +426,7 @@ TEST(PadV2OpTest, TooManyDimensions) {
EXPECT_DEATH(f({TensorType_FLOAT32, {1, 2, 3, 4, 5, 6, 7, 8, 9}}, {9, 2},
{1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9}, 0.0,
{TensorType_FLOAT32}),
"dims <= 4");
"dims <= reference_ops::PadKernelMaxDimensionCount()");
}
TEST(PadV2OpTest, UnequalDimensions) {

View File

@ -11,6 +11,7 @@ package(
licenses = ["notice"], # Apache 2.0
)
# LINT.IfChange(micro_ops)
cc_library(
name = "micro_ops",
srcs = [
@ -32,6 +33,7 @@ cc_library(
"mul.cc",
"neg.cc",
"pack.cc",
"pad.cc",
"pooling.cc",
"prelu.cc",
"quantize.cc",
@ -60,6 +62,7 @@ cc_library(
"//tensorflow/lite/micro:micro_utils",
],
)
# LINT.ThenChange(//tensorflow/lite/micro/kernels/BUILD:portable_optimized_micro_ops)
cc_library(
name = "all_ops_resolver",
@ -76,6 +79,7 @@ cc_library(
],
)
# LINT.IfChange(portable_optimized_micro_ops)
cc_library(
name = "portable_optimized_micro_ops",
srcs = [
@ -96,6 +100,7 @@ cc_library(
"mul.cc",
"neg.cc",
"pack.cc",
"pad.cc",
"pooling.cc",
"portable_optimized/depthwise_conv.cc",
"prelu.cc",
@ -120,13 +125,13 @@ cc_library(
"//tensorflow/lite/kernels/internal:common",
"//tensorflow/lite/kernels/internal:quantization_util",
"//tensorflow/lite/kernels/internal:reference_base",
"//tensorflow/lite/kernels/internal:strided_slice_logic",
"//tensorflow/lite/kernels/internal:tensor",
"//tensorflow/lite/kernels/internal:types",
"//tensorflow/lite/micro:micro_utils",
],
)
# LINT.ThenChange(//tensorflow/lite/micro/kernels/BUILD:micro_ops)
cc_library(
name = "portable_optimized_ops_resolver",
srcs = [
@ -550,3 +555,16 @@ tflite_micro_cc_test(
"//tensorflow/lite/micro/testing:micro_test",
],
)
tflite_micro_cc_test(
name = "pad_test",
srcs = [
"pad_test.cc",
],
deps = [
":all_ops_resolver",
"//tensorflow/lite/c:common",
"//tensorflow/lite/micro:micro_framework",
"//tensorflow/lite/micro/testing:micro_test",
],
)

View File

@ -58,6 +58,8 @@ AllOpsResolver::AllOpsResolver() {
AddBuiltin(BuiltinOperator_ROUND, Register_ROUND());
AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE());
AddBuiltin(BuiltinOperator_PACK, Register_PACK());
AddBuiltin(BuiltinOperator_PAD, Register_PAD());
AddBuiltin(BuiltinOperator_PADV2, Register_PADV2());
AddBuiltin(BuiltinOperator_SPLIT, Register_SPLIT(), 1, 3);
AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK());
AddBuiltin(BuiltinOperator_NEG, Register_NEG());

View File

@ -59,6 +59,8 @@ TfLiteRegistration* Register_MUL();
TfLiteRegistration* Register_NEG();
TfLiteRegistration* Register_NOT_EQUAL();
TfLiteRegistration* Register_PACK();
TfLiteRegistration* Register_PAD();
TfLiteRegistration* Register_PADV2();
TfLiteRegistration* Register_PRELU();
TfLiteRegistration* Register_QUANTIZE();
TfLiteRegistration* Register_RELU();

View File

@ -0,0 +1,222 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/pad.h"
#include <string.h>
#include "tensorflow/lite/kernels/internal/types.h"
#ifdef MEMORY_SANITIZER
#include <sanitizer/msan_interface.h>
#else
#define __msan_check_mem_is_initialized(ptr, size)
#endif
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
namespace tflite {
namespace ops {
namespace micro {
namespace pad {
struct PadContext {
PadContext(TfLiteContext* context, TfLiteNode* node) {
input = GetInput(context, node, 0);
paddings = GetInput(context, node, 1);
constant_values = nullptr;
if (NumInputs(node) == 3) {
constant_values = GetOptionalInputTensor(context, node, 2);
} else {
constant_values = nullptr;
}
output = GetOutput(context, node, 0);
dims = NumDimensions(input);
resizing_category = ResizingCategory::kGenericResize;
const int paddings_total = GetTensorShape(paddings).FlatSize();
const int32* paddings_data = GetTensorData<int32>(paddings);
// Paddings will be a n,2 array, and we need to detect 4D arrays with the
// pattern { {0,0}, {a, b}, {c, d}, {0,0} }.
if (IsConstantTensor(paddings) && paddings_total == 8 &&
(paddings_data[0] == 0 && paddings_data[1] == 0) &&
(paddings_data[6] == 0 && paddings_data[7] == 0)) {
resizing_category = ResizingCategory::kImageStyle;
}
}
const TfLiteTensor* constant_values;
const TfLiteTensor* input;
const TfLiteTensor* paddings;
TfLiteTensor* output;
int dims;
ResizingCategory resizing_category;
};
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
PadContext op_context(context, node);
TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
if (op_context.constant_values != nullptr) {
TF_LITE_ENSURE_EQ(context, op_context.input->type,
op_context.constant_values->type);
}
// There must be a pair of paddings for each output dimension.
TF_LITE_ENSURE_EQ(context, GetTensorShape(op_context.paddings).FlatSize(),
op_context.output->dims->size * 2);
// On Micro, outputs must be properly sized by the converter.
const int32* paddings_data = GetTensorData<int32>(op_context.paddings);
for (int i = 0; i < op_context.output->dims->size; i++) {
int output_dim = op_context.output->dims->data[i];
int expected_dim = op_context.input->dims->data[i] + paddings_data[i * 2] +
paddings_data[i * 2 + 1];
TF_LITE_ENSURE_EQ(context, output_dim, expected_dim);
}
// Current implementations rely on the inputs being <= 4D.
TF_LITE_ENSURE(
context, op_context.dims <= reference_ops::PadKernelMaxDimensionCount());
TF_LITE_ENSURE(context, IsConstantTensor(op_context.paddings));
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
PadContext op_context(context, node);
if (op_context.constant_values != nullptr) {
// Ensure that constant_values is a scalar.
TF_LITE_ENSURE_EQ(context, NumElements(op_context.constant_values), 1);
}
// Create before and after padding arrays that are accepted by the kernel.
const int32* paddings_data = GetTensorData<int32>(op_context.paddings);
tflite::PadParams op_params;
memset(&op_params, 0, sizeof(PadParams));
op_params.left_padding_count = op_context.dims;
op_params.right_padding_count = op_context.dims;
for (int idx = op_context.dims - 1; idx >= 0; --idx) {
op_params.left_padding[idx] = paddings_data[idx * 2];
op_params.right_padding[idx] = paddings_data[idx * 2 + 1];
}
#define TF_LITE_PAD(type, op_name, scalar, pad_value) \
const scalar pad_value_copy = pad_value; \
\
type::op_name(op_params, GetTensorShape(op_context.input), \
GetTensorData<scalar>(op_context.input), &pad_value_copy, \
GetTensorShape(op_context.output), \
GetTensorData<scalar>(op_context.output))
switch (op_context.input->type) {
case kTfLiteFloat32: {
float pad_value = op_context.constant_values == nullptr
? 0.f
: *GetTensorData<float>(op_context.constant_values);
if (op_context.resizing_category == ResizingCategory::kImageStyle) {
TF_LITE_PAD(reference_ops, PadImageStyle, float, pad_value);
} else {
TF_LITE_PAD(reference_ops, Pad, float, pad_value);
}
} break;
case kTfLiteUInt8: {
uint8_t pad_value;
if (op_context.constant_values == nullptr) {
// Quantized Pad requires that 0 is represented in the quantized
// range.
TF_LITE_ENSURE(context, op_context.output->params.zero_point >=
std::numeric_limits<uint8_t>::min());
TF_LITE_ENSURE(context, op_context.output->params.zero_point <=
std::numeric_limits<uint8_t>::max());
pad_value = static_cast<uint8_t>(op_context.output->params.zero_point);
} else {
// Quantized Pad requires that 'constant_values' is represented in the
// same quantized range as the input and output tensors.
TF_LITE_ENSURE_EQ(context, op_context.output->params.zero_point,
op_context.constant_values->params.zero_point);
TF_LITE_ENSURE_EQ(context, op_context.output->params.scale,
op_context.constant_values->params.scale);
pad_value = *GetTensorData<uint8_t>(op_context.constant_values);
}
if (op_context.resizing_category == ResizingCategory::kImageStyle) {
TF_LITE_PAD(reference_ops, PadImageStyle, uint8_t, pad_value);
} else {
TF_LITE_PAD(reference_ops, Pad, uint8_t, pad_value);
}
} break;
case kTfLiteInt8: {
int8_t pad_value;
if (op_context.constant_values == nullptr) {
// Quantized Pad requires that 0 is represented in the quantized
// range.
TF_LITE_ENSURE(context, op_context.output->params.zero_point >=
std::numeric_limits<int8_t>::min());
TF_LITE_ENSURE(context, op_context.output->params.zero_point <=
std::numeric_limits<int8_t>::max());
pad_value = static_cast<int8_t>(op_context.output->params.zero_point);
} else {
// Quantized Pad requires that 'constant_values' is represented in the
// same quantized range as the input and output tensors.
TF_LITE_ENSURE_EQ(context, op_context.output->params.zero_point,
op_context.constant_values->params.zero_point);
TF_LITE_ENSURE(context, op_context.output->params.scale ==
op_context.constant_values->params.scale);
pad_value = *GetTensorData<int8_t>(op_context.constant_values);
}
if (op_context.resizing_category == ResizingCategory::kImageStyle) {
TF_LITE_PAD(reference_ops, PadImageStyle, int8_t, pad_value);
} else {
TF_LITE_PAD(reference_ops, Pad, int8_t, pad_value);
}
} break;
case kTfLiteInt32: {
int32_t pad_value =
op_context.constant_values == nullptr
? 0
: *GetTensorData<int32_t>(op_context.constant_values);
TF_LITE_PAD(reference_ops, Pad, int32_t, pad_value);
} break;
default:
context->ReportError(context, "Type %s not currently supported by Pad.",
TfLiteTypeGetName(op_context.input->type));
return kTfLiteError;
}
#undef TF_LITE_PAD
return kTfLiteOk;
}
} // namespace pad
TfLiteRegistration* Register_PAD() {
static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare, pad::Eval};
return &r;
}
// Also register Pad as PadV2.
TfLiteRegistration* Register_PADV2() {
static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare, pad::Eval};
return &r;
}
} // namespace micro
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,444 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/kernels/all_ops_resolver.h"
#include "tensorflow/lite/micro/test_helpers.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
#include "tensorflow/lite/micro/testing/test_utils.h"
namespace tflite {
namespace testing {
namespace {
template <typename T>
TfLiteStatus ValidatePadGoldens(TfLiteTensor* tensors, int tensors_size,
const T* golden, T* output_data,
int output_length) {
TfLiteContext context;
PopulateContext(tensors, tensors_size, &context);
::tflite::ops::micro::AllOpsResolver resolver;
const TfLiteRegistration* registration =
resolver.FindOp(tflite::BuiltinOperator_PAD, 1);
TF_LITE_ENSURE(&context, registration != nullptr);
int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->prepare);
TF_LITE_ENSURE_EQ(&context, kTfLiteOk,
registration->prepare(&context, &node));
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_ENSURE_EQ(&context, kTfLiteOk, registration->invoke(&context, &node));
for (int i = 0; i < output_length; ++i) {
TF_LITE_MICRO_EXPECT_EQ(golden[i], output_data[i]);
}
return kTfLiteOk;
}
template <typename T>
TfLiteStatus ValidatePadV2Goldens(TfLiteTensor* tensors, int tensors_size,
const T* golden, T* output_data,
int output_length) {
TfLiteContext context;
PopulateContext(tensors, tensors_size, &context);
::tflite::ops::micro::AllOpsResolver resolver;
const TfLiteRegistration* registration =
resolver.FindOp(tflite::BuiltinOperator_PADV2, 1);
TF_LITE_ENSURE(&context, registration != nullptr);
int inputs_array_data[] = {3, 0, 1, 2};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->prepare);
// Prepare should catch dimension mismatches.
TfLiteStatus prepare_status = registration->prepare(&context, &node);
if (prepare_status != kTfLiteOk) {
return prepare_status;
}
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
// Eval should catch quantization mismatches.
TfLiteStatus invoke_status = registration->invoke(&context, &node);
if (invoke_status != kTfLiteOk) {
return invoke_status;
}
for (int i = 0; i < output_length; ++i) {
TF_LITE_MICRO_EXPECT_EQ(golden[i], output_data[i]);
}
return kTfLiteOk;
}
// output data and golden must be shaped correctly
void TestPadFloat(const int* input_dims_data, const float* input_data,
const int* pad_dims_data, const int* pad_data,
const int* output_dims_data, const float* golden,
float* output_data,
TfLiteStatus expected_status = kTfLiteOk) {
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* pad_dims = IntArrayFromInts(pad_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
const int output_dims_count = ElementCount(*output_dims);
constexpr int inputs_size = 2;
constexpr int outputs_size = 1;
constexpr int tensors_size = inputs_size + outputs_size;
TfLiteTensor tensors[tensors_size] = {
CreateFloatTensor(input_data, input_dims, "input_tensor"),
CreateInt32Tensor(pad_data, pad_dims, "padding tensor"),
CreateFloatTensor(output_data, output_dims, "output_tensor")};
// Pad tensor must be constant.
tensors[1].allocation_type = kTfLiteMmapRo;
TF_LITE_MICRO_EXPECT_EQ(expected_status,
ValidatePadGoldens(tensors, tensors_size, golden,
output_data, output_dims_count));
}
// output data and golden must be shaped correctly
void TestPadV2Float(const int* input_dims_data, const float* input_data,
const int* pad_dims_data, const int* pad_data,
const float pad_value, const int* output_dims_data,
const float* golden, float* output_data,
TfLiteStatus expected_status = kTfLiteOk) {
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* pad_dims = IntArrayFromInts(pad_dims_data);
const int pad_value_dims_data[] = {1, 1}; // Only one padding value allowed.
TfLiteIntArray* pad_value_dims = IntArrayFromInts(pad_value_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
const int output_dims_count = ElementCount(*output_dims);
constexpr int inputs_size = 3;
constexpr int outputs_size = 1;
constexpr int tensors_size = inputs_size + outputs_size;
TfLiteTensor tensors[tensors_size] = {
CreateFloatTensor(input_data, input_dims, "input_tensor"),
CreateInt32Tensor(pad_data, pad_dims, "padding tensor"),
CreateFloatTensor(&pad_value, pad_value_dims, "pad value tensor"),
CreateFloatTensor(output_data, output_dims, "output_tensor")};
// Pad tensor must be constant.
tensors[1].allocation_type = kTfLiteMmapRo;
TF_LITE_MICRO_EXPECT_EQ(expected_status,
ValidatePadV2Goldens(tensors, tensors_size, golden,
output_data, output_dims_count));
}
template <typename T>
void TestPadQuantized(const int* input_dims_data, const float* input_data,
T* input_quantized, float input_scale,
int input_zero_point, const int* pad_dims_data,
const int* pad_data, const int* output_dims_data,
const float* golden, T* golden_quantized,
float output_scale, int output_zero_point, T* output_data,
TfLiteStatus expected_status = kTfLiteOk) {
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* pad_dims = IntArrayFromInts(pad_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
const int output_dims_count = ElementCount(*output_dims);
constexpr int inputs_size = 2;
constexpr int outputs_size = 1;
constexpr int tensors_size = inputs_size + outputs_size;
TfLiteTensor tensors[tensors_size] = {
CreateQuantizedTensor(input_data, input_quantized, input_dims,
input_scale, input_zero_point, "input_tensor"),
CreateInt32Tensor(pad_data, pad_dims, "padding tensor"),
CreateQuantizedTensor(output_data, output_dims, output_scale,
output_zero_point, "output_tensor")};
// Pad tensor must be constant.
tensors[1].allocation_type = kTfLiteMmapRo;
tflite::AsymmetricQuantize(golden, golden_quantized, output_dims_count,
output_scale, output_zero_point);
TF_LITE_MICRO_EXPECT_EQ(
expected_status,
ValidatePadGoldens(tensors, tensors_size, golden_quantized, output_data,
output_dims_count));
}
template <typename T>
void TestPadV2Quantized(const int* input_dims_data, const float* input_data,
T* input_quantized, float input_scale,
int input_zero_point, const int* pad_dims_data,
const int* pad_data, const float pad_value,
const float pad_value_scale,
const int pad_value_zero_point,
const int* output_dims_data, const float* golden,
T* golden_quantized, float output_scale,
int output_zero_point, T* output_data,
TfLiteStatus expected_status = kTfLiteOk) {
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* pad_dims = IntArrayFromInts(pad_dims_data);
const int pad_value_dims_data[] = {1, 1}; // Only one padding value allowed.
TfLiteIntArray* pad_value_dims = IntArrayFromInts(pad_value_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
T pad_value_quantized;
const int output_dims_count = ElementCount(*output_dims);
constexpr int inputs_size = 3;
constexpr int outputs_size = 1;
constexpr int tensors_size = inputs_size + outputs_size;
TfLiteTensor tensors[tensors_size] = {
CreateQuantizedTensor(input_data, input_quantized, input_dims,
input_scale, input_zero_point, "input_tensor"),
CreateInt32Tensor(pad_data, pad_dims, "padding tensor"),
CreateQuantizedTensor(&pad_value, &pad_value_quantized, pad_value_dims,
pad_value_scale, pad_value_zero_point,
"pad value tensor"),
CreateQuantizedTensor(output_data, output_dims, output_scale,
output_zero_point, "output_tensor")};
// Pad tensor must be constant.
tensors[1].allocation_type = kTfLiteMmapRo;
tensors[2].params.scale = pad_value_scale;
tensors[3].params.scale = output_scale;
tflite::AsymmetricQuantize(golden, golden_quantized, output_dims_count,
output_scale, output_zero_point);
TF_LITE_MICRO_EXPECT_EQ(
expected_status,
ValidatePadV2Goldens(tensors, tensors_size, golden_quantized, output_data,
output_dims_count));
}
} // namespace
} // namespace testing
} // namespace tflite
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(Test2DFloat) {
const int input_dims[] = {4, 1, 2, 2, 1};
const float input_values[] = {1, 2, 3, 4};
const int pad_dims[] = {2, 4, 2};
const int pad_values[] = {1, 1, 0, 0, 1, 1, 0, 0};
const int output_dims[] = {4, 3, 2, 4, 1};
const float golden[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0,
0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0};
float output_data[24];
tflite::testing::TestPadFloat(input_dims, input_values, pad_dims, pad_values,
output_dims, golden, output_data);
}
TF_LITE_MICRO_TEST(Test4DFloat) {
const int input_dims[] = {4, 1, 1, 1, 1};
const float input_values[] = {42};
const int pad_dims[] = {2, 4, 2};
const int pad_values[] = {1, 1, 1, 1, 1, 1, 1, 1};
const int output_dims[] = {4, 3, 3, 3, 3};
const int kOutputLen = 81; // 3 * 3 * 3 * 3
float golden[kOutputLen];
for (int i = 0; i < kOutputLen; i++) {
golden[i] = 0;
}
golden[40] = 42;
float output_data[kOutputLen];
tflite::testing::TestPadFloat(input_dims, input_values, pad_dims, pad_values,
output_dims, const_cast<const float*>(golden),
output_data);
}
TF_LITE_MICRO_TEST(Test2DFloatV2) {
const int input_dims[] = {4, 1, 2, 2, 1};
const float input_values[] = {1, 2, 3, 4};
const int pad_dims[] = {2, 4, 2};
const int pad_values[] = {1, 1, 0, 0, 1, 1, 0, 0};
const float pad_value = 42;
const int output_dims[] = {4, 3, 2, 4, 1};
const float golden[] = {42, 42, 42, 42, 42, 42, 42, 42, 42, 1, 2, 42,
42, 3, 4, 42, 42, 42, 42, 42, 42, 42, 42, 42};
float output_data[24];
tflite::testing::TestPadV2Float(input_dims, input_values, pad_dims,
pad_values, pad_value, output_dims, golden,
output_data);
}
TF_LITE_MICRO_TEST(Test2DUInt8) {
const int input_dims[] = {4, 1, 2, 2, 1};
const float input_values[] = {1, 2, 3, 4};
const float input_scale = 1.0f;
const int input_zero_point = 127;
const int pad_dims[] = {2, 4, 2};
const int pad_values[] = {1, 1, 0, 0, 1, 1, 0, 0};
const int output_dims[] = {4, 3, 2, 4, 1};
const float golden[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0,
0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0};
const float output_scale = 1.0f;
const int output_zero_point = 127;
uint8_t output_data[24];
uint8_t input_quantized[4];
uint8_t golden_quantized[24];
tflite::testing::TestPadQuantized(
input_dims, input_values, input_quantized, input_scale, input_zero_point,
pad_dims, pad_values, output_dims, golden, golden_quantized, output_scale,
output_zero_point, output_data);
}
TF_LITE_MICRO_TEST(Test2DUInt8V2) {
const int input_dims[] = {4, 1, 2, 2, 1};
const float input_values[] = {1, 2, 3, 4};
const float input_scale = 1.0f;
const int input_zero_point = 127;
const int pad_dims[] = {2, 4, 2};
const int pad_values[] = {1, 1, 0, 0, 1, 1, 0, 0};
const float pad_value = 42;
const float pad_value_scale = 1.0;
const float pad_value_zero_point = 127;
const int output_dims[] = {4, 3, 2, 4, 1};
const float golden[] = {42, 42, 42, 42, 42, 42, 42, 42, 42, 1, 2, 42,
42, 3, 4, 42, 42, 42, 42, 42, 42, 42, 42, 42};
const float output_scale = 1.0f;
const int output_zero_point = 127;
uint8_t output_data[24];
uint8_t input_quantized[4];
uint8_t golden_quantized[24];
tflite::testing::TestPadV2Quantized(
input_dims, input_values, input_quantized, input_scale, input_zero_point,
pad_dims, pad_values, pad_value, pad_value_scale, pad_value_zero_point,
output_dims, golden, golden_quantized, output_scale, output_zero_point,
output_data);
}
TF_LITE_MICRO_TEST(Test2DInt8) {
const int input_dims[] = {4, 1, 2, 2, 1};
const float input_values[] = {1, 2, 3, 4};
const float input_scale = 1.0f;
const int input_zero_point = 0;
const int pad_dims[] = {2, 4, 2};
const int pad_values[] = {1, 1, 0, 0, 1, 1, 0, 0};
const int output_dims[] = {4, 3, 2, 4, 1};
const float golden[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0,
0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0};
const float output_scale = 1.0f;
const int output_zero_point = 0;
int8_t output_data[24];
int8_t input_quantized[4];
int8_t golden_quantized[24];
tflite::testing::TestPadQuantized(
input_dims, input_values, input_quantized, input_scale, input_zero_point,
pad_dims, pad_values, output_dims, golden, golden_quantized, output_scale,
output_zero_point, output_data);
}
TF_LITE_MICRO_TEST(Test2DInt8V2) {
const int input_dims[] = {4, 1, 2, 2, 1};
const float input_values[] = {1, 2, 3, 4};
const float input_scale = 1.0f;
const int input_zero_point = 0;
const int pad_dims[] = {2, 4, 2};
const int pad_values[] = {1, 1, 0, 0, 1, 1, 0, 0};
const float pad_value = 42;
const float pad_value_scale = 1.0;
const float pad_value_zero_point = 0;
const int output_dims[] = {4, 3, 2, 4, 1};
const float golden[] = {42, 42, 42, 42, 42, 42, 42, 42, 42, 1, 2, 42,
42, 3, 4, 42, 42, 42, 42, 42, 42, 42, 42, 42};
const float output_scale = 1.0f;
const int output_zero_point = 0;
int8_t output_data[24];
int8_t input_quantized[4];
int8_t golden_quantized[24];
tflite::testing::TestPadV2Quantized(
input_dims, input_values, input_quantized, input_scale, input_zero_point,
pad_dims, pad_values, pad_value, pad_value_scale, pad_value_zero_point,
output_dims, golden, golden_quantized, output_scale, output_zero_point,
output_data);
}
TF_LITE_MICRO_TEST(Test2DInt8V2ExpectFailurePadValueQuantizationMismatch) {
const int input_dims[] = {4, 1, 2, 2, 1};
const float input_values[] = {1, 2, 3, 4};
const float input_scale = 1.0f;
const int input_zero_point = 0;
const int pad_dims[] = {2, 4, 2};
const int pad_values[] = {1, 1, 0, 0, 1, 1, 0, 0};
const float pad_value = 42;
// Causes failure since this is in a different quantization space than input.
const float pad_value_scale = .5;
const float pad_value_zero_point = 0;
const int output_dims[] = {4, 3, 2, 4, 1};
const float golden[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
const float output_scale = 1.0f;
const int output_zero_point = 0;
int8_t output_data[24] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
int8_t input_quantized[4];
int8_t golden_quantized[24];
tflite::testing::TestPadV2Quantized(
input_dims, input_values, input_quantized, input_scale, input_zero_point,
pad_dims, pad_values, pad_value, pad_value_scale, pad_value_zero_point,
output_dims, golden, golden_quantized, output_scale, output_zero_point,
output_data, kTfLiteError);
}
TF_LITE_MICRO_TEST(Test2DInt8ExpectFailureQuantizationRangeExcludesZero) {
const int input_dims[] = {4, 1, 2, 2, 1};
const float input_values[] = {1, 2, 3, 4};
const float input_scale = 1.0f;
const int input_zero_point = 0;
const int pad_dims[] = {2, 4, 2};
const int pad_values[] = {1, 1, 0, 0, 1, 1, 0, 0};
const int output_dims[] = {4, 3, 2, 4, 1};
const float golden[] = {42, 42, 42, 42, 42, 42, 42, 42, 42, 1, 2, 42,
42, 3, 4, 42, 42, 42, 42, 42, 42, 42, 42, 42};
// Causes failure since this quantization zero point excludes zero.
const float output_scale = 1.0f;
const int output_zero_point = 129;
int8_t output_data[24];
int8_t input_quantized[4];
int8_t golden_quantized[24];
tflite::testing::TestPadQuantized(
input_dims, input_values, input_quantized, input_scale, input_zero_point,
pad_dims, pad_values, output_dims, golden, golden_quantized, output_scale,
output_zero_point, output_data, kTfLiteError);
}
TF_LITE_MICRO_TESTS_END

View File

@ -140,6 +140,7 @@ tensorflow/lite/kernels/internal/reference/integer_ops/softmax.h \
tensorflow/lite/kernels/internal/reference/maximum_minimum.h \
tensorflow/lite/kernels/internal/reference/mul.h \
tensorflow/lite/kernels/internal/reference/neg.h \
tensorflow/lite/kernels/internal/reference/pad.h \
tensorflow/lite/kernels/internal/reference/pooling.h \
tensorflow/lite/kernels/internal/reference/prelu.h \
tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h \