Properly determine the shared memory requirements of a fusion node.

So far, we assumed that if it is a fusion node, we are always dealing with
multi-output fusion, and each reduce in the fusion is an operand of the tuple
root. This may not always be true.

PiperOrigin-RevId: 340391529
Change-Id: Ie193d0828e911183feb5e650224b171f98ceaa0a
This commit is contained in:
Adrian Kuegel 2020-11-03 00:11:13 -08:00 committed by TensorFlower Gardener
parent cc53d3a3e0
commit 8d9f9141f4
2 changed files with 60 additions and 4 deletions

View File

@ -334,9 +334,9 @@ static int64 SharedMemoryUsage(const HloInstruction& instr) {
}
} else if (instr.opcode() == HloOpcode::kFusion) {
int64 sum = 0;
for (const HloInstruction* operand :
instr.fused_expression_root()->operands()) {
sum += SharedMemoryUsage(*operand);
for (const HloInstruction* hlo :
instr.fused_instructions_computation()->MakeInstructionPostOrder()) {
sum += SharedMemoryUsage(*hlo);
}
return sum;
}

View File

@ -906,7 +906,63 @@ TEST_F(MultiOutputFusionTest, SharedMemoryBudget) {
.ValueOrDie();
ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ConsumeValueOrDie());
EXPECT_EQ(2, CountMultiOutputFusions(module.get()));
EXPECT_EQ(3, CountMultiOutputFusions(module.get()));
}
TEST_F(MultiOutputFusionTest, NoFusionToAvoidUsingTooMuchSharedMemory) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule xla_computation_update_step.10931
%scalar_add_computation.1 (scalar_lhs.1: f64[], scalar_rhs.1: f64[]) -> f64[] {
%scalar_lhs.1 = f64[] parameter(0)
%scalar_rhs.1 = f64[] parameter(1)
ROOT %add.1257 = f64[] add(f64[] %scalar_lhs.1, f64[] %scalar_rhs.1)
}
%fused_computation.1 (param_0.8: f64[64,64], param_1.11: f64[64,64], param_2.9: f64[64,64]) -> (f64[64], f64[64]) {
%param_0.8 = f64[64,64]{1,0} parameter(0)
%param_1.11 = f64[64,64]{1,0} parameter(1)
%multiply.2 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %param_0.8, f64[64,64]{1,0} %param_1.11)
%constant_5217.3 = f64[] constant(0)
%broadcast.1 = f64[64,64]{1,0} broadcast(f64[] %constant_5217.3), dimensions={}
%multiply.0 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %multiply.2, f64[64,64]{1,0} %broadcast.1)
%reduce.0 = f64[64]{0} reduce(f64[64,64]{1,0} %multiply.0, f64[] %constant_5217.3), dimensions={0}, to_apply=%scalar_add_computation.1
%param_2.9 = f64[64,64]{1,0} parameter(2)
%multiply.1514.clone.0.clone.1 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %param_2.9, f64[64,64]{1,0} %param_1.11)
%constant_5217.1.clone.1 = f64[] constant(0)
%broadcast.0.clone.1 = f64[64,64]{1,0} broadcast(f64[] %constant_5217.1.clone.1), dimensions={}
%multiply.1341.clone.0.clone.1 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %multiply.1514.clone.0.clone.1, f64[64,64]{1,0} %broadcast.0.clone.1)
%reduce.630.clone.0.clone.1 = f64[64]{0} reduce(f64[64,64]{1,0} %multiply.1341.clone.0.clone.1, f64[] %constant_5217.1.clone.1), dimensions={0}, to_apply=%scalar_add_computation.1
ROOT %tuple = (f64[64]{0}, f64[64]{0}) tuple(f64[64]{0} %reduce.0, f64[64]{0} %reduce.630.clone.0.clone.1)
}
%primitive_computation_add__1.6426 (parameter.6427: f64[], parameter.6428: f64[]) -> f64[] {
%parameter.6427 = f64[] parameter(0)
%parameter.6428 = f64[] parameter(1)
ROOT %add.6429 = f64[] add(f64[] %parameter.6427, f64[] %parameter.6428)
}
%fused_computation.2 (param_0.7: f64[64,64], param_1.9: f64[64,64]) -> f64[64] {
%param_0.7 = f64[64,64]{1,0} parameter(0)
%param_1.9 = f64[64,64]{1,0} parameter(1)
%multiply.1 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %param_0.7, f64[64,64]{1,0} %param_1.9)
%constant_5217.2 = f64[] constant(0)
ROOT %reduce.740.clone.0 = f64[64]{0} reduce(f64[64,64]{1,0} %multiply.1, f64[] %constant_5217.2), dimensions={0}, to_apply=%primitive_computation_add__1.6426
}
ENTRY %reproducer (param_0.1090: f64[64,64], param_1.1377: f64[64,64], param_2.1948: f64[64,64]) -> (f64[64], f64[64], f64[64]) {
%param_0.1090 = f64[64,64]{1,0} parameter(0)
%param_1.1377 = f64[64,64]{1,0} parameter(1)
%param_2.1948 = f64[64,64]{1,0} parameter(2)
%fusion.1 = (f64[64]{0}, f64[64]{0}) fusion(f64[64,64]{1,0} %param_0.1090, f64[64,64]{1,0} %param_1.1377, f64[64,64]{1,0} %param_2.1948), kind=kInput, calls=%fused_computation.1
%get-tuple-element = f64[64]{0} get-tuple-element((f64[64]{0}, f64[64]{0}) %fusion.1), index=0
%fusion.2 = f64[64]{0} fusion(f64[64,64]{1,0} %param_0.1090, f64[64,64]{1,0} %param_1.1377), kind=kInput, calls=%fused_computation.2
%get-tuple-element.1 = f64[64]{0} get-tuple-element((f64[64]{0}, f64[64]{0}) %fusion.1), index=1
ROOT %tuple.428 = (f64[64]{0}, f64[64]{0}, f64[64]{0}) tuple(f64[64]{0} %get-tuple-element, f64[64]{0} %fusion.2, f64[64]{0} %get-tuple-element.1)
}
)")
.ValueOrDie();
EXPECT_FALSE(GpuMultiOutputFusion().Run(module.get()).ConsumeValueOrDie());
}
TEST_F(MultiOutputFusionTest, NoFusionToAvoidCodeDuplication) {