[MLIR] Convert RegionAliasAnalysis and SideEffectAnalysis to module scoped analysis.

- RegionAliasAnalysis analyzes functions attached to If/While Op's, but is used from
  passes that operate on functions. So when analyzing one function, the analysis can
  inspect another function while it's being modified, which is not thread safe. So convert
  RegionAliasAnalysis and its users to be module scoped analyses/passes.
- In general, the pattern for these analyses is a module scoped analysis that results in
  a per-function or per-region analysis information. Create a PerFunctionAggregateAnalysis
  helper class to define such analyzes, and an PerFunctionAggregateAnalysisConsumer class to
  help define passes that consume such analyzes.
- Extract the function passthrough analysis into a generic backtracking analysis.
- Convert several function passes that use either ResourceAliasAnalysis or
  SideEffectAnalysis to Module passes using the PerFunctionAggregateAnalysisConsumer class.

PiperOrigin-RevId: 324068834
Change-Id: I3f5ceb63ea44f7e3aa0581a5a3d2a559b1e28458
This commit is contained in:
Rahul Joshi 2020-07-30 13:19:07 -07:00 committed by TensorFlower Gardener
parent 34bd3aaad4
commit cc7694a3ef
12 changed files with 417 additions and 188 deletions

View File

@ -1765,6 +1765,7 @@ cc_library(
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
], ],
) )

View File

@ -47,6 +47,153 @@ namespace mlir {
namespace TF { namespace TF {
namespace { namespace {
//===----------------------------------------------------------------------===//
// BacktrackAnalysisInfo
//===----------------------------------------------------------------------===//
// Class to hold backtrack analysis for a results of a region. Backtrack
// analysis will trace back the definition of return values of regions through
// pass-through operations, so that the return value of the region will have the
// same value as the backtracked value.
class BacktrackAnalysisInfo {
public:
// Initializes the backtrack analysis for the given region.
explicit BacktrackAnalysisInfo(Region& region,
detail::BacktrackAnalysis& backtrack_analysis);
BacktrackAnalysisInfo(BacktrackAnalysisInfo&&) = default;
// Returns the value to which the given result number of the region can be
// backtracked to.
Value GetValue(int result_index) const {
return backtracked_values_[result_index];
}
// Returns the argument index of the region to which the given result number
// can backtracked to. Such results will be called "function passthrough". If
// the result cannot be backtracked to a region argument, returns llvm::None.
llvm::Optional<int> GetArg(int result_index) const {
if (auto arg = GetValue(result_index).dyn_cast<BlockArgument>())
if (arg.getParentBlock() == &region_->front()) return arg.getArgNumber();
return llvm::None;
}
private:
friend class detail::BacktrackAnalysis;
// Region for which this object holds the analysis info.
Region* region_;
// Backtracked values indexed by the result number.
llvm::SmallVector<Value, 4> backtracked_values_;
};
} // namespace
namespace detail {
//===----------------------------------------------------------------------===//
// BacktrackAnalysis
//===----------------------------------------------------------------------===//
// Holds backtrack analysis for all functions and regions within a module.
class BacktrackAnalysis {
public:
using InfoT = BacktrackAnalysisInfo;
// Constructs the analysis by analyzing the given module.
explicit BacktrackAnalysis(ModuleOp module);
// Returns backtracking analysis for the given region.
const InfoT& GetAnalysisForRegion(Region& region) const {
auto it = info_map_.find(&region);
assert(it != info_map_.end());
return it->second;
}
// Returns backtracking analysis for the given function.
const InfoT& GetAnalysisForFunc(FuncOp func) const {
return GetAnalysisForRegion(func.getBody());
}
// Backtracks the given value.
Value BacktrackValue(Value value);
private:
// Returns the analysis for the given region (analyzing the region if it has
// not yet been analyzed).
const InfoT& GetOrCreateAnalysis(Region& region) {
auto it = info_map_.find(&region);
if (it == info_map_.end()) {
// Note: Keep object construction and insertion separate. If we use
// emplace() to construct and insert in a single shot, when analyzing
// this region, calls to BacktrackValue() may end up inserting additional
// entries in the map, causing the underlying storage to be moved. This
// would also include this pertially constructed object that we have just
// inserted into the map and are constructing it. To avoid this issue,
// construct the analysis object separately and then insert it into the
// map.
InfoT info(region, *this);
info_map_.insert({&region, std::move(info)});
}
return GetAnalysisForRegion(region);
}
private:
llvm::SmallDenseMap<Region*, InfoT> info_map_;
};
// Analyzes all regions attached to all operations in the module.
BacktrackAnalysis::BacktrackAnalysis(ModuleOp module) {
module.walk([this](Operation* op) {
for (Region& region : op->getRegions()) GetOrCreateAnalysis(region);
});
}
// Backtracks the definition of `value` looking through passthrough ops.
// Returns a non-null value and can return `value` if backtracking is not
// possible.
Value BacktrackAnalysis::BacktrackValue(Value value) {
while (Operation* op = value.getDefiningOp()) {
int res_index = value.cast<OpResult>().getResultNumber();
if (auto graph = dyn_cast<tf_executor::GraphOp>(op)) {
value = graph.GetFetch().getOperand(res_index);
} else if (auto island = dyn_cast<tf_executor::IslandOp>(op)) {
// Control output is generated by the IslandOp, not the yield in
// in the Island body.
if (value == island.control()) break;
value = island.GetYield().getOperand(res_index);
} else if (isa<TF::IdentityNOp, TF::IdentityOp>(op)) {
value = op->getOperand(res_index);
} else {
break;
}
}
return value;
}
} // namespace detail
namespace {
// Analyze the region.
BacktrackAnalysisInfo::BacktrackAnalysisInfo(
Region& region, detail::BacktrackAnalysis& backtrack_analysis)
: region_(&region) {
if (region.empty()) return;
assert(llvm::hasSingleElement(region.getBlocks()));
auto results = region.front().getTerminator()->getOperands();
if (results.empty()) return;
backtracked_values_.reserve(results.size());
for (auto result : results)
backtracked_values_.push_back(backtrack_analysis.BacktrackValue(result));
}
} // namespace
namespace {
//===----------------------------------------------------------------------===//
// ResourceAliasAnalysisInfo helper functions.
//===----------------------------------------------------------------------===//
constexpr int64_t kUnknownResourceId = -1; constexpr int64_t kUnknownResourceId = -1;
constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id"; constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id";
@ -79,45 +226,16 @@ int64_t GetOrCreateIdForVarHandle(TF::VarHandleOp handle, int64_t* next_id,
return emplace_res.first->second; return emplace_res.first->second;
} }
// If the return value for `func_op` at `return_index` is a pass-through of an
// argument of this function, returns the argument index; otherwise, returns -1.
int64_t FindPassthroughArgumentForReturnValue(int64_t return_index,
FuncOp func_op) {
auto value =
func_op.getBody().front().getTerminator()->getOperand(return_index);
assert(mlir::getElementTypeOrSelf(value.getType()).isa<TF::ResourceType>());
int64_t arg_index = -1;
auto try_parse_arg_index = [&arg_index](Value v) {
auto resource_arg = v.dyn_cast<BlockArgument>();
if (resource_arg) arg_index = resource_arg.getArgNumber();
return arg_index;
};
while (try_parse_arg_index(value) == -1) {
auto op = value.getDefiningOp();
assert(op);
int64_t res_num = value.cast<OpResult>().getResultNumber();
if (auto graph = llvm::dyn_cast<tf_executor::GraphOp>(op)) {
value = graph.GetFetch().getOperand(res_num);
} else if (auto island = llvm::dyn_cast<tf_executor::IslandOp>(op)) {
value = island.GetYield().getOperand(res_num);
} else if (llvm::isa<TF::IdentityNOp, TF::IdentityOp>(op)) {
value = op->getOperand(res_num);
} else {
return -1;
}
}
return arg_index;
}
} // namespace } // namespace
ResourceAliasAnalysis::ResourceAliasAnalysis(Operation* op) { namespace detail {
auto func_op = llvm::dyn_cast<FuncOp>(op); //===----------------------------------------------------------------------===//
if (!func_op) return; // ResourceAliasAnalysisInfo
AnalyzeFunction(func_op); //===----------------------------------------------------------------------===//
}
void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) { // Constructs the analysis info by analyzing the given function.
ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo(
FuncOp func_op, const detail::BacktrackAnalysis& backtrack_analysis) {
// This function populates resource_value_to_ids_ and id_to_resource_values_. // This function populates resource_value_to_ids_ and id_to_resource_values_.
// If the "tf.resource_arg_unique_id" argument attributes are present for // If the "tf.resource_arg_unique_id" argument attributes are present for
@ -160,7 +278,6 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
result_ids.insert(operand_it->getSecond().begin(), result_ids.insert(operand_it->getSecond().begin(),
operand_it->getSecond().end()); operand_it->getSecond().end());
}; };
auto module = func_op.getParentOfType<ModuleOp>();
func_op.walk([&](Operation* op) { func_op.walk([&](Operation* op) {
if (auto var_handle = llvm::dyn_cast<TF::VarHandleOp>(op)) { if (auto var_handle = llvm::dyn_cast<TF::VarHandleOp>(op)) {
@ -184,7 +301,8 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
} }
} }
} else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(op)) { } else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(op)) {
auto body = llvm::cast<FuncOp>(module.lookupSymbol(while_op.body())); const auto& body_info =
backtrack_analysis.GetAnalysisForFunc(while_op.body_func());
// If a result is a passthrough of the body input, use the corresponding // If a result is a passthrough of the body input, use the corresponding
// operand's resource IDs. // operand's resource IDs.
for (auto result : llvm::enumerate(while_op.getResults())) { for (auto result : llvm::enumerate(while_op.getResults())) {
@ -192,20 +310,19 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
.isa<TF::ResourceType>()) { .isa<TF::ResourceType>()) {
continue; continue;
} }
int64_t passthrough_operand = auto passthrough_arg = body_info.GetArg(result.index());
FindPassthroughArgumentForReturnValue(result.index(), body); if (passthrough_arg) {
if (passthrough_operand >= 0) { forward_input_to_output(
forward_input_to_output(while_op.getOperand(passthrough_operand), while_op.getOperand(passthrough_arg.getValue()), result.value());
result.value());
} else { } else {
AddValueUniqueIDMapping(result.value(), kUnknownResourceId); AddValueUniqueIDMapping(result.value(), kUnknownResourceId);
} }
} }
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(op)) { } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(op)) {
auto then_branch = const auto& then_info =
llvm::cast<FuncOp>(module.lookupSymbol(if_op.then_branch())); backtrack_analysis.GetAnalysisForFunc(if_op.then_func());
auto else_branch = const auto& else_info =
llvm::cast<FuncOp>(module.lookupSymbol(if_op.else_branch())); backtrack_analysis.GetAnalysisForFunc(if_op.else_func());
// If a result is a passthrough of both branches' inputs, merge the // If a result is a passthrough of both branches' inputs, merge the
// resource IDs of corresponding operands for the two inputs. // resource IDs of corresponding operands for the two inputs.
for (auto result : llvm::enumerate(if_op.getResults())) { for (auto result : llvm::enumerate(if_op.getResults())) {
@ -213,15 +330,13 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
.isa<TF::ResourceType>()) { .isa<TF::ResourceType>()) {
continue; continue;
} }
int64_t passthrough_then_arg = auto passthrough_then_arg = then_info.GetArg(result.index());
FindPassthroughArgumentForReturnValue(result.index(), then_branch); auto passthrough_else_arg = else_info.GetArg(result.index());
int64_t passthrough_else_arg = if (passthrough_then_arg && passthrough_else_arg) {
FindPassthroughArgumentForReturnValue(result.index(), else_branch); Value then_operand = if_op.input()[passthrough_then_arg.getValue()];
if (passthrough_then_arg >= 0 && passthrough_else_arg >= 0) { Value else_operand = if_op.input()[passthrough_else_arg.getValue()];
forward_input_to_output(if_op.getOperand(passthrough_then_arg + 1), forward_input_to_output(then_operand, result.value());
result.value()); forward_input_to_output(else_operand, result.value());
forward_input_to_output(if_op.getOperand(passthrough_else_arg + 1),
result.value());
} else { } else {
AddValueUniqueIDMapping(result.value(), kUnknownResourceId); AddValueUniqueIDMapping(result.value(), kUnknownResourceId);
} }
@ -237,7 +352,7 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
}); });
} }
bool ResourceAliasAnalysis::IsUnknownResource(const Value resource) const { bool ResourceAliasAnalysisInfo::IsUnknownResource(const Value resource) const {
auto it = resource_value_to_ids_.find(resource); auto it = resource_value_to_ids_.find(resource);
assert(it != resource_value_to_ids_.end() && !it->getSecond().empty()); assert(it != resource_value_to_ids_.end() && !it->getSecond().empty());
// The set is sorted so we only need to check the first element since // The set is sorted so we only need to check the first element since
@ -247,21 +362,21 @@ bool ResourceAliasAnalysis::IsUnknownResource(const Value resource) const {
return *it->getSecond().begin() == kUnknownResourceId; return *it->getSecond().begin() == kUnknownResourceId;
} }
const llvm::SmallSet<int64_t, 8>& ResourceAliasAnalysis::GetResourceUniqueIds( const llvm::SmallSet<int64_t, 8>&
const Value resource) const { ResourceAliasAnalysisInfo::GetResourceUniqueIds(Value resource) const {
auto it = resource_value_to_ids_.find(resource); auto it = resource_value_to_ids_.find(resource);
assert(it != resource_value_to_ids_.end() && "Unseen resource was queried"); assert(it != resource_value_to_ids_.end() && "Unseen resource was queried");
return it->getSecond(); return it->getSecond();
} }
const llvm::SmallSetVector<Value, 8>& const llvm::SmallSetVector<Value, 8>&
ResourceAliasAnalysis::GetUniqueIdResources(const int64_t id) const { ResourceAliasAnalysisInfo::GetUniqueIdResources(const int64_t id) const {
auto it = id_to_resource_values_.find(id); auto it = id_to_resource_values_.find(id);
assert(it != id_to_resource_values_.end() && "Unseen id was queried"); assert(it != id_to_resource_values_.end() && "Unseen id was queried");
return it->getSecond(); return it->getSecond();
} }
llvm::SmallSetVector<Value, 8> ResourceAliasAnalysis::GetResourceAliases( llvm::SmallSetVector<Value, 8> ResourceAliasAnalysisInfo::GetResourceAliases(
const Value resource) const { const Value resource) const {
assert(!IsUnknownResource(resource) && "Unseen resource was queried"); assert(!IsUnknownResource(resource) && "Unseen resource was queried");
llvm::SmallSetVector<Value, 8> aliases; llvm::SmallSetVector<Value, 8> aliases;
@ -272,8 +387,31 @@ llvm::SmallSetVector<Value, 8> ResourceAliasAnalysis::GetResourceAliases(
} }
return aliases; return aliases;
} }
} // namespace detail
//===----------------------------------------------------------------------===//
// ResourceAliasAnalysis
//===----------------------------------------------------------------------===//
ResourceAliasAnalysis::ResourceAliasAnalysis(Operation* op) {
auto module = dyn_cast<ModuleOp>(op);
assert(module);
// Analyze all regions for backtracking info.
detail::BacktrackAnalysis backtrack_analysis(module);
// Analyze each function.
for (auto func : module.getOps<FuncOp>())
this->info_map_.try_emplace(func, func, backtrack_analysis);
}
namespace { namespace {
//===----------------------------------------------------------------------===//
// SideEffectAnalysisInfo helper functions.
//===----------------------------------------------------------------------===//
// Returns a set that contains only kUnknownResourceId. // Returns a set that contains only kUnknownResourceId.
llvm::SmallDenseSet<int64_t, 8> UnknownResourceSet() { llvm::SmallDenseSet<int64_t, 8> UnknownResourceSet() {
llvm::SmallDenseSet<int64_t, 8> unknown_set; llvm::SmallDenseSet<int64_t, 8> unknown_set;
@ -284,7 +422,7 @@ llvm::SmallDenseSet<int64_t, 8> UnknownResourceSet() {
// Returns all resources that could be accessed by op, or UnknownResourceSet() // Returns all resources that could be accessed by op, or UnknownResourceSet()
// if we cannot find all of them. // if we cannot find all of them.
llvm::SmallDenseSet<int64_t, 8> FindAccessedResources( llvm::SmallDenseSet<int64_t, 8> FindAccessedResources(
Operation* op, const ResourceAliasAnalysis& alias_analysis) { Operation* op, const ResourceAliasAnalysis::Info& alias_analysis) {
llvm::SmallDenseSet<int64_t, 8> resources; llvm::SmallDenseSet<int64_t, 8> resources;
for (auto operand : op->getOperands()) { for (auto operand : op->getOperands()) {
@ -311,7 +449,6 @@ llvm::SmallDenseSet<int64_t, 8> FindAccessedResources(
// TODO(yuanzx): Define this information in a different place. Currently we use // TODO(yuanzx): Define this information in a different place. Currently we use
// tensorflow/compiler/tf2xla/resource_operation_table.h. // tensorflow/compiler/tf2xla/resource_operation_table.h.
const tensorflow::XlaResourceOpInfo* GetResourceInfoForOp(Operation* op) { const tensorflow::XlaResourceOpInfo* GetResourceInfoForOp(Operation* op) {
auto op_name = op->getName().getStringRef().str();
if (op->getName().getDialect() != if (op->getName().getDialect() !=
TF::TensorFlowDialect::getDialectNamespace()) { TF::TensorFlowDialect::getDialectNamespace()) {
return nullptr; return nullptr;
@ -329,7 +466,7 @@ bool OpIsReadOnly(Operation* op) {
// Returns if `op` is a resource declaration. // Returns if `op` is a resource declaration.
bool OpIsDeclaration(Operation* op, bool OpIsDeclaration(Operation* op,
const ResourceAliasAnalysis& alias_analysis) { const ResourceAliasAnalysis::Info& alias_analysis) {
// TODO(yuanzx): Add other types of resources. // TODO(yuanzx): Add other types of resources.
return llvm::isa<TF::VarHandleOp>(op) || return llvm::isa<TF::VarHandleOp>(op) ||
(llvm::isa<TF::IdentityNOp, TF::IdentityOp>(op) && (llvm::isa<TF::IdentityNOp, TF::IdentityOp>(op) &&
@ -370,7 +507,12 @@ bool OpIsKnownToHaveNoSideEffect(Operation* op) {
} // namespace } // namespace
void SideEffectAnalysis::TrackAccess(int64_t resource_id, Operation* op, namespace detail {
//===----------------------------------------------------------------------===//
// SideEffectAnalysisInfo
//===----------------------------------------------------------------------===//
void SideEffectAnalysisInfo::TrackAccess(int64_t resource_id, Operation* op,
bool read_only) { bool read_only) {
if (resource_id == kUnknownResourceId) { if (resource_id == kUnknownResourceId) {
if (read_only) { if (read_only) {
@ -402,7 +544,7 @@ void SideEffectAnalysis::TrackAccess(int64_t resource_id, Operation* op,
} }
} }
void SideEffectAnalysis::AddPredecessorsForAccess(int64_t resource_id, void SideEffectAnalysisInfo::AddPredecessorsForAccess(int64_t resource_id,
Operation* op, Operation* op,
bool read_only) { bool read_only) {
auto it = per_resource_access_info_.find(resource_id); auto it = per_resource_access_info_.find(resource_id);
@ -420,8 +562,8 @@ void SideEffectAnalysis::AddPredecessorsForAccess(int64_t resource_id,
} }
} }
void SideEffectAnalysis::AnalyzeFunction( void SideEffectAnalysisInfo::AnalyzeFunction(
FuncOp func_op, const ResourceAliasAnalysis& alias_analysis) { FuncOp func_op, const TF::ResourceAliasAnalysis::Info& alias_analysis) {
// AnalyzeRegion() recursively analyzes the function body, and only populates // AnalyzeRegion() recursively analyzes the function body, and only populates
// control_predecessors_. // control_predecessors_.
AnalyzeRegion(&func_op.getBody(), alias_analysis); AnalyzeRegion(&func_op.getBody(), alias_analysis);
@ -448,8 +590,8 @@ void SideEffectAnalysis::AnalyzeFunction(
} }
} }
void SideEffectAnalysis::AnalyzeRegion( void SideEffectAnalysisInfo::AnalyzeRegion(
Region* region, const ResourceAliasAnalysis& alias_analysis) { Region* region, const TF::ResourceAliasAnalysis::Info& alias_analysis) {
// This function populates control_predecessors_ by walking through the // This function populates control_predecessors_ by walking through the
// region, and tracking resource accesses in per_resource_access_info_. // region, and tracking resource accesses in per_resource_access_info_.
@ -476,13 +618,12 @@ void SideEffectAnalysis::AnalyzeRegion(
// different nested regions separately. // different nested regions separately.
for (auto& block : *region) { for (auto& block : *region) {
for (auto& op : block) { for (auto& op : block) {
if (op.getNumRegions() > 0) { for (Region& child : op.getRegions()) {
llvm::SmallVector<SideEffectAnalysis, 4> child_analyses; SideEffectAnalysisInfo child_analysis(&child, alias_analysis);
for (auto& child_region : op.getRegions()) { // Moves the control_predecessors_ fields in child region to current
child_analyses.emplace_back(); // region
child_analyses.back().AnalyzeRegion(&child_region, alias_analysis); for (auto& entry : child_analysis.control_predecessors_)
} control_predecessors_[entry.first] = std::move(entry.second);
ConsumeChildAnalyses(std::move(child_analyses));
} }
// We do not need explicit control edges for declaration ops. // We do not need explicit control edges for declaration ops.
@ -529,16 +670,8 @@ void SideEffectAnalysis::AnalyzeRegion(
} }
} }
void SideEffectAnalysis::ConsumeChildAnalyses( llvm::SmallVector<Operation*, 4>
llvm::SmallVector<SideEffectAnalysis, 4>&& children) { SideEffectAnalysisInfo::DirectControlPredecessors(
for (auto& child : children) {
for (auto& entry : child.control_predecessors_) {
control_predecessors_[entry.getFirst()] = std::move(entry.getSecond());
}
}
}
llvm::SmallVector<Operation*, 4> SideEffectAnalysis::DirectControlPredecessors(
Operation* op, llvm::function_ref<bool(Operation*)> filter) const { Operation* op, llvm::function_ref<bool(Operation*)> filter) const {
llvm::SmallVector<Operation*, 4> result; llvm::SmallVector<Operation*, 4> result;
auto it = sorted_control_predecessors_.find(op); auto it = sorted_control_predecessors_.find(op);
@ -550,7 +683,8 @@ llvm::SmallVector<Operation*, 4> SideEffectAnalysis::DirectControlPredecessors(
return result; return result;
} }
llvm::SmallVector<Operation*, 4> SideEffectAnalysis::DirectControlSuccessors( llvm::SmallVector<Operation*, 4>
SideEffectAnalysisInfo::DirectControlSuccessors(
Operation* op, llvm::function_ref<bool(Operation*)> filter) const { Operation* op, llvm::function_ref<bool(Operation*)> filter) const {
llvm::SmallVector<Operation*, 4> result; llvm::SmallVector<Operation*, 4> result;
auto it = sorted_control_successors_.find(op); auto it = sorted_control_successors_.find(op);
@ -561,12 +695,19 @@ llvm::SmallVector<Operation*, 4> SideEffectAnalysis::DirectControlSuccessors(
} }
return result; return result;
} }
} // namespace detail
SideEffectAnalysis::SideEffectAnalysis(Operation* op) { SideEffectAnalysis::SideEffectAnalysis(Operation* op) {
auto func_op = llvm::dyn_cast<FuncOp>(op); auto module = dyn_cast<ModuleOp>(op);
if (!func_op) return; assert(module);
ResourceAliasAnalysis alias_analysis(op);
AnalyzeFunction(func_op, alias_analysis); // Analyze entire module for alias analysis info.
ResourceAliasAnalysis alias_analysis(module);
// Analyze all functions.
for (auto func : module.getOps<FuncOp>())
this->info_map_.try_emplace(func, func,
alias_analysis.GetAnalysisForFunc(func));
} }
} // namespace TF } // namespace TF

View File

@ -16,51 +16,68 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_SIDE_EFFECT_ANALYSIS_H_ #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_SIDE_EFFECT_ANALYSIS_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_SIDE_EFFECT_ANALYSIS_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_SIDE_EFFECT_ANALYSIS_H_
#include <cstddef>
#include <cstdint> #include <cstdint>
#include <memory> #include <memory>
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringMap.h"
#include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/Region.h" // from @llvm-project #include "mlir/IR/Region.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project
namespace mlir { namespace mlir {
namespace TF { namespace TF {
// An analysis that runs on a function and maps each resource-type value to a namespace detail {
// set of unique int64_t IDs representing the possible resources it could alias.
// // This template defines an aggregate analysis base class, which analyzes a
// If there are nested regions, each region is handled separately. This means // module but the analysis info is stored per function.
// cross-region aliasing cannot be checked by this analysis. template <typename InfoT>
class ResourceAliasAnalysis { class PerFunctionAggregateAnalysis {
public: public:
explicit ResourceAliasAnalysis(Operation* op); using Info = InfoT;
~ResourceAliasAnalysis() = default;
ResourceAliasAnalysis(ResourceAliasAnalysis&&) = default; // Returns the analysis info for the given function.
const Info& GetAnalysisForFunc(FuncOp func) const {
auto it = info_map_.find(func);
assert(it != info_map_.end());
return it->second;
}
protected:
llvm::SmallDenseMap<FuncOp, InfoT, 8> info_map_;
};
class BacktrackAnalysis;
// Resource alias analysis information for a single function.
class ResourceAliasAnalysisInfo {
public:
// Constructs analysis info by analyzing the given function.
ResourceAliasAnalysisInfo(FuncOp func,
const BacktrackAnalysis& backtrack_analysis);
ResourceAliasAnalysisInfo(ResourceAliasAnalysisInfo&&) = default;
// Returns if the analysis fails to resolve a resource-type value. // Returns if the analysis fails to resolve a resource-type value.
bool IsUnknownResource(const Value resource) const; bool IsUnknownResource(const Value resource) const;
// Returns the set unique IDs which `resource` could alias. Requires that // Returns the set unique IDs which `resource` could alias. Requires that
// IsUnknownResource(resource) == true. // IsUnknownResource(resource) == false.
const llvm::SmallSet<int64_t, 8>& GetResourceUniqueIds( const llvm::SmallSet<int64_t, 8>& GetResourceUniqueIds(Value resource) const;
const Value resource) const;
// Returns the set of values that are potentially aliases of `value`. Requires // Returns the set of values that are potentially aliases of `value`. Requires
// that IsUnknownResource(resource) == true. // that IsUnknownResource(resource) == false.
llvm::SmallSetVector<Value, 8> GetResourceAliases(const Value resource) const; llvm::SmallSetVector<Value, 8> GetResourceAliases(Value resource) const;
private: private:
ResourceAliasAnalysis() = default;
// Runs the analysis on `func_op` and populates two way resource values to
// unique ID mapping.
void AnalyzeFunction(FuncOp func_op);
// Maps resource value to unique ID and vice-versa. // Maps resource value to unique ID and vice-versa.
void AddValueUniqueIDMapping(Value value, int64_t id) { void AddValueUniqueIDMapping(Value value, int64_t id) {
resource_value_to_ids_[value].insert(id); resource_value_to_ids_[value].insert(id);
@ -80,21 +97,40 @@ class ResourceAliasAnalysis {
id_to_resource_values_; id_to_resource_values_;
}; };
// An analysis that runs on a function and infers the control predecessors and } // namespace detail
// successors for each op, based on side-effects on known and unknown resources.
// Side-effecting ops on unknown resources are conservatively treated as // An analysis that runs on a module and maps each resource-type value to a
// interfering with all known resource op accesses. It distinguishes accesses // set of unique IDs representing the possible resources it could alias.
// based on whether they are read-only, and read-only ops do not interfere with
// each other.
// //
// If there are nested regions, each region is handled separately, and control // Note that this is not an inter-procedural or inter-regional analysis, i.e.,
// dependencies are only tracked for ops under the same parent op. // each function and region are handled separately and cross-function or cross-
class SideEffectAnalysis { // region aliasing cannot be checked by this analysis.
class ResourceAliasAnalysis : public detail::PerFunctionAggregateAnalysis<
detail::ResourceAliasAnalysisInfo> {
public: public:
explicit SideEffectAnalysis() = default; // Constructs analysis by analyzing the given module operation.
explicit SideEffectAnalysis(Operation* op); explicit ResourceAliasAnalysis(Operation* op);
SideEffectAnalysis(SideEffectAnalysis&& other) = default; };
~SideEffectAnalysis() = default;
namespace detail {
// Side effect analysis info for a single function.
class SideEffectAnalysisInfo {
public:
SideEffectAnalysisInfo() = default;
// Constructs analysis info by analyzing the given function.
SideEffectAnalysisInfo(
FuncOp func_op, const TF::ResourceAliasAnalysis::Info& alias_analysis) {
AnalyzeFunction(func_op, alias_analysis);
}
// Constructs analysis info by analyzing the given region.
SideEffectAnalysisInfo(
Region* region, const TF::ResourceAliasAnalysis::Info& alias_analysis) {
AnalyzeRegion(region, alias_analysis);
}
SideEffectAnalysisInfo(SideEffectAnalysisInfo&&) = default;
// Returns a vector of ops that are direct control predecessors of `op`, // Returns a vector of ops that are direct control predecessors of `op`,
// sorted in program order. If `filter` is provided, only predecessors that // sorted in program order. If `filter` is provided, only predecessors that
@ -103,9 +139,9 @@ class SideEffectAnalysis {
Operation* op, Operation* op,
llvm::function_ref<bool(Operation*)> filter = nullptr) const; llvm::function_ref<bool(Operation*)> filter = nullptr) const;
// Returns a vector of ops that are direct control successors of `op`, sorted // Returns a vector of ops that are direct control successors of `op`,
// in program order. If `filter` is provided, only successors that pass the // sorted in program order. If `filter` is provided, only successors that
// filter (returning true) will be included. // pass the filter (returning true) will be included.
llvm::SmallVector<Operation*, 4> DirectControlSuccessors( llvm::SmallVector<Operation*, 4> DirectControlSuccessors(
Operation* op, Operation* op,
llvm::function_ref<bool(Operation*)> filter = nullptr) const; llvm::function_ref<bool(Operation*)> filter = nullptr) const;
@ -114,16 +150,11 @@ class SideEffectAnalysis {
// Runs the analysis on `func_op` and populates sorted_control_predecessors_ // Runs the analysis on `func_op` and populates sorted_control_predecessors_
// and sorted_control_successors_. // and sorted_control_successors_.
void AnalyzeFunction(FuncOp func_op, void AnalyzeFunction(FuncOp func_op,
const ResourceAliasAnalysis& alias_analysis); const TF::ResourceAliasAnalysis::Info& alias_analysis);
// Runs the analysis on `region` and populates control_predecessors_. // Runs the analysis on `region` and populates control_predecessors_.
void AnalyzeRegion(Region* region, void AnalyzeRegion(Region* region,
const ResourceAliasAnalysis& alias_analysis); const TF::ResourceAliasAnalysis::Info& alias_analysis);
// Moves the control_predecessors_ fields in `children` analyses to this
// current analysis.
void ConsumeChildAnalyses(
llvm::SmallVector<SideEffectAnalysis, 4>&& children);
// Updates control_predecessors_ for `op` that is being visited, on the given // Updates control_predecessors_ for `op` that is being visited, on the given
// `resource_id`. // `resource_id`.
@ -159,10 +190,50 @@ class SideEffectAnalysis {
// write for a the current write being analyzed. // write for a the current write being analyzed.
bool tracked_last_unknown_write_for_write = false; bool tracked_last_unknown_write_for_write = false;
}; };
llvm::SmallDenseMap<int64_t, PerResourceAccessInfo, 8> llvm::SmallDenseMap<int64_t, PerResourceAccessInfo, 8>
per_resource_access_info_; per_resource_access_info_;
}; };
} // namespace detail
// An analysis that runs on a function and infers the control predecessors and
// successors for each op, based on side-effects on known and unknown resources.
// Side-effecting ops on unknown resources are conservatively treated as
// interfering with all known resource op accesses. It distinguishes accesses
// based on whether they are read-only, and read-only ops do not interfere with
// each other.
//
// If there are nested regions, each region is handled separately, and control
// dependencies are only tracked for ops under the same parent op.
class SideEffectAnalysis : public detail::PerFunctionAggregateAnalysis<
detail::SideEffectAnalysisInfo> {
public:
// Constructs analysis by analyzing the given module operation.
explicit SideEffectAnalysis(Operation* op);
};
// Base CRTP class to help write passes that are consumes a per-function
// aggregate analysis and operate on all non-extern functions (similar to a
// FunctionPass, but with no concurrency between functions). The derived classes
// need to provide a runOnFunction() method that accepts the function and the
// analysis information for that function.
template <typename DerivedT, typename AnalysisT>
class PerFunctionAggregateAnalysisConsumerPass
: public PassWrapper<
PerFunctionAggregateAnalysisConsumerPass<DerivedT, AnalysisT>,
OperationPass<ModuleOp>> {
void runOnOperation() override {
ModuleOp op = this->getOperation();
DerivedT& derived = *static_cast<DerivedT*>(this);
auto& analysis = this->template getAnalysis<AnalysisT>();
for (auto func : op.getOps<FuncOp>())
if (!func.isExternal())
derived.runOnFunction(func, analysis.GetAnalysisForFunc(func));
}
};
} // namespace TF } // namespace TF
} // namespace mlir } // namespace mlir

View File

@ -562,7 +562,7 @@ Status MlirFunction::GetFunctionDef(tensorflow::FunctionDef** f) {
} }
PassManager pm(func_.getContext()); PassManager pm(func_.getContext());
pm.addNestedPass<FuncOp>(CreateFunctionalToExecutorDialectConversionPass()); pm.addNestedPass<FuncOp>(CreateFunctionalToExecutorDialectConversionPass());
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass()); pm.addPass(CreateBreakUpIslandsPass());
// In case of failure, the `diag_handler` converts MLIR errors emitted to // In case of failure, the `diag_handler` converts MLIR errors emitted to
// the MLIRContext into a tensorflow::Status. // the MLIRContext into a tensorflow::Status.

View File

@ -40,15 +40,16 @@ void EnableLogging(PassManager *pm) {
namespace TFTPU { namespace TFTPU {
namespace { namespace {
void AddGraphExportLoweringPasses(OpPassManager &pm) { void AddGraphExportLoweringPasses(OpPassManager &pm) {
auto add_pass = [&](std::unique_ptr<Pass> pass) {
pm.addNestedPass<FuncOp>(std::move(pass));
pm.addPass(CreateBreakUpIslandsPass());
};
pm.addNestedPass<FuncOp>(CreateFunctionalToExecutorDialectConversionPass()); pm.addNestedPass<FuncOp>(CreateFunctionalToExecutorDialectConversionPass());
pm.addNestedPass<FuncOp>(TFDevice::CreateParallelizeEmbeddingParamsOpsPass()); add_pass(TFDevice::CreateParallelizeEmbeddingParamsOpsPass());
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass()); add_pass(TFDevice::CreateReplicateToIslandPass());
pm.addNestedPass<FuncOp>(TFDevice::CreateReplicateToIslandPass()); add_pass(TFDevice::CreateParallelExecuteToIslandsPass());
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass()); add_pass(TFDevice::CreateLaunchToDeviceAttributePass());
pm.addNestedPass<FuncOp>(TFDevice::CreateParallelExecuteToIslandsPass());
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass());
pm.addNestedPass<FuncOp>(TFDevice::CreateLaunchToDeviceAttributePass());
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass());
} }
tensorflow::Status RunTPUBridge( tensorflow::Status RunTPUBridge(

View File

@ -51,7 +51,7 @@ Status MlirGraphOptimizationPass::Run(const ConfigProto& config_proto,
CreateLayoutOptimizationPipeline(pm, layout_optimization_options); CreateLayoutOptimizationPipeline(pm, layout_optimization_options);
// Prepare IR for exporting. // Prepare IR for exporting.
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass()); pm.addPass(CreateBreakUpIslandsPass());
// In case of failure, the `diag_handler` converts MLIR errors emitted to the // In case of failure, the `diag_handler` converts MLIR errors emitted to the
// MLIRContext into a tensorflow::Status. // MLIRContext into a tensorflow::Status.

View File

@ -26,7 +26,7 @@ namespace mlir {
// Creates a pass that breaks up an island with multiple ops into multiple // Creates a pass that breaks up an island with multiple ops into multiple
// islands, each with a single op. // islands, each with a single op.
std::unique_ptr<OperationPass<FuncOp>> CreateBreakUpIslandsPass(); std::unique_ptr<OperationPass<ModuleOp>> CreateBreakUpIslandsPass();
// Creates a pass that converts mlir functions consisting of mlir ops into a // Creates a pass that converts mlir functions consisting of mlir ops into a
// tf_executor dialect as a single island. // tf_executor dialect as a single island.
@ -270,7 +270,7 @@ std::unique_ptr<OperationPass<FuncOp>> CreateTPUClusterFormationPass();
// Creates a pass that allows TPU program inputs to have layouts determined at // Creates a pass that allows TPU program inputs to have layouts determined at
// run time. // run time.
std::unique_ptr<OperationPass<FuncOp>> CreateTPUDynamicLayoutPass(); std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDynamicLayoutPass();
// Creates a pass that remaps and assigns padding map from a // Creates a pass that remaps and assigns padding map from a
// `tf_device.launch_func` `padding_map` attribute to its encapsulated function. // `tf_device.launch_func` `padding_map` attribute to its encapsulated function.

View File

@ -61,7 +61,9 @@ struct ResourceDeviceInference
// A class that records each resource's device assignment in a function. // A class that records each resource's device assignment in a function.
class PerFunctionResult { class PerFunctionResult {
public: public:
explicit PerFunctionResult(FuncOp func_op) : alias_analysis_(func_op) {} explicit PerFunctionResult(
FuncOp func_op, const TF::ResourceAliasAnalysis::Info& alias_analysis)
: alias_analysis_(alias_analysis) {}
// Returns the recorded device assignment for a resource, if any. // Returns the recorded device assignment for a resource, if any.
llvm::Optional<llvm::StringRef> DeviceForResource( llvm::Optional<llvm::StringRef> DeviceForResource(
@ -105,7 +107,7 @@ class PerFunctionResult {
private: private:
llvm::SmallDenseMap<int64_t, llvm::StringRef, 8> resource_id_to_device_; llvm::SmallDenseMap<int64_t, llvm::StringRef, 8> resource_id_to_device_;
TF::ResourceAliasAnalysis alias_analysis_; const TF::ResourceAliasAnalysis::Info& alias_analysis_;
}; };
// Tries to record device assignment for a resource. // Tries to record device assignment for a resource.
@ -193,11 +195,15 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
void ResourceDeviceInference::runOnOperation() { void ResourceDeviceInference::runOnOperation() {
auto module = getOperation(); auto module = getOperation();
const auto& resource_alias_analysis =
getAnalysis<TF::ResourceAliasAnalysis>();
llvm::SmallDenseMap<Operation*, PerFunctionResult, 4> per_function_results; llvm::SmallDenseMap<Operation*, PerFunctionResult, 4> per_function_results;
llvm::SetVector<FuncOp> worklist; llvm::SetVector<FuncOp> worklist;
module.walk([&](FuncOp func_op) { module.walk([&](FuncOp func_op) {
worklist.insert(func_op); worklist.insert(func_op);
per_function_results.try_emplace(func_op, func_op); per_function_results.try_emplace(
func_op, func_op, resource_alias_analysis.GetAnalysisForFunc(func_op));
}); });
// Helper that propagates an op's recorded operand device assignments to its // Helper that propagates an op's recorded operand device assignments to its
// called function's arguments. // called function's arguments.

View File

@ -39,11 +39,13 @@ namespace {
// A pass that adds "Predecessors" and "Successors" remarks for each op based on // A pass that adds "Predecessors" and "Successors" remarks for each op based on
// SideEffectAnalysis result. For testing purpose only. // SideEffectAnalysis result. For testing purpose only.
struct TestSideEffectAnalysis struct TestSideEffectAnalysis
: public mlir::PassWrapper<TestSideEffectAnalysis, FunctionPass> { : public TF::PerFunctionAggregateAnalysisConsumerPass<
void runOnFunction() override { TestSideEffectAnalysis, TF::SideEffectAnalysis> {
void runOnFunction(FuncOp func,
const TF::SideEffectAnalysis::Info& analysis) {
int64_t next_id = 0; int64_t next_id = 0;
llvm::SmallDenseMap<Operation*, int64_t, 8> ids; llvm::SmallDenseMap<Operation*, int64_t, 8> ids;
getFunction().walk([&](Operation* op) { func.walk([&](Operation* op) {
ids[op] = next_id++; ids[op] = next_id++;
op->emitRemark("ID: ") << ids[op]; op->emitRemark("ID: ") << ids[op];
}); });
@ -53,8 +55,7 @@ struct TestSideEffectAnalysis
for (auto op : ops) id_vec.push_back(std::to_string(ids[op])); for (auto op : ops) id_vec.push_back(std::to_string(ids[op]));
return llvm::join(id_vec, ","); return llvm::join(id_vec, ",");
}; };
auto& analysis = getAnalysis<TF::SideEffectAnalysis>(); func.walk([&](Operation* op) {
getFunction().walk([&](Operation* op) {
if (!analysis.DirectControlPredecessors(op).empty()) { if (!analysis.DirectControlPredecessors(op).empty()) {
op->emitRemark("Predecessors: ") op->emitRemark("Predecessors: ")
<< "{" << join_ids(analysis.DirectControlPredecessors(op)) << "}"; << "{" << join_ids(analysis.DirectControlPredecessors(op)) << "}";

View File

@ -36,6 +36,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
@ -77,24 +78,28 @@ constexpr char kFuncDeviceAttr[] = "tf.device";
// because tf.TPUCopyWithLayout accepts a host input and produces a device // because tf.TPUCopyWithLayout accepts a host input and produces a device
// output. // output.
struct TPUDynamicLayoutPass struct TPUDynamicLayoutPass
: public PassWrapper<TPUDynamicLayoutPass, FunctionPass> { : public TF::PerFunctionAggregateAnalysisConsumerPass<
void runOnFunction() override; TPUDynamicLayoutPass, TF::ResourceAliasAnalysis> {
void runOnFunction(
FuncOp func,
const TF::ResourceAliasAnalysis::Info& resource_alias_analysis);
}; };
// Checks if the input producer op is supported in this transform. Right now, we // Checks if the input producer op is supported in this transform. Right now, we
// only check if it is a tf.IteratorGetNext where resource input is coming from // only check if it is a tf.IteratorGetNext where resource input is coming from
// a VarHandle on CPU or a function argument assigned to CPU. // a VarHandle on CPU or a function argument assigned to CPU.
bool IsSupportedInputOp(Operation* op, bool IsSupportedInputOp(
TF::ResourceAliasAnalysis* resource_alias_analysis) { Operation* op,
const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
TF::IteratorGetNextOp iterator_op = llvm::dyn_cast<TF::IteratorGetNextOp>(op); TF::IteratorGetNextOp iterator_op = llvm::dyn_cast<TF::IteratorGetNextOp>(op);
if (!iterator_op) return false; if (!iterator_op) return false;
Value resource_iterator = iterator_op.iterator(); Value resource_iterator = iterator_op.iterator();
if (resource_alias_analysis->IsUnknownResource(resource_iterator)) if (resource_alias_analysis.IsUnknownResource(resource_iterator))
return false; return false;
llvm::SmallSetVector<Value, 8> aliases = llvm::SmallSetVector<Value, 8> aliases =
resource_alias_analysis->GetResourceAliases(resource_iterator); resource_alias_analysis.GetResourceAliases(resource_iterator);
auto is_generator = [](Value val) { auto is_generator = [](Value val) {
if (val.isa<BlockArgument>()) return true; if (val.isa<BlockArgument>()) return true;
@ -177,7 +182,7 @@ bool HandleReplicatedInputs(
const int64_t execute_arg_index, Value compilation_key, const int64_t execute_arg_index, Value compilation_key,
tf_device::LaunchOp execute_launch, tf_device::LaunchOp compile_launch, tf_device::LaunchOp execute_launch, tf_device::LaunchOp compile_launch,
const int64_t replicate_arg_index, tf_device::ReplicateOp replicate, const int64_t replicate_arg_index, tf_device::ReplicateOp replicate,
TF::ResourceAliasAnalysis* resource_alias_analysis) { const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
// We need to know the devices to copy to. // We need to know the devices to copy to.
if (!replicate.devices()) return false; if (!replicate.devices()) return false;
int64_t num_replicas = replicate.n().getZExtValue(); int64_t num_replicas = replicate.n().getZExtValue();
@ -215,7 +220,7 @@ bool HandleReplicatedInputs(
void HandleCompileAndExecutes( void HandleCompileAndExecutes(
tf_device::LaunchOp compile_launch, tf_device::LaunchOp compile_launch,
llvm::MutableArrayRef<tf_device::LaunchOp> execute_launches, llvm::MutableArrayRef<tf_device::LaunchOp> execute_launches,
TF::ResourceAliasAnalysis* resource_alias_analysis) { const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
auto compile = auto compile =
llvm::cast<TF::_TPUCompileMlirOp>(compile_launch.GetBody().front()); llvm::cast<TF::_TPUCompileMlirOp>(compile_launch.GetBody().front());
tensorflow::tpu::TPUCompileMetadataProto metadata; tensorflow::tpu::TPUCompileMetadataProto metadata;
@ -273,9 +278,10 @@ void HandleCompileAndExecutes(
compile.getContext())); compile.getContext()));
} }
void TPUDynamicLayoutPass::runOnFunction() { void TPUDynamicLayoutPass::runOnFunction(
TF::ResourceAliasAnalysis resource_alias_analysis(getFunction()); FuncOp func,
getFunction().walk([&](TF::_TPUCompileMlirOp compile) { const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
func.walk([&](TF::_TPUCompileMlirOp compile) {
// Detect tf._TPUCompileMlir -> tf.TPUExecute(s). // Detect tf._TPUCompileMlir -> tf.TPUExecute(s).
auto compile_launch = auto compile_launch =
llvm::dyn_cast<tf_device::LaunchOp>(compile.getParentOp()); llvm::dyn_cast<tf_device::LaunchOp>(compile.getParentOp());
@ -295,13 +301,13 @@ void TPUDynamicLayoutPass::runOnFunction() {
} }
HandleCompileAndExecutes(compile_launch, execute_launches, HandleCompileAndExecutes(compile_launch, execute_launches,
&resource_alias_analysis); resource_alias_analysis);
}); });
} }
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> CreateTPUDynamicLayoutPass() { std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDynamicLayoutPass() {
return std::make_unique<TPUDynamicLayoutPass>(); return std::make_unique<TPUDynamicLayoutPass>();
} }

View File

@ -41,25 +41,28 @@ namespace mlir {
namespace { namespace {
struct BreakUpIslands : PassWrapper<BreakUpIslands, FunctionPass> { class BreakUpIslands : public TF::PerFunctionAggregateAnalysisConsumerPass<
void runOnFunction() final; BreakUpIslands, TF::SideEffectAnalysis> {
public:
void runOnFunction(FuncOp func,
const TF::SideEffectAnalysis::Info& side_effect_analysis);
void BreakUpIsland(tf_executor::IslandOp island_op, void BreakUpIsland(tf_executor::IslandOp island_op,
const TF::SideEffectAnalysis& side_effect_analysis, const TF::SideEffectAnalysis::Info& side_effect_analysis,
llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>>* llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>>*
new_control_inputs); new_control_inputs);
}; };
void BreakUpIslands::runOnFunction() { void BreakUpIslands::runOnFunction(
auto graph_op_range = getFunction().getBody().front().without_terminator(); FuncOp func, const TF::SideEffectAnalysis::Info& side_effect_analysis) {
auto graph_op_range = func.front().without_terminator();
tf_executor::GraphOp graph_op; tf_executor::GraphOp graph_op;
if (graph_op_range.begin() != graph_op_range.end() &&
std::next(graph_op_range.begin()) == graph_op_range.end()) { if (llvm::hasSingleElement(graph_op_range))
graph_op = dyn_cast<tf_executor::GraphOp>( graph_op = dyn_cast<tf_executor::GraphOp>(func.front().front());
getOperation().getBody().front().front());
}
if (!graph_op) { if (!graph_op) {
getOperation().emitError("expected function to contain only a graph_op"); func.emitError("expected function to contain only a graph_op");
signalPassFailure(); signalPassFailure();
return; return;
} }
@ -67,7 +70,6 @@ void BreakUpIslands::runOnFunction() {
// New control inputs to be added. For an operation x, new_control_inputs[x] // New control inputs to be added. For an operation x, new_control_inputs[x]
// contains all control inputs that need to be added to x as operands. // contains all control inputs that need to be added to x as operands.
llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>> new_control_inputs; llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>> new_control_inputs;
auto& side_effect_analysis = getAnalysis<TF::SideEffectAnalysis>();
// Iterate in reverse order to avoid invalidating Operation* stored in // Iterate in reverse order to avoid invalidating Operation* stored in
// new_control_inputs. // new_control_inputs.
for (auto& item : for (auto& item :
@ -76,7 +78,7 @@ void BreakUpIslands::runOnFunction() {
BreakUpIsland(island, side_effect_analysis, &new_control_inputs); BreakUpIsland(island, side_effect_analysis, &new_control_inputs);
} }
} }
OpBuilder builder(getOperation()); OpBuilder builder(func);
// For every op, add new control inputs in reverse order so that the ops don't // For every op, add new control inputs in reverse order so that the ops don't
// get invalidated. // get invalidated.
@ -181,7 +183,7 @@ struct IslandSourcesAndSinks {
// Finds IslandSourcesAndSinks for an unmodified island. // Finds IslandSourcesAndSinks for an unmodified island.
IslandSourcesAndSinks FindSourcesAndSinksInIsland( IslandSourcesAndSinks FindSourcesAndSinksInIsland(
tf_executor::IslandOp island, tf_executor::IslandOp island,
const TF::SideEffectAnalysis& side_effect_analysis) { const TF::SideEffectAnalysis::Info& side_effect_analysis) {
IslandSourcesAndSinks result; IslandSourcesAndSinks result;
auto island_body = island.GetBody().without_terminator(); auto island_body = island.GetBody().without_terminator();
for (Operation& sub_op : island_body) { for (Operation& sub_op : island_body) {
@ -208,7 +210,7 @@ IslandSourcesAndSinks FindSourcesAndSinksInIsland(
// are chained together by control flow values. // are chained together by control flow values.
void BreakUpIslands::BreakUpIsland( void BreakUpIslands::BreakUpIsland(
tf_executor::IslandOp island_op, tf_executor::IslandOp island_op,
const TF::SideEffectAnalysis& side_effect_analysis, const TF::SideEffectAnalysis::Info& side_effect_analysis,
llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>>* llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>>*
new_control_inputs) { new_control_inputs) {
auto island_body = island_op.GetBody().without_terminator(); auto island_body = island_op.GetBody().without_terminator();
@ -323,7 +325,7 @@ void BreakUpIslands::BreakUpIsland(
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> CreateBreakUpIslandsPass() { std::unique_ptr<OperationPass<ModuleOp>> CreateBreakUpIslandsPass() {
return std::make_unique<BreakUpIslands>(); return std::make_unique<BreakUpIslands>();
} }

View File

@ -45,7 +45,7 @@ void AddTFToTFJSConversionPasses(mlir::OpPassManager* pm) {
// raise to executor dialect in order to use GraphDef converter // raise to executor dialect in order to use GraphDef converter
pm->addNestedPass<mlir::FuncOp>( pm->addNestedPass<mlir::FuncOp>(
mlir::CreateFunctionalToExecutorDialectConversionPass()); mlir::CreateFunctionalToExecutorDialectConversionPass());
pm->addNestedPass<mlir::FuncOp>(mlir::CreateBreakUpIslandsPass()); pm->addPass(mlir::CreateBreakUpIslandsPass());
} }
} // namespace tensorflow } // namespace tensorflow