From df211323b834ead969130146297c9416853f1e65 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Feb 2019 14:19:54 -0800 Subject: [PATCH] Add built-in operator AddN. It only supports float and int_32 now. PiperOrigin-RevId: 232365698 --- tensorflow/lite/BUILD | 4 +- tensorflow/lite/build_def.bzl | 1 + tensorflow/lite/g3doc/tf_ops_compatibility.md | 11 +++ tensorflow/lite/kernels/BUILD | 13 +++ tensorflow/lite/kernels/add_n.cc | 88 +++++++++++++++++ tensorflow/lite/kernels/add_n_test.cc | 98 +++++++++++++++++++ .../internal/reference/reference_ops.h | 16 +++ tensorflow/lite/kernels/register.cc | 2 + tensorflow/lite/testing/generate_examples.py | 45 +++++++++ tensorflow/lite/toco/import_tensorflow.cc | 2 +- tensorflow/lite/toco/tflite/operator.cc | 21 ++++ tensorflow/lite/toco/tflite/operator_test.cc | 7 ++ 12 files changed, 305 insertions(+), 3 deletions(-) create mode 100644 tensorflow/lite/kernels/add_n.cc create mode 100644 tensorflow/lite/kernels/add_n_test.cc diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index 7d71ca32d50..8cc91a1a24f 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -4,7 +4,7 @@ package( licenses(["notice"]) # Apache 2.0 -load("//tensorflow:tensorflow.bzl", "tf_cc_test", "if_not_windows") +load("//tensorflow:tensorflow.bzl", "if_not_windows", "tf_cc_test") load("//tensorflow/lite:build_def.bzl", "tflite_copts") load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") @@ -190,7 +190,7 @@ cc_library( ":string", ":util", "//tensorflow/lite/c:c_api_internal", - "//tensorflow/lite/core/api:api", + "//tensorflow/lite/core/api", "//tensorflow/lite/nnapi:nnapi_implementation", "//tensorflow/lite/profiling:profiler", "//tensorflow/lite/schema:schema_fbs", diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl index 181590e5a94..88a8faf02e0 100644 --- a/tensorflow/lite/build_def.bzl +++ b/tensorflow/lite/build_def.bzl @@ -226,6 +226,7 @@ def generated_test_models(): return [ "abs", "add", + "add_n", "arg_min_max", "avg_pool", "batch_to_space_nd", diff --git a/tensorflow/lite/g3doc/tf_ops_compatibility.md b/tensorflow/lite/g3doc/tf_ops_compatibility.md index cff4afc2508..d7c71df9d8f 100644 --- a/tensorflow/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/lite/g3doc/tf_ops_compatibility.md @@ -165,6 +165,17 @@ Options { } ``` +**ADD_N** + +``` +Inputs { + 0-N: any number of tensors (must have same size and shape) +} +Outputs { + 0: elementwise sum of the input tensors +} +``` + **ARG_MAX** ``` diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 826bb9a2585..90e4d82618f 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -152,6 +152,7 @@ cc_library( srcs = [ "activations.cc", "add.cc", + "add_n.cc", "arg_min_max.cc", "audio_spectrogram.cc", "basic_rnn.cc", @@ -355,6 +356,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "add_n_test", + size = "small", + srcs = ["add_n_test.cc"], + deps = [ + ":builtin_ops", + ":test_util", + "//tensorflow/lite:framework", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "arg_min_max_test", size = "small", diff --git a/tensorflow/lite/kernels/add_n.cc b/tensorflow/lite/kernels/add_n.cc new file mode 100644 index 00000000000..3e9b2ea24af --- /dev/null +++ b/tensorflow/lite/kernels/add_n.cc @@ -0,0 +1,88 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace add_n { + +constexpr int kInputTensor1 = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + int num_inputs = NumInputs(node); + TF_LITE_ENSURE(context, num_inputs >= 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + output->type = input1->type; + + // Check that all input tensors have the same shape and type. + for (int i = kInputTensor1 + 1; i < num_inputs; ++i) { + const TfLiteTensor* input = GetInput(context, node, i); + TF_LITE_ENSURE(context, HaveSameShapes(input1, input)); + TF_LITE_ENSURE_EQ(context, input1->type, input->type); + } + + // Use the first input node's dimension to be the dimension of the output + // node. + TfLiteIntArray* input1_dims = input1->dims; + TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input1_dims); + return context->ResizeTensor(context, output, output_dims); +} + +template +void EvalAddN(TfLiteContext* context, TfLiteNode* node) { + // TODO(haoliang): Initialize all_inputs only once during init. + VectorOfTensors all_inputs(*context, *node->inputs); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + int num_inputs = NumInputs(node); + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + reference_ops::AddN(GetTensorShape(input1), num_inputs, all_inputs.data(), + GetTensorData(output)); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + if (output->type == kTfLiteFloat32) { + EvalAddN(context, node); + } else if (output->type == kTfLiteInt32) { + EvalAddN(context, node); + } else { + context->ReportError(context, + "AddN only supports FLOAT32|INT32 now, got %s.", + TfLiteTypeGetName(output->type)); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace add_n + +TfLiteRegistration* Register_ADD_N() { + static TfLiteRegistration r = {/*init*/ nullptr, /*free*/ nullptr, + add_n::Prepare, add_n::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/add_n_test.cc b/tensorflow/lite/kernels/add_n_test.cc new file mode 100644 index 00000000000..ee9477d2ff1 --- /dev/null +++ b/tensorflow/lite/kernels/add_n_test.cc @@ -0,0 +1,98 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseAddNOpModel : public SingleOpModel { + public: + BaseAddNOpModel(const std::vector& inputs, + const TensorData& output) { + int num_inputs = inputs.size(); + std::vector> input_shapes; + + for (int i = 0; i < num_inputs; ++i) { + inputs_.push_back(AddInput(inputs[i])); + input_shapes.push_back(GetShape(inputs_[i])); + } + + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_ADD_N, BuiltinOptions_AddNOptions, + CreateAddNOptions(builder_).Union()); + BuildInterpreter(input_shapes); + } + + int input(int i) { return inputs_[i]; } + + protected: + std::vector inputs_; + int output_; +}; + +class FloatAddNOpModel : public BaseAddNOpModel { + public: + using BaseAddNOpModel::BaseAddNOpModel; + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +class IntegerAddNOpModel : public BaseAddNOpModel { + public: + using BaseAddNOpModel::BaseAddNOpModel; + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +TEST(FloatAddNOpModel, AddMultipleTensors) { + FloatAddNOpModel m({{TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}}, + {TensorType_FLOAT32, {}}); + m.PopulateTensor(m.input(0), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input(1), {0.1, 0.2, 0.3, 0.5}); + m.PopulateTensor(m.input(2), {0.5, 0.1, 0.1, 0.2}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.4, 0.5, 1.1, 1.5})); +} + +TEST(IntegerAddNOpModel, AddMultipleTensors) { + IntegerAddNOpModel m({{TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}}, + {TensorType_INT32, {}}); + m.PopulateTensor(m.input(0), {-20, 2, 7, 8}); + m.PopulateTensor(m.input(1), {1, 2, 3, 5}); + m.PopulateTensor(m.input(2), {10, -5, 1, -2}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-9, -1, 11, 11})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index ac6905fcd7b..84f62b1c97a 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -702,6 +702,22 @@ inline void Add(const ArithmeticParams& params, } } +// T is expected to be either float or int. +template +inline void AddN(const RuntimeShape& input_shape, const size_t num_inputs, + T* const* input_data, T* output_data) { + // All inputs and output should have the same shape, this is checked during + // Prepare stage. + const size_t size = input_shape.FlatSize(); + for (int i = 0; i < size; ++i) { + T x = 0; + for (int j = 0; j < num_inputs; ++j) { + x += input_data[j][i]; + } + output_data[i] = x; + } +} + // Element-wise add that can often be used for inner loop of broadcast add as // well as the non-broadcast add. inline void AddElementwise(int size, const ArithmeticParams& params, diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index 12a53f93b81..aad6deb4d8f 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -131,6 +131,7 @@ TfLiteRegistration* Register_FILL(); TfLiteRegistration* Register_MIRROR_PAD(); TfLiteRegistration* Register_UNIQUE(); TfLiteRegistration* Register_REVERSE_V2(); +TfLiteRegistration* Register_ADD_N(); TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) { context->ReportError( @@ -295,6 +296,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_MIRROR_PAD, Register_MIRROR_PAD()); AddBuiltin(BuiltinOperator_UNIQUE, Register_UNIQUE()); AddBuiltin(BuiltinOperator_REVERSE_V2, Register_REVERSE_V2()); + AddBuiltin(BuiltinOperator_ADD_N, Register_ADD_N()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/lite/testing/generate_examples.py b/tensorflow/lite/testing/generate_examples.py index ee68607a0fe..7ebc5950b26 100644 --- a/tensorflow/lite/testing/generate_examples.py +++ b/tensorflow/lite/testing/generate_examples.py @@ -1184,6 +1184,51 @@ def make_add_tests(zip_path): make_binary_op_tests(zip_path, tf.add) +def make_add_n_tests(zip_path): + """Make a set of tests for AddN op.""" + + test_parameters = [ + { + "dtype": [tf.float32, tf.int32], + "input_shape": [[2, 5, 3, 1]], + "num_inputs": [2, 3, 4, 5], + }, + { + "dtype": [tf.float32, tf.int32], + "input_shape": [[5]], + "num_inputs": [2, 3, 4, 5], + }, + { + "dtype": [tf.float32, tf.int32], + "input_shape": [[]], + "num_inputs": [2, 3, 4, 5], + }, + ] + + def build_graph(parameters): + """Builds the graph given the current parameters.""" + input_tensors = [] + for i in range(parameters["num_inputs"]): + input_tensors.append( + tf.placeholder( + dtype=parameters["dtype"], + name="input_{}".format(i), + shape=parameters["input_shape"])) + out = tf.add_n(input_tensors) + return input_tensors, [out] + + def build_inputs(parameters, sess, inputs, outputs): + """Builds operand inputs for op.""" + input_data = [] + for i in range(parameters["num_inputs"]): + input_data.append( + create_tensor_data(parameters["dtype"], parameters["input_shape"])) + return input_data, sess.run( + outputs, feed_dict={i: d for i, d in zip(inputs, input_data)}) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_div_tests(zip_path): make_binary_op_tests(zip_path, tf.div) diff --git a/tensorflow/lite/toco/import_tensorflow.cc b/tensorflow/lite/toco/import_tensorflow.cc index b1b04949bad..813e439995e 100644 --- a/tensorflow/lite/toco/import_tensorflow.cc +++ b/tensorflow/lite/toco/import_tensorflow.cc @@ -2375,7 +2375,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() { return std::unordered_map({ {"Abs", ConvertSimpleOperator}, {"Add", ConvertSimpleOperator}, - {"AddN", ConvertSimpleOperatorFlexOk}, + {"AddN", ConvertSimpleOperator}, {"All", ConvertSimpleOperator}, {"Any", ConvertReduceOperator}, {"ArgMax", ConvertArgMaxOperator}, diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc index 81203dfa040..4ce3aa9218a 100644 --- a/tensorflow/lite/toco/tflite/operator.cc +++ b/tensorflow/lite/toco/tflite/operator.cc @@ -203,6 +203,25 @@ class Add : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateAddNOptions(*builder); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override {} + + int GetVersion(const OperatorSignature& op_signature) const override { + return 1; + } +}; + class SpaceToBatchND : public BuiltinOperator> BuildOperatorList( // Builtin Operators. ops.push_back( MakeUnique(::tflite::BuiltinOperator_ADD, OperatorType::kAdd)); + ops.push_back( + MakeUnique(::tflite::BuiltinOperator_ADD_N, OperatorType::kAddN)); ops.push_back( MakeUnique
(::tflite::BuiltinOperator_DIV, OperatorType::kDiv)); ops.push_back( diff --git a/tensorflow/lite/toco/tflite/operator_test.cc b/tensorflow/lite/toco/tflite/operator_test.cc index 88f68f7ebf9..43b52c4e930 100644 --- a/tensorflow/lite/toco/tflite/operator_test.cc +++ b/tensorflow/lite/toco/tflite/operator_test.cc @@ -164,6 +164,13 @@ TEST_F(OperatorTest, BuiltinAdd) { output_toco_op->fused_activation_function); } +TEST_F(OperatorTest, BuiltinAddN) { + AddNOperator op; + auto output_toco_op = + SerializeAndDeserialize(GetOperator("ADD_N", OperatorType::kAddN), op); + ASSERT_NE(output_toco_op.get(), nullptr); +} + TEST_F(OperatorTest, BuiltinReducerOps) { CheckReducerOperator("MEAN", OperatorType::kMean); CheckReducerOperator("SUM", OperatorType::kSum);