Use a pair of (building block, index) as key for the cache.
Previously, we only used index as key, and then checked whether the building block matches. If it didn't match, we overwrote the cache entry. In pathological cases, this means we *never* used a cached entry, because even in cases we had already computed a value that could be used, it was already replaced by a value which could not be used. This lead to exponential compile time and *huge* generated LLVM IR. PiperOrigin-RevId: 338237297 Change-Id: I258202c6238056bfb4466a95c30b6cd48700fe21
This commit is contained in:
parent
1048db0ceb
commit
87c1cf490f
@ -17,7 +17,6 @@ limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/IR/BasicBlock.h"
|
||||
#include "llvm/IR/Value.h"
|
||||
@ -44,32 +43,37 @@ using llvm_ir::IrArray;
|
||||
Status FusedIrEmitter::DefaultAction(const HloInstruction* hlo) {
|
||||
indexed_generators_[hlo] =
|
||||
[=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
||||
auto cache = generated_value_cache_.find(hlo);
|
||||
if (cache != generated_value_cache_.end()) {
|
||||
auto key = std::make_pair(b_->GetInsertBlock(), index.multidim());
|
||||
if (llvm::Value* generated_value =
|
||||
FindOrDefault(cache->second, key, nullptr)) {
|
||||
VLOG(3) << "The cached generated value is reused.";
|
||||
return generated_value;
|
||||
}
|
||||
auto null_key = std::make_pair(nullptr, index.multidim());
|
||||
if (llvm::Value* generated_value =
|
||||
FindOrDefault(cache->second, null_key, nullptr)) {
|
||||
if (llvm::Value* generated_value = FindOrDefault(
|
||||
generated_value_cache_[hlo], index.multidim(), nullptr)) {
|
||||
llvm::BasicBlock* generated_value_bb = nullptr;
|
||||
if (auto* generated_instruction =
|
||||
llvm::dyn_cast<llvm::Instruction>(generated_value)) {
|
||||
generated_value_bb = generated_instruction->getParent();
|
||||
}
|
||||
// Ideally, we should be able to reuse the cached generated value if it
|
||||
// dominates the current insertion block. However, the check for dominance
|
||||
// can be expensive and unreliable when the function is being constructed.
|
||||
//
|
||||
// It's also worth experimenting what if we don't do caching at all.
|
||||
// LLVM's CSE or GVN should be able to easily merge common subexpressions
|
||||
// that would be regenerated without caching. But this might increase the
|
||||
// JIT compilation time.
|
||||
if (generated_value_bb == nullptr ||
|
||||
generated_value_bb == b_->GetInsertBlock()) {
|
||||
VLOG(3) << "The cached generated value is reused.";
|
||||
return generated_value;
|
||||
}
|
||||
VLOG(3) << "The cached generated value can't be reused, because it is in "
|
||||
"a different BB ("
|
||||
<< generated_value_bb->getName().str()
|
||||
<< ") from the current insertion block ("
|
||||
<< b_->GetInsertBlock()->getName().str() << ").";
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value* const generated_value,
|
||||
elemental_emitter_->MakeElementGenerator(
|
||||
hlo, indexed_generators_)(index));
|
||||
llvm::BasicBlock* generated_value_bb = nullptr;
|
||||
if (auto* generated_instruction =
|
||||
llvm::dyn_cast<llvm::Instruction>(generated_value)) {
|
||||
generated_value_bb = generated_instruction->getParent();
|
||||
}
|
||||
generated_value_cache_[hlo][std::make_pair(
|
||||
generated_value_bb, index.multidim())] = generated_value;
|
||||
generated_value_cache_[hlo][index.multidim()] = generated_value;
|
||||
return generated_value;
|
||||
};
|
||||
return Status::OK();
|
||||
|
||||
@ -18,7 +18,6 @@ limitations under the License.
|
||||
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/types/optional.h"
|
||||
@ -154,10 +153,9 @@ class FusedIrEmitter : public ConstDfsHloVisitorWithDefault {
|
||||
|
||||
// Cache of generated values, lest we regenerate an element of a node with
|
||||
// multiple outgoing edges
|
||||
absl::flat_hash_map<const HloInstruction*,
|
||||
absl::flat_hash_map<std::pair<const llvm::BasicBlock*,
|
||||
std::vector<llvm::Value*>>,
|
||||
llvm::Value*>>
|
||||
absl::flat_hash_map<
|
||||
const HloInstruction*,
|
||||
absl::flat_hash_map<std::vector<llvm::Value*>, llvm::Value*>>
|
||||
generated_value_cache_;
|
||||
};
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user