[XLA/GPU] Fix some tree reduction rewriter bugs, add some TODOs
Main bug fixed: row reduction should be rewritten into a row reduction. PiperOrigin-RevId: 295256755 Change-Id: I79606680a0c3b05cf0f21d5b9f3f9e4f01e6426b
This commit is contained in:
parent
30e3943cb2
commit
a755fe8236
tensorflow/compiler/xla/service/gpu
@ -73,17 +73,18 @@ ENTRY main {
|
||||
// TODO(cheshire): a more generic check, do not hardcode the names.
|
||||
MatchOptimizedHloWithShapes(hlo_text,
|
||||
R"(
|
||||
// CHECK: %fused_computation (param_0.2: f32[50000]) -> f32[] {
|
||||
// CHECK: %fused_computation (param_0.2: f32[50000]) -> f32[7] {
|
||||
// CHECK: %param_0.2 = f32[50000]{0} parameter(0)
|
||||
// CHECK: %zero_1 = f32[] constant(0)
|
||||
// CHECK: %pad.1 = f32[65536]{0} pad(f32[50000]{0} %param_0.2, f32[] %zero_1), padding=0_15536
|
||||
// CHECK: %bitcast.1 = f32[4,16384]{1,0} bitcast(f32[65536]{0} %pad.1)
|
||||
// CHECK: %reduce.3 = f32[16384]{0} reduce(f32[4,16384]{1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add
|
||||
// CHECK: ROOT %reduce.2 = f32[] reduce(f32[16384]{0} %reduce.3, f32[] %zero_1), dimensions={0}, to_apply=%add
|
||||
// CHECK: %pad.1 = f32[57344]{0} pad(f32[50000]{0} %param_0.2, f32[] %zero_1), padding=0_7344
|
||||
// CHECK: %bitcast.1 = f32[7,8192]{1,0} bitcast(f32[57344]{0} %pad.1)
|
||||
// CHECK: ROOT %reduce.2 = f32[7]{0} reduce(f32[7,8192]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add
|
||||
// CHECK: }
|
||||
// CHECK: ENTRY %main (input: f32[50000]) -> f32[] {
|
||||
// CHECK: %input = f32[50000]{0} parameter(0)
|
||||
// CHECK: ROOT %fusion = f32[] fusion(f32[50000]{0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %fusion = f32[7]{0} fusion(f32[50000]{0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %zero = f32[] constant(0)
|
||||
// CHECK: ROOT %reduce.1 = f32[] reduce(f32[7]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
|
||||
// CHECK: }
|
||||
)");
|
||||
|
||||
@ -113,18 +114,20 @@ ENTRY main {
|
||||
|
||||
MatchOptimizedHloWithShapes(hlo_text,
|
||||
R"(
|
||||
// CHECK: %fused_computation (param_0.2: f32[100,100,10000]) -> f32[100,100] {
|
||||
// CHECK: %fused_computation (param_0.2: f32[100,100,10000]) -> f32[100,100,2] {
|
||||
// CHECK: %param_0.2 = f32[100,100,10000]{2,1,0} parameter(0)
|
||||
// CHECK: %zero_1 = f32[] constant(0)
|
||||
// CHECK: %pad.1 = f32[100,100,16384]{2,1,0} pad(f32[100,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384
|
||||
// CHECK: %bitcast.1 = f32[100,100,2,8192]{3,2,1,0} bitcast(f32[100,100,16384]{2,1,0} %pad.1)
|
||||
// CHECK: %reduce.3 = f32[100,100,8192]{2,1,0} reduce(f32[100,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={2}, to_apply=%add
|
||||
// CHECK: ROOT %reduce.2 = f32[100,100]{1,0} reduce(f32[100,100,8192]{2,1,0} %reduce.3, f32[] %zero_1), dimensions={2}, to_apply=%add
|
||||
// CHECK: ROOT %reduce.2 = f32[100,100,2]{2,1,0} reduce(f32[100,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3}, to_apply=%add
|
||||
// CHECK: }
|
||||
// CHECK: ENTRY %main (input: f32[100,100,10000]) -> f32[100,100] {
|
||||
// CHECK: %input = f32[100,100,10000]{2,1,0} parameter(0)
|
||||
// CHECK: ROOT %fusion = f32[100,100]{1,0} fusion(f32[100,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %fusion = f32[100,100,2]{2,1,0} fusion(f32[100,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %zero = f32[] constant(0)
|
||||
// CHECK: ROOT %reduce.1 = f32[100,100]{1,0} reduce(f32[100,100,2]{2,1,0} %fusion, f32[] %zero), dimensions={2}, to_apply=%add
|
||||
// CHECK: }
|
||||
|
||||
)");
|
||||
|
||||
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
|
||||
@ -151,18 +154,18 @@ ENTRY main {
|
||||
|
||||
MatchOptimizedHloWithShapes(hlo_text,
|
||||
R"(
|
||||
// CHECK: %fused_computation (param_0.2: f32[1000000]) -> f32[16384] {
|
||||
// CHECK: %fused_computation (param_0.2: f32[1000000]) -> f32[123] {
|
||||
// CHECK: %param_0.2 = f32[1000000]{0} parameter(0)
|
||||
// CHECK: %zero_1 = f32[] constant(0)
|
||||
// CHECK: %pad.1 = f32[1015808]{0} pad(f32[1000000]{0} %param_0.2, f32[] %zero_1), padding=0_15808
|
||||
// CHECK: %bitcast.1 = f32[62,16384]{1,0} bitcast(f32[1015808]{0} %pad.1)
|
||||
// CHECK: ROOT %reduce.2 = f32[16384]{0} reduce(f32[62,16384]{1,0} %bitcast.1, f32[] %zero_1), dimensions={0}, to_apply=%add
|
||||
// CHECK: %pad.1 = f32[1007616]{0} pad(f32[1000000]{0} %param_0.2, f32[] %zero_1), padding=0_7616
|
||||
// CHECK: %bitcast.1 = f32[123,8192]{1,0} bitcast(f32[1007616]{0} %pad.1)
|
||||
// CHECK: ROOT %reduce.2 = f32[123]{0} reduce(f32[123,8192]{1,0} %bitcast.1, f32[] %zero_1), dimensions={1}, to_apply=%add
|
||||
// CHECK: }
|
||||
// CHECK: ENTRY %main (input: f32[1000000]) -> f32[] {
|
||||
// CHECK: %input = f32[1000000]{0} parameter(0)
|
||||
// CHECK: %fusion = f32[16384]{0} fusion(f32[1000000]{0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %fusion = f32[123]{0} fusion(f32[1000000]{0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %zero = f32[] constant(0)
|
||||
// CHECK: ROOT %reduce.1 = f32[] reduce(f32[16384]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
|
||||
// CHECK: ROOT %reduce.1 = f32[] reduce(f32[123]{0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
|
||||
// CHECK: }
|
||||
)");
|
||||
|
||||
@ -192,17 +195,18 @@ ENTRY main {
|
||||
|
||||
MatchOptimizedHloWithShapes(hlo_text,
|
||||
R"(
|
||||
// CHECK: %fused_computation (param_0.2: f32[8,100,10000]) -> f32[100] {
|
||||
// CHECK: %param_0.2 = f32[8,100,10000]{2,1,0} parameter(0)
|
||||
// CHECK: %zero_1 = f32[] constant(0)
|
||||
// CHECK: %pad.1 = f32[8,100,16384]{2,1,0} pad(f32[8,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384
|
||||
// CHECK: %bitcast.1 = f32[8,100,2,8192]{3,2,1,0} bitcast(f32[8,100,16384]{2,1,0} %pad.1)
|
||||
// CHECK: %reduce.3 = f32[100,8192]{1,0} reduce(f32[8,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={2,0}, to_apply=%add
|
||||
// CHECK: ROOT %reduce.2 = f32[100]{0} reduce(f32[100,8192]{1,0} %reduce.3, f32[] %zero_1), dimensions={1}, to_apply=%add
|
||||
// CHECK: %fused_computation (param_0.2: f32[8,100,10000]) -> f32[100,2] {
|
||||
// CHECK: %param_0.2 = f32[8,100,10000]{2,1,0} parameter(0)
|
||||
// CHECK: %zero_1 = f32[] constant(0)
|
||||
// CHECK: %pad.1 = f32[8,100,16384]{2,1,0} pad(f32[8,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384
|
||||
// CHECK: %bitcast.1 = f32[8,100,2,8192]{3,2,1,0} bitcast(f32[8,100,16384]{2,1,0} %pad.1)
|
||||
// CHECK: ROOT %reduce.2 = f32[100,2]{1,0} reduce(f32[8,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={3,0}, to_apply=%add
|
||||
// CHECK: }
|
||||
// CHECK: ENTRY %main (input: f32[8,100,10000]) -> f32[100] {
|
||||
// CHECK: %input = f32[8,100,10000]{2,1,0} parameter(0)
|
||||
// CHECK: ROOT %fusion = f32[100]{0} fusion(f32[8,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %input = f32[8,100,10000]{2,1,0} parameter(0)
|
||||
// CHECK: %fusion = f32[100,2]{1,0} fusion(f32[8,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %zero = f32[] constant(0)
|
||||
// CHECK: ROOT %reduce.1 = f32[100]{0} reduce(f32[100,2]{1,0} %fusion, f32[] %zero), dimensions={1}, to_apply=%add
|
||||
// CHECK: }
|
||||
)");
|
||||
|
||||
@ -210,7 +214,6 @@ ENTRY main {
|
||||
}
|
||||
|
||||
TEST_F(TreeReductionRewriterTest, RowReductionBatchedDimensionDoesNotFit) {
|
||||
// Note: this could be too slow without shared memory optimization.
|
||||
const char* hlo_text = R"(
|
||||
HloModule ReduceWithPadding
|
||||
|
||||
@ -225,26 +228,29 @@ ENTRY main {
|
||||
zero = f32[] constant(0)
|
||||
ROOT out = f32[100] reduce(input, zero), dimensions={0,2}, to_apply=add
|
||||
}
|
||||
|
||||
)";
|
||||
|
||||
EnsureDeterminism(hlo_text);
|
||||
|
||||
MatchOptimizedHloWithShapes(hlo_text,
|
||||
R"(
|
||||
// CHECK: %fused_computation (param_0.2: f32[32,100,10000]) -> f32[32,100] {
|
||||
// CHECK: %param_0.2 = f32[32,100,10000]{2,1,0} parameter(0)
|
||||
// CHECK: %fused_computation (param_0.4: f32[32,100,2]) -> f32[100] {
|
||||
// CHECK: %param_0.4 = f32[32,100,2]{2,1,0} parameter(0)
|
||||
// CHECK: %zero_1 = f32[] constant(0)
|
||||
// CHECK: %pad.1 = f32[32,100,16384]{2,1,0} pad(f32[32,100,10000]{2,1,0} %param_0.2, f32[] %zero_1), padding=0_0x0_0x0_6384
|
||||
// CHECK: %reduce.5 = f32[32,100]{1,0} reduce(f32[32,100,2]{2,1,0} %param_0.4, f32[] %zero_1), dimensions={2}, to_apply=%add
|
||||
// CHECK: ROOT %reduce.4 = f32[100]{0} reduce(f32[32,100]{1,0} %reduce.5, f32[] %zero_1), dimensions={0}, to_apply=%add
|
||||
// CHECK: }
|
||||
// CHECK: %fused_computation.1 (param_0.5: f32[32,100,10000]) -> f32[32,100,2] {
|
||||
// CHECK: %param_0.5 = f32[32,100,10000]{2,1,0} parameter(0)
|
||||
// CHECK: %zero_2 = f32[] constant(0)
|
||||
// CHECK: %pad.1 = f32[32,100,16384]{2,1,0} pad(f32[32,100,10000]{2,1,0} %param_0.5, f32[] %zero_2), padding=0_0x0_0x0_6384
|
||||
// CHECK: %bitcast.1 = f32[32,100,2,8192]{3,2,1,0} bitcast(f32[32,100,16384]{2,1,0} %pad.1)
|
||||
// CHECK: %reduce.5 = f32[32,100,8192]{2,1,0} reduce(f32[32,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_1), dimensions={2}, to_apply=%add
|
||||
// CHECK: ROOT %reduce.4 = f32[32,100]{1,0} reduce(f32[32,100,8192]{2,1,0} %reduce.5, f32[] %zero_1), dimensions={2}, to_apply=%add
|
||||
// CHECK: ROOT %reduce.6 = f32[32,100,2]{2,1,0} reduce(f32[32,100,2,8192]{3,2,1,0} %bitcast.1, f32[] %zero_2), dimensions={3}, to_apply=%add
|
||||
// CHECK: }
|
||||
// CHECK: ENTRY %main (input: f32[32,100,10000]) -> f32[100] {
|
||||
// CHECK: %input = f32[32,100,10000]{2,1,0} parameter(0)
|
||||
// CHECK: %fusion = f32[32,100]{1,0} fusion(f32[32,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation
|
||||
// CHECK: %zero = f32[] constant(0)
|
||||
// CHECK: ROOT %reduce.1 = f32[100]{0} reduce(f32[32,100]{1,0} %fusion, f32[] %zero), dimensions={0}, to_apply=%add
|
||||
// CHECK: %fusion.1 = f32[32,100,2]{2,1,0} fusion(f32[32,100,10000]{2,1,0} %input), kind=kInput, calls=%fused_computation.1
|
||||
// CHECK: ROOT %fusion = f32[100]{0} fusion(f32[32,100,2]{2,1,0} %fusion.1), kind=kInput, calls=%fused_computation
|
||||
// CHECK: }
|
||||
)");
|
||||
|
||||
@ -306,7 +312,6 @@ ENTRY main {
|
||||
zero = f32[] constant(0)
|
||||
ROOT out = f32[2,2,2] reduce(input, zero), dimensions={0}, to_apply=add
|
||||
}
|
||||
|
||||
)";
|
||||
|
||||
MatchOptimizedHloWithShapes(hlo_text,
|
||||
@ -346,7 +351,6 @@ ENTRY main {
|
||||
zero = f32[] constant(0)
|
||||
ROOT out = f32[5] reduce(input, zero), dimensions={0}, to_apply=add
|
||||
}
|
||||
|
||||
)";
|
||||
|
||||
MatchOptimizedHloWithShapes(hlo_text,
|
||||
|
@ -38,6 +38,14 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
// TODO(cheshire): duplication w/ GetReductionTiling, but we need to get a
|
||||
// minimum possible tiling, regardless of the input.
|
||||
static constexpr int64 kRowAtomicFreeBound = kWarpSize * kWarpSize * 8;
|
||||
static constexpr int64 kColumnAtomicFreeBound = kWarpSize * 128;
|
||||
// TODO(cheshire): This is very small, we could increase it at the cost of
|
||||
// decreased column/row tiling.
|
||||
static constexpr int64 kBatchedAtomicFreeBound = 8;
|
||||
|
||||
class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
|
||||
public:
|
||||
explicit ReductionRewriterVisitor() {}
|
||||
@ -65,13 +73,8 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
|
||||
Shape input_shape = input->shape();
|
||||
VLOG(3) << "Input shape: " << input_shape.ToString();
|
||||
|
||||
std::array<int64, 3> reduction_tiling =
|
||||
GetReductionTiling(reduction_dimensions);
|
||||
|
||||
int64 batched_atomic_free_bound = reduction_tiling[0];
|
||||
bool reduce_batch_dimension = hlo->dimensions().size() > 1;
|
||||
VLOG(3) << "reduce_batch_dimension = " << reduce_batch_dimension;
|
||||
VLOG(3) << "batched atomic free: " << batched_atomic_free_bound;
|
||||
|
||||
std::vector<int64> reduced_dimensions = hlo->dimensions();
|
||||
absl::c_sort(reduced_dimensions);
|
||||
@ -82,19 +85,16 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
|
||||
|
||||
// Case (1): batched dimension does not fit.
|
||||
if (reduce_batch_dimension &&
|
||||
input_shape.dimensions(0) > batched_atomic_free_bound) {
|
||||
input_shape.dimensions(0) > kBatchedAtomicFreeBound) {
|
||||
VLOG(1) << "Splitting batched dimension reduce into a separate reduction";
|
||||
return RewriteBatchDimensionLargerThanTile(hlo, reduction_dimensions,
|
||||
reduced_input_dimension,
|
||||
input_shape, input);
|
||||
}
|
||||
bool is_row_reduction = reduction_dimensions.is_row_reduction;
|
||||
|
||||
int64 atomic_free_bound = [&] {
|
||||
if (reduction_dimensions.is_row_reduction) {
|
||||
return reduction_tiling[2] * kWarpSize * kWarpSize;
|
||||
}
|
||||
return reduction_tiling[1] * kWarpSize;
|
||||
}();
|
||||
int64 atomic_free_bound =
|
||||
is_row_reduction ? kRowAtomicFreeBound : kColumnAtomicFreeBound;
|
||||
VLOG(3) << "atomic_free_bound: " << atomic_free_bound;
|
||||
|
||||
// Base case: everything fits.
|
||||
@ -105,10 +105,24 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
|
||||
|
||||
int64 reduced_dim_size = input_shape.dimensions(reduced_input_dimension);
|
||||
VLOG(3) << "reduced_dim_size = " << reduced_dim_size;
|
||||
// TODO(cheshire): if atomic_free_bound is large, num_fit is likely to be
|
||||
// small. Generating a reduction with very small reduced dimension is not
|
||||
// efficient, it would be better to split the dimension sizes more evenly.
|
||||
//
|
||||
// One possible idea is to pad to a nearest square (ceil(sqrt(x)))^2.
|
||||
// Given that:
|
||||
//
|
||||
// (n + 1)^2 = n^2 + (2n+1)
|
||||
//
|
||||
// it can be seen that the distance to the nearest square is at most twice
|
||||
// the square root of the input number.
|
||||
int64 num_fit = CeilOfRatio(reduced_dim_size, atomic_free_bound);
|
||||
|
||||
// Pad reduced dimension to the required number of elements.
|
||||
HloInstruction *padded = [&] {
|
||||
// TODO(cheshire): if atomic_free_bound is very large, padding all the way
|
||||
// up to to atomic_free_bound is wasteful, we could pad to a much smaller
|
||||
// value.
|
||||
if (reduced_dim_size % atomic_free_bound != 0) {
|
||||
int64 padded_num_elements = num_fit * atomic_free_bound;
|
||||
PaddingConfig padding_config = MakeNoPaddingConfig(input_shape.rank());
|
||||
@ -145,20 +159,21 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
|
||||
VLOG(1) << "Generated reshape: " << reshaped_padded_input->ToString();
|
||||
|
||||
std::vector<int64> inner_reduce_dimensions = reshaped_dimensions;
|
||||
int64 inner_reduced_dimension = is_row_reduction
|
||||
? inner_reduce_dimensions.size() - 1
|
||||
: reduced_input_dimension;
|
||||
VLOG(2) << "inner_reduced_dimension = " << inner_reduced_dimension;
|
||||
inner_reduce_dimensions.erase(inner_reduce_dimensions.begin() +
|
||||
reduced_input_dimension);
|
||||
inner_reduced_dimension);
|
||||
if (reduce_batch_dimension) {
|
||||
inner_reduce_dimensions.erase(inner_reduce_dimensions.begin());
|
||||
}
|
||||
|
||||
Shape inner_reduce_shape = ShapeUtil::MakeShape(input_shape.element_type(),
|
||||
inner_reduce_dimensions);
|
||||
std::vector<int64> dims_to_reduce = {reduced_input_dimension};
|
||||
|
||||
int64 reduced_inner_dimension = reduced_input_dimension;
|
||||
std::vector<int64> dims_to_reduce = {inner_reduced_dimension};
|
||||
if (reduce_batch_dimension) {
|
||||
dims_to_reduce.push_back(0);
|
||||
reduced_inner_dimension -= 1;
|
||||
inner_reduced_dimension -= 1;
|
||||
}
|
||||
|
||||
HloInstruction *inner_reduce =
|
||||
@ -166,20 +181,21 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor {
|
||||
inner_reduce_shape, reshaped_padded_input, initial_value,
|
||||
dims_to_reduce, hlo->to_apply()));
|
||||
VLOG(1) << "Generated inner reduction: " << inner_reduce->ToString();
|
||||
|
||||
std::vector<int64> outer_reduce_dimensions = inner_reduce_dimensions;
|
||||
VLOG(3) << "outer_reduce_dimensions = "
|
||||
<< absl::StrJoin(outer_reduce_dimensions, ", ");
|
||||
VLOG(3) << "reduced_inner_dimension = " << reduced_inner_dimension;
|
||||
int64 outer_reduced_dimension = is_row_reduction
|
||||
? outer_reduce_dimensions.size() - 1
|
||||
: reduced_input_dimension;
|
||||
|
||||
// Remove reduced dimension.
|
||||
outer_reduce_dimensions.erase(outer_reduce_dimensions.begin() +
|
||||
reduced_inner_dimension);
|
||||
outer_reduced_dimension);
|
||||
Shape outer_reduce_shape = ShapeUtil::MakeShape(input_shape.element_type(),
|
||||
outer_reduce_dimensions);
|
||||
std::unique_ptr<HloInstruction> outer_reduce = HloInstruction::CreateReduce(
|
||||
outer_reduce_shape, inner_reduce, initial_value,
|
||||
{reduced_inner_dimension}, hlo->to_apply());
|
||||
{outer_reduced_dimension}, hlo->to_apply());
|
||||
|
||||
VLOG(1) << "Generated outer reduction: " << outer_reduce->ToString();
|
||||
return ReplaceWithNewInstruction(hlo, std::move(outer_reduce));
|
||||
|
Loading…
Reference in New Issue
Block a user