[XLA/GPU] Change rounding scheme for tree reduction to round up to nearest square
Previously, we were rounding up to the nearest divisor of the largest batch we could handle without introducing atomics. That leads to: - Very large padding, e.g. rounding up 8193 to 16384 - Very small dimensions of extra reduction kernels, e.g. 2 Instead, this CL uses a more "even" rounding scheme, where we round up the number to the nearest square. Nearest square is guaranteed to be within 2 * sqrt(N) of a number N, so required padding is fairly small even in the worst case. PiperOrigin-RevId: 296086172 Change-Id: I7bfa72b2309fd1e3c596d6e028a9468660f84879
This commit is contained in:
parent
9c7537daae
commit
7f48bded8a
@ -67,24 +67,23 @@ ENTRY main {
|
||||
zero = f32[] constant(0)
|
||||
ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add
|
||||
}
|
||||
|
||||
)";
|
||||
|
||||
// TODO(cheshire): a more generic check, do not hardcode the names.
|
||||
MatchOptimizedHloWithShapes(hlo_text,
|
||||
R"(
|
||||
// CHECK: %fused_computation (param_0.2: f32[50000]) -> f32[7] {
|
||||
// CHECK: %fused_computation (param_0.2: f32[50000]) -> f32[224] {
|
||||
// CHECK: %param_0.2 = f32[50000]{0} parameter(0)
|
||||
// CHECK: %zero_1 = f32[] constant(0)
|
||||
// CHECK: %pad.1 = f32[57344]{0} pad(f32[50000]{0} %param_0.2, f32[] %zero_1), padding=0_7344
|
||||
// CHECK: %bitcast.1 = f32[7,8192]{1,0} bitcast(f32[57344]{0} %pad.1)
|
||||
// CHECK: ROOT %reduce.2 = f32[7]{0} reduce(f32[7,8192]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add
|
||||
// CHECK: %pad.1 = f32[50176]{0} pad(f32[50000]{0} %param_0.2, f32[] %zero_1), padding=0_176
|
||||
// CHECK: %bitcast.1 = f32[224,224]{1,0} bitcast(f32[50176]{0} %pad.1)
|
||||
// CHECK: ROOT %reduce.2 = f32[224]{0} reduce(f32[224,224]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add
|
||||
// CHECK: }
|
||||
// CHECK: ENTRY %main (input: f32[50000]) -> f32[] {
|
||||
// CHECK: %input = f32[50000]{0} parameter(0)
|
||||
// CHECK: %fusion = f32[7]{0} fusion(f32[50000]{0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %fusion = f32[224]{0} fusion(f32[50000]{0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %zero = f32[] constant(0)
|
||||
// CHECK: ROOT %reduce.1 = f32[] reduce(f32[7]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
|
||||
// CHECK: ROOT %reduce.1 = f32[] reduce(f32[224]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
|
||||
// CHECK: }
|
||||
)");
|
||||
|
||||
@ -107,27 +106,25 @@ ENTRY main {
|
||||
zero = f32[] constant(0)
|
||||
ROOT out = f32[100,100] reduce(input, zero), dimensions={2}, to_apply=add
|
||||
}
|
||||
|
||||
)";
|
||||
|
||||
EnsureDeterminism(hlo_text);
|
||||
|
||||
MatchOptimizedHloWithShapes(hlo_text,
|
||||
R"(
|
||||
// CHECK: %fused_computation (param_0.2: f32[100,100,10000]) -> f32[100,100,2] {
|
||||
// CHECK: %fused_computation (param_0.2: f32[100,100,10000]) -> f32[100,100,100] {
|
||||
// CHECK: %param_0.2 = f32[100,100,10000]{2,1,0} parameter(0)
|
||||
// CHECK: %zero_1 = f32[] constant(0)
|
||||
// CHECK: %pad.1 = f32[100,100,16384]{2,1,0} pad(f32[100,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384
|
||||
// CHECK: %bitcast.1 = f32[100,100,2,8192]{3,2,1,0} bitcast(f32[100,100,16384]{2,1,0} %pad.1)
|
||||
// CHECK: ROOT %reduce.2 = f32[100,100,2]{2,1,0} reduce(f32[100,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3}, to_apply=%add
|
||||
// CHECK: %pad.1 = f32[100,100,10000]{2,1,0} pad(f32[100,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_0
|
||||
// CHECK: %bitcast.1 = f32[100,100,100,100]{3,2,1,0} bitcast(f32[100,100,10000]{2,1,0} %pad.1)
|
||||
// CHECK: ROOT %reduce.2 = f32[100,100,100]{2,1,0} reduce(f32[100,100,100,100]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3}, to_apply=%add
|
||||
// CHECK: }
|
||||
// CHECK: ENTRY %main (input: f32[100,100,10000]) -> f32[100,100] {
|
||||
// CHECK: %input = f32[100,100,10000]{2,1,0} parameter(0)
|
||||
// CHECK: %fusion = f32[100,100,2]{2,1,0} fusion(f32[100,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %fusion = f32[100,100,100]{2,1,0} fusion(f32[100,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %zero = f32[] constant(0)
|
||||
// CHECK: ROOT %reduce.1 = f32[100,100]{1,0} reduce(f32[100,100,2]{2,1,0} %fusion, f32[] %zero), dimensions={2}, to_apply=%add
|
||||
// CHECK: ROOT %reduce.1 = f32[100,100]{1,0} reduce(f32[100,100,100]{2,1,0} %fusion, f32[] %zero), dimensions={2}, to_apply=%add
|
||||
// CHECK: }
|
||||
|
||||
)");
|
||||
|
||||
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
|
||||
@ -149,23 +146,22 @@ ENTRY main {
|
||||
zero = f32[] constant(0)
|
||||
ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add
|
||||
}
|
||||
|
||||
)";
|
||||
|
||||
MatchOptimizedHloWithShapes(hlo_text,
|
||||
R"(
|
||||
// CHECK: %fused_computation (param_0.2: f32[1000000]) -> f32[123] {
|
||||
// CHECK: %fused_computation (param_0.2: f32[1000000]) -> f32[1000] {
|
||||
// CHECK: %param_0.2 = f32[1000000]{0} parameter(0)
|
||||
// CHECK: %zero_1 = f32[] constant(0)
|
||||
// CHECK: %pad.1 = f32[1007616]{0} pad(f32[1000000]{0} %param_0.2, f32[] %zero_1), padding=0_7616
|
||||
// CHECK: %bitcast.1 = f32[123,8192]{1,0} bitcast(f32[1007616]{0} %pad.1)
|
||||
// CHECK: ROOT %reduce.2 = f32[123]{0} reduce(f32[123,8192]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add
|
||||
// CHECK: %pad.1 = f32[1000000]{0} pad(f32[1000000]{0} %param_0.2, f32[] %zero_1), padding=0_0
|
||||
// CHECK: %bitcast.1 = f32[1000,1000]{1,0} bitcast(f32[1000000]{0} %pad.1)
|
||||
// CHECK: ROOT %reduce.2 = f32[1000]{0} reduce(f32[1000,1000]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add
|
||||
// CHECK: }
|
||||
// CHECK: ENTRY %main (input: f32[1000000]) -> f32[] {
|
||||
// CHECK: %input = f32[1000000]{0} parameter(0)
|
||||
// CHECK: %fusion = f32[123]{0} fusion(f32[1000000]{0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %fusion = f32[1000]{0} fusion(f32[1000000]{0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %zero = f32[] constant(0)
|
||||
// CHECK: ROOT %reduce.1 = f32[] reduce(f32[123]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
|
||||
// CHECK: ROOT %reduce.1 = f32[] reduce(f32[1000]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
|
||||
// CHECK: }
|
||||
)");
|
||||
|
||||
@ -188,25 +184,24 @@ ENTRY main {
|
||||
zero = f32[] constant(0)
|
||||
ROOT out = f32[100] reduce(input, zero), dimensions={0,2}, to_apply=add
|
||||
}
|
||||
|
||||
)";
|
||||
|
||||
EnsureDeterminism(hlo_text);
|
||||
|
||||
MatchOptimizedHloWithShapes(hlo_text,
|
||||
R"(
|
||||
// CHECK: %fused_computation (param_0.2: f32[8,100,10000]) -> f32[100,2] {
|
||||
// CHECK: %fused_computation (param_0.2: f32[8,100,10000]) -> f32[100,100] {
|
||||
// CHECK: %param_0.2 = f32[8,100,10000]{2,1,0} parameter(0)
|
||||
// CHECK: %zero_1 = f32[] constant(0)
|
||||
// CHECK: %pad.1 = f32[8,100,16384]{2,1,0} pad(f32[8,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384
|
||||
// CHECK: %bitcast.1 = f32[8,100,2,8192]{3,2,1,0} bitcast(f32[8,100,16384]{2,1,0} %pad.1)
|
||||
// CHECK: ROOT %reduce.2 = f32[100,2]{1,0} reduce(f32[8,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3,0}, to_apply=%add
|
||||
// CHECK: %pad.1 = f32[8,100,10000]{2,1,0} pad(f32[8,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_0
|
||||
// CHECK: %bitcast.1 = f32[8,100,100,100]{3,2,1,0} bitcast(f32[8,100,10000]{2,1,0} %pad.1)
|
||||
// CHECK: ROOT %reduce.2 = f32[100,100]{1,0} reduce(f32[8,100,100,100]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3,0}, to_apply=%add
|
||||
// CHECK: }
|
||||
// CHECK: ENTRY %main (input: f32[8,100,10000]) -> f32[100] {
|
||||
// CHECK: %input = f32[8,100,10000]{2,1,0} parameter(0)
|
||||
// CHECK: %fusion = f32[100,2]{1,0} fusion(f32[8,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %fusion = f32[100,100]{1,0} fusion(f32[8,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %zero = f32[] constant(0)
|
||||
// CHECK: ROOT %reduce.1 = f32[100]{0} reduce(f32[100,2]{1,0} %fusion, f32[] %zero), dimensions={1}, to_apply=%add
|
||||
// CHECK: ROOT %reduce.1 = f32[100]{0} reduce(f32[100,100]{1,0} %fusion, f32[] %zero), dimensions={1}, to_apply=%add
|
||||
// CHECK: }
|
||||
)");
|
||||
|
||||
@ -234,23 +229,19 @@ ENTRY main {
|
||||
|
||||
MatchOptimizedHloWithShapes(hlo_text,
|
||||
R"(
|
||||
// CHECK: %fused_computation (param_0.4: f32[32,100,2]) -> f32[100] {
|
||||
// CHECK: %param_0.4 = f32[32,100,2]{2,1,0} parameter(0)
|
||||
// CHECK: %fused_computation (param_0.2: f32[32,100,10000]) -> f32[32,100,100] {
|
||||
// CHECK: %param_0.2 = f32[32,100,10000]{2,1,0} parameter(0)
|
||||
// CHECK: %zero_1 = f32[] constant(0)
|
||||
// CHECK: %reduce.5 = f32[32,100]{1,0} reduce(f32[32,100,2]{2,1,0} %param_0.4, f32[] %zero_1), dimensions={2}, to_apply=%add
|
||||
// CHECK: ROOT %reduce.4 = f32[100]{0} reduce(f32[32,100]{1,0} %reduce.5, f32[] %zero_1), dimensions={0}, to_apply=%add
|
||||
// CHECK: }
|
||||
// CHECK: %fused_computation.1 (param_0.5: f32[32,100,10000]) -> f32[32,100,2] {
|
||||
// CHECK: %param_0.5 = f32[32,100,10000]{2,1,0} parameter(0)
|
||||
// CHECK: %zero_2 = f32[] constant(0)
|
||||
// CHECK: %pad.1 = f32[32,100,16384]{2,1,0} pad(f32[32,100,10000]{2,1,0} %param_0.5, f32[] %zero_2), padding=0_0x0_0x0_6384
|
||||
// CHECK: %bitcast.1 = f32[32,100,2,8192]{3,2,1,0} bitcast(f32[32,100,16384]{2,1,0} %pad.1)
|
||||
// CHECK: ROOT %reduce.6 = f32[32,100,2]{2,1,0} reduce(f32[32,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_2), dimensions={3}, to_apply=%add
|
||||
// CHECK: %pad.1 = f32[32,100,10000]{2,1,0} pad(f32[32,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_0
|
||||
// CHECK: %bitcast.1 = f32[32,100,100,100]{3,2,1,0} bitcast(f32[32,100,10000]{2,1,0} %pad.1)
|
||||
// CHECK: ROOT %reduce.4 = f32[32,100,100]{2,1,0} reduce(f32[32,100,100,100]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3}, to_apply=%add
|
||||
// CHECK: }
|
||||
// CHECK: ENTRY %main (input: f32[32,100,10000]) -> f32[100] {
|
||||
// CHECK: %input = f32[32,100,10000]{2,1,0} parameter(0)
|
||||
// CHECK: %fusion.1 = f32[32,100,2]{2,1,0} fusion(f32[32,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation.1
|
||||
// CHECK: ROOT %fusion = f32[100]{0} fusion(f32[32,100,2]{2,1,0} %fusion.1), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %fusion = f32[32,100,100]{2,1,0} fusion(f32[32,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %zero = f32[] constant(0)
|
||||
// CHECK: %reduce.3 = f32[32,100]{1,0} reduce(f32[32,100,100]{2,1,0} %fusion, f32[] %zero), dimensions={2}, to_apply=%add
|
||||
// CHECK: ROOT %reduce.1 = f32[100]{0} reduce(f32[32,100]{1,0} %reduce.3, f32[] %zero), dimensions={0}, to_apply=%add
|
||||
// CHECK: }
|
||||
)");
|
||||
|
||||
@ -274,22 +265,22 @@ ENTRY main {
|
||||
zero = f32[] constant(0)
|
||||
ROOT out = f32[100] reduce(input, zero), dimensions={0}, to_apply=add
|
||||
}
|
||||
|
||||
)";
|
||||
|
||||
MatchOptimizedHloWithShapes(hlo_text,
|
||||
R"(
|
||||
// CHECK: %fused_computation (param_0.2: f32[10000,100]) -> f32[100] {
|
||||
// CHECK: %param_0.2 = f32[10000,100]{1,0} parameter(0)
|
||||
// CHECK: %zero_1 = f32[] constant(0)
|
||||
// CHECK: %pad.1 = f32[12288,100]{1,0} pad(f32[10000,100]{1,0} %param_0.2, f32[] %zero_1), padding=0_2288x0_0
|
||||
// CHECK: %bitcast.1 = f32[3,4096,100]{2,1,0} bitcast(f32[12288,100]{1,0} %pad.1)
|
||||
// CHECK: %reduce.3 = f32[4096,100]{1,0} reduce(f32[3,4096,100]{2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add
|
||||
// CHECK: ROOT %reduce.2 = f32[100]{0} reduce(f32[4096,100]{1,0} %reduce.3, f32[] %zero_1), dimensions={0}, to_apply=%add
|
||||
// CHECK: %fused_computation (param_0.2: f32[10000,100]) -> f32[100,100] {
|
||||
// CHECK: %param_0.2 = f32[10000,100]{1,0} parameter(0)
|
||||
// CHECK: %zero_1 = f32[] constant(0)
|
||||
// CHECK: %pad.1 = f32[10000,100]{1,0} pad(f32[10000,100]{1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0
|
||||
// CHECK: %bitcast.1 = f32[100,100,100]{2,1,0} bitcast(f32[10000,100]{1,0} %pad.1)
|
||||
// CHECK: ROOT %reduce.2 = f32[100,100]{1,0} reduce(f32[100,100,100]{2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add
|
||||
// CHECK: }
|
||||
// CHECK: ENTRY %main (input: f32[10000,100]) -> f32[100] {
|
||||
// CHECK: %input = f32[10000,100]{1,0} parameter(0)
|
||||
// CHECK: ROOT %fusion = f32[100]{0} fusion(f32[10000,100]{1,0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %input = f32[10000,100]{1,0} parameter(0)
|
||||
// CHECK: %fusion = f32[100,100]{1,0} fusion(f32[10000,100]{1,0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %zero = f32[] constant(0)
|
||||
// CHECK: ROOT %reduce.1 = f32[100]{0} reduce(f32[100,100]{1,0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
|
||||
// CHECK: }
|
||||
)");
|
||||
|
||||
@ -316,17 +307,18 @@ ENTRY main {
|
||||
|
||||
MatchOptimizedHloWithShapes(hlo_text,
|
||||
R"(
|
||||
// CHECK: %fused_computation (param_0.2: f32[10000,2,2,2]) -> f32[2,2,2] {
|
||||
// CHECK: %param_0.2 = f32[10000,2,2,2]{3,2,1,0} parameter(0)
|
||||
// CHECK: %zero_1 = f32[] constant(0)
|
||||
// CHECK: %pad.1 = f32[12288,2,2,2]{3,2,1,0} pad(f32[10000,2,2,2]{3,2,1,0} %param_0.2, f32[] %zero_1), padding=0_2288x0_0x0_0x0_0
|
||||
// CHECK: %bitcast.1 = f32[3,4096,2,2,2]{4,3,2,1,0} bitcast(f32[12288,2,2,2]{3,2,1,0} %pad.1)
|
||||
// CHECK: %reduce.3 = f32[4096,2,2,2]{3,2,1,0} reduce(f32[3,4096,2,2,2]{4,3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add
|
||||
// CHECK: ROOT %reduce.2 = f32[2,2,2]{2,1,0} reduce(f32[4096,2,2,2]{3,2,1,0} %reduce.3, f32[] %zero_1), dimensions={0}, to_apply=%add
|
||||
// CHECK: %fused_computation (param_0.2: f32[10000,2,2,2]) -> f32[100,2,2,2] {
|
||||
// CHECK: %param_0.2 = f32[10000,2,2,2]{3,2,1,0} parameter(0)
|
||||
// CHECK: %zero_1 = f32[] constant(0)
|
||||
// CHECK: %pad.1 = f32[10000,2,2,2]{3,2,1,0} pad(f32[10000,2,2,2]{3,2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_0x0_0
|
||||
// CHECK: %bitcast.1 = f32[100,100,2,2,2]{4,3,2,1,0} bitcast(f32[10000,2,2,2]{3,2,1,0} %pad.1)
|
||||
// CHECK: ROOT %reduce.2 = f32[100,2,2,2]{3,2,1,0} reduce(f32[100,100,2,2,2]{4,3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add
|
||||
// CHECK: }
|
||||
// CHECK: ENTRY %main (input: f32[10000,2,2,2]) -> f32[2,2,2] {
|
||||
// CHECK: %input = f32[10000,2,2,2]{3,2,1,0} parameter(0)
|
||||
// CHECK: ROOT %fusion = f32[2,2,2]{2,1,0} fusion(f32[10000,2,2,2]{3,2,1,0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %input = f32[10000,2,2,2]{3,2,1,0} parameter(0)
|
||||
// CHECK: %fusion = f32[100,2,2,2]{3,2,1,0} fusion(f32[10000,2,2,2]{3,2,1,0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %zero = f32[] constant(0)
|
||||
// CHECK: ROOT %reduce.1 = f32[2,2,2]{2,1,0} reduce(f32[100,2,2,2]{3,2,1,0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
|
||||
// CHECK: }
|
||||
)");
|
||||
|
||||
@ -355,18 +347,18 @@ ENTRY main {
|
||||
|
||||
MatchOptimizedHloWithShapes(hlo_text,
|
||||
R"(
|
||||
// CHECK: %fused_computation (param_0.2: f32[1000000,5]) -> f32[4096,5] {
|
||||
// CHECK: %param_0.2 = f32[1000000,5]{1,0} parameter(0)
|
||||
// CHECK: %zero_1 = f32[] constant(0)
|
||||
// CHECK: %pad.1 = f32[1003520,5]{1,0} pad(f32[1000000,5]{1,0} %param_0.2, f32[] %zero_1), padding=0_3520x0_0
|
||||
// CHECK: %bitcast.1 = f32[245,4096,5]{2,1,0} bitcast(f32[1003520,5]{1,0} %pad.1)
|
||||
// CHECK: ROOT %reduce.2 = f32[4096,5]{1,0} reduce(f32[245,4096,5]{2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add
|
||||
// CHECK: %fused_computation (param_0.2: f32[1000000,5]) -> f32[1000,5] {
|
||||
// CHECK: %param_0.2 = f32[1000000,5]{1,0} parameter(0)
|
||||
// CHECK: %zero_1 = f32[] constant(0)
|
||||
// CHECK: %pad.1 = f32[1000000,5]{1,0} pad(f32[1000000,5]{1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0
|
||||
// CHECK: %bitcast.1 = f32[1000,1000,5]{2,1,0} bitcast(f32[1000000,5]{1,0} %pad.1)
|
||||
// CHECK: ROOT %reduce.2 = f32[1000,5]{1,0} reduce(f32[1000,1000,5]{2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add
|
||||
// CHECK: }
|
||||
// CHECK: ENTRY %main (input: f32[1000000,5]) -> f32[5] {
|
||||
// CHECK: %input = f32[1000000,5]{1,0} parameter(0)
|
||||
// CHECK: %fusion = f32[4096,5]{1,0} fusion(f32[1000000,5]{1,0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %zero = f32[] constant(0)
|
||||
// CHECK: ROOT %reduce.1 = f32[5]{0} reduce(f32[4096,5]{1,0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
|
||||
// CHECK: %input = f32[1000000,5]{1,0} parameter(0)
|
||||
// CHECK: %fusion = f32[1000,5]{1,0} fusion(f32[1000000,5]{1,0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %zero = f32[] constant(0)
|
||||
// CHECK: ROOT %reduce.1 = f32[5]{0} reduce(f32[1000,5]{1,0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
|
||||
// CHECK: }
|
||||
)");
|
||||
|
||||
|
@ -46,6 +46,11 @@ static constexpr int64 kColumnAtomicFreeBound = kWarpSize * 128;
|
||||
// decreased column/row tiling.
|
||||
static constexpr int64 kBatchedAtomicFreeBound = 8;
|
||||
|
||||
// Returns the square root of the input rounded up to the nearest square.
|
||||
static int64 SqrtOfRoundUpToNearestSquare(int64 input) {
|
||||
return static_cast<int64>(std::ceil(std::sqrt(input)));
|
||||
}
|
||||
|
||||
class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
|
||||
public:
|
||||
explicit ReductionRewriterVisitor() {}
|
||||
@ -105,39 +110,29 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
|
||||
|
||||
int64 reduced_dim_size = input_shape.dimensions(reduced_input_dimension);
|
||||
VLOG(3) << "reduced_dim_size = " << reduced_dim_size;
|
||||
// TODO(cheshire): if atomic_free_bound is large, num_fit is likely to be
|
||||
// small. Generating a reduction with very small reduced dimension is not
|
||||
// efficient, it would be better to split the dimension sizes more evenly.
|
||||
//
|
||||
// One possible idea is to pad to a nearest square (ceil(sqrt(x)))^2.
|
||||
// Given that:
|
||||
|
||||
// We pad to a nearest square (ceil(sqrt(x)))^2. Given that:
|
||||
//
|
||||
// (n + 1)^2 = n^2 + (2n+1)
|
||||
//
|
||||
// it can be seen that the distance to the nearest square is at most twice
|
||||
// the square root of the input number.
|
||||
int64 num_fit = CeilOfRatio(reduced_dim_size, atomic_free_bound);
|
||||
int64 num_fit = SqrtOfRoundUpToNearestSquare(reduced_dim_size);
|
||||
|
||||
// Pad reduced dimension to the required number of elements.
|
||||
HloInstruction *padded = [&] {
|
||||
// TODO(cheshire): if atomic_free_bound is very large, padding all the way
|
||||
// up to to atomic_free_bound is wasteful, we could pad to a much smaller
|
||||
// value.
|
||||
if (reduced_dim_size % atomic_free_bound != 0) {
|
||||
int64 padded_num_elements = num_fit * atomic_free_bound;
|
||||
PaddingConfig padding_config = MakeNoPaddingConfig(input_shape.rank());
|
||||
padding_config.mutable_dimensions(reduced_input_dimension)
|
||||
->set_edge_padding_high(padded_num_elements - reduced_dim_size);
|
||||
std::vector<int64> padded_dimensions(input_shape.dimensions().begin(),
|
||||
input_shape.dimensions().end());
|
||||
padded_dimensions[reduced_input_dimension] = padded_num_elements;
|
||||
Shape padded_shape =
|
||||
ShapeUtil::MakeShape(input_shape.element_type(), padded_dimensions);
|
||||
VLOG(3) << "Generated padded shape: " << padded_shape.ToString();
|
||||
return hlo->parent()->AddInstruction(HloInstruction::CreatePad(
|
||||
padded_shape, input, initial_value, padding_config));
|
||||
}
|
||||
return input;
|
||||
int64 padded_num_elements = num_fit * num_fit;
|
||||
PaddingConfig padding_config = MakeNoPaddingConfig(input_shape.rank());
|
||||
padding_config.mutable_dimensions(reduced_input_dimension)
|
||||
->set_edge_padding_high(padded_num_elements - reduced_dim_size);
|
||||
std::vector<int64> padded_dimensions(input_shape.dimensions().begin(),
|
||||
input_shape.dimensions().end());
|
||||
padded_dimensions[reduced_input_dimension] = padded_num_elements;
|
||||
Shape padded_shape =
|
||||
ShapeUtil::MakeShape(input_shape.element_type(), padded_dimensions);
|
||||
VLOG(3) << "Generated padded shape: " << padded_shape.ToString();
|
||||
return hlo->parent()->AddInstruction(HloInstruction::CreatePad(
|
||||
padded_shape, input, initial_value, padding_config));
|
||||
}();
|
||||
|
||||
VLOG(1) << "Generated padding: " << padded->ToString();
|
||||
@ -146,7 +141,7 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
|
||||
dim_idx++) {
|
||||
if (dim_idx == reduced_input_dimension) {
|
||||
reshaped_dimensions.push_back(num_fit);
|
||||
reshaped_dimensions.push_back(atomic_free_bound);
|
||||
reshaped_dimensions.push_back(num_fit);
|
||||
} else {
|
||||
reshaped_dimensions.push_back(padded->shape().dimensions(dim_idx));
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user