[XLA] Expose configuration space of conditional code motion to support auto-tuning of its cost model.
PiperOrigin-RevId: 351858804 Change-Id: Iaa4abf304af21cba88c09d1887ec93671a570b8a
This commit is contained in:
parent
5b5396b4b2
commit
c1336b952d
@ -2396,6 +2396,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
|
"@com_google_absl//absl/flags:flag",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -113,16 +113,14 @@ int64 CountNonLeafOps(const OpCollection& ops) {
|
|||||||
// instructions. Use different integers to classify different levels
|
// instructions. Use different integers to classify different levels
|
||||||
// of reuses This is used as a placeholder only, assuming all
|
// of reuses This is used as a placeholder only, assuming all
|
||||||
// instructions can be fused to enable data reuses
|
// instructions can be fused to enable data reuses
|
||||||
int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) {
|
int64 ReusesCarriedBy(HloOpcode op, HloOpcode user) {
|
||||||
// Reuses in some way work like forces that pull instructions
|
// Reuses in some way work like forces that pull instructions
|
||||||
// towards each other. We use a number 0-10 to classify how strong the force
|
// towards each other. We use a number 0-10 to classify how strong the force
|
||||||
// is between a pair of operations. Given a group of instructions that can be
|
// is between a pair of operations. Given a group of instructions that can be
|
||||||
// moved together, if the forces inside a conditional are stronger, the group
|
// moved together, if the forces inside a conditional are stronger, the group
|
||||||
// will be moved incide or remain inside the conditional; otherwise, it will
|
// will be moved incide or remain inside the conditional; otherwise, it will
|
||||||
// be moved outside to or remain outside of the conditional.
|
// be moved outside to or remain outside of the conditional.
|
||||||
VLOG(2) << "ConditionalCodeMotion: Add reuses carried by instr: "
|
switch (user) {
|
||||||
<< op->ToString() << "=>" << user->ToString() << "\n";
|
|
||||||
switch (user->opcode()) {
|
|
||||||
case HloOpcode::kGetTupleElement:
|
case HloOpcode::kGetTupleElement:
|
||||||
return 0;
|
return 0;
|
||||||
case HloOpcode::kConvert:
|
case HloOpcode::kConvert:
|
||||||
@ -130,7 +128,7 @@ int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) {
|
|||||||
// convolution, here if op is dot or convolution, they must be separated
|
// convolution, here if op is dot or convolution, they must be separated
|
||||||
// by a conditional boundary. Here we do not try to pull convert inside
|
// by a conditional boundary. Here we do not try to pull convert inside
|
||||||
// conditionals to be together with the dot or convolution.
|
// conditionals to be together with the dot or convolution.
|
||||||
switch (op->opcode()) {
|
switch (op) {
|
||||||
case HloOpcode::kConvolution:
|
case HloOpcode::kConvolution:
|
||||||
case HloOpcode::kDot:
|
case HloOpcode::kDot:
|
||||||
return 0;
|
return 0;
|
||||||
@ -141,7 +139,7 @@ int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) {
|
|||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
switch (op->opcode()) {
|
switch (op) {
|
||||||
// These instructions do not carry weight of reuse themselves.
|
// These instructions do not carry weight of reuse themselves.
|
||||||
case HloOpcode::kParameter:
|
case HloOpcode::kParameter:
|
||||||
case HloOpcode::kConstant:
|
case HloOpcode::kConstant:
|
||||||
@ -149,12 +147,57 @@ int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) {
|
|||||||
return 0;
|
return 0;
|
||||||
case HloOpcode::kConditional:
|
case HloOpcode::kConditional:
|
||||||
return 10;
|
return 10;
|
||||||
default: {
|
default:
|
||||||
// Assume the reuse decreases with increasing user count.
|
return -10;
|
||||||
int count1 = CountNonLeafOps(op->users());
|
|
||||||
int count2 = CountNonLeafOps(user->operands());
|
|
||||||
return 10 / count1 / count2;
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns true if `op` is worth hoisting.
|
||||||
|
bool WorthHoisting(HloOpcode op, HloOpcode child_op) {
|
||||||
|
// TOOD[b/169182921] The following cost model is rather incomplete. Will
|
||||||
|
// need to extend to cover most of element-wise ops.
|
||||||
|
switch (op) {
|
||||||
|
case HloOpcode::kConvert:
|
||||||
|
// If Convert is after AllReduce, it is worth moving out AllReduce
|
||||||
|
// out of conditional for AR/CRS combine. If Convert is after other
|
||||||
|
// ops such as Dot or Convolutional, it is better to keep convert
|
||||||
|
// within conditional so that convert can be fused with Dot or
|
||||||
|
// Convolutional.
|
||||||
|
switch (child_op) {
|
||||||
|
case HloOpcode::kAllReduce:
|
||||||
|
case HloOpcode::kReshape:
|
||||||
|
case HloOpcode::kGetTupleElement:
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
case HloOpcode::kGetTupleElement:
|
||||||
|
switch (child_op) {
|
||||||
|
// do not move GTE if its operand is a parameter
|
||||||
|
case HloOpcode::kParameter:
|
||||||
|
return false;
|
||||||
|
default:
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
case HloOpcode::kAllReduce:
|
||||||
|
case HloOpcode::kAbs:
|
||||||
|
case HloOpcode::kReduce:
|
||||||
|
case HloOpcode::kAdd:
|
||||||
|
case HloOpcode::kPower:
|
||||||
|
case HloOpcode::kCopy:
|
||||||
|
case HloOpcode::kConstant:
|
||||||
|
case HloOpcode::kSubtract:
|
||||||
|
case HloOpcode::kMultiply:
|
||||||
|
case HloOpcode::kDivide:
|
||||||
|
case HloOpcode::kTuple:
|
||||||
|
case HloOpcode::kSqrt:
|
||||||
|
case HloOpcode::kRsqrt:
|
||||||
|
case HloOpcode::kReshape:
|
||||||
|
case HloOpcode::kMinimum:
|
||||||
|
case HloOpcode::kMaximum:
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -579,8 +622,6 @@ StatusOr<bool> ConditionalCodeMotion::MoveInstructionOut(
|
|||||||
HloInstruction* new_root =
|
HloInstruction* new_root =
|
||||||
conditional->branch_computation(0)->root_instruction();
|
conditional->branch_computation(0)->root_instruction();
|
||||||
*conditional->mutable_shape() = new_root->shape();
|
*conditional->mutable_shape() = new_root->shape();
|
||||||
|
|
||||||
//
|
|
||||||
VLOG(1) << "done moving instructions out of branches\n"
|
VLOG(1) << "done moving instructions out of branches\n"
|
||||||
<< conditional_parent->ToString(HloPrintOptions::Fingerprint())
|
<< conditional_parent->ToString(HloPrintOptions::Fingerprint())
|
||||||
<< "\n";
|
<< "\n";
|
||||||
@ -772,15 +813,105 @@ class GroupConnectedBoundaries {
|
|||||||
bool is_layout_sensitive_;
|
bool is_layout_sensitive_;
|
||||||
// Instructions that have been visited but are not going to be moved.
|
// Instructions that have been visited but are not going to be moved.
|
||||||
absl::flat_hash_map<HloInstruction*, int>& visited_count_;
|
absl::flat_hash_map<HloInstruction*, int>& visited_count_;
|
||||||
|
// The following four lines are configurations of the cost model, which will
|
||||||
|
// be used to determine whether to move an instruction (move_config_) and how
|
||||||
|
// strongly preferred it is to keep a pair of ops together (reuse_config_).
|
||||||
|
// The search_config_ is used to control how to navigate the search space of
|
||||||
|
// the cost model in the context of auto/manual tuning. The flipped array is
|
||||||
|
// used to save which entries in the configuration have been changed in the
|
||||||
|
// search/tuning process.
|
||||||
|
std::vector<std::vector<int64>>& move_config_;
|
||||||
|
std::vector<std::vector<int64>>& reuse_config_;
|
||||||
|
int& search_config_;
|
||||||
|
absl::flat_hash_map<const int64*, int64> flipped_;
|
||||||
|
|
||||||
|
// The FlipMutation function serves to implement the search of alternative
|
||||||
|
// cost models by deciding whether to flip a given configuration, saved in
|
||||||
|
// the loc parameter. The non_zero parameter provides the new value to use
|
||||||
|
// to flip a zero. The msg parameter is only used for debugging purpposes.
|
||||||
|
int64 FlipMutation(int64* loc, const int64 non_zero, const std::string& msg) {
|
||||||
|
if (search_config_ == 0 || ContainsKey(flipped_, loc)) {
|
||||||
|
VLOG(2) << "Configured not to search or loc is already flipped.";
|
||||||
|
return *loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The 8-16 digits control the maximum number of times to flip a config.
|
||||||
|
int flip_count = (search_config_ >> 8) & 255;
|
||||||
|
if (flip_count == 0) {
|
||||||
|
VLOG(2) << "Maximum flip count has reached. ";
|
||||||
|
return *loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The last 8 digits control when to start the first flip.
|
||||||
|
int c = search_config_ & 255;
|
||||||
|
VLOG(2) << "flip start index = " << c << "\n";
|
||||||
|
// Only flip the decision if c reaches 0.
|
||||||
|
if (c > 0) {
|
||||||
|
search_config_--;
|
||||||
|
return *loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrement flip count so we can stop if it reaches 0.
|
||||||
|
search_config_ -= 256;
|
||||||
|
// Reload the 16-23 digits of the configuration, which controls how
|
||||||
|
// frequently a configuration should be flipped.
|
||||||
|
search_config_ += (search_config_ >> 16) & 255;
|
||||||
|
VLOG(2) << "Updating Flipping configuration = " << search_config_ << "\n";
|
||||||
|
|
||||||
|
flipped_[loc] = *loc;
|
||||||
|
// Copy the last 8 bits back to the first 8 bits of configuration.
|
||||||
|
switch (*loc) {
|
||||||
|
case 0:
|
||||||
|
*loc = non_zero;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
*loc = 0;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
VLOG(2) << "Flipping decision for: " << msg << ": from " << flipped_[loc]
|
||||||
|
<< " to " << *loc << "\n";
|
||||||
|
return *loc;
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit GroupConnectedBoundaries(
|
explicit GroupConnectedBoundaries(
|
||||||
HloInstruction* conditional, bool is_layout_sensitive,
|
HloInstruction* conditional, bool is_layout_sensitive,
|
||||||
absl::flat_hash_map<HloInstruction*, int>& visited_count)
|
absl::flat_hash_map<HloInstruction*, int>& visited_count,
|
||||||
|
std::vector<std::vector<int64>>* move_config,
|
||||||
|
std::vector<std::vector<int64>>* reuse_config, int* search_config)
|
||||||
: conditional_(conditional),
|
: conditional_(conditional),
|
||||||
conditional_parent_(conditional->parent()),
|
conditional_parent_(conditional->parent()),
|
||||||
is_layout_sensitive_(is_layout_sensitive),
|
is_layout_sensitive_(is_layout_sensitive),
|
||||||
visited_count_(visited_count) {}
|
visited_count_(visited_count),
|
||||||
|
move_config_(*move_config),
|
||||||
|
reuse_config_(*reuse_config),
|
||||||
|
search_config_(*search_config) {}
|
||||||
|
// Returns estimation of potential reuses carried by a given pair of
|
||||||
|
// instructions. Use different integers to classify different levels
|
||||||
|
// of reuses. Assume all instructions can be fused to enable data reuses.
|
||||||
|
int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) {
|
||||||
|
std::vector<int64>& curconfig =
|
||||||
|
reuse_config_[static_cast<uint32>(op->opcode())];
|
||||||
|
// Flip the reuse configuration if tuning the cost model.
|
||||||
|
// When flipping, use -10 if flipping to the default reuse model. Other
|
||||||
|
// values can be specified if needed to fine-control the decision making.
|
||||||
|
int64 config =
|
||||||
|
(search_config_ < 0)
|
||||||
|
? FlipMutation(&curconfig[static_cast<uint32>(user->opcode())], -10,
|
||||||
|
HloOpcodeString(op->opcode()) + "->" +
|
||||||
|
HloOpcodeString(user->opcode()))
|
||||||
|
: curconfig[static_cast<uint32>(user->opcode())];
|
||||||
|
VLOG(2) << "ConditionalCodeMotion: Add reuses carried by instr: "
|
||||||
|
<< op->ToString() << "=>" << user->ToString() << " : " << config
|
||||||
|
<< "\n";
|
||||||
|
if (config < 0) {
|
||||||
|
// Assume the reuse decreases with increasing user count.
|
||||||
|
int count1 = CountNonLeafOps(op->users());
|
||||||
|
int count2 = CountNonLeafOps(user->operands());
|
||||||
|
return (-config) / count1 / count2;
|
||||||
|
}
|
||||||
|
return config;
|
||||||
|
}
|
||||||
void clear_recently_visited() {
|
void clear_recently_visited() {
|
||||||
for (const auto& boundary : new_boundaries_) {
|
for (const auto& boundary : new_boundaries_) {
|
||||||
visited_count_.erase(boundary.operands()[0]);
|
visited_count_.erase(boundary.operands()[0]);
|
||||||
@ -791,63 +922,41 @@ class GroupConnectedBoundaries {
|
|||||||
// This is needed for the "moving-in" transformation, to prevent the root
|
// This is needed for the "moving-in" transformation, to prevent the root
|
||||||
// of the parent computation (which contains the conditional) to be moved
|
// of the parent computation (which contains the conditional) to be moved
|
||||||
// inside the conditional.
|
// inside the conditional.
|
||||||
if (instruction->opcode() == HloOpcode::kTuple &&
|
HloOpcode opcode = instruction->opcode();
|
||||||
|
if (opcode == HloOpcode::kTuple &&
|
||||||
instruction == conditional_parent_->root_instruction()) {
|
instruction == conditional_parent_->root_instruction()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// TOOD[b/169182921] The following cost model is rather incomplete. Will
|
|
||||||
// need to extend to cover most of element-wise ops.
|
|
||||||
switch (instruction->opcode()) {
|
|
||||||
case HloOpcode::kConvert:
|
|
||||||
// If Convert is after AllReduce, it is worth moving out AllReduce
|
|
||||||
// out of conditional for AR/CRS combine. If Convert is after other
|
|
||||||
// ops such as Dot or Convolutional, it is better to keep convert
|
|
||||||
// within conditional so that convert can be fused with Dot or
|
|
||||||
// Convolutional.
|
|
||||||
switch (instruction->operand(0)->opcode()) {
|
|
||||||
case HloOpcode::kAllReduce:
|
|
||||||
case HloOpcode::kReshape:
|
|
||||||
case HloOpcode::kGetTupleElement:
|
|
||||||
return true;
|
|
||||||
default:
|
|
||||||
VLOG(2) << "Instruction is convert and its operand is not known to "
|
|
||||||
"be worth hoisting\n";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
case HloOpcode::kGetTupleElement:
|
|
||||||
switch (instruction->operand(0)->opcode()) {
|
|
||||||
// do not move GTE if its operand is a parameter
|
|
||||||
case HloOpcode::kParameter:
|
|
||||||
return false;
|
|
||||||
default:
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
case HloOpcode::kAllReduce:
|
|
||||||
// It is not safe to move collective ops from outside to inside
|
// It is not safe to move collective ops from outside to inside
|
||||||
// conditional branches, as it may cause synchronization problems,
|
// conditional branches, as it may cause synchronization problems,
|
||||||
// when different layouts are assigned to different branches.
|
// when different layouts are assigned to different branches.
|
||||||
return is_inside_branch;
|
if (opcode == HloOpcode::kAllReduce && !is_inside_branch) {
|
||||||
case HloOpcode::kAbs:
|
|
||||||
case HloOpcode::kReduce:
|
|
||||||
case HloOpcode::kAdd:
|
|
||||||
case HloOpcode::kPower:
|
|
||||||
case HloOpcode::kCopy:
|
|
||||||
case HloOpcode::kConstant:
|
|
||||||
case HloOpcode::kSubtract:
|
|
||||||
case HloOpcode::kMultiply:
|
|
||||||
case HloOpcode::kDivide:
|
|
||||||
case HloOpcode::kTuple:
|
|
||||||
case HloOpcode::kSqrt:
|
|
||||||
case HloOpcode::kRsqrt:
|
|
||||||
case HloOpcode::kReshape:
|
|
||||||
case HloOpcode::kMinimum:
|
|
||||||
case HloOpcode::kMaximum:
|
|
||||||
return true;
|
|
||||||
default:
|
|
||||||
VLOG(2) << "Instruction is not known to be worth hoisting\n";
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// It is not legal to move the parameter instructions.
|
||||||
|
if (opcode == HloOpcode::kParameter) {
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Use configuration given from outside (e.g., by autotuner).
|
||||||
|
std::vector<int64>& curconfig = move_config_[static_cast<uint32>(opcode)];
|
||||||
|
auto col = (curconfig.size() == 1) ? 0
|
||||||
|
: (instruction->operand_count() > 0)
|
||||||
|
? static_cast<uint32>(instruction->operand(0)->opcode())
|
||||||
|
: 0;
|
||||||
|
VLOG(2) << "column = " << col << "\n";
|
||||||
|
VLOG(2) << "config size = " << curconfig.size() << "\n";
|
||||||
|
VLOG(2) << "search_config = " << search_config_ << "\n";
|
||||||
|
CHECK(col < curconfig.size());
|
||||||
|
uint32 config = (search_config_ > 0)
|
||||||
|
? FlipMutation(&curconfig[col], 1,
|
||||||
|
"Move-" + HloOpcodeString(opcode))
|
||||||
|
: curconfig[col];
|
||||||
|
VLOG(2) << "Checking instruction is worth moving: " << config << "\n";
|
||||||
|
return (config != 0);
|
||||||
|
}
|
||||||
|
|
||||||
int64 ReusesBeforeBoundary(HloInstruction* user) {
|
int64 ReusesBeforeBoundary(HloInstruction* user) {
|
||||||
int64 reuses = 0;
|
int64 reuses = 0;
|
||||||
for (auto op : user->operands()) {
|
for (auto op : user->operands()) {
|
||||||
@ -919,11 +1028,23 @@ class GroupConnectedBoundaries {
|
|||||||
|
|
||||||
int64 BenefitForMovingBoundaries(const std::vector<Boundary>& boundaries) {
|
int64 BenefitForMovingBoundaries(const std::vector<Boundary>& boundaries) {
|
||||||
int64 reuses_before = 0, reuses_after = 0;
|
int64 reuses_before = 0, reuses_after = 0;
|
||||||
if (boundaries.size() == 1 && boundaries[0].IsOutsideBranch() &&
|
if (boundaries.size() == 1) {
|
||||||
boundaries[0].operands()[0]->opcode() == HloOpcode::kGetTupleElement) {
|
if (boundaries[0].IsOutsideBranch() &&
|
||||||
|
boundaries[0].operands()[0]->opcode() ==
|
||||||
|
HloOpcode::kGetTupleElement) {
|
||||||
// The only boundary of moving-in is the get_tuple_element op.
|
// The only boundary of moving-in is the get_tuple_element op.
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
if (boundaries[0].IsInsideBranch() &&
|
||||||
|
boundaries[0].operands()[0]->opcode() == HloOpcode::kTuple) {
|
||||||
|
// The only boundary of moving-out is the tuple op inside branches.
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If trying alternative moving configurations, turn off reuse analysis.
|
||||||
|
if (search_config_ > 0) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
// For cases like :
|
// For cases like :
|
||||||
// branch0 {
|
// branch0 {
|
||||||
// ROOT copy
|
// ROOT copy
|
||||||
@ -1121,7 +1242,8 @@ ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion(
|
|||||||
std::vector<Boundary>& to_move, std::vector<Boundary>& new_boundaries,
|
std::vector<Boundary>& to_move, std::vector<Boundary>& new_boundaries,
|
||||||
absl::flat_hash_map<HloInstruction*, int>& visited_count) {
|
absl::flat_hash_map<HloInstruction*, int>& visited_count) {
|
||||||
GroupConnectedBoundaries connect(conditional, is_layout_sensitive_,
|
GroupConnectedBoundaries connect(conditional, is_layout_sensitive_,
|
||||||
visited_count);
|
visited_count, &move_config_, &reuse_config_,
|
||||||
|
&search_config_);
|
||||||
auto move_in_or_out =
|
auto move_in_or_out =
|
||||||
connect.BoundariesToMoveInOrOut(conditional, cur_boundary);
|
connect.BoundariesToMoveInOrOut(conditional, cur_boundary);
|
||||||
if (!move_in_or_out.empty()) {
|
if (!move_in_or_out.empty()) {
|
||||||
@ -1167,6 +1289,10 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
|
|||||||
TF_ASSIGN_OR_RETURN(auto cleanup_changed_now, subpipeline.Run(module));
|
TF_ASSIGN_OR_RETURN(auto cleanup_changed_now, subpipeline.Run(module));
|
||||||
cleanup_changed |= cleanup_changed_now;
|
cleanup_changed |= cleanup_changed_now;
|
||||||
}
|
}
|
||||||
|
// set the default configuration
|
||||||
|
VLOG(2) << "Obtaining default configuration\n";
|
||||||
|
SetDefaultMoveConfig();
|
||||||
|
VLOG(2) << "Done obtaining default configuration\n";
|
||||||
// Gather all the conditional ops in the module ahead of time, to avoid
|
// Gather all the conditional ops in the module ahead of time, to avoid
|
||||||
// potential complications of modifying the code that affecting traversal.
|
// potential complications of modifying the code that affecting traversal.
|
||||||
std::vector<HloInstruction*> conditional_ops;
|
std::vector<HloInstruction*> conditional_ops;
|
||||||
@ -1390,6 +1516,46 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
|
|||||||
}
|
}
|
||||||
return changed;
|
return changed;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ConditionalCodeMotion::SetDefaultMoveConfig() {
|
||||||
|
int tuning_option = (search_config_ == 0) ? 0 : (search_config_ > 0) ? 1 : 2;
|
||||||
|
|
||||||
|
auto row = HloOpcodeCount();
|
||||||
|
auto col = row;
|
||||||
|
VLOG(2) << "Start setting default configuration\n";
|
||||||
|
reuse_config_.reserve(row);
|
||||||
|
move_config_.reserve(row);
|
||||||
|
for (int64 opcode = 0; opcode < row; ++opcode) {
|
||||||
|
// To save whether an instruction is preferred to be moved.
|
||||||
|
std::vector<int64> reuse_vec(col, 0);
|
||||||
|
for (uint32 j = 0; j < col; ++j) {
|
||||||
|
reuse_vec[j] = ReusesCarriedBy(static_cast<HloOpcode>(opcode),
|
||||||
|
static_cast<HloOpcode>(j));
|
||||||
|
}
|
||||||
|
reuse_config_.push_back(reuse_vec);
|
||||||
|
std::vector<int64> move_vec;
|
||||||
|
switch (tuning_option) {
|
||||||
|
case 1:
|
||||||
|
// Tuning transformation decision --- start with all yes.
|
||||||
|
// Only a single entry is needed if we don't consider operands of an op
|
||||||
|
// when searching/tuning transformation decisions.
|
||||||
|
move_vec.push_back(1);
|
||||||
|
break;
|
||||||
|
case 2: // Tune the ReusesCarriedBy results only.
|
||||||
|
case 0:
|
||||||
|
// No tuning --- use the default configuration.
|
||||||
|
// Use the opcode of first operand to configure default.
|
||||||
|
move_vec.reserve(col);
|
||||||
|
for (uint32 j = 0; j < col; ++j) {
|
||||||
|
move_vec.push_back(WorthHoisting(static_cast<HloOpcode>(opcode),
|
||||||
|
static_cast<HloOpcode>(j)));
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
move_config_.push_back(move_vec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace conditional_opt
|
} // namespace conditional_opt
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -68,15 +68,40 @@ class Boundary {
|
|||||||
// and their properties are identical.
|
// and their properties are identical.
|
||||||
// - Only the identical ops that won't share operands with other ops will
|
// - Only the identical ops that won't share operands with other ops will
|
||||||
// be moved out of conditional.
|
// be moved out of conditional.
|
||||||
|
// The cost model of the code motion optimization includes two components:
|
||||||
|
// represented by the move_config_ and reuse_config_ arrays of the optimization.
|
||||||
|
// The move_config_ array uses 1 vs 0 to dictate whether each Hlo Opcode, when
|
||||||
|
// used with its first operand being another given Hlo Opcode, is allowed to
|
||||||
|
// move across any conditional boundary; the reuse_config_ array uses an integer
|
||||||
|
// to represent the force between each pair of HloOpcode regarding how
|
||||||
|
// attractive it is to place these instructions together (both inside or outside
|
||||||
|
// of a conditional). Both arrays use Hlo Opcode only to drive the
|
||||||
|
// configuration, regardless of where the operations are located in the
|
||||||
|
// module.
|
||||||
class ConditionalCodeMotion : public HloModulePass {
|
class ConditionalCodeMotion : public HloModulePass {
|
||||||
public:
|
public:
|
||||||
// If is_layout_sensitive is true, then the hoist process preserves layout
|
// If is_layout_sensitive is true, then the hoist process preserves layout
|
||||||
// during identical comparison. Otherwise, layout is ignored.
|
// during identical comparison. Otherwise, layout is ignored.
|
||||||
|
// The search configuration is a single integer but is split into four parts:
|
||||||
|
// (sign, n, m, p), where n,m,p each occupy 8 bits and together make the 24
|
||||||
|
// bits at the end of the int32. For the sign part, if search_config is <0,
|
||||||
|
// the reuse_config_ cost model is modified (tuned); if search_config is >0,
|
||||||
|
// the move_config_ cost model is modified (tuned); if search_config == 0,
|
||||||
|
// the default cost model is used with no tuning. When tuning, the entries in
|
||||||
|
// the designated configuration array (move_config_ or reuse_config_) are
|
||||||
|
// flipped between 0 and another default integer, starting from the pth entry
|
||||||
|
// being queried by the optimization and repeated every nth time a new entry
|
||||||
|
// is visited, until a maximal of m entries have been changed. The tuning
|
||||||
|
// start over when optimizing a new model.
|
||||||
explicit ConditionalCodeMotion(bool is_layout_sensitive,
|
explicit ConditionalCodeMotion(bool is_layout_sensitive,
|
||||||
bool pursue_full_conditional_code_motion)
|
bool pursue_full_conditional_code_motion,
|
||||||
|
int32 search_config = 0)
|
||||||
: is_layout_sensitive_(is_layout_sensitive),
|
: is_layout_sensitive_(is_layout_sensitive),
|
||||||
pursue_full_conditional_code_motion_(
|
pursue_full_conditional_code_motion_(
|
||||||
pursue_full_conditional_code_motion) {}
|
/*turn off special case if tuning*/
|
||||||
|
pursue_full_conditional_code_motion && search_config == 0),
|
||||||
|
search_config_(search_config) {}
|
||||||
|
|
||||||
absl::string_view name() const override { return "conditional-code-motion"; }
|
absl::string_view name() const override { return "conditional-code-motion"; }
|
||||||
StatusOr<bool> Run(HloModule* module) override;
|
StatusOr<bool> Run(HloModule* module) override;
|
||||||
|
|
||||||
@ -109,6 +134,8 @@ class ConditionalCodeMotion : public HloModulePass {
|
|||||||
private:
|
private:
|
||||||
const bool is_layout_sensitive_;
|
const bool is_layout_sensitive_;
|
||||||
const bool pursue_full_conditional_code_motion_;
|
const bool pursue_full_conditional_code_motion_;
|
||||||
|
int32 search_config_;
|
||||||
|
std::vector<std::vector<int64>> move_config_, reuse_config_;
|
||||||
|
|
||||||
StatusOr<bool> MoveInstructionOut(HloInstruction* conditional,
|
StatusOr<bool> MoveInstructionOut(HloInstruction* conditional,
|
||||||
std::vector<Boundary>& to_move_out,
|
std::vector<Boundary>& to_move_out,
|
||||||
@ -116,6 +143,7 @@ class ConditionalCodeMotion : public HloModulePass {
|
|||||||
StatusOr<bool> MoveInstructionIn(HloInstruction* conditional,
|
StatusOr<bool> MoveInstructionIn(HloInstruction* conditional,
|
||||||
std::vector<Boundary>& to_move_in,
|
std::vector<Boundary>& to_move_in,
|
||||||
std::vector<Boundary>& new_boundaries);
|
std::vector<Boundary>& new_boundaries);
|
||||||
|
void SetDefaultMoveConfig();
|
||||||
};
|
};
|
||||||
} // namespace conditional_opt
|
} // namespace conditional_opt
|
||||||
|
|
||||||
|
@ -279,6 +279,7 @@ ENTRY main {
|
|||||||
|
|
||||||
const HloInstruction* conditional =
|
const HloInstruction* conditional =
|
||||||
FindInstruction(module.get(), "conditional");
|
FindInstruction(module.get(), "conditional");
|
||||||
|
CHECK_NE(conditional, nullptr);
|
||||||
const HloComputation* on_true = conditional->branch_computation(0);
|
const HloComputation* on_true = conditional->branch_computation(0);
|
||||||
ASSERT_EQ(on_true->instruction_count(), 1);
|
ASSERT_EQ(on_true->instruction_count(), 1);
|
||||||
const HloComputation* on_false = conditional->branch_computation(1);
|
const HloComputation* on_false = conditional->branch_computation(1);
|
||||||
@ -1240,6 +1241,127 @@ ENTRY main {
|
|||||||
op::Parameter())));
|
op::Parameter())));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ConditionalCodeMotionTest, TestConfigurationFlag) {
|
||||||
|
absl::string_view hlo_string =
|
||||||
|
R"(
|
||||||
|
HloModule RemoveDotOpOut
|
||||||
|
|
||||||
|
on_true {
|
||||||
|
%arg_tuple.1 = (f32[93184,4]{1,0}) parameter(0)
|
||||||
|
%get-tuple-element.1 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.1), index=0
|
||||||
|
%reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.1)
|
||||||
|
%convert.2894 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.8493)
|
||||||
|
ROOT %tuple.1 = ( bf16[2,512,364]{2,1,0}) tuple(%convert.2894)
|
||||||
|
}
|
||||||
|
|
||||||
|
on_false {
|
||||||
|
%arg_tuple.2 = (f32[93184,4]{1,0}) parameter(0)
|
||||||
|
%get-tuple-element.3 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.2), index=0
|
||||||
|
%reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.3)
|
||||||
|
%convert.3604 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.9717), metadata={op_type="Cast" op_name="gradients/Cast_125_grad/Cast"}
|
||||||
|
ROOT %tuple.2 = (bf16[2,512,364]{2,1,0}) tuple(%convert.3604)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY main {
|
||||||
|
pred.1 = pred[] parameter(0)
|
||||||
|
arg_tuple.11 = (f32[93184,4]{1,0}) parameter(1)
|
||||||
|
arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2)
|
||||||
|
conditional = (bf16[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false
|
||||||
|
get-first-index = bf16[2,512,364]{2,1,0} get-tuple-element(conditional), index=0
|
||||||
|
add.1 = bf16[2,512,364]{2,1,0} add(bf16[2,512,364]{2,1,0} get-first-index, bf16[2,512,364]{2,1,0} get-first-index)
|
||||||
|
ROOT result = (bf16[2,512,364]{2,1,0}) tuple(add.1)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
// Use a config loop to tune which instructions should be moved/not_moved.
|
||||||
|
for (int max_flip = 1; max_flip < 3; ++max_flip) {
|
||||||
|
for (int flip_stride = 1; flip_stride < ((max_flip > 1) ? 7 : 2);
|
||||||
|
++flip_stride) {
|
||||||
|
for (int flip_start = 0; flip_start < 7; ++flip_start) {
|
||||||
|
// Start flipping at index config, repeat thereafter, until reaching
|
||||||
|
// max.
|
||||||
|
uint32 search_config =
|
||||||
|
(max_flip << 8) + flip_start + (flip_stride << 16);
|
||||||
|
ConditionalCodeMotion pass(true, true, search_config);
|
||||||
|
VLOG(1) << "Testing max_flip=" << max_flip
|
||||||
|
<< "; flip_start = " << flip_start
|
||||||
|
<< "; flip_stride = " << flip_stride
|
||||||
|
<< "; search_config=" << search_config;
|
||||||
|
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||||
|
bool opt_result = pass.Run(&*module).ValueOrDie();
|
||||||
|
// Turning off the first/second decision will disable moving out;
|
||||||
|
// Turning off the following decision will again disable moving in.
|
||||||
|
if (flip_start < 2 && max_flip > 1 && flip_stride == 1) {
|
||||||
|
// If the next decision is false, no moving in is allowed either.
|
||||||
|
CHECK_EQ(opt_result, false);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
CHECK_EQ(opt_result, true);
|
||||||
|
const HloInstruction* conditional =
|
||||||
|
FindInstruction(module.get(), "conditional");
|
||||||
|
const HloComputation* on_true = conditional->branch_computation(0);
|
||||||
|
const HloComputation* on_false = conditional->branch_computation(1);
|
||||||
|
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||||
|
switch (flip_start) {
|
||||||
|
case 0:
|
||||||
|
TF_FALLTHROUGH_INTENDED;
|
||||||
|
case 1:
|
||||||
|
// After flipping the corresponding decisions,
|
||||||
|
// instructions has been moved inside the conditionals.
|
||||||
|
ASSERT_EQ(on_true->instruction_count(), 6);
|
||||||
|
ASSERT_EQ(on_false->instruction_count(), 6);
|
||||||
|
EXPECT_THAT(root, AllOf(op::Conditional()));
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
// The 2nd decision has been flipped. Reshape was not moved out.
|
||||||
|
ASSERT_EQ(on_true->instruction_count(), 4);
|
||||||
|
ASSERT_EQ(on_false->instruction_count(), 4);
|
||||||
|
EXPECT_THAT(
|
||||||
|
root,
|
||||||
|
AllOf(op::Tuple(op::Add(
|
||||||
|
op::Convert(op::GetTupleElement(op::Conditional())),
|
||||||
|
op::Convert(op::GetTupleElement(op::Conditional()))))));
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
// The 3rd decision has been flipped. GTE was not moved out. The
|
||||||
|
// GTE is then merged with the tuple op of the new root in later
|
||||||
|
// cleanup.
|
||||||
|
ASSERT_EQ(on_true->instruction_count(), 1);
|
||||||
|
ASSERT_EQ(on_false->instruction_count(), 1);
|
||||||
|
EXPECT_THAT(root, AllOf(op::Tuple(op::Add(
|
||||||
|
op::Convert(op::Reshape(
|
||||||
|
op::GetTupleElement(op::Conditional()))),
|
||||||
|
op::Convert(op::Reshape(op::GetTupleElement(
|
||||||
|
op::Conditional())))))));
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
case 5:
|
||||||
|
case 6:
|
||||||
|
// The 4th decision has been flipped. Parameter was not moved out.
|
||||||
|
// Each conditional has the parameter and the new root.
|
||||||
|
ASSERT_EQ(on_true->instruction_count(), 2);
|
||||||
|
ASSERT_EQ(on_false->instruction_count(), 2);
|
||||||
|
EXPECT_THAT(root,
|
||||||
|
AllOf(op::Tuple(op::Add(
|
||||||
|
op::Convert(op::Reshape(op::GetTupleElement(
|
||||||
|
op::GetTupleElement(op::Conditional())))),
|
||||||
|
op::Convert(op::Reshape(op::GetTupleElement(
|
||||||
|
op::GetTupleElement(op::Conditional()))))))));
|
||||||
|
break;
|
||||||
|
default: // The default cost model is used.
|
||||||
|
ASSERT_EQ(on_true->instruction_count(), 1);
|
||||||
|
ASSERT_EQ(on_false->instruction_count(), 1);
|
||||||
|
EXPECT_THAT(root, AllOf(op::Tuple(op::Add(
|
||||||
|
op::Convert(op::Reshape(
|
||||||
|
op::GetTupleElement(op::Conditional()))),
|
||||||
|
op::Convert(op::Reshape(op::GetTupleElement(
|
||||||
|
op::Conditional())))))));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace conditional_opt
|
} // namespace conditional_opt
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
Reference in New Issue
Block a user