[XLA] Propagate memory spaces recursively inside nested fusions.

PiperOrigin-RevId: 317171110
Change-Id: I65004edb7498acb2f3b4238d9afbbb5d3930aab5
This commit is contained in:
Berkin Ilbeyi 2020-06-18 13:39:10 -07:00 committed by TensorFlower Gardener
parent e0962f4c37
commit 8f700fb2e0
3 changed files with 214 additions and 25 deletions

View File

@ -29,36 +29,78 @@ StatusOr<bool> MemorySpacePropagation::Run(HloModule* module) {
// Propagate the operand subshapes.
for (int operand_idx = 0; operand_idx < instruction->operand_count();
++operand_idx) {
modified |=
PropagateSubshapes(instruction->operand(operand_idx)->shape(),
instruction->fused_parameter(operand_idx));
for (const ShapeUtil::IndexedShape& indexed_shape :
ShapeUtil::GetLeafShapes(
instruction->operand(operand_idx)->shape())) {
int64 memory_space = indexed_shape.shape.layout().memory_space();
modified |= Propagate(indexed_shape.index,
instruction->fused_parameter(operand_idx),
memory_space);
}
}
// Propagate output subshapes.
modified |= PropagateSubshapes(instruction->shape(),
instruction->fused_expression_root());
for (const ShapeUtil::IndexedShape& indexed_shape :
ShapeUtil::GetLeafShapes(instruction->shape())) {
int64 memory_space = indexed_shape.shape.layout().memory_space();
modified |=
Propagate(indexed_shape.index,
instruction->fused_expression_root(), memory_space);
}
}
}
}
return modified;
}
bool MemorySpacePropagation::PropagateSubshapes(
const Shape& caller_shape, const HloInstruction* callee_instruction) const {
bool MemorySpacePropagation::Propagate(ShapeIndexView index,
const HloInstruction* callee_instruction,
int64 memory_space) const {
bool modified = false;
for (const ShapeUtil::IndexedShape& indexed_shape :
ShapeUtil::GetLeafShapes(caller_shape)) {
int64 memory_space = indexed_shape.shape.layout().memory_space();
const HloValue& value = dataflow_analysis_->GetUniqueValueAt(
callee_instruction, indexed_shape.index);
const HloValue& value = dataflow_analysis_->GetUniqueValueAt(
callee_instruction, index.ToShapeIndex());
for (const HloPosition& position : value.positions()) {
Shape* shape = ShapeUtil::GetMutableSubshape(
position.instruction->mutable_shape(), position.index);
if (shape->layout().memory_space() != memory_space) {
shape->mutable_layout()->set_memory_space(memory_space);
modified = true;
}
for (const HloPosition& position : value.positions()) {
HloInstruction* instruction = position.instruction;
Shape* shape = ShapeUtil::GetMutableSubshape(instruction->mutable_shape(),
position.index);
if (shape->layout().memory_space() == memory_space) {
continue;
}
shape->mutable_layout()->set_memory_space(memory_space);
modified = true;
// For fusion outputs, propagate the memory space to the fusion root.
if (instruction->opcode() == HloOpcode::kFusion) {
Propagate(position.index, instruction->fused_expression_root(),
memory_space);
}
const HloInstruction* parent_fusion =
instruction->parent()->FusionInstruction();
// For nested fusion roots, pop one level up and propagate the memory space
// to the output of the calling fusion instruction.
if (instruction == instruction->parent()->root_instruction() &&
parent_fusion->parent()->IsFusionComputation()) {
Propagate(position.index, parent_fusion, memory_space);
}
// For nested fusion parameters, pop one level up and propagate the memory
// space to the operand of the calling fusion instruction.
if (instruction->opcode() == HloOpcode::kParameter &&
parent_fusion->parent()->IsFusionComputation()) {
const HloInstruction* fusion_operand =
parent_fusion->operand(instruction->parameter_number());
Propagate(position.index, fusion_operand, memory_space);
}
}
for (const HloUse& use : value.uses()) {
// For fusion uses, propagate the memory space to the fusion parameter.
if (use.instruction->opcode() == HloOpcode::kFusion) {
modified |= Propagate(
use.operand_index,
use.instruction->fused_parameter(use.operand_number), memory_space);
}
}
return modified;

View File

@ -31,12 +31,11 @@ class MemorySpacePropagation : public HloModulePass {
StatusOr<bool> Run(HloModule* module) override;
private:
// Given the caller shape (operand or output) and its corresponding
// insturction in the fused computation (parameter or root), propagates the
// memory space to all the subshapes in the callee side. Returns true if the
// module is modified.
bool PropagateSubshapes(const Shape& caller_shape,
const HloInstruction* callee_instruction) const;
// Given the shape index (operand or output) and its corresponding instruction
// in the fused computation (parameter or root), propagates the memory space
// in the callee side. Returns true if the module is modified.
bool Propagate(ShapeIndexView index, const HloInstruction* callee_instruction,
int64 memory_space) const;
std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
};

View File

@ -199,5 +199,153 @@ TEST_F(MemorySpacePropagationTest, TupleOutput) {
EXPECT_EQ(module->Hash(), ref->Hash());
}
TEST_F(MemorySpacePropagationTest, NestedInputFusion) {
// Tests propagating the memory space to nested fusions on the input side.
absl::string_view hlo_string = R"(
HloModule NestedFusion
%bitcast_fusion {
%bf_param = s32[3,2]{0,1:T(128)} parameter(0)
ROOT %bitcast = s32[6]{0:T(128)} bitcast(%bf_param)
}
%fused_computation {
%param_1.3 = s32[1]{0:T(128)} parameter(1)
%constant.2 = s32[]{:T(128)} constant(-2147483648)
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
%param_2.3 = s32[5]{0:T(128)} parameter(2)
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
%param_0.1 = s32[3,2]{0,1:T(128)} parameter(0)
%fusion.1 = s32[6]{0:T(128)} fusion(%param_0.1), kind=kLoop, calls=bitcast_fusion
ROOT %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %fusion.1)
}
ENTRY %entry {
%param0 = s32[3,2]{0,1:T(128)} parameter(0)
%param1 = s32[1]{0:T(128)} parameter(1)
%param2 = s32[5]{0:T(128)} parameter(2)
%arg0 = s32[3,2]{0,1:T(128)S(1)} copy(%param0)
%arg1 = s32[1]{0:T(128)} copy(%param1)
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
%fusion = s32[6]{0:T(128)S(1)} fusion(s32[3,2]{0,1:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
ROOT %root = s32[6]{0:T(128)} copy(%fusion)
}
)";
absl::string_view expected_hlo_string = R"(
HloModule NestedFusion
%bitcast_fusion {
%bf_param = s32[3,2]{0,1:T(128)S(1)} parameter(0)
ROOT %bitcast = s32[6]{0:T(128)S(1)} bitcast(%bf_param)
}
%fused_computation {
%param_1.3 = s32[1]{0:T(128)} parameter(1)
%constant.2 = s32[]{:T(128)} constant(-2147483648)
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
%param_2.3 = s32[5]{0:T(128)S(1)} parameter(2)
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
%param_0.1 = s32[3,2]{0,1:T(128)S(1)} parameter(0)
%fusion.1 = s32[6]{0:T(128)S(1)} fusion(%param_0.1), kind=kLoop, calls=bitcast_fusion
ROOT %add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)S(1)} %fusion.1)
}
ENTRY %entry {
%param0 = s32[3,2]{0,1:T(128)} parameter(0)
%param1 = s32[1]{0:T(128)} parameter(1)
%param2 = s32[5]{0:T(128)} parameter(2)
%arg0 = s32[3,2]{0,1:T(128)S(1)} copy(%param0)
%arg1 = s32[1]{0:T(128)} copy(%param1)
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
%fusion = s32[6]{0:T(128)S(1)} fusion(s32[3,2]{0,1:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
ROOT %root = s32[6]{0:T(128)} copy(%fusion)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
MemorySpacePropagation memory_space_propagation;
EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie());
TF_EXPECT_OK(Verify(module.get()));
TF_ASSERT_OK_AND_ASSIGN(auto ref,
ParseAndReturnVerifiedModule(expected_hlo_string));
EXPECT_EQ(module->Hash(), ref->Hash());
}
TEST_F(MemorySpacePropagationTest, NestedOutputFusion) {
// Tests propagating the memory space to nested fusions on the output side.
absl::string_view hlo_string = R"(
HloModule NestedFusion
%bitcast_fusion {
%bf_param = s32[6]{0:T(128)} parameter(0)
ROOT %bitcast = s32[3,2]{0,1:T(128)} bitcast(%bf_param)
}
%fused_computation {
%param_1.3 = s32[1]{0:T(128)} parameter(1)
%constant.2 = s32[]{:T(128)} constant(-2147483648)
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
%param_2.3 = s32[5]{0:T(128)} parameter(2)
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
%param_0.1 = s32[6]{0:T(128)} parameter(0)
%add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1)
ROOT %fusion.1 = s32[3,2]{0,1:T(128)} fusion(%add.0), kind=kLoop, calls=bitcast_fusion
}
ENTRY %entry {
%param0 = s32[6]{0:T(128)} parameter(0)
%param1 = s32[1]{0:T(128)} parameter(1)
%param2 = s32[5]{0:T(128)} parameter(2)
%arg0 = s32[6]{0:T(128)S(1)} copy(%param0)
%arg1 = s32[1]{0:T(128)} copy(%param1)
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
%fusion = s32[3,2]{0,1:T(128)S(1)} fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
ROOT %root = s32[3,2]{0,1:T(128)} copy(%fusion)
}
)";
absl::string_view expected_hlo_string = R"(
HloModule NestedFusion
%bitcast_fusion {
%bf_param = s32[6]{0:T(128)S(1)} parameter(0)
ROOT %bitcast = s32[3,2]{0,1:T(128)S(1)} bitcast(%bf_param)
}
%fused_computation {
%param_1.3 = s32[1]{0:T(128)} parameter(1)
%constant.2 = s32[]{:T(128)} constant(-2147483648)
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
%param_2.3 = s32[5]{0:T(128)S(1)} parameter(2)
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
%param_0.1 = s32[6]{0:T(128)S(1)} parameter(0)
%add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)S(1)} %param_0.1)
ROOT %fusion.1 = s32[3,2]{0,1:T(128)S(1)} fusion(%add.0), kind=kLoop, calls=bitcast_fusion
}
ENTRY %entry {
%param0 = s32[6]{0:T(128)} parameter(0)
%param1 = s32[1]{0:T(128)} parameter(1)
%param2 = s32[5]{0:T(128)} parameter(2)
%arg0 = s32[6]{0:T(128)S(1)} copy(%param0)
%arg1 = s32[1]{0:T(128)} copy(%param1)
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
%fusion = s32[3,2]{0,1:T(128)S(1)} fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
ROOT %root = s32[3,2]{0,1:T(128)} copy(%fusion)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
MemorySpacePropagation memory_space_propagation;
EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie());
TF_EXPECT_OK(Verify(module.get()));
TF_ASSERT_OK_AND_ASSIGN(auto ref,
ParseAndReturnVerifiedModule(expected_hlo_string));
EXPECT_EQ(module->Hash(), ref->Hash());
}
} // namespace
} // namespace xla