From 8f700fb2e0da382f1e2e9630f56a7922a8799a59 Mon Sep 17 00:00:00 2001 From: Berkin Ilbeyi Date: Thu, 18 Jun 2020 13:39:10 -0700 Subject: [PATCH] [XLA] Propagate memory spaces recursively inside nested fusions. PiperOrigin-RevId: 317171110 Change-Id: I65004edb7498acb2f3b4238d9afbbb5d3930aab5 --- .../xla/service/memory_space_propagation.cc | 80 +++++++--- .../xla/service/memory_space_propagation.h | 11 +- .../service/memory_space_propagation_test.cc | 148 ++++++++++++++++++ 3 files changed, 214 insertions(+), 25 deletions(-) diff --git a/tensorflow/compiler/xla/service/memory_space_propagation.cc b/tensorflow/compiler/xla/service/memory_space_propagation.cc index 80eb4017477..2eb15b14eaf 100644 --- a/tensorflow/compiler/xla/service/memory_space_propagation.cc +++ b/tensorflow/compiler/xla/service/memory_space_propagation.cc @@ -29,36 +29,78 @@ StatusOr MemorySpacePropagation::Run(HloModule* module) { // Propagate the operand subshapes. for (int operand_idx = 0; operand_idx < instruction->operand_count(); ++operand_idx) { - modified |= - PropagateSubshapes(instruction->operand(operand_idx)->shape(), - instruction->fused_parameter(operand_idx)); + for (const ShapeUtil::IndexedShape& indexed_shape : + ShapeUtil::GetLeafShapes( + instruction->operand(operand_idx)->shape())) { + int64 memory_space = indexed_shape.shape.layout().memory_space(); + modified |= Propagate(indexed_shape.index, + instruction->fused_parameter(operand_idx), + memory_space); + } } // Propagate output subshapes. - modified |= PropagateSubshapes(instruction->shape(), - instruction->fused_expression_root()); + for (const ShapeUtil::IndexedShape& indexed_shape : + ShapeUtil::GetLeafShapes(instruction->shape())) { + int64 memory_space = indexed_shape.shape.layout().memory_space(); + modified |= + Propagate(indexed_shape.index, + instruction->fused_expression_root(), memory_space); + } } } } return modified; } -bool MemorySpacePropagation::PropagateSubshapes( - const Shape& caller_shape, const HloInstruction* callee_instruction) const { +bool MemorySpacePropagation::Propagate(ShapeIndexView index, + const HloInstruction* callee_instruction, + int64 memory_space) const { bool modified = false; - for (const ShapeUtil::IndexedShape& indexed_shape : - ShapeUtil::GetLeafShapes(caller_shape)) { - int64 memory_space = indexed_shape.shape.layout().memory_space(); - const HloValue& value = dataflow_analysis_->GetUniqueValueAt( - callee_instruction, indexed_shape.index); + const HloValue& value = dataflow_analysis_->GetUniqueValueAt( + callee_instruction, index.ToShapeIndex()); - for (const HloPosition& position : value.positions()) { - Shape* shape = ShapeUtil::GetMutableSubshape( - position.instruction->mutable_shape(), position.index); - if (shape->layout().memory_space() != memory_space) { - shape->mutable_layout()->set_memory_space(memory_space); - modified = true; - } + for (const HloPosition& position : value.positions()) { + HloInstruction* instruction = position.instruction; + Shape* shape = ShapeUtil::GetMutableSubshape(instruction->mutable_shape(), + position.index); + if (shape->layout().memory_space() == memory_space) { + continue; + } + shape->mutable_layout()->set_memory_space(memory_space); + modified = true; + + // For fusion outputs, propagate the memory space to the fusion root. + if (instruction->opcode() == HloOpcode::kFusion) { + Propagate(position.index, instruction->fused_expression_root(), + memory_space); + } + + const HloInstruction* parent_fusion = + instruction->parent()->FusionInstruction(); + // For nested fusion roots, pop one level up and propagate the memory space + // to the output of the calling fusion instruction. + if (instruction == instruction->parent()->root_instruction() && + parent_fusion->parent()->IsFusionComputation()) { + Propagate(position.index, parent_fusion, memory_space); + } + + // For nested fusion parameters, pop one level up and propagate the memory + // space to the operand of the calling fusion instruction. + if (instruction->opcode() == HloOpcode::kParameter && + parent_fusion->parent()->IsFusionComputation()) { + const HloInstruction* fusion_operand = + parent_fusion->operand(instruction->parameter_number()); + Propagate(position.index, fusion_operand, memory_space); + } + } + + for (const HloUse& use : value.uses()) { + // For fusion uses, propagate the memory space to the fusion parameter. + if (use.instruction->opcode() == HloOpcode::kFusion) { + modified |= Propagate( + use.operand_index, + use.instruction->fused_parameter(use.operand_number), memory_space); } } return modified; diff --git a/tensorflow/compiler/xla/service/memory_space_propagation.h b/tensorflow/compiler/xla/service/memory_space_propagation.h index 65a1dfd14a6..510e9e69f79 100644 --- a/tensorflow/compiler/xla/service/memory_space_propagation.h +++ b/tensorflow/compiler/xla/service/memory_space_propagation.h @@ -31,12 +31,11 @@ class MemorySpacePropagation : public HloModulePass { StatusOr Run(HloModule* module) override; private: - // Given the caller shape (operand or output) and its corresponding - // insturction in the fused computation (parameter or root), propagates the - // memory space to all the subshapes in the callee side. Returns true if the - // module is modified. - bool PropagateSubshapes(const Shape& caller_shape, - const HloInstruction* callee_instruction) const; + // Given the shape index (operand or output) and its corresponding instruction + // in the fused computation (parameter or root), propagates the memory space + // in the callee side. Returns true if the module is modified. + bool Propagate(ShapeIndexView index, const HloInstruction* callee_instruction, + int64 memory_space) const; std::unique_ptr dataflow_analysis_; }; diff --git a/tensorflow/compiler/xla/service/memory_space_propagation_test.cc b/tensorflow/compiler/xla/service/memory_space_propagation_test.cc index 8d74958f6aa..de45af5a190 100644 --- a/tensorflow/compiler/xla/service/memory_space_propagation_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_propagation_test.cc @@ -199,5 +199,153 @@ TEST_F(MemorySpacePropagationTest, TupleOutput) { EXPECT_EQ(module->Hash(), ref->Hash()); } +TEST_F(MemorySpacePropagationTest, NestedInputFusion) { + // Tests propagating the memory space to nested fusions on the input side. + absl::string_view hlo_string = R"( + HloModule NestedFusion + + %bitcast_fusion { + %bf_param = s32[3,2]{0,1:T(128)} parameter(0) + ROOT %bitcast = s32[6]{0:T(128)} bitcast(%bf_param) + } + + %fused_computation { + %param_1.3 = s32[1]{0:T(128)} parameter(1) + %constant.2 = s32[]{:T(128)} constant(-2147483648) + %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5 + %param_2.3 = s32[5]{0:T(128)} parameter(2) + %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0 + %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3) + %param_0.1 = s32[3,2]{0,1:T(128)} parameter(0) + %fusion.1 = s32[6]{0:T(128)} fusion(%param_0.1), kind=kLoop, calls=bitcast_fusion + ROOT %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %fusion.1) + } + + ENTRY %entry { + %param0 = s32[3,2]{0,1:T(128)} parameter(0) + %param1 = s32[1]{0:T(128)} parameter(1) + %param2 = s32[5]{0:T(128)} parameter(2) + %arg0 = s32[3,2]{0,1:T(128)S(1)} copy(%param0) + %arg1 = s32[1]{0:T(128)} copy(%param1) + %arg2 = s32[5]{0:T(128)S(1)} copy(%param2) + %fusion = s32[6]{0:T(128)S(1)} fusion(s32[3,2]{0,1:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation + ROOT %root = s32[6]{0:T(128)} copy(%fusion) + } + )"; + absl::string_view expected_hlo_string = R"( + HloModule NestedFusion + + %bitcast_fusion { + %bf_param = s32[3,2]{0,1:T(128)S(1)} parameter(0) + ROOT %bitcast = s32[6]{0:T(128)S(1)} bitcast(%bf_param) + } + + %fused_computation { + %param_1.3 = s32[1]{0:T(128)} parameter(1) + %constant.2 = s32[]{:T(128)} constant(-2147483648) + %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5 + %param_2.3 = s32[5]{0:T(128)S(1)} parameter(2) + %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0 + %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3) + %param_0.1 = s32[3,2]{0,1:T(128)S(1)} parameter(0) + %fusion.1 = s32[6]{0:T(128)S(1)} fusion(%param_0.1), kind=kLoop, calls=bitcast_fusion + ROOT %add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)S(1)} %fusion.1) + } + + ENTRY %entry { + %param0 = s32[3,2]{0,1:T(128)} parameter(0) + %param1 = s32[1]{0:T(128)} parameter(1) + %param2 = s32[5]{0:T(128)} parameter(2) + %arg0 = s32[3,2]{0,1:T(128)S(1)} copy(%param0) + %arg1 = s32[1]{0:T(128)} copy(%param1) + %arg2 = s32[5]{0:T(128)S(1)} copy(%param2) + %fusion = s32[6]{0:T(128)S(1)} fusion(s32[3,2]{0,1:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation + ROOT %root = s32[6]{0:T(128)} copy(%fusion) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + MemorySpacePropagation memory_space_propagation; + EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie()); + TF_EXPECT_OK(Verify(module.get())); + TF_ASSERT_OK_AND_ASSIGN(auto ref, + ParseAndReturnVerifiedModule(expected_hlo_string)); + EXPECT_EQ(module->Hash(), ref->Hash()); +} + +TEST_F(MemorySpacePropagationTest, NestedOutputFusion) { + // Tests propagating the memory space to nested fusions on the output side. + absl::string_view hlo_string = R"( + HloModule NestedFusion + + %bitcast_fusion { + %bf_param = s32[6]{0:T(128)} parameter(0) + ROOT %bitcast = s32[3,2]{0,1:T(128)} bitcast(%bf_param) + } + + %fused_computation { + %param_1.3 = s32[1]{0:T(128)} parameter(1) + %constant.2 = s32[]{:T(128)} constant(-2147483648) + %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5 + %param_2.3 = s32[5]{0:T(128)} parameter(2) + %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0 + %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3) + %param_0.1 = s32[6]{0:T(128)} parameter(0) + %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1) + ROOT %fusion.1 = s32[3,2]{0,1:T(128)} fusion(%add.0), kind=kLoop, calls=bitcast_fusion + } + + ENTRY %entry { + %param0 = s32[6]{0:T(128)} parameter(0) + %param1 = s32[1]{0:T(128)} parameter(1) + %param2 = s32[5]{0:T(128)} parameter(2) + %arg0 = s32[6]{0:T(128)S(1)} copy(%param0) + %arg1 = s32[1]{0:T(128)} copy(%param1) + %arg2 = s32[5]{0:T(128)S(1)} copy(%param2) + %fusion = s32[3,2]{0,1:T(128)S(1)} fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation + ROOT %root = s32[3,2]{0,1:T(128)} copy(%fusion) + } + )"; + absl::string_view expected_hlo_string = R"( + HloModule NestedFusion + + %bitcast_fusion { + %bf_param = s32[6]{0:T(128)S(1)} parameter(0) + ROOT %bitcast = s32[3,2]{0,1:T(128)S(1)} bitcast(%bf_param) + } + + %fused_computation { + %param_1.3 = s32[1]{0:T(128)} parameter(1) + %constant.2 = s32[]{:T(128)} constant(-2147483648) + %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5 + %param_2.3 = s32[5]{0:T(128)S(1)} parameter(2) + %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0 + %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3) + %param_0.1 = s32[6]{0:T(128)S(1)} parameter(0) + %add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)S(1)} %param_0.1) + ROOT %fusion.1 = s32[3,2]{0,1:T(128)S(1)} fusion(%add.0), kind=kLoop, calls=bitcast_fusion + } + + ENTRY %entry { + %param0 = s32[6]{0:T(128)} parameter(0) + %param1 = s32[1]{0:T(128)} parameter(1) + %param2 = s32[5]{0:T(128)} parameter(2) + %arg0 = s32[6]{0:T(128)S(1)} copy(%param0) + %arg1 = s32[1]{0:T(128)} copy(%param1) + %arg2 = s32[5]{0:T(128)S(1)} copy(%param2) + %fusion = s32[3,2]{0,1:T(128)S(1)} fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation + ROOT %root = s32[3,2]{0,1:T(128)} copy(%fusion) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + MemorySpacePropagation memory_space_propagation; + EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie()); + TF_EXPECT_OK(Verify(module.get())); + TF_ASSERT_OK_AND_ASSIGN(auto ref, + ParseAndReturnVerifiedModule(expected_hlo_string)); + EXPECT_EQ(module->Hash(), ref->Hash()); +} + } // namespace } // namespace xla