From b396f0ed32bf17160cde2602457ff14a5cbbc7d3 Mon Sep 17 00:00:00 2001 From: Prakalp Srivastava Date: Thu, 27 Feb 2020 14:14:20 -0800 Subject: [PATCH] Support Float and Complex types in Iota op xla_hlo to std lowering. PiperOrigin-RevId: 297685371 Change-Id: I7aaba02c0f484a645230cebe026fb4856c9eeda0 --- .../mlir/xla/tests/legalize-to-std.mlir | 44 ++++++++++++++++++ .../xla/transforms/legalize_to_standard.cc | 45 ++++++++++++++++--- 2 files changed, 83 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir index f56174ae075..da6adf8cbe1 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir @@ -183,3 +183,47 @@ func @iota.const.6() -> tensor<2x3x4xi32> { // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32> return %0 : tensor<2x3x4xi32> } + +// CHECK-LABEL: func @iota.const.f32 +func @iota.const.f32() -> tensor<4xf32> { + // CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32> + %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32> + // CHECK-NEXT: return %[[CST]] : tensor<4xf32> + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @iota.const.f64 +func @iota.const.f64() -> tensor<4xf64> { + // CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64> + %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf64> + // CHECK-NEXT: return %[[CST]] : tensor<4xf64> + return %0 : tensor<4xf64> +} + +// CHECK-LABEL: func @iota.const.bf16 +func @iota.const.bf16() -> tensor<4xbf16> { + // CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xbf16> + %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xbf16> + // CHECK-NEXT: return %[[CST]] : tensor<4xbf16> + return %0 : tensor<4xbf16> +} + +// CHECK-LABEL: func @iota.const.complex.f32 +func @iota.const.complex.f32() -> tensor<4xcomplex> { + // CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32> + // CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf32> + // CHECK-NEXT: [[COMPLEX:%.*]] = "xla_hlo.complex"([[REAL]], [[IMAG]]) + %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex> + // CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex> + return %0 : tensor<4xcomplex> +} + +// CHECK-LABEL: func @iota.const.complex.f64 +func @iota.const.complex.f64() -> tensor<4xcomplex> { + // CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64> + // CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf64> + // CHECK-NEXT: [[COMPLEX:%.*]] = "xla_hlo.complex"([[REAL]], [[IMAG]]) + %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex> + // CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex> + return %0 : tensor<4xcomplex> +} diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc index 3c15f0be7e8..1c0f3d8f242 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc @@ -105,6 +105,10 @@ class CompareFConvert : public OpRewritePattern { } }; +// Replace IotaOp with an integer constant. A ConvertOp is added to +// convert the integer constant to iota result type. For complex types, the real +// part is replaced with the generated constant and the imaginary part is +// replaced with zero tensor. class ConvertIotaOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -112,14 +116,18 @@ class ConvertIotaOp : public OpRewritePattern { PatternMatchResult matchAndRewrite(xla_hlo::IotaOp op, PatternRewriter &rewriter) const override { auto output_type = op.getType().cast(); - // TODO(prakalps): Handle FP and ComplexType iota ops. - if (!output_type.getElementType().isSignlessInteger()) - return matchFailure(); auto output_size = output_type.getNumElements(); auto dimension = op.iota_dimension().getSExtValue(); auto max_dim_size = output_type.getDimSize(dimension); - int bitwidth = output_type.getElementType().getIntOrFloatBitWidth(); + auto element_type = output_type.getElementType(); + int bitwidth; + + auto complex_ty = element_type.dyn_cast(); + Type int_or_float_ty = element_type; + if (complex_ty) int_or_float_ty = complex_ty.getElementType(); + + bitwidth = int_or_float_ty.getIntOrFloatBitWidth(); llvm::SmallVector values; values.reserve(output_size); @@ -135,8 +143,33 @@ class ConvertIotaOp : public OpRewritePattern { ++current_value; } - rewriter.replaceOpWithNewOp( - op, DenseIntElementsAttr::get(output_type, values)); + auto int_shape_type = RankedTensorType::get( + output_type.getShape(), + IntegerType::get(bitwidth, rewriter.getContext())); + auto loc = op.getLoc(); + auto integer_const = rewriter.create( + loc, DenseIntElementsAttr::get(int_shape_type, values)); + + auto int_or_float_shape_ty = + RankedTensorType::get(output_type.getShape(), int_or_float_ty); + + auto iota_const = + rewriter.create(loc, int_or_float_shape_ty, integer_const); + + // For int/float types we are done, replace op and return. + if (!complex_ty) { + rewriter.replaceOp(op, iota_const.getResult()); + return matchSuccess(); + } + + // For complex types, generate a constant tensor of zeroes for the imaginary + // part and use iota_const for real part. + auto zeroes = rewriter.create( + loc, DenseIntElementsAttr::get(int_shape_type, APInt(bitwidth, 0))); + auto imag_zeroes = + rewriter.create(loc, int_or_float_shape_ty, zeroes); + rewriter.replaceOpWithNewOp(op, iota_const, + imag_zeroes); return matchSuccess(); } };