the ErrorReporter for the FlatBufferModel will be used for the generated Interpreter(s). PiperOrigin-RevId: 351809813 Change-Id: I03f0e524638b6ed0d5bbb5527a376aab2e366e2e
106 lines
4.7 KiB
C++
106 lines
4.7 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
/// \file
|
|
/// Provides functionality to construct an interpreter for a model.
|
|
///
|
|
#ifndef TENSORFLOW_LITE_INTERPRETER_BUILDER_H_
|
|
#define TENSORFLOW_LITE_INTERPRETER_BUILDER_H_
|
|
|
|
#include <memory>
|
|
|
|
#include "tensorflow/lite/c/common.h"
|
|
#include "tensorflow/lite/core/api/error_reporter.h"
|
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
|
#include "tensorflow/lite/interpreter.h"
|
|
#include "tensorflow/lite/model_builder.h"
|
|
#include "tensorflow/lite/mutable_op_resolver.h"
|
|
#include "tensorflow/lite/schema/schema_generated.h"
|
|
|
|
namespace tflite {
|
|
|
|
/// Build an interpreter capable of interpreting `model`.
|
|
///
|
|
/// `model`: A model whose lifetime must be at least as long as any
|
|
/// interpreter(s) created by the builder. In principle multiple interpreters
|
|
/// can be made from a single model.
|
|
/// `op_resolver`: An instance that implements the `OpResolver` interface, which
|
|
/// maps custom op names and builtin op codes to op registrations. The
|
|
/// lifetime of the provided `op_resolver` object must be at least as long as
|
|
/// the `InterpreterBuilder`; unlike `model` and `error_reporter`, the
|
|
/// `op_resolver` does not need to exist for the duration of any created
|
|
/// `Interpreter` objects.
|
|
/// `error_reporter`: a functor that is called to report errors that handles
|
|
/// printf var arg semantics. The lifetime of the `error_reporter` object must
|
|
/// be greater than or equal to the `Interpreter` created by `operator()`.
|
|
///
|
|
/// Returns a kTfLiteOk when successful and sets interpreter to a valid
|
|
/// Interpreter. Note: The user must ensure the lifetime of the model (and error
|
|
/// reporter, if provided) is at least as long as interpreter's lifetime.
|
|
class InterpreterBuilder {
|
|
public:
|
|
/// For this constructor, the ErrorReporter will be extracted from the
|
|
/// FlatBufferModel.
|
|
InterpreterBuilder(const FlatBufferModel& model,
|
|
const OpResolver& op_resolver);
|
|
/// Builds an interpreter given only the raw flatbuffer Model object (instead
|
|
/// of a FlatBufferModel). Mostly used for testing.
|
|
/// If `error_reporter` is null, then DefaultErrorReporter() is used.
|
|
InterpreterBuilder(const ::tflite::Model* model,
|
|
const OpResolver& op_resolver,
|
|
ErrorReporter* error_reporter = DefaultErrorReporter());
|
|
~InterpreterBuilder();
|
|
InterpreterBuilder(const InterpreterBuilder&) = delete;
|
|
InterpreterBuilder& operator=(const InterpreterBuilder&) = delete;
|
|
TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter);
|
|
TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter,
|
|
int num_threads);
|
|
|
|
private:
|
|
TfLiteStatus BuildLocalIndexToRegistrationMapping();
|
|
TfLiteStatus ParseNodes(
|
|
const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators,
|
|
Subgraph* subgraph);
|
|
TfLiteStatus ParseTensors(
|
|
const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
|
|
const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
|
|
Subgraph* subgraph);
|
|
TfLiteStatus ApplyDelegates(Interpreter* interpreter, int num_threads);
|
|
TfLiteStatus ParseQuantization(const QuantizationParameters* src_quantization,
|
|
TfLiteQuantization* quantization,
|
|
const std::vector<int>& dims);
|
|
TfLiteStatus ParseSparsity(const SparsityParameters* src_sparsity,
|
|
TfLiteSparsity** sparsity);
|
|
TfLiteStatus ParseSignatureDefs(
|
|
const flatbuffers::Vector<flatbuffers::Offset<SignatureDef>>*
|
|
signature_def_list,
|
|
Interpreter* interpreter);
|
|
|
|
const ::tflite::Model* model_;
|
|
const OpResolver& op_resolver_;
|
|
ErrorReporter* error_reporter_;
|
|
|
|
std::vector<const TfLiteRegistration*> flatbuffer_op_index_to_registration_;
|
|
std::vector<TfLiteRegistration> unresolved_custom_ops_;
|
|
std::vector<BuiltinOperator> flatbuffer_op_index_to_registration_types_;
|
|
const Allocation* allocation_ = nullptr;
|
|
|
|
bool has_flex_op_ = false;
|
|
int num_fp32_tensors_ = 0;
|
|
};
|
|
|
|
} // namespace tflite
|
|
|
|
#endif // TENSORFLOW_LITE_INTERPRETER_BUILDER_H_
|