Extract reference for operator ADD_N to standalone header
Move the reference implementation to its own header so that micro can use it without the unrelated depedencies of reference_ops.h. PR step 2 for issue #46162
This commit is contained in:
parent
ca7c72bbbd
commit
f98d9ecb8e
@ -449,6 +449,7 @@ cc_library(
|
|||||||
srcs = [],
|
srcs = [],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"reference/add.h",
|
"reference/add.h",
|
||||||
|
"reference/add_n.h",
|
||||||
"reference/arg_min_max.h",
|
"reference/arg_min_max.h",
|
||||||
"reference/batch_matmul.h",
|
"reference/batch_matmul.h",
|
||||||
"reference/binary_function.h",
|
"reference/binary_function.h",
|
||||||
@ -558,6 +559,7 @@ cc_library(
|
|||||||
srcs = [],
|
srcs = [],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"reference/add.h",
|
"reference/add.h",
|
||||||
|
"reference/add_n.h",
|
||||||
"reference/arg_min_max.h",
|
"reference/arg_min_max.h",
|
||||||
"reference/binary_function.h",
|
"reference/binary_function.h",
|
||||||
"reference/ceil.h",
|
"reference/ceil.h",
|
||||||
|
42
tensorflow/lite/kernels/internal/reference/add_n.h
Normal file
42
tensorflow/lite/kernels/internal/reference/add_n.h
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ADD_N_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ADD_N_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace reference_ops {
|
||||||
|
|
||||||
|
// T is expected to be either float or int.
|
||||||
|
template <typename T>
|
||||||
|
inline void AddN(const RuntimeShape& input_shape, const size_t num_inputs,
|
||||||
|
T* const* input_data, T* output_data) {
|
||||||
|
// All inputs and output should have the same shape, this is checked during
|
||||||
|
// Prepare stage.
|
||||||
|
const size_t size = input_shape.FlatSize();
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
T x = 0;
|
||||||
|
for (int j = 0; j < num_inputs; ++j) {
|
||||||
|
x += input_data[j][i];
|
||||||
|
}
|
||||||
|
output_data[i] = x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace reference_ops
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ADD_N_H_
|
@ -33,6 +33,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/kernels/internal/common.h"
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/add.h"
|
#include "tensorflow/lite/kernels/internal/reference/add.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/add_n.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/arg_min_max.h"
|
#include "tensorflow/lite/kernels/internal/reference/arg_min_max.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/binary_function.h"
|
#include "tensorflow/lite/kernels/internal/reference/binary_function.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/ceil.h"
|
#include "tensorflow/lite/kernels/internal/reference/ceil.h"
|
||||||
@ -253,22 +254,6 @@ inline void QuantizeLeakyRelu(const LeakyReluParams& params,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// T is expected to be either float or int.
|
|
||||||
template <typename T>
|
|
||||||
inline void AddN(const RuntimeShape& input_shape, const size_t num_inputs,
|
|
||||||
T* const* input_data, T* output_data) {
|
|
||||||
// All inputs and output should have the same shape, this is checked during
|
|
||||||
// Prepare stage.
|
|
||||||
const size_t size = input_shape.FlatSize();
|
|
||||||
for (int i = 0; i < size; ++i) {
|
|
||||||
T x = 0;
|
|
||||||
for (int j = 0; j < num_inputs; ++j) {
|
|
||||||
x += input_data[j][i];
|
|
||||||
}
|
|
||||||
output_data[i] = x;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary
|
// TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary
|
||||||
// dimensionality if the runtime code does a single loop over one dimension
|
// dimensionality if the runtime code does a single loop over one dimension
|
||||||
// that handles broadcasting as the base case. The code generator would then
|
// that handles broadcasting as the base case. The code generator would then
|
||||||
|
Loading…
x
Reference in New Issue
Block a user