Add fast path to 2D pooling.
PiperOrigin-RevId: 345623618 Change-Id: I1be8a30d33a669cc0e7ae08ce1bd1140f760dcb4
This commit is contained in:
parent
ba8003e12b
commit
41a09d7b77
@ -107,7 +107,30 @@ absl::Status GenerateAveragePoolingCode(
|
||||
{"window_w", attr.kernel.w},
|
||||
};
|
||||
|
||||
std::string source = R"(
|
||||
// Bounds checking helper functions.
|
||||
auto x_in_bounds = [input_width = ctx.input_shapes[0][2],
|
||||
kernel_width = attr.kernel.w](int64_t x) -> bool {
|
||||
return 0 <= x && x + kernel_width <= input_width;
|
||||
};
|
||||
auto y_in_bounds = [input_height = ctx.input_shapes[0][1],
|
||||
kernel_height = attr.kernel.h](int64_t y) -> bool {
|
||||
return 0 <= y && y + kernel_height <= input_height;
|
||||
};
|
||||
|
||||
// Only include a bounds check in the shader if it will actually be necessary
|
||||
// at run time.
|
||||
const int64_t output_shape_max_y = ctx.output_shapes[0][1] - 1;
|
||||
const int64_t output_shape_max_x = ctx.output_shapes[0][2] - 1;
|
||||
const int64_t base_x = -attr.padding.prepended.w;
|
||||
const int64_t base_y = -attr.padding.prepended.h;
|
||||
const bool bounds_check_necessary =
|
||||
!(x_in_bounds(base_x) &&
|
||||
x_in_bounds(base_x + output_shape_max_x * attr.strides.w) &&
|
||||
y_in_bounds(base_y) &&
|
||||
y_in_bounds(base_y + output_shape_max_y * attr.strides.h));
|
||||
|
||||
std::string source = bounds_check_necessary ?
|
||||
R"(
|
||||
int window_size = 0;
|
||||
for (int a = 0; a < $window_h$; ++a) {
|
||||
for (int b = 0; b < $window_w$; ++b) {
|
||||
@ -121,7 +144,20 @@ absl::Status GenerateAveragePoolingCode(
|
||||
// If window_size==0, window covered nothing. This situation is a sign of
|
||||
// incorrectly constructed operation. NaNs are expected as output.
|
||||
value_0 /= float(window_size);
|
||||
)"
|
||||
:
|
||||
R"(
|
||||
for (int a = 0; a < $window_h$; ++a) {
|
||||
for (int b = 0; b < $window_w$; ++b) {
|
||||
ivec2 coord = gid.xy * $stride$ - $offset$ + ivec2(b, a);
|
||||
value_0 += $input_data_0[coord.x, coord.y, gid.z]$;
|
||||
}
|
||||
}
|
||||
// If the denominator is 0, that is a sign of an incorrectly constructed
|
||||
// operation. NaNs are expected as output.
|
||||
value_0 /= float($window_h$ * $window_w$);
|
||||
)";
|
||||
|
||||
*generated_code = {
|
||||
/*parameters=*/std::move(parameters),
|
||||
/*objects=*/{},
|
||||
|
Loading…
x
Reference in New Issue
Block a user