[XLA] Preserve metadata when sinking broadcasts.
PiperOrigin-RevId: 352796310 Change-Id: I1217a7bcf232de01b51d0fef231df0b0b4742a66
This commit is contained in:
parent
973e6773e9
commit
43f5aec62f
@ -3624,6 +3624,7 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
|
||||
new_operands.push_back(
|
||||
computation_->AddInstruction(HloInstruction::CreateBroadcast(
|
||||
changed_shape, user_operand->mutable_operand(0), {})));
|
||||
user_operand->SetupDerivedInstruction(new_operands.back());
|
||||
} else {
|
||||
// For the non-scalar broadcasts we guarantee that the shape of the
|
||||
// operand of the broadcast needs to be already a compatible shape.
|
||||
@ -3646,6 +3647,7 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
|
||||
HloInstruction* new_broadcast =
|
||||
computation_->AddInstruction(HloInstruction::CreateBroadcast(
|
||||
user->shape(), new_user, broadcast->dimensions()));
|
||||
broadcast->SetupDerivedInstruction(new_broadcast);
|
||||
VLOG(4) << " new broadcast: " << new_broadcast->ToString();
|
||||
TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_broadcast));
|
||||
changed = true;
|
||||
|
Loading…
x
Reference in New Issue
Block a user