[XLA] Adds transformation from collective-permute to the asynchronous version, and adds emitter for the asynchronous version.

PiperOrigin-RevId: 314246879
Change-Id: I11865391efb6f3925b02e7b1617ba3cfedcd7701
This commit is contained in:
Jinliang Wei 2020-06-01 18:51:16 -07:00 committed by TensorFlower Gardener
parent 9aca1f01b9
commit 857a3bc0e1
4 changed files with 57 additions and 1 deletions

View File

@ -676,6 +676,39 @@ bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) {
}
}
bool HloDataflowAnalysis::UpdateCollectivePermuteStartValueSet(
HloInstruction* collective_permute_start) {
CHECK_EQ(collective_permute_start->opcode(),
HloOpcode::kCollectivePermuteStart);
bool changed = false;
// CollectivePermuteStart forwards the operand value to element {0} of its
// output.
const HloValueSet& operand_value_set =
GetValueSet(collective_permute_start->operand(0));
HloValueSet& value_set = GetValueSet(collective_permute_start, {0});
if (value_set != operand_value_set) {
value_set = operand_value_set;
changed = true;
}
return changed;
}
bool HloDataflowAnalysis::UpdateCollectivePermuteDoneValueSet(
HloInstruction* collective_permute_done) {
CHECK_EQ(collective_permute_done->opcode(),
HloOpcode::kCollectivePermuteDone);
bool changed = false;
// CollectivePermuteDone forwards the operand value at {0} to its output.
const HloValueSet& operand_value_set =
GetValueSet(collective_permute_done->operand(0), {1});
HloValueSet& value_set = GetValueSet(collective_permute_done);
if (value_set != operand_value_set) {
value_set = operand_value_set;
changed = true;
}
return changed;
}
bool HloDataflowAnalysis::UpdateInstructionValueSet(
HloInstruction* instruction) {
// Recompute from operands.
@ -712,6 +745,10 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
return UpdateCopyDoneValueSet(instruction);
case HloOpcode::kConditional:
return UpdateConditionalValueSet(instruction);
case HloOpcode::kCollectivePermuteStart:
return UpdateCollectivePermuteStartValueSet(instruction);
case HloOpcode::kCollectivePermuteDone:
return UpdateCollectivePermuteDoneValueSet(instruction);
default:
// Instruction does not forward HloValues (it defines all values in its
// output). No update is necessary.

View File

@ -216,6 +216,10 @@ class HloDataflowAnalysis {
bool UpdateTupleValueSet(HloInstruction* tuple);
bool UpdateWhileValueSet(HloInstruction* xla_while);
bool UpdateAddDependencyValueSet(HloInstruction* add_dependency);
bool UpdateCollectivePermuteStartValueSet(
HloInstruction* collective_permute_start);
bool UpdateCollectivePermuteDoneValueSet(
HloInstruction* collective_permute_done);
// Propagates the dataflow through the module. In particular, it propagates
// the HloValueSet from its defining instruction to the users of the

View File

@ -1201,7 +1201,8 @@ TEST_F(HloVerifierTest, CollectivePermuteDoneNoCollectivePermuteStart) {
p0 = f32[2,3]{1,0:S(1)} parameter(0)
p1 = f32[2,3]{1,0:S(1)} parameter(1)
p2 = u32[] parameter(2)
tuple.1 = (f32[2,3], f32[2,3], u32[], u32[]) tuple(p0, p1, p2)
p3 = u32[] parameter(3)
tuple.1 = (f32[2,3], f32[2,3], u32[], u32[]) tuple(p0, p1, p2, p3)
ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(tuple.1)
}
)";

View File

@ -521,6 +521,20 @@ bool AlternateMemoryBestFitHeap::IsIntervalAllowedInAlternateMemory(
return false;
}
}
if ((position.instruction->opcode() == HloOpcode::kCollectivePermuteStart ||
position.instruction->opcode() == HloOpcode::kCollectivePermuteDone)) {
// Disable memory space allocation for these for now.
if (position.index == ShapeIndex({0})) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it is a collective-permute buffer.";
return false;
} else if (position.index == ShapeIndex({1})) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it is a collective-permute buffer.";
return false;
}
}
}
return true;