Legalize TensorFlow NonMaxSuppressionV4 and SelfAdjointEigV2Op ops to HLO

Added support for HLO ops bitcast-convert, sort and while in MlirHloBuilder and enabled tests for NonMaxSuppressionV4 and SelfAdjointEigV2Op using these ops.

PiperOrigin-RevId: 324360651
Change-Id: I300b67cfea37a1a4362cd543e8ba7c82b00273a7
This commit is contained in:
Smit Hinsu 2020-07-31 23:10:33 -07:00 committed by TensorFlower Gardener
parent 8a449bdb65
commit d3323e54e2
9 changed files with 111 additions and 39 deletions

View File

@ -30,13 +30,14 @@ namespace {
// A pass that sinks constants implicitly captured in control flow regions. This
// is necessary to export to XLA.
//
// TODO(hinsu): Generalize this pass to handle all the ops with regions. Any
// value used within the region that is defined outside of op's region should be
// sank to the regions and not just the constants. Ops such as If and While
// whose computations doesn't require fixed signature like Sort or Reduce have
// an option to pass outside values as operands of the op to avoid recomputing
// those within internally. Note that doing so is the only option in case of
// BlockArguments.
// values defined outside that are BlockArguments of any of the parent region.
class SinkConstantsToControlFlowPass
: public mlir::PassWrapper<SinkConstantsToControlFlowPass, FunctionPass> {
void runOnFunction() override {
@ -60,7 +61,7 @@ class SinkConstantsToControlFlowPass
visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) {
Value constant = use->get();
auto op = constant.getDefiningOp();
if (!op || !isa<ConstOp, ConstantOp>(op)) return;
if (!op || !op->hasTrait<OpTrait::ConstantLike>()) return;
auto map_entry = sunk_constant.try_emplace(constant, nullptr);
if (!map_entry.second) {
// This constant has already been cloned into the region, reuse it.
@ -82,6 +83,8 @@ class SinkConstantsToControlFlowPass
} // anonymous namespace
// TODO(hinsu): Rename this pass and move to a different file along with the
// generalization to make all ops isolated from above.
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass() {
return std::make_unique<SinkConstantsToControlFlowPass>();
}

View File

@ -206,6 +206,15 @@ XlaOp MlirHloBuilder::Iota(const Shape& shape, int64 iota_dimension) {
});
}
StatusOr<XlaOp> MlirHloBuilder::BitcastConvertTypeInternal(const Shape& shape,
XlaOp operand) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::mhlo::BitcastConvertOp>(loc_, ty,
GetValue(operand));
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::TransposeInternal(
const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
@ -224,6 +233,31 @@ StatusOr<XlaOp> MlirHloBuilder::RevInternal(
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::SortInternal(const Shape& shape,
absl::Span<const XlaOp> operands,
const XlaComputation& comparator,
int64 dimension, bool is_stable) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::mhlo::SortOp>(
loc_, ty, GetValues(operands), builder_.getI64IntegerAttr(dimension),
builder_.getBoolAttr(is_stable));
TF_RETURN_IF_ERROR(ImportComputation(comparator.proto(), &op.comparator()));
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::WhileInternal(const Shape& shape,
const XlaComputation& condition,
const XlaComputation& body,
XlaOp init) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::mhlo::WhileOp>(loc_, ty, GetValue(init));
TF_RETURN_IF_ERROR(ImportComputation(condition.proto(), &op.cond()));
TF_RETURN_IF_ERROR(ImportComputation(body.proto(), &op.body()));
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::GatherInternal(
const Shape& shape, XlaOp input, XlaOp start_indices,
const GatherDimensionNumbers& dimension_numbers,

View File

@ -142,6 +142,9 @@ class MlirHloBuilder : public XlaBuilder {
XlaOp Iota(const Shape& shape, int64 iota_dimension) override;
StatusOr<XlaOp> BitcastConvertTypeInternal(const Shape& shape,
XlaOp operand) override;
StatusOr<XlaOp> TransposeInternal(
const Shape& shape, XlaOp operand,
absl::Span<const int64> permutation) override;
@ -149,6 +152,16 @@ class MlirHloBuilder : public XlaBuilder {
StatusOr<XlaOp> RevInternal(const Shape& shape, XlaOp operand,
absl::Span<const int64> dimensions) override;
StatusOr<XlaOp> SortInternal(const Shape& shape,
absl::Span<const XlaOp> operands,
const XlaComputation& comparator,
int64 dimension, bool is_stable) override;
StatusOr<XlaOp> WhileInternal(const Shape& shape,
const XlaComputation& condition,
const XlaComputation& body,
XlaOp init) override;
StatusOr<XlaOp> GatherInternal(
const Shape& shape, XlaOp input, XlaOp start_indices,
const GatherDimensionNumbers& dimension_numbers,

View File

@ -257,6 +257,14 @@ func @arg_min(%arg0: tensor<6xf64>) -> tensor<i32> {
return %1 : tensor<i32>
}
// CHECK-LABEL: non_max_suppression_v4
func @non_max_suppression_v4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<2xi32> {
%max_size = mhlo.constant dense<2> : tensor<i32>
// CHECK-NOT: tf.NonMaxSuppressionV4
%0:2 = "tf.NonMaxSuppressionV4"(%arg0, %arg1, %max_size, %arg2, %arg3) {pad_to_max_output_size = true}: (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
return %0#0 : tensor<2xi32>
}
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
// available but doesn't support this instance.
}

View File

@ -159,6 +159,7 @@ static bool IsOpAllowlisted(Operation* op) {
TypeID::get<TF::MirrorPadOp>(),
TypeID::get<TF::MulOp>(),
TypeID::get<TF::NegOp>(),
TypeID::get<TF::NonMaxSuppressionV4Op>(),
TypeID::get<TF::NotEqualOp>(),
TypeID::get<TF::PadOp>(),
TypeID::get<TF::PlaceholderWithDefaultOp>(),
@ -178,6 +179,7 @@ static bool IsOpAllowlisted(Operation* op) {
TypeID::get<TF::RintOp>(),
TypeID::get<TF::RoundOp>(),
TypeID::get<TF::SelectV2Op>(),
TypeID::get<TF::SelfAdjointEigV2Op>(),
TypeID::get<TF::SeluGradOp>(),
TypeID::get<TF::SeluOp>(),
TypeID::get<TF::SigmoidGradOp>(),

View File

@ -324,6 +324,7 @@ tf_xla_py_test(
name = "self_adjoint_eig_op_test",
size = "medium",
srcs = ["self_adjoint_eig_op_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip

View File

@ -30,7 +30,6 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_image_ops
from tensorflow.python.ops import image_ops
@ -775,7 +774,6 @@ class ResizeBilinearNonAlignCornersTest(xla_test.XLATestCase):
class NonMaxSuppressionTest(xla_test.XLATestCase):
@test_util.disable_mlir_bridge("%1")
def testNMS128From1024(self):
num_boxes = 1024
boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4")
@ -810,7 +808,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
self.assertEqual(indices_tf.size, max_output_size)
@test_util.disable_mlir_bridge("%1")
def testNMS3From6Boxes(self):
# Three boxes are selected based on IOU.
boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
@ -852,7 +849,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
self.assertEqual(num_valid, 3)
self.assertAllClose(indices_tf[:num_valid], [3, 0, 5])
@test_util.disable_mlir_bridge("%1")
def testNMS3Then2WithScoreThresh(self):
# Three boxes are selected based on IOU.
# One is filtered out by score threshold.
@ -895,7 +891,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
self.assertEqual(num_valid, 2)
self.assertAllClose(indices_tf[:num_valid], [3, 0])
@test_util.disable_mlir_bridge("%1")
def testNMS3Then1WithScoreMaxThresh(self):
# Three boxes are selected based on IOU.
# One is filtered out by score threshold.
@ -939,7 +934,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
self.assertEqual(num_valid, 1)
self.assertAllClose(indices_tf[:num_valid], [3])
@test_util.disable_mlir_bridge("%1")
def testSelectFromContinuousOverLap(self):
# Tests that a suppressed box does not itself suppress other boxes.
@ -984,7 +978,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
@test_util.disable_mlir_bridge("%1")
def testBatchedNMSFrom6(self):
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
@ -1022,7 +1015,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
indices_output)
self.assertAllEqual([5, 4], num_valid_output)
@test_util.disable_mlir_bridge("%1")
def testBatchedNMSFrom6Max3(self):
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
@ -1056,7 +1048,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
self.assertAllEqual([[0, 1, 2], [0, 1, 3]], indices_output)
self.assertAllEqual([3, 3], num_valid_output)
@test_util.disable_mlir_bridge("%1")
def testBatchedNMSSingleFrom6Max3(self):
boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]]
@ -1087,7 +1078,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
self.assertAllEqual([0, 1, 2], indices_output)
self.assertAllEqual(3, num_valid_output)
@test_util.disable_mlir_bridge("%1")
def testBatchedNMSSingleFrom6NoPad(self):
boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]]
@ -1117,7 +1107,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
self.assertAllEqual([0, 1, 2, 4, 5], indices_output)
self.assertAllEqual(5, num_valid_output)
@test_util.disable_mlir_bridge("%1")
def testBatchedNMSBatchDimsFrom6Max3(self):
boxes_data = [[[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
@ -1151,7 +1140,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
self.assertAllEqual([[[0, 1, 2], [0, 1, 3]]], indices_output)
self.assertAllEqual([[3, 3]], num_valid_output)
@test_util.disable_mlir_bridge("%1")
def testBatchedNMSScoreThresholdFrom6Max3(self):
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
@ -1187,7 +1175,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
self.assertAllEqual([3, 2], num_valid_output)
self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output)
@test_util.disable_mlir_bridge("%1")
def testBatchedNMSUnsortedInputFrom6(self):
boxes_data = [[[0, 2, 1, 2], [3, 3, 4, 4], [0, 0, 1, 1],
[0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8]],
@ -1224,7 +1211,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
indices_output)
self.assertAllEqual([5, 4], num_valid_output)
@test_util.disable_mlir_bridge("%1")
def testBatchedNMSNoncanonicalizedInputFrom6(self):
boxes_data = [[[1, 0, 0, 1], [4, 3, 3, 4], [1, 0.4, 0, 1.4],
[1, 0.6, 0, 1.6], [1, 0.8, 0, 1.8], [1, 2, 0, 2]],
@ -1262,7 +1248,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
indices_output)
self.assertAllEqual([5, 4], num_valid_output)
@test_util.disable_mlir_bridge("%1")
def testBatchedNMSScoreThresholdCanInputsFrom6Max3(self):
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
@ -1298,7 +1283,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
self.assertAllEqual([3, 2], num_valid_output)
self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output)
@test_util.disable_mlir_bridge("%1")
def testBatchedNMSFrom6DynamicInput(self):
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],

View File

@ -1728,8 +1728,6 @@ XlaOp XlaBuilder::Sort(absl::Span<const XlaOp> operands,
const XlaComputation& comparator, int64 dimension,
bool is_stable) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
instr.set_is_stable(is_stable);
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(std::vector<Shape> operand_shapes,
GetOperandShapes(operands));
@ -1737,17 +1735,26 @@ XlaOp XlaBuilder::Sort(absl::Span<const XlaOp> operands,
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferVariadicOpShape(
HloOpcode::kSort, operand_shape_ptrs));
*instr.mutable_shape() = shape.ToProto();
if (dimension == -1) {
TF_ASSIGN_OR_RETURN(const Shape* keys_shape, GetShapePtr(operands[0]));
dimension = keys_shape->rank() - 1;
}
instr.add_dimensions(dimension);
AddCalledComputation(comparator, &instr);
return AddInstruction(std::move(instr), HloOpcode::kSort, operands);
return SortInternal(shape, operands, comparator, dimension, is_stable);
});
}
StatusOr<XlaOp> XlaBuilder::SortInternal(const Shape& shape,
absl::Span<const XlaOp> operands,
const XlaComputation& comparator,
int64 dimension, bool is_stable) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
instr.set_is_stable(is_stable);
if (dimension == -1) {
TF_ASSIGN_OR_RETURN(const Shape* keys_shape, GetShapePtr(operands[0]));
dimension = keys_shape->rank() - 1;
}
instr.add_dimensions(dimension);
AddCalledComputation(comparator, &instr);
return AddInstruction(std::move(instr), HloOpcode::kSort, operands);
}
XlaOp XlaBuilder::ConvertElementType(XlaOp operand,
PrimitiveType new_element_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@ -1761,16 +1768,21 @@ XlaOp XlaBuilder::ConvertElementType(XlaOp operand,
XlaOp XlaBuilder::BitcastConvertType(XlaOp operand,
PrimitiveType new_element_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
*operand_shape, new_element_type));
*instr.mutable_shape() = shape.ToProto();
return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert,
{operand});
return BitcastConvertTypeInternal(shape, operand);
});
}
StatusOr<XlaOp> XlaBuilder::BitcastConvertTypeInternal(const Shape& shape,
XlaOp operand) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert,
{operand});
}
XlaOp XlaBuilder::Clamp(XlaOp min, XlaOp operand, XlaOp max) {
return TernaryOp(HloOpcode::kClamp, min, operand, max);
}
@ -1892,8 +1904,6 @@ XlaOp XlaBuilder::RngBitGenerator(RandomAlgorithm algorithm,
XlaOp XlaBuilder::While(const XlaComputation& condition,
const XlaComputation& body, XlaOp init) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
// Infer shape.
TF_ASSIGN_OR_RETURN(const auto& body_program_shape, body.GetProgramShape());
TF_ASSIGN_OR_RETURN(const auto& condition_program_shape,
@ -1902,14 +1912,22 @@ XlaOp XlaBuilder::While(const XlaComputation& condition,
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferWhileShape(
condition_program_shape,
body_program_shape, *init_shape));
*instr.mutable_shape() = shape.ToProto();
// Body comes before condition computation in the vector.
AddCalledComputation(body, &instr);
AddCalledComputation(condition, &instr);
return AddInstruction(std::move(instr), HloOpcode::kWhile, {init});
return WhileInternal(shape, condition, body, init);
});
}
StatusOr<XlaOp> XlaBuilder::WhileInternal(const Shape& shape,
const XlaComputation& condition,
const XlaComputation& body,
XlaOp init) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
// Body comes before condition computation in the vector.
AddCalledComputation(body, &instr);
AddCalledComputation(condition, &instr);
return AddInstruction(std::move(instr), HloOpcode::kWhile, {init});
}
XlaOp XlaBuilder::Gather(XlaOp input, XlaOp start_indices,
const GatherDimensionNumbers& dimension_numbers,
absl::Span<const int64> slice_sizes,

View File

@ -639,6 +639,8 @@ class XlaBuilder {
XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type);
XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type);
virtual StatusOr<XlaOp> BitcastConvertTypeInternal(const Shape& shape,
XlaOp operand);
XlaOp Transpose(XlaOp operand, absl::Span<const int64> permutation);
virtual StatusOr<XlaOp> TransposeInternal(
@ -650,6 +652,10 @@ class XlaBuilder {
XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
int64 dimension = -1, bool is_stable = false);
virtual StatusOr<XlaOp> SortInternal(const Shape& shape,
absl::Span<const XlaOp> operands,
const XlaComputation& comparator,
int64 dimension, bool is_stable);
XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max);
@ -666,6 +672,9 @@ class XlaBuilder {
XlaOp While(const XlaComputation& condition, const XlaComputation& body,
XlaOp init);
virtual StatusOr<XlaOp> WhileInternal(const Shape& shape,
const XlaComputation& condition,
const XlaComputation& body, XlaOp init);
XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
const XlaComputation& true_computation, XlaOp false_operand,