diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 122122aae55..22d9f1bc648 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -309,6 +309,8 @@ Status HloComputation::RemoveInstructionImpl(HloInstruction* instruction, auto inst_it = instruction_iterators_.find(instruction); TF_RET_CHECK(inst_it != instruction_iterators_.end()); (*inst_it->second)->set_parent(nullptr); + to_be_deleted_.emplace_back(inst_it->second->release()); + to_be_deleted_.back()->DetachFromOperandsAndUsers(); instructions_.erase(inst_it->second); instruction_iterators_.erase(inst_it); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 9ca60403929..f1568858d9f 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -469,6 +469,12 @@ class HloComputation { int64 unique_id() const { return unique_id_; } + // Deallocate instructions that are marked by "RemoveInstruction". The two + // stage clean up process is designed such that HloPass can have stable + // internal pointers to HloInstructions while we create and remove + // HloInstructions in a pass. + void Cleanup() { to_be_deleted_.clear(); } + private: explicit HloComputation( const string& name, int parameter_count, @@ -527,6 +533,10 @@ class HloComputation { absl::flat_hash_map instruction_iterators_; + // Removed instructions are moved into to_be_deleted_ first and then + // deallocated when Cleanup is called. + std::vector> to_be_deleted_; + std::vector param_instructions_; TF_DISALLOW_COPY_AND_ASSIGN(HloComputation); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 9f45cac028c..8aeb92b40de 100755 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1661,7 +1661,11 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( return clone; } -HloInstruction::~HloInstruction() { +void HloInstruction::DetachFromOperandsAndUsers() { + if (cleaned_up_) { + return; + } + cleaned_up_ = true; // Detach from operands. An instruction may be repeated as an operand. To // avoid calling RemoveUser twice on the same operand, check before remove. for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index a108a91d5f9..33c0daca686 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -480,7 +480,11 @@ class HloInstruction { kCustom, }; - virtual ~HloInstruction(); + virtual ~HloInstruction() { DetachFromOperandsAndUsers(); } + + // Detaches an instruction from its operands and users. That is, remove the + // instruction from each operand's user set and user's operand set. + void DetachFromOperandsAndUsers(); // Creates an instruction from the given proto. Arguments: // @@ -2025,6 +2029,10 @@ class HloInstruction { // a default configuration. bool is_default_config_ = false; + // True if this instruction has already been detached from its user and + // operands. + bool cleaned_up_ = false; + // String identifier for instruction. string name_; diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 5e662e0bebc..f25f4694f21 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -184,6 +184,13 @@ class HloModule { // Gets the number of instructions in this module. int64 instruction_count() const; + // Deallocate removed instructions in each computation. + void Cleanup() { + for (auto& comp : computations_) { + comp->Cleanup(); + } + } + // Compute and return a post order of all computations in the module. The sort // is defined like so: if computation A has an instruction which calls // computation B, then A will appear after B in the sort. diff --git a/tensorflow/compiler/xla/service/hlo_module_group.h b/tensorflow/compiler/xla/service/hlo_module_group.h index c4b10f3b22a..217f65b4a75 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group.h +++ b/tensorflow/compiler/xla/service/hlo_module_group.h @@ -64,6 +64,13 @@ class HloModuleGroup { string ToString() const; + // Deallocate removed instructions in each module. + void Cleanup() { + for (auto& module : modules_) { + module->Cleanup(); + } + } + // Serialize the module group to/from a proto. HloModuleGroupProto ToProto() const; static StatusOr CreateFromProto( diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h index ad4070e3e23..16fad113b0d 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h @@ -104,11 +104,15 @@ class HloPassPipeline : public HloPassInterface { // helpers enable templating of the core of the pipeline logic by providing // HloModule and HloModuleGroup specific methods with the same name. static StatusOr RunHelper(HloPassInterface* pass, HloModule* module) { - return pass->Run(module); + TF_ASSIGN_OR_RETURN(bool changed, pass->Run(module)); + module->Cleanup(); + return changed; } static StatusOr RunHelper(HloPassInterface* pass, HloModuleGroup* module_group) { - return pass->RunOnModuleGroup(module_group); + TF_ASSIGN_OR_RETURN(bool changed, pass->RunOnModuleGroup(module_group)); + module_group->Cleanup(); + return changed; } const string name_; diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index a8a4b7ef872..d97893b6d04 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -368,12 +368,12 @@ bool MultiOutputFusion::Perform() { int changed = false; // Pick the top candidate from queue and try to merge. while (!worklist_.empty()) { - ToBeFused candidate = worklist_.top(); - worklist_.pop(); + ToBeFused candidate = worklist_.pop(); HloInstruction* instr1 = candidate.instr1; HloInstruction* instr2 = candidate.instr2; + // Candidates are already fused. if (is_fused(instr1) || is_fused(instr2)) { continue; } diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h index 18069e2f76c..f0b56eeff90 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -136,9 +136,34 @@ class MultiOutputFusion : public HloModulePass { HloInstruction* instr1; HloInstruction* instr2; int64 score; - ToBeFused(HloInstruction* instr1, HloInstruction* instr2, int64 score) - : instr1(instr1), instr2(instr2), score(score) {} - bool operator<(const ToBeFused& rhs) const { return score < rhs.score; } + int64 timestamp; + ToBeFused(HloInstruction* instr1, HloInstruction* instr2, int64 score, + int64 timestamp) + : instr1(instr1), instr2(instr2), score(score), timestamp(timestamp) {} + bool operator<(const ToBeFused& rhs) const { + return std::pair(score, timestamp) < + std::pair(rhs.score, rhs.timestamp); + } + }; + + // Stable priority queue where each insertion has a timestamp for + // deterministic popping. + class WorkList { + public: + bool empty() { return worklist_.empty(); } + ToBeFused pop() { + ToBeFused tmp = worklist_.top(); + worklist_.pop(); + return tmp; + } + template + void emplace(Args&&... args) { + worklist_.emplace(std::forward(args)..., timestamp_++); + } + + private: + std::priority_queue worklist_; + int64 timestamp_ = 0; }; // Update the internal data structures before instr1 and instr2 are fused into @@ -169,7 +194,7 @@ class MultiOutputFusion : public HloModulePass { } std::vector candidates_; - std::priority_queue worklist_; + WorkList worklist_; // A map that maps an instruction to the index_. absl::flat_hash_map candidates_index_;