Added tf.If and xla_hlo.Conditional with appropriate lowerings
- includes lowering tf.If to xla_hlo.Conditional - importing xla_hlo via the HloFunctionImporter - lowering from xla_hlo.Conditional to MLIR control flow operation PiperOrigin-RevId: 280553558 Change-Id: I3d86c981176d7a97f1ae1a02ea292e8da562abae
This commit is contained in:
parent
2dc5684649
commit
3c26f1a410
@ -102,6 +102,7 @@ cc_library(
|
||||
srcs = [
|
||||
"transforms/generated_legalize_tf.inc",
|
||||
"transforms/legalize_tf.cc",
|
||||
"transforms/legalize_tf_control_flow.cc",
|
||||
],
|
||||
deps = [
|
||||
":convert_op_folder",
|
||||
@ -210,6 +211,7 @@ cc_library(
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -331,6 +331,19 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
ConvertDimensions(instruction->slice_strides()))
|
||||
.getOperation();
|
||||
}
|
||||
case HloOpcode::kConditional: {
|
||||
llvm::SmallVector<Type, 4> rets;
|
||||
TF_RETURN_IF_ERROR(GetMlirTypes(
|
||||
{instruction->true_computation()->root_instruction()}, &rets));
|
||||
|
||||
auto op = func_builder->create<mlir::xla_hlo::ConditionalOp>(
|
||||
loc, rets, operands, attributes);
|
||||
TF_RETURN_IF_ERROR(ImportComputation(instruction->true_computation(),
|
||||
&op.true_branch()));
|
||||
TF_RETURN_IF_ERROR(ImportComputation(instruction->false_computation(),
|
||||
&op.false_branch()));
|
||||
return op.getOperation();
|
||||
}
|
||||
case HloOpcode::kConcatenate: {
|
||||
// TODO(b/132057942): Support taking an uint64_t instead of an IntegerAttr
|
||||
// for concatenate dimension.
|
||||
|
@ -887,6 +887,38 @@ static LogicalResult Verify(TransposeOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GetTupleElementOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void GetTupleElementOp::build(Builder* builder, OperationState& result,
|
||||
Value* tuple, int32_t index) {
|
||||
if (auto tuple_type = tuple->getType().dyn_cast<TupleType>()) {
|
||||
auto element_type = tuple_type.getType(index);
|
||||
build(builder, result, element_type, tuple,
|
||||
builder->getI32IntegerAttr(index));
|
||||
return;
|
||||
}
|
||||
|
||||
build(builder, result, tuple->getType(), tuple,
|
||||
builder->getI32IntegerAttr(index));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TupleOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void TupleOp::build(Builder* builder, OperationState& result,
|
||||
ArrayRef<Value*> values) {
|
||||
SmallVector<Type, 4> types;
|
||||
types.reserve(values.size());
|
||||
for (auto val : values) {
|
||||
types.push_back(val->getType());
|
||||
}
|
||||
|
||||
build(builder, result, builder->getTupleType(types), values);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CompareOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -318,6 +318,30 @@ def HLO_XorOp : HLO_BinaryLogicalElementwiseOp<"xor">, BASE_HLO_XorOp;
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA control flow op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def HLO_ConditionalOp: HLO_Op<"conditional", [NoSideEffect]> {
|
||||
string summary = "Conditional operator";
|
||||
|
||||
string description = [{
|
||||
Returns the result of executing either a true or false function depending on
|
||||
the result of a condition function.
|
||||
|
||||
See https://www.tensorflow.org/xla/operation_semantics#conditional.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
HLO_PredTensor:$pred,
|
||||
HLO_TensorOrTuple:$true_arg,
|
||||
HLO_TensorOrTuple:$false_arg
|
||||
);
|
||||
|
||||
let regions = (region AnyRegion:$true_branch, AnyRegion:$false_branch);
|
||||
|
||||
let results = (outs HLO_TensorOrTuple);
|
||||
|
||||
// TODO(b/129422361): ConditionalOp has special conversion logic to HLO.
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
def HLO_WhileOp: HLO_Op<"while", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
string summary = "While operator";
|
||||
|
||||
@ -381,6 +405,10 @@ def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState &results, "
|
||||
"Value* value, int32_t index">];
|
||||
|
||||
// GetTupleElementOp has special conversion logic to HLO.
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
@ -389,6 +417,10 @@ def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp {
|
||||
let arguments = (ins Variadic<HLO_TensorOrTuple>:$val);
|
||||
let results = (outs HLO_Tuple);
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState &results, "
|
||||
"ArrayRef<Value*> values">];
|
||||
|
||||
// TupleOp has special conversion logic to HLO.
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
@ -369,6 +369,24 @@ LogicalResult ExportXlaOp(ConcatenateOp op, OpLoweringContext ctx) {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ExportXlaOp(ConditionalOp op, OpLoweringContext ctx) {
|
||||
xla::XlaComputation true_branch;
|
||||
xla::XlaComputation false_branch;
|
||||
auto& value_map = *ctx.values;
|
||||
if (failed(ctx.converter->LowerRegionAsComputation(&op.true_branch(),
|
||||
&true_branch)) ||
|
||||
failed(ctx.converter->LowerRegionAsComputation(&op.false_branch(),
|
||||
&false_branch))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
value_map[op] =
|
||||
xla::Conditional(value_map[op.pred()], value_map[op.true_arg()],
|
||||
true_branch, value_map[op.false_arg()], false_branch);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ExportXlaOp(ConstOp op, OpLoweringContext ctx) {
|
||||
return failure();
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
// RUN: tf-opt -xla-legalize-control-flow %s -o - | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
func @main(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
// CHECK-LABEL: func @while(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
func @while(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
//CHECK: br ^bb1(%arg0 : tensor<i64>)
|
||||
//CHECK: ^bb1([[VAL0:%.+]]: tensor<i64>):
|
||||
//CHECK: [[VAL1:%.+]] = "xla_hlo.compare"([[VAL0]], [[VAL0]])
|
||||
@ -24,3 +24,37 @@ func @main(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
// CHECK-NEXT: return [[VAL5]]
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @conditional
|
||||
func @conditional(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
// CHECK: [[C0:%.+]] = constant dense<1.000000e+01> : tensor<f32>
|
||||
%cst = constant dense<1.000000e+01> : tensor<f32>
|
||||
|
||||
// CHECK: [[VAL0:%.+]] = "xla_hlo.compare"(%arg0, [[C0]]) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
%0 = "xla_hlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
|
||||
// CHECK: [[VAL1:%.+]] = extract_element [[VAL0]][] : tensor<i1>
|
||||
// CHECK: cond_br [[VAL1]], ^bb1(%arg0 : tensor<f32>), ^bb2(%arg0 : tensor<f32>)
|
||||
%1 = "xla_hlo.conditional"(%0, %arg0, %arg0) ( {
|
||||
|
||||
^bb0(%arg1: tensor<f32>):
|
||||
// CHECK: ^bb1([[VAL2:%.+]]: tensor<f32>):
|
||||
// CHECK: [[VAL3:%.+]] = "xla_hlo.log"([[VAL2]]) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: br ^bb3([[VAL3]] : tensor<f32>)
|
||||
%2 = "xla_hlo.log"(%arg1) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
|
||||
}, {
|
||||
|
||||
^bb0(%arg1: tensor<f32>):
|
||||
// CHECK: ^bb2([[VAL4:%.+]]: tensor<f32>):
|
||||
// CHECK: [[VAL5:%.+]] = "xla_hlo.exp"([[VAL4]]) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: br ^bb3([[VAL5]] : tensor<f32>)
|
||||
%2 = "xla_hlo.exp"(%arg1) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
|
||||
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
|
||||
// CHECK: ^bb3([[VAL6:%.+]]: tensor<f32>):
|
||||
// CHECK: return [[VAL6]] : tensor<f32>
|
||||
return %1 : tensor<f32>
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,38 @@
|
||||
// RUN: tf-opt -xla-legalize-tf-control-flow %s | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// CHECK-LABEL: @conditional
|
||||
func @conditional(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>)
|
||||
attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} {
|
||||
// CHECK: [[VAL0:%.+]] = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
|
||||
// CHECK: [[VAL1:%.+]] = "xla_hlo.tuple"(%arg0, %arg1)
|
||||
// CHECK: [[VAL2:%.+]] = "xla_hlo.conditional"([[VAL0]], %1, %1) ( {
|
||||
// CHECK: ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
|
||||
// CHECK: [[VAL4:%.+]] = "xla_hlo.log"(%arg2)
|
||||
// CHECK: [[VAL5:%.+]] = "xla_hlo.tuple"([[VAL4]])
|
||||
// CHECK: "xla_hlo.return"([[VAL5]])
|
||||
// CHECK: }, {
|
||||
// CHECK: ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>)
|
||||
// CHECK: [[VAL4:%.+]] = "xla_hlo.exp"(%arg3)
|
||||
// CHECK: [[VAL5:%.+]] = "xla_hlo.tuple"([[VAL4]])
|
||||
// CHECK: "xla_hlo.return"([[VAL5]])
|
||||
// CHECK: })
|
||||
%1 = "tf.If"(%0, %arg0, %arg1) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _lower_using_switch_merge = true, _output_shapes = ["tfshape$"], device = "", else_branch = @cond_false, is_stateless = true, name = "cond", output_shapes = ["tfshape$"], then_branch = @cond_true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
|
||||
// CHECK: [[VAL3:%.+]] = "xla_hlo.get_tuple_element"([[VAL2]]) {index = 0 : i32}
|
||||
// CHECK: return [[VAL3]]
|
||||
return %1 : tensor<f32>
|
||||
}
|
||||
|
||||
func @cond_false(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32>
|
||||
attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} {
|
||||
%0 = "xla_hlo.exp"(%arg1) : (tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
func @cond_true(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32>
|
||||
attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} {
|
||||
%0 = "xla_hlo.log"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
@ -0,0 +1,53 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule tfcompile.20
|
||||
|
||||
%then_branch {
|
||||
%arg_tuple.7 = (f32[]) parameter(0), metadata={op_name="XLA_Args"}
|
||||
%get-tuple-element.8 = f32[] get-tuple-element(%arg_tuple.7), index=0, metadata={op_name="XLA_Args"}
|
||||
%log.9 = f32[] log(%get-tuple-element.8), metadata={op_type="Log" op_name="cond/Log"}
|
||||
ROOT %tuple.10 = (f32[]) tuple(%log.9), metadata={op_name="XLA_Retvals"}
|
||||
}
|
||||
|
||||
%else_branch {
|
||||
%arg_tuple.12 = (f32[]) parameter(0), metadata={op_name="XLA_Args"}
|
||||
%get-tuple-element.13 = f32[] get-tuple-element(%arg_tuple.12), index=0, metadata={op_name="XLA_Args"}
|
||||
%exponential.14 = f32[] exponential(%get-tuple-element.13), metadata={op_type="Exp" op_name="cond/Exp"}
|
||||
ROOT %tuple.15 = (f32[]) tuple(%exponential.14), metadata={op_name="XLA_Retvals"}
|
||||
}
|
||||
|
||||
// CHECK: func @main([[A0:%.+]]: tensor<f32>)
|
||||
ENTRY %tfcompile.20 {
|
||||
%arg0.1 = f32[] parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"}
|
||||
|
||||
// CHECK: [[C0:%.+]] = constant
|
||||
%constant.3 = f32[] constant(10), metadata={op_type="Less" op_name="Less"}
|
||||
|
||||
// CHECK: [[R1:%.+]] = "xla_hlo.compare"([[A0]], [[C0]])
|
||||
%compare.4 = pred[] compare(%arg0.1, %constant.3), direction=LT, metadata={op_type="Less" op_name="Less"}
|
||||
|
||||
// CHECK: [[R2:%.+]] = "xla_hlo.tuple"([[A0]])
|
||||
%tuple.5 = (f32[]) tuple(%arg0.1), metadata={op_type="If" op_name="cond/Merge_if"}
|
||||
|
||||
// CHECK: [[R3:%.+]] = "xla_hlo.conditional"([[R1]], [[R2]], [[R2]]) ( {
|
||||
// CHECK: ^bb0([[A1:%.+]]: tuple<tensor<f32>>):
|
||||
// CHECK: [[R7:%.+]] = "xla_hlo.get_tuple_element"([[A1]])
|
||||
// CHECK: [[R8:%.+]] = "xla_hlo.log"([[R7]])
|
||||
// CHECK: [[R9:%.+]] = "xla_hlo.tuple"([[R8]])
|
||||
// CHECK: "xla_hlo.return"([[R9]])
|
||||
// CHECK: }, {
|
||||
// CHECK: ^bb0([[A1:%.+]]: tuple<tensor<f32>>):
|
||||
// CHECK: [[R7:%.+]] = "xla_hlo.get_tuple_element"([[A1]])
|
||||
// CHECK: [[R8:%.+]] = "xla_hlo.exp"([[R7]])
|
||||
// CHECK: [[R9:%.+]] = "xla_hlo.tuple"([[R8]])
|
||||
// CHECK: "xla_hlo.return"([[R9]])
|
||||
// CHECK: })
|
||||
%conditional.16 = (f32[]) conditional(%compare.4, %tuple.5, %tuple.5), true_computation=%then_branch, false_computation=%else_branch, metadata={op_type="If" op_name="cond/Merge_if"}
|
||||
|
||||
// CHECK: [[R4:%.+]] = "xla_hlo.get_tuple_element"([[R3]])
|
||||
%get-tuple-element.17 = f32[] get-tuple-element(%conditional.16), index=0, metadata={op_type="If" op_name="cond/Merge_if"}
|
||||
|
||||
// CHECK: [[R5:%.+]] = "xla_hlo.tuple"([[R4]])
|
||||
// CHECK: return [[R5]]
|
||||
ROOT %tuple.19 = (f32[]) tuple(%get-tuple-element.17), metadata={op_name="XLA_Retvals"}
|
||||
}
|
@ -0,0 +1,64 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
// CHECK: [[R0:%.+]] ([[A0:.+]]: (f32[])) -> (f32[]) {
|
||||
// CHECK: %[[A0]] = (f32[]) parameter(0)
|
||||
func @then_branch(%arg0: tuple<tensor<f32>>) -> tuple<tensor<f32>> {
|
||||
// CHECK: %[[VAL0:.+]] = f32[] get-tuple-element((f32[]) %[[A0]]), index=0
|
||||
%0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
|
||||
|
||||
// CHECK: %[[VAL1:.+]] = f32[] log(f32[] %[[VAL0]])
|
||||
%1 = "xla_hlo.log"(%0) : (tensor<f32>) -> tensor<f32>
|
||||
|
||||
// CHECK: ROOT %[[VAl2:.+]] = (f32[]) tuple(f32[] %[[VAL1]])
|
||||
%2 = "xla_hlo.tuple"(%1) : (tensor<f32>) -> tuple<tensor<f32>>
|
||||
return %2 : tuple<tensor<f32>>
|
||||
}
|
||||
|
||||
// CHECK: [[R1:%.+]] ([[A0:.+]]: (f32[])) -> (f32[]) {
|
||||
// CHECK: %[[A0]] = (f32[]) parameter(0)
|
||||
func @else_branch(%arg0: tuple<tensor<f32>>) -> tuple<tensor<f32>> {
|
||||
// CHECK: %[[VAL0:.+]] = f32[] get-tuple-element((f32[]) %[[A0]]), index=0
|
||||
%0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
|
||||
|
||||
// CHECK: %[[VAL1:.+]] = f32[] exponential(f32[] %[[VAL0]])
|
||||
%1 = "xla_hlo.exp"(%0) : (tensor<f32>) -> tensor<f32>
|
||||
|
||||
// CHECK: ROOT %[[VAL2:.+]] = (f32[]) tuple(f32[] %[[VAL1]])
|
||||
%2 = "xla_hlo.tuple"(%1) : (tensor<f32>) -> tuple<tensor<f32>>
|
||||
return %2 : tuple<tensor<f32>>
|
||||
}
|
||||
|
||||
// CHECK: ENTRY [[R3:%.+]] ([[A0:.+]]: f32[]) -> (f32[]) {
|
||||
// CHECK: %[[A0]] = f32[] parameter(0)
|
||||
func @main(%arg0: tensor<f32>) -> tuple<tensor<f32>> {
|
||||
// CHECK: %[[VAL0:.+]] = f32[] constant(10)
|
||||
%cst = constant dense<1.000000e+01> : tensor<f32>
|
||||
|
||||
// CHECK: %[[VAL1:.+]] = pred[] compare(f32[] %[[A0]], f32[] %[[VAL0]]), direction=LT
|
||||
%0 = "xla_hlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
|
||||
// CHECK: %[[VAL2:.+]] = (f32[]) tuple(f32[] %[[A0]])
|
||||
%1 = "xla_hlo.tuple"(%arg0) : (tensor<f32>) -> tuple<tensor<f32>>
|
||||
|
||||
// CHECK: %[[VAL3:.+]] = (f32[]) conditional(pred[] %[[VAL1]], (f32[]) %[[VAL2]], (f32[]) %[[VAL2]]), true_computation=[[R0]], false_computation=[[R1]]
|
||||
%2 = "xla_hlo.conditional"(%0, %1, %1) ( {
|
||||
^bb0(%arg1: tuple<tensor<f32>>): // no predecessors
|
||||
%6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
|
||||
%7 = "xla_hlo.log"(%6) : (tensor<f32>) -> tensor<f32>
|
||||
%8 = "xla_hlo.tuple"(%7) : (tensor<f32>) -> tuple<tensor<f32>>
|
||||
"xla_hlo.return"(%8) : (tuple<tensor<f32>>) -> ()
|
||||
}, {
|
||||
^bb0(%arg1: tuple<tensor<f32>>): // no predecessors
|
||||
%6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
|
||||
%7 = "xla_hlo.exp"(%6) : (tensor<f32>) -> tensor<f32>
|
||||
%8 = "xla_hlo.tuple"(%7) : (tensor<f32>) -> tuple<tensor<f32>>
|
||||
"xla_hlo.return"(%8) : (tuple<tensor<f32>>) -> ()
|
||||
}) : (tensor<i1>, tuple<tensor<f32>>, tuple<tensor<f32>>) -> tuple<tensor<f32>>
|
||||
|
||||
// CHECK: %[[VAL4:.+]] = f32[] get-tuple-element((f32[]) %[[VAL3]]), index=0
|
||||
%3 = "xla_hlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
|
||||
|
||||
// CHECK: ROOT %[[VAL5:.+]] = (f32[]) tuple(f32[] %[[VAL4]])
|
||||
%4 = "xla_hlo.tuple"(%3) : (tensor<f32>) -> tuple<tensor<f32>>
|
||||
return %4 : tuple<tensor<f32>>
|
||||
}
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||
|
||||
@ -39,6 +40,71 @@ struct LegalizeControlFlow : public mlir::FunctionPass<LegalizeControlFlow> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
// Replaces terminators for the newly created blocks from a targe region.
|
||||
// These terminators are replaced with branch operations to a target block.
|
||||
LogicalResult ReplaceTerminators(Region* region, Block* target_block,
|
||||
Location loc,
|
||||
const BlockAndValueMapping& mapper,
|
||||
OpBuilder* builder) {
|
||||
for (auto& old_block : region->getBlocks()) {
|
||||
Block* block = mapper.lookup(&old_block);
|
||||
auto return_op = dyn_cast<xla_hlo::ReturnOp>(block->getTerminator());
|
||||
if (!return_op) return failure();
|
||||
builder->setInsertionPointToEnd(block);
|
||||
|
||||
SmallVector<Value*, 4> args(return_op.getOperands());
|
||||
builder->create<mlir::BranchOp>(loc, target_block, args);
|
||||
return_op.getOperation()->erase();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult LowerConditionalOp(mlir::xla_hlo::ConditionalOp conditional_op) {
|
||||
Operation* op_inst = conditional_op.getOperation();
|
||||
mlir::OpBuilder builder(conditional_op);
|
||||
auto orig_block = op_inst->getBlock();
|
||||
auto* tail_block = orig_block->splitBlock(op_inst);
|
||||
auto loc = conditional_op.getLoc();
|
||||
|
||||
// Duplicate the true and false regions in the block between the sections
|
||||
// before and after the while loop.
|
||||
BlockAndValueMapping mapper;
|
||||
conditional_op.true_branch().cloneInto(orig_block->getParent(),
|
||||
Region::iterator(tail_block), mapper);
|
||||
conditional_op.false_branch().cloneInto(orig_block->getParent(),
|
||||
Region::iterator(tail_block), mapper);
|
||||
|
||||
Block* true_block = mapper.lookup(&conditional_op.true_branch().front());
|
||||
Block* false_block = mapper.lookup(&conditional_op.false_branch().front());
|
||||
|
||||
// Perform the conditional branch into the true/false cases.
|
||||
builder.setInsertionPointToEnd(orig_block);
|
||||
|
||||
// Extract the predicate for checking branching, then branch to the true and
|
||||
// false blocks appropriately.
|
||||
auto cond_value =
|
||||
builder.create<mlir::ExtractElementOp>(loc, conditional_op.pred());
|
||||
builder.create<mlir::CondBranchOp>(loc, cond_value, true_block,
|
||||
conditional_op.true_arg(), false_block,
|
||||
conditional_op.false_arg());
|
||||
|
||||
// Replace the true case's return operations with a branche to the tail of
|
||||
// the condition.
|
||||
if (failed(ReplaceTerminators(&conditional_op.true_branch(), tail_block, loc,
|
||||
mapper, &builder)))
|
||||
return failure();
|
||||
if (failed(ReplaceTerminators(&conditional_op.false_branch(), tail_block, loc,
|
||||
mapper, &builder)))
|
||||
return failure();
|
||||
|
||||
tail_block->addArguments(conditional_op.getResult()->getType());
|
||||
conditional_op.getResult()->replaceAllUsesWith(tail_block->getArgument(0));
|
||||
|
||||
op_inst->erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
bool LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
|
||||
// Converts an xla while loop into control flow. This mostly generates the
|
||||
// right MLIR boilerplate for calling the body / condition functions, then
|
||||
@ -48,7 +114,7 @@ bool LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
|
||||
// <prior operations>
|
||||
// %0 = "xla_hlo.while"(%arg0) {body: @loop, cond: @cond}
|
||||
// <post operations>
|
||||
auto* opInst = while_op.getOperation();
|
||||
auto* op_inst = while_op.getOperation();
|
||||
mlir::OpBuilder builder(while_op);
|
||||
auto loc = while_op.getLoc();
|
||||
|
||||
@ -64,8 +130,8 @@ bool LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
|
||||
// cond_block - check the looping condition, then conditionally branch into
|
||||
// the loop or, if condition is false, jump to the tail branch.
|
||||
// body_block - call the loop body, then jump back to the condition block.
|
||||
auto* orig_block = opInst->getBlock();
|
||||
auto* tail_block = orig_block->splitBlock(opInst);
|
||||
auto* orig_block = op_inst->getBlock();
|
||||
auto* tail_block = orig_block->splitBlock(op_inst);
|
||||
|
||||
BlockAndValueMapping mapper;
|
||||
while_op.cond().cloneInto(orig_block->getParent(),
|
||||
@ -144,28 +210,29 @@ bool LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
|
||||
return_op.getOperation()->erase();
|
||||
}
|
||||
|
||||
// Setup the tail block:
|
||||
// ^tail(%5):
|
||||
// <post operations>
|
||||
llvm::SmallVector<Value*, 4> tail_block_arguments;
|
||||
tail_block_arguments.reserve(while_op.getNumOperands());
|
||||
|
||||
// Erase the original while loop.
|
||||
for (int i = 0; i < while_op.getNumOperands(); i++) {
|
||||
tail_block->addArgument(while_op.getOperand(i)->getType());
|
||||
while_op.getResult(i)->replaceAllUsesWith(tail_block->getArgument(i));
|
||||
}
|
||||
opInst->erase();
|
||||
op_inst->erase();
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void LegalizeControlFlow::runOnFunction() {
|
||||
auto func = getFunction();
|
||||
llvm::SmallVector<WhileOp, 4> control_flow_ops;
|
||||
func.walk([&](WhileOp op) { control_flow_ops.push_back(op); });
|
||||
llvm::SmallVector<ConditionalOp, 4> conditional_ops;
|
||||
func.walk([&](ConditionalOp op) { conditional_ops.push_back(op); });
|
||||
|
||||
for (auto& op : control_flow_ops) {
|
||||
for (auto& op : conditional_ops) {
|
||||
if (failed(LowerConditionalOp(op))) return signalPassFailure();
|
||||
}
|
||||
|
||||
llvm::SmallVector<WhileOp, 4> while_ops;
|
||||
func.walk([&](WhileOp op) { while_ops.push_back(op); });
|
||||
|
||||
for (auto& op : while_ops) {
|
||||
if (LowerWhileOp(op)) return signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,120 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file implements logic for lowering TensorFlow dialect's control flow to
|
||||
// the XLA dialect.
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
|
||||
#include "mlir/Transforms/DialectConversion.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
using mlir::PassRegistration;
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace {
|
||||
class LegalizeTFControlFlow : public ModulePass<LegalizeTFControlFlow> {
|
||||
public:
|
||||
void runOnModule() override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
|
||||
createLegalizeTFControlFlowPass() {
|
||||
return std::make_unique<LegalizeTFControlFlow>();
|
||||
}
|
||||
|
||||
namespace {
|
||||
void ImportXlaRegion(Region* src_region, Region* dest_region) {
|
||||
BlockAndValueMapping mapper;
|
||||
src_region->cloneInto(dest_region, mapper);
|
||||
dest_region->walk([&](mlir::ReturnOp op) -> void {
|
||||
OpBuilder builder(op);
|
||||
llvm::SmallVector<Value*, 4> operands(op.operands());
|
||||
auto tuple = builder.create<xla_hlo::TupleOp>(op.getLoc(), operands);
|
||||
builder.create<xla_hlo::ReturnOp>(op.getLoc(), tuple.getResult());
|
||||
op.erase();
|
||||
});
|
||||
}
|
||||
|
||||
void LowerIf(TF::IfOp op, ModuleOp module) {
|
||||
Location loc = op.getLoc();
|
||||
|
||||
OpBuilder builder(op);
|
||||
// XLA prefers tuple arguments for control flow due to XLA not supporting
|
||||
// multiple return values.
|
||||
SmallVector<Value*, 3> inputs(op.input());
|
||||
builder.setInsertionPoint(op);
|
||||
auto tuple_input = builder.create<xla_hlo::TupleOp>(loc, inputs);
|
||||
|
||||
// Create the new conditional op with tuple inputs.
|
||||
SmallVector<Value*, 3> operands(op.getOperands());
|
||||
SmallVector<Type, 4> types(op.getResultTypes());
|
||||
auto result_type = builder.getTupleType(types);
|
||||
auto conditional_result = builder.create<xla_hlo::ConditionalOp>(
|
||||
loc, result_type, op.cond(), tuple_input, tuple_input);
|
||||
|
||||
// Import the regions for both the true and false cases. These regions
|
||||
// must be updated to tuple the return results together and use the xla hlo
|
||||
// return op.
|
||||
BlockAndValueMapping mapper;
|
||||
auto then_branch = module.lookupSymbol<mlir::FuncOp>(op.then_branch());
|
||||
auto else_branch = module.lookupSymbol<mlir::FuncOp>(op.else_branch());
|
||||
ImportXlaRegion(&then_branch.getBody(), &conditional_result.true_branch());
|
||||
ImportXlaRegion(&else_branch.getBody(), &conditional_result.false_branch());
|
||||
|
||||
// De-tuple the results of the xla hlo conditional result.
|
||||
builder.setInsertionPointAfter(op);
|
||||
for (auto result_it : llvm::enumerate(op.getResults())) {
|
||||
auto get_tuple_value = builder.create<xla_hlo::GetTupleElementOp>(
|
||||
loc, conditional_result, result_it.index());
|
||||
result_it.value()->replaceAllUsesWith(get_tuple_value);
|
||||
}
|
||||
|
||||
op.erase();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void LegalizeTFControlFlow::runOnModule() {
|
||||
auto module = getModule();
|
||||
|
||||
TypeConverter type_converter;
|
||||
module.walk([&](TF::IfOp op) -> void { LowerIf(op, module); });
|
||||
}
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
||||
|
||||
static PassRegistration<mlir::xla_hlo::LegalizeTFControlFlow> cfpass(
|
||||
"xla-legalize-tf-control-flow",
|
||||
"Legalize TensorFlow control flow to the XLA dialect");
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
namespace mlir {
|
||||
|
||||
class FuncOp;
|
||||
class ModuleOp;
|
||||
class Operation;
|
||||
template <typename T>
|
||||
class OpPassBase;
|
||||
@ -33,6 +34,9 @@ namespace xla_hlo {
|
||||
/// Lowers from TF dialect to HLO dialect.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> createLegalizeTFPass();
|
||||
|
||||
/// Lowers from TF dialect's control flow to HLO dialect's control flow.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> createLegalizeTFControlFlowPass();
|
||||
|
||||
/// Converts the provided Operation as well as all nested operations into HLO
|
||||
/// dialect using the conversion patterns registered by the HLO dialect.
|
||||
LogicalResult legalizeTF(Operation* op);
|
||||
|
Loading…
x
Reference in New Issue
Block a user