[MLIR] Convert RegionAliasAnalysis and SideEffectAnalysis to module scoped analysis.

- RegionAliasAnalysis analyzes functions attached to If/While Op's, but is used from
  passes that operate on functions. So when analyzing one function, the analysis can
  inspect another function while it's being modified, which is not thread safe. So convert
  RegionAliasAnalysis and its users to be module scoped analyses/passes.
- In general, the pattern for these analyses is a module scoped analysis that results in
  a per-function or per-region analysis information. Create a PerFunctionAggregateAnalysis
  helper class to define such analyzes, and an PerFunctionAggregateAnalysisConsumer class to
  help define passes that consume such analyzes.
- Extract the function passthrough analysis into a generic backtracking analysis.
- Convert several function passes that use either ResourceAliasAnalysis or
  SideEffectAnalysis to Module passes using the PerFunctionAggregateAnalysisConsumer class.

PiperOrigin-RevId: 324068834
Change-Id: I3f5ceb63ea44f7e3aa0581a5a3d2a559b1e28458
This commit is contained in:
Rahul Joshi 2020-07-30 13:19:07 -07:00 committed by TensorFlower Gardener
parent 34bd3aaad4
commit cc7694a3ef
12 changed files with 417 additions and 188 deletions

View File

@ -1765,6 +1765,7 @@ cc_library(
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
],
)

View File

@ -47,6 +47,153 @@ namespace mlir {
namespace TF {
namespace {
//===----------------------------------------------------------------------===//
// BacktrackAnalysisInfo
//===----------------------------------------------------------------------===//
// Class to hold backtrack analysis for a results of a region. Backtrack
// analysis will trace back the definition of return values of regions through
// pass-through operations, so that the return value of the region will have the
// same value as the backtracked value.
class BacktrackAnalysisInfo {
public:
// Initializes the backtrack analysis for the given region.
explicit BacktrackAnalysisInfo(Region& region,
detail::BacktrackAnalysis& backtrack_analysis);
BacktrackAnalysisInfo(BacktrackAnalysisInfo&&) = default;
// Returns the value to which the given result number of the region can be
// backtracked to.
Value GetValue(int result_index) const {
return backtracked_values_[result_index];
}
// Returns the argument index of the region to which the given result number
// can backtracked to. Such results will be called "function passthrough". If
// the result cannot be backtracked to a region argument, returns llvm::None.
llvm::Optional<int> GetArg(int result_index) const {
if (auto arg = GetValue(result_index).dyn_cast<BlockArgument>())
if (arg.getParentBlock() == &region_->front()) return arg.getArgNumber();
return llvm::None;
}
private:
friend class detail::BacktrackAnalysis;
// Region for which this object holds the analysis info.
Region* region_;
// Backtracked values indexed by the result number.
llvm::SmallVector<Value, 4> backtracked_values_;
};
} // namespace
namespace detail {
//===----------------------------------------------------------------------===//
// BacktrackAnalysis
//===----------------------------------------------------------------------===//
// Holds backtrack analysis for all functions and regions within a module.
class BacktrackAnalysis {
public:
using InfoT = BacktrackAnalysisInfo;
// Constructs the analysis by analyzing the given module.
explicit BacktrackAnalysis(ModuleOp module);
// Returns backtracking analysis for the given region.
const InfoT& GetAnalysisForRegion(Region& region) const {
auto it = info_map_.find(&region);
assert(it != info_map_.end());
return it->second;
}
// Returns backtracking analysis for the given function.
const InfoT& GetAnalysisForFunc(FuncOp func) const {
return GetAnalysisForRegion(func.getBody());
}
// Backtracks the given value.
Value BacktrackValue(Value value);
private:
// Returns the analysis for the given region (analyzing the region if it has
// not yet been analyzed).
const InfoT& GetOrCreateAnalysis(Region& region) {
auto it = info_map_.find(&region);
if (it == info_map_.end()) {
// Note: Keep object construction and insertion separate. If we use
// emplace() to construct and insert in a single shot, when analyzing
// this region, calls to BacktrackValue() may end up inserting additional
// entries in the map, causing the underlying storage to be moved. This
// would also include this pertially constructed object that we have just
// inserted into the map and are constructing it. To avoid this issue,
// construct the analysis object separately and then insert it into the
// map.
InfoT info(region, *this);
info_map_.insert({&region, std::move(info)});
}
return GetAnalysisForRegion(region);
}
private:
llvm::SmallDenseMap<Region*, InfoT> info_map_;
};
// Analyzes all regions attached to all operations in the module.
BacktrackAnalysis::BacktrackAnalysis(ModuleOp module) {
module.walk([this](Operation* op) {
for (Region& region : op->getRegions()) GetOrCreateAnalysis(region);
});
}
// Backtracks the definition of `value` looking through passthrough ops.
// Returns a non-null value and can return `value` if backtracking is not
// possible.
Value BacktrackAnalysis::BacktrackValue(Value value) {
while (Operation* op = value.getDefiningOp()) {
int res_index = value.cast<OpResult>().getResultNumber();
if (auto graph = dyn_cast<tf_executor::GraphOp>(op)) {
value = graph.GetFetch().getOperand(res_index);
} else if (auto island = dyn_cast<tf_executor::IslandOp>(op)) {
// Control output is generated by the IslandOp, not the yield in
// in the Island body.
if (value == island.control()) break;
value = island.GetYield().getOperand(res_index);
} else if (isa<TF::IdentityNOp, TF::IdentityOp>(op)) {
value = op->getOperand(res_index);
} else {
break;
}
}
return value;
}
} // namespace detail
namespace {
// Analyze the region.
BacktrackAnalysisInfo::BacktrackAnalysisInfo(
Region& region, detail::BacktrackAnalysis& backtrack_analysis)
: region_(&region) {
if (region.empty()) return;
assert(llvm::hasSingleElement(region.getBlocks()));
auto results = region.front().getTerminator()->getOperands();
if (results.empty()) return;
backtracked_values_.reserve(results.size());
for (auto result : results)
backtracked_values_.push_back(backtrack_analysis.BacktrackValue(result));
}
} // namespace
namespace {
//===----------------------------------------------------------------------===//
// ResourceAliasAnalysisInfo helper functions.
//===----------------------------------------------------------------------===//
constexpr int64_t kUnknownResourceId = -1;
constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id";
@ -79,45 +226,16 @@ int64_t GetOrCreateIdForVarHandle(TF::VarHandleOp handle, int64_t* next_id,
return emplace_res.first->second;
}
// If the return value for `func_op` at `return_index` is a pass-through of an
// argument of this function, returns the argument index; otherwise, returns -1.
int64_t FindPassthroughArgumentForReturnValue(int64_t return_index,
FuncOp func_op) {
auto value =
func_op.getBody().front().getTerminator()->getOperand(return_index);
assert(mlir::getElementTypeOrSelf(value.getType()).isa<TF::ResourceType>());
int64_t arg_index = -1;
auto try_parse_arg_index = [&arg_index](Value v) {
auto resource_arg = v.dyn_cast<BlockArgument>();
if (resource_arg) arg_index = resource_arg.getArgNumber();
return arg_index;
};
while (try_parse_arg_index(value) == -1) {
auto op = value.getDefiningOp();
assert(op);
int64_t res_num = value.cast<OpResult>().getResultNumber();
if (auto graph = llvm::dyn_cast<tf_executor::GraphOp>(op)) {
value = graph.GetFetch().getOperand(res_num);
} else if (auto island = llvm::dyn_cast<tf_executor::IslandOp>(op)) {
value = island.GetYield().getOperand(res_num);
} else if (llvm::isa<TF::IdentityNOp, TF::IdentityOp>(op)) {
value = op->getOperand(res_num);
} else {
return -1;
}
}
return arg_index;
}
} // namespace
ResourceAliasAnalysis::ResourceAliasAnalysis(Operation* op) {
auto func_op = llvm::dyn_cast<FuncOp>(op);
if (!func_op) return;
AnalyzeFunction(func_op);
}
namespace detail {
//===----------------------------------------------------------------------===//
// ResourceAliasAnalysisInfo
//===----------------------------------------------------------------------===//
void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
// Constructs the analysis info by analyzing the given function.
ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo(
FuncOp func_op, const detail::BacktrackAnalysis& backtrack_analysis) {
// This function populates resource_value_to_ids_ and id_to_resource_values_.
// If the "tf.resource_arg_unique_id" argument attributes are present for
@ -160,7 +278,6 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
result_ids.insert(operand_it->getSecond().begin(),
operand_it->getSecond().end());
};
auto module = func_op.getParentOfType<ModuleOp>();
func_op.walk([&](Operation* op) {
if (auto var_handle = llvm::dyn_cast<TF::VarHandleOp>(op)) {
@ -184,7 +301,8 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
}
}
} else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(op)) {
auto body = llvm::cast<FuncOp>(module.lookupSymbol(while_op.body()));
const auto& body_info =
backtrack_analysis.GetAnalysisForFunc(while_op.body_func());
// If a result is a passthrough of the body input, use the corresponding
// operand's resource IDs.
for (auto result : llvm::enumerate(while_op.getResults())) {
@ -192,20 +310,19 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
.isa<TF::ResourceType>()) {
continue;
}
int64_t passthrough_operand =
FindPassthroughArgumentForReturnValue(result.index(), body);
if (passthrough_operand >= 0) {
forward_input_to_output(while_op.getOperand(passthrough_operand),
result.value());
auto passthrough_arg = body_info.GetArg(result.index());
if (passthrough_arg) {
forward_input_to_output(
while_op.getOperand(passthrough_arg.getValue()), result.value());
} else {
AddValueUniqueIDMapping(result.value(), kUnknownResourceId);
}
}
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(op)) {
auto then_branch =
llvm::cast<FuncOp>(module.lookupSymbol(if_op.then_branch()));
auto else_branch =
llvm::cast<FuncOp>(module.lookupSymbol(if_op.else_branch()));
const auto& then_info =
backtrack_analysis.GetAnalysisForFunc(if_op.then_func());
const auto& else_info =
backtrack_analysis.GetAnalysisForFunc(if_op.else_func());
// If a result is a passthrough of both branches' inputs, merge the
// resource IDs of corresponding operands for the two inputs.
for (auto result : llvm::enumerate(if_op.getResults())) {
@ -213,15 +330,13 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
.isa<TF::ResourceType>()) {
continue;
}
int64_t passthrough_then_arg =
FindPassthroughArgumentForReturnValue(result.index(), then_branch);
int64_t passthrough_else_arg =
FindPassthroughArgumentForReturnValue(result.index(), else_branch);
if (passthrough_then_arg >= 0 && passthrough_else_arg >= 0) {
forward_input_to_output(if_op.getOperand(passthrough_then_arg + 1),
result.value());
forward_input_to_output(if_op.getOperand(passthrough_else_arg + 1),
result.value());
auto passthrough_then_arg = then_info.GetArg(result.index());
auto passthrough_else_arg = else_info.GetArg(result.index());
if (passthrough_then_arg && passthrough_else_arg) {
Value then_operand = if_op.input()[passthrough_then_arg.getValue()];
Value else_operand = if_op.input()[passthrough_else_arg.getValue()];
forward_input_to_output(then_operand, result.value());
forward_input_to_output(else_operand, result.value());
} else {
AddValueUniqueIDMapping(result.value(), kUnknownResourceId);
}
@ -237,7 +352,7 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
});
}
bool ResourceAliasAnalysis::IsUnknownResource(const Value resource) const {
bool ResourceAliasAnalysisInfo::IsUnknownResource(const Value resource) const {
auto it = resource_value_to_ids_.find(resource);
assert(it != resource_value_to_ids_.end() && !it->getSecond().empty());
// The set is sorted so we only need to check the first element since
@ -247,21 +362,21 @@ bool ResourceAliasAnalysis::IsUnknownResource(const Value resource) const {
return *it->getSecond().begin() == kUnknownResourceId;
}
const llvm::SmallSet<int64_t, 8>& ResourceAliasAnalysis::GetResourceUniqueIds(
const Value resource) const {
const llvm::SmallSet<int64_t, 8>&
ResourceAliasAnalysisInfo::GetResourceUniqueIds(Value resource) const {
auto it = resource_value_to_ids_.find(resource);
assert(it != resource_value_to_ids_.end() && "Unseen resource was queried");
return it->getSecond();
}
const llvm::SmallSetVector<Value, 8>&
ResourceAliasAnalysis::GetUniqueIdResources(const int64_t id) const {
ResourceAliasAnalysisInfo::GetUniqueIdResources(const int64_t id) const {
auto it = id_to_resource_values_.find(id);
assert(it != id_to_resource_values_.end() && "Unseen id was queried");
return it->getSecond();
}
llvm::SmallSetVector<Value, 8> ResourceAliasAnalysis::GetResourceAliases(
llvm::SmallSetVector<Value, 8> ResourceAliasAnalysisInfo::GetResourceAliases(
const Value resource) const {
assert(!IsUnknownResource(resource) && "Unseen resource was queried");
llvm::SmallSetVector<Value, 8> aliases;
@ -272,8 +387,31 @@ llvm::SmallSetVector<Value, 8> ResourceAliasAnalysis::GetResourceAliases(
}
return aliases;
}
} // namespace detail
//===----------------------------------------------------------------------===//
// ResourceAliasAnalysis
//===----------------------------------------------------------------------===//
ResourceAliasAnalysis::ResourceAliasAnalysis(Operation* op) {
auto module = dyn_cast<ModuleOp>(op);
assert(module);
// Analyze all regions for backtracking info.
detail::BacktrackAnalysis backtrack_analysis(module);
// Analyze each function.
for (auto func : module.getOps<FuncOp>())
this->info_map_.try_emplace(func, func, backtrack_analysis);
}
namespace {
//===----------------------------------------------------------------------===//
// SideEffectAnalysisInfo helper functions.
//===----------------------------------------------------------------------===//
// Returns a set that contains only kUnknownResourceId.
llvm::SmallDenseSet<int64_t, 8> UnknownResourceSet() {
llvm::SmallDenseSet<int64_t, 8> unknown_set;
@ -284,7 +422,7 @@ llvm::SmallDenseSet<int64_t, 8> UnknownResourceSet() {
// Returns all resources that could be accessed by op, or UnknownResourceSet()
// if we cannot find all of them.
llvm::SmallDenseSet<int64_t, 8> FindAccessedResources(
Operation* op, const ResourceAliasAnalysis& alias_analysis) {
Operation* op, const ResourceAliasAnalysis::Info& alias_analysis) {
llvm::SmallDenseSet<int64_t, 8> resources;
for (auto operand : op->getOperands()) {
@ -311,7 +449,6 @@ llvm::SmallDenseSet<int64_t, 8> FindAccessedResources(
// TODO(yuanzx): Define this information in a different place. Currently we use
// tensorflow/compiler/tf2xla/resource_operation_table.h.
const tensorflow::XlaResourceOpInfo* GetResourceInfoForOp(Operation* op) {
auto op_name = op->getName().getStringRef().str();
if (op->getName().getDialect() !=
TF::TensorFlowDialect::getDialectNamespace()) {
return nullptr;
@ -329,7 +466,7 @@ bool OpIsReadOnly(Operation* op) {
// Returns if `op` is a resource declaration.
bool OpIsDeclaration(Operation* op,
const ResourceAliasAnalysis& alias_analysis) {
const ResourceAliasAnalysis::Info& alias_analysis) {
// TODO(yuanzx): Add other types of resources.
return llvm::isa<TF::VarHandleOp>(op) ||
(llvm::isa<TF::IdentityNOp, TF::IdentityOp>(op) &&
@ -370,8 +507,13 @@ bool OpIsKnownToHaveNoSideEffect(Operation* op) {
} // namespace
void SideEffectAnalysis::TrackAccess(int64_t resource_id, Operation* op,
bool read_only) {
namespace detail {
//===----------------------------------------------------------------------===//
// SideEffectAnalysisInfo
//===----------------------------------------------------------------------===//
void SideEffectAnalysisInfo::TrackAccess(int64_t resource_id, Operation* op,
bool read_only) {
if (resource_id == kUnknownResourceId) {
if (read_only) {
// New unknown read is not tracked by any known resource access.
@ -402,9 +544,9 @@ void SideEffectAnalysis::TrackAccess(int64_t resource_id, Operation* op,
}
}
void SideEffectAnalysis::AddPredecessorsForAccess(int64_t resource_id,
Operation* op,
bool read_only) {
void SideEffectAnalysisInfo::AddPredecessorsForAccess(int64_t resource_id,
Operation* op,
bool read_only) {
auto it = per_resource_access_info_.find(resource_id);
if (it == per_resource_access_info_.end()) return;
const auto& access_info = it->getSecond();
@ -420,8 +562,8 @@ void SideEffectAnalysis::AddPredecessorsForAccess(int64_t resource_id,
}
}
void SideEffectAnalysis::AnalyzeFunction(
FuncOp func_op, const ResourceAliasAnalysis& alias_analysis) {
void SideEffectAnalysisInfo::AnalyzeFunction(
FuncOp func_op, const TF::ResourceAliasAnalysis::Info& alias_analysis) {
// AnalyzeRegion() recursively analyzes the function body, and only populates
// control_predecessors_.
AnalyzeRegion(&func_op.getBody(), alias_analysis);
@ -448,8 +590,8 @@ void SideEffectAnalysis::AnalyzeFunction(
}
}
void SideEffectAnalysis::AnalyzeRegion(
Region* region, const ResourceAliasAnalysis& alias_analysis) {
void SideEffectAnalysisInfo::AnalyzeRegion(
Region* region, const TF::ResourceAliasAnalysis::Info& alias_analysis) {
// This function populates control_predecessors_ by walking through the
// region, and tracking resource accesses in per_resource_access_info_.
@ -476,13 +618,12 @@ void SideEffectAnalysis::AnalyzeRegion(
// different nested regions separately.
for (auto& block : *region) {
for (auto& op : block) {
if (op.getNumRegions() > 0) {
llvm::SmallVector<SideEffectAnalysis, 4> child_analyses;
for (auto& child_region : op.getRegions()) {
child_analyses.emplace_back();
child_analyses.back().AnalyzeRegion(&child_region, alias_analysis);
}
ConsumeChildAnalyses(std::move(child_analyses));
for (Region& child : op.getRegions()) {
SideEffectAnalysisInfo child_analysis(&child, alias_analysis);
// Moves the control_predecessors_ fields in child region to current
// region
for (auto& entry : child_analysis.control_predecessors_)
control_predecessors_[entry.first] = std::move(entry.second);
}
// We do not need explicit control edges for declaration ops.
@ -529,16 +670,8 @@ void SideEffectAnalysis::AnalyzeRegion(
}
}
void SideEffectAnalysis::ConsumeChildAnalyses(
llvm::SmallVector<SideEffectAnalysis, 4>&& children) {
for (auto& child : children) {
for (auto& entry : child.control_predecessors_) {
control_predecessors_[entry.getFirst()] = std::move(entry.getSecond());
}
}
}
llvm::SmallVector<Operation*, 4> SideEffectAnalysis::DirectControlPredecessors(
llvm::SmallVector<Operation*, 4>
SideEffectAnalysisInfo::DirectControlPredecessors(
Operation* op, llvm::function_ref<bool(Operation*)> filter) const {
llvm::SmallVector<Operation*, 4> result;
auto it = sorted_control_predecessors_.find(op);
@ -550,7 +683,8 @@ llvm::SmallVector<Operation*, 4> SideEffectAnalysis::DirectControlPredecessors(
return result;
}
llvm::SmallVector<Operation*, 4> SideEffectAnalysis::DirectControlSuccessors(
llvm::SmallVector<Operation*, 4>
SideEffectAnalysisInfo::DirectControlSuccessors(
Operation* op, llvm::function_ref<bool(Operation*)> filter) const {
llvm::SmallVector<Operation*, 4> result;
auto it = sorted_control_successors_.find(op);
@ -561,12 +695,19 @@ llvm::SmallVector<Operation*, 4> SideEffectAnalysis::DirectControlSuccessors(
}
return result;
}
} // namespace detail
SideEffectAnalysis::SideEffectAnalysis(Operation* op) {
auto func_op = llvm::dyn_cast<FuncOp>(op);
if (!func_op) return;
ResourceAliasAnalysis alias_analysis(op);
AnalyzeFunction(func_op, alias_analysis);
auto module = dyn_cast<ModuleOp>(op);
assert(module);
// Analyze entire module for alias analysis info.
ResourceAliasAnalysis alias_analysis(module);
// Analyze all functions.
for (auto func : module.getOps<FuncOp>())
this->info_map_.try_emplace(func, func,
alias_analysis.GetAnalysisForFunc(func));
}
} // namespace TF

View File

@ -16,51 +16,68 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_SIDE_EFFECT_ANALYSIS_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_SIDE_EFFECT_ANALYSIS_H_
#include <cstddef>
#include <cstdint>
#include <memory>
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h"
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/Region.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
namespace mlir {
namespace TF {
// An analysis that runs on a function and maps each resource-type value to a
// set of unique int64_t IDs representing the possible resources it could alias.
//
// If there are nested regions, each region is handled separately. This means
// cross-region aliasing cannot be checked by this analysis.
class ResourceAliasAnalysis {
namespace detail {
// This template defines an aggregate analysis base class, which analyzes a
// module but the analysis info is stored per function.
template <typename InfoT>
class PerFunctionAggregateAnalysis {
public:
explicit ResourceAliasAnalysis(Operation* op);
~ResourceAliasAnalysis() = default;
ResourceAliasAnalysis(ResourceAliasAnalysis&&) = default;
using Info = InfoT;
// Returns the analysis info for the given function.
const Info& GetAnalysisForFunc(FuncOp func) const {
auto it = info_map_.find(func);
assert(it != info_map_.end());
return it->second;
}
protected:
llvm::SmallDenseMap<FuncOp, InfoT, 8> info_map_;
};
class BacktrackAnalysis;
// Resource alias analysis information for a single function.
class ResourceAliasAnalysisInfo {
public:
// Constructs analysis info by analyzing the given function.
ResourceAliasAnalysisInfo(FuncOp func,
const BacktrackAnalysis& backtrack_analysis);
ResourceAliasAnalysisInfo(ResourceAliasAnalysisInfo&&) = default;
// Returns if the analysis fails to resolve a resource-type value.
bool IsUnknownResource(const Value resource) const;
// Returns the set unique IDs which `resource` could alias. Requires that
// IsUnknownResource(resource) == true.
const llvm::SmallSet<int64_t, 8>& GetResourceUniqueIds(
const Value resource) const;
// IsUnknownResource(resource) == false.
const llvm::SmallSet<int64_t, 8>& GetResourceUniqueIds(Value resource) const;
// Returns the set of values that are potentially aliases of `value`. Requires
// that IsUnknownResource(resource) == true.
llvm::SmallSetVector<Value, 8> GetResourceAliases(const Value resource) const;
// that IsUnknownResource(resource) == false.
llvm::SmallSetVector<Value, 8> GetResourceAliases(Value resource) const;
private:
ResourceAliasAnalysis() = default;
// Runs the analysis on `func_op` and populates two way resource values to
// unique ID mapping.
void AnalyzeFunction(FuncOp func_op);
// Maps resource value to unique ID and vice-versa.
void AddValueUniqueIDMapping(Value value, int64_t id) {
resource_value_to_ids_[value].insert(id);
@ -80,21 +97,40 @@ class ResourceAliasAnalysis {
id_to_resource_values_;
};
// An analysis that runs on a function and infers the control predecessors and
// successors for each op, based on side-effects on known and unknown resources.
// Side-effecting ops on unknown resources are conservatively treated as
// interfering with all known resource op accesses. It distinguishes accesses
// based on whether they are read-only, and read-only ops do not interfere with
// each other.
} // namespace detail
// An analysis that runs on a module and maps each resource-type value to a
// set of unique IDs representing the possible resources it could alias.
//
// If there are nested regions, each region is handled separately, and control
// dependencies are only tracked for ops under the same parent op.
class SideEffectAnalysis {
// Note that this is not an inter-procedural or inter-regional analysis, i.e.,
// each function and region are handled separately and cross-function or cross-
// region aliasing cannot be checked by this analysis.
class ResourceAliasAnalysis : public detail::PerFunctionAggregateAnalysis<
detail::ResourceAliasAnalysisInfo> {
public:
explicit SideEffectAnalysis() = default;
explicit SideEffectAnalysis(Operation* op);
SideEffectAnalysis(SideEffectAnalysis&& other) = default;
~SideEffectAnalysis() = default;
// Constructs analysis by analyzing the given module operation.
explicit ResourceAliasAnalysis(Operation* op);
};
namespace detail {
// Side effect analysis info for a single function.
class SideEffectAnalysisInfo {
public:
SideEffectAnalysisInfo() = default;
// Constructs analysis info by analyzing the given function.
SideEffectAnalysisInfo(
FuncOp func_op, const TF::ResourceAliasAnalysis::Info& alias_analysis) {
AnalyzeFunction(func_op, alias_analysis);
}
// Constructs analysis info by analyzing the given region.
SideEffectAnalysisInfo(
Region* region, const TF::ResourceAliasAnalysis::Info& alias_analysis) {
AnalyzeRegion(region, alias_analysis);
}
SideEffectAnalysisInfo(SideEffectAnalysisInfo&&) = default;
// Returns a vector of ops that are direct control predecessors of `op`,
// sorted in program order. If `filter` is provided, only predecessors that
@ -103,9 +139,9 @@ class SideEffectAnalysis {
Operation* op,
llvm::function_ref<bool(Operation*)> filter = nullptr) const;
// Returns a vector of ops that are direct control successors of `op`, sorted
// in program order. If `filter` is provided, only successors that pass the
// filter (returning true) will be included.
// Returns a vector of ops that are direct control successors of `op`,
// sorted in program order. If `filter` is provided, only successors that
// pass the filter (returning true) will be included.
llvm::SmallVector<Operation*, 4> DirectControlSuccessors(
Operation* op,
llvm::function_ref<bool(Operation*)> filter = nullptr) const;
@ -114,16 +150,11 @@ class SideEffectAnalysis {
// Runs the analysis on `func_op` and populates sorted_control_predecessors_
// and sorted_control_successors_.
void AnalyzeFunction(FuncOp func_op,
const ResourceAliasAnalysis& alias_analysis);
const TF::ResourceAliasAnalysis::Info& alias_analysis);
// Runs the analysis on `region` and populates control_predecessors_.
void AnalyzeRegion(Region* region,
const ResourceAliasAnalysis& alias_analysis);
// Moves the control_predecessors_ fields in `children` analyses to this
// current analysis.
void ConsumeChildAnalyses(
llvm::SmallVector<SideEffectAnalysis, 4>&& children);
const TF::ResourceAliasAnalysis::Info& alias_analysis);
// Updates control_predecessors_ for `op` that is being visited, on the given
// `resource_id`.
@ -159,10 +190,50 @@ class SideEffectAnalysis {
// write for a the current write being analyzed.
bool tracked_last_unknown_write_for_write = false;
};
llvm::SmallDenseMap<int64_t, PerResourceAccessInfo, 8>
per_resource_access_info_;
};
} // namespace detail
// An analysis that runs on a function and infers the control predecessors and
// successors for each op, based on side-effects on known and unknown resources.
// Side-effecting ops on unknown resources are conservatively treated as
// interfering with all known resource op accesses. It distinguishes accesses
// based on whether they are read-only, and read-only ops do not interfere with
// each other.
//
// If there are nested regions, each region is handled separately, and control
// dependencies are only tracked for ops under the same parent op.
class SideEffectAnalysis : public detail::PerFunctionAggregateAnalysis<
detail::SideEffectAnalysisInfo> {
public:
// Constructs analysis by analyzing the given module operation.
explicit SideEffectAnalysis(Operation* op);
};
// Base CRTP class to help write passes that are consumes a per-function
// aggregate analysis and operate on all non-extern functions (similar to a
// FunctionPass, but with no concurrency between functions). The derived classes
// need to provide a runOnFunction() method that accepts the function and the
// analysis information for that function.
template <typename DerivedT, typename AnalysisT>
class PerFunctionAggregateAnalysisConsumerPass
: public PassWrapper<
PerFunctionAggregateAnalysisConsumerPass<DerivedT, AnalysisT>,
OperationPass<ModuleOp>> {
void runOnOperation() override {
ModuleOp op = this->getOperation();
DerivedT& derived = *static_cast<DerivedT*>(this);
auto& analysis = this->template getAnalysis<AnalysisT>();
for (auto func : op.getOps<FuncOp>())
if (!func.isExternal())
derived.runOnFunction(func, analysis.GetAnalysisForFunc(func));
}
};
} // namespace TF
} // namespace mlir

View File

@ -562,7 +562,7 @@ Status MlirFunction::GetFunctionDef(tensorflow::FunctionDef** f) {
}
PassManager pm(func_.getContext());
pm.addNestedPass<FuncOp>(CreateFunctionalToExecutorDialectConversionPass());
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass());
pm.addPass(CreateBreakUpIslandsPass());
// In case of failure, the `diag_handler` converts MLIR errors emitted to
// the MLIRContext into a tensorflow::Status.

View File

@ -40,15 +40,16 @@ void EnableLogging(PassManager *pm) {
namespace TFTPU {
namespace {
void AddGraphExportLoweringPasses(OpPassManager &pm) {
auto add_pass = [&](std::unique_ptr<Pass> pass) {
pm.addNestedPass<FuncOp>(std::move(pass));
pm.addPass(CreateBreakUpIslandsPass());
};
pm.addNestedPass<FuncOp>(CreateFunctionalToExecutorDialectConversionPass());
pm.addNestedPass<FuncOp>(TFDevice::CreateParallelizeEmbeddingParamsOpsPass());
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass());
pm.addNestedPass<FuncOp>(TFDevice::CreateReplicateToIslandPass());
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass());
pm.addNestedPass<FuncOp>(TFDevice::CreateParallelExecuteToIslandsPass());
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass());
pm.addNestedPass<FuncOp>(TFDevice::CreateLaunchToDeviceAttributePass());
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass());
add_pass(TFDevice::CreateParallelizeEmbeddingParamsOpsPass());
add_pass(TFDevice::CreateReplicateToIslandPass());
add_pass(TFDevice::CreateParallelExecuteToIslandsPass());
add_pass(TFDevice::CreateLaunchToDeviceAttributePass());
}
tensorflow::Status RunTPUBridge(

View File

@ -51,7 +51,7 @@ Status MlirGraphOptimizationPass::Run(const ConfigProto& config_proto,
CreateLayoutOptimizationPipeline(pm, layout_optimization_options);
// Prepare IR for exporting.
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass());
pm.addPass(CreateBreakUpIslandsPass());
// In case of failure, the `diag_handler` converts MLIR errors emitted to the
// MLIRContext into a tensorflow::Status.

View File

@ -26,7 +26,7 @@ namespace mlir {
// Creates a pass that breaks up an island with multiple ops into multiple
// islands, each with a single op.
std::unique_ptr<OperationPass<FuncOp>> CreateBreakUpIslandsPass();
std::unique_ptr<OperationPass<ModuleOp>> CreateBreakUpIslandsPass();
// Creates a pass that converts mlir functions consisting of mlir ops into a
// tf_executor dialect as a single island.
@ -270,7 +270,7 @@ std::unique_ptr<OperationPass<FuncOp>> CreateTPUClusterFormationPass();
// Creates a pass that allows TPU program inputs to have layouts determined at
// run time.
std::unique_ptr<OperationPass<FuncOp>> CreateTPUDynamicLayoutPass();
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDynamicLayoutPass();
// Creates a pass that remaps and assigns padding map from a
// `tf_device.launch_func` `padding_map` attribute to its encapsulated function.

View File

@ -61,7 +61,9 @@ struct ResourceDeviceInference
// A class that records each resource's device assignment in a function.
class PerFunctionResult {
public:
explicit PerFunctionResult(FuncOp func_op) : alias_analysis_(func_op) {}
explicit PerFunctionResult(
FuncOp func_op, const TF::ResourceAliasAnalysis::Info& alias_analysis)
: alias_analysis_(alias_analysis) {}
// Returns the recorded device assignment for a resource, if any.
llvm::Optional<llvm::StringRef> DeviceForResource(
@ -105,7 +107,7 @@ class PerFunctionResult {
private:
llvm::SmallDenseMap<int64_t, llvm::StringRef, 8> resource_id_to_device_;
TF::ResourceAliasAnalysis alias_analysis_;
const TF::ResourceAliasAnalysis::Info& alias_analysis_;
};
// Tries to record device assignment for a resource.
@ -193,11 +195,15 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
void ResourceDeviceInference::runOnOperation() {
auto module = getOperation();
const auto& resource_alias_analysis =
getAnalysis<TF::ResourceAliasAnalysis>();
llvm::SmallDenseMap<Operation*, PerFunctionResult, 4> per_function_results;
llvm::SetVector<FuncOp> worklist;
module.walk([&](FuncOp func_op) {
worklist.insert(func_op);
per_function_results.try_emplace(func_op, func_op);
per_function_results.try_emplace(
func_op, func_op, resource_alias_analysis.GetAnalysisForFunc(func_op));
});
// Helper that propagates an op's recorded operand device assignments to its
// called function's arguments.

View File

@ -39,11 +39,13 @@ namespace {
// A pass that adds "Predecessors" and "Successors" remarks for each op based on
// SideEffectAnalysis result. For testing purpose only.
struct TestSideEffectAnalysis
: public mlir::PassWrapper<TestSideEffectAnalysis, FunctionPass> {
void runOnFunction() override {
: public TF::PerFunctionAggregateAnalysisConsumerPass<
TestSideEffectAnalysis, TF::SideEffectAnalysis> {
void runOnFunction(FuncOp func,
const TF::SideEffectAnalysis::Info& analysis) {
int64_t next_id = 0;
llvm::SmallDenseMap<Operation*, int64_t, 8> ids;
getFunction().walk([&](Operation* op) {
func.walk([&](Operation* op) {
ids[op] = next_id++;
op->emitRemark("ID: ") << ids[op];
});
@ -53,8 +55,7 @@ struct TestSideEffectAnalysis
for (auto op : ops) id_vec.push_back(std::to_string(ids[op]));
return llvm::join(id_vec, ",");
};
auto& analysis = getAnalysis<TF::SideEffectAnalysis>();
getFunction().walk([&](Operation* op) {
func.walk([&](Operation* op) {
if (!analysis.DirectControlPredecessors(op).empty()) {
op->emitRemark("Predecessors: ")
<< "{" << join_ids(analysis.DirectControlPredecessors(op)) << "}";

View File

@ -36,6 +36,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
@ -77,24 +78,28 @@ constexpr char kFuncDeviceAttr[] = "tf.device";
// because tf.TPUCopyWithLayout accepts a host input and produces a device
// output.
struct TPUDynamicLayoutPass
: public PassWrapper<TPUDynamicLayoutPass, FunctionPass> {
void runOnFunction() override;
: public TF::PerFunctionAggregateAnalysisConsumerPass<
TPUDynamicLayoutPass, TF::ResourceAliasAnalysis> {
void runOnFunction(
FuncOp func,
const TF::ResourceAliasAnalysis::Info& resource_alias_analysis);
};
// Checks if the input producer op is supported in this transform. Right now, we
// only check if it is a tf.IteratorGetNext where resource input is coming from
// a VarHandle on CPU or a function argument assigned to CPU.
bool IsSupportedInputOp(Operation* op,
TF::ResourceAliasAnalysis* resource_alias_analysis) {
bool IsSupportedInputOp(
Operation* op,
const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
TF::IteratorGetNextOp iterator_op = llvm::dyn_cast<TF::IteratorGetNextOp>(op);
if (!iterator_op) return false;
Value resource_iterator = iterator_op.iterator();
if (resource_alias_analysis->IsUnknownResource(resource_iterator))
if (resource_alias_analysis.IsUnknownResource(resource_iterator))
return false;
llvm::SmallSetVector<Value, 8> aliases =
resource_alias_analysis->GetResourceAliases(resource_iterator);
resource_alias_analysis.GetResourceAliases(resource_iterator);
auto is_generator = [](Value val) {
if (val.isa<BlockArgument>()) return true;
@ -177,7 +182,7 @@ bool HandleReplicatedInputs(
const int64_t execute_arg_index, Value compilation_key,
tf_device::LaunchOp execute_launch, tf_device::LaunchOp compile_launch,
const int64_t replicate_arg_index, tf_device::ReplicateOp replicate,
TF::ResourceAliasAnalysis* resource_alias_analysis) {
const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
// We need to know the devices to copy to.
if (!replicate.devices()) return false;
int64_t num_replicas = replicate.n().getZExtValue();
@ -215,7 +220,7 @@ bool HandleReplicatedInputs(
void HandleCompileAndExecutes(
tf_device::LaunchOp compile_launch,
llvm::MutableArrayRef<tf_device::LaunchOp> execute_launches,
TF::ResourceAliasAnalysis* resource_alias_analysis) {
const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
auto compile =
llvm::cast<TF::_TPUCompileMlirOp>(compile_launch.GetBody().front());
tensorflow::tpu::TPUCompileMetadataProto metadata;
@ -273,9 +278,10 @@ void HandleCompileAndExecutes(
compile.getContext()));
}
void TPUDynamicLayoutPass::runOnFunction() {
TF::ResourceAliasAnalysis resource_alias_analysis(getFunction());
getFunction().walk([&](TF::_TPUCompileMlirOp compile) {
void TPUDynamicLayoutPass::runOnFunction(
FuncOp func,
const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
func.walk([&](TF::_TPUCompileMlirOp compile) {
// Detect tf._TPUCompileMlir -> tf.TPUExecute(s).
auto compile_launch =
llvm::dyn_cast<tf_device::LaunchOp>(compile.getParentOp());
@ -295,13 +301,13 @@ void TPUDynamicLayoutPass::runOnFunction() {
}
HandleCompileAndExecutes(compile_launch, execute_launches,
&resource_alias_analysis);
resource_alias_analysis);
});
}
} // namespace
std::unique_ptr<OperationPass<FuncOp>> CreateTPUDynamicLayoutPass() {
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDynamicLayoutPass() {
return std::make_unique<TPUDynamicLayoutPass>();
}

View File

@ -41,25 +41,28 @@ namespace mlir {
namespace {
struct BreakUpIslands : PassWrapper<BreakUpIslands, FunctionPass> {
void runOnFunction() final;
class BreakUpIslands : public TF::PerFunctionAggregateAnalysisConsumerPass<
BreakUpIslands, TF::SideEffectAnalysis> {
public:
void runOnFunction(FuncOp func,
const TF::SideEffectAnalysis::Info& side_effect_analysis);
void BreakUpIsland(tf_executor::IslandOp island_op,
const TF::SideEffectAnalysis& side_effect_analysis,
const TF::SideEffectAnalysis::Info& side_effect_analysis,
llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>>*
new_control_inputs);
};
void BreakUpIslands::runOnFunction() {
auto graph_op_range = getFunction().getBody().front().without_terminator();
void BreakUpIslands::runOnFunction(
FuncOp func, const TF::SideEffectAnalysis::Info& side_effect_analysis) {
auto graph_op_range = func.front().without_terminator();
tf_executor::GraphOp graph_op;
if (graph_op_range.begin() != graph_op_range.end() &&
std::next(graph_op_range.begin()) == graph_op_range.end()) {
graph_op = dyn_cast<tf_executor::GraphOp>(
getOperation().getBody().front().front());
}
if (llvm::hasSingleElement(graph_op_range))
graph_op = dyn_cast<tf_executor::GraphOp>(func.front().front());
if (!graph_op) {
getOperation().emitError("expected function to contain only a graph_op");
func.emitError("expected function to contain only a graph_op");
signalPassFailure();
return;
}
@ -67,7 +70,6 @@ void BreakUpIslands::runOnFunction() {
// New control inputs to be added. For an operation x, new_control_inputs[x]
// contains all control inputs that need to be added to x as operands.
llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>> new_control_inputs;
auto& side_effect_analysis = getAnalysis<TF::SideEffectAnalysis>();
// Iterate in reverse order to avoid invalidating Operation* stored in
// new_control_inputs.
for (auto& item :
@ -76,7 +78,7 @@ void BreakUpIslands::runOnFunction() {
BreakUpIsland(island, side_effect_analysis, &new_control_inputs);
}
}
OpBuilder builder(getOperation());
OpBuilder builder(func);
// For every op, add new control inputs in reverse order so that the ops don't
// get invalidated.
@ -181,7 +183,7 @@ struct IslandSourcesAndSinks {
// Finds IslandSourcesAndSinks for an unmodified island.
IslandSourcesAndSinks FindSourcesAndSinksInIsland(
tf_executor::IslandOp island,
const TF::SideEffectAnalysis& side_effect_analysis) {
const TF::SideEffectAnalysis::Info& side_effect_analysis) {
IslandSourcesAndSinks result;
auto island_body = island.GetBody().without_terminator();
for (Operation& sub_op : island_body) {
@ -208,7 +210,7 @@ IslandSourcesAndSinks FindSourcesAndSinksInIsland(
// are chained together by control flow values.
void BreakUpIslands::BreakUpIsland(
tf_executor::IslandOp island_op,
const TF::SideEffectAnalysis& side_effect_analysis,
const TF::SideEffectAnalysis::Info& side_effect_analysis,
llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>>*
new_control_inputs) {
auto island_body = island_op.GetBody().without_terminator();
@ -323,7 +325,7 @@ void BreakUpIslands::BreakUpIsland(
} // namespace
std::unique_ptr<OperationPass<FuncOp>> CreateBreakUpIslandsPass() {
std::unique_ptr<OperationPass<ModuleOp>> CreateBreakUpIslandsPass() {
return std::make_unique<BreakUpIslands>();
}

View File

@ -45,7 +45,7 @@ void AddTFToTFJSConversionPasses(mlir::OpPassManager* pm) {
// raise to executor dialect in order to use GraphDef converter
pm->addNestedPass<mlir::FuncOp>(
mlir::CreateFunctionalToExecutorDialectConversionPass());
pm->addNestedPass<mlir::FuncOp>(mlir::CreateBreakUpIslandsPass());
pm->addPass(mlir::CreateBreakUpIslandsPass());
}
} // namespace tensorflow