[XLA] Expose configuration space of conditional code motion to support auto-tuning of its cost model.

PiperOrigin-RevId: 351858804
Change-Id: Iaa4abf304af21cba88c09d1887ec93671a570b8a
This commit is contained in:
A. Unique TensorFlower 2021-01-14 12:54:20 -08:00 committed by TensorFlower Gardener
parent 5b5396b4b2
commit c1336b952d
4 changed files with 392 additions and 75 deletions

View File

@ -2396,6 +2396,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"@com_google_absl//absl/flags:flag",
],
)

View File

@ -113,16 +113,14 @@ int64 CountNonLeafOps(const OpCollection& ops) {
// instructions. Use different integers to classify different levels
// of reuses This is used as a placeholder only, assuming all
// instructions can be fused to enable data reuses
int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) {
int64 ReusesCarriedBy(HloOpcode op, HloOpcode user) {
// Reuses in some way work like forces that pull instructions
// towards each other. We use a number 0-10 to classify how strong the force
// is between a pair of operations. Given a group of instructions that can be
// moved together, if the forces inside a conditional are stronger, the group
// will be moved incide or remain inside the conditional; otherwise, it will
// be moved outside to or remain outside of the conditional.
VLOG(2) << "ConditionalCodeMotion: Add reuses carried by instr: "
<< op->ToString() << "=>" << user->ToString() << "\n";
switch (user->opcode()) {
switch (user) {
case HloOpcode::kGetTupleElement:
return 0;
case HloOpcode::kConvert:
@ -130,7 +128,7 @@ int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) {
// convolution, here if op is dot or convolution, they must be separated
// by a conditional boundary. Here we do not try to pull convert inside
// conditionals to be together with the dot or convolution.
switch (op->opcode()) {
switch (op) {
case HloOpcode::kConvolution:
case HloOpcode::kDot:
return 0;
@ -141,7 +139,7 @@ int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) {
default:
break;
}
switch (op->opcode()) {
switch (op) {
// These instructions do not carry weight of reuse themselves.
case HloOpcode::kParameter:
case HloOpcode::kConstant:
@ -149,12 +147,57 @@ int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) {
return 0;
case HloOpcode::kConditional:
return 10;
default: {
// Assume the reuse decreases with increasing user count.
int count1 = CountNonLeafOps(op->users());
int count2 = CountNonLeafOps(user->operands());
return 10 / count1 / count2;
default:
return -10;
}
}
// Returns true if `op` is worth hoisting.
bool WorthHoisting(HloOpcode op, HloOpcode child_op) {
// TOOD[b/169182921] The following cost model is rather incomplete. Will
// need to extend to cover most of element-wise ops.
switch (op) {
case HloOpcode::kConvert:
// If Convert is after AllReduce, it is worth moving out AllReduce
// out of conditional for AR/CRS combine. If Convert is after other
// ops such as Dot or Convolutional, it is better to keep convert
// within conditional so that convert can be fused with Dot or
// Convolutional.
switch (child_op) {
case HloOpcode::kAllReduce:
case HloOpcode::kReshape:
case HloOpcode::kGetTupleElement:
return true;
default:
return false;
}
case HloOpcode::kGetTupleElement:
switch (child_op) {
// do not move GTE if its operand is a parameter
case HloOpcode::kParameter:
return false;
default:
return true;
}
case HloOpcode::kAllReduce:
case HloOpcode::kAbs:
case HloOpcode::kReduce:
case HloOpcode::kAdd:
case HloOpcode::kPower:
case HloOpcode::kCopy:
case HloOpcode::kConstant:
case HloOpcode::kSubtract:
case HloOpcode::kMultiply:
case HloOpcode::kDivide:
case HloOpcode::kTuple:
case HloOpcode::kSqrt:
case HloOpcode::kRsqrt:
case HloOpcode::kReshape:
case HloOpcode::kMinimum:
case HloOpcode::kMaximum:
return true;
default:
return false;
}
}
@ -579,8 +622,6 @@ StatusOr<bool> ConditionalCodeMotion::MoveInstructionOut(
HloInstruction* new_root =
conditional->branch_computation(0)->root_instruction();
*conditional->mutable_shape() = new_root->shape();
//
VLOG(1) << "done moving instructions out of branches\n"
<< conditional_parent->ToString(HloPrintOptions::Fingerprint())
<< "\n";
@ -772,15 +813,105 @@ class GroupConnectedBoundaries {
bool is_layout_sensitive_;
// Instructions that have been visited but are not going to be moved.
absl::flat_hash_map<HloInstruction*, int>& visited_count_;
// The following four lines are configurations of the cost model, which will
// be used to determine whether to move an instruction (move_config_) and how
// strongly preferred it is to keep a pair of ops together (reuse_config_).
// The search_config_ is used to control how to navigate the search space of
// the cost model in the context of auto/manual tuning. The flipped array is
// used to save which entries in the configuration have been changed in the
// search/tuning process.
std::vector<std::vector<int64>>& move_config_;
std::vector<std::vector<int64>>& reuse_config_;
int& search_config_;
absl::flat_hash_map<const int64*, int64> flipped_;
// The FlipMutation function serves to implement the search of alternative
// cost models by deciding whether to flip a given configuration, saved in
// the loc parameter. The non_zero parameter provides the new value to use
// to flip a zero. The msg parameter is only used for debugging purpposes.
int64 FlipMutation(int64* loc, const int64 non_zero, const std::string& msg) {
if (search_config_ == 0 || ContainsKey(flipped_, loc)) {
VLOG(2) << "Configured not to search or loc is already flipped.";
return *loc;
}
// The 8-16 digits control the maximum number of times to flip a config.
int flip_count = (search_config_ >> 8) & 255;
if (flip_count == 0) {
VLOG(2) << "Maximum flip count has reached. ";
return *loc;
}
// The last 8 digits control when to start the first flip.
int c = search_config_ & 255;
VLOG(2) << "flip start index = " << c << "\n";
// Only flip the decision if c reaches 0.
if (c > 0) {
search_config_--;
return *loc;
}
// Decrement flip count so we can stop if it reaches 0.
search_config_ -= 256;
// Reload the 16-23 digits of the configuration, which controls how
// frequently a configuration should be flipped.
search_config_ += (search_config_ >> 16) & 255;
VLOG(2) << "Updating Flipping configuration = " << search_config_ << "\n";
flipped_[loc] = *loc;
// Copy the last 8 bits back to the first 8 bits of configuration.
switch (*loc) {
case 0:
*loc = non_zero;
break;
default:
*loc = 0;
break;
}
VLOG(2) << "Flipping decision for: " << msg << ": from " << flipped_[loc]
<< " to " << *loc << "\n";
return *loc;
}
public:
explicit GroupConnectedBoundaries(
HloInstruction* conditional, bool is_layout_sensitive,
absl::flat_hash_map<HloInstruction*, int>& visited_count)
absl::flat_hash_map<HloInstruction*, int>& visited_count,
std::vector<std::vector<int64>>* move_config,
std::vector<std::vector<int64>>* reuse_config, int* search_config)
: conditional_(conditional),
conditional_parent_(conditional->parent()),
is_layout_sensitive_(is_layout_sensitive),
visited_count_(visited_count) {}
visited_count_(visited_count),
move_config_(*move_config),
reuse_config_(*reuse_config),
search_config_(*search_config) {}
// Returns estimation of potential reuses carried by a given pair of
// instructions. Use different integers to classify different levels
// of reuses. Assume all instructions can be fused to enable data reuses.
int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) {
std::vector<int64>& curconfig =
reuse_config_[static_cast<uint32>(op->opcode())];
// Flip the reuse configuration if tuning the cost model.
// When flipping, use -10 if flipping to the default reuse model. Other
// values can be specified if needed to fine-control the decision making.
int64 config =
(search_config_ < 0)
? FlipMutation(&curconfig[static_cast<uint32>(user->opcode())], -10,
HloOpcodeString(op->opcode()) + "->" +
HloOpcodeString(user->opcode()))
: curconfig[static_cast<uint32>(user->opcode())];
VLOG(2) << "ConditionalCodeMotion: Add reuses carried by instr: "
<< op->ToString() << "=>" << user->ToString() << " : " << config
<< "\n";
if (config < 0) {
// Assume the reuse decreases with increasing user count.
int count1 = CountNonLeafOps(op->users());
int count2 = CountNonLeafOps(user->operands());
return (-config) / count1 / count2;
}
return config;
}
void clear_recently_visited() {
for (const auto& boundary : new_boundaries_) {
visited_count_.erase(boundary.operands()[0]);
@ -791,63 +922,41 @@ class GroupConnectedBoundaries {
// This is needed for the "moving-in" transformation, to prevent the root
// of the parent computation (which contains the conditional) to be moved
// inside the conditional.
if (instruction->opcode() == HloOpcode::kTuple &&
HloOpcode opcode = instruction->opcode();
if (opcode == HloOpcode::kTuple &&
instruction == conditional_parent_->root_instruction()) {
return false;
}
// TOOD[b/169182921] The following cost model is rather incomplete. Will
// need to extend to cover most of element-wise ops.
switch (instruction->opcode()) {
case HloOpcode::kConvert:
// If Convert is after AllReduce, it is worth moving out AllReduce
// out of conditional for AR/CRS combine. If Convert is after other
// ops such as Dot or Convolutional, it is better to keep convert
// within conditional so that convert can be fused with Dot or
// Convolutional.
switch (instruction->operand(0)->opcode()) {
case HloOpcode::kAllReduce:
case HloOpcode::kReshape:
case HloOpcode::kGetTupleElement:
return true;
default:
VLOG(2) << "Instruction is convert and its operand is not known to "
"be worth hoisting\n";
return false;
}
case HloOpcode::kGetTupleElement:
switch (instruction->operand(0)->opcode()) {
// do not move GTE if its operand is a parameter
case HloOpcode::kParameter:
return false;
default:
return true;
}
case HloOpcode::kAllReduce:
// It is not safe to move collective ops from outside to inside
// conditional branches, as it may cause synchronization problems,
// when different layouts are assigned to different branches.
return is_inside_branch;
case HloOpcode::kAbs:
case HloOpcode::kReduce:
case HloOpcode::kAdd:
case HloOpcode::kPower:
case HloOpcode::kCopy:
case HloOpcode::kConstant:
case HloOpcode::kSubtract:
case HloOpcode::kMultiply:
case HloOpcode::kDivide:
case HloOpcode::kTuple:
case HloOpcode::kSqrt:
case HloOpcode::kRsqrt:
case HloOpcode::kReshape:
case HloOpcode::kMinimum:
case HloOpcode::kMaximum:
return true;
default:
VLOG(2) << "Instruction is not known to be worth hoisting\n";
if (opcode == HloOpcode::kAllReduce && !is_inside_branch) {
return false;
}
// It is not legal to move the parameter instructions.
if (opcode == HloOpcode::kParameter) {
return false;
}
// Use configuration given from outside (e.g., by autotuner).
std::vector<int64>& curconfig = move_config_[static_cast<uint32>(opcode)];
auto col = (curconfig.size() == 1) ? 0
: (instruction->operand_count() > 0)
? static_cast<uint32>(instruction->operand(0)->opcode())
: 0;
VLOG(2) << "column = " << col << "\n";
VLOG(2) << "config size = " << curconfig.size() << "\n";
VLOG(2) << "search_config = " << search_config_ << "\n";
CHECK(col < curconfig.size());
uint32 config = (search_config_ > 0)
? FlipMutation(&curconfig[col], 1,
"Move-" + HloOpcodeString(opcode))
: curconfig[col];
VLOG(2) << "Checking instruction is worth moving: " << config << "\n";
return (config != 0);
}
int64 ReusesBeforeBoundary(HloInstruction* user) {
int64 reuses = 0;
for (auto op : user->operands()) {
@ -919,11 +1028,23 @@ class GroupConnectedBoundaries {
int64 BenefitForMovingBoundaries(const std::vector<Boundary>& boundaries) {
int64 reuses_before = 0, reuses_after = 0;
if (boundaries.size() == 1 && boundaries[0].IsOutsideBranch() &&
boundaries[0].operands()[0]->opcode() == HloOpcode::kGetTupleElement) {
if (boundaries.size() == 1) {
if (boundaries[0].IsOutsideBranch() &&
boundaries[0].operands()[0]->opcode() ==
HloOpcode::kGetTupleElement) {
// The only boundary of moving-in is the get_tuple_element op.
return -1;
}
if (boundaries[0].IsInsideBranch() &&
boundaries[0].operands()[0]->opcode() == HloOpcode::kTuple) {
// The only boundary of moving-out is the tuple op inside branches.
return -1;
}
}
// If trying alternative moving configurations, turn off reuse analysis.
if (search_config_ > 0) {
return 1;
}
// For cases like :
// branch0 {
// ROOT copy
@ -1121,7 +1242,8 @@ ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion(
std::vector<Boundary>& to_move, std::vector<Boundary>& new_boundaries,
absl::flat_hash_map<HloInstruction*, int>& visited_count) {
GroupConnectedBoundaries connect(conditional, is_layout_sensitive_,
visited_count);
visited_count, &move_config_, &reuse_config_,
&search_config_);
auto move_in_or_out =
connect.BoundariesToMoveInOrOut(conditional, cur_boundary);
if (!move_in_or_out.empty()) {
@ -1167,6 +1289,10 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
TF_ASSIGN_OR_RETURN(auto cleanup_changed_now, subpipeline.Run(module));
cleanup_changed |= cleanup_changed_now;
}
// set the default configuration
VLOG(2) << "Obtaining default configuration\n";
SetDefaultMoveConfig();
VLOG(2) << "Done obtaining default configuration\n";
// Gather all the conditional ops in the module ahead of time, to avoid
// potential complications of modifying the code that affecting traversal.
std::vector<HloInstruction*> conditional_ops;
@ -1390,6 +1516,46 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
}
return changed;
}
void ConditionalCodeMotion::SetDefaultMoveConfig() {
int tuning_option = (search_config_ == 0) ? 0 : (search_config_ > 0) ? 1 : 2;
auto row = HloOpcodeCount();
auto col = row;
VLOG(2) << "Start setting default configuration\n";
reuse_config_.reserve(row);
move_config_.reserve(row);
for (int64 opcode = 0; opcode < row; ++opcode) {
// To save whether an instruction is preferred to be moved.
std::vector<int64> reuse_vec(col, 0);
for (uint32 j = 0; j < col; ++j) {
reuse_vec[j] = ReusesCarriedBy(static_cast<HloOpcode>(opcode),
static_cast<HloOpcode>(j));
}
reuse_config_.push_back(reuse_vec);
std::vector<int64> move_vec;
switch (tuning_option) {
case 1:
// Tuning transformation decision --- start with all yes.
// Only a single entry is needed if we don't consider operands of an op
// when searching/tuning transformation decisions.
move_vec.push_back(1);
break;
case 2: // Tune the ReusesCarriedBy results only.
case 0:
// No tuning --- use the default configuration.
// Use the opcode of first operand to configure default.
move_vec.reserve(col);
for (uint32 j = 0; j < col; ++j) {
move_vec.push_back(WorthHoisting(static_cast<HloOpcode>(opcode),
static_cast<HloOpcode>(j)));
}
break;
}
move_config_.push_back(move_vec);
}
}
} // namespace conditional_opt
} // namespace xla

View File

@ -68,15 +68,40 @@ class Boundary {
// and their properties are identical.
// - Only the identical ops that won't share operands with other ops will
// be moved out of conditional.
// The cost model of the code motion optimization includes two components:
// represented by the move_config_ and reuse_config_ arrays of the optimization.
// The move_config_ array uses 1 vs 0 to dictate whether each Hlo Opcode, when
// used with its first operand being another given Hlo Opcode, is allowed to
// move across any conditional boundary; the reuse_config_ array uses an integer
// to represent the force between each pair of HloOpcode regarding how
// attractive it is to place these instructions together (both inside or outside
// of a conditional). Both arrays use Hlo Opcode only to drive the
// configuration, regardless of where the operations are located in the
// module.
class ConditionalCodeMotion : public HloModulePass {
public:
// If is_layout_sensitive is true, then the hoist process preserves layout
// during identical comparison. Otherwise, layout is ignored.
// The search configuration is a single integer but is split into four parts:
// (sign, n, m, p), where n,m,p each occupy 8 bits and together make the 24
// bits at the end of the int32. For the sign part, if search_config is <0,
// the reuse_config_ cost model is modified (tuned); if search_config is >0,
// the move_config_ cost model is modified (tuned); if search_config == 0,
// the default cost model is used with no tuning. When tuning, the entries in
// the designated configuration array (move_config_ or reuse_config_) are
// flipped between 0 and another default integer, starting from the pth entry
// being queried by the optimization and repeated every nth time a new entry
// is visited, until a maximal of m entries have been changed. The tuning
// start over when optimizing a new model.
explicit ConditionalCodeMotion(bool is_layout_sensitive,
bool pursue_full_conditional_code_motion)
bool pursue_full_conditional_code_motion,
int32 search_config = 0)
: is_layout_sensitive_(is_layout_sensitive),
pursue_full_conditional_code_motion_(
pursue_full_conditional_code_motion) {}
/*turn off special case if tuning*/
pursue_full_conditional_code_motion && search_config == 0),
search_config_(search_config) {}
absl::string_view name() const override { return "conditional-code-motion"; }
StatusOr<bool> Run(HloModule* module) override;
@ -109,6 +134,8 @@ class ConditionalCodeMotion : public HloModulePass {
private:
const bool is_layout_sensitive_;
const bool pursue_full_conditional_code_motion_;
int32 search_config_;
std::vector<std::vector<int64>> move_config_, reuse_config_;
StatusOr<bool> MoveInstructionOut(HloInstruction* conditional,
std::vector<Boundary>& to_move_out,
@ -116,6 +143,7 @@ class ConditionalCodeMotion : public HloModulePass {
StatusOr<bool> MoveInstructionIn(HloInstruction* conditional,
std::vector<Boundary>& to_move_in,
std::vector<Boundary>& new_boundaries);
void SetDefaultMoveConfig();
};
} // namespace conditional_opt

View File

@ -279,6 +279,7 @@ ENTRY main {
const HloInstruction* conditional =
FindInstruction(module.get(), "conditional");
CHECK_NE(conditional, nullptr);
const HloComputation* on_true = conditional->branch_computation(0);
ASSERT_EQ(on_true->instruction_count(), 1);
const HloComputation* on_false = conditional->branch_computation(1);
@ -1240,6 +1241,127 @@ ENTRY main {
op::Parameter())));
}
TEST_F(ConditionalCodeMotionTest, TestConfigurationFlag) {
absl::string_view hlo_string =
R"(
HloModule RemoveDotOpOut
on_true {
%arg_tuple.1 = (f32[93184,4]{1,0}) parameter(0)
%get-tuple-element.1 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.1), index=0
%reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.1)
%convert.2894 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.8493)
ROOT %tuple.1 = ( bf16[2,512,364]{2,1,0}) tuple(%convert.2894)
}
on_false {
%arg_tuple.2 = (f32[93184,4]{1,0}) parameter(0)
%get-tuple-element.3 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.2), index=0
%reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.3)
%convert.3604 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.9717), metadata={op_type="Cast" op_name="gradients/Cast_125_grad/Cast"}
ROOT %tuple.2 = (bf16[2,512,364]{2,1,0}) tuple(%convert.3604)
}
ENTRY main {
pred.1 = pred[] parameter(0)
arg_tuple.11 = (f32[93184,4]{1,0}) parameter(1)
arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2)
conditional = (bf16[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false
get-first-index = bf16[2,512,364]{2,1,0} get-tuple-element(conditional), index=0
add.1 = bf16[2,512,364]{2,1,0} add(bf16[2,512,364]{2,1,0} get-first-index, bf16[2,512,364]{2,1,0} get-first-index)
ROOT result = (bf16[2,512,364]{2,1,0}) tuple(add.1)
}
)";
// Use a config loop to tune which instructions should be moved/not_moved.
for (int max_flip = 1; max_flip < 3; ++max_flip) {
for (int flip_stride = 1; flip_stride < ((max_flip > 1) ? 7 : 2);
++flip_stride) {
for (int flip_start = 0; flip_start < 7; ++flip_start) {
// Start flipping at index config, repeat thereafter, until reaching
// max.
uint32 search_config =
(max_flip << 8) + flip_start + (flip_stride << 16);
ConditionalCodeMotion pass(true, true, search_config);
VLOG(1) << "Testing max_flip=" << max_flip
<< "; flip_start = " << flip_start
<< "; flip_stride = " << flip_stride
<< "; search_config=" << search_config;
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
bool opt_result = pass.Run(&*module).ValueOrDie();
// Turning off the first/second decision will disable moving out;
// Turning off the following decision will again disable moving in.
if (flip_start < 2 && max_flip > 1 && flip_stride == 1) {
// If the next decision is false, no moving in is allowed either.
CHECK_EQ(opt_result, false);
continue;
}
CHECK_EQ(opt_result, true);
const HloInstruction* conditional =
FindInstruction(module.get(), "conditional");
const HloComputation* on_true = conditional->branch_computation(0);
const HloComputation* on_false = conditional->branch_computation(1);
HloInstruction* root = module->entry_computation()->root_instruction();
switch (flip_start) {
case 0:
TF_FALLTHROUGH_INTENDED;
case 1:
// After flipping the corresponding decisions,
// instructions has been moved inside the conditionals.
ASSERT_EQ(on_true->instruction_count(), 6);
ASSERT_EQ(on_false->instruction_count(), 6);
EXPECT_THAT(root, AllOf(op::Conditional()));
break;
case 2:
// The 2nd decision has been flipped. Reshape was not moved out.
ASSERT_EQ(on_true->instruction_count(), 4);
ASSERT_EQ(on_false->instruction_count(), 4);
EXPECT_THAT(
root,
AllOf(op::Tuple(op::Add(
op::Convert(op::GetTupleElement(op::Conditional())),
op::Convert(op::GetTupleElement(op::Conditional()))))));
break;
case 3:
// The 3rd decision has been flipped. GTE was not moved out. The
// GTE is then merged with the tuple op of the new root in later
// cleanup.
ASSERT_EQ(on_true->instruction_count(), 1);
ASSERT_EQ(on_false->instruction_count(), 1);
EXPECT_THAT(root, AllOf(op::Tuple(op::Add(
op::Convert(op::Reshape(
op::GetTupleElement(op::Conditional()))),
op::Convert(op::Reshape(op::GetTupleElement(
op::Conditional())))))));
break;
case 4:
case 5:
case 6:
// The 4th decision has been flipped. Parameter was not moved out.
// Each conditional has the parameter and the new root.
ASSERT_EQ(on_true->instruction_count(), 2);
ASSERT_EQ(on_false->instruction_count(), 2);
EXPECT_THAT(root,
AllOf(op::Tuple(op::Add(
op::Convert(op::Reshape(op::GetTupleElement(
op::GetTupleElement(op::Conditional())))),
op::Convert(op::Reshape(op::GetTupleElement(
op::GetTupleElement(op::Conditional()))))))));
break;
default: // The default cost model is used.
ASSERT_EQ(on_true->instruction_count(), 1);
ASSERT_EQ(on_false->instruction_count(), 1);
EXPECT_THAT(root, AllOf(op::Tuple(op::Add(
op::Convert(op::Reshape(
op::GetTupleElement(op::Conditional()))),
op::Convert(op::Reshape(op::GetTupleElement(
op::Conditional())))))));
break;
}
}
}
}
}
} // namespace conditional_opt
} // namespace xla