Use internal graph indices as delegate indices.

PiperOrigin-RevId: 308433121
Change-Id: Ifd96ddfaeeea6be25cc4ef40589835f24554a907
This commit is contained in:
A. Unique TensorFlower 2020-04-25 12:46:05 -07:00 committed by TensorFlower Gardener
parent 7f37206771
commit d17488bd97

View File

@ -121,7 +121,10 @@ class DelegateKernel {
// Extract TFLite delegate execution plan from the context and convert it // Extract TFLite delegate execution plan from the context and convert it
// into GraphFloat32. // into GraphFloat32.
GraphFloat32 graph; GraphFloat32 graph;
RETURN_IF_ERROR(InitializeGraph(context, delegate_params, &graph)); std::vector<uint32_t> input_refs;
std::vector<uint32_t> output_refs;
RETURN_IF_ERROR(InitializeGraph(context, delegate_params, &graph,
&input_refs, &output_refs));
std::unique_ptr<InferenceBuilder> builder; std::unique_ptr<InferenceBuilder> builder;
bool graph_is_destroyed; bool graph_is_destroyed;
@ -142,7 +145,8 @@ class DelegateKernel {
// Graph needs to be re-created because it is moved above. // Graph needs to be re-created because it is moved above.
GraphFloat32 graph2; GraphFloat32 graph2;
if (graph_is_destroyed) { if (graph_is_destroyed) {
RETURN_IF_ERROR(InitializeGraph(context, delegate_params, &graph2)); RETURN_IF_ERROR(InitializeGraph(context, delegate_params, &graph2,
&input_refs, &output_refs));
} }
RETURN_IF_ERROR(InitializeOpenGlApi( RETURN_IF_ERROR(InitializeOpenGlApi(
graph_is_destroyed ? &graph2 : &graph, &builder)); graph_is_destroyed ? &graph2 : &graph, &builder));
@ -151,17 +155,19 @@ class DelegateKernel {
// At this point tflite didn't allocate tensors yet, therefore, collect // At this point tflite didn't allocate tensors yet, therefore, collect
// indices and set all input and output tensors from tflite later. // indices and set all input and output tensors from tflite later.
input_indices_.resize(graph.inputs().size()); input_indices_.reserve(input_refs.size());
for (int i = 0; i < input_indices_.size(); ++i) { for (uint32_t tensor_index : input_refs) {
const int64_t tflite_tensor_id = graph.inputs()[i]->tensor.ref; const int64_t object_index = input_indices_.size();
input_indices_.push_back(tflite_tensor_id); input_indices_.push_back(tensor_index);
RETURN_IF_ERROR(builder->SetInputObjectDef(i, GetObjectDef())); RETURN_IF_ERROR(
builder->SetInputObjectDef(object_index, GetObjectDef(tensor_index)));
} }
output_indices_.resize(graph.outputs().size()); output_indices_.reserve(output_refs.size());
for (int i = 0; i < output_indices_.size(); ++i) { for (uint32_t tensor_index : output_refs) {
const int64_t tflite_tensor_id = graph.outputs()[i]->tensor.ref; const int64_t object_index = output_indices_.size();
output_indices_.push_back(tflite_tensor_id); output_indices_.push_back(tensor_index);
RETURN_IF_ERROR(builder->SetOutputObjectDef(i, GetObjectDef())); RETURN_IF_ERROR(builder->SetOutputObjectDef(object_index,
GetObjectDef(tensor_index)));
} }
return builder->Build(&runner_); return builder->Build(&runner_);
@ -227,7 +233,7 @@ class DelegateKernel {
return absl::OkStatus(); return absl::OkStatus();
} }
ObjectDef GetObjectDef() const { ObjectDef GetObjectDef(int index) const {
ObjectDef default_object_def; ObjectDef default_object_def;
default_object_def.data_type = DataType::FLOAT32; default_object_def.data_type = DataType::FLOAT32;
default_object_def.data_layout = DataLayout::BHWC; default_object_def.data_layout = DataLayout::BHWC;
@ -244,7 +250,9 @@ class DelegateKernel {
private: private:
absl::Status InitializeGraph(TfLiteContext* context, absl::Status InitializeGraph(TfLiteContext* context,
const TfLiteDelegateParams* delegate_params, const TfLiteDelegateParams* delegate_params,
GraphFloat32* graph) { GraphFloat32* graph,
std::vector<uint32_t>* input_refs,
std::vector<uint32_t>* output_refs) {
quant_conversion_map_.clear(); quant_conversion_map_.clear();
if (delegate_->IsQuantOpsAllowed()) { if (delegate_->IsQuantOpsAllowed()) {
RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, graph, RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, graph,
@ -252,6 +260,20 @@ class DelegateKernel {
} else { } else {
RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, graph)); RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, graph));
} }
input_refs->clear();
output_refs->clear();
const auto& inputs = graph->inputs();
input_refs->reserve(inputs.size());
for (const auto& input : inputs) {
input_refs->push_back(input->tensor.ref);
}
const auto& outputs = graph->outputs();
output_refs->reserve(outputs.size());
for (const auto& output : outputs) {
output_refs->push_back(output->tensor.ref);
}
return absl::OkStatus(); return absl::OkStatus();
} }