242 lines
8.3 KiB
C++
242 lines
8.3 KiB
C++
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
#include "tensorflow/lite/experimental/micro/micro_interpreter.h"
|
|
|
|
#include "tensorflow/lite/c/c_api_internal.h"
|
|
#include "tensorflow/lite/core/api/flatbuffer_conversions.h"
|
|
#include "tensorflow/lite/experimental/micro/compatibility.h"
|
|
|
|
namespace tflite {
|
|
namespace {
|
|
const int kStackDataAllocatorSize = 128;
|
|
class StackDataAllocator : public BuiltinDataAllocator {
|
|
public:
|
|
void* Allocate(size_t size) override {
|
|
if (size > kStackDataAllocatorSize) {
|
|
return nullptr;
|
|
} else {
|
|
return data_;
|
|
}
|
|
}
|
|
void Deallocate(void* data) override {
|
|
// Do nothing.
|
|
}
|
|
|
|
private:
|
|
uint8_t data_[kStackDataAllocatorSize];
|
|
|
|
TF_LITE_REMOVE_VIRTUAL_DELETE
|
|
};
|
|
|
|
const char* OpNameFromRegistration(const TfLiteRegistration* registration) {
|
|
if (registration->builtin_code == BuiltinOperator_CUSTOM) {
|
|
return registration->custom_name;
|
|
} else {
|
|
return EnumNameBuiltinOperator(BuiltinOperator(registration->builtin_code));
|
|
}
|
|
}
|
|
|
|
void ReportOpError(struct TfLiteContext* context, const char* format, ...) {
|
|
MicroInterpreter* interpreter =
|
|
static_cast<MicroInterpreter*>(context->impl_);
|
|
va_list args;
|
|
va_start(args, format);
|
|
interpreter->error_reporter()->Report(format, args);
|
|
va_end(args);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
MicroInterpreter::MicroInterpreter(const Model* model,
|
|
const OpResolver& op_resolver,
|
|
uint8_t* tensor_arena,
|
|
size_t tensor_arena_size,
|
|
ErrorReporter* error_reporter)
|
|
: model_(model),
|
|
op_resolver_(op_resolver),
|
|
error_reporter_(error_reporter),
|
|
context_(),
|
|
allocator_(&context_, model_, tensor_arena, tensor_arena_size,
|
|
error_reporter_),
|
|
tensors_allocated_(false) {
|
|
auto* subgraphs = model->subgraphs();
|
|
if (subgraphs->size() != 1) {
|
|
error_reporter->Report("Only 1 subgraph is currently supported.\n");
|
|
initialization_status_ = kTfLiteError;
|
|
return;
|
|
}
|
|
subgraph_ = (*subgraphs)[0];
|
|
tensors_ = subgraph_->tensors();
|
|
operators_ = subgraph_->operators();
|
|
|
|
context_.impl_ = static_cast<void*>(this);
|
|
context_.ReportError = ReportOpError;
|
|
context_.recommended_num_threads = 1;
|
|
initialization_status_ = kTfLiteOk;
|
|
}
|
|
|
|
TfLiteStatus MicroInterpreter::RegisterPreallocatedInput(uint8_t* buffer,
|
|
size_t input_index) {
|
|
return allocator_.RegisterPreallocatedInput(buffer, input_index);
|
|
}
|
|
|
|
TfLiteStatus MicroInterpreter::AllocateTensors() {
|
|
TfLiteStatus status = allocator_.AllocateTensors();
|
|
TF_LITE_ENSURE_OK(&context_, status);
|
|
tensors_allocated_ = true;
|
|
return kTfLiteOk;
|
|
}
|
|
|
|
TfLiteStatus MicroInterpreter::Invoke() {
|
|
if (initialization_status_ != kTfLiteOk) {
|
|
error_reporter_->Report("Invoke() called after initialization failed\n");
|
|
return kTfLiteError;
|
|
}
|
|
|
|
// Ensure tensors are allocated before the interpreter is invoked to avoid
|
|
// difficult to debug segfaults.
|
|
if (!tensors_allocated_) {
|
|
AllocateTensors();
|
|
}
|
|
TfLiteStatus status = kTfLiteOk;
|
|
auto opcodes = model_->operator_codes();
|
|
for (int i = 0; i < operators_->size(); ++i) {
|
|
const auto* op = operators_->Get(i);
|
|
int index = op->opcode_index();
|
|
if (index < 0 || index >= opcodes->size()) {
|
|
error_reporter_->Report("Missing registration for opcode_index %d\n",
|
|
index);
|
|
return kTfLiteError;
|
|
}
|
|
auto opcode = (*opcodes)[index];
|
|
const TfLiteRegistration* registration = nullptr;
|
|
status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_,
|
|
®istration);
|
|
if (status != kTfLiteOk) {
|
|
return status;
|
|
}
|
|
if (registration == nullptr) {
|
|
error_reporter_->Report("Skipping op for opcode_index %d\n", index);
|
|
return kTfLiteError;
|
|
}
|
|
BuiltinOperator op_type =
|
|
static_cast<BuiltinOperator>(registration->builtin_code);
|
|
|
|
if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) {
|
|
error_reporter_->Report(
|
|
"Unsupported behavior: found builtin operator %s with custom "
|
|
"options.\n",
|
|
EnumNameBuiltinOperator(op_type));
|
|
return kTfLiteError;
|
|
}
|
|
StackDataAllocator stack_data_allocator;
|
|
const char* custom_data = nullptr;
|
|
size_t custom_data_size = 0;
|
|
unsigned char* builtin_data = nullptr;
|
|
if (op->custom_options()) {
|
|
custom_data = reinterpret_cast<const char*>(op->custom_options()->data());
|
|
custom_data_size = op->custom_options()->size();
|
|
} else {
|
|
TF_LITE_ENSURE_STATUS(ParseOpData(op, op_type, error_reporter_,
|
|
&stack_data_allocator,
|
|
(void**)(&builtin_data)));
|
|
}
|
|
|
|
const char* init_data;
|
|
size_t init_data_size;
|
|
if (registration->builtin_code == BuiltinOperator_CUSTOM) {
|
|
init_data = custom_data;
|
|
init_data_size = custom_data_size;
|
|
} else {
|
|
init_data = reinterpret_cast<const char*>(builtin_data);
|
|
init_data_size = 0;
|
|
}
|
|
void* user_data = nullptr;
|
|
if (registration->init) {
|
|
user_data = registration->init(&context_, init_data, init_data_size);
|
|
}
|
|
|
|
// Disregard const qualifier to workaround with existing API.
|
|
TfLiteIntArray* inputs_array = const_cast<TfLiteIntArray*>(
|
|
reinterpret_cast<const TfLiteIntArray*>(op->inputs()));
|
|
TfLiteIntArray* outputs_array = const_cast<TfLiteIntArray*>(
|
|
reinterpret_cast<const TfLiteIntArray*>(op->outputs()));
|
|
|
|
const int kMaxTemporaries = 16;
|
|
int temporaries_data[kMaxTemporaries + 1];
|
|
TfLiteIntArray* temporaries_array =
|
|
reinterpret_cast<TfLiteIntArray*>(temporaries_data);
|
|
temporaries_array->size = 0;
|
|
|
|
TfLiteNode node;
|
|
node.inputs = inputs_array;
|
|
node.outputs = outputs_array;
|
|
node.temporaries = temporaries_array;
|
|
node.user_data = user_data;
|
|
node.builtin_data = reinterpret_cast<void*>(builtin_data);
|
|
node.custom_initial_data = custom_data;
|
|
node.custom_initial_data_size = custom_data_size;
|
|
node.delegate = nullptr;
|
|
if (registration->prepare) {
|
|
TfLiteStatus prepare_status = registration->prepare(&context_, &node);
|
|
if (prepare_status != kTfLiteOk) {
|
|
error_reporter_->Report(
|
|
"Node %s (number %d) failed to prepare with status %d",
|
|
OpNameFromRegistration(registration), i, prepare_status);
|
|
return kTfLiteError;
|
|
}
|
|
}
|
|
|
|
if (registration->invoke) {
|
|
TfLiteStatus invoke_status = registration->invoke(&context_, &node);
|
|
if (invoke_status != kTfLiteOk) {
|
|
error_reporter_->Report(
|
|
"Node %s (number %d) failed to invoke with status %d",
|
|
OpNameFromRegistration(registration), i, invoke_status);
|
|
return kTfLiteError;
|
|
}
|
|
}
|
|
|
|
if (registration->free) {
|
|
registration->free(&context_, user_data);
|
|
}
|
|
}
|
|
return status;
|
|
}
|
|
|
|
TfLiteTensor* MicroInterpreter::input(int index) {
|
|
const flatbuffers::Vector<int32_t>* inputs = subgraph_->inputs();
|
|
const size_t length = inputs->size();
|
|
if ((index < 0) || (index >= length)) {
|
|
error_reporter_->Report("Input index %d out of range (length is %d)", index,
|
|
length);
|
|
return nullptr;
|
|
}
|
|
return &(context_.tensors[inputs->Get(index)]);
|
|
}
|
|
|
|
TfLiteTensor* MicroInterpreter::output(int index) {
|
|
const flatbuffers::Vector<int32_t>* outputs = subgraph_->outputs();
|
|
const size_t length = outputs->size();
|
|
if ((index < 0) || (index >= outputs->size())) {
|
|
error_reporter_->Report("Output index %d out of range (length is %d)",
|
|
index, length);
|
|
return nullptr;
|
|
}
|
|
return &(context_.tensors[outputs->Get(index)]);
|
|
}
|
|
|
|
} // namespace tflite
|