TFLM: Move Init and Prepare into initialization so that they're only ran once.

Also move free into destructor.

PiperOrigin-RevId: 289656924
Change-Id: Ib33496cd4a74f3e871d8cf0541b1f34afec72de6
This commit is contained in:
Tiezhen WANG 2020-01-14 08:28:37 -08:00 committed by TensorFlower Gardener
parent b32ee7b120
commit d8da788574
3 changed files with 63 additions and 54 deletions

View File

@ -83,6 +83,16 @@ MicroInterpreter::MicroInterpreter(const Model* model,
initialization_status_ = kTfLiteOk;
}
MicroInterpreter::~MicroInterpreter() {
for (size_t i = 0; i < operators_->size(); ++i) {
auto* node = &(node_and_registrations_[i].node);
auto* registration = node_and_registrations_[i].registration;
if (registration->free) {
registration->free(&context_, node->user_data);
}
}
}
void MicroInterpreter::CorrectTensorEndianness(TfLiteTensor* tensorCorr) {
int32_t tensorSize = 1;
for (int d = 0; d < tensorCorr->dims->size; ++d)
@ -125,22 +135,6 @@ TfLiteStatus MicroInterpreter::AllocateTensors() {
op_resolver_, &node_and_registrations_));
TF_LITE_ENSURE_OK(&context_, allocator_.FinishTensorAllocation());
tensors_allocated_ = true;
return kTfLiteOk;
}
TfLiteStatus MicroInterpreter::Invoke() {
if (initialization_status_ != kTfLiteOk) {
error_reporter_->Report("Invoke() called after initialization failed\n");
return kTfLiteError;
}
// Ensure tensors are allocated before the interpreter is invoked to avoid
// difficult to debug segfaults.
if (!tensors_allocated_) {
AllocateTensors();
}
// Init method is not yet implemented.
for (size_t i = 0; i < operators_->size(); ++i) {
auto* node = &(node_and_registrations_[i].node);
@ -174,6 +168,22 @@ TfLiteStatus MicroInterpreter::Invoke() {
}
}
tensors_allocated_ = true;
return kTfLiteOk;
}
TfLiteStatus MicroInterpreter::Invoke() {
if (initialization_status_ != kTfLiteOk) {
error_reporter_->Report("Invoke() called after initialization failed\n");
return kTfLiteError;
}
// Ensure tensors are allocated before the interpreter is invoked to avoid
// difficult to debug segfaults.
if (!tensors_allocated_) {
AllocateTensors();
}
for (size_t i = 0; i < operators_->size(); ++i) {
auto* node = &(node_and_registrations_[i].node);
auto* registration = node_and_registrations_[i].registration;
@ -188,16 +198,6 @@ TfLiteStatus MicroInterpreter::Invoke() {
}
}
}
// This is actually a no-op.
// TODO(wangtz): Consider removing this code to slightly reduce binary size.
for (size_t i = 0; i < operators_->size(); ++i) {
auto* node = &(node_and_registrations_[i].node);
auto* registration = node_and_registrations_[i].registration;
if (registration->free) {
registration->free(&context_, node->user_data);
}
}
return kTfLiteOk;
}

View File

@ -38,6 +38,7 @@ class MicroInterpreter {
MicroInterpreter(const Model* model, const OpResolver& op_resolver,
uint8_t* tensor_arena, size_t tensor_arena_size,
ErrorReporter* error_reporter);
~MicroInterpreter();
// Runs through the model and allocates all necessary input, output and
// intermediate tensors.

View File

@ -21,6 +21,7 @@ limitations under the License.
namespace tflite {
namespace {
void* MockInit(TfLiteContext* context, const char* buffer, size_t length) {
// We don't support delegate in TFL micro. This is a weak check to test if
// context struct being zero-initialized.
@ -30,9 +31,8 @@ void* MockInit(TfLiteContext* context, const char* buffer, size_t length) {
return nullptr;
}
void MockFree(TfLiteContext* context, void* buffer) {
// Do nothing.
}
bool freed = false;
void MockFree(TfLiteContext* context, void* buffer) { freed = true; }
TfLiteStatus MockPrepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
@ -72,40 +72,48 @@ class MockOpResolver : public OpResolver {
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(TestInterpreter) {
tflite::freed = false;
const tflite::Model* model = tflite::testing::GetSimpleMockModel();
TF_LITE_MICRO_EXPECT_NE(nullptr, model);
tflite::MockOpResolver mock_resolver;
constexpr size_t allocator_buffer_size = 1024;
uint8_t allocator_buffer[allocator_buffer_size];
tflite::MicroInterpreter interpreter(model, mock_resolver, allocator_buffer,
allocator_buffer_size,
micro_test::reporter);
TF_LITE_MICRO_EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
TF_LITE_MICRO_EXPECT_EQ(1, interpreter.inputs_size());
TF_LITE_MICRO_EXPECT_EQ(1, interpreter.outputs_size());
TfLiteTensor* input = interpreter.input(0);
TF_LITE_MICRO_EXPECT_NE(nullptr, input);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt32, input->type);
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->size);
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
TF_LITE_MICRO_EXPECT_EQ(4, input->bytes);
TF_LITE_MICRO_EXPECT_NE(nullptr, input->data.i32);
input->data.i32[0] = 21;
// Create a new scope so that we can test the destructor.
{
tflite::MicroInterpreter interpreter(model, mock_resolver, allocator_buffer,
allocator_buffer_size,
micro_test::reporter);
TF_LITE_MICRO_EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
TF_LITE_MICRO_EXPECT_EQ(1, interpreter.inputs_size());
TF_LITE_MICRO_EXPECT_EQ(1, interpreter.outputs_size());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, interpreter.Invoke());
TfLiteTensor* input = interpreter.input(0);
TF_LITE_MICRO_EXPECT_NE(nullptr, input);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt32, input->type);
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->size);
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
TF_LITE_MICRO_EXPECT_EQ(4, input->bytes);
TF_LITE_MICRO_EXPECT_NE(nullptr, input->data.i32);
input->data.i32[0] = 21;
TfLiteTensor* output = interpreter.output(0);
TF_LITE_MICRO_EXPECT_NE(nullptr, output);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt32, output->type);
TF_LITE_MICRO_EXPECT_EQ(1, output->dims->size);
TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]);
TF_LITE_MICRO_EXPECT_EQ(4, output->bytes);
TF_LITE_MICRO_EXPECT_NE(nullptr, output->data.i32);
TF_LITE_MICRO_EXPECT_EQ(42, output->data.i32[0]);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, interpreter.Invoke());
// Just to make sure that this method works.
tflite::PrintInterpreterState(&interpreter);
TfLiteTensor* output = interpreter.output(0);
TF_LITE_MICRO_EXPECT_NE(nullptr, output);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt32, output->type);
TF_LITE_MICRO_EXPECT_EQ(1, output->dims->size);
TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]);
TF_LITE_MICRO_EXPECT_EQ(4, output->bytes);
TF_LITE_MICRO_EXPECT_NE(nullptr, output->data.i32);
TF_LITE_MICRO_EXPECT_EQ(42, output->data.i32[0]);
// Just to make sure that this method works.
tflite::PrintInterpreterState(&interpreter);
TF_LITE_MICRO_EXPECT_EQ(tflite::freed, false);
}
TF_LITE_MICRO_EXPECT_EQ(tflite::freed, true);
}
TF_LITE_MICRO_TESTS_END