From 191628f0e5f34f71db15804d2aa83bfb326ad7bf Mon Sep 17 00:00:00 2001 From: Karim Nosir <karimnosseir@google.com> Date: Tue, 9 Jun 2020 13:51:17 -0700 Subject: [PATCH] - Remove executor_to_control pass. - Remove raise control flow pass. - Cleanup usage in TFLite and other referneces. - Remove skip_control_dialect member in PassConfig PiperOrigin-RevId: 315552807 Change-Id: I4994f6a3c26cbe4845b97e7933272a860d3f15c2 --- .../mlir/lite/common/tfl_pass_config.h | 10 +- .../compiler/mlir/lite/tf_tfl_passes.cc | 28 +- .../mlir/lite/tf_to_tfl_flatbuffer.cc | 6 - tensorflow/compiler/mlir/tensorflow/BUILD | 2 - .../mlir/tensorflow/tests/empty-main.mlir | 2 +- .../tests/executor_to_control_dialect.mlir | 188 -------------- .../tensorflow/tests/raise-control-flow.mlir | 57 ----- .../mlir/tensorflow/transforms/optimize.cc | 2 + .../mlir/tensorflow/transforms/passes.h | 10 +- .../transforms/raise_control_flow.cc | 159 ------------ .../translate/executor_to_control_dialect.cc | 242 ------------------ .../compiler/mlir/tfjs/tf_tfjs_passes.cc | 6 - 12 files changed, 14 insertions(+), 698 deletions(-) delete mode 100644 tensorflow/compiler/mlir/tensorflow/tests/executor_to_control_dialect.mlir delete mode 100644 tensorflow/compiler/mlir/tensorflow/tests/raise-control-flow.mlir delete mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/raise_control_flow.cc delete mode 100644 tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc diff --git a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h index 2ed63fcc794..83ff9971246 100644 --- a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h +++ b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h @@ -32,7 +32,6 @@ struct PassConfig { lower_tensor_list_ops(false), trim_functions_whitelist({}), quant_specs(std::move(specs)), - skip_control_dialect(false), form_clusters(false), unfold_batch_matmul(true), legalize_tf_while(true), @@ -49,13 +48,8 @@ struct PassConfig { llvm::ArrayRef<std::string> trim_functions_whitelist; // All information about quantization. QuantizationSpecs quant_specs; - // If `skip_control_dialect` is true, TF executor dialect is not converted to - // TF control dialect prior to legalization to TF Lite. - // TODO(b/142911013): Remove flag once control dialect is removed. - bool skip_control_dialect; - // If `form_clusters` is true (and `skip_control_dialect` is true), clusters - // are formed by grouping consecutive ops of the same device, under a - // `tf_device.launch` op. + // If `form_clusters` is true , clusters are formed by grouping consecutive + // ops of the same device, under a `tf_device.launch` op. bool form_clusters; // if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set // of tfl.fully_connected ops. diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 40420eee697..f23898d9530 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -58,21 +58,10 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs, void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, mlir::OpPassManager* pass_manager) { - pass_manager->addPass(mlir::tf_executor::CreateSwitchFoldPass()); - if (pass_config.skip_control_dialect) { - // Merge islands. - pass_manager->addPass( - mlir::tf_executor::CreateTFExecutorIslandCoarseningPass()); - // Assuming island coarsening above results in a graph with a single island, - // a canonicalization can be ran to hoist the ops of the single island out. - pass_manager->addPass(mlir::createCanonicalizerPass()); - - if (pass_config.form_clusters) - pass_manager->addPass(mlir::TFDevice::CreateClusterFormationPass()); - } else { - pass_manager->addPass(mlir::CreateTFExecutorToControlDialectConversion()); - pass_manager->addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass()); - } + mlir::TF::StandardPipelineOptions standard_pipeline_options; + standard_pipeline_options.enable_inliner = false; + standard_pipeline_options.form_clusters = pass_config.form_clusters; + mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options); if (pass_config.shape_inference) { pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); @@ -213,13 +202,8 @@ void CreateTFLStandardPipeline(OpPassManager& pm, OpPassManager& func_pm = pm.nest<FuncOp>(); // tf_executor dialect passes - Cleaning up the IR. - func_pm.addPass(tf_executor::CreateSwitchFoldPass()); - func_pm.addPass(tf_executor::CreateTFExecutorGraphPruningPass()); - func_pm.addPass(tf_executor::CreateTFExecutorIslandCoarseningPass()); - - // more cleanup of executor dialect and raise to control flow. - pm.addPass(mlir::CreateTFExecutorToControlDialectConversion()); - pm.addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass()); + mlir::TF::StandardPipelineOptions standard_pipeline_options; + mlir::TF::CreateTFStandardPipeline(func_pm, standard_pipeline_options); // This is needed for control flow support with TF TensorList. pm.addPass(mlir::TFL::CreateLowerStaticTensorListPass()); diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index 62f64ab63b4..38b96cf833f 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -38,12 +38,6 @@ limitations under the License. #include "tensorflow/lite/tools/optimize/quantize_weights.h" #include "tensorflow/stream_executor/lib/statusor.h" -namespace mlir { -/// Create a pass to convert from the TFExecutor to the TF control dialect. -std::unique_ptr<OperationPass<FuncOp>> -CreateTFExecutorToControlDialectConversion(); -} // namespace mlir - namespace tensorflow { using mlir::MLIRContext; diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 40add34393b..c74c13de0c2 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -431,7 +431,6 @@ cc_library( "transforms/optimize_global_tensors.cc", "transforms/parallel_execute_to_islands.cc", "transforms/promote_resources_to_args.cc", - "transforms/raise_control_flow.cc", "transforms/readonly_references_to_resources.cc", "transforms/replicate_invariant_op_hoisting.cc", "transforms/replicate_to_island.cc", @@ -460,7 +459,6 @@ cc_library( "transforms/tpu_variable_runtime_reformatting.cc", "translate/breakup-islands.cc", "translate/control_to_executor_dialect.cc", - "translate/executor_to_control_dialect.cc", "translate/tf_functional_to_executor.cc", ], hdrs = [ diff --git a/tensorflow/compiler/mlir/tensorflow/tests/empty-main.mlir b/tensorflow/compiler/mlir/tensorflow/tests/empty-main.mlir index 4a4aa277067..b5a9b84bc4a 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/empty-main.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/empty-main.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -tf-executor-to-control-conversion %s | FileCheck %s --check-prefix=CONTROL --dump-input=fail +// RUN: tf-opt -tf-executor-graph-pruning %s | FileCheck %s --check-prefix=CONTROL --dump-input=fail // RUN: tf-opt -tf-control-to-executor-conversion %s | FileCheck %s --check-prefix=EXECUTOR --dump-input=fail // CONTROL-LABEL: func @main diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_to_control_dialect.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_to_control_dialect.mlir deleted file mode 100644 index 5ecef050055..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/tests/executor_to_control_dialect.mlir +++ /dev/null @@ -1,188 +0,0 @@ -// RUN: tf-opt -tf-executor-to-control-conversion %s | FileCheck %s --dump-input=fail -// CHECK-LABEL: func @LoopTest() { -func @LoopTest() { - tf_executor.graph { - %0:2 = tf_executor.island { - %cst = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32> - tf_executor.yield %cst : tensor<i32> - } - %1:2 = tf_executor.Enter %0#0 frame "while/while_context" : (tensor<i32>) -> (tensor<*xi32>, !tf_executor.control) {T = "tfdtype$DT_INT32", device = "", name = "while/Enter"} - %2 = tf_executor.island { - "tf.NoOp"() {device = "", name = "cluster/pivot"} : () -> () - tf_executor.yield - } - %3:3 = tf_executor.NextIteration.Source : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"} - %4:3 = tf_executor.Merge %3#0, %1#0 : tensor<*xi32> {N = 2 : i64, T = "tfdtype$DT_INT32", device = "", name = "while/Merge"} - %5:2 = tf_executor.island(%4#2) { - %cst = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/Less/y", value = dense<2> : tensor<i32>} : () -> tensor<i32> - tf_executor.yield %cst : tensor<i32> - } - %6:2 = tf_executor.island { - %14 = "tf.Less"(%4#0, %5#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Less"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1> - tf_executor.yield %14 : tensor<*xi1> - } - %7:2 = tf_executor.LoopCond %6#0 : (tensor<*xi1>) -> (tensor<i1>, !tf_executor.control) {device = "", name = "while/LoopCond"} - %8:3 = tf_executor.Switch %4#0, %7#0 : tensor<*xi32> {T = "tfdtype$DT_INT32", _class = ["loc = @while/Merge"], device = "", name = "while/Switch"} - %9:2 = tf_executor.Exit %8#0 : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", name = "while/Exit"} - %10:2 = tf_executor.island { - %14 = "tf.Identity"(%8#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Identity"} : (tensor<*xi32>) -> tensor<*xi32> - tf_executor.yield %14 : tensor<*xi32> - } - %11:2 = tf_executor.island(%10#1) { - %cst = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/Add/y", value = dense<3> : tensor<i32>} : () -> tensor<i32> - tf_executor.yield %cst : tensor<i32> - } - %12:2 = tf_executor.island { - %14 = "tf.Add"(%10#0, %11#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Add"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32> - tf_executor.yield %14 : tensor<*xi32> - } - %13 = tf_executor.ControlTrigger %2, %12#1, %9#1 {_tpu_replicate = "cluster", device = "", name = "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/b_sync"} - tf_executor.NextIteration.Sink [%3#1] %12#0, %13 : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"} - tf_executor.fetch - } - return -} -// CHECK-NEXT: %[[CONST:[0-9]*]]:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> (tensor<i32>, !_tf.control) -// CHECK-NEXT: %[[ENTER:[0-9]*]]:2 = "_tf.Enter"(%[[CONST]]#0) {T = "tfdtype$DT_INT32", device = "", frame_name = "while/while_context", is_constant = false, name = "while/Enter", parallel_iterations = 10 : i64} : (tensor<i32>) -> (tensor<*xi32>, !_tf.control) -// CHECK-NEXT: %[[NOOP:[0-9]*]] = "_tf.NoOp"() {device = "", name = "cluster/pivot"} : () -> !_tf.control -// CHECK-NEXT: %[[SOURCE:[0-9]*]]:2 = "_tf.NextIteration.source"() {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"} : () -> (tensor<*xi32>, !_tf.control) -// CHECK-NEXT: %[[MERGE:[0-9]*]]:3 = "_tf.Merge"(%[[SOURCE]]#0, %[[ENTER]]#0) {N = 2 : i64, T = "tfdtype$DT_INT32", device = "", name = "while/Merge"} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<i32>, !_tf.control) -// CHECK-NEXT: %[[CONST_LESS:[0-9]*]]:2 = "_tf.Const"(%[[MERGE]]#2) {device = "", dtype = "tfdtype$DT_INT32", name = "while/Less/y", value = dense<2> : tensor<i32>} : (!_tf.control) -> (tensor<i32>, !_tf.control) -// CHECK-NEXT: %[[LESS:[0-9]*]]:2 = "_tf.Less"(%[[MERGE]]#0, %[[CONST_LESS]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Less"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi1>, !_tf.control) -// CHECK-NEXT: %[[COND:[0-9]*]]:2 = "_tf.LoopCond"(%[[LESS]]#0) {device = "", name = "while/LoopCond"} : (tensor<*xi1>) -> (tensor<i1>, !_tf.control) -// CHECK-NEXT: %[[SWITCH:[0-9]*]]:3 = "_tf.Switch"(%[[MERGE]]#0, %[[COND]]#0) {T = "tfdtype$DT_INT32", _class = ["loc = @while/Merge"], device = "", name = "while/Switch"} : (tensor<*xi32>, tensor<i1>) -> (tensor<*xi32>, tensor<*xi32>, !_tf.control) -// CHECK-NEXT: %[[EXIT:[0-9]*]]:2 = "_tf.Exit"(%[[SWITCH]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Exit"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control) -// CHECK-NEXT: %[[IDENTITY:[0-9]*]]:2 = "_tf.Identity"(%[[SWITCH]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Identity"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control) -// CHECK-NEXT: %[[CONST_ADD:[0-9]*]]:2 = "_tf.Const"(%[[IDENTITY]]#1) {device = "", dtype = "tfdtype$DT_INT32", name = "while/Add/y", value = dense<3> : tensor<i32>} : (!_tf.control) -> (tensor<i32>, !_tf.control) -// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = "_tf.Add"(%[[IDENTITY]]#0, %[[CONST_ADD]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Add"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi32>, !_tf.control) -// CHECK-NEXT: %[[CT:[0-9]*]] = "_tf.ControlTrigger"(%[[NOOP]], %[[ADD]]#1, %[[EXIT]]#1) {_tpu_replicate = "cluster", device = "", name = "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/b_sync"} : (!_tf.control, !_tf.control, !_tf.control) -> !_tf.control -// CHECK-NEXT: %[[SINK:[0-9]*]] = "_tf.NextIteration.sink"(%[[ADD]]#0, %[[CT]]) {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"} : (tensor<*xi32>, !_tf.control) -> !_tf.control -// CHECK-NEXT: return - -// ----- - -// CHECK-LABEL: func @multiple_ops_region -func @multiple_ops_region(%arg0 : tensor<*xi32>, %arg1 : tensor<i32>) { - tf_executor.graph { - %0:2 = tf_executor.island { - // The 4 operations are independent, but the current conversion will add - // control dependencies conservatively. - %1 = "tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add1"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32> - %2 = "tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add2"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32> - %3 = "tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add3"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32> - %4 = "tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add4"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32> - tf_executor.yield %4 : tensor<*xi32> - } - tf_executor.fetch - } - return -} -// CHECK-NEXT: %[[ADD1:[0-9]*]]:2 = "_tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add1"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi32>, !_tf.control) -// CHECK-NEXT: %[[ADD2:[0-9]*]]:2 = "_tf.Add"(%arg0, %arg1, %[[ADD1]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add2"} : (tensor<*xi32>, tensor<i32>, !_tf.control) -> (tensor<*xi32>, !_tf.control) -// CHECK-NEXT: %[[ADD3:[0-9]*]]:2 = "_tf.Add"(%arg0, %arg1, %[[ADD2]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add3"} : (tensor<*xi32>, tensor<i32>, !_tf.control) -> (tensor<*xi32>, !_tf.control) -// CHECK-NEXT: %[[ADD4:[0-9]*]]:2 = "_tf.Add"(%arg0, %arg1, %[[ADD3]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add4"} : (tensor<*xi32>, tensor<i32>, !_tf.control) -> (tensor<*xi32>, !_tf.control) - -// ----- - -// CHECK-LABEL: func @switchN( -func @switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { - %fetches = tf_executor.graph { - // CHECK: [[S1:%.*]]:6 = "_tf._SwitchN"(%arg1, %arg0) {num_outs = 5 : i64} - %1:6 = tf_executor.SwitchN %arg1, %arg0 of 5 : tensor<*xf32> - // CHECK: "_tf._SwitchN"(%arg1, %arg0, [[S1]]#5) {num_outs = 12 : i64} - %2:13 = tf_executor.SwitchN %arg1, %arg0 of 12 (%1#5) : tensor<*xf32> - tf_executor.fetch %2#0 : tensor<*xf32> - } - return %fetches : tensor<*xf32> -} - -// ----- - -// Test if tf_executor dialect ops with Ref types are mapped correctly to the ops in control dialect. -// CHECK-LABEL: func @ref_tf_executor_ops -func @ref_tf_executor_ops(%arg0: tensor<4x!tf.f32ref>, %arg1: tensor<4x!tf.f32ref>, %arg3: tensor<i32>, %arg4: tensor<i1> ) -> tensor<4x!tf.f32ref> { - %result = tf_executor.graph { - // CHECK: _tf.Enter - %0:2 = tf_executor.Enter %arg0 frame "while/while_context" : (tensor<4x!tf.f32ref>) -> (tensor<4x!tf.f32ref>, !tf_executor.control) - // CHECK: _tf.Exit - %1:2 = tf_executor.Exit %arg0 : tensor<4x!tf.f32ref> - // CHECK: _tf.Switch - %2:3 = tf_executor.Switch %arg0, %arg4 : (tensor<4x!tf.f32ref>, tensor<i1>) -> (tensor<4x!tf.f32ref>, tensor<4x!tf.f32ref>, !tf_executor.control) - // CHECK: _tf.Merge - %3:3 = tf_executor.Merge %arg0, %arg1 : (tensor<4x!tf.f32ref>, tensor<4x!tf.f32ref>) -> (tensor<4x!tf.f32ref>, tensor<i32>, !tf_executor.control) - // CHECK: _tf.NextIteration.source - %4:3 = tf_executor.NextIteration.Source : tensor<4x!tf.f32ref> - // CHECK: _tf.NextIteration.sink - tf_executor.NextIteration.Sink [%4#1] %4#0 : tensor<4x!tf.f32ref> - tf_executor.fetch %0#0 : tensor<4x!tf.f32ref> - } - return %result : tensor<4x!tf.f32ref> -} - -// ----- - -// Tests if empty island with just one control dependency input and output is -// handled correctly. -// CHECK-LABEL: func @empty_island_control_dep_only -func @empty_island_control_dep_only() -> tensor<i32> { - %fetch = tf_executor.graph { - %0:2 = tf_executor.island { - %4 = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32> - tf_executor.yield %4 : tensor<i32> - } - // CHECK-NEXT: %[[CONST1:[0-9]*]]:2 = "_tf.Const"() - // CHECK-SAME: () -> (tensor<i32>, !_tf.control) - %1:2 = tf_executor.island { - %5 = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32> - tf_executor.yield %5 : tensor<i32> - } - // CHECK-NEXT: %[[CONST2:[0-9]*]]:2 = "_tf.Const"() - // CHECK-SAME: () -> (tensor<i32>, !_tf.control) - %2 = tf_executor.island(%0#1) { - tf_executor.yield - } - %3:2 = tf_executor.island(%2, %1#1) { - %6 = "tf.Add"(%0#0, %1#0) : (tensor<i32>, tensor<i32>) -> tensor<i32> - tf_executor.yield %6 : tensor<i32> - } - // CHECK-NEXT: %[[ADD:[0-9]*]]:2 = "_tf.Add"(%[[CONST1]]#0, %[[CONST2]]#0, %[[CONST1]]#1, %[[CONST2]]#1) - // CHECK-SAME: (tensor<i32>, tensor<i32>, !_tf.control, !_tf.control) -> (tensor<i32>, !_tf.control) - tf_executor.fetch %3#0 : tensor<i32> - } - return %fetch : tensor<i32> -} - -// ----- - -// Tests if empty island with multiple control inputs will be replaced with a -// no-op. -// CHECK-LABEL: func @empty_island_multi_control_inputs -func @empty_island_multi_control_inputs() -> tensor<i32> { - %fetch = tf_executor.graph { - %0:2 = tf_executor.island { - %4 = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32> - tf_executor.yield %4 : tensor<i32> - } - // CHECK-NEXT: %[[CONST1:[0-9]*]]:2 = "_tf.Const"() - // CHECK-SAME: () -> (tensor<i32>, !_tf.control) - %1:2 = tf_executor.island { - %5 = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32> - tf_executor.yield %5 : tensor<i32> - } - // CHECK-NEXT: %[[CONST2:[0-9]*]]:2 = "_tf.Const"() - // CHECK-SAME: () -> (tensor<i32>, !_tf.control) - %2 = tf_executor.island(%0#1, %1#1) { - tf_executor.yield - } - // CHECK-NEXT: %[[NOOP:[0-9]*]] = "_tf.NoOp"(%[[CONST1]]#1, %[[CONST2]]#1) - // CHECK-SAME: (!_tf.control, !_tf.control) -> !_tf.control - %3:2 = tf_executor.island(%2) { - %6 = "tf.Add"(%0#0, %1#0) : (tensor<i32>, tensor<i32>) -> tensor<i32> - tf_executor.yield %6 : tensor<i32> - } - // CHECK-NEXT: %[[ADD:[0-9]*]]:2 = "_tf.Add"(%[[CONST1]]#0, %[[CONST2]]#0, %[[NOOP]]) - // CHECK-SAME: (tensor<i32>, tensor<i32>, !_tf.control) -> (tensor<i32>, !_tf.control) - tf_executor.fetch %3#0 : tensor<i32> - } - return %fetch : tensor<i32> -} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/raise-control-flow.mlir b/tensorflow/compiler/mlir/tensorflow/tests/raise-control-flow.mlir deleted file mode 100644 index a6c7bdd72ed..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/tests/raise-control-flow.mlir +++ /dev/null @@ -1,57 +0,0 @@ -// RUN: tf-opt %s -tf-raise-control-flow -split-input-file | FileCheck %s - -// Test that we remove underscores. - -// CHECK-LABEL: func @testSimpleAddsAndIdentity(%arg0: tensor<*xf32>) -func @testSimpleAddsAndIdentity(tensor<*xf32>) -> tensor<*xf32> { -^bb0(%0: tensor<*xf32>): - - // CHECK: %0 = "tf.Identity"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> - %1 = "_tf.Identity"(%0) : (tensor<*xf32>) -> tensor<*xf32> - - // CHECK: %1 = "tf.Add"(%arg0, %arg0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - %2 = "_tf.Add"(%0, %0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - - // CHECK: %2 = "tf.Add"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - %3 = "_tf.Add"(%1, %2) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - - // CHECK: return %2 : tensor<*xf32> - return %3 : tensor<*xf32> -} - -// CHECK-LABEL: func @testAddWithControlDependency(%arg0: tensor<*xf32>) -func @testAddWithControlDependency(tensor<*xf32>) -> tensor<*xf32> { -^bb0(%0: tensor<*xf32>): - - // CHECK: %0 = "tf.Identity"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> - %1:2 = "_tf.Identity"(%0) : (tensor<*xf32>) -> (tensor<*xf32>, !_tf.control) - - // CHECK: %1 = "tf.Add"(%arg0, %arg0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - %2:2 = "_tf.Add"(%0, %0, %1#1) : (tensor<*xf32>, tensor<*xf32>, !_tf.control) -> (tensor<*xf32>, !_tf.control) - - // CHECK: %2 = "tf.Add"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - %3:2 = "_tf.Add"(%1#0, %2, %1#1, %2#1) : (tensor<*xf32>, tensor<*xf32>, !_tf.control, !_tf.control) -> (tensor<*xf32>, !_tf.control) - - // CHECK: return %2 : tensor<*xf32> - return %3 : tensor<*xf32> -} - -// TODO(clattner): simplify and expand these tests. This is mostly a placeholder. -func @LoopTest() { - %0:2 = "_tf.Const"() {device = "", name = "Const", dtype = "tfdtype$DT_INT32", value = dense<1> : tensor<i32>} : () -> (tensor<i32>, !_tf.control) - %1:2 = "_tf.Enter"(%0#0) {device = "", name = "while/Enter", T = "tfdtype$DT_INT32", frame_name = "while/while_context", is_constant = false, parallel_iterations = 10} : (tensor<i32>) -> (tensor<*xi32>, !_tf.control) - - %11:2 = "_tf.NextIteration.source"() {device = "", name = "while/NextIteration", T = "tfdtype$DT_INT32", id = 0} : () -> (tensor<*xi32>, !_tf.control) - - %2:3 = "_tf.Merge"(%11#0, %1#0) {device = "", name = "while/Merge", N = 2, T = "tfdtype$DT_INT32"} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<i32>, !_tf.control) - %3:2 = "_tf.Const"(%2#2) {device = "", name = "while/Less/y", dtype = "tfdtype$DT_INT32", value = dense<2> : tensor<i32>} : (!_tf.control) -> (tensor<i32>, !_tf.control) - %4:2 = "_tf.Less"(%2#0, %3#0) {device = "", name = "while/Less", T = "tfdtype$DT_INT32"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi1>, !_tf.control) - %5:2 = "_tf.LoopCond"(%4#0) {device = "", name = "while/LoopCond"} : (tensor<*xi1>) -> (tensor<i1>, !_tf.control) - %6:3 = "_tf.Switch"(%2#0, %5#0) {device = "", name = "while/Switch", T = "tfdtype$DT_INT32", _class = ["loc:@while/Merge"]} : (tensor<*xi32>, tensor<i1>) -> (tensor<*xi32>, tensor<*xi32>, !_tf.control) - %7:2 = "_tf.Exit"(%6#0) {device = "", name = "while/Exit", T = "tfdtype$DT_INT32"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control) - %8:2 = "_tf.Identity"(%6#1) {device = "", name = "while/Identity", T = "tfdtype$DT_INT32"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control) - %9:2 = "_tf.Const"(%8#1) {device = "", name = "while/Add/y", dtype = "tfdtype$DT_INT32", value = dense<3> : tensor<i32>} : (!_tf.control) -> (tensor<i32>, !_tf.control) - %10:2 = "_tf.Add"(%8#0, %9#0) {device = "", name = "while/Add", T = "tfdtype$DT_INT32"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi32>, !_tf.control) - %ctl = "_tf.NextIteration.sink"(%10#0) {device = "", name = "while/NextIteration", T = "tfdtype$DT_INT32", id = 0} : (tensor<*xi32>) -> (!_tf.control) - return -} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc index 849f1487c6e..24e77d31e7c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc @@ -58,6 +58,8 @@ void CreateTFStandardPipeline(OpPassManager &pm, func_pm.addPass(tf_executor::CreateTFExecutorGraphPruningPass()); func_pm.addPass(tf_executor::CreateTFExecutorIslandCoarseningPass()); func_pm.addPass(CreateMaterializePassthroughOpPass()); + if (options.form_clusters) + func_pm.addPass(TFDevice::CreateClusterFormationPass()); // Hopefully there is a single island left, or there wasn't any to begin with. // We now run the optimizer which operates mostly inside islands. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 08c95bd8b0e..5ca3b3fc06c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -77,6 +77,9 @@ struct StandardPipelineOptions Option<bool> enable_inliner{*this, "enable-inliner", llvm::cl::desc("Enable inliner."), llvm::cl::init(false)}; + Option<bool> form_clusters{*this, "form-clusters", + llvm::cl::desc("Enable Cluster Formation pass."), + llvm::cl::init(false)}; }; // Propagates the pass manager with the passes involved in transforming or @@ -149,13 +152,6 @@ std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass(); std::unique_ptr<OperationPass<FuncOp>> CreateOpFusionPass(); } // namespace TF -namespace TFControlFlow { -// Raises from the "TensorFlow Control Flow" dialect to the standard TensorFlow -// dialect. -std::unique_ptr<OperationPass<FuncOp>> CreateRaiseTFControlFlowPass(); - -} // namespace TFControlFlow - namespace tf_executor { class GraphOp; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/raise_control_flow.cc b/tensorflow/compiler/mlir/tensorflow/transforms/raise_control_flow.cc deleted file mode 100644 index ca234818e10..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/transforms/raise_control_flow.cc +++ /dev/null @@ -1,159 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file implements logic for raising from the "TensorFlow control flow" -// dialect of MLIR to the standard TensorFlow dialect. The TensorFlow control -// flow dialect represents control flow with Switch/Merge and a few related -// control flow nodes, along with control dependencies. -// -// This pass rebuilds them code in terms of MLIR branches and blocks, -// eliminating control dependencies, and results in the code being in the -// canonical TensorFlow dialect. - -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" - -namespace mlir { -namespace TFControlFlow { - -namespace { -struct RaiseTFControlFlow - : public PassWrapper<RaiseTFControlFlow, FunctionPass> { - void runOnFunction() { - // First start by recognizing loops and reconstructing a loop tree. - buildLoopNests(); - - // Next, transform Switch/Merge and other control flow ops into proper - // conditional control flow. - buildConditionals(); - - // Now that we have proper conditional control flow ops, the control edges - // can be dropped, and the underscores removed from operation names. - rewriteOps(); - } - - void buildLoopNests(); - void buildConditionals(); - void rewriteOps(); -}; - -//===----------------------------------------------------------------------===// -// Loop nest reconstruction -//===----------------------------------------------------------------------===// - -void RaiseTFControlFlow::buildLoopNests() { - // TODO(clattner) -} - -//===----------------------------------------------------------------------===// -// Conditional Reconstruction -//===----------------------------------------------------------------------===// - -void RaiseTFControlFlow::buildConditionals() { - // TODO. -} - -//===----------------------------------------------------------------------===// -// Final rewrite from TF Control Flow form to canonical TensorFlow form -//===----------------------------------------------------------------------===// - -static bool isUnderscoredTFOp(Operation &op) { - return op.getName().getStringRef().startswith("_tf."); -} - -// Drop control edges, and remove underscores from operation names. -void RaiseTFControlFlow::rewriteOps() { - auto function = getFunction(); - OpBuilder builder(function.getBody()); - - // On the first pass, create replacement operations for every one we are going - // to replace, updating anything that uses the normal results with the newly - // created operation. - for (auto &bb : function) { - for (auto &op : bb) { - // Ignore any operations that we aren't looking for. - if (!isUnderscoredTFOp(op)) continue; - - // We always insert the replacement operation next to the operation it - // is replacing. - builder.setInsertionPoint(&op); - - // Drop the leading _ off the name. - OperationState result(op.getLoc(), - op.getName().getStringRef().drop_front()); - - // Add an operand for each non-control input we find. Control values - // aren't necessary any more since the order within a block encodes the - // same information. - for (auto &operand : op.getOpOperands()) { - if (!operand.get().getType().isa<TFControlType>()) - result.operands.push_back(operand.get()); - - // Drop all operands from the old operation, eliminating any - // inter-dependencies after this pass. - operand.drop(); - } - - // Add a result type for each non-control result we find. - bool sawControlResult = false; - for (auto opResult : op.getResults()) { - if (opResult.getType().isa<TFControlType>()) { - sawControlResult = true; - } else { - // We assume all control inputs are at the end of the result list. - assert(!sawControlResult && "all control results must be last"); - (void)sawControlResult; - result.types.push_back(opResult.getType()); - } - } - - result.attributes.append(op.getAttrs().begin(), op.getAttrs().end()); - - // Create the replacement operation. - auto *replacement = builder.createOperation(result); - - // We know that all the control results are last, so we can just rewrite - // the first results. - for (unsigned i = 0, e = result.types.size(); i != e; ++i) - op.getResult(i).replaceAllUsesWith(replacement->getResult(i)); - } - } - - // In the second pass, we can safely remove all of the old operations, because - // we know that all inter-dependencies are dropped. - for (auto &bb : function) { - // Advance the iterator so we don't invalidate it when we remove an - // operation later in the loop. - for (auto &op : llvm::make_early_inc_range(bb)) - if (isUnderscoredTFOp(op)) op.erase(); - } -} - -} // namespace - -std::unique_ptr<OperationPass<FuncOp>> CreateRaiseTFControlFlowPass() { - return std::make_unique<RaiseTFControlFlow>(); -} - -static PassRegistration<RaiseTFControlFlow> pass( - "tf-raise-control-flow", - "Raise from the TensorFlow Control Flow " - "dialect to the standard TensorFlow dialect"); - -} // namespace TFControlFlow -} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc b/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc deleted file mode 100644 index 481f1fac7b8..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc +++ /dev/null @@ -1,242 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This transformation pass transforms from TF executor dialect to MLIR TF -// control dialect. - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/Sequence.h" -#include "llvm/ADT/SmallString.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassRegistry.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" - -#define DEBUG_TYPE "tf-executor-to-ctl" - -namespace mlir { - -namespace { -struct ExecutorToControlDialectConversion - : public PassWrapper<ExecutorToControlDialectConversion, FunctionPass> { - void runOnFunction() override; -}; -} // end anonymous namespace - -static bool HasSingleGraph(FuncOp function) { - // We expect the function has only one region with one block, - if (function.getBlocks().size() != 1) return false; - auto &block = function.front(); - // and the block contains two ops, - if (std::next(block.begin()) == block.end()) return false; - // one GraphOp, - if (!isa<tf_executor::GraphOp>(block.begin())) return false; - // followed by a terminator. - if (!std::next(block.begin())->isKnownTerminator()) return false; - return true; -} - -void ExecutorToControlDialectConversion::runOnFunction() { - if (!HasSingleGraph(getFunction())) { - LLVM_DEBUG(llvm::dbgs() - << "Expect a Function with a single block and a single graph op," - " skip tf_executor dialect conversion\n"); - return; - } - Type control_type = TFControlFlow::TFControlType::get(&getContext()); - - Block &body = getFunction().front(); - auto graph = cast<tf_executor::GraphOp>(body.front()); - OpBuilder builder = OpBuilder::atBlockEnd(&body); - SmallString<64> new_op_name; - for (auto &op : llvm::make_early_inc_range(llvm::reverse(graph.GetBody()))) { - LLVM_DEBUG(llvm::dbgs() << "Process: " << op.getName() << "\n"); - - if (auto fetch = dyn_cast<tf_executor::FetchOp>(op)) { - // Replace all the operands of the fetch op with the uses of the graph - // results, remove the fetch op afterwards. - for (auto ops_and_ret_vals : - llvm::zip(graph.getResults(), fetch.getOperands())) - std::get<0>(ops_and_ret_vals) - .replaceAllUsesWith(std::get<1>(ops_and_ret_vals)); - op.erase(); - continue; - } - - builder.setInsertionPoint(&op); - - if (auto island = dyn_cast<tf_executor::IslandOp>(op)) { - Value ctl_sequence = nullptr; - if (island.GetBody().without_terminator().empty() && - island.getNumOperands() > 1) { - // For an empty island with multiple control inputs, we create a no-op - // inside it which will group all the inputs into one control output. - // This helps reducing the number of edges when there are multiple - // islands depending on this one. - builder.setInsertionPointToStart(&island.GetBody()); - builder.create<TF::NoOp>(op.getLoc(), ArrayRef<Type>{}, - ArrayRef<Value>{}, ArrayRef<NamedAttribute>{}); - builder.setInsertionPoint(&op); - } - for (Operation &wrapped_op : island.GetBody()) { - LLVM_DEBUG(llvm::dbgs() - << " In island: " << wrapped_op.getName() << "\n"); - if (isa<tf_executor::YieldOp>(wrapped_op)) { - for (auto ops_and_ret_vals : - llvm::zip(island.getResults(), wrapped_op.getOperands())) - std::get<0>(ops_and_ret_vals) - .replaceAllUsesWith(std::get<1>(ops_and_ret_vals)); - break; - } - // Add a leading _ off the name. - new_op_name = "_"; - new_op_name += wrapped_op.getName().getStringRef(); - OperationState state(wrapped_op.getLoc(), new_op_name); - - // Add an operand for each non-control input we find. Collect control - // values separately to add them to the island operands - state.operands.append(wrapped_op.getOperands().begin(), - wrapped_op.getOperands().end()); - - // Chain operations through a control dependency, except for the first - // operations in the sequence that carry the control dependencies held - // by the island itself. - if (ctl_sequence) { - state.operands.push_back(ctl_sequence); - } else { - for (Value ctl_operand : island.getOperands()) - state.operands.push_back(ctl_operand); - } - - // Add a result type for each result - state.types.append(wrapped_op.getResultTypes().begin(), - wrapped_op.getResultTypes().end()); - state.types.push_back(control_type); - - // Create the replacement operation. - auto *replacement = builder.createOperation(state); - replacement->setAttrs(wrapped_op.getMutableAttrDict()); - - for (auto ops_and_ret_vals : - llvm::zip(wrapped_op.getResults(), replacement->getResults())) - std::get<0>(ops_and_ret_vals) - .replaceAllUsesWith(std::get<1>(ops_and_ret_vals)); - - ctl_sequence = replacement->getResult(replacement->getNumResults() - 1); - } - - if (ctl_sequence) { - // If ctl_sequence is non-null, this means at least one operation has - // been rewritten from ops in island. Last op rewritten must logically - // carry // all the island control inputs, we can simply use it to - // replace all uses of island's control output. - island.control().replaceAllUsesWith(ctl_sequence); - } else if (island.getNumOperands() > 0) { - // Getting here means island had an effectively empty body and there is - // just one control input. In this case, island's control output should - // be replaced with the control input. - assert(island.getNumOperands() == 1); - island.control().replaceAllUsesWith(island.getOperand(0)); - } - - op.erase(); - continue; - } - - new_op_name.clear(); - if (isa<tf_executor::SwitchOp>(op)) { - new_op_name = "_tf.Switch"; - } else if (isa<tf_executor::SwitchNOp>(op)) { - new_op_name = "_tf._SwitchN"; - } else if (isa<tf_executor::MergeOp>(op)) { - new_op_name = "_tf.Merge"; - } else if (isa<tf_executor::NextIterationSourceOp>(op)) { - new_op_name = "_tf.NextIteration.source"; - } else if (isa<tf_executor::NextIterationSinkOp>(op)) { - new_op_name = "_tf.NextIteration.sink"; - } else if (isa<tf_executor::LoopCondOp>(op)) { - new_op_name = "_tf.LoopCond"; - } else if (isa<tf_executor::EnterOp>(op)) { - new_op_name = "_tf.Enter"; - } else if (isa<tf_executor::ExitOp>(op)) { - new_op_name = "_tf.Exit"; - } else if (isa<tf_executor::ControlTriggerOp>(op)) { - new_op_name = "_tf.ControlTrigger"; - } else { - op.emitOpError() << "unhandled op in tf_executor to _tf conversion"; - return signalPassFailure(); - } - OperationState state(op.getLoc(), new_op_name); - // Drop all TokenType operands since they don't exist in the control - // dialect. - auto non_null_operands = llvm::make_filter_range( - op.getOperands(), - [](Value v) { return !v.getType().isa<tf_executor::TokenType>(); }); - state.operands.append(non_null_operands.begin(), non_null_operands.end()); - for (Type result_type : op.getResultTypes()) { - // Filter out TokenType, they don't exist in the control dialect. - if (result_type.isa<tf_executor::TokenType>()) continue; - if (!result_type.isa<tf_executor::ControlType>()) - state.types.push_back(result_type); - else - state.types.push_back(control_type); - } - // The control dialect has a control result for the sink operation. - if (isa<tf_executor::NextIterationSinkOp>(op)) - state.types.push_back(control_type); - - // Create the replacement operation. - auto *replacement = builder.createOperation(state); - replacement->setAttrs(op.getMutableAttrDict()); - - if (auto next_iteration = - dyn_cast<tf_executor::NextIterationSourceOp>(op)) { - next_iteration.output().replaceAllUsesWith(replacement->getResult(0)); - next_iteration.token().dropAllUses(); - next_iteration.control().replaceAllUsesWith(replacement->getResult(1)); - } else { - for (auto ops_and_ret_vals : - llvm::zip(op.getResults(), replacement->getResults())) - std::get<0>(ops_and_ret_vals) - .replaceAllUsesWith(std::get<1>(ops_and_ret_vals)); - } - op.erase(); - } - - // Now we have rewritten all ops inside GraphOp to TF Control dialect. We need - // to move all operations outside of GraphOp and remove it. - body.getOperations().splice(body.begin(), graph.GetBody().getOperations()); - graph.erase(); -} - -std::unique_ptr<OperationPass<FuncOp>> -CreateTFExecutorToControlDialectConversion() { - return std::make_unique<ExecutorToControlDialectConversion>(); -} - -} // namespace mlir - -static mlir::PassRegistration<mlir::ExecutorToControlDialectConversion> pass( - "tf-executor-to-control-conversion", - "Convert from TF executor dialect to TF control dialect"); diff --git a/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc b/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc index a445937570e..d48d90997de 100644 --- a/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc +++ b/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc @@ -23,12 +23,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tfjs/transforms/passes.h" -namespace mlir { -/// Create a pass to convert from the TFExecutor to the TF control dialect. -std::unique_ptr<OperationPass<FuncOp>> -CreateTFExecutorToControlDialectConversion(); -} // namespace mlir - namespace tensorflow { void AddTFToTFJSConversionPasses(mlir::OpPassManager* pm) {