diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 1e1c431822d..fc44e778b92 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -175,6 +175,11 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, // Add a shape inference pass to optimize away the unnecessary casts. pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); } + + // Inline function calls that left in the graph after folding functional + // control flow ops (IfOp, CaseOp). + pass_manager->addPass(mlir::createInlinerPass()); + pass_manager->addPass( mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification)); pass_manager->addPass(mlir::TFL::CreateOptimizePass()); diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td index 22bcc563f7b..38c754ed08c 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td @@ -33,7 +33,7 @@ def : Pat<(TFL_QuantizeOp (TFL_DequantizeOp $in), $qt), (replaceWithValue $in)>; // point constant. def : Pat<(TFL_DequantizeOp (TFL_QuantizeOp (ConstantOp F32ElementsAttr:$cst), $qt)), - (ConstantOp $cst)>; + (TFL_ConstOp $cst)>; // Quantize the value of a constant op if the quantization parameters have been // propagated to the output. diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index 1fe301696a7..71b30ae8090 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -229,6 +229,8 @@ else_branch: A function that takes 'inputs' and returns a list of let verifier = [{ return Verify(*this); }]; + + let hasCanonicalizer = 1; } def TF_YieldOp : TF_Op<"Yield", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index af7a16ba127..f4f9ec42864 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -463,7 +463,7 @@ LogicalResult FoldConstantCaseOp::matchAndRewrite( auto call_op = rewriter.create( op.getLoc(), op.getResultTypes(), op.getOperands().drop_front(), func, /*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty); - PropagateAttributes(op.getOperation(), call_op); + PropagateDeviceAndInternalAttrs(op.getOperation(), call_op); rewriter.replaceOp(op, call_op.getResults()); return success(); } @@ -1615,6 +1615,56 @@ static LogicalResult Verify(IfOp op) { return success(); } +class FoldConstantIfOp : public OpRewritePattern { + public: + explicit FoldConstantIfOp(MLIRContext *context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(TF::IfOp op, + PatternRewriter &rewriter) const override; + + private: + template + struct CallOpType { + using CallOp = T; + }; +}; + +LogicalResult FoldConstantIfOp::matchAndRewrite( + TF::IfOp op, PatternRewriter &rewriter) const { + // Extract the constant cond value. + DenseIntElementsAttr cond_attr; + if (!matchPattern(op.cond(), m_Constant(&cond_attr))) return failure(); + + // Cond value must be a scalar. + if (cond_attr.getNumElements() != 1) return failure(); + + // Select a branch function. + bool cond = cond_attr.getSplatValue().getValue(); + FlatSymbolRefAttr func = cond ? op.then_branchAttr() : op.else_branchAttr(); + + // Replace IfOp with PartitionedCallOp or StatefulPartitionedCallOp. + auto rewrite = [&](auto op_type) { + auto empty = rewriter.getStringAttr(""); + auto call_op = rewriter.create( + op.getLoc(), op.getResultTypes(), op.getOperands().drop_front(), func, + /*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty); + PropagateDeviceAndInternalAttrs(op.getOperation(), call_op); + rewriter.replaceOp(op, call_op.getResults()); + }; + + if (op.is_stateless()) + rewrite(CallOpType{}); + else + rewrite(CallOpType{}); + + return success(); +} + +void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // IfRegionOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc index cea2aa17d46..33d51301208 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc @@ -21,7 +21,7 @@ limitations under the License. // Propagates underscore and device attributes from src to dst. // TODO(b/158769932): This should be a general feature instead post some policy // discussion. -static void PropagateAttributes(Operation *src, Operation *dst) { +static void PropagateDeviceAndInternalAttrs(Operation *src, Operation *dst) { auto device = mlir::Identifier::get("device", src->getContext()); for (auto named_attr : src->getAttrs()) { if (*named_attr.first.begin() == '_' || named_attr.first == device) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index c67725fbccf..17a19c50998 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -755,6 +755,27 @@ func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex>) { return %2, %3, %4 : tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex> } +// CHECK-LABEL: foldIf +func @foldIf(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor) { + %0 = "tf.Const"() {value = dense : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense : tensor} : () -> tensor + + // CHECK: %0 = "tf.PartitionedCall"(%arg0, %arg1) + // CHECK-SAME: device = "noodle" + // CHECK-SAME: f = @sub + %2 = "tf.If"(%0, %arg0, %arg1) {then_branch = @add, else_branch = @sub, output_shapes = [#tf.shape<>], device = "noodle", is_stateless = true} : (tensor, tensor, tensor) -> tensor + // CHECK: %1 = "tf.StatefulPartitionedCall"(%0, %arg1) + // CHECK-SAME: _underscore_attr = "something" + // CHECK-SAME: f = @add + %3 = "tf.If"(%1, %2, %arg1) {then_branch = @add, else_branch = @sub, output_shapes = [#tf.shape<>], _underscore_attr = "something", is_stateless = false} : (tensor, tensor, tensor) -> tensor + + // CHECK: %2 = "tf.If" + %4 = "tf.If"(%arg2, %3, %arg1) {then_branch = @add, else_branch = @sub, is_stateless = false} : (tensor, tensor, tensor) -> tensor + + // CHECK: return %2 + return %4 : tensor +} + // CHECK-LABEL: foldCase func @foldCase(%arg0: tensor, %arg1: tensor) -> (tensor) { %2 = constant dense<1> : tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/control_flow_upgrade_legacy_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/control_flow_upgrade_legacy_v1.py index 209ed3492e8..19e7a90c1e1 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/control_flow_upgrade_legacy_v1.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/control_flow_upgrade_legacy_v1.py @@ -33,9 +33,10 @@ from tensorflow.python.ops import control_flow_ops def Test(): data = tf.constant([1, 2, 3, 4, 5, 6]) - zero = tf.convert_to_tensor(0) - one = tf.convert_to_tensor(1) - less_op = tf.less(zero, one) + # Create placeholders to prevent constant folding. + x_op = tf.placeholder(dtype=tf.int32) + y_op = tf.placeholder(dtype=tf.int32) + less_op = tf.less(x_op, y_op) switch_op = control_flow_ops.switch(data, less_op) merge_op = control_flow_ops.merge(switch_op)[0] result = tf.transpose(merge_op)