Fix conditional canonicalizer to use clone instead of mutating existing conditional's shape.
Today I learned that it's not a good idea to mess up with existing instruction's shape.. Because: 1. It makes the code more confusing. 2. It may mess up with existing computation's root shape, and if we change root shape, we discard all aliasing info. See tensorflow/compiler/xla/service/hlo_computation.cc:347 PiperOrigin-RevId: 323431445 Change-Id: I64f0bd68a4e9832aa5d2aa22959a4083d9ae23f4
This commit is contained in:
parent
b21b471326
commit
24d4ff08e9
@ -31,11 +31,14 @@ Status CanonicalizeNonTupleConditional(HloInstruction* conditional) {
|
||||
branch->AddInstruction(HloInstruction::CreateTuple({root}));
|
||||
branch->set_root_instruction(tuple, /*accept_different_shape=*/true);
|
||||
}
|
||||
auto parent = conditional->parent();
|
||||
auto root_shape = conditional->shape();
|
||||
*conditional->mutable_shape() = ShapeUtil::MakeTupleShape({root_shape});
|
||||
auto gte = conditional->parent()->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(root_shape, conditional, 0));
|
||||
TF_RETURN_IF_ERROR(conditional->ReplaceAllUsesWithDifferentShape(gte));
|
||||
auto new_shape = ShapeUtil::MakeTupleShape({root_shape});
|
||||
auto new_conditional =
|
||||
parent->AddInstruction(conditional->CloneWithNewShape(new_shape));
|
||||
auto gte = parent->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(root_shape, new_conditional, 0));
|
||||
TF_RETURN_IF_ERROR(parent->ReplaceInstruction(conditional, gte));
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
|
@ -1750,10 +1750,10 @@ void HloInstruction::DetachFromOperandsAndUsers() {
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<HloInstruction> HloInstruction::Clone(
|
||||
const string& suffix, HloCloneContext* context) const {
|
||||
std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewShape(
|
||||
const Shape& shape, const string& suffix, HloCloneContext* context) const {
|
||||
std::unique_ptr<HloInstruction> clone =
|
||||
CloneWithNewOperands(shape_, operands_, context);
|
||||
CloneWithNewOperands(shape, operands_, context);
|
||||
if (suffix.empty()) {
|
||||
clone->name_ = name();
|
||||
} else {
|
||||
@ -1790,6 +1790,13 @@ std::unique_ptr<HloInstruction> HloInstruction::Clone(
|
||||
return clone;
|
||||
}
|
||||
|
||||
std::unique_ptr<HloInstruction> HloInstruction::Clone(
|
||||
const string& suffix, HloCloneContext* context) const {
|
||||
std::unique_ptr<HloInstruction> clone =
|
||||
CloneWithNewShape(shape_, suffix, context);
|
||||
return clone;
|
||||
}
|
||||
|
||||
std::pair<const HloInstruction*, ShapeIndex>
|
||||
HloInstruction::LatestNonGteAncestorAndIndex() const {
|
||||
const HloInstruction* hlo = this;
|
||||
|
@ -1419,6 +1419,11 @@ class HloInstruction {
|
||||
std::unique_ptr<HloInstruction> Clone(
|
||||
const string& suffix = "clone", HloCloneContext* context = nullptr) const;
|
||||
|
||||
// Clones the HLO instruction as above but with new shape.
|
||||
std::unique_ptr<HloInstruction> CloneWithNewShape(
|
||||
const Shape& shape, const string& suffix = "clone",
|
||||
HloCloneContext* context = nullptr) const;
|
||||
|
||||
// Clones the HLO instruction as above but with new shape and operands.
|
||||
std::unique_ptr<HloInstruction> CloneWithNewOperands(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||
|
Loading…
Reference in New Issue
Block a user