Make HloInstruction::RemoveUser amortized constant time

Previously it was linear in the number of users because we had to traverse
an std::vector to find the element to remove. After this change we store
the index inside the vector in the map used for fast lookup instead.

PiperOrigin-RevId: 266884015
This commit is contained in:
A. Unique TensorFlower 2019-09-03 02:16:29 -07:00 committed by TensorFlower Gardener
parent 95c5c99f71
commit 218836e14c
2 changed files with 24 additions and 17 deletions

View File

@ -1603,7 +1603,7 @@ HloInstruction::~HloInstruction() {
if (operand == nullptr) { if (operand == nullptr) {
continue; continue;
} }
if (operand->user_set_.find(this) != operand->user_set_.end()) { if (operand->user_map_.find(this) != operand->user_map_.end()) {
operand->RemoveUser(this); operand->RemoveUser(this);
} }
operands_[operand_num] = nullptr; operands_[operand_num] = nullptr;
@ -1786,8 +1786,8 @@ void HloInstruction::RemoveOperandsAtAscendingIndices(
} }
void HloInstruction::AddUser(HloInstruction* user) { void HloInstruction::AddUser(HloInstruction* user) {
if (!ContainsKey(user_set_, user)) { if (!ContainsKey(user_map_, user)) {
user_set_.insert(user); user_map_.emplace(user, users_.size());
users_.push_back(user); users_.push_back(user);
} }
} }
@ -1965,14 +1965,20 @@ uint64 HloInstruction::Hash() const {
uint64 HloInstruction::InnerHash() const { return 13; } uint64 HloInstruction::InnerHash() const { return 13; }
void HloInstruction::RemoveUser(HloInstruction* user) { void HloInstruction::RemoveUser(HloInstruction* user) {
auto set_it = user_set_.find(user); auto map_it = user_map_.find(user);
CHECK(set_it != user_set_.end()); CHECK(map_it != user_map_.end());
user_set_.erase(set_it);
// This is linear in the number of the users, but a vector provides a stable const int64 index = map_it->second;
// iteration order and much faster traversal. CHECK_EQ(users_[index], user);
auto vec_it = absl::c_find(users_, user);
CHECK(vec_it != users_.end()); // Move the last user into the position of the removed user.
users_.erase(vec_it); users_[index] = users_.back();
user_map_[users_.back()] = index;
// Remove the user from the map and drop the last slot from the vector what
// have been moved to the position of the original user.
user_map_.erase(map_it);
users_.pop_back();
} }
Status HloInstruction::ReplaceUseWith(HloInstruction* user, Status HloInstruction::ReplaceUseWith(HloInstruction* user,
@ -2063,7 +2069,7 @@ Status HloInstruction::ReplaceAllUsesWithDifferentShape(
} }
} }
users_.clear(); users_.clear();
user_set_.clear(); user_map_.clear();
if (new_producer_is_user) { if (new_producer_is_user) {
AddUser(new_producer); AddUser(new_producer);
} }

View File

@ -921,7 +921,7 @@ class HloInstruction {
// Returns true if this instruction is a user of 'instruction'. // Returns true if this instruction is a user of 'instruction'.
bool IsUserOf(const HloInstruction* instruction) const { bool IsUserOf(const HloInstruction* instruction) const {
return ContainsKey(instruction->user_set_, this); return ContainsKey(instruction->user_map_, this);
} }
// Adds a control dependency from this instruction to the given // Adds a control dependency from this instruction to the given
@ -1853,11 +1853,12 @@ class HloInstruction {
std::vector<HloInstruction*> control_predecessors_; std::vector<HloInstruction*> control_predecessors_;
// The users of this instruction. Users are HLOs where this instruction is an // The users of this instruction. Users are HLOs where this instruction is an
// operand. The vector users_ and the set user_set_ contain identical // operand. The vector users_ and the map user_map_ contain identical members.
// members. The set enables fast membership testing and the vector enables // The map enables fast membership testing and the vector enables fast, stable
// fast, stable iteration. // iteration. The value in the map contains the index of the instruction in
// the vector what enables fast removal.
std::vector<HloInstruction*> users_; std::vector<HloInstruction*> users_;
absl::flat_hash_set<const HloInstruction*> user_set_; absl::flat_hash_map<const HloInstruction*, int64> user_map_;
// The set of control successors of this instruction. // The set of control successors of this instruction.
std::vector<HloInstruction*> control_successors_; std::vector<HloInstruction*> control_successors_;