From 01cfc8a8a3d6176b1f028886087de4eaaa64ce2f Mon Sep 17 00:00:00 2001 From: Raman Sarokin Date: Tue, 16 Jun 2020 17:07:23 -0700 Subject: [PATCH] Resize converted to new style. PiperOrigin-RevId: 316787130 Change-Id: I67db63fa6eaec2bccc87031f2e202da65a2ce439 --- .../lite/delegates/gpu/cl/kernels/resize.cc | 287 +++++++++--------- 1 file changed, 147 insertions(+), 140 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/resize.cc b/tensorflow/lite/delegates/gpu/cl/kernels/resize.cc index 5d578fe6e09..6aa2d1d2570 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/resize.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/resize.cc @@ -25,168 +25,166 @@ namespace gpu { namespace cl { namespace { -std::string GetResizeCode( - const OperationDef& op_def, SamplingType sampling_type, - bool half_pixel_centers, - const std::vector& linked_operations) { - TensorCodeGenerator src_tensor( - "src_data", WHSPoint{"src_size.x", "src_size.y", "src_size.z"}, - op_def.src_tensors[0]); - TensorCodeGenerator dst_tensor( - "dst_data", WHSPoint{"dst_size.x", "dst_size.y", "dst_size.z"}, - op_def.dst_tensors[0]); +std::string GetResizeCode(const OperationDef& op_def, + SamplingType sampling_type, bool half_pixel_centers, + Arguments* args) { + auto src_desc = absl::make_unique(op_def.src_tensors[0]); + if (op_def.IsBatchSupported()) { + src_desc->SetStateVar("BatchedWidth", "true"); + } + args->AddObjectRef("src_tensor", AccessType::READ, std::move(src_desc)); + auto dst_desc = absl::make_unique(op_def.dst_tensors[0]); + if (op_def.IsBatchSupported()) { + dst_desc->SetStateVar("BatchedWidth", "true"); + } + args->AddObjectRef("dst_tensor", AccessType::WRITE, std::move(dst_desc)); + args->AddInt("border_x"); + args->AddInt("border_y"); + args->AddFloat("scale_factor_x"); + args->AddFloat("scale_factor_y"); std::string c = GetCommonDefines(op_def.precision); c += "__kernel void main_function(\n"; - c += src_tensor.GetDeclaration(AccessType::READ); - c += GetArgsDeclaration(linked_operations); - c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n"; - c += " int4 src_size, \n"; - c += " int4 dst_size, \n"; - c += " int2 border, \n"; - c += " float2 scale_factor \n"; - c += ") {\n"; + c += "$0) {\n"; c += " int Y = get_global_id(1);\n"; c += " int Z = get_global_id(2);\n"; if (op_def.IsBatchSupported()) { c += " int linear_id = get_global_id(0);\n"; - c += " int X = linear_id / dst_size.w;\n"; - c += " int B = linear_id % dst_size.w;\n"; - c += " if (get_global_id(0) >= dst_size.x || Y >= dst_size.y || Z >= " - "dst_size.z) return;\n"; + c += " int X = linear_id / args.dst_tensor.Batch();\n"; + c += " int B = linear_id % args.dst_tensor.Batch();\n"; + c += " if (linear_id >= args.dst_tensor.Width() || Y >= " + "args.dst_tensor.Height() || Z >= args.dst_tensor.Slices()) return;\n"; } else { c += " int X = get_global_id(0);\n"; - c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) " - "return;\n"; + c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() " + "|| Z >= args.dst_tensor.Slices()) return;\n"; } if (sampling_type == SamplingType::NEAREST) { - c += " int2 coord = (int2)(X * scale_factor.x, Y * scale_factor.y);\n"; + c += " int2 coord = (int2)(X * args.scale_factor_x, Y * " + "args.scale_factor_y);\n"; if (op_def.IsBatchSupported()) { - c += " coord.x = coord.x * src_size.w + B;\n"; - c += " X = X * src_size.w + B;\n"; + c += " coord.x = coord.x * args.src_tensor.Batch() + B;\n"; + c += " X = X * args.src_tensor.Batch() + B;\n"; } - c += " FLT4 r0 = " + src_tensor.ReadWHS("coord.x", "coord.y", "Z") + ";\n"; + c += " FLT4 r0 = args.src_tensor.Read(coord.x, coord.y, Z);\n"; } else { if (half_pixel_centers) { - c += " float2 f_coords = ((float2)(X, Y) + 0.5f) * scale_factor - " + c += " float2 f_coords = ((float2)(X, Y) + 0.5f) * " + "(float2)(args.scale_factor_x, args.scale_factor_y) - " "0.5f;\n"; } else { - c += " float2 f_coords = (float2)(X, Y) * scale_factor;\n"; + c += " float2 f_coords = (float2)(X, Y) * (float2)(args.scale_factor_x, " + "args.scale_factor_y);\n"; } c += " float2 f_coords_floor = floor(f_coords);\n"; c += " int2 coords_floor = (int2)(f_coords_floor.x, f_coords_floor.y);\n"; c += " int4 st;\n"; c += " st.xy = max(coords_floor, (int2)(0, 0));\n"; - c += " st.zw = min(coords_floor + (int2)(1, 1), border);\n"; + c += " st.zw = min(coords_floor + (int2)(1, 1), (int2)(args.border_x, " + "args.border_y));\n"; c += " float2 t = f_coords - f_coords_floor;\n"; if (op_def.IsBatchSupported()) { - c += " st.x = st.x * src_size.w + B;\n"; - c += " st.z = st.z * src_size.w + B;\n"; - c += " X = X * src_size.w + B;\n"; + c += " st.x = st.x * args.src_tensor.Batch() + B;\n"; + c += " st.z = st.z * args.src_tensor.Batch() + B;\n"; + c += " X = X * args.src_tensor.Batch() + B;\n"; } - c += " float4 src0 = " + src_tensor.ReadAsFloatWHS("st.x", "st.y", "Z") + - ";\n"; - c += " float4 src1 = " + src_tensor.ReadAsFloatWHS("st.z", "st.y", "Z") + - ";\n"; - c += " float4 src2 = " + src_tensor.ReadAsFloatWHS("st.x", "st.w", "Z") + - ";\n"; - c += " float4 src3 = " + src_tensor.ReadAsFloatWHS("st.z", "st.w", "Z") + - ";\n"; + c += " float4 src0 = args.src_tensor.Read(st.x, st.y, Z);\n"; + c += " float4 src1 = args.src_tensor.Read(st.z, st.y, Z);\n"; + c += " float4 src2 = args.src_tensor.Read(st.x, st.w, Z);\n"; + c += " float4 src3 = args.src_tensor.Read(st.z, st.w, Z);\n"; c += " FLT4 r0 = TO_FLT4(mix(mix(src0, src1, t.x), mix(src2, src3, t.x), " "t.y));\n"; } - const LinkingContext context{"r0", "X", "Y", "Z"}; - c += PostProcess(linked_operations, context); - c += " " + dst_tensor.WriteWHS("r0", "X", "Y", "Z"); + c += " args.dst_tensor.Write(r0, X, Y, Z);\n"; c += "}\n"; return c; } -std::string GetResize3DCode( - const OperationDef& op_def, SamplingType sampling_type, - const std::vector& linked_operations) { - TensorCodeGenerator src_tensor( - "src_data", - WHDSPoint{"src_size.x", "src_size.y", "src_size.z", "src_size.w"}, - op_def.src_tensors[0]); - TensorCodeGenerator dst_tensor( - "dst_data", - WHDSPoint{"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w"}, - op_def.dst_tensors[0]); +std::string GetResize3DCode(const OperationDef& op_def, + SamplingType sampling_type, Arguments* args) { + auto src_desc = absl::make_unique(op_def.src_tensors[0]); + if (op_def.IsBatchSupported()) { + src_desc->SetStateVar("BatchedWidth", "true"); + } + args->AddObjectRef("src_tensor", AccessType::READ, std::move(src_desc)); + auto dst_desc = absl::make_unique(op_def.dst_tensors[0]); + if (op_def.IsBatchSupported()) { + dst_desc->SetStateVar("BatchedWidth", "true"); + } + args->AddObjectRef("dst_tensor", AccessType::WRITE, std::move(dst_desc)); + args->AddInt("border_x"); + args->AddInt("border_y"); + args->AddInt("border_z"); + args->AddFloat("scale_factor_x"); + args->AddFloat("scale_factor_y"); + args->AddFloat("scale_factor_z"); std::string c = GetCommonDefines(op_def.precision); c += "__kernel void main_function(\n"; - c += src_tensor.GetDeclaration(AccessType::READ); - c += GetArgsDeclaration(linked_operations); - c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n"; - c += " int4 src_size, \n"; - c += " int4 dst_size, \n"; - if (op_def.IsBatchSupported()) { - c += " int batch_size, \n"; - } - c += " int4 border, \n"; - c += " float4 scale_factor \n"; - c += ") {\n"; + c += "$0) {\n"; c += " int Y = get_global_id(1);\n"; c += " int linear_id_z = get_global_id(2);\n"; - c += " int S = linear_id_z % dst_size.w;\n"; - c += " int Z = linear_id_z / dst_size.w;\n"; + c += " int S = linear_id_z % args.dst_tensor.Slices();\n"; + c += " int Z = linear_id_z / args.dst_tensor.Slices();\n"; if (op_def.IsBatchSupported()) { c += " int linear_id = get_global_id(0);\n"; - c += " int X = linear_id / batch_size;\n"; - c += " int B = linear_id % batch_size;\n"; - c += " if (linear_id >= dst_size.x || Y >= dst_size.y || Z >= " - "dst_size.z) return;\n"; + c += " int X = linear_id / args.dst_tensor.Batch();\n"; + c += " int B = linear_id % args.dst_tensor.Batch();\n"; + c += " if (linear_id >= args.dst_tensor.Width() || Y >= " + "args.dst_tensor.Height() || Z >= args.dst_tensor.Depth()) return;\n"; } else { c += " int X = get_global_id(0);\n"; - c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) " - "return;\n"; + c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() " + "|| Z >= args.dst_tensor.Depth()) return;\n"; } if (sampling_type == SamplingType::NEAREST) { - c += " int4 coord = (int4)(X * scale_factor.x, Y * scale_factor.y, Z * " - "scale_factor.z, 0);\n"; + c += " int4 coord = (int4)(X * args.scale_factor_x, Y * " + "args.scale_factor_y, Z * " + "args.scale_factor_z, 0);\n"; if (op_def.IsBatchSupported()) { - c += " coord.x = coord.x * batch_size + B;\n"; - c += " X = X * batch_size + B;\n"; + c += " coord.x = coord.x * args.src_tensor.Batch() + B;\n"; + c += " X = X * args.src_tensor.Batch() + B;\n"; } - c += " FLT4 r0 = " + - src_tensor.ReadWHDS("coord.x", "coord.y", "coord.z", "S") + ";\n"; + c += " FLT4 r0 = args.src_tensor.Read(coord.x, coord.y, coord.z, S);\n"; } else { - c += " float4 f_coords = (float4)(X, Y, Z, 0) * scale_factor;\n"; + c += " float4 f_coords;\n"; + c += " f_coords.x = (float)(X) * args.scale_factor_x;\n"; + c += " f_coords.y = (float)(Y) * args.scale_factor_y;\n"; + c += " f_coords.z = (float)(Z) * args.scale_factor_z;\n"; c += " int4 start = (int4)(f_coords.x, f_coords.y, f_coords.z, 0);\n"; - c += " int4 end = min(start + (int4)(1, 1, 1, 0), border);\n"; + c += " int4 end;\n"; + c += " end.x = min(start.x + 1, args.border_x);\n"; + c += " end.y = min(start.y + 1, args.border_y);\n"; + c += " end.z = min(start.z + 1, args.border_z);\n"; c += " float4 t = f_coords - (float4)(start.x, start.y, start.z, 0.0f);\n"; if (op_def.IsBatchSupported()) { - c += " start.x = start.x * batch_size + B;\n"; - c += " end.x = end.x * batch_size + B;\n"; - c += " X = X * batch_size + B;\n"; + c += " start.x = start.x * args.src_tensor.Batch() + B;\n"; + c += " end.x = end.x * args.src_tensor.Batch() + B;\n"; + c += " X = X * args.src_tensor.Batch() + B;\n"; } - c += " float4 src0 = " + - src_tensor.ReadAsFloatWHDS("start.x", "start.y", "start.z", "S") + - ";\n"; - c += " float4 src1 = " + - src_tensor.ReadAsFloatWHDS("end.x", "start.y", "start.z", "S") + ";\n"; - c += " float4 src2 = " + - src_tensor.ReadAsFloatWHDS("start.x", "end.y", "start.z", "S") + ";\n"; - c += " float4 src3 = " + - src_tensor.ReadAsFloatWHDS("end.x", "end.y", "start.z", "S") + ";\n"; - c += " float4 src4 = " + - src_tensor.ReadAsFloatWHDS("start.x", "start.y", "end.z", "S") + ";\n"; - c += " float4 src5 = " + - src_tensor.ReadAsFloatWHDS("end.x", "start.y", "end.z", "S") + ";\n"; - c += " float4 src6 = " + - src_tensor.ReadAsFloatWHDS("start.x", "end.y", "end.z", "S") + ";\n"; - c += " float4 src7 = " + - src_tensor.ReadAsFloatWHDS("end.x", "end.y", "end.z", "S") + ";\n"; + c += " float4 src0 = args.src_tensor.Read(start.x, start.y, " + "start.z, S);\n"; + c += " float4 src1 = args.src_tensor.Read(end.x, start.y, start.z, " + "S);\n"; + c += " float4 src2 = args.src_tensor.Read(start.x, end.y, start.z, " + "S);\n"; + c += " float4 src3 = args.src_tensor.Read(end.x, end.y, start.z, " + "S);\n"; + c += " float4 src4 = args.src_tensor.Read(start.x, start.y, end.z, " + "S);\n"; + c += " float4 src5 = args.src_tensor.Read(end.x, start.y, end.z, " + "S);\n"; + c += " float4 src6 = args.src_tensor.Read(start.x, end.y, end.z, " + "S);\n"; + c += " float4 src7 = args.src_tensor.Read(end.x, end.y, end.z, " + "S);\n"; c += " float4 t0 = mix(mix(src0, src1, t.x), mix(src2, src3, t.x), t.y);\n"; c += " float4 t1 = mix(mix(src4, src5, t.x), mix(src6, src7, t.x), t.y);\n"; c += " FLT4 r0 = TO_FLT4(mix(t0, t1, t.z));\n"; } - const LinkingContext context{"r0", "X", "Y", "S"}; - c += PostProcess(linked_operations, context); - c += " " + dst_tensor.WriteWHDS("r0", "X", "Y", "Z", "S"); + c += " args.dst_tensor.Write(r0, X, Y, Z, S);\n"; c += "}\n"; return c; } @@ -210,27 +208,32 @@ Resize& Resize::operator=(Resize&& operation) { } absl::Status Resize::Compile(const CreationContext& creation_context) { - const auto code = GetResizeCode(definition_, attr_.type, - attr_.half_pixel_centers, linked_operations_); + std::string code = + GetResizeCode(definition_, attr_.type, attr_.half_pixel_centers, &args_); + std::string element_wise_code; + RETURN_IF_ERROR( + MergeOperations(linked_operations_, &args_, &element_wise_code)); + RETURN_IF_ERROR(args_.TransformToCLCode(creation_context.device->GetInfo(), + {{"dst_tensor", element_wise_code}}, + &code)); return creation_context.cache->GetOrCreateCLKernel( code, "main_function", *creation_context.context, *creation_context.device, &kernel_); } absl::Status Resize::BindArguments() { - kernel_.ResetBindingCounter(); - RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); - RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); - RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting())); - RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHSB())); - RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB())); - RETURN_IF_ERROR( - kernel_.SetBytesAuto(int2(src_[0]->Width() - 1, src_[0]->Height() - 1))); - float2 scale_factor = - float2(CalculateResizeScale(src_[0]->Width(), dst_[0]->Width(), attr_), - CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_)); - RETURN_IF_ERROR(kernel_.SetBytesAuto(scale_factor)); - return absl::OkStatus(); + RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0])); + RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0])); + RETURN_IF_ERROR(args_.SetInt("border_x", src_[0]->Width() - 1)); + RETURN_IF_ERROR(args_.SetInt("border_y", src_[0]->Height() - 1)); + RETURN_IF_ERROR(args_.SetFloat( + "scale_factor_x", + CalculateResizeScale(src_[0]->Width(), dst_[0]->Width(), attr_))); + RETURN_IF_ERROR(args_.SetFloat( + "scale_factor_y", + CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_))); + RETURN_IF_ERROR(SetArguments(linked_operations_, &args_)); + return args_.Bind(kernel_.kernel()); } int3 Resize::GetGridSize() const { @@ -272,31 +275,35 @@ Resize3D& Resize3D::operator=(Resize3D&& operation) { } absl::Status Resize3D::Compile(const CreationContext& creation_context) { - const auto code = - GetResize3DCode(definition_, attr_.type, linked_operations_); + std::string code = GetResize3DCode(definition_, attr_.type, &args_); + std::string element_wise_code; + RETURN_IF_ERROR( + MergeOperations(linked_operations_, &args_, &element_wise_code)); + RETURN_IF_ERROR(args_.TransformToCLCode(creation_context.device->GetInfo(), + {{"dst_tensor", element_wise_code}}, + &code)); return creation_context.cache->GetOrCreateCLKernel( code, "main_function", *creation_context.context, *creation_context.device, &kernel_); } absl::Status Resize3D::BindArguments() { - kernel_.ResetBindingCounter(); - RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); - RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); - RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting())); - RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHDS())); - RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHDS())); - if (definition_.IsBatchSupported()) { - RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->Batch())); - } - RETURN_IF_ERROR(kernel_.SetBytesAuto(int4( - src_[0]->Width() - 1, src_[0]->Height() - 1, src_[0]->Depth() - 1, 0))); - float4 scale_factor = float4( - CalculateResizeScale(src_[0]->Width(), dst_[0]->Width(), attr_), - CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_), - CalculateResizeScale(src_[0]->Depth(), dst_[0]->Depth(), attr_), 1.0f); - RETURN_IF_ERROR(kernel_.SetBytesAuto(scale_factor)); - return absl::OkStatus(); + RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0])); + RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0])); + RETURN_IF_ERROR(args_.SetInt("border_x", src_[0]->Width() - 1)); + RETURN_IF_ERROR(args_.SetInt("border_y", src_[0]->Height() - 1)); + RETURN_IF_ERROR(args_.SetInt("border_z", src_[0]->Depth() - 1)); + RETURN_IF_ERROR(args_.SetFloat( + "scale_factor_x", + CalculateResizeScale(src_[0]->Width(), dst_[0]->Width(), attr_))); + RETURN_IF_ERROR(args_.SetFloat( + "scale_factor_y", + CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_))); + RETURN_IF_ERROR(args_.SetFloat( + "scale_factor_z", + CalculateResizeScale(src_[0]->Depth(), dst_[0]->Depth(), attr_))); + RETURN_IF_ERROR(SetArguments(linked_operations_, &args_)); + return args_.Bind(kernel_.kernel()); } int3 Resize3D::GetGridSize() const {