Allow split-phase operations for HloReachabilityMap where we can

get the bit vector index assigned to an instruction into an opaque
HloReachabilityMap::Index that can then be passed to various
other APIs in HloReachabilityMap.

Use this in MultiOutputFusion::UpdateReachability to avoid multiple
hash table lookups per loop iteration and instead mostly just
do bit vector operations.

Improves time spent in MultiOutputFusion::UpdateReachability from
1.92s to 1.53s for one compilation intensive workload.

PiperOrigin-RevId: 316362463
Change-Id: Ic55889a70e5ae6fa9cc95690171cbebc0654dbcf
This commit is contained in:
Jeffrey A. Dean 2020-06-14 14:40:38 -07:00 committed by TensorFlower Gardener
parent fbe175972d
commit 4021c9a091
3 changed files with 51 additions and 26 deletions

View File

@ -52,7 +52,7 @@ void HloReachabilityMap::SetReachabilityToUnionHelper(
if (!absl::c_linear_search(inputs, instruction)) {
bit_vector->SetToZero();
}
bit_vector->Set(GetIndex(instruction));
bit_vector->Set(GetIndex(instruction).v);
for (const HloInstruction* input : inputs) {
if (input != instruction) {
bit_vector->OrWith(GetBitVector(input));
@ -65,23 +65,12 @@ void HloReachabilityMap::Replace(const HloInstruction* original,
if (GetKey(original) == GetKey(replacement)) {
return;
}
indices_[GetKey(replacement)] = GetIndex(original);
indices_[GetKey(replacement)] = GetIndex(original).v;
indices_.erase(GetKey(original));
}
void HloReachabilityMap::SetReachable(const HloInstruction* a,
const HloInstruction* b) {
GetBitVector(b).Set(GetIndex(a));
}
bool HloReachabilityMap::IsReachable(const HloInstruction* a,
const HloInstruction* b) const {
return GetBitVector(b).Get(GetIndex(a));
}
bool HloReachabilityMap::IsConnected(const HloInstruction* a,
const HloInstruction* b) const {
return IsReachable(a, b) || IsReachable(b, a);
void HloReachabilityMap::SetReachable(Index a, Index b) {
GetBitVector(b).Set(a.v);
}
std::unique_ptr<HloReachabilityMap> HloReachabilityMap::Build(

View File

@ -74,13 +74,34 @@ class HloReachabilityMap {
absl::Span<const HloInstruction* const> inputs,
const HloInstruction* instruction);
// An opaque index that clients can use to make repeated operations for the
// same instruction faster, by calling GetIndex once for the instruction,
// and then calling the variants of other interfaces that take Index arguments
// rather than HloInstruction* arguments.
struct Index {
private:
friend class HloReachabilityMap;
// Index assigned for a particular instruction. The value is used to index
// into the vector of BitVectors and the BitVectors themselves.
int v;
};
Index GetIndex(const HloInstruction* instruction) const {
Index i;
i.v = FindOrDie(indices_, GetKey(instruction));
return i;
}
// Sets entry so that IsReachable(a, b) will return true
//
// !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency
// matrix in the internal graph of this HloReachabilityMap to have an edge
// from a to b and does not transitively update any other part of the
// adjacency matrix.
void SetReachable(const HloInstruction* a, const HloInstruction* b);
void SetReachable(const HloInstruction* a, const HloInstruction* b) {
SetReachable(GetIndex(a), GetIndex(b));
}
void SetReachable(Index a, Index b);
// Updates the given reachability map after the immediate predecessor set
// (operands and control predecessors) of 'instruction' has changed.
@ -90,13 +111,21 @@ class HloReachabilityMap {
//
// Note that this function only correctly answers queries about reachability
// if the set of edges that have been provided to this class are transitive.
bool IsReachable(const HloInstruction* a, const HloInstruction* b) const;
bool IsReachable(const HloInstruction* a, const HloInstruction* b) const {
return IsReachable(GetIndex(a), GetIndex(b));
}
bool IsReachable(Index a, Index b) const { return GetBitVector(b).Get(a.v); }
// Returns true if "b" is reachable from "a" or "a" is reachable from "b"
//
// Note that this function only correctly answers queries about reachability
// if the set of edges that have been provided to this class are transitive.
bool IsConnected(const HloInstruction* a, const HloInstruction* b) const;
bool IsConnected(const HloInstruction* a, const HloInstruction* b) const {
return IsConnected(GetIndex(a), GetIndex(b));
}
bool IsConnected(Index a, Index b) const {
return IsReachable(a, b) || IsReachable(b, a);
}
// Checks if an instruction is in the Reachability map.
bool IsPresent(const HloInstruction* a) const {
@ -158,12 +187,17 @@ class HloReachabilityMap {
// Return the bitvector storing the reachability-to of the given instruction.
const BitVector& GetBitVector(const HloInstruction* instruction) const {
return bit_vectors_[GetIndex(instruction)];
return GetBitVector(GetIndex(instruction));
}
BitVector& GetBitVector(const HloInstruction* instruction) {
return bit_vectors_[GetIndex(instruction)];
return GetBitVector(GetIndex(instruction));
}
const BitVector& GetBitVector(Index index) const {
return bit_vectors_[index.v];
}
BitVector& GetBitVector(Index index) { return bit_vectors_[index.v]; }
// Helper for SetReachabilityToUnion/FastSetReachabilityToUnion.
void SetReachabilityToUnionHelper(
absl::Span<const HloInstruction* const> inputs,
@ -175,9 +209,8 @@ class HloReachabilityMap {
absl::bit_cast<uint32>(instruction->parent()->parent()->unique_id());
return (module_id << 32) | unique_id;
}
// Return the index of the given instruction. The value is used to index into
// the vector of BitVectors and the BitVectors themselves.
int GetIndex(const HloInstruction* instruction) const {
// Return the index of the given instruction.
int GetIndexInternal(const HloInstruction* instruction) const {
return FindOrDie(indices_, GetKey(instruction));
}

View File

@ -346,12 +346,15 @@ void MultiOutputFusion::UpdateReachability(
HloInstruction* instr1, HloInstruction* instr2,
absl::Span<HloInstruction* const> instrs_to_update,
const std::function<bool(HloInstruction*)>& skip) {
auto instr1_i = reachability_->GetIndex(instr1);
auto instr2_i = reachability_->GetIndex(instr2);
for (auto instr : instrs_to_update) {
if (skip != nullptr && skip(instr)) {
continue;
}
bool instr2_instr = reachability_->IsReachable(instr2, instr);
bool instr1_instr = reachability_->IsReachable(instr1, instr);
auto instr_i = reachability_->GetIndex(instr);
bool instr2_instr = reachability_->IsReachable(instr2_i, instr_i);
bool instr1_instr = reachability_->IsReachable(instr1_i, instr_i);
if (instr2_instr && instr1_instr) {
// If a candidate was already reachable by both, no update needed.
continue;
@ -359,7 +362,7 @@ void MultiOutputFusion::UpdateReachability(
if (instr2_instr) {
reachability_->FastSetReachabilityToUnion({instr, instr1}, instr);
}
if (reachability_->IsReachable(instr1, instr)) {
if (reachability_->IsReachable(instr1_i, instr_i)) {
reachability_->FastSetReachabilityToUnion({instr, instr2}, instr);
}
}