Fix review comments, 24/4

This commit is contained in:
Fredrik Knutsson 2020-04-24 15:19:39 +02:00 committed by Jens Elofsson
parent 4703dd2ae4
commit e9b965a29c
5 changed files with 54 additions and 40 deletions

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/lite/micro/micro_error_reporter.h" #include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h" #include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h" #include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/micro_optional_debug_tools.h"
#include "tensorflow/lite/micro/testing/micro_test.h" #include "tensorflow/lite/micro/testing/micro_test.h"
#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/version.h" #include "tensorflow/lite/version.h"
@ -46,6 +47,7 @@ TF_LITE_MICRO_TEST(TestInvoke) {
"to supported version %d.\n", "to supported version %d.\n",
model->version(), TFLITE_SCHEMA_VERSION); model->version(), TFLITE_SCHEMA_VERSION);
} }
PrintModelData(model, error_reporter);
// Pull in only the operation implementations we need. // Pull in only the operation implementations we need.
// This relies on a complete list of all the ops needed by this graph. // This relies on a complete list of all the ops needed by this graph.

View File

@ -144,8 +144,10 @@ class GreedyMemoryPlanner : public MemoryPlanner {
int* buffer_sizes_sorted_; int* buffer_sizes_sorted_;
int* buffer_ids_sorted_; int* buffer_ids_sorted_;
ListEntry* buffers_sorted_by_offset_; ListEntry* buffers_sorted_by_offset_;
int next_free_entry_; int next_free_entry_; // Index of the next free entry of
int first_entry_index_; // buffers_sorted_by_offset_
int first_entry_index_; // Index of the first entry (smallest offset) of
// buffers_sorted_by_offset_
// Stores the outcome of the plan, the location of each buffer in the arena. // Stores the outcome of the plan, the location of each buffer in the arena.
int* buffer_offsets_; int* buffer_offsets_;

View File

@ -45,6 +45,8 @@ struct AllocationInfo {
// requirement for SIMD extensions. // requirement for SIMD extensions.
constexpr int kBufferAlignment = 16; constexpr int kBufferAlignment = 16;
constexpr char kOfflineMemAllocMetadata[] = "OfflineMemoryAllocation";
class MicroBuiltinDataAllocator : public BuiltinDataAllocator { class MicroBuiltinDataAllocator : public BuiltinDataAllocator {
public: public:
explicit MicroBuiltinDataAllocator(SimpleMemoryAllocator* memory_allocator) explicit MicroBuiltinDataAllocator(SimpleMemoryAllocator* memory_allocator)
@ -81,33 +83,6 @@ TfLiteStatus AllocateVariables(
return kTfLiteOk; return kTfLiteOk;
} }
// Helper function to print model flatbuffer data. This function is not called
// by default. Hence it's not linked in to the final binary code.
void PrintModelData(const Model* model, ErrorReporter* error_reporter) {
auto* subgraphs = model->subgraphs();
const SubGraph* subgraph = (*subgraphs)[0];
const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors =
subgraph->tensors();
const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers =
model->buffers();
TF_LITE_REPORT_ERROR(error_reporter, "==== Model info: =====");
for (int i = 0; i < tensors->size(); ++i) {
const tflite::Tensor& flatbuffer_tensor = *tensors->Get(i);
auto* quantization = flatbuffer_tensor.quantization();
size_t type_size, tensor_size;
auto* buffer = (*buffers)[flatbuffer_tensor.buffer()];
auto* array = buffer->data();
int array_size = 0;
if (array) {
array_size = array->size();
}
BytesRequiredForTensor(flatbuffer_tensor, &tensor_size, &type_size,
error_reporter);
TF_LITE_REPORT_ERROR(
error_reporter, "Tensor index: %d arena tensor %d size %d", i,
!array_size && !flatbuffer_tensor.is_variable(), tensor_size);
}
}
// Helper function to check flatbuffer metadata correctness. This function is // Helper function to check flatbuffer metadata correctness. This function is
// not called by default. Hence it's not linked in to the final binary code. // not called by default. Hence it's not linked in to the final binary code.
@ -116,8 +91,8 @@ TfLiteStatus CheckOfflinePlannedOffsets(const Model* model,
if (model->metadata()) { if (model->metadata()) {
for (int i = 0; i < model->metadata()->size(); ++i) { for (int i = 0; i < model->metadata()->size(); ++i) {
auto metadata = model->metadata()->Get(i); auto metadata = model->metadata()->Get(i);
if (strncmp(metadata->name()->c_str(), "OfflineMemoryAllocation", if (strncmp(metadata->name()->c_str(), kOfflineMemAllocMetadata,
strlen("OfflineMemoryAllocation")) == 0) { strlen(kOfflineMemAllocMetadata)) == 0) {
auto* subgraphs = model->subgraphs(); auto* subgraphs = model->subgraphs();
const SubGraph* subgraph = (*subgraphs)[0]; const SubGraph* subgraph = (*subgraphs)[0];
const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors = const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors =
@ -311,8 +286,9 @@ TfLiteStatus AllocationInfoBuilder::AddTensors(const SubGraph* subgraph,
// | name:string | “OfflineMemoryAllocation” | // | name:string | “OfflineMemoryAllocation” |
// | buffer:unit | Index of buffer containing memory allocation data | // | buffer:unit | Index of buffer containing memory allocation data |
// //
// The buffer contents for the memory allocation is a list of 32-bit integers of // The buffer contents for the memory allocation is a list of 32-bit integers.
// the following format: // The number of tensors, n, must be equal to the number of tensors defined in
// the model. The following encoding applies:
// //
// | Offset | Value | // | Offset | Value |
// | 0 | Offline allocation format version set to 0 | // | 0 | Offline allocation format version set to 0 |
@ -326,8 +302,8 @@ TfLiteStatus AllocationInfoBuilder::GetOfflinePlannedOffsets(
if (model->metadata()) { if (model->metadata()) {
for (int i = 0; i < model->metadata()->size(); ++i) { for (int i = 0; i < model->metadata()->size(); ++i) {
auto metadata = model->metadata()->Get(i); auto metadata = model->metadata()->Get(i);
if (strncmp(metadata->name()->c_str(), "OfflineMemoryAllocation", if (strncmp(metadata->name()->c_str(), kOfflineMemAllocMetadata,
strlen("OfflineMemoryAllocation")) == 0) { strlen(kOfflineMemAllocMetadata)) == 0) {
const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers = const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers =
model->buffers(); model->buffers();
auto* buffer = (*buffers)[metadata->buffer()]; auto* buffer = (*buffers)[metadata->buffer()];
@ -365,7 +341,8 @@ TfLiteStatus AllocationInfoBuilder::AddScratchBuffers(
return kTfLiteOk; return kTfLiteOk;
} }
TfLiteStatus CreatePlan(ErrorReporter* error_reporter, MemoryPlanner* planner, TfLiteStatus CreatePlan(ErrorReporter* error_reporter,
GreedyMemoryPlanner* planner,
const AllocationInfo* allocation_info, const AllocationInfo* allocation_info,
size_t allocation_info_size) { size_t allocation_info_size) {
// Add the tensors to our allocation plan. // Add the tensors to our allocation plan.
@ -380,10 +357,9 @@ TfLiteStatus CreatePlan(ErrorReporter* error_reporter, MemoryPlanner* planner,
current->first_created, current->last_used)); current->first_created, current->last_used));
} else { } else {
TF_LITE_ENSURE_STATUS( TF_LITE_ENSURE_STATUS(
(static_cast<GreedyMemoryPlanner*>(planner)) planner->AddBuffer(error_reporter, aligned_bytes_required,
->AddBuffer(error_reporter, aligned_bytes_required, current->first_created, current->last_used,
current->first_created, current->last_used, current->offline_offset));
current->offline_offset));
} }
} }
} }

View File

@ -22,6 +22,8 @@ limitations under the License.
#include <cinttypes> #include <cinttypes>
#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/micro/memory_helpers.h"
namespace tflite { namespace tflite {
namespace { namespace {
@ -100,6 +102,35 @@ const char* AllocTypeName(TfLiteAllocationType type) {
} }
} // namespace } // namespace
// Helper function to print model flatbuffer data. This function is not called
// by default. Hence it's not linked in to the final binary code.
void PrintModelData(const Model* model, ErrorReporter* error_reporter) {
auto* subgraphs = model->subgraphs();
const SubGraph* subgraph = (*subgraphs)[0];
const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors =
subgraph->tensors();
const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers =
model->buffers();
TF_LITE_REPORT_ERROR(error_reporter, "==== Model info: =====");
for (int i = 0; i < tensors->size(); ++i) {
const tflite::Tensor& flatbuffer_tensor = *tensors->Get(i);
auto* quantization = flatbuffer_tensor.quantization();
size_t type_size, tensor_size;
auto* buffer = (*buffers)[flatbuffer_tensor.buffer()];
auto* array = buffer->data();
int array_size = 0;
if (array) {
array_size = array->size();
}
BytesRequiredForTensor(flatbuffer_tensor, &tensor_size, &type_size,
error_reporter);
TF_LITE_REPORT_ERROR(
error_reporter,
"Tensor index: %d arena tensor %d size %d ",
i, !array_size && !flatbuffer_tensor.is_variable(), tensor_size);
}
}
// Prints a dump of what tensors and what nodes are in the interpreter. // Prints a dump of what tensors and what nodes are in the interpreter.
void PrintInterpreterState(MicroInterpreter* interpreter) { void PrintInterpreterState(MicroInterpreter* interpreter) {
printf("Interpreter has %zu tensors and %zu nodes\n", printf("Interpreter has %zu tensors and %zu nodes\n",

View File

@ -20,6 +20,9 @@ limitations under the License.
#include "tensorflow/lite/micro/micro_interpreter.h" #include "tensorflow/lite/micro/micro_interpreter.h"
namespace tflite { namespace tflite {
// Helper function to print model flatbuffer data. This function is not called
// by default. Hence it's not linked in to the final binary code.
void PrintModelData(const Model* model, ErrorReporter* error_reporter);
// Prints a dump of what tensors and what nodes are in the interpreter. // Prints a dump of what tensors and what nodes are in the interpreter.
void PrintInterpreterState(MicroInterpreter* interpreter); void PrintInterpreterState(MicroInterpreter* interpreter);
} // namespace tflite } // namespace tflite