[XLA/GPU] Fix some tree reduction rewriter bugs, add some TODOs

Main bug fixed: row reduction should be rewritten into a row reduction.

PiperOrigin-RevId: 295256755
Change-Id: I79606680a0c3b05cf0f21d5b9f3f9e4f01e6426b
This commit is contained in:
George Karpenkov 2020-02-14 16:55:19 -08:00 committed by TensorFlower Gardener
parent 30e3943cb2
commit a755fe8236
2 changed files with 79 additions and 59 deletions

View File

@ -73,17 +73,18 @@ ENTRY main {
// TODO(cheshire): a more generic check, do not hardcode the names. // TODO(cheshire): a more generic check, do not hardcode the names.
MatchOptimizedHloWithShapes(hlo_text, MatchOptimizedHloWithShapes(hlo_text,
R"( R"(
// CHECK: %fused_computation (param_0.2: f32[50000]) -> f32[] { // CHECK: %fused_computation (param_0.2: f32[50000]) -> f32[7] {
// CHECK: %param_0.2 = f32[50000]{0} parameter(0) // CHECK: %param_0.2 = f32[50000]{0} parameter(0)
// CHECK: %zero_1 = f32[] constant(0) // CHECK: %zero_1 = f32[] constant(0)
// CHECK: %pad.1 = f32[65536]{0} pad(f32[50000]{0} %param_0.2, f32[] %zero_1), padding=0_15536 // CHECK: %pad.1 = f32[57344]{0} pad(f32[50000]{0} %param_0.2, f32[] %zero_1), padding=0_7344
// CHECK: %bitcast.1 = f32[4,16384]{1,0} bitcast(f32[65536]{0} %pad.1) // CHECK: %bitcast.1 = f32[7,8192]{1,0} bitcast(f32[57344]{0} %pad.1)
// CHECK: %reduce.3 = f32[16384]{0} reduce(f32[4,16384]{1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add // CHECK: ROOT %reduce.2 = f32[7]{0} reduce(f32[7,8192]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add
// CHECK: ROOT %reduce.2 = f32[] reduce(f32[16384]{0} %reduce.3, f32[] %zero_1), dimensions={0}, to_apply=%add
// CHECK: } // CHECK: }
// CHECK: ENTRY %main (input: f32[50000]) -> f32[] { // CHECK: ENTRY %main (input: f32[50000]) -> f32[] {
// CHECK: %input = f32[50000]{0} parameter(0) // CHECK: %input = f32[50000]{0} parameter(0)
// CHECK: ROOT %fusion = f32[] fusion(f32[50000]{0} %input), kind=kInput, calls=%fused_computation // CHECK: %fusion = f32[7]{0} fusion(f32[50000]{0} %input), kind=kInput, calls=%fused_computation
// CHECK: %zero = f32[] constant(0)
// CHECK: ROOT %reduce.1 = f32[] reduce(f32[7]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
// CHECK: } // CHECK: }
)"); )");
@ -113,18 +114,20 @@ ENTRY main {
MatchOptimizedHloWithShapes(hlo_text, MatchOptimizedHloWithShapes(hlo_text,
R"( R"(
// CHECK: %fused_computation (param_0.2: f32[100,100,10000]) -> f32[100,100] { // CHECK: %fused_computation (param_0.2: f32[100,100,10000]) -> f32[100,100,2] {
// CHECK: %param_0.2 = f32[100,100,10000]{2,1,0} parameter(0) // CHECK: %param_0.2 = f32[100,100,10000]{2,1,0} parameter(0)
// CHECK: %zero_1 = f32[] constant(0) // CHECK: %zero_1 = f32[] constant(0)
// CHECK: %pad.1 = f32[100,100,16384]{2,1,0} pad(f32[100,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384 // CHECK: %pad.1 = f32[100,100,16384]{2,1,0} pad(f32[100,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384
// CHECK: %bitcast.1 = f32[100,100,2,8192]{3,2,1,0} bitcast(f32[100,100,16384]{2,1,0} %pad.1) // CHECK: %bitcast.1 = f32[100,100,2,8192]{3,2,1,0} bitcast(f32[100,100,16384]{2,1,0} %pad.1)
// CHECK: %reduce.3 = f32[100,100,8192]{2,1,0} reduce(f32[100,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={2}, to_apply=%add // CHECK: ROOT %reduce.2 = f32[100,100,2]{2,1,0} reduce(f32[100,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3}, to_apply=%add
// CHECK: ROOT %reduce.2 = f32[100,100]{1,0} reduce(f32[100,100,8192]{2,1,0} %reduce.3, f32[] %zero_1), dimensions={2}, to_apply=%add
// CHECK: } // CHECK: }
// CHECK: ENTRY %main (input: f32[100,100,10000]) -> f32[100,100] { // CHECK: ENTRY %main (input: f32[100,100,10000]) -> f32[100,100] {
// CHECK: %input = f32[100,100,10000]{2,1,0} parameter(0) // CHECK: %input = f32[100,100,10000]{2,1,0} parameter(0)
// CHECK: ROOT %fusion = f32[100,100]{1,0} fusion(f32[100,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation // CHECK: %fusion = f32[100,100,2]{2,1,0} fusion(f32[100,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
// CHECK: %zero = f32[] constant(0)
// CHECK: ROOT %reduce.1 = f32[100,100]{1,0} reduce(f32[100,100,2]{2,1,0} %fusion, f32[] %zero), dimensions={2}, to_apply=%add
// CHECK: } // CHECK: }
)"); )");
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
@ -151,18 +154,18 @@ ENTRY main {
MatchOptimizedHloWithShapes(hlo_text, MatchOptimizedHloWithShapes(hlo_text,
R"( R"(
// CHECK: %fused_computation (param_0.2: f32[1000000]) -> f32[16384] { // CHECK: %fused_computation (param_0.2: f32[1000000]) -> f32[123] {
// CHECK: %param_0.2 = f32[1000000]{0} parameter(0) // CHECK: %param_0.2 = f32[1000000]{0} parameter(0)
// CHECK: %zero_1 = f32[] constant(0) // CHECK: %zero_1 = f32[] constant(0)
// CHECK: %pad.1 = f32[1015808]{0} pad(f32[1000000]{0} %param_0.2, f32[] %zero_1), padding=0_15808 // CHECK: %pad.1 = f32[1007616]{0} pad(f32[1000000]{0} %param_0.2, f32[] %zero_1), padding=0_7616
// CHECK: %bitcast.1 = f32[62,16384]{1,0} bitcast(f32[1015808]{0} %pad.1) // CHECK: %bitcast.1 = f32[123,8192]{1,0} bitcast(f32[1007616]{0} %pad.1)
// CHECK: ROOT %reduce.2 = f32[16384]{0} reduce(f32[62,16384]{1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add // CHECK: ROOT %reduce.2 = f32[123]{0} reduce(f32[123,8192]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add
// CHECK: } // CHECK: }
// CHECK: ENTRY %main (input: f32[1000000]) -> f32[] { // CHECK: ENTRY %main (input: f32[1000000]) -> f32[] {
// CHECK: %input = f32[1000000]{0} parameter(0) // CHECK: %input = f32[1000000]{0} parameter(0)
// CHECK: %fusion = f32[16384]{0} fusion(f32[1000000]{0} %input), kind=kInput, calls=%fused_computation // CHECK: %fusion = f32[123]{0} fusion(f32[1000000]{0} %input), kind=kInput, calls=%fused_computation
// CHECK: %zero = f32[] constant(0) // CHECK: %zero = f32[] constant(0)
// CHECK: ROOT %reduce.1 = f32[] reduce(f32[16384]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add // CHECK: ROOT %reduce.1 = f32[] reduce(f32[123]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
// CHECK: } // CHECK: }
)"); )");
@ -192,17 +195,18 @@ ENTRY main {
MatchOptimizedHloWithShapes(hlo_text, MatchOptimizedHloWithShapes(hlo_text,
R"( R"(
// CHECK: %fused_computation (param_0.2: f32[8,100,10000]) -> f32[100] { // CHECK: %fused_computation (param_0.2: f32[8,100,10000]) -> f32[100,2] {
// CHECK: %param_0.2 = f32[8,100,10000]{2,1,0} parameter(0) // CHECK: %param_0.2 = f32[8,100,10000]{2,1,0} parameter(0)
// CHECK: %zero_1 = f32[] constant(0) // CHECK: %zero_1 = f32[] constant(0)
// CHECK: %pad.1 = f32[8,100,16384]{2,1,0} pad(f32[8,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384 // CHECK: %pad.1 = f32[8,100,16384]{2,1,0} pad(f32[8,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384
// CHECK: %bitcast.1 = f32[8,100,2,8192]{3,2,1,0} bitcast(f32[8,100,16384]{2,1,0} %pad.1) // CHECK: %bitcast.1 = f32[8,100,2,8192]{3,2,1,0} bitcast(f32[8,100,16384]{2,1,0} %pad.1)
// CHECK: %reduce.3 = f32[100,8192]{1,0} reduce(f32[8,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={2,0}, to_apply=%add // CHECK: ROOT %reduce.2 = f32[100,2]{1,0} reduce(f32[8,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3,0}, to_apply=%add
// CHECK: ROOT %reduce.2 = f32[100]{0} reduce(f32[100,8192]{1,0} %reduce.3, f32[] %zero_1), dimensions={1}, to_apply=%add
// CHECK: } // CHECK: }
// CHECK: ENTRY %main (input: f32[8,100,10000]) -> f32[100] { // CHECK: ENTRY %main (input: f32[8,100,10000]) -> f32[100] {
// CHECK: %input = f32[8,100,10000]{2,1,0} parameter(0) // CHECK: %input = f32[8,100,10000]{2,1,0} parameter(0)
// CHECK: ROOT %fusion = f32[100]{0} fusion(f32[8,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation // CHECK: %fusion = f32[100,2]{1,0} fusion(f32[8,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
// CHECK: %zero = f32[] constant(0)
// CHECK: ROOT %reduce.1 = f32[100]{0} reduce(f32[100,2]{1,0} %fusion, f32[] %zero), dimensions={1}, to_apply=%add
// CHECK: } // CHECK: }
)"); )");
@ -210,7 +214,6 @@ ENTRY main {
} }
TEST_F(TreeReductionRewriterTest, RowReductionBatchedDimensionDoesNotFit) { TEST_F(TreeReductionRewriterTest, RowReductionBatchedDimensionDoesNotFit) {
// Note: this could be too slow without shared memory optimization.
const char* hlo_text = R"( const char* hlo_text = R"(
HloModule ReduceWithPadding HloModule ReduceWithPadding
@ -225,26 +228,29 @@ ENTRY main {
zero = f32[] constant(0) zero = f32[] constant(0)
ROOT out = f32[100] reduce(input, zero), dimensions={0,2}, to_apply=add ROOT out = f32[100] reduce(input, zero), dimensions={0,2}, to_apply=add
} }
)"; )";
EnsureDeterminism(hlo_text); EnsureDeterminism(hlo_text);
MatchOptimizedHloWithShapes(hlo_text, MatchOptimizedHloWithShapes(hlo_text,
R"( R"(
// CHECK: %fused_computation (param_0.2: f32[32,100,10000]) -> f32[32,100] { // CHECK: %fused_computation (param_0.4: f32[32,100,2]) -> f32[100] {
// CHECK: %param_0.2 = f32[32,100,10000]{2,1,0} parameter(0) // CHECK: %param_0.4 = f32[32,100,2]{2,1,0} parameter(0)
// CHECK: %zero_1 = f32[] constant(0) // CHECK: %zero_1 = f32[] constant(0)
// CHECK: %pad.1 = f32[32,100,16384]{2,1,0} pad(f32[32,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384 // CHECK: %reduce.5 = f32[32,100]{1,0} reduce(f32[32,100,2]{2,1,0} %param_0.4, f32[] %zero_1), dimensions={2}, to_apply=%add
// CHECK: ROOT %reduce.4 = f32[100]{0} reduce(f32[32,100]{1,0} %reduce.5, f32[] %zero_1), dimensions={0}, to_apply=%add
// CHECK: }
// CHECK: %fused_computation.1 (param_0.5: f32[32,100,10000]) -> f32[32,100,2] {
// CHECK: %param_0.5 = f32[32,100,10000]{2,1,0} parameter(0)
// CHECK: %zero_2 = f32[] constant(0)
// CHECK: %pad.1 = f32[32,100,16384]{2,1,0} pad(f32[32,100,10000]{2,1,0} %param_0.5, f32[] %zero_2), padding=0_0x0_0x0_6384
// CHECK: %bitcast.1 = f32[32,100,2,8192]{3,2,1,0} bitcast(f32[32,100,16384]{2,1,0} %pad.1) // CHECK: %bitcast.1 = f32[32,100,2,8192]{3,2,1,0} bitcast(f32[32,100,16384]{2,1,0} %pad.1)
// CHECK: %reduce.5 = f32[32,100,8192]{2,1,0} reduce(f32[32,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={2}, to_apply=%add // CHECK: ROOT %reduce.6 = f32[32,100,2]{2,1,0} reduce(f32[32,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_2), dimensions={3}, to_apply=%add
// CHECK: ROOT %reduce.4 = f32[32,100]{1,0} reduce(f32[32,100,8192]{2,1,0} %reduce.5, f32[] %zero_1), dimensions={2}, to_apply=%add
// CHECK: } // CHECK: }
// CHECK: ENTRY %main (input: f32[32,100,10000]) -> f32[100] { // CHECK: ENTRY %main (input: f32[32,100,10000]) -> f32[100] {
// CHECK: %input = f32[32,100,10000]{2,1,0} parameter(0) // CHECK: %input = f32[32,100,10000]{2,1,0} parameter(0)
// CHECK: %fusion = f32[32,100]{1,0} fusion(f32[32,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation // CHECK: %fusion.1 = f32[32,100,2]{2,1,0} fusion(f32[32,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation.1
// CHECK: %zero = f32[] constant(0) // CHECK: ROOT %fusion = f32[100]{0} fusion(f32[32,100,2]{2,1,0} %fusion.1), kind=kInput, calls=%fused_computation
// CHECK: ROOT %reduce.1 = f32[100]{0} reduce(f32[32,100]{1,0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
// CHECK: } // CHECK: }
)"); )");
@ -306,7 +312,6 @@ ENTRY main {
zero = f32[] constant(0) zero = f32[] constant(0)
ROOT out = f32[2,2,2] reduce(input, zero), dimensions={0}, to_apply=add ROOT out = f32[2,2,2] reduce(input, zero), dimensions={0}, to_apply=add
} }
)"; )";
MatchOptimizedHloWithShapes(hlo_text, MatchOptimizedHloWithShapes(hlo_text,
@ -346,7 +351,6 @@ ENTRY main {
zero = f32[] constant(0) zero = f32[] constant(0)
ROOT out = f32[5] reduce(input, zero), dimensions={0}, to_apply=add ROOT out = f32[5] reduce(input, zero), dimensions={0}, to_apply=add
} }
)"; )";
MatchOptimizedHloWithShapes(hlo_text, MatchOptimizedHloWithShapes(hlo_text,

View File

@ -38,6 +38,14 @@ limitations under the License.
namespace xla { namespace xla {
namespace gpu { namespace gpu {
// TODO(cheshire): duplication w/ GetReductionTiling, but we need to get a
// minimum possible tiling, regardless of the input.
static constexpr int64 kRowAtomicFreeBound = kWarpSize * kWarpSize * 8;
static constexpr int64 kColumnAtomicFreeBound = kWarpSize * 128;
// TODO(cheshire): This is very small, we could increase it at the cost of
// decreased column/row tiling.
static constexpr int64 kBatchedAtomicFreeBound = 8;
class ReductionRewriterVisitor : public DfsHloRewriteVisitor { class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
public: public:
explicit ReductionRewriterVisitor() {} explicit ReductionRewriterVisitor() {}
@ -65,13 +73,8 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
Shape input_shape = input->shape(); Shape input_shape = input->shape();
VLOG(3) << "Input shape: " << input_shape.ToString(); VLOG(3) << "Input shape: " << input_shape.ToString();
std::array<int64, 3> reduction_tiling =
GetReductionTiling(reduction_dimensions);
int64 batched_atomic_free_bound = reduction_tiling[0];
bool reduce_batch_dimension = hlo->dimensions().size() > 1; bool reduce_batch_dimension = hlo->dimensions().size() > 1;
VLOG(3) << "reduce_batch_dimension = " << reduce_batch_dimension; VLOG(3) << "reduce_batch_dimension = " << reduce_batch_dimension;
VLOG(3) << "batched atomic free: " << batched_atomic_free_bound;
std::vector<int64> reduced_dimensions = hlo->dimensions(); std::vector<int64> reduced_dimensions = hlo->dimensions();
absl::c_sort(reduced_dimensions); absl::c_sort(reduced_dimensions);
@ -82,19 +85,16 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
// Case (1): batched dimension does not fit. // Case (1): batched dimension does not fit.
if (reduce_batch_dimension && if (reduce_batch_dimension &&
input_shape.dimensions(0) > batched_atomic_free_bound) { input_shape.dimensions(0) > kBatchedAtomicFreeBound) {
VLOG(1) << "Splitting batched dimension reduce into a separate reduction"; VLOG(1) << "Splitting batched dimension reduce into a separate reduction";
return RewriteBatchDimensionLargerThanTile(hlo, reduction_dimensions, return RewriteBatchDimensionLargerThanTile(hlo, reduction_dimensions,
reduced_input_dimension, reduced_input_dimension,
input_shape, input); input_shape, input);
} }
bool is_row_reduction = reduction_dimensions.is_row_reduction;
int64 atomic_free_bound = [&] { int64 atomic_free_bound =
if (reduction_dimensions.is_row_reduction) { is_row_reduction ? kRowAtomicFreeBound : kColumnAtomicFreeBound;
return reduction_tiling[2] * kWarpSize * kWarpSize;
}
return reduction_tiling[1] * kWarpSize;
}();
VLOG(3) << "atomic_free_bound: " << atomic_free_bound; VLOG(3) << "atomic_free_bound: " << atomic_free_bound;
// Base case: everything fits. // Base case: everything fits.
@ -105,10 +105,24 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
int64 reduced_dim_size = input_shape.dimensions(reduced_input_dimension); int64 reduced_dim_size = input_shape.dimensions(reduced_input_dimension);
VLOG(3) << "reduced_dim_size = " << reduced_dim_size; VLOG(3) << "reduced_dim_size = " << reduced_dim_size;
// TODO(cheshire): if atomic_free_bound is large, num_fit is likely to be
// small. Generating a reduction with very small reduced dimension is not
// efficient, it would be better to split the dimension sizes more evenly.
//
// One possible idea is to pad to a nearest square (ceil(sqrt(x)))^2.
// Given that:
//
// (n + 1)^2 = n^2 + (2n+1)
//
// it can be seen that the distance to the nearest square is at most twice
// the square root of the input number.
int64 num_fit = CeilOfRatio(reduced_dim_size, atomic_free_bound); int64 num_fit = CeilOfRatio(reduced_dim_size, atomic_free_bound);
// Pad reduced dimension to the required number of elements. // Pad reduced dimension to the required number of elements.
HloInstruction *padded = [&] { HloInstruction *padded = [&] {
// TODO(cheshire): if atomic_free_bound is very large, padding all the way
// up to to atomic_free_bound is wasteful, we could pad to a much smaller
// value.
if (reduced_dim_size % atomic_free_bound != 0) { if (reduced_dim_size % atomic_free_bound != 0) {
int64 padded_num_elements = num_fit * atomic_free_bound; int64 padded_num_elements = num_fit * atomic_free_bound;
PaddingConfig padding_config = MakeNoPaddingConfig(input_shape.rank()); PaddingConfig padding_config = MakeNoPaddingConfig(input_shape.rank());
@ -145,20 +159,21 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
VLOG(1) << "Generated reshape: " << reshaped_padded_input->ToString(); VLOG(1) << "Generated reshape: " << reshaped_padded_input->ToString();
std::vector<int64> inner_reduce_dimensions = reshaped_dimensions; std::vector<int64> inner_reduce_dimensions = reshaped_dimensions;
int64 inner_reduced_dimension = is_row_reduction
? inner_reduce_dimensions.size() - 1
: reduced_input_dimension;
VLOG(2) << "inner_reduced_dimension = " << inner_reduced_dimension;
inner_reduce_dimensions.erase(inner_reduce_dimensions.begin() + inner_reduce_dimensions.erase(inner_reduce_dimensions.begin() +
reduced_input_dimension); inner_reduced_dimension);
if (reduce_batch_dimension) { if (reduce_batch_dimension) {
inner_reduce_dimensions.erase(inner_reduce_dimensions.begin()); inner_reduce_dimensions.erase(inner_reduce_dimensions.begin());
} }
Shape inner_reduce_shape = ShapeUtil::MakeShape(input_shape.element_type(), Shape inner_reduce_shape = ShapeUtil::MakeShape(input_shape.element_type(),
inner_reduce_dimensions); inner_reduce_dimensions);
std::vector<int64> dims_to_reduce = {reduced_input_dimension}; std::vector<int64> dims_to_reduce = {inner_reduced_dimension};
int64 reduced_inner_dimension = reduced_input_dimension;
if (reduce_batch_dimension) { if (reduce_batch_dimension) {
dims_to_reduce.push_back(0); dims_to_reduce.push_back(0);
reduced_inner_dimension -= 1; inner_reduced_dimension -= 1;
} }
HloInstruction *inner_reduce = HloInstruction *inner_reduce =
@ -166,20 +181,21 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
inner_reduce_shape, reshaped_padded_input, initial_value, inner_reduce_shape, reshaped_padded_input, initial_value,
dims_to_reduce, hlo->to_apply())); dims_to_reduce, hlo->to_apply()));
VLOG(1) << "Generated inner reduction: " << inner_reduce->ToString(); VLOG(1) << "Generated inner reduction: " << inner_reduce->ToString();
std::vector<int64> outer_reduce_dimensions = inner_reduce_dimensions; std::vector<int64> outer_reduce_dimensions = inner_reduce_dimensions;
VLOG(3) << "outer_reduce_dimensions = " VLOG(3) << "outer_reduce_dimensions = "
<< absl::StrJoin(outer_reduce_dimensions, ", "); << absl::StrJoin(outer_reduce_dimensions, ", ");
VLOG(3) << "reduced_inner_dimension = " << reduced_inner_dimension; int64 outer_reduced_dimension = is_row_reduction
? outer_reduce_dimensions.size() - 1
: reduced_input_dimension;
// Remove reduced dimension. // Remove reduced dimension.
outer_reduce_dimensions.erase(outer_reduce_dimensions.begin() + outer_reduce_dimensions.erase(outer_reduce_dimensions.begin() +
reduced_inner_dimension); outer_reduced_dimension);
Shape outer_reduce_shape = ShapeUtil::MakeShape(input_shape.element_type(), Shape outer_reduce_shape = ShapeUtil::MakeShape(input_shape.element_type(),
outer_reduce_dimensions); outer_reduce_dimensions);
std::unique_ptr<HloInstruction> outer_reduce = HloInstruction::CreateReduce( std::unique_ptr<HloInstruction> outer_reduce = HloInstruction::CreateReduce(
outer_reduce_shape, inner_reduce, initial_value, outer_reduce_shape, inner_reduce, initial_value,
{reduced_inner_dimension}, hlo->to_apply()); {outer_reduced_dimension}, hlo->to_apply());
VLOG(1) << "Generated outer reduction: " << outer_reduce->ToString(); VLOG(1) << "Generated outer reduction: " << outer_reduce->ToString();
return ReplaceWithNewInstruction(hlo, std::move(outer_reduce)); return ReplaceWithNewInstruction(hlo, std::move(outer_reduce));