[Resubmit] Fix several issues of multi output fusion.

PiperOrigin-RevId: 295885785
Change-Id: I1e6350437987f63843181d704fe86660fd9cfb8c
This commit is contained in:
Yunxing Dai 2020-02-18 20:44:59 -08:00 committed by TensorFlower Gardener
parent 9771b11027
commit 823384e08f
9 changed files with 77 additions and 10 deletions

View File

@ -309,6 +309,8 @@ Status HloComputation::RemoveInstructionImpl(HloInstruction* instruction,
auto inst_it = instruction_iterators_.find(instruction);
TF_RET_CHECK(inst_it != instruction_iterators_.end());
(*inst_it->second)->set_parent(nullptr);
to_be_deleted_.emplace_back(inst_it->second->release());
to_be_deleted_.back()->DetachFromOperandsAndUsers();
instructions_.erase(inst_it->second);
instruction_iterators_.erase(inst_it);
return Status::OK();

View File

@ -469,6 +469,12 @@ class HloComputation {
int64 unique_id() const { return unique_id_; }
// Deallocate instructions that are marked by "RemoveInstruction". The two
// stage clean up process is designed such that HloPass can have stable
// internal pointers to HloInstructions while we create and remove
// HloInstructions in a pass.
void Cleanup() { to_be_deleted_.clear(); }
private:
explicit HloComputation(
const string& name, int parameter_count,
@ -527,6 +533,10 @@ class HloComputation {
absl::flat_hash_map<const HloInstruction*, InstructionList::iterator>
instruction_iterators_;
// Removed instructions are moved into to_be_deleted_ first and then
// deallocated when Cleanup is called.
std::vector<std::unique_ptr<HloInstruction>> to_be_deleted_;
std::vector<HloInstruction*> param_instructions_;
TF_DISALLOW_COPY_AND_ASSIGN(HloComputation);

View File

@ -1661,7 +1661,11 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
return clone;
}
HloInstruction::~HloInstruction() {
void HloInstruction::DetachFromOperandsAndUsers() {
if (cleaned_up_) {
return;
}
cleaned_up_ = true;
// Detach from operands. An instruction may be repeated as an operand. To
// avoid calling RemoveUser twice on the same operand, check before remove.
for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {

View File

@ -480,7 +480,11 @@ class HloInstruction {
kCustom,
};
virtual ~HloInstruction();
virtual ~HloInstruction() { DetachFromOperandsAndUsers(); }
// Detaches an instruction from its operands and users. That is, remove the
// instruction from each operand's user set and user's operand set.
void DetachFromOperandsAndUsers();
// Creates an instruction from the given proto. Arguments:
//
@ -2025,6 +2029,10 @@ class HloInstruction {
// a default configuration.
bool is_default_config_ = false;
// True if this instruction has already been detached from its user and
// operands.
bool cleaned_up_ = false;
// String identifier for instruction.
string name_;

View File

@ -184,6 +184,13 @@ class HloModule {
// Gets the number of instructions in this module.
int64 instruction_count() const;
// Deallocate removed instructions in each computation.
void Cleanup() {
for (auto& comp : computations_) {
comp->Cleanup();
}
}
// Compute and return a post order of all computations in the module. The sort
// is defined like so: if computation A has an instruction which calls
// computation B, then A will appear after B in the sort.

View File

@ -64,6 +64,13 @@ class HloModuleGroup {
string ToString() const;
// Deallocate removed instructions in each module.
void Cleanup() {
for (auto& module : modules_) {
module->Cleanup();
}
}
// Serialize the module group to/from a proto.
HloModuleGroupProto ToProto() const;
static StatusOr<HloModuleGroup> CreateFromProto(

View File

@ -104,11 +104,15 @@ class HloPassPipeline : public HloPassInterface {
// helpers enable templating of the core of the pipeline logic by providing
// HloModule and HloModuleGroup specific methods with the same name.
static StatusOr<bool> RunHelper(HloPassInterface* pass, HloModule* module) {
return pass->Run(module);
TF_ASSIGN_OR_RETURN(bool changed, pass->Run(module));
module->Cleanup();
return changed;
}
static StatusOr<bool> RunHelper(HloPassInterface* pass,
HloModuleGroup* module_group) {
return pass->RunOnModuleGroup(module_group);
TF_ASSIGN_OR_RETURN(bool changed, pass->RunOnModuleGroup(module_group));
module_group->Cleanup();
return changed;
}
const string name_;

View File

@ -368,12 +368,12 @@ bool MultiOutputFusion::Perform() {
int changed = false;
// Pick the top candidate from queue and try to merge.
while (!worklist_.empty()) {
ToBeFused candidate = worklist_.top();
worklist_.pop();
ToBeFused candidate = worklist_.pop();
HloInstruction* instr1 = candidate.instr1;
HloInstruction* instr2 = candidate.instr2;
// Candidates are already fused.
if (is_fused(instr1) || is_fused(instr2)) {
continue;
}

View File

@ -136,9 +136,34 @@ class MultiOutputFusion : public HloModulePass {
HloInstruction* instr1;
HloInstruction* instr2;
int64 score;
ToBeFused(HloInstruction* instr1, HloInstruction* instr2, int64 score)
: instr1(instr1), instr2(instr2), score(score) {}
bool operator<(const ToBeFused& rhs) const { return score < rhs.score; }
int64 timestamp;
ToBeFused(HloInstruction* instr1, HloInstruction* instr2, int64 score,
int64 timestamp)
: instr1(instr1), instr2(instr2), score(score), timestamp(timestamp) {}
bool operator<(const ToBeFused& rhs) const {
return std::pair<int64, int64>(score, timestamp) <
std::pair<int64, int64>(rhs.score, rhs.timestamp);
}
};
// Stable priority queue where each insertion has a timestamp for
// deterministic popping.
class WorkList {
public:
bool empty() { return worklist_.empty(); }
ToBeFused pop() {
ToBeFused tmp = worklist_.top();
worklist_.pop();
return tmp;
}
template <class... Args>
void emplace(Args&&... args) {
worklist_.emplace(std::forward<Args>(args)..., timestamp_++);
}
private:
std::priority_queue<ToBeFused> worklist_;
int64 timestamp_ = 0;
};
// Update the internal data structures before instr1 and instr2 are fused into
@ -169,7 +194,7 @@ class MultiOutputFusion : public HloModulePass {
}
std::vector<FusionCandidate> candidates_;
std::priority_queue<ToBeFused> worklist_;
WorkList worklist_;
// A map that maps an instruction to the index_.
absl::flat_hash_map<HloInstruction*, int> candidates_index_;