[XLA] Add auto-tuning support for conditional code motion.
In particular, extend the command-line parameter for search to take a string as value, to support tuning.Also extended the flag tuner parameters with conditional code motion configuration. PiperOrigin-RevId: 358902635 Change-Id: I6c8e8f3412725f6976e74e476a41dc738778f12c
This commit is contained in:
parent
c0d79e99e8
commit
d6325d93b1
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
@ -822,7 +823,9 @@ class GroupConnectedBoundaries {
|
||||
// search/tuning process.
|
||||
std::vector<std::vector<int64>>& move_config_;
|
||||
std::vector<std::vector<int64>>& reuse_config_;
|
||||
int& search_config_;
|
||||
std::vector<int64>& search_config_vec_;
|
||||
int64* search_config_;
|
||||
int64 search_subscript_;
|
||||
absl::flat_hash_map<const int64*, int64> flipped_;
|
||||
|
||||
// The FlipMutation function serves to implement the search of alternative
|
||||
@ -834,29 +837,35 @@ class GroupConnectedBoundaries {
|
||||
VLOG(2) << "Configured not to search or loc is already flipped.";
|
||||
return *loc;
|
||||
}
|
||||
|
||||
// The 8-16 digits control the maximum number of times to flip a config.
|
||||
int flip_count = (search_config_ >> 8) & 255;
|
||||
if (flip_count == 0) {
|
||||
VLOG(2) << "Maximum flip count has reached. ";
|
||||
return *loc;
|
||||
}
|
||||
|
||||
// The last 8 digits control when to start the first flip.
|
||||
int c = search_config_ & 255;
|
||||
int c = ConditionalCodeMotion::flip_start(*search_config_);
|
||||
VLOG(2) << "flip start index = " << c << "\n";
|
||||
// Only flip the decision if c reaches 0.
|
||||
if (c > 0) {
|
||||
search_config_--;
|
||||
(*search_config_)--;
|
||||
return *loc;
|
||||
}
|
||||
|
||||
// Decrement flip count so we can stop if it reaches 0.
|
||||
search_config_ -= 256;
|
||||
// The 8-16 digits control the maximum number of times to flip a config.
|
||||
auto flip_count = ConditionalCodeMotion::DecrementMaxFlip(search_config_);
|
||||
VLOG(2) << "max flip count = " << flip_count << "\n";
|
||||
VLOG(2) << "Updating max Flipping configuration = " << *search_config_
|
||||
<< "\n";
|
||||
if (flip_count == 0) {
|
||||
VLOG(2) << "Maximum flip count has reached. ";
|
||||
if (search_subscript_ + 1 < search_config_vec_.size()) {
|
||||
VLOG(2) << "search_subscript_ = " << search_subscript_;
|
||||
VLOG(2) << "search config vec size = " << search_config_vec_.size();
|
||||
search_config_ = &search_config_vec_[++search_subscript_];
|
||||
} else {
|
||||
return *loc;
|
||||
}
|
||||
}
|
||||
// Reload the 16-23 digits of the configuration, which controls how
|
||||
// frequently a configuration should be flipped.
|
||||
search_config_ += (search_config_ >> 16) & 255;
|
||||
VLOG(2) << "Updating Flipping configuration = " << search_config_ << "\n";
|
||||
auto flip_stride = ConditionalCodeMotion::flip_stride(*search_config_);
|
||||
*search_config_ += flip_stride;
|
||||
VLOG(2) << "flip stride = " << flip_stride << "\n";
|
||||
VLOG(2) << "Updating Flipping Stride = " << *search_config_ << "\n";
|
||||
|
||||
flipped_[loc] = *loc;
|
||||
// Copy the last 8 bits back to the first 8 bits of configuration.
|
||||
@ -878,14 +887,23 @@ class GroupConnectedBoundaries {
|
||||
HloInstruction* conditional, bool is_layout_sensitive,
|
||||
absl::flat_hash_map<HloInstruction*, int>& visited_count,
|
||||
std::vector<std::vector<int64>>* move_config,
|
||||
std::vector<std::vector<int64>>* reuse_config, int* search_config)
|
||||
std::vector<std::vector<int64>>* reuse_config,
|
||||
std::vector<int64>* search_config)
|
||||
: conditional_(conditional),
|
||||
conditional_parent_(conditional->parent()),
|
||||
is_layout_sensitive_(is_layout_sensitive),
|
||||
visited_count_(visited_count),
|
||||
move_config_(*move_config),
|
||||
reuse_config_(*reuse_config),
|
||||
search_config_(*search_config) {}
|
||||
search_config_vec_(*search_config),
|
||||
search_subscript_(0) {
|
||||
VLOG(2) << "Initializing Group Connected Boundaries\n";
|
||||
CHECK_NE(search_config, nullptr);
|
||||
if (search_config_vec_.empty()) {
|
||||
search_config_vec_.push_back(0);
|
||||
}
|
||||
search_config_ = &search_config_vec_[0];
|
||||
}
|
||||
// Returns estimation of potential reuses carried by a given pair of
|
||||
// instructions. Use different integers to classify different levels
|
||||
// of reuses. Assume all instructions can be fused to enable data reuses.
|
||||
@ -896,7 +914,7 @@ class GroupConnectedBoundaries {
|
||||
// When flipping, use -10 if flipping to the default reuse model. Other
|
||||
// values can be specified if needed to fine-control the decision making.
|
||||
int64 config =
|
||||
(search_config_ < 0)
|
||||
((*search_config_) < 0)
|
||||
? FlipMutation(&curconfig[static_cast<uint32>(user->opcode())], -10,
|
||||
HloOpcodeString(op->opcode()) + "->" +
|
||||
HloOpcodeString(user->opcode()))
|
||||
@ -947,13 +965,14 @@ class GroupConnectedBoundaries {
|
||||
: 0;
|
||||
VLOG(2) << "column = " << col << "\n";
|
||||
VLOG(2) << "config size = " << curconfig.size() << "\n";
|
||||
VLOG(2) << "search_config = " << search_config_ << "\n";
|
||||
VLOG(2) << "search_config = " << *search_config_ << "\n";
|
||||
CHECK(col < curconfig.size());
|
||||
uint32 config = (search_config_ > 0)
|
||||
uint32 config = ((*search_config_) > 0)
|
||||
? FlipMutation(&curconfig[col], 1,
|
||||
"Move-" + HloOpcodeString(opcode))
|
||||
: curconfig[col];
|
||||
VLOG(2) << "Checking instruction is worth moving: " << config << "\n";
|
||||
VLOG(2) << "after checking search_config = " << *search_config_ << "\n";
|
||||
return (config != 0);
|
||||
}
|
||||
|
||||
@ -1026,7 +1045,8 @@ class GroupConnectedBoundaries {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int64 BenefitForMovingBoundaries(const std::vector<Boundary>& boundaries) {
|
||||
int64 BenefitForMovingBoundaries(const std::vector<Boundary>& boundaries,
|
||||
bool perform_reuse_analysis = true) {
|
||||
int64 reuses_before = 0, reuses_after = 0;
|
||||
if (boundaries.size() == 1) {
|
||||
if (boundaries[0].IsOutsideBranch() &&
|
||||
@ -1042,7 +1062,7 @@ class GroupConnectedBoundaries {
|
||||
}
|
||||
}
|
||||
// If trying alternative moving configurations, turn off reuse analysis.
|
||||
if (search_config_ > 0) {
|
||||
if (!perform_reuse_analysis) {
|
||||
return 1;
|
||||
}
|
||||
// For cases like :
|
||||
@ -1247,7 +1267,8 @@ ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion(
|
||||
auto move_in_or_out =
|
||||
connect.BoundariesToMoveInOrOut(conditional, cur_boundary);
|
||||
if (!move_in_or_out.empty()) {
|
||||
auto benefit = connect.BenefitForMovingBoundaries(move_in_or_out);
|
||||
auto benefit = connect.BenefitForMovingBoundaries(
|
||||
move_in_or_out, search_config_map_.empty());
|
||||
VLOG(2) << "benefit of moving in or out "
|
||||
<< cur_boundary.operands()[0]->ToString() << ":" << benefit << "\n";
|
||||
if (benefit >= 0) {
|
||||
@ -1289,10 +1310,6 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
|
||||
TF_ASSIGN_OR_RETURN(auto cleanup_changed_now, subpipeline.Run(module));
|
||||
cleanup_changed |= cleanup_changed_now;
|
||||
}
|
||||
// set the default configuration
|
||||
VLOG(2) << "Obtaining default configuration\n";
|
||||
SetDefaultMoveConfig();
|
||||
VLOG(2) << "Done obtaining default configuration\n";
|
||||
// Gather all the conditional ops in the module ahead of time, to avoid
|
||||
// potential complications of modifying the code that affecting traversal.
|
||||
std::vector<HloInstruction*> conditional_ops;
|
||||
@ -1328,9 +1345,23 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
|
||||
}
|
||||
}
|
||||
|
||||
int64 conditional_index = 0;
|
||||
// Use to collect mappings between cloned instructions.
|
||||
HloCloneContext clone_context(module);
|
||||
for (HloInstruction* conditional : conditional_ops) {
|
||||
if (conditional_index == 0 || !search_config_map_.empty()) {
|
||||
auto config_entry = search_config_map_.find(conditional_index);
|
||||
if (config_entry != search_config_map_.end()) {
|
||||
search_config_ = (*config_entry).second;
|
||||
VLOG(2) << "config entry value extracted:" << search_config_.size();
|
||||
search_config_index_ = 0;
|
||||
}
|
||||
VLOG(2) << "Obtaining default configuration for conditional "
|
||||
<< conditional_index << "\n";
|
||||
SetDefaultMoveConfig();
|
||||
VLOG(2) << "Done obtaining default configuration\n";
|
||||
conditional_index++;
|
||||
}
|
||||
int branch_count = conditional->branch_count();
|
||||
// check for shared conditional computations
|
||||
bool conditional_is_shared = false;
|
||||
@ -1518,11 +1549,27 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
|
||||
}
|
||||
|
||||
void ConditionalCodeMotion::SetDefaultMoveConfig() {
|
||||
int tuning_option = (search_config_ == 0) ? 0 : (search_config_ > 0) ? 1 : 2;
|
||||
VLOG(2) << "search_config_index = " << search_config_index_ << "\n";
|
||||
VLOG(2) << "search_config_ size = " << search_config_.size() << "\n";
|
||||
int64 cur_search_config = (search_config_index_ < 0 ||
|
||||
search_config_index_ >= search_config_.size())
|
||||
? 0
|
||||
: search_config_[search_config_index_];
|
||||
enum class TuningOption {
|
||||
kDoNotTune = 0,
|
||||
kTuneTransformationDecision = 1,
|
||||
kTuneReuseModel = 2,
|
||||
};
|
||||
TuningOption tuning_option =
|
||||
(cur_search_config == 0) ? TuningOption::kDoNotTune
|
||||
: (cur_search_config > 0) ? TuningOption::kTuneTransformationDecision
|
||||
: TuningOption::kTuneReuseModel;
|
||||
|
||||
auto row = HloOpcodeCount();
|
||||
auto col = row;
|
||||
VLOG(2) << "Start setting default configuration\n";
|
||||
reuse_config_.clear();
|
||||
move_config_.clear();
|
||||
reuse_config_.reserve(row);
|
||||
move_config_.reserve(row);
|
||||
for (int64 opcode = 0; opcode < row; ++opcode) {
|
||||
@ -1535,14 +1582,15 @@ void ConditionalCodeMotion::SetDefaultMoveConfig() {
|
||||
reuse_config_.push_back(reuse_vec);
|
||||
std::vector<int64> move_vec;
|
||||
switch (tuning_option) {
|
||||
case 1:
|
||||
case TuningOption::kTuneTransformationDecision:
|
||||
// Tuning transformation decision --- start with all yes.
|
||||
// Only a single entry is needed if we don't consider operands of an op
|
||||
// when searching/tuning transformation decisions.
|
||||
move_vec.push_back(1);
|
||||
break;
|
||||
case 2: // Tune the ReusesCarriedBy results only.
|
||||
case 0:
|
||||
// Tune the ReusesCarriedBy results only.
|
||||
case TuningOption::kTuneReuseModel:
|
||||
case TuningOption::kDoNotTune:
|
||||
// No tuning --- use the default configuration.
|
||||
// Use the opcode of first operand to configure default.
|
||||
move_vec.reserve(col);
|
||||
@ -1556,6 +1604,36 @@ void ConditionalCodeMotion::SetDefaultMoveConfig() {
|
||||
}
|
||||
}
|
||||
|
||||
// The search configuration is specified using a string in the format of
|
||||
// 'config1;config2; ...;config_n', where each config_i is in the format of
|
||||
// 'index,start,max,stride' (four integers separated by comma), which specify
|
||||
// the index number of the conditional being configured, the index of the first
|
||||
// transformation decision to flip for the conditional, the max number of
|
||||
// decisions to flip, and how many decisions to skip in between the flips.
|
||||
void ConditionalCodeMotion::ParseSearchConfiguration(
|
||||
const std::string& search_config) {
|
||||
if (search_config.empty()) {
|
||||
return;
|
||||
}
|
||||
search_config_index_ = 0;
|
||||
std::vector<std::string> configs = absl::StrSplit(search_config, ';');
|
||||
for (const std::string& config : configs) {
|
||||
std::vector<std::string> specs = absl::StrSplit(config, ',');
|
||||
CHECK_EQ(specs.size(), 4);
|
||||
int64 condition_index;
|
||||
CHECK(absl::SimpleAtoi(specs[0], &condition_index));
|
||||
auto& cur_config_entry = search_config_map_[condition_index];
|
||||
int64 flip_start, max_flip, flip_stride;
|
||||
CHECK(absl::SimpleAtoi(specs[1], &flip_start));
|
||||
CHECK(absl::SimpleAtoi(specs[2], &max_flip));
|
||||
CHECK(absl::SimpleAtoi(specs[3], &flip_stride));
|
||||
int64 cur_config = MakeSearchConfig(flip_start, max_flip, flip_stride);
|
||||
cur_config_entry.push_back(cur_config);
|
||||
VLOG(2) << "Setting search config " << condition_index << "->" << cur_config
|
||||
<< "\n";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace conditional_opt
|
||||
|
||||
} // namespace xla
|
||||
|
@ -95,12 +95,69 @@ class ConditionalCodeMotion : public HloModulePass {
|
||||
// start over when optimizing a new model.
|
||||
explicit ConditionalCodeMotion(bool is_layout_sensitive,
|
||||
bool pursue_full_conditional_code_motion,
|
||||
int32 search_config = 0)
|
||||
int64 search_config = 0)
|
||||
: is_layout_sensitive_(is_layout_sensitive),
|
||||
pursue_full_conditional_code_motion_(
|
||||
/*turn off special case if tuning*/
|
||||
pursue_full_conditional_code_motion && search_config == 0),
|
||||
search_config_(search_config) {}
|
||||
search_config_index_(0) {
|
||||
search_config_.push_back(search_config);
|
||||
if (search_config != 0) {
|
||||
search_config_map_[0] = search_config_;
|
||||
}
|
||||
}
|
||||
explicit ConditionalCodeMotion(bool is_layout_sensitive,
|
||||
bool pursue_full_conditional_code_motion,
|
||||
std::string search_config)
|
||||
: is_layout_sensitive_(is_layout_sensitive),
|
||||
pursue_full_conditional_code_motion_(
|
||||
/*turn off special case if tuning*/
|
||||
pursue_full_conditional_code_motion && search_config.empty()),
|
||||
search_config_index_(-1) {
|
||||
ParseSearchConfiguration(search_config);
|
||||
}
|
||||
// Parse a given string in the format of a sequence of i,s,m,t into a
|
||||
// list of transformation search configurations, each configuration generated
|
||||
// by invoking MakeSearchConfig(s,m,t) and will be used for the ith
|
||||
// conditional encountered when optimizing a given module.
|
||||
void ParseSearchConfiguration(const std::string& search_config);
|
||||
// Make a single search configuration for changing transformation decisions:
|
||||
// flip the decisions at position n = flip_start + flip_stride * m, and
|
||||
// m = 0..max_flip.
|
||||
// The following defines how the int64 search configuration is composed, as
|
||||
// flip_start + (flip_max << kMaxPos) + (flip_stride << kStridePos).
|
||||
// Position (digit) for maximum number of flips.
|
||||
static constexpr int kMaxPos = 16;
|
||||
// Position (digit) for the count-down to the first flip.
|
||||
static constexpr int kStartPos = 0;
|
||||
// Position (digit) for the count-down to the next flip.
|
||||
static constexpr int kStridePos = 32;
|
||||
// Bit mask for extracting the last digits of value.
|
||||
static constexpr int kValueMask = 0xffff;
|
||||
static int64 MakeSearchConfig(int64 start, int64 max, int64 stride) {
|
||||
const int64 config =
|
||||
(max << kMaxPos) + (start << kStartPos) + (stride << kStridePos);
|
||||
VLOG(2) << "flip stride = " << flip_stride(config) << "\n";
|
||||
VLOG(2) << "flig config = " << config << "\n";
|
||||
return config;
|
||||
}
|
||||
|
||||
static int16 flip_start(int64 search_config) {
|
||||
return (search_config >> kStartPos) & kValueMask;
|
||||
}
|
||||
|
||||
static int16 flip_stride(int64 search_config) {
|
||||
return (search_config >> kStridePos) & kValueMask;
|
||||
}
|
||||
|
||||
static int16 DecrementMaxFlip(int64* search_config) {
|
||||
const int16 max_flip = ((*search_config) >> kMaxPos) & kValueMask;
|
||||
// Decrement flip count so we can stop if it reaches 0.
|
||||
if (max_flip > 0) {
|
||||
*search_config -= (1 << kMaxPos);
|
||||
}
|
||||
return max_flip;
|
||||
}
|
||||
|
||||
absl::string_view name() const override { return "conditional-code-motion"; }
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
@ -134,7 +191,14 @@ class ConditionalCodeMotion : public HloModulePass {
|
||||
private:
|
||||
const bool is_layout_sensitive_;
|
||||
const bool pursue_full_conditional_code_motion_;
|
||||
int32 search_config_;
|
||||
// The following parameterizes the transformation decisions and cost model.
|
||||
std::vector<int64> search_config_;
|
||||
int64 search_config_index_;
|
||||
// Map each conditional to a vector of its search configurations. The key of
|
||||
// the map is the index number of the conditional in a module when traversed
|
||||
// in post order, and the value of the map is the sequence of search
|
||||
// configurations specified with the same index number for the conditional.
|
||||
absl::flat_hash_map<int64, std::vector<int64>> search_config_map_;
|
||||
std::vector<std::vector<int64>> move_config_, reuse_config_;
|
||||
|
||||
StatusOr<bool> MoveInstructionOut(HloInstruction* conditional,
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/conditional_code_motion.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
@ -1279,8 +1280,8 @@ ENTRY main {
|
||||
for (int flip_start = 0; flip_start < 7; ++flip_start) {
|
||||
// Start flipping at index config, repeat thereafter, until reaching
|
||||
// max.
|
||||
uint32 search_config =
|
||||
(max_flip << 8) + flip_start + (flip_stride << 16);
|
||||
int64 search_config = ConditionalCodeMotion::MakeSearchConfig(
|
||||
flip_start, max_flip, flip_stride);
|
||||
ConditionalCodeMotion pass(true, true, search_config);
|
||||
VLOG(1) << "Testing max_flip=" << max_flip
|
||||
<< "; flip_start = " << flip_start
|
||||
@ -1362,6 +1363,150 @@ ENTRY main {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ConditionalCodeMotionTest, TestMultipleConfigurationFlags) {
|
||||
absl::string_view hlo_string =
|
||||
R"(
|
||||
HloModule RemoveDotOpOut
|
||||
|
||||
on_true {
|
||||
%arg_tuple.1 = (f32[93184,4]{1,0}) parameter(0)
|
||||
%get-tuple-element.1 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.1), index=0
|
||||
%reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.1)
|
||||
%convert.2894 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.8493)
|
||||
ROOT %tuple.1 = ( bf16[2,512,364]{2,1,0}) tuple(%convert.2894)
|
||||
}
|
||||
|
||||
on_false {
|
||||
%arg_tuple.2 = (f32[93184,4]{1,0}) parameter(0)
|
||||
%get-tuple-element.3 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.2), index=0
|
||||
%reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.3)
|
||||
%convert.3604 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.9717), metadata={op_type="Cast" op_name="gradients/Cast_125_grad/Cast"}
|
||||
ROOT %tuple.2 = (bf16[2,512,364]{2,1,0}) tuple(%convert.3604)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
pred.1 = pred[] parameter(0)
|
||||
arg_tuple.11 = (f32[93184,4]{1,0}) parameter(1)
|
||||
arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2)
|
||||
pred.2 = pred[] parameter(3)
|
||||
conditional = (bf16[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false
|
||||
get-first-index = bf16[2,512,364]{2,1,0} get-tuple-element(conditional), index=0
|
||||
add.1 = bf16[2,512,364]{2,1,0} add(bf16[2,512,364]{2,1,0} get-first-index, bf16[2,512,364]{2,1,0} get-first-index)
|
||||
conditional.2 = (bf16[2,512,364]{2,1,0}) conditional(pred.2, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false
|
||||
get-first-index.2 = bf16[2,512,364]{2,1,0} get-tuple-element(conditional.2), index=0
|
||||
add.2 = bf16[2,512,364]{2,1,0} add(bf16[2,512,364]{2,1,0} get-first-index.2, bf16[2,512,364]{2,1,0} get-first-index.2)
|
||||
ROOT result = (bf16[2,512,364]{2,1,0}) tuple(add.1, add.2)
|
||||
}
|
||||
)";
|
||||
// Use a config loop to tune which instructions should be moved/not_moved.
|
||||
for (int max_flip = 1; max_flip < 3; ++max_flip) {
|
||||
for (int flip_stride = 1; flip_stride < ((max_flip > 1) ? 7 : 2);
|
||||
++flip_stride) {
|
||||
for (int flip_start = 0; flip_start < 7; ++flip_start) {
|
||||
// generate two search strings separated by ';'
|
||||
std::stringstream config_stream;
|
||||
config_stream << 0 << "," << flip_start << "," << max_flip << ","
|
||||
<< flip_stride << ";";
|
||||
config_stream << 1 << "," << flip_start << "," << max_flip << ","
|
||||
<< flip_stride;
|
||||
auto search_config = config_stream.str();
|
||||
ConditionalCodeMotion pass(true, true, search_config);
|
||||
VLOG(1) << "Testing max_flip=" << max_flip
|
||||
<< "; flip_start = " << flip_start
|
||||
<< "; flip_stride = " << flip_stride
|
||||
<< "; search_config=" << search_config;
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
bool opt_result = pass.Run(&*module).ValueOrDie();
|
||||
// Turning off the first/second decision will disable moving out;
|
||||
// Turning off the following decision will again disable moving in.
|
||||
if (flip_start < 2 && max_flip > 1 && flip_stride == 1) {
|
||||
// If the next decision is false, no moving in is allowed either.
|
||||
CHECK_EQ(opt_result, false);
|
||||
continue;
|
||||
}
|
||||
CHECK_EQ(opt_result, true);
|
||||
const HloInstruction* conditional =
|
||||
FindInstruction(module.get(), "conditional");
|
||||
const HloComputation* on_true = conditional->branch_computation(0);
|
||||
const HloComputation* on_false = conditional->branch_computation(1);
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
switch (flip_start) {
|
||||
case 0:
|
||||
TF_FALLTHROUGH_INTENDED;
|
||||
case 1:
|
||||
// After flipping the corresponding decisions,
|
||||
// instructions has been moved inside the conditionals.
|
||||
ASSERT_EQ(on_true->instruction_count(), 6);
|
||||
ASSERT_EQ(on_false->instruction_count(), 6);
|
||||
EXPECT_THAT(
|
||||
root, AllOf(op::Tuple(op::GetTupleElement(op::Conditional()),
|
||||
op::GetTupleElement(op::Conditional()))));
|
||||
break;
|
||||
case 2:
|
||||
// The 2nd decision has been flipped. Reshape was not moved out.
|
||||
ASSERT_EQ(on_true->instruction_count(), 4);
|
||||
ASSERT_EQ(on_false->instruction_count(), 4);
|
||||
EXPECT_THAT(
|
||||
root,
|
||||
AllOf(op::Tuple(
|
||||
op::Add(
|
||||
op::Convert(op::GetTupleElement(op::Conditional())),
|
||||
op::Convert(op::GetTupleElement(op::Conditional()))),
|
||||
op::Add(
|
||||
op::Convert(op::GetTupleElement(op::Conditional())),
|
||||
op::Convert(op::GetTupleElement(op::Conditional()))))));
|
||||
break;
|
||||
case 3:
|
||||
// The 3rd decision has been flipped. GTE was not moved out. The
|
||||
// GTE is then merged with the tuple op of the new root in later
|
||||
// cleanup.
|
||||
ASSERT_EQ(on_true->instruction_count(), 1);
|
||||
ASSERT_EQ(on_false->instruction_count(), 1);
|
||||
EXPECT_THAT(
|
||||
root, AllOf(op::Tuple(
|
||||
op::Add(op::Convert(op::Reshape(
|
||||
op::GetTupleElement(op::Conditional()))),
|
||||
op::Convert(op::Reshape(
|
||||
op::GetTupleElement(op::Conditional())))),
|
||||
op::Add(op::Convert(op::Reshape(
|
||||
op::GetTupleElement(op::Conditional()))),
|
||||
op::Convert(op::Reshape(op::GetTupleElement(
|
||||
op::Conditional())))))));
|
||||
break;
|
||||
case 4:
|
||||
case 5:
|
||||
case 6:
|
||||
// The 4th decision has been flipped. Parameter was not moved out.
|
||||
// Each conditional has the parameter and the new root.
|
||||
ASSERT_EQ(on_true->instruction_count(), 2);
|
||||
ASSERT_EQ(on_false->instruction_count(), 2);
|
||||
EXPECT_THAT(
|
||||
root,
|
||||
AllOf(op::Tuple(
|
||||
op::Add(op::Convert(op::Reshape(op::GetTupleElement(
|
||||
op::GetTupleElement(op::Conditional())))),
|
||||
op::Convert(op::Reshape(op::GetTupleElement(
|
||||
op::GetTupleElement(op::Conditional()))))),
|
||||
op::Add(op::Convert(op::Reshape(op::GetTupleElement(
|
||||
op::GetTupleElement(op::Conditional())))),
|
||||
op::Convert(op::Reshape(op::GetTupleElement(
|
||||
op::GetTupleElement(op::Conditional()))))))));
|
||||
break;
|
||||
default: // The default cost model is used.
|
||||
ASSERT_EQ(on_true->instruction_count(), 1);
|
||||
ASSERT_EQ(on_false->instruction_count(), 1);
|
||||
EXPECT_THAT(root, AllOf(op::Tuple(op::Add(
|
||||
op::Convert(op::Reshape(
|
||||
op::GetTupleElement(op::Conditional()))),
|
||||
op::Convert(op::Reshape(op::GetTupleElement(
|
||||
op::Conditional())))))));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace conditional_opt
|
||||
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user