From 6e509432c07b79a254a40493585ff964e4df4461 Mon Sep 17 00:00:00 2001 From: Prakalp Srivastava Date: Wed, 20 May 2020 07:49:31 -0700 Subject: [PATCH] Add xla_hlo.case op for indexed conditional HLO. Adds import, export and verifier support for this op. It is exported to indexed conditional HLO. PiperOrigin-RevId: 312480515 Change-Id: I8306e8f7f24b0a304de00547d3022d4fe804deb9 --- .../mlir/xla/hlo_function_importer.cc | 38 +++++-- tensorflow/compiler/mlir/xla/ir/hlo_ops.cc | 41 ++++++++ tensorflow/compiler/mlir/xla/ir/hlo_ops.td | 22 ++++- .../compiler/mlir/xla/ir/hlo_ops_base.td | 23 +++++ tensorflow/compiler/mlir/xla/ir/lhlo_ops.td | 13 +++ .../compiler/mlir/xla/mlir_hlo_to_hlo.cc | 27 +++++ .../compiler/mlir/xla/tests/lhlo_ops.mlir | 21 ++++ tensorflow/compiler/mlir/xla/tests/ops.mlir | 92 +++++++++++++++++ .../mlir/xla/tests/translate/case.mlir | 99 +++++++++++++++++++ .../tests/translate/case_conditional.hlotxt | 46 +++++++++ 10 files changed, 413 insertions(+), 9 deletions(-) create mode 100644 tensorflow/compiler/mlir/xla/tests/translate/case.mlir create mode 100644 tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index 718db1597cf..22a0b038833 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -420,15 +420,37 @@ StatusOr HloFunctionImporter::ImportInstruction( } case HloOpcode::kConditional: { llvm::SmallVector rets; - TF_RETURN_IF_ERROR(GetMlirTypes( - {instruction->true_computation()->root_instruction()}, &rets)); + mlir::Type pred_or_index_type = + operands[0].getType().cast().getElementType(); + // It is a predicated conditional if first argument is a boolean and + // should be mapped to If op. + if (pred_or_index_type.isInteger(1)) { + TF_RETURN_IF_ERROR(GetMlirTypes( + {instruction->true_computation()->root_instruction()}, &rets)); - auto op = func_builder->create(loc, rets, operands, - attributes); - TF_RETURN_IF_ERROR(ImportComputation(instruction->true_computation(), - &op.true_branch())); - TF_RETURN_IF_ERROR(ImportComputation(instruction->false_computation(), - &op.false_branch())); + auto op = func_builder->create(loc, rets, operands, + attributes); + TF_RETURN_IF_ERROR(ImportComputation(instruction->true_computation(), + &op.true_branch())); + TF_RETURN_IF_ERROR(ImportComputation(instruction->false_computation(), + &op.false_branch())); + return op.getOperation(); + } + + // Otherwise, it is a indexed conditional and should be mapped to Case op. + TF_RETURN_IF_ERROR(GetMlirTypes( + {instruction->branch_computation(0)->root_instruction()}, &rets)); + + int num_branches = instruction->branch_count(); + auto op = func_builder->create( + loc, rets, operands, attributes, num_branches); + for (auto index_and_computation : + llvm::enumerate(instruction->branch_computations())) { + auto index = index_and_computation.index(); + HloComputation* computation = index_and_computation.value(); + TF_RETURN_IF_ERROR( + ImportComputation(computation, &op.branches()[index])); + } return op.getOperation(); } case HloOpcode::kConcatenate: { diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index 03928467cff..d20f1713eba 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -1396,6 +1396,47 @@ OpFoldResult ReshapeOp::fold(ArrayRef operands) { return {}; } +//===----------------------------------------------------------------------===// +// Case Op +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(CaseOp op) { + auto num_branches = op.branches().size(); + if (op.branch_operands().size() != num_branches) + return op.emitOpError() << "expects number of branches " << num_branches + << " to be same as number of branch operands " + << op.branch_operands().size(); + + MutableArrayRef branches = op.branches(); + OperandRange branch_operands = op.branch_operands(); + for (unsigned i = 0; i < num_branches; ++i) { + mlir::Region& branch_region = branches[i]; + if (branch_region.empty()) + return op.emitOpError() << "cannot have empty regions"; + mlir::Block& entry_block = branch_region.front(); + if (entry_block.getNumArguments() != 1) + return op.emitOpError() + << "expects branch regions to have single argument, but found " + << entry_block.getNumArguments() << " for branch " << i; + auto operand = branch_operands[i]; + if (entry_block.getArgument(0).getType() != operand.getType()) + return op.emitOpError() + << "expects operand " << i + 1 << " to be of type " + << entry_block.getArgument(0).getType() << ", but found " + << operand.getType(); + WalkResult walker = branch_region.walk([&](ReturnOp return_op) { + if (return_op.getOperands().getTypes() != op.getResultTypes()) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + if (walker.wasInterrupted()) + return op.emitOpError() + << "branch " << i + << " returned values do not match op result types"; + } + return success(); +} + //===----------------------------------------------------------------------===// // BinaryOps //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index 99801f1618e..093e79a8613 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -497,7 +497,8 @@ def HLO_IfOp: HLO_Op<"if", []> { HLO_TensorOrTuple:$false_arg ); - let regions = (region AnyRegion:$true_branch, AnyRegion:$false_branch); + let regions = (region AnyRegion:$true_branch, + AnyRegion:$false_branch); let results = (outs HLO_TensorOrTuple); @@ -505,6 +506,25 @@ def HLO_IfOp: HLO_Op<"if", []> { let hasCustomHLOConverter = 1; } +// Xla Client API has two separate calls for indexed and predicated conditional, +// although both eventually map to kConditional HLO. CaseOp maps to indexed +// conditional use of kConditional HLO. +def HLO_CaseOp: HLO_Op<"case", []>, + BASE_HLO_CaseOp { + + let arguments = (ins + I32Tensor:$index, + Variadic:$branch_operands + ); + + let regions = (region VariadicRegion:$branches); + + let results = (outs Variadic); + + let hasCustomHLOConverter = 1; +} + + def HLO_WhileOp: HLO_Op<"while", [SameOperandsAndResultType]> { string summary = "While operator"; diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td index b5130eafd0e..bad1bf16ec3 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td @@ -555,6 +555,29 @@ class BASE_HLO_XorOp { }]; } +//===----------------------------------------------------------------------===// +// XLA control flow related op definitions. +//===----------------------------------------------------------------------===// + +class BASE_HLO_CaseOp { + string summary = "Switch-Case operator"; + + string description = [{ + Returns the result of executing `branches[index]`. If + `index` is < 0 or >= N, then `branches[N-1] is executed as + the default branch. + + Each branch `branches[b]` must take in a single argument of same type as + `branch_operands[b]` and will be invoked with `branch_operands[b]`. The type + of the returned value of each branch must be the same. + + Note that only one of the branches will be executed depending on the value + of index. + See https://www.tensorflow.org/xla/operation_semantics#conditional. + }]; + +} + //===----------------------------------------------------------------------===// // XLA parallelism related op definitions. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td index db75bbd1f67..020859aa0bf 100644 --- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td @@ -196,6 +196,19 @@ def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [ let regions = (region SizedRegion<1>:$body); } +def LHLO_CaseOp: LHLO_Op<"case", [ + SingleBlockImplicitTerminator<"TerminatorOp"> + ]>, BASE_HLO_CaseOp { + + let arguments = (ins + Arg:$index, + Arg, "", [MemRead]>:$branch_operands, + Arg:$out + ); + + let regions = (region VariadicRegion>:$branches); +} + //===----------------------------------------------------------------------===// // XLA tuple op definitions. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 9e30d830602..8150d719f3e 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -636,6 +636,33 @@ LogicalResult ExportXlaOp(IfOp op, OpLoweringContext ctx) { return success(); } +LogicalResult ExportXlaOp(CaseOp op, OpLoweringContext ctx) { + llvm::DenseMap& value_map = *ctx.values; + OperandRange operands = op.branch_operands(); + MutableArrayRef branches = op.branches(); + llvm::SmallVector branch_operands(branches.size()); + std::vector computations(branches.size()); + std::vector computations_p(branches.size()); + + for (unsigned i = 0; i < branches.size(); ++i) { + branch_operands[i] = value_map[operands[i]]; + computations_p[i] = &computations[i]; + if (failed(ctx.converter->LowerRegionAsComputation(&branches[i], + computations_p[i]))) + return failure(); + } + xla::XlaOp result = + xla::Conditional(value_map[op.index()], computations_p, branch_operands); + if (op.getNumResults() == 1) { + value_map[op.getResult(0)] = result; + } else { + for (auto item : llvm::enumerate(op.getResults())) { + value_map[item.value()] = xla::GetTupleElement(result, item.index()); + } + } + return success(); +} + LogicalResult ExportXlaOp(ConstOp op, OpLoweringContext ctx) { return failure(); } diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir index 23e9d9b68e0..d4d775731c8 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir @@ -178,3 +178,24 @@ func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: m } ) : () -> () return } + +// ----- + +// CHECK-LABEL: func @case_memref +func @case_memref(%index: memref, %operand_1: memref, %operand_2: memref, %operand_3: memref, %out: memref) -> () { + "xla_lhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( { + ^bb0(%arg0: memref): + "xla_lhlo.negate"(%arg0, %out) : (memref, memref) -> () + "xla_lhlo.terminator"() : () -> () + }, { + ^bb0(%arg0: memref): + "xla_lhlo.copy"(%arg0, %out) : (memref, memref) -> () + "xla_lhlo.terminator"() : () -> () + }, { + ^bb0(%arg0: memref): + "xla_lhlo.add"(%arg0, %arg0, %out) : (memref, memref, memref) -> () + "xla_lhlo.terminator"() : () -> () + } + ) : (memref, memref, memref, memref, memref) -> () + return +} diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir index f09ec62c8dc..e6ae074f922 100644 --- a/tensorflow/compiler/mlir/xla/tests/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir @@ -156,6 +156,98 @@ func @broadcast_in_dim_bad_shape_mismatch(%arg0: tensor<3xi32>) -> tensor<1x2x3x // ----- +func @case_mismatch_num_args(%index: tensor, %operand_1: tensor, %operand_2: tensor, %operand_3: tensor) -> tensor { + // expected-error@+1 {{expects branch regions to have single argument, but found 2 for branch 1}} + %0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.negate"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg0: tensor, %arg1: tensor): + %1 = "xla_hlo.copy"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.floor"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + } + ) : (tensor, tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +func @case_mismatch_num_results(%index: tensor, %operand_1: tensor, %operand_2: tensor, %operand_3: tensor) -> tensor { + // expected-error@+1 {{branch 1 returned values do not match op result types}} + %0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.negate"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.copy"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1, %arg0) : (tensor, tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.floor"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + } + ) : (tensor, tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +func @case_mismatch_arg_type(%index: tensor, %operand_1: tensor, %operand_2: tensor, %operand_3: tensor) -> tensor { + // expected-error@+1 {{expects operand 2 to be of type 'tensor', but found 'tensor'}} + %0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.negate"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = xla_hlo.constant dense<2.0> : tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.floor"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + } + ) : (tensor, tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +func @case_mismatch_return_type(%index: tensor, %operand_1: tensor, %operand_2: tensor, %operand_3: tensor) -> tensor { + // expected-error@+1 {{branch 1 returned values do not match op result types}} + %0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.negate"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = xla_hlo.constant dense<2> : tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.floor"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + } + ) : (tensor, tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +func @case_empty_region(%index: tensor, %operand_1: tensor) -> () { + // expected-error@+1 {{cannot have empty regions}} + "xla_hlo.case"(%index, %operand_1) ( {} ) : (tensor, tensor) -> tensor + return +} + +// ----- + // CHECK-LABEL: func @comp_eq func @comp_eq(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi1> { %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> diff --git a/tensorflow/compiler/mlir/xla/tests/translate/case.mlir b/tensorflow/compiler/mlir/xla/tests/translate/case.mlir new file mode 100644 index 00000000000..dba9e8b61ca --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/case.mlir @@ -0,0 +1,99 @@ +// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s + +func @main() -> tensor { + %cst = constant {name = "constant"} dense<1> : tensor + %cst_0 = constant {name = "constant.1"} dense<5.600000e+01> : tensor + %cst_1 = constant {name = "constant.2"} dense<1.200000e+01> : tensor + %cst_2 = constant {name = "constant.3"} dense<1.300000e+01> : tensor + %0 = "xla_hlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.negate"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.copy"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.floor"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }) {name = "conditional"} : (tensor, tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK: %[[NEGATE_BRANCH:.*]] ({{.*}}: f32[]) -> f32[] { +// CHECK: %[[ARG:.*]] = f32[] parameter(0) +// CHECK: ROOT %[[RESULT:.*]] = f32[] negate(f32[] %[[ARG]]) +// CHECK: } + +// CHECK: %[[COPY_BRANCH:.*]] ({{.*}}: f32[]) -> f32[] { +// CHECK: %[[ARG:.*]] = f32[] parameter(0) +// CHECK: ROOT %[[RESULT:.*]] = f32[] copy(f32[] %[[ARG]]) +// CHECK: } + +// CHECK: %[[FLOOR_BRANCH:.*]] ({{.*}}: f32[]) -> f32[] { +// CHECK: %[[ARG:.*]] = f32[] parameter(0) +// CHECK: ROOT %[[RESULT:.*]] = f32[] floor(f32[] %[[ARG]]) +// CHECK: } + +// CHECK-LABEL: ENTRY +// CHECK-SAME: () -> f32[] + +// CHECK: %[[INDEX:.*]] = s32[] constant(1) +// CHECK: %[[OPERAND_1:.*]] = f32[] constant(56) +// CHECK: %[[OPERAND_2:.*]] = f32[] constant(12) +// CHECK: %[[OPERAND_3:.*]] = f32[] constant(13) +// CHECK: ROOT %[[RESULT:.*]] = f32[] conditional(s32[] %[[INDEX]], f32[] %[[OPERAND_1]], f32[] %[[OPERAND_2]], f32[] %[[OPERAND_3]]), branch_computations={%[[NEGATE_BRANCH]], %[[COPY_BRANCH]], %[[FLOOR_BRANCH]]} + +// ----- + +func @main() -> (tensor, tensor) { + %cst = constant {name = "constant"} dense<1> : tensor + %cst_0 = constant {name = "constant.1"} dense<5.600000e+01> : tensor + %cst_1 = constant {name = "constant.2"} dense<1.200000e+01> : tensor + %cst_2 = constant {name = "constant.3"} dense<1.300000e+01> : tensor + %0:2 = "xla_hlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.negate"(%arg0) {name = "negate"} : (tensor) -> tensor + "xla_hlo.return"(%1, %1) : (tensor, tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.copy"(%arg0) {name = "copy"} : (tensor) -> tensor + "xla_hlo.return"(%1, %1) : (tensor, tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.floor"(%arg0) {name = "floor"} : (tensor) -> tensor + "xla_hlo.return"(%1, %1) : (tensor, tensor) -> () + }) {name = "conditional"} : (tensor, tensor, tensor, tensor) -> (tensor, tensor) + return %0#0, %0#1 : tensor, tensor +} + +// CHECK: %[[NEGATE_BRANCH:.*]] ({{.*}}: f32[]) -> (f32[], f32[]) { +// CHECK: %[[ARG:.*]] = f32[] parameter(0) +// CHECK: %[[NEGATE:.*]] = f32[] negate(f32[] %[[ARG]]) +// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[NEGATE]], f32[] %[[NEGATE]]) +// CHECK: } + +// CHECK: %[[COPY_BRANCH:.*]] ({{.*}}: f32[]) -> (f32[], f32[]) { +// CHECK: %[[ARG:.*]] = f32[] parameter(0) +// CHECK: %[[COPY:.*]] = f32[] copy(f32[] %[[ARG]]) +// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[COPY]], f32[] %[[COPY]]) +// CHECK: } + +// CHECK: %[[FLOOR_BRANCH:.*]] ({{.*}}: f32[]) -> (f32[], f32[]) { +// CHECK: %[[ARG:.*]] = f32[] parameter(0) +// CHECK: %[[FLOOR:.*]] = f32[] floor(f32[] %[[ARG]]) +// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[FLOOR]], f32[] %[[FLOOR]]) +// CHECK: } + +// CHECK-LABEL: ENTRY +// CHECK-SAME: () -> (f32[], f32[]) + +// CHECK: %[[INDEX:.*]] = s32[] constant(1) +// CHECK: %[[OPERAND_1:.*]] = f32[] constant(56) +// CHECK: %[[OPERAND_2:.*]] = f32[] constant(12) +// CHECK: %[[OPERAND_3:.*]] = f32[] constant(13) +// CHECK: %[[TUPLE:.*]] = (f32[], f32[]) conditional(s32[] %[[INDEX]], f32[] %[[OPERAND_1]], f32[] %[[OPERAND_2]], f32[] %[[OPERAND_3]]), branch_computations={%[[NEGATE_BRANCH]], %[[COPY_BRANCH]], %[[FLOOR_BRANCH]]} +// CHECK: %[[RES_1:.*]] = f32[] get-tuple-element((f32[], f32[]) %[[TUPLE]]), index=0 +// CHECK: %[[RES_2:.*]] = f32[] get-tuple-element((f32[], f32[]) %[[TUPLE]]), index=1 +// CHECK: ROOT %[[RESULT:.*]] = (f32[], f32[]) tuple(f32[] %[[RES_1]], f32[] %[[RES_2]]) diff --git a/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt new file mode 100644 index 00000000000..2ff223cd480 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt @@ -0,0 +1,46 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule Indexed_Conditional + +%Negate (x: f32[]) -> f32[] { + %x = f32[] parameter(0) + ROOT %negate = f32[] negate(f32[] %x) +} + +%Identity (y: f32[]) -> f32[] { + %y = f32[] parameter(0) + ROOT %copy = f32[] copy(f32[] %y) +} + +%Floor (z: f32[]) -> f32[] { + %z = f32[] parameter(0) + ROOT %floor = f32[] floor(f32[] %z) +} + +ENTRY %indexed_conditional () -> f32[] { + %constant = s32[] constant(1) + %constant.1 = f32[] constant(56) + %constant.2 = f32[] constant(12) + %constant.3 = f32[] constant(13) + ROOT %conditional = f32[] conditional(s32[] %constant, f32[] %constant.1, f32[] %constant.2, f32[] %constant.3), branch_computations={%Negate, %Identity, %Floor} +} + +// CHECK-LABEL: func @main() -> tensor +// CHECK: %[[INDEX:.*]] = constant {name = "constant"} dense<1> : tensor +// CHECK: %[[OPERAND_1:.*]] = constant {name = "{{.*}}"} dense<5.600000e+01> : tensor +// CHECK: %[[OPERAND_2:.*]] = constant {name = "{{.*}}"} dense<1.200000e+01> : tensor +// CHECK: %[[OPERAND_3:.*]] = constant {name = "{{.*}}"} dense<1.300000e+01> : tensor +// CHECK: %[[RESULT:.*]] = "xla_hlo.case"(%[[INDEX]], %[[OPERAND_1]], %[[OPERAND_2]], %[[OPERAND_3]]) ( { +// CHECK: ^bb0(%[[ARG_1:.*]]: tensor): +// CHECK: %[[RES_1:.*]] = "xla_hlo.negate"(%[[ARG_1]]) {name = "{{.*}}"} : (tensor) -> tensor +// CHECK: "xla_hlo.return"(%[[RES_1]]) : (tensor) -> () +// CHECK: }, { +// CHECK: ^bb0(%[[ARG_2:.*]]: tensor): +// CHECK: %[[RES_2:.*]] = "xla_hlo.copy"(%[[ARG_2]]) {name = "{{.*}}"} : (tensor) -> tensor +// CHECK: "xla_hlo.return"(%[[RES_2]]) : (tensor) -> () +// CHECK: }, { +// CHECK: ^bb0(%[[ARG_3:.*]]: tensor): +// CHECK: %[[RES_3:.*]] = "xla_hlo.floor"(%[[ARG_3]]) {name = "{{.*}}"} : (tensor) -> tensor +// CHECK: "xla_hlo.return"(%[[RES_3]]) : (tensor) -> () +// CHECK: }) {name = "{{.*}}"} : (tensor, tensor, tensor, tensor) -> tensor +// CHECK: return %[[RESULT]] : tensor