[XLA/GPU] Change rounding scheme for tree reduction to round up to nearest square
Previously, we were rounding up to the nearest divisor of the largest batch we could handle without introducing atomics. That leads to: - Very large padding, e.g. rounding up 8193 to 16384 - Very small dimensions of extra reduction kernels, e.g. 2 Instead, this CL uses a more "even" rounding scheme, where we round up the number to the nearest square. Nearest square is guaranteed to be within 2 * sqrt(N) of a number N, so required padding is fairly small even in the worst case. PiperOrigin-RevId: 296086172 Change-Id: I7bfa72b2309fd1e3c596d6e028a9468660f84879
This commit is contained in:
parent
9c7537daae
commit
7f48bded8a
@ -67,24 +67,23 @@ ENTRY main {
|
||||
zero = f32[] constant(0)
|
||||
ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add
|
||||
}
|
||||
|
||||
)";
|
||||
|
||||
// TODO(cheshire): a more generic check, do not hardcode the names.
|
||||
MatchOptimizedHloWithShapes(hlo_text,
|
||||
R"(
|
||||
// CHECK: %fused_computation (param_0.2: f32[50000]) -> f32[7] {
|
||||
// CHECK: %fused_computation (param_0.2: f32[50000]) -> f32[224] {
|
||||
// CHECK: %param_0.2 = f32[50000]{0} parameter(0)
|
||||
// CHECK: %zero_1 = f32[] constant(0)
|
||||
// CHECK: %pad.1 = f32[57344]{0} pad(f32[50000]{0} %param_0.2, f32[] %zero_1), padding=0_7344
|
||||
// CHECK: %bitcast.1 = f32[7,8192]{1,0} bitcast(f32[57344]{0} %pad.1)
|
||||
// CHECK: ROOT %reduce.2 = f32[7]{0} reduce(f32[7,8192]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add
|
||||
// CHECK: %pad.1 = f32[50176]{0} pad(f32[50000]{0} %param_0.2, f32[] %zero_1), padding=0_176
|
||||
// CHECK: %bitcast.1 = f32[224,224]{1,0} bitcast(f32[50176]{0} %pad.1)
|
||||
// CHECK: ROOT %reduce.2 = f32[224]{0} reduce(f32[224,224]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add
|
||||
// CHECK: }
|
||||
// CHECK: ENTRY %main (input: f32[50000]) -> f32[] {
|
||||
// CHECK: %input = f32[50000]{0} parameter(0)
|
||||
// CHECK: %fusion = f32[7]{0} fusion(f32[50000]{0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %fusion = f32[224]{0} fusion(f32[50000]{0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %zero = f32[] constant(0)
|
||||
// CHECK: ROOT %reduce.1 = f32[] reduce(f32[7]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
|
||||
// CHECK: ROOT %reduce.1 = f32[] reduce(f32[224]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
|
||||
// CHECK: }
|
||||
)");
|
||||
|
||||
@ -107,27 +106,25 @@ ENTRY main {
|
||||
zero = f32[] constant(0)
|
||||
ROOT out = f32[100,100] reduce(input, zero), dimensions={2}, to_apply=add
|
||||
}
|
||||
|
||||
)";
|
||||
|
||||
EnsureDeterminism(hlo_text);
|
||||
|
||||
MatchOptimizedHloWithShapes(hlo_text,
|
||||
R"(
|
||||
// CHECK: %fused_computation (param_0.2: f32[100,100,10000]) -> f32[100,100,2] {
|
||||
// CHECK: %fused_computation (param_0.2: f32[100,100,10000]) -> f32[100,100,100] {
|
||||
// CHECK: %param_0.2 = f32[100,100,10000]{2,1,0} parameter(0)
|
||||
// CHECK: %zero_1 = f32[] constant(0)
|
||||
// CHECK: %pad.1 = f32[100,100,16384]{2,1,0} pad(f32[100,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384
|
||||
// CHECK: %bitcast.1 = f32[100,100,2,8192]{3,2,1,0} bitcast(f32[100,100,16384]{2,1,0} %pad.1)
|
||||
// CHECK: ROOT %reduce.2 = f32[100,100,2]{2,1,0} reduce(f32[100,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3}, to_apply=%add
|
||||
// CHECK: %pad.1 = f32[100,100,10000]{2,1,0} pad(f32[100,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_0
|
||||
// CHECK: %bitcast.1 = f32[100,100,100,100]{3,2,1,0} bitcast(f32[100,100,10000]{2,1,0} %pad.1)
|
||||
// CHECK: ROOT %reduce.2 = f32[100,100,100]{2,1,0} reduce(f32[100,100,100,100]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3}, to_apply=%add
|
||||
// CHECK: }
|
||||
// CHECK: ENTRY %main (input: f32[100,100,10000]) -> f32[100,100] {
|
||||
// CHECK: %input = f32[100,100,10000]{2,1,0} parameter(0)
|
||||
// CHECK: %fusion = f32[100,100,2]{2,1,0} fusion(f32[100,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %fusion = f32[100,100,100]{2,1,0} fusion(f32[100,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %zero = f32[] constant(0)
|
||||
// CHECK: ROOT %reduce.1 = f32[100,100]{1,0} reduce(f32[100,100,2]{2,1,0} %fusion, f32[] %zero), dimensions={2}, to_apply=%add
|
||||
// CHECK: ROOT %reduce.1 = f32[100,100]{1,0} reduce(f32[100,100,100]{2,1,0} %fusion, f32[] %zero), dimensions={2}, to_apply=%add
|
||||
// CHECK: }
|
||||
|
||||
)");
|
||||
|
||||
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
|
||||
@ -149,23 +146,22 @@ ENTRY main {
|
||||
zero = f32[] constant(0)
|
||||
ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add
|
||||
}
|
||||
|
||||
)";
|
||||
|
||||
MatchOptimizedHloWithShapes(hlo_text,
|
||||
R"(
|
||||
// CHECK: %fused_computation (param_0.2: f32[1000000]) -> f32[123] {
|
||||
// CHECK: %fused_computation (param_0.2: f32[1000000]) -> f32[1000] {
|
||||
// CHECK: %param_0.2 = f32[1000000]{0} parameter(0)
|
||||
// CHECK: %zero_1 = f32[] constant(0)
|
||||
// CHECK: %pad.1 = f32[1007616]{0} pad(f32[1000000]{0} %param_0.2, f32[] %zero_1), padding=0_7616
|
||||
// CHECK: %bitcast.1 = f32[123,8192]{1,0} bitcast(f32[1007616]{0} %pad.1)
|
||||
// CHECK: ROOT %reduce.2 = f32[123]{0} reduce(f32[123,8192]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add
|
||||
// CHECK: %pad.1 = f32[1000000]{0} pad(f32[1000000]{0} %param_0.2, f32[] %zero_1), padding=0_0
|
||||
// CHECK: %bitcast.1 = f32[1000,1000]{1,0} bitcast(f32[1000000]{0} %pad.1)
|
||||
// CHECK: ROOT %reduce.2 = f32[1000]{0} reduce(f32[1000,1000]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add
|
||||
// CHECK: }
|
||||
// CHECK: ENTRY %main (input: f32[1000000]) -> f32[] {
|
||||
// CHECK: %input = f32[1000000]{0} parameter(0)
|
||||
// CHECK: %fusion = f32[123]{0} fusion(f32[1000000]{0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %fusion = f32[1000]{0} fusion(f32[1000000]{0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %zero = f32[] constant(0)
|
||||
// CHECK: ROOT %reduce.1 = f32[] reduce(f32[123]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
|
||||
// CHECK: ROOT %reduce.1 = f32[] reduce(f32[1000]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
|
||||
// CHECK: }
|
||||
)");
|
||||
|
||||
@ -188,25 +184,24 @@ ENTRY main {
|
||||
zero = f32[] constant(0)
|
||||
ROOT out = f32[100] reduce(input, zero), dimensions={0,2}, to_apply=add
|
||||
}
|
||||
|
||||
)";
|
||||
|
||||
EnsureDeterminism(hlo_text);
|
||||
|
||||
MatchOptimizedHloWithShapes(hlo_text,
|
||||
R"(
|
||||
// CHECK: %fused_computation (param_0.2: f32[8,100,10000]) -> f32[100,2] {
|
||||
// CHECK: %fused_computation (param_0.2: f32[8,100,10000]) -> f32[100,100] {
|
||||
// CHECK: %param_0.2 = f32[8,100,10000]{2,1,0} parameter(0)
|
||||
// CHECK: %zero_1 = f32[] constant(0)
|
||||
// CHECK: %pad.1 = f32[8,100,16384]{2,1,0} pad(f32[8,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384
|
||||
// CHECK: %bitcast.1 = f32[8,100,2,8192]{3,2,1,0} bitcast(f32[8,100,16384]{2,1,0} %pad.1)
|
||||
// CHECK: ROOT %reduce.2 = f32[100,2]{1,0} reduce(f32[8,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3,0}, to_apply=%add
|
||||
// CHECK: %pad.1 = f32[8,100,10000]{2,1,0} pad(f32[8,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_0
|
||||
// CHECK: %bitcast.1 = f32[8,100,100,100]{3,2,1,0} bitcast(f32[8,100,10000]{2,1,0} %pad.1)
|
||||
// CHECK: ROOT %reduce.2 = f32[100,100]{1,0} reduce(f32[8,100,100,100]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3,0}, to_apply=%add
|
||||
// CHECK: }
|
||||
// CHECK: ENTRY %main (input: f32[8,100,10000]) -> f32[100] {
|
||||
// CHECK: %input = f32[8,100,10000]{2,1,0} parameter(0)
|
||||
// CHECK: %fusion = f32[100,2]{1,0} fusion(f32[8,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %fusion = f32[100,100]{1,0} fusion(f32[8,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %zero = f32[] constant(0)
|
||||
// CHECK: ROOT %reduce.1 = f32[100]{0} reduce(f32[100,2]{1,0} %fusion, f32[] %zero), dimensions={1}, to_apply=%add
|
||||
// CHECK: ROOT %reduce.1 = f32[100]{0} reduce(f32[100,100]{1,0} %fusion, f32[] %zero), dimensions={1}, to_apply=%add
|
||||
// CHECK: }
|
||||
)");
|
||||
|
||||
@ -234,23 +229,19 @@ ENTRY main {
|
||||
|
||||
MatchOptimizedHloWithShapes(hlo_text,
|
||||
R"(
|
||||
// CHECK: %fused_computation (param_0.4: f32[32,100,2]) -> f32[100] {
|
||||
// CHECK: %param_0.4 = f32[32,100,2]{2,1,0} parameter(0)
|
||||
// CHECK: %fused_computation (param_0.2: f32[32,100,10000]) -> f32[32,100,100] {
|
||||
// CHECK: %param_0.2 = f32[32,100,10000]{2,1,0} parameter(0)
|
||||
// CHECK: %zero_1 = f32[] constant(0)
|
||||
// CHECK: %reduce.5 = f32[32,100]{1,0} reduce(f32[32,100,2]{2,1,0} %param_0.4, f32[] %zero_1), dimensions={2}, to_apply=%add
|
||||
// CHECK: ROOT %reduce.4 = f32[100]{0} reduce(f32[32,100]{1,0} %reduce.5, f32[] %zero_1), dimensions={0}, to_apply=%add
|
||||
// CHECK: }
|
||||
// CHECK: %fused_computation.1 (param_0.5: f32[32,100,10000]) -> f32[32,100,2] {
|
||||
// CHECK: %param_0.5 = f32[32,100,10000]{2,1,0} parameter(0)
|
||||
// CHECK: %zero_2 = f32[] constant(0)
|
||||
// CHECK: %pad.1 = f32[32,100,16384]{2,1,0} pad(f32[32,100,10000]{2,1,0} %param_0.5, f32[] %zero_2), padding=0_0x0_0x0_6384
|
||||
// CHECK: %bitcast.1 = f32[32,100,2,8192]{3,2,1,0} bitcast(f32[32,100,16384]{2,1,0} %pad.1)
|
||||
// CHECK: ROOT %reduce.6 = f32[32,100,2]{2,1,0} reduce(f32[32,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_2), dimensions={3}, to_apply=%add
|
||||
// CHECK: %pad.1 = f32[32,100,10000]{2,1,0} pad(f32[32,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_0
|
||||
// CHECK: %bitcast.1 = f32[32,100,100,100]{3,2,1,0} bitcast(f32[32,100,10000]{2,1,0} %pad.1)
|
||||
// CHECK: ROOT %reduce.4 = f32[32,100,100]{2,1,0} reduce(f32[32,100,100,100]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3}, to_apply=%add
|
||||
// CHECK: }
|
||||
// CHECK: ENTRY %main (input: f32[32,100,10000]) -> f32[100] {
|
||||
// CHECK: %input = f32[32,100,10000]{2,1,0} parameter(0)
|
||||
// CHECK: %fusion.1 = f32[32,100,2]{2,1,0} fusion(f32[32,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation.1
|
||||
// CHECK: ROOT %fusion = f32[100]{0} fusion(f32[32,100,2]{2,1,0} %fusion.1), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %fusion = f32[32,100,100]{2,1,0} fusion(f32[32,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %zero = f32[] constant(0)
|
||||
// CHECK: %reduce.3 = f32[32,100]{1,0} reduce(f32[32,100,100]{2,1,0} %fusion, f32[] %zero), dimensions={2}, to_apply=%add
|
||||
// CHECK: ROOT %reduce.1 = f32[100]{0} reduce(f32[32,100]{1,0} %reduce.3, f32[] %zero), dimensions={0}, to_apply=%add
|
||||
// CHECK: }
|
||||
)");
|
||||
|
||||
@ -274,22 +265,22 @@ ENTRY main {
|
||||
zero = f32[] constant(0)
|
||||
ROOT out = f32[100] reduce(input, zero), dimensions={0}, to_apply=add
|
||||
}
|
||||
|
||||
)";
|
||||
|
||||
MatchOptimizedHloWithShapes(hlo_text,
|
||||
R"(
|
||||
// CHECK: %fused_computation (param_0.2: f32[10000,100]) -> f32[100] {
|
||||
// CHECK: %fused_computation (param_0.2: f32[10000,100]) -> f32[100,100] {
|
||||
// CHECK: %param_0.2 = f32[10000,100]{1,0} parameter(0)
|
||||
// CHECK: %zero_1 = f32[] constant(0)
|
||||
// CHECK: %pad.1 = f32[12288,100]{1,0} pad(f32[10000,100]{1,0} %param_0.2, f32[] %zero_1), padding=0_2288x0_0
|
||||
// CHECK: %bitcast.1 = f32[3,4096,100]{2,1,0} bitcast(f32[12288,100]{1,0} %pad.1)
|
||||
// CHECK: %reduce.3 = f32[4096,100]{1,0} reduce(f32[3,4096,100]{2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add
|
||||
// CHECK: ROOT %reduce.2 = f32[100]{0} reduce(f32[4096,100]{1,0} %reduce.3, f32[] %zero_1), dimensions={0}, to_apply=%add
|
||||
// CHECK: %pad.1 = f32[10000,100]{1,0} pad(f32[10000,100]{1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0
|
||||
// CHECK: %bitcast.1 = f32[100,100,100]{2,1,0} bitcast(f32[10000,100]{1,0} %pad.1)
|
||||
// CHECK: ROOT %reduce.2 = f32[100,100]{1,0} reduce(f32[100,100,100]{2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add
|
||||
// CHECK: }
|
||||
// CHECK: ENTRY %main (input: f32[10000,100]) -> f32[100] {
|
||||
// CHECK: %input = f32[10000,100]{1,0} parameter(0)
|
||||
// CHECK: ROOT %fusion = f32[100]{0} fusion(f32[10000,100]{1,0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %fusion = f32[100,100]{1,0} fusion(f32[10000,100]{1,0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %zero = f32[] constant(0)
|
||||
// CHECK: ROOT %reduce.1 = f32[100]{0} reduce(f32[100,100]{1,0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
|
||||
// CHECK: }
|
||||
)");
|
||||
|
||||
@ -316,17 +307,18 @@ ENTRY main {
|
||||
|
||||
MatchOptimizedHloWithShapes(hlo_text,
|
||||
R"(
|
||||
// CHECK: %fused_computation (param_0.2: f32[10000,2,2,2]) -> f32[2,2,2] {
|
||||
// CHECK: %fused_computation (param_0.2: f32[10000,2,2,2]) -> f32[100,2,2,2] {
|
||||
// CHECK: %param_0.2 = f32[10000,2,2,2]{3,2,1,0} parameter(0)
|
||||
// CHECK: %zero_1 = f32[] constant(0)
|
||||
// CHECK: %pad.1 = f32[12288,2,2,2]{3,2,1,0} pad(f32[10000,2,2,2]{3,2,1,0} %param_0.2, f32[] %zero_1), padding=0_2288x0_0x0_0x0_0
|
||||
// CHECK: %bitcast.1 = f32[3,4096,2,2,2]{4,3,2,1,0} bitcast(f32[12288,2,2,2]{3,2,1,0} %pad.1)
|
||||
// CHECK: %reduce.3 = f32[4096,2,2,2]{3,2,1,0} reduce(f32[3,4096,2,2,2]{4,3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add
|
||||
// CHECK: ROOT %reduce.2 = f32[2,2,2]{2,1,0} reduce(f32[4096,2,2,2]{3,2,1,0} %reduce.3, f32[] %zero_1), dimensions={0}, to_apply=%add
|
||||
// CHECK: %pad.1 = f32[10000,2,2,2]{3,2,1,0} pad(f32[10000,2,2,2]{3,2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_0x0_0
|
||||
// CHECK: %bitcast.1 = f32[100,100,2,2,2]{4,3,2,1,0} bitcast(f32[10000,2,2,2]{3,2,1,0} %pad.1)
|
||||
// CHECK: ROOT %reduce.2 = f32[100,2,2,2]{3,2,1,0} reduce(f32[100,100,2,2,2]{4,3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add
|
||||
// CHECK: }
|
||||
// CHECK: ENTRY %main (input: f32[10000,2,2,2]) -> f32[2,2,2] {
|
||||
// CHECK: %input = f32[10000,2,2,2]{3,2,1,0} parameter(0)
|
||||
// CHECK: ROOT %fusion = f32[2,2,2]{2,1,0} fusion(f32[10000,2,2,2]{3,2,1,0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %fusion = f32[100,2,2,2]{3,2,1,0} fusion(f32[10000,2,2,2]{3,2,1,0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %zero = f32[] constant(0)
|
||||
// CHECK: ROOT %reduce.1 = f32[2,2,2]{2,1,0} reduce(f32[100,2,2,2]{3,2,1,0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
|
||||
// CHECK: }
|
||||
)");
|
||||
|
||||
@ -355,18 +347,18 @@ ENTRY main {
|
||||
|
||||
MatchOptimizedHloWithShapes(hlo_text,
|
||||
R"(
|
||||
// CHECK: %fused_computation (param_0.2: f32[1000000,5]) -> f32[4096,5] {
|
||||
// CHECK: %fused_computation (param_0.2: f32[1000000,5]) -> f32[1000,5] {
|
||||
// CHECK: %param_0.2 = f32[1000000,5]{1,0} parameter(0)
|
||||
// CHECK: %zero_1 = f32[] constant(0)
|
||||
// CHECK: %pad.1 = f32[1003520,5]{1,0} pad(f32[1000000,5]{1,0} %param_0.2, f32[] %zero_1), padding=0_3520x0_0
|
||||
// CHECK: %bitcast.1 = f32[245,4096,5]{2,1,0} bitcast(f32[1003520,5]{1,0} %pad.1)
|
||||
// CHECK: ROOT %reduce.2 = f32[4096,5]{1,0} reduce(f32[245,4096,5]{2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add
|
||||
// CHECK: %pad.1 = f32[1000000,5]{1,0} pad(f32[1000000,5]{1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0
|
||||
// CHECK: %bitcast.1 = f32[1000,1000,5]{2,1,0} bitcast(f32[1000000,5]{1,0} %pad.1)
|
||||
// CHECK: ROOT %reduce.2 = f32[1000,5]{1,0} reduce(f32[1000,1000,5]{2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add
|
||||
// CHECK: }
|
||||
// CHECK: ENTRY %main (input: f32[1000000,5]) -> f32[5] {
|
||||
// CHECK: %input = f32[1000000,5]{1,0} parameter(0)
|
||||
// CHECK: %fusion = f32[4096,5]{1,0} fusion(f32[1000000,5]{1,0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %fusion = f32[1000,5]{1,0} fusion(f32[1000000,5]{1,0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %zero = f32[] constant(0)
|
||||
// CHECK: ROOT %reduce.1 = f32[5]{0} reduce(f32[4096,5]{1,0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
|
||||
// CHECK: ROOT %reduce.1 = f32[5]{0} reduce(f32[1000,5]{1,0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
|
||||
// CHECK: }
|
||||
)");
|
||||
|
||||
|
@ -46,6 +46,11 @@ static constexpr int64 kColumnAtomicFreeBound = kWarpSize * 128;
|
||||
// decreased column/row tiling.
|
||||
static constexpr int64 kBatchedAtomicFreeBound = 8;
|
||||
|
||||
// Returns the square root of the input rounded up to the nearest square.
|
||||
static int64 SqrtOfRoundUpToNearestSquare(int64 input) {
|
||||
return static_cast<int64>(std::ceil(std::sqrt(input)));
|
||||
}
|
||||
|
||||
class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
|
||||
public:
|
||||
explicit ReductionRewriterVisitor() {}
|
||||
@ -105,26 +110,18 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
|
||||
|
||||
int64 reduced_dim_size = input_shape.dimensions(reduced_input_dimension);
|
||||
VLOG(3) << "reduced_dim_size = " << reduced_dim_size;
|
||||
// TODO(cheshire): if atomic_free_bound is large, num_fit is likely to be
|
||||
// small. Generating a reduction with very small reduced dimension is not
|
||||
// efficient, it would be better to split the dimension sizes more evenly.
|
||||
//
|
||||
// One possible idea is to pad to a nearest square (ceil(sqrt(x)))^2.
|
||||
// Given that:
|
||||
|
||||
// We pad to a nearest square (ceil(sqrt(x)))^2. Given that:
|
||||
//
|
||||
// (n + 1)^2 = n^2 + (2n+1)
|
||||
//
|
||||
// it can be seen that the distance to the nearest square is at most twice
|
||||
// the square root of the input number.
|
||||
int64 num_fit = CeilOfRatio(reduced_dim_size, atomic_free_bound);
|
||||
int64 num_fit = SqrtOfRoundUpToNearestSquare(reduced_dim_size);
|
||||
|
||||
// Pad reduced dimension to the required number of elements.
|
||||
HloInstruction *padded = [&] {
|
||||
// TODO(cheshire): if atomic_free_bound is very large, padding all the way
|
||||
// up to to atomic_free_bound is wasteful, we could pad to a much smaller
|
||||
// value.
|
||||
if (reduced_dim_size % atomic_free_bound != 0) {
|
||||
int64 padded_num_elements = num_fit * atomic_free_bound;
|
||||
int64 padded_num_elements = num_fit * num_fit;
|
||||
PaddingConfig padding_config = MakeNoPaddingConfig(input_shape.rank());
|
||||
padding_config.mutable_dimensions(reduced_input_dimension)
|
||||
->set_edge_padding_high(padded_num_elements - reduced_dim_size);
|
||||
@ -136,8 +133,6 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
|
||||
VLOG(3) << "Generated padded shape: " << padded_shape.ToString();
|
||||
return hlo->parent()->AddInstruction(HloInstruction::CreatePad(
|
||||
padded_shape, input, initial_value, padding_config));
|
||||
}
|
||||
return input;
|
||||
}();
|
||||
|
||||
VLOG(1) << "Generated padding: " << padded->ToString();
|
||||
@ -146,7 +141,7 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
|
||||
dim_idx++) {
|
||||
if (dim_idx == reduced_input_dimension) {
|
||||
reshaped_dimensions.push_back(num_fit);
|
||||
reshaped_dimensions.push_back(atomic_free_bound);
|
||||
reshaped_dimensions.push_back(num_fit);
|
||||
} else {
|
||||
reshaped_dimensions.push_back(padded->shape().dimensions(dim_idx));
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user