diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td
index aed7c83570e..95ad97118ef 100644
--- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td
+++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td
@@ -760,15 +760,6 @@ def LHLO_SortOp: LHLO_Op<"sort", []>, BASE_HLO_SortOp {
   let regions = (region SizedRegion<1>:$comparator);
 }
 
-def LHLO_TupleSelectOp: LHLO_Op<"tuple_select", [SameOperandsShape]> {
-  let arguments = (ins
-    Arg<LHLO_PredBuffer, "", [MemRead]>:$pred,
-    Arg<LHLO_Buffer, "", [MemRead]>:$on_true,
-    Arg<LHLO_Buffer, "", [MemRead]>:$on_false,
-    Arg<LHLO_Buffer, "", [MemWrite]>:$output
-  );
-}
-
 //===----------------------------------------------------------------------===//
 // Late operations
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir
index 0ed8b36466e..1e803da4ac6 100644
--- a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir
@@ -964,23 +964,3 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
   }) : (memref<16x16xf32>, memref<16x16xf16>, tuple<memref<16x16xf32>, memref<16x16xf16>>) -> ()
   return
 }
-
-// -----
-
-// CHECK-LABEL: func @tuple_select_memrefs
-func @tuple_select_memrefs(%pred: memref<20xi1>, %true_values: memref<20xf32>,
-                           %false_values: memref<20xf32>, %arg_out: memref<20xf32>) -> () {
-  "xla_lhlo.tuple_select"(%pred, %true_values, %false_values, %arg_out)
-      : (memref<20xi1>, memref<20xf32>, memref<20xf32>, memref<20xf32>) -> ()
-  return
-}
-
-// -----
-
-func @tuple_select_memrefs(%pred: memref<10xi1>, %true_values: memref<20xf32>,
-                           %false_values: memref<20xf32>, %arg_out: memref<20xf32>) -> () {
-  // expected-error@+1{{requires the same shape for all operands}}
-  "xla_lhlo.tuple_select"(%pred, %true_values, %false_values, %arg_out)
-      : (memref<10xi1>, memref<20xf32>, memref<20xf32>, memref<20xf32>) -> ()
-  return
-}