Improve layout related multi-output fusion heuristic.
We should only look at the relative order of the non-trivial dimensions. Make sure that the relative order of non-trivial dimensions is the same for all shapes with maximum true rank. PiperOrigin-RevId: 303311472 Change-Id: I54e2a8f404b78da874dabfd6078bba8a8d4f2fa0
This commit is contained in:
parent
411203b713
commit
d81a767881
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
|
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
#include <stack>
|
#include <stack>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -66,6 +67,25 @@ bool IfFusedReadsElementsMultipleTimes(const HloInstruction& instr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int64> ExtractRelativeOrderOfNontrivialDims(const Shape& shape) {
|
||||||
|
std::vector<int64> relative_order;
|
||||||
|
for (int64 dim : LayoutUtil::MinorToMajor(shape)) {
|
||||||
|
if (shape.dimensions(dim) > 1) {
|
||||||
|
relative_order.push_back(dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Now normalize the dimensions to values between 0 and true rank - 1.
|
||||||
|
std::vector<int64> sorted_dims = relative_order;
|
||||||
|
std::sort(sorted_dims.begin(), sorted_dims.end());
|
||||||
|
for (int64& dim : relative_order) {
|
||||||
|
int64 sorted_index = std::distance(
|
||||||
|
sorted_dims.begin(),
|
||||||
|
std::lower_bound(sorted_dims.begin(), sorted_dims.end(), dim));
|
||||||
|
dim = sorted_index;
|
||||||
|
}
|
||||||
|
return relative_order;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer,
|
bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer,
|
||||||
|
@ -73,17 +93,20 @@ bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer,
|
||||||
std::vector<HloInstruction*> params;
|
std::vector<HloInstruction*> params;
|
||||||
AppendParams(producer, ¶ms);
|
AppendParams(producer, ¶ms);
|
||||||
AppendParams(reduce, ¶ms);
|
AppendParams(reduce, ¶ms);
|
||||||
int64 max_rank = -1;
|
int64 max_true_rank = -1;
|
||||||
const Layout* max_rank_layout;
|
std::vector<int64> max_rank_order;
|
||||||
for (HloInstruction* param : params) {
|
for (HloInstruction* param : params) {
|
||||||
if (param->shape().IsArray() && param->shape().rank() > max_rank) {
|
if (param->shape().IsArray() &&
|
||||||
max_rank = param->shape().rank();
|
ShapeUtil::TrueRank(param->shape()) > max_true_rank) {
|
||||||
max_rank_layout = ¶m->shape().layout();
|
max_true_rank = ShapeUtil::TrueRank(param->shape());
|
||||||
|
max_rank_order = ExtractRelativeOrderOfNontrivialDims(param->shape());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return absl::c_all_of(params, [&](HloInstruction* param) {
|
return absl::c_all_of(params, [&](HloInstruction* param) {
|
||||||
return (!param->shape().IsArray()) || (param->shape().rank() < max_rank) ||
|
return !param->shape().IsArray() ||
|
||||||
(LayoutUtil::Equal(param->shape().layout(), *max_rank_layout));
|
ShapeUtil::TrueRank(param->shape()) < max_true_rank ||
|
||||||
|
ExtractRelativeOrderOfNontrivialDims(param->shape()) ==
|
||||||
|
max_rank_order;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -91,6 +91,44 @@ TEST_F(GpuFusibleTest,
|
||||||
LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce_fusion));
|
LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce_fusion));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(GpuFusibleTest,
|
||||||
|
LayoutsAreReduceInputFusionFriendly_MixedLayoutProducerWithTrivialDim) {
|
||||||
|
auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
|
||||||
|
mixed_input_layouts_computation {
|
||||||
|
p0.1 = f16[128,1,32,32]{1,3,2,0} parameter(0)
|
||||||
|
p1.1 = f16[128,1,32,32]{3,2,1,0} parameter(1)
|
||||||
|
copy = f16[128,1,32,32]{1,3,2,0} copy(p1.1)
|
||||||
|
c0 = f16[] constant(0)
|
||||||
|
broadcast = f16[128,1,32,32]{1,3,2,0} broadcast(c0), dimensions={}
|
||||||
|
greater-than = pred[128,1,32,32]{1,3,2,0} compare(copy, broadcast), direction=GT
|
||||||
|
ROOT root = f16[128,1,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast)
|
||||||
|
}
|
||||||
|
fused_reduce {
|
||||||
|
p0.2 = f16[128,1,32,32]{1,3,2,0} parameter(0)
|
||||||
|
convert = f32[128,1,32,32]{1,3,2,0} convert(p0.2)
|
||||||
|
c0.2 = f32[] constant(0)
|
||||||
|
ROOT reduce = f32[1]{0} reduce(convert, c0.2), dimensions={0,2,3}, to_apply=scalar_add
|
||||||
|
}
|
||||||
|
ENTRY entry {
|
||||||
|
p0 = f16[128,1,32,32]{1,3,2,0} parameter(0)
|
||||||
|
p1 = f16[128,1,32,32]{3,2,1,0} parameter(1)
|
||||||
|
loop_fusion = f16[128,1,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=mixed_input_layouts_computation
|
||||||
|
reduce_fusion = f32[1]{0} fusion(loop_fusion), kind=kInput, calls=fused_reduce
|
||||||
|
ROOT root = (f32[1]{0}, f16[128,1,32,32]{1,3,2,0}) tuple(reduce_fusion, loop_fusion)
|
||||||
|
})"))
|
||||||
|
.ValueOrDie();
|
||||||
|
SCOPED_TRACE(module->ToString());
|
||||||
|
const HloInstruction* reduce_fusion =
|
||||||
|
module->entry_computation()->root_instruction()->operand(0);
|
||||||
|
ASSERT_EQ(reduce_fusion->fused_expression_root()->opcode(),
|
||||||
|
HloOpcode::kReduce);
|
||||||
|
const HloInstruction* loop_fusion =
|
||||||
|
module->entry_computation()->root_instruction()->operand(1);
|
||||||
|
ASSERT_EQ(loop_fusion->fused_expression_root()->opcode(), HloOpcode::kSelect);
|
||||||
|
EXPECT_TRUE(
|
||||||
|
LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce_fusion));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(GpuFusibleTest, LayoutsAreReduceInputFusionFriendly_CopyProducer) {
|
TEST_F(GpuFusibleTest, LayoutsAreReduceInputFusionFriendly_CopyProducer) {
|
||||||
auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
|
auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
|
||||||
fused_reduce {
|
fused_reduce {
|
||||||
|
@ -152,17 +190,18 @@ TEST_F(GpuFusibleTest,
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GpuFusibleTest,
|
TEST_F(GpuFusibleTest,
|
||||||
LayoutsAreReduceInputFusionFriendly_ConsiderMaximumRanksParamsOnly) {
|
LayoutsAreReduceInputFusionFriendly_ConsiderMaximumTrueRanksParamsOnly) {
|
||||||
auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
|
auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
|
||||||
broadcasting_computation {
|
broadcasting_computation {
|
||||||
p0.1 = f32[128,1024,32,32]{1,3,2,0} parameter(0)
|
p0.1 = f32[128,1024,32,32]{1,3,2,0} parameter(0)
|
||||||
p1.1 = f32[128]{0} parameter(1)
|
p1.1 = f32[1,128,1,1]{3,2,1,0} parameter(1)
|
||||||
broadcast = f32[128,1024,32,32]{1,3,2,0} broadcast(p1.1), dimensions={0}
|
reshape = f32[128]{0} reshape(p1.1)
|
||||||
|
broadcast = f32[128,1024,32,32]{1,3,2,0} broadcast(reshape), dimensions={0}
|
||||||
ROOT add = f32[128,1024,32,32]{1,3,2,0} add(p0.1, broadcast)
|
ROOT add = f32[128,1024,32,32]{1,3,2,0} add(p0.1, broadcast)
|
||||||
}
|
}
|
||||||
ENTRY entry {
|
ENTRY entry {
|
||||||
p0 = f32[128,1024,32,32]{1,3,2,0} parameter(0)
|
p0 = f32[128,1024,32,32]{1,3,2,0} parameter(0)
|
||||||
p1 = f32[128]{0} parameter(1)
|
p1 = f32[1,128,1,1]{3,2,1,0} parameter(1)
|
||||||
loop_fusion = f32[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=broadcasting_computation
|
loop_fusion = f32[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=broadcasting_computation
|
||||||
c0.2 = f32[] constant(0)
|
c0.2 = f32[] constant(0)
|
||||||
ROOT reduce = f32[1024]{0} reduce(loop_fusion, c0.2), dimensions={0,2,3}, to_apply=scalar_add
|
ROOT reduce = f32[1024]{0} reduce(loop_fusion, c0.2), dimensions={0,2,3}, to_apply=scalar_add
|
||||||
|
|
Loading…
Reference in New Issue