diff --git a/tensorflow/lite/micro/kernels/fully_connected.cc b/tensorflow/lite/micro/kernels/fully_connected.cc index 64bf788f538..91df80b328c 100644 --- a/tensorflow/lite/micro/kernels/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/fully_connected.cc @@ -71,18 +71,35 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, } // namespace void* Init(TfLiteContext* context, const char* buffer, size_t length) { - return nullptr; + OpData* data = nullptr; + TfLiteStatus status = context->AllocatePersistentBuffer( + context, sizeof(OpData), reinterpret_cast(&data)); + if (status != kTfLiteOk || data == nullptr) { + return nullptr; + } + return data; } void Free(TfLiteContext* context, void* buffer) {} TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + OpData* data = reinterpret_cast(node->user_data); + auto* params = + reinterpret_cast(node->builtin_data); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); + const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TF_LITE_ENSURE_EQ(context, input->type, output->type); TF_LITE_ENSURE_MSG(context, input->type == filter->type, "Hybrid models are not supported on TFLite Micro."); + + TfLiteType data_type = input->type; + TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input, + filter, bias, output, data)); + return kTfLiteOk; } @@ -178,11 +195,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TfLiteType data_type = input->type; - OpData local_data_object; - OpData* data = &local_data_object; - TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input, - filter, bias, output, data)); + OpData* data = reinterpret_cast(node->user_data); // Checks in Prepare ensure input, output and filter types are all the same. switch (input->type) { diff --git a/tensorflow/lite/micro/kernels/fully_connected_test.cc b/tensorflow/lite/micro/kernels/fully_connected_test.cc index 0859e4af591..4687ae89108 100644 --- a/tensorflow/lite/micro/kernels/fully_connected_test.cc +++ b/tensorflow/lite/micro/kernels/fully_connected_test.cc @@ -49,7 +49,6 @@ void TestFullyConnectedFloat( TfLiteContext context; PopulateContext(tensors, tensors_size, micro_test::reporter, &context); - ::tflite::ops::micro::AllOpsResolver resolver; const TfLiteRegistration* registration = resolver.FindOp(tflite::BuiltinOperator_FULLY_CONNECTED, 1); diff --git a/tensorflow/lite/micro/test_helpers.h b/tensorflow/lite/micro/test_helpers.h index 010e1f9e336..f4e7fa8dfba 100644 --- a/tensorflow/lite/micro/test_helpers.h +++ b/tensorflow/lite/micro/test_helpers.h @@ -58,9 +58,6 @@ CreateFlatbufferBuffers(); // Performs a simple string comparison without requiring standard C library. int TestStrcmp(const char* a, const char* b); -// Wrapper to forward kernel errors to the interpreter's error reporter. -void ReportOpError(struct TfLiteContext* context, const char* format, ...); - void PopulateContext(TfLiteTensor* tensors, int tensors_size, TfLiteContext* context); diff --git a/tensorflow/lite/micro/testing/BUILD b/tensorflow/lite/micro/testing/BUILD index 01bdffc6892..42f25f0e8b0 100644 --- a/tensorflow/lite/micro/testing/BUILD +++ b/tensorflow/lite/micro/testing/BUILD @@ -17,6 +17,7 @@ cc_library( deps = [ "//tensorflow/lite/c:common", "//tensorflow/lite/core/api", + "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro:micro_utils", ], diff --git a/tensorflow/lite/micro/testing/test_utils.cc b/tensorflow/lite/micro/testing/test_utils.cc index 9f7803fcf62..5fd0161d621 100644 --- a/tensorflow/lite/micro/testing/test_utils.cc +++ b/tensorflow/lite/micro/testing/test_utils.cc @@ -15,24 +15,107 @@ limitations under the License. #include "tensorflow/lite/micro/testing/test_utils.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" + namespace tflite { namespace testing { +TfLiteStatus FakeAllocator::AllocatePersistentBuffer(size_t bytes, void** ptr) { + uint8_t* addr = memory_allocator_->AllocateFromTail(bytes, kBufferAlignment); + *ptr = addr; + return kTfLiteOk; +} + +TfLiteStatus FakeAllocator::RequestScratchBufferInArena(int node_idx, + size_t bytes, + int* buffer_idx) { + if (scratch_buffers_count_ >= max_scratch_buffers_count_) { + return kTfLiteError; + } + uint8_t* ptr = memory_allocator_->AllocateFromTail(bytes, kBufferAlignment); + scratch_buffers_[scratch_buffers_count_] = ptr; + *buffer_idx = scratch_buffers_count_; + scratch_buffers_count_++; + return kTfLiteOk; +} + +void FakeAllocator::Reset() { + // Get A fresh memory allocator. + memory_allocator_ = CreateInPlaceSimpleMemoryAllocator(arena_, arena_size_); + TFLITE_DCHECK_NE(memory_allocator_, nullptr); + + // Allocate enough space holding pointers to the scrtach buffers. + scratch_buffers_ = + reinterpret_cast(memory_allocator_->AllocateFromTail( + sizeof(uint8_t*) * max_scratch_buffers_count_, alignof(uint8_t*))); + TFLITE_DCHECK_NE(scratch_buffers_, nullptr); + + scratch_buffers_count_ = 0; +} + +void* FakeAllocator::GetScratchBuffer(int buffer_idx) { + if (buffer_idx < 0 || buffer_idx >= scratch_buffers_count_) { + return nullptr; + } + return scratch_buffers_[buffer_idx]; +} + +TfLiteStatus FakeContextHelper::AllocatePersistentBuffer(TfLiteContext* ctx, + size_t bytes, + void** ptr) { + return reinterpret_cast(ctx->impl_) + ->allocator_->AllocatePersistentBuffer(bytes, ptr); +} + +TfLiteStatus FakeContextHelper::RequestScratchBufferInArena(TfLiteContext* ctx, + size_t bytes, + int* buffer_idx) { + FakeContextHelper* helper = reinterpret_cast(ctx->impl_); + // FakeAllocator doesn't do memory reusing so it doesn't need node_idx to + // calculate the lifetime of the scratch buffer. + int node_idx = -1; + return helper->allocator_->RequestScratchBufferInArena(node_idx, bytes, + buffer_idx); +} + +void* FakeContextHelper::GetScratchBuffer(TfLiteContext* ctx, int buffer_idx) { + return reinterpret_cast(ctx->impl_) + ->allocator_->GetScratchBuffer(buffer_idx); +} + +void FakeContextHelper::ReportOpError(struct TfLiteContext* context, + const char* format, ...) { + FakeContextHelper* helper = static_cast(context->impl_); + va_list args; + va_start(args, format); + TF_LITE_REPORT_ERROR(helper->error_reporter_, format, args); + va_end(args); +} + +namespace { +constexpr size_t kArenaSize = 10000; +constexpr int kMaxScratchBufferCount = 32; +uint8_t arena[kArenaSize]; +} // namespace + // TODO(b/141330728): Move this method elsewhere as part clean up. void PopulateContext(TfLiteTensor* tensors, int tensors_size, ErrorReporter* error_reporter, TfLiteContext* context) { + // This should be a large enough arena for each test cases. + static FakeAllocator allocator(arena, kArenaSize, kMaxScratchBufferCount); + static FakeContextHelper helper(error_reporter, &allocator); + // Reset the allocator so that it's ready for another test. + allocator.Reset(); + + *context = {}; + context->recommended_num_threads = 1; context->tensors_size = tensors_size; context->tensors = tensors; - context->impl_ = static_cast(error_reporter); - context->GetExecutionPlan = nullptr; - context->ResizeTensor = nullptr; - context->ReportError = ReportOpError; - context->AddTensors = nullptr; - context->GetNodeAndRegistration = nullptr; - context->ReplaceNodeSubsetsWithDelegateKernels = nullptr; - context->recommended_num_threads = 1; - context->GetExternalContext = nullptr; - context->SetExternalContext = nullptr; + context->impl_ = static_cast(&helper); + context->AllocatePersistentBuffer = helper.AllocatePersistentBuffer; + context->RequestScratchBufferInArena = helper.RequestScratchBufferInArena; + context->GetScratchBuffer = helper.GetScratchBuffer; + context->ReportError = helper.ReportOpError; for (int i = 0; i < tensors_size; ++i) { if (context->tensors[i].is_variable) { diff --git a/tensorflow/lite/micro/testing/test_utils.h b/tensorflow/lite/micro/testing/test_utils.h index 7aa1e9d488f..f7f5dff6bb1 100644 --- a/tensorflow/lite/micro/testing/test_utils.h +++ b/tensorflow/lite/micro/testing/test_utils.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/tensor_utils.h" #include "tensorflow/lite/micro/micro_utils.h" +#include "tensorflow/lite/micro/simple_memory_allocator.h" #include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" @@ -95,7 +96,67 @@ inline int32_t F2Q32(const float value, const float scale) { return static_cast(quantized); } -// TODO(b/141330728): Move this method elsewhere as part clean up. +// A fake version of MemoryAllocator that allocates everything from the tail +// without static memory planning or reusing. +// TODO(b/150260678): Consider splitting this into its own file and inherit from +// the same public interface as MicroAllocator. +class FakeAllocator { + public: + FakeAllocator(uint8_t* arena, size_t arena_size, + size_t max_scratch_buffers_count) + : arena_(arena), + arena_size_(arena_size), + max_scratch_buffers_count_(max_scratch_buffers_count) { + Reset(); + } + + TfLiteStatus AllocatePersistentBuffer(size_t bytes, void** ptr); + TfLiteStatus RequestScratchBufferInArena(int node_idx, size_t bytes, + int* buffer_idx); + void* GetScratchBuffer(int buffer_idx); + + // Reset the allocator to the intial state. + void Reset(); + + private: + uint8_t* arena_; + size_t arena_size_; + size_t max_scratch_buffers_count_; + + SimpleMemoryAllocator* memory_allocator_; + // An array of buffer pointers. + uint8_t** scratch_buffers_; + size_t scratch_buffers_count_ = 0; + static constexpr size_t kBufferAlignment = 16; +}; + +// A fake implementation of ContextHelper. Instead of forwarding requests to +// MicroAllocator, it calls into FakeAllocator. +// PopulateContext will point context->impl_ to an instance of this class. +// TODO(b/150260678): Consider moving this into the same file as FakeAllocator. +class FakeContextHelper { + public: + explicit FakeContextHelper(ErrorReporter* error_reporter, + FakeAllocator* allocator) + : allocator_(allocator), error_reporter_(error_reporter) {} + + static TfLiteStatus AllocatePersistentBuffer(TfLiteContext* ctx, size_t bytes, + void** ptr); + + static TfLiteStatus RequestScratchBufferInArena(TfLiteContext* ctx, + size_t bytes, + int* buffer_idx); + + static void* GetScratchBuffer(TfLiteContext* ctx, int buffer_idx); + + static void ReportOpError(struct TfLiteContext* context, const char* format, + ...); + + private: + FakeAllocator* allocator_; + ErrorReporter* error_reporter_; +}; + void PopulateContext(TfLiteTensor* tensors, int tensors_size, ErrorReporter* error_reporter, TfLiteContext* context); diff --git a/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc b/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc index 878067cf083..29a49288081 100644 --- a/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc @@ -38,6 +38,7 @@ ifeq ($(TARGET), bluepill) -Wno-unused-parameter \ -Wno-write-strings \ -fno-delete-null-pointer-checks \ + -fno-threadsafe-statics \ -fomit-frame-pointer \ -fpermissive \ -nostdlib \