- Remove executor_to_control pass.

- Remove raise control flow pass.
- Cleanup usage in TFLite and other referneces.
- Remove skip_control_dialect member in PassConfig

PiperOrigin-RevId: 315552807
Change-Id: I4994f6a3c26cbe4845b97e7933272a860d3f15c2
This commit is contained in:
Karim Nosir 2020-06-09 13:51:17 -07:00 committed by TensorFlower Gardener
parent b8a267a9fe
commit 191628f0e5
12 changed files with 14 additions and 698 deletions

View File

@ -32,7 +32,6 @@ struct PassConfig {
lower_tensor_list_ops(false),
trim_functions_whitelist({}),
quant_specs(std::move(specs)),
skip_control_dialect(false),
form_clusters(false),
unfold_batch_matmul(true),
legalize_tf_while(true),
@ -49,13 +48,8 @@ struct PassConfig {
llvm::ArrayRef<std::string> trim_functions_whitelist;
// All information about quantization.
QuantizationSpecs quant_specs;
// If `skip_control_dialect` is true, TF executor dialect is not converted to
// TF control dialect prior to legalization to TF Lite.
// TODO(b/142911013): Remove flag once control dialect is removed.
bool skip_control_dialect;
// If `form_clusters` is true (and `skip_control_dialect` is true), clusters
// are formed by grouping consecutive ops of the same device, under a
// `tf_device.launch` op.
// If `form_clusters` is true , clusters are formed by grouping consecutive
// ops of the same device, under a `tf_device.launch` op.
bool form_clusters;
// if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set
// of tfl.fully_connected ops.

View File

@ -58,21 +58,10 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
mlir::OpPassManager* pass_manager) {
pass_manager->addPass(mlir::tf_executor::CreateSwitchFoldPass());
if (pass_config.skip_control_dialect) {
// Merge islands.
pass_manager->addPass(
mlir::tf_executor::CreateTFExecutorIslandCoarseningPass());
// Assuming island coarsening above results in a graph with a single island,
// a canonicalization can be ran to hoist the ops of the single island out.
pass_manager->addPass(mlir::createCanonicalizerPass());
if (pass_config.form_clusters)
pass_manager->addPass(mlir::TFDevice::CreateClusterFormationPass());
} else {
pass_manager->addPass(mlir::CreateTFExecutorToControlDialectConversion());
pass_manager->addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass());
}
mlir::TF::StandardPipelineOptions standard_pipeline_options;
standard_pipeline_options.enable_inliner = false;
standard_pipeline_options.form_clusters = pass_config.form_clusters;
mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options);
if (pass_config.shape_inference) {
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
@ -213,13 +202,8 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
OpPassManager& func_pm = pm.nest<FuncOp>();
// tf_executor dialect passes - Cleaning up the IR.
func_pm.addPass(tf_executor::CreateSwitchFoldPass());
func_pm.addPass(tf_executor::CreateTFExecutorGraphPruningPass());
func_pm.addPass(tf_executor::CreateTFExecutorIslandCoarseningPass());
// more cleanup of executor dialect and raise to control flow.
pm.addPass(mlir::CreateTFExecutorToControlDialectConversion());
pm.addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass());
mlir::TF::StandardPipelineOptions standard_pipeline_options;
mlir::TF::CreateTFStandardPipeline(func_pm, standard_pipeline_options);
// This is needed for control flow support with TF TensorList.
pm.addPass(mlir::TFL::CreateLowerStaticTensorListPass());

View File

@ -38,12 +38,6 @@ limitations under the License.
#include "tensorflow/lite/tools/optimize/quantize_weights.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace mlir {
/// Create a pass to convert from the TFExecutor to the TF control dialect.
std::unique_ptr<OperationPass<FuncOp>>
CreateTFExecutorToControlDialectConversion();
} // namespace mlir
namespace tensorflow {
using mlir::MLIRContext;

View File

@ -431,7 +431,6 @@ cc_library(
"transforms/optimize_global_tensors.cc",
"transforms/parallel_execute_to_islands.cc",
"transforms/promote_resources_to_args.cc",
"transforms/raise_control_flow.cc",
"transforms/readonly_references_to_resources.cc",
"transforms/replicate_invariant_op_hoisting.cc",
"transforms/replicate_to_island.cc",
@ -460,7 +459,6 @@ cc_library(
"transforms/tpu_variable_runtime_reformatting.cc",
"translate/breakup-islands.cc",
"translate/control_to_executor_dialect.cc",
"translate/executor_to_control_dialect.cc",
"translate/tf_functional_to_executor.cc",
],
hdrs = [

View File

@ -1,4 +1,4 @@
// RUN: tf-opt -tf-executor-to-control-conversion %s | FileCheck %s --check-prefix=CONTROL --dump-input=fail
// RUN: tf-opt -tf-executor-graph-pruning %s | FileCheck %s --check-prefix=CONTROL --dump-input=fail
// RUN: tf-opt -tf-control-to-executor-conversion %s | FileCheck %s --check-prefix=EXECUTOR --dump-input=fail
// CONTROL-LABEL: func @main

View File

@ -1,188 +0,0 @@
// RUN: tf-opt -tf-executor-to-control-conversion %s | FileCheck %s --dump-input=fail
// CHECK-LABEL: func @LoopTest() {
func @LoopTest() {
tf_executor.graph {
%0:2 = tf_executor.island {
%cst = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32>
tf_executor.yield %cst : tensor<i32>
}
%1:2 = tf_executor.Enter %0#0 frame "while/while_context" : (tensor<i32>) -> (tensor<*xi32>, !tf_executor.control) {T = "tfdtype$DT_INT32", device = "", name = "while/Enter"}
%2 = tf_executor.island {
"tf.NoOp"() {device = "", name = "cluster/pivot"} : () -> ()
tf_executor.yield
}
%3:3 = tf_executor.NextIteration.Source : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"}
%4:3 = tf_executor.Merge %3#0, %1#0 : tensor<*xi32> {N = 2 : i64, T = "tfdtype$DT_INT32", device = "", name = "while/Merge"}
%5:2 = tf_executor.island(%4#2) {
%cst = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/Less/y", value = dense<2> : tensor<i32>} : () -> tensor<i32>
tf_executor.yield %cst : tensor<i32>
}
%6:2 = tf_executor.island {
%14 = "tf.Less"(%4#0, %5#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Less"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
tf_executor.yield %14 : tensor<*xi1>
}
%7:2 = tf_executor.LoopCond %6#0 : (tensor<*xi1>) -> (tensor<i1>, !tf_executor.control) {device = "", name = "while/LoopCond"}
%8:3 = tf_executor.Switch %4#0, %7#0 : tensor<*xi32> {T = "tfdtype$DT_INT32", _class = ["loc = @while/Merge"], device = "", name = "while/Switch"}
%9:2 = tf_executor.Exit %8#0 : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", name = "while/Exit"}
%10:2 = tf_executor.island {
%14 = "tf.Identity"(%8#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Identity"} : (tensor<*xi32>) -> tensor<*xi32>
tf_executor.yield %14 : tensor<*xi32>
}
%11:2 = tf_executor.island(%10#1) {
%cst = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/Add/y", value = dense<3> : tensor<i32>} : () -> tensor<i32>
tf_executor.yield %cst : tensor<i32>
}
%12:2 = tf_executor.island {
%14 = "tf.Add"(%10#0, %11#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Add"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
tf_executor.yield %14 : tensor<*xi32>
}
%13 = tf_executor.ControlTrigger %2, %12#1, %9#1 {_tpu_replicate = "cluster", device = "", name = "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/b_sync"}
tf_executor.NextIteration.Sink [%3#1] %12#0, %13 : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"}
tf_executor.fetch
}
return
}
// CHECK-NEXT: %[[CONST:[0-9]*]]:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> (tensor<i32>, !_tf.control)
// CHECK-NEXT: %[[ENTER:[0-9]*]]:2 = "_tf.Enter"(%[[CONST]]#0) {T = "tfdtype$DT_INT32", device = "", frame_name = "while/while_context", is_constant = false, name = "while/Enter", parallel_iterations = 10 : i64} : (tensor<i32>) -> (tensor<*xi32>, !_tf.control)
// CHECK-NEXT: %[[NOOP:[0-9]*]] = "_tf.NoOp"() {device = "", name = "cluster/pivot"} : () -> !_tf.control
// CHECK-NEXT: %[[SOURCE:[0-9]*]]:2 = "_tf.NextIteration.source"() {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"} : () -> (tensor<*xi32>, !_tf.control)
// CHECK-NEXT: %[[MERGE:[0-9]*]]:3 = "_tf.Merge"(%[[SOURCE]]#0, %[[ENTER]]#0) {N = 2 : i64, T = "tfdtype$DT_INT32", device = "", name = "while/Merge"} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<i32>, !_tf.control)
// CHECK-NEXT: %[[CONST_LESS:[0-9]*]]:2 = "_tf.Const"(%[[MERGE]]#2) {device = "", dtype = "tfdtype$DT_INT32", name = "while/Less/y", value = dense<2> : tensor<i32>} : (!_tf.control) -> (tensor<i32>, !_tf.control)
// CHECK-NEXT: %[[LESS:[0-9]*]]:2 = "_tf.Less"(%[[MERGE]]#0, %[[CONST_LESS]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Less"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi1>, !_tf.control)
// CHECK-NEXT: %[[COND:[0-9]*]]:2 = "_tf.LoopCond"(%[[LESS]]#0) {device = "", name = "while/LoopCond"} : (tensor<*xi1>) -> (tensor<i1>, !_tf.control)
// CHECK-NEXT: %[[SWITCH:[0-9]*]]:3 = "_tf.Switch"(%[[MERGE]]#0, %[[COND]]#0) {T = "tfdtype$DT_INT32", _class = ["loc = @while/Merge"], device = "", name = "while/Switch"} : (tensor<*xi32>, tensor<i1>) -> (tensor<*xi32>, tensor<*xi32>, !_tf.control)
// CHECK-NEXT: %[[EXIT:[0-9]*]]:2 = "_tf.Exit"(%[[SWITCH]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Exit"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
// CHECK-NEXT: %[[IDENTITY:[0-9]*]]:2 = "_tf.Identity"(%[[SWITCH]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Identity"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
// CHECK-NEXT: %[[CONST_ADD:[0-9]*]]:2 = "_tf.Const"(%[[IDENTITY]]#1) {device = "", dtype = "tfdtype$DT_INT32", name = "while/Add/y", value = dense<3> : tensor<i32>} : (!_tf.control) -> (tensor<i32>, !_tf.control)
// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = "_tf.Add"(%[[IDENTITY]]#0, %[[CONST_ADD]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Add"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi32>, !_tf.control)
// CHECK-NEXT: %[[CT:[0-9]*]] = "_tf.ControlTrigger"(%[[NOOP]], %[[ADD]]#1, %[[EXIT]]#1) {_tpu_replicate = "cluster", device = "", name = "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/b_sync"} : (!_tf.control, !_tf.control, !_tf.control) -> !_tf.control
// CHECK-NEXT: %[[SINK:[0-9]*]] = "_tf.NextIteration.sink"(%[[ADD]]#0, %[[CT]]) {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"} : (tensor<*xi32>, !_tf.control) -> !_tf.control
// CHECK-NEXT: return
// -----
// CHECK-LABEL: func @multiple_ops_region
func @multiple_ops_region(%arg0 : tensor<*xi32>, %arg1 : tensor<i32>) {
tf_executor.graph {
%0:2 = tf_executor.island {
// The 4 operations are independent, but the current conversion will add
// control dependencies conservatively.
%1 = "tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add1"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%2 = "tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add2"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%3 = "tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add3"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%4 = "tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add4"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
tf_executor.yield %4 : tensor<*xi32>
}
tf_executor.fetch
}
return
}
// CHECK-NEXT: %[[ADD1:[0-9]*]]:2 = "_tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add1"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi32>, !_tf.control)
// CHECK-NEXT: %[[ADD2:[0-9]*]]:2 = "_tf.Add"(%arg0, %arg1, %[[ADD1]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add2"} : (tensor<*xi32>, tensor<i32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)
// CHECK-NEXT: %[[ADD3:[0-9]*]]:2 = "_tf.Add"(%arg0, %arg1, %[[ADD2]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add3"} : (tensor<*xi32>, tensor<i32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)
// CHECK-NEXT: %[[ADD4:[0-9]*]]:2 = "_tf.Add"(%arg0, %arg1, %[[ADD3]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add4"} : (tensor<*xi32>, tensor<i32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)
// -----
// CHECK-LABEL: func @switchN(
func @switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%fetches = tf_executor.graph {
// CHECK: [[S1:%.*]]:6 = "_tf._SwitchN"(%arg1, %arg0) {num_outs = 5 : i64}
%1:6 = tf_executor.SwitchN %arg1, %arg0 of 5 : tensor<*xf32>
// CHECK: "_tf._SwitchN"(%arg1, %arg0, [[S1]]#5) {num_outs = 12 : i64}
%2:13 = tf_executor.SwitchN %arg1, %arg0 of 12 (%1#5) : tensor<*xf32>
tf_executor.fetch %2#0 : tensor<*xf32>
}
return %fetches : tensor<*xf32>
}
// -----
// Test if tf_executor dialect ops with Ref types are mapped correctly to the ops in control dialect.
// CHECK-LABEL: func @ref_tf_executor_ops
func @ref_tf_executor_ops(%arg0: tensor<4x!tf.f32ref>, %arg1: tensor<4x!tf.f32ref>, %arg3: tensor<i32>, %arg4: tensor<i1> ) -> tensor<4x!tf.f32ref> {
%result = tf_executor.graph {
// CHECK: _tf.Enter
%0:2 = tf_executor.Enter %arg0 frame "while/while_context" : (tensor<4x!tf.f32ref>) -> (tensor<4x!tf.f32ref>, !tf_executor.control)
// CHECK: _tf.Exit
%1:2 = tf_executor.Exit %arg0 : tensor<4x!tf.f32ref>
// CHECK: _tf.Switch
%2:3 = tf_executor.Switch %arg0, %arg4 : (tensor<4x!tf.f32ref>, tensor<i1>) -> (tensor<4x!tf.f32ref>, tensor<4x!tf.f32ref>, !tf_executor.control)
// CHECK: _tf.Merge
%3:3 = tf_executor.Merge %arg0, %arg1 : (tensor<4x!tf.f32ref>, tensor<4x!tf.f32ref>) -> (tensor<4x!tf.f32ref>, tensor<i32>, !tf_executor.control)
// CHECK: _tf.NextIteration.source
%4:3 = tf_executor.NextIteration.Source : tensor<4x!tf.f32ref>
// CHECK: _tf.NextIteration.sink
tf_executor.NextIteration.Sink [%4#1] %4#0 : tensor<4x!tf.f32ref>
tf_executor.fetch %0#0 : tensor<4x!tf.f32ref>
}
return %result : tensor<4x!tf.f32ref>
}
// -----
// Tests if empty island with just one control dependency input and output is
// handled correctly.
// CHECK-LABEL: func @empty_island_control_dep_only
func @empty_island_control_dep_only() -> tensor<i32> {
%fetch = tf_executor.graph {
%0:2 = tf_executor.island {
%4 = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32>
tf_executor.yield %4 : tensor<i32>
}
// CHECK-NEXT: %[[CONST1:[0-9]*]]:2 = "_tf.Const"()
// CHECK-SAME: () -> (tensor<i32>, !_tf.control)
%1:2 = tf_executor.island {
%5 = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32>
tf_executor.yield %5 : tensor<i32>
}
// CHECK-NEXT: %[[CONST2:[0-9]*]]:2 = "_tf.Const"()
// CHECK-SAME: () -> (tensor<i32>, !_tf.control)
%2 = tf_executor.island(%0#1) {
tf_executor.yield
}
%3:2 = tf_executor.island(%2, %1#1) {
%6 = "tf.Add"(%0#0, %1#0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_executor.yield %6 : tensor<i32>
}
// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = "_tf.Add"(%[[CONST1]]#0, %[[CONST2]]#0, %[[CONST1]]#1, %[[CONST2]]#1)
// CHECK-SAME: (tensor<i32>, tensor<i32>, !_tf.control, !_tf.control) -> (tensor<i32>, !_tf.control)
tf_executor.fetch %3#0 : tensor<i32>
}
return %fetch : tensor<i32>
}
// -----
// Tests if empty island with multiple control inputs will be replaced with a
// no-op.
// CHECK-LABEL: func @empty_island_multi_control_inputs
func @empty_island_multi_control_inputs() -> tensor<i32> {
%fetch = tf_executor.graph {
%0:2 = tf_executor.island {
%4 = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32>
tf_executor.yield %4 : tensor<i32>
}
// CHECK-NEXT: %[[CONST1:[0-9]*]]:2 = "_tf.Const"()
// CHECK-SAME: () -> (tensor<i32>, !_tf.control)
%1:2 = tf_executor.island {
%5 = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32>
tf_executor.yield %5 : tensor<i32>
}
// CHECK-NEXT: %[[CONST2:[0-9]*]]:2 = "_tf.Const"()
// CHECK-SAME: () -> (tensor<i32>, !_tf.control)
%2 = tf_executor.island(%0#1, %1#1) {
tf_executor.yield
}
// CHECK-NEXT: %[[NOOP:[0-9]*]] = "_tf.NoOp"(%[[CONST1]]#1, %[[CONST2]]#1)
// CHECK-SAME: (!_tf.control, !_tf.control) -> !_tf.control
%3:2 = tf_executor.island(%2) {
%6 = "tf.Add"(%0#0, %1#0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_executor.yield %6 : tensor<i32>
}
// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = "_tf.Add"(%[[CONST1]]#0, %[[CONST2]]#0, %[[NOOP]])
// CHECK-SAME: (tensor<i32>, tensor<i32>, !_tf.control) -> (tensor<i32>, !_tf.control)
tf_executor.fetch %3#0 : tensor<i32>
}
return %fetch : tensor<i32>
}

View File

@ -1,57 +0,0 @@
// RUN: tf-opt %s -tf-raise-control-flow -split-input-file | FileCheck %s
// Test that we remove underscores.
// CHECK-LABEL: func @testSimpleAddsAndIdentity(%arg0: tensor<*xf32>)
func @testSimpleAddsAndIdentity(tensor<*xf32>) -> tensor<*xf32> {
^bb0(%0: tensor<*xf32>):
// CHECK: %0 = "tf.Identity"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%1 = "_tf.Identity"(%0) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: %1 = "tf.Add"(%arg0, %arg0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
%2 = "_tf.Add"(%0, %0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: %2 = "tf.Add"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
%3 = "_tf.Add"(%1, %2) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: return %2 : tensor<*xf32>
return %3 : tensor<*xf32>
}
// CHECK-LABEL: func @testAddWithControlDependency(%arg0: tensor<*xf32>)
func @testAddWithControlDependency(tensor<*xf32>) -> tensor<*xf32> {
^bb0(%0: tensor<*xf32>):
// CHECK: %0 = "tf.Identity"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%1:2 = "_tf.Identity"(%0) : (tensor<*xf32>) -> (tensor<*xf32>, !_tf.control)
// CHECK: %1 = "tf.Add"(%arg0, %arg0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
%2:2 = "_tf.Add"(%0, %0, %1#1) : (tensor<*xf32>, tensor<*xf32>, !_tf.control) -> (tensor<*xf32>, !_tf.control)
// CHECK: %2 = "tf.Add"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
%3:2 = "_tf.Add"(%1#0, %2, %1#1, %2#1) : (tensor<*xf32>, tensor<*xf32>, !_tf.control, !_tf.control) -> (tensor<*xf32>, !_tf.control)
// CHECK: return %2 : tensor<*xf32>
return %3 : tensor<*xf32>
}
// TODO(clattner): simplify and expand these tests. This is mostly a placeholder.
func @LoopTest() {
%0:2 = "_tf.Const"() {device = "", name = "Const", dtype = "tfdtype$DT_INT32", value = dense<1> : tensor<i32>} : () -> (tensor<i32>, !_tf.control)
%1:2 = "_tf.Enter"(%0#0) {device = "", name = "while/Enter", T = "tfdtype$DT_INT32", frame_name = "while/while_context", is_constant = false, parallel_iterations = 10} : (tensor<i32>) -> (tensor<*xi32>, !_tf.control)
%11:2 = "_tf.NextIteration.source"() {device = "", name = "while/NextIteration", T = "tfdtype$DT_INT32", id = 0} : () -> (tensor<*xi32>, !_tf.control)
%2:3 = "_tf.Merge"(%11#0, %1#0) {device = "", name = "while/Merge", N = 2, T = "tfdtype$DT_INT32"} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<i32>, !_tf.control)
%3:2 = "_tf.Const"(%2#2) {device = "", name = "while/Less/y", dtype = "tfdtype$DT_INT32", value = dense<2> : tensor<i32>} : (!_tf.control) -> (tensor<i32>, !_tf.control)
%4:2 = "_tf.Less"(%2#0, %3#0) {device = "", name = "while/Less", T = "tfdtype$DT_INT32"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi1>, !_tf.control)
%5:2 = "_tf.LoopCond"(%4#0) {device = "", name = "while/LoopCond"} : (tensor<*xi1>) -> (tensor<i1>, !_tf.control)
%6:3 = "_tf.Switch"(%2#0, %5#0) {device = "", name = "while/Switch", T = "tfdtype$DT_INT32", _class = ["loc:@while/Merge"]} : (tensor<*xi32>, tensor<i1>) -> (tensor<*xi32>, tensor<*xi32>, !_tf.control)
%7:2 = "_tf.Exit"(%6#0) {device = "", name = "while/Exit", T = "tfdtype$DT_INT32"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
%8:2 = "_tf.Identity"(%6#1) {device = "", name = "while/Identity", T = "tfdtype$DT_INT32"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
%9:2 = "_tf.Const"(%8#1) {device = "", name = "while/Add/y", dtype = "tfdtype$DT_INT32", value = dense<3> : tensor<i32>} : (!_tf.control) -> (tensor<i32>, !_tf.control)
%10:2 = "_tf.Add"(%8#0, %9#0) {device = "", name = "while/Add", T = "tfdtype$DT_INT32"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi32>, !_tf.control)
%ctl = "_tf.NextIteration.sink"(%10#0) {device = "", name = "while/NextIteration", T = "tfdtype$DT_INT32", id = 0} : (tensor<*xi32>) -> (!_tf.control)
return
}

View File

@ -58,6 +58,8 @@ void CreateTFStandardPipeline(OpPassManager &pm,
func_pm.addPass(tf_executor::CreateTFExecutorGraphPruningPass());
func_pm.addPass(tf_executor::CreateTFExecutorIslandCoarseningPass());
func_pm.addPass(CreateMaterializePassthroughOpPass());
if (options.form_clusters)
func_pm.addPass(TFDevice::CreateClusterFormationPass());
// Hopefully there is a single island left, or there wasn't any to begin with.
// We now run the optimizer which operates mostly inside islands.

View File

@ -77,6 +77,9 @@ struct StandardPipelineOptions
Option<bool> enable_inliner{*this, "enable-inliner",
llvm::cl::desc("Enable inliner."),
llvm::cl::init(false)};
Option<bool> form_clusters{*this, "form-clusters",
llvm::cl::desc("Enable Cluster Formation pass."),
llvm::cl::init(false)};
};
// Propagates the pass manager with the passes involved in transforming or
@ -149,13 +152,6 @@ std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass();
std::unique_ptr<OperationPass<FuncOp>> CreateOpFusionPass();
} // namespace TF
namespace TFControlFlow {
// Raises from the "TensorFlow Control Flow" dialect to the standard TensorFlow
// dialect.
std::unique_ptr<OperationPass<FuncOp>> CreateRaiseTFControlFlowPass();
} // namespace TFControlFlow
namespace tf_executor {
class GraphOp;

View File

@ -1,159 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for raising from the "TensorFlow control flow"
// dialect of MLIR to the standard TensorFlow dialect. The TensorFlow control
// flow dialect represents control flow with Switch/Merge and a few related
// control flow nodes, along with control dependencies.
//
// This pass rebuilds them code in terms of MLIR branches and blocks,
// eliminating control dependencies, and results in the code being in the
// canonical TensorFlow dialect.
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
namespace mlir {
namespace TFControlFlow {
namespace {
struct RaiseTFControlFlow
: public PassWrapper<RaiseTFControlFlow, FunctionPass> {
void runOnFunction() {
// First start by recognizing loops and reconstructing a loop tree.
buildLoopNests();
// Next, transform Switch/Merge and other control flow ops into proper
// conditional control flow.
buildConditionals();
// Now that we have proper conditional control flow ops, the control edges
// can be dropped, and the underscores removed from operation names.
rewriteOps();
}
void buildLoopNests();
void buildConditionals();
void rewriteOps();
};
//===----------------------------------------------------------------------===//
// Loop nest reconstruction
//===----------------------------------------------------------------------===//
void RaiseTFControlFlow::buildLoopNests() {
// TODO(clattner)
}
//===----------------------------------------------------------------------===//
// Conditional Reconstruction
//===----------------------------------------------------------------------===//
void RaiseTFControlFlow::buildConditionals() {
// TODO.
}
//===----------------------------------------------------------------------===//
// Final rewrite from TF Control Flow form to canonical TensorFlow form
//===----------------------------------------------------------------------===//
static bool isUnderscoredTFOp(Operation &op) {
return op.getName().getStringRef().startswith("_tf.");
}
// Drop control edges, and remove underscores from operation names.
void RaiseTFControlFlow::rewriteOps() {
auto function = getFunction();
OpBuilder builder(function.getBody());
// On the first pass, create replacement operations for every one we are going
// to replace, updating anything that uses the normal results with the newly
// created operation.
for (auto &bb : function) {
for (auto &op : bb) {
// Ignore any operations that we aren't looking for.
if (!isUnderscoredTFOp(op)) continue;
// We always insert the replacement operation next to the operation it
// is replacing.
builder.setInsertionPoint(&op);
// Drop the leading _ off the name.
OperationState result(op.getLoc(),
op.getName().getStringRef().drop_front());
// Add an operand for each non-control input we find. Control values
// aren't necessary any more since the order within a block encodes the
// same information.
for (auto &operand : op.getOpOperands()) {
if (!operand.get().getType().isa<TFControlType>())
result.operands.push_back(operand.get());
// Drop all operands from the old operation, eliminating any
// inter-dependencies after this pass.
operand.drop();
}
// Add a result type for each non-control result we find.
bool sawControlResult = false;
for (auto opResult : op.getResults()) {
if (opResult.getType().isa<TFControlType>()) {
sawControlResult = true;
} else {
// We assume all control inputs are at the end of the result list.
assert(!sawControlResult && "all control results must be last");
(void)sawControlResult;
result.types.push_back(opResult.getType());
}
}
result.attributes.append(op.getAttrs().begin(), op.getAttrs().end());
// Create the replacement operation.
auto *replacement = builder.createOperation(result);
// We know that all the control results are last, so we can just rewrite
// the first results.
for (unsigned i = 0, e = result.types.size(); i != e; ++i)
op.getResult(i).replaceAllUsesWith(replacement->getResult(i));
}
}
// In the second pass, we can safely remove all of the old operations, because
// we know that all inter-dependencies are dropped.
for (auto &bb : function) {
// Advance the iterator so we don't invalidate it when we remove an
// operation later in the loop.
for (auto &op : llvm::make_early_inc_range(bb))
if (isUnderscoredTFOp(op)) op.erase();
}
}
} // namespace
std::unique_ptr<OperationPass<FuncOp>> CreateRaiseTFControlFlowPass() {
return std::make_unique<RaiseTFControlFlow>();
}
static PassRegistration<RaiseTFControlFlow> pass(
"tf-raise-control-flow",
"Raise from the TensorFlow Control Flow "
"dialect to the standard TensorFlow dialect");
} // namespace TFControlFlow
} // namespace mlir

View File

@ -1,242 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass transforms from TF executor dialect to MLIR TF
// control dialect.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#define DEBUG_TYPE "tf-executor-to-ctl"
namespace mlir {
namespace {
struct ExecutorToControlDialectConversion
: public PassWrapper<ExecutorToControlDialectConversion, FunctionPass> {
void runOnFunction() override;
};
} // end anonymous namespace
static bool HasSingleGraph(FuncOp function) {
// We expect the function has only one region with one block,
if (function.getBlocks().size() != 1) return false;
auto &block = function.front();
// and the block contains two ops,
if (std::next(block.begin()) == block.end()) return false;
// one GraphOp,
if (!isa<tf_executor::GraphOp>(block.begin())) return false;
// followed by a terminator.
if (!std::next(block.begin())->isKnownTerminator()) return false;
return true;
}
void ExecutorToControlDialectConversion::runOnFunction() {
if (!HasSingleGraph(getFunction())) {
LLVM_DEBUG(llvm::dbgs()
<< "Expect a Function with a single block and a single graph op,"
" skip tf_executor dialect conversion\n");
return;
}
Type control_type = TFControlFlow::TFControlType::get(&getContext());
Block &body = getFunction().front();
auto graph = cast<tf_executor::GraphOp>(body.front());
OpBuilder builder = OpBuilder::atBlockEnd(&body);
SmallString<64> new_op_name;
for (auto &op : llvm::make_early_inc_range(llvm::reverse(graph.GetBody()))) {
LLVM_DEBUG(llvm::dbgs() << "Process: " << op.getName() << "\n");
if (auto fetch = dyn_cast<tf_executor::FetchOp>(op)) {
// Replace all the operands of the fetch op with the uses of the graph
// results, remove the fetch op afterwards.
for (auto ops_and_ret_vals :
llvm::zip(graph.getResults(), fetch.getOperands()))
std::get<0>(ops_and_ret_vals)
.replaceAllUsesWith(std::get<1>(ops_and_ret_vals));
op.erase();
continue;
}
builder.setInsertionPoint(&op);
if (auto island = dyn_cast<tf_executor::IslandOp>(op)) {
Value ctl_sequence = nullptr;
if (island.GetBody().without_terminator().empty() &&
island.getNumOperands() > 1) {
// For an empty island with multiple control inputs, we create a no-op
// inside it which will group all the inputs into one control output.
// This helps reducing the number of edges when there are multiple
// islands depending on this one.
builder.setInsertionPointToStart(&island.GetBody());
builder.create<TF::NoOp>(op.getLoc(), ArrayRef<Type>{},
ArrayRef<Value>{}, ArrayRef<NamedAttribute>{});
builder.setInsertionPoint(&op);
}
for (Operation &wrapped_op : island.GetBody()) {
LLVM_DEBUG(llvm::dbgs()
<< " In island: " << wrapped_op.getName() << "\n");
if (isa<tf_executor::YieldOp>(wrapped_op)) {
for (auto ops_and_ret_vals :
llvm::zip(island.getResults(), wrapped_op.getOperands()))
std::get<0>(ops_and_ret_vals)
.replaceAllUsesWith(std::get<1>(ops_and_ret_vals));
break;
}
// Add a leading _ off the name.
new_op_name = "_";
new_op_name += wrapped_op.getName().getStringRef();
OperationState state(wrapped_op.getLoc(), new_op_name);
// Add an operand for each non-control input we find. Collect control
// values separately to add them to the island operands
state.operands.append(wrapped_op.getOperands().begin(),
wrapped_op.getOperands().end());
// Chain operations through a control dependency, except for the first
// operations in the sequence that carry the control dependencies held
// by the island itself.
if (ctl_sequence) {
state.operands.push_back(ctl_sequence);
} else {
for (Value ctl_operand : island.getOperands())
state.operands.push_back(ctl_operand);
}
// Add a result type for each result
state.types.append(wrapped_op.getResultTypes().begin(),
wrapped_op.getResultTypes().end());
state.types.push_back(control_type);
// Create the replacement operation.
auto *replacement = builder.createOperation(state);
replacement->setAttrs(wrapped_op.getMutableAttrDict());
for (auto ops_and_ret_vals :
llvm::zip(wrapped_op.getResults(), replacement->getResults()))
std::get<0>(ops_and_ret_vals)
.replaceAllUsesWith(std::get<1>(ops_and_ret_vals));
ctl_sequence = replacement->getResult(replacement->getNumResults() - 1);
}
if (ctl_sequence) {
// If ctl_sequence is non-null, this means at least one operation has
// been rewritten from ops in island. Last op rewritten must logically
// carry // all the island control inputs, we can simply use it to
// replace all uses of island's control output.
island.control().replaceAllUsesWith(ctl_sequence);
} else if (island.getNumOperands() > 0) {
// Getting here means island had an effectively empty body and there is
// just one control input. In this case, island's control output should
// be replaced with the control input.
assert(island.getNumOperands() == 1);
island.control().replaceAllUsesWith(island.getOperand(0));
}
op.erase();
continue;
}
new_op_name.clear();
if (isa<tf_executor::SwitchOp>(op)) {
new_op_name = "_tf.Switch";
} else if (isa<tf_executor::SwitchNOp>(op)) {
new_op_name = "_tf._SwitchN";
} else if (isa<tf_executor::MergeOp>(op)) {
new_op_name = "_tf.Merge";
} else if (isa<tf_executor::NextIterationSourceOp>(op)) {
new_op_name = "_tf.NextIteration.source";
} else if (isa<tf_executor::NextIterationSinkOp>(op)) {
new_op_name = "_tf.NextIteration.sink";
} else if (isa<tf_executor::LoopCondOp>(op)) {
new_op_name = "_tf.LoopCond";
} else if (isa<tf_executor::EnterOp>(op)) {
new_op_name = "_tf.Enter";
} else if (isa<tf_executor::ExitOp>(op)) {
new_op_name = "_tf.Exit";
} else if (isa<tf_executor::ControlTriggerOp>(op)) {
new_op_name = "_tf.ControlTrigger";
} else {
op.emitOpError() << "unhandled op in tf_executor to _tf conversion";
return signalPassFailure();
}
OperationState state(op.getLoc(), new_op_name);
// Drop all TokenType operands since they don't exist in the control
// dialect.
auto non_null_operands = llvm::make_filter_range(
op.getOperands(),
[](Value v) { return !v.getType().isa<tf_executor::TokenType>(); });
state.operands.append(non_null_operands.begin(), non_null_operands.end());
for (Type result_type : op.getResultTypes()) {
// Filter out TokenType, they don't exist in the control dialect.
if (result_type.isa<tf_executor::TokenType>()) continue;
if (!result_type.isa<tf_executor::ControlType>())
state.types.push_back(result_type);
else
state.types.push_back(control_type);
}
// The control dialect has a control result for the sink operation.
if (isa<tf_executor::NextIterationSinkOp>(op))
state.types.push_back(control_type);
// Create the replacement operation.
auto *replacement = builder.createOperation(state);
replacement->setAttrs(op.getMutableAttrDict());
if (auto next_iteration =
dyn_cast<tf_executor::NextIterationSourceOp>(op)) {
next_iteration.output().replaceAllUsesWith(replacement->getResult(0));
next_iteration.token().dropAllUses();
next_iteration.control().replaceAllUsesWith(replacement->getResult(1));
} else {
for (auto ops_and_ret_vals :
llvm::zip(op.getResults(), replacement->getResults()))
std::get<0>(ops_and_ret_vals)
.replaceAllUsesWith(std::get<1>(ops_and_ret_vals));
}
op.erase();
}
// Now we have rewritten all ops inside GraphOp to TF Control dialect. We need
// to move all operations outside of GraphOp and remove it.
body.getOperations().splice(body.begin(), graph.GetBody().getOperations());
graph.erase();
}
std::unique_ptr<OperationPass<FuncOp>>
CreateTFExecutorToControlDialectConversion() {
return std::make_unique<ExecutorToControlDialectConversion>();
}
} // namespace mlir
static mlir::PassRegistration<mlir::ExecutorToControlDialectConversion> pass(
"tf-executor-to-control-conversion",
"Convert from TF executor dialect to TF control dialect");

View File

@ -23,12 +23,6 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tfjs/transforms/passes.h"
namespace mlir {
/// Create a pass to convert from the TFExecutor to the TF control dialect.
std::unique_ptr<OperationPass<FuncOp>>
CreateTFExecutorToControlDialectConversion();
} // namespace mlir
namespace tensorflow {
void AddTFToTFJSConversionPasses(mlir::OpPassManager* pm) {