diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index 68be484e6c2..9ea9a585465 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -52,7 +52,7 @@ void HloReachabilityMap::SetReachabilityToUnionHelper( if (!absl::c_linear_search(inputs, instruction)) { bit_vector->SetToZero(); } - bit_vector->Set(GetIndex(instruction)); + bit_vector->Set(GetIndex(instruction).v); for (const HloInstruction* input : inputs) { if (input != instruction) { bit_vector->OrWith(GetBitVector(input)); @@ -65,23 +65,12 @@ void HloReachabilityMap::Replace(const HloInstruction* original, if (GetKey(original) == GetKey(replacement)) { return; } - indices_[GetKey(replacement)] = GetIndex(original); + indices_[GetKey(replacement)] = GetIndex(original).v; indices_.erase(GetKey(original)); } -void HloReachabilityMap::SetReachable(const HloInstruction* a, - const HloInstruction* b) { - GetBitVector(b).Set(GetIndex(a)); -} - -bool HloReachabilityMap::IsReachable(const HloInstruction* a, - const HloInstruction* b) const { - return GetBitVector(b).Get(GetIndex(a)); -} - -bool HloReachabilityMap::IsConnected(const HloInstruction* a, - const HloInstruction* b) const { - return IsReachable(a, b) || IsReachable(b, a); +void HloReachabilityMap::SetReachable(Index a, Index b) { + GetBitVector(b).Set(a.v); } std::unique_ptr HloReachabilityMap::Build( diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index 1d089333ef0..15edf315560 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -74,13 +74,34 @@ class HloReachabilityMap { absl::Span inputs, const HloInstruction* instruction); + // An opaque index that clients can use to make repeated operations for the + // same instruction faster, by calling GetIndex once for the instruction, + // and then calling the variants of other interfaces that take Index arguments + // rather than HloInstruction* arguments. + struct Index { + private: + friend class HloReachabilityMap; + + // Index assigned for a particular instruction. The value is used to index + // into the vector of BitVectors and the BitVectors themselves. + int v; + }; + Index GetIndex(const HloInstruction* instruction) const { + Index i; + i.v = FindOrDie(indices_, GetKey(instruction)); + return i; + } + // Sets entry so that IsReachable(a, b) will return true // // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency // matrix in the internal graph of this HloReachabilityMap to have an edge // from a to b and does not transitively update any other part of the // adjacency matrix. - void SetReachable(const HloInstruction* a, const HloInstruction* b); + void SetReachable(const HloInstruction* a, const HloInstruction* b) { + SetReachable(GetIndex(a), GetIndex(b)); + } + void SetReachable(Index a, Index b); // Updates the given reachability map after the immediate predecessor set // (operands and control predecessors) of 'instruction' has changed. @@ -90,13 +111,21 @@ class HloReachabilityMap { // // Note that this function only correctly answers queries about reachability // if the set of edges that have been provided to this class are transitive. - bool IsReachable(const HloInstruction* a, const HloInstruction* b) const; + bool IsReachable(const HloInstruction* a, const HloInstruction* b) const { + return IsReachable(GetIndex(a), GetIndex(b)); + } + bool IsReachable(Index a, Index b) const { return GetBitVector(b).Get(a.v); } // Returns true if "b" is reachable from "a" or "a" is reachable from "b" // // Note that this function only correctly answers queries about reachability // if the set of edges that have been provided to this class are transitive. - bool IsConnected(const HloInstruction* a, const HloInstruction* b) const; + bool IsConnected(const HloInstruction* a, const HloInstruction* b) const { + return IsConnected(GetIndex(a), GetIndex(b)); + } + bool IsConnected(Index a, Index b) const { + return IsReachable(a, b) || IsReachable(b, a); + } // Checks if an instruction is in the Reachability map. bool IsPresent(const HloInstruction* a) const { @@ -158,12 +187,17 @@ class HloReachabilityMap { // Return the bitvector storing the reachability-to of the given instruction. const BitVector& GetBitVector(const HloInstruction* instruction) const { - return bit_vectors_[GetIndex(instruction)]; + return GetBitVector(GetIndex(instruction)); } BitVector& GetBitVector(const HloInstruction* instruction) { - return bit_vectors_[GetIndex(instruction)]; + return GetBitVector(GetIndex(instruction)); } + const BitVector& GetBitVector(Index index) const { + return bit_vectors_[index.v]; + } + BitVector& GetBitVector(Index index) { return bit_vectors_[index.v]; } + // Helper for SetReachabilityToUnion/FastSetReachabilityToUnion. void SetReachabilityToUnionHelper( absl::Span inputs, @@ -175,9 +209,8 @@ class HloReachabilityMap { absl::bit_cast(instruction->parent()->parent()->unique_id()); return (module_id << 32) | unique_id; } - // Return the index of the given instruction. The value is used to index into - // the vector of BitVectors and the BitVectors themselves. - int GetIndex(const HloInstruction* instruction) const { + // Return the index of the given instruction. + int GetIndexInternal(const HloInstruction* instruction) const { return FindOrDie(indices_, GetKey(instruction)); } diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index 7b4819d8c03..b95b27d6291 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -346,12 +346,15 @@ void MultiOutputFusion::UpdateReachability( HloInstruction* instr1, HloInstruction* instr2, absl::Span instrs_to_update, const std::function& skip) { + auto instr1_i = reachability_->GetIndex(instr1); + auto instr2_i = reachability_->GetIndex(instr2); for (auto instr : instrs_to_update) { if (skip != nullptr && skip(instr)) { continue; } - bool instr2_instr = reachability_->IsReachable(instr2, instr); - bool instr1_instr = reachability_->IsReachable(instr1, instr); + auto instr_i = reachability_->GetIndex(instr); + bool instr2_instr = reachability_->IsReachable(instr2_i, instr_i); + bool instr1_instr = reachability_->IsReachable(instr1_i, instr_i); if (instr2_instr && instr1_instr) { // If a candidate was already reachable by both, no update needed. continue; @@ -359,7 +362,7 @@ void MultiOutputFusion::UpdateReachability( if (instr2_instr) { reachability_->FastSetReachabilityToUnion({instr, instr1}, instr); } - if (reachability_->IsReachable(instr1, instr)) { + if (reachability_->IsReachable(instr1_i, instr_i)) { reachability_->FastSetReachabilityToUnion({instr, instr2}, instr); } }