Support region based control flow ops in stack-ops-decomposition pass.
Since regions support implicit capture, there is no need to explicitly capture resource operands. It is expected that explicitly captured resource arguments have been canonicalized away for region based control flow ops. Now, we only need to run the pass on the regions associated with such ops. This change is meant to support WhileRegion, IfRegion and CaseRegion ops in this pass. PiperOrigin-RevId: 337339660 Change-Id: I3a3313f7afeb79ef72a24b8c01a12a5f1acbc10b
This commit is contained in:
parent
9adbacfd41
commit
e2c2ef461b
tensorflow/compiler/mlir/tensorflow
@ -122,6 +122,108 @@ func @while_cond(%arg0: tensor<!tf.resource>, %arg1: tensor<i32>) -> tensor<i32>
|
||||
|
||||
// -----
|
||||
|
||||
// Tests WhileRegion Op.
|
||||
|
||||
// CHECK-LABEL: func @main()
|
||||
func @main() -> () {
|
||||
%max_size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NOT: tf.Stack
|
||||
// CHECK: %[[BUFFER:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<10xf32>>>
|
||||
// CHECK: %[[SIZE:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<1xi32>>>
|
||||
// CHECK: tf.AssignVariableOp
|
||||
// CHECK: tf.AssignVariableOp
|
||||
%stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor<i32>) -> tensor<!tf.resource>
|
||||
// CHECK: tf.WhileRegion
|
||||
%while = "tf.WhileRegion"(%max_size) ({
|
||||
// CHECK: ^bb0(%[[BARG0:.*]]: tensor<i32>
|
||||
^bb0(%barg0: tensor<i32>):
|
||||
// CHECK: "tf._SomeOp"(%[[BARG0]])
|
||||
%pred = "tf._SomeOp"(%barg0) : (tensor<i32>) -> tensor<i1>
|
||||
"tf.Yield"(%pred) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
// CHECK: ^bb0(%[[BARG0:.*]]: tensor<i32>
|
||||
^bb0(%barg0: tensor<i32>):
|
||||
// CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[BARG0]], %[[CONST1]])
|
||||
%sub = "tf.Sub"(%barg0, %const1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%elem = "tf._SomeOp"() : () -> tensor<f32>
|
||||
// CHECK-NOT: "tf.StackPushV2"
|
||||
// CHECK: %[[BUFFER_VAL:.*]] = "tf.ReadVariableOp"(%[[BUFFER]])
|
||||
// CHECK: %[[SIZE_VAL:.*]] = "tf.ReadVariableOp"(%[[SIZE]])
|
||||
// CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[BUFFER_VAL]]
|
||||
// CHECK: "tf.AssignVariableOp"(%[[BUFFER]], %[[UPDATE]])
|
||||
// CHECK: "tf.AssignVariableOp"(%[[SIZE]]
|
||||
// CHECK-NOT: "tf.StackPushV2"
|
||||
%push = "tf.StackPushV2"(%stack, %elem) {swap_memory = false} : (tensor<!tf.resource>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: "tf.Yield"(%[[SUB]])
|
||||
"tf.Yield"(%sub) : (tensor<i32>) -> ()
|
||||
}) {is_stateless = false}
|
||||
: (tensor<i32>) -> tensor<i32>
|
||||
// CHECK-NOT: tf.StackPopV2
|
||||
// CHECK: %[[BUFFER_VAL:.*]] = "tf.ReadVariableOp"(%[[BUFFER]])
|
||||
// CHECK: %[[SIZE_VAL:.*]] = "tf.ReadVariableOp"(%[[SIZE]])
|
||||
// CHECK: %[[POP_VAL:.*]] = "tf.Slice"(%[[BUFFER_VAL]]
|
||||
// CHECK: "tf.AssignVariableOp"(%[[SIZE]]
|
||||
%pop = "tf.StackPopV2"(%stack) : (tensor<!tf.resource>) -> tensor<f32>
|
||||
// CHECK-NOT: tf.StackCloseV2
|
||||
"tf.StackCloseV2"(%stack) : (tensor<!tf.resource>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test CaseRegionOp
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[BRANCH_INDEX:.*]]: tensor<i32>
|
||||
func @main(%arg0: tensor<i32>) -> () {
|
||||
%max_size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NOT: tf.StackV2
|
||||
// CHECK: %[[BUFFER:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<10xf32>>>
|
||||
// CHECK: %[[SIZE:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<1xi32>>>
|
||||
// CHECK: tf.AssignVariableOp
|
||||
// CHECK: tf.AssignVariableOp
|
||||
%stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor<i32>) -> tensor<!tf.resource>
|
||||
// CHECK: %[[CASE_OUTPUT:.*]] = "tf.CaseRegion"(%[[BRANCH_INDEX]]) ( {
|
||||
%case_op = "tf.CaseRegion"(%arg0) ({
|
||||
%elem = "tf._SomeOp"() : () -> tensor<f32>
|
||||
// CHECK-NOT: tf.StackPushV2
|
||||
// CHECK: %[[BUFFER_VAL:.*]] = "tf.ReadVariableOp"(%[[BUFFER]])
|
||||
// CHECK: %[[SIZE_VAL:.*]] = "tf.ReadVariableOp"(%[[SIZE]])
|
||||
// CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[BUFFER_VAL]]
|
||||
// CHECK: "tf.AssignVariableOp"(%[[BUFFER]], %[[UPDATE]])
|
||||
// CHECK: "tf.AssignVariableOp"(%[[SIZE]]
|
||||
%push = "tf.StackPushV2"(%stack, %elem) {swap_memory = false} : (tensor<!tf.resource>, tensor<f32>) -> tensor<f32>
|
||||
"tf.Yield"(%elem) : (tensor<f32>) -> ()
|
||||
}, {
|
||||
%elem = "tf._SomeOtherOp"() : () -> tensor<f32>
|
||||
// CHECK-NOT: tf.StackPushV2
|
||||
// CHECK: %[[BUFFER_VAL:.*]] = "tf.ReadVariableOp"(%[[BUFFER]])
|
||||
// CHECK: %[[SIZE_VAL:.*]] = "tf.ReadVariableOp"(%[[SIZE]])
|
||||
// CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[BUFFER_VAL]]
|
||||
// CHECK: "tf.AssignVariableOp"(%[[BUFFER]], %[[UPDATE]])
|
||||
// CHECK: "tf.AssignVariableOp"(%[[SIZE]]
|
||||
%push = "tf.StackPushV2"(%stack, %elem) {swap_memory = false} : (tensor<!tf.resource>, tensor<f32>) -> tensor<f32>
|
||||
"tf.Yield"(%elem) : (tensor<f32>) -> ()
|
||||
}, {
|
||||
// CHECK-NOT: tf.StackPopV2
|
||||
// CHECK: %[[BUFFER_VAL:.*]] = "tf.ReadVariableOp"(%[[BUFFER]])
|
||||
// CHECK: %[[SIZE_VAL:.*]] = "tf.ReadVariableOp"(%[[SIZE]])
|
||||
// CHECK: %[[POP_VAL:.*]] = "tf.Slice"(%[[BUFFER_VAL]]
|
||||
// CHECK: "tf.AssignVariableOp"(%[[SIZE]]
|
||||
%pop = "tf.StackPopV2"(%stack) : (tensor<!tf.resource>) -> tensor<f32>
|
||||
"tf.Yield"(%pop) : (tensor<f32>) -> ()
|
||||
}) {is_stateless = false}
|
||||
: (tensor<i32>) -> tensor<f32>
|
||||
// CHECK-NOT: tf.StackPopV2
|
||||
%pop = "tf.StackPopV2"(%stack) : (tensor<!tf.resource>) -> tensor<f32>
|
||||
// CHECK-NOT: tf.StackCloseV2
|
||||
"tf.StackCloseV2"(%stack) : (tensor<!tf.resource>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
// Tests IfOp.
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
@ -308,3 +410,53 @@ func @if_else(%arg0: tensor<!tf.resource>, %arg1: tensor<!tf.resource>) -> tenso
|
||||
%push = "tf.StackPushV2"(%arg1, %elem) {swap_memory = false} : (tensor<!tf.resource>, tensor<f32>) -> tensor<f32>
|
||||
return %arg1 : tensor<!tf.resource>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass returns meaningful error message when WhileRegion op has
|
||||
// resource arguments.
|
||||
func @main() -> () {
|
||||
%max_size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
|
||||
%stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor<i32>) -> tensor<!tf.resource>
|
||||
%elem = "tf._SomeOp"() : () -> tensor<f32>
|
||||
%push_0 = "tf.StackPushV2"(%stack, %elem) {swap_memory = false} : (tensor<!tf.resource>, tensor<f32>) -> tensor<f32>
|
||||
// expected-error @+1 {{found unexpected type 'tensor<!tf.resource<tensor<10xf32>>>' of operand #0, resource type operands are expected to have been canonicalized away for region based control flow ops}}
|
||||
%1:2 = "tf.WhileRegion"(%stack, %max_size) ({
|
||||
^bb0 (%carg0: tensor<!tf.resource>, %carg1: tensor<i32>):
|
||||
%pred = "tf._SomeOp"(%carg1) : (tensor<i32>) -> tensor<i1>
|
||||
"tf.Yield"(%pred) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0 (%carg0: tensor<!tf.resource>, %carg1: tensor<i32>):
|
||||
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%sub = "tf.Sub"(%carg1, %const1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%push_1 = "tf.StackPushV2"(%carg0, %elem) {swap_memory = false} : (tensor<!tf.resource>, tensor<f32>) -> tensor<f32>
|
||||
"tf.Yield"(%carg0, %sub) : (tensor<!tf.resource>, tensor<i32>) -> ()
|
||||
}) {is_stateless = false}
|
||||
: (tensor<!tf.resource>, tensor<i32>) -> (tensor<!tf.resource>, tensor<i32>)
|
||||
%pop = "tf.StackPopV2"(%1#0) : (tensor<!tf.resource>) -> tensor<f32>
|
||||
"tf.StackCloseV2"(%stack) : (tensor<!tf.resource>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass returns meaningful error message when IfRegion op has
|
||||
// resource returns.
|
||||
|
||||
func @main(%arg0: tensor<i1>) -> () {
|
||||
%max_size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
|
||||
%stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor<i32>) -> tensor<!tf.resource>
|
||||
// expected-error @+1 {{found unexpected type 'tensor<!tf.resource>' of result #0, resource type results are expected to have been canonicalized away for region based control flow ops}}
|
||||
%if_op = "tf.IfRegion"(%arg0) ({
|
||||
%elem = "tf._SomeOp"() : () -> tensor<f32>
|
||||
%push = "tf.StackPushV2"(%stack, %elem) {swap_memory = false} : (tensor<!tf.resource>, tensor<f32>) -> tensor<f32>
|
||||
"tf.Yield"(%stack) : (tensor<!tf.resource>) -> ()
|
||||
}, {
|
||||
%pop = "tf.StackPopV2"(%stack) : (tensor<!tf.resource>) -> tensor<f32>
|
||||
"tf.Yield"(%stack) : (tensor<!tf.resource>) -> ()
|
||||
}) {is_stateless = false}
|
||||
: (tensor<i1>) -> tensor<!tf.resource>
|
||||
%pop = "tf.StackPopV2"(%if_op) : (tensor<!tf.resource>) -> tensor<f32>
|
||||
"tf.StackCloseV2"(%stack) : (tensor<!tf.resource>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -464,6 +464,38 @@ LogicalResult HandleStackPopV2Op(
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult HandleRegionControlFlowOps(
|
||||
Operation& op, ModuleOp module,
|
||||
llvm::SmallDenseMap<Value, Value>* data_var_to_size_var,
|
||||
llvm::StringMap<PartitionedCallStackOpsInfo>*
|
||||
decomposed_partitioned_call_callees) {
|
||||
for (OpOperand& operand : op.getOpOperands()) {
|
||||
if (getElementTypeOrSelf(operand.get().getType()).isa<TF::ResourceType>()) {
|
||||
return op.emitOpError()
|
||||
<< "found unexpected type " << operand.get().getType()
|
||||
<< " of operand #" << operand.getOperandNumber()
|
||||
<< ", resource type operands are expected to have been "
|
||||
"canonicalized away for region based control flow ops";
|
||||
}
|
||||
}
|
||||
for (OpResult result : op.getResults()) {
|
||||
if (getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) {
|
||||
return op.emitOpError()
|
||||
<< "found unexpected type " << result.getType() << " of result #"
|
||||
<< result.getResultNumber()
|
||||
<< ", resource type results are expected to have been "
|
||||
"canonicalized away for region based control flow ops";
|
||||
}
|
||||
}
|
||||
for (Region& region : op.getRegions()) {
|
||||
if (failed(DecomposeStackOpsInternal(®ion.front(), module,
|
||||
data_var_to_size_var,
|
||||
decomposed_partitioned_call_callees)))
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
// Decomposes stack ops on a region and recursively decomposes called functions.
|
||||
// data_var_to_size_var: a mapping from stacks' buffer local variables to size
|
||||
// local variables.
|
||||
@ -505,6 +537,13 @@ LogicalResult DecomposeStackOpsInternal(
|
||||
decomposed_partitioned_call_callees))) {
|
||||
return failure();
|
||||
}
|
||||
} else if (llvm::isa<TF::WhileRegionOp>(op) ||
|
||||
llvm::isa<TF::IfRegionOp>(op) ||
|
||||
llvm::isa<TF::CaseRegionOp>(op)) {
|
||||
if (failed(
|
||||
HandleRegionControlFlowOps(op, module, data_var_to_size_var,
|
||||
decomposed_partitioned_call_callees)))
|
||||
return failure();
|
||||
} else if (auto pcall = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
|
||||
if (!pcall.func()) {
|
||||
return pcall.emitOpError(
|
||||
|
Loading…
Reference in New Issue
Block a user