[MLIR:TF/XLA] Resource lifting for PartitionedCallOp/StatefulPartitionedCallOp
If a called function involves resources, clone it then lift the resource ops outside. Multiple call sites will share the same lifted callee function. PiperOrigin-RevId: 295793372 Change-Id: I39b00dab43815216a5fa5b2d594f3d391f871290
This commit is contained in:
parent
caad1b7a45
commit
11b27dd35a
@ -542,3 +542,116 @@ func @if_else(%arg0: tensor<*x!tf.resource<tensor<4xf32>>>, %arg1: tensor<*x!tf.
|
||||
-> (tensor<*x!tf.resource<tensor<4xf32>>>) {
|
||||
return %arg1 : tensor<*x!tf.resource<tensor<4xf32>>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass lifts resources on two partitioned call ops sharing the
|
||||
// same callee. The lifting should clone the callee then modify the clone.
|
||||
|
||||
// CHECK-LABEL: @launch_with_partitioned_call
|
||||
func @launch_with_partitioned_call() -> tensor<f32> {
|
||||
// CHECK: %[[VH:.*]] = "tf.VarHandleOp"()
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
// CHECK: %[[CONST:.*]] = "tf.Const"()
|
||||
%1 = "tf.Const"() {value = dense<10.0> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VH]])
|
||||
// CHECK: %[[LAUNCH:.*]] = "tf_device.launch"()
|
||||
%2 = "tf_device.launch"() ( {
|
||||
// CHECK: %[[PC0:.*]] = "tf.PartitionedCall"(%[[CONST]], %[[READ]], %[[CONST]])
|
||||
// CHECK-SAME: f = @callee_resource_lifted
|
||||
%3 = "tf.PartitionedCall"(%1, %0, %1) {f = @callee, config = "", config_proto = "", executor_type = ""}
|
||||
: (tensor<f32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[PC1:.*]] = "tf.PartitionedCall"(%[[CONST]], %[[READ]], %[[CONST]])
|
||||
// CHECK-SAME: f = @callee_resource_lifted
|
||||
%4 = "tf.PartitionedCall"(%1, %0, %1) {f = @callee, config = "", config_proto = "", executor_type = ""}
|
||||
: (tensor<f32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[PC0]], %[[PC1]])
|
||||
%5 = "tf.AddV2"(%3, %4) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: tf_device.return %[[ADD]] : tensor<f32>
|
||||
tf_device.return %5 : tensor<f32>
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<f32>
|
||||
return %2 : tensor<f32>
|
||||
}
|
||||
// CHECK: @callee(%[[OA0:.*]]: tensor<f32>, %[[OA1:.*]]: tensor<*x!tf.resource<tensor<f32>>>, %[[OA2:.*]]: tensor<f32>) -> tensor<f32>
|
||||
func @callee(%arg0: tensor<f32>, %arg1: tensor<*x!tf.resource<tensor<f32>>>, %arg2: tensor<f32>) -> tensor<f32> {
|
||||
// CHECK: "tf.ReadVariableOp"(%[[OA1]])
|
||||
%0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
%1 = "tf.AddV2"(%0, %arg0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
%2 = "tf.AddV2"(%1, %arg2) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
return %2 : tensor<f32>
|
||||
}
|
||||
// CHECK: func @callee_resource_lifted(%[[A0:.*]]: tensor<f32>, %[[A1:.*]]: tensor<f32>, %[[A2:.*]]: tensor<f32>) -> tensor<f32>
|
||||
// CHECK-NEXT: %[[ADD0:.*]] = "tf.AddV2"(%[[A1]], %[[A0]])
|
||||
// CHECK-NEXT: %[[ADD1:.*]] = "tf.AddV2"(%[[ADD0]], %[[A2]])
|
||||
// CHECK-NEXT: return %[[ADD1]]
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass lifts resources on two stateful partitioned call ops
|
||||
// sharing the same callee. The lifting should clone the callee then modify the
|
||||
// clone.
|
||||
|
||||
// CHECK-LABEL: @launch_with_stateful_partitioned_call
|
||||
func @launch_with_stateful_partitioned_call() -> () {
|
||||
// CHECK: %[[VH0:.*]] = "tf.VarHandleOp"()
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
// CHECK: %[[VH1:.*]] = "tf.VarHandleOp"()
|
||||
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
// CHECK: %[[CONST:.*]] = "tf.Const"()
|
||||
%2 = "tf.Const"() {value = dense<10.0> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK-DAG: %[[READ0:.*]] = "tf.ReadVariableOp"(%[[VH0]])
|
||||
// CHECK-DAG: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[VH1]])
|
||||
// CHECK: %[[LAUNCH:.*]] = "tf_device.launch"()
|
||||
"tf_device.launch"() ( {
|
||||
// CHECK: %[[PC0:.*]] = "tf.StatefulPartitionedCall"(%[[READ0]], %[[READ1]], %[[CONST]])
|
||||
// CHECK-SAME: f = @callee_resource_lifted
|
||||
%3 = "tf.StatefulPartitionedCall"(%0, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
// CHECK: %[[PC1:.*]] = "tf.StatefulPartitionedCall"(%[[PC0]], %[[READ1]], %[[CONST]])
|
||||
// CHECK-SAME: f = @callee_resource_lifted
|
||||
%4 = "tf.StatefulPartitionedCall"(%3, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
// CHECK: tf_device.return %[[PC1]] : tensor<f32>
|
||||
tf_device.return
|
||||
// CHECK: {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<f32>
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
|
||||
// CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[LAUNCH]])
|
||||
return
|
||||
}
|
||||
// CHECK: @callee(%[[OA0:.*]]: tensor<*x!tf.resource<tensor<f32>>>, %[[OA1:.*]]: tensor<*x!tf.resource<tensor<f32>>>, %[[OA2:.*]]: tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
func @callee(%arg0: tensor<*x!tf.resource<tensor<f32>>>, %arg1: tensor<*x!tf.resource<tensor<f32>>>, %arg2: tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>> {
|
||||
// CHECK: "tf.ReadVariableOp"(%[[OA1]])
|
||||
%0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
%1 = "tf.AddV2"(%0, %arg2) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
"tf.AssignVariableOp"(%arg0, %1) {dtype = i32} : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
|
||||
return %arg0 : tensor<*x!tf.resource<tensor<f32>>>
|
||||
}
|
||||
// CHECK: func @callee_resource_lifted(%[[A0:.*]]: tensor<f32>, %[[A1:.*]]: tensor<f32>, %[[A2:.*]]: tensor<f32>) -> tensor<f32>
|
||||
// CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[A1]], %[[A2]])
|
||||
// CHECK-NEXT: return %[[ADD]]
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass reports error on called function that has resource output
|
||||
// which doesn't alias an input.
|
||||
|
||||
func @launch_with_stateful_partitioned_call() -> () {
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
%2 = "tf.Const"() {value = dense<10.0> : tensor<f32>} : () -> tensor<f32>
|
||||
"tf_device.launch"() ( {
|
||||
%3 = "tf.StatefulPartitionedCall"(%0, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
%4 = "tf.StatefulPartitionedCall"(%3, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
tf_device.return
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
// expected-error @+1 {{Unsupported function call: resource return value does not alias an input.}}
|
||||
func @callee(%arg0: tensor<*x!tf.resource<tensor<f32>>>, %arg1: tensor<*x!tf.resource<tensor<f32>>>, %arg2: tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>> {
|
||||
%0 = "tf._Unknown_"() : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
return %0 : tensor<*x!tf.resource<tensor<f32>>>
|
||||
}
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
|
||||
#include "mlir/IR/Types.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
@ -811,16 +812,185 @@ LogicalResult HanldeIfOP(TF::IfOp if_op, FuncOp then_branch,
|
||||
return success();
|
||||
}
|
||||
|
||||
// A resource-lifted function for (potentially multiple) PartitionedCallOps and
|
||||
// information about the lifting changes.
|
||||
struct PartitionedCallLiftingInfo {
|
||||
// Function with resources lifted. Can be nullptr if nothing needs to change.
|
||||
FuncOp lifted_callee;
|
||||
// Mapping from old resource outputs to their aliasing output inputs.
|
||||
llvm::SmallDenseMap<int64_t, int64_t> old_outputs_aliasing_old_inputs;
|
||||
// Mapping from old to new output indices in case any output is removed.
|
||||
llvm::SmallVector<int64_t, 4> old_to_new_output_indices;
|
||||
// ResourceArgUseInfo for each old resource argument.
|
||||
llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> use_info;
|
||||
// Input for AddLoadsStoresOutsideControlFlowOp(), see its comment.
|
||||
llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
|
||||
arg_data_type_and_updated_output_index;
|
||||
};
|
||||
|
||||
// Lifts loads/stores from a PartitionedCallOp's callee function. If anything
|
||||
// needs to be changed, the original function will be preserved, and the lifting
|
||||
// happens on a clone, which will be stored in `result`.
|
||||
LogicalResult HandlePartitionedCallOpCallee(
|
||||
FuncOp callee, PartitionedCallLiftingInfo* result) {
|
||||
// Remove identity nodes to avoid aliasing.
|
||||
RemoveIdentity(&callee.front());
|
||||
// Sanity check: return of resources should be aliases of inputs. Such outputs
|
||||
// will be removed later.
|
||||
int64_t non_resource_results = 0;
|
||||
for (auto entry :
|
||||
llvm::enumerate(callee.front().getTerminator()->getOperands())) {
|
||||
auto retval = entry.value();
|
||||
if (!getElementTypeOrSelf(retval.getType()).isa<TF::ResourceType>()) {
|
||||
result->old_to_new_output_indices.push_back(non_resource_results++);
|
||||
continue;
|
||||
}
|
||||
auto aliasing_arg = retval.dyn_cast<BlockArgument>();
|
||||
if (!aliasing_arg) {
|
||||
return callee.emitOpError(
|
||||
"Unsupported function call: resource return value does not alias an "
|
||||
"input.");
|
||||
}
|
||||
result->old_outputs_aliasing_old_inputs[entry.index()] =
|
||||
aliasing_arg.getArgNumber();
|
||||
result->old_to_new_output_indices.push_back(-1);
|
||||
}
|
||||
|
||||
if (failed(FindResourceArgUseInfo(callee, &result->use_info))) {
|
||||
return failure();
|
||||
}
|
||||
if (result->use_info.empty()) {
|
||||
result->lifted_callee = nullptr;
|
||||
return success();
|
||||
}
|
||||
|
||||
// Clone the callee before making changes.
|
||||
SmallString<64> name_base = callee.getName();
|
||||
auto module = callee.getParentOfType<ModuleOp>();
|
||||
name_base += "_resource_lifted";
|
||||
auto name = name_base;
|
||||
{
|
||||
int64_t counter = 0;
|
||||
while (module.lookupSymbol(name)) {
|
||||
auto name = name_base;
|
||||
name += "_" + std::to_string(counter++);
|
||||
}
|
||||
}
|
||||
callee = callee.clone();
|
||||
callee.setName(name);
|
||||
SymbolTable(module).insert(callee);
|
||||
result->lifted_callee = callee;
|
||||
|
||||
// Remove unused resources in functions.
|
||||
llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types;
|
||||
RemoveUnusedResourceArgumentsAndForwardedRetvals(
|
||||
result->use_info, callee, /*old_to_new_arg_indices=*/nullptr,
|
||||
&remaining_resource_data_types);
|
||||
for (const auto& entry : remaining_resource_data_types) {
|
||||
result->arg_data_type_and_updated_output_index[entry.getFirst()] = {
|
||||
entry.getSecond(), -1};
|
||||
}
|
||||
llvm::SmallVector<Value, 4> new_retvals;
|
||||
for (auto val : callee.front().getTerminator()->getOperands()) {
|
||||
// Remove resource type outputs.
|
||||
if (getElementTypeOrSelf(val.getType()).isa<TF::ResourceType>()) continue;
|
||||
new_retvals.push_back(val);
|
||||
}
|
||||
// Lift resources.
|
||||
LiftArgRetResourcesForFunction(
|
||||
callee, remaining_resource_data_types, [&](int64_t index, Value value) {
|
||||
result->arg_data_type_and_updated_output_index[index].second =
|
||||
new_retvals.size();
|
||||
new_retvals.push_back(value);
|
||||
});
|
||||
auto old_return = callee.front().getTerminator();
|
||||
// Replace old return with the new ones with update values.
|
||||
OpBuilder builder(old_return);
|
||||
auto new_return = builder.create<ReturnOp>(old_return->getLoc(), new_retvals);
|
||||
old_return->erase();
|
||||
callee.setType(FunctionType::get(
|
||||
callee.getType().getInputs(),
|
||||
llvm::to_vector<4>(new_return.getOperandTypes()), callee.getContext()));
|
||||
return success();
|
||||
}
|
||||
|
||||
// Updates a PartitionedCallOp/StatefulPartitionedCallOp according to the
|
||||
// resource-lifted new callee function in lifting_info.
|
||||
template <typename CallOpType>
|
||||
void UpdatePartitionedCallOpWithNewCallee(
|
||||
CallOpType call_op, const PartitionedCallLiftingInfo& lifting_info) {
|
||||
if (lifting_info.lifted_callee == nullptr) return;
|
||||
// Replace output resource uses with the aliasing input, so that we can remove
|
||||
// this output.
|
||||
for (const auto& entry : lifting_info.old_outputs_aliasing_old_inputs) {
|
||||
call_op.getResult(entry.getFirst())
|
||||
.replaceAllUsesWith(call_op.getOperand(entry.getSecond()));
|
||||
}
|
||||
// Recreate the call op.
|
||||
OpBuilder builder(call_op);
|
||||
// Now use the filtered original operands, which will be replaced by
|
||||
// AddLoadsStoresOutsideControlFlowOp().
|
||||
auto new_operands =
|
||||
FilterRange<Value, OperandRange>(call_op.args(), lifting_info.use_info);
|
||||
auto new_call = builder.create<CallOpType>(
|
||||
call_op.getLoc(),
|
||||
const_cast<FuncOp&>(lifting_info.lifted_callee).getType().getResults(),
|
||||
new_operands, call_op.getAttrs());
|
||||
new_call.setAttr(
|
||||
"f", builder.getSymbolRefAttr(
|
||||
const_cast<FuncOp&>(lifting_info.lifted_callee).getName()));
|
||||
AddLoadsStoresOutsideControlFlowOp(
|
||||
new_call, lifting_info.arg_data_type_and_updated_output_index);
|
||||
// Replace uses.
|
||||
for (int64_t i = 0; i < lifting_info.old_to_new_output_indices.size(); ++i) {
|
||||
if (lifting_info.old_to_new_output_indices[i] >= 0) {
|
||||
call_op.getResult(i).replaceAllUsesWith(
|
||||
new_call.getResult(lifting_info.old_to_new_output_indices[i]));
|
||||
}
|
||||
}
|
||||
call_op.erase();
|
||||
}
|
||||
|
||||
LogicalResult HoistForFunctionalControlFlow(
|
||||
Block*, ModuleOp, llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>*);
|
||||
|
||||
// A templated routine for handling both PartitionedCallOp and
|
||||
// StatefulPartitionedCallOp. If the callee is already lifted, it just updates
|
||||
// the caller op itself; otherwise, it first recursively handles nested control
|
||||
// flow, then performs lifting on the callee.
|
||||
template <typename CallOpType>
|
||||
LogicalResult HandlePartitionedCallOp(
|
||||
CallOpType call_op, FuncOp callee, ModuleOp module,
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>* lifted_callees) {
|
||||
auto emplace_res =
|
||||
lifted_callees->try_emplace(callee, PartitionedCallLiftingInfo());
|
||||
if (emplace_res.second) {
|
||||
// Unseen callee. Perform resource lifting on it.
|
||||
HoistForFunctionalControlFlow(&callee.front(), module, lifted_callees);
|
||||
if (failed(HandlePartitionedCallOpCallee(
|
||||
callee, &emplace_res.first->getSecond()))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
UpdatePartitionedCallOpWithNewCallee(call_op, emplace_res.first->getSecond());
|
||||
return success();
|
||||
}
|
||||
|
||||
// Hoists resource loads/stores from control flow ops in `block` outside the
|
||||
// body/cond/branch functions.
|
||||
LogicalResult HoistForFunctionalControlFlow(Block* block, ModuleOp module) {
|
||||
// body/cond/branch/callee functions.
|
||||
LogicalResult HoistForFunctionalControlFlow(
|
||||
Block* block, ModuleOp module,
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>*
|
||||
lifted_partitioned_call_callees) {
|
||||
for (Operation& op : llvm::make_early_inc_range(*block)) {
|
||||
if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
|
||||
auto body = llvm::cast<FuncOp>(module.lookupSymbol(while_op.body()));
|
||||
auto cond = llvm::cast<FuncOp>(module.lookupSymbol(while_op.cond()));
|
||||
// Recursively handle the nested control flow.
|
||||
HoistForFunctionalControlFlow(&body.front(), module);
|
||||
HoistForFunctionalControlFlow(&cond.front(), module);
|
||||
HoistForFunctionalControlFlow(&body.front(), module,
|
||||
lifted_partitioned_call_callees);
|
||||
HoistForFunctionalControlFlow(&cond.front(), module,
|
||||
lifted_partitioned_call_callees);
|
||||
if (failed(HanldeWhileLoop(while_op, body, cond))) return failure();
|
||||
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
|
||||
auto then_branch =
|
||||
@ -828,9 +998,30 @@ LogicalResult HoistForFunctionalControlFlow(Block* block, ModuleOp module) {
|
||||
auto else_branch =
|
||||
llvm::cast<FuncOp>(module.lookupSymbol(if_op.else_branch()));
|
||||
// Recursively handle the nested control flow.
|
||||
HoistForFunctionalControlFlow(&then_branch.front(), module);
|
||||
HoistForFunctionalControlFlow(&else_branch.front(), module);
|
||||
HoistForFunctionalControlFlow(&then_branch.front(), module,
|
||||
lifted_partitioned_call_callees);
|
||||
HoistForFunctionalControlFlow(&else_branch.front(), module,
|
||||
lifted_partitioned_call_callees);
|
||||
if (failed(HanldeIfOP(if_op, then_branch, else_branch))) return failure();
|
||||
} else if (auto call_op = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
|
||||
if (!call_op.f().isa<FlatSymbolRefAttr>()) {
|
||||
return call_op.emitError(
|
||||
"Resource lifting does not support call with nested references.");
|
||||
}
|
||||
auto callee = llvm::cast<FuncOp>(
|
||||
module.lookupSymbol(call_op.f().getRootReference()));
|
||||
if (failed(HandlePartitionedCallOp(call_op, callee, module,
|
||||
lifted_partitioned_call_callees))) {
|
||||
// Nested control flow handling is done in HandlePartitionedCallOp().
|
||||
return failure();
|
||||
}
|
||||
} else if (auto call_op =
|
||||
llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
|
||||
auto callee = llvm::cast<FuncOp>(module.lookupSymbol(call_op.f()));
|
||||
if (failed(HandlePartitionedCallOp(call_op, callee, module,
|
||||
lifted_partitioned_call_callees))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
}
|
||||
return success();
|
||||
@ -840,10 +1031,13 @@ LogicalResult HoistForFunctionalControlFlow(Block* block, ModuleOp module) {
|
||||
// outside. Returns failure if there are remaining resource-type values that can
|
||||
// not be lifted.
|
||||
void ResourceOpLiftingPass::runOnModule() {
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>
|
||||
lifted_partitioned_call_callees;
|
||||
auto result = getModule().walk([&](FuncOp func_op) {
|
||||
return func_op.walk([&](tf_device::LaunchOp launch_op) {
|
||||
if (failed(HoistForFunctionalControlFlow(&launch_op.GetBody(),
|
||||
getModule())) ||
|
||||
if (failed(HoistForFunctionalControlFlow(
|
||||
&launch_op.GetBody(), getModule(),
|
||||
&lifted_partitioned_call_callees)) ||
|
||||
failed(HoistResourceOpsFromLaunchOp(launch_op))) {
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
@ -901,8 +1095,11 @@ LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function) {
|
||||
<< function.getBlocks().size();
|
||||
}
|
||||
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>
|
||||
lifted_partitioned_call_callees;
|
||||
return HoistForFunctionalControlFlow(&function.front(),
|
||||
cast<ModuleOp>(function.getParentOp()));
|
||||
cast<ModuleOp>(function.getParentOp()),
|
||||
&lifted_partitioned_call_callees);
|
||||
}
|
||||
} // namespace TF
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user