[XLA/GPU] Fix some tree reduction rewriter bugs, add some TODOs
Main bug fixed: row reduction should be rewritten into a row reduction. PiperOrigin-RevId: 295256755 Change-Id: I79606680a0c3b05cf0f21d5b9f3f9e4f01e6426b
This commit is contained in:
parent
30e3943cb2
commit
a755fe8236
@ -73,17 +73,18 @@ ENTRY main {
|
|||||||
// TODO(cheshire): a more generic check, do not hardcode the names.
|
// TODO(cheshire): a more generic check, do not hardcode the names.
|
||||||
MatchOptimizedHloWithShapes(hlo_text,
|
MatchOptimizedHloWithShapes(hlo_text,
|
||||||
R"(
|
R"(
|
||||||
// CHECK: %fused_computation (param_0.2: f32[50000]) -> f32[] {
|
// CHECK: %fused_computation (param_0.2: f32[50000]) -> f32[7] {
|
||||||
// CHECK: %param_0.2 = f32[50000]{0} parameter(0)
|
// CHECK: %param_0.2 = f32[50000]{0} parameter(0)
|
||||||
// CHECK: %zero_1 = f32[] constant(0)
|
// CHECK: %zero_1 = f32[] constant(0)
|
||||||
// CHECK: %pad.1 = f32[65536]{0} pad(f32[50000]{0} %param_0.2, f32[] %zero_1), padding=0_15536
|
// CHECK: %pad.1 = f32[57344]{0} pad(f32[50000]{0} %param_0.2, f32[] %zero_1), padding=0_7344
|
||||||
// CHECK: %bitcast.1 = f32[4,16384]{1,0} bitcast(f32[65536]{0} %pad.1)
|
// CHECK: %bitcast.1 = f32[7,8192]{1,0} bitcast(f32[57344]{0} %pad.1)
|
||||||
// CHECK: %reduce.3 = f32[16384]{0} reduce(f32[4,16384]{1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add
|
// CHECK: ROOT %reduce.2 = f32[7]{0} reduce(f32[7,8192]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add
|
||||||
// CHECK: ROOT %reduce.2 = f32[] reduce(f32[16384]{0} %reduce.3, f32[] %zero_1), dimensions={0}, to_apply=%add
|
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK: ENTRY %main (input: f32[50000]) -> f32[] {
|
// CHECK: ENTRY %main (input: f32[50000]) -> f32[] {
|
||||||
// CHECK: %input = f32[50000]{0} parameter(0)
|
// CHECK: %input = f32[50000]{0} parameter(0)
|
||||||
// CHECK: ROOT %fusion = f32[] fusion(f32[50000]{0} %input), kind=kInput, calls=%fused_computation
|
// CHECK: %fusion = f32[7]{0} fusion(f32[50000]{0} %input), kind=kInput, calls=%fused_computation
|
||||||
|
// CHECK: %zero = f32[] constant(0)
|
||||||
|
// CHECK: ROOT %reduce.1 = f32[] reduce(f32[7]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
)");
|
)");
|
||||||
|
|
||||||
@ -113,18 +114,20 @@ ENTRY main {
|
|||||||
|
|
||||||
MatchOptimizedHloWithShapes(hlo_text,
|
MatchOptimizedHloWithShapes(hlo_text,
|
||||||
R"(
|
R"(
|
||||||
// CHECK: %fused_computation (param_0.2: f32[100,100,10000]) -> f32[100,100] {
|
// CHECK: %fused_computation (param_0.2: f32[100,100,10000]) -> f32[100,100,2] {
|
||||||
// CHECK: %param_0.2 = f32[100,100,10000]{2,1,0} parameter(0)
|
// CHECK: %param_0.2 = f32[100,100,10000]{2,1,0} parameter(0)
|
||||||
// CHECK: %zero_1 = f32[] constant(0)
|
// CHECK: %zero_1 = f32[] constant(0)
|
||||||
// CHECK: %pad.1 = f32[100,100,16384]{2,1,0} pad(f32[100,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384
|
// CHECK: %pad.1 = f32[100,100,16384]{2,1,0} pad(f32[100,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384
|
||||||
// CHECK: %bitcast.1 = f32[100,100,2,8192]{3,2,1,0} bitcast(f32[100,100,16384]{2,1,0} %pad.1)
|
// CHECK: %bitcast.1 = f32[100,100,2,8192]{3,2,1,0} bitcast(f32[100,100,16384]{2,1,0} %pad.1)
|
||||||
// CHECK: %reduce.3 = f32[100,100,8192]{2,1,0} reduce(f32[100,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={2}, to_apply=%add
|
// CHECK: ROOT %reduce.2 = f32[100,100,2]{2,1,0} reduce(f32[100,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3}, to_apply=%add
|
||||||
// CHECK: ROOT %reduce.2 = f32[100,100]{1,0} reduce(f32[100,100,8192]{2,1,0} %reduce.3, f32[] %zero_1), dimensions={2}, to_apply=%add
|
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK: ENTRY %main (input: f32[100,100,10000]) -> f32[100,100] {
|
// CHECK: ENTRY %main (input: f32[100,100,10000]) -> f32[100,100] {
|
||||||
// CHECK: %input = f32[100,100,10000]{2,1,0} parameter(0)
|
// CHECK: %input = f32[100,100,10000]{2,1,0} parameter(0)
|
||||||
// CHECK: ROOT %fusion = f32[100,100]{1,0} fusion(f32[100,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
|
// CHECK: %fusion = f32[100,100,2]{2,1,0} fusion(f32[100,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
|
||||||
|
// CHECK: %zero = f32[] constant(0)
|
||||||
|
// CHECK: ROOT %reduce.1 = f32[100,100]{1,0} reduce(f32[100,100,2]{2,1,0} %fusion, f32[] %zero), dimensions={2}, to_apply=%add
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
|
||||||
)");
|
)");
|
||||||
|
|
||||||
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
|
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
|
||||||
@ -151,18 +154,18 @@ ENTRY main {
|
|||||||
|
|
||||||
MatchOptimizedHloWithShapes(hlo_text,
|
MatchOptimizedHloWithShapes(hlo_text,
|
||||||
R"(
|
R"(
|
||||||
// CHECK: %fused_computation (param_0.2: f32[1000000]) -> f32[16384] {
|
// CHECK: %fused_computation (param_0.2: f32[1000000]) -> f32[123] {
|
||||||
// CHECK: %param_0.2 = f32[1000000]{0} parameter(0)
|
// CHECK: %param_0.2 = f32[1000000]{0} parameter(0)
|
||||||
// CHECK: %zero_1 = f32[] constant(0)
|
// CHECK: %zero_1 = f32[] constant(0)
|
||||||
// CHECK: %pad.1 = f32[1015808]{0} pad(f32[1000000]{0} %param_0.2, f32[] %zero_1), padding=0_15808
|
// CHECK: %pad.1 = f32[1007616]{0} pad(f32[1000000]{0} %param_0.2, f32[] %zero_1), padding=0_7616
|
||||||
// CHECK: %bitcast.1 = f32[62,16384]{1,0} bitcast(f32[1015808]{0} %pad.1)
|
// CHECK: %bitcast.1 = f32[123,8192]{1,0} bitcast(f32[1007616]{0} %pad.1)
|
||||||
// CHECK: ROOT %reduce.2 = f32[16384]{0} reduce(f32[62,16384]{1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add
|
// CHECK: ROOT %reduce.2 = f32[123]{0} reduce(f32[123,8192]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK: ENTRY %main (input: f32[1000000]) -> f32[] {
|
// CHECK: ENTRY %main (input: f32[1000000]) -> f32[] {
|
||||||
// CHECK: %input = f32[1000000]{0} parameter(0)
|
// CHECK: %input = f32[1000000]{0} parameter(0)
|
||||||
// CHECK: %fusion = f32[16384]{0} fusion(f32[1000000]{0} %input), kind=kInput, calls=%fused_computation
|
// CHECK: %fusion = f32[123]{0} fusion(f32[1000000]{0} %input), kind=kInput, calls=%fused_computation
|
||||||
// CHECK: %zero = f32[] constant(0)
|
// CHECK: %zero = f32[] constant(0)
|
||||||
// CHECK: ROOT %reduce.1 = f32[] reduce(f32[16384]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
|
// CHECK: ROOT %reduce.1 = f32[] reduce(f32[123]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
)");
|
)");
|
||||||
|
|
||||||
@ -192,17 +195,18 @@ ENTRY main {
|
|||||||
|
|
||||||
MatchOptimizedHloWithShapes(hlo_text,
|
MatchOptimizedHloWithShapes(hlo_text,
|
||||||
R"(
|
R"(
|
||||||
// CHECK: %fused_computation (param_0.2: f32[8,100,10000]) -> f32[100] {
|
// CHECK: %fused_computation (param_0.2: f32[8,100,10000]) -> f32[100,2] {
|
||||||
// CHECK: %param_0.2 = f32[8,100,10000]{2,1,0} parameter(0)
|
// CHECK: %param_0.2 = f32[8,100,10000]{2,1,0} parameter(0)
|
||||||
// CHECK: %zero_1 = f32[] constant(0)
|
// CHECK: %zero_1 = f32[] constant(0)
|
||||||
// CHECK: %pad.1 = f32[8,100,16384]{2,1,0} pad(f32[8,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384
|
// CHECK: %pad.1 = f32[8,100,16384]{2,1,0} pad(f32[8,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384
|
||||||
// CHECK: %bitcast.1 = f32[8,100,2,8192]{3,2,1,0} bitcast(f32[8,100,16384]{2,1,0} %pad.1)
|
// CHECK: %bitcast.1 = f32[8,100,2,8192]{3,2,1,0} bitcast(f32[8,100,16384]{2,1,0} %pad.1)
|
||||||
// CHECK: %reduce.3 = f32[100,8192]{1,0} reduce(f32[8,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={2,0}, to_apply=%add
|
// CHECK: ROOT %reduce.2 = f32[100,2]{1,0} reduce(f32[8,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3,0}, to_apply=%add
|
||||||
// CHECK: ROOT %reduce.2 = f32[100]{0} reduce(f32[100,8192]{1,0} %reduce.3, f32[] %zero_1), dimensions={1}, to_apply=%add
|
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK: ENTRY %main (input: f32[8,100,10000]) -> f32[100] {
|
// CHECK: ENTRY %main (input: f32[8,100,10000]) -> f32[100] {
|
||||||
// CHECK: %input = f32[8,100,10000]{2,1,0} parameter(0)
|
// CHECK: %input = f32[8,100,10000]{2,1,0} parameter(0)
|
||||||
// CHECK: ROOT %fusion = f32[100]{0} fusion(f32[8,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
|
// CHECK: %fusion = f32[100,2]{1,0} fusion(f32[8,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
|
||||||
|
// CHECK: %zero = f32[] constant(0)
|
||||||
|
// CHECK: ROOT %reduce.1 = f32[100]{0} reduce(f32[100,2]{1,0} %fusion, f32[] %zero), dimensions={1}, to_apply=%add
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
)");
|
)");
|
||||||
|
|
||||||
@ -210,7 +214,6 @@ ENTRY main {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TreeReductionRewriterTest, RowReductionBatchedDimensionDoesNotFit) {
|
TEST_F(TreeReductionRewriterTest, RowReductionBatchedDimensionDoesNotFit) {
|
||||||
// Note: this could be too slow without shared memory optimization.
|
|
||||||
const char* hlo_text = R"(
|
const char* hlo_text = R"(
|
||||||
HloModule ReduceWithPadding
|
HloModule ReduceWithPadding
|
||||||
|
|
||||||
@ -225,26 +228,29 @@ ENTRY main {
|
|||||||
zero = f32[] constant(0)
|
zero = f32[] constant(0)
|
||||||
ROOT out = f32[100] reduce(input, zero), dimensions={0,2}, to_apply=add
|
ROOT out = f32[100] reduce(input, zero), dimensions={0,2}, to_apply=add
|
||||||
}
|
}
|
||||||
|
|
||||||
)";
|
)";
|
||||||
|
|
||||||
EnsureDeterminism(hlo_text);
|
EnsureDeterminism(hlo_text);
|
||||||
|
|
||||||
MatchOptimizedHloWithShapes(hlo_text,
|
MatchOptimizedHloWithShapes(hlo_text,
|
||||||
R"(
|
R"(
|
||||||
// CHECK: %fused_computation (param_0.2: f32[32,100,10000]) -> f32[32,100] {
|
// CHECK: %fused_computation (param_0.4: f32[32,100,2]) -> f32[100] {
|
||||||
// CHECK: %param_0.2 = f32[32,100,10000]{2,1,0} parameter(0)
|
// CHECK: %param_0.4 = f32[32,100,2]{2,1,0} parameter(0)
|
||||||
// CHECK: %zero_1 = f32[] constant(0)
|
// CHECK: %zero_1 = f32[] constant(0)
|
||||||
// CHECK: %pad.1 = f32[32,100,16384]{2,1,0} pad(f32[32,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384
|
// CHECK: %reduce.5 = f32[32,100]{1,0} reduce(f32[32,100,2]{2,1,0} %param_0.4, f32[] %zero_1), dimensions={2}, to_apply=%add
|
||||||
|
// CHECK: ROOT %reduce.4 = f32[100]{0} reduce(f32[32,100]{1,0} %reduce.5, f32[] %zero_1), dimensions={0}, to_apply=%add
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: %fused_computation.1 (param_0.5: f32[32,100,10000]) -> f32[32,100,2] {
|
||||||
|
// CHECK: %param_0.5 = f32[32,100,10000]{2,1,0} parameter(0)
|
||||||
|
// CHECK: %zero_2 = f32[] constant(0)
|
||||||
|
// CHECK: %pad.1 = f32[32,100,16384]{2,1,0} pad(f32[32,100,10000]{2,1,0} %param_0.5, f32[] %zero_2), padding=0_0x0_0x0_6384
|
||||||
// CHECK: %bitcast.1 = f32[32,100,2,8192]{3,2,1,0} bitcast(f32[32,100,16384]{2,1,0} %pad.1)
|
// CHECK: %bitcast.1 = f32[32,100,2,8192]{3,2,1,0} bitcast(f32[32,100,16384]{2,1,0} %pad.1)
|
||||||
// CHECK: %reduce.5 = f32[32,100,8192]{2,1,0} reduce(f32[32,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={2}, to_apply=%add
|
// CHECK: ROOT %reduce.6 = f32[32,100,2]{2,1,0} reduce(f32[32,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_2), dimensions={3}, to_apply=%add
|
||||||
// CHECK: ROOT %reduce.4 = f32[32,100]{1,0} reduce(f32[32,100,8192]{2,1,0} %reduce.5, f32[] %zero_1), dimensions={2}, to_apply=%add
|
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK: ENTRY %main (input: f32[32,100,10000]) -> f32[100] {
|
// CHECK: ENTRY %main (input: f32[32,100,10000]) -> f32[100] {
|
||||||
// CHECK: %input = f32[32,100,10000]{2,1,0} parameter(0)
|
// CHECK: %input = f32[32,100,10000]{2,1,0} parameter(0)
|
||||||
// CHECK: %fusion = f32[32,100]{1,0} fusion(f32[32,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
|
// CHECK: %fusion.1 = f32[32,100,2]{2,1,0} fusion(f32[32,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation.1
|
||||||
// CHECK: %zero = f32[] constant(0)
|
// CHECK: ROOT %fusion = f32[100]{0} fusion(f32[32,100,2]{2,1,0} %fusion.1), kind=kInput, calls=%fused_computation
|
||||||
// CHECK: ROOT %reduce.1 = f32[100]{0} reduce(f32[32,100]{1,0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
|
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
)");
|
)");
|
||||||
|
|
||||||
@ -306,7 +312,6 @@ ENTRY main {
|
|||||||
zero = f32[] constant(0)
|
zero = f32[] constant(0)
|
||||||
ROOT out = f32[2,2,2] reduce(input, zero), dimensions={0}, to_apply=add
|
ROOT out = f32[2,2,2] reduce(input, zero), dimensions={0}, to_apply=add
|
||||||
}
|
}
|
||||||
|
|
||||||
)";
|
)";
|
||||||
|
|
||||||
MatchOptimizedHloWithShapes(hlo_text,
|
MatchOptimizedHloWithShapes(hlo_text,
|
||||||
@ -346,7 +351,6 @@ ENTRY main {
|
|||||||
zero = f32[] constant(0)
|
zero = f32[] constant(0)
|
||||||
ROOT out = f32[5] reduce(input, zero), dimensions={0}, to_apply=add
|
ROOT out = f32[5] reduce(input, zero), dimensions={0}, to_apply=add
|
||||||
}
|
}
|
||||||
|
|
||||||
)";
|
)";
|
||||||
|
|
||||||
MatchOptimizedHloWithShapes(hlo_text,
|
MatchOptimizedHloWithShapes(hlo_text,
|
||||||
|
@ -38,6 +38,14 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
|
||||||
|
// TODO(cheshire): duplication w/ GetReductionTiling, but we need to get a
|
||||||
|
// minimum possible tiling, regardless of the input.
|
||||||
|
static constexpr int64 kRowAtomicFreeBound = kWarpSize * kWarpSize * 8;
|
||||||
|
static constexpr int64 kColumnAtomicFreeBound = kWarpSize * 128;
|
||||||
|
// TODO(cheshire): This is very small, we could increase it at the cost of
|
||||||
|
// decreased column/row tiling.
|
||||||
|
static constexpr int64 kBatchedAtomicFreeBound = 8;
|
||||||
|
|
||||||
class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
|
class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
|
||||||
public:
|
public:
|
||||||
explicit ReductionRewriterVisitor() {}
|
explicit ReductionRewriterVisitor() {}
|
||||||
@ -65,13 +73,8 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
|
|||||||
Shape input_shape = input->shape();
|
Shape input_shape = input->shape();
|
||||||
VLOG(3) << "Input shape: " << input_shape.ToString();
|
VLOG(3) << "Input shape: " << input_shape.ToString();
|
||||||
|
|
||||||
std::array<int64, 3> reduction_tiling =
|
|
||||||
GetReductionTiling(reduction_dimensions);
|
|
||||||
|
|
||||||
int64 batched_atomic_free_bound = reduction_tiling[0];
|
|
||||||
bool reduce_batch_dimension = hlo->dimensions().size() > 1;
|
bool reduce_batch_dimension = hlo->dimensions().size() > 1;
|
||||||
VLOG(3) << "reduce_batch_dimension = " << reduce_batch_dimension;
|
VLOG(3) << "reduce_batch_dimension = " << reduce_batch_dimension;
|
||||||
VLOG(3) << "batched atomic free: " << batched_atomic_free_bound;
|
|
||||||
|
|
||||||
std::vector<int64> reduced_dimensions = hlo->dimensions();
|
std::vector<int64> reduced_dimensions = hlo->dimensions();
|
||||||
absl::c_sort(reduced_dimensions);
|
absl::c_sort(reduced_dimensions);
|
||||||
@ -82,19 +85,16 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
|
|||||||
|
|
||||||
// Case (1): batched dimension does not fit.
|
// Case (1): batched dimension does not fit.
|
||||||
if (reduce_batch_dimension &&
|
if (reduce_batch_dimension &&
|
||||||
input_shape.dimensions(0) > batched_atomic_free_bound) {
|
input_shape.dimensions(0) > kBatchedAtomicFreeBound) {
|
||||||
VLOG(1) << "Splitting batched dimension reduce into a separate reduction";
|
VLOG(1) << "Splitting batched dimension reduce into a separate reduction";
|
||||||
return RewriteBatchDimensionLargerThanTile(hlo, reduction_dimensions,
|
return RewriteBatchDimensionLargerThanTile(hlo, reduction_dimensions,
|
||||||
reduced_input_dimension,
|
reduced_input_dimension,
|
||||||
input_shape, input);
|
input_shape, input);
|
||||||
}
|
}
|
||||||
|
bool is_row_reduction = reduction_dimensions.is_row_reduction;
|
||||||
|
|
||||||
int64 atomic_free_bound = [&] {
|
int64 atomic_free_bound =
|
||||||
if (reduction_dimensions.is_row_reduction) {
|
is_row_reduction ? kRowAtomicFreeBound : kColumnAtomicFreeBound;
|
||||||
return reduction_tiling[2] * kWarpSize * kWarpSize;
|
|
||||||
}
|
|
||||||
return reduction_tiling[1] * kWarpSize;
|
|
||||||
}();
|
|
||||||
VLOG(3) << "atomic_free_bound: " << atomic_free_bound;
|
VLOG(3) << "atomic_free_bound: " << atomic_free_bound;
|
||||||
|
|
||||||
// Base case: everything fits.
|
// Base case: everything fits.
|
||||||
@ -105,10 +105,24 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
|
|||||||
|
|
||||||
int64 reduced_dim_size = input_shape.dimensions(reduced_input_dimension);
|
int64 reduced_dim_size = input_shape.dimensions(reduced_input_dimension);
|
||||||
VLOG(3) << "reduced_dim_size = " << reduced_dim_size;
|
VLOG(3) << "reduced_dim_size = " << reduced_dim_size;
|
||||||
|
// TODO(cheshire): if atomic_free_bound is large, num_fit is likely to be
|
||||||
|
// small. Generating a reduction with very small reduced dimension is not
|
||||||
|
// efficient, it would be better to split the dimension sizes more evenly.
|
||||||
|
//
|
||||||
|
// One possible idea is to pad to a nearest square (ceil(sqrt(x)))^2.
|
||||||
|
// Given that:
|
||||||
|
//
|
||||||
|
// (n + 1)^2 = n^2 + (2n+1)
|
||||||
|
//
|
||||||
|
// it can be seen that the distance to the nearest square is at most twice
|
||||||
|
// the square root of the input number.
|
||||||
int64 num_fit = CeilOfRatio(reduced_dim_size, atomic_free_bound);
|
int64 num_fit = CeilOfRatio(reduced_dim_size, atomic_free_bound);
|
||||||
|
|
||||||
// Pad reduced dimension to the required number of elements.
|
// Pad reduced dimension to the required number of elements.
|
||||||
HloInstruction *padded = [&] {
|
HloInstruction *padded = [&] {
|
||||||
|
// TODO(cheshire): if atomic_free_bound is very large, padding all the way
|
||||||
|
// up to to atomic_free_bound is wasteful, we could pad to a much smaller
|
||||||
|
// value.
|
||||||
if (reduced_dim_size % atomic_free_bound != 0) {
|
if (reduced_dim_size % atomic_free_bound != 0) {
|
||||||
int64 padded_num_elements = num_fit * atomic_free_bound;
|
int64 padded_num_elements = num_fit * atomic_free_bound;
|
||||||
PaddingConfig padding_config = MakeNoPaddingConfig(input_shape.rank());
|
PaddingConfig padding_config = MakeNoPaddingConfig(input_shape.rank());
|
||||||
@ -145,20 +159,21 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
|
|||||||
VLOG(1) << "Generated reshape: " << reshaped_padded_input->ToString();
|
VLOG(1) << "Generated reshape: " << reshaped_padded_input->ToString();
|
||||||
|
|
||||||
std::vector<int64> inner_reduce_dimensions = reshaped_dimensions;
|
std::vector<int64> inner_reduce_dimensions = reshaped_dimensions;
|
||||||
|
int64 inner_reduced_dimension = is_row_reduction
|
||||||
|
? inner_reduce_dimensions.size() - 1
|
||||||
|
: reduced_input_dimension;
|
||||||
|
VLOG(2) << "inner_reduced_dimension = " << inner_reduced_dimension;
|
||||||
inner_reduce_dimensions.erase(inner_reduce_dimensions.begin() +
|
inner_reduce_dimensions.erase(inner_reduce_dimensions.begin() +
|
||||||
reduced_input_dimension);
|
inner_reduced_dimension);
|
||||||
if (reduce_batch_dimension) {
|
if (reduce_batch_dimension) {
|
||||||
inner_reduce_dimensions.erase(inner_reduce_dimensions.begin());
|
inner_reduce_dimensions.erase(inner_reduce_dimensions.begin());
|
||||||
}
|
}
|
||||||
|
|
||||||
Shape inner_reduce_shape = ShapeUtil::MakeShape(input_shape.element_type(),
|
Shape inner_reduce_shape = ShapeUtil::MakeShape(input_shape.element_type(),
|
||||||
inner_reduce_dimensions);
|
inner_reduce_dimensions);
|
||||||
std::vector<int64> dims_to_reduce = {reduced_input_dimension};
|
std::vector<int64> dims_to_reduce = {inner_reduced_dimension};
|
||||||
|
|
||||||
int64 reduced_inner_dimension = reduced_input_dimension;
|
|
||||||
if (reduce_batch_dimension) {
|
if (reduce_batch_dimension) {
|
||||||
dims_to_reduce.push_back(0);
|
dims_to_reduce.push_back(0);
|
||||||
reduced_inner_dimension -= 1;
|
inner_reduced_dimension -= 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
HloInstruction *inner_reduce =
|
HloInstruction *inner_reduce =
|
||||||
@ -166,20 +181,21 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
|
|||||||
inner_reduce_shape, reshaped_padded_input, initial_value,
|
inner_reduce_shape, reshaped_padded_input, initial_value,
|
||||||
dims_to_reduce, hlo->to_apply()));
|
dims_to_reduce, hlo->to_apply()));
|
||||||
VLOG(1) << "Generated inner reduction: " << inner_reduce->ToString();
|
VLOG(1) << "Generated inner reduction: " << inner_reduce->ToString();
|
||||||
|
|
||||||
std::vector<int64> outer_reduce_dimensions = inner_reduce_dimensions;
|
std::vector<int64> outer_reduce_dimensions = inner_reduce_dimensions;
|
||||||
VLOG(3) << "outer_reduce_dimensions = "
|
VLOG(3) << "outer_reduce_dimensions = "
|
||||||
<< absl::StrJoin(outer_reduce_dimensions, ", ");
|
<< absl::StrJoin(outer_reduce_dimensions, ", ");
|
||||||
VLOG(3) << "reduced_inner_dimension = " << reduced_inner_dimension;
|
int64 outer_reduced_dimension = is_row_reduction
|
||||||
|
? outer_reduce_dimensions.size() - 1
|
||||||
|
: reduced_input_dimension;
|
||||||
|
|
||||||
// Remove reduced dimension.
|
// Remove reduced dimension.
|
||||||
outer_reduce_dimensions.erase(outer_reduce_dimensions.begin() +
|
outer_reduce_dimensions.erase(outer_reduce_dimensions.begin() +
|
||||||
reduced_inner_dimension);
|
outer_reduced_dimension);
|
||||||
Shape outer_reduce_shape = ShapeUtil::MakeShape(input_shape.element_type(),
|
Shape outer_reduce_shape = ShapeUtil::MakeShape(input_shape.element_type(),
|
||||||
outer_reduce_dimensions);
|
outer_reduce_dimensions);
|
||||||
std::unique_ptr<HloInstruction> outer_reduce = HloInstruction::CreateReduce(
|
std::unique_ptr<HloInstruction> outer_reduce = HloInstruction::CreateReduce(
|
||||||
outer_reduce_shape, inner_reduce, initial_value,
|
outer_reduce_shape, inner_reduce, initial_value,
|
||||||
{reduced_inner_dimension}, hlo->to_apply());
|
{outer_reduced_dimension}, hlo->to_apply());
|
||||||
|
|
||||||
VLOG(1) << "Generated outer reduction: " << outer_reduce->ToString();
|
VLOG(1) << "Generated outer reduction: " << outer_reduce->ToString();
|
||||||
return ReplaceWithNewInstruction(hlo, std::move(outer_reduce));
|
return ReplaceWithNewInstruction(hlo, std::move(outer_reduce));
|
||||||
|
Loading…
Reference in New Issue
Block a user