Add transforms to convert between functional and region based If.

- Add this pair of transforms in TF -> TFLite conversion pass manager
- Add inliner hook to the TF dialect to allow inlining of call's within the
  regions of the IfRegion ops'
- Add a end-to-end test case using 2 If's that demonstrate inlining and followed
  by constant sinking and constant folding.

PiperOrigin-RevId: 314966564
Change-Id: I633b8a24cb68d5c2c54b889f9505e7d6fcb905a4
This commit is contained in:
Rahul Joshi 2020-06-05 11:43:21 -07:00 committed by TensorFlower Gardener
parent b6b9f0815e
commit b055058610
11 changed files with 1082 additions and 14 deletions

View File

@ -0,0 +1,422 @@
# RUN: tf_tfl_translate -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=4:4 -tf-output-arrays=StatefulIf,StatelessIf %s -o - --output-mlir | FileCheck %s --dump-input-on-failure
node {
name: "tf.Less"
op: "Less"
input: "a"
input: "b"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
}
}
node {
name: "my_equal"
op: "Equal"
input: "a"
input: "b"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
}
}
node {
name: "cst0"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 4
}
}
float_val: 1.0
float_val: 2.0
float_val: 3.0
float_val: 4.0
}
}
}
}
node {
name: "cst1"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 4
}
}
float_val: 5.0
float_val: 6.0
float_val: 7.0
float_val: 8.0
}
}
}
}
node {
name: "StatefulIf"
op: "If"
input: "tf.Less"
input: "a"
input: "b"
input: "cst0"
input: "cst1"
attr {
key: "Tcond"
value {
type: DT_BOOL
}
}
attr {
key: "Tin"
value {
list {
type: DT_FLOAT
type: DT_FLOAT
type: DT_FLOAT
type: DT_FLOAT
}
}
}
attr {
key: "Tout"
value {
list {
type: DT_FLOAT
}
}
}
attr {
key: "else_branch"
value {
func {
name: "cond_false"
}
}
}
attr {
key: "then_branch"
value {
func {
name: "cond_true"
}
}
}
experimental_debug_info {
}
}
node {
name: "StatelessIf"
op: "StatelessIf"
input: "my_equal"
input: "a"
input: "b"
attr {
key: "Tcond"
value {
type: DT_BOOL
}
}
attr {
key: "Tin"
value {
list {
type: DT_FLOAT
type: DT_FLOAT
}
}
}
attr {
key: "Tout"
value {
list {
type: DT_FLOAT
}
}
}
attr {
key: "else_branch"
value {
func {
name: "cond_false_1"
}
}
}
attr {
key: "then_branch"
value {
func {
name: "cond_true_1"
}
}
}
experimental_debug_info {
}
}
node {
name: "main"
op: "_Retval"
input: "StatefulIf"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "index"
value {
i: 0
}
}
}
node {
name: "main1"
op: "_Retval"
input: "StatelessIf"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "index"
value {
i: 1
}
}
}
node {
name: "a"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
}
}
node {
name: "b"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
}
}
library {
function {
signature {
name: "cond_true"
input_arg {
name: "cond_true_arg0"
type: DT_FLOAT
}
input_arg {
name: "cond_true_arg1"
type: DT_FLOAT
}
input_arg {
name: "cond_true_arg2"
type: DT_FLOAT
}
input_arg {
name: "cond_true_arg3"
type: DT_FLOAT
}
output_arg {
name: "cond_true_ret"
type: DT_FLOAT
}
}
node_def {
name: "tf.Add"
op: "Add"
input: "cond_true_arg2"
input: "cond_true_arg3"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
original_node_names: "tf.Add"
}
}
ret {
key: "cond_true_ret"
value: "tf.Add:z:0"
}
}
function {
signature {
name: "cond_false"
input_arg {
name: "cond_false_arg0"
type: DT_FLOAT
}
input_arg {
name: "cond_false_arg1"
type: DT_FLOAT
}
input_arg {
name: "cond_false_arg2"
type: DT_FLOAT
}
input_arg {
name: "cond_false_arg3"
type: DT_FLOAT
}
output_arg {
name: "cond_false_ret"
type: DT_FLOAT
}
}
node_def {
name: "tf.Mul"
op: "Mul"
input: "cond_false_arg0"
input: "cond_false_arg3"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
original_node_names: "tf.Mul"
}
}
ret {
key: "cond_false_ret"
value: "tf.Mul:z:0"
}
}
function {
signature {
name: "cond_true_1"
input_arg {
name: "cond_true_arg0"
type: DT_FLOAT
}
input_arg {
name: "cond_true_arg1"
type: DT_FLOAT
}
output_arg {
name: "cond_true_ret"
type: DT_FLOAT
}
}
node_def {
name: "tf.Sub"
op: "Sub"
input: "cond_true_arg0"
input: "cond_true_arg1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
original_node_names: "tf.Sub"
}
}
ret {
key: "cond_true_ret"
value: "tf.Sub:z:0"
}
}
function {
signature {
name: "cond_false_1"
input_arg {
name: "cond_false_arg0"
type: DT_FLOAT
}
input_arg {
name: "cond_false_arg1"
type: DT_FLOAT
}
output_arg {
name: "cond_false_ret"
type: DT_FLOAT
}
}
node_def {
name: "tf.Div"
op: "Div"
input: "cond_false_arg0"
input: "cond_false_arg1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
original_node_names: "tf.Div"
}
}
ret {
key: "cond_false_ret"
value: "tf.Div:z:0"
}
}
}
versions {
producer: 115
min_consumer: 12
}
# CHECK: func @StatefulIf_else
# CHECK-NEXT: constant dense<[5.000000e+00, 6.000000e+00, 7.000000e+00, 8.000000e+00]>
# CHECK-NEXT: tfl.mul
# CHECK: func @StatefulIf_then
# CHECK-NEXT: constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]>
# CHECK-NEXT: return
# CHECK: func @StatelessIf_else
# CHECK-NEXT: tfl.div
# CHECK: func @StatelessIf_then
# CHECK-NEXT: tfl.sub
# CHECK: "tf.If"{{.+}}else_branch = @StatelessIf_else{{.+}}then_branch = @StatelessIf_then
# CHECK: "tf.If"{{.+}}else_branch = @StatefulIf_else{{.+}}then_branch = @StatefulIf_then

View File

@ -85,6 +85,8 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
pass_config.quant_specs.serialized_quant_stats));
}
pass_manager->addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions());
// The conversion pipeline has to follow the following orders:
// 1) Saved model related optimization like decompose resource ops
// 2) Convert composite functions like lstm/rnns, along with proper function
@ -128,6 +130,9 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
// Add a shape inference pass to optimize away the unnecessary casts.
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
}
pass_manager->addPass(mlir::TF::CreateTFRegionControlFlowToFunctional());
// Legalize while early to allow further constant folding.
// TODO(jpienaar): This may not actually matter as we do canonicalization
// after the legalize below, for now it needs to be below the above passes

View File

@ -54,7 +54,6 @@ class WhileOutlinePass
tensorflow::OpOrArgLocNameMapper mapper_;
};
} // namespace
std::string WhileOutlinePass::GetName(Operation* op, StringRef suffix) {
return (mapper_.GetUniqueName(op) + suffix).str();
@ -62,7 +61,7 @@ std::string WhileOutlinePass::GetName(Operation* op, StringRef suffix) {
// Returns whether the WhileOp is already outlined (e.g., only consists of calls
// to functions).
static bool IsAlreadyOutlinedd(WhileOp while_op) {
bool IsAlreadyOutlined(WhileOp while_op) {
auto just_call = [](Region& region) {
auto it = region.front().begin();
if (!isa<CallOp>(*it)) return false;
@ -120,7 +119,7 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
}
// Skip if already just calls.
if (extra_operands.empty() && IsAlreadyOutlinedd(while_op)) return;
if (extra_operands.empty() && IsAlreadyOutlined(while_op)) return;
// Collect new types.
SmallVector<Type, 4> types;
@ -238,6 +237,7 @@ void WhileOutlinePass::runOnOperation() {
getOperation().walk(
[&](mlir::TFL::WhileOp while_op) { OutlineWhile(while_op); });
}
} // namespace
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass() {

View File

@ -418,6 +418,7 @@ cc_library(
"transforms/fold_switch.cc",
"transforms/freeze_global_tensors.cc",
"transforms/functional_control_flow_to_cfg.cc",
"transforms/functional_control_flow_to_regions.cc",
"transforms/generated_canonicalize.inc",
"transforms/generated_optimize.inc",
"transforms/gpu_fusion.cc",
@ -432,6 +433,7 @@ cc_library(
"transforms/promote_resources_to_args.cc",
"transforms/raise_control_flow.cc",
"transforms/readonly_references_to_resources.cc",
"transforms/region_control_flow_to_functional.cc",
"transforms/replicate_invariant_op_hoisting.cc",
"transforms/replicate_to_island.cc",
"transforms/resource_device_inference.cc",
@ -490,6 +492,7 @@ cc_library(
":translate_utils",
":unroll_batch_matmul_pass",
":xla_sharding_util",
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
"//tensorflow/compiler/mlir/lite:validators",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla:xla_proto_cc",

View File

@ -4011,6 +4011,15 @@ struct TFInlinerInterface : public DialectInlinerInterface {
// Analysis Hooks
//===--------------------------------------------------------------------===//
// Defines the legality of inlinining 'src' region into the 'dest' region
// attached to a TF operation
bool isLegalToInline(Region *dest, Region *src,
BlockAndValueMapping &valueMapping) const final {
// Allow inlining in regions attached to region based control flow
// operations only if the src region is a single block region
return isa<IfRegionOp>(dest->getParentOp()) && src->getBlocks().size() == 1;
}
// Defines the legality of inlining TF operations.
bool isLegalToInline(Operation *, Region *,
BlockAndValueMapping &) const final {

View File

@ -0,0 +1,48 @@
// RUN: tf-opt %s -tf-functional-control-flow-to-regions -split-input-file | FileCheck %s --dump-input=fail
// CHECK: func @testIf1Then{{.+}} {sym_visibility = "private"}
// CHECK: func @testIf1Else{{.+}} {sym_visibility = "private"}
func @testIf1Then(tensor<*xf32>) -> tensor<*xf32>
func @testIf1Else(tensor<*xf32>) -> tensor<*xf32>
// CHECK-LABEL: func @testIf1Result(%arg0: tensor<i1>, %arg1: tensor<*xf32>)
func @testIf1Result(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = "tf.If"(%arg0, %arg1) {
then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false
} : (tensor<i1>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: "tf.IfRegion"
// CHECK: [[Result0:%.*]] = call @testIf1Then
// CHECK: "tf.Yield"([[Result0]])
// CHECK: [[Result1:%.*]] = call @testIf1Else
// CHECK: "tf.Yield"([[Result1]])
return %0 : tensor<*xf32>
}
// -----
// With mismatching input types
// CHECK: func @testIf1Then{{.+}} {sym_visibility = "private"}
// CHECK: func @testIf1Else{{.+}} {sym_visibility = "private"}
func @testIf1Then(tensor<*xf32>) -> tensor<*xf32>
func @testIf1Else(tensor<*xf32>) -> tensor<*xf32>
// CHECK-LABEL: func @testIf2Result(%arg0: tensor<i1>, %arg1: tensor<2xf32>)
func @testIf2Result(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%0 = "tf.If"(%arg0, %arg1) {
then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false
} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
// CHECK: "tf.IfRegion"
// CHECK: "tf.Cast"
// CHECK: [[Result0:%.*]] = call @testIf1Then
// CHECK: "tf.Yield"([[Result0]])
// CHECK: "tf.Cast"
// CHECK: [[Result1:%.*]] = call @testIf1Else
// CHECK: "tf.Yield"([[Result1]])
return %0 : tensor<2xf32>
}

View File

@ -0,0 +1,142 @@
// RUN: tf-opt %s -tf-region-control-flow-to-functional -split-input-file
//| FileCheck %s --dump-input=fail
// CHECK: func @tf.IfRegion_else(%arg0: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"}
// CHECK-NEXT: "tf.Neg"
// CHECK: func @tf.IfRegion_then(%arg0: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"}
// CHECK-NEXT: "tf.Abs"
func @testSimple(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: "tf.If"{{.+}}else_branch = @tf.IfRegion_else{{.+}}then_branch = @tf.IfRegion_then
%0 = "tf.IfRegion"(%arg0) ({
%1 = "tf.Abs"(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
"tf.Yield"(%1) : (tensor<*xf32>) -> ()
}, {
%2 = "tf.Neg"(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
"tf.Yield"(%2) : (tensor<*xf32>) -> ()
}) { is_stateless = true } : (tensor<i1>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
// Use if condition inside the regions
// CHECK: func @tf.IfRegion_else(%arg0: tensor<i1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>) -> tensor<2xf32> attributes {sym_visibility = "private"}
// CHECK-NEXT: "tf.Select"(%arg0, %arg2, %arg3)
// CHECK: func @tf.IfRegion_then(%arg0: tensor<i1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>) -> tensor<2xf32> attributes {sym_visibility = "private"}
// CHECK-NEXT: "tf.Select"(%arg0, %arg1, %arg2)
func @testIfCondition(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%0 = "tf.Add"(%arg1, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
%1 = "tf.Mul"(%arg1, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
%2 = "tf.Div"(%arg1, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
// CHECK: "tf.If"{{.+}}else_branch = @tf.IfRegion_else{{.+}}then_branch = @tf.IfRegion_then
%3 = "tf.IfRegion"(%arg0) ({
%4 = "tf.Select"(%arg0, %0, %1) : (tensor<i1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%4) : (tensor<2xf32>) -> ()
}, {
%5 = "tf.Select"(%arg0, %1, %2): (tensor<i1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%5) : (tensor<2xf32>) -> ()
}) { is_stateless = true} : (tensor<i1>) -> tensor<2xf32>
return %3 : tensor<2xf32>
}
// -----
// Constant sinking
// CHECK: func @tf.IfRegion_else() -> tensor<2xf32>
// CHECK-NEXT: constant dense<1.0
// CHECK: func @tf.IfRegion_then() -> tensor<2xf32>
// CHECK-NEXT: constant dense<0.0
func @testIfConstant(%arg0: tensor<i1>) -> tensor<2xf32> {
%cst_zero = constant dense<0.0> : tensor<2xf32>
// CHECK: "tf.If"(%arg0) {else_branch = @tf.IfRegion_else{{.+}}then_branch = @tf.IfRegion_then
%0 = "tf.IfRegion"(%arg0) ({
"tf.Yield"(%cst_zero) : (tensor<2xf32>) -> ()
}, {
%cst_one = constant dense<1.0> : tensor<2xf32>
"tf.Yield"(%cst_one) : (tensor<2xf32>) -> ()
}) { is_stateless = true} : (tensor<i1>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// -----
// Nested IfRegions
// CHECK: func @tf.IfRegion1_else
// CHECK-NEXT: "tf.Acos"
// CHECK-NEXT: "tf.Abs"
// CHECK: func @tf.IfRegion1_then
// CHECK-NEXT: "tf.LogicalNot"
// CHECK-NEXT: "tf.Asin"
// CHECK-NEXT: "tf.If"({{.+}}) {else_branch = @tf.IfRegion_else, {{.+}} then_branch = @tf.IfRegion_then}
// CHECK: func @tf.IfRegion_else
// CHECK-NEXT: "tf.Neg"
// CHECK: func @tf.IfRegion_then
// CHECK-NEXT: "tf.Abs"
func @testNested(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: "tf.If"({{.+}}) {else_branch = @tf.IfRegion1_else, {{.+}} then_branch = @tf.IfRegion1_then}
%0 = "tf.IfRegion"(%arg0) ({
// Outer Then
%cond = "tf.LogicalNot"(%arg0) : (tensor<i1>) -> tensor<i1>
%asin = "tf.Asin"(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
// nested IfRegion
%1 = "tf.IfRegion"(%cond) ({
%2 = "tf.Abs"(%asin) : (tensor<*xf32>) -> tensor<*xf32>
"tf.Yield"(%2) : (tensor<*xf32>) -> ()
}, {
%2 = "tf.Neg"(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
"tf.Yield"(%2) : (tensor<*xf32>) -> ()
}) { is_stateless = true } : (tensor<i1>) -> tensor<*xf32>
"tf.Yield"(%1) : (tensor<*xf32>) -> ()
}, {
// Outer Else
%acos = "tf.Acos"(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
%3 = "tf.Abs"(%acos) : (tensor<*xf32>) -> tensor<*xf32>
"tf.Yield"(%3) : (tensor<*xf32>) -> ()
}) { is_stateless = true } : (tensor<i1>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
// Match existing function->Region pattern (simple)
func @testIf1Then(tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"}
func @testIf1Else(tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"}
func @testIf1Result(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: "tf.If"({{.+}}) {else_branch = @testIf1Else, {{.+}} then_branch = @testIf1Then}
%0 = "tf.IfRegion"(%arg0) ( {
%1 = call @testIf1Then(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
"tf.Yield"(%1) : (tensor<*xf32>) -> ()
}, {
%1 = call @testIf1Else(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
"tf.Yield"(%1) : (tensor<*xf32>) -> ()
}) {is_stateless = false} : (tensor<i1>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
// Match existing function->Region pattern (with casts)
func @testIf1Then(tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"}
func @testIf1Else(tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"}
func @testIf2Result(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// CHECK: "tf.If"({{.+}}) {else_branch = @testIf1Else, {{.+}} then_branch = @testIf1Then}
%0 = "tf.IfRegion"(%arg0) ( {
%1 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2xf32>) -> tensor<*xf32>
%2 = call @testIf1Then(%1) : (tensor<*xf32>) -> tensor<*xf32>
"tf.Yield"(%2) : (tensor<*xf32>) -> ()
}, {
%1 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2xf32>) -> tensor<*xf32>
%2 = call @testIf1Else(%1) : (tensor<*xf32>) -> tensor<*xf32>
"tf.Yield"(%2) : (tensor<*xf32>) -> ()
}) {is_stateless = false} : (tensor<i1>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}

View File

@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
// This transformation pass transforms functional control flow operations in the
// standard TensorFlow dialect to MLIR Control Flow Graph (CFG) form.
// TensorFlow dialect to MLIR Control Flow Graph (CFG) form.
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
@ -52,7 +52,6 @@ static Value LowerCondition(Location loc, Value value, OpBuilder* builder) {
//
// Requires the function to provide arguments for each of the `fn` operands
// that is compatible for tensor cast.
//
static Operation* CallFn(Location loc, const std::function<Value(int)>& get_arg,
FuncOp fn, OpBuilder* builder) {
FunctionType fn_type = fn.getType();
@ -113,7 +112,6 @@ static void JumpToBlock(Location loc, const std::function<Value(int)>& get_arg,
// Requires that the block has same number of arguments as number of results of
// the operation and either they have same types or are more generic types and
// it is possible to cast them to results' types.
//
static void ReplaceOpResultWithBlockArgs(Location loc, Operation* op,
Block* block, OpBuilder* builder) {
assert(op->getNumResults() == block->getNumArguments());
@ -132,9 +130,6 @@ static void ReplaceOpResultWithBlockArgs(Location loc, Operation* op,
// Given a functional IfOp, transforms the enclosing code to eliminate it
// completely from the IR, breaking it into operations to evaluate the condition
// as a bool, plus some branches.
//
// This returns true on failure.
//
static LogicalResult LowerIfOp(IfOp op) {
Operation* op_inst = op.getOperation();
Location loc = op_inst->getLoc();
@ -193,9 +188,6 @@ static LogicalResult LowerIfOp(IfOp op) {
// Given a functional WhileOp, transforms the enclosing code to eliminate it
// completely from the IR, breaking it into operations to execute the loop body
// repeatedly while the loop condition is true.
//
// This returns true on failure.
//
static LogicalResult LowerWhileOp(WhileOp op) {
Operation* op_inst = op.getOperation();
Location loc = op_inst->getLoc();

View File

@ -0,0 +1,118 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass transforms functional control flow operations in the
// TensorFlow dialect to their region based counterparts, i.e.,
// tf.If -> tf.IfRegion
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
namespace mlir {
namespace TF {
namespace {
struct FunctionalControlFlowToRegions
: public PassWrapper<FunctionalControlFlowToRegions, FunctionPass> {
void runOnFunction() override;
};
// Create a call to function `fn` with arguments `args` and return the CallOp.
// The arguments are cast to the required type before the call.
CallOp CreateCall(Location loc, Operation::operand_range args, FuncOp fn,
OpBuilder* builder) {
FunctionType fn_type = fn.getType();
llvm::SmallVector<Value, 4> operands;
int num_operands = fn_type.getNumInputs();
operands.reserve(num_operands);
for (int i = 0; i < num_operands; ++i) {
Value arg = args[i];
Type expected = fn_type.getInput(i);
if (arg.getType() != expected) {
arg = builder->create<CastOp>(loc, expected, arg,
/*Truncate=*/builder->getBoolAttr(false));
}
operands.push_back(arg);
}
return builder->create<CallOp>(loc, fn, operands);
}
// Transform a functional IfOp to a region based IfRegionOp
LogicalResult ConvertIfOp(IfOp if_op) {
auto if_region = OpBuilder(if_op).create<TF::IfRegionOp>(
if_op.getLoc(), if_op.getResultTypes(), if_op.cond(),
if_op.is_stateless());
// Insert call to the given function into the 'region'.
auto create_region_with_call = [&if_op](FlatSymbolRefAttr symbol,
Region& region) {
OpBuilder builder(region);
builder.createBlock(&region);
auto func = if_op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
symbol.getValue());
auto call = CreateCall(if_op.getLoc(), if_op.input(), func, &builder);
builder.create<YieldOp>(if_op.getLoc(), call.getResults());
// Mark old function as private so that it can be DCE'd if not called.
func.setVisibility(SymbolTable::Visibility::Private);
};
create_region_with_call(if_op.then_branchAttr(), if_region.then_branch());
create_region_with_call(if_op.else_branchAttr(), if_region.else_branch());
if_op.replaceAllUsesWith(if_region.getResults());
if_op.erase();
return success();
}
void FunctionalControlFlowToRegions::runOnFunction() {
for (Block& block : getFunction()) {
auto result = block.walk([](Operation* op) {
if (IfOp if_op = llvm::dyn_cast<IfOp>(op)) {
if (failed(ConvertIfOp(if_op))) {
if_op.emitOpError() << " failed to convert to region form";
return WalkResult::interrupt();
}
}
return WalkResult::advance();
});
if (result.wasInterrupted()) return signalPassFailure();
}
}
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
CreateTFFunctionalControlFlowToRegions() {
return std::make_unique<FunctionalControlFlowToRegions>();
}
static PassRegistration<FunctionalControlFlowToRegions> pass(
"tf-functional-control-flow-to-regions",
"Transform functional control flow Ops to Region based counterparts");
} // namespace TF
} // namespace mlir

View File

@ -32,10 +32,18 @@ std::unique_ptr<OperationPass<FuncOp>>
CreateFunctionalToExecutorDialectConversionPass();
namespace TF {
// Transforms functional control flow operations in the standard TensorFlow
// dialect to MLIR Control Flow Graph (CFG) form.
// Transforms functional control flow operations in the TensorFlow dialect to
// MLIR Control Flow Graph (CFG) form.
std::unique_ptr<OperationPass<FuncOp>> CreateTFFunctionalControlFlowToCFG();
// Transforms functional control flow operations in the TensorFlow dialect to
// their region based counterparts.
std::unique_ptr<OperationPass<FuncOp>> CreateTFFunctionalControlFlowToRegions();
// Transforms region bases control flow operations in the TensorFlow dialect to
// their functional counterparts.
std::unique_ptr<OperationPass<FuncOp>> CreateTFRegionControlFlowToFunctional();
// Materialize the MlirPassthroughOp by replacing it with the MLIR module
// attached as an attribute.
std::unique_ptr<OperationPass<FuncOp>> CreateMaterializePassthroughOpPass();

View File

@ -0,0 +1,321 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass transforms region bases control flow operations in
// the TensorFlow dialect to their functional counterparts, i.e.,
// tf.IfRegion -> tf.If
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
namespace mlir {
namespace TF {
namespace {
struct RegionControlFlowToFunctional
: public PassWrapper<RegionControlFlowToFunctional, FunctionPass> {
void runOnFunction() override;
private:
LogicalResult ConvertIfOp(IfRegionOp if_region);
// Get unique name by using the loc to name mapping.
std::string GetName(Operation* op, StringRef suffix);
tensorflow::OpOrArgLocNameMapper mapper;
};
std::string RegionControlFlowToFunctional::GetName(Operation* op,
StringRef suffix) {
return (mapper.GetUniqueName(op) + suffix).str();
}
// Given a list of regions, return all the external values referenced from the
// region. If the external value is a constant, sink it into the region instead
// (and do not add it to the returned vector).
llvm::SmallVector<Value, 4> CollectExternValues(ArrayRef<Region*> regions) {
llvm::SetVector<Value> extern_values_set;
for (auto region : regions) {
llvm::SetVector<Value> region_extern_values;
getUsedValuesDefinedAbove(*region, region_extern_values);
// Sink down constants into the functions.
for (auto extern_value : region_extern_values) {
if (!matchPattern(extern_value, m_Constant())) {
extern_values_set.insert(extern_value);
continue;
}
// Add constant at start of region.
auto const_builder = OpBuilder::atBlockBegin(&region->front());
auto const_value = const_builder.clone(*extern_value.getDefiningOp());
replaceAllUsesInRegionWith(extern_value, const_value->getResult(0),
*region);
}
}
return {extern_values_set.begin(), extern_values_set.end()};
}
// Extract the contents of a region with a single block into a new function.
// `extern_values` is the set of external values that the region refers to.
//
// Any inputs to the terminator of the region are converted to return values of
// the function. If any of these values is not exact type as the function's
// return type, appropriate cast operations will be inserted
void ExtractSingleBlockRegion(Region& region, FunctionType type, StringRef name,
llvm::SmallVector<Value, 4>& extern_values) {
ModuleOp module = region.getParentOfType<ModuleOp>();
auto builder = OpBuilder::atBlockBegin(module.getBody());
auto loc = region.getParentOp()->getLoc();
// Create new function and extract region body into the function
auto outlined_func =
builder.create<FuncOp>(loc, name, type, ArrayRef<NamedAttribute>{});
outlined_func.getBody().takeBody(region);
Region& func_region = outlined_func.getBody();
Block& first_block = func_region.front();
// Replace all external uses with function arguments.
for (auto it : llvm::enumerate(extern_values)) {
Value arg = first_block.addArgument(it.value().getType());
replaceAllUsesInRegionWith(it.value(), arg, func_region);
}
// Replace the existing terminator with a return.
Operation* terminator = outlined_func.getBody().front().getTerminator();
builder.setInsertionPoint(terminator);
SmallVector<Value, 4> return_values;
return_values.reserve(terminator->getNumOperands());
for (auto it : llvm::enumerate(type.getResults())) {
Value ret_val = terminator->getOperand(it.index());
// Add a cast operation if types do not match.
if (ret_val.getType() != it.value()) {
ret_val =
builder.create<CastOp>(terminator->getLoc(), it.value(), ret_val);
}
return_values.push_back(ret_val);
}
builder.create<ReturnOp>(terminator->getLoc(), return_values);
terminator->erase();
outlined_func.setVisibility(FuncOp::Visibility::Private);
}
// Returns call for region with single call whose result feeds into the
// terminator of the region. Returns none if the region doesn't contain just
// call and casts ops.
llvm::Optional<CallOp> IsSingleCallRegion(Region& region) {
if (region.getBlocks().size() != 1) return llvm::None;
auto it = region.front().rbegin();
YieldOp yield = dyn_cast<YieldOp>(*it++);
if (yield.getNumOperands() == 0) return llvm::None;
if (it == region.front().rend()) return llvm::None;
// and a Call before that
CallOp call = dyn_cast<CallOp>(*it++);
if (!call) return llvm::None;
// There can be cast op's prior to that
for (; it != region.front().rend(); ++it)
if (!isa<CastOp>(*it)) return llvm::None;
// all results of the call should feed into the yield
if (call.getNumResults() != yield.getNumOperands()) return llvm::None;
for (auto res_it : llvm::zip(call.getResults(), yield.getOperands()))
if (std::get<0>(res_it) != std::get<1>(res_it)) return llvm::None;
return call;
}
// Returns whether the arguments of the given call are same as the given list of
// arguments (after looking through cast ops).
bool MatchCallArgs(CallOp call, llvm::SmallVectorImpl<Value>& args) {
if (call.getNumOperands() != args.size()) return false;
for (auto it : llvm::enumerate(args)) {
Value arg = call.getOperand(it.index());
if (auto cast = dyn_cast_or_null<CastOp>(arg.getDefiningOp()))
arg = cast.getOperand();
if (arg != it.value()) return false;
}
return true;
}
// Summary information for trivially transforming region based op's to
// functional ops. A trivial transformation can be done when the regions are
// just calls to functions, in which case no outlining is needed.
struct TrivialTransformInfo {
// can the op be transformed trivially
bool can_transform = false;
// list of callee names (one for each region)
llvm::SmallVector<StringRef, 4> callee_names;
// list of arguments used in these call (each call uses the same arguments
// potentially through casts)
llvm::SmallVector<Value, 4> call_args;
};
// Analyze the given set of regions (attached to the same parent op) to check
// if the parent op be transformed to functional form trivially (i.e., reusing
// existing functions and without outlining). This is possible when all the
// regions are single call regions and the all the calls have the same
// arguments.
//
// If this trivial transformation is possible, return the relevant information
// needed for the transformation (in `TrivialTransformInfo`), else indicate that
// a trivial transformation is not possible by setting `can_transform` false.
TrivialTransformInfo AnalyzeForTrivialTransform(ArrayRef<Region*> regions) {
const TrivialTransformInfo cannot_transform;
if (regions.empty()) return cannot_transform;
llvm::SmallVector<CallOp, 2> calls;
calls.reserve(regions.size());
// Verify each region is a single call and collect these calls.
for (Region* region : regions) {
auto call = IsSingleCallRegion(*region);
if (!call.hasValue()) return cannot_transform;
calls.push_back(call.getValue());
}
llvm::SmallVector<StringRef, 4> callees;
callees.reserve(regions.size());
CallOp call0 = calls[0];
int num_args = call0.getNumOperands();
callees.push_back(call0.getCallee());
// collect call0 arguments
llvm::SmallVector<Value, 4> call0_args;
call0_args.reserve(num_args);
for (int arg_idx = 0; arg_idx < num_args; ++arg_idx) {
Value arg = call0.getOperand(arg_idx);
if (auto cast = dyn_cast_or_null<CastOp>(arg.getDefiningOp()))
arg = cast.getOperand();
call0_args.push_back(arg);
}
// match other call arguments with call0 arguments
for (int idx = 1; idx < calls.size(); ++idx) {
if (!MatchCallArgs(calls[idx], call0_args)) return cannot_transform;
callees.push_back(calls[idx].getCallee());
}
return {true, callees, call0_args};
}
// Transform IfRegionOp to IfOp
LogicalResult RegionControlFlowToFunctional::ConvertIfOp(IfRegionOp if_region) {
const TrivialTransformInfo tti = AnalyzeForTrivialTransform(
{&if_region.then_branch(), &if_region.else_branch()});
std::string then_name, else_name;
llvm::SmallVector<Value, 4> extern_values;
if (tti.can_transform) {
// We can transform to functional form trivially without outlining
then_name = tti.callee_names[0].str();
else_name = tti.callee_names[1].str();
extern_values = tti.call_args;
} else {
// Collect external values that are used within the else and then bodies.
extern_values = CollectExternValues(
{&if_region.then_branch(), &if_region.else_branch()});
// These external values need to be added as inputs to the generated If. The
// order is determined by the order of these values the `extern_vales`.
// Build the type for the outlined function
llvm::SmallVector<Type, 4> input_types;
input_types.reserve(extern_values.size());
for (auto input : extern_values) input_types.push_back(input.getType());
FunctionType func_type = FunctionType::get(
input_types, if_region.getResultTypes(), if_region.getContext());
// Create 2 new functions with the input signature matching this order,
// and outline the `then` and `else` regions by moving the bodies of these
// regions into these functions. Replace tf.yield with a regular return.
then_name = GetName(if_region, "_then");
ExtractSingleBlockRegion(if_region.then_branch(), func_type, then_name,
extern_values);
else_name = GetName(if_region, "_else");
ExtractSingleBlockRegion(if_region.else_branch(), func_type, else_name,
extern_values);
}
// Once we have the `then` and `else` functions ready (either outlined or
// existing ones), replace the region based op with a functional control flow
// op
OpBuilder builder(if_region);
auto if_op = builder.create<IfOp>(
if_region.getLoc(), if_region.getResultTypes(), if_region.cond(),
extern_values, then_name, else_name, if_region.is_stateless());
if_region.replaceAllUsesWith(if_op.getResults());
if_region.erase();
return success();
}
void RegionControlFlowToFunctional::runOnFunction() {
for (Block& block : getFunction()) {
auto result = block.walk([&](Operation* op) {
if (IfRegionOp if_region = llvm::dyn_cast<IfRegionOp>(op)) {
if (failed(ConvertIfOp(if_region))) {
if_region.emitOpError() << " failed to convert to functional form";
return WalkResult::interrupt();
}
}
return WalkResult::advance();
});
if (result.wasInterrupted()) return signalPassFailure();
}
}
} // namespace
std::unique_ptr<OperationPass<FuncOp>> CreateTFRegionControlFlowToFunctional() {
return std::make_unique<RegionControlFlowToFunctional>();
}
static PassRegistration<RegionControlFlowToFunctional> pass(
"tf-region-control-flow-to-functional",
"Transform region bases control flow Ops to functional counterparts");
} // namespace TF
} // namespace mlir