[tf.lite] Add FlatBufferModel::BuildFromAllocation() API

This allows more flexibility in providing models from alternative sources.
Also add a `MMAPAllocation(int fd)` API for enabling user-provided file
descriptors as model sources.

Resolves .

PiperOrigin-RevId: 356738903
Change-Id: Ibbbf774b866f5b733f2b6751c800a07695553339
This commit is contained in:
Mihai Maruseac 2021-02-10 08:25:01 -08:00 committed by TensorFlower Gardener
parent 03bd8a34ed
commit 119ade110d
9 changed files with 60 additions and 259 deletions

View File

@ -703,23 +703,6 @@ cc_test(
],
)
cc_test(
name = "allocation_test",
size = "small",
srcs = ["allocation_test.cc"],
data = [
"testdata/empty_model.bin",
],
tags = [
"tflite_smoke_test",
],
deps = [
":allocation",
"//tensorflow/lite/testing:util",
"@com_google_googletest//:gtest",
],
)
# Test OpResolver.
cc_test(
name = "mutable_op_resolver_test",

View File

@ -22,8 +22,11 @@ limitations under the License.
#include <cstdio>
#include <cstdlib>
#include <memory>
#include <vector>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/string_type.h"
namespace tflite {
@ -56,18 +59,9 @@ class Allocation {
const Type type_;
};
// Note that not all platforms support MMAP-based allocation.
// Use `IsSupported()` to check.
class MMAPAllocation : public Allocation {
public:
// Loads and maps the provided file to a memory region.
MMAPAllocation(const char* filename, ErrorReporter* error_reporter);
// Maps the provided file descriptor to a memory region.
// Note: The provided file descriptor will be dup'ed for usage; the caller
// retains ownership of the provided descriptor and should close accordingly.
MMAPAllocation(int fd, ErrorReporter* error_reporter);
virtual ~MMAPAllocation();
const void* base() const override;
size_t bytes() const override;
@ -82,15 +76,10 @@ class MMAPAllocation : public Allocation {
int mmap_fd_ = -1; // mmap file descriptor
const void* mmapped_buffer_;
size_t buffer_size_bytes_ = 0;
private:
// Assumes ownership of the provided `owned_fd` instance.
MMAPAllocation(ErrorReporter* error_reporter, int owned_fd);
};
class FileCopyAllocation : public Allocation {
public:
// Loads the provided file into a heap memory region.
FileCopyAllocation(const char* filename, ErrorReporter* error_reporter);
virtual ~FileCopyAllocation();
const void* base() const override;
@ -98,15 +87,16 @@ class FileCopyAllocation : public Allocation {
bool valid() const override;
private:
// Data required for mmap.
std::unique_ptr<const char[]> copied_buffer_;
size_t buffer_size_bytes_ = 0;
};
class MemoryAllocation : public Allocation {
public:
// Provides a (read-only) view of the provided buffer region as an allocation.
// Note: The caller retains ownership of `ptr`, and must ensure it remains
// valid for the lifetime of the class instance.
// Allocates memory with the pointer and the number of bytes of the memory.
// The pointer has to remain alive and unchanged until the destructor is
// called.
MemoryAllocation(const void* ptr, size_t num_bytes,
ErrorReporter* error_reporter);
virtual ~MemoryAllocation();

View File

@ -1,90 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/allocation.h"
#if defined(__linux__)
#include <fcntl.h>
#endif
#include <string>
#include <gtest/gtest.h>
#include "tensorflow/lite/testing/util.h"
namespace tflite {
TEST(MMAPAllocation, TestInvalidFile) {
if (!MMAPAllocation::IsSupported()) {
return;
}
TestErrorReporter error_reporter;
MMAPAllocation allocation("/tmp/tflite_model_1234", &error_reporter);
EXPECT_FALSE(allocation.valid());
}
TEST(MMAPAllocation, TestValidFile) {
if (!MMAPAllocation::IsSupported()) {
return;
}
TestErrorReporter error_reporter;
MMAPAllocation allocation(
"third_party/tensorflow/lite/testdata/empty_model.bin", &error_reporter);
ASSERT_TRUE(allocation.valid());
EXPECT_GT(allocation.fd(), 0);
EXPECT_GT(allocation.bytes(), 0);
EXPECT_NE(allocation.base(), nullptr);
}
#if defined(__linux__)
TEST(MMAPAllocation, TestInvalidFileDescriptor) {
if (!MMAPAllocation::IsSupported()) {
return;
}
TestErrorReporter error_reporter;
MMAPAllocation allocation(-1, &error_reporter);
EXPECT_FALSE(allocation.valid());
}
TEST(MMAPAllocation, TestValidFileDescriptor) {
if (!MMAPAllocation::IsSupported()) {
return;
}
int fd =
open("third_party/tensorflow/lite/testdata/empty_model.bin", O_RDONLY);
ASSERT_GT(fd, 0);
TestErrorReporter error_reporter;
MMAPAllocation allocation(fd, &error_reporter);
EXPECT_TRUE(allocation.valid());
EXPECT_GT(allocation.fd(), 0);
EXPECT_GT(allocation.bytes(), 0);
EXPECT_NE(allocation.base(), nullptr);
close(fd);
}
#endif
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -81,7 +81,7 @@ public final class NativeInterpreterWrapperTest {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INVALID_MODEL_PATH);
fail();
} catch (IllegalArgumentException e) {
assertThat(e).hasMessageThat().contains("The model is not a valid Flatbuffer");
assertThat(e).hasMessageThat().contains("The model is not a valid Flatbuffer file");
}
}
@ -92,6 +92,7 @@ public final class NativeInterpreterWrapperTest {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(NONEXISTING_MODEL_PATH);
fail();
} catch (IllegalArgumentException e) {
assertThat(e).hasMessageThat().contains("The model is not a valid Flatbuffer file");
assertThat(e).hasMessageThat().contains("Could not open");
}
}

View File

@ -26,26 +26,11 @@ namespace tflite {
MMAPAllocation::MMAPAllocation(const char* filename,
ErrorReporter* error_reporter)
: MMAPAllocation(error_reporter, open(filename, O_RDONLY)) {
if (mmap_fd_ == -1) {
TF_LITE_REPORT_ERROR(error_reporter, "Could not open '%s'.", filename);
}
}
MMAPAllocation::MMAPAllocation(int fd, ErrorReporter* error_reporter)
: MMAPAllocation(error_reporter, dup(fd)) {
if (mmap_fd_ == -1) {
TF_LITE_REPORT_ERROR(error_reporter, "Failed to dup '%d' file descriptor.",
fd);
}
}
MMAPAllocation::MMAPAllocation(ErrorReporter* error_reporter, int owned_fd)
: Allocation(error_reporter, Allocation::Type::kMMap),
mmap_fd_(owned_fd),
mmapped_buffer_(MAP_FAILED),
buffer_size_bytes_(0) {
mmapped_buffer_(MAP_FAILED) {
mmap_fd_ = open(filename, O_RDONLY);
if (mmap_fd_ == -1) {
error_reporter_->Report("Could not open '%s'.", filename);
return;
}
struct stat sb;
@ -54,7 +39,7 @@ MMAPAllocation::MMAPAllocation(ErrorReporter* error_reporter, int owned_fd)
mmapped_buffer_ =
mmap(nullptr, buffer_size_bytes_, PROT_READ, MAP_SHARED, mmap_fd_, 0);
if (mmapped_buffer_ == MAP_FAILED) {
TF_LITE_REPORT_ERROR(error_reporter, "Mmap of '%d' failed.", mmap_fd_);
error_reporter_->Report("Mmap of '%s' failed.", filename);
return;
}
}

View File

@ -21,12 +21,6 @@ namespace tflite {
MMAPAllocation::MMAPAllocation(const char* filename,
ErrorReporter* error_reporter)
: MMAPAllocation(error_reporter, -1) {}
MMAPAllocation::MMAPAllocation(int fd, ErrorReporter* error_reporter)
: MMAPAllocation(error_reporter, -1) {}
MMAPAllocation::MMAPAllocation(ErrorReporter* error_reporter, int owned_fd)
: Allocation(error_reporter, Allocation::Type::kMMap),
mmapped_buffer_(nullptr) {
// The disabled variant should never be created.

View File

@ -43,10 +43,12 @@ ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
#ifndef TFLITE_MCU
// Loads a model from `filename`. If `mmap_file` is true then use mmap,
// otherwise make a copy of the model in a buffer.
std::unique_ptr<Allocation> GetAllocationFromFile(
const char* filename, ErrorReporter* error_reporter) {
std::unique_ptr<Allocation> GetAllocationFromFile(const char* filename,
bool mmap_file,
ErrorReporter* error_reporter,
bool use_nnapi) {
std::unique_ptr<Allocation> allocation;
if (MMAPAllocation::IsSupported()) {
if (mmap_file && MMAPAllocation::IsSupported()) {
allocation.reset(new MMAPAllocation(filename, error_reporter));
} else {
allocation.reset(new FileCopyAllocation(filename, error_reporter));
@ -57,17 +59,41 @@ std::unique_ptr<Allocation> GetAllocationFromFile(
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
const char* filename, ErrorReporter* error_reporter) {
error_reporter = ValidateErrorReporter(error_reporter);
return BuildFromAllocation(GetAllocationFromFile(filename, error_reporter),
error_reporter);
std::unique_ptr<FlatBufferModel> model;
auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
error_reporter, /*use_nnapi=*/true);
model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
if (!model->initialized()) model.reset();
return model;
}
std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromFile(
const char* filename, TfLiteVerifier* extra_verifier,
ErrorReporter* error_reporter) {
error_reporter = ValidateErrorReporter(error_reporter);
return VerifyAndBuildFromAllocation(
GetAllocationFromFile(filename, error_reporter), extra_verifier,
error_reporter);
std::unique_ptr<FlatBufferModel> model;
auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
error_reporter, /*use_nnapi=*/true);
flatbuffers::Verifier base_verifier(
reinterpret_cast<const uint8_t*>(allocation->base()),
allocation->bytes());
if (!VerifyModelBuffer(base_verifier)) {
TF_LITE_REPORT_ERROR(error_reporter,
"The model is not a valid Flatbuffer file");
return nullptr;
}
if (extra_verifier &&
!extra_verifier->Verify(static_cast<const char*>(allocation->base()),
allocation->bytes(), error_reporter)) {
return model;
}
model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
if (!model->initialized()) model.reset();
return model;
}
#endif
@ -75,57 +101,34 @@ std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
const char* caller_owned_buffer, size_t buffer_size,
ErrorReporter* error_reporter) {
error_reporter = ValidateErrorReporter(error_reporter);
std::unique_ptr<FlatBufferModel> model;
std::unique_ptr<Allocation> allocation(
new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
return BuildFromAllocation(std::move(allocation), error_reporter);
model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
if (!model->initialized()) model.reset();
return model;
}
std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromBuffer(
const char* caller_owned_buffer, size_t buffer_size,
TfLiteVerifier* extra_verifier, ErrorReporter* error_reporter) {
error_reporter = ValidateErrorReporter(error_reporter);
std::unique_ptr<Allocation> allocation(
new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
return VerifyAndBuildFromAllocation(std::move(allocation), extra_verifier,
error_reporter);
}
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromAllocation(
std::unique_ptr<Allocation> allocation, ErrorReporter* error_reporter) {
std::unique_ptr<FlatBufferModel> model(new FlatBufferModel(
std::move(allocation), ValidateErrorReporter(error_reporter)));
if (!model->initialized()) {
model.reset();
}
return model;
}
std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromAllocation(
std::unique_ptr<Allocation> allocation, TfLiteVerifier* extra_verifier,
ErrorReporter* error_reporter) {
error_reporter = ValidateErrorReporter(error_reporter);
if (!allocation || !allocation->valid()) {
TF_LITE_REPORT_ERROR(error_reporter, "The model allocation is null/empty");
return nullptr;
}
flatbuffers::Verifier base_verifier(
reinterpret_cast<const uint8_t*>(allocation->base()),
allocation->bytes());
reinterpret_cast<const uint8_t*>(caller_owned_buffer), buffer_size);
if (!VerifyModelBuffer(base_verifier)) {
TF_LITE_REPORT_ERROR(error_reporter,
"The model is not a valid Flatbuffer buffer");
return nullptr;
}
if (extra_verifier &&
!extra_verifier->Verify(static_cast<const char*>(allocation->base()),
allocation->bytes(), error_reporter)) {
// The verifier will have already logged an appropriate error message.
if (extra_verifier && !extra_verifier->Verify(caller_owned_buffer,
buffer_size, error_reporter)) {
return nullptr;
}
return BuildFromAllocation(std::move(allocation), error_reporter);
return BuildFromBuffer(caller_owned_buffer, buffer_size, error_reporter);
}
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
@ -133,11 +136,9 @@ std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
ErrorReporter* error_reporter) {
error_reporter = ValidateErrorReporter(error_reporter);
std::unique_ptr<FlatBufferModel> model(
new FlatBufferModel(caller_owned_model_spec, error_reporter));
if (!model->initialized()) {
model.reset();
}
std::unique_ptr<FlatBufferModel> model;
model.reset(new FlatBufferModel(caller_owned_model_spec, error_reporter));
if (!model->initialized()) model.reset();
return model;
}
@ -188,9 +189,7 @@ FlatBufferModel::FlatBufferModel(std::unique_ptr<Allocation> allocation,
ErrorReporter* error_reporter)
: error_reporter_(ValidateErrorReporter(error_reporter)),
allocation_(std::move(allocation)) {
if (!allocation_ || !allocation_->valid() || !CheckModelIdentifier()) {
return;
}
if (!allocation_->valid() || !CheckModelIdentifier()) return;
model_ = ::tflite::GetModel(allocation_->base());
}

View File

@ -110,30 +110,6 @@ class FlatBufferModel {
TfLiteVerifier* extra_verifier = nullptr,
ErrorReporter* error_reporter = DefaultErrorReporter());
/// Builds a model directly from an allocation.
/// Ownership of the allocation is passed to the model, but the caller
/// retains ownership of `error_reporter` and must ensure its lifetime is
/// longer than the FlatBufferModel instance.
/// Returns a nullptr in case of failure (e.g., the allocation is invalid).
static std::unique_ptr<FlatBufferModel> BuildFromAllocation(
std::unique_ptr<Allocation> allocation,
ErrorReporter* error_reporter = DefaultErrorReporter());
/// Verifies whether the content of the allocation is legit, then builds a
/// model based on the provided allocation.
/// The extra_verifier argument is an additional optional verifier for the
/// buffer. By default, we always check with tflite::VerifyModelBuffer. If
/// extra_verifier is supplied, the buffer is checked against the
/// extra_verifier after the check against tflite::VerifyModelBuilder.
/// Ownership of the allocation is passed to the model, but the caller
/// retains ownership of `error_reporter` and must ensure its lifetime is
/// longer than the FlatBufferModel instance.
/// Returns a nullptr in case of failure.
static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromAllocation(
std::unique_ptr<Allocation> allocation,
TfLiteVerifier* extra_verifier = nullptr,
ErrorReporter* error_reporter = DefaultErrorReporter());
/// Builds a model directly from a flatbuffer pointer
/// Caller retains ownership of the buffer and should keep it alive until the
/// returned object is destroyed. Caller retains ownership of `error_reporter`

View File

@ -372,43 +372,6 @@ TEST(BasicFlatBufferModel, TestBuildFromModel) {
ASSERT_NE(interpreter, nullptr);
}
// Test that loading model directly from an Allocation works.
TEST(BasicFlatBufferModel, TestBuildFromAllocation) {
TestErrorReporter reporter;
std::unique_ptr<Allocation> model_allocation(new FileCopyAllocation(
"tensorflow/lite/testdata/test_model.bin", &reporter));
ASSERT_TRUE(model_allocation->valid());
auto model =
FlatBufferModel::BuildFromAllocation(std::move(model_allocation));
ASSERT_TRUE(model);
std::unique_ptr<Interpreter> interpreter;
ASSERT_EQ(
InterpreterBuilder(*model, TrivialResolver(&dummy_reg))(&interpreter),
kTfLiteOk);
ASSERT_NE(interpreter, nullptr);
}
TEST(BasicFlatBufferModel, TestBuildFromNullAllocation) {
TestErrorReporter reporter;
std::unique_ptr<Allocation> model_allocation;
auto model =
FlatBufferModel::BuildFromAllocation(std::move(model_allocation));
ASSERT_FALSE(model);
}
TEST(BasicFlatBufferModel, TestBuildFromInvalidAllocation) {
TestErrorReporter reporter;
std::unique_ptr<Allocation> model_allocation(
new MemoryAllocation(nullptr, 0, nullptr));
auto model =
FlatBufferModel::BuildFromAllocation(std::move(model_allocation));
ASSERT_FALSE(model);
}
// Test reading the minimum runtime string from metadata in a Model flatbuffer.
TEST(BasicFlatBufferModel, TestReadRuntimeVersionFromModel) {
// First read a model that doesn't have the runtime string.