Merge pull request #46207 from ddavis-2015:AddN-pr4

PiperOrigin-RevId: 361249674
Change-Id: I24ca5102ef2c061bcf946be288dc4c884a037262
This commit is contained in:
TensorFlower Gardener 2021-03-05 16:47:21 -08:00
commit 8b97d84da1
2 changed files with 24 additions and 64 deletions

View File

@ -22,8 +22,9 @@ limitations under the License.
namespace tflite {
namespace ops {
namespace builtin {
namespace micro {
namespace add_n {
namespace {
constexpr int kInputTensor1 = 0;
constexpr int kOutputTensor = 0;
@ -49,11 +50,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input->type);
}
// Use the first input node's dimension to be the dimension of the output
// node.
TfLiteIntArray* input1_dims = input1->dims;
TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input1_dims);
return context->ResizeTensor(context, output, output_dims);
return kTfLiteError;
}
template <typename T>
@ -88,14 +85,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
} // namespace
} // namespace add_n
TfLiteRegistration* Register_ADD_N() {
static TfLiteRegistration r = {/*init*/ nullptr, /*free*/ nullptr,
add_n::Prepare, add_n::Eval};
return &r;
}
TfLiteRegistration* Register_ADD_N() { return nullptr; }
} // namespace builtin
} // namespace micro
} // namespace ops
} // namespace tflite

View File

@ -12,59 +12,25 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stdint.h>
#include <vector>
#include <type_traits>
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
#include "tensorflow/lite/micro/test_helpers.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
namespace tflite {
namespace {
namespace testing {
namespace {} // namespace
} // namespace testing
} // namespace tflite
using ::testing::ElementsAreArray;
TF_LITE_MICRO_TESTS_BEGIN
class BaseAddNOpModel : public SingleOpModel {
public:
BaseAddNOpModel(const std::vector<TensorData>& inputs,
const TensorData& output) {
int num_inputs = inputs.size();
std::vector<std::vector<int>> input_shapes;
for (int i = 0; i < num_inputs; ++i) {
inputs_.push_back(AddInput(inputs[i]));
input_shapes.push_back(GetShape(inputs_[i]));
}
output_ = AddOutput(output);
SetBuiltinOp(BuiltinOperator_ADD_N, BuiltinOptions_AddNOptions,
CreateAddNOptions(builder_).Union());
BuildInterpreter(input_shapes);
}
int input(int i) { return inputs_[i]; }
protected:
std::vector<int> inputs_;
int output_;
};
class FloatAddNOpModel : public BaseAddNOpModel {
public:
using BaseAddNOpModel::BaseAddNOpModel;
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
};
class IntegerAddNOpModel : public BaseAddNOpModel {
public:
using BaseAddNOpModel::BaseAddNOpModel;
std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); }
};
TEST(FloatAddNOpModel, AddMultipleTensors) {
TF_LITE_MICRO_TEST(FloatAddNOpAddMultipleTensors) {
#ifdef notdef
FloatAddNOpModel m({{TensorType_FLOAT32, {1, 2, 2, 1}},
{TensorType_FLOAT32, {1, 2, 2, 1}},
{TensorType_FLOAT32, {1, 2, 2, 1}}},
@ -72,11 +38,12 @@ TEST(FloatAddNOpModel, AddMultipleTensors) {
m.PopulateTensor<float>(m.input(0), {-2.0, 0.2, 0.7, 0.8});
m.PopulateTensor<float>(m.input(1), {0.1, 0.2, 0.3, 0.5});
m.PopulateTensor<float>(m.input(2), {0.5, 0.1, 0.1, 0.2});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.4, 0.5, 1.1, 1.5}));
#endif // notdef
}
TEST(IntegerAddNOpModel, AddMultipleTensors) {
TF_LITE_MICRO_TEST(IntegerAddNOpAddMultipleTensors) {
#ifdef notdef
IntegerAddNOpModel m({{TensorType_INT32, {1, 2, 2, 1}},
{TensorType_INT32, {1, 2, 2, 1}},
{TensorType_INT32, {1, 2, 2, 1}}},
@ -84,9 +51,8 @@ TEST(IntegerAddNOpModel, AddMultipleTensors) {
m.PopulateTensor<int32_t>(m.input(0), {-20, 2, 7, 8});
m.PopulateTensor<int32_t>(m.input(1), {1, 2, 3, 5});
m.PopulateTensor<int32_t>(m.input(2), {10, -5, 1, -2});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray({-9, -1, 11, 11}));
#endif // notdef
}
} // namespace
} // namespace tflite
TF_LITE_MICRO_TESTS_END