[XLA] Preserve the layout of the slice shape during algebraic_simplification.

PiperOrigin-RevId: 336151483
Change-Id: I500927df6cf274d2e8f3e82dcd0bd37a8661c8f1
This commit is contained in:
Blake Hechtman 2020-10-08 13:20:37 -07:00 committed by TensorFlower Gardener
parent a97f22586d
commit c51c793a53

View File

@ -4116,8 +4116,10 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
new_limits[i] -= low;
}
if (slice_in_padding) {
return ReplaceInstruction(
slice, MakeBroadcastHlo(pad->mutable_operand(1), {}, slice->shape()));
HloInstruction* broadcast =
MakeBroadcastHlo(pad->mutable_operand(1), {}, slice->shape());
*(broadcast->mutable_shape()) = slice->shape();
return ReplaceInstruction(slice, broadcast);
}
if (slice_undoes_pad && ReplaceInstructionIfSameShape(slice, pad_operand)) {
return Status::OK();
@ -4126,6 +4128,7 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
TF_ASSIGN_OR_RETURN(HloInstruction * new_slice,
MakeSliceHlo(pad_operand, new_starts, new_limits,
slice->slice_strides()));
*(new_slice->mutable_shape()) = slice->shape();
return ReplaceInstruction(slice, new_slice);
}
}