[XLA] Preserve the layout of the slice shape during algebraic_simplification.
PiperOrigin-RevId: 336151483 Change-Id: I500927df6cf274d2e8f3e82dcd0bd37a8661c8f1
This commit is contained in:
parent
a97f22586d
commit
c51c793a53
@ -4116,8 +4116,10 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
|
||||
new_limits[i] -= low;
|
||||
}
|
||||
if (slice_in_padding) {
|
||||
return ReplaceInstruction(
|
||||
slice, MakeBroadcastHlo(pad->mutable_operand(1), {}, slice->shape()));
|
||||
HloInstruction* broadcast =
|
||||
MakeBroadcastHlo(pad->mutable_operand(1), {}, slice->shape());
|
||||
*(broadcast->mutable_shape()) = slice->shape();
|
||||
return ReplaceInstruction(slice, broadcast);
|
||||
}
|
||||
if (slice_undoes_pad && ReplaceInstructionIfSameShape(slice, pad_operand)) {
|
||||
return Status::OK();
|
||||
@ -4126,6 +4128,7 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * new_slice,
|
||||
MakeSliceHlo(pad_operand, new_starts, new_limits,
|
||||
slice->slice_strides()));
|
||||
*(new_slice->mutable_shape()) = slice->shape();
|
||||
return ReplaceInstruction(slice, new_slice);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user