Fix conditional canonicalizer to use clone instead of mutating existing conditional's shape.

Today I learned that it's not a good idea to mess up with existing instruction's shape.. Because:
1. It makes the code more confusing.
2. It may mess up with existing computation's root shape, and if we change root shape, we discard all aliasing info. See tensorflow/compiler/xla/service/hlo_computation.cc:347

PiperOrigin-RevId: 323431445
Change-Id: I64f0bd68a4e9832aa5d2aa22959a4083d9ae23f4
This commit is contained in:
Yunxing Dai 2020-07-27 13:40:17 -07:00 committed by TensorFlower Gardener
parent b21b471326
commit 24d4ff08e9
3 changed files with 22 additions and 7 deletions

View File

@ -31,11 +31,14 @@ Status CanonicalizeNonTupleConditional(HloInstruction* conditional) {
branch->AddInstruction(HloInstruction::CreateTuple({root}));
branch->set_root_instruction(tuple, /*accept_different_shape=*/true);
}
auto parent = conditional->parent();
auto root_shape = conditional->shape();
*conditional->mutable_shape() = ShapeUtil::MakeTupleShape({root_shape});
auto gte = conditional->parent()->AddInstruction(
HloInstruction::CreateGetTupleElement(root_shape, conditional, 0));
TF_RETURN_IF_ERROR(conditional->ReplaceAllUsesWithDifferentShape(gte));
auto new_shape = ShapeUtil::MakeTupleShape({root_shape});
auto new_conditional =
parent->AddInstruction(conditional->CloneWithNewShape(new_shape));
auto gte = parent->AddInstruction(
HloInstruction::CreateGetTupleElement(root_shape, new_conditional, 0));
TF_RETURN_IF_ERROR(parent->ReplaceInstruction(conditional, gte));
return Status::OK();
}
} // namespace

View File

@ -1750,10 +1750,10 @@ void HloInstruction::DetachFromOperandsAndUsers() {
}
}
std::unique_ptr<HloInstruction> HloInstruction::Clone(
const string& suffix, HloCloneContext* context) const {
std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewShape(
const Shape& shape, const string& suffix, HloCloneContext* context) const {
std::unique_ptr<HloInstruction> clone =
CloneWithNewOperands(shape_, operands_, context);
CloneWithNewOperands(shape, operands_, context);
if (suffix.empty()) {
clone->name_ = name();
} else {
@ -1790,6 +1790,13 @@ std::unique_ptr<HloInstruction> HloInstruction::Clone(
return clone;
}
std::unique_ptr<HloInstruction> HloInstruction::Clone(
const string& suffix, HloCloneContext* context) const {
std::unique_ptr<HloInstruction> clone =
CloneWithNewShape(shape_, suffix, context);
return clone;
}
std::pair<const HloInstruction*, ShapeIndex>
HloInstruction::LatestNonGteAncestorAndIndex() const {
const HloInstruction* hlo = this;

View File

@ -1419,6 +1419,11 @@ class HloInstruction {
std::unique_ptr<HloInstruction> Clone(
const string& suffix = "clone", HloCloneContext* context = nullptr) const;
// Clones the HLO instruction as above but with new shape.
std::unique_ptr<HloInstruction> CloneWithNewShape(
const Shape& shape, const string& suffix = "clone",
HloCloneContext* context = nullptr) const;
// Clones the HLO instruction as above but with new shape and operands.
std::unique_ptr<HloInstruction> CloneWithNewOperands(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,