Add xla_hlo.case op for indexed conditional HLO.

Adds import, export and verifier support for this op. It is exported to indexed conditional HLO.

PiperOrigin-RevId: 312480515
Change-Id: I8306e8f7f24b0a304de00547d3022d4fe804deb9
This commit is contained in:
Prakalp Srivastava 2020-05-20 07:49:31 -07:00 committed by TensorFlower Gardener
parent 9c313c4d2d
commit 6e509432c0
10 changed files with 413 additions and 9 deletions

View File

@ -420,15 +420,37 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
}
case HloOpcode::kConditional: {
llvm::SmallVector<Type, 4> rets;
TF_RETURN_IF_ERROR(GetMlirTypes(
{instruction->true_computation()->root_instruction()}, &rets));
mlir::Type pred_or_index_type =
operands[0].getType().cast<mlir::TensorType>().getElementType();
// It is a predicated conditional if first argument is a boolean and
// should be mapped to If op.
if (pred_or_index_type.isInteger(1)) {
TF_RETURN_IF_ERROR(GetMlirTypes(
{instruction->true_computation()->root_instruction()}, &rets));
auto op = func_builder->create<mlir::xla_hlo::IfOp>(loc, rets, operands,
attributes);
TF_RETURN_IF_ERROR(ImportComputation(instruction->true_computation(),
&op.true_branch()));
TF_RETURN_IF_ERROR(ImportComputation(instruction->false_computation(),
&op.false_branch()));
auto op = func_builder->create<mlir::xla_hlo::IfOp>(loc, rets, operands,
attributes);
TF_RETURN_IF_ERROR(ImportComputation(instruction->true_computation(),
&op.true_branch()));
TF_RETURN_IF_ERROR(ImportComputation(instruction->false_computation(),
&op.false_branch()));
return op.getOperation();
}
// Otherwise, it is a indexed conditional and should be mapped to Case op.
TF_RETURN_IF_ERROR(GetMlirTypes(
{instruction->branch_computation(0)->root_instruction()}, &rets));
int num_branches = instruction->branch_count();
auto op = func_builder->create<mlir::xla_hlo::CaseOp>(
loc, rets, operands, attributes, num_branches);
for (auto index_and_computation :
llvm::enumerate(instruction->branch_computations())) {
auto index = index_and_computation.index();
HloComputation* computation = index_and_computation.value();
TF_RETURN_IF_ERROR(
ImportComputation(computation, &op.branches()[index]));
}
return op.getOperation();
}
case HloOpcode::kConcatenate: {

View File

@ -1396,6 +1396,47 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
return {};
}
//===----------------------------------------------------------------------===//
// Case Op
//===----------------------------------------------------------------------===//
static LogicalResult Verify(CaseOp op) {
auto num_branches = op.branches().size();
if (op.branch_operands().size() != num_branches)
return op.emitOpError() << "expects number of branches " << num_branches
<< " to be same as number of branch operands "
<< op.branch_operands().size();
MutableArrayRef<Region> branches = op.branches();
OperandRange branch_operands = op.branch_operands();
for (unsigned i = 0; i < num_branches; ++i) {
mlir::Region& branch_region = branches[i];
if (branch_region.empty())
return op.emitOpError() << "cannot have empty regions";
mlir::Block& entry_block = branch_region.front();
if (entry_block.getNumArguments() != 1)
return op.emitOpError()
<< "expects branch regions to have single argument, but found "
<< entry_block.getNumArguments() << " for branch " << i;
auto operand = branch_operands[i];
if (entry_block.getArgument(0).getType() != operand.getType())
return op.emitOpError()
<< "expects operand " << i + 1 << " to be of type "
<< entry_block.getArgument(0).getType() << ", but found "
<< operand.getType();
WalkResult walker = branch_region.walk([&](ReturnOp return_op) {
if (return_op.getOperands().getTypes() != op.getResultTypes())
return WalkResult::interrupt();
return WalkResult::advance();
});
if (walker.wasInterrupted())
return op.emitOpError()
<< "branch " << i
<< " returned values do not match op result types";
}
return success();
}
//===----------------------------------------------------------------------===//
// BinaryOps
//===----------------------------------------------------------------------===//

View File

@ -497,7 +497,8 @@ def HLO_IfOp: HLO_Op<"if", []> {
HLO_TensorOrTuple:$false_arg
);
let regions = (region AnyRegion:$true_branch, AnyRegion:$false_branch);
let regions = (region AnyRegion:$true_branch,
AnyRegion:$false_branch);
let results = (outs HLO_TensorOrTuple);
@ -505,6 +506,25 @@ def HLO_IfOp: HLO_Op<"if", []> {
let hasCustomHLOConverter = 1;
}
// Xla Client API has two separate calls for indexed and predicated conditional,
// although both eventually map to kConditional HLO. CaseOp maps to indexed
// conditional use of kConditional HLO.
def HLO_CaseOp: HLO_Op<"case", []>,
BASE_HLO_CaseOp {
let arguments = (ins
I32Tensor:$index,
Variadic<HLO_TensorOrTuple>:$branch_operands
);
let regions = (region VariadicRegion<AnyRegion>:$branches);
let results = (outs Variadic<HLO_TensorOrTuple>);
let hasCustomHLOConverter = 1;
}
def HLO_WhileOp: HLO_Op<"while", [SameOperandsAndResultType]> {
string summary = "While operator";

View File

@ -555,6 +555,29 @@ class BASE_HLO_XorOp {
}];
}
//===----------------------------------------------------------------------===//
// XLA control flow related op definitions.
//===----------------------------------------------------------------------===//
class BASE_HLO_CaseOp {
string summary = "Switch-Case operator";
string description = [{
Returns the result of executing `branches[index]`. If
`index` is < 0 or >= N, then `branches[N-1] is executed as
the default branch.
Each branch `branches[b]` must take in a single argument of same type as
`branch_operands[b]` and will be invoked with `branch_operands[b]`. The type
of the returned value of each branch must be the same.
Note that only one of the branches will be executed depending on the value
of index.
See https://www.tensorflow.org/xla/operation_semantics#conditional.
}];
}
//===----------------------------------------------------------------------===//
// XLA parallelism related op definitions.
//===----------------------------------------------------------------------===//

View File

@ -196,6 +196,19 @@ def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [
let regions = (region SizedRegion<1>:$body);
}
def LHLO_CaseOp: LHLO_Op<"case", [
SingleBlockImplicitTerminator<"TerminatorOp">
]>, BASE_HLO_CaseOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$index,
Arg<Variadic<LHLO_BufferOrTuple>, "", [MemRead]>:$branch_operands,
Arg<LHLO_BufferOrTuple, "", [MemWrite]>:$out
);
let regions = (region VariadicRegion<SizedRegion<1>>:$branches);
}
//===----------------------------------------------------------------------===//
// XLA tuple op definitions.
//===----------------------------------------------------------------------===//

View File

@ -636,6 +636,33 @@ LogicalResult ExportXlaOp(IfOp op, OpLoweringContext ctx) {
return success();
}
LogicalResult ExportXlaOp(CaseOp op, OpLoweringContext ctx) {
llvm::DenseMap<mlir::Value, xla::XlaOp>& value_map = *ctx.values;
OperandRange operands = op.branch_operands();
MutableArrayRef<Region> branches = op.branches();
llvm::SmallVector<xla::XlaOp, 4> branch_operands(branches.size());
std::vector<xla::XlaComputation> computations(branches.size());
std::vector<xla::XlaComputation*> computations_p(branches.size());
for (unsigned i = 0; i < branches.size(); ++i) {
branch_operands[i] = value_map[operands[i]];
computations_p[i] = &computations[i];
if (failed(ctx.converter->LowerRegionAsComputation(&branches[i],
computations_p[i])))
return failure();
}
xla::XlaOp result =
xla::Conditional(value_map[op.index()], computations_p, branch_operands);
if (op.getNumResults() == 1) {
value_map[op.getResult(0)] = result;
} else {
for (auto item : llvm::enumerate(op.getResults())) {
value_map[item.value()] = xla::GetTupleElement(result, item.index());
}
}
return success();
}
LogicalResult ExportXlaOp(ConstOp op, OpLoweringContext ctx) {
return failure();
}

View File

@ -178,3 +178,24 @@ func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: m
} ) : () -> ()
return
}
// -----
// CHECK-LABEL: func @case_memref
func @case_memref(%index: memref<i32>, %operand_1: memref<f32>, %operand_2: memref<f32>, %operand_3: memref<f32>, %out: memref<f32>) -> () {
"xla_lhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( {
^bb0(%arg0: memref<f32>):
"xla_lhlo.negate"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
}, {
^bb0(%arg0: memref<f32>):
"xla_lhlo.copy"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
}, {
^bb0(%arg0: memref<f32>):
"xla_lhlo.add"(%arg0, %arg0, %out) : (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
}
) : (memref<i32>, memref<f32>, memref<f32>, memref<f32>, memref<f32>) -> ()
return
}

View File

@ -156,6 +156,98 @@ func @broadcast_in_dim_bad_shape_mismatch(%arg0: tensor<3xi32>) -> tensor<1x2x3x
// -----
func @case_mismatch_num_args(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> {
// expected-error@+1 {{expects branch regions to have single argument, but found 2 for branch 1}}
%0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( {
^bb0(%arg0: tensor<f32>):
%1 = "xla_hlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
}, {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
%1 = "xla_hlo.copy"(%arg0) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
}, {
^bb0(%arg0: tensor<f32>):
%1 = "xla_hlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
}
) : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
// -----
func @case_mismatch_num_results(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> {
// expected-error@+1 {{branch 1 returned values do not match op result types}}
%0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( {
^bb0(%arg0: tensor<f32>):
%1 = "xla_hlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
}, {
^bb0(%arg0: tensor<f32>):
%1 = "xla_hlo.copy"(%arg0) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%1, %arg0) : (tensor<f32>, tensor<f32>) -> ()
}, {
^bb0(%arg0: tensor<f32>):
%1 = "xla_hlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
}
) : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
// -----
func @case_mismatch_arg_type(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> {
// expected-error@+1 {{expects operand 2 to be of type 'tensor<i32>', but found 'tensor<f32>'}}
%0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( {
^bb0(%arg0: tensor<f32>):
%1 = "xla_hlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
}, {
^bb0(%arg0: tensor<i32>):
%1 = xla_hlo.constant dense<2.0> : tensor<f32>
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
}, {
^bb0(%arg0: tensor<f32>):
%1 = "xla_hlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
}
) : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
// -----
func @case_mismatch_return_type(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> {
// expected-error@+1 {{branch 1 returned values do not match op result types}}
%0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( {
^bb0(%arg0: tensor<f32>):
%1 = "xla_hlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
}, {
^bb0(%arg0: tensor<f32>):
%1 = xla_hlo.constant dense<2> : tensor<i32>
"xla_hlo.return"(%1) : (tensor<i32>) -> ()
}, {
^bb0(%arg0: tensor<f32>):
%1 = "xla_hlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
}
) : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
// -----
func @case_empty_region(%index: tensor<i32>, %operand_1: tensor<f32>) -> () {
// expected-error@+1 {{cannot have empty regions}}
"xla_hlo.case"(%index, %operand_1) ( {} ) : (tensor<i32>, tensor<f32>) -> tensor<f32>
return
}
// -----
// CHECK-LABEL: func @comp_eq
func @comp_eq(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>

View File

@ -0,0 +1,99 @@
// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s
func @main() -> tensor<f32> {
%cst = constant {name = "constant"} dense<1> : tensor<i32>
%cst_0 = constant {name = "constant.1"} dense<5.600000e+01> : tensor<f32>
%cst_1 = constant {name = "constant.2"} dense<1.200000e+01> : tensor<f32>
%cst_2 = constant {name = "constant.3"} dense<1.300000e+01> : tensor<f32>
%0 = "xla_hlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( {
^bb0(%arg0: tensor<f32>):
%1 = "xla_hlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
}, {
^bb0(%arg0: tensor<f32>):
%1 = "xla_hlo.copy"(%arg0) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
}, {
^bb0(%arg0: tensor<f32>):
%1 = "xla_hlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
}) {name = "conditional"} : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
// CHECK: %[[NEGATE_BRANCH:.*]] ({{.*}}: f32[]) -> f32[] {
// CHECK: %[[ARG:.*]] = f32[] parameter(0)
// CHECK: ROOT %[[RESULT:.*]] = f32[] negate(f32[] %[[ARG]])
// CHECK: }
// CHECK: %[[COPY_BRANCH:.*]] ({{.*}}: f32[]) -> f32[] {
// CHECK: %[[ARG:.*]] = f32[] parameter(0)
// CHECK: ROOT %[[RESULT:.*]] = f32[] copy(f32[] %[[ARG]])
// CHECK: }
// CHECK: %[[FLOOR_BRANCH:.*]] ({{.*}}: f32[]) -> f32[] {
// CHECK: %[[ARG:.*]] = f32[] parameter(0)
// CHECK: ROOT %[[RESULT:.*]] = f32[] floor(f32[] %[[ARG]])
// CHECK: }
// CHECK-LABEL: ENTRY
// CHECK-SAME: () -> f32[]
// CHECK: %[[INDEX:.*]] = s32[] constant(1)
// CHECK: %[[OPERAND_1:.*]] = f32[] constant(56)
// CHECK: %[[OPERAND_2:.*]] = f32[] constant(12)
// CHECK: %[[OPERAND_3:.*]] = f32[] constant(13)
// CHECK: ROOT %[[RESULT:.*]] = f32[] conditional(s32[] %[[INDEX]], f32[] %[[OPERAND_1]], f32[] %[[OPERAND_2]], f32[] %[[OPERAND_3]]), branch_computations={%[[NEGATE_BRANCH]], %[[COPY_BRANCH]], %[[FLOOR_BRANCH]]}
// -----
func @main() -> (tensor<f32>, tensor<f32>) {
%cst = constant {name = "constant"} dense<1> : tensor<i32>
%cst_0 = constant {name = "constant.1"} dense<5.600000e+01> : tensor<f32>
%cst_1 = constant {name = "constant.2"} dense<1.200000e+01> : tensor<f32>
%cst_2 = constant {name = "constant.3"} dense<1.300000e+01> : tensor<f32>
%0:2 = "xla_hlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( {
^bb0(%arg0: tensor<f32>):
%1 = "xla_hlo.negate"(%arg0) {name = "negate"} : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> ()
}, {
^bb0(%arg0: tensor<f32>):
%1 = "xla_hlo.copy"(%arg0) {name = "copy"} : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> ()
}, {
^bb0(%arg0: tensor<f32>):
%1 = "xla_hlo.floor"(%arg0) {name = "floor"} : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> ()
}) {name = "conditional"} : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
return %0#0, %0#1 : tensor<f32>, tensor<f32>
}
// CHECK: %[[NEGATE_BRANCH:.*]] ({{.*}}: f32[]) -> (f32[], f32[]) {
// CHECK: %[[ARG:.*]] = f32[] parameter(0)
// CHECK: %[[NEGATE:.*]] = f32[] negate(f32[] %[[ARG]])
// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[NEGATE]], f32[] %[[NEGATE]])
// CHECK: }
// CHECK: %[[COPY_BRANCH:.*]] ({{.*}}: f32[]) -> (f32[], f32[]) {
// CHECK: %[[ARG:.*]] = f32[] parameter(0)
// CHECK: %[[COPY:.*]] = f32[] copy(f32[] %[[ARG]])
// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[COPY]], f32[] %[[COPY]])
// CHECK: }
// CHECK: %[[FLOOR_BRANCH:.*]] ({{.*}}: f32[]) -> (f32[], f32[]) {
// CHECK: %[[ARG:.*]] = f32[] parameter(0)
// CHECK: %[[FLOOR:.*]] = f32[] floor(f32[] %[[ARG]])
// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[FLOOR]], f32[] %[[FLOOR]])
// CHECK: }
// CHECK-LABEL: ENTRY
// CHECK-SAME: () -> (f32[], f32[])
// CHECK: %[[INDEX:.*]] = s32[] constant(1)
// CHECK: %[[OPERAND_1:.*]] = f32[] constant(56)
// CHECK: %[[OPERAND_2:.*]] = f32[] constant(12)
// CHECK: %[[OPERAND_3:.*]] = f32[] constant(13)
// CHECK: %[[TUPLE:.*]] = (f32[], f32[]) conditional(s32[] %[[INDEX]], f32[] %[[OPERAND_1]], f32[] %[[OPERAND_2]], f32[] %[[OPERAND_3]]), branch_computations={%[[NEGATE_BRANCH]], %[[COPY_BRANCH]], %[[FLOOR_BRANCH]]}
// CHECK: %[[RES_1:.*]] = f32[] get-tuple-element((f32[], f32[]) %[[TUPLE]]), index=0
// CHECK: %[[RES_2:.*]] = f32[] get-tuple-element((f32[], f32[]) %[[TUPLE]]), index=1
// CHECK: ROOT %[[RESULT:.*]] = (f32[], f32[]) tuple(f32[] %[[RES_1]], f32[] %[[RES_2]])

View File

@ -0,0 +1,46 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule Indexed_Conditional
%Negate (x: f32[]) -> f32[] {
%x = f32[] parameter(0)
ROOT %negate = f32[] negate(f32[] %x)
}
%Identity (y: f32[]) -> f32[] {
%y = f32[] parameter(0)
ROOT %copy = f32[] copy(f32[] %y)
}
%Floor (z: f32[]) -> f32[] {
%z = f32[] parameter(0)
ROOT %floor = f32[] floor(f32[] %z)
}
ENTRY %indexed_conditional () -> f32[] {
%constant = s32[] constant(1)
%constant.1 = f32[] constant(56)
%constant.2 = f32[] constant(12)
%constant.3 = f32[] constant(13)
ROOT %conditional = f32[] conditional(s32[] %constant, f32[] %constant.1, f32[] %constant.2, f32[] %constant.3), branch_computations={%Negate, %Identity, %Floor}
}
// CHECK-LABEL: func @main() -> tensor<f32>
// CHECK: %[[INDEX:.*]] = constant {name = "constant"} dense<1> : tensor<i32>
// CHECK: %[[OPERAND_1:.*]] = constant {name = "{{.*}}"} dense<5.600000e+01> : tensor<f32>
// CHECK: %[[OPERAND_2:.*]] = constant {name = "{{.*}}"} dense<1.200000e+01> : tensor<f32>
// CHECK: %[[OPERAND_3:.*]] = constant {name = "{{.*}}"} dense<1.300000e+01> : tensor<f32>
// CHECK: %[[RESULT:.*]] = "xla_hlo.case"(%[[INDEX]], %[[OPERAND_1]], %[[OPERAND_2]], %[[OPERAND_3]]) ( {
// CHECK: ^bb0(%[[ARG_1:.*]]: tensor<f32>):
// CHECK: %[[RES_1:.*]] = "xla_hlo.negate"(%[[ARG_1]]) {name = "{{.*}}"} : (tensor<f32>) -> tensor<f32>
// CHECK: "xla_hlo.return"(%[[RES_1]]) : (tensor<f32>) -> ()
// CHECK: }, {
// CHECK: ^bb0(%[[ARG_2:.*]]: tensor<f32>):
// CHECK: %[[RES_2:.*]] = "xla_hlo.copy"(%[[ARG_2]]) {name = "{{.*}}"} : (tensor<f32>) -> tensor<f32>
// CHECK: "xla_hlo.return"(%[[RES_2]]) : (tensor<f32>) -> ()
// CHECK: }, {
// CHECK: ^bb0(%[[ARG_3:.*]]: tensor<f32>):
// CHECK: %[[RES_3:.*]] = "xla_hlo.floor"(%[[ARG_3]]) {name = "{{.*}}"} : (tensor<f32>) -> tensor<f32>
// CHECK: "xla_hlo.return"(%[[RES_3]]) : (tensor<f32>) -> ()
// CHECK: }) {name = "{{.*}}"} : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: return %[[RESULT]] : tensor<f32>