[XLA] Identical check shouldn't skip operands for all-reduces.

PiperOrigin-RevId: 348200745
Change-Id: I8c4ca10a9d8da484ceb527ca239a682afff6c038
This commit is contained in:
Berkin Ilbeyi 2020-12-18 10:26:29 -08:00 committed by TensorFlower Gardener
parent e8f1a46dc1
commit 9ddbada1e0
2 changed files with 11 additions and 12 deletions

View File

@ -365,10 +365,13 @@ bool ArCrsCombiner::InstructionsComputeSameValue(
auto eq_computations = [](const HloComputation* a, const HloComputation* b) {
return *a == *b;
};
// Two MPMD AllReduces are identical if they have the same channel_id. Their
// operands don't have to be identical.
auto eq_operands = [](const HloInstruction*, const HloInstruction*) {
return true;
};
if (i1->IsCrossModuleAllReduce()) {
return i1->Identical(*i2,
/*eq_operands=*/std::equal_to<const HloInstruction*>(),
eq_computations,
return i1->Identical(*i2, eq_operands, eq_computations,
/*layout_sensitive=*/false);
}
visited_pairs->emplace(min_uid, max_uid);

View File

@ -2024,15 +2024,11 @@ bool HloInstruction::IdenticalInternal(
return false;
}
// Two AllReduces are Identical if they have the same channel_id.
// Their operands don't have to be Identical.
if (!IsCrossModuleAllReduce()) {
// Use an explicit loop rather than ContainerEquals, because copying
// around std::functions may be too expensive in some cases.
for (size_t i = 0; i < operands().size(); ++i) {
if (!eq_operands(operand(i), other.operand(i))) {
return false;
}
// Use an explicit loop rather than ContainerEquals, because copying around
// std::functions may be too expensive in some cases.
for (size_t i = 0; i < operands().size(); ++i) {
if (!eq_operands(operand(i), other.operand(i))) {
return false;
}
}