Add support for tf.Case op in Resource Op Lifting pass.
PiperOrigin-RevId: 314196123 Change-Id: I8eef7aad5b6da7cd69b821d92f940a8f08fef08c
This commit is contained in:
parent
e1a3f9a9e7
commit
bfc2553173
tensorflow/compiler/mlir/tensorflow
@ -406,6 +406,61 @@ func @while_cond(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @cluster_with_case(%[[ARG0:.*]]: tensor<i32>) -> tensor<4xf32>
|
||||
func @cluster_with_case(%arg0: tensor<i32>) -> tensor<4xf32> {
|
||||
// CHECK: %[[VH0:.*]] = "tf.VarHandleOp"()
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<4xf32>>>
|
||||
// CHECK: %[[VH1:.*]] = "tf.VarHandleOp"()
|
||||
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource<tensor<4xf32>>>
|
||||
// CHECK-DAG: %[[READ0:.*]] = "tf.ReadVariableOp"(%[[VH0]])
|
||||
// CHECK-DAG: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[VH1]])
|
||||
// CHECK: %[[CLUSTER:.*]]:2 = "tf_device.cluster"()
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
// CHECK: %[[CASE:.*]]:2 = "tf.Case"(%[[ARG0]], %[[READ0]], %[[READ1]])
|
||||
%3:2 = "tf.Case"(%arg0, %0, %1) {branches = [@branch_0, @branch_1, @branch_2]}
|
||||
: (tensor<i32>, tensor<*x!tf.resource<tensor<4xf32>>>, tensor<*x!tf.resource<tensor<4xf32>>>)
|
||||
-> (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>)
|
||||
// CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[CASE]]#1, %[[CASE]]#0)
|
||||
%4 = "tf.ReadVariableOp"(%3#0) : (tensor<*x!tf.resource<tensor<4xf32>>>) -> tensor<4xf32>
|
||||
%5 = "tf.AddV2"(%4, %3#1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: tf_device.return %[[ADD]], %[[CASE]]#1
|
||||
tf_device.return %5 : tensor<4xf32>
|
||||
// CHECK: {cluster_attr = "cluster_attr"} : () -> (tensor<4xf32>, tensor<4xf32>)
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<4xf32>
|
||||
// CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[CLUSTER]]#1)
|
||||
// CHECK: return %[[CLUSTER]]#0
|
||||
return %2 : tensor<4xf32>
|
||||
}
|
||||
// CHECK: func @branch_0(%[[TARG0:.*]]: tensor<4xf32>, %[[TARG1:.*]]: tensor<4xf32>)
|
||||
func @branch_0(%arg0: tensor<*x!tf.resource<tensor<4xf32>>>, %arg1: tensor<*x!tf.resource<tensor<4xf32>>>)
|
||||
-> (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>) {
|
||||
// CHECK-NEXT: %[[CONST:.*]] = "tf.Const"()
|
||||
%constant = "tf.Const"() {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||
"tf.AssignVariableOp"(%arg0, %constant) : (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>) -> ()
|
||||
// CHECK-NEXT: return %[[CONST]], %[[CONST]]
|
||||
return %arg0, %constant : tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>
|
||||
}
|
||||
// CHECK: func @branch_1(%[[EARG0:.*]]: tensor<4xf32>, %[[EARG1:.*]]: tensor<4xf32>)
|
||||
func @branch_1(%arg0: tensor<*x!tf.resource<tensor<4xf32>>>, %arg1: tensor<*x!tf.resource<tensor<4xf32>>>)
|
||||
-> (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>) {
|
||||
%id = "tf.Identity"(%arg1) : (tensor<*x!tf.resource<tensor<4xf32>>>) -> tensor<*x!tf.resource<tensor<4xf32>>>
|
||||
%read = "tf.ReadVariableOp"(%id) : (tensor<*x!tf.resource<tensor<4xf32>>>) -> tensor<4xf32>
|
||||
"tf.AssignVariableOp"(%arg0, %read) : (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>) -> ()
|
||||
// CHECK-NEXT: return %[[EARG1]], %[[EARG1]]
|
||||
return %arg0, %read : tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>
|
||||
}
|
||||
// CHECK: func @branch_2(%[[EARG0:.*]]: tensor<4xf32>, %[[EARG1:.*]]: tensor<4xf32>)
|
||||
func @branch_2(%arg0: tensor<*x!tf.resource<tensor<4xf32>>>, %arg1: tensor<*x!tf.resource<tensor<4xf32>>>)
|
||||
-> (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>) {
|
||||
%id = "tf.Identity"(%arg1) : (tensor<*x!tf.resource<tensor<4xf32>>>) -> tensor<*x!tf.resource<tensor<4xf32>>>
|
||||
%read = "tf.ReadVariableOp"(%id) : (tensor<*x!tf.resource<tensor<4xf32>>>) -> tensor<4xf32>
|
||||
"tf.AssignVariableOp"(%arg0, %read) : (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>) -> ()
|
||||
// CHECK-NEXT: return %[[EARG1]], %[[EARG1]]
|
||||
return %arg0, %read : tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that pass lifts resource reads from if branches.
|
||||
|
||||
// CHECK: func @cluster_with_if(%[[ARG0:.*]]: tensor<i1>) -> tensor<4xf32>
|
||||
@ -524,7 +579,7 @@ func @cluster_with_if(%arg0: tensor<i1>) -> tensor<4xf32> {
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<4xf32>>>
|
||||
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource<tensor<4xf32>>>
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
// expected-error @+1 {{unsupported tf.IfOp output: resource does not alias a single input.}}
|
||||
// expected-error @+1 {{unsupported output: resource does not alias a single input}}
|
||||
%3 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else,
|
||||
is_stateless = false}
|
||||
: (tensor<i1>, tensor<*x!tf.resource<tensor<4xf32>>>, tensor<*x!tf.resource<tensor<4xf32>>>)
|
||||
|
@ -668,67 +668,74 @@ LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
|
||||
return success();
|
||||
}
|
||||
|
||||
// Lifts loads/stores from an IfOp's branches.
|
||||
LogicalResult HandleIfOP(TF::IfOp if_op, FuncOp then_branch,
|
||||
FuncOp else_branch) {
|
||||
// Lifts loads/stores from an IfOp or CaseOp's branches.
|
||||
template <class CaseOrIfOp>
|
||||
LogicalResult HandleCaseOrIfOp(CaseOrIfOp op, ArrayRef<FuncOp> branches) {
|
||||
// Remove identity nodes to avoid aliasing.
|
||||
RemoveIdentity(&then_branch.front());
|
||||
RemoveIdentity(&else_branch.front());
|
||||
for (auto func : branches) RemoveIdentity(&func.front());
|
||||
|
||||
// Sanity check: branch return of resources should be aliases of inputs. If
|
||||
// so, replace the output uses with the input so that we can remove these
|
||||
// outputs.
|
||||
for (auto entry : llvm::enumerate(
|
||||
llvm::zip(then_branch.front().getTerminator()->getOperands(),
|
||||
else_branch.front().getTerminator()->getOperands()))) {
|
||||
auto then_retval = std::get<0>(entry.value());
|
||||
auto else_retval = std::get<1>(entry.value());
|
||||
assert(then_retval.getType() == else_retval.getType());
|
||||
if (!getElementTypeOrSelf(then_retval.getType()).isa<TF::ResourceType>()) {
|
||||
for (OpResult result : op.getResults()) {
|
||||
if (!getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>())
|
||||
continue;
|
||||
unsigned result_index = result.getResultNumber();
|
||||
constexpr unsigned kUnassigned = -1;
|
||||
unsigned common_aliasing_arg_num = kUnassigned;
|
||||
for (auto func : branches) {
|
||||
auto retval = func.front().getTerminator()->getOperand(result_index);
|
||||
assert(result.getType() == retval.getType());
|
||||
auto aliasing_arg = retval.dyn_cast<BlockArgument>();
|
||||
if (common_aliasing_arg_num == kUnassigned)
|
||||
common_aliasing_arg_num = aliasing_arg.getArgNumber();
|
||||
if (!aliasing_arg ||
|
||||
aliasing_arg.getArgNumber() != common_aliasing_arg_num)
|
||||
return op.emitOpError("unsupported output: ")
|
||||
<< "resource does not alias a single input";
|
||||
}
|
||||
auto then_aliasing_arg = then_retval.dyn_cast<BlockArgument>();
|
||||
auto else_aliasing_arg = else_retval.dyn_cast<BlockArgument>();
|
||||
if (!then_aliasing_arg || !else_aliasing_arg ||
|
||||
then_aliasing_arg.getArgNumber() != else_aliasing_arg.getArgNumber()) {
|
||||
return if_op.emitOpError("unsupported tf.IfOp output: ")
|
||||
<< "resource does not alias a single input.";
|
||||
}
|
||||
if_op.getResult(entry.index())
|
||||
.replaceAllUsesWith(
|
||||
if_op.getOperand(then_aliasing_arg.getArgNumber() + 1));
|
||||
assert(common_aliasing_arg_num != kUnassigned);
|
||||
result.replaceAllUsesWith(op.getOperand(common_aliasing_arg_num + 1));
|
||||
}
|
||||
|
||||
// Erase the resource outputs from the branches.
|
||||
int64_t non_resource_results = 0;
|
||||
llvm::SmallVector<int64_t, 4> old_to_new_output_indices;
|
||||
bool output_removed = false;
|
||||
for (auto result : if_op.getResults()) {
|
||||
if (!getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) {
|
||||
for (auto result : op.getResults()) {
|
||||
if (!getElementTypeOrSelf(result.getType())
|
||||
.template isa<TF::ResourceType>()) {
|
||||
old_to_new_output_indices.push_back(non_resource_results++);
|
||||
continue;
|
||||
}
|
||||
old_to_new_output_indices.push_back(-1);
|
||||
then_branch.front().getTerminator()->eraseOperand(non_resource_results);
|
||||
else_branch.front().getTerminator()->eraseOperand(non_resource_results);
|
||||
for (auto func : branches)
|
||||
func.front().getTerminator()->eraseOperand(non_resource_results);
|
||||
output_removed = true;
|
||||
}
|
||||
|
||||
llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> then_use_info;
|
||||
llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> else_use_info;
|
||||
if (failed(FindResourceArgUseInfo(then_branch, &then_use_info)) ||
|
||||
failed(FindResourceArgUseInfo(else_branch, &else_use_info))) {
|
||||
llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> resource_arg_uses;
|
||||
if (failed(FindResourceArgUseInfo(branches.front(), &resource_arg_uses)))
|
||||
return failure();
|
||||
|
||||
for (auto func : branches.drop_front()) {
|
||||
llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> branch_use_info;
|
||||
if (failed(FindResourceArgUseInfo(func, &branch_use_info)))
|
||||
return failure();
|
||||
// A resource is considered used as long as it is used in either branch.
|
||||
resource_arg_uses =
|
||||
MergeArgResourceUseInfo(resource_arg_uses, branch_use_info);
|
||||
}
|
||||
// A resource is considered used as long as it is used in either branch.
|
||||
auto resource_arg_uses =
|
||||
MergeArgResourceUseInfo(then_use_info, else_use_info);
|
||||
|
||||
if (resource_arg_uses.empty() && !output_removed) return success();
|
||||
// Remove unused resources in functions.
|
||||
llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types;
|
||||
RemoveUnusedResourceArgumentsAndForwardedRetvals(
|
||||
resource_arg_uses, then_branch, /*old_to_new_arg_indices=*/nullptr,
|
||||
resource_arg_uses, branches.front(), /*old_to_new_arg_indices=*/nullptr,
|
||||
&remaining_resource_data_types);
|
||||
RemoveUnusedResourceArgumentsAndForwardedRetvals(resource_arg_uses,
|
||||
else_branch);
|
||||
for (auto func : branches.drop_front())
|
||||
RemoveUnusedResourceArgumentsAndForwardedRetvals(resource_arg_uses, func);
|
||||
|
||||
// Forward resource inputs updated in any branch to the outputs of both
|
||||
// branches. First prepare the mapping from arg to new update output.
|
||||
llvm::SmallDenseMap<int64_t, int64_t> resource_arg_to_new_output;
|
||||
@ -746,10 +753,11 @@ LogicalResult HandleIfOP(TF::IfOp if_op, FuncOp then_branch,
|
||||
new_output_index;
|
||||
}
|
||||
}
|
||||
|
||||
// Append resource updates to the return ops: now they are just forwarded
|
||||
// input resources, but will be replaced by the data value in
|
||||
// LiftArgRetResourcesForFunction().
|
||||
for (auto branch : {then_branch, else_branch}) {
|
||||
for (auto branch : branches) {
|
||||
auto new_retvals =
|
||||
llvm::to_vector<4>(branch.front().getTerminator()->getOperands());
|
||||
for (const auto& entry : resource_arg_to_new_output) {
|
||||
@ -766,16 +774,17 @@ LogicalResult HandleIfOP(TF::IfOp if_op, FuncOp then_branch,
|
||||
});
|
||||
}
|
||||
|
||||
// Recreate the if op.
|
||||
OpBuilder builder(if_op);
|
||||
// Recreate the op without resource operands.
|
||||
OpBuilder builder(op);
|
||||
// Now use the filtered original operands, which will be replaced by
|
||||
// AddLoadsStoresOutsideControlFlowOp().
|
||||
auto new_operands =
|
||||
FilterRange<Value, OperandRange>(if_op.input(), resource_arg_uses);
|
||||
new_operands.insert(new_operands.begin(), if_op.cond());
|
||||
auto new_if = builder.create<TF::IfOp>(if_op.getLoc(),
|
||||
then_branch.getType().getResults(),
|
||||
new_operands, if_op.getAttrs());
|
||||
FilterRange<Value, OperandRange>(op.input(), resource_arg_uses);
|
||||
new_operands.insert(new_operands.begin(), op.getOperand(0));
|
||||
FuncOp first_func = branches.front();
|
||||
auto new_op =
|
||||
builder.create<CaseOrIfOp>(op.getLoc(), first_func.getType().getResults(),
|
||||
new_operands, op.getAttrs());
|
||||
// Prepare for AddLoadsStoresOutsideControlFlowOp()
|
||||
llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
|
||||
arg_data_type_and_updated_output_index;
|
||||
@ -787,16 +796,16 @@ LogicalResult HandleIfOP(TF::IfOp if_op, FuncOp then_branch,
|
||||
arg_data_type_and_updated_output_index[entry.getFirst() + 1] = {
|
||||
entry.getSecond(), update_index};
|
||||
}
|
||||
AddLoadsStoresOutsideControlFlowOp(new_if,
|
||||
AddLoadsStoresOutsideControlFlowOp(new_op,
|
||||
arg_data_type_and_updated_output_index);
|
||||
// Replace uses.
|
||||
for (int64_t i = 0; i < old_to_new_output_indices.size(); ++i) {
|
||||
if (old_to_new_output_indices[i] >= 0) {
|
||||
if_op.getResult(i).replaceAllUsesWith(
|
||||
new_if.getResult(old_to_new_output_indices[i]));
|
||||
op.getResult(i).replaceAllUsesWith(
|
||||
new_op.getResult(old_to_new_output_indices[i]));
|
||||
}
|
||||
}
|
||||
if_op.erase();
|
||||
op.erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -985,7 +994,20 @@ LogicalResult HoistForFunctionalControlFlow(
|
||||
lifted_partitioned_call_callees);
|
||||
HoistForFunctionalControlFlow(&else_branch.front(), module,
|
||||
lifted_partitioned_call_callees);
|
||||
if (failed(HandleIfOP(if_op, then_branch, else_branch))) return failure();
|
||||
if (failed(HandleCaseOrIfOp(if_op, {then_branch, else_branch})))
|
||||
return failure();
|
||||
} else if (auto case_op = llvm::dyn_cast<TF::CaseOp>(&op)) {
|
||||
SmallVector<FuncOp, 4> branch_functions;
|
||||
branch_functions.reserve(case_op.branches().size());
|
||||
for (const Attribute& branch : case_op.branches()) {
|
||||
FuncOp func =
|
||||
module.lookupSymbol<FuncOp>(branch.cast<FlatSymbolRefAttr>());
|
||||
// Recursively handle the nested control flow.
|
||||
HoistForFunctionalControlFlow(&func.front(), module,
|
||||
lifted_partitioned_call_callees);
|
||||
branch_functions.push_back(func);
|
||||
}
|
||||
if (failed(HandleCaseOrIfOp(case_op, branch_functions))) return failure();
|
||||
} else if (auto call_op = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
|
||||
if (!call_op.f().isa<FlatSymbolRefAttr>()) {
|
||||
return call_op.emitOpError(
|
||||
|
Loading…
Reference in New Issue
Block a user