[XLA] Preserve metadata when sinking broadcasts.
PiperOrigin-RevId: 352796310 Change-Id: I1217a7bcf232de01b51d0fef231df0b0b4742a66
This commit is contained in:
parent
973e6773e9
commit
43f5aec62f
@ -3624,6 +3624,7 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
|
|||||||
new_operands.push_back(
|
new_operands.push_back(
|
||||||
computation_->AddInstruction(HloInstruction::CreateBroadcast(
|
computation_->AddInstruction(HloInstruction::CreateBroadcast(
|
||||||
changed_shape, user_operand->mutable_operand(0), {})));
|
changed_shape, user_operand->mutable_operand(0), {})));
|
||||||
|
user_operand->SetupDerivedInstruction(new_operands.back());
|
||||||
} else {
|
} else {
|
||||||
// For the non-scalar broadcasts we guarantee that the shape of the
|
// For the non-scalar broadcasts we guarantee that the shape of the
|
||||||
// operand of the broadcast needs to be already a compatible shape.
|
// operand of the broadcast needs to be already a compatible shape.
|
||||||
@ -3646,6 +3647,7 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
|
|||||||
HloInstruction* new_broadcast =
|
HloInstruction* new_broadcast =
|
||||||
computation_->AddInstruction(HloInstruction::CreateBroadcast(
|
computation_->AddInstruction(HloInstruction::CreateBroadcast(
|
||||||
user->shape(), new_user, broadcast->dimensions()));
|
user->shape(), new_user, broadcast->dimensions()));
|
||||||
|
broadcast->SetupDerivedInstruction(new_broadcast);
|
||||||
VLOG(4) << " new broadcast: " << new_broadcast->ToString();
|
VLOG(4) << " new broadcast: " << new_broadcast->ToString();
|
||||||
TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_broadcast));
|
TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_broadcast));
|
||||||
changed = true;
|
changed = true;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user