Create Iota, Rev and Fft HLO ops with MlirHloBuilder
White-list some of the ops enabled through this and enable corresponding compiler tests. PiperOrigin-RevId: 316209927 Change-Id: I4f12a197425d3b1766a12c29d06e64f78caec307
This commit is contained in:
parent
fe34364219
commit
fd934e4895
@ -120,6 +120,30 @@ StatusOr<XlaOp> MlirHloBuilder::ConvGeneralDilatedInternal(
|
||||
return MakeXlaOp(op);
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::FftInternal(
|
||||
const Shape& shape, XlaOp operand, FftType fft_type,
|
||||
absl::Span<const int64> fft_length) {
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
shape, builder_));
|
||||
auto op = builder_.create<mlir::xla_hlo::FftOp>(
|
||||
loc_, ty, GetValue(operand),
|
||||
builder_.getStringAttr(FftType_Name(fft_type)),
|
||||
GetI64ElementsAttr(fft_length, &builder_));
|
||||
return MakeXlaOp(op);
|
||||
}
|
||||
|
||||
XlaOp MlirHloBuilder::Iota(const Shape& shape, int64 iota_dimension) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
mlir::Type ty,
|
||||
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
|
||||
auto op = builder_.create<mlir::xla_hlo::IotaOp>(
|
||||
loc_, ty,
|
||||
builder_.getIntegerAttr(builder_.getI64Type(), iota_dimension));
|
||||
return MakeXlaOp(op);
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::TransposeInternal(
|
||||
const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) {
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
@ -129,6 +153,15 @@ StatusOr<XlaOp> MlirHloBuilder::TransposeInternal(
|
||||
return MakeXlaOp(op);
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::RevInternal(
|
||||
const Shape& shape, XlaOp operand, absl::Span<const int64> dimensions) {
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
shape, builder_));
|
||||
auto op = builder_.create<mlir::xla_hlo::ReverseOp>(
|
||||
loc_, ty, GetValue(operand), GetI64ElementsAttr(dimensions, &builder_));
|
||||
return MakeXlaOp(op);
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::GatherInternal(
|
||||
const Shape& shape, XlaOp input, XlaOp start_indices,
|
||||
const GatherDimensionNumbers& dimension_numbers,
|
||||
|
@ -120,10 +120,19 @@ class MlirHloBuilder : public XlaBuilder {
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config) override;
|
||||
|
||||
StatusOr<XlaOp> FftInternal(const Shape& shape, XlaOp operand,
|
||||
FftType fft_type,
|
||||
absl::Span<const int64> fft_length) override;
|
||||
|
||||
XlaOp Iota(const Shape& shape, int64 iota_dimension) override;
|
||||
|
||||
StatusOr<XlaOp> TransposeInternal(
|
||||
const Shape& shape, XlaOp operand,
|
||||
absl::Span<const int64> permutation) override;
|
||||
|
||||
StatusOr<XlaOp> RevInternal(const Shape& shape, XlaOp operand,
|
||||
absl::Span<const int64> dimensions) override;
|
||||
|
||||
StatusOr<XlaOp> GatherInternal(
|
||||
const Shape& shape, XlaOp input, XlaOp start_indices,
|
||||
const GatherDimensionNumbers& dimension_numbers,
|
||||
|
@ -214,6 +214,28 @@ func @sparse_to_dense(%arg0: tensor<3x2xi32>, %arg1: tensor<3xf32>, %arg2: tenso
|
||||
return %0 : tensor<3x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fft
|
||||
func @fft(%arg0: tensor<3x5x8xcomplex<f32>>) -> tensor<3x5x8xcomplex<f32>> {
|
||||
// CHECK: "xla_hlo.fft"(%arg0)
|
||||
%0 = "tf.FFT"(%arg0) : (tensor<3x5x8xcomplex<f32>>) -> tensor<3x5x8xcomplex<f32>>
|
||||
return %0 : tensor<3x5x8xcomplex<f32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: reverse_sequence
|
||||
func @reverse_sequence(%arg0: tensor<4x2x3x1x1xi32>, %arg1: tensor<3xi32>) -> tensor<4x2x3x1x1xi32> {
|
||||
// CHECK-NOT: tf.ReverseSequence
|
||||
%0 = "tf.ReverseSequence"(%arg0, %arg1) {batch_dim = 2 : i64, seq_dim = 0 : i64}: (tensor<4x2x3x1x1xi32>, tensor<3xi32>) -> tensor<4x2x3x1x1xi32>
|
||||
return %0 : tensor<4x2x3x1x1xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mirror_pad
|
||||
func @mirror_pad(%arg0: tensor<2x3xcomplex<f64>>) -> tensor<4x7xcomplex<f64>> {
|
||||
%0 = xla_hlo.constant dense<[[1, 1], [2, 2]]> : tensor<2x2xi32>
|
||||
// CHECK-NOT: tf.MirrorPad
|
||||
%1 = "tf.MirrorPad"(%arg0, %0) {mode = "SYMMETRIC"} : (tensor<2x3xcomplex<f64>>, tensor<2x2xi32>) -> tensor<4x7xcomplex<f64>>
|
||||
return %1 : tensor<4x7xcomplex<f64>>
|
||||
}
|
||||
|
||||
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
|
||||
// available but doesn't support this instance.
|
||||
}
|
||||
|
@ -116,11 +116,20 @@ static bool IsOpWhitelisted(Operation* op) {
|
||||
TypeID::get<TF::ErfcOp>(),
|
||||
TypeID::get<TF::ErfOp>(),
|
||||
TypeID::get<TF::Expm1Op>(),
|
||||
TypeID::get<TF::FFT2DOp>(),
|
||||
TypeID::get<TF::FFT3DOp>(),
|
||||
TypeID::get<TF::FFTOp>(),
|
||||
TypeID::get<TF::FloorDivOp>(),
|
||||
TypeID::get<TF::FloorModOp>(),
|
||||
TypeID::get<TF::GatherNdOp>(),
|
||||
TypeID::get<TF::GreaterEqualOp>(),
|
||||
TypeID::get<TF::GreaterOp>(),
|
||||
TypeID::get<TF::IFFT2DOp>(),
|
||||
TypeID::get<TF::IFFT3DOp>(),
|
||||
TypeID::get<TF::IFFTOp>(),
|
||||
TypeID::get<TF::IRFFT2DOp>(),
|
||||
TypeID::get<TF::IRFFT3DOp>(),
|
||||
TypeID::get<TF::IRFFTOp>(),
|
||||
TypeID::get<TF::InvertOp>(),
|
||||
TypeID::get<TF::InvOp>(),
|
||||
TypeID::get<TF::LeakyReluGradOp>(),
|
||||
@ -134,16 +143,20 @@ static bool IsOpWhitelisted(Operation* op) {
|
||||
TypeID::get<TF::LogicalOrOp>(),
|
||||
TypeID::get<TF::LogOp>(),
|
||||
TypeID::get<TF::MatMulOp>(),
|
||||
TypeID::get<TF::MirrorPadOp>(),
|
||||
TypeID::get<TF::MulOp>(),
|
||||
TypeID::get<TF::NegOp>(),
|
||||
TypeID::get<TF::NotEqualOp>(),
|
||||
TypeID::get<TF::PadOp>(),
|
||||
TypeID::get<TF::PlaceholderWithDefaultOp>(),
|
||||
TypeID::get<TF::PowOp>(),
|
||||
TypeID::get<TF::RFFT2DOp>(),
|
||||
TypeID::get<TF::RFFT3DOp>(),
|
||||
TypeID::get<TF::RealDivOp>(),
|
||||
TypeID::get<TF::ReciprocalOp>(),
|
||||
TypeID::get<TF::ReciprocalGradOp>(),
|
||||
TypeID::get<TF::Relu6GradOp>(),
|
||||
TypeID::get<TF::ReverseSequenceOp>(),
|
||||
TypeID::get<TF::RightShiftOp>(),
|
||||
TypeID::get<TF::RintOp>(),
|
||||
TypeID::get<TF::RoundOp>(),
|
||||
|
@ -692,6 +692,7 @@ tf_xla_py_test(
|
||||
name = "fft_test",
|
||||
size = "medium",
|
||||
srcs = ["fft_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
shard_count = 6,
|
||||
tags = [
|
||||
@ -1129,6 +1130,7 @@ tf_xla_py_test(
|
||||
name = "reverse_sequence_op_test",
|
||||
size = "medium",
|
||||
srcs = ["reverse_sequence_op_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
|
@ -1225,8 +1225,6 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
[7, 7, 7, 7, 7, 7]],
|
||||
dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"Requires concatenate op support in MlirHloBuilder")
|
||||
def testSymmetricMirrorPad(self):
|
||||
mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "SYMMETRIC")
|
||||
for dtype in self.numeric_types:
|
||||
@ -1258,8 +1256,6 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
np.array([[0, 0], [0, 0]], dtype=np.int32),
|
||||
expected=np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"Requires concatenate op support in MlirHloBuilder")
|
||||
def testReflectMirrorPad(self):
|
||||
mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "REFLECT")
|
||||
for dtype in self.numeric_types:
|
||||
|
@ -1323,20 +1323,26 @@ StatusOr<XlaOp> XlaBuilder::ConvGeneralDilatedInternal(
|
||||
XlaOp XlaBuilder::Fft(XlaOp operand, const FftType fft_type,
|
||||
const absl::Span<const int64> fft_length) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferFftShape(
|
||||
*operand_shape, fft_type, fft_length));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
instr.set_fft_type(fft_type);
|
||||
for (int64 i : fft_length) {
|
||||
instr.add_fft_length(i);
|
||||
}
|
||||
|
||||
return AddInstruction(std::move(instr), HloOpcode::kFft, {operand});
|
||||
return FftInternal(shape, operand, fft_type, fft_length);
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::FftInternal(
|
||||
const Shape& shape, XlaOp operand, const FftType fft_type,
|
||||
const absl::Span<const int64> fft_length) {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
instr.set_fft_type(fft_type);
|
||||
for (int64 i : fft_length) {
|
||||
instr.add_fft_length(i);
|
||||
}
|
||||
|
||||
return AddInstruction(std::move(instr), HloOpcode::kFft, {operand});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
@ -1664,18 +1670,23 @@ StatusOr<XlaOp> XlaBuilder::TransposeInternal(
|
||||
|
||||
XlaOp XlaBuilder::Rev(XlaOp operand, absl::Span<const int64> dimensions) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReverseShape(
|
||||
*operand_shape, dimensions));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
for (int64 dim : dimensions) {
|
||||
instr.add_dimensions(dim);
|
||||
}
|
||||
return AddInstruction(std::move(instr), HloOpcode::kReverse, {operand});
|
||||
return RevInternal(shape, operand, dimensions);
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::RevInternal(const Shape& shape, XlaOp operand,
|
||||
absl::Span<const int64> dimensions) {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
for (int64 dim : dimensions) {
|
||||
instr.add_dimensions(dim);
|
||||
}
|
||||
return AddInstruction(std::move(instr), HloOpcode::kReverse, {operand});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Sort(absl::Span<const XlaOp> operands,
|
||||
const XlaComputation& comparator, int64 dimension,
|
||||
bool is_stable) {
|
||||
|
@ -502,6 +502,9 @@ class XlaBuilder {
|
||||
|
||||
XlaOp Fft(XlaOp operand, FftType fft_type,
|
||||
absl::Span<const int64> fft_length);
|
||||
virtual StatusOr<XlaOp> FftInternal(const Shape& shape, XlaOp operand,
|
||||
FftType fft_type,
|
||||
absl::Span<const int64> fft_length);
|
||||
|
||||
XlaOp Infeed(const Shape& shape, const string& config = "");
|
||||
XlaOp InfeedWithToken(XlaOp token, const Shape& shape, const string& config);
|
||||
@ -594,7 +597,7 @@ class XlaBuilder {
|
||||
absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
|
||||
XlaOp init_value, const XlaComputation& scatter);
|
||||
|
||||
XlaOp Iota(const Shape& shape, int64 iota_dimension);
|
||||
virtual XlaOp Iota(const Shape& shape, int64 iota_dimension);
|
||||
|
||||
XlaOp Iota(PrimitiveType type, int64 size);
|
||||
|
||||
@ -607,6 +610,8 @@ class XlaBuilder {
|
||||
const Shape& shape, XlaOp operand, absl::Span<const int64> permutation);
|
||||
|
||||
XlaOp Rev(XlaOp operand, absl::Span<const int64> dimensions);
|
||||
virtual StatusOr<XlaOp> RevInternal(const Shape& shape, XlaOp operand,
|
||||
absl::Span<const int64> dimensions);
|
||||
|
||||
XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
|
||||
int64 dimension = -1, bool is_stable = false);
|
||||
|
Loading…
Reference in New Issue
Block a user