[MLIR][NFC] Split RegionAliasAnalysis and SideEffectAnalysis into separate files
- Also changed the name of the build target to tensorflow_analysis PiperOrigin-RevId: 324254818 Change-Id: Ie661873942c3ccf7b406f6be796b6d5fe1b79cea
This commit is contained in:
parent
ce3315d20c
commit
f3e7bc6a0b
@ -702,6 +702,30 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorflow_analysis",
|
||||
srcs = [
|
||||
"analysis/per_function_aggregate_analysis.h",
|
||||
"analysis/resource_alias_analysis.cc",
|
||||
"analysis/side_effect_analysis.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"analysis/resource_alias_analysis.h",
|
||||
"analysis/side_effect_analysis.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorflow",
|
||||
":tensorflow_types",
|
||||
"//tensorflow/compiler/tf2xla:resource_operation_table",
|
||||
"//tensorflow/core:framework",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorflow_passes",
|
||||
srcs = [
|
||||
@ -789,8 +813,8 @@ cc_library(
|
||||
":error_util",
|
||||
":export_tf_dialect_op",
|
||||
":mangling_util",
|
||||
":side_effect_analysis",
|
||||
":tensorflow",
|
||||
":tensorflow_analysis",
|
||||
":tensorflow_optimize_inc_gen",
|
||||
":tensorflow_types",
|
||||
":tf_data_optimization",
|
||||
@ -1754,23 +1778,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "side_effect_analysis",
|
||||
srcs = ["analysis/side_effect_analysis.cc"],
|
||||
hdrs = ["analysis/side_effect_analysis.h"],
|
||||
deps = [
|
||||
":tensorflow",
|
||||
":tensorflow_types",
|
||||
"//tensorflow/compiler/tf2xla:resource_operation_table",
|
||||
"//tensorflow/core:framework",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_sharding_util",
|
||||
srcs = [
|
||||
|
||||
@ -0,0 +1,76 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_PER_FUNCTION_AGGREGATE_ANALYSIS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_PER_FUNCTION_AGGREGATE_ANALYSIS_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
namespace detail {
|
||||
|
||||
// This template defines an aggregate analysis base class, which analyzes a
|
||||
// module but the analysis info is stored per function.
|
||||
template <typename InfoT>
|
||||
class PerFunctionAggregateAnalysis {
|
||||
public:
|
||||
using Info = InfoT;
|
||||
|
||||
// Returns the analysis info for the given function.
|
||||
const Info& GetAnalysisForFunc(FuncOp func) const {
|
||||
auto it = info_map_.find(func);
|
||||
assert(it != info_map_.end());
|
||||
return it->second;
|
||||
}
|
||||
|
||||
protected:
|
||||
llvm::SmallDenseMap<FuncOp, InfoT, 8> info_map_;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Base CRTP class to help write passes that are consumes a per-function
|
||||
// aggregate analysis and operate on all non-extern functions (similar to a
|
||||
// FunctionPass, but with no concurrency between functions). The derived classes
|
||||
// need to provide a runOnFunction() method that accepts the function and the
|
||||
// analysis information for that function.
|
||||
template <typename DerivedT, typename AnalysisT>
|
||||
class PerFunctionAggregateAnalysisConsumerPass
|
||||
: public PassWrapper<
|
||||
PerFunctionAggregateAnalysisConsumerPass<DerivedT, AnalysisT>,
|
||||
OperationPass<ModuleOp>> {
|
||||
void runOnOperation() override {
|
||||
ModuleOp op = this->getOperation();
|
||||
DerivedT& derived = *static_cast<DerivedT*>(this);
|
||||
auto& analysis = this->template getAnalysis<AnalysisT>();
|
||||
|
||||
for (auto func : op.getOps<FuncOp>())
|
||||
if (!func.isExternal())
|
||||
derived.runOnFunction(func, analysis.GetAnalysisForFunc(func));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_PER_FUNCTION_AGGREGATE_ANALYSIS_H_
|
||||
@ -0,0 +1,406 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <initializer_list>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Block.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
namespace {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BacktrackAnalysisInfo
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Class to hold backtrack analysis for a results of a region. Backtrack
|
||||
// analysis will trace back the definition of return values of regions through
|
||||
// pass-through operations, so that the return value of the region will have the
|
||||
// same value as the backtracked value.
|
||||
class BacktrackAnalysisInfo {
|
||||
public:
|
||||
// Initializes the backtrack analysis for the given region.
|
||||
explicit BacktrackAnalysisInfo(Region& region,
|
||||
detail::BacktrackAnalysis& backtrack_analysis);
|
||||
|
||||
BacktrackAnalysisInfo(BacktrackAnalysisInfo&&) = default;
|
||||
|
||||
// Returns the value to which the given result number of the region can be
|
||||
// backtracked to.
|
||||
Value GetValue(int result_index) const {
|
||||
return backtracked_values_[result_index];
|
||||
}
|
||||
|
||||
// Returns the argument index of the region to which the given result number
|
||||
// can backtracked to. Such results will be called "function passthrough". If
|
||||
// the result cannot be backtracked to a region argument, returns llvm::None.
|
||||
llvm::Optional<int> GetArg(int result_index) const {
|
||||
if (auto arg = GetValue(result_index).dyn_cast<BlockArgument>())
|
||||
if (arg.getParentBlock() == ®ion_->front()) return arg.getArgNumber();
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
private:
|
||||
friend class detail::BacktrackAnalysis;
|
||||
|
||||
// Region for which this object holds the analysis info.
|
||||
Region* region_;
|
||||
|
||||
// Backtracked values indexed by the result number.
|
||||
llvm::SmallVector<Value, 4> backtracked_values_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace detail {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BacktrackAnalysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Holds backtrack analysis for all functions and regions within a module.
|
||||
class BacktrackAnalysis {
|
||||
public:
|
||||
using InfoT = BacktrackAnalysisInfo;
|
||||
|
||||
// Constructs the analysis by analyzing the given module.
|
||||
explicit BacktrackAnalysis(ModuleOp module);
|
||||
|
||||
// Returns backtracking analysis for the given region.
|
||||
const InfoT& GetAnalysisForRegion(Region& region) const {
|
||||
auto it = info_map_.find(®ion);
|
||||
assert(it != info_map_.end());
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Returns backtracking analysis for the given function.
|
||||
const InfoT& GetAnalysisForFunc(FuncOp func) const {
|
||||
return GetAnalysisForRegion(func.getBody());
|
||||
}
|
||||
|
||||
// Backtracks the given value.
|
||||
Value BacktrackValue(Value value);
|
||||
|
||||
private:
|
||||
// Returns the analysis for the given region (analyzing the region if it has
|
||||
// not yet been analyzed).
|
||||
const InfoT& GetOrCreateAnalysis(Region& region) {
|
||||
auto it = info_map_.find(®ion);
|
||||
if (it == info_map_.end()) {
|
||||
// Note: Keep object construction and insertion separate. If we use
|
||||
// emplace() to construct and insert in a single shot, when analyzing
|
||||
// this region, calls to BacktrackValue() may end up inserting additional
|
||||
// entries in the map, causing the underlying storage to be moved. This
|
||||
// would also include this pertially constructed object that we have just
|
||||
// inserted into the map and are constructing it. To avoid this issue,
|
||||
// construct the analysis object separately and then insert it into the
|
||||
// map.
|
||||
InfoT info(region, *this);
|
||||
info_map_.insert({®ion, std::move(info)});
|
||||
}
|
||||
|
||||
return GetAnalysisForRegion(region);
|
||||
}
|
||||
|
||||
private:
|
||||
llvm::SmallDenseMap<Region*, InfoT> info_map_;
|
||||
};
|
||||
|
||||
// Analyzes all regions attached to all operations in the module.
|
||||
BacktrackAnalysis::BacktrackAnalysis(ModuleOp module) {
|
||||
module.walk([this](Operation* op) {
|
||||
for (Region& region : op->getRegions()) GetOrCreateAnalysis(region);
|
||||
});
|
||||
}
|
||||
|
||||
// Backtracks the definition of `value` looking through passthrough ops.
|
||||
// Returns a non-null value and can return `value` if backtracking is not
|
||||
// possible.
|
||||
Value BacktrackAnalysis::BacktrackValue(Value value) {
|
||||
while (Operation* op = value.getDefiningOp()) {
|
||||
int res_index = value.cast<OpResult>().getResultNumber();
|
||||
if (auto graph = dyn_cast<tf_executor::GraphOp>(op)) {
|
||||
value = graph.GetFetch().getOperand(res_index);
|
||||
} else if (auto island = dyn_cast<tf_executor::IslandOp>(op)) {
|
||||
// Control output is generated by the IslandOp, not the yield in
|
||||
// in the Island body.
|
||||
if (value == island.control()) break;
|
||||
value = island.GetYield().getOperand(res_index);
|
||||
} else if (isa<TF::IdentityNOp, TF::IdentityOp>(op)) {
|
||||
value = op->getOperand(res_index);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return value;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
namespace {
|
||||
|
||||
// Analyze the region.
|
||||
BacktrackAnalysisInfo::BacktrackAnalysisInfo(
|
||||
Region& region, detail::BacktrackAnalysis& backtrack_analysis)
|
||||
: region_(®ion) {
|
||||
if (region.empty()) return;
|
||||
|
||||
assert(llvm::hasSingleElement(region.getBlocks()));
|
||||
auto results = region.front().getTerminator()->getOperands();
|
||||
if (results.empty()) return;
|
||||
|
||||
backtracked_values_.reserve(results.size());
|
||||
for (auto result : results)
|
||||
backtracked_values_.push_back(backtrack_analysis.BacktrackValue(result));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ResourceAliasAnalysisInfo helper functions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id";
|
||||
|
||||
// Returns if a VarHandleOp is anonymous, which means it always creates a new
|
||||
// variable.
|
||||
bool IsResourceHandleAnonymous(TF::VarHandleOp handle) {
|
||||
return handle.shared_name() == tensorflow::ResourceHandle::ANONYMOUS_NAME;
|
||||
}
|
||||
|
||||
// Returns a string unique identifier for a non-anonymous VarHandleOp.
|
||||
std::string GetVarHandleStringId(TF::VarHandleOp handle) {
|
||||
auto device = handle.getAttrOfType<StringAttr>("device");
|
||||
return absl::StrCat(handle.container().str(), "/", handle.shared_name().str(),
|
||||
"/", device ? device.getValue().str() : std::string(""));
|
||||
}
|
||||
|
||||
// Finds a unique ID for a VarHandleOp's output. If it is anonymous, always
|
||||
// creates a new ID; otherwise, tries to reuse the existing ID for the
|
||||
// referenced variable if it exists, or creates a new one if not.
|
||||
int64_t GetOrCreateIdForVarHandle(TF::VarHandleOp handle, int64_t* next_id,
|
||||
llvm::StringMap<int64_t>* name_id_map) {
|
||||
// Always create a new ID for anonymous handle.
|
||||
if (IsResourceHandleAnonymous(handle)) return (*next_id)++;
|
||||
|
||||
auto name = GetVarHandleStringId(handle);
|
||||
auto emplace_res = name_id_map->try_emplace(name, *next_id);
|
||||
// New ID created, increment next_id.
|
||||
if (emplace_res.second) ++(*next_id);
|
||||
return emplace_res.first->second;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace detail {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ResourceAliasAnalysisInfo
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Constructs the analysis info by analyzing the given function.
|
||||
ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo(
|
||||
FuncOp func_op, const detail::BacktrackAnalysis& backtrack_analysis) {
|
||||
// This function populates resource_value_to_ids_ and id_to_resource_values_.
|
||||
|
||||
// If the "tf.resource_arg_unique_id" argument attributes are present for
|
||||
// resource-type arguments, respect them when choosing IDs; otherwise, they
|
||||
// must not alias.
|
||||
int64_t next_unique_id = 0;
|
||||
const bool has_arg_unique_id_attrs =
|
||||
llvm::any_of(func_op.getArguments(), [&](const BlockArgument& arg) {
|
||||
return func_op.getArgAttr(arg.getArgNumber(), kResourceArgUniqueIdAttr);
|
||||
});
|
||||
// Maps the kResourceArgUniqueIdAttr attribute value to the internal integer
|
||||
// ID used by this pass.
|
||||
llvm::SmallDenseMap<int64_t, int64_t> attr_id_to_internal_id;
|
||||
for (auto arg : func_op.getArguments()) {
|
||||
if (!mlir::getElementTypeOrSelf(arg.getType()).isa<TF::ResourceType>())
|
||||
continue;
|
||||
if (has_arg_unique_id_attrs) {
|
||||
auto id_attr = func_op.getArgAttrOfType<IntegerAttr>(
|
||||
arg.getArgNumber(), kResourceArgUniqueIdAttr);
|
||||
assert(id_attr &&
|
||||
"tf.resource_arg_unique_id attribute should exist on either none "
|
||||
"or all arguments.");
|
||||
auto emplace_res = attr_id_to_internal_id.try_emplace(id_attr.getInt(),
|
||||
next_unique_id++);
|
||||
AddValueUniqueIDMapping(arg, emplace_res.first->getSecond());
|
||||
} else {
|
||||
AddValueUniqueIDMapping(arg, next_unique_id++);
|
||||
}
|
||||
}
|
||||
llvm::StringMap<int64_t> var_handle_name_id_map;
|
||||
auto forward_input_to_output = [&](const Value& operand,
|
||||
const Value& result) {
|
||||
if (!mlir::getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>())
|
||||
return;
|
||||
auto& result_ids = resource_value_to_ids_[result];
|
||||
auto operand_it = resource_value_to_ids_.find(operand);
|
||||
assert(operand_it != resource_value_to_ids_.end() &&
|
||||
"A resource-type output does not have the corresponding "
|
||||
"resource-type input.");
|
||||
result_ids.insert(operand_it->getSecond().begin(),
|
||||
operand_it->getSecond().end());
|
||||
};
|
||||
|
||||
func_op.walk([&](Operation* op) {
|
||||
if (auto var_handle = llvm::dyn_cast<TF::VarHandleOp>(op)) {
|
||||
AddValueUniqueIDMapping(
|
||||
var_handle.resource(),
|
||||
GetOrCreateIdForVarHandle(var_handle, &next_unique_id,
|
||||
&var_handle_name_id_map));
|
||||
} else if (llvm::isa<TF::IdentityNOp, TF::IdentityOp>(op)) {
|
||||
for (auto operand_and_result :
|
||||
llvm::zip(op->getOperands(), op->getResults())) {
|
||||
forward_input_to_output(std::get<0>(operand_and_result),
|
||||
std::get<1>(operand_and_result));
|
||||
}
|
||||
} else if (auto replicate = llvm::dyn_cast<tf_device::ReplicateOp>(op)) {
|
||||
// The nested block for ReplicateOp is handled separately in side-effect
|
||||
// analysis. Inside that block, we can still treat its block arguments as
|
||||
// different resources.
|
||||
for (auto arg : replicate.GetBody().getArguments()) {
|
||||
if (mlir::getElementTypeOrSelf(arg.getType()).isa<TF::ResourceType>()) {
|
||||
AddValueUniqueIDMapping(arg, next_unique_id++);
|
||||
}
|
||||
}
|
||||
} else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(op)) {
|
||||
const auto& body_info =
|
||||
backtrack_analysis.GetAnalysisForFunc(while_op.body_func());
|
||||
// If a result is a passthrough of the body input, use the corresponding
|
||||
// operand's resource IDs.
|
||||
for (auto result : llvm::enumerate(while_op.getResults())) {
|
||||
if (!mlir::getElementTypeOrSelf(result.value().getType())
|
||||
.isa<TF::ResourceType>()) {
|
||||
continue;
|
||||
}
|
||||
auto passthrough_arg = body_info.GetArg(result.index());
|
||||
if (passthrough_arg) {
|
||||
forward_input_to_output(
|
||||
while_op.getOperand(passthrough_arg.getValue()), result.value());
|
||||
} else {
|
||||
AddValueUniqueIDMapping(result.value(), kUnknownResourceId);
|
||||
}
|
||||
}
|
||||
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(op)) {
|
||||
const auto& then_info =
|
||||
backtrack_analysis.GetAnalysisForFunc(if_op.then_func());
|
||||
const auto& else_info =
|
||||
backtrack_analysis.GetAnalysisForFunc(if_op.else_func());
|
||||
// If a result is a passthrough of both branches' inputs, merge the
|
||||
// resource IDs of corresponding operands for the two inputs.
|
||||
for (auto result : llvm::enumerate(if_op.getResults())) {
|
||||
if (!mlir::getElementTypeOrSelf(result.value().getType())
|
||||
.isa<TF::ResourceType>()) {
|
||||
continue;
|
||||
}
|
||||
auto passthrough_then_arg = then_info.GetArg(result.index());
|
||||
auto passthrough_else_arg = else_info.GetArg(result.index());
|
||||
if (passthrough_then_arg && passthrough_else_arg) {
|
||||
Value then_operand = if_op.input()[passthrough_then_arg.getValue()];
|
||||
Value else_operand = if_op.input()[passthrough_else_arg.getValue()];
|
||||
forward_input_to_output(then_operand, result.value());
|
||||
forward_input_to_output(else_operand, result.value());
|
||||
} else {
|
||||
AddValueUniqueIDMapping(result.value(), kUnknownResourceId);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (auto result : op->getResults()) {
|
||||
if (!mlir::getElementTypeOrSelf(result.getType())
|
||||
.isa<TF::ResourceType>())
|
||||
continue;
|
||||
AddValueUniqueIDMapping(result, kUnknownResourceId);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
bool ResourceAliasAnalysisInfo::IsUnknownResource(const Value resource) const {
|
||||
auto it = resource_value_to_ids_.find(resource);
|
||||
assert(it != resource_value_to_ids_.end() && !it->getSecond().empty());
|
||||
// The set is sorted so we only need to check the first element since
|
||||
// kUnknownResourceId < 0.
|
||||
static_assert(kUnknownResourceId < 0,
|
||||
"kUnknownResourceId should be negative");
|
||||
return *it->getSecond().begin() == kUnknownResourceId;
|
||||
}
|
||||
|
||||
const llvm::SmallSet<int64_t, 8>&
|
||||
ResourceAliasAnalysisInfo::GetResourceUniqueIds(Value resource) const {
|
||||
auto it = resource_value_to_ids_.find(resource);
|
||||
assert(it != resource_value_to_ids_.end() && "Unseen resource was queried");
|
||||
return it->getSecond();
|
||||
}
|
||||
|
||||
const llvm::SmallSetVector<Value, 8>&
|
||||
ResourceAliasAnalysisInfo::GetUniqueIdResources(const int64_t id) const {
|
||||
auto it = id_to_resource_values_.find(id);
|
||||
assert(it != id_to_resource_values_.end() && "Unseen id was queried");
|
||||
return it->getSecond();
|
||||
}
|
||||
|
||||
llvm::SmallSetVector<Value, 8> ResourceAliasAnalysisInfo::GetResourceAliases(
|
||||
const Value resource) const {
|
||||
assert(!IsUnknownResource(resource) && "Unseen resource was queried");
|
||||
llvm::SmallSetVector<Value, 8> aliases;
|
||||
for (int64_t id : GetResourceUniqueIds(resource)) {
|
||||
const llvm::SmallSetVector<Value, 8>& resources_aliasing_id =
|
||||
GetUniqueIdResources(id);
|
||||
aliases.insert(resources_aliasing_id.begin(), resources_aliasing_id.end());
|
||||
}
|
||||
return aliases;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ResourceAliasAnalysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ResourceAliasAnalysis::ResourceAliasAnalysis(Operation* op) {
|
||||
auto module = dyn_cast<ModuleOp>(op);
|
||||
assert(module);
|
||||
|
||||
// Analyze all regions for backtracking info.
|
||||
detail::BacktrackAnalysis backtrack_analysis(module);
|
||||
|
||||
// Analyze each function.
|
||||
for (auto func : module.getOps<FuncOp>())
|
||||
this->info_map_.try_emplace(func, func, backtrack_analysis);
|
||||
}
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
@ -0,0 +1,97 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_RESOURCE_ALIAS_ANALYSIS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_RESOURCE_ALIAS_ANALYSIS_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/Region.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/analysis/per_function_aggregate_analysis.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
namespace detail {
|
||||
class BacktrackAnalysis;
|
||||
|
||||
// Resource alias analysis information for a single function.
|
||||
class ResourceAliasAnalysisInfo {
|
||||
public:
|
||||
// Constructs analysis info by analyzing the given function.
|
||||
ResourceAliasAnalysisInfo(FuncOp func,
|
||||
const BacktrackAnalysis& backtrack_analysis);
|
||||
|
||||
ResourceAliasAnalysisInfo(ResourceAliasAnalysisInfo&&) = default;
|
||||
|
||||
// Returns if the analysis fails to resolve a resource-type value.
|
||||
bool IsUnknownResource(const Value resource) const;
|
||||
|
||||
// Returns the set unique IDs which `resource` could alias. Requires that
|
||||
// IsUnknownResource(resource) == false.
|
||||
const llvm::SmallSet<int64_t, 8>& GetResourceUniqueIds(Value resource) const;
|
||||
|
||||
// Returns the set of values that are potentially aliases of `value`. Requires
|
||||
// that IsUnknownResource(resource) == false.
|
||||
llvm::SmallSetVector<Value, 8> GetResourceAliases(Value resource) const;
|
||||
|
||||
private:
|
||||
// Maps resource value to unique ID and vice-versa.
|
||||
void AddValueUniqueIDMapping(Value value, int64_t id) {
|
||||
resource_value_to_ids_[value].insert(id);
|
||||
id_to_resource_values_[id].insert(value);
|
||||
}
|
||||
|
||||
// Returns the set unique Values which map to `id`.
|
||||
const llvm::SmallSetVector<Value, 8>& GetUniqueIdResources(int64_t id) const;
|
||||
|
||||
// Maps each resource-type value to a set of unique IDs that it could alias.
|
||||
llvm::SmallDenseMap<Value, llvm::SmallSet<int64_t, 8>, 8>
|
||||
resource_value_to_ids_;
|
||||
|
||||
// Maps each unique ID to a set of resource-type values that could alias to
|
||||
// it. This is inverse of `resource_value_to_ids_` map.
|
||||
llvm::SmallDenseMap<int64_t, llvm::SmallSetVector<Value, 8>, 8>
|
||||
id_to_resource_values_;
|
||||
|
||||
public:
|
||||
static constexpr int64_t kUnknownResourceId = -1;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// An analysis that runs on a module and maps each resource-type value to a
|
||||
// set of unique IDs representing the possible resources it could alias.
|
||||
//
|
||||
// Note that this is not an inter-procedural or inter-regional analysis, i.e.,
|
||||
// each function and region are handled separately and cross-function or cross-
|
||||
// region aliasing cannot be checked by this analysis.
|
||||
class ResourceAliasAnalysis : public detail::PerFunctionAggregateAnalysis<
|
||||
detail::ResourceAliasAnalysisInfo> {
|
||||
public:
|
||||
// Constructs analysis by analyzing the given module operation.
|
||||
explicit ResourceAliasAnalysis(Operation* op);
|
||||
};
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_RESOURCE_ALIAS_ANALYSIS_H_
|
||||
@ -45,368 +45,10 @@ limitations under the License.
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
namespace {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BacktrackAnalysisInfo
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Class to hold backtrack analysis for a results of a region. Backtrack
|
||||
// analysis will trace back the definition of return values of regions through
|
||||
// pass-through operations, so that the return value of the region will have the
|
||||
// same value as the backtracked value.
|
||||
class BacktrackAnalysisInfo {
|
||||
public:
|
||||
// Initializes the backtrack analysis for the given region.
|
||||
explicit BacktrackAnalysisInfo(Region& region,
|
||||
detail::BacktrackAnalysis& backtrack_analysis);
|
||||
|
||||
BacktrackAnalysisInfo(BacktrackAnalysisInfo&&) = default;
|
||||
|
||||
// Returns the value to which the given result number of the region can be
|
||||
// backtracked to.
|
||||
Value GetValue(int result_index) const {
|
||||
return backtracked_values_[result_index];
|
||||
}
|
||||
|
||||
// Returns the argument index of the region to which the given result number
|
||||
// can backtracked to. Such results will be called "function passthrough". If
|
||||
// the result cannot be backtracked to a region argument, returns llvm::None.
|
||||
llvm::Optional<int> GetArg(int result_index) const {
|
||||
if (auto arg = GetValue(result_index).dyn_cast<BlockArgument>())
|
||||
if (arg.getParentBlock() == ®ion_->front()) return arg.getArgNumber();
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
private:
|
||||
friend class detail::BacktrackAnalysis;
|
||||
|
||||
// Region for which this object holds the analysis info.
|
||||
Region* region_;
|
||||
|
||||
// Backtracked values indexed by the result number.
|
||||
llvm::SmallVector<Value, 4> backtracked_values_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace detail {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BacktrackAnalysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Holds backtrack analysis for all functions and regions within a module.
|
||||
class BacktrackAnalysis {
|
||||
public:
|
||||
using InfoT = BacktrackAnalysisInfo;
|
||||
|
||||
// Constructs the analysis by analyzing the given module.
|
||||
explicit BacktrackAnalysis(ModuleOp module);
|
||||
|
||||
// Returns backtracking analysis for the given region.
|
||||
const InfoT& GetAnalysisForRegion(Region& region) const {
|
||||
auto it = info_map_.find(®ion);
|
||||
assert(it != info_map_.end());
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Returns backtracking analysis for the given function.
|
||||
const InfoT& GetAnalysisForFunc(FuncOp func) const {
|
||||
return GetAnalysisForRegion(func.getBody());
|
||||
}
|
||||
|
||||
// Backtracks the given value.
|
||||
Value BacktrackValue(Value value);
|
||||
|
||||
private:
|
||||
// Returns the analysis for the given region (analyzing the region if it has
|
||||
// not yet been analyzed).
|
||||
const InfoT& GetOrCreateAnalysis(Region& region) {
|
||||
auto it = info_map_.find(®ion);
|
||||
if (it == info_map_.end()) {
|
||||
// Note: Keep object construction and insertion separate. If we use
|
||||
// emplace() to construct and insert in a single shot, when analyzing
|
||||
// this region, calls to BacktrackValue() may end up inserting additional
|
||||
// entries in the map, causing the underlying storage to be moved. This
|
||||
// would also include this pertially constructed object that we have just
|
||||
// inserted into the map and are constructing it. To avoid this issue,
|
||||
// construct the analysis object separately and then insert it into the
|
||||
// map.
|
||||
InfoT info(region, *this);
|
||||
info_map_.insert({®ion, std::move(info)});
|
||||
}
|
||||
|
||||
return GetAnalysisForRegion(region);
|
||||
}
|
||||
|
||||
private:
|
||||
llvm::SmallDenseMap<Region*, InfoT> info_map_;
|
||||
};
|
||||
|
||||
// Analyzes all regions attached to all operations in the module.
|
||||
BacktrackAnalysis::BacktrackAnalysis(ModuleOp module) {
|
||||
module.walk([this](Operation* op) {
|
||||
for (Region& region : op->getRegions()) GetOrCreateAnalysis(region);
|
||||
});
|
||||
}
|
||||
|
||||
// Backtracks the definition of `value` looking through passthrough ops.
|
||||
// Returns a non-null value and can return `value` if backtracking is not
|
||||
// possible.
|
||||
Value BacktrackAnalysis::BacktrackValue(Value value) {
|
||||
while (Operation* op = value.getDefiningOp()) {
|
||||
int res_index = value.cast<OpResult>().getResultNumber();
|
||||
if (auto graph = dyn_cast<tf_executor::GraphOp>(op)) {
|
||||
value = graph.GetFetch().getOperand(res_index);
|
||||
} else if (auto island = dyn_cast<tf_executor::IslandOp>(op)) {
|
||||
// Control output is generated by the IslandOp, not the yield in
|
||||
// in the Island body.
|
||||
if (value == island.control()) break;
|
||||
value = island.GetYield().getOperand(res_index);
|
||||
} else if (isa<TF::IdentityNOp, TF::IdentityOp>(op)) {
|
||||
value = op->getOperand(res_index);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return value;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
namespace {
|
||||
|
||||
// Analyze the region.
|
||||
BacktrackAnalysisInfo::BacktrackAnalysisInfo(
|
||||
Region& region, detail::BacktrackAnalysis& backtrack_analysis)
|
||||
: region_(®ion) {
|
||||
if (region.empty()) return;
|
||||
|
||||
assert(llvm::hasSingleElement(region.getBlocks()));
|
||||
auto results = region.front().getTerminator()->getOperands();
|
||||
if (results.empty()) return;
|
||||
|
||||
backtracked_values_.reserve(results.size());
|
||||
for (auto result : results)
|
||||
backtracked_values_.push_back(backtrack_analysis.BacktrackValue(result));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ResourceAliasAnalysisInfo helper functions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
constexpr int64_t kUnknownResourceId = -1;
|
||||
constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id";
|
||||
|
||||
// Returns if a VarHandleOp is anonymous, which means it always creates a new
|
||||
// variable.
|
||||
bool IsResourceHandleAnonymous(TF::VarHandleOp handle) {
|
||||
return handle.shared_name() == tensorflow::ResourceHandle::ANONYMOUS_NAME;
|
||||
}
|
||||
|
||||
// Returns a string unique identifier for a non-anonymous VarHandleOp.
|
||||
std::string GetVarHandleStringId(TF::VarHandleOp handle) {
|
||||
auto device = handle.getAttrOfType<StringAttr>("device");
|
||||
return absl::StrCat(handle.container().str(), "/", handle.shared_name().str(),
|
||||
"/", device ? device.getValue().str() : std::string(""));
|
||||
}
|
||||
|
||||
// Finds a unique ID for a VarHandleOp's output. If it is anonymous, always
|
||||
// creates a new ID; otherwise, tries to reuse the existing ID for the
|
||||
// referenced variable if it exists, or creates a new one if not.
|
||||
int64_t GetOrCreateIdForVarHandle(TF::VarHandleOp handle, int64_t* next_id,
|
||||
llvm::StringMap<int64_t>* name_id_map) {
|
||||
// Always create a new ID for anonymous handle.
|
||||
if (IsResourceHandleAnonymous(handle)) return (*next_id)++;
|
||||
|
||||
auto name = GetVarHandleStringId(handle);
|
||||
auto emplace_res = name_id_map->try_emplace(name, *next_id);
|
||||
// New ID created, increment next_id.
|
||||
if (emplace_res.second) ++(*next_id);
|
||||
return emplace_res.first->second;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace detail {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ResourceAliasAnalysisInfo
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Constructs the analysis info by analyzing the given function.
|
||||
ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo(
|
||||
FuncOp func_op, const detail::BacktrackAnalysis& backtrack_analysis) {
|
||||
// This function populates resource_value_to_ids_ and id_to_resource_values_.
|
||||
|
||||
// If the "tf.resource_arg_unique_id" argument attributes are present for
|
||||
// resource-type arguments, respect them when choosing IDs; otherwise, they
|
||||
// must not alias.
|
||||
int64_t next_unique_id = 0;
|
||||
const bool has_arg_unique_id_attrs =
|
||||
llvm::any_of(func_op.getArguments(), [&](const BlockArgument& arg) {
|
||||
return func_op.getArgAttr(arg.getArgNumber(), kResourceArgUniqueIdAttr);
|
||||
});
|
||||
// Maps the kResourceArgUniqueIdAttr attribute value to the internal integer
|
||||
// ID used by this pass.
|
||||
llvm::SmallDenseMap<int64_t, int64_t> attr_id_to_internal_id;
|
||||
for (auto arg : func_op.getArguments()) {
|
||||
if (!mlir::getElementTypeOrSelf(arg.getType()).isa<TF::ResourceType>())
|
||||
continue;
|
||||
if (has_arg_unique_id_attrs) {
|
||||
auto id_attr = func_op.getArgAttrOfType<IntegerAttr>(
|
||||
arg.getArgNumber(), kResourceArgUniqueIdAttr);
|
||||
assert(id_attr &&
|
||||
"tf.resource_arg_unique_id attribute should exist on either none "
|
||||
"or all arguments.");
|
||||
auto emplace_res = attr_id_to_internal_id.try_emplace(id_attr.getInt(),
|
||||
next_unique_id++);
|
||||
AddValueUniqueIDMapping(arg, emplace_res.first->getSecond());
|
||||
} else {
|
||||
AddValueUniqueIDMapping(arg, next_unique_id++);
|
||||
}
|
||||
}
|
||||
llvm::StringMap<int64_t> var_handle_name_id_map;
|
||||
auto forward_input_to_output = [&](const Value& operand,
|
||||
const Value& result) {
|
||||
if (!mlir::getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>())
|
||||
return;
|
||||
auto& result_ids = resource_value_to_ids_[result];
|
||||
auto operand_it = resource_value_to_ids_.find(operand);
|
||||
assert(operand_it != resource_value_to_ids_.end() &&
|
||||
"A resource-type output does not have the corresponding "
|
||||
"resource-type input.");
|
||||
result_ids.insert(operand_it->getSecond().begin(),
|
||||
operand_it->getSecond().end());
|
||||
};
|
||||
|
||||
func_op.walk([&](Operation* op) {
|
||||
if (auto var_handle = llvm::dyn_cast<TF::VarHandleOp>(op)) {
|
||||
AddValueUniqueIDMapping(
|
||||
var_handle.resource(),
|
||||
GetOrCreateIdForVarHandle(var_handle, &next_unique_id,
|
||||
&var_handle_name_id_map));
|
||||
} else if (llvm::isa<TF::IdentityNOp, TF::IdentityOp>(op)) {
|
||||
for (auto operand_and_result :
|
||||
llvm::zip(op->getOperands(), op->getResults())) {
|
||||
forward_input_to_output(std::get<0>(operand_and_result),
|
||||
std::get<1>(operand_and_result));
|
||||
}
|
||||
} else if (auto replicate = llvm::dyn_cast<tf_device::ReplicateOp>(op)) {
|
||||
// The nested block for ReplicateOp is handled separately in side-effect
|
||||
// analysis. Inside that block, we can still treat its block arguments as
|
||||
// different resources.
|
||||
for (auto arg : replicate.GetBody().getArguments()) {
|
||||
if (mlir::getElementTypeOrSelf(arg.getType()).isa<TF::ResourceType>()) {
|
||||
AddValueUniqueIDMapping(arg, next_unique_id++);
|
||||
}
|
||||
}
|
||||
} else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(op)) {
|
||||
const auto& body_info =
|
||||
backtrack_analysis.GetAnalysisForFunc(while_op.body_func());
|
||||
// If a result is a passthrough of the body input, use the corresponding
|
||||
// operand's resource IDs.
|
||||
for (auto result : llvm::enumerate(while_op.getResults())) {
|
||||
if (!mlir::getElementTypeOrSelf(result.value().getType())
|
||||
.isa<TF::ResourceType>()) {
|
||||
continue;
|
||||
}
|
||||
auto passthrough_arg = body_info.GetArg(result.index());
|
||||
if (passthrough_arg) {
|
||||
forward_input_to_output(
|
||||
while_op.getOperand(passthrough_arg.getValue()), result.value());
|
||||
} else {
|
||||
AddValueUniqueIDMapping(result.value(), kUnknownResourceId);
|
||||
}
|
||||
}
|
||||
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(op)) {
|
||||
const auto& then_info =
|
||||
backtrack_analysis.GetAnalysisForFunc(if_op.then_func());
|
||||
const auto& else_info =
|
||||
backtrack_analysis.GetAnalysisForFunc(if_op.else_func());
|
||||
// If a result is a passthrough of both branches' inputs, merge the
|
||||
// resource IDs of corresponding operands for the two inputs.
|
||||
for (auto result : llvm::enumerate(if_op.getResults())) {
|
||||
if (!mlir::getElementTypeOrSelf(result.value().getType())
|
||||
.isa<TF::ResourceType>()) {
|
||||
continue;
|
||||
}
|
||||
auto passthrough_then_arg = then_info.GetArg(result.index());
|
||||
auto passthrough_else_arg = else_info.GetArg(result.index());
|
||||
if (passthrough_then_arg && passthrough_else_arg) {
|
||||
Value then_operand = if_op.input()[passthrough_then_arg.getValue()];
|
||||
Value else_operand = if_op.input()[passthrough_else_arg.getValue()];
|
||||
forward_input_to_output(then_operand, result.value());
|
||||
forward_input_to_output(else_operand, result.value());
|
||||
} else {
|
||||
AddValueUniqueIDMapping(result.value(), kUnknownResourceId);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (auto result : op->getResults()) {
|
||||
if (!mlir::getElementTypeOrSelf(result.getType())
|
||||
.isa<TF::ResourceType>())
|
||||
continue;
|
||||
AddValueUniqueIDMapping(result, kUnknownResourceId);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
bool ResourceAliasAnalysisInfo::IsUnknownResource(const Value resource) const {
|
||||
auto it = resource_value_to_ids_.find(resource);
|
||||
assert(it != resource_value_to_ids_.end() && !it->getSecond().empty());
|
||||
// The set is sorted so we only need to check the first element since
|
||||
// kUnknownResourceId < 0.
|
||||
static_assert(kUnknownResourceId < 0,
|
||||
"kUnknownResourceId should be negative");
|
||||
return *it->getSecond().begin() == kUnknownResourceId;
|
||||
}
|
||||
|
||||
const llvm::SmallSet<int64_t, 8>&
|
||||
ResourceAliasAnalysisInfo::GetResourceUniqueIds(Value resource) const {
|
||||
auto it = resource_value_to_ids_.find(resource);
|
||||
assert(it != resource_value_to_ids_.end() && "Unseen resource was queried");
|
||||
return it->getSecond();
|
||||
}
|
||||
|
||||
const llvm::SmallSetVector<Value, 8>&
|
||||
ResourceAliasAnalysisInfo::GetUniqueIdResources(const int64_t id) const {
|
||||
auto it = id_to_resource_values_.find(id);
|
||||
assert(it != id_to_resource_values_.end() && "Unseen id was queried");
|
||||
return it->getSecond();
|
||||
}
|
||||
|
||||
llvm::SmallSetVector<Value, 8> ResourceAliasAnalysisInfo::GetResourceAliases(
|
||||
const Value resource) const {
|
||||
assert(!IsUnknownResource(resource) && "Unseen resource was queried");
|
||||
llvm::SmallSetVector<Value, 8> aliases;
|
||||
for (int64_t id : GetResourceUniqueIds(resource)) {
|
||||
const llvm::SmallSetVector<Value, 8>& resources_aliasing_id =
|
||||
GetUniqueIdResources(id);
|
||||
aliases.insert(resources_aliasing_id.begin(), resources_aliasing_id.end());
|
||||
}
|
||||
return aliases;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ResourceAliasAnalysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ResourceAliasAnalysis::ResourceAliasAnalysis(Operation* op) {
|
||||
auto module = dyn_cast<ModuleOp>(op);
|
||||
assert(module);
|
||||
|
||||
// Analyze all regions for backtracking info.
|
||||
detail::BacktrackAnalysis backtrack_analysis(module);
|
||||
|
||||
// Analyze each function.
|
||||
for (auto func : module.getOps<FuncOp>())
|
||||
this->info_map_.try_emplace(func, func, backtrack_analysis);
|
||||
}
|
||||
|
||||
namespace {
|
||||
constexpr auto kUnknownResourceId =
|
||||
ResourceAliasAnalysis::Info::kUnknownResourceId;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SideEffectAnalysisInfo helper functions.
|
||||
|
||||
@ -20,99 +20,19 @@ limitations under the License.
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/Region.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
namespace detail {
|
||||
|
||||
// This template defines an aggregate analysis base class, which analyzes a
|
||||
// module but the analysis info is stored per function.
|
||||
template <typename InfoT>
|
||||
class PerFunctionAggregateAnalysis {
|
||||
public:
|
||||
using Info = InfoT;
|
||||
|
||||
// Returns the analysis info for the given function.
|
||||
const Info& GetAnalysisForFunc(FuncOp func) const {
|
||||
auto it = info_map_.find(func);
|
||||
assert(it != info_map_.end());
|
||||
return it->second;
|
||||
}
|
||||
|
||||
protected:
|
||||
llvm::SmallDenseMap<FuncOp, InfoT, 8> info_map_;
|
||||
};
|
||||
|
||||
class BacktrackAnalysis;
|
||||
|
||||
// Resource alias analysis information for a single function.
|
||||
class ResourceAliasAnalysisInfo {
|
||||
public:
|
||||
// Constructs analysis info by analyzing the given function.
|
||||
ResourceAliasAnalysisInfo(FuncOp func,
|
||||
const BacktrackAnalysis& backtrack_analysis);
|
||||
|
||||
ResourceAliasAnalysisInfo(ResourceAliasAnalysisInfo&&) = default;
|
||||
|
||||
// Returns if the analysis fails to resolve a resource-type value.
|
||||
bool IsUnknownResource(const Value resource) const;
|
||||
|
||||
// Returns the set unique IDs which `resource` could alias. Requires that
|
||||
// IsUnknownResource(resource) == false.
|
||||
const llvm::SmallSet<int64_t, 8>& GetResourceUniqueIds(Value resource) const;
|
||||
|
||||
// Returns the set of values that are potentially aliases of `value`. Requires
|
||||
// that IsUnknownResource(resource) == false.
|
||||
llvm::SmallSetVector<Value, 8> GetResourceAliases(Value resource) const;
|
||||
|
||||
private:
|
||||
// Maps resource value to unique ID and vice-versa.
|
||||
void AddValueUniqueIDMapping(Value value, int64_t id) {
|
||||
resource_value_to_ids_[value].insert(id);
|
||||
id_to_resource_values_[id].insert(value);
|
||||
}
|
||||
|
||||
// Returns the set unique Values which map to `id`.
|
||||
const llvm::SmallSetVector<Value, 8>& GetUniqueIdResources(int64_t id) const;
|
||||
|
||||
// Maps each resource-type value to a set of unique IDs that it could alias.
|
||||
llvm::SmallDenseMap<Value, llvm::SmallSet<int64_t, 8>, 8>
|
||||
resource_value_to_ids_;
|
||||
|
||||
// Maps each unique ID to a set of resource-type values that could alias to
|
||||
// it. This is inverse of `resource_value_to_ids_` map.
|
||||
llvm::SmallDenseMap<int64_t, llvm::SmallSetVector<Value, 8>, 8>
|
||||
id_to_resource_values_;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// An analysis that runs on a module and maps each resource-type value to a
|
||||
// set of unique IDs representing the possible resources it could alias.
|
||||
//
|
||||
// Note that this is not an inter-procedural or inter-regional analysis, i.e.,
|
||||
// each function and region are handled separately and cross-function or cross-
|
||||
// region aliasing cannot be checked by this analysis.
|
||||
class ResourceAliasAnalysis : public detail::PerFunctionAggregateAnalysis<
|
||||
detail::ResourceAliasAnalysisInfo> {
|
||||
public:
|
||||
// Constructs analysis by analyzing the given module operation.
|
||||
explicit ResourceAliasAnalysis(Operation* op);
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
// Side effect analysis info for a single function.
|
||||
class SideEffectAnalysisInfo {
|
||||
public:
|
||||
@ -213,27 +133,6 @@ class SideEffectAnalysis : public detail::PerFunctionAggregateAnalysis<
|
||||
explicit SideEffectAnalysis(Operation* op);
|
||||
};
|
||||
|
||||
// Base CRTP class to help write passes that are consumes a per-function
|
||||
// aggregate analysis and operate on all non-extern functions (similar to a
|
||||
// FunctionPass, but with no concurrency between functions). The derived classes
|
||||
// need to provide a runOnFunction() method that accepts the function and the
|
||||
// analysis information for that function.
|
||||
template <typename DerivedT, typename AnalysisT>
|
||||
class PerFunctionAggregateAnalysisConsumerPass
|
||||
: public PassWrapper<
|
||||
PerFunctionAggregateAnalysisConsumerPass<DerivedT, AnalysisT>,
|
||||
OperationPass<ModuleOp>> {
|
||||
void runOnOperation() override {
|
||||
ModuleOp op = this->getOperation();
|
||||
DerivedT& derived = *static_cast<DerivedT*>(this);
|
||||
auto& analysis = this->template getAnalysis<AnalysisT>();
|
||||
|
||||
for (auto func : op.getOps<FuncOp>())
|
||||
if (!func.isExternal())
|
||||
derived.runOnFunction(func, analysis.GetAnalysisForFunc(func));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@ -36,7 +36,7 @@ limitations under the License.
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user