Merge pull request #46207 from ddavis-2015:AddN-pr4
PiperOrigin-RevId: 361249674 Change-Id: I24ca5102ef2c061bcf946be288dc4c884a037262
This commit is contained in:
commit
8b97d84da1
@ -22,8 +22,9 @@ limitations under the License.
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace builtin {
|
||||
namespace micro {
|
||||
namespace add_n {
|
||||
namespace {
|
||||
|
||||
constexpr int kInputTensor1 = 0;
|
||||
constexpr int kOutputTensor = 0;
|
||||
@ -49,11 +50,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input->type);
|
||||
}
|
||||
|
||||
// Use the first input node's dimension to be the dimension of the output
|
||||
// node.
|
||||
TfLiteIntArray* input1_dims = input1->dims;
|
||||
TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input1_dims);
|
||||
return context->ResizeTensor(context, output, output_dims);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -88,14 +85,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace add_n
|
||||
|
||||
TfLiteRegistration* Register_ADD_N() {
|
||||
static TfLiteRegistration r = {/*init*/ nullptr, /*free*/ nullptr,
|
||||
add_n::Prepare, add_n::Eval};
|
||||
return &r;
|
||||
}
|
||||
TfLiteRegistration* Register_ADD_N() { return nullptr; }
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace micro
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
|
@ -12,59 +12,25 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <stdint.h>
|
||||
|
||||
#include <vector>
|
||||
#include <type_traits>
|
||||
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
|
||||
#include "tensorflow/lite/micro/test_helpers.h"
|
||||
#include "tensorflow/lite/micro/testing/micro_test.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
namespace testing {
|
||||
namespace {} // namespace
|
||||
} // namespace testing
|
||||
} // namespace tflite
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
TF_LITE_MICRO_TESTS_BEGIN
|
||||
|
||||
class BaseAddNOpModel : public SingleOpModel {
|
||||
public:
|
||||
BaseAddNOpModel(const std::vector<TensorData>& inputs,
|
||||
const TensorData& output) {
|
||||
int num_inputs = inputs.size();
|
||||
std::vector<std::vector<int>> input_shapes;
|
||||
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
inputs_.push_back(AddInput(inputs[i]));
|
||||
input_shapes.push_back(GetShape(inputs_[i]));
|
||||
}
|
||||
|
||||
output_ = AddOutput(output);
|
||||
SetBuiltinOp(BuiltinOperator_ADD_N, BuiltinOptions_AddNOptions,
|
||||
CreateAddNOptions(builder_).Union());
|
||||
BuildInterpreter(input_shapes);
|
||||
}
|
||||
|
||||
int input(int i) { return inputs_[i]; }
|
||||
|
||||
protected:
|
||||
std::vector<int> inputs_;
|
||||
int output_;
|
||||
};
|
||||
|
||||
class FloatAddNOpModel : public BaseAddNOpModel {
|
||||
public:
|
||||
using BaseAddNOpModel::BaseAddNOpModel;
|
||||
|
||||
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
|
||||
};
|
||||
|
||||
class IntegerAddNOpModel : public BaseAddNOpModel {
|
||||
public:
|
||||
using BaseAddNOpModel::BaseAddNOpModel;
|
||||
|
||||
std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); }
|
||||
};
|
||||
|
||||
TEST(FloatAddNOpModel, AddMultipleTensors) {
|
||||
TF_LITE_MICRO_TEST(FloatAddNOpAddMultipleTensors) {
|
||||
#ifdef notdef
|
||||
FloatAddNOpModel m({{TensorType_FLOAT32, {1, 2, 2, 1}},
|
||||
{TensorType_FLOAT32, {1, 2, 2, 1}},
|
||||
{TensorType_FLOAT32, {1, 2, 2, 1}}},
|
||||
@ -72,11 +38,12 @@ TEST(FloatAddNOpModel, AddMultipleTensors) {
|
||||
m.PopulateTensor<float>(m.input(0), {-2.0, 0.2, 0.7, 0.8});
|
||||
m.PopulateTensor<float>(m.input(1), {0.1, 0.2, 0.3, 0.5});
|
||||
m.PopulateTensor<float>(m.input(2), {0.5, 0.1, 0.1, 0.2});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.4, 0.5, 1.1, 1.5}));
|
||||
#endif // notdef
|
||||
}
|
||||
|
||||
TEST(IntegerAddNOpModel, AddMultipleTensors) {
|
||||
TF_LITE_MICRO_TEST(IntegerAddNOpAddMultipleTensors) {
|
||||
#ifdef notdef
|
||||
IntegerAddNOpModel m({{TensorType_INT32, {1, 2, 2, 1}},
|
||||
{TensorType_INT32, {1, 2, 2, 1}},
|
||||
{TensorType_INT32, {1, 2, 2, 1}}},
|
||||
@ -84,9 +51,8 @@ TEST(IntegerAddNOpModel, AddMultipleTensors) {
|
||||
m.PopulateTensor<int32_t>(m.input(0), {-20, 2, 7, 8});
|
||||
m.PopulateTensor<int32_t>(m.input(1), {1, 2, 3, 5});
|
||||
m.PopulateTensor<int32_t>(m.input(2), {10, -5, 1, -2});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({-9, -1, 11, 11}));
|
||||
#endif // notdef
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
TF_LITE_MICRO_TESTS_END
|
||||
|
Loading…
Reference in New Issue
Block a user