diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index fc57210359f..67a8292d6cc 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -4820,6 +4820,40 @@ tf_cc_test( ], ) +cc_library( + name = "while_loop_expensive_invariant_code_motion", + srcs = ["while_loop_expensive_invariant_code_motion.cc"], + hdrs = ["while_loop_expensive_invariant_code_motion.h"], + deps = [ + ":hlo", + ":hlo_pass", + ":tuple_util", + ":while_loop_analysis", + ":while_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + ], +) + +tf_cc_test( + name = "while_loop_expensive_invariant_code_motion_test", + srcs = ["while_loop_expensive_invariant_code_motion_test.cc"], + deps = [ + ":hlo_matchers", + ":hlo_parser", + ":while_loop_expensive_invariant_code_motion", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "while_loop_constant_sinking", srcs = ["while_loop_constant_sinking.cc"], diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index cb1b1d0dae4..ae9bdc1999f 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -417,6 +417,11 @@ inline ::testing::Matcher NoSharding() { new ::xla::testing::HloShardingMatcher(absl::nullopt)); } +inline ::testing::Matcher Dot() { + return ::testing::MakeMatcher( + new ::xla::testing::HloMatcher(::xla::HloOpcode::kDot, {})); +} + inline ::testing::Matcher Dot( ::testing::Matcher lhs_matcher, ::testing::Matcher rhs_matcher) { diff --git a/tensorflow/compiler/xla/service/while_loop_expensive_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_expensive_invariant_code_motion.cc new file mode 100644 index 00000000000..7e809a4976e --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_expensive_invariant_code_motion.cc @@ -0,0 +1,376 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_expensive_invariant_code_motion.h" + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "tensorflow/compiler/xla/service/tuple_util.h" +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" +#include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +namespace { +using absl::flat_hash_map; +using absl::flat_hash_set; +using absl::InlinedVector; + +struct InvariantInfo { + explicit InvariantInfo(int64 user_count) : remaining_user_count(user_count) {} + // The transitive input size of all input operands, traced up to the while + // loop parameter or leaf invariant ops. + int64 transitive_input_size = 0; + // The remaining users count that remain in the body after all hoistable + // invariant users are hoisted. This number excludes the root instruction. + int64 remaining_user_count; + // If this instruction is hoisted, this stores the copy outside the body. + HloInstruction* hoisted_copy = nullptr; + // Hoistable instructions depending on this op to be hoisted. + InlinedVector blocked_users; +}; + +// Copies `to_hoist` to the computation containing `while_instr`, hoisting its +// operands as needed. All of its transitive operands are expected to be in +// `invariant_instructions`. This function hoists the operands in +// `invariant_instructions` and sets the entry's hoisted_copy to the hoisted +// instruction. +static void CreateLoopInvariantCopy( + flat_hash_map* invariant_instructions, + HloInstruction* while_instr, HloInstruction* to_hoist) { + HloComputation* parent_of_while = while_instr->parent(); + HloComputation* while_body = while_instr->while_body(); + + struct DFSFrame { + HloInstruction* instruction; + int64 operand_index; + }; + + InlinedVector dfs_stack; + dfs_stack.push_back({to_hoist, 0}); + + HloInstruction* while_body_param = while_body->parameter_instruction(0); + HloInstruction* while_operand = while_instr->mutable_operand(0); + + do { + DFSFrame* frame = &dfs_stack.back(); + // All of the operands for old_instruction have been cloned, so it is time + // to clone old_instruction itself. + if (frame->operand_index == frame->instruction->operand_count()) { + HloInstruction* old_instruction = frame->instruction; + InvariantInfo& info = FindOrDie(*invariant_instructions, old_instruction); + + // Check if this instruction might have already been hoisted. + if (info.hoisted_copy == nullptr) { + auto get_new_operand = [&](HloInstruction* old_operand) { + return old_operand == while_body_param + ? while_operand + : FindOrDie(*invariant_instructions, old_operand) + .hoisted_copy; + }; + + InlinedVector new_operands; + absl::c_transform(old_instruction->operands(), + std::back_inserter(new_operands), get_new_operand); + + HloInstruction* new_instruction = parent_of_while->AddInstruction( + old_instruction->CloneWithNewOperands(old_instruction->shape(), + new_operands)); + + info.hoisted_copy = new_instruction; + } + + dfs_stack.pop_back(); + continue; + } + + HloInstruction* next_operand = + frame->instruction->mutable_operand(frame->operand_index++); + if (next_operand == while_body_param || + FindOrDie(*invariant_instructions, next_operand).hoisted_copy != + nullptr) { + continue; + } + + dfs_stack.push_back({next_operand, 0}); + } while (!dfs_stack.empty()); +} +} // namespace + +StatusOr WhileLoopExpensiveInvariantCodeMotion:: + TryHoistingInvariantInstructionsFromWhileBody(HloInstruction* while_instr) { + auto print_no_metadata = HloPrintOptions{}.set_print_metadata(false); + + if (!while_instr->shape().IsTuple()) { + // This restriction leaves one interesting pattern on the table: + // + // while_body(f32[1024, 1024] %param) { + // %value = expensive_op(%param) + // outfeed(%value) + // ROOT = %param + // } + // + // If we see that pattern in the while, instead of generalizing this + // algorithm to work with non-tuples, we should instead add a pass that + // canonicalizes while loops like the above to use a tuple state. + return false; + } + + string while_instr_name = while_instr->ToString(print_no_metadata); + VLOG(2) << "Trying to hoist from " << while_instr_name; + + auto maybe_upper_bound = ComputeWhileLoopTripCountUpperBound(while_instr); + if (maybe_upper_bound && *maybe_upper_bound <= 1) { + VLOG(2) << "Loop has a trip count of at most 1, skipping."; + return false; + } + + HloComputation* while_body = while_instr->while_body(); + + // Contains the information for all invariant instructions that can be legally + // hoisted. When we hoist an instruction in this set, we set its hoisted_copy + // field to the hoisted instruction. + flat_hash_map invariant_instructions; + + // Map from an invariant instruction to the number of remaining unresolved + // operands, i.e. operands used by unvisited instructions. If all these + // operands are used by other invariant instructions, then hoisting out that + // operand won't leave a copy of itself in the body and it's free to hoist. + flat_hash_map to_hoist_when_ready; + + // Identify invariant GTE instructions so that we can identify its users that + // are also invariants. + for (auto* instr : WhileUtil::GetInvariantGTEsForWhileBody(*while_body)) { + // TODO(b/79147885): We should try to generalize this to tuples for + // uniformity's sake, if nothing else. + if (instr->shape().IsArray()) { + // We subtract 1 from user_count because we know one of the users is root. + auto emplace_result = invariant_instructions.emplace( + instr, InvariantInfo(/*user_count=*/instr->user_count() - 1)); + CHECK(emplace_result.second); + InvariantInfo& info = emplace_result.first->second; + info.transitive_input_size = shape_size_function_(instr->shape()); + } + } + + // LICM in the presence of domain instructions is complex, bail. + for (auto* instruction : while_body->MakeInstructionPostOrder()) { + if (instruction->opcode() == HloOpcode::kDomain) { + return false; + } + } + + // instructions_to_replace[i] is hoisted into a loop invariant instruction + // replacement_instructions[i]. + std::vector instructions_to_replace; + std::vector replacement_instructions; + + auto hoist = [&](HloInstruction* instruction, const InvariantInfo& info) { + if (info.hoisted_copy) { + // Already hoisted. + return; + } + VLOG(2) << "Hoisting " << instruction->ToString(print_no_metadata); + + CreateLoopInvariantCopy(&invariant_instructions, while_instr, instruction); + + instructions_to_replace.push_back(instruction); + replacement_instructions.push_back(info.hoisted_copy); + }; + + // Temporary helper container for marking a operand as checked when + // decrementing its remaining_user_count counter. Cleared after each + // iteration. + flat_hash_set checked_operands; + + for (auto* instruction : while_body->MakeInstructionPostOrder()) { + if (instruction->HasSideEffect() || + instruction->opcode() == HloOpcode::kParameter || + !instruction->control_predecessors().empty() || + !instruction->control_successors().empty() || + instruction == while_body->root_instruction()) { + continue; + } + + auto is_invariant = [&](HloInstruction* op) { + return invariant_instructions.find(op) != invariant_instructions.end(); + }; + + if (!absl::c_all_of(instruction->operands(), is_invariant)) { + continue; + } + + auto emplace_result = invariant_instructions.emplace( + instruction, InvariantInfo(/*user_count=*/instruction->user_count())); + CHECK(emplace_result.second); + InvariantInfo& instr_info = emplace_result.first->second; + // If root is a users of it, substract 1 from remaining user count as we + // don't want root to be blocking other users from being hoisted. Note that + // for invariant parameter GTEs, they will skip the iteration because their + // operand parameter(0) is not invariant, and they are put into + // invariant_instructions before this loop. + for (auto* user : instruction->users()) { + if (user == while_body->root_instruction()) { + --instr_info.remaining_user_count; + break; + } + } + + int64 num_blocking_operands = 0; + // Check that hoisting the instruction doesn't cause a significant memory + // blow-up. LICM extends the live-range of the output of the hoisted + // instruction to be the entire while loop, which may be problematic on + // platforms where memory is limited. This can be especially harmful if + // the instruction has a significantly larger output than its input, e.g. + // kIota, kBroadcast or kConstant. + int64 output_size = 0; + + for (auto* operand : instruction->operands()) { + auto& operand_info = invariant_instructions.at(operand); + if (!checked_operands.contains(operand)) { + instr_info.transitive_input_size += operand_info.transitive_input_size; + --operand_info.remaining_user_count; + checked_operands.insert(operand); + } + if (operand_info.remaining_user_count == 0) { + // All users are hoistable invariants, unblock held off users. + for (auto* user : operand_info.blocked_users) { + auto it = to_hoist_when_ready.find(user); + if (it != to_hoist_when_ready.end()) { + auto& num_blocking = it->second; + CHECK_GT(num_blocking, 0); + --num_blocking; + // Hoist a previously held off instruction now that there are no + // more blocking operands. + if (num_blocking == 0) { + hoist(user, invariant_instructions.at(user)); + to_hoist_when_ready.erase(it); + } + } + } + operand_info.blocked_users.clear(); + } else if (operand_info.remaining_user_count > 0) { + ++num_blocking_operands; + if (operand_info.blocked_users.empty() || + operand_info.blocked_users.back() != instruction) { + operand_info.blocked_users.push_back(instruction); + } + } else { + LOG(FATAL) + << "An instruction should not have number of negative users."; + } + } + checked_operands.erase(checked_operands.begin(), checked_operands.end()); + ShapeUtil::ForEachSubshape( + instruction->shape(), + [&output_size, this](const Shape& subshape, + const ShapeIndex& /*index*/) { + if (subshape.IsArray()) { + output_size += shape_size_function_(subshape); + } + }); + // If it is size-inflating, we leave it as is and potentially will still + // hoist it out if we later found a group of ops that are worth hoisting + // as a whole. + if (output_size > instr_info.transitive_input_size) { + continue; + } + + if (!worth_hoisting_individually_(*instruction)) { + continue; + } + + // Need to wait until we inspected the users of some operands until we can + // finally decide whether to hoist this instruction. + if (num_blocking_operands > 0) { + to_hoist_when_ready.emplace(instruction, num_blocking_operands); + continue; + } + + hoist(instruction, instr_info); + } + + if (instructions_to_replace.empty()) { + return false; + } + + TF_ASSIGN_OR_RETURN( + WhileUtil::MakeInstructionsLiveInResult live_in_instructions_result, + WhileUtil::MakeInstructionsLiveIn(while_instr, replacement_instructions)); + + HloComputation* new_while_body = + live_in_instructions_result.new_while_instr->while_body(); + + for (int i = 0; i < instructions_to_replace.size(); i++) { + HloInstruction* instruction_to_replace_in_new_while = + FindOrDie(live_in_instructions_result.while_body_instruction_map, + instructions_to_replace[i]); + TF_RETURN_IF_ERROR(new_while_body->ReplaceInstruction( + instruction_to_replace_in_new_while, + live_in_instructions_result.while_body_live_in_values[i])); + } + + VLOG(1) << "Hoisted " << instructions_to_replace.size() + << " instructions from " << while_instr_name; + + return true; +} + +StatusOr WhileLoopExpensiveInvariantCodeMotion::Run(HloModule* module) { + VLOG(2) << "HLO module before WhileLoopExpensiveInvariantCodeMotion:"; + XLA_VLOG_LINES(2, module->ToString()); + + bool changed = false; + std::vector while_instrs; + for (auto* comp : module->computations()) { + absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs), + [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }); + } + + for (HloInstruction* while_instr : while_instrs) { + // Right now we only hoist computations from the while body, but + // TryHoistingInvariantInstructionsFromWhileBody can be generalized to + // optimize the condition computation too, if needed. + // + // The transform we do here is a pessimization for while loops that execute + // zero times*, but at this time we expect those to be rare. If this + // becomes a problem we can consider using the conditional HLO to avoid + // doing extra work for while loops with zero trip count. + // + // * We delete while loops that have a zero trip count, so this would have + // to be a while loop with a somewhat opaque condition expression. + + TF_ASSIGN_OR_RETURN( + bool result, + TryHoistingInvariantInstructionsFromWhileBody(while_instr)); + changed |= result; + } + + if (changed) { + VLOG(2) << "HLO module after WhileLoopExpensiveInvariantCodeMotion:"; + XLA_VLOG_LINES(2, module->ToString()); + } else { + VLOG(2) + << "HLO module unchanged after WhileLoopExpensiveInvariantCodeMotion"; + } + + return changed; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_expensive_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_expensive_invariant_code_motion.h new file mode 100644 index 00000000000..d26824d7ddf --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_expensive_invariant_code_motion.h @@ -0,0 +1,54 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_EXPENSIVE_INVARIANT_CODE_MOTION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_EXPENSIVE_INVARIANT_CODE_MOTION_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// HLO pass that rewrites while loops to hoist expensive and non-size-inflating +// groups of loop invariant instructions in the while body into the computation +// that contains the while instruction. +// Users can specify worth_hoisting_individually, and only the groups +// instructions with a root that returns true with it will be hoisted out. +class WhileLoopExpensiveInvariantCodeMotion : public HloModulePass { + public: + using ShapeSizeFunction = std::function; + explicit WhileLoopExpensiveInvariantCodeMotion( + std::function worth_hoisting_individually, + ShapeSizeFunction shape_size_function = ShapeUtil::ByteSizeOfElements) + : shape_size_function_(std::move(shape_size_function)), + worth_hoisting_individually_(std::move(worth_hoisting_individually)) {} + ~WhileLoopExpensiveInvariantCodeMotion() override = default; + + absl::string_view name() const override { + return "while-loop-expensive-invariant-code-motion"; + } + StatusOr Run(HloModule* module) override; + + private: + StatusOr TryHoistingInvariantInstructionsFromWhileBody( + HloInstruction* while_instr); + + ShapeSizeFunction shape_size_function_; + std::function worth_hoisting_individually_; +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_EXPENSIVE_INVARIANT_CODE_MOTION_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_expensive_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_expensive_invariant_code_motion_test.cc new file mode 100644 index 00000000000..fb1d2ea1a89 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_expensive_invariant_code_motion_test.cc @@ -0,0 +1,240 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_expensive_invariant_code_motion.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +using WhileLoopExpensiveInvariantCodeMotionTest = HloTestBase; +namespace op = xla::testing::opcode_matchers; + +constexpr char kModuleWithNonInflatingInvariantDot[] = R"( +HloModule ModuleWithWhile + +mul { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT mul = f32[] multiply(lhs, rhs) +} + +body { + p_body = (f32[], f32[16, 8]) parameter(0) + b = get-tuple-element(p_body), index=1 + const = f32[] constant(1.0) + lhs = f32[8, 16] broadcast(const), dimensions={} + dot = dot(lhs, b), lhs_contracting_dims={1}, rhs_contracting_dims={0} + reduced = reduce(dot, const), dimensions={0, 1}, to_apply=mul + a = get-tuple-element(p_body), index=0 + add = add(reduced, a) + ROOT root = tuple(add, b) +} + +condition { + p_cond = (f32[], f32[16, 8]) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + param0 = f32[] parameter(0) + param1 = f32[16, 8] parameter(1) + while_init = tuple(param0, param1) + ROOT while = while(while_init), condition=condition, body=body +} +)"; + +TEST_F(WhileLoopExpensiveInvariantCodeMotionTest, + HoistsGroupOfAllowedNonInflating) { + auto m = ParseAndReturnVerifiedModule(kModuleWithNonInflatingInvariantDot) + .ValueOrDie(); + + TF_ASSERT_OK_AND_ASSIGN( + bool simplified_loop, + WhileLoopExpensiveInvariantCodeMotion( + /*worth_hoisting_individually=*/[](const HloInstruction& instr) { + return instr.opcode() == HloOpcode::kDot; + }) + .Run(m.get())); + EXPECT_TRUE(simplified_loop); + + HloComputation* while_body = m->GetComputationWithName("wide.body"); + ASSERT_NE(while_body, nullptr); + EXPECT_THAT(while_body->instructions(), Not(Contains(op::Dot()))); + // kReduce not in the allow list. + EXPECT_THAT(while_body->instructions(), Contains(op::Reduce())); +} + +TEST_F(WhileLoopExpensiveInvariantCodeMotionTest, + HoistsGroupOfAllNonInflating) { + auto m = ParseAndReturnVerifiedModule(kModuleWithNonInflatingInvariantDot) + .ValueOrDie(); + + TF_ASSERT_OK_AND_ASSIGN( + bool simplified_loop, + WhileLoopExpensiveInvariantCodeMotion( + /*worth_hoisting_individually=*/[](const HloInstruction& instr) { + return instr.opcode() == HloOpcode::kDot || + instr.opcode() == HloOpcode::kReduce; + }) + .Run(m.get())); + EXPECT_TRUE(simplified_loop); + + HloComputation* while_body = m->GetComputationWithName("wide.body"); + ASSERT_NE(while_body, nullptr); + EXPECT_THAT(while_body->instructions(), Not(Contains(op::Dot()))); + EXPECT_THAT(while_body->instructions(), Not(Contains(op::Reduce()))); +} + +TEST_F(WhileLoopExpensiveInvariantCodeMotionTest, + DoesNotHoistsUnallowedInstructions) { + auto m = ParseAndReturnVerifiedModule(kModuleWithNonInflatingInvariantDot) + .ValueOrDie(); + + TF_ASSERT_OK_AND_ASSIGN( + bool simplified_loop, + WhileLoopExpensiveInvariantCodeMotion( + /*worth_hoisting_individually=*/[](const HloInstruction& instr) { + return false; + }) + .Run(m.get())); + EXPECT_FALSE(simplified_loop); +} + +constexpr char kModuleWithInflatingInvariantDot[] = R"( +HloModule ModuleWithWhile + +mul { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT mul = f32[] multiply(lhs, rhs) +} + +body { + p_body = (f32[], f32[16, 4]) parameter(0) + b = get-tuple-element(p_body), index=1 + const = f32[] constant(1.0) + lhs = f32[4, 16] broadcast(const), dimensions={} + dot = dot(lhs, b), lhs_contracting_dims={0}, rhs_contracting_dims={1} + reduced = reduce(dot, const), dimensions={0, 1}, to_apply=mul + a = get-tuple-element(p_body), index=0 + add = add(reduced, a) + ROOT root = tuple(add, b) +} + +condition { + p_cond = (f32[], f32[16, 4]) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + param0 = f32[] parameter(0) + param1 = f32[16, 4] parameter(1) + while_init = tuple(param0, param1) + ROOT while = while(while_init), condition=condition, body=body +} +)"; + +TEST_F(WhileLoopExpensiveInvariantCodeMotionTest, DoesNotHoistsInflating) { + auto m = ParseAndReturnVerifiedModule(kModuleWithInflatingInvariantDot) + .ValueOrDie(); + + TF_ASSERT_OK_AND_ASSIGN( + bool simplified_loop, + WhileLoopExpensiveInvariantCodeMotion( + /*worth_hoisting_individually=*/[](const HloInstruction& instr) { + return instr.opcode() == HloOpcode::kDot; + }) + .Run(m.get())); + EXPECT_FALSE(simplified_loop); +} + +TEST_F(WhileLoopExpensiveInvariantCodeMotionTest, + HoistsGroupOfNonInflatingWithInflatingIntermediate) { + auto m = ParseAndReturnVerifiedModule(kModuleWithInflatingInvariantDot) + .ValueOrDie(); + + TF_ASSERT_OK_AND_ASSIGN( + bool simplified_loop, + WhileLoopExpensiveInvariantCodeMotion( + /*worth_hoisting_individually=*/[](const HloInstruction& instr) { + return instr.opcode() == HloOpcode::kDot || + instr.opcode() == HloOpcode::kReduce; + }) + .Run(m.get())); + EXPECT_TRUE(simplified_loop); + + HloComputation* while_body = m->GetComputationWithName("wide.body"); + ASSERT_NE(while_body, nullptr); + EXPECT_THAT(while_body->instructions(), Not(Contains(op::Dot()))); + EXPECT_THAT(while_body->instructions(), Not(Contains(op::Reduce()))); +} + +TEST_F(WhileLoopExpensiveInvariantCodeMotionTest, + HoistsOpWithDuplicateOperands) { + constexpr char kModuleWithDuplicateOperands[] = R"( +HloModule ModuleWithWhile + +mul { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT mul = f32[] multiply(lhs, rhs) +} + +body { + p_body = (f32[4, 4], f32[4, 4]) parameter(0) + a = get-tuple-element(p_body), index=0 + dot = dot(a, a), lhs_contracting_dims={0}, rhs_contracting_dims={1} + b = get-tuple-element(p_body), index=1 + add = add(b, dot) + ROOT root = tuple(a, add) +} + +condition { + p_cond = (f32[4, 4], f32[4, 4]) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + param0 = f32[4, 4] parameter(0) + param1 = f32[4, 4] parameter(1) + while_init = tuple(param0, param1) + ROOT while = while(while_init), condition=condition, body=body +} +)"; + auto m = + ParseAndReturnVerifiedModule(kModuleWithDuplicateOperands).ValueOrDie(); + + TF_ASSERT_OK_AND_ASSIGN( + bool simplified_loop, + WhileLoopExpensiveInvariantCodeMotion( + /*worth_hoisting_individually=*/[](const HloInstruction& instr) { + return instr.opcode() == HloOpcode::kDot; + }) + .Run(m.get())); + EXPECT_TRUE(simplified_loop); + + HloComputation* while_body = m->GetComputationWithName("wide.body"); + ASSERT_NE(while_body, nullptr); + EXPECT_THAT(while_body->instructions(), Not(Contains(op::Dot()))); +} + +} // namespace +} // namespace xla