diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 4894e566393..d0d533e0b06 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -676,6 +676,39 @@ bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) { } } +bool HloDataflowAnalysis::UpdateCollectivePermuteStartValueSet( + HloInstruction* collective_permute_start) { + CHECK_EQ(collective_permute_start->opcode(), + HloOpcode::kCollectivePermuteStart); + bool changed = false; + // CollectivePermuteStart forwards the operand value to element {0} of its + // output. + const HloValueSet& operand_value_set = + GetValueSet(collective_permute_start->operand(0)); + HloValueSet& value_set = GetValueSet(collective_permute_start, {0}); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + return changed; +} + +bool HloDataflowAnalysis::UpdateCollectivePermuteDoneValueSet( + HloInstruction* collective_permute_done) { + CHECK_EQ(collective_permute_done->opcode(), + HloOpcode::kCollectivePermuteDone); + bool changed = false; + // CollectivePermuteDone forwards the operand value at {0} to its output. + const HloValueSet& operand_value_set = + GetValueSet(collective_permute_done->operand(0), {1}); + HloValueSet& value_set = GetValueSet(collective_permute_done); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + return changed; +} + bool HloDataflowAnalysis::UpdateInstructionValueSet( HloInstruction* instruction) { // Recompute from operands. @@ -712,6 +745,10 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( return UpdateCopyDoneValueSet(instruction); case HloOpcode::kConditional: return UpdateConditionalValueSet(instruction); + case HloOpcode::kCollectivePermuteStart: + return UpdateCollectivePermuteStartValueSet(instruction); + case HloOpcode::kCollectivePermuteDone: + return UpdateCollectivePermuteDoneValueSet(instruction); default: // Instruction does not forward HloValues (it defines all values in its // output). No update is necessary. diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index 75bcf7ea318..bec592aeb20 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -216,6 +216,10 @@ class HloDataflowAnalysis { bool UpdateTupleValueSet(HloInstruction* tuple); bool UpdateWhileValueSet(HloInstruction* xla_while); bool UpdateAddDependencyValueSet(HloInstruction* add_dependency); + bool UpdateCollectivePermuteStartValueSet( + HloInstruction* collective_permute_start); + bool UpdateCollectivePermuteDoneValueSet( + HloInstruction* collective_permute_done); // Propagates the dataflow through the module. In particular, it propagates // the HloValueSet from its defining instruction to the users of the diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 294dfbf66fa..d9709c50df9 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -1201,7 +1201,8 @@ TEST_F(HloVerifierTest, CollectivePermuteDoneNoCollectivePermuteStart) { p0 = f32[2,3]{1,0:S(1)} parameter(0) p1 = f32[2,3]{1,0:S(1)} parameter(1) p2 = u32[] parameter(2) - tuple.1 = (f32[2,3], f32[2,3], u32[], u32[]) tuple(p0, p1, p2) + p3 = u32[] parameter(3) + tuple.1 = (f32[2,3], f32[2,3], u32[], u32[]) tuple(p0, p1, p2, p3) ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(tuple.1) } )"; diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index e4ee79e9f4c..0ed72f51754 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -521,6 +521,20 @@ bool AlternateMemoryBestFitHeap::IsIntervalAllowedInAlternateMemory( return false; } } + + if ((position.instruction->opcode() == HloOpcode::kCollectivePermuteStart || + position.instruction->opcode() == HloOpcode::kCollectivePermuteDone)) { + // Disable memory space allocation for these for now. + if (position.index == ShapeIndex({0})) { + VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + << " in default mem because it is a collective-permute buffer."; + return false; + } else if (position.index == ShapeIndex({1})) { + VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + << " in default mem because it is a collective-permute buffer."; + return false; + } + } } return true;