SpaceToDepth converted to new style.
PiperOrigin-RevId: 316775897 Change-Id: I2715cadbf112dffc93b2b45570d2220444af31a4
This commit is contained in:
		
							parent
							
								
									6b126156d9
								
							
						
					
					
						commit
						a6945b9b0f
					
				@ -27,28 +27,33 @@ namespace gpu {
 | 
			
		||||
namespace cl {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
std::string GetSpaceToDepthCode(
 | 
			
		||||
    const OperationDef& op_def,
 | 
			
		||||
    const std::vector<ElementwiseOperation*>& linked_operations) {
 | 
			
		||||
  TensorCodeGenerator src_tensor(
 | 
			
		||||
      "src_data", WHSPoint{"src_size.x", "src_size.y", "src_size.z"},
 | 
			
		||||
      op_def.src_tensors[0]);
 | 
			
		||||
  TensorCodeGenerator dst_tensor(
 | 
			
		||||
      "dst_data", WHSPoint{"dst_size.x", "dst_size.y", "dst_size.z"},
 | 
			
		||||
      op_def.dst_tensors[0]);
 | 
			
		||||
std::string GetSpaceToDepthCode(const OperationDef& op_def, Arguments* args) {
 | 
			
		||||
  args->AddObjectRef(
 | 
			
		||||
      "src_tensor", AccessType::READ,
 | 
			
		||||
      absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]));
 | 
			
		||||
  args->AddObjectRef(
 | 
			
		||||
      "dst_tensor", AccessType::WRITE,
 | 
			
		||||
      absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]));
 | 
			
		||||
  args->AddInt("block_size");
 | 
			
		||||
 | 
			
		||||
  std::string c = GetCommonDefines(op_def.precision);
 | 
			
		||||
  c += "__kernel void main_function(\n";
 | 
			
		||||
  c += src_tensor.GetDeclaration(AccessType::READ);
 | 
			
		||||
  c += GetArgsDeclaration(linked_operations);
 | 
			
		||||
  c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
 | 
			
		||||
  c += "    int4 src_size,\n";
 | 
			
		||||
  c += "    int4 dst_size,\n";
 | 
			
		||||
  c += "    int src_channels,\n";
 | 
			
		||||
  c += "    int block_size) {\n";
 | 
			
		||||
  c += "  int X = get_global_id(0);\n";
 | 
			
		||||
  c += "$0) {\n";
 | 
			
		||||
  if (op_def.IsBatchSupported()) {
 | 
			
		||||
    c += "  int linear_id = get_global_id(0);\n";
 | 
			
		||||
    c += "  int X = linear_id / args.dst_tensor.Batch();\n";
 | 
			
		||||
    c += "  int B = linear_id % args.dst_tensor.Batch();\n";
 | 
			
		||||
    c += "  args.dst_tensor.SetBatchRef(B);\n";
 | 
			
		||||
    c += "  args.src_tensor.SetBatchRef(B);\n";
 | 
			
		||||
  } else {
 | 
			
		||||
    c += "  int X = get_global_id(0);\n";
 | 
			
		||||
  }
 | 
			
		||||
  c += "  int Y = get_global_id(1);\n";
 | 
			
		||||
  c += "  int Z = get_global_id(2);\n";
 | 
			
		||||
  c += "  if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return;\n";
 | 
			
		||||
  c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
 | 
			
		||||
       "Z >= args.dst_tensor.Slices()) { \n";
 | 
			
		||||
  c += "    return; \n";
 | 
			
		||||
  c += "  } \n";
 | 
			
		||||
  c += "  FLT tmp[4];\n";
 | 
			
		||||
  c += "  tmp[0] = (FLT)(0.0f);\n";
 | 
			
		||||
  c += "  tmp[1] = (FLT)(0.0f);\n";
 | 
			
		||||
@ -56,19 +61,17 @@ std::string GetSpaceToDepthCode(
 | 
			
		||||
  c += "  tmp[3] = (FLT)(0.0f);\n";
 | 
			
		||||
  c += "  for (int i = 0; i < 4; ++i) {\n";
 | 
			
		||||
  c += "    int dst_c = 4 * Z + i;\n";
 | 
			
		||||
  c += "    int block_id = dst_c / src_channels;\n";
 | 
			
		||||
  c += "    int src_x = X * block_size + block_id % block_size;\n";
 | 
			
		||||
  c += "    int src_y = Y * block_size + block_id / block_size;\n";
 | 
			
		||||
  c += "    int src_c = dst_c % src_channels;\n";
 | 
			
		||||
  c += "    int block_id = dst_c / args.src_tensor.Channels();\n";
 | 
			
		||||
  c += "    int src_x = X * args.block_size + block_id % args.block_size;\n";
 | 
			
		||||
  c += "    int src_y = Y * args.block_size + block_id / args.block_size;\n";
 | 
			
		||||
  c += "    int src_c = dst_c % args.src_tensor.Channels();\n";
 | 
			
		||||
  c += "    int src_z = src_c / 4;\n";
 | 
			
		||||
  c += "    FLT4 t = " + src_tensor.ReadWHS("src_x", "src_y", "src_z") + ";\n";
 | 
			
		||||
  c += "    FLT4 t =  args.src_tensor.Read(src_x, src_y, src_z);\n";
 | 
			
		||||
  c += "    FLT t_ar[4] = {t.x, t.y, t.z, t.w};\n";
 | 
			
		||||
  c += "    tmp[i] = t_ar[src_c % 4];\n";
 | 
			
		||||
  c += "  }\n";
 | 
			
		||||
  c += "  FLT4 result = (FLT4)(tmp[0], tmp[1], tmp[2], tmp[3]);\n";
 | 
			
		||||
  const LinkingContext context{"result", "X", "Y", "Z"};
 | 
			
		||||
  c += PostProcess(linked_operations, context);
 | 
			
		||||
  c += "  " + dst_tensor.WriteWHS("result", "X", "Y", "Z");
 | 
			
		||||
  c += "  args.dst_tensor.Write(result, X, Y, Z);\n";
 | 
			
		||||
  c += "}\n";
 | 
			
		||||
  return c;
 | 
			
		||||
}
 | 
			
		||||
@ -92,21 +95,24 @@ SpaceToDepth& SpaceToDepth::operator=(SpaceToDepth&& operation) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status SpaceToDepth::Compile(const CreationContext& creation_context) {
 | 
			
		||||
  const auto code = GetSpaceToDepthCode(definition_, linked_operations_);
 | 
			
		||||
  std::string code = GetSpaceToDepthCode(definition_, &args_);
 | 
			
		||||
  std::string element_wise_code;
 | 
			
		||||
  RETURN_IF_ERROR(
 | 
			
		||||
      MergeOperations(linked_operations_, &args_, &element_wise_code));
 | 
			
		||||
  RETURN_IF_ERROR(args_.TransformToCLCode(creation_context.device->GetInfo(),
 | 
			
		||||
                                          {{"dst_tensor", element_wise_code}},
 | 
			
		||||
                                          &code));
 | 
			
		||||
  return creation_context.cache->GetOrCreateCLKernel(
 | 
			
		||||
      code, "main_function", *creation_context.context,
 | 
			
		||||
      *creation_context.device, &kernel_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status SpaceToDepth::BindArguments() {
 | 
			
		||||
  kernel_.ResetBindingCounter();
 | 
			
		||||
  RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
 | 
			
		||||
  RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
 | 
			
		||||
  RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
 | 
			
		||||
  RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
 | 
			
		||||
  RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
 | 
			
		||||
  RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->Channels()));
 | 
			
		||||
  return kernel_.SetBytesAuto(attr_.block_size);
 | 
			
		||||
  RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0]));
 | 
			
		||||
  RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0]));
 | 
			
		||||
  RETURN_IF_ERROR(args_.SetInt("block_size", attr_.block_size));
 | 
			
		||||
  RETURN_IF_ERROR(SetArguments(linked_operations_, &args_));
 | 
			
		||||
  return args_.Bind(kernel_.kernel());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int3 SpaceToDepth::GetGridSize() const {
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user