diff --git a/tensorflow/lite/micro/micro_interpreter.cc b/tensorflow/lite/micro/micro_interpreter.cc index a9286e88a27..f6f8127f467 100644 --- a/tensorflow/lite/micro/micro_interpreter.cc +++ b/tensorflow/lite/micro/micro_interpreter.cc @@ -52,7 +52,8 @@ MicroInterpreter::MicroInterpreter(const Model* model, error_reporter_(error_reporter), allocator_(&context_, model_, tensor_arena, tensor_arena_size, error_reporter_), - tensors_allocated_(false) { + tensors_allocated_(false), + tensors_prepared_(false) { const flatbuffers::Vector>* subgraphs = model->subgraphs(); if (subgraphs->size() != 1) { @@ -155,24 +156,30 @@ TfLiteStatus MicroInterpreter::Invoke() { init_data = reinterpret_cast(node->builtin_data); init_data_size = 0; } - if (registration->init) { + if (!tensors_prepared_ && registration->init) { node->user_data = registration->init(&context_, init_data, init_data_size); } } - for (size_t i = 0; i < operators_->size(); ++i) { - auto* node = &(node_and_registrations_[i].node); - auto* registration = node_and_registrations_[i].registration; - if (registration->prepare) { - TfLiteStatus prepare_status = registration->prepare(&context_, node); - if (prepare_status != kTfLiteOk) { - error_reporter_->Report( - "Node %s (number %d) failed to prepare with status %d", - OpNameFromRegistration(registration), i, prepare_status); - return kTfLiteError; + if (!tensors_prepared_) { + for (size_t i = 0; i < operators_->size(); ++i) { + auto* node = &(node_and_registrations_[i].node); + auto* registration = node_and_registrations_[i].registration; + if (registration->prepare) { + TfLiteStatus prepare_status = registration->prepare(&context_, node); + if (prepare_status != kTfLiteOk) { + error_reporter_->Report( + "Node %s (number %d) failed to prepare with status %d", + OpNameFromRegistration(registration), i, prepare_status); + return kTfLiteError; + } } } +#ifdef TF_LITE_MICRO_TENSORS_PREPARED + // TODO(b/148085107): Turn this value on by default. + tensors_prepared_ = true; +#endif } for (size_t i = 0; i < operators_->size(); ++i) { diff --git a/tensorflow/lite/micro/micro_interpreter.h b/tensorflow/lite/micro/micro_interpreter.h index e7d0c897c8b..941960a5116 100644 --- a/tensorflow/lite/micro/micro_interpreter.h +++ b/tensorflow/lite/micro/micro_interpreter.h @@ -117,6 +117,7 @@ class MicroInterpreter { TfLiteContext context_ = {}; MicroAllocator allocator_; bool tensors_allocated_; + bool tensors_prepared_; TfLiteStatus initialization_status_; const flatbuffers::Vector>* tensors_; diff --git a/tensorflow/lite/micro/tools/make/targets/xtensa_xpg_makefile.inc b/tensorflow/lite/micro/tools/make/targets/xtensa_xpg_makefile.inc index fee3855ba6c..22b013a7dfe 100644 --- a/tensorflow/lite/micro/tools/make/targets/xtensa_xpg_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/xtensa_xpg_makefile.inc @@ -7,6 +7,7 @@ ifeq ($(TARGET), xtensa-xpg) TARGET_ARCH := xtensa-xpg PLATFORM_ARGS = \ + -DTF_LITE_MICRO_TENSORS_PREPARED \ -DTF_LITE_STATIC_MEMORY \ -DNDEBUG \ -DTF_LITE_MCU_DEBUG_LOG \