Add xla_hlo.case op for indexed conditional HLO.
Adds import, export and verifier support for this op. It is exported to indexed conditional HLO. PiperOrigin-RevId: 312480515 Change-Id: I8306e8f7f24b0a304de00547d3022d4fe804deb9
This commit is contained in:
parent
9c313c4d2d
commit
6e509432c0
|
@ -420,6 +420,11 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
|||
}
|
||||
case HloOpcode::kConditional: {
|
||||
llvm::SmallVector<Type, 4> rets;
|
||||
mlir::Type pred_or_index_type =
|
||||
operands[0].getType().cast<mlir::TensorType>().getElementType();
|
||||
// It is a predicated conditional if first argument is a boolean and
|
||||
// should be mapped to If op.
|
||||
if (pred_or_index_type.isInteger(1)) {
|
||||
TF_RETURN_IF_ERROR(GetMlirTypes(
|
||||
{instruction->true_computation()->root_instruction()}, &rets));
|
||||
|
||||
|
@ -431,6 +436,23 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
|||
&op.false_branch()));
|
||||
return op.getOperation();
|
||||
}
|
||||
|
||||
// Otherwise, it is a indexed conditional and should be mapped to Case op.
|
||||
TF_RETURN_IF_ERROR(GetMlirTypes(
|
||||
{instruction->branch_computation(0)->root_instruction()}, &rets));
|
||||
|
||||
int num_branches = instruction->branch_count();
|
||||
auto op = func_builder->create<mlir::xla_hlo::CaseOp>(
|
||||
loc, rets, operands, attributes, num_branches);
|
||||
for (auto index_and_computation :
|
||||
llvm::enumerate(instruction->branch_computations())) {
|
||||
auto index = index_and_computation.index();
|
||||
HloComputation* computation = index_and_computation.value();
|
||||
TF_RETURN_IF_ERROR(
|
||||
ImportComputation(computation, &op.branches()[index]));
|
||||
}
|
||||
return op.getOperation();
|
||||
}
|
||||
case HloOpcode::kConcatenate: {
|
||||
// TODO(b/132057942): Support taking an uint64_t instead of an IntegerAttr
|
||||
// for concatenate dimension.
|
||||
|
|
|
@ -1396,6 +1396,47 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
|
|||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Case Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(CaseOp op) {
|
||||
auto num_branches = op.branches().size();
|
||||
if (op.branch_operands().size() != num_branches)
|
||||
return op.emitOpError() << "expects number of branches " << num_branches
|
||||
<< " to be same as number of branch operands "
|
||||
<< op.branch_operands().size();
|
||||
|
||||
MutableArrayRef<Region> branches = op.branches();
|
||||
OperandRange branch_operands = op.branch_operands();
|
||||
for (unsigned i = 0; i < num_branches; ++i) {
|
||||
mlir::Region& branch_region = branches[i];
|
||||
if (branch_region.empty())
|
||||
return op.emitOpError() << "cannot have empty regions";
|
||||
mlir::Block& entry_block = branch_region.front();
|
||||
if (entry_block.getNumArguments() != 1)
|
||||
return op.emitOpError()
|
||||
<< "expects branch regions to have single argument, but found "
|
||||
<< entry_block.getNumArguments() << " for branch " << i;
|
||||
auto operand = branch_operands[i];
|
||||
if (entry_block.getArgument(0).getType() != operand.getType())
|
||||
return op.emitOpError()
|
||||
<< "expects operand " << i + 1 << " to be of type "
|
||||
<< entry_block.getArgument(0).getType() << ", but found "
|
||||
<< operand.getType();
|
||||
WalkResult walker = branch_region.walk([&](ReturnOp return_op) {
|
||||
if (return_op.getOperands().getTypes() != op.getResultTypes())
|
||||
return WalkResult::interrupt();
|
||||
return WalkResult::advance();
|
||||
});
|
||||
if (walker.wasInterrupted())
|
||||
return op.emitOpError()
|
||||
<< "branch " << i
|
||||
<< " returned values do not match op result types";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BinaryOps
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -497,7 +497,8 @@ def HLO_IfOp: HLO_Op<"if", []> {
|
|||
HLO_TensorOrTuple:$false_arg
|
||||
);
|
||||
|
||||
let regions = (region AnyRegion:$true_branch, AnyRegion:$false_branch);
|
||||
let regions = (region AnyRegion:$true_branch,
|
||||
AnyRegion:$false_branch);
|
||||
|
||||
let results = (outs HLO_TensorOrTuple);
|
||||
|
||||
|
@ -505,6 +506,25 @@ def HLO_IfOp: HLO_Op<"if", []> {
|
|||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
// Xla Client API has two separate calls for indexed and predicated conditional,
|
||||
// although both eventually map to kConditional HLO. CaseOp maps to indexed
|
||||
// conditional use of kConditional HLO.
|
||||
def HLO_CaseOp: HLO_Op<"case", []>,
|
||||
BASE_HLO_CaseOp {
|
||||
|
||||
let arguments = (ins
|
||||
I32Tensor:$index,
|
||||
Variadic<HLO_TensorOrTuple>:$branch_operands
|
||||
);
|
||||
|
||||
let regions = (region VariadicRegion<AnyRegion>:$branches);
|
||||
|
||||
let results = (outs Variadic<HLO_TensorOrTuple>);
|
||||
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
|
||||
def HLO_WhileOp: HLO_Op<"while", [SameOperandsAndResultType]> {
|
||||
string summary = "While operator";
|
||||
|
||||
|
|
|
@ -555,6 +555,29 @@ class BASE_HLO_XorOp {
|
|||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA control flow related op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class BASE_HLO_CaseOp {
|
||||
string summary = "Switch-Case operator";
|
||||
|
||||
string description = [{
|
||||
Returns the result of executing `branches[index]`. If
|
||||
`index` is < 0 or >= N, then `branches[N-1] is executed as
|
||||
the default branch.
|
||||
|
||||
Each branch `branches[b]` must take in a single argument of same type as
|
||||
`branch_operands[b]` and will be invoked with `branch_operands[b]`. The type
|
||||
of the returned value of each branch must be the same.
|
||||
|
||||
Note that only one of the branches will be executed depending on the value
|
||||
of index.
|
||||
See https://www.tensorflow.org/xla/operation_semantics#conditional.
|
||||
}];
|
||||
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA parallelism related op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -196,6 +196,19 @@ def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [
|
|||
let regions = (region SizedRegion<1>:$body);
|
||||
}
|
||||
|
||||
def LHLO_CaseOp: LHLO_Op<"case", [
|
||||
SingleBlockImplicitTerminator<"TerminatorOp">
|
||||
]>, BASE_HLO_CaseOp {
|
||||
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$index,
|
||||
Arg<Variadic<LHLO_BufferOrTuple>, "", [MemRead]>:$branch_operands,
|
||||
Arg<LHLO_BufferOrTuple, "", [MemWrite]>:$out
|
||||
);
|
||||
|
||||
let regions = (region VariadicRegion<SizedRegion<1>>:$branches);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA tuple op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -636,6 +636,33 @@ LogicalResult ExportXlaOp(IfOp op, OpLoweringContext ctx) {
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ExportXlaOp(CaseOp op, OpLoweringContext ctx) {
|
||||
llvm::DenseMap<mlir::Value, xla::XlaOp>& value_map = *ctx.values;
|
||||
OperandRange operands = op.branch_operands();
|
||||
MutableArrayRef<Region> branches = op.branches();
|
||||
llvm::SmallVector<xla::XlaOp, 4> branch_operands(branches.size());
|
||||
std::vector<xla::XlaComputation> computations(branches.size());
|
||||
std::vector<xla::XlaComputation*> computations_p(branches.size());
|
||||
|
||||
for (unsigned i = 0; i < branches.size(); ++i) {
|
||||
branch_operands[i] = value_map[operands[i]];
|
||||
computations_p[i] = &computations[i];
|
||||
if (failed(ctx.converter->LowerRegionAsComputation(&branches[i],
|
||||
computations_p[i])))
|
||||
return failure();
|
||||
}
|
||||
xla::XlaOp result =
|
||||
xla::Conditional(value_map[op.index()], computations_p, branch_operands);
|
||||
if (op.getNumResults() == 1) {
|
||||
value_map[op.getResult(0)] = result;
|
||||
} else {
|
||||
for (auto item : llvm::enumerate(op.getResults())) {
|
||||
value_map[item.value()] = xla::GetTupleElement(result, item.index());
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ExportXlaOp(ConstOp op, OpLoweringContext ctx) {
|
||||
return failure();
|
||||
}
|
||||
|
|
|
@ -178,3 +178,24 @@ func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: m
|
|||
} ) : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @case_memref
|
||||
func @case_memref(%index: memref<i32>, %operand_1: memref<f32>, %operand_2: memref<f32>, %operand_3: memref<f32>, %out: memref<f32>) -> () {
|
||||
"xla_lhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( {
|
||||
^bb0(%arg0: memref<f32>):
|
||||
"xla_lhlo.negate"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
}, {
|
||||
^bb0(%arg0: memref<f32>):
|
||||
"xla_lhlo.copy"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
}, {
|
||||
^bb0(%arg0: memref<f32>):
|
||||
"xla_lhlo.add"(%arg0, %arg0, %out) : (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
}
|
||||
) : (memref<i32>, memref<f32>, memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
return
|
||||
}
|
||||
|
|
|
@ -156,6 +156,98 @@ func @broadcast_in_dim_bad_shape_mismatch(%arg0: tensor<3xi32>) -> tensor<1x2x3x
|
|||
|
||||
// -----
|
||||
|
||||
func @case_mismatch_num_args(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> {
|
||||
// expected-error@+1 {{expects branch regions to have single argument, but found 2 for branch 1}}
|
||||
%0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( {
|
||||
^bb0(%arg0: tensor<f32>):
|
||||
%1 = "xla_hlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
|
||||
}, {
|
||||
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
|
||||
%1 = "xla_hlo.copy"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
|
||||
}, {
|
||||
^bb0(%arg0: tensor<f32>):
|
||||
%1 = "xla_hlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
|
||||
}
|
||||
) : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @case_mismatch_num_results(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> {
|
||||
// expected-error@+1 {{branch 1 returned values do not match op result types}}
|
||||
%0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( {
|
||||
^bb0(%arg0: tensor<f32>):
|
||||
%1 = "xla_hlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
|
||||
}, {
|
||||
^bb0(%arg0: tensor<f32>):
|
||||
%1 = "xla_hlo.copy"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%1, %arg0) : (tensor<f32>, tensor<f32>) -> ()
|
||||
}, {
|
||||
^bb0(%arg0: tensor<f32>):
|
||||
%1 = "xla_hlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
|
||||
}
|
||||
) : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @case_mismatch_arg_type(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> {
|
||||
// expected-error@+1 {{expects operand 2 to be of type 'tensor<i32>', but found 'tensor<f32>'}}
|
||||
%0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( {
|
||||
^bb0(%arg0: tensor<f32>):
|
||||
%1 = "xla_hlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
|
||||
}, {
|
||||
^bb0(%arg0: tensor<i32>):
|
||||
%1 = xla_hlo.constant dense<2.0> : tensor<f32>
|
||||
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
|
||||
}, {
|
||||
^bb0(%arg0: tensor<f32>):
|
||||
%1 = "xla_hlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
|
||||
}
|
||||
) : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @case_mismatch_return_type(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> {
|
||||
// expected-error@+1 {{branch 1 returned values do not match op result types}}
|
||||
%0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( {
|
||||
^bb0(%arg0: tensor<f32>):
|
||||
%1 = "xla_hlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
|
||||
}, {
|
||||
^bb0(%arg0: tensor<f32>):
|
||||
%1 = xla_hlo.constant dense<2> : tensor<i32>
|
||||
"xla_hlo.return"(%1) : (tensor<i32>) -> ()
|
||||
}, {
|
||||
^bb0(%arg0: tensor<f32>):
|
||||
%1 = "xla_hlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
|
||||
}
|
||||
) : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @case_empty_region(%index: tensor<i32>, %operand_1: tensor<f32>) -> () {
|
||||
// expected-error@+1 {{cannot have empty regions}}
|
||||
"xla_hlo.case"(%index, %operand_1) ( {} ) : (tensor<i32>, tensor<f32>) -> tensor<f32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @comp_eq
|
||||
func @comp_eq(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
func @main() -> tensor<f32> {
|
||||
%cst = constant {name = "constant"} dense<1> : tensor<i32>
|
||||
%cst_0 = constant {name = "constant.1"} dense<5.600000e+01> : tensor<f32>
|
||||
%cst_1 = constant {name = "constant.2"} dense<1.200000e+01> : tensor<f32>
|
||||
%cst_2 = constant {name = "constant.3"} dense<1.300000e+01> : tensor<f32>
|
||||
%0 = "xla_hlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( {
|
||||
^bb0(%arg0: tensor<f32>):
|
||||
%1 = "xla_hlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
|
||||
}, {
|
||||
^bb0(%arg0: tensor<f32>):
|
||||
%1 = "xla_hlo.copy"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
|
||||
}, {
|
||||
^bb0(%arg0: tensor<f32>):
|
||||
%1 = "xla_hlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
|
||||
}) {name = "conditional"} : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: %[[NEGATE_BRANCH:.*]] ({{.*}}: f32[]) -> f32[] {
|
||||
// CHECK: %[[ARG:.*]] = f32[] parameter(0)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[] negate(f32[] %[[ARG]])
|
||||
// CHECK: }
|
||||
|
||||
// CHECK: %[[COPY_BRANCH:.*]] ({{.*}}: f32[]) -> f32[] {
|
||||
// CHECK: %[[ARG:.*]] = f32[] parameter(0)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[] copy(f32[] %[[ARG]])
|
||||
// CHECK: }
|
||||
|
||||
// CHECK: %[[FLOOR_BRANCH:.*]] ({{.*}}: f32[]) -> f32[] {
|
||||
// CHECK: %[[ARG:.*]] = f32[] parameter(0)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[] floor(f32[] %[[ARG]])
|
||||
// CHECK: }
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK-SAME: () -> f32[]
|
||||
|
||||
// CHECK: %[[INDEX:.*]] = s32[] constant(1)
|
||||
// CHECK: %[[OPERAND_1:.*]] = f32[] constant(56)
|
||||
// CHECK: %[[OPERAND_2:.*]] = f32[] constant(12)
|
||||
// CHECK: %[[OPERAND_3:.*]] = f32[] constant(13)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[] conditional(s32[] %[[INDEX]], f32[] %[[OPERAND_1]], f32[] %[[OPERAND_2]], f32[] %[[OPERAND_3]]), branch_computations={%[[NEGATE_BRANCH]], %[[COPY_BRANCH]], %[[FLOOR_BRANCH]]}
|
||||
|
||||
// -----
|
||||
|
||||
func @main() -> (tensor<f32>, tensor<f32>) {
|
||||
%cst = constant {name = "constant"} dense<1> : tensor<i32>
|
||||
%cst_0 = constant {name = "constant.1"} dense<5.600000e+01> : tensor<f32>
|
||||
%cst_1 = constant {name = "constant.2"} dense<1.200000e+01> : tensor<f32>
|
||||
%cst_2 = constant {name = "constant.3"} dense<1.300000e+01> : tensor<f32>
|
||||
%0:2 = "xla_hlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( {
|
||||
^bb0(%arg0: tensor<f32>):
|
||||
%1 = "xla_hlo.negate"(%arg0) {name = "negate"} : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> ()
|
||||
}, {
|
||||
^bb0(%arg0: tensor<f32>):
|
||||
%1 = "xla_hlo.copy"(%arg0) {name = "copy"} : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> ()
|
||||
}, {
|
||||
^bb0(%arg0: tensor<f32>):
|
||||
%1 = "xla_hlo.floor"(%arg0) {name = "floor"} : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> ()
|
||||
}) {name = "conditional"} : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
|
||||
return %0#0, %0#1 : tensor<f32>, tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: %[[NEGATE_BRANCH:.*]] ({{.*}}: f32[]) -> (f32[], f32[]) {
|
||||
// CHECK: %[[ARG:.*]] = f32[] parameter(0)
|
||||
// CHECK: %[[NEGATE:.*]] = f32[] negate(f32[] %[[ARG]])
|
||||
// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[NEGATE]], f32[] %[[NEGATE]])
|
||||
// CHECK: }
|
||||
|
||||
// CHECK: %[[COPY_BRANCH:.*]] ({{.*}}: f32[]) -> (f32[], f32[]) {
|
||||
// CHECK: %[[ARG:.*]] = f32[] parameter(0)
|
||||
// CHECK: %[[COPY:.*]] = f32[] copy(f32[] %[[ARG]])
|
||||
// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[COPY]], f32[] %[[COPY]])
|
||||
// CHECK: }
|
||||
|
||||
// CHECK: %[[FLOOR_BRANCH:.*]] ({{.*}}: f32[]) -> (f32[], f32[]) {
|
||||
// CHECK: %[[ARG:.*]] = f32[] parameter(0)
|
||||
// CHECK: %[[FLOOR:.*]] = f32[] floor(f32[] %[[ARG]])
|
||||
// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[FLOOR]], f32[] %[[FLOOR]])
|
||||
// CHECK: }
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK-SAME: () -> (f32[], f32[])
|
||||
|
||||
// CHECK: %[[INDEX:.*]] = s32[] constant(1)
|
||||
// CHECK: %[[OPERAND_1:.*]] = f32[] constant(56)
|
||||
// CHECK: %[[OPERAND_2:.*]] = f32[] constant(12)
|
||||
// CHECK: %[[OPERAND_3:.*]] = f32[] constant(13)
|
||||
// CHECK: %[[TUPLE:.*]] = (f32[], f32[]) conditional(s32[] %[[INDEX]], f32[] %[[OPERAND_1]], f32[] %[[OPERAND_2]], f32[] %[[OPERAND_3]]), branch_computations={%[[NEGATE_BRANCH]], %[[COPY_BRANCH]], %[[FLOOR_BRANCH]]}
|
||||
// CHECK: %[[RES_1:.*]] = f32[] get-tuple-element((f32[], f32[]) %[[TUPLE]]), index=0
|
||||
// CHECK: %[[RES_2:.*]] = f32[] get-tuple-element((f32[], f32[]) %[[TUPLE]]), index=1
|
||||
// CHECK: ROOT %[[RESULT:.*]] = (f32[], f32[]) tuple(f32[] %[[RES_1]], f32[] %[[RES_2]])
|
|
@ -0,0 +1,46 @@
|
|||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule Indexed_Conditional
|
||||
|
||||
%Negate (x: f32[]) -> f32[] {
|
||||
%x = f32[] parameter(0)
|
||||
ROOT %negate = f32[] negate(f32[] %x)
|
||||
}
|
||||
|
||||
%Identity (y: f32[]) -> f32[] {
|
||||
%y = f32[] parameter(0)
|
||||
ROOT %copy = f32[] copy(f32[] %y)
|
||||
}
|
||||
|
||||
%Floor (z: f32[]) -> f32[] {
|
||||
%z = f32[] parameter(0)
|
||||
ROOT %floor = f32[] floor(f32[] %z)
|
||||
}
|
||||
|
||||
ENTRY %indexed_conditional () -> f32[] {
|
||||
%constant = s32[] constant(1)
|
||||
%constant.1 = f32[] constant(56)
|
||||
%constant.2 = f32[] constant(12)
|
||||
%constant.3 = f32[] constant(13)
|
||||
ROOT %conditional = f32[] conditional(s32[] %constant, f32[] %constant.1, f32[] %constant.2, f32[] %constant.3), branch_computations={%Negate, %Identity, %Floor}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @main() -> tensor<f32>
|
||||
// CHECK: %[[INDEX:.*]] = constant {name = "constant"} dense<1> : tensor<i32>
|
||||
// CHECK: %[[OPERAND_1:.*]] = constant {name = "{{.*}}"} dense<5.600000e+01> : tensor<f32>
|
||||
// CHECK: %[[OPERAND_2:.*]] = constant {name = "{{.*}}"} dense<1.200000e+01> : tensor<f32>
|
||||
// CHECK: %[[OPERAND_3:.*]] = constant {name = "{{.*}}"} dense<1.300000e+01> : tensor<f32>
|
||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.case"(%[[INDEX]], %[[OPERAND_1]], %[[OPERAND_2]], %[[OPERAND_3]]) ( {
|
||||
// CHECK: ^bb0(%[[ARG_1:.*]]: tensor<f32>):
|
||||
// CHECK: %[[RES_1:.*]] = "xla_hlo.negate"(%[[ARG_1]]) {name = "{{.*}}"} : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: "xla_hlo.return"(%[[RES_1]]) : (tensor<f32>) -> ()
|
||||
// CHECK: }, {
|
||||
// CHECK: ^bb0(%[[ARG_2:.*]]: tensor<f32>):
|
||||
// CHECK: %[[RES_2:.*]] = "xla_hlo.copy"(%[[ARG_2]]) {name = "{{.*}}"} : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: "xla_hlo.return"(%[[RES_2]]) : (tensor<f32>) -> ()
|
||||
// CHECK: }, {
|
||||
// CHECK: ^bb0(%[[ARG_3:.*]]: tensor<f32>):
|
||||
// CHECK: %[[RES_3:.*]] = "xla_hlo.floor"(%[[ARG_3]]) {name = "{{.*}}"} : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: "xla_hlo.return"(%[[RES_3]]) : (tensor<f32>) -> ()
|
||||
// CHECK: }) {name = "{{.*}}"} : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: return %[[RESULT]] : tensor<f32>
|
Loading…
Reference in New Issue