[XLA] Adds transformation from collective-permute to the asynchronous version, and adds emitter for the asynchronous version.
PiperOrigin-RevId: 314246879 Change-Id: I11865391efb6f3925b02e7b1617ba3cfedcd7701
This commit is contained in:
parent
9aca1f01b9
commit
857a3bc0e1
@ -676,6 +676,39 @@ bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool HloDataflowAnalysis::UpdateCollectivePermuteStartValueSet(
|
||||||
|
HloInstruction* collective_permute_start) {
|
||||||
|
CHECK_EQ(collective_permute_start->opcode(),
|
||||||
|
HloOpcode::kCollectivePermuteStart);
|
||||||
|
bool changed = false;
|
||||||
|
// CollectivePermuteStart forwards the operand value to element {0} of its
|
||||||
|
// output.
|
||||||
|
const HloValueSet& operand_value_set =
|
||||||
|
GetValueSet(collective_permute_start->operand(0));
|
||||||
|
HloValueSet& value_set = GetValueSet(collective_permute_start, {0});
|
||||||
|
if (value_set != operand_value_set) {
|
||||||
|
value_set = operand_value_set;
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
return changed;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool HloDataflowAnalysis::UpdateCollectivePermuteDoneValueSet(
|
||||||
|
HloInstruction* collective_permute_done) {
|
||||||
|
CHECK_EQ(collective_permute_done->opcode(),
|
||||||
|
HloOpcode::kCollectivePermuteDone);
|
||||||
|
bool changed = false;
|
||||||
|
// CollectivePermuteDone forwards the operand value at {0} to its output.
|
||||||
|
const HloValueSet& operand_value_set =
|
||||||
|
GetValueSet(collective_permute_done->operand(0), {1});
|
||||||
|
HloValueSet& value_set = GetValueSet(collective_permute_done);
|
||||||
|
if (value_set != operand_value_set) {
|
||||||
|
value_set = operand_value_set;
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
return changed;
|
||||||
|
}
|
||||||
|
|
||||||
bool HloDataflowAnalysis::UpdateInstructionValueSet(
|
bool HloDataflowAnalysis::UpdateInstructionValueSet(
|
||||||
HloInstruction* instruction) {
|
HloInstruction* instruction) {
|
||||||
// Recompute from operands.
|
// Recompute from operands.
|
||||||
@ -712,6 +745,10 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
|
|||||||
return UpdateCopyDoneValueSet(instruction);
|
return UpdateCopyDoneValueSet(instruction);
|
||||||
case HloOpcode::kConditional:
|
case HloOpcode::kConditional:
|
||||||
return UpdateConditionalValueSet(instruction);
|
return UpdateConditionalValueSet(instruction);
|
||||||
|
case HloOpcode::kCollectivePermuteStart:
|
||||||
|
return UpdateCollectivePermuteStartValueSet(instruction);
|
||||||
|
case HloOpcode::kCollectivePermuteDone:
|
||||||
|
return UpdateCollectivePermuteDoneValueSet(instruction);
|
||||||
default:
|
default:
|
||||||
// Instruction does not forward HloValues (it defines all values in its
|
// Instruction does not forward HloValues (it defines all values in its
|
||||||
// output). No update is necessary.
|
// output). No update is necessary.
|
||||||
|
|||||||
@ -216,6 +216,10 @@ class HloDataflowAnalysis {
|
|||||||
bool UpdateTupleValueSet(HloInstruction* tuple);
|
bool UpdateTupleValueSet(HloInstruction* tuple);
|
||||||
bool UpdateWhileValueSet(HloInstruction* xla_while);
|
bool UpdateWhileValueSet(HloInstruction* xla_while);
|
||||||
bool UpdateAddDependencyValueSet(HloInstruction* add_dependency);
|
bool UpdateAddDependencyValueSet(HloInstruction* add_dependency);
|
||||||
|
bool UpdateCollectivePermuteStartValueSet(
|
||||||
|
HloInstruction* collective_permute_start);
|
||||||
|
bool UpdateCollectivePermuteDoneValueSet(
|
||||||
|
HloInstruction* collective_permute_done);
|
||||||
|
|
||||||
// Propagates the dataflow through the module. In particular, it propagates
|
// Propagates the dataflow through the module. In particular, it propagates
|
||||||
// the HloValueSet from its defining instruction to the users of the
|
// the HloValueSet from its defining instruction to the users of the
|
||||||
|
|||||||
@ -1201,7 +1201,8 @@ TEST_F(HloVerifierTest, CollectivePermuteDoneNoCollectivePermuteStart) {
|
|||||||
p0 = f32[2,3]{1,0:S(1)} parameter(0)
|
p0 = f32[2,3]{1,0:S(1)} parameter(0)
|
||||||
p1 = f32[2,3]{1,0:S(1)} parameter(1)
|
p1 = f32[2,3]{1,0:S(1)} parameter(1)
|
||||||
p2 = u32[] parameter(2)
|
p2 = u32[] parameter(2)
|
||||||
tuple.1 = (f32[2,3], f32[2,3], u32[], u32[]) tuple(p0, p1, p2)
|
p3 = u32[] parameter(3)
|
||||||
|
tuple.1 = (f32[2,3], f32[2,3], u32[], u32[]) tuple(p0, p1, p2, p3)
|
||||||
ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(tuple.1)
|
ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(tuple.1)
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|||||||
@ -521,6 +521,20 @@ bool AlternateMemoryBestFitHeap::IsIntervalAllowedInAlternateMemory(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ((position.instruction->opcode() == HloOpcode::kCollectivePermuteStart ||
|
||||||
|
position.instruction->opcode() == HloOpcode::kCollectivePermuteDone)) {
|
||||||
|
// Disable memory space allocation for these for now.
|
||||||
|
if (position.index == ShapeIndex({0})) {
|
||||||
|
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||||
|
<< " in default mem because it is a collective-permute buffer.";
|
||||||
|
return false;
|
||||||
|
} else if (position.index == ShapeIndex({1})) {
|
||||||
|
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||||
|
<< " in default mem because it is a collective-permute buffer.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user