Support Float and Complex types in Iota op xla_hlo to std lowering.

PiperOrigin-RevId: 297685371
Change-Id: I7aaba02c0f484a645230cebe026fb4856c9eeda0
This commit is contained in:
Prakalp Srivastava 2020-02-27 14:14:20 -08:00 committed by TensorFlower Gardener
parent 384f1a5507
commit b396f0ed32
2 changed files with 83 additions and 6 deletions

View File

@ -183,3 +183,47 @@ func @iota.const.6() -> tensor<2x3x4xi32> {
// CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32> // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
return %0 : tensor<2x3x4xi32> return %0 : tensor<2x3x4xi32>
} }
// CHECK-LABEL: func @iota.const.f32
func @iota.const.f32() -> tensor<4xf32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
// CHECK-NEXT: return %[[CST]] : tensor<4xf32>
return %0 : tensor<4xf32>
}
// CHECK-LABEL: func @iota.const.f64
func @iota.const.f64() -> tensor<4xf64> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64>
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf64>
// CHECK-NEXT: return %[[CST]] : tensor<4xf64>
return %0 : tensor<4xf64>
}
// CHECK-LABEL: func @iota.const.bf16
func @iota.const.bf16() -> tensor<4xbf16> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xbf16>
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xbf16>
// CHECK-NEXT: return %[[CST]] : tensor<4xbf16>
return %0 : tensor<4xbf16>
}
// CHECK-LABEL: func @iota.const.complex.f32
func @iota.const.complex.f32() -> tensor<4xcomplex<f32>> {
// CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>
// CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf32>
// CHECK-NEXT: [[COMPLEX:%.*]] = "xla_hlo.complex"([[REAL]], [[IMAG]])
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f32>>
// CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex<f32>>
return %0 : tensor<4xcomplex<f32>>
}
// CHECK-LABEL: func @iota.const.complex.f64
func @iota.const.complex.f64() -> tensor<4xcomplex<f64>> {
// CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64>
// CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf64>
// CHECK-NEXT: [[COMPLEX:%.*]] = "xla_hlo.complex"([[REAL]], [[IMAG]])
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f64>>
// CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex<f64>>
return %0 : tensor<4xcomplex<f64>>
}

View File

@ -105,6 +105,10 @@ class CompareFConvert : public OpRewritePattern<xla_hlo::CompareOp> {
} }
}; };
// Replace IotaOp with an integer constant. A ConvertOp is added to
// convert the integer constant to iota result type. For complex types, the real
// part is replaced with the generated constant and the imaginary part is
// replaced with zero tensor.
class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> { class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
public: public:
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
@ -112,14 +116,18 @@ class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
PatternMatchResult matchAndRewrite(xla_hlo::IotaOp op, PatternMatchResult matchAndRewrite(xla_hlo::IotaOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto output_type = op.getType().cast<ShapedType>(); auto output_type = op.getType().cast<ShapedType>();
// TODO(prakalps): Handle FP and ComplexType iota ops.
if (!output_type.getElementType().isSignlessInteger())
return matchFailure();
auto output_size = output_type.getNumElements(); auto output_size = output_type.getNumElements();
auto dimension = op.iota_dimension().getSExtValue(); auto dimension = op.iota_dimension().getSExtValue();
auto max_dim_size = output_type.getDimSize(dimension); auto max_dim_size = output_type.getDimSize(dimension);
int bitwidth = output_type.getElementType().getIntOrFloatBitWidth();
auto element_type = output_type.getElementType();
int bitwidth;
auto complex_ty = element_type.dyn_cast<ComplexType>();
Type int_or_float_ty = element_type;
if (complex_ty) int_or_float_ty = complex_ty.getElementType();
bitwidth = int_or_float_ty.getIntOrFloatBitWidth();
llvm::SmallVector<APInt, 10> values; llvm::SmallVector<APInt, 10> values;
values.reserve(output_size); values.reserve(output_size);
@ -135,8 +143,33 @@ class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
++current_value; ++current_value;
} }
rewriter.replaceOpWithNewOp<mlir::ConstantOp>( auto int_shape_type = RankedTensorType::get(
op, DenseIntElementsAttr::get(output_type, values)); output_type.getShape(),
IntegerType::get(bitwidth, rewriter.getContext()));
auto loc = op.getLoc();
auto integer_const = rewriter.create<mlir::ConstantOp>(
loc, DenseIntElementsAttr::get(int_shape_type, values));
auto int_or_float_shape_ty =
RankedTensorType::get(output_type.getShape(), int_or_float_ty);
auto iota_const =
rewriter.create<ConvertOp>(loc, int_or_float_shape_ty, integer_const);
// For int/float types we are done, replace op and return.
if (!complex_ty) {
rewriter.replaceOp(op, iota_const.getResult());
return matchSuccess();
}
// For complex types, generate a constant tensor of zeroes for the imaginary
// part and use iota_const for real part.
auto zeroes = rewriter.create<mlir::ConstantOp>(
loc, DenseIntElementsAttr::get(int_shape_type, APInt(bitwidth, 0)));
auto imag_zeroes =
rewriter.create<ConvertOp>(loc, int_or_float_shape_ty, zeroes);
rewriter.replaceOpWithNewOp<xla_hlo::ComplexOp>(op, iota_const,
imag_zeroes);
return matchSuccess(); return matchSuccess();
} }
}; };