[tf.lite] (Re-land) Add FlatBufferModel::BuildFromAllocation() API

This allows more flexibility in providing models from alternative sources.
Also add a `MMAPAllocation(int fd)` API for enabling user-provided file
descriptors as model sources.

Resolves #46593.

PiperOrigin-RevId: 358931162
Change-Id: I6a443f73f7bce2c0c18dc35f20650715170365cb
This commit is contained in:
Jared Duke 2021-02-22 16:33:35 -08:00 committed by TensorFlower Gardener
parent 901d914ab7
commit 7da3e2c1d4
9 changed files with 260 additions and 61 deletions

View File

@ -704,6 +704,23 @@ cc_test(
],
)
cc_test(
name = "allocation_test",
size = "small",
srcs = ["allocation_test.cc"],
data = [
"testdata/empty_model.bin",
],
tags = [
"tflite_smoke_test",
],
deps = [
":allocation",
"//tensorflow/lite/testing:util",
"@com_google_googletest//:gtest",
],
)
# Test OpResolver.
cc_test(
name = "mutable_op_resolver_test",

View File

@ -22,11 +22,8 @@ limitations under the License.
#include <cstdio>
#include <cstdlib>
#include <memory>
#include <vector>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/string_type.h"
namespace tflite {
@ -59,9 +56,18 @@ class Allocation {
const Type type_;
};
// Note that not all platforms support MMAP-based allocation.
// Use `IsSupported()` to check.
class MMAPAllocation : public Allocation {
public:
// Loads and maps the provided file to a memory region.
MMAPAllocation(const char* filename, ErrorReporter* error_reporter);
// Maps the provided file descriptor to a memory region.
// Note: The provided file descriptor will be dup'ed for usage; the caller
// retains ownership of the provided descriptor and should close accordingly.
MMAPAllocation(int fd, ErrorReporter* error_reporter);
virtual ~MMAPAllocation();
const void* base() const override;
size_t bytes() const override;
@ -76,10 +82,15 @@ class MMAPAllocation : public Allocation {
int mmap_fd_ = -1; // mmap file descriptor
const void* mmapped_buffer_;
size_t buffer_size_bytes_ = 0;
private:
// Assumes ownership of the provided `owned_fd` instance.
MMAPAllocation(ErrorReporter* error_reporter, int owned_fd);
};
class FileCopyAllocation : public Allocation {
public:
// Loads the provided file into a heap memory region.
FileCopyAllocation(const char* filename, ErrorReporter* error_reporter);
virtual ~FileCopyAllocation();
const void* base() const override;
@ -87,16 +98,15 @@ class FileCopyAllocation : public Allocation {
bool valid() const override;
private:
// Data required for mmap.
std::unique_ptr<const char[]> copied_buffer_;
size_t buffer_size_bytes_ = 0;
};
class MemoryAllocation : public Allocation {
public:
// Allocates memory with the pointer and the number of bytes of the memory.
// The pointer has to remain alive and unchanged until the destructor is
// called.
// Provides a (read-only) view of the provided buffer region as an allocation.
// Note: The caller retains ownership of `ptr`, and must ensure it remains
// valid for the lifetime of the class instance.
MemoryAllocation(const void* ptr, size_t num_bytes,
ErrorReporter* error_reporter);
virtual ~MemoryAllocation();

View File

@ -0,0 +1,90 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/allocation.h"
#if defined(__linux__)
#include <fcntl.h>
#endif
#include <string>
#include <gtest/gtest.h>
#include "tensorflow/lite/testing/util.h"
namespace tflite {
TEST(MMAPAllocation, TestInvalidFile) {
if (!MMAPAllocation::IsSupported()) {
return;
}
TestErrorReporter error_reporter;
MMAPAllocation allocation("/tmp/tflite_model_1234", &error_reporter);
EXPECT_FALSE(allocation.valid());
}
TEST(MMAPAllocation, TestValidFile) {
if (!MMAPAllocation::IsSupported()) {
return;
}
TestErrorReporter error_reporter;
MMAPAllocation allocation(
"tensorflow/lite/testdata/empty_model.bin", &error_reporter);
ASSERT_TRUE(allocation.valid());
EXPECT_GT(allocation.fd(), 0);
EXPECT_GT(allocation.bytes(), 0);
EXPECT_NE(allocation.base(), nullptr);
}
#if defined(__linux__)
TEST(MMAPAllocation, TestInvalidFileDescriptor) {
if (!MMAPAllocation::IsSupported()) {
return;
}
TestErrorReporter error_reporter;
MMAPAllocation allocation(-1, &error_reporter);
EXPECT_FALSE(allocation.valid());
}
TEST(MMAPAllocation, TestValidFileDescriptor) {
if (!MMAPAllocation::IsSupported()) {
return;
}
int fd =
open("tensorflow/lite/testdata/empty_model.bin", O_RDONLY);
ASSERT_GT(fd, 0);
TestErrorReporter error_reporter;
MMAPAllocation allocation(fd, &error_reporter);
EXPECT_TRUE(allocation.valid());
EXPECT_GT(allocation.fd(), 0);
EXPECT_GT(allocation.bytes(), 0);
EXPECT_NE(allocation.base(), nullptr);
close(fd);
}
#endif
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -81,7 +81,7 @@ public final class NativeInterpreterWrapperTest {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INVALID_MODEL_PATH);
fail();
} catch (IllegalArgumentException e) {
assertThat(e).hasMessageThat().contains("The model is not a valid Flatbuffer file");
assertThat(e).hasMessageThat().contains("The model is not a valid Flatbuffer");
}
}
@ -92,7 +92,6 @@ public final class NativeInterpreterWrapperTest {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(NONEXISTING_MODEL_PATH);
fail();
} catch (IllegalArgumentException e) {
assertThat(e).hasMessageThat().contains("The model is not a valid Flatbuffer file");
assertThat(e).hasMessageThat().contains("Could not open");
}
}

View File

@ -26,11 +26,26 @@ namespace tflite {
MMAPAllocation::MMAPAllocation(const char* filename,
ErrorReporter* error_reporter)
: Allocation(error_reporter, Allocation::Type::kMMap),
mmapped_buffer_(MAP_FAILED) {
mmap_fd_ = open(filename, O_RDONLY);
: MMAPAllocation(error_reporter, open(filename, O_RDONLY)) {
if (mmap_fd_ == -1) {
TF_LITE_REPORT_ERROR(error_reporter, "Could not open '%s'.", filename);
}
}
MMAPAllocation::MMAPAllocation(int fd, ErrorReporter* error_reporter)
: MMAPAllocation(error_reporter, dup(fd)) {
if (mmap_fd_ == -1) {
TF_LITE_REPORT_ERROR(error_reporter, "Failed to dup '%d' file descriptor.",
fd);
}
}
MMAPAllocation::MMAPAllocation(ErrorReporter* error_reporter, int owned_fd)
: Allocation(error_reporter, Allocation::Type::kMMap),
mmap_fd_(owned_fd),
mmapped_buffer_(MAP_FAILED),
buffer_size_bytes_(0) {
if (mmap_fd_ == -1) {
error_reporter_->Report("Could not open '%s'.", filename);
return;
}
struct stat sb;
@ -39,7 +54,7 @@ MMAPAllocation::MMAPAllocation(const char* filename,
mmapped_buffer_ =
mmap(nullptr, buffer_size_bytes_, PROT_READ, MAP_SHARED, mmap_fd_, 0);
if (mmapped_buffer_ == MAP_FAILED) {
error_reporter_->Report("Mmap of '%s' failed.", filename);
TF_LITE_REPORT_ERROR(error_reporter, "Mmap of '%d' failed.", mmap_fd_);
return;
}
}

View File

@ -21,6 +21,12 @@ namespace tflite {
MMAPAllocation::MMAPAllocation(const char* filename,
ErrorReporter* error_reporter)
: MMAPAllocation(error_reporter, -1) {}
MMAPAllocation::MMAPAllocation(int fd, ErrorReporter* error_reporter)
: MMAPAllocation(error_reporter, -1) {}
MMAPAllocation::MMAPAllocation(ErrorReporter* error_reporter, int owned_fd)
: Allocation(error_reporter, Allocation::Type::kMMap),
mmapped_buffer_(nullptr) {
// The disabled variant should never be created.

View File

@ -43,12 +43,10 @@ ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
#ifndef TFLITE_MCU
// Loads a model from `filename`. If `mmap_file` is true then use mmap,
// otherwise make a copy of the model in a buffer.
std::unique_ptr<Allocation> GetAllocationFromFile(const char* filename,
bool mmap_file,
ErrorReporter* error_reporter,
bool use_nnapi) {
std::unique_ptr<Allocation> GetAllocationFromFile(
const char* filename, ErrorReporter* error_reporter) {
std::unique_ptr<Allocation> allocation;
if (mmap_file && MMAPAllocation::IsSupported()) {
if (MMAPAllocation::IsSupported()) {
allocation.reset(new MMAPAllocation(filename, error_reporter));
} else {
allocation.reset(new FileCopyAllocation(filename, error_reporter));
@ -59,41 +57,17 @@ std::unique_ptr<Allocation> GetAllocationFromFile(const char* filename,
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
const char* filename, ErrorReporter* error_reporter) {
error_reporter = ValidateErrorReporter(error_reporter);
std::unique_ptr<FlatBufferModel> model;
auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
error_reporter, /*use_nnapi=*/true);
model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
if (!model->initialized()) model.reset();
return model;
return BuildFromAllocation(GetAllocationFromFile(filename, error_reporter),
error_reporter);
}
std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromFile(
const char* filename, TfLiteVerifier* extra_verifier,
ErrorReporter* error_reporter) {
error_reporter = ValidateErrorReporter(error_reporter);
std::unique_ptr<FlatBufferModel> model;
auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
error_reporter, /*use_nnapi=*/true);
flatbuffers::Verifier base_verifier(
reinterpret_cast<const uint8_t*>(allocation->base()),
allocation->bytes());
if (!VerifyModelBuffer(base_verifier)) {
TF_LITE_REPORT_ERROR(error_reporter,
"The model is not a valid Flatbuffer file");
return nullptr;
}
if (extra_verifier &&
!extra_verifier->Verify(static_cast<const char*>(allocation->base()),
allocation->bytes(), error_reporter)) {
return model;
}
model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
if (!model->initialized()) model.reset();
return model;
return VerifyAndBuildFromAllocation(
GetAllocationFromFile(filename, error_reporter), extra_verifier,
error_reporter);
}
#endif
@ -101,34 +75,57 @@ std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
const char* caller_owned_buffer, size_t buffer_size,
ErrorReporter* error_reporter) {
error_reporter = ValidateErrorReporter(error_reporter);
std::unique_ptr<FlatBufferModel> model;
std::unique_ptr<Allocation> allocation(
new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
if (!model->initialized()) model.reset();
return model;
return BuildFromAllocation(std::move(allocation), error_reporter);
}
std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromBuffer(
const char* caller_owned_buffer, size_t buffer_size,
TfLiteVerifier* extra_verifier, ErrorReporter* error_reporter) {
error_reporter = ValidateErrorReporter(error_reporter);
std::unique_ptr<Allocation> allocation(
new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
return VerifyAndBuildFromAllocation(std::move(allocation), extra_verifier,
error_reporter);
}
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromAllocation(
std::unique_ptr<Allocation> allocation, ErrorReporter* error_reporter) {
std::unique_ptr<FlatBufferModel> model(new FlatBufferModel(
std::move(allocation), ValidateErrorReporter(error_reporter)));
if (!model->initialized()) {
model.reset();
}
return model;
}
std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromAllocation(
std::unique_ptr<Allocation> allocation, TfLiteVerifier* extra_verifier,
ErrorReporter* error_reporter) {
error_reporter = ValidateErrorReporter(error_reporter);
if (!allocation || !allocation->valid()) {
TF_LITE_REPORT_ERROR(error_reporter, "The model allocation is null/empty");
return nullptr;
}
flatbuffers::Verifier base_verifier(
reinterpret_cast<const uint8_t*>(caller_owned_buffer), buffer_size);
reinterpret_cast<const uint8_t*>(allocation->base()),
allocation->bytes());
if (!VerifyModelBuffer(base_verifier)) {
TF_LITE_REPORT_ERROR(error_reporter,
"The model is not a valid Flatbuffer buffer");
return nullptr;
}
if (extra_verifier && !extra_verifier->Verify(caller_owned_buffer,
buffer_size, error_reporter)) {
if (extra_verifier &&
!extra_verifier->Verify(static_cast<const char*>(allocation->base()),
allocation->bytes(), error_reporter)) {
// The verifier will have already logged an appropriate error message.
return nullptr;
}
return BuildFromBuffer(caller_owned_buffer, buffer_size, error_reporter);
return BuildFromAllocation(std::move(allocation), error_reporter);
}
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
@ -136,9 +133,11 @@ std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
ErrorReporter* error_reporter) {
error_reporter = ValidateErrorReporter(error_reporter);
std::unique_ptr<FlatBufferModel> model;
model.reset(new FlatBufferModel(caller_owned_model_spec, error_reporter));
if (!model->initialized()) model.reset();
std::unique_ptr<FlatBufferModel> model(
new FlatBufferModel(caller_owned_model_spec, error_reporter));
if (!model->initialized()) {
model.reset();
}
return model;
}
@ -189,7 +188,9 @@ FlatBufferModel::FlatBufferModel(std::unique_ptr<Allocation> allocation,
ErrorReporter* error_reporter)
: error_reporter_(ValidateErrorReporter(error_reporter)),
allocation_(std::move(allocation)) {
if (!allocation_->valid() || !CheckModelIdentifier()) return;
if (!allocation_ || !allocation_->valid() || !CheckModelIdentifier()) {
return;
}
model_ = ::tflite::GetModel(allocation_->base());
}

View File

@ -110,6 +110,30 @@ class FlatBufferModel {
TfLiteVerifier* extra_verifier = nullptr,
ErrorReporter* error_reporter = DefaultErrorReporter());
/// Builds a model directly from an allocation.
/// Ownership of the allocation is passed to the model, but the caller
/// retains ownership of `error_reporter` and must ensure its lifetime is
/// longer than the FlatBufferModel instance.
/// Returns a nullptr in case of failure (e.g., the allocation is invalid).
static std::unique_ptr<FlatBufferModel> BuildFromAllocation(
std::unique_ptr<Allocation> allocation,
ErrorReporter* error_reporter = DefaultErrorReporter());
/// Verifies whether the content of the allocation is legit, then builds a
/// model based on the provided allocation.
/// The extra_verifier argument is an additional optional verifier for the
/// buffer. By default, we always check with tflite::VerifyModelBuffer. If
/// extra_verifier is supplied, the buffer is checked against the
/// extra_verifier after the check against tflite::VerifyModelBuilder.
/// Ownership of the allocation is passed to the model, but the caller
/// retains ownership of `error_reporter` and must ensure its lifetime is
/// longer than the FlatBufferModel instance.
/// Returns a nullptr in case of failure.
static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromAllocation(
std::unique_ptr<Allocation> allocation,
TfLiteVerifier* extra_verifier = nullptr,
ErrorReporter* error_reporter = DefaultErrorReporter());
/// Builds a model directly from a flatbuffer pointer
/// Caller retains ownership of the buffer and should keep it alive until the
/// returned object is destroyed. Caller retains ownership of `error_reporter`

View File

@ -384,6 +384,43 @@ TEST(BasicFlatBufferModel, TestBuildFromModel) {
ASSERT_NE(interpreter, nullptr);
}
// Test that loading model directly from an Allocation works.
TEST(BasicFlatBufferModel, TestBuildFromAllocation) {
TestErrorReporter reporter;
std::unique_ptr<Allocation> model_allocation(new FileCopyAllocation(
"tensorflow/lite/testdata/test_model.bin", &reporter));
ASSERT_TRUE(model_allocation->valid());
auto model =
FlatBufferModel::BuildFromAllocation(std::move(model_allocation));
ASSERT_TRUE(model);
std::unique_ptr<Interpreter> interpreter;
ASSERT_EQ(
InterpreterBuilder(*model, TrivialResolver(&dummy_reg))(&interpreter),
kTfLiteOk);
ASSERT_NE(interpreter, nullptr);
}
TEST(BasicFlatBufferModel, TestBuildFromNullAllocation) {
TestErrorReporter reporter;
std::unique_ptr<Allocation> model_allocation;
auto model =
FlatBufferModel::BuildFromAllocation(std::move(model_allocation));
ASSERT_FALSE(model);
}
TEST(BasicFlatBufferModel, TestBuildFromInvalidAllocation) {
TestErrorReporter reporter;
std::unique_ptr<Allocation> model_allocation(
new MemoryAllocation(nullptr, 0, nullptr));
auto model =
FlatBufferModel::BuildFromAllocation(std::move(model_allocation));
ASSERT_FALSE(model);
}
// Test reading the minimum runtime string from metadata in a Model flatbuffer.
TEST(BasicFlatBufferModel, TestReadRuntimeVersionFromModel) {
// First read a model that doesn't have the runtime string.