Batch support for MaxUnpooling.
PiperOrigin-RevId: 272992496
This commit is contained in:
		
							parent
							
								
									a59e1600e6
								
							
						
					
					
						commit
						8c650bc738
					
				@ -28,65 +28,79 @@ namespace {
 | 
			
		||||
std::string GetMaxUnoolingKernelCode(
 | 
			
		||||
    const OperationDef& op_def, const CLDevice& device,
 | 
			
		||||
    const std::vector<ElementwiseOperation*>& linked_operations) {
 | 
			
		||||
  TensorCodeGenerator src("src_data", "src_size", op_def.src_tensors[0]);
 | 
			
		||||
  TensorCodeGenerator src_ind("src_data_indices", "src_size",
 | 
			
		||||
  TensorCodeGenerator src("src_data",
 | 
			
		||||
                          {"src_size.x", "src_size.y", "src_size.z"},
 | 
			
		||||
                          op_def.src_tensors[0]);
 | 
			
		||||
  TensorCodeGenerator src_ind("src_data_indices",
 | 
			
		||||
                              {"src_size.x", "src_size.y", "src_size.z"},
 | 
			
		||||
                              op_def.src_tensors[1]);
 | 
			
		||||
  TensorCodeGenerator dst("dst_data", "dst_size", op_def.dst_tensors[0]);
 | 
			
		||||
  TensorCodeGenerator dst("dst_data",
 | 
			
		||||
                          {"dst_size.x", "dst_size.y", "dst_size.z"},
 | 
			
		||||
                          op_def.dst_tensors[0]);
 | 
			
		||||
 | 
			
		||||
  const auto address_mode = GetFastestZeroMode(device);
 | 
			
		||||
 | 
			
		||||
  std::string code = GetCommonDefines(op_def.precision);
 | 
			
		||||
  std::string c = GetCommonDefines(op_def.precision);
 | 
			
		||||
 | 
			
		||||
  code += "__kernel void main_function(\n";
 | 
			
		||||
  code += src.GetDeclaration(AccessType::READ) + ",\n";
 | 
			
		||||
  code += src_ind.GetDeclaration(AccessType::READ);
 | 
			
		||||
  code += GetArgsDeclaration(linked_operations);
 | 
			
		||||
  code += dst.GetDeclaration(AccessType::WRITE) + ",\n";
 | 
			
		||||
  code += "    int4 src_size,      \n";
 | 
			
		||||
  code += "    int4 dst_size,      \n";
 | 
			
		||||
  code += "    int2 kernel_size,   \n";
 | 
			
		||||
  code += "    int2 padding,       \n";
 | 
			
		||||
  code += "    int2 stride         \n";
 | 
			
		||||
  code += ") {\n";
 | 
			
		||||
  code += "  int X = get_global_id(0);\n";
 | 
			
		||||
  code += "  int Y = get_global_id(1);\n";
 | 
			
		||||
  code += "  int Z = get_global_id(2);\n";
 | 
			
		||||
  code +=
 | 
			
		||||
      "  if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.w) return; \n";
 | 
			
		||||
  code += "  int src_x = (X + padding.x) / stride.x;\n";
 | 
			
		||||
  code += "  int src_y = (Y + padding.y) / stride.y;\n";
 | 
			
		||||
  code += "  " + src.GetAddress("src_adr", "src_x", "src_y", "Z") + "\n";
 | 
			
		||||
  if (op_def.src_tensors[0].storage_type == TensorStorageType::BUFFER) {
 | 
			
		||||
    code += "  bool outside = src_x < 0 || src_y < 0 ||";
 | 
			
		||||
    code += "  src_x >= src_size.x || src_y >= src_size.y;\n";
 | 
			
		||||
    code += "  FLT4 src = (FLT4)(0.0f);\n";
 | 
			
		||||
    code += "  int4 ind = (int4)(0);\n";
 | 
			
		||||
    code += "  if (!outside) {\n";
 | 
			
		||||
    code += "    src = " + src.Read("src_adr", TextureAddressMode::DONT_CARE) +
 | 
			
		||||
            ";\n";
 | 
			
		||||
    code += "    ind = convert_int4(" +
 | 
			
		||||
            src_ind.Read("src_adr", TextureAddressMode::DONT_CARE) + ");\n";
 | 
			
		||||
    code += "  }\n";
 | 
			
		||||
  c += "__kernel void main_function(\n";
 | 
			
		||||
  c += src.GetDeclaration(AccessType::READ) + ",\n";
 | 
			
		||||
  c += src_ind.GetDeclaration(AccessType::READ);
 | 
			
		||||
  c += GetArgsDeclaration(linked_operations);
 | 
			
		||||
  c += dst.GetDeclaration(AccessType::WRITE) + ",\n";
 | 
			
		||||
  c += "    int4 src_size,      \n";
 | 
			
		||||
  c += "    int4 dst_size,      \n";
 | 
			
		||||
  c += "    int2 kernel_size,   \n";
 | 
			
		||||
  c += "    int2 padding,       \n";
 | 
			
		||||
  c += "    int2 stride         \n";
 | 
			
		||||
  c += ") {\n";
 | 
			
		||||
  c += "  int X = get_global_id(0);\n";
 | 
			
		||||
  c += "  int Y = get_global_id(1);\n";
 | 
			
		||||
  c += "  int Z = get_global_id(2);\n";
 | 
			
		||||
  c += "  if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return;\n";
 | 
			
		||||
  if (op_def.batch_support) {
 | 
			
		||||
    c += "  int B = get_global_id(0) % dst_size.w;\n";
 | 
			
		||||
    c += "  int X0 = get_global_id(0) / dst_size.w;\n";
 | 
			
		||||
    c += "  int src_x0 = (X0 + padding.x) / stride.x;\n";
 | 
			
		||||
    c += "  int src_x = src_x0 * dst_size.w + B;\n";
 | 
			
		||||
  } else {
 | 
			
		||||
    code += "  FLT4 src = " + src.Read("src_adr", address_mode) + ";\n";
 | 
			
		||||
    code += "  int4 ind = convert_int4(" +
 | 
			
		||||
            src_ind.Read("src_adr", address_mode) + ");\n";
 | 
			
		||||
    c += "  int src_x = (X + padding.x) / stride.x;\n";
 | 
			
		||||
  }
 | 
			
		||||
  code += "  int t_x = X - (src_x * stride.x - padding.x);\n";
 | 
			
		||||
  code += "  int t_y = Y - (src_y * stride.y - padding.y);\n";
 | 
			
		||||
  code += "  int t_index = t_y * kernel_size.x + t_x;\n";
 | 
			
		||||
  code += "  FLT4 result;\n";
 | 
			
		||||
  c += "  int src_y = (Y + padding.y) / stride.y;\n";
 | 
			
		||||
  c += "  " + src.GetAddress("src_adr", "src_x", "src_y", "Z") + "\n";
 | 
			
		||||
  if (op_def.src_tensors[0].storage_type == TensorStorageType::BUFFER) {
 | 
			
		||||
    c += "  bool outside = src_x < 0 || src_y < 0 ||";
 | 
			
		||||
    c += "  src_x >= src_size.x || src_y >= src_size.y;\n";
 | 
			
		||||
    c += "  FLT4 src = (FLT4)(0.0f);\n";
 | 
			
		||||
    c += "  int4 ind = (int4)(0);\n";
 | 
			
		||||
    c += "  if (!outside) {\n";
 | 
			
		||||
    c += "    src = " + src.Read("src_adr", TextureAddressMode::DONT_CARE) +
 | 
			
		||||
         ";\n";
 | 
			
		||||
    c += "    ind = convert_int4(" +
 | 
			
		||||
         src_ind.Read("src_adr", TextureAddressMode::DONT_CARE) + ");\n";
 | 
			
		||||
    c += "  }\n";
 | 
			
		||||
  } else {
 | 
			
		||||
    c += "  FLT4 src = " + src.Read("src_adr", address_mode) + ";\n";
 | 
			
		||||
    c += "  int4 ind = convert_int4(" + src_ind.Read("src_adr", address_mode) +
 | 
			
		||||
         ");\n";
 | 
			
		||||
  }
 | 
			
		||||
  if (op_def.batch_support) {
 | 
			
		||||
    c += "  int t_x = X0 - (src_x0 * stride.x - padding.x);\n";
 | 
			
		||||
  } else {
 | 
			
		||||
    c += "  int t_x = X - (src_x * stride.x - padding.x);\n";
 | 
			
		||||
  }
 | 
			
		||||
  c += "  int t_y = Y - (src_y * stride.y - padding.y);\n";
 | 
			
		||||
  c += "  int t_index = t_y * kernel_size.x + t_x;\n";
 | 
			
		||||
  c += "  FLT4 result;\n";
 | 
			
		||||
  const std::string channels[] = {".x", ".y", ".z", ".w"};
 | 
			
		||||
  for (int i = 0; i < 4; ++i) {
 | 
			
		||||
    const auto& s = channels[i];
 | 
			
		||||
    code += "  result" + s + "= t_index == ind" + s + "? src" + s + ": 0.0f;\n";
 | 
			
		||||
    c += "  result" + s + "= t_index == ind" + s + "? src" + s + ": 0.0f;\n";
 | 
			
		||||
  }
 | 
			
		||||
  const LinkingContext context{"result", "X", "Y", "Z"};
 | 
			
		||||
  code += PostProcess(linked_operations, context);
 | 
			
		||||
  code += "  " + dst.Write3D("result", "X", "Y", "Z");
 | 
			
		||||
  code += "}\n";
 | 
			
		||||
  c += PostProcess(linked_operations, {"result", "X", "Y", "Z"});
 | 
			
		||||
  c += "  " + dst.Write3D("result", "X", "Y", "Z");
 | 
			
		||||
  c += "}\n";
 | 
			
		||||
 | 
			
		||||
  return code;
 | 
			
		||||
  return c;
 | 
			
		||||
}
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
@ -131,8 +145,8 @@ Status MaxUnpooling::BindArguments() {
 | 
			
		||||
  RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[1]->GetMemoryPtr()));
 | 
			
		||||
  RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
 | 
			
		||||
  RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
 | 
			
		||||
  RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetSizeWithDepth()));
 | 
			
		||||
  RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetSizeWithDepth()));
 | 
			
		||||
  RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHDB()));
 | 
			
		||||
  RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHDB()));
 | 
			
		||||
  RETURN_IF_ERROR(kernel_.SetBytesAuto(kernel_size_));
 | 
			
		||||
  RETURN_IF_ERROR(kernel_.SetBytesAuto(padding_));
 | 
			
		||||
  RETURN_IF_ERROR(kernel_.SetBytesAuto(stride_));
 | 
			
		||||
@ -141,7 +155,7 @@ Status MaxUnpooling::BindArguments() {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int3 MaxUnpooling::GetGridSize() const {
 | 
			
		||||
  const int grid_x = dst_[0]->Width();
 | 
			
		||||
  const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
 | 
			
		||||
  const int grid_y = dst_[0]->Height();
 | 
			
		||||
  const int grid_z = dst_[0]->Depth();
 | 
			
		||||
  return int3(grid_x, grid_y, grid_z);
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user