From 7ea86e9de8a5b6426e7291e0e5477ddaee83ba88 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 1 Oct 2020 05:12:01 -0700 Subject: [PATCH] Refactor the code to avoid duplication (NFC). IsFusedIrEmitterInefficient can reuse the code from FusionNodeIndexingEvaluation. PiperOrigin-RevId: 334791886 Change-Id: I8bd812913355133bfcc0ea1f85792f47c550fb1c --- .../fusion_node_indexing_evaluation.cc | 23 +++-- .../service/fusion_node_indexing_evaluation.h | 10 +- tensorflow/compiler/xla/service/llvm_ir/BUILD | 2 +- .../xla/service/llvm_ir/fused_ir_emitter.cc | 93 ++----------------- 4 files changed, 35 insertions(+), 93 deletions(-) diff --git a/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc b/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc index ab6a3d01d21..17d3fb2b3d6 100644 --- a/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc +++ b/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc @@ -25,24 +25,33 @@ limitations under the License. namespace xla { FusionNodeIndexingEvaluation::FusionNodeIndexingEvaluation( - const HloInstruction* fusion) + const HloInstruction* fusion, int64 root_usage_count) : fusion_(fusion) { HloInstruction* root = fusion->fused_expression_root(); indexing_users_[root].insert(fusion); - index_usage_count_[fusion] = 1; + index_usage_count_[fusion] = root_usage_count; RecomputeCache(); } +// This constant is arbitrarily chosen. Essentially we don't want to have too +// much code duplication, because it slows down the compilation time. There is +// a tradeoff between compilation time and runtime here. +const int64 FusionNodeIndexingEvaluation::kAllowedCodeDuplication = 15; + bool FusionNodeIndexingEvaluation::CodeDuplicationTooHigh( const HloInstruction* producer) const { - // This constant is arbitrarily chosen. Essentially we don't want to have too - // much code duplication, because it slows down the compilation time. There is - // a tradeoff between compilation time and runtime here. - const int64 kAllowedCodeDuplication = 15; - return EvaluateEmittedInstructions(producer) > kAllowedCodeDuplication; } +bool FusionNodeIndexingEvaluation::MaxCodeDuplicationTooHigh() const { + for (const auto& entry : index_usage_count_) { + if (entry.second > kAllowedCodeDuplication) { + return true; + } + } + return false; +} + int64 FusionNodeIndexingEvaluation::EvaluateEmittedInstructions( const HloInstruction* producer) const { int64 total = 0; diff --git a/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h b/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h index b85bf9104c7..abe154a5149 100644 --- a/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h +++ b/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h @@ -24,13 +24,19 @@ limitations under the License. namespace xla { class FusionNodeIndexingEvaluation { public: - explicit FusionNodeIndexingEvaluation(const HloInstruction* fusion); + explicit FusionNodeIndexingEvaluation(const HloInstruction* fusion, + int64 root_usage_count = 1); // Evaluate the number of times 'producer' would be emitted if it is fused // into 'fusion_'. If the duplication is "too high" (some arbitrary chosen // constant), returns true. bool CodeDuplicationTooHigh(const HloInstruction* producer) const; + // Evaluate the maximum code duplication inside the fusion node. If the + // maximum code duplication is "too high" (some arbitrary chosen constant), + // returns true. + bool MaxCodeDuplicationTooHigh() const; + // Evaluate the number of times 'producer' would be emitted if it is fused // into 'fusion_'. int64 EvaluateEmittedInstructions(const HloInstruction* producer) const; @@ -53,6 +59,8 @@ class FusionNodeIndexingEvaluation { HloInstruction* fusion_operand); private: + static const int64 kAllowedCodeDuplication; + // Computes the 'indexing_users_' and 'index_usage_count_' maps based on the // current instructions inside the fusion node. Also updates // 'total_emitted_instructions_' accordingly. diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 59f4466980f..9940b032558 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -158,10 +158,10 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:elemental_ir_emitter", + "//tensorflow/compiler/xla/service:fusion_node_indexing_evaluation", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index fbffacc3b26..164c8f7e1c8 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -18,12 +18,11 @@ limitations under the License. #include #include -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -214,89 +213,15 @@ bool FusedIrEmitter::IsFusedIrEmitterInefficient( if (consumer->opcode() != HloOpcode::kFusion) { return false; } - // Collects for each instruction in the fusion node from which (indirect) - // users newly created index values are passed. Roughly speaking, we reuse - // index values if the shapes are equal when ignoring the element type (we may - // reuse also if the shape change is a bitcast, but we don't consider that - // here). By ignoring potential reuses our estimate whether the fusion emitter - // is inefficient is a bit more conservative than necessary. - absl::flat_hash_map> - indexing_users; - // Stores the number of different index accesses for each instruction in the - // fusion node. The fusion emitter caches access with the same index, so this - // value indicates how many times a specific instruction will be emitted. - absl::flat_hash_map index_usage_count; - index_usage_count[consumer] = 1; - - auto evaluate_fusion_computation = [&indexing_users, &index_usage_count]( - const HloInstruction* fusion) { - auto postorder = - fusion->fused_instructions_computation()->MakeInstructionPostOrder(); - std::reverse(postorder.begin(), postorder.end()); - for (const auto* instruction : postorder) { - if (instruction->opcode() == HloOpcode::kParameter) { - continue; - } - int64& total = index_usage_count[instruction]; - if (indexing_users[instruction].empty()) { - total = index_usage_count[fusion]; - } else { - total = 0; - for (const auto* user : indexing_users[instruction]) { - total += index_usage_count[user]; - } - } - for (const auto* operand : instruction->operands()) { - // For simplicity we assume that all shape and layout changing - // operations except Transposes invalidate index reuse. Transposes are - // special: although they are shape changing, we can reuse the - // multi-dimensional index for the operand by permuting it. - if (instruction->opcode() == HloOpcode::kTranspose || - Shape::Equal().IgnoreElementType()(operand->shape(), - instruction->shape())) { - // If the index is reused, it means the operand gets index values - // from the same set of (indirect) users as 'instruction' itself. - indexing_users[operand].insert(indexing_users[instruction].begin(), - indexing_users[instruction].end()); - } else { - // If the index is not reused, it means 'instruction' computes a - // new index derived from the index it gets. - indexing_users[operand].insert(instruction); - } - } - } - }; - evaluate_fusion_computation(consumer); - - // Also account for the 'producer' if it would be fused. Find the operand it - // corresponds to. - for (int64 operand_num = 0; operand_num < consumer->operand_count(); - ++operand_num) { - if (consumer->operand(operand_num) == producer) { - auto instruction = consumer->fused_parameter(operand_num); - int64& total = index_usage_count[producer]; - total = 0; - for (const auto* user : indexing_users[instruction]) { - total += index_usage_count[user]; - } - break; - } + FusionNodeIndexingEvaluation eval_consumer(consumer); + if (producer->opcode() != HloOpcode::kFusion) { + return eval_consumer.CodeDuplicationTooHigh(producer); } - - // If 'producer' is a fusion node as well, also evaluate it. - if (producer->opcode() == HloOpcode::kFusion) { - evaluate_fusion_computation(producer); - } - - for (const auto& entry : index_usage_count) { - // Check that the code duplication has at most a factor of 15 (where 15 is - // an arbitrary constant that seems to work). - if (entry.second > 15) { - return true; - } - } - return false; + // If 'producer' is a fusion node as well, also evaluate it. Pass the + // evaluated duplication of the fusion node if it is merged into consumer. + FusionNodeIndexingEvaluation eval_producer( + producer, eval_consumer.EvaluateEmittedInstructions(producer)); + return eval_producer.MaxCodeDuplicationTooHigh(); } } // namespace xla