Use a pair of (building block, index) as key for the cache.

Previously, we only used index as key, and then checked whether the building
block matches. If it didn't match, we overwrote the cache entry. In
pathological cases, this means we *never* used a cached entry, because even in
cases we had already computed a value that could be used, it was already
replaced by a value which could not be used. This lead to exponential compile
time and *huge* generated LLVM IR.

PiperOrigin-RevId: 338237297
Change-Id: I258202c6238056bfb4466a95c30b6cd48700fe21
This commit is contained in:
Adrian Kuegel 2020-10-21 04:15:01 -07:00 committed by TensorFlower Gardener
parent 1048db0ceb
commit 87c1cf490f
2 changed files with 26 additions and 24 deletions

View File

@ -17,7 +17,6 @@ limitations under the License.
#include <algorithm>
#include <functional>
#include <utility>
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Value.h"
@ -44,32 +43,37 @@ using llvm_ir::IrArray;
Status FusedIrEmitter::DefaultAction(const HloInstruction* hlo) {
indexed_generators_[hlo] =
[=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
auto cache = generated_value_cache_.find(hlo);
if (cache != generated_value_cache_.end()) {
auto key = std::make_pair(b_->GetInsertBlock(), index.multidim());
if (llvm::Value* generated_value =
FindOrDefault(cache->second, key, nullptr)) {
VLOG(3) << "The cached generated value is reused.";
return generated_value;
}
auto null_key = std::make_pair(nullptr, index.multidim());
if (llvm::Value* generated_value =
FindOrDefault(cache->second, null_key, nullptr)) {
if (llvm::Value* generated_value = FindOrDefault(
generated_value_cache_[hlo], index.multidim(), nullptr)) {
llvm::BasicBlock* generated_value_bb = nullptr;
if (auto* generated_instruction =
llvm::dyn_cast<llvm::Instruction>(generated_value)) {
generated_value_bb = generated_instruction->getParent();
}
// Ideally, we should be able to reuse the cached generated value if it
// dominates the current insertion block. However, the check for dominance
// can be expensive and unreliable when the function is being constructed.
//
// It's also worth experimenting what if we don't do caching at all.
// LLVM's CSE or GVN should be able to easily merge common subexpressions
// that would be regenerated without caching. But this might increase the
// JIT compilation time.
if (generated_value_bb == nullptr ||
generated_value_bb == b_->GetInsertBlock()) {
VLOG(3) << "The cached generated value is reused.";
return generated_value;
}
VLOG(3) << "The cached generated value can't be reused, because it is in "
"a different BB ("
<< generated_value_bb->getName().str()
<< ") from the current insertion block ("
<< b_->GetInsertBlock()->getName().str() << ").";
}
TF_ASSIGN_OR_RETURN(llvm::Value* const generated_value,
elemental_emitter_->MakeElementGenerator(
hlo, indexed_generators_)(index));
llvm::BasicBlock* generated_value_bb = nullptr;
if (auto* generated_instruction =
llvm::dyn_cast<llvm::Instruction>(generated_value)) {
generated_value_bb = generated_instruction->getParent();
}
generated_value_cache_[hlo][std::make_pair(
generated_value_bb, index.multidim())] = generated_value;
generated_value_cache_[hlo][index.multidim()] = generated_value;
return generated_value;
};
return Status::OK();

View File

@ -18,7 +18,6 @@ limitations under the License.
#include <map>
#include <unordered_map>
#include <utility>
#include "absl/container/flat_hash_map.h"
#include "absl/types/optional.h"
@ -154,10 +153,9 @@ class FusedIrEmitter : public ConstDfsHloVisitorWithDefault {
// Cache of generated values, lest we regenerate an element of a node with
// multiple outgoing edges
absl::flat_hash_map<const HloInstruction*,
absl::flat_hash_map<std::pair<const llvm::BasicBlock*,
std::vector<llvm::Value*>>,
llvm::Value*>>
absl::flat_hash_map<
const HloInstruction*,
absl::flat_hash_map<std::vector<llvm::Value*>, llvm::Value*>>
generated_value_cache_;
};