[XLA/GPU] Change rounding scheme for tree reduction to round up to nearest square

Previously, we were rounding up to the nearest divisor of the largest batch we
could handle without introducing atomics.
That leads to:
 - Very large padding, e.g. rounding up 8193 to 16384
 - Very small dimensions of extra reduction kernels, e.g. 2

Instead, this CL uses a more "even" rounding scheme, where we round up the
number to the nearest square. Nearest square is guaranteed to be within
2 * sqrt(N) of a number N, so required padding is fairly small even in the
worst case.

PiperOrigin-RevId: 296086172
Change-Id: I7bfa72b2309fd1e3c596d6e028a9468660f84879
This commit is contained in:
George Karpenkov 2020-02-19 16:37:12 -08:00 committed by TensorFlower Gardener
parent 9c7537daae
commit 7f48bded8a
2 changed files with 84 additions and 97 deletions

View File

@ -67,24 +67,23 @@ ENTRY main {
zero = f32[] constant(0) zero = f32[] constant(0)
ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add
} }
)"; )";
// TODO(cheshire): a more generic check, do not hardcode the names. // TODO(cheshire): a more generic check, do not hardcode the names.
MatchOptimizedHloWithShapes(hlo_text, MatchOptimizedHloWithShapes(hlo_text,
R"( R"(
// CHECK: %fused_computation (param_0.2: f32[50000]) -> f32[7] { // CHECK: %fused_computation (param_0.2: f32[50000]) -> f32[224] {
// CHECK: %param_0.2 = f32[50000]{0} parameter(0) // CHECK: %param_0.2 = f32[50000]{0} parameter(0)
// CHECK: %zero_1 = f32[] constant(0) // CHECK: %zero_1 = f32[] constant(0)
// CHECK: %pad.1 = f32[57344]{0} pad(f32[50000]{0} %param_0.2, f32[] %zero_1), padding=0_7344 // CHECK: %pad.1 = f32[50176]{0} pad(f32[50000]{0} %param_0.2, f32[] %zero_1), padding=0_176
// CHECK: %bitcast.1 = f32[7,8192]{1,0} bitcast(f32[57344]{0} %pad.1) // CHECK: %bitcast.1 = f32[224,224]{1,0} bitcast(f32[50176]{0} %pad.1)
// CHECK: ROOT %reduce.2 = f32[7]{0} reduce(f32[7,8192]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add // CHECK: ROOT %reduce.2 = f32[224]{0} reduce(f32[224,224]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add
// CHECK: } // CHECK: }
// CHECK: ENTRY %main (input: f32[50000]) -> f32[] { // CHECK: ENTRY %main (input: f32[50000]) -> f32[] {
// CHECK: %input = f32[50000]{0} parameter(0) // CHECK: %input = f32[50000]{0} parameter(0)
// CHECK: %fusion = f32[7]{0} fusion(f32[50000]{0} %input), kind=kInput, calls=%fused_computation // CHECK: %fusion = f32[224]{0} fusion(f32[50000]{0} %input), kind=kInput, calls=%fused_computation
// CHECK: %zero = f32[] constant(0) // CHECK: %zero = f32[] constant(0)
// CHECK: ROOT %reduce.1 = f32[] reduce(f32[7]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add // CHECK: ROOT %reduce.1 = f32[] reduce(f32[224]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
// CHECK: } // CHECK: }
)"); )");
@ -107,27 +106,25 @@ ENTRY main {
zero = f32[] constant(0) zero = f32[] constant(0)
ROOT out = f32[100,100] reduce(input, zero), dimensions={2}, to_apply=add ROOT out = f32[100,100] reduce(input, zero), dimensions={2}, to_apply=add
} }
)"; )";
EnsureDeterminism(hlo_text); EnsureDeterminism(hlo_text);
MatchOptimizedHloWithShapes(hlo_text, MatchOptimizedHloWithShapes(hlo_text,
R"( R"(
// CHECK: %fused_computation (param_0.2: f32[100,100,10000]) -> f32[100,100,2] { // CHECK: %fused_computation (param_0.2: f32[100,100,10000]) -> f32[100,100,100] {
// CHECK: %param_0.2 = f32[100,100,10000]{2,1,0} parameter(0) // CHECK: %param_0.2 = f32[100,100,10000]{2,1,0} parameter(0)
// CHECK: %zero_1 = f32[] constant(0) // CHECK: %zero_1 = f32[] constant(0)
// CHECK: %pad.1 = f32[100,100,16384]{2,1,0} pad(f32[100,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384 // CHECK: %pad.1 = f32[100,100,10000]{2,1,0} pad(f32[100,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_0
// CHECK: %bitcast.1 = f32[100,100,2,8192]{3,2,1,0} bitcast(f32[100,100,16384]{2,1,0} %pad.1) // CHECK: %bitcast.1 = f32[100,100,100,100]{3,2,1,0} bitcast(f32[100,100,10000]{2,1,0} %pad.1)
// CHECK: ROOT %reduce.2 = f32[100,100,2]{2,1,0} reduce(f32[100,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3}, to_apply=%add // CHECK: ROOT %reduce.2 = f32[100,100,100]{2,1,0} reduce(f32[100,100,100,100]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3}, to_apply=%add
// CHECK: } // CHECK: }
// CHECK: ENTRY %main (input: f32[100,100,10000]) -> f32[100,100] { // CHECK: ENTRY %main (input: f32[100,100,10000]) -> f32[100,100] {
// CHECK: %input = f32[100,100,10000]{2,1,0} parameter(0) // CHECK: %input = f32[100,100,10000]{2,1,0} parameter(0)
// CHECK: %fusion = f32[100,100,2]{2,1,0} fusion(f32[100,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation // CHECK: %fusion = f32[100,100,100]{2,1,0} fusion(f32[100,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
// CHECK: %zero = f32[] constant(0) // CHECK: %zero = f32[] constant(0)
// CHECK: ROOT %reduce.1 = f32[100,100]{1,0} reduce(f32[100,100,2]{2,1,0} %fusion, f32[] %zero), dimensions={2}, to_apply=%add // CHECK: ROOT %reduce.1 = f32[100,100]{1,0} reduce(f32[100,100,100]{2,1,0} %fusion, f32[] %zero), dimensions={2}, to_apply=%add
// CHECK: } // CHECK: }
)"); )");
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
@ -149,23 +146,22 @@ ENTRY main {
zero = f32[] constant(0) zero = f32[] constant(0)
ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add
} }
)"; )";
MatchOptimizedHloWithShapes(hlo_text, MatchOptimizedHloWithShapes(hlo_text,
R"( R"(
// CHECK: %fused_computation (param_0.2: f32[1000000]) -> f32[123] { // CHECK: %fused_computation (param_0.2: f32[1000000]) -> f32[1000] {
// CHECK: %param_0.2 = f32[1000000]{0} parameter(0) // CHECK: %param_0.2 = f32[1000000]{0} parameter(0)
// CHECK: %zero_1 = f32[] constant(0) // CHECK: %zero_1 = f32[] constant(0)
// CHECK: %pad.1 = f32[1007616]{0} pad(f32[1000000]{0} %param_0.2, f32[] %zero_1), padding=0_7616 // CHECK: %pad.1 = f32[1000000]{0} pad(f32[1000000]{0} %param_0.2, f32[] %zero_1), padding=0_0
// CHECK: %bitcast.1 = f32[123,8192]{1,0} bitcast(f32[1007616]{0} %pad.1) // CHECK: %bitcast.1 = f32[1000,1000]{1,0} bitcast(f32[1000000]{0} %pad.1)
// CHECK: ROOT %reduce.2 = f32[123]{0} reduce(f32[123,8192]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add // CHECK: ROOT %reduce.2 = f32[1000]{0} reduce(f32[1000,1000]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add
// CHECK: } // CHECK: }
// CHECK: ENTRY %main (input: f32[1000000]) -> f32[] { // CHECK: ENTRY %main (input: f32[1000000]) -> f32[] {
// CHECK: %input = f32[1000000]{0} parameter(0) // CHECK: %input = f32[1000000]{0} parameter(0)
// CHECK: %fusion = f32[123]{0} fusion(f32[1000000]{0} %input), kind=kInput, calls=%fused_computation // CHECK: %fusion = f32[1000]{0} fusion(f32[1000000]{0} %input), kind=kInput, calls=%fused_computation
// CHECK: %zero = f32[] constant(0) // CHECK: %zero = f32[] constant(0)
// CHECK: ROOT %reduce.1 = f32[] reduce(f32[123]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add // CHECK: ROOT %reduce.1 = f32[] reduce(f32[1000]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
// CHECK: } // CHECK: }
)"); )");
@ -188,25 +184,24 @@ ENTRY main {
zero = f32[] constant(0) zero = f32[] constant(0)
ROOT out = f32[100] reduce(input, zero), dimensions={0,2}, to_apply=add ROOT out = f32[100] reduce(input, zero), dimensions={0,2}, to_apply=add
} }
)"; )";
EnsureDeterminism(hlo_text); EnsureDeterminism(hlo_text);
MatchOptimizedHloWithShapes(hlo_text, MatchOptimizedHloWithShapes(hlo_text,
R"( R"(
// CHECK: %fused_computation (param_0.2: f32[8,100,10000]) -> f32[100,2] { // CHECK: %fused_computation (param_0.2: f32[8,100,10000]) -> f32[100,100] {
// CHECK: %param_0.2 = f32[8,100,10000]{2,1,0} parameter(0) // CHECK: %param_0.2 = f32[8,100,10000]{2,1,0} parameter(0)
// CHECK: %zero_1 = f32[] constant(0) // CHECK: %zero_1 = f32[] constant(0)
// CHECK: %pad.1 = f32[8,100,16384]{2,1,0} pad(f32[8,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384 // CHECK: %pad.1 = f32[8,100,10000]{2,1,0} pad(f32[8,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_0
// CHECK: %bitcast.1 = f32[8,100,2,8192]{3,2,1,0} bitcast(f32[8,100,16384]{2,1,0} %pad.1) // CHECK: %bitcast.1 = f32[8,100,100,100]{3,2,1,0} bitcast(f32[8,100,10000]{2,1,0} %pad.1)
// CHECK: ROOT %reduce.2 = f32[100,2]{1,0} reduce(f32[8,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3,0}, to_apply=%add // CHECK: ROOT %reduce.2 = f32[100,100]{1,0} reduce(f32[8,100,100,100]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3,0}, to_apply=%add
// CHECK: } // CHECK: }
// CHECK: ENTRY %main (input: f32[8,100,10000]) -> f32[100] { // CHECK: ENTRY %main (input: f32[8,100,10000]) -> f32[100] {
// CHECK: %input = f32[8,100,10000]{2,1,0} parameter(0) // CHECK: %input = f32[8,100,10000]{2,1,0} parameter(0)
// CHECK: %fusion = f32[100,2]{1,0} fusion(f32[8,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation // CHECK: %fusion = f32[100,100]{1,0} fusion(f32[8,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
// CHECK: %zero = f32[] constant(0) // CHECK: %zero = f32[] constant(0)
// CHECK: ROOT %reduce.1 = f32[100]{0} reduce(f32[100,2]{1,0} %fusion, f32[] %zero), dimensions={1}, to_apply=%add // CHECK: ROOT %reduce.1 = f32[100]{0} reduce(f32[100,100]{1,0} %fusion, f32[] %zero), dimensions={1}, to_apply=%add
// CHECK: } // CHECK: }
)"); )");
@ -234,23 +229,19 @@ ENTRY main {
MatchOptimizedHloWithShapes(hlo_text, MatchOptimizedHloWithShapes(hlo_text,
R"( R"(
// CHECK: %fused_computation (param_0.4: f32[32,100,2]) -> f32[100] { // CHECK: %fused_computation (param_0.2: f32[32,100,10000]) -> f32[32,100,100] {
// CHECK: %param_0.4 = f32[32,100,2]{2,1,0} parameter(0) // CHECK: %param_0.2 = f32[32,100,10000]{2,1,0} parameter(0)
// CHECK: %zero_1 = f32[] constant(0) // CHECK: %zero_1 = f32[] constant(0)
// CHECK: %reduce.5 = f32[32,100]{1,0} reduce(f32[32,100,2]{2,1,0} %param_0.4, f32[] %zero_1), dimensions={2}, to_apply=%add // CHECK: %pad.1 = f32[32,100,10000]{2,1,0} pad(f32[32,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_0
// CHECK: ROOT %reduce.4 = f32[100]{0} reduce(f32[32,100]{1,0} %reduce.5, f32[] %zero_1), dimensions={0}, to_apply=%add // CHECK: %bitcast.1 = f32[32,100,100,100]{3,2,1,0} bitcast(f32[32,100,10000]{2,1,0} %pad.1)
// CHECK: } // CHECK: ROOT %reduce.4 = f32[32,100,100]{2,1,0} reduce(f32[32,100,100,100]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3}, to_apply=%add
// CHECK: %fused_computation.1 (param_0.5: f32[32,100,10000]) -> f32[32,100,2] {
// CHECK: %param_0.5 = f32[32,100,10000]{2,1,0} parameter(0)
// CHECK: %zero_2 = f32[] constant(0)
// CHECK: %pad.1 = f32[32,100,16384]{2,1,0} pad(f32[32,100,10000]{2,1,0} %param_0.5, f32[] %zero_2), padding=0_0x0_0x0_6384
// CHECK: %bitcast.1 = f32[32,100,2,8192]{3,2,1,0} bitcast(f32[32,100,16384]{2,1,0} %pad.1)
// CHECK: ROOT %reduce.6 = f32[32,100,2]{2,1,0} reduce(f32[32,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_2), dimensions={3}, to_apply=%add
// CHECK: } // CHECK: }
// CHECK: ENTRY %main (input: f32[32,100,10000]) -> f32[100] { // CHECK: ENTRY %main (input: f32[32,100,10000]) -> f32[100] {
// CHECK: %input = f32[32,100,10000]{2,1,0} parameter(0) // CHECK: %input = f32[32,100,10000]{2,1,0} parameter(0)
// CHECK: %fusion.1 = f32[32,100,2]{2,1,0} fusion(f32[32,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation.1 // CHECK: %fusion = f32[32,100,100]{2,1,0} fusion(f32[32,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
// CHECK: ROOT %fusion = f32[100]{0} fusion(f32[32,100,2]{2,1,0} %fusion.1), kind=kInput, calls=%fused_computation // CHECK: %zero = f32[] constant(0)
// CHECK: %reduce.3 = f32[32,100]{1,0} reduce(f32[32,100,100]{2,1,0} %fusion, f32[] %zero), dimensions={2}, to_apply=%add
// CHECK: ROOT %reduce.1 = f32[100]{0} reduce(f32[32,100]{1,0} %reduce.3, f32[] %zero), dimensions={0}, to_apply=%add
// CHECK: } // CHECK: }
)"); )");
@ -274,22 +265,22 @@ ENTRY main {
zero = f32[] constant(0) zero = f32[] constant(0)
ROOT out = f32[100] reduce(input, zero), dimensions={0}, to_apply=add ROOT out = f32[100] reduce(input, zero), dimensions={0}, to_apply=add
} }
)"; )";
MatchOptimizedHloWithShapes(hlo_text, MatchOptimizedHloWithShapes(hlo_text,
R"( R"(
// CHECK: %fused_computation (param_0.2: f32[10000,100]) -> f32[100] { // CHECK: %fused_computation (param_0.2: f32[10000,100]) -> f32[100,100] {
// CHECK: %param_0.2 = f32[10000,100]{1,0} parameter(0) // CHECK: %param_0.2 = f32[10000,100]{1,0} parameter(0)
// CHECK: %zero_1 = f32[] constant(0) // CHECK: %zero_1 = f32[] constant(0)
// CHECK: %pad.1 = f32[12288,100]{1,0} pad(f32[10000,100]{1,0} %param_0.2, f32[] %zero_1), padding=0_2288x0_0 // CHECK: %pad.1 = f32[10000,100]{1,0} pad(f32[10000,100]{1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0
// CHECK: %bitcast.1 = f32[3,4096,100]{2,1,0} bitcast(f32[12288,100]{1,0} %pad.1) // CHECK: %bitcast.1 = f32[100,100,100]{2,1,0} bitcast(f32[10000,100]{1,0} %pad.1)
// CHECK: %reduce.3 = f32[4096,100]{1,0} reduce(f32[3,4096,100]{2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add // CHECK: ROOT %reduce.2 = f32[100,100]{1,0} reduce(f32[100,100,100]{2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add
// CHECK: ROOT %reduce.2 = f32[100]{0} reduce(f32[4096,100]{1,0} %reduce.3, f32[] %zero_1), dimensions={0}, to_apply=%add
// CHECK: } // CHECK: }
// CHECK: ENTRY %main (input: f32[10000,100]) -> f32[100] { // CHECK: ENTRY %main (input: f32[10000,100]) -> f32[100] {
// CHECK: %input = f32[10000,100]{1,0} parameter(0) // CHECK: %input = f32[10000,100]{1,0} parameter(0)
// CHECK: ROOT %fusion = f32[100]{0} fusion(f32[10000,100]{1,0} %input), kind=kInput, calls=%fused_computation // CHECK: %fusion = f32[100,100]{1,0} fusion(f32[10000,100]{1,0} %input), kind=kInput, calls=%fused_computation
// CHECK: %zero = f32[] constant(0)
// CHECK: ROOT %reduce.1 = f32[100]{0} reduce(f32[100,100]{1,0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
// CHECK: } // CHECK: }
)"); )");
@ -316,17 +307,18 @@ ENTRY main {
MatchOptimizedHloWithShapes(hlo_text, MatchOptimizedHloWithShapes(hlo_text,
R"( R"(
// CHECK: %fused_computation (param_0.2: f32[10000,2,2,2]) -> f32[2,2,2] { // CHECK: %fused_computation (param_0.2: f32[10000,2,2,2]) -> f32[100,2,2,2] {
// CHECK: %param_0.2 = f32[10000,2,2,2]{3,2,1,0} parameter(0) // CHECK: %param_0.2 = f32[10000,2,2,2]{3,2,1,0} parameter(0)
// CHECK: %zero_1 = f32[] constant(0) // CHECK: %zero_1 = f32[] constant(0)
// CHECK: %pad.1 = f32[12288,2,2,2]{3,2,1,0} pad(f32[10000,2,2,2]{3,2,1,0} %param_0.2, f32[] %zero_1), padding=0_2288x0_0x0_0x0_0 // CHECK: %pad.1 = f32[10000,2,2,2]{3,2,1,0} pad(f32[10000,2,2,2]{3,2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_0x0_0
// CHECK: %bitcast.1 = f32[3,4096,2,2,2]{4,3,2,1,0} bitcast(f32[12288,2,2,2]{3,2,1,0} %pad.1) // CHECK: %bitcast.1 = f32[100,100,2,2,2]{4,3,2,1,0} bitcast(f32[10000,2,2,2]{3,2,1,0} %pad.1)
// CHECK: %reduce.3 = f32[4096,2,2,2]{3,2,1,0} reduce(f32[3,4096,2,2,2]{4,3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add // CHECK: ROOT %reduce.2 = f32[100,2,2,2]{3,2,1,0} reduce(f32[100,100,2,2,2]{4,3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add
// CHECK: ROOT %reduce.2 = f32[2,2,2]{2,1,0} reduce(f32[4096,2,2,2]{3,2,1,0} %reduce.3, f32[] %zero_1), dimensions={0}, to_apply=%add
// CHECK: } // CHECK: }
// CHECK: ENTRY %main (input: f32[10000,2,2,2]) -> f32[2,2,2] { // CHECK: ENTRY %main (input: f32[10000,2,2,2]) -> f32[2,2,2] {
// CHECK: %input = f32[10000,2,2,2]{3,2,1,0} parameter(0) // CHECK: %input = f32[10000,2,2,2]{3,2,1,0} parameter(0)
// CHECK: ROOT %fusion = f32[2,2,2]{2,1,0} fusion(f32[10000,2,2,2]{3,2,1,0} %input), kind=kInput, calls=%fused_computation // CHECK: %fusion = f32[100,2,2,2]{3,2,1,0} fusion(f32[10000,2,2,2]{3,2,1,0} %input), kind=kInput, calls=%fused_computation
// CHECK: %zero = f32[] constant(0)
// CHECK: ROOT %reduce.1 = f32[2,2,2]{2,1,0} reduce(f32[100,2,2,2]{3,2,1,0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
// CHECK: } // CHECK: }
)"); )");
@ -355,18 +347,18 @@ ENTRY main {
MatchOptimizedHloWithShapes(hlo_text, MatchOptimizedHloWithShapes(hlo_text,
R"( R"(
// CHECK: %fused_computation (param_0.2: f32[1000000,5]) -> f32[4096,5] { // CHECK: %fused_computation (param_0.2: f32[1000000,5]) -> f32[1000,5] {
// CHECK: %param_0.2 = f32[1000000,5]{1,0} parameter(0) // CHECK: %param_0.2 = f32[1000000,5]{1,0} parameter(0)
// CHECK: %zero_1 = f32[] constant(0) // CHECK: %zero_1 = f32[] constant(0)
// CHECK: %pad.1 = f32[1003520,5]{1,0} pad(f32[1000000,5]{1,0} %param_0.2, f32[] %zero_1), padding=0_3520x0_0 // CHECK: %pad.1 = f32[1000000,5]{1,0} pad(f32[1000000,5]{1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0
// CHECK: %bitcast.1 = f32[245,4096,5]{2,1,0} bitcast(f32[1003520,5]{1,0} %pad.1) // CHECK: %bitcast.1 = f32[1000,1000,5]{2,1,0} bitcast(f32[1000000,5]{1,0} %pad.1)
// CHECK: ROOT %reduce.2 = f32[4096,5]{1,0} reduce(f32[245,4096,5]{2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add // CHECK: ROOT %reduce.2 = f32[1000,5]{1,0} reduce(f32[1000,1000,5]{2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add
// CHECK: } // CHECK: }
// CHECK: ENTRY %main (input: f32[1000000,5]) -> f32[5] { // CHECK: ENTRY %main (input: f32[1000000,5]) -> f32[5] {
// CHECK: %input = f32[1000000,5]{1,0} parameter(0) // CHECK: %input = f32[1000000,5]{1,0} parameter(0)
// CHECK: %fusion = f32[4096,5]{1,0} fusion(f32[1000000,5]{1,0} %input), kind=kInput, calls=%fused_computation // CHECK: %fusion = f32[1000,5]{1,0} fusion(f32[1000000,5]{1,0} %input), kind=kInput, calls=%fused_computation
// CHECK: %zero = f32[] constant(0) // CHECK: %zero = f32[] constant(0)
// CHECK: ROOT %reduce.1 = f32[5]{0} reduce(f32[4096,5]{1,0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add // CHECK: ROOT %reduce.1 = f32[5]{0} reduce(f32[1000,5]{1,0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
// CHECK: } // CHECK: }
)"); )");

View File

@ -46,6 +46,11 @@ static constexpr int64 kColumnAtomicFreeBound = kWarpSize * 128;
// decreased column/row tiling. // decreased column/row tiling.
static constexpr int64 kBatchedAtomicFreeBound = 8; static constexpr int64 kBatchedAtomicFreeBound = 8;
// Returns the square root of the input rounded up to the nearest square.
static int64 SqrtOfRoundUpToNearestSquare(int64 input) {
return static_cast<int64>(std::ceil(std::sqrt(input)));
}
class ReductionRewriterVisitor : public DfsHloRewriteVisitor { class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
public: public:
explicit ReductionRewriterVisitor() {} explicit ReductionRewriterVisitor() {}
@ -105,26 +110,18 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
int64 reduced_dim_size = input_shape.dimensions(reduced_input_dimension); int64 reduced_dim_size = input_shape.dimensions(reduced_input_dimension);
VLOG(3) << "reduced_dim_size = " << reduced_dim_size; VLOG(3) << "reduced_dim_size = " << reduced_dim_size;
// TODO(cheshire): if atomic_free_bound is large, num_fit is likely to be
// small. Generating a reduction with very small reduced dimension is not // We pad to a nearest square (ceil(sqrt(x)))^2. Given that:
// efficient, it would be better to split the dimension sizes more evenly.
//
// One possible idea is to pad to a nearest square (ceil(sqrt(x)))^2.
// Given that:
// //
// (n + 1)^2 = n^2 + (2n+1) // (n + 1)^2 = n^2 + (2n+1)
// //
// it can be seen that the distance to the nearest square is at most twice // it can be seen that the distance to the nearest square is at most twice
// the square root of the input number. // the square root of the input number.
int64 num_fit = CeilOfRatio(reduced_dim_size, atomic_free_bound); int64 num_fit = SqrtOfRoundUpToNearestSquare(reduced_dim_size);
// Pad reduced dimension to the required number of elements. // Pad reduced dimension to the required number of elements.
HloInstruction *padded = [&] { HloInstruction *padded = [&] {
// TODO(cheshire): if atomic_free_bound is very large, padding all the way int64 padded_num_elements = num_fit * num_fit;
// up to to atomic_free_bound is wasteful, we could pad to a much smaller
// value.
if (reduced_dim_size % atomic_free_bound != 0) {
int64 padded_num_elements = num_fit * atomic_free_bound;
PaddingConfig padding_config = MakeNoPaddingConfig(input_shape.rank()); PaddingConfig padding_config = MakeNoPaddingConfig(input_shape.rank());
padding_config.mutable_dimensions(reduced_input_dimension) padding_config.mutable_dimensions(reduced_input_dimension)
->set_edge_padding_high(padded_num_elements - reduced_dim_size); ->set_edge_padding_high(padded_num_elements - reduced_dim_size);
@ -136,8 +133,6 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
VLOG(3) << "Generated padded shape: " << padded_shape.ToString(); VLOG(3) << "Generated padded shape: " << padded_shape.ToString();
return hlo->parent()->AddInstruction(HloInstruction::CreatePad( return hlo->parent()->AddInstruction(HloInstruction::CreatePad(
padded_shape, input, initial_value, padding_config)); padded_shape, input, initial_value, padding_config));
}
return input;
}(); }();
VLOG(1) << "Generated padding: " << padded->ToString(); VLOG(1) << "Generated padding: " << padded->ToString();
@ -146,7 +141,7 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
dim_idx++) { dim_idx++) {
if (dim_idx == reduced_input_dimension) { if (dim_idx == reduced_input_dimension) {
reshaped_dimensions.push_back(num_fit); reshaped_dimensions.push_back(num_fit);
reshaped_dimensions.push_back(atomic_free_bound); reshaped_dimensions.push_back(num_fit);
} else { } else {
reshaped_dimensions.push_back(padded->shape().dimensions(dim_idx)); reshaped_dimensions.push_back(padded->shape().dimensions(dim_idx));
} }