TFLM: Move forward with the original CL and fix bluepill test by adding -fno-threadsafe-statics flag.

PiperOrigin-RevId: 302606302
Change-Id: Iaff9548f5aa7bbdfc81b981b300098f5c3ed8dea
This commit is contained in:
Tiezhen WANG 2020-03-24 00:09:09 -07:00 committed by TensorFlower Gardener
parent e9db6486b3
commit be9d33754d
7 changed files with 176 additions and 21 deletions

View File

@ -71,18 +71,35 @@ TfLiteStatus CalculateOpData(TfLiteContext* context,
} // namespace
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
return nullptr;
OpData* data = nullptr;
TfLiteStatus status = context->AllocatePersistentBuffer(
context, sizeof(OpData), reinterpret_cast<void**>(&data));
if (status != kTfLiteOk || data == nullptr) {
return nullptr;
}
return data;
}
void Free(TfLiteContext* context, void* buffer) {}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
auto* params =
reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_MSG(context, input->type == filter->type,
"Hybrid models are not supported on TFLite Micro.");
TfLiteType data_type = input->type;
TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input,
filter, bias, output, data));
return kTfLiteOk;
}
@ -178,11 +195,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteType data_type = input->type;
OpData local_data_object;
OpData* data = &local_data_object;
TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input,
filter, bias, output, data));
OpData* data = reinterpret_cast<OpData*>(node->user_data);
// Checks in Prepare ensure input, output and filter types are all the same.
switch (input->type) {

View File

@ -49,7 +49,6 @@ void TestFullyConnectedFloat(
TfLiteContext context;
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
::tflite::ops::micro::AllOpsResolver resolver;
const TfLiteRegistration* registration =
resolver.FindOp(tflite::BuiltinOperator_FULLY_CONNECTED, 1);

View File

@ -58,9 +58,6 @@ CreateFlatbufferBuffers();
// Performs a simple string comparison without requiring standard C library.
int TestStrcmp(const char* a, const char* b);
// Wrapper to forward kernel errors to the interpreter's error reporter.
void ReportOpError(struct TfLiteContext* context, const char* format, ...);
void PopulateContext(TfLiteTensor* tensors, int tensors_size,
TfLiteContext* context);

View File

@ -17,6 +17,7 @@ cc_library(
deps = [
"//tensorflow/lite/c:common",
"//tensorflow/lite/core/api",
"//tensorflow/lite/kernels/internal:compatibility",
"//tensorflow/lite/micro:micro_framework",
"//tensorflow/lite/micro:micro_utils",
],

View File

@ -15,24 +15,107 @@ limitations under the License.
#include "tensorflow/lite/micro/testing/test_utils.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
namespace tflite {
namespace testing {
TfLiteStatus FakeAllocator::AllocatePersistentBuffer(size_t bytes, void** ptr) {
uint8_t* addr = memory_allocator_->AllocateFromTail(bytes, kBufferAlignment);
*ptr = addr;
return kTfLiteOk;
}
TfLiteStatus FakeAllocator::RequestScratchBufferInArena(int node_idx,
size_t bytes,
int* buffer_idx) {
if (scratch_buffers_count_ >= max_scratch_buffers_count_) {
return kTfLiteError;
}
uint8_t* ptr = memory_allocator_->AllocateFromTail(bytes, kBufferAlignment);
scratch_buffers_[scratch_buffers_count_] = ptr;
*buffer_idx = scratch_buffers_count_;
scratch_buffers_count_++;
return kTfLiteOk;
}
void FakeAllocator::Reset() {
// Get A fresh memory allocator.
memory_allocator_ = CreateInPlaceSimpleMemoryAllocator(arena_, arena_size_);
TFLITE_DCHECK_NE(memory_allocator_, nullptr);
// Allocate enough space holding pointers to the scrtach buffers.
scratch_buffers_ =
reinterpret_cast<uint8_t**>(memory_allocator_->AllocateFromTail(
sizeof(uint8_t*) * max_scratch_buffers_count_, alignof(uint8_t*)));
TFLITE_DCHECK_NE(scratch_buffers_, nullptr);
scratch_buffers_count_ = 0;
}
void* FakeAllocator::GetScratchBuffer(int buffer_idx) {
if (buffer_idx < 0 || buffer_idx >= scratch_buffers_count_) {
return nullptr;
}
return scratch_buffers_[buffer_idx];
}
TfLiteStatus FakeContextHelper::AllocatePersistentBuffer(TfLiteContext* ctx,
size_t bytes,
void** ptr) {
return reinterpret_cast<FakeContextHelper*>(ctx->impl_)
->allocator_->AllocatePersistentBuffer(bytes, ptr);
}
TfLiteStatus FakeContextHelper::RequestScratchBufferInArena(TfLiteContext* ctx,
size_t bytes,
int* buffer_idx) {
FakeContextHelper* helper = reinterpret_cast<FakeContextHelper*>(ctx->impl_);
// FakeAllocator doesn't do memory reusing so it doesn't need node_idx to
// calculate the lifetime of the scratch buffer.
int node_idx = -1;
return helper->allocator_->RequestScratchBufferInArena(node_idx, bytes,
buffer_idx);
}
void* FakeContextHelper::GetScratchBuffer(TfLiteContext* ctx, int buffer_idx) {
return reinterpret_cast<FakeContextHelper*>(ctx->impl_)
->allocator_->GetScratchBuffer(buffer_idx);
}
void FakeContextHelper::ReportOpError(struct TfLiteContext* context,
const char* format, ...) {
FakeContextHelper* helper = static_cast<FakeContextHelper*>(context->impl_);
va_list args;
va_start(args, format);
TF_LITE_REPORT_ERROR(helper->error_reporter_, format, args);
va_end(args);
}
namespace {
constexpr size_t kArenaSize = 10000;
constexpr int kMaxScratchBufferCount = 32;
uint8_t arena[kArenaSize];
} // namespace
// TODO(b/141330728): Move this method elsewhere as part clean up.
void PopulateContext(TfLiteTensor* tensors, int tensors_size,
ErrorReporter* error_reporter, TfLiteContext* context) {
// This should be a large enough arena for each test cases.
static FakeAllocator allocator(arena, kArenaSize, kMaxScratchBufferCount);
static FakeContextHelper helper(error_reporter, &allocator);
// Reset the allocator so that it's ready for another test.
allocator.Reset();
*context = {};
context->recommended_num_threads = 1;
context->tensors_size = tensors_size;
context->tensors = tensors;
context->impl_ = static_cast<void*>(error_reporter);
context->GetExecutionPlan = nullptr;
context->ResizeTensor = nullptr;
context->ReportError = ReportOpError;
context->AddTensors = nullptr;
context->GetNodeAndRegistration = nullptr;
context->ReplaceNodeSubsetsWithDelegateKernels = nullptr;
context->recommended_num_threads = 1;
context->GetExternalContext = nullptr;
context->SetExternalContext = nullptr;
context->impl_ = static_cast<void*>(&helper);
context->AllocatePersistentBuffer = helper.AllocatePersistentBuffer;
context->RequestScratchBufferInArena = helper.RequestScratchBufferInArena;
context->GetScratchBuffer = helper.GetScratchBuffer;
context->ReportError = helper.ReportOpError;
for (int i = 0; i < tensors_size; ++i) {
if (context->tensors[i].is_variable) {

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/tensor_utils.h"
#include "tensorflow/lite/micro/micro_utils.h"
#include "tensorflow/lite/micro/simple_memory_allocator.h"
#include "tensorflow/lite/micro/test_helpers.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
@ -95,7 +96,67 @@ inline int32_t F2Q32(const float value, const float scale) {
return static_cast<int>(quantized);
}
// TODO(b/141330728): Move this method elsewhere as part clean up.
// A fake version of MemoryAllocator that allocates everything from the tail
// without static memory planning or reusing.
// TODO(b/150260678): Consider splitting this into its own file and inherit from
// the same public interface as MicroAllocator.
class FakeAllocator {
public:
FakeAllocator(uint8_t* arena, size_t arena_size,
size_t max_scratch_buffers_count)
: arena_(arena),
arena_size_(arena_size),
max_scratch_buffers_count_(max_scratch_buffers_count) {
Reset();
}
TfLiteStatus AllocatePersistentBuffer(size_t bytes, void** ptr);
TfLiteStatus RequestScratchBufferInArena(int node_idx, size_t bytes,
int* buffer_idx);
void* GetScratchBuffer(int buffer_idx);
// Reset the allocator to the intial state.
void Reset();
private:
uint8_t* arena_;
size_t arena_size_;
size_t max_scratch_buffers_count_;
SimpleMemoryAllocator* memory_allocator_;
// An array of buffer pointers.
uint8_t** scratch_buffers_;
size_t scratch_buffers_count_ = 0;
static constexpr size_t kBufferAlignment = 16;
};
// A fake implementation of ContextHelper. Instead of forwarding requests to
// MicroAllocator, it calls into FakeAllocator.
// PopulateContext will point context->impl_ to an instance of this class.
// TODO(b/150260678): Consider moving this into the same file as FakeAllocator.
class FakeContextHelper {
public:
explicit FakeContextHelper(ErrorReporter* error_reporter,
FakeAllocator* allocator)
: allocator_(allocator), error_reporter_(error_reporter) {}
static TfLiteStatus AllocatePersistentBuffer(TfLiteContext* ctx, size_t bytes,
void** ptr);
static TfLiteStatus RequestScratchBufferInArena(TfLiteContext* ctx,
size_t bytes,
int* buffer_idx);
static void* GetScratchBuffer(TfLiteContext* ctx, int buffer_idx);
static void ReportOpError(struct TfLiteContext* context, const char* format,
...);
private:
FakeAllocator* allocator_;
ErrorReporter* error_reporter_;
};
void PopulateContext(TfLiteTensor* tensors, int tensors_size,
ErrorReporter* error_reporter, TfLiteContext* context);

View File

@ -38,6 +38,7 @@ ifeq ($(TARGET), bluepill)
-Wno-unused-parameter \
-Wno-write-strings \
-fno-delete-null-pointer-checks \
-fno-threadsafe-statics \
-fomit-frame-pointer \
-fpermissive \
-nostdlib \