diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index d0d533e0b06..f19882c9347 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -698,7 +698,7 @@ bool HloDataflowAnalysis::UpdateCollectivePermuteDoneValueSet( CHECK_EQ(collective_permute_done->opcode(), HloOpcode::kCollectivePermuteDone); bool changed = false; - // CollectivePermuteDone forwards the operand value at {0} to its output. + // CollectivePermuteDone forwards the operand value at {1} to its output. const HloValueSet& operand_value_set = GetValueSet(collective_permute_done->operand(0), {1}); HloValueSet& value_set = GetValueSet(collective_permute_done); @@ -945,6 +945,17 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { // CopyDone consumes a tuple produced by CopyStart and produces an // element. Its output aliases its input tuple element {0}. break; + case HloOpcode::kCollectivePermuteStart: + // CollectivePermuteStart produces a tuple of + // {aliased operand, destination buffer, U32 context, U32 context}. + define_value_at(/*index=*/{}); + define_value_at(/*index=*/{1}); + define_value_at(/*index=*/{2}); + define_value_at(/*index=*/{3}); + break; + case HloOpcode::kCollectivePermuteDone: + // CollectivePermuteDone's output aliases its input tuple element {1}. + break; case HloOpcode::kRecvDone: // RecvDone produces a two-element tuple. Element zero aliases its // input tuple element {0}; element one is a token.