Add check for correct memory alignment to MemoryAllocation::MemoryAllocation() on 32-bit ARM

This will give a reasonable error message at model build time, rather than a SIGBUS later.

PiperOrigin-RevId: 290002381
Change-Id: I4126c4bcfdcee3c7e962a838ff4838e5c59d48f6
This commit is contained in:
Terry Heo 2020-01-15 22:18:00 -08:00 committed by TensorFlower Gardener
parent 6ecf5b0767
commit b37904edb5
2 changed files with 63 additions and 7 deletions

View File

@ -87,6 +87,24 @@ bool FileCopyAllocation::valid() const { return copied_buffer_ != nullptr; }
MemoryAllocation::MemoryAllocation(const void* ptr, size_t num_bytes, MemoryAllocation::MemoryAllocation(const void* ptr, size_t num_bytes,
ErrorReporter* error_reporter) ErrorReporter* error_reporter)
: Allocation(error_reporter, Allocation::Type::kMemory) { : Allocation(error_reporter, Allocation::Type::kMemory) {
#ifdef __arm__
if ((reinterpret_cast<uintptr_t>(ptr) & 0x3) != 0) {
// The flatbuffer schema has alignment requirements of up to 16 bytes to
// guarantee that data can be correctly accesses by various backends.
// Therefore, model pointer should also be 16-bytes aligned to preserve this
// requirement. But this condition only checks 4-bytes alignment which is
// the mininum requirement to prevent SIGBUS fault on 32bit ARM. Some models
// could require 8 or 16 bytes alignment which is not checked yet.
//
// Note that 64-bit ARM may also suffer a performance impact, but no crash -
// that case is not checked.
error_reporter->Report("The supplied buffer is not 4-bytes aligned");
buffer_ = nullptr;
buffer_size_bytes_ = 0;
return;
}
#endif // __arm__
buffer_ = ptr; buffer_ = ptr;
buffer_size_bytes_ = num_bytes; buffer_size_bytes_ = num_bytes;
} }

View File

@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/lite/model.h"
#include <fcntl.h> #include <fcntl.h>
#include <stdint.h> #include <stdint.h>
#include <stdio.h> #include <stdio.h>
@ -20,7 +22,8 @@ limitations under the License.
#include <sys/stat.h> #include <sys/stat.h>
#include <sys/types.h> #include <sys/types.h>
#include "tensorflow/lite/model.h" #include <fstream>
#include <iostream>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/error_reporter.h"
@ -72,6 +75,44 @@ TEST(BasicFlatBufferModel, TestNonExistantFiles) {
ASSERT_TRUE(!FlatBufferModel::BuildFromFile("/tmp/tflite_model_1234")); ASSERT_TRUE(!FlatBufferModel::BuildFromFile("/tmp/tflite_model_1234"));
} }
TEST(BasicFlatBufferModel, TestBufferAlignment) {
// On 32-bit ARM buffers are required to be 4-bytes aligned, on other
// platforms there is no alignment requirement.
const uintptr_t kAlignment = 4;
const uintptr_t kAlignmentBits = kAlignment - 1;
// Use real model data so that we can be sure error is only from the
// alignment requirement and not from bad data.
std::ifstream fp("tensorflow/lite/testdata/empty_model.bin");
ASSERT_TRUE(fp.good());
std::string empty_model_data((std::istreambuf_iterator<char>(fp)),
std::istreambuf_iterator<char>());
auto free_chars = [](char* p) { free(p); };
std::unique_ptr<char, decltype(free_chars)> buffer(
reinterpret_cast<char*>(malloc(empty_model_data.size() + kAlignment)),
free_chars);
// Check that aligned buffer works (no other errors in the test).
char* aligned = reinterpret_cast<char*>(
(reinterpret_cast<uintptr_t>(buffer.get()) + kAlignment) &
~kAlignmentBits);
memcpy(aligned, empty_model_data.c_str(), empty_model_data.size());
EXPECT_TRUE(
FlatBufferModel::BuildFromBuffer(aligned, empty_model_data.size()));
// Check unaligned buffer handling.
char* unaligned =
reinterpret_cast<char*>(reinterpret_cast<uintptr_t>(buffer.get()) | 0x1);
memcpy(unaligned, empty_model_data.c_str(), empty_model_data.size());
#ifdef __arm__
EXPECT_FALSE(
FlatBufferModel::BuildFromBuffer(unaligned, empty_model_data.size()));
#else // !__arm__
EXPECT_TRUE(
FlatBufferModel::BuildFromBuffer(unaligned, empty_model_data.size()));
#endif // __arm__
}
// Make sure a model with nothing in it loads properly. // Make sure a model with nothing in it loads properly.
TEST(BasicFlatBufferModel, TestEmptyModelsAndNullDestination) { TEST(BasicFlatBufferModel, TestEmptyModelsAndNullDestination) {
auto model = FlatBufferModel::BuildFromFile( auto model = FlatBufferModel::BuildFromFile(
@ -248,15 +289,13 @@ class FakeVerifier : public tflite::TfLiteVerifier {
TEST(BasicFlatBufferModel, TestWithTrueVerifier) { TEST(BasicFlatBufferModel, TestWithTrueVerifier) {
FakeVerifier verifier(true); FakeVerifier verifier(true);
ASSERT_TRUE(FlatBufferModel::VerifyAndBuildFromFile( ASSERT_TRUE(FlatBufferModel::VerifyAndBuildFromFile(
"tensorflow/lite/testdata/test_model.bin", "tensorflow/lite/testdata/test_model.bin", &verifier));
&verifier));
} }
TEST(BasicFlatBufferModel, TestWithFalseVerifier) { TEST(BasicFlatBufferModel, TestWithFalseVerifier) {
FakeVerifier verifier(false); FakeVerifier verifier(false);
ASSERT_FALSE(FlatBufferModel::VerifyAndBuildFromFile( ASSERT_FALSE(FlatBufferModel::VerifyAndBuildFromFile(
"tensorflow/lite/testdata/test_model.bin", "tensorflow/lite/testdata/test_model.bin", &verifier));
&verifier));
} }
TEST(BasicFlatBufferModel, TestWithNullVerifier) { TEST(BasicFlatBufferModel, TestWithNullVerifier) {
@ -269,8 +308,7 @@ TEST(BasicFlatBufferModel, TestWithNullVerifier) {
TEST(BasicFlatBufferModel, TestCustomErrorReporter) { TEST(BasicFlatBufferModel, TestCustomErrorReporter) {
TestErrorReporter reporter; TestErrorReporter reporter;
auto model = FlatBufferModel::BuildFromFile( auto model = FlatBufferModel::BuildFromFile(
"tensorflow/lite/testdata/empty_model.bin", "tensorflow/lite/testdata/empty_model.bin", &reporter);
&reporter);
ASSERT_TRUE(model); ASSERT_TRUE(model);
std::unique_ptr<Interpreter> interpreter; std::unique_ptr<Interpreter> interpreter;