Fixed accumulator precision for generic DepthWise implementation.
Removed inlined constants for kernel sizes. PiperOrigin-RevId: 303445102 Change-Id: If7de0468fc5554c85165f1a7020b54af512a7426
This commit is contained in:
		
							parent
							
								
									f0d9ae52dd
								
							
						
					
					
						commit
						86f5733999
					
				@ -475,91 +475,93 @@ std::vector<ComputeTaskDescriptorPtr> DepthWiseConvolution(
 | 
				
			|||||||
  std::string shader_source = R"(
 | 
					  std::string shader_source = R"(
 | 
				
			||||||
    #include <metal_stdlib>
 | 
					    #include <metal_stdlib>
 | 
				
			||||||
    using namespace metal;
 | 
					    using namespace metal;
 | 
				
			||||||
    constant int kernel_x = $0;
 | 
					 | 
				
			||||||
    constant int kernel_y = $1;
 | 
					 | 
				
			||||||
    struct uniforms {
 | 
					    struct uniforms {
 | 
				
			||||||
      int4 stride;
 | 
					      int4 src_size;
 | 
				
			||||||
      int4 padding;
 | 
					      int4 dst_size;
 | 
				
			||||||
      int4 dilation;
 | 
					      int2 stride;
 | 
				
			||||||
      int4 size;
 | 
					      int2 padding;
 | 
				
			||||||
 | 
					      int2 dilation;
 | 
				
			||||||
 | 
					      int2 kernel_size;
 | 
				
			||||||
      int4 channel_multiplier;
 | 
					      int4 channel_multiplier;
 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
    $$0
 | 
					    $0
 | 
				
			||||||
    kernel void ComputeFunction(
 | 
					    kernel void ComputeFunction(
 | 
				
			||||||
                                $$1
 | 
					                                $1
 | 
				
			||||||
                                uint tid[[thread_index_in_threadgroup]],
 | 
					                                uint tid[[thread_index_in_threadgroup]],
 | 
				
			||||||
                                uint3 gid[[thread_position_in_grid]]) {
 | 
					                                uint3 gid[[thread_position_in_grid]]) {
 | 
				
			||||||
      const bool outside = static_cast<int>(gid.x) >= params.size.z ||
 | 
					      int dst_x = static_cast<int>(gid.x);
 | 
				
			||||||
        static_cast<int>(gid.y) >= params.size.w;
 | 
					      int dst_y = static_cast<int>(gid.y);
 | 
				
			||||||
      if (outside) {
 | 
					      int dst_z = static_cast<int>(gid.z);
 | 
				
			||||||
        return;
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
      device FLT4* temp = filters + gid.z * kernel_y * kernel_x;
 | 
					 | 
				
			||||||
      float4 sum0 = float4(0.0f, 0.0f, 0.0f, 0.0f);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
      for(int ky = 0; ky < kernel_y; ++ky) {
 | 
					      if (dst_x >= U.dst_size.x || dst_y >= U.dst_size.y) return;
 | 
				
			||||||
        for(int kx = 0; kx < kernel_x; ++kx) {
 | 
					
 | 
				
			||||||
          int2 coords  = int2(gid.xy) * params.stride.xy + int2(kx, ky) * params.dilation.xy -
 | 
					      device FLT4* temp = filters + dst_z * U.kernel_size.x * U.kernel_size.y;
 | 
				
			||||||
            params.padding.xy;
 | 
					      ACCUM_FLT4 sum0 = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);
 | 
				
			||||||
          const bool outside = coords.x < 0 || coords.y < 0 ||
 | 
					
 | 
				
			||||||
            coords.x >= params.size.x || coords.y >= params.size.y;
 | 
					      int src_x = dst_x * U.stride.x + U.padding.x;
 | 
				
			||||||
          if (outside) continue;
 | 
					      int src_y = dst_y * U.stride.y + U.padding.y;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      for(int ky = 0; ky < U.kernel_size.y; ++ky) {
 | 
				
			||||||
 | 
					        int yc = ky * U.dilation.y + src_x;
 | 
				
			||||||
 | 
					        if (yc < 0 || yc >= U.src_size.y) continue;
 | 
				
			||||||
 | 
					        for(int kx = 0; kx < U.kernel_size.x; ++kx) {
 | 
				
			||||||
 | 
					          int xc = kx * U.dilation.x + src_x;
 | 
				
			||||||
 | 
					          if (xc < 0 || xc >= U.src_size.x) continue;
 | 
				
			||||||
)";
 | 
					)";
 | 
				
			||||||
  if (channels_multiplier == 1) {
 | 
					  if (channels_multiplier == 1) {
 | 
				
			||||||
    shader_source += R"(
 | 
					    shader_source += R"(
 | 
				
			||||||
        const int src_layer = gid.z;
 | 
					        int src_layer = dst_z;
 | 
				
			||||||
        const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x;
 | 
					        int src_index = (src_layer * U.src_size.y + yc) * U.src_size.x + xc;
 | 
				
			||||||
        const FLT4 src_modified = src_buffer[src_index];
 | 
					        FLT4 src_modified = src_buffer[src_index];
 | 
				
			||||||
)";
 | 
					)";
 | 
				
			||||||
  } else if (channels_multiplier == 2) {
 | 
					  } else if (channels_multiplier == 2) {
 | 
				
			||||||
    shader_source += R"(
 | 
					    shader_source += R"(
 | 
				
			||||||
        const int src_layer = gid.z / 2;
 | 
					        int src_layer = dst_z / 2;
 | 
				
			||||||
        const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x;
 | 
					        int src_index = (src_layer * U.src_size.y + yc) * U.src_size.x + xc;
 | 
				
			||||||
        const FLT4 src = src_buffer[src_index];
 | 
					        FLT4 src = src_buffer[src_index];
 | 
				
			||||||
        const FLT2 t0 = gid.z % 2 == 0 ? src.xy : src.zw;
 | 
					        FLT2 t0 = dst_z % 2 == 0 ? src.xy : src.zw;
 | 
				
			||||||
        const FLT4 src_modified = FLT4(t0.x, t0.x, t0.y, t0.y);
 | 
					        FLT4 src_modified = FLT4(t0.x, t0.x, t0.y, t0.y);
 | 
				
			||||||
)";
 | 
					)";
 | 
				
			||||||
  } else if (channels_multiplier == 4) {
 | 
					  } else if (channels_multiplier == 4) {
 | 
				
			||||||
    shader_source += R"(
 | 
					    shader_source += R"(
 | 
				
			||||||
        const int src_layer = gid.z / 4;
 | 
					        int src_layer = dst_z / 4;
 | 
				
			||||||
        const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x;
 | 
					        int src_index = (src_layer * U.src_size.y + yc) * U.src_size.x + xc;
 | 
				
			||||||
        const FLT4 src = src_buffer[src_index];
 | 
					        FLT4 src = src_buffer[src_index];
 | 
				
			||||||
        const FLT t0 = src[gid.z % 4];
 | 
					        FLT t0 = src[dst_z % 4];
 | 
				
			||||||
        const FLT4 src_modified = FLT4(t0, t0, t0, t0);
 | 
					        FLT4 src_modified = FLT4(t0, t0, t0, t0);
 | 
				
			||||||
)";
 | 
					)";
 | 
				
			||||||
  } else {
 | 
					  } else {
 | 
				
			||||||
    shader_source += R"(
 | 
					    shader_source += R"(
 | 
				
			||||||
        const int src_layer = gid.z / params.channel_multiplier.x;
 | 
					        int src_layer = dst_z / U.channel_multiplier.x;
 | 
				
			||||||
        const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x;
 | 
					        int src_index = (src_layer * U.src_size.y + yc) * U.src_size.x + xc;
 | 
				
			||||||
        const FLT4 src = src_buffer[src_index];
 | 
					        FLT4 src = src_buffer[src_index];
 | 
				
			||||||
        FLT4 src_modified;
 | 
					        FLT4 src_modified;
 | 
				
			||||||
        const int src_layer_offset = (gid.z % params.channel_multiplier.x) * 4;
 | 
					        const int src_layer_offset = (dst_z % U.channel_multiplier.x) * 4;
 | 
				
			||||||
        src_modified.x = src[(src_layer_offset + 0) / params.channel_multiplier.x];
 | 
					        src_modified.x = src[(src_layer_offset + 0) / U.channel_multiplier.x];
 | 
				
			||||||
        src_modified.y = src[(src_layer_offset + 1) / params.channel_multiplier.x];
 | 
					        src_modified.y = src[(src_layer_offset + 1) / U.channel_multiplier.x];
 | 
				
			||||||
        src_modified.z = src[(src_layer_offset + 2) / params.channel_multiplier.x];
 | 
					        src_modified.z = src[(src_layer_offset + 2) / U.channel_multiplier.x];
 | 
				
			||||||
        src_modified.w = src[(src_layer_offset + 3) / params.channel_multiplier.x];
 | 
					        src_modified.w = src[(src_layer_offset + 3) / U.channel_multiplier.x];
 | 
				
			||||||
)";
 | 
					)";
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  shader_source += R"(
 | 
					  shader_source += R"(
 | 
				
			||||||
          sum0 += float4(src_modified * temp[ky * kernel_x + kx]);
 | 
					          sum0 += TO_ACCUM4_TYPE(src_modified * temp[ky * U.kernel_size.x + kx]);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      FLT4 res = FLT4(sum0 + float4(biases[gid.z]));
 | 
					      FLT4 res = FLT4(sum0) + biases[dst_z];
 | 
				
			||||||
      const int linear_index = (gid.z * params.size.w + int(gid.y)) * params.size.z + int(gid.x);
 | 
					      const int linear_index = (dst_z * U.dst_size.y + dst_y) * U.dst_size.x + dst_x;
 | 
				
			||||||
      FLT4 value = res;
 | 
					      FLT4 value = res;
 | 
				
			||||||
      $$2
 | 
					      $2
 | 
				
			||||||
      output_buffer[linear_index] = value;
 | 
					      dst_buffer[linear_index] = value;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  )";
 | 
					  )";
 | 
				
			||||||
  desc->shader_source = absl::Substitute(shader_source, attr.weights.shape.w,
 | 
					  desc->shader_source = shader_source;
 | 
				
			||||||
                                         attr.weights.shape.h);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  desc->input_buffers = {
 | 
					  desc->input_buffers = {
 | 
				
			||||||
      {input_id, "device FLT4* const src_buffer"},
 | 
					      {input_id, "device FLT4* const src_buffer"},
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  desc->output_buffer = {
 | 
					  desc->output_buffer = {
 | 
				
			||||||
      output_id, "device FLT4* output_buffer",
 | 
					      output_id, "device FLT4* dst_buffer",
 | 
				
			||||||
      [input_id, attr](const std::map<ValueId, BHWC>& buffers) {
 | 
					      [input_id, attr](const std::map<ValueId, BHWC>& buffers) {
 | 
				
			||||||
        auto out_shape =
 | 
					        auto out_shape =
 | 
				
			||||||
            CalculateOutputShape(buffers.find(input_id)->second, attr);
 | 
					            CalculateOutputShape(buffers.find(input_id)->second, attr);
 | 
				
			||||||
@ -577,27 +579,27 @@ std::vector<ComputeTaskDescriptorPtr> DepthWiseConvolution(
 | 
				
			|||||||
  };
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  desc->uniform_buffers = {
 | 
					  desc->uniform_buffers = {
 | 
				
			||||||
      {"constant uniforms& params",
 | 
					      {"constant uniforms& U",
 | 
				
			||||||
       [input_id, output_id, attr](const std::map<ValueId, BHWC>& buffers) {
 | 
					       [input_id, output_id, attr](const std::map<ValueId, BHWC>& buffers) {
 | 
				
			||||||
         const auto& dimension = buffers.find(input_id)->second;
 | 
					         const auto& dimension = buffers.find(input_id)->second;
 | 
				
			||||||
         const auto& output_dimension = buffers.find(output_id)->second;
 | 
					         const auto& output_dimension = buffers.find(output_id)->second;
 | 
				
			||||||
         std::vector<int> uniform_params{
 | 
					         std::vector<int> uniform_params{
 | 
				
			||||||
             attr.strides.w,
 | 
					 | 
				
			||||||
             attr.strides.h,
 | 
					 | 
				
			||||||
             1,
 | 
					 | 
				
			||||||
             1,
 | 
					 | 
				
			||||||
             attr.padding.prepended.w,
 | 
					 | 
				
			||||||
             attr.padding.prepended.h,
 | 
					 | 
				
			||||||
             1,
 | 
					 | 
				
			||||||
             1,
 | 
					 | 
				
			||||||
             attr.dilations.w,
 | 
					 | 
				
			||||||
             attr.dilations.h,
 | 
					 | 
				
			||||||
             1,
 | 
					 | 
				
			||||||
             1,
 | 
					 | 
				
			||||||
             dimension.w,
 | 
					             dimension.w,
 | 
				
			||||||
             dimension.h,
 | 
					             dimension.h,
 | 
				
			||||||
 | 
					             IntegralDivideRoundUp(dimension.c, 4),
 | 
				
			||||||
 | 
					             0,
 | 
				
			||||||
             output_dimension.w,
 | 
					             output_dimension.w,
 | 
				
			||||||
             output_dimension.h,
 | 
					             output_dimension.h,
 | 
				
			||||||
 | 
					             IntegralDivideRoundUp(output_dimension.c, 4),
 | 
				
			||||||
 | 
					             0,
 | 
				
			||||||
 | 
					             attr.strides.w,
 | 
				
			||||||
 | 
					             attr.strides.h,
 | 
				
			||||||
 | 
					             -attr.padding.prepended.w,
 | 
				
			||||||
 | 
					             -attr.padding.prepended.h,
 | 
				
			||||||
 | 
					             attr.dilations.w,
 | 
				
			||||||
 | 
					             attr.dilations.h,
 | 
				
			||||||
 | 
					             attr.weights.shape.w,
 | 
				
			||||||
 | 
					             attr.weights.shape.h,
 | 
				
			||||||
             attr.weights.shape.o,
 | 
					             attr.weights.shape.o,
 | 
				
			||||||
             0,
 | 
					             0,
 | 
				
			||||||
             0,
 | 
					             0,
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user