- Remove executor_to_control pass.
- Remove raise control flow pass. - Cleanup usage in TFLite and other referneces. - Remove skip_control_dialect member in PassConfig PiperOrigin-RevId: 315552807 Change-Id: I4994f6a3c26cbe4845b97e7933272a860d3f15c2
This commit is contained in:
parent
b8a267a9fe
commit
191628f0e5
tensorflow/compiler/mlir
@ -32,7 +32,6 @@ struct PassConfig {
|
||||
lower_tensor_list_ops(false),
|
||||
trim_functions_whitelist({}),
|
||||
quant_specs(std::move(specs)),
|
||||
skip_control_dialect(false),
|
||||
form_clusters(false),
|
||||
unfold_batch_matmul(true),
|
||||
legalize_tf_while(true),
|
||||
@ -49,13 +48,8 @@ struct PassConfig {
|
||||
llvm::ArrayRef<std::string> trim_functions_whitelist;
|
||||
// All information about quantization.
|
||||
QuantizationSpecs quant_specs;
|
||||
// If `skip_control_dialect` is true, TF executor dialect is not converted to
|
||||
// TF control dialect prior to legalization to TF Lite.
|
||||
// TODO(b/142911013): Remove flag once control dialect is removed.
|
||||
bool skip_control_dialect;
|
||||
// If `form_clusters` is true (and `skip_control_dialect` is true), clusters
|
||||
// are formed by grouping consecutive ops of the same device, under a
|
||||
// `tf_device.launch` op.
|
||||
// If `form_clusters` is true , clusters are formed by grouping consecutive
|
||||
// ops of the same device, under a `tf_device.launch` op.
|
||||
bool form_clusters;
|
||||
// if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set
|
||||
// of tfl.fully_connected ops.
|
||||
|
@ -58,21 +58,10 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
|
||||
|
||||
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
mlir::OpPassManager* pass_manager) {
|
||||
pass_manager->addPass(mlir::tf_executor::CreateSwitchFoldPass());
|
||||
if (pass_config.skip_control_dialect) {
|
||||
// Merge islands.
|
||||
pass_manager->addPass(
|
||||
mlir::tf_executor::CreateTFExecutorIslandCoarseningPass());
|
||||
// Assuming island coarsening above results in a graph with a single island,
|
||||
// a canonicalization can be ran to hoist the ops of the single island out.
|
||||
pass_manager->addPass(mlir::createCanonicalizerPass());
|
||||
|
||||
if (pass_config.form_clusters)
|
||||
pass_manager->addPass(mlir::TFDevice::CreateClusterFormationPass());
|
||||
} else {
|
||||
pass_manager->addPass(mlir::CreateTFExecutorToControlDialectConversion());
|
||||
pass_manager->addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass());
|
||||
}
|
||||
mlir::TF::StandardPipelineOptions standard_pipeline_options;
|
||||
standard_pipeline_options.enable_inliner = false;
|
||||
standard_pipeline_options.form_clusters = pass_config.form_clusters;
|
||||
mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options);
|
||||
|
||||
if (pass_config.shape_inference) {
|
||||
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||
@ -213,13 +202,8 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
|
||||
OpPassManager& func_pm = pm.nest<FuncOp>();
|
||||
|
||||
// tf_executor dialect passes - Cleaning up the IR.
|
||||
func_pm.addPass(tf_executor::CreateSwitchFoldPass());
|
||||
func_pm.addPass(tf_executor::CreateTFExecutorGraphPruningPass());
|
||||
func_pm.addPass(tf_executor::CreateTFExecutorIslandCoarseningPass());
|
||||
|
||||
// more cleanup of executor dialect and raise to control flow.
|
||||
pm.addPass(mlir::CreateTFExecutorToControlDialectConversion());
|
||||
pm.addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass());
|
||||
mlir::TF::StandardPipelineOptions standard_pipeline_options;
|
||||
mlir::TF::CreateTFStandardPipeline(func_pm, standard_pipeline_options);
|
||||
|
||||
// This is needed for control flow support with TF TensorList.
|
||||
pm.addPass(mlir::TFL::CreateLowerStaticTensorListPass());
|
||||
|
@ -38,12 +38,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/tools/optimize/quantize_weights.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace mlir {
|
||||
/// Create a pass to convert from the TFExecutor to the TF control dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
CreateTFExecutorToControlDialectConversion();
|
||||
} // namespace mlir
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using mlir::MLIRContext;
|
||||
|
@ -431,7 +431,6 @@ cc_library(
|
||||
"transforms/optimize_global_tensors.cc",
|
||||
"transforms/parallel_execute_to_islands.cc",
|
||||
"transforms/promote_resources_to_args.cc",
|
||||
"transforms/raise_control_flow.cc",
|
||||
"transforms/readonly_references_to_resources.cc",
|
||||
"transforms/replicate_invariant_op_hoisting.cc",
|
||||
"transforms/replicate_to_island.cc",
|
||||
@ -460,7 +459,6 @@ cc_library(
|
||||
"transforms/tpu_variable_runtime_reformatting.cc",
|
||||
"translate/breakup-islands.cc",
|
||||
"translate/control_to_executor_dialect.cc",
|
||||
"translate/executor_to_control_dialect.cc",
|
||||
"translate/tf_functional_to_executor.cc",
|
||||
],
|
||||
hdrs = [
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt -tf-executor-to-control-conversion %s | FileCheck %s --check-prefix=CONTROL --dump-input=fail
|
||||
// RUN: tf-opt -tf-executor-graph-pruning %s | FileCheck %s --check-prefix=CONTROL --dump-input=fail
|
||||
// RUN: tf-opt -tf-control-to-executor-conversion %s | FileCheck %s --check-prefix=EXECUTOR --dump-input=fail
|
||||
|
||||
// CONTROL-LABEL: func @main
|
||||
|
@ -1,188 +0,0 @@
|
||||
// RUN: tf-opt -tf-executor-to-control-conversion %s | FileCheck %s --dump-input=fail
|
||||
// CHECK-LABEL: func @LoopTest() {
|
||||
func @LoopTest() {
|
||||
tf_executor.graph {
|
||||
%0:2 = tf_executor.island {
|
||||
%cst = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
tf_executor.yield %cst : tensor<i32>
|
||||
}
|
||||
%1:2 = tf_executor.Enter %0#0 frame "while/while_context" : (tensor<i32>) -> (tensor<*xi32>, !tf_executor.control) {T = "tfdtype$DT_INT32", device = "", name = "while/Enter"}
|
||||
%2 = tf_executor.island {
|
||||
"tf.NoOp"() {device = "", name = "cluster/pivot"} : () -> ()
|
||||
tf_executor.yield
|
||||
}
|
||||
%3:3 = tf_executor.NextIteration.Source : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"}
|
||||
%4:3 = tf_executor.Merge %3#0, %1#0 : tensor<*xi32> {N = 2 : i64, T = "tfdtype$DT_INT32", device = "", name = "while/Merge"}
|
||||
%5:2 = tf_executor.island(%4#2) {
|
||||
%cst = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/Less/y", value = dense<2> : tensor<i32>} : () -> tensor<i32>
|
||||
tf_executor.yield %cst : tensor<i32>
|
||||
}
|
||||
%6:2 = tf_executor.island {
|
||||
%14 = "tf.Less"(%4#0, %5#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Less"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
|
||||
tf_executor.yield %14 : tensor<*xi1>
|
||||
}
|
||||
%7:2 = tf_executor.LoopCond %6#0 : (tensor<*xi1>) -> (tensor<i1>, !tf_executor.control) {device = "", name = "while/LoopCond"}
|
||||
%8:3 = tf_executor.Switch %4#0, %7#0 : tensor<*xi32> {T = "tfdtype$DT_INT32", _class = ["loc = @while/Merge"], device = "", name = "while/Switch"}
|
||||
%9:2 = tf_executor.Exit %8#0 : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", name = "while/Exit"}
|
||||
%10:2 = tf_executor.island {
|
||||
%14 = "tf.Identity"(%8#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Identity"} : (tensor<*xi32>) -> tensor<*xi32>
|
||||
tf_executor.yield %14 : tensor<*xi32>
|
||||
}
|
||||
%11:2 = tf_executor.island(%10#1) {
|
||||
%cst = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/Add/y", value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||
tf_executor.yield %cst : tensor<i32>
|
||||
}
|
||||
%12:2 = tf_executor.island {
|
||||
%14 = "tf.Add"(%10#0, %11#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Add"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
tf_executor.yield %14 : tensor<*xi32>
|
||||
}
|
||||
%13 = tf_executor.ControlTrigger %2, %12#1, %9#1 {_tpu_replicate = "cluster", device = "", name = "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/b_sync"}
|
||||
tf_executor.NextIteration.Sink [%3#1] %12#0, %13 : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"}
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-NEXT: %[[CONST:[0-9]*]]:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> (tensor<i32>, !_tf.control)
|
||||
// CHECK-NEXT: %[[ENTER:[0-9]*]]:2 = "_tf.Enter"(%[[CONST]]#0) {T = "tfdtype$DT_INT32", device = "", frame_name = "while/while_context", is_constant = false, name = "while/Enter", parallel_iterations = 10 : i64} : (tensor<i32>) -> (tensor<*xi32>, !_tf.control)
|
||||
// CHECK-NEXT: %[[NOOP:[0-9]*]] = "_tf.NoOp"() {device = "", name = "cluster/pivot"} : () -> !_tf.control
|
||||
// CHECK-NEXT: %[[SOURCE:[0-9]*]]:2 = "_tf.NextIteration.source"() {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"} : () -> (tensor<*xi32>, !_tf.control)
|
||||
// CHECK-NEXT: %[[MERGE:[0-9]*]]:3 = "_tf.Merge"(%[[SOURCE]]#0, %[[ENTER]]#0) {N = 2 : i64, T = "tfdtype$DT_INT32", device = "", name = "while/Merge"} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<i32>, !_tf.control)
|
||||
// CHECK-NEXT: %[[CONST_LESS:[0-9]*]]:2 = "_tf.Const"(%[[MERGE]]#2) {device = "", dtype = "tfdtype$DT_INT32", name = "while/Less/y", value = dense<2> : tensor<i32>} : (!_tf.control) -> (tensor<i32>, !_tf.control)
|
||||
// CHECK-NEXT: %[[LESS:[0-9]*]]:2 = "_tf.Less"(%[[MERGE]]#0, %[[CONST_LESS]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Less"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi1>, !_tf.control)
|
||||
// CHECK-NEXT: %[[COND:[0-9]*]]:2 = "_tf.LoopCond"(%[[LESS]]#0) {device = "", name = "while/LoopCond"} : (tensor<*xi1>) -> (tensor<i1>, !_tf.control)
|
||||
// CHECK-NEXT: %[[SWITCH:[0-9]*]]:3 = "_tf.Switch"(%[[MERGE]]#0, %[[COND]]#0) {T = "tfdtype$DT_INT32", _class = ["loc = @while/Merge"], device = "", name = "while/Switch"} : (tensor<*xi32>, tensor<i1>) -> (tensor<*xi32>, tensor<*xi32>, !_tf.control)
|
||||
// CHECK-NEXT: %[[EXIT:[0-9]*]]:2 = "_tf.Exit"(%[[SWITCH]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Exit"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
||||
// CHECK-NEXT: %[[IDENTITY:[0-9]*]]:2 = "_tf.Identity"(%[[SWITCH]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Identity"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
||||
// CHECK-NEXT: %[[CONST_ADD:[0-9]*]]:2 = "_tf.Const"(%[[IDENTITY]]#1) {device = "", dtype = "tfdtype$DT_INT32", name = "while/Add/y", value = dense<3> : tensor<i32>} : (!_tf.control) -> (tensor<i32>, !_tf.control)
|
||||
// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = "_tf.Add"(%[[IDENTITY]]#0, %[[CONST_ADD]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Add"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi32>, !_tf.control)
|
||||
// CHECK-NEXT: %[[CT:[0-9]*]] = "_tf.ControlTrigger"(%[[NOOP]], %[[ADD]]#1, %[[EXIT]]#1) {_tpu_replicate = "cluster", device = "", name = "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/b_sync"} : (!_tf.control, !_tf.control, !_tf.control) -> !_tf.control
|
||||
// CHECK-NEXT: %[[SINK:[0-9]*]] = "_tf.NextIteration.sink"(%[[ADD]]#0, %[[CT]]) {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"} : (tensor<*xi32>, !_tf.control) -> !_tf.control
|
||||
// CHECK-NEXT: return
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @multiple_ops_region
|
||||
func @multiple_ops_region(%arg0 : tensor<*xi32>, %arg1 : tensor<i32>) {
|
||||
tf_executor.graph {
|
||||
%0:2 = tf_executor.island {
|
||||
// The 4 operations are independent, but the current conversion will add
|
||||
// control dependencies conservatively.
|
||||
%1 = "tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add1"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
%2 = "tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add2"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
%3 = "tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add3"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
%4 = "tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add4"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
tf_executor.yield %4 : tensor<*xi32>
|
||||
}
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-NEXT: %[[ADD1:[0-9]*]]:2 = "_tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add1"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi32>, !_tf.control)
|
||||
// CHECK-NEXT: %[[ADD2:[0-9]*]]:2 = "_tf.Add"(%arg0, %arg1, %[[ADD1]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add2"} : (tensor<*xi32>, tensor<i32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)
|
||||
// CHECK-NEXT: %[[ADD3:[0-9]*]]:2 = "_tf.Add"(%arg0, %arg1, %[[ADD2]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add3"} : (tensor<*xi32>, tensor<i32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)
|
||||
// CHECK-NEXT: %[[ADD4:[0-9]*]]:2 = "_tf.Add"(%arg0, %arg1, %[[ADD3]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add4"} : (tensor<*xi32>, tensor<i32>, !_tf.control) -> (tensor<*xi32>, !_tf.control)
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @switchN(
|
||||
func @switchN(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%fetches = tf_executor.graph {
|
||||
// CHECK: [[S1:%.*]]:6 = "_tf._SwitchN"(%arg1, %arg0) {num_outs = 5 : i64}
|
||||
%1:6 = tf_executor.SwitchN %arg1, %arg0 of 5 : tensor<*xf32>
|
||||
// CHECK: "_tf._SwitchN"(%arg1, %arg0, [[S1]]#5) {num_outs = 12 : i64}
|
||||
%2:13 = tf_executor.SwitchN %arg1, %arg0 of 12 (%1#5) : tensor<*xf32>
|
||||
tf_executor.fetch %2#0 : tensor<*xf32>
|
||||
}
|
||||
return %fetches : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test if tf_executor dialect ops with Ref types are mapped correctly to the ops in control dialect.
|
||||
// CHECK-LABEL: func @ref_tf_executor_ops
|
||||
func @ref_tf_executor_ops(%arg0: tensor<4x!tf.f32ref>, %arg1: tensor<4x!tf.f32ref>, %arg3: tensor<i32>, %arg4: tensor<i1> ) -> tensor<4x!tf.f32ref> {
|
||||
%result = tf_executor.graph {
|
||||
// CHECK: _tf.Enter
|
||||
%0:2 = tf_executor.Enter %arg0 frame "while/while_context" : (tensor<4x!tf.f32ref>) -> (tensor<4x!tf.f32ref>, !tf_executor.control)
|
||||
// CHECK: _tf.Exit
|
||||
%1:2 = tf_executor.Exit %arg0 : tensor<4x!tf.f32ref>
|
||||
// CHECK: _tf.Switch
|
||||
%2:3 = tf_executor.Switch %arg0, %arg4 : (tensor<4x!tf.f32ref>, tensor<i1>) -> (tensor<4x!tf.f32ref>, tensor<4x!tf.f32ref>, !tf_executor.control)
|
||||
// CHECK: _tf.Merge
|
||||
%3:3 = tf_executor.Merge %arg0, %arg1 : (tensor<4x!tf.f32ref>, tensor<4x!tf.f32ref>) -> (tensor<4x!tf.f32ref>, tensor<i32>, !tf_executor.control)
|
||||
// CHECK: _tf.NextIteration.source
|
||||
%4:3 = tf_executor.NextIteration.Source : tensor<4x!tf.f32ref>
|
||||
// CHECK: _tf.NextIteration.sink
|
||||
tf_executor.NextIteration.Sink [%4#1] %4#0 : tensor<4x!tf.f32ref>
|
||||
tf_executor.fetch %0#0 : tensor<4x!tf.f32ref>
|
||||
}
|
||||
return %result : tensor<4x!tf.f32ref>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests if empty island with just one control dependency input and output is
|
||||
// handled correctly.
|
||||
// CHECK-LABEL: func @empty_island_control_dep_only
|
||||
func @empty_island_control_dep_only() -> tensor<i32> {
|
||||
%fetch = tf_executor.graph {
|
||||
%0:2 = tf_executor.island {
|
||||
%4 = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
tf_executor.yield %4 : tensor<i32>
|
||||
}
|
||||
// CHECK-NEXT: %[[CONST1:[0-9]*]]:2 = "_tf.Const"()
|
||||
// CHECK-SAME: () -> (tensor<i32>, !_tf.control)
|
||||
%1:2 = tf_executor.island {
|
||||
%5 = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
tf_executor.yield %5 : tensor<i32>
|
||||
}
|
||||
// CHECK-NEXT: %[[CONST2:[0-9]*]]:2 = "_tf.Const"()
|
||||
// CHECK-SAME: () -> (tensor<i32>, !_tf.control)
|
||||
%2 = tf_executor.island(%0#1) {
|
||||
tf_executor.yield
|
||||
}
|
||||
%3:2 = tf_executor.island(%2, %1#1) {
|
||||
%6 = "tf.Add"(%0#0, %1#0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
tf_executor.yield %6 : tensor<i32>
|
||||
}
|
||||
// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = "_tf.Add"(%[[CONST1]]#0, %[[CONST2]]#0, %[[CONST1]]#1, %[[CONST2]]#1)
|
||||
// CHECK-SAME: (tensor<i32>, tensor<i32>, !_tf.control, !_tf.control) -> (tensor<i32>, !_tf.control)
|
||||
tf_executor.fetch %3#0 : tensor<i32>
|
||||
}
|
||||
return %fetch : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests if empty island with multiple control inputs will be replaced with a
|
||||
// no-op.
|
||||
// CHECK-LABEL: func @empty_island_multi_control_inputs
|
||||
func @empty_island_multi_control_inputs() -> tensor<i32> {
|
||||
%fetch = tf_executor.graph {
|
||||
%0:2 = tf_executor.island {
|
||||
%4 = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
tf_executor.yield %4 : tensor<i32>
|
||||
}
|
||||
// CHECK-NEXT: %[[CONST1:[0-9]*]]:2 = "_tf.Const"()
|
||||
// CHECK-SAME: () -> (tensor<i32>, !_tf.control)
|
||||
%1:2 = tf_executor.island {
|
||||
%5 = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
tf_executor.yield %5 : tensor<i32>
|
||||
}
|
||||
// CHECK-NEXT: %[[CONST2:[0-9]*]]:2 = "_tf.Const"()
|
||||
// CHECK-SAME: () -> (tensor<i32>, !_tf.control)
|
||||
%2 = tf_executor.island(%0#1, %1#1) {
|
||||
tf_executor.yield
|
||||
}
|
||||
// CHECK-NEXT: %[[NOOP:[0-9]*]] = "_tf.NoOp"(%[[CONST1]]#1, %[[CONST2]]#1)
|
||||
// CHECK-SAME: (!_tf.control, !_tf.control) -> !_tf.control
|
||||
%3:2 = tf_executor.island(%2) {
|
||||
%6 = "tf.Add"(%0#0, %1#0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
tf_executor.yield %6 : tensor<i32>
|
||||
}
|
||||
// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = "_tf.Add"(%[[CONST1]]#0, %[[CONST2]]#0, %[[NOOP]])
|
||||
// CHECK-SAME: (tensor<i32>, tensor<i32>, !_tf.control) -> (tensor<i32>, !_tf.control)
|
||||
tf_executor.fetch %3#0 : tensor<i32>
|
||||
}
|
||||
return %fetch : tensor<i32>
|
||||
}
|
@ -1,57 +0,0 @@
|
||||
// RUN: tf-opt %s -tf-raise-control-flow -split-input-file | FileCheck %s
|
||||
|
||||
// Test that we remove underscores.
|
||||
|
||||
// CHECK-LABEL: func @testSimpleAddsAndIdentity(%arg0: tensor<*xf32>)
|
||||
func @testSimpleAddsAndIdentity(tensor<*xf32>) -> tensor<*xf32> {
|
||||
^bb0(%0: tensor<*xf32>):
|
||||
|
||||
// CHECK: %0 = "tf.Identity"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%1 = "_tf.Identity"(%0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
|
||||
// CHECK: %1 = "tf.Add"(%arg0, %arg0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
%2 = "_tf.Add"(%0, %0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
|
||||
// CHECK: %2 = "tf.Add"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
%3 = "_tf.Add"(%1, %2) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
|
||||
// CHECK: return %2 : tensor<*xf32>
|
||||
return %3 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testAddWithControlDependency(%arg0: tensor<*xf32>)
|
||||
func @testAddWithControlDependency(tensor<*xf32>) -> tensor<*xf32> {
|
||||
^bb0(%0: tensor<*xf32>):
|
||||
|
||||
// CHECK: %0 = "tf.Identity"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%1:2 = "_tf.Identity"(%0) : (tensor<*xf32>) -> (tensor<*xf32>, !_tf.control)
|
||||
|
||||
// CHECK: %1 = "tf.Add"(%arg0, %arg0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
%2:2 = "_tf.Add"(%0, %0, %1#1) : (tensor<*xf32>, tensor<*xf32>, !_tf.control) -> (tensor<*xf32>, !_tf.control)
|
||||
|
||||
// CHECK: %2 = "tf.Add"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
%3:2 = "_tf.Add"(%1#0, %2, %1#1, %2#1) : (tensor<*xf32>, tensor<*xf32>, !_tf.control, !_tf.control) -> (tensor<*xf32>, !_tf.control)
|
||||
|
||||
// CHECK: return %2 : tensor<*xf32>
|
||||
return %3 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// TODO(clattner): simplify and expand these tests. This is mostly a placeholder.
|
||||
func @LoopTest() {
|
||||
%0:2 = "_tf.Const"() {device = "", name = "Const", dtype = "tfdtype$DT_INT32", value = dense<1> : tensor<i32>} : () -> (tensor<i32>, !_tf.control)
|
||||
%1:2 = "_tf.Enter"(%0#0) {device = "", name = "while/Enter", T = "tfdtype$DT_INT32", frame_name = "while/while_context", is_constant = false, parallel_iterations = 10} : (tensor<i32>) -> (tensor<*xi32>, !_tf.control)
|
||||
|
||||
%11:2 = "_tf.NextIteration.source"() {device = "", name = "while/NextIteration", T = "tfdtype$DT_INT32", id = 0} : () -> (tensor<*xi32>, !_tf.control)
|
||||
|
||||
%2:3 = "_tf.Merge"(%11#0, %1#0) {device = "", name = "while/Merge", N = 2, T = "tfdtype$DT_INT32"} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<i32>, !_tf.control)
|
||||
%3:2 = "_tf.Const"(%2#2) {device = "", name = "while/Less/y", dtype = "tfdtype$DT_INT32", value = dense<2> : tensor<i32>} : (!_tf.control) -> (tensor<i32>, !_tf.control)
|
||||
%4:2 = "_tf.Less"(%2#0, %3#0) {device = "", name = "while/Less", T = "tfdtype$DT_INT32"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi1>, !_tf.control)
|
||||
%5:2 = "_tf.LoopCond"(%4#0) {device = "", name = "while/LoopCond"} : (tensor<*xi1>) -> (tensor<i1>, !_tf.control)
|
||||
%6:3 = "_tf.Switch"(%2#0, %5#0) {device = "", name = "while/Switch", T = "tfdtype$DT_INT32", _class = ["loc:@while/Merge"]} : (tensor<*xi32>, tensor<i1>) -> (tensor<*xi32>, tensor<*xi32>, !_tf.control)
|
||||
%7:2 = "_tf.Exit"(%6#0) {device = "", name = "while/Exit", T = "tfdtype$DT_INT32"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
||||
%8:2 = "_tf.Identity"(%6#1) {device = "", name = "while/Identity", T = "tfdtype$DT_INT32"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control)
|
||||
%9:2 = "_tf.Const"(%8#1) {device = "", name = "while/Add/y", dtype = "tfdtype$DT_INT32", value = dense<3> : tensor<i32>} : (!_tf.control) -> (tensor<i32>, !_tf.control)
|
||||
%10:2 = "_tf.Add"(%8#0, %9#0) {device = "", name = "while/Add", T = "tfdtype$DT_INT32"} : (tensor<*xi32>, tensor<i32>) -> (tensor<*xi32>, !_tf.control)
|
||||
%ctl = "_tf.NextIteration.sink"(%10#0) {device = "", name = "while/NextIteration", T = "tfdtype$DT_INT32", id = 0} : (tensor<*xi32>) -> (!_tf.control)
|
||||
return
|
||||
}
|
@ -58,6 +58,8 @@ void CreateTFStandardPipeline(OpPassManager &pm,
|
||||
func_pm.addPass(tf_executor::CreateTFExecutorGraphPruningPass());
|
||||
func_pm.addPass(tf_executor::CreateTFExecutorIslandCoarseningPass());
|
||||
func_pm.addPass(CreateMaterializePassthroughOpPass());
|
||||
if (options.form_clusters)
|
||||
func_pm.addPass(TFDevice::CreateClusterFormationPass());
|
||||
|
||||
// Hopefully there is a single island left, or there wasn't any to begin with.
|
||||
// We now run the optimizer which operates mostly inside islands.
|
||||
|
@ -77,6 +77,9 @@ struct StandardPipelineOptions
|
||||
Option<bool> enable_inliner{*this, "enable-inliner",
|
||||
llvm::cl::desc("Enable inliner."),
|
||||
llvm::cl::init(false)};
|
||||
Option<bool> form_clusters{*this, "form-clusters",
|
||||
llvm::cl::desc("Enable Cluster Formation pass."),
|
||||
llvm::cl::init(false)};
|
||||
};
|
||||
|
||||
// Propagates the pass manager with the passes involved in transforming or
|
||||
@ -149,13 +152,6 @@ std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass();
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateOpFusionPass();
|
||||
} // namespace TF
|
||||
|
||||
namespace TFControlFlow {
|
||||
// Raises from the "TensorFlow Control Flow" dialect to the standard TensorFlow
|
||||
// dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateRaiseTFControlFlowPass();
|
||||
|
||||
} // namespace TFControlFlow
|
||||
|
||||
namespace tf_executor {
|
||||
class GraphOp;
|
||||
|
||||
|
@ -1,159 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file implements logic for raising from the "TensorFlow control flow"
|
||||
// dialect of MLIR to the standard TensorFlow dialect. The TensorFlow control
|
||||
// flow dialect represents control flow with Switch/Merge and a few related
|
||||
// control flow nodes, along with control dependencies.
|
||||
//
|
||||
// This pass rebuilds them code in terms of MLIR branches and blocks,
|
||||
// eliminating control dependencies, and results in the code being in the
|
||||
// canonical TensorFlow dialect.
|
||||
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFControlFlow {
|
||||
|
||||
namespace {
|
||||
struct RaiseTFControlFlow
|
||||
: public PassWrapper<RaiseTFControlFlow, FunctionPass> {
|
||||
void runOnFunction() {
|
||||
// First start by recognizing loops and reconstructing a loop tree.
|
||||
buildLoopNests();
|
||||
|
||||
// Next, transform Switch/Merge and other control flow ops into proper
|
||||
// conditional control flow.
|
||||
buildConditionals();
|
||||
|
||||
// Now that we have proper conditional control flow ops, the control edges
|
||||
// can be dropped, and the underscores removed from operation names.
|
||||
rewriteOps();
|
||||
}
|
||||
|
||||
void buildLoopNests();
|
||||
void buildConditionals();
|
||||
void rewriteOps();
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Loop nest reconstruction
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void RaiseTFControlFlow::buildLoopNests() {
|
||||
// TODO(clattner)
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Conditional Reconstruction
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void RaiseTFControlFlow::buildConditionals() {
|
||||
// TODO.
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Final rewrite from TF Control Flow form to canonical TensorFlow form
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static bool isUnderscoredTFOp(Operation &op) {
|
||||
return op.getName().getStringRef().startswith("_tf.");
|
||||
}
|
||||
|
||||
// Drop control edges, and remove underscores from operation names.
|
||||
void RaiseTFControlFlow::rewriteOps() {
|
||||
auto function = getFunction();
|
||||
OpBuilder builder(function.getBody());
|
||||
|
||||
// On the first pass, create replacement operations for every one we are going
|
||||
// to replace, updating anything that uses the normal results with the newly
|
||||
// created operation.
|
||||
for (auto &bb : function) {
|
||||
for (auto &op : bb) {
|
||||
// Ignore any operations that we aren't looking for.
|
||||
if (!isUnderscoredTFOp(op)) continue;
|
||||
|
||||
// We always insert the replacement operation next to the operation it
|
||||
// is replacing.
|
||||
builder.setInsertionPoint(&op);
|
||||
|
||||
// Drop the leading _ off the name.
|
||||
OperationState result(op.getLoc(),
|
||||
op.getName().getStringRef().drop_front());
|
||||
|
||||
// Add an operand for each non-control input we find. Control values
|
||||
// aren't necessary any more since the order within a block encodes the
|
||||
// same information.
|
||||
for (auto &operand : op.getOpOperands()) {
|
||||
if (!operand.get().getType().isa<TFControlType>())
|
||||
result.operands.push_back(operand.get());
|
||||
|
||||
// Drop all operands from the old operation, eliminating any
|
||||
// inter-dependencies after this pass.
|
||||
operand.drop();
|
||||
}
|
||||
|
||||
// Add a result type for each non-control result we find.
|
||||
bool sawControlResult = false;
|
||||
for (auto opResult : op.getResults()) {
|
||||
if (opResult.getType().isa<TFControlType>()) {
|
||||
sawControlResult = true;
|
||||
} else {
|
||||
// We assume all control inputs are at the end of the result list.
|
||||
assert(!sawControlResult && "all control results must be last");
|
||||
(void)sawControlResult;
|
||||
result.types.push_back(opResult.getType());
|
||||
}
|
||||
}
|
||||
|
||||
result.attributes.append(op.getAttrs().begin(), op.getAttrs().end());
|
||||
|
||||
// Create the replacement operation.
|
||||
auto *replacement = builder.createOperation(result);
|
||||
|
||||
// We know that all the control results are last, so we can just rewrite
|
||||
// the first results.
|
||||
for (unsigned i = 0, e = result.types.size(); i != e; ++i)
|
||||
op.getResult(i).replaceAllUsesWith(replacement->getResult(i));
|
||||
}
|
||||
}
|
||||
|
||||
// In the second pass, we can safely remove all of the old operations, because
|
||||
// we know that all inter-dependencies are dropped.
|
||||
for (auto &bb : function) {
|
||||
// Advance the iterator so we don't invalidate it when we remove an
|
||||
// operation later in the loop.
|
||||
for (auto &op : llvm::make_early_inc_range(bb))
|
||||
if (isUnderscoredTFOp(op)) op.erase();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateRaiseTFControlFlowPass() {
|
||||
return std::make_unique<RaiseTFControlFlow>();
|
||||
}
|
||||
|
||||
static PassRegistration<RaiseTFControlFlow> pass(
|
||||
"tf-raise-control-flow",
|
||||
"Raise from the TensorFlow Control Flow "
|
||||
"dialect to the standard TensorFlow dialect");
|
||||
|
||||
} // namespace TFControlFlow
|
||||
} // namespace mlir
|
@ -1,242 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass transforms from TF executor dialect to MLIR TF
|
||||
// control dialect.
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
#include "llvm/ADT/SmallString.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
|
||||
#define DEBUG_TYPE "tf-executor-to-ctl"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace {
|
||||
struct ExecutorToControlDialectConversion
|
||||
: public PassWrapper<ExecutorToControlDialectConversion, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
static bool HasSingleGraph(FuncOp function) {
|
||||
// We expect the function has only one region with one block,
|
||||
if (function.getBlocks().size() != 1) return false;
|
||||
auto &block = function.front();
|
||||
// and the block contains two ops,
|
||||
if (std::next(block.begin()) == block.end()) return false;
|
||||
// one GraphOp,
|
||||
if (!isa<tf_executor::GraphOp>(block.begin())) return false;
|
||||
// followed by a terminator.
|
||||
if (!std::next(block.begin())->isKnownTerminator()) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
void ExecutorToControlDialectConversion::runOnFunction() {
|
||||
if (!HasSingleGraph(getFunction())) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "Expect a Function with a single block and a single graph op,"
|
||||
" skip tf_executor dialect conversion\n");
|
||||
return;
|
||||
}
|
||||
Type control_type = TFControlFlow::TFControlType::get(&getContext());
|
||||
|
||||
Block &body = getFunction().front();
|
||||
auto graph = cast<tf_executor::GraphOp>(body.front());
|
||||
OpBuilder builder = OpBuilder::atBlockEnd(&body);
|
||||
SmallString<64> new_op_name;
|
||||
for (auto &op : llvm::make_early_inc_range(llvm::reverse(graph.GetBody()))) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Process: " << op.getName() << "\n");
|
||||
|
||||
if (auto fetch = dyn_cast<tf_executor::FetchOp>(op)) {
|
||||
// Replace all the operands of the fetch op with the uses of the graph
|
||||
// results, remove the fetch op afterwards.
|
||||
for (auto ops_and_ret_vals :
|
||||
llvm::zip(graph.getResults(), fetch.getOperands()))
|
||||
std::get<0>(ops_and_ret_vals)
|
||||
.replaceAllUsesWith(std::get<1>(ops_and_ret_vals));
|
||||
op.erase();
|
||||
continue;
|
||||
}
|
||||
|
||||
builder.setInsertionPoint(&op);
|
||||
|
||||
if (auto island = dyn_cast<tf_executor::IslandOp>(op)) {
|
||||
Value ctl_sequence = nullptr;
|
||||
if (island.GetBody().without_terminator().empty() &&
|
||||
island.getNumOperands() > 1) {
|
||||
// For an empty island with multiple control inputs, we create a no-op
|
||||
// inside it which will group all the inputs into one control output.
|
||||
// This helps reducing the number of edges when there are multiple
|
||||
// islands depending on this one.
|
||||
builder.setInsertionPointToStart(&island.GetBody());
|
||||
builder.create<TF::NoOp>(op.getLoc(), ArrayRef<Type>{},
|
||||
ArrayRef<Value>{}, ArrayRef<NamedAttribute>{});
|
||||
builder.setInsertionPoint(&op);
|
||||
}
|
||||
for (Operation &wrapped_op : island.GetBody()) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< " In island: " << wrapped_op.getName() << "\n");
|
||||
if (isa<tf_executor::YieldOp>(wrapped_op)) {
|
||||
for (auto ops_and_ret_vals :
|
||||
llvm::zip(island.getResults(), wrapped_op.getOperands()))
|
||||
std::get<0>(ops_and_ret_vals)
|
||||
.replaceAllUsesWith(std::get<1>(ops_and_ret_vals));
|
||||
break;
|
||||
}
|
||||
// Add a leading _ off the name.
|
||||
new_op_name = "_";
|
||||
new_op_name += wrapped_op.getName().getStringRef();
|
||||
OperationState state(wrapped_op.getLoc(), new_op_name);
|
||||
|
||||
// Add an operand for each non-control input we find. Collect control
|
||||
// values separately to add them to the island operands
|
||||
state.operands.append(wrapped_op.getOperands().begin(),
|
||||
wrapped_op.getOperands().end());
|
||||
|
||||
// Chain operations through a control dependency, except for the first
|
||||
// operations in the sequence that carry the control dependencies held
|
||||
// by the island itself.
|
||||
if (ctl_sequence) {
|
||||
state.operands.push_back(ctl_sequence);
|
||||
} else {
|
||||
for (Value ctl_operand : island.getOperands())
|
||||
state.operands.push_back(ctl_operand);
|
||||
}
|
||||
|
||||
// Add a result type for each result
|
||||
state.types.append(wrapped_op.getResultTypes().begin(),
|
||||
wrapped_op.getResultTypes().end());
|
||||
state.types.push_back(control_type);
|
||||
|
||||
// Create the replacement operation.
|
||||
auto *replacement = builder.createOperation(state);
|
||||
replacement->setAttrs(wrapped_op.getMutableAttrDict());
|
||||
|
||||
for (auto ops_and_ret_vals :
|
||||
llvm::zip(wrapped_op.getResults(), replacement->getResults()))
|
||||
std::get<0>(ops_and_ret_vals)
|
||||
.replaceAllUsesWith(std::get<1>(ops_and_ret_vals));
|
||||
|
||||
ctl_sequence = replacement->getResult(replacement->getNumResults() - 1);
|
||||
}
|
||||
|
||||
if (ctl_sequence) {
|
||||
// If ctl_sequence is non-null, this means at least one operation has
|
||||
// been rewritten from ops in island. Last op rewritten must logically
|
||||
// carry // all the island control inputs, we can simply use it to
|
||||
// replace all uses of island's control output.
|
||||
island.control().replaceAllUsesWith(ctl_sequence);
|
||||
} else if (island.getNumOperands() > 0) {
|
||||
// Getting here means island had an effectively empty body and there is
|
||||
// just one control input. In this case, island's control output should
|
||||
// be replaced with the control input.
|
||||
assert(island.getNumOperands() == 1);
|
||||
island.control().replaceAllUsesWith(island.getOperand(0));
|
||||
}
|
||||
|
||||
op.erase();
|
||||
continue;
|
||||
}
|
||||
|
||||
new_op_name.clear();
|
||||
if (isa<tf_executor::SwitchOp>(op)) {
|
||||
new_op_name = "_tf.Switch";
|
||||
} else if (isa<tf_executor::SwitchNOp>(op)) {
|
||||
new_op_name = "_tf._SwitchN";
|
||||
} else if (isa<tf_executor::MergeOp>(op)) {
|
||||
new_op_name = "_tf.Merge";
|
||||
} else if (isa<tf_executor::NextIterationSourceOp>(op)) {
|
||||
new_op_name = "_tf.NextIteration.source";
|
||||
} else if (isa<tf_executor::NextIterationSinkOp>(op)) {
|
||||
new_op_name = "_tf.NextIteration.sink";
|
||||
} else if (isa<tf_executor::LoopCondOp>(op)) {
|
||||
new_op_name = "_tf.LoopCond";
|
||||
} else if (isa<tf_executor::EnterOp>(op)) {
|
||||
new_op_name = "_tf.Enter";
|
||||
} else if (isa<tf_executor::ExitOp>(op)) {
|
||||
new_op_name = "_tf.Exit";
|
||||
} else if (isa<tf_executor::ControlTriggerOp>(op)) {
|
||||
new_op_name = "_tf.ControlTrigger";
|
||||
} else {
|
||||
op.emitOpError() << "unhandled op in tf_executor to _tf conversion";
|
||||
return signalPassFailure();
|
||||
}
|
||||
OperationState state(op.getLoc(), new_op_name);
|
||||
// Drop all TokenType operands since they don't exist in the control
|
||||
// dialect.
|
||||
auto non_null_operands = llvm::make_filter_range(
|
||||
op.getOperands(),
|
||||
[](Value v) { return !v.getType().isa<tf_executor::TokenType>(); });
|
||||
state.operands.append(non_null_operands.begin(), non_null_operands.end());
|
||||
for (Type result_type : op.getResultTypes()) {
|
||||
// Filter out TokenType, they don't exist in the control dialect.
|
||||
if (result_type.isa<tf_executor::TokenType>()) continue;
|
||||
if (!result_type.isa<tf_executor::ControlType>())
|
||||
state.types.push_back(result_type);
|
||||
else
|
||||
state.types.push_back(control_type);
|
||||
}
|
||||
// The control dialect has a control result for the sink operation.
|
||||
if (isa<tf_executor::NextIterationSinkOp>(op))
|
||||
state.types.push_back(control_type);
|
||||
|
||||
// Create the replacement operation.
|
||||
auto *replacement = builder.createOperation(state);
|
||||
replacement->setAttrs(op.getMutableAttrDict());
|
||||
|
||||
if (auto next_iteration =
|
||||
dyn_cast<tf_executor::NextIterationSourceOp>(op)) {
|
||||
next_iteration.output().replaceAllUsesWith(replacement->getResult(0));
|
||||
next_iteration.token().dropAllUses();
|
||||
next_iteration.control().replaceAllUsesWith(replacement->getResult(1));
|
||||
} else {
|
||||
for (auto ops_and_ret_vals :
|
||||
llvm::zip(op.getResults(), replacement->getResults()))
|
||||
std::get<0>(ops_and_ret_vals)
|
||||
.replaceAllUsesWith(std::get<1>(ops_and_ret_vals));
|
||||
}
|
||||
op.erase();
|
||||
}
|
||||
|
||||
// Now we have rewritten all ops inside GraphOp to TF Control dialect. We need
|
||||
// to move all operations outside of GraphOp and remove it.
|
||||
body.getOperations().splice(body.begin(), graph.GetBody().getOperations());
|
||||
graph.erase();
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
CreateTFExecutorToControlDialectConversion() {
|
||||
return std::make_unique<ExecutorToControlDialectConversion>();
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
static mlir::PassRegistration<mlir::ExecutorToControlDialectConversion> pass(
|
||||
"tf-executor-to-control-conversion",
|
||||
"Convert from TF executor dialect to TF control dialect");
|
@ -23,12 +23,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tfjs/transforms/passes.h"
|
||||
|
||||
namespace mlir {
|
||||
/// Create a pass to convert from the TFExecutor to the TF control dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
CreateTFExecutorToControlDialectConversion();
|
||||
} // namespace mlir
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
void AddTFToTFJSConversionPasses(mlir::OpPassManager* pm) {
|
||||
|
Loading…
Reference in New Issue
Block a user