Experimental feature.

Using special kernels for inference improvements.
Added new hint to control behavior.

PiperOrigin-RevId: 322465118
Change-Id: I62c2a3ddc75907f2d9e455b7454e1de8c54a9881
This commit is contained in:
Raman Sarokin 2020-07-21 16:20:04 -07:00 committed by TensorFlower Gardener
parent 777bdc36a2
commit 2a150a026a
8 changed files with 242 additions and 44 deletions

View File

@ -366,6 +366,7 @@ cc_library(
":tensor_type",
"//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
"//tensorflow/lite/delegates/gpu/cl/selectors:operation_selector",
"//tensorflow/lite/delegates/gpu/cl/selectors:special_selector",
"//tensorflow/lite/delegates/gpu/common:data_type",
"//tensorflow/lite/delegates/gpu/common:memory_management",
"//tensorflow/lite/delegates/gpu/common:model",

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/cl/model_hints.h"
#include "tensorflow/lite/delegates/gpu/cl/precision.h"
#include "tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h"
#include "tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h"
#include "tensorflow/lite/delegates/gpu/cl/storage_type_util.h"
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
@ -261,6 +262,12 @@ void InferenceContext::ReserveGraphTensors(
absl::Status InferenceContext::ConvertOperations(
const CreationContext& creation_context, const GraphFloat32& graph,
ModelHints hints) {
std::map<ValueId, TensorDescriptor> tensor_descriptors;
const auto values = graph.values();
for (auto value : values) {
tensor_descriptors[value->id] = tensor_reserver_.Get(value->id).descriptor;
}
std::set<NodeId> consumed_nodes;
std::vector<Node*> graph_nodes = graph.nodes();
std::map<ValueId, int>
tensor_usages; // keeps latest index of operation that updated tensor
@ -270,45 +277,54 @@ absl::Status InferenceContext::ConvertOperations(
}
for (int i = 0; i < graph_nodes.size(); ++i) {
const Node& node = *graph_nodes[i];
auto inputs = graph.FindInputs(node.id);
auto outputs = graph.FindOutputs(node.id);
// Reordering of input ids and updating of temporary tensors_usage struct.
// This stage is necessary because we are building OperationDef that rely on
// order of input ids. But we also should have input id on first position
// that potentially can be "linking" tensor and as result eliminated(unused)
// We apply it only for ADD operation, because of ADD associativity and
// ADD can be linked.
// In current approach "linking" tensor can be only latest written
// tensor(during linear order of execution) among input tensors.
if (IsGenericAdd(node, inputs, outputs)) {
int latest_written_tensor_index = 0;
int last_usage = tensor_usages[inputs[0]->id];
for (int j = 1; j < inputs.size(); ++j) {
if (tensor_usages[inputs[j]->id] > last_usage) {
last_usage = tensor_usages[inputs[j]->id];
latest_written_tensor_index = j;
}
}
std::swap(inputs[0], inputs[latest_written_tensor_index]);
}
for (const auto& out_id : outputs) {
tensor_usages[out_id->id] = i;
}
OperationDef op_def;
op_def.precision = precision_;
for (int j = 0; j < inputs.size(); ++j) {
op_def.src_tensors.push_back(
tensor_reserver_.Get(inputs[j]->id).descriptor);
}
for (int j = 0; j < outputs.size(); ++j) {
op_def.dst_tensors.push_back(
tensor_reserver_.Get(outputs[j]->id).descriptor);
if (consumed_nodes.find(node.id) != consumed_nodes.end()) {
continue;
}
GPUOperationsSubgraph gpu_subgraph;
RETURN_IF_ERROR(GPUOperationFromNode(creation_context, op_def, hints,
inputs, outputs, node, &gpu_subgraph));
if (hints.Check(ModelHints::kAllowSpecialKernels) &&
GPUSubgraphFromGraph(creation_context, precision_, graph, node.id,
tensor_descriptors, &consumed_nodes, &gpu_subgraph)
.ok()) {
// Mapping of subgraph (set of nodes) to GPU operations. Should happen
// before straigtforward mapping.
} else {
// Straigtforward mapping of one graph node to GPU operations.
auto inputs = graph.FindInputs(node.id);
auto outputs = graph.FindOutputs(node.id);
// Reordering of input ids and updating of temporary tensors_usage struct.
// This stage is necessary because we are building OperationDef that rely
// on order of input ids. But we also should have input id on first
// position that potentially can be "linking" tensor and as result
// eliminated(unused) We apply it only for ADD operation, because of ADD
// associativity and ADD can be linked. In current approach "linking"
// tensor can be only latest written tensor(during linear order of
// execution) among input tensors.
if (IsGenericAdd(node, inputs, outputs)) {
int latest_written_tensor_index = 0;
int last_usage = tensor_usages[inputs[0]->id];
for (int j = 1; j < inputs.size(); ++j) {
if (tensor_usages[inputs[j]->id] > last_usage) {
last_usage = tensor_usages[inputs[j]->id];
latest_written_tensor_index = j;
}
}
std::swap(inputs[0], inputs[latest_written_tensor_index]);
}
consumed_nodes.insert(node.id);
OperationDef op_def;
op_def.precision = precision_;
for (int j = 0; j < inputs.size(); ++j) {
op_def.src_tensors.push_back(
tensor_reserver_.Get(inputs[j]->id).descriptor);
}
for (int j = 0; j < outputs.size(); ++j) {
op_def.dst_tensors.push_back(
tensor_reserver_.Get(outputs[j]->id).descriptor);
}
RETURN_IF_ERROR(GPUOperationFromNode(creation_context, op_def, hints,
inputs, outputs, node,
&gpu_subgraph));
}
std::unordered_map<int, ValueId> mapping_to_global_ids;
for (int j = 0; j < gpu_subgraph.new_tensors.size(); ++j) {
const auto& t = gpu_subgraph.new_tensors[j];
@ -324,7 +340,7 @@ absl::Status InferenceContext::ConvertOperations(
for (int j = 0; j < gpu_op.input_ids.size(); ++j) {
int id = gpu_op.input_ids[j];
if (id >= 0) {
cl_node.inputs[j] = inputs[id]->id;
cl_node.inputs[j] = id;
} else {
cl_node.inputs[j] = mapping_to_global_ids[-(id + 1)];
}
@ -333,7 +349,8 @@ absl::Status InferenceContext::ConvertOperations(
for (int j = 0; j < gpu_op.output_ids.size(); ++j) {
int id = gpu_op.output_ids[j];
if (id >= 0) {
cl_node.outputs[j] = outputs[id]->id;
cl_node.outputs[j] = id;
tensor_usages[id] = i;
} else {
cl_node.outputs[j] = mapping_to_global_ids[-(id + 1)];
}

View File

@ -25,13 +25,18 @@ namespace cl {
struct ModelHints {
using ModelHint = uint64_t;
// By default we want the fastest inference
// By default we want the fastest inference.
static constexpr ModelHint kFastestInference = 0x00000000;
// Can improve compilation time, but inference can be slower
// Can improve compilation time, but inference can be slower.
static constexpr ModelHint kReduceKernelsCount = 0x00000001;
// Can improve tuning time, but inference can be slower
// Can improve tuning time, but inference can be slower.
static constexpr ModelHint kFastTuning = 0x00000002;
// Experimental.
// Can improve performance and memory consumption, but slow down
// initialization a lot and create more kernels.
static constexpr ModelHint kAllowSpecialKernels = 0x00000004;
void Add(ModelHint hint) {
if (hint == kFastestInference) {
hints = kFastestInference;

View File

@ -152,6 +152,26 @@ cc_library(
],
)
cc_library(
name = "special_selector",
srcs = ["special_selector.cc"],
hdrs = ["special_selector.h"],
deps = [
":subgraph",
"//tensorflow/lite/delegates/gpu/cl:cl_device",
"//tensorflow/lite/delegates/gpu/cl:tensor_type",
"//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
"//tensorflow/lite/delegates/gpu/cl/kernels/special:depthwise_conv_plus_1x1_conv",
"//tensorflow/lite/delegates/gpu/common:data_type",
"//tensorflow/lite/delegates/gpu/common:model",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:shape",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common:tensor",
"@com_google_absl//absl/types:any",
],
)
cc_library(
name = "subgraph",
srcs = ["subgraph.cc"],

View File

@ -0,0 +1,111 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h"
#include "absl/types/any.h"
#include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/special/depthwise_conv_plus_1x1_conv.h"
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
namespace tflite {
namespace gpu {
namespace cl {
namespace {
absl::Status TryDepthwiseConvPlus1x1Conv(
const CreationContext& creation_context, CalculationsPrecision precision,
const GraphFloat32& graph, NodeId first_node_id,
const std::map<ValueId, TensorDescriptor>& tensor_descriptors,
std::set<NodeId>* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph) {
auto* dw_node = graph.GetNode(first_node_id);
if (OperationTypeFromString(dw_node->operation.type) !=
OperationType::DEPTHWISE_CONVOLUTION) {
return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable.");
}
auto dw_outputs = graph.FindOutputs(dw_node->id);
auto consumers = graph.FindConsumers(dw_outputs[0]->id);
if (consumers.size() != 1) {
return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable.");
}
auto* conv_node = consumers[0];
if (consumed_nodes->find(conv_node->id) != consumed_nodes->end()) {
return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable.");
}
if (OperationTypeFromString(conv_node->operation.type) !=
OperationType::CONVOLUTION_2D) {
return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable.");
}
if (graph.FindInputs(conv_node->id).size() != 1) {
return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable.");
}
auto dw_attr = absl::any_cast<DepthwiseConvolution2DAttributes>(
dw_node->operation.attributes);
auto conv_attr =
absl::any_cast<Convolution2DAttributes>(conv_node->operation.attributes);
auto dw_inputs = graph.FindInputs(dw_node->id);
auto conv_outputs = graph.FindOutputs(conv_node->id);
OperationDef op_def;
op_def.precision = precision;
auto it = tensor_descriptors.find(dw_inputs[0]->id);
if (it != tensor_descriptors.end()) {
op_def.src_tensors.push_back(it->second);
}
it = tensor_descriptors.find(conv_outputs[0]->id);
if (it != tensor_descriptors.end()) {
op_def.dst_tensors.push_back(it->second);
}
if (!IsDepthwiseConvPlus1x1ConvSupported(*creation_context.device, op_def,
dw_attr, conv_attr)) {
return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable.");
}
std::unique_ptr<GPUOperation>* gpu_op =
InitSingleOpSubgraph(dw_inputs, conv_outputs, gpu_subgraph);
DepthwiseConvPlus1x1Conv operation;
RETURN_IF_ERROR(CreateDepthwiseConvPlus1x1Conv(
creation_context, op_def, dw_attr, conv_attr, &operation));
*gpu_op = absl::make_unique<DepthwiseConvPlus1x1Conv>(std::move(operation));
consumed_nodes->insert(dw_node->id);
consumed_nodes->insert(conv_node->id);
return absl::OkStatus();
}
} // namespace
absl::Status GPUSubgraphFromGraph(
const CreationContext& creation_context, CalculationsPrecision precision,
const GraphFloat32& graph, NodeId first_node_id,
const std::map<ValueId, TensorDescriptor>& tensor_descriptors,
std::set<NodeId>* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph) {
if (!creation_context.device->IsNvidia()) {
return absl::NotFoundError(
"Experimental feature, enabled for NVidia only, but device is not "
"nvidia gpu.");
}
if (TryDepthwiseConvPlus1x1Conv(creation_context, precision, graph,
first_node_id, tensor_descriptors,
consumed_nodes, gpu_subgraph)
.ok()) {
return absl::OkStatus();
}
return absl::NotFoundError("No special combination.");
}
} // namespace cl
} // namespace gpu
} // namespace tflite

View File

@ -0,0 +1,43 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SPECIAL_SELECTOR_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SPECIAL_SELECTOR_H_
#include <map>
#include <set>
#include <vector>
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
#include "tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h"
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
namespace tflite {
namespace gpu {
namespace cl {
absl::Status GPUSubgraphFromGraph(
const CreationContext& creation_context, CalculationsPrecision precision,
const GraphFloat32& graph, NodeId first_node_id,
const std::map<ValueId, TensorDescriptor>& tensor_descriptors,
std::set<NodeId>* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph);
} // namespace cl
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SPECIAL_SELECTOR_H_

View File

@ -32,10 +32,10 @@ std::unique_ptr<GPUOperation>* InitSingleOpSubgraph(
gpu_subgraph->new_tensors.clear();
gpu_subgraph->operations.push_back({});
for (int i = 0; i < inputs.size(); ++i) {
gpu_subgraph->operations[0].input_ids.push_back(i);
gpu_subgraph->operations[0].input_ids.push_back(inputs[i]->id);
}
for (int i = 0; i < outputs.size(); ++i) {
gpu_subgraph->operations[0].output_ids.push_back(i);
gpu_subgraph->operations[0].output_ids.push_back(outputs[i]->id);
}
return &gpu_subgraph->operations[0].operation;

View File

@ -44,6 +44,7 @@ absl::Status RunModelSample(const std::string& model_name) {
? CalculationsPrecision::F16
: CalculationsPrecision::F32;
create_info.storage_type = GetFastestStorageType(env.device());
create_info.hints.Add(ModelHints::kAllowSpecialKernels);
std::cout << "Precision: " << ToString(create_info.precision) << std::endl;
std::cout << "Storage type: " << ToString(create_info.storage_type)
<< std::endl;