[MLIR] Fix incorrect tfl.while canonicalization
- Move operand -> result forwarding to the correct place where we know we are dealing with a pass through operand PiperOrigin-RevId: 319792098 Change-Id: I46846a791b8ed22995808aa63b73b2459590c777
This commit is contained in:
parent
cff750f2f6
commit
d69c8e8c9c
@ -2213,7 +2213,8 @@ struct WhileResultOperandsMatchAndImplicitCapture
|
||||
|
||||
LogicalResult matchAndRewrite(WhileOp while_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Replace values simply passed through the body with extern values. The
|
||||
// Replace values simply passed through the body with extern values
|
||||
// (in both body and condition regions as well as while result). The
|
||||
// block arguments of body and while match and so the corresponding cond
|
||||
// argument can be easily found.
|
||||
bool unchanged = true;
|
||||
@ -2221,18 +2222,23 @@ struct WhileResultOperandsMatchAndImplicitCapture
|
||||
auto &cond_block = while_op.cond().front();
|
||||
auto &yield = *body_block.getTerminator();
|
||||
for (auto ba : body_block.getArguments()) {
|
||||
if (ba == yield.getOperand(ba.getArgNumber())) {
|
||||
int arg_no = ba.getArgNumber();
|
||||
if (ba == yield.getOperand(arg_no)) {
|
||||
unchanged = false;
|
||||
auto value = while_op.getOperand(ba.getArgNumber());
|
||||
auto value = while_op.getOperand(arg_no);
|
||||
ba.replaceAllUsesWith(value);
|
||||
cond_block.getArgument(ba.getArgNumber()).replaceAllUsesWith(value);
|
||||
cond_block.getArgument(arg_no).replaceAllUsesWith(value);
|
||||
|
||||
// This could be relaxed and casts inserted.
|
||||
if (while_op.getResult(arg_no).getType() == value.getType())
|
||||
while_op.getResult(arg_no).replaceAllUsesWith(value);
|
||||
}
|
||||
}
|
||||
|
||||
// The While ops operands and result types need to match
|
||||
SmallVector<Value, 4> new_operands;
|
||||
SmallVector<Value, 4> new_body_yield;
|
||||
SmallVector<bool, 4> const_operand(while_op.getNumOperands(), false);
|
||||
SmallVector<bool, 4> removed_operand(while_op.getNumOperands(), false);
|
||||
llvm::SmallVector<Type, 4> types;
|
||||
new_operands.reserve(while_op.getNumOperands());
|
||||
new_body_yield.reserve(while_op.getNumOperands());
|
||||
@ -2246,15 +2252,15 @@ struct WhileResultOperandsMatchAndImplicitCapture
|
||||
auto value = while_op.getOperand(while_index);
|
||||
if (body_block.getArgument(arg_index).use_empty() &&
|
||||
cond_block.getArgument(arg_index).use_empty() &&
|
||||
// This could be relaxed and casts inserted.
|
||||
while_op.getResult(while_index).getType() == value.getType()) {
|
||||
// Note: since we are not erasing results, need to use while_index
|
||||
// to check if the corresponding result is unused.
|
||||
while_op.getResult(while_index).use_empty()) {
|
||||
unchanged = false;
|
||||
body_block.eraseArgument(arg_index);
|
||||
cond_block.eraseArgument(arg_index);
|
||||
|
||||
// Mark operand as constant and replace all uses with input to while.
|
||||
while_op.getResult(while_index).replaceAllUsesWith(value);
|
||||
const_operand[while_index] = true;
|
||||
// Mark operand for removal.
|
||||
removed_operand[while_index] = true;
|
||||
} else {
|
||||
new_operands.push_back(value);
|
||||
new_body_yield.push_back(yield.getOperand(while_index));
|
||||
@ -2276,7 +2282,7 @@ struct WhileResultOperandsMatchAndImplicitCapture
|
||||
for (int i = 0; i < 2; ++i) new_op->getRegion(i).takeBody(op->getRegion(i));
|
||||
int new_index = 0;
|
||||
for (int op_index = 0, e = op->getNumResults(); op_index < e; ++op_index) {
|
||||
if (const_operand[op_index]) continue;
|
||||
if (removed_operand[op_index]) continue;
|
||||
op->getResult(op_index).replaceAllUsesWith(new_op->getResult(new_index));
|
||||
++new_index;
|
||||
}
|
||||
|
@ -111,3 +111,53 @@ func @Int64SliceBeginSize(%arg0: tensor<4x128x32xf32>) -> tensor<1x128x32xf32> {
|
||||
// CHECK: [[VAL_2:%.*]] = constant dense<[1, 128, 32]> : tensor<3xi32>
|
||||
// CHECK: [[VAL_3:%.*]] = "tfl.slice"(%arg0, [[VAL_1]], [[VAL_2]]) : (tensor<4x128x32xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x128x32xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @WhileCanonicalizeBug
|
||||
// Make sure that second output of the tf.while is not incorrectly inferred as
|
||||
// pass through just because the corresponding input is not used in either
|
||||
// condition or body. The tensor<f32> result of the loop can be either %arg1
|
||||
// (if the body never executes, or 22.0 if the body executes atleast once).
|
||||
func @WhileCanonicalizeBug(%arg0: tensor<i32>, %arg1: tensor<f32>) -> tensor<f32> {
|
||||
%0:2 = "tfl.while"(%arg0, %arg1) ( {
|
||||
^bb0(%arg2: tensor<i32>, %arg3: tensor<f32>):
|
||||
%limit = constant dense<100> : tensor<i32>
|
||||
%test = "tfl.less"(%arg0, %limit) : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
"tfl.yield"(%test) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg2: tensor<i32>, %arg3: tensor<f32>):
|
||||
%cst = constant dense<22.0> : tensor<f32>
|
||||
%stride = constant dense<1> : tensor<i32>
|
||||
%inc = tfl.add %arg2, %stride {fused_activation_function = "NONE"} : tensor<i32>
|
||||
"tfl.yield"(%inc, %cst) : (tensor<i32>, tensor<f32>) -> ()
|
||||
}) : (tensor<i32>, tensor<f32>) -> (tensor<i32>, tensor<f32>)
|
||||
// CHECK: return %0#1 : tensor<f32>
|
||||
return %0#1 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case to test bug due to checking
|
||||
// `while_op.getResult(arg_index).use_empty()` instead of
|
||||
// `while_op.getResult(while_index).use_empty()` in the tfl.while
|
||||
// canonicalization.
|
||||
// arg0 is a pass through. After first iteration, arg_index = 0
|
||||
// and while_index = 1. Make arg1 use empty in block and condition, but not in
|
||||
// result. Canonicalize will think it can remove both slot#0 and slot#1 and do
|
||||
// so without replacing all operands, and in assert builds it will fail an
|
||||
// assert failure ( op->use_empty() && "expected 'op' to have no uses")
|
||||
// CHECK-LABEL: WhileCanonicalizeBug1
|
||||
func @WhileCanonicalizeBug1(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
|
||||
%0:2 = "tfl.while"(%arg0, %arg1) ( {
|
||||
^bb0(%carg0: tensor<f32>, %carg1: tensor<f32>):
|
||||
%limit = constant dense<100> : tensor<i32>
|
||||
%test = "tfl.less"(%limit, %limit) : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
"tfl.yield"(%test) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%barg0: tensor<f32>, %barg1: tensor<f32>):
|
||||
%cst = constant dense<22.0> : tensor<f32>
|
||||
"tfl.yield"(%barg0, %cst) : (tensor<f32>, tensor<f32>) -> ()
|
||||
}) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
|
||||
return %0#1 : tensor<f32>
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user