Add fast path to 2D pooling.

PiperOrigin-RevId: 345623618
Change-Id: I1be8a30d33a669cc0e7ae08ce1bd1140f760dcb4
This commit is contained in:
A. Unique TensorFlower 2020-12-04 00:39:09 -08:00 committed by TensorFlower Gardener
parent ba8003e12b
commit 41a09d7b77

View File

@ -107,7 +107,30 @@ absl::Status GenerateAveragePoolingCode(
{"window_w", attr.kernel.w},
};
std::string source = R"(
// Bounds checking helper functions.
auto x_in_bounds = [input_width = ctx.input_shapes[0][2],
kernel_width = attr.kernel.w](int64_t x) -> bool {
return 0 <= x && x + kernel_width <= input_width;
};
auto y_in_bounds = [input_height = ctx.input_shapes[0][1],
kernel_height = attr.kernel.h](int64_t y) -> bool {
return 0 <= y && y + kernel_height <= input_height;
};
// Only include a bounds check in the shader if it will actually be necessary
// at run time.
const int64_t output_shape_max_y = ctx.output_shapes[0][1] - 1;
const int64_t output_shape_max_x = ctx.output_shapes[0][2] - 1;
const int64_t base_x = -attr.padding.prepended.w;
const int64_t base_y = -attr.padding.prepended.h;
const bool bounds_check_necessary =
!(x_in_bounds(base_x) &&
x_in_bounds(base_x + output_shape_max_x * attr.strides.w) &&
y_in_bounds(base_y) &&
y_in_bounds(base_y + output_shape_max_y * attr.strides.h));
std::string source = bounds_check_necessary ?
R"(
int window_size = 0;
for (int a = 0; a < $window_h$; ++a) {
for (int b = 0; b < $window_w$; ++b) {
@ -121,7 +144,20 @@ absl::Status GenerateAveragePoolingCode(
// If window_size==0, window covered nothing. This situation is a sign of
// incorrectly constructed operation. NaNs are expected as output.
value_0 /= float(window_size);
)"
:
R"(
for (int a = 0; a < $window_h$; ++a) {
for (int b = 0; b < $window_w$; ++b) {
ivec2 coord = gid.xy * $stride$ - $offset$ + ivec2(b, a);
value_0 += $input_data_0[coord.x, coord.y, gid.z]$;
}
}
// If the denominator is 0, that is a sign of an incorrectly constructed
// operation. NaNs are expected as output.
value_0 /= float($window_h$ * $window_w$);
)";
*generated_code = {
/*parameters=*/std::move(parameters),
/*objects=*/{},