From ee4736e5bf74251f7b871351a9e430420a530f6c Mon Sep 17 00:00:00 2001 From: Yunxing Dai Date: Fri, 28 Aug 2020 15:37:35 -0700 Subject: [PATCH] [XLA] Support hoist copy in code motion. PiperOrigin-RevId: 329021839 Change-Id: I699c547f462466508b90289468f46e400008b4c6 --- tensorflow/compiler/xla/service/BUILD | 1 + .../xla/service/conditional_code_motion.cc | 17 +++++- .../service/conditional_code_motion_test.cc | 60 +++++++++++++++++++ tensorflow/compiler/xla/shape_tree.h | 4 +- 4 files changed, 80 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index dd16bd32dd1..a1d6959ed37 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -2343,6 +2343,7 @@ cc_library( ":hlo_dce", ":hlo_pass", ":hlo_pass_pipeline", + ":hlo_verifier", ":tuple_simplifier", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/xla/service/conditional_code_motion.cc b/tensorflow/compiler/xla/service/conditional_code_motion.cc index ce80b4cfc15..42caf20ff80 100644 --- a/tensorflow/compiler/xla/service/conditional_code_motion.cc +++ b/tensorflow/compiler/xla/service/conditional_code_motion.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -114,6 +115,8 @@ int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) { case HloOpcode::kConstant: case HloOpcode::kGetTupleElement: return 0; + case HloOpcode::kConditional: + return 10; default: // Assume fusion will not happen anyway if user count > 1) if (op->user_count() > 1) { @@ -582,6 +585,7 @@ StatusOr ConditionalCodeMotion::MoveInstructionIn( // to replace the conditional directly in the new computation. b_opd_use.mutable_operands().push_back(conditional); } + HloInstruction* new_root = computation->AddInstruction(HloInstruction::CreateTuple(operands)); VLOG(2) << "setting new root: " << new_root->ToString() << "\n"; @@ -592,6 +596,15 @@ StatusOr ConditionalCodeMotion::MoveInstructionIn( } VLOG(2) << "new branch computation: " << computation->ToString() << "\n"; } + // Update get tuple element index of the conditional. + if (use_index != -1) { + for (auto* user : conditional->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement && + user->tuple_index() > use_index) { + user->set_tuple_index(user->tuple_index() - 1); + } + } + } hoisted_instructions[conditional] = b_old_root; int64 cp_start = 0; if (use_index >= 0) { @@ -677,7 +690,7 @@ class GroupConnectedBoundaries { : conditional_(conditional), conditional_parent_(conditional->parent()), is_layout_sensitive_(is_layout_sensitive) {} - // Returns true if `instruction` is worth hoisting out. + // Returns true if `instruction` is worth hoisting. bool WorthHoisting(HloInstruction* instruction) { // This is needed for the "moving-in" transformation, to prevent the root // of the parent computation (which contains the conditional) to be moved @@ -708,6 +721,7 @@ class GroupConnectedBoundaries { case HloOpcode::kAllReduce: case HloOpcode::kAdd: case HloOpcode::kPower: + case HloOpcode::kCopy: case HloOpcode::kConstant: case HloOpcode::kSubtract: case HloOpcode::kMultiply: @@ -1070,6 +1084,7 @@ StatusOr ConditionalCodeMotion::Run(HloModule* module) { subpipeline.AddPass(); subpipeline.AddPass(); subpipeline.AddPass(); + subpipeline.AddPass(false, true); TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module)); changed |= cleanup_changed; } diff --git a/tensorflow/compiler/xla/service/conditional_code_motion_test.cc b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc index b91f3813980..e5e3873cc66 100644 --- a/tensorflow/compiler/xla/service/conditional_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc @@ -728,6 +728,66 @@ ENTRY main { EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional()))); } +TEST_F(ConditionalCodeMotionTest, MoveCopyInBranch) { + absl::string_view hlo_string = + R"( +HloModule RemoveIdenticalInstruction + +branch1 { + arg_tuple.1 = (s32[], f32[10,3]{0,1}) parameter(0) + constant.1 = s32[] constant(4) + get-tuple-element.1 = s32[] get-tuple-element(arg_tuple.1), index=0 + add.1 = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = f32[10,3]{0,1} get-tuple-element(arg_tuple.1), index=1 + slice.1 = f32[4,3]{0,1} slice(get-tuple-element.2), + slice={[0:4:1], [0:3:1]} + constant.2 = f32[] constant(0.0) + ROOT tuple.1 = (f32[4,3]{0,1}, s32[],f32[]) tuple(slice.1, add.1, constant.2) +} + +branch2 { + arg_tuple.2 = (s32[], f32[4,3]{1,0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(arg_tuple.2), index=0 + copy.1 = s32[] copy(get-tuple-element.3) + get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element(arg_tuple.2), index=1 + copy.2 = f32[4,3]{0,1} copy(get-tuple-element.4) + constant.2 = f32[] constant(0.0) + ROOT tuple.2 = (f32[4,3]{0,1}, s32[], f32[]) tuple(copy.2, copy.1, constant.2) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + tuple.3 = (s32[], f32[10,3]{0,1}) parameter(1) + tuple.4 = (s32[], f32[4,3]{1,0}) parameter(2) + conditional = (f32[4,3]{0,1}, s32[], f32[]) + conditional(pred.1, tuple.3, tuple.4), true_computation=branch1, + false_computation=branch2 + get-zero-index = f32[4,3]{0,1} get-tuple-element(conditional), index=0 + get-first-index = s32[] get-tuple-element(conditional), index=1 + get-second-index = f32[] get-tuple-element(conditional), index=2 + copy.3 = f32[4,3]{1,0} copy(get-zero-index) + ROOT tuple.5 = (f32[4,3]{0,1}, s32[], f32[]) tuple(copy.3, get-first-index, + get-second-index) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + VLOG(1) << module->ToString(); + + const HloInstruction* conditional = + FindInstruction(module.get(), "conditional"); + const HloComputation* on_true = conditional->branch_computation(0); + ASSERT_EQ(on_true->instruction_count(), 9); + const HloComputation* on_false = conditional->branch_computation(1); + ASSERT_EQ(on_false->instruction_count(), 8); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + AllOf(op::Tuple(op::GetTupleElement(op::Conditional(), 2), + op::GetTupleElement(op::Conditional(), 0), + op::GetTupleElement(op::Conditional(), 1)))); +} + } // namespace conditional_opt } // namespace xla diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index 73bb3327784..bc48a9c94d1 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -648,7 +648,9 @@ void ShapeTree::CopySubtreeFrom(const ShapeTree& other, const ShapeIndex& target_base_index) { CHECK(ShapeUtil::Compatible( ShapeUtil::GetSubshape(shape(), target_base_index), - ShapeUtil::GetSubshape(other.shape(), source_base_index))); + ShapeUtil::GetSubshape(other.shape(), source_base_index))) + << ShapeUtil::GetSubshape(shape(), target_base_index) << " vs " + << ShapeUtil::GetSubshape(other.shape(), source_base_index); ForEachMutableElement([this, &other, &source_base_index, &target_base_index]( const ShapeIndex& index, T* data) { // Copy the data element only if index is in the