[XLA] Fix cost model of conditional code motion to disallow AllReduce moving from outside conditional branches to inside. Also add debugging support for isolating the impact of different transformations by the optimization.
PiperOrigin-RevId: 334044092 Change-Id: I93cd566a3c26630f18ae572110d303aae8a91f0d
This commit is contained in:
parent
517e851ee1
commit
46c9cb1ecb
@ -2374,6 +2374,7 @@ cc_library(
|
|||||||
":hlo_pass_pipeline",
|
":hlo_pass_pipeline",
|
||||||
":hlo_verifier",
|
":hlo_verifier",
|
||||||
":tuple_simplifier",
|
":tuple_simplifier",
|
||||||
|
"//tensorflow/compiler/xla:debug_options_flags",
|
||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/algorithm/container.h"
|
#include "absl/algorithm/container.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||||
#include "tensorflow/compiler/xla/literal.h"
|
#include "tensorflow/compiler/xla/literal.h"
|
||||||
#include "tensorflow/compiler/xla/map_util.h"
|
#include "tensorflow/compiler/xla/map_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/call_graph.h"
|
#include "tensorflow/compiler/xla/service/call_graph.h"
|
||||||
@ -498,7 +499,7 @@ StatusOr<bool> ConditionalCodeMotion::MoveInstructionOut(
|
|||||||
<< conditional_parent->ToString(HloPrintOptions::Fingerprint())
|
<< conditional_parent->ToString(HloPrintOptions::Fingerprint())
|
||||||
<< "\n";
|
<< "\n";
|
||||||
int64 op_index = 0;
|
int64 op_index = 0;
|
||||||
for (Boundary b : new_boundaries) {
|
for (const Boundary& b : new_boundaries) {
|
||||||
HloInstruction* op = b.operands()[0];
|
HloInstruction* op = b.operands()[0];
|
||||||
CHECK(op != nullptr);
|
CHECK(op != nullptr);
|
||||||
VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n";
|
VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n";
|
||||||
@ -545,7 +546,7 @@ StatusOr<bool> ConditionalCodeMotion::MoveInstructionOut(
|
|||||||
for (int i = 0; i < branch_count; i++) {
|
for (int i = 0; i < branch_count; i++) {
|
||||||
auto computation = conditional->branch_computation(i);
|
auto computation = conditional->branch_computation(i);
|
||||||
std::vector<HloInstruction*> elements;
|
std::vector<HloInstruction*> elements;
|
||||||
for (auto b1 : new_boundaries) {
|
for (const auto& b1 : new_boundaries) {
|
||||||
HloInstruction* op = b1.operands()[i];
|
HloInstruction* op = b1.operands()[i];
|
||||||
CHECK(op != nullptr);
|
CHECK(op != nullptr);
|
||||||
VLOG(2) << "Adding to root " << i << " with " << op->ToString() << "\n";
|
VLOG(2) << "Adding to root " << i << " with " << op->ToString() << "\n";
|
||||||
@ -556,7 +557,7 @@ StatusOr<bool> ConditionalCodeMotion::MoveInstructionOut(
|
|||||||
computation->set_root_instruction(tuple, true);
|
computation->set_root_instruction(tuple, true);
|
||||||
VLOG(2) << "computation is :" << computation->ToString() << "\n";
|
VLOG(2) << "computation is :" << computation->ToString() << "\n";
|
||||||
// Remove hoisted instructions from the branches.
|
// Remove hoisted instructions from the branches.
|
||||||
for (auto b2 : to_move_out) {
|
for (const auto& b2 : to_move_out) {
|
||||||
auto instr_to_remove = b2.operands()[i];
|
auto instr_to_remove = b2.operands()[i];
|
||||||
// Double check to make sure it is safe to delete the instruction.
|
// Double check to make sure it is safe to delete the instruction.
|
||||||
// Complications may arise due to some operations in the alternative
|
// Complications may arise due to some operations in the alternative
|
||||||
@ -781,7 +782,7 @@ class GroupConnectedBoundaries {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Returns true if `instruction` is worth hoisting.
|
// Returns true if `instruction` is worth hoisting.
|
||||||
bool WorthHoisting(HloInstruction* instruction) {
|
bool WorthHoisting(HloInstruction* instruction, bool is_inside_branch) {
|
||||||
// This is needed for the "moving-in" transformation, to prevent the root
|
// This is needed for the "moving-in" transformation, to prevent the root
|
||||||
// of the parent computation (which contains the conditional) to be moved
|
// of the parent computation (which contains the conditional) to be moved
|
||||||
// inside the conditional.
|
// inside the conditional.
|
||||||
@ -789,6 +790,8 @@ class GroupConnectedBoundaries {
|
|||||||
instruction == conditional_parent_->root_instruction()) {
|
instruction == conditional_parent_->root_instruction()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
// TOOD[b/169182921] The following cost model is rather incomplete. Will
|
||||||
|
// need to extend to cover most of element-wise ops.
|
||||||
switch (instruction->opcode()) {
|
switch (instruction->opcode()) {
|
||||||
case HloOpcode::kConvert:
|
case HloOpcode::kConvert:
|
||||||
// If Convert is after AllReduce, it is worth moving out AllReduce
|
// If Convert is after AllReduce, it is worth moving out AllReduce
|
||||||
@ -815,6 +818,11 @@ class GroupConnectedBoundaries {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
case HloOpcode::kAllReduce:
|
case HloOpcode::kAllReduce:
|
||||||
|
// It is not safe to move collective ops from outside to inside
|
||||||
|
// conditional branches, as it may cause synchronization problems,
|
||||||
|
// when different layouts are assigned to different branches.
|
||||||
|
return is_inside_branch;
|
||||||
|
case HloOpcode::kAbs:
|
||||||
case HloOpcode::kReduce:
|
case HloOpcode::kReduce:
|
||||||
case HloOpcode::kAdd:
|
case HloOpcode::kAdd:
|
||||||
case HloOpcode::kPower:
|
case HloOpcode::kPower:
|
||||||
@ -999,7 +1007,8 @@ class GroupConnectedBoundaries {
|
|||||||
VLOG(2) << "visiting boundary " << b.ToString() << "\n";
|
VLOG(2) << "visiting boundary " << b.ToString() << "\n";
|
||||||
if ((b.IsOutsideBranch() || InstructionWithinBranchIdentical(
|
if ((b.IsOutsideBranch() || InstructionWithinBranchIdentical(
|
||||||
b.operands(), is_layout_sensitive_)) &&
|
b.operands(), is_layout_sensitive_)) &&
|
||||||
IsSafeToMoveBoundary(b) && WorthHoisting(b.operands()[0])) {
|
IsSafeToMoveBoundary(b) &&
|
||||||
|
WorthHoisting(b.operands()[0], b.IsInsideBranch())) {
|
||||||
connected_boundaries_.push_back(b);
|
connected_boundaries_.push_back(b);
|
||||||
VLOG(2) << "boundary can be moved\n";
|
VLOG(2) << "boundary can be moved\n";
|
||||||
int64 operand_count = (b.IsInsideBranch())
|
int64 operand_count = (b.IsInsideBranch())
|
||||||
@ -1087,6 +1096,13 @@ ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion(
|
|||||||
|
|
||||||
StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
|
StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
|
||||||
VLOG(2) << "Begin a new pass of conditional code motion optimization.\n";
|
VLOG(2) << "Begin a new pass of conditional code motion optimization.\n";
|
||||||
|
// Use to support debugging of optimization, by disabling the opt after it has
|
||||||
|
// been applied a pre-determined times (to isolate impact of transformations).
|
||||||
|
if (!ConsumeFuel("conditional_code_motion", [&] {
|
||||||
|
return "Skipping conditional opt after allowed limit reaching 0.\n";
|
||||||
|
})) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
bool cleanup_changed = false;
|
bool cleanup_changed = false;
|
||||||
{
|
{
|
||||||
@ -1177,7 +1193,8 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
|
|||||||
benefit_move_out += d.GetBenefit();
|
benefit_move_out += d.GetBenefit();
|
||||||
if (benefit_move_out >= benefit_move_in) {
|
if (benefit_move_out >= benefit_move_in) {
|
||||||
final_d = Decision::Direction::kMoveOutOfBranch;
|
final_d = Decision::Direction::kMoveOutOfBranch;
|
||||||
VLOG(2) << "Current Decision is move out of branch\n";
|
VLOG(2) << "Current Decision is move out of branch ("
|
||||||
|
<< to_move_out.size() << ")\n";
|
||||||
} else {
|
} else {
|
||||||
VLOG(2) << "Current Decision remains move into branch\n";
|
VLOG(2) << "Current Decision remains move into branch\n";
|
||||||
}
|
}
|
||||||
@ -1191,7 +1208,8 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
|
|||||||
VLOG(2) << "Current Decision remains move out of branch\n";
|
VLOG(2) << "Current Decision remains move out of branch\n";
|
||||||
} else {
|
} else {
|
||||||
final_d = Decision::Direction::kMoveIntoBranch;
|
final_d = Decision::Direction::kMoveIntoBranch;
|
||||||
VLOG(2) << "Current Decision is move into branch\n";
|
VLOG(2) << "Current Decision is move into branch ("
|
||||||
|
<< to_move_in.size() << ")\n";
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case Decision::Direction::kNoChange:
|
case Decision::Direction::kNoChange:
|
||||||
@ -1260,6 +1278,13 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
|
|||||||
new_boundaries_for_moveout[i]));
|
new_boundaries_for_moveout[i]));
|
||||||
changed |= result;
|
changed |= result;
|
||||||
}
|
}
|
||||||
|
VLOG(2) << "Done moving out of branches " << to_move_out.size()
|
||||||
|
<< " times. \n";
|
||||||
|
if (!ConsumeFuel("conditional_code_motion", [&] {
|
||||||
|
return "Skipping conditional opt after allowed limit reaching 0.\n";
|
||||||
|
})) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
} else if (final_d == Decision::Direction::kMoveIntoBranch) {
|
} else if (final_d == Decision::Direction::kMoveIntoBranch) {
|
||||||
CHECK(to_move_in.size() == new_boundaries_for_movein.size());
|
CHECK(to_move_in.size() == new_boundaries_for_movein.size());
|
||||||
for (int i = 0; i < to_move_in.size(); ++i) {
|
for (int i = 0; i < to_move_in.size(); ++i) {
|
||||||
@ -1268,6 +1293,13 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
|
|||||||
new_boundaries_for_movein[i]));
|
new_boundaries_for_movein[i]));
|
||||||
changed |= result;
|
changed |= result;
|
||||||
}
|
}
|
||||||
|
VLOG(2) << "Done moving into branches " << to_move_in.size()
|
||||||
|
<< " times. \n";
|
||||||
|
if (!ConsumeFuel("conditional_code_motion", [&] {
|
||||||
|
return "Skipping conditional opt after allowed limit reaching 0.\n";
|
||||||
|
})) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
} else if (pursue_full_conditional_code_motion_ && !conditional_is_shared) {
|
} else if (pursue_full_conditional_code_motion_ && !conditional_is_shared) {
|
||||||
// Invoke special handling for convert rematerialization/hoisting
|
// Invoke special handling for convert rematerialization/hoisting
|
||||||
// We need to make sure no sharing is present in the branches because no
|
// We need to make sure no sharing is present in the branches because no
|
||||||
@ -1276,8 +1308,16 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
|
|||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
bool convert_result,
|
bool convert_result,
|
||||||
ConvertSpecialMove(conditional, is_layout_sensitive_));
|
ConvertSpecialMove(conditional, is_layout_sensitive_));
|
||||||
changed |= convert_result;
|
if (convert_result) {
|
||||||
VLOG(2) << "Done special moving of convert\n";
|
VLOG(2) << "Done special moving of convert\n";
|
||||||
|
if (!ConsumeFuel("conditional_code_motion", [&] {
|
||||||
|
return "Skipping conditional opt after allowed limit reaching "
|
||||||
|
"0.\n";
|
||||||
|
})) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
changed |= convert_result;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (changed) {
|
if (changed) {
|
||||||
|
@ -576,6 +576,89 @@ ENTRY main {
|
|||||||
op::AllReduce(op::GetTupleElement(op::Conditional())))))));
|
op::AllReduce(op::GetTupleElement(op::Conditional())))))));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ConditionalCodeMotionTest, DoNotMoveAllReduceIn) {
|
||||||
|
absl::string_view hlo_string =
|
||||||
|
R"(
|
||||||
|
HloModule RemoveIdenticalInstruction
|
||||||
|
|
||||||
|
%add.64 (x.139: bf16[], y.139: bf16[]) -> bf16[] {
|
||||||
|
%x.139 = bf16[]{:T(512)} parameter(0)
|
||||||
|
%y.139 = bf16[]{:T(512)} parameter(1)
|
||||||
|
ROOT %add.44073 = bf16[]{:T(512)} add(bf16[]{:T(512)} %x.139, bf16[]{:T(512)} %y.139)
|
||||||
|
}
|
||||||
|
|
||||||
|
%add.181 (x.256: bf16[], y.256: bf16[]) -> bf16[] {
|
||||||
|
%x.256 = bf16[]{:T(512)} parameter(0)
|
||||||
|
%y.256 = bf16[]{:T(512)} parameter(1)
|
||||||
|
ROOT %add.44842 = bf16[]{:T(512)} add(bf16[]{:T(512)} %x.256, bf16[]{:T(512)} %y.256)
|
||||||
|
}
|
||||||
|
|
||||||
|
on_true {
|
||||||
|
arg_tuple.1 = (bf16[2,54,168,128], bf16[2,52,168,128]) parameter(0)
|
||||||
|
get-tuple-element.11 = bf16[2,54,168,128] get-tuple-element(arg_tuple.1), index=0
|
||||||
|
get-tuple-element.12 = bf16[2,52,168,128] get-tuple-element(arg_tuple.1), index=1
|
||||||
|
convolution.1 = bf16[3,3,128,128] convolution(bf16[2,54,168,128]
|
||||||
|
get-tuple-element.11, bf16[2,52,168,128]
|
||||||
|
get-tuple-element.12), window={size=52x168 pad=0_0x1_1},
|
||||||
|
dim_labels=f01b_i01o->01bf
|
||||||
|
add.1 = bf16[3,3,128,128] add(bf16[3,3,128,128] convolution.1, bf16[3,3,128,128] convolution.1)
|
||||||
|
ROOT tuple.1 = (bf16[3,3,128,128]) tuple(add.1)
|
||||||
|
}
|
||||||
|
|
||||||
|
on_false {
|
||||||
|
arg_tuple.2 = (bf16[2,86,104,128], bf16[2,84,104,128]) parameter(0)
|
||||||
|
get-tuple-element.21 = bf16[2,86,104,128]
|
||||||
|
get-tuple-element(arg_tuple.2), index=0
|
||||||
|
get-tuple-element.22 = bf16[2,84,104,128]
|
||||||
|
get-tuple-element(arg_tuple.2), index=1
|
||||||
|
convolution.2 = bf16[3,3,128,128]
|
||||||
|
convolution(bf16[2,86,104,128] get-tuple-element.21, bf16[2,84,104,128]
|
||||||
|
get-tuple-element.22), window={size=84x104 pad=0_0x1_1},
|
||||||
|
dim_labels=f01b_i01o->01bf
|
||||||
|
add.2 = bf16[3,3,128,128] add(bf16[3,3,128,128] convolution.2, bf16[3,3,128,128] convolution.2)
|
||||||
|
ROOT tuple.2 = (bf16[3,3,128,128]) tuple(add.2)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY main {
|
||||||
|
pred.1 = pred[] parameter(0)
|
||||||
|
arg_tuple.3 = (bf16[2,54,168,128], bf16[2,52,168,128]) parameter(1)
|
||||||
|
arg_tuple.4 = (bf16[2,86,104,128], bf16[2,84,104,128]) parameter(2)
|
||||||
|
arg_tuple.5 = f32[3,3,128,128] parameter(3)
|
||||||
|
conditional = (bf16[3,3,128,128])
|
||||||
|
conditional(pred.1, arg_tuple.3, arg_tuple.4), true_computation=on_true,
|
||||||
|
false_computation=on_false
|
||||||
|
get-first-index = bf16[3,3,128,128] get-tuple-element(conditional), index=0
|
||||||
|
all-reduce.2 = bf16[3,3,128,128]
|
||||||
|
all-reduce(bf16[3,3,128,128] %get-first-index),
|
||||||
|
channel_id=485, replica_groups={{0,1}}, use_global_device_ids=true,
|
||||||
|
to_apply=%add.181, metadata={op_type="Conv2DBackpropFilter"
|
||||||
|
op_name="gradients/resnet50/conv2d_22/Conv2D_grad/Conv2DBackpropFilter"}
|
||||||
|
convert.2 = f32[3,3,128,128]
|
||||||
|
convert(bf16[3,3,128,128] %all-reduce.2),
|
||||||
|
metadata={op_type="Cast" op_name="Cast_15"}
|
||||||
|
ROOT result = (f32[3,3,128,128]) tuple(convert.2)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||||
|
ConditionalCodeMotion pass(true, true);
|
||||||
|
ASSERT_FALSE(pass.Run(&*module).ValueOrDie());
|
||||||
|
const HloInstruction* conditional =
|
||||||
|
FindInstruction(module.get(), "conditional");
|
||||||
|
CHECK(conditional != nullptr);
|
||||||
|
const HloComputation* on_true = conditional->branch_computation(0);
|
||||||
|
ASSERT_EQ(on_true->instruction_count(), 6);
|
||||||
|
const HloComputation* on_false = conditional->branch_computation(1);
|
||||||
|
ASSERT_EQ(on_false->instruction_count(), 6);
|
||||||
|
|
||||||
|
// Checks if conditional shape has changed.
|
||||||
|
ASSERT_TRUE(ShapeUtil::Compatible(
|
||||||
|
conditional->shape(), ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(
|
||||||
|
BF16, {3, 3, 128, 128})})));
|
||||||
|
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||||
|
EXPECT_THAT(root, AllOf(op::Tuple(op::Convert(op::AllReduce(
|
||||||
|
op::GetTupleElement(op::Conditional()))))));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ConditionalCodeMotionTest, MovePowOpIn) {
|
TEST_F(ConditionalCodeMotionTest, MovePowOpIn) {
|
||||||
absl::string_view hlo_string =
|
absl::string_view hlo_string =
|
||||||
R"(
|
R"(
|
||||||
|
Loading…
Reference in New Issue
Block a user