TFLu: Support operator only input tensors.

Ethos-u relies on differentiating between having operator input tensors
as subgraph inputs or not.
This commit is contained in:
Måns Nilsson 2020-08-12 12:15:28 +02:00
parent 2303ed4bdb
commit 82a899657a
4 changed files with 93 additions and 8 deletions

View File

@ -242,6 +242,18 @@ TfLiteStatus AllocationInfoBuilder::AddTensors(const SubGraph* subgraph,
for (size_t n = 0; n < op->inputs()->size(); ++n) {
const int tensor_index = op->inputs()->Get(n);
AllocationInfo* current = &info_[tensor_index];
// In case operator input are not in subgraph inputs initialize them.
if (current->first_created == 0) {
for (size_t op_input = 0; op_input < op->inputs()->size(); ++op_input) {
const int op_tensor_index = op->inputs()->Get(op_input);
AllocationInfo* op_current = &info_[op_tensor_index];
if (op_current->needs_allocating && op_current->first_created == -1) {
op_current->first_created = i;
}
}
}
if (((current->last_used == -1) || (current->last_used < i))) {
current->last_used = i;
}

View File

@ -735,4 +735,64 @@ TF_LITE_MICRO_TEST(TestAllocateTfLiteTensorWithReset) {
TF_LITE_MICRO_EXPECT(tensor2 == tensor1);
}
TF_LITE_MICRO_TEST(TestOperatorInputsNotInSubgraphInputs) {
constexpr int nbr_tensors = 5;
tflite::AllOpsResolver op_resolver = tflite::testing::GetOpResolver();
tflite::NodeAndRegistration* node_and_registration;
const int32_t metadata_buffer[tflite::testing::kOfflinePlannerHeaderSize +
nbr_tensors] = {
1, 0, nbr_tensors, // header: version, subgraph, nbr tensors
// memory offsets:
0, // t0
0, // t1
0, // t2
48, // t3
-1}; // t4
int t0 = 0;
int t1 = 1;
int t2 = 2;
int t3 = 3;
int t4 = 4;
int num_conns = 2;
tflite::testing::NodeConnection node_list[2] = {
{
{t0, t1, t2}, // t0: input (actual input part of subgraph inputs as well as
// operator inputs)
// t1: scratch1 (only in operator inputs)
// t2: scratch2 (only in operator inputs)
{t3} // output
},
{
{t3}, // input
{t4} // output
},
};
const tflite::Model* model = tflite::testing::GetModelWithOfflinePlanning(
nbr_tensors, metadata_buffer, node_list, num_conns,
1/* only first tensor (t0) is in subgraph input list*/);
TfLiteEvalTensor* eval_tensors = nullptr;
constexpr size_t arena_size = 4096;
uint8_t arena[arena_size];
tflite::MicroAllocator* allocator =
tflite::MicroAllocator::Create(arena, arena_size, micro_test::reporter);
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk,
allocator->StartModelAllocation(model, op_resolver,
&node_and_registration, &eval_tensors));
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, allocator->FinishModelAllocation(model, eval_tensors));
uint8_t* start = eval_tensors[0].data.uint8;
TF_LITE_MICRO_EXPECT_EQ(0, eval_tensors[0].data.uint8 - start);
TF_LITE_MICRO_EXPECT_EQ(0, eval_tensors[1].data.uint8 - start);
TF_LITE_MICRO_EXPECT_EQ(0, eval_tensors[2].data.uint8 - start);
TF_LITE_MICRO_EXPECT_EQ(48, eval_tensors[3].data.uint8 - start);
TF_LITE_MICRO_EXPECT_EQ(0, eval_tensors[4].data.uint8 - start);
}
TF_LITE_MICRO_TESTS_END

View File

@ -108,7 +108,8 @@ class ModelBuilder {
// Constructs the flatbuffer model using `builder_` and return a pointer to
// it. The returned model has the same lifetime as `builder_`.
const Model* BuildModel(std::initializer_list<Tensor> inputs,
std::initializer_list<Tensor> outputs);
std::initializer_list<Tensor> outputs,
size_t num_subgraph_inputs = 0);
private:
// Adds a tensor to the model.
@ -179,7 +180,8 @@ void ModelBuilder::AddMetadata(const char* description_string,
const Model* ModelBuilder::BuildModel(
std::initializer_list<ModelBuilder::Tensor> inputs,
std::initializer_list<ModelBuilder::Tensor> outputs) {
std::initializer_list<ModelBuilder::Tensor> outputs,
size_t num_subgraph_inputs) {
// Model schema requires an empty buffer at idx 0.
size_t buffer_size = 1 + ModelBuilder::nbr_of_metadata_buffers_;
flatbuffers::Offset<Buffer> buffers[kMaxMetadataBuffers];
@ -193,10 +195,17 @@ const Model* ModelBuilder::BuildModel(
// TFLM only supports single subgraph.
constexpr size_t subgraphs_size = 1;
// Find out number of subgraph inputs.
int num_tensors_in_subgraph_inputs = inputs.size(); // Default case = all inputs
if (num_subgraph_inputs > 0 && num_subgraph_inputs < inputs.size()) {
num_tensors_in_subgraph_inputs = num_subgraph_inputs;
}
const flatbuffers::Offset<SubGraph> subgraphs[subgraphs_size] = {
tflite::CreateSubGraph(
*builder_, builder_->CreateVector(tensors_, next_tensor_id_),
builder_->CreateVector(inputs.begin(), inputs.size()),
builder_->CreateVector(inputs.begin(), num_tensors_in_subgraph_inputs),
builder_->CreateVector(outputs.begin(), outputs.size()),
builder_->CreateVector(operators_, next_operator_id_),
builder_->CreateString("test_subgraph"))};
@ -301,7 +310,8 @@ const Model* BuildSimpleModelWithBranch() {
const Model* BuildModelWithOfflinePlanning(int number_of_tensors,
const int32_t* metadata_buffer,
NodeConnection* node_conn,
int num_conns) {
int num_conns,
int num_subgraph_inputs) {
using flatbuffers::Offset;
flatbuffers::FlatBufferBuilder* fb_builder = BuilderInstance();
@ -324,7 +334,8 @@ const Model* BuildModelWithOfflinePlanning(int number_of_tensors,
number_of_tensors + tflite::testing::kOfflinePlannerHeaderSize);
return model_builder.BuildModel(node_conn[0].input,
node_conn[num_conns - 1].output);
node_conn[num_conns - 1].output,
num_subgraph_inputs);
}
const Model* BuildSimpleMockModel() {
@ -710,9 +721,10 @@ const Model* GetSimpleModelWithBranch() {
const Model* GetModelWithOfflinePlanning(int num_tensors,
const int32_t* metadata_buffer,
NodeConnection* node_conn,
int num_conns) {
int num_conns,
int num_subgraph_inputs) {
const Model* model = BuildModelWithOfflinePlanning(
num_tensors, metadata_buffer, node_conn, num_conns);
num_tensors, metadata_buffer, node_conn, num_conns, num_subgraph_inputs);
return model;
}

View File

@ -91,7 +91,8 @@ const Model* GetSimpleModelWithBranch();
const Model* GetModelWithOfflinePlanning(int num_tensors,
const int32_t* metadata_buffer,
NodeConnection* node_conn,
int num_conns);
int num_conns,
int num_subgraph_inputs = 0/*0 means all*/);
// Returns a flatbuffer model with `simple_stateful_op`
const Model* GetSimpleStatefulModel();