Add caching for reused parameters.

Since a fusion node is never splitted, we can assume that if a parameter was
reused, it will stay that way even if additional instructions are fused in.

PiperOrigin-RevId: 329497992
Change-Id: Ie0cdf7161d01704139783c329f88f7312bf7edf8
This commit is contained in:
Adrian Kuegel 2020-09-01 06:46:12 -07:00 committed by TensorFlower Gardener
parent 6293abccc7
commit adcee3ccc4
4 changed files with 34 additions and 5 deletions

View File

@ -95,7 +95,7 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
// Cost condition: not fuse (simple, expensive producers) and (consumers who
// reuse operand elements).
if (producer->opcode() != HloOpcode::kFusion && is_expensive(*producer) &&
consumer->ReusesOperandElements(operand_index)) {
ReusesOperandElements(consumer, operand_index)) {
VLOG(2) << "Fusion is not profitable.";
return false;
}

View File

@ -65,9 +65,8 @@ bool GpuInstructionFusion::ShouldFuseInexpensiveChecks(HloInstruction* consumer,
}
// Cost condition: not fuse (simple, expensive producers) and (consumers who
// reuse operand elements).
if (producer->opcode() != HloOpcode::kFusion &&
consumer->ReusesOperandElements(operand_index) &&
is_expensive(*producer)) {
if (producer->opcode() != HloOpcode::kFusion && is_expensive(*producer) &&
ReusesOperandElements(consumer, operand_index)) {
VLOG(4) << "Do not fuse simple, expensive producer " << producer->name()
<< " and consumer which reuses operand elements.";
return false;

View File

@ -711,4 +711,23 @@ HloInstruction::FusionKind InstructionFusion::ChooseKind(
return HloInstruction::FusionKind::kLoop;
}
bool InstructionFusion::ReusesOperandElements(const HloInstruction* consumer,
int64 operand_index) {
auto operand = consumer->operand(operand_index);
auto it = reused_fusion_operands_.find(consumer);
if (it != reused_fusion_operands_.end() && it->second.contains(operand)) {
return true;
}
bool reuses = consumer->ReusesOperandElements(operand_index);
// If a parameter was reused, we can cache this information. Fusion
// computations only ever grow, so it becomes more likely that a parameter is
// reused, but a reused parameter will never become *not* reused.
if (reuses) {
// We cache the operand corresponding to the fusion parameter, because the
// parameter pointers would be invalidated after the next fusion.
reused_fusion_operands_[consumer].insert(operand);
}
return reuses;
}
} // namespace xla

View File

@ -1,4 +1,3 @@
#include "absl/container/flat_hash_map.h"
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
@ -20,6 +19,8 @@ limitations under the License.
#include <functional>
#include <utility>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/fusion_queue.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@ -138,6 +139,11 @@ class InstructionFusion : public HloModulePass {
return config_collection_mode_;
}
// Returns whether 'consumer' may reuse elements of its `operand_index`th
// operand.
bool ReusesOperandElements(const HloInstruction* consumer,
int64 operand_index);
private:
// The set of producers whose consumers we cannot fuse into.
using HloInstructionSet = std::unordered_set<HloInstruction*>;
@ -172,6 +178,11 @@ class InstructionFusion : public HloModulePass {
// Configuration mode.
FusionConfigCollection config_collection_mode_;
// Caches which operands are reused inside fusion computations.
absl::flat_hash_map<const HloInstruction*,
absl::flat_hash_set<const HloInstruction*>>
reused_fusion_operands_;
TF_DISALLOW_COPY_AND_ASSIGN(InstructionFusion);
};