Keep the old instruction alive when replacing instructions in GetDimensionSizeRewriter.

Otherwise it could create use-after-free if we delete and iterate
through instructions at the same time.

PiperOrigin-RevId: 223544664
This commit is contained in:
Yunxing Dai 2018-11-30 11:13:46 -08:00 committed by TensorFlower Gardener
parent 92e236dd24
commit f23efde510

View File

@ -39,7 +39,7 @@ StatusOr<bool> ReplaceGetSize(HloInstruction* instr) {
uint32 size = instr->operand(0)->shape().dimensions(instr->dimension()); uint32 size = instr->operand(0)->shape().dimensions(instr->dimension());
HloInstruction* new_instr = computation->AddInstruction( HloInstruction* new_instr = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(size))); HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(size)));
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(instr, new_instr)); TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr));
return true; return true;
} }
@ -50,12 +50,7 @@ StatusOr<bool> HloGetDimensionSizeRewriter::Run(HloModule* module) {
HloProto proto; HloProto proto;
*proto.mutable_hlo_module() = module->ToProto(); *proto.mutable_hlo_module() = module->ToProto();
for (auto* computation : module->computations()) { for (auto* computation : module->computations()) {
// Replacing instructions will change the instruction list in the for (auto instruction : computation->instructions()) {
// computation. So instead of iterating computation->instructions()
// directly, we make a copy of the list to avoid use-after-free.
std::vector<HloInstruction*> instrs(computation->instruction_count());
absl::c_copy(computation->instructions(), instrs.begin());
for (auto instruction : instrs) {
TF_ASSIGN_OR_RETURN(bool replaced, ReplaceGetSize(instruction)); TF_ASSIGN_OR_RETURN(bool replaced, ReplaceGetSize(instruction));
changed = changed || replaced; changed = changed || replaced;
} }