From 897e3c0ecad3b45f5e96615173e7511619eebc93 Mon Sep 17 00:00:00 2001
From: Raman Sarokin <sorokin@google.com>
Date: Tue, 16 Jun 2020 16:39:31 -0700
Subject: [PATCH] Softmax1x1 converted to new style.

PiperOrigin-RevId: 316782370
Change-Id: I1f7761c0520d72876f352c9f156341b349b90cbe
---
 .../delegates/gpu/cl/kernels/softmax1x1.cc    | 87 ++++++++++---------
 1 file changed, 45 insertions(+), 42 deletions(-)

diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc
index 192bee771d6..fcfe4a1810c 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc
@@ -25,47 +25,45 @@ namespace gpu {
 namespace cl {
 namespace {
 
-std::string GetSoftmaxKernelCode(
-    const OperationDef& op_def,
-    const std::vector<ElementwiseOperation*>& linked_operations) {
-  TensorCodeGenerator src_tensor("src_data",
-                                 WHSBPoint{"tensor_size.x", "tensor_size.y",
-                                           "tensor_size.z", "tensor_size.w"},
-                                 op_def.src_tensors[0]);
-  TensorCodeGenerator dst_tensor("dst_data",
-                                 WHSBPoint{"tensor_size.x", "tensor_size.y",
-                                           "tensor_size.z", "tensor_size.w"},
-                                 op_def.dst_tensors[0]);
+std::string GetSoftmaxKernelCode(const OperationDef& op_def, Arguments* args) {
+  args->AddObjectRef(
+      "src_tensor", AccessType::READ,
+      absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]));
+  args->AddObjectRef(
+      "dst_tensor", AccessType::WRITE,
+      absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]));
+  args->AddFloat("mask_x");
+  args->AddFloat("mask_y");
+  args->AddFloat("mask_z");
+  args->AddFloat("mask_w");
+  args->AddInt("slices_x32");
 
-  const std::string batch_id = op_def.IsBatchSupported() ? "batch_id" : "";
   std::string c = GetCommonDefines(op_def.precision);
   c += "__kernel void main_function(\n";
-  c += src_tensor.GetDeclaration(AccessType::READ);
-  c += GetArgsDeclaration(linked_operations);
-  c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
-  c += "    int4 tensor_size,\n";
-  c += "    int2 size,\n";
-  c += "    float4 mask\n";
-  c += ") {\n";
+  c += "$0) {\n";
   if (op_def.IsBatchSupported()) {
     c += "  int batch_id = get_global_id(1);\n";
-    c += "  if (batch_id >= tensor_size.w) return;\n";
+    c += "  if (batch_id >= args.dst_tensor.Batch()) return;\n";
+    c += "  args.dst_tensor.SetBatchRef(batch_id);\n";
+    c += "  args.src_tensor.SetBatchRef(batch_id);\n";
   }
+  c += "  float4 mask = (float4)(args.mask_x, args.mask_y, args.mask_z, "
+       "args.mask_w);\n";
   c += "  int offset = 0;\n";
   c += "  float sum = 0.0f;\n";
   c += "  int s = 0;\n";
   c += "  int tid = get_local_id(0);\n";
   c += "  do {\n";
   c += "    int z = offset + tid;\n";
-  c += "    if (z < size.x) {\n";
-  c += "      float4 mask_temp = z == size.x - 1 ? mask : (float4)(1.0f);\n";
-  c += "      float4 src = " +
-       src_tensor.ReadAsFloatWHSB("0", "0", "z", batch_id) + ";\n";
+  c += "    if (z < args.dst_tensor.Slices()) {\n";
+  c += "      float4 mask_temp = z == args.dst_tensor.Slices() - 1 ? mask : "
+       "(float4)(1.0f);\n";
+  c += "      float4 src = args.src_tensor.Read<float>(0, 0, z);\n";
   c += "      sum += dot(mask_temp, exp(src));\n";
   c += "      offset += 32;\n";
   c += "    }\n";
   c += "    s++;\n";
-  c += "  } while (s < size.y);\n";
+  c += "  } while (s < args.slices_x32);\n";
   c += "\n";
   c += "  __local float4 tmp[8];\n";
   c += "  __local float* tmpx1 = (__local float*)tmp;\n";
@@ -89,16 +87,14 @@ std::string GetSoftmaxKernelCode(
   c += "  s = 0;\n";
   c += "  do {\n";
   c += "    int z = offset + tid;\n";
-  c += "    if (z < size.x) {\n";
-  c += "      FLT4 res = TO_FLT4(exp(" +
-       src_tensor.ReadAsFloatWHSB("0", "0", "z", batch_id) + ")*sum);\n";
-  const LinkingContext context{"res", "0", "0", "z"};
-  c += PostProcess(linked_operations, context);
-  c += "    " + dst_tensor.WriteWHSB("res", "0", "0", "z", batch_id);
+  c += "    if (z < args.dst_tensor.Slices()) {\n";
+  c += "      FLT4 res = TO_FLT4(exp(args.src_tensor.Read<float>(0, 0, "
+       "z))*sum);\n";
+  c += "      args.dst_tensor.Write(res, 0, 0, z);\n";
   c += "      offset += 32;\n";
   c += "    }\n";
   c += "    s++;\n";
-  c += "  } while (s < size.y);\n";
+  c += "  } while (s < args.slices_x32);\n";
   c += "}\n";
   return c;
 }
@@ -116,23 +112,30 @@ Softmax1x1& Softmax1x1::operator=(Softmax1x1&& kernel) {
 }
 
 absl::Status Softmax1x1::Compile(const CreationContext& creation_context) {
-  const auto code = GetSoftmaxKernelCode(definition_, linked_operations_);
+  std::string code = GetSoftmaxKernelCode(definition_, &args_);
+  std::string element_wise_code;
+  RETURN_IF_ERROR(
+      MergeOperations(linked_operations_, &args_, &element_wise_code));
+  RETURN_IF_ERROR(args_.TransformToCLCode(creation_context.device->GetInfo(),
+                                          {{"dst_tensor", element_wise_code}},
+                                          &code));
   return creation_context.cache->GetOrCreateCLKernel(
       code, "main_function", *creation_context.context,
       *creation_context.device, &kernel_);
 }
 
 absl::Status Softmax1x1::AddToQueue(CLCommandQueue* queue) {
-  kernel_.ResetBindingCounter();
-  RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
-  RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
-  RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
-  RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
-  const int depth = src_[0]->Slices();
-  RETURN_IF_ERROR(kernel_.SetBytesAuto(int2(depth, DivideRoundUp(depth, 32))));
+  RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0]));
+  RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0]));
+  float4 mask = GetMaskForLastPlane(src_[0]->Channels());
+  RETURN_IF_ERROR(args_.SetFloat("mask_x", mask.x));
+  RETURN_IF_ERROR(args_.SetFloat("mask_y", mask.y));
+  RETURN_IF_ERROR(args_.SetFloat("mask_z", mask.z));
+  RETURN_IF_ERROR(args_.SetFloat("mask_w", mask.w));
   RETURN_IF_ERROR(
-      kernel_.SetBytesAuto(GetMaskForLastPlane(src_[0]->Channels())));
-
+      args_.SetInt("slices_x32", DivideRoundUp(src_[0]->Slices(), 32)));
+  RETURN_IF_ERROR(SetArguments(linked_operations_, &args_));
+  RETURN_IF_ERROR(args_.Bind(kernel_.kernel()));
   return queue->DispatchImplicit(kernel_, {32, dst_[0]->Batch(), 1},
                                  {32, 1, 1});
 }