[MLIR] Remove TupleSelectOp from LHLO.

LHLO uses output-parameter, but TupleSelectOp outputs into a tuple on the device. The current type constraints are wrong, and there is not enough expressiveness in LHLO to define a device-memory representation that can be passed in between kernel launches.

PiperOrigin-RevId: 317749542
Change-Id: I6d3350cbf9decf006f239a7208f6bbef0175ac61
This commit is contained in:
Tim Shen 2020-06-22 15:38:24 -07:00 committed by TensorFlower Gardener
parent e653495365
commit 3980537192
2 changed files with 0 additions and 29 deletions
tensorflow/compiler/mlir/xla

View File

@ -760,15 +760,6 @@ def LHLO_SortOp: LHLO_Op<"sort", []>, BASE_HLO_SortOp {
let regions = (region SizedRegion<1>:$comparator);
}
def LHLO_TupleSelectOp: LHLO_Op<"tuple_select", [SameOperandsShape]> {
let arguments = (ins
Arg<LHLO_PredBuffer, "", [MemRead]>:$pred,
Arg<LHLO_Buffer, "", [MemRead]>:$on_true,
Arg<LHLO_Buffer, "", [MemRead]>:$on_false,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
//===----------------------------------------------------------------------===//
// Late operations
//===----------------------------------------------------------------------===//

View File

@ -964,23 +964,3 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
}) : (memref<16x16xf32>, memref<16x16xf16>, tuple<memref<16x16xf32>, memref<16x16xf16>>) -> ()
return
}
// -----
// CHECK-LABEL: func @tuple_select_memrefs
func @tuple_select_memrefs(%pred: memref<20xi1>, %true_values: memref<20xf32>,
%false_values: memref<20xf32>, %arg_out: memref<20xf32>) -> () {
"xla_lhlo.tuple_select"(%pred, %true_values, %false_values, %arg_out)
: (memref<20xi1>, memref<20xf32>, memref<20xf32>, memref<20xf32>) -> ()
return
}
// -----
func @tuple_select_memrefs(%pred: memref<10xi1>, %true_values: memref<20xf32>,
%false_values: memref<20xf32>, %arg_out: memref<20xf32>) -> () {
// expected-error@+1{{requires the same shape for all operands}}
"xla_lhlo.tuple_select"(%pred, %true_values, %false_values, %arg_out)
: (memref<10xi1>, memref<20xf32>, memref<20xf32>, memref<20xf32>) -> ()
return
}