Create Iota, Rev and Fft HLO ops with MlirHloBuilder

White-list some of the ops enabled through this and enable corresponding compiler tests.

PiperOrigin-RevId: 316209927
Change-Id: I4f12a197425d3b1766a12c29d06e64f78caec307
This commit is contained in:
Smit Hinsu 2020-06-12 17:51:44 -07:00 committed by TensorFlower Gardener
parent fe34364219
commit fd934e4895
8 changed files with 110 additions and 19 deletions

View File

@ -120,6 +120,30 @@ StatusOr<XlaOp> MlirHloBuilder::ConvGeneralDilatedInternal(
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::FftInternal(
const Shape& shape, XlaOp operand, FftType fft_type,
absl::Span<const int64> fft_length) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::xla_hlo::FftOp>(
loc_, ty, GetValue(operand),
builder_.getStringAttr(FftType_Name(fft_type)),
GetI64ElementsAttr(fft_length, &builder_));
return MakeXlaOp(op);
}
XlaOp MlirHloBuilder::Iota(const Shape& shape, int64 iota_dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(
mlir::Type ty,
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
auto op = builder_.create<mlir::xla_hlo::IotaOp>(
loc_, ty,
builder_.getIntegerAttr(builder_.getI64Type(), iota_dimension));
return MakeXlaOp(op);
});
}
StatusOr<XlaOp> MlirHloBuilder::TransposeInternal(
const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
@ -129,6 +153,15 @@ StatusOr<XlaOp> MlirHloBuilder::TransposeInternal(
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::RevInternal(
const Shape& shape, XlaOp operand, absl::Span<const int64> dimensions) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::xla_hlo::ReverseOp>(
loc_, ty, GetValue(operand), GetI64ElementsAttr(dimensions, &builder_));
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::GatherInternal(
const Shape& shape, XlaOp input, XlaOp start_indices,
const GatherDimensionNumbers& dimension_numbers,

View File

@ -120,10 +120,19 @@ class MlirHloBuilder : public XlaBuilder {
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config) override;
StatusOr<XlaOp> FftInternal(const Shape& shape, XlaOp operand,
FftType fft_type,
absl::Span<const int64> fft_length) override;
XlaOp Iota(const Shape& shape, int64 iota_dimension) override;
StatusOr<XlaOp> TransposeInternal(
const Shape& shape, XlaOp operand,
absl::Span<const int64> permutation) override;
StatusOr<XlaOp> RevInternal(const Shape& shape, XlaOp operand,
absl::Span<const int64> dimensions) override;
StatusOr<XlaOp> GatherInternal(
const Shape& shape, XlaOp input, XlaOp start_indices,
const GatherDimensionNumbers& dimension_numbers,

View File

@ -214,6 +214,28 @@ func @sparse_to_dense(%arg0: tensor<3x2xi32>, %arg1: tensor<3xf32>, %arg2: tenso
return %0 : tensor<3x3xf32>
}
// CHECK-LABEL: fft
func @fft(%arg0: tensor<3x5x8xcomplex<f32>>) -> tensor<3x5x8xcomplex<f32>> {
// CHECK: "xla_hlo.fft"(%arg0)
%0 = "tf.FFT"(%arg0) : (tensor<3x5x8xcomplex<f32>>) -> tensor<3x5x8xcomplex<f32>>
return %0 : tensor<3x5x8xcomplex<f32>>
}
// CHECK-LABEL: reverse_sequence
func @reverse_sequence(%arg0: tensor<4x2x3x1x1xi32>, %arg1: tensor<3xi32>) -> tensor<4x2x3x1x1xi32> {
// CHECK-NOT: tf.ReverseSequence
%0 = "tf.ReverseSequence"(%arg0, %arg1) {batch_dim = 2 : i64, seq_dim = 0 : i64}: (tensor<4x2x3x1x1xi32>, tensor<3xi32>) -> tensor<4x2x3x1x1xi32>
return %0 : tensor<4x2x3x1x1xi32>
}
// CHECK-LABEL: mirror_pad
func @mirror_pad(%arg0: tensor<2x3xcomplex<f64>>) -> tensor<4x7xcomplex<f64>> {
%0 = xla_hlo.constant dense<[[1, 1], [2, 2]]> : tensor<2x2xi32>
// CHECK-NOT: tf.MirrorPad
%1 = "tf.MirrorPad"(%arg0, %0) {mode = "SYMMETRIC"} : (tensor<2x3xcomplex<f64>>, tensor<2x2xi32>) -> tensor<4x7xcomplex<f64>>
return %1 : tensor<4x7xcomplex<f64>>
}
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
// available but doesn't support this instance.
}

View File

@ -116,11 +116,20 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::ErfcOp>(),
TypeID::get<TF::ErfOp>(),
TypeID::get<TF::Expm1Op>(),
TypeID::get<TF::FFT2DOp>(),
TypeID::get<TF::FFT3DOp>(),
TypeID::get<TF::FFTOp>(),
TypeID::get<TF::FloorDivOp>(),
TypeID::get<TF::FloorModOp>(),
TypeID::get<TF::GatherNdOp>(),
TypeID::get<TF::GreaterEqualOp>(),
TypeID::get<TF::GreaterOp>(),
TypeID::get<TF::IFFT2DOp>(),
TypeID::get<TF::IFFT3DOp>(),
TypeID::get<TF::IFFTOp>(),
TypeID::get<TF::IRFFT2DOp>(),
TypeID::get<TF::IRFFT3DOp>(),
TypeID::get<TF::IRFFTOp>(),
TypeID::get<TF::InvertOp>(),
TypeID::get<TF::InvOp>(),
TypeID::get<TF::LeakyReluGradOp>(),
@ -134,16 +143,20 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::LogicalOrOp>(),
TypeID::get<TF::LogOp>(),
TypeID::get<TF::MatMulOp>(),
TypeID::get<TF::MirrorPadOp>(),
TypeID::get<TF::MulOp>(),
TypeID::get<TF::NegOp>(),
TypeID::get<TF::NotEqualOp>(),
TypeID::get<TF::PadOp>(),
TypeID::get<TF::PlaceholderWithDefaultOp>(),
TypeID::get<TF::PowOp>(),
TypeID::get<TF::RFFT2DOp>(),
TypeID::get<TF::RFFT3DOp>(),
TypeID::get<TF::RealDivOp>(),
TypeID::get<TF::ReciprocalOp>(),
TypeID::get<TF::ReciprocalGradOp>(),
TypeID::get<TF::Relu6GradOp>(),
TypeID::get<TF::ReverseSequenceOp>(),
TypeID::get<TF::RightShiftOp>(),
TypeID::get<TF::RintOp>(),
TypeID::get<TF::RoundOp>(),

View File

@ -692,6 +692,7 @@ tf_xla_py_test(
name = "fft_test",
size = "medium",
srcs = ["fft_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
shard_count = 6,
tags = [
@ -1129,6 +1130,7 @@ tf_xla_py_test(
name = "reverse_sequence_op_test",
size = "medium",
srcs = ["reverse_sequence_op_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip

View File

@ -1225,8 +1225,6 @@ class BinaryOpsTest(xla_test.XLATestCase):
[7, 7, 7, 7, 7, 7]],
dtype=dtype))
@test_util.disable_mlir_bridge(
"Requires concatenate op support in MlirHloBuilder")
def testSymmetricMirrorPad(self):
mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "SYMMETRIC")
for dtype in self.numeric_types:
@ -1258,8 +1256,6 @@ class BinaryOpsTest(xla_test.XLATestCase):
np.array([[0, 0], [0, 0]], dtype=np.int32),
expected=np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype))
@test_util.disable_mlir_bridge(
"Requires concatenate op support in MlirHloBuilder")
def testReflectMirrorPad(self):
mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "REFLECT")
for dtype in self.numeric_types:

View File

@ -1323,20 +1323,26 @@ StatusOr<XlaOp> XlaBuilder::ConvGeneralDilatedInternal(
XlaOp XlaBuilder::Fft(XlaOp operand, const FftType fft_type,
const absl::Span<const int64> fft_length) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferFftShape(
*operand_shape, fft_type, fft_length));
*instr.mutable_shape() = shape.ToProto();
instr.set_fft_type(fft_type);
for (int64 i : fft_length) {
instr.add_fft_length(i);
}
return AddInstruction(std::move(instr), HloOpcode::kFft, {operand});
return FftInternal(shape, operand, fft_type, fft_length);
});
}
StatusOr<XlaOp> XlaBuilder::FftInternal(
const Shape& shape, XlaOp operand, const FftType fft_type,
const absl::Span<const int64> fft_length) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
instr.set_fft_type(fft_type);
for (int64 i : fft_length) {
instr.add_fft_length(i);
}
return AddInstruction(std::move(instr), HloOpcode::kFft, {operand});
}
XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
@ -1664,18 +1670,23 @@ StatusOr<XlaOp> XlaBuilder::TransposeInternal(
XlaOp XlaBuilder::Rev(XlaOp operand, absl::Span<const int64> dimensions) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReverseShape(
*operand_shape, dimensions));
*instr.mutable_shape() = shape.ToProto();
for (int64 dim : dimensions) {
instr.add_dimensions(dim);
}
return AddInstruction(std::move(instr), HloOpcode::kReverse, {operand});
return RevInternal(shape, operand, dimensions);
});
}
StatusOr<XlaOp> XlaBuilder::RevInternal(const Shape& shape, XlaOp operand,
absl::Span<const int64> dimensions) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
for (int64 dim : dimensions) {
instr.add_dimensions(dim);
}
return AddInstruction(std::move(instr), HloOpcode::kReverse, {operand});
}
XlaOp XlaBuilder::Sort(absl::Span<const XlaOp> operands,
const XlaComputation& comparator, int64 dimension,
bool is_stable) {

View File

@ -502,6 +502,9 @@ class XlaBuilder {
XlaOp Fft(XlaOp operand, FftType fft_type,
absl::Span<const int64> fft_length);
virtual StatusOr<XlaOp> FftInternal(const Shape& shape, XlaOp operand,
FftType fft_type,
absl::Span<const int64> fft_length);
XlaOp Infeed(const Shape& shape, const string& config = "");
XlaOp InfeedWithToken(XlaOp token, const Shape& shape, const string& config);
@ -594,7 +597,7 @@ class XlaBuilder {
absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
XlaOp init_value, const XlaComputation& scatter);
XlaOp Iota(const Shape& shape, int64 iota_dimension);
virtual XlaOp Iota(const Shape& shape, int64 iota_dimension);
XlaOp Iota(PrimitiveType type, int64 size);
@ -607,6 +610,8 @@ class XlaBuilder {
const Shape& shape, XlaOp operand, absl::Span<const int64> permutation);
XlaOp Rev(XlaOp operand, absl::Span<const int64> dimensions);
virtual StatusOr<XlaOp> RevInternal(const Shape& shape, XlaOp operand,
absl::Span<const int64> dimensions);
XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
int64 dimension = -1, bool is_stable = false);