While loop simplification to remove repeated parameters.
PiperOrigin-RevId: 328948578 Change-Id: Ia42de967da824ec94b608154373ea281dadbef4d
This commit is contained in:
parent
59d177d9ac
commit
d3785a2f11
@ -37,23 +37,15 @@ namespace m = match;
|
||||
using absl::optional;
|
||||
using hlo_query::ContainsInstrWithOpcode;
|
||||
|
||||
// Tries to remove elements in a while loop's tuple that aren't used within the
|
||||
// loop.
|
||||
//
|
||||
// Specifically, if a loop is tuple-shaped, and there exists some element of
|
||||
// that tuple that is not used by the loop condition and is not used by the loop
|
||||
// body except to pass it to the next iteration of the loop, then we can remove
|
||||
// that element from the loop's tuples.
|
||||
static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
|
||||
CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
|
||||
|
||||
// Don't try this transformation if the while loop isn't removable, since if
|
||||
// it succeeds ultimately we're going to have to replace the old while loop
|
||||
// with a new one.
|
||||
if (!while_op->parent()->IsSafelyRemovable(while_op)) {
|
||||
VLOG(2) << "Can't remove dead parameters from non-removable while op.";
|
||||
return false;
|
||||
}
|
||||
// This is a utility function that removes the given tuple indices from the
|
||||
// while loop init, body, and condition. The final shape returned is still the
|
||||
// same as before.
|
||||
static StatusOr<HloInstruction*> RemoveDeadTupleIndices(
|
||||
HloInstruction* while_op, absl::flat_hash_set<int64>& used_tuple_indices) {
|
||||
// Build up maps from the old/new to the new/old tuple indices.
|
||||
std::vector<int64> new_to_old_tuple_idx(used_tuple_indices.begin(),
|
||||
used_tuple_indices.end());
|
||||
absl::c_sort(new_to_old_tuple_idx);
|
||||
|
||||
HloModule* module = while_op->GetModule();
|
||||
HloComputation* computation = while_op->parent();
|
||||
@ -62,107 +54,8 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
|
||||
HloComputation* while_body = while_op->while_body();
|
||||
HloInstruction* while_body_root = while_body->root_instruction();
|
||||
|
||||
if (!while_init->shape().IsTuple()) {
|
||||
VLOG(2) << "While op's carried value isn't tuple shaped.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (while_body_root->opcode() != HloOpcode::kTuple) {
|
||||
VLOG(2) << "While body's root is not a tuple(...) instruction.";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto print_no_metadata = HloPrintOptions().set_print_metadata(false);
|
||||
|
||||
// Bail if param0 of while_cond or while_body has users which aren't of type
|
||||
// get-tuple-element.
|
||||
for (const HloInstruction* instr : {while_body->parameter_instruction(0),
|
||||
while_cond->parameter_instruction(0)}) {
|
||||
for (const HloInstruction* user : instr->users()) {
|
||||
if (user->opcode() != HloOpcode::kGetTupleElement) {
|
||||
VLOG(2) << "Cowardly refusing to analyze while loop with "
|
||||
<< instr->ToString(print_no_metadata)
|
||||
<< " used by non-GTE instruction "
|
||||
<< user->ToString(print_no_metadata) << " in computation "
|
||||
<< instr->parent()->name();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape());
|
||||
if (tuple_size == 0) {
|
||||
VLOG(2) << "Can't remove elements from while loop's tuple -- it's already "
|
||||
"empty.";
|
||||
return false;
|
||||
}
|
||||
|
||||
absl::flat_hash_set<int64> used_tuple_indices;
|
||||
for (HloComputation* comp : {while_body, while_cond}) {
|
||||
// The HLO verifier ensures that while_input's shape matches while_init's
|
||||
// shape, which we verified above is a tuple.
|
||||
HloInstruction* while_input = comp->parameter_instruction(0);
|
||||
|
||||
for (const HloInstruction* user : while_input->users()) {
|
||||
// This user doesn't count if it's only used by the while body's root, and
|
||||
// the root places the tuple element into the same index of the tuple as
|
||||
// it came from. That just amounts to us carrying the variable through
|
||||
// the loop.
|
||||
//
|
||||
// Careful: HloInstruction::operand_index returns the first index the
|
||||
// operand appears in, but it may appear more than once!
|
||||
if (user->user_count() == 1 && user->users().front() == while_body_root &&
|
||||
while_body_root->operand_index(user) == user->tuple_index() &&
|
||||
absl::c_count(while_body_root->operands(), user) == 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
used_tuple_indices.insert(user->tuple_index());
|
||||
if (used_tuple_indices.size() == tuple_size) {
|
||||
VLOG(2) << "Loop " << while_op->ToString(print_no_metadata)
|
||||
<< " uses all of its inputs; no simplification possible.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If a tuple element is not passed unmodified from the while body's param0
|
||||
// through to the while body's root, count that element as "used", since
|
||||
// removing that element would be observable.
|
||||
for (int64 i = 0; i < while_body_root->operand_count(); ++i) {
|
||||
if (used_tuple_indices.contains(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto* operand = while_body_root->operand(i);
|
||||
if (operand->opcode() != HloOpcode::kGetTupleElement ||
|
||||
operand->operand(0) != while_body->parameter_instruction(0) ||
|
||||
operand->tuple_index() != i) {
|
||||
VLOG(2) << "Tuple index " << i
|
||||
<< " is not passed through loop body unmodified.";
|
||||
used_tuple_indices.insert(i);
|
||||
|
||||
if (used_tuple_indices.size() == tuple_size) {
|
||||
VLOG(2) << "Loop " << while_op->ToString(print_no_metadata)
|
||||
<< " uses all of its inputs; no simplification possible.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we got here, used_tuple_indices.size() < tuple_size, meaning some
|
||||
// elements of the loop's tuple aren't used by while_body or while_cond.
|
||||
CHECK_LT(used_tuple_indices.size(), tuple_size);
|
||||
|
||||
VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size()
|
||||
<< " elements from tuple of "
|
||||
<< while_op->ToString(print_no_metadata);
|
||||
|
||||
// Build up maps from the old/new to the new/old tuple indices.
|
||||
std::vector<int64> new_to_old_tuple_idx(used_tuple_indices.begin(),
|
||||
used_tuple_indices.end());
|
||||
absl::c_sort(new_to_old_tuple_idx);
|
||||
|
||||
absl::flat_hash_map<int64, int64> old_to_new_tuple_idx;
|
||||
for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) {
|
||||
int64 old_idx = new_to_old_tuple_idx[new_idx];
|
||||
@ -288,6 +181,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
|
||||
// The tuple simplifier will then simplify this if possible, removing
|
||||
// new_tuple and while_init.
|
||||
std::vector<HloInstruction*> new_tuple_elems;
|
||||
const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape());
|
||||
for (int64 old_idx = 0; old_idx < tuple_size; ++old_idx) {
|
||||
auto new_tuple_idx_it = old_to_new_tuple_idx.find(old_idx);
|
||||
if (new_tuple_idx_it != old_to_new_tuple_idx.end()) {
|
||||
@ -305,9 +199,293 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
|
||||
HloInstruction* new_tuple =
|
||||
computation->AddInstruction(HloInstruction::CreateTuple(new_tuple_elems));
|
||||
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, new_tuple));
|
||||
|
||||
return new_while_op;
|
||||
}
|
||||
|
||||
// Tries to remove elements in a while loop's tuple that aren't used within the
|
||||
// loop.
|
||||
//
|
||||
// Specifically, if a loop is tuple-shaped, and there exists some element of
|
||||
// that tuple that is not used by the loop condition and is not used by the loop
|
||||
// body except to pass it to the next iteration of the loop, then we can remove
|
||||
// that element from the loop's tuples.
|
||||
static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
|
||||
CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
|
||||
|
||||
// Don't try this transformation if the while loop isn't removable, since if
|
||||
// it succeeds ultimately we're going to have to replace the old while loop
|
||||
// with a new one.
|
||||
if (!while_op->parent()->IsSafelyRemovable(while_op)) {
|
||||
VLOG(2) << "Can't remove dead parameters from non-removable while op.";
|
||||
return false;
|
||||
}
|
||||
|
||||
HloInstruction* while_init = while_op->mutable_operand(0);
|
||||
HloComputation* while_cond = while_op->while_condition();
|
||||
HloComputation* while_body = while_op->while_body();
|
||||
HloInstruction* while_body_root = while_body->root_instruction();
|
||||
|
||||
if (!while_init->shape().IsTuple()) {
|
||||
VLOG(2) << "While op's carried value isn't tuple shaped.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (while_body_root->opcode() != HloOpcode::kTuple) {
|
||||
VLOG(2) << "While body's root is not a tuple(...) instruction.";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto print_no_metadata = HloPrintOptions().set_print_metadata(false);
|
||||
|
||||
// Bail if param0 of while_cond or while_body has users which aren't of type
|
||||
// get-tuple-element.
|
||||
for (const HloInstruction* instr : {while_body->parameter_instruction(0),
|
||||
while_cond->parameter_instruction(0)}) {
|
||||
for (const HloInstruction* user : instr->users()) {
|
||||
if (user->opcode() != HloOpcode::kGetTupleElement) {
|
||||
VLOG(2) << "Cowardly refusing to analyze while loop with "
|
||||
<< instr->ToString(print_no_metadata)
|
||||
<< " used by non-GTE instruction "
|
||||
<< user->ToString(print_no_metadata) << " in computation "
|
||||
<< instr->parent()->name();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape());
|
||||
if (tuple_size == 0) {
|
||||
VLOG(2) << "Can't remove elements from while loop's tuple -- it's already "
|
||||
"empty.";
|
||||
return false;
|
||||
}
|
||||
|
||||
absl::flat_hash_set<int64> used_tuple_indices;
|
||||
for (HloComputation* comp : {while_body, while_cond}) {
|
||||
// The HLO verifier ensures that while_input's shape matches while_init's
|
||||
// shape, which we verified above is a tuple.
|
||||
HloInstruction* while_input = comp->parameter_instruction(0);
|
||||
|
||||
for (const HloInstruction* user : while_input->users()) {
|
||||
// This user doesn't count if it's only used by the while body's root, and
|
||||
// the root places the tuple element into the same index of the tuple as
|
||||
// it came from. That just amounts to us carrying the variable through
|
||||
// the loop.
|
||||
//
|
||||
// Careful: HloInstruction::operand_index returns the first index the
|
||||
// operand appears in, but it may appear more than once!
|
||||
if (user->user_count() == 1 && user->users().front() == while_body_root &&
|
||||
while_body_root->operand_index(user) == user->tuple_index() &&
|
||||
absl::c_count(while_body_root->operands(), user) == 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
used_tuple_indices.insert(user->tuple_index());
|
||||
if (used_tuple_indices.size() == tuple_size) {
|
||||
VLOG(2) << "Loop " << while_op->ToString(print_no_metadata)
|
||||
<< " uses all of its inputs; no simplification possible.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If a tuple element is not passed unmodified from the while body's param0
|
||||
// through to the while body's root, count that element as "used", since
|
||||
// removing that element would be observable.
|
||||
for (int64 i = 0; i < while_body_root->operand_count(); ++i) {
|
||||
if (used_tuple_indices.contains(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto* operand = while_body_root->operand(i);
|
||||
if (operand->opcode() != HloOpcode::kGetTupleElement ||
|
||||
operand->operand(0) != while_body->parameter_instruction(0) ||
|
||||
operand->tuple_index() != i) {
|
||||
VLOG(2) << "Tuple index " << i
|
||||
<< " is not passed through loop body unmodified.";
|
||||
used_tuple_indices.insert(i);
|
||||
|
||||
if (used_tuple_indices.size() == tuple_size) {
|
||||
VLOG(2) << "Loop " << while_op->ToString(print_no_metadata)
|
||||
<< " uses all of its inputs; no simplification possible.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we got here, used_tuple_indices.size() < tuple_size, meaning some
|
||||
// elements of the loop's tuple aren't used by while_body or while_cond.
|
||||
CHECK_LT(used_tuple_indices.size(), tuple_size);
|
||||
|
||||
VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size()
|
||||
<< " elements from tuple of "
|
||||
<< while_op->ToString(print_no_metadata);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(while_op,
|
||||
RemoveDeadTupleIndices(while_op, used_tuple_indices));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// This is a helper function for TryRemoveRepeatedWhileTupleIndices. It removes
|
||||
// duplicates by replacing them with tuple_index, followed by a call to
|
||||
// RemoveDeadTupleIndices.
|
||||
static StatusOr<HloInstruction*> TryRemoveRepeatedWhileTupleIndicesHelper(
|
||||
HloInstruction* while_op, const int64 tuple_index,
|
||||
absl::flat_hash_set<int64>& duplicates) {
|
||||
HloComputation* while_cond = while_op->while_condition();
|
||||
HloComputation* while_body = while_op->while_body();
|
||||
HloInstruction* while_init = while_op->mutable_operand(0);
|
||||
|
||||
VLOG(2) << "while_init " << while_init->ToString() << " operands "
|
||||
<< while_init->operand_count();
|
||||
VLOG(2) << "while_body_root " << while_body->root_instruction()->ToString()
|
||||
<< " operands " << while_body->root_instruction()->operand_count();
|
||||
|
||||
// Change the loop body and condition such that uses of the duplicates are
|
||||
// replaced with the original tuple element.
|
||||
for (HloComputation* comp : {while_body, while_cond}) {
|
||||
auto new_get = comp->AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
comp->parameter_instruction(0)->shape().tuple_shapes(tuple_index),
|
||||
comp->parameter_instruction(0), tuple_index));
|
||||
|
||||
std::vector<HloInstruction*> instrs_to_replace;
|
||||
for (auto* instr : comp->instructions()) {
|
||||
if (instr->opcode() == HloOpcode::kGetTupleElement &&
|
||||
duplicates.contains(instr->tuple_index()) &&
|
||||
instr->operand(0) == comp->parameter_instruction(0)) {
|
||||
instrs_to_replace.push_back(instr);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto instr : instrs_to_replace) {
|
||||
TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_get));
|
||||
}
|
||||
}
|
||||
|
||||
// We know which tuple indices are useful; i.e, those which aren't duplicates.
|
||||
absl::flat_hash_set<int64> used_tuple_indices;
|
||||
for (int index = 0; index < while_init->shape().tuple_shapes_size();
|
||||
++index) {
|
||||
if (!duplicates.count(index)) {
|
||||
used_tuple_indices.insert(index);
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the duplicate tuple elements.
|
||||
TF_ASSIGN_OR_RETURN(while_op,
|
||||
RemoveDeadTupleIndices(while_op, used_tuple_indices));
|
||||
|
||||
return while_op;
|
||||
}
|
||||
|
||||
// If the while loop init passes the same values to several tuple indices, and
|
||||
// if the body keeps on passing them through, we can remove the duplicates.
|
||||
static StatusOr<bool> TryRemoveRepeatedWhileTupleIndices(
|
||||
HloInstruction* while_op) {
|
||||
CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
|
||||
|
||||
int index_to_investigate = 0;
|
||||
// Don't try this transformation if the while loop isn't removable, since if
|
||||
// it succeeds ultimately we're going to have to replace the old while loop
|
||||
// with a new one.
|
||||
if (!while_op->parent()->IsSafelyRemovable(while_op)) {
|
||||
VLOG(2) << "Can't remove dead parameters from non-removable while op.";
|
||||
return false;
|
||||
}
|
||||
|
||||
HloInstruction* while_init = while_op->mutable_operand(0);
|
||||
HloComputation* while_cond = while_op->while_condition();
|
||||
HloComputation* while_body = while_op->while_body();
|
||||
HloInstruction* while_body_root = while_body->root_instruction();
|
||||
|
||||
if (!while_init->shape().IsTuple()) {
|
||||
VLOG(2) << "While op's carried value isn't tuple shaped.";
|
||||
return false;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
while (index_to_investigate < while_init->shape().tuple_shapes_size()) {
|
||||
if (!while_init->shape().IsTuple() ||
|
||||
while_init->opcode() != HloOpcode::kTuple) {
|
||||
VLOG(2) << "While op's carried value isn't tuple shaped.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (while_body_root->opcode() != HloOpcode::kTuple) {
|
||||
VLOG(2) << "While body's root is not a tuple(...) instruction.";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto& while_shape = while_init->shape();
|
||||
VLOG(2) << "Iterating " << index_to_investigate;
|
||||
|
||||
absl::flat_hash_set<int64> duplicates;
|
||||
auto* pivot_init_elem = while_init->operand(index_to_investigate);
|
||||
auto* pivot_body_elem = while_body_root->operand(index_to_investigate);
|
||||
if (pivot_body_elem->opcode() == HloOpcode::kGetTupleElement &&
|
||||
pivot_body_elem->operand(0) == while_body->parameter_instruction(0)) {
|
||||
if (pivot_body_elem->tuple_index() != index_to_investigate) {
|
||||
VLOG(2) << "Mismatch between pivot_body_elem->tuple_index() "
|
||||
<< pivot_body_elem->tuple_index() << " index_to_investigate "
|
||||
<< index_to_investigate;
|
||||
index_to_investigate++;
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
index_to_investigate++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Look from index_to_investigate onwards to see if it is repeated.
|
||||
for (int64 i = index_to_investigate + 1;
|
||||
i < while_shape.tuple_shapes_size(); ++i) {
|
||||
auto* init_elem = while_init->operand(i);
|
||||
auto* body_elem = while_body_root->operand(i);
|
||||
if (body_elem->opcode() == HloOpcode::kGetTupleElement &&
|
||||
body_elem->operand(0) == while_body->parameter_instruction(0)) {
|
||||
if (body_elem->tuple_index() != i) {
|
||||
VLOG(2) << "Mismatch between body_elem->tuple_index() "
|
||||
<< body_elem->tuple_index() << " i " << i;
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (pivot_init_elem == init_elem) {
|
||||
VLOG(2) << "init_elem " << init_elem->ToString() << " pivot_init_elem "
|
||||
<< pivot_init_elem->ToString();
|
||||
VLOG(2) << "body_elem " << body_elem->ToString() << " pivot_body_elem "
|
||||
<< pivot_body_elem->ToString();
|
||||
duplicates.insert(i);
|
||||
}
|
||||
}
|
||||
|
||||
// If duplicates are found, call the helper to remove them.
|
||||
if (!duplicates.empty()) {
|
||||
VLOG(2) << "Duplicate found " << duplicates.size() << " pivot_init "
|
||||
<< pivot_init_elem->ToString();
|
||||
TF_ASSIGN_OR_RETURN(while_op,
|
||||
TryRemoveRepeatedWhileTupleIndicesHelper(
|
||||
while_op, index_to_investigate, duplicates));
|
||||
changed = true;
|
||||
VLOG(2) << "Changed while_op " << while_op->ToString()
|
||||
<< " while_op operand count " << while_op->operand_count();
|
||||
// Update the while loop variables so we can continue looking for
|
||||
// duplicates of a different index.
|
||||
while_init = while_op->mutable_operand(0);
|
||||
while_cond = while_op->while_condition();
|
||||
while_body = while_op->while_body();
|
||||
while_body_root = while_body->root_instruction();
|
||||
}
|
||||
index_to_investigate++;
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
// Removes each loop parameter (i.e. member of the while loop tuple) that is a
|
||||
// constant and is the same in the while loop body and the while loop init.
|
||||
static StatusOr<bool> TryRemoveConstantParams(HloInstruction* while_op) {
|
||||
@ -1048,6 +1226,7 @@ StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) {
|
||||
|
||||
TF_ASSIGN_OR_RETURN(result, TryRemoveWhileLoop(while_op));
|
||||
changed |= result;
|
||||
|
||||
if (result) {
|
||||
// Don't continue simplifying after successfully removing the while loop
|
||||
// -- that would result in use-after-free nastiness.
|
||||
@ -1067,6 +1246,12 @@ StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) {
|
||||
// successful, meaning that `while_op` is no longer valid after one of these
|
||||
// transformations returns true.
|
||||
|
||||
TF_ASSIGN_OR_RETURN(result, TryRemoveRepeatedWhileTupleIndices(while_op));
|
||||
changed |= result;
|
||||
if (result) {
|
||||
continue;
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(result, TryFlattenNestedTuples(while_op));
|
||||
changed |= result;
|
||||
if (result) {
|
||||
@ -1074,6 +1259,7 @@ StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) {
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(result, TryRemoveDeadWhileParams(while_op));
|
||||
|
||||
changed |= result;
|
||||
if (result) {
|
||||
continue;
|
||||
|
||||
@ -794,5 +794,51 @@ TEST_F(WhileLoopSimplifierTest, MergeInductionVariables_SkipS16) {
|
||||
.ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(WhileLoopSimplifierTest, RemoveRepeatedParams) {
|
||||
const string hlo_string = R"(
|
||||
HloModule SwappingTupleElements
|
||||
|
||||
SwappingTupleElements.body {
|
||||
loop_var = (s32[], s32[], s32[]) parameter(0)
|
||||
get-tuple-element = s32[] get-tuple-element(loop_var), index=0
|
||||
get-tuple-element.1 = s32[] get-tuple-element(loop_var), index=1
|
||||
get-tuple-element.2 = s32[] get-tuple-element(loop_var), index=2
|
||||
y = s32[] add(get-tuple-element.1, get-tuple-element.2)
|
||||
ROOT tuple = (s32[], s32[], s32[]) tuple(s32[] get-tuple-element, y,
|
||||
s32[] get-tuple-element.2)
|
||||
}
|
||||
|
||||
SwappingTupleElements.always_true {
|
||||
param = (s32[], s32[], s32[]) parameter(0)
|
||||
get-tuple-element = s32[] get-tuple-element(param), index=0
|
||||
get-tuple-element.1 = s32[] get-tuple-element(param), index=1
|
||||
ROOT less-than = pred[] compare(get-tuple-element, get-tuple-element.1), direction=LT
|
||||
}
|
||||
|
||||
ENTRY SwappingTupleElements {
|
||||
x = s32[] parameter(0)
|
||||
y = s32[] parameter(1)
|
||||
tuple.1 = (s32[], s32[], s32[]) tuple(s32[] x, s32[] y, s32[] x)
|
||||
ROOT while = (s32[], s32[], s32[]) while(tuple.1),
|
||||
condition=SwappingTupleElements.always_true,
|
||||
body=SwappingTupleElements.body
|
||||
}
|
||||
)";
|
||||
|
||||
auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
|
||||
HloInstruction* new_while = FindFirstWhile(m.get());
|
||||
Shape new_while_shape = ParseShape("(s32[], s32[])").ValueOrDie();
|
||||
EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape));
|
||||
EXPECT_TRUE(ShapeUtil::Equal(
|
||||
new_while->while_body()->root_instruction()->shape(), new_while_shape));
|
||||
EXPECT_TRUE(ShapeUtil::Equal(
|
||||
new_while->while_body()->parameter_instruction(0)->shape(),
|
||||
new_while_shape));
|
||||
EXPECT_TRUE(ShapeUtil::Equal(
|
||||
new_while->while_condition()->parameter_instruction(0)->shape(),
|
||||
new_while_shape));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user