[MLIR:TF/XLA] Resource lifting for PartitionedCallOp/StatefulPartitionedCallOp

If a called function involves resources, clone it then lift the resource ops outside. Multiple call sites will share the same lifted callee function.

PiperOrigin-RevId: 295793372
Change-Id: I39b00dab43815216a5fa5b2d594f3d391f871290
This commit is contained in:
Yuanzhong Xu 2020-02-18 12:16:05 -08:00 committed by TensorFlower Gardener
parent caad1b7a45
commit 11b27dd35a
2 changed files with 319 additions and 9 deletions

View File

@ -542,3 +542,116 @@ func @if_else(%arg0: tensor<*x!tf.resource<tensor<4xf32>>>, %arg1: tensor<*x!tf.
-> (tensor<*x!tf.resource<tensor<4xf32>>>) {
return %arg1 : tensor<*x!tf.resource<tensor<4xf32>>>
}
// -----
// Tests that the pass lifts resources on two partitioned call ops sharing the
// same callee. The lifting should clone the callee then modify the clone.
// CHECK-LABEL: @launch_with_partitioned_call
func @launch_with_partitioned_call() -> tensor<f32> {
// CHECK: %[[VH:.*]] = "tf.VarHandleOp"()
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
// CHECK: %[[CONST:.*]] = "tf.Const"()
%1 = "tf.Const"() {value = dense<10.0> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VH]])
// CHECK: %[[LAUNCH:.*]] = "tf_device.launch"()
%2 = "tf_device.launch"() ( {
// CHECK: %[[PC0:.*]] = "tf.PartitionedCall"(%[[CONST]], %[[READ]], %[[CONST]])
// CHECK-SAME: f = @callee_resource_lifted
%3 = "tf.PartitionedCall"(%1, %0, %1) {f = @callee, config = "", config_proto = "", executor_type = ""}
: (tensor<f32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<f32>
// CHECK: %[[PC1:.*]] = "tf.PartitionedCall"(%[[CONST]], %[[READ]], %[[CONST]])
// CHECK-SAME: f = @callee_resource_lifted
%4 = "tf.PartitionedCall"(%1, %0, %1) {f = @callee, config = "", config_proto = "", executor_type = ""}
: (tensor<f32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<f32>
// CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[PC0]], %[[PC1]])
%5 = "tf.AddV2"(%3, %4) : (tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: tf_device.return %[[ADD]] : tensor<f32>
tf_device.return %5 : tensor<f32>
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<f32>
return %2 : tensor<f32>
}
// CHECK: @callee(%[[OA0:.*]]: tensor<f32>, %[[OA1:.*]]: tensor<*x!tf.resource<tensor<f32>>>, %[[OA2:.*]]: tensor<f32>) -> tensor<f32>
func @callee(%arg0: tensor<f32>, %arg1: tensor<*x!tf.resource<tensor<f32>>>, %arg2: tensor<f32>) -> tensor<f32> {
// CHECK: "tf.ReadVariableOp"(%[[OA1]])
%0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
%1 = "tf.AddV2"(%0, %arg0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%2 = "tf.AddV2"(%1, %arg2) : (tensor<f32>, tensor<f32>) -> tensor<f32>
return %2 : tensor<f32>
}
// CHECK: func @callee_resource_lifted(%[[A0:.*]]: tensor<f32>, %[[A1:.*]]: tensor<f32>, %[[A2:.*]]: tensor<f32>) -> tensor<f32>
// CHECK-NEXT: %[[ADD0:.*]] = "tf.AddV2"(%[[A1]], %[[A0]])
// CHECK-NEXT: %[[ADD1:.*]] = "tf.AddV2"(%[[ADD0]], %[[A2]])
// CHECK-NEXT: return %[[ADD1]]
// -----
// Tests that the pass lifts resources on two stateful partitioned call ops
// sharing the same callee. The lifting should clone the callee then modify the
// clone.
// CHECK-LABEL: @launch_with_stateful_partitioned_call
func @launch_with_stateful_partitioned_call() -> () {
// CHECK: %[[VH0:.*]] = "tf.VarHandleOp"()
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
// CHECK: %[[VH1:.*]] = "tf.VarHandleOp"()
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource<tensor<f32>>>
// CHECK: %[[CONST:.*]] = "tf.Const"()
%2 = "tf.Const"() {value = dense<10.0> : tensor<f32>} : () -> tensor<f32>
// CHECK-DAG: %[[READ0:.*]] = "tf.ReadVariableOp"(%[[VH0]])
// CHECK-DAG: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[VH1]])
// CHECK: %[[LAUNCH:.*]] = "tf_device.launch"()
"tf_device.launch"() ( {
// CHECK: %[[PC0:.*]] = "tf.StatefulPartitionedCall"(%[[READ0]], %[[READ1]], %[[CONST]])
// CHECK-SAME: f = @callee_resource_lifted
%3 = "tf.StatefulPartitionedCall"(%0, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""}
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
// CHECK: %[[PC1:.*]] = "tf.StatefulPartitionedCall"(%[[PC0]], %[[READ1]], %[[CONST]])
// CHECK-SAME: f = @callee_resource_lifted
%4 = "tf.StatefulPartitionedCall"(%3, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""}
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
// CHECK: tf_device.return %[[PC1]] : tensor<f32>
tf_device.return
// CHECK: {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<f32>
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
// CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[LAUNCH]])
return
}
// CHECK: @callee(%[[OA0:.*]]: tensor<*x!tf.resource<tensor<f32>>>, %[[OA1:.*]]: tensor<*x!tf.resource<tensor<f32>>>, %[[OA2:.*]]: tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
func @callee(%arg0: tensor<*x!tf.resource<tensor<f32>>>, %arg1: tensor<*x!tf.resource<tensor<f32>>>, %arg2: tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>> {
// CHECK: "tf.ReadVariableOp"(%[[OA1]])
%0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
%1 = "tf.AddV2"(%0, %arg2) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"tf.AssignVariableOp"(%arg0, %1) {dtype = i32} : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
return %arg0 : tensor<*x!tf.resource<tensor<f32>>>
}
// CHECK: func @callee_resource_lifted(%[[A0:.*]]: tensor<f32>, %[[A1:.*]]: tensor<f32>, %[[A2:.*]]: tensor<f32>) -> tensor<f32>
// CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[A1]], %[[A2]])
// CHECK-NEXT: return %[[ADD]]
// -----
// Tests that the pass reports error on called function that has resource output
// which doesn't alias an input.
func @launch_with_stateful_partitioned_call() -> () {
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource<tensor<f32>>>
%2 = "tf.Const"() {value = dense<10.0> : tensor<f32>} : () -> tensor<f32>
"tf_device.launch"() ( {
%3 = "tf.StatefulPartitionedCall"(%0, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""}
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
%4 = "tf.StatefulPartitionedCall"(%3, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""}
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>>
tf_device.return
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
return
}
// expected-error @+1 {{Unsupported function call: resource return value does not alias an input.}}
func @callee(%arg0: tensor<*x!tf.resource<tensor<f32>>>, %arg1: tensor<*x!tf.resource<tensor<f32>>>, %arg2: tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>> {
%0 = "tf._Unknown_"() : () -> tensor<*x!tf.resource<tensor<f32>>>
return %0 : tensor<*x!tf.resource<tensor<f32>>>
}

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
#include "mlir/IR/Types.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
@ -811,16 +812,185 @@ LogicalResult HanldeIfOP(TF::IfOp if_op, FuncOp then_branch,
return success();
}
// A resource-lifted function for (potentially multiple) PartitionedCallOps and
// information about the lifting changes.
struct PartitionedCallLiftingInfo {
// Function with resources lifted. Can be nullptr if nothing needs to change.
FuncOp lifted_callee;
// Mapping from old resource outputs to their aliasing output inputs.
llvm::SmallDenseMap<int64_t, int64_t> old_outputs_aliasing_old_inputs;
// Mapping from old to new output indices in case any output is removed.
llvm::SmallVector<int64_t, 4> old_to_new_output_indices;
// ResourceArgUseInfo for each old resource argument.
llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> use_info;
// Input for AddLoadsStoresOutsideControlFlowOp(), see its comment.
llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
arg_data_type_and_updated_output_index;
};
// Lifts loads/stores from a PartitionedCallOp's callee function. If anything
// needs to be changed, the original function will be preserved, and the lifting
// happens on a clone, which will be stored in `result`.
LogicalResult HandlePartitionedCallOpCallee(
FuncOp callee, PartitionedCallLiftingInfo* result) {
// Remove identity nodes to avoid aliasing.
RemoveIdentity(&callee.front());
// Sanity check: return of resources should be aliases of inputs. Such outputs
// will be removed later.
int64_t non_resource_results = 0;
for (auto entry :
llvm::enumerate(callee.front().getTerminator()->getOperands())) {
auto retval = entry.value();
if (!getElementTypeOrSelf(retval.getType()).isa<TF::ResourceType>()) {
result->old_to_new_output_indices.push_back(non_resource_results++);
continue;
}
auto aliasing_arg = retval.dyn_cast<BlockArgument>();
if (!aliasing_arg) {
return callee.emitOpError(
"Unsupported function call: resource return value does not alias an "
"input.");
}
result->old_outputs_aliasing_old_inputs[entry.index()] =
aliasing_arg.getArgNumber();
result->old_to_new_output_indices.push_back(-1);
}
if (failed(FindResourceArgUseInfo(callee, &result->use_info))) {
return failure();
}
if (result->use_info.empty()) {
result->lifted_callee = nullptr;
return success();
}
// Clone the callee before making changes.
SmallString<64> name_base = callee.getName();
auto module = callee.getParentOfType<ModuleOp>();
name_base += "_resource_lifted";
auto name = name_base;
{
int64_t counter = 0;
while (module.lookupSymbol(name)) {
auto name = name_base;
name += "_" + std::to_string(counter++);
}
}
callee = callee.clone();
callee.setName(name);
SymbolTable(module).insert(callee);
result->lifted_callee = callee;
// Remove unused resources in functions.
llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types;
RemoveUnusedResourceArgumentsAndForwardedRetvals(
result->use_info, callee, /*old_to_new_arg_indices=*/nullptr,
&remaining_resource_data_types);
for (const auto& entry : remaining_resource_data_types) {
result->arg_data_type_and_updated_output_index[entry.getFirst()] = {
entry.getSecond(), -1};
}
llvm::SmallVector<Value, 4> new_retvals;
for (auto val : callee.front().getTerminator()->getOperands()) {
// Remove resource type outputs.
if (getElementTypeOrSelf(val.getType()).isa<TF::ResourceType>()) continue;
new_retvals.push_back(val);
}
// Lift resources.
LiftArgRetResourcesForFunction(
callee, remaining_resource_data_types, [&](int64_t index, Value value) {
result->arg_data_type_and_updated_output_index[index].second =
new_retvals.size();
new_retvals.push_back(value);
});
auto old_return = callee.front().getTerminator();
// Replace old return with the new ones with update values.
OpBuilder builder(old_return);
auto new_return = builder.create<ReturnOp>(old_return->getLoc(), new_retvals);
old_return->erase();
callee.setType(FunctionType::get(
callee.getType().getInputs(),
llvm::to_vector<4>(new_return.getOperandTypes()), callee.getContext()));
return success();
}
// Updates a PartitionedCallOp/StatefulPartitionedCallOp according to the
// resource-lifted new callee function in lifting_info.
template <typename CallOpType>
void UpdatePartitionedCallOpWithNewCallee(
CallOpType call_op, const PartitionedCallLiftingInfo& lifting_info) {
if (lifting_info.lifted_callee == nullptr) return;
// Replace output resource uses with the aliasing input, so that we can remove
// this output.
for (const auto& entry : lifting_info.old_outputs_aliasing_old_inputs) {
call_op.getResult(entry.getFirst())
.replaceAllUsesWith(call_op.getOperand(entry.getSecond()));
}
// Recreate the call op.
OpBuilder builder(call_op);
// Now use the filtered original operands, which will be replaced by
// AddLoadsStoresOutsideControlFlowOp().
auto new_operands =
FilterRange<Value, OperandRange>(call_op.args(), lifting_info.use_info);
auto new_call = builder.create<CallOpType>(
call_op.getLoc(),
const_cast<FuncOp&>(lifting_info.lifted_callee).getType().getResults(),
new_operands, call_op.getAttrs());
new_call.setAttr(
"f", builder.getSymbolRefAttr(
const_cast<FuncOp&>(lifting_info.lifted_callee).getName()));
AddLoadsStoresOutsideControlFlowOp(
new_call, lifting_info.arg_data_type_and_updated_output_index);
// Replace uses.
for (int64_t i = 0; i < lifting_info.old_to_new_output_indices.size(); ++i) {
if (lifting_info.old_to_new_output_indices[i] >= 0) {
call_op.getResult(i).replaceAllUsesWith(
new_call.getResult(lifting_info.old_to_new_output_indices[i]));
}
}
call_op.erase();
}
LogicalResult HoistForFunctionalControlFlow(
Block*, ModuleOp, llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>*);
// A templated routine for handling both PartitionedCallOp and
// StatefulPartitionedCallOp. If the callee is already lifted, it just updates
// the caller op itself; otherwise, it first recursively handles nested control
// flow, then performs lifting on the callee.
template <typename CallOpType>
LogicalResult HandlePartitionedCallOp(
CallOpType call_op, FuncOp callee, ModuleOp module,
llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>* lifted_callees) {
auto emplace_res =
lifted_callees->try_emplace(callee, PartitionedCallLiftingInfo());
if (emplace_res.second) {
// Unseen callee. Perform resource lifting on it.
HoistForFunctionalControlFlow(&callee.front(), module, lifted_callees);
if (failed(HandlePartitionedCallOpCallee(
callee, &emplace_res.first->getSecond()))) {
return failure();
}
}
UpdatePartitionedCallOpWithNewCallee(call_op, emplace_res.first->getSecond());
return success();
}
// Hoists resource loads/stores from control flow ops in `block` outside the
// body/cond/branch functions.
LogicalResult HoistForFunctionalControlFlow(Block* block, ModuleOp module) {
// body/cond/branch/callee functions.
LogicalResult HoistForFunctionalControlFlow(
Block* block, ModuleOp module,
llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>*
lifted_partitioned_call_callees) {
for (Operation& op : llvm::make_early_inc_range(*block)) {
if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
auto body = llvm::cast<FuncOp>(module.lookupSymbol(while_op.body()));
auto cond = llvm::cast<FuncOp>(module.lookupSymbol(while_op.cond()));
// Recursively handle the nested control flow.
HoistForFunctionalControlFlow(&body.front(), module);
HoistForFunctionalControlFlow(&cond.front(), module);
HoistForFunctionalControlFlow(&body.front(), module,
lifted_partitioned_call_callees);
HoistForFunctionalControlFlow(&cond.front(), module,
lifted_partitioned_call_callees);
if (failed(HanldeWhileLoop(while_op, body, cond))) return failure();
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
auto then_branch =
@ -828,9 +998,30 @@ LogicalResult HoistForFunctionalControlFlow(Block* block, ModuleOp module) {
auto else_branch =
llvm::cast<FuncOp>(module.lookupSymbol(if_op.else_branch()));
// Recursively handle the nested control flow.
HoistForFunctionalControlFlow(&then_branch.front(), module);
HoistForFunctionalControlFlow(&else_branch.front(), module);
HoistForFunctionalControlFlow(&then_branch.front(), module,
lifted_partitioned_call_callees);
HoistForFunctionalControlFlow(&else_branch.front(), module,
lifted_partitioned_call_callees);
if (failed(HanldeIfOP(if_op, then_branch, else_branch))) return failure();
} else if (auto call_op = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
if (!call_op.f().isa<FlatSymbolRefAttr>()) {
return call_op.emitError(
"Resource lifting does not support call with nested references.");
}
auto callee = llvm::cast<FuncOp>(
module.lookupSymbol(call_op.f().getRootReference()));
if (failed(HandlePartitionedCallOp(call_op, callee, module,
lifted_partitioned_call_callees))) {
// Nested control flow handling is done in HandlePartitionedCallOp().
return failure();
}
} else if (auto call_op =
llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
auto callee = llvm::cast<FuncOp>(module.lookupSymbol(call_op.f()));
if (failed(HandlePartitionedCallOp(call_op, callee, module,
lifted_partitioned_call_callees))) {
return failure();
}
}
}
return success();
@ -840,10 +1031,13 @@ LogicalResult HoistForFunctionalControlFlow(Block* block, ModuleOp module) {
// outside. Returns failure if there are remaining resource-type values that can
// not be lifted.
void ResourceOpLiftingPass::runOnModule() {
llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>
lifted_partitioned_call_callees;
auto result = getModule().walk([&](FuncOp func_op) {
return func_op.walk([&](tf_device::LaunchOp launch_op) {
if (failed(HoistForFunctionalControlFlow(&launch_op.GetBody(),
getModule())) ||
if (failed(HoistForFunctionalControlFlow(
&launch_op.GetBody(), getModule(),
&lifted_partitioned_call_callees)) ||
failed(HoistResourceOpsFromLaunchOp(launch_op))) {
return WalkResult::interrupt();
}
@ -901,8 +1095,11 @@ LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function) {
<< function.getBlocks().size();
}
llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>
lifted_partitioned_call_callees;
return HoistForFunctionalControlFlow(&function.front(),
cast<ModuleOp>(function.getParentOp()));
cast<ModuleOp>(function.getParentOp()),
&lifted_partitioned_call_callees);
}
} // namespace TF