Improve layout related multi-output fusion heuristic.
We should only look at the relative order of the non-trivial dimensions. Make sure that the relative order of non-trivial dimensions is the same for all shapes with maximum true rank. PiperOrigin-RevId: 303311472 Change-Id: I54e2a8f404b78da874dabfd6078bba8a8d4f2fa0
This commit is contained in:
parent
411203b713
commit
d81a767881
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
#include <stack>
|
||||
#include <vector>
|
||||
|
@ -66,6 +67,25 @@ bool IfFusedReadsElementsMultipleTimes(const HloInstruction& instr) {
|
|||
return false;
|
||||
}
|
||||
|
||||
std::vector<int64> ExtractRelativeOrderOfNontrivialDims(const Shape& shape) {
|
||||
std::vector<int64> relative_order;
|
||||
for (int64 dim : LayoutUtil::MinorToMajor(shape)) {
|
||||
if (shape.dimensions(dim) > 1) {
|
||||
relative_order.push_back(dim);
|
||||
}
|
||||
}
|
||||
// Now normalize the dimensions to values between 0 and true rank - 1.
|
||||
std::vector<int64> sorted_dims = relative_order;
|
||||
std::sort(sorted_dims.begin(), sorted_dims.end());
|
||||
for (int64& dim : relative_order) {
|
||||
int64 sorted_index = std::distance(
|
||||
sorted_dims.begin(),
|
||||
std::lower_bound(sorted_dims.begin(), sorted_dims.end(), dim));
|
||||
dim = sorted_index;
|
||||
}
|
||||
return relative_order;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer,
|
||||
|
@ -73,17 +93,20 @@ bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer,
|
|||
std::vector<HloInstruction*> params;
|
||||
AppendParams(producer, ¶ms);
|
||||
AppendParams(reduce, ¶ms);
|
||||
int64 max_rank = -1;
|
||||
const Layout* max_rank_layout;
|
||||
int64 max_true_rank = -1;
|
||||
std::vector<int64> max_rank_order;
|
||||
for (HloInstruction* param : params) {
|
||||
if (param->shape().IsArray() && param->shape().rank() > max_rank) {
|
||||
max_rank = param->shape().rank();
|
||||
max_rank_layout = ¶m->shape().layout();
|
||||
if (param->shape().IsArray() &&
|
||||
ShapeUtil::TrueRank(param->shape()) > max_true_rank) {
|
||||
max_true_rank = ShapeUtil::TrueRank(param->shape());
|
||||
max_rank_order = ExtractRelativeOrderOfNontrivialDims(param->shape());
|
||||
}
|
||||
}
|
||||
return absl::c_all_of(params, [&](HloInstruction* param) {
|
||||
return (!param->shape().IsArray()) || (param->shape().rank() < max_rank) ||
|
||||
(LayoutUtil::Equal(param->shape().layout(), *max_rank_layout));
|
||||
return !param->shape().IsArray() ||
|
||||
ShapeUtil::TrueRank(param->shape()) < max_true_rank ||
|
||||
ExtractRelativeOrderOfNontrivialDims(param->shape()) ==
|
||||
max_rank_order;
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
@ -91,6 +91,44 @@ TEST_F(GpuFusibleTest,
|
|||
LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce_fusion));
|
||||
}
|
||||
|
||||
TEST_F(GpuFusibleTest,
|
||||
LayoutsAreReduceInputFusionFriendly_MixedLayoutProducerWithTrivialDim) {
|
||||
auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
|
||||
mixed_input_layouts_computation {
|
||||
p0.1 = f16[128,1,32,32]{1,3,2,0} parameter(0)
|
||||
p1.1 = f16[128,1,32,32]{3,2,1,0} parameter(1)
|
||||
copy = f16[128,1,32,32]{1,3,2,0} copy(p1.1)
|
||||
c0 = f16[] constant(0)
|
||||
broadcast = f16[128,1,32,32]{1,3,2,0} broadcast(c0), dimensions={}
|
||||
greater-than = pred[128,1,32,32]{1,3,2,0} compare(copy, broadcast), direction=GT
|
||||
ROOT root = f16[128,1,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast)
|
||||
}
|
||||
fused_reduce {
|
||||
p0.2 = f16[128,1,32,32]{1,3,2,0} parameter(0)
|
||||
convert = f32[128,1,32,32]{1,3,2,0} convert(p0.2)
|
||||
c0.2 = f32[] constant(0)
|
||||
ROOT reduce = f32[1]{0} reduce(convert, c0.2), dimensions={0,2,3}, to_apply=scalar_add
|
||||
}
|
||||
ENTRY entry {
|
||||
p0 = f16[128,1,32,32]{1,3,2,0} parameter(0)
|
||||
p1 = f16[128,1,32,32]{3,2,1,0} parameter(1)
|
||||
loop_fusion = f16[128,1,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=mixed_input_layouts_computation
|
||||
reduce_fusion = f32[1]{0} fusion(loop_fusion), kind=kInput, calls=fused_reduce
|
||||
ROOT root = (f32[1]{0}, f16[128,1,32,32]{1,3,2,0}) tuple(reduce_fusion, loop_fusion)
|
||||
})"))
|
||||
.ValueOrDie();
|
||||
SCOPED_TRACE(module->ToString());
|
||||
const HloInstruction* reduce_fusion =
|
||||
module->entry_computation()->root_instruction()->operand(0);
|
||||
ASSERT_EQ(reduce_fusion->fused_expression_root()->opcode(),
|
||||
HloOpcode::kReduce);
|
||||
const HloInstruction* loop_fusion =
|
||||
module->entry_computation()->root_instruction()->operand(1);
|
||||
ASSERT_EQ(loop_fusion->fused_expression_root()->opcode(), HloOpcode::kSelect);
|
||||
EXPECT_TRUE(
|
||||
LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce_fusion));
|
||||
}
|
||||
|
||||
TEST_F(GpuFusibleTest, LayoutsAreReduceInputFusionFriendly_CopyProducer) {
|
||||
auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
|
||||
fused_reduce {
|
||||
|
@ -152,17 +190,18 @@ TEST_F(GpuFusibleTest,
|
|||
}
|
||||
|
||||
TEST_F(GpuFusibleTest,
|
||||
LayoutsAreReduceInputFusionFriendly_ConsiderMaximumRanksParamsOnly) {
|
||||
LayoutsAreReduceInputFusionFriendly_ConsiderMaximumTrueRanksParamsOnly) {
|
||||
auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
|
||||
broadcasting_computation {
|
||||
p0.1 = f32[128,1024,32,32]{1,3,2,0} parameter(0)
|
||||
p1.1 = f32[128]{0} parameter(1)
|
||||
broadcast = f32[128,1024,32,32]{1,3,2,0} broadcast(p1.1), dimensions={0}
|
||||
p1.1 = f32[1,128,1,1]{3,2,1,0} parameter(1)
|
||||
reshape = f32[128]{0} reshape(p1.1)
|
||||
broadcast = f32[128,1024,32,32]{1,3,2,0} broadcast(reshape), dimensions={0}
|
||||
ROOT add = f32[128,1024,32,32]{1,3,2,0} add(p0.1, broadcast)
|
||||
}
|
||||
ENTRY entry {
|
||||
p0 = f32[128,1024,32,32]{1,3,2,0} parameter(0)
|
||||
p1 = f32[128]{0} parameter(1)
|
||||
p1 = f32[1,128,1,1]{3,2,1,0} parameter(1)
|
||||
loop_fusion = f32[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=broadcasting_computation
|
||||
c0.2 = f32[] constant(0)
|
||||
ROOT reduce = f32[1024]{0} reduce(loop_fusion, c0.2), dimensions={0,2,3}, to_apply=scalar_add
|
||||
|
|
Loading…
Reference in New Issue