[XLA] Preserve metadata when sinking broadcasts.

PiperOrigin-RevId: 352796310
Change-Id: I1217a7bcf232de01b51d0fef231df0b0b4742a66
This commit is contained in:
Blake Hechtman 2021-01-20 08:13:20 -08:00 committed by TensorFlower Gardener
parent 973e6773e9
commit 43f5aec62f

View File

@ -3624,6 +3624,7 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
new_operands.push_back( new_operands.push_back(
computation_->AddInstruction(HloInstruction::CreateBroadcast( computation_->AddInstruction(HloInstruction::CreateBroadcast(
changed_shape, user_operand->mutable_operand(0), {}))); changed_shape, user_operand->mutable_operand(0), {})));
user_operand->SetupDerivedInstruction(new_operands.back());
} else { } else {
// For the non-scalar broadcasts we guarantee that the shape of the // For the non-scalar broadcasts we guarantee that the shape of the
// operand of the broadcast needs to be already a compatible shape. // operand of the broadcast needs to be already a compatible shape.
@ -3646,6 +3647,7 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
HloInstruction* new_broadcast = HloInstruction* new_broadcast =
computation_->AddInstruction(HloInstruction::CreateBroadcast( computation_->AddInstruction(HloInstruction::CreateBroadcast(
user->shape(), new_user, broadcast->dimensions())); user->shape(), new_user, broadcast->dimensions()));
broadcast->SetupDerivedInstruction(new_broadcast);
VLOG(4) << " new broadcast: " << new_broadcast->ToString(); VLOG(4) << " new broadcast: " << new_broadcast->ToString();
TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_broadcast)); TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_broadcast));
changed = true; changed = true;