diff --git a/tensorflow/compiler/xla/service/gpu/tests/tree_reduction_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/tests/tree_reduction_rewriter_test.cc index 96560cd86aa..c0210ff941d 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/tree_reduction_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/tree_reduction_rewriter_test.cc @@ -73,17 +73,18 @@ ENTRY main { // TODO(cheshire): a more generic check, do not hardcode the names. MatchOptimizedHloWithShapes(hlo_text, R"( -// CHECK: %fused_computation (param_0.2: f32[50000]) -> f32[] { +// CHECK: %fused_computation (param_0.2: f32[50000]) -> f32[7] { // CHECK: %param_0.2 = f32[50000]{0} parameter(0) // CHECK: %zero_1 = f32[] constant(0) -// CHECK: %pad.1 = f32[65536]{0} pad(f32[50000]{0} %param_0.2, f32[] %zero_1), padding=0_15536 -// CHECK: %bitcast.1 = f32[4,16384]{1,0} bitcast(f32[65536]{0} %pad.1) -// CHECK: %reduce.3 = f32[16384]{0} reduce(f32[4,16384]{1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add -// CHECK: ROOT %reduce.2 = f32[] reduce(f32[16384]{0} %reduce.3, f32[] %zero_1), dimensions={0}, to_apply=%add +// CHECK: %pad.1 = f32[57344]{0} pad(f32[50000]{0} %param_0.2, f32[] %zero_1), padding=0_7344 +// CHECK: %bitcast.1 = f32[7,8192]{1,0} bitcast(f32[57344]{0} %pad.1) +// CHECK: ROOT %reduce.2 = f32[7]{0} reduce(f32[7,8192]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add // CHECK: } // CHECK: ENTRY %main (input: f32[50000]) -> f32[] { // CHECK: %input = f32[50000]{0} parameter(0) -// CHECK: ROOT %fusion = f32[] fusion(f32[50000]{0} %input), kind=kInput, calls=%fused_computation +// CHECK: %fusion = f32[7]{0} fusion(f32[50000]{0} %input), kind=kInput, calls=%fused_computation +// CHECK: %zero = f32[] constant(0) +// CHECK: ROOT %reduce.1 = f32[] reduce(f32[7]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add // CHECK: } )"); @@ -113,18 +114,20 @@ ENTRY main { MatchOptimizedHloWithShapes(hlo_text, R"( -// CHECK: %fused_computation (param_0.2: f32[100,100,10000]) -> f32[100,100] { +// CHECK: %fused_computation (param_0.2: f32[100,100,10000]) -> f32[100,100,2] { // CHECK: %param_0.2 = f32[100,100,10000]{2,1,0} parameter(0) // CHECK: %zero_1 = f32[] constant(0) // CHECK: %pad.1 = f32[100,100,16384]{2,1,0} pad(f32[100,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384 // CHECK: %bitcast.1 = f32[100,100,2,8192]{3,2,1,0} bitcast(f32[100,100,16384]{2,1,0} %pad.1) -// CHECK: %reduce.3 = f32[100,100,8192]{2,1,0} reduce(f32[100,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={2}, to_apply=%add -// CHECK: ROOT %reduce.2 = f32[100,100]{1,0} reduce(f32[100,100,8192]{2,1,0} %reduce.3, f32[] %zero_1), dimensions={2}, to_apply=%add +// CHECK: ROOT %reduce.2 = f32[100,100,2]{2,1,0} reduce(f32[100,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3}, to_apply=%add // CHECK: } // CHECK: ENTRY %main (input: f32[100,100,10000]) -> f32[100,100] { // CHECK: %input = f32[100,100,10000]{2,1,0} parameter(0) -// CHECK: ROOT %fusion = f32[100,100]{1,0} fusion(f32[100,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation +// CHECK: %fusion = f32[100,100,2]{2,1,0} fusion(f32[100,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation +// CHECK: %zero = f32[] constant(0) +// CHECK: ROOT %reduce.1 = f32[100,100]{1,0} reduce(f32[100,100,2]{2,1,0} %fusion, f32[] %zero), dimensions={2}, to_apply=%add // CHECK: } + )"); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); @@ -151,18 +154,18 @@ ENTRY main { MatchOptimizedHloWithShapes(hlo_text, R"( -// CHECK: %fused_computation (param_0.2: f32[1000000]) -> f32[16384] { +// CHECK: %fused_computation (param_0.2: f32[1000000]) -> f32[123] { // CHECK: %param_0.2 = f32[1000000]{0} parameter(0) // CHECK: %zero_1 = f32[] constant(0) -// CHECK: %pad.1 = f32[1015808]{0} pad(f32[1000000]{0} %param_0.2, f32[] %zero_1), padding=0_15808 -// CHECK: %bitcast.1 = f32[62,16384]{1,0} bitcast(f32[1015808]{0} %pad.1) -// CHECK: ROOT %reduce.2 = f32[16384]{0} reduce(f32[62,16384]{1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add +// CHECK: %pad.1 = f32[1007616]{0} pad(f32[1000000]{0} %param_0.2, f32[] %zero_1), padding=0_7616 +// CHECK: %bitcast.1 = f32[123,8192]{1,0} bitcast(f32[1007616]{0} %pad.1) +// CHECK: ROOT %reduce.2 = f32[123]{0} reduce(f32[123,8192]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add // CHECK: } // CHECK: ENTRY %main (input: f32[1000000]) -> f32[] { // CHECK: %input = f32[1000000]{0} parameter(0) -// CHECK: %fusion = f32[16384]{0} fusion(f32[1000000]{0} %input), kind=kInput, calls=%fused_computation +// CHECK: %fusion = f32[123]{0} fusion(f32[1000000]{0} %input), kind=kInput, calls=%fused_computation // CHECK: %zero = f32[] constant(0) -// CHECK: ROOT %reduce.1 = f32[] reduce(f32[16384]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add +// CHECK: ROOT %reduce.1 = f32[] reduce(f32[123]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add // CHECK: } )"); @@ -192,17 +195,18 @@ ENTRY main { MatchOptimizedHloWithShapes(hlo_text, R"( -// CHECK: %fused_computation (param_0.2: f32[8,100,10000]) -> f32[100] { -// CHECK: %param_0.2 = f32[8,100,10000]{2,1,0} parameter(0) -// CHECK: %zero_1 = f32[] constant(0) -// CHECK: %pad.1 = f32[8,100,16384]{2,1,0} pad(f32[8,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384 -// CHECK: %bitcast.1 = f32[8,100,2,8192]{3,2,1,0} bitcast(f32[8,100,16384]{2,1,0} %pad.1) -// CHECK: %reduce.3 = f32[100,8192]{1,0} reduce(f32[8,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={2,0}, to_apply=%add -// CHECK: ROOT %reduce.2 = f32[100]{0} reduce(f32[100,8192]{1,0} %reduce.3, f32[] %zero_1), dimensions={1}, to_apply=%add +// CHECK: %fused_computation (param_0.2: f32[8,100,10000]) -> f32[100,2] { +// CHECK: %param_0.2 = f32[8,100,10000]{2,1,0} parameter(0) +// CHECK: %zero_1 = f32[] constant(0) +// CHECK: %pad.1 = f32[8,100,16384]{2,1,0} pad(f32[8,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384 +// CHECK: %bitcast.1 = f32[8,100,2,8192]{3,2,1,0} bitcast(f32[8,100,16384]{2,1,0} %pad.1) +// CHECK: ROOT %reduce.2 = f32[100,2]{1,0} reduce(f32[8,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3,0}, to_apply=%add // CHECK: } // CHECK: ENTRY %main (input: f32[8,100,10000]) -> f32[100] { -// CHECK: %input = f32[8,100,10000]{2,1,0} parameter(0) -// CHECK: ROOT %fusion = f32[100]{0} fusion(f32[8,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation +// CHECK: %input = f32[8,100,10000]{2,1,0} parameter(0) +// CHECK: %fusion = f32[100,2]{1,0} fusion(f32[8,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation +// CHECK: %zero = f32[] constant(0) +// CHECK: ROOT %reduce.1 = f32[100]{0} reduce(f32[100,2]{1,0} %fusion, f32[] %zero), dimensions={1}, to_apply=%add // CHECK: } )"); @@ -210,7 +214,6 @@ ENTRY main { } TEST_F(TreeReductionRewriterTest, RowReductionBatchedDimensionDoesNotFit) { - // Note: this could be too slow without shared memory optimization. const char* hlo_text = R"( HloModule ReduceWithPadding @@ -225,26 +228,29 @@ ENTRY main { zero = f32[] constant(0) ROOT out = f32[100] reduce(input, zero), dimensions={0,2}, to_apply=add } - )"; EnsureDeterminism(hlo_text); MatchOptimizedHloWithShapes(hlo_text, R"( -// CHECK: %fused_computation (param_0.2: f32[32,100,10000]) -> f32[32,100] { -// CHECK: %param_0.2 = f32[32,100,10000]{2,1,0} parameter(0) +// CHECK: %fused_computation (param_0.4: f32[32,100,2]) -> f32[100] { +// CHECK: %param_0.4 = f32[32,100,2]{2,1,0} parameter(0) // CHECK: %zero_1 = f32[] constant(0) -// CHECK: %pad.1 = f32[32,100,16384]{2,1,0} pad(f32[32,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384 +// CHECK: %reduce.5 = f32[32,100]{1,0} reduce(f32[32,100,2]{2,1,0} %param_0.4, f32[] %zero_1), dimensions={2}, to_apply=%add +// CHECK: ROOT %reduce.4 = f32[100]{0} reduce(f32[32,100]{1,0} %reduce.5, f32[] %zero_1), dimensions={0}, to_apply=%add +// CHECK: } +// CHECK: %fused_computation.1 (param_0.5: f32[32,100,10000]) -> f32[32,100,2] { +// CHECK: %param_0.5 = f32[32,100,10000]{2,1,0} parameter(0) +// CHECK: %zero_2 = f32[] constant(0) +// CHECK: %pad.1 = f32[32,100,16384]{2,1,0} pad(f32[32,100,10000]{2,1,0} %param_0.5, f32[] %zero_2), padding=0_0x0_0x0_6384 // CHECK: %bitcast.1 = f32[32,100,2,8192]{3,2,1,0} bitcast(f32[32,100,16384]{2,1,0} %pad.1) -// CHECK: %reduce.5 = f32[32,100,8192]{2,1,0} reduce(f32[32,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={2}, to_apply=%add -// CHECK: ROOT %reduce.4 = f32[32,100]{1,0} reduce(f32[32,100,8192]{2,1,0} %reduce.5, f32[] %zero_1), dimensions={2}, to_apply=%add +// CHECK: ROOT %reduce.6 = f32[32,100,2]{2,1,0} reduce(f32[32,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_2), dimensions={3}, to_apply=%add // CHECK: } // CHECK: ENTRY %main (input: f32[32,100,10000]) -> f32[100] { // CHECK: %input = f32[32,100,10000]{2,1,0} parameter(0) -// CHECK: %fusion = f32[32,100]{1,0} fusion(f32[32,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation -// CHECK: %zero = f32[] constant(0) -// CHECK: ROOT %reduce.1 = f32[100]{0} reduce(f32[32,100]{1,0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add +// CHECK: %fusion.1 = f32[32,100,2]{2,1,0} fusion(f32[32,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation.1 +// CHECK: ROOT %fusion = f32[100]{0} fusion(f32[32,100,2]{2,1,0} %fusion.1), kind=kInput, calls=%fused_computation // CHECK: } )"); @@ -306,7 +312,6 @@ ENTRY main { zero = f32[] constant(0) ROOT out = f32[2,2,2] reduce(input, zero), dimensions={0}, to_apply=add } - )"; MatchOptimizedHloWithShapes(hlo_text, @@ -346,7 +351,6 @@ ENTRY main { zero = f32[] constant(0) ROOT out = f32[5] reduce(input, zero), dimensions={0}, to_apply=add } - )"; MatchOptimizedHloWithShapes(hlo_text, diff --git a/tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.cc b/tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.cc index c3d70f70a9f..5dad97dab39 100644 --- a/tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.cc @@ -38,6 +38,14 @@ limitations under the License. namespace xla { namespace gpu { +// TODO(cheshire): duplication w/ GetReductionTiling, but we need to get a +// minimum possible tiling, regardless of the input. +static constexpr int64 kRowAtomicFreeBound = kWarpSize * kWarpSize * 8; +static constexpr int64 kColumnAtomicFreeBound = kWarpSize * 128; +// TODO(cheshire): This is very small, we could increase it at the cost of +// decreased column/row tiling. +static constexpr int64 kBatchedAtomicFreeBound = 8; + class ReductionRewriterVisitor : public DfsHloRewriteVisitor { public: explicit ReductionRewriterVisitor() {} @@ -65,13 +73,8 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { Shape input_shape = input->shape(); VLOG(3) << "Input shape: " << input_shape.ToString(); - std::array reduction_tiling = - GetReductionTiling(reduction_dimensions); - - int64 batched_atomic_free_bound = reduction_tiling[0]; bool reduce_batch_dimension = hlo->dimensions().size() > 1; VLOG(3) << "reduce_batch_dimension = " << reduce_batch_dimension; - VLOG(3) << "batched atomic free: " << batched_atomic_free_bound; std::vector reduced_dimensions = hlo->dimensions(); absl::c_sort(reduced_dimensions); @@ -82,19 +85,16 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { // Case (1): batched dimension does not fit. if (reduce_batch_dimension && - input_shape.dimensions(0) > batched_atomic_free_bound) { + input_shape.dimensions(0) > kBatchedAtomicFreeBound) { VLOG(1) << "Splitting batched dimension reduce into a separate reduction"; return RewriteBatchDimensionLargerThanTile(hlo, reduction_dimensions, reduced_input_dimension, input_shape, input); } + bool is_row_reduction = reduction_dimensions.is_row_reduction; - int64 atomic_free_bound = [&] { - if (reduction_dimensions.is_row_reduction) { - return reduction_tiling[2] * kWarpSize * kWarpSize; - } - return reduction_tiling[1] * kWarpSize; - }(); + int64 atomic_free_bound = + is_row_reduction ? kRowAtomicFreeBound : kColumnAtomicFreeBound; VLOG(3) << "atomic_free_bound: " << atomic_free_bound; // Base case: everything fits. @@ -105,10 +105,24 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { int64 reduced_dim_size = input_shape.dimensions(reduced_input_dimension); VLOG(3) << "reduced_dim_size = " << reduced_dim_size; + // TODO(cheshire): if atomic_free_bound is large, num_fit is likely to be + // small. Generating a reduction with very small reduced dimension is not + // efficient, it would be better to split the dimension sizes more evenly. + // + // One possible idea is to pad to a nearest square (ceil(sqrt(x)))^2. + // Given that: + // + // (n + 1)^2 = n^2 + (2n+1) + // + // it can be seen that the distance to the nearest square is at most twice + // the square root of the input number. int64 num_fit = CeilOfRatio(reduced_dim_size, atomic_free_bound); // Pad reduced dimension to the required number of elements. HloInstruction *padded = [&] { + // TODO(cheshire): if atomic_free_bound is very large, padding all the way + // up to to atomic_free_bound is wasteful, we could pad to a much smaller + // value. if (reduced_dim_size % atomic_free_bound != 0) { int64 padded_num_elements = num_fit * atomic_free_bound; PaddingConfig padding_config = MakeNoPaddingConfig(input_shape.rank()); @@ -145,20 +159,21 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { VLOG(1) << "Generated reshape: " << reshaped_padded_input->ToString(); std::vector inner_reduce_dimensions = reshaped_dimensions; + int64 inner_reduced_dimension = is_row_reduction + ? inner_reduce_dimensions.size() - 1 + : reduced_input_dimension; + VLOG(2) << "inner_reduced_dimension = " << inner_reduced_dimension; inner_reduce_dimensions.erase(inner_reduce_dimensions.begin() + - reduced_input_dimension); + inner_reduced_dimension); if (reduce_batch_dimension) { inner_reduce_dimensions.erase(inner_reduce_dimensions.begin()); } - Shape inner_reduce_shape = ShapeUtil::MakeShape(input_shape.element_type(), inner_reduce_dimensions); - std::vector dims_to_reduce = {reduced_input_dimension}; - - int64 reduced_inner_dimension = reduced_input_dimension; + std::vector dims_to_reduce = {inner_reduced_dimension}; if (reduce_batch_dimension) { dims_to_reduce.push_back(0); - reduced_inner_dimension -= 1; + inner_reduced_dimension -= 1; } HloInstruction *inner_reduce = @@ -166,20 +181,21 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { inner_reduce_shape, reshaped_padded_input, initial_value, dims_to_reduce, hlo->to_apply())); VLOG(1) << "Generated inner reduction: " << inner_reduce->ToString(); - std::vector outer_reduce_dimensions = inner_reduce_dimensions; VLOG(3) << "outer_reduce_dimensions = " << absl::StrJoin(outer_reduce_dimensions, ", "); - VLOG(3) << "reduced_inner_dimension = " << reduced_inner_dimension; + int64 outer_reduced_dimension = is_row_reduction + ? outer_reduce_dimensions.size() - 1 + : reduced_input_dimension; // Remove reduced dimension. outer_reduce_dimensions.erase(outer_reduce_dimensions.begin() + - reduced_inner_dimension); + outer_reduced_dimension); Shape outer_reduce_shape = ShapeUtil::MakeShape(input_shape.element_type(), outer_reduce_dimensions); std::unique_ptr outer_reduce = HloInstruction::CreateReduce( outer_reduce_shape, inner_reduce, initial_value, - {reduced_inner_dimension}, hlo->to_apply()); + {outer_reduced_dimension}, hlo->to_apply()); VLOG(1) << "Generated outer reduction: " << outer_reduce->ToString(); return ReplaceWithNewInstruction(hlo, std::move(outer_reduce));