diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 35c159da0b9..beb50599154 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -1545,61 +1545,33 @@ func @rfft_1D(%arg0: tensor<8xf32>) -> tensor<8xcomplex> { // CHECK-LABEL: func @shape_1D func @shape_1D(%arg0: tensor) -> tensor<1xi32> { // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0 - // CHECK: %[[C0:.*]] = shape.const_size 0 - // CHECK-DAG: [[EXTENT:%.+]] = shape.get_extent [[SHAPE]], %[[C0]] - // CHECK-DAG: [[TO_INDEX:%.+]] = shape.size_to_index [[EXTENT]] - // CHECK-DAG: [[CAST:%.+]] = index_cast [[TO_INDEX]] - // CHECK-DAG: [[TENSOR:%.+]] = tensor_from_elements([[CAST]]) - // CHECK-DAG: [[RESHAPE:%.+]] = "xla_hlo.reshape"([[TENSOR]]) - // CHECK-DAG: [[CONCAT:%.+]] = "xla_hlo.concatenate"([[RESHAPE]]) {dimension = 0 : i64} + // CHECK: [[TENSOR:%.+]] = "shape.to_extent_tensor"([[SHAPE]]) + // CHECK: [[CAST:%.+]] = index_cast [[TENSOR]] %0 = "tf.Shape"(%arg0) : (tensor) -> tensor<1xi32> - // CHECK: return [[CONCAT]] + // CHECK: return [[CAST]] return %0 : tensor<1xi32> } // CHECK-LABEL: func @shape_2D func @shape_2D(%arg0: tensor) -> tensor<2xi32> { // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0 - // CHECK: %[[C0:.*]] = shape.const_size 0 - // CHECK: [[EXTENT0:%.+]] = shape.get_extent [[SHAPE]], %[[C0]] - // CHECK-DAG: [[TO_INDEX0:%.+]] = shape.size_to_index [[EXTENT0]] - // CHECK-DAG: [[CAST0:%.+]] = index_cast [[TO_INDEX0]] - // CHECK-DAG: [[TENSOR0:%.+]] = tensor_from_elements([[CAST0]]) - // CHECK-DAG: [[RESHAPE0:%.+]] = "xla_hlo.reshape"([[TENSOR0]]) - // CHECK: %[[C1:.*]] = shape.const_size 1 - // CHECK-DAG: [[EXTENT1:%.+]] = shape.get_extent [[SHAPE]], %[[C1]] - // CHECK-DAG: [[TO_INDEX1:%.+]] = shape.size_to_index [[EXTENT1]] - // CHECK-DAG: [[CAST1:%.+]] = index_cast [[TO_INDEX1]] - // CHECK-DAG: [[TENSOR1:%.+]] = tensor_from_elements([[CAST1]]) - // CHECK-DAG: [[RESHAPE1:%.+]] = "xla_hlo.reshape"([[TENSOR1]]) - // CHECK-DAG: [[CONCAT:%.+]] = "xla_hlo.concatenate"([[RESHAPE0]], [[RESHAPE1]]) {dimension = 0 : i64} + // CHECK: [[TENSOR:%.+]] = "shape.to_extent_tensor"([[SHAPE]]) + // CHECK: [[CAST:%.+]] = index_cast [[TENSOR]] %0 = "tf.Shape"(%arg0) : (tensor) -> tensor<2xi32> - // CHECK: return [[CONCAT]] - return %0 : tensor<2xi32> -} - -// CHECK-LABEL: func @shape_with_const -func @shape_with_const(%arg0: tensor) -> tensor<2xi32> { - // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0 - // CHECK: %[[C0:.*]] = shape.const_size 0 - // CHECK-DAG: [[EXTENT:%.+]] = shape.get_extent [[SHAPE]], %[[C0]] - // CHECK-DAG: [[TO_INDEX:%.+]] = shape.size_to_index [[EXTENT]] - // CHECK-DAG: [[CAST:%.+]] = index_cast [[TO_INDEX]] - // CHECK-DAG: [[TENSOR:%.+]] = tensor_from_elements([[CAST]]) - // CHECK-DAG: [[RESHAPE:%.+]] = "xla_hlo.reshape"([[TENSOR]]) - // CHECK-DAG: [[CONST:%.+]] = xla_hlo.constant dense<3> - // CHECK-DAG: [[CONCAT:%.+]] = "xla_hlo.concatenate"([[RESHAPE]], [[CONST]]) {dimension = 0 : i64} - %0 = "tf.Shape"(%arg0) : (tensor) -> tensor<2xi32> - - // CHECK: return [[CONCAT]] + // CHECK: return [[CAST]] return %0 : tensor<2xi32> } // CHECK-LABEL: func @shape_rankless func @shape_rankless(%arg0: tensor<*xf32>) -> tensor { + // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0 + // CHECK: [[TENSOR:%.+]] = "shape.to_extent_tensor"([[SHAPE]]) + // CHECK: [[CAST:%.+]] = index_cast [[TENSOR]] %0 = "tf.Shape"(%arg0) : (tensor<*xf32>) -> tensor + + // CHECK: return [[CAST]] return %0 : tensor } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 36243a1bc57..200e04d5413 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -4610,45 +4610,16 @@ class ConvertShapeOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::ShapeOp op, PatternRewriter &rewriter) const override { Value input = op.input(); - auto input_ty = input.getType().dyn_cast(); - // If the shape is static it can be canonicalized. - if (!input_ty || input_ty.hasStaticShape()) { - return failure(); - } - auto result_ty = op.getResult().getType().cast(); - auto element_ty = result_ty.getElementType(); - - int64_t rank = input_ty.getRank(); auto shape_op = rewriter.create(op.getLoc(), input); + auto result_ty = op.getResult().getType().cast(); - auto index_ty = RankedTensorType::get({1}, element_ty); - llvm::SmallVector dim_values; - for (int64_t i = 0; i < rank; ++i) { - if (!input_ty.isDynamicDim(i)) { - auto dim_attr = DenseElementsAttr::get( - index_ty, - rewriter.getIntegerAttr(element_ty, input_ty.getDimSize(i))); - auto index = rewriter.create(op.getLoc(), dim_attr); - dim_values.push_back(index); - continue; - } + auto index_tensor = + RankedTensorType::get(result_ty.getShape(), rewriter.getIndexType()); + auto extent_tensor = rewriter.create( + op.getLoc(), index_tensor, shape_op); - auto extent_op = - rewriter.create(op.getLoc(), shape_op, i); - auto index_op = rewriter.create( - op.getLoc(), rewriter.getIndexType(), extent_op); - auto int_op = - rewriter.create(op.getLoc(), element_ty, index_op); - auto from_tensor = rewriter.create( - op.getLoc(), int_op.getResult()); - auto reshape_op = - rewriter.create(op.getLoc(), index_ty, from_tensor); - dim_values.push_back(reshape_op); - } - - rewriter.replaceOpWithNewOp(op, result_ty, dim_values, - rewriter.getI64IntegerAttr(0)); + rewriter.replaceOpWithNewOp(op, result_ty, extent_tensor); return success(); } };