Refactor the code to avoid duplication (NFC).
IsFusedIrEmitterInefficient can reuse the code from FusionNodeIndexingEvaluation. PiperOrigin-RevId: 334791886 Change-Id: I8bd812913355133bfcc0ea1f85792f47c550fb1c
This commit is contained in:
parent
94e163651a
commit
7ea86e9de8
@ -25,24 +25,33 @@ limitations under the License.
|
||||
namespace xla {
|
||||
|
||||
FusionNodeIndexingEvaluation::FusionNodeIndexingEvaluation(
|
||||
const HloInstruction* fusion)
|
||||
const HloInstruction* fusion, int64 root_usage_count)
|
||||
: fusion_(fusion) {
|
||||
HloInstruction* root = fusion->fused_expression_root();
|
||||
indexing_users_[root].insert(fusion);
|
||||
index_usage_count_[fusion] = 1;
|
||||
index_usage_count_[fusion] = root_usage_count;
|
||||
RecomputeCache();
|
||||
}
|
||||
|
||||
// This constant is arbitrarily chosen. Essentially we don't want to have too
|
||||
// much code duplication, because it slows down the compilation time. There is
|
||||
// a tradeoff between compilation time and runtime here.
|
||||
const int64 FusionNodeIndexingEvaluation::kAllowedCodeDuplication = 15;
|
||||
|
||||
bool FusionNodeIndexingEvaluation::CodeDuplicationTooHigh(
|
||||
const HloInstruction* producer) const {
|
||||
// This constant is arbitrarily chosen. Essentially we don't want to have too
|
||||
// much code duplication, because it slows down the compilation time. There is
|
||||
// a tradeoff between compilation time and runtime here.
|
||||
const int64 kAllowedCodeDuplication = 15;
|
||||
|
||||
return EvaluateEmittedInstructions(producer) > kAllowedCodeDuplication;
|
||||
}
|
||||
|
||||
bool FusionNodeIndexingEvaluation::MaxCodeDuplicationTooHigh() const {
|
||||
for (const auto& entry : index_usage_count_) {
|
||||
if (entry.second > kAllowedCodeDuplication) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
int64 FusionNodeIndexingEvaluation::EvaluateEmittedInstructions(
|
||||
const HloInstruction* producer) const {
|
||||
int64 total = 0;
|
||||
|
||||
@ -24,13 +24,19 @@ limitations under the License.
|
||||
namespace xla {
|
||||
class FusionNodeIndexingEvaluation {
|
||||
public:
|
||||
explicit FusionNodeIndexingEvaluation(const HloInstruction* fusion);
|
||||
explicit FusionNodeIndexingEvaluation(const HloInstruction* fusion,
|
||||
int64 root_usage_count = 1);
|
||||
|
||||
// Evaluate the number of times 'producer' would be emitted if it is fused
|
||||
// into 'fusion_'. If the duplication is "too high" (some arbitrary chosen
|
||||
// constant), returns true.
|
||||
bool CodeDuplicationTooHigh(const HloInstruction* producer) const;
|
||||
|
||||
// Evaluate the maximum code duplication inside the fusion node. If the
|
||||
// maximum code duplication is "too high" (some arbitrary chosen constant),
|
||||
// returns true.
|
||||
bool MaxCodeDuplicationTooHigh() const;
|
||||
|
||||
// Evaluate the number of times 'producer' would be emitted if it is fused
|
||||
// into 'fusion_'.
|
||||
int64 EvaluateEmittedInstructions(const HloInstruction* producer) const;
|
||||
@ -53,6 +59,8 @@ class FusionNodeIndexingEvaluation {
|
||||
HloInstruction* fusion_operand);
|
||||
|
||||
private:
|
||||
static const int64 kAllowedCodeDuplication;
|
||||
|
||||
// Computes the 'indexing_users_' and 'index_usage_count_' maps based on the
|
||||
// current instructions inside the fusion node. Also updates
|
||||
// 'total_emitted_instructions_' accordingly.
|
||||
|
||||
@ -158,10 +158,10 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
|
||||
"//tensorflow/compiler/xla/service:fusion_node_indexing_evaluation",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm-project//llvm:Core",
|
||||
|
||||
@ -18,12 +18,11 @@ limitations under the License.
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "llvm/IR/BasicBlock.h"
|
||||
#include "llvm/IR/Value.h"
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
@ -214,89 +213,15 @@ bool FusedIrEmitter::IsFusedIrEmitterInefficient(
|
||||
if (consumer->opcode() != HloOpcode::kFusion) {
|
||||
return false;
|
||||
}
|
||||
// Collects for each instruction in the fusion node from which (indirect)
|
||||
// users newly created index values are passed. Roughly speaking, we reuse
|
||||
// index values if the shapes are equal when ignoring the element type (we may
|
||||
// reuse also if the shape change is a bitcast, but we don't consider that
|
||||
// here). By ignoring potential reuses our estimate whether the fusion emitter
|
||||
// is inefficient is a bit more conservative than necessary.
|
||||
absl::flat_hash_map<const HloInstruction*,
|
||||
absl::flat_hash_set<const HloInstruction*>>
|
||||
indexing_users;
|
||||
// Stores the number of different index accesses for each instruction in the
|
||||
// fusion node. The fusion emitter caches access with the same index, so this
|
||||
// value indicates how many times a specific instruction will be emitted.
|
||||
absl::flat_hash_map<const HloInstruction*, int64> index_usage_count;
|
||||
index_usage_count[consumer] = 1;
|
||||
|
||||
auto evaluate_fusion_computation = [&indexing_users, &index_usage_count](
|
||||
const HloInstruction* fusion) {
|
||||
auto postorder =
|
||||
fusion->fused_instructions_computation()->MakeInstructionPostOrder();
|
||||
std::reverse(postorder.begin(), postorder.end());
|
||||
for (const auto* instruction : postorder) {
|
||||
if (instruction->opcode() == HloOpcode::kParameter) {
|
||||
continue;
|
||||
}
|
||||
int64& total = index_usage_count[instruction];
|
||||
if (indexing_users[instruction].empty()) {
|
||||
total = index_usage_count[fusion];
|
||||
} else {
|
||||
total = 0;
|
||||
for (const auto* user : indexing_users[instruction]) {
|
||||
total += index_usage_count[user];
|
||||
}
|
||||
}
|
||||
for (const auto* operand : instruction->operands()) {
|
||||
// For simplicity we assume that all shape and layout changing
|
||||
// operations except Transposes invalidate index reuse. Transposes are
|
||||
// special: although they are shape changing, we can reuse the
|
||||
// multi-dimensional index for the operand by permuting it.
|
||||
if (instruction->opcode() == HloOpcode::kTranspose ||
|
||||
Shape::Equal().IgnoreElementType()(operand->shape(),
|
||||
instruction->shape())) {
|
||||
// If the index is reused, it means the operand gets index values
|
||||
// from the same set of (indirect) users as 'instruction' itself.
|
||||
indexing_users[operand].insert(indexing_users[instruction].begin(),
|
||||
indexing_users[instruction].end());
|
||||
} else {
|
||||
// If the index is not reused, it means 'instruction' computes a
|
||||
// new index derived from the index it gets.
|
||||
indexing_users[operand].insert(instruction);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
evaluate_fusion_computation(consumer);
|
||||
|
||||
// Also account for the 'producer' if it would be fused. Find the operand it
|
||||
// corresponds to.
|
||||
for (int64 operand_num = 0; operand_num < consumer->operand_count();
|
||||
++operand_num) {
|
||||
if (consumer->operand(operand_num) == producer) {
|
||||
auto instruction = consumer->fused_parameter(operand_num);
|
||||
int64& total = index_usage_count[producer];
|
||||
total = 0;
|
||||
for (const auto* user : indexing_users[instruction]) {
|
||||
total += index_usage_count[user];
|
||||
}
|
||||
break;
|
||||
}
|
||||
FusionNodeIndexingEvaluation eval_consumer(consumer);
|
||||
if (producer->opcode() != HloOpcode::kFusion) {
|
||||
return eval_consumer.CodeDuplicationTooHigh(producer);
|
||||
}
|
||||
|
||||
// If 'producer' is a fusion node as well, also evaluate it.
|
||||
if (producer->opcode() == HloOpcode::kFusion) {
|
||||
evaluate_fusion_computation(producer);
|
||||
}
|
||||
|
||||
for (const auto& entry : index_usage_count) {
|
||||
// Check that the code duplication has at most a factor of 15 (where 15 is
|
||||
// an arbitrary constant that seems to work).
|
||||
if (entry.second > 15) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
// If 'producer' is a fusion node as well, also evaluate it. Pass the
|
||||
// evaluated duplication of the fusion node if it is merged into consumer.
|
||||
FusionNodeIndexingEvaluation eval_producer(
|
||||
producer, eval_consumer.EvaluateEmittedInstructions(producer));
|
||||
return eval_producer.MaxCodeDuplicationTooHigh();
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user