[Resubmit] Fix several issues of multi output fusion.
PiperOrigin-RevId: 295885785 Change-Id: I1e6350437987f63843181d704fe86660fd9cfb8c
This commit is contained in:
parent
9771b11027
commit
823384e08f
@ -309,6 +309,8 @@ Status HloComputation::RemoveInstructionImpl(HloInstruction* instruction,
|
||||
auto inst_it = instruction_iterators_.find(instruction);
|
||||
TF_RET_CHECK(inst_it != instruction_iterators_.end());
|
||||
(*inst_it->second)->set_parent(nullptr);
|
||||
to_be_deleted_.emplace_back(inst_it->second->release());
|
||||
to_be_deleted_.back()->DetachFromOperandsAndUsers();
|
||||
instructions_.erase(inst_it->second);
|
||||
instruction_iterators_.erase(inst_it);
|
||||
return Status::OK();
|
||||
|
@ -469,6 +469,12 @@ class HloComputation {
|
||||
|
||||
int64 unique_id() const { return unique_id_; }
|
||||
|
||||
// Deallocate instructions that are marked by "RemoveInstruction". The two
|
||||
// stage clean up process is designed such that HloPass can have stable
|
||||
// internal pointers to HloInstructions while we create and remove
|
||||
// HloInstructions in a pass.
|
||||
void Cleanup() { to_be_deleted_.clear(); }
|
||||
|
||||
private:
|
||||
explicit HloComputation(
|
||||
const string& name, int parameter_count,
|
||||
@ -527,6 +533,10 @@ class HloComputation {
|
||||
absl::flat_hash_map<const HloInstruction*, InstructionList::iterator>
|
||||
instruction_iterators_;
|
||||
|
||||
// Removed instructions are moved into to_be_deleted_ first and then
|
||||
// deallocated when Cleanup is called.
|
||||
std::vector<std::unique_ptr<HloInstruction>> to_be_deleted_;
|
||||
|
||||
std::vector<HloInstruction*> param_instructions_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(HloComputation);
|
||||
|
@ -1661,7 +1661,11 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
||||
return clone;
|
||||
}
|
||||
|
||||
HloInstruction::~HloInstruction() {
|
||||
void HloInstruction::DetachFromOperandsAndUsers() {
|
||||
if (cleaned_up_) {
|
||||
return;
|
||||
}
|
||||
cleaned_up_ = true;
|
||||
// Detach from operands. An instruction may be repeated as an operand. To
|
||||
// avoid calling RemoveUser twice on the same operand, check before remove.
|
||||
for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
|
||||
|
@ -480,7 +480,11 @@ class HloInstruction {
|
||||
kCustom,
|
||||
};
|
||||
|
||||
virtual ~HloInstruction();
|
||||
virtual ~HloInstruction() { DetachFromOperandsAndUsers(); }
|
||||
|
||||
// Detaches an instruction from its operands and users. That is, remove the
|
||||
// instruction from each operand's user set and user's operand set.
|
||||
void DetachFromOperandsAndUsers();
|
||||
|
||||
// Creates an instruction from the given proto. Arguments:
|
||||
//
|
||||
@ -2025,6 +2029,10 @@ class HloInstruction {
|
||||
// a default configuration.
|
||||
bool is_default_config_ = false;
|
||||
|
||||
// True if this instruction has already been detached from its user and
|
||||
// operands.
|
||||
bool cleaned_up_ = false;
|
||||
|
||||
// String identifier for instruction.
|
||||
string name_;
|
||||
|
||||
|
@ -184,6 +184,13 @@ class HloModule {
|
||||
// Gets the number of instructions in this module.
|
||||
int64 instruction_count() const;
|
||||
|
||||
// Deallocate removed instructions in each computation.
|
||||
void Cleanup() {
|
||||
for (auto& comp : computations_) {
|
||||
comp->Cleanup();
|
||||
}
|
||||
}
|
||||
|
||||
// Compute and return a post order of all computations in the module. The sort
|
||||
// is defined like so: if computation A has an instruction which calls
|
||||
// computation B, then A will appear after B in the sort.
|
||||
|
@ -64,6 +64,13 @@ class HloModuleGroup {
|
||||
|
||||
string ToString() const;
|
||||
|
||||
// Deallocate removed instructions in each module.
|
||||
void Cleanup() {
|
||||
for (auto& module : modules_) {
|
||||
module->Cleanup();
|
||||
}
|
||||
}
|
||||
|
||||
// Serialize the module group to/from a proto.
|
||||
HloModuleGroupProto ToProto() const;
|
||||
static StatusOr<HloModuleGroup> CreateFromProto(
|
||||
|
@ -104,11 +104,15 @@ class HloPassPipeline : public HloPassInterface {
|
||||
// helpers enable templating of the core of the pipeline logic by providing
|
||||
// HloModule and HloModuleGroup specific methods with the same name.
|
||||
static StatusOr<bool> RunHelper(HloPassInterface* pass, HloModule* module) {
|
||||
return pass->Run(module);
|
||||
TF_ASSIGN_OR_RETURN(bool changed, pass->Run(module));
|
||||
module->Cleanup();
|
||||
return changed;
|
||||
}
|
||||
static StatusOr<bool> RunHelper(HloPassInterface* pass,
|
||||
HloModuleGroup* module_group) {
|
||||
return pass->RunOnModuleGroup(module_group);
|
||||
TF_ASSIGN_OR_RETURN(bool changed, pass->RunOnModuleGroup(module_group));
|
||||
module_group->Cleanup();
|
||||
return changed;
|
||||
}
|
||||
|
||||
const string name_;
|
||||
|
@ -368,12 +368,12 @@ bool MultiOutputFusion::Perform() {
|
||||
int changed = false;
|
||||
// Pick the top candidate from queue and try to merge.
|
||||
while (!worklist_.empty()) {
|
||||
ToBeFused candidate = worklist_.top();
|
||||
worklist_.pop();
|
||||
ToBeFused candidate = worklist_.pop();
|
||||
|
||||
HloInstruction* instr1 = candidate.instr1;
|
||||
HloInstruction* instr2 = candidate.instr2;
|
||||
|
||||
// Candidates are already fused.
|
||||
if (is_fused(instr1) || is_fused(instr2)) {
|
||||
continue;
|
||||
}
|
||||
|
@ -136,9 +136,34 @@ class MultiOutputFusion : public HloModulePass {
|
||||
HloInstruction* instr1;
|
||||
HloInstruction* instr2;
|
||||
int64 score;
|
||||
ToBeFused(HloInstruction* instr1, HloInstruction* instr2, int64 score)
|
||||
: instr1(instr1), instr2(instr2), score(score) {}
|
||||
bool operator<(const ToBeFused& rhs) const { return score < rhs.score; }
|
||||
int64 timestamp;
|
||||
ToBeFused(HloInstruction* instr1, HloInstruction* instr2, int64 score,
|
||||
int64 timestamp)
|
||||
: instr1(instr1), instr2(instr2), score(score), timestamp(timestamp) {}
|
||||
bool operator<(const ToBeFused& rhs) const {
|
||||
return std::pair<int64, int64>(score, timestamp) <
|
||||
std::pair<int64, int64>(rhs.score, rhs.timestamp);
|
||||
}
|
||||
};
|
||||
|
||||
// Stable priority queue where each insertion has a timestamp for
|
||||
// deterministic popping.
|
||||
class WorkList {
|
||||
public:
|
||||
bool empty() { return worklist_.empty(); }
|
||||
ToBeFused pop() {
|
||||
ToBeFused tmp = worklist_.top();
|
||||
worklist_.pop();
|
||||
return tmp;
|
||||
}
|
||||
template <class... Args>
|
||||
void emplace(Args&&... args) {
|
||||
worklist_.emplace(std::forward<Args>(args)..., timestamp_++);
|
||||
}
|
||||
|
||||
private:
|
||||
std::priority_queue<ToBeFused> worklist_;
|
||||
int64 timestamp_ = 0;
|
||||
};
|
||||
|
||||
// Update the internal data structures before instr1 and instr2 are fused into
|
||||
@ -169,7 +194,7 @@ class MultiOutputFusion : public HloModulePass {
|
||||
}
|
||||
|
||||
std::vector<FusionCandidate> candidates_;
|
||||
std::priority_queue<ToBeFused> worklist_;
|
||||
WorkList worklist_;
|
||||
|
||||
// A map that maps an instruction to the index_.
|
||||
absl::flat_hash_map<HloInstruction*, int> candidates_index_;
|
||||
|
Loading…
x
Reference in New Issue
Block a user