[tf.lite] (Re-land) Add FlatBufferModel::BuildFromAllocation() API
This allows more flexibility in providing models from alternative sources. Also add a `MMAPAllocation(int fd)` API for enabling user-provided file descriptors as model sources. Resolves #46593. PiperOrigin-RevId: 358931162 Change-Id: I6a443f73f7bce2c0c18dc35f20650715170365cb
This commit is contained in:
parent
901d914ab7
commit
7da3e2c1d4
@ -704,6 +704,23 @@ cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "allocation_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["allocation_test.cc"],
|
||||||
|
data = [
|
||||||
|
"testdata/empty_model.bin",
|
||||||
|
],
|
||||||
|
tags = [
|
||||||
|
"tflite_smoke_test",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":allocation",
|
||||||
|
"//tensorflow/lite/testing:util",
|
||||||
|
"@com_google_googletest//:gtest",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# Test OpResolver.
|
# Test OpResolver.
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "mutable_op_resolver_test",
|
name = "mutable_op_resolver_test",
|
||||||
|
@ -22,11 +22,8 @@ limitations under the License.
|
|||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "tensorflow/lite/c/common.h"
|
|
||||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||||
#include "tensorflow/lite/string_type.h"
|
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
|
||||||
@ -59,9 +56,18 @@ class Allocation {
|
|||||||
const Type type_;
|
const Type type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Note that not all platforms support MMAP-based allocation.
|
||||||
|
// Use `IsSupported()` to check.
|
||||||
class MMAPAllocation : public Allocation {
|
class MMAPAllocation : public Allocation {
|
||||||
public:
|
public:
|
||||||
|
// Loads and maps the provided file to a memory region.
|
||||||
MMAPAllocation(const char* filename, ErrorReporter* error_reporter);
|
MMAPAllocation(const char* filename, ErrorReporter* error_reporter);
|
||||||
|
|
||||||
|
// Maps the provided file descriptor to a memory region.
|
||||||
|
// Note: The provided file descriptor will be dup'ed for usage; the caller
|
||||||
|
// retains ownership of the provided descriptor and should close accordingly.
|
||||||
|
MMAPAllocation(int fd, ErrorReporter* error_reporter);
|
||||||
|
|
||||||
virtual ~MMAPAllocation();
|
virtual ~MMAPAllocation();
|
||||||
const void* base() const override;
|
const void* base() const override;
|
||||||
size_t bytes() const override;
|
size_t bytes() const override;
|
||||||
@ -76,10 +82,15 @@ class MMAPAllocation : public Allocation {
|
|||||||
int mmap_fd_ = -1; // mmap file descriptor
|
int mmap_fd_ = -1; // mmap file descriptor
|
||||||
const void* mmapped_buffer_;
|
const void* mmapped_buffer_;
|
||||||
size_t buffer_size_bytes_ = 0;
|
size_t buffer_size_bytes_ = 0;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Assumes ownership of the provided `owned_fd` instance.
|
||||||
|
MMAPAllocation(ErrorReporter* error_reporter, int owned_fd);
|
||||||
};
|
};
|
||||||
|
|
||||||
class FileCopyAllocation : public Allocation {
|
class FileCopyAllocation : public Allocation {
|
||||||
public:
|
public:
|
||||||
|
// Loads the provided file into a heap memory region.
|
||||||
FileCopyAllocation(const char* filename, ErrorReporter* error_reporter);
|
FileCopyAllocation(const char* filename, ErrorReporter* error_reporter);
|
||||||
virtual ~FileCopyAllocation();
|
virtual ~FileCopyAllocation();
|
||||||
const void* base() const override;
|
const void* base() const override;
|
||||||
@ -87,16 +98,15 @@ class FileCopyAllocation : public Allocation {
|
|||||||
bool valid() const override;
|
bool valid() const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Data required for mmap.
|
|
||||||
std::unique_ptr<const char[]> copied_buffer_;
|
std::unique_ptr<const char[]> copied_buffer_;
|
||||||
size_t buffer_size_bytes_ = 0;
|
size_t buffer_size_bytes_ = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
class MemoryAllocation : public Allocation {
|
class MemoryAllocation : public Allocation {
|
||||||
public:
|
public:
|
||||||
// Allocates memory with the pointer and the number of bytes of the memory.
|
// Provides a (read-only) view of the provided buffer region as an allocation.
|
||||||
// The pointer has to remain alive and unchanged until the destructor is
|
// Note: The caller retains ownership of `ptr`, and must ensure it remains
|
||||||
// called.
|
// valid for the lifetime of the class instance.
|
||||||
MemoryAllocation(const void* ptr, size_t num_bytes,
|
MemoryAllocation(const void* ptr, size_t num_bytes,
|
||||||
ErrorReporter* error_reporter);
|
ErrorReporter* error_reporter);
|
||||||
virtual ~MemoryAllocation();
|
virtual ~MemoryAllocation();
|
||||||
|
90
tensorflow/lite/allocation_test.cc
Normal file
90
tensorflow/lite/allocation_test.cc
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/lite/allocation.h"
|
||||||
|
|
||||||
|
#if defined(__linux__)
|
||||||
|
#include <fcntl.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "tensorflow/lite/testing/util.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
|
||||||
|
TEST(MMAPAllocation, TestInvalidFile) {
|
||||||
|
if (!MMAPAllocation::IsSupported()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
TestErrorReporter error_reporter;
|
||||||
|
MMAPAllocation allocation("/tmp/tflite_model_1234", &error_reporter);
|
||||||
|
EXPECT_FALSE(allocation.valid());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MMAPAllocation, TestValidFile) {
|
||||||
|
if (!MMAPAllocation::IsSupported()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
TestErrorReporter error_reporter;
|
||||||
|
MMAPAllocation allocation(
|
||||||
|
"tensorflow/lite/testdata/empty_model.bin", &error_reporter);
|
||||||
|
|
||||||
|
ASSERT_TRUE(allocation.valid());
|
||||||
|
EXPECT_GT(allocation.fd(), 0);
|
||||||
|
EXPECT_GT(allocation.bytes(), 0);
|
||||||
|
EXPECT_NE(allocation.base(), nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
#if defined(__linux__)
|
||||||
|
TEST(MMAPAllocation, TestInvalidFileDescriptor) {
|
||||||
|
if (!MMAPAllocation::IsSupported()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
TestErrorReporter error_reporter;
|
||||||
|
MMAPAllocation allocation(-1, &error_reporter);
|
||||||
|
EXPECT_FALSE(allocation.valid());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MMAPAllocation, TestValidFileDescriptor) {
|
||||||
|
if (!MMAPAllocation::IsSupported()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int fd =
|
||||||
|
open("tensorflow/lite/testdata/empty_model.bin", O_RDONLY);
|
||||||
|
ASSERT_GT(fd, 0);
|
||||||
|
|
||||||
|
TestErrorReporter error_reporter;
|
||||||
|
MMAPAllocation allocation(fd, &error_reporter);
|
||||||
|
EXPECT_TRUE(allocation.valid());
|
||||||
|
EXPECT_GT(allocation.fd(), 0);
|
||||||
|
EXPECT_GT(allocation.bytes(), 0);
|
||||||
|
EXPECT_NE(allocation.base(), nullptr);
|
||||||
|
|
||||||
|
close(fd);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
::tflite::LogToStderr();
|
||||||
|
::testing::InitGoogleTest(&argc, argv);
|
||||||
|
return RUN_ALL_TESTS();
|
||||||
|
}
|
@ -81,7 +81,7 @@ public final class NativeInterpreterWrapperTest {
|
|||||||
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INVALID_MODEL_PATH);
|
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INVALID_MODEL_PATH);
|
||||||
fail();
|
fail();
|
||||||
} catch (IllegalArgumentException e) {
|
} catch (IllegalArgumentException e) {
|
||||||
assertThat(e).hasMessageThat().contains("The model is not a valid Flatbuffer file");
|
assertThat(e).hasMessageThat().contains("The model is not a valid Flatbuffer");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -92,7 +92,6 @@ public final class NativeInterpreterWrapperTest {
|
|||||||
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(NONEXISTING_MODEL_PATH);
|
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(NONEXISTING_MODEL_PATH);
|
||||||
fail();
|
fail();
|
||||||
} catch (IllegalArgumentException e) {
|
} catch (IllegalArgumentException e) {
|
||||||
assertThat(e).hasMessageThat().contains("The model is not a valid Flatbuffer file");
|
|
||||||
assertThat(e).hasMessageThat().contains("Could not open");
|
assertThat(e).hasMessageThat().contains("Could not open");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -26,11 +26,26 @@ namespace tflite {
|
|||||||
|
|
||||||
MMAPAllocation::MMAPAllocation(const char* filename,
|
MMAPAllocation::MMAPAllocation(const char* filename,
|
||||||
ErrorReporter* error_reporter)
|
ErrorReporter* error_reporter)
|
||||||
: Allocation(error_reporter, Allocation::Type::kMMap),
|
: MMAPAllocation(error_reporter, open(filename, O_RDONLY)) {
|
||||||
mmapped_buffer_(MAP_FAILED) {
|
if (mmap_fd_ == -1) {
|
||||||
mmap_fd_ = open(filename, O_RDONLY);
|
TF_LITE_REPORT_ERROR(error_reporter, "Could not open '%s'.", filename);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MMAPAllocation::MMAPAllocation(int fd, ErrorReporter* error_reporter)
|
||||||
|
: MMAPAllocation(error_reporter, dup(fd)) {
|
||||||
|
if (mmap_fd_ == -1) {
|
||||||
|
TF_LITE_REPORT_ERROR(error_reporter, "Failed to dup '%d' file descriptor.",
|
||||||
|
fd);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MMAPAllocation::MMAPAllocation(ErrorReporter* error_reporter, int owned_fd)
|
||||||
|
: Allocation(error_reporter, Allocation::Type::kMMap),
|
||||||
|
mmap_fd_(owned_fd),
|
||||||
|
mmapped_buffer_(MAP_FAILED),
|
||||||
|
buffer_size_bytes_(0) {
|
||||||
if (mmap_fd_ == -1) {
|
if (mmap_fd_ == -1) {
|
||||||
error_reporter_->Report("Could not open '%s'.", filename);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
struct stat sb;
|
struct stat sb;
|
||||||
@ -39,7 +54,7 @@ MMAPAllocation::MMAPAllocation(const char* filename,
|
|||||||
mmapped_buffer_ =
|
mmapped_buffer_ =
|
||||||
mmap(nullptr, buffer_size_bytes_, PROT_READ, MAP_SHARED, mmap_fd_, 0);
|
mmap(nullptr, buffer_size_bytes_, PROT_READ, MAP_SHARED, mmap_fd_, 0);
|
||||||
if (mmapped_buffer_ == MAP_FAILED) {
|
if (mmapped_buffer_ == MAP_FAILED) {
|
||||||
error_reporter_->Report("Mmap of '%s' failed.", filename);
|
TF_LITE_REPORT_ERROR(error_reporter, "Mmap of '%d' failed.", mmap_fd_);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -21,6 +21,12 @@ namespace tflite {
|
|||||||
|
|
||||||
MMAPAllocation::MMAPAllocation(const char* filename,
|
MMAPAllocation::MMAPAllocation(const char* filename,
|
||||||
ErrorReporter* error_reporter)
|
ErrorReporter* error_reporter)
|
||||||
|
: MMAPAllocation(error_reporter, -1) {}
|
||||||
|
|
||||||
|
MMAPAllocation::MMAPAllocation(int fd, ErrorReporter* error_reporter)
|
||||||
|
: MMAPAllocation(error_reporter, -1) {}
|
||||||
|
|
||||||
|
MMAPAllocation::MMAPAllocation(ErrorReporter* error_reporter, int owned_fd)
|
||||||
: Allocation(error_reporter, Allocation::Type::kMMap),
|
: Allocation(error_reporter, Allocation::Type::kMMap),
|
||||||
mmapped_buffer_(nullptr) {
|
mmapped_buffer_(nullptr) {
|
||||||
// The disabled variant should never be created.
|
// The disabled variant should never be created.
|
||||||
|
@ -43,12 +43,10 @@ ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
|
|||||||
#ifndef TFLITE_MCU
|
#ifndef TFLITE_MCU
|
||||||
// Loads a model from `filename`. If `mmap_file` is true then use mmap,
|
// Loads a model from `filename`. If `mmap_file` is true then use mmap,
|
||||||
// otherwise make a copy of the model in a buffer.
|
// otherwise make a copy of the model in a buffer.
|
||||||
std::unique_ptr<Allocation> GetAllocationFromFile(const char* filename,
|
std::unique_ptr<Allocation> GetAllocationFromFile(
|
||||||
bool mmap_file,
|
const char* filename, ErrorReporter* error_reporter) {
|
||||||
ErrorReporter* error_reporter,
|
|
||||||
bool use_nnapi) {
|
|
||||||
std::unique_ptr<Allocation> allocation;
|
std::unique_ptr<Allocation> allocation;
|
||||||
if (mmap_file && MMAPAllocation::IsSupported()) {
|
if (MMAPAllocation::IsSupported()) {
|
||||||
allocation.reset(new MMAPAllocation(filename, error_reporter));
|
allocation.reset(new MMAPAllocation(filename, error_reporter));
|
||||||
} else {
|
} else {
|
||||||
allocation.reset(new FileCopyAllocation(filename, error_reporter));
|
allocation.reset(new FileCopyAllocation(filename, error_reporter));
|
||||||
@ -59,41 +57,17 @@ std::unique_ptr<Allocation> GetAllocationFromFile(const char* filename,
|
|||||||
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
|
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
|
||||||
const char* filename, ErrorReporter* error_reporter) {
|
const char* filename, ErrorReporter* error_reporter) {
|
||||||
error_reporter = ValidateErrorReporter(error_reporter);
|
error_reporter = ValidateErrorReporter(error_reporter);
|
||||||
|
return BuildFromAllocation(GetAllocationFromFile(filename, error_reporter),
|
||||||
std::unique_ptr<FlatBufferModel> model;
|
error_reporter);
|
||||||
auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
|
|
||||||
error_reporter, /*use_nnapi=*/true);
|
|
||||||
model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
|
|
||||||
if (!model->initialized()) model.reset();
|
|
||||||
return model;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromFile(
|
std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromFile(
|
||||||
const char* filename, TfLiteVerifier* extra_verifier,
|
const char* filename, TfLiteVerifier* extra_verifier,
|
||||||
ErrorReporter* error_reporter) {
|
ErrorReporter* error_reporter) {
|
||||||
error_reporter = ValidateErrorReporter(error_reporter);
|
error_reporter = ValidateErrorReporter(error_reporter);
|
||||||
|
return VerifyAndBuildFromAllocation(
|
||||||
std::unique_ptr<FlatBufferModel> model;
|
GetAllocationFromFile(filename, error_reporter), extra_verifier,
|
||||||
auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
|
error_reporter);
|
||||||
error_reporter, /*use_nnapi=*/true);
|
|
||||||
|
|
||||||
flatbuffers::Verifier base_verifier(
|
|
||||||
reinterpret_cast<const uint8_t*>(allocation->base()),
|
|
||||||
allocation->bytes());
|
|
||||||
if (!VerifyModelBuffer(base_verifier)) {
|
|
||||||
TF_LITE_REPORT_ERROR(error_reporter,
|
|
||||||
"The model is not a valid Flatbuffer file");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (extra_verifier &&
|
|
||||||
!extra_verifier->Verify(static_cast<const char*>(allocation->base()),
|
|
||||||
allocation->bytes(), error_reporter)) {
|
|
||||||
return model;
|
|
||||||
}
|
|
||||||
model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
|
|
||||||
if (!model->initialized()) model.reset();
|
|
||||||
return model;
|
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -101,34 +75,57 @@ std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
|
|||||||
const char* caller_owned_buffer, size_t buffer_size,
|
const char* caller_owned_buffer, size_t buffer_size,
|
||||||
ErrorReporter* error_reporter) {
|
ErrorReporter* error_reporter) {
|
||||||
error_reporter = ValidateErrorReporter(error_reporter);
|
error_reporter = ValidateErrorReporter(error_reporter);
|
||||||
|
|
||||||
std::unique_ptr<FlatBufferModel> model;
|
|
||||||
std::unique_ptr<Allocation> allocation(
|
std::unique_ptr<Allocation> allocation(
|
||||||
new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
|
new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
|
||||||
model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
|
return BuildFromAllocation(std::move(allocation), error_reporter);
|
||||||
if (!model->initialized()) model.reset();
|
|
||||||
return model;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromBuffer(
|
std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromBuffer(
|
||||||
const char* caller_owned_buffer, size_t buffer_size,
|
const char* caller_owned_buffer, size_t buffer_size,
|
||||||
TfLiteVerifier* extra_verifier, ErrorReporter* error_reporter) {
|
TfLiteVerifier* extra_verifier, ErrorReporter* error_reporter) {
|
||||||
error_reporter = ValidateErrorReporter(error_reporter);
|
error_reporter = ValidateErrorReporter(error_reporter);
|
||||||
|
std::unique_ptr<Allocation> allocation(
|
||||||
|
new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
|
||||||
|
return VerifyAndBuildFromAllocation(std::move(allocation), extra_verifier,
|
||||||
|
error_reporter);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromAllocation(
|
||||||
|
std::unique_ptr<Allocation> allocation, ErrorReporter* error_reporter) {
|
||||||
|
std::unique_ptr<FlatBufferModel> model(new FlatBufferModel(
|
||||||
|
std::move(allocation), ValidateErrorReporter(error_reporter)));
|
||||||
|
if (!model->initialized()) {
|
||||||
|
model.reset();
|
||||||
|
}
|
||||||
|
return model;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromAllocation(
|
||||||
|
std::unique_ptr<Allocation> allocation, TfLiteVerifier* extra_verifier,
|
||||||
|
ErrorReporter* error_reporter) {
|
||||||
|
error_reporter = ValidateErrorReporter(error_reporter);
|
||||||
|
if (!allocation || !allocation->valid()) {
|
||||||
|
TF_LITE_REPORT_ERROR(error_reporter, "The model allocation is null/empty");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
flatbuffers::Verifier base_verifier(
|
flatbuffers::Verifier base_verifier(
|
||||||
reinterpret_cast<const uint8_t*>(caller_owned_buffer), buffer_size);
|
reinterpret_cast<const uint8_t*>(allocation->base()),
|
||||||
|
allocation->bytes());
|
||||||
if (!VerifyModelBuffer(base_verifier)) {
|
if (!VerifyModelBuffer(base_verifier)) {
|
||||||
TF_LITE_REPORT_ERROR(error_reporter,
|
TF_LITE_REPORT_ERROR(error_reporter,
|
||||||
"The model is not a valid Flatbuffer buffer");
|
"The model is not a valid Flatbuffer buffer");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (extra_verifier && !extra_verifier->Verify(caller_owned_buffer,
|
if (extra_verifier &&
|
||||||
buffer_size, error_reporter)) {
|
!extra_verifier->Verify(static_cast<const char*>(allocation->base()),
|
||||||
|
allocation->bytes(), error_reporter)) {
|
||||||
|
// The verifier will have already logged an appropriate error message.
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
return BuildFromBuffer(caller_owned_buffer, buffer_size, error_reporter);
|
return BuildFromAllocation(std::move(allocation), error_reporter);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
|
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
|
||||||
@ -136,9 +133,11 @@ std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
|
|||||||
ErrorReporter* error_reporter) {
|
ErrorReporter* error_reporter) {
|
||||||
error_reporter = ValidateErrorReporter(error_reporter);
|
error_reporter = ValidateErrorReporter(error_reporter);
|
||||||
|
|
||||||
std::unique_ptr<FlatBufferModel> model;
|
std::unique_ptr<FlatBufferModel> model(
|
||||||
model.reset(new FlatBufferModel(caller_owned_model_spec, error_reporter));
|
new FlatBufferModel(caller_owned_model_spec, error_reporter));
|
||||||
if (!model->initialized()) model.reset();
|
if (!model->initialized()) {
|
||||||
|
model.reset();
|
||||||
|
}
|
||||||
return model;
|
return model;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -189,7 +188,9 @@ FlatBufferModel::FlatBufferModel(std::unique_ptr<Allocation> allocation,
|
|||||||
ErrorReporter* error_reporter)
|
ErrorReporter* error_reporter)
|
||||||
: error_reporter_(ValidateErrorReporter(error_reporter)),
|
: error_reporter_(ValidateErrorReporter(error_reporter)),
|
||||||
allocation_(std::move(allocation)) {
|
allocation_(std::move(allocation)) {
|
||||||
if (!allocation_->valid() || !CheckModelIdentifier()) return;
|
if (!allocation_ || !allocation_->valid() || !CheckModelIdentifier()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
model_ = ::tflite::GetModel(allocation_->base());
|
model_ = ::tflite::GetModel(allocation_->base());
|
||||||
}
|
}
|
||||||
|
@ -110,6 +110,30 @@ class FlatBufferModel {
|
|||||||
TfLiteVerifier* extra_verifier = nullptr,
|
TfLiteVerifier* extra_verifier = nullptr,
|
||||||
ErrorReporter* error_reporter = DefaultErrorReporter());
|
ErrorReporter* error_reporter = DefaultErrorReporter());
|
||||||
|
|
||||||
|
/// Builds a model directly from an allocation.
|
||||||
|
/// Ownership of the allocation is passed to the model, but the caller
|
||||||
|
/// retains ownership of `error_reporter` and must ensure its lifetime is
|
||||||
|
/// longer than the FlatBufferModel instance.
|
||||||
|
/// Returns a nullptr in case of failure (e.g., the allocation is invalid).
|
||||||
|
static std::unique_ptr<FlatBufferModel> BuildFromAllocation(
|
||||||
|
std::unique_ptr<Allocation> allocation,
|
||||||
|
ErrorReporter* error_reporter = DefaultErrorReporter());
|
||||||
|
|
||||||
|
/// Verifies whether the content of the allocation is legit, then builds a
|
||||||
|
/// model based on the provided allocation.
|
||||||
|
/// The extra_verifier argument is an additional optional verifier for the
|
||||||
|
/// buffer. By default, we always check with tflite::VerifyModelBuffer. If
|
||||||
|
/// extra_verifier is supplied, the buffer is checked against the
|
||||||
|
/// extra_verifier after the check against tflite::VerifyModelBuilder.
|
||||||
|
/// Ownership of the allocation is passed to the model, but the caller
|
||||||
|
/// retains ownership of `error_reporter` and must ensure its lifetime is
|
||||||
|
/// longer than the FlatBufferModel instance.
|
||||||
|
/// Returns a nullptr in case of failure.
|
||||||
|
static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromAllocation(
|
||||||
|
std::unique_ptr<Allocation> allocation,
|
||||||
|
TfLiteVerifier* extra_verifier = nullptr,
|
||||||
|
ErrorReporter* error_reporter = DefaultErrorReporter());
|
||||||
|
|
||||||
/// Builds a model directly from a flatbuffer pointer
|
/// Builds a model directly from a flatbuffer pointer
|
||||||
/// Caller retains ownership of the buffer and should keep it alive until the
|
/// Caller retains ownership of the buffer and should keep it alive until the
|
||||||
/// returned object is destroyed. Caller retains ownership of `error_reporter`
|
/// returned object is destroyed. Caller retains ownership of `error_reporter`
|
||||||
|
@ -384,6 +384,43 @@ TEST(BasicFlatBufferModel, TestBuildFromModel) {
|
|||||||
ASSERT_NE(interpreter, nullptr);
|
ASSERT_NE(interpreter, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that loading model directly from an Allocation works.
|
||||||
|
TEST(BasicFlatBufferModel, TestBuildFromAllocation) {
|
||||||
|
TestErrorReporter reporter;
|
||||||
|
std::unique_ptr<Allocation> model_allocation(new FileCopyAllocation(
|
||||||
|
"tensorflow/lite/testdata/test_model.bin", &reporter));
|
||||||
|
ASSERT_TRUE(model_allocation->valid());
|
||||||
|
|
||||||
|
auto model =
|
||||||
|
FlatBufferModel::BuildFromAllocation(std::move(model_allocation));
|
||||||
|
ASSERT_TRUE(model);
|
||||||
|
|
||||||
|
std::unique_ptr<Interpreter> interpreter;
|
||||||
|
ASSERT_EQ(
|
||||||
|
InterpreterBuilder(*model, TrivialResolver(&dummy_reg))(&interpreter),
|
||||||
|
kTfLiteOk);
|
||||||
|
ASSERT_NE(interpreter, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(BasicFlatBufferModel, TestBuildFromNullAllocation) {
|
||||||
|
TestErrorReporter reporter;
|
||||||
|
std::unique_ptr<Allocation> model_allocation;
|
||||||
|
|
||||||
|
auto model =
|
||||||
|
FlatBufferModel::BuildFromAllocation(std::move(model_allocation));
|
||||||
|
ASSERT_FALSE(model);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(BasicFlatBufferModel, TestBuildFromInvalidAllocation) {
|
||||||
|
TestErrorReporter reporter;
|
||||||
|
std::unique_ptr<Allocation> model_allocation(
|
||||||
|
new MemoryAllocation(nullptr, 0, nullptr));
|
||||||
|
|
||||||
|
auto model =
|
||||||
|
FlatBufferModel::BuildFromAllocation(std::move(model_allocation));
|
||||||
|
ASSERT_FALSE(model);
|
||||||
|
}
|
||||||
|
|
||||||
// Test reading the minimum runtime string from metadata in a Model flatbuffer.
|
// Test reading the minimum runtime string from metadata in a Model flatbuffer.
|
||||||
TEST(BasicFlatBufferModel, TestReadRuntimeVersionFromModel) {
|
TEST(BasicFlatBufferModel, TestReadRuntimeVersionFromModel) {
|
||||||
// First read a model that doesn't have the runtime string.
|
// First read a model that doesn't have the runtime string.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user