diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc index 059fdc3edbe..14d89a7e196 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc @@ -30,13 +30,14 @@ namespace { // A pass that sinks constants implicitly captured in control flow regions. This // is necessary to export to XLA. +// // TODO(hinsu): Generalize this pass to handle all the ops with regions. Any // value used within the region that is defined outside of op's region should be // sank to the regions and not just the constants. Ops such as If and While // whose computations doesn't require fixed signature like Sort or Reduce have // an option to pass outside values as operands of the op to avoid recomputing // those within internally. Note that doing so is the only option in case of -// BlockArguments. +// values defined outside that are BlockArguments of any of the parent region. class SinkConstantsToControlFlowPass : public mlir::PassWrapper<SinkConstantsToControlFlowPass, FunctionPass> { void runOnFunction() override { @@ -60,7 +61,7 @@ class SinkConstantsToControlFlowPass visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) { Value constant = use->get(); auto op = constant.getDefiningOp(); - if (!op || !isa<ConstOp, ConstantOp>(op)) return; + if (!op || !op->hasTrait<OpTrait::ConstantLike>()) return; auto map_entry = sunk_constant.try_emplace(constant, nullptr); if (!map_entry.second) { // This constant has already been cloned into the region, reuse it. @@ -82,6 +83,8 @@ class SinkConstantsToControlFlowPass } // anonymous namespace +// TODO(hinsu): Rename this pass and move to a different file along with the +// generalization to make all ops isolated from above. std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass() { return std::make_unique<SinkConstantsToControlFlowPass>(); } diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index 31512c90f09..c94110d9102 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -206,6 +206,15 @@ XlaOp MlirHloBuilder::Iota(const Shape& shape, int64 iota_dimension) { }); } +StatusOr<XlaOp> MlirHloBuilder::BitcastConvertTypeInternal(const Shape& shape, + XlaOp operand) { + TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>( + shape, builder_)); + auto op = builder_.create<mlir::mhlo::BitcastConvertOp>(loc_, ty, + GetValue(operand)); + return MakeXlaOp(op); +} + StatusOr<XlaOp> MlirHloBuilder::TransposeInternal( const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) { TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>( @@ -224,6 +233,31 @@ StatusOr<XlaOp> MlirHloBuilder::RevInternal( return MakeXlaOp(op); } +StatusOr<XlaOp> MlirHloBuilder::SortInternal(const Shape& shape, + absl::Span<const XlaOp> operands, + const XlaComputation& comparator, + int64 dimension, bool is_stable) { + TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>( + shape, builder_)); + auto op = builder_.create<mlir::mhlo::SortOp>( + loc_, ty, GetValues(operands), builder_.getI64IntegerAttr(dimension), + builder_.getBoolAttr(is_stable)); + TF_RETURN_IF_ERROR(ImportComputation(comparator.proto(), &op.comparator())); + return MakeXlaOp(op); +} + +StatusOr<XlaOp> MlirHloBuilder::WhileInternal(const Shape& shape, + const XlaComputation& condition, + const XlaComputation& body, + XlaOp init) { + TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>( + shape, builder_)); + auto op = builder_.create<mlir::mhlo::WhileOp>(loc_, ty, GetValue(init)); + TF_RETURN_IF_ERROR(ImportComputation(condition.proto(), &op.cond())); + TF_RETURN_IF_ERROR(ImportComputation(body.proto(), &op.body())); + return MakeXlaOp(op); +} + StatusOr<XlaOp> MlirHloBuilder::GatherInternal( const Shape& shape, XlaOp input, XlaOp start_indices, const GatherDimensionNumbers& dimension_numbers, diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h index ab1a0d2c9b3..a12eb723465 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h @@ -142,6 +142,9 @@ class MlirHloBuilder : public XlaBuilder { XlaOp Iota(const Shape& shape, int64 iota_dimension) override; + StatusOr<XlaOp> BitcastConvertTypeInternal(const Shape& shape, + XlaOp operand) override; + StatusOr<XlaOp> TransposeInternal( const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) override; @@ -149,6 +152,16 @@ class MlirHloBuilder : public XlaBuilder { StatusOr<XlaOp> RevInternal(const Shape& shape, XlaOp operand, absl::Span<const int64> dimensions) override; + StatusOr<XlaOp> SortInternal(const Shape& shape, + absl::Span<const XlaOp> operands, + const XlaComputation& comparator, + int64 dimension, bool is_stable) override; + + StatusOr<XlaOp> WhileInternal(const Shape& shape, + const XlaComputation& condition, + const XlaComputation& body, + XlaOp init) override; + StatusOr<XlaOp> GatherInternal( const Shape& shape, XlaOp input, XlaOp start_indices, const GatherDimensionNumbers& dimension_numbers, diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir index 5a1edc0d933..cd351447303 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir @@ -257,6 +257,14 @@ func @arg_min(%arg0: tensor<6xf64>) -> tensor<i32> { return %1 : tensor<i32> } +// CHECK-LABEL: non_max_suppression_v4 +func @non_max_suppression_v4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<2xi32> { + %max_size = mhlo.constant dense<2> : tensor<i32> + // CHECK-NOT: tf.NonMaxSuppressionV4 + %0:2 = "tf.NonMaxSuppressionV4"(%arg0, %arg1, %max_size, %arg2, %arg3) {pad_to_max_output_size = true}: (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>) + return %0#0 : tensor<2xi32> +} + // TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is // available but doesn't support this instance. } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index 1743ae7be17..bb50fc198c8 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -159,6 +159,7 @@ static bool IsOpAllowlisted(Operation* op) { TypeID::get<TF::MirrorPadOp>(), TypeID::get<TF::MulOp>(), TypeID::get<TF::NegOp>(), + TypeID::get<TF::NonMaxSuppressionV4Op>(), TypeID::get<TF::NotEqualOp>(), TypeID::get<TF::PadOp>(), TypeID::get<TF::PlaceholderWithDefaultOp>(), @@ -178,6 +179,7 @@ static bool IsOpAllowlisted(Operation* op) { TypeID::get<TF::RintOp>(), TypeID::get<TF::RoundOp>(), TypeID::get<TF::SelectV2Op>(), + TypeID::get<TF::SelfAdjointEigV2Op>(), TypeID::get<TF::SeluGradOp>(), TypeID::get<TF::SeluOp>(), TypeID::get<TF::SigmoidGradOp>(), diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index c2b5000647d..a3134fc1c94 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -324,6 +324,7 @@ tf_xla_py_test( name = "self_adjoint_eig_op_test", size = "medium", srcs = ["self_adjoint_eig_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 326c3ec4929..9590688fda7 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -30,7 +30,6 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_image_ops from tensorflow.python.ops import image_ops @@ -775,7 +774,6 @@ class ResizeBilinearNonAlignCornersTest(xla_test.XLATestCase): class NonMaxSuppressionTest(xla_test.XLATestCase): - @test_util.disable_mlir_bridge("%1") def testNMS128From1024(self): num_boxes = 1024 boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4") @@ -810,7 +808,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): self.assertEqual(indices_tf.size, max_output_size) - @test_util.disable_mlir_bridge("%1") def testNMS3From6Boxes(self): # Three boxes are selected based on IOU. boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], @@ -852,7 +849,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): self.assertEqual(num_valid, 3) self.assertAllClose(indices_tf[:num_valid], [3, 0, 5]) - @test_util.disable_mlir_bridge("%1") def testNMS3Then2WithScoreThresh(self): # Three boxes are selected based on IOU. # One is filtered out by score threshold. @@ -895,7 +891,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): self.assertEqual(num_valid, 2) self.assertAllClose(indices_tf[:num_valid], [3, 0]) - @test_util.disable_mlir_bridge("%1") def testNMS3Then1WithScoreMaxThresh(self): # Three boxes are selected based on IOU. # One is filtered out by score threshold. @@ -939,7 +934,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): self.assertEqual(num_valid, 1) self.assertAllClose(indices_tf[:num_valid], [3]) - @test_util.disable_mlir_bridge("%1") def testSelectFromContinuousOverLap(self): # Tests that a suppressed box does not itself suppress other boxes. @@ -984,7 +978,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): - @test_util.disable_mlir_bridge("%1") def testBatchedNMSFrom6(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1022,7 +1015,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): indices_output) self.assertAllEqual([5, 4], num_valid_output) - @test_util.disable_mlir_bridge("%1") def testBatchedNMSFrom6Max3(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1056,7 +1048,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([[0, 1, 2], [0, 1, 3]], indices_output) self.assertAllEqual([3, 3], num_valid_output) - @test_util.disable_mlir_bridge("%1") def testBatchedNMSSingleFrom6Max3(self): boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]] @@ -1087,7 +1078,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([0, 1, 2], indices_output) self.assertAllEqual(3, num_valid_output) - @test_util.disable_mlir_bridge("%1") def testBatchedNMSSingleFrom6NoPad(self): boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]] @@ -1117,7 +1107,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([0, 1, 2, 4, 5], indices_output) self.assertAllEqual(5, num_valid_output) - @test_util.disable_mlir_bridge("%1") def testBatchedNMSBatchDimsFrom6Max3(self): boxes_data = [[[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1151,7 +1140,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([[[0, 1, 2], [0, 1, 3]]], indices_output) self.assertAllEqual([[3, 3]], num_valid_output) - @test_util.disable_mlir_bridge("%1") def testBatchedNMSScoreThresholdFrom6Max3(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1187,7 +1175,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([3, 2], num_valid_output) self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output) - @test_util.disable_mlir_bridge("%1") def testBatchedNMSUnsortedInputFrom6(self): boxes_data = [[[0, 2, 1, 2], [3, 3, 4, 4], [0, 0, 1, 1], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8]], @@ -1224,7 +1211,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): indices_output) self.assertAllEqual([5, 4], num_valid_output) - @test_util.disable_mlir_bridge("%1") def testBatchedNMSNoncanonicalizedInputFrom6(self): boxes_data = [[[1, 0, 0, 1], [4, 3, 3, 4], [1, 0.4, 0, 1.4], [1, 0.6, 0, 1.6], [1, 0.8, 0, 1.8], [1, 2, 0, 2]], @@ -1262,7 +1248,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): indices_output) self.assertAllEqual([5, 4], num_valid_output) - @test_util.disable_mlir_bridge("%1") def testBatchedNMSScoreThresholdCanInputsFrom6Max3(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1298,7 +1283,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([3, 2], num_valid_output) self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output) - @test_util.disable_mlir_bridge("%1") def testBatchedNMSFrom6DynamicInput(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 0fc299f031f..52f61408cbb 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -1728,8 +1728,6 @@ XlaOp XlaBuilder::Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator, int64 dimension, bool is_stable) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { - HloInstructionProto instr; - instr.set_is_stable(is_stable); std::vector<const Shape*> operand_shape_ptrs; TF_ASSIGN_OR_RETURN(std::vector<Shape> operand_shapes, GetOperandShapes(operands)); @@ -1737,17 +1735,26 @@ XlaOp XlaBuilder::Sort(absl::Span<const XlaOp> operands, [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferVariadicOpShape( HloOpcode::kSort, operand_shape_ptrs)); - *instr.mutable_shape() = shape.ToProto(); - if (dimension == -1) { - TF_ASSIGN_OR_RETURN(const Shape* keys_shape, GetShapePtr(operands[0])); - dimension = keys_shape->rank() - 1; - } - instr.add_dimensions(dimension); - AddCalledComputation(comparator, &instr); - return AddInstruction(std::move(instr), HloOpcode::kSort, operands); + return SortInternal(shape, operands, comparator, dimension, is_stable); }); } +StatusOr<XlaOp> XlaBuilder::SortInternal(const Shape& shape, + absl::Span<const XlaOp> operands, + const XlaComputation& comparator, + int64 dimension, bool is_stable) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + instr.set_is_stable(is_stable); + if (dimension == -1) { + TF_ASSIGN_OR_RETURN(const Shape* keys_shape, GetShapePtr(operands[0])); + dimension = keys_shape->rank() - 1; + } + instr.add_dimensions(dimension); + AddCalledComputation(comparator, &instr); + return AddInstruction(std::move(instr), HloOpcode::kSort, operands); +} + XlaOp XlaBuilder::ConvertElementType(XlaOp operand, PrimitiveType new_element_type) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { @@ -1761,16 +1768,21 @@ XlaOp XlaBuilder::ConvertElementType(XlaOp operand, XlaOp XlaBuilder::BitcastConvertType(XlaOp operand, PrimitiveType new_element_type) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { - HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape( *operand_shape, new_element_type)); - *instr.mutable_shape() = shape.ToProto(); - return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert, - {operand}); + return BitcastConvertTypeInternal(shape, operand); }); } +StatusOr<XlaOp> XlaBuilder::BitcastConvertTypeInternal(const Shape& shape, + XlaOp operand) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert, + {operand}); +} + XlaOp XlaBuilder::Clamp(XlaOp min, XlaOp operand, XlaOp max) { return TernaryOp(HloOpcode::kClamp, min, operand, max); } @@ -1892,8 +1904,6 @@ XlaOp XlaBuilder::RngBitGenerator(RandomAlgorithm algorithm, XlaOp XlaBuilder::While(const XlaComputation& condition, const XlaComputation& body, XlaOp init) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { - HloInstructionProto instr; - // Infer shape. TF_ASSIGN_OR_RETURN(const auto& body_program_shape, body.GetProgramShape()); TF_ASSIGN_OR_RETURN(const auto& condition_program_shape, @@ -1902,14 +1912,22 @@ XlaOp XlaBuilder::While(const XlaComputation& condition, TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferWhileShape( condition_program_shape, body_program_shape, *init_shape)); - *instr.mutable_shape() = shape.ToProto(); - // Body comes before condition computation in the vector. - AddCalledComputation(body, &instr); - AddCalledComputation(condition, &instr); - return AddInstruction(std::move(instr), HloOpcode::kWhile, {init}); + return WhileInternal(shape, condition, body, init); }); } +StatusOr<XlaOp> XlaBuilder::WhileInternal(const Shape& shape, + const XlaComputation& condition, + const XlaComputation& body, + XlaOp init) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + // Body comes before condition computation in the vector. + AddCalledComputation(body, &instr); + AddCalledComputation(condition, &instr); + return AddInstruction(std::move(instr), HloOpcode::kWhile, {init}); +} + XlaOp XlaBuilder::Gather(XlaOp input, XlaOp start_indices, const GatherDimensionNumbers& dimension_numbers, absl::Span<const int64> slice_sizes, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 7b96c6dfed6..1960d0c4632 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -639,6 +639,8 @@ class XlaBuilder { XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type); XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type); + virtual StatusOr<XlaOp> BitcastConvertTypeInternal(const Shape& shape, + XlaOp operand); XlaOp Transpose(XlaOp operand, absl::Span<const int64> permutation); virtual StatusOr<XlaOp> TransposeInternal( @@ -650,6 +652,10 @@ class XlaBuilder { XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator, int64 dimension = -1, bool is_stable = false); + virtual StatusOr<XlaOp> SortInternal(const Shape& shape, + absl::Span<const XlaOp> operands, + const XlaComputation& comparator, + int64 dimension, bool is_stable); XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max); @@ -666,6 +672,9 @@ class XlaBuilder { XlaOp While(const XlaComputation& condition, const XlaComputation& body, XlaOp init); + virtual StatusOr<XlaOp> WhileInternal(const Shape& shape, + const XlaComputation& condition, + const XlaComputation& body, XlaOp init); XlaOp Conditional(XlaOp predicate, XlaOp true_operand, const XlaComputation& true_computation, XlaOp false_operand,