[MLIR] Remove TupleSelectOp from LHLO.
LHLO uses output-parameter, but TupleSelectOp outputs into a tuple on the device. The current type constraints are wrong, and there is not enough expressiveness in LHLO to define a device-memory representation that can be passed in between kernel launches. PiperOrigin-RevId: 317749542 Change-Id: I6d3350cbf9decf006f239a7208f6bbef0175ac61
This commit is contained in:
parent
e653495365
commit
3980537192
tensorflow/compiler/mlir/xla
@ -760,15 +760,6 @@ def LHLO_SortOp: LHLO_Op<"sort", []>, BASE_HLO_SortOp {
|
||||
let regions = (region SizedRegion<1>:$comparator);
|
||||
}
|
||||
|
||||
def LHLO_TupleSelectOp: LHLO_Op<"tuple_select", [SameOperandsShape]> {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_PredBuffer, "", [MemRead]>:$pred,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$on_true,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$on_false,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
||||
);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Late operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -964,23 +964,3 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
|
||||
}) : (memref<16x16xf32>, memref<16x16xf16>, tuple<memref<16x16xf32>, memref<16x16xf16>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @tuple_select_memrefs
|
||||
func @tuple_select_memrefs(%pred: memref<20xi1>, %true_values: memref<20xf32>,
|
||||
%false_values: memref<20xf32>, %arg_out: memref<20xf32>) -> () {
|
||||
"xla_lhlo.tuple_select"(%pred, %true_values, %false_values, %arg_out)
|
||||
: (memref<20xi1>, memref<20xf32>, memref<20xf32>, memref<20xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @tuple_select_memrefs(%pred: memref<10xi1>, %true_values: memref<20xf32>,
|
||||
%false_values: memref<20xf32>, %arg_out: memref<20xf32>) -> () {
|
||||
// expected-error@+1{{requires the same shape for all operands}}
|
||||
"xla_lhlo.tuple_select"(%pred, %true_values, %false_values, %arg_out)
|
||||
: (memref<10xi1>, memref<20xf32>, memref<20xf32>, memref<20xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user