Support Float and Complex types in Iota op xla_hlo to std lowering.
PiperOrigin-RevId: 297685371 Change-Id: I7aaba02c0f484a645230cebe026fb4856c9eeda0
This commit is contained in:
parent
384f1a5507
commit
b396f0ed32
@ -183,3 +183,47 @@ func @iota.const.6() -> tensor<2x3x4xi32> {
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
|
||||
return %0 : tensor<2x3x4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @iota.const.f32
|
||||
func @iota.const.f32() -> tensor<4xf32> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @iota.const.f64
|
||||
func @iota.const.f64() -> tensor<4xf64> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf64>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<4xf64>
|
||||
return %0 : tensor<4xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @iota.const.bf16
|
||||
func @iota.const.bf16() -> tensor<4xbf16> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xbf16>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xbf16>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<4xbf16>
|
||||
return %0 : tensor<4xbf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @iota.const.complex.f32
|
||||
func @iota.const.complex.f32() -> tensor<4xcomplex<f32>> {
|
||||
// CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>
|
||||
// CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf32>
|
||||
// CHECK-NEXT: [[COMPLEX:%.*]] = "xla_hlo.complex"([[REAL]], [[IMAG]])
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f32>>
|
||||
// CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex<f32>>
|
||||
return %0 : tensor<4xcomplex<f32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @iota.const.complex.f64
|
||||
func @iota.const.complex.f64() -> tensor<4xcomplex<f64>> {
|
||||
// CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64>
|
||||
// CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf64>
|
||||
// CHECK-NEXT: [[COMPLEX:%.*]] = "xla_hlo.complex"([[REAL]], [[IMAG]])
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f64>>
|
||||
// CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex<f64>>
|
||||
return %0 : tensor<4xcomplex<f64>>
|
||||
}
|
||||
|
@ -105,6 +105,10 @@ class CompareFConvert : public OpRewritePattern<xla_hlo::CompareOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Replace IotaOp with an integer constant. A ConvertOp is added to
|
||||
// convert the integer constant to iota result type. For complex types, the real
|
||||
// part is replaced with the generated constant and the imaginary part is
|
||||
// replaced with zero tensor.
|
||||
class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
@ -112,14 +116,18 @@ class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
|
||||
PatternMatchResult matchAndRewrite(xla_hlo::IotaOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto output_type = op.getType().cast<ShapedType>();
|
||||
// TODO(prakalps): Handle FP and ComplexType iota ops.
|
||||
if (!output_type.getElementType().isSignlessInteger())
|
||||
return matchFailure();
|
||||
auto output_size = output_type.getNumElements();
|
||||
auto dimension = op.iota_dimension().getSExtValue();
|
||||
auto max_dim_size = output_type.getDimSize(dimension);
|
||||
int bitwidth = output_type.getElementType().getIntOrFloatBitWidth();
|
||||
|
||||
auto element_type = output_type.getElementType();
|
||||
int bitwidth;
|
||||
|
||||
auto complex_ty = element_type.dyn_cast<ComplexType>();
|
||||
Type int_or_float_ty = element_type;
|
||||
if (complex_ty) int_or_float_ty = complex_ty.getElementType();
|
||||
|
||||
bitwidth = int_or_float_ty.getIntOrFloatBitWidth();
|
||||
llvm::SmallVector<APInt, 10> values;
|
||||
values.reserve(output_size);
|
||||
|
||||
@ -135,8 +143,33 @@ class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
|
||||
++current_value;
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<mlir::ConstantOp>(
|
||||
op, DenseIntElementsAttr::get(output_type, values));
|
||||
auto int_shape_type = RankedTensorType::get(
|
||||
output_type.getShape(),
|
||||
IntegerType::get(bitwidth, rewriter.getContext()));
|
||||
auto loc = op.getLoc();
|
||||
auto integer_const = rewriter.create<mlir::ConstantOp>(
|
||||
loc, DenseIntElementsAttr::get(int_shape_type, values));
|
||||
|
||||
auto int_or_float_shape_ty =
|
||||
RankedTensorType::get(output_type.getShape(), int_or_float_ty);
|
||||
|
||||
auto iota_const =
|
||||
rewriter.create<ConvertOp>(loc, int_or_float_shape_ty, integer_const);
|
||||
|
||||
// For int/float types we are done, replace op and return.
|
||||
if (!complex_ty) {
|
||||
rewriter.replaceOp(op, iota_const.getResult());
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
// For complex types, generate a constant tensor of zeroes for the imaginary
|
||||
// part and use iota_const for real part.
|
||||
auto zeroes = rewriter.create<mlir::ConstantOp>(
|
||||
loc, DenseIntElementsAttr::get(int_shape_type, APInt(bitwidth, 0)));
|
||||
auto imag_zeroes =
|
||||
rewriter.create<ConvertOp>(loc, int_or_float_shape_ty, zeroes);
|
||||
rewriter.replaceOpWithNewOp<xla_hlo::ComplexOp>(op, iota_const,
|
||||
imag_zeroes);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user