Hoisting unconditional converts from conditional branch computations.

PiperOrigin-RevId: 317239618
Change-Id: If3b16ff4f2bbcf38ee1ca51f5e8b187c58ab8e91
This commit is contained in:
A. Unique TensorFlower 2020-06-18 20:51:42 -07:00 committed by TensorFlower Gardener
parent 2a05589bd4
commit 397494a231
3 changed files with 369 additions and 32 deletions

View File

@ -106,7 +106,6 @@ class BranchVisitor {
boundaries_.emplace_back(operand, i, inst);
continue;
}
worklist_.push_back(operand);
visited_.insert(operand);
}
@ -197,6 +196,7 @@ bool WorthHoisting(HloInstruction* instruction) {
case HloOpcode::kMultiply:
case HloOpcode::kDivide:
case HloOpcode::kTuple:
case HloOpcode::kSqrt:
case HloOpcode::kGetTupleElement:
return true;
default:
@ -206,10 +206,11 @@ bool WorthHoisting(HloInstruction* instruction) {
// Compare if the instructions to be visited at each branches are identical.
bool InstructionWithinBranchIdentical(
const std::vector<HloInstruction*>& instructions, bool is_layout_senstive) {
const std::vector<HloInstruction*>& instructions,
bool is_layout_sensitive) {
// Identical includes the shape of each operands are equal.
auto eq_operand = [&](const HloInstruction* a, const HloInstruction* b) {
bool eq_operands = is_layout_senstive
bool eq_operands = is_layout_sensitive
? ShapeUtil::Equal(a->shape(), b->shape())
: ShapeUtil::Compatible(a->shape(), b->shape());
return eq_operands;
@ -233,7 +234,7 @@ bool InstructionWithinBranchIdentical(
auto old_channel_id = instruction->channel_id();
instruction->set_channel_id(instructions[0]->channel_id());
bool eq_instructions = instructions[0]->Identical(
*instruction, eq_operand, eq_computations, is_layout_senstive);
*instruction, eq_operand, eq_computations, is_layout_sensitive);
instruction->set_channel_id(old_channel_id);
return eq_instructions;
});
@ -243,7 +244,7 @@ bool InstructionWithinBranchIdentical(
[&](HloInstruction* instruction) {
return instructions[0]->Identical(
*instruction, eq_operand, eq_computations,
is_layout_senstive);
is_layout_sensitive);
});
}
@ -354,12 +355,228 @@ Status RemoveInstructionFromComputation(
return Status::OK();
}
// Identify converts to be hoisted/rematerialized out of the branch
// computations.
absl::flat_hash_set<int64> FindSpecialConverts(HloInstruction* old_root,
int branch_count,
HloInstruction* conditional,
bool is_layout_sensitive) {
absl::flat_hash_set<int64> kspecial_convert;
for (int64 operand_num = 0; operand_num < old_root->operand_count();
++operand_num) {
if (old_root->operand(operand_num)->opcode() != HloOpcode::kConvert) {
continue;
}
bool replica = true;
HloInstruction* kspecial_convert_candidate =
old_root->mutable_operand(operand_num);
// Check whether an identical candidate appears in other branches
for (int others = 1; others < branch_count; ++others) {
HloInstruction* others_root =
conditional->branch_computation(others)->root_instruction();
bool eq_shape =
is_layout_sensitive
? ShapeUtil::Equal(others_root->operand(operand_num)->shape(),
kspecial_convert_candidate->shape())
: ShapeUtil::Compatible(
others_root->operand(operand_num)->shape(),
kspecial_convert_candidate->shape());
if ((others_root->operand(operand_num)->opcode() ==
HloOpcode::kConvert) &&
eq_shape) {
// Nothing to be done.
} else {
replica = false;
break;
}
}
if (replica) {
kspecial_convert.insert(operand_num);
}
}
return kspecial_convert;
}
// Restructuring the conditional instruction as follows:
// i.e., %result = conditional() becomes
// x = conditional()
// y.{0..n} = gte(x, {0..n})
// z = tuple(y.0, y.1, ...y.n)
// Doing so ensures that we can accommodate the possible shape-change of the
// conditional when the instructions are hoisted.
Status RestructureConditionalInstruction(HloComputation* computation,
HloInstruction* conditional) {
HloInstruction* old_root = computation->root_instruction();
std::vector<HloInstruction*> new_operands;
int cur_index = 0;
for (; cur_index < ShapeUtil::TupleElementCount(conditional->shape());
++cur_index) {
new_operands.push_back(
computation->AddInstruction(HloInstruction::CreateGetTupleElement(
ShapeUtil::GetTupleElementShape(conditional->shape(), cur_index),
conditional, cur_index)));
}
HloInstruction* new_tuple =
computation->AddInstruction(HloInstruction::CreateTuple(new_operands));
if (old_root == conditional) {
computation->set_root_instruction(new_tuple);
} else {
std::vector<HloInstruction*> new_tuple_users;
for (auto conditional_user : conditional->users()) {
auto is_new_gte = absl::c_find_if(
new_operands,
[&](HloInstruction* instr) { return instr == conditional_user; });
if (is_new_gte == new_operands.end()) {
new_tuple_users.push_back(conditional_user);
}
}
for (auto new_tuple_user : new_tuple_users) {
TF_RETURN_IF_ERROR(
conditional->ReplaceUseWith(new_tuple_user, new_tuple));
}
}
VLOG(2) << "computation after root restructure:\n" << computation->ToString();
return Status::OK();
}
StatusOr<bool> ConvertSpecialMove(HloInstruction* conditional,
bool is_layout_sensitive) {
int branch_count = conditional->branch_count();
if (branch_count <= 0) {
return false;
}
HloInstruction* old_root =
conditional->branch_computation(0)->root_instruction();
if (old_root->opcode() != HloOpcode::kTuple) {
return false;
} else {
VLOG(2) << "BEFORE :" << conditional->parent()->parent()->ToString();
// Identify the gte using `index'.
auto find_gte = [](const HloInstruction* conditional_result,
int64 index) -> HloInstruction* {
for (HloInstruction* instr : conditional_result->users()) {
if (instr->opcode() != HloOpcode::kGetTupleElement) {
return nullptr;
}
if (instr->tuple_index() == index) {
return instr;
}
}
return nullptr;
};
// Captures tuple indices refering to converts to be rematerialized/hoisted.
absl::flat_hash_set<int64> kspecial_convert = FindSpecialConverts(
old_root, branch_count, conditional, is_layout_sensitive);
// Exit if we cannot find any converts to be hoisted.
if (kspecial_convert.empty()) {
return false;
}
TF_RETURN_IF_ERROR(
RestructureConditionalInstruction(conditional->parent(), conditional));
for (int branch = 0; branch < branch_count; branch++) {
old_root = conditional->branch_computation(branch)->root_instruction();
absl::flat_hash_map<HloInstruction*, int64> map_inst_to_tuple_index;
std::vector<HloInstruction*> new_operands(old_root->operand_count());
std::unordered_set<HloInstruction*> to_hoist_set;
for (int64 operand_num = 0; operand_num < old_root->operand_count();
++operand_num) {
map_inst_to_tuple_index[old_root->mutable_operand(operand_num)] =
operand_num;
}
for (int64 operand_num = 0; operand_num < old_root->operand_count();
++operand_num) {
HloInstruction* hoist = old_root->mutable_operand(operand_num);
if (!kspecial_convert.contains(operand_num)) {
new_operands[operand_num] = old_root->mutable_operand(operand_num);
continue;
}
to_hoist_set.insert(hoist);
int64 new_tuple_count = old_root->operand_count();
// Replace the hoisted instr in the tuple with the operand/operands.
// We will replace at least one of the operands of the hoist at the
// tuple place; the rest will be added at the end.
bool inplace = true;
CHECK(!hoist->operands().empty());
for (HloInstruction* prod : hoist->operands()) {
if (inplace) {
map_inst_to_tuple_index[prod] = map_inst_to_tuple_index[hoist];
new_operands[map_inst_to_tuple_index[hoist]] = prod;
inplace = false;
} else {
map_inst_to_tuple_index[prod] = new_tuple_count++;
new_operands.push_back(prod);
}
}
}
// Create the new root instruction.
HloComputation* cur_branch = conditional->branch_computation(branch);
HloInstruction* new_branch_root =
cur_branch->AddInstruction(HloInstruction::CreateTuple(new_operands));
// The shape can vary since the operands to convert are now
// being returned through the branches' root.
cur_branch->set_root_instruction(new_branch_root, true /*new shape*/);
TF_CHECK_OK(cur_branch->RemoveInstruction(old_root));
// Only one of the branches needs to change the conditional->parent().
if (branch != 0) {
continue;
}
HloComputation* conditional_parent = conditional->parent();
HloInstruction* newconditional =
conditional_parent->AddInstruction(HloInstruction::CreateConditional(
cur_branch->root_instruction()->shape(),
conditional->mutable_operand(0),
absl::MakeSpan(conditional->branch_computations()),
absl::MakeSpan(conditional->operands()).subspan(1)));
// Ensure that all the users of conditional refer to the new one.
TF_RETURN_IF_ERROR(
conditional->ReplaceAllUsesWithDifferentShape(newconditional));
TF_CHECK_OK(conditional_parent->RemoveInstruction(conditional));
conditional = newconditional;
// Add the hoisted instructions in the parent.
for (HloInstruction* hoist : to_hoist_set) {
VLOG(2) << "Hoisting instruction:" << hoist->ToString();
int64 hoist_index = map_inst_to_tuple_index[hoist];
// Find out the gte that captured the hoisted instr result.
HloInstruction* gte_hoist = find_gte(conditional, hoist_index);
CHECK(gte_hoist != nullptr);
std::vector<HloInstruction*> new_operands;
for (HloInstruction* op : hoist->operands()) {
HloInstruction* gte = conditional_parent->AddInstruction(
HloInstruction::CreateGetTupleElement(
op->shape(), conditional, map_inst_to_tuple_index[op]));
new_operands.push_back(gte);
}
HloInstruction* hoisted = conditional_parent->AddInstruction(
hoist->CloneWithNewOperands(hoist->shape(), new_operands));
VLOG(2) << "Hoisted instruction in parent:" << hoisted->ToString();
TF_RETURN_IF_ERROR(gte_hoist->ReplaceAllUsesWith(hoisted));
TF_CHECK_OK(conditional_parent->RemoveInstruction(gte_hoist));
}
// No need to explicitly delete a hoisted instruction since if its dead
// then the subsequent DCE will remove it.
}
}
VLOG(2) << "AFTER :" << conditional->parent()->parent()->ToString();
return true;
}
// Hoist identical ops out of the conditional. The definition of identical
// are the shape of the operands are identical and their properties are
// identical. Will start from the root instruction of each branch and get
// the identical ops to hoist.
StatusOr<bool> MergeIdenticalElements(HloInstruction* conditional,
bool is_layout_sensitive) {
VLOG(1) << " visiting conditional:" << conditional->ToString();
int branch_count = conditional->branch_count();
if (branch_count <= 0) {
return false;
@ -399,7 +616,7 @@ StatusOr<bool> MergeIdenticalElements(HloInstruction* conditional,
}
}
if (visitors[0].HoistInstructionSize() <= 1) {
if (visitors[0].HoistInstructionSize() < 1) {
return false;
}
@ -442,7 +659,6 @@ StatusOr<bool> MergeIdenticalElements(HloInstruction* conditional,
RemoveInstructionFromComputation(visitors[i].instructions_to_hoist(),
conditional->branch_computation(i)));
}
return true;
}
@ -451,26 +667,55 @@ StatusOr<bool> MergeIdenticalElements(HloInstruction* conditional,
StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
bool changed = false;
// Gather all the conditional ops in our module. We do this ahead of time so
// we don't have to worry about mutating the lists of computations or
// instructions as we iterate.
std::vector<HloInstruction*> conditional_ops;
for (auto* comp : module->MakeComputationPostOrder()) {
for (auto* instr : comp->MakeInstructionPostOrder()) {
if (instr->opcode() == HloOpcode::kConditional) {
conditional_ops.push_back(instr);
if (pursue_full_conditional_code_motion_) {
std::vector<HloInstruction*> conditional_ops;
for (auto* comp : module->MakeComputationPostOrder()) {
for (auto* instr : comp->MakeInstructionPostOrder()) {
if (instr->opcode() == HloOpcode::kConditional) {
conditional_ops.push_back(instr);
}
}
}
for (HloInstruction* conditional_op : conditional_ops) {
TF_ASSIGN_OR_RETURN(
bool result,
MergeIdenticalElements(conditional_op, is_layout_sensitive_));
changed |= result;
}
if (changed) {
HloPassPipeline subpipeline("after_conditional_code_motion");
subpipeline.AddPass<HloDCE>();
subpipeline.AddPass<TupleSimplifier>();
subpipeline.AddPass<HloDCE>();
TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module));
changed |= cleanup_changed;
}
}
for (HloInstruction* conditional_op : conditional_ops) {
TF_ASSIGN_OR_RETURN(bool result, MergeIdenticalElements(
conditional_op, is_layout_sensitive_));
changed |= result;
// handling convert rematerialization/hoisting
{
std::vector<HloInstruction*> conditional_ops;
for (auto* comp : module->MakeComputationPostOrder()) {
for (auto* instr : comp->MakeInstructionPostOrder()) {
if (instr->opcode() == HloOpcode::kConditional) {
conditional_ops.push_back(instr);
}
}
}
for (HloInstruction* conditional_op : conditional_ops) {
TF_ASSIGN_OR_RETURN(
bool convert_result,
ConvertSpecialMove(conditional_op, is_layout_sensitive_));
changed |= convert_result;
}
}
if (changed) {
HloPassPipeline subpipeline("after_conditional_code_motion");
HloPassPipeline subpipeline(
"after_conditional_code_motion_after_convert_hoisting");
subpipeline.AddPass<HloDCE>();
subpipeline.AddPass<TupleSimplifier>();
subpipeline.AddPass<HloDCE>();
TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module));

View File

@ -23,7 +23,11 @@ limitations under the License.
namespace xla {
// HLO pass that moves identical ops out of conditional.
// ConditionalCodeMotion specializes in hoisting/rematerializing
// unconditional converts in the default mode.
// When pursue_full_conditional_code_motion_ is set to true, the
// full HLO pass moves identical ops out of a conditional in addition to moving
// converts.
// - The definition of identical are the shape of the operands are identical
// and their properties are identical.
// - Currently, only some types of instructions is supported.
@ -35,13 +39,18 @@ class ConditionalCodeMotion : public HloModulePass {
public:
// If is_layout_sensitive is true, then the hoist process preserves layout
// during identical comparison. Otherwise, layout is ignored.
explicit ConditionalCodeMotion(bool is_layout_sensitive = true)
: is_layout_sensitive_(is_layout_sensitive) {}
explicit ConditionalCodeMotion(
bool is_layout_sensitive = true,
bool pursue_full_conditional_code_motion = false)
: is_layout_sensitive_(is_layout_sensitive),
pursue_full_conditional_code_motion_(
pursue_full_conditional_code_motion) {}
absl::string_view name() const override { return "conditional-code-motion"; }
StatusOr<bool> Run(HloModule* module) override;
private:
const bool is_layout_sensitive_;
const bool pursue_full_conditional_code_motion_;
};
} // namespace xla

View File

@ -38,7 +38,86 @@ namespace {
using ConditionalCodeMotionTest = HloTestBase;
namespace op = xla::testing::opcode_matchers;
TEST_F(ConditionalCodeMotionTest, DoNotMoveConvertOut) {
TEST_F(ConditionalCodeMotionTest, MoveSubsetTupleOut) {
absl::string_view hlo_string =
R"(
HloModule RemoveDotOpOut
on_true {
%arg_tuple.1 = (f32[93184,4]{1,0}) parameter(0)
%get-tuple-element.1 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.1), index=0
%reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.1)
%convert.2894 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.8493)
ROOT %tuple.1 = ( bf16[2,512,364]{2,1,0}, f32[2,512,364]{2,1,0}) tuple(%convert.2894, %reshape.8493)
}
on_false {
%arg_tuple.2 = (f32[93184,4]{1,0}) parameter(0)
%get-tuple-element.3 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.2), index=0
%reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.3)
%add = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.9717, f32[2,512,364]{2,1,0} %reshape.9717)
%convert.3604 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.9717), metadata={op_type="Cast" op_name="gradients/Cast_125_grad/Cast"}
ROOT %tuple.2 = (bf16[2,512,364]{2,1,0}, f32[2,512,364]{2,1,0}) tuple(%convert.3604, %add)
}
ENTRY main {
pred.1 = pred[] parameter(0)
arg_tuple.11 = (f32[93184,4]{1,0}) parameter(1)
arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2)
conditional = (bf16[2,512,364]{2,1,0}, f32[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false
get-first-index = bf16[2,512,364]{2,1,0} get-tuple-element(conditional), index=0
get-first-index.2 = f32[2,512,364]{2,1,0} get-tuple-element(conditional), index=1
ROOT result = (bf16[2,512,364]{2,1,0}, f32[2,512,364]{2,1,0}) tuple(get-first-index, get-first-index.2)
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass(true, true);
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Tuple(op::Convert(), op::GetTupleElement())));
}
TEST_F(ConditionalCodeMotionTest, MoveConvertOutConditionalRoot) {
absl::string_view hlo_string =
R"(
HloModule RemoveDotOpOut
on_true {
%arg_tuple.1 = (f32[93184,4]{1,0}) parameter(0)
%get-tuple-element.1 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.1), index=0
%reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.1)
%add.8493 = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.8493, f32[2,512,364]{2,1,0} %reshape.8493)
%convert.2894 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %add.8493)
ROOT %tuple.1 = ( bf16[2,512,364]{2,1,0}) tuple(%convert.2894)
}
on_false {
%arg_tuple.2 = (f32[93184,4]{1,0}) parameter(0)
%get-tuple-element.3 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.2), index=0
%reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.3)
%add.8493 = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.9717, f32[2,512,364]{2,1,0} %reshape.9717)
%sub.8493 = f32[2,512,364]{2,1,0} subtract(f32[2,512,364]{2,1,0} %add.8493, f32[2,512,364]{2,1,0} %reshape.9717)
%convert.3604 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.9717), metadata={op_type="Cast" op_name="gradients/Cast_125_grad/Cast"}
ROOT %tuple.2 = (bf16[2,512,364]{2,1,0}) tuple(%convert.3604)
}
ENTRY main {
pred.1 = pred[] parameter(0)
arg_tuple.11 = (f32[93184,4]{1,0}) parameter(1)
arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2)
ROOT conditional = (bf16[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass(true, true);
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Tuple(op::Convert())));
}
TEST_F(ConditionalCodeMotionTest, MoveConvertOut) {
absl::string_view hlo_string =
R"(
HloModule RemoveDotOpOut
@ -65,12 +144,16 @@ ENTRY main {
arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2)
conditional = (bf16[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false
get-first-index = bf16[2,512,364]{2,1,0} get-tuple-element(conditional), index=0
ROOT result = (bf16[2,512,364]{2,1,0}) tuple(get-first-index)
add.1 = bf16[2,512,364]{2,1,0} add(bf16[2,512,364]{2,1,0} get-first-index, bf16[2,512,364]{2,1,0} get-first-index)
ROOT result = (bf16[2,512,364]{2,1,0}) tuple(add.1)
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass;
ASSERT_FALSE(pass.Run(&*module).ValueOrDie());
ConditionalCodeMotion pass(true, true);
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Tuple(op::Add(op::Convert(), op::Convert()))));
}
TEST_F(ConditionalCodeMotionTest, UserShareOperandCannotBeMoved) {
@ -123,7 +206,7 @@ ENTRY main {
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass;
ConditionalCodeMotion pass(true, true);
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
const HloInstruction* conditional =
@ -181,7 +264,7 @@ ENTRY main {
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass;
ConditionalCodeMotion pass(true, true);
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
const HloInstruction* conditional =
FindInstruction(module.get(), "conditional");
@ -245,7 +328,7 @@ ENTRY main {
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass;
ConditionalCodeMotion pass(true, true);
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
const HloInstruction* conditional =
@ -317,7 +400,7 @@ ENTRY main {
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass;
ConditionalCodeMotion pass(true, true);
ASSERT_FALSE(pass.Run(&*module).ValueOrDie());
}
@ -390,7 +473,7 @@ ENTRY main {
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass;
ConditionalCodeMotion pass(true, true);
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
const HloInstruction* conditional =
FindInstruction(module.get(), "conditional");