Make HloInstruction::RemoveUser amortized constant time
Previously it was linear in the number of users because we had to traverse an std::vector to find the element to remove. After this change we store the index inside the vector in the map used for fast lookup instead. PiperOrigin-RevId: 266884015
This commit is contained in:
parent
95c5c99f71
commit
218836e14c
@ -1603,7 +1603,7 @@ HloInstruction::~HloInstruction() {
|
||||
if (operand == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (operand->user_set_.find(this) != operand->user_set_.end()) {
|
||||
if (operand->user_map_.find(this) != operand->user_map_.end()) {
|
||||
operand->RemoveUser(this);
|
||||
}
|
||||
operands_[operand_num] = nullptr;
|
||||
@ -1786,8 +1786,8 @@ void HloInstruction::RemoveOperandsAtAscendingIndices(
|
||||
}
|
||||
|
||||
void HloInstruction::AddUser(HloInstruction* user) {
|
||||
if (!ContainsKey(user_set_, user)) {
|
||||
user_set_.insert(user);
|
||||
if (!ContainsKey(user_map_, user)) {
|
||||
user_map_.emplace(user, users_.size());
|
||||
users_.push_back(user);
|
||||
}
|
||||
}
|
||||
@ -1965,14 +1965,20 @@ uint64 HloInstruction::Hash() const {
|
||||
uint64 HloInstruction::InnerHash() const { return 13; }
|
||||
|
||||
void HloInstruction::RemoveUser(HloInstruction* user) {
|
||||
auto set_it = user_set_.find(user);
|
||||
CHECK(set_it != user_set_.end());
|
||||
user_set_.erase(set_it);
|
||||
// This is linear in the number of the users, but a vector provides a stable
|
||||
// iteration order and much faster traversal.
|
||||
auto vec_it = absl::c_find(users_, user);
|
||||
CHECK(vec_it != users_.end());
|
||||
users_.erase(vec_it);
|
||||
auto map_it = user_map_.find(user);
|
||||
CHECK(map_it != user_map_.end());
|
||||
|
||||
const int64 index = map_it->second;
|
||||
CHECK_EQ(users_[index], user);
|
||||
|
||||
// Move the last user into the position of the removed user.
|
||||
users_[index] = users_.back();
|
||||
user_map_[users_.back()] = index;
|
||||
|
||||
// Remove the user from the map and drop the last slot from the vector what
|
||||
// have been moved to the position of the original user.
|
||||
user_map_.erase(map_it);
|
||||
users_.pop_back();
|
||||
}
|
||||
|
||||
Status HloInstruction::ReplaceUseWith(HloInstruction* user,
|
||||
@ -2063,7 +2069,7 @@ Status HloInstruction::ReplaceAllUsesWithDifferentShape(
|
||||
}
|
||||
}
|
||||
users_.clear();
|
||||
user_set_.clear();
|
||||
user_map_.clear();
|
||||
if (new_producer_is_user) {
|
||||
AddUser(new_producer);
|
||||
}
|
||||
|
@ -921,7 +921,7 @@ class HloInstruction {
|
||||
|
||||
// Returns true if this instruction is a user of 'instruction'.
|
||||
bool IsUserOf(const HloInstruction* instruction) const {
|
||||
return ContainsKey(instruction->user_set_, this);
|
||||
return ContainsKey(instruction->user_map_, this);
|
||||
}
|
||||
|
||||
// Adds a control dependency from this instruction to the given
|
||||
@ -1853,11 +1853,12 @@ class HloInstruction {
|
||||
std::vector<HloInstruction*> control_predecessors_;
|
||||
|
||||
// The users of this instruction. Users are HLOs where this instruction is an
|
||||
// operand. The vector users_ and the set user_set_ contain identical
|
||||
// members. The set enables fast membership testing and the vector enables
|
||||
// fast, stable iteration.
|
||||
// operand. The vector users_ and the map user_map_ contain identical members.
|
||||
// The map enables fast membership testing and the vector enables fast, stable
|
||||
// iteration. The value in the map contains the index of the instruction in
|
||||
// the vector what enables fast removal.
|
||||
std::vector<HloInstruction*> users_;
|
||||
absl::flat_hash_set<const HloInstruction*> user_set_;
|
||||
absl::flat_hash_map<const HloInstruction*, int64> user_map_;
|
||||
|
||||
// The set of control successors of this instruction.
|
||||
std::vector<HloInstruction*> control_successors_;
|
||||
|
Loading…
Reference in New Issue
Block a user