[XLA/GPU] Change rounding scheme for tree reduction to round up to nearest square
Previously, we were rounding up to the nearest divisor of the largest batch we could handle without introducing atomics. That leads to: - Very large padding, e.g. rounding up 8193 to 16384 - Very small dimensions of extra reduction kernels, e.g. 2 Instead, this CL uses a more "even" rounding scheme, where we round up the number to the nearest square. Nearest square is guaranteed to be within 2 * sqrt(N) of a number N, so required padding is fairly small even in the worst case. PiperOrigin-RevId: 296086172 Change-Id: I7bfa72b2309fd1e3c596d6e028a9468660f84879
This commit is contained in:
		
							parent
							
								
									9c7537daae
								
							
						
					
					
						commit
						7f48bded8a
					
				| @ -67,24 +67,23 @@ ENTRY main { | |||||||
|   zero = f32[] constant(0) |   zero = f32[] constant(0) | ||||||
|   ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add |   ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add | ||||||
| } | } | ||||||
| 
 |  | ||||||
| )"; | )"; | ||||||
| 
 | 
 | ||||||
|   // TODO(cheshire): a more generic check, do not hardcode the names.
 |   // TODO(cheshire): a more generic check, do not hardcode the names.
 | ||||||
|   MatchOptimizedHloWithShapes(hlo_text, |   MatchOptimizedHloWithShapes(hlo_text, | ||||||
|                               R"( |                               R"( | ||||||
| // CHECK: %fused_computation (param_0.2: f32[50000]) -> f32[7] {
 | // CHECK: %fused_computation (param_0.2: f32[50000]) -> f32[224] {
 | ||||||
| // CHECK:   %param_0.2 = f32[50000]{0} parameter(0)
 | // CHECK:   %param_0.2 = f32[50000]{0} parameter(0)
 | ||||||
| // CHECK:   %zero_1 = f32[] constant(0)
 | // CHECK:   %zero_1 = f32[] constant(0)
 | ||||||
| // CHECK:   %pad.1 = f32[57344]{0} pad(f32[50000]{0} %param_0.2, f32[] %zero_1), padding=0_7344
 | // CHECK:   %pad.1 = f32[50176]{0} pad(f32[50000]{0} %param_0.2, f32[] %zero_1), padding=0_176
 | ||||||
| // CHECK:   %bitcast.1 = f32[7,8192]{1,0} bitcast(f32[57344]{0} %pad.1)
 | // CHECK:   %bitcast.1 = f32[224,224]{1,0} bitcast(f32[50176]{0} %pad.1)
 | ||||||
| // CHECK:   ROOT %reduce.2 = f32[7]{0} reduce(f32[7,8192]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add
 | // CHECK:   ROOT %reduce.2 = f32[224]{0} reduce(f32[224,224]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add
 | ||||||
| // CHECK: }
 | // CHECK: }
 | ||||||
| // CHECK: ENTRY %main (input: f32[50000]) -> f32[] {
 | // CHECK: ENTRY %main (input: f32[50000]) -> f32[] {
 | ||||||
| // CHECK:   %input = f32[50000]{0} parameter(0)
 | // CHECK:   %input = f32[50000]{0} parameter(0)
 | ||||||
| // CHECK:   %fusion = f32[7]{0} fusion(f32[50000]{0} %input), kind=kInput, calls=%fused_computation
 | // CHECK:   %fusion = f32[224]{0} fusion(f32[50000]{0} %input), kind=kInput, calls=%fused_computation
 | ||||||
| // CHECK:   %zero = f32[] constant(0)
 | // CHECK:   %zero = f32[] constant(0)
 | ||||||
| // CHECK:   ROOT %reduce.1 = f32[] reduce(f32[7]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
 | // CHECK:   ROOT %reduce.1 = f32[] reduce(f32[224]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
 | ||||||
| // CHECK: }
 | // CHECK: }
 | ||||||
|       )"); |       )"); | ||||||
| 
 | 
 | ||||||
| @ -107,27 +106,25 @@ ENTRY main { | |||||||
|   zero = f32[] constant(0) |   zero = f32[] constant(0) | ||||||
|   ROOT out = f32[100,100] reduce(input, zero), dimensions={2}, to_apply=add |   ROOT out = f32[100,100] reduce(input, zero), dimensions={2}, to_apply=add | ||||||
| } | } | ||||||
| 
 |  | ||||||
| )"; | )"; | ||||||
| 
 | 
 | ||||||
|   EnsureDeterminism(hlo_text); |   EnsureDeterminism(hlo_text); | ||||||
| 
 | 
 | ||||||
|   MatchOptimizedHloWithShapes(hlo_text, |   MatchOptimizedHloWithShapes(hlo_text, | ||||||
|                               R"( |                               R"( | ||||||
| // CHECK: %fused_computation (param_0.2: f32[100,100,10000]) -> f32[100,100,2] {
 | // CHECK: %fused_computation (param_0.2: f32[100,100,10000]) -> f32[100,100,100] {
 | ||||||
| // CHECK:   %param_0.2 = f32[100,100,10000]{2,1,0} parameter(0)
 | // CHECK:   %param_0.2 = f32[100,100,10000]{2,1,0} parameter(0)
 | ||||||
| // CHECK:   %zero_1 = f32[] constant(0)
 | // CHECK:   %zero_1 = f32[] constant(0)
 | ||||||
| // CHECK:   %pad.1 = f32[100,100,16384]{2,1,0} pad(f32[100,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384
 | // CHECK:   %pad.1 = f32[100,100,10000]{2,1,0} pad(f32[100,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_0
 | ||||||
| // CHECK:   %bitcast.1 = f32[100,100,2,8192]{3,2,1,0} bitcast(f32[100,100,16384]{2,1,0} %pad.1)
 | // CHECK:   %bitcast.1 = f32[100,100,100,100]{3,2,1,0} bitcast(f32[100,100,10000]{2,1,0} %pad.1)
 | ||||||
| // CHECK:   ROOT %reduce.2 = f32[100,100,2]{2,1,0} reduce(f32[100,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3}, to_apply=%add
 | // CHECK:   ROOT %reduce.2 = f32[100,100,100]{2,1,0} reduce(f32[100,100,100,100]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3}, to_apply=%add
 | ||||||
| // CHECK: }
 | // CHECK: }
 | ||||||
| // CHECK: ENTRY %main (input: f32[100,100,10000]) -> f32[100,100] {
 | // CHECK: ENTRY %main (input: f32[100,100,10000]) -> f32[100,100] {
 | ||||||
| // CHECK:   %input = f32[100,100,10000]{2,1,0} parameter(0)
 | // CHECK:   %input = f32[100,100,10000]{2,1,0} parameter(0)
 | ||||||
| // CHECK:   %fusion = f32[100,100,2]{2,1,0} fusion(f32[100,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
 | // CHECK:   %fusion = f32[100,100,100]{2,1,0} fusion(f32[100,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
 | ||||||
| // CHECK:   %zero = f32[] constant(0)
 | // CHECK:   %zero = f32[] constant(0)
 | ||||||
| // CHECK:   ROOT %reduce.1 = f32[100,100]{1,0} reduce(f32[100,100,2]{2,1,0} %fusion, f32[] %zero), dimensions={2}, to_apply=%add
 | // CHECK:   ROOT %reduce.1 = f32[100,100]{1,0} reduce(f32[100,100,100]{2,1,0} %fusion, f32[] %zero), dimensions={2}, to_apply=%add
 | ||||||
| // CHECK: }
 | // CHECK: }
 | ||||||
| 
 |  | ||||||
|       )"); |       )"); | ||||||
| 
 | 
 | ||||||
|   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); |   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); | ||||||
| @ -149,23 +146,22 @@ ENTRY main { | |||||||
|   zero = f32[] constant(0) |   zero = f32[] constant(0) | ||||||
|   ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add |   ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add | ||||||
| } | } | ||||||
| 
 |  | ||||||
| )"; | )"; | ||||||
| 
 | 
 | ||||||
|   MatchOptimizedHloWithShapes(hlo_text, |   MatchOptimizedHloWithShapes(hlo_text, | ||||||
|                               R"( |                               R"( | ||||||
| // CHECK: %fused_computation (param_0.2: f32[1000000]) -> f32[123] {
 | // CHECK: %fused_computation (param_0.2: f32[1000000]) -> f32[1000] {
 | ||||||
| // CHECK:   %param_0.2 = f32[1000000]{0} parameter(0)
 | // CHECK:   %param_0.2 = f32[1000000]{0} parameter(0)
 | ||||||
| // CHECK:   %zero_1 = f32[] constant(0)
 | // CHECK:   %zero_1 = f32[] constant(0)
 | ||||||
| // CHECK:   %pad.1 = f32[1007616]{0} pad(f32[1000000]{0} %param_0.2, f32[] %zero_1), padding=0_7616
 | // CHECK:   %pad.1 = f32[1000000]{0} pad(f32[1000000]{0} %param_0.2, f32[] %zero_1), padding=0_0
 | ||||||
| // CHECK:   %bitcast.1 = f32[123,8192]{1,0} bitcast(f32[1007616]{0} %pad.1)
 | // CHECK:   %bitcast.1 = f32[1000,1000]{1,0} bitcast(f32[1000000]{0} %pad.1)
 | ||||||
| // CHECK:   ROOT %reduce.2 = f32[123]{0} reduce(f32[123,8192]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add
 | // CHECK:   ROOT %reduce.2 = f32[1000]{0} reduce(f32[1000,1000]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add
 | ||||||
| // CHECK: }
 | // CHECK: }
 | ||||||
| // CHECK: ENTRY %main (input: f32[1000000]) -> f32[] {
 | // CHECK: ENTRY %main (input: f32[1000000]) -> f32[] {
 | ||||||
| // CHECK:   %input = f32[1000000]{0} parameter(0)
 | // CHECK:   %input = f32[1000000]{0} parameter(0)
 | ||||||
| // CHECK:   %fusion = f32[123]{0} fusion(f32[1000000]{0} %input), kind=kInput, calls=%fused_computation
 | // CHECK:   %fusion = f32[1000]{0} fusion(f32[1000000]{0} %input), kind=kInput, calls=%fused_computation
 | ||||||
| // CHECK:   %zero = f32[] constant(0)
 | // CHECK:   %zero = f32[] constant(0)
 | ||||||
| // CHECK:   ROOT %reduce.1 = f32[] reduce(f32[123]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
 | // CHECK:   ROOT %reduce.1 = f32[] reduce(f32[1000]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
 | ||||||
| // CHECK: }
 | // CHECK: }
 | ||||||
|       )"); |       )"); | ||||||
| 
 | 
 | ||||||
| @ -188,25 +184,24 @@ ENTRY main { | |||||||
|   zero = f32[] constant(0) |   zero = f32[] constant(0) | ||||||
|   ROOT out = f32[100] reduce(input, zero), dimensions={0,2}, to_apply=add |   ROOT out = f32[100] reduce(input, zero), dimensions={0,2}, to_apply=add | ||||||
| } | } | ||||||
| 
 |  | ||||||
| )"; | )"; | ||||||
| 
 | 
 | ||||||
|   EnsureDeterminism(hlo_text); |   EnsureDeterminism(hlo_text); | ||||||
| 
 | 
 | ||||||
|   MatchOptimizedHloWithShapes(hlo_text, |   MatchOptimizedHloWithShapes(hlo_text, | ||||||
|                               R"( |                               R"( | ||||||
| // CHECK: %fused_computation (param_0.2: f32[8,100,10000]) -> f32[100,2] {
 | // CHECK: %fused_computation (param_0.2: f32[8,100,10000]) -> f32[100,100] {
 | ||||||
| // CHECK:   %param_0.2 = f32[8,100,10000]{2,1,0} parameter(0)
 | // CHECK:   %param_0.2 = f32[8,100,10000]{2,1,0} parameter(0)
 | ||||||
| // CHECK:   %zero_1 = f32[] constant(0)
 | // CHECK:   %zero_1 = f32[] constant(0)
 | ||||||
| // CHECK:   %pad.1 = f32[8,100,16384]{2,1,0} pad(f32[8,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384
 | // CHECK:   %pad.1 = f32[8,100,10000]{2,1,0} pad(f32[8,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_0
 | ||||||
| // CHECK:   %bitcast.1 = f32[8,100,2,8192]{3,2,1,0} bitcast(f32[8,100,16384]{2,1,0} %pad.1)
 | // CHECK:   %bitcast.1 = f32[8,100,100,100]{3,2,1,0} bitcast(f32[8,100,10000]{2,1,0} %pad.1)
 | ||||||
| // CHECK:   ROOT %reduce.2 = f32[100,2]{1,0} reduce(f32[8,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3,0}, to_apply=%add
 | // CHECK:   ROOT %reduce.2 = f32[100,100]{1,0} reduce(f32[8,100,100,100]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3,0}, to_apply=%add
 | ||||||
| // CHECK: }
 | // CHECK: }
 | ||||||
| // CHECK: ENTRY %main (input: f32[8,100,10000]) -> f32[100] {
 | // CHECK: ENTRY %main (input: f32[8,100,10000]) -> f32[100] {
 | ||||||
| // CHECK:   %input = f32[8,100,10000]{2,1,0} parameter(0)
 | // CHECK:   %input = f32[8,100,10000]{2,1,0} parameter(0)
 | ||||||
| // CHECK:   %fusion = f32[100,2]{1,0} fusion(f32[8,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
 | // CHECK:   %fusion = f32[100,100]{1,0} fusion(f32[8,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
 | ||||||
| // CHECK:   %zero = f32[] constant(0)
 | // CHECK:   %zero = f32[] constant(0)
 | ||||||
| // CHECK:   ROOT %reduce.1 = f32[100]{0} reduce(f32[100,2]{1,0} %fusion, f32[] %zero), dimensions={1}, to_apply=%add
 | // CHECK:   ROOT %reduce.1 = f32[100]{0} reduce(f32[100,100]{1,0} %fusion, f32[] %zero), dimensions={1}, to_apply=%add
 | ||||||
| // CHECK: }
 | // CHECK: }
 | ||||||
|       )"); |       )"); | ||||||
| 
 | 
 | ||||||
| @ -234,23 +229,19 @@ ENTRY main { | |||||||
| 
 | 
 | ||||||
|   MatchOptimizedHloWithShapes(hlo_text, |   MatchOptimizedHloWithShapes(hlo_text, | ||||||
|                               R"( |                               R"( | ||||||
| // CHECK: %fused_computation (param_0.4: f32[32,100,2]) -> f32[100] {
 | // CHECK: %fused_computation (param_0.2: f32[32,100,10000]) -> f32[32,100,100] {
 | ||||||
| // CHECK:   %param_0.4 = f32[32,100,2]{2,1,0} parameter(0)
 | // CHECK:   %param_0.2 = f32[32,100,10000]{2,1,0} parameter(0)
 | ||||||
| // CHECK:   %zero_1 = f32[] constant(0)
 | // CHECK:   %zero_1 = f32[] constant(0)
 | ||||||
| // CHECK:   %reduce.5 = f32[32,100]{1,0} reduce(f32[32,100,2]{2,1,0} %param_0.4, f32[] %zero_1), dimensions={2}, to_apply=%add
 | // CHECK:   %pad.1 = f32[32,100,10000]{2,1,0} pad(f32[32,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_0
 | ||||||
| // CHECK:   ROOT %reduce.4 = f32[100]{0} reduce(f32[32,100]{1,0} %reduce.5, f32[] %zero_1), dimensions={0}, to_apply=%add
 | // CHECK:   %bitcast.1 = f32[32,100,100,100]{3,2,1,0} bitcast(f32[32,100,10000]{2,1,0} %pad.1)
 | ||||||
| // CHECK: }
 | // CHECK:   ROOT %reduce.4 = f32[32,100,100]{2,1,0} reduce(f32[32,100,100,100]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3}, to_apply=%add
 | ||||||
| // CHECK: %fused_computation.1 (param_0.5: f32[32,100,10000]) -> f32[32,100,2] {
 |  | ||||||
| // CHECK:   %param_0.5 = f32[32,100,10000]{2,1,0} parameter(0)
 |  | ||||||
| // CHECK:   %zero_2 = f32[] constant(0)
 |  | ||||||
| // CHECK:   %pad.1 = f32[32,100,16384]{2,1,0} pad(f32[32,100,10000]{2,1,0} %param_0.5, f32[] %zero_2), padding=0_0x0_0x0_6384
 |  | ||||||
| // CHECK:   %bitcast.1 = f32[32,100,2,8192]{3,2,1,0} bitcast(f32[32,100,16384]{2,1,0} %pad.1)
 |  | ||||||
| // CHECK:   ROOT %reduce.6 = f32[32,100,2]{2,1,0} reduce(f32[32,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_2), dimensions={3}, to_apply=%add
 |  | ||||||
| // CHECK: }
 | // CHECK: }
 | ||||||
| // CHECK: ENTRY %main (input: f32[32,100,10000]) -> f32[100] {
 | // CHECK: ENTRY %main (input: f32[32,100,10000]) -> f32[100] {
 | ||||||
| // CHECK:   %input = f32[32,100,10000]{2,1,0} parameter(0)
 | // CHECK:   %input = f32[32,100,10000]{2,1,0} parameter(0)
 | ||||||
| // CHECK:   %fusion.1 = f32[32,100,2]{2,1,0} fusion(f32[32,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation.1
 | // CHECK:   %fusion = f32[32,100,100]{2,1,0} fusion(f32[32,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
 | ||||||
| // CHECK:   ROOT %fusion = f32[100]{0} fusion(f32[32,100,2]{2,1,0} %fusion.1), kind=kInput, calls=%fused_computation
 | // CHECK:   %zero = f32[] constant(0)
 | ||||||
|  | // CHECK:   %reduce.3 = f32[32,100]{1,0} reduce(f32[32,100,100]{2,1,0} %fusion, f32[] %zero), dimensions={2}, to_apply=%add
 | ||||||
|  | // CHECK:   ROOT %reduce.1 = f32[100]{0} reduce(f32[32,100]{1,0} %reduce.3, f32[] %zero), dimensions={0}, to_apply=%add
 | ||||||
| // CHECK: }
 | // CHECK: }
 | ||||||
|       )"); |       )"); | ||||||
| 
 | 
 | ||||||
| @ -274,22 +265,22 @@ ENTRY main { | |||||||
|   zero = f32[] constant(0) |   zero = f32[] constant(0) | ||||||
|   ROOT out = f32[100] reduce(input, zero), dimensions={0}, to_apply=add |   ROOT out = f32[100] reduce(input, zero), dimensions={0}, to_apply=add | ||||||
| } | } | ||||||
| 
 |  | ||||||
| )"; | )"; | ||||||
| 
 | 
 | ||||||
|   MatchOptimizedHloWithShapes(hlo_text, |   MatchOptimizedHloWithShapes(hlo_text, | ||||||
|                               R"( |                               R"( | ||||||
| // CHECK: %fused_computation (param_0.2: f32[10000,100]) -> f32[100] {
 | // CHECK: %fused_computation (param_0.2: f32[10000,100]) -> f32[100,100] {
 | ||||||
| // CHECK:  %param_0.2 = f32[10000,100]{1,0} parameter(0)
 | // CHECK:   %param_0.2 = f32[10000,100]{1,0} parameter(0)
 | ||||||
| // CHECK:  %zero_1 = f32[] constant(0)
 | // CHECK:   %zero_1 = f32[] constant(0)
 | ||||||
| // CHECK:  %pad.1 = f32[12288,100]{1,0} pad(f32[10000,100]{1,0} %param_0.2, f32[] %zero_1), padding=0_2288x0_0
 | // CHECK:   %pad.1 = f32[10000,100]{1,0} pad(f32[10000,100]{1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0
 | ||||||
| // CHECK:  %bitcast.1 = f32[3,4096,100]{2,1,0} bitcast(f32[12288,100]{1,0} %pad.1)
 | // CHECK:   %bitcast.1 = f32[100,100,100]{2,1,0} bitcast(f32[10000,100]{1,0} %pad.1)
 | ||||||
| // CHECK:  %reduce.3 = f32[4096,100]{1,0} reduce(f32[3,4096,100]{2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add
 | // CHECK:   ROOT %reduce.2 = f32[100,100]{1,0} reduce(f32[100,100,100]{2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add
 | ||||||
| // CHECK:  ROOT %reduce.2 = f32[100]{0} reduce(f32[4096,100]{1,0} %reduce.3, f32[] %zero_1), dimensions={0}, to_apply=%add
 |  | ||||||
| // CHECK: }
 | // CHECK: }
 | ||||||
| // CHECK: ENTRY %main (input: f32[10000,100]) -> f32[100] {
 | // CHECK: ENTRY %main (input: f32[10000,100]) -> f32[100] {
 | ||||||
| // CHECK:  %input = f32[10000,100]{1,0} parameter(0)
 | // CHECK:   %input = f32[10000,100]{1,0} parameter(0)
 | ||||||
| // CHECK:  ROOT %fusion = f32[100]{0} fusion(f32[10000,100]{1,0} %input), kind=kInput, calls=%fused_computation
 | // CHECK:   %fusion = f32[100,100]{1,0} fusion(f32[10000,100]{1,0} %input), kind=kInput, calls=%fused_computation
 | ||||||
|  | // CHECK:   %zero = f32[] constant(0)
 | ||||||
|  | // CHECK:   ROOT %reduce.1 = f32[100]{0} reduce(f32[100,100]{1,0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
 | ||||||
| // CHECK: }
 | // CHECK: }
 | ||||||
|       )"); |       )"); | ||||||
| 
 | 
 | ||||||
| @ -316,17 +307,18 @@ ENTRY main { | |||||||
| 
 | 
 | ||||||
|   MatchOptimizedHloWithShapes(hlo_text, |   MatchOptimizedHloWithShapes(hlo_text, | ||||||
|                               R"( |                               R"( | ||||||
| // CHECK: %fused_computation (param_0.2: f32[10000,2,2,2]) -> f32[2,2,2] {
 | // CHECK: %fused_computation (param_0.2: f32[10000,2,2,2]) -> f32[100,2,2,2] {
 | ||||||
| // CHECK:  %param_0.2 = f32[10000,2,2,2]{3,2,1,0} parameter(0)
 | // CHECK:   %param_0.2 = f32[10000,2,2,2]{3,2,1,0} parameter(0)
 | ||||||
| // CHECK:  %zero_1 = f32[] constant(0)
 | // CHECK:   %zero_1 = f32[] constant(0)
 | ||||||
| // CHECK:  %pad.1 = f32[12288,2,2,2]{3,2,1,0} pad(f32[10000,2,2,2]{3,2,1,0} %param_0.2, f32[] %zero_1), padding=0_2288x0_0x0_0x0_0
 | // CHECK:   %pad.1 = f32[10000,2,2,2]{3,2,1,0} pad(f32[10000,2,2,2]{3,2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_0x0_0
 | ||||||
| // CHECK:  %bitcast.1 = f32[3,4096,2,2,2]{4,3,2,1,0} bitcast(f32[12288,2,2,2]{3,2,1,0} %pad.1)
 | // CHECK:   %bitcast.1 = f32[100,100,2,2,2]{4,3,2,1,0} bitcast(f32[10000,2,2,2]{3,2,1,0} %pad.1)
 | ||||||
| // CHECK:  %reduce.3 = f32[4096,2,2,2]{3,2,1,0} reduce(f32[3,4096,2,2,2]{4,3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add
 | // CHECK:   ROOT %reduce.2 = f32[100,2,2,2]{3,2,1,0} reduce(f32[100,100,2,2,2]{4,3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add
 | ||||||
| // CHECK:  ROOT %reduce.2 = f32[2,2,2]{2,1,0} reduce(f32[4096,2,2,2]{3,2,1,0} %reduce.3, f32[] %zero_1), dimensions={0}, to_apply=%add
 |  | ||||||
| // CHECK: }
 | // CHECK: }
 | ||||||
| // CHECK: ENTRY %main (input: f32[10000,2,2,2]) -> f32[2,2,2] {
 | // CHECK: ENTRY %main (input: f32[10000,2,2,2]) -> f32[2,2,2] {
 | ||||||
| // CHECK:  %input = f32[10000,2,2,2]{3,2,1,0} parameter(0)
 | // CHECK:   %input = f32[10000,2,2,2]{3,2,1,0} parameter(0)
 | ||||||
| // CHECK:  ROOT %fusion = f32[2,2,2]{2,1,0} fusion(f32[10000,2,2,2]{3,2,1,0} %input), kind=kInput, calls=%fused_computation
 | // CHECK:   %fusion = f32[100,2,2,2]{3,2,1,0} fusion(f32[10000,2,2,2]{3,2,1,0} %input), kind=kInput, calls=%fused_computation
 | ||||||
|  | // CHECK:   %zero = f32[] constant(0)
 | ||||||
|  | // CHECK:   ROOT %reduce.1 = f32[2,2,2]{2,1,0} reduce(f32[100,2,2,2]{3,2,1,0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
 | ||||||
| // CHECK: }
 | // CHECK: }
 | ||||||
|       )"); |       )"); | ||||||
| 
 | 
 | ||||||
| @ -355,18 +347,18 @@ ENTRY main { | |||||||
| 
 | 
 | ||||||
|   MatchOptimizedHloWithShapes(hlo_text, |   MatchOptimizedHloWithShapes(hlo_text, | ||||||
|                               R"( |                               R"( | ||||||
| // CHECK: %fused_computation (param_0.2: f32[1000000,5]) -> f32[4096,5] {
 | // CHECK: %fused_computation (param_0.2: f32[1000000,5]) -> f32[1000,5] {
 | ||||||
| // CHECK:  %param_0.2 = f32[1000000,5]{1,0} parameter(0)
 | // CHECK:   %param_0.2 = f32[1000000,5]{1,0} parameter(0)
 | ||||||
| // CHECK:  %zero_1 = f32[] constant(0)
 | // CHECK:   %zero_1 = f32[] constant(0)
 | ||||||
| // CHECK:  %pad.1 = f32[1003520,5]{1,0} pad(f32[1000000,5]{1,0} %param_0.2, f32[] %zero_1), padding=0_3520x0_0
 | // CHECK:   %pad.1 = f32[1000000,5]{1,0} pad(f32[1000000,5]{1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0
 | ||||||
| // CHECK:  %bitcast.1 = f32[245,4096,5]{2,1,0} bitcast(f32[1003520,5]{1,0} %pad.1)
 | // CHECK:   %bitcast.1 = f32[1000,1000,5]{2,1,0} bitcast(f32[1000000,5]{1,0} %pad.1)
 | ||||||
| // CHECK:  ROOT %reduce.2 = f32[4096,5]{1,0} reduce(f32[245,4096,5]{2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add
 | // CHECK:   ROOT %reduce.2 = f32[1000,5]{1,0} reduce(f32[1000,1000,5]{2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add
 | ||||||
| // CHECK: }
 | // CHECK: }
 | ||||||
| // CHECK: ENTRY %main (input: f32[1000000,5]) -> f32[5] {
 | // CHECK: ENTRY %main (input: f32[1000000,5]) -> f32[5] {
 | ||||||
| // CHECK:  %input = f32[1000000,5]{1,0} parameter(0)
 | // CHECK:   %input = f32[1000000,5]{1,0} parameter(0)
 | ||||||
| // CHECK:  %fusion = f32[4096,5]{1,0} fusion(f32[1000000,5]{1,0} %input), kind=kInput, calls=%fused_computation
 | // CHECK:   %fusion = f32[1000,5]{1,0} fusion(f32[1000000,5]{1,0} %input), kind=kInput, calls=%fused_computation
 | ||||||
| // CHECK:  %zero = f32[] constant(0)
 | // CHECK:   %zero = f32[] constant(0)
 | ||||||
| // CHECK:  ROOT %reduce.1 = f32[5]{0} reduce(f32[4096,5]{1,0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
 | // CHECK:   ROOT %reduce.1 = f32[5]{0} reduce(f32[1000,5]{1,0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
 | ||||||
| // CHECK: }
 | // CHECK: }
 | ||||||
|       )"); |       )"); | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -46,6 +46,11 @@ static constexpr int64 kColumnAtomicFreeBound = kWarpSize * 128; | |||||||
| // decreased column/row tiling.
 | // decreased column/row tiling.
 | ||||||
| static constexpr int64 kBatchedAtomicFreeBound = 8; | static constexpr int64 kBatchedAtomicFreeBound = 8; | ||||||
| 
 | 
 | ||||||
|  | // Returns the square root of the input rounded up to the nearest square.
 | ||||||
|  | static int64 SqrtOfRoundUpToNearestSquare(int64 input) { | ||||||
|  |   return static_cast<int64>(std::ceil(std::sqrt(input))); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| class ReductionRewriterVisitor : public DfsHloRewriteVisitor { | class ReductionRewriterVisitor : public DfsHloRewriteVisitor { | ||||||
|  public: |  public: | ||||||
|   explicit ReductionRewriterVisitor() {} |   explicit ReductionRewriterVisitor() {} | ||||||
| @ -105,39 +110,29 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { | |||||||
| 
 | 
 | ||||||
|     int64 reduced_dim_size = input_shape.dimensions(reduced_input_dimension); |     int64 reduced_dim_size = input_shape.dimensions(reduced_input_dimension); | ||||||
|     VLOG(3) << "reduced_dim_size = " << reduced_dim_size; |     VLOG(3) << "reduced_dim_size = " << reduced_dim_size; | ||||||
|     // TODO(cheshire): if atomic_free_bound is large, num_fit is likely to be
 | 
 | ||||||
|     // small. Generating a reduction with very small reduced dimension is not
 |     // We pad to a nearest square (ceil(sqrt(x)))^2.  Given that:
 | ||||||
|     // efficient, it would be better to split the dimension sizes more evenly.
 |  | ||||||
|     //
 |  | ||||||
|     // One possible idea is to pad to a nearest square (ceil(sqrt(x)))^2.
 |  | ||||||
|     // Given that:
 |  | ||||||
|     //
 |     //
 | ||||||
|     // (n + 1)^2 = n^2 + (2n+1)
 |     // (n + 1)^2 = n^2 + (2n+1)
 | ||||||
|     //
 |     //
 | ||||||
|     // it can be seen that the distance to the nearest square is at most twice
 |     // it can be seen that the distance to the nearest square is at most twice
 | ||||||
|     // the square root of the input number.
 |     // the square root of the input number.
 | ||||||
|     int64 num_fit = CeilOfRatio(reduced_dim_size, atomic_free_bound); |     int64 num_fit = SqrtOfRoundUpToNearestSquare(reduced_dim_size); | ||||||
| 
 | 
 | ||||||
|     // Pad reduced dimension to the required number of elements.
 |     // Pad reduced dimension to the required number of elements.
 | ||||||
|     HloInstruction *padded = [&] { |     HloInstruction *padded = [&] { | ||||||
|       // TODO(cheshire): if atomic_free_bound is very large, padding all the way
 |       int64 padded_num_elements = num_fit * num_fit; | ||||||
|       // up to to atomic_free_bound is wasteful, we could pad to a much smaller
 |       PaddingConfig padding_config = MakeNoPaddingConfig(input_shape.rank()); | ||||||
|       // value.
 |       padding_config.mutable_dimensions(reduced_input_dimension) | ||||||
|       if (reduced_dim_size % atomic_free_bound != 0) { |           ->set_edge_padding_high(padded_num_elements - reduced_dim_size); | ||||||
|         int64 padded_num_elements = num_fit * atomic_free_bound; |       std::vector<int64> padded_dimensions(input_shape.dimensions().begin(), | ||||||
|         PaddingConfig padding_config = MakeNoPaddingConfig(input_shape.rank()); |                                            input_shape.dimensions().end()); | ||||||
|         padding_config.mutable_dimensions(reduced_input_dimension) |       padded_dimensions[reduced_input_dimension] = padded_num_elements; | ||||||
|             ->set_edge_padding_high(padded_num_elements - reduced_dim_size); |       Shape padded_shape = | ||||||
|         std::vector<int64> padded_dimensions(input_shape.dimensions().begin(), |           ShapeUtil::MakeShape(input_shape.element_type(), padded_dimensions); | ||||||
|                                              input_shape.dimensions().end()); |       VLOG(3) << "Generated padded shape: " << padded_shape.ToString(); | ||||||
|         padded_dimensions[reduced_input_dimension] = padded_num_elements; |       return hlo->parent()->AddInstruction(HloInstruction::CreatePad( | ||||||
|         Shape padded_shape = |           padded_shape, input, initial_value, padding_config)); | ||||||
|             ShapeUtil::MakeShape(input_shape.element_type(), padded_dimensions); |  | ||||||
|         VLOG(3) << "Generated padded shape: " << padded_shape.ToString(); |  | ||||||
|         return hlo->parent()->AddInstruction(HloInstruction::CreatePad( |  | ||||||
|             padded_shape, input, initial_value, padding_config)); |  | ||||||
|       } |  | ||||||
|       return input; |  | ||||||
|     }(); |     }(); | ||||||
| 
 | 
 | ||||||
|     VLOG(1) << "Generated padding: " << padded->ToString(); |     VLOG(1) << "Generated padding: " << padded->ToString(); | ||||||
| @ -146,7 +141,7 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { | |||||||
|          dim_idx++) { |          dim_idx++) { | ||||||
|       if (dim_idx == reduced_input_dimension) { |       if (dim_idx == reduced_input_dimension) { | ||||||
|         reshaped_dimensions.push_back(num_fit); |         reshaped_dimensions.push_back(num_fit); | ||||||
|         reshaped_dimensions.push_back(atomic_free_bound); |         reshaped_dimensions.push_back(num_fit); | ||||||
|       } else { |       } else { | ||||||
|         reshaped_dimensions.push_back(padded->shape().dimensions(dim_idx)); |         reshaped_dimensions.push_back(padded->shape().dimensions(dim_idx)); | ||||||
|       } |       } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user