diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir index 016b06b662a..52bc0f878fc 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir @@ -542,3 +542,116 @@ func @if_else(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf. -> (tensor<*x!tf.resource>>) { return %arg1 : tensor<*x!tf.resource>> } + +// ----- + +// Tests that the pass lifts resources on two partitioned call ops sharing the +// same callee. The lifting should clone the callee then modify the clone. + +// CHECK-LABEL: @launch_with_partitioned_call +func @launch_with_partitioned_call() -> tensor { + // CHECK: %[[VH:.*]] = "tf.VarHandleOp"() + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> + // CHECK: %[[CONST:.*]] = "tf.Const"() + %1 = "tf.Const"() {value = dense<10.0> : tensor} : () -> tensor + // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VH]]) + // CHECK: %[[LAUNCH:.*]] = "tf_device.launch"() + %2 = "tf_device.launch"() ( { + // CHECK: %[[PC0:.*]] = "tf.PartitionedCall"(%[[CONST]], %[[READ]], %[[CONST]]) + // CHECK-SAME: f = @callee_resource_lifted + %3 = "tf.PartitionedCall"(%1, %0, %1) {f = @callee, config = "", config_proto = "", executor_type = ""} + : (tensor, tensor<*x!tf.resource>>, tensor) -> tensor + // CHECK: %[[PC1:.*]] = "tf.PartitionedCall"(%[[CONST]], %[[READ]], %[[CONST]]) + // CHECK-SAME: f = @callee_resource_lifted + %4 = "tf.PartitionedCall"(%1, %0, %1) {f = @callee, config = "", config_proto = "", executor_type = ""} + : (tensor, tensor<*x!tf.resource>>, tensor) -> tensor + // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[PC0]], %[[PC1]]) + %5 = "tf.AddV2"(%3, %4) : (tensor, tensor) -> tensor + // CHECK: tf_device.return %[[ADD]] : tensor + tf_device.return %5 : tensor + }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor + return %2 : tensor +} +// CHECK: @callee(%[[OA0:.*]]: tensor, %[[OA1:.*]]: tensor<*x!tf.resource>>, %[[OA2:.*]]: tensor) -> tensor +func @callee(%arg0: tensor, %arg1: tensor<*x!tf.resource>>, %arg2: tensor) -> tensor { + // CHECK: "tf.ReadVariableOp"(%[[OA1]]) + %0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource>>) -> tensor + %1 = "tf.AddV2"(%0, %arg0) : (tensor, tensor) -> tensor + %2 = "tf.AddV2"(%1, %arg2) : (tensor, tensor) -> tensor + return %2 : tensor +} +// CHECK: func @callee_resource_lifted(%[[A0:.*]]: tensor, %[[A1:.*]]: tensor, %[[A2:.*]]: tensor) -> tensor +// CHECK-NEXT: %[[ADD0:.*]] = "tf.AddV2"(%[[A1]], %[[A0]]) +// CHECK-NEXT: %[[ADD1:.*]] = "tf.AddV2"(%[[ADD0]], %[[A2]]) +// CHECK-NEXT: return %[[ADD1]] + + +// ----- + +// Tests that the pass lifts resources on two stateful partitioned call ops +// sharing the same callee. The lifting should clone the callee then modify the +// clone. + +// CHECK-LABEL: @launch_with_stateful_partitioned_call +func @launch_with_stateful_partitioned_call() -> () { + // CHECK: %[[VH0:.*]] = "tf.VarHandleOp"() + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> + // CHECK: %[[VH1:.*]] = "tf.VarHandleOp"() + %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> + // CHECK: %[[CONST:.*]] = "tf.Const"() + %2 = "tf.Const"() {value = dense<10.0> : tensor} : () -> tensor + // CHECK-DAG: %[[READ0:.*]] = "tf.ReadVariableOp"(%[[VH0]]) + // CHECK-DAG: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[VH1]]) + // CHECK: %[[LAUNCH:.*]] = "tf_device.launch"() + "tf_device.launch"() ( { + // CHECK: %[[PC0:.*]] = "tf.StatefulPartitionedCall"(%[[READ0]], %[[READ1]], %[[CONST]]) + // CHECK-SAME: f = @callee_resource_lifted + %3 = "tf.StatefulPartitionedCall"(%0, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""} + : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor) -> tensor<*x!tf.resource>> + // CHECK: %[[PC1:.*]] = "tf.StatefulPartitionedCall"(%[[PC0]], %[[READ1]], %[[CONST]]) + // CHECK-SAME: f = @callee_resource_lifted + %4 = "tf.StatefulPartitionedCall"(%3, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""} + : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor) -> tensor<*x!tf.resource>> + // CHECK: tf_device.return %[[PC1]] : tensor + tf_device.return + // CHECK: {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor + }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () + // CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[LAUNCH]]) + return +} +// CHECK: @callee(%[[OA0:.*]]: tensor<*x!tf.resource>>, %[[OA1:.*]]: tensor<*x!tf.resource>>, %[[OA2:.*]]: tensor) -> tensor<*x!tf.resource>> +func @callee(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf.resource>>, %arg2: tensor) -> tensor<*x!tf.resource>> { + // CHECK: "tf.ReadVariableOp"(%[[OA1]]) + %0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource>>) -> tensor + %1 = "tf.AddV2"(%0, %arg2) : (tensor, tensor) -> tensor + "tf.AssignVariableOp"(%arg0, %1) {dtype = i32} : (tensor<*x!tf.resource>>, tensor) -> () + return %arg0 : tensor<*x!tf.resource>> +} +// CHECK: func @callee_resource_lifted(%[[A0:.*]]: tensor, %[[A1:.*]]: tensor, %[[A2:.*]]: tensor) -> tensor +// CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[A1]], %[[A2]]) +// CHECK-NEXT: return %[[ADD]] + + +// ----- + +// Tests that the pass reports error on called function that has resource output +// which doesn't alias an input. + +func @launch_with_stateful_partitioned_call() -> () { + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> + %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> + %2 = "tf.Const"() {value = dense<10.0> : tensor} : () -> tensor + "tf_device.launch"() ( { + %3 = "tf.StatefulPartitionedCall"(%0, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""} + : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor) -> tensor<*x!tf.resource>> + %4 = "tf.StatefulPartitionedCall"(%3, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""} + : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor) -> tensor<*x!tf.resource>> + tf_device.return + }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () + return +} +// expected-error @+1 {{Unsupported function call: resource return value does not alias an input.}} +func @callee(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf.resource>>, %arg2: tensor) -> tensor<*x!tf.resource>> { + %0 = "tf._Unknown_"() : () -> tensor<*x!tf.resource>> + return %0 : tensor<*x!tf.resource>> +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index 7f0b1b96560..8dc21feca90 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/IR/Function.h" // TF:llvm-project #include "mlir/IR/Module.h" // TF:llvm-project #include "mlir/IR/StandardTypes.h" // TF:llvm-project +#include "mlir/IR/SymbolTable.h" // TF:llvm-project #include "mlir/IR/TypeUtilities.h" // TF:llvm-project #include "mlir/IR/Types.h" // TF:llvm-project #include "mlir/IR/Value.h" // TF:llvm-project @@ -811,16 +812,185 @@ LogicalResult HanldeIfOP(TF::IfOp if_op, FuncOp then_branch, return success(); } +// A resource-lifted function for (potentially multiple) PartitionedCallOps and +// information about the lifting changes. +struct PartitionedCallLiftingInfo { + // Function with resources lifted. Can be nullptr if nothing needs to change. + FuncOp lifted_callee; + // Mapping from old resource outputs to their aliasing output inputs. + llvm::SmallDenseMap old_outputs_aliasing_old_inputs; + // Mapping from old to new output indices in case any output is removed. + llvm::SmallVector old_to_new_output_indices; + // ResourceArgUseInfo for each old resource argument. + llvm::SmallDenseMap use_info; + // Input for AddLoadsStoresOutsideControlFlowOp(), see its comment. + llvm::SmallDenseMap> + arg_data_type_and_updated_output_index; +}; + +// Lifts loads/stores from a PartitionedCallOp's callee function. If anything +// needs to be changed, the original function will be preserved, and the lifting +// happens on a clone, which will be stored in `result`. +LogicalResult HandlePartitionedCallOpCallee( + FuncOp callee, PartitionedCallLiftingInfo* result) { + // Remove identity nodes to avoid aliasing. + RemoveIdentity(&callee.front()); + // Sanity check: return of resources should be aliases of inputs. Such outputs + // will be removed later. + int64_t non_resource_results = 0; + for (auto entry : + llvm::enumerate(callee.front().getTerminator()->getOperands())) { + auto retval = entry.value(); + if (!getElementTypeOrSelf(retval.getType()).isa()) { + result->old_to_new_output_indices.push_back(non_resource_results++); + continue; + } + auto aliasing_arg = retval.dyn_cast(); + if (!aliasing_arg) { + return callee.emitOpError( + "Unsupported function call: resource return value does not alias an " + "input."); + } + result->old_outputs_aliasing_old_inputs[entry.index()] = + aliasing_arg.getArgNumber(); + result->old_to_new_output_indices.push_back(-1); + } + + if (failed(FindResourceArgUseInfo(callee, &result->use_info))) { + return failure(); + } + if (result->use_info.empty()) { + result->lifted_callee = nullptr; + return success(); + } + + // Clone the callee before making changes. + SmallString<64> name_base = callee.getName(); + auto module = callee.getParentOfType(); + name_base += "_resource_lifted"; + auto name = name_base; + { + int64_t counter = 0; + while (module.lookupSymbol(name)) { + auto name = name_base; + name += "_" + std::to_string(counter++); + } + } + callee = callee.clone(); + callee.setName(name); + SymbolTable(module).insert(callee); + result->lifted_callee = callee; + + // Remove unused resources in functions. + llvm::SmallDenseMap remaining_resource_data_types; + RemoveUnusedResourceArgumentsAndForwardedRetvals( + result->use_info, callee, /*old_to_new_arg_indices=*/nullptr, + &remaining_resource_data_types); + for (const auto& entry : remaining_resource_data_types) { + result->arg_data_type_and_updated_output_index[entry.getFirst()] = { + entry.getSecond(), -1}; + } + llvm::SmallVector new_retvals; + for (auto val : callee.front().getTerminator()->getOperands()) { + // Remove resource type outputs. + if (getElementTypeOrSelf(val.getType()).isa()) continue; + new_retvals.push_back(val); + } + // Lift resources. + LiftArgRetResourcesForFunction( + callee, remaining_resource_data_types, [&](int64_t index, Value value) { + result->arg_data_type_and_updated_output_index[index].second = + new_retvals.size(); + new_retvals.push_back(value); + }); + auto old_return = callee.front().getTerminator(); + // Replace old return with the new ones with update values. + OpBuilder builder(old_return); + auto new_return = builder.create(old_return->getLoc(), new_retvals); + old_return->erase(); + callee.setType(FunctionType::get( + callee.getType().getInputs(), + llvm::to_vector<4>(new_return.getOperandTypes()), callee.getContext())); + return success(); +} + +// Updates a PartitionedCallOp/StatefulPartitionedCallOp according to the +// resource-lifted new callee function in lifting_info. +template +void UpdatePartitionedCallOpWithNewCallee( + CallOpType call_op, const PartitionedCallLiftingInfo& lifting_info) { + if (lifting_info.lifted_callee == nullptr) return; + // Replace output resource uses with the aliasing input, so that we can remove + // this output. + for (const auto& entry : lifting_info.old_outputs_aliasing_old_inputs) { + call_op.getResult(entry.getFirst()) + .replaceAllUsesWith(call_op.getOperand(entry.getSecond())); + } + // Recreate the call op. + OpBuilder builder(call_op); + // Now use the filtered original operands, which will be replaced by + // AddLoadsStoresOutsideControlFlowOp(). + auto new_operands = + FilterRange(call_op.args(), lifting_info.use_info); + auto new_call = builder.create( + call_op.getLoc(), + const_cast(lifting_info.lifted_callee).getType().getResults(), + new_operands, call_op.getAttrs()); + new_call.setAttr( + "f", builder.getSymbolRefAttr( + const_cast(lifting_info.lifted_callee).getName())); + AddLoadsStoresOutsideControlFlowOp( + new_call, lifting_info.arg_data_type_and_updated_output_index); + // Replace uses. + for (int64_t i = 0; i < lifting_info.old_to_new_output_indices.size(); ++i) { + if (lifting_info.old_to_new_output_indices[i] >= 0) { + call_op.getResult(i).replaceAllUsesWith( + new_call.getResult(lifting_info.old_to_new_output_indices[i])); + } + } + call_op.erase(); +} + +LogicalResult HoistForFunctionalControlFlow( + Block*, ModuleOp, llvm::SmallDenseMap*); + +// A templated routine for handling both PartitionedCallOp and +// StatefulPartitionedCallOp. If the callee is already lifted, it just updates +// the caller op itself; otherwise, it first recursively handles nested control +// flow, then performs lifting on the callee. +template +LogicalResult HandlePartitionedCallOp( + CallOpType call_op, FuncOp callee, ModuleOp module, + llvm::SmallDenseMap* lifted_callees) { + auto emplace_res = + lifted_callees->try_emplace(callee, PartitionedCallLiftingInfo()); + if (emplace_res.second) { + // Unseen callee. Perform resource lifting on it. + HoistForFunctionalControlFlow(&callee.front(), module, lifted_callees); + if (failed(HandlePartitionedCallOpCallee( + callee, &emplace_res.first->getSecond()))) { + return failure(); + } + } + UpdatePartitionedCallOpWithNewCallee(call_op, emplace_res.first->getSecond()); + return success(); +} + // Hoists resource loads/stores from control flow ops in `block` outside the -// body/cond/branch functions. -LogicalResult HoistForFunctionalControlFlow(Block* block, ModuleOp module) { +// body/cond/branch/callee functions. +LogicalResult HoistForFunctionalControlFlow( + Block* block, ModuleOp module, + llvm::SmallDenseMap* + lifted_partitioned_call_callees) { for (Operation& op : llvm::make_early_inc_range(*block)) { if (auto while_op = llvm::dyn_cast(&op)) { auto body = llvm::cast(module.lookupSymbol(while_op.body())); auto cond = llvm::cast(module.lookupSymbol(while_op.cond())); // Recursively handle the nested control flow. - HoistForFunctionalControlFlow(&body.front(), module); - HoistForFunctionalControlFlow(&cond.front(), module); + HoistForFunctionalControlFlow(&body.front(), module, + lifted_partitioned_call_callees); + HoistForFunctionalControlFlow(&cond.front(), module, + lifted_partitioned_call_callees); if (failed(HanldeWhileLoop(while_op, body, cond))) return failure(); } else if (auto if_op = llvm::dyn_cast(&op)) { auto then_branch = @@ -828,9 +998,30 @@ LogicalResult HoistForFunctionalControlFlow(Block* block, ModuleOp module) { auto else_branch = llvm::cast(module.lookupSymbol(if_op.else_branch())); // Recursively handle the nested control flow. - HoistForFunctionalControlFlow(&then_branch.front(), module); - HoistForFunctionalControlFlow(&else_branch.front(), module); + HoistForFunctionalControlFlow(&then_branch.front(), module, + lifted_partitioned_call_callees); + HoistForFunctionalControlFlow(&else_branch.front(), module, + lifted_partitioned_call_callees); if (failed(HanldeIfOP(if_op, then_branch, else_branch))) return failure(); + } else if (auto call_op = llvm::dyn_cast(&op)) { + if (!call_op.f().isa()) { + return call_op.emitError( + "Resource lifting does not support call with nested references."); + } + auto callee = llvm::cast( + module.lookupSymbol(call_op.f().getRootReference())); + if (failed(HandlePartitionedCallOp(call_op, callee, module, + lifted_partitioned_call_callees))) { + // Nested control flow handling is done in HandlePartitionedCallOp(). + return failure(); + } + } else if (auto call_op = + llvm::dyn_cast(&op)) { + auto callee = llvm::cast(module.lookupSymbol(call_op.f())); + if (failed(HandlePartitionedCallOp(call_op, callee, module, + lifted_partitioned_call_callees))) { + return failure(); + } } } return success(); @@ -840,10 +1031,13 @@ LogicalResult HoistForFunctionalControlFlow(Block* block, ModuleOp module) { // outside. Returns failure if there are remaining resource-type values that can // not be lifted. void ResourceOpLiftingPass::runOnModule() { + llvm::SmallDenseMap + lifted_partitioned_call_callees; auto result = getModule().walk([&](FuncOp func_op) { return func_op.walk([&](tf_device::LaunchOp launch_op) { - if (failed(HoistForFunctionalControlFlow(&launch_op.GetBody(), - getModule())) || + if (failed(HoistForFunctionalControlFlow( + &launch_op.GetBody(), getModule(), + &lifted_partitioned_call_callees)) || failed(HoistResourceOpsFromLaunchOp(launch_op))) { return WalkResult::interrupt(); } @@ -901,8 +1095,11 @@ LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function) { << function.getBlocks().size(); } + llvm::SmallDenseMap + lifted_partitioned_call_callees; return HoistForFunctionalControlFlow(&function.front(), - cast(function.getParentOp())); + cast(function.getParentOp()), + &lifted_partitioned_call_callees); } } // namespace TF