From 403417fdc85650181f894ee0aa0fb85dfa75fea4 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Fri, 10 May 2019 10:51:47 -0700
Subject: [PATCH] TFLite GPU Metal: Removed restriction channels_multiplier ==
 1.

PiperOrigin-RevId: 247638516
---
 .../gpu/metal/kernels/depthwise_conv.cc       | 57 +++++++++++++------
 1 file changed, 40 insertions(+), 17 deletions(-)

diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.cc b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.cc
index cacbb3c8ae3..15b46541562 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.cc
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.cc
@@ -468,6 +468,7 @@ std::vector<ComputeTaskDescriptorPtr> DepthWiseConvolution(
     int id, ValueId input_id, ValueId output_id,
     const DepthwiseConvolution2DAttributes& attr,
     const RuntimeOptions& options) {
+  int channels_multiplier = attr.weights.shape.o;
   auto desc = std::make_shared<ComputeTaskDescriptor>();
   desc->id = id;
   desc->is_linkable = false;
@@ -503,10 +504,44 @@ std::vector<ComputeTaskDescriptorPtr> DepthWiseConvolution(
           const bool outside = coords.x < 0 || coords.y < 0 ||
             coords.x >= params.size.x || coords.y >= params.size.y;
           if (outside) continue;
-
-          const int src_layer = gid.z;
-          const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x;
-          sum0 += float4(src_buffer[src_index]) * float4(temp[ky * kernel_x + kx]);
+)";
+  if (channels_multiplier == 1) {
+    shader_source += R"(
+        const int src_layer = gid.z;
+        const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x;
+        const FLT4 src_modified = src_buffer[src_index];
+)";
+  } else if (channels_multiplier == 2) {
+    shader_source += R"(
+        const int src_layer = gid.z / 2;
+        const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x;
+        const FLT4 src = src_buffer[src_index];
+        const FLT2 t0 = gid.z % 2 == 0 ? src.xy : src.zw;
+        const FLT4 src_modified = FLT4(t0.x, t0.x, t0.y, t0.y);
+)";
+  } else if (channels_multiplier == 4) {
+    shader_source += R"(
+        const int src_layer = gid.z / 4;
+        const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x;
+        const FLT4 src = src_buffer[src_index];
+        const FLT t0 = src[gid.z % 4];
+        const FLT4 src_modified = FLT4(t0, t0, t0, t0);
+)";
+  } else {
+    shader_source += R"(
+        const int src_layer = gid.z / params.channel_multiplier.x;
+        const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x;
+        const FLT4 src = src_buffer[src_index];
+        FLT4 src_modified;
+        const int src_layer_offset = (gid.z % params.channel_multiplier.x) * 4;
+        src_modified.x = src[(src_layer_offset + 0) / params.channel_multiplier.x];
+        src_modified.y = src[(src_layer_offset + 1) / params.channel_multiplier.x];
+        src_modified.z = src[(src_layer_offset + 2) / params.channel_multiplier.x];
+        src_modified.w = src[(src_layer_offset + 3) / params.channel_multiplier.x];
+)";
+  }
+  shader_source += R"(
+          sum0 += float4(src_modified * temp[ky * kernel_x + kx]);
         }
       }
       FLT4 res = FLT4(sum0 + float4(biases[gid.z]));
@@ -531,19 +566,7 @@ std::vector<ComputeTaskDescriptorPtr> DepthWiseConvolution(
         return out_shape;
       }};
 
-  const int num_output_channels = attr.weights.shape.i * attr.weights.shape.o;
-  BHWC reordered_dims{1, attr.weights.shape.h, attr.weights.shape.w,
-                      num_output_channels};
-  std::vector<float> filters_reordered(GetElementsSizeForPHWC4(reordered_dims),
-                                       0.0f);
-  if (!ConvertToPHWC4(
-           absl::MakeConstSpan(attr.weights.data.data(),
-                               attr.weights.data.size()),
-           reordered_dims,
-           absl::MakeSpan(filters_reordered.data(), filters_reordered.size()))
-           .ok()) {
-    return {};
-  }
+  std::vector<float> filters_reordered = ConvertToPIOHW4(attr.weights);
   auto filters = options.storage_precision == RuntimeOptions::Precision::FP32
                      ? VectorToUint8Vector(filters_reordered)
                      : VectorFloatToHalf(filters_reordered);