diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/if_op.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/if_op.pbtxt deleted file mode 100644 index 79159f53070..00000000000 --- a/tensorflow/compiler/mlir/lite/tests/end2end/if_op.pbtxt +++ /dev/null @@ -1,422 +0,0 @@ -# RUN: tf_tfl_translate -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=4:4 -tf-output-arrays=StatefulIf,StatelessIf %s -o - --output-mlir | FileCheck %s --dump-input-on-failure - -node { - name: "tf.Less" - op: "Less" - input: "a" - input: "b" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - experimental_debug_info { - } -} -node { - name: "my_equal" - op: "Equal" - input: "a" - input: "b" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - experimental_debug_info { - } -} -node { - name: "cst0" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 4 - } - } - float_val: 1.0 - float_val: 2.0 - float_val: 3.0 - float_val: 4.0 - } - } - } -} -node { - name: "cst1" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 4 - } - } - float_val: 5.0 - float_val: 6.0 - float_val: 7.0 - float_val: 8.0 - } - } - } -} -node { - name: "StatefulIf" - op: "If" - input: "tf.Less" - input: "a" - input: "b" - input: "cst0" - input: "cst1" - attr { - key: "Tcond" - value { - type: DT_BOOL - } - } - attr { - key: "Tin" - value { - list { - type: DT_FLOAT - type: DT_FLOAT - type: DT_FLOAT - type: DT_FLOAT - } - } - } - attr { - key: "Tout" - value { - list { - type: DT_FLOAT - } - } - } - attr { - key: "else_branch" - value { - func { - name: "cond_false" - } - } - } - attr { - key: "then_branch" - value { - func { - name: "cond_true" - } - } - } - experimental_debug_info { - } -} -node { - name: "StatelessIf" - op: "StatelessIf" - input: "my_equal" - input: "a" - input: "b" - attr { - key: "Tcond" - value { - type: DT_BOOL - } - } - attr { - key: "Tin" - value { - list { - type: DT_FLOAT - type: DT_FLOAT - } - } - } - attr { - key: "Tout" - value { - list { - type: DT_FLOAT - } - } - } - attr { - key: "else_branch" - value { - func { - name: "cond_false_1" - } - } - } - attr { - key: "then_branch" - value { - func { - name: "cond_true_1" - } - } - } - experimental_debug_info { - } -} -node { - name: "main" - op: "_Retval" - input: "StatefulIf" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "index" - value { - i: 0 - } - } -} -node { - name: "main1" - op: "_Retval" - input: "StatelessIf" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "index" - value { - i: 1 - } - } -} -node { - name: "a" - op: "Placeholder" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - experimental_debug_info { - } -} -node { - name: "b" - op: "Placeholder" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - experimental_debug_info { - } -} -library { - function { - signature { - name: "cond_true" - input_arg { - name: "cond_true_arg0" - type: DT_FLOAT - } - input_arg { - name: "cond_true_arg1" - type: DT_FLOAT - } - input_arg { - name: "cond_true_arg2" - type: DT_FLOAT - } - input_arg { - name: "cond_true_arg3" - type: DT_FLOAT - } - output_arg { - name: "cond_true_ret" - type: DT_FLOAT - } - } - node_def { - name: "tf.Add" - op: "Add" - input: "cond_true_arg2" - input: "cond_true_arg3" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - experimental_debug_info { - original_node_names: "tf.Add" - } - } - ret { - key: "cond_true_ret" - value: "tf.Add:z:0" - } - } - function { - signature { - name: "cond_false" - input_arg { - name: "cond_false_arg0" - type: DT_FLOAT - } - input_arg { - name: "cond_false_arg1" - type: DT_FLOAT - } - input_arg { - name: "cond_false_arg2" - type: DT_FLOAT - } - input_arg { - name: "cond_false_arg3" - type: DT_FLOAT - } - output_arg { - name: "cond_false_ret" - type: DT_FLOAT - } - } - node_def { - name: "tf.Mul" - op: "Mul" - input: "cond_false_arg0" - input: "cond_false_arg3" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - experimental_debug_info { - original_node_names: "tf.Mul" - } - } - ret { - key: "cond_false_ret" - value: "tf.Mul:z:0" - } - } - function { - signature { - name: "cond_true_1" - input_arg { - name: "cond_true_arg0" - type: DT_FLOAT - } - input_arg { - name: "cond_true_arg1" - type: DT_FLOAT - } - output_arg { - name: "cond_true_ret" - type: DT_FLOAT - } - } - node_def { - name: "tf.Sub" - op: "Sub" - input: "cond_true_arg0" - input: "cond_true_arg1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - experimental_debug_info { - original_node_names: "tf.Sub" - } - } - ret { - key: "cond_true_ret" - value: "tf.Sub:z:0" - } - } - function { - signature { - name: "cond_false_1" - input_arg { - name: "cond_false_arg0" - type: DT_FLOAT - } - input_arg { - name: "cond_false_arg1" - type: DT_FLOAT - } - output_arg { - name: "cond_false_ret" - type: DT_FLOAT - } - } - node_def { - name: "tf.Div" - op: "Div" - input: "cond_false_arg0" - input: "cond_false_arg1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - experimental_debug_info { - original_node_names: "tf.Div" - } - } - ret { - key: "cond_false_ret" - value: "tf.Div:z:0" - } - } -} -versions { - producer: 115 - min_consumer: 12 -} - -# CHECK: func @StatefulIf_else -# CHECK-NEXT: constant dense<[5.000000e+00, 6.000000e+00, 7.000000e+00, 8.000000e+00]> -# CHECK-NEXT: tfl.mul -# CHECK: func @StatefulIf_then -# CHECK-NEXT: constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]> -# CHECK-NEXT: return -# CHECK: func @StatelessIf_else -# CHECK-NEXT: tfl.div -# CHECK: func @StatelessIf_then -# CHECK-NEXT: tfl.sub -# CHECK: "tf.If"{{.+}}else_branch = @StatelessIf_else{{.+}}then_branch = @StatelessIf_then -# CHECK: "tf.If"{{.+}}else_branch = @StatefulIf_else{{.+}}then_branch = @StatefulIf_then - diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 32f57a53851..40420eee697 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -85,8 +85,6 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, pass_config.quant_specs.serialized_quant_stats)); } - pass_manager->addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions()); - // The conversion pipeline has to follow the following orders: // 1) Saved model related optimization like decompose resource ops // 2) Convert composite functions like lstm/rnns, along with proper function @@ -130,9 +128,6 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, // Add a shape inference pass to optimize away the unnecessary casts. pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); } - - pass_manager->addPass(mlir::TF::CreateTFRegionControlFlowToFunctional()); - // Legalize while early to allow further constant folding. // TODO(jpienaar): This may not actually matter as we do canonicalization // after the legalize below, for now it needs to be below the above passes diff --git a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc index 106b0f9af83..707f4aba881 100644 --- a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc +++ b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc @@ -54,6 +54,7 @@ class WhileOutlinePass tensorflow::OpOrArgLocNameMapper mapper_; }; +} // namespace std::string WhileOutlinePass::GetName(Operation* op, StringRef suffix) { return (mapper_.GetUniqueName(op) + suffix).str(); @@ -61,7 +62,7 @@ std::string WhileOutlinePass::GetName(Operation* op, StringRef suffix) { // Returns whether the WhileOp is already outlined (e.g., only consists of calls // to functions). -bool IsAlreadyOutlined(WhileOp while_op) { +static bool IsAlreadyOutlinedd(WhileOp while_op) { auto just_call = [](Region& region) { auto it = region.front().begin(); if (!isa(*it)) return false; @@ -119,7 +120,7 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) { } // Skip if already just calls. - if (extra_operands.empty() && IsAlreadyOutlined(while_op)) return; + if (extra_operands.empty() && IsAlreadyOutlinedd(while_op)) return; // Collect new types. SmallVector types; @@ -237,7 +238,6 @@ void WhileOutlinePass::runOnOperation() { getOperation().walk( [&](mlir::TFL::WhileOp while_op) { OutlineWhile(while_op); }); } -} // namespace // Creates an instance of the TensorFlow Lite dialect WhileOp outline pass. std::unique_ptr> CreateWhileOutlinePass() { diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index c1ce7cf3374..40add34393b 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -418,7 +418,6 @@ cc_library( "transforms/fold_switch.cc", "transforms/freeze_global_tensors.cc", "transforms/functional_control_flow_to_cfg.cc", - "transforms/functional_control_flow_to_regions.cc", "transforms/generated_canonicalize.inc", "transforms/generated_optimize.inc", "transforms/gpu_fusion.cc", @@ -434,7 +433,6 @@ cc_library( "transforms/promote_resources_to_args.cc", "transforms/raise_control_flow.cc", "transforms/readonly_references_to_resources.cc", - "transforms/region_control_flow_to_functional.cc", "transforms/replicate_invariant_op_hoisting.cc", "transforms/replicate_to_island.cc", "transforms/resource_device_inference.cc", @@ -493,7 +491,6 @@ cc_library( ":translate_utils", ":unroll_batch_matmul_pass", ":xla_sharding_util", - "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/lite:validators", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla:xla_proto_cc", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 1ee1b1bd526..7b3e1508efb 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -4011,15 +4011,6 @@ struct TFInlinerInterface : public DialectInlinerInterface { // Analysis Hooks //===--------------------------------------------------------------------===// - // Defines the legality of inlinining 'src' region into the 'dest' region - // attached to a TF operation - bool isLegalToInline(Region *dest, Region *src, - BlockAndValueMapping &valueMapping) const final { - // Allow inlining in regions attached to region based control flow - // operations only if the src region is a single block region - return isa(dest->getParentOp()) && src->getBlocks().size() == 1; - } - // Defines the legality of inlining TF operations. bool isLegalToInline(Operation *, Region *, BlockAndValueMapping &) const final { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir deleted file mode 100644 index 09dfa61b29d..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir +++ /dev/null @@ -1,48 +0,0 @@ -// RUN: tf-opt %s -tf-functional-control-flow-to-regions -split-input-file | FileCheck %s --dump-input=fail - -// CHECK: func @testIf1Then{{.+}} {sym_visibility = "private"} -// CHECK: func @testIf1Else{{.+}} {sym_visibility = "private"} -func @testIf1Then(tensor<*xf32>) -> tensor<*xf32> -func @testIf1Else(tensor<*xf32>) -> tensor<*xf32> - -// CHECK-LABEL: func @testIf1Result(%arg0: tensor, %arg1: tensor<*xf32>) -func @testIf1Result(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { - %0 = "tf.If"(%arg0, %arg1) { - then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false - } : (tensor, tensor<*xf32>) -> tensor<*xf32> - - // CHECK: "tf.IfRegion" - // CHECK: [[Result0:%.*]] = call @testIf1Then - // CHECK: "tf.Yield"([[Result0]]) - // CHECK: [[Result1:%.*]] = call @testIf1Else - // CHECK: "tf.Yield"([[Result1]]) - return %0 : tensor<*xf32> -} - -// ----- - -// With mismatching input types - -// CHECK: func @testIf1Then{{.+}} {sym_visibility = "private"} -// CHECK: func @testIf1Else{{.+}} {sym_visibility = "private"} -func @testIf1Then(tensor<*xf32>) -> tensor<*xf32> -func @testIf1Else(tensor<*xf32>) -> tensor<*xf32> - -// CHECK-LABEL: func @testIf2Result(%arg0: tensor, %arg1: tensor<2xf32>) -func @testIf2Result(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { - %0 = "tf.If"(%arg0, %arg1) { - then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false - } : (tensor, tensor<2xf32>) -> tensor<2xf32> - - // CHECK: "tf.IfRegion" - // CHECK: "tf.Cast" - // CHECK: [[Result0:%.*]] = call @testIf1Then - // CHECK: "tf.Yield"([[Result0]]) - // CHECK: "tf.Cast" - // CHECK: [[Result1:%.*]] = call @testIf1Else - // CHECK: "tf.Yield"([[Result1]]) - return %0 : tensor<2xf32> -} - - - diff --git a/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir b/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir deleted file mode 100644 index c9b6543e0a3..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir +++ /dev/null @@ -1,142 +0,0 @@ -// RUN: tf-opt %s -tf-region-control-flow-to-functional -split-input-file -//| FileCheck %s --dump-input=fail - -// CHECK: func @tf.IfRegion_else(%arg0: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} -// CHECK-NEXT: "tf.Neg" -// CHECK: func @tf.IfRegion_then(%arg0: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} -// CHECK-NEXT: "tf.Abs" -func @testSimple(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "tf.If"{{.+}}else_branch = @tf.IfRegion_else{{.+}}then_branch = @tf.IfRegion_then - %0 = "tf.IfRegion"(%arg0) ({ - %1 = "tf.Abs"(%arg1) : (tensor<*xf32>) -> tensor<*xf32> - "tf.Yield"(%1) : (tensor<*xf32>) -> () - }, { - %2 = "tf.Neg"(%arg1) : (tensor<*xf32>) -> tensor<*xf32> - "tf.Yield"(%2) : (tensor<*xf32>) -> () - }) { is_stateless = true } : (tensor) -> tensor<*xf32> - return %0 : tensor<*xf32> -} - -// ----- - -// Use if condition inside the regions -// CHECK: func @tf.IfRegion_else(%arg0: tensor, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>) -> tensor<2xf32> attributes {sym_visibility = "private"} -// CHECK-NEXT: "tf.Select"(%arg0, %arg2, %arg3) -// CHECK: func @tf.IfRegion_then(%arg0: tensor, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>) -> tensor<2xf32> attributes {sym_visibility = "private"} -// CHECK-NEXT: "tf.Select"(%arg0, %arg1, %arg2) -func @testIfCondition(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { - %0 = "tf.Add"(%arg1, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - %1 = "tf.Mul"(%arg1, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - %2 = "tf.Div"(%arg1, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - - // CHECK: "tf.If"{{.+}}else_branch = @tf.IfRegion_else{{.+}}then_branch = @tf.IfRegion_then - %3 = "tf.IfRegion"(%arg0) ({ - %4 = "tf.Select"(%arg0, %0, %1) : (tensor, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - "tf.Yield"(%4) : (tensor<2xf32>) -> () - }, { - %5 = "tf.Select"(%arg0, %1, %2): (tensor, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - "tf.Yield"(%5) : (tensor<2xf32>) -> () - }) { is_stateless = true} : (tensor) -> tensor<2xf32> - return %3 : tensor<2xf32> -} - -// ----- - -// Constant sinking - -// CHECK: func @tf.IfRegion_else() -> tensor<2xf32> -// CHECK-NEXT: constant dense<1.0 -// CHECK: func @tf.IfRegion_then() -> tensor<2xf32> -// CHECK-NEXT: constant dense<0.0 -func @testIfConstant(%arg0: tensor) -> tensor<2xf32> { - %cst_zero = constant dense<0.0> : tensor<2xf32> - // CHECK: "tf.If"(%arg0) {else_branch = @tf.IfRegion_else{{.+}}then_branch = @tf.IfRegion_then - %0 = "tf.IfRegion"(%arg0) ({ - "tf.Yield"(%cst_zero) : (tensor<2xf32>) -> () - }, { - %cst_one = constant dense<1.0> : tensor<2xf32> - "tf.Yield"(%cst_one) : (tensor<2xf32>) -> () - }) { is_stateless = true} : (tensor) -> tensor<2xf32> - return %0 : tensor<2xf32> -} - -// ----- - -// Nested IfRegions -// CHECK: func @tf.IfRegion1_else -// CHECK-NEXT: "tf.Acos" -// CHECK-NEXT: "tf.Abs" - -// CHECK: func @tf.IfRegion1_then -// CHECK-NEXT: "tf.LogicalNot" -// CHECK-NEXT: "tf.Asin" -// CHECK-NEXT: "tf.If"({{.+}}) {else_branch = @tf.IfRegion_else, {{.+}} then_branch = @tf.IfRegion_then} - -// CHECK: func @tf.IfRegion_else -// CHECK-NEXT: "tf.Neg" -// CHECK: func @tf.IfRegion_then -// CHECK-NEXT: "tf.Abs" - -func @testNested(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "tf.If"({{.+}}) {else_branch = @tf.IfRegion1_else, {{.+}} then_branch = @tf.IfRegion1_then} - %0 = "tf.IfRegion"(%arg0) ({ - // Outer Then - %cond = "tf.LogicalNot"(%arg0) : (tensor) -> tensor - %asin = "tf.Asin"(%arg1) : (tensor<*xf32>) -> tensor<*xf32> - - // nested IfRegion - %1 = "tf.IfRegion"(%cond) ({ - %2 = "tf.Abs"(%asin) : (tensor<*xf32>) -> tensor<*xf32> - "tf.Yield"(%2) : (tensor<*xf32>) -> () - }, { - %2 = "tf.Neg"(%arg1) : (tensor<*xf32>) -> tensor<*xf32> - "tf.Yield"(%2) : (tensor<*xf32>) -> () - }) { is_stateless = true } : (tensor) -> tensor<*xf32> - - "tf.Yield"(%1) : (tensor<*xf32>) -> () - }, { - // Outer Else - %acos = "tf.Acos"(%arg1) : (tensor<*xf32>) -> tensor<*xf32> - %3 = "tf.Abs"(%acos) : (tensor<*xf32>) -> tensor<*xf32> - "tf.Yield"(%3) : (tensor<*xf32>) -> () - }) { is_stateless = true } : (tensor) -> tensor<*xf32> - return %0 : tensor<*xf32> -} - -// ----- - -// Match existing function->Region pattern (simple) -func @testIf1Then(tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} -func @testIf1Else(tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} -func @testIf1Result(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "tf.If"({{.+}}) {else_branch = @testIf1Else, {{.+}} then_branch = @testIf1Then} - %0 = "tf.IfRegion"(%arg0) ( { - %1 = call @testIf1Then(%arg1) : (tensor<*xf32>) -> tensor<*xf32> - "tf.Yield"(%1) : (tensor<*xf32>) -> () - }, { - %1 = call @testIf1Else(%arg1) : (tensor<*xf32>) -> tensor<*xf32> - "tf.Yield"(%1) : (tensor<*xf32>) -> () - }) {is_stateless = false} : (tensor) -> tensor<*xf32> - return %0 : tensor<*xf32> -} - -// ----- - -// Match existing function->Region pattern (with casts) - -func @testIf1Then(tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} -func @testIf1Else(tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} -func @testIf2Result(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: "tf.If"({{.+}}) {else_branch = @testIf1Else, {{.+}} then_branch = @testIf1Then} - %0 = "tf.IfRegion"(%arg0) ( { - %1 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2xf32>) -> tensor<*xf32> - %2 = call @testIf1Then(%1) : (tensor<*xf32>) -> tensor<*xf32> - "tf.Yield"(%2) : (tensor<*xf32>) -> () - }, { - %1 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2xf32>) -> tensor<*xf32> - %2 = call @testIf1Else(%1) : (tensor<*xf32>) -> tensor<*xf32> - "tf.Yield"(%2) : (tensor<*xf32>) -> () - }) {is_stateless = false} : (tensor) -> tensor<2xf32> - return %0 : tensor<2xf32> -} - diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc index a0be88cc564..91bbac235e9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ // This transformation pass transforms functional control flow operations in the -// TensorFlow dialect to MLIR Control Flow Graph (CFG) form. +// standard TensorFlow dialect to MLIR Control Flow Graph (CFG) form. #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project @@ -52,6 +52,7 @@ static Value LowerCondition(Location loc, Value value, OpBuilder* builder) { // // Requires the function to provide arguments for each of the `fn` operands // that is compatible for tensor cast. +// static Operation* CallFn(Location loc, const std::function& get_arg, FuncOp fn, OpBuilder* builder) { FunctionType fn_type = fn.getType(); @@ -112,6 +113,7 @@ static void JumpToBlock(Location loc, const std::function& get_arg, // Requires that the block has same number of arguments as number of results of // the operation and either they have same types or are more generic types and // it is possible to cast them to results' types. +// static void ReplaceOpResultWithBlockArgs(Location loc, Operation* op, Block* block, OpBuilder* builder) { assert(op->getNumResults() == block->getNumArguments()); @@ -130,6 +132,9 @@ static void ReplaceOpResultWithBlockArgs(Location loc, Operation* op, // Given a functional IfOp, transforms the enclosing code to eliminate it // completely from the IR, breaking it into operations to evaluate the condition // as a bool, plus some branches. +// +// This returns true on failure. +// static LogicalResult LowerIfOp(IfOp op) { Operation* op_inst = op.getOperation(); Location loc = op_inst->getLoc(); @@ -188,6 +193,9 @@ static LogicalResult LowerIfOp(IfOp op) { // Given a functional WhileOp, transforms the enclosing code to eliminate it // completely from the IR, breaking it into operations to execute the loop body // repeatedly while the loop condition is true. +// +// This returns true on failure. +// static LogicalResult LowerWhileOp(WhileOp op) { Operation* op_inst = op.getOperation(); Location loc = op_inst->getLoc(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc deleted file mode 100644 index 88f411ccd8a..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc +++ /dev/null @@ -1,118 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This transformation pass transforms functional control flow operations in the -// TensorFlow dialect to their region based counterparts, i.e., -// tf.If -> tf.IfRegion - -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/TypeUtilities.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassRegistry.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" - -namespace mlir { -namespace TF { - -namespace { - -struct FunctionalControlFlowToRegions - : public PassWrapper { - void runOnFunction() override; -}; - -// Create a call to function `fn` with arguments `args` and return the CallOp. -// The arguments are cast to the required type before the call. -CallOp CreateCall(Location loc, Operation::operand_range args, FuncOp fn, - OpBuilder* builder) { - FunctionType fn_type = fn.getType(); - llvm::SmallVector operands; - int num_operands = fn_type.getNumInputs(); - operands.reserve(num_operands); - for (int i = 0; i < num_operands; ++i) { - Value arg = args[i]; - Type expected = fn_type.getInput(i); - if (arg.getType() != expected) { - arg = builder->create(loc, expected, arg, - /*Truncate=*/builder->getBoolAttr(false)); - } - operands.push_back(arg); - } - return builder->create(loc, fn, operands); -} - -// Transform a functional IfOp to a region based IfRegionOp -LogicalResult ConvertIfOp(IfOp if_op) { - auto if_region = OpBuilder(if_op).create( - if_op.getLoc(), if_op.getResultTypes(), if_op.cond(), - if_op.is_stateless()); - - // Insert call to the given function into the 'region'. - auto create_region_with_call = [&if_op](FlatSymbolRefAttr symbol, - Region& region) { - OpBuilder builder(region); - builder.createBlock(®ion); - auto func = if_op.getParentOfType().lookupSymbol( - symbol.getValue()); - auto call = CreateCall(if_op.getLoc(), if_op.input(), func, &builder); - builder.create(if_op.getLoc(), call.getResults()); - // Mark old function as private so that it can be DCE'd if not called. - func.setVisibility(SymbolTable::Visibility::Private); - }; - - create_region_with_call(if_op.then_branchAttr(), if_region.then_branch()); - create_region_with_call(if_op.else_branchAttr(), if_region.else_branch()); - - if_op.replaceAllUsesWith(if_region.getResults()); - if_op.erase(); - return success(); -} - -void FunctionalControlFlowToRegions::runOnFunction() { - for (Block& block : getFunction()) { - auto result = block.walk([](Operation* op) { - if (IfOp if_op = llvm::dyn_cast(op)) { - if (failed(ConvertIfOp(if_op))) { - if_op.emitOpError() << " failed to convert to region form"; - return WalkResult::interrupt(); - } - } - return WalkResult::advance(); - }); - - if (result.wasInterrupted()) return signalPassFailure(); - } -} - -} // namespace - -std::unique_ptr> -CreateTFFunctionalControlFlowToRegions() { - return std::make_unique(); -} - -static PassRegistration pass( - "tf-functional-control-flow-to-regions", - "Transform functional control flow Ops to Region based counterparts"); - -} // namespace TF -} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index f39eba1eac0..3973eb60707 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -32,18 +32,10 @@ std::unique_ptr> CreateFunctionalToExecutorDialectConversionPass(); namespace TF { -// Transforms functional control flow operations in the TensorFlow dialect to -// MLIR Control Flow Graph (CFG) form. +// Transforms functional control flow operations in the standard TensorFlow +// dialect to MLIR Control Flow Graph (CFG) form. std::unique_ptr> CreateTFFunctionalControlFlowToCFG(); -// Transforms functional control flow operations in the TensorFlow dialect to -// their region based counterparts. -std::unique_ptr> CreateTFFunctionalControlFlowToRegions(); - -// Transforms region bases control flow operations in the TensorFlow dialect to -// their functional counterparts. -std::unique_ptr> CreateTFRegionControlFlowToFunctional(); - // Materialize the MlirPassthroughOp by replacing it with the MLIR module // attached as an attribute. std::unique_ptr> CreateMaterializePassthroughOpPass(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc deleted file mode 100644 index 07dfa2903dd..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc +++ /dev/null @@ -1,321 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This transformation pass transforms region bases control flow operations in -// the TensorFlow dialect to their functional counterparts, i.e., -// tf.IfRegion -> tf.If - -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/Module.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/TypeUtilities.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassRegistry.h" // from @llvm-project -#include "mlir/Transforms/RegionUtils.h" // from @llvm-project -#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" - -namespace mlir { -namespace TF { - -namespace { - -struct RegionControlFlowToFunctional - : public PassWrapper { - void runOnFunction() override; - - private: - LogicalResult ConvertIfOp(IfRegionOp if_region); - - // Get unique name by using the loc to name mapping. - std::string GetName(Operation* op, StringRef suffix); - - tensorflow::OpOrArgLocNameMapper mapper; -}; - -std::string RegionControlFlowToFunctional::GetName(Operation* op, - StringRef suffix) { - return (mapper.GetUniqueName(op) + suffix).str(); -} - -// Given a list of regions, return all the external values referenced from the -// region. If the external value is a constant, sink it into the region instead -// (and do not add it to the returned vector). -llvm::SmallVector CollectExternValues(ArrayRef regions) { - llvm::SetVector extern_values_set; - - for (auto region : regions) { - llvm::SetVector region_extern_values; - getUsedValuesDefinedAbove(*region, region_extern_values); - - // Sink down constants into the functions. - for (auto extern_value : region_extern_values) { - if (!matchPattern(extern_value, m_Constant())) { - extern_values_set.insert(extern_value); - continue; - } - // Add constant at start of region. - auto const_builder = OpBuilder::atBlockBegin(®ion->front()); - auto const_value = const_builder.clone(*extern_value.getDefiningOp()); - replaceAllUsesInRegionWith(extern_value, const_value->getResult(0), - *region); - } - } - - return {extern_values_set.begin(), extern_values_set.end()}; -} - -// Extract the contents of a region with a single block into a new function. -// `extern_values` is the set of external values that the region refers to. -// -// Any inputs to the terminator of the region are converted to return values of -// the function. If any of these values is not exact type as the function's -// return type, appropriate cast operations will be inserted -void ExtractSingleBlockRegion(Region& region, FunctionType type, StringRef name, - llvm::SmallVector& extern_values) { - ModuleOp module = region.getParentOfType(); - auto builder = OpBuilder::atBlockBegin(module.getBody()); - auto loc = region.getParentOp()->getLoc(); - - // Create new function and extract region body into the function - auto outlined_func = - builder.create(loc, name, type, ArrayRef{}); - - outlined_func.getBody().takeBody(region); - Region& func_region = outlined_func.getBody(); - Block& first_block = func_region.front(); - - // Replace all external uses with function arguments. - for (auto it : llvm::enumerate(extern_values)) { - Value arg = first_block.addArgument(it.value().getType()); - replaceAllUsesInRegionWith(it.value(), arg, func_region); - } - - // Replace the existing terminator with a return. - Operation* terminator = outlined_func.getBody().front().getTerminator(); - builder.setInsertionPoint(terminator); - - SmallVector return_values; - return_values.reserve(terminator->getNumOperands()); - for (auto it : llvm::enumerate(type.getResults())) { - Value ret_val = terminator->getOperand(it.index()); - // Add a cast operation if types do not match. - if (ret_val.getType() != it.value()) { - ret_val = - builder.create(terminator->getLoc(), it.value(), ret_val); - } - return_values.push_back(ret_val); - } - builder.create(terminator->getLoc(), return_values); - terminator->erase(); - outlined_func.setVisibility(FuncOp::Visibility::Private); -} - -// Returns call for region with single call whose result feeds into the -// terminator of the region. Returns none if the region doesn't contain just -// call and casts ops. -llvm::Optional IsSingleCallRegion(Region& region) { - if (region.getBlocks().size() != 1) return llvm::None; - - auto it = region.front().rbegin(); - YieldOp yield = dyn_cast(*it++); - if (yield.getNumOperands() == 0) return llvm::None; - - if (it == region.front().rend()) return llvm::None; - - // and a Call before that - CallOp call = dyn_cast(*it++); - if (!call) return llvm::None; - - // There can be cast op's prior to that - for (; it != region.front().rend(); ++it) - if (!isa(*it)) return llvm::None; - - // all results of the call should feed into the yield - if (call.getNumResults() != yield.getNumOperands()) return llvm::None; - - for (auto res_it : llvm::zip(call.getResults(), yield.getOperands())) - if (std::get<0>(res_it) != std::get<1>(res_it)) return llvm::None; - - return call; -} - -// Returns whether the arguments of the given call are same as the given list of -// arguments (after looking through cast ops). -bool MatchCallArgs(CallOp call, llvm::SmallVectorImpl& args) { - if (call.getNumOperands() != args.size()) return false; - - for (auto it : llvm::enumerate(args)) { - Value arg = call.getOperand(it.index()); - if (auto cast = dyn_cast_or_null(arg.getDefiningOp())) - arg = cast.getOperand(); - - if (arg != it.value()) return false; - } - return true; -} - -// Summary information for trivially transforming region based op's to -// functional ops. A trivial transformation can be done when the regions are -// just calls to functions, in which case no outlining is needed. -struct TrivialTransformInfo { - // can the op be transformed trivially - bool can_transform = false; - - // list of callee names (one for each region) - llvm::SmallVector callee_names; - - // list of arguments used in these call (each call uses the same arguments - // potentially through casts) - llvm::SmallVector call_args; -}; - -// Analyze the given set of regions (attached to the same parent op) to check -// if the parent op be transformed to functional form trivially (i.e., reusing -// existing functions and without outlining). This is possible when all the -// regions are single call regions and the all the calls have the same -// arguments. -// -// If this trivial transformation is possible, return the relevant information -// needed for the transformation (in `TrivialTransformInfo`), else indicate that -// a trivial transformation is not possible by setting `can_transform` false. -TrivialTransformInfo AnalyzeForTrivialTransform(ArrayRef regions) { - const TrivialTransformInfo cannot_transform; - - if (regions.empty()) return cannot_transform; - - llvm::SmallVector calls; - calls.reserve(regions.size()); - - // Verify each region is a single call and collect these calls. - for (Region* region : regions) { - auto call = IsSingleCallRegion(*region); - if (!call.hasValue()) return cannot_transform; - calls.push_back(call.getValue()); - } - - llvm::SmallVector callees; - callees.reserve(regions.size()); - - CallOp call0 = calls[0]; - int num_args = call0.getNumOperands(); - callees.push_back(call0.getCallee()); - - // collect call0 arguments - llvm::SmallVector call0_args; - call0_args.reserve(num_args); - for (int arg_idx = 0; arg_idx < num_args; ++arg_idx) { - Value arg = call0.getOperand(arg_idx); - if (auto cast = dyn_cast_or_null(arg.getDefiningOp())) - arg = cast.getOperand(); - call0_args.push_back(arg); - } - - // match other call arguments with call0 arguments - for (int idx = 1; idx < calls.size(); ++idx) { - if (!MatchCallArgs(calls[idx], call0_args)) return cannot_transform; - callees.push_back(calls[idx].getCallee()); - } - - return {true, callees, call0_args}; -} - -// Transform IfRegionOp to IfOp -LogicalResult RegionControlFlowToFunctional::ConvertIfOp(IfRegionOp if_region) { - const TrivialTransformInfo tti = AnalyzeForTrivialTransform( - {&if_region.then_branch(), &if_region.else_branch()}); - - std::string then_name, else_name; - llvm::SmallVector extern_values; - - if (tti.can_transform) { - // We can transform to functional form trivially without outlining - then_name = tti.callee_names[0].str(); - else_name = tti.callee_names[1].str(); - extern_values = tti.call_args; - } else { - // Collect external values that are used within the else and then bodies. - extern_values = CollectExternValues( - {&if_region.then_branch(), &if_region.else_branch()}); - - // These external values need to be added as inputs to the generated If. The - // order is determined by the order of these values the `extern_vales`. - - // Build the type for the outlined function - llvm::SmallVector input_types; - input_types.reserve(extern_values.size()); - for (auto input : extern_values) input_types.push_back(input.getType()); - - FunctionType func_type = FunctionType::get( - input_types, if_region.getResultTypes(), if_region.getContext()); - - // Create 2 new functions with the input signature matching this order, - // and outline the `then` and `else` regions by moving the bodies of these - // regions into these functions. Replace tf.yield with a regular return. - then_name = GetName(if_region, "_then"); - ExtractSingleBlockRegion(if_region.then_branch(), func_type, then_name, - extern_values); - - else_name = GetName(if_region, "_else"); - ExtractSingleBlockRegion(if_region.else_branch(), func_type, else_name, - extern_values); - } - - // Once we have the `then` and `else` functions ready (either outlined or - // existing ones), replace the region based op with a functional control flow - // op - OpBuilder builder(if_region); - auto if_op = builder.create( - if_region.getLoc(), if_region.getResultTypes(), if_region.cond(), - extern_values, then_name, else_name, if_region.is_stateless()); - if_region.replaceAllUsesWith(if_op.getResults()); - if_region.erase(); - return success(); -} - -void RegionControlFlowToFunctional::runOnFunction() { - for (Block& block : getFunction()) { - auto result = block.walk([&](Operation* op) { - if (IfRegionOp if_region = llvm::dyn_cast(op)) { - if (failed(ConvertIfOp(if_region))) { - if_region.emitOpError() << " failed to convert to functional form"; - return WalkResult::interrupt(); - } - } - return WalkResult::advance(); - }); - - if (result.wasInterrupted()) return signalPassFailure(); - } -} - -} // namespace - -std::unique_ptr> CreateTFRegionControlFlowToFunctional() { - return std::make_unique(); -} - -static PassRegistration pass( - "tf-region-control-flow-to-functional", - "Transform region bases control flow Ops to functional counterparts"); - -} // namespace TF -} // namespace mlir