Separate model.cc
into model_builder.cc
and interpreter_builder.cc
.
PiperOrigin-RevId: 302570223 Change-Id: I9b26c21dc7db0a1fb225986db20b1ed66f6fcc82
This commit is contained in:
parent
e918c6e6fa
commit
dcf212b2a5
@ -72,6 +72,8 @@ FRAMEWORK_LIB_HDRS = [
|
||||
"graph_info.h",
|
||||
"interpreter.h",
|
||||
"model.h",
|
||||
"model_builder.h",
|
||||
"interpreter_builder.h",
|
||||
"mutable_op_resolver.h",
|
||||
"op_resolver.h",
|
||||
"optional_debug_tools.h",
|
||||
@ -222,7 +224,8 @@ cc_library(
|
||||
"core/subgraph.cc",
|
||||
"graph_info.cc",
|
||||
"interpreter.cc",
|
||||
"model.cc",
|
||||
"interpreter_builder.cc",
|
||||
"model_builder.cc",
|
||||
"mutable_op_resolver.cc",
|
||||
"optional_debug_tools.cc",
|
||||
"stderr_reporter.cc",
|
||||
|
@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/interpreter_builder.h"
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <stdint.h>
|
||||
@ -37,6 +37,7 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
|
||||
namespace {
|
||||
|
||||
// Ensure that ErrorReporter is non-null.
|
||||
ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
|
||||
return e ? e : DefaultErrorReporter();
|
||||
@ -91,6 +92,7 @@ TfLiteStatus ParseSparseIndexVector(const DimensionMetadata* src,
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
const char* kEmptyTensorName = "";
|
||||
@ -114,162 +116,6 @@ __attribute__((weak)) Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() {
|
||||
Interpreter::TfLiteDelegatePtr (*AcquireFlexDelegate)() = nullptr;
|
||||
#endif
|
||||
|
||||
#ifndef TFLITE_MCU
|
||||
// Loads a model from `filename`. If `mmap_file` is true then use mmap,
|
||||
// otherwise make a copy of the model in a buffer.
|
||||
std::unique_ptr<Allocation> GetAllocationFromFile(const char* filename,
|
||||
bool mmap_file,
|
||||
ErrorReporter* error_reporter,
|
||||
bool use_nnapi) {
|
||||
std::unique_ptr<Allocation> allocation;
|
||||
if (mmap_file && MMAPAllocation::IsSupported()) {
|
||||
allocation.reset(new MMAPAllocation(filename, error_reporter));
|
||||
} else {
|
||||
allocation.reset(new FileCopyAllocation(filename, error_reporter));
|
||||
}
|
||||
return allocation;
|
||||
}
|
||||
|
||||
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
|
||||
const char* filename, ErrorReporter* error_reporter) {
|
||||
error_reporter = ValidateErrorReporter(error_reporter);
|
||||
|
||||
std::unique_ptr<FlatBufferModel> model;
|
||||
auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
|
||||
error_reporter, /*use_nnapi=*/true);
|
||||
model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
|
||||
if (!model->initialized()) model.reset();
|
||||
return model;
|
||||
}
|
||||
|
||||
std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromFile(
|
||||
const char* filename, TfLiteVerifier* extra_verifier,
|
||||
ErrorReporter* error_reporter) {
|
||||
error_reporter = ValidateErrorReporter(error_reporter);
|
||||
|
||||
std::unique_ptr<FlatBufferModel> model;
|
||||
auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
|
||||
error_reporter, /*use_nnapi=*/true);
|
||||
|
||||
flatbuffers::Verifier base_verifier(
|
||||
reinterpret_cast<const uint8_t*>(allocation->base()),
|
||||
allocation->bytes());
|
||||
if (!VerifyModelBuffer(base_verifier)) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter,
|
||||
"The model is not a valid Flatbuffer file");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (extra_verifier &&
|
||||
!extra_verifier->Verify(static_cast<const char*>(allocation->base()),
|
||||
allocation->bytes(), error_reporter)) {
|
||||
return model;
|
||||
}
|
||||
model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
|
||||
if (!model->initialized()) model.reset();
|
||||
return model;
|
||||
}
|
||||
#endif
|
||||
|
||||
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
|
||||
const char* caller_owned_buffer, size_t buffer_size,
|
||||
ErrorReporter* error_reporter) {
|
||||
error_reporter = ValidateErrorReporter(error_reporter);
|
||||
|
||||
std::unique_ptr<FlatBufferModel> model;
|
||||
std::unique_ptr<Allocation> allocation(
|
||||
new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
|
||||
model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
|
||||
if (!model->initialized()) model.reset();
|
||||
return model;
|
||||
}
|
||||
|
||||
std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromBuffer(
|
||||
const char* caller_owned_buffer, size_t buffer_size,
|
||||
TfLiteVerifier* extra_verifier, ErrorReporter* error_reporter) {
|
||||
error_reporter = ValidateErrorReporter(error_reporter);
|
||||
|
||||
flatbuffers::Verifier base_verifier(
|
||||
reinterpret_cast<const uint8_t*>(caller_owned_buffer), buffer_size);
|
||||
if (!VerifyModelBuffer(base_verifier)) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter,
|
||||
"The model is not a valid Flatbuffer buffer");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (extra_verifier && !extra_verifier->Verify(caller_owned_buffer,
|
||||
buffer_size, error_reporter)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return BuildFromBuffer(caller_owned_buffer, buffer_size, error_reporter);
|
||||
}
|
||||
|
||||
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
|
||||
const tflite::Model* caller_owned_model_spec,
|
||||
ErrorReporter* error_reporter) {
|
||||
error_reporter = ValidateErrorReporter(error_reporter);
|
||||
|
||||
std::unique_ptr<FlatBufferModel> model;
|
||||
model.reset(new FlatBufferModel(caller_owned_model_spec, error_reporter));
|
||||
if (!model->initialized()) model.reset();
|
||||
return model;
|
||||
}
|
||||
|
||||
string FlatBufferModel::GetMinimumRuntime() const {
|
||||
if (!model_ || !model_->metadata()) return "";
|
||||
|
||||
for (int i = 0; i < model_->metadata()->size(); ++i) {
|
||||
auto metadata = model_->metadata()->Get(i);
|
||||
if (metadata->name()->str() == "min_runtime_version") {
|
||||
auto buf = metadata->buffer();
|
||||
auto* buffer = (*model_->buffers())[buf];
|
||||
auto* array = buffer->data();
|
||||
// Get the real length of the runtime string, since there might be
|
||||
// trailing
|
||||
// '\0's in the buffer.
|
||||
for (int len = 0; len < array->size(); ++len) {
|
||||
if (array->data()[len] == '\0') {
|
||||
return string(reinterpret_cast<const char*>(array->data()), len);
|
||||
}
|
||||
}
|
||||
// If there is no '\0' in the buffer, this indicates that the flatbuffer
|
||||
// is malformed.
|
||||
TF_LITE_REPORT_ERROR(
|
||||
error_reporter_,
|
||||
"Min_runtime_version in model metadata is malformed");
|
||||
break;
|
||||
}
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
bool FlatBufferModel::CheckModelIdentifier() const {
|
||||
if (!tflite::ModelBufferHasIdentifier(allocation_->base())) {
|
||||
const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base());
|
||||
error_reporter_->Report(
|
||||
"Model provided has model identifier '%c%c%c%c', should be '%s'\n",
|
||||
ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
FlatBufferModel::FlatBufferModel(const Model* model,
|
||||
ErrorReporter* error_reporter)
|
||||
: model_(model), error_reporter_(ValidateErrorReporter(error_reporter)) {}
|
||||
|
||||
FlatBufferModel::FlatBufferModel(std::unique_ptr<Allocation> allocation,
|
||||
ErrorReporter* error_reporter)
|
||||
: error_reporter_(ValidateErrorReporter(error_reporter)),
|
||||
allocation_(std::move(allocation)) {
|
||||
if (!allocation_->valid() || !CheckModelIdentifier()) return;
|
||||
|
||||
model_ = ::tflite::GetModel(allocation_->base());
|
||||
}
|
||||
|
||||
FlatBufferModel::~FlatBufferModel() {}
|
||||
|
||||
namespace impl {
|
||||
|
||||
InterpreterBuilder::InterpreterBuilder(const FlatBufferModel& model,
|
104
tensorflow/lite/interpreter_builder.h
Normal file
104
tensorflow/lite/interpreter_builder.h
Normal file
@ -0,0 +1,104 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
/// Deserialization infrastructure for tflite. Provides functionality
|
||||
/// to go from a serialized tflite model in flatbuffer format to an
|
||||
/// interpreter.
|
||||
///
|
||||
#ifndef TENSORFLOW_LITE_INTERPRETER_BUILDER_H_
|
||||
#define TENSORFLOW_LITE_INTERPRETER_BUILDER_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/model_builder.h"
|
||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
namespace impl {
|
||||
|
||||
/// Build an interpreter capable of interpreting `model`.
|
||||
///
|
||||
/// model: A model whose lifetime must be at least as long as any
|
||||
/// interpreter(s) created by the builder. In principle multiple interpreters
|
||||
/// can be made from a single model.
|
||||
/// op_resolver: An instance that implements the OpResolver interface, which
|
||||
/// maps
|
||||
/// custom op names and builtin op codes to op registrations. The lifetime
|
||||
/// of the provided `op_resolver` object must be at least as long as the
|
||||
/// InterpreterBuilder; unlike `model` and `error_reporter`, the `op_resolver`
|
||||
/// does not need to exist for the duration of any created Interpreter
|
||||
/// objects.
|
||||
/// error_reporter: a functor that is called to report errors that handles
|
||||
/// printf var arg semantics. The lifetime of the `error_reporter` object must
|
||||
/// be greater than or equal to the Interpreter created by operator().
|
||||
///
|
||||
/// Returns a kTfLiteOk when successful and sets interpreter to a valid
|
||||
/// Interpreter. Note: The user must ensure the model lifetime (and error
|
||||
/// reporter, if provided) is at least as long as interpreter's lifetime.
|
||||
class InterpreterBuilder {
|
||||
public:
|
||||
InterpreterBuilder(const FlatBufferModel& model,
|
||||
const OpResolver& op_resolver);
|
||||
/// Builds an interpreter given only the raw flatbuffer Model object (instead
|
||||
/// of a FlatBufferModel). Mostly used for testing.
|
||||
/// If `error_reporter` is null, then DefaultErrorReporter() is used.
|
||||
InterpreterBuilder(const ::tflite::Model* model,
|
||||
const OpResolver& op_resolver,
|
||||
ErrorReporter* error_reporter = DefaultErrorReporter());
|
||||
~InterpreterBuilder();
|
||||
InterpreterBuilder(const InterpreterBuilder&) = delete;
|
||||
InterpreterBuilder& operator=(const InterpreterBuilder&) = delete;
|
||||
TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter);
|
||||
TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter,
|
||||
int num_threads);
|
||||
|
||||
private:
|
||||
TfLiteStatus BuildLocalIndexToRegistrationMapping();
|
||||
TfLiteStatus ParseNodes(
|
||||
const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators,
|
||||
Subgraph* subgraph);
|
||||
TfLiteStatus ParseTensors(
|
||||
const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
|
||||
const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
|
||||
Subgraph* subgraph);
|
||||
TfLiteStatus ApplyDelegates(Interpreter* interpreter);
|
||||
TfLiteStatus ParseQuantization(const QuantizationParameters* src_quantization,
|
||||
TfLiteQuantization* quantization,
|
||||
const std::vector<int>& dims);
|
||||
TfLiteStatus ParseSparsity(const SparsityParameters* src_sparsity,
|
||||
TfLiteSparsity** sparsity);
|
||||
|
||||
const ::tflite::Model* model_;
|
||||
const OpResolver& op_resolver_;
|
||||
ErrorReporter* error_reporter_;
|
||||
|
||||
std::vector<const TfLiteRegistration*> flatbuffer_op_index_to_registration_;
|
||||
std::vector<TfLiteRegistration> unresolved_custom_ops_;
|
||||
std::vector<BuiltinOperator> flatbuffer_op_index_to_registration_types_;
|
||||
const Allocation* allocation_ = nullptr;
|
||||
|
||||
bool has_flex_op_ = false;
|
||||
};
|
||||
|
||||
} // namespace impl
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_INTERPRETER_BUILDER_H_
|
@ -19,229 +19,11 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_MODEL_H_
|
||||
#define TENSORFLOW_LITE_MODEL_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/interpreter_builder.h"
|
||||
#include "tensorflow/lite/model_builder.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
/// Abstract interface that verifies whether a given model is legit.
|
||||
/// It facilitates the use-case to verify and build a model without loading it
|
||||
/// twice.
|
||||
class TfLiteVerifier {
|
||||
public:
|
||||
/// Returns true if the model is legit.
|
||||
virtual bool Verify(const char* data, int length,
|
||||
ErrorReporter* reporter) = 0;
|
||||
virtual ~TfLiteVerifier() {}
|
||||
};
|
||||
|
||||
/// An RAII object that represents a read-only tflite model, copied from disk,
|
||||
/// or mmapped. This uses flatbuffers as the serialization format.
|
||||
///
|
||||
/// NOTE: The current API requires that a FlatBufferModel instance be kept alive
|
||||
/// by the client as long as it is in use by any dependent Interpreter
|
||||
/// instances.
|
||||
/// <pre><code>
|
||||
/// using namespace tflite;
|
||||
/// StderrReporter error_reporter;
|
||||
/// auto model = FlatBufferModel::BuildFromFile("interesting_model.tflite",
|
||||
/// &error_reporter);
|
||||
/// MyOpResolver resolver; // You need to subclass OpResolver to provide
|
||||
/// // implementations.
|
||||
/// InterpreterBuilder builder(*model, resolver);
|
||||
/// std::unique_ptr<Interpreter> interpreter;
|
||||
/// if(builder(&interpreter) == kTfLiteOk) {
|
||||
/// .. run model inference with interpreter
|
||||
/// }
|
||||
/// </code></pre>
|
||||
///
|
||||
/// OpResolver must be defined to provide your kernel implementations to the
|
||||
/// interpreter. This is environment specific and may consist of just the
|
||||
/// builtin ops, or some custom operators you defined to extend tflite.
|
||||
class FlatBufferModel {
|
||||
public:
|
||||
/// Builds a model based on a file.
|
||||
/// Caller retains ownership of `error_reporter` and must ensure its lifetime
|
||||
/// is longer than the FlatBufferModel instance.
|
||||
/// Returns a nullptr in case of failure.
|
||||
static std::unique_ptr<FlatBufferModel> BuildFromFile(
|
||||
const char* filename,
|
||||
ErrorReporter* error_reporter = DefaultErrorReporter());
|
||||
|
||||
/// Verifies whether the content of the file is legit, then builds a model
|
||||
/// based on the file.
|
||||
/// The extra_verifier argument is an additional optional verifier for the
|
||||
/// file contents. By default, we always check with tflite::VerifyModelBuffer.
|
||||
/// If extra_verifier is supplied, the file contents is also checked against
|
||||
/// the extra_verifier after the check against tflite::VerifyModelBuilder.
|
||||
/// Caller retains ownership of `error_reporter` and must ensure its lifetime
|
||||
/// is longer than the FlatBufferModel instance.
|
||||
/// Returns a nullptr in case of failure.
|
||||
static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromFile(
|
||||
const char* filename, TfLiteVerifier* extra_verifier = nullptr,
|
||||
ErrorReporter* error_reporter = DefaultErrorReporter());
|
||||
|
||||
/// Builds a model based on a pre-loaded flatbuffer.
|
||||
/// Caller retains ownership of the buffer and should keep it alive until
|
||||
/// the returned object is destroyed. Caller also retains ownership of
|
||||
/// `error_reporter` and must ensure its lifetime is longer than the
|
||||
/// FlatBufferModel instance.
|
||||
/// Returns a nullptr in case of failure.
|
||||
/// NOTE: this does NOT validate the buffer so it should NOT be called on
|
||||
/// invalid/untrusted input. Use VerifyAndBuildFromBuffer in that case
|
||||
static std::unique_ptr<FlatBufferModel> BuildFromBuffer(
|
||||
const char* caller_owned_buffer, size_t buffer_size,
|
||||
ErrorReporter* error_reporter = DefaultErrorReporter());
|
||||
|
||||
/// Verifies whether the content of the buffer is legit, then builds a model
|
||||
/// based on the pre-loaded flatbuffer.
|
||||
/// The extra_verifier argument is an additional optional verifier for the
|
||||
/// buffer. By default, we always check with tflite::VerifyModelBuffer. If
|
||||
/// extra_verifier is supplied, the buffer is checked against the
|
||||
/// extra_verifier after the check against tflite::VerifyModelBuilder. The
|
||||
/// caller retains ownership of the buffer and should keep it alive until the
|
||||
/// returned object is destroyed. Caller retains ownership of `error_reporter`
|
||||
/// and must ensure its lifetime is longer than the FlatBufferModel instance.
|
||||
/// Returns a nullptr in case of failure.
|
||||
static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromBuffer(
|
||||
const char* caller_owned_buffer, size_t buffer_size,
|
||||
TfLiteVerifier* extra_verifier = nullptr,
|
||||
ErrorReporter* error_reporter = DefaultErrorReporter());
|
||||
|
||||
/// Builds a model directly from a flatbuffer pointer
|
||||
/// Caller retains ownership of the buffer and should keep it alive until the
|
||||
/// returned object is destroyed. Caller retains ownership of `error_reporter`
|
||||
/// and must ensure its lifetime is longer than the FlatBufferModel instance.
|
||||
/// Returns a nullptr in case of failure.
|
||||
static std::unique_ptr<FlatBufferModel> BuildFromModel(
|
||||
const tflite::Model* caller_owned_model_spec,
|
||||
ErrorReporter* error_reporter = DefaultErrorReporter());
|
||||
|
||||
// Releases memory or unmaps mmaped memory.
|
||||
~FlatBufferModel();
|
||||
|
||||
// Copying or assignment is disallowed to simplify ownership semantics.
|
||||
FlatBufferModel(const FlatBufferModel&) = delete;
|
||||
FlatBufferModel& operator=(const FlatBufferModel&) = delete;
|
||||
|
||||
bool initialized() const { return model_ != nullptr; }
|
||||
const tflite::Model* operator->() const { return model_; }
|
||||
const tflite::Model* GetModel() const { return model_; }
|
||||
ErrorReporter* error_reporter() const { return error_reporter_; }
|
||||
const Allocation* allocation() const { return allocation_.get(); }
|
||||
|
||||
// Returns the minimum runtime version from the flatbuffer. This runtime
|
||||
// version encodes the minimum required interpreter version to run the
|
||||
// flatbuffer model. If the minimum version can't be determined, an empty
|
||||
// string will be returned.
|
||||
// Note that the returned minimum version is a lower-bound but not a strict
|
||||
// lower-bound; ops in the graph may not have an associated runtime version,
|
||||
// in which case the actual required runtime might be greater than the
|
||||
// reported minimum.
|
||||
string GetMinimumRuntime() const;
|
||||
|
||||
/// Returns true if the model identifier is correct (otherwise false and
|
||||
/// reports an error).
|
||||
bool CheckModelIdentifier() const;
|
||||
|
||||
private:
|
||||
/// Loads a model from a given allocation. FlatBufferModel will take over the
|
||||
/// ownership of `allocation`, and delete it in destructor. The ownership of
|
||||
/// `error_reporter`remains with the caller and must have lifetime at least
|
||||
/// as much as FlatBufferModel. This is to allow multiple models to use the
|
||||
/// same ErrorReporter instance.
|
||||
FlatBufferModel(std::unique_ptr<Allocation> allocation,
|
||||
ErrorReporter* error_reporter = DefaultErrorReporter());
|
||||
|
||||
/// Loads a model from Model flatbuffer. The `model` has to remain alive and
|
||||
/// unchanged until the end of this flatbuffermodel's lifetime.
|
||||
FlatBufferModel(const Model* model, ErrorReporter* error_reporter);
|
||||
|
||||
/// Flatbuffer traverser pointer. (Model* is a pointer that is within the
|
||||
/// allocated memory of the data allocated by allocation's internals.
|
||||
const tflite::Model* model_ = nullptr;
|
||||
/// The error reporter to use for model errors and subsequent errors when
|
||||
/// the interpreter is created
|
||||
ErrorReporter* error_reporter_;
|
||||
/// The allocator used for holding memory of the model. Note that this will
|
||||
/// be null if the client provides a tflite::Model directly.
|
||||
std::unique_ptr<Allocation> allocation_;
|
||||
};
|
||||
|
||||
namespace impl {
|
||||
|
||||
/// Build an interpreter capable of interpreting `model`.
|
||||
///
|
||||
/// model: A model whose lifetime must be at least as long as any
|
||||
/// interpreter(s) created by the builder. In principle multiple interpreters
|
||||
/// can be made from a single model.
|
||||
/// op_resolver: An instance that implements the OpResolver interface, which
|
||||
/// maps
|
||||
/// custom op names and builtin op codes to op registrations. The lifetime
|
||||
/// of the provided `op_resolver` object must be at least as long as the
|
||||
/// InterpreterBuilder; unlike `model` and `error_reporter`, the `op_resolver`
|
||||
/// does not need to exist for the duration of any created Interpreter
|
||||
/// objects.
|
||||
/// error_reporter: a functor that is called to report errors that handles
|
||||
/// printf var arg semantics. The lifetime of the `error_reporter` object must
|
||||
/// be greater than or equal to the Interpreter created by operator().
|
||||
///
|
||||
/// Returns a kTfLiteOk when successful and sets interpreter to a valid
|
||||
/// Interpreter. Note: The user must ensure the model lifetime (and error
|
||||
/// reporter, if provided) is at least as long as interpreter's lifetime.
|
||||
class InterpreterBuilder {
|
||||
public:
|
||||
InterpreterBuilder(const FlatBufferModel& model,
|
||||
const OpResolver& op_resolver);
|
||||
/// Builds an interpreter given only the raw flatbuffer Model object (instead
|
||||
/// of a FlatBufferModel). Mostly used for testing.
|
||||
/// If `error_reporter` is null, then DefaultErrorReporter() is used.
|
||||
InterpreterBuilder(const ::tflite::Model* model,
|
||||
const OpResolver& op_resolver,
|
||||
ErrorReporter* error_reporter = DefaultErrorReporter());
|
||||
~InterpreterBuilder();
|
||||
InterpreterBuilder(const InterpreterBuilder&) = delete;
|
||||
InterpreterBuilder& operator=(const InterpreterBuilder&) = delete;
|
||||
TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter);
|
||||
TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter,
|
||||
int num_threads);
|
||||
|
||||
private:
|
||||
TfLiteStatus BuildLocalIndexToRegistrationMapping();
|
||||
TfLiteStatus ParseNodes(
|
||||
const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators,
|
||||
Subgraph* subgraph);
|
||||
TfLiteStatus ParseTensors(
|
||||
const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
|
||||
const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
|
||||
Subgraph* subgraph);
|
||||
TfLiteStatus ApplyDelegates(Interpreter* interpreter);
|
||||
TfLiteStatus ParseQuantization(const QuantizationParameters* src_quantization,
|
||||
TfLiteQuantization* quantization,
|
||||
const std::vector<int>& dims);
|
||||
TfLiteStatus ParseSparsity(const SparsityParameters* src_sparsity,
|
||||
TfLiteSparsity** sparsity);
|
||||
|
||||
const ::tflite::Model* model_;
|
||||
const OpResolver& op_resolver_;
|
||||
ErrorReporter* error_reporter_;
|
||||
|
||||
std::vector<const TfLiteRegistration*> flatbuffer_op_index_to_registration_;
|
||||
std::vector<TfLiteRegistration> unresolved_custom_ops_;
|
||||
std::vector<BuiltinOperator> flatbuffer_op_index_to_registration_types_;
|
||||
const Allocation* allocation_ = nullptr;
|
||||
|
||||
bool has_flex_op_ = false;
|
||||
};
|
||||
|
||||
} // namespace impl
|
||||
|
||||
using InterpreterBuilder = impl::InterpreterBuilder;
|
||||
|
||||
} // namespace tflite
|
||||
|
204
tensorflow/lite/model_builder.cc
Normal file
204
tensorflow/lite/model_builder.cc
Normal file
@ -0,0 +1,204 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/model_builder.h"
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
|
||||
#include "tensorflow/lite/allocation.h"
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/core/api/flatbuffer_conversions.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/util.h"
|
||||
#include "tensorflow/lite/version.h"
|
||||
|
||||
#if defined(TFLITE_ENABLE_DEFAULT_PROFILER)
|
||||
#include "tensorflow/lite/profiling/platform_profiler.h"
|
||||
#endif
|
||||
|
||||
namespace tflite {
|
||||
|
||||
namespace {
|
||||
|
||||
// Ensure that ErrorReporter is non-null.
|
||||
ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
|
||||
return e ? e : DefaultErrorReporter();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
#ifndef TFLITE_MCU
|
||||
// Loads a model from `filename`. If `mmap_file` is true then use mmap,
|
||||
// otherwise make a copy of the model in a buffer.
|
||||
std::unique_ptr<Allocation> GetAllocationFromFile(const char* filename,
|
||||
bool mmap_file,
|
||||
ErrorReporter* error_reporter,
|
||||
bool use_nnapi) {
|
||||
std::unique_ptr<Allocation> allocation;
|
||||
if (mmap_file && MMAPAllocation::IsSupported()) {
|
||||
allocation.reset(new MMAPAllocation(filename, error_reporter));
|
||||
} else {
|
||||
allocation.reset(new FileCopyAllocation(filename, error_reporter));
|
||||
}
|
||||
return allocation;
|
||||
}
|
||||
|
||||
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
|
||||
const char* filename, ErrorReporter* error_reporter) {
|
||||
error_reporter = ValidateErrorReporter(error_reporter);
|
||||
|
||||
std::unique_ptr<FlatBufferModel> model;
|
||||
auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
|
||||
error_reporter, /*use_nnapi=*/true);
|
||||
model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
|
||||
if (!model->initialized()) model.reset();
|
||||
return model;
|
||||
}
|
||||
|
||||
std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromFile(
|
||||
const char* filename, TfLiteVerifier* extra_verifier,
|
||||
ErrorReporter* error_reporter) {
|
||||
error_reporter = ValidateErrorReporter(error_reporter);
|
||||
|
||||
std::unique_ptr<FlatBufferModel> model;
|
||||
auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
|
||||
error_reporter, /*use_nnapi=*/true);
|
||||
|
||||
flatbuffers::Verifier base_verifier(
|
||||
reinterpret_cast<const uint8_t*>(allocation->base()),
|
||||
allocation->bytes());
|
||||
if (!VerifyModelBuffer(base_verifier)) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter,
|
||||
"The model is not a valid Flatbuffer file");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (extra_verifier &&
|
||||
!extra_verifier->Verify(static_cast<const char*>(allocation->base()),
|
||||
allocation->bytes(), error_reporter)) {
|
||||
return model;
|
||||
}
|
||||
model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
|
||||
if (!model->initialized()) model.reset();
|
||||
return model;
|
||||
}
|
||||
#endif
|
||||
|
||||
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
|
||||
const char* caller_owned_buffer, size_t buffer_size,
|
||||
ErrorReporter* error_reporter) {
|
||||
error_reporter = ValidateErrorReporter(error_reporter);
|
||||
|
||||
std::unique_ptr<FlatBufferModel> model;
|
||||
std::unique_ptr<Allocation> allocation(
|
||||
new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
|
||||
model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
|
||||
if (!model->initialized()) model.reset();
|
||||
return model;
|
||||
}
|
||||
|
||||
std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromBuffer(
|
||||
const char* caller_owned_buffer, size_t buffer_size,
|
||||
TfLiteVerifier* extra_verifier, ErrorReporter* error_reporter) {
|
||||
error_reporter = ValidateErrorReporter(error_reporter);
|
||||
|
||||
flatbuffers::Verifier base_verifier(
|
||||
reinterpret_cast<const uint8_t*>(caller_owned_buffer), buffer_size);
|
||||
if (!VerifyModelBuffer(base_verifier)) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter,
|
||||
"The model is not a valid Flatbuffer buffer");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (extra_verifier && !extra_verifier->Verify(caller_owned_buffer,
|
||||
buffer_size, error_reporter)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return BuildFromBuffer(caller_owned_buffer, buffer_size, error_reporter);
|
||||
}
|
||||
|
||||
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
|
||||
const tflite::Model* caller_owned_model_spec,
|
||||
ErrorReporter* error_reporter) {
|
||||
error_reporter = ValidateErrorReporter(error_reporter);
|
||||
|
||||
std::unique_ptr<FlatBufferModel> model;
|
||||
model.reset(new FlatBufferModel(caller_owned_model_spec, error_reporter));
|
||||
if (!model->initialized()) model.reset();
|
||||
return model;
|
||||
}
|
||||
|
||||
string FlatBufferModel::GetMinimumRuntime() const {
|
||||
if (!model_ || !model_->metadata()) return "";
|
||||
|
||||
for (int i = 0; i < model_->metadata()->size(); ++i) {
|
||||
auto metadata = model_->metadata()->Get(i);
|
||||
if (metadata->name()->str() == "min_runtime_version") {
|
||||
auto buf = metadata->buffer();
|
||||
auto* buffer = (*model_->buffers())[buf];
|
||||
auto* array = buffer->data();
|
||||
// Get the real length of the runtime string, since there might be
|
||||
// trailing
|
||||
// '\0's in the buffer.
|
||||
for (int len = 0; len < array->size(); ++len) {
|
||||
if (array->data()[len] == '\0') {
|
||||
return string(reinterpret_cast<const char*>(array->data()), len);
|
||||
}
|
||||
}
|
||||
// If there is no '\0' in the buffer, this indicates that the flatbuffer
|
||||
// is malformed.
|
||||
TF_LITE_REPORT_ERROR(
|
||||
error_reporter_,
|
||||
"Min_runtime_version in model metadata is malformed");
|
||||
break;
|
||||
}
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
bool FlatBufferModel::CheckModelIdentifier() const {
|
||||
if (!tflite::ModelBufferHasIdentifier(allocation_->base())) {
|
||||
const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base());
|
||||
error_reporter_->Report(
|
||||
"Model provided has model identifier '%c%c%c%c', should be '%s'\n",
|
||||
ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
FlatBufferModel::FlatBufferModel(const Model* model,
|
||||
ErrorReporter* error_reporter)
|
||||
: model_(model), error_reporter_(ValidateErrorReporter(error_reporter)) {}
|
||||
|
||||
FlatBufferModel::FlatBufferModel(std::unique_ptr<Allocation> allocation,
|
||||
ErrorReporter* error_reporter)
|
||||
: error_reporter_(ValidateErrorReporter(error_reporter)),
|
||||
allocation_(std::move(allocation)) {
|
||||
if (!allocation_->valid() || !CheckModelIdentifier()) return;
|
||||
|
||||
model_ = ::tflite::GetModel(allocation_->base());
|
||||
}
|
||||
|
||||
FlatBufferModel::~FlatBufferModel() {}
|
||||
|
||||
} // namespace tflite
|
179
tensorflow/lite/model_builder.h
Normal file
179
tensorflow/lite/model_builder.h
Normal file
@ -0,0 +1,179 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
/// Deserialization infrastructure for tflite. Provides functionality
|
||||
/// to go from a serialized tflite model in flatbuffer format to an
|
||||
/// interpreter.
|
||||
///
|
||||
#ifndef TENSORFLOW_LITE_MODEL_BUILDER_H_
|
||||
#define TENSORFLOW_LITE_MODEL_BUILDER_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
/// Abstract interface that verifies whether a given model is legit.
|
||||
/// It facilitates the use-case to verify and build a model without loading it
|
||||
/// twice.
|
||||
class TfLiteVerifier {
|
||||
public:
|
||||
/// Returns true if the model is legit.
|
||||
virtual bool Verify(const char* data, int length,
|
||||
ErrorReporter* reporter) = 0;
|
||||
virtual ~TfLiteVerifier() {}
|
||||
};
|
||||
|
||||
/// An RAII object that represents a read-only tflite model, copied from disk,
|
||||
/// or mmapped. This uses flatbuffers as the serialization format.
|
||||
///
|
||||
/// NOTE: The current API requires that a FlatBufferModel instance be kept alive
|
||||
/// by the client as long as it is in use by any dependent Interpreter
|
||||
/// instances.
|
||||
/// <pre><code>
|
||||
/// using namespace tflite;
|
||||
/// StderrReporter error_reporter;
|
||||
/// auto model = FlatBufferModel::BuildFromFile("interesting_model.tflite",
|
||||
/// &error_reporter);
|
||||
/// MyOpResolver resolver; // You need to subclass OpResolver to provide
|
||||
/// // implementations.
|
||||
/// InterpreterBuilder builder(*model, resolver);
|
||||
/// std::unique_ptr<Interpreter> interpreter;
|
||||
/// if(builder(&interpreter) == kTfLiteOk) {
|
||||
/// .. run model inference with interpreter
|
||||
/// }
|
||||
/// </code></pre>
|
||||
///
|
||||
/// OpResolver must be defined to provide your kernel implementations to the
|
||||
/// interpreter. This is environment specific and may consist of just the
|
||||
/// builtin ops, or some custom operators you defined to extend tflite.
|
||||
class FlatBufferModel {
|
||||
public:
|
||||
/// Builds a model based on a file.
|
||||
/// Caller retains ownership of `error_reporter` and must ensure its lifetime
|
||||
/// is longer than the FlatBufferModel instance.
|
||||
/// Returns a nullptr in case of failure.
|
||||
static std::unique_ptr<FlatBufferModel> BuildFromFile(
|
||||
const char* filename,
|
||||
ErrorReporter* error_reporter = DefaultErrorReporter());
|
||||
|
||||
/// Verifies whether the content of the file is legit, then builds a model
|
||||
/// based on the file.
|
||||
/// The extra_verifier argument is an additional optional verifier for the
|
||||
/// file contents. By default, we always check with tflite::VerifyModelBuffer.
|
||||
/// If extra_verifier is supplied, the file contents is also checked against
|
||||
/// the extra_verifier after the check against tflite::VerifyModelBuilder.
|
||||
/// Caller retains ownership of `error_reporter` and must ensure its lifetime
|
||||
/// is longer than the FlatBufferModel instance.
|
||||
/// Returns a nullptr in case of failure.
|
||||
static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromFile(
|
||||
const char* filename, TfLiteVerifier* extra_verifier = nullptr,
|
||||
ErrorReporter* error_reporter = DefaultErrorReporter());
|
||||
|
||||
/// Builds a model based on a pre-loaded flatbuffer.
|
||||
/// Caller retains ownership of the buffer and should keep it alive until
|
||||
/// the returned object is destroyed. Caller also retains ownership of
|
||||
/// `error_reporter` and must ensure its lifetime is longer than the
|
||||
/// FlatBufferModel instance.
|
||||
/// Returns a nullptr in case of failure.
|
||||
/// NOTE: this does NOT validate the buffer so it should NOT be called on
|
||||
/// invalid/untrusted input. Use VerifyAndBuildFromBuffer in that case
|
||||
static std::unique_ptr<FlatBufferModel> BuildFromBuffer(
|
||||
const char* caller_owned_buffer, size_t buffer_size,
|
||||
ErrorReporter* error_reporter = DefaultErrorReporter());
|
||||
|
||||
/// Verifies whether the content of the buffer is legit, then builds a model
|
||||
/// based on the pre-loaded flatbuffer.
|
||||
/// The extra_verifier argument is an additional optional verifier for the
|
||||
/// buffer. By default, we always check with tflite::VerifyModelBuffer. If
|
||||
/// extra_verifier is supplied, the buffer is checked against the
|
||||
/// extra_verifier after the check against tflite::VerifyModelBuilder. The
|
||||
/// caller retains ownership of the buffer and should keep it alive until the
|
||||
/// returned object is destroyed. Caller retains ownership of `error_reporter`
|
||||
/// and must ensure its lifetime is longer than the FlatBufferModel instance.
|
||||
/// Returns a nullptr in case of failure.
|
||||
static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromBuffer(
|
||||
const char* caller_owned_buffer, size_t buffer_size,
|
||||
TfLiteVerifier* extra_verifier = nullptr,
|
||||
ErrorReporter* error_reporter = DefaultErrorReporter());
|
||||
|
||||
/// Builds a model directly from a flatbuffer pointer
|
||||
/// Caller retains ownership of the buffer and should keep it alive until the
|
||||
/// returned object is destroyed. Caller retains ownership of `error_reporter`
|
||||
/// and must ensure its lifetime is longer than the FlatBufferModel instance.
|
||||
/// Returns a nullptr in case of failure.
|
||||
static std::unique_ptr<FlatBufferModel> BuildFromModel(
|
||||
const tflite::Model* caller_owned_model_spec,
|
||||
ErrorReporter* error_reporter = DefaultErrorReporter());
|
||||
|
||||
// Releases memory or unmaps mmaped memory.
|
||||
~FlatBufferModel();
|
||||
|
||||
// Copying or assignment is disallowed to simplify ownership semantics.
|
||||
FlatBufferModel(const FlatBufferModel&) = delete;
|
||||
FlatBufferModel& operator=(const FlatBufferModel&) = delete;
|
||||
|
||||
bool initialized() const { return model_ != nullptr; }
|
||||
const tflite::Model* operator->() const { return model_; }
|
||||
const tflite::Model* GetModel() const { return model_; }
|
||||
ErrorReporter* error_reporter() const { return error_reporter_; }
|
||||
const Allocation* allocation() const { return allocation_.get(); }
|
||||
|
||||
// Returns the minimum runtime version from the flatbuffer. This runtime
|
||||
// version encodes the minimum required interpreter version to run the
|
||||
// flatbuffer model. If the minimum version can't be determined, an empty
|
||||
// string will be returned.
|
||||
// Note that the returned minimum version is a lower-bound but not a strict
|
||||
// lower-bound; ops in the graph may not have an associated runtime version,
|
||||
// in which case the actual required runtime might be greater than the
|
||||
// reported minimum.
|
||||
string GetMinimumRuntime() const;
|
||||
|
||||
/// Returns true if the model identifier is correct (otherwise false and
|
||||
/// reports an error).
|
||||
bool CheckModelIdentifier() const;
|
||||
|
||||
private:
|
||||
/// Loads a model from a given allocation. FlatBufferModel will take over the
|
||||
/// ownership of `allocation`, and delete it in destructor. The ownership of
|
||||
/// `error_reporter`remains with the caller and must have lifetime at least
|
||||
/// as much as FlatBufferModel. This is to allow multiple models to use the
|
||||
/// same ErrorReporter instance.
|
||||
FlatBufferModel(std::unique_ptr<Allocation> allocation,
|
||||
ErrorReporter* error_reporter = DefaultErrorReporter());
|
||||
|
||||
/// Loads a model from Model flatbuffer. The `model` has to remain alive and
|
||||
/// unchanged until the end of this flatbuffermodel's lifetime.
|
||||
FlatBufferModel(const Model* model, ErrorReporter* error_reporter);
|
||||
|
||||
/// Flatbuffer traverser pointer. (Model* is a pointer that is within the
|
||||
/// allocated memory of the data allocated by allocation's internals.
|
||||
const tflite::Model* model_ = nullptr;
|
||||
/// The error reporter to use for model errors and subsequent errors when
|
||||
/// the interpreter is created
|
||||
ErrorReporter* error_reporter_;
|
||||
/// The allocator used for holding memory of the model. Note that this will
|
||||
/// be null if the client provides a tflite::Model directly.
|
||||
std::unique_ptr<Allocation> allocation_;
|
||||
};
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_MODEL_BUILDER_H_
|
Loading…
x
Reference in New Issue
Block a user