diff --git a/tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.cc b/tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.cc index 295ccebd442..64a1403a3f9 100644 --- a/tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.cc +++ b/tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.cc @@ -101,6 +101,7 @@ class EnforceMinorToMajorReduceOpVisitor : public DfsHloRewriteVisitor { new_reduce_shape_layout); HloInstruction *canonical_reduce_input = reduce->parent()->AddInstruction( HloInstruction::CreateBitcast(new_operand_shape, operand)); + canonical_reduce_input->set_metadata(reduce->metadata()); VLOG(5) << "Reduction input: " << canonical_reduce_input->ToString(); std::unique_ptr new_reduce = HloInstruction::CreateReduce(