[XLA] Fix cost model of conditional code motion to disallow AllReduce moving from outside conditional branches to inside. Also add debugging support for isolating the impact of different transformations by the optimization.

PiperOrigin-RevId: 334044092
Change-Id: I93cd566a3c26630f18ae572110d303aae8a91f0d
This commit is contained in:
A. Unique TensorFlower 2020-09-27 14:23:26 -07:00 committed by TensorFlower Gardener
parent 517e851ee1
commit 46c9cb1ecb
3 changed files with 132 additions and 8 deletions

View File

@ -2374,6 +2374,7 @@ cc_library(
":hlo_pass_pipeline",
":hlo_verifier",
":tuple_simplifier",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
@ -498,7 +499,7 @@ StatusOr<bool> ConditionalCodeMotion::MoveInstructionOut(
<< conditional_parent->ToString(HloPrintOptions::Fingerprint())
<< "\n";
int64 op_index = 0;
for (Boundary b : new_boundaries) {
for (const Boundary& b : new_boundaries) {
HloInstruction* op = b.operands()[0];
CHECK(op != nullptr);
VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n";
@ -545,7 +546,7 @@ StatusOr<bool> ConditionalCodeMotion::MoveInstructionOut(
for (int i = 0; i < branch_count; i++) {
auto computation = conditional->branch_computation(i);
std::vector<HloInstruction*> elements;
for (auto b1 : new_boundaries) {
for (const auto& b1 : new_boundaries) {
HloInstruction* op = b1.operands()[i];
CHECK(op != nullptr);
VLOG(2) << "Adding to root " << i << " with " << op->ToString() << "\n";
@ -556,7 +557,7 @@ StatusOr<bool> ConditionalCodeMotion::MoveInstructionOut(
computation->set_root_instruction(tuple, true);
VLOG(2) << "computation is :" << computation->ToString() << "\n";
// Remove hoisted instructions from the branches.
for (auto b2 : to_move_out) {
for (const auto& b2 : to_move_out) {
auto instr_to_remove = b2.operands()[i];
// Double check to make sure it is safe to delete the instruction.
// Complications may arise due to some operations in the alternative
@ -781,7 +782,7 @@ class GroupConnectedBoundaries {
}
}
// Returns true if `instruction` is worth hoisting.
bool WorthHoisting(HloInstruction* instruction) {
bool WorthHoisting(HloInstruction* instruction, bool is_inside_branch) {
// This is needed for the "moving-in" transformation, to prevent the root
// of the parent computation (which contains the conditional) to be moved
// inside the conditional.
@ -789,6 +790,8 @@ class GroupConnectedBoundaries {
instruction == conditional_parent_->root_instruction()) {
return false;
}
// TOOD[b/169182921] The following cost model is rather incomplete. Will
// need to extend to cover most of element-wise ops.
switch (instruction->opcode()) {
case HloOpcode::kConvert:
// If Convert is after AllReduce, it is worth moving out AllReduce
@ -815,6 +818,11 @@ class GroupConnectedBoundaries {
return true;
}
case HloOpcode::kAllReduce:
// It is not safe to move collective ops from outside to inside
// conditional branches, as it may cause synchronization problems,
// when different layouts are assigned to different branches.
return is_inside_branch;
case HloOpcode::kAbs:
case HloOpcode::kReduce:
case HloOpcode::kAdd:
case HloOpcode::kPower:
@ -999,7 +1007,8 @@ class GroupConnectedBoundaries {
VLOG(2) << "visiting boundary " << b.ToString() << "\n";
if ((b.IsOutsideBranch() || InstructionWithinBranchIdentical(
b.operands(), is_layout_sensitive_)) &&
IsSafeToMoveBoundary(b) && WorthHoisting(b.operands()[0])) {
IsSafeToMoveBoundary(b) &&
WorthHoisting(b.operands()[0], b.IsInsideBranch())) {
connected_boundaries_.push_back(b);
VLOG(2) << "boundary can be moved\n";
int64 operand_count = (b.IsInsideBranch())
@ -1087,6 +1096,13 @@ ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion(
StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
VLOG(2) << "Begin a new pass of conditional code motion optimization.\n";
// Use to support debugging of optimization, by disabling the opt after it has
// been applied a pre-determined times (to isolate impact of transformations).
if (!ConsumeFuel("conditional_code_motion", [&] {
return "Skipping conditional opt after allowed limit reaching 0.\n";
})) {
return false;
}
bool changed = false;
bool cleanup_changed = false;
{
@ -1177,7 +1193,8 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
benefit_move_out += d.GetBenefit();
if (benefit_move_out >= benefit_move_in) {
final_d = Decision::Direction::kMoveOutOfBranch;
VLOG(2) << "Current Decision is move out of branch\n";
VLOG(2) << "Current Decision is move out of branch ("
<< to_move_out.size() << ")\n";
} else {
VLOG(2) << "Current Decision remains move into branch\n";
}
@ -1191,7 +1208,8 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
VLOG(2) << "Current Decision remains move out of branch\n";
} else {
final_d = Decision::Direction::kMoveIntoBranch;
VLOG(2) << "Current Decision is move into branch\n";
VLOG(2) << "Current Decision is move into branch ("
<< to_move_in.size() << ")\n";
}
break;
case Decision::Direction::kNoChange:
@ -1260,6 +1278,13 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
new_boundaries_for_moveout[i]));
changed |= result;
}
VLOG(2) << "Done moving out of branches " << to_move_out.size()
<< " times. \n";
if (!ConsumeFuel("conditional_code_motion", [&] {
return "Skipping conditional opt after allowed limit reaching 0.\n";
})) {
break;
}
} else if (final_d == Decision::Direction::kMoveIntoBranch) {
CHECK(to_move_in.size() == new_boundaries_for_movein.size());
for (int i = 0; i < to_move_in.size(); ++i) {
@ -1268,6 +1293,13 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
new_boundaries_for_movein[i]));
changed |= result;
}
VLOG(2) << "Done moving into branches " << to_move_in.size()
<< " times. \n";
if (!ConsumeFuel("conditional_code_motion", [&] {
return "Skipping conditional opt after allowed limit reaching 0.\n";
})) {
break;
}
} else if (pursue_full_conditional_code_motion_ && !conditional_is_shared) {
// Invoke special handling for convert rematerialization/hoisting
// We need to make sure no sharing is present in the branches because no
@ -1276,8 +1308,16 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
TF_ASSIGN_OR_RETURN(
bool convert_result,
ConvertSpecialMove(conditional, is_layout_sensitive_));
changed |= convert_result;
if (convert_result) {
VLOG(2) << "Done special moving of convert\n";
if (!ConsumeFuel("conditional_code_motion", [&] {
return "Skipping conditional opt after allowed limit reaching "
"0.\n";
})) {
break;
}
}
changed |= convert_result;
}
}
if (changed) {

View File

@ -576,6 +576,89 @@ ENTRY main {
op::AllReduce(op::GetTupleElement(op::Conditional())))))));
}
TEST_F(ConditionalCodeMotionTest, DoNotMoveAllReduceIn) {
absl::string_view hlo_string =
R"(
HloModule RemoveIdenticalInstruction
%add.64 (x.139: bf16[], y.139: bf16[]) -> bf16[] {
%x.139 = bf16[]{:T(512)} parameter(0)
%y.139 = bf16[]{:T(512)} parameter(1)
ROOT %add.44073 = bf16[]{:T(512)} add(bf16[]{:T(512)} %x.139, bf16[]{:T(512)} %y.139)
}
%add.181 (x.256: bf16[], y.256: bf16[]) -> bf16[] {
%x.256 = bf16[]{:T(512)} parameter(0)
%y.256 = bf16[]{:T(512)} parameter(1)
ROOT %add.44842 = bf16[]{:T(512)} add(bf16[]{:T(512)} %x.256, bf16[]{:T(512)} %y.256)
}
on_true {
arg_tuple.1 = (bf16[2,54,168,128], bf16[2,52,168,128]) parameter(0)
get-tuple-element.11 = bf16[2,54,168,128] get-tuple-element(arg_tuple.1), index=0
get-tuple-element.12 = bf16[2,52,168,128] get-tuple-element(arg_tuple.1), index=1
convolution.1 = bf16[3,3,128,128] convolution(bf16[2,54,168,128]
get-tuple-element.11, bf16[2,52,168,128]
get-tuple-element.12), window={size=52x168 pad=0_0x1_1},
dim_labels=f01b_i01o->01bf
add.1 = bf16[3,3,128,128] add(bf16[3,3,128,128] convolution.1, bf16[3,3,128,128] convolution.1)
ROOT tuple.1 = (bf16[3,3,128,128]) tuple(add.1)
}
on_false {
arg_tuple.2 = (bf16[2,86,104,128], bf16[2,84,104,128]) parameter(0)
get-tuple-element.21 = bf16[2,86,104,128]
get-tuple-element(arg_tuple.2), index=0
get-tuple-element.22 = bf16[2,84,104,128]
get-tuple-element(arg_tuple.2), index=1
convolution.2 = bf16[3,3,128,128]
convolution(bf16[2,86,104,128] get-tuple-element.21, bf16[2,84,104,128]
get-tuple-element.22), window={size=84x104 pad=0_0x1_1},
dim_labels=f01b_i01o->01bf
add.2 = bf16[3,3,128,128] add(bf16[3,3,128,128] convolution.2, bf16[3,3,128,128] convolution.2)
ROOT tuple.2 = (bf16[3,3,128,128]) tuple(add.2)
}
ENTRY main {
pred.1 = pred[] parameter(0)
arg_tuple.3 = (bf16[2,54,168,128], bf16[2,52,168,128]) parameter(1)
arg_tuple.4 = (bf16[2,86,104,128], bf16[2,84,104,128]) parameter(2)
arg_tuple.5 = f32[3,3,128,128] parameter(3)
conditional = (bf16[3,3,128,128])
conditional(pred.1, arg_tuple.3, arg_tuple.4), true_computation=on_true,
false_computation=on_false
get-first-index = bf16[3,3,128,128] get-tuple-element(conditional), index=0
all-reduce.2 = bf16[3,3,128,128]
all-reduce(bf16[3,3,128,128] %get-first-index),
channel_id=485, replica_groups={{0,1}}, use_global_device_ids=true,
to_apply=%add.181, metadata={op_type="Conv2DBackpropFilter"
op_name="gradients/resnet50/conv2d_22/Conv2D_grad/Conv2DBackpropFilter"}
convert.2 = f32[3,3,128,128]
convert(bf16[3,3,128,128] %all-reduce.2),
metadata={op_type="Cast" op_name="Cast_15"}
ROOT result = (f32[3,3,128,128]) tuple(convert.2)
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass(true, true);
ASSERT_FALSE(pass.Run(&*module).ValueOrDie());
const HloInstruction* conditional =
FindInstruction(module.get(), "conditional");
CHECK(conditional != nullptr);
const HloComputation* on_true = conditional->branch_computation(0);
ASSERT_EQ(on_true->instruction_count(), 6);
const HloComputation* on_false = conditional->branch_computation(1);
ASSERT_EQ(on_false->instruction_count(), 6);
// Checks if conditional shape has changed.
ASSERT_TRUE(ShapeUtil::Compatible(
conditional->shape(), ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(
BF16, {3, 3, 128, 128})})));
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Tuple(op::Convert(op::AllReduce(
op::GetTupleElement(op::Conditional()))))));
}
TEST_F(ConditionalCodeMotionTest, MovePowOpIn) {
absl::string_view hlo_string =
R"(