Add transforms to convert between functional and region based If.
- Add this pair of transforms in TF -> TFLite conversion pass manager - Add inliner hook to the TF dialect to allow inlining of call's within the regions of the IfRegion ops' - Add a end-to-end test case using 2 If's that demonstrate inlining and followed by constant sinking and constant folding. PiperOrigin-RevId: 314966564 Change-Id: I633b8a24cb68d5c2c54b889f9505e7d6fcb905a4
This commit is contained in:
parent
b6b9f0815e
commit
b055058610
tensorflow/compiler/mlir
lite
tensorflow
422
tensorflow/compiler/mlir/lite/tests/end2end/if_op.pbtxt
Normal file
422
tensorflow/compiler/mlir/lite/tests/end2end/if_op.pbtxt
Normal file
@ -0,0 +1,422 @@
|
||||
# RUN: tf_tfl_translate -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=4:4 -tf-output-arrays=StatefulIf,StatelessIf %s -o - --output-mlir | FileCheck %s --dump-input-on-failure
|
||||
|
||||
node {
|
||||
name: "tf.Less"
|
||||
op: "Less"
|
||||
input: "a"
|
||||
input: "b"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "my_equal"
|
||||
op: "Equal"
|
||||
input: "a"
|
||||
input: "b"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cst0"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
dim {
|
||||
size: 4
|
||||
}
|
||||
}
|
||||
float_val: 1.0
|
||||
float_val: 2.0
|
||||
float_val: 3.0
|
||||
float_val: 4.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cst1"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
dim {
|
||||
size: 4
|
||||
}
|
||||
}
|
||||
float_val: 5.0
|
||||
float_val: 6.0
|
||||
float_val: 7.0
|
||||
float_val: 8.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "StatefulIf"
|
||||
op: "If"
|
||||
input: "tf.Less"
|
||||
input: "a"
|
||||
input: "b"
|
||||
input: "cst0"
|
||||
input: "cst1"
|
||||
attr {
|
||||
key: "Tcond"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "Tin"
|
||||
value {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_FLOAT
|
||||
type: DT_FLOAT
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "Tout"
|
||||
value {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "else_branch"
|
||||
value {
|
||||
func {
|
||||
name: "cond_false"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "then_branch"
|
||||
value {
|
||||
func {
|
||||
name: "cond_true"
|
||||
}
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "StatelessIf"
|
||||
op: "StatelessIf"
|
||||
input: "my_equal"
|
||||
input: "a"
|
||||
input: "b"
|
||||
attr {
|
||||
key: "Tcond"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "Tin"
|
||||
value {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "Tout"
|
||||
value {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "else_branch"
|
||||
value {
|
||||
func {
|
||||
name: "cond_false_1"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "then_branch"
|
||||
value {
|
||||
func {
|
||||
name: "cond_true_1"
|
||||
}
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "main"
|
||||
op: "_Retval"
|
||||
input: "StatefulIf"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "index"
|
||||
value {
|
||||
i: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "main1"
|
||||
op: "_Retval"
|
||||
input: "StatelessIf"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "index"
|
||||
value {
|
||||
i: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "a"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "b"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
library {
|
||||
function {
|
||||
signature {
|
||||
name: "cond_true"
|
||||
input_arg {
|
||||
name: "cond_true_arg0"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
input_arg {
|
||||
name: "cond_true_arg1"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
input_arg {
|
||||
name: "cond_true_arg2"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
input_arg {
|
||||
name: "cond_true_arg3"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
output_arg {
|
||||
name: "cond_true_ret"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "tf.Add"
|
||||
op: "Add"
|
||||
input: "cond_true_arg2"
|
||||
input: "cond_true_arg3"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "tf.Add"
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "cond_true_ret"
|
||||
value: "tf.Add:z:0"
|
||||
}
|
||||
}
|
||||
function {
|
||||
signature {
|
||||
name: "cond_false"
|
||||
input_arg {
|
||||
name: "cond_false_arg0"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
input_arg {
|
||||
name: "cond_false_arg1"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
input_arg {
|
||||
name: "cond_false_arg2"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
input_arg {
|
||||
name: "cond_false_arg3"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
output_arg {
|
||||
name: "cond_false_ret"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "tf.Mul"
|
||||
op: "Mul"
|
||||
input: "cond_false_arg0"
|
||||
input: "cond_false_arg3"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "tf.Mul"
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "cond_false_ret"
|
||||
value: "tf.Mul:z:0"
|
||||
}
|
||||
}
|
||||
function {
|
||||
signature {
|
||||
name: "cond_true_1"
|
||||
input_arg {
|
||||
name: "cond_true_arg0"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
input_arg {
|
||||
name: "cond_true_arg1"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
output_arg {
|
||||
name: "cond_true_ret"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "tf.Sub"
|
||||
op: "Sub"
|
||||
input: "cond_true_arg0"
|
||||
input: "cond_true_arg1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "tf.Sub"
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "cond_true_ret"
|
||||
value: "tf.Sub:z:0"
|
||||
}
|
||||
}
|
||||
function {
|
||||
signature {
|
||||
name: "cond_false_1"
|
||||
input_arg {
|
||||
name: "cond_false_arg0"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
input_arg {
|
||||
name: "cond_false_arg1"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
output_arg {
|
||||
name: "cond_false_ret"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "tf.Div"
|
||||
op: "Div"
|
||||
input: "cond_false_arg0"
|
||||
input: "cond_false_arg1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "tf.Div"
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "cond_false_ret"
|
||||
value: "tf.Div:z:0"
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 115
|
||||
min_consumer: 12
|
||||
}
|
||||
|
||||
# CHECK: func @StatefulIf_else
|
||||
# CHECK-NEXT: constant dense<[5.000000e+00, 6.000000e+00, 7.000000e+00, 8.000000e+00]>
|
||||
# CHECK-NEXT: tfl.mul
|
||||
# CHECK: func @StatefulIf_then
|
||||
# CHECK-NEXT: constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]>
|
||||
# CHECK-NEXT: return
|
||||
# CHECK: func @StatelessIf_else
|
||||
# CHECK-NEXT: tfl.div
|
||||
# CHECK: func @StatelessIf_then
|
||||
# CHECK-NEXT: tfl.sub
|
||||
# CHECK: "tf.If"{{.+}}else_branch = @StatelessIf_else{{.+}}then_branch = @StatelessIf_then
|
||||
# CHECK: "tf.If"{{.+}}else_branch = @StatefulIf_else{{.+}}then_branch = @StatefulIf_then
|
||||
|
@ -85,6 +85,8 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
pass_config.quant_specs.serialized_quant_stats));
|
||||
}
|
||||
|
||||
pass_manager->addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions());
|
||||
|
||||
// The conversion pipeline has to follow the following orders:
|
||||
// 1) Saved model related optimization like decompose resource ops
|
||||
// 2) Convert composite functions like lstm/rnns, along with proper function
|
||||
@ -128,6 +130,9 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
// Add a shape inference pass to optimize away the unnecessary casts.
|
||||
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||
}
|
||||
|
||||
pass_manager->addPass(mlir::TF::CreateTFRegionControlFlowToFunctional());
|
||||
|
||||
// Legalize while early to allow further constant folding.
|
||||
// TODO(jpienaar): This may not actually matter as we do canonicalization
|
||||
// after the legalize below, for now it needs to be below the above passes
|
||||
|
@ -54,7 +54,6 @@ class WhileOutlinePass
|
||||
|
||||
tensorflow::OpOrArgLocNameMapper mapper_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::string WhileOutlinePass::GetName(Operation* op, StringRef suffix) {
|
||||
return (mapper_.GetUniqueName(op) + suffix).str();
|
||||
@ -62,7 +61,7 @@ std::string WhileOutlinePass::GetName(Operation* op, StringRef suffix) {
|
||||
|
||||
// Returns whether the WhileOp is already outlined (e.g., only consists of calls
|
||||
// to functions).
|
||||
static bool IsAlreadyOutlinedd(WhileOp while_op) {
|
||||
bool IsAlreadyOutlined(WhileOp while_op) {
|
||||
auto just_call = [](Region& region) {
|
||||
auto it = region.front().begin();
|
||||
if (!isa<CallOp>(*it)) return false;
|
||||
@ -120,7 +119,7 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
|
||||
}
|
||||
|
||||
// Skip if already just calls.
|
||||
if (extra_operands.empty() && IsAlreadyOutlinedd(while_op)) return;
|
||||
if (extra_operands.empty() && IsAlreadyOutlined(while_op)) return;
|
||||
|
||||
// Collect new types.
|
||||
SmallVector<Type, 4> types;
|
||||
@ -238,6 +237,7 @@ void WhileOutlinePass::runOnOperation() {
|
||||
getOperation().walk(
|
||||
[&](mlir::TFL::WhileOp while_op) { OutlineWhile(while_op); });
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass() {
|
||||
|
@ -418,6 +418,7 @@ cc_library(
|
||||
"transforms/fold_switch.cc",
|
||||
"transforms/freeze_global_tensors.cc",
|
||||
"transforms/functional_control_flow_to_cfg.cc",
|
||||
"transforms/functional_control_flow_to_regions.cc",
|
||||
"transforms/generated_canonicalize.inc",
|
||||
"transforms/generated_optimize.inc",
|
||||
"transforms/gpu_fusion.cc",
|
||||
@ -432,6 +433,7 @@ cc_library(
|
||||
"transforms/promote_resources_to_args.cc",
|
||||
"transforms/raise_control_flow.cc",
|
||||
"transforms/readonly_references_to_resources.cc",
|
||||
"transforms/region_control_flow_to_functional.cc",
|
||||
"transforms/replicate_invariant_op_hoisting.cc",
|
||||
"transforms/replicate_to_island.cc",
|
||||
"transforms/resource_device_inference.cc",
|
||||
@ -490,6 +492,7 @@ cc_library(
|
||||
":translate_utils",
|
||||
":unroll_batch_matmul_pass",
|
||||
":xla_sharding_util",
|
||||
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
|
||||
"//tensorflow/compiler/mlir/lite:validators",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla:xla_proto_cc",
|
||||
|
@ -4011,6 +4011,15 @@ struct TFInlinerInterface : public DialectInlinerInterface {
|
||||
// Analysis Hooks
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
// Defines the legality of inlinining 'src' region into the 'dest' region
|
||||
// attached to a TF operation
|
||||
bool isLegalToInline(Region *dest, Region *src,
|
||||
BlockAndValueMapping &valueMapping) const final {
|
||||
// Allow inlining in regions attached to region based control flow
|
||||
// operations only if the src region is a single block region
|
||||
return isa<IfRegionOp>(dest->getParentOp()) && src->getBlocks().size() == 1;
|
||||
}
|
||||
|
||||
// Defines the legality of inlining TF operations.
|
||||
bool isLegalToInline(Operation *, Region *,
|
||||
BlockAndValueMapping &) const final {
|
||||
|
@ -0,0 +1,48 @@
|
||||
// RUN: tf-opt %s -tf-functional-control-flow-to-regions -split-input-file | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK: func @testIf1Then{{.+}} {sym_visibility = "private"}
|
||||
// CHECK: func @testIf1Else{{.+}} {sym_visibility = "private"}
|
||||
func @testIf1Then(tensor<*xf32>) -> tensor<*xf32>
|
||||
func @testIf1Else(tensor<*xf32>) -> tensor<*xf32>
|
||||
|
||||
// CHECK-LABEL: func @testIf1Result(%arg0: tensor<i1>, %arg1: tensor<*xf32>)
|
||||
func @testIf1Result(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "tf.If"(%arg0, %arg1) {
|
||||
then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false
|
||||
} : (tensor<i1>, tensor<*xf32>) -> tensor<*xf32>
|
||||
|
||||
// CHECK: "tf.IfRegion"
|
||||
// CHECK: [[Result0:%.*]] = call @testIf1Then
|
||||
// CHECK: "tf.Yield"([[Result0]])
|
||||
// CHECK: [[Result1:%.*]] = call @testIf1Else
|
||||
// CHECK: "tf.Yield"([[Result1]])
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// With mismatching input types
|
||||
|
||||
// CHECK: func @testIf1Then{{.+}} {sym_visibility = "private"}
|
||||
// CHECK: func @testIf1Else{{.+}} {sym_visibility = "private"}
|
||||
func @testIf1Then(tensor<*xf32>) -> tensor<*xf32>
|
||||
func @testIf1Else(tensor<*xf32>) -> tensor<*xf32>
|
||||
|
||||
// CHECK-LABEL: func @testIf2Result(%arg0: tensor<i1>, %arg1: tensor<2xf32>)
|
||||
func @testIf2Result(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "tf.If"(%arg0, %arg1) {
|
||||
then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false
|
||||
} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
// CHECK: "tf.IfRegion"
|
||||
// CHECK: "tf.Cast"
|
||||
// CHECK: [[Result0:%.*]] = call @testIf1Then
|
||||
// CHECK: "tf.Yield"([[Result0]])
|
||||
// CHECK: "tf.Cast"
|
||||
// CHECK: [[Result1:%.*]] = call @testIf1Else
|
||||
// CHECK: "tf.Yield"([[Result1]])
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
|
||||
|
@ -0,0 +1,142 @@
|
||||
// RUN: tf-opt %s -tf-region-control-flow-to-functional -split-input-file
|
||||
//| FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK: func @tf.IfRegion_else(%arg0: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"}
|
||||
// CHECK-NEXT: "tf.Neg"
|
||||
// CHECK: func @tf.IfRegion_then(%arg0: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"}
|
||||
// CHECK-NEXT: "tf.Abs"
|
||||
func @testSimple(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: "tf.If"{{.+}}else_branch = @tf.IfRegion_else{{.+}}then_branch = @tf.IfRegion_then
|
||||
%0 = "tf.IfRegion"(%arg0) ({
|
||||
%1 = "tf.Abs"(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
"tf.Yield"(%1) : (tensor<*xf32>) -> ()
|
||||
}, {
|
||||
%2 = "tf.Neg"(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
"tf.Yield"(%2) : (tensor<*xf32>) -> ()
|
||||
}) { is_stateless = true } : (tensor<i1>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Use if condition inside the regions
|
||||
// CHECK: func @tf.IfRegion_else(%arg0: tensor<i1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>) -> tensor<2xf32> attributes {sym_visibility = "private"}
|
||||
// CHECK-NEXT: "tf.Select"(%arg0, %arg2, %arg3)
|
||||
// CHECK: func @tf.IfRegion_then(%arg0: tensor<i1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>) -> tensor<2xf32> attributes {sym_visibility = "private"}
|
||||
// CHECK-NEXT: "tf.Select"(%arg0, %arg1, %arg2)
|
||||
func @testIfCondition(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "tf.Add"(%arg1, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
%1 = "tf.Mul"(%arg1, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
%2 = "tf.Div"(%arg1, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
// CHECK: "tf.If"{{.+}}else_branch = @tf.IfRegion_else{{.+}}then_branch = @tf.IfRegion_then
|
||||
%3 = "tf.IfRegion"(%arg0) ({
|
||||
%4 = "tf.Select"(%arg0, %0, %1) : (tensor<i1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%4) : (tensor<2xf32>) -> ()
|
||||
}, {
|
||||
%5 = "tf.Select"(%arg0, %1, %2): (tensor<i1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%5) : (tensor<2xf32>) -> ()
|
||||
}) { is_stateless = true} : (tensor<i1>) -> tensor<2xf32>
|
||||
return %3 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Constant sinking
|
||||
|
||||
// CHECK: func @tf.IfRegion_else() -> tensor<2xf32>
|
||||
// CHECK-NEXT: constant dense<1.0
|
||||
// CHECK: func @tf.IfRegion_then() -> tensor<2xf32>
|
||||
// CHECK-NEXT: constant dense<0.0
|
||||
func @testIfConstant(%arg0: tensor<i1>) -> tensor<2xf32> {
|
||||
%cst_zero = constant dense<0.0> : tensor<2xf32>
|
||||
// CHECK: "tf.If"(%arg0) {else_branch = @tf.IfRegion_else{{.+}}then_branch = @tf.IfRegion_then
|
||||
%0 = "tf.IfRegion"(%arg0) ({
|
||||
"tf.Yield"(%cst_zero) : (tensor<2xf32>) -> ()
|
||||
}, {
|
||||
%cst_one = constant dense<1.0> : tensor<2xf32>
|
||||
"tf.Yield"(%cst_one) : (tensor<2xf32>) -> ()
|
||||
}) { is_stateless = true} : (tensor<i1>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Nested IfRegions
|
||||
// CHECK: func @tf.IfRegion1_else
|
||||
// CHECK-NEXT: "tf.Acos"
|
||||
// CHECK-NEXT: "tf.Abs"
|
||||
|
||||
// CHECK: func @tf.IfRegion1_then
|
||||
// CHECK-NEXT: "tf.LogicalNot"
|
||||
// CHECK-NEXT: "tf.Asin"
|
||||
// CHECK-NEXT: "tf.If"({{.+}}) {else_branch = @tf.IfRegion_else, {{.+}} then_branch = @tf.IfRegion_then}
|
||||
|
||||
// CHECK: func @tf.IfRegion_else
|
||||
// CHECK-NEXT: "tf.Neg"
|
||||
// CHECK: func @tf.IfRegion_then
|
||||
// CHECK-NEXT: "tf.Abs"
|
||||
|
||||
func @testNested(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: "tf.If"({{.+}}) {else_branch = @tf.IfRegion1_else, {{.+}} then_branch = @tf.IfRegion1_then}
|
||||
%0 = "tf.IfRegion"(%arg0) ({
|
||||
// Outer Then
|
||||
%cond = "tf.LogicalNot"(%arg0) : (tensor<i1>) -> tensor<i1>
|
||||
%asin = "tf.Asin"(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
|
||||
// nested IfRegion
|
||||
%1 = "tf.IfRegion"(%cond) ({
|
||||
%2 = "tf.Abs"(%asin) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
"tf.Yield"(%2) : (tensor<*xf32>) -> ()
|
||||
}, {
|
||||
%2 = "tf.Neg"(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
"tf.Yield"(%2) : (tensor<*xf32>) -> ()
|
||||
}) { is_stateless = true } : (tensor<i1>) -> tensor<*xf32>
|
||||
|
||||
"tf.Yield"(%1) : (tensor<*xf32>) -> ()
|
||||
}, {
|
||||
// Outer Else
|
||||
%acos = "tf.Acos"(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%3 = "tf.Abs"(%acos) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
"tf.Yield"(%3) : (tensor<*xf32>) -> ()
|
||||
}) { is_stateless = true } : (tensor<i1>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Match existing function->Region pattern (simple)
|
||||
func @testIf1Then(tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"}
|
||||
func @testIf1Else(tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"}
|
||||
func @testIf1Result(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: "tf.If"({{.+}}) {else_branch = @testIf1Else, {{.+}} then_branch = @testIf1Then}
|
||||
%0 = "tf.IfRegion"(%arg0) ( {
|
||||
%1 = call @testIf1Then(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
"tf.Yield"(%1) : (tensor<*xf32>) -> ()
|
||||
}, {
|
||||
%1 = call @testIf1Else(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
"tf.Yield"(%1) : (tensor<*xf32>) -> ()
|
||||
}) {is_stateless = false} : (tensor<i1>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Match existing function->Region pattern (with casts)
|
||||
|
||||
func @testIf1Then(tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"}
|
||||
func @testIf1Else(tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"}
|
||||
func @testIf2Result(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK: "tf.If"({{.+}}) {else_branch = @testIf1Else, {{.+}} then_branch = @testIf1Then}
|
||||
%0 = "tf.IfRegion"(%arg0) ( {
|
||||
%1 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2xf32>) -> tensor<*xf32>
|
||||
%2 = call @testIf1Then(%1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
"tf.Yield"(%2) : (tensor<*xf32>) -> ()
|
||||
}, {
|
||||
%1 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2xf32>) -> tensor<*xf32>
|
||||
%2 = call @testIf1Else(%1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
"tf.Yield"(%2) : (tensor<*xf32>) -> ()
|
||||
}) {is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
@ -14,7 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass transforms functional control flow operations in the
|
||||
// standard TensorFlow dialect to MLIR Control Flow Graph (CFG) form.
|
||||
// TensorFlow dialect to MLIR Control Flow Graph (CFG) form.
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
@ -52,7 +52,6 @@ static Value LowerCondition(Location loc, Value value, OpBuilder* builder) {
|
||||
//
|
||||
// Requires the function to provide arguments for each of the `fn` operands
|
||||
// that is compatible for tensor cast.
|
||||
//
|
||||
static Operation* CallFn(Location loc, const std::function<Value(int)>& get_arg,
|
||||
FuncOp fn, OpBuilder* builder) {
|
||||
FunctionType fn_type = fn.getType();
|
||||
@ -113,7 +112,6 @@ static void JumpToBlock(Location loc, const std::function<Value(int)>& get_arg,
|
||||
// Requires that the block has same number of arguments as number of results of
|
||||
// the operation and either they have same types or are more generic types and
|
||||
// it is possible to cast them to results' types.
|
||||
//
|
||||
static void ReplaceOpResultWithBlockArgs(Location loc, Operation* op,
|
||||
Block* block, OpBuilder* builder) {
|
||||
assert(op->getNumResults() == block->getNumArguments());
|
||||
@ -132,9 +130,6 @@ static void ReplaceOpResultWithBlockArgs(Location loc, Operation* op,
|
||||
// Given a functional IfOp, transforms the enclosing code to eliminate it
|
||||
// completely from the IR, breaking it into operations to evaluate the condition
|
||||
// as a bool, plus some branches.
|
||||
//
|
||||
// This returns true on failure.
|
||||
//
|
||||
static LogicalResult LowerIfOp(IfOp op) {
|
||||
Operation* op_inst = op.getOperation();
|
||||
Location loc = op_inst->getLoc();
|
||||
@ -193,9 +188,6 @@ static LogicalResult LowerIfOp(IfOp op) {
|
||||
// Given a functional WhileOp, transforms the enclosing code to eliminate it
|
||||
// completely from the IR, breaking it into operations to execute the loop body
|
||||
// repeatedly while the loop condition is true.
|
||||
//
|
||||
// This returns true on failure.
|
||||
//
|
||||
static LogicalResult LowerWhileOp(WhileOp op) {
|
||||
Operation* op_inst = op.getOperation();
|
||||
Location loc = op_inst->getLoc();
|
||||
|
@ -0,0 +1,118 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass transforms functional control flow operations in the
|
||||
// TensorFlow dialect to their region based counterparts, i.e.,
|
||||
// tf.If -> tf.IfRegion
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
namespace {
|
||||
|
||||
struct FunctionalControlFlowToRegions
|
||||
: public PassWrapper<FunctionalControlFlowToRegions, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
// Create a call to function `fn` with arguments `args` and return the CallOp.
|
||||
// The arguments are cast to the required type before the call.
|
||||
CallOp CreateCall(Location loc, Operation::operand_range args, FuncOp fn,
|
||||
OpBuilder* builder) {
|
||||
FunctionType fn_type = fn.getType();
|
||||
llvm::SmallVector<Value, 4> operands;
|
||||
int num_operands = fn_type.getNumInputs();
|
||||
operands.reserve(num_operands);
|
||||
for (int i = 0; i < num_operands; ++i) {
|
||||
Value arg = args[i];
|
||||
Type expected = fn_type.getInput(i);
|
||||
if (arg.getType() != expected) {
|
||||
arg = builder->create<CastOp>(loc, expected, arg,
|
||||
/*Truncate=*/builder->getBoolAttr(false));
|
||||
}
|
||||
operands.push_back(arg);
|
||||
}
|
||||
return builder->create<CallOp>(loc, fn, operands);
|
||||
}
|
||||
|
||||
// Transform a functional IfOp to a region based IfRegionOp
|
||||
LogicalResult ConvertIfOp(IfOp if_op) {
|
||||
auto if_region = OpBuilder(if_op).create<TF::IfRegionOp>(
|
||||
if_op.getLoc(), if_op.getResultTypes(), if_op.cond(),
|
||||
if_op.is_stateless());
|
||||
|
||||
// Insert call to the given function into the 'region'.
|
||||
auto create_region_with_call = [&if_op](FlatSymbolRefAttr symbol,
|
||||
Region& region) {
|
||||
OpBuilder builder(region);
|
||||
builder.createBlock(®ion);
|
||||
auto func = if_op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
|
||||
symbol.getValue());
|
||||
auto call = CreateCall(if_op.getLoc(), if_op.input(), func, &builder);
|
||||
builder.create<YieldOp>(if_op.getLoc(), call.getResults());
|
||||
// Mark old function as private so that it can be DCE'd if not called.
|
||||
func.setVisibility(SymbolTable::Visibility::Private);
|
||||
};
|
||||
|
||||
create_region_with_call(if_op.then_branchAttr(), if_region.then_branch());
|
||||
create_region_with_call(if_op.else_branchAttr(), if_region.else_branch());
|
||||
|
||||
if_op.replaceAllUsesWith(if_region.getResults());
|
||||
if_op.erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
void FunctionalControlFlowToRegions::runOnFunction() {
|
||||
for (Block& block : getFunction()) {
|
||||
auto result = block.walk([](Operation* op) {
|
||||
if (IfOp if_op = llvm::dyn_cast<IfOp>(op)) {
|
||||
if (failed(ConvertIfOp(if_op))) {
|
||||
if_op.emitOpError() << " failed to convert to region form";
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (result.wasInterrupted()) return signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
CreateTFFunctionalControlFlowToRegions() {
|
||||
return std::make_unique<FunctionalControlFlowToRegions>();
|
||||
}
|
||||
|
||||
static PassRegistration<FunctionalControlFlowToRegions> pass(
|
||||
"tf-functional-control-flow-to-regions",
|
||||
"Transform functional control flow Ops to Region based counterparts");
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
@ -32,10 +32,18 @@ std::unique_ptr<OperationPass<FuncOp>>
|
||||
CreateFunctionalToExecutorDialectConversionPass();
|
||||
|
||||
namespace TF {
|
||||
// Transforms functional control flow operations in the standard TensorFlow
|
||||
// dialect to MLIR Control Flow Graph (CFG) form.
|
||||
// Transforms functional control flow operations in the TensorFlow dialect to
|
||||
// MLIR Control Flow Graph (CFG) form.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateTFFunctionalControlFlowToCFG();
|
||||
|
||||
// Transforms functional control flow operations in the TensorFlow dialect to
|
||||
// their region based counterparts.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateTFFunctionalControlFlowToRegions();
|
||||
|
||||
// Transforms region bases control flow operations in the TensorFlow dialect to
|
||||
// their functional counterparts.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateTFRegionControlFlowToFunctional();
|
||||
|
||||
// Materialize the MlirPassthroughOp by replacing it with the MLIR module
|
||||
// attached as an attribute.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateMaterializePassthroughOpPass();
|
||||
|
@ -0,0 +1,321 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass transforms region bases control flow operations in
|
||||
// the TensorFlow dialect to their functional counterparts, i.e.,
|
||||
// tf.IfRegion -> tf.If
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
namespace {
|
||||
|
||||
struct RegionControlFlowToFunctional
|
||||
: public PassWrapper<RegionControlFlowToFunctional, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
|
||||
private:
|
||||
LogicalResult ConvertIfOp(IfRegionOp if_region);
|
||||
|
||||
// Get unique name by using the loc to name mapping.
|
||||
std::string GetName(Operation* op, StringRef suffix);
|
||||
|
||||
tensorflow::OpOrArgLocNameMapper mapper;
|
||||
};
|
||||
|
||||
std::string RegionControlFlowToFunctional::GetName(Operation* op,
|
||||
StringRef suffix) {
|
||||
return (mapper.GetUniqueName(op) + suffix).str();
|
||||
}
|
||||
|
||||
// Given a list of regions, return all the external values referenced from the
|
||||
// region. If the external value is a constant, sink it into the region instead
|
||||
// (and do not add it to the returned vector).
|
||||
llvm::SmallVector<Value, 4> CollectExternValues(ArrayRef<Region*> regions) {
|
||||
llvm::SetVector<Value> extern_values_set;
|
||||
|
||||
for (auto region : regions) {
|
||||
llvm::SetVector<Value> region_extern_values;
|
||||
getUsedValuesDefinedAbove(*region, region_extern_values);
|
||||
|
||||
// Sink down constants into the functions.
|
||||
for (auto extern_value : region_extern_values) {
|
||||
if (!matchPattern(extern_value, m_Constant())) {
|
||||
extern_values_set.insert(extern_value);
|
||||
continue;
|
||||
}
|
||||
// Add constant at start of region.
|
||||
auto const_builder = OpBuilder::atBlockBegin(®ion->front());
|
||||
auto const_value = const_builder.clone(*extern_value.getDefiningOp());
|
||||
replaceAllUsesInRegionWith(extern_value, const_value->getResult(0),
|
||||
*region);
|
||||
}
|
||||
}
|
||||
|
||||
return {extern_values_set.begin(), extern_values_set.end()};
|
||||
}
|
||||
|
||||
// Extract the contents of a region with a single block into a new function.
|
||||
// `extern_values` is the set of external values that the region refers to.
|
||||
//
|
||||
// Any inputs to the terminator of the region are converted to return values of
|
||||
// the function. If any of these values is not exact type as the function's
|
||||
// return type, appropriate cast operations will be inserted
|
||||
void ExtractSingleBlockRegion(Region& region, FunctionType type, StringRef name,
|
||||
llvm::SmallVector<Value, 4>& extern_values) {
|
||||
ModuleOp module = region.getParentOfType<ModuleOp>();
|
||||
auto builder = OpBuilder::atBlockBegin(module.getBody());
|
||||
auto loc = region.getParentOp()->getLoc();
|
||||
|
||||
// Create new function and extract region body into the function
|
||||
auto outlined_func =
|
||||
builder.create<FuncOp>(loc, name, type, ArrayRef<NamedAttribute>{});
|
||||
|
||||
outlined_func.getBody().takeBody(region);
|
||||
Region& func_region = outlined_func.getBody();
|
||||
Block& first_block = func_region.front();
|
||||
|
||||
// Replace all external uses with function arguments.
|
||||
for (auto it : llvm::enumerate(extern_values)) {
|
||||
Value arg = first_block.addArgument(it.value().getType());
|
||||
replaceAllUsesInRegionWith(it.value(), arg, func_region);
|
||||
}
|
||||
|
||||
// Replace the existing terminator with a return.
|
||||
Operation* terminator = outlined_func.getBody().front().getTerminator();
|
||||
builder.setInsertionPoint(terminator);
|
||||
|
||||
SmallVector<Value, 4> return_values;
|
||||
return_values.reserve(terminator->getNumOperands());
|
||||
for (auto it : llvm::enumerate(type.getResults())) {
|
||||
Value ret_val = terminator->getOperand(it.index());
|
||||
// Add a cast operation if types do not match.
|
||||
if (ret_val.getType() != it.value()) {
|
||||
ret_val =
|
||||
builder.create<CastOp>(terminator->getLoc(), it.value(), ret_val);
|
||||
}
|
||||
return_values.push_back(ret_val);
|
||||
}
|
||||
builder.create<ReturnOp>(terminator->getLoc(), return_values);
|
||||
terminator->erase();
|
||||
outlined_func.setVisibility(FuncOp::Visibility::Private);
|
||||
}
|
||||
|
||||
// Returns call for region with single call whose result feeds into the
|
||||
// terminator of the region. Returns none if the region doesn't contain just
|
||||
// call and casts ops.
|
||||
llvm::Optional<CallOp> IsSingleCallRegion(Region& region) {
|
||||
if (region.getBlocks().size() != 1) return llvm::None;
|
||||
|
||||
auto it = region.front().rbegin();
|
||||
YieldOp yield = dyn_cast<YieldOp>(*it++);
|
||||
if (yield.getNumOperands() == 0) return llvm::None;
|
||||
|
||||
if (it == region.front().rend()) return llvm::None;
|
||||
|
||||
// and a Call before that
|
||||
CallOp call = dyn_cast<CallOp>(*it++);
|
||||
if (!call) return llvm::None;
|
||||
|
||||
// There can be cast op's prior to that
|
||||
for (; it != region.front().rend(); ++it)
|
||||
if (!isa<CastOp>(*it)) return llvm::None;
|
||||
|
||||
// all results of the call should feed into the yield
|
||||
if (call.getNumResults() != yield.getNumOperands()) return llvm::None;
|
||||
|
||||
for (auto res_it : llvm::zip(call.getResults(), yield.getOperands()))
|
||||
if (std::get<0>(res_it) != std::get<1>(res_it)) return llvm::None;
|
||||
|
||||
return call;
|
||||
}
|
||||
|
||||
// Returns whether the arguments of the given call are same as the given list of
|
||||
// arguments (after looking through cast ops).
|
||||
bool MatchCallArgs(CallOp call, llvm::SmallVectorImpl<Value>& args) {
|
||||
if (call.getNumOperands() != args.size()) return false;
|
||||
|
||||
for (auto it : llvm::enumerate(args)) {
|
||||
Value arg = call.getOperand(it.index());
|
||||
if (auto cast = dyn_cast_or_null<CastOp>(arg.getDefiningOp()))
|
||||
arg = cast.getOperand();
|
||||
|
||||
if (arg != it.value()) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Summary information for trivially transforming region based op's to
|
||||
// functional ops. A trivial transformation can be done when the regions are
|
||||
// just calls to functions, in which case no outlining is needed.
|
||||
struct TrivialTransformInfo {
|
||||
// can the op be transformed trivially
|
||||
bool can_transform = false;
|
||||
|
||||
// list of callee names (one for each region)
|
||||
llvm::SmallVector<StringRef, 4> callee_names;
|
||||
|
||||
// list of arguments used in these call (each call uses the same arguments
|
||||
// potentially through casts)
|
||||
llvm::SmallVector<Value, 4> call_args;
|
||||
};
|
||||
|
||||
// Analyze the given set of regions (attached to the same parent op) to check
|
||||
// if the parent op be transformed to functional form trivially (i.e., reusing
|
||||
// existing functions and without outlining). This is possible when all the
|
||||
// regions are single call regions and the all the calls have the same
|
||||
// arguments.
|
||||
//
|
||||
// If this trivial transformation is possible, return the relevant information
|
||||
// needed for the transformation (in `TrivialTransformInfo`), else indicate that
|
||||
// a trivial transformation is not possible by setting `can_transform` false.
|
||||
TrivialTransformInfo AnalyzeForTrivialTransform(ArrayRef<Region*> regions) {
|
||||
const TrivialTransformInfo cannot_transform;
|
||||
|
||||
if (regions.empty()) return cannot_transform;
|
||||
|
||||
llvm::SmallVector<CallOp, 2> calls;
|
||||
calls.reserve(regions.size());
|
||||
|
||||
// Verify each region is a single call and collect these calls.
|
||||
for (Region* region : regions) {
|
||||
auto call = IsSingleCallRegion(*region);
|
||||
if (!call.hasValue()) return cannot_transform;
|
||||
calls.push_back(call.getValue());
|
||||
}
|
||||
|
||||
llvm::SmallVector<StringRef, 4> callees;
|
||||
callees.reserve(regions.size());
|
||||
|
||||
CallOp call0 = calls[0];
|
||||
int num_args = call0.getNumOperands();
|
||||
callees.push_back(call0.getCallee());
|
||||
|
||||
// collect call0 arguments
|
||||
llvm::SmallVector<Value, 4> call0_args;
|
||||
call0_args.reserve(num_args);
|
||||
for (int arg_idx = 0; arg_idx < num_args; ++arg_idx) {
|
||||
Value arg = call0.getOperand(arg_idx);
|
||||
if (auto cast = dyn_cast_or_null<CastOp>(arg.getDefiningOp()))
|
||||
arg = cast.getOperand();
|
||||
call0_args.push_back(arg);
|
||||
}
|
||||
|
||||
// match other call arguments with call0 arguments
|
||||
for (int idx = 1; idx < calls.size(); ++idx) {
|
||||
if (!MatchCallArgs(calls[idx], call0_args)) return cannot_transform;
|
||||
callees.push_back(calls[idx].getCallee());
|
||||
}
|
||||
|
||||
return {true, callees, call0_args};
|
||||
}
|
||||
|
||||
// Transform IfRegionOp to IfOp
|
||||
LogicalResult RegionControlFlowToFunctional::ConvertIfOp(IfRegionOp if_region) {
|
||||
const TrivialTransformInfo tti = AnalyzeForTrivialTransform(
|
||||
{&if_region.then_branch(), &if_region.else_branch()});
|
||||
|
||||
std::string then_name, else_name;
|
||||
llvm::SmallVector<Value, 4> extern_values;
|
||||
|
||||
if (tti.can_transform) {
|
||||
// We can transform to functional form trivially without outlining
|
||||
then_name = tti.callee_names[0].str();
|
||||
else_name = tti.callee_names[1].str();
|
||||
extern_values = tti.call_args;
|
||||
} else {
|
||||
// Collect external values that are used within the else and then bodies.
|
||||
extern_values = CollectExternValues(
|
||||
{&if_region.then_branch(), &if_region.else_branch()});
|
||||
|
||||
// These external values need to be added as inputs to the generated If. The
|
||||
// order is determined by the order of these values the `extern_vales`.
|
||||
|
||||
// Build the type for the outlined function
|
||||
llvm::SmallVector<Type, 4> input_types;
|
||||
input_types.reserve(extern_values.size());
|
||||
for (auto input : extern_values) input_types.push_back(input.getType());
|
||||
|
||||
FunctionType func_type = FunctionType::get(
|
||||
input_types, if_region.getResultTypes(), if_region.getContext());
|
||||
|
||||
// Create 2 new functions with the input signature matching this order,
|
||||
// and outline the `then` and `else` regions by moving the bodies of these
|
||||
// regions into these functions. Replace tf.yield with a regular return.
|
||||
then_name = GetName(if_region, "_then");
|
||||
ExtractSingleBlockRegion(if_region.then_branch(), func_type, then_name,
|
||||
extern_values);
|
||||
|
||||
else_name = GetName(if_region, "_else");
|
||||
ExtractSingleBlockRegion(if_region.else_branch(), func_type, else_name,
|
||||
extern_values);
|
||||
}
|
||||
|
||||
// Once we have the `then` and `else` functions ready (either outlined or
|
||||
// existing ones), replace the region based op with a functional control flow
|
||||
// op
|
||||
OpBuilder builder(if_region);
|
||||
auto if_op = builder.create<IfOp>(
|
||||
if_region.getLoc(), if_region.getResultTypes(), if_region.cond(),
|
||||
extern_values, then_name, else_name, if_region.is_stateless());
|
||||
if_region.replaceAllUsesWith(if_op.getResults());
|
||||
if_region.erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
void RegionControlFlowToFunctional::runOnFunction() {
|
||||
for (Block& block : getFunction()) {
|
||||
auto result = block.walk([&](Operation* op) {
|
||||
if (IfRegionOp if_region = llvm::dyn_cast<IfRegionOp>(op)) {
|
||||
if (failed(ConvertIfOp(if_region))) {
|
||||
if_region.emitOpError() << " failed to convert to functional form";
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (result.wasInterrupted()) return signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateTFRegionControlFlowToFunctional() {
|
||||
return std::make_unique<RegionControlFlowToFunctional>();
|
||||
}
|
||||
|
||||
static PassRegistration<RegionControlFlowToFunctional> pass(
|
||||
"tf-region-control-flow-to-functional",
|
||||
"Transform region bases control flow Ops to functional counterparts");
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
Loading…
Reference in New Issue
Block a user