[TF:MLIR] Fold constant cond in IfOp

PiperOrigin-RevId: 321499354
Change-Id: Ie644b7f25d1c05d43c43f90983c7f5ac00866c47
This commit is contained in:
Eugene Zhulenev 2020-07-15 21:50:03 -07:00 committed by TensorFlower Gardener
parent 59215389ce
commit 9c20fbff6c
7 changed files with 85 additions and 6 deletions

View File

@ -175,6 +175,11 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
// Add a shape inference pass to optimize away the unnecessary casts.
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
}
// Inline function calls that left in the graph after folding functional
// control flow ops (IfOp, CaseOp).
pass_manager->addPass(mlir::createInlinerPass());
pass_manager->addPass(
mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification));
pass_manager->addPass(mlir::TFL::CreateOptimizePass());

View File

@ -33,7 +33,7 @@ def : Pat<(TFL_QuantizeOp (TFL_DequantizeOp $in), $qt), (replaceWithValue $in)>;
// point constant.
def : Pat<(TFL_DequantizeOp
(TFL_QuantizeOp (ConstantOp F32ElementsAttr:$cst), $qt)),
(ConstantOp $cst)>;
(TFL_ConstOp $cst)>;
// Quantize the value of a constant op if the quantization parameters have been
// propagated to the output.

View File

@ -229,6 +229,8 @@ else_branch: A function that takes 'inputs' and returns a list of
let verifier = [{
return Verify(*this);
}];
let hasCanonicalizer = 1;
}
def TF_YieldOp : TF_Op<"Yield",

View File

@ -463,7 +463,7 @@ LogicalResult FoldConstantCaseOp::matchAndRewrite(
auto call_op = rewriter.create<PartitionedCallOp>(
op.getLoc(), op.getResultTypes(), op.getOperands().drop_front(), func,
/*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty);
PropagateAttributes(op.getOperation(), call_op);
PropagateDeviceAndInternalAttrs(op.getOperation(), call_op);
rewriter.replaceOp(op, call_op.getResults());
return success();
}
@ -1615,6 +1615,56 @@ static LogicalResult Verify(IfOp op) {
return success();
}
class FoldConstantIfOp : public OpRewritePattern<TF::IfOp> {
public:
explicit FoldConstantIfOp(MLIRContext *context)
: OpRewritePattern<TF::IfOp>(context) {}
LogicalResult matchAndRewrite(TF::IfOp op,
PatternRewriter &rewriter) const override;
private:
template <typename T>
struct CallOpType {
using CallOp = T;
};
};
LogicalResult FoldConstantIfOp::matchAndRewrite(
TF::IfOp op, PatternRewriter &rewriter) const {
// Extract the constant cond value.
DenseIntElementsAttr cond_attr;
if (!matchPattern(op.cond(), m_Constant(&cond_attr))) return failure();
// Cond value must be a scalar.
if (cond_attr.getNumElements() != 1) return failure();
// Select a branch function.
bool cond = cond_attr.getSplatValue<BoolAttr>().getValue();
FlatSymbolRefAttr func = cond ? op.then_branchAttr() : op.else_branchAttr();
// Replace IfOp with PartitionedCallOp or StatefulPartitionedCallOp.
auto rewrite = [&](auto op_type) {
auto empty = rewriter.getStringAttr("");
auto call_op = rewriter.create<typename decltype(op_type)::CallOp>(
op.getLoc(), op.getResultTypes(), op.getOperands().drop_front(), func,
/*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty);
PropagateDeviceAndInternalAttrs(op.getOperation(), call_op);
rewriter.replaceOp(op, call_op.getResults());
};
if (op.is_stateless())
rewrite(CallOpType<PartitionedCallOp>{});
else
rewrite(CallOpType<StatefulPartitionedCallOp>{});
return success();
}
void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<FoldConstantIfOp>(context);
}
//===----------------------------------------------------------------------===//
// IfRegionOp
//===----------------------------------------------------------------------===//

View File

@ -21,7 +21,7 @@ limitations under the License.
// Propagates underscore and device attributes from src to dst.
// TODO(b/158769932): This should be a general feature instead post some policy
// discussion.
static void PropagateAttributes(Operation *src, Operation *dst) {
static void PropagateDeviceAndInternalAttrs(Operation *src, Operation *dst) {
auto device = mlir::Identifier::get("device", src->getContext());
for (auto named_attr : src->getAttrs()) {
if (*named_attr.first.begin() == '_' || named_attr.first == device)

View File

@ -755,6 +755,27 @@ func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>>) {
return %2, %3, %4 : tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>>
}
// CHECK-LABEL: foldIf
func @foldIf(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> (tensor<f32>) {
%0 = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
%1 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
// CHECK: %0 = "tf.PartitionedCall"(%arg0, %arg1)
// CHECK-SAME: device = "noodle"
// CHECK-SAME: f = @sub
%2 = "tf.If"(%0, %arg0, %arg1) {then_branch = @add, else_branch = @sub, output_shapes = [#tf.shape<>], device = "noodle", is_stateless = true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: %1 = "tf.StatefulPartitionedCall"(%0, %arg1)
// CHECK-SAME: _underscore_attr = "something"
// CHECK-SAME: f = @add
%3 = "tf.If"(%1, %2, %arg1) {then_branch = @add, else_branch = @sub, output_shapes = [#tf.shape<>], _underscore_attr = "something", is_stateless = false} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: %2 = "tf.If"
%4 = "tf.If"(%arg2, %3, %arg1) {then_branch = @add, else_branch = @sub, is_stateless = false} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: return %2
return %4 : tensor<f32>
}
// CHECK-LABEL: foldCase
func @foldCase(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
%2 = constant dense<1> : tensor<i32>

View File

@ -33,9 +33,10 @@ from tensorflow.python.ops import control_flow_ops
def Test():
data = tf.constant([1, 2, 3, 4, 5, 6])
zero = tf.convert_to_tensor(0)
one = tf.convert_to_tensor(1)
less_op = tf.less(zero, one)
# Create placeholders to prevent constant folding.
x_op = tf.placeholder(dtype=tf.int32)
y_op = tf.placeholder(dtype=tf.int32)
less_op = tf.less(x_op, y_op)
switch_op = control_flow_ops.switch(data, less_op)
merge_op = control_flow_ops.merge(switch_op)[0]
result = tf.transpose(merge_op)