Refactor the code to avoid duplication (NFC).

IsFusedIrEmitterInefficient can reuse the code from
FusionNodeIndexingEvaluation.

PiperOrigin-RevId: 334791886
Change-Id: I8bd812913355133bfcc0ea1f85792f47c550fb1c
This commit is contained in:
Adrian Kuegel 2020-10-01 05:12:01 -07:00 committed by TensorFlower Gardener
parent 94e163651a
commit 7ea86e9de8
4 changed files with 35 additions and 93 deletions

View File

@ -25,24 +25,33 @@ limitations under the License.
namespace xla {
FusionNodeIndexingEvaluation::FusionNodeIndexingEvaluation(
const HloInstruction* fusion)
const HloInstruction* fusion, int64 root_usage_count)
: fusion_(fusion) {
HloInstruction* root = fusion->fused_expression_root();
indexing_users_[root].insert(fusion);
index_usage_count_[fusion] = 1;
index_usage_count_[fusion] = root_usage_count;
RecomputeCache();
}
// This constant is arbitrarily chosen. Essentially we don't want to have too
// much code duplication, because it slows down the compilation time. There is
// a tradeoff between compilation time and runtime here.
const int64 FusionNodeIndexingEvaluation::kAllowedCodeDuplication = 15;
bool FusionNodeIndexingEvaluation::CodeDuplicationTooHigh(
const HloInstruction* producer) const {
// This constant is arbitrarily chosen. Essentially we don't want to have too
// much code duplication, because it slows down the compilation time. There is
// a tradeoff between compilation time and runtime here.
const int64 kAllowedCodeDuplication = 15;
return EvaluateEmittedInstructions(producer) > kAllowedCodeDuplication;
}
bool FusionNodeIndexingEvaluation::MaxCodeDuplicationTooHigh() const {
for (const auto& entry : index_usage_count_) {
if (entry.second > kAllowedCodeDuplication) {
return true;
}
}
return false;
}
int64 FusionNodeIndexingEvaluation::EvaluateEmittedInstructions(
const HloInstruction* producer) const {
int64 total = 0;

View File

@ -24,13 +24,19 @@ limitations under the License.
namespace xla {
class FusionNodeIndexingEvaluation {
public:
explicit FusionNodeIndexingEvaluation(const HloInstruction* fusion);
explicit FusionNodeIndexingEvaluation(const HloInstruction* fusion,
int64 root_usage_count = 1);
// Evaluate the number of times 'producer' would be emitted if it is fused
// into 'fusion_'. If the duplication is "too high" (some arbitrary chosen
// constant), returns true.
bool CodeDuplicationTooHigh(const HloInstruction* producer) const;
// Evaluate the maximum code duplication inside the fusion node. If the
// maximum code duplication is "too high" (some arbitrary chosen constant),
// returns true.
bool MaxCodeDuplicationTooHigh() const;
// Evaluate the number of times 'producer' would be emitted if it is fused
// into 'fusion_'.
int64 EvaluateEmittedInstructions(const HloInstruction* producer) const;
@ -53,6 +59,8 @@ class FusionNodeIndexingEvaluation {
HloInstruction* fusion_operand);
private:
static const int64 kAllowedCodeDuplication;
// Computes the 'indexing_users_' and 'index_usage_count_' maps based on the
// current instructions inside the fusion node. Also updates
// 'total_emitted_instructions_' accordingly.

View File

@ -158,10 +158,10 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
"//tensorflow/compiler/xla/service:fusion_node_indexing_evaluation",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Core",

View File

@ -18,12 +18,11 @@ limitations under the License.
#include <algorithm>
#include <functional>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
#include "tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@ -214,89 +213,15 @@ bool FusedIrEmitter::IsFusedIrEmitterInefficient(
if (consumer->opcode() != HloOpcode::kFusion) {
return false;
}
// Collects for each instruction in the fusion node from which (indirect)
// users newly created index values are passed. Roughly speaking, we reuse
// index values if the shapes are equal when ignoring the element type (we may
// reuse also if the shape change is a bitcast, but we don't consider that
// here). By ignoring potential reuses our estimate whether the fusion emitter
// is inefficient is a bit more conservative than necessary.
absl::flat_hash_map<const HloInstruction*,
absl::flat_hash_set<const HloInstruction*>>
indexing_users;
// Stores the number of different index accesses for each instruction in the
// fusion node. The fusion emitter caches access with the same index, so this
// value indicates how many times a specific instruction will be emitted.
absl::flat_hash_map<const HloInstruction*, int64> index_usage_count;
index_usage_count[consumer] = 1;
auto evaluate_fusion_computation = [&indexing_users, &index_usage_count](
const HloInstruction* fusion) {
auto postorder =
fusion->fused_instructions_computation()->MakeInstructionPostOrder();
std::reverse(postorder.begin(), postorder.end());
for (const auto* instruction : postorder) {
if (instruction->opcode() == HloOpcode::kParameter) {
continue;
}
int64& total = index_usage_count[instruction];
if (indexing_users[instruction].empty()) {
total = index_usage_count[fusion];
} else {
total = 0;
for (const auto* user : indexing_users[instruction]) {
total += index_usage_count[user];
}
}
for (const auto* operand : instruction->operands()) {
// For simplicity we assume that all shape and layout changing
// operations except Transposes invalidate index reuse. Transposes are
// special: although they are shape changing, we can reuse the
// multi-dimensional index for the operand by permuting it.
if (instruction->opcode() == HloOpcode::kTranspose ||
Shape::Equal().IgnoreElementType()(operand->shape(),
instruction->shape())) {
// If the index is reused, it means the operand gets index values
// from the same set of (indirect) users as 'instruction' itself.
indexing_users[operand].insert(indexing_users[instruction].begin(),
indexing_users[instruction].end());
} else {
// If the index is not reused, it means 'instruction' computes a
// new index derived from the index it gets.
indexing_users[operand].insert(instruction);
}
}
}
};
evaluate_fusion_computation(consumer);
// Also account for the 'producer' if it would be fused. Find the operand it
// corresponds to.
for (int64 operand_num = 0; operand_num < consumer->operand_count();
++operand_num) {
if (consumer->operand(operand_num) == producer) {
auto instruction = consumer->fused_parameter(operand_num);
int64& total = index_usage_count[producer];
total = 0;
for (const auto* user : indexing_users[instruction]) {
total += index_usage_count[user];
}
break;
}
FusionNodeIndexingEvaluation eval_consumer(consumer);
if (producer->opcode() != HloOpcode::kFusion) {
return eval_consumer.CodeDuplicationTooHigh(producer);
}
// If 'producer' is a fusion node as well, also evaluate it.
if (producer->opcode() == HloOpcode::kFusion) {
evaluate_fusion_computation(producer);
}
for (const auto& entry : index_usage_count) {
// Check that the code duplication has at most a factor of 15 (where 15 is
// an arbitrary constant that seems to work).
if (entry.second > 15) {
return true;
}
}
return false;
// If 'producer' is a fusion node as well, also evaluate it. Pass the
// evaluated duplication of the fusion node if it is merged into consumer.
FusionNodeIndexingEvaluation eval_producer(
producer, eval_consumer.EvaluateEmittedInstructions(producer));
return eval_producer.MaxCodeDuplicationTooHigh();
}
} // namespace xla