Add support for tf.CaseRegion in LegalizeTFControlFlow pass.

This adds a legalization for tf.CaseRegion -> mhlo.case with special handling for implicitly captured/used inputs.

PiperOrigin-RevId: 333292773
Change-Id: I963de7446d80f10acd437caa39f040f81438dda4
This commit is contained in:
Andy Ly 2020-09-23 08:03:04 -07:00 committed by TensorFlower Gardener
parent 4509a4b7c9
commit 77c5c05aef
2 changed files with 70 additions and 10 deletions
tensorflow/compiler/mlir/xla

View File

@ -117,6 +117,43 @@ func @floor(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>
}
// CHECK-LABEL: func @caseRegion
// CHECK-SAME: ([[BRANCH_INDEX:%.+]]: tensor<i32>, [[ARG0:.+]]: tensor<f32>, [[ARG1:%.+]]: tensor<f32>)
func @caseRegion(%index: tensor<i32>, %arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
// CHECK: [[VAL0:%.+]] = "mhlo.tuple"([[ARG1]])
// CHECK: [[VAL1:%.+]] = "mhlo.tuple"([[ARG0]], [[ARG1]])
// CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[ARG0]], [[ARG1]])
// CHECK: [[VAL3:%.+]]:2 = "mhlo.case"([[BRANCH_INDEX]], [[VAL0]], [[VAL1]], [[VAL2]]) ( {
%0:2 = "tf.CaseRegion"(%index) ( {
// CHECK: ^{{[a-z0-9]+}}([[BRANCH0_ARG:%.+]]: tuple<tensor<f32>>):
// CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[BRANCH0_ARG]]) {index = 0 : i32}
// CHECK: [[VAL5:%.+]] = "mhlo.exponential"([[VAL4]])
%1 = "mhlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32>
// CHECK: "mhlo.return"([[VAL5]], [[VAL4]])
"tf.Yield"(%1, %arg1) : (tensor<f32>, tensor<f32>) -> ()
}, {
// CHECK: ^{{[a-z0-9]+}}([[BRANCH1_ARG:%.+]]: tuple<tensor<f32>, tensor<f32>>):
// CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[BRANCH1_ARG]]) {index = 0 : i32}
// CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[BRANCH1_ARG]]) {index = 1 : i32}
// CHECK: [[VAL6:%.+]] = "mhlo.log"([[VAL4]])
%1 = "mhlo.log"(%arg0) : (tensor<f32>) -> tensor<f32>
// CHECK: "mhlo.return"([[VAL6]], [[VAL5]])
"tf.Yield"(%1, %arg1) : (tensor<f32>, tensor<f32>) -> ()
}, {
// CHECK: ^{{[a-z0-9]+}}([[BRANCH2_ARG:%.+]]: tuple<tensor<f32>, tensor<f32>>):
// CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[BRANCH2_ARG]]) {index = 0 : i32}
// CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[BRANCH2_ARG]]) {index = 1 : i32}
// CHECK: [[VAL6:%.+]] = "mhlo.floor"([[VAL4]])
%1 = "mhlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32>
// CHECK: "mhlo.return"([[VAL6]], [[VAL5]])
"tf.Yield"(%1, %arg1) : (tensor<f32>, tensor<f32>) -> ()
// CHECK: }) : (tensor<i32>, tuple<tensor<f32>>, tuple<tensor<f32>, tensor<f32>>, tuple<tensor<f32>, tensor<f32>>) -> (tensor<f32>, tensor<f32>)
}) {is_stateless = true} : (tensor<i32>) -> (tensor<f32>, tensor<f32>)
// CHECK: return [[VAL3]]#0, [[VAL3]]#1 : tensor<f32>, tensor<f32>
return %0#0, %0#1 : tensor<f32>, tensor<f32>
}
// CHECK-LABEL: func @while
func @while() -> tensor<i32> {
// CHECK: [[VAL0:%.+]] = mhlo.constant dense<0>

View File

@ -20,33 +20,24 @@ limitations under the License.
#include <cstdint>
#include <iterator>
#include <numeric>
#include <tuple>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/iterator_range.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/core/util/tensor_format.h"
using mlir::PassRegistration;
@ -290,6 +281,34 @@ void LowerIfRegion(TF::IfRegionOp op) {
op.erase();
}
void LowerCaseRegion(TF::CaseRegionOp op) {
Location loc = op.getLoc();
OpBuilder builder(op);
llvm::SmallVector<Value, 4> branch_inputs;
branch_inputs.reserve(op.branches().size());
// Tuple implicit inputs per region and update terminators.
for (Region& region : op.branches()) {
builder.setInsertionPoint(op);
Value branch_input = TupleImplicitInputs(region, loc, &builder);
branch_inputs.emplace_back(branch_input);
ReplaceTerminator(&region.front(), /*extra_results=*/{}, &builder,
/*tuple_return=*/false);
}
// Create the new `mhlo.case` op with tuple inputs and take ownership of
// regions from `tf.CaseRegion` op.
builder.setInsertionPoint(op);
auto case_op =
builder.create<mhlo::CaseOp>(loc, op.getResultTypes(), op.branch_index(),
branch_inputs, branch_inputs.size());
for (auto region : llvm::zip(case_op.branches(), op.branches()))
std::get<0>(region).takeBody(std::get<1>(region));
op.replaceAllUsesWith(case_op.getResults());
op.erase();
}
void LowerWhileRegion(TF::WhileRegionOp op) {
Location loc = op.getLoc();
OpBuilder builder(op);
@ -370,6 +389,10 @@ void LegalizeTFControlFlow::runOnOperation() {
LowerCase(case_op);
return;
}
if (auto case_region_op = dyn_cast<TF::CaseRegionOp>(op)) {
LowerCaseRegion(case_region_op);
return;
}
});
}
} // namespace mhlo