From 2a150a026a04f7a66eefb5fae4b1587d4665f4e0 Mon Sep 17 00:00:00 2001
From: Raman Sarokin <sorokin@google.com>
Date: Tue, 21 Jul 2020 16:20:04 -0700
Subject: [PATCH] Experimental feature. Using special kernels for inference
 improvements. Added new hint to control behavior.

PiperOrigin-RevId: 322465118
Change-Id: I62c2a3ddc75907f2d9e455b7454e1de8c54a9881
---
 tensorflow/lite/delegates/gpu/cl/BUILD        |   1 +
 .../delegates/gpu/cl/inference_context.cc     |  95 +++++++++------
 .../lite/delegates/gpu/cl/model_hints.h       |  11 +-
 .../lite/delegates/gpu/cl/selectors/BUILD     |  20 ++++
 .../gpu/cl/selectors/special_selector.cc      | 111 ++++++++++++++++++
 .../gpu/cl/selectors/special_selector.h       |  43 +++++++
 .../delegates/gpu/cl/selectors/subgraph.cc    |   4 +-
 .../gpu/cl/testing/performance_profiling.cc   |   1 +
 8 files changed, 242 insertions(+), 44 deletions(-)
 create mode 100644 tensorflow/lite/delegates/gpu/cl/selectors/special_selector.cc
 create mode 100644 tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h

diff --git a/tensorflow/lite/delegates/gpu/cl/BUILD b/tensorflow/lite/delegates/gpu/cl/BUILD
index 9155bc1166a..36cafdb4d3b 100644
--- a/tensorflow/lite/delegates/gpu/cl/BUILD
+++ b/tensorflow/lite/delegates/gpu/cl/BUILD
@@ -366,6 +366,7 @@ cc_library(
         ":tensor_type",
         "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
         "//tensorflow/lite/delegates/gpu/cl/selectors:operation_selector",
+        "//tensorflow/lite/delegates/gpu/cl/selectors:special_selector",
         "//tensorflow/lite/delegates/gpu/common:data_type",
         "//tensorflow/lite/delegates/gpu/common:memory_management",
         "//tensorflow/lite/delegates/gpu/common:model",
diff --git a/tensorflow/lite/delegates/gpu/cl/inference_context.cc b/tensorflow/lite/delegates/gpu/cl/inference_context.cc
index 9e57dd175bc..3067c81ec94 100644
--- a/tensorflow/lite/delegates/gpu/cl/inference_context.cc
+++ b/tensorflow/lite/delegates/gpu/cl/inference_context.cc
@@ -30,6 +30,7 @@ limitations under the License.
 #include "tensorflow/lite/delegates/gpu/cl/model_hints.h"
 #include "tensorflow/lite/delegates/gpu/cl/precision.h"
 #include "tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h"
+#include "tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h"
 #include "tensorflow/lite/delegates/gpu/cl/storage_type_util.h"
 #include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
@@ -261,6 +262,12 @@ void InferenceContext::ReserveGraphTensors(
 absl::Status InferenceContext::ConvertOperations(
     const CreationContext& creation_context, const GraphFloat32& graph,
     ModelHints hints) {
+  std::map<ValueId, TensorDescriptor> tensor_descriptors;
+  const auto values = graph.values();
+  for (auto value : values) {
+    tensor_descriptors[value->id] = tensor_reserver_.Get(value->id).descriptor;
+  }
+  std::set<NodeId> consumed_nodes;
   std::vector<Node*> graph_nodes = graph.nodes();
   std::map<ValueId, int>
       tensor_usages;  // keeps latest index of operation that updated tensor
@@ -270,45 +277,54 @@ absl::Status InferenceContext::ConvertOperations(
   }
   for (int i = 0; i < graph_nodes.size(); ++i) {
     const Node& node = *graph_nodes[i];
-    auto inputs = graph.FindInputs(node.id);
-    auto outputs = graph.FindOutputs(node.id);
-
-    // Reordering of input ids and updating of temporary tensors_usage struct.
-    // This stage is necessary because we are building OperationDef that rely on
-    // order of input ids. But we also should have input id on first position
-    // that potentially can be "linking" tensor and as result eliminated(unused)
-    // We apply it only for ADD operation, because of ADD associativity and
-    // ADD can be linked.
-    // In current approach "linking" tensor can be only latest written
-    // tensor(during linear order of execution) among input tensors.
-    if (IsGenericAdd(node, inputs, outputs)) {
-      int latest_written_tensor_index = 0;
-      int last_usage = tensor_usages[inputs[0]->id];
-      for (int j = 1; j < inputs.size(); ++j) {
-        if (tensor_usages[inputs[j]->id] > last_usage) {
-          last_usage = tensor_usages[inputs[j]->id];
-          latest_written_tensor_index = j;
-        }
-      }
-      std::swap(inputs[0], inputs[latest_written_tensor_index]);
-    }
-    for (const auto& out_id : outputs) {
-      tensor_usages[out_id->id] = i;
-    }
-
-    OperationDef op_def;
-    op_def.precision = precision_;
-    for (int j = 0; j < inputs.size(); ++j) {
-      op_def.src_tensors.push_back(
-          tensor_reserver_.Get(inputs[j]->id).descriptor);
-    }
-    for (int j = 0; j < outputs.size(); ++j) {
-      op_def.dst_tensors.push_back(
-          tensor_reserver_.Get(outputs[j]->id).descriptor);
+    if (consumed_nodes.find(node.id) != consumed_nodes.end()) {
+      continue;
     }
     GPUOperationsSubgraph gpu_subgraph;
-    RETURN_IF_ERROR(GPUOperationFromNode(creation_context, op_def, hints,
-                                         inputs, outputs, node, &gpu_subgraph));
+    if (hints.Check(ModelHints::kAllowSpecialKernels) &&
+        GPUSubgraphFromGraph(creation_context, precision_, graph, node.id,
+                             tensor_descriptors, &consumed_nodes, &gpu_subgraph)
+            .ok()) {
+      // Mapping of subgraph (set of nodes) to GPU operations. Should happen
+      // before straigtforward mapping.
+    } else {
+      // Straigtforward mapping of one graph node to GPU operations.
+      auto inputs = graph.FindInputs(node.id);
+      auto outputs = graph.FindOutputs(node.id);
+      // Reordering of input ids and updating of temporary tensors_usage struct.
+      // This stage is necessary because we are building OperationDef that rely
+      // on order of input ids. But we also should have input id on first
+      // position that potentially can be "linking" tensor and as result
+      // eliminated(unused) We apply it only for ADD operation, because of ADD
+      // associativity and ADD can be linked. In current approach "linking"
+      // tensor can be only latest written tensor(during linear order of
+      // execution) among input tensors.
+      if (IsGenericAdd(node, inputs, outputs)) {
+        int latest_written_tensor_index = 0;
+        int last_usage = tensor_usages[inputs[0]->id];
+        for (int j = 1; j < inputs.size(); ++j) {
+          if (tensor_usages[inputs[j]->id] > last_usage) {
+            last_usage = tensor_usages[inputs[j]->id];
+            latest_written_tensor_index = j;
+          }
+        }
+        std::swap(inputs[0], inputs[latest_written_tensor_index]);
+      }
+      consumed_nodes.insert(node.id);
+      OperationDef op_def;
+      op_def.precision = precision_;
+      for (int j = 0; j < inputs.size(); ++j) {
+        op_def.src_tensors.push_back(
+            tensor_reserver_.Get(inputs[j]->id).descriptor);
+      }
+      for (int j = 0; j < outputs.size(); ++j) {
+        op_def.dst_tensors.push_back(
+            tensor_reserver_.Get(outputs[j]->id).descriptor);
+      }
+      RETURN_IF_ERROR(GPUOperationFromNode(creation_context, op_def, hints,
+                                           inputs, outputs, node,
+                                           &gpu_subgraph));
+    }
     std::unordered_map<int, ValueId> mapping_to_global_ids;
     for (int j = 0; j < gpu_subgraph.new_tensors.size(); ++j) {
       const auto& t = gpu_subgraph.new_tensors[j];
@@ -324,7 +340,7 @@ absl::Status InferenceContext::ConvertOperations(
       for (int j = 0; j < gpu_op.input_ids.size(); ++j) {
         int id = gpu_op.input_ids[j];
         if (id >= 0) {
-          cl_node.inputs[j] = inputs[id]->id;
+          cl_node.inputs[j] = id;
         } else {
           cl_node.inputs[j] = mapping_to_global_ids[-(id + 1)];
         }
@@ -333,7 +349,8 @@ absl::Status InferenceContext::ConvertOperations(
       for (int j = 0; j < gpu_op.output_ids.size(); ++j) {
         int id = gpu_op.output_ids[j];
         if (id >= 0) {
-          cl_node.outputs[j] = outputs[id]->id;
+          cl_node.outputs[j] = id;
+          tensor_usages[id] = i;
         } else {
           cl_node.outputs[j] = mapping_to_global_ids[-(id + 1)];
         }
diff --git a/tensorflow/lite/delegates/gpu/cl/model_hints.h b/tensorflow/lite/delegates/gpu/cl/model_hints.h
index 7661cc0dacb..7c0f4b55b1d 100644
--- a/tensorflow/lite/delegates/gpu/cl/model_hints.h
+++ b/tensorflow/lite/delegates/gpu/cl/model_hints.h
@@ -25,13 +25,18 @@ namespace cl {
 struct ModelHints {
   using ModelHint = uint64_t;
 
-  // By default we want the fastest inference
+  // By default we want the fastest inference.
   static constexpr ModelHint kFastestInference = 0x00000000;
-  // Can improve compilation time, but inference can be slower
+  // Can improve compilation time, but inference can be slower.
   static constexpr ModelHint kReduceKernelsCount = 0x00000001;
-  // Can improve tuning time, but inference can be slower
+  // Can improve tuning time, but inference can be slower.
   static constexpr ModelHint kFastTuning = 0x00000002;
 
+  // Experimental.
+  // Can improve performance and memory consumption, but slow down
+  // initialization a lot and create more kernels.
+  static constexpr ModelHint kAllowSpecialKernels = 0x00000004;
+
   void Add(ModelHint hint) {
     if (hint == kFastestInference) {
       hints = kFastestInference;
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/BUILD b/tensorflow/lite/delegates/gpu/cl/selectors/BUILD
index ff196cfaf71..bf4c7df8651 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/BUILD
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/BUILD
@@ -152,6 +152,26 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "special_selector",
+    srcs = ["special_selector.cc"],
+    hdrs = ["special_selector.h"],
+    deps = [
+        ":subgraph",
+        "//tensorflow/lite/delegates/gpu/cl:cl_device",
+        "//tensorflow/lite/delegates/gpu/cl:tensor_type",
+        "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
+        "//tensorflow/lite/delegates/gpu/cl/kernels/special:depthwise_conv_plus_1x1_conv",
+        "//tensorflow/lite/delegates/gpu/common:data_type",
+        "//tensorflow/lite/delegates/gpu/common:model",
+        "//tensorflow/lite/delegates/gpu/common:operations",
+        "//tensorflow/lite/delegates/gpu/common:shape",
+        "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common:tensor",
+        "@com_google_absl//absl/types:any",
+    ],
+)
+
 cc_library(
     name = "subgraph",
     srcs = ["subgraph.cc"],
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.cc
new file mode 100644
index 00000000000..8a801b460d1
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.cc
@@ -0,0 +1,111 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h"
+
+#include "absl/types/any.h"
+#include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
+#include "tensorflow/lite/delegates/gpu/cl/kernels/special/depthwise_conv_plus_1x1_conv.h"
+#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
+#include "tensorflow/lite/delegates/gpu/common/data_type.h"
+#include "tensorflow/lite/delegates/gpu/common/operations.h"
+#include "tensorflow/lite/delegates/gpu/common/shape.h"
+#include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/tensor.h"
+
+namespace tflite {
+namespace gpu {
+namespace cl {
+namespace {
+absl::Status TryDepthwiseConvPlus1x1Conv(
+    const CreationContext& creation_context, CalculationsPrecision precision,
+    const GraphFloat32& graph, NodeId first_node_id,
+    const std::map<ValueId, TensorDescriptor>& tensor_descriptors,
+    std::set<NodeId>* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph) {
+  auto* dw_node = graph.GetNode(first_node_id);
+  if (OperationTypeFromString(dw_node->operation.type) !=
+      OperationType::DEPTHWISE_CONVOLUTION) {
+    return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable.");
+  }
+  auto dw_outputs = graph.FindOutputs(dw_node->id);
+  auto consumers = graph.FindConsumers(dw_outputs[0]->id);
+  if (consumers.size() != 1) {
+    return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable.");
+  }
+  auto* conv_node = consumers[0];
+  if (consumed_nodes->find(conv_node->id) != consumed_nodes->end()) {
+    return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable.");
+  }
+  if (OperationTypeFromString(conv_node->operation.type) !=
+      OperationType::CONVOLUTION_2D) {
+    return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable.");
+  }
+  if (graph.FindInputs(conv_node->id).size() != 1) {
+    return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable.");
+  }
+  auto dw_attr = absl::any_cast<DepthwiseConvolution2DAttributes>(
+      dw_node->operation.attributes);
+  auto conv_attr =
+      absl::any_cast<Convolution2DAttributes>(conv_node->operation.attributes);
+  auto dw_inputs = graph.FindInputs(dw_node->id);
+  auto conv_outputs = graph.FindOutputs(conv_node->id);
+  OperationDef op_def;
+  op_def.precision = precision;
+  auto it = tensor_descriptors.find(dw_inputs[0]->id);
+  if (it != tensor_descriptors.end()) {
+    op_def.src_tensors.push_back(it->second);
+  }
+  it = tensor_descriptors.find(conv_outputs[0]->id);
+  if (it != tensor_descriptors.end()) {
+    op_def.dst_tensors.push_back(it->second);
+  }
+  if (!IsDepthwiseConvPlus1x1ConvSupported(*creation_context.device, op_def,
+                                           dw_attr, conv_attr)) {
+    return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable.");
+  }
+  std::unique_ptr<GPUOperation>* gpu_op =
+      InitSingleOpSubgraph(dw_inputs, conv_outputs, gpu_subgraph);
+  DepthwiseConvPlus1x1Conv operation;
+  RETURN_IF_ERROR(CreateDepthwiseConvPlus1x1Conv(
+      creation_context, op_def, dw_attr, conv_attr, &operation));
+  *gpu_op = absl::make_unique<DepthwiseConvPlus1x1Conv>(std::move(operation));
+  consumed_nodes->insert(dw_node->id);
+  consumed_nodes->insert(conv_node->id);
+  return absl::OkStatus();
+}
+}  // namespace
+
+absl::Status GPUSubgraphFromGraph(
+    const CreationContext& creation_context, CalculationsPrecision precision,
+    const GraphFloat32& graph, NodeId first_node_id,
+    const std::map<ValueId, TensorDescriptor>& tensor_descriptors,
+    std::set<NodeId>* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph) {
+  if (!creation_context.device->IsNvidia()) {
+    return absl::NotFoundError(
+        "Experimental feature, enabled for NVidia only, but device is not "
+        "nvidia gpu.");
+  }
+  if (TryDepthwiseConvPlus1x1Conv(creation_context, precision, graph,
+                                  first_node_id, tensor_descriptors,
+                                  consumed_nodes, gpu_subgraph)
+          .ok()) {
+    return absl::OkStatus();
+  }
+  return absl::NotFoundError("No special combination.");
+}
+
+}  // namespace cl
+}  // namespace gpu
+}  // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h
new file mode 100644
index 00000000000..687d221aac6
--- /dev/null
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h
@@ -0,0 +1,43 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SPECIAL_SELECTOR_H_
+#define TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SPECIAL_SELECTOR_H_
+
+#include <map>
+#include <set>
+#include <vector>
+
+#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
+#include "tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h"
+#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
+#include "tensorflow/lite/delegates/gpu/common/model.h"
+#include "tensorflow/lite/delegates/gpu/common/status.h"
+
+namespace tflite {
+namespace gpu {
+namespace cl {
+
+absl::Status GPUSubgraphFromGraph(
+    const CreationContext& creation_context, CalculationsPrecision precision,
+    const GraphFloat32& graph, NodeId first_node_id,
+    const std::map<ValueId, TensorDescriptor>& tensor_descriptors,
+    std::set<NodeId>* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph);
+
+}  // namespace cl
+}  // namespace gpu
+}  // namespace tflite
+
+#endif  // TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SPECIAL_SELECTOR_H_
diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.cc b/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.cc
index 0f18a4b7be5..27a40886497 100644
--- a/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.cc
+++ b/tensorflow/lite/delegates/gpu/cl/selectors/subgraph.cc
@@ -32,10 +32,10 @@ std::unique_ptr<GPUOperation>* InitSingleOpSubgraph(
   gpu_subgraph->new_tensors.clear();
   gpu_subgraph->operations.push_back({});
   for (int i = 0; i < inputs.size(); ++i) {
-    gpu_subgraph->operations[0].input_ids.push_back(i);
+    gpu_subgraph->operations[0].input_ids.push_back(inputs[i]->id);
   }
   for (int i = 0; i < outputs.size(); ++i) {
-    gpu_subgraph->operations[0].output_ids.push_back(i);
+    gpu_subgraph->operations[0].output_ids.push_back(outputs[i]->id);
   }
 
   return &gpu_subgraph->operations[0].operation;
diff --git a/tensorflow/lite/delegates/gpu/cl/testing/performance_profiling.cc b/tensorflow/lite/delegates/gpu/cl/testing/performance_profiling.cc
index 0c500cd0bbe..ab2e52f14ed 100644
--- a/tensorflow/lite/delegates/gpu/cl/testing/performance_profiling.cc
+++ b/tensorflow/lite/delegates/gpu/cl/testing/performance_profiling.cc
@@ -44,6 +44,7 @@ absl::Status RunModelSample(const std::string& model_name) {
                               ? CalculationsPrecision::F16
                               : CalculationsPrecision::F32;
   create_info.storage_type = GetFastestStorageType(env.device());
+  create_info.hints.Add(ModelHints::kAllowSpecialKernels);
   std::cout << "Precision: " << ToString(create_info.precision) << std::endl;
   std::cout << "Storage type: " << ToString(create_info.storage_type)
             << std::endl;