Support Float and Complex types in Iota op xla_hlo to std lowering.
PiperOrigin-RevId: 297685371 Change-Id: I7aaba02c0f484a645230cebe026fb4856c9eeda0
This commit is contained in:
parent
384f1a5507
commit
b396f0ed32
@ -183,3 +183,47 @@ func @iota.const.6() -> tensor<2x3x4xi32> {
|
|||||||
// CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
|
// CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
|
||||||
return %0 : tensor<2x3x4xi32>
|
return %0 : tensor<2x3x4xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @iota.const.f32
|
||||||
|
func @iota.const.f32() -> tensor<4xf32> {
|
||||||
|
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>
|
||||||
|
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
|
||||||
|
// CHECK-NEXT: return %[[CST]] : tensor<4xf32>
|
||||||
|
return %0 : tensor<4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @iota.const.f64
|
||||||
|
func @iota.const.f64() -> tensor<4xf64> {
|
||||||
|
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64>
|
||||||
|
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf64>
|
||||||
|
// CHECK-NEXT: return %[[CST]] : tensor<4xf64>
|
||||||
|
return %0 : tensor<4xf64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @iota.const.bf16
|
||||||
|
func @iota.const.bf16() -> tensor<4xbf16> {
|
||||||
|
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xbf16>
|
||||||
|
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xbf16>
|
||||||
|
// CHECK-NEXT: return %[[CST]] : tensor<4xbf16>
|
||||||
|
return %0 : tensor<4xbf16>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @iota.const.complex.f32
|
||||||
|
func @iota.const.complex.f32() -> tensor<4xcomplex<f32>> {
|
||||||
|
// CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>
|
||||||
|
// CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf32>
|
||||||
|
// CHECK-NEXT: [[COMPLEX:%.*]] = "xla_hlo.complex"([[REAL]], [[IMAG]])
|
||||||
|
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f32>>
|
||||||
|
// CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex<f32>>
|
||||||
|
return %0 : tensor<4xcomplex<f32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @iota.const.complex.f64
|
||||||
|
func @iota.const.complex.f64() -> tensor<4xcomplex<f64>> {
|
||||||
|
// CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64>
|
||||||
|
// CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf64>
|
||||||
|
// CHECK-NEXT: [[COMPLEX:%.*]] = "xla_hlo.complex"([[REAL]], [[IMAG]])
|
||||||
|
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f64>>
|
||||||
|
// CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex<f64>>
|
||||||
|
return %0 : tensor<4xcomplex<f64>>
|
||||||
|
}
|
||||||
|
@ -105,6 +105,10 @@ class CompareFConvert : public OpRewritePattern<xla_hlo::CompareOp> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Replace IotaOp with an integer constant. A ConvertOp is added to
|
||||||
|
// convert the integer constant to iota result type. For complex types, the real
|
||||||
|
// part is replaced with the generated constant and the imaginary part is
|
||||||
|
// replaced with zero tensor.
|
||||||
class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
|
class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
|
||||||
public:
|
public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
@ -112,14 +116,18 @@ class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
|
|||||||
PatternMatchResult matchAndRewrite(xla_hlo::IotaOp op,
|
PatternMatchResult matchAndRewrite(xla_hlo::IotaOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto output_type = op.getType().cast<ShapedType>();
|
auto output_type = op.getType().cast<ShapedType>();
|
||||||
// TODO(prakalps): Handle FP and ComplexType iota ops.
|
|
||||||
if (!output_type.getElementType().isSignlessInteger())
|
|
||||||
return matchFailure();
|
|
||||||
auto output_size = output_type.getNumElements();
|
auto output_size = output_type.getNumElements();
|
||||||
auto dimension = op.iota_dimension().getSExtValue();
|
auto dimension = op.iota_dimension().getSExtValue();
|
||||||
auto max_dim_size = output_type.getDimSize(dimension);
|
auto max_dim_size = output_type.getDimSize(dimension);
|
||||||
int bitwidth = output_type.getElementType().getIntOrFloatBitWidth();
|
|
||||||
|
|
||||||
|
auto element_type = output_type.getElementType();
|
||||||
|
int bitwidth;
|
||||||
|
|
||||||
|
auto complex_ty = element_type.dyn_cast<ComplexType>();
|
||||||
|
Type int_or_float_ty = element_type;
|
||||||
|
if (complex_ty) int_or_float_ty = complex_ty.getElementType();
|
||||||
|
|
||||||
|
bitwidth = int_or_float_ty.getIntOrFloatBitWidth();
|
||||||
llvm::SmallVector<APInt, 10> values;
|
llvm::SmallVector<APInt, 10> values;
|
||||||
values.reserve(output_size);
|
values.reserve(output_size);
|
||||||
|
|
||||||
@ -135,8 +143,33 @@ class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
|
|||||||
++current_value;
|
++current_value;
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<mlir::ConstantOp>(
|
auto int_shape_type = RankedTensorType::get(
|
||||||
op, DenseIntElementsAttr::get(output_type, values));
|
output_type.getShape(),
|
||||||
|
IntegerType::get(bitwidth, rewriter.getContext()));
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
auto integer_const = rewriter.create<mlir::ConstantOp>(
|
||||||
|
loc, DenseIntElementsAttr::get(int_shape_type, values));
|
||||||
|
|
||||||
|
auto int_or_float_shape_ty =
|
||||||
|
RankedTensorType::get(output_type.getShape(), int_or_float_ty);
|
||||||
|
|
||||||
|
auto iota_const =
|
||||||
|
rewriter.create<ConvertOp>(loc, int_or_float_shape_ty, integer_const);
|
||||||
|
|
||||||
|
// For int/float types we are done, replace op and return.
|
||||||
|
if (!complex_ty) {
|
||||||
|
rewriter.replaceOp(op, iota_const.getResult());
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
|
||||||
|
// For complex types, generate a constant tensor of zeroes for the imaginary
|
||||||
|
// part and use iota_const for real part.
|
||||||
|
auto zeroes = rewriter.create<mlir::ConstantOp>(
|
||||||
|
loc, DenseIntElementsAttr::get(int_shape_type, APInt(bitwidth, 0)));
|
||||||
|
auto imag_zeroes =
|
||||||
|
rewriter.create<ConvertOp>(loc, int_or_float_shape_ty, zeroes);
|
||||||
|
rewriter.replaceOpWithNewOp<xla_hlo::ComplexOp>(op, iota_const,
|
||||||
|
imag_zeroes);
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user