[XLA] Adds transformation from collective-permute to the asynchronous version, and adds emitter for the asynchronous version.
PiperOrigin-RevId: 314246879 Change-Id: I11865391efb6f3925b02e7b1617ba3cfedcd7701
This commit is contained in:
parent
9aca1f01b9
commit
857a3bc0e1
@ -676,6 +676,39 @@ bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) {
|
||||
}
|
||||
}
|
||||
|
||||
bool HloDataflowAnalysis::UpdateCollectivePermuteStartValueSet(
|
||||
HloInstruction* collective_permute_start) {
|
||||
CHECK_EQ(collective_permute_start->opcode(),
|
||||
HloOpcode::kCollectivePermuteStart);
|
||||
bool changed = false;
|
||||
// CollectivePermuteStart forwards the operand value to element {0} of its
|
||||
// output.
|
||||
const HloValueSet& operand_value_set =
|
||||
GetValueSet(collective_permute_start->operand(0));
|
||||
HloValueSet& value_set = GetValueSet(collective_permute_start, {0});
|
||||
if (value_set != operand_value_set) {
|
||||
value_set = operand_value_set;
|
||||
changed = true;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool HloDataflowAnalysis::UpdateCollectivePermuteDoneValueSet(
|
||||
HloInstruction* collective_permute_done) {
|
||||
CHECK_EQ(collective_permute_done->opcode(),
|
||||
HloOpcode::kCollectivePermuteDone);
|
||||
bool changed = false;
|
||||
// CollectivePermuteDone forwards the operand value at {0} to its output.
|
||||
const HloValueSet& operand_value_set =
|
||||
GetValueSet(collective_permute_done->operand(0), {1});
|
||||
HloValueSet& value_set = GetValueSet(collective_permute_done);
|
||||
if (value_set != operand_value_set) {
|
||||
value_set = operand_value_set;
|
||||
changed = true;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool HloDataflowAnalysis::UpdateInstructionValueSet(
|
||||
HloInstruction* instruction) {
|
||||
// Recompute from operands.
|
||||
@ -712,6 +745,10 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
|
||||
return UpdateCopyDoneValueSet(instruction);
|
||||
case HloOpcode::kConditional:
|
||||
return UpdateConditionalValueSet(instruction);
|
||||
case HloOpcode::kCollectivePermuteStart:
|
||||
return UpdateCollectivePermuteStartValueSet(instruction);
|
||||
case HloOpcode::kCollectivePermuteDone:
|
||||
return UpdateCollectivePermuteDoneValueSet(instruction);
|
||||
default:
|
||||
// Instruction does not forward HloValues (it defines all values in its
|
||||
// output). No update is necessary.
|
||||
|
||||
@ -216,6 +216,10 @@ class HloDataflowAnalysis {
|
||||
bool UpdateTupleValueSet(HloInstruction* tuple);
|
||||
bool UpdateWhileValueSet(HloInstruction* xla_while);
|
||||
bool UpdateAddDependencyValueSet(HloInstruction* add_dependency);
|
||||
bool UpdateCollectivePermuteStartValueSet(
|
||||
HloInstruction* collective_permute_start);
|
||||
bool UpdateCollectivePermuteDoneValueSet(
|
||||
HloInstruction* collective_permute_done);
|
||||
|
||||
// Propagates the dataflow through the module. In particular, it propagates
|
||||
// the HloValueSet from its defining instruction to the users of the
|
||||
|
||||
@ -1201,7 +1201,8 @@ TEST_F(HloVerifierTest, CollectivePermuteDoneNoCollectivePermuteStart) {
|
||||
p0 = f32[2,3]{1,0:S(1)} parameter(0)
|
||||
p1 = f32[2,3]{1,0:S(1)} parameter(1)
|
||||
p2 = u32[] parameter(2)
|
||||
tuple.1 = (f32[2,3], f32[2,3], u32[], u32[]) tuple(p0, p1, p2)
|
||||
p3 = u32[] parameter(3)
|
||||
tuple.1 = (f32[2,3], f32[2,3], u32[], u32[]) tuple(p0, p1, p2, p3)
|
||||
ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(tuple.1)
|
||||
}
|
||||
)";
|
||||
|
||||
@ -521,6 +521,20 @@ bool AlternateMemoryBestFitHeap::IsIntervalAllowedInAlternateMemory(
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if ((position.instruction->opcode() == HloOpcode::kCollectivePermuteStart ||
|
||||
position.instruction->opcode() == HloOpcode::kCollectivePermuteDone)) {
|
||||
// Disable memory space allocation for these for now.
|
||||
if (position.index == ShapeIndex({0})) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it is a collective-permute buffer.";
|
||||
return false;
|
||||
} else if (position.index == ShapeIndex({1})) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it is a collective-permute buffer.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user