From fd934e4895ae8ee3697da1d558fc8ecb6ba1205f Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Fri, 12 Jun 2020 17:51:44 -0700 Subject: [PATCH] Create Iota, Rev and Fft HLO ops with MlirHloBuilder White-list some of the ops enabled through this and enable corresponding compiler tests. PiperOrigin-RevId: 316209927 Change-Id: I4f12a197425d3b1766a12c29d06e64f78caec307 --- .../compiler/mlir/xla/ir/mlir_hlo_builder.cc | 33 ++++++++++++++++ .../compiler/mlir/xla/ir/mlir_hlo_builder.h | 9 +++++ .../xla/tests/legalize-tf-with-tf2xla.mlir | 22 +++++++++++ .../xla/transforms/legalize_tf_with_tf2xla.cc | 13 +++++++ tensorflow/compiler/tests/BUILD | 2 + tensorflow/compiler/tests/binary_ops_test.py | 4 -- tensorflow/compiler/xla/client/xla_builder.cc | 39 ++++++++++++------- tensorflow/compiler/xla/client/xla_builder.h | 7 +++- 8 files changed, 110 insertions(+), 19 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index d98e6375f7e..63a277516ac 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -120,6 +120,30 @@ StatusOr MlirHloBuilder::ConvGeneralDilatedInternal( return MakeXlaOp(op); } +StatusOr MlirHloBuilder::FftInternal( + const Shape& shape, XlaOp operand, FftType fft_type, + absl::Span fft_length) { + TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( + shape, builder_)); + auto op = builder_.create( + loc_, ty, GetValue(operand), + builder_.getStringAttr(FftType_Name(fft_type)), + GetI64ElementsAttr(fft_length, &builder_)); + return MakeXlaOp(op); +} + +XlaOp MlirHloBuilder::Iota(const Shape& shape, int64 iota_dimension) { + return ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN( + mlir::Type ty, + ConvertShapeToType(shape, builder_)); + auto op = builder_.create( + loc_, ty, + builder_.getIntegerAttr(builder_.getI64Type(), iota_dimension)); + return MakeXlaOp(op); + }); +} + StatusOr MlirHloBuilder::TransposeInternal( const Shape& shape, XlaOp operand, absl::Span permutation) { TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( @@ -129,6 +153,15 @@ StatusOr MlirHloBuilder::TransposeInternal( return MakeXlaOp(op); } +StatusOr MlirHloBuilder::RevInternal( + const Shape& shape, XlaOp operand, absl::Span dimensions) { + TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( + shape, builder_)); + auto op = builder_.create( + loc_, ty, GetValue(operand), GetI64ElementsAttr(dimensions, &builder_)); + return MakeXlaOp(op); +} + StatusOr MlirHloBuilder::GatherInternal( const Shape& shape, XlaOp input, XlaOp start_indices, const GatherDimensionNumbers& dimension_numbers, diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h index 0b6bacbfff6..7d93f0b1eae 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h @@ -120,10 +120,19 @@ class MlirHloBuilder : public XlaBuilder { int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config) override; + StatusOr FftInternal(const Shape& shape, XlaOp operand, + FftType fft_type, + absl::Span fft_length) override; + + XlaOp Iota(const Shape& shape, int64 iota_dimension) override; + StatusOr TransposeInternal( const Shape& shape, XlaOp operand, absl::Span permutation) override; + StatusOr RevInternal(const Shape& shape, XlaOp operand, + absl::Span dimensions) override; + StatusOr GatherInternal( const Shape& shape, XlaOp input, XlaOp start_indices, const GatherDimensionNumbers& dimension_numbers, diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir index 3f99d71494e..d7c92b95b40 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir @@ -214,6 +214,28 @@ func @sparse_to_dense(%arg0: tensor<3x2xi32>, %arg1: tensor<3xf32>, %arg2: tenso return %0 : tensor<3x3xf32> } +// CHECK-LABEL: fft +func @fft(%arg0: tensor<3x5x8xcomplex>) -> tensor<3x5x8xcomplex> { + // CHECK: "xla_hlo.fft"(%arg0) + %0 = "tf.FFT"(%arg0) : (tensor<3x5x8xcomplex>) -> tensor<3x5x8xcomplex> + return %0 : tensor<3x5x8xcomplex> +} + +// CHECK-LABEL: reverse_sequence +func @reverse_sequence(%arg0: tensor<4x2x3x1x1xi32>, %arg1: tensor<3xi32>) -> tensor<4x2x3x1x1xi32> { + // CHECK-NOT: tf.ReverseSequence + %0 = "tf.ReverseSequence"(%arg0, %arg1) {batch_dim = 2 : i64, seq_dim = 0 : i64}: (tensor<4x2x3x1x1xi32>, tensor<3xi32>) -> tensor<4x2x3x1x1xi32> + return %0 : tensor<4x2x3x1x1xi32> +} + +// CHECK-LABEL: mirror_pad +func @mirror_pad(%arg0: tensor<2x3xcomplex>) -> tensor<4x7xcomplex> { + %0 = xla_hlo.constant dense<[[1, 1], [2, 2]]> : tensor<2x2xi32> + // CHECK-NOT: tf.MirrorPad + %1 = "tf.MirrorPad"(%arg0, %0) {mode = "SYMMETRIC"} : (tensor<2x3xcomplex>, tensor<2x2xi32>) -> tensor<4x7xcomplex> + return %1 : tensor<4x7xcomplex> +} + // TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is // available but doesn't support this instance. } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index 659cbbe8ebc..477bc654914 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -116,11 +116,20 @@ static bool IsOpWhitelisted(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -134,16 +143,20 @@ static bool IsOpWhitelisted(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 595bef42a5a..c9876035da9 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -692,6 +692,7 @@ tf_xla_py_test( name = "fft_test", size = "medium", srcs = ["fft_test.py"], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 6, tags = [ @@ -1129,6 +1130,7 @@ tf_xla_py_test( name = "reverse_sequence_op_test", size = "medium", srcs = ["reverse_sequence_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 70390bc6cda..789309bb3bc 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -1225,8 +1225,6 @@ class BinaryOpsTest(xla_test.XLATestCase): [7, 7, 7, 7, 7, 7]], dtype=dtype)) - @test_util.disable_mlir_bridge( - "Requires concatenate op support in MlirHloBuilder") def testSymmetricMirrorPad(self): mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "SYMMETRIC") for dtype in self.numeric_types: @@ -1258,8 +1256,6 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([[0, 0], [0, 0]], dtype=np.int32), expected=np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)) - @test_util.disable_mlir_bridge( - "Requires concatenate op support in MlirHloBuilder") def testReflectMirrorPad(self): mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "REFLECT") for dtype in self.numeric_types: diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 440fbebfa5e..77556a72442 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -1323,20 +1323,26 @@ StatusOr XlaBuilder::ConvGeneralDilatedInternal( XlaOp XlaBuilder::Fft(XlaOp operand, const FftType fft_type, const absl::Span fft_length) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferFftShape( *operand_shape, fft_type, fft_length)); - *instr.mutable_shape() = shape.ToProto(); - instr.set_fft_type(fft_type); - for (int64 i : fft_length) { - instr.add_fft_length(i); - } - - return AddInstruction(std::move(instr), HloOpcode::kFft, {operand}); + return FftInternal(shape, operand, fft_type, fft_length); }); } +StatusOr XlaBuilder::FftInternal( + const Shape& shape, XlaOp operand, const FftType fft_type, + const absl::Span fft_length) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + instr.set_fft_type(fft_type); + for (int64 i : fft_length) { + instr.add_fft_length(i); + } + + return AddInstruction(std::move(instr), HloOpcode::kFft, {operand}); +} + XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -1664,18 +1670,23 @@ StatusOr XlaBuilder::TransposeInternal( XlaOp XlaBuilder::Rev(XlaOp operand, absl::Span dimensions) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReverseShape( *operand_shape, dimensions)); - *instr.mutable_shape() = shape.ToProto(); - for (int64 dim : dimensions) { - instr.add_dimensions(dim); - } - return AddInstruction(std::move(instr), HloOpcode::kReverse, {operand}); + return RevInternal(shape, operand, dimensions); }); } +StatusOr XlaBuilder::RevInternal(const Shape& shape, XlaOp operand, + absl::Span dimensions) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + for (int64 dim : dimensions) { + instr.add_dimensions(dim); + } + return AddInstruction(std::move(instr), HloOpcode::kReverse, {operand}); +} + XlaOp XlaBuilder::Sort(absl::Span operands, const XlaComputation& comparator, int64 dimension, bool is_stable) { diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 82f8cdbabce..d21ae66d365 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -502,6 +502,9 @@ class XlaBuilder { XlaOp Fft(XlaOp operand, FftType fft_type, absl::Span fft_length); + virtual StatusOr FftInternal(const Shape& shape, XlaOp operand, + FftType fft_type, + absl::Span fft_length); XlaOp Infeed(const Shape& shape, const string& config = ""); XlaOp InfeedWithToken(XlaOp token, const Shape& shape, const string& config); @@ -594,7 +597,7 @@ class XlaBuilder { absl::Span> padding, XlaOp source, XlaOp init_value, const XlaComputation& scatter); - XlaOp Iota(const Shape& shape, int64 iota_dimension); + virtual XlaOp Iota(const Shape& shape, int64 iota_dimension); XlaOp Iota(PrimitiveType type, int64 size); @@ -607,6 +610,8 @@ class XlaBuilder { const Shape& shape, XlaOp operand, absl::Span permutation); XlaOp Rev(XlaOp operand, absl::Span dimensions); + virtual StatusOr RevInternal(const Shape& shape, XlaOp operand, + absl::Span dimensions); XlaOp Sort(absl::Span operands, const XlaComputation& comparator, int64 dimension = -1, bool is_stable = false);