Added tf.If and xla_hlo.Conditional with appropriate lowerings

- includes lowering tf.If to xla_hlo.Conditional
- importing xla_hlo via the HloFunctionImporter
- lowering from xla_hlo.Conditional to MLIR control flow operation

PiperOrigin-RevId: 280553558
Change-Id: I3d86c981176d7a97f1ae1a02ea292e8da562abae
This commit is contained in:
A. Unique TensorFlower 2019-11-14 17:39:04 -08:00 committed by TensorFlower Gardener
parent 2dc5684649
commit 3c26f1a410
12 changed files with 494 additions and 17 deletions

View File

@ -102,6 +102,7 @@ cc_library(
srcs = [
"transforms/generated_legalize_tf.inc",
"transforms/legalize_tf.cc",
"transforms/legalize_tf_control_flow.cc",
],
deps = [
":convert_op_folder",
@ -210,6 +211,7 @@ cc_library(
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
],
alwayslink = 1,
)

View File

@ -331,6 +331,19 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
ConvertDimensions(instruction->slice_strides()))
.getOperation();
}
case HloOpcode::kConditional: {
llvm::SmallVector<Type, 4> rets;
TF_RETURN_IF_ERROR(GetMlirTypes(
{instruction->true_computation()->root_instruction()}, &rets));
auto op = func_builder->create<mlir::xla_hlo::ConditionalOp>(
loc, rets, operands, attributes);
TF_RETURN_IF_ERROR(ImportComputation(instruction->true_computation(),
&op.true_branch()));
TF_RETURN_IF_ERROR(ImportComputation(instruction->false_computation(),
&op.false_branch()));
return op.getOperation();
}
case HloOpcode::kConcatenate: {
// TODO(b/132057942): Support taking an uint64_t instead of an IntegerAttr
// for concatenate dimension.

View File

@ -887,6 +887,38 @@ static LogicalResult Verify(TransposeOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// GetTupleElementOp
//===----------------------------------------------------------------------===//
void GetTupleElementOp::build(Builder* builder, OperationState& result,
Value* tuple, int32_t index) {
if (auto tuple_type = tuple->getType().dyn_cast<TupleType>()) {
auto element_type = tuple_type.getType(index);
build(builder, result, element_type, tuple,
builder->getI32IntegerAttr(index));
return;
}
build(builder, result, tuple->getType(), tuple,
builder->getI32IntegerAttr(index));
}
//===----------------------------------------------------------------------===//
// TupleOp
//===----------------------------------------------------------------------===//
void TupleOp::build(Builder* builder, OperationState& result,
ArrayRef<Value*> values) {
SmallVector<Type, 4> types;
types.reserve(values.size());
for (auto val : values) {
types.push_back(val->getType());
}
build(builder, result, builder->getTupleType(types), values);
}
//===----------------------------------------------------------------------===//
// CompareOp
//===----------------------------------------------------------------------===//

View File

@ -318,6 +318,30 @@ def HLO_XorOp : HLO_BinaryLogicalElementwiseOp<"xor">, BASE_HLO_XorOp;
//===----------------------------------------------------------------------===//
// XLA control flow op definitions.
//===----------------------------------------------------------------------===//
def HLO_ConditionalOp: HLO_Op<"conditional", [NoSideEffect]> {
string summary = "Conditional operator";
string description = [{
Returns the result of executing either a true or false function depending on
the result of a condition function.
See https://www.tensorflow.org/xla/operation_semantics#conditional.
}];
let arguments = (ins
HLO_PredTensor:$pred,
HLO_TensorOrTuple:$true_arg,
HLO_TensorOrTuple:$false_arg
);
let regions = (region AnyRegion:$true_branch, AnyRegion:$false_branch);
let results = (outs HLO_TensorOrTuple);
// TODO(b/129422361): ConditionalOp has special conversion logic to HLO.
let hasCustomHLOConverter = 1;
}
def HLO_WhileOp: HLO_Op<"while", [NoSideEffect, SameOperandsAndResultType]> {
string summary = "While operator";
@ -381,6 +405,10 @@ def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO
let hasFolder = 1;
let builders = [OpBuilder<
"Builder *builder, OperationState &results, "
"Value* value, int32_t index">];
// GetTupleElementOp has special conversion logic to HLO.
let hasCustomHLOConverter = 1;
}
@ -389,6 +417,10 @@ def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp {
let arguments = (ins Variadic<HLO_TensorOrTuple>:$val);
let results = (outs HLO_Tuple);
let builders = [OpBuilder<
"Builder *builder, OperationState &results, "
"ArrayRef<Value*> values">];
// TupleOp has special conversion logic to HLO.
let hasCustomHLOConverter = 1;
}

View File

@ -369,6 +369,24 @@ LogicalResult ExportXlaOp(ConcatenateOp op, OpLoweringContext ctx) {
return success();
}
LogicalResult ExportXlaOp(ConditionalOp op, OpLoweringContext ctx) {
xla::XlaComputation true_branch;
xla::XlaComputation false_branch;
auto& value_map = *ctx.values;
if (failed(ctx.converter->LowerRegionAsComputation(&op.true_branch(),
&true_branch)) ||
failed(ctx.converter->LowerRegionAsComputation(&op.false_branch(),
&false_branch))) {
return failure();
}
value_map[op] =
xla::Conditional(value_map[op.pred()], value_map[op.true_arg()],
true_branch, value_map[op.false_arg()], false_branch);
return success();
}
LogicalResult ExportXlaOp(ConstOp op, OpLoweringContext ctx) {
return failure();
}

View File

@ -1,7 +1,7 @@
// RUN: tf-opt -xla-legalize-control-flow %s -o - | FileCheck %s
// CHECK-LABEL: func @main(%arg0: tensor<i64>) -> tensor<i64> {
func @main(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK-LABEL: func @while(%arg0: tensor<i64>) -> tensor<i64> {
func @while(%arg0: tensor<i64>) -> tensor<i64> {
//CHECK: br ^bb1(%arg0 : tensor<i64>)
//CHECK: ^bb1([[VAL0:%.+]]: tensor<i64>):
//CHECK: [[VAL1:%.+]] = "xla_hlo.compare"([[VAL0]], [[VAL0]])
@ -24,3 +24,37 @@ func @main(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK-NEXT: return [[VAL5]]
return %0 : tensor<i64>
}
// CHECK-LABEL: func @conditional
func @conditional(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK: [[C0:%.+]] = constant dense<1.000000e+01> : tensor<f32>
%cst = constant dense<1.000000e+01> : tensor<f32>
// CHECK: [[VAL0:%.+]] = "xla_hlo.compare"(%arg0, [[C0]]) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
%0 = "xla_hlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: [[VAL1:%.+]] = extract_element [[VAL0]][] : tensor<i1>
// CHECK: cond_br [[VAL1]], ^bb1(%arg0 : tensor<f32>), ^bb2(%arg0 : tensor<f32>)
%1 = "xla_hlo.conditional"(%0, %arg0, %arg0) ( {
^bb0(%arg1: tensor<f32>):
// CHECK: ^bb1([[VAL2:%.+]]: tensor<f32>):
// CHECK: [[VAL3:%.+]] = "xla_hlo.log"([[VAL2]]) : (tensor<f32>) -> tensor<f32>
// CHECK: br ^bb3([[VAL3]] : tensor<f32>)
%2 = "xla_hlo.log"(%arg1) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
}, {
^bb0(%arg1: tensor<f32>):
// CHECK: ^bb2([[VAL4:%.+]]: tensor<f32>):
// CHECK: [[VAL5:%.+]] = "xla_hlo.exp"([[VAL4]]) : (tensor<f32>) -> tensor<f32>
// CHECK: br ^bb3([[VAL5]] : tensor<f32>)
%2 = "xla_hlo.exp"(%arg1) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: ^bb3([[VAL6:%.+]]: tensor<f32>):
// CHECK: return [[VAL6]] : tensor<f32>
return %1 : tensor<f32>
}

View File

@ -0,0 +1,38 @@
// RUN: tf-opt -xla-legalize-tf-control-flow %s | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: @conditional
func @conditional(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>)
attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} {
// CHECK: [[VAL0:%.+]] = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
%0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: [[VAL1:%.+]] = "xla_hlo.tuple"(%arg0, %arg1)
// CHECK: [[VAL2:%.+]] = "xla_hlo.conditional"([[VAL0]], %1, %1) ( {
// CHECK: ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
// CHECK: [[VAL4:%.+]] = "xla_hlo.log"(%arg2)
// CHECK: [[VAL5:%.+]] = "xla_hlo.tuple"([[VAL4]])
// CHECK: "xla_hlo.return"([[VAL5]])
// CHECK: }, {
// CHECK: ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>)
// CHECK: [[VAL4:%.+]] = "xla_hlo.exp"(%arg3)
// CHECK: [[VAL5:%.+]] = "xla_hlo.tuple"([[VAL4]])
// CHECK: "xla_hlo.return"([[VAL5]])
// CHECK: })
%1 = "tf.If"(%0, %arg0, %arg1) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _lower_using_switch_merge = true, _output_shapes = ["tfshape$"], device = "", else_branch = @cond_false, is_stateless = true, name = "cond", output_shapes = ["tfshape$"], then_branch = @cond_true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: [[VAL3:%.+]] = "xla_hlo.get_tuple_element"([[VAL2]]) {index = 0 : i32}
// CHECK: return [[VAL3]]
return %1 : tensor<f32>
}
func @cond_false(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32>
attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} {
%0 = "xla_hlo.exp"(%arg1) : (tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
func @cond_true(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32>
attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} {
%0 = "xla_hlo.log"(%arg0) : (tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}

View File

@ -0,0 +1,53 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule tfcompile.20
%then_branch {
%arg_tuple.7 = (f32[]) parameter(0), metadata={op_name="XLA_Args"}
%get-tuple-element.8 = f32[] get-tuple-element(%arg_tuple.7), index=0, metadata={op_name="XLA_Args"}
%log.9 = f32[] log(%get-tuple-element.8), metadata={op_type="Log" op_name="cond/Log"}
ROOT %tuple.10 = (f32[]) tuple(%log.9), metadata={op_name="XLA_Retvals"}
}
%else_branch {
%arg_tuple.12 = (f32[]) parameter(0), metadata={op_name="XLA_Args"}
%get-tuple-element.13 = f32[] get-tuple-element(%arg_tuple.12), index=0, metadata={op_name="XLA_Args"}
%exponential.14 = f32[] exponential(%get-tuple-element.13), metadata={op_type="Exp" op_name="cond/Exp"}
ROOT %tuple.15 = (f32[]) tuple(%exponential.14), metadata={op_name="XLA_Retvals"}
}
// CHECK: func @main([[A0:%.+]]: tensor<f32>)
ENTRY %tfcompile.20 {
%arg0.1 = f32[] parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"}
// CHECK: [[C0:%.+]] = constant
%constant.3 = f32[] constant(10), metadata={op_type="Less" op_name="Less"}
// CHECK: [[R1:%.+]] = "xla_hlo.compare"([[A0]], [[C0]])
%compare.4 = pred[] compare(%arg0.1, %constant.3), direction=LT, metadata={op_type="Less" op_name="Less"}
// CHECK: [[R2:%.+]] = "xla_hlo.tuple"([[A0]])
%tuple.5 = (f32[]) tuple(%arg0.1), metadata={op_type="If" op_name="cond/Merge_if"}
// CHECK: [[R3:%.+]] = "xla_hlo.conditional"([[R1]], [[R2]], [[R2]]) ( {
// CHECK: ^bb0([[A1:%.+]]: tuple<tensor<f32>>):
// CHECK: [[R7:%.+]] = "xla_hlo.get_tuple_element"([[A1]])
// CHECK: [[R8:%.+]] = "xla_hlo.log"([[R7]])
// CHECK: [[R9:%.+]] = "xla_hlo.tuple"([[R8]])
// CHECK: "xla_hlo.return"([[R9]])
// CHECK: }, {
// CHECK: ^bb0([[A1:%.+]]: tuple<tensor<f32>>):
// CHECK: [[R7:%.+]] = "xla_hlo.get_tuple_element"([[A1]])
// CHECK: [[R8:%.+]] = "xla_hlo.exp"([[R7]])
// CHECK: [[R9:%.+]] = "xla_hlo.tuple"([[R8]])
// CHECK: "xla_hlo.return"([[R9]])
// CHECK: })
%conditional.16 = (f32[]) conditional(%compare.4, %tuple.5, %tuple.5), true_computation=%then_branch, false_computation=%else_branch, metadata={op_type="If" op_name="cond/Merge_if"}
// CHECK: [[R4:%.+]] = "xla_hlo.get_tuple_element"([[R3]])
%get-tuple-element.17 = f32[] get-tuple-element(%conditional.16), index=0, metadata={op_type="If" op_name="cond/Merge_if"}
// CHECK: [[R5:%.+]] = "xla_hlo.tuple"([[R4]])
// CHECK: return [[R5]]
ROOT %tuple.19 = (f32[]) tuple(%get-tuple-element.17), metadata={op_name="XLA_Retvals"}
}

View File

@ -0,0 +1,64 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
// CHECK: [[R0:%.+]] ([[A0:.+]]: (f32[])) -> (f32[]) {
// CHECK: %[[A0]] = (f32[]) parameter(0)
func @then_branch(%arg0: tuple<tensor<f32>>) -> tuple<tensor<f32>> {
// CHECK: %[[VAL0:.+]] = f32[] get-tuple-element((f32[]) %[[A0]]), index=0
%0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
// CHECK: %[[VAL1:.+]] = f32[] log(f32[] %[[VAL0]])
%1 = "xla_hlo.log"(%0) : (tensor<f32>) -> tensor<f32>
// CHECK: ROOT %[[VAl2:.+]] = (f32[]) tuple(f32[] %[[VAL1]])
%2 = "xla_hlo.tuple"(%1) : (tensor<f32>) -> tuple<tensor<f32>>
return %2 : tuple<tensor<f32>>
}
// CHECK: [[R1:%.+]] ([[A0:.+]]: (f32[])) -> (f32[]) {
// CHECK: %[[A0]] = (f32[]) parameter(0)
func @else_branch(%arg0: tuple<tensor<f32>>) -> tuple<tensor<f32>> {
// CHECK: %[[VAL0:.+]] = f32[] get-tuple-element((f32[]) %[[A0]]), index=0
%0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
// CHECK: %[[VAL1:.+]] = f32[] exponential(f32[] %[[VAL0]])
%1 = "xla_hlo.exp"(%0) : (tensor<f32>) -> tensor<f32>
// CHECK: ROOT %[[VAL2:.+]] = (f32[]) tuple(f32[] %[[VAL1]])
%2 = "xla_hlo.tuple"(%1) : (tensor<f32>) -> tuple<tensor<f32>>
return %2 : tuple<tensor<f32>>
}
// CHECK: ENTRY [[R3:%.+]] ([[A0:.+]]: f32[]) -> (f32[]) {
// CHECK: %[[A0]] = f32[] parameter(0)
func @main(%arg0: tensor<f32>) -> tuple<tensor<f32>> {
// CHECK: %[[VAL0:.+]] = f32[] constant(10)
%cst = constant dense<1.000000e+01> : tensor<f32>
// CHECK: %[[VAL1:.+]] = pred[] compare(f32[] %[[A0]], f32[] %[[VAL0]]), direction=LT
%0 = "xla_hlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL2:.+]] = (f32[]) tuple(f32[] %[[A0]])
%1 = "xla_hlo.tuple"(%arg0) : (tensor<f32>) -> tuple<tensor<f32>>
// CHECK: %[[VAL3:.+]] = (f32[]) conditional(pred[] %[[VAL1]], (f32[]) %[[VAL2]], (f32[]) %[[VAL2]]), true_computation=[[R0]], false_computation=[[R1]]
%2 = "xla_hlo.conditional"(%0, %1, %1) ( {
^bb0(%arg1: tuple<tensor<f32>>): // no predecessors
%6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
%7 = "xla_hlo.log"(%6) : (tensor<f32>) -> tensor<f32>
%8 = "xla_hlo.tuple"(%7) : (tensor<f32>) -> tuple<tensor<f32>>
"xla_hlo.return"(%8) : (tuple<tensor<f32>>) -> ()
}, {
^bb0(%arg1: tuple<tensor<f32>>): // no predecessors
%6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
%7 = "xla_hlo.exp"(%6) : (tensor<f32>) -> tensor<f32>
%8 = "xla_hlo.tuple"(%7) : (tensor<f32>) -> tuple<tensor<f32>>
"xla_hlo.return"(%8) : (tuple<tensor<f32>>) -> ()
}) : (tensor<i1>, tuple<tensor<f32>>, tuple<tensor<f32>>) -> tuple<tensor<f32>>
// CHECK: %[[VAL4:.+]] = f32[] get-tuple-element((f32[]) %[[VAL3]]), index=0
%3 = "xla_hlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
// CHECK: ROOT %[[VAL5:.+]] = (f32[]) tuple(f32[] %[[VAL4]])
%4 = "xla_hlo.tuple"(%3) : (tensor<f32>) -> tuple<tensor<f32>>
return %4 : tuple<tensor<f32>>
}

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
@ -39,6 +40,71 @@ struct LegalizeControlFlow : public mlir::FunctionPass<LegalizeControlFlow> {
void runOnFunction() override;
};
// Replaces terminators for the newly created blocks from a targe region.
// These terminators are replaced with branch operations to a target block.
LogicalResult ReplaceTerminators(Region* region, Block* target_block,
Location loc,
const BlockAndValueMapping& mapper,
OpBuilder* builder) {
for (auto& old_block : region->getBlocks()) {
Block* block = mapper.lookup(&old_block);
auto return_op = dyn_cast<xla_hlo::ReturnOp>(block->getTerminator());
if (!return_op) return failure();
builder->setInsertionPointToEnd(block);
SmallVector<Value*, 4> args(return_op.getOperands());
builder->create<mlir::BranchOp>(loc, target_block, args);
return_op.getOperation()->erase();
}
return success();
}
LogicalResult LowerConditionalOp(mlir::xla_hlo::ConditionalOp conditional_op) {
Operation* op_inst = conditional_op.getOperation();
mlir::OpBuilder builder(conditional_op);
auto orig_block = op_inst->getBlock();
auto* tail_block = orig_block->splitBlock(op_inst);
auto loc = conditional_op.getLoc();
// Duplicate the true and false regions in the block between the sections
// before and after the while loop.
BlockAndValueMapping mapper;
conditional_op.true_branch().cloneInto(orig_block->getParent(),
Region::iterator(tail_block), mapper);
conditional_op.false_branch().cloneInto(orig_block->getParent(),
Region::iterator(tail_block), mapper);
Block* true_block = mapper.lookup(&conditional_op.true_branch().front());
Block* false_block = mapper.lookup(&conditional_op.false_branch().front());
// Perform the conditional branch into the true/false cases.
builder.setInsertionPointToEnd(orig_block);
// Extract the predicate for checking branching, then branch to the true and
// false blocks appropriately.
auto cond_value =
builder.create<mlir::ExtractElementOp>(loc, conditional_op.pred());
builder.create<mlir::CondBranchOp>(loc, cond_value, true_block,
conditional_op.true_arg(), false_block,
conditional_op.false_arg());
// Replace the true case's return operations with a branche to the tail of
// the condition.
if (failed(ReplaceTerminators(&conditional_op.true_branch(), tail_block, loc,
mapper, &builder)))
return failure();
if (failed(ReplaceTerminators(&conditional_op.false_branch(), tail_block, loc,
mapper, &builder)))
return failure();
tail_block->addArguments(conditional_op.getResult()->getType());
conditional_op.getResult()->replaceAllUsesWith(tail_block->getArgument(0));
op_inst->erase();
return success();
}
bool LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
// Converts an xla while loop into control flow. This mostly generates the
// right MLIR boilerplate for calling the body / condition functions, then
@ -48,7 +114,7 @@ bool LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
// <prior operations>
// %0 = "xla_hlo.while"(%arg0) {body: @loop, cond: @cond}
// <post operations>
auto* opInst = while_op.getOperation();
auto* op_inst = while_op.getOperation();
mlir::OpBuilder builder(while_op);
auto loc = while_op.getLoc();
@ -64,8 +130,8 @@ bool LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
// cond_block - check the looping condition, then conditionally branch into
// the loop or, if condition is false, jump to the tail branch.
// body_block - call the loop body, then jump back to the condition block.
auto* orig_block = opInst->getBlock();
auto* tail_block = orig_block->splitBlock(opInst);
auto* orig_block = op_inst->getBlock();
auto* tail_block = orig_block->splitBlock(op_inst);
BlockAndValueMapping mapper;
while_op.cond().cloneInto(orig_block->getParent(),
@ -144,28 +210,29 @@ bool LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
return_op.getOperation()->erase();
}
// Setup the tail block:
// ^tail(%5):
// <post operations>
llvm::SmallVector<Value*, 4> tail_block_arguments;
tail_block_arguments.reserve(while_op.getNumOperands());
// Erase the original while loop.
for (int i = 0; i < while_op.getNumOperands(); i++) {
tail_block->addArgument(while_op.getOperand(i)->getType());
while_op.getResult(i)->replaceAllUsesWith(tail_block->getArgument(i));
}
opInst->erase();
op_inst->erase();
return false;
}
void LegalizeControlFlow::runOnFunction() {
auto func = getFunction();
llvm::SmallVector<WhileOp, 4> control_flow_ops;
func.walk([&](WhileOp op) { control_flow_ops.push_back(op); });
llvm::SmallVector<ConditionalOp, 4> conditional_ops;
func.walk([&](ConditionalOp op) { conditional_ops.push_back(op); });
for (auto& op : control_flow_ops) {
for (auto& op : conditional_ops) {
if (failed(LowerConditionalOp(op))) return signalPassFailure();
}
llvm::SmallVector<WhileOp, 4> while_ops;
func.walk([&](WhileOp op) { while_ops.push_back(op); });
for (auto& op : while_ops) {
if (LowerWhileOp(op)) return signalPassFailure();
}
}

View File

@ -0,0 +1,120 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for lowering TensorFlow dialect's control flow to
// the XLA dialect.
#include <cstddef>
#include <cstdint>
#include <iterator>
#include <numeric>
#include "llvm/ADT/ArrayRef.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
#include "mlir/Transforms/DialectConversion.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
#include "tensorflow/core/util/tensor_format.h"
using mlir::PassRegistration;
namespace mlir {
namespace xla_hlo {
namespace {
class LegalizeTFControlFlow : public ModulePass<LegalizeTFControlFlow> {
public:
void runOnModule() override;
};
} // namespace
std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
createLegalizeTFControlFlowPass() {
return std::make_unique<LegalizeTFControlFlow>();
}
namespace {
void ImportXlaRegion(Region* src_region, Region* dest_region) {
BlockAndValueMapping mapper;
src_region->cloneInto(dest_region, mapper);
dest_region->walk([&](mlir::ReturnOp op) -> void {
OpBuilder builder(op);
llvm::SmallVector<Value*, 4> operands(op.operands());
auto tuple = builder.create<xla_hlo::TupleOp>(op.getLoc(), operands);
builder.create<xla_hlo::ReturnOp>(op.getLoc(), tuple.getResult());
op.erase();
});
}
void LowerIf(TF::IfOp op, ModuleOp module) {
Location loc = op.getLoc();
OpBuilder builder(op);
// XLA prefers tuple arguments for control flow due to XLA not supporting
// multiple return values.
SmallVector<Value*, 3> inputs(op.input());
builder.setInsertionPoint(op);
auto tuple_input = builder.create<xla_hlo::TupleOp>(loc, inputs);
// Create the new conditional op with tuple inputs.
SmallVector<Value*, 3> operands(op.getOperands());
SmallVector<Type, 4> types(op.getResultTypes());
auto result_type = builder.getTupleType(types);
auto conditional_result = builder.create<xla_hlo::ConditionalOp>(
loc, result_type, op.cond(), tuple_input, tuple_input);
// Import the regions for both the true and false cases. These regions
// must be updated to tuple the return results together and use the xla hlo
// return op.
BlockAndValueMapping mapper;
auto then_branch = module.lookupSymbol<mlir::FuncOp>(op.then_branch());
auto else_branch = module.lookupSymbol<mlir::FuncOp>(op.else_branch());
ImportXlaRegion(&then_branch.getBody(), &conditional_result.true_branch());
ImportXlaRegion(&else_branch.getBody(), &conditional_result.false_branch());
// De-tuple the results of the xla hlo conditional result.
builder.setInsertionPointAfter(op);
for (auto result_it : llvm::enumerate(op.getResults())) {
auto get_tuple_value = builder.create<xla_hlo::GetTupleElementOp>(
loc, conditional_result, result_it.index());
result_it.value()->replaceAllUsesWith(get_tuple_value);
}
op.erase();
}
} // namespace
void LegalizeTFControlFlow::runOnModule() {
auto module = getModule();
TypeConverter type_converter;
module.walk([&](TF::IfOp op) -> void { LowerIf(op, module); });
}
} // namespace xla_hlo
} // namespace mlir
static PassRegistration<mlir::xla_hlo::LegalizeTFControlFlow> cfpass(
"xla-legalize-tf-control-flow",
"Legalize TensorFlow control flow to the XLA dialect");

View File

@ -24,6 +24,7 @@ limitations under the License.
namespace mlir {
class FuncOp;
class ModuleOp;
class Operation;
template <typename T>
class OpPassBase;
@ -33,6 +34,9 @@ namespace xla_hlo {
/// Lowers from TF dialect to HLO dialect.
std::unique_ptr<OpPassBase<FuncOp>> createLegalizeTFPass();
/// Lowers from TF dialect's control flow to HLO dialect's control flow.
std::unique_ptr<OpPassBase<ModuleOp>> createLegalizeTFControlFlowPass();
/// Converts the provided Operation as well as all nested operations into HLO
/// dialect using the conversion patterns registered by the HLO dialect.
LogicalResult legalizeTF(Operation* op);