[XLA] Add auto-tuning support for conditional code motion.

In particular, extend the command-line parameter for search to take a string as
value, to support tuning.Also extended the flag tuner parameters with conditional code motion configuration.

PiperOrigin-RevId: 358902635
Change-Id: I6c8e8f3412725f6976e74e476a41dc738778f12c
This commit is contained in:
A. Unique TensorFlower 2021-02-22 14:21:11 -08:00 committed by TensorFlower Gardener
parent c0d79e99e8
commit d6325d93b1
3 changed files with 324 additions and 37 deletions

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/literal.h"
@ -822,7 +823,9 @@ class GroupConnectedBoundaries {
// search/tuning process.
std::vector<std::vector<int64>>& move_config_;
std::vector<std::vector<int64>>& reuse_config_;
int& search_config_;
std::vector<int64>& search_config_vec_;
int64* search_config_;
int64 search_subscript_;
absl::flat_hash_map<const int64*, int64> flipped_;
// The FlipMutation function serves to implement the search of alternative
@ -834,29 +837,35 @@ class GroupConnectedBoundaries {
VLOG(2) << "Configured not to search or loc is already flipped.";
return *loc;
}
// The 8-16 digits control the maximum number of times to flip a config.
int flip_count = (search_config_ >> 8) & 255;
if (flip_count == 0) {
VLOG(2) << "Maximum flip count has reached. ";
return *loc;
}
// The last 8 digits control when to start the first flip.
int c = search_config_ & 255;
int c = ConditionalCodeMotion::flip_start(*search_config_);
VLOG(2) << "flip start index = " << c << "\n";
// Only flip the decision if c reaches 0.
if (c > 0) {
search_config_--;
(*search_config_)--;
return *loc;
}
// Decrement flip count so we can stop if it reaches 0.
search_config_ -= 256;
// The 8-16 digits control the maximum number of times to flip a config.
auto flip_count = ConditionalCodeMotion::DecrementMaxFlip(search_config_);
VLOG(2) << "max flip count = " << flip_count << "\n";
VLOG(2) << "Updating max Flipping configuration = " << *search_config_
<< "\n";
if (flip_count == 0) {
VLOG(2) << "Maximum flip count has reached. ";
if (search_subscript_ + 1 < search_config_vec_.size()) {
VLOG(2) << "search_subscript_ = " << search_subscript_;
VLOG(2) << "search config vec size = " << search_config_vec_.size();
search_config_ = &search_config_vec_[++search_subscript_];
} else {
return *loc;
}
}
// Reload the 16-23 digits of the configuration, which controls how
// frequently a configuration should be flipped.
search_config_ += (search_config_ >> 16) & 255;
VLOG(2) << "Updating Flipping configuration = " << search_config_ << "\n";
auto flip_stride = ConditionalCodeMotion::flip_stride(*search_config_);
*search_config_ += flip_stride;
VLOG(2) << "flip stride = " << flip_stride << "\n";
VLOG(2) << "Updating Flipping Stride = " << *search_config_ << "\n";
flipped_[loc] = *loc;
// Copy the last 8 bits back to the first 8 bits of configuration.
@ -878,14 +887,23 @@ class GroupConnectedBoundaries {
HloInstruction* conditional, bool is_layout_sensitive,
absl::flat_hash_map<HloInstruction*, int>& visited_count,
std::vector<std::vector<int64>>* move_config,
std::vector<std::vector<int64>>* reuse_config, int* search_config)
std::vector<std::vector<int64>>* reuse_config,
std::vector<int64>* search_config)
: conditional_(conditional),
conditional_parent_(conditional->parent()),
is_layout_sensitive_(is_layout_sensitive),
visited_count_(visited_count),
move_config_(*move_config),
reuse_config_(*reuse_config),
search_config_(*search_config) {}
search_config_vec_(*search_config),
search_subscript_(0) {
VLOG(2) << "Initializing Group Connected Boundaries\n";
CHECK_NE(search_config, nullptr);
if (search_config_vec_.empty()) {
search_config_vec_.push_back(0);
}
search_config_ = &search_config_vec_[0];
}
// Returns estimation of potential reuses carried by a given pair of
// instructions. Use different integers to classify different levels
// of reuses. Assume all instructions can be fused to enable data reuses.
@ -896,7 +914,7 @@ class GroupConnectedBoundaries {
// When flipping, use -10 if flipping to the default reuse model. Other
// values can be specified if needed to fine-control the decision making.
int64 config =
(search_config_ < 0)
((*search_config_) < 0)
? FlipMutation(&curconfig[static_cast<uint32>(user->opcode())], -10,
HloOpcodeString(op->opcode()) + "->" +
HloOpcodeString(user->opcode()))
@ -947,13 +965,14 @@ class GroupConnectedBoundaries {
: 0;
VLOG(2) << "column = " << col << "\n";
VLOG(2) << "config size = " << curconfig.size() << "\n";
VLOG(2) << "search_config = " << search_config_ << "\n";
VLOG(2) << "search_config = " << *search_config_ << "\n";
CHECK(col < curconfig.size());
uint32 config = (search_config_ > 0)
uint32 config = ((*search_config_) > 0)
? FlipMutation(&curconfig[col], 1,
"Move-" + HloOpcodeString(opcode))
: curconfig[col];
VLOG(2) << "Checking instruction is worth moving: " << config << "\n";
VLOG(2) << "after checking search_config = " << *search_config_ << "\n";
return (config != 0);
}
@ -1026,7 +1045,8 @@ class GroupConnectedBoundaries {
return 0;
}
int64 BenefitForMovingBoundaries(const std::vector<Boundary>& boundaries) {
int64 BenefitForMovingBoundaries(const std::vector<Boundary>& boundaries,
bool perform_reuse_analysis = true) {
int64 reuses_before = 0, reuses_after = 0;
if (boundaries.size() == 1) {
if (boundaries[0].IsOutsideBranch() &&
@ -1042,7 +1062,7 @@ class GroupConnectedBoundaries {
}
}
// If trying alternative moving configurations, turn off reuse analysis.
if (search_config_ > 0) {
if (!perform_reuse_analysis) {
return 1;
}
// For cases like :
@ -1247,7 +1267,8 @@ ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion(
auto move_in_or_out =
connect.BoundariesToMoveInOrOut(conditional, cur_boundary);
if (!move_in_or_out.empty()) {
auto benefit = connect.BenefitForMovingBoundaries(move_in_or_out);
auto benefit = connect.BenefitForMovingBoundaries(
move_in_or_out, search_config_map_.empty());
VLOG(2) << "benefit of moving in or out "
<< cur_boundary.operands()[0]->ToString() << ":" << benefit << "\n";
if (benefit >= 0) {
@ -1289,10 +1310,6 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
TF_ASSIGN_OR_RETURN(auto cleanup_changed_now, subpipeline.Run(module));
cleanup_changed |= cleanup_changed_now;
}
// set the default configuration
VLOG(2) << "Obtaining default configuration\n";
SetDefaultMoveConfig();
VLOG(2) << "Done obtaining default configuration\n";
// Gather all the conditional ops in the module ahead of time, to avoid
// potential complications of modifying the code that affecting traversal.
std::vector<HloInstruction*> conditional_ops;
@ -1328,9 +1345,23 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
}
}
int64 conditional_index = 0;
// Use to collect mappings between cloned instructions.
HloCloneContext clone_context(module);
for (HloInstruction* conditional : conditional_ops) {
if (conditional_index == 0 || !search_config_map_.empty()) {
auto config_entry = search_config_map_.find(conditional_index);
if (config_entry != search_config_map_.end()) {
search_config_ = (*config_entry).second;
VLOG(2) << "config entry value extracted:" << search_config_.size();
search_config_index_ = 0;
}
VLOG(2) << "Obtaining default configuration for conditional "
<< conditional_index << "\n";
SetDefaultMoveConfig();
VLOG(2) << "Done obtaining default configuration\n";
conditional_index++;
}
int branch_count = conditional->branch_count();
// check for shared conditional computations
bool conditional_is_shared = false;
@ -1518,11 +1549,27 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
}
void ConditionalCodeMotion::SetDefaultMoveConfig() {
int tuning_option = (search_config_ == 0) ? 0 : (search_config_ > 0) ? 1 : 2;
VLOG(2) << "search_config_index = " << search_config_index_ << "\n";
VLOG(2) << "search_config_ size = " << search_config_.size() << "\n";
int64 cur_search_config = (search_config_index_ < 0 ||
search_config_index_ >= search_config_.size())
? 0
: search_config_[search_config_index_];
enum class TuningOption {
kDoNotTune = 0,
kTuneTransformationDecision = 1,
kTuneReuseModel = 2,
};
TuningOption tuning_option =
(cur_search_config == 0) ? TuningOption::kDoNotTune
: (cur_search_config > 0) ? TuningOption::kTuneTransformationDecision
: TuningOption::kTuneReuseModel;
auto row = HloOpcodeCount();
auto col = row;
VLOG(2) << "Start setting default configuration\n";
reuse_config_.clear();
move_config_.clear();
reuse_config_.reserve(row);
move_config_.reserve(row);
for (int64 opcode = 0; opcode < row; ++opcode) {
@ -1535,14 +1582,15 @@ void ConditionalCodeMotion::SetDefaultMoveConfig() {
reuse_config_.push_back(reuse_vec);
std::vector<int64> move_vec;
switch (tuning_option) {
case 1:
case TuningOption::kTuneTransformationDecision:
// Tuning transformation decision --- start with all yes.
// Only a single entry is needed if we don't consider operands of an op
// when searching/tuning transformation decisions.
move_vec.push_back(1);
break;
case 2: // Tune the ReusesCarriedBy results only.
case 0:
// Tune the ReusesCarriedBy results only.
case TuningOption::kTuneReuseModel:
case TuningOption::kDoNotTune:
// No tuning --- use the default configuration.
// Use the opcode of first operand to configure default.
move_vec.reserve(col);
@ -1556,6 +1604,36 @@ void ConditionalCodeMotion::SetDefaultMoveConfig() {
}
}
// The search configuration is specified using a string in the format of
// 'config1;config2; ...;config_n', where each config_i is in the format of
// 'index,start,max,stride' (four integers separated by comma), which specify
// the index number of the conditional being configured, the index of the first
// transformation decision to flip for the conditional, the max number of
// decisions to flip, and how many decisions to skip in between the flips.
void ConditionalCodeMotion::ParseSearchConfiguration(
const std::string& search_config) {
if (search_config.empty()) {
return;
}
search_config_index_ = 0;
std::vector<std::string> configs = absl::StrSplit(search_config, ';');
for (const std::string& config : configs) {
std::vector<std::string> specs = absl::StrSplit(config, ',');
CHECK_EQ(specs.size(), 4);
int64 condition_index;
CHECK(absl::SimpleAtoi(specs[0], &condition_index));
auto& cur_config_entry = search_config_map_[condition_index];
int64 flip_start, max_flip, flip_stride;
CHECK(absl::SimpleAtoi(specs[1], &flip_start));
CHECK(absl::SimpleAtoi(specs[2], &max_flip));
CHECK(absl::SimpleAtoi(specs[3], &flip_stride));
int64 cur_config = MakeSearchConfig(flip_start, max_flip, flip_stride);
cur_config_entry.push_back(cur_config);
VLOG(2) << "Setting search config " << condition_index << "->" << cur_config
<< "\n";
}
}
} // namespace conditional_opt
} // namespace xla

View File

@ -95,12 +95,69 @@ class ConditionalCodeMotion : public HloModulePass {
// start over when optimizing a new model.
explicit ConditionalCodeMotion(bool is_layout_sensitive,
bool pursue_full_conditional_code_motion,
int32 search_config = 0)
int64 search_config = 0)
: is_layout_sensitive_(is_layout_sensitive),
pursue_full_conditional_code_motion_(
/*turn off special case if tuning*/
pursue_full_conditional_code_motion && search_config == 0),
search_config_(search_config) {}
search_config_index_(0) {
search_config_.push_back(search_config);
if (search_config != 0) {
search_config_map_[0] = search_config_;
}
}
explicit ConditionalCodeMotion(bool is_layout_sensitive,
bool pursue_full_conditional_code_motion,
std::string search_config)
: is_layout_sensitive_(is_layout_sensitive),
pursue_full_conditional_code_motion_(
/*turn off special case if tuning*/
pursue_full_conditional_code_motion && search_config.empty()),
search_config_index_(-1) {
ParseSearchConfiguration(search_config);
}
// Parse a given string in the format of a sequence of i,s,m,t into a
// list of transformation search configurations, each configuration generated
// by invoking MakeSearchConfig(s,m,t) and will be used for the ith
// conditional encountered when optimizing a given module.
void ParseSearchConfiguration(const std::string& search_config);
// Make a single search configuration for changing transformation decisions:
// flip the decisions at position n = flip_start + flip_stride * m, and
// m = 0..max_flip.
// The following defines how the int64 search configuration is composed, as
// flip_start + (flip_max << kMaxPos) + (flip_stride << kStridePos).
// Position (digit) for maximum number of flips.
static constexpr int kMaxPos = 16;
// Position (digit) for the count-down to the first flip.
static constexpr int kStartPos = 0;
// Position (digit) for the count-down to the next flip.
static constexpr int kStridePos = 32;
// Bit mask for extracting the last digits of value.
static constexpr int kValueMask = 0xffff;
static int64 MakeSearchConfig(int64 start, int64 max, int64 stride) {
const int64 config =
(max << kMaxPos) + (start << kStartPos) + (stride << kStridePos);
VLOG(2) << "flip stride = " << flip_stride(config) << "\n";
VLOG(2) << "flig config = " << config << "\n";
return config;
}
static int16 flip_start(int64 search_config) {
return (search_config >> kStartPos) & kValueMask;
}
static int16 flip_stride(int64 search_config) {
return (search_config >> kStridePos) & kValueMask;
}
static int16 DecrementMaxFlip(int64* search_config) {
const int16 max_flip = ((*search_config) >> kMaxPos) & kValueMask;
// Decrement flip count so we can stop if it reaches 0.
if (max_flip > 0) {
*search_config -= (1 << kMaxPos);
}
return max_flip;
}
absl::string_view name() const override { return "conditional-code-motion"; }
StatusOr<bool> Run(HloModule* module) override;
@ -134,7 +191,14 @@ class ConditionalCodeMotion : public HloModulePass {
private:
const bool is_layout_sensitive_;
const bool pursue_full_conditional_code_motion_;
int32 search_config_;
// The following parameterizes the transformation decisions and cost model.
std::vector<int64> search_config_;
int64 search_config_index_;
// Map each conditional to a vector of its search configurations. The key of
// the map is the index number of the conditional in a module when traversed
// in post order, and the value of the map is the sequence of search
// configurations specified with the same index number for the conditional.
absl::flat_hash_map<int64, std::vector<int64>> search_config_map_;
std::vector<std::vector<int64>> move_config_, reuse_config_;
StatusOr<bool> MoveInstructionOut(HloInstruction* conditional,

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/conditional_code_motion.h"
#include <sstream>
#include <string>
#include <utility>
@ -1279,8 +1280,8 @@ ENTRY main {
for (int flip_start = 0; flip_start < 7; ++flip_start) {
// Start flipping at index config, repeat thereafter, until reaching
// max.
uint32 search_config =
(max_flip << 8) + flip_start + (flip_stride << 16);
int64 search_config = ConditionalCodeMotion::MakeSearchConfig(
flip_start, max_flip, flip_stride);
ConditionalCodeMotion pass(true, true, search_config);
VLOG(1) << "Testing max_flip=" << max_flip
<< "; flip_start = " << flip_start
@ -1362,6 +1363,150 @@ ENTRY main {
}
}
TEST_F(ConditionalCodeMotionTest, TestMultipleConfigurationFlags) {
absl::string_view hlo_string =
R"(
HloModule RemoveDotOpOut
on_true {
%arg_tuple.1 = (f32[93184,4]{1,0}) parameter(0)
%get-tuple-element.1 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.1), index=0
%reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.1)
%convert.2894 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.8493)
ROOT %tuple.1 = ( bf16[2,512,364]{2,1,0}) tuple(%convert.2894)
}
on_false {
%arg_tuple.2 = (f32[93184,4]{1,0}) parameter(0)
%get-tuple-element.3 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.2), index=0
%reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.3)
%convert.3604 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.9717), metadata={op_type="Cast" op_name="gradients/Cast_125_grad/Cast"}
ROOT %tuple.2 = (bf16[2,512,364]{2,1,0}) tuple(%convert.3604)
}
ENTRY main {
pred.1 = pred[] parameter(0)
arg_tuple.11 = (f32[93184,4]{1,0}) parameter(1)
arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2)
pred.2 = pred[] parameter(3)
conditional = (bf16[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false
get-first-index = bf16[2,512,364]{2,1,0} get-tuple-element(conditional), index=0
add.1 = bf16[2,512,364]{2,1,0} add(bf16[2,512,364]{2,1,0} get-first-index, bf16[2,512,364]{2,1,0} get-first-index)
conditional.2 = (bf16[2,512,364]{2,1,0}) conditional(pred.2, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false
get-first-index.2 = bf16[2,512,364]{2,1,0} get-tuple-element(conditional.2), index=0
add.2 = bf16[2,512,364]{2,1,0} add(bf16[2,512,364]{2,1,0} get-first-index.2, bf16[2,512,364]{2,1,0} get-first-index.2)
ROOT result = (bf16[2,512,364]{2,1,0}) tuple(add.1, add.2)
}
)";
// Use a config loop to tune which instructions should be moved/not_moved.
for (int max_flip = 1; max_flip < 3; ++max_flip) {
for (int flip_stride = 1; flip_stride < ((max_flip > 1) ? 7 : 2);
++flip_stride) {
for (int flip_start = 0; flip_start < 7; ++flip_start) {
// generate two search strings separated by ';'
std::stringstream config_stream;
config_stream << 0 << "," << flip_start << "," << max_flip << ","
<< flip_stride << ";";
config_stream << 1 << "," << flip_start << "," << max_flip << ","
<< flip_stride;
auto search_config = config_stream.str();
ConditionalCodeMotion pass(true, true, search_config);
VLOG(1) << "Testing max_flip=" << max_flip
<< "; flip_start = " << flip_start
<< "; flip_stride = " << flip_stride
<< "; search_config=" << search_config;
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
bool opt_result = pass.Run(&*module).ValueOrDie();
// Turning off the first/second decision will disable moving out;
// Turning off the following decision will again disable moving in.
if (flip_start < 2 && max_flip > 1 && flip_stride == 1) {
// If the next decision is false, no moving in is allowed either.
CHECK_EQ(opt_result, false);
continue;
}
CHECK_EQ(opt_result, true);
const HloInstruction* conditional =
FindInstruction(module.get(), "conditional");
const HloComputation* on_true = conditional->branch_computation(0);
const HloComputation* on_false = conditional->branch_computation(1);
HloInstruction* root = module->entry_computation()->root_instruction();
switch (flip_start) {
case 0:
TF_FALLTHROUGH_INTENDED;
case 1:
// After flipping the corresponding decisions,
// instructions has been moved inside the conditionals.
ASSERT_EQ(on_true->instruction_count(), 6);
ASSERT_EQ(on_false->instruction_count(), 6);
EXPECT_THAT(
root, AllOf(op::Tuple(op::GetTupleElement(op::Conditional()),
op::GetTupleElement(op::Conditional()))));
break;
case 2:
// The 2nd decision has been flipped. Reshape was not moved out.
ASSERT_EQ(on_true->instruction_count(), 4);
ASSERT_EQ(on_false->instruction_count(), 4);
EXPECT_THAT(
root,
AllOf(op::Tuple(
op::Add(
op::Convert(op::GetTupleElement(op::Conditional())),
op::Convert(op::GetTupleElement(op::Conditional()))),
op::Add(
op::Convert(op::GetTupleElement(op::Conditional())),
op::Convert(op::GetTupleElement(op::Conditional()))))));
break;
case 3:
// The 3rd decision has been flipped. GTE was not moved out. The
// GTE is then merged with the tuple op of the new root in later
// cleanup.
ASSERT_EQ(on_true->instruction_count(), 1);
ASSERT_EQ(on_false->instruction_count(), 1);
EXPECT_THAT(
root, AllOf(op::Tuple(
op::Add(op::Convert(op::Reshape(
op::GetTupleElement(op::Conditional()))),
op::Convert(op::Reshape(
op::GetTupleElement(op::Conditional())))),
op::Add(op::Convert(op::Reshape(
op::GetTupleElement(op::Conditional()))),
op::Convert(op::Reshape(op::GetTupleElement(
op::Conditional())))))));
break;
case 4:
case 5:
case 6:
// The 4th decision has been flipped. Parameter was not moved out.
// Each conditional has the parameter and the new root.
ASSERT_EQ(on_true->instruction_count(), 2);
ASSERT_EQ(on_false->instruction_count(), 2);
EXPECT_THAT(
root,
AllOf(op::Tuple(
op::Add(op::Convert(op::Reshape(op::GetTupleElement(
op::GetTupleElement(op::Conditional())))),
op::Convert(op::Reshape(op::GetTupleElement(
op::GetTupleElement(op::Conditional()))))),
op::Add(op::Convert(op::Reshape(op::GetTupleElement(
op::GetTupleElement(op::Conditional())))),
op::Convert(op::Reshape(op::GetTupleElement(
op::GetTupleElement(op::Conditional()))))))));
break;
default: // The default cost model is used.
ASSERT_EQ(on_true->instruction_count(), 1);
ASSERT_EQ(on_false->instruction_count(), 1);
EXPECT_THAT(root, AllOf(op::Tuple(op::Add(
op::Convert(op::Reshape(
op::GetTupleElement(op::Conditional()))),
op::Convert(op::Reshape(op::GetTupleElement(
op::Conditional())))))));
break;
}
}
}
}
}
} // namespace conditional_opt
} // namespace xla