Properly determine the shared memory requirements of a fusion node.
So far, we assumed that if it is a fusion node, we are always dealing with multi-output fusion, and each reduce in the fusion is an operand of the tuple root. This may not always be true. PiperOrigin-RevId: 340391529 Change-Id: Ie193d0828e911183feb5e650224b171f98ceaa0a
This commit is contained in:
parent
cc53d3a3e0
commit
8d9f9141f4
@ -334,9 +334,9 @@ static int64 SharedMemoryUsage(const HloInstruction& instr) {
|
||||
}
|
||||
} else if (instr.opcode() == HloOpcode::kFusion) {
|
||||
int64 sum = 0;
|
||||
for (const HloInstruction* operand :
|
||||
instr.fused_expression_root()->operands()) {
|
||||
sum += SharedMemoryUsage(*operand);
|
||||
for (const HloInstruction* hlo :
|
||||
instr.fused_instructions_computation()->MakeInstructionPostOrder()) {
|
||||
sum += SharedMemoryUsage(*hlo);
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
@ -906,7 +906,63 @@ TEST_F(MultiOutputFusionTest, SharedMemoryBudget) {
|
||||
.ValueOrDie();
|
||||
ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ConsumeValueOrDie());
|
||||
|
||||
EXPECT_EQ(2, CountMultiOutputFusions(module.get()));
|
||||
EXPECT_EQ(3, CountMultiOutputFusions(module.get()));
|
||||
}
|
||||
|
||||
TEST_F(MultiOutputFusionTest, NoFusionToAvoidUsingTooMuchSharedMemory) {
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule xla_computation_update_step.10931
|
||||
|
||||
%scalar_add_computation.1 (scalar_lhs.1: f64[], scalar_rhs.1: f64[]) -> f64[] {
|
||||
%scalar_lhs.1 = f64[] parameter(0)
|
||||
%scalar_rhs.1 = f64[] parameter(1)
|
||||
ROOT %add.1257 = f64[] add(f64[] %scalar_lhs.1, f64[] %scalar_rhs.1)
|
||||
}
|
||||
|
||||
%fused_computation.1 (param_0.8: f64[64,64], param_1.11: f64[64,64], param_2.9: f64[64,64]) -> (f64[64], f64[64]) {
|
||||
%param_0.8 = f64[64,64]{1,0} parameter(0)
|
||||
%param_1.11 = f64[64,64]{1,0} parameter(1)
|
||||
%multiply.2 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %param_0.8, f64[64,64]{1,0} %param_1.11)
|
||||
%constant_5217.3 = f64[] constant(0)
|
||||
%broadcast.1 = f64[64,64]{1,0} broadcast(f64[] %constant_5217.3), dimensions={}
|
||||
%multiply.0 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %multiply.2, f64[64,64]{1,0} %broadcast.1)
|
||||
%reduce.0 = f64[64]{0} reduce(f64[64,64]{1,0} %multiply.0, f64[] %constant_5217.3), dimensions={0}, to_apply=%scalar_add_computation.1
|
||||
%param_2.9 = f64[64,64]{1,0} parameter(2)
|
||||
%multiply.1514.clone.0.clone.1 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %param_2.9, f64[64,64]{1,0} %param_1.11)
|
||||
%constant_5217.1.clone.1 = f64[] constant(0)
|
||||
%broadcast.0.clone.1 = f64[64,64]{1,0} broadcast(f64[] %constant_5217.1.clone.1), dimensions={}
|
||||
%multiply.1341.clone.0.clone.1 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %multiply.1514.clone.0.clone.1, f64[64,64]{1,0} %broadcast.0.clone.1)
|
||||
%reduce.630.clone.0.clone.1 = f64[64]{0} reduce(f64[64,64]{1,0} %multiply.1341.clone.0.clone.1, f64[] %constant_5217.1.clone.1), dimensions={0}, to_apply=%scalar_add_computation.1
|
||||
ROOT %tuple = (f64[64]{0}, f64[64]{0}) tuple(f64[64]{0} %reduce.0, f64[64]{0} %reduce.630.clone.0.clone.1)
|
||||
}
|
||||
|
||||
%primitive_computation_add__1.6426 (parameter.6427: f64[], parameter.6428: f64[]) -> f64[] {
|
||||
%parameter.6427 = f64[] parameter(0)
|
||||
%parameter.6428 = f64[] parameter(1)
|
||||
ROOT %add.6429 = f64[] add(f64[] %parameter.6427, f64[] %parameter.6428)
|
||||
}
|
||||
|
||||
%fused_computation.2 (param_0.7: f64[64,64], param_1.9: f64[64,64]) -> f64[64] {
|
||||
%param_0.7 = f64[64,64]{1,0} parameter(0)
|
||||
%param_1.9 = f64[64,64]{1,0} parameter(1)
|
||||
%multiply.1 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %param_0.7, f64[64,64]{1,0} %param_1.9)
|
||||
%constant_5217.2 = f64[] constant(0)
|
||||
ROOT %reduce.740.clone.0 = f64[64]{0} reduce(f64[64,64]{1,0} %multiply.1, f64[] %constant_5217.2), dimensions={0}, to_apply=%primitive_computation_add__1.6426
|
||||
}
|
||||
|
||||
ENTRY %reproducer (param_0.1090: f64[64,64], param_1.1377: f64[64,64], param_2.1948: f64[64,64]) -> (f64[64], f64[64], f64[64]) {
|
||||
%param_0.1090 = f64[64,64]{1,0} parameter(0)
|
||||
%param_1.1377 = f64[64,64]{1,0} parameter(1)
|
||||
%param_2.1948 = f64[64,64]{1,0} parameter(2)
|
||||
%fusion.1 = (f64[64]{0}, f64[64]{0}) fusion(f64[64,64]{1,0} %param_0.1090, f64[64,64]{1,0} %param_1.1377, f64[64,64]{1,0} %param_2.1948), kind=kInput, calls=%fused_computation.1
|
||||
%get-tuple-element = f64[64]{0} get-tuple-element((f64[64]{0}, f64[64]{0}) %fusion.1), index=0
|
||||
%fusion.2 = f64[64]{0} fusion(f64[64,64]{1,0} %param_0.1090, f64[64,64]{1,0} %param_1.1377), kind=kInput, calls=%fused_computation.2
|
||||
%get-tuple-element.1 = f64[64]{0} get-tuple-element((f64[64]{0}, f64[64]{0}) %fusion.1), index=1
|
||||
ROOT %tuple.428 = (f64[64]{0}, f64[64]{0}, f64[64]{0}) tuple(f64[64]{0} %get-tuple-element, f64[64]{0} %fusion.2, f64[64]{0} %get-tuple-element.1)
|
||||
}
|
||||
)")
|
||||
.ValueOrDie();
|
||||
EXPECT_FALSE(GpuMultiOutputFusion().Run(module.get()).ConsumeValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(MultiOutputFusionTest, NoFusionToAvoidCodeDuplication) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user