Add transforms to convert between functional and region based If.
- Add this pair of transforms in TF -> TFLite conversion pass manager - Add inliner hook to the TF dialect to allow inlining of call's within the regions of the IfRegion ops' - Add a end-to-end test case using 2 If's that demonstrate inlining and followed by constant sinking and constant folding. PiperOrigin-RevId: 314966564 Change-Id: I633b8a24cb68d5c2c54b889f9505e7d6fcb905a4
This commit is contained in:
parent
b6b9f0815e
commit
b055058610
422
tensorflow/compiler/mlir/lite/tests/end2end/if_op.pbtxt
Normal file
422
tensorflow/compiler/mlir/lite/tests/end2end/if_op.pbtxt
Normal file
@ -0,0 +1,422 @@
|
|||||||
|
# RUN: tf_tfl_translate -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=4:4 -tf-output-arrays=StatefulIf,StatelessIf %s -o - --output-mlir | FileCheck %s --dump-input-on-failure
|
||||||
|
|
||||||
|
node {
|
||||||
|
name: "tf.Less"
|
||||||
|
op: "Less"
|
||||||
|
input: "a"
|
||||||
|
input: "b"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "my_equal"
|
||||||
|
op: "Equal"
|
||||||
|
input: "a"
|
||||||
|
input: "b"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "cst0"
|
||||||
|
op: "Const"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value {
|
||||||
|
tensor {
|
||||||
|
dtype: DT_FLOAT
|
||||||
|
tensor_shape {
|
||||||
|
dim {
|
||||||
|
size: 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
float_val: 1.0
|
||||||
|
float_val: 2.0
|
||||||
|
float_val: 3.0
|
||||||
|
float_val: 4.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "cst1"
|
||||||
|
op: "Const"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value {
|
||||||
|
tensor {
|
||||||
|
dtype: DT_FLOAT
|
||||||
|
tensor_shape {
|
||||||
|
dim {
|
||||||
|
size: 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
float_val: 5.0
|
||||||
|
float_val: 6.0
|
||||||
|
float_val: 7.0
|
||||||
|
float_val: 8.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "StatefulIf"
|
||||||
|
op: "If"
|
||||||
|
input: "tf.Less"
|
||||||
|
input: "a"
|
||||||
|
input: "b"
|
||||||
|
input: "cst0"
|
||||||
|
input: "cst1"
|
||||||
|
attr {
|
||||||
|
key: "Tcond"
|
||||||
|
value {
|
||||||
|
type: DT_BOOL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "Tin"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "Tout"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "else_branch"
|
||||||
|
value {
|
||||||
|
func {
|
||||||
|
name: "cond_false"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "then_branch"
|
||||||
|
value {
|
||||||
|
func {
|
||||||
|
name: "cond_true"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "StatelessIf"
|
||||||
|
op: "StatelessIf"
|
||||||
|
input: "my_equal"
|
||||||
|
input: "a"
|
||||||
|
input: "b"
|
||||||
|
attr {
|
||||||
|
key: "Tcond"
|
||||||
|
value {
|
||||||
|
type: DT_BOOL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "Tin"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "Tout"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "else_branch"
|
||||||
|
value {
|
||||||
|
func {
|
||||||
|
name: "cond_false_1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "then_branch"
|
||||||
|
value {
|
||||||
|
func {
|
||||||
|
name: "cond_true_1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "main"
|
||||||
|
op: "_Retval"
|
||||||
|
input: "StatefulIf"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "index"
|
||||||
|
value {
|
||||||
|
i: 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "main1"
|
||||||
|
op: "_Retval"
|
||||||
|
input: "StatelessIf"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "index"
|
||||||
|
value {
|
||||||
|
i: 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "a"
|
||||||
|
op: "Placeholder"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "b"
|
||||||
|
op: "Placeholder"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
library {
|
||||||
|
function {
|
||||||
|
signature {
|
||||||
|
name: "cond_true"
|
||||||
|
input_arg {
|
||||||
|
name: "cond_true_arg0"
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "cond_true_arg1"
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "cond_true_arg2"
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "cond_true_arg3"
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "cond_true_ret"
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node_def {
|
||||||
|
name: "tf.Add"
|
||||||
|
op: "Add"
|
||||||
|
input: "cond_true_arg2"
|
||||||
|
input: "cond_true_arg3"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
original_node_names: "tf.Add"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ret {
|
||||||
|
key: "cond_true_ret"
|
||||||
|
value: "tf.Add:z:0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
function {
|
||||||
|
signature {
|
||||||
|
name: "cond_false"
|
||||||
|
input_arg {
|
||||||
|
name: "cond_false_arg0"
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "cond_false_arg1"
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "cond_false_arg2"
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "cond_false_arg3"
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "cond_false_ret"
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node_def {
|
||||||
|
name: "tf.Mul"
|
||||||
|
op: "Mul"
|
||||||
|
input: "cond_false_arg0"
|
||||||
|
input: "cond_false_arg3"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
original_node_names: "tf.Mul"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ret {
|
||||||
|
key: "cond_false_ret"
|
||||||
|
value: "tf.Mul:z:0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
function {
|
||||||
|
signature {
|
||||||
|
name: "cond_true_1"
|
||||||
|
input_arg {
|
||||||
|
name: "cond_true_arg0"
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "cond_true_arg1"
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "cond_true_ret"
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node_def {
|
||||||
|
name: "tf.Sub"
|
||||||
|
op: "Sub"
|
||||||
|
input: "cond_true_arg0"
|
||||||
|
input: "cond_true_arg1"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
original_node_names: "tf.Sub"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ret {
|
||||||
|
key: "cond_true_ret"
|
||||||
|
value: "tf.Sub:z:0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
function {
|
||||||
|
signature {
|
||||||
|
name: "cond_false_1"
|
||||||
|
input_arg {
|
||||||
|
name: "cond_false_arg0"
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "cond_false_arg1"
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "cond_false_ret"
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node_def {
|
||||||
|
name: "tf.Div"
|
||||||
|
op: "Div"
|
||||||
|
input: "cond_false_arg0"
|
||||||
|
input: "cond_false_arg1"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
experimental_debug_info {
|
||||||
|
original_node_names: "tf.Div"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ret {
|
||||||
|
key: "cond_false_ret"
|
||||||
|
value: "tf.Div:z:0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
versions {
|
||||||
|
producer: 115
|
||||||
|
min_consumer: 12
|
||||||
|
}
|
||||||
|
|
||||||
|
# CHECK: func @StatefulIf_else
|
||||||
|
# CHECK-NEXT: constant dense<[5.000000e+00, 6.000000e+00, 7.000000e+00, 8.000000e+00]>
|
||||||
|
# CHECK-NEXT: tfl.mul
|
||||||
|
# CHECK: func @StatefulIf_then
|
||||||
|
# CHECK-NEXT: constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]>
|
||||||
|
# CHECK-NEXT: return
|
||||||
|
# CHECK: func @StatelessIf_else
|
||||||
|
# CHECK-NEXT: tfl.div
|
||||||
|
# CHECK: func @StatelessIf_then
|
||||||
|
# CHECK-NEXT: tfl.sub
|
||||||
|
# CHECK: "tf.If"{{.+}}else_branch = @StatelessIf_else{{.+}}then_branch = @StatelessIf_then
|
||||||
|
# CHECK: "tf.If"{{.+}}else_branch = @StatefulIf_else{{.+}}then_branch = @StatefulIf_then
|
||||||
|
|
@ -85,6 +85,8 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
|||||||
pass_config.quant_specs.serialized_quant_stats));
|
pass_config.quant_specs.serialized_quant_stats));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pass_manager->addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions());
|
||||||
|
|
||||||
// The conversion pipeline has to follow the following orders:
|
// The conversion pipeline has to follow the following orders:
|
||||||
// 1) Saved model related optimization like decompose resource ops
|
// 1) Saved model related optimization like decompose resource ops
|
||||||
// 2) Convert composite functions like lstm/rnns, along with proper function
|
// 2) Convert composite functions like lstm/rnns, along with proper function
|
||||||
@ -128,6 +130,9 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
|||||||
// Add a shape inference pass to optimize away the unnecessary casts.
|
// Add a shape inference pass to optimize away the unnecessary casts.
|
||||||
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
|
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pass_manager->addPass(mlir::TF::CreateTFRegionControlFlowToFunctional());
|
||||||
|
|
||||||
// Legalize while early to allow further constant folding.
|
// Legalize while early to allow further constant folding.
|
||||||
// TODO(jpienaar): This may not actually matter as we do canonicalization
|
// TODO(jpienaar): This may not actually matter as we do canonicalization
|
||||||
// after the legalize below, for now it needs to be below the above passes
|
// after the legalize below, for now it needs to be below the above passes
|
||||||
|
@ -54,7 +54,6 @@ class WhileOutlinePass
|
|||||||
|
|
||||||
tensorflow::OpOrArgLocNameMapper mapper_;
|
tensorflow::OpOrArgLocNameMapper mapper_;
|
||||||
};
|
};
|
||||||
} // namespace
|
|
||||||
|
|
||||||
std::string WhileOutlinePass::GetName(Operation* op, StringRef suffix) {
|
std::string WhileOutlinePass::GetName(Operation* op, StringRef suffix) {
|
||||||
return (mapper_.GetUniqueName(op) + suffix).str();
|
return (mapper_.GetUniqueName(op) + suffix).str();
|
||||||
@ -62,7 +61,7 @@ std::string WhileOutlinePass::GetName(Operation* op, StringRef suffix) {
|
|||||||
|
|
||||||
// Returns whether the WhileOp is already outlined (e.g., only consists of calls
|
// Returns whether the WhileOp is already outlined (e.g., only consists of calls
|
||||||
// to functions).
|
// to functions).
|
||||||
static bool IsAlreadyOutlinedd(WhileOp while_op) {
|
bool IsAlreadyOutlined(WhileOp while_op) {
|
||||||
auto just_call = [](Region& region) {
|
auto just_call = [](Region& region) {
|
||||||
auto it = region.front().begin();
|
auto it = region.front().begin();
|
||||||
if (!isa<CallOp>(*it)) return false;
|
if (!isa<CallOp>(*it)) return false;
|
||||||
@ -120,7 +119,7 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Skip if already just calls.
|
// Skip if already just calls.
|
||||||
if (extra_operands.empty() && IsAlreadyOutlinedd(while_op)) return;
|
if (extra_operands.empty() && IsAlreadyOutlined(while_op)) return;
|
||||||
|
|
||||||
// Collect new types.
|
// Collect new types.
|
||||||
SmallVector<Type, 4> types;
|
SmallVector<Type, 4> types;
|
||||||
@ -238,6 +237,7 @@ void WhileOutlinePass::runOnOperation() {
|
|||||||
getOperation().walk(
|
getOperation().walk(
|
||||||
[&](mlir::TFL::WhileOp while_op) { OutlineWhile(while_op); });
|
[&](mlir::TFL::WhileOp while_op) { OutlineWhile(while_op); });
|
||||||
}
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
|
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass() {
|
std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass() {
|
||||||
|
@ -418,6 +418,7 @@ cc_library(
|
|||||||
"transforms/fold_switch.cc",
|
"transforms/fold_switch.cc",
|
||||||
"transforms/freeze_global_tensors.cc",
|
"transforms/freeze_global_tensors.cc",
|
||||||
"transforms/functional_control_flow_to_cfg.cc",
|
"transforms/functional_control_flow_to_cfg.cc",
|
||||||
|
"transforms/functional_control_flow_to_regions.cc",
|
||||||
"transforms/generated_canonicalize.inc",
|
"transforms/generated_canonicalize.inc",
|
||||||
"transforms/generated_optimize.inc",
|
"transforms/generated_optimize.inc",
|
||||||
"transforms/gpu_fusion.cc",
|
"transforms/gpu_fusion.cc",
|
||||||
@ -432,6 +433,7 @@ cc_library(
|
|||||||
"transforms/promote_resources_to_args.cc",
|
"transforms/promote_resources_to_args.cc",
|
||||||
"transforms/raise_control_flow.cc",
|
"transforms/raise_control_flow.cc",
|
||||||
"transforms/readonly_references_to_resources.cc",
|
"transforms/readonly_references_to_resources.cc",
|
||||||
|
"transforms/region_control_flow_to_functional.cc",
|
||||||
"transforms/replicate_invariant_op_hoisting.cc",
|
"transforms/replicate_invariant_op_hoisting.cc",
|
||||||
"transforms/replicate_to_island.cc",
|
"transforms/replicate_to_island.cc",
|
||||||
"transforms/resource_device_inference.cc",
|
"transforms/resource_device_inference.cc",
|
||||||
@ -490,6 +492,7 @@ cc_library(
|
|||||||
":translate_utils",
|
":translate_utils",
|
||||||
":unroll_batch_matmul_pass",
|
":unroll_batch_matmul_pass",
|
||||||
":xla_sharding_util",
|
":xla_sharding_util",
|
||||||
|
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
|
||||||
"//tensorflow/compiler/mlir/lite:validators",
|
"//tensorflow/compiler/mlir/lite:validators",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
"//tensorflow/compiler/xla:xla_proto_cc",
|
"//tensorflow/compiler/xla:xla_proto_cc",
|
||||||
|
@ -4011,6 +4011,15 @@ struct TFInlinerInterface : public DialectInlinerInterface {
|
|||||||
// Analysis Hooks
|
// Analysis Hooks
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Defines the legality of inlinining 'src' region into the 'dest' region
|
||||||
|
// attached to a TF operation
|
||||||
|
bool isLegalToInline(Region *dest, Region *src,
|
||||||
|
BlockAndValueMapping &valueMapping) const final {
|
||||||
|
// Allow inlining in regions attached to region based control flow
|
||||||
|
// operations only if the src region is a single block region
|
||||||
|
return isa<IfRegionOp>(dest->getParentOp()) && src->getBlocks().size() == 1;
|
||||||
|
}
|
||||||
|
|
||||||
// Defines the legality of inlining TF operations.
|
// Defines the legality of inlining TF operations.
|
||||||
bool isLegalToInline(Operation *, Region *,
|
bool isLegalToInline(Operation *, Region *,
|
||||||
BlockAndValueMapping &) const final {
|
BlockAndValueMapping &) const final {
|
||||||
|
@ -0,0 +1,48 @@
|
|||||||
|
// RUN: tf-opt %s -tf-functional-control-flow-to-regions -split-input-file | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
|
// CHECK: func @testIf1Then{{.+}} {sym_visibility = "private"}
|
||||||
|
// CHECK: func @testIf1Else{{.+}} {sym_visibility = "private"}
|
||||||
|
func @testIf1Then(tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
func @testIf1Else(tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @testIf1Result(%arg0: tensor<i1>, %arg1: tensor<*xf32>)
|
||||||
|
func @testIf1Result(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "tf.If"(%arg0, %arg1) {
|
||||||
|
then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false
|
||||||
|
} : (tensor<i1>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK: "tf.IfRegion"
|
||||||
|
// CHECK: [[Result0:%.*]] = call @testIf1Then
|
||||||
|
// CHECK: "tf.Yield"([[Result0]])
|
||||||
|
// CHECK: [[Result1:%.*]] = call @testIf1Else
|
||||||
|
// CHECK: "tf.Yield"([[Result1]])
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// With mismatching input types
|
||||||
|
|
||||||
|
// CHECK: func @testIf1Then{{.+}} {sym_visibility = "private"}
|
||||||
|
// CHECK: func @testIf1Else{{.+}} {sym_visibility = "private"}
|
||||||
|
func @testIf1Then(tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
func @testIf1Else(tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @testIf2Result(%arg0: tensor<i1>, %arg1: tensor<2xf32>)
|
||||||
|
func @testIf2Result(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||||
|
%0 = "tf.If"(%arg0, %arg1) {
|
||||||
|
then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false
|
||||||
|
} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||||
|
|
||||||
|
// CHECK: "tf.IfRegion"
|
||||||
|
// CHECK: "tf.Cast"
|
||||||
|
// CHECK: [[Result0:%.*]] = call @testIf1Then
|
||||||
|
// CHECK: "tf.Yield"([[Result0]])
|
||||||
|
// CHECK: "tf.Cast"
|
||||||
|
// CHECK: [[Result1:%.*]] = call @testIf1Else
|
||||||
|
// CHECK: "tf.Yield"([[Result1]])
|
||||||
|
return %0 : tensor<2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,142 @@
|
|||||||
|
// RUN: tf-opt %s -tf-region-control-flow-to-functional -split-input-file
|
||||||
|
//| FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
|
// CHECK: func @tf.IfRegion_else(%arg0: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"}
|
||||||
|
// CHECK-NEXT: "tf.Neg"
|
||||||
|
// CHECK: func @tf.IfRegion_then(%arg0: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"}
|
||||||
|
// CHECK-NEXT: "tf.Abs"
|
||||||
|
func @testSimple(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
// CHECK: "tf.If"{{.+}}else_branch = @tf.IfRegion_else{{.+}}then_branch = @tf.IfRegion_then
|
||||||
|
%0 = "tf.IfRegion"(%arg0) ({
|
||||||
|
%1 = "tf.Abs"(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
"tf.Yield"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
}, {
|
||||||
|
%2 = "tf.Neg"(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
"tf.Yield"(%2) : (tensor<*xf32>) -> ()
|
||||||
|
}) { is_stateless = true } : (tensor<i1>) -> tensor<*xf32>
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Use if condition inside the regions
|
||||||
|
// CHECK: func @tf.IfRegion_else(%arg0: tensor<i1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>) -> tensor<2xf32> attributes {sym_visibility = "private"}
|
||||||
|
// CHECK-NEXT: "tf.Select"(%arg0, %arg2, %arg3)
|
||||||
|
// CHECK: func @tf.IfRegion_then(%arg0: tensor<i1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>) -> tensor<2xf32> attributes {sym_visibility = "private"}
|
||||||
|
// CHECK-NEXT: "tf.Select"(%arg0, %arg1, %arg2)
|
||||||
|
func @testIfCondition(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||||
|
%0 = "tf.Add"(%arg1, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||||
|
%1 = "tf.Mul"(%arg1, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||||
|
%2 = "tf.Div"(%arg1, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||||
|
|
||||||
|
// CHECK: "tf.If"{{.+}}else_branch = @tf.IfRegion_else{{.+}}then_branch = @tf.IfRegion_then
|
||||||
|
%3 = "tf.IfRegion"(%arg0) ({
|
||||||
|
%4 = "tf.Select"(%arg0, %0, %1) : (tensor<i1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||||
|
"tf.Yield"(%4) : (tensor<2xf32>) -> ()
|
||||||
|
}, {
|
||||||
|
%5 = "tf.Select"(%arg0, %1, %2): (tensor<i1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||||
|
"tf.Yield"(%5) : (tensor<2xf32>) -> ()
|
||||||
|
}) { is_stateless = true} : (tensor<i1>) -> tensor<2xf32>
|
||||||
|
return %3 : tensor<2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Constant sinking
|
||||||
|
|
||||||
|
// CHECK: func @tf.IfRegion_else() -> tensor<2xf32>
|
||||||
|
// CHECK-NEXT: constant dense<1.0
|
||||||
|
// CHECK: func @tf.IfRegion_then() -> tensor<2xf32>
|
||||||
|
// CHECK-NEXT: constant dense<0.0
|
||||||
|
func @testIfConstant(%arg0: tensor<i1>) -> tensor<2xf32> {
|
||||||
|
%cst_zero = constant dense<0.0> : tensor<2xf32>
|
||||||
|
// CHECK: "tf.If"(%arg0) {else_branch = @tf.IfRegion_else{{.+}}then_branch = @tf.IfRegion_then
|
||||||
|
%0 = "tf.IfRegion"(%arg0) ({
|
||||||
|
"tf.Yield"(%cst_zero) : (tensor<2xf32>) -> ()
|
||||||
|
}, {
|
||||||
|
%cst_one = constant dense<1.0> : tensor<2xf32>
|
||||||
|
"tf.Yield"(%cst_one) : (tensor<2xf32>) -> ()
|
||||||
|
}) { is_stateless = true} : (tensor<i1>) -> tensor<2xf32>
|
||||||
|
return %0 : tensor<2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Nested IfRegions
|
||||||
|
// CHECK: func @tf.IfRegion1_else
|
||||||
|
// CHECK-NEXT: "tf.Acos"
|
||||||
|
// CHECK-NEXT: "tf.Abs"
|
||||||
|
|
||||||
|
// CHECK: func @tf.IfRegion1_then
|
||||||
|
// CHECK-NEXT: "tf.LogicalNot"
|
||||||
|
// CHECK-NEXT: "tf.Asin"
|
||||||
|
// CHECK-NEXT: "tf.If"({{.+}}) {else_branch = @tf.IfRegion_else, {{.+}} then_branch = @tf.IfRegion_then}
|
||||||
|
|
||||||
|
// CHECK: func @tf.IfRegion_else
|
||||||
|
// CHECK-NEXT: "tf.Neg"
|
||||||
|
// CHECK: func @tf.IfRegion_then
|
||||||
|
// CHECK-NEXT: "tf.Abs"
|
||||||
|
|
||||||
|
func @testNested(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
// CHECK: "tf.If"({{.+}}) {else_branch = @tf.IfRegion1_else, {{.+}} then_branch = @tf.IfRegion1_then}
|
||||||
|
%0 = "tf.IfRegion"(%arg0) ({
|
||||||
|
// Outer Then
|
||||||
|
%cond = "tf.LogicalNot"(%arg0) : (tensor<i1>) -> tensor<i1>
|
||||||
|
%asin = "tf.Asin"(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
|
||||||
|
// nested IfRegion
|
||||||
|
%1 = "tf.IfRegion"(%cond) ({
|
||||||
|
%2 = "tf.Abs"(%asin) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
"tf.Yield"(%2) : (tensor<*xf32>) -> ()
|
||||||
|
}, {
|
||||||
|
%2 = "tf.Neg"(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
"tf.Yield"(%2) : (tensor<*xf32>) -> ()
|
||||||
|
}) { is_stateless = true } : (tensor<i1>) -> tensor<*xf32>
|
||||||
|
|
||||||
|
"tf.Yield"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
}, {
|
||||||
|
// Outer Else
|
||||||
|
%acos = "tf.Acos"(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
%3 = "tf.Abs"(%acos) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
"tf.Yield"(%3) : (tensor<*xf32>) -> ()
|
||||||
|
}) { is_stateless = true } : (tensor<i1>) -> tensor<*xf32>
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Match existing function->Region pattern (simple)
|
||||||
|
func @testIf1Then(tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"}
|
||||||
|
func @testIf1Else(tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"}
|
||||||
|
func @testIf1Result(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
// CHECK: "tf.If"({{.+}}) {else_branch = @testIf1Else, {{.+}} then_branch = @testIf1Then}
|
||||||
|
%0 = "tf.IfRegion"(%arg0) ( {
|
||||||
|
%1 = call @testIf1Then(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
"tf.Yield"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
}, {
|
||||||
|
%1 = call @testIf1Else(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
"tf.Yield"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
}) {is_stateless = false} : (tensor<i1>) -> tensor<*xf32>
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Match existing function->Region pattern (with casts)
|
||||||
|
|
||||||
|
func @testIf1Then(tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"}
|
||||||
|
func @testIf1Else(tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"}
|
||||||
|
func @testIf2Result(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||||
|
// CHECK: "tf.If"({{.+}}) {else_branch = @testIf1Else, {{.+}} then_branch = @testIf1Then}
|
||||||
|
%0 = "tf.IfRegion"(%arg0) ( {
|
||||||
|
%1 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2xf32>) -> tensor<*xf32>
|
||||||
|
%2 = call @testIf1Then(%1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
"tf.Yield"(%2) : (tensor<*xf32>) -> ()
|
||||||
|
}, {
|
||||||
|
%1 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2xf32>) -> tensor<*xf32>
|
||||||
|
%2 = call @testIf1Else(%1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
"tf.Yield"(%2) : (tensor<*xf32>) -> ()
|
||||||
|
}) {is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
|
||||||
|
return %0 : tensor<2xf32>
|
||||||
|
}
|
||||||
|
|
@ -14,7 +14,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
// This transformation pass transforms functional control flow operations in the
|
// This transformation pass transforms functional control flow operations in the
|
||||||
// standard TensorFlow dialect to MLIR Control Flow Graph (CFG) form.
|
// TensorFlow dialect to MLIR Control Flow Graph (CFG) form.
|
||||||
|
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||||
@ -52,7 +52,6 @@ static Value LowerCondition(Location loc, Value value, OpBuilder* builder) {
|
|||||||
//
|
//
|
||||||
// Requires the function to provide arguments for each of the `fn` operands
|
// Requires the function to provide arguments for each of the `fn` operands
|
||||||
// that is compatible for tensor cast.
|
// that is compatible for tensor cast.
|
||||||
//
|
|
||||||
static Operation* CallFn(Location loc, const std::function<Value(int)>& get_arg,
|
static Operation* CallFn(Location loc, const std::function<Value(int)>& get_arg,
|
||||||
FuncOp fn, OpBuilder* builder) {
|
FuncOp fn, OpBuilder* builder) {
|
||||||
FunctionType fn_type = fn.getType();
|
FunctionType fn_type = fn.getType();
|
||||||
@ -113,7 +112,6 @@ static void JumpToBlock(Location loc, const std::function<Value(int)>& get_arg,
|
|||||||
// Requires that the block has same number of arguments as number of results of
|
// Requires that the block has same number of arguments as number of results of
|
||||||
// the operation and either they have same types or are more generic types and
|
// the operation and either they have same types or are more generic types and
|
||||||
// it is possible to cast them to results' types.
|
// it is possible to cast them to results' types.
|
||||||
//
|
|
||||||
static void ReplaceOpResultWithBlockArgs(Location loc, Operation* op,
|
static void ReplaceOpResultWithBlockArgs(Location loc, Operation* op,
|
||||||
Block* block, OpBuilder* builder) {
|
Block* block, OpBuilder* builder) {
|
||||||
assert(op->getNumResults() == block->getNumArguments());
|
assert(op->getNumResults() == block->getNumArguments());
|
||||||
@ -132,9 +130,6 @@ static void ReplaceOpResultWithBlockArgs(Location loc, Operation* op,
|
|||||||
// Given a functional IfOp, transforms the enclosing code to eliminate it
|
// Given a functional IfOp, transforms the enclosing code to eliminate it
|
||||||
// completely from the IR, breaking it into operations to evaluate the condition
|
// completely from the IR, breaking it into operations to evaluate the condition
|
||||||
// as a bool, plus some branches.
|
// as a bool, plus some branches.
|
||||||
//
|
|
||||||
// This returns true on failure.
|
|
||||||
//
|
|
||||||
static LogicalResult LowerIfOp(IfOp op) {
|
static LogicalResult LowerIfOp(IfOp op) {
|
||||||
Operation* op_inst = op.getOperation();
|
Operation* op_inst = op.getOperation();
|
||||||
Location loc = op_inst->getLoc();
|
Location loc = op_inst->getLoc();
|
||||||
@ -193,9 +188,6 @@ static LogicalResult LowerIfOp(IfOp op) {
|
|||||||
// Given a functional WhileOp, transforms the enclosing code to eliminate it
|
// Given a functional WhileOp, transforms the enclosing code to eliminate it
|
||||||
// completely from the IR, breaking it into operations to execute the loop body
|
// completely from the IR, breaking it into operations to execute the loop body
|
||||||
// repeatedly while the loop condition is true.
|
// repeatedly while the loop condition is true.
|
||||||
//
|
|
||||||
// This returns true on failure.
|
|
||||||
//
|
|
||||||
static LogicalResult LowerWhileOp(WhileOp op) {
|
static LogicalResult LowerWhileOp(WhileOp op) {
|
||||||
Operation* op_inst = op.getOperation();
|
Operation* op_inst = op.getOperation();
|
||||||
Location loc = op_inst->getLoc();
|
Location loc = op_inst->getLoc();
|
||||||
|
@ -0,0 +1,118 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This transformation pass transforms functional control flow operations in the
|
||||||
|
// TensorFlow dialect to their region based counterparts, i.e.,
|
||||||
|
// tf.If -> tf.IfRegion
|
||||||
|
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Value.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||||
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
|
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace TF {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct FunctionalControlFlowToRegions
|
||||||
|
: public PassWrapper<FunctionalControlFlowToRegions, FunctionPass> {
|
||||||
|
void runOnFunction() override;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create a call to function `fn` with arguments `args` and return the CallOp.
|
||||||
|
// The arguments are cast to the required type before the call.
|
||||||
|
CallOp CreateCall(Location loc, Operation::operand_range args, FuncOp fn,
|
||||||
|
OpBuilder* builder) {
|
||||||
|
FunctionType fn_type = fn.getType();
|
||||||
|
llvm::SmallVector<Value, 4> operands;
|
||||||
|
int num_operands = fn_type.getNumInputs();
|
||||||
|
operands.reserve(num_operands);
|
||||||
|
for (int i = 0; i < num_operands; ++i) {
|
||||||
|
Value arg = args[i];
|
||||||
|
Type expected = fn_type.getInput(i);
|
||||||
|
if (arg.getType() != expected) {
|
||||||
|
arg = builder->create<CastOp>(loc, expected, arg,
|
||||||
|
/*Truncate=*/builder->getBoolAttr(false));
|
||||||
|
}
|
||||||
|
operands.push_back(arg);
|
||||||
|
}
|
||||||
|
return builder->create<CallOp>(loc, fn, operands);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transform a functional IfOp to a region based IfRegionOp
|
||||||
|
LogicalResult ConvertIfOp(IfOp if_op) {
|
||||||
|
auto if_region = OpBuilder(if_op).create<TF::IfRegionOp>(
|
||||||
|
if_op.getLoc(), if_op.getResultTypes(), if_op.cond(),
|
||||||
|
if_op.is_stateless());
|
||||||
|
|
||||||
|
// Insert call to the given function into the 'region'.
|
||||||
|
auto create_region_with_call = [&if_op](FlatSymbolRefAttr symbol,
|
||||||
|
Region& region) {
|
||||||
|
OpBuilder builder(region);
|
||||||
|
builder.createBlock(®ion);
|
||||||
|
auto func = if_op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
|
||||||
|
symbol.getValue());
|
||||||
|
auto call = CreateCall(if_op.getLoc(), if_op.input(), func, &builder);
|
||||||
|
builder.create<YieldOp>(if_op.getLoc(), call.getResults());
|
||||||
|
// Mark old function as private so that it can be DCE'd if not called.
|
||||||
|
func.setVisibility(SymbolTable::Visibility::Private);
|
||||||
|
};
|
||||||
|
|
||||||
|
create_region_with_call(if_op.then_branchAttr(), if_region.then_branch());
|
||||||
|
create_region_with_call(if_op.else_branchAttr(), if_region.else_branch());
|
||||||
|
|
||||||
|
if_op.replaceAllUsesWith(if_region.getResults());
|
||||||
|
if_op.erase();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void FunctionalControlFlowToRegions::runOnFunction() {
|
||||||
|
for (Block& block : getFunction()) {
|
||||||
|
auto result = block.walk([](Operation* op) {
|
||||||
|
if (IfOp if_op = llvm::dyn_cast<IfOp>(op)) {
|
||||||
|
if (failed(ConvertIfOp(if_op))) {
|
||||||
|
if_op.emitOpError() << " failed to convert to region form";
|
||||||
|
return WalkResult::interrupt();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return WalkResult::advance();
|
||||||
|
});
|
||||||
|
|
||||||
|
if (result.wasInterrupted()) return signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
|
CreateTFFunctionalControlFlowToRegions() {
|
||||||
|
return std::make_unique<FunctionalControlFlowToRegions>();
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<FunctionalControlFlowToRegions> pass(
|
||||||
|
"tf-functional-control-flow-to-regions",
|
||||||
|
"Transform functional control flow Ops to Region based counterparts");
|
||||||
|
|
||||||
|
} // namespace TF
|
||||||
|
} // namespace mlir
|
@ -32,10 +32,18 @@ std::unique_ptr<OperationPass<FuncOp>>
|
|||||||
CreateFunctionalToExecutorDialectConversionPass();
|
CreateFunctionalToExecutorDialectConversionPass();
|
||||||
|
|
||||||
namespace TF {
|
namespace TF {
|
||||||
// Transforms functional control flow operations in the standard TensorFlow
|
// Transforms functional control flow operations in the TensorFlow dialect to
|
||||||
// dialect to MLIR Control Flow Graph (CFG) form.
|
// MLIR Control Flow Graph (CFG) form.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> CreateTFFunctionalControlFlowToCFG();
|
std::unique_ptr<OperationPass<FuncOp>> CreateTFFunctionalControlFlowToCFG();
|
||||||
|
|
||||||
|
// Transforms functional control flow operations in the TensorFlow dialect to
|
||||||
|
// their region based counterparts.
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> CreateTFFunctionalControlFlowToRegions();
|
||||||
|
|
||||||
|
// Transforms region bases control flow operations in the TensorFlow dialect to
|
||||||
|
// their functional counterparts.
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> CreateTFRegionControlFlowToFunctional();
|
||||||
|
|
||||||
// Materialize the MlirPassthroughOp by replacing it with the MLIR module
|
// Materialize the MlirPassthroughOp by replacing it with the MLIR module
|
||||||
// attached as an attribute.
|
// attached as an attribute.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> CreateMaterializePassthroughOpPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateMaterializePassthroughOpPass();
|
||||||
|
@ -0,0 +1,321 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This transformation pass transforms region bases control flow operations in
|
||||||
|
// the TensorFlow dialect to their functional counterparts, i.e.,
|
||||||
|
// tf.IfRegion -> tf.If
|
||||||
|
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Module.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Value.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||||
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
|
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||||
|
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
|
||||||
|
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace TF {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct RegionControlFlowToFunctional
|
||||||
|
: public PassWrapper<RegionControlFlowToFunctional, FunctionPass> {
|
||||||
|
void runOnFunction() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
LogicalResult ConvertIfOp(IfRegionOp if_region);
|
||||||
|
|
||||||
|
// Get unique name by using the loc to name mapping.
|
||||||
|
std::string GetName(Operation* op, StringRef suffix);
|
||||||
|
|
||||||
|
tensorflow::OpOrArgLocNameMapper mapper;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string RegionControlFlowToFunctional::GetName(Operation* op,
|
||||||
|
StringRef suffix) {
|
||||||
|
return (mapper.GetUniqueName(op) + suffix).str();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Given a list of regions, return all the external values referenced from the
|
||||||
|
// region. If the external value is a constant, sink it into the region instead
|
||||||
|
// (and do not add it to the returned vector).
|
||||||
|
llvm::SmallVector<Value, 4> CollectExternValues(ArrayRef<Region*> regions) {
|
||||||
|
llvm::SetVector<Value> extern_values_set;
|
||||||
|
|
||||||
|
for (auto region : regions) {
|
||||||
|
llvm::SetVector<Value> region_extern_values;
|
||||||
|
getUsedValuesDefinedAbove(*region, region_extern_values);
|
||||||
|
|
||||||
|
// Sink down constants into the functions.
|
||||||
|
for (auto extern_value : region_extern_values) {
|
||||||
|
if (!matchPattern(extern_value, m_Constant())) {
|
||||||
|
extern_values_set.insert(extern_value);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Add constant at start of region.
|
||||||
|
auto const_builder = OpBuilder::atBlockBegin(®ion->front());
|
||||||
|
auto const_value = const_builder.clone(*extern_value.getDefiningOp());
|
||||||
|
replaceAllUsesInRegionWith(extern_value, const_value->getResult(0),
|
||||||
|
*region);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {extern_values_set.begin(), extern_values_set.end()};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the contents of a region with a single block into a new function.
|
||||||
|
// `extern_values` is the set of external values that the region refers to.
|
||||||
|
//
|
||||||
|
// Any inputs to the terminator of the region are converted to return values of
|
||||||
|
// the function. If any of these values is not exact type as the function's
|
||||||
|
// return type, appropriate cast operations will be inserted
|
||||||
|
void ExtractSingleBlockRegion(Region& region, FunctionType type, StringRef name,
|
||||||
|
llvm::SmallVector<Value, 4>& extern_values) {
|
||||||
|
ModuleOp module = region.getParentOfType<ModuleOp>();
|
||||||
|
auto builder = OpBuilder::atBlockBegin(module.getBody());
|
||||||
|
auto loc = region.getParentOp()->getLoc();
|
||||||
|
|
||||||
|
// Create new function and extract region body into the function
|
||||||
|
auto outlined_func =
|
||||||
|
builder.create<FuncOp>(loc, name, type, ArrayRef<NamedAttribute>{});
|
||||||
|
|
||||||
|
outlined_func.getBody().takeBody(region);
|
||||||
|
Region& func_region = outlined_func.getBody();
|
||||||
|
Block& first_block = func_region.front();
|
||||||
|
|
||||||
|
// Replace all external uses with function arguments.
|
||||||
|
for (auto it : llvm::enumerate(extern_values)) {
|
||||||
|
Value arg = first_block.addArgument(it.value().getType());
|
||||||
|
replaceAllUsesInRegionWith(it.value(), arg, func_region);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace the existing terminator with a return.
|
||||||
|
Operation* terminator = outlined_func.getBody().front().getTerminator();
|
||||||
|
builder.setInsertionPoint(terminator);
|
||||||
|
|
||||||
|
SmallVector<Value, 4> return_values;
|
||||||
|
return_values.reserve(terminator->getNumOperands());
|
||||||
|
for (auto it : llvm::enumerate(type.getResults())) {
|
||||||
|
Value ret_val = terminator->getOperand(it.index());
|
||||||
|
// Add a cast operation if types do not match.
|
||||||
|
if (ret_val.getType() != it.value()) {
|
||||||
|
ret_val =
|
||||||
|
builder.create<CastOp>(terminator->getLoc(), it.value(), ret_val);
|
||||||
|
}
|
||||||
|
return_values.push_back(ret_val);
|
||||||
|
}
|
||||||
|
builder.create<ReturnOp>(terminator->getLoc(), return_values);
|
||||||
|
terminator->erase();
|
||||||
|
outlined_func.setVisibility(FuncOp::Visibility::Private);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns call for region with single call whose result feeds into the
|
||||||
|
// terminator of the region. Returns none if the region doesn't contain just
|
||||||
|
// call and casts ops.
|
||||||
|
llvm::Optional<CallOp> IsSingleCallRegion(Region& region) {
|
||||||
|
if (region.getBlocks().size() != 1) return llvm::None;
|
||||||
|
|
||||||
|
auto it = region.front().rbegin();
|
||||||
|
YieldOp yield = dyn_cast<YieldOp>(*it++);
|
||||||
|
if (yield.getNumOperands() == 0) return llvm::None;
|
||||||
|
|
||||||
|
if (it == region.front().rend()) return llvm::None;
|
||||||
|
|
||||||
|
// and a Call before that
|
||||||
|
CallOp call = dyn_cast<CallOp>(*it++);
|
||||||
|
if (!call) return llvm::None;
|
||||||
|
|
||||||
|
// There can be cast op's prior to that
|
||||||
|
for (; it != region.front().rend(); ++it)
|
||||||
|
if (!isa<CastOp>(*it)) return llvm::None;
|
||||||
|
|
||||||
|
// all results of the call should feed into the yield
|
||||||
|
if (call.getNumResults() != yield.getNumOperands()) return llvm::None;
|
||||||
|
|
||||||
|
for (auto res_it : llvm::zip(call.getResults(), yield.getOperands()))
|
||||||
|
if (std::get<0>(res_it) != std::get<1>(res_it)) return llvm::None;
|
||||||
|
|
||||||
|
return call;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns whether the arguments of the given call are same as the given list of
|
||||||
|
// arguments (after looking through cast ops).
|
||||||
|
bool MatchCallArgs(CallOp call, llvm::SmallVectorImpl<Value>& args) {
|
||||||
|
if (call.getNumOperands() != args.size()) return false;
|
||||||
|
|
||||||
|
for (auto it : llvm::enumerate(args)) {
|
||||||
|
Value arg = call.getOperand(it.index());
|
||||||
|
if (auto cast = dyn_cast_or_null<CastOp>(arg.getDefiningOp()))
|
||||||
|
arg = cast.getOperand();
|
||||||
|
|
||||||
|
if (arg != it.value()) return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Summary information for trivially transforming region based op's to
|
||||||
|
// functional ops. A trivial transformation can be done when the regions are
|
||||||
|
// just calls to functions, in which case no outlining is needed.
|
||||||
|
struct TrivialTransformInfo {
|
||||||
|
// can the op be transformed trivially
|
||||||
|
bool can_transform = false;
|
||||||
|
|
||||||
|
// list of callee names (one for each region)
|
||||||
|
llvm::SmallVector<StringRef, 4> callee_names;
|
||||||
|
|
||||||
|
// list of arguments used in these call (each call uses the same arguments
|
||||||
|
// potentially through casts)
|
||||||
|
llvm::SmallVector<Value, 4> call_args;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Analyze the given set of regions (attached to the same parent op) to check
|
||||||
|
// if the parent op be transformed to functional form trivially (i.e., reusing
|
||||||
|
// existing functions and without outlining). This is possible when all the
|
||||||
|
// regions are single call regions and the all the calls have the same
|
||||||
|
// arguments.
|
||||||
|
//
|
||||||
|
// If this trivial transformation is possible, return the relevant information
|
||||||
|
// needed for the transformation (in `TrivialTransformInfo`), else indicate that
|
||||||
|
// a trivial transformation is not possible by setting `can_transform` false.
|
||||||
|
TrivialTransformInfo AnalyzeForTrivialTransform(ArrayRef<Region*> regions) {
|
||||||
|
const TrivialTransformInfo cannot_transform;
|
||||||
|
|
||||||
|
if (regions.empty()) return cannot_transform;
|
||||||
|
|
||||||
|
llvm::SmallVector<CallOp, 2> calls;
|
||||||
|
calls.reserve(regions.size());
|
||||||
|
|
||||||
|
// Verify each region is a single call and collect these calls.
|
||||||
|
for (Region* region : regions) {
|
||||||
|
auto call = IsSingleCallRegion(*region);
|
||||||
|
if (!call.hasValue()) return cannot_transform;
|
||||||
|
calls.push_back(call.getValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<StringRef, 4> callees;
|
||||||
|
callees.reserve(regions.size());
|
||||||
|
|
||||||
|
CallOp call0 = calls[0];
|
||||||
|
int num_args = call0.getNumOperands();
|
||||||
|
callees.push_back(call0.getCallee());
|
||||||
|
|
||||||
|
// collect call0 arguments
|
||||||
|
llvm::SmallVector<Value, 4> call0_args;
|
||||||
|
call0_args.reserve(num_args);
|
||||||
|
for (int arg_idx = 0; arg_idx < num_args; ++arg_idx) {
|
||||||
|
Value arg = call0.getOperand(arg_idx);
|
||||||
|
if (auto cast = dyn_cast_or_null<CastOp>(arg.getDefiningOp()))
|
||||||
|
arg = cast.getOperand();
|
||||||
|
call0_args.push_back(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
// match other call arguments with call0 arguments
|
||||||
|
for (int idx = 1; idx < calls.size(); ++idx) {
|
||||||
|
if (!MatchCallArgs(calls[idx], call0_args)) return cannot_transform;
|
||||||
|
callees.push_back(calls[idx].getCallee());
|
||||||
|
}
|
||||||
|
|
||||||
|
return {true, callees, call0_args};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transform IfRegionOp to IfOp
|
||||||
|
LogicalResult RegionControlFlowToFunctional::ConvertIfOp(IfRegionOp if_region) {
|
||||||
|
const TrivialTransformInfo tti = AnalyzeForTrivialTransform(
|
||||||
|
{&if_region.then_branch(), &if_region.else_branch()});
|
||||||
|
|
||||||
|
std::string then_name, else_name;
|
||||||
|
llvm::SmallVector<Value, 4> extern_values;
|
||||||
|
|
||||||
|
if (tti.can_transform) {
|
||||||
|
// We can transform to functional form trivially without outlining
|
||||||
|
then_name = tti.callee_names[0].str();
|
||||||
|
else_name = tti.callee_names[1].str();
|
||||||
|
extern_values = tti.call_args;
|
||||||
|
} else {
|
||||||
|
// Collect external values that are used within the else and then bodies.
|
||||||
|
extern_values = CollectExternValues(
|
||||||
|
{&if_region.then_branch(), &if_region.else_branch()});
|
||||||
|
|
||||||
|
// These external values need to be added as inputs to the generated If. The
|
||||||
|
// order is determined by the order of these values the `extern_vales`.
|
||||||
|
|
||||||
|
// Build the type for the outlined function
|
||||||
|
llvm::SmallVector<Type, 4> input_types;
|
||||||
|
input_types.reserve(extern_values.size());
|
||||||
|
for (auto input : extern_values) input_types.push_back(input.getType());
|
||||||
|
|
||||||
|
FunctionType func_type = FunctionType::get(
|
||||||
|
input_types, if_region.getResultTypes(), if_region.getContext());
|
||||||
|
|
||||||
|
// Create 2 new functions with the input signature matching this order,
|
||||||
|
// and outline the `then` and `else` regions by moving the bodies of these
|
||||||
|
// regions into these functions. Replace tf.yield with a regular return.
|
||||||
|
then_name = GetName(if_region, "_then");
|
||||||
|
ExtractSingleBlockRegion(if_region.then_branch(), func_type, then_name,
|
||||||
|
extern_values);
|
||||||
|
|
||||||
|
else_name = GetName(if_region, "_else");
|
||||||
|
ExtractSingleBlockRegion(if_region.else_branch(), func_type, else_name,
|
||||||
|
extern_values);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Once we have the `then` and `else` functions ready (either outlined or
|
||||||
|
// existing ones), replace the region based op with a functional control flow
|
||||||
|
// op
|
||||||
|
OpBuilder builder(if_region);
|
||||||
|
auto if_op = builder.create<IfOp>(
|
||||||
|
if_region.getLoc(), if_region.getResultTypes(), if_region.cond(),
|
||||||
|
extern_values, then_name, else_name, if_region.is_stateless());
|
||||||
|
if_region.replaceAllUsesWith(if_op.getResults());
|
||||||
|
if_region.erase();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void RegionControlFlowToFunctional::runOnFunction() {
|
||||||
|
for (Block& block : getFunction()) {
|
||||||
|
auto result = block.walk([&](Operation* op) {
|
||||||
|
if (IfRegionOp if_region = llvm::dyn_cast<IfRegionOp>(op)) {
|
||||||
|
if (failed(ConvertIfOp(if_region))) {
|
||||||
|
if_region.emitOpError() << " failed to convert to functional form";
|
||||||
|
return WalkResult::interrupt();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return WalkResult::advance();
|
||||||
|
});
|
||||||
|
|
||||||
|
if (result.wasInterrupted()) return signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> CreateTFRegionControlFlowToFunctional() {
|
||||||
|
return std::make_unique<RegionControlFlowToFunctional>();
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<RegionControlFlowToFunctional> pass(
|
||||||
|
"tf-region-control-flow-to-functional",
|
||||||
|
"Transform region bases control flow Ops to functional counterparts");
|
||||||
|
|
||||||
|
} // namespace TF
|
||||||
|
} // namespace mlir
|
Loading…
x
Reference in New Issue
Block a user