Add support for tf.CaseRegion in LegalizeTFControlFlow pass.
This adds a legalization for tf.CaseRegion -> mhlo.case with special handling for implicitly captured/used inputs. PiperOrigin-RevId: 333292773 Change-Id: I963de7446d80f10acd437caa39f040f81438dda4
This commit is contained in:
parent
4509a4b7c9
commit
77c5c05aef
tensorflow/compiler/mlir/xla
@ -117,6 +117,43 @@ func @floor(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func @caseRegion
|
||||
// CHECK-SAME: ([[BRANCH_INDEX:%.+]]: tensor<i32>, [[ARG0:.+]]: tensor<f32>, [[ARG1:%.+]]: tensor<f32>)
|
||||
func @caseRegion(%index: tensor<i32>, %arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
|
||||
// CHECK: [[VAL0:%.+]] = "mhlo.tuple"([[ARG1]])
|
||||
// CHECK: [[VAL1:%.+]] = "mhlo.tuple"([[ARG0]], [[ARG1]])
|
||||
// CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[ARG0]], [[ARG1]])
|
||||
// CHECK: [[VAL3:%.+]]:2 = "mhlo.case"([[BRANCH_INDEX]], [[VAL0]], [[VAL1]], [[VAL2]]) ( {
|
||||
%0:2 = "tf.CaseRegion"(%index) ( {
|
||||
// CHECK: ^{{[a-z0-9]+}}([[BRANCH0_ARG:%.+]]: tuple<tensor<f32>>):
|
||||
// CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[BRANCH0_ARG]]) {index = 0 : i32}
|
||||
// CHECK: [[VAL5:%.+]] = "mhlo.exponential"([[VAL4]])
|
||||
%1 = "mhlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: "mhlo.return"([[VAL5]], [[VAL4]])
|
||||
"tf.Yield"(%1, %arg1) : (tensor<f32>, tensor<f32>) -> ()
|
||||
}, {
|
||||
// CHECK: ^{{[a-z0-9]+}}([[BRANCH1_ARG:%.+]]: tuple<tensor<f32>, tensor<f32>>):
|
||||
// CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[BRANCH1_ARG]]) {index = 0 : i32}
|
||||
// CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[BRANCH1_ARG]]) {index = 1 : i32}
|
||||
// CHECK: [[VAL6:%.+]] = "mhlo.log"([[VAL4]])
|
||||
%1 = "mhlo.log"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: "mhlo.return"([[VAL6]], [[VAL5]])
|
||||
"tf.Yield"(%1, %arg1) : (tensor<f32>, tensor<f32>) -> ()
|
||||
}, {
|
||||
// CHECK: ^{{[a-z0-9]+}}([[BRANCH2_ARG:%.+]]: tuple<tensor<f32>, tensor<f32>>):
|
||||
// CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[BRANCH2_ARG]]) {index = 0 : i32}
|
||||
// CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[BRANCH2_ARG]]) {index = 1 : i32}
|
||||
// CHECK: [[VAL6:%.+]] = "mhlo.floor"([[VAL4]])
|
||||
%1 = "mhlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: "mhlo.return"([[VAL6]], [[VAL5]])
|
||||
"tf.Yield"(%1, %arg1) : (tensor<f32>, tensor<f32>) -> ()
|
||||
// CHECK: }) : (tensor<i32>, tuple<tensor<f32>>, tuple<tensor<f32>, tensor<f32>>, tuple<tensor<f32>, tensor<f32>>) -> (tensor<f32>, tensor<f32>)
|
||||
}) {is_stateless = true} : (tensor<i32>) -> (tensor<f32>, tensor<f32>)
|
||||
// CHECK: return [[VAL3]]#0, [[VAL3]]#1 : tensor<f32>, tensor<f32>
|
||||
return %0#0, %0#1 : tensor<f32>, tensor<f32>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func @while
|
||||
func @while() -> tensor<i32> {
|
||||
// CHECK: [[VAL0:%.+]] = mhlo.constant dense<0>
|
||||
|
@ -20,33 +20,24 @@ limitations under the License.
|
||||
#include <cstdint>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <tuple>
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
using mlir::PassRegistration;
|
||||
|
||||
@ -290,6 +281,34 @@ void LowerIfRegion(TF::IfRegionOp op) {
|
||||
op.erase();
|
||||
}
|
||||
|
||||
void LowerCaseRegion(TF::CaseRegionOp op) {
|
||||
Location loc = op.getLoc();
|
||||
OpBuilder builder(op);
|
||||
|
||||
llvm::SmallVector<Value, 4> branch_inputs;
|
||||
branch_inputs.reserve(op.branches().size());
|
||||
// Tuple implicit inputs per region and update terminators.
|
||||
for (Region& region : op.branches()) {
|
||||
builder.setInsertionPoint(op);
|
||||
Value branch_input = TupleImplicitInputs(region, loc, &builder);
|
||||
branch_inputs.emplace_back(branch_input);
|
||||
ReplaceTerminator(®ion.front(), /*extra_results=*/{}, &builder,
|
||||
/*tuple_return=*/false);
|
||||
}
|
||||
|
||||
// Create the new `mhlo.case` op with tuple inputs and take ownership of
|
||||
// regions from `tf.CaseRegion` op.
|
||||
builder.setInsertionPoint(op);
|
||||
auto case_op =
|
||||
builder.create<mhlo::CaseOp>(loc, op.getResultTypes(), op.branch_index(),
|
||||
branch_inputs, branch_inputs.size());
|
||||
for (auto region : llvm::zip(case_op.branches(), op.branches()))
|
||||
std::get<0>(region).takeBody(std::get<1>(region));
|
||||
|
||||
op.replaceAllUsesWith(case_op.getResults());
|
||||
op.erase();
|
||||
}
|
||||
|
||||
void LowerWhileRegion(TF::WhileRegionOp op) {
|
||||
Location loc = op.getLoc();
|
||||
OpBuilder builder(op);
|
||||
@ -370,6 +389,10 @@ void LegalizeTFControlFlow::runOnOperation() {
|
||||
LowerCase(case_op);
|
||||
return;
|
||||
}
|
||||
if (auto case_region_op = dyn_cast<TF::CaseRegionOp>(op)) {
|
||||
LowerCaseRegion(case_region_op);
|
||||
return;
|
||||
}
|
||||
});
|
||||
}
|
||||
} // namespace mhlo
|
||||
|
Loading…
Reference in New Issue
Block a user