TFLM: Allow interleaving RequestScratchBuffer and AllocatePersistentBuffer in kernels.
Major changes: - Scratch buffers are placed in the head during prepare stage then move to the tail once we know its length before static memory plan. - ContextHelper sends RequestScratchBuffer request in a batch to workaround some limitation with temp allocation during Prepare stage. PiperOrigin-RevId: 328945674 Change-Id: I09db5c1be0e225904f1c4bf3a5a4a2831a5db438
This commit is contained in:
parent
757befb73d
commit
59d177d9ac
@ -337,8 +337,8 @@ TfLiteStatus AllocationInfoBuilder::AddScratchBuffers(
|
||||
current->bytes = handle->bytes;
|
||||
current->first_created = handle->node_idx;
|
||||
current->last_used = handle->node_idx;
|
||||
current->needs_allocating = true;
|
||||
current->offline_offset = kOnlinePlannedBuffer;
|
||||
current->needs_allocating = true;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
@ -655,6 +655,7 @@ TfLiteStatus MicroAllocator::StartModelAllocation(
|
||||
|
||||
model_is_allocating_ = true;
|
||||
|
||||
TF_LITE_ENSURE_STATUS(InitScratchBufferHandles());
|
||||
TF_LITE_ENSURE_STATUS(AllocateTfLiteEvalTensors(model, eval_tensors));
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
AllocateNodeAndRegistrations(model, node_and_registrations));
|
||||
@ -665,7 +666,8 @@ TfLiteStatus MicroAllocator::StartModelAllocation(
|
||||
}
|
||||
|
||||
TfLiteStatus MicroAllocator::FinishModelAllocation(
|
||||
const Model* model, TfLiteEvalTensor* eval_tensors) {
|
||||
const Model* model, TfLiteEvalTensor* eval_tensors,
|
||||
void** scratch_buffer_handles) {
|
||||
if (!model_is_allocating_) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter_,
|
||||
"MicroAllocator: Model allocation finished before "
|
||||
@ -676,9 +678,13 @@ TfLiteStatus MicroAllocator::FinishModelAllocation(
|
||||
const SubGraph* subgraph = GetSubGraphFromModel(model);
|
||||
TFLITE_DCHECK(subgraph != nullptr);
|
||||
|
||||
TF_LITE_ENSURE_STATUS(MoveScratchBufferHandlesToTail());
|
||||
TF_LITE_ENSURE_STATUS(CommitStaticMemoryPlan(model, subgraph, eval_tensors));
|
||||
TF_LITE_ENSURE_STATUS(AllocateVariables(subgraph, eval_tensors));
|
||||
|
||||
if (scratch_buffer_handles != nullptr) {
|
||||
*scratch_buffer_handles = scratch_buffer_handles_;
|
||||
}
|
||||
model_is_allocating_ = false;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
@ -690,49 +696,39 @@ void* MicroAllocator::AllocatePersistentBuffer(size_t bytes) {
|
||||
TfLiteStatus MicroAllocator::RequestScratchBufferInArena(int node_id,
|
||||
size_t bytes,
|
||||
int* buffer_idx) {
|
||||
// A consistency check to make sure scratch_buffer_handles_ is contiguous i.e.
|
||||
// scratch_buffer_handles_ is pointing to the last allocation from memory
|
||||
// allocator.
|
||||
if (scratch_buffer_handles_ != nullptr &&
|
||||
reinterpret_cast<uint8_t*>(scratch_buffer_handles_) !=
|
||||
memory_allocator_->GetTail()) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter_,
|
||||
"Internal error: AllocateFromTail can not be called "
|
||||
"between two RequestScratchBufferInArena calls.");
|
||||
return kTfLiteError;
|
||||
// This method is only called during Prepare stage, when the scratch buffer
|
||||
// handles are placed in the head.
|
||||
|
||||
// Allocate space for the new scratch buffer handle.
|
||||
TF_LITE_ENSURE_STATUS(memory_allocator_->EnsureHeadSize(
|
||||
sizeof(internal::ScratchBufferHandle) * (scratch_buffer_count_ + 1),
|
||||
alignof(internal::ScratchBufferHandle)));
|
||||
|
||||
if (scratch_buffer_handles_ == nullptr) {
|
||||
// If this is the first scratch buffer handle, place it in the buffer head.
|
||||
scratch_buffer_handles_ = reinterpret_cast<internal::ScratchBufferHandle*>(
|
||||
memory_allocator_->GetBufferHead());
|
||||
}
|
||||
|
||||
// Initialize the handle. `data` field will be set during memory planning.
|
||||
internal::ScratchBufferHandle* handle =
|
||||
reinterpret_cast<internal::ScratchBufferHandle*>(
|
||||
memory_allocator_->AllocateFromTail(
|
||||
sizeof(internal::ScratchBufferHandle),
|
||||
alignof(internal::ScratchBufferHandle)));
|
||||
if (handle == nullptr) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter_,
|
||||
"Failed to register scratch buffer handle for node %s",
|
||||
node_id);
|
||||
return kTfLiteError;
|
||||
}
|
||||
scratch_buffer_handles_ + scratch_buffer_count_;
|
||||
*handle = {};
|
||||
handle->bytes = bytes;
|
||||
handle->node_idx = node_id;
|
||||
|
||||
// Buffer idx starts from 0 in this implementation.
|
||||
*buffer_idx = scratch_buffer_count_;
|
||||
scratch_buffer_count_ += 1;
|
||||
// scratch_buffer_handles_ is in reverse order. The following code ensures
|
||||
// that scratch_buffers[0] is pointing to the newly allocated handle.
|
||||
scratch_buffer_handles_ = handle;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
void* MicroAllocator::GetScratchBuffer(int buffer_idx) const {
|
||||
if (static_cast<size_t>(buffer_idx) >= scratch_buffer_count_) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter_,
|
||||
"Buffer %d not found. %d buffers available.",
|
||||
buffer_idx, scratch_buffer_count_);
|
||||
return nullptr;
|
||||
}
|
||||
// scratch_buffer_handles_ is in reverse order.
|
||||
return scratch_buffer_handles_[scratch_buffer_count_ - buffer_idx - 1].data;
|
||||
void* MicroAllocator::GetScratchBuffer(void* scratch_buffer_handles,
|
||||
int buffer_idx) {
|
||||
internal::ScratchBufferHandle* handle =
|
||||
reinterpret_cast<internal::ScratchBufferHandle*>(scratch_buffer_handles) +
|
||||
buffer_idx;
|
||||
return handle->data;
|
||||
}
|
||||
|
||||
size_t MicroAllocator::used_bytes() const {
|
||||
@ -1035,7 +1031,6 @@ TfLiteStatus MicroAllocator::CommitStaticMemoryPlan(
|
||||
builder.GetOfflinePlannedOffsets(model, &offline_planner_offsets));
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
builder.AddTensors(subgraph, offline_planner_offsets, eval_tensors));
|
||||
|
||||
TF_LITE_ENSURE_STATUS(builder.AddScratchBuffers(scratch_buffer_handles_));
|
||||
const AllocationInfo* allocation_info = builder.Finish();
|
||||
|
||||
@ -1051,16 +1046,16 @@ TfLiteStatus MicroAllocator::CommitStaticMemoryPlan(
|
||||
|
||||
size_t actual_available_arena_size =
|
||||
memory_allocator_->GetAvailableMemory(kBufferAlignment);
|
||||
|
||||
// Make sure we have enough arena size.
|
||||
if (planner.GetMaximumMemorySize() > actual_available_arena_size) {
|
||||
TF_LITE_REPORT_ERROR(
|
||||
error_reporter_,
|
||||
"Arena size is too small for activation buffers. Needed %d but only "
|
||||
"%d was available.",
|
||||
"Arena size is too small for all buffers. Needed %u but only "
|
||||
"%u was available.",
|
||||
planner.GetMaximumMemorySize(), actual_available_arena_size);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// Commit the plan.
|
||||
TF_LITE_ENSURE_STATUS(CommitPlan(error_reporter_, &planner,
|
||||
memory_allocator_->GetBufferHead(),
|
||||
@ -1073,4 +1068,27 @@ TfLiteStatus MicroAllocator::CommitStaticMemoryPlan(
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus MicroAllocator::InitScratchBufferHandles() {
|
||||
scratch_buffer_count_ = 0;
|
||||
scratch_buffer_handles_ = nullptr;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus MicroAllocator::MoveScratchBufferHandlesToTail() {
|
||||
if (scratch_buffer_count_ == 0) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
auto src = scratch_buffer_handles_;
|
||||
internal::ScratchBufferHandle* dest =
|
||||
reinterpret_cast<internal::ScratchBufferHandle*>(
|
||||
memory_allocator_->AllocateFromTail(
|
||||
sizeof(internal::ScratchBufferHandle) * scratch_buffer_count_,
|
||||
alignof(internal::ScratchBufferHandle)));
|
||||
for (size_t i = 0; i < scratch_buffer_count_; i++) {
|
||||
*(dest + i) = *(src + i);
|
||||
}
|
||||
scratch_buffer_handles_ = dest;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
@ -123,9 +123,12 @@ class MicroAllocator {
|
||||
// the 'head' section of the memory arena. All variable tensor data will also
|
||||
// be allocated. This method should be called after assigning model resources
|
||||
// in StartModelAllocation(). The eval_tensors pointer should be the value
|
||||
// passed into this class during StartModelAllocation().
|
||||
// passed into this class during StartModelAllocation(). Scratch buffer
|
||||
// handles are stored in the out-param `scratch_buffer_handles`. This value
|
||||
// will be used in `GetScratchBuffer` call to retrieve scratch buffers.
|
||||
TfLiteStatus FinishModelAllocation(const Model* model,
|
||||
TfLiteEvalTensor* eval_tensors);
|
||||
TfLiteEvalTensor* eval_tensors,
|
||||
void** scratch_buffer_handles = nullptr);
|
||||
|
||||
// Allocates a TfLiteTensor struct and populates the returned value with
|
||||
// properties from the model flatbuffer. This struct is allocated from
|
||||
@ -160,12 +163,18 @@ class MicroAllocator {
|
||||
// This method only allocates a BufferHandle holding information for memory
|
||||
// planning. The buffer ptr is ready after `FinishModelAllocation` and can
|
||||
// be retrieved by `GetScratchBuffer` method using the returned buffer_idx.
|
||||
// Note that there should be no tail allocation between two consecutive
|
||||
// `RequestScratchBufferInArena` calls.
|
||||
// Note that this method should only be called in the Prepare stage.
|
||||
TfLiteStatus RequestScratchBufferInArena(int node_id, size_t bytes,
|
||||
int* buffer_idx);
|
||||
// Returns the pointer to the planned scratch buffer.
|
||||
void* GetScratchBuffer(int buffer_idx) const;
|
||||
|
||||
// Return the number of scratch buffers in the allocator.
|
||||
size_t GetScratchBufferCount() const { return scratch_buffer_count_; }
|
||||
|
||||
// Return the pointer to the planned scratch buffer. `scratch_buffer_handles`
|
||||
// should be the corresponding value returned in `FinishModelAllocation`.
|
||||
// `scratch_buffer_handles` is intentionally desigend as void*. The actual
|
||||
// data type is an implementation detail, and is only visible in this class.
|
||||
static void* GetScratchBuffer(void* scratch_buffer_handles, int buffer_idx);
|
||||
|
||||
// Returns the arena usage in bytes, only available after
|
||||
// `FinishModelAllocation`. Otherwise, it will return 0.
|
||||
@ -236,13 +245,16 @@ class MicroAllocator {
|
||||
ErrorReporter* error_reporter_;
|
||||
bool model_is_allocating_;
|
||||
|
||||
// In reverse order for efficiency.
|
||||
// i.e. scratch_buffer_handles_[0] is the handle for the last buffer,
|
||||
// corresponding to the last RequestScratchBufferInArena call.
|
||||
// Points to the first allocated scratch buffer handle.
|
||||
// Scratch buffer handles are placed in the head during `Prepare` stage and
|
||||
// then moved to the tail for static memory plan.
|
||||
internal::ScratchBufferHandle* scratch_buffer_handles_ = nullptr;
|
||||
// How many scratch buffers have been allocated.
|
||||
size_t scratch_buffer_count_ = 0;
|
||||
|
||||
virtual TfLiteStatus InitScratchBufferHandles();
|
||||
virtual TfLiteStatus MoveScratchBufferHandlesToTail();
|
||||
|
||||
TF_LITE_REMOVE_VIRTUAL_DELETE
|
||||
};
|
||||
|
||||
|
||||
@ -59,13 +59,31 @@ TfLiteStatus ContextHelper::RequestScratchBufferInArena(TfLiteContext* ctx,
|
||||
size_t bytes,
|
||||
int* buffer_idx) {
|
||||
ContextHelper* helper = reinterpret_cast<ContextHelper*>(ctx->impl_);
|
||||
return helper->allocator_->RequestScratchBufferInArena(
|
||||
helper->current_node_idx_, bytes, buffer_idx);
|
||||
|
||||
// We can not forward the scratch buffer request to the allocator yet,
|
||||
// otherwise the scratch buffer handles will ruin the data in `temp` section.
|
||||
// These requests will be processed once the `temp` section is deallocated,
|
||||
// i.e. after a node has been prepared.
|
||||
|
||||
if (helper->scratch_buffer_count_ >= kMaxScratchBuffersPerOp) {
|
||||
TF_LITE_REPORT_ERROR(
|
||||
helper->error_reporter_,
|
||||
"Node %d is allocating too many scratch buffers per op, max=%d",
|
||||
helper->current_node_idx_, helper->scratch_buffer_count_);
|
||||
}
|
||||
helper->scrach_buffer_sizes_[helper->scratch_buffer_count_] = bytes;
|
||||
// buffer_idx is 0 indexed.
|
||||
*buffer_idx = helper->scratch_buffer_count_ +
|
||||
helper->allocator_->GetScratchBufferCount();
|
||||
helper->scratch_buffer_count_++;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
void* ContextHelper::GetScratchBuffer(TfLiteContext* ctx, int buffer_idx) {
|
||||
return reinterpret_cast<ContextHelper*>(ctx->impl_)
|
||||
->allocator_->GetScratchBuffer(buffer_idx);
|
||||
ContextHelper* helper = reinterpret_cast<ContextHelper*>(ctx->impl_);
|
||||
|
||||
return helper->allocator_->GetScratchBuffer(helper->scratch_buffer_handles_,
|
||||
buffer_idx);
|
||||
}
|
||||
|
||||
void ContextHelper::ReportOpError(struct TfLiteContext* context,
|
||||
@ -92,12 +110,39 @@ TfLiteEvalTensor* ContextHelper::GetEvalTensor(
|
||||
return &helper->eval_tensors_[tensor_idx];
|
||||
}
|
||||
|
||||
void ContextHelper::SetNodeIndex(int idx) { current_node_idx_ = idx; }
|
||||
void ContextHelper::SetNodeIndex(int idx) {
|
||||
if (scratch_buffer_count_ != 0) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter_,
|
||||
"Internal error: Please commit scratch buffers "
|
||||
"befrore moving to the next node");
|
||||
}
|
||||
current_node_idx_ = idx;
|
||||
}
|
||||
|
||||
void ContextHelper::SetTfLiteEvalTensors(TfLiteEvalTensor* eval_tensors) {
|
||||
eval_tensors_ = eval_tensors;
|
||||
}
|
||||
|
||||
void ContextHelper::SetScratchBufferHandles(void* scratch_buffer_handle) {
|
||||
scratch_buffer_handles_ = scratch_buffer_handle;
|
||||
}
|
||||
|
||||
TfLiteStatus ContextHelper::CommitScratchBuffers() {
|
||||
size_t initial_buffer_count = allocator_->GetScratchBufferCount();
|
||||
for (size_t i = 0; i < scratch_buffer_count_; i++) {
|
||||
int buffer_id;
|
||||
allocator_->RequestScratchBufferInArena(
|
||||
current_node_idx_, scrach_buffer_sizes_[i], &buffer_id);
|
||||
if (static_cast<size_t>(buffer_id) != initial_buffer_count + i) {
|
||||
TF_LITE_REPORT_ERROR(
|
||||
error_reporter_,
|
||||
"Internal error. Scratch buffers are not contiguous.\n");
|
||||
}
|
||||
}
|
||||
scratch_buffer_count_ = 0;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
|
||||
MicroInterpreter::MicroInterpreter(const Model* model,
|
||||
@ -297,6 +342,7 @@ TfLiteStatus MicroInterpreter::AllocateTensors() {
|
||||
}
|
||||
}
|
||||
allocator_.ResetTempAllocations();
|
||||
context_helper_.CommitScratchBuffers();
|
||||
}
|
||||
context_helper_.SetNodeIndex(-1);
|
||||
|
||||
@ -306,8 +352,12 @@ TfLiteStatus MicroInterpreter::AllocateTensors() {
|
||||
context_.RequestScratchBufferInArena = nullptr;
|
||||
context_.GetScratchBuffer = context_helper_.GetScratchBuffer;
|
||||
|
||||
void* scratch_buffer_handles = nullptr;
|
||||
|
||||
TF_LITE_ENSURE_OK(&context_,
|
||||
allocator_.FinishModelAllocation(model_, eval_tensors_));
|
||||
allocator_.FinishModelAllocation(model_, eval_tensors_,
|
||||
&scratch_buffer_handles));
|
||||
context_helper_.SetScratchBufferHandles(scratch_buffer_handles);
|
||||
TF_LITE_ENSURE_STATUS(ResetVariableTensors());
|
||||
|
||||
tensors_allocated_ = true;
|
||||
|
||||
@ -32,6 +32,8 @@ namespace tflite {
|
||||
|
||||
namespace internal {
|
||||
|
||||
constexpr size_t kMaxScratchBuffersPerOp = 8;
|
||||
|
||||
// A helper class to encapsulate the implementation of APIs in Context.
|
||||
// context->impl_ points to an instance of this class.
|
||||
// Check tensorflow/lite/c/common.h for detailed descriptions.
|
||||
@ -53,19 +55,28 @@ class ContextHelper {
|
||||
int tensor_idx);
|
||||
static TfLiteEvalTensor* GetEvalTensor(const struct TfLiteContext* context,
|
||||
int tensor_idx);
|
||||
// Commits all scratch buffer allocations to MicroAllocator.
|
||||
TfLiteStatus CommitScratchBuffers();
|
||||
|
||||
// Sets the current node index to assist with scratch buffer allocations:
|
||||
void SetNodeIndex(int idx);
|
||||
|
||||
// Sets the pointer to a list of TfLiteEvalTensor instances.
|
||||
void SetTfLiteEvalTensors(TfLiteEvalTensor* eval_tensors);
|
||||
// Sets the pointer to scratch buffer handle, which is needed by
|
||||
// `GetScratchBuffer`.
|
||||
void SetScratchBufferHandles(void* scratch_buffer_handle);
|
||||
|
||||
private:
|
||||
MicroAllocator* allocator_;
|
||||
ErrorReporter* error_reporter_;
|
||||
const Model* model_;
|
||||
TfLiteEvalTensor* eval_tensors_;
|
||||
MicroAllocator* allocator_ = nullptr;
|
||||
ErrorReporter* error_reporter_ = nullptr;
|
||||
const Model* model_ = nullptr;
|
||||
TfLiteEvalTensor* eval_tensors_ = nullptr;
|
||||
void* scratch_buffer_handles_ = nullptr;
|
||||
int current_node_idx_ = -1;
|
||||
|
||||
size_t scrach_buffer_sizes_[kMaxScratchBuffersPerOp];
|
||||
size_t scratch_buffer_count_ = 0;
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
@ -220,38 +220,45 @@ TF_LITE_MICRO_TEST(TestKernelMemoryPlanning) {
|
||||
|
||||
tflite::AllOpsResolver op_resolver = tflite::testing::GetOpResolver();
|
||||
|
||||
constexpr size_t allocator_buffer_size = 2048;
|
||||
constexpr size_t allocator_buffer_size = 4096;
|
||||
uint8_t allocator_buffer[allocator_buffer_size];
|
||||
tflite::MicroInterpreter interpreter(model, op_resolver, allocator_buffer,
|
||||
allocator_buffer_size,
|
||||
micro_test::reporter);
|
||||
TF_LITE_MICRO_EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
|
||||
TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(1), interpreter.inputs_size());
|
||||
TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(2), interpreter.outputs_size());
|
||||
|
||||
TfLiteTensor* input = interpreter.input(0);
|
||||
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->size);
|
||||
TF_LITE_MICRO_EXPECT_EQ(3, input->dims->data[0]);
|
||||
input->data.uint8[0] = 2;
|
||||
input->data.uint8[1] = 3;
|
||||
input->data.uint8[2] = 1;
|
||||
tflite::RecordingMicroAllocator* allocator =
|
||||
tflite::RecordingMicroAllocator::Create(
|
||||
allocator_buffer, allocator_buffer_size, micro_test::reporter);
|
||||
|
||||
uint8_t expected_median = 2;
|
||||
// Make sure kernel memory planning works in multi-tenant context.
|
||||
for (int i = 0; i < 3; i++) {
|
||||
tflite::MicroInterpreter interpreter(model, op_resolver, allocator,
|
||||
micro_test::reporter);
|
||||
TF_LITE_MICRO_EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
|
||||
TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(1), interpreter.inputs_size());
|
||||
TF_LITE_MICRO_EXPECT_EQ(static_cast<size_t>(2), interpreter.outputs_size());
|
||||
|
||||
{
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, interpreter.Invoke());
|
||||
TfLiteTensor* median = interpreter.output(0);
|
||||
TF_LITE_MICRO_EXPECT_EQ(expected_median, median->data.uint8[0]);
|
||||
TfLiteTensor* invoke_count = interpreter.output(1);
|
||||
TF_LITE_MICRO_EXPECT_EQ(1, invoke_count->data.i32[0]);
|
||||
}
|
||||
TfLiteTensor* input = interpreter.input(0);
|
||||
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->size);
|
||||
TF_LITE_MICRO_EXPECT_EQ(3, input->dims->data[0]);
|
||||
input->data.uint8[0] = 2;
|
||||
input->data.uint8[1] = 3;
|
||||
input->data.uint8[2] = 1;
|
||||
|
||||
{
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, interpreter.Invoke());
|
||||
TfLiteTensor* median = interpreter.output(0);
|
||||
TF_LITE_MICRO_EXPECT_EQ(expected_median, median->data.uint8[0]);
|
||||
TfLiteTensor* invoke_count = interpreter.output(1);
|
||||
TF_LITE_MICRO_EXPECT_EQ(2, invoke_count->data.i32[0]);
|
||||
uint8_t expected_median = 2;
|
||||
|
||||
{
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, interpreter.Invoke());
|
||||
TfLiteTensor* median = interpreter.output(0);
|
||||
TF_LITE_MICRO_EXPECT_EQ(expected_median, median->data.uint8[0]);
|
||||
TfLiteTensor* invoke_count = interpreter.output(1);
|
||||
TF_LITE_MICRO_EXPECT_EQ(1, invoke_count->data.i32[0]);
|
||||
}
|
||||
|
||||
{
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, interpreter.Invoke());
|
||||
TfLiteTensor* median = interpreter.output(0);
|
||||
TF_LITE_MICRO_EXPECT_EQ(expected_median, median->data.uint8[0]);
|
||||
TfLiteTensor* invoke_count = interpreter.output(1);
|
||||
TF_LITE_MICRO_EXPECT_EQ(2, invoke_count->data.i32[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -593,13 +593,18 @@ TfLiteStatus SimpleStatefulOp::Prepare(TfLiteContext* context,
|
||||
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
|
||||
context, sizeof(uint8_t) * NumElements(input->dims),
|
||||
&data->sorting_buffer));
|
||||
// We can interleave scratch / persistent buffer allocation.
|
||||
data->invoke_count = reinterpret_cast<int*>(
|
||||
context->AllocatePersistentBuffer(context, sizeof(int)));
|
||||
*data->invoke_count = 0;
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus SimpleStatefulOp::Invoke(TfLiteContext* context,
|
||||
TfLiteNode* node) {
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
data->invoke_count += 1;
|
||||
*data->invoke_count += 1;
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const uint8_t* input_data = GetTensorData<uint8_t>(input);
|
||||
@ -626,7 +631,7 @@ TfLiteStatus SimpleStatefulOp::Invoke(TfLiteContext* context,
|
||||
int32_t* invoke_count_data = GetTensorData<int32_t>(invoke_count);
|
||||
|
||||
median_data[0] = sorting_buffer[size / 2];
|
||||
invoke_count_data[0] = data->invoke_count;
|
||||
invoke_count_data[0] = *data->invoke_count;
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
||||
@ -49,7 +49,7 @@ class SimpleStatefulOp {
|
||||
static constexpr int kMedianTensor = 0;
|
||||
static constexpr int kInvokeCount = 1;
|
||||
struct OpData {
|
||||
int invoke_count = 0;
|
||||
int* invoke_count = nullptr;
|
||||
int sorting_buffer = kBufferNotAllocated;
|
||||
};
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user