Use internal graph indices as delegate indices.
PiperOrigin-RevId: 308433121 Change-Id: Ifd96ddfaeeea6be25cc4ef40589835f24554a907
This commit is contained in:
parent
7f37206771
commit
d17488bd97
@ -121,7 +121,10 @@ class DelegateKernel {
|
||||
// Extract TFLite delegate execution plan from the context and convert it
|
||||
// into GraphFloat32.
|
||||
GraphFloat32 graph;
|
||||
RETURN_IF_ERROR(InitializeGraph(context, delegate_params, &graph));
|
||||
std::vector<uint32_t> input_refs;
|
||||
std::vector<uint32_t> output_refs;
|
||||
RETURN_IF_ERROR(InitializeGraph(context, delegate_params, &graph,
|
||||
&input_refs, &output_refs));
|
||||
|
||||
std::unique_ptr<InferenceBuilder> builder;
|
||||
bool graph_is_destroyed;
|
||||
@ -142,7 +145,8 @@ class DelegateKernel {
|
||||
// Graph needs to be re-created because it is moved above.
|
||||
GraphFloat32 graph2;
|
||||
if (graph_is_destroyed) {
|
||||
RETURN_IF_ERROR(InitializeGraph(context, delegate_params, &graph2));
|
||||
RETURN_IF_ERROR(InitializeGraph(context, delegate_params, &graph2,
|
||||
&input_refs, &output_refs));
|
||||
}
|
||||
RETURN_IF_ERROR(InitializeOpenGlApi(
|
||||
graph_is_destroyed ? &graph2 : &graph, &builder));
|
||||
@ -151,17 +155,19 @@ class DelegateKernel {
|
||||
|
||||
// At this point tflite didn't allocate tensors yet, therefore, collect
|
||||
// indices and set all input and output tensors from tflite later.
|
||||
input_indices_.resize(graph.inputs().size());
|
||||
for (int i = 0; i < input_indices_.size(); ++i) {
|
||||
const int64_t tflite_tensor_id = graph.inputs()[i]->tensor.ref;
|
||||
input_indices_.push_back(tflite_tensor_id);
|
||||
RETURN_IF_ERROR(builder->SetInputObjectDef(i, GetObjectDef()));
|
||||
input_indices_.reserve(input_refs.size());
|
||||
for (uint32_t tensor_index : input_refs) {
|
||||
const int64_t object_index = input_indices_.size();
|
||||
input_indices_.push_back(tensor_index);
|
||||
RETURN_IF_ERROR(
|
||||
builder->SetInputObjectDef(object_index, GetObjectDef(tensor_index)));
|
||||
}
|
||||
output_indices_.resize(graph.outputs().size());
|
||||
for (int i = 0; i < output_indices_.size(); ++i) {
|
||||
const int64_t tflite_tensor_id = graph.outputs()[i]->tensor.ref;
|
||||
output_indices_.push_back(tflite_tensor_id);
|
||||
RETURN_IF_ERROR(builder->SetOutputObjectDef(i, GetObjectDef()));
|
||||
output_indices_.reserve(output_refs.size());
|
||||
for (uint32_t tensor_index : output_refs) {
|
||||
const int64_t object_index = output_indices_.size();
|
||||
output_indices_.push_back(tensor_index);
|
||||
RETURN_IF_ERROR(builder->SetOutputObjectDef(object_index,
|
||||
GetObjectDef(tensor_index)));
|
||||
}
|
||||
|
||||
return builder->Build(&runner_);
|
||||
@ -227,7 +233,7 @@ class DelegateKernel {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
ObjectDef GetObjectDef() const {
|
||||
ObjectDef GetObjectDef(int index) const {
|
||||
ObjectDef default_object_def;
|
||||
default_object_def.data_type = DataType::FLOAT32;
|
||||
default_object_def.data_layout = DataLayout::BHWC;
|
||||
@ -244,7 +250,9 @@ class DelegateKernel {
|
||||
private:
|
||||
absl::Status InitializeGraph(TfLiteContext* context,
|
||||
const TfLiteDelegateParams* delegate_params,
|
||||
GraphFloat32* graph) {
|
||||
GraphFloat32* graph,
|
||||
std::vector<uint32_t>* input_refs,
|
||||
std::vector<uint32_t>* output_refs) {
|
||||
quant_conversion_map_.clear();
|
||||
if (delegate_->IsQuantOpsAllowed()) {
|
||||
RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, graph,
|
||||
@ -252,6 +260,20 @@ class DelegateKernel {
|
||||
} else {
|
||||
RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, graph));
|
||||
}
|
||||
|
||||
input_refs->clear();
|
||||
output_refs->clear();
|
||||
const auto& inputs = graph->inputs();
|
||||
input_refs->reserve(inputs.size());
|
||||
for (const auto& input : inputs) {
|
||||
input_refs->push_back(input->tensor.ref);
|
||||
}
|
||||
const auto& outputs = graph->outputs();
|
||||
output_refs->reserve(outputs.size());
|
||||
for (const auto& output : outputs) {
|
||||
output_refs->push_back(output->tensor.ref);
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user