TFLu: Update Ethos-U related test with comments from review
This commit is contained in:
parent
82a899657a
commit
2411dcbbf7
@ -243,7 +243,10 @@ TfLiteStatus AllocationInfoBuilder::AddTensors(const SubGraph* subgraph,
|
|||||||
const int tensor_index = op->inputs()->Get(n);
|
const int tensor_index = op->inputs()->Get(n);
|
||||||
AllocationInfo* current = &info_[tensor_index];
|
AllocationInfo* current = &info_[tensor_index];
|
||||||
|
|
||||||
// In case operator input are not in subgraph inputs initialize them.
|
// TODO(b/166484865): Figure out a more general solution.
|
||||||
|
// This workaround is needed to handle situations where subgraph input !=
|
||||||
|
// operator input.
|
||||||
|
// In case operator input(s) are not in subgraph inputs initialize them.
|
||||||
if (current->first_created == 0) {
|
if (current->first_created == 0) {
|
||||||
for (size_t op_input = 0; op_input < op->inputs()->size(); ++op_input) {
|
for (size_t op_input = 0; op_input < op->inputs()->size(); ++op_input) {
|
||||||
const int op_tensor_index = op->inputs()->Get(op_input);
|
const int op_tensor_index = op->inputs()->Get(op_input);
|
||||||
|
@ -758,8 +758,8 @@ TF_LITE_MICRO_TEST(TestOperatorInputsNotInSubgraphInputs) {
|
|||||||
int num_conns = 2;
|
int num_conns = 2;
|
||||||
tflite::testing::NodeConnection node_list[2] = {
|
tflite::testing::NodeConnection node_list[2] = {
|
||||||
{
|
{
|
||||||
{t0, t1, t2}, // t0: input (actual input part of subgraph inputs as well as
|
{t0, t1, t2}, // t0: input (actual input part of subgraph inputs as
|
||||||
// operator inputs)
|
// well as operator inputs)
|
||||||
// t1: scratch1 (only in operator inputs)
|
// t1: scratch1 (only in operator inputs)
|
||||||
// t2: scratch2 (only in operator inputs)
|
// t2: scratch2 (only in operator inputs)
|
||||||
{t3} // output
|
{t3} // output
|
||||||
|
@ -107,6 +107,8 @@ class ModelBuilder {
|
|||||||
|
|
||||||
// Constructs the flatbuffer model using `builder_` and return a pointer to
|
// Constructs the flatbuffer model using `builder_` and return a pointer to
|
||||||
// it. The returned model has the same lifetime as `builder_`.
|
// it. The returned model has the same lifetime as `builder_`.
|
||||||
|
// Note the default value of 0 for num_subgraph_inputs means all tensor inputs
|
||||||
|
// are in subgraph input list.
|
||||||
const Model* BuildModel(std::initializer_list<Tensor> inputs,
|
const Model* BuildModel(std::initializer_list<Tensor> inputs,
|
||||||
std::initializer_list<Tensor> outputs,
|
std::initializer_list<Tensor> outputs,
|
||||||
size_t num_subgraph_inputs = 0);
|
size_t num_subgraph_inputs = 0);
|
||||||
@ -197,15 +199,19 @@ const Model* ModelBuilder::BuildModel(
|
|||||||
constexpr size_t subgraphs_size = 1;
|
constexpr size_t subgraphs_size = 1;
|
||||||
|
|
||||||
// Find out number of subgraph inputs.
|
// Find out number of subgraph inputs.
|
||||||
int num_tensors_in_subgraph_inputs = inputs.size(); // Default case = all inputs
|
if (num_subgraph_inputs == 0) {
|
||||||
if (num_subgraph_inputs > 0 && num_subgraph_inputs < inputs.size()) {
|
// This is the default case.
|
||||||
num_tensors_in_subgraph_inputs = num_subgraph_inputs;
|
num_subgraph_inputs = inputs.size();
|
||||||
|
} else {
|
||||||
|
// A non-zero value of num_subgraph_inputs means that some of
|
||||||
|
// the operator input tensors are not subgraph inputs.
|
||||||
|
TFLITE_DCHECK(num_subgraph_inputs < inputs.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
const flatbuffers::Offset<SubGraph> subgraphs[subgraphs_size] = {
|
const flatbuffers::Offset<SubGraph> subgraphs[subgraphs_size] = {
|
||||||
tflite::CreateSubGraph(
|
tflite::CreateSubGraph(
|
||||||
*builder_, builder_->CreateVector(tensors_, next_tensor_id_),
|
*builder_, builder_->CreateVector(tensors_, next_tensor_id_),
|
||||||
builder_->CreateVector(inputs.begin(), num_tensors_in_subgraph_inputs),
|
builder_->CreateVector(inputs.begin(), num_subgraph_inputs),
|
||||||
builder_->CreateVector(outputs.begin(), outputs.size()),
|
builder_->CreateVector(outputs.begin(), outputs.size()),
|
||||||
builder_->CreateVector(operators_, next_operator_id_),
|
builder_->CreateVector(operators_, next_operator_id_),
|
||||||
builder_->CreateString("test_subgraph"))};
|
builder_->CreateString("test_subgraph"))};
|
||||||
|
@ -88,11 +88,22 @@ const Model* GetComplexMockModel();
|
|||||||
const Model* GetSimpleModelWithBranch();
|
const Model* GetSimpleModelWithBranch();
|
||||||
|
|
||||||
// Returns a simple flatbuffer model with offline planned tensors
|
// Returns a simple flatbuffer model with offline planned tensors
|
||||||
|
// @param[in] num_tensors Number of tensors in the model.
|
||||||
|
// @param[in] metadata_buffer Metadata for offline planner.
|
||||||
|
// @param[in] node_con List of connections, i.e. operators
|
||||||
|
// in the model.
|
||||||
|
// @param[in] num_conns Number of connections.
|
||||||
|
// @param[in] num_subgraph_inputs How many of the input tensors are in
|
||||||
|
// the subgraph inputs. The default value
|
||||||
|
// of 0 means all of the input tensors
|
||||||
|
// are in the subgraph input list. There
|
||||||
|
// must be at least 1 input tensor in the
|
||||||
|
// subgraph input list.
|
||||||
const Model* GetModelWithOfflinePlanning(int num_tensors,
|
const Model* GetModelWithOfflinePlanning(int num_tensors,
|
||||||
const int32_t* metadata_buffer,
|
const int32_t* metadata_buffer,
|
||||||
NodeConnection* node_conn,
|
NodeConnection* node_conn,
|
||||||
int num_conns,
|
int num_conns,
|
||||||
int num_subgraph_inputs = 0/*0 means all*/);
|
int num_subgraph_inputs = 0);
|
||||||
|
|
||||||
// Returns a flatbuffer model with `simple_stateful_op`
|
// Returns a flatbuffer model with `simple_stateful_op`
|
||||||
const Model* GetSimpleStatefulModel();
|
const Model* GetSimpleStatefulModel();
|
||||||
|
Loading…
Reference in New Issue
Block a user