[tf.lite] (Re-land) Add FlatBufferModel::BuildFromAllocation() API

This allows more flexibility in providing models from alternative sources.
Also add a `MMAPAllocation(int fd)` API for enabling user-provided file
descriptors as model sources.

Resolves #46593.

PiperOrigin-RevId: 358931162
Change-Id: I6a443f73f7bce2c0c18dc35f20650715170365cb
This commit is contained in:
Jared Duke 2021-02-22 16:33:35 -08:00 committed by TensorFlower Gardener
parent 901d914ab7
commit 7da3e2c1d4
9 changed files with 260 additions and 61 deletions

View File

@ -704,6 +704,23 @@ cc_test(
], ],
) )
cc_test(
name = "allocation_test",
size = "small",
srcs = ["allocation_test.cc"],
data = [
"testdata/empty_model.bin",
],
tags = [
"tflite_smoke_test",
],
deps = [
":allocation",
"//tensorflow/lite/testing:util",
"@com_google_googletest//:gtest",
],
)
# Test OpResolver. # Test OpResolver.
cc_test( cc_test(
name = "mutable_op_resolver_test", name = "mutable_op_resolver_test",

View File

@ -22,11 +22,8 @@ limitations under the License.
#include <cstdio> #include <cstdio>
#include <cstdlib> #include <cstdlib>
#include <memory> #include <memory>
#include <vector>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/string_type.h"
namespace tflite { namespace tflite {
@ -59,9 +56,18 @@ class Allocation {
const Type type_; const Type type_;
}; };
// Note that not all platforms support MMAP-based allocation.
// Use `IsSupported()` to check.
class MMAPAllocation : public Allocation { class MMAPAllocation : public Allocation {
public: public:
// Loads and maps the provided file to a memory region.
MMAPAllocation(const char* filename, ErrorReporter* error_reporter); MMAPAllocation(const char* filename, ErrorReporter* error_reporter);
// Maps the provided file descriptor to a memory region.
// Note: The provided file descriptor will be dup'ed for usage; the caller
// retains ownership of the provided descriptor and should close accordingly.
MMAPAllocation(int fd, ErrorReporter* error_reporter);
virtual ~MMAPAllocation(); virtual ~MMAPAllocation();
const void* base() const override; const void* base() const override;
size_t bytes() const override; size_t bytes() const override;
@ -76,10 +82,15 @@ class MMAPAllocation : public Allocation {
int mmap_fd_ = -1; // mmap file descriptor int mmap_fd_ = -1; // mmap file descriptor
const void* mmapped_buffer_; const void* mmapped_buffer_;
size_t buffer_size_bytes_ = 0; size_t buffer_size_bytes_ = 0;
private:
// Assumes ownership of the provided `owned_fd` instance.
MMAPAllocation(ErrorReporter* error_reporter, int owned_fd);
}; };
class FileCopyAllocation : public Allocation { class FileCopyAllocation : public Allocation {
public: public:
// Loads the provided file into a heap memory region.
FileCopyAllocation(const char* filename, ErrorReporter* error_reporter); FileCopyAllocation(const char* filename, ErrorReporter* error_reporter);
virtual ~FileCopyAllocation(); virtual ~FileCopyAllocation();
const void* base() const override; const void* base() const override;
@ -87,16 +98,15 @@ class FileCopyAllocation : public Allocation {
bool valid() const override; bool valid() const override;
private: private:
// Data required for mmap.
std::unique_ptr<const char[]> copied_buffer_; std::unique_ptr<const char[]> copied_buffer_;
size_t buffer_size_bytes_ = 0; size_t buffer_size_bytes_ = 0;
}; };
class MemoryAllocation : public Allocation { class MemoryAllocation : public Allocation {
public: public:
// Allocates memory with the pointer and the number of bytes of the memory. // Provides a (read-only) view of the provided buffer region as an allocation.
// The pointer has to remain alive and unchanged until the destructor is // Note: The caller retains ownership of `ptr`, and must ensure it remains
// called. // valid for the lifetime of the class instance.
MemoryAllocation(const void* ptr, size_t num_bytes, MemoryAllocation(const void* ptr, size_t num_bytes,
ErrorReporter* error_reporter); ErrorReporter* error_reporter);
virtual ~MemoryAllocation(); virtual ~MemoryAllocation();

View File

@ -0,0 +1,90 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/allocation.h"
#if defined(__linux__)
#include <fcntl.h>
#endif
#include <string>
#include <gtest/gtest.h>
#include "tensorflow/lite/testing/util.h"
namespace tflite {
TEST(MMAPAllocation, TestInvalidFile) {
if (!MMAPAllocation::IsSupported()) {
return;
}
TestErrorReporter error_reporter;
MMAPAllocation allocation("/tmp/tflite_model_1234", &error_reporter);
EXPECT_FALSE(allocation.valid());
}
TEST(MMAPAllocation, TestValidFile) {
if (!MMAPAllocation::IsSupported()) {
return;
}
TestErrorReporter error_reporter;
MMAPAllocation allocation(
"tensorflow/lite/testdata/empty_model.bin", &error_reporter);
ASSERT_TRUE(allocation.valid());
EXPECT_GT(allocation.fd(), 0);
EXPECT_GT(allocation.bytes(), 0);
EXPECT_NE(allocation.base(), nullptr);
}
#if defined(__linux__)
TEST(MMAPAllocation, TestInvalidFileDescriptor) {
if (!MMAPAllocation::IsSupported()) {
return;
}
TestErrorReporter error_reporter;
MMAPAllocation allocation(-1, &error_reporter);
EXPECT_FALSE(allocation.valid());
}
TEST(MMAPAllocation, TestValidFileDescriptor) {
if (!MMAPAllocation::IsSupported()) {
return;
}
int fd =
open("tensorflow/lite/testdata/empty_model.bin", O_RDONLY);
ASSERT_GT(fd, 0);
TestErrorReporter error_reporter;
MMAPAllocation allocation(fd, &error_reporter);
EXPECT_TRUE(allocation.valid());
EXPECT_GT(allocation.fd(), 0);
EXPECT_GT(allocation.bytes(), 0);
EXPECT_NE(allocation.base(), nullptr);
close(fd);
}
#endif
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -81,7 +81,7 @@ public final class NativeInterpreterWrapperTest {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INVALID_MODEL_PATH); NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INVALID_MODEL_PATH);
fail(); fail();
} catch (IllegalArgumentException e) { } catch (IllegalArgumentException e) {
assertThat(e).hasMessageThat().contains("The model is not a valid Flatbuffer file"); assertThat(e).hasMessageThat().contains("The model is not a valid Flatbuffer");
} }
} }
@ -92,7 +92,6 @@ public final class NativeInterpreterWrapperTest {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(NONEXISTING_MODEL_PATH); NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(NONEXISTING_MODEL_PATH);
fail(); fail();
} catch (IllegalArgumentException e) { } catch (IllegalArgumentException e) {
assertThat(e).hasMessageThat().contains("The model is not a valid Flatbuffer file");
assertThat(e).hasMessageThat().contains("Could not open"); assertThat(e).hasMessageThat().contains("Could not open");
} }
} }

View File

@ -26,11 +26,26 @@ namespace tflite {
MMAPAllocation::MMAPAllocation(const char* filename, MMAPAllocation::MMAPAllocation(const char* filename,
ErrorReporter* error_reporter) ErrorReporter* error_reporter)
: Allocation(error_reporter, Allocation::Type::kMMap), : MMAPAllocation(error_reporter, open(filename, O_RDONLY)) {
mmapped_buffer_(MAP_FAILED) { if (mmap_fd_ == -1) {
mmap_fd_ = open(filename, O_RDONLY); TF_LITE_REPORT_ERROR(error_reporter, "Could not open '%s'.", filename);
}
}
MMAPAllocation::MMAPAllocation(int fd, ErrorReporter* error_reporter)
: MMAPAllocation(error_reporter, dup(fd)) {
if (mmap_fd_ == -1) {
TF_LITE_REPORT_ERROR(error_reporter, "Failed to dup '%d' file descriptor.",
fd);
}
}
MMAPAllocation::MMAPAllocation(ErrorReporter* error_reporter, int owned_fd)
: Allocation(error_reporter, Allocation::Type::kMMap),
mmap_fd_(owned_fd),
mmapped_buffer_(MAP_FAILED),
buffer_size_bytes_(0) {
if (mmap_fd_ == -1) { if (mmap_fd_ == -1) {
error_reporter_->Report("Could not open '%s'.", filename);
return; return;
} }
struct stat sb; struct stat sb;
@ -39,7 +54,7 @@ MMAPAllocation::MMAPAllocation(const char* filename,
mmapped_buffer_ = mmapped_buffer_ =
mmap(nullptr, buffer_size_bytes_, PROT_READ, MAP_SHARED, mmap_fd_, 0); mmap(nullptr, buffer_size_bytes_, PROT_READ, MAP_SHARED, mmap_fd_, 0);
if (mmapped_buffer_ == MAP_FAILED) { if (mmapped_buffer_ == MAP_FAILED) {
error_reporter_->Report("Mmap of '%s' failed.", filename); TF_LITE_REPORT_ERROR(error_reporter, "Mmap of '%d' failed.", mmap_fd_);
return; return;
} }
} }

View File

@ -21,6 +21,12 @@ namespace tflite {
MMAPAllocation::MMAPAllocation(const char* filename, MMAPAllocation::MMAPAllocation(const char* filename,
ErrorReporter* error_reporter) ErrorReporter* error_reporter)
: MMAPAllocation(error_reporter, -1) {}
MMAPAllocation::MMAPAllocation(int fd, ErrorReporter* error_reporter)
: MMAPAllocation(error_reporter, -1) {}
MMAPAllocation::MMAPAllocation(ErrorReporter* error_reporter, int owned_fd)
: Allocation(error_reporter, Allocation::Type::kMMap), : Allocation(error_reporter, Allocation::Type::kMMap),
mmapped_buffer_(nullptr) { mmapped_buffer_(nullptr) {
// The disabled variant should never be created. // The disabled variant should never be created.

View File

@ -43,12 +43,10 @@ ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
#ifndef TFLITE_MCU #ifndef TFLITE_MCU
// Loads a model from `filename`. If `mmap_file` is true then use mmap, // Loads a model from `filename`. If `mmap_file` is true then use mmap,
// otherwise make a copy of the model in a buffer. // otherwise make a copy of the model in a buffer.
std::unique_ptr<Allocation> GetAllocationFromFile(const char* filename, std::unique_ptr<Allocation> GetAllocationFromFile(
bool mmap_file, const char* filename, ErrorReporter* error_reporter) {
ErrorReporter* error_reporter,
bool use_nnapi) {
std::unique_ptr<Allocation> allocation; std::unique_ptr<Allocation> allocation;
if (mmap_file && MMAPAllocation::IsSupported()) { if (MMAPAllocation::IsSupported()) {
allocation.reset(new MMAPAllocation(filename, error_reporter)); allocation.reset(new MMAPAllocation(filename, error_reporter));
} else { } else {
allocation.reset(new FileCopyAllocation(filename, error_reporter)); allocation.reset(new FileCopyAllocation(filename, error_reporter));
@ -59,41 +57,17 @@ std::unique_ptr<Allocation> GetAllocationFromFile(const char* filename,
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile( std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
const char* filename, ErrorReporter* error_reporter) { const char* filename, ErrorReporter* error_reporter) {
error_reporter = ValidateErrorReporter(error_reporter); error_reporter = ValidateErrorReporter(error_reporter);
return BuildFromAllocation(GetAllocationFromFile(filename, error_reporter),
std::unique_ptr<FlatBufferModel> model; error_reporter);
auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
error_reporter, /*use_nnapi=*/true);
model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
if (!model->initialized()) model.reset();
return model;
} }
std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromFile( std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromFile(
const char* filename, TfLiteVerifier* extra_verifier, const char* filename, TfLiteVerifier* extra_verifier,
ErrorReporter* error_reporter) { ErrorReporter* error_reporter) {
error_reporter = ValidateErrorReporter(error_reporter); error_reporter = ValidateErrorReporter(error_reporter);
return VerifyAndBuildFromAllocation(
std::unique_ptr<FlatBufferModel> model; GetAllocationFromFile(filename, error_reporter), extra_verifier,
auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true, error_reporter);
error_reporter, /*use_nnapi=*/true);
flatbuffers::Verifier base_verifier(
reinterpret_cast<const uint8_t*>(allocation->base()),
allocation->bytes());
if (!VerifyModelBuffer(base_verifier)) {
TF_LITE_REPORT_ERROR(error_reporter,
"The model is not a valid Flatbuffer file");
return nullptr;
}
if (extra_verifier &&
!extra_verifier->Verify(static_cast<const char*>(allocation->base()),
allocation->bytes(), error_reporter)) {
return model;
}
model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
if (!model->initialized()) model.reset();
return model;
} }
#endif #endif
@ -101,34 +75,57 @@ std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
const char* caller_owned_buffer, size_t buffer_size, const char* caller_owned_buffer, size_t buffer_size,
ErrorReporter* error_reporter) { ErrorReporter* error_reporter) {
error_reporter = ValidateErrorReporter(error_reporter); error_reporter = ValidateErrorReporter(error_reporter);
std::unique_ptr<FlatBufferModel> model;
std::unique_ptr<Allocation> allocation( std::unique_ptr<Allocation> allocation(
new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter)); new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
model.reset(new FlatBufferModel(std::move(allocation), error_reporter)); return BuildFromAllocation(std::move(allocation), error_reporter);
if (!model->initialized()) model.reset();
return model;
} }
std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromBuffer( std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromBuffer(
const char* caller_owned_buffer, size_t buffer_size, const char* caller_owned_buffer, size_t buffer_size,
TfLiteVerifier* extra_verifier, ErrorReporter* error_reporter) { TfLiteVerifier* extra_verifier, ErrorReporter* error_reporter) {
error_reporter = ValidateErrorReporter(error_reporter); error_reporter = ValidateErrorReporter(error_reporter);
std::unique_ptr<Allocation> allocation(
new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
return VerifyAndBuildFromAllocation(std::move(allocation), extra_verifier,
error_reporter);
}
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromAllocation(
std::unique_ptr<Allocation> allocation, ErrorReporter* error_reporter) {
std::unique_ptr<FlatBufferModel> model(new FlatBufferModel(
std::move(allocation), ValidateErrorReporter(error_reporter)));
if (!model->initialized()) {
model.reset();
}
return model;
}
std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromAllocation(
std::unique_ptr<Allocation> allocation, TfLiteVerifier* extra_verifier,
ErrorReporter* error_reporter) {
error_reporter = ValidateErrorReporter(error_reporter);
if (!allocation || !allocation->valid()) {
TF_LITE_REPORT_ERROR(error_reporter, "The model allocation is null/empty");
return nullptr;
}
flatbuffers::Verifier base_verifier( flatbuffers::Verifier base_verifier(
reinterpret_cast<const uint8_t*>(caller_owned_buffer), buffer_size); reinterpret_cast<const uint8_t*>(allocation->base()),
allocation->bytes());
if (!VerifyModelBuffer(base_verifier)) { if (!VerifyModelBuffer(base_verifier)) {
TF_LITE_REPORT_ERROR(error_reporter, TF_LITE_REPORT_ERROR(error_reporter,
"The model is not a valid Flatbuffer buffer"); "The model is not a valid Flatbuffer buffer");
return nullptr; return nullptr;
} }
if (extra_verifier && !extra_verifier->Verify(caller_owned_buffer, if (extra_verifier &&
buffer_size, error_reporter)) { !extra_verifier->Verify(static_cast<const char*>(allocation->base()),
allocation->bytes(), error_reporter)) {
// The verifier will have already logged an appropriate error message.
return nullptr; return nullptr;
} }
return BuildFromBuffer(caller_owned_buffer, buffer_size, error_reporter); return BuildFromAllocation(std::move(allocation), error_reporter);
} }
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel( std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
@ -136,9 +133,11 @@ std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
ErrorReporter* error_reporter) { ErrorReporter* error_reporter) {
error_reporter = ValidateErrorReporter(error_reporter); error_reporter = ValidateErrorReporter(error_reporter);
std::unique_ptr<FlatBufferModel> model; std::unique_ptr<FlatBufferModel> model(
model.reset(new FlatBufferModel(caller_owned_model_spec, error_reporter)); new FlatBufferModel(caller_owned_model_spec, error_reporter));
if (!model->initialized()) model.reset(); if (!model->initialized()) {
model.reset();
}
return model; return model;
} }
@ -189,7 +188,9 @@ FlatBufferModel::FlatBufferModel(std::unique_ptr<Allocation> allocation,
ErrorReporter* error_reporter) ErrorReporter* error_reporter)
: error_reporter_(ValidateErrorReporter(error_reporter)), : error_reporter_(ValidateErrorReporter(error_reporter)),
allocation_(std::move(allocation)) { allocation_(std::move(allocation)) {
if (!allocation_->valid() || !CheckModelIdentifier()) return; if (!allocation_ || !allocation_->valid() || !CheckModelIdentifier()) {
return;
}
model_ = ::tflite::GetModel(allocation_->base()); model_ = ::tflite::GetModel(allocation_->base());
} }

View File

@ -110,6 +110,30 @@ class FlatBufferModel {
TfLiteVerifier* extra_verifier = nullptr, TfLiteVerifier* extra_verifier = nullptr,
ErrorReporter* error_reporter = DefaultErrorReporter()); ErrorReporter* error_reporter = DefaultErrorReporter());
/// Builds a model directly from an allocation.
/// Ownership of the allocation is passed to the model, but the caller
/// retains ownership of `error_reporter` and must ensure its lifetime is
/// longer than the FlatBufferModel instance.
/// Returns a nullptr in case of failure (e.g., the allocation is invalid).
static std::unique_ptr<FlatBufferModel> BuildFromAllocation(
std::unique_ptr<Allocation> allocation,
ErrorReporter* error_reporter = DefaultErrorReporter());
/// Verifies whether the content of the allocation is legit, then builds a
/// model based on the provided allocation.
/// The extra_verifier argument is an additional optional verifier for the
/// buffer. By default, we always check with tflite::VerifyModelBuffer. If
/// extra_verifier is supplied, the buffer is checked against the
/// extra_verifier after the check against tflite::VerifyModelBuilder.
/// Ownership of the allocation is passed to the model, but the caller
/// retains ownership of `error_reporter` and must ensure its lifetime is
/// longer than the FlatBufferModel instance.
/// Returns a nullptr in case of failure.
static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromAllocation(
std::unique_ptr<Allocation> allocation,
TfLiteVerifier* extra_verifier = nullptr,
ErrorReporter* error_reporter = DefaultErrorReporter());
/// Builds a model directly from a flatbuffer pointer /// Builds a model directly from a flatbuffer pointer
/// Caller retains ownership of the buffer and should keep it alive until the /// Caller retains ownership of the buffer and should keep it alive until the
/// returned object is destroyed. Caller retains ownership of `error_reporter` /// returned object is destroyed. Caller retains ownership of `error_reporter`

View File

@ -384,6 +384,43 @@ TEST(BasicFlatBufferModel, TestBuildFromModel) {
ASSERT_NE(interpreter, nullptr); ASSERT_NE(interpreter, nullptr);
} }
// Test that loading model directly from an Allocation works.
TEST(BasicFlatBufferModel, TestBuildFromAllocation) {
TestErrorReporter reporter;
std::unique_ptr<Allocation> model_allocation(new FileCopyAllocation(
"tensorflow/lite/testdata/test_model.bin", &reporter));
ASSERT_TRUE(model_allocation->valid());
auto model =
FlatBufferModel::BuildFromAllocation(std::move(model_allocation));
ASSERT_TRUE(model);
std::unique_ptr<Interpreter> interpreter;
ASSERT_EQ(
InterpreterBuilder(*model, TrivialResolver(&dummy_reg))(&interpreter),
kTfLiteOk);
ASSERT_NE(interpreter, nullptr);
}
TEST(BasicFlatBufferModel, TestBuildFromNullAllocation) {
TestErrorReporter reporter;
std::unique_ptr<Allocation> model_allocation;
auto model =
FlatBufferModel::BuildFromAllocation(std::move(model_allocation));
ASSERT_FALSE(model);
}
TEST(BasicFlatBufferModel, TestBuildFromInvalidAllocation) {
TestErrorReporter reporter;
std::unique_ptr<Allocation> model_allocation(
new MemoryAllocation(nullptr, 0, nullptr));
auto model =
FlatBufferModel::BuildFromAllocation(std::move(model_allocation));
ASSERT_FALSE(model);
}
// Test reading the minimum runtime string from metadata in a Model flatbuffer. // Test reading the minimum runtime string from metadata in a Model flatbuffer.
TEST(BasicFlatBufferModel, TestReadRuntimeVersionFromModel) { TEST(BasicFlatBufferModel, TestReadRuntimeVersionFromModel) {
// First read a model that doesn't have the runtime string. // First read a model that doesn't have the runtime string.