Use internal graph indices as delegate indices.

PiperOrigin-RevId: 308433121
Change-Id: Ifd96ddfaeeea6be25cc4ef40589835f24554a907
This commit is contained in:
A. Unique TensorFlower 2020-04-25 12:46:05 -07:00 committed by TensorFlower Gardener
parent 7f37206771
commit d17488bd97

View File

@ -121,7 +121,10 @@ class DelegateKernel {
// Extract TFLite delegate execution plan from the context and convert it
// into GraphFloat32.
GraphFloat32 graph;
RETURN_IF_ERROR(InitializeGraph(context, delegate_params, &graph));
std::vector<uint32_t> input_refs;
std::vector<uint32_t> output_refs;
RETURN_IF_ERROR(InitializeGraph(context, delegate_params, &graph,
&input_refs, &output_refs));
std::unique_ptr<InferenceBuilder> builder;
bool graph_is_destroyed;
@ -142,7 +145,8 @@ class DelegateKernel {
// Graph needs to be re-created because it is moved above.
GraphFloat32 graph2;
if (graph_is_destroyed) {
RETURN_IF_ERROR(InitializeGraph(context, delegate_params, &graph2));
RETURN_IF_ERROR(InitializeGraph(context, delegate_params, &graph2,
&input_refs, &output_refs));
}
RETURN_IF_ERROR(InitializeOpenGlApi(
graph_is_destroyed ? &graph2 : &graph, &builder));
@ -151,17 +155,19 @@ class DelegateKernel {
// At this point tflite didn't allocate tensors yet, therefore, collect
// indices and set all input and output tensors from tflite later.
input_indices_.resize(graph.inputs().size());
for (int i = 0; i < input_indices_.size(); ++i) {
const int64_t tflite_tensor_id = graph.inputs()[i]->tensor.ref;
input_indices_.push_back(tflite_tensor_id);
RETURN_IF_ERROR(builder->SetInputObjectDef(i, GetObjectDef()));
input_indices_.reserve(input_refs.size());
for (uint32_t tensor_index : input_refs) {
const int64_t object_index = input_indices_.size();
input_indices_.push_back(tensor_index);
RETURN_IF_ERROR(
builder->SetInputObjectDef(object_index, GetObjectDef(tensor_index)));
}
output_indices_.resize(graph.outputs().size());
for (int i = 0; i < output_indices_.size(); ++i) {
const int64_t tflite_tensor_id = graph.outputs()[i]->tensor.ref;
output_indices_.push_back(tflite_tensor_id);
RETURN_IF_ERROR(builder->SetOutputObjectDef(i, GetObjectDef()));
output_indices_.reserve(output_refs.size());
for (uint32_t tensor_index : output_refs) {
const int64_t object_index = output_indices_.size();
output_indices_.push_back(tensor_index);
RETURN_IF_ERROR(builder->SetOutputObjectDef(object_index,
GetObjectDef(tensor_index)));
}
return builder->Build(&runner_);
@ -227,7 +233,7 @@ class DelegateKernel {
return absl::OkStatus();
}
ObjectDef GetObjectDef() const {
ObjectDef GetObjectDef(int index) const {
ObjectDef default_object_def;
default_object_def.data_type = DataType::FLOAT32;
default_object_def.data_layout = DataLayout::BHWC;
@ -244,7 +250,9 @@ class DelegateKernel {
private:
absl::Status InitializeGraph(TfLiteContext* context,
const TfLiteDelegateParams* delegate_params,
GraphFloat32* graph) {
GraphFloat32* graph,
std::vector<uint32_t>* input_refs,
std::vector<uint32_t>* output_refs) {
quant_conversion_map_.clear();
if (delegate_->IsQuantOpsAllowed()) {
RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, graph,
@ -252,6 +260,20 @@ class DelegateKernel {
} else {
RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, graph));
}
input_refs->clear();
output_refs->clear();
const auto& inputs = graph->inputs();
input_refs->reserve(inputs.size());
for (const auto& input : inputs) {
input_refs->push_back(input->tensor.ref);
}
const auto& outputs = graph->outputs();
output_refs->reserve(outputs.size());
for (const auto& output : outputs) {
output_refs->push_back(output->tensor.ref);
}
return absl::OkStatus();
}