[XLA] Preserve replication info when cloning a parameter

PiperOrigin-RevId: 313319423
Change-Id: Ic92f71d5bc78e0b0ab04264ba1ea0b4416c24159
This commit is contained in:
Yuanzhong Xu 2020-05-26 20:39:18 -07:00 committed by TensorFlower Gardener
parent 3d333927a3
commit fa0a9c876a

View File

@ -1867,8 +1867,14 @@ std::unique_ptr<HloInstruction>
HloParameterInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
return absl::make_unique<HloParameterInstruction>(parameter_number_, shape,
name());
auto clone = absl::make_unique<HloParameterInstruction>(parameter_number_,
shape, name());
if (parameter_replicated_at_leaf_buffers_ &&
ShapeUtil::Equal(shape, this->shape())) {
clone->set_parameter_replicated_at_leaf_buffers(
*parameter_replicated_at_leaf_buffers_);
}
return clone;
}
HloGetTupleElementInstruction::HloGetTupleElementInstruction(