[XLA] Propagate memory spaces recursively inside nested fusions.
PiperOrigin-RevId: 317171110 Change-Id: I65004edb7498acb2f3b4238d9afbbb5d3930aab5
This commit is contained in:
parent
e0962f4c37
commit
8f700fb2e0
@ -29,36 +29,78 @@ StatusOr<bool> MemorySpacePropagation::Run(HloModule* module) {
|
||||
// Propagate the operand subshapes.
|
||||
for (int operand_idx = 0; operand_idx < instruction->operand_count();
|
||||
++operand_idx) {
|
||||
modified |=
|
||||
PropagateSubshapes(instruction->operand(operand_idx)->shape(),
|
||||
instruction->fused_parameter(operand_idx));
|
||||
for (const ShapeUtil::IndexedShape& indexed_shape :
|
||||
ShapeUtil::GetLeafShapes(
|
||||
instruction->operand(operand_idx)->shape())) {
|
||||
int64 memory_space = indexed_shape.shape.layout().memory_space();
|
||||
modified |= Propagate(indexed_shape.index,
|
||||
instruction->fused_parameter(operand_idx),
|
||||
memory_space);
|
||||
}
|
||||
}
|
||||
|
||||
// Propagate output subshapes.
|
||||
modified |= PropagateSubshapes(instruction->shape(),
|
||||
instruction->fused_expression_root());
|
||||
for (const ShapeUtil::IndexedShape& indexed_shape :
|
||||
ShapeUtil::GetLeafShapes(instruction->shape())) {
|
||||
int64 memory_space = indexed_shape.shape.layout().memory_space();
|
||||
modified |=
|
||||
Propagate(indexed_shape.index,
|
||||
instruction->fused_expression_root(), memory_space);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return modified;
|
||||
}
|
||||
|
||||
bool MemorySpacePropagation::PropagateSubshapes(
|
||||
const Shape& caller_shape, const HloInstruction* callee_instruction) const {
|
||||
bool MemorySpacePropagation::Propagate(ShapeIndexView index,
|
||||
const HloInstruction* callee_instruction,
|
||||
int64 memory_space) const {
|
||||
bool modified = false;
|
||||
for (const ShapeUtil::IndexedShape& indexed_shape :
|
||||
ShapeUtil::GetLeafShapes(caller_shape)) {
|
||||
int64 memory_space = indexed_shape.shape.layout().memory_space();
|
||||
const HloValue& value = dataflow_analysis_->GetUniqueValueAt(
|
||||
callee_instruction, indexed_shape.index);
|
||||
const HloValue& value = dataflow_analysis_->GetUniqueValueAt(
|
||||
callee_instruction, index.ToShapeIndex());
|
||||
|
||||
for (const HloPosition& position : value.positions()) {
|
||||
Shape* shape = ShapeUtil::GetMutableSubshape(
|
||||
position.instruction->mutable_shape(), position.index);
|
||||
if (shape->layout().memory_space() != memory_space) {
|
||||
shape->mutable_layout()->set_memory_space(memory_space);
|
||||
modified = true;
|
||||
}
|
||||
for (const HloPosition& position : value.positions()) {
|
||||
HloInstruction* instruction = position.instruction;
|
||||
Shape* shape = ShapeUtil::GetMutableSubshape(instruction->mutable_shape(),
|
||||
position.index);
|
||||
if (shape->layout().memory_space() == memory_space) {
|
||||
continue;
|
||||
}
|
||||
shape->mutable_layout()->set_memory_space(memory_space);
|
||||
modified = true;
|
||||
|
||||
// For fusion outputs, propagate the memory space to the fusion root.
|
||||
if (instruction->opcode() == HloOpcode::kFusion) {
|
||||
Propagate(position.index, instruction->fused_expression_root(),
|
||||
memory_space);
|
||||
}
|
||||
|
||||
const HloInstruction* parent_fusion =
|
||||
instruction->parent()->FusionInstruction();
|
||||
// For nested fusion roots, pop one level up and propagate the memory space
|
||||
// to the output of the calling fusion instruction.
|
||||
if (instruction == instruction->parent()->root_instruction() &&
|
||||
parent_fusion->parent()->IsFusionComputation()) {
|
||||
Propagate(position.index, parent_fusion, memory_space);
|
||||
}
|
||||
|
||||
// For nested fusion parameters, pop one level up and propagate the memory
|
||||
// space to the operand of the calling fusion instruction.
|
||||
if (instruction->opcode() == HloOpcode::kParameter &&
|
||||
parent_fusion->parent()->IsFusionComputation()) {
|
||||
const HloInstruction* fusion_operand =
|
||||
parent_fusion->operand(instruction->parameter_number());
|
||||
Propagate(position.index, fusion_operand, memory_space);
|
||||
}
|
||||
}
|
||||
|
||||
for (const HloUse& use : value.uses()) {
|
||||
// For fusion uses, propagate the memory space to the fusion parameter.
|
||||
if (use.instruction->opcode() == HloOpcode::kFusion) {
|
||||
modified |= Propagate(
|
||||
use.operand_index,
|
||||
use.instruction->fused_parameter(use.operand_number), memory_space);
|
||||
}
|
||||
}
|
||||
return modified;
|
||||
|
@ -31,12 +31,11 @@ class MemorySpacePropagation : public HloModulePass {
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
private:
|
||||
// Given the caller shape (operand or output) and its corresponding
|
||||
// insturction in the fused computation (parameter or root), propagates the
|
||||
// memory space to all the subshapes in the callee side. Returns true if the
|
||||
// module is modified.
|
||||
bool PropagateSubshapes(const Shape& caller_shape,
|
||||
const HloInstruction* callee_instruction) const;
|
||||
// Given the shape index (operand or output) and its corresponding instruction
|
||||
// in the fused computation (parameter or root), propagates the memory space
|
||||
// in the callee side. Returns true if the module is modified.
|
||||
bool Propagate(ShapeIndexView index, const HloInstruction* callee_instruction,
|
||||
int64 memory_space) const;
|
||||
|
||||
std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
|
||||
};
|
||||
|
@ -199,5 +199,153 @@ TEST_F(MemorySpacePropagationTest, TupleOutput) {
|
||||
EXPECT_EQ(module->Hash(), ref->Hash());
|
||||
}
|
||||
|
||||
TEST_F(MemorySpacePropagationTest, NestedInputFusion) {
|
||||
// Tests propagating the memory space to nested fusions on the input side.
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule NestedFusion
|
||||
|
||||
%bitcast_fusion {
|
||||
%bf_param = s32[3,2]{0,1:T(128)} parameter(0)
|
||||
ROOT %bitcast = s32[6]{0:T(128)} bitcast(%bf_param)
|
||||
}
|
||||
|
||||
%fused_computation {
|
||||
%param_1.3 = s32[1]{0:T(128)} parameter(1)
|
||||
%constant.2 = s32[]{:T(128)} constant(-2147483648)
|
||||
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
|
||||
%param_2.3 = s32[5]{0:T(128)} parameter(2)
|
||||
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
|
||||
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
|
||||
%param_0.1 = s32[3,2]{0,1:T(128)} parameter(0)
|
||||
%fusion.1 = s32[6]{0:T(128)} fusion(%param_0.1), kind=kLoop, calls=bitcast_fusion
|
||||
ROOT %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %fusion.1)
|
||||
}
|
||||
|
||||
ENTRY %entry {
|
||||
%param0 = s32[3,2]{0,1:T(128)} parameter(0)
|
||||
%param1 = s32[1]{0:T(128)} parameter(1)
|
||||
%param2 = s32[5]{0:T(128)} parameter(2)
|
||||
%arg0 = s32[3,2]{0,1:T(128)S(1)} copy(%param0)
|
||||
%arg1 = s32[1]{0:T(128)} copy(%param1)
|
||||
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
|
||||
%fusion = s32[6]{0:T(128)S(1)} fusion(s32[3,2]{0,1:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
|
||||
ROOT %root = s32[6]{0:T(128)} copy(%fusion)
|
||||
}
|
||||
)";
|
||||
absl::string_view expected_hlo_string = R"(
|
||||
HloModule NestedFusion
|
||||
|
||||
%bitcast_fusion {
|
||||
%bf_param = s32[3,2]{0,1:T(128)S(1)} parameter(0)
|
||||
ROOT %bitcast = s32[6]{0:T(128)S(1)} bitcast(%bf_param)
|
||||
}
|
||||
|
||||
%fused_computation {
|
||||
%param_1.3 = s32[1]{0:T(128)} parameter(1)
|
||||
%constant.2 = s32[]{:T(128)} constant(-2147483648)
|
||||
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
|
||||
%param_2.3 = s32[5]{0:T(128)S(1)} parameter(2)
|
||||
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
|
||||
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
|
||||
%param_0.1 = s32[3,2]{0,1:T(128)S(1)} parameter(0)
|
||||
%fusion.1 = s32[6]{0:T(128)S(1)} fusion(%param_0.1), kind=kLoop, calls=bitcast_fusion
|
||||
ROOT %add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)S(1)} %fusion.1)
|
||||
}
|
||||
|
||||
ENTRY %entry {
|
||||
%param0 = s32[3,2]{0,1:T(128)} parameter(0)
|
||||
%param1 = s32[1]{0:T(128)} parameter(1)
|
||||
%param2 = s32[5]{0:T(128)} parameter(2)
|
||||
%arg0 = s32[3,2]{0,1:T(128)S(1)} copy(%param0)
|
||||
%arg1 = s32[1]{0:T(128)} copy(%param1)
|
||||
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
|
||||
%fusion = s32[6]{0:T(128)S(1)} fusion(s32[3,2]{0,1:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
|
||||
ROOT %root = s32[6]{0:T(128)} copy(%fusion)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
MemorySpacePropagation memory_space_propagation;
|
||||
EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie());
|
||||
TF_EXPECT_OK(Verify(module.get()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto ref,
|
||||
ParseAndReturnVerifiedModule(expected_hlo_string));
|
||||
EXPECT_EQ(module->Hash(), ref->Hash());
|
||||
}
|
||||
|
||||
TEST_F(MemorySpacePropagationTest, NestedOutputFusion) {
|
||||
// Tests propagating the memory space to nested fusions on the output side.
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule NestedFusion
|
||||
|
||||
%bitcast_fusion {
|
||||
%bf_param = s32[6]{0:T(128)} parameter(0)
|
||||
ROOT %bitcast = s32[3,2]{0,1:T(128)} bitcast(%bf_param)
|
||||
}
|
||||
|
||||
%fused_computation {
|
||||
%param_1.3 = s32[1]{0:T(128)} parameter(1)
|
||||
%constant.2 = s32[]{:T(128)} constant(-2147483648)
|
||||
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
|
||||
%param_2.3 = s32[5]{0:T(128)} parameter(2)
|
||||
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
|
||||
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
|
||||
%param_0.1 = s32[6]{0:T(128)} parameter(0)
|
||||
%add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1)
|
||||
ROOT %fusion.1 = s32[3,2]{0,1:T(128)} fusion(%add.0), kind=kLoop, calls=bitcast_fusion
|
||||
}
|
||||
|
||||
ENTRY %entry {
|
||||
%param0 = s32[6]{0:T(128)} parameter(0)
|
||||
%param1 = s32[1]{0:T(128)} parameter(1)
|
||||
%param2 = s32[5]{0:T(128)} parameter(2)
|
||||
%arg0 = s32[6]{0:T(128)S(1)} copy(%param0)
|
||||
%arg1 = s32[1]{0:T(128)} copy(%param1)
|
||||
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
|
||||
%fusion = s32[3,2]{0,1:T(128)S(1)} fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
|
||||
ROOT %root = s32[3,2]{0,1:T(128)} copy(%fusion)
|
||||
}
|
||||
)";
|
||||
absl::string_view expected_hlo_string = R"(
|
||||
HloModule NestedFusion
|
||||
|
||||
%bitcast_fusion {
|
||||
%bf_param = s32[6]{0:T(128)S(1)} parameter(0)
|
||||
ROOT %bitcast = s32[3,2]{0,1:T(128)S(1)} bitcast(%bf_param)
|
||||
}
|
||||
|
||||
%fused_computation {
|
||||
%param_1.3 = s32[1]{0:T(128)} parameter(1)
|
||||
%constant.2 = s32[]{:T(128)} constant(-2147483648)
|
||||
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
|
||||
%param_2.3 = s32[5]{0:T(128)S(1)} parameter(2)
|
||||
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
|
||||
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
|
||||
%param_0.1 = s32[6]{0:T(128)S(1)} parameter(0)
|
||||
%add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)S(1)} %param_0.1)
|
||||
ROOT %fusion.1 = s32[3,2]{0,1:T(128)S(1)} fusion(%add.0), kind=kLoop, calls=bitcast_fusion
|
||||
}
|
||||
|
||||
ENTRY %entry {
|
||||
%param0 = s32[6]{0:T(128)} parameter(0)
|
||||
%param1 = s32[1]{0:T(128)} parameter(1)
|
||||
%param2 = s32[5]{0:T(128)} parameter(2)
|
||||
%arg0 = s32[6]{0:T(128)S(1)} copy(%param0)
|
||||
%arg1 = s32[1]{0:T(128)} copy(%param1)
|
||||
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
|
||||
%fusion = s32[3,2]{0,1:T(128)S(1)} fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
|
||||
ROOT %root = s32[3,2]{0,1:T(128)} copy(%fusion)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
MemorySpacePropagation memory_space_propagation;
|
||||
EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie());
|
||||
TF_EXPECT_OK(Verify(module.get()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto ref,
|
||||
ParseAndReturnVerifiedModule(expected_hlo_string));
|
||||
EXPECT_EQ(module->Hash(), ref->Hash());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user