Legalize TensorFlow NonMaxSuppressionV4 and SelfAdjointEigV2Op ops to HLO
Added support for HLO ops bitcast-convert, sort and while in MlirHloBuilder and enabled tests for NonMaxSuppressionV4 and SelfAdjointEigV2Op using these ops. PiperOrigin-RevId: 324360651 Change-Id: I300b67cfea37a1a4362cd543e8ba7c82b00273a7
This commit is contained in:
parent
8a449bdb65
commit
d3323e54e2
@ -30,13 +30,14 @@ namespace {
|
||||
|
||||
// A pass that sinks constants implicitly captured in control flow regions. This
|
||||
// is necessary to export to XLA.
|
||||
//
|
||||
// TODO(hinsu): Generalize this pass to handle all the ops with regions. Any
|
||||
// value used within the region that is defined outside of op's region should be
|
||||
// sank to the regions and not just the constants. Ops such as If and While
|
||||
// whose computations doesn't require fixed signature like Sort or Reduce have
|
||||
// an option to pass outside values as operands of the op to avoid recomputing
|
||||
// those within internally. Note that doing so is the only option in case of
|
||||
// BlockArguments.
|
||||
// values defined outside that are BlockArguments of any of the parent region.
|
||||
class SinkConstantsToControlFlowPass
|
||||
: public mlir::PassWrapper<SinkConstantsToControlFlowPass, FunctionPass> {
|
||||
void runOnFunction() override {
|
||||
@ -60,7 +61,7 @@ class SinkConstantsToControlFlowPass
|
||||
visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) {
|
||||
Value constant = use->get();
|
||||
auto op = constant.getDefiningOp();
|
||||
if (!op || !isa<ConstOp, ConstantOp>(op)) return;
|
||||
if (!op || !op->hasTrait<OpTrait::ConstantLike>()) return;
|
||||
auto map_entry = sunk_constant.try_emplace(constant, nullptr);
|
||||
if (!map_entry.second) {
|
||||
// This constant has already been cloned into the region, reuse it.
|
||||
@ -82,6 +83,8 @@ class SinkConstantsToControlFlowPass
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// TODO(hinsu): Rename this pass and move to a different file along with the
|
||||
// generalization to make all ops isolated from above.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass() {
|
||||
return std::make_unique<SinkConstantsToControlFlowPass>();
|
||||
}
|
||||
|
@ -206,6 +206,15 @@ XlaOp MlirHloBuilder::Iota(const Shape& shape, int64 iota_dimension) {
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::BitcastConvertTypeInternal(const Shape& shape,
|
||||
XlaOp operand) {
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
shape, builder_));
|
||||
auto op = builder_.create<mlir::mhlo::BitcastConvertOp>(loc_, ty,
|
||||
GetValue(operand));
|
||||
return MakeXlaOp(op);
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::TransposeInternal(
|
||||
const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) {
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
@ -224,6 +233,31 @@ StatusOr<XlaOp> MlirHloBuilder::RevInternal(
|
||||
return MakeXlaOp(op);
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::SortInternal(const Shape& shape,
|
||||
absl::Span<const XlaOp> operands,
|
||||
const XlaComputation& comparator,
|
||||
int64 dimension, bool is_stable) {
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
shape, builder_));
|
||||
auto op = builder_.create<mlir::mhlo::SortOp>(
|
||||
loc_, ty, GetValues(operands), builder_.getI64IntegerAttr(dimension),
|
||||
builder_.getBoolAttr(is_stable));
|
||||
TF_RETURN_IF_ERROR(ImportComputation(comparator.proto(), &op.comparator()));
|
||||
return MakeXlaOp(op);
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::WhileInternal(const Shape& shape,
|
||||
const XlaComputation& condition,
|
||||
const XlaComputation& body,
|
||||
XlaOp init) {
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
shape, builder_));
|
||||
auto op = builder_.create<mlir::mhlo::WhileOp>(loc_, ty, GetValue(init));
|
||||
TF_RETURN_IF_ERROR(ImportComputation(condition.proto(), &op.cond()));
|
||||
TF_RETURN_IF_ERROR(ImportComputation(body.proto(), &op.body()));
|
||||
return MakeXlaOp(op);
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::GatherInternal(
|
||||
const Shape& shape, XlaOp input, XlaOp start_indices,
|
||||
const GatherDimensionNumbers& dimension_numbers,
|
||||
|
@ -142,6 +142,9 @@ class MlirHloBuilder : public XlaBuilder {
|
||||
|
||||
XlaOp Iota(const Shape& shape, int64 iota_dimension) override;
|
||||
|
||||
StatusOr<XlaOp> BitcastConvertTypeInternal(const Shape& shape,
|
||||
XlaOp operand) override;
|
||||
|
||||
StatusOr<XlaOp> TransposeInternal(
|
||||
const Shape& shape, XlaOp operand,
|
||||
absl::Span<const int64> permutation) override;
|
||||
@ -149,6 +152,16 @@ class MlirHloBuilder : public XlaBuilder {
|
||||
StatusOr<XlaOp> RevInternal(const Shape& shape, XlaOp operand,
|
||||
absl::Span<const int64> dimensions) override;
|
||||
|
||||
StatusOr<XlaOp> SortInternal(const Shape& shape,
|
||||
absl::Span<const XlaOp> operands,
|
||||
const XlaComputation& comparator,
|
||||
int64 dimension, bool is_stable) override;
|
||||
|
||||
StatusOr<XlaOp> WhileInternal(const Shape& shape,
|
||||
const XlaComputation& condition,
|
||||
const XlaComputation& body,
|
||||
XlaOp init) override;
|
||||
|
||||
StatusOr<XlaOp> GatherInternal(
|
||||
const Shape& shape, XlaOp input, XlaOp start_indices,
|
||||
const GatherDimensionNumbers& dimension_numbers,
|
||||
|
@ -257,6 +257,14 @@ func @arg_min(%arg0: tensor<6xf64>) -> tensor<i32> {
|
||||
return %1 : tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: non_max_suppression_v4
|
||||
func @non_max_suppression_v4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<2xi32> {
|
||||
%max_size = mhlo.constant dense<2> : tensor<i32>
|
||||
// CHECK-NOT: tf.NonMaxSuppressionV4
|
||||
%0:2 = "tf.NonMaxSuppressionV4"(%arg0, %arg1, %max_size, %arg2, %arg3) {pad_to_max_output_size = true}: (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
|
||||
return %0#0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
|
||||
// available but doesn't support this instance.
|
||||
}
|
||||
|
@ -159,6 +159,7 @@ static bool IsOpAllowlisted(Operation* op) {
|
||||
TypeID::get<TF::MirrorPadOp>(),
|
||||
TypeID::get<TF::MulOp>(),
|
||||
TypeID::get<TF::NegOp>(),
|
||||
TypeID::get<TF::NonMaxSuppressionV4Op>(),
|
||||
TypeID::get<TF::NotEqualOp>(),
|
||||
TypeID::get<TF::PadOp>(),
|
||||
TypeID::get<TF::PlaceholderWithDefaultOp>(),
|
||||
@ -178,6 +179,7 @@ static bool IsOpAllowlisted(Operation* op) {
|
||||
TypeID::get<TF::RintOp>(),
|
||||
TypeID::get<TF::RoundOp>(),
|
||||
TypeID::get<TF::SelectV2Op>(),
|
||||
TypeID::get<TF::SelfAdjointEigV2Op>(),
|
||||
TypeID::get<TF::SeluGradOp>(),
|
||||
TypeID::get<TF::SeluOp>(),
|
||||
TypeID::get<TF::SigmoidGradOp>(),
|
||||
|
@ -324,6 +324,7 @@ tf_xla_py_test(
|
||||
name = "self_adjoint_eig_op_test",
|
||||
size = "medium",
|
||||
srcs = ["self_adjoint_eig_op_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
|
@ -30,7 +30,6 @@ from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_image_ops
|
||||
from tensorflow.python.ops import image_ops
|
||||
@ -775,7 +774,6 @@ class ResizeBilinearNonAlignCornersTest(xla_test.XLATestCase):
|
||||
|
||||
class NonMaxSuppressionTest(xla_test.XLATestCase):
|
||||
|
||||
@test_util.disable_mlir_bridge("%1")
|
||||
def testNMS128From1024(self):
|
||||
num_boxes = 1024
|
||||
boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4")
|
||||
@ -810,7 +808,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
|
||||
|
||||
self.assertEqual(indices_tf.size, max_output_size)
|
||||
|
||||
@test_util.disable_mlir_bridge("%1")
|
||||
def testNMS3From6Boxes(self):
|
||||
# Three boxes are selected based on IOU.
|
||||
boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
|
||||
@ -852,7 +849,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
|
||||
self.assertEqual(num_valid, 3)
|
||||
self.assertAllClose(indices_tf[:num_valid], [3, 0, 5])
|
||||
|
||||
@test_util.disable_mlir_bridge("%1")
|
||||
def testNMS3Then2WithScoreThresh(self):
|
||||
# Three boxes are selected based on IOU.
|
||||
# One is filtered out by score threshold.
|
||||
@ -895,7 +891,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
|
||||
self.assertEqual(num_valid, 2)
|
||||
self.assertAllClose(indices_tf[:num_valid], [3, 0])
|
||||
|
||||
@test_util.disable_mlir_bridge("%1")
|
||||
def testNMS3Then1WithScoreMaxThresh(self):
|
||||
# Three boxes are selected based on IOU.
|
||||
# One is filtered out by score threshold.
|
||||
@ -939,7 +934,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
|
||||
self.assertEqual(num_valid, 1)
|
||||
self.assertAllClose(indices_tf[:num_valid], [3])
|
||||
|
||||
@test_util.disable_mlir_bridge("%1")
|
||||
def testSelectFromContinuousOverLap(self):
|
||||
# Tests that a suppressed box does not itself suppress other boxes.
|
||||
|
||||
@ -984,7 +978,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
|
||||
|
||||
class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
||||
|
||||
@test_util.disable_mlir_bridge("%1")
|
||||
def testBatchedNMSFrom6(self):
|
||||
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
||||
@ -1022,7 +1015,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
||||
indices_output)
|
||||
self.assertAllEqual([5, 4], num_valid_output)
|
||||
|
||||
@test_util.disable_mlir_bridge("%1")
|
||||
def testBatchedNMSFrom6Max3(self):
|
||||
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
||||
@ -1056,7 +1048,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
||||
self.assertAllEqual([[0, 1, 2], [0, 1, 3]], indices_output)
|
||||
self.assertAllEqual([3, 3], num_valid_output)
|
||||
|
||||
@test_util.disable_mlir_bridge("%1")
|
||||
def testBatchedNMSSingleFrom6Max3(self):
|
||||
boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]]
|
||||
@ -1087,7 +1078,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
||||
self.assertAllEqual([0, 1, 2], indices_output)
|
||||
self.assertAllEqual(3, num_valid_output)
|
||||
|
||||
@test_util.disable_mlir_bridge("%1")
|
||||
def testBatchedNMSSingleFrom6NoPad(self):
|
||||
boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]]
|
||||
@ -1117,7 +1107,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
||||
self.assertAllEqual([0, 1, 2, 4, 5], indices_output)
|
||||
self.assertAllEqual(5, num_valid_output)
|
||||
|
||||
@test_util.disable_mlir_bridge("%1")
|
||||
def testBatchedNMSBatchDimsFrom6Max3(self):
|
||||
boxes_data = [[[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
||||
@ -1151,7 +1140,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
||||
self.assertAllEqual([[[0, 1, 2], [0, 1, 3]]], indices_output)
|
||||
self.assertAllEqual([[3, 3]], num_valid_output)
|
||||
|
||||
@test_util.disable_mlir_bridge("%1")
|
||||
def testBatchedNMSScoreThresholdFrom6Max3(self):
|
||||
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
||||
@ -1187,7 +1175,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
||||
self.assertAllEqual([3, 2], num_valid_output)
|
||||
self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output)
|
||||
|
||||
@test_util.disable_mlir_bridge("%1")
|
||||
def testBatchedNMSUnsortedInputFrom6(self):
|
||||
boxes_data = [[[0, 2, 1, 2], [3, 3, 4, 4], [0, 0, 1, 1],
|
||||
[0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8]],
|
||||
@ -1224,7 +1211,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
||||
indices_output)
|
||||
self.assertAllEqual([5, 4], num_valid_output)
|
||||
|
||||
@test_util.disable_mlir_bridge("%1")
|
||||
def testBatchedNMSNoncanonicalizedInputFrom6(self):
|
||||
boxes_data = [[[1, 0, 0, 1], [4, 3, 3, 4], [1, 0.4, 0, 1.4],
|
||||
[1, 0.6, 0, 1.6], [1, 0.8, 0, 1.8], [1, 2, 0, 2]],
|
||||
@ -1262,7 +1248,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
||||
indices_output)
|
||||
self.assertAllEqual([5, 4], num_valid_output)
|
||||
|
||||
@test_util.disable_mlir_bridge("%1")
|
||||
def testBatchedNMSScoreThresholdCanInputsFrom6Max3(self):
|
||||
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
||||
@ -1298,7 +1283,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
||||
self.assertAllEqual([3, 2], num_valid_output)
|
||||
self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output)
|
||||
|
||||
@test_util.disable_mlir_bridge("%1")
|
||||
def testBatchedNMSFrom6DynamicInput(self):
|
||||
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
||||
|
@ -1728,8 +1728,6 @@ XlaOp XlaBuilder::Sort(absl::Span<const XlaOp> operands,
|
||||
const XlaComputation& comparator, int64 dimension,
|
||||
bool is_stable) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
instr.set_is_stable(is_stable);
|
||||
std::vector<const Shape*> operand_shape_ptrs;
|
||||
TF_ASSIGN_OR_RETURN(std::vector<Shape> operand_shapes,
|
||||
GetOperandShapes(operands));
|
||||
@ -1737,17 +1735,26 @@ XlaOp XlaBuilder::Sort(absl::Span<const XlaOp> operands,
|
||||
[](const Shape& shape) { return &shape; });
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferVariadicOpShape(
|
||||
HloOpcode::kSort, operand_shape_ptrs));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
if (dimension == -1) {
|
||||
TF_ASSIGN_OR_RETURN(const Shape* keys_shape, GetShapePtr(operands[0]));
|
||||
dimension = keys_shape->rank() - 1;
|
||||
}
|
||||
instr.add_dimensions(dimension);
|
||||
AddCalledComputation(comparator, &instr);
|
||||
return AddInstruction(std::move(instr), HloOpcode::kSort, operands);
|
||||
return SortInternal(shape, operands, comparator, dimension, is_stable);
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::SortInternal(const Shape& shape,
|
||||
absl::Span<const XlaOp> operands,
|
||||
const XlaComputation& comparator,
|
||||
int64 dimension, bool is_stable) {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
instr.set_is_stable(is_stable);
|
||||
if (dimension == -1) {
|
||||
TF_ASSIGN_OR_RETURN(const Shape* keys_shape, GetShapePtr(operands[0]));
|
||||
dimension = keys_shape->rank() - 1;
|
||||
}
|
||||
instr.add_dimensions(dimension);
|
||||
AddCalledComputation(comparator, &instr);
|
||||
return AddInstruction(std::move(instr), HloOpcode::kSort, operands);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::ConvertElementType(XlaOp operand,
|
||||
PrimitiveType new_element_type) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
@ -1761,16 +1768,21 @@ XlaOp XlaBuilder::ConvertElementType(XlaOp operand,
|
||||
XlaOp XlaBuilder::BitcastConvertType(XlaOp operand,
|
||||
PrimitiveType new_element_type) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
|
||||
*operand_shape, new_element_type));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert,
|
||||
{operand});
|
||||
return BitcastConvertTypeInternal(shape, operand);
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::BitcastConvertTypeInternal(const Shape& shape,
|
||||
XlaOp operand) {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert,
|
||||
{operand});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Clamp(XlaOp min, XlaOp operand, XlaOp max) {
|
||||
return TernaryOp(HloOpcode::kClamp, min, operand, max);
|
||||
}
|
||||
@ -1892,8 +1904,6 @@ XlaOp XlaBuilder::RngBitGenerator(RandomAlgorithm algorithm,
|
||||
XlaOp XlaBuilder::While(const XlaComputation& condition,
|
||||
const XlaComputation& body, XlaOp init) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
|
||||
// Infer shape.
|
||||
TF_ASSIGN_OR_RETURN(const auto& body_program_shape, body.GetProgramShape());
|
||||
TF_ASSIGN_OR_RETURN(const auto& condition_program_shape,
|
||||
@ -1902,14 +1912,22 @@ XlaOp XlaBuilder::While(const XlaComputation& condition,
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferWhileShape(
|
||||
condition_program_shape,
|
||||
body_program_shape, *init_shape));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
// Body comes before condition computation in the vector.
|
||||
AddCalledComputation(body, &instr);
|
||||
AddCalledComputation(condition, &instr);
|
||||
return AddInstruction(std::move(instr), HloOpcode::kWhile, {init});
|
||||
return WhileInternal(shape, condition, body, init);
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::WhileInternal(const Shape& shape,
|
||||
const XlaComputation& condition,
|
||||
const XlaComputation& body,
|
||||
XlaOp init) {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
// Body comes before condition computation in the vector.
|
||||
AddCalledComputation(body, &instr);
|
||||
AddCalledComputation(condition, &instr);
|
||||
return AddInstruction(std::move(instr), HloOpcode::kWhile, {init});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Gather(XlaOp input, XlaOp start_indices,
|
||||
const GatherDimensionNumbers& dimension_numbers,
|
||||
absl::Span<const int64> slice_sizes,
|
||||
|
@ -639,6 +639,8 @@ class XlaBuilder {
|
||||
XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type);
|
||||
|
||||
XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type);
|
||||
virtual StatusOr<XlaOp> BitcastConvertTypeInternal(const Shape& shape,
|
||||
XlaOp operand);
|
||||
|
||||
XlaOp Transpose(XlaOp operand, absl::Span<const int64> permutation);
|
||||
virtual StatusOr<XlaOp> TransposeInternal(
|
||||
@ -650,6 +652,10 @@ class XlaBuilder {
|
||||
|
||||
XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
|
||||
int64 dimension = -1, bool is_stable = false);
|
||||
virtual StatusOr<XlaOp> SortInternal(const Shape& shape,
|
||||
absl::Span<const XlaOp> operands,
|
||||
const XlaComputation& comparator,
|
||||
int64 dimension, bool is_stable);
|
||||
|
||||
XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max);
|
||||
|
||||
@ -666,6 +672,9 @@ class XlaBuilder {
|
||||
|
||||
XlaOp While(const XlaComputation& condition, const XlaComputation& body,
|
||||
XlaOp init);
|
||||
virtual StatusOr<XlaOp> WhileInternal(const Shape& shape,
|
||||
const XlaComputation& condition,
|
||||
const XlaComputation& body, XlaOp init);
|
||||
|
||||
XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
|
||||
const XlaComputation& true_computation, XlaOp false_operand,
|
||||
|
Loading…
Reference in New Issue
Block a user