[MLIR] Fix incorrect tfl.while canonicalization

- Move operand -> result forwarding to the correct place where we know we
  are dealing with a pass through operand

PiperOrigin-RevId: 319792098
Change-Id: I46846a791b8ed22995808aa63b73b2459590c777
This commit is contained in:
Rahul Joshi 2020-07-06 09:17:22 -07:00 committed by TensorFlower Gardener
parent cff750f2f6
commit d69c8e8c9c
2 changed files with 67 additions and 11 deletions

View File

@ -2213,7 +2213,8 @@ struct WhileResultOperandsMatchAndImplicitCapture
LogicalResult matchAndRewrite(WhileOp while_op,
PatternRewriter &rewriter) const override {
// Replace values simply passed through the body with extern values. The
// Replace values simply passed through the body with extern values
// (in both body and condition regions as well as while result). The
// block arguments of body and while match and so the corresponding cond
// argument can be easily found.
bool unchanged = true;
@ -2221,18 +2222,23 @@ struct WhileResultOperandsMatchAndImplicitCapture
auto &cond_block = while_op.cond().front();
auto &yield = *body_block.getTerminator();
for (auto ba : body_block.getArguments()) {
if (ba == yield.getOperand(ba.getArgNumber())) {
int arg_no = ba.getArgNumber();
if (ba == yield.getOperand(arg_no)) {
unchanged = false;
auto value = while_op.getOperand(ba.getArgNumber());
auto value = while_op.getOperand(arg_no);
ba.replaceAllUsesWith(value);
cond_block.getArgument(ba.getArgNumber()).replaceAllUsesWith(value);
cond_block.getArgument(arg_no).replaceAllUsesWith(value);
// This could be relaxed and casts inserted.
if (while_op.getResult(arg_no).getType() == value.getType())
while_op.getResult(arg_no).replaceAllUsesWith(value);
}
}
// The While ops operands and result types need to match
SmallVector<Value, 4> new_operands;
SmallVector<Value, 4> new_body_yield;
SmallVector<bool, 4> const_operand(while_op.getNumOperands(), false);
SmallVector<bool, 4> removed_operand(while_op.getNumOperands(), false);
llvm::SmallVector<Type, 4> types;
new_operands.reserve(while_op.getNumOperands());
new_body_yield.reserve(while_op.getNumOperands());
@ -2246,15 +2252,15 @@ struct WhileResultOperandsMatchAndImplicitCapture
auto value = while_op.getOperand(while_index);
if (body_block.getArgument(arg_index).use_empty() &&
cond_block.getArgument(arg_index).use_empty() &&
// This could be relaxed and casts inserted.
while_op.getResult(while_index).getType() == value.getType()) {
// Note: since we are not erasing results, need to use while_index
// to check if the corresponding result is unused.
while_op.getResult(while_index).use_empty()) {
unchanged = false;
body_block.eraseArgument(arg_index);
cond_block.eraseArgument(arg_index);
// Mark operand as constant and replace all uses with input to while.
while_op.getResult(while_index).replaceAllUsesWith(value);
const_operand[while_index] = true;
// Mark operand for removal.
removed_operand[while_index] = true;
} else {
new_operands.push_back(value);
new_body_yield.push_back(yield.getOperand(while_index));
@ -2276,7 +2282,7 @@ struct WhileResultOperandsMatchAndImplicitCapture
for (int i = 0; i < 2; ++i) new_op->getRegion(i).takeBody(op->getRegion(i));
int new_index = 0;
for (int op_index = 0, e = op->getNumResults(); op_index < e; ++op_index) {
if (const_operand[op_index]) continue;
if (removed_operand[op_index]) continue;
op->getResult(op_index).replaceAllUsesWith(new_op->getResult(new_index));
++new_index;
}

View File

@ -111,3 +111,53 @@ func @Int64SliceBeginSize(%arg0: tensor<4x128x32xf32>) -> tensor<1x128x32xf32> {
// CHECK: [[VAL_2:%.*]] = constant dense<[1, 128, 32]> : tensor<3xi32>
// CHECK: [[VAL_3:%.*]] = "tfl.slice"(%arg0, [[VAL_1]], [[VAL_2]]) : (tensor<4x128x32xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x128x32xf32>
}
// -----
// CHECK-LABEL: @WhileCanonicalizeBug
// Make sure that second output of the tf.while is not incorrectly inferred as
// pass through just because the corresponding input is not used in either
// condition or body. The tensor<f32> result of the loop can be either %arg1
// (if the body never executes, or 22.0 if the body executes atleast once).
func @WhileCanonicalizeBug(%arg0: tensor<i32>, %arg1: tensor<f32>) -> tensor<f32> {
%0:2 = "tfl.while"(%arg0, %arg1) ( {
^bb0(%arg2: tensor<i32>, %arg3: tensor<f32>):
%limit = constant dense<100> : tensor<i32>
%test = "tfl.less"(%arg0, %limit) : (tensor<i32>, tensor<i32>) -> tensor<i1>
"tfl.yield"(%test) : (tensor<i1>) -> ()
}, {
^bb0(%arg2: tensor<i32>, %arg3: tensor<f32>):
%cst = constant dense<22.0> : tensor<f32>
%stride = constant dense<1> : tensor<i32>
%inc = tfl.add %arg2, %stride {fused_activation_function = "NONE"} : tensor<i32>
"tfl.yield"(%inc, %cst) : (tensor<i32>, tensor<f32>) -> ()
}) : (tensor<i32>, tensor<f32>) -> (tensor<i32>, tensor<f32>)
// CHECK: return %0#1 : tensor<f32>
return %0#1 : tensor<f32>
}
// -----
// Test case to test bug due to checking
// `while_op.getResult(arg_index).use_empty()` instead of
// `while_op.getResult(while_index).use_empty()` in the tfl.while
// canonicalization.
// arg0 is a pass through. After first iteration, arg_index = 0
// and while_index = 1. Make arg1 use empty in block and condition, but not in
// result. Canonicalize will think it can remove both slot#0 and slot#1 and do
// so without replacing all operands, and in assert builds it will fail an
// assert failure ( op->use_empty() && "expected 'op' to have no uses")
// CHECK-LABEL: WhileCanonicalizeBug1
func @WhileCanonicalizeBug1(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
%0:2 = "tfl.while"(%arg0, %arg1) ( {
^bb0(%carg0: tensor<f32>, %carg1: tensor<f32>):
%limit = constant dense<100> : tensor<i32>
%test = "tfl.less"(%limit, %limit) : (tensor<i32>, tensor<i32>) -> tensor<i1>
"tfl.yield"(%test) : (tensor<i1>) -> ()
}, {
^bb0(%barg0: tensor<f32>, %barg1: tensor<f32>):
%cst = constant dense<22.0> : tensor<f32>
"tfl.yield"(%barg0, %cst) : (tensor<f32>, tensor<f32>) -> ()
}) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
return %0#1 : tensor<f32>
}