diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index f8514a6cba3..164c8f7e1c8 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Value.h" @@ -44,32 +43,37 @@ using llvm_ir::IrArray; Status FusedIrEmitter::DefaultAction(const HloInstruction* hlo) { indexed_generators_[hlo] = [=](const IrArray::Index& index) -> StatusOr { - auto cache = generated_value_cache_.find(hlo); - if (cache != generated_value_cache_.end()) { - auto key = std::make_pair(b_->GetInsertBlock(), index.multidim()); - if (llvm::Value* generated_value = - FindOrDefault(cache->second, key, nullptr)) { - VLOG(3) << "The cached generated value is reused."; - return generated_value; - } - auto null_key = std::make_pair(nullptr, index.multidim()); - if (llvm::Value* generated_value = - FindOrDefault(cache->second, null_key, nullptr)) { + if (llvm::Value* generated_value = FindOrDefault( + generated_value_cache_[hlo], index.multidim(), nullptr)) { + llvm::BasicBlock* generated_value_bb = nullptr; + if (auto* generated_instruction = + llvm::dyn_cast(generated_value)) { + generated_value_bb = generated_instruction->getParent(); + } + // Ideally, we should be able to reuse the cached generated value if it + // dominates the current insertion block. However, the check for dominance + // can be expensive and unreliable when the function is being constructed. + // + // It's also worth experimenting what if we don't do caching at all. + // LLVM's CSE or GVN should be able to easily merge common subexpressions + // that would be regenerated without caching. But this might increase the + // JIT compilation time. + if (generated_value_bb == nullptr || + generated_value_bb == b_->GetInsertBlock()) { VLOG(3) << "The cached generated value is reused."; return generated_value; } + VLOG(3) << "The cached generated value can't be reused, because it is in " + "a different BB (" + << generated_value_bb->getName().str() + << ") from the current insertion block (" + << b_->GetInsertBlock()->getName().str() << ")."; } TF_ASSIGN_OR_RETURN(llvm::Value* const generated_value, elemental_emitter_->MakeElementGenerator( hlo, indexed_generators_)(index)); - llvm::BasicBlock* generated_value_bb = nullptr; - if (auto* generated_instruction = - llvm::dyn_cast(generated_value)) { - generated_value_bb = generated_instruction->getParent(); - } - generated_value_cache_[hlo][std::make_pair( - generated_value_bb, index.multidim())] = generated_value; + generated_value_cache_[hlo][index.multidim()] = generated_value; return generated_value; }; return Status::OK(); diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index e19e970cb24..d13b0262180 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include "absl/container/flat_hash_map.h" #include "absl/types/optional.h" @@ -154,10 +153,9 @@ class FusedIrEmitter : public ConstDfsHloVisitorWithDefault { // Cache of generated values, lest we regenerate an element of a node with // multiple outgoing edges - absl::flat_hash_map>, - llvm::Value*>> + absl::flat_hash_map< + const HloInstruction*, + absl::flat_hash_map, llvm::Value*>> generated_value_cache_; };