[XLA] Reduce number of lookups into nested hash table

The nested lookup is actually expensive, and there's no reason for doing it 4
times if 2 times is sufficient.

PiperOrigin-RevId: 307379544
Change-Id: I1f02bcc3d014e832eee9f817d70ea64f47bd02bd
This commit is contained in:
Benjamin Kramer 2020-04-20 05:19:39 -07:00 committed by TensorFlower Gardener
parent 8630cd9742
commit ccd260cd82

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "absl/container/flat_hash_set.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@ -43,9 +44,8 @@ using llvm_ir::IrArray;
Status FusedIrEmitter::DefaultAction(const HloInstruction* hlo) {
indexed_generators_[hlo] =
[=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
if (generated_value_cache_[hlo].contains(index.multidim())) {
llvm::Value* generated_value =
generated_value_cache_[hlo][index.multidim()];
if (llvm::Value* generated_value = FindOrDefault(
generated_value_cache_[hlo], index.multidim(), nullptr)) {
llvm::BasicBlock* generated_value_bb = nullptr;
if (auto* generated_instruction =
llvm::dyn_cast<llvm::Instruction>(generated_value)) {
@ -71,10 +71,11 @@ Status FusedIrEmitter::DefaultAction(const HloInstruction* hlo) {
<< b_->GetInsertBlock()->getName().str() << ").";
}
TF_ASSIGN_OR_RETURN(generated_value_cache_[hlo][index.multidim()],
TF_ASSIGN_OR_RETURN(llvm::Value* const generated_value,
elemental_emitter_->MakeElementGenerator(
hlo, indexed_generators_)(index));
return generated_value_cache_[hlo][index.multidim()];
generated_value_cache_[hlo][index.multidim()] = generated_value;
return generated_value;
};
return Status::OK();
}