TFLM: Move forward with the original CL and fix bluepill test by adding -fno-threadsafe-statics flag.
PiperOrigin-RevId: 302606302 Change-Id: Iaff9548f5aa7bbdfc81b981b300098f5c3ed8dea
This commit is contained in:
parent
e9db6486b3
commit
be9d33754d
@ -71,18 +71,35 @@ TfLiteStatus CalculateOpData(TfLiteContext* context,
|
||||
} // namespace
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
OpData* data = nullptr;
|
||||
TfLiteStatus status = context->AllocatePersistentBuffer(
|
||||
context, sizeof(OpData), reinterpret_cast<void**>(&data));
|
||||
if (status != kTfLiteOk || data == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
void Free(TfLiteContext* context, void* buffer) {}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
auto* params =
|
||||
reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
|
||||
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, input->type, output->type);
|
||||
TF_LITE_ENSURE_MSG(context, input->type == filter->type,
|
||||
"Hybrid models are not supported on TFLite Micro.");
|
||||
|
||||
TfLiteType data_type = input->type;
|
||||
TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input,
|
||||
filter, bias, output, data));
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
@ -178,11 +195,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
TfLiteType data_type = input->type;
|
||||
OpData local_data_object;
|
||||
OpData* data = &local_data_object;
|
||||
TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input,
|
||||
filter, bias, output, data));
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
// Checks in Prepare ensure input, output and filter types are all the same.
|
||||
switch (input->type) {
|
||||
|
@ -49,7 +49,6 @@ void TestFullyConnectedFloat(
|
||||
|
||||
TfLiteContext context;
|
||||
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
|
||||
|
||||
::tflite::ops::micro::AllOpsResolver resolver;
|
||||
const TfLiteRegistration* registration =
|
||||
resolver.FindOp(tflite::BuiltinOperator_FULLY_CONNECTED, 1);
|
||||
|
@ -58,9 +58,6 @@ CreateFlatbufferBuffers();
|
||||
// Performs a simple string comparison without requiring standard C library.
|
||||
int TestStrcmp(const char* a, const char* b);
|
||||
|
||||
// Wrapper to forward kernel errors to the interpreter's error reporter.
|
||||
void ReportOpError(struct TfLiteContext* context, const char* format, ...);
|
||||
|
||||
void PopulateContext(TfLiteTensor* tensors, int tensors_size,
|
||||
TfLiteContext* context);
|
||||
|
||||
|
@ -17,6 +17,7 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/kernels/internal:compatibility",
|
||||
"//tensorflow/lite/micro:micro_framework",
|
||||
"//tensorflow/lite/micro:micro_utils",
|
||||
],
|
||||
|
@ -15,24 +15,107 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/micro/testing/test_utils.h"
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace testing {
|
||||
|
||||
TfLiteStatus FakeAllocator::AllocatePersistentBuffer(size_t bytes, void** ptr) {
|
||||
uint8_t* addr = memory_allocator_->AllocateFromTail(bytes, kBufferAlignment);
|
||||
*ptr = addr;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus FakeAllocator::RequestScratchBufferInArena(int node_idx,
|
||||
size_t bytes,
|
||||
int* buffer_idx) {
|
||||
if (scratch_buffers_count_ >= max_scratch_buffers_count_) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
uint8_t* ptr = memory_allocator_->AllocateFromTail(bytes, kBufferAlignment);
|
||||
scratch_buffers_[scratch_buffers_count_] = ptr;
|
||||
*buffer_idx = scratch_buffers_count_;
|
||||
scratch_buffers_count_++;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
void FakeAllocator::Reset() {
|
||||
// Get A fresh memory allocator.
|
||||
memory_allocator_ = CreateInPlaceSimpleMemoryAllocator(arena_, arena_size_);
|
||||
TFLITE_DCHECK_NE(memory_allocator_, nullptr);
|
||||
|
||||
// Allocate enough space holding pointers to the scrtach buffers.
|
||||
scratch_buffers_ =
|
||||
reinterpret_cast<uint8_t**>(memory_allocator_->AllocateFromTail(
|
||||
sizeof(uint8_t*) * max_scratch_buffers_count_, alignof(uint8_t*)));
|
||||
TFLITE_DCHECK_NE(scratch_buffers_, nullptr);
|
||||
|
||||
scratch_buffers_count_ = 0;
|
||||
}
|
||||
|
||||
void* FakeAllocator::GetScratchBuffer(int buffer_idx) {
|
||||
if (buffer_idx < 0 || buffer_idx >= scratch_buffers_count_) {
|
||||
return nullptr;
|
||||
}
|
||||
return scratch_buffers_[buffer_idx];
|
||||
}
|
||||
|
||||
TfLiteStatus FakeContextHelper::AllocatePersistentBuffer(TfLiteContext* ctx,
|
||||
size_t bytes,
|
||||
void** ptr) {
|
||||
return reinterpret_cast<FakeContextHelper*>(ctx->impl_)
|
||||
->allocator_->AllocatePersistentBuffer(bytes, ptr);
|
||||
}
|
||||
|
||||
TfLiteStatus FakeContextHelper::RequestScratchBufferInArena(TfLiteContext* ctx,
|
||||
size_t bytes,
|
||||
int* buffer_idx) {
|
||||
FakeContextHelper* helper = reinterpret_cast<FakeContextHelper*>(ctx->impl_);
|
||||
// FakeAllocator doesn't do memory reusing so it doesn't need node_idx to
|
||||
// calculate the lifetime of the scratch buffer.
|
||||
int node_idx = -1;
|
||||
return helper->allocator_->RequestScratchBufferInArena(node_idx, bytes,
|
||||
buffer_idx);
|
||||
}
|
||||
|
||||
void* FakeContextHelper::GetScratchBuffer(TfLiteContext* ctx, int buffer_idx) {
|
||||
return reinterpret_cast<FakeContextHelper*>(ctx->impl_)
|
||||
->allocator_->GetScratchBuffer(buffer_idx);
|
||||
}
|
||||
|
||||
void FakeContextHelper::ReportOpError(struct TfLiteContext* context,
|
||||
const char* format, ...) {
|
||||
FakeContextHelper* helper = static_cast<FakeContextHelper*>(context->impl_);
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
TF_LITE_REPORT_ERROR(helper->error_reporter_, format, args);
|
||||
va_end(args);
|
||||
}
|
||||
|
||||
namespace {
|
||||
constexpr size_t kArenaSize = 10000;
|
||||
constexpr int kMaxScratchBufferCount = 32;
|
||||
uint8_t arena[kArenaSize];
|
||||
} // namespace
|
||||
|
||||
// TODO(b/141330728): Move this method elsewhere as part clean up.
|
||||
void PopulateContext(TfLiteTensor* tensors, int tensors_size,
|
||||
ErrorReporter* error_reporter, TfLiteContext* context) {
|
||||
// This should be a large enough arena for each test cases.
|
||||
static FakeAllocator allocator(arena, kArenaSize, kMaxScratchBufferCount);
|
||||
static FakeContextHelper helper(error_reporter, &allocator);
|
||||
// Reset the allocator so that it's ready for another test.
|
||||
allocator.Reset();
|
||||
|
||||
*context = {};
|
||||
context->recommended_num_threads = 1;
|
||||
context->tensors_size = tensors_size;
|
||||
context->tensors = tensors;
|
||||
context->impl_ = static_cast<void*>(error_reporter);
|
||||
context->GetExecutionPlan = nullptr;
|
||||
context->ResizeTensor = nullptr;
|
||||
context->ReportError = ReportOpError;
|
||||
context->AddTensors = nullptr;
|
||||
context->GetNodeAndRegistration = nullptr;
|
||||
context->ReplaceNodeSubsetsWithDelegateKernels = nullptr;
|
||||
context->recommended_num_threads = 1;
|
||||
context->GetExternalContext = nullptr;
|
||||
context->SetExternalContext = nullptr;
|
||||
context->impl_ = static_cast<void*>(&helper);
|
||||
context->AllocatePersistentBuffer = helper.AllocatePersistentBuffer;
|
||||
context->RequestScratchBufferInArena = helper.RequestScratchBufferInArena;
|
||||
context->GetScratchBuffer = helper.GetScratchBuffer;
|
||||
context->ReportError = helper.ReportOpError;
|
||||
|
||||
for (int i = 0; i < tensors_size; ++i) {
|
||||
if (context->tensors[i].is_variable) {
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/core/api/tensor_utils.h"
|
||||
#include "tensorflow/lite/micro/micro_utils.h"
|
||||
#include "tensorflow/lite/micro/simple_memory_allocator.h"
|
||||
#include "tensorflow/lite/micro/test_helpers.h"
|
||||
#include "tensorflow/lite/micro/testing/micro_test.h"
|
||||
|
||||
@ -95,7 +96,67 @@ inline int32_t F2Q32(const float value, const float scale) {
|
||||
return static_cast<int>(quantized);
|
||||
}
|
||||
|
||||
// TODO(b/141330728): Move this method elsewhere as part clean up.
|
||||
// A fake version of MemoryAllocator that allocates everything from the tail
|
||||
// without static memory planning or reusing.
|
||||
// TODO(b/150260678): Consider splitting this into its own file and inherit from
|
||||
// the same public interface as MicroAllocator.
|
||||
class FakeAllocator {
|
||||
public:
|
||||
FakeAllocator(uint8_t* arena, size_t arena_size,
|
||||
size_t max_scratch_buffers_count)
|
||||
: arena_(arena),
|
||||
arena_size_(arena_size),
|
||||
max_scratch_buffers_count_(max_scratch_buffers_count) {
|
||||
Reset();
|
||||
}
|
||||
|
||||
TfLiteStatus AllocatePersistentBuffer(size_t bytes, void** ptr);
|
||||
TfLiteStatus RequestScratchBufferInArena(int node_idx, size_t bytes,
|
||||
int* buffer_idx);
|
||||
void* GetScratchBuffer(int buffer_idx);
|
||||
|
||||
// Reset the allocator to the intial state.
|
||||
void Reset();
|
||||
|
||||
private:
|
||||
uint8_t* arena_;
|
||||
size_t arena_size_;
|
||||
size_t max_scratch_buffers_count_;
|
||||
|
||||
SimpleMemoryAllocator* memory_allocator_;
|
||||
// An array of buffer pointers.
|
||||
uint8_t** scratch_buffers_;
|
||||
size_t scratch_buffers_count_ = 0;
|
||||
static constexpr size_t kBufferAlignment = 16;
|
||||
};
|
||||
|
||||
// A fake implementation of ContextHelper. Instead of forwarding requests to
|
||||
// MicroAllocator, it calls into FakeAllocator.
|
||||
// PopulateContext will point context->impl_ to an instance of this class.
|
||||
// TODO(b/150260678): Consider moving this into the same file as FakeAllocator.
|
||||
class FakeContextHelper {
|
||||
public:
|
||||
explicit FakeContextHelper(ErrorReporter* error_reporter,
|
||||
FakeAllocator* allocator)
|
||||
: allocator_(allocator), error_reporter_(error_reporter) {}
|
||||
|
||||
static TfLiteStatus AllocatePersistentBuffer(TfLiteContext* ctx, size_t bytes,
|
||||
void** ptr);
|
||||
|
||||
static TfLiteStatus RequestScratchBufferInArena(TfLiteContext* ctx,
|
||||
size_t bytes,
|
||||
int* buffer_idx);
|
||||
|
||||
static void* GetScratchBuffer(TfLiteContext* ctx, int buffer_idx);
|
||||
|
||||
static void ReportOpError(struct TfLiteContext* context, const char* format,
|
||||
...);
|
||||
|
||||
private:
|
||||
FakeAllocator* allocator_;
|
||||
ErrorReporter* error_reporter_;
|
||||
};
|
||||
|
||||
void PopulateContext(TfLiteTensor* tensors, int tensors_size,
|
||||
ErrorReporter* error_reporter, TfLiteContext* context);
|
||||
|
||||
|
@ -38,6 +38,7 @@ ifeq ($(TARGET), bluepill)
|
||||
-Wno-unused-parameter \
|
||||
-Wno-write-strings \
|
||||
-fno-delete-null-pointer-checks \
|
||||
-fno-threadsafe-statics \
|
||||
-fomit-frame-pointer \
|
||||
-fpermissive \
|
||||
-nostdlib \
|
||||
|
Loading…
Reference in New Issue
Block a user