From 897e3c0ecad3b45f5e96615173e7511619eebc93 Mon Sep 17 00:00:00 2001 From: Raman Sarokin <sorokin@google.com> Date: Tue, 16 Jun 2020 16:39:31 -0700 Subject: [PATCH] Softmax1x1 converted to new style. PiperOrigin-RevId: 316782370 Change-Id: I1f7761c0520d72876f352c9f156341b349b90cbe --- .../delegates/gpu/cl/kernels/softmax1x1.cc | 87 ++++++++++--------- 1 file changed, 45 insertions(+), 42 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc index 192bee771d6..fcfe4a1810c 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc @@ -25,47 +25,45 @@ namespace gpu { namespace cl { namespace { -std::string GetSoftmaxKernelCode( - const OperationDef& op_def, - const std::vector<ElementwiseOperation*>& linked_operations) { - TensorCodeGenerator src_tensor("src_data", - WHSBPoint{"tensor_size.x", "tensor_size.y", - "tensor_size.z", "tensor_size.w"}, - op_def.src_tensors[0]); - TensorCodeGenerator dst_tensor("dst_data", - WHSBPoint{"tensor_size.x", "tensor_size.y", - "tensor_size.z", "tensor_size.w"}, - op_def.dst_tensors[0]); +std::string GetSoftmaxKernelCode(const OperationDef& op_def, Arguments* args) { + args->AddObjectRef( + "src_tensor", AccessType::READ, + absl::make_unique<TensorDescriptor>(op_def.src_tensors[0])); + args->AddObjectRef( + "dst_tensor", AccessType::WRITE, + absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0])); + args->AddFloat("mask_x"); + args->AddFloat("mask_y"); + args->AddFloat("mask_z"); + args->AddFloat("mask_w"); + args->AddInt("slices_x32"); - const std::string batch_id = op_def.IsBatchSupported() ? "batch_id" : ""; std::string c = GetCommonDefines(op_def.precision); c += "__kernel void main_function(\n"; - c += src_tensor.GetDeclaration(AccessType::READ); - c += GetArgsDeclaration(linked_operations); - c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n"; - c += " int4 tensor_size,\n"; - c += " int2 size,\n"; - c += " float4 mask\n"; - c += ") {\n"; + c += "$0) {\n"; if (op_def.IsBatchSupported()) { c += " int batch_id = get_global_id(1);\n"; - c += " if (batch_id >= tensor_size.w) return;\n"; + c += " if (batch_id >= args.dst_tensor.Batch()) return;\n"; + c += " args.dst_tensor.SetBatchRef(batch_id);\n"; + c += " args.src_tensor.SetBatchRef(batch_id);\n"; } + c += " float4 mask = (float4)(args.mask_x, args.mask_y, args.mask_z, " + "args.mask_w);\n"; c += " int offset = 0;\n"; c += " float sum = 0.0f;\n"; c += " int s = 0;\n"; c += " int tid = get_local_id(0);\n"; c += " do {\n"; c += " int z = offset + tid;\n"; - c += " if (z < size.x) {\n"; - c += " float4 mask_temp = z == size.x - 1 ? mask : (float4)(1.0f);\n"; - c += " float4 src = " + - src_tensor.ReadAsFloatWHSB("0", "0", "z", batch_id) + ";\n"; + c += " if (z < args.dst_tensor.Slices()) {\n"; + c += " float4 mask_temp = z == args.dst_tensor.Slices() - 1 ? mask : " + "(float4)(1.0f);\n"; + c += " float4 src = args.src_tensor.Read<float>(0, 0, z);\n"; c += " sum += dot(mask_temp, exp(src));\n"; c += " offset += 32;\n"; c += " }\n"; c += " s++;\n"; - c += " } while (s < size.y);\n"; + c += " } while (s < args.slices_x32);\n"; c += "\n"; c += " __local float4 tmp[8];\n"; c += " __local float* tmpx1 = (__local float*)tmp;\n"; @@ -89,16 +87,14 @@ std::string GetSoftmaxKernelCode( c += " s = 0;\n"; c += " do {\n"; c += " int z = offset + tid;\n"; - c += " if (z < size.x) {\n"; - c += " FLT4 res = TO_FLT4(exp(" + - src_tensor.ReadAsFloatWHSB("0", "0", "z", batch_id) + ")*sum);\n"; - const LinkingContext context{"res", "0", "0", "z"}; - c += PostProcess(linked_operations, context); - c += " " + dst_tensor.WriteWHSB("res", "0", "0", "z", batch_id); + c += " if (z < args.dst_tensor.Slices()) {\n"; + c += " FLT4 res = TO_FLT4(exp(args.src_tensor.Read<float>(0, 0, " + "z))*sum);\n"; + c += " args.dst_tensor.Write(res, 0, 0, z);\n"; c += " offset += 32;\n"; c += " }\n"; c += " s++;\n"; - c += " } while (s < size.y);\n"; + c += " } while (s < args.slices_x32);\n"; c += "}\n"; return c; } @@ -116,23 +112,30 @@ Softmax1x1& Softmax1x1::operator=(Softmax1x1&& kernel) { } absl::Status Softmax1x1::Compile(const CreationContext& creation_context) { - const auto code = GetSoftmaxKernelCode(definition_, linked_operations_); + std::string code = GetSoftmaxKernelCode(definition_, &args_); + std::string element_wise_code; + RETURN_IF_ERROR( + MergeOperations(linked_operations_, &args_, &element_wise_code)); + RETURN_IF_ERROR(args_.TransformToCLCode(creation_context.device->GetInfo(), + {{"dst_tensor", element_wise_code}}, + &code)); return creation_context.cache->GetOrCreateCLKernel( code, "main_function", *creation_context.context, *creation_context.device, &kernel_); } absl::Status Softmax1x1::AddToQueue(CLCommandQueue* queue) { - kernel_.ResetBindingCounter(); - RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); - RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); - RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting())); - RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB())); - const int depth = src_[0]->Slices(); - RETURN_IF_ERROR(kernel_.SetBytesAuto(int2(depth, DivideRoundUp(depth, 32)))); + RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0])); + RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0])); + float4 mask = GetMaskForLastPlane(src_[0]->Channels()); + RETURN_IF_ERROR(args_.SetFloat("mask_x", mask.x)); + RETURN_IF_ERROR(args_.SetFloat("mask_y", mask.y)); + RETURN_IF_ERROR(args_.SetFloat("mask_z", mask.z)); + RETURN_IF_ERROR(args_.SetFloat("mask_w", mask.w)); RETURN_IF_ERROR( - kernel_.SetBytesAuto(GetMaskForLastPlane(src_[0]->Channels()))); - + args_.SetInt("slices_x32", DivideRoundUp(src_[0]->Slices(), 32))); + RETURN_IF_ERROR(SetArguments(linked_operations_, &args_)); + RETURN_IF_ERROR(args_.Bind(kernel_.kernel())); return queue->DispatchImplicit(kernel_, {32, dst_[0]->Batch(), 1}, {32, 1, 1}); }