STT-tensorflow/tensorflow/lite/interpreter_builder.h
Fergus Henderson 172bbf2f2c Document that for an InterpreterBuilder constructed from a FlatBufferModel,
the ErrorReporter for the FlatBufferModel will be used for the
generated Interpreter(s).

PiperOrigin-RevId: 351809813
Change-Id: I03f0e524638b6ed0d5bbb5527a376aab2e366e2e
2021-01-14 09:03:33 -08:00

106 lines
4.7 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
/// \file
/// Provides functionality to construct an interpreter for a model.
///
#ifndef TENSORFLOW_LITE_INTERPRETER_BUILDER_H_
#define TENSORFLOW_LITE_INTERPRETER_BUILDER_H_
#include <memory>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/model_builder.h"
#include "tensorflow/lite/mutable_op_resolver.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace tflite {
/// Build an interpreter capable of interpreting `model`.
///
/// `model`: A model whose lifetime must be at least as long as any
/// interpreter(s) created by the builder. In principle multiple interpreters
/// can be made from a single model.
/// `op_resolver`: An instance that implements the `OpResolver` interface, which
/// maps custom op names and builtin op codes to op registrations. The
/// lifetime of the provided `op_resolver` object must be at least as long as
/// the `InterpreterBuilder`; unlike `model` and `error_reporter`, the
/// `op_resolver` does not need to exist for the duration of any created
/// `Interpreter` objects.
/// `error_reporter`: a functor that is called to report errors that handles
/// printf var arg semantics. The lifetime of the `error_reporter` object must
/// be greater than or equal to the `Interpreter` created by `operator()`.
///
/// Returns a kTfLiteOk when successful and sets interpreter to a valid
/// Interpreter. Note: The user must ensure the lifetime of the model (and error
/// reporter, if provided) is at least as long as interpreter's lifetime.
class InterpreterBuilder {
public:
/// For this constructor, the ErrorReporter will be extracted from the
/// FlatBufferModel.
InterpreterBuilder(const FlatBufferModel& model,
const OpResolver& op_resolver);
/// Builds an interpreter given only the raw flatbuffer Model object (instead
/// of a FlatBufferModel). Mostly used for testing.
/// If `error_reporter` is null, then DefaultErrorReporter() is used.
InterpreterBuilder(const ::tflite::Model* model,
const OpResolver& op_resolver,
ErrorReporter* error_reporter = DefaultErrorReporter());
~InterpreterBuilder();
InterpreterBuilder(const InterpreterBuilder&) = delete;
InterpreterBuilder& operator=(const InterpreterBuilder&) = delete;
TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter);
TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter,
int num_threads);
private:
TfLiteStatus BuildLocalIndexToRegistrationMapping();
TfLiteStatus ParseNodes(
const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators,
Subgraph* subgraph);
TfLiteStatus ParseTensors(
const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
Subgraph* subgraph);
TfLiteStatus ApplyDelegates(Interpreter* interpreter, int num_threads);
TfLiteStatus ParseQuantization(const QuantizationParameters* src_quantization,
TfLiteQuantization* quantization,
const std::vector<int>& dims);
TfLiteStatus ParseSparsity(const SparsityParameters* src_sparsity,
TfLiteSparsity** sparsity);
TfLiteStatus ParseSignatureDefs(
const flatbuffers::Vector<flatbuffers::Offset<SignatureDef>>*
signature_def_list,
Interpreter* interpreter);
const ::tflite::Model* model_;
const OpResolver& op_resolver_;
ErrorReporter* error_reporter_;
std::vector<const TfLiteRegistration*> flatbuffer_op_index_to_registration_;
std::vector<TfLiteRegistration> unresolved_custom_ops_;
std::vector<BuiltinOperator> flatbuffer_op_index_to_registration_types_;
const Allocation* allocation_ = nullptr;
bool has_flex_op_ = false;
int num_fp32_tensors_ = 0;
};
} // namespace tflite
#endif // TENSORFLOW_LITE_INTERPRETER_BUILDER_H_