From 777b6ad484d1647a1b7a64ab862b1cb3ff706a4f Mon Sep 17 00:00:00 2001
From: Raman Sarokin <sorokin@google.com>
Date: Mon, 29 Jun 2020 10:42:03 -0700
Subject: [PATCH] Improved AddBias transformation.

PiperOrigin-RevId: 318845057
Change-Id: I41321cdea9d8c605fa77dcff4f962a891536d985
---
 .../gpu/common/transformations/add_bias.cc    | 39 ++++++++++++-------
 1 file changed, 24 insertions(+), 15 deletions(-)

diff --git a/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc b/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc
index 7feac824ef7..ec2474138a3 100644
--- a/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc
+++ b/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc
@@ -27,38 +27,47 @@ namespace tflite {
 namespace gpu {
 namespace {
 
-template <typename T>
-TransformResult FillBias(Node* node) {
-  auto& attr = absl::any_cast<T&>(node->operation.attributes);
-  if (attr.bias.data.empty()) {
-    const int dst_channels = attr.weights.shape.o;
-    attr.bias = MakeZeroTensor<Linear, DataType::FLOAT32>(Linear(dst_channels));
+TransformResult FillBias(
+    int output_channels,
+    tflite::gpu::Tensor<Linear, DataType::FLOAT32>* biases) {
+  if (biases->data.empty()) {
+    *biases =
+        MakeZeroTensor<Linear, DataType::FLOAT32>(Linear(output_channels));
     return {TransformStatus::APPLIED, "Added bias"};
   }
+  if (biases->shape.v != output_channels) {
+    float last_value = biases->data.back();
+    biases->shape.v = output_channels;
+    biases->data.resize(output_channels, last_value);
+    return {TransformStatus::APPLIED, "Bias extended"};
+  }
   return {TransformStatus::SKIPPED, ""};
 }
 
-template TransformResult FillBias<Convolution2DAttributes>(Node* node);
-template TransformResult FillBias<ConvolutionTransposedAttributes>(Node* node);
-template TransformResult FillBias<DepthwiseConvolution2DAttributes>(Node* node);
-template TransformResult FillBias<FullyConnectedAttributes>(Node* node);
-
 class AddBias : public NodeTransformation {
  public:
   TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final {
     if (node->operation.type == ToString(OperationType::CONVOLUTION_2D)) {
-      return FillBias<Convolution2DAttributes>(node);
+      auto& attr =
+          absl::any_cast<Convolution2DAttributes&>(node->operation.attributes);
+      return FillBias(attr.weights.shape.o, &attr.bias);
     }
     if (node->operation.type ==
         ToString(OperationType::CONVOLUTION_TRANSPOSED)) {
-      return FillBias<ConvolutionTransposedAttributes>(node);
+      auto& attr = absl::any_cast<ConvolutionTransposedAttributes&>(
+          node->operation.attributes);
+      return FillBias(attr.weights.shape.o, &attr.bias);
     }
     if (node->operation.type ==
         ToString(OperationType::DEPTHWISE_CONVOLUTION)) {
-      return FillBias<DepthwiseConvolution2DAttributes>(node);
+      auto& attr = absl::any_cast<DepthwiseConvolution2DAttributes&>(
+          node->operation.attributes);
+      return FillBias(attr.weights.shape.o * attr.weights.shape.i, &attr.bias);
     }
     if (node->operation.type == ToString(OperationType::FULLY_CONNECTED)) {
-      return FillBias<FullyConnectedAttributes>(node);
+      auto& attr =
+          absl::any_cast<FullyConnectedAttributes&>(node->operation.attributes);
+      return FillBias(attr.weights.shape.o, &attr.bias);
     }
     return {TransformStatus::SKIPPED, ""};
   }