436 lines
16 KiB
C++
436 lines
16 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/lite/interpreter.h"
|
|
|
|
#include <cassert>
|
|
#include <cstdarg>
|
|
#include <cstdint>
|
|
#include <cstring>
|
|
#include <utility>
|
|
|
|
#include "tensorflow/lite/c/common.h"
|
|
#include "tensorflow/lite/context_util.h"
|
|
#include "tensorflow/lite/core/api/error_reporter.h"
|
|
#include "tensorflow/lite/delegates/status.h"
|
|
#include "tensorflow/lite/graph_info.h"
|
|
#include "tensorflow/lite/memory_planner.h"
|
|
#include "tensorflow/lite/minimal_logging.h"
|
|
#include "tensorflow/lite/schema/schema_generated.h"
|
|
#include "tensorflow/lite/util.h"
|
|
|
|
// TODO(b/139446230): Move to portable platform header.
|
|
#if defined(__ANDROID__)
|
|
#define TFLITE_IS_MOBILE_PLATFORM
|
|
#endif // defined(__ANDROID__)
|
|
|
|
#if defined(__APPLE__)
|
|
#include "TargetConditionals.h"
|
|
#if TARGET_IPHONE_SIMULATOR
|
|
#define TFLITE_IS_MOBILE_PLATFORM
|
|
#elif TARGET_OS_IPHONE
|
|
#define TFLITE_IS_MOBILE_PLATFORM
|
|
#endif
|
|
#endif // defined(__APPLE__)
|
|
|
|
// TODO(b/132087118): move static_assert to c_api_internal when compiled with
|
|
// C++.
|
|
static_assert(sizeof(TfLiteFloat16) == sizeof(uint16_t),
|
|
"Float 16 type must be 16 bits.");
|
|
|
|
namespace tflite {
|
|
|
|
namespace impl {
|
|
|
|
namespace {
|
|
|
|
// Gets the current TfLiteQuantization from the legacy TfLiteQuantizationParams.
|
|
TfLiteQuantization GetQuantizationFromLegacy(
|
|
const TfLiteQuantizationParams& legacy_quantization) {
|
|
TfLiteQuantization quantization;
|
|
quantization.type = kTfLiteAffineQuantization;
|
|
auto* affine_quantization = reinterpret_cast<TfLiteAffineQuantization*>(
|
|
malloc(sizeof(TfLiteAffineQuantization)));
|
|
affine_quantization->scale = TfLiteFloatArrayCreate(1);
|
|
affine_quantization->zero_point = TfLiteIntArrayCreate(1);
|
|
affine_quantization->scale->data[0] = legacy_quantization.scale;
|
|
affine_quantization->zero_point->data[0] = legacy_quantization.zero_point;
|
|
quantization.params = affine_quantization;
|
|
|
|
return quantization;
|
|
}
|
|
|
|
// TODO(b/153131797): We have put 'delegate_status' to 0 in the following macro
|
|
// temporarily because delegate-specific error codes are either not retrievable
|
|
// at the moment, which we will add later.
|
|
#define TF_LITE_ENSURE_STATUS_WITH_SCOPED_INSTRUMENTATION(runtime_event, a) \
|
|
do { \
|
|
TfLiteStatus status = (a); \
|
|
runtime_event.set_runtime_status(/*delegate_status=*/0, \
|
|
static_cast<int64_t>(status)); \
|
|
TF_LITE_ENSURE_STATUS(status); \
|
|
} while (0)
|
|
|
|
} // namespace
|
|
|
|
Interpreter::Interpreter(ErrorReporter* error_reporter)
|
|
: error_reporter_(error_reporter ? error_reporter : DefaultErrorReporter()),
|
|
lazy_delegate_provider_(
|
|
TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {})) {
|
|
// TODO(b/128420794): Include the TFLite runtime version in the log.
|
|
// Prod logging is useful for mobile platforms where scraping console logs is
|
|
// critical for debugging.
|
|
#if defined(TFLITE_IS_MOBILE_PLATFORM)
|
|
TFLITE_LOG_PROD_ONCE(TFLITE_LOG_INFO, "Initialized TensorFlow Lite runtime.");
|
|
#else
|
|
TFLITE_LOG_ONCE(TFLITE_LOG_INFO, "Initialized TensorFlow Lite runtime.");
|
|
#endif
|
|
|
|
// There's always at least 1 subgraph which is the primary subgraph.
|
|
AddSubgraphs(1);
|
|
context_ = primary_subgraph().context();
|
|
|
|
// Reserve some space for the tensors to avoid excessive resizing.
|
|
for (int i = 0; i < kTfLiteMaxExternalContexts; ++i) {
|
|
external_contexts_[i] = nullptr;
|
|
}
|
|
|
|
// This operation is cheap because we allocate the CPU context resources (i.e.
|
|
// threads) lazily.
|
|
own_external_cpu_backend_context_.reset(new ExternalCpuBackendContext());
|
|
external_contexts_[kTfLiteCpuBackendContext] =
|
|
own_external_cpu_backend_context_.get();
|
|
|
|
UseNNAPI(false);
|
|
}
|
|
|
|
Interpreter::~Interpreter() {
|
|
// The owned external Cpu Backend Context will go out of scope with this
|
|
// interpreter. If we have an external backend context that is not
|
|
// owned, we need to clear the cache for other interpreters that may
|
|
// use the context.
|
|
if (external_contexts_[kTfLiteCpuBackendContext] &&
|
|
(external_contexts_[kTfLiteCpuBackendContext] !=
|
|
own_external_cpu_backend_context_.get())) {
|
|
ExternalCpuBackendContext* external_context =
|
|
static_cast<ExternalCpuBackendContext*>(
|
|
external_contexts_[kTfLiteCpuBackendContext]);
|
|
TfLiteInternalBackendContext* internal_context =
|
|
external_context->internal_backend_context();
|
|
if (internal_context) {
|
|
// This call may have negative performance impacts on the next inference
|
|
// for any interpreter using this context. The cache will be refreshed
|
|
// by the next inference.
|
|
internal_context->ClearCaches();
|
|
}
|
|
}
|
|
}
|
|
|
|
void Interpreter::SetExternalContext(TfLiteExternalContextType type,
|
|
TfLiteExternalContext* ctx) {
|
|
if (ctx == own_external_cpu_backend_context_.get()) {
|
|
error_reporter_->Report(
|
|
"WARNING: The passed external context is identical to the internally "
|
|
"owned one.");
|
|
return;
|
|
}
|
|
|
|
// We have an internally owned external context of kTfLiteCpuBackendContext.
|
|
// If it's overwritten here, we will release the resource of the internally
|
|
// owned external context.
|
|
// Note: the 'max thread count' info associated with the overwritten context
|
|
// will be lost here, and such info is now determined by the new context, thus
|
|
// affecting how much parallelism a TFLite op would have.
|
|
if (kTfLiteCpuBackendContext == type &&
|
|
external_contexts_[kTfLiteCpuBackendContext] ==
|
|
own_external_cpu_backend_context_.get()) {
|
|
own_external_cpu_backend_context_.reset();
|
|
}
|
|
|
|
// This essentially changes the "external_contexts_[type]".
|
|
primary_subgraph().SetExternalContext(type, ctx);
|
|
}
|
|
|
|
TfLiteStatus Interpreter::SetInputs(std::vector<int> inputs) {
|
|
return primary_subgraph().SetInputs(std::move(inputs));
|
|
}
|
|
|
|
TfLiteStatus Interpreter::SetOutputs(std::vector<int> outputs) {
|
|
return primary_subgraph().SetOutputs(std::move(outputs));
|
|
}
|
|
|
|
TfLiteStatus Interpreter::SetVariables(std::vector<int> variables) {
|
|
return primary_subgraph().SetVariables(std::move(variables));
|
|
}
|
|
|
|
TfLiteStatus Interpreter::AllocateTensors() {
|
|
// Apply the default delegate that TFLite will enable at this point to allow
|
|
// other user-level delegates to be applied first.
|
|
if (lazy_delegate_provider_) {
|
|
// The execution will fall back to default implementation if the XNNPACK
|
|
// delegate fails to be applied. Therefore, we ignore the return status
|
|
// here and let it fall through the rest of the code.
|
|
ModifyGraphWithDelegate(std::move(lazy_delegate_provider_));
|
|
lazy_delegate_provider_.reset();
|
|
}
|
|
|
|
return primary_subgraph().AllocateTensors();
|
|
}
|
|
|
|
void Interpreter::ReserveNodes(int count) {
|
|
primary_subgraph().ReserveNodes(count);
|
|
}
|
|
|
|
void Interpreter::AddSubgraphs(int subgraphs_to_add,
|
|
int* first_new_subgraph_index) {
|
|
const size_t base_index = subgraphs_.size();
|
|
if (first_new_subgraph_index) *first_new_subgraph_index = base_index;
|
|
|
|
subgraphs_.reserve(base_index + subgraphs_to_add);
|
|
for (int i = 0; i < subgraphs_to_add; ++i) {
|
|
Subgraph* subgraph = new Subgraph(error_reporter_, external_contexts_,
|
|
&subgraphs_, &resources_);
|
|
subgraphs_.emplace_back(subgraph);
|
|
}
|
|
}
|
|
|
|
TfLiteStatus Interpreter::AddNodeWithParameters(
|
|
const std::vector<int>& inputs, const std::vector<int>& outputs,
|
|
const char* init_data, size_t init_data_size, void* builtin_data,
|
|
const TfLiteRegistration* registration, int* node_index) {
|
|
return primary_subgraph().AddNodeWithParameters(
|
|
inputs, outputs, {}, init_data, init_data_size, builtin_data,
|
|
registration, node_index);
|
|
}
|
|
|
|
TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index,
|
|
const std::vector<int>& dims) {
|
|
return primary_subgraph().ResizeInputTensor(tensor_index, dims);
|
|
}
|
|
|
|
TfLiteStatus Interpreter::ResizeInputTensorStrict(
|
|
int tensor_index, const std::vector<int>& dims) {
|
|
return primary_subgraph().ResizeInputTensorStrict(tensor_index, dims);
|
|
}
|
|
|
|
TfLiteStatus Interpreter::ReleaseNonPersistentMemory() {
|
|
// TODO(b/138790287): We could do this for all subgraphs whose tensors have
|
|
// been allocated. However, AllocateTensors() relies on Control Flow ops to
|
|
// allocate tensors on 'children' subgraphs. Revisit this if required.
|
|
return primary_subgraph().ReleaseNonPersistentMemory();
|
|
}
|
|
|
|
TfLiteStatus Interpreter::Invoke() {
|
|
ScopedRuntimeInstrumentationProfile scoped_runtime_event(installed_profiler_,
|
|
"invoke");
|
|
TF_LITE_ENSURE_STATUS_WITH_SCOPED_INSTRUMENTATION(
|
|
scoped_runtime_event, primary_subgraph().Invoke());
|
|
|
|
if (!allow_buffer_handle_output_) {
|
|
for (int tensor_index : outputs()) {
|
|
TF_LITE_ENSURE_STATUS_WITH_SCOPED_INSTRUMENTATION(
|
|
scoped_runtime_event,
|
|
primary_subgraph().EnsureTensorDataIsReadable(tensor_index));
|
|
}
|
|
}
|
|
|
|
return kTfLiteOk;
|
|
}
|
|
|
|
TfLiteStatus Interpreter::AddTensors(int tensors_to_add,
|
|
int* first_new_tensor_index) {
|
|
return primary_subgraph().AddTensors(tensors_to_add, first_new_tensor_index);
|
|
}
|
|
|
|
TfLiteStatus Interpreter::ResetVariableTensors() {
|
|
return primary_subgraph().ResetVariableTensors();
|
|
}
|
|
|
|
TfLiteStatus Interpreter::SetTensorParametersReadOnly(
|
|
int tensor_index, TfLiteType type, const char* name,
|
|
const std::vector<int>& dims, TfLiteQuantization quantization,
|
|
const char* buffer, size_t bytes, const Allocation* allocation) {
|
|
return primary_subgraph().SetTensorParametersReadOnly(
|
|
tensor_index, type, name, dims.size(), dims.data(), quantization, buffer,
|
|
bytes, allocation);
|
|
}
|
|
|
|
TfLiteStatus Interpreter::SetTensorParametersReadWrite(
|
|
int tensor_index, TfLiteType type, const char* name,
|
|
const std::vector<int>& dims, TfLiteQuantization quantization,
|
|
bool is_variable) {
|
|
return primary_subgraph().SetTensorParametersReadWrite(
|
|
tensor_index, type, name, dims.size(), dims.data(), quantization,
|
|
is_variable);
|
|
}
|
|
|
|
TfLiteStatus Interpreter::SetTensorParametersReadOnly(
|
|
int tensor_index, TfLiteType type, const char* name, const size_t rank,
|
|
const int* dims, TfLiteQuantizationParams quantization, const char* buffer,
|
|
size_t bytes, const Allocation* allocation) {
|
|
TfLiteQuantization new_quantization = GetQuantizationFromLegacy(quantization);
|
|
return primary_subgraph().SetTensorParametersReadOnly(
|
|
tensor_index, type, name, rank, dims, new_quantization, buffer, bytes,
|
|
allocation);
|
|
}
|
|
|
|
TfLiteStatus Interpreter::SetTensorParametersReadWrite(
|
|
int tensor_index, TfLiteType type, const char* name, const size_t rank,
|
|
const int* dims, TfLiteQuantizationParams quantization, bool is_variable,
|
|
const size_t rank_dims_signature, const int* dims_signature) {
|
|
TfLiteQuantization new_quantization = GetQuantizationFromLegacy(quantization);
|
|
return primary_subgraph().SetTensorParametersReadWrite(
|
|
tensor_index, type, name, rank, dims, new_quantization, is_variable,
|
|
rank_dims_signature, dims_signature);
|
|
}
|
|
|
|
TfLiteStatus Interpreter::SetExecutionPlan(const std::vector<int>& new_plan) {
|
|
return primary_subgraph().SetExecutionPlan(new_plan);
|
|
}
|
|
|
|
void Interpreter::UseNNAPI(bool enable) { primary_subgraph().UseNNAPI(enable); }
|
|
|
|
void Interpreter::SetNumThreads(int num_threads) {
|
|
if (num_threads < -1) {
|
|
context_->ReportError(context_,
|
|
"num_threads should be >=0 or just -1 to let TFLite "
|
|
"runtime set the value.");
|
|
return;
|
|
}
|
|
|
|
for (auto& subgraph : subgraphs_) {
|
|
subgraph->context()->recommended_num_threads = num_threads;
|
|
}
|
|
|
|
for (int i = 0; i < kTfLiteMaxExternalContexts; ++i) {
|
|
auto* c = external_contexts_[i];
|
|
if (c && c->Refresh) {
|
|
c->Refresh(context_);
|
|
}
|
|
}
|
|
}
|
|
|
|
void Interpreter::SetAllowFp16PrecisionForFp32(bool allow) {
|
|
for (auto& subgraph : subgraphs_) {
|
|
subgraph->context()->allow_fp32_relax_to_fp16 = allow;
|
|
}
|
|
}
|
|
|
|
// TODO(b/121264966): Subgraphs added after cancellation is set will not get the
|
|
// cancellation function added to their context.
|
|
void Interpreter::SetCancellationFunction(void* data,
|
|
bool (*check_cancelled_func)(void*)) {
|
|
for (auto& subgraph : subgraphs_) {
|
|
subgraph->SetCancellationFunction(data, check_cancelled_func);
|
|
}
|
|
}
|
|
|
|
bool Interpreter::IsCancelled() { return primary_subgraph().IsCancelled(); }
|
|
|
|
TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
|
|
TfLiteStatus status = kTfLiteOk;
|
|
for (auto& subgraph : subgraphs_) {
|
|
status = subgraph->ModifyGraphWithDelegate(delegate);
|
|
if (status != kTfLiteOk) {
|
|
break;
|
|
}
|
|
}
|
|
// Delegate-specific errors can be recovered from by restoring Interpreter to
|
|
// its original state.
|
|
if (status == kTfLiteDelegateError) {
|
|
TF_LITE_ENSURE_STATUS(RemoveAllDelegates());
|
|
}
|
|
return status;
|
|
}
|
|
|
|
TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegatePtr delegate) {
|
|
// Note that we retain ownership of the delegate even if graph modification
|
|
// fails, as delegate use will be in an indeterminate state at that point.
|
|
owned_delegates_.push_back(std::move(delegate));
|
|
return ModifyGraphWithDelegate(owned_delegates_.back().get());
|
|
}
|
|
|
|
TfLiteStatus Interpreter::RemoveAllDelegates() {
|
|
for (auto& subgraph : subgraphs_) {
|
|
TF_LITE_ENSURE_STATUS(subgraph->RemoveAllDelegates());
|
|
}
|
|
return kTfLiteOk;
|
|
}
|
|
|
|
bool Interpreter::HasDelegates() { return primary_subgraph().HasDelegates(); }
|
|
|
|
TfLiteStatus Interpreter::SetBufferHandle(int tensor_index,
|
|
TfLiteBufferHandle buffer_handle,
|
|
TfLiteDelegate* delegate) {
|
|
TF_LITE_ENSURE(context_, tensor_index < tensors_size());
|
|
std::vector<TfLiteTensor>& tensors = primary_subgraph().tensors();
|
|
TfLiteTensor* tensor = &tensors[tensor_index];
|
|
|
|
TF_LITE_ENSURE(context_,
|
|
tensor->delegate == nullptr || tensor->delegate == delegate);
|
|
tensor->delegate = delegate;
|
|
if (tensor->buffer_handle != kTfLiteNullBufferHandle) {
|
|
TF_LITE_ENSURE(context_, tensor->delegate->FreeBufferHandle != nullptr);
|
|
tensor->delegate->FreeBufferHandle(context_, tensor->delegate,
|
|
&tensor->buffer_handle);
|
|
}
|
|
tensor->buffer_handle = buffer_handle;
|
|
|
|
return kTfLiteOk;
|
|
}
|
|
|
|
TfLiteStatus Interpreter::GetBufferHandle(int tensor_index,
|
|
TfLiteBufferHandle* buffer_handle,
|
|
TfLiteDelegate** delegate) {
|
|
TF_LITE_ENSURE(context_, tensor_index < tensors_size());
|
|
std::vector<TfLiteTensor>& tensors = primary_subgraph().tensors();
|
|
TfLiteTensor* tensor = &tensors[tensor_index];
|
|
|
|
*delegate = tensor->delegate;
|
|
*buffer_handle = tensor->buffer_handle;
|
|
|
|
return kTfLiteOk;
|
|
}
|
|
|
|
void Interpreter::SetProfiler(Profiler* profiler) {
|
|
// Release resources occupied by owned_profiler_ which is replaced by
|
|
// caller-owned profiler.
|
|
owned_profiler_.reset(nullptr);
|
|
installed_profiler_ = profiler;
|
|
SetSubgraphProfiler();
|
|
}
|
|
|
|
void Interpreter::SetProfiler(std::unique_ptr<Profiler> profiler) {
|
|
owned_profiler_ = std::move(profiler);
|
|
installed_profiler_ = owned_profiler_.get();
|
|
SetSubgraphProfiler();
|
|
}
|
|
|
|
void Interpreter::SetSubgraphProfiler() {
|
|
for (int subgraph_index = 0; subgraph_index < subgraphs_.size();
|
|
++subgraph_index) {
|
|
subgraphs_[subgraph_index]->SetProfiler(installed_profiler_,
|
|
subgraph_index);
|
|
}
|
|
}
|
|
|
|
Profiler* Interpreter::GetProfiler() {
|
|
return primary_subgraph().GetProfiler();
|
|
}
|
|
|
|
} // namespace impl
|
|
|
|
} // namespace tflite
|