The Multi-tenant/recording constructor hasn't been widely adopted so this change would be safe. In case it's not, we should update the reference. Otherwise this discrepancy could be misleading. PiperOrigin-RevId: 317039376 Change-Id: I518591faa709a5e386cbc6aa6bf00f539aa498ca
351 lines
13 KiB
C++
351 lines
13 KiB
C++
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
#include "tensorflow/lite/micro/micro_interpreter.h"
|
|
|
|
#include <cstdarg>
|
|
#include <cstddef>
|
|
#include <cstdint>
|
|
|
|
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
|
#include "tensorflow/lite/c/common.h"
|
|
#include "tensorflow/lite/core/api/error_reporter.h"
|
|
#include "tensorflow/lite/core/api/tensor_utils.h"
|
|
#include "tensorflow/lite/micro/micro_allocator.h"
|
|
#include "tensorflow/lite/micro/micro_op_resolver.h"
|
|
#include "tensorflow/lite/micro/micro_profiler.h"
|
|
#include "tensorflow/lite/schema/schema_generated.h"
|
|
|
|
namespace tflite {
|
|
namespace {
|
|
|
|
const char* OpNameFromRegistration(const TfLiteRegistration* registration) {
|
|
if (registration->builtin_code == BuiltinOperator_CUSTOM) {
|
|
return registration->custom_name;
|
|
} else {
|
|
return EnumNameBuiltinOperator(BuiltinOperator(registration->builtin_code));
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
namespace internal {
|
|
|
|
TfLiteStatus ContextHelper::AllocatePersistentBuffer(TfLiteContext* ctx,
|
|
size_t bytes, void** ptr) {
|
|
return reinterpret_cast<ContextHelper*>(ctx->impl_)
|
|
->allocator_->AllocatePersistentBuffer(bytes, ptr);
|
|
}
|
|
|
|
TfLiteStatus ContextHelper::RequestScratchBufferInArena(TfLiteContext* ctx,
|
|
size_t bytes,
|
|
int* buffer_idx) {
|
|
ContextHelper* helper = reinterpret_cast<ContextHelper*>(ctx->impl_);
|
|
return helper->allocator_->RequestScratchBufferInArena(
|
|
helper->current_node_idx_, bytes, buffer_idx);
|
|
}
|
|
|
|
void* ContextHelper::GetScratchBuffer(TfLiteContext* ctx, int buffer_idx) {
|
|
return reinterpret_cast<ContextHelper*>(ctx->impl_)
|
|
->allocator_->GetScratchBuffer(buffer_idx);
|
|
}
|
|
|
|
void ContextHelper::ReportOpError(struct TfLiteContext* context,
|
|
const char* format, ...) {
|
|
ContextHelper* helper = static_cast<ContextHelper*>(context->impl_);
|
|
va_list args;
|
|
va_start(args, format);
|
|
TF_LITE_REPORT_ERROR(helper->error_reporter_, format, args);
|
|
va_end(args);
|
|
}
|
|
|
|
} // namespace internal
|
|
|
|
MicroInterpreter::MicroInterpreter(const Model* model,
|
|
const MicroOpResolver& op_resolver,
|
|
uint8_t* tensor_arena,
|
|
size_t tensor_arena_size,
|
|
ErrorReporter* error_reporter,
|
|
tflite::Profiler* profiler)
|
|
: model_(model),
|
|
op_resolver_(op_resolver),
|
|
error_reporter_(error_reporter),
|
|
allocator_(*MicroAllocator::Create(tensor_arena, tensor_arena_size,
|
|
error_reporter)),
|
|
tensors_allocated_(false),
|
|
initialization_status_(kTfLiteError),
|
|
context_helper_(error_reporter_, &allocator_) {
|
|
Init(profiler);
|
|
}
|
|
|
|
MicroInterpreter::MicroInterpreter(const Model* model,
|
|
const MicroOpResolver& op_resolver,
|
|
MicroAllocator* allocator,
|
|
ErrorReporter* error_reporter,
|
|
tflite::Profiler* profiler)
|
|
: model_(model),
|
|
op_resolver_(op_resolver),
|
|
error_reporter_(error_reporter),
|
|
allocator_(*allocator),
|
|
tensors_allocated_(false),
|
|
initialization_status_(kTfLiteError),
|
|
context_helper_(error_reporter_, &allocator_) {
|
|
Init(profiler);
|
|
}
|
|
|
|
MicroInterpreter::~MicroInterpreter() {
|
|
if (node_and_registrations_ != nullptr) {
|
|
for (size_t i = 0; i < subgraph_->operators()->size(); ++i) {
|
|
TfLiteNode* node = &(node_and_registrations_[i].node);
|
|
const TfLiteRegistration* registration =
|
|
node_and_registrations_[i].registration;
|
|
// registration is allocated outside the interpreter, so double check to
|
|
// make sure it's not nullptr;
|
|
if (registration != nullptr && registration->free != nullptr) {
|
|
registration->free(&context_, node->user_data);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void MicroInterpreter::Init(tflite::Profiler* profiler) {
|
|
const flatbuffers::Vector<flatbuffers::Offset<SubGraph>>* subgraphs =
|
|
model_->subgraphs();
|
|
if (subgraphs->size() != 1) {
|
|
TF_LITE_REPORT_ERROR(error_reporter_,
|
|
"Only 1 subgraph is currently supported.\n");
|
|
initialization_status_ = kTfLiteError;
|
|
return;
|
|
}
|
|
subgraph_ = (*subgraphs)[0];
|
|
|
|
context_.impl_ = static_cast<void*>(&context_helper_);
|
|
context_.ReportError = context_helper_.ReportOpError;
|
|
context_.recommended_num_threads = 1;
|
|
context_.profiler = profiler;
|
|
|
|
initialization_status_ = kTfLiteOk;
|
|
}
|
|
|
|
void MicroInterpreter::CorrectTensorEndianness(TfLiteTensor* tensorCorr) {
|
|
int32_t tensorSize = 1;
|
|
for (int d = 0; d < tensorCorr->dims->size; ++d)
|
|
tensorSize *= reinterpret_cast<const int32_t*>(tensorCorr->dims->data)[d];
|
|
|
|
switch (tensorCorr->type) {
|
|
case TfLiteType::kTfLiteFloat32:
|
|
CorrectTensorDataEndianness(tensorCorr->data.f, tensorSize);
|
|
break;
|
|
case TfLiteType::kTfLiteFloat16:
|
|
CorrectTensorDataEndianness(tensorCorr->data.f16, tensorSize);
|
|
break;
|
|
case TfLiteType::kTfLiteInt64:
|
|
CorrectTensorDataEndianness(tensorCorr->data.i64, tensorSize);
|
|
break;
|
|
case TfLiteType::kTfLiteInt32:
|
|
CorrectTensorDataEndianness(tensorCorr->data.i32, tensorSize);
|
|
break;
|
|
case TfLiteType::kTfLiteInt16:
|
|
CorrectTensorDataEndianness(tensorCorr->data.i16, tensorSize);
|
|
break;
|
|
case TfLiteType::kTfLiteComplex64:
|
|
CorrectTensorDataEndianness(tensorCorr->data.c64, tensorSize);
|
|
break;
|
|
default:
|
|
// Do nothing for other data types.
|
|
break;
|
|
}
|
|
}
|
|
|
|
template <class T>
|
|
void MicroInterpreter::CorrectTensorDataEndianness(T* data, int32_t size) {
|
|
for (int32_t i = 0; i < size; ++i) {
|
|
data[i] = flatbuffers::EndianScalar(data[i]);
|
|
}
|
|
}
|
|
|
|
TfLiteStatus MicroInterpreter::AllocateTensors() {
|
|
if (allocator_.StartModelAllocation(model_, &context_, op_resolver_,
|
|
&node_and_registrations_) != kTfLiteOk) {
|
|
TF_LITE_REPORT_ERROR(error_reporter_,
|
|
"Failed starting model allocation.\n");
|
|
initialization_status_ = kTfLiteError;
|
|
return kTfLiteError;
|
|
}
|
|
|
|
// If the system is big endian then convert weights from the flatbuffer from
|
|
// little to big endian on startup so that it does not need to be done during
|
|
// inference.
|
|
// NOTE: This requires that the flatbuffer is held in memory which can be
|
|
// modified by this process.
|
|
if (!FLATBUFFERS_LITTLEENDIAN) {
|
|
for (size_t t = 0; t < tensors_size(); ++t) {
|
|
TfLiteTensor* thisTensor = &context_.tensors[t];
|
|
if (thisTensor->allocation_type == kTfLiteMmapRo)
|
|
CorrectTensorEndianness(thisTensor);
|
|
}
|
|
}
|
|
|
|
// Only allow AllocatePersistentBuffer in Init stage.
|
|
context_.AllocatePersistentBuffer = context_helper_.AllocatePersistentBuffer;
|
|
context_.RequestScratchBufferInArena = nullptr;
|
|
context_.GetScratchBuffer = nullptr;
|
|
|
|
for (size_t i = 0; i < subgraph_->operators()->size(); ++i) {
|
|
context_helper_.SetNodeIndex(i);
|
|
auto* node = &(node_and_registrations_[i].node);
|
|
auto* registration = node_and_registrations_[i].registration;
|
|
size_t init_data_size;
|
|
const char* init_data;
|
|
if (registration->builtin_code == BuiltinOperator_CUSTOM) {
|
|
init_data = reinterpret_cast<const char*>(node->custom_initial_data);
|
|
init_data_size = node->custom_initial_data_size;
|
|
} else {
|
|
init_data = reinterpret_cast<const char*>(node->builtin_data);
|
|
init_data_size = 0;
|
|
}
|
|
if (registration->init) {
|
|
node->user_data =
|
|
registration->init(&context_, init_data, init_data_size);
|
|
}
|
|
}
|
|
context_helper_.SetNodeIndex(-1);
|
|
|
|
// Both AllocatePersistentBuffer and RequestScratchBufferInArena is available
|
|
// in Prepare stage.
|
|
context_.RequestScratchBufferInArena =
|
|
context_helper_.RequestScratchBufferInArena;
|
|
for (size_t i = 0; i < subgraph_->operators()->size(); ++i) {
|
|
// Set node idx to annotate the lifetime for scratch buffers.
|
|
context_helper_.SetNodeIndex(i);
|
|
auto* node = &(node_and_registrations_[i].node);
|
|
auto* registration = node_and_registrations_[i].registration;
|
|
if (registration->prepare) {
|
|
TfLiteStatus prepare_status = registration->prepare(&context_, node);
|
|
if (prepare_status != kTfLiteOk) {
|
|
TF_LITE_REPORT_ERROR(
|
|
error_reporter_,
|
|
"Node %s (number %df) failed to prepare with status %d",
|
|
OpNameFromRegistration(registration), i, prepare_status);
|
|
return kTfLiteError;
|
|
}
|
|
}
|
|
}
|
|
context_helper_.SetNodeIndex(-1);
|
|
|
|
// Prepare is done, we're ready for Invoke. Memory allocation is no longer
|
|
// allowed. Kernels can only fetch scratch buffers via GetScratchBuffer.
|
|
context_.AllocatePersistentBuffer = nullptr;
|
|
context_.RequestScratchBufferInArena = nullptr;
|
|
context_.GetScratchBuffer = context_helper_.GetScratchBuffer;
|
|
|
|
TF_LITE_ENSURE_OK(&context_,
|
|
allocator_.FinishModelAllocation(model_, &context_));
|
|
tensors_allocated_ = true;
|
|
return kTfLiteOk;
|
|
}
|
|
|
|
TfLiteStatus MicroInterpreter::Invoke() {
|
|
if (initialization_status_ != kTfLiteOk) {
|
|
TF_LITE_REPORT_ERROR(error_reporter_,
|
|
"Invoke() called after initialization failed\n");
|
|
return kTfLiteError;
|
|
}
|
|
|
|
// Ensure tensors are allocated before the interpreter is invoked to avoid
|
|
// difficult to debug segfaults.
|
|
if (!tensors_allocated_) {
|
|
TF_LITE_ENSURE_OK(&context_, AllocateTensors());
|
|
}
|
|
|
|
for (size_t i = 0; i < subgraph_->operators()->size(); ++i) {
|
|
auto* node = &(node_and_registrations_[i].node);
|
|
auto* registration = node_and_registrations_[i].registration;
|
|
|
|
if (registration->invoke) {
|
|
TfLiteStatus invoke_status;
|
|
#ifndef NDEBUG // Omit profiler overhead from release builds.
|
|
// The case where profiler == nullptr is handled by ScopedOperatorProfile.
|
|
tflite::Profiler* profiler =
|
|
reinterpret_cast<tflite::Profiler*>(context_.profiler);
|
|
ScopedOperatorProfile scoped_profiler(
|
|
profiler, OpNameFromRegistration(registration), i);
|
|
#endif
|
|
invoke_status = registration->invoke(&context_, node);
|
|
|
|
if (invoke_status == kTfLiteError) {
|
|
TF_LITE_REPORT_ERROR(
|
|
error_reporter_,
|
|
"Node %s (number %d) failed to invoke with status %d",
|
|
OpNameFromRegistration(registration), i, invoke_status);
|
|
return kTfLiteError;
|
|
} else if (invoke_status != kTfLiteOk) {
|
|
return invoke_status;
|
|
}
|
|
}
|
|
}
|
|
return kTfLiteOk;
|
|
}
|
|
|
|
TfLiteTensor* MicroInterpreter::input(size_t index) {
|
|
const size_t length = inputs_size();
|
|
if ((index < 0) || (index >= length)) {
|
|
TF_LITE_REPORT_ERROR(error_reporter_,
|
|
"Input index %d out of range (length is %d)", index,
|
|
length);
|
|
return nullptr;
|
|
}
|
|
return &(context_.tensors[inputs().Get(index)]);
|
|
}
|
|
|
|
TfLiteTensor* MicroInterpreter::output(size_t index) {
|
|
const size_t length = outputs_size();
|
|
if ((index < 0) || (index >= length)) {
|
|
TF_LITE_REPORT_ERROR(error_reporter_,
|
|
"Output index %d out of range (length is %d)", index,
|
|
length);
|
|
return nullptr;
|
|
}
|
|
return &(context_.tensors[outputs().Get(index)]);
|
|
}
|
|
|
|
TfLiteTensor* MicroInterpreter::tensor(size_t index) {
|
|
const size_t length = tensors_size();
|
|
if ((index < 0) || (index >= length)) {
|
|
TF_LITE_REPORT_ERROR(error_reporter_,
|
|
"Tensor index %d out of range (length is %d)", index,
|
|
length);
|
|
return nullptr;
|
|
}
|
|
return &context_.tensors[index];
|
|
}
|
|
|
|
TfLiteStatus MicroInterpreter::ResetVariableTensors() {
|
|
const size_t length = tensors_size();
|
|
for (size_t i = 0; i < length; ++i) {
|
|
TfLiteTensor* cur_tensor = tensor(i);
|
|
if (cur_tensor->is_variable) {
|
|
TfLiteStatus status = tflite::ResetVariableTensor(cur_tensor);
|
|
if (status != kTfLiteOk) {
|
|
TF_LITE_REPORT_ERROR(error_reporter_,
|
|
"Failed to reset variable tensor at index: %d", i);
|
|
return status;
|
|
}
|
|
}
|
|
}
|
|
return kTfLiteOk;
|
|
}
|
|
|
|
} // namespace tflite
|