diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc
index 059fdc3edbe..14d89a7e196 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc
@@ -30,13 +30,14 @@ namespace {
 
 // A pass that sinks constants implicitly captured in control flow regions. This
 // is necessary to export to XLA.
+//
 // TODO(hinsu): Generalize this pass to handle all the ops with regions. Any
 // value used within the region that is defined outside of op's region should be
 // sank to the regions and not just the constants. Ops such as If and While
 // whose computations doesn't require fixed signature like Sort or Reduce have
 // an option to pass outside values as operands of the op to avoid recomputing
 // those within internally. Note that doing so is the only option in case of
-// BlockArguments.
+// values defined outside that are BlockArguments of any of the parent region.
 class SinkConstantsToControlFlowPass
     : public mlir::PassWrapper<SinkConstantsToControlFlowPass, FunctionPass> {
   void runOnFunction() override {
@@ -60,7 +61,7 @@ class SinkConstantsToControlFlowPass
     visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) {
       Value constant = use->get();
       auto op = constant.getDefiningOp();
-      if (!op || !isa<ConstOp, ConstantOp>(op)) return;
+      if (!op || !op->hasTrait<OpTrait::ConstantLike>()) return;
       auto map_entry = sunk_constant.try_emplace(constant, nullptr);
       if (!map_entry.second) {
         // This constant has already been cloned into the region, reuse it.
@@ -82,6 +83,8 @@ class SinkConstantsToControlFlowPass
 
 }  // anonymous namespace
 
+// TODO(hinsu): Rename this pass and move to a different file along with the
+// generalization to make all ops isolated from above.
 std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass() {
   return std::make_unique<SinkConstantsToControlFlowPass>();
 }
diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
index 31512c90f09..c94110d9102 100644
--- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
+++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
@@ -206,6 +206,15 @@ XlaOp MlirHloBuilder::Iota(const Shape& shape, int64 iota_dimension) {
   });
 }
 
+StatusOr<XlaOp> MlirHloBuilder::BitcastConvertTypeInternal(const Shape& shape,
+                                                           XlaOp operand) {
+  TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
+                                         shape, builder_));
+  auto op = builder_.create<mlir::mhlo::BitcastConvertOp>(loc_, ty,
+                                                          GetValue(operand));
+  return MakeXlaOp(op);
+}
+
 StatusOr<XlaOp> MlirHloBuilder::TransposeInternal(
     const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) {
   TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
@@ -224,6 +233,31 @@ StatusOr<XlaOp> MlirHloBuilder::RevInternal(
   return MakeXlaOp(op);
 }
 
+StatusOr<XlaOp> MlirHloBuilder::SortInternal(const Shape& shape,
+                                             absl::Span<const XlaOp> operands,
+                                             const XlaComputation& comparator,
+                                             int64 dimension, bool is_stable) {
+  TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
+                                         shape, builder_));
+  auto op = builder_.create<mlir::mhlo::SortOp>(
+      loc_, ty, GetValues(operands), builder_.getI64IntegerAttr(dimension),
+      builder_.getBoolAttr(is_stable));
+  TF_RETURN_IF_ERROR(ImportComputation(comparator.proto(), &op.comparator()));
+  return MakeXlaOp(op);
+}
+
+StatusOr<XlaOp> MlirHloBuilder::WhileInternal(const Shape& shape,
+                                              const XlaComputation& condition,
+                                              const XlaComputation& body,
+                                              XlaOp init) {
+  TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
+                                         shape, builder_));
+  auto op = builder_.create<mlir::mhlo::WhileOp>(loc_, ty, GetValue(init));
+  TF_RETURN_IF_ERROR(ImportComputation(condition.proto(), &op.cond()));
+  TF_RETURN_IF_ERROR(ImportComputation(body.proto(), &op.body()));
+  return MakeXlaOp(op);
+}
+
 StatusOr<XlaOp> MlirHloBuilder::GatherInternal(
     const Shape& shape, XlaOp input, XlaOp start_indices,
     const GatherDimensionNumbers& dimension_numbers,
diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
index ab1a0d2c9b3..a12eb723465 100644
--- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
+++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
@@ -142,6 +142,9 @@ class MlirHloBuilder : public XlaBuilder {
 
   XlaOp Iota(const Shape& shape, int64 iota_dimension) override;
 
+  StatusOr<XlaOp> BitcastConvertTypeInternal(const Shape& shape,
+                                             XlaOp operand) override;
+
   StatusOr<XlaOp> TransposeInternal(
       const Shape& shape, XlaOp operand,
       absl::Span<const int64> permutation) override;
@@ -149,6 +152,16 @@ class MlirHloBuilder : public XlaBuilder {
   StatusOr<XlaOp> RevInternal(const Shape& shape, XlaOp operand,
                               absl::Span<const int64> dimensions) override;
 
+  StatusOr<XlaOp> SortInternal(const Shape& shape,
+                               absl::Span<const XlaOp> operands,
+                               const XlaComputation& comparator,
+                               int64 dimension, bool is_stable) override;
+
+  StatusOr<XlaOp> WhileInternal(const Shape& shape,
+                                const XlaComputation& condition,
+                                const XlaComputation& body,
+                                XlaOp init) override;
+
   StatusOr<XlaOp> GatherInternal(
       const Shape& shape, XlaOp input, XlaOp start_indices,
       const GatherDimensionNumbers& dimension_numbers,
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir
index 5a1edc0d933..cd351447303 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir
@@ -257,6 +257,14 @@ func @arg_min(%arg0: tensor<6xf64>) -> tensor<i32> {
   return %1 : tensor<i32>
 }
 
+// CHECK-LABEL: non_max_suppression_v4
+func @non_max_suppression_v4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<2xi32> {
+  %max_size = mhlo.constant dense<2> : tensor<i32>
+  // CHECK-NOT: tf.NonMaxSuppressionV4
+  %0:2 = "tf.NonMaxSuppressionV4"(%arg0, %arg1, %max_size, %arg2, %arg3) {pad_to_max_output_size = true}: (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
+  return %0#0 : tensor<2xi32>
+}
+
 // TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
 // available but doesn't support this instance.
 }
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
index 1743ae7be17..bb50fc198c8 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
@@ -159,6 +159,7 @@ static bool IsOpAllowlisted(Operation* op) {
     TypeID::get<TF::MirrorPadOp>(),
     TypeID::get<TF::MulOp>(),
     TypeID::get<TF::NegOp>(),
+    TypeID::get<TF::NonMaxSuppressionV4Op>(),
     TypeID::get<TF::NotEqualOp>(),
     TypeID::get<TF::PadOp>(),
     TypeID::get<TF::PlaceholderWithDefaultOp>(),
@@ -178,6 +179,7 @@ static bool IsOpAllowlisted(Operation* op) {
     TypeID::get<TF::RintOp>(),
     TypeID::get<TF::RoundOp>(),
     TypeID::get<TF::SelectV2Op>(),
+    TypeID::get<TF::SelfAdjointEigV2Op>(),
     TypeID::get<TF::SeluGradOp>(),
     TypeID::get<TF::SeluOp>(),
     TypeID::get<TF::SigmoidGradOp>(),
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index c2b5000647d..a3134fc1c94 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -324,6 +324,7 @@ tf_xla_py_test(
     name = "self_adjoint_eig_op_test",
     size = "medium",
     srcs = ["self_adjoint_eig_op_test.py"],
+    enable_mlir_bridge = True,
     python_version = "PY3",
     tags = [
         "no_pip",  # TODO(b/149738646): fix pip install so these tests run on kokoro pip
diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py
index 326c3ec4929..9590688fda7 100644
--- a/tensorflow/compiler/tests/image_ops_test.py
+++ b/tensorflow/compiler/tests/image_ops_test.py
@@ -30,7 +30,6 @@ from six.moves import xrange  # pylint: disable=redefined-builtin
 from tensorflow.compiler.tests import xla_test
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
-from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_image_ops
 from tensorflow.python.ops import image_ops
@@ -775,7 +774,6 @@ class ResizeBilinearNonAlignCornersTest(xla_test.XLATestCase):
 
 class NonMaxSuppressionTest(xla_test.XLATestCase):
 
-  @test_util.disable_mlir_bridge("%1")
   def testNMS128From1024(self):
     num_boxes = 1024
     boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4")
@@ -810,7 +808,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
 
       self.assertEqual(indices_tf.size, max_output_size)
 
-  @test_util.disable_mlir_bridge("%1")
   def testNMS3From6Boxes(self):
     # Three boxes are selected based on IOU.
     boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
@@ -852,7 +849,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
       self.assertEqual(num_valid, 3)
       self.assertAllClose(indices_tf[:num_valid], [3, 0, 5])
 
-  @test_util.disable_mlir_bridge("%1")
   def testNMS3Then2WithScoreThresh(self):
     # Three boxes are selected based on IOU.
     # One is filtered out by score threshold.
@@ -895,7 +891,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
       self.assertEqual(num_valid, 2)
       self.assertAllClose(indices_tf[:num_valid], [3, 0])
 
-  @test_util.disable_mlir_bridge("%1")
   def testNMS3Then1WithScoreMaxThresh(self):
     # Three boxes are selected based on IOU.
     # One is filtered out by score threshold.
@@ -939,7 +934,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
       self.assertEqual(num_valid, 1)
       self.assertAllClose(indices_tf[:num_valid], [3])
 
-  @test_util.disable_mlir_bridge("%1")
   def testSelectFromContinuousOverLap(self):
     # Tests that a suppressed box does not itself suppress other boxes.
 
@@ -984,7 +978,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
 
 class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
 
-  @test_util.disable_mlir_bridge("%1")
   def testBatchedNMSFrom6(self):
     boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
                    [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
@@ -1022,7 +1015,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
                         indices_output)
     self.assertAllEqual([5, 4], num_valid_output)
 
-  @test_util.disable_mlir_bridge("%1")
   def testBatchedNMSFrom6Max3(self):
     boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
                    [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
@@ -1056,7 +1048,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
     self.assertAllEqual([[0, 1, 2], [0, 1, 3]], indices_output)
     self.assertAllEqual([3, 3], num_valid_output)
 
-  @test_util.disable_mlir_bridge("%1")
   def testBatchedNMSSingleFrom6Max3(self):
     boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
                   [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]]
@@ -1087,7 +1078,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
     self.assertAllEqual([0, 1, 2], indices_output)
     self.assertAllEqual(3, num_valid_output)
 
-  @test_util.disable_mlir_bridge("%1")
   def testBatchedNMSSingleFrom6NoPad(self):
     boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
                   [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]]
@@ -1117,7 +1107,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
     self.assertAllEqual([0, 1, 2, 4, 5], indices_output)
     self.assertAllEqual(5, num_valid_output)
 
-  @test_util.disable_mlir_bridge("%1")
   def testBatchedNMSBatchDimsFrom6Max3(self):
     boxes_data = [[[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
                     [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
@@ -1151,7 +1140,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
     self.assertAllEqual([[[0, 1, 2], [0, 1, 3]]], indices_output)
     self.assertAllEqual([[3, 3]], num_valid_output)
 
-  @test_util.disable_mlir_bridge("%1")
   def testBatchedNMSScoreThresholdFrom6Max3(self):
     boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
                    [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
@@ -1187,7 +1175,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
     self.assertAllEqual([3, 2], num_valid_output)
     self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output)
 
-  @test_util.disable_mlir_bridge("%1")
   def testBatchedNMSUnsortedInputFrom6(self):
     boxes_data = [[[0, 2, 1, 2], [3, 3, 4, 4], [0, 0, 1, 1],
                    [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8]],
@@ -1224,7 +1211,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
                         indices_output)
     self.assertAllEqual([5, 4], num_valid_output)
 
-  @test_util.disable_mlir_bridge("%1")
   def testBatchedNMSNoncanonicalizedInputFrom6(self):
     boxes_data = [[[1, 0, 0, 1], [4, 3, 3, 4], [1, 0.4, 0, 1.4],
                    [1, 0.6, 0, 1.6], [1, 0.8, 0, 1.8], [1, 2, 0, 2]],
@@ -1262,7 +1248,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
                         indices_output)
     self.assertAllEqual([5, 4], num_valid_output)
 
-  @test_util.disable_mlir_bridge("%1")
   def testBatchedNMSScoreThresholdCanInputsFrom6Max3(self):
     boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
                    [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
@@ -1298,7 +1283,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
     self.assertAllEqual([3, 2], num_valid_output)
     self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output)
 
-  @test_util.disable_mlir_bridge("%1")
   def testBatchedNMSFrom6DynamicInput(self):
     boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
                    [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 0fc299f031f..52f61408cbb 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -1728,8 +1728,6 @@ XlaOp XlaBuilder::Sort(absl::Span<const XlaOp> operands,
                        const XlaComputation& comparator, int64 dimension,
                        bool is_stable) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
-    HloInstructionProto instr;
-    instr.set_is_stable(is_stable);
     std::vector<const Shape*> operand_shape_ptrs;
     TF_ASSIGN_OR_RETURN(std::vector<Shape> operand_shapes,
                         GetOperandShapes(operands));
@@ -1737,17 +1735,26 @@ XlaOp XlaBuilder::Sort(absl::Span<const XlaOp> operands,
                       [](const Shape& shape) { return &shape; });
     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferVariadicOpShape(
                                          HloOpcode::kSort, operand_shape_ptrs));
-    *instr.mutable_shape() = shape.ToProto();
-    if (dimension == -1) {
-      TF_ASSIGN_OR_RETURN(const Shape* keys_shape, GetShapePtr(operands[0]));
-      dimension = keys_shape->rank() - 1;
-    }
-    instr.add_dimensions(dimension);
-    AddCalledComputation(comparator, &instr);
-    return AddInstruction(std::move(instr), HloOpcode::kSort, operands);
+    return SortInternal(shape, operands, comparator, dimension, is_stable);
   });
 }
 
+StatusOr<XlaOp> XlaBuilder::SortInternal(const Shape& shape,
+                                         absl::Span<const XlaOp> operands,
+                                         const XlaComputation& comparator,
+                                         int64 dimension, bool is_stable) {
+  HloInstructionProto instr;
+  *instr.mutable_shape() = shape.ToProto();
+  instr.set_is_stable(is_stable);
+  if (dimension == -1) {
+    TF_ASSIGN_OR_RETURN(const Shape* keys_shape, GetShapePtr(operands[0]));
+    dimension = keys_shape->rank() - 1;
+  }
+  instr.add_dimensions(dimension);
+  AddCalledComputation(comparator, &instr);
+  return AddInstruction(std::move(instr), HloOpcode::kSort, operands);
+}
+
 XlaOp XlaBuilder::ConvertElementType(XlaOp operand,
                                      PrimitiveType new_element_type) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@@ -1761,16 +1768,21 @@ XlaOp XlaBuilder::ConvertElementType(XlaOp operand,
 XlaOp XlaBuilder::BitcastConvertType(XlaOp operand,
                                      PrimitiveType new_element_type) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
-    HloInstructionProto instr;
     TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
                                          *operand_shape, new_element_type));
-    *instr.mutable_shape() = shape.ToProto();
-    return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert,
-                          {operand});
+    return BitcastConvertTypeInternal(shape, operand);
   });
 }
 
+StatusOr<XlaOp> XlaBuilder::BitcastConvertTypeInternal(const Shape& shape,
+                                                       XlaOp operand) {
+  HloInstructionProto instr;
+  *instr.mutable_shape() = shape.ToProto();
+  return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert,
+                        {operand});
+}
+
 XlaOp XlaBuilder::Clamp(XlaOp min, XlaOp operand, XlaOp max) {
   return TernaryOp(HloOpcode::kClamp, min, operand, max);
 }
@@ -1892,8 +1904,6 @@ XlaOp XlaBuilder::RngBitGenerator(RandomAlgorithm algorithm,
 XlaOp XlaBuilder::While(const XlaComputation& condition,
                         const XlaComputation& body, XlaOp init) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
-    HloInstructionProto instr;
-
     // Infer shape.
     TF_ASSIGN_OR_RETURN(const auto& body_program_shape, body.GetProgramShape());
     TF_ASSIGN_OR_RETURN(const auto& condition_program_shape,
@@ -1902,14 +1912,22 @@ XlaOp XlaBuilder::While(const XlaComputation& condition,
     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferWhileShape(
                                          condition_program_shape,
                                          body_program_shape, *init_shape));
-    *instr.mutable_shape() = shape.ToProto();
-    // Body comes before condition computation in the vector.
-    AddCalledComputation(body, &instr);
-    AddCalledComputation(condition, &instr);
-    return AddInstruction(std::move(instr), HloOpcode::kWhile, {init});
+    return WhileInternal(shape, condition, body, init);
   });
 }
 
+StatusOr<XlaOp> XlaBuilder::WhileInternal(const Shape& shape,
+                                          const XlaComputation& condition,
+                                          const XlaComputation& body,
+                                          XlaOp init) {
+  HloInstructionProto instr;
+  *instr.mutable_shape() = shape.ToProto();
+  // Body comes before condition computation in the vector.
+  AddCalledComputation(body, &instr);
+  AddCalledComputation(condition, &instr);
+  return AddInstruction(std::move(instr), HloOpcode::kWhile, {init});
+}
+
 XlaOp XlaBuilder::Gather(XlaOp input, XlaOp start_indices,
                          const GatherDimensionNumbers& dimension_numbers,
                          absl::Span<const int64> slice_sizes,
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 7b96c6dfed6..1960d0c4632 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -639,6 +639,8 @@ class XlaBuilder {
   XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type);
 
   XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type);
+  virtual StatusOr<XlaOp> BitcastConvertTypeInternal(const Shape& shape,
+                                                     XlaOp operand);
 
   XlaOp Transpose(XlaOp operand, absl::Span<const int64> permutation);
   virtual StatusOr<XlaOp> TransposeInternal(
@@ -650,6 +652,10 @@ class XlaBuilder {
 
   XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
              int64 dimension = -1, bool is_stable = false);
+  virtual StatusOr<XlaOp> SortInternal(const Shape& shape,
+                                       absl::Span<const XlaOp> operands,
+                                       const XlaComputation& comparator,
+                                       int64 dimension, bool is_stable);
 
   XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max);
 
@@ -666,6 +672,9 @@ class XlaBuilder {
 
   XlaOp While(const XlaComputation& condition, const XlaComputation& body,
               XlaOp init);
+  virtual StatusOr<XlaOp> WhileInternal(const Shape& shape,
+                                        const XlaComputation& condition,
+                                        const XlaComputation& body, XlaOp init);
 
   XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
                     const XlaComputation& true_computation, XlaOp false_operand,