Refactor the code to avoid duplication (NFC).
IsFusedIrEmitterInefficient can reuse the code from FusionNodeIndexingEvaluation. PiperOrigin-RevId: 334791886 Change-Id: I8bd812913355133bfcc0ea1f85792f47c550fb1c
This commit is contained in:
parent
94e163651a
commit
7ea86e9de8
@ -25,24 +25,33 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
FusionNodeIndexingEvaluation::FusionNodeIndexingEvaluation(
|
FusionNodeIndexingEvaluation::FusionNodeIndexingEvaluation(
|
||||||
const HloInstruction* fusion)
|
const HloInstruction* fusion, int64 root_usage_count)
|
||||||
: fusion_(fusion) {
|
: fusion_(fusion) {
|
||||||
HloInstruction* root = fusion->fused_expression_root();
|
HloInstruction* root = fusion->fused_expression_root();
|
||||||
indexing_users_[root].insert(fusion);
|
indexing_users_[root].insert(fusion);
|
||||||
index_usage_count_[fusion] = 1;
|
index_usage_count_[fusion] = root_usage_count;
|
||||||
RecomputeCache();
|
RecomputeCache();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This constant is arbitrarily chosen. Essentially we don't want to have too
|
||||||
|
// much code duplication, because it slows down the compilation time. There is
|
||||||
|
// a tradeoff between compilation time and runtime here.
|
||||||
|
const int64 FusionNodeIndexingEvaluation::kAllowedCodeDuplication = 15;
|
||||||
|
|
||||||
bool FusionNodeIndexingEvaluation::CodeDuplicationTooHigh(
|
bool FusionNodeIndexingEvaluation::CodeDuplicationTooHigh(
|
||||||
const HloInstruction* producer) const {
|
const HloInstruction* producer) const {
|
||||||
// This constant is arbitrarily chosen. Essentially we don't want to have too
|
|
||||||
// much code duplication, because it slows down the compilation time. There is
|
|
||||||
// a tradeoff between compilation time and runtime here.
|
|
||||||
const int64 kAllowedCodeDuplication = 15;
|
|
||||||
|
|
||||||
return EvaluateEmittedInstructions(producer) > kAllowedCodeDuplication;
|
return EvaluateEmittedInstructions(producer) > kAllowedCodeDuplication;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool FusionNodeIndexingEvaluation::MaxCodeDuplicationTooHigh() const {
|
||||||
|
for (const auto& entry : index_usage_count_) {
|
||||||
|
if (entry.second > kAllowedCodeDuplication) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
int64 FusionNodeIndexingEvaluation::EvaluateEmittedInstructions(
|
int64 FusionNodeIndexingEvaluation::EvaluateEmittedInstructions(
|
||||||
const HloInstruction* producer) const {
|
const HloInstruction* producer) const {
|
||||||
int64 total = 0;
|
int64 total = 0;
|
||||||
|
|||||||
@ -24,13 +24,19 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
class FusionNodeIndexingEvaluation {
|
class FusionNodeIndexingEvaluation {
|
||||||
public:
|
public:
|
||||||
explicit FusionNodeIndexingEvaluation(const HloInstruction* fusion);
|
explicit FusionNodeIndexingEvaluation(const HloInstruction* fusion,
|
||||||
|
int64 root_usage_count = 1);
|
||||||
|
|
||||||
// Evaluate the number of times 'producer' would be emitted if it is fused
|
// Evaluate the number of times 'producer' would be emitted if it is fused
|
||||||
// into 'fusion_'. If the duplication is "too high" (some arbitrary chosen
|
// into 'fusion_'. If the duplication is "too high" (some arbitrary chosen
|
||||||
// constant), returns true.
|
// constant), returns true.
|
||||||
bool CodeDuplicationTooHigh(const HloInstruction* producer) const;
|
bool CodeDuplicationTooHigh(const HloInstruction* producer) const;
|
||||||
|
|
||||||
|
// Evaluate the maximum code duplication inside the fusion node. If the
|
||||||
|
// maximum code duplication is "too high" (some arbitrary chosen constant),
|
||||||
|
// returns true.
|
||||||
|
bool MaxCodeDuplicationTooHigh() const;
|
||||||
|
|
||||||
// Evaluate the number of times 'producer' would be emitted if it is fused
|
// Evaluate the number of times 'producer' would be emitted if it is fused
|
||||||
// into 'fusion_'.
|
// into 'fusion_'.
|
||||||
int64 EvaluateEmittedInstructions(const HloInstruction* producer) const;
|
int64 EvaluateEmittedInstructions(const HloInstruction* producer) const;
|
||||||
@ -53,6 +59,8 @@ class FusionNodeIndexingEvaluation {
|
|||||||
HloInstruction* fusion_operand);
|
HloInstruction* fusion_operand);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
static const int64 kAllowedCodeDuplication;
|
||||||
|
|
||||||
// Computes the 'indexing_users_' and 'index_usage_count_' maps based on the
|
// Computes the 'indexing_users_' and 'index_usage_count_' maps based on the
|
||||||
// current instructions inside the fusion node. Also updates
|
// current instructions inside the fusion node. Also updates
|
||||||
// 'total_emitted_instructions_' accordingly.
|
// 'total_emitted_instructions_' accordingly.
|
||||||
|
|||||||
@ -158,10 +158,10 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
|
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
|
||||||
|
"//tensorflow/compiler/xla/service:fusion_node_indexing_evaluation",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
"@llvm-project//llvm:Core",
|
"@llvm-project//llvm:Core",
|
||||||
|
|||||||
@ -18,12 +18,11 @@ limitations under the License.
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
|
||||||
#include "absl/container/flat_hash_set.h"
|
|
||||||
#include "llvm/IR/BasicBlock.h"
|
#include "llvm/IR/BasicBlock.h"
|
||||||
#include "llvm/IR/Value.h"
|
#include "llvm/IR/Value.h"
|
||||||
#include "tensorflow/compiler/xla/map_util.h"
|
#include "tensorflow/compiler/xla/map_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
|
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
@ -214,89 +213,15 @@ bool FusedIrEmitter::IsFusedIrEmitterInefficient(
|
|||||||
if (consumer->opcode() != HloOpcode::kFusion) {
|
if (consumer->opcode() != HloOpcode::kFusion) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// Collects for each instruction in the fusion node from which (indirect)
|
FusionNodeIndexingEvaluation eval_consumer(consumer);
|
||||||
// users newly created index values are passed. Roughly speaking, we reuse
|
if (producer->opcode() != HloOpcode::kFusion) {
|
||||||
// index values if the shapes are equal when ignoring the element type (we may
|
return eval_consumer.CodeDuplicationTooHigh(producer);
|
||||||
// reuse also if the shape change is a bitcast, but we don't consider that
|
|
||||||
// here). By ignoring potential reuses our estimate whether the fusion emitter
|
|
||||||
// is inefficient is a bit more conservative than necessary.
|
|
||||||
absl::flat_hash_map<const HloInstruction*,
|
|
||||||
absl::flat_hash_set<const HloInstruction*>>
|
|
||||||
indexing_users;
|
|
||||||
// Stores the number of different index accesses for each instruction in the
|
|
||||||
// fusion node. The fusion emitter caches access with the same index, so this
|
|
||||||
// value indicates how many times a specific instruction will be emitted.
|
|
||||||
absl::flat_hash_map<const HloInstruction*, int64> index_usage_count;
|
|
||||||
index_usage_count[consumer] = 1;
|
|
||||||
|
|
||||||
auto evaluate_fusion_computation = [&indexing_users, &index_usage_count](
|
|
||||||
const HloInstruction* fusion) {
|
|
||||||
auto postorder =
|
|
||||||
fusion->fused_instructions_computation()->MakeInstructionPostOrder();
|
|
||||||
std::reverse(postorder.begin(), postorder.end());
|
|
||||||
for (const auto* instruction : postorder) {
|
|
||||||
if (instruction->opcode() == HloOpcode::kParameter) {
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
int64& total = index_usage_count[instruction];
|
// If 'producer' is a fusion node as well, also evaluate it. Pass the
|
||||||
if (indexing_users[instruction].empty()) {
|
// evaluated duplication of the fusion node if it is merged into consumer.
|
||||||
total = index_usage_count[fusion];
|
FusionNodeIndexingEvaluation eval_producer(
|
||||||
} else {
|
producer, eval_consumer.EvaluateEmittedInstructions(producer));
|
||||||
total = 0;
|
return eval_producer.MaxCodeDuplicationTooHigh();
|
||||||
for (const auto* user : indexing_users[instruction]) {
|
|
||||||
total += index_usage_count[user];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (const auto* operand : instruction->operands()) {
|
|
||||||
// For simplicity we assume that all shape and layout changing
|
|
||||||
// operations except Transposes invalidate index reuse. Transposes are
|
|
||||||
// special: although they are shape changing, we can reuse the
|
|
||||||
// multi-dimensional index for the operand by permuting it.
|
|
||||||
if (instruction->opcode() == HloOpcode::kTranspose ||
|
|
||||||
Shape::Equal().IgnoreElementType()(operand->shape(),
|
|
||||||
instruction->shape())) {
|
|
||||||
// If the index is reused, it means the operand gets index values
|
|
||||||
// from the same set of (indirect) users as 'instruction' itself.
|
|
||||||
indexing_users[operand].insert(indexing_users[instruction].begin(),
|
|
||||||
indexing_users[instruction].end());
|
|
||||||
} else {
|
|
||||||
// If the index is not reused, it means 'instruction' computes a
|
|
||||||
// new index derived from the index it gets.
|
|
||||||
indexing_users[operand].insert(instruction);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
evaluate_fusion_computation(consumer);
|
|
||||||
|
|
||||||
// Also account for the 'producer' if it would be fused. Find the operand it
|
|
||||||
// corresponds to.
|
|
||||||
for (int64 operand_num = 0; operand_num < consumer->operand_count();
|
|
||||||
++operand_num) {
|
|
||||||
if (consumer->operand(operand_num) == producer) {
|
|
||||||
auto instruction = consumer->fused_parameter(operand_num);
|
|
||||||
int64& total = index_usage_count[producer];
|
|
||||||
total = 0;
|
|
||||||
for (const auto* user : indexing_users[instruction]) {
|
|
||||||
total += index_usage_count[user];
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If 'producer' is a fusion node as well, also evaluate it.
|
|
||||||
if (producer->opcode() == HloOpcode::kFusion) {
|
|
||||||
evaluate_fusion_computation(producer);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const auto& entry : index_usage_count) {
|
|
||||||
// Check that the code duplication has at most a factor of 15 (where 15 is
|
|
||||||
// an arbitrary constant that seems to work).
|
|
||||||
if (entry.second > 15) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user