[XLA] Add a variant of WhileLoopInvariantCodeMotion pass to hoist groups of loop expensive invariant HLOs out of while if they are not size-inflating as a whole.
PiperOrigin-RevId: 356334931 Change-Id: I9b933654926897fa12a99b8e9050c57db60fdd23
This commit is contained in:
parent
5ac570fa43
commit
bdd28bfb0c
@ -4820,6 +4820,40 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "while_loop_expensive_invariant_code_motion",
|
||||
srcs = ["while_loop_expensive_invariant_code_motion.cc"],
|
||||
hdrs = ["while_loop_expensive_invariant_code_motion.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_pass",
|
||||
":tuple_util",
|
||||
":while_loop_analysis",
|
||||
":while_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "while_loop_expensive_invariant_code_motion_test",
|
||||
srcs = ["while_loop_expensive_invariant_code_motion_test.cc"],
|
||||
deps = [
|
||||
":hlo_matchers",
|
||||
":hlo_parser",
|
||||
":while_loop_expensive_invariant_code_motion",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "while_loop_constant_sinking",
|
||||
srcs = ["while_loop_constant_sinking.cc"],
|
||||
|
@ -417,6 +417,11 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> NoSharding() {
|
||||
new ::xla::testing::HloShardingMatcher(absl::nullopt));
|
||||
}
|
||||
|
||||
inline ::testing::Matcher<const ::xla::HloInstruction*> Dot() {
|
||||
return ::testing::MakeMatcher(
|
||||
new ::xla::testing::HloMatcher(::xla::HloOpcode::kDot, {}));
|
||||
}
|
||||
|
||||
inline ::testing::Matcher<const ::xla::HloInstruction*> Dot(
|
||||
::testing::Matcher<const HloInstruction*> lhs_matcher,
|
||||
::testing::Matcher<const HloInstruction*> rhs_matcher) {
|
||||
|
@ -0,0 +1,376 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/while_loop_expensive_invariant_code_motion.h"
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "tensorflow/compiler/xla/service/tuple_util.h"
|
||||
#include "tensorflow/compiler/xla/service/while_loop_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/while_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
using absl::flat_hash_map;
|
||||
using absl::flat_hash_set;
|
||||
using absl::InlinedVector;
|
||||
|
||||
struct InvariantInfo {
|
||||
explicit InvariantInfo(int64 user_count) : remaining_user_count(user_count) {}
|
||||
// The transitive input size of all input operands, traced up to the while
|
||||
// loop parameter or leaf invariant ops.
|
||||
int64 transitive_input_size = 0;
|
||||
// The remaining users count that remain in the body after all hoistable
|
||||
// invariant users are hoisted. This number excludes the root instruction.
|
||||
int64 remaining_user_count;
|
||||
// If this instruction is hoisted, this stores the copy outside the body.
|
||||
HloInstruction* hoisted_copy = nullptr;
|
||||
// Hoistable instructions depending on this op to be hoisted.
|
||||
InlinedVector<HloInstruction*, 2> blocked_users;
|
||||
};
|
||||
|
||||
// Copies `to_hoist` to the computation containing `while_instr`, hoisting its
|
||||
// operands as needed. All of its transitive operands are expected to be in
|
||||
// `invariant_instructions`. This function hoists the operands in
|
||||
// `invariant_instructions` and sets the entry's hoisted_copy to the hoisted
|
||||
// instruction.
|
||||
static void CreateLoopInvariantCopy(
|
||||
flat_hash_map<HloInstruction*, InvariantInfo>* invariant_instructions,
|
||||
HloInstruction* while_instr, HloInstruction* to_hoist) {
|
||||
HloComputation* parent_of_while = while_instr->parent();
|
||||
HloComputation* while_body = while_instr->while_body();
|
||||
|
||||
struct DFSFrame {
|
||||
HloInstruction* instruction;
|
||||
int64 operand_index;
|
||||
};
|
||||
|
||||
InlinedVector<DFSFrame, 8> dfs_stack;
|
||||
dfs_stack.push_back({to_hoist, 0});
|
||||
|
||||
HloInstruction* while_body_param = while_body->parameter_instruction(0);
|
||||
HloInstruction* while_operand = while_instr->mutable_operand(0);
|
||||
|
||||
do {
|
||||
DFSFrame* frame = &dfs_stack.back();
|
||||
// All of the operands for old_instruction have been cloned, so it is time
|
||||
// to clone old_instruction itself.
|
||||
if (frame->operand_index == frame->instruction->operand_count()) {
|
||||
HloInstruction* old_instruction = frame->instruction;
|
||||
InvariantInfo& info = FindOrDie(*invariant_instructions, old_instruction);
|
||||
|
||||
// Check if this instruction might have already been hoisted.
|
||||
if (info.hoisted_copy == nullptr) {
|
||||
auto get_new_operand = [&](HloInstruction* old_operand) {
|
||||
return old_operand == while_body_param
|
||||
? while_operand
|
||||
: FindOrDie(*invariant_instructions, old_operand)
|
||||
.hoisted_copy;
|
||||
};
|
||||
|
||||
InlinedVector<HloInstruction*, 4> new_operands;
|
||||
absl::c_transform(old_instruction->operands(),
|
||||
std::back_inserter(new_operands), get_new_operand);
|
||||
|
||||
HloInstruction* new_instruction = parent_of_while->AddInstruction(
|
||||
old_instruction->CloneWithNewOperands(old_instruction->shape(),
|
||||
new_operands));
|
||||
|
||||
info.hoisted_copy = new_instruction;
|
||||
}
|
||||
|
||||
dfs_stack.pop_back();
|
||||
continue;
|
||||
}
|
||||
|
||||
HloInstruction* next_operand =
|
||||
frame->instruction->mutable_operand(frame->operand_index++);
|
||||
if (next_operand == while_body_param ||
|
||||
FindOrDie(*invariant_instructions, next_operand).hoisted_copy !=
|
||||
nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
dfs_stack.push_back({next_operand, 0});
|
||||
} while (!dfs_stack.empty());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
StatusOr<bool> WhileLoopExpensiveInvariantCodeMotion::
|
||||
TryHoistingInvariantInstructionsFromWhileBody(HloInstruction* while_instr) {
|
||||
auto print_no_metadata = HloPrintOptions{}.set_print_metadata(false);
|
||||
|
||||
if (!while_instr->shape().IsTuple()) {
|
||||
// This restriction leaves one interesting pattern on the table:
|
||||
//
|
||||
// while_body(f32[1024, 1024] %param) {
|
||||
// %value = expensive_op(%param)
|
||||
// outfeed(%value)
|
||||
// ROOT = %param
|
||||
// }
|
||||
//
|
||||
// If we see that pattern in the while, instead of generalizing this
|
||||
// algorithm to work with non-tuples, we should instead add a pass that
|
||||
// canonicalizes while loops like the above to use a tuple state.
|
||||
return false;
|
||||
}
|
||||
|
||||
string while_instr_name = while_instr->ToString(print_no_metadata);
|
||||
VLOG(2) << "Trying to hoist from " << while_instr_name;
|
||||
|
||||
auto maybe_upper_bound = ComputeWhileLoopTripCountUpperBound(while_instr);
|
||||
if (maybe_upper_bound && *maybe_upper_bound <= 1) {
|
||||
VLOG(2) << "Loop has a trip count of at most 1, skipping.";
|
||||
return false;
|
||||
}
|
||||
|
||||
HloComputation* while_body = while_instr->while_body();
|
||||
|
||||
// Contains the information for all invariant instructions that can be legally
|
||||
// hoisted. When we hoist an instruction in this set, we set its hoisted_copy
|
||||
// field to the hoisted instruction.
|
||||
flat_hash_map<HloInstruction*, InvariantInfo> invariant_instructions;
|
||||
|
||||
// Map from an invariant instruction to the number of remaining unresolved
|
||||
// operands, i.e. operands used by unvisited instructions. If all these
|
||||
// operands are used by other invariant instructions, then hoisting out that
|
||||
// operand won't leave a copy of itself in the body and it's free to hoist.
|
||||
flat_hash_map<HloInstruction*, int64> to_hoist_when_ready;
|
||||
|
||||
// Identify invariant GTE instructions so that we can identify its users that
|
||||
// are also invariants.
|
||||
for (auto* instr : WhileUtil::GetInvariantGTEsForWhileBody(*while_body)) {
|
||||
// TODO(b/79147885): We should try to generalize this to tuples for
|
||||
// uniformity's sake, if nothing else.
|
||||
if (instr->shape().IsArray()) {
|
||||
// We subtract 1 from user_count because we know one of the users is root.
|
||||
auto emplace_result = invariant_instructions.emplace(
|
||||
instr, InvariantInfo(/*user_count=*/instr->user_count() - 1));
|
||||
CHECK(emplace_result.second);
|
||||
InvariantInfo& info = emplace_result.first->second;
|
||||
info.transitive_input_size = shape_size_function_(instr->shape());
|
||||
}
|
||||
}
|
||||
|
||||
// LICM in the presence of domain instructions is complex, bail.
|
||||
for (auto* instruction : while_body->MakeInstructionPostOrder()) {
|
||||
if (instruction->opcode() == HloOpcode::kDomain) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// instructions_to_replace[i] is hoisted into a loop invariant instruction
|
||||
// replacement_instructions[i].
|
||||
std::vector<HloInstruction*> instructions_to_replace;
|
||||
std::vector<HloInstruction*> replacement_instructions;
|
||||
|
||||
auto hoist = [&](HloInstruction* instruction, const InvariantInfo& info) {
|
||||
if (info.hoisted_copy) {
|
||||
// Already hoisted.
|
||||
return;
|
||||
}
|
||||
VLOG(2) << "Hoisting " << instruction->ToString(print_no_metadata);
|
||||
|
||||
CreateLoopInvariantCopy(&invariant_instructions, while_instr, instruction);
|
||||
|
||||
instructions_to_replace.push_back(instruction);
|
||||
replacement_instructions.push_back(info.hoisted_copy);
|
||||
};
|
||||
|
||||
// Temporary helper container for marking a operand as checked when
|
||||
// decrementing its remaining_user_count counter. Cleared after each
|
||||
// iteration.
|
||||
flat_hash_set<HloInstruction*> checked_operands;
|
||||
|
||||
for (auto* instruction : while_body->MakeInstructionPostOrder()) {
|
||||
if (instruction->HasSideEffect() ||
|
||||
instruction->opcode() == HloOpcode::kParameter ||
|
||||
!instruction->control_predecessors().empty() ||
|
||||
!instruction->control_successors().empty() ||
|
||||
instruction == while_body->root_instruction()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto is_invariant = [&](HloInstruction* op) {
|
||||
return invariant_instructions.find(op) != invariant_instructions.end();
|
||||
};
|
||||
|
||||
if (!absl::c_all_of(instruction->operands(), is_invariant)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto emplace_result = invariant_instructions.emplace(
|
||||
instruction, InvariantInfo(/*user_count=*/instruction->user_count()));
|
||||
CHECK(emplace_result.second);
|
||||
InvariantInfo& instr_info = emplace_result.first->second;
|
||||
// If root is a users of it, substract 1 from remaining user count as we
|
||||
// don't want root to be blocking other users from being hoisted. Note that
|
||||
// for invariant parameter GTEs, they will skip the iteration because their
|
||||
// operand parameter(0) is not invariant, and they are put into
|
||||
// invariant_instructions before this loop.
|
||||
for (auto* user : instruction->users()) {
|
||||
if (user == while_body->root_instruction()) {
|
||||
--instr_info.remaining_user_count;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
int64 num_blocking_operands = 0;
|
||||
// Check that hoisting the instruction doesn't cause a significant memory
|
||||
// blow-up. LICM extends the live-range of the output of the hoisted
|
||||
// instruction to be the entire while loop, which may be problematic on
|
||||
// platforms where memory is limited. This can be especially harmful if
|
||||
// the instruction has a significantly larger output than its input, e.g.
|
||||
// kIota, kBroadcast or kConstant.
|
||||
int64 output_size = 0;
|
||||
|
||||
for (auto* operand : instruction->operands()) {
|
||||
auto& operand_info = invariant_instructions.at(operand);
|
||||
if (!checked_operands.contains(operand)) {
|
||||
instr_info.transitive_input_size += operand_info.transitive_input_size;
|
||||
--operand_info.remaining_user_count;
|
||||
checked_operands.insert(operand);
|
||||
}
|
||||
if (operand_info.remaining_user_count == 0) {
|
||||
// All users are hoistable invariants, unblock held off users.
|
||||
for (auto* user : operand_info.blocked_users) {
|
||||
auto it = to_hoist_when_ready.find(user);
|
||||
if (it != to_hoist_when_ready.end()) {
|
||||
auto& num_blocking = it->second;
|
||||
CHECK_GT(num_blocking, 0);
|
||||
--num_blocking;
|
||||
// Hoist a previously held off instruction now that there are no
|
||||
// more blocking operands.
|
||||
if (num_blocking == 0) {
|
||||
hoist(user, invariant_instructions.at(user));
|
||||
to_hoist_when_ready.erase(it);
|
||||
}
|
||||
}
|
||||
}
|
||||
operand_info.blocked_users.clear();
|
||||
} else if (operand_info.remaining_user_count > 0) {
|
||||
++num_blocking_operands;
|
||||
if (operand_info.blocked_users.empty() ||
|
||||
operand_info.blocked_users.back() != instruction) {
|
||||
operand_info.blocked_users.push_back(instruction);
|
||||
}
|
||||
} else {
|
||||
LOG(FATAL)
|
||||
<< "An instruction should not have number of negative users.";
|
||||
}
|
||||
}
|
||||
checked_operands.erase(checked_operands.begin(), checked_operands.end());
|
||||
ShapeUtil::ForEachSubshape(
|
||||
instruction->shape(),
|
||||
[&output_size, this](const Shape& subshape,
|
||||
const ShapeIndex& /*index*/) {
|
||||
if (subshape.IsArray()) {
|
||||
output_size += shape_size_function_(subshape);
|
||||
}
|
||||
});
|
||||
// If it is size-inflating, we leave it as is and potentially will still
|
||||
// hoist it out if we later found a group of ops that are worth hoisting
|
||||
// as a whole.
|
||||
if (output_size > instr_info.transitive_input_size) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!worth_hoisting_individually_(*instruction)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Need to wait until we inspected the users of some operands until we can
|
||||
// finally decide whether to hoist this instruction.
|
||||
if (num_blocking_operands > 0) {
|
||||
to_hoist_when_ready.emplace(instruction, num_blocking_operands);
|
||||
continue;
|
||||
}
|
||||
|
||||
hoist(instruction, instr_info);
|
||||
}
|
||||
|
||||
if (instructions_to_replace.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
WhileUtil::MakeInstructionsLiveInResult live_in_instructions_result,
|
||||
WhileUtil::MakeInstructionsLiveIn(while_instr, replacement_instructions));
|
||||
|
||||
HloComputation* new_while_body =
|
||||
live_in_instructions_result.new_while_instr->while_body();
|
||||
|
||||
for (int i = 0; i < instructions_to_replace.size(); i++) {
|
||||
HloInstruction* instruction_to_replace_in_new_while =
|
||||
FindOrDie(live_in_instructions_result.while_body_instruction_map,
|
||||
instructions_to_replace[i]);
|
||||
TF_RETURN_IF_ERROR(new_while_body->ReplaceInstruction(
|
||||
instruction_to_replace_in_new_while,
|
||||
live_in_instructions_result.while_body_live_in_values[i]));
|
||||
}
|
||||
|
||||
VLOG(1) << "Hoisted " << instructions_to_replace.size()
|
||||
<< " instructions from " << while_instr_name;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
StatusOr<bool> WhileLoopExpensiveInvariantCodeMotion::Run(HloModule* module) {
|
||||
VLOG(2) << "HLO module before WhileLoopExpensiveInvariantCodeMotion:";
|
||||
XLA_VLOG_LINES(2, module->ToString());
|
||||
|
||||
bool changed = false;
|
||||
std::vector<HloInstruction*> while_instrs;
|
||||
for (auto* comp : module->computations()) {
|
||||
absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs),
|
||||
[](const HloInstruction* instr) {
|
||||
return instr->opcode() == HloOpcode::kWhile;
|
||||
});
|
||||
}
|
||||
|
||||
for (HloInstruction* while_instr : while_instrs) {
|
||||
// Right now we only hoist computations from the while body, but
|
||||
// TryHoistingInvariantInstructionsFromWhileBody can be generalized to
|
||||
// optimize the condition computation too, if needed.
|
||||
//
|
||||
// The transform we do here is a pessimization for while loops that execute
|
||||
// zero times*, but at this time we expect those to be rare. If this
|
||||
// becomes a problem we can consider using the conditional HLO to avoid
|
||||
// doing extra work for while loops with zero trip count.
|
||||
//
|
||||
// * We delete while loops that have a zero trip count, so this would have
|
||||
// to be a while loop with a somewhat opaque condition expression.
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool result,
|
||||
TryHoistingInvariantInstructionsFromWhileBody(while_instr));
|
||||
changed |= result;
|
||||
}
|
||||
|
||||
if (changed) {
|
||||
VLOG(2) << "HLO module after WhileLoopExpensiveInvariantCodeMotion:";
|
||||
XLA_VLOG_LINES(2, module->ToString());
|
||||
} else {
|
||||
VLOG(2)
|
||||
<< "HLO module unchanged after WhileLoopExpensiveInvariantCodeMotion";
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
} // namespace xla
|
@ -0,0 +1,54 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_EXPENSIVE_INVARIANT_CODE_MOTION_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_EXPENSIVE_INVARIANT_CODE_MOTION_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// HLO pass that rewrites while loops to hoist expensive and non-size-inflating
|
||||
// groups of loop invariant instructions in the while body into the computation
|
||||
// that contains the while instruction.
|
||||
// Users can specify worth_hoisting_individually, and only the groups
|
||||
// instructions with a root that returns true with it will be hoisted out.
|
||||
class WhileLoopExpensiveInvariantCodeMotion : public HloModulePass {
|
||||
public:
|
||||
using ShapeSizeFunction = std::function<int64(const Shape&)>;
|
||||
explicit WhileLoopExpensiveInvariantCodeMotion(
|
||||
std::function<bool(const HloInstruction&)> worth_hoisting_individually,
|
||||
ShapeSizeFunction shape_size_function = ShapeUtil::ByteSizeOfElements)
|
||||
: shape_size_function_(std::move(shape_size_function)),
|
||||
worth_hoisting_individually_(std::move(worth_hoisting_individually)) {}
|
||||
~WhileLoopExpensiveInvariantCodeMotion() override = default;
|
||||
|
||||
absl::string_view name() const override {
|
||||
return "while-loop-expensive-invariant-code-motion";
|
||||
}
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
private:
|
||||
StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody(
|
||||
HloInstruction* while_instr);
|
||||
|
||||
ShapeSizeFunction shape_size_function_;
|
||||
std::function<bool(const HloInstruction&)> worth_hoisting_individually_;
|
||||
};
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_EXPENSIVE_INVARIANT_CODE_MOTION_H_
|
@ -0,0 +1,240 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/while_loop_expensive_invariant_code_motion.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using WhileLoopExpensiveInvariantCodeMotionTest = HloTestBase;
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
constexpr char kModuleWithNonInflatingInvariantDot[] = R"(
|
||||
HloModule ModuleWithWhile
|
||||
|
||||
mul {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT mul = f32[] multiply(lhs, rhs)
|
||||
}
|
||||
|
||||
body {
|
||||
p_body = (f32[], f32[16, 8]) parameter(0)
|
||||
b = get-tuple-element(p_body), index=1
|
||||
const = f32[] constant(1.0)
|
||||
lhs = f32[8, 16] broadcast(const), dimensions={}
|
||||
dot = dot(lhs, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
reduced = reduce(dot, const), dimensions={0, 1}, to_apply=mul
|
||||
a = get-tuple-element(p_body), index=0
|
||||
add = add(reduced, a)
|
||||
ROOT root = tuple(add, b)
|
||||
}
|
||||
|
||||
condition {
|
||||
p_cond = (f32[], f32[16, 8]) parameter(0)
|
||||
ROOT result = pred[] constant(true)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
param0 = f32[] parameter(0)
|
||||
param1 = f32[16, 8] parameter(1)
|
||||
while_init = tuple(param0, param1)
|
||||
ROOT while = while(while_init), condition=condition, body=body
|
||||
}
|
||||
)";
|
||||
|
||||
TEST_F(WhileLoopExpensiveInvariantCodeMotionTest,
|
||||
HoistsGroupOfAllowedNonInflating) {
|
||||
auto m = ParseAndReturnVerifiedModule(kModuleWithNonInflatingInvariantDot)
|
||||
.ValueOrDie();
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool simplified_loop,
|
||||
WhileLoopExpensiveInvariantCodeMotion(
|
||||
/*worth_hoisting_individually=*/[](const HloInstruction& instr) {
|
||||
return instr.opcode() == HloOpcode::kDot;
|
||||
})
|
||||
.Run(m.get()));
|
||||
EXPECT_TRUE(simplified_loop);
|
||||
|
||||
HloComputation* while_body = m->GetComputationWithName("wide.body");
|
||||
ASSERT_NE(while_body, nullptr);
|
||||
EXPECT_THAT(while_body->instructions(), Not(Contains(op::Dot())));
|
||||
// kReduce not in the allow list.
|
||||
EXPECT_THAT(while_body->instructions(), Contains(op::Reduce()));
|
||||
}
|
||||
|
||||
TEST_F(WhileLoopExpensiveInvariantCodeMotionTest,
|
||||
HoistsGroupOfAllNonInflating) {
|
||||
auto m = ParseAndReturnVerifiedModule(kModuleWithNonInflatingInvariantDot)
|
||||
.ValueOrDie();
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool simplified_loop,
|
||||
WhileLoopExpensiveInvariantCodeMotion(
|
||||
/*worth_hoisting_individually=*/[](const HloInstruction& instr) {
|
||||
return instr.opcode() == HloOpcode::kDot ||
|
||||
instr.opcode() == HloOpcode::kReduce;
|
||||
})
|
||||
.Run(m.get()));
|
||||
EXPECT_TRUE(simplified_loop);
|
||||
|
||||
HloComputation* while_body = m->GetComputationWithName("wide.body");
|
||||
ASSERT_NE(while_body, nullptr);
|
||||
EXPECT_THAT(while_body->instructions(), Not(Contains(op::Dot())));
|
||||
EXPECT_THAT(while_body->instructions(), Not(Contains(op::Reduce())));
|
||||
}
|
||||
|
||||
TEST_F(WhileLoopExpensiveInvariantCodeMotionTest,
|
||||
DoesNotHoistsUnallowedInstructions) {
|
||||
auto m = ParseAndReturnVerifiedModule(kModuleWithNonInflatingInvariantDot)
|
||||
.ValueOrDie();
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool simplified_loop,
|
||||
WhileLoopExpensiveInvariantCodeMotion(
|
||||
/*worth_hoisting_individually=*/[](const HloInstruction& instr) {
|
||||
return false;
|
||||
})
|
||||
.Run(m.get()));
|
||||
EXPECT_FALSE(simplified_loop);
|
||||
}
|
||||
|
||||
constexpr char kModuleWithInflatingInvariantDot[] = R"(
|
||||
HloModule ModuleWithWhile
|
||||
|
||||
mul {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT mul = f32[] multiply(lhs, rhs)
|
||||
}
|
||||
|
||||
body {
|
||||
p_body = (f32[], f32[16, 4]) parameter(0)
|
||||
b = get-tuple-element(p_body), index=1
|
||||
const = f32[] constant(1.0)
|
||||
lhs = f32[4, 16] broadcast(const), dimensions={}
|
||||
dot = dot(lhs, b), lhs_contracting_dims={0}, rhs_contracting_dims={1}
|
||||
reduced = reduce(dot, const), dimensions={0, 1}, to_apply=mul
|
||||
a = get-tuple-element(p_body), index=0
|
||||
add = add(reduced, a)
|
||||
ROOT root = tuple(add, b)
|
||||
}
|
||||
|
||||
condition {
|
||||
p_cond = (f32[], f32[16, 4]) parameter(0)
|
||||
ROOT result = pred[] constant(true)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
param0 = f32[] parameter(0)
|
||||
param1 = f32[16, 4] parameter(1)
|
||||
while_init = tuple(param0, param1)
|
||||
ROOT while = while(while_init), condition=condition, body=body
|
||||
}
|
||||
)";
|
||||
|
||||
TEST_F(WhileLoopExpensiveInvariantCodeMotionTest, DoesNotHoistsInflating) {
|
||||
auto m = ParseAndReturnVerifiedModule(kModuleWithInflatingInvariantDot)
|
||||
.ValueOrDie();
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool simplified_loop,
|
||||
WhileLoopExpensiveInvariantCodeMotion(
|
||||
/*worth_hoisting_individually=*/[](const HloInstruction& instr) {
|
||||
return instr.opcode() == HloOpcode::kDot;
|
||||
})
|
||||
.Run(m.get()));
|
||||
EXPECT_FALSE(simplified_loop);
|
||||
}
|
||||
|
||||
TEST_F(WhileLoopExpensiveInvariantCodeMotionTest,
|
||||
HoistsGroupOfNonInflatingWithInflatingIntermediate) {
|
||||
auto m = ParseAndReturnVerifiedModule(kModuleWithInflatingInvariantDot)
|
||||
.ValueOrDie();
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool simplified_loop,
|
||||
WhileLoopExpensiveInvariantCodeMotion(
|
||||
/*worth_hoisting_individually=*/[](const HloInstruction& instr) {
|
||||
return instr.opcode() == HloOpcode::kDot ||
|
||||
instr.opcode() == HloOpcode::kReduce;
|
||||
})
|
||||
.Run(m.get()));
|
||||
EXPECT_TRUE(simplified_loop);
|
||||
|
||||
HloComputation* while_body = m->GetComputationWithName("wide.body");
|
||||
ASSERT_NE(while_body, nullptr);
|
||||
EXPECT_THAT(while_body->instructions(), Not(Contains(op::Dot())));
|
||||
EXPECT_THAT(while_body->instructions(), Not(Contains(op::Reduce())));
|
||||
}
|
||||
|
||||
TEST_F(WhileLoopExpensiveInvariantCodeMotionTest,
|
||||
HoistsOpWithDuplicateOperands) {
|
||||
constexpr char kModuleWithDuplicateOperands[] = R"(
|
||||
HloModule ModuleWithWhile
|
||||
|
||||
mul {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT mul = f32[] multiply(lhs, rhs)
|
||||
}
|
||||
|
||||
body {
|
||||
p_body = (f32[4, 4], f32[4, 4]) parameter(0)
|
||||
a = get-tuple-element(p_body), index=0
|
||||
dot = dot(a, a), lhs_contracting_dims={0}, rhs_contracting_dims={1}
|
||||
b = get-tuple-element(p_body), index=1
|
||||
add = add(b, dot)
|
||||
ROOT root = tuple(a, add)
|
||||
}
|
||||
|
||||
condition {
|
||||
p_cond = (f32[4, 4], f32[4, 4]) parameter(0)
|
||||
ROOT result = pred[] constant(true)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
param0 = f32[4, 4] parameter(0)
|
||||
param1 = f32[4, 4] parameter(1)
|
||||
while_init = tuple(param0, param1)
|
||||
ROOT while = while(while_init), condition=condition, body=body
|
||||
}
|
||||
)";
|
||||
auto m =
|
||||
ParseAndReturnVerifiedModule(kModuleWithDuplicateOperands).ValueOrDie();
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool simplified_loop,
|
||||
WhileLoopExpensiveInvariantCodeMotion(
|
||||
/*worth_hoisting_individually=*/[](const HloInstruction& instr) {
|
||||
return instr.opcode() == HloOpcode::kDot;
|
||||
})
|
||||
.Run(m.get()));
|
||||
EXPECT_TRUE(simplified_loop);
|
||||
|
||||
HloComputation* while_body = m->GetComputationWithName("wide.body");
|
||||
ASSERT_NE(while_body, nullptr);
|
||||
EXPECT_THAT(while_body->instructions(), Not(Contains(op::Dot())));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
Loading…
Reference in New Issue
Block a user