Allow split-phase operations for HloReachabilityMap where we can
get the bit vector index assigned to an instruction into an opaque HloReachabilityMap::Index that can then be passed to various other APIs in HloReachabilityMap. Use this in MultiOutputFusion::UpdateReachability to avoid multiple hash table lookups per loop iteration and instead mostly just do bit vector operations. Improves time spent in MultiOutputFusion::UpdateReachability from 1.92s to 1.53s for one compilation intensive workload. PiperOrigin-RevId: 316362463 Change-Id: Ic55889a70e5ae6fa9cc95690171cbebc0654dbcf
This commit is contained in:
parent
fbe175972d
commit
4021c9a091
@ -52,7 +52,7 @@ void HloReachabilityMap::SetReachabilityToUnionHelper(
|
||||
if (!absl::c_linear_search(inputs, instruction)) {
|
||||
bit_vector->SetToZero();
|
||||
}
|
||||
bit_vector->Set(GetIndex(instruction));
|
||||
bit_vector->Set(GetIndex(instruction).v);
|
||||
for (const HloInstruction* input : inputs) {
|
||||
if (input != instruction) {
|
||||
bit_vector->OrWith(GetBitVector(input));
|
||||
@ -65,23 +65,12 @@ void HloReachabilityMap::Replace(const HloInstruction* original,
|
||||
if (GetKey(original) == GetKey(replacement)) {
|
||||
return;
|
||||
}
|
||||
indices_[GetKey(replacement)] = GetIndex(original);
|
||||
indices_[GetKey(replacement)] = GetIndex(original).v;
|
||||
indices_.erase(GetKey(original));
|
||||
}
|
||||
|
||||
void HloReachabilityMap::SetReachable(const HloInstruction* a,
|
||||
const HloInstruction* b) {
|
||||
GetBitVector(b).Set(GetIndex(a));
|
||||
}
|
||||
|
||||
bool HloReachabilityMap::IsReachable(const HloInstruction* a,
|
||||
const HloInstruction* b) const {
|
||||
return GetBitVector(b).Get(GetIndex(a));
|
||||
}
|
||||
|
||||
bool HloReachabilityMap::IsConnected(const HloInstruction* a,
|
||||
const HloInstruction* b) const {
|
||||
return IsReachable(a, b) || IsReachable(b, a);
|
||||
void HloReachabilityMap::SetReachable(Index a, Index b) {
|
||||
GetBitVector(b).Set(a.v);
|
||||
}
|
||||
|
||||
std::unique_ptr<HloReachabilityMap> HloReachabilityMap::Build(
|
||||
|
@ -74,13 +74,34 @@ class HloReachabilityMap {
|
||||
absl::Span<const HloInstruction* const> inputs,
|
||||
const HloInstruction* instruction);
|
||||
|
||||
// An opaque index that clients can use to make repeated operations for the
|
||||
// same instruction faster, by calling GetIndex once for the instruction,
|
||||
// and then calling the variants of other interfaces that take Index arguments
|
||||
// rather than HloInstruction* arguments.
|
||||
struct Index {
|
||||
private:
|
||||
friend class HloReachabilityMap;
|
||||
|
||||
// Index assigned for a particular instruction. The value is used to index
|
||||
// into the vector of BitVectors and the BitVectors themselves.
|
||||
int v;
|
||||
};
|
||||
Index GetIndex(const HloInstruction* instruction) const {
|
||||
Index i;
|
||||
i.v = FindOrDie(indices_, GetKey(instruction));
|
||||
return i;
|
||||
}
|
||||
|
||||
// Sets entry so that IsReachable(a, b) will return true
|
||||
//
|
||||
// !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency
|
||||
// matrix in the internal graph of this HloReachabilityMap to have an edge
|
||||
// from a to b and does not transitively update any other part of the
|
||||
// adjacency matrix.
|
||||
void SetReachable(const HloInstruction* a, const HloInstruction* b);
|
||||
void SetReachable(const HloInstruction* a, const HloInstruction* b) {
|
||||
SetReachable(GetIndex(a), GetIndex(b));
|
||||
}
|
||||
void SetReachable(Index a, Index b);
|
||||
|
||||
// Updates the given reachability map after the immediate predecessor set
|
||||
// (operands and control predecessors) of 'instruction' has changed.
|
||||
@ -90,13 +111,21 @@ class HloReachabilityMap {
|
||||
//
|
||||
// Note that this function only correctly answers queries about reachability
|
||||
// if the set of edges that have been provided to this class are transitive.
|
||||
bool IsReachable(const HloInstruction* a, const HloInstruction* b) const;
|
||||
bool IsReachable(const HloInstruction* a, const HloInstruction* b) const {
|
||||
return IsReachable(GetIndex(a), GetIndex(b));
|
||||
}
|
||||
bool IsReachable(Index a, Index b) const { return GetBitVector(b).Get(a.v); }
|
||||
|
||||
// Returns true if "b" is reachable from "a" or "a" is reachable from "b"
|
||||
//
|
||||
// Note that this function only correctly answers queries about reachability
|
||||
// if the set of edges that have been provided to this class are transitive.
|
||||
bool IsConnected(const HloInstruction* a, const HloInstruction* b) const;
|
||||
bool IsConnected(const HloInstruction* a, const HloInstruction* b) const {
|
||||
return IsConnected(GetIndex(a), GetIndex(b));
|
||||
}
|
||||
bool IsConnected(Index a, Index b) const {
|
||||
return IsReachable(a, b) || IsReachable(b, a);
|
||||
}
|
||||
|
||||
// Checks if an instruction is in the Reachability map.
|
||||
bool IsPresent(const HloInstruction* a) const {
|
||||
@ -158,12 +187,17 @@ class HloReachabilityMap {
|
||||
|
||||
// Return the bitvector storing the reachability-to of the given instruction.
|
||||
const BitVector& GetBitVector(const HloInstruction* instruction) const {
|
||||
return bit_vectors_[GetIndex(instruction)];
|
||||
return GetBitVector(GetIndex(instruction));
|
||||
}
|
||||
BitVector& GetBitVector(const HloInstruction* instruction) {
|
||||
return bit_vectors_[GetIndex(instruction)];
|
||||
return GetBitVector(GetIndex(instruction));
|
||||
}
|
||||
|
||||
const BitVector& GetBitVector(Index index) const {
|
||||
return bit_vectors_[index.v];
|
||||
}
|
||||
BitVector& GetBitVector(Index index) { return bit_vectors_[index.v]; }
|
||||
|
||||
// Helper for SetReachabilityToUnion/FastSetReachabilityToUnion.
|
||||
void SetReachabilityToUnionHelper(
|
||||
absl::Span<const HloInstruction* const> inputs,
|
||||
@ -175,9 +209,8 @@ class HloReachabilityMap {
|
||||
absl::bit_cast<uint32>(instruction->parent()->parent()->unique_id());
|
||||
return (module_id << 32) | unique_id;
|
||||
}
|
||||
// Return the index of the given instruction. The value is used to index into
|
||||
// the vector of BitVectors and the BitVectors themselves.
|
||||
int GetIndex(const HloInstruction* instruction) const {
|
||||
// Return the index of the given instruction.
|
||||
int GetIndexInternal(const HloInstruction* instruction) const {
|
||||
return FindOrDie(indices_, GetKey(instruction));
|
||||
}
|
||||
|
||||
|
@ -346,12 +346,15 @@ void MultiOutputFusion::UpdateReachability(
|
||||
HloInstruction* instr1, HloInstruction* instr2,
|
||||
absl::Span<HloInstruction* const> instrs_to_update,
|
||||
const std::function<bool(HloInstruction*)>& skip) {
|
||||
auto instr1_i = reachability_->GetIndex(instr1);
|
||||
auto instr2_i = reachability_->GetIndex(instr2);
|
||||
for (auto instr : instrs_to_update) {
|
||||
if (skip != nullptr && skip(instr)) {
|
||||
continue;
|
||||
}
|
||||
bool instr2_instr = reachability_->IsReachable(instr2, instr);
|
||||
bool instr1_instr = reachability_->IsReachable(instr1, instr);
|
||||
auto instr_i = reachability_->GetIndex(instr);
|
||||
bool instr2_instr = reachability_->IsReachable(instr2_i, instr_i);
|
||||
bool instr1_instr = reachability_->IsReachable(instr1_i, instr_i);
|
||||
if (instr2_instr && instr1_instr) {
|
||||
// If a candidate was already reachable by both, no update needed.
|
||||
continue;
|
||||
@ -359,7 +362,7 @@ void MultiOutputFusion::UpdateReachability(
|
||||
if (instr2_instr) {
|
||||
reachability_->FastSetReachabilityToUnion({instr, instr1}, instr);
|
||||
}
|
||||
if (reachability_->IsReachable(instr1, instr)) {
|
||||
if (reachability_->IsReachable(instr1_i, instr_i)) {
|
||||
reachability_->FastSetReachabilityToUnion({instr, instr2}, instr);
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user