diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index 7fbd01e1b21..0371ce71874 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -43,9 +44,8 @@ using llvm_ir::IrArray; Status FusedIrEmitter::DefaultAction(const HloInstruction* hlo) { indexed_generators_[hlo] = [=](const IrArray::Index& index) -> StatusOr { - if (generated_value_cache_[hlo].contains(index.multidim())) { - llvm::Value* generated_value = - generated_value_cache_[hlo][index.multidim()]; + if (llvm::Value* generated_value = FindOrDefault( + generated_value_cache_[hlo], index.multidim(), nullptr)) { llvm::BasicBlock* generated_value_bb = nullptr; if (auto* generated_instruction = llvm::dyn_cast(generated_value)) { @@ -71,10 +71,11 @@ Status FusedIrEmitter::DefaultAction(const HloInstruction* hlo) { << b_->GetInsertBlock()->getName().str() << ")."; } - TF_ASSIGN_OR_RETURN(generated_value_cache_[hlo][index.multidim()], + TF_ASSIGN_OR_RETURN(llvm::Value* const generated_value, elemental_emitter_->MakeElementGenerator( hlo, indexed_generators_)(index)); - return generated_value_cache_[hlo][index.multidim()]; + generated_value_cache_[hlo][index.multidim()] = generated_value; + return generated_value; }; return Status::OK(); }