Hoisting unconditional converts from conditional branch computations.
PiperOrigin-RevId: 317239618 Change-Id: If3b16ff4f2bbcf38ee1ca51f5e8b187c58ab8e91
This commit is contained in:
parent
2a05589bd4
commit
397494a231
|
@ -106,7 +106,6 @@ class BranchVisitor {
|
|||
boundaries_.emplace_back(operand, i, inst);
|
||||
continue;
|
||||
}
|
||||
|
||||
worklist_.push_back(operand);
|
||||
visited_.insert(operand);
|
||||
}
|
||||
|
@ -197,6 +196,7 @@ bool WorthHoisting(HloInstruction* instruction) {
|
|||
case HloOpcode::kMultiply:
|
||||
case HloOpcode::kDivide:
|
||||
case HloOpcode::kTuple:
|
||||
case HloOpcode::kSqrt:
|
||||
case HloOpcode::kGetTupleElement:
|
||||
return true;
|
||||
default:
|
||||
|
@ -206,10 +206,11 @@ bool WorthHoisting(HloInstruction* instruction) {
|
|||
|
||||
// Compare if the instructions to be visited at each branches are identical.
|
||||
bool InstructionWithinBranchIdentical(
|
||||
const std::vector<HloInstruction*>& instructions, bool is_layout_senstive) {
|
||||
const std::vector<HloInstruction*>& instructions,
|
||||
bool is_layout_sensitive) {
|
||||
// Identical includes the shape of each operands are equal.
|
||||
auto eq_operand = [&](const HloInstruction* a, const HloInstruction* b) {
|
||||
bool eq_operands = is_layout_senstive
|
||||
bool eq_operands = is_layout_sensitive
|
||||
? ShapeUtil::Equal(a->shape(), b->shape())
|
||||
: ShapeUtil::Compatible(a->shape(), b->shape());
|
||||
return eq_operands;
|
||||
|
@ -233,7 +234,7 @@ bool InstructionWithinBranchIdentical(
|
|||
auto old_channel_id = instruction->channel_id();
|
||||
instruction->set_channel_id(instructions[0]->channel_id());
|
||||
bool eq_instructions = instructions[0]->Identical(
|
||||
*instruction, eq_operand, eq_computations, is_layout_senstive);
|
||||
*instruction, eq_operand, eq_computations, is_layout_sensitive);
|
||||
instruction->set_channel_id(old_channel_id);
|
||||
return eq_instructions;
|
||||
});
|
||||
|
@ -243,7 +244,7 @@ bool InstructionWithinBranchIdentical(
|
|||
[&](HloInstruction* instruction) {
|
||||
return instructions[0]->Identical(
|
||||
*instruction, eq_operand, eq_computations,
|
||||
is_layout_senstive);
|
||||
is_layout_sensitive);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -354,12 +355,228 @@ Status RemoveInstructionFromComputation(
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Identify converts to be hoisted/rematerialized out of the branch
|
||||
// computations.
|
||||
absl::flat_hash_set<int64> FindSpecialConverts(HloInstruction* old_root,
|
||||
int branch_count,
|
||||
HloInstruction* conditional,
|
||||
bool is_layout_sensitive) {
|
||||
absl::flat_hash_set<int64> kspecial_convert;
|
||||
for (int64 operand_num = 0; operand_num < old_root->operand_count();
|
||||
++operand_num) {
|
||||
if (old_root->operand(operand_num)->opcode() != HloOpcode::kConvert) {
|
||||
continue;
|
||||
}
|
||||
bool replica = true;
|
||||
HloInstruction* kspecial_convert_candidate =
|
||||
old_root->mutable_operand(operand_num);
|
||||
// Check whether an identical candidate appears in other branches
|
||||
for (int others = 1; others < branch_count; ++others) {
|
||||
HloInstruction* others_root =
|
||||
conditional->branch_computation(others)->root_instruction();
|
||||
bool eq_shape =
|
||||
is_layout_sensitive
|
||||
? ShapeUtil::Equal(others_root->operand(operand_num)->shape(),
|
||||
kspecial_convert_candidate->shape())
|
||||
: ShapeUtil::Compatible(
|
||||
others_root->operand(operand_num)->shape(),
|
||||
kspecial_convert_candidate->shape());
|
||||
if ((others_root->operand(operand_num)->opcode() ==
|
||||
HloOpcode::kConvert) &&
|
||||
eq_shape) {
|
||||
// Nothing to be done.
|
||||
} else {
|
||||
replica = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (replica) {
|
||||
kspecial_convert.insert(operand_num);
|
||||
}
|
||||
}
|
||||
return kspecial_convert;
|
||||
}
|
||||
|
||||
// Restructuring the conditional instruction as follows:
|
||||
// i.e., %result = conditional() becomes
|
||||
// x = conditional()
|
||||
// y.{0..n} = gte(x, {0..n})
|
||||
// z = tuple(y.0, y.1, ...y.n)
|
||||
// Doing so ensures that we can accommodate the possible shape-change of the
|
||||
// conditional when the instructions are hoisted.
|
||||
Status RestructureConditionalInstruction(HloComputation* computation,
|
||||
HloInstruction* conditional) {
|
||||
HloInstruction* old_root = computation->root_instruction();
|
||||
std::vector<HloInstruction*> new_operands;
|
||||
int cur_index = 0;
|
||||
for (; cur_index < ShapeUtil::TupleElementCount(conditional->shape());
|
||||
++cur_index) {
|
||||
new_operands.push_back(
|
||||
computation->AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
ShapeUtil::GetTupleElementShape(conditional->shape(), cur_index),
|
||||
conditional, cur_index)));
|
||||
}
|
||||
HloInstruction* new_tuple =
|
||||
computation->AddInstruction(HloInstruction::CreateTuple(new_operands));
|
||||
if (old_root == conditional) {
|
||||
computation->set_root_instruction(new_tuple);
|
||||
} else {
|
||||
std::vector<HloInstruction*> new_tuple_users;
|
||||
for (auto conditional_user : conditional->users()) {
|
||||
auto is_new_gte = absl::c_find_if(
|
||||
new_operands,
|
||||
[&](HloInstruction* instr) { return instr == conditional_user; });
|
||||
if (is_new_gte == new_operands.end()) {
|
||||
new_tuple_users.push_back(conditional_user);
|
||||
}
|
||||
}
|
||||
for (auto new_tuple_user : new_tuple_users) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
conditional->ReplaceUseWith(new_tuple_user, new_tuple));
|
||||
}
|
||||
}
|
||||
VLOG(2) << "computation after root restructure:\n" << computation->ToString();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<bool> ConvertSpecialMove(HloInstruction* conditional,
|
||||
bool is_layout_sensitive) {
|
||||
int branch_count = conditional->branch_count();
|
||||
if (branch_count <= 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
HloInstruction* old_root =
|
||||
conditional->branch_computation(0)->root_instruction();
|
||||
if (old_root->opcode() != HloOpcode::kTuple) {
|
||||
return false;
|
||||
} else {
|
||||
VLOG(2) << "BEFORE :" << conditional->parent()->parent()->ToString();
|
||||
// Identify the gte using `index'.
|
||||
auto find_gte = [](const HloInstruction* conditional_result,
|
||||
int64 index) -> HloInstruction* {
|
||||
for (HloInstruction* instr : conditional_result->users()) {
|
||||
if (instr->opcode() != HloOpcode::kGetTupleElement) {
|
||||
return nullptr;
|
||||
}
|
||||
if (instr->tuple_index() == index) {
|
||||
return instr;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
// Captures tuple indices refering to converts to be rematerialized/hoisted.
|
||||
absl::flat_hash_set<int64> kspecial_convert = FindSpecialConverts(
|
||||
old_root, branch_count, conditional, is_layout_sensitive);
|
||||
|
||||
// Exit if we cannot find any converts to be hoisted.
|
||||
if (kspecial_convert.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
RestructureConditionalInstruction(conditional->parent(), conditional));
|
||||
|
||||
for (int branch = 0; branch < branch_count; branch++) {
|
||||
old_root = conditional->branch_computation(branch)->root_instruction();
|
||||
absl::flat_hash_map<HloInstruction*, int64> map_inst_to_tuple_index;
|
||||
std::vector<HloInstruction*> new_operands(old_root->operand_count());
|
||||
std::unordered_set<HloInstruction*> to_hoist_set;
|
||||
|
||||
for (int64 operand_num = 0; operand_num < old_root->operand_count();
|
||||
++operand_num) {
|
||||
map_inst_to_tuple_index[old_root->mutable_operand(operand_num)] =
|
||||
operand_num;
|
||||
}
|
||||
for (int64 operand_num = 0; operand_num < old_root->operand_count();
|
||||
++operand_num) {
|
||||
HloInstruction* hoist = old_root->mutable_operand(operand_num);
|
||||
if (!kspecial_convert.contains(operand_num)) {
|
||||
new_operands[operand_num] = old_root->mutable_operand(operand_num);
|
||||
continue;
|
||||
}
|
||||
|
||||
to_hoist_set.insert(hoist);
|
||||
int64 new_tuple_count = old_root->operand_count();
|
||||
|
||||
// Replace the hoisted instr in the tuple with the operand/operands.
|
||||
// We will replace at least one of the operands of the hoist at the
|
||||
// tuple place; the rest will be added at the end.
|
||||
bool inplace = true;
|
||||
CHECK(!hoist->operands().empty());
|
||||
for (HloInstruction* prod : hoist->operands()) {
|
||||
if (inplace) {
|
||||
map_inst_to_tuple_index[prod] = map_inst_to_tuple_index[hoist];
|
||||
new_operands[map_inst_to_tuple_index[hoist]] = prod;
|
||||
inplace = false;
|
||||
} else {
|
||||
map_inst_to_tuple_index[prod] = new_tuple_count++;
|
||||
new_operands.push_back(prod);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create the new root instruction.
|
||||
HloComputation* cur_branch = conditional->branch_computation(branch);
|
||||
HloInstruction* new_branch_root =
|
||||
cur_branch->AddInstruction(HloInstruction::CreateTuple(new_operands));
|
||||
// The shape can vary since the operands to convert are now
|
||||
// being returned through the branches' root.
|
||||
cur_branch->set_root_instruction(new_branch_root, true /*new shape*/);
|
||||
TF_CHECK_OK(cur_branch->RemoveInstruction(old_root));
|
||||
|
||||
// Only one of the branches needs to change the conditional->parent().
|
||||
if (branch != 0) {
|
||||
continue;
|
||||
}
|
||||
HloComputation* conditional_parent = conditional->parent();
|
||||
HloInstruction* newconditional =
|
||||
conditional_parent->AddInstruction(HloInstruction::CreateConditional(
|
||||
cur_branch->root_instruction()->shape(),
|
||||
conditional->mutable_operand(0),
|
||||
absl::MakeSpan(conditional->branch_computations()),
|
||||
absl::MakeSpan(conditional->operands()).subspan(1)));
|
||||
// Ensure that all the users of conditional refer to the new one.
|
||||
TF_RETURN_IF_ERROR(
|
||||
conditional->ReplaceAllUsesWithDifferentShape(newconditional));
|
||||
TF_CHECK_OK(conditional_parent->RemoveInstruction(conditional));
|
||||
conditional = newconditional;
|
||||
// Add the hoisted instructions in the parent.
|
||||
for (HloInstruction* hoist : to_hoist_set) {
|
||||
VLOG(2) << "Hoisting instruction:" << hoist->ToString();
|
||||
int64 hoist_index = map_inst_to_tuple_index[hoist];
|
||||
// Find out the gte that captured the hoisted instr result.
|
||||
HloInstruction* gte_hoist = find_gte(conditional, hoist_index);
|
||||
CHECK(gte_hoist != nullptr);
|
||||
std::vector<HloInstruction*> new_operands;
|
||||
for (HloInstruction* op : hoist->operands()) {
|
||||
HloInstruction* gte = conditional_parent->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(
|
||||
op->shape(), conditional, map_inst_to_tuple_index[op]));
|
||||
new_operands.push_back(gte);
|
||||
}
|
||||
HloInstruction* hoisted = conditional_parent->AddInstruction(
|
||||
hoist->CloneWithNewOperands(hoist->shape(), new_operands));
|
||||
VLOG(2) << "Hoisted instruction in parent:" << hoisted->ToString();
|
||||
TF_RETURN_IF_ERROR(gte_hoist->ReplaceAllUsesWith(hoisted));
|
||||
TF_CHECK_OK(conditional_parent->RemoveInstruction(gte_hoist));
|
||||
}
|
||||
// No need to explicitly delete a hoisted instruction since if its dead
|
||||
// then the subsequent DCE will remove it.
|
||||
}
|
||||
}
|
||||
VLOG(2) << "AFTER :" << conditional->parent()->parent()->ToString();
|
||||
return true;
|
||||
}
|
||||
|
||||
// Hoist identical ops out of the conditional. The definition of identical
|
||||
// are the shape of the operands are identical and their properties are
|
||||
// identical. Will start from the root instruction of each branch and get
|
||||
// the identical ops to hoist.
|
||||
StatusOr<bool> MergeIdenticalElements(HloInstruction* conditional,
|
||||
bool is_layout_sensitive) {
|
||||
VLOG(1) << " visiting conditional:" << conditional->ToString();
|
||||
int branch_count = conditional->branch_count();
|
||||
if (branch_count <= 0) {
|
||||
return false;
|
||||
|
@ -399,7 +616,7 @@ StatusOr<bool> MergeIdenticalElements(HloInstruction* conditional,
|
|||
}
|
||||
}
|
||||
|
||||
if (visitors[0].HoistInstructionSize() <= 1) {
|
||||
if (visitors[0].HoistInstructionSize() < 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -442,7 +659,6 @@ StatusOr<bool> MergeIdenticalElements(HloInstruction* conditional,
|
|||
RemoveInstructionFromComputation(visitors[i].instructions_to_hoist(),
|
||||
conditional->branch_computation(i)));
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -451,26 +667,55 @@ StatusOr<bool> MergeIdenticalElements(HloInstruction* conditional,
|
|||
StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
|
||||
// Gather all the conditional ops in our module. We do this ahead of time so
|
||||
// we don't have to worry about mutating the lists of computations or
|
||||
// instructions as we iterate.
|
||||
std::vector<HloInstruction*> conditional_ops;
|
||||
for (auto* comp : module->MakeComputationPostOrder()) {
|
||||
for (auto* instr : comp->MakeInstructionPostOrder()) {
|
||||
if (instr->opcode() == HloOpcode::kConditional) {
|
||||
conditional_ops.push_back(instr);
|
||||
if (pursue_full_conditional_code_motion_) {
|
||||
std::vector<HloInstruction*> conditional_ops;
|
||||
for (auto* comp : module->MakeComputationPostOrder()) {
|
||||
for (auto* instr : comp->MakeInstructionPostOrder()) {
|
||||
if (instr->opcode() == HloOpcode::kConditional) {
|
||||
conditional_ops.push_back(instr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (HloInstruction* conditional_op : conditional_ops) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool result,
|
||||
MergeIdenticalElements(conditional_op, is_layout_sensitive_));
|
||||
changed |= result;
|
||||
}
|
||||
|
||||
if (changed) {
|
||||
HloPassPipeline subpipeline("after_conditional_code_motion");
|
||||
subpipeline.AddPass<HloDCE>();
|
||||
subpipeline.AddPass<TupleSimplifier>();
|
||||
subpipeline.AddPass<HloDCE>();
|
||||
TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module));
|
||||
changed |= cleanup_changed;
|
||||
}
|
||||
}
|
||||
|
||||
for (HloInstruction* conditional_op : conditional_ops) {
|
||||
TF_ASSIGN_OR_RETURN(bool result, MergeIdenticalElements(
|
||||
conditional_op, is_layout_sensitive_));
|
||||
changed |= result;
|
||||
// handling convert rematerialization/hoisting
|
||||
{
|
||||
std::vector<HloInstruction*> conditional_ops;
|
||||
for (auto* comp : module->MakeComputationPostOrder()) {
|
||||
for (auto* instr : comp->MakeInstructionPostOrder()) {
|
||||
if (instr->opcode() == HloOpcode::kConditional) {
|
||||
conditional_ops.push_back(instr);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (HloInstruction* conditional_op : conditional_ops) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool convert_result,
|
||||
ConvertSpecialMove(conditional_op, is_layout_sensitive_));
|
||||
changed |= convert_result;
|
||||
}
|
||||
}
|
||||
|
||||
if (changed) {
|
||||
HloPassPipeline subpipeline("after_conditional_code_motion");
|
||||
HloPassPipeline subpipeline(
|
||||
"after_conditional_code_motion_after_convert_hoisting");
|
||||
subpipeline.AddPass<HloDCE>();
|
||||
subpipeline.AddPass<TupleSimplifier>();
|
||||
subpipeline.AddPass<HloDCE>();
|
||||
TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module));
|
||||
|
|
|
@ -23,7 +23,11 @@ limitations under the License.
|
|||
|
||||
namespace xla {
|
||||
|
||||
// HLO pass that moves identical ops out of conditional.
|
||||
// ConditionalCodeMotion specializes in hoisting/rematerializing
|
||||
// unconditional converts in the default mode.
|
||||
// When pursue_full_conditional_code_motion_ is set to true, the
|
||||
// full HLO pass moves identical ops out of a conditional in addition to moving
|
||||
// converts.
|
||||
// - The definition of identical are the shape of the operands are identical
|
||||
// and their properties are identical.
|
||||
// - Currently, only some types of instructions is supported.
|
||||
|
@ -35,13 +39,18 @@ class ConditionalCodeMotion : public HloModulePass {
|
|||
public:
|
||||
// If is_layout_sensitive is true, then the hoist process preserves layout
|
||||
// during identical comparison. Otherwise, layout is ignored.
|
||||
explicit ConditionalCodeMotion(bool is_layout_sensitive = true)
|
||||
: is_layout_sensitive_(is_layout_sensitive) {}
|
||||
explicit ConditionalCodeMotion(
|
||||
bool is_layout_sensitive = true,
|
||||
bool pursue_full_conditional_code_motion = false)
|
||||
: is_layout_sensitive_(is_layout_sensitive),
|
||||
pursue_full_conditional_code_motion_(
|
||||
pursue_full_conditional_code_motion) {}
|
||||
absl::string_view name() const override { return "conditional-code-motion"; }
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
private:
|
||||
const bool is_layout_sensitive_;
|
||||
const bool pursue_full_conditional_code_motion_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
|
|
@ -38,7 +38,86 @@ namespace {
|
|||
using ConditionalCodeMotionTest = HloTestBase;
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
TEST_F(ConditionalCodeMotionTest, DoNotMoveConvertOut) {
|
||||
TEST_F(ConditionalCodeMotionTest, MoveSubsetTupleOut) {
|
||||
absl::string_view hlo_string =
|
||||
R"(
|
||||
HloModule RemoveDotOpOut
|
||||
|
||||
on_true {
|
||||
%arg_tuple.1 = (f32[93184,4]{1,0}) parameter(0)
|
||||
%get-tuple-element.1 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.1), index=0
|
||||
%reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.1)
|
||||
%convert.2894 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.8493)
|
||||
ROOT %tuple.1 = ( bf16[2,512,364]{2,1,0}, f32[2,512,364]{2,1,0}) tuple(%convert.2894, %reshape.8493)
|
||||
}
|
||||
|
||||
on_false {
|
||||
%arg_tuple.2 = (f32[93184,4]{1,0}) parameter(0)
|
||||
%get-tuple-element.3 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.2), index=0
|
||||
%reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.3)
|
||||
%add = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.9717, f32[2,512,364]{2,1,0} %reshape.9717)
|
||||
%convert.3604 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.9717), metadata={op_type="Cast" op_name="gradients/Cast_125_grad/Cast"}
|
||||
ROOT %tuple.2 = (bf16[2,512,364]{2,1,0}, f32[2,512,364]{2,1,0}) tuple(%convert.3604, %add)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
pred.1 = pred[] parameter(0)
|
||||
arg_tuple.11 = (f32[93184,4]{1,0}) parameter(1)
|
||||
arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2)
|
||||
conditional = (bf16[2,512,364]{2,1,0}, f32[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false
|
||||
get-first-index = bf16[2,512,364]{2,1,0} get-tuple-element(conditional), index=0
|
||||
get-first-index.2 = f32[2,512,364]{2,1,0} get-tuple-element(conditional), index=1
|
||||
ROOT result = (bf16[2,512,364]{2,1,0}, f32[2,512,364]{2,1,0}) tuple(get-first-index, get-first-index.2)
|
||||
}
|
||||
)";
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
ConditionalCodeMotion pass(true, true);
|
||||
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, AllOf(op::Tuple(op::Convert(), op::GetTupleElement())));
|
||||
}
|
||||
|
||||
TEST_F(ConditionalCodeMotionTest, MoveConvertOutConditionalRoot) {
|
||||
absl::string_view hlo_string =
|
||||
R"(
|
||||
HloModule RemoveDotOpOut
|
||||
|
||||
on_true {
|
||||
%arg_tuple.1 = (f32[93184,4]{1,0}) parameter(0)
|
||||
%get-tuple-element.1 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.1), index=0
|
||||
%reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.1)
|
||||
%add.8493 = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.8493, f32[2,512,364]{2,1,0} %reshape.8493)
|
||||
%convert.2894 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %add.8493)
|
||||
ROOT %tuple.1 = ( bf16[2,512,364]{2,1,0}) tuple(%convert.2894)
|
||||
}
|
||||
|
||||
on_false {
|
||||
%arg_tuple.2 = (f32[93184,4]{1,0}) parameter(0)
|
||||
%get-tuple-element.3 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.2), index=0
|
||||
%reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.3)
|
||||
%add.8493 = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.9717, f32[2,512,364]{2,1,0} %reshape.9717)
|
||||
%sub.8493 = f32[2,512,364]{2,1,0} subtract(f32[2,512,364]{2,1,0} %add.8493, f32[2,512,364]{2,1,0} %reshape.9717)
|
||||
%convert.3604 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.9717), metadata={op_type="Cast" op_name="gradients/Cast_125_grad/Cast"}
|
||||
ROOT %tuple.2 = (bf16[2,512,364]{2,1,0}) tuple(%convert.3604)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
pred.1 = pred[] parameter(0)
|
||||
arg_tuple.11 = (f32[93184,4]{1,0}) parameter(1)
|
||||
arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2)
|
||||
ROOT conditional = (bf16[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false
|
||||
}
|
||||
)";
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
ConditionalCodeMotion pass(true, true);
|
||||
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, AllOf(op::Tuple(op::Convert())));
|
||||
}
|
||||
|
||||
TEST_F(ConditionalCodeMotionTest, MoveConvertOut) {
|
||||
absl::string_view hlo_string =
|
||||
R"(
|
||||
HloModule RemoveDotOpOut
|
||||
|
@ -65,12 +144,16 @@ ENTRY main {
|
|||
arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2)
|
||||
conditional = (bf16[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false
|
||||
get-first-index = bf16[2,512,364]{2,1,0} get-tuple-element(conditional), index=0
|
||||
ROOT result = (bf16[2,512,364]{2,1,0}) tuple(get-first-index)
|
||||
add.1 = bf16[2,512,364]{2,1,0} add(bf16[2,512,364]{2,1,0} get-first-index, bf16[2,512,364]{2,1,0} get-first-index)
|
||||
ROOT result = (bf16[2,512,364]{2,1,0}) tuple(add.1)
|
||||
}
|
||||
)";
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
ConditionalCodeMotion pass;
|
||||
ASSERT_FALSE(pass.Run(&*module).ValueOrDie());
|
||||
ConditionalCodeMotion pass(true, true);
|
||||
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, AllOf(op::Tuple(op::Add(op::Convert(), op::Convert()))));
|
||||
}
|
||||
|
||||
TEST_F(ConditionalCodeMotionTest, UserShareOperandCannotBeMoved) {
|
||||
|
@ -123,7 +206,7 @@ ENTRY main {
|
|||
}
|
||||
)";
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
ConditionalCodeMotion pass;
|
||||
ConditionalCodeMotion pass(true, true);
|
||||
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
||||
|
||||
const HloInstruction* conditional =
|
||||
|
@ -181,7 +264,7 @@ ENTRY main {
|
|||
}
|
||||
)";
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
ConditionalCodeMotion pass;
|
||||
ConditionalCodeMotion pass(true, true);
|
||||
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
||||
const HloInstruction* conditional =
|
||||
FindInstruction(module.get(), "conditional");
|
||||
|
@ -245,7 +328,7 @@ ENTRY main {
|
|||
}
|
||||
)";
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
ConditionalCodeMotion pass;
|
||||
ConditionalCodeMotion pass(true, true);
|
||||
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
||||
|
||||
const HloInstruction* conditional =
|
||||
|
@ -317,7 +400,7 @@ ENTRY main {
|
|||
)";
|
||||
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
ConditionalCodeMotion pass;
|
||||
ConditionalCodeMotion pass(true, true);
|
||||
ASSERT_FALSE(pass.Run(&*module).ValueOrDie());
|
||||
}
|
||||
|
||||
|
@ -390,7 +473,7 @@ ENTRY main {
|
|||
}
|
||||
)";
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
ConditionalCodeMotion pass;
|
||||
ConditionalCodeMotion pass(true, true);
|
||||
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
||||
const HloInstruction* conditional =
|
||||
FindInstruction(module.get(), "conditional");
|
||||
|
|
Loading…
Reference in New Issue