diff --git a/tensorflow/compiler/xla/service/gpu/tests/tree_reduction_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/tests/tree_reduction_rewriter_test.cc index c0210ff941d..eb821c36fae 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/tree_reduction_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/tree_reduction_rewriter_test.cc @@ -67,24 +67,23 @@ ENTRY main { zero = f32[] constant(0) ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add } - )"; // TODO(cheshire): a more generic check, do not hardcode the names. MatchOptimizedHloWithShapes(hlo_text, R"( -// CHECK: %fused_computation (param_0.2: f32[50000]) -> f32[7] { +// CHECK: %fused_computation (param_0.2: f32[50000]) -> f32[224] { // CHECK: %param_0.2 = f32[50000]{0} parameter(0) // CHECK: %zero_1 = f32[] constant(0) -// CHECK: %pad.1 = f32[57344]{0} pad(f32[50000]{0} %param_0.2, f32[] %zero_1), padding=0_7344 -// CHECK: %bitcast.1 = f32[7,8192]{1,0} bitcast(f32[57344]{0} %pad.1) -// CHECK: ROOT %reduce.2 = f32[7]{0} reduce(f32[7,8192]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add +// CHECK: %pad.1 = f32[50176]{0} pad(f32[50000]{0} %param_0.2, f32[] %zero_1), padding=0_176 +// CHECK: %bitcast.1 = f32[224,224]{1,0} bitcast(f32[50176]{0} %pad.1) +// CHECK: ROOT %reduce.2 = f32[224]{0} reduce(f32[224,224]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add // CHECK: } // CHECK: ENTRY %main (input: f32[50000]) -> f32[] { // CHECK: %input = f32[50000]{0} parameter(0) -// CHECK: %fusion = f32[7]{0} fusion(f32[50000]{0} %input), kind=kInput, calls=%fused_computation +// CHECK: %fusion = f32[224]{0} fusion(f32[50000]{0} %input), kind=kInput, calls=%fused_computation // CHECK: %zero = f32[] constant(0) -// CHECK: ROOT %reduce.1 = f32[] reduce(f32[7]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add +// CHECK: ROOT %reduce.1 = f32[] reduce(f32[224]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add // CHECK: } )"); @@ -107,27 +106,25 @@ ENTRY main { zero = f32[] constant(0) ROOT out = f32[100,100] reduce(input, zero), dimensions={2}, to_apply=add } - )"; EnsureDeterminism(hlo_text); MatchOptimizedHloWithShapes(hlo_text, R"( -// CHECK: %fused_computation (param_0.2: f32[100,100,10000]) -> f32[100,100,2] { +// CHECK: %fused_computation (param_0.2: f32[100,100,10000]) -> f32[100,100,100] { // CHECK: %param_0.2 = f32[100,100,10000]{2,1,0} parameter(0) // CHECK: %zero_1 = f32[] constant(0) -// CHECK: %pad.1 = f32[100,100,16384]{2,1,0} pad(f32[100,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384 -// CHECK: %bitcast.1 = f32[100,100,2,8192]{3,2,1,0} bitcast(f32[100,100,16384]{2,1,0} %pad.1) -// CHECK: ROOT %reduce.2 = f32[100,100,2]{2,1,0} reduce(f32[100,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3}, to_apply=%add +// CHECK: %pad.1 = f32[100,100,10000]{2,1,0} pad(f32[100,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_0 +// CHECK: %bitcast.1 = f32[100,100,100,100]{3,2,1,0} bitcast(f32[100,100,10000]{2,1,0} %pad.1) +// CHECK: ROOT %reduce.2 = f32[100,100,100]{2,1,0} reduce(f32[100,100,100,100]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3}, to_apply=%add // CHECK: } // CHECK: ENTRY %main (input: f32[100,100,10000]) -> f32[100,100] { // CHECK: %input = f32[100,100,10000]{2,1,0} parameter(0) -// CHECK: %fusion = f32[100,100,2]{2,1,0} fusion(f32[100,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation +// CHECK: %fusion = f32[100,100,100]{2,1,0} fusion(f32[100,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation // CHECK: %zero = f32[] constant(0) -// CHECK: ROOT %reduce.1 = f32[100,100]{1,0} reduce(f32[100,100,2]{2,1,0} %fusion, f32[] %zero), dimensions={2}, to_apply=%add +// CHECK: ROOT %reduce.1 = f32[100,100]{1,0} reduce(f32[100,100,100]{2,1,0} %fusion, f32[] %zero), dimensions={2}, to_apply=%add // CHECK: } - )"); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); @@ -149,23 +146,22 @@ ENTRY main { zero = f32[] constant(0) ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add } - )"; MatchOptimizedHloWithShapes(hlo_text, R"( -// CHECK: %fused_computation (param_0.2: f32[1000000]) -> f32[123] { +// CHECK: %fused_computation (param_0.2: f32[1000000]) -> f32[1000] { // CHECK: %param_0.2 = f32[1000000]{0} parameter(0) // CHECK: %zero_1 = f32[] constant(0) -// CHECK: %pad.1 = f32[1007616]{0} pad(f32[1000000]{0} %param_0.2, f32[] %zero_1), padding=0_7616 -// CHECK: %bitcast.1 = f32[123,8192]{1,0} bitcast(f32[1007616]{0} %pad.1) -// CHECK: ROOT %reduce.2 = f32[123]{0} reduce(f32[123,8192]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add +// CHECK: %pad.1 = f32[1000000]{0} pad(f32[1000000]{0} %param_0.2, f32[] %zero_1), padding=0_0 +// CHECK: %bitcast.1 = f32[1000,1000]{1,0} bitcast(f32[1000000]{0} %pad.1) +// CHECK: ROOT %reduce.2 = f32[1000]{0} reduce(f32[1000,1000]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add // CHECK: } // CHECK: ENTRY %main (input: f32[1000000]) -> f32[] { // CHECK: %input = f32[1000000]{0} parameter(0) -// CHECK: %fusion = f32[123]{0} fusion(f32[1000000]{0} %input), kind=kInput, calls=%fused_computation +// CHECK: %fusion = f32[1000]{0} fusion(f32[1000000]{0} %input), kind=kInput, calls=%fused_computation // CHECK: %zero = f32[] constant(0) -// CHECK: ROOT %reduce.1 = f32[] reduce(f32[123]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add +// CHECK: ROOT %reduce.1 = f32[] reduce(f32[1000]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add // CHECK: } )"); @@ -188,25 +184,24 @@ ENTRY main { zero = f32[] constant(0) ROOT out = f32[100] reduce(input, zero), dimensions={0,2}, to_apply=add } - )"; EnsureDeterminism(hlo_text); MatchOptimizedHloWithShapes(hlo_text, R"( -// CHECK: %fused_computation (param_0.2: f32[8,100,10000]) -> f32[100,2] { +// CHECK: %fused_computation (param_0.2: f32[8,100,10000]) -> f32[100,100] { // CHECK: %param_0.2 = f32[8,100,10000]{2,1,0} parameter(0) // CHECK: %zero_1 = f32[] constant(0) -// CHECK: %pad.1 = f32[8,100,16384]{2,1,0} pad(f32[8,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384 -// CHECK: %bitcast.1 = f32[8,100,2,8192]{3,2,1,0} bitcast(f32[8,100,16384]{2,1,0} %pad.1) -// CHECK: ROOT %reduce.2 = f32[100,2]{1,0} reduce(f32[8,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3,0}, to_apply=%add +// CHECK: %pad.1 = f32[8,100,10000]{2,1,0} pad(f32[8,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_0 +// CHECK: %bitcast.1 = f32[8,100,100,100]{3,2,1,0} bitcast(f32[8,100,10000]{2,1,0} %pad.1) +// CHECK: ROOT %reduce.2 = f32[100,100]{1,0} reduce(f32[8,100,100,100]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3,0}, to_apply=%add // CHECK: } // CHECK: ENTRY %main (input: f32[8,100,10000]) -> f32[100] { // CHECK: %input = f32[8,100,10000]{2,1,0} parameter(0) -// CHECK: %fusion = f32[100,2]{1,0} fusion(f32[8,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation +// CHECK: %fusion = f32[100,100]{1,0} fusion(f32[8,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation // CHECK: %zero = f32[] constant(0) -// CHECK: ROOT %reduce.1 = f32[100]{0} reduce(f32[100,2]{1,0} %fusion, f32[] %zero), dimensions={1}, to_apply=%add +// CHECK: ROOT %reduce.1 = f32[100]{0} reduce(f32[100,100]{1,0} %fusion, f32[] %zero), dimensions={1}, to_apply=%add // CHECK: } )"); @@ -234,23 +229,19 @@ ENTRY main { MatchOptimizedHloWithShapes(hlo_text, R"( -// CHECK: %fused_computation (param_0.4: f32[32,100,2]) -> f32[100] { -// CHECK: %param_0.4 = f32[32,100,2]{2,1,0} parameter(0) +// CHECK: %fused_computation (param_0.2: f32[32,100,10000]) -> f32[32,100,100] { +// CHECK: %param_0.2 = f32[32,100,10000]{2,1,0} parameter(0) // CHECK: %zero_1 = f32[] constant(0) -// CHECK: %reduce.5 = f32[32,100]{1,0} reduce(f32[32,100,2]{2,1,0} %param_0.4, f32[] %zero_1), dimensions={2}, to_apply=%add -// CHECK: ROOT %reduce.4 = f32[100]{0} reduce(f32[32,100]{1,0} %reduce.5, f32[] %zero_1), dimensions={0}, to_apply=%add -// CHECK: } -// CHECK: %fused_computation.1 (param_0.5: f32[32,100,10000]) -> f32[32,100,2] { -// CHECK: %param_0.5 = f32[32,100,10000]{2,1,0} parameter(0) -// CHECK: %zero_2 = f32[] constant(0) -// CHECK: %pad.1 = f32[32,100,16384]{2,1,0} pad(f32[32,100,10000]{2,1,0} %param_0.5, f32[] %zero_2), padding=0_0x0_0x0_6384 -// CHECK: %bitcast.1 = f32[32,100,2,8192]{3,2,1,0} bitcast(f32[32,100,16384]{2,1,0} %pad.1) -// CHECK: ROOT %reduce.6 = f32[32,100,2]{2,1,0} reduce(f32[32,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_2), dimensions={3}, to_apply=%add +// CHECK: %pad.1 = f32[32,100,10000]{2,1,0} pad(f32[32,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_0 +// CHECK: %bitcast.1 = f32[32,100,100,100]{3,2,1,0} bitcast(f32[32,100,10000]{2,1,0} %pad.1) +// CHECK: ROOT %reduce.4 = f32[32,100,100]{2,1,0} reduce(f32[32,100,100,100]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3}, to_apply=%add // CHECK: } // CHECK: ENTRY %main (input: f32[32,100,10000]) -> f32[100] { // CHECK: %input = f32[32,100,10000]{2,1,0} parameter(0) -// CHECK: %fusion.1 = f32[32,100,2]{2,1,0} fusion(f32[32,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation.1 -// CHECK: ROOT %fusion = f32[100]{0} fusion(f32[32,100,2]{2,1,0} %fusion.1), kind=kInput, calls=%fused_computation +// CHECK: %fusion = f32[32,100,100]{2,1,0} fusion(f32[32,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation +// CHECK: %zero = f32[] constant(0) +// CHECK: %reduce.3 = f32[32,100]{1,0} reduce(f32[32,100,100]{2,1,0} %fusion, f32[] %zero), dimensions={2}, to_apply=%add +// CHECK: ROOT %reduce.1 = f32[100]{0} reduce(f32[32,100]{1,0} %reduce.3, f32[] %zero), dimensions={0}, to_apply=%add // CHECK: } )"); @@ -274,22 +265,22 @@ ENTRY main { zero = f32[] constant(0) ROOT out = f32[100] reduce(input, zero), dimensions={0}, to_apply=add } - )"; MatchOptimizedHloWithShapes(hlo_text, R"( -// CHECK: %fused_computation (param_0.2: f32[10000,100]) -> f32[100] { -// CHECK: %param_0.2 = f32[10000,100]{1,0} parameter(0) -// CHECK: %zero_1 = f32[] constant(0) -// CHECK: %pad.1 = f32[12288,100]{1,0} pad(f32[10000,100]{1,0} %param_0.2, f32[] %zero_1), padding=0_2288x0_0 -// CHECK: %bitcast.1 = f32[3,4096,100]{2,1,0} bitcast(f32[12288,100]{1,0} %pad.1) -// CHECK: %reduce.3 = f32[4096,100]{1,0} reduce(f32[3,4096,100]{2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add -// CHECK: ROOT %reduce.2 = f32[100]{0} reduce(f32[4096,100]{1,0} %reduce.3, f32[] %zero_1), dimensions={0}, to_apply=%add +// CHECK: %fused_computation (param_0.2: f32[10000,100]) -> f32[100,100] { +// CHECK: %param_0.2 = f32[10000,100]{1,0} parameter(0) +// CHECK: %zero_1 = f32[] constant(0) +// CHECK: %pad.1 = f32[10000,100]{1,0} pad(f32[10000,100]{1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0 +// CHECK: %bitcast.1 = f32[100,100,100]{2,1,0} bitcast(f32[10000,100]{1,0} %pad.1) +// CHECK: ROOT %reduce.2 = f32[100,100]{1,0} reduce(f32[100,100,100]{2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add // CHECK: } // CHECK: ENTRY %main (input: f32[10000,100]) -> f32[100] { -// CHECK: %input = f32[10000,100]{1,0} parameter(0) -// CHECK: ROOT %fusion = f32[100]{0} fusion(f32[10000,100]{1,0} %input), kind=kInput, calls=%fused_computation +// CHECK: %input = f32[10000,100]{1,0} parameter(0) +// CHECK: %fusion = f32[100,100]{1,0} fusion(f32[10000,100]{1,0} %input), kind=kInput, calls=%fused_computation +// CHECK: %zero = f32[] constant(0) +// CHECK: ROOT %reduce.1 = f32[100]{0} reduce(f32[100,100]{1,0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add // CHECK: } )"); @@ -316,17 +307,18 @@ ENTRY main { MatchOptimizedHloWithShapes(hlo_text, R"( -// CHECK: %fused_computation (param_0.2: f32[10000,2,2,2]) -> f32[2,2,2] { -// CHECK: %param_0.2 = f32[10000,2,2,2]{3,2,1,0} parameter(0) -// CHECK: %zero_1 = f32[] constant(0) -// CHECK: %pad.1 = f32[12288,2,2,2]{3,2,1,0} pad(f32[10000,2,2,2]{3,2,1,0} %param_0.2, f32[] %zero_1), padding=0_2288x0_0x0_0x0_0 -// CHECK: %bitcast.1 = f32[3,4096,2,2,2]{4,3,2,1,0} bitcast(f32[12288,2,2,2]{3,2,1,0} %pad.1) -// CHECK: %reduce.3 = f32[4096,2,2,2]{3,2,1,0} reduce(f32[3,4096,2,2,2]{4,3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add -// CHECK: ROOT %reduce.2 = f32[2,2,2]{2,1,0} reduce(f32[4096,2,2,2]{3,2,1,0} %reduce.3, f32[] %zero_1), dimensions={0}, to_apply=%add +// CHECK: %fused_computation (param_0.2: f32[10000,2,2,2]) -> f32[100,2,2,2] { +// CHECK: %param_0.2 = f32[10000,2,2,2]{3,2,1,0} parameter(0) +// CHECK: %zero_1 = f32[] constant(0) +// CHECK: %pad.1 = f32[10000,2,2,2]{3,2,1,0} pad(f32[10000,2,2,2]{3,2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_0x0_0 +// CHECK: %bitcast.1 = f32[100,100,2,2,2]{4,3,2,1,0} bitcast(f32[10000,2,2,2]{3,2,1,0} %pad.1) +// CHECK: ROOT %reduce.2 = f32[100,2,2,2]{3,2,1,0} reduce(f32[100,100,2,2,2]{4,3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add // CHECK: } // CHECK: ENTRY %main (input: f32[10000,2,2,2]) -> f32[2,2,2] { -// CHECK: %input = f32[10000,2,2,2]{3,2,1,0} parameter(0) -// CHECK: ROOT %fusion = f32[2,2,2]{2,1,0} fusion(f32[10000,2,2,2]{3,2,1,0} %input), kind=kInput, calls=%fused_computation +// CHECK: %input = f32[10000,2,2,2]{3,2,1,0} parameter(0) +// CHECK: %fusion = f32[100,2,2,2]{3,2,1,0} fusion(f32[10000,2,2,2]{3,2,1,0} %input), kind=kInput, calls=%fused_computation +// CHECK: %zero = f32[] constant(0) +// CHECK: ROOT %reduce.1 = f32[2,2,2]{2,1,0} reduce(f32[100,2,2,2]{3,2,1,0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add // CHECK: } )"); @@ -355,18 +347,18 @@ ENTRY main { MatchOptimizedHloWithShapes(hlo_text, R"( -// CHECK: %fused_computation (param_0.2: f32[1000000,5]) -> f32[4096,5] { -// CHECK: %param_0.2 = f32[1000000,5]{1,0} parameter(0) -// CHECK: %zero_1 = f32[] constant(0) -// CHECK: %pad.1 = f32[1003520,5]{1,0} pad(f32[1000000,5]{1,0} %param_0.2, f32[] %zero_1), padding=0_3520x0_0 -// CHECK: %bitcast.1 = f32[245,4096,5]{2,1,0} bitcast(f32[1003520,5]{1,0} %pad.1) -// CHECK: ROOT %reduce.2 = f32[4096,5]{1,0} reduce(f32[245,4096,5]{2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add +// CHECK: %fused_computation (param_0.2: f32[1000000,5]) -> f32[1000,5] { +// CHECK: %param_0.2 = f32[1000000,5]{1,0} parameter(0) +// CHECK: %zero_1 = f32[] constant(0) +// CHECK: %pad.1 = f32[1000000,5]{1,0} pad(f32[1000000,5]{1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0 +// CHECK: %bitcast.1 = f32[1000,1000,5]{2,1,0} bitcast(f32[1000000,5]{1,0} %pad.1) +// CHECK: ROOT %reduce.2 = f32[1000,5]{1,0} reduce(f32[1000,1000,5]{2,1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add // CHECK: } // CHECK: ENTRY %main (input: f32[1000000,5]) -> f32[5] { -// CHECK: %input = f32[1000000,5]{1,0} parameter(0) -// CHECK: %fusion = f32[4096,5]{1,0} fusion(f32[1000000,5]{1,0} %input), kind=kInput, calls=%fused_computation -// CHECK: %zero = f32[] constant(0) -// CHECK: ROOT %reduce.1 = f32[5]{0} reduce(f32[4096,5]{1,0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add +// CHECK: %input = f32[1000000,5]{1,0} parameter(0) +// CHECK: %fusion = f32[1000,5]{1,0} fusion(f32[1000000,5]{1,0} %input), kind=kInput, calls=%fused_computation +// CHECK: %zero = f32[] constant(0) +// CHECK: ROOT %reduce.1 = f32[5]{0} reduce(f32[1000,5]{1,0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add // CHECK: } )"); diff --git a/tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.cc b/tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.cc index 5dad97dab39..e6d4569478c 100644 --- a/tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.cc @@ -46,6 +46,11 @@ static constexpr int64 kColumnAtomicFreeBound = kWarpSize * 128; // decreased column/row tiling. static constexpr int64 kBatchedAtomicFreeBound = 8; +// Returns the square root of the input rounded up to the nearest square. +static int64 SqrtOfRoundUpToNearestSquare(int64 input) { + return static_cast(std::ceil(std::sqrt(input))); +} + class ReductionRewriterVisitor : public DfsHloRewriteVisitor { public: explicit ReductionRewriterVisitor() {} @@ -105,39 +110,29 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { int64 reduced_dim_size = input_shape.dimensions(reduced_input_dimension); VLOG(3) << "reduced_dim_size = " << reduced_dim_size; - // TODO(cheshire): if atomic_free_bound is large, num_fit is likely to be - // small. Generating a reduction with very small reduced dimension is not - // efficient, it would be better to split the dimension sizes more evenly. - // - // One possible idea is to pad to a nearest square (ceil(sqrt(x)))^2. - // Given that: + + // We pad to a nearest square (ceil(sqrt(x)))^2. Given that: // // (n + 1)^2 = n^2 + (2n+1) // // it can be seen that the distance to the nearest square is at most twice // the square root of the input number. - int64 num_fit = CeilOfRatio(reduced_dim_size, atomic_free_bound); + int64 num_fit = SqrtOfRoundUpToNearestSquare(reduced_dim_size); // Pad reduced dimension to the required number of elements. HloInstruction *padded = [&] { - // TODO(cheshire): if atomic_free_bound is very large, padding all the way - // up to to atomic_free_bound is wasteful, we could pad to a much smaller - // value. - if (reduced_dim_size % atomic_free_bound != 0) { - int64 padded_num_elements = num_fit * atomic_free_bound; - PaddingConfig padding_config = MakeNoPaddingConfig(input_shape.rank()); - padding_config.mutable_dimensions(reduced_input_dimension) - ->set_edge_padding_high(padded_num_elements - reduced_dim_size); - std::vector padded_dimensions(input_shape.dimensions().begin(), - input_shape.dimensions().end()); - padded_dimensions[reduced_input_dimension] = padded_num_elements; - Shape padded_shape = - ShapeUtil::MakeShape(input_shape.element_type(), padded_dimensions); - VLOG(3) << "Generated padded shape: " << padded_shape.ToString(); - return hlo->parent()->AddInstruction(HloInstruction::CreatePad( - padded_shape, input, initial_value, padding_config)); - } - return input; + int64 padded_num_elements = num_fit * num_fit; + PaddingConfig padding_config = MakeNoPaddingConfig(input_shape.rank()); + padding_config.mutable_dimensions(reduced_input_dimension) + ->set_edge_padding_high(padded_num_elements - reduced_dim_size); + std::vector padded_dimensions(input_shape.dimensions().begin(), + input_shape.dimensions().end()); + padded_dimensions[reduced_input_dimension] = padded_num_elements; + Shape padded_shape = + ShapeUtil::MakeShape(input_shape.element_type(), padded_dimensions); + VLOG(3) << "Generated padded shape: " << padded_shape.ToString(); + return hlo->parent()->AddInstruction(HloInstruction::CreatePad( + padded_shape, input, initial_value, padding_config)); }(); VLOG(1) << "Generated padding: " << padded->ToString(); @@ -146,7 +141,7 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { dim_idx++) { if (dim_idx == reduced_input_dimension) { reshaped_dimensions.push_back(num_fit); - reshaped_dimensions.push_back(atomic_free_bound); + reshaped_dimensions.push_back(num_fit); } else { reshaped_dimensions.push_back(padded->shape().dimensions(dim_idx)); }